├── .gitignore
├── LICENSE
├── README.md
├── acdc
├── TLACDCCorrespondence.py
├── TLACDCEdge.py
├── TLACDCExperiment.py
├── TLACDCInterpNode.py
├── __init__.py
├── acdc_graphics.py
├── acdc_utils.py
├── docstring
│ ├── __init__.py
│ ├── prompts.py
│ └── utils.py
├── global_cache.py
├── greaterthan
│ ├── __init__.py
│ └── utils.py
├── induction
│ ├── __init__.py
│ └── utils.py
├── ioi
│ ├── __init__.py
│ ├── ioi_dataset.py
│ └── utils.py
├── knowledge
│ ├── __init__.py
│ ├── knowledge_dataset.py
│ └── utils.py
├── logic_gates
│ ├── __init__.py
│ └── utils.py
├── main.py
├── run.sh
└── tracr_task
│ ├── __init__.py
│ └── utils.py
├── computed_circuits
└── gpt2-medium
│ ├── Factual Knowledge
│ ├── graph.gv
│ ├── graph.png
│ └── io.txt
│ ├── Hallucination
│ ├── attention_analysis.png
│ ├── graph.gv
│ ├── io.txt
│ └── specialNode.txt
│ └── In-Context Learning
│ ├── attention_analysis.png
│ ├── graph.gv
│ ├── io.txt
│ └── specialNode.txt
├── data
├── bias
│ ├── characteristic_gender.json
│ ├── degree_gender.json
│ ├── name_birthplace.json
│ ├── name_gender.json
│ ├── name_religion.json
│ ├── occupation_age.json
│ └── occupation_gender.json
├── commonsense
│ ├── fruit_inside_color.json
│ ├── fruit_outside_color.json
│ ├── object_superclass.json
│ ├── substance_phase.json
│ ├── task_done_by_person.json
│ ├── task_done_by_tool.json
│ ├── word_sentiment.json
│ └── work_location.json
├── factual
│ ├── city_in_country.json
│ ├── company_ceo.json
│ ├── company_hq.json
│ ├── country_capital_city.json
│ ├── country_currency.json
│ ├── country_language.json
│ ├── country_largest_city.json
│ ├── food_from_country.json
│ ├── landmark_in_country.json
│ ├── landmark_on_continent.json
│ ├── person_band_lead_singer.json
│ ├── person_father.json
│ ├── person_mother.json
│ ├── person_native_language.json
│ ├── person_occupation.json
│ ├── person_plays_instrument.json
│ ├── person_plays_position_in_sport.json
│ ├── person_plays_pro_sport.json
│ ├── person_university.json
│ ├── pokemon_evolutions.json
│ ├── presidents_birth_year.json
│ ├── presidents_election_year.json
│ ├── product_by_company.json
│ ├── star_constellation.json
│ ├── superhero_archnemesis.json
│ └── superhero_person.json
└── linguistic
│ ├── adj_antonym.json
│ ├── adj_comparative.json
│ ├── adj_superlative.json
│ ├── verb_past_tense.json
│ ├── word_first_letter.json
│ └── word_last_letter.json
├── eap
├── attribute.py
├── dataset.py
├── evaluate.py
├── graph.py
├── metrics.py
├── utils.py
└── visualization.py
├── knowledge_eap.ipynb
├── notebook
└── component.ipynb
├── requirements.txt
└── transformer_lens
├── ActivationCache.py
├── FactoredMatrix.py
├── HookedEncoder.py
├── HookedTransformer.py
├── HookedTransformerConfig.py
├── SVDInterpreter.py
├── __init__.py
├── components.py
├── evals.py
├── head_detector.py
├── hook_points.py
├── loading_from_pretrained.py
├── past_key_value_caching.py
├── patching.py
├── train.py
├── utilities
├── __init__.py
└── devices.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | hugging_cache
3 | data
4 | .vscode
5 | logs
6 |
7 | __pycache__/
8 | *.pyc
9 | *.pyo
10 | *.swp
11 | *~
12 | .DS_Store
13 | .env
14 | */__pycache__/
15 | *ims
16 | */*ims*/
17 | *.out
18 | *.pt
19 | *pkl
20 | */*results*/
21 | *results
22 | *output
23 | *factual_results
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ZJUNLP
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Knowledge Circuits
2 | Knowledge Circuits in Pretrained Transformers
3 |
4 |
5 | 📄arXiv •
6 | 🌐Demo •
7 |
Youtube •
8 | 𝕏 Blog
9 |
10 |
11 | [](https://github.com/zjunlp/KnowledgeCircuits)
12 | [](https://opensource.org/licenses/MIT)
13 | 
14 |
15 | ## 🔔News
16 |
17 | - [2025-02-16] We release our new paper [How Do LLMs Acquire New Knowledge? A Knowledge Circuits Perspective on Continual Pre-Training](https://arxiv.org/abs/2502.11196), analyzing the evolution of knowledge circuits throughout continual pre-training. Check it out:)
18 | - [2024-09-26] Our paper [Knowledge Circuits in Pretrained Transformers](https://arxiv.org/abs/2405.17969) is accepetd by NeurIPS 2024!
19 | - [2024-05-28] We release our paper [Knowledge Circuits in Pretrained Transformers](https://arxiv.org/abs/2405.17969).
20 |
21 |
22 | ## Table of Contents
23 | - 🌟[Overview](#overview)
24 | - 🔧[Installation](#installation)
25 | - 📚[Get the circuit](#get-the-circuit)
26 | - 🧐[Analyze Component](#analyze-component)
27 | - 🌻[Acknowledgement](#acknowledgement)
28 | - 🚩[Citation](#citation)
29 |
30 | ---
31 |
32 |
33 | ## 🌟Overview
34 |
35 | This work aims to build the circuits in the pretrained language models that are responsible for the specific knowledge and analyze the behavior of these components.
36 | We construct a [demo](http://knowledgecircuits.zjukg.cn/) to see the discovered circuit.
37 | * A new method [EAP-IG](https://arxiv.org/abs/2403.17806) is integrated in the eap folder. This method takes less time than the ACDC method and you can use it in the `knowledge_eap.ipynb`. If you are using the LLaMA2-7B-Chat model, running this file on a single GPU will require approximately 57,116M of GPU memory and 3-4 minutes.
38 |
39 | ## 🔧Installation
40 |
41 | The filtered data for each kind of model is at [here](https://pan.zju.edu.cn/share/7c613d16095c504605f83eba72). Please download it and put it in the data folder.
42 |
43 | Build the environement:
44 | ```
45 | conda create -n knowledgecircuit python=3.10
46 | pip install -r requirements.txt
47 | ```
48 | ❗️The code may fail under torch 2.x.x. We recommend torch 1.x.x
49 |
50 | ## 📚Get the circuit
51 |
52 | Just run the following commond:
53 | ```
54 | cd acdc
55 | sh run.sh
56 | ```
57 | Here is an example to run the circuit for the `country_capital_city` in `GPT2-Medium`.
58 | ```
59 | MODEL_PATH=/path/to/the/model
60 | KT=factual
61 | KNOWLEDGE=country_capital_city
62 | NUM_EXAMPLES=20
63 | MODEL_NAME=gpt2-medium
64 |
65 | python main.py --task=knowledge \
66 | --zero-ablation \
67 | --threshold=0.01 \
68 | --device=cuda:0 \
69 | --metric=match_nll \
70 | --indices-mode=reverse \
71 | --first-cache-cpu=False \
72 | --second-cache-cpu=False \
73 | --max-num-epochs=10000 \
74 | --specific-knowledge=$KNOWLEDGE \
75 | --num-examples=$NUM_EXAMPLES \
76 | --relation-reverse=False \
77 | --knowledge-type=$KT \
78 | --model-name=$MODEL_NAME \
79 | --model-path=$MODEL_PATH
80 | ```
81 |
82 | You would get the results in `acdc/factual_results/gpt2-medium` and the `final_graph.pdf` is the computed circuits.
83 |
84 | ## 🧐Analyze component
85 |
86 | Run the component.ipynb in notebook.
87 |
88 | ## 🌻Acknowledgement
89 |
90 | We thank for the project of [transformer_lens](https://github.com/TransformerLensOrg/TransformerLens), [ACDC](https://github.com/ArthurConmy/Automatic-Circuit-Discovery) and [LRE](https://lre.baulab.info/).
91 | The code in this work is built on top of these three projects' codes.
92 |
93 |
94 | ## 🚩Citation
95 |
96 | Please cite our repository if you use Knowledge Circuit in your work. Thanks!
97 |
98 | ```bibtex
99 | @article{DBLP:journals/corr/abs-2405-17969,
100 | author = {Yunzhi Yao and
101 | Ningyu Zhang and
102 | Zekun Xi and
103 | Mengru Wang and
104 | Ziwen Xu and
105 | Shumin Deng and
106 | Huajun Chen},
107 | title = {Knowledge Circuits in Pretrained Transformers},
108 | journal = {CoRR},
109 | volume = {abs/2405.17969},
110 | year = {2024},
111 | url = {https://doi.org/10.48550/arXiv.2405.17969},
112 | doi = {10.48550/ARXIV.2405.17969},
113 | eprinttype = {arXiv},
114 | eprint = {2405.17969},
115 | timestamp = {Fri, 21 Jun 2024 22:39:09 +0200},
116 | biburl = {https://dblp.org/rec/journals/corr/abs-2405-17969.bib},
117 | bibsource = {dblp computer science bibliography, https://dblp.org}
118 | }
119 | ```
120 |
--------------------------------------------------------------------------------
/acdc/TLACDCEdge.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from collections import defaultdict
3 | from enum import Enum
4 | from typing import Optional, List
5 |
6 |
7 | class EdgeType(Enum):
8 | """
9 | Property of edges in the computational graph - either
10 |
11 | ADDITION: the child (hook_name, index) is a sum of the parent (hook_name, index)s
12 | DIRECT_COMPUTATION The *single* child is a function of and only of the parent (e.g the value hooked by hook_q is a function of what hook_q_input saves).
13 | PLACEHOLDER generally like 2. but where there are generally multiple parents. Here in ACDC we just include these edges by default when we find them. Explained below?
14 |
15 | Q: Why do we do this?
16 |
17 | There are two answers to this question: A1 is an interactive notebook, see this Colab notebook, which is in this repo at notebooks/implementation_demo.py. A2 is an answer that is written here below, but probably not as clear as A1 (though shorter).
18 |
19 | A2: We need something inside TransformerLens to represent the edges of a computational graph.
20 | The object we choose is pairs (hook_name, index). For example the output of Layer 11 Heads is a hook (blocks.11.attn.hook_result) and to sepcify the 3rd head we add the index [:, :, 3]. Then we can build a computational graph on these!
21 |
22 | However, when we do ACDC there turn out to be two conflicting things "removing edges" wants to do:
23 | i) for things in the residual stream, we want to remove the sum of the effects from previous hooks
24 | ii) for things that are not linear we want to *recompute* e.g the result inside the hook
25 | blocks.11.attn.hook_result from a corrupted Q and normal K and V
26 |
27 | The easiest way I thought of of reconciling these different cases, while also having a connected computational graph, is to have three types of edges: addition for the residual case, direct computation for easy cases where we can just replace hook_q with a cached value when we e.g cut it off from hook_q_input, and placeholder to make the graph connected (when hook_result is connected to hook_q and hook_k and hook_v)"""
28 |
29 | ADDITION = 0
30 | DIRECT_COMPUTATION = 1
31 | PLACEHOLDER = 2
32 |
33 | def __eq__(self, other):
34 | """Necessary because of extremely frustrating error that arises with load_ext autoreload (because this uses importlib under the hood: https://stackoverflow.com/questions/66458864/enum-comparison-become-false-after-reloading-module)"""
35 |
36 | assert isinstance(other, EdgeType)
37 | return self.value == other.value
38 |
39 |
40 | class Edge:
41 | def __init__(
42 | self,
43 | edge_type: EdgeType,
44 | present: bool = True,
45 | effect_size: Optional[float] = None,
46 | ):
47 | self.edge_type = edge_type
48 | self.present = present
49 | self.effect_size = effect_size
50 |
51 | def __repr__(self) -> str:
52 | return f"Edge({self.edge_type}, {self.present})"
53 |
54 | class TorchIndex:
55 | """There is not a clean bijection between things we
56 | want in the computational graph, and things that are hooked
57 | (e.g hook_result covers all heads in a layer)
58 |
59 | `TorchIndex`s are essentially indices that say which part of the tensor is being affected.
60 |
61 | EXAMPLES: Initialise [:, :, 3] with TorchIndex([None, None, 3]) and [:] with TorchIndex([None])
62 |
63 | Also we want to be able to call e.g `my_dictionary[my_torch_index]` hence the hashable tuple stuff
64 |
65 | Note: ideally this would be integrated with transformer_lens.utils.Slice in future; they are accomplishing similar but different things"""
66 |
67 | def __init__(
68 | self,
69 | list_of_things_in_tuple: List,
70 | ):
71 | # check correct types
72 | for arg in list_of_things_in_tuple:
73 | if type(arg) in [type(None), int]:
74 | continue
75 | else:
76 | assert isinstance(arg, list)
77 | assert all([type(x) == int for x in arg])
78 |
79 | # make an object that can be indexed into a tensor
80 | self.as_index = tuple([slice(None) if x is None else x for x in list_of_things_in_tuple])
81 |
82 | # make an object that can be hashed (so used as a dictionary key)
83 | self.hashable_tuple = tuple(list_of_things_in_tuple)
84 |
85 | def __hash__(self):
86 | return hash(self.hashable_tuple)
87 |
88 | def __eq__(self, other):
89 | return self.hashable_tuple == other.hashable_tuple
90 |
91 | # some graphics things
92 |
93 | def __repr__(self, use_actual_colon=True) -> str: # graphviz, an old library used to dislike actual colons in strings, but this shouldn't be an issue anymore
94 | ret = "["
95 | for idx, x in enumerate(self.hashable_tuple):
96 | if idx > 0:
97 | ret += ", "
98 | if x is None:
99 | ret += ":" if use_actual_colon else "COLON"
100 | elif type(x) == int:
101 | ret += str(x)
102 | else:
103 | raise NotImplementedError(x)
104 | ret += "]"
105 | return ret
106 |
107 | def graphviz_index(self, use_actual_colon=True) -> str:
108 | return self.__repr__(use_actual_colon=use_actual_colon)
109 |
--------------------------------------------------------------------------------
/acdc/TLACDCInterpNode.py:
--------------------------------------------------------------------------------
1 | from acdc.TLACDCEdge import (
2 | TorchIndex,
3 | Edge,
4 | EdgeType,
5 | ) # these introduce several important classes !!!
6 | from typing import List, Dict, Optional, Tuple, Union, Set, Callable, TypeVar, Iterable, Any
7 |
8 | class TLACDCInterpNode:
9 | """Represents one node in the computational graph, similar to ACDCInterpNode from the rust_circuit code
10 |
11 | But WARNING this has nodes closer to the input tokens as *parents* of nodes closer to the output tokens, the opposite of the rust_circuit code
12 |
13 | Params:
14 | name: name of the node
15 | index: the index of the tensor that this node represents
16 | mode: how we deal with this node when we bump into it as a parent of another node. Addition: it's summed to make up the child. Direct_computation: it's the sole node used to compute the child. Off: it's not the parent of a child ever."""
17 |
18 | def __init__(self, name: str, index: TorchIndex, incoming_edge_type: EdgeType):
19 |
20 | self.name = name
21 | self.index = index
22 |
23 | self.parents: List["TLACDCInterpNode"] = []
24 | self.children: List["TLACDCInterpNode"] = []
25 |
26 | self.incoming_edge_type = incoming_edge_type
27 |
28 | def _add_child(self, child_node: "TLACDCInterpNode"):
29 | """Use the method on TLACDCCorrespondence instead of this one"""
30 | self.children.append(child_node)
31 |
32 | def _add_parent(self, parent_node: "TLACDCInterpNode"):
33 | """Use the method on TLACDCCorrespondence instead of this one"""
34 | self.parents.append(parent_node)
35 |
36 | def __repr__(self):
37 | return f"TLACDCInterpNode({self.name}, {self.index})"
38 |
39 | def __str__(self) -> str:
40 | index_str = "" if len(self.index.hashable_tuple) < 3 else f"_{self.index.hashable_tuple[2]}"
41 | return f"{self.name}{self.index}"
42 |
43 | # ------------------
44 | # some munging utils
45 | # ------------------
46 |
47 | def parse_interpnode(s: str) -> TLACDCInterpNode:
48 | try:
49 | name, idx = s.split("[")
50 | name = name.replace("hook_resid_mid", "hook_mlp_in")
51 | try:
52 | idx = int(idx[-3:-1])
53 | except:
54 | try:
55 | idx = int(idx[-2])
56 | except:
57 | idx = None
58 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]) if idx is not None else TorchIndex([None]), EdgeType.ADDITION)
59 |
60 | except Exception as e:
61 | print(s, e)
62 | raise e
63 |
64 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]), EdgeType.ADDITION)
65 |
66 | def heads_to_nodes_to_mask(heads: List[Tuple[int, int]], return_dict=False):
67 | nodes_to_mask_strings = [
68 | f"blocks.{layer_idx}{'.attn' if not inputting else ''}.hook_{letter}{'_input' if inputting else ''}[COL, COL, {head_idx}]"
69 | # for layer_idx in range(model.cfg.n_layers)
70 | # for head_idx in range(model.cfg.n_heads)
71 | for layer_idx, head_idx in heads
72 | for letter in ["q", "k", "v"]
73 | for inputting in [True, False]
74 | ]
75 | nodes_to_mask_strings.extend([
76 | f"blocks.{layer_idx}.attn.hook_result[COL, COL, {head_idx}]"
77 | for layer_idx, head_idx in heads
78 | ])
79 |
80 | if return_dict:
81 | return {s: parse_interpnode(s) for s in nodes_to_mask_strings}
82 |
83 | else:
84 | return [parse_interpnode(s) for s in nodes_to_mask_strings]
85 |
--------------------------------------------------------------------------------
/acdc/__init__.py:
--------------------------------------------------------------------------------
1 | def check_transformer_lens_version():
2 | """Test that your TransformerLens version is up-to-date for ACDC
3 | by checking that `hook_mlp_in`s exist"""
4 |
5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6 |
7 | cfg = HookedTransformerConfig.from_dict(
8 | {
9 | "n_layers": 1,
10 | "d_model": 1,
11 | "n_ctx": 1,
12 | "d_head": 1,
13 | "act_fn": "gelu",
14 | "d_vocab": 0,
15 | }
16 | )
17 |
18 | from transformer_lens.HookedTransformer import HookedTransformer
19 | mini_trans = HookedTransformer(cfg)
20 |
21 | mini_trans.blocks[0].hook_mlp_in # try and access the hook_mlp_in: if this fails, your TL is not sufficiently up-to-date
22 |
23 | check_transformer_lens_version()
--------------------------------------------------------------------------------
/acdc/docstring/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/docstring/__init__.py
--------------------------------------------------------------------------------
/acdc/global_cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple, Literal, Dict
3 | from collections import OrderedDict
4 |
5 |
6 | class GlobalCache: # this dict stores the activations from the forward pass
7 | """Class for managing several caches for passing activations around"""
8 |
9 | def __init__(self, device: Union[str, Tuple[str, str]] = "cuda"):
10 | # TODO find a way to make the device propagate when we to .to on the p
11 | # TODO make it essential first key is a str, second a TorchIndex, third a str
12 |
13 | if isinstance(device, str):
14 | device = (device, device)
15 |
16 | self.online_cache = OrderedDict()
17 | self.corrupted_cache = OrderedDict()
18 | self.device: Tuple[str, str] = (device, device)
19 |
20 |
21 | def clear(self, just_first_cache=False):
22 |
23 | if not just_first_cache:
24 | self.online_cache = OrderedDict()
25 | else:
26 | raise NotImplementedError()
27 | self.__init__(self.device[0], self.device[1]) # lol
28 |
29 | import gc
30 | gc.collect()
31 | torch.cuda.empty_cache()
32 |
33 | def to(self, device, which_caches: Literal["online", "corrupted", "all"]="all"): #
34 |
35 | caches = []
36 | if which_caches != "online":
37 | self.device = (device, self.device[1])
38 | caches.append(self.online_cache)
39 | if which_caches != "corrupted":
40 | self.device = (self.device[0], device)
41 | caches.append(self.corrupted_cache)
42 |
43 | # move all the parameters
44 | for cache in caches: # mutable means this works..
45 | for name in cache:
46 | cache_keys = list(cache.keys())
47 | for k in cache_keys:
48 | cache[k].to(device) # = cache[name].to(device)
49 |
50 | return self
--------------------------------------------------------------------------------
/acdc/greaterthan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/greaterthan/__init__.py
--------------------------------------------------------------------------------
/acdc/induction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/induction/__init__.py
--------------------------------------------------------------------------------
/acdc/induction/utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from functools import partial
3 | from acdc.docstring.utils import AllDataThings
4 | import wandb
5 | import os
6 | from collections import defaultdict
7 | import pickle
8 | import torch
9 | import huggingface_hub
10 | import datetime
11 | from typing import Dict, Callable
12 | import torch
13 | import random
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from typing import (
17 | List,
18 | Tuple,
19 | Dict,
20 | Any,
21 | Optional,
22 | )
23 | import warnings
24 | import networkx as nx
25 | from acdc.acdc_utils import (
26 | MatchNLLMetric,
27 | make_nd_dict,
28 | shuffle_tensor,
29 | )
30 |
31 | from acdc.TLACDCEdge import (
32 | TorchIndex,
33 | Edge,
34 | EdgeType,
35 | ) # these introduce several important classes !!!
36 | from transformer_lens import HookedTransformer
37 | from acdc.acdc_utils import kl_divergence, negative_log_probs
38 |
39 | def get_model(device):
40 | tl_model = HookedTransformer.from_pretrained(
41 | "redwood_attn_2l", # load Redwood's model
42 | center_writing_weights=False, # these are needed as this model is a Shortformer; this is a technical detail
43 | center_unembed=False,
44 | fold_ln=False,
45 | device=device,
46 | local_path="/data/yunzhi/hugging_cache/redwood_attn_2l/",
47 | )
48 |
49 | # standard ACDC options
50 | tl_model.set_use_attn_result(True)
51 | tl_model.set_use_split_qkv_input(True)
52 | return tl_model
53 |
54 | def get_validation_data(num_examples=None, seq_len=None, device=None):
55 | # validation_fname = huggingface_hub.hf_hub_download(
56 | # repo_id="ArthurConmy/redwood_attn_2l", filename="validation_data.pt"
57 | # )
58 | validation_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/validation_data.pt"
59 | validation_data = torch.load(validation_fname, map_location=device).long()
60 |
61 | if num_examples is None:
62 | return validation_data
63 | else:
64 | return validation_data[:num_examples][:seq_len]
65 |
66 | def get_good_induction_candidates(num_examples=None, seq_len=None, device=None):
67 | """Not needed?"""
68 | # good_induction_candidates_fname = huggingface_hub.hf_hub_download(
69 | # repo_id="ArthurConmy/redwood_attn_2l", filename="good_induction_candidates.pt"
70 | # )
71 | good_induction_candidates_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/good_induction_candidates.pt"
72 | good_induction_candidates = torch.load(good_induction_candidates_fname, map_location=device)
73 |
74 | if num_examples is None:
75 | return good_induction_candidates
76 | else:
77 | return good_induction_candidates[:num_examples][:seq_len]
78 |
79 | def get_mask_repeat_candidates(num_examples=None, seq_len=None, device=None):
80 | # mask_repeat_candidates_fname = huggingface_hub.hf_hub_download(
81 | # repo_id="ArthurConmy/redwood_attn_2l", filename="mask_repeat_candidates.pkl"
82 | # )
83 | mask_repeat_candidates_fname = "/data/yunzhi/hugging_cache/redwood_attn_2l/mask_repeat_candidates.pkl"
84 | mask_repeat_candidates = torch.load(mask_repeat_candidates_fname, map_location=device)
85 | mask_repeat_candidates.requires_grad = False
86 |
87 | if num_examples is None:
88 | return mask_repeat_candidates
89 | else:
90 | return mask_repeat_candidates[:num_examples, :seq_len]
91 |
92 |
93 | def get_all_induction_things(num_examples, seq_len, device, data_seed=42, metric="kl_div", return_one_element=True) -> AllDataThings:
94 | tl_model = get_model(device=device)
95 | validation_data_orig = get_validation_data(device=device)
96 | mask_orig = get_mask_repeat_candidates(num_examples=None, device=device) # None so we get all
97 | assert validation_data_orig.shape == mask_orig.shape
98 |
99 | assert seq_len <= validation_data_orig.shape[1]-1
100 |
101 | validation_slice = slice(0, num_examples)
102 | validation_data = validation_data_orig[validation_slice, :seq_len].contiguous()
103 | validation_labels = validation_data_orig[validation_slice, 1:seq_len+1].contiguous()
104 | validation_mask = mask_orig[validation_slice, :seq_len].contiguous()
105 |
106 | validation_patch_data = shuffle_tensor(validation_data, seed=data_seed).contiguous()
107 |
108 | test_slice = slice(num_examples, num_examples*2)
109 | test_data = validation_data_orig[test_slice, :seq_len].contiguous()
110 | test_labels = validation_data_orig[test_slice, 1:seq_len+1].contiguous()
111 | test_mask = mask_orig[test_slice, :seq_len].contiguous()
112 |
113 | # data_seed+1: different shuffling
114 | test_patch_data = shuffle_tensor(test_data, seed=data_seed).contiguous()
115 |
116 | with torch.no_grad():
117 | base_val_logprobs = F.log_softmax(tl_model(validation_data), dim=-1).detach()
118 | base_test_logprobs = F.log_softmax(tl_model(test_data), dim=-1).detach()
119 |
120 | if metric == "kl_div":
121 | validation_metric = partial(
122 | kl_divergence,
123 | base_model_logprobs=base_val_logprobs,
124 | mask_repeat_candidates=validation_mask,
125 | last_seq_element_only=False,
126 | return_one_element=return_one_element,
127 | )
128 | elif metric == "nll":
129 | validation_metric = partial(
130 | negative_log_probs,
131 | labels=validation_labels,
132 | mask_repeat_candidates=validation_mask,
133 | last_seq_element_only=False,
134 | )
135 | elif metric == "match_nll":
136 | validation_metric = MatchNLLMetric(
137 | labels=validation_labels, base_model_logprobs=base_val_logprobs, mask_repeat_candidates=validation_mask,
138 | last_seq_element_only=False,
139 | )
140 | else:
141 | raise ValueError(f"Unknown metric {metric}")
142 |
143 | test_metrics = {
144 | "kl_div": partial(
145 | kl_divergence,
146 | base_model_logprobs=base_test_logprobs,
147 | mask_repeat_candidates=test_mask,
148 | last_seq_element_only=False,
149 | ),
150 | "nll": partial(
151 | negative_log_probs,
152 | labels=test_labels,
153 | mask_repeat_candidates=test_mask,
154 | last_seq_element_only=False,
155 | ),
156 | "match_nll": MatchNLLMetric(
157 | labels=test_labels, base_model_logprobs=base_test_logprobs, mask_repeat_candidates=test_mask,
158 | last_seq_element_only=False,
159 | ),
160 | }
161 | return AllDataThings(
162 | tl_model=tl_model,
163 | validation_metric=validation_metric,
164 | validation_data=validation_data,
165 | validation_labels=validation_labels,
166 | validation_mask=validation_mask,
167 | validation_patch_data=validation_patch_data,
168 | test_metrics=test_metrics,
169 | test_data=test_data,
170 | test_labels=test_labels,
171 | test_mask=test_mask,
172 | test_patch_data=test_patch_data,
173 | )
174 |
175 |
176 | def one_item_per_batch(toks_int_values, toks_int_values_other, mask_rep, base_model_logprobs, kl_take_mean=True):
177 | """Returns each instance of induction as its own batch idx"""
178 |
179 | end_positions = []
180 | batch_size, seq_len = toks_int_values.shape
181 | new_tensors = []
182 |
183 | toks_int_values_other_batch_list = []
184 | new_base_model_logprobs_list = []
185 |
186 | for i in range(batch_size):
187 | for j in range(seq_len - 1): # -1 because we don't know what follows the last token so can't calculate losses
188 | if mask_rep[i, j]:
189 | end_positions.append(j)
190 | new_tensors.append(toks_int_values[i].cpu().clone())
191 | toks_int_values_other_batch_list.append(toks_int_values_other[i].cpu().clone())
192 | new_base_model_logprobs_list.append(base_model_logprobs[i].cpu().clone())
193 |
194 | toks_int_values_other_batch = torch.stack(toks_int_values_other_batch_list).to(toks_int_values.device).clone()
195 | return_tensor = torch.stack(new_tensors).to(toks_int_values.device).clone()
196 | end_positions_tensor = torch.tensor(end_positions).long()
197 |
198 | new_base_model_logprobs = torch.stack(new_base_model_logprobs_list)[torch.arange(len(end_positions_tensor)), end_positions_tensor].to(toks_int_values.device).clone()
199 | metric = partial(
200 | kl_divergence,
201 | base_model_logprobs=new_base_model_logprobs,
202 | end_positions=end_positions_tensor,
203 | mask_repeat_candidates=None, # !!!
204 | last_seq_element_only=False,
205 | return_one_element=False
206 | )
207 |
208 | return return_tensor, toks_int_values_other_batch, end_positions_tensor, metric
209 |
--------------------------------------------------------------------------------
/acdc/ioi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/ioi/__init__.py
--------------------------------------------------------------------------------
/acdc/knowledge/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/knowledge/__init__.py
--------------------------------------------------------------------------------
/acdc/logic_gates/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/logic_gates/__init__.py
--------------------------------------------------------------------------------
/acdc/logic_gates/utils.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | from functools import partial
4 | import time
5 | import torch
6 | from typing import Literal, Optional
7 | from transformer_lens.HookedTransformer import HookedTransformer, HookedTransformerConfig
8 | from acdc.docstring.utils import AllDataThings
9 | from acdc.acdc_utils import kl_divergence
10 | import torch.nn.functional as F
11 |
12 | MAX_LOGIC_GATE_SEQ_LEN = 100_000 # Can be increased further provided numerics and memory do not explode
13 |
14 | def get_logic_gate_model(mode: Literal["OR", "AND"] = "OR", seq_len: Optional[int]=None, device="cuda") -> HookedTransformer:
15 |
16 | if seq_len is None:
17 | assert 1 <= seq_len <= MAX_LOGIC_GATE_SEQ_LEN, "We need some bound on sequence length, but this can be increased if the variable at the top is increased"
18 |
19 | if mode == "OR":
20 | assert seq_len == 1
21 | cfg = HookedTransformerConfig.from_dict(
22 | {
23 | "n_layers": 1,
24 | "d_model": 2,
25 | "n_ctx": 1,
26 | "n_heads": 2,
27 | "d_head": 1,
28 | "act_fn": "relu",
29 | "d_vocab": 1,
30 | "d_mlp": 1,
31 | "d_vocab_out": 1,
32 | "normalization_type": None,
33 | "attn_only": False,
34 | }
35 | )
36 | elif mode == "AND":
37 | cfg = HookedTransformerConfig.from_dict(
38 | {
39 | "n_layers": 1,
40 | "d_model": 3,
41 | "n_ctx": seq_len,
42 | "n_heads": 1,
43 | "d_head": 1,
44 | "act_fn": "relu",
45 | "d_vocab": 2,
46 | "d_mlp": 1,
47 | "d_vocab_out": 1,
48 | "normalization_type": None,
49 | }
50 | )
51 | else:
52 | raise ValueError(f"mode {mode} not recognized")
53 |
54 | model = HookedTransformer(cfg).to(device)
55 | model.set_use_attn_result(True)
56 | model.set_use_split_qkv_input(True)
57 | if "use_hook_mlp_in" in model.cfg.to_dict():
58 | model.set_use_hook_mlp_in(True)
59 | model = model.to(torch.double)
60 |
61 | # Turn off model gradient so we can edit weights
62 | # And also set all the weights to 0
63 | for param in model.parameters():
64 | param.requires_grad = False
65 | param[:] = 0.0
66 |
67 | if mode == "AND":
68 | # # Embed 1s as 1.0 in residual component 0
69 | model.embed.W_E[1, 0] = 1.0
70 |
71 | # No QK so uniform attention; this allows us to detect if everything is a 1 as the output into the channel 1 will be 1 not less than that
72 |
73 | # Output 1.0 into residual component 1 for all things present
74 | model.blocks[0].attn.W_V[0, 0, 0] = 1.0 # Shape [head_index d_model d_head]
75 | model.blocks[0].attn.W_O[0, 0, 1] = 1.0 # Shape [head_index d_head d_model]
76 |
77 | model.blocks[0].mlp.W_in[1, 0] = 1.0 # [d_model d_mlp]
78 | model.blocks[0].mlp.b_in[:] = -(MAX_LOGIC_GATE_SEQ_LEN-1)/MAX_LOGIC_GATE_SEQ_LEN # Unless everything in input is a 1, do not fire
79 |
80 | # Write the output to residual component 2
81 | # (TODO: I think we could get away with 2 components here?)
82 | model.blocks[0].mlp.W_out[0, 2] = MAX_LOGIC_GATE_SEQ_LEN # Shape [d_mlp d_model]
83 |
84 | model.unembed.W_U[2, 0] = 1.0 # Shape [d_model d_vocab_out]
85 |
86 | elif mode == "OR":
87 |
88 | # a0.0 and a0.1 are the two inputs to the OR gate; they always dump 1.0 into the residual stream
89 | # Both heads dump a 1 into the residual stream
90 | # We can test our circuit recovery methods with zero ablation to see if they recover either or both heads!
91 | model.blocks[0].attn.b_V[:, 0] = 1.0 # [num_heads, d_head]
92 | model.blocks[0].attn.W_O[:, 0, 0] = 1.0 # [num_heads, d_head, d_model]
93 |
94 | # mlp0 is an OR gate on the output on the output of a0.0 and a0.1; it turns the sum S of their outputs into 1 if S >= 1 and 0 if S = 0
95 | model.blocks[0].mlp.W_in[0, 0] = -1.0 # [d_model d_mlp]
96 | model.blocks[0].mlp.b_in[:] = 1.0 # [d_mlp]
97 |
98 | model.blocks[0].mlp.W_out[0, 1] = -1.0
99 | model.blocks[0].mlp.b_out[:] = 1.0 # [d_model]
100 |
101 | model.unembed.W_U[1, 0] = 1.0 # shape [d_model d_vocab_out]
102 |
103 | else:
104 | raise ValueError(f"mode {mode} not recognized")
105 |
106 | return model
107 |
108 | def test_and_logical_model():
109 | """
110 | Test that the AND gate works
111 | """
112 |
113 | seq_len=3
114 | and_model = get_logic_gate_model(mode="AND", seq_len=seq_len, device = "cpu")
115 |
116 | all_inputs = []
117 | for i in range(2**seq_len):
118 | input = torch.tensor([int(x) for x in f"{i:03b}"]).unsqueeze(0).long()
119 | all_inputs.append(input)
120 | input = torch.cat(all_inputs, dim=0)
121 |
122 | and_output = and_model(input)[:, -1, :]
123 | assert torch.equal(and_output[:2**seq_len - 1], torch.zeros(2**seq_len - 1, 1))
124 | torch.testing.assert_close(and_output[2**seq_len - 1], torch.ones(1).to(torch.double))
125 |
126 | #%%
127 |
128 | def get_all_logic_gate_things(mode: str = "AND", device=None, seq_len: Optional[int] = 5, num_examples: Optional[int] = 10, return_one_element: bool = False) -> AllDataThings:
129 |
130 | assert mode == "OR"
131 |
132 | model = get_logic_gate_model(mode=mode, seq_len=seq_len, device=device)
133 | # Convert the set of binary string back llto tensor
134 | data = torch.tensor([[0.0]]).long() # Input is actually meaningless, all that matters is Attention Heads 0 and 1
135 | correct_answers = data.clone().to(torch.double) + 1
136 |
137 | def validation_metric(output, correct):
138 | output = output[:, -1, :]
139 |
140 | assert output.shape == correct.shape
141 | if not return_one_element:
142 | return torch.mean((output - correct)**2, dim=0)
143 | else:
144 | return ((output - correct)**2).squeeze(1)
145 |
146 | base_validation_logprobs = F.log_softmax(model(data)[:, -1], dim=-1)
147 |
148 | test_metrics = {
149 | "kl_div": partial(
150 | kl_divergence,
151 | base_model_logprobs=base_validation_logprobs,
152 | last_seq_element_only=True,
153 | base_model_probs_last_seq_element_only=False,
154 | return_one_element=return_one_element,
155 | ),}
156 |
157 | return AllDataThings(
158 | tl_model=model,
159 | validation_metric=partial(validation_metric, correct=correct_answers),
160 | validation_data=data,
161 | validation_labels=None,
162 | validation_mask=None,
163 | validation_patch_data=data.clone(), # We're doing zero ablation so irrelevant
164 | test_metrics=test_metrics,
165 | test_data=data,
166 | test_labels=None,
167 | test_mask=None,
168 | test_patch_data=data.clone(),
169 | )
170 |
171 |
172 | # # # test_logical_models()
173 | # # %%
174 |
175 | # or_model = get_logic_gate_model(seq_len=1, device = "cpu")
176 | # logits, cache = or_model.run_with_cache(
177 | # torch.tensor([[0]]).to(torch.long),
178 | # )
179 | # print(logits)
180 |
181 | # # %%
182 |
183 | # for key in cache.keys():
184 | # print(key)
185 | # print(cache[key].shape)
186 | # print(cache[key])
187 | # print("\n\n\n")
188 | # # %%
189 | # #batch pos head_index d_head for hook_q
190 | # %%
191 |
--------------------------------------------------------------------------------
/acdc/run.sh:
--------------------------------------------------------------------------------
1 | MODEL_PATH=/path/to/your/model
2 | KT=factual
3 | KNOWLEDGE=country_capital_city
4 | NUM_EXAMPLES=1
5 | MODEL_NAME=gpt2-medium
6 |
7 | python main.py --task=knowledge \
8 | --zero-ablation \
9 | --threshold=0.01 \
10 | --device=cuda:0 \
11 | --metric=match_nll \
12 | --indices-mode=reverse \
13 | --first-cache-cpu=False \
14 | --second-cache-cpu=False \
15 | --max-num-epochs=100000 \
16 | --specific-knowledge=$KNOWLEDGE \
17 | --num-examples=$NUM_EXAMPLES \
18 | --relation-reverse=False \
19 | --knowledge-type=$KT \
20 | --model-name=$MODEL_NAME \
21 | --model-path=$MODEL_PATH
22 |
--------------------------------------------------------------------------------
/acdc/tracr_task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/acdc/tracr_task/__init__.py
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/Factual Knowledge/graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/Factual Knowledge/graph.png
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/Factual Knowledge/io.txt:
--------------------------------------------------------------------------------
1 | The capital city of China is
2 | Beijing
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/Hallucination/attention_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/Hallucination/attention_analysis.png
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/Hallucination/io.txt:
--------------------------------------------------------------------------------
1 | The official currency of Malaysia is called
2 | Ringgit
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/Hallucination/specialNode.txt:
--------------------------------------------------------------------------------
1 | a15.0
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/In-Context Learning/attention_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/computed_circuits/gpt2-medium/In-Context Learning/attention_analysis.png
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/In-Context Learning/io.txt:
--------------------------------------------------------------------------------
1 | The co mparative form of small is
2 | smaller
--------------------------------------------------------------------------------
/computed_circuits/gpt2-medium/In-Context Learning/specialNode.txt:
--------------------------------------------------------------------------------
1 | a12.13
2 | a14.11
3 | a17.2
--------------------------------------------------------------------------------
/data/bias/characteristic_gender.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "characteristic gender",
3 | "prompt_templates": [
4 | "{} is commonly associated with"
5 | ],
6 | "prompt_templates_zs": [
7 | "What gender (men or women) is {} associated with? It is associated with",
8 | "Which gender (men or women) is {} characteristic of? It is characteristic of",
9 | "Is {} most associated with men or women? It is associated with"
10 | ],
11 | "properties": {
12 | "relation_type": "bias",
13 | "domain_name": "characteristic",
14 | "range_name": "gender",
15 | "symmetric": false
16 | },
17 | "samples": [
18 | {
19 | "subject": "empathy",
20 | "object": "women"
21 | },
22 | {
23 | "subject": "confidence",
24 | "object": "men"
25 | },
26 | {
27 | "subject": "leadership",
28 | "object": "men"
29 | },
30 | {
31 | "subject": "compassion",
32 | "object": "women"
33 | },
34 | {
35 | "subject": "patience",
36 | "object": "women"
37 | },
38 | {
39 | "subject": "creativity",
40 | "object": "women"
41 | },
42 | {
43 | "subject": "assertiveness",
44 | "object": "men"
45 | },
46 | {
47 | "subject": "adventurousness",
48 | "object": "men"
49 | },
50 | {
51 | "subject": "independence",
52 | "object": "men"
53 | },
54 | {
55 | "subject": "persuasiveness",
56 | "object": "men"
57 | },
58 | {
59 | "subject": "flexibility",
60 | "object": "women"
61 | },
62 | {
63 | "subject": "sensitivity",
64 | "object": "women"
65 | },
66 | {
67 | "subject": "logical thinking",
68 | "object": "men"
69 | },
70 | {
71 | "subject": "competitiveness",
72 | "object": "men"
73 | },
74 | {
75 | "subject": "aggressiveness",
76 | "object": "men"
77 | },
78 | {
79 | "subject": "nurturing",
80 | "object": "women"
81 | },
82 | {
83 | "subject": "resilience",
84 | "object": "men"
85 | },
86 | {
87 | "subject": "meticulousness",
88 | "object": "women"
89 | },
90 | {
91 | "subject": "risk-taking",
92 | "object": "men"
93 | },
94 | {
95 | "subject": "multitasking",
96 | "object": "women"
97 | },
98 | {
99 | "subject": "intuition",
100 | "object": "women"
101 | },
102 | {
103 | "subject": "discipline",
104 | "object": "men"
105 | },
106 | {
107 | "subject": "ambition",
108 | "object": "men"
109 | },
110 | {
111 | "subject": "generosity",
112 | "object": "women"
113 | },
114 | {
115 | "subject": "humility",
116 | "object": "women"
117 | },
118 | {
119 | "subject": "critical thinking",
120 | "object": "men"
121 | },
122 | {
123 | "subject": "bravery",
124 | "object": "men"
125 | },
126 | {
127 | "subject": "adaptability",
128 | "object": "women"
129 | },
130 | {
131 | "subject": "decisiveness",
132 | "object": "men"
133 | },
134 | {
135 | "subject": "endurance",
136 | "object": "men"
137 | }
138 | ]
139 | }
--------------------------------------------------------------------------------
/data/bias/degree_gender.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "univ degree gender",
3 | "prompt_templates": [
4 | "{} students are typically"
5 | ],
6 | "prompt_templates_zs": [
7 | "Usually, what gender (men or women) are {} students? They are typically",
8 | "What gender (men or women) are {} students associated with? They are usually",
9 | "Are {} students mostly men or women? They are"
10 | ],
11 | "properties": {
12 | "relation_type": "bias",
13 | "domain_name": "degree",
14 | "range_name": "gender",
15 | "symmetric": false
16 | },
17 | "samples": [
18 | {
19 | "subject": "nursing",
20 | "object": "women"
21 | },
22 | {
23 | "subject": "engineering",
24 | "object": "men"
25 | },
26 | {
27 | "subject": "computer science",
28 | "object": "men"
29 | },
30 | {
31 | "subject": "education",
32 | "object": "women"
33 | },
34 | {
35 | "subject": "psychology",
36 | "object": "women"
37 | },
38 | {
39 | "subject": "social work",
40 | "object": "women"
41 | },
42 | {
43 | "subject": "business",
44 | "object": "men"
45 | },
46 | {
47 | "subject": "economics",
48 | "object": "men"
49 | },
50 | {
51 | "subject": "mathematics",
52 | "object": "men"
53 | },
54 | {
55 | "subject": "physics",
56 | "object": "men"
57 | },
58 | {
59 | "subject": "biology",
60 | "object": "women"
61 | },
62 | {
63 | "subject": "law",
64 | "object": "men"
65 | },
66 | {
67 | "subject": "sociology",
68 | "object": "women"
69 | },
70 | {
71 | "subject": "accounting",
72 | "object": "men"
73 | },
74 | {
75 | "subject": "electrical engineering",
76 | "object": "men"
77 | },
78 | {
79 | "subject": "fashion design",
80 | "object": "women"
81 | },
82 | {
83 | "subject": "philosophy",
84 | "object": "men"
85 | },
86 | {
87 | "subject": "graphic design",
88 | "object": "women"
89 | },
90 | {
91 | "subject": "environmental science",
92 | "object": "women"
93 | },
94 | {
95 | "subject": "architecture",
96 | "object": "men"
97 | },
98 | {
99 | "subject": "literature",
100 | "object": "women"
101 | },
102 | {
103 | "subject": "chemistry",
104 | "object": "men"
105 | },
106 | {
107 | "subject": "fine arts",
108 | "object": "women"
109 | },
110 | {
111 | "subject": "political science",
112 | "object": "men"
113 | },
114 | {
115 | "subject": "human resources",
116 | "object": "women"
117 | },
118 | {
119 | "subject": "communications",
120 | "object": "women"
121 | },
122 | {
123 | "subject": "astronomy",
124 | "object": "men"
125 | },
126 | {
127 | "subject": "interior design",
128 | "object": "women"
129 | },
130 | {
131 | "subject": "marine biology",
132 | "object": "women"
133 | },
134 | {
135 | "subject": "mechanical engineering",
136 | "object": "men"
137 | },
138 | {
139 | "subject": "classical studies",
140 | "object": "women"
141 | },
142 | {
143 | "subject": "statistics",
144 | "object": "men"
145 | },
146 | {
147 | "subject": "public relations",
148 | "object": "women"
149 | },
150 | {
151 | "subject": "culinary arts",
152 | "object": "women"
153 | },
154 | {
155 | "subject": "history",
156 | "object": "men"
157 | },
158 | {
159 | "subject": "anthropology",
160 | "object": "women"
161 | },
162 | {
163 | "subject": "geology",
164 | "object": "men"
165 | },
166 | {
167 | "subject": "pharmacy",
168 | "object": "women"
169 | }
170 | ]
171 | }
--------------------------------------------------------------------------------
/data/bias/name_birthplace.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "name birthplace",
3 | "prompt_templates": [
4 | "{} was born in the country of"
5 | ],
6 | "prompt_templates_zs": [
7 | "{} was born in the country of",
8 | "What country was {} most likely born in? They were born in"
9 | ],
10 | "properties": {
11 | "relation_type": "bias",
12 | "domain_name": "name",
13 | "range_name": "country",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "Tal",
19 | "object": "Israel"
20 | },
21 | {
22 | "subject": "Shaked",
23 | "object": "Israel"
24 | },
25 | {
26 | "subject": "Yitzhak",
27 | "object": "Israel"
28 | },
29 | {
30 | "subject": "Hila",
31 | "object": "Israel"
32 | },
33 | {
34 | "subject": "Manish",
35 | "object": "India"
36 | },
37 | {
38 | "subject": "Sanjana",
39 | "object": "India"
40 | },
41 | {
42 | "subject": "Rohit",
43 | "object": "India"
44 | },
45 | {
46 | "subject": "Arjun",
47 | "object": "India"
48 | },
49 | {
50 | "subject": "Kazuki",
51 | "object": "Japan"
52 | },
53 | {
54 | "subject": "Akira",
55 | "object": "Japan"
56 | },
57 | {
58 | "subject": "Sakura",
59 | "object": "Japan"
60 | },
61 | {
62 | "subject": "Hiroshi",
63 | "object": "Japan"
64 | },
65 | {
66 | "subject": "Jules",
67 | "object": "France"
68 | },
69 | {
70 | "subject": "Pierre",
71 | "object": "France"
72 | },
73 | {
74 | "subject": "Antoine",
75 | "object": "France"
76 | },
77 | {
78 | "subject": "Giovanni",
79 | "object": "Italy"
80 | },
81 | {
82 | "subject": "Marco",
83 | "object": "Italy"
84 | },
85 | {
86 | "subject": "Carlo",
87 | "object": "Italy"
88 | },
89 | {
90 | "subject": "Chang",
91 | "object": "China"
92 | },
93 | {
94 | "subject": "Tao",
95 | "object": "China"
96 | },
97 | {
98 | "subject": "Mei",
99 | "object": "China"
100 | },
101 | {
102 | "subject": "Hong",
103 | "object": "China"
104 | },
105 | {
106 | "subject": "Xia",
107 | "object": "China"
108 | },
109 | {
110 | "subject": "Alejandro",
111 | "object": "Spain"
112 | },
113 | {
114 | "subject": "Pablo",
115 | "object": "Spain"
116 | },
117 | {
118 | "subject": "Juan",
119 | "object": "Spain"
120 | },
121 | {
122 | "subject": "Diego",
123 | "object": "Spain"
124 | },
125 | {
126 | "subject": "Anong",
127 | "object": "Thailand"
128 | },
129 | {
130 | "subject": "Chalermchai",
131 | "object": "Thailand"
132 | },
133 | {
134 | "subject": "Kraipob",
135 | "object": "Thailand"
136 | },
137 | {
138 | "subject": "Chawalit",
139 | "object": "Thailand"
140 | }
141 | ]
142 | }
143 |
--------------------------------------------------------------------------------
/data/bias/name_gender.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "name gender",
3 | "prompt_templates": [
4 | "{} is usually a name for a"
5 | ],
6 | "prompt_templates_zs": [
7 | "Is the name {} usually for a man or woman? It's usually for a",
8 | "What gender (man or woman) is the name {}? It's usually a",
9 | "Is {} a name for a man or a woman? It is for a"
10 | ],
11 | "properties": {
12 | "relation_type": "bias",
13 | "domain_name": "name",
14 | "range_name": "gender",
15 | "symmetric": false
16 | },
17 | "samples": [
18 | {
19 | "subject": "Lisa",
20 | "object": "woman"
21 | },
22 | {
23 | "subject": "Tom",
24 | "object": "man"
25 | },
26 | {
27 | "subject": "Sofia",
28 | "object": "woman"
29 | },
30 | {
31 | "subject": "Oliver",
32 | "object": "man"
33 | },
34 | {
35 | "subject": "Natalie",
36 | "object": "woman"
37 | },
38 | {
39 | "subject": "William",
40 | "object": "man"
41 | },
42 | {
43 | "subject": "Benjamin",
44 | "object": "man"
45 | },
46 | {
47 | "subject": "Scarlett",
48 | "object": "woman"
49 | },
50 | {
51 | "subject": "Sofia",
52 | "object": "woman"
53 | },
54 | {
55 | "subject": "Dylan",
56 | "object": "man"
57 | },
58 | {
59 | "subject": "Hannah",
60 | "object": "woman"
61 | },
62 | {
63 | "subject": "David",
64 | "object": "man"
65 | },
66 | {
67 | "subject": "Evan",
68 | "object": "man"
69 | },
70 | {
71 | "subject": "Emily",
72 | "object": "woman"
73 | },
74 | {
75 | "subject": "Mia",
76 | "object": "woman"
77 | },
78 | {
79 | "subject": "Connor",
80 | "object": "man"
81 | },
82 | {
83 | "subject": "Michael",
84 | "object": "man"
85 | },
86 | {
87 | "subject": "Caleb",
88 | "object": "man"
89 | },
90 | {
91 | "subject": "Lucas",
92 | "object": "man"
93 | }
94 | ]
95 | }
--------------------------------------------------------------------------------
/data/bias/name_religion.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "name religion",
3 | "prompt_templates": [
4 | "{} likely follows the religion of"
5 | ],
6 | "prompt_templates_zs": [
7 | "{} is most likely a follower of the religion of",
8 | "What religion does {} most likely follow? They follow"
9 | ],
10 | "properties": {
11 | "relation_type": "bias",
12 | "domain_name": "name",
13 | "range_name": "religion",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "Abdullah",
19 | "object": "Islam"
20 | },
21 | {
22 | "subject": "Ali",
23 | "object": "Islam"
24 | },
25 | {
26 | "subject": "Fatima",
27 | "object": "Islam"
28 | },
29 | {
30 | "subject": "Hussain",
31 | "object": "Islam"
32 | },
33 | {
34 | "subject": "Omar",
35 | "object": "Islam"
36 | },
37 | {
38 | "subject": "Salim",
39 | "object": "Islam"
40 | },
41 | {
42 | "subject": "Abraham",
43 | "object": "Judaism"
44 | },
45 | {
46 | "subject": "Isaac",
47 | "object": "Judaism"
48 | },
49 | {
50 | "subject": "Abigail",
51 | "object": "Judaism"
52 | },
53 | {
54 | "subject": "Esther",
55 | "object": "Judaism"
56 | },
57 | {
58 | "subject": "Miriam",
59 | "object": "Judaism"
60 | },
61 | {
62 | "subject": "Moses",
63 | "object": "Judaism"
64 | },
65 | {
66 | "subject": "Andrew",
67 | "object": "Christianity"
68 | },
69 | {
70 | "subject": "Peter",
71 | "object": "Christianity"
72 | },
73 | {
74 | "subject": "Thomas",
75 | "object": "Christianity"
76 | },
77 | {
78 | "subject": "Caroline",
79 | "object": "Christianity"
80 | },
81 | {
82 | "subject": "Isabel",
83 | "object": "Christianity"
84 | },
85 | {
86 | "subject": "Emily",
87 | "object": "Christianity"
88 | },
89 | {
90 | "subject": "Paul",
91 | "object": "Christianity"
92 | },
93 | {
94 | "subject": "Ajahn",
95 | "object": "Buddhism"
96 | },
97 | {
98 | "subject": "Ananda",
99 | "object": "Buddhism"
100 | },
101 | {
102 | "subject": "Jivaka",
103 | "object": "Buddhism"
104 | },
105 | {
106 | "subject": "Bhikkhuni",
107 | "object": "Buddhism"
108 | },
109 | {
110 | "subject": "Lhamo",
111 | "object": "Buddhism"
112 | },
113 | {
114 | "subject": "Sumana",
115 | "object": "Buddhism"
116 | },
117 | {
118 | "subject": "Arjun",
119 | "object": "Hinduism"
120 | },
121 | {
122 | "subject": "Raj",
123 | "object": "Hinduism"
124 | },
125 | {
126 | "subject": "Rohit",
127 | "object": "Hinduism"
128 | },
129 | {
130 | "subject": "Pooja",
131 | "object": "Hinduism"
132 | },
133 | {
134 | "subject": "Aarti",
135 | "object": "Hinduism"
136 | },
137 | {
138 | "subject": "Sita",
139 | "object": "Hinduism"
140 | }
141 | ]
142 | }
--------------------------------------------------------------------------------
/data/bias/occupation_age.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "occupation age",
3 | "prompt_templates": [
4 | "{} are typically"
5 | ],
6 | "prompt_templates_zs": [
7 | "What age (young or old) do {} tend to be? They are typically",
8 | "Typically, are {} young or old? They are typically"
9 | ],
10 | "properties": {
11 | "relation_type": "bias",
12 | "domain_name": "occupation",
13 | "range_name": "age",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "interns",
19 | "object": "young"
20 | },
21 | {
22 | "subject": "mentors",
23 | "object": "old"
24 | },
25 | {
26 | "subject": "professors",
27 | "object": "old"
28 | },
29 | {
30 | "subject": "graphic designers",
31 | "object": "young"
32 | },
33 | {
34 | "subject": "advisors",
35 | "object": "old"
36 | },
37 | {
38 | "subject": "personal assistants",
39 | "object": "young"
40 | },
41 | {
42 | "subject": "politicians",
43 | "object": "old"
44 | },
45 | {
46 | "subject": "writers",
47 | "object": "old"
48 | },
49 | {
50 | "subject": "analysts",
51 | "object": "old"
52 | },
53 | {
54 | "subject": "executives",
55 | "object": "old"
56 | },
57 | {
58 | "subject": "content creators",
59 | "object": "young"
60 | },
61 | {
62 | "subject": "doctors",
63 | "object": "old"
64 | },
65 | {
66 | "subject": "freelancers",
67 | "object": "old"
68 | },
69 | {
70 | "subject": "judges",
71 | "object": "old"
72 | },
73 | {
74 | "subject": "bloggers",
75 | "object": "young"
76 | },
77 | {
78 | "subject": "baristas",
79 | "object": "young"
80 | },
81 | {
82 | "subject": "social media influencers",
83 | "object": "young"
84 | },
85 | {
86 | "subject": "architects",
87 | "object": "old"
88 | },
89 | {
90 | "subject": "web developers",
91 | "object": "young"
92 | },
93 | {
94 | "subject": "videographers",
95 | "object": "young"
96 | },
97 | {
98 | "subject": "astronauts",
99 | "object": "old"
100 | },
101 | {
102 | "subject": "biologists",
103 | "object": "old"
104 | },
105 | {
106 | "subject": "chefs",
107 | "object": "old"
108 | },
109 | {
110 | "subject": "dancers",
111 | "object": "young"
112 | },
113 | {
114 | "subject": "engineers",
115 | "object": "old"
116 | },
117 | {
118 | "subject": "farmers",
119 | "object": "old"
120 | },
121 | {
122 | "subject": "game designers",
123 | "object": "young"
124 | },
125 | {
126 | "subject": "hiking guides",
127 | "object": "young"
128 | },
129 | {
130 | "subject": "investigators",
131 | "object": "old"
132 | },
133 | {
134 | "subject": "journalists",
135 | "object": "old"
136 | },
137 | {
138 | "subject": "kindergarten teachers",
139 | "object": "young"
140 | },
141 | {
142 | "subject": "librarians",
143 | "object": "old"
144 | },
145 | {
146 | "subject": "musicians",
147 | "object": "young"
148 | },
149 | {
150 | "subject": "nurses",
151 | "object": "old"
152 | },
153 | {
154 | "subject": "ophthalmologists",
155 | "object": "old"
156 | },
157 | {
158 | "subject": "psychologists",
159 | "object": "old"
160 | },
161 | {
162 | "subject": "quantum physicists",
163 | "object": "old"
164 | },
165 | {
166 | "subject": "restaurant managers",
167 | "object": "old"
168 | },
169 | {
170 | "subject": "scientists",
171 | "object": "old"
172 | },
173 | {
174 | "subject": "tattoo artists",
175 | "object": "young"
176 | },
177 | {
178 | "subject": "university students",
179 | "object": "young"
180 | },
181 | {
182 | "subject": "veterinarians",
183 | "object": "old"
184 | },
185 | {
186 | "subject": "waitstaff",
187 | "object": "young"
188 | },
189 | {
190 | "subject": "yoga instructors",
191 | "object": "young"
192 | },
193 | {
194 | "subject": "zookeepers",
195 | "object": "old"
196 | }
197 | ]
198 | }
--------------------------------------------------------------------------------
/data/bias/occupation_gender.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "occupation gender",
3 | "prompt_templates": [
4 | "{} are typically"
5 | ],
6 | "prompt_templates_zs": [
7 | "Usually, what gender (men or women) are {}? They are typically",
8 | "What gender (men or women) are {} associated with? They are usually",
9 | "Is the profession of {} usually for men or women? It is usually for"
10 | ],
11 | "properties": {
12 | "relation_type": "bias",
13 | "domain_name": "occupation",
14 | "range_name": "gender",
15 | "symmetric": false
16 | },
17 | "samples": [
18 | {
19 | "subject": "nurses",
20 | "object": "women"
21 | },
22 | {
23 | "subject": "pilots",
24 | "object": "men"
25 | },
26 | {
27 | "subject": "surgeons",
28 | "object": "men"
29 | },
30 | {
31 | "subject": "teachers",
32 | "object": "women"
33 | },
34 | {
35 | "subject": "psychologists",
36 | "object": "women"
37 | },
38 | {
39 | "subject": "librarians",
40 | "object": "women"
41 | },
42 | {
43 | "subject": "chefs",
44 | "object": "men"
45 | },
46 | {
47 | "subject": "plumbers",
48 | "object": "men"
49 | },
50 | {
51 | "subject": "engineers",
52 | "object": "men"
53 | },
54 | {
55 | "subject": "dietitians",
56 | "object": "women"
57 | },
58 | {
59 | "subject": "hair stylists",
60 | "object": "women"
61 | },
62 | {
63 | "subject": "accountants",
64 | "object": "men"
65 | },
66 | {
67 | "subject": "entrepreneurs",
68 | "object": "men"
69 | },
70 | {
71 | "subject": "social workers",
72 | "object": "women"
73 | },
74 | {
75 | "subject": "receptionists",
76 | "object": "women"
77 | },
78 | {
79 | "subject": "software developers",
80 | "object": "men"
81 | },
82 | {
83 | "subject": "researchers",
84 | "object": "men"
85 | },
86 | {
87 | "subject": "presidents",
88 | "object": "men"
89 | },
90 | {
91 | "subject": "professors",
92 | "object": "men"
93 | }
94 | ]
95 | }
96 |
--------------------------------------------------------------------------------
/data/commonsense/fruit_inside_color.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "fruit inside color",
3 | "prompt_templates": [
4 | "On the inside, {} are"
5 | ],
6 | "prompt_templates_zs": [
7 | "What color are {} on the inside? They are"
8 | ],
9 | "properties": {
10 | "relation_type": "commonsense",
11 | "domain_name": "fruit",
12 | "range_name": "color",
13 | "symmetric": false
14 | },
15 | "samples": [
16 | {
17 | "subject": "bananas",
18 | "object": "white"
19 | },
20 | {
21 | "subject": "apples",
22 | "object": "white"
23 | },
24 | {
25 | "subject": "watermelons",
26 | "object": "red"
27 | },
28 | {
29 | "subject": "kiwis",
30 | "object": "green"
31 | },
32 | {
33 | "subject": "dragon fruits",
34 | "object": "white"
35 | },
36 | {
37 | "subject": "eggplants",
38 | "object": "white"
39 | },
40 | {
41 | "subject": "zucchinis",
42 | "object": "white"
43 | },
44 | {
45 | "subject": "pineapples",
46 | "object": "yellow"
47 | },
48 | {
49 | "subject": "mangoes",
50 | "object": "orange"
51 | },
52 | {
53 | "subject": "cucumbers",
54 | "object": "white"
55 | },
56 | {
57 | "subject": "radishes",
58 | "object": "white"
59 | },
60 | {
61 | "subject": "passion fruits",
62 | "object": "yellow"
63 | },
64 | {
65 | "subject": "nectarines",
66 | "object": "yellow"
67 | },
68 | {
69 | "subject": "plums",
70 | "object": "red"
71 | },
72 | {
73 | "subject": "potatos",
74 | "object": "white"
75 | },
76 | {
77 | "subject": "strawberries",
78 | "object": "red"
79 | },
80 | {
81 | "subject": "avocados",
82 | "object": "green"
83 | },
84 | {
85 | "subject": "peaches",
86 | "object": "yellow"
87 | },
88 | {
89 | "subject": "pomegranates",
90 | "object": "red"
91 | },
92 | {
93 | "subject": "cherries",
94 | "object": "red"
95 | },
96 | {
97 | "subject": "grapes",
98 | "object": "green"
99 | },
100 | {
101 | "subject": "blueberries",
102 | "object": "green"
103 | },
104 | {
105 | "subject": "oranges",
106 | "object": "orange"
107 | },
108 | {
109 | "subject": "lemons",
110 | "object": "yellow"
111 | },
112 | {
113 | "subject": "limes",
114 | "object": "green"
115 | },
116 | {
117 | "subject": "grapefruits",
118 | "object": "pink"
119 | },
120 | {
121 | "subject": "blackberries",
122 | "object": "red"
123 | },
124 | {
125 | "subject": "raspberries",
126 | "object": "red"
127 | },
128 | {
129 | "subject": "papayas",
130 | "object": "orange"
131 | },
132 | {
133 | "subject": "apricots",
134 | "object": "orange"
135 | },
136 | {
137 | "subject": "tomatoes",
138 | "object": "red"
139 | },
140 | {
141 | "subject": "bell peppers",
142 | "object": "red"
143 | },
144 | {
145 | "subject": "persimmons",
146 | "object": "orange"
147 | },
148 | {
149 | "subject": "lychees",
150 | "object": "white"
151 | },
152 | {
153 | "subject": "coconuts",
154 | "object": "white"
155 | },
156 | {
157 | "subject": "cantaloupes",
158 | "object": "orange"
159 | }
160 | ]
161 | }
--------------------------------------------------------------------------------
/data/commonsense/fruit_outside_color.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "fruit outside color",
3 | "prompt_templates": [
4 | "On the outside, {} are"
5 | ],
6 | "prompt_templates_zs": [
7 | "What color are {} on the outside? They are the color of"
8 | ],
9 | "properties": {
10 | "relation_type": "commonsense",
11 | "domain_name": "fruit",
12 | "range_name": "color",
13 | "symmetric": false
14 | },
15 | "samples": [
16 | {
17 | "subject": "bananas",
18 | "object": "yellow"
19 | },
20 | {
21 | "subject": "apples",
22 | "object": "red"
23 | },
24 | {
25 | "subject": "watermelons",
26 | "object": "green"
27 | },
28 | {
29 | "subject": "kiwis",
30 | "object": "brown"
31 | },
32 | {
33 | "subject": "dragon fruits",
34 | "object": "pink"
35 | },
36 | {
37 | "subject": "eggplants",
38 | "object": "purple"
39 | },
40 | {
41 | "subject": "zucchinis",
42 | "object": "green"
43 | },
44 | {
45 | "subject": "pineapples",
46 | "object": "brown"
47 | },
48 | {
49 | "subject": "mangoes",
50 | "object": "green"
51 | },
52 | {
53 | "subject": "cucumbers",
54 | "object": "green"
55 | },
56 | {
57 | "subject": "radishes",
58 | "object": "pink"
59 | },
60 | {
61 | "subject": "passion fruits",
62 | "object": "purple"
63 | },
64 | {
65 | "subject": "nectarines",
66 | "object": "red"
67 | },
68 | {
69 | "subject": "plums",
70 | "object": "purple"
71 | },
72 | {
73 | "subject": "potatos",
74 | "object": "brown"
75 | },
76 | {
77 | "subject": "grapefruits",
78 | "object": "orange"
79 | },
80 | {
81 | "subject": "limes",
82 | "object": "green"
83 | },
84 | {
85 | "subject": "oranges",
86 | "object": "orange"
87 | },
88 | {
89 | "subject": "peaches",
90 | "object": "pink"
91 | },
92 | {
93 | "subject": "pomegranates",
94 | "object": "red"
95 | },
96 | {
97 | "subject": "cherries",
98 | "object": "red"
99 | },
100 | {
101 | "subject": "strawberries",
102 | "object": "red"
103 | },
104 | {
105 | "subject": "lemons",
106 | "object": "yellow"
107 | },
108 | {
109 | "subject": "avocados",
110 | "object": "green"
111 | },
112 | {
113 | "subject": "coconuts",
114 | "object": "brown"
115 | },
116 | {
117 | "subject": "blueberries",
118 | "object": "blue"
119 | },
120 | {
121 | "subject": "apricots",
122 | "object": "orange"
123 | },
124 | {
125 | "subject": "blackberries",
126 | "object": "black"
127 | },
128 | {
129 | "subject": "raspberries",
130 | "object": "red"
131 | },
132 | {
133 | "subject": "figs",
134 | "object": "purple"
135 | }
136 | ]
137 | }
--------------------------------------------------------------------------------
/data/commonsense/object_superclass.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "object superclass",
3 | "prompt_templates": [
4 | "A {} is in the category of a"
5 | ],
6 | "prompt_templates_zs": [
7 | "Is a {} a animal, bird, fish, vegetable, flower, or fruit? It is a",
8 | "What superclass (animal, bird, fish, vegetable, flower, fruit) is a {}? It is a"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "object",
13 | "range_name": "superclass",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "tiger",
19 | "object": "animal"
20 | },
21 | {
22 | "subject": "eagle",
23 | "object": "bird"
24 | },
25 | {
26 | "subject": "salmon",
27 | "object": "fish"
28 | },
29 | {
30 | "subject": "spinach",
31 | "object": "vegetable"
32 | },
33 | {
34 | "subject": "rose",
35 | "object": "flower"
36 | },
37 | {
38 | "subject": "mango",
39 | "object": "fruit"
40 | },
41 | {
42 | "subject": "elephant",
43 | "object": "animal"
44 | },
45 | {
46 | "subject": "ostrich",
47 | "object": "bird"
48 | },
49 | {
50 | "subject": "dolphin",
51 | "object": "mammal"
52 | },
53 | {
54 | "subject": "giraffe",
55 | "object": "animal"
56 | },
57 | {
58 | "subject": "dog",
59 | "object": "mammal"
60 | },
61 | {
62 | "subject": "cat",
63 | "object": "mammal"
64 | },
65 | {
66 | "subject": "lion",
67 | "object": "animal"
68 | },
69 | {
70 | "subject": "snake",
71 | "object": "reptile"
72 | },
73 | {
74 | "subject": "turtle",
75 | "object": "reptile"
76 | },
77 | {
78 | "subject": "fish",
79 | "object": "fish"
80 | },
81 | {
82 | "subject": "goldfish",
83 | "object": "fish"
84 | },
85 | {
86 | "subject": "shark",
87 | "object": "fish"
88 | },
89 | {
90 | "subject": "whale",
91 | "object": "mammal"
92 | },
93 | {
94 | "subject": "crocodile",
95 | "object": "reptile"
96 | },
97 | {
98 | "subject": "lizard",
99 | "object": "reptile"
100 | },
101 | {
102 | "subject": "frog",
103 | "object": "amphibian"
104 | },
105 | {
106 | "subject": "toad",
107 | "object": "amphibian"
108 | },
109 | {
110 | "subject": "tree",
111 | "object": "plant"
112 | },
113 | {
114 | "subject": "flower",
115 | "object": "plant"
116 | },
117 | {
118 | "subject": "grass",
119 | "object": "plant"
120 | },
121 | {
122 | "subject": "weed",
123 | "object": "plant"
124 | },
125 | {
126 | "subject": "banana",
127 | "object": "plant"
128 | },
129 | {
130 | "subject": "apple",
131 | "object": "plant"
132 | },
133 | {
134 | "subject": "orange",
135 | "object": "plant"
136 | },
137 | {
138 | "subject": "pineapple",
139 | "object": "plant"
140 | },
141 | {
142 | "subject": "carrot",
143 | "object": "plant"
144 | },
145 | {
146 | "subject": "potato",
147 | "object": "plant"
148 | },
149 | {
150 | "subject": "onion",
151 | "object": "plant"
152 | },
153 | {
154 | "subject": "tomato",
155 | "object": "plant"
156 | },
157 | {
158 | "subject": "cucumber",
159 | "object": "plant"
160 | },
161 | {
162 | "subject": "trout",
163 | "object": "fish"
164 | },
165 | {
166 | "subject": "bass",
167 | "object": "fish"
168 | },
169 | {
170 | "subject": "carp",
171 | "object": "fish"
172 | },
173 | {
174 | "subject": "catfish",
175 | "object": "fish"
176 | },
177 | {
178 | "subject": "tilapia",
179 | "object": "fish"
180 | },
181 | {
182 | "subject": "goldfish",
183 | "object": "fish"
184 | },
185 | {
186 | "subject": "salmon",
187 | "object": "fish"
188 | },
189 | {
190 | "subject": "tuna",
191 | "object": "fish"
192 | },
193 | {
194 | "subject": "swordfish",
195 | "object": "fish"
196 | },
197 | {
198 | "subject": "shark",
199 | "object": "fish"
200 | },
201 | {
202 | "subject": "sparrow",
203 | "object": "bird"
204 | },
205 | {
206 | "subject": "crow",
207 | "object": "bird"
208 | },
209 | {
210 | "subject": "robin",
211 | "object": "bird"
212 | },
213 | {
214 | "subject": "blue jay",
215 | "object": "bird"
216 | },
217 | {
218 | "subject": "owl",
219 | "object": "bird"
220 | },
221 | {
222 | "subject": "eagle",
223 | "object": "bird"
224 | },
225 | {
226 | "subject": "hawk",
227 | "object": "bird"
228 | },
229 | {
230 | "subject": "turkey",
231 | "object": "bird"
232 | },
233 | {
234 | "subject": "duck",
235 | "object": "bird"
236 | },
237 | {
238 | "subject": "goose",
239 | "object": "bird"
240 | },
241 | {
242 | "subject": "carrot",
243 | "object": "vegetable"
244 | },
245 | {
246 | "subject": "potato",
247 | "object": "vegetable"
248 | },
249 | {
250 | "subject": "tomato",
251 | "object": "vegetable"
252 | },
253 | {
254 | "subject": "lettuce",
255 | "object": "vegetable"
256 | },
257 | {
258 | "subject": "cucumber",
259 | "object": "vegetable"
260 | },
261 | {
262 | "subject": "onion",
263 | "object": "vegetable"
264 | },
265 | {
266 | "subject": "bell pepper",
267 | "object": "vegetable"
268 | },
269 | {
270 | "subject": "broccoli",
271 | "object": "vegetable"
272 | },
273 | {
274 | "subject": "cauliflower",
275 | "object": "vegetable"
276 | },
277 | {
278 | "subject": "kale",
279 | "object": "vegetable"
280 | },
281 | {
282 | "subject": "apple",
283 | "object": "fruit"
284 | },
285 | {
286 | "subject": "banana",
287 | "object": "fruit"
288 | },
289 | {
290 | "subject": "orange",
291 | "object": "fruit"
292 | },
293 | {
294 | "subject": "grapefruit",
295 | "object": "fruit"
296 | },
297 | {
298 | "subject": "grapes",
299 | "object": "fruit"
300 | },
301 | {
302 | "subject": "peach",
303 | "object": "fruit"
304 | },
305 | {
306 | "subject": "pear",
307 | "object": "fruit"
308 | },
309 | {
310 | "subject": "watermelon",
311 | "object": "fruit"
312 | },
313 | {
314 | "subject": "strawberry",
315 | "object": "fruit"
316 | },
317 | {
318 | "subject": "blueberry",
319 | "object": "fruit"
320 | }
321 | ]
322 | }
--------------------------------------------------------------------------------
/data/commonsense/substance_phase.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "substance phase of matter",
3 | "prompt_templates": [
4 | "{}'s phase of matter at room temperature is a"
5 | ],
6 | "prompt_templates_zs": [
7 | "{}'s phase of matter at room temperature is a",
8 | "What is the phase of matter of {} at room temperature? It is"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "substance",
13 | "range_name": "phase of matter",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "water",
19 | "object": "liquid"
20 | },
21 | {
22 | "subject": "iron",
23 | "object": "solid"
24 | },
25 | {
26 | "subject": "oxygen",
27 | "object": "gas"
28 | },
29 | {
30 | "subject": "gold",
31 | "object": "solid"
32 | },
33 | {
34 | "subject": "mercury",
35 | "object": "liquid"
36 | },
37 | {
38 | "subject": "aluminum",
39 | "object": "solid"
40 | },
41 | {
42 | "subject": "nitrogen",
43 | "object": "gas"
44 | },
45 | {
46 | "subject": "silicon",
47 | "object": "solid"
48 | },
49 | {
50 | "subject": "neon",
51 | "object": "gas"
52 | },
53 | {
54 | "subject": "ethanol",
55 | "object": "liquid"
56 | },
57 | {
58 | "subject": "sulfur",
59 | "object": "solid"
60 | },
61 | {
62 | "subject": "helium",
63 | "object": "gas"
64 | },
65 | {
66 | "subject": "lead",
67 | "object": "solid"
68 | },
69 | {
70 | "subject": "ice cream",
71 | "object": "solid"
72 | },
73 | {
74 | "subject": "coffee",
75 | "object": "liquid"
76 | },
77 | {
78 | "subject": "wood",
79 | "object": "solid"
80 | },
81 | {
82 | "subject": "plastic",
83 | "object": "solid"
84 | },
85 | {
86 | "subject": "butter",
87 | "object": "solid"
88 | },
89 | {
90 | "subject": "honey",
91 | "object": "liquid"
92 | },
93 | {
94 | "subject": "wine",
95 | "object": "liquid"
96 | },
97 | {
98 | "subject": "glass",
99 | "object": "solid"
100 | },
101 | {
102 | "subject": "carbon dioxide",
103 | "object": "gas"
104 | },
105 | {
106 | "subject": "paper",
107 | "object": "solid"
108 | },
109 | {
110 | "subject": "copper",
111 | "object": "solid"
112 | },
113 | {
114 | "subject": "chocolate",
115 | "object": "solid"
116 | },
117 | {
118 | "subject": "petroleum",
119 | "object": "liquid"
120 | },
121 | {
122 | "subject": "steam",
123 | "object": "gas"
124 | },
125 | {
126 | "subject": "ice",
127 | "object": "solid"
128 | },
129 | {
130 | "subject": "diamond",
131 | "object": "solid"
132 | },
133 | {
134 | "subject": "milk",
135 | "object": "liquid"
136 | },
137 | {
138 | "subject": "olive oil",
139 | "object": "liquid"
140 | },
141 | {
142 | "subject": "soap",
143 | "object": "solid"
144 | },
145 | {
146 | "subject": "rubber",
147 | "object": "solid"
148 | },
149 | {
150 | "subject": "glycerin",
151 | "object": "liquid"
152 | },
153 | {
154 | "subject": "tea",
155 | "object": "liquid"
156 | },
157 | {
158 | "subject": "hydrogen",
159 | "object": "gas"
160 | },
161 | {
162 | "subject": "salt",
163 | "object": "solid"
164 | },
165 | {
166 | "subject": "sugar",
167 | "object": "solid"
168 | },
169 | {
170 | "subject": "vinegar",
171 | "object": "liquid"
172 | },
173 | {
174 | "subject": "silver",
175 | "object": "solid"
176 | },
177 | {
178 | "subject": "leather",
179 | "object": "solid"
180 | },
181 | {
182 | "subject": "argon",
183 | "object": "gas"
184 | },
185 | {
186 | "subject": "wax",
187 | "object": "solid"
188 | },
189 | {
190 | "subject": "beer",
191 | "object": "liquid"
192 | },
193 | {
194 | "subject": "radium",
195 | "object": "solid"
196 | },
197 | {
198 | "subject": "platinum",
199 | "object": "solid"
200 | },
201 | {
202 | "subject": "juice",
203 | "object": "liquid"
204 | },
205 | {
206 | "subject": "tungsten",
207 | "object": "solid"
208 | },
209 | {
210 | "subject": "kerosene",
211 | "object": "liquid"
212 | },
213 | {
214 | "subject": "champagne",
215 | "object": "liquid"
216 | }
217 | ]
218 | }
219 |
--------------------------------------------------------------------------------
/data/commonsense/task_done_by_person.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "task person type",
3 | "prompt_templates": [
4 | "{} is best suited for someone with the role of a"
5 | ],
6 | "prompt_templates_zs": [
7 | "The task of {} would be best performed by someone with the role of a",
8 | "The professional role most suited to handle {} is a"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "task",
13 | "range_name": "occupation type",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "researching history",
19 | "object": "historian"
20 | },
21 | {
22 | "subject": "delivering mail",
23 | "object": "mail carrier"
24 | },
25 | {
26 | "subject": "photographing weddings",
27 | "object": "photographer"
28 | },
29 | {
30 | "subject": "baking cakes",
31 | "object": "baker"
32 | },
33 | {
34 | "subject": "leading teams",
35 | "object": "leader"
36 | },
37 | {
38 | "subject": "directing movies",
39 | "object": "director"
40 | },
41 | {
42 | "subject": "investigating diseases",
43 | "object": "epidemiologist"
44 | },
45 | {
46 | "subject": "designing buildings",
47 | "object": "architect"
48 | },
49 | {
50 | "subject": "playing piano concertos",
51 | "object": "pianist"
52 | },
53 | {
54 | "subject": "translating books",
55 | "object": "translator"
56 | },
57 | {
58 | "subject": "repairing computers",
59 | "object": "technician"
60 | },
61 | {
62 | "subject": "selling houses",
63 | "object": "real estate agent"
64 | },
65 | {
66 | "subject": "managing hotels",
67 | "object": "hotel manager"
68 | },
69 | {
70 | "subject": "farming",
71 | "object": "farmer"
72 | },
73 | {
74 | "subject": "rescuing mountaineers",
75 | "object": "rescuer"
76 | },
77 | {
78 | "subject": "analyzing genetics",
79 | "object": "geneticist"
80 | },
81 | {
82 | "subject": "flying airplanes",
83 | "object": "pilot"
84 | },
85 | {
86 | "subject": "writing novels",
87 | "object": "author"
88 | },
89 | {
90 | "subject": "investigating crimes",
91 | "object": "detective"
92 | },
93 | {
94 | "subject": "making clothes",
95 | "object": "fashion designer"
96 | },
97 | {
98 | "subject": "conducting an orchestra",
99 | "object": "conductor"
100 | },
101 | {
102 | "subject": "building bridges",
103 | "object": "civil engineer"
104 | },
105 | {
106 | "subject": "treating animals",
107 | "object": "veterinarian"
108 | },
109 | {
110 | "subject": "driving trucks",
111 | "object": "truck driver"
112 | },
113 | {
114 | "subject": "providing legal advice",
115 | "object": "lawyer"
116 | },
117 | {
118 | "subject": "coding software",
119 | "object": "software engineer"
120 | },
121 | {
122 | "subject": "reporting news",
123 | "object": "journalist"
124 | },
125 | {
126 | "subject": "managing finances",
127 | "object": "financial advisor"
128 | },
129 | {
130 | "subject": "exploring space",
131 | "object": "astronaut"
132 | },
133 | {
134 | "subject": "teaching students",
135 | "object": "teacher"
136 | },
137 | {
138 | "subject": "cooking meals",
139 | "object": "chef"
140 | },
141 | {
142 | "subject": "performing surgeries",
143 | "object": "surgeon"
144 | }
145 | ]
146 | }
147 |
--------------------------------------------------------------------------------
/data/commonsense/task_done_by_tool.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "task done by tool",
3 | "prompt_templates": [
4 | "The tool used for {} is called a"
5 | ],
6 | "prompt_templates_zs": [
7 | "What tool is used for {}? Usually, you need a",
8 | "To accomplish {}, you need a tool called a"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "task",
13 | "range_name": "tool",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "hitting nails",
19 | "object": "hammer"
20 | },
21 | {
22 | "subject": "turning screws",
23 | "object": "screwdriver"
24 | },
25 | {
26 | "subject": "cutting",
27 | "object": "knife"
28 | },
29 | {
30 | "subject": "drilling holes",
31 | "object": "drill"
32 | },
33 | {
34 | "subject": "sawing wood",
35 | "object": "saw"
36 | },
37 | {
38 | "subject": "scraping paint",
39 | "object": "scraper"
40 | },
41 | {
42 | "subject": "painting walls",
43 | "object": "paintbrush"
44 | },
45 | {
46 | "subject": "mopping floors",
47 | "object": "mop"
48 | },
49 | {
50 | "subject": "sweeping floors",
51 | "object": "broom"
52 | },
53 | {
54 | "subject": "washing dishes",
55 | "object": "sponge"
56 | },
57 | {
58 | "subject": "ironing clothes",
59 | "object": "iron"
60 | },
61 | {
62 | "subject": "sewing",
63 | "object": "needle and thread"
64 | },
65 | {
66 | "subject": "knitting",
67 | "object": "yarn"
68 | },
69 | {
70 | "subject": "hunting",
71 | "object": "gun"
72 | },
73 | {
74 | "subject": "boating",
75 | "object": "boat"
76 | },
77 | {
78 | "subject": "stirring food",
79 | "object": "spoon"
80 | },
81 | {
82 | "subject": "measuring ingredients",
83 | "object": "cup"
84 | },
85 | {
86 | "subject": "baking",
87 | "object": "oven"
88 | },
89 | {
90 | "subject": "mixing batter",
91 | "object": "whisk"
92 | },
93 | {
94 | "subject": "digging soil",
95 | "object": "shovel"
96 | },
97 | {
98 | "subject": "raking leaves",
99 | "object": "rake"
100 | },
101 | {
102 | "subject": "cleaning windows",
103 | "object": "squeegee"
104 | },
105 | {
106 | "subject": "vacuuming carpets",
107 | "object": "vacuum cleaner"
108 | },
109 | {
110 | "subject": "washing clothes",
111 | "object": "washing machine"
112 | },
113 | {
114 | "subject": "drying clothes",
115 | "object": "clothesline"
116 | },
117 | {
118 | "subject": "polishing shoes",
119 | "object": "shoe polish"
120 | },
121 | {
122 | "subject": "painting furniture",
123 | "object": "paint roller"
124 | },
125 | {
126 | "subject": "sanding wood",
127 | "object": "sandpaper"
128 | },
129 | {
130 | "subject": "hiking",
131 | "object": "hiking boots"
132 | },
133 | {
134 | "subject": "biking",
135 | "object": "bicycle"
136 | },
137 | {
138 | "subject": "swimming",
139 | "object": "swimsuit"
140 | },
141 | {
142 | "subject": "cooking",
143 | "object": "stove"
144 | },
145 | {
146 | "subject": "writing",
147 | "object": "pen and paper"
148 | },
149 | {
150 | "subject": "drawing",
151 | "object": "pencil and sketchbook"
152 | },
153 | {
154 | "subject": "gardening",
155 | "object": "gardening gloves"
156 | },
157 | {
158 | "subject": "photography",
159 | "object": "camera"
160 | },
161 | {
162 | "subject": "playing sports",
163 | "object": "ball"
164 | },
165 | {
166 | "subject": "exercising",
167 | "object": "dumbbells"
168 | },
169 | {
170 | "subject": "dancing",
171 | "object": "music"
172 | },
173 | {
174 | "subject": "watching movies",
175 | "object": "television"
176 | },
177 | {
178 | "subject": "reading",
179 | "object": "book"
180 | },
181 | {
182 | "subject": "listening to music",
183 | "object": "headphones"
184 | },
185 | {
186 | "subject": "singing",
187 | "object": "microphone"
188 | },
189 | {
190 | "subject": "measuring",
191 | "object": "scale"
192 | },
193 | {
194 | "subject": "birdwatching",
195 | "object": "binoculars"
196 | },
197 | {
198 | "subject": "playing basketball",
199 | "object": "basketball"
200 | },
201 | {
202 | "subject": "playing soccer",
203 | "object": "soccer ball"
204 | },
205 | {
206 | "subject": "skateboarding",
207 | "object": "skateboard"
208 | },
209 | {
210 | "subject": "riding a scooter",
211 | "object": "scooter"
212 | },
213 | {
214 | "subject": "flying a kite",
215 | "object": "kite"
216 | },
217 | {
218 | "subject": "doing makeup",
219 | "object": "makeup brushes"
220 | },
221 | {
222 | "subject": "taking photographs",
223 | "object": "camera"
224 | }
225 | ]
226 | }
--------------------------------------------------------------------------------
/data/commonsense/word_sentiment.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "word sentiment",
3 | "prompt_templates": [
4 | "The sentiment of '{}' is"
5 | ],
6 | "prompt_templates_zs": [
7 | "Is the sentiment of the word '{}' positive, negative, or neutral? It is",
8 | "What is the sentiment (positive, negative, neutral) of '{}'? It is"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "word",
13 | "range_name": "sentiment",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "happy",
19 | "object": "positive"
20 | },
21 | {
22 | "subject": "joy",
23 | "object": "positive"
24 | },
25 | {
26 | "subject": "love",
27 | "object": "positive"
28 | },
29 | {
30 | "subject": "peace",
31 | "object": "positive"
32 | },
33 | {
34 | "subject": "hope",
35 | "object": "positive"
36 | },
37 | {
38 | "subject": "excited",
39 | "object": "positive"
40 | },
41 | {
42 | "subject": "grateful",
43 | "object": "positive"
44 | },
45 | {
46 | "subject": "proud",
47 | "object": "positive"
48 | },
49 | {
50 | "subject": "blessed",
51 | "object": "positive"
52 | },
53 | {
54 | "subject": "confident",
55 | "object": "positive"
56 | },
57 | {
58 | "subject": "content",
59 | "object": "positive"
60 | },
61 | {
62 | "subject": "satisfied",
63 | "object": "positive"
64 | },
65 | {
66 | "subject": "optimistic",
67 | "object": "positive"
68 | },
69 | {
70 | "subject": "cheerful",
71 | "object": "positive"
72 | },
73 | {
74 | "subject": "ecstatic",
75 | "object": "positive"
76 | },
77 | {
78 | "subject": "delighted",
79 | "object": "positive"
80 | },
81 | {
82 | "subject": "thrilled",
83 | "object": "positive"
84 | },
85 | {
86 | "subject": "overjoyed",
87 | "object": "positive"
88 | },
89 | {
90 | "subject": "elated",
91 | "object": "positive"
92 | },
93 | {
94 | "subject": "blissful",
95 | "object": "positive"
96 | },
97 | {
98 | "subject": "sad",
99 | "object": "negative"
100 | },
101 | {
102 | "subject": "unhappy",
103 | "object": "negative"
104 | },
105 | {
106 | "subject": "depressed",
107 | "object": "negative"
108 | },
109 | {
110 | "subject": "lonely",
111 | "object": "negative"
112 | },
113 | {
114 | "subject": "heartbroken",
115 | "object": "negative"
116 | },
117 | {
118 | "subject": "anxious",
119 | "object": "negative"
120 | },
121 | {
122 | "subject": "frustrated",
123 | "object": "negative"
124 | },
125 | {
126 | "subject": "angry",
127 | "object": "negative"
128 | },
129 | {
130 | "subject": "jealous",
131 | "object": "negative"
132 | },
133 | {
134 | "subject": "hateful",
135 | "object": "negative"
136 | },
137 | {
138 | "subject": "disappointed",
139 | "object": "negative"
140 | },
141 | {
142 | "subject": "gloomy",
143 | "object": "negative"
144 | },
145 | {
146 | "subject": "dejected",
147 | "object": "negative"
148 | },
149 | {
150 | "subject": "hopeless",
151 | "object": "negative"
152 | },
153 | {
154 | "subject": "despairing",
155 | "object": "negative"
156 | },
157 | {
158 | "subject": "frightened",
159 | "object": "negative"
160 | },
161 | {
162 | "subject": "terrified",
163 | "object": "negative"
164 | },
165 | {
166 | "subject": "scared",
167 | "object": "negative"
168 | },
169 | {
170 | "subject": "worried",
171 | "object": "negative"
172 | },
173 | {
174 | "subject": "apprehensive",
175 | "object": "negative"
176 | },
177 | {
178 | "subject": "nervous",
179 | "object": "negative"
180 | },
181 | {
182 | "subject": "neutral",
183 | "object": "neutral"
184 | },
185 | {
186 | "subject": "computer",
187 | "object": "neutral"
188 | },
189 | {
190 | "subject": "car",
191 | "object": "neutral"
192 | },
193 | {
194 | "subject": "house",
195 | "object": "neutral"
196 | },
197 | {
198 | "subject": "tree",
199 | "object": "neutral"
200 | },
201 | {
202 | "subject": "book",
203 | "object": "neutral"
204 | },
205 | {
206 | "subject": "money",
207 | "object": "neutral"
208 | },
209 | {
210 | "subject": "time",
211 | "object": "neutral"
212 | },
213 | {
214 | "subject": "day",
215 | "object": "neutral"
216 | },
217 | {
218 | "subject": "week",
219 | "object": "neutral"
220 | },
221 | {
222 | "subject": "ordinary",
223 | "object": "neutral"
224 | },
225 | {
226 | "subject": "common",
227 | "object": "neutral"
228 | },
229 | {
230 | "subject": "typical",
231 | "object": "neutral"
232 | },
233 | {
234 | "subject": "average",
235 | "object": "neutral"
236 | },
237 | {
238 | "subject": "indifferent",
239 | "object": "neutral"
240 | },
241 | {
242 | "subject": "unbiased",
243 | "object": "neutral"
244 | },
245 | {
246 | "subject": "impartial",
247 | "object": "neutral"
248 | },
249 | {
250 | "subject": "objective",
251 | "object": "neutral"
252 | },
253 | {
254 | "subject": "neutral",
255 | "object": "neutral"
256 | }
257 | ]
258 | }
--------------------------------------------------------------------------------
/data/commonsense/work_location.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "work location",
3 | "prompt_templates": [
4 | "A {} typically works at a"
5 | ],
6 | "prompt_templates_zs": [
7 | "A {} typically works at a",
8 | "You can usually find a {} working in a"
9 | ],
10 | "properties": {
11 | "relation_type": "commonsense",
12 | "domain_name": "occupation",
13 | "range_name": "location",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "farmer",
19 | "object": "farm"
20 | },
21 | {
22 | "subject": "lawyer",
23 | "object": "courthouse"
24 | },
25 | {
26 | "subject": "teacher",
27 | "object": "school"
28 | },
29 | {
30 | "subject": "accountant",
31 | "object": "office"
32 | },
33 | {
34 | "subject": "artist",
35 | "object": "studio"
36 | },
37 | {
38 | "subject": "athlete",
39 | "object": "stadium"
40 | },
41 | {
42 | "subject": "baker",
43 | "object": "bakery"
44 | },
45 | {
46 | "subject": "barber",
47 | "object": "barbershop"
48 | },
49 | {
50 | "subject": "chef",
51 | "object": "kitchen"
52 | },
53 | {
54 | "subject": "doctor",
55 | "object": "hospital"
56 | },
57 | {
58 | "subject": "fashion designer",
59 | "object": "studio"
60 | },
61 | {
62 | "subject": "firefighter",
63 | "object": "fire station"
64 | },
65 | {
66 | "subject": "florist",
67 | "object": "flower shop"
68 | },
69 | {
70 | "subject": "flight attendant",
71 | "object": "airplane"
72 | },
73 | {
74 | "subject": "hairdresser",
75 | "object": "salon"
76 | },
77 | {
78 | "subject": "historian",
79 | "object": "library"
80 | },
81 | {
82 | "subject": "insurance agent",
83 | "object": "office"
84 | },
85 | {
86 | "subject": "journalist",
87 | "object": "office"
88 | },
89 | {
90 | "subject": "librarian",
91 | "object": "library"
92 | },
93 | {
94 | "subject": "mechanic",
95 | "object": "garage"
96 | },
97 | {
98 | "subject": "musician",
99 | "object": "concert hall"
100 | },
101 | {
102 | "subject": "nurse",
103 | "object": "hospital"
104 | },
105 | {
106 | "subject": "painter",
107 | "object": "studio"
108 | },
109 | {
110 | "subject": "pharmacist",
111 | "object": "pharmacy"
112 | },
113 | {
114 | "subject": "photographer",
115 | "object": "studio"
116 | },
117 | {
118 | "subject": "pilot",
119 | "object": "airplane"
120 | },
121 | {
122 | "subject": "researcher",
123 | "object": "laboratory"
124 | },
125 | {
126 | "subject": "salesperson",
127 | "object": "store"
128 | },
129 | {
130 | "subject": "scientist",
131 | "object": "laboratory"
132 | },
133 | {
134 | "subject": "secretary",
135 | "object": "office"
136 | },
137 | {
138 | "subject": "soldier",
139 | "object": "military base"
140 | },
141 | {
142 | "subject": "software engineer",
143 | "object": "office"
144 | },
145 | {
146 | "subject": "student",
147 | "object": "school"
148 | },
149 | {
150 | "subject": "surgeon",
151 | "object": "hospital"
152 | },
153 | {
154 | "subject": "trainer",
155 | "object": "gym"
156 | },
157 | {
158 | "subject": "truck driver",
159 | "object": "truck"
160 | },
161 | {
162 | "subject": "waitress",
163 | "object": "restaurant"
164 | },
165 | {
166 | "subject": "waiter",
167 | "object": "restaurant"
168 | }
169 | ]
170 | }
--------------------------------------------------------------------------------
/data/factual/city_in_country.json:
--------------------------------------------------------------------------------
1 | {
2 | "name":"city in country",
3 | "prompt_templates":[
4 | "{} is part of",
5 | "{} is in the country of"
6 | ],
7 | "prompt_templates_zs":[
8 | "{} is part of the country of",
9 | "{} is located in the country of"
10 | ],
11 | "properties":{
12 | "relation_type":"factual",
13 | "domain_name":"city",
14 | "range_name":"country",
15 | "symmetric":false
16 | },
17 | "samples":[
18 | {
19 | "subject":"New York City",
20 | "object":"United States"
21 | },
22 | {
23 | "subject":"Rio de Janeiro",
24 | "object":"Brazil"
25 | },
26 | {
27 | "subject":"Buenos Aires",
28 | "object":"Argentina"
29 | },
30 | {
31 | "subject":"Mexico City",
32 | "object":"Mexico"
33 | },
34 | {
35 | "subject":"São Paulo",
36 | "object":"Brazil"
37 | },
38 | {
39 | "subject":"Los Angeles",
40 | "object":"United States"
41 | },
42 | {
43 | "subject":"Saint Petersburg",
44 | "object":"Russia"
45 | },
46 | {
47 | "subject":"San Francisco",
48 | "object":"United States"
49 | },
50 | {
51 | "subject":"Ho Chi Minh City",
52 | "object":"Vietnam"
53 | },
54 | {
55 | "subject":"Kuala Lumpur",
56 | "object":"Malaysia"
57 | },
58 | {
59 | "subject":"Abu Dhabi",
60 | "object":"United Arab Emirates"
61 | },
62 | {
63 | "subject":"Cape Town",
64 | "object":"South Africa"
65 | },
66 | {
67 | "subject":"New Delhi",
68 | "object":"India"
69 | },
70 | {
71 | "subject":"Las Vegas",
72 | "object":"United States"
73 | },
74 | {
75 | "subject":"Hong Kong",
76 | "object":"China"
77 | },
78 | {
79 | "subject":"Tel Aviv",
80 | "object":"Israel"
81 | },
82 | {
83 | "subject":"Johannesburg",
84 | "object":"South Africa"
85 | },
86 | {
87 | "subject":"Santo Domingo",
88 | "object":"Dominican Republic"
89 | },
90 | {
91 | "subject":"Port-au-Prince",
92 | "object":"Haiti"
93 | },
94 | {
95 | "subject":"Santiago de Chile",
96 | "object":"Chile"
97 | },
98 | {
99 | "subject":"Panama City",
100 | "object":"Panama"
101 | },
102 | {
103 | "subject":"Siem Reap",
104 | "object":"Cambodia"
105 | },
106 | {
107 | "subject":"Casablanca",
108 | "object":"Morocco"
109 | },
110 | {
111 | "subject":"San Juan",
112 | "object":"Puerto Rico"
113 | },
114 | {
115 | "subject":"Costa Rica",
116 | "object":"San José"
117 | },
118 | {
119 | "subject":"Addis Ababa",
120 | "object":"Ethiopia"
121 | },
122 | {
123 | "subject":"Punta Cana",
124 | "object":"Dominican Republic"
125 | }
126 | ]
127 | }
--------------------------------------------------------------------------------
/data/factual/country_capital_city.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "country capital city",
3 | "prompt_templates": [
4 | "The capital city of {} is",
5 | "The capital of {} is"
6 | ],
7 | "prompt_templates_zs": [
8 | "The capital of {} is the city of",
9 | "What is the capital of {}? It is the city of"
10 | ],
11 | "properties": {
12 | "relation_type": "factual",
13 | "domain_name": "country",
14 | "range_name": "city",
15 | "symmetric": false
16 | },
17 | "samples": [
18 | {
19 | "subject": "United States",
20 | "object": "Washington D.C."
21 | },
22 | {
23 | "subject": "Canada",
24 | "object": "Ottawa"
25 | },
26 | {
27 | "subject": "Mexico",
28 | "object": "Mexico City"
29 | },
30 | {
31 | "subject": "Brazil",
32 | "object": "Bras\\u00edlia"
33 | },
34 | {
35 | "subject": "Argentina",
36 | "object": "Buenos Aires"
37 | },
38 | {
39 | "subject": "Chile",
40 | "object": "Santiago"
41 | },
42 | {
43 | "subject": "Peru",
44 | "object": "Lima"
45 | },
46 | {
47 | "subject": "Colombia",
48 | "object": "Bogot\\u00e1"
49 | },
50 | {
51 | "subject": "Venezuela",
52 | "object": "Caracas"
53 | },
54 | {
55 | "subject": "Spain",
56 | "object": "Madrid"
57 | },
58 | {
59 | "subject": "France",
60 | "object": "Paris"
61 | },
62 | {
63 | "subject": "Germany",
64 | "object": "Berlin"
65 | },
66 | {
67 | "subject": "Italy",
68 | "object": "Rome"
69 | },
70 | {
71 | "subject": "Russia",
72 | "object": "Moscow"
73 | },
74 | {
75 | "subject": "China",
76 | "object": "Beijing"
77 | },
78 | {
79 | "subject": "Japan",
80 | "object": "Tokyo"
81 | },
82 | {
83 | "subject": "South Korea",
84 | "object": "Seoul"
85 | },
86 | {
87 | "subject": "India",
88 | "object": "New Delhi"
89 | },
90 | {
91 | "subject": "Pakistan",
92 | "object": "Islamabad"
93 | },
94 | {
95 | "subject": "Nigeria",
96 | "object": "Abuja"
97 | },
98 | {
99 | "subject": "Egypt",
100 | "object": "Cairo"
101 | },
102 | {
103 | "subject": "Saudi Arabia",
104 | "object": "Riyadh"
105 | },
106 | {
107 | "subject": "Turkey",
108 | "object": "Ankara"
109 | },
110 | {
111 | "subject": "Australia",
112 | "object": "Canberra"
113 | }
114 | ]
115 | }
116 |
--------------------------------------------------------------------------------
/data/factual/country_currency.json:
--------------------------------------------------------------------------------
1 | {
2 | "name":"country currency",
3 | "prompt_templates":[
4 | "The official currency of {} is the",
5 | "{}'s official currency is the"
6 | ],
7 | "prompt_templates_zs":[
8 | "What is the official currency of {}? It is called the",
9 | "{}'s official currency is called the",
10 | "The name of {}'s currency is the"
11 | ],
12 | "properties":{
13 | "relation_type":"factual",
14 | "domain_name":"country",
15 | "range_name":"currency",
16 | "symmetric":false
17 | },
18 | "samples":[
19 | {
20 | "subject":"United States",
21 | "object":"Dollar"
22 | },
23 | {
24 | "subject":"United Kingdom",
25 | "object":"Pound"
26 | },
27 | {
28 | "subject":"Japan",
29 | "object":"Yen"
30 | },
31 | {
32 | "subject":"Canada",
33 | "object":"Dollar"
34 | },
35 | {
36 | "subject":"Australia",
37 | "object":"Dollar"
38 | },
39 | {
40 | "subject":"Brazil",
41 | "object":"Real"
42 | },
43 | {
44 | "subject":"China",
45 | "object":"Yuan"
46 | },
47 | {
48 | "subject":"India",
49 | "object":"Rupee"
50 | },
51 | {
52 | "subject":"Russia",
53 | "object":"Ruble"
54 | },
55 | {
56 | "subject":"South Africa",
57 | "object":"Rand"
58 | },
59 | {
60 | "subject":"Mexico",
61 | "object":"Peso"
62 | },
63 | {
64 | "subject":"New Zealand",
65 | "object":"Dollar"
66 | },
67 | {
68 | "subject":"South Korea",
69 | "object":"Won"
70 | },
71 | {
72 | "subject":"Switzerland",
73 | "object":"Franc"
74 | },
75 | {
76 | "subject":"Turkey",
77 | "object":"Lira"
78 | },
79 | {
80 | "subject":"Argentina",
81 | "object":"Peso"
82 | },
83 | {
84 | "subject":"Norway",
85 | "object":"Krone"
86 | },
87 | {
88 | "subject":"Sweden",
89 | "object":"Krona"
90 | },
91 | {
92 | "subject":"Denmark",
93 | "object":"Krone"
94 | },
95 | {
96 | "subject":"Poland",
97 | "object":"Zloty"
98 | },
99 | {
100 | "subject":"Hungary",
101 | "object":"Forint"
102 | },
103 | {
104 | "subject":"Czech Republic",
105 | "object":"Koruna"
106 | },
107 | {
108 | "subject":"Israel",
109 | "object":"Shekel"
110 | },
111 | {
112 | "subject":"Saudi Arabia",
113 | "object":"Riyal"
114 | },
115 | {
116 | "subject":"United Arab Emirates",
117 | "object":"Dirham"
118 | },
119 | {
120 | "subject":"Singapore",
121 | "object":"Dollar"
122 | },
123 | {
124 | "subject":"Malaysia",
125 | "object":"Ringgit"
126 | },
127 | {
128 | "subject":"Indonesia",
129 | "object":"Rupiah"
130 | },
131 | {
132 | "subject":"Thailand",
133 | "object":"Baht"
134 | },
135 | {
136 | "subject":"Philippines",
137 | "object":"Peso"
138 | }
139 | ]
140 | }
--------------------------------------------------------------------------------
/data/factual/country_language.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "country language",
3 | "prompt_templates": [
4 | "People in {} speak",
5 | "The language used in {} is",
6 | "In {}, the primary language is"
7 | ],
8 | "prompt_templates_zs": [
9 | "{}, where most people speak",
10 | "In {}, people speak the language of",
11 | "People in {} speak the language of"
12 | ],
13 | "properties": {
14 | "relation_type": "factual",
15 | "domain_name": "country",
16 | "range_name": "language",
17 | "symmetric": false
18 | },
19 | "samples": [
20 | {
21 | "subject": "United States",
22 | "object": "English"
23 | },
24 | {
25 | "subject": "Canada",
26 | "object": "English"
27 | },
28 | {
29 | "subject": "Mexico",
30 | "object": "Spanish"
31 | },
32 | {
33 | "subject": "Brazil",
34 | "object": "Portuguese"
35 | },
36 | {
37 | "subject": "Argentina",
38 | "object": "Spanish"
39 | },
40 | {
41 | "subject": "Chile",
42 | "object": "Spanish"
43 | },
44 | {
45 | "subject": "Peru",
46 | "object": "Spanish"
47 | },
48 | {
49 | "subject": "Colombia",
50 | "object": "Spanish"
51 | },
52 | {
53 | "subject": "Venezuela",
54 | "object": "Spanish"
55 | },
56 | {
57 | "subject": "Spain",
58 | "object": "Spanish"
59 | },
60 | {
61 | "subject": "France",
62 | "object": "French"
63 | },
64 | {
65 | "subject": "Germany",
66 | "object": "German"
67 | },
68 | {
69 | "subject": "Italy",
70 | "object": "Italian"
71 | },
72 | {
73 | "subject": "Russia",
74 | "object": "Russian"
75 | },
76 | {
77 | "subject": "China",
78 | "object": "Mandarin Chinese"
79 | },
80 | {
81 | "subject": "Japan",
82 | "object": "Japanese"
83 | },
84 | {
85 | "subject": "South Korea",
86 | "object": "Korean"
87 | },
88 | {
89 | "subject": "India",
90 | "object": "Hindi"
91 | },
92 | {
93 | "subject": "Pakistan",
94 | "object": "Urdu"
95 | },
96 | {
97 | "subject": "Nigeria",
98 | "object": "English"
99 | },
100 | {
101 | "subject": "Egypt",
102 | "object": "Arabic"
103 | },
104 | {
105 | "subject": "Saudi Arabia",
106 | "object": "Arabic"
107 | },
108 | {
109 | "subject": "Turkey",
110 | "object": "Turkish"
111 | },
112 | {
113 | "subject": "Australia",
114 | "object": "English"
115 | }
116 | ]
117 | }
--------------------------------------------------------------------------------
/data/factual/country_largest_city.json:
--------------------------------------------------------------------------------
1 | {
2 | "name":"country largest city",
3 | "prompt_templates":[
4 | "The largest city in {} is",
5 | "The biggest city in {} is"
6 | ],
7 | "prompt_templates_zs":[
8 | "What is the largest city in {}? It is the city of",
9 | "The largest city in {} is the city of"
10 | ],
11 | "properties":{
12 | "relation_type":"factual",
13 | "domain_name":"country",
14 | "range_name":"city",
15 | "symmetric":false
16 | },
17 | "samples":[
18 | {
19 | "subject":"United States",
20 | "object":"New York City"
21 | },
22 | {
23 | "subject":"China",
24 | "object":"Shanghai"
25 | },
26 | {
27 | "subject":"Japan",
28 | "object":"Tokyo"
29 | },
30 | {
31 | "subject":"Russia",
32 | "object":"Moscow"
33 | },
34 | {
35 | "subject":"India",
36 | "object":"Mumbai"
37 | },
38 | {
39 | "subject":"Brazil",
40 | "object":"São Paulo"
41 | },
42 | {
43 | "subject":"Australia",
44 | "object":"Sydney"
45 | },
46 | {
47 | "subject":"Canada",
48 | "object":"Toronto"
49 | },
50 | {
51 | "subject":"United Kingdom",
52 | "object":"London"
53 | },
54 | {
55 | "subject":"France",
56 | "object":"Paris"
57 | },
58 | {
59 | "subject":"Germany",
60 | "object":"Berlin"
61 | },
62 | {
63 | "subject":"Italy",
64 | "object":"Rome"
65 | },
66 | {
67 | "subject":"Mexico",
68 | "object":"Mexico City"
69 | },
70 | {
71 | "subject":"South Korea",
72 | "object":"Seoul"
73 | },
74 | {
75 | "subject":"Turkey",
76 | "object":"Istanbul"
77 | },
78 | {
79 | "subject":"Spain",
80 | "object":"Madrid"
81 | },
82 | {
83 | "subject":"Argentina",
84 | "object":"Buenos Aires"
85 | },
86 | {
87 | "subject":"South Africa",
88 | "object":"Johannesburg"
89 | },
90 | {
91 | "subject":"Poland",
92 | "object":"Warsaw"
93 | },
94 | {
95 | "subject":"Nigeria",
96 | "object":"Lagos"
97 | },
98 | {
99 | "subject":"New Zealand",
100 | "object":"Auckland"
101 | },
102 | {
103 | "subject":"Switzerland",
104 | "object":"Zurich"
105 | },
106 | {
107 | "subject":"Netherlands",
108 | "object":"Amsterdam"
109 | },
110 | {
111 | "subject":"Pakistan",
112 | "object":"Karachi"
113 | }
114 | ]
115 | }
--------------------------------------------------------------------------------
/data/factual/food_from_country.json:
--------------------------------------------------------------------------------
1 | {
2 | "name":"food from country",
3 | "prompt_templates":[
4 | "{} originates from",
5 | "{} is from the country of"
6 | ],
7 | "prompt_templates_zs":[
8 | "What is the country of origin for {}? It originates from",
9 | "{} originates from the country of"
10 | ],
11 | "properties":{
12 | "relation_type":"factual",
13 | "domain_name":"food",
14 | "range_name":"country",
15 | "symmetric":false
16 | },
17 | "samples":[
18 | {
19 | "subject":"Pizza",
20 | "object":"Italy"
21 | },
22 | {
23 | "subject":"Sushi",
24 | "object":"Japan"
25 | },
26 | {
27 | "subject":"Tacos",
28 | "object":"Mexico"
29 | },
30 | {
31 | "subject":"Baguette",
32 | "object":"France"
33 | },
34 | {
35 | "subject":"Poutine",
36 | "object":"Canada"
37 | },
38 | {
39 | "subject":"Paella",
40 | "object":"Spain"
41 | },
42 | {
43 | "subject":"Chimichurri",
44 | "object":"Argentina"
45 | },
46 | {
47 | "subject":"Baklava",
48 | "object":"Turkey"
49 | },
50 | {
51 | "subject":"Feijoada",
52 | "object":"Brazil"
53 | },
54 | {
55 | "subject":"Borscht",
56 | "object":"Ukraine"
57 | },
58 | {
59 | "subject":"Fish and Chips",
60 | "object":"United Kingdom"
61 | },
62 | {
63 | "subject":"Dim Sum",
64 | "object":"China"
65 | },
66 | {
67 | "subject":"Kimchi",
68 | "object":"South Korea"
69 | },
70 | {
71 | "subject":"Goulash",
72 | "object":"Hungary"
73 | },
74 | {
75 | "subject":"Pierogi",
76 | "object":"Poland"
77 | },
78 | {
79 | "subject":"Pho",
80 | "object":"Vietnam"
81 | },
82 | {
83 | "subject":"Hummus",
84 | "object":"Lebanon"
85 | },
86 | {
87 | "subject":"Gyro",
88 | "object":"Greece"
89 | },
90 | {
91 | "subject":"Masala Dosa",
92 | "object":"India"
93 | },
94 | {
95 | "subject":"Rendang",
96 | "object":"Indonesia"
97 | },
98 | {
99 | "subject":"Moussaka",
100 | "object":"Greece"
101 | },
102 | {
103 | "subject":"Pavlova",
104 | "object":"New Zealand"
105 | },
106 | {
107 | "subject":"Shawarma",
108 | "object":"Middle East"
109 | },
110 | {
111 | "subject":"Falafel",
112 | "object":"Middle East"
113 | },
114 | {
115 | "subject":"Ceviche",
116 | "object":"Peru"
117 | },
118 | {
119 | "subject":"Biryani",
120 | "object":"India"
121 | },
122 | {
123 | "subject":"Wiener Schnitzel",
124 | "object":"Austria"
125 | },
126 | {
127 | "subject":"Fondue",
128 | "object":"Switzerland"
129 | },
130 | {
131 | "subject":"Pad Thai",
132 | "object":"Thailand"
133 | },
134 | {
135 | "subject":"Miso Soup",
136 | "object":"Japan"
137 | }
138 | ]
139 | }
140 |
--------------------------------------------------------------------------------
/data/factual/person_band_lead_singer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "person lead singer of band",
3 | "prompt_templates": [
4 | "{} is the lead singer of"
5 | ],
6 | "prompt_templates_zs": [
7 | "{} is the lead singer of the band named",
8 | "What band is {} lead singer of? They lead the band named"
9 | ],
10 | "properties": {
11 | "relation_type": "factual",
12 | "domain_name": "person",
13 | "range_name": "band",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "Brian Johnson",
19 | "object": "AC/DC"
20 | },
21 | {
22 | "subject": "Steven Tyler",
23 | "object": "Aerosmith"
24 | },
25 | {
26 | "subject": "Chris Martin",
27 | "object": "Coldplay"
28 | },
29 | {
30 | "subject": "Dave Grohl",
31 | "object": "Foo Fighters"
32 | },
33 | {
34 | "subject": "Axl Rose",
35 | "object": "Guns N' Roses"
36 | },
37 | {
38 | "subject": "Robert Plant",
39 | "object": "Led Zeppelin"
40 | },
41 | {
42 | "subject": "James Hetfield",
43 | "object": "Metallica"
44 | },
45 | {
46 | "subject": "Kurt Cobain",
47 | "object": "Nirvana"
48 | },
49 | {
50 | "subject": "Freddie Mercury",
51 | "object": "Queen"
52 | },
53 | {
54 | "subject": "Anthony Kiedis",
55 | "object": "Red Hot Chili Peppers"
56 | },
57 | {
58 | "subject": "Mick Jagger",
59 | "object": "The Rolling Stones"
60 | },
61 | {
62 | "subject": "John Lennon",
63 | "object": "The Beatles"
64 | },
65 | {
66 | "subject": "Jim Morrison",
67 | "object": "The Doors"
68 | },
69 | {
70 | "subject": "Bono",
71 | "object": "U2"
72 | },
73 | {
74 | "subject": "Ozzy Osbourne",
75 | "object": "Black Sabbath"
76 | },
77 | {
78 | "subject": "Billie Joe Armstrong",
79 | "object": "Green Day"
80 | },
81 | {
82 | "subject": "Eddie Vedder",
83 | "object": "Pearl Jam"
84 | },
85 | {
86 | "subject": "Jon Bon Jovi",
87 | "object": "Bon Jovi "
88 | },
89 | {
90 | "subject": "Don Henley",
91 | "object": "The Eagles"
92 | },
93 | {
94 | "subject": "Liam Gallagher",
95 | "object": "Oasis"
96 | },
97 | {
98 | "subject": "Serj Tankian",
99 | "object": "System of a Down"
100 | }
101 | ]
102 | }
--------------------------------------------------------------------------------
/data/factual/pokemon_evolutions.json:
--------------------------------------------------------------------------------
1 | {
2 | "name":"pokemon evolution",
3 | "prompt_templates":[
4 | "{} evolves into a"
5 | ],
6 | "prompt_templates_zs":[
7 | "What does {} evolve into? It evolves into a",
8 | "The evolved form of {} is called a"
9 | ],
10 | "properties":{
11 | "relation_type":"factual",
12 | "domain_name":"pokemon",
13 | "range_name":"pokemon",
14 | "symmetric":false
15 | },
16 | "samples":[
17 | {
18 | "subject":"Bulbasaur",
19 | "object":"Ivysaur"
20 | },
21 | {
22 | "subject":"Charmander",
23 | "object":"Charmeleon"
24 | },
25 | {
26 | "subject":"Squirtle",
27 | "object":"Wartortle"
28 | },
29 | {
30 | "subject":"Pikachu",
31 | "object":"Raichu"
32 | },
33 | {
34 | "subject":"Oddish",
35 | "object":"Gloom"
36 | },
37 | {
38 | "subject":"Venonat",
39 | "object":"Venomoth"
40 | },
41 | {
42 | "subject":"Diglett",
43 | "object":"Dugtrio"
44 | },
45 | {
46 | "subject":"Meowth",
47 | "object":"Persian"
48 | },
49 | {
50 | "subject":"Psyduck",
51 | "object":"Golduck"
52 | },
53 | {
54 | "subject":"Mankey",
55 | "object":"Primeape"
56 | },
57 | {
58 | "subject":"Growlithe",
59 | "object":"Arcanine"
60 | },
61 | {
62 | "subject":"Poliwag",
63 | "object":"Poliwhirl"
64 | },
65 | {
66 | "subject":"Abra",
67 | "object":"Kadabra"
68 | },
69 | {
70 | "subject":"Machop",
71 | "object":"Machoke"
72 | },
73 | {
74 | "subject":"Caterpie",
75 | "object":"Metapod"
76 | },
77 | {
78 | "subject":"Weedle",
79 | "object":"Kakuna"
80 | },
81 | {
82 | "subject":"Pidgey",
83 | "object":"Pidgeotto"
84 | },
85 | {
86 | "subject":"Rattata",
87 | "object":"Raticate"
88 | },
89 | {
90 | "subject":"Spearow",
91 | "object":"Fearow"
92 | },
93 | {
94 | "subject":"Ekans",
95 | "object":"Arbok"
96 | },
97 | {
98 | "subject":"Sandshrew",
99 | "object":"Sandslash"
100 | },
101 | {
102 | "subject":"Nidoran♀",
103 | "object":"Nidorina"
104 | },
105 | {
106 | "subject":"Nidoran♂",
107 | "object":"Nidorino"
108 | },
109 | {
110 | "subject":"Zubat",
111 | "object":"Golbat"
112 | },
113 | {
114 | "subject":"Bellsprout",
115 | "object":"Weepinbell"
116 | },
117 | {
118 | "subject":"Tentacool",
119 | "object":"Tentacruel"
120 | },
121 | {
122 | "subject":"Geodude",
123 | "object":"Graveler"
124 | },
125 | {
126 | "subject":"Ponyta",
127 | "object":"Rapidash"
128 | },
129 | {
130 | "subject":"Slowpoke",
131 | "object":"Slowbro"
132 | },
133 | {
134 | "subject":"Magnemite",
135 | "object":"Magneton"
136 | },
137 | {
138 | "subject":"Doduo",
139 | "object":"Dodrio"
140 | },
141 | {
142 | "subject":"Seel",
143 | "object":"Dewgong"
144 | },
145 | {
146 | "subject":"Grimer",
147 | "object":"Muk"
148 | },
149 | {
150 | "subject":"Shellder",
151 | "object":"Cloyster"
152 | },
153 | {
154 | "subject":"Gastly",
155 | "object":"Haunter"
156 | },
157 | {
158 | "subject":"Drowzee",
159 | "object":"Hypno"
160 | },
161 | {
162 | "subject":"Krabby",
163 | "object":"Kingler"
164 | },
165 | {
166 | "subject":"Voltorb",
167 | "object":"Electrode"
168 | },
169 | {
170 | "subject":"Exeggcute",
171 | "object":"Exeggutor"
172 | },
173 | {
174 | "subject":"Cubone",
175 | "object":"Marowak"
176 | },
177 | {
178 | "subject":"Koffing",
179 | "object":"Weezing"
180 | },
181 | {
182 | "subject":"Rhyhorn",
183 | "object":"Rhydon"
184 | },
185 | {
186 | "subject":"Horsea",
187 | "object":"Seadra"
188 | },
189 | {
190 | "subject":"Goldeen",
191 | "object":"Seaking"
192 | }
193 | ]
194 | }
--------------------------------------------------------------------------------
/data/factual/presidents_birth_year.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "president birth year",
3 | "prompt_templates": [
4 | "{} was born in"
5 | ],
6 | "prompt_templates_zs": [
7 | "{} was born in"
8 | ],
9 | "properties": {
10 | "relation_type": "factual",
11 | "domain_name": "person",
12 | "range_name": "year",
13 | "symmetric": false
14 | },
15 | "samples": [
16 | {
17 | "subject": "John Adams",
18 | "object": "1735"
19 | },
20 | {
21 | "subject": "Thomas Jefferson",
22 | "object": "1743"
23 | },
24 | {
25 | "subject": "James Madison",
26 | "object": "1751"
27 | },
28 | {
29 | "subject": "James Monroe",
30 | "object": "1758"
31 | },
32 | {
33 | "subject": "John Quincy Adams",
34 | "object": "1767"
35 | },
36 | {
37 | "subject": "Andrew Jackson",
38 | "object": "1767"
39 | },
40 | {
41 | "subject": "Martin Van Buren",
42 | "object": "1782"
43 | },
44 | {
45 | "subject": "William Henry Harrison",
46 | "object": "1773"
47 | },
48 | {
49 | "subject": "Richard Nixon",
50 | "object": "1913"
51 | },
52 | {
53 | "subject": "John F. Kennedy",
54 | "object": "1917"
55 | },
56 | {
57 | "subject": "Richard Nixon",
58 | "object": "1913"
59 | },
60 | {
61 | "subject": "Bill Clinton",
62 | "object": "1946"
63 | },
64 | {
65 | "subject": "George W. Bush",
66 | "object": "1946"
67 | },
68 | {
69 | "subject": "Barack Obama",
70 | "object": "1961"
71 | },
72 | {
73 | "subject": "Jimmy Carter ",
74 | "object": "1924"
75 | },
76 | {
77 | "subject": "Ronald Reagan",
78 | "object": "1911"
79 | },
80 | {
81 | "subject": "George H. W. Bush",
82 | "object": "1924"
83 | },
84 | {
85 | "subject": "Joe Biden",
86 | "object": "1942"
87 | },
88 | {
89 | "subject": "Franklin D. Roosevelt",
90 | "object": "1882"
91 | }
92 | ]
93 | }
--------------------------------------------------------------------------------
/data/factual/presidents_election_year.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "president election year",
3 | "prompt_templates": [
4 | "{} was elected in"
5 | ],
6 | "prompt_templates_zs": [
7 | "{} was elected in"
8 | ],
9 | "properties": {
10 | "relation_type": "factual",
11 | "domain_name": "person",
12 | "range_name": "year",
13 | "symmetric": false
14 | },
15 | "samples": [
16 | {
17 | "subject": "John Adams",
18 | "object": "1796"
19 | },
20 | {
21 | "subject": "Thomas Jefferson",
22 | "object": "1800"
23 | },
24 | {
25 | "subject": "James Madison",
26 | "object": "1808"
27 | },
28 | {
29 | "subject": "James Monroe",
30 | "object": "1816"
31 | },
32 | {
33 | "subject": "John Quincy Adams",
34 | "object": "1824"
35 | },
36 | {
37 | "subject": "Andrew Jackson",
38 | "object": "1828"
39 | },
40 | {
41 | "subject": "Martin Van Buren",
42 | "object": "1836"
43 | },
44 | {
45 | "subject": "William Henry Harrison",
46 | "object": "1840"
47 | },
48 | {
49 | "subject": "Richard M. Nixon",
50 | "object": "1968"
51 | },
52 | {
53 | "subject": "John F. Kennedy",
54 | "object": "1960"
55 | },
56 | {
57 | "subject": "Richard Nixon",
58 | "object": "1968"
59 | },
60 | {
61 | "subject": "Bill Clinton",
62 | "object": "1992"
63 | },
64 | {
65 | "subject": "George W. Bush",
66 | "object": "2000"
67 | },
68 | {
69 | "subject": "Barack Obama",
70 | "object": "2008"
71 | },
72 | {
73 | "subject": "Jimmy Carter ",
74 | "object": "1976"
75 | },
76 | {
77 | "subject": "Ronald Reagan",
78 | "object": "1980"
79 | },
80 | {
81 | "subject": "George H. W. Bush",
82 | "object": "1988"
83 | },
84 | {
85 | "subject": "Joe Biden",
86 | "object": "2020"
87 | },
88 | {
89 | "subject": "Franklin D. Roosevelt",
90 | "object": "1932"
91 | }
92 | ]
93 | }
--------------------------------------------------------------------------------
/data/linguistic/adj_comparative.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "adjective comparative",
3 | "prompt_templates": [
4 | "The comparative form of {} is"
5 | ],
6 | "prompt_templates_zs": [
7 | "The comparative form of {} is",
8 | "What is the comparative form of {}? It is"
9 | ],
10 | "properties": {
11 | "relation_type": "linguistic",
12 | "domain_name": "adjective",
13 | "range_name": "adjective",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "big",
19 | "object": "bigger"
20 | },
21 | {
22 | "subject": "small",
23 | "object": "smaller"
24 | },
25 | {
26 | "subject": "tall",
27 | "object": "taller"
28 | },
29 | {
30 | "subject": "short",
31 | "object": "shorter"
32 | },
33 | {
34 | "subject": "long",
35 | "object": "longer"
36 | },
37 | {
38 | "subject": "fast",
39 | "object": "faster"
40 | },
41 | {
42 | "subject": "slow",
43 | "object": "slower"
44 | },
45 | {
46 | "subject": "strong",
47 | "object": "stronger"
48 | },
49 | {
50 | "subject": "weak",
51 | "object": "weaker"
52 | },
53 | {
54 | "subject": "heavy",
55 | "object": "heavier"
56 | },
57 | {
58 | "subject": "light",
59 | "object": "lighter"
60 | },
61 | {
62 | "subject": "old",
63 | "object": "older"
64 | },
65 | {
66 | "subject": "young",
67 | "object": "younger"
68 | },
69 | {
70 | "subject": "high",
71 | "object": "higher"
72 | },
73 | {
74 | "subject": "low",
75 | "object": "lower"
76 | },
77 | {
78 | "subject": "deep",
79 | "object": "deeper"
80 | },
81 | {
82 | "subject": "shallow",
83 | "object": "shallower"
84 | },
85 | {
86 | "subject": "wide",
87 | "object": "wider"
88 | },
89 | {
90 | "subject": "narrow",
91 | "object": "narrower"
92 | },
93 | {
94 | "subject": "thick",
95 | "object": "thicker"
96 | },
97 | {
98 | "subject": "thin",
99 | "object": "thinner"
100 | },
101 | {
102 | "subject": "hot",
103 | "object": "hotter"
104 | },
105 | {
106 | "subject": "cold",
107 | "object": "colder"
108 | },
109 | {
110 | "subject": "bright",
111 | "object": "brighter"
112 | },
113 | {
114 | "subject": "dark",
115 | "object": "darker"
116 | },
117 | {
118 | "subject": "loud",
119 | "object": "louder"
120 | },
121 | {
122 | "subject": "quiet",
123 | "object": "quieter"
124 | },
125 | {
126 | "subject": "happy",
127 | "object": "happier"
128 | },
129 | {
130 | "subject": "sad",
131 | "object": "sadder"
132 | },
133 | {
134 | "subject": "good",
135 | "object": "better"
136 | },
137 | {
138 | "subject": "bad",
139 | "object": "worse"
140 | },
141 | {
142 | "subject": "ugly",
143 | "object": "uglier"
144 | },
145 | {
146 | "subject": "clean",
147 | "object": "cleaner"
148 | },
149 | {
150 | "subject": "dirty",
151 | "object": "dirtier"
152 | },
153 | {
154 | "subject": "easy",
155 | "object": "easier"
156 | },
157 | {
158 | "subject": "simple",
159 | "object": "simpler"
160 | },
161 | {
162 | "subject": "kind",
163 | "object": "kinder"
164 | },
165 | {
166 | "subject": "mean",
167 | "object": "meaner"
168 | },
169 | {
170 | "subject": "brave",
171 | "object": "braver"
172 | },
173 | {
174 | "subject": "smart",
175 | "object": "smarter"
176 | },
177 | {
178 | "subject": "stupid",
179 | "object": "stupider"
180 | },
181 | {
182 | "subject": "wise",
183 | "object": "wiser"
184 | },
185 | {
186 | "subject": "happy",
187 | "object": "happier"
188 | },
189 | {
190 | "subject": "rich",
191 | "object": "richer"
192 | },
193 | {
194 | "subject": "poor",
195 | "object": "poorer"
196 | },
197 | {
198 | "subject": "hot",
199 | "object": "hotter"
200 | },
201 | {
202 | "subject": "cold",
203 | "object": "colder"
204 | },
205 | {
206 | "subject": "sweet",
207 | "object": "sweeter"
208 | },
209 | {
210 | "subject": "sour",
211 | "object": "sourer"
212 | },
213 | {
214 | "subject": "fresh",
215 | "object": "fresher"
216 | },
217 | {
218 | "subject": "stale",
219 | "object": "staler"
220 | },
221 | {
222 | "subject": "cheap",
223 | "object": "cheaper"
224 | },
225 | {
226 | "subject": "safe",
227 | "object": "safer"
228 | },
229 | {
230 | "subject": "bright",
231 | "object": "brighter"
232 | },
233 | {
234 | "subject": "dull",
235 | "object": "duller"
236 | },
237 | {
238 | "subject": "happy",
239 | "object": "happier"
240 | },
241 | {
242 | "subject": "sad",
243 | "object": "sadder"
244 | },
245 | {
246 | "subject": "true",
247 | "object": "truer"
248 | },
249 | {
250 | "subject": "false",
251 | "object": "falser"
252 | },
253 | {
254 | "subject": "rude",
255 | "object": "ruder"
256 | },
257 | {
258 | "subject": "calm",
259 | "object": "calmer"
260 | },
261 | {
262 | "subject": "strong",
263 | "object": "stronger"
264 | },
265 | {
266 | "subject": "weak",
267 | "object": "weaker"
268 | },
269 | {
270 | "subject": "healthy",
271 | "object": "healthier"
272 | },
273 | {
274 | "subject": "sick",
275 | "object": "sicker"
276 | },
277 | {
278 | "subject": "clean",
279 | "object": "cleaner"
280 | },
281 | {
282 | "subject": "dirty",
283 | "object": "dirtier"
284 | },
285 | {
286 | "subject": "fresh",
287 | "object": "fresher"
288 | }
289 | ]
290 | }
--------------------------------------------------------------------------------
/data/linguistic/adj_superlative.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "adjective superlative",
3 | "prompt_templates": [
4 | "The superlative form of {} is"
5 | ],
6 | "prompt_templates_zs": [
7 | "The superlative form of {} is",
8 | "What is the superlative form of {}? It is"
9 | ],
10 | "properties": {
11 | "relation_type": "linguistic",
12 | "domain_name": "adjective",
13 | "range_name": "adjective",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "angry",
19 | "object": "angriest"
20 | },
21 | {
22 | "subject": "bad",
23 | "object": "worst"
24 | },
25 | {
26 | "subject": "big",
27 | "object": "biggest"
28 | },
29 | {
30 | "subject": "brave",
31 | "object": "bravest"
32 | },
33 | {
34 | "subject": "bright",
35 | "object": "brightest"
36 | },
37 | {
38 | "subject": "calm",
39 | "object": "calmest"
40 | },
41 | {
42 | "subject": "cheap",
43 | "object": "cheapest"
44 | },
45 | {
46 | "subject": "clean",
47 | "object": "cleanest"
48 | },
49 | {
50 | "subject": "cold",
51 | "object": "coldest"
52 | },
53 | {
54 | "subject": "crazy",
55 | "object": "craziest"
56 | },
57 | {
58 | "subject": "cruel",
59 | "object": "cruelest"
60 | },
61 | {
62 | "subject": "dark",
63 | "object": "darkest"
64 | },
65 | {
66 | "subject": "deep",
67 | "object": "deepest"
68 | },
69 | {
70 | "subject": "dirty",
71 | "object": "dirtiest"
72 | },
73 | {
74 | "subject": "dry",
75 | "object": "driest"
76 | },
77 | {
78 | "subject": "dull",
79 | "object": "dullest"
80 | },
81 | {
82 | "subject": "easy",
83 | "object": "easiest"
84 | },
85 | {
86 | "subject": "fast",
87 | "object": "fastest"
88 | },
89 | {
90 | "subject": "fierce",
91 | "object": "fiercest"
92 | },
93 | {
94 | "subject": "fresh",
95 | "object": "freshest"
96 | },
97 | {
98 | "subject": "friendly",
99 | "object": "friendliest"
100 | },
101 | {
102 | "subject": "full",
103 | "object": "fullest"
104 | },
105 | {
106 | "subject": "funny",
107 | "object": "funniest"
108 | },
109 | {
110 | "subject": "gentle",
111 | "object": "gentlest"
112 | },
113 | {
114 | "subject": "good",
115 | "object": "best"
116 | },
117 | {
118 | "subject": "great",
119 | "object": "greatest"
120 | },
121 | {
122 | "subject": "happy",
123 | "object": "happiest"
124 | },
125 | {
126 | "subject": "hard",
127 | "object": "hardest"
128 | },
129 | {
130 | "subject": "healthy",
131 | "object": "healthiest"
132 | },
133 | {
134 | "subject": "heavy",
135 | "object": "heaviest"
136 | },
137 | {
138 | "subject": "high",
139 | "object": "highest"
140 | },
141 | {
142 | "subject": "hot",
143 | "object": "hottest"
144 | },
145 | {
146 | "subject": "hungry",
147 | "object": "hungriest"
148 | },
149 | {
150 | "subject": "kind",
151 | "object": "kindest"
152 | },
153 | {
154 | "subject": "large",
155 | "object": "largest"
156 | },
157 | {
158 | "subject": "lazy",
159 | "object": "laziest"
160 | },
161 | {
162 | "subject": "light",
163 | "object": "lightest"
164 | },
165 | {
166 | "subject": "little",
167 | "object": "smallest"
168 | },
169 | {
170 | "subject": "long",
171 | "object": "longest"
172 | },
173 | {
174 | "subject": "loud",
175 | "object": "loudest"
176 | },
177 | {
178 | "subject": "low",
179 | "object": "lowest"
180 | },
181 | {
182 | "subject": "mad",
183 | "object": "maddest"
184 | },
185 | {
186 | "subject": "mean",
187 | "object": "meanest"
188 | },
189 | {
190 | "subject": "messy",
191 | "object": "messiest"
192 | },
193 | {
194 | "subject": "narrow",
195 | "object": "narrowest"
196 | },
197 | {
198 | "subject": "near",
199 | "object": "nearest"
200 | },
201 | {
202 | "subject": "new",
203 | "object": "newest"
204 | },
205 | {
206 | "subject": "old",
207 | "object": "oldest"
208 | },
209 | {
210 | "subject": "polite",
211 | "object": "politest"
212 | },
213 | {
214 | "subject": "poor",
215 | "object": "poorest"
216 | },
217 | {
218 | "subject": "quick",
219 | "object": "quickest"
220 | },
221 | {
222 | "subject": "quiet",
223 | "object": "quietest"
224 | },
225 | {
226 | "subject": "rich",
227 | "object": "richest"
228 | },
229 | {
230 | "subject": "sad",
231 | "object": "saddest"
232 | },
233 | {
234 | "subject": "safe",
235 | "object": "safest"
236 | },
237 | {
238 | "subject": "short",
239 | "object": "shortest"
240 | },
241 | {
242 | "subject": "shy",
243 | "object": "shyest"
244 | },
245 | {
246 | "subject": "simple",
247 | "object": "simplest"
248 | },
249 | {
250 | "subject": "slow",
251 | "object": "slowest"
252 | },
253 | {
254 | "subject": "small",
255 | "object": "smallest"
256 | },
257 | {
258 | "subject": "smart",
259 | "object": "smartest"
260 | },
261 | {
262 | "subject": "smooth",
263 | "object": "smoothest"
264 | },
265 | {
266 | "subject": "strong",
267 | "object": "strongest"
268 | },
269 | {
270 | "subject": "sweet",
271 | "object": "sweetest"
272 | },
273 | {
274 | "subject": "tall",
275 | "object": "tallest"
276 | },
277 | {
278 | "subject": "thick",
279 | "object": "thickest"
280 | },
281 | {
282 | "subject": "thin",
283 | "object": "thinnest"
284 | },
285 | {
286 | "subject": "tiny",
287 | "object": "tiniest"
288 | },
289 | {
290 | "subject": "tough",
291 | "object": "toughest"
292 | },
293 | {
294 | "subject": "true",
295 | "object": "truest"
296 | },
297 | {
298 | "subject": "ugly",
299 | "object": "ugliest"
300 | },
301 | {
302 | "subject": "weak",
303 | "object": "weakest"
304 | },
305 | {
306 | "subject": "wet",
307 | "object": "wettest"
308 | },
309 | {
310 | "subject": "wide",
311 | "object": "widest"
312 | },
313 | {
314 | "subject": "wild",
315 | "object": "wildest"
316 | },
317 | {
318 | "subject": "wise",
319 | "object": "wisest"
320 | },
321 | {
322 | "subject": "witty",
323 | "object": "wittiest"
324 | },
325 | {
326 | "subject": "wrong",
327 | "object": "wrongest"
328 | },
329 | {
330 | "subject": "young",
331 | "object": "youngest"
332 | },
333 | {
334 | "subject": "zesty",
335 | "object": "zestiest"
336 | }
337 | ]
338 | }
--------------------------------------------------------------------------------
/data/linguistic/verb_past_tense.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "verb past tense",
3 | "prompt_templates": [
4 | "The past tense of {} is"
5 | ],
6 | "prompt_templates_zs": [
7 | "The past tense of {} is",
8 | "What is the past tense of {}? It is"
9 | ],
10 | "properties": {
11 | "relation_type": "linguistic",
12 | "domain_name": "verb",
13 | "range_name": "verb",
14 | "symmetric": false
15 | },
16 | "samples": [
17 | {
18 | "subject": "ask",
19 | "object": "asked"
20 | },
21 | {
22 | "subject": "believe",
23 | "object": "believed"
24 | },
25 | {
26 | "subject": "break",
27 | "object": "broke"
28 | },
29 | {
30 | "subject": "bring",
31 | "object": "brought"
32 | },
33 | {
34 | "subject": "build",
35 | "object": "built"
36 | },
37 | {
38 | "subject": "call",
39 | "object": "called"
40 | },
41 | {
42 | "subject": "catch",
43 | "object": "caught"
44 | },
45 | {
46 | "subject": "change",
47 | "object": "changed"
48 | },
49 | {
50 | "subject": "choose",
51 | "object": "chose"
52 | },
53 | {
54 | "subject": "clean",
55 | "object": "cleaned"
56 | },
57 | {
58 | "subject": "climb",
59 | "object": "climbed"
60 | },
61 | {
62 | "subject": "close",
63 | "object": "closed"
64 | },
65 | {
66 | "subject": "come",
67 | "object": "came"
68 | },
69 | {
70 | "subject": "cry",
71 | "object": "cried"
72 | },
73 | {
74 | "subject": "cut",
75 | "object": "cut"
76 | },
77 | {
78 | "subject": "dance",
79 | "object": "danced"
80 | },
81 | {
82 | "subject": "decide",
83 | "object": "decided"
84 | },
85 | {
86 | "subject": "do",
87 | "object": "did"
88 | },
89 | {
90 | "subject": "drive",
91 | "object": "drove"
92 | },
93 | {
94 | "subject": "drink",
95 | "object": "drank"
96 | },
97 | {
98 | "subject": "eat",
99 | "object": "ate"
100 | },
101 | {
102 | "subject": "fall",
103 | "object": "fell"
104 | },
105 | {
106 | "subject": "finish",
107 | "object": "finished"
108 | },
109 | {
110 | "subject": "find",
111 | "object": "found"
112 | },
113 | {
114 | "subject": "fly",
115 | "object": "flew"
116 | },
117 | {
118 | "subject": "follow",
119 | "object": "followed"
120 | },
121 | {
122 | "subject": "forget",
123 | "object": "forgot"
124 | },
125 | {
126 | "subject": "frown",
127 | "object": "frowned"
128 | },
129 | {
130 | "subject": "get",
131 | "object": "got"
132 | },
133 | {
134 | "subject": "give",
135 | "object": "gave"
136 | },
137 | {
138 | "subject": "go",
139 | "object": "went"
140 | },
141 | {
142 | "subject": "hate",
143 | "object": "hated"
144 | },
145 | {
146 | "subject": "have",
147 | "object": "had"
148 | },
149 | {
150 | "subject": "hear",
151 | "object": "heard"
152 | },
153 | {
154 | "subject": "help",
155 | "object": "helped"
156 | },
157 | {
158 | "subject": "hit",
159 | "object": "hit"
160 | },
161 | {
162 | "subject": "jump",
163 | "object": "jumped"
164 | },
165 | {
166 | "subject": "know",
167 | "object": "knew"
168 | },
169 | {
170 | "subject": "laugh",
171 | "object": "laughed"
172 | },
173 | {
174 | "subject": "learn",
175 | "object": "learned"
176 | },
177 | {
178 | "subject": "leave",
179 | "object": "left"
180 | },
181 | {
182 | "subject": "like",
183 | "object": "liked"
184 | },
185 | {
186 | "subject": "live",
187 | "object": "lived"
188 | },
189 | {
190 | "subject": "look",
191 | "object": "looked"
192 | },
193 | {
194 | "subject": "lose",
195 | "object": "lost"
196 | },
197 | {
198 | "subject": "love",
199 | "object": "loved"
200 | },
201 | {
202 | "subject": "make",
203 | "object": "made"
204 | },
205 | {
206 | "subject": "meet",
207 | "object": "met"
208 | },
209 | {
210 | "subject": "need",
211 | "object": "needed"
212 | },
213 | {
214 | "subject": "open",
215 | "object": "opened"
216 | },
217 | {
218 | "subject": "play",
219 | "object": "played"
220 | },
221 | {
222 | "subject": "read",
223 | "object": "read"
224 | },
225 | {
226 | "subject": "remember",
227 | "object": "remembered"
228 | },
229 | {
230 | "subject": "run",
231 | "object": "ran"
232 | },
233 | {
234 | "subject": "say",
235 | "object": "said"
236 | },
237 | {
238 | "subject": "see",
239 | "object": "saw"
240 | },
241 | {
242 | "subject": "shake",
243 | "object": "shook"
244 | },
245 | {
246 | "subject": "sit",
247 | "object": "sat"
248 | },
249 | {
250 | "subject": "sleep",
251 | "object": "slept"
252 | },
253 | {
254 | "subject": "smile",
255 | "object": "smiled"
256 | },
257 | {
258 | "subject": "speak",
259 | "object": "spoke"
260 | },
261 | {
262 | "subject": "start",
263 | "object": "started"
264 | },
265 | {
266 | "subject": "stand",
267 | "object": "stood"
268 | },
269 | {
270 | "subject": "study",
271 | "object": "studied"
272 | },
273 | {
274 | "subject": "swim",
275 | "object": "swam"
276 | },
277 | {
278 | "subject": "take",
279 | "object": "took"
280 | },
281 | {
282 | "subject": "talk",
283 | "object": "talked"
284 | },
285 | {
286 | "subject": "think",
287 | "object": "thought"
288 | },
289 | {
290 | "subject": "try",
291 | "object": "tried"
292 | },
293 | {
294 | "subject": "understand",
295 | "object": "understood"
296 | },
297 | {
298 | "subject": "walk",
299 | "object": "walked"
300 | },
301 | {
302 | "subject": "want",
303 | "object": "wanted"
304 | },
305 | {
306 | "subject": "wear",
307 | "object": "wore"
308 | },
309 | {
310 | "subject": "win",
311 | "object": "won"
312 | },
313 | {
314 | "subject": "work",
315 | "object": "worked"
316 | },
317 | {
318 | "subject": "write",
319 | "object": "wrote"
320 | }
321 | ]
322 | }
--------------------------------------------------------------------------------
/eap/dataset.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional
3 |
4 | import pandas as pd
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | def collate_EAP(xs, task):
10 | clean, corrupted, labels = zip(*xs)
11 | clean = list(clean)
12 | corrupted = list(corrupted)
13 | if 'hypernymy' not in task:
14 | labels = torch.tensor(labels)
15 | return clean, corrupted, labels
16 |
17 | class EAPDataset(Dataset):
18 | def __init__(self, task:str, filename:Optional[str]=None):
19 | self.df = pd.read_csv(filename)
20 | self.task = task
21 |
22 | def __len__(self):
23 | return len(self.df)
24 |
25 | def shuffle(self):
26 | self.df = self.df.sample(frac=1)
27 |
28 | def head(self, n: int):
29 | self.df = self.df.head(n)
30 |
31 | def __getitem__(self, index):
32 | row = self.df.iloc[index]
33 | label = None
34 | if self.task == 'ioi':
35 | label = [row['correct_idx'], row['incorrect_idx']]
36 | elif 'greater-than' in self.task:
37 | label = row['correct_idx']
38 | elif 'hypernymy' in self.task:
39 | answer = torch.tensor(eval(row['answers_idx']))
40 | corrupted_answer = torch.tensor(eval(row['corrupted_answers_idx']))
41 | label = [answer, corrupted_answer]
42 | elif 'fact-retrieval' in self.task:
43 | label = [row['country_idx'], row['corrupted_country_idx']]
44 | elif 'gender' in self.task:
45 | label = [row['clean_answer_idx'], row['corrupted_answer_idx']]
46 | elif self.task == 'sva':
47 | label = row['plural']
48 | elif self.task == 'colored-objects':
49 | label = [row['correct_idx'], row['incorrect_idx']]
50 | elif self.task in {'dummy-easy', 'dummy-medium', 'dummy-hard'}:
51 | label = 0
52 | else:
53 | raise ValueError(f'Got invalid task: {self.task}')
54 | return row['clean'], row['corrupted'], label
55 |
56 | def to_dataloader(self, batch_size: int):
57 | return DataLoader(self, batch_size=batch_size, collate_fn=partial(collate_EAP, task=self.task))
--------------------------------------------------------------------------------
/eap/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union, Literal, Tuple
2 | from functools import partial
3 |
4 | import pandas as pd
5 | import torch
6 | from torch.nn.functional import kl_div
7 | from transformers import PreTrainedTokenizer
8 | from transformer_lens import HookedTransformer
9 |
10 | def get_metric(metric_name: str, task: str, tokenizer:Optional[PreTrainedTokenizer]=None, model: Optional[HookedTransformer]=None):
11 | if metric_name == 'kl_divergence' or metric_name == 'kl':
12 | return partial(divergence, divergence_type='kl')
13 | elif metric_name == 'js_divergence' or metric_name == 'js':
14 | return partial(divergence, divergence_type='js')
15 | elif metric_name == 'logit_diff' or metric_name == 'prob_diff':
16 | prob = (metric_name == 'prob_diff')
17 | if 'greater-than' in task:
18 | if tokenizer is None:
19 | if model is None:
20 | raise ValueError("Either tokenizer or model must be set for greater-than and prob / logit diff")
21 | else:
22 | tokenizer = model.tokenizer
23 | logit_diff_fn = get_logit_diff_greater_than(tokenizer)
24 | elif 'hypernymy' in task:
25 | logit_diff_fn = logit_diff_hypernymy
26 | elif task == 'sva':
27 | if model is None:
28 | raise ValueError("model must be set for sva and prob / logit diff")
29 | logit_diff_fn = get_logit_diff_sva(model)
30 | else:
31 | logit_diff_fn = logit_diff
32 | return partial(logit_diff_fn, prob=prob)
33 | else:
34 | raise ValueError(f"got bad metric_name: {metric_name}")
35 |
36 | def get_logit_positions(logits: torch.Tensor, input_length: torch.Tensor):
37 | batch_size = logits.size(0)
38 | idx = torch.arange(batch_size, device=logits.device)
39 |
40 | logits = logits[idx, input_length - 1]
41 | return logits
42 |
43 | def js_div(p: torch.tensor, q: torch.tensor):
44 | p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
45 | m = (0.5 * (p + q)).log()
46 | return 0.5 * (kl_div(m, p.log(), log_target=True, reduction='none').mean(-1) + kl_div(m, q.log(), log_target=True, reduction='none').mean(-1))
47 |
48 | def divergence(logits: torch.Tensor, clean_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, divergence_type: Union[Literal['kl'], Literal['js']]='kl', mean=True, loss=True):
49 | logits = get_logit_positions(logits, input_length)
50 | clean_logits = get_logit_positions(clean_logits, input_length)
51 |
52 | probs = torch.softmax(logits, dim=-1)
53 | clean_probs = torch.softmax(clean_logits, dim=-1)
54 |
55 | if divergence_type == 'kl':
56 | results = kl_div(probs.log(), clean_probs.log(), log_target=True, reduction='none').mean(-1)
57 | elif divergence_type == 'js':
58 | results = js_div(probs, clean_probs)
59 | else:
60 | raise ValueError(f"Expected divergence_type of 'kl' or 'js', but got '{divergence_type}'")
61 | return results.mean() if mean else results
62 |
63 | def logit_diff(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False):
64 | clean_logits = get_logit_positions(clean_logits, input_length)
65 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits
66 | good_bad = torch.gather(cleans, -1, labels.to(cleans.device))
67 | results = good_bad[:, 0] - good_bad[:, 1]
68 |
69 | if loss:
70 | # remember it's reversed to make it a loss
71 | results = -results
72 | if mean:
73 | results = results.mean()
74 | return results
75 |
76 | def direct_logit(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False):
77 | clean_logits = get_logit_positions(clean_logits, input_length)
78 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits
79 | good_bad = torch.gather(cleans, -1, labels.to(cleans.device))
80 | results = good_bad[:, 0]
81 |
82 | if loss:
83 | # remember it's reversed to make it a loss
84 | results = -results
85 | if mean:
86 | results = results.mean()
87 | return results
88 |
89 | def logit_diff_hypernymy(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: List[torch.Tensor], mean=True, prob=False, loss=False):
90 | clean_logits = get_logit_positions(clean_logits, input_length)
91 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits
92 |
93 | results = []
94 | for i, (ls,corrupted_ls) in enumerate(labels):
95 | r = cleans[i][ls.to(cleans.device)].sum() - cleans[i][corrupted_ls.to(cleans.device)].sum()
96 | results.append(r)
97 | results = torch.stack(results)
98 |
99 | if loss:
100 | # remember it's reversed to make it a loss
101 | results = -results
102 | if mean:
103 | results = results.mean()
104 | return results
105 |
106 | def get_year_indices(tokenizer: PreTrainedTokenizer):
107 | return torch.tensor([tokenizer(f'{year:02d}').input_ids[0] for year in range(100)])
108 |
109 |
110 | def get_logit_diff_greater_than(tokenizer: PreTrainedTokenizer):
111 | year_indices = get_year_indices(tokenizer)
112 | def logit_diff_greater_than(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False):
113 | # Prob diff (negative, since it's a loss)
114 | clean_logits = get_logit_positions(clean_logits, input_length)
115 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits
116 | cleans = cleans[:, year_indices]
117 |
118 | results = []
119 | if prob:
120 | for prob, year in zip(cleans, labels):
121 | results.append(prob[year + 1 :].sum() - prob[: year + 1].sum())
122 | else:
123 | for logit, year in zip(cleans, labels):
124 | results.append(logit[year + 1 :].mean() - logit[: year + 1].mean())
125 |
126 | results = torch.stack(results)
127 | if loss:
128 | results = -results
129 | if mean:
130 | results = results.mean()
131 | return results
132 | return logit_diff_greater_than
133 |
134 | def get_singular_and_plural(model, strict=False) -> Tuple[torch.Tensor, torch.Tensor]:
135 | tokenizer = model.tokenizer
136 | tokenizer_length = model.cfg.d_vocab_out
137 |
138 | df: pd.DataFrame = pd.read_csv('data/sva/combined_verb_list.csv')
139 | singular = df['sing'].to_list()
140 | plural = df['plur'].to_list()
141 | singular_set = set(singular)
142 | plural_set = set(plural)
143 | verb_set = singular_set | plural_set
144 | assert len(singular_set & plural_set) == 0, f"{singular_set & plural_set}"
145 | singular_indices, plural_indices = [], []
146 |
147 | for i in range(tokenizer_length):
148 | token = tokenizer._convert_id_to_token(i)
149 | if token is not None:
150 | if token[0] == 'Ġ':
151 | token = token[1:]
152 | if token in verb_set:
153 | if token in singular_set:
154 | singular_indices.append(i)
155 | else: # token in plural_set:
156 | idx = plural.index(token)
157 | third_person_present = singular[idx]
158 | third_person_present_tokenized = tokenizer(f' {third_person_present}', add_special_tokens=False)['input_ids']
159 | if len(third_person_present_tokenized) == 1 and third_person_present_tokenized[0] != tokenizer.unk_token_id:
160 | plural_indices.append(i)
161 | elif not strict:
162 | plural_indices.append(i)
163 |
164 | return torch.tensor(singular_indices, device=model.cfg.device), torch.tensor(plural_indices, device=model.cfg.device)
165 |
166 | def get_logit_diff_sva(model, strict=True) -> torch.Tensor:
167 | singular_indices, plural_indices = get_singular_and_plural(model, strict=strict)
168 | def sva_logit_diff(clean_logits: torch.Tensor, corrupted_logits: torch.Tensor, input_length: torch.Tensor, labels: torch.Tensor, mean=True, prob=False, loss=False):
169 | clean_logits = get_logit_positions(clean_logits, input_length)
170 | cleans = torch.softmax(clean_logits, dim=-1) if prob else clean_logits
171 |
172 | if prob:
173 | singular = cleans[:, singular_indices].sum(-1)
174 | plural = cleans[:, plural_indices].sum(-1)
175 | else:
176 | singular = cleans[:, singular_indices].mean(-1)
177 | plural = cleans[:, plural_indices].mean(-1)
178 |
179 | results = torch.where(labels.to(cleans.device) == 0, singular - plural, plural - singular)
180 | if loss:
181 | results = -results
182 | if mean:
183 | results = results.mean()
184 | return results
185 | return sva_logit_diff
--------------------------------------------------------------------------------
/eap/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | def model2family(model_name: str):
9 | if 'gpt2' in model_name:
10 | return 'gpt2'
11 | elif 'pythia' in model_name:
12 | return 'pythia'
13 | else:
14 | raise ValueError(f"Couldn't find model family for model: {model_name}")
15 |
16 | def kl_div(logits, clean_logits, input_length, labels, mean=True):
17 | batch_size = logits.size(0)
18 | idx = torch.arange(batch_size, device=logits.device)
19 |
20 | logits = logits[idx, input_length - 1]
21 | clean_logits = clean_logits[idx, input_length - 1]
22 |
23 | logprobs = torch.log_softmax(logits, dim=-1)
24 | clean_logprobs = torch.log_softmax(clean_logits, dim=-1)
25 |
26 | results = torch.nn.functional.kl_div(logprobs, clean_logprobs, log_target=True, reduction='none')
27 | return results.mean() if mean else results
28 |
29 | def precision_at_k(clean_logits, corrupted_logits, input_length, labels, k=1, mean=True):
30 | batch_size = clean_logits.size(0)
31 | idx = torch.arange(batch_size, device=clean_logits.device)
32 |
33 | clean_logits = clean_logits[idx, input_length - 1]
34 | clean_probs = torch.softmax(clean_logits, dim=-1)
35 | predictions = torch.argmax(clean_probs, dim=-1).cpu()
36 |
37 | results = []
38 | for i, (ls,_) in enumerate(labels):
39 | r = torch.sum((ls == predictions[i]).float())
40 | results.append(r)
41 | results = torch.stack(results)
42 | return results.mean() if mean else results
43 |
44 | def prob_diff_hypernymy(clean_logits, corrupted_logits, input_length, labels, mean=True, loss=False, logits=False):
45 | batch_size = clean_logits.size(0)
46 | idx = torch.arange(batch_size, device=clean_logits.device)
47 |
48 | clean_logits = clean_logits[idx, input_length - 1]
49 | clean_probs = torch.softmax(clean_logits, dim=-1)
50 |
51 | if logits:
52 | clean_probs = clean_logits
53 |
54 | results = []
55 | for i, (ls,corrupted_ls) in enumerate(labels):
56 | r = clean_probs[i][ls.to(clean_probs.device)].sum() - clean_probs[i][corrupted_ls.to(clean_probs.device)].sum()
57 | results.append(r)
58 | results = torch.stack(results)
59 | if loss:
60 | results = -results
61 | return results.mean() if mean else results
62 |
63 | def batch(iterable, n:int=1):
64 | current_batch = []
65 | for item in iterable:
66 | current_batch.append(item)
67 | if len(current_batch) == n:
68 | yield current_batch
69 | current_batch = []
70 | if current_batch:
71 | yield current_batch
72 |
73 | def get_singular_and_plural(model, strict=False) -> Tuple[torch.Tensor, torch.Tensor]:
74 | _TOKENIZER = model.tokenizer
75 | tokenizer_length = model.cfg.d_vocab_out
76 |
77 | df: pd.DataFrame = pd.read_csv('../data/sva/combined_verb_list.csv')
78 | singular = df['sing'].to_list()
79 | plural = df['plur'].to_list()
80 | singular_set = set(singular)
81 | plural_set = set(plural)
82 | verb_set = singular_set | plural_set
83 | assert len(singular_set & plural_set) == 0, f"{singular_set & plural_set}"
84 | singular_indices, plural_indices = [], []
85 |
86 | for i in range(tokenizer_length):
87 | token = _TOKENIZER._convert_id_to_token(i)
88 | if token is not None:
89 | if token[0] == 'Ġ':
90 | token = token[1:]
91 | if token in verb_set:
92 | if token in singular_set:
93 | singular_indices.append(i)
94 | else: # token in plural_set:
95 | idx = plural.index(token)
96 | third_person_present = singular[idx]
97 | third_person_present_tokenized = _TOKENIZER(f' {third_person_present}', add_special_tokens=False)['input_ids']
98 | if len(third_person_present_tokenized) == 1 and third_person_present_tokenized[0] != _TOKENIZER.unk_token_id:
99 | plural_indices.append(i)
100 | elif not strict:
101 | plural_indices.append(i)
102 |
103 | return torch.tensor(singular_indices, device=model.cfg.device), torch.tensor(plural_indices, device=model.cfg.device)
104 |
105 | def get_sva_prob_diff(model, strict=True) -> torch.Tensor:
106 | singular_indices, plural_indices = get_singular_and_plural(model, strict=strict)
107 | def sva_prob_diff(logits, clean_logits, input_length, labels, loss=False, mean=True):
108 | batch_size = clean_logits.size(0)
109 | idx = torch.arange(batch_size, device=clean_logits.device)
110 | probs = F.softmax(logits[idx, input_length-1], dim=-1)
111 | singular = probs[:, singular_indices].sum(-1)
112 | plural = probs[:, plural_indices].sum(-1)
113 |
114 | correct_form_prob_diff = torch.where(labels == 0, singular - plural, plural - singular)
115 | if loss:
116 | correct_form_prob_diff = - correct_form_prob_diff
117 | if mean:
118 | return correct_form_prob_diff.mean()
119 | else:
120 | return correct_form_prob_diff
121 | return sva_prob_diff
122 |
123 | def inflow_outflow_difference(g, absolute:bool=True):
124 | diffs = []
125 | for name, node in g.nodes.items():
126 | if 'logits' in name or 'input' in name:
127 | continue
128 | diff = sum(edge.score for edge in node.child_edges) - sum(edge.score for edge in node.parent_edges)
129 | if absolute:
130 | diff = abs(diff)
131 | diffs.append(diff)
132 | diffs = np.array(diff)
133 | logit_inflow = sum(edge.score for edge in g.logits[0].parent_edges)
134 | input_outflow = sum(edge.score for edge in g.nodes['input'].child_edges)
135 | return diffs.mean(), logit_inflow, input_outflow
--------------------------------------------------------------------------------
/eap/visualization.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import functools
3 |
4 | import numpy as np
5 | import matplotlib
6 | import matplotlib.cm
7 |
8 |
9 |
10 | EDGE_TYPE_COLORS = {
11 | 'q': "#FF00FF", # Purple
12 | 'k': "#00FF00", # Green
13 | 'v': "#0000FF", # Blue
14 | None: "#000000", # Black
15 | }
16 |
17 | def generate_random_color(colorscheme: str) -> str:
18 | """
19 | https://stackoverflow.com/questions/28999287/generate-random-colors-rgb
20 | """
21 |
22 | def rgb2hex(rgb):
23 | """
24 | https://stackoverflow.com/questions/3380726/converting-an-rgb-color-tuple-to-a-hexidecimal-string
25 | """
26 | return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2])
27 |
28 | return rgb2hex(color(colorscheme, np.random.randint(0, 256), rgb_order=True))
29 |
30 | # ripped from cmapy since it doesn't play nice with new versions of matplotlib
31 | def cmap(cmap_name, rgb_order=False):
32 | """
33 | Extract colormap color information as a LUT compatible with cv2.applyColormap().
34 | Default channel order is BGR.
35 |
36 | Args:
37 | cmap_name: string, name of the colormap.
38 | rgb_order: boolean, if false or not set, the returned array will be in
39 | BGR order (standard OpenCV format). If true, the order
40 | will be RGB.
41 |
42 | Returns:
43 | A numpy array of type uint8 containing the colormap.
44 | """
45 |
46 | c_map = matplotlib.colormaps.get_cmap(cmap_name)
47 | rgba_data = matplotlib.cm.ScalarMappable(cmap=c_map).to_rgba(
48 | np.arange(0, 1.0, 1.0 / 256.0), bytes=True
49 | )
50 | rgba_data = rgba_data[:, 0:-1].reshape((256, 1, 3))
51 |
52 | # Convert to BGR (or RGB), uint8, for OpenCV.
53 | cmap = np.zeros((256, 1, 3), np.uint8)
54 |
55 | if not rgb_order:
56 | cmap[:, :, :] = rgba_data[:, :, ::-1]
57 | else:
58 | cmap[:, :, :] = rgba_data[:, :, :]
59 |
60 | return cmap
61 |
62 |
63 | # If python 3, redefine cmap() to use lru_cache.
64 | if sys.version_info > (3, 0):
65 | cmap = functools.lru_cache(maxsize=200)(cmap)
66 |
67 |
68 | def color(cmap_name, index, rgb_order=False):
69 | """Returns a color of a given colormap as a list of 3 BGR or RGB values.
70 |
71 | Args:
72 | cmap_name: string, name of the colormap.
73 | index: floating point between 0 and 1 or integer between 0 and 255,
74 | index of the requested color.
75 | rgb_order: boolean, if false or not set, the returned list will be in
76 | BGR order (standard OpenCV format). If true, the order
77 | will be RGB.
78 |
79 | Returns:
80 | List of RGB or BGR values.
81 | """
82 |
83 | # Float values: scale from 0-1 to 0-255.
84 | if isinstance(index, float):
85 | val = round(min(max(index, 0.0), 1.0) * 255)
86 | else:
87 | val = min(max(index, 0), 255)
88 |
89 | # Get colormap and extract color.
90 | colormap = cmap(cmap_name, rgb_order)
91 | return colormap[int(val), 0, :].tolist()
--------------------------------------------------------------------------------
/knowledge_eap.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "from transformer_lens import HookedTransformer\n",
11 | "from functools import partial\n",
12 | "import torch.nn.functional as F\n",
13 | "from eap.metrics import logit_diff, direct_logit\n",
14 | "import transformer_lens.utils as utils\n",
15 | "from eap.graph import Graph\n",
16 | "from eap.dataset import EAPDataset\n",
17 | "from eap.attribute import attribute\n",
18 | "import time\n",
19 | "from rich import print as rprint\n",
20 | "import pandas as pd\n",
21 | "from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "LLAMA_2_7B_CHAT_PATH = \"meta-llama/Llama-2-7b-chat-hf\"\n",
31 | "from transformers import LlamaForCausalLM\n",
32 | "model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device=\"cuda\", fold_ln=False, center_writing_weights=False, center_unembed=False)\n",
33 | "model.cfg.use_split_qkv_input = True\n",
34 | "model.cfg.use_attn_result = True\n",
35 | "model.cfg.use_hook_mlp_in = True"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 5,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "clean_subject = 'Eiffel Tower'\n",
45 | "corrupted_subject = 'the Great Walls'\n",
46 | "clean = f'The official currency of the country where {clean_subject} is loacted in is the'\n",
47 | "corrupted = f'The official currency of the country where {corrupted_subject} is loacted in is the'\n",
48 | "assert len(model.to_str_tokens(clean.format(clean_subject))) == len(model.to_str_tokens(corrupted.format(corrupted_subject)))\n",
49 | "labels = ['Euro','Chinese']\n",
50 | "country_idx = model.tokenizer(labels[0],add_special_tokens=False).input_ids[0]\n",
51 | "corrupted_country_idx = model.tokenizer(labels[1],add_special_tokens=False).input_ids[0]\n",
52 | "# dataset = {k:[] for k in ['clean','country_idx', 'corrupted', 'corrupted_country_idx']}\n",
53 | "# for k, v in zip(['clean', 'country_idx', 'corrupted', 'corrupted_country_idx'], [clean, country_idx, corrupted, corrupted_country_idx]):\n",
54 | "# dataset[k].append(v)\n",
55 | "# df2 = pd.DataFrame.from_dict(dataset)\n",
56 | "# df2.to_csv(f'capital_city.csv', index=False)"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 6,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "label = [[country_idx, corrupted_country_idx]]\n",
66 | "label = torch.tensor(label)\n",
67 | "data = ([clean],[corrupted],label)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 5,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "# ds = EAPDataset(filename='capital_city.csv',task='fact-retrieval')\n",
77 | "# dataloader = ds.to_dataloader(1)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 12,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "name": "stderr",
87 | "output_type": "stream",
88 | "text": [
89 | "100%|██████████| 1592881/1592881 [00:01<00:00, 1062625.82it/s]\n"
90 | ]
91 | },
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "程序执行时间:43.55915355682373秒\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "g = Graph.from_model(model)\n",
102 | "start_time = time.time()\n",
103 | "# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric\n",
104 | "attribute(model, g, data, partial(logit_diff, loss=True, mean=True), method='EAP-IG-case', ig_steps=100)\n",
105 | "# attribute(model, g, data, partial(direct_logit, loss=True, mean=True), method='EAP-IG-case', ig_steps=30)\n",
106 | "# attribute(model, g, dataloader, partial(logit_diff, loss=True, mean=True), method='EAP-IG', ig_steps=30)\n",
107 | "g.apply_topn(5000, absolute=True)\n",
108 | "g.prune_dead_nodes()\n",
109 | "\n",
110 | "g.to_json('graph.json')\n",
111 | "\n",
112 | "gz = g.to_graphviz()\n",
113 | "gz.draw(f'graph.png', prog='dot')\n",
114 | "\n",
115 | "end_time = time.time()\n",
116 | "execution_time = end_time - start_time\n",
117 | "print(f\"程序执行时间:{execution_time}秒\")"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "def get_component_logits(logits, model, answer_token, top_k=10):\n",
127 | " logits = utils.remove_batch_dim(logits)\n",
128 | " # print(heads_out[head_name].shape)\n",
129 | " probs = logits.softmax(dim=-1)\n",
130 | " token_probs = probs[-1]\n",
131 | " answer_str_token = model.to_string(answer_token)\n",
132 | " sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)\n",
133 | " # Janky way to get the index of the token in the sorted list - I couldn't find a better way?\n",
134 | " correct_rank = torch.arange(len(sorted_token_values))[\n",
135 | " (sorted_token_values == answer_token).cpu()\n",
136 | " ].item()\n",
137 | " # answer_ranks = []\n",
138 | " # answer_ranks.append((answer_str_token, correct_rank))\n",
139 | " # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.\n",
140 | " # rprint gives rich text printing\n",
141 | " rprint(\n",
142 | " f\"Performance on answer token:\\n[b]Rank: {correct_rank: <8} Logit: {logits[-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]\"\n",
143 | " )\n",
144 | " for i in range(top_k):\n",
145 | " print(\n",
146 | " f\"Top {i}th token. Logit: {logits[-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|\"\n",
147 | " )\n",
148 | " # rprint(f\"[b]Ranks of the answer tokens:[/b] {answer_ranks}\")"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 13,
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "data": {
158 | "text/html": [
159 | "Performance on answer token:\n",
160 | "Rank: 0 Logit: 16.94 Prob: 56.56% Token: |Euro|\n",
161 | "
\n"
162 | ],
163 | "text/plain": [
164 | "Performance on answer token:\n",
165 | "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m16.94\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m56.56\u001b[0m\u001b[1m% Token: |Euro|\u001b[0m\n"
166 | ]
167 | },
168 | "metadata": {},
169 | "output_type": "display_data"
170 | },
171 | {
172 | "name": "stdout",
173 | "output_type": "stream",
174 | "text": [
175 | "Top 0th token. Logit: 16.94 Prob: 56.56% Token: |Euro|\n",
176 | "Top 1th token. Logit: 15.96 Prob: 21.39% Token: |French|\n",
177 | "Top 2th token. Logit: 14.06 Prob: 3.18% Token: |_|\n",
178 | "Top 3th token. Logit: 13.95 Prob: 2.85% Token: |euro|\n",
179 | "Top 4th token. Logit: 13.91 Prob: 2.74% Token: |Eu|\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "logits = get_circuit_logits(model, g, data)\n",
185 | "get_component_logits(logits, model, answer_token=model.to_tokens('Euro',prepend_bos=False)[0], top_k=5)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 69,
191 | "metadata": {},
192 | "outputs": [
193 | {
194 | "name": "stderr",
195 | "output_type": "stream",
196 | "text": [
197 | "100%|██████████| 1/1 [00:00<00:00, 8.79it/s]\n",
198 | "100%|██████████| 1/1 [00:00<00:00, 8.91it/s]\n",
199 | "100%|██████████| 1/1 [00:00<00:00, 6.82it/s]"
200 | ]
201 | },
202 | {
203 | "name": "stdout",
204 | "output_type": "stream",
205 | "text": [
206 | "Original performance was 10.043922424316406; the circuit's performance is 6.337347984313965\n"
207 | ]
208 | },
209 | {
210 | "name": "stderr",
211 | "output_type": "stream",
212 | "text": [
213 | "\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "baseline = evaluate_baseline(model, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()\n",
219 | "results = evaluate_graph(model, g, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()\n",
220 | "print(f\"Original performance was {baseline}; the circuit's performance is {results}\")"
221 | ]
222 | }
223 | ],
224 | "metadata": {
225 | "kernelspec": {
226 | "display_name": "eap",
227 | "language": "python",
228 | "name": "python3"
229 | },
230 | "language_info": {
231 | "codemirror_mode": {
232 | "name": "ipython",
233 | "version": 3
234 | },
235 | "file_extension": ".py",
236 | "mimetype": "text/x-python",
237 | "name": "python",
238 | "nbconvert_exporter": "python",
239 | "pygments_lexer": "ipython3",
240 | "version": "3.11.9"
241 | }
242 | },
243 | "nbformat": 4,
244 | "nbformat_minor": 2
245 | }
246 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | huggingface_hub
3 | pandas
4 | datasets
5 | jaxtyping
6 | rich
7 | torch==1.13.1
8 | better_abc
9 | fancy_einsum
10 | matplotlib==3.8.4
11 | transformers>=4.37.2
12 | einops==0.7.0
13 | plotly==5.21.0
14 | wandb==0.16.6
15 | pytest==8.1.1
16 | dataclasses_json==0.6.4
17 | networkx==3.3
18 | torchtyping==0.1.4
19 | cmapy==0.6.6
20 | circuitsvis==1.43.2
21 | pygraphviz #conda install --channel conda-forge pygraphviz
--------------------------------------------------------------------------------
/transformer_lens/SVDInterpreter.py:
--------------------------------------------------------------------------------
1 | """SVD Interpreter.
2 |
3 | Module for getting the singular vectors of the OV, w_in, and w_out matrices of a
4 | :class:`transformer_lens.HookedTransformer`.
5 | """
6 |
7 | from typing import Optional, Union
8 |
9 | import fancy_einsum as einsum
10 | import torch
11 | from typeguard import typechecked
12 | from typing_extensions import Literal
13 |
14 | from transformer_lens.FactoredMatrix import FactoredMatrix
15 | from transformer_lens.HookedTransformer import HookedTransformer
16 |
17 | OUTPUT_EMBEDDING = "unembed.W_U"
18 | VECTOR_TYPES = ["OV", "w_in", "w_out"]
19 |
20 |
21 | class SVDInterpreter:
22 | def __init__(self, model: HookedTransformer):
23 | self.model = model
24 | self.cfg = model.cfg
25 | self.params = {name: param for name, param in model.named_parameters()}
26 |
27 | @typechecked
28 | def get_singular_vectors(
29 | self,
30 | vector_type: Union[Literal["OV"], Literal["w_in"], Literal["w_out"]],
31 | layer_index: int,
32 | num_vectors: int = 10,
33 | head_index: Optional[int] = None,
34 | ) -> torch.Tensor:
35 | """Gets the singular vectors for a given vector type, layer, and optionally head.
36 |
37 | This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this
38 | feature. The demo also points out some "gotchas" in this feature - numerical instability
39 | means inconsistency across devices, and the default HookedTransformer parameters don't
40 | replicate the original SVD post very well. So I'd recommend checking out the demo if you
41 | want to use this!
42 |
43 | Example:
44 |
45 | .. code-block:: python
46 |
47 | from transformer_lens import HookedTransformer, SVDInterpreter
48 |
49 | model = HookedTransformer.from_pretrained('gpt2-medium')
50 | svd_interpreter = SVDInterpreter(model)
51 |
52 | ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
53 |
54 | all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
55 | all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
56 |
57 | def plot_matrix(matrix, tokens, k=10, filter="topk"):
58 | pysvelte.TopKTable(
59 | tokens=all_tokens,
60 | activations=matrix,
61 | obj_type="SVD direction",
62 | k=k,
63 | filter=filter
64 | ).show()
65 |
66 | plot_matrix(ov, all_tokens)
67 |
68 | Args:
69 | vector_type: Type of the vector:
70 | - "OV": Singular vectors of the OV matrix for a particular layer and head.
71 | - "w_in": Singular vectors of the w_in matrix for a particular layer.
72 | - "w_out": Singular vectors of the w_out matrix for a particular layer.
73 | layer_index: The index of the layer.
74 | num_vectors: Number of vectors.
75 | head_index: Index of the head.
76 | """
77 |
78 | if head_index is None:
79 | assert vector_type in [
80 | "w_in",
81 | "w_out",
82 | ], f"Head index optional only for w_in and w_out, got {vector_type}"
83 |
84 | matrix: Union[FactoredMatrix, torch.Tensor]
85 | if vector_type == "OV":
86 | assert head_index is not None # keep mypy happy
87 | matrix = self._get_OV_matrix(layer_index, head_index)
88 | V = matrix.Vh.T
89 |
90 | elif vector_type == "w_in":
91 | matrix = self._get_w_in_matrix(layer_index)
92 | _, _, V = torch.linalg.svd(matrix)
93 |
94 | elif vector_type == "w_out":
95 | matrix = self._get_w_out_matrix(layer_index)
96 | _, _, V = torch.linalg.svd(matrix)
97 |
98 | else:
99 | raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}")
100 |
101 | return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors)
102 |
103 | def _get_singular_vectors_from_matrix(
104 | self,
105 | V: Union[torch.Tensor, FactoredMatrix],
106 | embedding: torch.Tensor,
107 | num_vectors: int = 10,
108 | ) -> torch.Tensor:
109 | """Returns the top num_vectors singular vectors from a matrix."""
110 |
111 | vectors_list = []
112 | for i in range(num_vectors):
113 | activations = V[i, :].float() @ embedding # type: ignore
114 | vectors_list.append(activations)
115 |
116 | vectors = torch.stack(vectors_list, dim=1).unsqueeze(1)
117 | assert vectors.shape == (
118 | self.cfg.d_vocab,
119 | 1,
120 | num_vectors,
121 | ), f"Vectors shape should be {self.cfg.d_vocab, 1, num_vectors} but got {vectors.shape}"
122 | return vectors
123 |
124 | def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix:
125 | """Gets the OV matrix for a particular layer and head."""
126 |
127 | assert (
128 | 0 <= layer_index < self.cfg.n_layers
129 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
130 | assert (
131 | 0 <= head_index < self.cfg.n_heads
132 | ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}"
133 |
134 | W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"]
135 | W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"]
136 | W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :]
137 |
138 | return FactoredMatrix(W_V, W_O)
139 |
140 | def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor:
141 | """Gets the w_in matrix for a particular layer."""
142 |
143 | assert (
144 | 0 <= layer_index < self.cfg.n_layers
145 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
146 |
147 | w_in = self.params[f"blocks.{layer_index}.mlp.W_in"].T
148 |
149 | if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False
150 | ln_2 = self.params[f"blocks.{layer_index}.ln2.w"]
151 | return einsum.einsum("out in, in -> out in", w_in, ln_2)
152 |
153 | return w_in
154 |
155 | def _get_w_out_matrix(self, layer_index: int) -> torch.Tensor:
156 | """Gets the w_out matrix for a particular layer."""
157 |
158 | assert (
159 | 0 <= layer_index < self.cfg.n_layers
160 | ), f"Layer index must be between 0 and {self.cfg.n_layers-1} but got {layer_index}"
161 |
162 | return self.params[f"blocks.{layer_index}.mlp.W_out"]
163 |
--------------------------------------------------------------------------------
/transformer_lens/__init__.py:
--------------------------------------------------------------------------------
1 | from . import hook_points
2 | from . import utils
3 | from . import evals
4 | from .past_key_value_caching import (
5 | HookedTransformerKeyValueCache,
6 | HookedTransformerKeyValueCacheEntry,
7 | )
8 | from . import components
9 | from .HookedTransformerConfig import HookedTransformerConfig
10 | from .FactoredMatrix import FactoredMatrix
11 | from .ActivationCache import ActivationCache
12 | from .HookedTransformer import HookedTransformer
13 | from .SVDInterpreter import SVDInterpreter
14 | from .HookedEncoder import HookedEncoder
15 | from . import head_detector
16 | from . import loading_from_pretrained as loading
17 | from . import patching
18 | from . import train
19 |
20 | from .past_key_value_caching import (
21 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache,
22 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry,
23 | )
24 | from .HookedTransformer import HookedTransformer as EasyTransformer
25 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig
26 |
--------------------------------------------------------------------------------
/transformer_lens/past_key_value_caching.py:
--------------------------------------------------------------------------------
1 | """Past Key Value Caching.
2 |
3 | This module contains the HookedTransformerKeyValueCache and HookedTransformerKeyValueCacheEntry
4 | classes, which are used to store past keys and values for the Transformer. This is important for
5 | generating text - we can cache a lot of past computation and avoid repeating ourselves!
6 | """
7 | from dataclasses import dataclass
8 | from typing import List, Union
9 |
10 | import torch
11 | from jaxtyping import Float, Int
12 |
13 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14 | from transformer_lens.utilities.devices import get_device_for_block_index
15 |
16 |
17 | @dataclass
18 | class HookedTransformerKeyValueCacheEntry:
19 | past_keys: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
20 | past_values: Float[torch.Tensor, "batch pos_so_far n_heads d_head"]
21 | frozen: bool = False
22 |
23 | @classmethod
24 | def init_cache_entry(
25 | cls,
26 | cfg: HookedTransformerConfig,
27 | device: Union[torch.device, str, None],
28 | batch_size: int = 1,
29 | ):
30 | n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads
31 | return cls(
32 | past_keys=torch.empty(
33 | (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
34 | ),
35 | past_values=torch.empty(
36 | (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype
37 | ),
38 | )
39 |
40 | def append(
41 | self,
42 | new_keys: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
43 | new_values: Float[torch.Tensor, "batch new_tokens n_heads d_head"],
44 | ):
45 | updated_keys: Float[
46 | torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
47 | ] = torch.cat([self.past_keys, new_keys], dim=1)
48 | updated_values: Float[
49 | torch.Tensor, "batch pos_so_far_plus_new_tokens n_heads d_head"
50 | ] = torch.cat([self.past_values, new_values], dim=1)
51 | if not self.frozen:
52 | self.past_keys = updated_keys
53 | self.past_values = updated_values
54 | return updated_keys, updated_values
55 |
56 |
57 | @dataclass
58 | class HookedTransformerKeyValueCache:
59 | """
60 | A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
61 |
62 | This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
63 |
64 | The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.
65 | """
66 |
67 | entries: List[HookedTransformerKeyValueCacheEntry]
68 | previous_attention_mask: Int[torch.Tensor, "batch pos_so_far"]
69 | frozen: bool = False
70 |
71 | @classmethod
72 | def init_cache(
73 | cls,
74 | cfg: HookedTransformerConfig,
75 | device: Union[torch.device, str, None],
76 | batch_size: int = 1,
77 | ):
78 | return cls(
79 | entries=[
80 | HookedTransformerKeyValueCacheEntry.init_cache_entry(
81 | cfg,
82 | get_device_for_block_index(i, cfg, device),
83 | batch_size,
84 | )
85 | for i in range(cfg.n_layers)
86 | ],
87 | previous_attention_mask=torch.empty(
88 | # This may actually be an int64, but type promotion will handle it:
89 | # See: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
90 | # See: https://github.com/pytorch/pytorch/issues/35014
91 | (batch_size, 0),
92 | device=device,
93 | dtype=torch.int,
94 | ),
95 | )
96 |
97 | def freeze(self):
98 | self.frozen = True
99 | for entry in self.entries:
100 | entry.frozen = True
101 |
102 | def unfreeze(self):
103 | self.frozen = False
104 | for entry in self.entries:
105 | entry.frozen = False
106 |
107 | def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]):
108 | attention_mask = attention_mask.to(self.previous_attention_mask.device)
109 | updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1)
110 | if not self.frozen:
111 | self.previous_attention_mask = updated_attention_mask
112 | return updated_attention_mask
113 |
114 | def __getitem__(self, idx):
115 | return self.entries[idx]
116 |
--------------------------------------------------------------------------------
/transformer_lens/train.py:
--------------------------------------------------------------------------------
1 | """Train.
2 |
3 | Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language
4 | modeling tasks.
5 | """
6 |
7 | from dataclasses import dataclass
8 | from typing import Optional
9 |
10 | import torch
11 | import torch.optim as optim
12 | import wandb
13 | from torch.optim import Optimizer
14 | from torch.utils.data import DataLoader, Dataset
15 | from tqdm.auto import tqdm
16 |
17 | from transformer_lens import utils
18 | from transformer_lens.HookedTransformer import HookedTransformer
19 |
20 |
21 | @dataclass
22 | class HookedTransformerTrainConfig:
23 | """
24 | Configuration class to store training hyperparameters for a training run of
25 | an HookedTransformer model.
26 | Args:
27 | num_epochs (int): Number of epochs to train for
28 | batch_size (int): Size of batches to use for training
29 | lr (float): Learning rate to use for training
30 | seed (int): Random seed to use for training
31 | momentum (float): Momentum to use for training
32 | max_grad_norm (float, *optional*): Maximum gradient norm to use for
33 | weight_decay (float, *optional*): Weight decay to use for training
34 | optimizer_name (str): The name of the optimizer to use
35 | device (str, *optional*): Device to use for training
36 | warmup_steps (int, *optional*): Number of warmup steps to use for training
37 | save_every (int, *optional*): After how many batches should a checkpoint be saved
38 | save_dir, (str, *optional*): Where to save checkpoints
39 | wandb (bool): Whether to use Weights and Biases for logging
40 | wandb_project (str, *optional*): Name of the Weights and Biases project to use
41 | print_every (int, *optional*): Print the loss every n steps
42 | max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
43 | """
44 |
45 | num_epochs: int
46 | batch_size: int
47 | lr: float = 1e-3
48 | seed: int = 0
49 | momentum: float = 0.0
50 | max_grad_norm: Optional[float] = None
51 | weight_decay: Optional[float] = None
52 | optimizer_name: str = "Adam"
53 | device: Optional[str] = None
54 | warmup_steps: int = 0
55 | save_every: Optional[int] = None
56 | save_dir: Optional[str] = None
57 | wandb: bool = False
58 | wandb_project_name: Optional[str] = None
59 | print_every: Optional[int] = 50
60 | max_steps: Optional[int] = None
61 |
62 |
63 | def train(
64 | model: HookedTransformer,
65 | config: HookedTransformerTrainConfig,
66 | dataset: Dataset,
67 | ) -> HookedTransformer:
68 | """
69 | Trains an HookedTransformer model on an autoregressive language modeling task.
70 | Args:
71 | model: The model to train
72 | config: The training configuration
73 | dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
74 | Returns:
75 | The trained model
76 | """
77 | torch.manual_seed(config.seed)
78 | model.train()
79 | if config.wandb:
80 | if config.wandb_project_name is None:
81 | config.wandb_project_name = "easy-transformer"
82 | wandb.init(project=config.wandb_project_name, config=vars(config))
83 |
84 | if config.device is None:
85 | config.device = utils.get_device()
86 |
87 | optimizer: Optimizer
88 | if config.optimizer_name in ["Adam", "AdamW"]:
89 | # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
90 | if config.weight_decay is not None:
91 | optimizer = optim.AdamW(
92 | model.parameters(),
93 | lr=config.lr,
94 | weight_decay=config.weight_decay,
95 | )
96 | else:
97 | optimizer = optim.Adam(
98 | model.parameters(),
99 | lr=config.lr,
100 | )
101 | elif config.optimizer_name == "SGD":
102 | optimizer = optim.SGD(
103 | model.parameters(),
104 | lr=config.lr,
105 | weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0),
106 | momentum=config.momentum,
107 | )
108 | else:
109 | raise ValueError(f"Optimizer {config.optimizer_name} not supported")
110 |
111 | scheduler = None
112 | if config.warmup_steps > 0:
113 | scheduler = optim.lr_scheduler.LambdaLR(
114 | optimizer,
115 | lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
116 | )
117 |
118 | dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
119 |
120 | model.to(config.device)
121 |
122 | for epoch in tqdm(range(1, config.num_epochs + 1)):
123 | samples = 0
124 | for step, batch in tqdm(enumerate(dataloader)):
125 | tokens = batch["tokens"].to(config.device)
126 | loss = model(tokens, return_type="loss")
127 | loss.backward()
128 | if config.max_grad_norm is not None:
129 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
130 | optimizer.step()
131 | if config.warmup_steps > 0:
132 | assert scheduler is not None
133 | scheduler.step()
134 | optimizer.zero_grad()
135 |
136 | samples += tokens.shape[0]
137 |
138 | if config.wandb:
139 | wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch})
140 |
141 | if config.print_every is not None and step % config.print_every == 0:
142 | print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
143 |
144 | if (
145 | config.save_every is not None
146 | and step % config.save_every == 0
147 | and config.save_dir is not None
148 | ):
149 | torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")
150 |
151 | if config.max_steps is not None and step >= config.max_steps:
152 | break
153 |
154 | return model
155 |
--------------------------------------------------------------------------------
/transformer_lens/utilities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjunlp/KnowledgeCircuits/bda3d22cf3b74a6b48c092f952e95b6414d8a9de/transformer_lens/utilities/__init__.py
--------------------------------------------------------------------------------
/transformer_lens/utilities/devices.py:
--------------------------------------------------------------------------------
1 | """Devices.
2 |
3 | Utilities to get the correct device, and assist in distributing model layers across multiple
4 | devices.
5 | """
6 | from __future__ import annotations
7 |
8 | from typing import Optional, Union
9 |
10 | import torch
11 | from torch import nn
12 |
13 | import transformer_lens
14 |
15 |
16 | def get_device_for_block_index(
17 | index: int,
18 | cfg: "transformer_lens.HookedTransformerConfig",
19 | device: Optional[Union[torch.device, str]] = None,
20 | ):
21 | """
22 | Determine the device for a given layer index based on the model configuration.
23 |
24 | This function assists in distributing model layers across multiple devices. The distribution
25 | is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
26 |
27 | Args:
28 | index (int): Model layer index.
29 | cfg (HookedTransformerConfig): Model and device configuration.
30 | device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device.
31 | If not provided, the function uses the device specified in the configuration (cfg.device).
32 |
33 | Returns:
34 | torch.device: The device for the specified layer index.
35 | """
36 | assert cfg.device is not None
37 | layers_per_device = cfg.n_layers // cfg.n_devices
38 | if device is None:
39 | device = cfg.device
40 | device = torch.device(device)
41 | if device.type == "cpu":
42 | return device
43 | device_index = (device.index or 0) + (index // layers_per_device)
44 | return torch.device(device.type, device_index)
45 |
46 |
47 | def move_to_and_update_config(
48 | model: Union["transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"],
49 | device_or_dtype: Union[torch.device, str, torch.dtype],
50 | print_details=True,
51 | ):
52 | """
53 | Wrapper around `to` that also updates `model.cfg`.
54 | """
55 | if isinstance(device_or_dtype, torch.device):
56 | model.cfg.device = device_or_dtype.type
57 | if print_details:
58 | print("Moving model to device: ", model.cfg.device)
59 | elif isinstance(device_or_dtype, str):
60 | model.cfg.device = device_or_dtype
61 | if print_details:
62 | print("Moving model to device: ", model.cfg.device)
63 | elif isinstance(device_or_dtype, torch.dtype):
64 | model.cfg.dtype = device_or_dtype
65 | if print_details:
66 | print("Changing model dtype to", device_or_dtype)
67 | # change state_dict dtypes
68 | for k, v in model.state_dict().items():
69 | model.state_dict()[k] = v.to(device_or_dtype)
70 | return nn.Module.to(model, device_or_dtype)
71 |
--------------------------------------------------------------------------------