├── .gitignore ├── LICENSE ├── README.md ├── acdc ├── TLACDCCorrespondence.py ├── TLACDCEdge.py ├── TLACDCExperiment.py ├── TLACDCInterpNode.py ├── __init__.py ├── acdc_graphics.py ├── acdc_utils.py ├── docstring │ ├── __init__.py │ ├── prompts.py │ └── utils.py ├── global_cache.py ├── greaterthan │ ├── __init__.py │ └── utils.py ├── induction │ ├── __init__.py │ └── utils.py ├── ioi │ ├── __init__.py │ ├── ioi_dataset.py │ └── utils.py ├── knowledge │ ├── __init__.py │ ├── knowledge_dataset.py │ └── utils.py ├── logic_gates │ ├── __init__.py │ └── utils.py ├── main.py ├── run.sh └── tracr_task │ ├── __init__.py │ └── utils.py ├── computed_circuits └── gpt2-medium │ ├── Factual Knowledge │ ├── graph.gv │ ├── graph.png │ └── io.txt │ ├── Hallucination │ ├── attention_analysis.png │ ├── graph.gv │ ├── io.txt │ └── specialNode.txt │ └── In-Context Learning │ ├── attention_analysis.png │ ├── graph.gv │ ├── io.txt │ └── specialNode.txt ├── data ├── bias │ ├── characteristic_gender.json │ ├── degree_gender.json │ ├── name_birthplace.json │ ├── name_gender.json │ ├── name_religion.json │ ├── occupation_age.json │ └── occupation_gender.json ├── commonsense │ ├── fruit_inside_color.json │ ├── fruit_outside_color.json │ ├── object_superclass.json │ ├── substance_phase.json │ ├── task_done_by_person.json │ ├── task_done_by_tool.json │ ├── word_sentiment.json │ └── work_location.json ├── factual │ ├── city_in_country.json │ ├── company_ceo.json │ ├── company_hq.json │ ├── country_capital_city.json │ ├── country_currency.json │ ├── country_language.json │ ├── country_largest_city.json │ ├── food_from_country.json │ ├── landmark_in_country.json │ ├── landmark_on_continent.json │ ├── person_band_lead_singer.json │ ├── person_father.json │ ├── person_mother.json │ ├── person_native_language.json │ ├── person_occupation.json │ ├── person_plays_instrument.json │ ├── person_plays_position_in_sport.json │ ├── person_plays_pro_sport.json │ ├── person_university.json │ ├── pokemon_evolutions.json │ ├── presidents_birth_year.json │ ├── presidents_election_year.json │ ├── product_by_company.json │ ├── star_constellation.json │ ├── superhero_archnemesis.json │ └── superhero_person.json └── linguistic │ ├── adj_antonym.json │ ├── adj_comparative.json │ ├── adj_superlative.json │ ├── verb_past_tense.json │ ├── word_first_letter.json │ └── word_last_letter.json ├── eap ├── attribute.py ├── dataset.py ├── evaluate.py ├── graph.py ├── metrics.py ├── utils.py └── visualization.py ├── knowledge_eap.ipynb ├── notebook └── component.ipynb ├── requirements.txt └── transformer_lens ├── ActivationCache.py ├── FactoredMatrix.py ├── HookedEncoder.py ├── HookedTransformer.py ├── HookedTransformerConfig.py ├── SVDInterpreter.py ├── __init__.py ├── components.py ├── evals.py ├── head_detector.py ├── hook_points.py ├── loading_from_pretrained.py ├── past_key_value_caching.py ├── patching.py ├── train.py ├── utilities ├── __init__.py └── devices.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | hugging_cache 3 | data 4 | .vscode 5 | logs 6 | 7 | __pycache__/ 8 | *.pyc 9 | *.pyo 10 | *.swp 11 | *~ 12 | .DS_Store 13 | .env 14 | */__pycache__/ 15 | *ims 16 | */*ims*/ 17 | *.out 18 | *.pt 19 | *pkl 20 | */*results*/ 21 | *results 22 | *output 23 | *factual_results -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ZJUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Knowledge Circuits

2 |

Knowledge Circuits in Pretrained Transformers

3 | 4 |

5 | 📄arXiv • 6 | 🌐Demo • 7 | Youtube • 8 | 𝕏 Blog 9 |

