├── .gitignore
├── README.md
├── figures
└── methodology.png
├── programs
├── conll
│ ├── conll_ner.py
│ ├── conll_ner_embeddings.csv
│ ├── conll_ner_embeddings.py
│ ├── conll_ner_weights.csv
│ └── conll_notebook.ipynb
├── rasp
│ ├── double_hist
│ │ ├── double_hist.py
│ │ └── double_hist_weights.csv
│ ├── dyck1
│ │ ├── dyck1.py
│ │ └── dyck1_weights.csv
│ ├── dyck2
│ │ ├── dyck2.py
│ │ └── dyck2_weights.csv
│ ├── hist
│ │ ├── hist.py
│ │ └── hist_weights.csv
│ ├── most_freq
│ │ ├── most_freq.py
│ │ └── most_freq_weights.csv
│ ├── reverse
│ │ ├── reverse.py
│ │ └── reverse_weights.csv
│ └── sort
│ │ ├── sort.py
│ │ └── sort_weights.csv
├── rasp_categorical_only
│ ├── double_hist
│ │ ├── double_hist.py
│ │ └── double_hist_weights.csv
│ ├── dyck1
│ │ ├── dyck1.py
│ │ └── dyck1_weights.csv
│ ├── dyck2
│ │ ├── dyck2.py
│ │ └── dyck2_weights.csv
│ ├── hist
│ │ ├── hist.py
│ │ └── hist_weights.csv
│ ├── most_freq
│ │ ├── most_freq.py
│ │ └── most_freq_weights.csv
│ ├── reverse
│ │ ├── reverse.py
│ │ └── reverse_weights.csv
│ └── sort
│ │ ├── sort.py
│ │ └── sort_weights.csv
└── trec
│ ├── trec.py
│ ├── trec_embeddings.csv
│ ├── trec_embeddings.py
│ └── trec_weights.csv
├── requirements.txt
├── scripts
├── classification.sh
├── classification_short.sh
├── classification_standard.sh
├── conll.sh
├── conll_standard.sh
├── dyck.sh
├── induction.sh
└── rasp.sh
├── setup.py
└── src
├── __init__.py
├── decompile.py
├── models
├── __init__.py
├── programs.py
└── transformers.py
├── run.py
└── utils
├── __init__.py
├── analysis_utils.py
├── code_utils.py
├── data_utils.py
├── logging.py
└── metric_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | _site
2 | .sass-cache
3 | Gemfile.lock
4 | *.gem
5 | .jekyll-cache
6 | .jekyll-cache
7 | *~
8 | *__pycache__*
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Transformer Programs
2 |
3 | This repository contains the code for our paper, [Learning Transformer Programs](https://arxiv.org/abs/2306.01128).
4 | The code can be used to train a modified Transformer to solve a task, and then convert it into a human-readable Python program.
5 | The repository also includes a number of [example programs](#Example-programs), which we learned for the tasks described in the paper.
6 | Please see [our paper](https://arxiv.org/abs/2306.01128) for more details.
7 |
8 |
9 |
10 |
11 | ## Quick links
12 | * [Setup](#Setup)
13 | * [Learning Programs](#Learning-programs)
14 | * [Training](#Training)
15 | * [Converting to code](#Converting-to-code)
16 | * [Example Programs](#Example-programs)
17 | * [Questions?](#Questions)
18 | * [Citation](#Citation)
19 |
20 | ## Setup
21 |
22 | Install [PyTorch](https://pytorch.org/get-started/locally/) and then install the remaining requirements: `pip install -r requirements.txt`.
23 | This code was tested using Python 3.8 and PyTorch version 1.13.1.
24 |
25 | In our experiments on NLP tasks, we initialize word embeddings using 300-dimensional pre-trained GloVe embeddings, which can be downloaded [here](https://github.com/stanfordnlp/GloVe) (Common Crawl, cased):
26 | ```bash
27 | mkdir data
28 | wget https://huggingface.co/stanfordnlp/glove/resolve/main/glove.840B.300d.zip -P data/
29 | unzip data/glove.840B.300d.zip
30 | ```
31 |
32 | ## Learning Programs
33 |
34 | ### Training
35 |
36 | The code to learn a Transformer Program can be found in [src/run.py](src/run.py).
37 | For example, the following command will train a Transformer Program for the `sort` task, using two layers, four categorical attention heads per-layer, and one-hot input embeddings:
38 | ```bash
39 | python src/run.py \
40 | --dataset "sort" \
41 | --vocab_size 8 \
42 | --dataset_size 10000 \
43 | --min_length 1 \
44 | --max_length 8 \
45 | --n_epochs 250 \
46 | --batch_size 512 \
47 | --lr "5e-2" \
48 | --n_layers 2 \
49 | --n_heads_cat 4 \
50 | --n_heads_num 0 \
51 | --n_cat_mlps 1 \
52 | --n_num_mlps 0 \
53 | --one_hot_embed \
54 | --count_only \
55 | --seed 0 \
56 | --save \
57 | --save_code \
58 | --output_dir "output/sort";
59 | ```
60 | This command will train a Transformer Program for the CoNLL 2003 named-entity recognition task, learning input embeddings composed of four 32-dimensional categorical variables:
61 | ```bash
62 | python src/run.py \
63 | --dataset "conll_ner" \
64 | --vocab_size 10000 \
65 | --min_length 1 \
66 | --max_length 32 \
67 | --n_epochs 50 \
68 | --batch_size 32 \
69 | --lr "5e-2" \
70 | --n_vars_cat 4 \
71 | --d_var 32 \
72 | --n_layers 2 \
73 | --n_heads_cat 4 \
74 | --n_heads_num 0 \
75 | --n_cat_mlps 1 \
76 | --n_num_mlps 0 \
77 | --mlp_vars_in 2 \
78 | --count_only \
79 | --seed 0 \
80 | --replace_numbers 1 \
81 | --glove_embeddings "data/glove.840B.300d.txt" \
82 | --do_glove 1 \
83 | --save \
84 | --save_code \
85 | --output_dir "output/conll";
86 | ```
87 | Please see [src/run.py](src/run.py) for all of the possible arguments.
88 | The training data will either be generated (for the RASP tasks) or downloaded from [Hugging Face Datasets](https://huggingface.co/datasets); see [src/utils/data_utils.py](src/utils/data_utils.py) for the supported datasets.
89 | The [scripts](scripts/) directory contains scripts for training Transformer Programs and standard Transformers with the experiment settings used in the paper.
90 |
91 | ### Converting to code
92 |
93 | Run the training script with the `--save_code` flag to convert the model to a Python program at the end of training.
94 | To convert a model that has already been trained, use `src/decompile.py`.
95 | For example,
96 | ```bash
97 | python src/decompile.py --path output/sort/ --output_dir programs/sort/
98 | ```
99 | `output/sort/` should be the output directory of a training run.
100 |
101 | # Example Programs
102 |
103 | The [programs](programs/) directory contains example programs for small-scale versions of all of the [RASP tasks](https://arxiv.org/abs/2106.06981), as well as named-entity recognition.
104 | Each program defines a function called `run` that takes a sequence of tokens as input and returns a list of predicted labels.
105 | For example:
106 | ```pycon
107 | >>> from programs.rasp.sort import sort
108 | >>> sort.run(["", "3", "1", "4", "2", "4", "0", ""])
109 | ['', '0', '1', '2', '3', '4', '4', '']
110 | ```
111 | [programs/rasp](programs/rasp) contains the best-performing programs for each task, using both categorical and numerical attention heads.
112 | [programs/rasp_categorical_only](programs/rasp_categorical_only) contains the best-performing programs using only categorical variables.
113 | [programs/conll_ner](programs/conll_ner) contains a program for named-entity recognition.
114 |
115 | # Questions?
116 |
117 | If you have any questions about the code or paper, please email Dan (dfriedman@cs.princeton.edu) or open an issue.
118 |
119 | # Citation
120 |
121 | ```bibtex
122 | @inproceedings{
123 | friedman2023learning,
124 | title={Learning Transformer Programs},
125 | author={Dan Friedman and Alexander Wettig and Danqi Chen},
126 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
127 | year={2023},
128 | url={https://openreview.net/forum?id=Pe9WxkN8Ff}
129 | }
130 | ```
131 |
--------------------------------------------------------------------------------
/figures/methodology.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/figures/methodology.png
--------------------------------------------------------------------------------
/programs/rasp/double_hist/double_hist_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,1,2,3,4,5,6
2 | attn_0_0_outputs,0,1.238853,1.243148,0.68027335,-4.868072,-3.0746865,-8.173675
3 | attn_0_0_outputs,1,0.31741944,0.6675796,0.59875584,-0.74731636,-1.1577901,-1.0011843
4 | attn_0_0_outputs,2,-1.0877419,1.9231943,0.3949728,-1.539528,-1.1613191,-2.8614204
5 | attn_0_0_outputs,3,-0.11632717,0.41220337,0.60158277,-0.28780693,-0.35998842,-0.59837425
6 | attn_0_0_outputs,4,-0.19129646,0.30991477,0.4910216,-0.08420599,0.5274653,-0.19311686
7 | attn_0_0_outputs,5,-0.40474758,-0.0207731,0.56750906,-0.62817794,1.2629008,0.54661673
8 | attn_0_0_outputs,6,-1.8074787,-0.30840394,2.049226,0.3926784,-1.7478626,5.0987434
9 | attn_0_0_outputs,7,-1.731275,-0.15045756,0.36238095,0.8407283,4.7304835,-2.225291
10 | attn_0_1_outputs,0,0.5155269,-0.23179123,0.9619248,0.8720239,-4.370495,-2.8444674
11 | attn_0_1_outputs,1,-0.032038573,0.15968679,-0.11178788,-0.96792173,-1.093138,-1.0469398
12 | attn_0_1_outputs,2,-0.49667308,0.20647019,0.9112648,-1.1316853,0.0978555,-0.8900808
13 | attn_0_1_outputs,3,-0.19489852,0.30834976,0.29149553,-0.5899691,-1.2079358,-0.4888331
14 | attn_0_1_outputs,4,-0.7092352,0.40705568,1.2359505,-0.14838673,-1.4845357,-0.8633547
15 | attn_0_1_outputs,5,-1.7096843,0.57668656,1.8166345,-1.3838611,2.8332915,1.320257
16 | attn_0_1_outputs,6,-1.8643975,-0.14910048,1.1025906,-0.046011463,1.486126,3.767073
17 | attn_0_1_outputs,7,-1.6024725,-0.0244187,1.0061564,-0.68429786,3.118972,-2.3170335
18 | attn_0_2_outputs,0,1.2139428,1.788951,-0.6449291,-4.5301723,-3.1953726,-8.304671
19 | attn_0_2_outputs,1,0.2632638,0.2750803,-0.058388937,-0.96432394,-1.2596529,-0.82894176
20 | attn_0_2_outputs,2,-0.06439551,0.10690294,-0.5768559,-1.3450785,-1.2679642,-1.0744632
21 | attn_0_2_outputs,3,0.071711585,0.59252673,0.50191075,-0.10376727,-0.6866497,-0.6754526
22 | attn_0_2_outputs,4,0.10581133,0.1553707,0.038040042,0.071251996,0.29624116,0.06374497
23 | attn_0_2_outputs,5,-1.8224907,0.8820499,0.9754138,-1.1711323,2.2756393,-0.019901173
24 | attn_0_2_outputs,6,-1.8428569,0.5488585,2.5900345,1.2428666,-0.56255573,5.247683
25 | attn_0_2_outputs,7,-2.4727259,0.5789163,1.2076744,1.0896817,5.294417,-2.1652772
26 | attn_0_3_outputs,0,-0.5514277,1.2344804,0.32450405,-1.3071667,-3.758229,-5.2093534
27 | attn_0_3_outputs,1,-0.5949326,0.017211447,-0.21026301,-1.1978165,-1.1930654,-1.3590801
28 | attn_0_3_outputs,2,-0.9301655,0.37480146,1.6829457,-0.75800765,-0.65382415,-0.49428207
29 | attn_0_3_outputs,3,-0.27631456,0.27953884,0.43736613,-0.38235372,-0.5264454,-0.63245666
30 | attn_0_3_outputs,4,-0.93398,0.5528436,1.5676297,0.10561387,-0.64765644,-0.28653306
31 | attn_0_3_outputs,5,-1.6023785,0.24490198,1.8999624,-1.2599331,0.92110413,0.8147831
32 | attn_0_3_outputs,6,-1.3627234,0.13494563,1.5294584,0.40429628,1.4354382,4.1456857
33 | attn_0_3_outputs,7,-1.6638067,0.09259084,1.1263107,-0.0054527083,2.0681398,-2.6293533
34 | attn_1_0_outputs,0,-0.9760603,0.9746706,-0.4910691,-0.28577244,-0.55609435,-0.629483
35 | attn_1_0_outputs,1,0.027657213,0.029981421,-0.24031596,-0.596706,-0.17725739,-0.60143006
36 | attn_1_0_outputs,2,-1.6395041,0.9376783,-0.6826759,-0.34162158,-0.6060739,-0.91161937
37 | attn_1_0_outputs,3,-1.273714,0.48423272,1.1538424,-0.8714459,-1.4124024,-1.35856
38 | attn_1_0_outputs,4,-0.8404264,0.92396635,1.5313677,-1.2234914,-3.3400934,-4.8241587
39 | attn_1_0_outputs,5,-1.8222586,0.9126125,0.56483024,-0.6948536,2.166879,1.4243178
40 | attn_1_0_outputs,6,-0.52920604,-1.4131529,-0.1339102,2.2840686,-0.79413146,4.49058
41 | attn_1_0_outputs,7,-0.37916443,-0.98024577,0.21309751,0.6570406,3.8979933,-0.9583814
42 | attn_1_1_outputs,0,1.1848046,-0.78384286,4.3952165,-4.754357,-2.5311391,-4.304711
43 | attn_1_1_outputs,1,-0.13832416,0.7882959,-0.02595803,-1.5303969,-1.0957807,-1.3929857
44 | attn_1_1_outputs,2,-0.50316894,-0.22863217,1.5112969,-0.29604048,-0.580187,-1.3095005
45 | attn_1_1_outputs,3,-0.79329056,0.42733788,0.14356261,0.6288875,-0.19792889,-0.26626664
46 | attn_1_1_outputs,4,-1.9033594,1.7780949,1.592835,-1.8940227,-2.7704701,-0.5928191
47 | attn_1_1_outputs,5,-1.9111953,0.37122542,1.2167804,-1.6839112,2.1015897,-0.7357016
48 | attn_1_1_outputs,6,-2.3154323,-0.28756234,2.4062943,2.108501,0.11289929,4.5701685
49 | attn_1_1_outputs,7,2.4386678,-1.1878088,-2.2552593,-0.20960121,7.1543226,-1.7744904
50 | attn_1_2_outputs,0,-1.1251806,1.1169194,-0.91066366,-0.7407269,-0.81425554,-0.48066366
51 | attn_1_2_outputs,1,0.45716456,0.41061336,-0.46052435,-0.34725094,-1.1503232,-0.87109935
52 | attn_1_2_outputs,2,0.013095508,0.45874515,-0.12555777,-1.040384,-1.5763316,-0.78179336
53 | attn_1_2_outputs,3,-1.289953,0.9008724,1.2994568,-0.13545798,-0.44344234,-0.953467
54 | attn_1_2_outputs,4,-1.1722677,1.0613216,1.844418,-0.7773634,-2.8786588,-4.2749085
55 | attn_1_2_outputs,5,-0.1720321,0.119764805,0.8184181,-1.9128888,1.5016356,0.541971
56 | attn_1_2_outputs,6,-1.4663309,0.59112865,0.12126205,0.59590137,-1.0372763,3.382513
57 | attn_1_2_outputs,7,-0.9167632,-0.07292843,0.44805673,0.12057245,4.2690535,-0.84026176
58 | attn_1_3_outputs,0,-0.18947232,0.93384516,2.9093647,-4.247921,-3.8325047,-1.8869388
59 | attn_1_3_outputs,1,-0.15879959,0.20900528,0.10304354,-1.2823522,-1.4563886,-1.1255486
60 | attn_1_3_outputs,2,-2.7822475,1.1226128,3.61258,-1.2871381,3.0924113,-1.6648036
61 | attn_1_3_outputs,3,-0.18244827,0.10408772,0.1082736,-0.33289492,-1.1335009,-1.1429838
62 | attn_1_3_outputs,4,-1.3203812,-0.40582213,0.47339472,2.714348,-2.8118975,-2.4057264
63 | attn_1_3_outputs,5,-2.4856753,0.31100777,1.4052066,-0.45595795,1.055359,1.5162545
64 | attn_1_3_outputs,6,-1.5411837,-0.10812548,1.6073796,0.0880546,0.5234726,4.230394
65 | attn_1_3_outputs,7,-2.2810664,0.3371949,1.277099,1.0733384,5.143846,-1.5911686
66 | attn_2_0_outputs,0,-0.7608168,0.73692936,2.2746222,0.09226601,-3.9912004,-4.357786
67 | attn_2_0_outputs,1,0.73612237,0.4157502,0.40716407,-1.2919782,-0.75344867,-1.2904367
68 | attn_2_0_outputs,2,-0.43750256,-0.31478426,1.2700193,-1.3610334,-0.40046877,-1.7192286
69 | attn_2_0_outputs,3,-0.14183228,0.46041623,0.7322802,-0.7410107,-0.67563444,-0.98685783
70 | attn_2_0_outputs,4,-1.6794555,0.2764951,2.2788801,-1.535832,-1.5119663,-1.3937936
71 | attn_2_0_outputs,5,-1.122806,0.35427764,1.8529143,-0.08506123,0.5083381,0.70440453
72 | attn_2_0_outputs,6,-1.3130459,0.24721187,1.5720521,0.6670818,1.2480316,5.339757
73 | attn_2_0_outputs,7,-1.756592,-0.3179971,0.9707079,0.3391206,4.742656,-0.77619433
74 | attn_2_1_outputs,0,5.1843696,-0.14824952,-7.7279587,-0.18898705,-7.4058304,-5.8876233
75 | attn_2_1_outputs,1,0.07970304,0.33187732,0.47960332,-1.1841625,-1.2469735,-0.8724654
76 | attn_2_1_outputs,2,0.1316852,0.07690099,0.9332321,-1.4709224,-0.95776975,-0.77643555
77 | attn_2_1_outputs,3,0.17687394,-0.45715567,0.9835989,0.15518622,-0.95674294,-0.68985116
78 | attn_2_1_outputs,4,-0.37534592,-1.1890558,0.37393364,0.37020233,-0.8960392,-0.8327007
79 | attn_2_1_outputs,5,0.011616593,-0.3313448,0.8376248,-1.2172726,1.578385,1.6046035
80 | attn_2_1_outputs,6,-3.7970808,2.6305835,-0.4534069,0.8028219,-3.9948761,8.23428
81 | attn_2_1_outputs,7,-6.8515778,-0.8602876,8.407036,-1.1475514,11.212976,-2.1321476
82 | attn_2_2_outputs,0,-1.3947915,0.82383966,0.19460057,0.4075744,-0.50466233,0.38375315
83 | attn_2_2_outputs,1,0.36330578,0.78628147,0.20042144,0.08738863,0.14224096,-0.08429496
84 | attn_2_2_outputs,2,-1.4779562,1.399697,0.86861706,0.1598239,0.75133777,0.3097304
85 | attn_2_2_outputs,3,-1.2577441,0.014175095,1.141915,1.2425467,-0.23655094,-0.7029587
86 | attn_2_2_outputs,4,-0.98365426,0.30592975,1.760364,-3.4134092,-4.0350323,-6.1376657
87 | attn_2_2_outputs,5,-1.975907,1.060878,0.7518102,-1.0094594,1.415349,0.56762004
88 | attn_2_2_outputs,6,-0.56342816,-1.4276974,0.8616133,1.3711227,-0.664221,2.6826735
89 | attn_2_2_outputs,7,-1.6278553,-0.31998882,1.4476215,-0.50721836,2.9143424,-0.6652316
90 | attn_2_3_outputs,0,-0.10320366,1.2831827,1.0159588,-4.5068374,-3.8469913,-3.0415845
91 | attn_2_3_outputs,1,0.9415238,0.6988879,-0.58729637,-1.3218888,-1.1747098,-1.3971877
92 | attn_2_3_outputs,2,-1.1604211,0.40635908,2.708663,-3.707528,2.3209622,-3.0766304
93 | attn_2_3_outputs,3,-0.8133805,-0.66232896,0.031035835,-0.24328968,-0.7707727,-0.6527309
94 | attn_2_3_outputs,4,-1.1196601,-0.2392621,1.2020575,1.8583574,-3.9654202,-3.1436868
95 | attn_2_3_outputs,5,-1.373232,0.14063334,0.9891212,-0.32424918,1.4090708,0.9559994
96 | attn_2_3_outputs,6,-1.7961397,-0.22981627,1.3094918,0.7286974,0.2208126,4.3420715
97 | attn_2_3_outputs,7,-1.9815708,-0.68765175,1.3021115,0.6383351,4.8859954,-1.5050861
98 | mlp_0_0_outputs,0,-0.80427796,-0.5239094,0.64130455,-0.8501387,0.88480604,0.8062301
99 | mlp_0_0_outputs,1,-0.703996,-0.06041357,1.1169733,-0.21612476,-0.56217194,-0.062908545
100 | mlp_0_0_outputs,2,-0.5470064,0.0647984,0.9857732,-0.7440401,1.3193768,0.47770888
101 | mlp_0_0_outputs,3,-1.2310531,-0.66088575,1.5901628,0.08961851,0.6930918,-1.6064136
102 | mlp_0_0_outputs,4,0.6362144,0.93965715,-1.6470672,-0.77374464,-4.137099,-3.0756845
103 | mlp_0_0_outputs,5,-0.69935584,-0.5942387,0.61135715,-0.33549237,0.9449682,-0.04219789
104 | mlp_0_0_outputs,6,-0.81338984,-0.004528439,1.7929608,-1.1758713,-1.5474731,-0.6041481
105 | mlp_0_0_outputs,7,-1.3381019,-0.25485164,0.9574129,-0.5505601,0.84711146,-1.2687894
106 | mlp_1_0_outputs,0,-0.35261238,-0.488968,0.317209,-0.353129,-0.29976666,-0.22690473
107 | mlp_1_0_outputs,1,10.887252,-2.810294,-6.033218,-4.187805,-3.888413,-2.4883578
108 | mlp_1_0_outputs,2,-0.18418525,-0.3537582,0.4354651,-0.05020326,0.2935266,0.3063467
109 | mlp_1_0_outputs,3,-4.5915027,-0.70305747,7.38149,0.14982294,-0.30977443,-0.60969657
110 | mlp_1_0_outputs,4,-0.2931938,-0.42586467,0.6560199,-0.23105785,-0.093104005,0.11018199
111 | mlp_1_0_outputs,5,0.026697708,-0.14337309,0.4799277,-0.05001711,0.34555975,0.35481232
112 | mlp_1_0_outputs,6,0.028377075,-0.24432565,0.54357,-0.14310212,0.11082643,-0.18106437
113 | mlp_1_0_outputs,7,7.758,0.051592894,-1.7972865,-0.4951801,-0.56271195,-1.6264058
114 | mlp_2_0_outputs,0,-0.07149683,0.17804867,0.027318498,-0.47190544,0.2526821,0.25335145
115 | mlp_2_0_outputs,1,-0.029524894,-0.124501094,-0.280803,-0.6324781,-0.11064398,0.20567767
116 | mlp_2_0_outputs,2,-0.36568528,-0.18049608,1.1864383,2.4141302,-1.3685066,-2.4289098
117 | mlp_2_0_outputs,3,-0.08572693,0.095609844,1.3656018,2.4667256,-0.62999135,-0.9260636
118 | mlp_2_0_outputs,4,-0.25557655,-0.08844128,-0.011158431,-0.67957497,-0.13888906,0.057696126
119 | mlp_2_0_outputs,5,-0.57896894,-0.4964491,1.1442068,0.18332906,-0.40583086,-1.8021346
120 | mlp_2_0_outputs,6,-0.14598659,-0.25447085,-0.09391215,-1.0940522,-0.26641744,0.085448615
121 | mlp_2_0_outputs,7,-1.6514058,0.5661072,2.6141307,-0.58939695,0.0069024065,-0.47796342
122 | num_attn_0_0_outputs,_,1.6166081,0.65979415,0.02641497,-3.4313304,-4.040161,-2.7888904
123 | num_attn_0_1_outputs,_,-0.24869336,-0.6833789,0.9426215,0.3185697,1.193121,0.52059203
124 | num_attn_0_2_outputs,_,-2.9678483,1.6683826,1.7130572,-1.6549467,2.552361,-0.4446513
125 | num_attn_0_3_outputs,_,-2.0443754,0.47501704,0.6353057,0.45521334,-0.075898685,-0.39096996
126 | num_attn_1_0_outputs,_,1.4484639,-0.9826932,-2.5024445,-4.524883,9.575657,-1.2793733
127 | num_attn_1_1_outputs,_,1.5601239,0.16244218,-2.2383647,-2.335401,-0.31956533,1.361603
128 | num_attn_1_2_outputs,_,-0.75736904,0.11321503,-0.82287806,0.718299,0.71823144,1.0727578
129 | num_attn_1_3_outputs,_,2.7385015,1.4961662,-1.1917236,-1.0514781,-6.2250605,-9.062402
130 | num_attn_2_0_outputs,_,-15.748312,-2.3748958,2.378052,5.5488276,6.0594873,5.684393
131 | num_attn_2_1_outputs,_,0.85498565,0.80901504,-0.25188714,-0.95084816,-4.0991917,-3.6710646
132 | num_attn_2_2_outputs,_,2.5925405,0.43864623,-0.7539673,-1.6541822,-4.0720997,-8.631772
133 | num_attn_2_3_outputs,_,4.2104626,-0.2540275,-5.9844904,-1.7332941,-0.9908037,1.0950465
134 | num_mlp_0_0_outputs,0,-1.8669392,-0.2961063,0.98103577,1.4933444,2.198966,1.1375257
135 | num_mlp_0_0_outputs,1,-1.1520817,-1.1762809,0.09541475,0.8781277,1.2588459,0.8453051
136 | num_mlp_0_0_outputs,2,-0.9004866,-0.8049949,0.11854656,0.39003035,0.67636615,0.64462405
137 | num_mlp_0_0_outputs,3,8.561097,-1.1257874,-3.453622,0.69200045,1.3708919,-0.49515036
138 | num_mlp_0_0_outputs,4,-2.8854432,0.90047354,6.3888383,-5.94705,-7.024973,-3.1921215
139 | num_mlp_0_0_outputs,5,-0.6971181,-0.7816255,0.16676763,0.96413225,1.1872518,0.85966665
140 | num_mlp_0_0_outputs,6,-0.26123014,-0.059790544,0.3098392,0.6849877,0.6764854,0.49069953
141 | num_mlp_0_0_outputs,7,-0.53684187,0.012539488,-0.02403504,0.5809876,0.40013784,0.6112457
142 | num_mlp_1_0_outputs,0,-0.16077685,-1.2207857,0.23239522,1.5080422,-0.60546786,1.1818341
143 | num_mlp_1_0_outputs,1,-0.6600274,0.16164501,0.88917744,1.1062784,0.20580845,1.2902285
144 | num_mlp_1_0_outputs,2,-0.07932499,0.108578734,0.19467932,0.4703639,-0.11713662,0.91529095
145 | num_mlp_1_0_outputs,3,-0.57606816,-0.23965557,0.13980961,0.46052232,-0.27391964,0.45012388
146 | num_mlp_1_0_outputs,4,-0.7861332,-0.52138984,-0.09231762,0.64113474,0.025483737,1.0184015
147 | num_mlp_1_0_outputs,5,-0.713504,0.54386413,1.3562707,-2.5753596,-1.427283,-6.9030857
148 | num_mlp_1_0_outputs,6,-0.50802475,-1.2871871,0.537766,2.643965,-1.3457818,2.97455
149 | num_mlp_1_0_outputs,7,2.0027044,-0.7188986,-4.4494953,3.1136823,0.0020262096,0.36739042
150 | num_mlp_2_0_outputs,0,-0.13204575,-0.11307499,-0.3766279,-0.07198191,0.30587465,0.33915785
151 | num_mlp_2_0_outputs,1,-1.0364666,2.084065,-0.50771797,-3.1504862,-0.30198094,0.096702345
152 | num_mlp_2_0_outputs,2,-8.098701,4.3463387,-2.0458136,5.9026895,3.121162,3.0957198
153 | num_mlp_2_0_outputs,3,-3.605029,3.1863525,2.0602572,-2.160901,-7.234171,-0.6112598
154 | num_mlp_2_0_outputs,4,1.5672542,-0.33181366,-4.4696174,-0.4666259,4.466716,-4.639328
155 | num_mlp_2_0_outputs,5,1.5313199,-2.517952,1.0063609,0.32012826,-1.8395686,-0.6061035
156 | num_mlp_2_0_outputs,6,0.21038409,2.1400955,-0.90267885,-4.3364935,-1.1109636,0.27975422
157 | num_mlp_2_0_outputs,7,1.5466431,2.9550323,-2.1725159,-3.0149686,-2.3728058,-0.8228693
158 | ones,_,-0.63681555,0.6360205,1.2210147,-0.1677267,-1.8965924,-0.9514799
159 | positions,0,0.31322083,-0.18008727,-0.17022493,0.5562112,-0.3243463,-0.33771023
160 | positions,1,-0.164062,0.3669627,1.1244484,0.6322126,-1.4446061,-1.479529
161 | positions,2,-0.76111907,0.4353367,1.4015459,-0.34046993,-1.0691034,-0.5398579
162 | positions,3,-1.0637239,-0.030700902,1.1874292,-1.0612618,-0.95897645,-1.2181816
163 | positions,4,-0.79604656,0.28970852,1.2973528,-0.36234856,-0.75552547,0.12908255
164 | positions,5,-1.1553457,0.0699891,1.1624489,-0.49714485,1.4032562,0.39587685
165 | positions,6,-0.6900308,0.2557735,1.3243333,-0.9535796,-0.028765034,0.8872918
166 | positions,7,-0.88872874,0.07091375,1.42425,-0.60613817,-1.0351824,-2.0732577
167 | tokens,0,-0.8381742,0.11838423,1.1716859,-0.34598798,-1.2162316,-1.9089854
168 | tokens,1,-0.6593003,0.58185554,1.4920444,-0.22869995,-0.9598005,-1.7809644
169 | tokens,2,-1.02963,0.06541627,1.0584888,-0.6785367,-1.1368308,-2.131806
170 | tokens,3,-0.7562651,0.465944,1.4615022,-0.5278057,-0.5604278,-0.8920955
171 | tokens,4,-0.56911784,0.50735754,1.4724189,-0.20176388,-0.87996334,-0.4631814
172 | tokens,5,-0.85632926,0.32627252,1.2416458,-0.29448336,-0.47716865,0.26062497
173 | tokens,,0.04131505,-0.46409363,0.15720077,-0.5609889,0.16757916,0.34972325
174 | tokens,,0.10905723,-0.33392945,0.07082459,-0.31309077,0.42329115,-0.34247398
175 |
--------------------------------------------------------------------------------
/programs/rasp/dyck1/dyck1_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,F,P,T
2 | attn_0_0_outputs,(,-2.6340582,3.296491,-0.44525737
3 | attn_0_0_outputs,),2.815133,-2.8173714,-0.077561244
4 | attn_0_0_outputs,,0.31516382,0.6258608,0.51490164
5 | attn_0_0_outputs,,-0.7577929,1.2806988,0.021066613
6 | attn_0_0_outputs,_,0.1783087,-0.022254534,-0.33423853
7 | attn_0_0_outputs,_,0.10090828,-0.2582165,0.1490271
8 | attn_0_0_outputs,_,-0.15222801,-0.89355934,0.5346948
9 | attn_0_0_outputs,_,0.29934147,-0.37996712,-0.39817178
10 | attn_0_0_outputs,_,-0.16398807,0.25323474,-0.46668038
11 | attn_0_0_outputs,_,0.028858155,0.07219034,-0.06156428
12 | attn_0_0_outputs,_,0.06055783,0.2945571,-0.19819283
13 | attn_0_0_outputs,_,-0.25926435,-0.16323452,0.4009143
14 | attn_0_0_outputs,_,0.52586144,0.20267111,0.8969524
15 | attn_0_0_outputs,_,0.22868425,0.1599105,-0.37888685
16 | attn_0_0_outputs,_,0.22064249,0.11402117,0.7742326
17 | attn_0_0_outputs,_,-0.24329506,0.19571169,-0.46124965
18 | attn_0_1_outputs,(,-0.36807892,1.1088231,0.27818483
19 | attn_0_1_outputs,),-0.41888747,-0.09075923,-0.527925
20 | attn_0_1_outputs,,-0.85293067,-0.5840806,-0.40153265
21 | attn_0_1_outputs,,2.7544026,0.20201652,-2.4185522
22 | attn_0_1_outputs,_,-0.3167692,-0.082709216,0.993802
23 | attn_0_1_outputs,_,0.80444956,-1.1216061,-0.76241434
24 | attn_0_1_outputs,_,0.90439874,-0.15584756,-0.8316175
25 | attn_0_1_outputs,_,0.6542889,0.48348022,0.032383945
26 | attn_0_1_outputs,_,0.5817215,-0.064832136,0.2867886
27 | attn_0_1_outputs,_,0.49341214,-0.33040112,-0.03182304
28 | attn_0_1_outputs,_,-0.083858415,0.5024774,0.1928215
29 | attn_0_1_outputs,_,0.3615586,0.16559708,-0.07771133
30 | attn_0_1_outputs,_,1.004799,-0.7025519,-0.46842074
31 | attn_0_1_outputs,_,0.43362275,0.28161952,-0.0104415575
32 | attn_0_1_outputs,_,-0.31160653,0.23411115,-0.57747537
33 | attn_0_1_outputs,_,-0.37716666,0.5208574,-0.76324815
34 | attn_0_2_outputs,(,-0.030860722,1.7733968,-1.1154492
35 | attn_0_2_outputs,),0.8185658,0.98675424,-0.65843534
36 | attn_0_2_outputs,,-0.18944798,1.2176713,-0.42110875
37 | attn_0_2_outputs,,-0.34205785,-1.2517091,-0.5264156
38 | attn_0_2_outputs,_,-0.66757905,0.14454332,-0.24297304
39 | attn_0_2_outputs,_,0.88671744,0.62714237,0.2174341
40 | attn_0_2_outputs,_,-0.100170776,0.9241575,-0.67759424
41 | attn_0_2_outputs,_,-0.258302,1.2594484,-0.8207846
42 | attn_0_2_outputs,_,0.7526842,0.39307418,-0.562563
43 | attn_0_2_outputs,_,-0.6082082,0.36667,0.17507246
44 | attn_0_2_outputs,_,0.16467999,0.74238044,-0.7653758
45 | attn_0_2_outputs,_,-0.4535027,0.41556233,0.13512231
46 | attn_0_2_outputs,_,0.28487095,0.34868884,-0.17336306
47 | attn_0_2_outputs,_,-0.30027777,0.45009154,-0.021399405
48 | attn_0_2_outputs,_,-0.115277715,0.6226942,0.47862878
49 | attn_0_2_outputs,_,0.3414086,-0.04924957,-0.71807915
50 | attn_0_3_outputs,(,-0.52774686,1.061042,-1.3247046
51 | attn_0_3_outputs,),0.53207314,-0.51257735,0.37381876
52 | attn_0_3_outputs,,1.2845441,-0.9961886,-1.6160716
53 | attn_0_3_outputs,,-0.07281321,0.24826683,0.22082955
54 | attn_0_3_outputs,_,-0.40482992,-0.7973792,-0.39291495
55 | attn_0_3_outputs,_,0.6773935,-0.1490348,-0.9378139
56 | attn_0_3_outputs,_,-2.6586642,0.3645929,1.424275
57 | attn_0_3_outputs,_,0.5714178,-0.15675014,-0.5435661
58 | attn_0_3_outputs,_,-0.3804901,1.0324043,0.24484947
59 | attn_0_3_outputs,_,0.34369984,0.27104208,-0.39457268
60 | attn_0_3_outputs,_,0.7677406,0.23606537,-0.36674848
61 | attn_0_3_outputs,_,0.72403216,0.94206196,-0.33496436
62 | attn_0_3_outputs,_,-0.606932,0.45185584,-0.014735105
63 | attn_0_3_outputs,_,-0.47034562,0.1342081,0.53295195
64 | attn_0_3_outputs,_,0.14294454,-0.304042,0.20307107
65 | attn_0_3_outputs,_,0.3230739,-0.42828202,-1.0465811
66 | attn_1_0_outputs,(,-0.6138929,1.1492342,-0.21226932
67 | attn_1_0_outputs,),1.1421751,0.88136363,0.98169017
68 | attn_1_0_outputs,,0.4028954,-0.85677516,-0.023720358
69 | attn_1_0_outputs,,1.0315775,-0.8483394,-1.2852764
70 | attn_1_0_outputs,_,0.15765458,-0.6344901,-0.5549551
71 | attn_1_0_outputs,_,-0.05087188,0.29270756,-0.5156284
72 | attn_1_0_outputs,_,1.1122603,-0.8028311,-0.06559855
73 | attn_1_0_outputs,_,-0.48832175,-0.15441939,-0.2644985
74 | attn_1_0_outputs,_,-0.37901145,-0.03420561,-0.5548248
75 | attn_1_0_outputs,_,0.05604345,-0.97195405,1.0271589
76 | attn_1_0_outputs,_,0.55539984,0.057354666,-0.4074509
77 | attn_1_0_outputs,_,-0.77038145,-0.09844211,0.07271357
78 | attn_1_0_outputs,_,0.52502674,-0.008494564,0.5783807
79 | attn_1_0_outputs,_,-0.7112233,0.87106544,0.24254535
80 | attn_1_0_outputs,_,0.29931223,-0.7983458,-0.2591785
81 | attn_1_0_outputs,_,1.1367513,-1.1649325,-0.25845855
82 | attn_1_1_outputs,(,-0.5425992,1.5051843,-0.28839168
83 | attn_1_1_outputs,),1.5574366,-1.0693368,0.14204933
84 | attn_1_1_outputs,,0.58851004,0.6008777,-0.07114553
85 | attn_1_1_outputs,,-0.32155347,-1.0976663,-1.4935772
86 | attn_1_1_outputs,_,0.17493103,-0.38030365,0.53236455
87 | attn_1_1_outputs,_,1.0686699,-0.120278545,-0.56144714
88 | attn_1_1_outputs,_,1.3084234,0.3557633,-0.14442345
89 | attn_1_1_outputs,_,0.47586453,0.20887673,-0.15990862
90 | attn_1_1_outputs,_,0.13891101,0.68185794,-0.5525119
91 | attn_1_1_outputs,_,-0.6078569,1.2303034,-2.18228
92 | attn_1_1_outputs,_,0.91525626,-0.67010164,-0.21897195
93 | attn_1_1_outputs,_,1.4655004,-0.47915524,0.97498626
94 | attn_1_1_outputs,_,1.1796371,-1.6545213,-0.45778522
95 | attn_1_1_outputs,_,-0.7026438,0.5664589,1.0344129
96 | attn_1_1_outputs,_,0.5785904,0.5293408,0.37516353
97 | attn_1_1_outputs,_,2.3779967,-2.060698,-1.0075454
98 | attn_1_2_outputs,(,-0.056032557,0.6231306,0.24149144
99 | attn_1_2_outputs,),0.44365737,-0.98022085,-0.64319813
100 | attn_1_2_outputs,,0.08831471,-0.29584908,-0.15343662
101 | attn_1_2_outputs,,-1.3659412,3.5988708,-0.66621685
102 | attn_1_2_outputs,_,-0.1274261,0.34521273,0.027675765
103 | attn_1_2_outputs,_,-0.14180441,1.1854141,-0.8122686
104 | attn_1_2_outputs,_,0.4875018,0.45869988,0.1374755
105 | attn_1_2_outputs,_,-0.32931647,0.87830615,-0.49895287
106 | attn_1_2_outputs,_,-0.1785307,0.16045158,-0.8639972
107 | attn_1_2_outputs,_,0.5269155,1.2485906,-1.8244286
108 | attn_1_2_outputs,_,-0.383124,-0.5021012,-0.55019313
109 | attn_1_2_outputs,_,-0.092871554,0.22963502,0.12724838
110 | attn_1_2_outputs,_,-0.8239071,0.94725597,0.037750408
111 | attn_1_2_outputs,_,0.34750172,0.9910317,-0.7895062
112 | attn_1_2_outputs,_,-0.66590446,0.40987346,0.37104082
113 | attn_1_2_outputs,_,-0.052774414,-0.43269822,-0.5298429
114 | attn_1_3_outputs,(,-0.81292146,1.8696464,-0.5263783
115 | attn_1_3_outputs,),-0.19519545,-0.15165852,-0.19049717
116 | attn_1_3_outputs,,-0.3622958,0.81331193,0.30731088
117 | attn_1_3_outputs,,0.23445317,0.015669065,-1.4538736
118 | attn_1_3_outputs,_,-0.35138428,0.28486276,-0.16167411
119 | attn_1_3_outputs,_,-0.45811188,0.43144393,0.20147736
120 | attn_1_3_outputs,_,0.5497336,0.15336615,0.7609401
121 | attn_1_3_outputs,_,1.5862544,0.9128603,-1.2414508
122 | attn_1_3_outputs,_,0.37166864,0.21070029,-0.18914412
123 | attn_1_3_outputs,_,0.12535681,0.1599941,0.54591626
124 | attn_1_3_outputs,_,-0.6368585,0.37556642,-0.5828852
125 | attn_1_3_outputs,_,-0.935505,1.2570758,0.15145183
126 | attn_1_3_outputs,_,0.5280753,0.11495201,-0.31024593
127 | attn_1_3_outputs,_,-0.3986401,0.9080014,0.763646
128 | attn_1_3_outputs,_,1.2436559,-0.40404543,-1.3131849
129 | attn_1_3_outputs,_,0.686876,-0.80954635,-0.45231095
130 | attn_2_0_outputs,(,-1.5561209,1.4590123,-0.57100546
131 | attn_2_0_outputs,),0.9308131,-1.3626683,-1.3802326
132 | attn_2_0_outputs,,-0.9678551,0.26538986,-0.8596553
133 | attn_2_0_outputs,,0.3270525,1.4117384,-0.017223258
134 | attn_2_0_outputs,_,0.10557971,1.4797398,-0.36684495
135 | attn_2_0_outputs,_,0.4341196,-0.31931913,1.5029707
136 | attn_2_0_outputs,_,-0.2973821,-0.19198877,-0.1289013
137 | attn_2_0_outputs,_,-0.36797506,0.47559232,0.19091709
138 | attn_2_0_outputs,_,-0.4748216,0.026234714,0.63951665
139 | attn_2_0_outputs,_,0.94507056,0.34140712,0.31623304
140 | attn_2_0_outputs,_,-0.028984,-0.26938406,0.25619373
141 | attn_2_0_outputs,_,-0.06181845,-0.548181,0.2155078
142 | attn_2_0_outputs,_,0.3375238,0.1336795,-0.2824661
143 | attn_2_0_outputs,_,0.575439,0.9426243,0.64006364
144 | attn_2_0_outputs,_,0.074084885,0.34421614,0.43613386
145 | attn_2_0_outputs,_,0.47341222,0.08374641,0.17351331
146 | attn_2_1_outputs,(,-0.45127738,1.0018231,0.44534314
147 | attn_2_1_outputs,),-0.6839655,0.074150436,-0.2715587
148 | attn_2_1_outputs,,-0.70120454,-0.0535122,0.48966962
149 | attn_2_1_outputs,,-0.2806599,-0.20490536,-0.7978403
150 | attn_2_1_outputs,_,0.20717245,-0.7625038,-0.5953323
151 | attn_2_1_outputs,_,0.15005049,-1.1470549,-0.4587043
152 | attn_2_1_outputs,_,0.9465549,-0.71227086,-0.25911647
153 | attn_2_1_outputs,_,-0.3448204,-0.19821261,-0.5702802
154 | attn_2_1_outputs,_,0.16869305,0.3031439,-0.11745471
155 | attn_2_1_outputs,_,-0.4607347,1.3273661,-0.5856749
156 | attn_2_1_outputs,_,-0.3482659,-0.6055313,-1.2954234
157 | attn_2_1_outputs,_,0.6048929,0.2115645,0.3289516
158 | attn_2_1_outputs,_,-0.19371964,-0.4275286,-0.74410176
159 | attn_2_1_outputs,_,-0.13987288,0.3069642,0.38531253
160 | attn_2_1_outputs,_,0.039333925,0.6699897,0.034916207
161 | attn_2_1_outputs,_,1.5877175,-0.9881175,0.006162866
162 | attn_2_2_outputs,(,-2.750019,2.7980466,2.5939832
163 | attn_2_2_outputs,),3.6954346,-1.7907751,-2.5867593
164 | attn_2_2_outputs,,-0.079127155,-0.9472731,-1.1885493
165 | attn_2_2_outputs,,4.0218706,-3.1554546,-4.7507057
166 | attn_2_2_outputs,_,0.9610852,-0.08878815,0.08363015
167 | attn_2_2_outputs,_,-0.08353972,0.37381604,-0.28025633
168 | attn_2_2_outputs,_,0.35657725,0.8969886,0.072851464
169 | attn_2_2_outputs,_,-0.32833374,-0.7531067,0.3516615
170 | attn_2_2_outputs,_,0.9393161,-0.5649734,-0.5196474
171 | attn_2_2_outputs,_,0.08626939,-1.1392281,1.6193309
172 | attn_2_2_outputs,_,1.2216192,-0.21259107,-1.3641104
173 | attn_2_2_outputs,_,-0.09241636,0.27687886,-0.16864115
174 | attn_2_2_outputs,_,0.49919057,0.65654963,0.21979749
175 | attn_2_2_outputs,_,-0.7234905,-0.14355081,0.40794903
176 | attn_2_2_outputs,_,0.72649467,0.1041588,0.011536894
177 | attn_2_2_outputs,_,0.27901632,0.6072106,-0.4015586
178 | attn_2_3_outputs,(,-1.6574559,1.2850804,0.09690388
179 | attn_2_3_outputs,),-0.37473014,0.7424306,0.21433578
180 | attn_2_3_outputs,,-1.5348345,1.6227428,1.3219541
181 | attn_2_3_outputs,,0.25438994,-0.1620961,-0.8211459
182 | attn_2_3_outputs,_,-0.051882792,0.7923025,-0.18180363
183 | attn_2_3_outputs,_,-0.6806416,0.61710155,-0.14199154
184 | attn_2_3_outputs,_,0.82735395,0.37269086,0.15985933
185 | attn_2_3_outputs,_,0.5044491,0.6499633,-0.1310748
186 | attn_2_3_outputs,_,0.3183162,-0.5054114,-0.007711913
187 | attn_2_3_outputs,_,-0.33328158,0.29277143,0.18589394
188 | attn_2_3_outputs,_,0.047428608,0.4182818,0.6526039
189 | attn_2_3_outputs,_,-1.0801055,0.58600456,0.11516155
190 | attn_2_3_outputs,_,0.04333351,0.2217204,0.12683378
191 | attn_2_3_outputs,_,-1.0841376,0.9090559,-0.7465817
192 | attn_2_3_outputs,_,-0.87214464,0.8850236,0.20182282
193 | attn_2_3_outputs,_,-0.46616295,-0.37282166,0.15026246
194 | mlp_0_0_outputs,0,-0.609197,-0.027131865,-0.22918989
195 | mlp_0_0_outputs,1,0.556234,-0.03605403,-0.371482
196 | mlp_0_0_outputs,10,-1.310118,3.0756617,0.20343237
197 | mlp_0_0_outputs,11,-0.04378883,-0.092148624,-0.11734988
198 | mlp_0_0_outputs,12,-0.90647274,-0.1498387,-0.19741559
199 | mlp_0_0_outputs,13,0.28161588,0.206903,-0.98625535
200 | mlp_0_0_outputs,14,0.5062789,-0.39951432,-0.25588968
201 | mlp_0_0_outputs,15,0.15142694,0.89919674,0.6030912
202 | mlp_0_0_outputs,2,0.72293925,-0.8654614,0.5670439
203 | mlp_0_0_outputs,3,0.7204033,0.15438499,0.8819249
204 | mlp_0_0_outputs,4,-0.4914485,0.48221803,-0.026991796
205 | mlp_0_0_outputs,5,-0.2592103,-0.1592593,-0.34026244
206 | mlp_0_0_outputs,6,-0.26047266,0.2415719,0.13886715
207 | mlp_0_0_outputs,7,0.13437562,0.109038174,-0.27345678
208 | mlp_0_0_outputs,8,0.15529574,0.25750098,0.09392571
209 | mlp_0_0_outputs,9,-2.1050072,0.8252483,-0.97018075
210 | mlp_1_0_outputs,0,0.28741208,0.35307455,0.46118635
211 | mlp_1_0_outputs,1,-0.02276508,0.610253,-0.40055403
212 | mlp_1_0_outputs,10,0.19024082,1.1057127,-0.119028404
213 | mlp_1_0_outputs,11,-0.16726859,0.8911886,0.06611366
214 | mlp_1_0_outputs,12,-0.2248021,0.68136096,-0.19288228
215 | mlp_1_0_outputs,13,-0.2775437,0.17751533,-0.944292
216 | mlp_1_0_outputs,14,-0.38200572,0.5241664,-0.081385374
217 | mlp_1_0_outputs,15,-0.5838907,0.96178657,-1.1940188
218 | mlp_1_0_outputs,2,-0.398339,0.6708183,-0.21698345
219 | mlp_1_0_outputs,3,-0.3026861,0.12032325,-0.1728985
220 | mlp_1_0_outputs,4,-2.0897841,2.2824123,0.1175223
221 | mlp_1_0_outputs,5,-0.35795817,0.78761435,0.056806456
222 | mlp_1_0_outputs,6,-1.2497257,0.2776687,-0.51030177
223 | mlp_1_0_outputs,7,-0.025528096,0.9015349,0.009612377
224 | mlp_1_0_outputs,8,-1.0789367,0.115357794,-0.4832368
225 | mlp_1_0_outputs,9,0.14225361,0.18992919,0.26218814
226 | mlp_2_0_outputs,0,-0.9600631,-0.0905738,0.4139894
227 | mlp_2_0_outputs,1,-1.2516974,0.26433572,-0.093466
228 | mlp_2_0_outputs,10,-0.4429331,1.2817378,1.3716829
229 | mlp_2_0_outputs,11,-0.8686043,0.63608295,0.88492036
230 | mlp_2_0_outputs,12,0.5736482,0.27950564,-0.3706815
231 | mlp_2_0_outputs,13,-1.066069,0.9496548,0.7379837
232 | mlp_2_0_outputs,14,1.1608202,-1.783248,-1.4169054
233 | mlp_2_0_outputs,15,-0.8355681,0.7468592,0.7600724
234 | mlp_2_0_outputs,2,0.41237113,0.16725028,-0.3574132
235 | mlp_2_0_outputs,3,-0.86503994,1.2446153,0.70891947
236 | mlp_2_0_outputs,4,-0.27445322,1.5090894,1.6432297
237 | mlp_2_0_outputs,5,1.4590892,-0.91437083,-0.60318696
238 | mlp_2_0_outputs,6,-1.553053,0.7144941,0.7145207
239 | mlp_2_0_outputs,7,-1.256165,0.2772562,0.011028548
240 | mlp_2_0_outputs,8,0.7696337,-1.3811562,0.09975987
241 | mlp_2_0_outputs,9,6.067684,-2.7872226,-8.341669
242 | num_attn_0_0_outputs,_,3.7398744,-7.181635,3.5538623
243 | num_attn_0_1_outputs,_,0.49548444,0.100667655,-0.16136481
244 | num_attn_0_2_outputs,_,2.1738443,-1.07889,-1.8416303
245 | num_attn_0_3_outputs,_,0.691551,-0.5990272,0.014788681
246 | num_attn_1_0_outputs,_,2.9874318,-2.4701202,-0.3135344
247 | num_attn_1_1_outputs,_,0.12491817,-0.38467312,0.048041295
248 | num_attn_1_2_outputs,_,0.3216399,0.071806945,0.3041418
249 | num_attn_1_3_outputs,_,-1.1208845,1.7802722,0.1500264
250 | num_attn_2_0_outputs,_,1.0652643,0.16920693,0.12847312
251 | num_attn_2_1_outputs,_,0.2974014,-1.6822767,-0.6943763
252 | num_attn_2_2_outputs,_,1.2344797,-1.0337117,-0.39935145
253 | num_attn_2_3_outputs,_,0.68644404,0.025850799,0.31126392
254 | num_mlp_0_0_outputs,0,1.0641823,-0.54406065,0.1806945
255 | num_mlp_0_0_outputs,1,0.28959027,-1.0161468,-0.6264937
256 | num_mlp_0_0_outputs,10,-0.043870565,0.95602506,0.09602879
257 | num_mlp_0_0_outputs,11,0.17968851,0.8659425,0.06156473
258 | num_mlp_0_0_outputs,12,-0.48285016,0.5750547,-0.2724834
259 | num_mlp_0_0_outputs,13,0.37540218,-0.24070558,-0.11528563
260 | num_mlp_0_0_outputs,14,-0.41391063,0.54637676,-0.13351686
261 | num_mlp_0_0_outputs,15,-0.29117483,0.471855,-0.28612953
262 | num_mlp_0_0_outputs,2,-0.20709011,0.96985,0.14005361
263 | num_mlp_0_0_outputs,3,-0.7341656,-0.06500526,-0.43360725
264 | num_mlp_0_0_outputs,4,-0.7252949,0.9496459,0.97981936
265 | num_mlp_0_0_outputs,5,-0.46789286,0.19887882,-0.40079117
266 | num_mlp_0_0_outputs,6,-0.11914711,0.50922734,-0.12805091
267 | num_mlp_0_0_outputs,7,-0.5103994,-0.83839464,-0.7605896
268 | num_mlp_0_0_outputs,8,0.0003218591,0.4606981,-0.7254237
269 | num_mlp_0_0_outputs,9,0.12420503,-1.4992923,-1.0763122
270 | num_mlp_1_0_outputs,0,-0.50804985,-0.048254747,-0.50164807
271 | num_mlp_1_0_outputs,1,-0.18716975,0.314667,-0.13772789
272 | num_mlp_1_0_outputs,10,-0.018390449,0.30108446,-0.16875868
273 | num_mlp_1_0_outputs,11,-0.2499132,0.2316408,-0.27131632
274 | num_mlp_1_0_outputs,12,-0.36267832,-0.029418426,-0.3747254
275 | num_mlp_1_0_outputs,13,-0.5715566,-0.050854474,-0.74630326
276 | num_mlp_1_0_outputs,14,-0.43067545,-0.19346179,0.2828539
277 | num_mlp_1_0_outputs,15,-0.8118014,0.031878382,-0.5857855
278 | num_mlp_1_0_outputs,2,-0.94931656,0.090484895,-0.4667937
279 | num_mlp_1_0_outputs,3,-0.3243798,0.38227555,-0.16742714
280 | num_mlp_1_0_outputs,4,-0.55292064,0.0885116,-0.16673157
281 | num_mlp_1_0_outputs,5,-0.19122665,0.56547916,0.026035376
282 | num_mlp_1_0_outputs,6,-0.35779983,0.07764958,-0.5537944
283 | num_mlp_1_0_outputs,7,-0.16128884,0.30881864,-0.5096637
284 | num_mlp_1_0_outputs,8,-0.77744365,-0.02209261,-0.67441314
285 | num_mlp_1_0_outputs,9,0.1671805,0.7428009,0.7823384
286 | num_mlp_2_0_outputs,0,-0.3438476,-0.018625923,-0.8805979
287 | num_mlp_2_0_outputs,1,-0.19985998,0.26809484,-0.60468084
288 | num_mlp_2_0_outputs,10,-0.39671543,0.106505,-0.6291543
289 | num_mlp_2_0_outputs,11,-0.50791365,0.48418602,-0.25579256
290 | num_mlp_2_0_outputs,12,-3.960384,3.2354944,1.5324167
291 | num_mlp_2_0_outputs,13,-0.061338473,-0.11631737,-0.47945797
292 | num_mlp_2_0_outputs,14,-0.30011797,0.09830846,-0.6083012
293 | num_mlp_2_0_outputs,15,-0.38574433,0.19097115,-0.82048756
294 | num_mlp_2_0_outputs,2,-0.20133315,0.10674115,-0.58823496
295 | num_mlp_2_0_outputs,3,0.0117857745,0.66684407,0.3024773
296 | num_mlp_2_0_outputs,4,-0.33616486,0.30070633,-0.4564914
297 | num_mlp_2_0_outputs,5,-0.04366164,0.78352547,-0.033621307
298 | num_mlp_2_0_outputs,6,0.08861894,0.54160535,0.0855582
299 | num_mlp_2_0_outputs,7,0.039711624,0.9778597,0.20092271
300 | num_mlp_2_0_outputs,8,0.72047126,0.44210163,-0.335978
301 | num_mlp_2_0_outputs,9,-0.006253452,0.69338506,-0.3173577
302 | ones,_,-0.7997238,1.0074667,0.026579212
303 | positions,0,-0.29179272,0.23241787,0.6308148
304 | positions,1,-3.904494,5.006513,-1.8895661
305 | positions,10,2.2057185,-5.531098,5.206517
306 | positions,11,3.9026654,-2.0590394,-2.612058
307 | positions,12,4.1754217,-4.574852,5.335018
308 | positions,13,3.8170507,-2.526535,-3.9886053
309 | positions,14,0.7010598,-0.4262054,-0.4852942
310 | positions,15,4.7547708,-4.233093,-4.5563035
311 | positions,2,-2.9996147,-6.9421015,6.494166
312 | positions,3,1.2800095,2.991268,-9.872106
313 | positions,4,1.2470323,-2.7899837,2.7413142
314 | positions,5,-9.135268,14.433624,-8.043073
315 | positions,6,-7.3264008,11.024153,-0.3841814
316 | positions,7,0.30973187,2.3126788,-1.3860149
317 | positions,8,-10.491879,17.33303,-5.2549305
318 | positions,9,-1.5044698,3.8648603,-1.0246311
319 | tokens,(,2.0314052,0.9187388,-5.384055
320 | tokens,),-3.5163503,0.672222,2.1528707
321 | tokens,,0.3435328,0.89075476,-1.2196645
322 | tokens,,-0.33722147,-0.05571238,-0.4713453
323 | tokens,_,0.55962527,-0.5702431,0.23070188
324 | tokens,_,0.2838472,0.7993582,-0.99131817
325 | tokens,_,0.12776248,0.89144236,-0.23171602
326 | tokens,_,-0.13315062,0.95989436,0.9555972
327 | tokens,_,-0.6196355,-0.36522523,-0.359687
328 | tokens,_,0.10280879,0.3058465,-0.47682306
329 | tokens,_,0.35391825,-0.20935811,0.6336341
330 | tokens,_,1.0826828,-0.8673405,-0.55527455
331 | tokens,_,0.5650465,0.42321026,-0.6229729
332 | tokens,_,0.34496674,-0.7241081,-0.40609998
333 | tokens,_,-0.053234033,0.15473613,0.3231568
334 | tokens,_,0.897978,-0.39115798,-0.12659745
335 |
--------------------------------------------------------------------------------
/programs/rasp/dyck2/dyck2_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,F,P,T
2 | attn_0_0_outputs,(,-0.3893371,0.60064846,-1.0043755
3 | attn_0_0_outputs,),0.45308644,-0.33719206,-0.25703704
4 | attn_0_0_outputs,,0.7680508,-1.140804,-0.6336192
5 | attn_0_0_outputs,,1.3979992,-0.55402267,-2.9333327
6 | attn_0_0_outputs,_,-0.27462152,-0.2883787,-0.084659226
7 | attn_0_0_outputs,_,0.41020688,0.80682915,-0.36093912
8 | attn_0_0_outputs,_,0.056031883,-0.039105184,0.19863698
9 | attn_0_0_outputs,_,-0.41961256,-0.54929376,0.43671665
10 | attn_0_0_outputs,_,0.13279785,0.36707786,0.40994072
11 | attn_0_0_outputs,_,-0.2363711,0.5147661,-0.18510109
12 | attn_0_0_outputs,_,0.12620896,-0.18906875,0.9231106
13 | attn_0_0_outputs,_,-0.029982865,0.23939228,1.6595911
14 | attn_0_0_outputs,_,-0.73210377,-1.1544671,0.57133
15 | attn_0_0_outputs,_,-0.23087524,1.8027706,-1.5571856
16 | attn_0_0_outputs,{,-0.04748568,1.2735734,-0.45339546
17 | attn_0_0_outputs,},0.5547761,-0.3066463,-0.1081352
18 | attn_0_1_outputs,(,-0.06017389,0.23692839,-0.65924984
19 | attn_0_1_outputs,),1.5613009,0.05749913,-1.1849948
20 | attn_0_1_outputs,,0.30928898,0.60175115,-0.77076155
21 | attn_0_1_outputs,,-0.05464487,-6.3642993,5.9092317
22 | attn_0_1_outputs,_,0.82652664,0.4707333,0.045426637
23 | attn_0_1_outputs,_,-0.43456507,1.0889953,-0.6075947
24 | attn_0_1_outputs,_,-0.650409,0.48931292,-0.960644
25 | attn_0_1_outputs,_,-0.14928144,-0.3577456,-0.08982656
26 | attn_0_1_outputs,_,-0.4732285,-0.5932854,-0.044864617
27 | attn_0_1_outputs,_,0.17000043,-0.046643566,0.04345899
28 | attn_0_1_outputs,_,-0.049569067,0.36960238,0.43352073
29 | attn_0_1_outputs,_,-0.43735144,-0.4731728,0.18792638
30 | attn_0_1_outputs,_,-0.46125263,-0.9085558,1.1110471
31 | attn_0_1_outputs,_,0.19633614,1.5570519,0.18224098
32 | attn_0_1_outputs,{,-0.11224139,0.64842534,-0.33784837
33 | attn_0_1_outputs,},-0.6292952,1.3489552,0.3662693
34 | attn_1_0_outputs,0,0.20661825,1.3388423,-2.5574758
35 | attn_1_0_outputs,1,2.5823343,-0.7297078,-3.195439
36 | attn_1_0_outputs,10,-0.5139535,-0.078736916,1.6875057
37 | attn_1_0_outputs,11,0.041748732,-0.14228384,-0.14014556
38 | attn_1_0_outputs,12,-0.454622,-0.6740401,0.04633342
39 | attn_1_0_outputs,13,3.6046548,-1.8717623,-2.699525
40 | attn_1_0_outputs,14,-1.5789233,0.97090966,-0.3266689
41 | attn_1_0_outputs,15,-0.24816914,-0.10174749,1.3838878
42 | attn_1_0_outputs,2,0.11910092,0.22456145,1.3250326
43 | attn_1_0_outputs,3,-0.4317825,-0.28709865,-0.20844361
44 | attn_1_0_outputs,4,-3.6241753,3.323929,0.5773792
45 | attn_1_0_outputs,5,-0.11809031,0.24210535,-0.23381685
46 | attn_1_0_outputs,6,-1.0743729,1.1453243,1.608349
47 | attn_1_0_outputs,7,-3.164238,0.699024,1.0340277
48 | attn_1_0_outputs,8,-0.28543827,0.40379593,0.14130151
49 | attn_1_0_outputs,9,-0.5014043,-0.20715642,0.18966955
50 | attn_1_1_outputs,0,0.49038798,1.5800108,-1.2205557
51 | attn_1_1_outputs,1,-0.93285435,1.7327744,-1.2455193
52 | attn_1_1_outputs,10,-0.039011396,-0.65827584,0.71766496
53 | attn_1_1_outputs,11,0.23105314,-0.0301322,1.3648343
54 | attn_1_1_outputs,12,1.4345965,-2.2012353,0.9856132
55 | attn_1_1_outputs,13,-0.18832216,0.19023667,0.89298505
56 | attn_1_1_outputs,14,0.03877773,0.43423906,0.40338022
57 | attn_1_1_outputs,15,2.353314,-1.4448909,-0.21294129
58 | attn_1_1_outputs,2,1.4249635,-1.0321774,0.75964826
59 | attn_1_1_outputs,3,-1.1570958,2.1030843,1.7259277
60 | attn_1_1_outputs,4,-1.1145258,1.9922961,-2.486801
61 | attn_1_1_outputs,5,-0.40263888,0.60196733,1.4514791
62 | attn_1_1_outputs,6,2.065989,-2.143043,0.49976075
63 | attn_1_1_outputs,7,0.30126733,-0.6678852,0.65130216
64 | attn_1_1_outputs,8,-0.061816756,0.13271871,1.2922349
65 | attn_1_1_outputs,9,0.5106305,-1.2614174,1.5335835
66 | attn_2_0_outputs,0,-0.05937145,1.2803885,-0.6700919
67 | attn_2_0_outputs,1,-1.3801527,2.8249412,0.48383245
68 | attn_2_0_outputs,10,0.060948834,-0.13837986,-0.054340422
69 | attn_2_0_outputs,11,-0.37168485,-0.076273456,-0.64383096
70 | attn_2_0_outputs,12,0.13798621,-0.6402446,0.5534661
71 | attn_2_0_outputs,13,0.8956216,-0.676867,0.5628605
72 | attn_2_0_outputs,14,0.23056701,-1.5529728,0.9693785
73 | attn_2_0_outputs,15,1.1047347,-0.7788408,-0.9719154
74 | attn_2_0_outputs,2,0.7777873,-1.1186872,0.052649383
75 | attn_2_0_outputs,3,-1.7339926,2.1315053,1.7853323
76 | attn_2_0_outputs,4,-1.377781,1.8794999,-1.3036643
77 | attn_2_0_outputs,5,3.117208,-1.8162274,1.6918435
78 | attn_2_0_outputs,6,0.7117053,0.103766,0.21950138
79 | attn_2_0_outputs,7,0.6329945,-1.5896332,1.3083339
80 | attn_2_0_outputs,8,4.091406,-3.2152696,0.2076941
81 | attn_2_0_outputs,9,-0.04400771,-0.26516256,0.04262287
82 | attn_2_1_outputs,0,-1.3274802,2.2420464,-0.69805795
83 | attn_2_1_outputs,1,-0.65939707,2.6968226,-0.36676592
84 | attn_2_1_outputs,10,0.04043235,-0.5054306,0.015038767
85 | attn_2_1_outputs,11,-0.6721706,-0.31992498,-0.16208978
86 | attn_2_1_outputs,12,1.5115396,-0.9017534,-0.3154115
87 | attn_2_1_outputs,13,0.79894733,-1.3227056,-0.10762319
88 | attn_2_1_outputs,14,0.12715805,-1.4680986,0.32952312
89 | attn_2_1_outputs,15,-0.2186547,-0.48416838,0.35051808
90 | attn_2_1_outputs,2,0.37131494,-1.1129726,0.40816385
91 | attn_2_1_outputs,3,-1.511338,1.4423819,0.6649975
92 | attn_2_1_outputs,4,-0.44196594,1.0740502,-0.8123087
93 | attn_2_1_outputs,5,5.7057724,-3.2706978,-4.0638323
94 | attn_2_1_outputs,6,2.224006,-1.0677377,-1.4632593
95 | attn_2_1_outputs,7,0.38812694,-0.7842454,1.4326475
96 | attn_2_1_outputs,8,-0.65958214,0.6723407,0.006303966
97 | attn_2_1_outputs,9,0.2128453,-1.0684156,2.4653118
98 | mlp_0_0_outputs,0,1.4124762,0.3918622,0.96247226
99 | mlp_0_0_outputs,1,0.035059676,-0.78724706,-0.49286532
100 | mlp_0_0_outputs,10,0.43810797,0.6196286,-0.29613635
101 | mlp_0_0_outputs,11,0.024041645,-0.45532593,-0.24850975
102 | mlp_0_0_outputs,12,0.21156618,0.30932778,0.49199212
103 | mlp_0_0_outputs,13,0.038805455,-0.1897838,0.03590136
104 | mlp_0_0_outputs,14,0.5470265,0.35249376,-0.26997298
105 | mlp_0_0_outputs,15,0.41228998,-1.3040318,1.2210026
106 | mlp_0_0_outputs,2,-0.45929217,1.0102494,-0.034102287
107 | mlp_0_0_outputs,3,0.46021348,0.09072253,0.099708
108 | mlp_0_0_outputs,4,1.9909586,-0.93861586,-2.8465278
109 | mlp_0_0_outputs,5,1.6368363,-1.6418909,2.8530548
110 | mlp_0_0_outputs,6,-0.85980654,0.6846919,-0.6151556
111 | mlp_0_0_outputs,7,-0.18866596,0.42998883,-0.31827375
112 | mlp_0_0_outputs,8,0.032268297,-0.017944142,0.26745126
113 | mlp_0_0_outputs,9,-0.66864365,-0.14510924,-1.214156
114 | mlp_0_1_outputs,0,0.45482358,-1.2842598,-0.60993576
115 | mlp_0_1_outputs,1,0.31253016,-0.7943685,1.5131452
116 | mlp_0_1_outputs,10,0.25627092,0.22809455,1.0208089
117 | mlp_0_1_outputs,11,0.76562774,-0.18769291,-0.1974224
118 | mlp_0_1_outputs,12,0.38039443,-0.5691965,-0.091124795
119 | mlp_0_1_outputs,13,2.6179564,-0.32311362,-0.72676677
120 | mlp_0_1_outputs,14,-0.040358998,0.030396035,-0.82734126
121 | mlp_0_1_outputs,15,-0.46312347,0.32302356,0.42312586
122 | mlp_0_1_outputs,2,1.068703,-0.067160755,-0.58011967
123 | mlp_0_1_outputs,3,0.14145625,0.29858097,0.6092269
124 | mlp_0_1_outputs,4,-1.5048414,1.6003934,-0.54811466
125 | mlp_0_1_outputs,5,-0.07364487,0.103354104,0.5413819
126 | mlp_0_1_outputs,6,-0.024732672,0.080756046,0.43893534
127 | mlp_0_1_outputs,7,-1.79316,1.7965577,1.1761781
128 | mlp_0_1_outputs,8,-0.4545691,0.09123524,-2.1593776
129 | mlp_0_1_outputs,9,-0.8172794,-0.3064075,0.030816736
130 | mlp_1_0_outputs,0,0.22648785,-0.25926873,-0.506777
131 | mlp_1_0_outputs,1,-0.07746484,-0.984498,6.8709683
132 | mlp_1_0_outputs,10,0.20862451,-0.022583699,-0.790106
133 | mlp_1_0_outputs,11,1.020657,0.9678695,-0.295873
134 | mlp_1_0_outputs,12,0.3965819,0.2451575,-0.6036312
135 | mlp_1_0_outputs,13,0.47067815,0.2438208,-0.6940792
136 | mlp_1_0_outputs,14,-0.35612527,1.1211482,-1.4271184
137 | mlp_1_0_outputs,15,-0.29890123,-0.32131162,-0.707251
138 | mlp_1_0_outputs,2,0.5454376,0.30075255,-1.1875064
139 | mlp_1_0_outputs,3,0.9890708,0.9738718,0.17902154
140 | mlp_1_0_outputs,4,0.35423172,0.36298084,-1.0344858
141 | mlp_1_0_outputs,5,0.38208264,0.20733608,-1.3121635
142 | mlp_1_0_outputs,6,0.3350503,0.058604088,-0.639987
143 | mlp_1_0_outputs,7,0.48388848,0.2619127,-0.6384467
144 | mlp_1_0_outputs,8,-0.0939644,1.6520033,-2.8920875
145 | mlp_1_0_outputs,9,0.46636695,0.19207071,-0.47046715
146 | mlp_1_1_outputs,0,1.0177721,-0.4037793,-0.67931974
147 | mlp_1_1_outputs,1,-1.03061,0.70863277,-0.6241576
148 | mlp_1_1_outputs,10,0.9735542,1.4381392,-2.922016
149 | mlp_1_1_outputs,11,-0.37031677,0.04114491,-0.43833217
150 | mlp_1_1_outputs,12,0.16325824,0.08265705,-0.22923788
151 | mlp_1_1_outputs,13,0.50651205,0.04826694,-0.2412666
152 | mlp_1_1_outputs,14,-0.68649834,0.16758002,-0.8419172
153 | mlp_1_1_outputs,15,0.22871964,-0.3694542,0.036439355
154 | mlp_1_1_outputs,2,0.52248716,0.56564903,0.028707229
155 | mlp_1_1_outputs,3,-0.13902096,0.83899367,0.045766518
156 | mlp_1_1_outputs,4,0.8631842,-0.11090764,-3.633642
157 | mlp_1_1_outputs,5,0.24383672,0.4300947,-0.028894566
158 | mlp_1_1_outputs,6,0.77099085,0.40300265,-0.6123488
159 | mlp_1_1_outputs,7,0.80704874,-0.35886392,-1.2488782
160 | mlp_1_1_outputs,8,0.05687267,-0.7564281,0.10573861
161 | mlp_1_1_outputs,9,0.55384576,-0.005512286,-0.25716427
162 | mlp_2_0_outputs,0,-0.43766528,0.63004327,0.82238406
163 | mlp_2_0_outputs,1,-0.28752628,0.024231752,-0.73521745
164 | mlp_2_0_outputs,10,0.42408463,-1.1208787,-0.8666483
165 | mlp_2_0_outputs,11,2.4725761,-0.24301499,0.1757837
166 | mlp_2_0_outputs,12,0.78457654,-0.50197077,-0.31981418
167 | mlp_2_0_outputs,13,0.41764382,-0.94169605,-0.4211051
168 | mlp_2_0_outputs,14,1.2861595,-0.5318925,-0.036760274
169 | mlp_2_0_outputs,15,0.43116912,-0.94550234,-0.62638485
170 | mlp_2_0_outputs,2,0.057849146,0.12554613,-0.17656699
171 | mlp_2_0_outputs,3,0.9872132,-0.6279045,-0.24866179
172 | mlp_2_0_outputs,4,0.0049139177,1.9990245,-2.5793169
173 | mlp_2_0_outputs,5,0.52743393,-0.35086772,-0.1816786
174 | mlp_2_0_outputs,6,1.2171247,-1.0690438,-0.7097241
175 | mlp_2_0_outputs,7,1.2408279,-0.32120407,-2.283859
176 | mlp_2_0_outputs,8,1.1670102,-0.3065825,-1.6924549
177 | mlp_2_0_outputs,9,0.47862154,0.17900337,-1.1995778
178 | mlp_2_1_outputs,0,1.5722502,-1.8906522,-3.274062
179 | mlp_2_1_outputs,1,-0.14563715,0.071487516,0.4061026
180 | mlp_2_1_outputs,10,0.6692192,0.022232514,0.052336734
181 | mlp_2_1_outputs,11,-0.4981686,0.010872859,-0.120791934
182 | mlp_2_1_outputs,12,0.21853887,0.6528137,0.40027425
183 | mlp_2_1_outputs,13,-0.0539364,-0.04593685,-0.77657515
184 | mlp_2_1_outputs,14,-0.029986976,-0.16086973,0.45140362
185 | mlp_2_1_outputs,15,-0.19369616,0.82837254,0.23891859
186 | mlp_2_1_outputs,2,-1.1634233,1.885357,-3.579113
187 | mlp_2_1_outputs,3,0.46121255,0.038033824,0.0016819923
188 | mlp_2_1_outputs,4,0.09414851,0.5597031,0.2918426
189 | mlp_2_1_outputs,5,-0.41278046,0.0027908396,0.085692376
190 | mlp_2_1_outputs,6,0.1588337,0.11019328,-0.09616117
191 | mlp_2_1_outputs,7,-1.097063,-0.48358715,1.1503831
192 | mlp_2_1_outputs,8,0.4955245,0.4686215,0.23423254
193 | mlp_2_1_outputs,9,-0.049522586,0.9311799,-0.728278
194 | num_attn_0_0_outputs,_,-0.85263616,-0.036784045,-0.13747008
195 | num_attn_0_1_outputs,_,0.33958793,0.86033154,0.89450693
196 | num_attn_1_0_outputs,_,-0.78232145,0.5346261,0.2896964
197 | num_attn_1_1_outputs,_,6.0080223,-6.6434155,-3.9860556
198 | num_attn_2_0_outputs,_,-0.77475876,1.1810129,1.0096526
199 | num_attn_2_1_outputs,_,1.2496741,-1.8492256,1.0177201
200 | num_mlp_0_0_outputs,0,0.3314178,0.5325629,-0.16401817
201 | num_mlp_0_0_outputs,1,0.19410345,0.24263886,-0.3985704
202 | num_mlp_0_0_outputs,10,-0.05038471,0.23930399,-0.5071923
203 | num_mlp_0_0_outputs,11,0.14974928,0.4013859,-0.500009
204 | num_mlp_0_0_outputs,12,0.3967017,0.6515252,-0.013054512
205 | num_mlp_0_0_outputs,13,-0.3407226,-0.23862363,-1.0013121
206 | num_mlp_0_0_outputs,14,-0.0042749355,0.07379021,-0.87100405
207 | num_mlp_0_0_outputs,15,0.044042736,0.5217005,-0.44642144
208 | num_mlp_0_0_outputs,2,-0.04536464,0.35205916,-0.63331443
209 | num_mlp_0_0_outputs,3,0.30032563,0.5497403,-0.3833386
210 | num_mlp_0_0_outputs,4,0.4260003,0.63091743,-0.28802612
211 | num_mlp_0_0_outputs,5,0.25858745,0.0029270295,-0.3307586
212 | num_mlp_0_0_outputs,6,-0.068435706,0.053463157,-0.76825386
213 | num_mlp_0_0_outputs,7,0.36288938,0.19408187,-0.6395527
214 | num_mlp_0_0_outputs,8,0.6338136,0.76790816,-0.06365792
215 | num_mlp_0_0_outputs,9,0.087720215,0.43525898,-0.31613958
216 | num_mlp_0_1_outputs,0,-0.06632771,-0.070165716,-0.4955601
217 | num_mlp_0_1_outputs,1,0.2773735,0.34788114,-0.2892638
218 | num_mlp_0_1_outputs,10,0.2745447,-0.13056615,-0.8543876
219 | num_mlp_0_1_outputs,11,0.044537343,-0.32159278,-0.9105748
220 | num_mlp_0_1_outputs,12,0.18666676,0.1831349,-0.1372812
221 | num_mlp_0_1_outputs,13,0.23526604,-0.16011046,-0.73585725
222 | num_mlp_0_1_outputs,14,0.06875769,0.27533332,-0.23280683
223 | num_mlp_0_1_outputs,15,0.20177683,0.22331731,-0.3380389
224 | num_mlp_0_1_outputs,2,-0.26213843,-0.25991192,-1.0179317
225 | num_mlp_0_1_outputs,3,0.5162888,0.4858723,-0.36303845
226 | num_mlp_0_1_outputs,4,0.31894425,0.2724289,-0.30766574
227 | num_mlp_0_1_outputs,5,0.6092662,0.19296132,-0.85525966
228 | num_mlp_0_1_outputs,6,-0.07518837,0.025708076,-0.6562557
229 | num_mlp_0_1_outputs,7,0.6991878,0.6410874,0.10406117
230 | num_mlp_0_1_outputs,8,-0.3627702,-0.08890215,-0.4461756
231 | num_mlp_0_1_outputs,9,0.024958644,-0.08254894,-0.6912631
232 | num_mlp_1_0_outputs,0,0.3114148,0.55049676,-0.22676654
233 | num_mlp_1_0_outputs,1,0.07806706,-0.023488285,-0.30448
234 | num_mlp_1_0_outputs,10,-0.036106024,-0.061172247,-0.6728275
235 | num_mlp_1_0_outputs,11,0.00018137011,0.24429022,-0.113959186
236 | num_mlp_1_0_outputs,12,0.5541197,0.3748864,-0.091273285
237 | num_mlp_1_0_outputs,13,0.0910017,0.37516505,-0.11659233
238 | num_mlp_1_0_outputs,14,0.029605865,-0.40416694,-0.39333186
239 | num_mlp_1_0_outputs,15,0.4850837,0.17000733,-0.15730006
240 | num_mlp_1_0_outputs,2,-0.09316445,-0.3945504,-0.74077773
241 | num_mlp_1_0_outputs,3,0.77411944,0.45456645,-0.2434635
242 | num_mlp_1_0_outputs,4,0.45820743,0.267871,-0.08835843
243 | num_mlp_1_0_outputs,5,0.8663778,0.48250827,0.17160328
244 | num_mlp_1_0_outputs,6,0.18805933,0.16399784,-0.35467726
245 | num_mlp_1_0_outputs,7,0.4086816,0.020269368,-0.34887818
246 | num_mlp_1_0_outputs,8,-0.026719714,-0.16764963,-0.65317184
247 | num_mlp_1_0_outputs,9,0.019892184,-0.01170788,-0.5114425
248 | num_mlp_1_1_outputs,0,-0.078854226,0.08707899,-0.98892313
249 | num_mlp_1_1_outputs,1,-0.15243109,0.12799971,-0.72895944
250 | num_mlp_1_1_outputs,10,0.58201754,0.6582406,-0.16776891
251 | num_mlp_1_1_outputs,11,0.23494452,0.5244795,-0.41649953
252 | num_mlp_1_1_outputs,12,-0.01350627,-0.00017212816,-0.47231874
253 | num_mlp_1_1_outputs,13,-0.020348622,0.20011073,-0.66538125
254 | num_mlp_1_1_outputs,14,0.16774431,0.5128025,-0.28296882
255 | num_mlp_1_1_outputs,15,0.33657792,0.6268225,-0.22136451
256 | num_mlp_1_1_outputs,2,0.36348855,0.56271625,-0.2105581
257 | num_mlp_1_1_outputs,3,0.24504001,0.5178616,-0.24728571
258 | num_mlp_1_1_outputs,4,-0.05191437,0.09888679,-0.6176198
259 | num_mlp_1_1_outputs,5,-0.06691858,0.15422752,-0.6520832
260 | num_mlp_1_1_outputs,6,0.8261999,1.1822537,0.19525766
261 | num_mlp_1_1_outputs,7,-0.24775888,0.09739883,-0.73482597
262 | num_mlp_1_1_outputs,8,0.21655971,0.25961864,-0.49139246
263 | num_mlp_1_1_outputs,9,-0.23927306,-0.0052134893,-0.8000629
264 | num_mlp_2_0_outputs,0,0.7344719,0.48031777,-0.0072156666
265 | num_mlp_2_0_outputs,1,-0.08577925,-0.13313006,-0.8903967
266 | num_mlp_2_0_outputs,10,0.07255866,-0.023388484,-0.35158905
267 | num_mlp_2_0_outputs,11,0.21112183,0.28738138,-0.1871139
268 | num_mlp_2_0_outputs,12,0.3063709,0.3118506,-0.3509904
269 | num_mlp_2_0_outputs,13,0.20950411,0.26588303,-0.3123367
270 | num_mlp_2_0_outputs,14,0.02394933,0.14214058,-0.45821595
271 | num_mlp_2_0_outputs,15,0.15515622,-0.23100227,-0.7154926
272 | num_mlp_2_0_outputs,2,0.071676075,0.32760754,-0.3858776
273 | num_mlp_2_0_outputs,3,-0.23932433,-0.17936808,-0.6346995
274 | num_mlp_2_0_outputs,4,-0.07454885,-0.04197524,-0.8073381
275 | num_mlp_2_0_outputs,5,0.44984394,0.27431875,-0.26230815
276 | num_mlp_2_0_outputs,6,0.24677134,0.27385816,-0.17668292
277 | num_mlp_2_0_outputs,7,-0.4095317,-0.1621746,-0.8552454
278 | num_mlp_2_0_outputs,8,0.4552858,0.8000447,0.23228435
279 | num_mlp_2_0_outputs,9,0.5577926,0.5177477,-0.009681215
280 | num_mlp_2_1_outputs,0,1.1987035,-0.33801028,-0.50512236
281 | num_mlp_2_1_outputs,1,0.94955945,-0.6463761,-0.7593889
282 | num_mlp_2_1_outputs,10,-0.29233116,0.7033983,0.14847362
283 | num_mlp_2_1_outputs,11,0.263601,-0.39432043,-0.87329227
284 | num_mlp_2_1_outputs,12,2.5524316,-2.5761328,-0.87832254
285 | num_mlp_2_1_outputs,13,0.6176073,-0.8246647,-1.1955972
286 | num_mlp_2_1_outputs,14,0.20788954,0.76745355,0.03399014
287 | num_mlp_2_1_outputs,15,0.72697127,-0.68905944,-0.7866722
288 | num_mlp_2_1_outputs,2,-0.5621572,0.887737,0.11716976
289 | num_mlp_2_1_outputs,3,1.18867,-0.34044042,-0.44703516
290 | num_mlp_2_1_outputs,4,-0.57840586,0.45728707,-0.35807198
291 | num_mlp_2_1_outputs,5,0.6322493,0.63687426,0.113720134
292 | num_mlp_2_1_outputs,6,-0.19396284,1.0908433,0.08487423
293 | num_mlp_2_1_outputs,7,0.47693765,0.46259257,-0.7392341
294 | num_mlp_2_1_outputs,8,-0.1926289,-0.3515196,-0.8700848
295 | num_mlp_2_1_outputs,9,-0.45918483,0.6270346,-0.17804115
296 | ones,_,-0.026366323,0.395983,-0.2675112
297 | positions,0,-0.21428509,0.07516778,0.20615174
298 | positions,1,1.3148408,0.7747299,-9.5917425
299 | positions,10,-0.9434932,-0.7169733,1.7839171
300 | positions,11,2.346714,-0.6581614,-11.038458
301 | positions,12,-1.921217,-0.8985177,2.4055943
302 | positions,13,2.3171093,-0.23238912,-11.31294
303 | positions,14,-2.6990774,-1.6240467,2.7664332
304 | positions,15,3.4511793,-0.7494351,-10.082401
305 | positions,2,-3.2204971,-0.737425,6.4716024
306 | positions,3,-1.0551778,2.5078318,-2.613356
307 | positions,4,-1.4998984,0.22135884,1.4028941
308 | positions,5,0.5460883,2.1069403,-8.382127
309 | positions,6,-0.50053823,0.633908,2.4933894
310 | positions,7,0.29808688,1.4961973,-8.695588
311 | positions,8,-0.6610369,0.24578033,2.0996253
312 | positions,9,3.1191888,0.5754417,-5.0389624
313 | tokens,(,-1.1092756,2.403831,-5.864368
314 | tokens,),0.9881682,-2.3895998,3.3189006
315 | tokens,,0.4111517,-0.24375123,-0.72117674
316 | tokens,,-0.004239899,-0.017491871,-0.007405918
317 | tokens,_,1.3654574,0.29655144,0.28431866
318 | tokens,_,1.4269474,-0.7771767,-0.16797024
319 | tokens,_,0.3346523,-0.2013016,0.14942457
320 | tokens,_,1.3444127,-0.10818399,0.32242635
321 | tokens,_,0.07180428,-0.7881537,0.19018956
322 | tokens,_,-0.31608832,-0.19187881,-0.06436471
323 | tokens,_,0.9428614,-1.0588603,-0.5878867
324 | tokens,_,1.2218833,0.40168327,-0.35955858
325 | tokens,_,-0.5326404,-0.19429654,-0.4998532
326 | tokens,_,0.7153185,-0.070266396,-0.59234005
327 | tokens,{,-0.7310767,2.8042543,-4.676652
328 | tokens,},0.6038923,-3.4896483,2.4227173
329 |
--------------------------------------------------------------------------------
/programs/rasp/hist/hist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def select_closest(keys, queries, predicate):
6 | scores = [[False for _ in keys] for _ in queries]
7 | for i, q in enumerate(queries):
8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)]
9 | if not (any(matches)):
10 | scores[i][0] = True
11 | else:
12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j))
13 | scores[i][j] = True
14 | return scores
15 |
16 |
17 | def select(keys, queries, predicate):
18 | return [[predicate(q, k) for k in keys] for q in queries]
19 |
20 |
21 | def aggregate(attention, values):
22 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention]
23 |
24 |
25 | def aggregate_sum(attention, values):
26 | return [sum([v for a, v in zip(attn, values) if a]) for attn in attention]
27 |
28 |
29 | def run(tokens):
30 | # classifier weights ##########################################
31 | classifier_weights = pd.read_csv(
32 | "programs/rasp/hist/hist_weights.csv", index_col=[0, 1], dtype={"feature": str}
33 | )
34 | # inputs #####################################################
35 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]]
36 |
37 | positions = list(range(len(tokens)))
38 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]]
39 |
40 | ones = [1 for _ in range(len(tokens))]
41 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0)
42 |
43 | # attn_0_0 ####################################################
44 | def predicate_0_0(q_token, k_token):
45 | if q_token in {"1", "4", "0", "2", "3", "5"}:
46 | return k_token == ""
47 | elif q_token in {""}:
48 | return k_token == ""
49 |
50 | attn_0_0_pattern = select_closest(tokens, tokens, predicate_0_0)
51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, positions)
52 | attn_0_0_output_scores = classifier_weights.loc[
53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs]
54 | ]
55 |
56 | # attn_0_1 ####################################################
57 | def predicate_0_1(q_position, k_position):
58 | if q_position in {0}:
59 | return k_position == 7
60 | elif q_position in {1, 2, 3, 5}:
61 | return k_position == 0
62 | elif q_position in {4, 7}:
63 | return k_position == 4
64 | elif q_position in {6}:
65 | return k_position == 2
66 |
67 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1)
68 | attn_0_1_outputs = aggregate(attn_0_1_pattern, positions)
69 | attn_0_1_output_scores = classifier_weights.loc[
70 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs]
71 | ]
72 |
73 | # num_attn_0_0 ####################################################
74 | def num_predicate_0_0(q_token, k_token):
75 | if q_token in {"0"}:
76 | return k_token == "0"
77 | elif q_token in {"1"}:
78 | return k_token == "1"
79 | elif q_token in {"2"}:
80 | return k_token == "2"
81 | elif q_token in {"3"}:
82 | return k_token == "3"
83 | elif q_token in {"4"}:
84 | return k_token == "4"
85 | elif q_token in {"5"}:
86 | return k_token == "5"
87 | elif q_token in {""}:
88 | return k_token == ""
89 |
90 | num_attn_0_0_pattern = select(tokens, tokens, num_predicate_0_0)
91 | num_attn_0_0_outputs = aggregate_sum(num_attn_0_0_pattern, ones)
92 | num_attn_0_0_output_scores = classifier_weights.loc[
93 | [("num_attn_0_0_outputs", "_") for v in num_attn_0_0_outputs]
94 | ].mul(num_attn_0_0_outputs, axis=0)
95 |
96 | # num_attn_0_1 ####################################################
97 | def num_predicate_0_1(q_token, k_token):
98 | if q_token in {"1", "0", "3"}:
99 | return k_token == "5"
100 | elif q_token in {"2"}:
101 | return k_token == "1"
102 | elif q_token in {"4"}:
103 | return k_token == "2"
104 | elif q_token in {"5"}:
105 | return k_token == "3"
106 | elif q_token in {""}:
107 | return k_token == ""
108 |
109 | num_attn_0_1_pattern = select(tokens, tokens, num_predicate_0_1)
110 | num_attn_0_1_outputs = aggregate_sum(num_attn_0_1_pattern, ones)
111 | num_attn_0_1_output_scores = classifier_weights.loc[
112 | [("num_attn_0_1_outputs", "_") for v in num_attn_0_1_outputs]
113 | ].mul(num_attn_0_1_outputs, axis=0)
114 |
115 | # mlp_0_0 #####################################################
116 | def mlp_0_0(position, attn_0_1_output):
117 | key = (position, attn_0_1_output)
118 | return 3
119 |
120 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(positions, attn_0_1_outputs)]
121 | mlp_0_0_output_scores = classifier_weights.loc[
122 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs]
123 | ]
124 |
125 | # num_mlp_0_0 #################################################
126 | def num_mlp_0_0(num_attn_0_0_output):
127 | key = num_attn_0_0_output
128 | if key in {0, 1}:
129 | return 6
130 | return 3
131 |
132 | num_mlp_0_0_outputs = [num_mlp_0_0(k0) for k0 in num_attn_0_0_outputs]
133 | num_mlp_0_0_output_scores = classifier_weights.loc[
134 | [("num_mlp_0_0_outputs", str(v)) for v in num_mlp_0_0_outputs]
135 | ]
136 |
137 | feature_logits = pd.concat(
138 | [
139 | df.reset_index()
140 | for df in [
141 | token_scores,
142 | position_scores,
143 | attn_0_0_output_scores,
144 | attn_0_1_output_scores,
145 | mlp_0_0_output_scores,
146 | num_mlp_0_0_output_scores,
147 | one_scores,
148 | num_attn_0_0_output_scores,
149 | num_attn_0_1_output_scores,
150 | ]
151 | ]
152 | )
153 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy()
154 | classes = classifier_weights.columns.to_numpy()
155 | predictions = classes[logits.argmax(-1)]
156 | if tokens[0] == "":
157 | predictions[0] = ""
158 | if tokens[-1] == "":
159 | predictions[-1] = ""
160 | return predictions.tolist()
161 |
162 |
163 | examples = [
164 | (
165 | ["", "2", "5", "3", "2", "1", "5", "3"],
166 | ["", "2", "2", "2", "2", "1", "2", "2"],
167 | ),
168 | (["", "2", "4", "3", "5", "2", "5"], ["", "2", "1", "1", "2", "2", "2"]),
169 | (["", "5", "2", "0", "4", "3", "4"], ["", "1", "1", "1", "2", "1", "2"]),
170 | (
171 | ["", "5", "3", "4", "1", "2", "5", "2"],
172 | ["", "2", "1", "1", "1", "2", "2", "2"],
173 | ),
174 | (
175 | ["", "2", "0", "2", "0", "3", "4", "4"],
176 | ["", "2", "2", "2", "2", "1", "2", "2"],
177 | ),
178 | (["", "5", "0", "3", "2", "5"], ["", "2", "1", "1", "1", "2"]),
179 | (["", "4", "5", "0", "2", "3", "1"], ["", "1", "1", "1", "1", "1", "1"]),
180 | (["", "2", "5", "5"], ["", "1", "2", "2"]),
181 | (
182 | ["", "0", "2", "3", "2", "3", "0", "3"],
183 | ["", "2", "2", "3", "2", "3", "2", "3"],
184 | ),
185 | (
186 | ["", "2", "1", "1", "2", "3", "3", "4"],
187 | ["", "2", "2", "2", "2", "2", "2", "1"],
188 | ),
189 | ]
190 | for x, y in examples:
191 | print(f"x: {x}")
192 | print(f"y: {y}")
193 | y_hat = run(x)
194 | print(f"y_hat: {y_hat}")
195 | print()
196 |
--------------------------------------------------------------------------------
/programs/rasp/hist/hist_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,1,2,3,4,5,6
2 | attn_0_0_outputs,0,6.781757,25.452711,10.974162,-1.9199638,-10.4853525,-15.709011
3 | attn_0_0_outputs,1,1.4295937,3.2632368,-0.074068084,-2.4339116,-5.1740317,-3.098548
4 | attn_0_0_outputs,2,1.0853696,2.85033,-0.07023634,-2.7719102,-5.3003054,-3.8455026
5 | attn_0_0_outputs,3,1.5324218,3.3294716,0.415891,-2.2662394,-4.7209177,-3.3629472
6 | attn_0_0_outputs,4,1.0801277,2.88044,0.028874561,-2.579776,-4.811679,-3.6099398
7 | attn_0_0_outputs,5,1.0413198,3.1703262,-0.014626876,-2.2997146,-4.568741,-2.9599016
8 | attn_0_0_outputs,6,0.33147946,2.3486278,0.06572169,-2.0305748,-4.2188454,-2.249484
9 | attn_0_0_outputs,7,0.1623374,2.830696,-0.14196976,-1.2703604,-2.8180695,-0.34874895
10 | attn_0_1_outputs,0,4.262853,13.889445,7.3905044,-0.72937316,-8.37309,-13.166056
11 | attn_0_1_outputs,1,1.2386446,2.9108822,-0.2176869,-2.825408,-5.370257,-2.9436467
12 | attn_0_1_outputs,2,1.7371144,9.852318,4.1092057,-2.8085132,-9.59537,-14.0036745
13 | attn_0_1_outputs,3,0.71794397,3.1448169,-0.023318842,-2.7755547,-5.5583167,-3.5951939
14 | attn_0_1_outputs,4,2.2684147,10.615341,5.2562523,-2.1873116,-9.4488535,-14.030566
15 | attn_0_1_outputs,5,0.5881659,2.4700139,-0.08045172,-2.2570171,-4.1498766,-1.9420767
16 | attn_0_1_outputs,6,0.40211794,1.7297664,-0.050353132,-1.4864924,-2.592189,-0.6177777
17 | attn_0_1_outputs,7,0.4194881,2.907967,0.62582093,-1.6866827,-3.9042215,-1.8792101
18 | mlp_0_0_outputs,0,0.4938747,0.630038,0.23696063,-0.5110521,-0.7338477,-0.6524683
19 | mlp_0_0_outputs,1,0.74656296,0.19177268,-0.31644797,0.058880232,-0.66460264,-1.0663412
20 | mlp_0_0_outputs,2,0.5359965,-0.04372708,-0.7205332,-0.7282277,-0.7255958,-1.1605394
21 | mlp_0_0_outputs,3,4.233813,12.146829,4.1495814,-5.1848955,-11.641673,-16.112806
22 | mlp_0_0_outputs,4,0.20570962,-0.24562106,-0.19681151,-0.36250097,-0.3743104,-0.7396391
23 | mlp_0_0_outputs,5,0.8558908,0.5187872,0.0962603,0.30497876,-0.7910682,-0.49263448
24 | mlp_0_0_outputs,6,0.4841655,1.0118659,0.03025597,-0.36506364,-0.70360786,-0.21639934
25 | mlp_0_0_outputs,7,0.9191022,0.3703126,-0.28247193,-0.29374048,-0.6344728,-0.64356095
26 | num_attn_0_0_outputs,_,-4.868864,-27.139973,-8.78817,5.6797624,15.745415,20.68757
27 | num_attn_0_1_outputs,_,0.9234888,0.40377212,0.19646904,-0.013339977,-0.3078388,-7.9514294
28 | num_mlp_0_0_outputs,0,1.1347574,-0.2206074,-0.33900908,-0.525305,-0.8139806,-0.8801327
29 | num_mlp_0_0_outputs,1,3.9370484,-3.322479,-1.9119565,-0.60652006,-0.30610657,-1.285272
30 | num_mlp_0_0_outputs,2,1.6984854,-0.56601596,-0.68408555,-0.67591697,-0.3258158,-1.005294
31 | num_mlp_0_0_outputs,3,-7.021126,15.839726,6.340541,-3.105211,-10.3643875,-16.01029
32 | num_mlp_0_0_outputs,4,1.4550712,-1.0342308,-0.34201962,-1.1408417,-0.93610877,-0.28879276
33 | num_mlp_0_0_outputs,5,3.7206275,-2.0013118,-2.0068526,-1.2758466,0.019715928,-0.79015535
34 | num_mlp_0_0_outputs,6,18.228642,-20.21013,-3.2542965,-1.0904899,-0.45162725,-0.012080559
35 | num_mlp_0_0_outputs,7,2.5776606,-1.1733114,-1.8273598,-0.85842824,-0.93384117,-0.54284114
36 | ones,_,1.283879,6.372424,2.8256679,-1.4089221,-8.6447,-11.394332
37 | positions,0,0.5897627,0.16051057,0.7623805,0.24794468,0.1367697,0.36979273
38 | positions,1,1.0287309,4.5030293,2.3415422,-1.3173267,-6.8056574,-10.195892
39 | positions,2,1.083402,4.620112,2.3640404,-1.3567489,-6.838617,-10.438745
40 | positions,3,1.0730014,4.6261773,2.4860039,-1.2073922,-6.6122475,-10.18452
41 | positions,4,1.0189941,5.8566985,2.620293,-1.7836683,-7.4319153,-11.364978
42 | positions,5,0.96149987,4.631558,2.526294,-1.1215514,-6.515995,-10.108589
43 | positions,6,0.44539887,5.6350055,2.9353447,-1.8244816,-7.916871,-11.977934
44 | positions,7,1.1673726,6.778532,3.7333648,-0.52524173,-6.0611577,-9.702886
45 | tokens,0,1.217923,4.5158324,2.319617,-1.2275133,-4.5206556,-6.714174
46 | tokens,1,0.4282554,3.2384415,1.0357013,-2.5313835,-5.9769506,-13.27017
47 | tokens,2,1.4819847,5.0116262,2.8555121,-0.65765595,-4.2250724,-5.598877
48 | tokens,3,1.1304332,4.4012556,2.251413,-1.3463753,-4.7049932,-6.437009
49 | tokens,4,0.8615442,3.7823195,1.4800204,-2.1071358,-5.914414,-12.4524975
50 | tokens,5,1.0285466,4.464686,2.3432922,-1.1491289,-4.432571,-6.315694
51 | tokens,,0.83862156,0.89218545,-0.032585267,-0.8908109,-0.610661,-0.34020367
52 | tokens,,0.3811775,0.747888,-0.18793498,-0.6211176,-0.59679526,-0.5371528
53 |
--------------------------------------------------------------------------------
/programs/rasp/most_freq/most_freq_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,0,1,2,3,4,5,
2 | attn_0_0_outputs,0,0.3180253,0.17051244,0.17996229,0.08486652,-0.050912812,-0.055793837,0.43272197
3 | attn_0_0_outputs,1,-0.096741244,0.7338011,-0.0095024835,0.24725218,-0.05952198,-1.2952329,-4.8797684
4 | attn_0_0_outputs,2,0.17302921,0.5287753,0.077450015,0.43879598,0.12759952,0.024711223,-1.2387171
5 | attn_0_0_outputs,3,-0.020483518,-0.4090684,-0.06907431,-0.019290976,-0.4361626,-0.98895913,1.5400033
6 | attn_0_0_outputs,4,-0.028248059,0.12039765,-0.18920611,-0.07074923,-0.24273163,-0.27256623,0.3491913
7 | attn_0_0_outputs,5,0.24701072,0.64526,0.2259116,0.72207314,-0.093196996,1.9177046,-3.6679096
8 | attn_0_0_outputs,,0.061711412,0.2944613,-0.06865534,-0.5166919,-0.14260983,-0.3558589,0.32799473
9 | attn_0_0_outputs,,0.23678741,-1.63671,0.25388855,-0.14997865,0.11181579,-2.2149167,9.153281
10 | attn_0_1_outputs,0,1.0899844,-0.28106454,-0.85135394,-0.026211262,-0.5590216,0.21339954,0.19288713
11 | attn_0_1_outputs,1,0.092373446,0.24267156,-0.40928265,0.15436336,0.45326802,0.33714578,0.60150456
12 | attn_0_1_outputs,2,-0.03783019,-0.033535108,3.3072128,-1.0111456,0.71332467,-0.11776653,-6.409337
13 | attn_0_1_outputs,3,0.23245741,0.1589569,-0.85276276,1.2885551,0.2456359,-1.3701653,0.6277164
14 | attn_0_1_outputs,4,0.2582553,0.33029148,-0.46015385,0.35623923,0.8750041,-1.0261837,1.0965881
15 | attn_0_1_outputs,5,0.112118945,0.1932354,-0.35270444,0.3343012,0.33635232,1.7447497,0.049526434
16 | attn_0_1_outputs,,0.25605395,0.05547662,0.16247602,-0.14924651,0.14146863,-0.53788525,0.24325536
17 | attn_0_1_outputs,,0.29997656,-0.10898747,-0.51443803,0.098860115,-0.02304782,0.09178379,0.74192613
18 | attn_0_2_outputs,0,0.27396625,0.2443406,-0.14687176,0.19030151,0.3444485,-0.0885761,0.026525198
19 | attn_0_2_outputs,1,-0.072723806,-0.6435364,0.2422663,-0.5205403,-0.01871688,-0.12267997,-0.45468605
20 | attn_0_2_outputs,2,-0.0013228728,0.30712268,-0.5249838,0.077301025,0.1839496,-0.0076866625,-0.754529
21 | attn_0_2_outputs,3,-0.10197616,-0.22210558,-0.057100624,-0.5848149,-0.028931309,-0.16512564,-0.57091314
22 | attn_0_2_outputs,4,0.25916776,0.40352133,0.33918947,0.13571715,-0.34599137,0.010477799,-0.6685143
23 | attn_0_2_outputs,5,0.6194994,0.941407,0.93209845,0.742026,0.68645924,0.18754262,-1.0428277
24 | attn_0_2_outputs,,0.14436923,-0.95419455,0.082035325,-0.03649257,-0.070194826,-0.072033234,0.5734176
25 | attn_0_2_outputs,,-0.30024177,0.7607851,-0.07830339,0.13043396,-0.15203752,0.47434998,1.0684925
26 | attn_0_3_outputs,0,0.4359758,0.39296976,0.3107209,0.27862722,0.4337573,-0.5773376,0.78378
27 | attn_0_3_outputs,1,0.28894436,2.4578593,0.23578313,0.49545267,0.65666926,-2.2483876,-4.0735903
28 | attn_0_3_outputs,2,0.09834935,0.06901251,-0.50735116,0.30153948,0.31839263,-0.32487082,-0.114381574
29 | attn_0_3_outputs,3,0.30135068,-1.0148153,0.27192265,-0.124534786,0.7126381,-2.1536736,2.5469096
30 | attn_0_3_outputs,4,0.30178407,0.8716293,0.07430823,0.7604845,0.55651057,-1.229208,0.050405312
31 | attn_0_3_outputs,5,-0.13249332,-1.1359549,-0.13182358,0.3849195,-0.26988852,3.4882991,-7.758176
32 | attn_0_3_outputs,,0.38542554,-0.20650521,0.45955607,-0.4878401,-0.031885352,-0.50210035,0.16015093
33 | attn_0_3_outputs,,0.28155896,-0.31953847,0.32690638,0.54727995,0.70288324,-2.5177412,3.5819783
34 | attn_1_0_outputs,0,2.0141475,0.34906653,0.68659526,0.9869916,-0.13294022,1.8037857,-10.206428
35 | attn_1_0_outputs,1,-0.19686982,0.73792344,-0.12540162,-1.3587084,0.16657266,1.8406917,1.5816158
36 | attn_1_0_outputs,2,-0.226673,-0.26789877,0.18990181,0.16948664,0.52418464,0.2592063,1.2878286
37 | attn_1_0_outputs,3,-0.83169764,3.5384831,-0.4879805,-1.1000174,-0.6881626,-0.77432674,-13.3651
38 | attn_1_0_outputs,4,0.15082398,-0.14441168,-0.005273094,0.23471557,0.20179209,0.5158661,0.38174176
39 | attn_1_0_outputs,5,0.13206911,-1.4834553,0.21140684,0.68273836,0.37051022,0.6227547,-2.3825295
40 | attn_1_0_outputs,6,-0.46533993,-0.44353464,-0.72359395,-0.74037266,-0.42077038,-0.08801137,3.6209009
41 | attn_1_0_outputs,7,-0.45181456,-10.881032,-0.46400476,0.98205274,-0.11754535,0.46539927,11.495142
42 | attn_1_1_outputs,0,0.9750074,-0.33286405,-0.09389574,-0.04933356,0.94795406,0.24440528,0.02822969
43 | attn_1_1_outputs,1,0.09181721,-0.108542785,-0.004596466,0.09085039,-0.0960271,-0.2939083,-1.4847285
44 | attn_1_1_outputs,2,0.13675968,-0.14381808,-0.14170228,1.3159441,0.84011686,0.187356,-0.2180252
45 | attn_1_1_outputs,3,0.03167638,-0.2564188,0.10374082,-0.030116782,0.37519884,-0.17905286,0.4855987
46 | attn_1_1_outputs,4,0.3155073,-0.27092186,-0.030149449,0.4447328,0.26189294,0.13343854,-0.2836621
47 | attn_1_1_outputs,5,0.79908526,0.33884925,0.5204918,0.563897,0.31888938,0.24198307,-3.097063
48 | attn_1_1_outputs,,0.12040514,-0.3541742,-0.37897912,0.33264196,-0.1644499,0.09382361,0.47280157
49 | attn_1_1_outputs,,-0.21088882,-0.25195375,-0.38574246,-0.26290575,-0.38151267,-0.6612043,3.2316494
50 | attn_1_2_outputs,0,1.4660153,-0.07058339,0.15579924,0.22429287,0.015527983,-0.008877302,-0.4786283
51 | attn_1_2_outputs,1,-0.32045385,0.7251625,-0.16099809,-0.6113556,-0.20735613,-0.34426436,0.03460376
52 | attn_1_2_outputs,2,-0.05991355,-0.26121002,1.5285293,-0.1514813,0.13110724,-0.5427844,-0.37770408
53 | attn_1_2_outputs,3,-0.17418517,-0.5064767,0.020557076,1.8331826,-0.19199504,-0.5358481,-0.2174345
54 | attn_1_2_outputs,4,-0.33506688,-0.4865637,-0.19673166,-0.35718244,1.8005149,-0.7867146,-0.5104927
55 | attn_1_2_outputs,5,0.1529805,-0.11731564,0.11252774,0.09463248,0.026155643,1.275788,-0.39444077
56 | attn_1_2_outputs,,-0.03210811,-0.1851439,-0.52887833,-0.252833,-0.6852915,-0.4575861,0.27375743
57 | attn_1_2_outputs,,-0.03430497,-0.044254147,-0.15155597,0.057937745,-0.022528501,-0.033426747,0.46648577
58 | attn_1_3_outputs,0,0.35898498,-0.20034088,0.020413859,-0.22355627,-10.102141,-0.33305383,7.6976376
59 | attn_1_3_outputs,1,1.0111092,-0.062783584,0.5327143,0.87491333,1.3438202,1.1383574,-2.6126447
60 | attn_1_3_outputs,2,0.081306025,-0.57257915,-0.17199379,0.23137064,1.004636,0.7501657,-3.3402092
61 | attn_1_3_outputs,3,0.4108505,-0.08137371,0.49178204,1.0274754,-0.049497884,1.4293927,-3.6926756
62 | attn_1_3_outputs,4,0.032954328,-0.26930884,0.1631624,0.6166075,-0.80901116,0.41795555,-0.96519506
63 | attn_1_3_outputs,5,0.14050587,-0.043274768,0.389469,0.94465864,0.34980577,1.3579348,-2.1161852
64 | attn_1_3_outputs,6,0.44501925,0.13667245,0.6853415,0.5725537,2.0187433,0.6294245,-5.953397
65 | attn_1_3_outputs,7,-0.46666646,0.5990393,0.025998702,0.41633928,0.02385631,1.1114172,-1.6757435
66 | attn_2_0_outputs,0,0.3495429,0.30337927,0.49805364,0.13225693,-0.2870291,0.27936482,-0.6221486
67 | attn_2_0_outputs,1,0.27308455,-0.14491963,-0.059935313,0.3436768,-0.5626455,0.080760516,-0.67031586
68 | attn_2_0_outputs,2,-0.11247145,0.52780163,-0.031945467,0.16022672,-0.08858324,-0.13487771,1.2465036
69 | attn_2_0_outputs,3,0.4697936,0.82808137,0.5196833,0.3312025,0.35191402,-0.13352305,0.16811055
70 | attn_2_0_outputs,4,0.58132875,0.8832289,0.71233314,0.5913085,0.2582655,0.5411773,-0.13924693
71 | attn_2_0_outputs,5,-0.5324829,-0.54294145,-0.47411755,-0.559129,-0.962662,-0.80526704,0.70581704
72 | attn_2_0_outputs,6,0.29960564,0.5047627,0.42610115,0.47605318,0.055260766,0.2752538,0.2869074
73 | attn_2_0_outputs,7,0.8061134,1.0037099,0.43810177,0.50335884,0.08938471,0.5953439,-0.61285174
74 | attn_2_1_outputs,0,0.35270405,-0.050330415,0.35786822,0.8349852,0.43520316,0.37781978,-1.8966578
75 | attn_2_1_outputs,1,-0.5173535,-1.0113736,-0.77803063,-12.798715,-0.7180023,-0.73949975,13.488985
76 | attn_2_1_outputs,2,0.036005918,-0.15727578,0.045522507,0.9319064,-0.048925404,-0.14085157,-1.1990162
77 | attn_2_1_outputs,3,1.0115867,0.43455648,0.93790376,1.9610963,0.88474375,0.71575737,-6.8169203
78 | attn_2_1_outputs,4,0.77685916,0.17810385,0.14974725,0.837453,-0.38813585,-0.10558918,-0.19106685
79 | attn_2_1_outputs,5,0.5566551,-0.01609987,0.4755912,1.9255587,0.53078216,0.33371648,-6.7548733
80 | attn_2_1_outputs,6,0.12431612,-0.3542981,0.30530918,1.5669754,0.23087248,0.45061296,-1.3779991
81 | attn_2_1_outputs,7,0.22629534,-0.6761686,0.6213678,1.038082,0.306376,-0.59195256,-0.19801751
82 | attn_2_2_outputs,0,-0.13583189,-0.03280338,-0.19608569,0.08558506,-3.1654773,0.06910204,1.4312361
83 | attn_2_2_outputs,1,-0.38239425,0.05366795,0.022488022,-0.042790532,-0.0076938574,-0.35101935,2.496905
84 | attn_2_2_outputs,2,0.4532444,0.4999768,0.5216906,0.4312385,0.6598157,0.57998353,-3.2020876
85 | attn_2_2_outputs,3,0.69540215,-0.018808544,0.92008996,0.85889035,1.0399714,0.9349368,-1.2531354
86 | attn_2_2_outputs,4,0.32867137,-0.498774,0.3323383,0.422421,0.10450748,0.53979826,-0.23186198
87 | attn_2_2_outputs,5,0.30699706,-0.16054757,0.245183,0.09240096,0.3129589,-0.011169057,-1.3021853
88 | attn_2_2_outputs,6,0.4382863,0.8850737,0.5132612,0.5126867,0.6615926,0.052526668,-1.1747515
89 | attn_2_2_outputs,7,0.798873,-3.1653268,0.8079976,0.46534437,0.99906564,-0.27384913,0.41713637
90 | attn_2_3_outputs,0,-0.0595329,-0.3432393,0.30217978,0.16495714,0.19219992,0.17431322,1.4355792
91 | attn_2_3_outputs,1,0.12711096,-1.4225363,0.22368737,-0.027527437,0.09366445,-0.040307615,0.01481341
92 | attn_2_3_outputs,2,0.58099073,0.3474417,-0.41153923,0.79462695,0.9205741,0.8575549,-1.4096668
93 | attn_2_3_outputs,3,-0.042766575,-0.65066993,0.20429829,-0.9746706,-0.12820214,-0.34311306,-0.5651443
94 | attn_2_3_outputs,4,0.4263014,0.39161858,0.58194137,0.5913219,-0.66568244,-0.35783538,0.029537905
95 | attn_2_3_outputs,5,0.7230335,0.28817055,0.8745937,0.90366805,0.88130885,-0.37258768,-0.67117524
96 | attn_2_3_outputs,,-0.2462671,-0.82254106,-0.20106415,-0.16820565,0.0874213,-0.04618327,-0.017906675
97 | attn_2_3_outputs,,0.1876988,0.5234487,0.4194045,0.589009,0.3397852,-0.13758777,-2.7441623
98 | mlp_0_0_outputs,0,-0.009067974,-0.10616362,0.5120841,0.30963194,0.2905615,0.5171235,1.3900716
99 | mlp_0_0_outputs,1,-0.18765688,0.7436373,0.18984891,0.51863587,0.27437514,-2.8250957,2.2061565
100 | mlp_0_0_outputs,2,0.1514967,-0.24174538,0.40851867,0.06300549,0.45437306,1.9693145,-1.5879657
101 | mlp_0_0_outputs,3,-0.074622594,2.0452924,0.074981324,2.5786343,0.48228234,-0.98552316,-1.702283
102 | mlp_0_0_outputs,4,0.87153715,0.71818954,1.6129097,1.2832383,1.3467029,1.3728735,-4.159039
103 | mlp_0_0_outputs,5,0.25284207,-0.92589784,0.52295667,0.26822376,0.86385244,0.7218063,0.9255813
104 | mlp_0_0_outputs,6,-2.2069793,-1.6042551,-1.9572843,-1.4469824,-1.7742672,-1.4493668,5.8365097
105 | mlp_0_0_outputs,7,0.6310457,-2.2726629,0.96292347,0.3914378,0.8041782,1.9205551,-7.1911926
106 | mlp_0_1_outputs,0,0.2533135,2.9234807,0.8514278,0.9629272,0.9832052,0.7691257,-2.316475
107 | mlp_0_1_outputs,1,0.28654918,0.3766166,1.2217463,1.2102123,0.11371473,1.5548589,-3.2720823
108 | mlp_0_1_outputs,2,1.2309315,-0.3645303,0.39764106,0.63220364,0.13651253,0.668529,-3.4863636
109 | mlp_0_1_outputs,3,-1.8739396,-2.2920024,-2.590256,-1.8237208,-2.8227699,-2.6384785,6.279832
110 | mlp_0_1_outputs,4,3.6876507,1.1458737,1.7822194,2.422588,4.042542,1.0609845,-10.931403
111 | mlp_0_1_outputs,5,0.87431866,1.2460302,-0.2996694,1.574731,-0.5186607,-0.72277695,-4.7037625
112 | mlp_0_1_outputs,6,-3.0101469,-0.37742805,-3.9014912,-3.186129,-4.1218615,-3.49612,8.171139
113 | mlp_0_1_outputs,7,2.8105667,0.41083074,2.3087406,1.4794773,0.5835391,0.5765365,-7.642529
114 | mlp_1_0_outputs,0,0.10424284,0.42691907,0.26822764,0.72388,0.4306211,0.3545007,-1.45301
115 | mlp_1_0_outputs,1,-0.29419023,0.08807841,-0.19642091,0.2297624,0.11144069,0.22908655,0.012775027
116 | mlp_1_0_outputs,2,0.033752218,0.5904094,0.016924301,0.65878856,0.552538,0.40060502,-0.8275471
117 | mlp_1_0_outputs,3,0.03518476,0.5058186,0.004764182,0.65987307,0.27559388,0.34288618,0.8714135
118 | mlp_1_0_outputs,4,-0.33149403,0.07573308,-0.010276801,0.5206794,0.03746909,0.11693972,-0.54683185
119 | mlp_1_0_outputs,5,-0.02862338,0.37228417,0.2534814,0.36578193,0.43761313,0.74797416,0.6001068
120 | mlp_1_0_outputs,6,0.67046475,0.42168856,0.17522931,0.9923386,0.42843175,0.6051585,-0.6304327
121 | mlp_1_0_outputs,7,-0.6101541,0.12701088,0.45372865,0.8725,0.27473047,-0.016846292,-0.33390883
122 | mlp_1_1_outputs,0,0.013466556,-0.31899878,-0.0581283,0.12199419,-0.28192335,0.0019597644,0.32551047
123 | mlp_1_1_outputs,1,0.86517906,0.99981546,1.1384677,0.9402987,0.7827801,0.7498811,-7.286347
124 | mlp_1_1_outputs,2,-1.2548337,-1.2306365,-1.2297429,-0.7032778,-0.9226823,-0.9777365,3.206642
125 | mlp_1_1_outputs,3,0.3976817,-0.037045293,0.17781861,0.5145008,-0.060972173,-0.15903838,0.661624
126 | mlp_1_1_outputs,4,0.33862308,-0.02482277,0.22182596,0.64739317,-0.026410017,0.43479386,1.3172387
127 | mlp_1_1_outputs,5,-0.15095389,-0.48228294,-0.20532566,0.103111684,-0.37304664,-0.33961606,0.4904485
128 | mlp_1_1_outputs,6,0.77520096,1.0291071,0.8356842,0.6138105,0.508028,0.81939375,-6.3209095
129 | mlp_1_1_outputs,7,0.35298768,0.14868644,0.1929121,0.8483676,0.13015799,0.1845815,-0.6845896
130 | mlp_2_0_outputs,0,0.22749007,0.010398724,-0.7143261,0.19049524,0.40374222,-0.18089859,1.913285
131 | mlp_2_0_outputs,1,0.43453947,0.48566645,1.2921165,0.27306423,0.7190616,0.64280427,-9.392651
132 | mlp_2_0_outputs,2,-0.010374933,0.6695905,0.8829944,0.09739681,0.1910442,-0.55804473,-0.24017237
133 | mlp_2_0_outputs,3,0.21176498,-0.6259095,1.0027347,-0.2894584,0.15444903,0.39754984,-0.62771827
134 | mlp_2_0_outputs,4,0.35422748,-0.03349688,1.0252547,0.1733238,0.36093882,-0.23943187,0.3938007
135 | mlp_2_0_outputs,5,-1.4427056,0.19055358,0.7897258,0.31336093,-0.007993803,0.31204116,-1.0540504
136 | mlp_2_0_outputs,6,0.66334516,-0.1595391,1.3743109,0.043930616,0.16582729,-0.17752126,-1.0506142
137 | mlp_2_0_outputs,7,-0.81803083,-1.0996387,-1.3154311,-1.6927465,-1.0877206,-1.2953286,6.2696586
138 | mlp_2_1_outputs,0,0.20822705,0.8802053,0.08234023,0.29908177,0.044410724,0.18044509,-1.9308993
139 | mlp_2_1_outputs,1,-0.45252222,0.64726084,-0.39025173,-0.12514322,-0.5468172,-0.4661018,0.12722749
140 | mlp_2_1_outputs,2,0.3255769,0.23384853,-0.41722363,0.724245,0.47933605,0.16578805,-0.4978126
141 | mlp_2_1_outputs,3,-0.18909135,0.21740477,-0.1728526,0.08816649,0.33721173,0.0031366053,0.33242244
142 | mlp_2_1_outputs,4,0.31890598,1.205369,0.26985618,0.43814164,0.012860747,0.028930334,0.41829804
143 | mlp_2_1_outputs,5,0.26172248,1.438389,0.345994,0.60735875,0.22993265,0.39017054,-2.027972
144 | mlp_2_1_outputs,6,0.5005941,0.60231215,0.81867206,0.4571476,0.137451,-0.1562085,-0.7049604
145 | mlp_2_1_outputs,7,0.73108,2.107512,0.88618535,0.81509197,0.54469866,0.5268218,-2.7247245
146 | num_attn_0_0_outputs,_,0.49675706,0.76701194,0.8319468,1.0188137,-3.46674,1.3313375,1.2141222
147 | num_attn_0_1_outputs,_,-4.705128,0.20360513,0.05153054,0.18922624,2.8429217,0.7594758,0.08416794
148 | num_attn_0_2_outputs,_,0.32272223,0.26861843,0.27784988,-0.007126636,0.30486012,-0.51013726,0.45339686
149 | num_attn_0_3_outputs,_,0.35131818,0.15362968,-0.026241664,0.024230676,0.7753488,-0.36277744,0.23196767
150 | num_attn_1_0_outputs,_,-0.17899519,-0.22110243,1.955975,-0.33664736,-0.18993665,-0.42071265,-3.2977908
151 | num_attn_1_1_outputs,_,0.096573666,0.041719396,0.3011042,-0.7403994,0.24944292,0.35339233,0.40832838
152 | num_attn_1_2_outputs,_,0.307874,0.77904606,-3.6890242,0.57999134,0.07966587,1.0292304,2.588211
153 | num_attn_1_3_outputs,_,0.1556463,-0.88464093,-0.009464315,1.2894713,-0.006048728,0.12710084,0.042141438
154 | num_attn_2_0_outputs,_,3.2238479,-0.39403713,-0.5007616,-0.38034695,-1.5038663,-0.5157612,-0.4220676
155 | num_attn_2_1_outputs,_,0.09123767,0.06927917,0.071765244,0.06278864,0.03571881,-1.0034105,0.22208074
156 | num_attn_2_2_outputs,_,0.49392816,0.30212557,0.5877162,0.3936646,0.7463011,0.91613317,-0.27264255
157 | num_attn_2_3_outputs,_,0.59951186,1.5315337,0.45857117,-2.2965033,0.47967055,0.9584901,-0.25912756
158 | num_mlp_0_0_outputs,0,0.75295377,0.22206154,0.4430251,0.43609834,0.3002954,0.14143573,-0.9965674
159 | num_mlp_0_0_outputs,1,0.13362123,-0.27709126,0.11786564,0.0842753,0.057983357,-0.20576423,-0.17937659
160 | num_mlp_0_0_outputs,2,0.77560294,0.23398742,0.38785017,0.09531838,0.36529553,0.0049865237,-2.6353512
161 | num_mlp_0_0_outputs,3,0.67882055,0.47041214,0.49682024,0.22344151,0.27797723,0.384288,-0.78773314
162 | num_mlp_0_0_outputs,4,0.36207622,-0.12504679,0.15990801,0.038606238,-0.029018853,-0.017407462,-0.3937458
163 | num_mlp_0_0_outputs,5,-5.4784355,-0.21962734,-0.05829186,0.06191359,0.76395345,-0.15172787,5.640328
164 | num_mlp_0_0_outputs,6,0.6987904,0.19056033,0.45902565,0.0974688,0.20158903,0.2252314,-0.5025556
165 | num_mlp_0_0_outputs,7,-0.66295034,0.33240935,0.7547661,0.07789776,-1.1808933,0.11714694,-1.6093485
166 | num_mlp_0_1_outputs,0,-0.4456775,-0.3319486,0.259621,0.04547681,-3.9428256,0.30970556,3.493646
167 | num_mlp_0_1_outputs,1,-0.060211174,0.5636933,0.063009106,0.058421217,1.4696409,-0.012732882,-3.97732
168 | num_mlp_0_1_outputs,2,-0.112916395,0.7270722,0.22726461,0.24413253,0.8066552,0.011055177,-0.95724344
169 | num_mlp_0_1_outputs,3,1.633666,0.42186847,0.36930156,-0.119241916,0.40250668,-0.55435634,-3.3496842
170 | num_mlp_0_1_outputs,4,-0.4858897,0.43768635,-0.16505812,-0.17672735,0.9722587,-0.056002922,-1.3548676
171 | num_mlp_0_1_outputs,5,-0.097139664,0.3772895,0.27854103,0.25372255,0.8777077,0.24463934,-0.93909496
172 | num_mlp_0_1_outputs,6,0.33189073,0.012857659,-0.14545752,0.1099122,1.541078,0.07001639,-0.91754353
173 | num_mlp_0_1_outputs,7,0.41338253,0.3731902,-0.07001816,0.79229236,-0.9019284,0.11754089,2.5209103
174 | num_mlp_1_0_outputs,0,0.24915189,-0.13592264,0.28131127,0.31066325,0.21636605,0.18062419,-1.4040216
175 | num_mlp_1_0_outputs,1,0.1564762,0.08252395,-5.7924643,-0.192247,0.33157474,-0.009017468,2.942537
176 | num_mlp_1_0_outputs,2,0.13001019,-0.31609473,0.84513766,0.34132242,0.5279612,0.41859335,-6.285437
177 | num_mlp_1_0_outputs,3,0.4110137,-0.054942466,-0.1518215,-0.08841369,0.3400565,0.11748088,-0.1709277
178 | num_mlp_1_0_outputs,4,0.23019612,-0.12827308,-0.05825717,-0.0017411346,-0.12765515,-0.08880548,-0.9720372
179 | num_mlp_1_0_outputs,5,0.6083281,-0.21133165,-1.5179089,0.7018768,-0.3144238,0.19359513,0.51440924
180 | num_mlp_1_0_outputs,6,0.01387636,-0.01365367,-0.14135487,0.027097523,0.001307688,0.11765708,0.1378213
181 | num_mlp_1_0_outputs,7,0.067421116,0.3388032,-0.16557097,-0.12955894,-0.19303417,0.02512272,0.18632148
182 | num_mlp_1_1_outputs,0,0.3679926,0.43395373,0.39618903,0.30078262,0.29004973,0.1358351,0.06785458
183 | num_mlp_1_1_outputs,1,0.20307627,0.16571133,-0.034833442,0.056774694,-0.28404135,-0.0896154,-0.32265937
184 | num_mlp_1_1_outputs,2,0.13026774,0.061297063,0.06480035,0.068802625,-0.009410563,-0.18956332,-0.45669928
185 | num_mlp_1_1_outputs,3,0.23407696,0.14676198,0.011613104,-0.020378916,-0.018320754,-0.2158755,-0.41205496
186 | num_mlp_1_1_outputs,4,0.51550543,0.19345349,-0.29995015,-0.21714196,-0.05119475,-0.11259293,-0.39577466
187 | num_mlp_1_1_outputs,5,0.7264443,0.28373966,0.18676305,0.5019726,0.27477577,-0.22090729,-0.0033221303
188 | num_mlp_1_1_outputs,6,0.09651393,0.122992426,-0.09937113,-0.0116935,-0.07023688,-0.034174327,-0.34691346
189 | num_mlp_1_1_outputs,7,0.4628552,-0.038354713,0.18961802,0.0764211,-0.18719746,-0.06302883,-0.23135951
190 | num_mlp_2_0_outputs,0,1.3870856,-0.25443602,0.1545994,0.07868433,1.3484426,0.5581853,-1.5508221
191 | num_mlp_2_0_outputs,1,-0.6188547,-0.28784364,0.18100783,-0.7676343,-0.15114093,0.17485546,-0.16489339
192 | num_mlp_2_0_outputs,2,1.650306,0.25490463,0.73042595,0.5314448,0.4455995,0.46853003,-4.48665
193 | num_mlp_2_0_outputs,3,-0.53009546,-0.12033514,-0.21350926,-0.053965583,-0.014094976,0.3029367,0.16062967
194 | num_mlp_2_0_outputs,4,-0.35834578,-0.14751779,-0.32184893,-0.22511242,-0.2735182,0.02457167,-0.18555361
195 | num_mlp_2_0_outputs,5,-0.36402833,0.74258685,0.9334475,1.3147615,0.6026904,0.7450087,-0.85253215
196 | num_mlp_2_0_outputs,6,0.43582806,-0.7167769,-0.2027412,-0.5209631,-1.3333701,-0.1770416,-4.8763123
197 | num_mlp_2_0_outputs,7,-5.2162914,-1.7618054,-1.4699253,-1.6181387,-0.6872477,-1.2078984,8.595443
198 | num_mlp_2_1_outputs,0,0.38339245,0.74561024,1.8125025,0.34722832,0.57123804,0.35547578,-5.5821314
199 | num_mlp_2_1_outputs,1,0.15095738,0.10469516,1.0636414,0.34179947,0.79633814,0.42729956,-3.5103517
200 | num_mlp_2_1_outputs,2,-0.47540858,-0.09333152,-0.9867091,-0.23736648,-0.22757289,0.15208386,0.045130484
201 | num_mlp_2_1_outputs,3,-0.4281402,-0.18589725,-1.0558151,-0.16013376,0.048908155,-0.041572504,0.64169294
202 | num_mlp_2_1_outputs,4,-0.4590558,-0.4067188,-1.0003194,0.0051985644,0.21098104,-0.287274,0.3842943
203 | num_mlp_2_1_outputs,5,-0.07419524,0.20917611,-0.29637602,0.33867353,0.06768416,0.12848133,-0.05852556
204 | num_mlp_2_1_outputs,6,-1.1355742,-0.85389644,-3.9056017,-0.29252398,-0.52056974,-0.3573991,7.055144
205 | num_mlp_2_1_outputs,7,0.12410815,0.33784002,0.16031556,0.3937437,0.3059639,0.27686897,0.08846511
206 | ones,_,0.021529218,-0.15176989,0.2739342,0.42407322,0.85170376,0.6750486,0.039232593
207 | positions,0,-0.56147146,-0.12002305,0.2413873,0.1853156,-0.5742414,0.08497386,0.07996443
208 | positions,1,3.0675466,3.2321472,4.122518,0.2724716,3.5735428,3.0197644,-16.798847
209 | positions,2,2.4287295,1.9609035,1.7737118,3.7877054,1.7662969,1.9576112,-6.0926604
210 | positions,3,2.299956,-0.8758886,2.310295,2.8263414,1.6215533,2.39478,-6.9706354
211 | positions,4,-1.049784,1.8572471,-0.75735337,-1.4148265,-1.0481493,-1.4053416,2.2860672
212 | positions,5,-3.3229315,-1.5263214,-2.91694,-3.5551183,-4.0157175,-3.1399124,5.1658487
213 | positions,6,-7.9361725,-8.840564,-7.397052,-9.724115,-9.180236,-3.3155594,10.632461
214 | positions,7,-7.9004083,-13.131187,-9.459313,-10.873171,-9.7958555,-8.555323,11.622234
215 | tokens,0,0.7939464,0.02564311,0.10964918,0.23205447,-0.4861161,-0.0044119004,-0.5045989
216 | tokens,1,0.13817202,0.7952534,-0.108660065,0.011065895,-0.3787378,-0.07682034,-0.5505967
217 | tokens,2,0.21265996,0.10762084,0.8517744,0.086404786,0.06522274,-0.738799,-0.5161016
218 | tokens,3,0.009029636,-0.1933909,-0.23239014,1.1473926,-0.46348011,0.24348621,-0.5114875
219 | tokens,4,0.4256177,-0.043745376,-0.28037384,-0.1920835,1.4755142,-0.97692513,-0.5098483
220 | tokens,5,-0.042473715,-0.19566593,-0.2469844,0.31330782,-0.5277211,1.0684317,-0.44248924
221 | tokens,,-0.23496482,0.3631881,0.43677837,0.011330187,-0.40754265,-0.8863597,0.15948972
222 | tokens,,-0.056050863,0.29581282,-0.1711932,0.0840109,0.20864172,0.31114328,0.17907692
223 |
--------------------------------------------------------------------------------
/programs/rasp/reverse/reverse_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,0,1,2,3,4
2 | attn_0_0_outputs,0,2.742066,-1.7750658,-1.5919067,-1.4220088,-0.5447206
3 | attn_0_0_outputs,1,-1.7905326,3.2459338,-1.0251777,-0.35510015,-1.285745
4 | attn_0_0_outputs,2,-1.3913494,0.42532334,3.6383674,-1.5303861,-0.04807833
5 | attn_0_0_outputs,3,0.49622118,-0.41630313,-1.4269036,3.259316,-1.6067595
6 | attn_0_0_outputs,4,-0.6003677,-0.21356046,-2.0015326,-0.8214214,2.7384455
7 | attn_0_0_outputs,,-0.21139087,-0.061902717,2.2465034,1.0147868,-1.6599561
8 | attn_0_0_outputs,,0.92124563,0.5974984,0.154064,0.34046212,0.059019208
9 | attn_0_0_outputs,,0.24552345,0.39865023,0.20082256,-0.15461971,0.67986333
10 | attn_0_1_outputs,0,7.673302,-3.501045,-5.5389047,-4.3190165,-1.769243
11 | attn_0_1_outputs,1,-1.9890903,7.9207335,-3.289247,-3.0845697,-2.3067732
12 | attn_0_1_outputs,2,-4.705924,-3.0480518,7.011647,-0.6278108,-3.4870532
13 | attn_0_1_outputs,3,-3.1684492,-2.9271348,-2.6957874,7.3821106,-2.4872382
14 | attn_0_1_outputs,4,-1.5438712,-3.585505,-2.7780457,-3.6590195,6.381682
15 | attn_0_1_outputs,,2.4767106,-0.34512123,-2.5959194,1.4119138,-3.1488388
16 | attn_0_1_outputs,,0.1414338,-0.32987177,-0.14424743,0.121554226,0.069688104
17 | attn_0_1_outputs,,-0.67165095,-0.29987463,0.9411148,-1.1158872,0.9772637
18 | attn_0_2_outputs,0,2.8864198,0.5736399,-1.7944727,1.275154,-3.7867603
19 | attn_0_2_outputs,1,1.286558,3.8100312,0.2034576,-4.1266513,-2.4538345
20 | attn_0_2_outputs,2,-3.55232,-2.6151497,2.8218007,0.27975446,0.9032074
21 | attn_0_2_outputs,3,2.0676212,-1.8931607,-1.0893198,2.6283581,-3.3388352
22 | attn_0_2_outputs,4,-2.601808,1.4966966,-5.524055,-3.3436193,6.786246
23 | attn_0_2_outputs,,-1.2964387,-1.184649,3.191794,0.26659352,-0.07239697
24 | attn_0_2_outputs,,-0.00678445,-0.12167784,-0.28402784,-0.1283455,0.53786683
25 | attn_0_2_outputs,,0.46426013,-0.44629773,0.791248,0.34926856,-0.41215342
26 | attn_0_3_outputs,0,-0.1498222,1.2560209,0.19964494,0.6762862,-1.121879
27 | attn_0_3_outputs,1,0.64040285,1.9587319,1.4074128,-3.5538082,0.9837679
28 | attn_0_3_outputs,2,-1.5384316,-0.65306073,2.1958265,0.9510098,-3.2979362
29 | attn_0_3_outputs,3,-0.81305355,-1.5002078,0.020203998,0.63694465,0.60759926
30 | attn_0_3_outputs,4,-1.5025858,1.7346878,-1.3115505,1.4721668,1.8840551
31 | attn_0_3_outputs,,4.6110573,0.39784852,-4.7847724,0.5624679,-2.1006687
32 | attn_0_3_outputs,,-0.4865959,-0.75623137,-0.559213,-0.6034367,-0.19291799
33 | attn_0_3_outputs,,-0.6616858,-0.70004207,1.4663161,-2.5758874,0.58236104
34 | attn_1_0_outputs,0,5.14012,-2.0321715,-3.1873803,-2.161212,-2.924637
35 | attn_1_0_outputs,1,-1.1760738,6.243624,-2.292563,-2.3656707,-1.8633323
36 | attn_1_0_outputs,2,-1.8438314,-2.2223895,7.2542777,-1.7975602,-2.8767316
37 | attn_1_0_outputs,3,-1.4480008,-2.8605947,-2.552013,6.218815,-2.0353441
38 | attn_1_0_outputs,4,-1.8545668,-2.4011686,-2.4577968,-0.8337147,5.6437707
39 | attn_1_0_outputs,,1.4606192,-0.65165573,-0.96397144,1.39993,-2.030799
40 | attn_1_0_outputs,,0.22300358,-0.5646053,-0.41895154,0.40009347,-0.45633748
41 | attn_1_0_outputs,,-0.5169577,-0.574157,0.06152434,-0.45967716,0.7860707
42 | attn_1_1_outputs,0,-1.729071,0.3904098,1.063312,0.9541465,0.64347845
43 | attn_1_1_outputs,1,0.14294437,0.5665816,1.1513399,1.1456755,-0.45402047
44 | attn_1_1_outputs,2,-0.47264585,-1.2265207,0.73496914,-0.27962002,1.5110021
45 | attn_1_1_outputs,3,0.55221164,0.74901503,-0.2234499,0.46550083,0.13538252
46 | attn_1_1_outputs,4,-0.887349,-2.3832846,-0.090858944,-0.27379212,2.0423574
47 | attn_1_1_outputs,,0.44976777,0.99956155,-0.25536492,0.0694434,-0.45211685
48 | attn_1_1_outputs,,0.30245078,0.1915722,-0.34525523,0.6264211,0.28215456
49 | attn_1_1_outputs,,0.64166284,-0.9103854,0.21581404,-0.1512637,-0.53048915
50 | attn_1_2_outputs,0,-6.9120216,2.9334385,3.5168996,1.5317044,3.1637955
51 | attn_1_2_outputs,1,1.3649912,-8.357152,1.0929214,2.8909745,2.2582288
52 | attn_1_2_outputs,2,1.6044966,2.746573,-9.02028,2.4642735,2.3857315
53 | attn_1_2_outputs,3,1.3656769,3.410438,1.7813203,-7.0019917,2.6581526
54 | attn_1_2_outputs,4,1.7743725,2.7270765,2.717991,2.2948518,-7.542405
55 | attn_1_2_outputs,,-0.2602812,-0.3624655,1.7670982,-0.119142905,-1.4537865
56 | attn_1_2_outputs,,-0.4052089,-0.016108118,0.87777245,0.1692151,0.259325
57 | attn_1_2_outputs,,-0.61002606,0.6863391,1.1960802,-0.972229,-0.3090549
58 | attn_1_3_outputs,0,3.147749,-2.74755,-2.6375554,-1.7992983,-1.8083935
59 | attn_1_3_outputs,1,-0.3166353,3.135929,-2.0476818,-1.5201313,-0.9221027
60 | attn_1_3_outputs,2,-0.82175285,-1.7785048,4.262148,-1.098342,-1.320515
61 | attn_1_3_outputs,3,-0.48976257,-1.4941785,-2.4922595,3.688606,-0.7487961
62 | attn_1_3_outputs,4,-2.8958664,0.9745885,-1.7469578,0.11342506,2.1992576
63 | attn_1_3_outputs,,0.9585229,-0.021460164,-0.068051144,-0.5773519,-0.15273045
64 | attn_1_3_outputs,,-0.08786828,0.34689122,0.3743438,0.2036276,0.04196969
65 | attn_1_3_outputs,,-0.5865496,-0.08093558,1.0197412,0.5368188,-0.35131058
66 | attn_2_0_outputs,0,-6.404271,3.213514,3.7119424,3.6321664,2.465991
67 | attn_2_0_outputs,1,1.3457497,-6.91716,2.117783,3.216933,2.2687566
68 | attn_2_0_outputs,2,1.7306777,1.9339283,-6.4193964,2.4106104,1.2669448
69 | attn_2_0_outputs,3,1.3426539,4.61214,2.2282271,-6.918863,2.1501
70 | attn_2_0_outputs,4,1.4316176,2.9336243,3.1591465,1.782481,-7.2788053
71 | attn_2_0_outputs,,0.1130622,0.17945175,-0.29610068,-0.18893324,0.092504
72 | attn_2_0_outputs,,0.07159574,-0.3103722,0.110794805,-0.79871124,-0.7539021
73 | attn_2_0_outputs,,0.7954956,-0.26322004,-0.07522555,0.038551275,-0.50479794
74 | attn_2_1_outputs,0,6.998647,-3.457301,-4.5451207,-1.8305441,-1.8368263
75 | attn_2_1_outputs,1,-0.40058747,6.4474087,-1.4574099,-2.18984,-1.7002127
76 | attn_2_1_outputs,2,-1.9955103,-2.4883533,7.227968,-2.8376138,-1.9176865
77 | attn_2_1_outputs,3,-1.350527,-2.7862217,-2.4625497,6.6744094,-1.3645524
78 | attn_2_1_outputs,4,-3.3767576,-0.8977402,-2.9630077,-0.49530327,5.3836374
79 | attn_2_1_outputs,,0.20378506,-0.5800168,-1.106855,0.19926731,-0.06131589
80 | attn_2_1_outputs,,-0.43978697,-0.21296938,-0.22513205,0.6500348,0.30841365
81 | attn_2_1_outputs,,-0.547537,0.8714378,0.33708888,-2.0015378,1.5639987
82 | attn_2_2_outputs,0,4.7300496,-2.609322,-3.414926,-2.652299,-1.7123688
83 | attn_2_2_outputs,1,-1.0747069,4.541831,-1.8588012,-1.2684554,-0.17224875
84 | attn_2_2_outputs,2,-1.5746042,-1.28898,4.3767176,-1.0493795,-2.8320768
85 | attn_2_2_outputs,3,-2.344087,-1.2606856,-0.91734046,5.6075225,-2.2346644
86 | attn_2_2_outputs,4,-2.6133668,-0.18721344,-1.6325711,-1.149275,3.6483755
87 | attn_2_2_outputs,,1.2496438,-0.41678748,0.087676175,-0.6057794,-0.044405453
88 | attn_2_2_outputs,,-0.15122709,-0.2583006,0.39665553,0.003170321,-0.016708078
89 | attn_2_2_outputs,,-0.7736017,-0.43213943,0.18013555,-0.09826307,0.057805218
90 | attn_2_3_outputs,0,-2.7729218,1.4200424,1.4600606,0.34230572,0.84415036
91 | attn_2_3_outputs,1,0.4400618,-1.756905,1.4549123,1.019148,0.4850436
92 | attn_2_3_outputs,2,0.9549057,1.2824273,-2.0410612,0.9188365,1.6863561
93 | attn_2_3_outputs,3,-0.091469035,1.2047702,1.370021,-0.89607966,-0.51357347
94 | attn_2_3_outputs,4,0.466892,0.8118219,0.28800988,1.3429086,-3.2995336
95 | attn_2_3_outputs,,0.24129342,-0.86307657,0.1745238,-2.054632,2.8129969
96 | attn_2_3_outputs,,-0.11414052,0.05034334,0.09903593,0.10165983,0.5136056
97 | attn_2_3_outputs,,0.4432127,0.15918905,0.12439812,0.3750773,-0.17403647
98 | mlp_0_0_outputs,0,-5.1317487,-6.70734,6.037307,4.5852356,4.0092926
99 | mlp_0_0_outputs,1,-0.257455,0.35274887,0.24112554,-0.42148477,0.024651978
100 | mlp_0_0_outputs,2,-1.8256108,4.1555786,-1.9008249,1.3808151,-1.2630675
101 | mlp_0_0_outputs,3,4.6051683,-1.8804812,3.9409556,-5.850926,-0.88602555
102 | mlp_0_0_outputs,4,1.0506113,-3.9949007,0.13467729,0.70601964,-0.3138636
103 | mlp_0_0_outputs,5,0.5380068,-0.7315861,-1.4061457,1.1938375,-1.1255091
104 | mlp_0_0_outputs,6,-0.14030905,2.8830237,-2.683621,-0.16997392,0.47404832
105 | mlp_0_0_outputs,7,0.3876715,-0.08891349,0.8728597,1.0125917,0.11347792
106 | mlp_1_0_outputs,0,2.3266206,-6.2295556,-4.0005665,2.3041046,0.21257897
107 | mlp_1_0_outputs,1,-0.4556595,-0.805607,-1.6509348,1.7824169,1.501474
108 | mlp_1_0_outputs,2,-5.792818,9.575128,3.5130193,-7.9973984,5.002579
109 | mlp_1_0_outputs,3,-1.2260249,0.9880918,-2.378248,2.0664575,1.2702942
110 | mlp_1_0_outputs,4,4.1306114,1.7542486,5.0177255,-0.94854623,-13.307635
111 | mlp_1_0_outputs,5,-0.06441002,0.36633462,0.41651008,0.1121977,-0.19381374
112 | mlp_1_0_outputs,6,0.28442925,0.3170138,1.0178398,-0.13551465,0.05566804
113 | mlp_1_0_outputs,7,-0.30202946,0.3719944,0.7272587,0.13975187,-0.37298957
114 | mlp_2_0_outputs,0,0.45908874,-0.01754609,0.97707564,-0.9511417,0.5083081
115 | mlp_2_0_outputs,1,0.18382199,-0.19358855,0.4822291,-0.80399686,-0.1716763
116 | mlp_2_0_outputs,2,1.5142181,-0.02708238,-0.86326504,0.48660412,-0.061133042
117 | mlp_2_0_outputs,3,-1.4314252,0.31027403,1.2170343,-0.10667357,0.94278526
118 | mlp_2_0_outputs,4,-0.0568396,-0.4547981,1.945643,-1.6712406,0.698788
119 | mlp_2_0_outputs,5,0.47281817,-0.15489303,0.5081915,-0.3702506,0.28318653
120 | mlp_2_0_outputs,6,0.30037868,0.22154446,0.37623194,0.10769829,0.31810036
121 | mlp_2_0_outputs,7,0.1804857,0.27696744,-0.5805316,1.504735,-0.4786584
122 | num_attn_0_0_outputs,_,0.8465965,2.0133536,-2.2380738,-1.4228604,1.0347323
123 | num_attn_0_1_outputs,_,-0.8386918,-0.38922623,-2.324954,0.07236001,1.3036776
124 | num_attn_0_2_outputs,_,-0.117477916,-2.304579,1.9196264,-1.8928717,-0.46572107
125 | num_attn_0_3_outputs,_,0.28793186,-1.0810422,-1.1591295,0.8848445,0.37602443
126 | num_attn_1_0_outputs,_,0.67058,-0.86042327,-1.924298,-1.2025291,3.5931141
127 | num_attn_1_1_outputs,_,-1.2113587,0.8658644,0.23557991,-0.5923407,0.55709636
128 | num_attn_1_2_outputs,_,0.9216758,0.9326565,-4.265963,0.8197906,1.7644495
129 | num_attn_1_3_outputs,_,2.0054622,-1.7789265,-0.87892455,-2.7376547,2.2104425
130 | num_attn_2_0_outputs,_,-1.8372577,0.24842429,-1.3088604,0.64833593,1.3509363
131 | num_attn_2_1_outputs,_,-0.108283974,-1.7611748,3.154574,3.5453818,0.70397145
132 | num_attn_2_2_outputs,_,3.0538957,-1.0975236,2.4850554,-1.7636111,-0.890746
133 | num_attn_2_3_outputs,_,0.43507352,4.196319,3.1197977,-2.9802363,-0.2592518
134 | num_mlp_0_0_outputs,0,-0.64667314,0.17904095,0.39200422,-0.0028258313,-0.37290496
135 | num_mlp_0_0_outputs,1,0.2116071,-0.017157368,0.9101409,0.45278645,0.056753635
136 | num_mlp_0_0_outputs,2,0.061665036,0.41990674,0.6671413,0.40301383,0.24431042
137 | num_mlp_0_0_outputs,3,-0.4006326,-0.4738044,-0.11009932,-0.14682105,-0.36005288
138 | num_mlp_0_0_outputs,4,0.4470174,-0.936997,-1.4516735,0.07075252,-0.16868772
139 | num_mlp_0_0_outputs,5,-0.5206594,-0.38733086,0.040900033,-0.473343,-0.33152476
140 | num_mlp_0_0_outputs,6,-0.22264968,0.5538271,0.2416182,-0.031259954,-0.3844373
141 | num_mlp_0_0_outputs,7,-0.1432738,-0.064719036,0.12572391,-0.20695664,0.2349148
142 | num_mlp_1_0_outputs,0,0.23161256,-0.17896621,1.3701653,0.027287895,-1.3479851
143 | num_mlp_1_0_outputs,1,-0.29446986,-0.47882456,0.77442497,0.34317648,-0.5066082
144 | num_mlp_1_0_outputs,2,-0.37243122,0.5159971,-2.766263,0.03633871,2.1990325
145 | num_mlp_1_0_outputs,3,-0.2770934,-0.053324725,0.62508935,0.08229346,-1.4606029
146 | num_mlp_1_0_outputs,4,0.47045127,0.05848788,1.1008859,0.31856912,-0.87950534
147 | num_mlp_1_0_outputs,5,0.03743611,-0.21205485,0.7981379,0.031759117,-0.23362987
148 | num_mlp_1_0_outputs,6,-0.120329894,0.09074645,1.4699514,-0.08971425,-1.3956728
149 | num_mlp_1_0_outputs,7,0.41638732,0.2154947,1.4405341,0.6774053,-0.46320668
150 | num_mlp_2_0_outputs,0,0.22702669,0.38019413,0.5690051,0.97963583,0.35750404
151 | num_mlp_2_0_outputs,1,-0.8008013,0.08042168,-0.41324988,-0.26802236,-0.3302354
152 | num_mlp_2_0_outputs,2,-0.64541245,-0.080827184,-0.18740739,-0.22994758,-0.24330296
153 | num_mlp_2_0_outputs,3,0.037600122,0.50564003,-0.23816861,0.11780354,0.2502989
154 | num_mlp_2_0_outputs,4,-0.012841126,0.39657962,-0.119696006,0.11229075,-0.10252234
155 | num_mlp_2_0_outputs,5,-0.7187831,-0.115768135,-0.15021832,0.22964966,-0.20723926
156 | num_mlp_2_0_outputs,6,-0.22354771,0.23253189,0.22618711,-0.14161181,-0.4494641
157 | num_mlp_2_0_outputs,7,-0.43640983,0.28443414,-0.12486136,-0.008818966,0.2645296
158 | ones,_,0.55620617,-0.5459692,0.11906973,0.0029487282,-0.60756934
159 | positions,0,0.48431224,-0.07508403,-0.15782595,0.08621112,-0.2329988
160 | positions,1,-1.0866297,0.20545797,0.020345941,0.27752703,-0.0058798143
161 | positions,2,-0.1897344,-0.6670251,0.017169373,-0.7351196,2.036752
162 | positions,3,-0.20501482,0.57732826,1.1450588,1.1435491,-1.1068201
163 | positions,4,0.12606427,0.026706036,-0.29324046,0.31004795,0.5862827
164 | positions,5,-0.7492487,0.8921466,-1.4819856,0.54742503,0.501571
165 | positions,6,-0.6667614,1.901696,1.1710435,0.22811733,-1.1982802
166 | positions,7,0.51935697,0.3317305,-0.9227535,0.20672892,-0.5248978
167 | tokens,0,-0.47813687,0.12706348,-0.04924151,0.548112,0.14897579
168 | tokens,1,-0.30380446,0.8807152,-0.2979864,0.1732994,-0.020828225
169 | tokens,2,-0.21944927,0.18884198,0.6118489,0.252533,-1.4074645
170 | tokens,3,-0.19298886,0.11001143,-0.3632103,0.2215297,-0.3719192
171 | tokens,4,-0.076143056,0.31109792,-0.25140736,-0.4576611,0.7776612
172 | tokens,,0.24346647,0.7005277,0.22031198,-0.29249498,-0.07512681
173 | tokens,,0.11568012,-0.57711613,0.19756645,-0.48424596,0.09319665
174 | tokens,,-0.5352298,0.28023535,0.11314476,0.25227708,0.2986933
175 |
--------------------------------------------------------------------------------
/programs/rasp/sort/sort_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,0,1,2,3,4
2 | attn_0_0_outputs,0,2.058161,-0.63779896,0.352272,0.5141959,-1.740315
3 | attn_0_0_outputs,1,-2.4275339,2.824267,1.1222093,-0.29264563,-1.6031634
4 | attn_0_0_outputs,2,-0.66014487,0.4371716,0.6551343,0.1385141,-0.82386607
5 | attn_0_0_outputs,3,-1.4426589,-0.6121028,0.95834476,3.1481867,-2.513183
6 | attn_0_0_outputs,4,-5.1414647,-4.19145,-2.107596,0.40243983,5.3070917
7 | attn_0_0_outputs,,-2.1824768,-2.4896393,-1.6732111,-0.5224847,3.0039287
8 | attn_0_0_outputs,,-0.29941246,0.6749306,0.05445957,-0.16473089,0.30180916
9 | attn_0_0_outputs,,-1.52416,-1.9376818,-1.7669702,-1.5059099,3.1769867
10 | attn_0_1_outputs,0,2.1742406,-0.104620405,0.667535,0.56420654,-0.70504665
11 | attn_0_1_outputs,1,-1.0247861,2.315593,0.35188147,-0.6513114,-1.2275088
12 | attn_0_1_outputs,2,-0.4099264,0.6904153,0.5056659,-0.7560546,-0.83760774
13 | attn_0_1_outputs,3,-1.055448,-0.52389306,1.1448267,3.5426579,-2.5540905
14 | attn_0_1_outputs,4,-3.743604,-3.8234832,-2.1349356,-0.02523663,4.5795074
15 | attn_0_1_outputs,,-6.0711293,-3.95341,-1.697248,0.52152914,5.3833594
16 | attn_0_1_outputs,,0.4076765,0.08954932,-0.14387824,0.091574,0.20851962
17 | attn_0_1_outputs,,-0.38013995,0.15324089,0.36798063,0.38066244,0.54085183
18 | attn_0_2_outputs,0,5.2839427,-2.605443,-1.1492563,-1.7018955,-1.4085128
19 | attn_0_2_outputs,1,-1.1856108,0.3934325,0.09754402,1.237557,1.0921215
20 | attn_0_2_outputs,2,-1.2996812,-0.00338279,0.70976466,0.60546106,0.31050247
21 | attn_0_2_outputs,3,-1.7861035,0.39520007,0.18829803,2.3311763,-1.300178
22 | attn_0_2_outputs,4,-9.374754,-5.7489552,-2.4201286,1.2188936,6.6614056
23 | attn_0_2_outputs,,-0.24304307,0.06899596,0.06185269,0.42859817,0.07496456
24 | attn_0_2_outputs,,0.54609865,0.4358928,0.3134578,0.6048993,0.7376409
25 | attn_0_2_outputs,,-2.2961771,2.386299,1.5124247,0.1737369,-2.5670602
26 | attn_0_3_outputs,0,0.3295003,-0.1827127,-0.18133737,-0.013248737,-0.10451385
27 | attn_0_3_outputs,1,-0.45805857,0.50490904,0.15572038,-0.71143156,-0.13632512
28 | attn_0_3_outputs,2,-0.060173124,0.06350225,0.20524347,0.15398552,0.42267397
29 | attn_0_3_outputs,3,-1.0496696,0.57741636,-0.8549515,2.4007623,-1.0587981
30 | attn_0_3_outputs,4,-2.2801893,-0.99242985,-0.45985192,-0.20450349,2.3491073
31 | attn_0_3_outputs,,-5.532329,-2.3926458,-0.032455105,2.7236922,3.9317892
32 | attn_0_3_outputs,,0.12758909,-0.13754353,0.030744994,-0.15796055,0.30099612
33 | attn_0_3_outputs,,-0.41964686,-1.7669121,-1.2862828,0.43886116,2.138884
34 | attn_1_0_outputs,0,2.0348678,-1.2148954,-0.88260573,0.072366394,0.8086225
35 | attn_1_0_outputs,1,-1.2006625,0.116209,-0.0810096,0.38852194,0.95987433
36 | attn_1_0_outputs,2,-0.14999832,0.9288459,1.3716835,1.1660155,-1.7833197
37 | attn_1_0_outputs,3,-2.7026277,-1.178805,-1.0291067,1.0259144,1.4074742
38 | attn_1_0_outputs,4,-4.517803,-1.8277934,-1.0758657,0.06511916,2.5772486
39 | attn_1_0_outputs,,-0.9347721,-0.8367545,-0.77415204,-0.26920766,1.24447
40 | attn_1_0_outputs,,-0.04765274,-0.13257939,0.18410715,0.88693845,0.29562962
41 | attn_1_0_outputs,,-2.463801,0.061894733,-0.032191936,0.4246904,0.90985197
42 | attn_1_1_outputs,0,2.2647228,-1.0525606,0.11064329,-0.4824749,-2.2326334
43 | attn_1_1_outputs,1,-1.016662,2.9539187,1.0269936,-1.7816477,-2.2167149
44 | attn_1_1_outputs,2,0.19277221,1.0981342,1.3715613,0.008356502,-0.45385465
45 | attn_1_1_outputs,3,-0.21926539,-0.4179502,1.2211832,3.2872894,-2.8288746
46 | attn_1_1_outputs,4,-4.9702883,-4.228738,-1.920627,0.4357961,6.435965
47 | attn_1_1_outputs,,-2.873453,-4.4431806,-1.4632183,0.603075,5.768971
48 | attn_1_1_outputs,,-0.25213683,0.016516566,0.26208618,-0.27179545,-0.22851557
49 | attn_1_1_outputs,,-1.7156378,-0.26994628,-0.20035094,0.42288014,1.1078966
50 | attn_1_2_outputs,0,2.66944,-0.45978594,0.24579379,0.6636476,-2.0367975
51 | attn_1_2_outputs,1,-1.3979408,2.155738,0.6993282,-0.12806043,-1.3829575
52 | attn_1_2_outputs,2,-0.5397533,0.3294003,1.0066247,0.8717879,-0.453143
53 | attn_1_2_outputs,3,-0.79578793,-0.5700265,1.0958053,3.8531687,-2.876995
54 | attn_1_2_outputs,4,-4.6522145,-4.2854285,-2.4250696,0.3245887,5.7784514
55 | attn_1_2_outputs,,-1.2333149,-0.8423292,-0.6918403,0.26540172,1.3943026
56 | attn_1_2_outputs,,0.14982238,0.35288456,-0.88604194,-0.06601087,-0.15605317
57 | attn_1_2_outputs,,-0.6495551,-2.0333426,-1.4724616,-0.4954959,2.3913198
58 | attn_1_3_outputs,0,1.0067343,-0.6924186,-0.22267842,0.0022108958,-0.34514758
59 | attn_1_3_outputs,1,-1.1532832,1.2752177,0.4407405,-0.2488312,-0.34415925
60 | attn_1_3_outputs,2,0.43818164,0.1435282,0.32248864,0.17158462,-0.002002176
61 | attn_1_3_outputs,3,-0.05012911,0.06287343,0.9565397,2.1478107,-1.4736792
62 | attn_1_3_outputs,4,-3.6912036,-2.599063,-0.84031093,0.7071018,3.493349
63 | attn_1_3_outputs,,-1.0017529,-0.16416864,0.032560404,0.2880399,0.6146985
64 | attn_1_3_outputs,,0.08257049,-0.038490284,-0.1203994,0.65081865,-0.14556827
65 | attn_1_3_outputs,,-2.321713,-1.3995543,-0.75829405,-0.35096657,2.7249875
66 | attn_2_0_outputs,0,0.65123445,-0.00015970842,0.28133044,0.29919377,-0.4494978
67 | attn_2_0_outputs,1,-0.46387243,-0.22923669,-0.09010046,0.5440508,-0.58355516
68 | attn_2_0_outputs,2,-0.746604,-0.8561036,0.9778113,-0.8867997,1.2600936
69 | attn_2_0_outputs,3,-2.3676467,-0.8762227,0.047619678,2.8621771,0.106826946
70 | attn_2_0_outputs,4,-2.2080712,-2.103707,-1.5367882,-0.9091501,3.8440235
71 | attn_2_0_outputs,,-1.5794939,-0.87683463,-0.5443371,0.67485535,0.6738131
72 | attn_2_0_outputs,,0.0008923216,0.19533116,0.36518034,-0.08854164,-0.2732349
73 | attn_2_0_outputs,,-0.59456545,-0.38609,-0.37244636,-0.48838097,1.1909106
74 | attn_2_1_outputs,0,6.305336,0.77024233,-0.5253686,-3.0544596,-2.9562898
75 | attn_2_1_outputs,1,-0.506348,0.88375586,-0.5610058,-0.3551172,-0.44845787
76 | attn_2_1_outputs,2,-0.79803765,-0.26507252,0.40326494,0.20280881,-0.38371047
77 | attn_2_1_outputs,3,-2.2682018,0.34390137,0.8440804,2.4715912,-1.713408
78 | attn_2_1_outputs,4,-8.596309,-6.2820163,-3.0651898,0.6778617,7.88834
79 | attn_2_1_outputs,,-0.9434871,-0.3908067,0.28419968,1.0826464,-0.4644816
80 | attn_2_1_outputs,,-0.73219436,-0.09140464,0.21194346,0.6483371,0.62705564
81 | attn_2_1_outputs,,-3.0782797,0.004718122,1.668909,2.3308017,-2.3886883
82 | attn_2_2_outputs,0,0.71128976,-0.3320928,0.04913793,0.27017894,-0.5422503
83 | attn_2_2_outputs,1,-2.1448827,1.1844211,0.69605696,0.05205563,-0.086535886
84 | attn_2_2_outputs,2,-1.3683058,0.17096008,0.6356051,0.6679375,0.6418501
85 | attn_2_2_outputs,3,-1.5241508,-0.8560207,0.32365173,1.5988479,0.07578664
86 | attn_2_2_outputs,4,-2.2993717,-1.7863915,-0.8184309,0.3028464,2.733902
87 | attn_2_2_outputs,,-1.1926019,-0.42977494,-0.319407,1.0023048,0.43580797
88 | attn_2_2_outputs,,-0.11356022,0.018971028,-0.61723095,-0.21103133,-0.4132432
89 | attn_2_2_outputs,,-0.7201614,-1.980554,-1.0417174,0.70280063,0.9491416
90 | attn_2_3_outputs,0,-0.6578003,-1.4176185,-0.16867943,0.79162806,1.3469748
91 | attn_2_3_outputs,1,-1.2225962,0.5266581,0.31030974,0.50686336,0.88578856
92 | attn_2_3_outputs,2,-1.1110145,-0.59071904,0.040772133,0.40666303,0.49748218
93 | attn_2_3_outputs,3,-0.7991637,-0.63388026,0.025102817,1.1622124,0.9700968
94 | attn_2_3_outputs,4,-1.523163,-1.0361481,-0.06378266,0.8740367,1.220963
95 | attn_2_3_outputs,,-2.4229674,-0.92591256,0.03990261,0.5686489,1.1961548
96 | attn_2_3_outputs,,-0.7670137,0.32505798,-0.19609599,0.63817674,0.36692366
97 | attn_2_3_outputs,,-0.17675255,-0.061802976,0.41715944,1.5147864,0.637639
98 | mlp_0_0_outputs,0,-0.8589103,-0.1552966,0.61306983,0.07937584,2.6195028
99 | mlp_0_0_outputs,1,-0.089151576,0.6665328,0.64555883,0.26704308,-1.2672955
100 | mlp_0_0_outputs,2,-1.3417193,0.89044684,1.4538382,1.145795,-0.34172156
101 | mlp_0_0_outputs,3,5.009247,2.353605,-0.03374805,-2.3626945,-6.8675437
102 | mlp_0_0_outputs,4,-0.00933083,0.8026369,0.70613724,0.33960083,-0.33451435
103 | mlp_0_0_outputs,5,-0.036977522,-0.1316631,0.5099742,0.5255853,0.5357659
104 | mlp_0_0_outputs,6,0.09807571,0.65641683,-0.31813905,0.17056943,-0.71300334
105 | mlp_0_0_outputs,7,-10.803721,-6.045529,-1.7798707,3.2467043,5.0803566
106 | mlp_0_1_outputs,0,-0.5035141,-0.20353706,-0.7522745,0.3356952,-0.6081926
107 | mlp_0_1_outputs,1,-4.242264,0.032010686,1.6723924,1.2721102,1.0669887
108 | mlp_0_1_outputs,2,-1.4796605,-1.0899225,0.16888815,1.3151004,0.8872381
109 | mlp_0_1_outputs,3,-0.28629878,-0.28274626,0.22532974,-0.05161379,0.6493964
110 | mlp_0_1_outputs,4,0.044844028,0.594797,0.088082105,-0.6334384,-0.99465847
111 | mlp_0_1_outputs,5,3.6674852,2.6837974,0.23535982,-2.6139388,-6.132662
112 | mlp_0_1_outputs,6,-1.7653883,-9.047455,-5.737482,0.19138227,7.072668
113 | mlp_0_1_outputs,7,-5.4091525,-1.258377,1.3022676,2.1761138,1.0759268
114 | mlp_1_0_outputs,0,-8.089947,-3.9758954,-0.049512167,1.9524714,3.6874063
115 | mlp_1_0_outputs,1,-0.25300968,-0.18175596,0.41158956,0.34343994,0.2131009
116 | mlp_1_0_outputs,2,2.9635074,1.5911162,0.35373554,-1.8420515,-4.9741387
117 | mlp_1_0_outputs,3,-0.7055162,0.8571573,1.9057972,0.2289657,-1.0690937
118 | mlp_1_0_outputs,4,-2.4608774,-1.1034535,0.5610324,0.7546167,2.5527046
119 | mlp_1_0_outputs,5,-0.8002519,-0.37941346,0.35449395,-0.01350961,-0.08785999
120 | mlp_1_0_outputs,6,-0.27233246,1.2476993,0.5192803,-0.4762693,-0.067111954
121 | mlp_1_0_outputs,7,0.431638,-0.23128301,-0.22074634,-0.38865703,0.29866084
122 | mlp_1_1_outputs,0,2.570149,1.4482967,-0.018957324,-2.3094382,-4.180814
123 | mlp_1_1_outputs,1,-6.535501,-5.5468464,-2.4822295,0.90257144,4.8221583
124 | mlp_1_1_outputs,2,0.21253164,-0.8952474,-0.4594328,-0.5081009,0.12444797
125 | mlp_1_1_outputs,3,0.107153066,1.1495358,0.47018173,-0.3799021,-0.95472884
126 | mlp_1_1_outputs,4,-0.44310027,-0.50543004,0.1825062,0.14150716,0.23581605
127 | mlp_1_1_outputs,5,-5.052584,-2.6555338,1.050915,2.478589,2.2599907
128 | mlp_1_1_outputs,6,-0.97004384,0.11741753,0.20926222,-0.07429946,-0.20051539
129 | mlp_1_1_outputs,7,-1.5305095,0.36310226,0.9118783,0.68767303,-0.9041264
130 | mlp_2_0_outputs,0,-3.9110544,-0.9270479,0.4481431,2.1587543,2.9880595
131 | mlp_2_0_outputs,1,0.20119122,0.09129438,-0.10534093,-0.13398942,0.35579532
132 | mlp_2_0_outputs,2,-5.627938,-3.4119296,-0.54387337,2.0478127,3.7280648
133 | mlp_2_0_outputs,3,-0.8237606,-0.058409758,0.06423915,0.38022444,0.32353234
134 | mlp_2_0_outputs,4,0.523112,0.52292186,0.10970966,-1.1015787,-0.37122783
135 | mlp_2_0_outputs,5,-0.463441,0.4770951,0.31376934,0.037225943,0.74693483
136 | mlp_2_0_outputs,6,3.0876498,3.265038,0.6365583,-2.6961308,-5.7100368
137 | mlp_2_0_outputs,7,-0.03848318,0.07956341,0.12161484,-0.63174456,-0.15266788
138 | mlp_2_1_outputs,0,-1.19403,-0.013288918,0.25646126,-0.34185714,0.46653393
139 | mlp_2_1_outputs,1,-0.28805476,0.1984507,-0.054004673,0.08747607,-0.0060478314
140 | mlp_2_1_outputs,2,4.965758,3.2663405,0.7348362,-2.8354158,-7.0237536
141 | mlp_2_1_outputs,3,-0.3902142,0.14936583,0.46208203,0.5396692,-0.32771578
142 | mlp_2_1_outputs,4,-1.6855072,-0.8212709,-0.031309534,0.8434268,1.4494125
143 | mlp_2_1_outputs,5,-0.66520023,-0.12322676,0.23253511,0.8862266,0.5351467
144 | mlp_2_1_outputs,6,-0.12667158,0.17387478,0.37683076,0.30193081,0.033221602
145 | mlp_2_1_outputs,7,-6.953155,-4.2610555,-1.0238223,2.2198749,6.7818727
146 | num_attn_0_0_outputs,_,2.7008455,1.7573988,-0.5612883,-3.1591482,-0.8077632
147 | num_attn_0_1_outputs,_,1.8419529,1.7084945,1.2732,-0.6164277,-2.7078419
148 | num_attn_0_2_outputs,_,1.7593077,-0.78708524,2.0769553,1.6491287,-2.7435398
149 | num_attn_0_3_outputs,_,3.631106,-1.8356805,-2.8216524,-3.0461025,2.3921132
150 | num_attn_1_0_outputs,_,1.5844845,-3.8878868,-2.0250123,4.444674,0.40432557
151 | num_attn_1_1_outputs,_,-0.05317056,0.5621444,-0.03736721,-0.7964875,-1.0776875
152 | num_attn_1_2_outputs,_,0.37114012,2.4403827,0.115093544,-0.80042255,-2.1261592
153 | num_attn_1_3_outputs,_,2.2861667,3.0232863,0.6303658,-1.7195753,-3.5699532
154 | num_attn_2_0_outputs,_,-0.81557804,-0.12176335,8.105928,-0.29102525,-3.5535772
155 | num_attn_2_1_outputs,_,2.769653,2.2095704,0.32860684,-2.286327,-3.9597259
156 | num_attn_2_2_outputs,_,-0.4148719,0.17052153,0.75779927,1.0667933,-1.2106642
157 | num_attn_2_3_outputs,_,2.7550514,0.53845423,0.012782344,-0.42958698,-2.1298997
158 | num_mlp_0_0_outputs,0,-0.052239962,-0.9593842,-0.2768684,-0.024186546,0.13085616
159 | num_mlp_0_0_outputs,1,-0.47369894,-0.043022864,0.4291395,0.35447082,0.6350527
160 | num_mlp_0_0_outputs,2,-3.0991662,0.14400464,-0.43281138,0.75551677,2.5591822
161 | num_mlp_0_0_outputs,3,-0.6901305,-0.7482572,-0.22040156,-0.065956876,0.5631253
162 | num_mlp_0_0_outputs,4,-0.82058483,-1.0333056,-0.2300869,0.055369288,0.27217725
163 | num_mlp_0_0_outputs,5,-0.24412262,-0.77988327,-0.03587442,0.099466026,0.4355093
164 | num_mlp_0_0_outputs,6,-0.18769716,-0.87073785,-0.08458955,-0.0117840925,0.14444879
165 | num_mlp_0_0_outputs,7,0.7771105,-1.2428197,0.109384626,0.3367401,0.5219702
166 | num_mlp_0_1_outputs,0,-1.928114,-0.64596057,0.1970114,0.72684216,0.2434265
167 | num_mlp_0_1_outputs,1,1.4667885,-0.04219994,-1.7887018,-3.354516,1.6295062
168 | num_mlp_0_1_outputs,2,-0.98909557,-0.37205845,0.4755718,1.0217645,0.4730045
169 | num_mlp_0_1_outputs,3,-3.8669333,-1.8498815,0.3956896,1.7269781,1.3696659
170 | num_mlp_0_1_outputs,4,-2.0566127,-0.9900346,0.124328226,0.6929384,0.54728794
171 | num_mlp_0_1_outputs,5,-1.4753255,0.13082078,0.030201413,0.20565401,-0.11292716
172 | num_mlp_0_1_outputs,6,-0.84873676,0.19060001,0.3849743,0.5375944,0.21243429
173 | num_mlp_0_1_outputs,7,-1.5675162,-0.14484215,0.5516114,1.1091007,0.26935497
174 | num_mlp_1_0_outputs,0,-0.58943486,-0.5716185,0.09452104,0.51946837,0.8833337
175 | num_mlp_1_0_outputs,1,0.16628344,0.028681463,-0.1816516,0.4498699,-0.17988878
176 | num_mlp_1_0_outputs,2,-0.20675577,-0.3085491,-0.45421112,0.071232475,-0.4011028
177 | num_mlp_1_0_outputs,3,0.07106293,-0.22149324,-0.17892063,0.39708918,0.03731434
178 | num_mlp_1_0_outputs,4,-0.59267014,-0.2528348,-0.13206014,0.11341445,-0.21827152
179 | num_mlp_1_0_outputs,5,-3.1149821,-0.90035933,0.7789497,1.3473841,1.3481094
180 | num_mlp_1_0_outputs,6,-0.7881431,-0.3538762,0.080496565,0.33989707,0.09272613
181 | num_mlp_1_0_outputs,7,-1.9980044,-1.499147,-0.13361156,1.4995332,2.7140417
182 | num_mlp_1_1_outputs,0,-0.24536662,0.2940862,-0.34453082,0.046545297,0.20954652
183 | num_mlp_1_1_outputs,1,-1.5211539,-1.4764887,-0.24266818,0.62331903,0.839463
184 | num_mlp_1_1_outputs,2,-0.60920113,0.39721942,-0.2720907,0.6585138,1.2556117
185 | num_mlp_1_1_outputs,3,-0.7962207,-0.022392428,-0.098228484,0.12305577,0.5722526
186 | num_mlp_1_1_outputs,4,-1.1961201,-0.48624808,0.0026044075,0.9489868,1.084819
187 | num_mlp_1_1_outputs,5,-0.204229,0.41044012,-0.16175759,0.3332036,0.38448393
188 | num_mlp_1_1_outputs,6,-4.0577865,-1.1186019,-0.19234465,2.1006806,1.1207315
189 | num_mlp_1_1_outputs,7,-0.1431973,-0.2963608,0.35561675,0.5961827,-0.64972585
190 | num_mlp_2_0_outputs,0,1.8791451,-1.1463113,-0.51802814,0.93353784,-1.1545902
191 | num_mlp_2_0_outputs,1,-1.1244311,-0.1578107,0.516671,0.5967867,1.4779365
192 | num_mlp_2_0_outputs,2,-2.4088686,0.21240945,0.9339892,0.09490182,1.2009406
193 | num_mlp_2_0_outputs,3,-0.79646355,-0.035708643,-0.15684086,-0.09949614,0.48360366
194 | num_mlp_2_0_outputs,4,3.6157093,-6.813567,-1.4317391,1.9446093,2.5641017
195 | num_mlp_2_0_outputs,5,0.39062092,1.3092825,-0.19412902,0.75027215,-0.8247279
196 | num_mlp_2_0_outputs,6,-3.1036084,1.2624242,-1.3120074,-0.12865852,1.6784518
197 | num_mlp_2_0_outputs,7,-4.274369,0.27141675,0.99371576,0.6078943,1.7251453
198 | num_mlp_2_1_outputs,0,-1.2744915,-0.38451442,-0.0047318963,0.3987959,0.38946167
199 | num_mlp_2_1_outputs,1,-1.6799719,-0.7788202,-0.20249598,-0.18425512,1.1495807
200 | num_mlp_2_1_outputs,2,-0.34154946,-0.51062804,-0.3556491,-0.033050634,0.4181513
201 | num_mlp_2_1_outputs,3,-0.53709805,-1.6549367,-0.7475045,1.7235987,1.2236754
202 | num_mlp_2_1_outputs,4,-0.83570236,0.10588158,-0.06152223,0.15144911,0.08399685
203 | num_mlp_2_1_outputs,5,-0.79256934,-0.39871466,-0.007255134,0.15521254,0.46749616
204 | num_mlp_2_1_outputs,6,0.0021853913,0.06891026,0.11127618,0.06793636,-0.5632735
205 | num_mlp_2_1_outputs,7,-2.8081176,-0.20779653,-0.30498865,1.0156854,1.5306722
206 | ones,_,-1.3621695,-0.18504938,-0.2161428,0.37226388,0.5816039
207 | positions,0,-0.13063768,0.36072794,0.13273126,0.08235803,-0.78297955
208 | positions,1,3.3640394,1.4360483,-0.19464374,-3.6062639,-6.1498446
209 | positions,2,0.15126465,2.110513,0.34884313,-1.5275149,-2.7380152
210 | positions,3,-1.1328846,0.9832707,1.3702737,-1.1388667,0.24428602
211 | positions,4,-4.9033227,-5.421642,2.7506442,7.5432215,-0.07508394
212 | positions,5,-5.0503674,1.0907708,2.297857,1.425875,2.3817568
213 | positions,6,0.65689164,-6.7524247,-8.143799,-5.7251816,6.340086
214 | positions,7,-0.8686947,-0.5266267,-0.328554,0.117776334,0.3166736
215 | tokens,0,1.8981099,-0.8942574,-0.15392976,0.062256597,-1.4486569
216 | tokens,1,-1.294627,2.9094937,0.33469462,-1.4084401,-0.6528272
217 | tokens,2,-0.5369325,0.3324304,0.5662137,0.01150415,0.71778655
218 | tokens,3,-0.98839134,-0.6982198,0.7751885,3.5432122,-1.4779247
219 | tokens,4,-5.4420075,-4.702195,-2.883783,0.086826995,6.86268
220 | tokens,,-0.7684788,-0.24721867,0.30233318,-0.792211,0.23400383
221 | tokens,,-0.09745099,0.15648925,0.45449963,-0.054960843,-0.32444936
222 | tokens,,0.109125495,-0.04502393,0.25828347,0.35704604,-0.00495211
223 |
--------------------------------------------------------------------------------
/programs/rasp_categorical_only/double_hist/double_hist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def select_closest(keys, queries, predicate):
6 | scores = [[False for _ in keys] for _ in queries]
7 | for i, q in enumerate(queries):
8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)]
9 | if not (any(matches)):
10 | scores[i][0] = True
11 | else:
12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j))
13 | scores[i][j] = True
14 | return scores
15 |
16 |
17 | def aggregate(attention, values):
18 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention]
19 |
20 |
21 | def run(tokens):
22 | # classifier weights ##########################################
23 | classifier_weights = pd.read_csv(
24 | "programs/rasp_categorical_only/double_hist/double_hist_weights.csv",
25 | index_col=[0, 1],
26 | dtype={"feature": str},
27 | )
28 | # inputs #####################################################
29 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]]
30 |
31 | positions = list(range(len(tokens)))
32 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]]
33 |
34 | ones = [1 for _ in range(len(tokens))]
35 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0)
36 |
37 | # attn_0_0 ####################################################
38 | def predicate_0_0(q_position, k_position):
39 | if q_position in {0, 3}:
40 | return k_position == 3
41 | elif q_position in {1, 2}:
42 | return k_position == 5
43 | elif q_position in {4, 6}:
44 | return k_position == 6
45 | elif q_position in {5}:
46 | return k_position == 2
47 | elif q_position in {7}:
48 | return k_position == 7
49 |
50 | attn_0_0_pattern = select_closest(positions, positions, predicate_0_0)
51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, positions)
52 | attn_0_0_output_scores = classifier_weights.loc[
53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs]
54 | ]
55 |
56 | # attn_0_1 ####################################################
57 | def predicate_0_1(q_position, k_position):
58 | if q_position in {0}:
59 | return k_position == 5
60 | elif q_position in {1, 2, 3, 4, 5, 6}:
61 | return k_position == 6
62 | elif q_position in {7}:
63 | return k_position == 7
64 |
65 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1)
66 | attn_0_1_outputs = aggregate(attn_0_1_pattern, positions)
67 | attn_0_1_output_scores = classifier_weights.loc[
68 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs]
69 | ]
70 |
71 | # attn_0_2 ####################################################
72 | def predicate_0_2(q_token, k_token):
73 | if q_token in {"1", "4", "0", "2", "3"}:
74 | return k_token == "5"
75 | elif q_token in {"5"}:
76 | return k_token == "0"
77 | elif q_token in {""}:
78 | return k_token == ""
79 |
80 | attn_0_2_pattern = select_closest(tokens, tokens, predicate_0_2)
81 | attn_0_2_outputs = aggregate(attn_0_2_pattern, positions)
82 | attn_0_2_output_scores = classifier_weights.loc[
83 | [("attn_0_2_outputs", str(v)) for v in attn_0_2_outputs]
84 | ]
85 |
86 | # attn_0_3 ####################################################
87 | def predicate_0_3(q_position, k_position):
88 | if q_position in {0, 3}:
89 | return k_position == 3
90 | elif q_position in {1, 2, 4, 6}:
91 | return k_position == 6
92 | elif q_position in {5}:
93 | return k_position == 1
94 | elif q_position in {7}:
95 | return k_position == 7
96 |
97 | attn_0_3_pattern = select_closest(positions, positions, predicate_0_3)
98 | attn_0_3_outputs = aggregate(attn_0_3_pattern, positions)
99 | attn_0_3_output_scores = classifier_weights.loc[
100 | [("attn_0_3_outputs", str(v)) for v in attn_0_3_outputs]
101 | ]
102 |
103 | # attn_0_4 ####################################################
104 | def predicate_0_4(q_position, k_position):
105 | if q_position in {0}:
106 | return k_position == 5
107 | elif q_position in {1, 2, 3, 4, 5, 6, 7}:
108 | return k_position == 7
109 |
110 | attn_0_4_pattern = select_closest(positions, positions, predicate_0_4)
111 | attn_0_4_outputs = aggregate(attn_0_4_pattern, positions)
112 | attn_0_4_output_scores = classifier_weights.loc[
113 | [("attn_0_4_outputs", str(v)) for v in attn_0_4_outputs]
114 | ]
115 |
116 | # attn_0_5 ####################################################
117 | def predicate_0_5(q_position, k_position):
118 | if q_position in {0}:
119 | return k_position == 3
120 | elif q_position in {1, 7}:
121 | return k_position == 7
122 | elif q_position in {2}:
123 | return k_position == 5
124 | elif q_position in {3, 4, 5, 6}:
125 | return k_position == 6
126 |
127 | attn_0_5_pattern = select_closest(positions, positions, predicate_0_5)
128 | attn_0_5_outputs = aggregate(attn_0_5_pattern, positions)
129 | attn_0_5_output_scores = classifier_weights.loc[
130 | [("attn_0_5_outputs", str(v)) for v in attn_0_5_outputs]
131 | ]
132 |
133 | # attn_0_6 ####################################################
134 | def predicate_0_6(q_token, k_token):
135 | if q_token in {"0"}:
136 | return k_token == "1"
137 | elif q_token in {"1", "2"}:
138 | return k_token == "0"
139 | elif q_token in {"5", "4", "3"}:
140 | return k_token == "2"
141 | elif q_token in {""}:
142 | return k_token == "5"
143 |
144 | attn_0_6_pattern = select_closest(tokens, tokens, predicate_0_6)
145 | attn_0_6_outputs = aggregate(attn_0_6_pattern, tokens)
146 | attn_0_6_output_scores = classifier_weights.loc[
147 | [("attn_0_6_outputs", str(v)) for v in attn_0_6_outputs]
148 | ]
149 |
150 | # attn_0_7 ####################################################
151 | def predicate_0_7(q_token, k_token):
152 | if q_token in {"0"}:
153 | return k_token == "0"
154 | elif q_token in {"1"}:
155 | return k_token == "1"
156 | elif q_token in {"2"}:
157 | return k_token == "2"
158 | elif q_token in {"3"}:
159 | return k_token == "3"
160 | elif q_token in {"4"}:
161 | return k_token == "4"
162 | elif q_token in {"5"}:
163 | return k_token == "5"
164 | elif q_token in {""}:
165 | return k_token == ""
166 |
167 | attn_0_7_pattern = select_closest(tokens, tokens, predicate_0_7)
168 | attn_0_7_outputs = aggregate(attn_0_7_pattern, positions)
169 | attn_0_7_output_scores = classifier_weights.loc[
170 | [("attn_0_7_outputs", str(v)) for v in attn_0_7_outputs]
171 | ]
172 |
173 | # mlp_0_0 #####################################################
174 | def mlp_0_0(token, attn_0_6_output):
175 | key = (token, attn_0_6_output)
176 | if key in {
177 | ("0", "0"),
178 | ("0", "1"),
179 | ("0", "2"),
180 | ("0", "3"),
181 | ("0", "4"),
182 | ("0", "5"),
183 | ("0", ""),
184 | ("1", "0"),
185 | ("1", "1"),
186 | ("1", "2"),
187 | ("1", "3"),
188 | ("1", "4"),
189 | ("1", "5"),
190 | ("1", ""),
191 | ("3", "0"),
192 | ("4", "0"),
193 | ("4", "1"),
194 | }:
195 | return 3
196 | elif key in {
197 | ("2", "5"),
198 | ("2", ""),
199 | ("3", "4"),
200 | ("3", "5"),
201 | ("3", ""),
202 | ("", "0"),
203 | ("", "1"),
204 | ("", "2"),
205 | ("", "3"),
206 | ("", "4"),
207 | ("", "5"),
208 | ("", ""),
209 | }:
210 | return 0
211 | elif key in {("2", "0"), ("5", "0")}:
212 | return 2
213 | return 5
214 |
215 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(tokens, attn_0_6_outputs)]
216 | mlp_0_0_output_scores = classifier_weights.loc[
217 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs]
218 | ]
219 |
220 | # mlp_0_1 #####################################################
221 | def mlp_0_1(attn_0_6_output, token):
222 | key = (attn_0_6_output, token)
223 | if key in {("4", "1"), ("4", "3"), ("4", "4"), ("4", "")}:
224 | return 6
225 | return 5
226 |
227 | mlp_0_1_outputs = [mlp_0_1(k0, k1) for k0, k1 in zip(attn_0_6_outputs, tokens)]
228 | mlp_0_1_output_scores = classifier_weights.loc[
229 | [("mlp_0_1_outputs", str(v)) for v in mlp_0_1_outputs]
230 | ]
231 |
232 | # attn_1_0 ####################################################
233 | def predicate_1_0(q_token, k_token):
234 | if q_token in {"1", "4", "0", "2", "5"}:
235 | return k_token == "3"
236 | elif q_token in {"3"}:
237 | return k_token == "0"
238 | elif q_token in {""}:
239 | return k_token == ""
240 |
241 | attn_1_0_pattern = select_closest(tokens, tokens, predicate_1_0)
242 | attn_1_0_outputs = aggregate(attn_1_0_pattern, tokens)
243 | attn_1_0_output_scores = classifier_weights.loc[
244 | [("attn_1_0_outputs", str(v)) for v in attn_1_0_outputs]
245 | ]
246 |
247 | # attn_1_1 ####################################################
248 | def predicate_1_1(q_token, k_token):
249 | if q_token in {"0", "3"}:
250 | return k_token == "1"
251 | elif q_token in {"1", "4"}:
252 | return k_token == "2"
253 | elif q_token in {"5", "2"}:
254 | return k_token == "4"
255 | elif q_token in {""}:
256 | return k_token == ""
257 |
258 | attn_1_1_pattern = select_closest(tokens, tokens, predicate_1_1)
259 | attn_1_1_outputs = aggregate(attn_1_1_pattern, mlp_0_0_outputs)
260 | attn_1_1_output_scores = classifier_weights.loc[
261 | [("attn_1_1_outputs", str(v)) for v in attn_1_1_outputs]
262 | ]
263 |
264 | # attn_1_2 ####################################################
265 | def predicate_1_2(q_token, k_token):
266 | if q_token in {"4", "5", "0", "2"}:
267 | return k_token == "2"
268 | elif q_token in {"1"}:
269 | return k_token == "1"
270 | elif q_token in {"", "3"}:
271 | return k_token == "5"
272 |
273 | attn_1_2_pattern = select_closest(tokens, tokens, predicate_1_2)
274 | attn_1_2_outputs = aggregate(attn_1_2_pattern, mlp_0_0_outputs)
275 | attn_1_2_output_scores = classifier_weights.loc[
276 | [("attn_1_2_outputs", str(v)) for v in attn_1_2_outputs]
277 | ]
278 |
279 | # attn_1_3 ####################################################
280 | def predicate_1_3(attn_0_5_output, position):
281 | if attn_0_5_output in {0, 1, 3, 6}:
282 | return position == 4
283 | elif attn_0_5_output in {2, 4, 7}:
284 | return position == 5
285 | elif attn_0_5_output in {5}:
286 | return position == 2
287 |
288 | attn_1_3_pattern = select_closest(positions, attn_0_5_outputs, predicate_1_3)
289 | attn_1_3_outputs = aggregate(attn_1_3_pattern, attn_0_0_outputs)
290 | attn_1_3_output_scores = classifier_weights.loc[
291 | [("attn_1_3_outputs", str(v)) for v in attn_1_3_outputs]
292 | ]
293 |
294 | # attn_1_4 ####################################################
295 | def predicate_1_4(q_token, k_token):
296 | if q_token in {"0"}:
297 | return k_token == "2"
298 | elif q_token in {"1", "4"}:
299 | return k_token == "0"
300 | elif q_token in {"2"}:
301 | return k_token == "4"
302 | elif q_token in {"3"}:
303 | return k_token == "1"
304 | elif q_token in {"5", ""}:
305 | return k_token == ""
306 |
307 | attn_1_4_pattern = select_closest(tokens, tokens, predicate_1_4)
308 | attn_1_4_outputs = aggregate(attn_1_4_pattern, attn_0_7_outputs)
309 | attn_1_4_output_scores = classifier_weights.loc[
310 | [("attn_1_4_outputs", str(v)) for v in attn_1_4_outputs]
311 | ]
312 |
313 | # attn_1_5 ####################################################
314 | def predicate_1_5(attn_0_4_output, attn_0_0_output):
315 | if attn_0_4_output in {0, 1, 3, 5}:
316 | return attn_0_0_output == 6
317 | elif attn_0_4_output in {2, 6, 7}:
318 | return attn_0_0_output == 7
319 | elif attn_0_4_output in {4}:
320 | return attn_0_0_output == 5
321 |
322 | attn_1_5_pattern = select_closest(attn_0_0_outputs, attn_0_4_outputs, predicate_1_5)
323 | attn_1_5_outputs = aggregate(attn_1_5_pattern, attn_0_1_outputs)
324 | attn_1_5_output_scores = classifier_weights.loc[
325 | [("attn_1_5_outputs", str(v)) for v in attn_1_5_outputs]
326 | ]
327 |
328 | # attn_1_6 ####################################################
329 | def predicate_1_6(attn_0_0_output, attn_0_7_output):
330 | if attn_0_0_output in {0, 6}:
331 | return attn_0_7_output == 7
332 | elif attn_0_0_output in {1, 4, 5}:
333 | return attn_0_7_output == 6
334 | elif attn_0_0_output in {2}:
335 | return attn_0_7_output == 5
336 | elif attn_0_0_output in {3}:
337 | return attn_0_7_output == 3
338 | elif attn_0_0_output in {7}:
339 | return attn_0_7_output == 1
340 |
341 | attn_1_6_pattern = select_closest(attn_0_7_outputs, attn_0_0_outputs, predicate_1_6)
342 | attn_1_6_outputs = aggregate(attn_1_6_pattern, attn_0_3_outputs)
343 | attn_1_6_output_scores = classifier_weights.loc[
344 | [("attn_1_6_outputs", str(v)) for v in attn_1_6_outputs]
345 | ]
346 |
347 | # attn_1_7 ####################################################
348 | def predicate_1_7(q_token, k_token):
349 | if q_token in {"1", "0", "3"}:
350 | return k_token == "4"
351 | elif q_token in {"4", "5", "2"}:
352 | return k_token == "1"
353 | elif q_token in {""}:
354 | return k_token == ""
355 |
356 | attn_1_7_pattern = select_closest(tokens, tokens, predicate_1_7)
357 | attn_1_7_outputs = aggregate(attn_1_7_pattern, mlp_0_0_outputs)
358 | attn_1_7_output_scores = classifier_weights.loc[
359 | [("attn_1_7_outputs", str(v)) for v in attn_1_7_outputs]
360 | ]
361 |
362 | # mlp_1_0 #####################################################
363 | def mlp_1_0(attn_1_5_output, attn_1_3_output):
364 | key = (attn_1_5_output, attn_1_3_output)
365 | if key in {
366 | (0, 1),
367 | (0, 3),
368 | (0, 5),
369 | (1, 1),
370 | (1, 3),
371 | (1, 5),
372 | (2, 3),
373 | (3, 3),
374 | (4, 1),
375 | (4, 3),
376 | (4, 5),
377 | (5, 0),
378 | (5, 1),
379 | (5, 2),
380 | (5, 3),
381 | (5, 4),
382 | (5, 5),
383 | }:
384 | return 3
385 | return 5
386 |
387 | mlp_1_0_outputs = [
388 | mlp_1_0(k0, k1) for k0, k1 in zip(attn_1_5_outputs, attn_1_3_outputs)
389 | ]
390 | mlp_1_0_output_scores = classifier_weights.loc[
391 | [("mlp_1_0_outputs", str(v)) for v in mlp_1_0_outputs]
392 | ]
393 |
394 | # mlp_1_1 #####################################################
395 | def mlp_1_1(attn_1_7_output, attn_1_5_output):
396 | key = (attn_1_7_output, attn_1_5_output)
397 | if key in {
398 | (1, 0),
399 | (1, 1),
400 | (1, 3),
401 | (1, 4),
402 | (1, 5),
403 | (2, 0),
404 | (2, 1),
405 | (2, 3),
406 | (2, 4),
407 | (2, 5),
408 | (3, 0),
409 | (3, 1),
410 | (3, 3),
411 | (3, 4),
412 | (3, 5),
413 | (6, 0),
414 | (6, 1),
415 | (6, 3),
416 | (6, 4),
417 | (6, 5),
418 | (7, 0),
419 | (7, 1),
420 | (7, 3),
421 | (7, 4),
422 | (7, 5),
423 | }:
424 | return 1
425 | elif key in {(0, 7), (1, 7), (2, 7), (3, 7), (4, 7), (5, 7), (6, 7)}:
426 | return 5
427 | elif key in {(4, 0)}:
428 | return 2
429 | return 0
430 |
431 | mlp_1_1_outputs = [
432 | mlp_1_1(k0, k1) for k0, k1 in zip(attn_1_7_outputs, attn_1_5_outputs)
433 | ]
434 | mlp_1_1_output_scores = classifier_weights.loc[
435 | [("mlp_1_1_outputs", str(v)) for v in mlp_1_1_outputs]
436 | ]
437 |
438 | feature_logits = pd.concat(
439 | [
440 | df.reset_index()
441 | for df in [
442 | token_scores,
443 | position_scores,
444 | attn_0_0_output_scores,
445 | attn_0_1_output_scores,
446 | attn_0_2_output_scores,
447 | attn_0_3_output_scores,
448 | attn_0_4_output_scores,
449 | attn_0_5_output_scores,
450 | attn_0_6_output_scores,
451 | attn_0_7_output_scores,
452 | mlp_0_0_output_scores,
453 | mlp_0_1_output_scores,
454 | attn_1_0_output_scores,
455 | attn_1_1_output_scores,
456 | attn_1_2_output_scores,
457 | attn_1_3_output_scores,
458 | attn_1_4_output_scores,
459 | attn_1_5_output_scores,
460 | attn_1_6_output_scores,
461 | attn_1_7_output_scores,
462 | mlp_1_0_output_scores,
463 | mlp_1_1_output_scores,
464 | one_scores,
465 | ]
466 | ]
467 | )
468 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy()
469 | classes = classifier_weights.columns.to_numpy()
470 | predictions = classes[logits.argmax(-1)]
471 | if tokens[0] == "":
472 | predictions[0] = ""
473 | if tokens[-1] == "":
474 | predictions[-1] = ""
475 | return predictions.tolist()
476 |
477 |
478 | examples = [
479 | (["", "4", "5", "0", "5", "2", "2"], ["", "2", "2", "2", "2", "2", "2"]),
480 | (["", "5", "2", "3", "5", "2", "3"], ["", "3", "3", "3", "3", "3", "3"]),
481 | (["", "3", "1", "3", "5"], ["", "1", "2", "1", "2"]),
482 | (
483 | ["", "0", "2", "5", "0", "3", "3", "4"],
484 | ["", "2", "3", "3", "2", "2", "2", "3"],
485 | ),
486 | (["", "1", "0", "5", "4"], ["", "4", "4", "4", "4"]),
487 | (["", "1", "4", "1", "2", "0", "1"], ["", "1", "3", "1", "3", "3", "1"]),
488 | (["", "2", "0", "3", "2", "3", "5"], ["", "2", "2", "2", "2", "2", "2"]),
489 | (
490 | ["", "1", "2", "1", "2", "4", "1", "5"],
491 | ["", "1", "1", "1", "1", "2", "1", "2"],
492 | ),
493 | (["", "3", "1", "1", "0"], ["", "2", "1", "1", "2"]),
494 | (["", "4", "0", "5", "5", "4"], ["", "2", "1", "2", "2", "2"]),
495 | ]
496 | for x, y in examples:
497 | print(f"x: {x}")
498 | print(f"y: {y}")
499 | y_hat = run(x)
500 | print(f"y_hat: {y_hat}")
501 | print()
502 |
--------------------------------------------------------------------------------
/programs/rasp_categorical_only/double_hist/double_hist_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,1,2,3,4,5,6
2 | attn_0_0_outputs,0,0.126095,-0.09308497,-0.04164269,1.1473964,-0.65259224,-0.5759567
3 | attn_0_0_outputs,1,0.2643548,0.057178997,0.34055573,-0.13716577,-0.2368343,-0.39215782
4 | attn_0_0_outputs,2,0.09445334,-0.113155134,0.09719708,-0.3638223,-0.7040596,-1.0189345
5 | attn_0_0_outputs,3,0.3111331,0.07621635,0.29869452,-0.11679748,-0.4302198,-0.4611198
6 | attn_0_0_outputs,4,0.1457545,-0.08105723,0.16098967,0.43949932,0.038392957,-0.3110958
7 | attn_0_0_outputs,5,0.2503496,-0.0776672,0.2193873,-0.16906078,-0.25772113,-0.5475772
8 | attn_0_0_outputs,6,0.6186277,0.37508786,0.42960152,-0.44146442,-0.41667625,-0.62044233
9 | attn_0_0_outputs,7,0.55322224,0.21084103,0.76800203,-0.35938764,0.6075319,-4.7059984
10 | attn_0_1_outputs,0,-0.20470576,-0.5146743,-0.045462333,-0.7030797,8.210753,2.5101397
11 | attn_0_1_outputs,1,0.92150104,0.72936517,-1.0153459,0.9652804,2.3336089,1.3663455
12 | attn_0_1_outputs,2,0.039329447,-0.25970924,-0.53097314,0.094466195,-1.3910042,-0.7387352
13 | attn_0_1_outputs,3,0.30621508,0.059149753,0.13238326,0.42272153,0.28027993,0.14599061
14 | attn_0_1_outputs,4,0.13979882,-0.12561479,0.2264524,-0.007633883,0.46641046,-0.11170333
15 | attn_0_1_outputs,5,-0.21153452,-0.42493245,-0.2567995,-0.22684059,1.0124276,-0.56272596
16 | attn_0_1_outputs,6,0.6481529,0.54434663,0.0014183833,0.29601145,-5.983322,-1.2265881
17 | attn_0_1_outputs,7,0.5336258,0.4396987,-0.00064421905,0.07741543,-4.62246,-7.6855483
18 | attn_0_2_outputs,0,1.7004082,1.068717,0.67274964,-2.202218,-14.525051,-24.145012
19 | attn_0_2_outputs,1,-0.34902173,-0.46517774,-0.26951686,0.4179399,2.3199003,3.5011954
20 | attn_0_2_outputs,2,-0.30551437,-0.49071354,-0.20850673,0.52268445,2.352511,3.313084
21 | attn_0_2_outputs,3,-0.36689532,-0.6685318,-0.4558378,0.18013526,2.1214046,3.1762848
22 | attn_0_2_outputs,4,-0.42851058,-0.6440964,-0.38658094,0.28424677,2.1883128,2.630304
23 | attn_0_2_outputs,5,0.045050662,-0.07220718,0.047026966,0.6048806,2.6343098,2.9428246
24 | attn_0_2_outputs,6,-0.27480802,-0.30446512,-0.2069842,0.5289953,2.8373914,2.2616675
25 | attn_0_2_outputs,7,0.06093425,0.14711075,0.24991485,0.32365993,2.6028795,-4.8855653
26 | attn_0_3_outputs,0,0.09755095,0.18649152,-0.27435488,-1.9707323,-1.4276974,-2.0362265
27 | attn_0_3_outputs,1,0.35753757,0.23701514,-0.25319183,-0.6494274,-1.605695,-0.7366025
28 | attn_0_3_outputs,2,0.32133347,0.2595023,0.31592253,0.17246504,-0.1395176,-0.22350298
29 | attn_0_3_outputs,3,0.24959454,0.32810065,-0.024572447,-0.08864521,-0.25495535,-0.24189267
30 | attn_0_3_outputs,4,-0.25754684,-0.08468185,-0.3961516,-0.13146977,-0.35567138,-0.29548463
31 | attn_0_3_outputs,5,0.23781574,0.3707741,0.1524827,0.0640203,0.10271894,-0.15585618
32 | attn_0_3_outputs,6,-0.04686378,0.1890392,0.06480245,0.41855758,1.3988882,-0.042787485
33 | attn_0_3_outputs,7,-0.1659183,0.11428217,0.0062955585,0.4894925,0.07572694,-5.7053986
34 | attn_0_4_outputs,0,0.0009369574,0.07684592,0.040978912,0.5028028,0.9522137,3.6835043
35 | attn_0_4_outputs,1,-0.5538695,0.603373,-0.20237413,-0.11773054,0.77875495,-2.9072464
36 | attn_0_4_outputs,2,0.30023372,0.00922706,-0.50846523,0.59284455,-1.5345312,-0.671517
37 | attn_0_4_outputs,3,0.61231476,-0.13248341,0.6498266,0.61870444,0.7098338,-0.43163025
38 | attn_0_4_outputs,4,0.341006,0.14442582,0.41357324,0.23423965,-0.040843505,-0.04311107
39 | attn_0_4_outputs,5,0.15555005,-0.0129139135,0.038804073,-0.07452761,-0.2310702,-1.8921185
40 | attn_0_4_outputs,6,0.085187085,0.11249486,-0.30318964,0.5192505,-1.7593035,-0.5133031
41 | attn_0_4_outputs,7,0.67837936,0.4230805,0.6677434,-1.2566549,-1.3864559,-14.990864
42 | attn_0_5_outputs,0,0.117494754,-0.0076674935,0.17776193,-0.18112357,-0.27396384,-1.0354977
43 | attn_0_5_outputs,1,1.411103,-0.9264894,1.0556792,0.40703505,0.21264277,-0.08458514
44 | attn_0_5_outputs,2,0.08361238,0.5619377,0.48529363,-0.33324116,0.12216819,-0.40256438
45 | attn_0_5_outputs,3,-0.4463776,0.7458927,-0.2468474,-0.110574245,0.27681002,-0.2829393
46 | attn_0_5_outputs,4,0.27849215,0.050043125,-0.27561343,-0.05235713,-0.050278183,-0.09016176
47 | attn_0_5_outputs,5,-0.08726232,-0.21527587,0.051751416,-0.38245076,-0.7218009,-1.5243168
48 | attn_0_5_outputs,6,0.41514802,0.25593644,0.4507316,-0.057119988,-0.124129154,-0.17267601
49 | attn_0_5_outputs,7,0.5153719,0.2776314,0.6347292,-0.058477916,0.21668187,-5.153364
50 | attn_0_6_outputs,0,0.4399908,0.40121058,0.2908646,0.30896205,1.0355054,2.6121292
51 | attn_0_6_outputs,1,-0.4126085,-0.40561342,-0.2074391,0.47275737,0.93861,-0.9778263
52 | attn_0_6_outputs,2,-0.16018946,-0.020672152,0.5263022,1.8692286,3.3783615,2.0448
53 | attn_0_6_outputs,3,1.2413086,0.609498,0.27304208,-2.0135508,-3.9722664,-2.54368
54 | attn_0_6_outputs,4,-0.4699255,-0.07383059,0.8524104,0.43587294,0.46777043,0.14019349
55 | attn_0_6_outputs,5,-0.79917717,-0.056769755,0.6343696,0.45576847,-2.0306313,-0.97413427
56 | attn_0_6_outputs,,-0.27148145,0.012953444,-0.09518302,0.17713453,-0.20631313,-0.18813157
57 | attn_0_6_outputs,,0.68961024,0.53815615,0.44267753,-0.3528938,-4.7231135,-14.933182
58 | attn_0_7_outputs,0,0.26736894,-0.24484192,0.017989898,0.36529353,0.21895835,-0.13605899
59 | attn_0_7_outputs,1,-0.028649254,0.027535517,0.3608771,0.6933264,0.58034605,-0.25237384
60 | attn_0_7_outputs,2,-0.13851081,-0.03744818,0.35292834,0.566245,0.34746164,-0.39807728
61 | attn_0_7_outputs,3,0.207372,0.20546846,0.3811095,0.65656036,-0.3655698,-0.94992673
62 | attn_0_7_outputs,4,-0.21982807,-0.07117972,0.24072638,0.580462,0.22779587,-0.25692478
63 | attn_0_7_outputs,5,0.41687506,0.26974058,0.19956248,-0.015195189,-3.0216563,-1.1237081
64 | attn_0_7_outputs,6,-0.34657335,-0.24584907,0.0033485373,0.29958242,0.08694523,-0.8937644
65 | attn_0_7_outputs,7,0.2750328,0.3009795,0.66187775,0.84384704,0.58818835,-3.057173
66 | attn_1_0_outputs,0,-1.1977414,-0.68712807,0.21589795,1.5379475,2.7449412,6.4803658
67 | attn_1_0_outputs,1,0.43876192,0.5447946,1.4586774,-1.0479536,-2.8326502,-4.8915324
68 | attn_1_0_outputs,2,-0.096182294,0.5217153,0.98753965,-0.32770967,-1.3522687,-2.2044451
69 | attn_1_0_outputs,3,-1.6992767,-1.1615183,-0.33289766,0.98977876,6.670316,7.1105237
70 | attn_1_0_outputs,4,0.9227671,0.8429137,0.9108278,-1.8007963,-3.0463388,-6.395819
71 | attn_1_0_outputs,5,-0.37416497,0.8064592,1.6661402,-0.17301181,-2.1290236,-6.337081
72 | attn_1_0_outputs,,0.16492641,0.2724202,0.30109468,0.1270346,-0.42348635,-0.6057034
73 | attn_1_0_outputs,,1.0615019,1.1287842,1.2637048,-1.2012541,-14.706598,-19.987967
74 | attn_1_1_outputs,0,1.3178737,1.3075073,0.6171435,-1.2386639,-14.961128,-22.710047
75 | attn_1_1_outputs,1,-0.79646075,-0.048236232,-0.57135,0.8660453,5.204853,3.9006507
76 | attn_1_1_outputs,2,-0.64026845,-0.31622618,-0.80555373,-0.38438004,-3.0017076,3.1284397
77 | attn_1_1_outputs,3,-0.413255,-0.0743735,-0.44691494,0.00972029,1.0504625,1.6899987
78 | attn_1_1_outputs,4,-1.0316962,-0.29744324,-0.7646268,0.6624031,2.773765,2.0252886
79 | attn_1_1_outputs,5,-0.32514748,0.08588939,-0.24190034,1.1112455,4.891911,-4.0668387
80 | attn_1_1_outputs,6,-0.18358597,0.007782329,-0.167846,0.55784917,1.0497051,1.5342381
81 | attn_1_1_outputs,7,-0.89457923,-0.32318884,-0.5396412,0.82994485,2.7423346,3.3378825
82 | attn_1_2_outputs,0,0.6464282,0.4415729,0.34436905,0.17836824,-4.205051,-11.493903
83 | attn_1_2_outputs,1,-0.0137010105,-0.445989,-0.4315788,-0.57123286,0.7491084,2.0127614
84 | attn_1_2_outputs,2,-0.28706008,-0.40443745,-0.38865212,0.5154979,4.107742,7.3013763
85 | attn_1_2_outputs,3,-0.10718431,-0.3202067,-0.32353058,0.20667966,1.9230273,1.5108932
86 | attn_1_2_outputs,4,0.33059838,0.07220246,-0.23482965,-1.067047,0.19484031,1.3690035
87 | attn_1_2_outputs,5,0.3945407,0.329356,0.33315465,0.6143382,1.234685,-1.9626414
88 | attn_1_2_outputs,6,-0.5554587,0.40695667,0.13053398,0.9338181,0.99574876,0.43968514
89 | attn_1_2_outputs,7,0.33515787,0.09245179,0.09882208,-0.6038185,0.082006045,2.063728
90 | attn_1_3_outputs,0,-0.535496,-0.13259524,-0.71865547,-0.3909597,-0.38141873,-0.40478173
91 | attn_1_3_outputs,1,0.20157264,1.0196886,-0.502918,-0.23266602,-0.06780651,0.2839907
92 | attn_1_3_outputs,2,-0.11747025,0.37703773,-0.42134264,-0.14813195,-0.20976208,-0.034975
93 | attn_1_3_outputs,3,-0.37622637,-6.0123005,5.564005,-2.447405,0.38429993,-0.59021026
94 | attn_1_3_outputs,4,1.1969863,-0.76458126,-1.5581679,-0.42092353,0.42188513,0.6117949
95 | attn_1_3_outputs,5,-0.13388146,0.33084488,-0.3471546,-0.08674212,-0.17246169,-0.21080919
96 | attn_1_3_outputs,6,0.040989626,0.52416426,-0.098571725,0.21533841,-0.24930477,0.22166981
97 | attn_1_3_outputs,7,0.2629529,0.48815468,-0.1758514,0.5167536,-0.11469977,-7.593266
98 | attn_1_4_outputs,0,0.7242939,0.6138902,-0.2165776,-1.189821,-6.016478,-2.156461
99 | attn_1_4_outputs,1,-0.028479,0.10070866,-0.31167245,0.2912638,2.3396828,1.539698
100 | attn_1_4_outputs,2,0.03833334,0.08542576,-0.23502897,0.34766886,1.8727108,0.8869683
101 | attn_1_4_outputs,3,0.14814246,0.14517051,-0.23019965,0.4274214,2.4027905,0.5986793
102 | attn_1_4_outputs,4,-0.060111366,0.08575323,-0.3969963,0.3362898,2.0958183,1.2442379
103 | attn_1_4_outputs,5,-0.083861604,0.069971085,-0.24564557,0.27127108,1.9678422,0.49202684
104 | attn_1_4_outputs,6,-0.013286352,0.3606033,-0.18357426,0.50619376,2.4264543,0.56555873
105 | attn_1_4_outputs,7,-0.2426668,-0.2329866,-0.72179276,0.0852172,1.769792,-2.2768142
106 | attn_1_5_outputs,0,-0.38110098,0.40513593,-4.439374,4.093755,2.1234877,0.5540743
107 | attn_1_5_outputs,1,-0.23400371,-0.036937438,-0.09517478,-0.38474637,-0.2784226,-0.58958876
108 | attn_1_5_outputs,2,0.3410733,0.18570994,-0.07866506,0.34950978,-0.75937384,-0.24752817
109 | attn_1_5_outputs,3,0.021574836,-0.6249923,0.34806028,-0.11711029,3.910408,-0.14427455
110 | attn_1_5_outputs,4,-0.4571682,-0.37512428,-0.30038518,-0.41331235,0.18258524,-0.23161632
111 | attn_1_5_outputs,5,0.24728091,0.4856511,0.594313,-2.091695,2.0573926,-0.7336284
112 | attn_1_5_outputs,6,0.22816464,0.27136412,0.5195222,-0.6798148,-5.8662944,-0.39577755
113 | attn_1_5_outputs,7,0.12520222,0.22135924,0.7278236,-0.72360736,-2.3794682,-8.568124
114 | attn_1_6_outputs,0,1.9565225,1.992841,-4.4924707,-1.4354049,-3.0359225,0.4108184
115 | attn_1_6_outputs,1,-1.5355867,-0.8102765,0.39700136,0.9479792,5.72563,0.6617512
116 | attn_1_6_outputs,2,0.41150886,-0.04871347,0.3481352,-0.41566238,-1.5758594,0.090029344
117 | attn_1_6_outputs,3,-0.26570487,-0.021473275,0.59041107,0.2970626,0.9045193,-0.53055835
118 | attn_1_6_outputs,4,-0.08541798,0.34799036,0.76428026,0.69276136,0.68592554,0.5443473
119 | attn_1_6_outputs,5,-0.15378077,-0.39172465,0.40749076,0.5314768,0.8851225,0.4981952
120 | attn_1_6_outputs,6,-0.008769149,0.07061772,0.37057376,-0.65310204,-3.0223656,-1.3392087
121 | attn_1_6_outputs,7,0.05816845,0.25007242,0.451493,-0.22949804,-2.5058506,-7.4911504
122 | attn_1_7_outputs,0,1.0521657,1.0074128,0.4629282,-2.0979998,-18.004267,-24.830841
123 | attn_1_7_outputs,1,-0.69483924,-0.15778616,-0.24732345,1.2288451,3.7539797,5.79142
124 | attn_1_7_outputs,2,0.77858096,0.7278021,-0.42153692,-1.9385699,-2.943345,-4.5531664
125 | attn_1_7_outputs,3,-0.8458971,-0.20327307,-0.19063051,1.0987198,2.5467868,5.116683
126 | attn_1_7_outputs,4,-1.1233516,-0.48994657,-0.16887976,1.1522313,3.7934172,3.263776
127 | attn_1_7_outputs,5,-0.67034286,-0.18595126,-0.15823115,0.8096825,2.814898,0.057380866
128 | attn_1_7_outputs,6,-0.6525574,0.014802525,0.19252089,0.6953846,2.502593,-0.91043985
129 | attn_1_7_outputs,7,-1.4201316,-0.9218594,-1.0736845,0.35498753,3.4924703,3.2490141
130 | mlp_0_0_outputs,0,1.1609898,0.8572743,0.17748477,-1.1314943,-6.752202,-10.017833
131 | mlp_0_0_outputs,1,0.15575539,0.20244724,-0.35808483,-0.16532898,0.6535445,0.9125833
132 | mlp_0_0_outputs,2,0.47592318,0.36881614,0.17811424,1.0390111,2.0942621,-0.4145827
133 | mlp_0_0_outputs,3,0.25076088,0.31756687,-0.023319967,0.14243905,-1.2209392,0.2947385
134 | mlp_0_0_outputs,4,0.47051138,0.46116182,0.02881971,0.85385466,0.9588443,1.3177751
135 | mlp_0_0_outputs,5,0.32828408,0.40574586,-0.122866556,0.5059813,1.4154013,-0.2793758
136 | mlp_0_0_outputs,6,0.5967165,0.52169484,-0.30867448,-1.2104449,-1.8745283,-0.5121046
137 | mlp_0_0_outputs,7,-0.01277371,-0.049345464,-0.40406802,-0.19294757,0.47392055,1.439194
138 | mlp_0_1_outputs,0,-0.08253285,-0.067494564,-0.19991755,0.16750817,-0.049758844,-0.8476353
139 | mlp_0_1_outputs,1,-0.11760705,-0.03612627,-0.060250442,0.016026465,0.3172396,-0.77812344
140 | mlp_0_1_outputs,2,-0.23936105,0.00013453455,-0.25397772,0.266447,-0.7957484,-0.7953196
141 | mlp_0_1_outputs,3,-0.34299704,-0.27820325,-0.29482996,-0.2864089,-0.09556553,-1.2629421
142 | mlp_0_1_outputs,4,-0.046016607,-0.09005855,-0.03415023,0.054577596,0.10380587,-0.7631649
143 | mlp_0_1_outputs,5,0.11538427,0.1287231,0.07588654,0.18756787,0.20409665,-0.34829876
144 | mlp_0_1_outputs,6,-0.47443476,-0.08539611,1.0217084,0.75428355,-0.06510375,-1.7716645
145 | mlp_0_1_outputs,7,-0.033662606,0.12398085,0.002802741,0.3118455,0.17265546,-0.92884475
146 | mlp_1_0_outputs,0,0.4595583,0.4955876,0.6625086,-0.30223736,0.2889277,-0.061707802
147 | mlp_1_0_outputs,1,0.12989043,0.2530865,-0.45482454,0.09733099,-2.926608,-0.32248923
148 | mlp_1_0_outputs,2,0.39325613,0.04932119,-1.1195576,0.22859044,-2.9229677,-0.25452712
149 | mlp_1_0_outputs,3,0.24489257,0.1534004,0.32450718,-1.9852786,1.6052507,-1.8961568
150 | mlp_1_0_outputs,4,0.20856915,0.21389735,-0.08142016,-0.114698805,-0.5788899,-0.084191464
151 | mlp_1_0_outputs,5,0.12847619,0.1381205,-0.26468688,-0.016217208,-2.0436559,-0.60245377
152 | mlp_1_0_outputs,6,-0.028417686,0.23583707,0.23858303,0.30995926,-3.0726144,-0.10256881
153 | mlp_1_0_outputs,7,0.5175587,0.68217903,0.32498646,0.61085075,-2.3507814,0.5371939
154 | mlp_1_1_outputs,0,0.45030883,0.4058445,0.3972978,0.28020045,-1.9913944,-0.055018645
155 | mlp_1_1_outputs,1,0.06336323,-0.28932407,0.37485504,0.25855744,1.9029437,0.020531047
156 | mlp_1_1_outputs,2,0.2779919,0.2499768,0.25289297,0.027555976,-1.9493699,-0.6084653
157 | mlp_1_1_outputs,3,0.44192633,0.19696143,0.19470312,0.69076985,-0.7434745,-0.1210918
158 | mlp_1_1_outputs,4,0.17183223,0.33310947,0.48470286,-0.5919067,-0.8927419,-4.1806927
159 | mlp_1_1_outputs,5,0.28217918,0.17491329,0.3567068,-0.2545964,0.08058698,-1.1707318
160 | mlp_1_1_outputs,6,0.61476415,0.21169148,0.082067244,0.027210148,-1.1794713,-0.701124
161 | mlp_1_1_outputs,7,0.3328481,0.29813948,0.22811241,0.22678965,-1.0541234,-0.81014025
162 | ones,_,0.10541394,-0.19712536,-0.008858198,0.58838516,-0.5383584,-1.4631578
163 | positions,0,0.19978875,0.010510841,0.4988928,-0.2598821,0.5445723,0.21678215
164 | positions,1,0.3364794,0.13280986,0.08279502,-0.20376179,-0.1359483,-0.47131416
165 | positions,2,0.31971225,0.13252828,0.15829775,-0.1309726,0.040779695,-0.4894607
166 | positions,3,0.33861408,0.18808766,0.4626563,0.1264201,-0.29951236,-0.6159285
167 | positions,4,0.3720952,0.11060386,0.2411623,-0.2241606,-0.03696618,-0.99401474
168 | positions,5,0.5989128,0.5286751,0.9225798,0.64178616,-0.24038357,-0.25051555
169 | positions,6,0.106012695,-0.12985282,-0.019405967,-0.54309326,0.049963955,-1.0877988
170 | positions,7,0.29646856,0.07841838,-0.053576685,-0.04623412,-1.2492927,-4.2103653
171 | tokens,0,0.48456392,0.24045633,0.19029397,0.19912028,-1.2815764,-1.3616517
172 | tokens,1,-0.2905613,-0.39902,-0.16171901,0.9161303,4.4462442,1.6672348
173 | tokens,2,0.3137043,0.1140619,0.2546798,-0.3095032,-6.1083627,-2.055795
174 | tokens,3,-0.2670635,-0.5329313,-0.38812107,-0.55894905,2.1805744,5.809422
175 | tokens,4,0.6327516,0.46054688,0.57450086,0.070067786,-0.3341348,-7.701936
176 | tokens,5,0.37370822,0.29103732,0.50220376,0.12526518,0.14494029,-0.0737996
177 | tokens,,-0.05186225,0.08119709,0.30115068,0.07776928,-0.2788572,-0.26285934
178 | tokens,,-0.3594653,-0.10468927,0.24319318,0.77997345,-0.14584279,-1.0055608
179 |
--------------------------------------------------------------------------------
/programs/rasp_categorical_only/dyck2/dyck2_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,F,P,T
2 | attn_0_0_outputs,(,-0.09839721,0.072768934,-1.9487087
3 | attn_0_0_outputs,),0.08454172,-0.8850317,-0.31351435
4 | attn_0_0_outputs,,0.64131826,0.51406443,-1.0410135
5 | attn_0_0_outputs,,-0.6800124,0.23581685,0.8454185
6 | attn_0_0_outputs,_,-0.64708936,0.05047905,-0.22396585
7 | attn_0_0_outputs,_,-0.60640234,-0.123617664,-0.3515302
8 | attn_0_0_outputs,_,0.33327684,0.7507405,0.3113162
9 | attn_0_0_outputs,_,-0.39292628,0.22216852,0.22758712
10 | attn_0_0_outputs,_,-0.10046251,0.024066549,0.026169369
11 | attn_0_0_outputs,_,0.55228436,0.15485866,0.10148037
12 | attn_0_0_outputs,_,0.3600001,0.4491899,0.22103375
13 | attn_0_0_outputs,_,-0.52200395,0.7784305,0.113123655
14 | attn_0_0_outputs,_,-0.59014857,0.64098155,1.3751171
15 | attn_0_0_outputs,_,0.14060214,0.8747294,-0.85699344
16 | attn_0_0_outputs,{,-0.039841566,0.13473256,-2.0053372
17 | attn_0_0_outputs,},0.01718044,-0.9023884,-0.3925461
18 | attn_0_1_outputs,(,0.24663934,0.05382682,-1.4295913
19 | attn_0_1_outputs,),0.44859663,0.11123296,0.21891037
20 | attn_0_1_outputs,,-0.12616456,0.80257934,0.82834697
21 | attn_0_1_outputs,,-3.839294,1.9205002,4.9068
22 | attn_0_1_outputs,_,0.016972326,0.6989716,0.30263668
23 | attn_0_1_outputs,_,0.06236384,0.3147991,0.21376273
24 | attn_0_1_outputs,_,0.05235185,0.2560973,-0.39419228
25 | attn_0_1_outputs,_,-0.70403737,-0.11916539,-0.8784481
26 | attn_0_1_outputs,_,0.180807,0.21920295,0.40988365
27 | attn_0_1_outputs,_,-0.38003322,0.43255183,0.30699322
28 | attn_0_1_outputs,_,-0.17761539,-0.16689041,-0.07989258
29 | attn_0_1_outputs,_,-0.2789575,0.21954398,-0.5772647
30 | attn_0_1_outputs,_,-0.7351252,0.29486123,1.1601452
31 | attn_0_1_outputs,_,0.39496288,0.80411744,-1.136259
32 | attn_0_1_outputs,{,0.25580674,-0.020562408,-1.5278777
33 | attn_0_1_outputs,},0.16623977,-0.2624999,-0.06795137
34 | attn_0_2_outputs,(,0.61922807,-1.373114,2.2359898
35 | attn_0_2_outputs,),0.5194672,-2.4147916,0.6145038
36 | attn_0_2_outputs,,0.56023663,-0.40202478,0.661288
37 | attn_0_2_outputs,,0.08606467,1.1087501,-0.7814447
38 | attn_0_2_outputs,_,0.28296587,0.18189661,-0.49550337
39 | attn_0_2_outputs,_,0.57444793,0.557668,0.13670582
40 | attn_0_2_outputs,_,1.248424,-0.001178341,-0.012084891
41 | attn_0_2_outputs,_,0.35445789,-0.5528627,-0.7319502
42 | attn_0_2_outputs,_,0.14257619,-0.5120995,-0.31158966
43 | attn_0_2_outputs,_,0.41282007,0.03288246,-0.95664227
44 | attn_0_2_outputs,_,0.34463474,-0.52674174,-0.4970775
45 | attn_0_2_outputs,_,0.844124,-0.25249133,-0.31938592
46 | attn_0_2_outputs,_,0.66700053,0.30742458,0.25866547
47 | attn_0_2_outputs,_,1.1990013,0.85531867,-1.7240785
48 | attn_0_2_outputs,{,0.12513652,1.2699282,-0.7781075
49 | attn_0_2_outputs,},0.9145521,-2.0735366,1.2254958
50 | attn_0_3_outputs,(,1.4417167,-0.64466053,-1.18381
51 | attn_0_3_outputs,),-0.6638932,1.0632501,0.49708265
52 | attn_0_3_outputs,,0.45009208,0.54130626,0.25542018
53 | attn_0_3_outputs,,0.70722836,-0.44485033,1.8558304
54 | attn_0_3_outputs,_,-0.3365018,-0.4629079,-0.17494892
55 | attn_0_3_outputs,_,1.1892766,0.6559469,0.19583496
56 | attn_0_3_outputs,_,0.77296907,0.44271478,-0.2764062
57 | attn_0_3_outputs,_,0.29536107,0.55541724,0.8952327
58 | attn_0_3_outputs,_,-0.1426614,0.44414833,0.44125217
59 | attn_0_3_outputs,_,-0.86521447,-0.82092905,0.21177834
60 | attn_0_3_outputs,_,0.032857087,0.3982487,0.5682109
61 | attn_0_3_outputs,_,-0.33331037,0.38984817,0.6000708
62 | attn_0_3_outputs,_,-1.2023804,-0.24906081,0.23746833
63 | attn_0_3_outputs,_,0.20427324,0.14984377,-0.69274366
64 | attn_0_3_outputs,{,1.219885,-0.54653597,-1.0410267
65 | attn_0_3_outputs,},-0.7413577,0.6296828,0.09741894
66 | attn_1_0_outputs,0,0.22650789,-0.18543506,-0.58758974
67 | attn_1_0_outputs,1,0.13576485,1.0912529,-0.5784185
68 | attn_1_0_outputs,10,-0.830527,0.81066996,0.042711593
69 | attn_1_0_outputs,11,0.17748582,-0.5338113,-1.0006839
70 | attn_1_0_outputs,12,0.028687602,0.12396941,-5.01376
71 | attn_1_0_outputs,13,-0.3463128,1.461672,-0.94459385
72 | attn_1_0_outputs,14,1.6825908,-1.798565,-0.6786788
73 | attn_1_0_outputs,15,0.6389121,-0.9041575,-0.96575123
74 | attn_1_0_outputs,2,0.25002298,-1.8573852,1.413765
75 | attn_1_0_outputs,3,0.22070499,-1.2486526,0.710387
76 | attn_1_0_outputs,4,-0.6571064,0.35194886,-1.3422754
77 | attn_1_0_outputs,5,1.9105325,-1.0667194,0.36192167
78 | attn_1_0_outputs,6,0.047967866,-0.09409498,-0.6467464
79 | attn_1_0_outputs,7,-0.86700124,1.1282116,0.33472228
80 | attn_1_0_outputs,8,-0.6687376,1.0361387,-0.20685646
81 | attn_1_0_outputs,9,-0.7377063,1.0348144,-0.33758923
82 | attn_1_1_outputs,0,-0.5052586,0.5057064,0.29024526
83 | attn_1_1_outputs,1,-0.7648216,1.4227237,-0.258629
84 | attn_1_1_outputs,10,0.41181248,-0.6207077,-0.9900048
85 | attn_1_1_outputs,11,0.077770524,-0.9213925,-0.05973171
86 | attn_1_1_outputs,12,0.4023454,1.3574238,-3.6971736
87 | attn_1_1_outputs,13,-2.6055536,0.43831292,1.9638561
88 | attn_1_1_outputs,14,1.7760623,-1.4472873,-0.7355287
89 | attn_1_1_outputs,15,-1.3515613,0.17346454,0.6269158
90 | attn_1_1_outputs,2,1.4712042,-1.8864663,1.1948589
91 | attn_1_1_outputs,3,0.53816074,-0.7384589,0.7789173
92 | attn_1_1_outputs,4,-0.8784697,1.489651,-0.40728238
93 | attn_1_1_outputs,5,2.2822268,-1.2395031,0.87639683
94 | attn_1_1_outputs,6,-0.64847505,-0.52782285,-0.6055985
95 | attn_1_1_outputs,7,0.035765603,1.5342312,0.83606446
96 | attn_1_1_outputs,8,0.82426596,-0.4623573,-0.52451324
97 | attn_1_1_outputs,9,0.26120463,-0.1682919,-0.74863493
98 | attn_1_2_outputs,0,0.2741086,0.8019985,0.09451536
99 | attn_1_2_outputs,1,0.037871893,1.1881831,-0.593388
100 | attn_1_2_outputs,10,-0.10018027,-0.28660047,0.33032972
101 | attn_1_2_outputs,11,-0.15546942,0.14354461,-0.09335824
102 | attn_1_2_outputs,12,-0.75883734,1.1132323,-1.8390666
103 | attn_1_2_outputs,13,-0.91976297,0.20129623,0.7091432
104 | attn_1_2_outputs,14,2.1666486,-1.2059776,-1.1775556
105 | attn_1_2_outputs,15,1.2957777,0.11005765,0.27522662
106 | attn_1_2_outputs,2,2.6221218,-0.79460573,-0.4321904
107 | attn_1_2_outputs,3,-0.9202741,0.67877096,-1.6474462
108 | attn_1_2_outputs,4,-0.80796766,0.73023415,-0.8312531
109 | attn_1_2_outputs,5,1.5846504,-1.2096406,0.7192363
110 | attn_1_2_outputs,6,0.5472782,-1.2913166,-0.11775407
111 | attn_1_2_outputs,7,0.8987451,-1.1431016,0.027534375
112 | attn_1_2_outputs,8,-0.6447031,0.3921986,0.90193
113 | attn_1_2_outputs,9,-0.6580996,0.2183337,0.47104686
114 | attn_1_3_outputs,0,0.7974158,0.076014474,-0.12770599
115 | attn_1_3_outputs,1,0.08960959,0.38123062,-1.230019
116 | attn_1_3_outputs,10,0.46521083,0.018465715,-0.1728664
117 | attn_1_3_outputs,11,0.6438368,-0.79515433,-0.28863785
118 | attn_1_3_outputs,12,0.53023314,-1.1950611,-1.609009
119 | attn_1_3_outputs,13,0.10482195,1.6155713,0.0057584657
120 | attn_1_3_outputs,14,0.95015556,-1.4626787,0.5756642
121 | attn_1_3_outputs,15,1.1534876,0.57546484,-1.156569
122 | attn_1_3_outputs,2,0.014377812,-0.7606671,1.6436303
123 | attn_1_3_outputs,3,0.43214664,-0.30850783,0.011211452
124 | attn_1_3_outputs,4,0.73773974,1.1139972,-0.69320714
125 | attn_1_3_outputs,5,0.5149078,-0.9359474,1.22139
126 | attn_1_3_outputs,6,0.6841033,-0.39748695,-0.08539801
127 | attn_1_3_outputs,7,-0.43467954,1.2192235,-0.83311576
128 | attn_1_3_outputs,8,-0.5582946,0.3580021,0.18355377
129 | attn_1_3_outputs,9,-0.15669319,1.2627395,-0.6961721
130 | attn_2_0_outputs,0,-0.40245613,0.24030568,0.30186835
131 | attn_2_0_outputs,1,-1.3748924,0.71387535,-0.39306825
132 | attn_2_0_outputs,10,-0.59452015,-0.4597474,2.018241
133 | attn_2_0_outputs,11,0.2047717,0.003245883,-0.73424035
134 | attn_2_0_outputs,12,-0.720487,1.5443448,-1.6471566
135 | attn_2_0_outputs,13,0.37173122,-0.028756775,0.36360672
136 | attn_2_0_outputs,14,0.76262903,-0.89649415,-0.5565152
137 | attn_2_0_outputs,15,0.025871454,0.16201858,0.56903595
138 | attn_2_0_outputs,2,2.0154552,-2.0844448,-0.10086185
139 | attn_2_0_outputs,3,0.73500365,-0.98438346,0.101617634
140 | attn_2_0_outputs,4,-1.2039843,2.2337372,0.16640835
141 | attn_2_0_outputs,5,3.293801,-2.8079114,-0.77340686
142 | attn_2_0_outputs,6,-0.7499077,0.33472747,-0.1948529
143 | attn_2_0_outputs,7,-0.119230315,-0.78161407,-0.29735667
144 | attn_2_0_outputs,8,-0.21677648,-0.98522747,-0.09130101
145 | attn_2_0_outputs,9,0.52573526,-0.1827654,0.5464673
146 | attn_2_1_outputs,0,-1.387743,0.9506943,0.5479685
147 | attn_2_1_outputs,1,-0.37947392,0.5295445,0.2626229
148 | attn_2_1_outputs,10,0.13699992,-0.0180703,-1.8421631
149 | attn_2_1_outputs,11,-0.43704024,-0.2988407,-0.862835
150 | attn_2_1_outputs,12,-1.346792,1.172357,-0.39379665
151 | attn_2_1_outputs,13,2.1580062,-1.5489818,-1.5334694
152 | attn_2_1_outputs,14,0.21818756,0.87238884,0.45327023
153 | attn_2_1_outputs,15,1.447946,-0.76852626,-0.29249018
154 | attn_2_1_outputs,2,0.34594288,0.8285091,0.060779355
155 | attn_2_1_outputs,3,-0.054519173,0.14733565,0.14937484
156 | attn_2_1_outputs,4,-0.39207256,1.3152301,-0.71527904
157 | attn_2_1_outputs,5,7.0318747,-4.8902483,-5.9904876
158 | attn_2_1_outputs,6,0.96370083,0.029464895,-0.1656916
159 | attn_2_1_outputs,7,0.58060133,-0.9354646,0.65154225
160 | attn_2_1_outputs,8,0.637,-0.49399912,-0.21180116
161 | attn_2_1_outputs,9,0.19271861,-0.01355111,-1.29539
162 | attn_2_2_outputs,0,-1.387934,0.39397144,0.16978894
163 | attn_2_2_outputs,1,-0.6932238,1.0597913,0.08682805
164 | attn_2_2_outputs,10,-0.17581192,0.5762519,0.04119891
165 | attn_2_2_outputs,11,-0.8586262,-0.40480924,-0.4132628
166 | attn_2_2_outputs,12,-1.6009113,2.973149,-1.1810892
167 | attn_2_2_outputs,13,-0.34387457,-0.35372856,0.28478163
168 | attn_2_2_outputs,14,0.35553423,0.7332589,-0.1989791
169 | attn_2_2_outputs,15,0.69653744,-0.09478814,-0.54141074
170 | attn_2_2_outputs,2,1.1788943,-0.65940076,-0.20502281
171 | attn_2_2_outputs,3,0.65942603,-0.32618254,-0.107519
172 | attn_2_2_outputs,4,-1.3350819,0.6971251,0.092566006
173 | attn_2_2_outputs,5,0.8389036,0.28690335,0.43394306
174 | attn_2_2_outputs,6,0.018443262,0.20905562,0.04728933
175 | attn_2_2_outputs,7,-0.85211277,0.10939689,1.3427547
176 | attn_2_2_outputs,8,0.25118992,0.59912837,-0.42855418
177 | attn_2_2_outputs,9,0.3361104,-0.14676784,-0.7411014
178 | attn_2_3_outputs,0,-0.5618858,0.5809453,1.0021099
179 | attn_2_3_outputs,1,-0.5739619,1.7692959,-0.378794
180 | attn_2_3_outputs,10,0.035676032,-0.6341216,0.28642064
181 | attn_2_3_outputs,11,0.07306448,0.17111002,0.2374836
182 | attn_2_3_outputs,12,-1.188873,1.1038005,-1.1276703
183 | attn_2_3_outputs,13,0.54361314,-1.0640161,0.84797543
184 | attn_2_3_outputs,14,1.3631084,-0.56192636,0.4859162
185 | attn_2_3_outputs,15,0.48447692,-0.37444183,0.39286935
186 | attn_2_3_outputs,2,8.552915,-5.3725734,-6.36959
187 | attn_2_3_outputs,3,-0.11614109,0.30205145,0.12240872
188 | attn_2_3_outputs,4,-1.6696448,1.1503662,-0.46547452
189 | attn_2_3_outputs,5,3.2380612,-1.0949727,0.22095455
190 | attn_2_3_outputs,6,0.10913107,0.03129147,-1.241605
191 | attn_2_3_outputs,7,0.0360893,-1.0509039,0.4309466
192 | attn_2_3_outputs,8,2.3759365,-1.4136782,-1.8956783
193 | attn_2_3_outputs,9,2.515315,-1.9465212,-1.5227278
194 | mlp_0_0_outputs,0,-0.20890403,0.14091475,0.7584955
195 | mlp_0_0_outputs,1,0.3021631,0.1919344,0.1126692
196 | mlp_0_0_outputs,10,-2.756637,1.5794932,1.9001974
197 | mlp_0_0_outputs,11,0.35167038,0.067295164,0.41860104
198 | mlp_0_0_outputs,12,-0.013727432,0.9100293,-2.461872
199 | mlp_0_0_outputs,13,0.46883824,0.21986999,0.37993425
200 | mlp_0_0_outputs,14,-0.018150175,0.22210702,-0.48269686
201 | mlp_0_0_outputs,15,1.2375835,0.29053566,-0.856558
202 | mlp_0_0_outputs,2,1.1214304,-1.9128251,0.6156005
203 | mlp_0_0_outputs,3,0.567998,-0.43510225,0.60242385
204 | mlp_0_0_outputs,4,0.15451127,-0.5385739,-0.78529364
205 | mlp_0_0_outputs,5,1.6117405,-2.094646,-0.42933798
206 | mlp_0_0_outputs,6,0.87502843,0.53566813,-0.9268592
207 | mlp_0_0_outputs,7,0.31717533,-0.27278233,-0.0015997723
208 | mlp_0_0_outputs,8,1.5352674,0.55484116,-6.1980653
209 | mlp_0_0_outputs,9,0.97670096,-0.14376327,-6.9078407
210 | mlp_0_1_outputs,0,0.9233807,0.2854663,-0.8436871
211 | mlp_0_1_outputs,1,0.24584311,0.2944399,0.66519386
212 | mlp_0_1_outputs,10,0.22019076,0.81191224,0.5874235
213 | mlp_0_1_outputs,11,0.11129518,-0.08004127,-0.122246675
214 | mlp_0_1_outputs,12,0.32788306,0.6816243,-2.6723564
215 | mlp_0_1_outputs,13,-0.3635576,0.86900264,-0.9099431
216 | mlp_0_1_outputs,14,0.5210501,-0.018928718,-0.46168962
217 | mlp_0_1_outputs,15,0.93055755,-0.023645593,-0.7769493
218 | mlp_0_1_outputs,2,-1.2583181,1.8995489,0.5895376
219 | mlp_0_1_outputs,3,1.0925126,-1.0588683,-1.3893077
220 | mlp_0_1_outputs,4,0.53683454,0.7541305,0.17249009
221 | mlp_0_1_outputs,5,0.5254129,0.41082186,0.31906855
222 | mlp_0_1_outputs,6,-0.30498037,-1.0835053,0.2178056
223 | mlp_0_1_outputs,7,-0.6612743,0.8853988,0.6170607
224 | mlp_0_1_outputs,8,-0.46027464,9.3793526e-05,-0.13259079
225 | mlp_0_1_outputs,9,-0.7891113,1.0600111,0.6583635
226 | mlp_1_0_outputs,0,0.34241092,0.2945056,-0.1292986
227 | mlp_1_0_outputs,1,0.47135493,0.46016377,-0.43053067
228 | mlp_1_0_outputs,10,0.2581985,0.5706344,0.19086708
229 | mlp_1_0_outputs,11,0.26166466,0.90268534,-0.17562716
230 | mlp_1_0_outputs,12,-0.17449586,0.33637732,0.16877215
231 | mlp_1_0_outputs,13,-1.4794242,1.719357,0.39456385
232 | mlp_1_0_outputs,14,0.77961195,0.008924905,-0.09428507
233 | mlp_1_0_outputs,15,0.08689608,0.9701745,0.09053159
234 | mlp_1_0_outputs,2,1.1217159,-0.7634738,-1.2014244
235 | mlp_1_0_outputs,3,0.39370438,-0.41021878,-0.7921828
236 | mlp_1_0_outputs,4,-0.4663107,1.2819383,0.4163665
237 | mlp_1_0_outputs,5,0.49742854,0.008245769,-0.45130947
238 | mlp_1_0_outputs,6,0.14249237,-0.102100946,-0.5098823
239 | mlp_1_0_outputs,7,0.30476442,0.62488115,-1.3264716
240 | mlp_1_0_outputs,8,0.08670288,-0.42941815,-0.7135336
241 | mlp_1_0_outputs,9,0.26799846,-0.09152536,-1.3237554
242 | mlp_1_1_outputs,0,0.26805434,0.23167966,0.4149677
243 | mlp_1_1_outputs,1,0.53618395,-0.5244912,-2.6226473
244 | mlp_1_1_outputs,10,2.2270403,-0.6349289,-3.1257336
245 | mlp_1_1_outputs,11,-0.14030638,-0.15512456,0.7863296
246 | mlp_1_1_outputs,12,-0.5041155,0.5484189,-0.19427307
247 | mlp_1_1_outputs,13,-0.26450726,-0.33360416,0.79678315
248 | mlp_1_1_outputs,14,-0.9820768,0.5201349,0.059233483
249 | mlp_1_1_outputs,15,-0.7036932,0.4160271,-0.47682777
250 | mlp_1_1_outputs,2,-0.39756742,0.5957785,0.6356138
251 | mlp_1_1_outputs,3,-0.09534266,-0.17529783,0.72164303
252 | mlp_1_1_outputs,4,0.88351613,0.77100885,0.38341337
253 | mlp_1_1_outputs,5,-0.011537257,-0.18188691,0.7026197
254 | mlp_1_1_outputs,6,-0.6994262,0.019658938,-0.47276354
255 | mlp_1_1_outputs,7,0.57141864,-0.04391361,0.055888955
256 | mlp_1_1_outputs,8,0.043691415,0.07147635,0.50050634
257 | mlp_1_1_outputs,9,-0.35295078,0.23593843,0.7010527
258 | mlp_2_0_outputs,0,1.8654248,-0.76327336,-2.1851802
259 | mlp_2_0_outputs,1,2.4732018,2.0428228,-2.8090878
260 | mlp_2_0_outputs,10,0.50512856,-0.4752733,-0.42407128
261 | mlp_2_0_outputs,11,-0.2162885,-0.041211855,-0.13339278
262 | mlp_2_0_outputs,12,0.2361639,-0.9290504,0.4851906
263 | mlp_2_0_outputs,13,0.34775272,-0.14868683,0.1458214
264 | mlp_2_0_outputs,14,0.28508988,0.33885202,0.37766585
265 | mlp_2_0_outputs,15,-0.08315753,0.28621098,0.066261254
266 | mlp_2_0_outputs,2,0.5492595,-0.30825824,-0.66436327
267 | mlp_2_0_outputs,3,0.27374202,0.40835023,0.27150822
268 | mlp_2_0_outputs,4,1.1041646,0.44970506,-0.008170099
269 | mlp_2_0_outputs,5,0.14885116,0.48616803,-0.02961562
270 | mlp_2_0_outputs,6,-0.05234915,0.3683034,-2.0102525
271 | mlp_2_0_outputs,7,-2.398438,0.019524893,8.194051
272 | mlp_2_0_outputs,8,0.25424272,0.4785088,-0.41905257
273 | mlp_2_0_outputs,9,-0.14737429,-0.1506395,0.16152707
274 | mlp_2_1_outputs,0,3.7675667,-1.4037446,-3.882887
275 | mlp_2_1_outputs,1,0.38408276,-0.20874391,-0.79680127
276 | mlp_2_1_outputs,10,0.82168883,0.15445253,-0.35908702
277 | mlp_2_1_outputs,11,0.40715832,-0.078189924,-0.23746909
278 | mlp_2_1_outputs,12,0.19456168,0.5811429,0.19617397
279 | mlp_2_1_outputs,13,-0.18762338,-0.037409503,-0.88011366
280 | mlp_2_1_outputs,14,-0.48191726,0.7540797,0.049526248
281 | mlp_2_1_outputs,15,0.704309,0.12997411,0.05762332
282 | mlp_2_1_outputs,2,-1.4413515,1.5535452,0.43207467
283 | mlp_2_1_outputs,3,0.25547877,-0.8476021,-0.36438197
284 | mlp_2_1_outputs,4,0.67018205,-0.10163318,-0.2500151
285 | mlp_2_1_outputs,5,-0.094072334,-0.73808116,-0.67700255
286 | mlp_2_1_outputs,6,0.32068485,0.49894604,-0.4640106
287 | mlp_2_1_outputs,7,0.6843552,-0.1548457,-0.12560412
288 | mlp_2_1_outputs,8,-0.38264602,-0.3764761,-0.9299223
289 | mlp_2_1_outputs,9,0.104728,-0.70314866,-0.6734416
290 | ones,_,0.25881597,-0.46701866,0.16698436
291 | positions,0,-0.2762641,-0.17773171,-1.017329
292 | positions,1,0.26486596,0.65079963,-2.5160186
293 | positions,10,0.14841577,-0.47636732,1.5774705
294 | positions,11,1.761886,1.1001859,-10.36295
295 | positions,12,-0.68251526,-2.187309,2.290777
296 | positions,13,1.9318004,1.8703078,-8.895537
297 | positions,14,-1.3542655,-2.2336614,3.640843
298 | positions,15,1.3797264,1.0626522,-7.66056
299 | positions,2,-4.7400956,-4.6051226,10.471129
300 | positions,3,2.772535,2.4986095,-10.479663
301 | positions,4,-3.1975293,-2.1136668,8.509926
302 | positions,5,0.9007016,2.1843433,-7.800698
303 | positions,6,-2.446595,-0.2883177,4.8488107
304 | positions,7,1.4072119,2.093496,-6.3489823
305 | positions,8,-0.5884627,-0.3837479,1.0950348
306 | positions,9,1.4770374,2.9142487,-7.4769826
307 | tokens,(,1.0545344,1.422954,-4.159084
308 | tokens,),-0.44501638,-0.9425237,3.0321198
309 | tokens,,0.24174905,-0.15760137,0.25735503
310 | tokens,,1.4156133,-0.7207831,-0.14946787
311 | tokens,_,0.32973674,0.1480116,0.40687805
312 | tokens,_,0.4269608,0.18007647,0.0591342
313 | tokens,_,0.012144065,0.24502037,-0.37508482
314 | tokens,_,1.0130444,0.2379365,0.7822351
315 | tokens,_,0.24589725,0.05158651,-0.009494092
316 | tokens,_,0.13942418,0.16681431,0.54903275
317 | tokens,_,-0.2541933,0.61024016,0.46229398
318 | tokens,_,-0.67985225,0.38225034,-0.016558083
319 | tokens,_,1.4591968,-0.29021746,-0.55709064
320 | tokens,_,0.14226177,0.10692842,0.23002318
321 | tokens,{,0.9597864,1.4787383,-3.468003
322 | tokens,},-0.43541923,-1.1256342,2.860657
323 |
--------------------------------------------------------------------------------
/programs/rasp_categorical_only/sort/sort.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def select_closest(keys, queries, predicate):
6 | scores = [[False for _ in keys] for _ in queries]
7 | for i, q in enumerate(queries):
8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)]
9 | if not (any(matches)):
10 | scores[i][0] = True
11 | else:
12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j))
13 | scores[i][j] = True
14 | return scores
15 |
16 |
17 | def aggregate(attention, values):
18 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention]
19 |
20 |
21 | def run(tokens):
22 | # classifier weights ##########################################
23 | classifier_weights = pd.read_csv(
24 | "programs/rasp_categorical_only/sort/sort_weights.csv",
25 | index_col=[0, 1],
26 | dtype={"feature": str},
27 | )
28 | # inputs #####################################################
29 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]]
30 |
31 | positions = list(range(len(tokens)))
32 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]]
33 |
34 | ones = [1 for _ in range(len(tokens))]
35 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0)
36 |
37 | # attn_0_0 ####################################################
38 | def predicate_0_0(q_token, k_token):
39 | if q_token in {"", "0"}:
40 | return k_token == "0"
41 | elif q_token in {"1"}:
42 | return k_token == "1"
43 | elif q_token in {"2"}:
44 | return k_token == "2"
45 | elif q_token in {"3"}:
46 | return k_token == "3"
47 | elif q_token in {"4", ""}:
48 | return k_token == "4"
49 |
50 | attn_0_0_pattern = select_closest(tokens, tokens, predicate_0_0)
51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, tokens)
52 | attn_0_0_output_scores = classifier_weights.loc[
53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs]
54 | ]
55 |
56 | # attn_0_1 ####################################################
57 | def predicate_0_1(q_position, k_position):
58 | if q_position in {0}:
59 | return k_position == 1
60 | elif q_position in {1, 6}:
61 | return k_position == 2
62 | elif q_position in {2}:
63 | return k_position == 3
64 | elif q_position in {3, 5}:
65 | return k_position == 4
66 | elif q_position in {4, 7}:
67 | return k_position == 6
68 |
69 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1)
70 | attn_0_1_outputs = aggregate(attn_0_1_pattern, tokens)
71 | attn_0_1_output_scores = classifier_weights.loc[
72 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs]
73 | ]
74 |
75 | # attn_0_2 ####################################################
76 | def predicate_0_2(q_position, k_position):
77 | if q_position in {0}:
78 | return k_position == 1
79 | elif q_position in {1}:
80 | return k_position == 3
81 | elif q_position in {2, 6}:
82 | return k_position == 4
83 | elif q_position in {3, 4}:
84 | return k_position == 5
85 | elif q_position in {5, 7}:
86 | return k_position == 6
87 |
88 | attn_0_2_pattern = select_closest(positions, positions, predicate_0_2)
89 | attn_0_2_outputs = aggregate(attn_0_2_pattern, tokens)
90 | attn_0_2_output_scores = classifier_weights.loc[
91 | [("attn_0_2_outputs", str(v)) for v in attn_0_2_outputs]
92 | ]
93 |
94 | # attn_0_3 ####################################################
95 | def predicate_0_3(q_position, k_position):
96 | if q_position in {0, 4}:
97 | return k_position == 3
98 | elif q_position in {1, 7}:
99 | return k_position == 4
100 | elif q_position in {2}:
101 | return k_position == 1
102 | elif q_position in {3, 5}:
103 | return k_position == 2
104 | elif q_position in {6}:
105 | return k_position == 5
106 |
107 | attn_0_3_pattern = select_closest(positions, positions, predicate_0_3)
108 | attn_0_3_outputs = aggregate(attn_0_3_pattern, tokens)
109 | attn_0_3_output_scores = classifier_weights.loc[
110 | [("attn_0_3_outputs", str(v)) for v in attn_0_3_outputs]
111 | ]
112 |
113 | # mlp_0_0 #####################################################
114 | def mlp_0_0(attn_0_2_output, position):
115 | key = (attn_0_2_output, position)
116 | if key in {
117 | ("0", 1),
118 | ("0", 2),
119 | ("0", 3),
120 | ("0", 4),
121 | ("1", 1),
122 | ("1", 2),
123 | ("2", 1),
124 | ("2", 2),
125 | ("3", 1),
126 | ("4", 1),
127 | ("", 1),
128 | }:
129 | return 0
130 | elif key in {
131 | ("0", 0),
132 | ("0", 6),
133 | ("1", 0),
134 | ("1", 6),
135 | ("", 0),
136 | ("", 2),
137 | ("", 3),
138 | ("", 5),
139 | ("", 6),
140 | ("", 7),
141 | }:
142 | return 2
143 | elif key in {
144 | ("0", 5),
145 | ("0", 7),
146 | ("1", 3),
147 | ("1", 4),
148 | ("1", 5),
149 | ("1", 7),
150 | ("3", 2),
151 | ("4", 2),
152 | ("", 1),
153 | }:
154 | return 7
155 | elif key in {
156 | ("3", 4),
157 | ("3", 5),
158 | ("3", 7),
159 | ("4", 4),
160 | ("4", 5),
161 | ("4", 7),
162 | ("", 4),
163 | }:
164 | return 1
165 | elif key in {("2", 3), ("2", 4), ("2", 5), ("2", 7), ("3", 3), ("4", 3)}:
166 | return 6
167 | return 4
168 |
169 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(attn_0_2_outputs, positions)]
170 | mlp_0_0_output_scores = classifier_weights.loc[
171 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs]
172 | ]
173 |
174 | # mlp_0_1 #####################################################
175 | def mlp_0_1(position):
176 | key = position
177 | if key in {1, 2}:
178 | return 7
179 | elif key in {5}:
180 | return 1
181 | return 6
182 |
183 | mlp_0_1_outputs = [mlp_0_1(k0) for k0 in positions]
184 | mlp_0_1_output_scores = classifier_weights.loc[
185 | [("mlp_0_1_outputs", str(v)) for v in mlp_0_1_outputs]
186 | ]
187 |
188 | # mlp_0_2 #####################################################
189 | def mlp_0_2(attn_0_1_output, position):
190 | key = (attn_0_1_output, position)
191 | if key in {
192 | ("1", 4),
193 | ("2", 2),
194 | ("2", 3),
195 | ("2", 4),
196 | ("2", 7),
197 | ("3", 2),
198 | ("3", 3),
199 | ("3", 4),
200 | ("3", 7),
201 | ("4", 2),
202 | ("4", 3),
203 | ("4", 4),
204 | ("4", 7),
205 | }:
206 | return 3
207 | elif key in {
208 | ("0", 0),
209 | ("0", 1),
210 | ("0", 2),
211 | ("0", 3),
212 | ("0", 4),
213 | ("0", 7),
214 | ("1", 1),
215 | ("2", 1),
216 | ("3", 1),
217 | ("4", 1),
218 | ("", 1),
219 | }:
220 | return 6
221 | elif key in {
222 | ("0", 5),
223 | ("1", 2),
224 | ("1", 3),
225 | ("1", 7),
226 | ("", 1),
227 | ("", 2),
228 | ("", 3),
229 | ("", 7),
230 | }:
231 | return 5
232 | return 1
233 |
234 | mlp_0_2_outputs = [mlp_0_2(k0, k1) for k0, k1 in zip(attn_0_1_outputs, positions)]
235 | mlp_0_2_output_scores = classifier_weights.loc[
236 | [("mlp_0_2_outputs", str(v)) for v in mlp_0_2_outputs]
237 | ]
238 |
239 | # mlp_0_3 #####################################################
240 | def mlp_0_3(attn_0_2_output, position):
241 | key = (attn_0_2_output, position)
242 | if key in {
243 | ("0", 4),
244 | ("0", 7),
245 | ("1", 4),
246 | ("2", 4),
247 | ("2", 7),
248 | ("3", 4),
249 | ("3", 7),
250 | ("4", 4),
251 | ("", 0),
252 | ("", 1),
253 | ("", 2),
254 | ("", 3),
255 | ("", 4),
256 | ("", 7),
257 | ("", 0),
258 | ("", 1),
259 | ("", 2),
260 | ("", 3),
261 | ("", 4),
262 | ("", 6),
263 | ("", 7),
264 | }:
265 | return 0
266 | elif key in {
267 | ("0", 5),
268 | ("0", 6),
269 | ("1", 5),
270 | ("1", 6),
271 | ("1", 7),
272 | ("2", 5),
273 | ("2", 6),
274 | ("3", 5),
275 | ("3", 6),
276 | ("4", 5),
277 | ("4", 6),
278 | ("", 5),
279 | ("", 6),
280 | ("", 5),
281 | }:
282 | return 7
283 | return 5
284 |
285 | mlp_0_3_outputs = [mlp_0_3(k0, k1) for k0, k1 in zip(attn_0_2_outputs, positions)]
286 | mlp_0_3_output_scores = classifier_weights.loc[
287 | [("mlp_0_3_outputs", str(v)) for v in mlp_0_3_outputs]
288 | ]
289 |
290 | # attn_1_0 ####################################################
291 | def predicate_1_0(position, mlp_0_1_output):
292 | if position in {0, 1, 2, 3}:
293 | return mlp_0_1_output == 1
294 | elif position in {4, 5, 7}:
295 | return mlp_0_1_output == 0
296 | elif position in {6}:
297 | return mlp_0_1_output == 3
298 |
299 | attn_1_0_pattern = select_closest(mlp_0_1_outputs, positions, predicate_1_0)
300 | attn_1_0_outputs = aggregate(attn_1_0_pattern, attn_0_2_outputs)
301 | attn_1_0_output_scores = classifier_weights.loc[
302 | [("attn_1_0_outputs", str(v)) for v in attn_1_0_outputs]
303 | ]
304 |
305 | # attn_1_1 ####################################################
306 | def predicate_1_1(q_position, k_position):
307 | if q_position in {0, 4, 5, 7}:
308 | return k_position == 1
309 | elif q_position in {1}:
310 | return k_position == 0
311 | elif q_position in {2}:
312 | return k_position == 5
313 | elif q_position in {3}:
314 | return k_position == 6
315 | elif q_position in {6}:
316 | return k_position == 7
317 |
318 | attn_1_1_pattern = select_closest(positions, positions, predicate_1_1)
319 | attn_1_1_outputs = aggregate(attn_1_1_pattern, attn_0_0_outputs)
320 | attn_1_1_output_scores = classifier_weights.loc[
321 | [("attn_1_1_outputs", str(v)) for v in attn_1_1_outputs]
322 | ]
323 |
324 | # attn_1_2 ####################################################
325 | def predicate_1_2(q_position, k_position):
326 | if q_position in {0, 3}:
327 | return k_position == 1
328 | elif q_position in {1}:
329 | return k_position == 5
330 | elif q_position in {2}:
331 | return k_position == 6
332 | elif q_position in {4, 7}:
333 | return k_position == 2
334 | elif q_position in {5, 6}:
335 | return k_position == 3
336 |
337 | attn_1_2_pattern = select_closest(positions, positions, predicate_1_2)
338 | attn_1_2_outputs = aggregate(attn_1_2_pattern, tokens)
339 | attn_1_2_output_scores = classifier_weights.loc[
340 | [("attn_1_2_outputs", str(v)) for v in attn_1_2_outputs]
341 | ]
342 |
343 | # attn_1_3 ####################################################
344 | def predicate_1_3(position, attn_0_1_output):
345 | if position in {0}:
346 | return attn_0_1_output == "1"
347 | elif position in {1, 2}:
348 | return attn_0_1_output == "0"
349 | elif position in {3}:
350 | return attn_0_1_output == ""
351 | elif position in {4, 5, 6}:
352 | return attn_0_1_output == "4"
353 | elif position in {7}:
354 | return attn_0_1_output == "3"
355 |
356 | attn_1_3_pattern = select_closest(attn_0_1_outputs, positions, predicate_1_3)
357 | attn_1_3_outputs = aggregate(attn_1_3_pattern, attn_0_1_outputs)
358 | attn_1_3_output_scores = classifier_weights.loc[
359 | [("attn_1_3_outputs", str(v)) for v in attn_1_3_outputs]
360 | ]
361 |
362 | # mlp_1_0 #####################################################
363 | def mlp_1_0(position, attn_1_2_output):
364 | key = (position, attn_1_2_output)
365 | if key in {
366 | (0, "0"),
367 | (1, "0"),
368 | (2, "0"),
369 | (2, "1"),
370 | (2, "2"),
371 | (2, "3"),
372 | (2, "4"),
373 | (3, "0"),
374 | }:
375 | return 5
376 | elif key in {(1, "1"), (1, "2"), (1, "3"), (1, "4")}:
377 | return 1
378 | elif key in {(0, "2"), (4, "2"), (5, "2")}:
379 | return 4
380 | elif key in {(0, "1"), (3, "1"), (4, "0")}:
381 | return 6
382 | elif key in {(3, "2"), (4, "1"), (5, "1")}:
383 | return 7
384 | elif key in {(5, "0")}:
385 | return 2
386 | return 0
387 |
388 | mlp_1_0_outputs = [mlp_1_0(k0, k1) for k0, k1 in zip(positions, attn_1_2_outputs)]
389 | mlp_1_0_output_scores = classifier_weights.loc[
390 | [("mlp_1_0_outputs", str(v)) for v in mlp_1_0_outputs]
391 | ]
392 |
393 | # mlp_1_1 #####################################################
394 | def mlp_1_1(attn_1_3_output, attn_1_0_output):
395 | key = (attn_1_3_output, attn_1_0_output)
396 | if key in {
397 | ("3", "2"),
398 | ("3", "3"),
399 | ("3", "4"),
400 | ("3", ""),
401 | ("3", ""),
402 | ("4", "2"),
403 | ("4", "3"),
404 | ("4", "4"),
405 | ("4", ""),
406 | ("4", ""),
407 | ("", "2"),
408 | ("", "3"),
409 | ("", "4"),
410 | ("", ""),
411 | ("", ""),
412 | ("", "3"),
413 | ("", "4"),
414 | ("", ""),
415 | ("", ""),
416 | }:
417 | return 4
418 | elif key in {("0", ""), ("2", "")}:
419 | return 6
420 | return 1
421 |
422 | mlp_1_1_outputs = [
423 | mlp_1_1(k0, k1) for k0, k1 in zip(attn_1_3_outputs, attn_1_0_outputs)
424 | ]
425 | mlp_1_1_output_scores = classifier_weights.loc[
426 | [("mlp_1_1_outputs", str(v)) for v in mlp_1_1_outputs]
427 | ]
428 |
429 | # mlp_1_2 #####################################################
430 | def mlp_1_2(attn_1_0_output, position):
431 | key = (attn_1_0_output, position)
432 | if key in {
433 | ("2", 4),
434 | ("2", 5),
435 | ("3", 4),
436 | ("3", 5),
437 | ("4", 0),
438 | ("4", 4),
439 | ("4", 5),
440 | ("", 4),
441 | ("", 5),
442 | ("", 4),
443 | ("", 5),
444 | }:
445 | return 0
446 | elif key in {
447 | ("0", 3),
448 | ("1", 3),
449 | ("2", 3),
450 | ("3", 0),
451 | ("3", 3),
452 | ("4", 3),
453 | ("", 3),
454 | ("", 3),
455 | }:
456 | return 7
457 | elif key in {
458 | ("2", 1),
459 | ("2", 7),
460 | ("3", 1),
461 | ("3", 7),
462 | ("4", 1),
463 | ("4", 7),
464 | ("", 1),
465 | ("", 7),
466 | }:
467 | return 1
468 | elif key in {("0", 4), ("0", 5), ("1", 4), ("1", 5)}:
469 | return 4
470 | elif key in {("2", 0), ("2", 6), ("3", 6), ("4", 6)}:
471 | return 5
472 | elif key in {("1", 1), ("1", 7), ("", 6), ("", 6)}:
473 | return 2
474 | elif key in {("4", 2), ("", 0), ("", 2), ("", 7)}:
475 | return 6
476 | return 3
477 |
478 | mlp_1_2_outputs = [mlp_1_2(k0, k1) for k0, k1 in zip(attn_1_0_outputs, positions)]
479 | mlp_1_2_output_scores = classifier_weights.loc[
480 | [("mlp_1_2_outputs", str(v)) for v in mlp_1_2_outputs]
481 | ]
482 |
483 | # mlp_1_3 #####################################################
484 | def mlp_1_3(attn_1_0_output, attn_1_3_output):
485 | key = (attn_1_0_output, attn_1_3_output)
486 | if key in {
487 | ("0", "3"),
488 | ("0", ""),
489 | ("1", "3"),
490 | ("2", "3"),
491 | ("3", "0"),
492 | ("3", "1"),
493 | ("3", "2"),
494 | ("3", "3"),
495 | ("3", ""),
496 | ("", "0"),
497 | ("", "1"),
498 | ("", "2"),
499 | ("", "3"),
500 | ("", ""),
501 | ("", "3"),
502 | }:
503 | return 5
504 | elif key in {("3", "4")}:
505 | return 2
506 | return 7
507 |
508 | mlp_1_3_outputs = [
509 | mlp_1_3(k0, k1) for k0, k1 in zip(attn_1_0_outputs, attn_1_3_outputs)
510 | ]
511 | mlp_1_3_output_scores = classifier_weights.loc[
512 | [("mlp_1_3_outputs", str(v)) for v in mlp_1_3_outputs]
513 | ]
514 |
515 | feature_logits = pd.concat(
516 | [
517 | df.reset_index()
518 | for df in [
519 | token_scores,
520 | position_scores,
521 | attn_0_0_output_scores,
522 | attn_0_1_output_scores,
523 | attn_0_2_output_scores,
524 | attn_0_3_output_scores,
525 | mlp_0_0_output_scores,
526 | mlp_0_1_output_scores,
527 | mlp_0_2_output_scores,
528 | mlp_0_3_output_scores,
529 | attn_1_0_output_scores,
530 | attn_1_1_output_scores,
531 | attn_1_2_output_scores,
532 | attn_1_3_output_scores,
533 | mlp_1_0_output_scores,
534 | mlp_1_1_output_scores,
535 | mlp_1_2_output_scores,
536 | mlp_1_3_output_scores,
537 | one_scores,
538 | ]
539 | ]
540 | )
541 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy()
542 | classes = classifier_weights.columns.to_numpy()
543 | predictions = classes[logits.argmax(-1)]
544 | if tokens[0] == "":
545 | predictions[0] = ""
546 | if tokens[-1] == "":
547 | predictions[-1] = ""
548 | return predictions.tolist()
549 |
550 |
551 | examples = [
552 | (
553 | ["", "0", "4", "1", "1", "4", "2", ""],
554 | ["", "0", "1", "1", "2", "4", "4", ""],
555 | ),
556 | (
557 | ["", "4", "4", "3", "0", "2", ""],
558 | ["", "0", "2", "3", "4", "4", ""],
559 | ),
560 | (
561 | ["", "3", "0", "2", "2", "3", "3", ""],
562 | ["", "0", "2", "2", "3", "3", "3", ""],
563 | ),
564 | (
565 | ["", "2", "4", "2", "0", "0", "3", ""],
566 | ["", "0", "0", "2", "2", "3", "4", ""],
567 | ),
568 | (
569 | ["", "0", "0", "2", "0", "2", "3", ""],
570 | ["", "0", "0", "0", "2", "2", "3", ""],
571 | ),
572 | (
573 | ["", "0", "1", "0", "4", "3", ""],
574 | ["", "0", "0", "1", "3", "4", ""],
575 | ),
576 | (
577 | ["", "4", "2", "1", "2", "4", "3", ""],
578 | ["", "1", "2", "2", "3", "4", "4", ""],
579 | ),
580 | (["", "4", "3", ""], ["", "3", "4", ""]),
581 | (
582 | ["", "1", "1", "0", "4", "2", "1", ""],
583 | ["", "0", "1", "1", "1", "2", "4", ""],
584 | ),
585 | (
586 | ["", "0", "1", "1", "3", "1", ""],
587 | ["", "0", "1", "1", "1", "3", ""],
588 | ),
589 | ]
590 | for x, y in examples:
591 | print(f"x: {x}")
592 | print(f"y: {y}")
593 | y_hat = run(x)
594 | print(f"y_hat: {y_hat}")
595 | print()
596 |
--------------------------------------------------------------------------------
/programs/rasp_categorical_only/sort/sort_weights.csv:
--------------------------------------------------------------------------------
1 | feature,value,0,1,2,3,4
2 | attn_0_0_outputs,0,6.9886675,1.9029737,-1.5632732,-4.666394,-6.6998963
3 | attn_0_0_outputs,1,2.808001,3.347147,0.15896283,-2.8454115,-5.028731
4 | attn_0_0_outputs,2,-0.9442088,0.086697936,2.8916383,0.13243037,-2.0047884
5 | attn_0_0_outputs,3,-5.0501323,-3.2974973,-0.2889569,3.5421517,1.7046565
6 | attn_0_0_outputs,4,-6.3542542,-4.6292524,-1.617208,2.236299,7.221099
7 | attn_0_0_outputs,,-1.2565033,-0.85322773,0.024527336,0.8092462,0.7044034
8 | attn_0_0_outputs,,0.215054,-0.6509929,-0.43638363,0.19641109,0.49052995
9 | attn_0_0_outputs,,-0.58543414,-0.4350722,-0.11593681,-0.010477837,0.035624333
10 | attn_0_1_outputs,0,8.321785,2.2734017,-1.3977479,-7.016843,-9.175802
11 | attn_0_1_outputs,1,6.1663084,8.126509,-0.13728762,-7.441755,-11.841383
12 | attn_0_1_outputs,2,-0.9028804,0.18783027,5.3251595,-2.1351826,-6.537742
13 | attn_0_1_outputs,3,-6.775278,-5.8244658,-1.1134557,7.1989717,1.5569733
14 | attn_0_1_outputs,4,-9.164681,-8.360944,-3.933387,3.6149611,13.612795
15 | attn_0_1_outputs,,-3.547872,-4.1135654,-2.683632,2.05698,6.5400867
16 | attn_0_1_outputs,,-0.20318432,0.13683431,-0.14054358,0.1945901,0.143777
17 | attn_0_1_outputs,,-0.2252483,-0.39968184,-1.0448881,0.33150062,0.60946065
18 | attn_0_2_outputs,0,12.662373,2.54436,-4.249066,-8.406431,-10.299077
19 | attn_0_2_outputs,1,2.874053,7.271783,-0.90393364,-5.438341,-8.241028
20 | attn_0_2_outputs,2,-2.4153283,1.249425,6.9100757,-2.0006323,-8.70784
21 | attn_0_2_outputs,3,-6.3554764,-4.6737704,0.021502413,7.1687856,-1.3787818
22 | attn_0_2_outputs,4,-8.521622,-7.321229,-3.046059,3.4124966,10.726074
23 | attn_0_2_outputs,,-1.8578418,-2.6656096,-4.077629,0.20807894,6.7849145
24 | attn_0_2_outputs,,0.092216395,-0.7707133,0.16602008,0.17321993,0.7220149
25 | attn_0_2_outputs,,-9.541172,-5.166671,-0.87453365,4.0391154,10.82512
26 | attn_0_3_outputs,0,16.349298,4.3640847,-3.3730683,-9.79156,-14.669813
27 | attn_0_3_outputs,1,4.139899,8.263065,-0.17833088,-6.8866434,-11.679859
28 | attn_0_3_outputs,2,-3.926204,-0.5807447,5.967192,-1.0567814,-5.560351
29 | attn_0_3_outputs,3,-9.774161,-6.630783,-0.6135139,7.905448,2.3533192
30 | attn_0_3_outputs,4,-12.426147,-9.234469,-3.574072,4.1587915,14.047162
31 | attn_0_3_outputs,,-4.6650424,-3.240851,-1.7017452,2.632841,6.9260855
32 | attn_0_3_outputs,,-0.7803622,-0.16891778,0.42231166,-0.11307383,-0.3041337
33 | attn_0_3_outputs,,0.2458686,-1.8348123,-1.403487,0.39682958,1.4361156
34 | attn_1_0_outputs,0,3.4249582,0.4830792,-1.4701751,-2.006143,-1.3161409
35 | attn_1_0_outputs,1,1.6426649,1.2126992,-0.69759977,-1.437486,-1.3773936
36 | attn_1_0_outputs,2,0.23226358,0.64345044,1.8823681,-1.4889188,-2.2751608
37 | attn_1_0_outputs,3,-2.1088793,-1.0431234,-0.4745396,2.6021097,0.0047708363
38 | attn_1_0_outputs,4,-0.4867569,-1.6423402,-1.241547,0.5894667,2.0828123
39 | attn_1_0_outputs,,-2.0974174,-2.5358207,-0.59028447,2.7884414,3.1243649
40 | attn_1_0_outputs,,-1.4079151,0.42247453,0.79169756,1.1603345,0.31134152
41 | attn_1_0_outputs,,-0.8227935,-2.4848738,-0.8918966,1.4806012,2.132991
42 | attn_1_1_outputs,0,13.7478485,2.4539003,-4.1379776,-8.794688,-13.024131
43 | attn_1_1_outputs,1,3.659074,7.3829765,-0.4129496,-6.3457537,-8.960195
44 | attn_1_1_outputs,2,-3.8716147,-0.74163604,4.9987493,0.16779378,-3.0191104
45 | attn_1_1_outputs,3,-9.338297,-5.943794,-0.07820616,6.4413576,3.6664085
46 | attn_1_1_outputs,4,-11.835959,-7.989941,-2.6649897,4.1501045,14.6659155
47 | attn_1_1_outputs,,-1.7373832,-1.0473574,-0.07293691,0.88357306,3.197291
48 | attn_1_1_outputs,,-0.26156133,0.73238564,1.010969,-0.0779104,-0.3386485
49 | attn_1_1_outputs,,-5.260586,-0.10670824,1.731505,2.0308712,-6.089484
50 | attn_1_2_outputs,0,11.480239,2.3556077,-2.5624404,-6.0736723,-8.765706
51 | attn_1_2_outputs,1,2.2337794,6.481738,-0.6245802,-4.446057,-7.1153054
52 | attn_1_2_outputs,2,-4.0426908,-1.4142859,4.2344704,0.5490717,-2.002491
53 | attn_1_2_outputs,3,-7.267531,-4.5090466,0.7685141,7.347656,-0.10958998
54 | attn_1_2_outputs,4,-9.947915,-7.479793,-2.613032,2.9863348,10.500013
55 | attn_1_2_outputs,,0.53556585,0.11474143,-0.014953171,-1.402238,0.556593
56 | attn_1_2_outputs,,-0.25244236,0.8308479,0.7589563,-0.32094172,-1.2273515
57 | attn_1_2_outputs,,-0.53400123,-2.8056118,-0.78918266,0.04643116,2.6973917
58 | attn_1_3_outputs,0,1.9439827,0.1603147,-0.28384447,-0.5827607,-1.5783767
59 | attn_1_3_outputs,1,0.16566427,0.85972536,0.93071663,0.55593437,-0.40285176
60 | attn_1_3_outputs,2,0.22454293,0.38536653,1.0067779,0.041088387,-1.6240689
61 | attn_1_3_outputs,3,-0.34689322,-0.12858017,0.11249515,0.83399034,-0.7050651
62 | attn_1_3_outputs,4,-0.8917678,-1.1450619,-0.80545264,-0.21876524,0.8630245
63 | attn_1_3_outputs,,-0.030100795,-0.6054426,-0.36994955,-0.07122007,0.38008985
64 | attn_1_3_outputs,,-0.4243932,0.38434273,0.29122502,-0.122149594,0.6827977
65 | attn_1_3_outputs,,-2.117496,-1.7259998,-0.7347762,0.4593257,2.2775197
66 | mlp_0_0_outputs,0,4.0188355,1.367233,-1.417323,-6.483142,-13.193972
67 | mlp_0_0_outputs,1,-3.9137676,-0.15376055,0.49536678,0.9998601,0.047882058
68 | mlp_0_0_outputs,2,-15.88143,-5.672158,0.99274987,3.5297909,4.746169
69 | mlp_0_0_outputs,3,0.70922613,0.12939063,0.303092,0.19436331,-0.15312691
70 | mlp_0_0_outputs,4,-19.113987,-14.055936,-5.2655005,1.665476,7.5324826
71 | mlp_0_0_outputs,5,0.0049047554,0.3532272,-0.61323506,-0.15334794,0.20983703
72 | mlp_0_0_outputs,6,-2.2349446,0.48716977,1.2473469,1.0492531,-0.959091
73 | mlp_0_0_outputs,7,2.0180795,2.697081,1.4667916,-2.6459293,-8.920683
74 | mlp_0_1_outputs,0,4.9354362,2.2356412,-0.82518,-6.697162,-13.193914
75 | mlp_0_1_outputs,1,-14.731216,-4.927309,-0.5305642,1.9409475,2.7460032
76 | mlp_0_1_outputs,2,0.23375379,0.11552315,0.44065633,-0.086906254,0.16232851
77 | mlp_0_1_outputs,3,-0.01565274,-0.505149,0.24257246,-0.20083542,0.091414995
78 | mlp_0_1_outputs,4,1.0003215,0.1339593,0.18748964,0.14529389,0.15201308
79 | mlp_0_1_outputs,5,-0.07818346,0.11651017,0.8559554,0.40377986,0.6004811
80 | mlp_0_1_outputs,6,-5.059556,-0.8851335,1.2983979,1.9971292,1.2971436
81 | mlp_0_1_outputs,7,2.3331697,2.7773235,1.5271827,-2.8835008,-7.8288684
82 | mlp_0_2_outputs,0,-5.812679,-8.050895,-2.0316663,1.604939,3.8419476
83 | mlp_0_2_outputs,1,-15.904148,-4.8865576,0.7665841,2.2370648,4.8236346
84 | mlp_0_2_outputs,2,0.023708992,0.23602854,-0.21941465,-0.06791145,0.14681834
85 | mlp_0_2_outputs,3,-5.155198,1.1561449,2.7422621,0.84532666,-1.3750209
86 | mlp_0_2_outputs,4,-1.6516947,0.14862537,0.27884433,0.09279299,0.0030477077
87 | mlp_0_2_outputs,5,-4.058745,1.8142332,1.0810403,-0.14894587,-0.9095843
88 | mlp_0_2_outputs,6,5.3210797,2.4041133,-1.6120065,-5.1028633,-10.729489
89 | mlp_0_2_outputs,7,-0.31382832,0.9711427,-0.14119859,-0.1960116,-0.1285425
90 | mlp_0_3_outputs,0,-15.051607,-2.8604221,2.0543356,2.499284,2.205402
91 | mlp_0_3_outputs,1,5.6800714,1.356257,-2.8838096,-6.541643,-12.12329
92 | mlp_0_3_outputs,2,0.8320392,0.06768919,0.18168978,-0.13751434,0.71345246
93 | mlp_0_3_outputs,3,-2.744153,-1.4329554,1.1388501,1.9455607,-0.36828822
94 | mlp_0_3_outputs,4,-0.47749937,1.1444833,1.4878665,0.511925,-2.8988562
95 | mlp_0_3_outputs,5,4.4155946,4.092108,1.8837255,-5.068154,-17.417696
96 | mlp_0_3_outputs,6,0.91394186,1.7178278,0.29244098,-1.9470179,-0.79406583
97 | mlp_0_3_outputs,7,-27.025166,-14.000694,-3.8636396,4.7704196,9.954426
98 | mlp_1_0_outputs,0,-8.994571,-4.491032,-1.4437572,2.639763,4.943209
99 | mlp_1_0_outputs,1,5.9503913,2.6140954,-2.56839,-10.30842,-19.072475
100 | mlp_1_0_outputs,2,-3.2524624,1.9415048,2.075256,0.1147645,-2.324298
101 | mlp_1_0_outputs,3,-3.3153234,-6.1573997,-1.7476376,1.969032,3.2494743
102 | mlp_1_0_outputs,4,-7.131646,-0.9621037,3.1830635,1.1788143,-1.1787232
103 | mlp_1_0_outputs,5,2.321128,1.8879267,0.55639076,-1.853329,-5.405838
104 | mlp_1_0_outputs,6,-0.3154466,1.7536464,1.4805571,-0.4060518,-2.380789
105 | mlp_1_0_outputs,7,-4.6405716,0.54778636,2.5990744,1.2474383,-1.03555
106 | mlp_1_1_outputs,0,0.2574987,1.1060137,0.34981915,-0.5553312,-1.3027234
107 | mlp_1_1_outputs,1,0.096980885,0.19448379,0.38600856,0.21713899,0.31291375
108 | mlp_1_1_outputs,2,0.20168386,0.24377005,-0.16930194,0.0846435,0.27269673
109 | mlp_1_1_outputs,3,-0.15246853,0.13554697,-0.14340618,-0.24426788,-0.1754791
110 | mlp_1_1_outputs,4,-1.1793035,-0.2165895,0.05511073,0.106202744,1.238185
111 | mlp_1_1_outputs,5,1.5681087,0.95414907,0.6013005,-1.0025665,-1.9541403
112 | mlp_1_1_outputs,6,-0.92157024,-0.2364761,0.2521039,1.1757741,1.7977172
113 | mlp_1_1_outputs,7,-0.8971796,0.37849936,-0.026569659,0.031324264,-0.46514994
114 | mlp_1_2_outputs,0,-1.3478705,-2.4680243,0.66674495,1.3479987,0.5765663
115 | mlp_1_2_outputs,1,0.37738907,1.0555747,1.5234038,-2.4196775,-6.760583
116 | mlp_1_2_outputs,2,2.4935262,4.591032,-4.003316,-5.4412313,-2.304786
117 | mlp_1_2_outputs,3,4.2700715,2.2302587,-0.795981,-2.4704978,-4.8020735
118 | mlp_1_2_outputs,4,-3.9012556,-1.7321999,3.3738852,2.948957,2.2248867
119 | mlp_1_2_outputs,5,-2.0524378,-5.6856995,-0.8716151,2.056629,3.9021623
120 | mlp_1_2_outputs,6,2.1821785,1.7694793,-0.9968027,-1.056598,-4.7362437
121 | mlp_1_2_outputs,7,-2.2508187,0.65054375,1.1322451,0.36528292,-1.2405411
122 | mlp_1_3_outputs,0,0.3038056,-0.40180534,0.27598795,0.07215004,-0.9269151
123 | mlp_1_3_outputs,1,-3.1313102,-2.0797613,-0.5406382,1.3487104,2.601697
124 | mlp_1_3_outputs,2,1.0282918,0.54115653,0.54766685,0.22385533,-0.9048508
125 | mlp_1_3_outputs,3,-0.046904188,0.8035955,1.1800617,0.2309441,-0.29504803
126 | mlp_1_3_outputs,4,0.7052468,-0.32243052,-0.2011492,-0.5544647,0.47795638
127 | mlp_1_3_outputs,5,1.8298012,0.72522086,0.55117387,0.12736094,-1.4325687
128 | mlp_1_3_outputs,6,0.09611695,-0.158596,-0.02878351,-0.4176433,-0.78748065
129 | mlp_1_3_outputs,7,1.6000229,0.9329063,1.1538035,0.3359634,-1.0799118
130 | ones,_,-0.055436447,-0.1179431,-1.0715027,-0.15707256,-0.06473778
131 | positions,0,-0.45721093,-0.87155324,-0.35231563,0.61707205,0.11435398
132 | positions,1,7.035323,2.7762,-4.4782853,-11.52158,-7.547409
133 | positions,2,1.7988344,4.5708203,1.894728,-4.0963306,-10.078739
134 | positions,3,-2.066558,3.0727959,3.1645296,-0.5606865,-4.997998
135 | positions,4,-8.358648,-0.029626308,1.6459913,4.489941,-0.8345926
136 | positions,5,-11.328461,-4.9711714,-0.53534645,5.983784,3.4289176
137 | positions,6,-17.799313,-13.832921,-4.907452,3.116108,9.586147
138 | positions,7,-0.025707932,-0.42475572,-0.059313126,0.047880296,0.63343513
139 | tokens,0,9.776978,1.7060088,-2.9519324,-5.849588,-7.211308
140 | tokens,1,2.336422,4.711168,-0.5835804,-3.646622,-5.077569
141 | tokens,2,-2.223295,-0.92388165,3.334594,-0.44335306,-2.2490618
142 | tokens,3,-4.5560665,-3.9469757,-0.53029966,5.159678,2.8704758
143 | tokens,4,-6.5824995,-5.69455,-2.5579076,2.4740694,9.455751
144 | tokens,,0.6628496,-0.32530716,0.34280553,-0.2948921,-0.063519835
145 | tokens,,-0.010760286,-0.05715638,0.4203831,-0.12519221,0.6524513
146 | tokens,,-0.5623373,0.11551832,0.005275513,0.6801121,-0.23678207
147 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | datasets
3 | einops
4 | gensim
5 | matplotlib
6 | nltk
7 | numpy
8 | pandas
9 | scikit-learn
10 | seaborn
11 | torch
12 | tqdm
13 |
--------------------------------------------------------------------------------
/scripts/classification.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | N_EPOCHS=10
4 | N_LAYERS=2
5 | N_VARS_CAT=4
6 | N_HEADS_CAT=4
7 | N_CAT_MLPS=1
8 | SEED=0
9 |
10 | python src/run.py \
11 | --dataset "${DATASET}" \
12 | --vocab_size 10000 \
13 | --min_length 1 \
14 | --max_length 64 \
15 | --n_epochs "${N_EPOCHS}" \
16 | --batch_size 128 \
17 | --lr "5e-2" \
18 | --gumbel_samples 1 \
19 | --sample_fn "gumbel_soft" \
20 | --tau_init 3.0 \
21 | --tau_end 0.01 \
22 | --tau_schedule "geomspace" \
23 | --n_vars_cat "${N_VARS_CAT}" \
24 | --d_var 64 \
25 | --n_vars_num 1 \
26 | --n_layers "${N_LAYERS}" \
27 | --n_heads_cat "${N_HEADS_CAT}" \
28 | --n_heads_num 0 \
29 | --n_cat_mlps "${N_CAT_MLPS}" \
30 | --n_num_mlps 0 \
31 | --attention_type "cat" \
32 | --rel_pos_bias "fixed" \
33 | --dropout 0.0 \
34 | --mlp_vars_in 2 \
35 | --d_mlp 64 \
36 | --count_only \
37 | --selector_width 0 \
38 | --do_lower 0 \
39 | --replace_numbers 0 \
40 | --glove_embeddings "data/glove.840B.300d.txt" \
41 | --do_glove 1 \
42 | --pool_outputs 1 \
43 | --seed "${SEED}" \
44 | --save \
45 | --save_code \
46 | --output_dir "output/classification/${DATASET}/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}";
47 |
--------------------------------------------------------------------------------
/scripts/classification_short.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | N_EPOCHS=10
4 | N_LAYERS=2
5 | N_VARS_CAT=4
6 | N_HEADS_CAT=4
7 | N_CAT_MLPS=1
8 | SEED=0
9 |
10 | python src/run.py \
11 | --dataset "${DATASET}" \
12 | --vocab_size 5000 \
13 | --min_length 1 \
14 | --max_length 25 \
15 | --n_epochs "${N_EPOCHS}" \
16 | --batch_size 128 \
17 | --lr "5e-2" \
18 | --gumbel_samples 1 \
19 | --sample_fn "gumbel_soft" \
20 | --tau_init 3.0 \
21 | --tau_end 0.01 \
22 | --tau_schedule "geomspace" \
23 | --n_vars_cat "${N_VARS_CAT}" \
24 | --d_var 25 \
25 | --n_vars_num 1 \
26 | --n_layers "${N_LAYERS}" \
27 | --n_heads_cat "${N_HEADS_CAT}" \
28 | --n_heads_num 0 \
29 | --n_cat_mlps "${N_CAT_MLPS}" \
30 | --n_num_mlps 0 \
31 | --attention_type "cat" \
32 | --rel_pos_bias "fixed" \
33 | --dropout 0.0 \
34 | --mlp_vars_in 2 \
35 | --d_mlp 64 \
36 | --count_only \
37 | --selector_width 0 \
38 | --do_lower 0 \
39 | --replace_numbers 0 \
40 | --glove_embeddings "data/glove.840B.300d.txt" \
41 | --do_glove 1 \
42 | --pool_outputs 1 \
43 | --seed "${SEED}" \
44 | --save \
45 | --save_code \
46 | --output_dir "output/classification_short/${DATASET}/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}";
47 |
--------------------------------------------------------------------------------
/scripts/classification_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | D_MODEL=64
4 | N_LAYERS=2
5 | N_HEADS=4
6 | SEED=0
7 |
8 | python src/run.py \
9 | --dataset "${DATASET}" \
10 | --standard \
11 | --vocab_size 10000 \
12 | --min_length 1 \
13 | --max_length 64 \
14 | --n_epochs 100 \
15 | --batch_size 128 \
16 | --lr "5e-3" \
17 | --d_model "${D_MODEL}" \
18 | --n_heads "${N_HEADS}" \
19 | --n_layers "${N_LAYERS}" \
20 | --d_mlp 64 \
21 | --dropout 0.5 \
22 | --max_grad_norm 5.0 \
23 | --do_lower 0 \
24 | --replace_numbers 0 \
25 | --glove_embeddings "data/glove.840B.300d.txt" \
26 | --do_glove "${DO_GLOVE}" \
27 | --pool_outputs 1 \
28 | --seed "${SEED}" \
29 | --output_dir "output/classification/${DATASET}/standard_transformer/dmodel${D_MODEL}nheads${N_CAT}nlayers${N_LAYERS}/s${SEED}";
30 |
--------------------------------------------------------------------------------
/scripts/conll.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | N_EPOCHS=50
4 | N_LAYERS=2
5 | N_VARS_CAT=4
6 | N_HEADS_CAT=8
7 | N_CAT_MLPS=1
8 | SEED=0
9 |
10 | python src/run.py \
11 | --dataset "conll_ner" \
12 | --vocab_size 10000 \
13 | --min_length 1 \
14 | --max_length 32 \
15 | --n_epochs "${N_EPOCHS}" \
16 | --batch_size 32 \
17 | --lr "1e-2" \
18 | --gumbel_samples 1 \
19 | --sample_fn "gumbel_soft" \
20 | --tau_init 3.0 \
21 | --tau_end 0.01 \
22 | --tau_schedule "geomspace" \
23 | --n_vars_cat "${N_VARS_CAT}" \
24 | --d_var 32 \
25 | --n_vars_num 1 \
26 | --n_layers "${N_LAYERS}" \
27 | --n_heads_cat "${N_HEADS_CAT}" \
28 | --n_heads_num 0 \
29 | --n_cat_mlps "${N_CAT_MLPS}" \
30 | --n_num_mlps 0 \
31 | --attention_type "cat" \
32 | --rel_pos_bias "fixed" \
33 | --dropout 0.5 \
34 | --mlp_vars_in 2 \
35 | --d_mlp 64 \
36 | --count_only \
37 | --selector_width 0 \
38 | --do_lower 0 \
39 | --replace_numbers 1 \
40 | --glove_embeddings "data/glove.840B.300d.txt" \
41 | --do_glove 1 \
42 | --pool_outputs 0 \
43 | --seed "${SEED}" \
44 | --save \
45 | --save_code \
46 | --output_dir "output/conll_ner/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}";
47 |
--------------------------------------------------------------------------------
/scripts/conll_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | D_MODEL=128
4 | N_LAYERS=3
5 | N_HEADS=4
6 | SEED=0
7 |
8 | python src/run.py \
9 | --dataset "conll_ner" \
10 | --standard \
11 | --vocab_size 10000 \
12 | --min_length 1 \
13 | --max_length 32 \
14 | --n_epochs 100 \
15 | --batch_size 32 \
16 | --lr "5e-3" \
17 | --d_model "${D_MODEL}" \
18 | --n_heads "${N_HEADS}" \
19 | --n_layers "${N_LAYERS}" \
20 | --d_mlp 64 \
21 | --dropout 0.5 \
22 | --max_grad_norm 5.0 \
23 | --do_lower 0 \
24 | --replace_numbers 1 \
25 | --glove_embeddings "data/glove.840B.300d.txt" \
26 | --do_glove "${DO_GLOVE}" \
27 | --pool_outputs 0 \
28 | --seed "${SEED}" \
29 | --output_dir "output/conll_ner/standard_transformer/dmodel${D_MODEL}nheads${N_CAT}nlayers${N_LAYERS}/s${SEED}";
30 |
--------------------------------------------------------------------------------
/scripts/dyck.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MAX_LENGTH=16
4 | N_LAYERS=3
5 | N_HEADS_CAT=4
6 | N_HEADS_NUM=4
7 | N_CAT_MLPS=2
8 | N_NUM_MLPS=2
9 | SEED=0
10 |
11 | python src/run.py \
12 | --dataset "${DATASET}" \
13 | --dataset_size 20000 \
14 | --min_length "${MAX_LENGTH}" \
15 | --max_length "${MAX_LENGTH}" \
16 | --n_epochs 250 \
17 | --batch_size 512 \
18 | --lr "5e-2" \
19 | --gumbel_samples 1 \
20 | --sample_fn "gumbel_soft" \
21 | --tau_init 3.0 \
22 | --tau_end 0.01 \
23 | --tau_schedule "geomspace" \
24 | --n_vars_cat 1 \
25 | --d_var "${MAX_LENGTH}" \
26 | --n_vars_num 1 \
27 | --n_layers "${N_LAYERS}" \
28 | --n_heads_cat "${N_HEADS_CAT}" \
29 | --n_heads_num "${N_HEADS_NUM}" \
30 | --n_cat_mlps "${N_MLPS}" \
31 | --n_num_mlps "${N_NUM_CAT_MLPS}" \
32 | --attention_type "cat" \
33 | --rel_pos_bias "fixed" \
34 | --one_hot_embed \
35 | --dropout 0.0 \
36 | --mlp_vars_in 2 \
37 | --d_mlp 64 \
38 | --count_only \
39 | --selector_width 0 \
40 | --seed "${SEED}" \
41 | --unique 1 \
42 | --save \
43 | --save_code \
44 | --output_dir "output/rasp/${DATASET}/maxlen${MAX_LENGTH}/transformer_program/headsc${N_HEADS_CAT}headsn${N_HEADS_NUM}nlayers${N_LAYERS}cmlps${N_MLPS}nmlps${N_NUM_CAT_MLPS}/s${SEED}";
45 |
--------------------------------------------------------------------------------
/scripts/induction.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | VOCAB_SIZE=10
4 | MIN_LENGTH=9
5 | MAX_LENGTH=9
6 | SEED=6
7 |
8 | echo "SEED=${SEED}";
9 |
10 | python src/run.py \
11 | --dataset "induction" \
12 | --vocab_size "${VOCAB_SIZE}" \
13 | --dataset_size 20000 \
14 | --min_length "${MIN_LENGTH}" \
15 | --max_length "${MAX_LENGTH}" \
16 | --n_epochs 500 \
17 | --batch_size 512 \
18 | --lr "5e-2" \
19 | --gumbel_samples 1 \
20 | --sample_fn "gumbel_soft" \
21 | --tau_init 3.0 \
22 | --tau_end 0.01 \
23 | --tau_schedule "geomspace" \
24 | --n_vars_cat 1 \
25 | --n_vars_num 1 \
26 | --n_layers 2 \
27 | --n_heads_cat 1 \
28 | --n_heads_num 0 \
29 | --n_cat_mlps 0 \
30 | --n_num_mlps 0 \
31 | --attention_type "cat" \
32 | --rel_pos_bias "fixed" \
33 | --one_hot_embed \
34 | --count_only \
35 | --selector_width 0 \
36 | --seed "${SEED}" \
37 | --unique 1 \
38 | --unembed_mask 0 \
39 | --autoregressive \
40 | --save \
41 | --save_code \
42 | --output_dir "output/induction/s${SEED}";
43 |
--------------------------------------------------------------------------------
/scripts/rasp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | VOCAB_SIZE=8
4 | MAX_LENGTH=8
5 | N_LAYERS=3
6 | N_HEADS_CAT=4
7 | N_HEADS_NUM=4
8 | N_CAT_MLPS=2
9 | N_NUM_MLPS=2
10 | SEED=0
11 |
12 | python src/run.py \
13 | --dataset "${DATASET}" \
14 | --vocab_size "${VOCAB_SIZE}" \
15 | --dataset_size 20000 \
16 | --min_length 1 \
17 | --max_length "${MAX_LENGTH}" \
18 | --n_epochs 250 \
19 | --batch_size 512 \
20 | --lr "5e-2" \
21 | --gumbel_samples 1 \
22 | --sample_fn "gumbel_soft" \
23 | --tau_init 3.0 \
24 | --tau_end 0.01 \
25 | --tau_schedule "geomspace" \
26 | --n_vars_cat 1 \
27 | --d_var "${MAX_LENGTH}" \
28 | --n_vars_num 1 \
29 | --n_layers "${N_LAYERS}" \
30 | --n_heads_cat "${N_HEADS_CAT}" \
31 | --n_heads_num "${N_HEADS_NUM}" \
32 | --n_cat_mlps "${N_MLPS}" \
33 | --n_num_mlps "${N_NUM_CAT_MLPS}" \
34 | --attention_type "cat" \
35 | --rel_pos_bias "fixed" \
36 | --one_hot_embed \
37 | --dropout 0.0 \
38 | --mlp_vars_in 2 \
39 | --d_mlp 64 \
40 | --count_only \
41 | --selector_width 0 \
42 | --seed "${SEED}" \
43 | --unique 1 \
44 | --save \
45 | --save_code \
46 | --output_dir "output/rasp/${DATASET}/vocab${VOCAB_SIZE}maxlen${MAX_LENGTH}/transformer_program/headsc${N_HEADS_CAT}headsn${N_HEADS_NUM}nlayers${N_LAYERS}cmlps${N_MLPS}nmlps${N_NUM_CAT_MLPS}/s${SEED}";
47 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name="src", version="0.1", packages=find_packages())
4 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/__init__.py
--------------------------------------------------------------------------------
/src/decompile.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | from copy import deepcopy
4 | from functools import partial
5 | import itertools
6 | import json
7 | import math
8 | from pathlib import Path
9 | import random
10 |
11 | import einops
12 | import numpy as np
13 | import pandas as pd
14 | import re
15 | import torch
16 | from torch import nn
17 | from torch.nn import functional as F
18 | from torch.utils.data import DataLoader
19 | from tqdm import tqdm
20 |
21 | from src.models.programs import TransformerProgramModel, argmax
22 |
23 | from src.run import set_seed, get_sample_fn
24 | from src.utils import code_utils, data_utils, logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class DotDict(dict):
31 | __getattr__ = dict.__getitem__
32 | __setattr__ = dict.__setitem__
33 | __delattr__ = dict.__delitem__
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument("--path", type=str, default="scratch")
39 | parser.add_argument("--output_dir", type=str, default="scratch")
40 | return parser.parse_args()
41 |
42 |
43 | def load_model(path):
44 | args_fn = Path(path) / "args.json"
45 | if not args_fn.exists():
46 | raise ValueError(f"missing {args_fn}")
47 | model_fn = Path(path) / "model.pt"
48 | if not model_fn.exists():
49 | raise ValueError(f"missing {model_fn}")
50 |
51 | with open(args_fn, "r") as f:
52 | args = DotDict(json.load(f))
53 |
54 | logger.info(f"loading {args.dataset}")
55 | set_seed(args.seed)
56 | (
57 | train,
58 | test,
59 | val,
60 | idx_w,
61 | w_idx,
62 | idx_t,
63 | t_idx,
64 | X_train,
65 | Y_train,
66 | X_test,
67 | Y_test,
68 | X_val,
69 | Y_val,
70 | ) = data_utils.get_dataset(
71 | name=args.dataset,
72 | vocab_size=args.vocab_size,
73 | dataset_size=args.dataset_size,
74 | min_length=args.min_length,
75 | max_length=args.max_length,
76 | seed=args.seed,
77 | do_lower=args.do_lower,
78 | replace_numbers=args.replace_numbers,
79 | get_val=True,
80 | unique=args.unique,
81 | )
82 |
83 | logger.info(f"initializing model from {args_fn}")
84 | if args.d_var is None:
85 | d = max(len(idx_w), X_train.shape[-1])
86 | else:
87 | d = args.d_var
88 | init_emb = None
89 | if args.glove_embeddings and args.do_glove:
90 | emb = data_utils.get_glove_embeddings(
91 | idx_w, args.glove_embeddings, dim=args.n_vars_cat * d
92 | )
93 | init_emb = torch.tensor(emb, dtype=torch.float32).T
94 | unembed_mask = None
95 | if args.unembed_mask:
96 | unembed_mask = np.array([t in ("", "") for t in idx_t])
97 | set_seed(args.seed)
98 | model = TransformerProgramModel(
99 | d_vocab=len(idx_w),
100 | d_vocab_out=len(idx_t),
101 | n_vars_cat=args.n_vars_cat,
102 | n_vars_num=args.n_vars_num,
103 | d_var=d,
104 | n_heads_cat=args.n_heads_cat,
105 | n_heads_num=args.n_heads_num,
106 | d_mlp=args.d_mlp,
107 | n_cat_mlps=args.n_cat_mlps,
108 | n_num_mlps=args.n_num_mlps,
109 | mlp_vars_in=args.mlp_vars_in,
110 | n_layers=args.n_layers,
111 | n_ctx=X_train.shape[1],
112 | sample_fn=get_sample_fn(args.sample_fn),
113 | init_emb=init_emb,
114 | attention_type=args.attention_type,
115 | rel_pos_bias=args.rel_pos_bias,
116 | unembed_mask=unembed_mask,
117 | pool_outputs=args.pool_outputs,
118 | one_hot_embed=args.one_hot_embed,
119 | count_only=args.count_only,
120 | selector_width=args.selector_width,
121 | )
122 |
123 | logger.info(f"restoring weights from {model_fn}")
124 | model.load_state_dict(
125 | torch.load(str(model_fn), map_location=torch.device("cpu"))
126 | )
127 |
128 | model.set_temp(args.tau_end, argmax)
129 | return model, args, idx_w, idx_t, X_val
130 |
131 |
132 | def model_to_code(model, args, idx_w, idx_t, X=None, output_dir=""):
133 | if output_dir:
134 | Path(output_dir).mkdir(exist_ok=True, parents=True)
135 | x = None
136 | if X is not None:
137 | x = idx_w[X[0]]
138 | x = x[x != ""].tolist()
139 | m = code_utils.model_to_code(
140 | model=model,
141 | idx_w=idx_w,
142 | idx_t=idx_t,
143 | embed_csv=not args.one_hot_embed,
144 | unembed_csv=True,
145 | one_hot=args.one_hot_embed,
146 | autoregressive=args.autoregressive,
147 | var_types=True,
148 | output_dir=output_dir,
149 | name=args.dataset,
150 | example=x,
151 | save=bool(output_dir),
152 | )
153 | return m
154 |
155 |
156 | if __name__ == "__main__":
157 | args = parse_args()
158 | model, args, idx_w, idx_t, X_val = load_model(args.path)
159 | model_to_code(
160 | model=model,
161 | args=args,
162 | idx_w=idx_w,
163 | idx_t=idx_t,
164 | X=X_val,
165 | output_dir=args.output_dir,
166 | )
167 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/transformers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 |
9 | # Define network architecture. Adapted from
10 | # https://colab.research.google.com/drive/19gn2tavBGDqOYHLatjSROhABBD5O_JyZ
11 | class Embed(nn.Module):
12 | def __init__(self, d_vocab, d_model, init_emb=None):
13 | super().__init__()
14 | if init_emb is not None:
15 | self.W_E = nn.Parameter(init_emb)
16 | else:
17 | self.W_E = nn.Parameter(
18 | torch.randn(d_model, d_vocab) / np.sqrt(d_model)
19 | )
20 |
21 | def forward(self, x):
22 | return torch.einsum("dbp -> bpd", self.W_E[:, x])
23 |
24 |
25 | class Unembed(nn.Module):
26 | def __init__(self, d_vocab, d_model, mask=None):
27 | super().__init__()
28 | self.W_U = nn.Parameter(
29 | torch.randn(d_model, d_vocab) / np.sqrt(d_vocab)
30 | )
31 | if mask is not None:
32 | self.register_buffer(
33 | "mask", torch.tensor(mask).view(1, 1, len(mask))
34 | )
35 | else:
36 | self.mask = None
37 |
38 | def forward(self, x):
39 | logits = x @ self.W_U
40 | if self.mask is not None:
41 | logits = logits.masked_fill(self.mask, -1e30)
42 | return logits
43 |
44 |
45 | class PosEmbed(nn.Module):
46 | def __init__(self, max_ctx, d_model):
47 | super().__init__()
48 | self.W_pos = nn.Parameter(
49 | torch.randn(max_ctx, d_model) / np.sqrt(d_model)
50 | )
51 |
52 | def forward(self, x):
53 | return x + self.W_pos[: x.shape[-2]]
54 |
55 |
56 | class Transformer(nn.Module):
57 | def __init__(
58 | self,
59 | d_vocab,
60 | n_layers=2,
61 | d_model=64,
62 | d_mlp=64,
63 | n_heads=4,
64 | n_ctx=32,
65 | act_type="ReLU",
66 | d_vocab_out=None,
67 | dropout=0.0,
68 | init_emb=None,
69 | unembed_mask=None,
70 | pool_outputs=False,
71 | **kwargs,
72 | ):
73 | super().__init__()
74 | self.embed = Embed(d_vocab, d_model, init_emb=init_emb)
75 | self.pos_embed = PosEmbed(n_ctx, d_model)
76 | layer = nn.TransformerEncoderLayer(
77 | d_model=d_model,
78 | nhead=n_heads,
79 | batch_first=True,
80 | dim_feedforward=d_mlp,
81 | dropout=dropout,
82 | norm_first=True,
83 | )
84 | self.encoder = nn.TransformerEncoder(layer, n_layers)
85 | self.unembed = Unembed(
86 | d_vocab_out or d_vocab, d_model, mask=unembed_mask
87 | )
88 | self.n_heads = n_heads
89 | self.dropout = nn.Dropout(dropout)
90 | self.pool_outputs = pool_outputs
91 |
92 | def forward(self, x, mask=None):
93 | x = self.embed(x)
94 | x = x * np.sqrt(x.shape[-1])
95 | x = self.pos_embed(x)
96 | x = self.dropout(x)
97 | m = mask
98 | mask = torch.cat([mask for _ in range(self.n_heads)], 0)
99 | mask[:, :, 0] = True
100 | x = self.encoder(x, mask=~mask)
101 | if self.pool_outputs:
102 | x = torch.cat(
103 | [
104 | x.masked_fill(~m[:, 0].unsqueeze(-1), 0).mean(
105 | 1, keepdims=True
106 | ),
107 | x[:, 1:],
108 | ],
109 | 1,
110 | )
111 | x = self.unembed(x)
112 | return x
113 |
114 | @property
115 | def device(self):
116 | return self.embed.W_E.device
117 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 | from pathlib import Path
5 | import sys
6 |
7 |
8 | def get_handler(fn=None, stream=None):
9 | formatter = logging.Formatter(
10 | fmt="%(levelname).1s %(asctime)s %(name)s:%(lineno)d: %(message)s",
11 | datefmt="%Y-%m-%dT%H:%M:%S",
12 | )
13 | h = (
14 | logging.FileHandler(fn, mode="w")
15 | if fn
16 | else logging.StreamHandler(stream=stream)
17 | )
18 | h.setLevel(logging.INFO)
19 | h.setFormatter(formatter)
20 | return h
21 |
22 |
23 | def get_resume_name(output_dir, s="output.log"):
24 | output_logs = list(Path(output_dir).glob(f"{s}*"))
25 | if len(output_logs) == 0:
26 | return Path(output_dir) / s
27 | idx = 1
28 | for p in output_logs:
29 | if str(p)[-1].isdigit():
30 | p_idx = int(str(p).split(".")[-1])
31 | idx = max(idx, p_idx + 1)
32 | return Path(output_dir) / f"{s}.{idx}"
33 |
34 |
35 | def initialize(output_dir, resume=False):
36 | fn = (
37 | get_resume_name(output_dir)
38 | if resume
39 | else Path(output_dir) / "output.log"
40 | )
41 | if not fn.parent.exists():
42 | fn.parent.mkdir(parents=True)
43 | handlers = (
44 | get_handler(stream=sys.stdout),
45 | get_handler(fn=fn),
46 | )
47 | logging.basicConfig(handlers=handlers, force=True, level=logging.INFO)
48 |
49 |
50 | def get_logger(name):
51 | logging.basicConfig(handlers=[get_handler(stream=None)], level=logging.INFO)
52 | return logging.getLogger(name)
53 |
--------------------------------------------------------------------------------
/src/utils/metric_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from seqeval.metrics import precision_score
3 | from seqeval.metrics import recall_score
4 | from seqeval.metrics import f1_score
5 | from seqeval.scheme import IOB2
6 |
7 |
8 | def eval_induction(X, Y, Y_pred, y_pad_idx=0):
9 | rows = []
10 | for y in Y:
11 | m = []
12 | seen = set()
13 | for t in y:
14 | if t not in seen:
15 | m.append(False)
16 | seen.add(t)
17 | else:
18 | m.append(True)
19 | rows.append(m)
20 | mask = np.stack(m, 0) & (Y != y_pad_idx)
21 | return (Y == Y_pred)[mask].mean()
22 |
23 |
24 | def __f1_score(y_true, y_pred, o_idx):
25 | precision, recall, f1 = 0, 0, 0
26 | m_p = y_pred != o_idx
27 | if m_p.sum() > 0:
28 | precision = (y_true[m_p] == y_pred[m_p]).mean()
29 | m_r = y_true != o_idx
30 | if m_r.sum() > 0:
31 | recall = (y_true[m_r] == y_pred[m_r]).mean()
32 | if precision + recall > 0:
33 | f1 = 2 * precision * recall / (precision + recall)
34 | return {"precision": precision, "recall": recall, "f1": f1}
35 |
36 |
37 | def conll_score(y_true, y_pred):
38 | return {
39 | "precision": precision_score(
40 | y_true, y_pred, mode="strict", scheme=IOB2
41 | ),
42 | "recall": recall_score(y_true, y_pred, mode="strict", scheme=IOB2),
43 | "f1": f1_score(y_true, y_pred, mode="strict", scheme=IOB2),
44 | }
45 |
--------------------------------------------------------------------------------