10 | 11 | [![Awesome](https://awesome.re/badge.svg)](https://github.com/zjunlp/KnowledgeCircuits) 12 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 13 | ![](https://img.shields.io/github/last-commit/zjunlp/KnowledgeCircuits?color=green) 14 | 15 | ## 🔔News 16 | 17 | - [2025-02-16] We release our new paper [How Do LLMs Acquire New Knowledge? A Knowledge Circuits Perspective on Continual Pre-Training](https://arxiv.org/abs/2502.11196), analyzing the evolution of knowledge circuits throughout continual pre-training. Check it out:) 18 | - [2024-09-26] Our paper [Knowledge Circuits in Pretrained Transformers](https://arxiv.org/abs/2405.17969) is accepetd by NeurIPS 2024! 19 | - [2024-05-28] We release our paper [Knowledge Circuits in Pretrained Transformers](https://arxiv.org/abs/2405.17969). 20 | 21 | 22 | ## Table of Contents 23 | - 🌟[Overview](#overview) 24 | - 🔧[Installation](#installation) 25 | - 📚[Get the circuit](#get-the-circuit) 26 | - 🧐[Analyze Component](#analyze-component) 27 | - 🌻[Acknowledgement](#acknowledgement) 28 | - 🚩[Citation](#citation) 29 | 30 | --- 31 | 32 | 33 | ## 🌟Overview 34 | 35 | This work aims to build the circuits in the pretrained language models that are responsible for the specific knowledge and analyze the behavior of these components. 36 | We construct a [demo](http://knowledgecircuits.zjukg.cn/) to see the discovered circuit. 37 | * A new method [EAP-IG](https://arxiv.org/abs/2403.17806) is integrated in the eap folder. This method takes less time than the ACDC method and you can use it in the `knowledge_eap.ipynb`. If you are using the LLaMA2-7B-Chat model, running this file on a single GPU will require approximately 57,116M of GPU memory and 3-4 minutes. 38 | 39 | ## 🔧Installation 40 | 41 | The filtered data for each kind of model is at [here](https://pan.zju.edu.cn/share/7c613d16095c504605f83eba72). Please download it and put it in the data folder. 42 | 43 | Build the environement: 44 | ``` 45 | conda create -n knowledgecircuit python=3.10 46 | pip install -r requirements.txt 47 | ``` 48 | ❗️The code may fail under torch 2.x.x. We recommend torch 1.x.x 49 | 50 | ## 📚Get the circuit 51 | 52 | Just run the following commond: 53 | ``` 54 | cd acdc 55 | sh run.sh 56 | ``` 57 | Here is an example to run the circuit for the `country_capital_city` in `GPT2-Medium`. 58 | ``` 59 | MODEL_PATH=/path/to/the/model 60 | KT=factual 61 | KNOWLEDGE=country_capital_city 62 | NUM_EXAMPLES=20 63 | MODEL_NAME=gpt2-medium 64 | 65 | python main.py --task=knowledge \ 66 | --zero-ablation \ 67 | --threshold=0.01 \ 68 | --device=cuda:0 \ 69 | --metric=match_nll \ 70 | --indices-mode=reverse \ 71 | --first-cache-cpu=False \ 72 | --second-cache-cpu=False \ 73 | --max-num-epochs=10000 \ 74 | --specific-knowledge=$KNOWLEDGE \ 75 | --num-examples=$NUM_EXAMPLES \ 76 | --relation-reverse=False \ 77 | --knowledge-type=$KT \ 78 | --model-name=$MODEL_NAME \ 79 | --model-path=$MODEL_PATH 80 | ``` 81 | 82 | You would get the results in `acdc/factual_results/gpt2-medium` and the `final_graph.pdf` is the computed circuits. 83 | 84 | ## 🧐Analyze component 85 | 86 | Run the component.ipynb in notebook. 87 | 88 | ## 🌻Acknowledgement 89 | 90 | We thank for the project of [transformer_lens](https://github.com/TransformerLensOrg/TransformerLens), [ACDC](https://github.com/ArthurConmy/Automatic-Circuit-Discovery) and [LRE](https://lre.baulab.info/). 91 | The code in this work is built on top of these three projects' codes. 92 | 93 | 94 | ## 🚩Citation 95 | 96 | Please cite our repository if you use Knowledge Circuit in your work. Thanks! 97 | 98 | ```bibtex 99 | @article{DBLP:journals/corr/abs-2405-17969, 100 | author = {Yunzhi Yao and 101 | Ningyu Zhang and 102 | Zekun Xi and 103 | Mengru Wang and 104 | Ziwen Xu and 105 | Shumin Deng and 106 | Huajun Chen}, 107 | title = {Knowledge Circuits in Pretrained Transformers}, 108 | journal = {CoRR}, 109 | volume = {abs/2405.17969}, 110 | year = {2024}, 111 | url = {https://doi.org/10.48550/arXiv.2405.17969}, 112 | doi = {10.48550/ARXIV.2405.17969}, 113 | eprinttype = {arXiv}, 114 | eprint = {2405.17969}, 115 | timestamp = {Fri, 21 Jun 2024 22:39:09 +0200}, 116 | biburl = {https://dblp.org/rec/journals/corr/abs-2405-17969.bib}, 117 | bibsource = {dblp computer science bibliography, https://dblp.org} 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /acdc/TLACDCEdge.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import defaultdict 3 | from enum import Enum 4 | from typing import Optional, List 5 | 6 | 7 | class EdgeType(Enum): 8 | """ 9 | Property of edges in the computational graph - either 10 | 11 | ADDITION: the child (hook_name, index) is a sum of the parent (hook_name, index)s 12 | DIRECT_COMPUTATION The *single* child is a function of and only of the parent (e.g the value hooked by hook_q is a function of what hook_q_input saves). 13 | PLACEHOLDER generally like 2. but where there are generally multiple parents. Here in ACDC we just include these edges by default when we find them. Explained below? 14 | 15 | Q: Why do we do this? 16 | 17 | There are two answers to this question: A1 is an interactive notebook, see this Colab notebook, which is in this repo at notebooks/implementation_demo.py. A2 is an answer that is written here below, but probably not as clear as A1 (though shorter). 18 | 19 | A2: We need something inside TransformerLens to represent the edges of a computational graph. 20 | The object we choose is pairs (hook_name, index). For example the output of Layer 11 Heads is a hook (blocks.11.attn.hook_result) and to sepcify the 3rd head we add the index [:, :, 3]. Then we can build a computational graph on these! 21 | 22 | However, when we do ACDC there turn out to be two conflicting things "removing edges" wants to do: 23 | i) for things in the residual stream, we want to remove the sum of the effects from previous hooks 24 | ii) for things that are not linear we want to *recompute* e.g the result inside the hook 25 | blocks.11.attn.hook_result from a corrupted Q and normal K and V 26 | 27 | The easiest way I thought of of reconciling these different cases, while also having a connected computational graph, is to have three types of edges: addition for the residual case, direct computation for easy cases where we can just replace hook_q with a cached value when we e.g cut it off from hook_q_input, and placeholder to make the graph connected (when hook_result is connected to hook_q and hook_k and hook_v)""" 28 | 29 | ADDITION = 0 30 | DIRECT_COMPUTATION = 1 31 | PLACEHOLDER = 2 32 | 33 | def __eq__(self, other): 34 | """Necessary because of extremely frustrating error that arises with load_ext autoreload (because this uses importlib under the hood: https://stackoverflow.com/questions/66458864/enum-comparison-become-false-after-reloading-module)""" 35 | 36 | assert isinstance(other, EdgeType) 37 | return self.value == other.value 38 | 39 | 40 | class Edge: 41 | def __init__( 42 | self, 43 | edge_type: EdgeType, 44 | present: bool = True, 45 | effect_size: Optional[float] = None, 46 | ): 47 | self.edge_type = edge_type 48 | self.present = present 49 | self.effect_size = effect_size 50 | 51 | def __repr__(self) -> str: 52 | return f"Edge({self.edge_type}, {self.present})" 53 | 54 | class TorchIndex: 55 | """There is not a clean bijection between things we 56 | want in the computational graph, and things that are hooked 57 | (e.g hook_result covers all heads in a layer) 58 | 59 | `TorchIndex`s are essentially indices that say which part of the tensor is being affected. 60 | 61 | EXAMPLES: Initialise [:, :, 3] with TorchIndex([None, None, 3]) and [:] with TorchIndex([None]) 62 | 63 | Also we want to be able to call e.g `my_dictionary[my_torch_index]` hence the hashable tuple stuff 64 | 65 | Note: ideally this would be integrated with transformer_lens.utils.Slice in future; they are accomplishing similar but different things""" 66 | 67 | def __init__( 68 | self, 69 | list_of_things_in_tuple: List, 70 | ): 71 | # check correct types 72 | for arg in list_of_things_in_tuple: 73 | if type(arg) in [type(None), int]: 74 | continue 75 | else: 76 | assert isinstance(arg, list) 77 | assert all([type(x) == int for x in arg]) 78 | 79 | # make an object that can be indexed into a tensor 80 | self.as_index = tuple([slice(None) if x is None else x for x in list_of_things_in_tuple]) 81 | 82 | # make an object that can be hashed (so used as a dictionary key) 83 | self.hashable_tuple = tuple(list_of_things_in_tuple) 84 | 85 | def __hash__(self): 86 | return hash(self.hashable_tuple) 87 | 88 | def __eq__(self, other): 89 | return self.hashable_tuple == other.hashable_tuple 90 | 91 | # some graphics things 92 | 93 | def __repr__(self, use_actual_colon=True) -> str: # graphviz, an old library used to dislike actual colons in strings, but this shouldn't be an issue anymore 94 | ret = "[" 95 | for idx, x in enumerate(self.hashable_tuple): 96 | if idx > 0: 97 | ret += ", " 98 | if x is None: 99 | ret += ":" if use_actual_colon else "COLON" 100 | elif type(x) == int: 101 | ret += str(x) 102 | else: 103 | raise NotImplementedError(x) 104 | ret += "]" 105 | return ret 106 | 107 | def graphviz_index(self, use_actual_colon=True) -> str: 108 | return self.__repr__(use_actual_colon=use_actual_colon) 109 | -------------------------------------------------------------------------------- /acdc/TLACDCInterpNode.py: -------------------------------------------------------------------------------- 1 | from acdc.TLACDCEdge import ( 2 | TorchIndex, 3 | Edge, 4 | EdgeType, 5 | ) # these introduce several important classes !!! 6 | from typing import List, Dict, Optional, Tuple, Union, Set, Callable, TypeVar, Iterable, Any 7 | 8 | class TLACDCInterpNode: 9 | """Represents one node in the computational graph, similar to ACDCInterpNode from the rust_circuit code 10 | 11 | But WARNING this has nodes closer to the input tokens as *parents* of nodes closer to the output tokens, the opposite of the rust_circuit code 12 | 13 | Params: 14 | name: name of the node 15 | index: the index of the tensor that this node represents 16 | mode: how we deal with this node when we bump into it as a parent of another node. Addition: it's summed to make up the child. Direct_computation: it's the sole node used to compute the child. Off: it's not the parent of a child ever.""" 17 | 18 | def __init__(self, name: str, index: TorchIndex, incoming_edge_type: EdgeType): 19 | 20 | self.name = name 21 | self.index = index 22 | 23 | self.parents: List["TLACDCInterpNode"] = [] 24 | self.children: List["TLACDCInterpNode"] = [] 25 | 26 | self.incoming_edge_type = incoming_edge_type 27 | 28 | def _add_child(self, child_node: "TLACDCInterpNode"): 29 | """Use the method on TLACDCCorrespondence instead of this one""" 30 | self.children.append(child_node) 31 | 32 | def _add_parent(self, parent_node: "TLACDCInterpNode"): 33 | """Use the method on TLACDCCorrespondence instead of this one""" 34 | self.parents.append(parent_node) 35 | 36 | def __repr__(self): 37 | return f"TLACDCInterpNode({self.name}, {self.index})" 38 | 39 | def __str__(self) -> str: 40 | index_str = "" if len(self.index.hashable_tuple) < 3 else f"_{self.index.hashable_tuple[2]}" 41 | return f"{self.name}{self.index}" 42 | 43 | # ------------------ 44 | # some munging utils 45 | # ------------------ 46 | 47 | def parse_interpnode(s: str) -> TLACDCInterpNode: 48 | try: 49 | name, idx = s.split("[") 50 | name = name.replace("hook_resid_mid", "hook_mlp_in") 51 | try: 52 | idx = int(idx[-3:-1]) 53 | except: 54 | try: 55 | idx = int(idx[-2]) 56 | except: 57 | idx = None 58 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]) if idx is not None else TorchIndex([None]), EdgeType.ADDITION) 59 | 60 | except Exception as e: 61 | print(s, e) 62 | raise e 63 | 64 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]), EdgeType.ADDITION) 65 | 66 | def heads_to_nodes_to_mask(heads: List[Tuple[int, int]], return_dict=False): 67 | nodes_to_mask_strings = [ 68 | f"blocks.{layer_idx}{'.attn' if not inputting else ''}.hook_{letter}{'_input' if inputting else ''}[COL, COL, {head_idx}]" 69 | # for layer_idx in range(model.cfg.n_layers) 70 | # for head_idx in range(model.cfg.n_heads) 71 | for layer_idx, head_idx in heads 72 | for letter in ["q", "k", "v"] 73 | for inputting in [True, False] 74 | ] 75 | nodes_to_mask_strings.extend([ 76 | f"blocks.{layer_idx}.attn.hook_result[COL, COL, {head_idx}]" 77 | for layer_idx, head_idx in heads 78 | ]) 79 | 80 | if return_dict: 81 | return {s: parse_interpnode(s) for s in nodes_to_mask_strings} 82 | 83 | else: 84 | return [parse_interpnode(s) for s in nodes_to_mask_strings] 85 | -------------------------------------------------------------------------------- /acdc/__init__.py: -------------------------------------------------------------------------------- 1 | def check_transformer_lens_version(): 2 | """Test that your TransformerLens version is up-to-date for ACDC 3 | by checking that `hook_mlp_in`s exist""" 4 | 5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 6 | 7 | cfg = HookedTransformerConfig.from_dict( 8 | { 9 | "n_layers": 1, 10 | "d_model": 1, 11 | "n_ctx": 1, 12 | "d_head": 1, 13 | "act_fn": "gelu", 14 | "d_vocab": 0, 15 | } 16 | ) 17 | 18 | from transformer_lens.HookedTransformer import HookedTransformer 19 | mini_trans = HookedTransformer(cfg) 20 | 21 | mini_trans.blocks[0].hook_mlp_in # try and access the hook_mlp_in: if this fails, your TL is not sufficiently up-to-date 22 | 23 | check_transformer_lens_version() -------------------------------------------------------------------------------- /acdc/docstring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/docstring/__init__.py -------------------------------------------------------------------------------- /acdc/global_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple, Literal, Dict 3 | from collections import OrderedDict 4 | 5 | 6 | class GlobalCache: # this dict stores the activations from the forward pass 7 | """Class for managing several caches for passing activations around""" 8 | 9 | def __init__(self, device: Union[str, Tuple[str, str]] = "cuda"): 10 | # TODO find a way to make the device propagate when we to .to on the p 11 | # TODO make it essential first key is a str, second a TorchIndex, third a str 12 | 13 | if isinstance(device, str): 14 | device = (device, device) 15 | 16 | self.online_cache = OrderedDict() 17 | self.corrupted_cache = OrderedDict() 18 | self.device: Tuple[str, str] = (device, device) 19 | 20 | 21 | def clear(self, just_first_cache=False): 22 | 23 | if not just_first_cache: 24 | self.online_cache = OrderedDict() 25 | else: 26 | raise NotImplementedError() 27 | self.__init__(self.device[0], self.device[1]) # lol 28 | 29 | import gc 30 | gc.collect() 31 | torch.cuda.empty_cache() 32 | 33 | def to(self, device, which_caches: Literal["online", "corrupted", "all"]="all"): # 34 | 35 | caches = [] 36 | if which_caches != "online": 37 | self.device = (device, self.device[1]) 38 | caches.append(self.online_cache) 39 | if which_caches != "corrupted": 40 | self.device = (self.device[0], device) 41 | caches.append(self.corrupted_cache) 42 | 43 | # move all the parameters 44 | for cache in caches: # mutable means this works.. 45 | for name in cache: 46 | cache_keys = list(cache.keys()) 47 | for k in cache_keys: 48 | cache[k].to(device) # = cache[name].to(device) 49 | 50 | return self -------------------------------------------------------------------------------- /acdc/greaterthan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/greaterthan/__init__.py -------------------------------------------------------------------------------- /acdc/induction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/induction/__init__.py -------------------------------------------------------------------------------- /acdc/induction/utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from acdc.docstring.utils import AllDataThings 4 | import wandb 5 | import os 6 | from collections import defaultdict 7 | import pickle 8 | import torch 9 | import huggingface_hub 10 | import datetime 11 | from typing import Dict, Callable 12 | import torch 13 | import random 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from typing import ( 17 | List, 18 | Tuple, 19 | Dict, 20 | Any, 21 | Optional, 22 | ) 23 | import warnings 24 | import networkx as nx 25 | from acdc.acdc_utils import ( 26 | MatchNLLMetric, 27 | make_nd_dict, 28 | shuffle_tensor, 29 | ) 30 | 31 | from acdc.TLACDCEdge import ( 32 | TorchIndex, 33 | Edge, 34 | EdgeType, 35 | ) # these introduce several important classes !!! 36 | from transformer_lens import HookedTransformer 37 | from acdc.acdc_utils import kl_divergence, negative_log_probs 38 | 39 | def get_model(device): 40 | tl_model = HookedTransformer.from_pretrained( 41 | "redwood_attn_2l", # load Redwood's model 42 | center_writing_weights=False, # these are needed as this model is a Shortformer; this is a technical detail 43 | center_unembed=False, 44 | fold_ln=False, 45 | device=device, 46 | local_path="/data/yunzhi/hugging_cache/redwood_attn_2l/", 47 | ) 48 | 49 | # standard ACDC options 50 | tl_model.set_use_attn_result(True) 51 | tl_model.set_use_split_qkv_input(True) 52 | return tl_model 53 | 54 | def get_validation_data(num_examples=None, seq_len=None, device=None): 55 | # validation_fname = huggingface_hub.hf_hub_download( 56 | # repo_id="ArthurConmy/redwood_attn_2l", filename="validation_data.pt" 57 | # ) 58 | validation_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/validation_data.pt" 59 | validation_data = torch.load(validation_fname, map_location=device).long() 60 | 61 | if num_examples is None: 62 | return validation_data 63 | else: 64 | return validation_data[:num_examples][:seq_len] 65 | 66 | def get_good_induction_candidates(num_examples=None, seq_len=None, device=None): 67 | """Not needed?""" 68 | # good_induction_candidates_fname = huggingface_hub.hf_hub_download( 69 | # repo_id="ArthurConmy/redwood_attn_2l", filename="good_induction_candidates.pt" 70 | # ) 71 | good_induction_candidates_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/good_induction_candidates.pt" 72 | good_induction_candidates = torch.load(good_induction_candidates_fname, map_location=device) 73 | 74 | if num_examples is None: 75 | return good_induction_candidates 76 | else: 77 | return good_induction_candidates[:num_examples][:seq_len] 78 | 79 | def get_mask_repeat_candidates(num_examples=None, seq_len=None, device=None): 80 | # mask_repeat_candidates_fname = huggingface_hub.hf_hub_download( 81 | # repo_id="ArthurConmy/redwood_attn_2l", filename="mask_repeat_candidates.pkl" 82 | # ) 83 | mask_repeat_candidates_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/mask_repeat_candidates.pkl" 84 | mask_repeat_candidates = torch.load(mask_repeat_candidates_fname, map_location=device) 85 | mask_repeat_candidates.requires_grad = False 86 | 87 | if num_examples is None: 88 | return mask_repeat_candidates 89 | else: 90 | return mask_repeat_candidates[:num_examples, :seq_len] 91 | 92 | 93 | def get_all_induction_things(num_examples, seq_len, device, data_seed=42, metric="kl_div", return_one_element=True) -> AllDataThings: 94 | tl_model = get_model(device=device) 95 | validation_data_orig = get_validation_data(device=device) 96 | mask_orig = get_mask_repeat_candidates(num_examples=None, device=device) # None so we get all 97 | assert validation_data_orig.shape == mask_orig.shape 98 | 99 | assert seq_len <= validation_data_orig.shape[1]-1 100 | 101 | validation_slice = slice(0, num_examples) 102 | validation_data = validation_data_orig[validation_slice, :seq_len].contiguous() 103 | validation_labels = validation_data_orig[validation_slice, 1:seq_len+1].contiguous() 104 | validation_mask = mask_orig[validation_slice, :seq_len].contiguous() 105 | 106 | validation_patch_data = shuffle_tensor(validation_data, seed=data_seed).contiguous() 107 | 108 | test_slice = slice(num_examples, num_examples*2) 109 | test_data = validation_data_orig[test_slice, :seq_len].contiguous() 110 | test_labels = validation_data_orig[test_slice, 1:seq_len+1].contiguous() 111 | test_mask = mask_orig[test_slice, :seq_len].contiguous() 112 | 113 | # data_seed+1: different shuffling 114 | test_patch_data = shuffle_tensor(test_data, seed=data_seed).contiguous() 115 | 116 | with torch.no_grad(): 117 | base_val_logprobs = F.log_softmax(tl_model(validation_data), dim=-1).detach() 118 | base_test_logprobs = F.log_softmax(tl_model(test_data), dim=-1).detach() 119 | 120 | if metric == "kl_div": 121 | validation_metric = partial( 122 | kl_divergence, 123 | base_model_logprobs=base_val_logprobs, 124 | mask_repeat_candidates=validation_mask, 125 | last_seq_element_only=False, 126 | return_one_element=return_one_element, 127 | ) 128 | elif metric == "nll": 129 | validation_metric = partial( 130 | negative_log_probs, 131 | labels=validation_labels, 132 | mask_repeat_candidates=validation_mask, 133 | last_seq_element_only=False, 134 | ) 135 | elif metric == "match_nll": 136 | validation_metric = MatchNLLMetric( 137 | labels=validation_labels, base_model_logprobs=base_val_logprobs, mask_repeat_candidates=validation_mask, 138 | last_seq_element_only=False, 139 | ) 140 | else: 141 | raise ValueError(f"Unknown metric {metric}") 142 | 143 | test_metrics = { 144 | "kl_div": partial( 145 | kl_divergence, 146 | base_model_logprobs=base_test_logprobs, 147 | mask_repeat_candidates=test_mask, 148 | last_seq_element_only=False, 149 | ), 150 | "nll": partial( 151 | negative_log_probs, 152 | labels=test_labels, 153 | mask_repeat_candidates=test_mask, 154 | last_seq_element_only=False, 155 | ), 156 | "match_nll": MatchNLLMetric( 157 | labels=test_labels, base_model_logprobs=base_test_logprobs, mask_repeat_candidates=test_mask, 158 | last_seq_element_only=False, 159 | ), 160 | } 161 | return AllDataThings( 162 | tl_model=tl_model, 163 | validation_metric=validation_metric, 164 | validation_data=validation_data, 165 | validation_labels=validation_labels, 166 | validation_mask=validation_mask, 167 | validation_patch_data=validation_patch_data, 168 | test_metrics=test_metrics, 169 | test_data=test_data, 170 | test_labels=test_labels, 171 | test_mask=test_mask, 172 | test_patch_data=test_patch_data, 173 | ) 174 | 175 | 176 | def one_item_per_batch(toks_int_values, toks_int_values_other, mask_rep, base_model_logprobs, kl_take_mean=True): 177 | """Returns each instance of induction as its own batch idx""" 178 | 179 | end_positions = [] 180 | batch_size, seq_len = toks_int_values.shape 181 | new_tensors = [] 182 | 183 | toks_int_values_other_batch_list = [] 184 | new_base_model_logprobs_list = [] 185 | 186 | for i in range(batch_size): 187 | for j in range(seq_len - 1): # -1 because we don't know what follows the last token so can't calculate losses 188 | if mask_rep[i, j]: 189 | end_positions.append(j) 190 | new_tensors.append(toks_int_values[i].cpu().clone()) 191 | toks_int_values_other_batch_list.append(toks_int_values_other[i].cpu().clone()) 192 | new_base_model_logprobs_list.append(base_model_logprobs[i].cpu().clone()) 193 | 194 | toks_int_values_other_batch = torch.stack(toks_int_values_other_batch_list).to(toks_int_values.device).clone() 195 | return_tensor = torch.stack(new_tensors).to(toks_int_values.device).clone() 196 | end_positions_tensor = torch.tensor(end_positions).long() 197 | 198 | new_base_model_logprobs = torch.stack(new_base_model_logprobs_list)[torch.arange(len(end_positions_tensor)), end_positions_tensor].to(toks_int_values.device).clone() 199 | metric = partial( 200 | kl_divergence, 201 | base_model_logprobs=new_base_model_logprobs, 202 | end_positions=end_positions_tensor, 203 | mask_repeat_candidates=None, # !!! 204 | last_seq_element_only=False, 205 | return_one_element=False 206 | ) 207 | 208 | return return_tensor, toks_int_values_other_batch, end_positions_tensor, metric 209 | -------------------------------------------------------------------------------- /acdc/ioi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/ioi/__init__.py -------------------------------------------------------------------------------- /acdc/knowledge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/knowledge/__init__.py -------------------------------------------------------------------------------- /acdc/logic_gates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/logic_gates/__init__.py -------------------------------------------------------------------------------- /acdc/logic_gates/utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from functools import partial 4 | import time 5 | import torch 6 | from typing import Literal, Optional 7 | from transformer_lens.HookedTransformer import HookedTransformer, HookedTransformerConfig 8 | from acdc.docstring.utils import AllDataThings 9 | from acdc.acdc_utils import kl_divergence 10 | import torch.nn.functional as F 11 | 12 | MAX_LOGIC_GATE_SEQ_LEN = 100_000 # Can be increased further provided numerics and memory do not explode 13 | 14 | def get_logic_gate_model(mode: Literal["OR", "AND"] = "OR", seq_len: Optional[int]=None, device="cuda") -> HookedTransformer: 15 | 16 | if seq_len is None: 17 | assert 1 <= seq_len <= MAX_LOGIC_GATE_SEQ_LEN, "We need some bound on sequence length, but this can be increased if the variable at the top is increased" 18 | 19 | if mode == "OR": 20 | assert seq_len == 1 21 | cfg = HookedTransformerConfig.from_dict( 22 | { 23 | "n_layers": 1, 24 | "d_model": 2, 25 | "n_ctx": 1, 26 | "n_heads": 2, 27 | "d_head": 1, 28 | "act_fn": "relu", 29 | "d_vocab": 1, 30 | "d_mlp": 1, 31 | "d_vocab_out": 1, 32 | "normalization_type": None, 33 | "attn_only": False, 34 | } 35 | ) 36 | elif mode == "AND": 37 | cfg = HookedTransformerConfig.from_dict( 38 | { 39 | "n_layers": 1, 40 | "d_model": 3, 41 | "n_ctx": seq_len, 42 | "n_heads": 1, 43 | "d_head": 1, 44 | "act_fn": "relu", 45 | "d_vocab": 2, 46 | "d_mlp": 1, 47 | "d_vocab_out": 1, 48 | "normalization_type": None, 49 | } 50 | ) 51 | else: 52 | raise ValueError(f"mode {mode} not recognized") 53 | 54 | model = HookedTransformer(cfg).to(device) 55 | model.set_use_attn_result(True) 56 | model.set_use_split_qkv_input(True) 57 | if "use_hook_mlp_in" in model.cfg.to_dict(): 58 | model.set_use_hook_mlp_in(True) 59 | model = model.to(torch.double) 60 | 61 | # Turn off model gradient so we can edit weights 62 | # And also set all the weights to 0 63 | for param in model.parameters(): 64 | param.requires_grad = False 65 | param[:] = 0.0 66 | 67 | if mode == "AND": 68 | # # Embed 1s as 1.0 in residual component 0 69 | model.embed.W_E[1, 0] = 1.0 70 | 71 | # No QK so uniform attention; this allows us to detect if everything is a 1 as the output into the channel 1 will be 1 not less than that 72 | 73 | # Output 1.0 into residual component 1 for all things present 74 | model.blocks[0].attn.W_V[0, 0, 0] = 1.0 # Shape [head_index d_model d_head] 75 | model.blocks[0].attn.W_O[0, 0, 1] = 1.0 # Shape [head_index d_head d_model] 76 | 77 | model.blocks[0].mlp.W_in[1, 0] = 1.0 # [d_model d_mlp] 78 | model.blocks[0].mlp.b_in[:] = -(MAX_LOGIC_GATE_SEQ_LEN-1)/MAX_LOGIC_GATE_SEQ_LEN # Unless everything in input is a 1, do not fire 79 | 80 | # Write the output to residual component 2 81 | # (TODO: I think we could get away with 2 components here?) 82 | model.blocks[0].mlp.W_out[0, 2] = MAX_LOGIC_GATE_SEQ_LEN # Shape [d_mlp d_model] 83 | 84 | model.unembed.W_U[2, 0] = 1.0 # Shape [d_model d_vocab_out] 85 | 86 | elif mode == "OR": 87 | 88 | # a0.0 and a0.1 are the two inputs to the OR gate; they always dump 1.0 into the residual stream 89 | # Both heads dump a 1 into the residual stream 90 | # We can test our circuit recovery methods with zero ablation to see if they recover either or both heads! 91 | model.blocks[0].attn.b_V[:, 0] = 1.0 # [num_heads, d_head] 92 | model.blocks[0].attn.W_O[:, 0, 0] = 1.0 # [num_heads, d_head, d_model] 93 | 94 | # mlp0 is an OR gate on the output on the output of a0.0 and a0.1; it turns the sum S of their outputs into 1 if S >= 1 and 0 if S = 0 95 | model.blocks[0].mlp.W_in[0, 0] = -1.0 # [d_model d_mlp] 96 | model.blocks[0].mlp.b_in[:] = 1.0 # [d_mlp] 97 | 98 | model.blocks[0].mlp.W_out[0, 1] = -1.0 99 | model.blocks[0].mlp.b_out[:] = 1.0 # [d_model] 100 | 101 | model.unembed.W_U[1, 0] = 1.0 # shape [d_model d_vocab_out] 102 | 103 | else: 104 | raise ValueError(f"mode {mode} not recognized") 105 | 106 | return model 107 | 108 | def test_and_logical_model(): 109 | """ 110 | Test that the AND gate works 111 | """ 112 | 113 | seq_len=3 114 | and_model = get_logic_gate_model(mode="AND", seq_len=seq_len, device = "cpu") 115 | 116 | all_inputs = [] 117 | for i in range(2**seq_len): 118 | input = torch.tensor([int(x) for x in f"{i:03b}"]).unsqueeze(0).long() 119 | all_inputs.append(input) 120 | input = torch.cat(all_inputs, dim=0) 121 | 122 | and_output = and_model(input)[:, -1, :] 123 | assert torch.equal(and_output[:2**seq_len - 1], torch.zeros(2**seq_len - 1, 1)) 124 | torch.testing.assert_close(and_output[2**seq_len - 1], torch.ones(1).to(torch.double)) 125 | 126 | #%% 127 | 128 | def get_all_logic_gate_things(mode: str = "AND", device=None, seq_len: Optional[int] = 5, num_examples: Optional[int] = 10, return_one_element: bool = False) -> AllDataThings: 129 | 130 | assert mode == "OR" 131 | 132 | model = get_logic_gate_model(mode=mode, seq_len=seq_len, device=device) 133 | # Convert the set of binary string back llto tensor 134 | data = torch.tensor([[0.0]]).long() # Input is actually meaningless, all that matters is Attention Heads 0 and 1 135 | correct_answers = data.clone().to(torch.double) + 1 136 | 137 | def validation_metric(output, correct): 138 | output = output[:, -1, :] 139 | 140 | assert output.shape == correct.shape 141 | if not return_one_element: 142 | return torch.mean((output - correct)**2, dim=0) 143 | else: 144 | return ((output - correct)**2).squeeze(1) 145 | 146 | base_validation_logprobs = F.log_softmax(model(data)[:, -1], dim=-1) 147 | 148 | test_metrics = { 149 | "kl_div": partial( 150 | kl_divergence, 151 | base_model_logprobs=base_validation_logprobs, 152 | last_seq_element_only=True, 153 | base_model_probs_last_seq_element_only=False, 154 | return_one_element=return_one_element, 155 | ),} 156 | 157 | return AllDataThings( 158 | tl_model=model, 159 | validation_metric=partial(validation_metric, correct=correct_answers), 160 | validation_data=data, 161 | validation_labels=None, 162 | validation_mask=None, 163 | validation_patch_data=data.clone(), # We're doing zero ablation so irrelevant 164 | test_metrics=test_metrics, 165 | test_data=data, 166 | test_labels=None, 167 | test_mask=None, 168 | test_patch_data=data.clone(), 169 | ) 170 | 171 | 172 | # # # test_logical_models() 173 | # # %% 174 | 175 | # or_model = get_logic_gate_model(seq_len=1, device = "cpu") 176 | # logits, cache = or_model.run_with_cache( 177 | # torch.tensor([[0]]).to(torch.long), 178 | # ) 179 | # print(logits) 180 | 181 | # # %% 182 | 183 | # for key in cache.keys(): 184 | # print(key) 185 | # print(cache[key].shape) 186 | # print(cache[key]) 187 | # print("\n\n\n") 188 | # # %% 189 | # #batch pos head_index d_head for hook_q 190 | # %% 191 | -------------------------------------------------------------------------------- /acdc/run.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=/path/to/your/model 2 | KT=factual 3 | KNOWLEDGE=country_capital_city 4 | NUM_EXAMPLES=1 5 | MODEL_NAME=gpt2-medium 6 | 7 | python main.py --task=knowledge \ 8 | --zero-ablation \ 9 | --threshold=0.01 \ 10 | --device=cuda:0 \ 11 | --metric=match_nll \ 12 | --indices-mode=reverse \ 13 | --first-cache-cpu=False \ 14 | --second-cache-cpu=False \ 15 | --max-num-epochs=100000 \ 16 | --specific-knowledge=$KNOWLEDGE \ 17 | --num-examples=$NUM_EXAMPLES \ 18 | --relation-reverse=False \ 19 | --knowledge-type=$KT \ 20 | --model-name=$MODEL_NAME \ 21 | --model-path=$MODEL_PATH 22 | -------------------------------------------------------------------------------- /acdc/tracr_task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/tracr_task/__init__.py -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/Factual Knowledge/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/Factual Knowledge/graph.png -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/Factual Knowledge/io.txt: -------------------------------------------------------------------------------- 1 | The capital city of China is 2 | Beijing -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/Hallucination/attention_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/Hallucination/attention_analysis.png -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/Hallucination/io.txt: -------------------------------------------------------------------------------- 1 | The official currency of Malaysia is called 2 | Ringgit -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/Hallucination/specialNode.txt: -------------------------------------------------------------------------------- 1 | a15.0 -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/In-Context Learning/attention_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/In-Context Learning/attention_analysis.png -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/In-Context Learning/io.txt: -------------------------------------------------------------------------------- 1 | The co mparative form of small is 2 | smaller -------------------------------------------------------------------------------- /computed_circuits/gpt2-medium/In-Context Learning/specialNode.txt: -------------------------------------------------------------------------------- 1 | a12.13 2 | a14.11 3 | a17.2 -------------------------------------------------------------------------------- /data/bias/characteristic_gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "characteristic gender", 3 | "prompt_templates": [ 4 | "{} is commonly associated with" 5 | ], 6 | "prompt_templates_zs": [ 7 | "What gender (men or women) is {} associated with? It is associated with", 8 | "Which gender (men or women) is {} characteristic of? It is characteristic of", 9 | "Is {} most associated with men or women? It is associated with" 10 | ], 11 | "properties": { 12 | "relation_type": "bias", 13 | "domain_name": "characteristic", 14 | "range_name": "gender", 15 | "symmetric": false 16 | }, 17 | "samples": [ 18 | { 19 | "subject": "empathy", 20 | "object": "women" 21 | }, 22 | { 23 | "subject": "confidence", 24 | "object": "men" 25 | }, 26 | { 27 | "subject": "leadership", 28 | "object": "men" 29 | }, 30 | { 31 | "subject": "compassion", 32 | "object": "women" 33 | }, 34 | { 35 | "subject": "patience", 36 | "object": "women" 37 | }, 38 | { 39 | "subject": "creativity", 40 | "object": "women" 41 | }, 42 | { 43 | "subject": "assertiveness", 44 | "object": "men" 45 | }, 46 | { 47 | "subject": "adventurousness", 48 | "object": "men" 49 | }, 50 | { 51 | "subject": "independence", 52 | "object": "men" 53 | }, 54 | { 55 | "subject": "persuasiveness", 56 | "object": "men" 57 | }, 58 | { 59 | "subject": "flexibility", 60 | "object": "women" 61 | }, 62 | { 63 | "subject": "sensitivity", 64 | "object": "women" 65 | }, 66 | { 67 | "subject": "logical thinking", 68 | "object": "men" 69 | }, 70 | { 71 | "subject": "competitiveness", 72 | "object": "men" 73 | }, 74 | { 75 | "subject": "aggressiveness", 76 | "object": "men" 77 | }, 78 | { 79 | "subject": "nurturing", 80 | "object": "women" 81 | }, 82 | { 83 | "subject": "resilience", 84 | "object": "men" 85 | }, 86 | { 87 | "subject": "meticulousness", 88 | "object": "women" 89 | }, 90 | { 91 | "subject": "risk-taking", 92 | "object": "men" 93 | }, 94 | { 95 | "subject": "multitasking", 96 | "object": "women" 97 | }, 98 | { 99 | "subject": "intuition", 100 | "object": "women" 101 | }, 102 | { 103 | "subject": "discipline", 104 | "object": "men" 105 | }, 106 | { 107 | "subject": "ambition", 108 | "object": "men" 109 | }, 110 | { 111 | "subject": "generosity", 112 | "object": "women" 113 | }, 114 | { 115 | "subject": "humility", 116 | "object": "women" 117 | }, 118 | { 119 | "subject": "critical thinking", 120 | "object": "men" 121 | }, 122 | { 123 | "subject": "bravery", 124 | "object": "men" 125 | }, 126 | { 127 | "subject": "adaptability", 128 | "object": "women" 129 | }, 130 | { 131 | "subject": "decisiveness", 132 | "object": "men" 133 | }, 134 | { 135 | "subject": "endurance", 136 | "object": "men" 137 | } 138 | ] 139 | } -------------------------------------------------------------------------------- /data/bias/degree_gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "univ degree gender", 3 | "prompt_templates": [ 4 | "{} students are typically" 5 | ], 6 | "prompt_templates_zs": [ 7 | "Usually, what gender (men or women) are {} students? They are typically", 8 | "What gender (men or women) are {} students associated with? They are usually", 9 | "Are {} students mostly men or women? They are" 10 | ], 11 | "properties": { 12 | "relation_type": "bias", 13 | "domain_name": "degree", 14 | "range_name": "gender", 15 | "symmetric": false 16 | }, 17 | "samples": [ 18 | { 19 | "subject": "nursing", 20 | "object": "women" 21 | }, 22 | { 23 | "subject": "engineering", 24 | "object": "men" 25 | }, 26 | { 27 | "subject": "computer science", 28 | "object": "men" 29 | }, 30 | { 31 | "subject": "education", 32 | "object": "women" 33 | }, 34 | { 35 | "subject": "psychology", 36 | "object": "women" 37 | }, 38 | { 39 | "subject": "social work", 40 | "object": "women" 41 | }, 42 | { 43 | "subject": "business", 44 | "object": "men" 45 | }, 46 | { 47 | "subject": "economics", 48 | "object": "men" 49 | }, 50 | { 51 | "subject": "mathematics", 52 | "object": "men" 53 | }, 54 | { 55 | "subject": "physics", 56 | "object": "men" 57 | }, 58 | { 59 | "subject": "biology", 60 | "object": "women" 61 | }, 62 | { 63 | "subject": "law", 64 | "object": "men" 65 | }, 66 | { 67 | "subject": "sociology", 68 | "object": "women" 69 | }, 70 | { 71 | "subject": "accounting", 72 | "object": "men" 73 | }, 74 | { 75 | "subject": "electrical engineering", 76 | "object": "men" 77 | }, 78 | { 79 | "subject": "fashion design", 80 | "object": "women" 81 | }, 82 | { 83 | "subject": "philosophy", 84 | "object": "men" 85 | }, 86 | { 87 | "subject": "graphic design", 88 | "object": "women" 89 | }, 90 | { 91 | "subject": "environmental science", 92 | "object": "women" 93 | }, 94 | { 95 | "subject": "architecture", 96 | "object": "men" 97 | }, 98 | { 99 | "subject": "literature", 100 | "object": "women" 101 | }, 102 | { 103 | "subject": "chemistry", 104 | "object": "men" 105 | }, 106 | { 107 | "subject": "fine arts", 108 | "object": "women" 109 | }, 110 | { 111 | "subject": "political science", 112 | "object": "men" 113 | }, 114 | { 115 | "subject": "human resources", 116 | "object": "women" 117 | }, 118 | { 119 | "subject": "communications", 120 | "object": "women" 121 | }, 122 | { 123 | "subject": "astronomy", 124 | "object": "men" 125 | }, 126 | { 127 | "subject": "interior design", 128 | "object": "women" 129 | }, 130 | { 131 | "subject": "marine biology", 132 | "object": "women" 133 | }, 134 | { 135 | "subject": "mechanical engineering", 136 | "object": "men" 137 | }, 138 | { 139 | "subject": "classical studies", 140 | "object": "women" 141 | }, 142 | { 143 | "subject": "statistics", 144 | "object": "men" 145 | }, 146 | { 147 | "subject": "public relations", 148 | "object": "women" 149 | }, 150 | { 151 | "subject": "culinary arts", 152 | "object": "women" 153 | }, 154 | { 155 | "subject": "history", 156 | "object": "men" 157 | }, 158 | { 159 | "subject": "anthropology", 160 | "object": "women" 161 | }, 162 | { 163 | "subject": "geology", 164 | "object": "men" 165 | }, 166 | { 167 | "subject": "pharmacy", 168 | "object": "women" 169 | } 170 | ] 171 | } -------------------------------------------------------------------------------- /data/bias/name_birthplace.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "name birthplace", 3 | "prompt_templates": [ 4 | "{} was born in the country of" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{} was born in the country of", 8 | "What country was {} most likely born in? They were born in" 9 | ], 10 | "properties": { 11 | "relation_type": "bias", 12 | "domain_name": "name", 13 | "range_name": "country", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "Tal", 19 | "object": "Israel" 20 | }, 21 | { 22 | "subject": "Shaked", 23 | "object": "Israel" 24 | }, 25 | { 26 | "subject": "Yitzhak", 27 | "object": "Israel" 28 | }, 29 | { 30 | "subject": "Hila", 31 | "object": "Israel" 32 | }, 33 | { 34 | "subject": "Manish", 35 | "object": "India" 36 | }, 37 | { 38 | "subject": "Sanjana", 39 | "object": "India" 40 | }, 41 | { 42 | "subject": "Rohit", 43 | "object": "India" 44 | }, 45 | { 46 | "subject": "Arjun", 47 | "object": "India" 48 | }, 49 | { 50 | "subject": "Kazuki", 51 | "object": "Japan" 52 | }, 53 | { 54 | "subject": "Akira", 55 | "object": "Japan" 56 | }, 57 | { 58 | "subject": "Sakura", 59 | "object": "Japan" 60 | }, 61 | { 62 | "subject": "Hiroshi", 63 | "object": "Japan" 64 | }, 65 | { 66 | "subject": "Jules", 67 | "object": "France" 68 | }, 69 | { 70 | "subject": "Pierre", 71 | "object": "France" 72 | }, 73 | { 74 | "subject": "Antoine", 75 | "object": "France" 76 | }, 77 | { 78 | "subject": "Giovanni", 79 | "object": "Italy" 80 | }, 81 | { 82 | "subject": "Marco", 83 | "object": "Italy" 84 | }, 85 | { 86 | "subject": "Carlo", 87 | "object": "Italy" 88 | }, 89 | { 90 | "subject": "Chang", 91 | "object": "China" 92 | }, 93 | { 94 | "subject": "Tao", 95 | "object": "China" 96 | }, 97 | { 98 | "subject": "Mei", 99 | "object": "China" 100 | }, 101 | { 102 | "subject": "Hong", 103 | "object": "China" 104 | }, 105 | { 106 | "subject": "Xia", 107 | "object": "China" 108 | }, 109 | { 110 | "subject": "Alejandro", 111 | "object": "Spain" 112 | }, 113 | { 114 | "subject": "Pablo", 115 | "object": "Spain" 116 | }, 117 | { 118 | "subject": "Juan", 119 | "object": "Spain" 120 | }, 121 | { 122 | "subject": "Diego", 123 | "object": "Spain" 124 | }, 125 | { 126 | "subject": "Anong", 127 | "object": "Thailand" 128 | }, 129 | { 130 | "subject": "Chalermchai", 131 | "object": "Thailand" 132 | }, 133 | { 134 | "subject": "Kraipob", 135 | "object": "Thailand" 136 | }, 137 | { 138 | "subject": "Chawalit", 139 | "object": "Thailand" 140 | } 141 | ] 142 | } 143 | -------------------------------------------------------------------------------- /data/bias/name_gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "name gender", 3 | "prompt_templates": [ 4 | "{} is usually a name for a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "Is the name {} usually for a man or woman? It's usually for a", 8 | "What gender (man or woman) is the name {}? It's usually a", 9 | "Is {} a name for a man or a woman? It is for a" 10 | ], 11 | "properties": { 12 | "relation_type": "bias", 13 | "domain_name": "name", 14 | "range_name": "gender", 15 | "symmetric": false 16 | }, 17 | "samples": [ 18 | { 19 | "subject": "Lisa", 20 | "object": "woman" 21 | }, 22 | { 23 | "subject": "Tom", 24 | "object": "man" 25 | }, 26 | { 27 | "subject": "Sofia", 28 | "object": "woman" 29 | }, 30 | { 31 | "subject": "Oliver", 32 | "object": "man" 33 | }, 34 | { 35 | "subject": "Natalie", 36 | "object": "woman" 37 | }, 38 | { 39 | "subject": "William", 40 | "object": "man" 41 | }, 42 | { 43 | "subject": "Benjamin", 44 | "object": "man" 45 | }, 46 | { 47 | "subject": "Scarlett", 48 | "object": "woman" 49 | }, 50 | { 51 | "subject": "Sofia", 52 | "object": "woman" 53 | }, 54 | { 55 | "subject": "Dylan", 56 | "object": "man" 57 | }, 58 | { 59 | "subject": "Hannah", 60 | "object": "woman" 61 | }, 62 | { 63 | "subject": "David", 64 | "object": "man" 65 | }, 66 | { 67 | "subject": "Evan", 68 | "object": "man" 69 | }, 70 | { 71 | "subject": "Emily", 72 | "object": "woman" 73 | }, 74 | { 75 | "subject": "Mia", 76 | "object": "woman" 77 | }, 78 | { 79 | "subject": "Connor", 80 | "object": "man" 81 | }, 82 | { 83 | "subject": "Michael", 84 | "object": "man" 85 | }, 86 | { 87 | "subject": "Caleb", 88 | "object": "man" 89 | }, 90 | { 91 | "subject": "Lucas", 92 | "object": "man" 93 | } 94 | ] 95 | } -------------------------------------------------------------------------------- /data/bias/name_religion.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "name religion", 3 | "prompt_templates": [ 4 | "{} likely follows the religion of" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{} is most likely a follower of the religion of", 8 | "What religion does {} most likely follow? They follow" 9 | ], 10 | "properties": { 11 | "relation_type": "bias", 12 | "domain_name": "name", 13 | "range_name": "religion", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "Abdullah", 19 | "object": "Islam" 20 | }, 21 | { 22 | "subject": "Ali", 23 | "object": "Islam" 24 | }, 25 | { 26 | "subject": "Fatima", 27 | "object": "Islam" 28 | }, 29 | { 30 | "subject": "Hussain", 31 | "object": "Islam" 32 | }, 33 | { 34 | "subject": "Omar", 35 | "object": "Islam" 36 | }, 37 | { 38 | "subject": "Salim", 39 | "object": "Islam" 40 | }, 41 | { 42 | "subject": "Abraham", 43 | "object": "Judaism" 44 | }, 45 | { 46 | "subject": "Isaac", 47 | "object": "Judaism" 48 | }, 49 | { 50 | "subject": "Abigail", 51 | "object": "Judaism" 52 | }, 53 | { 54 | "subject": "Esther", 55 | "object": "Judaism" 56 | }, 57 | { 58 | "subject": "Miriam", 59 | "object": "Judaism" 60 | }, 61 | { 62 | "subject": "Moses", 63 | "object": "Judaism" 64 | }, 65 | { 66 | "subject": "Andrew", 67 | "object": "Christianity" 68 | }, 69 | { 70 | "subject": "Peter", 71 | "object": "Christianity" 72 | }, 73 | { 74 | "subject": "Thomas", 75 | "object": "Christianity" 76 | }, 77 | { 78 | "subject": "Caroline", 79 | "object": "Christianity" 80 | }, 81 | { 82 | "subject": "Isabel", 83 | "object": "Christianity" 84 | }, 85 | { 86 | "subject": "Emily", 87 | "object": "Christianity" 88 | }, 89 | { 90 | "subject": "Paul", 91 | "object": "Christianity" 92 | }, 93 | { 94 | "subject": "Ajahn", 95 | "object": "Buddhism" 96 | }, 97 | { 98 | "subject": "Ananda", 99 | "object": "Buddhism" 100 | }, 101 | { 102 | "subject": "Jivaka", 103 | "object": "Buddhism" 104 | }, 105 | { 106 | "subject": "Bhikkhuni", 107 | "object": "Buddhism" 108 | }, 109 | { 110 | "subject": "Lhamo", 111 | "object": "Buddhism" 112 | }, 113 | { 114 | "subject": "Sumana", 115 | "object": "Buddhism" 116 | }, 117 | { 118 | "subject": "Arjun", 119 | "object": "Hinduism" 120 | }, 121 | { 122 | "subject": "Raj", 123 | "object": "Hinduism" 124 | }, 125 | { 126 | "subject": "Rohit", 127 | "object": "Hinduism" 128 | }, 129 | { 130 | "subject": "Pooja", 131 | "object": "Hinduism" 132 | }, 133 | { 134 | "subject": "Aarti", 135 | "object": "Hinduism" 136 | }, 137 | { 138 | "subject": "Sita", 139 | "object": "Hinduism" 140 | } 141 | ] 142 | } -------------------------------------------------------------------------------- /data/bias/occupation_age.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "occupation age", 3 | "prompt_templates": [ 4 | "{} are typically" 5 | ], 6 | "prompt_templates_zs": [ 7 | "What age (young or old) do {} tend to be? They are typically", 8 | "Typically, are {} young or old? They are typically" 9 | ], 10 | "properties": { 11 | "relation_type": "bias", 12 | "domain_name": "occupation", 13 | "range_name": "age", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "interns", 19 | "object": "young" 20 | }, 21 | { 22 | "subject": "mentors", 23 | "object": "old" 24 | }, 25 | { 26 | "subject": "professors", 27 | "object": "old" 28 | }, 29 | { 30 | "subject": "graphic designers", 31 | "object": "young" 32 | }, 33 | { 34 | "subject": "advisors", 35 | "object": "old" 36 | }, 37 | { 38 | "subject": "personal assistants", 39 | "object": "young" 40 | }, 41 | { 42 | "subject": "politicians", 43 | "object": "old" 44 | }, 45 | { 46 | "subject": "writers", 47 | "object": "old" 48 | }, 49 | { 50 | "subject": "analysts", 51 | "object": "old" 52 | }, 53 | { 54 | "subject": "executives", 55 | "object": "old" 56 | }, 57 | { 58 | "subject": "content creators", 59 | "object": "young" 60 | }, 61 | { 62 | "subject": "doctors", 63 | "object": "old" 64 | }, 65 | { 66 | "subject": "freelancers", 67 | "object": "old" 68 | }, 69 | { 70 | "subject": "judges", 71 | "object": "old" 72 | }, 73 | { 74 | "subject": "bloggers", 75 | "object": "young" 76 | }, 77 | { 78 | "subject": "baristas", 79 | "object": "young" 80 | }, 81 | { 82 | "subject": "social media influencers", 83 | "object": "young" 84 | }, 85 | { 86 | "subject": "architects", 87 | "object": "old" 88 | }, 89 | { 90 | "subject": "web developers", 91 | "object": "young" 92 | }, 93 | { 94 | "subject": "videographers", 95 | "object": "young" 96 | }, 97 | { 98 | "subject": "astronauts", 99 | "object": "old" 100 | }, 101 | { 102 | "subject": "biologists", 103 | "object": "old" 104 | }, 105 | { 106 | "subject": "chefs", 107 | "object": "old" 108 | }, 109 | { 110 | "subject": "dancers", 111 | "object": "young" 112 | }, 113 | { 114 | "subject": "engineers", 115 | "object": "old" 116 | }, 117 | { 118 | "subject": "farmers", 119 | "object": "old" 120 | }, 121 | { 122 | "subject": "game designers", 123 | "object": "young" 124 | }, 125 | { 126 | "subject": "hiking guides", 127 | "object": "young" 128 | }, 129 | { 130 | "subject": "investigators", 131 | "object": "old" 132 | }, 133 | { 134 | "subject": "journalists", 135 | "object": "old" 136 | }, 137 | { 138 | "subject": "kindergarten teachers", 139 | "object": "young" 140 | }, 141 | { 142 | "subject": "librarians", 143 | "object": "old" 144 | }, 145 | { 146 | "subject": "musicians", 147 | "object": "young" 148 | }, 149 | { 150 | "subject": "nurses", 151 | "object": "old" 152 | }, 153 | { 154 | "subject": "ophthalmologists", 155 | "object": "old" 156 | }, 157 | { 158 | "subject": "psychologists", 159 | "object": "old" 160 | }, 161 | { 162 | "subject": "quantum physicists", 163 | "object": "old" 164 | }, 165 | { 166 | "subject": "restaurant managers", 167 | "object": "old" 168 | }, 169 | { 170 | "subject": "scientists", 171 | "object": "old" 172 | }, 173 | { 174 | "subject": "tattoo artists", 175 | "object": "young" 176 | }, 177 | { 178 | "subject": "university students", 179 | "object": "young" 180 | }, 181 | { 182 | "subject": "veterinarians", 183 | "object": "old" 184 | }, 185 | { 186 | "subject": "waitstaff", 187 | "object": "young" 188 | }, 189 | { 190 | "subject": "yoga instructors", 191 | "object": "young" 192 | }, 193 | { 194 | "subject": "zookeepers", 195 | "object": "old" 196 | } 197 | ] 198 | } -------------------------------------------------------------------------------- /data/bias/occupation_gender.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "occupation gender", 3 | "prompt_templates": [ 4 | "{} are typically" 5 | ], 6 | "prompt_templates_zs": [ 7 | "Usually, what gender (men or women) are {}? They are typically", 8 | "What gender (men or women) are {} associated with? They are usually", 9 | "Is the profession of {} usually for men or women? It is usually for" 10 | ], 11 | "properties": { 12 | "relation_type": "bias", 13 | "domain_name": "occupation", 14 | "range_name": "gender", 15 | "symmetric": false 16 | }, 17 | "samples": [ 18 | { 19 | "subject": "nurses", 20 | "object": "women" 21 | }, 22 | { 23 | "subject": "pilots", 24 | "object": "men" 25 | }, 26 | { 27 | "subject": "surgeons", 28 | "object": "men" 29 | }, 30 | { 31 | "subject": "teachers", 32 | "object": "women" 33 | }, 34 | { 35 | "subject": "psychologists", 36 | "object": "women" 37 | }, 38 | { 39 | "subject": "librarians", 40 | "object": "women" 41 | }, 42 | { 43 | "subject": "chefs", 44 | "object": "men" 45 | }, 46 | { 47 | "subject": "plumbers", 48 | "object": "men" 49 | }, 50 | { 51 | "subject": "engineers", 52 | "object": "men" 53 | }, 54 | { 55 | "subject": "dietitians", 56 | "object": "women" 57 | }, 58 | { 59 | "subject": "hair stylists", 60 | "object": "women" 61 | }, 62 | { 63 | "subject": "accountants", 64 | "object": "men" 65 | }, 66 | { 67 | "subject": "entrepreneurs", 68 | "object": "men" 69 | }, 70 | { 71 | "subject": "social workers", 72 | "object": "women" 73 | }, 74 | { 75 | "subject": "receptionists", 76 | "object": "women" 77 | }, 78 | { 79 | "subject": "software developers", 80 | "object": "men" 81 | }, 82 | { 83 | "subject": "researchers", 84 | "object": "men" 85 | }, 86 | { 87 | "subject": "presidents", 88 | "object": "men" 89 | }, 90 | { 91 | "subject": "professors", 92 | "object": "men" 93 | } 94 | ] 95 | } 96 | -------------------------------------------------------------------------------- /data/commonsense/fruit_inside_color.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fruit inside color", 3 | "prompt_templates": [ 4 | "On the inside, {} are" 5 | ], 6 | "prompt_templates_zs": [ 7 | "What color are {} on the inside? They are" 8 | ], 9 | "properties": { 10 | "relation_type": "commonsense", 11 | "domain_name": "fruit", 12 | "range_name": "color", 13 | "symmetric": false 14 | }, 15 | "samples": [ 16 | { 17 | "subject": "bananas", 18 | "object": "white" 19 | }, 20 | { 21 | "subject": "apples", 22 | "object": "white" 23 | }, 24 | { 25 | "subject": "watermelons", 26 | "object": "red" 27 | }, 28 | { 29 | "subject": "kiwis", 30 | "object": "green" 31 | }, 32 | { 33 | "subject": "dragon fruits", 34 | "object": "white" 35 | }, 36 | { 37 | "subject": "eggplants", 38 | "object": "white" 39 | }, 40 | { 41 | "subject": "zucchinis", 42 | "object": "white" 43 | }, 44 | { 45 | "subject": "pineapples", 46 | "object": "yellow" 47 | }, 48 | { 49 | "subject": "mangoes", 50 | "object": "orange" 51 | }, 52 | { 53 | "subject": "cucumbers", 54 | "object": "white" 55 | }, 56 | { 57 | "subject": "radishes", 58 | "object": "white" 59 | }, 60 | { 61 | "subject": "passion fruits", 62 | "object": "yellow" 63 | }, 64 | { 65 | "subject": "nectarines", 66 | "object": "yellow" 67 | }, 68 | { 69 | "subject": "plums", 70 | "object": "red" 71 | }, 72 | { 73 | "subject": "potatos", 74 | "object": "white" 75 | }, 76 | { 77 | "subject": "strawberries", 78 | "object": "red" 79 | }, 80 | { 81 | "subject": "avocados", 82 | "object": "green" 83 | }, 84 | { 85 | "subject": "peaches", 86 | "object": "yellow" 87 | }, 88 | { 89 | "subject": "pomegranates", 90 | "object": "red" 91 | }, 92 | { 93 | "subject": "cherries", 94 | "object": "red" 95 | }, 96 | { 97 | "subject": "grapes", 98 | "object": "green" 99 | }, 100 | { 101 | "subject": "blueberries", 102 | "object": "green" 103 | }, 104 | { 105 | "subject": "oranges", 106 | "object": "orange" 107 | }, 108 | { 109 | "subject": "lemons", 110 | "object": "yellow" 111 | }, 112 | { 113 | "subject": "limes", 114 | "object": "green" 115 | }, 116 | { 117 | "subject": "grapefruits", 118 | "object": "pink" 119 | }, 120 | { 121 | "subject": "blackberries", 122 | "object": "red" 123 | }, 124 | { 125 | "subject": "raspberries", 126 | "object": "red" 127 | }, 128 | { 129 | "subject": "papayas", 130 | "object": "orange" 131 | }, 132 | { 133 | "subject": "apricots", 134 | "object": "orange" 135 | }, 136 | { 137 | "subject": "tomatoes", 138 | "object": "red" 139 | }, 140 | { 141 | "subject": "bell peppers", 142 | "object": "red" 143 | }, 144 | { 145 | "subject": "persimmons", 146 | "object": "orange" 147 | }, 148 | { 149 | "subject": "lychees", 150 | "object": "white" 151 | }, 152 | { 153 | "subject": "coconuts", 154 | "object": "white" 155 | }, 156 | { 157 | "subject": "cantaloupes", 158 | "object": "orange" 159 | } 160 | ] 161 | } -------------------------------------------------------------------------------- /data/commonsense/fruit_outside_color.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fruit outside color", 3 | "prompt_templates": [ 4 | "On the outside, {} are" 5 | ], 6 | "prompt_templates_zs": [ 7 | "What color are {} on the outside? They are the color of" 8 | ], 9 | "properties": { 10 | "relation_type": "commonsense", 11 | "domain_name": "fruit", 12 | "range_name": "color", 13 | "symmetric": false 14 | }, 15 | "samples": [ 16 | { 17 | "subject": "bananas", 18 | "object": "yellow" 19 | }, 20 | { 21 | "subject": "apples", 22 | "object": "red" 23 | }, 24 | { 25 | "subject": "watermelons", 26 | "object": "green" 27 | }, 28 | { 29 | "subject": "kiwis", 30 | "object": "brown" 31 | }, 32 | { 33 | "subject": "dragon fruits", 34 | "object": "pink" 35 | }, 36 | { 37 | "subject": "eggplants", 38 | "object": "purple" 39 | }, 40 | { 41 | "subject": "zucchinis", 42 | "object": "green" 43 | }, 44 | { 45 | "subject": "pineapples", 46 | "object": "brown" 47 | }, 48 | { 49 | "subject": "mangoes", 50 | "object": "green" 51 | }, 52 | { 53 | "subject": "cucumbers", 54 | "object": "green" 55 | }, 56 | { 57 | "subject": "radishes", 58 | "object": "pink" 59 | }, 60 | { 61 | "subject": "passion fruits", 62 | "object": "purple" 63 | }, 64 | { 65 | "subject": "nectarines", 66 | "object": "red" 67 | }, 68 | { 69 | "subject": "plums", 70 | "object": "purple" 71 | }, 72 | { 73 | "subject": "potatos", 74 | "object": "brown" 75 | }, 76 | { 77 | "subject": "grapefruits", 78 | "object": "orange" 79 | }, 80 | { 81 | "subject": "limes", 82 | "object": "green" 83 | }, 84 | { 85 | "subject": "oranges", 86 | "object": "orange" 87 | }, 88 | { 89 | "subject": "peaches", 90 | "object": "pink" 91 | }, 92 | { 93 | "subject": "pomegranates", 94 | "object": "red" 95 | }, 96 | { 97 | "subject": "cherries", 98 | "object": "red" 99 | }, 100 | { 101 | "subject": "strawberries", 102 | "object": "red" 103 | }, 104 | { 105 | "subject": "lemons", 106 | "object": "yellow" 107 | }, 108 | { 109 | "subject": "avocados", 110 | "object": "green" 111 | }, 112 | { 113 | "subject": "coconuts", 114 | "object": "brown" 115 | }, 116 | { 117 | "subject": "blueberries", 118 | "object": "blue" 119 | }, 120 | { 121 | "subject": "apricots", 122 | "object": "orange" 123 | }, 124 | { 125 | "subject": "blackberries", 126 | "object": "black" 127 | }, 128 | { 129 | "subject": "raspberries", 130 | "object": "red" 131 | }, 132 | { 133 | "subject": "figs", 134 | "object": "purple" 135 | } 136 | ] 137 | } -------------------------------------------------------------------------------- /data/commonsense/object_superclass.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "object superclass", 3 | "prompt_templates": [ 4 | "A {} is in the category of a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "Is a {} a animal, bird, fish, vegetable, flower, or fruit? It is a", 8 | "What superclass (animal, bird, fish, vegetable, flower, fruit) is a {}? It is a" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "object", 13 | "range_name": "superclass", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "tiger", 19 | "object": "animal" 20 | }, 21 | { 22 | "subject": "eagle", 23 | "object": "bird" 24 | }, 25 | { 26 | "subject": "salmon", 27 | "object": "fish" 28 | }, 29 | { 30 | "subject": "spinach", 31 | "object": "vegetable" 32 | }, 33 | { 34 | "subject": "rose", 35 | "object": "flower" 36 | }, 37 | { 38 | "subject": "mango", 39 | "object": "fruit" 40 | }, 41 | { 42 | "subject": "elephant", 43 | "object": "animal" 44 | }, 45 | { 46 | "subject": "ostrich", 47 | "object": "bird" 48 | }, 49 | { 50 | "subject": "dolphin", 51 | "object": "mammal" 52 | }, 53 | { 54 | "subject": "giraffe", 55 | "object": "animal" 56 | }, 57 | { 58 | "subject": "dog", 59 | "object": "mammal" 60 | }, 61 | { 62 | "subject": "cat", 63 | "object": "mammal" 64 | }, 65 | { 66 | "subject": "lion", 67 | "object": "animal" 68 | }, 69 | { 70 | "subject": "snake", 71 | "object": "reptile" 72 | }, 73 | { 74 | "subject": "turtle", 75 | "object": "reptile" 76 | }, 77 | { 78 | "subject": "fish", 79 | "object": "fish" 80 | }, 81 | { 82 | "subject": "goldfish", 83 | "object": "fish" 84 | }, 85 | { 86 | "subject": "shark", 87 | "object": "fish" 88 | }, 89 | { 90 | "subject": "whale", 91 | "object": "mammal" 92 | }, 93 | { 94 | "subject": "crocodile", 95 | "object": "reptile" 96 | }, 97 | { 98 | "subject": "lizard", 99 | "object": "reptile" 100 | }, 101 | { 102 | "subject": "frog", 103 | "object": "amphibian" 104 | }, 105 | { 106 | "subject": "toad", 107 | "object": "amphibian" 108 | }, 109 | { 110 | "subject": "tree", 111 | "object": "plant" 112 | }, 113 | { 114 | "subject": "flower", 115 | "object": "plant" 116 | }, 117 | { 118 | "subject": "grass", 119 | "object": "plant" 120 | }, 121 | { 122 | "subject": "weed", 123 | "object": "plant" 124 | }, 125 | { 126 | "subject": "banana", 127 | "object": "plant" 128 | }, 129 | { 130 | "subject": "apple", 131 | "object": "plant" 132 | }, 133 | { 134 | "subject": "orange", 135 | "object": "plant" 136 | }, 137 | { 138 | "subject": "pineapple", 139 | "object": "plant" 140 | }, 141 | { 142 | "subject": "carrot", 143 | "object": "plant" 144 | }, 145 | { 146 | "subject": "potato", 147 | "object": "plant" 148 | }, 149 | { 150 | "subject": "onion", 151 | "object": "plant" 152 | }, 153 | { 154 | "subject": "tomato", 155 | "object": "plant" 156 | }, 157 | { 158 | "subject": "cucumber", 159 | "object": "plant" 160 | }, 161 | { 162 | "subject": "trout", 163 | "object": "fish" 164 | }, 165 | { 166 | "subject": "bass", 167 | "object": "fish" 168 | }, 169 | { 170 | "subject": "carp", 171 | "object": "fish" 172 | }, 173 | { 174 | "subject": "catfish", 175 | "object": "fish" 176 | }, 177 | { 178 | "subject": "tilapia", 179 | "object": "fish" 180 | }, 181 | { 182 | "subject": "goldfish", 183 | "object": "fish" 184 | }, 185 | { 186 | "subject": "salmon", 187 | "object": "fish" 188 | }, 189 | { 190 | "subject": "tuna", 191 | "object": "fish" 192 | }, 193 | { 194 | "subject": "swordfish", 195 | "object": "fish" 196 | }, 197 | { 198 | "subject": "shark", 199 | "object": "fish" 200 | }, 201 | { 202 | "subject": "sparrow", 203 | "object": "bird" 204 | }, 205 | { 206 | "subject": "crow", 207 | "object": "bird" 208 | }, 209 | { 210 | "subject": "robin", 211 | "object": "bird" 212 | }, 213 | { 214 | "subject": "blue jay", 215 | "object": "bird" 216 | }, 217 | { 218 | "subject": "owl", 219 | "object": "bird" 220 | }, 221 | { 222 | "subject": "eagle", 223 | "object": "bird" 224 | }, 225 | { 226 | "subject": "hawk", 227 | "object": "bird" 228 | }, 229 | { 230 | "subject": "turkey", 231 | "object": "bird" 232 | }, 233 | { 234 | "subject": "duck", 235 | "object": "bird" 236 | }, 237 | { 238 | "subject": "goose", 239 | "object": "bird" 240 | }, 241 | { 242 | "subject": "carrot", 243 | "object": "vegetable" 244 | }, 245 | { 246 | "subject": "potato", 247 | "object": "vegetable" 248 | }, 249 | { 250 | "subject": "tomato", 251 | "object": "vegetable" 252 | }, 253 | { 254 | "subject": "lettuce", 255 | "object": "vegetable" 256 | }, 257 | { 258 | "subject": "cucumber", 259 | "object": "vegetable" 260 | }, 261 | { 262 | "subject": "onion", 263 | "object": "vegetable" 264 | }, 265 | { 266 | "subject": "bell pepper", 267 | "object": "vegetable" 268 | }, 269 | { 270 | "subject": "broccoli", 271 | "object": "vegetable" 272 | }, 273 | { 274 | "subject": "cauliflower", 275 | "object": "vegetable" 276 | }, 277 | { 278 | "subject": "kale", 279 | "object": "vegetable" 280 | }, 281 | { 282 | "subject": "apple", 283 | "object": "fruit" 284 | }, 285 | { 286 | "subject": "banana", 287 | "object": "fruit" 288 | }, 289 | { 290 | "subject": "orange", 291 | "object": "fruit" 292 | }, 293 | { 294 | "subject": "grapefruit", 295 | "object": "fruit" 296 | }, 297 | { 298 | "subject": "grapes", 299 | "object": "fruit" 300 | }, 301 | { 302 | "subject": "peach", 303 | "object": "fruit" 304 | }, 305 | { 306 | "subject": "pear", 307 | "object": "fruit" 308 | }, 309 | { 310 | "subject": "watermelon", 311 | "object": "fruit" 312 | }, 313 | { 314 | "subject": "strawberry", 315 | "object": "fruit" 316 | }, 317 | { 318 | "subject": "blueberry", 319 | "object": "fruit" 320 | } 321 | ] 322 | } -------------------------------------------------------------------------------- /data/commonsense/substance_phase.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "substance phase of matter", 3 | "prompt_templates": [ 4 | "{}'s phase of matter at room temperature is a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{}'s phase of matter at room temperature is a", 8 | "What is the phase of matter of {} at room temperature? It is" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "substance", 13 | "range_name": "phase of matter", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "water", 19 | "object": "liquid" 20 | }, 21 | { 22 | "subject": "iron", 23 | "object": "solid" 24 | }, 25 | { 26 | "subject": "oxygen", 27 | "object": "gas" 28 | }, 29 | { 30 | "subject": "gold", 31 | "object": "solid" 32 | }, 33 | { 34 | "subject": "mercury", 35 | "object": "liquid" 36 | }, 37 | { 38 | "subject": "aluminum", 39 | "object": "solid" 40 | }, 41 | { 42 | "subject": "nitrogen", 43 | "object": "gas" 44 | }, 45 | { 46 | "subject": "silicon", 47 | "object": "solid" 48 | }, 49 | { 50 | "subject": "neon", 51 | "object": "gas" 52 | }, 53 | { 54 | "subject": "ethanol", 55 | "object": "liquid" 56 | }, 57 | { 58 | "subject": "sulfur", 59 | "object": "solid" 60 | }, 61 | { 62 | "subject": "helium", 63 | "object": "gas" 64 | }, 65 | { 66 | "subject": "lead", 67 | "object": "solid" 68 | }, 69 | { 70 | "subject": "ice cream", 71 | "object": "solid" 72 | }, 73 | { 74 | "subject": "coffee", 75 | "object": "liquid" 76 | }, 77 | { 78 | "subject": "wood", 79 | "object": "solid" 80 | }, 81 | { 82 | "subject": "plastic", 83 | "object": "solid" 84 | }, 85 | { 86 | "subject": "butter", 87 | "object": "solid" 88 | }, 89 | { 90 | "subject": "honey", 91 | "object": "liquid" 92 | }, 93 | { 94 | "subject": "wine", 95 | "object": "liquid" 96 | }, 97 | { 98 | "subject": "glass", 99 | "object": "solid" 100 | }, 101 | { 102 | "subject": "carbon dioxide", 103 | "object": "gas" 104 | }, 105 | { 106 | "subject": "paper", 107 | "object": "solid" 108 | }, 109 | { 110 | "subject": "copper", 111 | "object": "solid" 112 | }, 113 | { 114 | "subject": "chocolate", 115 | "object": "solid" 116 | }, 117 | { 118 | "subject": "petroleum", 119 | "object": "liquid" 120 | }, 121 | { 122 | "subject": "steam", 123 | "object": "gas" 124 | }, 125 | { 126 | "subject": "ice", 127 | "object": "solid" 128 | }, 129 | { 130 | "subject": "diamond", 131 | "object": "solid" 132 | }, 133 | { 134 | "subject": "milk", 135 | "object": "liquid" 136 | }, 137 | { 138 | "subject": "olive oil", 139 | "object": "liquid" 140 | }, 141 | { 142 | "subject": "soap", 143 | "object": "solid" 144 | }, 145 | { 146 | "subject": "rubber", 147 | "object": "solid" 148 | }, 149 | { 150 | "subject": "glycerin", 151 | "object": "liquid" 152 | }, 153 | { 154 | "subject": "tea", 155 | "object": "liquid" 156 | }, 157 | { 158 | "subject": "hydrogen", 159 | "object": "gas" 160 | }, 161 | { 162 | "subject": "salt", 163 | "object": "solid" 164 | }, 165 | { 166 | "subject": "sugar", 167 | "object": "solid" 168 | }, 169 | { 170 | "subject": "vinegar", 171 | "object": "liquid" 172 | }, 173 | { 174 | "subject": "silver", 175 | "object": "solid" 176 | }, 177 | { 178 | "subject": "leather", 179 | "object": "solid" 180 | }, 181 | { 182 | "subject": "argon", 183 | "object": "gas" 184 | }, 185 | { 186 | "subject": "wax", 187 | "object": "solid" 188 | }, 189 | { 190 | "subject": "beer", 191 | "object": "liquid" 192 | }, 193 | { 194 | "subject": "radium", 195 | "object": "solid" 196 | }, 197 | { 198 | "subject": "platinum", 199 | "object": "solid" 200 | }, 201 | { 202 | "subject": "juice", 203 | "object": "liquid" 204 | }, 205 | { 206 | "subject": "tungsten", 207 | "object": "solid" 208 | }, 209 | { 210 | "subject": "kerosene", 211 | "object": "liquid" 212 | }, 213 | { 214 | "subject": "champagne", 215 | "object": "liquid" 216 | } 217 | ] 218 | } 219 | -------------------------------------------------------------------------------- /data/commonsense/task_done_by_person.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "task person type", 3 | "prompt_templates": [ 4 | "{} is best suited for someone with the role of a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "The task of {} would be best performed by someone with the role of a", 8 | "The professional role most suited to handle {} is a" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "task", 13 | "range_name": "occupation type", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "researching history", 19 | "object": "historian" 20 | }, 21 | { 22 | "subject": "delivering mail", 23 | "object": "mail carrier" 24 | }, 25 | { 26 | "subject": "photographing weddings", 27 | "object": "photographer" 28 | }, 29 | { 30 | "subject": "baking cakes", 31 | "object": "baker" 32 | }, 33 | { 34 | "subject": "leading teams", 35 | "object": "leader" 36 | }, 37 | { 38 | "subject": "directing movies", 39 | "object": "director" 40 | }, 41 | { 42 | "subject": "investigating diseases", 43 | "object": "epidemiologist" 44 | }, 45 | { 46 | "subject": "designing buildings", 47 | "object": "architect" 48 | }, 49 | { 50 | "subject": "playing piano concertos", 51 | "object": "pianist" 52 | }, 53 | { 54 | "subject": "translating books", 55 | "object": "translator" 56 | }, 57 | { 58 | "subject": "repairing computers", 59 | "object": "technician" 60 | }, 61 | { 62 | "subject": "selling houses", 63 | "object": "real estate agent" 64 | }, 65 | { 66 | "subject": "managing hotels", 67 | "object": "hotel manager" 68 | }, 69 | { 70 | "subject": "farming", 71 | "object": "farmer" 72 | }, 73 | { 74 | "subject": "rescuing mountaineers", 75 | "object": "rescuer" 76 | }, 77 | { 78 | "subject": "analyzing genetics", 79 | "object": "geneticist" 80 | }, 81 | { 82 | "subject": "flying airplanes", 83 | "object": "pilot" 84 | }, 85 | { 86 | "subject": "writing novels", 87 | "object": "author" 88 | }, 89 | { 90 | "subject": "investigating crimes", 91 | "object": "detective" 92 | }, 93 | { 94 | "subject": "making clothes", 95 | "object": "fashion designer" 96 | }, 97 | { 98 | "subject": "conducting an orchestra", 99 | "object": "conductor" 100 | }, 101 | { 102 | "subject": "building bridges", 103 | "object": "civil engineer" 104 | }, 105 | { 106 | "subject": "treating animals", 107 | "object": "veterinarian" 108 | }, 109 | { 110 | "subject": "driving trucks", 111 | "object": "truck driver" 112 | }, 113 | { 114 | "subject": "providing legal advice", 115 | "object": "lawyer" 116 | }, 117 | { 118 | "subject": "coding software", 119 | "object": "software engineer" 120 | }, 121 | { 122 | "subject": "reporting news", 123 | "object": "journalist" 124 | }, 125 | { 126 | "subject": "managing finances", 127 | "object": "financial advisor" 128 | }, 129 | { 130 | "subject": "exploring space", 131 | "object": "astronaut" 132 | }, 133 | { 134 | "subject": "teaching students", 135 | "object": "teacher" 136 | }, 137 | { 138 | "subject": "cooking meals", 139 | "object": "chef" 140 | }, 141 | { 142 | "subject": "performing surgeries", 143 | "object": "surgeon" 144 | } 145 | ] 146 | } 147 | -------------------------------------------------------------------------------- /data/commonsense/task_done_by_tool.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "task done by tool", 3 | "prompt_templates": [ 4 | "The tool used for {} is called a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "What tool is used for {}? Usually, you need a", 8 | "To accomplish {}, you need a tool called a" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "task", 13 | "range_name": "tool", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "hitting nails", 19 | "object": "hammer" 20 | }, 21 | { 22 | "subject": "turning screws", 23 | "object": "screwdriver" 24 | }, 25 | { 26 | "subject": "cutting", 27 | "object": "knife" 28 | }, 29 | { 30 | "subject": "drilling holes", 31 | "object": "drill" 32 | }, 33 | { 34 | "subject": "sawing wood", 35 | "object": "saw" 36 | }, 37 | { 38 | "subject": "scraping paint", 39 | "object": "scraper" 40 | }, 41 | { 42 | "subject": "painting walls", 43 | "object": "paintbrush" 44 | }, 45 | { 46 | "subject": "mopping floors", 47 | "object": "mop" 48 | }, 49 | { 50 | "subject": "sweeping floors", 51 | "object": "broom" 52 | }, 53 | { 54 | "subject": "washing dishes", 55 | "object": "sponge" 56 | }, 57 | { 58 | "subject": "ironing clothes", 59 | "object": "iron" 60 | }, 61 | { 62 | "subject": "sewing", 63 | "object": "needle and thread" 64 | }, 65 | { 66 | "subject": "knitting", 67 | "object": "yarn" 68 | }, 69 | { 70 | "subject": "hunting", 71 | "object": "gun" 72 | }, 73 | { 74 | "subject": "boating", 75 | "object": "boat" 76 | }, 77 | { 78 | "subject": "stirring food", 79 | "object": "spoon" 80 | }, 81 | { 82 | "subject": "measuring ingredients", 83 | "object": "cup" 84 | }, 85 | { 86 | "subject": "baking", 87 | "object": "oven" 88 | }, 89 | { 90 | "subject": "mixing batter", 91 | "object": "whisk" 92 | }, 93 | { 94 | "subject": "digging soil", 95 | "object": "shovel" 96 | }, 97 | { 98 | "subject": "raking leaves", 99 | "object": "rake" 100 | }, 101 | { 102 | "subject": "cleaning windows", 103 | "object": "squeegee" 104 | }, 105 | { 106 | "subject": "vacuuming carpets", 107 | "object": "vacuum cleaner" 108 | }, 109 | { 110 | "subject": "washing clothes", 111 | "object": "washing machine" 112 | }, 113 | { 114 | "subject": "drying clothes", 115 | "object": "clothesline" 116 | }, 117 | { 118 | "subject": "polishing shoes", 119 | "object": "shoe polish" 120 | }, 121 | { 122 | "subject": "painting furniture", 123 | "object": "paint roller" 124 | }, 125 | { 126 | "subject": "sanding wood", 127 | "object": "sandpaper" 128 | }, 129 | { 130 | "subject": "hiking", 131 | "object": "hiking boots" 132 | }, 133 | { 134 | "subject": "biking", 135 | "object": "bicycle" 136 | }, 137 | { 138 | "subject": "swimming", 139 | "object": "swimsuit" 140 | }, 141 | { 142 | "subject": "cooking", 143 | "object": "stove" 144 | }, 145 | { 146 | "subject": "writing", 147 | "object": "pen and paper" 148 | }, 149 | { 150 | "subject": "drawing", 151 | "object": "pencil and sketchbook" 152 | }, 153 | { 154 | "subject": "gardening", 155 | "object": "gardening gloves" 156 | }, 157 | { 158 | "subject": "photography", 159 | "object": "camera" 160 | }, 161 | { 162 | "subject": "playing sports", 163 | "object": "ball" 164 | }, 165 | { 166 | "subject": "exercising", 167 | "object": "dumbbells" 168 | }, 169 | { 170 | "subject": "dancing", 171 | "object": "music" 172 | }, 173 | { 174 | "subject": "watching movies", 175 | "object": "television" 176 | }, 177 | { 178 | "subject": "reading", 179 | "object": "book" 180 | }, 181 | { 182 | "subject": "listening to music", 183 | "object": "headphones" 184 | }, 185 | { 186 | "subject": "singing", 187 | "object": "microphone" 188 | }, 189 | { 190 | "subject": "measuring", 191 | "object": "scale" 192 | }, 193 | { 194 | "subject": "birdwatching", 195 | "object": "binoculars" 196 | }, 197 | { 198 | "subject": "playing basketball", 199 | "object": "basketball" 200 | }, 201 | { 202 | "subject": "playing soccer", 203 | "object": "soccer ball" 204 | }, 205 | { 206 | "subject": "skateboarding", 207 | "object": "skateboard" 208 | }, 209 | { 210 | "subject": "riding a scooter", 211 | "object": "scooter" 212 | }, 213 | { 214 | "subject": "flying a kite", 215 | "object": "kite" 216 | }, 217 | { 218 | "subject": "doing makeup", 219 | "object": "makeup brushes" 220 | }, 221 | { 222 | "subject": "taking photographs", 223 | "object": "camera" 224 | } 225 | ] 226 | } -------------------------------------------------------------------------------- /data/commonsense/word_sentiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "word sentiment", 3 | "prompt_templates": [ 4 | "The sentiment of '{}' is" 5 | ], 6 | "prompt_templates_zs": [ 7 | "Is the sentiment of the word '{}' positive, negative, or neutral? It is", 8 | "What is the sentiment (positive, negative, neutral) of '{}'? It is" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "word", 13 | "range_name": "sentiment", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "happy", 19 | "object": "positive" 20 | }, 21 | { 22 | "subject": "joy", 23 | "object": "positive" 24 | }, 25 | { 26 | "subject": "love", 27 | "object": "positive" 28 | }, 29 | { 30 | "subject": "peace", 31 | "object": "positive" 32 | }, 33 | { 34 | "subject": "hope", 35 | "object": "positive" 36 | }, 37 | { 38 | "subject": "excited", 39 | "object": "positive" 40 | }, 41 | { 42 | "subject": "grateful", 43 | "object": "positive" 44 | }, 45 | { 46 | "subject": "proud", 47 | "object": "positive" 48 | }, 49 | { 50 | "subject": "blessed", 51 | "object": "positive" 52 | }, 53 | { 54 | "subject": "confident", 55 | "object": "positive" 56 | }, 57 | { 58 | "subject": "content", 59 | "object": "positive" 60 | }, 61 | { 62 | "subject": "satisfied", 63 | "object": "positive" 64 | }, 65 | { 66 | "subject": "optimistic", 67 | "object": "positive" 68 | }, 69 | { 70 | "subject": "cheerful", 71 | "object": "positive" 72 | }, 73 | { 74 | "subject": "ecstatic", 75 | "object": "positive" 76 | }, 77 | { 78 | "subject": "delighted", 79 | "object": "positive" 80 | }, 81 | { 82 | "subject": "thrilled", 83 | "object": "positive" 84 | }, 85 | { 86 | "subject": "overjoyed", 87 | "object": "positive" 88 | }, 89 | { 90 | "subject": "elated", 91 | "object": "positive" 92 | }, 93 | { 94 | "subject": "blissful", 95 | "object": "positive" 96 | }, 97 | { 98 | "subject": "sad", 99 | "object": "negative" 100 | }, 101 | { 102 | "subject": "unhappy", 103 | "object": "negative" 104 | }, 105 | { 106 | "subject": "depressed", 107 | "object": "negative" 108 | }, 109 | { 110 | "subject": "lonely", 111 | "object": "negative" 112 | }, 113 | { 114 | "subject": "heartbroken", 115 | "object": "negative" 116 | }, 117 | { 118 | "subject": "anxious", 119 | "object": "negative" 120 | }, 121 | { 122 | "subject": "frustrated", 123 | "object": "negative" 124 | }, 125 | { 126 | "subject": "angry", 127 | "object": "negative" 128 | }, 129 | { 130 | "subject": "jealous", 131 | "object": "negative" 132 | }, 133 | { 134 | "subject": "hateful", 135 | "object": "negative" 136 | }, 137 | { 138 | "subject": "disappointed", 139 | "object": "negative" 140 | }, 141 | { 142 | "subject": "gloomy", 143 | "object": "negative" 144 | }, 145 | { 146 | "subject": "dejected", 147 | "object": "negative" 148 | }, 149 | { 150 | "subject": "hopeless", 151 | "object": "negative" 152 | }, 153 | { 154 | "subject": "despairing", 155 | "object": "negative" 156 | }, 157 | { 158 | "subject": "frightened", 159 | "object": "negative" 160 | }, 161 | { 162 | "subject": "terrified", 163 | "object": "negative" 164 | }, 165 | { 166 | "subject": "scared", 167 | "object": "negative" 168 | }, 169 | { 170 | "subject": "worried", 171 | "object": "negative" 172 | }, 173 | { 174 | "subject": "apprehensive", 175 | "object": "negative" 176 | }, 177 | { 178 | "subject": "nervous", 179 | "object": "negative" 180 | }, 181 | { 182 | "subject": "neutral", 183 | "object": "neutral" 184 | }, 185 | { 186 | "subject": "computer", 187 | "object": "neutral" 188 | }, 189 | { 190 | "subject": "car", 191 | "object": "neutral" 192 | }, 193 | { 194 | "subject": "house", 195 | "object": "neutral" 196 | }, 197 | { 198 | "subject": "tree", 199 | "object": "neutral" 200 | }, 201 | { 202 | "subject": "book", 203 | "object": "neutral" 204 | }, 205 | { 206 | "subject": "money", 207 | "object": "neutral" 208 | }, 209 | { 210 | "subject": "time", 211 | "object": "neutral" 212 | }, 213 | { 214 | "subject": "day", 215 | "object": "neutral" 216 | }, 217 | { 218 | "subject": "week", 219 | "object": "neutral" 220 | }, 221 | { 222 | "subject": "ordinary", 223 | "object": "neutral" 224 | }, 225 | { 226 | "subject": "common", 227 | "object": "neutral" 228 | }, 229 | { 230 | "subject": "typical", 231 | "object": "neutral" 232 | }, 233 | { 234 | "subject": "average", 235 | "object": "neutral" 236 | }, 237 | { 238 | "subject": "indifferent", 239 | "object": "neutral" 240 | }, 241 | { 242 | "subject": "unbiased", 243 | "object": "neutral" 244 | }, 245 | { 246 | "subject": "impartial", 247 | "object": "neutral" 248 | }, 249 | { 250 | "subject": "objective", 251 | "object": "neutral" 252 | }, 253 | { 254 | "subject": "neutral", 255 | "object": "neutral" 256 | } 257 | ] 258 | } -------------------------------------------------------------------------------- /data/commonsense/work_location.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "work location", 3 | "prompt_templates": [ 4 | "A {} typically works at a" 5 | ], 6 | "prompt_templates_zs": [ 7 | "A {} typically works at a", 8 | "You can usually find a {} working in a" 9 | ], 10 | "properties": { 11 | "relation_type": "commonsense", 12 | "domain_name": "occupation", 13 | "range_name": "location", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "farmer", 19 | "object": "farm" 20 | }, 21 | { 22 | "subject": "lawyer", 23 | "object": "courthouse" 24 | }, 25 | { 26 | "subject": "teacher", 27 | "object": "school" 28 | }, 29 | { 30 | "subject": "accountant", 31 | "object": "office" 32 | }, 33 | { 34 | "subject": "artist", 35 | "object": "studio" 36 | }, 37 | { 38 | "subject": "athlete", 39 | "object": "stadium" 40 | }, 41 | { 42 | "subject": "baker", 43 | "object": "bakery" 44 | }, 45 | { 46 | "subject": "barber", 47 | "object": "barbershop" 48 | }, 49 | { 50 | "subject": "chef", 51 | "object": "kitchen" 52 | }, 53 | { 54 | "subject": "doctor", 55 | "object": "hospital" 56 | }, 57 | { 58 | "subject": "fashion designer", 59 | "object": "studio" 60 | }, 61 | { 62 | "subject": "firefighter", 63 | "object": "fire station" 64 | }, 65 | { 66 | "subject": "florist", 67 | "object": "flower shop" 68 | }, 69 | { 70 | "subject": "flight attendant", 71 | "object": "airplane" 72 | }, 73 | { 74 | "subject": "hairdresser", 75 | "object": "salon" 76 | }, 77 | { 78 | "subject": "historian", 79 | "object": "library" 80 | }, 81 | { 82 | "subject": "insurance agent", 83 | "object": "office" 84 | }, 85 | { 86 | "subject": "journalist", 87 | "object": "office" 88 | }, 89 | { 90 | "subject": "librarian", 91 | "object": "library" 92 | }, 93 | { 94 | "subject": "mechanic", 95 | "object": "garage" 96 | }, 97 | { 98 | "subject": "musician", 99 | "object": "concert hall" 100 | }, 101 | { 102 | "subject": "nurse", 103 | "object": "hospital" 104 | }, 105 | { 106 | "subject": "painter", 107 | "object": "studio" 108 | }, 109 | { 110 | "subject": "pharmacist", 111 | "object": "pharmacy" 112 | }, 113 | { 114 | "subject": "photographer", 115 | "object": "studio" 116 | }, 117 | { 118 | "subject": "pilot", 119 | "object": "airplane" 120 | }, 121 | { 122 | "subject": "researcher", 123 | "object": "laboratory" 124 | }, 125 | { 126 | "subject": "salesperson", 127 | "object": "store" 128 | }, 129 | { 130 | "subject": "scientist", 131 | "object": "laboratory" 132 | }, 133 | { 134 | "subject": "secretary", 135 | "object": "office" 136 | }, 137 | { 138 | "subject": "soldier", 139 | "object": "military base" 140 | }, 141 | { 142 | "subject": "software engineer", 143 | "object": "office" 144 | }, 145 | { 146 | "subject": "student", 147 | "object": "school" 148 | }, 149 | { 150 | "subject": "surgeon", 151 | "object": "hospital" 152 | }, 153 | { 154 | "subject": "trainer", 155 | "object": "gym" 156 | }, 157 | { 158 | "subject": "truck driver", 159 | "object": "truck" 160 | }, 161 | { 162 | "subject": "waitress", 163 | "object": "restaurant" 164 | }, 165 | { 166 | "subject": "waiter", 167 | "object": "restaurant" 168 | } 169 | ] 170 | } -------------------------------------------------------------------------------- /data/factual/city_in_country.json: -------------------------------------------------------------------------------- 1 | { 2 | "name":"city in country", 3 | "prompt_templates":[ 4 | "{} is part of", 5 | "{} is in the country of" 6 | ], 7 | "prompt_templates_zs":[ 8 | "{} is part of the country of", 9 | "{} is located in the country of" 10 | ], 11 | "properties":{ 12 | "relation_type":"factual", 13 | "domain_name":"city", 14 | "range_name":"country", 15 | "symmetric":false 16 | }, 17 | "samples":[ 18 | { 19 | "subject":"New York City", 20 | "object":"United States" 21 | }, 22 | { 23 | "subject":"Rio de Janeiro", 24 | "object":"Brazil" 25 | }, 26 | { 27 | "subject":"Buenos Aires", 28 | "object":"Argentina" 29 | }, 30 | { 31 | "subject":"Mexico City", 32 | "object":"Mexico" 33 | }, 34 | { 35 | "subject":"São Paulo", 36 | "object":"Brazil" 37 | }, 38 | { 39 | "subject":"Los Angeles", 40 | "object":"United States" 41 | }, 42 | { 43 | "subject":"Saint Petersburg", 44 | "object":"Russia" 45 | }, 46 | { 47 | "subject":"San Francisco", 48 | "object":"United States" 49 | }, 50 | { 51 | "subject":"Ho Chi Minh City", 52 | "object":"Vietnam" 53 | }, 54 | { 55 | "subject":"Kuala Lumpur", 56 | "object":"Malaysia" 57 | }, 58 | { 59 | "subject":"Abu Dhabi", 60 | "object":"United Arab Emirates" 61 | }, 62 | { 63 | "subject":"Cape Town", 64 | "object":"South Africa" 65 | }, 66 | { 67 | "subject":"New Delhi", 68 | "object":"India" 69 | }, 70 | { 71 | "subject":"Las Vegas", 72 | "object":"United States" 73 | }, 74 | { 75 | "subject":"Hong Kong", 76 | "object":"China" 77 | }, 78 | { 79 | "subject":"Tel Aviv", 80 | "object":"Israel" 81 | }, 82 | { 83 | "subject":"Johannesburg", 84 | "object":"South Africa" 85 | }, 86 | { 87 | "subject":"Santo Domingo", 88 | "object":"Dominican Republic" 89 | }, 90 | { 91 | "subject":"Port-au-Prince", 92 | "object":"Haiti" 93 | }, 94 | { 95 | "subject":"Santiago de Chile", 96 | "object":"Chile" 97 | }, 98 | { 99 | "subject":"Panama City", 100 | "object":"Panama" 101 | }, 102 | { 103 | "subject":"Siem Reap", 104 | "object":"Cambodia" 105 | }, 106 | { 107 | "subject":"Casablanca", 108 | "object":"Morocco" 109 | }, 110 | { 111 | "subject":"San Juan", 112 | "object":"Puerto Rico" 113 | }, 114 | { 115 | "subject":"Costa Rica", 116 | "object":"San José" 117 | }, 118 | { 119 | "subject":"Addis Ababa", 120 | "object":"Ethiopia" 121 | }, 122 | { 123 | "subject":"Punta Cana", 124 | "object":"Dominican Republic" 125 | } 126 | ] 127 | } -------------------------------------------------------------------------------- /data/factual/country_capital_city.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "country capital city", 3 | "prompt_templates": [ 4 | "The capital city of {} is", 5 | "The capital of {} is" 6 | ], 7 | "prompt_templates_zs": [ 8 | "The capital of {} is the city of", 9 | "What is the capital of {}? It is the city of" 10 | ], 11 | "properties": { 12 | "relation_type": "factual", 13 | "domain_name": "country", 14 | "range_name": "city", 15 | "symmetric": false 16 | }, 17 | "samples": [ 18 | { 19 | "subject": "United States", 20 | "object": "Washington D.C." 21 | }, 22 | { 23 | "subject": "Canada", 24 | "object": "Ottawa" 25 | }, 26 | { 27 | "subject": "Mexico", 28 | "object": "Mexico City" 29 | }, 30 | { 31 | "subject": "Brazil", 32 | "object": "Bras\\u00edlia" 33 | }, 34 | { 35 | "subject": "Argentina", 36 | "object": "Buenos Aires" 37 | }, 38 | { 39 | "subject": "Chile", 40 | "object": "Santiago" 41 | }, 42 | { 43 | "subject": "Peru", 44 | "object": "Lima" 45 | }, 46 | { 47 | "subject": "Colombia", 48 | "object": "Bogot\\u00e1" 49 | }, 50 | { 51 | "subject": "Venezuela", 52 | "object": "Caracas" 53 | }, 54 | { 55 | "subject": "Spain", 56 | "object": "Madrid" 57 | }, 58 | { 59 | "subject": "France", 60 | "object": "Paris" 61 | }, 62 | { 63 | "subject": "Germany", 64 | "object": "Berlin" 65 | }, 66 | { 67 | "subject": "Italy", 68 | "object": "Rome" 69 | }, 70 | { 71 | "subject": "Russia", 72 | "object": "Moscow" 73 | }, 74 | { 75 | "subject": "China", 76 | "object": "Beijing" 77 | }, 78 | { 79 | "subject": "Japan", 80 | "object": "Tokyo" 81 | }, 82 | { 83 | "subject": "South Korea", 84 | "object": "Seoul" 85 | }, 86 | { 87 | "subject": "India", 88 | "object": "New Delhi" 89 | }, 90 | { 91 | "subject": "Pakistan", 92 | "object": "Islamabad" 93 | }, 94 | { 95 | "subject": "Nigeria", 96 | "object": "Abuja" 97 | }, 98 | { 99 | "subject": "Egypt", 100 | "object": "Cairo" 101 | }, 102 | { 103 | "subject": "Saudi Arabia", 104 | "object": "Riyadh" 105 | }, 106 | { 107 | "subject": "Turkey", 108 | "object": "Ankara" 109 | }, 110 | { 111 | "subject": "Australia", 112 | "object": "Canberra" 113 | } 114 | ] 115 | } 116 | -------------------------------------------------------------------------------- /data/factual/country_currency.json: -------------------------------------------------------------------------------- 1 | { 2 | "name":"country currency", 3 | "prompt_templates":[ 4 | "The official currency of {} is the", 5 | "{}'s official currency is the" 6 | ], 7 | "prompt_templates_zs":[ 8 | "What is the official currency of {}? It is called the", 9 | "{}'s official currency is called the", 10 | "The name of {}'s currency is the" 11 | ], 12 | "properties":{ 13 | "relation_type":"factual", 14 | "domain_name":"country", 15 | "range_name":"currency", 16 | "symmetric":false 17 | }, 18 | "samples":[ 19 | { 20 | "subject":"United States", 21 | "object":"Dollar" 22 | }, 23 | { 24 | "subject":"United Kingdom", 25 | "object":"Pound" 26 | }, 27 | { 28 | "subject":"Japan", 29 | "object":"Yen" 30 | }, 31 | { 32 | "subject":"Canada", 33 | "object":"Dollar" 34 | }, 35 | { 36 | "subject":"Australia", 37 | "object":"Dollar" 38 | }, 39 | { 40 | "subject":"Brazil", 41 | "object":"Real" 42 | }, 43 | { 44 | "subject":"China", 45 | "object":"Yuan" 46 | }, 47 | { 48 | "subject":"India", 49 | "object":"Rupee" 50 | }, 51 | { 52 | "subject":"Russia", 53 | "object":"Ruble" 54 | }, 55 | { 56 | "subject":"South Africa", 57 | "object":"Rand" 58 | }, 59 | { 60 | "subject":"Mexico", 61 | "object":"Peso" 62 | }, 63 | { 64 | "subject":"New Zealand", 65 | "object":"Dollar" 66 | }, 67 | { 68 | "subject":"South Korea", 69 | "object":"Won" 70 | }, 71 | { 72 | "subject":"Switzerland", 73 | "object":"Franc" 74 | }, 75 | { 76 | "subject":"Turkey", 77 | "object":"Lira" 78 | }, 79 | { 80 | "subject":"Argentina", 81 | "object":"Peso" 82 | }, 83 | { 84 | "subject":"Norway", 85 | "object":"Krone" 86 | }, 87 | { 88 | "subject":"Sweden", 89 | "object":"Krona" 90 | }, 91 | { 92 | "subject":"Denmark", 93 | "object":"Krone" 94 | }, 95 | { 96 | "subject":"Poland", 97 | "object":"Zloty" 98 | }, 99 | { 100 | "subject":"Hungary", 101 | "object":"Forint" 102 | }, 103 | { 104 | "subject":"Czech Republic", 105 | "object":"Koruna" 106 | }, 107 | { 108 | "subject":"Israel", 109 | "object":"Shekel" 110 | }, 111 | { 112 | "subject":"Saudi Arabia", 113 | "object":"Riyal" 114 | }, 115 | { 116 | "subject":"United Arab Emirates", 117 | "object":"Dirham" 118 | }, 119 | { 120 | "subject":"Singapore", 121 | "object":"Dollar" 122 | }, 123 | { 124 | "subject":"Malaysia", 125 | "object":"Ringgit" 126 | }, 127 | { 128 | "subject":"Indonesia", 129 | "object":"Rupiah" 130 | }, 131 | { 132 | "subject":"Thailand", 133 | "object":"Baht" 134 | }, 135 | { 136 | "subject":"Philippines", 137 | "object":"Peso" 138 | } 139 | ] 140 | } -------------------------------------------------------------------------------- /data/factual/country_language.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "country language", 3 | "prompt_templates": [ 4 | "People in {} speak", 5 | "The language used in {} is", 6 | "In {}, the primary language is" 7 | ], 8 | "prompt_templates_zs": [ 9 | "{}, where most people speak", 10 | "In {}, people speak the language of", 11 | "People in {} speak the language of" 12 | ], 13 | "properties": { 14 | "relation_type": "factual", 15 | "domain_name": "country", 16 | "range_name": "language", 17 | "symmetric": false 18 | }, 19 | "samples": [ 20 | { 21 | "subject": "United States", 22 | "object": "English" 23 | }, 24 | { 25 | "subject": "Canada", 26 | "object": "English" 27 | }, 28 | { 29 | "subject": "Mexico", 30 | "object": "Spanish" 31 | }, 32 | { 33 | "subject": "Brazil", 34 | "object": "Portuguese" 35 | }, 36 | { 37 | "subject": "Argentina", 38 | "object": "Spanish" 39 | }, 40 | { 41 | "subject": "Chile", 42 | "object": "Spanish" 43 | }, 44 | { 45 | "subject": "Peru", 46 | "object": "Spanish" 47 | }, 48 | { 49 | "subject": "Colombia", 50 | "object": "Spanish" 51 | }, 52 | { 53 | "subject": "Venezuela", 54 | "object": "Spanish" 55 | }, 56 | { 57 | "subject": "Spain", 58 | "object": "Spanish" 59 | }, 60 | { 61 | "subject": "France", 62 | "object": "French" 63 | }, 64 | { 65 | "subject": "Germany", 66 | "object": "German" 67 | }, 68 | { 69 | "subject": "Italy", 70 | "object": "Italian" 71 | }, 72 | { 73 | "subject": "Russia", 74 | "object": "Russian" 75 | }, 76 | { 77 | "subject": "China", 78 | "object": "Mandarin Chinese" 79 | }, 80 | { 81 | "subject": "Japan", 82 | "object": "Japanese" 83 | }, 84 | { 85 | "subject": "South Korea", 86 | "object": "Korean" 87 | }, 88 | { 89 | "subject": "India", 90 | "object": "Hindi" 91 | }, 92 | { 93 | "subject": "Pakistan", 94 | "object": "Urdu" 95 | }, 96 | { 97 | "subject": "Nigeria", 98 | "object": "English" 99 | }, 100 | { 101 | "subject": "Egypt", 102 | "object": "Arabic" 103 | }, 104 | { 105 | "subject": "Saudi Arabia", 106 | "object": "Arabic" 107 | }, 108 | { 109 | "subject": "Turkey", 110 | "object": "Turkish" 111 | }, 112 | { 113 | "subject": "Australia", 114 | "object": "English" 115 | } 116 | ] 117 | } -------------------------------------------------------------------------------- /data/factual/country_largest_city.json: -------------------------------------------------------------------------------- 1 | { 2 | "name":"country largest city", 3 | "prompt_templates":[ 4 | "The largest city in {} is", 5 | "The biggest city in {} is" 6 | ], 7 | "prompt_templates_zs":[ 8 | "What is the largest city in {}? It is the city of", 9 | "The largest city in {} is the city of" 10 | ], 11 | "properties":{ 12 | "relation_type":"factual", 13 | "domain_name":"country", 14 | "range_name":"city", 15 | "symmetric":false 16 | }, 17 | "samples":[ 18 | { 19 | "subject":"United States", 20 | "object":"New York City" 21 | }, 22 | { 23 | "subject":"China", 24 | "object":"Shanghai" 25 | }, 26 | { 27 | "subject":"Japan", 28 | "object":"Tokyo" 29 | }, 30 | { 31 | "subject":"Russia", 32 | "object":"Moscow" 33 | }, 34 | { 35 | "subject":"India", 36 | "object":"Mumbai" 37 | }, 38 | { 39 | "subject":"Brazil", 40 | "object":"São Paulo" 41 | }, 42 | { 43 | "subject":"Australia", 44 | "object":"Sydney" 45 | }, 46 | { 47 | "subject":"Canada", 48 | "object":"Toronto" 49 | }, 50 | { 51 | "subject":"United Kingdom", 52 | "object":"London" 53 | }, 54 | { 55 | "subject":"France", 56 | "object":"Paris" 57 | }, 58 | { 59 | "subject":"Germany", 60 | "object":"Berlin" 61 | }, 62 | { 63 | "subject":"Italy", 64 | "object":"Rome" 65 | }, 66 | { 67 | "subject":"Mexico", 68 | "object":"Mexico City" 69 | }, 70 | { 71 | "subject":"South Korea", 72 | "object":"Seoul" 73 | }, 74 | { 75 | "subject":"Turkey", 76 | "object":"Istanbul" 77 | }, 78 | { 79 | "subject":"Spain", 80 | "object":"Madrid" 81 | }, 82 | { 83 | "subject":"Argentina", 84 | "object":"Buenos Aires" 85 | }, 86 | { 87 | "subject":"South Africa", 88 | "object":"Johannesburg" 89 | }, 90 | { 91 | "subject":"Poland", 92 | "object":"Warsaw" 93 | }, 94 | { 95 | "subject":"Nigeria", 96 | "object":"Lagos" 97 | }, 98 | { 99 | "subject":"New Zealand", 100 | "object":"Auckland" 101 | }, 102 | { 103 | "subject":"Switzerland", 104 | "object":"Zurich" 105 | }, 106 | { 107 | "subject":"Netherlands", 108 | "object":"Amsterdam" 109 | }, 110 | { 111 | "subject":"Pakistan", 112 | "object":"Karachi" 113 | } 114 | ] 115 | } -------------------------------------------------------------------------------- /data/factual/food_from_country.json: -------------------------------------------------------------------------------- 1 | { 2 | "name":"food from country", 3 | "prompt_templates":[ 4 | "{} originates from", 5 | "{} is from the country of" 6 | ], 7 | "prompt_templates_zs":[ 8 | "What is the country of origin for {}? It originates from", 9 | "{} originates from the country of" 10 | ], 11 | "properties":{ 12 | "relation_type":"factual", 13 | "domain_name":"food", 14 | "range_name":"country", 15 | "symmetric":false 16 | }, 17 | "samples":[ 18 | { 19 | "subject":"Pizza", 20 | "object":"Italy" 21 | }, 22 | { 23 | "subject":"Sushi", 24 | "object":"Japan" 25 | }, 26 | { 27 | "subject":"Tacos", 28 | "object":"Mexico" 29 | }, 30 | { 31 | "subject":"Baguette", 32 | "object":"France" 33 | }, 34 | { 35 | "subject":"Poutine", 36 | "object":"Canada" 37 | }, 38 | { 39 | "subject":"Paella", 40 | "object":"Spain" 41 | }, 42 | { 43 | "subject":"Chimichurri", 44 | "object":"Argentina" 45 | }, 46 | { 47 | "subject":"Baklava", 48 | "object":"Turkey" 49 | }, 50 | { 51 | "subject":"Feijoada", 52 | "object":"Brazil" 53 | }, 54 | { 55 | "subject":"Borscht", 56 | "object":"Ukraine" 57 | }, 58 | { 59 | "subject":"Fish and Chips", 60 | "object":"United Kingdom" 61 | }, 62 | { 63 | "subject":"Dim Sum", 64 | "object":"China" 65 | }, 66 | { 67 | "subject":"Kimchi", 68 | "object":"South Korea" 69 | }, 70 | { 71 | "subject":"Goulash", 72 | "object":"Hungary" 73 | }, 74 | { 75 | "subject":"Pierogi", 76 | "object":"Poland" 77 | }, 78 | { 79 | "subject":"Pho", 80 | "object":"Vietnam" 81 | }, 82 | { 83 | "subject":"Hummus", 84 | "object":"Lebanon" 85 | }, 86 | { 87 | "subject":"Gyro", 88 | "object":"Greece" 89 | }, 90 | { 91 | "subject":"Masala Dosa", 92 | "object":"India" 93 | }, 94 | { 95 | "subject":"Rendang", 96 | "object":"Indonesia" 97 | }, 98 | { 99 | "subject":"Moussaka", 100 | "object":"Greece" 101 | }, 102 | { 103 | "subject":"Pavlova", 104 | "object":"New Zealand" 105 | }, 106 | { 107 | "subject":"Shawarma", 108 | "object":"Middle East" 109 | }, 110 | { 111 | "subject":"Falafel", 112 | "object":"Middle East" 113 | }, 114 | { 115 | "subject":"Ceviche", 116 | "object":"Peru" 117 | }, 118 | { 119 | "subject":"Biryani", 120 | "object":"India" 121 | }, 122 | { 123 | "subject":"Wiener Schnitzel", 124 | "object":"Austria" 125 | }, 126 | { 127 | "subject":"Fondue", 128 | "object":"Switzerland" 129 | }, 130 | { 131 | "subject":"Pad Thai", 132 | "object":"Thailand" 133 | }, 134 | { 135 | "subject":"Miso Soup", 136 | "object":"Japan" 137 | } 138 | ] 139 | } 140 | -------------------------------------------------------------------------------- /data/factual/person_band_lead_singer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "person lead singer of band", 3 | "prompt_templates": [ 4 | "{} is the lead singer of" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{} is the lead singer of the band named", 8 | "What band is {} lead singer of? They lead the band named" 9 | ], 10 | "properties": { 11 | "relation_type": "factual", 12 | "domain_name": "person", 13 | "range_name": "band", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "Brian Johnson", 19 | "object": "AC/DC" 20 | }, 21 | { 22 | "subject": "Steven Tyler", 23 | "object": "Aerosmith" 24 | }, 25 | { 26 | "subject": "Chris Martin", 27 | "object": "Coldplay" 28 | }, 29 | { 30 | "subject": "Dave Grohl", 31 | "object": "Foo Fighters" 32 | }, 33 | { 34 | "subject": "Axl Rose", 35 | "object": "Guns N' Roses" 36 | }, 37 | { 38 | "subject": "Robert Plant", 39 | "object": "Led Zeppelin" 40 | }, 41 | { 42 | "subject": "James Hetfield", 43 | "object": "Metallica" 44 | }, 45 | { 46 | "subject": "Kurt Cobain", 47 | "object": "Nirvana" 48 | }, 49 | { 50 | "subject": "Freddie Mercury", 51 | "object": "Queen" 52 | }, 53 | { 54 | "subject": "Anthony Kiedis", 55 | "object": "Red Hot Chili Peppers" 56 | }, 57 | { 58 | "subject": "Mick Jagger", 59 | "object": "The Rolling Stones" 60 | }, 61 | { 62 | "subject": "John Lennon", 63 | "object": "The Beatles" 64 | }, 65 | { 66 | "subject": "Jim Morrison", 67 | "object": "The Doors" 68 | }, 69 | { 70 | "subject": "Bono", 71 | "object": "U2" 72 | }, 73 | { 74 | "subject": "Ozzy Osbourne", 75 | "object": "Black Sabbath" 76 | }, 77 | { 78 | "subject": "Billie Joe Armstrong", 79 | "object": "Green Day" 80 | }, 81 | { 82 | "subject": "Eddie Vedder", 83 | "object": "Pearl Jam" 84 | }, 85 | { 86 | "subject": "Jon Bon Jovi", 87 | "object": "Bon Jovi " 88 | }, 89 | { 90 | "subject": "Don Henley", 91 | "object": "The Eagles" 92 | }, 93 | { 94 | "subject": "Liam Gallagher", 95 | "object": "Oasis" 96 | }, 97 | { 98 | "subject": "Serj Tankian", 99 | "object": "System of a Down" 100 | } 101 | ] 102 | } -------------------------------------------------------------------------------- /data/factual/pokemon_evolutions.json: -------------------------------------------------------------------------------- 1 | { 2 | "name":"pokemon evolution", 3 | "prompt_templates":[ 4 | "{} evolves into a" 5 | ], 6 | "prompt_templates_zs":[ 7 | "What does {} evolve into? It evolves into a", 8 | "The evolved form of {} is called a" 9 | ], 10 | "properties":{ 11 | "relation_type":"factual", 12 | "domain_name":"pokemon", 13 | "range_name":"pokemon", 14 | "symmetric":false 15 | }, 16 | "samples":[ 17 | { 18 | "subject":"Bulbasaur", 19 | "object":"Ivysaur" 20 | }, 21 | { 22 | "subject":"Charmander", 23 | "object":"Charmeleon" 24 | }, 25 | { 26 | "subject":"Squirtle", 27 | "object":"Wartortle" 28 | }, 29 | { 30 | "subject":"Pikachu", 31 | "object":"Raichu" 32 | }, 33 | { 34 | "subject":"Oddish", 35 | "object":"Gloom" 36 | }, 37 | { 38 | "subject":"Venonat", 39 | "object":"Venomoth" 40 | }, 41 | { 42 | "subject":"Diglett", 43 | "object":"Dugtrio" 44 | }, 45 | { 46 | "subject":"Meowth", 47 | "object":"Persian" 48 | }, 49 | { 50 | "subject":"Psyduck", 51 | "object":"Golduck" 52 | }, 53 | { 54 | "subject":"Mankey", 55 | "object":"Primeape" 56 | }, 57 | { 58 | "subject":"Growlithe", 59 | "object":"Arcanine" 60 | }, 61 | { 62 | "subject":"Poliwag", 63 | "object":"Poliwhirl" 64 | }, 65 | { 66 | "subject":"Abra", 67 | "object":"Kadabra" 68 | }, 69 | { 70 | "subject":"Machop", 71 | "object":"Machoke" 72 | }, 73 | { 74 | "subject":"Caterpie", 75 | "object":"Metapod" 76 | }, 77 | { 78 | "subject":"Weedle", 79 | "object":"Kakuna" 80 | }, 81 | { 82 | "subject":"Pidgey", 83 | "object":"Pidgeotto" 84 | }, 85 | { 86 | "subject":"Rattata", 87 | "object":"Raticate" 88 | }, 89 | { 90 | "subject":"Spearow", 91 | "object":"Fearow" 92 | }, 93 | { 94 | "subject":"Ekans", 95 | "object":"Arbok" 96 | }, 97 | { 98 | "subject":"Sandshrew", 99 | "object":"Sandslash" 100 | }, 101 | { 102 | "subject":"Nidoran♀", 103 | "object":"Nidorina" 104 | }, 105 | { 106 | "subject":"Nidoran♂", 107 | "object":"Nidorino" 108 | }, 109 | { 110 | "subject":"Zubat", 111 | "object":"Golbat" 112 | }, 113 | { 114 | "subject":"Bellsprout", 115 | "object":"Weepinbell" 116 | }, 117 | { 118 | "subject":"Tentacool", 119 | "object":"Tentacruel" 120 | }, 121 | { 122 | "subject":"Geodude", 123 | "object":"Graveler" 124 | }, 125 | { 126 | "subject":"Ponyta", 127 | "object":"Rapidash" 128 | }, 129 | { 130 | "subject":"Slowpoke", 131 | "object":"Slowbro" 132 | }, 133 | { 134 | "subject":"Magnemite", 135 | "object":"Magneton" 136 | }, 137 | { 138 | "subject":"Doduo", 139 | "object":"Dodrio" 140 | }, 141 | { 142 | "subject":"Seel", 143 | "object":"Dewgong" 144 | }, 145 | { 146 | "subject":"Grimer", 147 | "object":"Muk" 148 | }, 149 | { 150 | "subject":"Shellder", 151 | "object":"Cloyster" 152 | }, 153 | { 154 | "subject":"Gastly", 155 | "object":"Haunter" 156 | }, 157 | { 158 | "subject":"Drowzee", 159 | "object":"Hypno" 160 | }, 161 | { 162 | "subject":"Krabby", 163 | "object":"Kingler" 164 | }, 165 | { 166 | "subject":"Voltorb", 167 | "object":"Electrode" 168 | }, 169 | { 170 | "subject":"Exeggcute", 171 | "object":"Exeggutor" 172 | }, 173 | { 174 | "subject":"Cubone", 175 | "object":"Marowak" 176 | }, 177 | { 178 | "subject":"Koffing", 179 | "object":"Weezing" 180 | }, 181 | { 182 | "subject":"Rhyhorn", 183 | "object":"Rhydon" 184 | }, 185 | { 186 | "subject":"Horsea", 187 | "object":"Seadra" 188 | }, 189 | { 190 | "subject":"Goldeen", 191 | "object":"Seaking" 192 | } 193 | ] 194 | } -------------------------------------------------------------------------------- /data/factual/presidents_birth_year.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "president birth year", 3 | "prompt_templates": [ 4 | "{} was born in" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{} was born in" 8 | ], 9 | "properties": { 10 | "relation_type": "factual", 11 | "domain_name": "person", 12 | "range_name": "year", 13 | "symmetric": false 14 | }, 15 | "samples": [ 16 | { 17 | "subject": "John Adams", 18 | "object": "1735" 19 | }, 20 | { 21 | "subject": "Thomas Jefferson", 22 | "object": "1743" 23 | }, 24 | { 25 | "subject": "James Madison", 26 | "object": "1751" 27 | }, 28 | { 29 | "subject": "James Monroe", 30 | "object": "1758" 31 | }, 32 | { 33 | "subject": "John Quincy Adams", 34 | "object": "1767" 35 | }, 36 | { 37 | "subject": "Andrew Jackson", 38 | "object": "1767" 39 | }, 40 | { 41 | "subject": "Martin Van Buren", 42 | "object": "1782" 43 | }, 44 | { 45 | "subject": "William Henry Harrison", 46 | "object": "1773" 47 | }, 48 | { 49 | "subject": "Richard Nixon", 50 | "object": "1913" 51 | }, 52 | { 53 | "subject": "John F. Kennedy", 54 | "object": "1917" 55 | }, 56 | { 57 | "subject": "Richard Nixon", 58 | "object": "1913" 59 | }, 60 | { 61 | "subject": "Bill Clinton", 62 | "object": "1946" 63 | }, 64 | { 65 | "subject": "George W. Bush", 66 | "object": "1946" 67 | }, 68 | { 69 | "subject": "Barack Obama", 70 | "object": "1961" 71 | }, 72 | { 73 | "subject": "Jimmy Carter ", 74 | "object": "1924" 75 | }, 76 | { 77 | "subject": "Ronald Reagan", 78 | "object": "1911" 79 | }, 80 | { 81 | "subject": "George H. W. Bush", 82 | "object": "1924" 83 | }, 84 | { 85 | "subject": "Joe Biden", 86 | "object": "1942" 87 | }, 88 | { 89 | "subject": "Franklin D. Roosevelt", 90 | "object": "1882" 91 | } 92 | ] 93 | } -------------------------------------------------------------------------------- /data/factual/presidents_election_year.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "president election year", 3 | "prompt_templates": [ 4 | "{} was elected in" 5 | ], 6 | "prompt_templates_zs": [ 7 | "{} was elected in" 8 | ], 9 | "properties": { 10 | "relation_type": "factual", 11 | "domain_name": "person", 12 | "range_name": "year", 13 | "symmetric": false 14 | }, 15 | "samples": [ 16 | { 17 | "subject": "John Adams", 18 | "object": "1796" 19 | }, 20 | { 21 | "subject": "Thomas Jefferson", 22 | "object": "1800" 23 | }, 24 | { 25 | "subject": "James Madison", 26 | "object": "1808" 27 | }, 28 | { 29 | "subject": "James Monroe", 30 | "object": "1816" 31 | }, 32 | { 33 | "subject": "John Quincy Adams", 34 | "object": "1824" 35 | }, 36 | { 37 | "subject": "Andrew Jackson", 38 | "object": "1828" 39 | }, 40 | { 41 | "subject": "Martin Van Buren", 42 | "object": "1836" 43 | }, 44 | { 45 | "subject": "William Henry Harrison", 46 | "object": "1840" 47 | }, 48 | { 49 | "subject": "Richard M. Nixon", 50 | "object": "1968" 51 | }, 52 | { 53 | "subject": "John F. Kennedy", 54 | "object": "1960" 55 | }, 56 | { 57 | "subject": "Richard Nixon", 58 | "object": "1968" 59 | }, 60 | { 61 | "subject": "Bill Clinton", 62 | "object": "1992" 63 | }, 64 | { 65 | "subject": "George W. Bush", 66 | "object": "2000" 67 | }, 68 | { 69 | "subject": "Barack Obama", 70 | "object": "2008" 71 | }, 72 | { 73 | "subject": "Jimmy Carter ", 74 | "object": "1976" 75 | }, 76 | { 77 | "subject": "Ronald Reagan", 78 | "object": "1980" 79 | }, 80 | { 81 | "subject": "George H. W. Bush", 82 | "object": "1988" 83 | }, 84 | { 85 | "subject": "Joe Biden", 86 | "object": "2020" 87 | }, 88 | { 89 | "subject": "Franklin D. Roosevelt", 90 | "object": "1932" 91 | } 92 | ] 93 | } -------------------------------------------------------------------------------- /data/linguistic/adj_comparative.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "adjective comparative", 3 | "prompt_templates": [ 4 | "The comparative form of {} is" 5 | ], 6 | "prompt_templates_zs": [ 7 | "The comparative form of {} is", 8 | "What is the comparative form of {}? It is" 9 | ], 10 | "properties": { 11 | "relation_type": "linguistic", 12 | "domain_name": "adjective", 13 | "range_name": "adjective", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "big", 19 | "object": "bigger" 20 | }, 21 | { 22 | "subject": "small", 23 | "object": "smaller" 24 | }, 25 | { 26 | "subject": "tall", 27 | "object": "taller" 28 | }, 29 | { 30 | "subject": "short", 31 | "object": "shorter" 32 | }, 33 | { 34 | "subject": "long", 35 | "object": "longer" 36 | }, 37 | { 38 | "subject": "fast", 39 | "object": "faster" 40 | }, 41 | { 42 | "subject": "slow", 43 | "object": "slower" 44 | }, 45 | { 46 | "subject": "strong", 47 | "object": "stronger" 48 | }, 49 | { 50 | "subject": "weak", 51 | "object": "weaker" 52 | }, 53 | { 54 | "subject": "heavy", 55 | "object": "heavier" 56 | }, 57 | { 58 | "subject": "light", 59 | "object": "lighter" 60 | }, 61 | { 62 | "subject": "old", 63 | "object": "older" 64 | }, 65 | { 66 | "subject": "young", 67 | "object": "younger" 68 | }, 69 | { 70 | "subject": "high", 71 | "object": "higher" 72 | }, 73 | { 74 | "subject": "low", 75 | "object": "lower" 76 | }, 77 | { 78 | "subject": "deep", 79 | "object": "deeper" 80 | }, 81 | { 82 | "subject": "shallow", 83 | "object": "shallower" 84 | }, 85 | { 86 | "subject": "wide", 87 | "object": "wider" 88 | }, 89 | { 90 | "subject": "narrow", 91 | "object": "narrower" 92 | }, 93 | { 94 | "subject": "thick", 95 | "object": "thicker" 96 | }, 97 | { 98 | "subject": "thin", 99 | "object": "thinner" 100 | }, 101 | { 102 | "subject": "hot", 103 | "object": "hotter" 104 | }, 105 | { 106 | "subject": "cold", 107 | "object": "colder" 108 | }, 109 | { 110 | "subject": "bright", 111 | "object": "brighter" 112 | }, 113 | { 114 | "subject": "dark", 115 | "object": "darker" 116 | }, 117 | { 118 | "subject": "loud", 119 | "object": "louder" 120 | }, 121 | { 122 | "subject": "quiet", 123 | "object": "quieter" 124 | }, 125 | { 126 | "subject": "happy", 127 | "object": "happier" 128 | }, 129 | { 130 | "subject": "sad", 131 | "object": "sadder" 132 | }, 133 | { 134 | "subject": "good", 135 | "object": "better" 136 | }, 137 | { 138 | "subject": "bad", 139 | "object": "worse" 140 | }, 141 | { 142 | "subject": "ugly", 143 | "object": "uglier" 144 | }, 145 | { 146 | "subject": "clean", 147 | "object": "cleaner" 148 | }, 149 | { 150 | "subject": "dirty", 151 | "object": "dirtier" 152 | }, 153 | { 154 | "subject": "easy", 155 | "object": "easier" 156 | }, 157 | { 158 | "subject": "simple", 159 | "object": "simpler" 160 | }, 161 | { 162 | "subject": "kind", 163 | "object": "kinder" 164 | }, 165 | { 166 | "subject": "mean", 167 | "object": "meaner" 168 | }, 169 | { 170 | "subject": "brave", 171 | "object": "braver" 172 | }, 173 | { 174 | "subject": "smart", 175 | "object": "smarter" 176 | }, 177 | { 178 | "subject": "stupid", 179 | "object": "stupider" 180 | }, 181 | { 182 | "subject": "wise", 183 | "object": "wiser" 184 | }, 185 | { 186 | "subject": "happy", 187 | "object": "happier" 188 | }, 189 | { 190 | "subject": "rich", 191 | "object": "richer" 192 | }, 193 | { 194 | "subject": "poor", 195 | "object": "poorer" 196 | }, 197 | { 198 | "subject": "hot", 199 | "object": "hotter" 200 | }, 201 | { 202 | "subject": "cold", 203 | "object": "colder" 204 | }, 205 | { 206 | "subject": "sweet", 207 | "object": "sweeter" 208 | }, 209 | { 210 | "subject": "sour", 211 | "object": "sourer" 212 | }, 213 | { 214 | "subject": "fresh", 215 | "object": "fresher" 216 | }, 217 | { 218 | "subject": "stale", 219 | "object": "staler" 220 | }, 221 | { 222 | "subject": "cheap", 223 | "object": "cheaper" 224 | }, 225 | { 226 | "subject": "safe", 227 | "object": "safer" 228 | }, 229 | { 230 | "subject": "bright", 231 | "object": "brighter" 232 | }, 233 | { 234 | "subject": "dull", 235 | "object": "duller" 236 | }, 237 | { 238 | "subject": "happy", 239 | "object": "happier" 240 | }, 241 | { 242 | "subject": "sad", 243 | "object": "sadder" 244 | }, 245 | { 246 | "subject": "true", 247 | "object": "truer" 248 | }, 249 | { 250 | "subject": "false", 251 | "object": "falser" 252 | }, 253 | { 254 | "subject": "rude", 255 | "object": "ruder" 256 | }, 257 | { 258 | "subject": "calm", 259 | "object": "calmer" 260 | }, 261 | { 262 | "subject": "strong", 263 | "object": "stronger" 264 | }, 265 | { 266 | "subject": "weak", 267 | "object": "weaker" 268 | }, 269 | { 270 | "subject": "healthy", 271 | "object": "healthier" 272 | }, 273 | { 274 | "subject": "sick", 275 | "object": "sicker" 276 | }, 277 | { 278 | "subject": "clean", 279 | "object": "cleaner" 280 | }, 281 | { 282 | "subject": "dirty", 283 | "object": "dirtier" 284 | }, 285 | { 286 | "subject": "fresh", 287 | "object": "fresher" 288 | } 289 | ] 290 | } -------------------------------------------------------------------------------- /data/linguistic/adj_superlative.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "adjective superlative", 3 | "prompt_templates": [ 4 | "The superlative form of {} is" 5 | ], 6 | "prompt_templates_zs": [ 7 | "The superlative form of {} is", 8 | "What is the superlative form of {}? It is" 9 | ], 10 | "properties": { 11 | "relation_type": "linguistic", 12 | "domain_name": "adjective", 13 | "range_name": "adjective", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "angry", 19 | "object": "angriest" 20 | }, 21 | { 22 | "subject": "bad", 23 | "object": "worst" 24 | }, 25 | { 26 | "subject": "big", 27 | "object": "biggest" 28 | }, 29 | { 30 | "subject": "brave", 31 | "object": "bravest" 32 | }, 33 | { 34 | "subject": "bright", 35 | "object": "brightest" 36 | }, 37 | { 38 | "subject": "calm", 39 | "object": "calmest" 40 | }, 41 | { 42 | "subject": "cheap", 43 | "object": "cheapest" 44 | }, 45 | { 46 | "subject": "clean", 47 | "object": "cleanest" 48 | }, 49 | { 50 | "subject": "cold", 51 | "object": "coldest" 52 | }, 53 | { 54 | "subject": "crazy", 55 | "object": "craziest" 56 | }, 57 | { 58 | "subject": "cruel", 59 | "object": "cruelest" 60 | }, 61 | { 62 | "subject": "dark", 63 | "object": "darkest" 64 | }, 65 | { 66 | "subject": "deep", 67 | "object": "deepest" 68 | }, 69 | { 70 | "subject": "dirty", 71 | "object": "dirtiest" 72 | }, 73 | { 74 | "subject": "dry", 75 | "object": "driest" 76 | }, 77 | { 78 | "subject": "dull", 79 | "object": "dullest" 80 | }, 81 | { 82 | "subject": "easy", 83 | "object": "easiest" 84 | }, 85 | { 86 | "subject": "fast", 87 | "object": "fastest" 88 | }, 89 | { 90 | "subject": "fierce", 91 | "object": "fiercest" 92 | }, 93 | { 94 | "subject": "fresh", 95 | "object": "freshest" 96 | }, 97 | { 98 | "subject": "friendly", 99 | "object": "friendliest" 100 | }, 101 | { 102 | "subject": "full", 103 | "object": "fullest" 104 | }, 105 | { 106 | "subject": "funny", 107 | "object": "funniest" 108 | }, 109 | { 110 | "subject": "gentle", 111 | "object": "gentlest" 112 | }, 113 | { 114 | "subject": "good", 115 | "object": "best" 116 | }, 117 | { 118 | "subject": "great", 119 | "object": "greatest" 120 | }, 121 | { 122 | "subject": "happy", 123 | "object": "happiest" 124 | }, 125 | { 126 | "subject": "hard", 127 | "object": "hardest" 128 | }, 129 | { 130 | "subject": "healthy", 131 | "object": "healthiest" 132 | }, 133 | { 134 | "subject": "heavy", 135 | "object": "heaviest" 136 | }, 137 | { 138 | "subject": "high", 139 | "object": "highest" 140 | }, 141 | { 142 | "subject": "hot", 143 | "object": "hottest" 144 | }, 145 | { 146 | "subject": "hungry", 147 | "object": "hungriest" 148 | }, 149 | { 150 | "subject": "kind", 151 | "object": "kindest" 152 | }, 153 | { 154 | "subject": "large", 155 | "object": "largest" 156 | }, 157 | { 158 | "subject": "lazy", 159 | "object": "laziest" 160 | }, 161 | { 162 | "subject": "light", 163 | "object": "lightest" 164 | }, 165 | { 166 | "subject": "little", 167 | "object": "smallest" 168 | }, 169 | { 170 | "subject": "long", 171 | "object": "longest" 172 | }, 173 | { 174 | "subject": "loud", 175 | "object": "loudest" 176 | }, 177 | { 178 | "subject": "low", 179 | "object": "lowest" 180 | }, 181 | { 182 | "subject": "mad", 183 | "object": "maddest" 184 | }, 185 | { 186 | "subject": "mean", 187 | "object": "meanest" 188 | }, 189 | { 190 | "subject": "messy", 191 | "object": "messiest" 192 | }, 193 | { 194 | "subject": "narrow", 195 | "object": "narrowest" 196 | }, 197 | { 198 | "subject": "near", 199 | "object": "nearest" 200 | }, 201 | { 202 | "subject": "new", 203 | "object": "newest" 204 | }, 205 | { 206 | "subject": "old", 207 | "object": "oldest" 208 | }, 209 | { 210 | "subject": "polite", 211 | "object": "politest" 212 | }, 213 | { 214 | "subject": "poor", 215 | "object": "poorest" 216 | }, 217 | { 218 | "subject": "quick", 219 | "object": "quickest" 220 | }, 221 | { 222 | "subject": "quiet", 223 | "object": "quietest" 224 | }, 225 | { 226 | "subject": "rich", 227 | "object": "richest" 228 | }, 229 | { 230 | "subject": "sad", 231 | "object": "saddest" 232 | }, 233 | { 234 | "subject": "safe", 235 | "object": "safest" 236 | }, 237 | { 238 | "subject": "short", 239 | "object": "shortest" 240 | }, 241 | { 242 | "subject": "shy", 243 | "object": "shyest" 244 | }, 245 | { 246 | "subject": "simple", 247 | "object": "simplest" 248 | }, 249 | { 250 | "subject": "slow", 251 | "object": "slowest" 252 | }, 253 | { 254 | "subject": "small", 255 | "object": "smallest" 256 | }, 257 | { 258 | "subject": "smart", 259 | "object": "smartest" 260 | }, 261 | { 262 | "subject": "smooth", 263 | "object": "smoothest" 264 | }, 265 | { 266 | "subject": "strong", 267 | "object": "strongest" 268 | }, 269 | { 270 | "subject": "sweet", 271 | "object": "sweetest" 272 | }, 273 | { 274 | "subject": "tall", 275 | "object": "tallest" 276 | }, 277 | { 278 | "subject": "thick", 279 | "object": "thickest" 280 | }, 281 | { 282 | "subject": "thin", 283 | "object": "thinnest" 284 | }, 285 | { 286 | "subject": "tiny", 287 | "object": "tiniest" 288 | }, 289 | { 290 | "subject": "tough", 291 | "object": "toughest" 292 | }, 293 | { 294 | "subject": "true", 295 | "object": "truest" 296 | }, 297 | { 298 | "subject": "ugly", 299 | "object": "ugliest" 300 | }, 301 | { 302 | "subject": "weak", 303 | "object": "weakest" 304 | }, 305 | { 306 | "subject": "wet", 307 | "object": "wettest" 308 | }, 309 | { 310 | "subject": "wide", 311 | "object": "widest" 312 | }, 313 | { 314 | "subject": "wild", 315 | "object": "wildest" 316 | }, 317 | { 318 | "subject": "wise", 319 | "object": "wisest" 320 | }, 321 | { 322 | "subject": "witty", 323 | "object": "wittiest" 324 | }, 325 | { 326 | "subject": "wrong", 327 | "object": "wrongest" 328 | }, 329 | { 330 | "subject": "young", 331 | "object": "youngest" 332 | }, 333 | { 334 | "subject": "zesty", 335 | "object": "zestiest" 336 | } 337 | ] 338 | } -------------------------------------------------------------------------------- /data/linguistic/verb_past_tense.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "verb past tense", 3 | "prompt_templates": [ 4 | "The past tense of {} is" 5 | ], 6 | "prompt_templates_zs": [ 7 | "The past tense of {} is", 8 | "What is the past tense of {}? It is" 9 | ], 10 | "properties": { 11 | "relation_type": "linguistic", 12 | "domain_name": "verb", 13 | "range_name": "verb", 14 | "symmetric": false 15 | }, 16 | "samples": [ 17 | { 18 | "subject": "ask", 19 | "object": "asked" 20 | }, 21 | { 22 | "subject": "believe", 23 | "object": "believed" 24 | }, 25 | { 26 | "subject": "break", 27 | "object": "broke" 28 | }, 29 | { 30 | "subject": "bring", 31 | "object": "brought" 32 | }, 33 | { 34 | "subject": "build", 35 | "object": "built" 36 | }, 37 | { 38 | "subject": "call", 39 | "object": "called" 40 | }, 41 | { 42 | "subject": "catch", 43 | "object": "caught" 44 | }, 45 | { 46 | "subject": "change", 47 | "object": "changed" 48 | }, 49 | { 50 | "subject": "choose", 51 | "object": "chose" 52 | }, 53 | { 54 | "subject": "clean", 55 | "object": "cleaned" 56 | }, 57 | { 58 | "subject": "climb", 59 | "object": "climbed" 60 | }, 61 | { 62 | "subject": "close", 63 | "object": "closed" 64 | }, 65 | { 66 | "subject": "come", 67 | "object": "came" 68 | }, 69 | { 70 | "subject": "cry", 71 | "object": "cried" 72 | }, 73 | { 74 | "subject": "cut", 75 | "object": "cut" 76 | }, 77 | { 78 | "subject": "dance", 79 | "object": "danced" 80 | }, 81 | { 82 | "subject": "decide", 83 | "object": "decided" 84 | }, 85 | { 86 | "subject": "do", 87 | "object": "did" 88 | }, 89 | { 90 | "subject": "drive", 91 | "object": "drove" 92 | }, 93 | { 94 | "subject": "drink", 95 | "object": "drank" 96 | }, 97 | { 98 | "subject": "eat", 99 | "object": "ate" 100 | }, 101 | { 102 | "subject": "fall", 103 | "object": "fell" 104 | }, 105 | { 106 | "subject": "finish", 107 | "object": "finished" 108 | }, 109 | { 110 | "subject": "find", 111 | "object": "found" 112 | }, 113 | { 114 | "subject": "fly", 115 | "object": "flew" 116 | }, 117 | { 118 | "subject": "follow", 119 | "object": "followed" 120 | }, 121 | { 122 | "subject": "forget", 123 | "object": "forgot" 124 | }, 125 | { 126 | "subject": "frown", 127 | "object": "frowned" 128 | }, 129 | { 130 | "subject": "get", 131 | "object": "got" 132 | }, 133 | { 134 | "subject": "give", 135 | "object": "gave" 136 | }, 137 | { 138 | "subject": "go", 139 | "object": "went" 140 | }, 141 | { 142 | "subject": "hate", 143 | "object": "hated" 144 | }, 145 | { 146 | "subject": "have", 147 | "object": "had" 148 | }, 149 | { 150 | "subject": "hear", 151 | "object": "heard" 152 | }, 153 | { 154 | "subject": "help", 155 | "object": "helped" 156 | }, 157 | { 158 | "subject": "hit", 159 | "object": "hit" 160 | }, 161 | { 162 | "subject": "jump", 163 | "object": "jumped" 164 | }, 165 | { 166 | "subject": "know", 167 | "object": "knew" 168 | }, 169 | { 170 | "subject": "laugh", 171 | "object": "laughed" 172 | }, 173 | { 174 | "subject": "learn", 175 | "object": "learned" 176 | }, 177 | { 178 | "subject": "leave", 179 | "object": "left" 180 | }, 181 | { 182 | "subject": "like", 183 | "object": "liked" 184 | }, 185 | { 186 | "subject": "live", 187 | "object": "lived" 188 | }, 189 | { 190 | "subject": "look", 191 | "object": "looked" 192 | }, 193 | { 194 | "subject": "lose", 195 | "object": "lost" 196 | }, 197 | { 198 | "subject": "love", 199 | "object": "loved" 200 | }, 201 | { 202 | "subject": "make", 203 | "object": "made" 204 | }, 205 | { 206 | "subject": "meet", 207 | "object": "met" 208 | }, 209 | { 210 | "subject": "need", 211 | "object": "needed" 212 | }, 213 | { 214 | "subject": "open", 215 | "object": "opened" 216 | }, 217 | { 218 | "subject": "play", 219 | "object": "played" 220 | }, 221 | { 222 | "subject": "read", 223 | "object": "read" 224 | }, 225 | { 226 | "subject": "remember", 227 | "object": "remembered" 228 | }, 229 | { 230 | "subject": "run", 231 | "object": "ran" 232 | }, 233 | { 234 | "subject": "say", 235 | "object": "said" 236 | }, 237 | { 238 | "subject": "see", 239 | "object": "saw" 240 | }, 241 | { 242 | "subject": "shake", 243 | "object": "shook" 244 | }, 245 | { 246 | "subject": "sit", 247 | "object": "sat" 248 | }, 249 | { 250 | "subject": "sleep", 251 | "object": "slept" 252 | }, 253 | { 254 | "subject": "smile", 255 | "object": "smiled" 256 | }, 257 | { 258 | "subject": "speak", 259 | "object": "spoke" 260 | }, 261 | { 262 | "subject": "start", 263 | "object": "started" 264 | }, 265 | { 266 | "subject": "stand", 267 | "object": "stood" 268 | }, 269 | { 270 | "subject": "study", 271 | "object": "studied" 272 | }, 273 | { 274 | "subject": "swim", 275 | "object": "swam" 276 | }, 277 | { 278 | "subject": "take", 279 | "object": "took" 280 | }, 281 | { 282 | "subject": "talk", 283 | "object": "talked" 284 | }, 285 | { 286 | "subject": "think", 287 | "object": "thought" 288 | }, 289 | { 290 | "subject": "try", 291 | "object": "tried" 292 | }, 293 | { 294 | "subject": "understand", 295 | "object": "understood" 296 | }, 297 | { 298 | "subject": "walk", 299 | "object": "walked" 300 | }, 301 | { 302 | "subject": "want", 303 | "object": "wanted" 304 | }, 305 | { 306 | "subject": "wear", 307 | "object": "wore" 308 | }, 309 | { 310 | "subject": "win", 311 | "object": "won" 312 | }, 313 | { 314 | "subject": "work", 315 | "object": "worked" 316 | }, 317 | { 318 | "subject": "write", 319 | "object": "wrote" 320 | } 321 | ] 322 | } -------------------------------------------------------------------------------- /eap/dataset.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | def collate_EAP(xs, task): 10 | clean, corrupted, labels = zip(*xs) 11 | clean = list(clean) 12 | corrupted = list(corrupted) 13 | if 'hypernymy' not in task: 14 | labels = torch.tensor(labels) 15 | return clean, corrupted, labels 16 | 17 | class EAPDataset(Dataset): 18 | def __init__(self, task:str, filename:Optional[str]=None): 19 | self.df = pd.read_csv(filename) 20 | self.task = task 21 | 22 | def __len__(self): 23 | return len(self.df) 24 | 25 | def shuffle(self): 26 | self.df = self.df.sample(frac=1) 27 | 28 | def head(self, n: int): 29 | self.df = self.df.head(n) 30 | 31 | def __getitem__(self, index): 32 | row = self.df.iloc[index] 33 | label = None 34 | if self.task == 'ioi': 35 | label = [row['correct_idx'], row['incorrect_idx']] 36 | elif 'greater-than' in self.task: 37 | label = row['correct_idx'] 38 | elif 'hypernymy' in self.task: 39 | answer = torch.tensor(eval(row['answers_idx'])) 40 | corrupted_answer = torch.tensor(eval(row['corrupted_answers_idx'])) 41 | label = [answer, corrupted_answer] 42 | elif 'fact-retrieval' in self.task: 43 | label = [row['country_idx'], row['corrupted_country_idx']] 44 | elif 'gender' in self.task: 45 | label = [row['clean_answer_idx'], row['corrupted_answer_idx']] 46 | elif self.task == 'sva': 47 | label = row['plural'] 48 | elif self.task == 'colored-objects': 49 | label = [row['correct_idx'], row['incorrect_idx']] 50 | elif self.task in {'dummy-easy', 'dummy-medium', 'dummy-hard'}: 51 | label = 0 52 | else: 53 | raise ValueError(f'Got invalid task: {self.task}') 54 | return row['clean'], row['corrupted'], label 55 | 56 | def to_dataloader(self, batch_size: int): 57 | return DataLoader(self, batch_size=batch_size, collate_fn=partial(collate_EAP, task=self.task)) -------------------------------------------------------------------------------- /eap/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union, Literal, Tuple 2 | from functools import partial 3 | 4 | import pandas as pd 5 | import torch 6 | from torch.nn.functional import kl_div 7 | from transformers import PreTrainedTokenizer 8 | from transformer_lens import HookedTransformer 9 | 10 | def get_metric(metric_name: str, task: str, tokenizer:Optional[PreTrainedTokenizer]=None, model: Optional[HookedTransformer]=None): 11 | if metric_name == 'kl_divergence' or metric_name == 'kl': 12 | return partial(divergence, divergence_type='kl') 13 | elif metric_name == 'js_divergence' or metric_name == 'js': 14 | return partial(divergence, divergence_type='js') 15 | elif metric_name == 'logit_diff' or metric_name == 'prob_diff': 16 | prob = (metric_name == 'prob_diff') 17 | if 'greater-than' in task: 18 | if tokenizer is None: 19 | if model is None: 20 | raise ValueError("Either tokenizer or model must be set for greater-than and prob / logit diff") 21 | else: 22 | tokenizer = model.tokenizer 23 | logit_diff_fn = get_logit_diff_greater_than(tokenizer) 24 | elif 'hypernymy' in task: 25 | logit_diff_fn = logit_diff_hypernymy 26 | elif task == 'sva': 27 | if model is None: 28 | raise ValueError("model must be set for sva and prob / logit diff") 29 | logit_diff_fn = get_logit_diff_sva(model) 30 | else: 31 | logit_diff_fn = logit_diff 32 | return partial(logit_diff_fn, prob=prob) 33 | else: 34 | raise ValueError(f"got bad metric_name: {metric_name}") 35 | 36 | def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor): 37 | batch_size = logits.size(0) 38 | idx = torch.arange(batch_size, device=logits.device) 39 | 40 | logits = logits[idx, input_length - 1] 41 | return logits 42 | 43 | def js_div(p: torch.tensor, q: torch.tensor): 44 | p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1)) 45 | m = (0.5 * (p + q)).log() 46 | return 0.5 * (kl_div(m, p.log(), log_target=True, reduction='none').mean(-1) + kl_div(m, q.log(), log_target=True, reduction='none').mean(-1)) 47 | 48 | def divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, divergence_type: Union[Literal['kl'], Literal['js']]='kl', mean=True, loss=True): 49 | logits = get_logit_positions(logits, input_length) 50 | clean_logits = get_logit_positions(clean_logits, input_length) 51 | 52 | probs = torch.softmax(logits, dim=-1) 53 | clean_probs = torch.softmax(clean_logits, dim=-1) 54 | 55 | if divergence_type == 'kl': 56 | results = kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1) 57 | elif divergence_type == 'js': 58 | results = js_div(probs, clean_probs) 59 | else: 60 | raise ValueError(f"Expected divergence_type of 'kl' or 'js', but got '{divergence_type}'") 61 | return results.mean() if mean else results 62 | 63 | def logit_diff(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False): 64 | clean_logits = get_logit_positions(clean_logits, input_length) 65 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits 66 | good_bad = torch.gather(cleans, -1, labels.to(cleans.device)) 67 | results = good_bad[:, 0] - good_bad[:, 1] 68 | 69 | if loss: 70 | # remember it's reversed to make it a loss 71 | results = -results 72 | if mean: 73 | results = results.mean() 74 | return results 75 | 76 | def direct_logit(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False): 77 | clean_logits = get_logit_positions(clean_logits, input_length) 78 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits 79 | good_bad = torch.gather(cleans, -1, labels.to(cleans.device)) 80 | results = good_bad[:, 0] 81 | 82 | if loss: 83 | # remember it's reversed to make it a loss 84 | results = -results 85 | if mean: 86 | results = results.mean() 87 | return results 88 | 89 | def logit_diff_hypernymy(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: List[torch.Tensor], mean=True, prob=False, loss=False): 90 | clean_logits = get_logit_positions(clean_logits, input_length) 91 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits 92 | 93 | results = [] 94 | for i, (ls,corrupted_ls) in enumerate(labels): 95 | r = cleans[i][ls.to(cleans.device)].sum() - cleans[i][corrupted_ls.to(cleans.device)].sum() 96 | results.append(r) 97 | results = torch.stack(results) 98 | 99 | if loss: 100 | # remember it's reversed to make it a loss 101 | results = -results 102 | if mean: 103 | results = results.mean() 104 | return results 105 | 106 | def get_year_indices(tokenizer: PreTrainedTokenizer): 107 | return torch.tensor([tokenizer(f'{year:02d}').input_ids[0] for year in range(100)]) 108 | 109 | 110 | def get_logit_diff_greater_than(tokenizer: PreTrainedTokenizer): 111 | year_indices = get_year_indices(tokenizer) 112 | def logit_diff_greater_than(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False): 113 | # Prob diff (negative, since it's a loss) 114 | clean_logits = get_logit_positions(clean_logits, input_length) 115 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits 116 | cleans = cleans[:, year_indices] 117 | 118 | results = [] 119 | if prob: 120 | for prob, year in zip(cleans, labels): 121 | results.append(prob[year + 1 :].sum() - prob[: year + 1].sum()) 122 | else: 123 | for logit, year in zip(cleans, labels): 124 | results.append(logit[year + 1 :].mean() - logit[: year + 1].mean()) 125 | 126 | results = torch.stack(results) 127 | if loss: 128 | results = -results 129 | if mean: 130 | results = results.mean() 131 | return results 132 | return logit_diff_greater_than 133 | 134 | def get_singular_and_plural(model, strict=False) -> Tuple[torch.Tensor, torch.Tensor]: 135 | tokenizer = model.tokenizer 136 | tokenizer_length = model.cfg.d_vocab_out 137 | 138 | df: pd.DataFrame = pd.read_csv('data/sva/combined_verb_list.csv') 139 | singular = df['sing'].to_list() 140 | plural = df['plur'].to_list() 141 | singular_set = set(singular) 142 | plural_set = set(plural) 143 | verb_set = singular_set | plural_set 144 | assert len(singular_set & plural_set) == 0, f"{singular_set & plural_set}" 145 | singular_indices, plural_indices = [], [] 146 | 147 | for i in range(tokenizer_length): 148 | token = tokenizer._convert_id_to_token(i) 149 | if token is not None: 150 | if token[0] == 'Ġ': 151 | token = token[1:] 152 | if token in verb_set: 153 | if token in singular_set: 154 | singular_indices.append(i) 155 | else: # token in plural_set: 156 | idx = plural.index(token) 157 | third_person_present = singular[idx] 158 | third_person_present_tokenized = tokenizer(f' {third_person_present}', add_special_tokens=False)['input_ids'] 159 | if len(third_person_present_tokenized) == 1 and third_person_present_tokenized[0] != tokenizer.unk_token_id: 160 | plural_indices.append(i) 161 | elif not strict: 162 | plural_indices.append(i) 163 | 164 | return torch.tensor(singular_indices, device=model.cfg.device), torch.tensor(plural_indices, device=model.cfg.device) 165 | 166 | def get_logit_diff_sva(model, strict=True) -> torch.Tensor: 167 | singular_indices, plural_indices = get_singular_and_plural(model, strict=strict) 168 | def sva_logit_diff(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False): 169 | clean_logits = get_logit_positions(clean_logits, input_length) 170 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits 171 | 172 | if prob: 173 | singular = cleans[:, singular_indices].sum(-1) 174 | plural = cleans[:, plural_indices].sum(-1) 175 | else: 176 | singular = cleans[:, singular_indices].mean(-1) 177 | plural = cleans[:, plural_indices].mean(-1) 178 | 179 | results = torch.where(labels.to(cleans.device) == 0, singular - plural, plural - singular) 180 | if loss: 181 | results = -results 182 | if mean: 183 | results = results.mean() 184 | return results 185 | return sva_logit_diff -------------------------------------------------------------------------------- /eap/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def model2family(model_name: str): 9 | if 'gpt2' in model_name: 10 | return 'gpt2' 11 | elif 'pythia' in model_name: 12 | return 'pythia' 13 | else: 14 | raise ValueError(f"Couldn't find model family for model: {model_name}") 15 | 16 | def kl_div(logits, clean_logits, input_length, labels, mean=True): 17 | batch_size = logits.size(0) 18 | idx = torch.arange(batch_size, device=logits.device) 19 | 20 | logits = logits[idx, input_length - 1] 21 | clean_logits = clean_logits[idx, input_length - 1] 22 | 23 | logprobs = torch.log_softmax(logits, dim=-1) 24 | clean_logprobs = torch.log_softmax(clean_logits, dim=-1) 25 | 26 | results = torch.nn.functional.kl_div(logprobs, clean_logprobs, log_target=True, reduction='none') 27 | return results.mean() if mean else results 28 | 29 | def precision_at_k(clean_logits, corrupted_logits, input_length, labels, k=1, mean=True): 30 | batch_size = clean_logits.size(0) 31 | idx = torch.arange(batch_size, device=clean_logits.device) 32 | 33 | clean_logits = clean_logits[idx, input_length - 1] 34 | clean_probs = torch.softmax(clean_logits, dim=-1) 35 | predictions = torch.argmax(clean_probs, dim=-1).cpu() 36 | 37 | results = [] 38 | for i, (ls,_) in enumerate(labels): 39 | r = torch.sum((ls == predictions[i]).float()) 40 | results.append(r) 41 | results = torch.stack(results) 42 | return results.mean() if mean else results 43 | 44 | def prob_diff_hypernymy(clean_logits, corrupted_logits, input_length, labels, mean=True, loss=False, logits=False): 45 | batch_size = clean_logits.size(0) 46 | idx = torch.arange(batch_size, device=clean_logits.device) 47 | 48 | clean_logits = clean_logits[idx, input_length - 1] 49 | clean_probs = torch.softmax(clean_logits, dim=-1) 50 | 51 | if logits: 52 | clean_probs = clean_logits 53 | 54 | results = [] 55 | for i, (ls,corrupted_ls) in enumerate(labels): 56 | r = clean_probs[i][ls.to(clean_probs.device)].sum() - clean_probs[i][corrupted_ls.to(clean_probs.device)].sum() 57 | results.append(r) 58 | results = torch.stack(results) 59 | if loss: 60 | results = -results 61 | return results.mean() if mean else results 62 | 63 | def batch(iterable, n:int=1): 64 | current_batch = [] 65 | for item in iterable: 66 | current_batch.append(item) 67 | if len(current_batch) == n: 68 | yield current_batch 69 | current_batch = [] 70 | if current_batch: 71 | yield current_batch 72 | 73 | def get_singular_and_plural(model, strict=False) -> Tuple[torch.Tensor, torch.Tensor]: 74 | _TOKENIZER = model.tokenizer 75 | tokenizer_length = model.cfg.d_vocab_out 76 | 77 | df: pd.DataFrame = pd.read_csv('../data/sva/combined_verb_list.csv') 78 | singular = df['sing'].to_list() 79 | plural = df['plur'].to_list() 80 | singular_set = set(singular) 81 | plural_set = set(plural) 82 | verb_set = singular_set | plural_set 83 | assert len(singular_set & plural_set) == 0, f"{singular_set & plural_set}" 84 | singular_indices, plural_indices = [], [] 85 | 86 | for i in range(tokenizer_length): 87 | token = _TOKENIZER._convert_id_to_token(i) 88 | if token is not None: 89 | if token[0] == 'Ġ': 90 | token = token[1:] 91 | if token in verb_set: 92 | if token in singular_set: 93 | singular_indices.append(i) 94 | else: # token in plural_set: 95 | idx = plural.index(token) 96 | third_person_present = singular[idx] 97 | third_person_present_tokenized = _TOKENIZER(f' {third_person_present}', add_special_tokens=False)['input_ids'] 98 | if len(third_person_present_tokenized) == 1 and third_person_present_tokenized[0] != _TOKENIZER.unk_token_id: 99 | plural_indices.append(i) 100 | elif not strict: 101 | plural_indices.append(i) 102 | 103 | return torch.tensor(singular_indices, device=model.cfg.device), torch.tensor(plural_indices, device=model.cfg.device) 104 | 105 | def get_sva_prob_diff(model, strict=True) -> torch.Tensor: 106 | singular_indices, plural_indices = get_singular_and_plural(model, strict=strict) 107 | def sva_prob_diff(logits, clean_logits, input_length, labels, loss=False, mean=True): 108 | batch_size = clean_logits.size(0) 109 | idx = torch.arange(batch_size, device=clean_logits.device) 110 | probs = F.softmax(logits[idx, input_length-1], dim=-1) 111 | singular = probs[:, singular_indices].sum(-1) 112 | plural = probs[:, plural_indices].sum(-1) 113 | 114 | correct_form_prob_diff = torch.where(labels == 0, singular - plural, plural - singular) 115 | if loss: 116 | correct_form_prob_diff = - correct_form_prob_diff 117 | if mean: 118 | return correct_form_prob_diff.mean() 119 | else: 120 | return correct_form_prob_diff 121 | return sva_prob_diff 122 | 123 | def inflow_outflow_difference(g, absolute:bool=True): 124 | diffs = [] 125 | for name, node in g.nodes.items(): 126 | if 'logits' in name or 'input' in name: 127 | continue 128 | diff = sum(edge.score for edge in node.child_edges) - sum(edge.score for edge in node.parent_edges) 129 | if absolute: 130 | diff = abs(diff) 131 | diffs.append(diff) 132 | diffs = np.array(diff) 133 | logit_inflow = sum(edge.score for edge in g.logits[0].parent_edges) 134 | input_outflow = sum(edge.score for edge in g.nodes['input'].child_edges) 135 | return diffs.mean(), logit_inflow, input_outflow -------------------------------------------------------------------------------- /eap/visualization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import functools 3 | 4 | import numpy as np 5 | import matplotlib 6 | import matplotlib.cm 7 | 8 | 9 | 10 | EDGE_TYPE_COLORS = { 11 | 'q': "#FF00FF", # Purple 12 | 'k': "#00FF00", # Green 13 | 'v': "#0000FF", # Blue 14 | None: "#000000", # Black 15 | } 16 | 17 | def generate_random_color(colorscheme: str) -> str: 18 | """ 19 | https://stackoverflow.com/questions/28999287/generate-random-colors-rgb 20 | """ 21 | 22 | def rgb2hex(rgb): 23 | """ 24 | https://stackoverflow.com/questions/3380726/converting-an-rgb-color-tuple-to-a-hexidecimal-string 25 | """ 26 | return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2]) 27 | 28 | return rgb2hex(color(colorscheme, np.random.randint(0, 256), rgb_order=True)) 29 | 30 | # ripped from cmapy since it doesn't play nice with new versions of matplotlib 31 | def cmap(cmap_name, rgb_order=False): 32 | """ 33 | Extract colormap color information as a LUT compatible with cv2.applyColormap(). 34 | Default channel order is BGR. 35 | 36 | Args: 37 | cmap_name: string, name of the colormap. 38 | rgb_order: boolean, if false or not set, the returned array will be in 39 | BGR order (standard OpenCV format). If true, the order 40 | will be RGB. 41 | 42 | Returns: 43 | A numpy array of type uint8 containing the colormap. 44 | """ 45 | 46 | c_map = matplotlib.colormaps.get_cmap(cmap_name) 47 | rgba_data = matplotlib.cm.ScalarMappable(cmap=c_map).to_rgba( 48 | np.arange(0, 1.0, 1.0 / 256.0), bytes=True 49 | ) 50 | rgba_data = rgba_data[:, 0:-1].reshape((256, 1, 3)) 51 | 52 | # Convert to BGR (or RGB), uint8, for OpenCV. 53 | cmap = np.zeros((256, 1, 3), np.uint8) 54 | 55 | if not rgb_order: 56 | cmap[:, :, :] = rgba_data[:, :, ::-1] 57 | else: 58 | cmap[:, :, :] = rgba_data[:, :, :] 59 | 60 | return cmap 61 | 62 | 63 | # If python 3, redefine cmap() to use lru_cache. 64 | if sys.version_info > (3, 0): 65 | cmap = functools.lru_cache(maxsize=200)(cmap) 66 | 67 | 68 | def color(cmap_name, index, rgb_order=False): 69 | """Returns a color of a given colormap as a list of 3 BGR or RGB values. 70 | 71 | Args: 72 | cmap_name: string, name of the colormap. 73 | index: floating point between 0 and 1 or integer between 0 and 255, 74 | index of the requested color. 75 | rgb_order: boolean, if false or not set, the returned list will be in 76 | BGR order (standard OpenCV format). If true, the order 77 | will be RGB. 78 | 79 | Returns: 80 | List of RGB or BGR values. 81 | """ 82 | 83 | # Float values: scale from 0-1 to 0-255. 84 | if isinstance(index, float): 85 | val = round(min(max(index, 0.0), 1.0) * 255) 86 | else: 87 | val = min(max(index, 0), 255) 88 | 89 | # Get colormap and extract color. 90 | colormap = cmap(cmap_name, rgb_order) 91 | return colormap[int(val), 0, :].tolist() -------------------------------------------------------------------------------- /knowledge_eap.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from transformer_lens import HookedTransformer\n", 11 | "from functools import partial\n", 12 | "import torch.nn.functional as F\n", 13 | "from eap.metrics import logit_diff, direct_logit\n", 14 | "import transformer_lens.utils as utils\n", 15 | "from eap.graph import Graph\n", 16 | "from eap.dataset import EAPDataset\n", 17 | "from eap.attribute import attribute\n", 18 | "import time\n", 19 | "from rich import print as rprint\n", 20 | "import pandas as pd\n", 21 | "from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "LLAMA_2_7B_CHAT_PATH = \"meta-llama/Llama-2-7b-chat-hf\"\n", 31 | "from transformers import LlamaForCausalLM\n", 32 | "model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device=\"cuda\", fold_ln=False, center_writing_weights=False, center_unembed=False)\n", 33 | "model.cfg.use_split_qkv_input = True\n", 34 | "model.cfg.use_attn_result = True\n", 35 | "model.cfg.use_hook_mlp_in = True" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 5, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "clean_subject = 'Eiffel Tower'\n", 45 | "corrupted_subject = 'the Great Walls'\n", 46 | "clean = f'The official currency of the country where {clean_subject} is loacted in is the'\n", 47 | "corrupted = f'The official currency of the country where {corrupted_subject} is loacted in is the'\n", 48 | "assert len(model.to_str_tokens(clean.format(clean_subject))) == len(model.to_str_tokens(corrupted.format(corrupted_subject)))\n", 49 | "labels = ['Euro','Chinese']\n", 50 | "country_idx = model.tokenizer(labels[0],add_special_tokens=False).input_ids[0]\n", 51 | "corrupted_country_idx = model.tokenizer(labels[1],add_special_tokens=False).input_ids[0]\n", 52 | "# dataset = {k:[] for k in ['clean','country_idx', 'corrupted', 'corrupted_country_idx']}\n", 53 | "# for k, v in zip(['clean', 'country_idx', 'corrupted', 'corrupted_country_idx'], [clean, country_idx, corrupted, corrupted_country_idx]):\n", 54 | "# dataset[k].append(v)\n", 55 | "# df2 = pd.DataFrame.from_dict(dataset)\n", 56 | "# df2.to_csv(f'capital_city.csv', index=False)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 6, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "label = [[country_idx, corrupted_country_idx]]\n", 66 | "label = torch.tensor(label)\n", 67 | "data = ([clean],[corrupted],label)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# ds = EAPDataset(filename='capital_city.csv',task='fact-retrieval')\n", 77 | "# dataloader = ds.to_dataloader(1)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 12, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stderr", 87 | "output_type": "stream", 88 | "text": [ 89 | "100%|██████████| 1592881/1592881 [00:01<00:00, 1062625.82it/s]\n" 90 | ] 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "程序执行时间:43.55915355682373秒\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "g = Graph.from_model(model)\n", 102 | "start_time = time.time()\n", 103 | "# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric\n", 104 | "attribute(model, g, data, partial(logit_diff, loss=True, mean=True), method='EAP-IG-case', ig_steps=100)\n", 105 | "# attribute(model, g, data, partial(direct_logit, loss=True, mean=True), method='EAP-IG-case', ig_steps=30)\n", 106 | "# attribute(model, g, dataloader, partial(logit_diff, loss=True, mean=True), method='EAP-IG', ig_steps=30)\n", 107 | "g.apply_topn(5000, absolute=True)\n", 108 | "g.prune_dead_nodes()\n", 109 | "\n", 110 | "g.to_json('graph.json')\n", 111 | "\n", 112 | "gz = g.to_graphviz()\n", 113 | "gz.draw(f'graph.png', prog='dot')\n", 114 | "\n", 115 | "end_time = time.time()\n", 116 | "execution_time = end_time - start_time\n", 117 | "print(f\"程序执行时间:{execution_time}秒\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def get_component_logits(logits, model, answer_token, top_k=10):\n", 127 | " logits = utils.remove_batch_dim(logits)\n", 128 | " # print(heads_out[head_name].shape)\n", 129 | " probs = logits.softmax(dim=-1)\n", 130 | " token_probs = probs[-1]\n", 131 | " answer_str_token = model.to_string(answer_token)\n", 132 | " sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)\n", 133 | " # Janky way to get the index of the token in the sorted list - I couldn't find a better way?\n", 134 | " correct_rank = torch.arange(len(sorted_token_values))[\n", 135 | " (sorted_token_values == answer_token).cpu()\n", 136 | " ].item()\n", 137 | " # answer_ranks = []\n", 138 | " # answer_ranks.append((answer_str_token, correct_rank))\n", 139 | " # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.\n", 140 | " # rprint gives rich text printing\n", 141 | " rprint(\n", 142 | " f\"Performance on answer token:\\n[b]Rank: {correct_rank: <8} Logit: {logits[-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]\"\n", 143 | " )\n", 144 | " for i in range(top_k):\n", 145 | " print(\n", 146 | " f\"Top {i}th token. Logit: {logits[-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|\"\n", 147 | " )\n", 148 | " # rprint(f\"[b]Ranks of the answer tokens:[/b] {answer_ranks}\")" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 13, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/html": [ 159 | "
Performance on answer token:\n",
160 |        "Rank: 0        Logit: 16.94 Prob: 56.56% Token: |Euro|\n",
161 |        "
\n" 162 | ], 163 | "text/plain": [ 164 | "Performance on answer token:\n", 165 | "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m16.94\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m56.56\u001b[0m\u001b[1m% Token: |Euro|\u001b[0m\n" 166 | ] 167 | }, 168 | "metadata": {}, 169 | "output_type": "display_data" 170 | }, 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Top 0th token. Logit: 16.94 Prob: 56.56% Token: |Euro|\n", 176 | "Top 1th token. Logit: 15.96 Prob: 21.39% Token: |French|\n", 177 | "Top 2th token. Logit: 14.06 Prob: 3.18% Token: |_|\n", 178 | "Top 3th token. Logit: 13.95 Prob: 2.85% Token: |euro|\n", 179 | "Top 4th token. Logit: 13.91 Prob: 2.74% Token: |Eu|\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "logits = get_circuit_logits(model, g, data)\n", 185 | "get_component_logits(logits, model, answer_token=model.to_tokens('Euro',prepend_bos=False)[0], top_k=5)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 69, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stderr", 195 | "output_type": "stream", 196 | "text": [ 197 | "100%|██████████| 1/1 [00:00<00:00, 8.79it/s]\n", 198 | "100%|██████████| 1/1 [00:00<00:00, 8.91it/s]\n", 199 | "100%|██████████| 1/1 [00:00<00:00, 6.82it/s]" 200 | ] 201 | }, 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "Original performance was 10.043922424316406; the circuit's performance is 6.337347984313965\n" 207 | ] 208 | }, 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "baseline = evaluate_baseline(model, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()\n", 219 | "results = evaluate_graph(model, g, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()\n", 220 | "print(f\"Original performance was {baseline}; the circuit's performance is {results}\")" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "eap", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.11.9" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 2 245 | } 246 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | huggingface_hub 3 | pandas 4 | datasets 5 | jaxtyping 6 | rich 7 | torch==1.13.1 8 | better_abc 9 | fancy_einsum 10 | matplotlib==3.8.4 11 | transformers>=4.37.2 12 | einops==0.7.0 13 | plotly==5.21.0 14 | wandb==0.16.6 15 | pytest==8.1.1 16 | dataclasses_json==0.6.4 17 | networkx==3.3 18 | torchtyping==0.1.4 19 | cmapy==0.6.6 20 | circuitsvis==1.43.2 21 | pygraphviz #conda install --channel conda-forge pygraphviz -------------------------------------------------------------------------------- /transformer_lens/SVDInterpreter.py: -------------------------------------------------------------------------------- 1 | """SVD Interpreter. 2 | 3 | Module for getting the singular vectors of the OV, w_in, and w_out matrices of a 4 | :class:`transformer_lens.HookedTransformer`. 5 | """ 6 | 7 | from typing import Optional, Union 8 | 9 | import fancy_einsum as einsum 10 | import torch 11 | from typeguard import typechecked 12 | from typing_extensions import Literal 13 | 14 | from transformer_lens.FactoredMatrix import FactoredMatrix 15 | from transformer_lens.HookedTransformer import HookedTransformer 16 | 17 | OUTPUT_EMBEDDING = "unembed.W_U" 18 | VECTOR_TYPES = ["OV", "w_in", "w_out"] 19 | 20 | 21 | class SVDInterpreter: 22 | def __init__(self, model: HookedTransformer): 23 | self.model = model 24 | self.cfg = model.cfg 25 | self.params = {name: param for name, param in model.named_parameters()} 26 | 27 | @typechecked 28 | def get_singular_vectors( 29 | self, 30 | vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]], 31 | layer_index: int, 32 | num_vectors: int = 10, 33 | head_index: Optional[int] = None, 34 | ) -> torch.Tensor: 35 | """Gets the singular vectors for a given vector type, layer, and optionally head. 36 | 37 | This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this 38 | feature. The demo also points out some "gotchas" in this feature - numerical instability 39 | means inconsistency across devices, and the default HookedTransformer parameters don't 40 | replicate the original SVD post very well. So I'd recommend checking out the demo if you 41 | want to use this! 42 | 43 | Example: 44 | 45 | .. code-block:: python 46 | 47 | from transformer_lens import HookedTransformer, SVDInterpreter 48 | 49 | model = HookedTransformer.from_pretrained('gpt2-medium') 50 | svd_interpreter = SVDInterpreter(model) 51 | 52 | ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10) 53 | 54 | all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)] 55 | all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)] 56 | 57 | def plot_matrix(matrix, tokens, k=10, filter="topk"): 58 | pysvelte.TopKTable( 59 | tokens=all_tokens, 60 | activations=matrix, 61 | obj_type="SVD direction", 62 | k=k, 63 | filter=filter 64 | ).show() 65 | 66 | plot_matrix(ov, all_tokens) 67 | 68 | Args: 69 | vector_type: Type of the vector: 70 | - "OV": Singular vectors of the OV matrix for a particular layer and head. 71 | - "w_in": Singular vectors of the w_in matrix for a particular layer. 72 | - "w_out": Singular vectors of the w_out matrix for a particular layer. 73 | layer_index: The index of the layer. 74 | num_vectors: Number of vectors. 75 | head_index: Index of the head. 76 | """ 77 | 78 | if head_index is None: 79 | assert vector_type in [ 80 | "w_in", 81 | "w_out", 82 | ], f"Head index optional only for w_in and w_out, got {vector_type}" 83 | 84 | matrix: Union[FactoredMatrix, torch.Tensor] 85 | if vector_type == "OV": 86 | assert head_index is not None # keep mypy happy 87 | matrix = self._get_OV_matrix(layer_index, head_index) 88 | V = matrix.Vh.T 89 | 90 | elif vector_type == "w_in": 91 | matrix = self._get_w_in_matrix(layer_index) 92 | _, _, V = torch.linalg.svd(matrix) 93 | 94 | elif vector_type == "w_out": 95 | matrix = self._get_w_out_matrix(layer_index) 96 | _, _, V = torch.linalg.svd(matrix) 97 | 98 | else: 99 | raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}") 100 | 101 | return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors) 102 | 103 | def _get_singular_vectors_from_matrix( 104 | self, 105 | V: Union[torch.Tensor, FactoredMatrix], 106 | embedding: torch.Tensor, 107 | num_vectors: int = 10, 108 | ) -> torch.Tensor: 109 | """Returns the top num_vectors singular vectors from a matrix.""" 110 | 111 | vectors_list = [] 112 | for i in range(num_vectors): 113 | activations = V[i, :].float() @ embedding # type: ignore 114 | vectors_list.append(activations) 115 | 116 | vectors = torch.stack(vectors_list, dim=1).unsqueeze(1) 117 | assert vectors.shape == ( 118 | self.cfg.d_vocab, 119 | 1, 120 | num_vectors, 121 | ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}" 122 | return vectors 123 | 124 | def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix: 125 | """Gets the OV matrix for a particular layer and head.""" 126 | 127 | assert ( 128 | 0 <= layer_index < self.cfg.n_layers 129 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 130 | assert ( 131 | 0 <= head_index < self.cfg.n_heads 132 | ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}" 133 | 134 | W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"] 135 | W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"] 136 | W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :] 137 | 138 | return FactoredMatrix(W_V, W_O) 139 | 140 | def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor: 141 | """Gets the w_in matrix for a particular layer.""" 142 | 143 | assert ( 144 | 0 <= layer_index < self.cfg.n_layers 145 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 146 | 147 | w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T 148 | 149 | if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False 150 | ln_2 = self.params[f"blocks.{layer_index}.ln2.w"] 151 | return einsum.einsum("out in, in -> out in", w_in, ln_2) 152 | 153 | return w_in 154 | 155 | def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor: 156 | """Gets the w_out matrix for a particular layer.""" 157 | 158 | assert ( 159 | 0 <= layer_index < self.cfg.n_layers 160 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}" 161 | 162 | return self.params[f"blocks.{layer_index}.mlp.W_out"] 163 | -------------------------------------------------------------------------------- /transformer_lens/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hook_points 2 | from . import utils 3 | from . import evals 4 | from .past_key_value_caching import ( 5 | HookedTransformerKeyValueCache, 6 | HookedTransformerKeyValueCacheEntry, 7 | ) 8 | from . import components 9 | from .HookedTransformerConfig import HookedTransformerConfig 10 | from .FactoredMatrix import FactoredMatrix 11 | from .ActivationCache import ActivationCache 12 | from .HookedTransformer import HookedTransformer 13 | from .SVDInterpreter import SVDInterpreter 14 | from .HookedEncoder import HookedEncoder 15 | from . import head_detector 16 | from . import loading_from_pretrained as loading 17 | from . import patching 18 | from . import train 19 | 20 | from .past_key_value_caching import ( 21 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache, 22 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry, 23 | ) 24 | from .HookedTransformer import HookedTransformer as EasyTransformer 25 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig 26 | -------------------------------------------------------------------------------- /transformer_lens/past_key_value_caching.py: -------------------------------------------------------------------------------- 1 | """Past Key Value Caching. 2 | 3 | This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry 4 | classes, which are used to store past keys and values for the Transformer. This is important for 5 | generating text - we can cache a lot of past computation and avoid repeating ourselves! 6 | """ 7 | from dataclasses import dataclass 8 | from typing import List, Union 9 | 10 | import torch 11 | from jaxtyping import Float, Int 12 | 13 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 14 | from transformer_lens.utilities.devices import get_device_for_block_index 15 | 16 | 17 | @dataclass 18 | class HookedTransformerKeyValueCacheEntry: 19 | past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 20 | past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"] 21 | frozen: bool = False 22 | 23 | @classmethod 24 | def init_cache_entry( 25 | cls, 26 | cfg: HookedTransformerConfig, 27 | device: Union[torch.device, str, None], 28 | batch_size: int = 1, 29 | ): 30 | n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads 31 | return cls( 32 | past_keys=torch.empty( 33 | (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 34 | ), 35 | past_values=torch.empty( 36 | (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype 37 | ), 38 | ) 39 | 40 | def append( 41 | self, 42 | new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 43 | new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"], 44 | ): 45 | updated_keys: Float[ 46 | torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 47 | ] = torch.cat([self.past_keys, new_keys], dim=1) 48 | updated_values: Float[ 49 | torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head" 50 | ] = torch.cat([self.past_values, new_values], dim=1) 51 | if not self.frozen: 52 | self.past_keys = updated_keys 53 | self.past_values = updated_values 54 | return updated_keys, updated_values 55 | 56 | 57 | @dataclass 58 | class HookedTransformerKeyValueCache: 59 | """ 60 | A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 61 | 62 | This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 63 | 64 | The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix. 65 | """ 66 | 67 | entries: List[HookedTransformerKeyValueCacheEntry] 68 | previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"] 69 | frozen: bool = False 70 | 71 | @classmethod 72 | def init_cache( 73 | cls, 74 | cfg: HookedTransformerConfig, 75 | device: Union[torch.device, str, None], 76 | batch_size: int = 1, 77 | ): 78 | return cls( 79 | entries=[ 80 | HookedTransformerKeyValueCacheEntry.init_cache_entry( 81 | cfg, 82 | get_device_for_block_index(i, cfg, device), 83 | batch_size, 84 | ) 85 | for i in range(cfg.n_layers) 86 | ], 87 | previous_attention_mask=torch.empty( 88 | # This may actually be an int64, but type promotion will handle it: 89 | # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 90 | # See: https://github.com/pytorch/pytorch/issues/35014 91 | (batch_size, 0), 92 | device=device, 93 | dtype=torch.int, 94 | ), 95 | ) 96 | 97 | def freeze(self): 98 | self.frozen = True 99 | for entry in self.entries: 100 | entry.frozen = True 101 | 102 | def unfreeze(self): 103 | self.frozen = False 104 | for entry in self.entries: 105 | entry.frozen = False 106 | 107 | def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): 108 | attention_mask = attention_mask.to(self.previous_attention_mask.device) 109 | updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) 110 | if not self.frozen: 111 | self.previous_attention_mask = updated_attention_mask 112 | return updated_attention_mask 113 | 114 | def __getitem__(self, idx): 115 | return self.entries[idx] 116 | -------------------------------------------------------------------------------- /transformer_lens/train.py: -------------------------------------------------------------------------------- 1 | """Train. 2 | 3 | Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language 4 | modeling tasks. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.optim as optim 12 | import wandb 13 | from torch.optim import Optimizer 14 | from torch.utils.data import DataLoader, Dataset 15 | from tqdm.auto import tqdm 16 | 17 | from transformer_lens import utils 18 | from transformer_lens.HookedTransformer import HookedTransformer 19 | 20 | 21 | @dataclass 22 | class HookedTransformerTrainConfig: 23 | """ 24 | Configuration class to store training hyperparameters for a training run of 25 | an HookedTransformer model. 26 | Args: 27 | num_epochs (int): Number of epochs to train for 28 | batch_size (int): Size of batches to use for training 29 | lr (float): Learning rate to use for training 30 | seed (int): Random seed to use for training 31 | momentum (float): Momentum to use for training 32 | max_grad_norm (float, *optional*): Maximum gradient norm to use for 33 | weight_decay (float, *optional*): Weight decay to use for training 34 | optimizer_name (str): The name of the optimizer to use 35 | device (str, *optional*): Device to use for training 36 | warmup_steps (int, *optional*): Number of warmup steps to use for training 37 | save_every (int, *optional*): After how many batches should a checkpoint be saved 38 | save_dir, (str, *optional*): Where to save checkpoints 39 | wandb (bool): Whether to use Weights and Biases for logging 40 | wandb_project (str, *optional*): Name of the Weights and Biases project to use 41 | print_every (int, *optional*): Print the loss every n steps 42 | max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging. 43 | """ 44 | 45 | num_epochs: int 46 | batch_size: int 47 | lr: float = 1e-3 48 | seed: int = 0 49 | momentum: float = 0.0 50 | max_grad_norm: Optional[float] = None 51 | weight_decay: Optional[float] = None 52 | optimizer_name: str = "Adam" 53 | device: Optional[str] = None 54 | warmup_steps: int = 0 55 | save_every: Optional[int] = None 56 | save_dir: Optional[str] = None 57 | wandb: bool = False 58 | wandb_project_name: Optional[str] = None 59 | print_every: Optional[int] = 50 60 | max_steps: Optional[int] = None 61 | 62 | 63 | def train( 64 | model: HookedTransformer, 65 | config: HookedTransformerTrainConfig, 66 | dataset: Dataset, 67 | ) -> HookedTransformer: 68 | """ 69 | Trains an HookedTransformer model on an autoregressive language modeling task. 70 | Args: 71 | model: The model to train 72 | config: The training configuration 73 | dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling. 74 | Returns: 75 | The trained model 76 | """ 77 | torch.manual_seed(config.seed) 78 | model.train() 79 | if config.wandb: 80 | if config.wandb_project_name is None: 81 | config.wandb_project_name = "easy-transformer" 82 | wandb.init(project=config.wandb_project_name, config=vars(config)) 83 | 84 | if config.device is None: 85 | config.device = utils.get_device() 86 | 87 | optimizer: Optimizer 88 | if config.optimizer_name in ["Adam", "AdamW"]: 89 | # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs) 90 | if config.weight_decay is not None: 91 | optimizer = optim.AdamW( 92 | model.parameters(), 93 | lr=config.lr, 94 | weight_decay=config.weight_decay, 95 | ) 96 | else: 97 | optimizer = optim.Adam( 98 | model.parameters(), 99 | lr=config.lr, 100 | ) 101 | elif config.optimizer_name == "SGD": 102 | optimizer = optim.SGD( 103 | model.parameters(), 104 | lr=config.lr, 105 | weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0), 106 | momentum=config.momentum, 107 | ) 108 | else: 109 | raise ValueError(f"Optimizer {config.optimizer_name} not supported") 110 | 111 | scheduler = None 112 | if config.warmup_steps > 0: 113 | scheduler = optim.lr_scheduler.LambdaLR( 114 | optimizer, 115 | lr_lambda=lambda step: min(1.0, step / config.warmup_steps), 116 | ) 117 | 118 | dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) 119 | 120 | model.to(config.device) 121 | 122 | for epoch in tqdm(range(1, config.num_epochs + 1)): 123 | samples = 0 124 | for step, batch in tqdm(enumerate(dataloader)): 125 | tokens = batch["tokens"].to(config.device) 126 | loss = model(tokens, return_type="loss") 127 | loss.backward() 128 | if config.max_grad_norm is not None: 129 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) 130 | optimizer.step() 131 | if config.warmup_steps > 0: 132 | assert scheduler is not None 133 | scheduler.step() 134 | optimizer.zero_grad() 135 | 136 | samples += tokens.shape[0] 137 | 138 | if config.wandb: 139 | wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch}) 140 | 141 | if config.print_every is not None and step % config.print_every == 0: 142 | print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}") 143 | 144 | if ( 145 | config.save_every is not None 146 | and step % config.save_every == 0 147 | and config.save_dir is not None 148 | ): 149 | torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt") 150 | 151 | if config.max_steps is not None and step >= config.max_steps: 152 | break 153 | 154 | return model 155 | -------------------------------------------------------------------------------- /transformer_lens/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/transformer_lens/utilities/__init__.py -------------------------------------------------------------------------------- /transformer_lens/utilities/devices.py: -------------------------------------------------------------------------------- 1 | """Devices. 2 | 3 | Utilities to get the correct device, and assist in distributing model layers across multiple 4 | devices. 5 | """ 6 | from __future__ import annotations 7 | 8 | from typing import Optional, Union 9 | 10 | import torch 11 | from torch import nn 12 | 13 | import transformer_lens 14 | 15 | 16 | def get_device_for_block_index( 17 | index: int, 18 | cfg: "transformer_lens.HookedTransformerConfig", 19 | device: Optional[Union[torch.device, str]] = None, 20 | ): 21 | """ 22 | Determine the device for a given layer index based on the model configuration. 23 | 24 | This function assists in distributing model layers across multiple devices. The distribution 25 | is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 26 | 27 | Args: 28 | index (int): Model layer index. 29 | cfg (HookedTransformerConfig): Model and device configuration. 30 | device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 31 | If not provided, the function uses the device specified in the configuration (cfg.device). 32 | 33 | Returns: 34 | torch.device: The device for the specified layer index. 35 | """ 36 | assert cfg.device is not None 37 | layers_per_device = cfg.n_layers // cfg.n_devices 38 | if device is None: 39 | device = cfg.device 40 | device = torch.device(device) 41 | if device.type == "cpu": 42 | return device 43 | device_index = (device.index or 0) + (index // layers_per_device) 44 | return torch.device(device.type, device_index) 45 | 46 | 47 | def move_to_and_update_config( 48 | model: Union["transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"], 49 | device_or_dtype: Union[torch.device, str, torch.dtype], 50 | print_details=True, 51 | ): 52 | """ 53 | Wrapper around `to` that also updates `model.cfg`. 54 | """ 55 | if isinstance(device_or_dtype, torch.device): 56 | model.cfg.device = device_or_dtype.type 57 | if print_details: 58 | print("Moving model to device: ", model.cfg.device) 59 | elif isinstance(device_or_dtype, str): 60 | model.cfg.device = device_or_dtype 61 | if print_details: 62 | print("Moving model to device: ", model.cfg.device) 63 | elif isinstance(device_or_dtype, torch.dtype): 64 | model.cfg.dtype = device_or_dtype 65 | if print_details: 66 | print("Changing model dtype to", device_or_dtype) 67 | # change state_dict dtypes 68 | for k, v in model.state_dict().items(): 69 | model.state_dict()[k] = v.to(device_or_dtype) 70 | return nn.Module.to(model, device_or_dtype) 71 | --------------------------------------------------------------------------------