├── .gitignore ├── README.md ├── figures └── methodology.png ├── programs ├── conll │ ├── conll_ner.py │ ├── conll_ner_embeddings.csv │ ├── conll_ner_embeddings.py │ ├── conll_ner_weights.csv │ └── conll_notebook.ipynb ├── rasp │ ├── double_hist │ │ ├── double_hist.py │ │ └── double_hist_weights.csv │ ├── dyck1 │ │ ├── dyck1.py │ │ └── dyck1_weights.csv │ ├── dyck2 │ │ ├── dyck2.py │ │ └── dyck2_weights.csv │ ├── hist │ │ ├── hist.py │ │ └── hist_weights.csv │ ├── most_freq │ │ ├── most_freq.py │ │ └── most_freq_weights.csv │ ├── reverse │ │ ├── reverse.py │ │ └── reverse_weights.csv │ └── sort │ │ ├── sort.py │ │ └── sort_weights.csv ├── rasp_categorical_only │ ├── double_hist │ │ ├── double_hist.py │ │ └── double_hist_weights.csv │ ├── dyck1 │ │ ├── dyck1.py │ │ └── dyck1_weights.csv │ ├── dyck2 │ │ ├── dyck2.py │ │ └── dyck2_weights.csv │ ├── hist │ │ ├── hist.py │ │ └── hist_weights.csv │ ├── most_freq │ │ ├── most_freq.py │ │ └── most_freq_weights.csv │ ├── reverse │ │ ├── reverse.py │ │ └── reverse_weights.csv │ └── sort │ │ ├── sort.py │ │ └── sort_weights.csv └── trec │ ├── trec.py │ ├── trec_embeddings.csv │ ├── trec_embeddings.py │ └── trec_weights.csv ├── requirements.txt ├── scripts ├── classification.sh ├── classification_short.sh ├── classification_standard.sh ├── conll.sh ├── conll_standard.sh ├── dyck.sh ├── induction.sh └── rasp.sh ├── setup.py └── src ├── __init__.py ├── decompile.py ├── models ├── __init__.py ├── programs.py └── transformers.py ├── run.py └── utils ├── __init__.py ├── analysis_utils.py ├── code_utils.py ├── data_utils.py ├── logging.py └── metric_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | _site 2 | .sass-cache 3 | Gemfile.lock 4 | *.gem 5 | .jekyll-cache 6 | .jekyll-cache 7 | *~ 8 | *__pycache__* 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Transformer Programs 2 | 3 | This repository contains the code for our paper, [Learning Transformer Programs](https://arxiv.org/abs/2306.01128). 4 | The code can be used to train a modified Transformer to solve a task, and then convert it into a human-readable Python program. 5 | The repository also includes a number of [example programs](#Example-programs), which we learned for the tasks described in the paper. 6 | Please see [our paper](https://arxiv.org/abs/2306.01128) for more details. 7 | 8 | 9 | 10 | 11 | ## Quick links 12 | * [Setup](#Setup) 13 | * [Learning Programs](#Learning-programs) 14 | * [Training](#Training) 15 | * [Converting to code](#Converting-to-code) 16 | * [Example Programs](#Example-programs) 17 | * [Questions?](#Questions) 18 | * [Citation](#Citation) 19 | 20 | ## Setup 21 | 22 | Install [PyTorch](https://pytorch.org/get-started/locally/) and then install the remaining requirements: `pip install -r requirements.txt`. 23 | This code was tested using Python 3.8 and PyTorch version 1.13.1. 24 | 25 | In our experiments on NLP tasks, we initialize word embeddings using 300-dimensional pre-trained GloVe embeddings, which can be downloaded [here](https://github.com/stanfordnlp/GloVe) (Common Crawl, cased): 26 | ```bash 27 | mkdir data 28 | wget https://huggingface.co/stanfordnlp/glove/resolve/main/glove.840B.300d.zip -P data/ 29 | unzip data/glove.840B.300d.zip 30 | ``` 31 | 32 | ## Learning Programs 33 | 34 | ### Training 35 | 36 | The code to learn a Transformer Program can be found in [src/run.py](src/run.py). 37 | For example, the following command will train a Transformer Program for the `sort` task, using two layers, four categorical attention heads per-layer, and one-hot input embeddings: 38 | ```bash 39 | python src/run.py \ 40 | --dataset "sort" \ 41 | --vocab_size 8 \ 42 | --dataset_size 10000 \ 43 | --min_length 1 \ 44 | --max_length 8 \ 45 | --n_epochs 250 \ 46 | --batch_size 512 \ 47 | --lr "5e-2" \ 48 | --n_layers 2 \ 49 | --n_heads_cat 4 \ 50 | --n_heads_num 0 \ 51 | --n_cat_mlps 1 \ 52 | --n_num_mlps 0 \ 53 | --one_hot_embed \ 54 | --count_only \ 55 | --seed 0 \ 56 | --save \ 57 | --save_code \ 58 | --output_dir "output/sort"; 59 | ``` 60 | This command will train a Transformer Program for the CoNLL 2003 named-entity recognition task, learning input embeddings composed of four 32-dimensional categorical variables: 61 | ```bash 62 | python src/run.py \ 63 | --dataset "conll_ner" \ 64 | --vocab_size 10000 \ 65 | --min_length 1 \ 66 | --max_length 32 \ 67 | --n_epochs 50 \ 68 | --batch_size 32 \ 69 | --lr "5e-2" \ 70 | --n_vars_cat 4 \ 71 | --d_var 32 \ 72 | --n_layers 2 \ 73 | --n_heads_cat 4 \ 74 | --n_heads_num 0 \ 75 | --n_cat_mlps 1 \ 76 | --n_num_mlps 0 \ 77 | --mlp_vars_in 2 \ 78 | --count_only \ 79 | --seed 0 \ 80 | --replace_numbers 1 \ 81 | --glove_embeddings "data/glove.840B.300d.txt" \ 82 | --do_glove 1 \ 83 | --save \ 84 | --save_code \ 85 | --output_dir "output/conll"; 86 | ``` 87 | Please see [src/run.py](src/run.py) for all of the possible arguments. 88 | The training data will either be generated (for the RASP tasks) or downloaded from [Hugging Face Datasets](https://huggingface.co/datasets); see [src/utils/data_utils.py](src/utils/data_utils.py) for the supported datasets. 89 | The [scripts](scripts/) directory contains scripts for training Transformer Programs and standard Transformers with the experiment settings used in the paper. 90 | 91 | ### Converting to code 92 | 93 | Run the training script with the `--save_code` flag to convert the model to a Python program at the end of training. 94 | To convert a model that has already been trained, use `src/decompile.py`. 95 | For example, 96 | ```bash 97 | python src/decompile.py --path output/sort/ --output_dir programs/sort/ 98 | ``` 99 | `output/sort/` should be the output directory of a training run. 100 | 101 | # Example Programs 102 | 103 | The [programs](programs/) directory contains example programs for small-scale versions of all of the [RASP tasks](https://arxiv.org/abs/2106.06981), as well as named-entity recognition. 104 | Each program defines a function called `run` that takes a sequence of tokens as input and returns a list of predicted labels. 105 | For example: 106 | ```pycon 107 | >>> from programs.rasp.sort import sort 108 | >>> sort.run(["", "3", "1", "4", "2", "4", "0", ""]) 109 | ['', '0', '1', '2', '3', '4', '4', ''] 110 | ``` 111 | [programs/rasp](programs/rasp) contains the best-performing programs for each task, using both categorical and numerical attention heads. 112 | [programs/rasp_categorical_only](programs/rasp_categorical_only) contains the best-performing programs using only categorical variables. 113 | [programs/conll_ner](programs/conll_ner) contains a program for named-entity recognition. 114 | 115 | # Questions? 116 | 117 | If you have any questions about the code or paper, please email Dan (dfriedman@cs.princeton.edu) or open an issue. 118 | 119 | # Citation 120 | 121 | ```bibtex 122 | @inproceedings{ 123 | friedman2023learning, 124 | title={Learning Transformer Programs}, 125 | author={Dan Friedman and Alexander Wettig and Danqi Chen}, 126 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 127 | year={2023}, 128 | url={https://openreview.net/forum?id=Pe9WxkN8Ff} 129 | } 130 | ``` 131 | -------------------------------------------------------------------------------- /figures/methodology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/figures/methodology.png -------------------------------------------------------------------------------- /programs/rasp/double_hist/double_hist_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,1,2,3,4,5,6 2 | attn_0_0_outputs,0,1.238853,1.243148,0.68027335,-4.868072,-3.0746865,-8.173675 3 | attn_0_0_outputs,1,0.31741944,0.6675796,0.59875584,-0.74731636,-1.1577901,-1.0011843 4 | attn_0_0_outputs,2,-1.0877419,1.9231943,0.3949728,-1.539528,-1.1613191,-2.8614204 5 | attn_0_0_outputs,3,-0.11632717,0.41220337,0.60158277,-0.28780693,-0.35998842,-0.59837425 6 | attn_0_0_outputs,4,-0.19129646,0.30991477,0.4910216,-0.08420599,0.5274653,-0.19311686 7 | attn_0_0_outputs,5,-0.40474758,-0.0207731,0.56750906,-0.62817794,1.2629008,0.54661673 8 | attn_0_0_outputs,6,-1.8074787,-0.30840394,2.049226,0.3926784,-1.7478626,5.0987434 9 | attn_0_0_outputs,7,-1.731275,-0.15045756,0.36238095,0.8407283,4.7304835,-2.225291 10 | attn_0_1_outputs,0,0.5155269,-0.23179123,0.9619248,0.8720239,-4.370495,-2.8444674 11 | attn_0_1_outputs,1,-0.032038573,0.15968679,-0.11178788,-0.96792173,-1.093138,-1.0469398 12 | attn_0_1_outputs,2,-0.49667308,0.20647019,0.9112648,-1.1316853,0.0978555,-0.8900808 13 | attn_0_1_outputs,3,-0.19489852,0.30834976,0.29149553,-0.5899691,-1.2079358,-0.4888331 14 | attn_0_1_outputs,4,-0.7092352,0.40705568,1.2359505,-0.14838673,-1.4845357,-0.8633547 15 | attn_0_1_outputs,5,-1.7096843,0.57668656,1.8166345,-1.3838611,2.8332915,1.320257 16 | attn_0_1_outputs,6,-1.8643975,-0.14910048,1.1025906,-0.046011463,1.486126,3.767073 17 | attn_0_1_outputs,7,-1.6024725,-0.0244187,1.0061564,-0.68429786,3.118972,-2.3170335 18 | attn_0_2_outputs,0,1.2139428,1.788951,-0.6449291,-4.5301723,-3.1953726,-8.304671 19 | attn_0_2_outputs,1,0.2632638,0.2750803,-0.058388937,-0.96432394,-1.2596529,-0.82894176 20 | attn_0_2_outputs,2,-0.06439551,0.10690294,-0.5768559,-1.3450785,-1.2679642,-1.0744632 21 | attn_0_2_outputs,3,0.071711585,0.59252673,0.50191075,-0.10376727,-0.6866497,-0.6754526 22 | attn_0_2_outputs,4,0.10581133,0.1553707,0.038040042,0.071251996,0.29624116,0.06374497 23 | attn_0_2_outputs,5,-1.8224907,0.8820499,0.9754138,-1.1711323,2.2756393,-0.019901173 24 | attn_0_2_outputs,6,-1.8428569,0.5488585,2.5900345,1.2428666,-0.56255573,5.247683 25 | attn_0_2_outputs,7,-2.4727259,0.5789163,1.2076744,1.0896817,5.294417,-2.1652772 26 | attn_0_3_outputs,0,-0.5514277,1.2344804,0.32450405,-1.3071667,-3.758229,-5.2093534 27 | attn_0_3_outputs,1,-0.5949326,0.017211447,-0.21026301,-1.1978165,-1.1930654,-1.3590801 28 | attn_0_3_outputs,2,-0.9301655,0.37480146,1.6829457,-0.75800765,-0.65382415,-0.49428207 29 | attn_0_3_outputs,3,-0.27631456,0.27953884,0.43736613,-0.38235372,-0.5264454,-0.63245666 30 | attn_0_3_outputs,4,-0.93398,0.5528436,1.5676297,0.10561387,-0.64765644,-0.28653306 31 | attn_0_3_outputs,5,-1.6023785,0.24490198,1.8999624,-1.2599331,0.92110413,0.8147831 32 | attn_0_3_outputs,6,-1.3627234,0.13494563,1.5294584,0.40429628,1.4354382,4.1456857 33 | attn_0_3_outputs,7,-1.6638067,0.09259084,1.1263107,-0.0054527083,2.0681398,-2.6293533 34 | attn_1_0_outputs,0,-0.9760603,0.9746706,-0.4910691,-0.28577244,-0.55609435,-0.629483 35 | attn_1_0_outputs,1,0.027657213,0.029981421,-0.24031596,-0.596706,-0.17725739,-0.60143006 36 | attn_1_0_outputs,2,-1.6395041,0.9376783,-0.6826759,-0.34162158,-0.6060739,-0.91161937 37 | attn_1_0_outputs,3,-1.273714,0.48423272,1.1538424,-0.8714459,-1.4124024,-1.35856 38 | attn_1_0_outputs,4,-0.8404264,0.92396635,1.5313677,-1.2234914,-3.3400934,-4.8241587 39 | attn_1_0_outputs,5,-1.8222586,0.9126125,0.56483024,-0.6948536,2.166879,1.4243178 40 | attn_1_0_outputs,6,-0.52920604,-1.4131529,-0.1339102,2.2840686,-0.79413146,4.49058 41 | attn_1_0_outputs,7,-0.37916443,-0.98024577,0.21309751,0.6570406,3.8979933,-0.9583814 42 | attn_1_1_outputs,0,1.1848046,-0.78384286,4.3952165,-4.754357,-2.5311391,-4.304711 43 | attn_1_1_outputs,1,-0.13832416,0.7882959,-0.02595803,-1.5303969,-1.0957807,-1.3929857 44 | attn_1_1_outputs,2,-0.50316894,-0.22863217,1.5112969,-0.29604048,-0.580187,-1.3095005 45 | attn_1_1_outputs,3,-0.79329056,0.42733788,0.14356261,0.6288875,-0.19792889,-0.26626664 46 | attn_1_1_outputs,4,-1.9033594,1.7780949,1.592835,-1.8940227,-2.7704701,-0.5928191 47 | attn_1_1_outputs,5,-1.9111953,0.37122542,1.2167804,-1.6839112,2.1015897,-0.7357016 48 | attn_1_1_outputs,6,-2.3154323,-0.28756234,2.4062943,2.108501,0.11289929,4.5701685 49 | attn_1_1_outputs,7,2.4386678,-1.1878088,-2.2552593,-0.20960121,7.1543226,-1.7744904 50 | attn_1_2_outputs,0,-1.1251806,1.1169194,-0.91066366,-0.7407269,-0.81425554,-0.48066366 51 | attn_1_2_outputs,1,0.45716456,0.41061336,-0.46052435,-0.34725094,-1.1503232,-0.87109935 52 | attn_1_2_outputs,2,0.013095508,0.45874515,-0.12555777,-1.040384,-1.5763316,-0.78179336 53 | attn_1_2_outputs,3,-1.289953,0.9008724,1.2994568,-0.13545798,-0.44344234,-0.953467 54 | attn_1_2_outputs,4,-1.1722677,1.0613216,1.844418,-0.7773634,-2.8786588,-4.2749085 55 | attn_1_2_outputs,5,-0.1720321,0.119764805,0.8184181,-1.9128888,1.5016356,0.541971 56 | attn_1_2_outputs,6,-1.4663309,0.59112865,0.12126205,0.59590137,-1.0372763,3.382513 57 | attn_1_2_outputs,7,-0.9167632,-0.07292843,0.44805673,0.12057245,4.2690535,-0.84026176 58 | attn_1_3_outputs,0,-0.18947232,0.93384516,2.9093647,-4.247921,-3.8325047,-1.8869388 59 | attn_1_3_outputs,1,-0.15879959,0.20900528,0.10304354,-1.2823522,-1.4563886,-1.1255486 60 | attn_1_3_outputs,2,-2.7822475,1.1226128,3.61258,-1.2871381,3.0924113,-1.6648036 61 | attn_1_3_outputs,3,-0.18244827,0.10408772,0.1082736,-0.33289492,-1.1335009,-1.1429838 62 | attn_1_3_outputs,4,-1.3203812,-0.40582213,0.47339472,2.714348,-2.8118975,-2.4057264 63 | attn_1_3_outputs,5,-2.4856753,0.31100777,1.4052066,-0.45595795,1.055359,1.5162545 64 | attn_1_3_outputs,6,-1.5411837,-0.10812548,1.6073796,0.0880546,0.5234726,4.230394 65 | attn_1_3_outputs,7,-2.2810664,0.3371949,1.277099,1.0733384,5.143846,-1.5911686 66 | attn_2_0_outputs,0,-0.7608168,0.73692936,2.2746222,0.09226601,-3.9912004,-4.357786 67 | attn_2_0_outputs,1,0.73612237,0.4157502,0.40716407,-1.2919782,-0.75344867,-1.2904367 68 | attn_2_0_outputs,2,-0.43750256,-0.31478426,1.2700193,-1.3610334,-0.40046877,-1.7192286 69 | attn_2_0_outputs,3,-0.14183228,0.46041623,0.7322802,-0.7410107,-0.67563444,-0.98685783 70 | attn_2_0_outputs,4,-1.6794555,0.2764951,2.2788801,-1.535832,-1.5119663,-1.3937936 71 | attn_2_0_outputs,5,-1.122806,0.35427764,1.8529143,-0.08506123,0.5083381,0.70440453 72 | attn_2_0_outputs,6,-1.3130459,0.24721187,1.5720521,0.6670818,1.2480316,5.339757 73 | attn_2_0_outputs,7,-1.756592,-0.3179971,0.9707079,0.3391206,4.742656,-0.77619433 74 | attn_2_1_outputs,0,5.1843696,-0.14824952,-7.7279587,-0.18898705,-7.4058304,-5.8876233 75 | attn_2_1_outputs,1,0.07970304,0.33187732,0.47960332,-1.1841625,-1.2469735,-0.8724654 76 | attn_2_1_outputs,2,0.1316852,0.07690099,0.9332321,-1.4709224,-0.95776975,-0.77643555 77 | attn_2_1_outputs,3,0.17687394,-0.45715567,0.9835989,0.15518622,-0.95674294,-0.68985116 78 | attn_2_1_outputs,4,-0.37534592,-1.1890558,0.37393364,0.37020233,-0.8960392,-0.8327007 79 | attn_2_1_outputs,5,0.011616593,-0.3313448,0.8376248,-1.2172726,1.578385,1.6046035 80 | attn_2_1_outputs,6,-3.7970808,2.6305835,-0.4534069,0.8028219,-3.9948761,8.23428 81 | attn_2_1_outputs,7,-6.8515778,-0.8602876,8.407036,-1.1475514,11.212976,-2.1321476 82 | attn_2_2_outputs,0,-1.3947915,0.82383966,0.19460057,0.4075744,-0.50466233,0.38375315 83 | attn_2_2_outputs,1,0.36330578,0.78628147,0.20042144,0.08738863,0.14224096,-0.08429496 84 | attn_2_2_outputs,2,-1.4779562,1.399697,0.86861706,0.1598239,0.75133777,0.3097304 85 | attn_2_2_outputs,3,-1.2577441,0.014175095,1.141915,1.2425467,-0.23655094,-0.7029587 86 | attn_2_2_outputs,4,-0.98365426,0.30592975,1.760364,-3.4134092,-4.0350323,-6.1376657 87 | attn_2_2_outputs,5,-1.975907,1.060878,0.7518102,-1.0094594,1.415349,0.56762004 88 | attn_2_2_outputs,6,-0.56342816,-1.4276974,0.8616133,1.3711227,-0.664221,2.6826735 89 | attn_2_2_outputs,7,-1.6278553,-0.31998882,1.4476215,-0.50721836,2.9143424,-0.6652316 90 | attn_2_3_outputs,0,-0.10320366,1.2831827,1.0159588,-4.5068374,-3.8469913,-3.0415845 91 | attn_2_3_outputs,1,0.9415238,0.6988879,-0.58729637,-1.3218888,-1.1747098,-1.3971877 92 | attn_2_3_outputs,2,-1.1604211,0.40635908,2.708663,-3.707528,2.3209622,-3.0766304 93 | attn_2_3_outputs,3,-0.8133805,-0.66232896,0.031035835,-0.24328968,-0.7707727,-0.6527309 94 | attn_2_3_outputs,4,-1.1196601,-0.2392621,1.2020575,1.8583574,-3.9654202,-3.1436868 95 | attn_2_3_outputs,5,-1.373232,0.14063334,0.9891212,-0.32424918,1.4090708,0.9559994 96 | attn_2_3_outputs,6,-1.7961397,-0.22981627,1.3094918,0.7286974,0.2208126,4.3420715 97 | attn_2_3_outputs,7,-1.9815708,-0.68765175,1.3021115,0.6383351,4.8859954,-1.5050861 98 | mlp_0_0_outputs,0,-0.80427796,-0.5239094,0.64130455,-0.8501387,0.88480604,0.8062301 99 | mlp_0_0_outputs,1,-0.703996,-0.06041357,1.1169733,-0.21612476,-0.56217194,-0.062908545 100 | mlp_0_0_outputs,2,-0.5470064,0.0647984,0.9857732,-0.7440401,1.3193768,0.47770888 101 | mlp_0_0_outputs,3,-1.2310531,-0.66088575,1.5901628,0.08961851,0.6930918,-1.6064136 102 | mlp_0_0_outputs,4,0.6362144,0.93965715,-1.6470672,-0.77374464,-4.137099,-3.0756845 103 | mlp_0_0_outputs,5,-0.69935584,-0.5942387,0.61135715,-0.33549237,0.9449682,-0.04219789 104 | mlp_0_0_outputs,6,-0.81338984,-0.004528439,1.7929608,-1.1758713,-1.5474731,-0.6041481 105 | mlp_0_0_outputs,7,-1.3381019,-0.25485164,0.9574129,-0.5505601,0.84711146,-1.2687894 106 | mlp_1_0_outputs,0,-0.35261238,-0.488968,0.317209,-0.353129,-0.29976666,-0.22690473 107 | mlp_1_0_outputs,1,10.887252,-2.810294,-6.033218,-4.187805,-3.888413,-2.4883578 108 | mlp_1_0_outputs,2,-0.18418525,-0.3537582,0.4354651,-0.05020326,0.2935266,0.3063467 109 | mlp_1_0_outputs,3,-4.5915027,-0.70305747,7.38149,0.14982294,-0.30977443,-0.60969657 110 | mlp_1_0_outputs,4,-0.2931938,-0.42586467,0.6560199,-0.23105785,-0.093104005,0.11018199 111 | mlp_1_0_outputs,5,0.026697708,-0.14337309,0.4799277,-0.05001711,0.34555975,0.35481232 112 | mlp_1_0_outputs,6,0.028377075,-0.24432565,0.54357,-0.14310212,0.11082643,-0.18106437 113 | mlp_1_0_outputs,7,7.758,0.051592894,-1.7972865,-0.4951801,-0.56271195,-1.6264058 114 | mlp_2_0_outputs,0,-0.07149683,0.17804867,0.027318498,-0.47190544,0.2526821,0.25335145 115 | mlp_2_0_outputs,1,-0.029524894,-0.124501094,-0.280803,-0.6324781,-0.11064398,0.20567767 116 | mlp_2_0_outputs,2,-0.36568528,-0.18049608,1.1864383,2.4141302,-1.3685066,-2.4289098 117 | mlp_2_0_outputs,3,-0.08572693,0.095609844,1.3656018,2.4667256,-0.62999135,-0.9260636 118 | mlp_2_0_outputs,4,-0.25557655,-0.08844128,-0.011158431,-0.67957497,-0.13888906,0.057696126 119 | mlp_2_0_outputs,5,-0.57896894,-0.4964491,1.1442068,0.18332906,-0.40583086,-1.8021346 120 | mlp_2_0_outputs,6,-0.14598659,-0.25447085,-0.09391215,-1.0940522,-0.26641744,0.085448615 121 | mlp_2_0_outputs,7,-1.6514058,0.5661072,2.6141307,-0.58939695,0.0069024065,-0.47796342 122 | num_attn_0_0_outputs,_,1.6166081,0.65979415,0.02641497,-3.4313304,-4.040161,-2.7888904 123 | num_attn_0_1_outputs,_,-0.24869336,-0.6833789,0.9426215,0.3185697,1.193121,0.52059203 124 | num_attn_0_2_outputs,_,-2.9678483,1.6683826,1.7130572,-1.6549467,2.552361,-0.4446513 125 | num_attn_0_3_outputs,_,-2.0443754,0.47501704,0.6353057,0.45521334,-0.075898685,-0.39096996 126 | num_attn_1_0_outputs,_,1.4484639,-0.9826932,-2.5024445,-4.524883,9.575657,-1.2793733 127 | num_attn_1_1_outputs,_,1.5601239,0.16244218,-2.2383647,-2.335401,-0.31956533,1.361603 128 | num_attn_1_2_outputs,_,-0.75736904,0.11321503,-0.82287806,0.718299,0.71823144,1.0727578 129 | num_attn_1_3_outputs,_,2.7385015,1.4961662,-1.1917236,-1.0514781,-6.2250605,-9.062402 130 | num_attn_2_0_outputs,_,-15.748312,-2.3748958,2.378052,5.5488276,6.0594873,5.684393 131 | num_attn_2_1_outputs,_,0.85498565,0.80901504,-0.25188714,-0.95084816,-4.0991917,-3.6710646 132 | num_attn_2_2_outputs,_,2.5925405,0.43864623,-0.7539673,-1.6541822,-4.0720997,-8.631772 133 | num_attn_2_3_outputs,_,4.2104626,-0.2540275,-5.9844904,-1.7332941,-0.9908037,1.0950465 134 | num_mlp_0_0_outputs,0,-1.8669392,-0.2961063,0.98103577,1.4933444,2.198966,1.1375257 135 | num_mlp_0_0_outputs,1,-1.1520817,-1.1762809,0.09541475,0.8781277,1.2588459,0.8453051 136 | num_mlp_0_0_outputs,2,-0.9004866,-0.8049949,0.11854656,0.39003035,0.67636615,0.64462405 137 | num_mlp_0_0_outputs,3,8.561097,-1.1257874,-3.453622,0.69200045,1.3708919,-0.49515036 138 | num_mlp_0_0_outputs,4,-2.8854432,0.90047354,6.3888383,-5.94705,-7.024973,-3.1921215 139 | num_mlp_0_0_outputs,5,-0.6971181,-0.7816255,0.16676763,0.96413225,1.1872518,0.85966665 140 | num_mlp_0_0_outputs,6,-0.26123014,-0.059790544,0.3098392,0.6849877,0.6764854,0.49069953 141 | num_mlp_0_0_outputs,7,-0.53684187,0.012539488,-0.02403504,0.5809876,0.40013784,0.6112457 142 | num_mlp_1_0_outputs,0,-0.16077685,-1.2207857,0.23239522,1.5080422,-0.60546786,1.1818341 143 | num_mlp_1_0_outputs,1,-0.6600274,0.16164501,0.88917744,1.1062784,0.20580845,1.2902285 144 | num_mlp_1_0_outputs,2,-0.07932499,0.108578734,0.19467932,0.4703639,-0.11713662,0.91529095 145 | num_mlp_1_0_outputs,3,-0.57606816,-0.23965557,0.13980961,0.46052232,-0.27391964,0.45012388 146 | num_mlp_1_0_outputs,4,-0.7861332,-0.52138984,-0.09231762,0.64113474,0.025483737,1.0184015 147 | num_mlp_1_0_outputs,5,-0.713504,0.54386413,1.3562707,-2.5753596,-1.427283,-6.9030857 148 | num_mlp_1_0_outputs,6,-0.50802475,-1.2871871,0.537766,2.643965,-1.3457818,2.97455 149 | num_mlp_1_0_outputs,7,2.0027044,-0.7188986,-4.4494953,3.1136823,0.0020262096,0.36739042 150 | num_mlp_2_0_outputs,0,-0.13204575,-0.11307499,-0.3766279,-0.07198191,0.30587465,0.33915785 151 | num_mlp_2_0_outputs,1,-1.0364666,2.084065,-0.50771797,-3.1504862,-0.30198094,0.096702345 152 | num_mlp_2_0_outputs,2,-8.098701,4.3463387,-2.0458136,5.9026895,3.121162,3.0957198 153 | num_mlp_2_0_outputs,3,-3.605029,3.1863525,2.0602572,-2.160901,-7.234171,-0.6112598 154 | num_mlp_2_0_outputs,4,1.5672542,-0.33181366,-4.4696174,-0.4666259,4.466716,-4.639328 155 | num_mlp_2_0_outputs,5,1.5313199,-2.517952,1.0063609,0.32012826,-1.8395686,-0.6061035 156 | num_mlp_2_0_outputs,6,0.21038409,2.1400955,-0.90267885,-4.3364935,-1.1109636,0.27975422 157 | num_mlp_2_0_outputs,7,1.5466431,2.9550323,-2.1725159,-3.0149686,-2.3728058,-0.8228693 158 | ones,_,-0.63681555,0.6360205,1.2210147,-0.1677267,-1.8965924,-0.9514799 159 | positions,0,0.31322083,-0.18008727,-0.17022493,0.5562112,-0.3243463,-0.33771023 160 | positions,1,-0.164062,0.3669627,1.1244484,0.6322126,-1.4446061,-1.479529 161 | positions,2,-0.76111907,0.4353367,1.4015459,-0.34046993,-1.0691034,-0.5398579 162 | positions,3,-1.0637239,-0.030700902,1.1874292,-1.0612618,-0.95897645,-1.2181816 163 | positions,4,-0.79604656,0.28970852,1.2973528,-0.36234856,-0.75552547,0.12908255 164 | positions,5,-1.1553457,0.0699891,1.1624489,-0.49714485,1.4032562,0.39587685 165 | positions,6,-0.6900308,0.2557735,1.3243333,-0.9535796,-0.028765034,0.8872918 166 | positions,7,-0.88872874,0.07091375,1.42425,-0.60613817,-1.0351824,-2.0732577 167 | tokens,0,-0.8381742,0.11838423,1.1716859,-0.34598798,-1.2162316,-1.9089854 168 | tokens,1,-0.6593003,0.58185554,1.4920444,-0.22869995,-0.9598005,-1.7809644 169 | tokens,2,-1.02963,0.06541627,1.0584888,-0.6785367,-1.1368308,-2.131806 170 | tokens,3,-0.7562651,0.465944,1.4615022,-0.5278057,-0.5604278,-0.8920955 171 | tokens,4,-0.56911784,0.50735754,1.4724189,-0.20176388,-0.87996334,-0.4631814 172 | tokens,5,-0.85632926,0.32627252,1.2416458,-0.29448336,-0.47716865,0.26062497 173 | tokens,,0.04131505,-0.46409363,0.15720077,-0.5609889,0.16757916,0.34972325 174 | tokens,,0.10905723,-0.33392945,0.07082459,-0.31309077,0.42329115,-0.34247398 175 | -------------------------------------------------------------------------------- /programs/rasp/dyck1/dyck1_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,F,P,T 2 | attn_0_0_outputs,(,-2.6340582,3.296491,-0.44525737 3 | attn_0_0_outputs,),2.815133,-2.8173714,-0.077561244 4 | attn_0_0_outputs,,0.31516382,0.6258608,0.51490164 5 | attn_0_0_outputs,,-0.7577929,1.2806988,0.021066613 6 | attn_0_0_outputs,_,0.1783087,-0.022254534,-0.33423853 7 | attn_0_0_outputs,_,0.10090828,-0.2582165,0.1490271 8 | attn_0_0_outputs,_,-0.15222801,-0.89355934,0.5346948 9 | attn_0_0_outputs,_,0.29934147,-0.37996712,-0.39817178 10 | attn_0_0_outputs,_,-0.16398807,0.25323474,-0.46668038 11 | attn_0_0_outputs,_,0.028858155,0.07219034,-0.06156428 12 | attn_0_0_outputs,_,0.06055783,0.2945571,-0.19819283 13 | attn_0_0_outputs,_,-0.25926435,-0.16323452,0.4009143 14 | attn_0_0_outputs,_,0.52586144,0.20267111,0.8969524 15 | attn_0_0_outputs,_,0.22868425,0.1599105,-0.37888685 16 | attn_0_0_outputs,_,0.22064249,0.11402117,0.7742326 17 | attn_0_0_outputs,_,-0.24329506,0.19571169,-0.46124965 18 | attn_0_1_outputs,(,-0.36807892,1.1088231,0.27818483 19 | attn_0_1_outputs,),-0.41888747,-0.09075923,-0.527925 20 | attn_0_1_outputs,,-0.85293067,-0.5840806,-0.40153265 21 | attn_0_1_outputs,,2.7544026,0.20201652,-2.4185522 22 | attn_0_1_outputs,_,-0.3167692,-0.082709216,0.993802 23 | attn_0_1_outputs,_,0.80444956,-1.1216061,-0.76241434 24 | attn_0_1_outputs,_,0.90439874,-0.15584756,-0.8316175 25 | attn_0_1_outputs,_,0.6542889,0.48348022,0.032383945 26 | attn_0_1_outputs,_,0.5817215,-0.064832136,0.2867886 27 | attn_0_1_outputs,_,0.49341214,-0.33040112,-0.03182304 28 | attn_0_1_outputs,_,-0.083858415,0.5024774,0.1928215 29 | attn_0_1_outputs,_,0.3615586,0.16559708,-0.07771133 30 | attn_0_1_outputs,_,1.004799,-0.7025519,-0.46842074 31 | attn_0_1_outputs,_,0.43362275,0.28161952,-0.0104415575 32 | attn_0_1_outputs,_,-0.31160653,0.23411115,-0.57747537 33 | attn_0_1_outputs,_,-0.37716666,0.5208574,-0.76324815 34 | attn_0_2_outputs,(,-0.030860722,1.7733968,-1.1154492 35 | attn_0_2_outputs,),0.8185658,0.98675424,-0.65843534 36 | attn_0_2_outputs,,-0.18944798,1.2176713,-0.42110875 37 | attn_0_2_outputs,,-0.34205785,-1.2517091,-0.5264156 38 | attn_0_2_outputs,_,-0.66757905,0.14454332,-0.24297304 39 | attn_0_2_outputs,_,0.88671744,0.62714237,0.2174341 40 | attn_0_2_outputs,_,-0.100170776,0.9241575,-0.67759424 41 | attn_0_2_outputs,_,-0.258302,1.2594484,-0.8207846 42 | attn_0_2_outputs,_,0.7526842,0.39307418,-0.562563 43 | attn_0_2_outputs,_,-0.6082082,0.36667,0.17507246 44 | attn_0_2_outputs,_,0.16467999,0.74238044,-0.7653758 45 | attn_0_2_outputs,_,-0.4535027,0.41556233,0.13512231 46 | attn_0_2_outputs,_,0.28487095,0.34868884,-0.17336306 47 | attn_0_2_outputs,_,-0.30027777,0.45009154,-0.021399405 48 | attn_0_2_outputs,_,-0.115277715,0.6226942,0.47862878 49 | attn_0_2_outputs,_,0.3414086,-0.04924957,-0.71807915 50 | attn_0_3_outputs,(,-0.52774686,1.061042,-1.3247046 51 | attn_0_3_outputs,),0.53207314,-0.51257735,0.37381876 52 | attn_0_3_outputs,,1.2845441,-0.9961886,-1.6160716 53 | attn_0_3_outputs,,-0.07281321,0.24826683,0.22082955 54 | attn_0_3_outputs,_,-0.40482992,-0.7973792,-0.39291495 55 | attn_0_3_outputs,_,0.6773935,-0.1490348,-0.9378139 56 | attn_0_3_outputs,_,-2.6586642,0.3645929,1.424275 57 | attn_0_3_outputs,_,0.5714178,-0.15675014,-0.5435661 58 | attn_0_3_outputs,_,-0.3804901,1.0324043,0.24484947 59 | attn_0_3_outputs,_,0.34369984,0.27104208,-0.39457268 60 | attn_0_3_outputs,_,0.7677406,0.23606537,-0.36674848 61 | attn_0_3_outputs,_,0.72403216,0.94206196,-0.33496436 62 | attn_0_3_outputs,_,-0.606932,0.45185584,-0.014735105 63 | attn_0_3_outputs,_,-0.47034562,0.1342081,0.53295195 64 | attn_0_3_outputs,_,0.14294454,-0.304042,0.20307107 65 | attn_0_3_outputs,_,0.3230739,-0.42828202,-1.0465811 66 | attn_1_0_outputs,(,-0.6138929,1.1492342,-0.21226932 67 | attn_1_0_outputs,),1.1421751,0.88136363,0.98169017 68 | attn_1_0_outputs,,0.4028954,-0.85677516,-0.023720358 69 | attn_1_0_outputs,,1.0315775,-0.8483394,-1.2852764 70 | attn_1_0_outputs,_,0.15765458,-0.6344901,-0.5549551 71 | attn_1_0_outputs,_,-0.05087188,0.29270756,-0.5156284 72 | attn_1_0_outputs,_,1.1122603,-0.8028311,-0.06559855 73 | attn_1_0_outputs,_,-0.48832175,-0.15441939,-0.2644985 74 | attn_1_0_outputs,_,-0.37901145,-0.03420561,-0.5548248 75 | attn_1_0_outputs,_,0.05604345,-0.97195405,1.0271589 76 | attn_1_0_outputs,_,0.55539984,0.057354666,-0.4074509 77 | attn_1_0_outputs,_,-0.77038145,-0.09844211,0.07271357 78 | attn_1_0_outputs,_,0.52502674,-0.008494564,0.5783807 79 | attn_1_0_outputs,_,-0.7112233,0.87106544,0.24254535 80 | attn_1_0_outputs,_,0.29931223,-0.7983458,-0.2591785 81 | attn_1_0_outputs,_,1.1367513,-1.1649325,-0.25845855 82 | attn_1_1_outputs,(,-0.5425992,1.5051843,-0.28839168 83 | attn_1_1_outputs,),1.5574366,-1.0693368,0.14204933 84 | attn_1_1_outputs,,0.58851004,0.6008777,-0.07114553 85 | attn_1_1_outputs,,-0.32155347,-1.0976663,-1.4935772 86 | attn_1_1_outputs,_,0.17493103,-0.38030365,0.53236455 87 | attn_1_1_outputs,_,1.0686699,-0.120278545,-0.56144714 88 | attn_1_1_outputs,_,1.3084234,0.3557633,-0.14442345 89 | attn_1_1_outputs,_,0.47586453,0.20887673,-0.15990862 90 | attn_1_1_outputs,_,0.13891101,0.68185794,-0.5525119 91 | attn_1_1_outputs,_,-0.6078569,1.2303034,-2.18228 92 | attn_1_1_outputs,_,0.91525626,-0.67010164,-0.21897195 93 | attn_1_1_outputs,_,1.4655004,-0.47915524,0.97498626 94 | attn_1_1_outputs,_,1.1796371,-1.6545213,-0.45778522 95 | attn_1_1_outputs,_,-0.7026438,0.5664589,1.0344129 96 | attn_1_1_outputs,_,0.5785904,0.5293408,0.37516353 97 | attn_1_1_outputs,_,2.3779967,-2.060698,-1.0075454 98 | attn_1_2_outputs,(,-0.056032557,0.6231306,0.24149144 99 | attn_1_2_outputs,),0.44365737,-0.98022085,-0.64319813 100 | attn_1_2_outputs,,0.08831471,-0.29584908,-0.15343662 101 | attn_1_2_outputs,,-1.3659412,3.5988708,-0.66621685 102 | attn_1_2_outputs,_,-0.1274261,0.34521273,0.027675765 103 | attn_1_2_outputs,_,-0.14180441,1.1854141,-0.8122686 104 | attn_1_2_outputs,_,0.4875018,0.45869988,0.1374755 105 | attn_1_2_outputs,_,-0.32931647,0.87830615,-0.49895287 106 | attn_1_2_outputs,_,-0.1785307,0.16045158,-0.8639972 107 | attn_1_2_outputs,_,0.5269155,1.2485906,-1.8244286 108 | attn_1_2_outputs,_,-0.383124,-0.5021012,-0.55019313 109 | attn_1_2_outputs,_,-0.092871554,0.22963502,0.12724838 110 | attn_1_2_outputs,_,-0.8239071,0.94725597,0.037750408 111 | attn_1_2_outputs,_,0.34750172,0.9910317,-0.7895062 112 | attn_1_2_outputs,_,-0.66590446,0.40987346,0.37104082 113 | attn_1_2_outputs,_,-0.052774414,-0.43269822,-0.5298429 114 | attn_1_3_outputs,(,-0.81292146,1.8696464,-0.5263783 115 | attn_1_3_outputs,),-0.19519545,-0.15165852,-0.19049717 116 | attn_1_3_outputs,,-0.3622958,0.81331193,0.30731088 117 | attn_1_3_outputs,,0.23445317,0.015669065,-1.4538736 118 | attn_1_3_outputs,_,-0.35138428,0.28486276,-0.16167411 119 | attn_1_3_outputs,_,-0.45811188,0.43144393,0.20147736 120 | attn_1_3_outputs,_,0.5497336,0.15336615,0.7609401 121 | attn_1_3_outputs,_,1.5862544,0.9128603,-1.2414508 122 | attn_1_3_outputs,_,0.37166864,0.21070029,-0.18914412 123 | attn_1_3_outputs,_,0.12535681,0.1599941,0.54591626 124 | attn_1_3_outputs,_,-0.6368585,0.37556642,-0.5828852 125 | attn_1_3_outputs,_,-0.935505,1.2570758,0.15145183 126 | attn_1_3_outputs,_,0.5280753,0.11495201,-0.31024593 127 | attn_1_3_outputs,_,-0.3986401,0.9080014,0.763646 128 | attn_1_3_outputs,_,1.2436559,-0.40404543,-1.3131849 129 | attn_1_3_outputs,_,0.686876,-0.80954635,-0.45231095 130 | attn_2_0_outputs,(,-1.5561209,1.4590123,-0.57100546 131 | attn_2_0_outputs,),0.9308131,-1.3626683,-1.3802326 132 | attn_2_0_outputs,,-0.9678551,0.26538986,-0.8596553 133 | attn_2_0_outputs,,0.3270525,1.4117384,-0.017223258 134 | attn_2_0_outputs,_,0.10557971,1.4797398,-0.36684495 135 | attn_2_0_outputs,_,0.4341196,-0.31931913,1.5029707 136 | attn_2_0_outputs,_,-0.2973821,-0.19198877,-0.1289013 137 | attn_2_0_outputs,_,-0.36797506,0.47559232,0.19091709 138 | attn_2_0_outputs,_,-0.4748216,0.026234714,0.63951665 139 | attn_2_0_outputs,_,0.94507056,0.34140712,0.31623304 140 | attn_2_0_outputs,_,-0.028984,-0.26938406,0.25619373 141 | attn_2_0_outputs,_,-0.06181845,-0.548181,0.2155078 142 | attn_2_0_outputs,_,0.3375238,0.1336795,-0.2824661 143 | attn_2_0_outputs,_,0.575439,0.9426243,0.64006364 144 | attn_2_0_outputs,_,0.074084885,0.34421614,0.43613386 145 | attn_2_0_outputs,_,0.47341222,0.08374641,0.17351331 146 | attn_2_1_outputs,(,-0.45127738,1.0018231,0.44534314 147 | attn_2_1_outputs,),-0.6839655,0.074150436,-0.2715587 148 | attn_2_1_outputs,,-0.70120454,-0.0535122,0.48966962 149 | attn_2_1_outputs,,-0.2806599,-0.20490536,-0.7978403 150 | attn_2_1_outputs,_,0.20717245,-0.7625038,-0.5953323 151 | attn_2_1_outputs,_,0.15005049,-1.1470549,-0.4587043 152 | attn_2_1_outputs,_,0.9465549,-0.71227086,-0.25911647 153 | attn_2_1_outputs,_,-0.3448204,-0.19821261,-0.5702802 154 | attn_2_1_outputs,_,0.16869305,0.3031439,-0.11745471 155 | attn_2_1_outputs,_,-0.4607347,1.3273661,-0.5856749 156 | attn_2_1_outputs,_,-0.3482659,-0.6055313,-1.2954234 157 | attn_2_1_outputs,_,0.6048929,0.2115645,0.3289516 158 | attn_2_1_outputs,_,-0.19371964,-0.4275286,-0.74410176 159 | attn_2_1_outputs,_,-0.13987288,0.3069642,0.38531253 160 | attn_2_1_outputs,_,0.039333925,0.6699897,0.034916207 161 | attn_2_1_outputs,_,1.5877175,-0.9881175,0.006162866 162 | attn_2_2_outputs,(,-2.750019,2.7980466,2.5939832 163 | attn_2_2_outputs,),3.6954346,-1.7907751,-2.5867593 164 | attn_2_2_outputs,,-0.079127155,-0.9472731,-1.1885493 165 | attn_2_2_outputs,,4.0218706,-3.1554546,-4.7507057 166 | attn_2_2_outputs,_,0.9610852,-0.08878815,0.08363015 167 | attn_2_2_outputs,_,-0.08353972,0.37381604,-0.28025633 168 | attn_2_2_outputs,_,0.35657725,0.8969886,0.072851464 169 | attn_2_2_outputs,_,-0.32833374,-0.7531067,0.3516615 170 | attn_2_2_outputs,_,0.9393161,-0.5649734,-0.5196474 171 | attn_2_2_outputs,_,0.08626939,-1.1392281,1.6193309 172 | attn_2_2_outputs,_,1.2216192,-0.21259107,-1.3641104 173 | attn_2_2_outputs,_,-0.09241636,0.27687886,-0.16864115 174 | attn_2_2_outputs,_,0.49919057,0.65654963,0.21979749 175 | attn_2_2_outputs,_,-0.7234905,-0.14355081,0.40794903 176 | attn_2_2_outputs,_,0.72649467,0.1041588,0.011536894 177 | attn_2_2_outputs,_,0.27901632,0.6072106,-0.4015586 178 | attn_2_3_outputs,(,-1.6574559,1.2850804,0.09690388 179 | attn_2_3_outputs,),-0.37473014,0.7424306,0.21433578 180 | attn_2_3_outputs,,-1.5348345,1.6227428,1.3219541 181 | attn_2_3_outputs,,0.25438994,-0.1620961,-0.8211459 182 | attn_2_3_outputs,_,-0.051882792,0.7923025,-0.18180363 183 | attn_2_3_outputs,_,-0.6806416,0.61710155,-0.14199154 184 | attn_2_3_outputs,_,0.82735395,0.37269086,0.15985933 185 | attn_2_3_outputs,_,0.5044491,0.6499633,-0.1310748 186 | attn_2_3_outputs,_,0.3183162,-0.5054114,-0.007711913 187 | attn_2_3_outputs,_,-0.33328158,0.29277143,0.18589394 188 | attn_2_3_outputs,_,0.047428608,0.4182818,0.6526039 189 | attn_2_3_outputs,_,-1.0801055,0.58600456,0.11516155 190 | attn_2_3_outputs,_,0.04333351,0.2217204,0.12683378 191 | attn_2_3_outputs,_,-1.0841376,0.9090559,-0.7465817 192 | attn_2_3_outputs,_,-0.87214464,0.8850236,0.20182282 193 | attn_2_3_outputs,_,-0.46616295,-0.37282166,0.15026246 194 | mlp_0_0_outputs,0,-0.609197,-0.027131865,-0.22918989 195 | mlp_0_0_outputs,1,0.556234,-0.03605403,-0.371482 196 | mlp_0_0_outputs,10,-1.310118,3.0756617,0.20343237 197 | mlp_0_0_outputs,11,-0.04378883,-0.092148624,-0.11734988 198 | mlp_0_0_outputs,12,-0.90647274,-0.1498387,-0.19741559 199 | mlp_0_0_outputs,13,0.28161588,0.206903,-0.98625535 200 | mlp_0_0_outputs,14,0.5062789,-0.39951432,-0.25588968 201 | mlp_0_0_outputs,15,0.15142694,0.89919674,0.6030912 202 | mlp_0_0_outputs,2,0.72293925,-0.8654614,0.5670439 203 | mlp_0_0_outputs,3,0.7204033,0.15438499,0.8819249 204 | mlp_0_0_outputs,4,-0.4914485,0.48221803,-0.026991796 205 | mlp_0_0_outputs,5,-0.2592103,-0.1592593,-0.34026244 206 | mlp_0_0_outputs,6,-0.26047266,0.2415719,0.13886715 207 | mlp_0_0_outputs,7,0.13437562,0.109038174,-0.27345678 208 | mlp_0_0_outputs,8,0.15529574,0.25750098,0.09392571 209 | mlp_0_0_outputs,9,-2.1050072,0.8252483,-0.97018075 210 | mlp_1_0_outputs,0,0.28741208,0.35307455,0.46118635 211 | mlp_1_0_outputs,1,-0.02276508,0.610253,-0.40055403 212 | mlp_1_0_outputs,10,0.19024082,1.1057127,-0.119028404 213 | mlp_1_0_outputs,11,-0.16726859,0.8911886,0.06611366 214 | mlp_1_0_outputs,12,-0.2248021,0.68136096,-0.19288228 215 | mlp_1_0_outputs,13,-0.2775437,0.17751533,-0.944292 216 | mlp_1_0_outputs,14,-0.38200572,0.5241664,-0.081385374 217 | mlp_1_0_outputs,15,-0.5838907,0.96178657,-1.1940188 218 | mlp_1_0_outputs,2,-0.398339,0.6708183,-0.21698345 219 | mlp_1_0_outputs,3,-0.3026861,0.12032325,-0.1728985 220 | mlp_1_0_outputs,4,-2.0897841,2.2824123,0.1175223 221 | mlp_1_0_outputs,5,-0.35795817,0.78761435,0.056806456 222 | mlp_1_0_outputs,6,-1.2497257,0.2776687,-0.51030177 223 | mlp_1_0_outputs,7,-0.025528096,0.9015349,0.009612377 224 | mlp_1_0_outputs,8,-1.0789367,0.115357794,-0.4832368 225 | mlp_1_0_outputs,9,0.14225361,0.18992919,0.26218814 226 | mlp_2_0_outputs,0,-0.9600631,-0.0905738,0.4139894 227 | mlp_2_0_outputs,1,-1.2516974,0.26433572,-0.093466 228 | mlp_2_0_outputs,10,-0.4429331,1.2817378,1.3716829 229 | mlp_2_0_outputs,11,-0.8686043,0.63608295,0.88492036 230 | mlp_2_0_outputs,12,0.5736482,0.27950564,-0.3706815 231 | mlp_2_0_outputs,13,-1.066069,0.9496548,0.7379837 232 | mlp_2_0_outputs,14,1.1608202,-1.783248,-1.4169054 233 | mlp_2_0_outputs,15,-0.8355681,0.7468592,0.7600724 234 | mlp_2_0_outputs,2,0.41237113,0.16725028,-0.3574132 235 | mlp_2_0_outputs,3,-0.86503994,1.2446153,0.70891947 236 | mlp_2_0_outputs,4,-0.27445322,1.5090894,1.6432297 237 | mlp_2_0_outputs,5,1.4590892,-0.91437083,-0.60318696 238 | mlp_2_0_outputs,6,-1.553053,0.7144941,0.7145207 239 | mlp_2_0_outputs,7,-1.256165,0.2772562,0.011028548 240 | mlp_2_0_outputs,8,0.7696337,-1.3811562,0.09975987 241 | mlp_2_0_outputs,9,6.067684,-2.7872226,-8.341669 242 | num_attn_0_0_outputs,_,3.7398744,-7.181635,3.5538623 243 | num_attn_0_1_outputs,_,0.49548444,0.100667655,-0.16136481 244 | num_attn_0_2_outputs,_,2.1738443,-1.07889,-1.8416303 245 | num_attn_0_3_outputs,_,0.691551,-0.5990272,0.014788681 246 | num_attn_1_0_outputs,_,2.9874318,-2.4701202,-0.3135344 247 | num_attn_1_1_outputs,_,0.12491817,-0.38467312,0.048041295 248 | num_attn_1_2_outputs,_,0.3216399,0.071806945,0.3041418 249 | num_attn_1_3_outputs,_,-1.1208845,1.7802722,0.1500264 250 | num_attn_2_0_outputs,_,1.0652643,0.16920693,0.12847312 251 | num_attn_2_1_outputs,_,0.2974014,-1.6822767,-0.6943763 252 | num_attn_2_2_outputs,_,1.2344797,-1.0337117,-0.39935145 253 | num_attn_2_3_outputs,_,0.68644404,0.025850799,0.31126392 254 | num_mlp_0_0_outputs,0,1.0641823,-0.54406065,0.1806945 255 | num_mlp_0_0_outputs,1,0.28959027,-1.0161468,-0.6264937 256 | num_mlp_0_0_outputs,10,-0.043870565,0.95602506,0.09602879 257 | num_mlp_0_0_outputs,11,0.17968851,0.8659425,0.06156473 258 | num_mlp_0_0_outputs,12,-0.48285016,0.5750547,-0.2724834 259 | num_mlp_0_0_outputs,13,0.37540218,-0.24070558,-0.11528563 260 | num_mlp_0_0_outputs,14,-0.41391063,0.54637676,-0.13351686 261 | num_mlp_0_0_outputs,15,-0.29117483,0.471855,-0.28612953 262 | num_mlp_0_0_outputs,2,-0.20709011,0.96985,0.14005361 263 | num_mlp_0_0_outputs,3,-0.7341656,-0.06500526,-0.43360725 264 | num_mlp_0_0_outputs,4,-0.7252949,0.9496459,0.97981936 265 | num_mlp_0_0_outputs,5,-0.46789286,0.19887882,-0.40079117 266 | num_mlp_0_0_outputs,6,-0.11914711,0.50922734,-0.12805091 267 | num_mlp_0_0_outputs,7,-0.5103994,-0.83839464,-0.7605896 268 | num_mlp_0_0_outputs,8,0.0003218591,0.4606981,-0.7254237 269 | num_mlp_0_0_outputs,9,0.12420503,-1.4992923,-1.0763122 270 | num_mlp_1_0_outputs,0,-0.50804985,-0.048254747,-0.50164807 271 | num_mlp_1_0_outputs,1,-0.18716975,0.314667,-0.13772789 272 | num_mlp_1_0_outputs,10,-0.018390449,0.30108446,-0.16875868 273 | num_mlp_1_0_outputs,11,-0.2499132,0.2316408,-0.27131632 274 | num_mlp_1_0_outputs,12,-0.36267832,-0.029418426,-0.3747254 275 | num_mlp_1_0_outputs,13,-0.5715566,-0.050854474,-0.74630326 276 | num_mlp_1_0_outputs,14,-0.43067545,-0.19346179,0.2828539 277 | num_mlp_1_0_outputs,15,-0.8118014,0.031878382,-0.5857855 278 | num_mlp_1_0_outputs,2,-0.94931656,0.090484895,-0.4667937 279 | num_mlp_1_0_outputs,3,-0.3243798,0.38227555,-0.16742714 280 | num_mlp_1_0_outputs,4,-0.55292064,0.0885116,-0.16673157 281 | num_mlp_1_0_outputs,5,-0.19122665,0.56547916,0.026035376 282 | num_mlp_1_0_outputs,6,-0.35779983,0.07764958,-0.5537944 283 | num_mlp_1_0_outputs,7,-0.16128884,0.30881864,-0.5096637 284 | num_mlp_1_0_outputs,8,-0.77744365,-0.02209261,-0.67441314 285 | num_mlp_1_0_outputs,9,0.1671805,0.7428009,0.7823384 286 | num_mlp_2_0_outputs,0,-0.3438476,-0.018625923,-0.8805979 287 | num_mlp_2_0_outputs,1,-0.19985998,0.26809484,-0.60468084 288 | num_mlp_2_0_outputs,10,-0.39671543,0.106505,-0.6291543 289 | num_mlp_2_0_outputs,11,-0.50791365,0.48418602,-0.25579256 290 | num_mlp_2_0_outputs,12,-3.960384,3.2354944,1.5324167 291 | num_mlp_2_0_outputs,13,-0.061338473,-0.11631737,-0.47945797 292 | num_mlp_2_0_outputs,14,-0.30011797,0.09830846,-0.6083012 293 | num_mlp_2_0_outputs,15,-0.38574433,0.19097115,-0.82048756 294 | num_mlp_2_0_outputs,2,-0.20133315,0.10674115,-0.58823496 295 | num_mlp_2_0_outputs,3,0.0117857745,0.66684407,0.3024773 296 | num_mlp_2_0_outputs,4,-0.33616486,0.30070633,-0.4564914 297 | num_mlp_2_0_outputs,5,-0.04366164,0.78352547,-0.033621307 298 | num_mlp_2_0_outputs,6,0.08861894,0.54160535,0.0855582 299 | num_mlp_2_0_outputs,7,0.039711624,0.9778597,0.20092271 300 | num_mlp_2_0_outputs,8,0.72047126,0.44210163,-0.335978 301 | num_mlp_2_0_outputs,9,-0.006253452,0.69338506,-0.3173577 302 | ones,_,-0.7997238,1.0074667,0.026579212 303 | positions,0,-0.29179272,0.23241787,0.6308148 304 | positions,1,-3.904494,5.006513,-1.8895661 305 | positions,10,2.2057185,-5.531098,5.206517 306 | positions,11,3.9026654,-2.0590394,-2.612058 307 | positions,12,4.1754217,-4.574852,5.335018 308 | positions,13,3.8170507,-2.526535,-3.9886053 309 | positions,14,0.7010598,-0.4262054,-0.4852942 310 | positions,15,4.7547708,-4.233093,-4.5563035 311 | positions,2,-2.9996147,-6.9421015,6.494166 312 | positions,3,1.2800095,2.991268,-9.872106 313 | positions,4,1.2470323,-2.7899837,2.7413142 314 | positions,5,-9.135268,14.433624,-8.043073 315 | positions,6,-7.3264008,11.024153,-0.3841814 316 | positions,7,0.30973187,2.3126788,-1.3860149 317 | positions,8,-10.491879,17.33303,-5.2549305 318 | positions,9,-1.5044698,3.8648603,-1.0246311 319 | tokens,(,2.0314052,0.9187388,-5.384055 320 | tokens,),-3.5163503,0.672222,2.1528707 321 | tokens,,0.3435328,0.89075476,-1.2196645 322 | tokens,,-0.33722147,-0.05571238,-0.4713453 323 | tokens,_,0.55962527,-0.5702431,0.23070188 324 | tokens,_,0.2838472,0.7993582,-0.99131817 325 | tokens,_,0.12776248,0.89144236,-0.23171602 326 | tokens,_,-0.13315062,0.95989436,0.9555972 327 | tokens,_,-0.6196355,-0.36522523,-0.359687 328 | tokens,_,0.10280879,0.3058465,-0.47682306 329 | tokens,_,0.35391825,-0.20935811,0.6336341 330 | tokens,_,1.0826828,-0.8673405,-0.55527455 331 | tokens,_,0.5650465,0.42321026,-0.6229729 332 | tokens,_,0.34496674,-0.7241081,-0.40609998 333 | tokens,_,-0.053234033,0.15473613,0.3231568 334 | tokens,_,0.897978,-0.39115798,-0.12659745 335 | -------------------------------------------------------------------------------- /programs/rasp/dyck2/dyck2_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,F,P,T 2 | attn_0_0_outputs,(,-0.3893371,0.60064846,-1.0043755 3 | attn_0_0_outputs,),0.45308644,-0.33719206,-0.25703704 4 | attn_0_0_outputs,,0.7680508,-1.140804,-0.6336192 5 | attn_0_0_outputs,,1.3979992,-0.55402267,-2.9333327 6 | attn_0_0_outputs,_,-0.27462152,-0.2883787,-0.084659226 7 | attn_0_0_outputs,_,0.41020688,0.80682915,-0.36093912 8 | attn_0_0_outputs,_,0.056031883,-0.039105184,0.19863698 9 | attn_0_0_outputs,_,-0.41961256,-0.54929376,0.43671665 10 | attn_0_0_outputs,_,0.13279785,0.36707786,0.40994072 11 | attn_0_0_outputs,_,-0.2363711,0.5147661,-0.18510109 12 | attn_0_0_outputs,_,0.12620896,-0.18906875,0.9231106 13 | attn_0_0_outputs,_,-0.029982865,0.23939228,1.6595911 14 | attn_0_0_outputs,_,-0.73210377,-1.1544671,0.57133 15 | attn_0_0_outputs,_,-0.23087524,1.8027706,-1.5571856 16 | attn_0_0_outputs,{,-0.04748568,1.2735734,-0.45339546 17 | attn_0_0_outputs,},0.5547761,-0.3066463,-0.1081352 18 | attn_0_1_outputs,(,-0.06017389,0.23692839,-0.65924984 19 | attn_0_1_outputs,),1.5613009,0.05749913,-1.1849948 20 | attn_0_1_outputs,,0.30928898,0.60175115,-0.77076155 21 | attn_0_1_outputs,,-0.05464487,-6.3642993,5.9092317 22 | attn_0_1_outputs,_,0.82652664,0.4707333,0.045426637 23 | attn_0_1_outputs,_,-0.43456507,1.0889953,-0.6075947 24 | attn_0_1_outputs,_,-0.650409,0.48931292,-0.960644 25 | attn_0_1_outputs,_,-0.14928144,-0.3577456,-0.08982656 26 | attn_0_1_outputs,_,-0.4732285,-0.5932854,-0.044864617 27 | attn_0_1_outputs,_,0.17000043,-0.046643566,0.04345899 28 | attn_0_1_outputs,_,-0.049569067,0.36960238,0.43352073 29 | attn_0_1_outputs,_,-0.43735144,-0.4731728,0.18792638 30 | attn_0_1_outputs,_,-0.46125263,-0.9085558,1.1110471 31 | attn_0_1_outputs,_,0.19633614,1.5570519,0.18224098 32 | attn_0_1_outputs,{,-0.11224139,0.64842534,-0.33784837 33 | attn_0_1_outputs,},-0.6292952,1.3489552,0.3662693 34 | attn_1_0_outputs,0,0.20661825,1.3388423,-2.5574758 35 | attn_1_0_outputs,1,2.5823343,-0.7297078,-3.195439 36 | attn_1_0_outputs,10,-0.5139535,-0.078736916,1.6875057 37 | attn_1_0_outputs,11,0.041748732,-0.14228384,-0.14014556 38 | attn_1_0_outputs,12,-0.454622,-0.6740401,0.04633342 39 | attn_1_0_outputs,13,3.6046548,-1.8717623,-2.699525 40 | attn_1_0_outputs,14,-1.5789233,0.97090966,-0.3266689 41 | attn_1_0_outputs,15,-0.24816914,-0.10174749,1.3838878 42 | attn_1_0_outputs,2,0.11910092,0.22456145,1.3250326 43 | attn_1_0_outputs,3,-0.4317825,-0.28709865,-0.20844361 44 | attn_1_0_outputs,4,-3.6241753,3.323929,0.5773792 45 | attn_1_0_outputs,5,-0.11809031,0.24210535,-0.23381685 46 | attn_1_0_outputs,6,-1.0743729,1.1453243,1.608349 47 | attn_1_0_outputs,7,-3.164238,0.699024,1.0340277 48 | attn_1_0_outputs,8,-0.28543827,0.40379593,0.14130151 49 | attn_1_0_outputs,9,-0.5014043,-0.20715642,0.18966955 50 | attn_1_1_outputs,0,0.49038798,1.5800108,-1.2205557 51 | attn_1_1_outputs,1,-0.93285435,1.7327744,-1.2455193 52 | attn_1_1_outputs,10,-0.039011396,-0.65827584,0.71766496 53 | attn_1_1_outputs,11,0.23105314,-0.0301322,1.3648343 54 | attn_1_1_outputs,12,1.4345965,-2.2012353,0.9856132 55 | attn_1_1_outputs,13,-0.18832216,0.19023667,0.89298505 56 | attn_1_1_outputs,14,0.03877773,0.43423906,0.40338022 57 | attn_1_1_outputs,15,2.353314,-1.4448909,-0.21294129 58 | attn_1_1_outputs,2,1.4249635,-1.0321774,0.75964826 59 | attn_1_1_outputs,3,-1.1570958,2.1030843,1.7259277 60 | attn_1_1_outputs,4,-1.1145258,1.9922961,-2.486801 61 | attn_1_1_outputs,5,-0.40263888,0.60196733,1.4514791 62 | attn_1_1_outputs,6,2.065989,-2.143043,0.49976075 63 | attn_1_1_outputs,7,0.30126733,-0.6678852,0.65130216 64 | attn_1_1_outputs,8,-0.061816756,0.13271871,1.2922349 65 | attn_1_1_outputs,9,0.5106305,-1.2614174,1.5335835 66 | attn_2_0_outputs,0,-0.05937145,1.2803885,-0.6700919 67 | attn_2_0_outputs,1,-1.3801527,2.8249412,0.48383245 68 | attn_2_0_outputs,10,0.060948834,-0.13837986,-0.054340422 69 | attn_2_0_outputs,11,-0.37168485,-0.076273456,-0.64383096 70 | attn_2_0_outputs,12,0.13798621,-0.6402446,0.5534661 71 | attn_2_0_outputs,13,0.8956216,-0.676867,0.5628605 72 | attn_2_0_outputs,14,0.23056701,-1.5529728,0.9693785 73 | attn_2_0_outputs,15,1.1047347,-0.7788408,-0.9719154 74 | attn_2_0_outputs,2,0.7777873,-1.1186872,0.052649383 75 | attn_2_0_outputs,3,-1.7339926,2.1315053,1.7853323 76 | attn_2_0_outputs,4,-1.377781,1.8794999,-1.3036643 77 | attn_2_0_outputs,5,3.117208,-1.8162274,1.6918435 78 | attn_2_0_outputs,6,0.7117053,0.103766,0.21950138 79 | attn_2_0_outputs,7,0.6329945,-1.5896332,1.3083339 80 | attn_2_0_outputs,8,4.091406,-3.2152696,0.2076941 81 | attn_2_0_outputs,9,-0.04400771,-0.26516256,0.04262287 82 | attn_2_1_outputs,0,-1.3274802,2.2420464,-0.69805795 83 | attn_2_1_outputs,1,-0.65939707,2.6968226,-0.36676592 84 | attn_2_1_outputs,10,0.04043235,-0.5054306,0.015038767 85 | attn_2_1_outputs,11,-0.6721706,-0.31992498,-0.16208978 86 | attn_2_1_outputs,12,1.5115396,-0.9017534,-0.3154115 87 | attn_2_1_outputs,13,0.79894733,-1.3227056,-0.10762319 88 | attn_2_1_outputs,14,0.12715805,-1.4680986,0.32952312 89 | attn_2_1_outputs,15,-0.2186547,-0.48416838,0.35051808 90 | attn_2_1_outputs,2,0.37131494,-1.1129726,0.40816385 91 | attn_2_1_outputs,3,-1.511338,1.4423819,0.6649975 92 | attn_2_1_outputs,4,-0.44196594,1.0740502,-0.8123087 93 | attn_2_1_outputs,5,5.7057724,-3.2706978,-4.0638323 94 | attn_2_1_outputs,6,2.224006,-1.0677377,-1.4632593 95 | attn_2_1_outputs,7,0.38812694,-0.7842454,1.4326475 96 | attn_2_1_outputs,8,-0.65958214,0.6723407,0.006303966 97 | attn_2_1_outputs,9,0.2128453,-1.0684156,2.4653118 98 | mlp_0_0_outputs,0,1.4124762,0.3918622,0.96247226 99 | mlp_0_0_outputs,1,0.035059676,-0.78724706,-0.49286532 100 | mlp_0_0_outputs,10,0.43810797,0.6196286,-0.29613635 101 | mlp_0_0_outputs,11,0.024041645,-0.45532593,-0.24850975 102 | mlp_0_0_outputs,12,0.21156618,0.30932778,0.49199212 103 | mlp_0_0_outputs,13,0.038805455,-0.1897838,0.03590136 104 | mlp_0_0_outputs,14,0.5470265,0.35249376,-0.26997298 105 | mlp_0_0_outputs,15,0.41228998,-1.3040318,1.2210026 106 | mlp_0_0_outputs,2,-0.45929217,1.0102494,-0.034102287 107 | mlp_0_0_outputs,3,0.46021348,0.09072253,0.099708 108 | mlp_0_0_outputs,4,1.9909586,-0.93861586,-2.8465278 109 | mlp_0_0_outputs,5,1.6368363,-1.6418909,2.8530548 110 | mlp_0_0_outputs,6,-0.85980654,0.6846919,-0.6151556 111 | mlp_0_0_outputs,7,-0.18866596,0.42998883,-0.31827375 112 | mlp_0_0_outputs,8,0.032268297,-0.017944142,0.26745126 113 | mlp_0_0_outputs,9,-0.66864365,-0.14510924,-1.214156 114 | mlp_0_1_outputs,0,0.45482358,-1.2842598,-0.60993576 115 | mlp_0_1_outputs,1,0.31253016,-0.7943685,1.5131452 116 | mlp_0_1_outputs,10,0.25627092,0.22809455,1.0208089 117 | mlp_0_1_outputs,11,0.76562774,-0.18769291,-0.1974224 118 | mlp_0_1_outputs,12,0.38039443,-0.5691965,-0.091124795 119 | mlp_0_1_outputs,13,2.6179564,-0.32311362,-0.72676677 120 | mlp_0_1_outputs,14,-0.040358998,0.030396035,-0.82734126 121 | mlp_0_1_outputs,15,-0.46312347,0.32302356,0.42312586 122 | mlp_0_1_outputs,2,1.068703,-0.067160755,-0.58011967 123 | mlp_0_1_outputs,3,0.14145625,0.29858097,0.6092269 124 | mlp_0_1_outputs,4,-1.5048414,1.6003934,-0.54811466 125 | mlp_0_1_outputs,5,-0.07364487,0.103354104,0.5413819 126 | mlp_0_1_outputs,6,-0.024732672,0.080756046,0.43893534 127 | mlp_0_1_outputs,7,-1.79316,1.7965577,1.1761781 128 | mlp_0_1_outputs,8,-0.4545691,0.09123524,-2.1593776 129 | mlp_0_1_outputs,9,-0.8172794,-0.3064075,0.030816736 130 | mlp_1_0_outputs,0,0.22648785,-0.25926873,-0.506777 131 | mlp_1_0_outputs,1,-0.07746484,-0.984498,6.8709683 132 | mlp_1_0_outputs,10,0.20862451,-0.022583699,-0.790106 133 | mlp_1_0_outputs,11,1.020657,0.9678695,-0.295873 134 | mlp_1_0_outputs,12,0.3965819,0.2451575,-0.6036312 135 | mlp_1_0_outputs,13,0.47067815,0.2438208,-0.6940792 136 | mlp_1_0_outputs,14,-0.35612527,1.1211482,-1.4271184 137 | mlp_1_0_outputs,15,-0.29890123,-0.32131162,-0.707251 138 | mlp_1_0_outputs,2,0.5454376,0.30075255,-1.1875064 139 | mlp_1_0_outputs,3,0.9890708,0.9738718,0.17902154 140 | mlp_1_0_outputs,4,0.35423172,0.36298084,-1.0344858 141 | mlp_1_0_outputs,5,0.38208264,0.20733608,-1.3121635 142 | mlp_1_0_outputs,6,0.3350503,0.058604088,-0.639987 143 | mlp_1_0_outputs,7,0.48388848,0.2619127,-0.6384467 144 | mlp_1_0_outputs,8,-0.0939644,1.6520033,-2.8920875 145 | mlp_1_0_outputs,9,0.46636695,0.19207071,-0.47046715 146 | mlp_1_1_outputs,0,1.0177721,-0.4037793,-0.67931974 147 | mlp_1_1_outputs,1,-1.03061,0.70863277,-0.6241576 148 | mlp_1_1_outputs,10,0.9735542,1.4381392,-2.922016 149 | mlp_1_1_outputs,11,-0.37031677,0.04114491,-0.43833217 150 | mlp_1_1_outputs,12,0.16325824,0.08265705,-0.22923788 151 | mlp_1_1_outputs,13,0.50651205,0.04826694,-0.2412666 152 | mlp_1_1_outputs,14,-0.68649834,0.16758002,-0.8419172 153 | mlp_1_1_outputs,15,0.22871964,-0.3694542,0.036439355 154 | mlp_1_1_outputs,2,0.52248716,0.56564903,0.028707229 155 | mlp_1_1_outputs,3,-0.13902096,0.83899367,0.045766518 156 | mlp_1_1_outputs,4,0.8631842,-0.11090764,-3.633642 157 | mlp_1_1_outputs,5,0.24383672,0.4300947,-0.028894566 158 | mlp_1_1_outputs,6,0.77099085,0.40300265,-0.6123488 159 | mlp_1_1_outputs,7,0.80704874,-0.35886392,-1.2488782 160 | mlp_1_1_outputs,8,0.05687267,-0.7564281,0.10573861 161 | mlp_1_1_outputs,9,0.55384576,-0.005512286,-0.25716427 162 | mlp_2_0_outputs,0,-0.43766528,0.63004327,0.82238406 163 | mlp_2_0_outputs,1,-0.28752628,0.024231752,-0.73521745 164 | mlp_2_0_outputs,10,0.42408463,-1.1208787,-0.8666483 165 | mlp_2_0_outputs,11,2.4725761,-0.24301499,0.1757837 166 | mlp_2_0_outputs,12,0.78457654,-0.50197077,-0.31981418 167 | mlp_2_0_outputs,13,0.41764382,-0.94169605,-0.4211051 168 | mlp_2_0_outputs,14,1.2861595,-0.5318925,-0.036760274 169 | mlp_2_0_outputs,15,0.43116912,-0.94550234,-0.62638485 170 | mlp_2_0_outputs,2,0.057849146,0.12554613,-0.17656699 171 | mlp_2_0_outputs,3,0.9872132,-0.6279045,-0.24866179 172 | mlp_2_0_outputs,4,0.0049139177,1.9990245,-2.5793169 173 | mlp_2_0_outputs,5,0.52743393,-0.35086772,-0.1816786 174 | mlp_2_0_outputs,6,1.2171247,-1.0690438,-0.7097241 175 | mlp_2_0_outputs,7,1.2408279,-0.32120407,-2.283859 176 | mlp_2_0_outputs,8,1.1670102,-0.3065825,-1.6924549 177 | mlp_2_0_outputs,9,0.47862154,0.17900337,-1.1995778 178 | mlp_2_1_outputs,0,1.5722502,-1.8906522,-3.274062 179 | mlp_2_1_outputs,1,-0.14563715,0.071487516,0.4061026 180 | mlp_2_1_outputs,10,0.6692192,0.022232514,0.052336734 181 | mlp_2_1_outputs,11,-0.4981686,0.010872859,-0.120791934 182 | mlp_2_1_outputs,12,0.21853887,0.6528137,0.40027425 183 | mlp_2_1_outputs,13,-0.0539364,-0.04593685,-0.77657515 184 | mlp_2_1_outputs,14,-0.029986976,-0.16086973,0.45140362 185 | mlp_2_1_outputs,15,-0.19369616,0.82837254,0.23891859 186 | mlp_2_1_outputs,2,-1.1634233,1.885357,-3.579113 187 | mlp_2_1_outputs,3,0.46121255,0.038033824,0.0016819923 188 | mlp_2_1_outputs,4,0.09414851,0.5597031,0.2918426 189 | mlp_2_1_outputs,5,-0.41278046,0.0027908396,0.085692376 190 | mlp_2_1_outputs,6,0.1588337,0.11019328,-0.09616117 191 | mlp_2_1_outputs,7,-1.097063,-0.48358715,1.1503831 192 | mlp_2_1_outputs,8,0.4955245,0.4686215,0.23423254 193 | mlp_2_1_outputs,9,-0.049522586,0.9311799,-0.728278 194 | num_attn_0_0_outputs,_,-0.85263616,-0.036784045,-0.13747008 195 | num_attn_0_1_outputs,_,0.33958793,0.86033154,0.89450693 196 | num_attn_1_0_outputs,_,-0.78232145,0.5346261,0.2896964 197 | num_attn_1_1_outputs,_,6.0080223,-6.6434155,-3.9860556 198 | num_attn_2_0_outputs,_,-0.77475876,1.1810129,1.0096526 199 | num_attn_2_1_outputs,_,1.2496741,-1.8492256,1.0177201 200 | num_mlp_0_0_outputs,0,0.3314178,0.5325629,-0.16401817 201 | num_mlp_0_0_outputs,1,0.19410345,0.24263886,-0.3985704 202 | num_mlp_0_0_outputs,10,-0.05038471,0.23930399,-0.5071923 203 | num_mlp_0_0_outputs,11,0.14974928,0.4013859,-0.500009 204 | num_mlp_0_0_outputs,12,0.3967017,0.6515252,-0.013054512 205 | num_mlp_0_0_outputs,13,-0.3407226,-0.23862363,-1.0013121 206 | num_mlp_0_0_outputs,14,-0.0042749355,0.07379021,-0.87100405 207 | num_mlp_0_0_outputs,15,0.044042736,0.5217005,-0.44642144 208 | num_mlp_0_0_outputs,2,-0.04536464,0.35205916,-0.63331443 209 | num_mlp_0_0_outputs,3,0.30032563,0.5497403,-0.3833386 210 | num_mlp_0_0_outputs,4,0.4260003,0.63091743,-0.28802612 211 | num_mlp_0_0_outputs,5,0.25858745,0.0029270295,-0.3307586 212 | num_mlp_0_0_outputs,6,-0.068435706,0.053463157,-0.76825386 213 | num_mlp_0_0_outputs,7,0.36288938,0.19408187,-0.6395527 214 | num_mlp_0_0_outputs,8,0.6338136,0.76790816,-0.06365792 215 | num_mlp_0_0_outputs,9,0.087720215,0.43525898,-0.31613958 216 | num_mlp_0_1_outputs,0,-0.06632771,-0.070165716,-0.4955601 217 | num_mlp_0_1_outputs,1,0.2773735,0.34788114,-0.2892638 218 | num_mlp_0_1_outputs,10,0.2745447,-0.13056615,-0.8543876 219 | num_mlp_0_1_outputs,11,0.044537343,-0.32159278,-0.9105748 220 | num_mlp_0_1_outputs,12,0.18666676,0.1831349,-0.1372812 221 | num_mlp_0_1_outputs,13,0.23526604,-0.16011046,-0.73585725 222 | num_mlp_0_1_outputs,14,0.06875769,0.27533332,-0.23280683 223 | num_mlp_0_1_outputs,15,0.20177683,0.22331731,-0.3380389 224 | num_mlp_0_1_outputs,2,-0.26213843,-0.25991192,-1.0179317 225 | num_mlp_0_1_outputs,3,0.5162888,0.4858723,-0.36303845 226 | num_mlp_0_1_outputs,4,0.31894425,0.2724289,-0.30766574 227 | num_mlp_0_1_outputs,5,0.6092662,0.19296132,-0.85525966 228 | num_mlp_0_1_outputs,6,-0.07518837,0.025708076,-0.6562557 229 | num_mlp_0_1_outputs,7,0.6991878,0.6410874,0.10406117 230 | num_mlp_0_1_outputs,8,-0.3627702,-0.08890215,-0.4461756 231 | num_mlp_0_1_outputs,9,0.024958644,-0.08254894,-0.6912631 232 | num_mlp_1_0_outputs,0,0.3114148,0.55049676,-0.22676654 233 | num_mlp_1_0_outputs,1,0.07806706,-0.023488285,-0.30448 234 | num_mlp_1_0_outputs,10,-0.036106024,-0.061172247,-0.6728275 235 | num_mlp_1_0_outputs,11,0.00018137011,0.24429022,-0.113959186 236 | num_mlp_1_0_outputs,12,0.5541197,0.3748864,-0.091273285 237 | num_mlp_1_0_outputs,13,0.0910017,0.37516505,-0.11659233 238 | num_mlp_1_0_outputs,14,0.029605865,-0.40416694,-0.39333186 239 | num_mlp_1_0_outputs,15,0.4850837,0.17000733,-0.15730006 240 | num_mlp_1_0_outputs,2,-0.09316445,-0.3945504,-0.74077773 241 | num_mlp_1_0_outputs,3,0.77411944,0.45456645,-0.2434635 242 | num_mlp_1_0_outputs,4,0.45820743,0.267871,-0.08835843 243 | num_mlp_1_0_outputs,5,0.8663778,0.48250827,0.17160328 244 | num_mlp_1_0_outputs,6,0.18805933,0.16399784,-0.35467726 245 | num_mlp_1_0_outputs,7,0.4086816,0.020269368,-0.34887818 246 | num_mlp_1_0_outputs,8,-0.026719714,-0.16764963,-0.65317184 247 | num_mlp_1_0_outputs,9,0.019892184,-0.01170788,-0.5114425 248 | num_mlp_1_1_outputs,0,-0.078854226,0.08707899,-0.98892313 249 | num_mlp_1_1_outputs,1,-0.15243109,0.12799971,-0.72895944 250 | num_mlp_1_1_outputs,10,0.58201754,0.6582406,-0.16776891 251 | num_mlp_1_1_outputs,11,0.23494452,0.5244795,-0.41649953 252 | num_mlp_1_1_outputs,12,-0.01350627,-0.00017212816,-0.47231874 253 | num_mlp_1_1_outputs,13,-0.020348622,0.20011073,-0.66538125 254 | num_mlp_1_1_outputs,14,0.16774431,0.5128025,-0.28296882 255 | num_mlp_1_1_outputs,15,0.33657792,0.6268225,-0.22136451 256 | num_mlp_1_1_outputs,2,0.36348855,0.56271625,-0.2105581 257 | num_mlp_1_1_outputs,3,0.24504001,0.5178616,-0.24728571 258 | num_mlp_1_1_outputs,4,-0.05191437,0.09888679,-0.6176198 259 | num_mlp_1_1_outputs,5,-0.06691858,0.15422752,-0.6520832 260 | num_mlp_1_1_outputs,6,0.8261999,1.1822537,0.19525766 261 | num_mlp_1_1_outputs,7,-0.24775888,0.09739883,-0.73482597 262 | num_mlp_1_1_outputs,8,0.21655971,0.25961864,-0.49139246 263 | num_mlp_1_1_outputs,9,-0.23927306,-0.0052134893,-0.8000629 264 | num_mlp_2_0_outputs,0,0.7344719,0.48031777,-0.0072156666 265 | num_mlp_2_0_outputs,1,-0.08577925,-0.13313006,-0.8903967 266 | num_mlp_2_0_outputs,10,0.07255866,-0.023388484,-0.35158905 267 | num_mlp_2_0_outputs,11,0.21112183,0.28738138,-0.1871139 268 | num_mlp_2_0_outputs,12,0.3063709,0.3118506,-0.3509904 269 | num_mlp_2_0_outputs,13,0.20950411,0.26588303,-0.3123367 270 | num_mlp_2_0_outputs,14,0.02394933,0.14214058,-0.45821595 271 | num_mlp_2_0_outputs,15,0.15515622,-0.23100227,-0.7154926 272 | num_mlp_2_0_outputs,2,0.071676075,0.32760754,-0.3858776 273 | num_mlp_2_0_outputs,3,-0.23932433,-0.17936808,-0.6346995 274 | num_mlp_2_0_outputs,4,-0.07454885,-0.04197524,-0.8073381 275 | num_mlp_2_0_outputs,5,0.44984394,0.27431875,-0.26230815 276 | num_mlp_2_0_outputs,6,0.24677134,0.27385816,-0.17668292 277 | num_mlp_2_0_outputs,7,-0.4095317,-0.1621746,-0.8552454 278 | num_mlp_2_0_outputs,8,0.4552858,0.8000447,0.23228435 279 | num_mlp_2_0_outputs,9,0.5577926,0.5177477,-0.009681215 280 | num_mlp_2_1_outputs,0,1.1987035,-0.33801028,-0.50512236 281 | num_mlp_2_1_outputs,1,0.94955945,-0.6463761,-0.7593889 282 | num_mlp_2_1_outputs,10,-0.29233116,0.7033983,0.14847362 283 | num_mlp_2_1_outputs,11,0.263601,-0.39432043,-0.87329227 284 | num_mlp_2_1_outputs,12,2.5524316,-2.5761328,-0.87832254 285 | num_mlp_2_1_outputs,13,0.6176073,-0.8246647,-1.1955972 286 | num_mlp_2_1_outputs,14,0.20788954,0.76745355,0.03399014 287 | num_mlp_2_1_outputs,15,0.72697127,-0.68905944,-0.7866722 288 | num_mlp_2_1_outputs,2,-0.5621572,0.887737,0.11716976 289 | num_mlp_2_1_outputs,3,1.18867,-0.34044042,-0.44703516 290 | num_mlp_2_1_outputs,4,-0.57840586,0.45728707,-0.35807198 291 | num_mlp_2_1_outputs,5,0.6322493,0.63687426,0.113720134 292 | num_mlp_2_1_outputs,6,-0.19396284,1.0908433,0.08487423 293 | num_mlp_2_1_outputs,7,0.47693765,0.46259257,-0.7392341 294 | num_mlp_2_1_outputs,8,-0.1926289,-0.3515196,-0.8700848 295 | num_mlp_2_1_outputs,9,-0.45918483,0.6270346,-0.17804115 296 | ones,_,-0.026366323,0.395983,-0.2675112 297 | positions,0,-0.21428509,0.07516778,0.20615174 298 | positions,1,1.3148408,0.7747299,-9.5917425 299 | positions,10,-0.9434932,-0.7169733,1.7839171 300 | positions,11,2.346714,-0.6581614,-11.038458 301 | positions,12,-1.921217,-0.8985177,2.4055943 302 | positions,13,2.3171093,-0.23238912,-11.31294 303 | positions,14,-2.6990774,-1.6240467,2.7664332 304 | positions,15,3.4511793,-0.7494351,-10.082401 305 | positions,2,-3.2204971,-0.737425,6.4716024 306 | positions,3,-1.0551778,2.5078318,-2.613356 307 | positions,4,-1.4998984,0.22135884,1.4028941 308 | positions,5,0.5460883,2.1069403,-8.382127 309 | positions,6,-0.50053823,0.633908,2.4933894 310 | positions,7,0.29808688,1.4961973,-8.695588 311 | positions,8,-0.6610369,0.24578033,2.0996253 312 | positions,9,3.1191888,0.5754417,-5.0389624 313 | tokens,(,-1.1092756,2.403831,-5.864368 314 | tokens,),0.9881682,-2.3895998,3.3189006 315 | tokens,,0.4111517,-0.24375123,-0.72117674 316 | tokens,,-0.004239899,-0.017491871,-0.007405918 317 | tokens,_,1.3654574,0.29655144,0.28431866 318 | tokens,_,1.4269474,-0.7771767,-0.16797024 319 | tokens,_,0.3346523,-0.2013016,0.14942457 320 | tokens,_,1.3444127,-0.10818399,0.32242635 321 | tokens,_,0.07180428,-0.7881537,0.19018956 322 | tokens,_,-0.31608832,-0.19187881,-0.06436471 323 | tokens,_,0.9428614,-1.0588603,-0.5878867 324 | tokens,_,1.2218833,0.40168327,-0.35955858 325 | tokens,_,-0.5326404,-0.19429654,-0.4998532 326 | tokens,_,0.7153185,-0.070266396,-0.59234005 327 | tokens,{,-0.7310767,2.8042543,-4.676652 328 | tokens,},0.6038923,-3.4896483,2.4227173 329 | -------------------------------------------------------------------------------- /programs/rasp/hist/hist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def select_closest(keys, queries, predicate): 6 | scores = [[False for _ in keys] for _ in queries] 7 | for i, q in enumerate(queries): 8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)] 9 | if not (any(matches)): 10 | scores[i][0] = True 11 | else: 12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j)) 13 | scores[i][j] = True 14 | return scores 15 | 16 | 17 | def select(keys, queries, predicate): 18 | return [[predicate(q, k) for k in keys] for q in queries] 19 | 20 | 21 | def aggregate(attention, values): 22 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention] 23 | 24 | 25 | def aggregate_sum(attention, values): 26 | return [sum([v for a, v in zip(attn, values) if a]) for attn in attention] 27 | 28 | 29 | def run(tokens): 30 | # classifier weights ########################################## 31 | classifier_weights = pd.read_csv( 32 | "programs/rasp/hist/hist_weights.csv", index_col=[0, 1], dtype={"feature": str} 33 | ) 34 | # inputs ##################################################### 35 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]] 36 | 37 | positions = list(range(len(tokens))) 38 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]] 39 | 40 | ones = [1 for _ in range(len(tokens))] 41 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0) 42 | 43 | # attn_0_0 #################################################### 44 | def predicate_0_0(q_token, k_token): 45 | if q_token in {"1", "4", "0", "2", "3", "5"}: 46 | return k_token == "" 47 | elif q_token in {""}: 48 | return k_token == "" 49 | 50 | attn_0_0_pattern = select_closest(tokens, tokens, predicate_0_0) 51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, positions) 52 | attn_0_0_output_scores = classifier_weights.loc[ 53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs] 54 | ] 55 | 56 | # attn_0_1 #################################################### 57 | def predicate_0_1(q_position, k_position): 58 | if q_position in {0}: 59 | return k_position == 7 60 | elif q_position in {1, 2, 3, 5}: 61 | return k_position == 0 62 | elif q_position in {4, 7}: 63 | return k_position == 4 64 | elif q_position in {6}: 65 | return k_position == 2 66 | 67 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1) 68 | attn_0_1_outputs = aggregate(attn_0_1_pattern, positions) 69 | attn_0_1_output_scores = classifier_weights.loc[ 70 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs] 71 | ] 72 | 73 | # num_attn_0_0 #################################################### 74 | def num_predicate_0_0(q_token, k_token): 75 | if q_token in {"0"}: 76 | return k_token == "0" 77 | elif q_token in {"1"}: 78 | return k_token == "1" 79 | elif q_token in {"2"}: 80 | return k_token == "2" 81 | elif q_token in {"3"}: 82 | return k_token == "3" 83 | elif q_token in {"4"}: 84 | return k_token == "4" 85 | elif q_token in {"5"}: 86 | return k_token == "5" 87 | elif q_token in {""}: 88 | return k_token == "" 89 | 90 | num_attn_0_0_pattern = select(tokens, tokens, num_predicate_0_0) 91 | num_attn_0_0_outputs = aggregate_sum(num_attn_0_0_pattern, ones) 92 | num_attn_0_0_output_scores = classifier_weights.loc[ 93 | [("num_attn_0_0_outputs", "_") for v in num_attn_0_0_outputs] 94 | ].mul(num_attn_0_0_outputs, axis=0) 95 | 96 | # num_attn_0_1 #################################################### 97 | def num_predicate_0_1(q_token, k_token): 98 | if q_token in {"1", "0", "3"}: 99 | return k_token == "5" 100 | elif q_token in {"2"}: 101 | return k_token == "1" 102 | elif q_token in {"4"}: 103 | return k_token == "2" 104 | elif q_token in {"5"}: 105 | return k_token == "3" 106 | elif q_token in {""}: 107 | return k_token == "" 108 | 109 | num_attn_0_1_pattern = select(tokens, tokens, num_predicate_0_1) 110 | num_attn_0_1_outputs = aggregate_sum(num_attn_0_1_pattern, ones) 111 | num_attn_0_1_output_scores = classifier_weights.loc[ 112 | [("num_attn_0_1_outputs", "_") for v in num_attn_0_1_outputs] 113 | ].mul(num_attn_0_1_outputs, axis=0) 114 | 115 | # mlp_0_0 ##################################################### 116 | def mlp_0_0(position, attn_0_1_output): 117 | key = (position, attn_0_1_output) 118 | return 3 119 | 120 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(positions, attn_0_1_outputs)] 121 | mlp_0_0_output_scores = classifier_weights.loc[ 122 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs] 123 | ] 124 | 125 | # num_mlp_0_0 ################################################# 126 | def num_mlp_0_0(num_attn_0_0_output): 127 | key = num_attn_0_0_output 128 | if key in {0, 1}: 129 | return 6 130 | return 3 131 | 132 | num_mlp_0_0_outputs = [num_mlp_0_0(k0) for k0 in num_attn_0_0_outputs] 133 | num_mlp_0_0_output_scores = classifier_weights.loc[ 134 | [("num_mlp_0_0_outputs", str(v)) for v in num_mlp_0_0_outputs] 135 | ] 136 | 137 | feature_logits = pd.concat( 138 | [ 139 | df.reset_index() 140 | for df in [ 141 | token_scores, 142 | position_scores, 143 | attn_0_0_output_scores, 144 | attn_0_1_output_scores, 145 | mlp_0_0_output_scores, 146 | num_mlp_0_0_output_scores, 147 | one_scores, 148 | num_attn_0_0_output_scores, 149 | num_attn_0_1_output_scores, 150 | ] 151 | ] 152 | ) 153 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy() 154 | classes = classifier_weights.columns.to_numpy() 155 | predictions = classes[logits.argmax(-1)] 156 | if tokens[0] == "": 157 | predictions[0] = "" 158 | if tokens[-1] == "": 159 | predictions[-1] = "" 160 | return predictions.tolist() 161 | 162 | 163 | examples = [ 164 | ( 165 | ["", "2", "5", "3", "2", "1", "5", "3"], 166 | ["", "2", "2", "2", "2", "1", "2", "2"], 167 | ), 168 | (["", "2", "4", "3", "5", "2", "5"], ["", "2", "1", "1", "2", "2", "2"]), 169 | (["", "5", "2", "0", "4", "3", "4"], ["", "1", "1", "1", "2", "1", "2"]), 170 | ( 171 | ["", "5", "3", "4", "1", "2", "5", "2"], 172 | ["", "2", "1", "1", "1", "2", "2", "2"], 173 | ), 174 | ( 175 | ["", "2", "0", "2", "0", "3", "4", "4"], 176 | ["", "2", "2", "2", "2", "1", "2", "2"], 177 | ), 178 | (["", "5", "0", "3", "2", "5"], ["", "2", "1", "1", "1", "2"]), 179 | (["", "4", "5", "0", "2", "3", "1"], ["", "1", "1", "1", "1", "1", "1"]), 180 | (["", "2", "5", "5"], ["", "1", "2", "2"]), 181 | ( 182 | ["", "0", "2", "3", "2", "3", "0", "3"], 183 | ["", "2", "2", "3", "2", "3", "2", "3"], 184 | ), 185 | ( 186 | ["", "2", "1", "1", "2", "3", "3", "4"], 187 | ["", "2", "2", "2", "2", "2", "2", "1"], 188 | ), 189 | ] 190 | for x, y in examples: 191 | print(f"x: {x}") 192 | print(f"y: {y}") 193 | y_hat = run(x) 194 | print(f"y_hat: {y_hat}") 195 | print() 196 | -------------------------------------------------------------------------------- /programs/rasp/hist/hist_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,1,2,3,4,5,6 2 | attn_0_0_outputs,0,6.781757,25.452711,10.974162,-1.9199638,-10.4853525,-15.709011 3 | attn_0_0_outputs,1,1.4295937,3.2632368,-0.074068084,-2.4339116,-5.1740317,-3.098548 4 | attn_0_0_outputs,2,1.0853696,2.85033,-0.07023634,-2.7719102,-5.3003054,-3.8455026 5 | attn_0_0_outputs,3,1.5324218,3.3294716,0.415891,-2.2662394,-4.7209177,-3.3629472 6 | attn_0_0_outputs,4,1.0801277,2.88044,0.028874561,-2.579776,-4.811679,-3.6099398 7 | attn_0_0_outputs,5,1.0413198,3.1703262,-0.014626876,-2.2997146,-4.568741,-2.9599016 8 | attn_0_0_outputs,6,0.33147946,2.3486278,0.06572169,-2.0305748,-4.2188454,-2.249484 9 | attn_0_0_outputs,7,0.1623374,2.830696,-0.14196976,-1.2703604,-2.8180695,-0.34874895 10 | attn_0_1_outputs,0,4.262853,13.889445,7.3905044,-0.72937316,-8.37309,-13.166056 11 | attn_0_1_outputs,1,1.2386446,2.9108822,-0.2176869,-2.825408,-5.370257,-2.9436467 12 | attn_0_1_outputs,2,1.7371144,9.852318,4.1092057,-2.8085132,-9.59537,-14.0036745 13 | attn_0_1_outputs,3,0.71794397,3.1448169,-0.023318842,-2.7755547,-5.5583167,-3.5951939 14 | attn_0_1_outputs,4,2.2684147,10.615341,5.2562523,-2.1873116,-9.4488535,-14.030566 15 | attn_0_1_outputs,5,0.5881659,2.4700139,-0.08045172,-2.2570171,-4.1498766,-1.9420767 16 | attn_0_1_outputs,6,0.40211794,1.7297664,-0.050353132,-1.4864924,-2.592189,-0.6177777 17 | attn_0_1_outputs,7,0.4194881,2.907967,0.62582093,-1.6866827,-3.9042215,-1.8792101 18 | mlp_0_0_outputs,0,0.4938747,0.630038,0.23696063,-0.5110521,-0.7338477,-0.6524683 19 | mlp_0_0_outputs,1,0.74656296,0.19177268,-0.31644797,0.058880232,-0.66460264,-1.0663412 20 | mlp_0_0_outputs,2,0.5359965,-0.04372708,-0.7205332,-0.7282277,-0.7255958,-1.1605394 21 | mlp_0_0_outputs,3,4.233813,12.146829,4.1495814,-5.1848955,-11.641673,-16.112806 22 | mlp_0_0_outputs,4,0.20570962,-0.24562106,-0.19681151,-0.36250097,-0.3743104,-0.7396391 23 | mlp_0_0_outputs,5,0.8558908,0.5187872,0.0962603,0.30497876,-0.7910682,-0.49263448 24 | mlp_0_0_outputs,6,0.4841655,1.0118659,0.03025597,-0.36506364,-0.70360786,-0.21639934 25 | mlp_0_0_outputs,7,0.9191022,0.3703126,-0.28247193,-0.29374048,-0.6344728,-0.64356095 26 | num_attn_0_0_outputs,_,-4.868864,-27.139973,-8.78817,5.6797624,15.745415,20.68757 27 | num_attn_0_1_outputs,_,0.9234888,0.40377212,0.19646904,-0.013339977,-0.3078388,-7.9514294 28 | num_mlp_0_0_outputs,0,1.1347574,-0.2206074,-0.33900908,-0.525305,-0.8139806,-0.8801327 29 | num_mlp_0_0_outputs,1,3.9370484,-3.322479,-1.9119565,-0.60652006,-0.30610657,-1.285272 30 | num_mlp_0_0_outputs,2,1.6984854,-0.56601596,-0.68408555,-0.67591697,-0.3258158,-1.005294 31 | num_mlp_0_0_outputs,3,-7.021126,15.839726,6.340541,-3.105211,-10.3643875,-16.01029 32 | num_mlp_0_0_outputs,4,1.4550712,-1.0342308,-0.34201962,-1.1408417,-0.93610877,-0.28879276 33 | num_mlp_0_0_outputs,5,3.7206275,-2.0013118,-2.0068526,-1.2758466,0.019715928,-0.79015535 34 | num_mlp_0_0_outputs,6,18.228642,-20.21013,-3.2542965,-1.0904899,-0.45162725,-0.012080559 35 | num_mlp_0_0_outputs,7,2.5776606,-1.1733114,-1.8273598,-0.85842824,-0.93384117,-0.54284114 36 | ones,_,1.283879,6.372424,2.8256679,-1.4089221,-8.6447,-11.394332 37 | positions,0,0.5897627,0.16051057,0.7623805,0.24794468,0.1367697,0.36979273 38 | positions,1,1.0287309,4.5030293,2.3415422,-1.3173267,-6.8056574,-10.195892 39 | positions,2,1.083402,4.620112,2.3640404,-1.3567489,-6.838617,-10.438745 40 | positions,3,1.0730014,4.6261773,2.4860039,-1.2073922,-6.6122475,-10.18452 41 | positions,4,1.0189941,5.8566985,2.620293,-1.7836683,-7.4319153,-11.364978 42 | positions,5,0.96149987,4.631558,2.526294,-1.1215514,-6.515995,-10.108589 43 | positions,6,0.44539887,5.6350055,2.9353447,-1.8244816,-7.916871,-11.977934 44 | positions,7,1.1673726,6.778532,3.7333648,-0.52524173,-6.0611577,-9.702886 45 | tokens,0,1.217923,4.5158324,2.319617,-1.2275133,-4.5206556,-6.714174 46 | tokens,1,0.4282554,3.2384415,1.0357013,-2.5313835,-5.9769506,-13.27017 47 | tokens,2,1.4819847,5.0116262,2.8555121,-0.65765595,-4.2250724,-5.598877 48 | tokens,3,1.1304332,4.4012556,2.251413,-1.3463753,-4.7049932,-6.437009 49 | tokens,4,0.8615442,3.7823195,1.4800204,-2.1071358,-5.914414,-12.4524975 50 | tokens,5,1.0285466,4.464686,2.3432922,-1.1491289,-4.432571,-6.315694 51 | tokens,,0.83862156,0.89218545,-0.032585267,-0.8908109,-0.610661,-0.34020367 52 | tokens,,0.3811775,0.747888,-0.18793498,-0.6211176,-0.59679526,-0.5371528 53 | -------------------------------------------------------------------------------- /programs/rasp/most_freq/most_freq_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,0,1,2,3,4,5, 2 | attn_0_0_outputs,0,0.3180253,0.17051244,0.17996229,0.08486652,-0.050912812,-0.055793837,0.43272197 3 | attn_0_0_outputs,1,-0.096741244,0.7338011,-0.0095024835,0.24725218,-0.05952198,-1.2952329,-4.8797684 4 | attn_0_0_outputs,2,0.17302921,0.5287753,0.077450015,0.43879598,0.12759952,0.024711223,-1.2387171 5 | attn_0_0_outputs,3,-0.020483518,-0.4090684,-0.06907431,-0.019290976,-0.4361626,-0.98895913,1.5400033 6 | attn_0_0_outputs,4,-0.028248059,0.12039765,-0.18920611,-0.07074923,-0.24273163,-0.27256623,0.3491913 7 | attn_0_0_outputs,5,0.24701072,0.64526,0.2259116,0.72207314,-0.093196996,1.9177046,-3.6679096 8 | attn_0_0_outputs,,0.061711412,0.2944613,-0.06865534,-0.5166919,-0.14260983,-0.3558589,0.32799473 9 | attn_0_0_outputs,,0.23678741,-1.63671,0.25388855,-0.14997865,0.11181579,-2.2149167,9.153281 10 | attn_0_1_outputs,0,1.0899844,-0.28106454,-0.85135394,-0.026211262,-0.5590216,0.21339954,0.19288713 11 | attn_0_1_outputs,1,0.092373446,0.24267156,-0.40928265,0.15436336,0.45326802,0.33714578,0.60150456 12 | attn_0_1_outputs,2,-0.03783019,-0.033535108,3.3072128,-1.0111456,0.71332467,-0.11776653,-6.409337 13 | attn_0_1_outputs,3,0.23245741,0.1589569,-0.85276276,1.2885551,0.2456359,-1.3701653,0.6277164 14 | attn_0_1_outputs,4,0.2582553,0.33029148,-0.46015385,0.35623923,0.8750041,-1.0261837,1.0965881 15 | attn_0_1_outputs,5,0.112118945,0.1932354,-0.35270444,0.3343012,0.33635232,1.7447497,0.049526434 16 | attn_0_1_outputs,,0.25605395,0.05547662,0.16247602,-0.14924651,0.14146863,-0.53788525,0.24325536 17 | attn_0_1_outputs,,0.29997656,-0.10898747,-0.51443803,0.098860115,-0.02304782,0.09178379,0.74192613 18 | attn_0_2_outputs,0,0.27396625,0.2443406,-0.14687176,0.19030151,0.3444485,-0.0885761,0.026525198 19 | attn_0_2_outputs,1,-0.072723806,-0.6435364,0.2422663,-0.5205403,-0.01871688,-0.12267997,-0.45468605 20 | attn_0_2_outputs,2,-0.0013228728,0.30712268,-0.5249838,0.077301025,0.1839496,-0.0076866625,-0.754529 21 | attn_0_2_outputs,3,-0.10197616,-0.22210558,-0.057100624,-0.5848149,-0.028931309,-0.16512564,-0.57091314 22 | attn_0_2_outputs,4,0.25916776,0.40352133,0.33918947,0.13571715,-0.34599137,0.010477799,-0.6685143 23 | attn_0_2_outputs,5,0.6194994,0.941407,0.93209845,0.742026,0.68645924,0.18754262,-1.0428277 24 | attn_0_2_outputs,,0.14436923,-0.95419455,0.082035325,-0.03649257,-0.070194826,-0.072033234,0.5734176 25 | attn_0_2_outputs,,-0.30024177,0.7607851,-0.07830339,0.13043396,-0.15203752,0.47434998,1.0684925 26 | attn_0_3_outputs,0,0.4359758,0.39296976,0.3107209,0.27862722,0.4337573,-0.5773376,0.78378 27 | attn_0_3_outputs,1,0.28894436,2.4578593,0.23578313,0.49545267,0.65666926,-2.2483876,-4.0735903 28 | attn_0_3_outputs,2,0.09834935,0.06901251,-0.50735116,0.30153948,0.31839263,-0.32487082,-0.114381574 29 | attn_0_3_outputs,3,0.30135068,-1.0148153,0.27192265,-0.124534786,0.7126381,-2.1536736,2.5469096 30 | attn_0_3_outputs,4,0.30178407,0.8716293,0.07430823,0.7604845,0.55651057,-1.229208,0.050405312 31 | attn_0_3_outputs,5,-0.13249332,-1.1359549,-0.13182358,0.3849195,-0.26988852,3.4882991,-7.758176 32 | attn_0_3_outputs,,0.38542554,-0.20650521,0.45955607,-0.4878401,-0.031885352,-0.50210035,0.16015093 33 | attn_0_3_outputs,,0.28155896,-0.31953847,0.32690638,0.54727995,0.70288324,-2.5177412,3.5819783 34 | attn_1_0_outputs,0,2.0141475,0.34906653,0.68659526,0.9869916,-0.13294022,1.8037857,-10.206428 35 | attn_1_0_outputs,1,-0.19686982,0.73792344,-0.12540162,-1.3587084,0.16657266,1.8406917,1.5816158 36 | attn_1_0_outputs,2,-0.226673,-0.26789877,0.18990181,0.16948664,0.52418464,0.2592063,1.2878286 37 | attn_1_0_outputs,3,-0.83169764,3.5384831,-0.4879805,-1.1000174,-0.6881626,-0.77432674,-13.3651 38 | attn_1_0_outputs,4,0.15082398,-0.14441168,-0.005273094,0.23471557,0.20179209,0.5158661,0.38174176 39 | attn_1_0_outputs,5,0.13206911,-1.4834553,0.21140684,0.68273836,0.37051022,0.6227547,-2.3825295 40 | attn_1_0_outputs,6,-0.46533993,-0.44353464,-0.72359395,-0.74037266,-0.42077038,-0.08801137,3.6209009 41 | attn_1_0_outputs,7,-0.45181456,-10.881032,-0.46400476,0.98205274,-0.11754535,0.46539927,11.495142 42 | attn_1_1_outputs,0,0.9750074,-0.33286405,-0.09389574,-0.04933356,0.94795406,0.24440528,0.02822969 43 | attn_1_1_outputs,1,0.09181721,-0.108542785,-0.004596466,0.09085039,-0.0960271,-0.2939083,-1.4847285 44 | attn_1_1_outputs,2,0.13675968,-0.14381808,-0.14170228,1.3159441,0.84011686,0.187356,-0.2180252 45 | attn_1_1_outputs,3,0.03167638,-0.2564188,0.10374082,-0.030116782,0.37519884,-0.17905286,0.4855987 46 | attn_1_1_outputs,4,0.3155073,-0.27092186,-0.030149449,0.4447328,0.26189294,0.13343854,-0.2836621 47 | attn_1_1_outputs,5,0.79908526,0.33884925,0.5204918,0.563897,0.31888938,0.24198307,-3.097063 48 | attn_1_1_outputs,,0.12040514,-0.3541742,-0.37897912,0.33264196,-0.1644499,0.09382361,0.47280157 49 | attn_1_1_outputs,,-0.21088882,-0.25195375,-0.38574246,-0.26290575,-0.38151267,-0.6612043,3.2316494 50 | attn_1_2_outputs,0,1.4660153,-0.07058339,0.15579924,0.22429287,0.015527983,-0.008877302,-0.4786283 51 | attn_1_2_outputs,1,-0.32045385,0.7251625,-0.16099809,-0.6113556,-0.20735613,-0.34426436,0.03460376 52 | attn_1_2_outputs,2,-0.05991355,-0.26121002,1.5285293,-0.1514813,0.13110724,-0.5427844,-0.37770408 53 | attn_1_2_outputs,3,-0.17418517,-0.5064767,0.020557076,1.8331826,-0.19199504,-0.5358481,-0.2174345 54 | attn_1_2_outputs,4,-0.33506688,-0.4865637,-0.19673166,-0.35718244,1.8005149,-0.7867146,-0.5104927 55 | attn_1_2_outputs,5,0.1529805,-0.11731564,0.11252774,0.09463248,0.026155643,1.275788,-0.39444077 56 | attn_1_2_outputs,,-0.03210811,-0.1851439,-0.52887833,-0.252833,-0.6852915,-0.4575861,0.27375743 57 | attn_1_2_outputs,,-0.03430497,-0.044254147,-0.15155597,0.057937745,-0.022528501,-0.033426747,0.46648577 58 | attn_1_3_outputs,0,0.35898498,-0.20034088,0.020413859,-0.22355627,-10.102141,-0.33305383,7.6976376 59 | attn_1_3_outputs,1,1.0111092,-0.062783584,0.5327143,0.87491333,1.3438202,1.1383574,-2.6126447 60 | attn_1_3_outputs,2,0.081306025,-0.57257915,-0.17199379,0.23137064,1.004636,0.7501657,-3.3402092 61 | attn_1_3_outputs,3,0.4108505,-0.08137371,0.49178204,1.0274754,-0.049497884,1.4293927,-3.6926756 62 | attn_1_3_outputs,4,0.032954328,-0.26930884,0.1631624,0.6166075,-0.80901116,0.41795555,-0.96519506 63 | attn_1_3_outputs,5,0.14050587,-0.043274768,0.389469,0.94465864,0.34980577,1.3579348,-2.1161852 64 | attn_1_3_outputs,6,0.44501925,0.13667245,0.6853415,0.5725537,2.0187433,0.6294245,-5.953397 65 | attn_1_3_outputs,7,-0.46666646,0.5990393,0.025998702,0.41633928,0.02385631,1.1114172,-1.6757435 66 | attn_2_0_outputs,0,0.3495429,0.30337927,0.49805364,0.13225693,-0.2870291,0.27936482,-0.6221486 67 | attn_2_0_outputs,1,0.27308455,-0.14491963,-0.059935313,0.3436768,-0.5626455,0.080760516,-0.67031586 68 | attn_2_0_outputs,2,-0.11247145,0.52780163,-0.031945467,0.16022672,-0.08858324,-0.13487771,1.2465036 69 | attn_2_0_outputs,3,0.4697936,0.82808137,0.5196833,0.3312025,0.35191402,-0.13352305,0.16811055 70 | attn_2_0_outputs,4,0.58132875,0.8832289,0.71233314,0.5913085,0.2582655,0.5411773,-0.13924693 71 | attn_2_0_outputs,5,-0.5324829,-0.54294145,-0.47411755,-0.559129,-0.962662,-0.80526704,0.70581704 72 | attn_2_0_outputs,6,0.29960564,0.5047627,0.42610115,0.47605318,0.055260766,0.2752538,0.2869074 73 | attn_2_0_outputs,7,0.8061134,1.0037099,0.43810177,0.50335884,0.08938471,0.5953439,-0.61285174 74 | attn_2_1_outputs,0,0.35270405,-0.050330415,0.35786822,0.8349852,0.43520316,0.37781978,-1.8966578 75 | attn_2_1_outputs,1,-0.5173535,-1.0113736,-0.77803063,-12.798715,-0.7180023,-0.73949975,13.488985 76 | attn_2_1_outputs,2,0.036005918,-0.15727578,0.045522507,0.9319064,-0.048925404,-0.14085157,-1.1990162 77 | attn_2_1_outputs,3,1.0115867,0.43455648,0.93790376,1.9610963,0.88474375,0.71575737,-6.8169203 78 | attn_2_1_outputs,4,0.77685916,0.17810385,0.14974725,0.837453,-0.38813585,-0.10558918,-0.19106685 79 | attn_2_1_outputs,5,0.5566551,-0.01609987,0.4755912,1.9255587,0.53078216,0.33371648,-6.7548733 80 | attn_2_1_outputs,6,0.12431612,-0.3542981,0.30530918,1.5669754,0.23087248,0.45061296,-1.3779991 81 | attn_2_1_outputs,7,0.22629534,-0.6761686,0.6213678,1.038082,0.306376,-0.59195256,-0.19801751 82 | attn_2_2_outputs,0,-0.13583189,-0.03280338,-0.19608569,0.08558506,-3.1654773,0.06910204,1.4312361 83 | attn_2_2_outputs,1,-0.38239425,0.05366795,0.022488022,-0.042790532,-0.0076938574,-0.35101935,2.496905 84 | attn_2_2_outputs,2,0.4532444,0.4999768,0.5216906,0.4312385,0.6598157,0.57998353,-3.2020876 85 | attn_2_2_outputs,3,0.69540215,-0.018808544,0.92008996,0.85889035,1.0399714,0.9349368,-1.2531354 86 | attn_2_2_outputs,4,0.32867137,-0.498774,0.3323383,0.422421,0.10450748,0.53979826,-0.23186198 87 | attn_2_2_outputs,5,0.30699706,-0.16054757,0.245183,0.09240096,0.3129589,-0.011169057,-1.3021853 88 | attn_2_2_outputs,6,0.4382863,0.8850737,0.5132612,0.5126867,0.6615926,0.052526668,-1.1747515 89 | attn_2_2_outputs,7,0.798873,-3.1653268,0.8079976,0.46534437,0.99906564,-0.27384913,0.41713637 90 | attn_2_3_outputs,0,-0.0595329,-0.3432393,0.30217978,0.16495714,0.19219992,0.17431322,1.4355792 91 | attn_2_3_outputs,1,0.12711096,-1.4225363,0.22368737,-0.027527437,0.09366445,-0.040307615,0.01481341 92 | attn_2_3_outputs,2,0.58099073,0.3474417,-0.41153923,0.79462695,0.9205741,0.8575549,-1.4096668 93 | attn_2_3_outputs,3,-0.042766575,-0.65066993,0.20429829,-0.9746706,-0.12820214,-0.34311306,-0.5651443 94 | attn_2_3_outputs,4,0.4263014,0.39161858,0.58194137,0.5913219,-0.66568244,-0.35783538,0.029537905 95 | attn_2_3_outputs,5,0.7230335,0.28817055,0.8745937,0.90366805,0.88130885,-0.37258768,-0.67117524 96 | attn_2_3_outputs,,-0.2462671,-0.82254106,-0.20106415,-0.16820565,0.0874213,-0.04618327,-0.017906675 97 | attn_2_3_outputs,,0.1876988,0.5234487,0.4194045,0.589009,0.3397852,-0.13758777,-2.7441623 98 | mlp_0_0_outputs,0,-0.009067974,-0.10616362,0.5120841,0.30963194,0.2905615,0.5171235,1.3900716 99 | mlp_0_0_outputs,1,-0.18765688,0.7436373,0.18984891,0.51863587,0.27437514,-2.8250957,2.2061565 100 | mlp_0_0_outputs,2,0.1514967,-0.24174538,0.40851867,0.06300549,0.45437306,1.9693145,-1.5879657 101 | mlp_0_0_outputs,3,-0.074622594,2.0452924,0.074981324,2.5786343,0.48228234,-0.98552316,-1.702283 102 | mlp_0_0_outputs,4,0.87153715,0.71818954,1.6129097,1.2832383,1.3467029,1.3728735,-4.159039 103 | mlp_0_0_outputs,5,0.25284207,-0.92589784,0.52295667,0.26822376,0.86385244,0.7218063,0.9255813 104 | mlp_0_0_outputs,6,-2.2069793,-1.6042551,-1.9572843,-1.4469824,-1.7742672,-1.4493668,5.8365097 105 | mlp_0_0_outputs,7,0.6310457,-2.2726629,0.96292347,0.3914378,0.8041782,1.9205551,-7.1911926 106 | mlp_0_1_outputs,0,0.2533135,2.9234807,0.8514278,0.9629272,0.9832052,0.7691257,-2.316475 107 | mlp_0_1_outputs,1,0.28654918,0.3766166,1.2217463,1.2102123,0.11371473,1.5548589,-3.2720823 108 | mlp_0_1_outputs,2,1.2309315,-0.3645303,0.39764106,0.63220364,0.13651253,0.668529,-3.4863636 109 | mlp_0_1_outputs,3,-1.8739396,-2.2920024,-2.590256,-1.8237208,-2.8227699,-2.6384785,6.279832 110 | mlp_0_1_outputs,4,3.6876507,1.1458737,1.7822194,2.422588,4.042542,1.0609845,-10.931403 111 | mlp_0_1_outputs,5,0.87431866,1.2460302,-0.2996694,1.574731,-0.5186607,-0.72277695,-4.7037625 112 | mlp_0_1_outputs,6,-3.0101469,-0.37742805,-3.9014912,-3.186129,-4.1218615,-3.49612,8.171139 113 | mlp_0_1_outputs,7,2.8105667,0.41083074,2.3087406,1.4794773,0.5835391,0.5765365,-7.642529 114 | mlp_1_0_outputs,0,0.10424284,0.42691907,0.26822764,0.72388,0.4306211,0.3545007,-1.45301 115 | mlp_1_0_outputs,1,-0.29419023,0.08807841,-0.19642091,0.2297624,0.11144069,0.22908655,0.012775027 116 | mlp_1_0_outputs,2,0.033752218,0.5904094,0.016924301,0.65878856,0.552538,0.40060502,-0.8275471 117 | mlp_1_0_outputs,3,0.03518476,0.5058186,0.004764182,0.65987307,0.27559388,0.34288618,0.8714135 118 | mlp_1_0_outputs,4,-0.33149403,0.07573308,-0.010276801,0.5206794,0.03746909,0.11693972,-0.54683185 119 | mlp_1_0_outputs,5,-0.02862338,0.37228417,0.2534814,0.36578193,0.43761313,0.74797416,0.6001068 120 | mlp_1_0_outputs,6,0.67046475,0.42168856,0.17522931,0.9923386,0.42843175,0.6051585,-0.6304327 121 | mlp_1_0_outputs,7,-0.6101541,0.12701088,0.45372865,0.8725,0.27473047,-0.016846292,-0.33390883 122 | mlp_1_1_outputs,0,0.013466556,-0.31899878,-0.0581283,0.12199419,-0.28192335,0.0019597644,0.32551047 123 | mlp_1_1_outputs,1,0.86517906,0.99981546,1.1384677,0.9402987,0.7827801,0.7498811,-7.286347 124 | mlp_1_1_outputs,2,-1.2548337,-1.2306365,-1.2297429,-0.7032778,-0.9226823,-0.9777365,3.206642 125 | mlp_1_1_outputs,3,0.3976817,-0.037045293,0.17781861,0.5145008,-0.060972173,-0.15903838,0.661624 126 | mlp_1_1_outputs,4,0.33862308,-0.02482277,0.22182596,0.64739317,-0.026410017,0.43479386,1.3172387 127 | mlp_1_1_outputs,5,-0.15095389,-0.48228294,-0.20532566,0.103111684,-0.37304664,-0.33961606,0.4904485 128 | mlp_1_1_outputs,6,0.77520096,1.0291071,0.8356842,0.6138105,0.508028,0.81939375,-6.3209095 129 | mlp_1_1_outputs,7,0.35298768,0.14868644,0.1929121,0.8483676,0.13015799,0.1845815,-0.6845896 130 | mlp_2_0_outputs,0,0.22749007,0.010398724,-0.7143261,0.19049524,0.40374222,-0.18089859,1.913285 131 | mlp_2_0_outputs,1,0.43453947,0.48566645,1.2921165,0.27306423,0.7190616,0.64280427,-9.392651 132 | mlp_2_0_outputs,2,-0.010374933,0.6695905,0.8829944,0.09739681,0.1910442,-0.55804473,-0.24017237 133 | mlp_2_0_outputs,3,0.21176498,-0.6259095,1.0027347,-0.2894584,0.15444903,0.39754984,-0.62771827 134 | mlp_2_0_outputs,4,0.35422748,-0.03349688,1.0252547,0.1733238,0.36093882,-0.23943187,0.3938007 135 | mlp_2_0_outputs,5,-1.4427056,0.19055358,0.7897258,0.31336093,-0.007993803,0.31204116,-1.0540504 136 | mlp_2_0_outputs,6,0.66334516,-0.1595391,1.3743109,0.043930616,0.16582729,-0.17752126,-1.0506142 137 | mlp_2_0_outputs,7,-0.81803083,-1.0996387,-1.3154311,-1.6927465,-1.0877206,-1.2953286,6.2696586 138 | mlp_2_1_outputs,0,0.20822705,0.8802053,0.08234023,0.29908177,0.044410724,0.18044509,-1.9308993 139 | mlp_2_1_outputs,1,-0.45252222,0.64726084,-0.39025173,-0.12514322,-0.5468172,-0.4661018,0.12722749 140 | mlp_2_1_outputs,2,0.3255769,0.23384853,-0.41722363,0.724245,0.47933605,0.16578805,-0.4978126 141 | mlp_2_1_outputs,3,-0.18909135,0.21740477,-0.1728526,0.08816649,0.33721173,0.0031366053,0.33242244 142 | mlp_2_1_outputs,4,0.31890598,1.205369,0.26985618,0.43814164,0.012860747,0.028930334,0.41829804 143 | mlp_2_1_outputs,5,0.26172248,1.438389,0.345994,0.60735875,0.22993265,0.39017054,-2.027972 144 | mlp_2_1_outputs,6,0.5005941,0.60231215,0.81867206,0.4571476,0.137451,-0.1562085,-0.7049604 145 | mlp_2_1_outputs,7,0.73108,2.107512,0.88618535,0.81509197,0.54469866,0.5268218,-2.7247245 146 | num_attn_0_0_outputs,_,0.49675706,0.76701194,0.8319468,1.0188137,-3.46674,1.3313375,1.2141222 147 | num_attn_0_1_outputs,_,-4.705128,0.20360513,0.05153054,0.18922624,2.8429217,0.7594758,0.08416794 148 | num_attn_0_2_outputs,_,0.32272223,0.26861843,0.27784988,-0.007126636,0.30486012,-0.51013726,0.45339686 149 | num_attn_0_3_outputs,_,0.35131818,0.15362968,-0.026241664,0.024230676,0.7753488,-0.36277744,0.23196767 150 | num_attn_1_0_outputs,_,-0.17899519,-0.22110243,1.955975,-0.33664736,-0.18993665,-0.42071265,-3.2977908 151 | num_attn_1_1_outputs,_,0.096573666,0.041719396,0.3011042,-0.7403994,0.24944292,0.35339233,0.40832838 152 | num_attn_1_2_outputs,_,0.307874,0.77904606,-3.6890242,0.57999134,0.07966587,1.0292304,2.588211 153 | num_attn_1_3_outputs,_,0.1556463,-0.88464093,-0.009464315,1.2894713,-0.006048728,0.12710084,0.042141438 154 | num_attn_2_0_outputs,_,3.2238479,-0.39403713,-0.5007616,-0.38034695,-1.5038663,-0.5157612,-0.4220676 155 | num_attn_2_1_outputs,_,0.09123767,0.06927917,0.071765244,0.06278864,0.03571881,-1.0034105,0.22208074 156 | num_attn_2_2_outputs,_,0.49392816,0.30212557,0.5877162,0.3936646,0.7463011,0.91613317,-0.27264255 157 | num_attn_2_3_outputs,_,0.59951186,1.5315337,0.45857117,-2.2965033,0.47967055,0.9584901,-0.25912756 158 | num_mlp_0_0_outputs,0,0.75295377,0.22206154,0.4430251,0.43609834,0.3002954,0.14143573,-0.9965674 159 | num_mlp_0_0_outputs,1,0.13362123,-0.27709126,0.11786564,0.0842753,0.057983357,-0.20576423,-0.17937659 160 | num_mlp_0_0_outputs,2,0.77560294,0.23398742,0.38785017,0.09531838,0.36529553,0.0049865237,-2.6353512 161 | num_mlp_0_0_outputs,3,0.67882055,0.47041214,0.49682024,0.22344151,0.27797723,0.384288,-0.78773314 162 | num_mlp_0_0_outputs,4,0.36207622,-0.12504679,0.15990801,0.038606238,-0.029018853,-0.017407462,-0.3937458 163 | num_mlp_0_0_outputs,5,-5.4784355,-0.21962734,-0.05829186,0.06191359,0.76395345,-0.15172787,5.640328 164 | num_mlp_0_0_outputs,6,0.6987904,0.19056033,0.45902565,0.0974688,0.20158903,0.2252314,-0.5025556 165 | num_mlp_0_0_outputs,7,-0.66295034,0.33240935,0.7547661,0.07789776,-1.1808933,0.11714694,-1.6093485 166 | num_mlp_0_1_outputs,0,-0.4456775,-0.3319486,0.259621,0.04547681,-3.9428256,0.30970556,3.493646 167 | num_mlp_0_1_outputs,1,-0.060211174,0.5636933,0.063009106,0.058421217,1.4696409,-0.012732882,-3.97732 168 | num_mlp_0_1_outputs,2,-0.112916395,0.7270722,0.22726461,0.24413253,0.8066552,0.011055177,-0.95724344 169 | num_mlp_0_1_outputs,3,1.633666,0.42186847,0.36930156,-0.119241916,0.40250668,-0.55435634,-3.3496842 170 | num_mlp_0_1_outputs,4,-0.4858897,0.43768635,-0.16505812,-0.17672735,0.9722587,-0.056002922,-1.3548676 171 | num_mlp_0_1_outputs,5,-0.097139664,0.3772895,0.27854103,0.25372255,0.8777077,0.24463934,-0.93909496 172 | num_mlp_0_1_outputs,6,0.33189073,0.012857659,-0.14545752,0.1099122,1.541078,0.07001639,-0.91754353 173 | num_mlp_0_1_outputs,7,0.41338253,0.3731902,-0.07001816,0.79229236,-0.9019284,0.11754089,2.5209103 174 | num_mlp_1_0_outputs,0,0.24915189,-0.13592264,0.28131127,0.31066325,0.21636605,0.18062419,-1.4040216 175 | num_mlp_1_0_outputs,1,0.1564762,0.08252395,-5.7924643,-0.192247,0.33157474,-0.009017468,2.942537 176 | num_mlp_1_0_outputs,2,0.13001019,-0.31609473,0.84513766,0.34132242,0.5279612,0.41859335,-6.285437 177 | num_mlp_1_0_outputs,3,0.4110137,-0.054942466,-0.1518215,-0.08841369,0.3400565,0.11748088,-0.1709277 178 | num_mlp_1_0_outputs,4,0.23019612,-0.12827308,-0.05825717,-0.0017411346,-0.12765515,-0.08880548,-0.9720372 179 | num_mlp_1_0_outputs,5,0.6083281,-0.21133165,-1.5179089,0.7018768,-0.3144238,0.19359513,0.51440924 180 | num_mlp_1_0_outputs,6,0.01387636,-0.01365367,-0.14135487,0.027097523,0.001307688,0.11765708,0.1378213 181 | num_mlp_1_0_outputs,7,0.067421116,0.3388032,-0.16557097,-0.12955894,-0.19303417,0.02512272,0.18632148 182 | num_mlp_1_1_outputs,0,0.3679926,0.43395373,0.39618903,0.30078262,0.29004973,0.1358351,0.06785458 183 | num_mlp_1_1_outputs,1,0.20307627,0.16571133,-0.034833442,0.056774694,-0.28404135,-0.0896154,-0.32265937 184 | num_mlp_1_1_outputs,2,0.13026774,0.061297063,0.06480035,0.068802625,-0.009410563,-0.18956332,-0.45669928 185 | num_mlp_1_1_outputs,3,0.23407696,0.14676198,0.011613104,-0.020378916,-0.018320754,-0.2158755,-0.41205496 186 | num_mlp_1_1_outputs,4,0.51550543,0.19345349,-0.29995015,-0.21714196,-0.05119475,-0.11259293,-0.39577466 187 | num_mlp_1_1_outputs,5,0.7264443,0.28373966,0.18676305,0.5019726,0.27477577,-0.22090729,-0.0033221303 188 | num_mlp_1_1_outputs,6,0.09651393,0.122992426,-0.09937113,-0.0116935,-0.07023688,-0.034174327,-0.34691346 189 | num_mlp_1_1_outputs,7,0.4628552,-0.038354713,0.18961802,0.0764211,-0.18719746,-0.06302883,-0.23135951 190 | num_mlp_2_0_outputs,0,1.3870856,-0.25443602,0.1545994,0.07868433,1.3484426,0.5581853,-1.5508221 191 | num_mlp_2_0_outputs,1,-0.6188547,-0.28784364,0.18100783,-0.7676343,-0.15114093,0.17485546,-0.16489339 192 | num_mlp_2_0_outputs,2,1.650306,0.25490463,0.73042595,0.5314448,0.4455995,0.46853003,-4.48665 193 | num_mlp_2_0_outputs,3,-0.53009546,-0.12033514,-0.21350926,-0.053965583,-0.014094976,0.3029367,0.16062967 194 | num_mlp_2_0_outputs,4,-0.35834578,-0.14751779,-0.32184893,-0.22511242,-0.2735182,0.02457167,-0.18555361 195 | num_mlp_2_0_outputs,5,-0.36402833,0.74258685,0.9334475,1.3147615,0.6026904,0.7450087,-0.85253215 196 | num_mlp_2_0_outputs,6,0.43582806,-0.7167769,-0.2027412,-0.5209631,-1.3333701,-0.1770416,-4.8763123 197 | num_mlp_2_0_outputs,7,-5.2162914,-1.7618054,-1.4699253,-1.6181387,-0.6872477,-1.2078984,8.595443 198 | num_mlp_2_1_outputs,0,0.38339245,0.74561024,1.8125025,0.34722832,0.57123804,0.35547578,-5.5821314 199 | num_mlp_2_1_outputs,1,0.15095738,0.10469516,1.0636414,0.34179947,0.79633814,0.42729956,-3.5103517 200 | num_mlp_2_1_outputs,2,-0.47540858,-0.09333152,-0.9867091,-0.23736648,-0.22757289,0.15208386,0.045130484 201 | num_mlp_2_1_outputs,3,-0.4281402,-0.18589725,-1.0558151,-0.16013376,0.048908155,-0.041572504,0.64169294 202 | num_mlp_2_1_outputs,4,-0.4590558,-0.4067188,-1.0003194,0.0051985644,0.21098104,-0.287274,0.3842943 203 | num_mlp_2_1_outputs,5,-0.07419524,0.20917611,-0.29637602,0.33867353,0.06768416,0.12848133,-0.05852556 204 | num_mlp_2_1_outputs,6,-1.1355742,-0.85389644,-3.9056017,-0.29252398,-0.52056974,-0.3573991,7.055144 205 | num_mlp_2_1_outputs,7,0.12410815,0.33784002,0.16031556,0.3937437,0.3059639,0.27686897,0.08846511 206 | ones,_,0.021529218,-0.15176989,0.2739342,0.42407322,0.85170376,0.6750486,0.039232593 207 | positions,0,-0.56147146,-0.12002305,0.2413873,0.1853156,-0.5742414,0.08497386,0.07996443 208 | positions,1,3.0675466,3.2321472,4.122518,0.2724716,3.5735428,3.0197644,-16.798847 209 | positions,2,2.4287295,1.9609035,1.7737118,3.7877054,1.7662969,1.9576112,-6.0926604 210 | positions,3,2.299956,-0.8758886,2.310295,2.8263414,1.6215533,2.39478,-6.9706354 211 | positions,4,-1.049784,1.8572471,-0.75735337,-1.4148265,-1.0481493,-1.4053416,2.2860672 212 | positions,5,-3.3229315,-1.5263214,-2.91694,-3.5551183,-4.0157175,-3.1399124,5.1658487 213 | positions,6,-7.9361725,-8.840564,-7.397052,-9.724115,-9.180236,-3.3155594,10.632461 214 | positions,7,-7.9004083,-13.131187,-9.459313,-10.873171,-9.7958555,-8.555323,11.622234 215 | tokens,0,0.7939464,0.02564311,0.10964918,0.23205447,-0.4861161,-0.0044119004,-0.5045989 216 | tokens,1,0.13817202,0.7952534,-0.108660065,0.011065895,-0.3787378,-0.07682034,-0.5505967 217 | tokens,2,0.21265996,0.10762084,0.8517744,0.086404786,0.06522274,-0.738799,-0.5161016 218 | tokens,3,0.009029636,-0.1933909,-0.23239014,1.1473926,-0.46348011,0.24348621,-0.5114875 219 | tokens,4,0.4256177,-0.043745376,-0.28037384,-0.1920835,1.4755142,-0.97692513,-0.5098483 220 | tokens,5,-0.042473715,-0.19566593,-0.2469844,0.31330782,-0.5277211,1.0684317,-0.44248924 221 | tokens,,-0.23496482,0.3631881,0.43677837,0.011330187,-0.40754265,-0.8863597,0.15948972 222 | tokens,,-0.056050863,0.29581282,-0.1711932,0.0840109,0.20864172,0.31114328,0.17907692 223 | -------------------------------------------------------------------------------- /programs/rasp/reverse/reverse_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,0,1,2,3,4 2 | attn_0_0_outputs,0,2.742066,-1.7750658,-1.5919067,-1.4220088,-0.5447206 3 | attn_0_0_outputs,1,-1.7905326,3.2459338,-1.0251777,-0.35510015,-1.285745 4 | attn_0_0_outputs,2,-1.3913494,0.42532334,3.6383674,-1.5303861,-0.04807833 5 | attn_0_0_outputs,3,0.49622118,-0.41630313,-1.4269036,3.259316,-1.6067595 6 | attn_0_0_outputs,4,-0.6003677,-0.21356046,-2.0015326,-0.8214214,2.7384455 7 | attn_0_0_outputs,,-0.21139087,-0.061902717,2.2465034,1.0147868,-1.6599561 8 | attn_0_0_outputs,,0.92124563,0.5974984,0.154064,0.34046212,0.059019208 9 | attn_0_0_outputs,,0.24552345,0.39865023,0.20082256,-0.15461971,0.67986333 10 | attn_0_1_outputs,0,7.673302,-3.501045,-5.5389047,-4.3190165,-1.769243 11 | attn_0_1_outputs,1,-1.9890903,7.9207335,-3.289247,-3.0845697,-2.3067732 12 | attn_0_1_outputs,2,-4.705924,-3.0480518,7.011647,-0.6278108,-3.4870532 13 | attn_0_1_outputs,3,-3.1684492,-2.9271348,-2.6957874,7.3821106,-2.4872382 14 | attn_0_1_outputs,4,-1.5438712,-3.585505,-2.7780457,-3.6590195,6.381682 15 | attn_0_1_outputs,,2.4767106,-0.34512123,-2.5959194,1.4119138,-3.1488388 16 | attn_0_1_outputs,,0.1414338,-0.32987177,-0.14424743,0.121554226,0.069688104 17 | attn_0_1_outputs,,-0.67165095,-0.29987463,0.9411148,-1.1158872,0.9772637 18 | attn_0_2_outputs,0,2.8864198,0.5736399,-1.7944727,1.275154,-3.7867603 19 | attn_0_2_outputs,1,1.286558,3.8100312,0.2034576,-4.1266513,-2.4538345 20 | attn_0_2_outputs,2,-3.55232,-2.6151497,2.8218007,0.27975446,0.9032074 21 | attn_0_2_outputs,3,2.0676212,-1.8931607,-1.0893198,2.6283581,-3.3388352 22 | attn_0_2_outputs,4,-2.601808,1.4966966,-5.524055,-3.3436193,6.786246 23 | attn_0_2_outputs,,-1.2964387,-1.184649,3.191794,0.26659352,-0.07239697 24 | attn_0_2_outputs,,-0.00678445,-0.12167784,-0.28402784,-0.1283455,0.53786683 25 | attn_0_2_outputs,,0.46426013,-0.44629773,0.791248,0.34926856,-0.41215342 26 | attn_0_3_outputs,0,-0.1498222,1.2560209,0.19964494,0.6762862,-1.121879 27 | attn_0_3_outputs,1,0.64040285,1.9587319,1.4074128,-3.5538082,0.9837679 28 | attn_0_3_outputs,2,-1.5384316,-0.65306073,2.1958265,0.9510098,-3.2979362 29 | attn_0_3_outputs,3,-0.81305355,-1.5002078,0.020203998,0.63694465,0.60759926 30 | attn_0_3_outputs,4,-1.5025858,1.7346878,-1.3115505,1.4721668,1.8840551 31 | attn_0_3_outputs,,4.6110573,0.39784852,-4.7847724,0.5624679,-2.1006687 32 | attn_0_3_outputs,,-0.4865959,-0.75623137,-0.559213,-0.6034367,-0.19291799 33 | attn_0_3_outputs,,-0.6616858,-0.70004207,1.4663161,-2.5758874,0.58236104 34 | attn_1_0_outputs,0,5.14012,-2.0321715,-3.1873803,-2.161212,-2.924637 35 | attn_1_0_outputs,1,-1.1760738,6.243624,-2.292563,-2.3656707,-1.8633323 36 | attn_1_0_outputs,2,-1.8438314,-2.2223895,7.2542777,-1.7975602,-2.8767316 37 | attn_1_0_outputs,3,-1.4480008,-2.8605947,-2.552013,6.218815,-2.0353441 38 | attn_1_0_outputs,4,-1.8545668,-2.4011686,-2.4577968,-0.8337147,5.6437707 39 | attn_1_0_outputs,,1.4606192,-0.65165573,-0.96397144,1.39993,-2.030799 40 | attn_1_0_outputs,,0.22300358,-0.5646053,-0.41895154,0.40009347,-0.45633748 41 | attn_1_0_outputs,,-0.5169577,-0.574157,0.06152434,-0.45967716,0.7860707 42 | attn_1_1_outputs,0,-1.729071,0.3904098,1.063312,0.9541465,0.64347845 43 | attn_1_1_outputs,1,0.14294437,0.5665816,1.1513399,1.1456755,-0.45402047 44 | attn_1_1_outputs,2,-0.47264585,-1.2265207,0.73496914,-0.27962002,1.5110021 45 | attn_1_1_outputs,3,0.55221164,0.74901503,-0.2234499,0.46550083,0.13538252 46 | attn_1_1_outputs,4,-0.887349,-2.3832846,-0.090858944,-0.27379212,2.0423574 47 | attn_1_1_outputs,,0.44976777,0.99956155,-0.25536492,0.0694434,-0.45211685 48 | attn_1_1_outputs,,0.30245078,0.1915722,-0.34525523,0.6264211,0.28215456 49 | attn_1_1_outputs,,0.64166284,-0.9103854,0.21581404,-0.1512637,-0.53048915 50 | attn_1_2_outputs,0,-6.9120216,2.9334385,3.5168996,1.5317044,3.1637955 51 | attn_1_2_outputs,1,1.3649912,-8.357152,1.0929214,2.8909745,2.2582288 52 | attn_1_2_outputs,2,1.6044966,2.746573,-9.02028,2.4642735,2.3857315 53 | attn_1_2_outputs,3,1.3656769,3.410438,1.7813203,-7.0019917,2.6581526 54 | attn_1_2_outputs,4,1.7743725,2.7270765,2.717991,2.2948518,-7.542405 55 | attn_1_2_outputs,,-0.2602812,-0.3624655,1.7670982,-0.119142905,-1.4537865 56 | attn_1_2_outputs,,-0.4052089,-0.016108118,0.87777245,0.1692151,0.259325 57 | attn_1_2_outputs,,-0.61002606,0.6863391,1.1960802,-0.972229,-0.3090549 58 | attn_1_3_outputs,0,3.147749,-2.74755,-2.6375554,-1.7992983,-1.8083935 59 | attn_1_3_outputs,1,-0.3166353,3.135929,-2.0476818,-1.5201313,-0.9221027 60 | attn_1_3_outputs,2,-0.82175285,-1.7785048,4.262148,-1.098342,-1.320515 61 | attn_1_3_outputs,3,-0.48976257,-1.4941785,-2.4922595,3.688606,-0.7487961 62 | attn_1_3_outputs,4,-2.8958664,0.9745885,-1.7469578,0.11342506,2.1992576 63 | attn_1_3_outputs,,0.9585229,-0.021460164,-0.068051144,-0.5773519,-0.15273045 64 | attn_1_3_outputs,,-0.08786828,0.34689122,0.3743438,0.2036276,0.04196969 65 | attn_1_3_outputs,,-0.5865496,-0.08093558,1.0197412,0.5368188,-0.35131058 66 | attn_2_0_outputs,0,-6.404271,3.213514,3.7119424,3.6321664,2.465991 67 | attn_2_0_outputs,1,1.3457497,-6.91716,2.117783,3.216933,2.2687566 68 | attn_2_0_outputs,2,1.7306777,1.9339283,-6.4193964,2.4106104,1.2669448 69 | attn_2_0_outputs,3,1.3426539,4.61214,2.2282271,-6.918863,2.1501 70 | attn_2_0_outputs,4,1.4316176,2.9336243,3.1591465,1.782481,-7.2788053 71 | attn_2_0_outputs,,0.1130622,0.17945175,-0.29610068,-0.18893324,0.092504 72 | attn_2_0_outputs,,0.07159574,-0.3103722,0.110794805,-0.79871124,-0.7539021 73 | attn_2_0_outputs,,0.7954956,-0.26322004,-0.07522555,0.038551275,-0.50479794 74 | attn_2_1_outputs,0,6.998647,-3.457301,-4.5451207,-1.8305441,-1.8368263 75 | attn_2_1_outputs,1,-0.40058747,6.4474087,-1.4574099,-2.18984,-1.7002127 76 | attn_2_1_outputs,2,-1.9955103,-2.4883533,7.227968,-2.8376138,-1.9176865 77 | attn_2_1_outputs,3,-1.350527,-2.7862217,-2.4625497,6.6744094,-1.3645524 78 | attn_2_1_outputs,4,-3.3767576,-0.8977402,-2.9630077,-0.49530327,5.3836374 79 | attn_2_1_outputs,,0.20378506,-0.5800168,-1.106855,0.19926731,-0.06131589 80 | attn_2_1_outputs,,-0.43978697,-0.21296938,-0.22513205,0.6500348,0.30841365 81 | attn_2_1_outputs,,-0.547537,0.8714378,0.33708888,-2.0015378,1.5639987 82 | attn_2_2_outputs,0,4.7300496,-2.609322,-3.414926,-2.652299,-1.7123688 83 | attn_2_2_outputs,1,-1.0747069,4.541831,-1.8588012,-1.2684554,-0.17224875 84 | attn_2_2_outputs,2,-1.5746042,-1.28898,4.3767176,-1.0493795,-2.8320768 85 | attn_2_2_outputs,3,-2.344087,-1.2606856,-0.91734046,5.6075225,-2.2346644 86 | attn_2_2_outputs,4,-2.6133668,-0.18721344,-1.6325711,-1.149275,3.6483755 87 | attn_2_2_outputs,,1.2496438,-0.41678748,0.087676175,-0.6057794,-0.044405453 88 | attn_2_2_outputs,,-0.15122709,-0.2583006,0.39665553,0.003170321,-0.016708078 89 | attn_2_2_outputs,,-0.7736017,-0.43213943,0.18013555,-0.09826307,0.057805218 90 | attn_2_3_outputs,0,-2.7729218,1.4200424,1.4600606,0.34230572,0.84415036 91 | attn_2_3_outputs,1,0.4400618,-1.756905,1.4549123,1.019148,0.4850436 92 | attn_2_3_outputs,2,0.9549057,1.2824273,-2.0410612,0.9188365,1.6863561 93 | attn_2_3_outputs,3,-0.091469035,1.2047702,1.370021,-0.89607966,-0.51357347 94 | attn_2_3_outputs,4,0.466892,0.8118219,0.28800988,1.3429086,-3.2995336 95 | attn_2_3_outputs,,0.24129342,-0.86307657,0.1745238,-2.054632,2.8129969 96 | attn_2_3_outputs,,-0.11414052,0.05034334,0.09903593,0.10165983,0.5136056 97 | attn_2_3_outputs,,0.4432127,0.15918905,0.12439812,0.3750773,-0.17403647 98 | mlp_0_0_outputs,0,-5.1317487,-6.70734,6.037307,4.5852356,4.0092926 99 | mlp_0_0_outputs,1,-0.257455,0.35274887,0.24112554,-0.42148477,0.024651978 100 | mlp_0_0_outputs,2,-1.8256108,4.1555786,-1.9008249,1.3808151,-1.2630675 101 | mlp_0_0_outputs,3,4.6051683,-1.8804812,3.9409556,-5.850926,-0.88602555 102 | mlp_0_0_outputs,4,1.0506113,-3.9949007,0.13467729,0.70601964,-0.3138636 103 | mlp_0_0_outputs,5,0.5380068,-0.7315861,-1.4061457,1.1938375,-1.1255091 104 | mlp_0_0_outputs,6,-0.14030905,2.8830237,-2.683621,-0.16997392,0.47404832 105 | mlp_0_0_outputs,7,0.3876715,-0.08891349,0.8728597,1.0125917,0.11347792 106 | mlp_1_0_outputs,0,2.3266206,-6.2295556,-4.0005665,2.3041046,0.21257897 107 | mlp_1_0_outputs,1,-0.4556595,-0.805607,-1.6509348,1.7824169,1.501474 108 | mlp_1_0_outputs,2,-5.792818,9.575128,3.5130193,-7.9973984,5.002579 109 | mlp_1_0_outputs,3,-1.2260249,0.9880918,-2.378248,2.0664575,1.2702942 110 | mlp_1_0_outputs,4,4.1306114,1.7542486,5.0177255,-0.94854623,-13.307635 111 | mlp_1_0_outputs,5,-0.06441002,0.36633462,0.41651008,0.1121977,-0.19381374 112 | mlp_1_0_outputs,6,0.28442925,0.3170138,1.0178398,-0.13551465,0.05566804 113 | mlp_1_0_outputs,7,-0.30202946,0.3719944,0.7272587,0.13975187,-0.37298957 114 | mlp_2_0_outputs,0,0.45908874,-0.01754609,0.97707564,-0.9511417,0.5083081 115 | mlp_2_0_outputs,1,0.18382199,-0.19358855,0.4822291,-0.80399686,-0.1716763 116 | mlp_2_0_outputs,2,1.5142181,-0.02708238,-0.86326504,0.48660412,-0.061133042 117 | mlp_2_0_outputs,3,-1.4314252,0.31027403,1.2170343,-0.10667357,0.94278526 118 | mlp_2_0_outputs,4,-0.0568396,-0.4547981,1.945643,-1.6712406,0.698788 119 | mlp_2_0_outputs,5,0.47281817,-0.15489303,0.5081915,-0.3702506,0.28318653 120 | mlp_2_0_outputs,6,0.30037868,0.22154446,0.37623194,0.10769829,0.31810036 121 | mlp_2_0_outputs,7,0.1804857,0.27696744,-0.5805316,1.504735,-0.4786584 122 | num_attn_0_0_outputs,_,0.8465965,2.0133536,-2.2380738,-1.4228604,1.0347323 123 | num_attn_0_1_outputs,_,-0.8386918,-0.38922623,-2.324954,0.07236001,1.3036776 124 | num_attn_0_2_outputs,_,-0.117477916,-2.304579,1.9196264,-1.8928717,-0.46572107 125 | num_attn_0_3_outputs,_,0.28793186,-1.0810422,-1.1591295,0.8848445,0.37602443 126 | num_attn_1_0_outputs,_,0.67058,-0.86042327,-1.924298,-1.2025291,3.5931141 127 | num_attn_1_1_outputs,_,-1.2113587,0.8658644,0.23557991,-0.5923407,0.55709636 128 | num_attn_1_2_outputs,_,0.9216758,0.9326565,-4.265963,0.8197906,1.7644495 129 | num_attn_1_3_outputs,_,2.0054622,-1.7789265,-0.87892455,-2.7376547,2.2104425 130 | num_attn_2_0_outputs,_,-1.8372577,0.24842429,-1.3088604,0.64833593,1.3509363 131 | num_attn_2_1_outputs,_,-0.108283974,-1.7611748,3.154574,3.5453818,0.70397145 132 | num_attn_2_2_outputs,_,3.0538957,-1.0975236,2.4850554,-1.7636111,-0.890746 133 | num_attn_2_3_outputs,_,0.43507352,4.196319,3.1197977,-2.9802363,-0.2592518 134 | num_mlp_0_0_outputs,0,-0.64667314,0.17904095,0.39200422,-0.0028258313,-0.37290496 135 | num_mlp_0_0_outputs,1,0.2116071,-0.017157368,0.9101409,0.45278645,0.056753635 136 | num_mlp_0_0_outputs,2,0.061665036,0.41990674,0.6671413,0.40301383,0.24431042 137 | num_mlp_0_0_outputs,3,-0.4006326,-0.4738044,-0.11009932,-0.14682105,-0.36005288 138 | num_mlp_0_0_outputs,4,0.4470174,-0.936997,-1.4516735,0.07075252,-0.16868772 139 | num_mlp_0_0_outputs,5,-0.5206594,-0.38733086,0.040900033,-0.473343,-0.33152476 140 | num_mlp_0_0_outputs,6,-0.22264968,0.5538271,0.2416182,-0.031259954,-0.3844373 141 | num_mlp_0_0_outputs,7,-0.1432738,-0.064719036,0.12572391,-0.20695664,0.2349148 142 | num_mlp_1_0_outputs,0,0.23161256,-0.17896621,1.3701653,0.027287895,-1.3479851 143 | num_mlp_1_0_outputs,1,-0.29446986,-0.47882456,0.77442497,0.34317648,-0.5066082 144 | num_mlp_1_0_outputs,2,-0.37243122,0.5159971,-2.766263,0.03633871,2.1990325 145 | num_mlp_1_0_outputs,3,-0.2770934,-0.053324725,0.62508935,0.08229346,-1.4606029 146 | num_mlp_1_0_outputs,4,0.47045127,0.05848788,1.1008859,0.31856912,-0.87950534 147 | num_mlp_1_0_outputs,5,0.03743611,-0.21205485,0.7981379,0.031759117,-0.23362987 148 | num_mlp_1_0_outputs,6,-0.120329894,0.09074645,1.4699514,-0.08971425,-1.3956728 149 | num_mlp_1_0_outputs,7,0.41638732,0.2154947,1.4405341,0.6774053,-0.46320668 150 | num_mlp_2_0_outputs,0,0.22702669,0.38019413,0.5690051,0.97963583,0.35750404 151 | num_mlp_2_0_outputs,1,-0.8008013,0.08042168,-0.41324988,-0.26802236,-0.3302354 152 | num_mlp_2_0_outputs,2,-0.64541245,-0.080827184,-0.18740739,-0.22994758,-0.24330296 153 | num_mlp_2_0_outputs,3,0.037600122,0.50564003,-0.23816861,0.11780354,0.2502989 154 | num_mlp_2_0_outputs,4,-0.012841126,0.39657962,-0.119696006,0.11229075,-0.10252234 155 | num_mlp_2_0_outputs,5,-0.7187831,-0.115768135,-0.15021832,0.22964966,-0.20723926 156 | num_mlp_2_0_outputs,6,-0.22354771,0.23253189,0.22618711,-0.14161181,-0.4494641 157 | num_mlp_2_0_outputs,7,-0.43640983,0.28443414,-0.12486136,-0.008818966,0.2645296 158 | ones,_,0.55620617,-0.5459692,0.11906973,0.0029487282,-0.60756934 159 | positions,0,0.48431224,-0.07508403,-0.15782595,0.08621112,-0.2329988 160 | positions,1,-1.0866297,0.20545797,0.020345941,0.27752703,-0.0058798143 161 | positions,2,-0.1897344,-0.6670251,0.017169373,-0.7351196,2.036752 162 | positions,3,-0.20501482,0.57732826,1.1450588,1.1435491,-1.1068201 163 | positions,4,0.12606427,0.026706036,-0.29324046,0.31004795,0.5862827 164 | positions,5,-0.7492487,0.8921466,-1.4819856,0.54742503,0.501571 165 | positions,6,-0.6667614,1.901696,1.1710435,0.22811733,-1.1982802 166 | positions,7,0.51935697,0.3317305,-0.9227535,0.20672892,-0.5248978 167 | tokens,0,-0.47813687,0.12706348,-0.04924151,0.548112,0.14897579 168 | tokens,1,-0.30380446,0.8807152,-0.2979864,0.1732994,-0.020828225 169 | tokens,2,-0.21944927,0.18884198,0.6118489,0.252533,-1.4074645 170 | tokens,3,-0.19298886,0.11001143,-0.3632103,0.2215297,-0.3719192 171 | tokens,4,-0.076143056,0.31109792,-0.25140736,-0.4576611,0.7776612 172 | tokens,,0.24346647,0.7005277,0.22031198,-0.29249498,-0.07512681 173 | tokens,,0.11568012,-0.57711613,0.19756645,-0.48424596,0.09319665 174 | tokens,,-0.5352298,0.28023535,0.11314476,0.25227708,0.2986933 175 | -------------------------------------------------------------------------------- /programs/rasp/sort/sort_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,0,1,2,3,4 2 | attn_0_0_outputs,0,2.058161,-0.63779896,0.352272,0.5141959,-1.740315 3 | attn_0_0_outputs,1,-2.4275339,2.824267,1.1222093,-0.29264563,-1.6031634 4 | attn_0_0_outputs,2,-0.66014487,0.4371716,0.6551343,0.1385141,-0.82386607 5 | attn_0_0_outputs,3,-1.4426589,-0.6121028,0.95834476,3.1481867,-2.513183 6 | attn_0_0_outputs,4,-5.1414647,-4.19145,-2.107596,0.40243983,5.3070917 7 | attn_0_0_outputs,,-2.1824768,-2.4896393,-1.6732111,-0.5224847,3.0039287 8 | attn_0_0_outputs,,-0.29941246,0.6749306,0.05445957,-0.16473089,0.30180916 9 | attn_0_0_outputs,,-1.52416,-1.9376818,-1.7669702,-1.5059099,3.1769867 10 | attn_0_1_outputs,0,2.1742406,-0.104620405,0.667535,0.56420654,-0.70504665 11 | attn_0_1_outputs,1,-1.0247861,2.315593,0.35188147,-0.6513114,-1.2275088 12 | attn_0_1_outputs,2,-0.4099264,0.6904153,0.5056659,-0.7560546,-0.83760774 13 | attn_0_1_outputs,3,-1.055448,-0.52389306,1.1448267,3.5426579,-2.5540905 14 | attn_0_1_outputs,4,-3.743604,-3.8234832,-2.1349356,-0.02523663,4.5795074 15 | attn_0_1_outputs,,-6.0711293,-3.95341,-1.697248,0.52152914,5.3833594 16 | attn_0_1_outputs,,0.4076765,0.08954932,-0.14387824,0.091574,0.20851962 17 | attn_0_1_outputs,,-0.38013995,0.15324089,0.36798063,0.38066244,0.54085183 18 | attn_0_2_outputs,0,5.2839427,-2.605443,-1.1492563,-1.7018955,-1.4085128 19 | attn_0_2_outputs,1,-1.1856108,0.3934325,0.09754402,1.237557,1.0921215 20 | attn_0_2_outputs,2,-1.2996812,-0.00338279,0.70976466,0.60546106,0.31050247 21 | attn_0_2_outputs,3,-1.7861035,0.39520007,0.18829803,2.3311763,-1.300178 22 | attn_0_2_outputs,4,-9.374754,-5.7489552,-2.4201286,1.2188936,6.6614056 23 | attn_0_2_outputs,,-0.24304307,0.06899596,0.06185269,0.42859817,0.07496456 24 | attn_0_2_outputs,,0.54609865,0.4358928,0.3134578,0.6048993,0.7376409 25 | attn_0_2_outputs,,-2.2961771,2.386299,1.5124247,0.1737369,-2.5670602 26 | attn_0_3_outputs,0,0.3295003,-0.1827127,-0.18133737,-0.013248737,-0.10451385 27 | attn_0_3_outputs,1,-0.45805857,0.50490904,0.15572038,-0.71143156,-0.13632512 28 | attn_0_3_outputs,2,-0.060173124,0.06350225,0.20524347,0.15398552,0.42267397 29 | attn_0_3_outputs,3,-1.0496696,0.57741636,-0.8549515,2.4007623,-1.0587981 30 | attn_0_3_outputs,4,-2.2801893,-0.99242985,-0.45985192,-0.20450349,2.3491073 31 | attn_0_3_outputs,,-5.532329,-2.3926458,-0.032455105,2.7236922,3.9317892 32 | attn_0_3_outputs,,0.12758909,-0.13754353,0.030744994,-0.15796055,0.30099612 33 | attn_0_3_outputs,,-0.41964686,-1.7669121,-1.2862828,0.43886116,2.138884 34 | attn_1_0_outputs,0,2.0348678,-1.2148954,-0.88260573,0.072366394,0.8086225 35 | attn_1_0_outputs,1,-1.2006625,0.116209,-0.0810096,0.38852194,0.95987433 36 | attn_1_0_outputs,2,-0.14999832,0.9288459,1.3716835,1.1660155,-1.7833197 37 | attn_1_0_outputs,3,-2.7026277,-1.178805,-1.0291067,1.0259144,1.4074742 38 | attn_1_0_outputs,4,-4.517803,-1.8277934,-1.0758657,0.06511916,2.5772486 39 | attn_1_0_outputs,,-0.9347721,-0.8367545,-0.77415204,-0.26920766,1.24447 40 | attn_1_0_outputs,,-0.04765274,-0.13257939,0.18410715,0.88693845,0.29562962 41 | attn_1_0_outputs,,-2.463801,0.061894733,-0.032191936,0.4246904,0.90985197 42 | attn_1_1_outputs,0,2.2647228,-1.0525606,0.11064329,-0.4824749,-2.2326334 43 | attn_1_1_outputs,1,-1.016662,2.9539187,1.0269936,-1.7816477,-2.2167149 44 | attn_1_1_outputs,2,0.19277221,1.0981342,1.3715613,0.008356502,-0.45385465 45 | attn_1_1_outputs,3,-0.21926539,-0.4179502,1.2211832,3.2872894,-2.8288746 46 | attn_1_1_outputs,4,-4.9702883,-4.228738,-1.920627,0.4357961,6.435965 47 | attn_1_1_outputs,,-2.873453,-4.4431806,-1.4632183,0.603075,5.768971 48 | attn_1_1_outputs,,-0.25213683,0.016516566,0.26208618,-0.27179545,-0.22851557 49 | attn_1_1_outputs,,-1.7156378,-0.26994628,-0.20035094,0.42288014,1.1078966 50 | attn_1_2_outputs,0,2.66944,-0.45978594,0.24579379,0.6636476,-2.0367975 51 | attn_1_2_outputs,1,-1.3979408,2.155738,0.6993282,-0.12806043,-1.3829575 52 | attn_1_2_outputs,2,-0.5397533,0.3294003,1.0066247,0.8717879,-0.453143 53 | attn_1_2_outputs,3,-0.79578793,-0.5700265,1.0958053,3.8531687,-2.876995 54 | attn_1_2_outputs,4,-4.6522145,-4.2854285,-2.4250696,0.3245887,5.7784514 55 | attn_1_2_outputs,,-1.2333149,-0.8423292,-0.6918403,0.26540172,1.3943026 56 | attn_1_2_outputs,,0.14982238,0.35288456,-0.88604194,-0.06601087,-0.15605317 57 | attn_1_2_outputs,,-0.6495551,-2.0333426,-1.4724616,-0.4954959,2.3913198 58 | attn_1_3_outputs,0,1.0067343,-0.6924186,-0.22267842,0.0022108958,-0.34514758 59 | attn_1_3_outputs,1,-1.1532832,1.2752177,0.4407405,-0.2488312,-0.34415925 60 | attn_1_3_outputs,2,0.43818164,0.1435282,0.32248864,0.17158462,-0.002002176 61 | attn_1_3_outputs,3,-0.05012911,0.06287343,0.9565397,2.1478107,-1.4736792 62 | attn_1_3_outputs,4,-3.6912036,-2.599063,-0.84031093,0.7071018,3.493349 63 | attn_1_3_outputs,,-1.0017529,-0.16416864,0.032560404,0.2880399,0.6146985 64 | attn_1_3_outputs,,0.08257049,-0.038490284,-0.1203994,0.65081865,-0.14556827 65 | attn_1_3_outputs,,-2.321713,-1.3995543,-0.75829405,-0.35096657,2.7249875 66 | attn_2_0_outputs,0,0.65123445,-0.00015970842,0.28133044,0.29919377,-0.4494978 67 | attn_2_0_outputs,1,-0.46387243,-0.22923669,-0.09010046,0.5440508,-0.58355516 68 | attn_2_0_outputs,2,-0.746604,-0.8561036,0.9778113,-0.8867997,1.2600936 69 | attn_2_0_outputs,3,-2.3676467,-0.8762227,0.047619678,2.8621771,0.106826946 70 | attn_2_0_outputs,4,-2.2080712,-2.103707,-1.5367882,-0.9091501,3.8440235 71 | attn_2_0_outputs,,-1.5794939,-0.87683463,-0.5443371,0.67485535,0.6738131 72 | attn_2_0_outputs,,0.0008923216,0.19533116,0.36518034,-0.08854164,-0.2732349 73 | attn_2_0_outputs,,-0.59456545,-0.38609,-0.37244636,-0.48838097,1.1909106 74 | attn_2_1_outputs,0,6.305336,0.77024233,-0.5253686,-3.0544596,-2.9562898 75 | attn_2_1_outputs,1,-0.506348,0.88375586,-0.5610058,-0.3551172,-0.44845787 76 | attn_2_1_outputs,2,-0.79803765,-0.26507252,0.40326494,0.20280881,-0.38371047 77 | attn_2_1_outputs,3,-2.2682018,0.34390137,0.8440804,2.4715912,-1.713408 78 | attn_2_1_outputs,4,-8.596309,-6.2820163,-3.0651898,0.6778617,7.88834 79 | attn_2_1_outputs,,-0.9434871,-0.3908067,0.28419968,1.0826464,-0.4644816 80 | attn_2_1_outputs,,-0.73219436,-0.09140464,0.21194346,0.6483371,0.62705564 81 | attn_2_1_outputs,,-3.0782797,0.004718122,1.668909,2.3308017,-2.3886883 82 | attn_2_2_outputs,0,0.71128976,-0.3320928,0.04913793,0.27017894,-0.5422503 83 | attn_2_2_outputs,1,-2.1448827,1.1844211,0.69605696,0.05205563,-0.086535886 84 | attn_2_2_outputs,2,-1.3683058,0.17096008,0.6356051,0.6679375,0.6418501 85 | attn_2_2_outputs,3,-1.5241508,-0.8560207,0.32365173,1.5988479,0.07578664 86 | attn_2_2_outputs,4,-2.2993717,-1.7863915,-0.8184309,0.3028464,2.733902 87 | attn_2_2_outputs,,-1.1926019,-0.42977494,-0.319407,1.0023048,0.43580797 88 | attn_2_2_outputs,,-0.11356022,0.018971028,-0.61723095,-0.21103133,-0.4132432 89 | attn_2_2_outputs,,-0.7201614,-1.980554,-1.0417174,0.70280063,0.9491416 90 | attn_2_3_outputs,0,-0.6578003,-1.4176185,-0.16867943,0.79162806,1.3469748 91 | attn_2_3_outputs,1,-1.2225962,0.5266581,0.31030974,0.50686336,0.88578856 92 | attn_2_3_outputs,2,-1.1110145,-0.59071904,0.040772133,0.40666303,0.49748218 93 | attn_2_3_outputs,3,-0.7991637,-0.63388026,0.025102817,1.1622124,0.9700968 94 | attn_2_3_outputs,4,-1.523163,-1.0361481,-0.06378266,0.8740367,1.220963 95 | attn_2_3_outputs,,-2.4229674,-0.92591256,0.03990261,0.5686489,1.1961548 96 | attn_2_3_outputs,,-0.7670137,0.32505798,-0.19609599,0.63817674,0.36692366 97 | attn_2_3_outputs,,-0.17675255,-0.061802976,0.41715944,1.5147864,0.637639 98 | mlp_0_0_outputs,0,-0.8589103,-0.1552966,0.61306983,0.07937584,2.6195028 99 | mlp_0_0_outputs,1,-0.089151576,0.6665328,0.64555883,0.26704308,-1.2672955 100 | mlp_0_0_outputs,2,-1.3417193,0.89044684,1.4538382,1.145795,-0.34172156 101 | mlp_0_0_outputs,3,5.009247,2.353605,-0.03374805,-2.3626945,-6.8675437 102 | mlp_0_0_outputs,4,-0.00933083,0.8026369,0.70613724,0.33960083,-0.33451435 103 | mlp_0_0_outputs,5,-0.036977522,-0.1316631,0.5099742,0.5255853,0.5357659 104 | mlp_0_0_outputs,6,0.09807571,0.65641683,-0.31813905,0.17056943,-0.71300334 105 | mlp_0_0_outputs,7,-10.803721,-6.045529,-1.7798707,3.2467043,5.0803566 106 | mlp_0_1_outputs,0,-0.5035141,-0.20353706,-0.7522745,0.3356952,-0.6081926 107 | mlp_0_1_outputs,1,-4.242264,0.032010686,1.6723924,1.2721102,1.0669887 108 | mlp_0_1_outputs,2,-1.4796605,-1.0899225,0.16888815,1.3151004,0.8872381 109 | mlp_0_1_outputs,3,-0.28629878,-0.28274626,0.22532974,-0.05161379,0.6493964 110 | mlp_0_1_outputs,4,0.044844028,0.594797,0.088082105,-0.6334384,-0.99465847 111 | mlp_0_1_outputs,5,3.6674852,2.6837974,0.23535982,-2.6139388,-6.132662 112 | mlp_0_1_outputs,6,-1.7653883,-9.047455,-5.737482,0.19138227,7.072668 113 | mlp_0_1_outputs,7,-5.4091525,-1.258377,1.3022676,2.1761138,1.0759268 114 | mlp_1_0_outputs,0,-8.089947,-3.9758954,-0.049512167,1.9524714,3.6874063 115 | mlp_1_0_outputs,1,-0.25300968,-0.18175596,0.41158956,0.34343994,0.2131009 116 | mlp_1_0_outputs,2,2.9635074,1.5911162,0.35373554,-1.8420515,-4.9741387 117 | mlp_1_0_outputs,3,-0.7055162,0.8571573,1.9057972,0.2289657,-1.0690937 118 | mlp_1_0_outputs,4,-2.4608774,-1.1034535,0.5610324,0.7546167,2.5527046 119 | mlp_1_0_outputs,5,-0.8002519,-0.37941346,0.35449395,-0.01350961,-0.08785999 120 | mlp_1_0_outputs,6,-0.27233246,1.2476993,0.5192803,-0.4762693,-0.067111954 121 | mlp_1_0_outputs,7,0.431638,-0.23128301,-0.22074634,-0.38865703,0.29866084 122 | mlp_1_1_outputs,0,2.570149,1.4482967,-0.018957324,-2.3094382,-4.180814 123 | mlp_1_1_outputs,1,-6.535501,-5.5468464,-2.4822295,0.90257144,4.8221583 124 | mlp_1_1_outputs,2,0.21253164,-0.8952474,-0.4594328,-0.5081009,0.12444797 125 | mlp_1_1_outputs,3,0.107153066,1.1495358,0.47018173,-0.3799021,-0.95472884 126 | mlp_1_1_outputs,4,-0.44310027,-0.50543004,0.1825062,0.14150716,0.23581605 127 | mlp_1_1_outputs,5,-5.052584,-2.6555338,1.050915,2.478589,2.2599907 128 | mlp_1_1_outputs,6,-0.97004384,0.11741753,0.20926222,-0.07429946,-0.20051539 129 | mlp_1_1_outputs,7,-1.5305095,0.36310226,0.9118783,0.68767303,-0.9041264 130 | mlp_2_0_outputs,0,-3.9110544,-0.9270479,0.4481431,2.1587543,2.9880595 131 | mlp_2_0_outputs,1,0.20119122,0.09129438,-0.10534093,-0.13398942,0.35579532 132 | mlp_2_0_outputs,2,-5.627938,-3.4119296,-0.54387337,2.0478127,3.7280648 133 | mlp_2_0_outputs,3,-0.8237606,-0.058409758,0.06423915,0.38022444,0.32353234 134 | mlp_2_0_outputs,4,0.523112,0.52292186,0.10970966,-1.1015787,-0.37122783 135 | mlp_2_0_outputs,5,-0.463441,0.4770951,0.31376934,0.037225943,0.74693483 136 | mlp_2_0_outputs,6,3.0876498,3.265038,0.6365583,-2.6961308,-5.7100368 137 | mlp_2_0_outputs,7,-0.03848318,0.07956341,0.12161484,-0.63174456,-0.15266788 138 | mlp_2_1_outputs,0,-1.19403,-0.013288918,0.25646126,-0.34185714,0.46653393 139 | mlp_2_1_outputs,1,-0.28805476,0.1984507,-0.054004673,0.08747607,-0.0060478314 140 | mlp_2_1_outputs,2,4.965758,3.2663405,0.7348362,-2.8354158,-7.0237536 141 | mlp_2_1_outputs,3,-0.3902142,0.14936583,0.46208203,0.5396692,-0.32771578 142 | mlp_2_1_outputs,4,-1.6855072,-0.8212709,-0.031309534,0.8434268,1.4494125 143 | mlp_2_1_outputs,5,-0.66520023,-0.12322676,0.23253511,0.8862266,0.5351467 144 | mlp_2_1_outputs,6,-0.12667158,0.17387478,0.37683076,0.30193081,0.033221602 145 | mlp_2_1_outputs,7,-6.953155,-4.2610555,-1.0238223,2.2198749,6.7818727 146 | num_attn_0_0_outputs,_,2.7008455,1.7573988,-0.5612883,-3.1591482,-0.8077632 147 | num_attn_0_1_outputs,_,1.8419529,1.7084945,1.2732,-0.6164277,-2.7078419 148 | num_attn_0_2_outputs,_,1.7593077,-0.78708524,2.0769553,1.6491287,-2.7435398 149 | num_attn_0_3_outputs,_,3.631106,-1.8356805,-2.8216524,-3.0461025,2.3921132 150 | num_attn_1_0_outputs,_,1.5844845,-3.8878868,-2.0250123,4.444674,0.40432557 151 | num_attn_1_1_outputs,_,-0.05317056,0.5621444,-0.03736721,-0.7964875,-1.0776875 152 | num_attn_1_2_outputs,_,0.37114012,2.4403827,0.115093544,-0.80042255,-2.1261592 153 | num_attn_1_3_outputs,_,2.2861667,3.0232863,0.6303658,-1.7195753,-3.5699532 154 | num_attn_2_0_outputs,_,-0.81557804,-0.12176335,8.105928,-0.29102525,-3.5535772 155 | num_attn_2_1_outputs,_,2.769653,2.2095704,0.32860684,-2.286327,-3.9597259 156 | num_attn_2_2_outputs,_,-0.4148719,0.17052153,0.75779927,1.0667933,-1.2106642 157 | num_attn_2_3_outputs,_,2.7550514,0.53845423,0.012782344,-0.42958698,-2.1298997 158 | num_mlp_0_0_outputs,0,-0.052239962,-0.9593842,-0.2768684,-0.024186546,0.13085616 159 | num_mlp_0_0_outputs,1,-0.47369894,-0.043022864,0.4291395,0.35447082,0.6350527 160 | num_mlp_0_0_outputs,2,-3.0991662,0.14400464,-0.43281138,0.75551677,2.5591822 161 | num_mlp_0_0_outputs,3,-0.6901305,-0.7482572,-0.22040156,-0.065956876,0.5631253 162 | num_mlp_0_0_outputs,4,-0.82058483,-1.0333056,-0.2300869,0.055369288,0.27217725 163 | num_mlp_0_0_outputs,5,-0.24412262,-0.77988327,-0.03587442,0.099466026,0.4355093 164 | num_mlp_0_0_outputs,6,-0.18769716,-0.87073785,-0.08458955,-0.0117840925,0.14444879 165 | num_mlp_0_0_outputs,7,0.7771105,-1.2428197,0.109384626,0.3367401,0.5219702 166 | num_mlp_0_1_outputs,0,-1.928114,-0.64596057,0.1970114,0.72684216,0.2434265 167 | num_mlp_0_1_outputs,1,1.4667885,-0.04219994,-1.7887018,-3.354516,1.6295062 168 | num_mlp_0_1_outputs,2,-0.98909557,-0.37205845,0.4755718,1.0217645,0.4730045 169 | num_mlp_0_1_outputs,3,-3.8669333,-1.8498815,0.3956896,1.7269781,1.3696659 170 | num_mlp_0_1_outputs,4,-2.0566127,-0.9900346,0.124328226,0.6929384,0.54728794 171 | num_mlp_0_1_outputs,5,-1.4753255,0.13082078,0.030201413,0.20565401,-0.11292716 172 | num_mlp_0_1_outputs,6,-0.84873676,0.19060001,0.3849743,0.5375944,0.21243429 173 | num_mlp_0_1_outputs,7,-1.5675162,-0.14484215,0.5516114,1.1091007,0.26935497 174 | num_mlp_1_0_outputs,0,-0.58943486,-0.5716185,0.09452104,0.51946837,0.8833337 175 | num_mlp_1_0_outputs,1,0.16628344,0.028681463,-0.1816516,0.4498699,-0.17988878 176 | num_mlp_1_0_outputs,2,-0.20675577,-0.3085491,-0.45421112,0.071232475,-0.4011028 177 | num_mlp_1_0_outputs,3,0.07106293,-0.22149324,-0.17892063,0.39708918,0.03731434 178 | num_mlp_1_0_outputs,4,-0.59267014,-0.2528348,-0.13206014,0.11341445,-0.21827152 179 | num_mlp_1_0_outputs,5,-3.1149821,-0.90035933,0.7789497,1.3473841,1.3481094 180 | num_mlp_1_0_outputs,6,-0.7881431,-0.3538762,0.080496565,0.33989707,0.09272613 181 | num_mlp_1_0_outputs,7,-1.9980044,-1.499147,-0.13361156,1.4995332,2.7140417 182 | num_mlp_1_1_outputs,0,-0.24536662,0.2940862,-0.34453082,0.046545297,0.20954652 183 | num_mlp_1_1_outputs,1,-1.5211539,-1.4764887,-0.24266818,0.62331903,0.839463 184 | num_mlp_1_1_outputs,2,-0.60920113,0.39721942,-0.2720907,0.6585138,1.2556117 185 | num_mlp_1_1_outputs,3,-0.7962207,-0.022392428,-0.098228484,0.12305577,0.5722526 186 | num_mlp_1_1_outputs,4,-1.1961201,-0.48624808,0.0026044075,0.9489868,1.084819 187 | num_mlp_1_1_outputs,5,-0.204229,0.41044012,-0.16175759,0.3332036,0.38448393 188 | num_mlp_1_1_outputs,6,-4.0577865,-1.1186019,-0.19234465,2.1006806,1.1207315 189 | num_mlp_1_1_outputs,7,-0.1431973,-0.2963608,0.35561675,0.5961827,-0.64972585 190 | num_mlp_2_0_outputs,0,1.8791451,-1.1463113,-0.51802814,0.93353784,-1.1545902 191 | num_mlp_2_0_outputs,1,-1.1244311,-0.1578107,0.516671,0.5967867,1.4779365 192 | num_mlp_2_0_outputs,2,-2.4088686,0.21240945,0.9339892,0.09490182,1.2009406 193 | num_mlp_2_0_outputs,3,-0.79646355,-0.035708643,-0.15684086,-0.09949614,0.48360366 194 | num_mlp_2_0_outputs,4,3.6157093,-6.813567,-1.4317391,1.9446093,2.5641017 195 | num_mlp_2_0_outputs,5,0.39062092,1.3092825,-0.19412902,0.75027215,-0.8247279 196 | num_mlp_2_0_outputs,6,-3.1036084,1.2624242,-1.3120074,-0.12865852,1.6784518 197 | num_mlp_2_0_outputs,7,-4.274369,0.27141675,0.99371576,0.6078943,1.7251453 198 | num_mlp_2_1_outputs,0,-1.2744915,-0.38451442,-0.0047318963,0.3987959,0.38946167 199 | num_mlp_2_1_outputs,1,-1.6799719,-0.7788202,-0.20249598,-0.18425512,1.1495807 200 | num_mlp_2_1_outputs,2,-0.34154946,-0.51062804,-0.3556491,-0.033050634,0.4181513 201 | num_mlp_2_1_outputs,3,-0.53709805,-1.6549367,-0.7475045,1.7235987,1.2236754 202 | num_mlp_2_1_outputs,4,-0.83570236,0.10588158,-0.06152223,0.15144911,0.08399685 203 | num_mlp_2_1_outputs,5,-0.79256934,-0.39871466,-0.007255134,0.15521254,0.46749616 204 | num_mlp_2_1_outputs,6,0.0021853913,0.06891026,0.11127618,0.06793636,-0.5632735 205 | num_mlp_2_1_outputs,7,-2.8081176,-0.20779653,-0.30498865,1.0156854,1.5306722 206 | ones,_,-1.3621695,-0.18504938,-0.2161428,0.37226388,0.5816039 207 | positions,0,-0.13063768,0.36072794,0.13273126,0.08235803,-0.78297955 208 | positions,1,3.3640394,1.4360483,-0.19464374,-3.6062639,-6.1498446 209 | positions,2,0.15126465,2.110513,0.34884313,-1.5275149,-2.7380152 210 | positions,3,-1.1328846,0.9832707,1.3702737,-1.1388667,0.24428602 211 | positions,4,-4.9033227,-5.421642,2.7506442,7.5432215,-0.07508394 212 | positions,5,-5.0503674,1.0907708,2.297857,1.425875,2.3817568 213 | positions,6,0.65689164,-6.7524247,-8.143799,-5.7251816,6.340086 214 | positions,7,-0.8686947,-0.5266267,-0.328554,0.117776334,0.3166736 215 | tokens,0,1.8981099,-0.8942574,-0.15392976,0.062256597,-1.4486569 216 | tokens,1,-1.294627,2.9094937,0.33469462,-1.4084401,-0.6528272 217 | tokens,2,-0.5369325,0.3324304,0.5662137,0.01150415,0.71778655 218 | tokens,3,-0.98839134,-0.6982198,0.7751885,3.5432122,-1.4779247 219 | tokens,4,-5.4420075,-4.702195,-2.883783,0.086826995,6.86268 220 | tokens,,-0.7684788,-0.24721867,0.30233318,-0.792211,0.23400383 221 | tokens,,-0.09745099,0.15648925,0.45449963,-0.054960843,-0.32444936 222 | tokens,,0.109125495,-0.04502393,0.25828347,0.35704604,-0.00495211 223 | -------------------------------------------------------------------------------- /programs/rasp_categorical_only/double_hist/double_hist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def select_closest(keys, queries, predicate): 6 | scores = [[False for _ in keys] for _ in queries] 7 | for i, q in enumerate(queries): 8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)] 9 | if not (any(matches)): 10 | scores[i][0] = True 11 | else: 12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j)) 13 | scores[i][j] = True 14 | return scores 15 | 16 | 17 | def aggregate(attention, values): 18 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention] 19 | 20 | 21 | def run(tokens): 22 | # classifier weights ########################################## 23 | classifier_weights = pd.read_csv( 24 | "programs/rasp_categorical_only/double_hist/double_hist_weights.csv", 25 | index_col=[0, 1], 26 | dtype={"feature": str}, 27 | ) 28 | # inputs ##################################################### 29 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]] 30 | 31 | positions = list(range(len(tokens))) 32 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]] 33 | 34 | ones = [1 for _ in range(len(tokens))] 35 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0) 36 | 37 | # attn_0_0 #################################################### 38 | def predicate_0_0(q_position, k_position): 39 | if q_position in {0, 3}: 40 | return k_position == 3 41 | elif q_position in {1, 2}: 42 | return k_position == 5 43 | elif q_position in {4, 6}: 44 | return k_position == 6 45 | elif q_position in {5}: 46 | return k_position == 2 47 | elif q_position in {7}: 48 | return k_position == 7 49 | 50 | attn_0_0_pattern = select_closest(positions, positions, predicate_0_0) 51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, positions) 52 | attn_0_0_output_scores = classifier_weights.loc[ 53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs] 54 | ] 55 | 56 | # attn_0_1 #################################################### 57 | def predicate_0_1(q_position, k_position): 58 | if q_position in {0}: 59 | return k_position == 5 60 | elif q_position in {1, 2, 3, 4, 5, 6}: 61 | return k_position == 6 62 | elif q_position in {7}: 63 | return k_position == 7 64 | 65 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1) 66 | attn_0_1_outputs = aggregate(attn_0_1_pattern, positions) 67 | attn_0_1_output_scores = classifier_weights.loc[ 68 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs] 69 | ] 70 | 71 | # attn_0_2 #################################################### 72 | def predicate_0_2(q_token, k_token): 73 | if q_token in {"1", "4", "0", "2", "3"}: 74 | return k_token == "5" 75 | elif q_token in {"5"}: 76 | return k_token == "0" 77 | elif q_token in {""}: 78 | return k_token == "" 79 | 80 | attn_0_2_pattern = select_closest(tokens, tokens, predicate_0_2) 81 | attn_0_2_outputs = aggregate(attn_0_2_pattern, positions) 82 | attn_0_2_output_scores = classifier_weights.loc[ 83 | [("attn_0_2_outputs", str(v)) for v in attn_0_2_outputs] 84 | ] 85 | 86 | # attn_0_3 #################################################### 87 | def predicate_0_3(q_position, k_position): 88 | if q_position in {0, 3}: 89 | return k_position == 3 90 | elif q_position in {1, 2, 4, 6}: 91 | return k_position == 6 92 | elif q_position in {5}: 93 | return k_position == 1 94 | elif q_position in {7}: 95 | return k_position == 7 96 | 97 | attn_0_3_pattern = select_closest(positions, positions, predicate_0_3) 98 | attn_0_3_outputs = aggregate(attn_0_3_pattern, positions) 99 | attn_0_3_output_scores = classifier_weights.loc[ 100 | [("attn_0_3_outputs", str(v)) for v in attn_0_3_outputs] 101 | ] 102 | 103 | # attn_0_4 #################################################### 104 | def predicate_0_4(q_position, k_position): 105 | if q_position in {0}: 106 | return k_position == 5 107 | elif q_position in {1, 2, 3, 4, 5, 6, 7}: 108 | return k_position == 7 109 | 110 | attn_0_4_pattern = select_closest(positions, positions, predicate_0_4) 111 | attn_0_4_outputs = aggregate(attn_0_4_pattern, positions) 112 | attn_0_4_output_scores = classifier_weights.loc[ 113 | [("attn_0_4_outputs", str(v)) for v in attn_0_4_outputs] 114 | ] 115 | 116 | # attn_0_5 #################################################### 117 | def predicate_0_5(q_position, k_position): 118 | if q_position in {0}: 119 | return k_position == 3 120 | elif q_position in {1, 7}: 121 | return k_position == 7 122 | elif q_position in {2}: 123 | return k_position == 5 124 | elif q_position in {3, 4, 5, 6}: 125 | return k_position == 6 126 | 127 | attn_0_5_pattern = select_closest(positions, positions, predicate_0_5) 128 | attn_0_5_outputs = aggregate(attn_0_5_pattern, positions) 129 | attn_0_5_output_scores = classifier_weights.loc[ 130 | [("attn_0_5_outputs", str(v)) for v in attn_0_5_outputs] 131 | ] 132 | 133 | # attn_0_6 #################################################### 134 | def predicate_0_6(q_token, k_token): 135 | if q_token in {"0"}: 136 | return k_token == "1" 137 | elif q_token in {"1", "2"}: 138 | return k_token == "0" 139 | elif q_token in {"5", "4", "3"}: 140 | return k_token == "2" 141 | elif q_token in {""}: 142 | return k_token == "5" 143 | 144 | attn_0_6_pattern = select_closest(tokens, tokens, predicate_0_6) 145 | attn_0_6_outputs = aggregate(attn_0_6_pattern, tokens) 146 | attn_0_6_output_scores = classifier_weights.loc[ 147 | [("attn_0_6_outputs", str(v)) for v in attn_0_6_outputs] 148 | ] 149 | 150 | # attn_0_7 #################################################### 151 | def predicate_0_7(q_token, k_token): 152 | if q_token in {"0"}: 153 | return k_token == "0" 154 | elif q_token in {"1"}: 155 | return k_token == "1" 156 | elif q_token in {"2"}: 157 | return k_token == "2" 158 | elif q_token in {"3"}: 159 | return k_token == "3" 160 | elif q_token in {"4"}: 161 | return k_token == "4" 162 | elif q_token in {"5"}: 163 | return k_token == "5" 164 | elif q_token in {""}: 165 | return k_token == "" 166 | 167 | attn_0_7_pattern = select_closest(tokens, tokens, predicate_0_7) 168 | attn_0_7_outputs = aggregate(attn_0_7_pattern, positions) 169 | attn_0_7_output_scores = classifier_weights.loc[ 170 | [("attn_0_7_outputs", str(v)) for v in attn_0_7_outputs] 171 | ] 172 | 173 | # mlp_0_0 ##################################################### 174 | def mlp_0_0(token, attn_0_6_output): 175 | key = (token, attn_0_6_output) 176 | if key in { 177 | ("0", "0"), 178 | ("0", "1"), 179 | ("0", "2"), 180 | ("0", "3"), 181 | ("0", "4"), 182 | ("0", "5"), 183 | ("0", ""), 184 | ("1", "0"), 185 | ("1", "1"), 186 | ("1", "2"), 187 | ("1", "3"), 188 | ("1", "4"), 189 | ("1", "5"), 190 | ("1", ""), 191 | ("3", "0"), 192 | ("4", "0"), 193 | ("4", "1"), 194 | }: 195 | return 3 196 | elif key in { 197 | ("2", "5"), 198 | ("2", ""), 199 | ("3", "4"), 200 | ("3", "5"), 201 | ("3", ""), 202 | ("", "0"), 203 | ("", "1"), 204 | ("", "2"), 205 | ("", "3"), 206 | ("", "4"), 207 | ("", "5"), 208 | ("", ""), 209 | }: 210 | return 0 211 | elif key in {("2", "0"), ("5", "0")}: 212 | return 2 213 | return 5 214 | 215 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(tokens, attn_0_6_outputs)] 216 | mlp_0_0_output_scores = classifier_weights.loc[ 217 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs] 218 | ] 219 | 220 | # mlp_0_1 ##################################################### 221 | def mlp_0_1(attn_0_6_output, token): 222 | key = (attn_0_6_output, token) 223 | if key in {("4", "1"), ("4", "3"), ("4", "4"), ("4", "")}: 224 | return 6 225 | return 5 226 | 227 | mlp_0_1_outputs = [mlp_0_1(k0, k1) for k0, k1 in zip(attn_0_6_outputs, tokens)] 228 | mlp_0_1_output_scores = classifier_weights.loc[ 229 | [("mlp_0_1_outputs", str(v)) for v in mlp_0_1_outputs] 230 | ] 231 | 232 | # attn_1_0 #################################################### 233 | def predicate_1_0(q_token, k_token): 234 | if q_token in {"1", "4", "0", "2", "5"}: 235 | return k_token == "3" 236 | elif q_token in {"3"}: 237 | return k_token == "0" 238 | elif q_token in {""}: 239 | return k_token == "" 240 | 241 | attn_1_0_pattern = select_closest(tokens, tokens, predicate_1_0) 242 | attn_1_0_outputs = aggregate(attn_1_0_pattern, tokens) 243 | attn_1_0_output_scores = classifier_weights.loc[ 244 | [("attn_1_0_outputs", str(v)) for v in attn_1_0_outputs] 245 | ] 246 | 247 | # attn_1_1 #################################################### 248 | def predicate_1_1(q_token, k_token): 249 | if q_token in {"0", "3"}: 250 | return k_token == "1" 251 | elif q_token in {"1", "4"}: 252 | return k_token == "2" 253 | elif q_token in {"5", "2"}: 254 | return k_token == "4" 255 | elif q_token in {""}: 256 | return k_token == "" 257 | 258 | attn_1_1_pattern = select_closest(tokens, tokens, predicate_1_1) 259 | attn_1_1_outputs = aggregate(attn_1_1_pattern, mlp_0_0_outputs) 260 | attn_1_1_output_scores = classifier_weights.loc[ 261 | [("attn_1_1_outputs", str(v)) for v in attn_1_1_outputs] 262 | ] 263 | 264 | # attn_1_2 #################################################### 265 | def predicate_1_2(q_token, k_token): 266 | if q_token in {"4", "5", "0", "2"}: 267 | return k_token == "2" 268 | elif q_token in {"1"}: 269 | return k_token == "1" 270 | elif q_token in {"", "3"}: 271 | return k_token == "5" 272 | 273 | attn_1_2_pattern = select_closest(tokens, tokens, predicate_1_2) 274 | attn_1_2_outputs = aggregate(attn_1_2_pattern, mlp_0_0_outputs) 275 | attn_1_2_output_scores = classifier_weights.loc[ 276 | [("attn_1_2_outputs", str(v)) for v in attn_1_2_outputs] 277 | ] 278 | 279 | # attn_1_3 #################################################### 280 | def predicate_1_3(attn_0_5_output, position): 281 | if attn_0_5_output in {0, 1, 3, 6}: 282 | return position == 4 283 | elif attn_0_5_output in {2, 4, 7}: 284 | return position == 5 285 | elif attn_0_5_output in {5}: 286 | return position == 2 287 | 288 | attn_1_3_pattern = select_closest(positions, attn_0_5_outputs, predicate_1_3) 289 | attn_1_3_outputs = aggregate(attn_1_3_pattern, attn_0_0_outputs) 290 | attn_1_3_output_scores = classifier_weights.loc[ 291 | [("attn_1_3_outputs", str(v)) for v in attn_1_3_outputs] 292 | ] 293 | 294 | # attn_1_4 #################################################### 295 | def predicate_1_4(q_token, k_token): 296 | if q_token in {"0"}: 297 | return k_token == "2" 298 | elif q_token in {"1", "4"}: 299 | return k_token == "0" 300 | elif q_token in {"2"}: 301 | return k_token == "4" 302 | elif q_token in {"3"}: 303 | return k_token == "1" 304 | elif q_token in {"5", ""}: 305 | return k_token == "" 306 | 307 | attn_1_4_pattern = select_closest(tokens, tokens, predicate_1_4) 308 | attn_1_4_outputs = aggregate(attn_1_4_pattern, attn_0_7_outputs) 309 | attn_1_4_output_scores = classifier_weights.loc[ 310 | [("attn_1_4_outputs", str(v)) for v in attn_1_4_outputs] 311 | ] 312 | 313 | # attn_1_5 #################################################### 314 | def predicate_1_5(attn_0_4_output, attn_0_0_output): 315 | if attn_0_4_output in {0, 1, 3, 5}: 316 | return attn_0_0_output == 6 317 | elif attn_0_4_output in {2, 6, 7}: 318 | return attn_0_0_output == 7 319 | elif attn_0_4_output in {4}: 320 | return attn_0_0_output == 5 321 | 322 | attn_1_5_pattern = select_closest(attn_0_0_outputs, attn_0_4_outputs, predicate_1_5) 323 | attn_1_5_outputs = aggregate(attn_1_5_pattern, attn_0_1_outputs) 324 | attn_1_5_output_scores = classifier_weights.loc[ 325 | [("attn_1_5_outputs", str(v)) for v in attn_1_5_outputs] 326 | ] 327 | 328 | # attn_1_6 #################################################### 329 | def predicate_1_6(attn_0_0_output, attn_0_7_output): 330 | if attn_0_0_output in {0, 6}: 331 | return attn_0_7_output == 7 332 | elif attn_0_0_output in {1, 4, 5}: 333 | return attn_0_7_output == 6 334 | elif attn_0_0_output in {2}: 335 | return attn_0_7_output == 5 336 | elif attn_0_0_output in {3}: 337 | return attn_0_7_output == 3 338 | elif attn_0_0_output in {7}: 339 | return attn_0_7_output == 1 340 | 341 | attn_1_6_pattern = select_closest(attn_0_7_outputs, attn_0_0_outputs, predicate_1_6) 342 | attn_1_6_outputs = aggregate(attn_1_6_pattern, attn_0_3_outputs) 343 | attn_1_6_output_scores = classifier_weights.loc[ 344 | [("attn_1_6_outputs", str(v)) for v in attn_1_6_outputs] 345 | ] 346 | 347 | # attn_1_7 #################################################### 348 | def predicate_1_7(q_token, k_token): 349 | if q_token in {"1", "0", "3"}: 350 | return k_token == "4" 351 | elif q_token in {"4", "5", "2"}: 352 | return k_token == "1" 353 | elif q_token in {""}: 354 | return k_token == "" 355 | 356 | attn_1_7_pattern = select_closest(tokens, tokens, predicate_1_7) 357 | attn_1_7_outputs = aggregate(attn_1_7_pattern, mlp_0_0_outputs) 358 | attn_1_7_output_scores = classifier_weights.loc[ 359 | [("attn_1_7_outputs", str(v)) for v in attn_1_7_outputs] 360 | ] 361 | 362 | # mlp_1_0 ##################################################### 363 | def mlp_1_0(attn_1_5_output, attn_1_3_output): 364 | key = (attn_1_5_output, attn_1_3_output) 365 | if key in { 366 | (0, 1), 367 | (0, 3), 368 | (0, 5), 369 | (1, 1), 370 | (1, 3), 371 | (1, 5), 372 | (2, 3), 373 | (3, 3), 374 | (4, 1), 375 | (4, 3), 376 | (4, 5), 377 | (5, 0), 378 | (5, 1), 379 | (5, 2), 380 | (5, 3), 381 | (5, 4), 382 | (5, 5), 383 | }: 384 | return 3 385 | return 5 386 | 387 | mlp_1_0_outputs = [ 388 | mlp_1_0(k0, k1) for k0, k1 in zip(attn_1_5_outputs, attn_1_3_outputs) 389 | ] 390 | mlp_1_0_output_scores = classifier_weights.loc[ 391 | [("mlp_1_0_outputs", str(v)) for v in mlp_1_0_outputs] 392 | ] 393 | 394 | # mlp_1_1 ##################################################### 395 | def mlp_1_1(attn_1_7_output, attn_1_5_output): 396 | key = (attn_1_7_output, attn_1_5_output) 397 | if key in { 398 | (1, 0), 399 | (1, 1), 400 | (1, 3), 401 | (1, 4), 402 | (1, 5), 403 | (2, 0), 404 | (2, 1), 405 | (2, 3), 406 | (2, 4), 407 | (2, 5), 408 | (3, 0), 409 | (3, 1), 410 | (3, 3), 411 | (3, 4), 412 | (3, 5), 413 | (6, 0), 414 | (6, 1), 415 | (6, 3), 416 | (6, 4), 417 | (6, 5), 418 | (7, 0), 419 | (7, 1), 420 | (7, 3), 421 | (7, 4), 422 | (7, 5), 423 | }: 424 | return 1 425 | elif key in {(0, 7), (1, 7), (2, 7), (3, 7), (4, 7), (5, 7), (6, 7)}: 426 | return 5 427 | elif key in {(4, 0)}: 428 | return 2 429 | return 0 430 | 431 | mlp_1_1_outputs = [ 432 | mlp_1_1(k0, k1) for k0, k1 in zip(attn_1_7_outputs, attn_1_5_outputs) 433 | ] 434 | mlp_1_1_output_scores = classifier_weights.loc[ 435 | [("mlp_1_1_outputs", str(v)) for v in mlp_1_1_outputs] 436 | ] 437 | 438 | feature_logits = pd.concat( 439 | [ 440 | df.reset_index() 441 | for df in [ 442 | token_scores, 443 | position_scores, 444 | attn_0_0_output_scores, 445 | attn_0_1_output_scores, 446 | attn_0_2_output_scores, 447 | attn_0_3_output_scores, 448 | attn_0_4_output_scores, 449 | attn_0_5_output_scores, 450 | attn_0_6_output_scores, 451 | attn_0_7_output_scores, 452 | mlp_0_0_output_scores, 453 | mlp_0_1_output_scores, 454 | attn_1_0_output_scores, 455 | attn_1_1_output_scores, 456 | attn_1_2_output_scores, 457 | attn_1_3_output_scores, 458 | attn_1_4_output_scores, 459 | attn_1_5_output_scores, 460 | attn_1_6_output_scores, 461 | attn_1_7_output_scores, 462 | mlp_1_0_output_scores, 463 | mlp_1_1_output_scores, 464 | one_scores, 465 | ] 466 | ] 467 | ) 468 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy() 469 | classes = classifier_weights.columns.to_numpy() 470 | predictions = classes[logits.argmax(-1)] 471 | if tokens[0] == "": 472 | predictions[0] = "" 473 | if tokens[-1] == "": 474 | predictions[-1] = "" 475 | return predictions.tolist() 476 | 477 | 478 | examples = [ 479 | (["", "4", "5", "0", "5", "2", "2"], ["", "2", "2", "2", "2", "2", "2"]), 480 | (["", "5", "2", "3", "5", "2", "3"], ["", "3", "3", "3", "3", "3", "3"]), 481 | (["", "3", "1", "3", "5"], ["", "1", "2", "1", "2"]), 482 | ( 483 | ["", "0", "2", "5", "0", "3", "3", "4"], 484 | ["", "2", "3", "3", "2", "2", "2", "3"], 485 | ), 486 | (["", "1", "0", "5", "4"], ["", "4", "4", "4", "4"]), 487 | (["", "1", "4", "1", "2", "0", "1"], ["", "1", "3", "1", "3", "3", "1"]), 488 | (["", "2", "0", "3", "2", "3", "5"], ["", "2", "2", "2", "2", "2", "2"]), 489 | ( 490 | ["", "1", "2", "1", "2", "4", "1", "5"], 491 | ["", "1", "1", "1", "1", "2", "1", "2"], 492 | ), 493 | (["", "3", "1", "1", "0"], ["", "2", "1", "1", "2"]), 494 | (["", "4", "0", "5", "5", "4"], ["", "2", "1", "2", "2", "2"]), 495 | ] 496 | for x, y in examples: 497 | print(f"x: {x}") 498 | print(f"y: {y}") 499 | y_hat = run(x) 500 | print(f"y_hat: {y_hat}") 501 | print() 502 | -------------------------------------------------------------------------------- /programs/rasp_categorical_only/double_hist/double_hist_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,1,2,3,4,5,6 2 | attn_0_0_outputs,0,0.126095,-0.09308497,-0.04164269,1.1473964,-0.65259224,-0.5759567 3 | attn_0_0_outputs,1,0.2643548,0.057178997,0.34055573,-0.13716577,-0.2368343,-0.39215782 4 | attn_0_0_outputs,2,0.09445334,-0.113155134,0.09719708,-0.3638223,-0.7040596,-1.0189345 5 | attn_0_0_outputs,3,0.3111331,0.07621635,0.29869452,-0.11679748,-0.4302198,-0.4611198 6 | attn_0_0_outputs,4,0.1457545,-0.08105723,0.16098967,0.43949932,0.038392957,-0.3110958 7 | attn_0_0_outputs,5,0.2503496,-0.0776672,0.2193873,-0.16906078,-0.25772113,-0.5475772 8 | attn_0_0_outputs,6,0.6186277,0.37508786,0.42960152,-0.44146442,-0.41667625,-0.62044233 9 | attn_0_0_outputs,7,0.55322224,0.21084103,0.76800203,-0.35938764,0.6075319,-4.7059984 10 | attn_0_1_outputs,0,-0.20470576,-0.5146743,-0.045462333,-0.7030797,8.210753,2.5101397 11 | attn_0_1_outputs,1,0.92150104,0.72936517,-1.0153459,0.9652804,2.3336089,1.3663455 12 | attn_0_1_outputs,2,0.039329447,-0.25970924,-0.53097314,0.094466195,-1.3910042,-0.7387352 13 | attn_0_1_outputs,3,0.30621508,0.059149753,0.13238326,0.42272153,0.28027993,0.14599061 14 | attn_0_1_outputs,4,0.13979882,-0.12561479,0.2264524,-0.007633883,0.46641046,-0.11170333 15 | attn_0_1_outputs,5,-0.21153452,-0.42493245,-0.2567995,-0.22684059,1.0124276,-0.56272596 16 | attn_0_1_outputs,6,0.6481529,0.54434663,0.0014183833,0.29601145,-5.983322,-1.2265881 17 | attn_0_1_outputs,7,0.5336258,0.4396987,-0.00064421905,0.07741543,-4.62246,-7.6855483 18 | attn_0_2_outputs,0,1.7004082,1.068717,0.67274964,-2.202218,-14.525051,-24.145012 19 | attn_0_2_outputs,1,-0.34902173,-0.46517774,-0.26951686,0.4179399,2.3199003,3.5011954 20 | attn_0_2_outputs,2,-0.30551437,-0.49071354,-0.20850673,0.52268445,2.352511,3.313084 21 | attn_0_2_outputs,3,-0.36689532,-0.6685318,-0.4558378,0.18013526,2.1214046,3.1762848 22 | attn_0_2_outputs,4,-0.42851058,-0.6440964,-0.38658094,0.28424677,2.1883128,2.630304 23 | attn_0_2_outputs,5,0.045050662,-0.07220718,0.047026966,0.6048806,2.6343098,2.9428246 24 | attn_0_2_outputs,6,-0.27480802,-0.30446512,-0.2069842,0.5289953,2.8373914,2.2616675 25 | attn_0_2_outputs,7,0.06093425,0.14711075,0.24991485,0.32365993,2.6028795,-4.8855653 26 | attn_0_3_outputs,0,0.09755095,0.18649152,-0.27435488,-1.9707323,-1.4276974,-2.0362265 27 | attn_0_3_outputs,1,0.35753757,0.23701514,-0.25319183,-0.6494274,-1.605695,-0.7366025 28 | attn_0_3_outputs,2,0.32133347,0.2595023,0.31592253,0.17246504,-0.1395176,-0.22350298 29 | attn_0_3_outputs,3,0.24959454,0.32810065,-0.024572447,-0.08864521,-0.25495535,-0.24189267 30 | attn_0_3_outputs,4,-0.25754684,-0.08468185,-0.3961516,-0.13146977,-0.35567138,-0.29548463 31 | attn_0_3_outputs,5,0.23781574,0.3707741,0.1524827,0.0640203,0.10271894,-0.15585618 32 | attn_0_3_outputs,6,-0.04686378,0.1890392,0.06480245,0.41855758,1.3988882,-0.042787485 33 | attn_0_3_outputs,7,-0.1659183,0.11428217,0.0062955585,0.4894925,0.07572694,-5.7053986 34 | attn_0_4_outputs,0,0.0009369574,0.07684592,0.040978912,0.5028028,0.9522137,3.6835043 35 | attn_0_4_outputs,1,-0.5538695,0.603373,-0.20237413,-0.11773054,0.77875495,-2.9072464 36 | attn_0_4_outputs,2,0.30023372,0.00922706,-0.50846523,0.59284455,-1.5345312,-0.671517 37 | attn_0_4_outputs,3,0.61231476,-0.13248341,0.6498266,0.61870444,0.7098338,-0.43163025 38 | attn_0_4_outputs,4,0.341006,0.14442582,0.41357324,0.23423965,-0.040843505,-0.04311107 39 | attn_0_4_outputs,5,0.15555005,-0.0129139135,0.038804073,-0.07452761,-0.2310702,-1.8921185 40 | attn_0_4_outputs,6,0.085187085,0.11249486,-0.30318964,0.5192505,-1.7593035,-0.5133031 41 | attn_0_4_outputs,7,0.67837936,0.4230805,0.6677434,-1.2566549,-1.3864559,-14.990864 42 | attn_0_5_outputs,0,0.117494754,-0.0076674935,0.17776193,-0.18112357,-0.27396384,-1.0354977 43 | attn_0_5_outputs,1,1.411103,-0.9264894,1.0556792,0.40703505,0.21264277,-0.08458514 44 | attn_0_5_outputs,2,0.08361238,0.5619377,0.48529363,-0.33324116,0.12216819,-0.40256438 45 | attn_0_5_outputs,3,-0.4463776,0.7458927,-0.2468474,-0.110574245,0.27681002,-0.2829393 46 | attn_0_5_outputs,4,0.27849215,0.050043125,-0.27561343,-0.05235713,-0.050278183,-0.09016176 47 | attn_0_5_outputs,5,-0.08726232,-0.21527587,0.051751416,-0.38245076,-0.7218009,-1.5243168 48 | attn_0_5_outputs,6,0.41514802,0.25593644,0.4507316,-0.057119988,-0.124129154,-0.17267601 49 | attn_0_5_outputs,7,0.5153719,0.2776314,0.6347292,-0.058477916,0.21668187,-5.153364 50 | attn_0_6_outputs,0,0.4399908,0.40121058,0.2908646,0.30896205,1.0355054,2.6121292 51 | attn_0_6_outputs,1,-0.4126085,-0.40561342,-0.2074391,0.47275737,0.93861,-0.9778263 52 | attn_0_6_outputs,2,-0.16018946,-0.020672152,0.5263022,1.8692286,3.3783615,2.0448 53 | attn_0_6_outputs,3,1.2413086,0.609498,0.27304208,-2.0135508,-3.9722664,-2.54368 54 | attn_0_6_outputs,4,-0.4699255,-0.07383059,0.8524104,0.43587294,0.46777043,0.14019349 55 | attn_0_6_outputs,5,-0.79917717,-0.056769755,0.6343696,0.45576847,-2.0306313,-0.97413427 56 | attn_0_6_outputs,,-0.27148145,0.012953444,-0.09518302,0.17713453,-0.20631313,-0.18813157 57 | attn_0_6_outputs,,0.68961024,0.53815615,0.44267753,-0.3528938,-4.7231135,-14.933182 58 | attn_0_7_outputs,0,0.26736894,-0.24484192,0.017989898,0.36529353,0.21895835,-0.13605899 59 | attn_0_7_outputs,1,-0.028649254,0.027535517,0.3608771,0.6933264,0.58034605,-0.25237384 60 | attn_0_7_outputs,2,-0.13851081,-0.03744818,0.35292834,0.566245,0.34746164,-0.39807728 61 | attn_0_7_outputs,3,0.207372,0.20546846,0.3811095,0.65656036,-0.3655698,-0.94992673 62 | attn_0_7_outputs,4,-0.21982807,-0.07117972,0.24072638,0.580462,0.22779587,-0.25692478 63 | attn_0_7_outputs,5,0.41687506,0.26974058,0.19956248,-0.015195189,-3.0216563,-1.1237081 64 | attn_0_7_outputs,6,-0.34657335,-0.24584907,0.0033485373,0.29958242,0.08694523,-0.8937644 65 | attn_0_7_outputs,7,0.2750328,0.3009795,0.66187775,0.84384704,0.58818835,-3.057173 66 | attn_1_0_outputs,0,-1.1977414,-0.68712807,0.21589795,1.5379475,2.7449412,6.4803658 67 | attn_1_0_outputs,1,0.43876192,0.5447946,1.4586774,-1.0479536,-2.8326502,-4.8915324 68 | attn_1_0_outputs,2,-0.096182294,0.5217153,0.98753965,-0.32770967,-1.3522687,-2.2044451 69 | attn_1_0_outputs,3,-1.6992767,-1.1615183,-0.33289766,0.98977876,6.670316,7.1105237 70 | attn_1_0_outputs,4,0.9227671,0.8429137,0.9108278,-1.8007963,-3.0463388,-6.395819 71 | attn_1_0_outputs,5,-0.37416497,0.8064592,1.6661402,-0.17301181,-2.1290236,-6.337081 72 | attn_1_0_outputs,,0.16492641,0.2724202,0.30109468,0.1270346,-0.42348635,-0.6057034 73 | attn_1_0_outputs,,1.0615019,1.1287842,1.2637048,-1.2012541,-14.706598,-19.987967 74 | attn_1_1_outputs,0,1.3178737,1.3075073,0.6171435,-1.2386639,-14.961128,-22.710047 75 | attn_1_1_outputs,1,-0.79646075,-0.048236232,-0.57135,0.8660453,5.204853,3.9006507 76 | attn_1_1_outputs,2,-0.64026845,-0.31622618,-0.80555373,-0.38438004,-3.0017076,3.1284397 77 | attn_1_1_outputs,3,-0.413255,-0.0743735,-0.44691494,0.00972029,1.0504625,1.6899987 78 | attn_1_1_outputs,4,-1.0316962,-0.29744324,-0.7646268,0.6624031,2.773765,2.0252886 79 | attn_1_1_outputs,5,-0.32514748,0.08588939,-0.24190034,1.1112455,4.891911,-4.0668387 80 | attn_1_1_outputs,6,-0.18358597,0.007782329,-0.167846,0.55784917,1.0497051,1.5342381 81 | attn_1_1_outputs,7,-0.89457923,-0.32318884,-0.5396412,0.82994485,2.7423346,3.3378825 82 | attn_1_2_outputs,0,0.6464282,0.4415729,0.34436905,0.17836824,-4.205051,-11.493903 83 | attn_1_2_outputs,1,-0.0137010105,-0.445989,-0.4315788,-0.57123286,0.7491084,2.0127614 84 | attn_1_2_outputs,2,-0.28706008,-0.40443745,-0.38865212,0.5154979,4.107742,7.3013763 85 | attn_1_2_outputs,3,-0.10718431,-0.3202067,-0.32353058,0.20667966,1.9230273,1.5108932 86 | attn_1_2_outputs,4,0.33059838,0.07220246,-0.23482965,-1.067047,0.19484031,1.3690035 87 | attn_1_2_outputs,5,0.3945407,0.329356,0.33315465,0.6143382,1.234685,-1.9626414 88 | attn_1_2_outputs,6,-0.5554587,0.40695667,0.13053398,0.9338181,0.99574876,0.43968514 89 | attn_1_2_outputs,7,0.33515787,0.09245179,0.09882208,-0.6038185,0.082006045,2.063728 90 | attn_1_3_outputs,0,-0.535496,-0.13259524,-0.71865547,-0.3909597,-0.38141873,-0.40478173 91 | attn_1_3_outputs,1,0.20157264,1.0196886,-0.502918,-0.23266602,-0.06780651,0.2839907 92 | attn_1_3_outputs,2,-0.11747025,0.37703773,-0.42134264,-0.14813195,-0.20976208,-0.034975 93 | attn_1_3_outputs,3,-0.37622637,-6.0123005,5.564005,-2.447405,0.38429993,-0.59021026 94 | attn_1_3_outputs,4,1.1969863,-0.76458126,-1.5581679,-0.42092353,0.42188513,0.6117949 95 | attn_1_3_outputs,5,-0.13388146,0.33084488,-0.3471546,-0.08674212,-0.17246169,-0.21080919 96 | attn_1_3_outputs,6,0.040989626,0.52416426,-0.098571725,0.21533841,-0.24930477,0.22166981 97 | attn_1_3_outputs,7,0.2629529,0.48815468,-0.1758514,0.5167536,-0.11469977,-7.593266 98 | attn_1_4_outputs,0,0.7242939,0.6138902,-0.2165776,-1.189821,-6.016478,-2.156461 99 | attn_1_4_outputs,1,-0.028479,0.10070866,-0.31167245,0.2912638,2.3396828,1.539698 100 | attn_1_4_outputs,2,0.03833334,0.08542576,-0.23502897,0.34766886,1.8727108,0.8869683 101 | attn_1_4_outputs,3,0.14814246,0.14517051,-0.23019965,0.4274214,2.4027905,0.5986793 102 | attn_1_4_outputs,4,-0.060111366,0.08575323,-0.3969963,0.3362898,2.0958183,1.2442379 103 | attn_1_4_outputs,5,-0.083861604,0.069971085,-0.24564557,0.27127108,1.9678422,0.49202684 104 | attn_1_4_outputs,6,-0.013286352,0.3606033,-0.18357426,0.50619376,2.4264543,0.56555873 105 | attn_1_4_outputs,7,-0.2426668,-0.2329866,-0.72179276,0.0852172,1.769792,-2.2768142 106 | attn_1_5_outputs,0,-0.38110098,0.40513593,-4.439374,4.093755,2.1234877,0.5540743 107 | attn_1_5_outputs,1,-0.23400371,-0.036937438,-0.09517478,-0.38474637,-0.2784226,-0.58958876 108 | attn_1_5_outputs,2,0.3410733,0.18570994,-0.07866506,0.34950978,-0.75937384,-0.24752817 109 | attn_1_5_outputs,3,0.021574836,-0.6249923,0.34806028,-0.11711029,3.910408,-0.14427455 110 | attn_1_5_outputs,4,-0.4571682,-0.37512428,-0.30038518,-0.41331235,0.18258524,-0.23161632 111 | attn_1_5_outputs,5,0.24728091,0.4856511,0.594313,-2.091695,2.0573926,-0.7336284 112 | attn_1_5_outputs,6,0.22816464,0.27136412,0.5195222,-0.6798148,-5.8662944,-0.39577755 113 | attn_1_5_outputs,7,0.12520222,0.22135924,0.7278236,-0.72360736,-2.3794682,-8.568124 114 | attn_1_6_outputs,0,1.9565225,1.992841,-4.4924707,-1.4354049,-3.0359225,0.4108184 115 | attn_1_6_outputs,1,-1.5355867,-0.8102765,0.39700136,0.9479792,5.72563,0.6617512 116 | attn_1_6_outputs,2,0.41150886,-0.04871347,0.3481352,-0.41566238,-1.5758594,0.090029344 117 | attn_1_6_outputs,3,-0.26570487,-0.021473275,0.59041107,0.2970626,0.9045193,-0.53055835 118 | attn_1_6_outputs,4,-0.08541798,0.34799036,0.76428026,0.69276136,0.68592554,0.5443473 119 | attn_1_6_outputs,5,-0.15378077,-0.39172465,0.40749076,0.5314768,0.8851225,0.4981952 120 | attn_1_6_outputs,6,-0.008769149,0.07061772,0.37057376,-0.65310204,-3.0223656,-1.3392087 121 | attn_1_6_outputs,7,0.05816845,0.25007242,0.451493,-0.22949804,-2.5058506,-7.4911504 122 | attn_1_7_outputs,0,1.0521657,1.0074128,0.4629282,-2.0979998,-18.004267,-24.830841 123 | attn_1_7_outputs,1,-0.69483924,-0.15778616,-0.24732345,1.2288451,3.7539797,5.79142 124 | attn_1_7_outputs,2,0.77858096,0.7278021,-0.42153692,-1.9385699,-2.943345,-4.5531664 125 | attn_1_7_outputs,3,-0.8458971,-0.20327307,-0.19063051,1.0987198,2.5467868,5.116683 126 | attn_1_7_outputs,4,-1.1233516,-0.48994657,-0.16887976,1.1522313,3.7934172,3.263776 127 | attn_1_7_outputs,5,-0.67034286,-0.18595126,-0.15823115,0.8096825,2.814898,0.057380866 128 | attn_1_7_outputs,6,-0.6525574,0.014802525,0.19252089,0.6953846,2.502593,-0.91043985 129 | attn_1_7_outputs,7,-1.4201316,-0.9218594,-1.0736845,0.35498753,3.4924703,3.2490141 130 | mlp_0_0_outputs,0,1.1609898,0.8572743,0.17748477,-1.1314943,-6.752202,-10.017833 131 | mlp_0_0_outputs,1,0.15575539,0.20244724,-0.35808483,-0.16532898,0.6535445,0.9125833 132 | mlp_0_0_outputs,2,0.47592318,0.36881614,0.17811424,1.0390111,2.0942621,-0.4145827 133 | mlp_0_0_outputs,3,0.25076088,0.31756687,-0.023319967,0.14243905,-1.2209392,0.2947385 134 | mlp_0_0_outputs,4,0.47051138,0.46116182,0.02881971,0.85385466,0.9588443,1.3177751 135 | mlp_0_0_outputs,5,0.32828408,0.40574586,-0.122866556,0.5059813,1.4154013,-0.2793758 136 | mlp_0_0_outputs,6,0.5967165,0.52169484,-0.30867448,-1.2104449,-1.8745283,-0.5121046 137 | mlp_0_0_outputs,7,-0.01277371,-0.049345464,-0.40406802,-0.19294757,0.47392055,1.439194 138 | mlp_0_1_outputs,0,-0.08253285,-0.067494564,-0.19991755,0.16750817,-0.049758844,-0.8476353 139 | mlp_0_1_outputs,1,-0.11760705,-0.03612627,-0.060250442,0.016026465,0.3172396,-0.77812344 140 | mlp_0_1_outputs,2,-0.23936105,0.00013453455,-0.25397772,0.266447,-0.7957484,-0.7953196 141 | mlp_0_1_outputs,3,-0.34299704,-0.27820325,-0.29482996,-0.2864089,-0.09556553,-1.2629421 142 | mlp_0_1_outputs,4,-0.046016607,-0.09005855,-0.03415023,0.054577596,0.10380587,-0.7631649 143 | mlp_0_1_outputs,5,0.11538427,0.1287231,0.07588654,0.18756787,0.20409665,-0.34829876 144 | mlp_0_1_outputs,6,-0.47443476,-0.08539611,1.0217084,0.75428355,-0.06510375,-1.7716645 145 | mlp_0_1_outputs,7,-0.033662606,0.12398085,0.002802741,0.3118455,0.17265546,-0.92884475 146 | mlp_1_0_outputs,0,0.4595583,0.4955876,0.6625086,-0.30223736,0.2889277,-0.061707802 147 | mlp_1_0_outputs,1,0.12989043,0.2530865,-0.45482454,0.09733099,-2.926608,-0.32248923 148 | mlp_1_0_outputs,2,0.39325613,0.04932119,-1.1195576,0.22859044,-2.9229677,-0.25452712 149 | mlp_1_0_outputs,3,0.24489257,0.1534004,0.32450718,-1.9852786,1.6052507,-1.8961568 150 | mlp_1_0_outputs,4,0.20856915,0.21389735,-0.08142016,-0.114698805,-0.5788899,-0.084191464 151 | mlp_1_0_outputs,5,0.12847619,0.1381205,-0.26468688,-0.016217208,-2.0436559,-0.60245377 152 | mlp_1_0_outputs,6,-0.028417686,0.23583707,0.23858303,0.30995926,-3.0726144,-0.10256881 153 | mlp_1_0_outputs,7,0.5175587,0.68217903,0.32498646,0.61085075,-2.3507814,0.5371939 154 | mlp_1_1_outputs,0,0.45030883,0.4058445,0.3972978,0.28020045,-1.9913944,-0.055018645 155 | mlp_1_1_outputs,1,0.06336323,-0.28932407,0.37485504,0.25855744,1.9029437,0.020531047 156 | mlp_1_1_outputs,2,0.2779919,0.2499768,0.25289297,0.027555976,-1.9493699,-0.6084653 157 | mlp_1_1_outputs,3,0.44192633,0.19696143,0.19470312,0.69076985,-0.7434745,-0.1210918 158 | mlp_1_1_outputs,4,0.17183223,0.33310947,0.48470286,-0.5919067,-0.8927419,-4.1806927 159 | mlp_1_1_outputs,5,0.28217918,0.17491329,0.3567068,-0.2545964,0.08058698,-1.1707318 160 | mlp_1_1_outputs,6,0.61476415,0.21169148,0.082067244,0.027210148,-1.1794713,-0.701124 161 | mlp_1_1_outputs,7,0.3328481,0.29813948,0.22811241,0.22678965,-1.0541234,-0.81014025 162 | ones,_,0.10541394,-0.19712536,-0.008858198,0.58838516,-0.5383584,-1.4631578 163 | positions,0,0.19978875,0.010510841,0.4988928,-0.2598821,0.5445723,0.21678215 164 | positions,1,0.3364794,0.13280986,0.08279502,-0.20376179,-0.1359483,-0.47131416 165 | positions,2,0.31971225,0.13252828,0.15829775,-0.1309726,0.040779695,-0.4894607 166 | positions,3,0.33861408,0.18808766,0.4626563,0.1264201,-0.29951236,-0.6159285 167 | positions,4,0.3720952,0.11060386,0.2411623,-0.2241606,-0.03696618,-0.99401474 168 | positions,5,0.5989128,0.5286751,0.9225798,0.64178616,-0.24038357,-0.25051555 169 | positions,6,0.106012695,-0.12985282,-0.019405967,-0.54309326,0.049963955,-1.0877988 170 | positions,7,0.29646856,0.07841838,-0.053576685,-0.04623412,-1.2492927,-4.2103653 171 | tokens,0,0.48456392,0.24045633,0.19029397,0.19912028,-1.2815764,-1.3616517 172 | tokens,1,-0.2905613,-0.39902,-0.16171901,0.9161303,4.4462442,1.6672348 173 | tokens,2,0.3137043,0.1140619,0.2546798,-0.3095032,-6.1083627,-2.055795 174 | tokens,3,-0.2670635,-0.5329313,-0.38812107,-0.55894905,2.1805744,5.809422 175 | tokens,4,0.6327516,0.46054688,0.57450086,0.070067786,-0.3341348,-7.701936 176 | tokens,5,0.37370822,0.29103732,0.50220376,0.12526518,0.14494029,-0.0737996 177 | tokens,,-0.05186225,0.08119709,0.30115068,0.07776928,-0.2788572,-0.26285934 178 | tokens,,-0.3594653,-0.10468927,0.24319318,0.77997345,-0.14584279,-1.0055608 179 | -------------------------------------------------------------------------------- /programs/rasp_categorical_only/dyck2/dyck2_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,F,P,T 2 | attn_0_0_outputs,(,-0.09839721,0.072768934,-1.9487087 3 | attn_0_0_outputs,),0.08454172,-0.8850317,-0.31351435 4 | attn_0_0_outputs,,0.64131826,0.51406443,-1.0410135 5 | attn_0_0_outputs,,-0.6800124,0.23581685,0.8454185 6 | attn_0_0_outputs,_,-0.64708936,0.05047905,-0.22396585 7 | attn_0_0_outputs,_,-0.60640234,-0.123617664,-0.3515302 8 | attn_0_0_outputs,_,0.33327684,0.7507405,0.3113162 9 | attn_0_0_outputs,_,-0.39292628,0.22216852,0.22758712 10 | attn_0_0_outputs,_,-0.10046251,0.024066549,0.026169369 11 | attn_0_0_outputs,_,0.55228436,0.15485866,0.10148037 12 | attn_0_0_outputs,_,0.3600001,0.4491899,0.22103375 13 | attn_0_0_outputs,_,-0.52200395,0.7784305,0.113123655 14 | attn_0_0_outputs,_,-0.59014857,0.64098155,1.3751171 15 | attn_0_0_outputs,_,0.14060214,0.8747294,-0.85699344 16 | attn_0_0_outputs,{,-0.039841566,0.13473256,-2.0053372 17 | attn_0_0_outputs,},0.01718044,-0.9023884,-0.3925461 18 | attn_0_1_outputs,(,0.24663934,0.05382682,-1.4295913 19 | attn_0_1_outputs,),0.44859663,0.11123296,0.21891037 20 | attn_0_1_outputs,,-0.12616456,0.80257934,0.82834697 21 | attn_0_1_outputs,,-3.839294,1.9205002,4.9068 22 | attn_0_1_outputs,_,0.016972326,0.6989716,0.30263668 23 | attn_0_1_outputs,_,0.06236384,0.3147991,0.21376273 24 | attn_0_1_outputs,_,0.05235185,0.2560973,-0.39419228 25 | attn_0_1_outputs,_,-0.70403737,-0.11916539,-0.8784481 26 | attn_0_1_outputs,_,0.180807,0.21920295,0.40988365 27 | attn_0_1_outputs,_,-0.38003322,0.43255183,0.30699322 28 | attn_0_1_outputs,_,-0.17761539,-0.16689041,-0.07989258 29 | attn_0_1_outputs,_,-0.2789575,0.21954398,-0.5772647 30 | attn_0_1_outputs,_,-0.7351252,0.29486123,1.1601452 31 | attn_0_1_outputs,_,0.39496288,0.80411744,-1.136259 32 | attn_0_1_outputs,{,0.25580674,-0.020562408,-1.5278777 33 | attn_0_1_outputs,},0.16623977,-0.2624999,-0.06795137 34 | attn_0_2_outputs,(,0.61922807,-1.373114,2.2359898 35 | attn_0_2_outputs,),0.5194672,-2.4147916,0.6145038 36 | attn_0_2_outputs,,0.56023663,-0.40202478,0.661288 37 | attn_0_2_outputs,,0.08606467,1.1087501,-0.7814447 38 | attn_0_2_outputs,_,0.28296587,0.18189661,-0.49550337 39 | attn_0_2_outputs,_,0.57444793,0.557668,0.13670582 40 | attn_0_2_outputs,_,1.248424,-0.001178341,-0.012084891 41 | attn_0_2_outputs,_,0.35445789,-0.5528627,-0.7319502 42 | attn_0_2_outputs,_,0.14257619,-0.5120995,-0.31158966 43 | attn_0_2_outputs,_,0.41282007,0.03288246,-0.95664227 44 | attn_0_2_outputs,_,0.34463474,-0.52674174,-0.4970775 45 | attn_0_2_outputs,_,0.844124,-0.25249133,-0.31938592 46 | attn_0_2_outputs,_,0.66700053,0.30742458,0.25866547 47 | attn_0_2_outputs,_,1.1990013,0.85531867,-1.7240785 48 | attn_0_2_outputs,{,0.12513652,1.2699282,-0.7781075 49 | attn_0_2_outputs,},0.9145521,-2.0735366,1.2254958 50 | attn_0_3_outputs,(,1.4417167,-0.64466053,-1.18381 51 | attn_0_3_outputs,),-0.6638932,1.0632501,0.49708265 52 | attn_0_3_outputs,,0.45009208,0.54130626,0.25542018 53 | attn_0_3_outputs,,0.70722836,-0.44485033,1.8558304 54 | attn_0_3_outputs,_,-0.3365018,-0.4629079,-0.17494892 55 | attn_0_3_outputs,_,1.1892766,0.6559469,0.19583496 56 | attn_0_3_outputs,_,0.77296907,0.44271478,-0.2764062 57 | attn_0_3_outputs,_,0.29536107,0.55541724,0.8952327 58 | attn_0_3_outputs,_,-0.1426614,0.44414833,0.44125217 59 | attn_0_3_outputs,_,-0.86521447,-0.82092905,0.21177834 60 | attn_0_3_outputs,_,0.032857087,0.3982487,0.5682109 61 | attn_0_3_outputs,_,-0.33331037,0.38984817,0.6000708 62 | attn_0_3_outputs,_,-1.2023804,-0.24906081,0.23746833 63 | attn_0_3_outputs,_,0.20427324,0.14984377,-0.69274366 64 | attn_0_3_outputs,{,1.219885,-0.54653597,-1.0410267 65 | attn_0_3_outputs,},-0.7413577,0.6296828,0.09741894 66 | attn_1_0_outputs,0,0.22650789,-0.18543506,-0.58758974 67 | attn_1_0_outputs,1,0.13576485,1.0912529,-0.5784185 68 | attn_1_0_outputs,10,-0.830527,0.81066996,0.042711593 69 | attn_1_0_outputs,11,0.17748582,-0.5338113,-1.0006839 70 | attn_1_0_outputs,12,0.028687602,0.12396941,-5.01376 71 | attn_1_0_outputs,13,-0.3463128,1.461672,-0.94459385 72 | attn_1_0_outputs,14,1.6825908,-1.798565,-0.6786788 73 | attn_1_0_outputs,15,0.6389121,-0.9041575,-0.96575123 74 | attn_1_0_outputs,2,0.25002298,-1.8573852,1.413765 75 | attn_1_0_outputs,3,0.22070499,-1.2486526,0.710387 76 | attn_1_0_outputs,4,-0.6571064,0.35194886,-1.3422754 77 | attn_1_0_outputs,5,1.9105325,-1.0667194,0.36192167 78 | attn_1_0_outputs,6,0.047967866,-0.09409498,-0.6467464 79 | attn_1_0_outputs,7,-0.86700124,1.1282116,0.33472228 80 | attn_1_0_outputs,8,-0.6687376,1.0361387,-0.20685646 81 | attn_1_0_outputs,9,-0.7377063,1.0348144,-0.33758923 82 | attn_1_1_outputs,0,-0.5052586,0.5057064,0.29024526 83 | attn_1_1_outputs,1,-0.7648216,1.4227237,-0.258629 84 | attn_1_1_outputs,10,0.41181248,-0.6207077,-0.9900048 85 | attn_1_1_outputs,11,0.077770524,-0.9213925,-0.05973171 86 | attn_1_1_outputs,12,0.4023454,1.3574238,-3.6971736 87 | attn_1_1_outputs,13,-2.6055536,0.43831292,1.9638561 88 | attn_1_1_outputs,14,1.7760623,-1.4472873,-0.7355287 89 | attn_1_1_outputs,15,-1.3515613,0.17346454,0.6269158 90 | attn_1_1_outputs,2,1.4712042,-1.8864663,1.1948589 91 | attn_1_1_outputs,3,0.53816074,-0.7384589,0.7789173 92 | attn_1_1_outputs,4,-0.8784697,1.489651,-0.40728238 93 | attn_1_1_outputs,5,2.2822268,-1.2395031,0.87639683 94 | attn_1_1_outputs,6,-0.64847505,-0.52782285,-0.6055985 95 | attn_1_1_outputs,7,0.035765603,1.5342312,0.83606446 96 | attn_1_1_outputs,8,0.82426596,-0.4623573,-0.52451324 97 | attn_1_1_outputs,9,0.26120463,-0.1682919,-0.74863493 98 | attn_1_2_outputs,0,0.2741086,0.8019985,0.09451536 99 | attn_1_2_outputs,1,0.037871893,1.1881831,-0.593388 100 | attn_1_2_outputs,10,-0.10018027,-0.28660047,0.33032972 101 | attn_1_2_outputs,11,-0.15546942,0.14354461,-0.09335824 102 | attn_1_2_outputs,12,-0.75883734,1.1132323,-1.8390666 103 | attn_1_2_outputs,13,-0.91976297,0.20129623,0.7091432 104 | attn_1_2_outputs,14,2.1666486,-1.2059776,-1.1775556 105 | attn_1_2_outputs,15,1.2957777,0.11005765,0.27522662 106 | attn_1_2_outputs,2,2.6221218,-0.79460573,-0.4321904 107 | attn_1_2_outputs,3,-0.9202741,0.67877096,-1.6474462 108 | attn_1_2_outputs,4,-0.80796766,0.73023415,-0.8312531 109 | attn_1_2_outputs,5,1.5846504,-1.2096406,0.7192363 110 | attn_1_2_outputs,6,0.5472782,-1.2913166,-0.11775407 111 | attn_1_2_outputs,7,0.8987451,-1.1431016,0.027534375 112 | attn_1_2_outputs,8,-0.6447031,0.3921986,0.90193 113 | attn_1_2_outputs,9,-0.6580996,0.2183337,0.47104686 114 | attn_1_3_outputs,0,0.7974158,0.076014474,-0.12770599 115 | attn_1_3_outputs,1,0.08960959,0.38123062,-1.230019 116 | attn_1_3_outputs,10,0.46521083,0.018465715,-0.1728664 117 | attn_1_3_outputs,11,0.6438368,-0.79515433,-0.28863785 118 | attn_1_3_outputs,12,0.53023314,-1.1950611,-1.609009 119 | attn_1_3_outputs,13,0.10482195,1.6155713,0.0057584657 120 | attn_1_3_outputs,14,0.95015556,-1.4626787,0.5756642 121 | attn_1_3_outputs,15,1.1534876,0.57546484,-1.156569 122 | attn_1_3_outputs,2,0.014377812,-0.7606671,1.6436303 123 | attn_1_3_outputs,3,0.43214664,-0.30850783,0.011211452 124 | attn_1_3_outputs,4,0.73773974,1.1139972,-0.69320714 125 | attn_1_3_outputs,5,0.5149078,-0.9359474,1.22139 126 | attn_1_3_outputs,6,0.6841033,-0.39748695,-0.08539801 127 | attn_1_3_outputs,7,-0.43467954,1.2192235,-0.83311576 128 | attn_1_3_outputs,8,-0.5582946,0.3580021,0.18355377 129 | attn_1_3_outputs,9,-0.15669319,1.2627395,-0.6961721 130 | attn_2_0_outputs,0,-0.40245613,0.24030568,0.30186835 131 | attn_2_0_outputs,1,-1.3748924,0.71387535,-0.39306825 132 | attn_2_0_outputs,10,-0.59452015,-0.4597474,2.018241 133 | attn_2_0_outputs,11,0.2047717,0.003245883,-0.73424035 134 | attn_2_0_outputs,12,-0.720487,1.5443448,-1.6471566 135 | attn_2_0_outputs,13,0.37173122,-0.028756775,0.36360672 136 | attn_2_0_outputs,14,0.76262903,-0.89649415,-0.5565152 137 | attn_2_0_outputs,15,0.025871454,0.16201858,0.56903595 138 | attn_2_0_outputs,2,2.0154552,-2.0844448,-0.10086185 139 | attn_2_0_outputs,3,0.73500365,-0.98438346,0.101617634 140 | attn_2_0_outputs,4,-1.2039843,2.2337372,0.16640835 141 | attn_2_0_outputs,5,3.293801,-2.8079114,-0.77340686 142 | attn_2_0_outputs,6,-0.7499077,0.33472747,-0.1948529 143 | attn_2_0_outputs,7,-0.119230315,-0.78161407,-0.29735667 144 | attn_2_0_outputs,8,-0.21677648,-0.98522747,-0.09130101 145 | attn_2_0_outputs,9,0.52573526,-0.1827654,0.5464673 146 | attn_2_1_outputs,0,-1.387743,0.9506943,0.5479685 147 | attn_2_1_outputs,1,-0.37947392,0.5295445,0.2626229 148 | attn_2_1_outputs,10,0.13699992,-0.0180703,-1.8421631 149 | attn_2_1_outputs,11,-0.43704024,-0.2988407,-0.862835 150 | attn_2_1_outputs,12,-1.346792,1.172357,-0.39379665 151 | attn_2_1_outputs,13,2.1580062,-1.5489818,-1.5334694 152 | attn_2_1_outputs,14,0.21818756,0.87238884,0.45327023 153 | attn_2_1_outputs,15,1.447946,-0.76852626,-0.29249018 154 | attn_2_1_outputs,2,0.34594288,0.8285091,0.060779355 155 | attn_2_1_outputs,3,-0.054519173,0.14733565,0.14937484 156 | attn_2_1_outputs,4,-0.39207256,1.3152301,-0.71527904 157 | attn_2_1_outputs,5,7.0318747,-4.8902483,-5.9904876 158 | attn_2_1_outputs,6,0.96370083,0.029464895,-0.1656916 159 | attn_2_1_outputs,7,0.58060133,-0.9354646,0.65154225 160 | attn_2_1_outputs,8,0.637,-0.49399912,-0.21180116 161 | attn_2_1_outputs,9,0.19271861,-0.01355111,-1.29539 162 | attn_2_2_outputs,0,-1.387934,0.39397144,0.16978894 163 | attn_2_2_outputs,1,-0.6932238,1.0597913,0.08682805 164 | attn_2_2_outputs,10,-0.17581192,0.5762519,0.04119891 165 | attn_2_2_outputs,11,-0.8586262,-0.40480924,-0.4132628 166 | attn_2_2_outputs,12,-1.6009113,2.973149,-1.1810892 167 | attn_2_2_outputs,13,-0.34387457,-0.35372856,0.28478163 168 | attn_2_2_outputs,14,0.35553423,0.7332589,-0.1989791 169 | attn_2_2_outputs,15,0.69653744,-0.09478814,-0.54141074 170 | attn_2_2_outputs,2,1.1788943,-0.65940076,-0.20502281 171 | attn_2_2_outputs,3,0.65942603,-0.32618254,-0.107519 172 | attn_2_2_outputs,4,-1.3350819,0.6971251,0.092566006 173 | attn_2_2_outputs,5,0.8389036,0.28690335,0.43394306 174 | attn_2_2_outputs,6,0.018443262,0.20905562,0.04728933 175 | attn_2_2_outputs,7,-0.85211277,0.10939689,1.3427547 176 | attn_2_2_outputs,8,0.25118992,0.59912837,-0.42855418 177 | attn_2_2_outputs,9,0.3361104,-0.14676784,-0.7411014 178 | attn_2_3_outputs,0,-0.5618858,0.5809453,1.0021099 179 | attn_2_3_outputs,1,-0.5739619,1.7692959,-0.378794 180 | attn_2_3_outputs,10,0.035676032,-0.6341216,0.28642064 181 | attn_2_3_outputs,11,0.07306448,0.17111002,0.2374836 182 | attn_2_3_outputs,12,-1.188873,1.1038005,-1.1276703 183 | attn_2_3_outputs,13,0.54361314,-1.0640161,0.84797543 184 | attn_2_3_outputs,14,1.3631084,-0.56192636,0.4859162 185 | attn_2_3_outputs,15,0.48447692,-0.37444183,0.39286935 186 | attn_2_3_outputs,2,8.552915,-5.3725734,-6.36959 187 | attn_2_3_outputs,3,-0.11614109,0.30205145,0.12240872 188 | attn_2_3_outputs,4,-1.6696448,1.1503662,-0.46547452 189 | attn_2_3_outputs,5,3.2380612,-1.0949727,0.22095455 190 | attn_2_3_outputs,6,0.10913107,0.03129147,-1.241605 191 | attn_2_3_outputs,7,0.0360893,-1.0509039,0.4309466 192 | attn_2_3_outputs,8,2.3759365,-1.4136782,-1.8956783 193 | attn_2_3_outputs,9,2.515315,-1.9465212,-1.5227278 194 | mlp_0_0_outputs,0,-0.20890403,0.14091475,0.7584955 195 | mlp_0_0_outputs,1,0.3021631,0.1919344,0.1126692 196 | mlp_0_0_outputs,10,-2.756637,1.5794932,1.9001974 197 | mlp_0_0_outputs,11,0.35167038,0.067295164,0.41860104 198 | mlp_0_0_outputs,12,-0.013727432,0.9100293,-2.461872 199 | mlp_0_0_outputs,13,0.46883824,0.21986999,0.37993425 200 | mlp_0_0_outputs,14,-0.018150175,0.22210702,-0.48269686 201 | mlp_0_0_outputs,15,1.2375835,0.29053566,-0.856558 202 | mlp_0_0_outputs,2,1.1214304,-1.9128251,0.6156005 203 | mlp_0_0_outputs,3,0.567998,-0.43510225,0.60242385 204 | mlp_0_0_outputs,4,0.15451127,-0.5385739,-0.78529364 205 | mlp_0_0_outputs,5,1.6117405,-2.094646,-0.42933798 206 | mlp_0_0_outputs,6,0.87502843,0.53566813,-0.9268592 207 | mlp_0_0_outputs,7,0.31717533,-0.27278233,-0.0015997723 208 | mlp_0_0_outputs,8,1.5352674,0.55484116,-6.1980653 209 | mlp_0_0_outputs,9,0.97670096,-0.14376327,-6.9078407 210 | mlp_0_1_outputs,0,0.9233807,0.2854663,-0.8436871 211 | mlp_0_1_outputs,1,0.24584311,0.2944399,0.66519386 212 | mlp_0_1_outputs,10,0.22019076,0.81191224,0.5874235 213 | mlp_0_1_outputs,11,0.11129518,-0.08004127,-0.122246675 214 | mlp_0_1_outputs,12,0.32788306,0.6816243,-2.6723564 215 | mlp_0_1_outputs,13,-0.3635576,0.86900264,-0.9099431 216 | mlp_0_1_outputs,14,0.5210501,-0.018928718,-0.46168962 217 | mlp_0_1_outputs,15,0.93055755,-0.023645593,-0.7769493 218 | mlp_0_1_outputs,2,-1.2583181,1.8995489,0.5895376 219 | mlp_0_1_outputs,3,1.0925126,-1.0588683,-1.3893077 220 | mlp_0_1_outputs,4,0.53683454,0.7541305,0.17249009 221 | mlp_0_1_outputs,5,0.5254129,0.41082186,0.31906855 222 | mlp_0_1_outputs,6,-0.30498037,-1.0835053,0.2178056 223 | mlp_0_1_outputs,7,-0.6612743,0.8853988,0.6170607 224 | mlp_0_1_outputs,8,-0.46027464,9.3793526e-05,-0.13259079 225 | mlp_0_1_outputs,9,-0.7891113,1.0600111,0.6583635 226 | mlp_1_0_outputs,0,0.34241092,0.2945056,-0.1292986 227 | mlp_1_0_outputs,1,0.47135493,0.46016377,-0.43053067 228 | mlp_1_0_outputs,10,0.2581985,0.5706344,0.19086708 229 | mlp_1_0_outputs,11,0.26166466,0.90268534,-0.17562716 230 | mlp_1_0_outputs,12,-0.17449586,0.33637732,0.16877215 231 | mlp_1_0_outputs,13,-1.4794242,1.719357,0.39456385 232 | mlp_1_0_outputs,14,0.77961195,0.008924905,-0.09428507 233 | mlp_1_0_outputs,15,0.08689608,0.9701745,0.09053159 234 | mlp_1_0_outputs,2,1.1217159,-0.7634738,-1.2014244 235 | mlp_1_0_outputs,3,0.39370438,-0.41021878,-0.7921828 236 | mlp_1_0_outputs,4,-0.4663107,1.2819383,0.4163665 237 | mlp_1_0_outputs,5,0.49742854,0.008245769,-0.45130947 238 | mlp_1_0_outputs,6,0.14249237,-0.102100946,-0.5098823 239 | mlp_1_0_outputs,7,0.30476442,0.62488115,-1.3264716 240 | mlp_1_0_outputs,8,0.08670288,-0.42941815,-0.7135336 241 | mlp_1_0_outputs,9,0.26799846,-0.09152536,-1.3237554 242 | mlp_1_1_outputs,0,0.26805434,0.23167966,0.4149677 243 | mlp_1_1_outputs,1,0.53618395,-0.5244912,-2.6226473 244 | mlp_1_1_outputs,10,2.2270403,-0.6349289,-3.1257336 245 | mlp_1_1_outputs,11,-0.14030638,-0.15512456,0.7863296 246 | mlp_1_1_outputs,12,-0.5041155,0.5484189,-0.19427307 247 | mlp_1_1_outputs,13,-0.26450726,-0.33360416,0.79678315 248 | mlp_1_1_outputs,14,-0.9820768,0.5201349,0.059233483 249 | mlp_1_1_outputs,15,-0.7036932,0.4160271,-0.47682777 250 | mlp_1_1_outputs,2,-0.39756742,0.5957785,0.6356138 251 | mlp_1_1_outputs,3,-0.09534266,-0.17529783,0.72164303 252 | mlp_1_1_outputs,4,0.88351613,0.77100885,0.38341337 253 | mlp_1_1_outputs,5,-0.011537257,-0.18188691,0.7026197 254 | mlp_1_1_outputs,6,-0.6994262,0.019658938,-0.47276354 255 | mlp_1_1_outputs,7,0.57141864,-0.04391361,0.055888955 256 | mlp_1_1_outputs,8,0.043691415,0.07147635,0.50050634 257 | mlp_1_1_outputs,9,-0.35295078,0.23593843,0.7010527 258 | mlp_2_0_outputs,0,1.8654248,-0.76327336,-2.1851802 259 | mlp_2_0_outputs,1,2.4732018,2.0428228,-2.8090878 260 | mlp_2_0_outputs,10,0.50512856,-0.4752733,-0.42407128 261 | mlp_2_0_outputs,11,-0.2162885,-0.041211855,-0.13339278 262 | mlp_2_0_outputs,12,0.2361639,-0.9290504,0.4851906 263 | mlp_2_0_outputs,13,0.34775272,-0.14868683,0.1458214 264 | mlp_2_0_outputs,14,0.28508988,0.33885202,0.37766585 265 | mlp_2_0_outputs,15,-0.08315753,0.28621098,0.066261254 266 | mlp_2_0_outputs,2,0.5492595,-0.30825824,-0.66436327 267 | mlp_2_0_outputs,3,0.27374202,0.40835023,0.27150822 268 | mlp_2_0_outputs,4,1.1041646,0.44970506,-0.008170099 269 | mlp_2_0_outputs,5,0.14885116,0.48616803,-0.02961562 270 | mlp_2_0_outputs,6,-0.05234915,0.3683034,-2.0102525 271 | mlp_2_0_outputs,7,-2.398438,0.019524893,8.194051 272 | mlp_2_0_outputs,8,0.25424272,0.4785088,-0.41905257 273 | mlp_2_0_outputs,9,-0.14737429,-0.1506395,0.16152707 274 | mlp_2_1_outputs,0,3.7675667,-1.4037446,-3.882887 275 | mlp_2_1_outputs,1,0.38408276,-0.20874391,-0.79680127 276 | mlp_2_1_outputs,10,0.82168883,0.15445253,-0.35908702 277 | mlp_2_1_outputs,11,0.40715832,-0.078189924,-0.23746909 278 | mlp_2_1_outputs,12,0.19456168,0.5811429,0.19617397 279 | mlp_2_1_outputs,13,-0.18762338,-0.037409503,-0.88011366 280 | mlp_2_1_outputs,14,-0.48191726,0.7540797,0.049526248 281 | mlp_2_1_outputs,15,0.704309,0.12997411,0.05762332 282 | mlp_2_1_outputs,2,-1.4413515,1.5535452,0.43207467 283 | mlp_2_1_outputs,3,0.25547877,-0.8476021,-0.36438197 284 | mlp_2_1_outputs,4,0.67018205,-0.10163318,-0.2500151 285 | mlp_2_1_outputs,5,-0.094072334,-0.73808116,-0.67700255 286 | mlp_2_1_outputs,6,0.32068485,0.49894604,-0.4640106 287 | mlp_2_1_outputs,7,0.6843552,-0.1548457,-0.12560412 288 | mlp_2_1_outputs,8,-0.38264602,-0.3764761,-0.9299223 289 | mlp_2_1_outputs,9,0.104728,-0.70314866,-0.6734416 290 | ones,_,0.25881597,-0.46701866,0.16698436 291 | positions,0,-0.2762641,-0.17773171,-1.017329 292 | positions,1,0.26486596,0.65079963,-2.5160186 293 | positions,10,0.14841577,-0.47636732,1.5774705 294 | positions,11,1.761886,1.1001859,-10.36295 295 | positions,12,-0.68251526,-2.187309,2.290777 296 | positions,13,1.9318004,1.8703078,-8.895537 297 | positions,14,-1.3542655,-2.2336614,3.640843 298 | positions,15,1.3797264,1.0626522,-7.66056 299 | positions,2,-4.7400956,-4.6051226,10.471129 300 | positions,3,2.772535,2.4986095,-10.479663 301 | positions,4,-3.1975293,-2.1136668,8.509926 302 | positions,5,0.9007016,2.1843433,-7.800698 303 | positions,6,-2.446595,-0.2883177,4.8488107 304 | positions,7,1.4072119,2.093496,-6.3489823 305 | positions,8,-0.5884627,-0.3837479,1.0950348 306 | positions,9,1.4770374,2.9142487,-7.4769826 307 | tokens,(,1.0545344,1.422954,-4.159084 308 | tokens,),-0.44501638,-0.9425237,3.0321198 309 | tokens,,0.24174905,-0.15760137,0.25735503 310 | tokens,,1.4156133,-0.7207831,-0.14946787 311 | tokens,_,0.32973674,0.1480116,0.40687805 312 | tokens,_,0.4269608,0.18007647,0.0591342 313 | tokens,_,0.012144065,0.24502037,-0.37508482 314 | tokens,_,1.0130444,0.2379365,0.7822351 315 | tokens,_,0.24589725,0.05158651,-0.009494092 316 | tokens,_,0.13942418,0.16681431,0.54903275 317 | tokens,_,-0.2541933,0.61024016,0.46229398 318 | tokens,_,-0.67985225,0.38225034,-0.016558083 319 | tokens,_,1.4591968,-0.29021746,-0.55709064 320 | tokens,_,0.14226177,0.10692842,0.23002318 321 | tokens,{,0.9597864,1.4787383,-3.468003 322 | tokens,},-0.43541923,-1.1256342,2.860657 323 | -------------------------------------------------------------------------------- /programs/rasp_categorical_only/sort/sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def select_closest(keys, queries, predicate): 6 | scores = [[False for _ in keys] for _ in queries] 7 | for i, q in enumerate(queries): 8 | matches = [j for j, k in enumerate(keys) if predicate(q, k)] 9 | if not (any(matches)): 10 | scores[i][0] = True 11 | else: 12 | j = min(matches, key=lambda j: len(matches) if j == i else abs(i - j)) 13 | scores[i][j] = True 14 | return scores 15 | 16 | 17 | def aggregate(attention, values): 18 | return [[v for a, v in zip(attn, values) if a][0] for attn in attention] 19 | 20 | 21 | def run(tokens): 22 | # classifier weights ########################################## 23 | classifier_weights = pd.read_csv( 24 | "programs/rasp_categorical_only/sort/sort_weights.csv", 25 | index_col=[0, 1], 26 | dtype={"feature": str}, 27 | ) 28 | # inputs ##################################################### 29 | token_scores = classifier_weights.loc[[("tokens", str(v)) for v in tokens]] 30 | 31 | positions = list(range(len(tokens))) 32 | position_scores = classifier_weights.loc[[("positions", str(v)) for v in positions]] 33 | 34 | ones = [1 for _ in range(len(tokens))] 35 | one_scores = classifier_weights.loc[[("ones", "_") for v in ones]].mul(ones, axis=0) 36 | 37 | # attn_0_0 #################################################### 38 | def predicate_0_0(q_token, k_token): 39 | if q_token in {"", "0"}: 40 | return k_token == "0" 41 | elif q_token in {"1"}: 42 | return k_token == "1" 43 | elif q_token in {"2"}: 44 | return k_token == "2" 45 | elif q_token in {"3"}: 46 | return k_token == "3" 47 | elif q_token in {"4", ""}: 48 | return k_token == "4" 49 | 50 | attn_0_0_pattern = select_closest(tokens, tokens, predicate_0_0) 51 | attn_0_0_outputs = aggregate(attn_0_0_pattern, tokens) 52 | attn_0_0_output_scores = classifier_weights.loc[ 53 | [("attn_0_0_outputs", str(v)) for v in attn_0_0_outputs] 54 | ] 55 | 56 | # attn_0_1 #################################################### 57 | def predicate_0_1(q_position, k_position): 58 | if q_position in {0}: 59 | return k_position == 1 60 | elif q_position in {1, 6}: 61 | return k_position == 2 62 | elif q_position in {2}: 63 | return k_position == 3 64 | elif q_position in {3, 5}: 65 | return k_position == 4 66 | elif q_position in {4, 7}: 67 | return k_position == 6 68 | 69 | attn_0_1_pattern = select_closest(positions, positions, predicate_0_1) 70 | attn_0_1_outputs = aggregate(attn_0_1_pattern, tokens) 71 | attn_0_1_output_scores = classifier_weights.loc[ 72 | [("attn_0_1_outputs", str(v)) for v in attn_0_1_outputs] 73 | ] 74 | 75 | # attn_0_2 #################################################### 76 | def predicate_0_2(q_position, k_position): 77 | if q_position in {0}: 78 | return k_position == 1 79 | elif q_position in {1}: 80 | return k_position == 3 81 | elif q_position in {2, 6}: 82 | return k_position == 4 83 | elif q_position in {3, 4}: 84 | return k_position == 5 85 | elif q_position in {5, 7}: 86 | return k_position == 6 87 | 88 | attn_0_2_pattern = select_closest(positions, positions, predicate_0_2) 89 | attn_0_2_outputs = aggregate(attn_0_2_pattern, tokens) 90 | attn_0_2_output_scores = classifier_weights.loc[ 91 | [("attn_0_2_outputs", str(v)) for v in attn_0_2_outputs] 92 | ] 93 | 94 | # attn_0_3 #################################################### 95 | def predicate_0_3(q_position, k_position): 96 | if q_position in {0, 4}: 97 | return k_position == 3 98 | elif q_position in {1, 7}: 99 | return k_position == 4 100 | elif q_position in {2}: 101 | return k_position == 1 102 | elif q_position in {3, 5}: 103 | return k_position == 2 104 | elif q_position in {6}: 105 | return k_position == 5 106 | 107 | attn_0_3_pattern = select_closest(positions, positions, predicate_0_3) 108 | attn_0_3_outputs = aggregate(attn_0_3_pattern, tokens) 109 | attn_0_3_output_scores = classifier_weights.loc[ 110 | [("attn_0_3_outputs", str(v)) for v in attn_0_3_outputs] 111 | ] 112 | 113 | # mlp_0_0 ##################################################### 114 | def mlp_0_0(attn_0_2_output, position): 115 | key = (attn_0_2_output, position) 116 | if key in { 117 | ("0", 1), 118 | ("0", 2), 119 | ("0", 3), 120 | ("0", 4), 121 | ("1", 1), 122 | ("1", 2), 123 | ("2", 1), 124 | ("2", 2), 125 | ("3", 1), 126 | ("4", 1), 127 | ("", 1), 128 | }: 129 | return 0 130 | elif key in { 131 | ("0", 0), 132 | ("0", 6), 133 | ("1", 0), 134 | ("1", 6), 135 | ("", 0), 136 | ("", 2), 137 | ("", 3), 138 | ("", 5), 139 | ("", 6), 140 | ("", 7), 141 | }: 142 | return 2 143 | elif key in { 144 | ("0", 5), 145 | ("0", 7), 146 | ("1", 3), 147 | ("1", 4), 148 | ("1", 5), 149 | ("1", 7), 150 | ("3", 2), 151 | ("4", 2), 152 | ("", 1), 153 | }: 154 | return 7 155 | elif key in { 156 | ("3", 4), 157 | ("3", 5), 158 | ("3", 7), 159 | ("4", 4), 160 | ("4", 5), 161 | ("4", 7), 162 | ("", 4), 163 | }: 164 | return 1 165 | elif key in {("2", 3), ("2", 4), ("2", 5), ("2", 7), ("3", 3), ("4", 3)}: 166 | return 6 167 | return 4 168 | 169 | mlp_0_0_outputs = [mlp_0_0(k0, k1) for k0, k1 in zip(attn_0_2_outputs, positions)] 170 | mlp_0_0_output_scores = classifier_weights.loc[ 171 | [("mlp_0_0_outputs", str(v)) for v in mlp_0_0_outputs] 172 | ] 173 | 174 | # mlp_0_1 ##################################################### 175 | def mlp_0_1(position): 176 | key = position 177 | if key in {1, 2}: 178 | return 7 179 | elif key in {5}: 180 | return 1 181 | return 6 182 | 183 | mlp_0_1_outputs = [mlp_0_1(k0) for k0 in positions] 184 | mlp_0_1_output_scores = classifier_weights.loc[ 185 | [("mlp_0_1_outputs", str(v)) for v in mlp_0_1_outputs] 186 | ] 187 | 188 | # mlp_0_2 ##################################################### 189 | def mlp_0_2(attn_0_1_output, position): 190 | key = (attn_0_1_output, position) 191 | if key in { 192 | ("1", 4), 193 | ("2", 2), 194 | ("2", 3), 195 | ("2", 4), 196 | ("2", 7), 197 | ("3", 2), 198 | ("3", 3), 199 | ("3", 4), 200 | ("3", 7), 201 | ("4", 2), 202 | ("4", 3), 203 | ("4", 4), 204 | ("4", 7), 205 | }: 206 | return 3 207 | elif key in { 208 | ("0", 0), 209 | ("0", 1), 210 | ("0", 2), 211 | ("0", 3), 212 | ("0", 4), 213 | ("0", 7), 214 | ("1", 1), 215 | ("2", 1), 216 | ("3", 1), 217 | ("4", 1), 218 | ("", 1), 219 | }: 220 | return 6 221 | elif key in { 222 | ("0", 5), 223 | ("1", 2), 224 | ("1", 3), 225 | ("1", 7), 226 | ("", 1), 227 | ("", 2), 228 | ("", 3), 229 | ("", 7), 230 | }: 231 | return 5 232 | return 1 233 | 234 | mlp_0_2_outputs = [mlp_0_2(k0, k1) for k0, k1 in zip(attn_0_1_outputs, positions)] 235 | mlp_0_2_output_scores = classifier_weights.loc[ 236 | [("mlp_0_2_outputs", str(v)) for v in mlp_0_2_outputs] 237 | ] 238 | 239 | # mlp_0_3 ##################################################### 240 | def mlp_0_3(attn_0_2_output, position): 241 | key = (attn_0_2_output, position) 242 | if key in { 243 | ("0", 4), 244 | ("0", 7), 245 | ("1", 4), 246 | ("2", 4), 247 | ("2", 7), 248 | ("3", 4), 249 | ("3", 7), 250 | ("4", 4), 251 | ("", 0), 252 | ("", 1), 253 | ("", 2), 254 | ("", 3), 255 | ("", 4), 256 | ("", 7), 257 | ("", 0), 258 | ("", 1), 259 | ("", 2), 260 | ("", 3), 261 | ("", 4), 262 | ("", 6), 263 | ("", 7), 264 | }: 265 | return 0 266 | elif key in { 267 | ("0", 5), 268 | ("0", 6), 269 | ("1", 5), 270 | ("1", 6), 271 | ("1", 7), 272 | ("2", 5), 273 | ("2", 6), 274 | ("3", 5), 275 | ("3", 6), 276 | ("4", 5), 277 | ("4", 6), 278 | ("", 5), 279 | ("", 6), 280 | ("", 5), 281 | }: 282 | return 7 283 | return 5 284 | 285 | mlp_0_3_outputs = [mlp_0_3(k0, k1) for k0, k1 in zip(attn_0_2_outputs, positions)] 286 | mlp_0_3_output_scores = classifier_weights.loc[ 287 | [("mlp_0_3_outputs", str(v)) for v in mlp_0_3_outputs] 288 | ] 289 | 290 | # attn_1_0 #################################################### 291 | def predicate_1_0(position, mlp_0_1_output): 292 | if position in {0, 1, 2, 3}: 293 | return mlp_0_1_output == 1 294 | elif position in {4, 5, 7}: 295 | return mlp_0_1_output == 0 296 | elif position in {6}: 297 | return mlp_0_1_output == 3 298 | 299 | attn_1_0_pattern = select_closest(mlp_0_1_outputs, positions, predicate_1_0) 300 | attn_1_0_outputs = aggregate(attn_1_0_pattern, attn_0_2_outputs) 301 | attn_1_0_output_scores = classifier_weights.loc[ 302 | [("attn_1_0_outputs", str(v)) for v in attn_1_0_outputs] 303 | ] 304 | 305 | # attn_1_1 #################################################### 306 | def predicate_1_1(q_position, k_position): 307 | if q_position in {0, 4, 5, 7}: 308 | return k_position == 1 309 | elif q_position in {1}: 310 | return k_position == 0 311 | elif q_position in {2}: 312 | return k_position == 5 313 | elif q_position in {3}: 314 | return k_position == 6 315 | elif q_position in {6}: 316 | return k_position == 7 317 | 318 | attn_1_1_pattern = select_closest(positions, positions, predicate_1_1) 319 | attn_1_1_outputs = aggregate(attn_1_1_pattern, attn_0_0_outputs) 320 | attn_1_1_output_scores = classifier_weights.loc[ 321 | [("attn_1_1_outputs", str(v)) for v in attn_1_1_outputs] 322 | ] 323 | 324 | # attn_1_2 #################################################### 325 | def predicate_1_2(q_position, k_position): 326 | if q_position in {0, 3}: 327 | return k_position == 1 328 | elif q_position in {1}: 329 | return k_position == 5 330 | elif q_position in {2}: 331 | return k_position == 6 332 | elif q_position in {4, 7}: 333 | return k_position == 2 334 | elif q_position in {5, 6}: 335 | return k_position == 3 336 | 337 | attn_1_2_pattern = select_closest(positions, positions, predicate_1_2) 338 | attn_1_2_outputs = aggregate(attn_1_2_pattern, tokens) 339 | attn_1_2_output_scores = classifier_weights.loc[ 340 | [("attn_1_2_outputs", str(v)) for v in attn_1_2_outputs] 341 | ] 342 | 343 | # attn_1_3 #################################################### 344 | def predicate_1_3(position, attn_0_1_output): 345 | if position in {0}: 346 | return attn_0_1_output == "1" 347 | elif position in {1, 2}: 348 | return attn_0_1_output == "0" 349 | elif position in {3}: 350 | return attn_0_1_output == "" 351 | elif position in {4, 5, 6}: 352 | return attn_0_1_output == "4" 353 | elif position in {7}: 354 | return attn_0_1_output == "3" 355 | 356 | attn_1_3_pattern = select_closest(attn_0_1_outputs, positions, predicate_1_3) 357 | attn_1_3_outputs = aggregate(attn_1_3_pattern, attn_0_1_outputs) 358 | attn_1_3_output_scores = classifier_weights.loc[ 359 | [("attn_1_3_outputs", str(v)) for v in attn_1_3_outputs] 360 | ] 361 | 362 | # mlp_1_0 ##################################################### 363 | def mlp_1_0(position, attn_1_2_output): 364 | key = (position, attn_1_2_output) 365 | if key in { 366 | (0, "0"), 367 | (1, "0"), 368 | (2, "0"), 369 | (2, "1"), 370 | (2, "2"), 371 | (2, "3"), 372 | (2, "4"), 373 | (3, "0"), 374 | }: 375 | return 5 376 | elif key in {(1, "1"), (1, "2"), (1, "3"), (1, "4")}: 377 | return 1 378 | elif key in {(0, "2"), (4, "2"), (5, "2")}: 379 | return 4 380 | elif key in {(0, "1"), (3, "1"), (4, "0")}: 381 | return 6 382 | elif key in {(3, "2"), (4, "1"), (5, "1")}: 383 | return 7 384 | elif key in {(5, "0")}: 385 | return 2 386 | return 0 387 | 388 | mlp_1_0_outputs = [mlp_1_0(k0, k1) for k0, k1 in zip(positions, attn_1_2_outputs)] 389 | mlp_1_0_output_scores = classifier_weights.loc[ 390 | [("mlp_1_0_outputs", str(v)) for v in mlp_1_0_outputs] 391 | ] 392 | 393 | # mlp_1_1 ##################################################### 394 | def mlp_1_1(attn_1_3_output, attn_1_0_output): 395 | key = (attn_1_3_output, attn_1_0_output) 396 | if key in { 397 | ("3", "2"), 398 | ("3", "3"), 399 | ("3", "4"), 400 | ("3", ""), 401 | ("3", ""), 402 | ("4", "2"), 403 | ("4", "3"), 404 | ("4", "4"), 405 | ("4", ""), 406 | ("4", ""), 407 | ("", "2"), 408 | ("", "3"), 409 | ("", "4"), 410 | ("", ""), 411 | ("", ""), 412 | ("", "3"), 413 | ("", "4"), 414 | ("", ""), 415 | ("", ""), 416 | }: 417 | return 4 418 | elif key in {("0", ""), ("2", "")}: 419 | return 6 420 | return 1 421 | 422 | mlp_1_1_outputs = [ 423 | mlp_1_1(k0, k1) for k0, k1 in zip(attn_1_3_outputs, attn_1_0_outputs) 424 | ] 425 | mlp_1_1_output_scores = classifier_weights.loc[ 426 | [("mlp_1_1_outputs", str(v)) for v in mlp_1_1_outputs] 427 | ] 428 | 429 | # mlp_1_2 ##################################################### 430 | def mlp_1_2(attn_1_0_output, position): 431 | key = (attn_1_0_output, position) 432 | if key in { 433 | ("2", 4), 434 | ("2", 5), 435 | ("3", 4), 436 | ("3", 5), 437 | ("4", 0), 438 | ("4", 4), 439 | ("4", 5), 440 | ("", 4), 441 | ("", 5), 442 | ("", 4), 443 | ("", 5), 444 | }: 445 | return 0 446 | elif key in { 447 | ("0", 3), 448 | ("1", 3), 449 | ("2", 3), 450 | ("3", 0), 451 | ("3", 3), 452 | ("4", 3), 453 | ("", 3), 454 | ("", 3), 455 | }: 456 | return 7 457 | elif key in { 458 | ("2", 1), 459 | ("2", 7), 460 | ("3", 1), 461 | ("3", 7), 462 | ("4", 1), 463 | ("4", 7), 464 | ("", 1), 465 | ("", 7), 466 | }: 467 | return 1 468 | elif key in {("0", 4), ("0", 5), ("1", 4), ("1", 5)}: 469 | return 4 470 | elif key in {("2", 0), ("2", 6), ("3", 6), ("4", 6)}: 471 | return 5 472 | elif key in {("1", 1), ("1", 7), ("", 6), ("", 6)}: 473 | return 2 474 | elif key in {("4", 2), ("", 0), ("", 2), ("", 7)}: 475 | return 6 476 | return 3 477 | 478 | mlp_1_2_outputs = [mlp_1_2(k0, k1) for k0, k1 in zip(attn_1_0_outputs, positions)] 479 | mlp_1_2_output_scores = classifier_weights.loc[ 480 | [("mlp_1_2_outputs", str(v)) for v in mlp_1_2_outputs] 481 | ] 482 | 483 | # mlp_1_3 ##################################################### 484 | def mlp_1_3(attn_1_0_output, attn_1_3_output): 485 | key = (attn_1_0_output, attn_1_3_output) 486 | if key in { 487 | ("0", "3"), 488 | ("0", ""), 489 | ("1", "3"), 490 | ("2", "3"), 491 | ("3", "0"), 492 | ("3", "1"), 493 | ("3", "2"), 494 | ("3", "3"), 495 | ("3", ""), 496 | ("", "0"), 497 | ("", "1"), 498 | ("", "2"), 499 | ("", "3"), 500 | ("", ""), 501 | ("", "3"), 502 | }: 503 | return 5 504 | elif key in {("3", "4")}: 505 | return 2 506 | return 7 507 | 508 | mlp_1_3_outputs = [ 509 | mlp_1_3(k0, k1) for k0, k1 in zip(attn_1_0_outputs, attn_1_3_outputs) 510 | ] 511 | mlp_1_3_output_scores = classifier_weights.loc[ 512 | [("mlp_1_3_outputs", str(v)) for v in mlp_1_3_outputs] 513 | ] 514 | 515 | feature_logits = pd.concat( 516 | [ 517 | df.reset_index() 518 | for df in [ 519 | token_scores, 520 | position_scores, 521 | attn_0_0_output_scores, 522 | attn_0_1_output_scores, 523 | attn_0_2_output_scores, 524 | attn_0_3_output_scores, 525 | mlp_0_0_output_scores, 526 | mlp_0_1_output_scores, 527 | mlp_0_2_output_scores, 528 | mlp_0_3_output_scores, 529 | attn_1_0_output_scores, 530 | attn_1_1_output_scores, 531 | attn_1_2_output_scores, 532 | attn_1_3_output_scores, 533 | mlp_1_0_output_scores, 534 | mlp_1_1_output_scores, 535 | mlp_1_2_output_scores, 536 | mlp_1_3_output_scores, 537 | one_scores, 538 | ] 539 | ] 540 | ) 541 | logits = feature_logits.groupby(level=0).sum(numeric_only=True).to_numpy() 542 | classes = classifier_weights.columns.to_numpy() 543 | predictions = classes[logits.argmax(-1)] 544 | if tokens[0] == "": 545 | predictions[0] = "" 546 | if tokens[-1] == "": 547 | predictions[-1] = "" 548 | return predictions.tolist() 549 | 550 | 551 | examples = [ 552 | ( 553 | ["", "0", "4", "1", "1", "4", "2", ""], 554 | ["", "0", "1", "1", "2", "4", "4", ""], 555 | ), 556 | ( 557 | ["", "4", "4", "3", "0", "2", ""], 558 | ["", "0", "2", "3", "4", "4", ""], 559 | ), 560 | ( 561 | ["", "3", "0", "2", "2", "3", "3", ""], 562 | ["", "0", "2", "2", "3", "3", "3", ""], 563 | ), 564 | ( 565 | ["", "2", "4", "2", "0", "0", "3", ""], 566 | ["", "0", "0", "2", "2", "3", "4", ""], 567 | ), 568 | ( 569 | ["", "0", "0", "2", "0", "2", "3", ""], 570 | ["", "0", "0", "0", "2", "2", "3", ""], 571 | ), 572 | ( 573 | ["", "0", "1", "0", "4", "3", ""], 574 | ["", "0", "0", "1", "3", "4", ""], 575 | ), 576 | ( 577 | ["", "4", "2", "1", "2", "4", "3", ""], 578 | ["", "1", "2", "2", "3", "4", "4", ""], 579 | ), 580 | (["", "4", "3", ""], ["", "3", "4", ""]), 581 | ( 582 | ["", "1", "1", "0", "4", "2", "1", ""], 583 | ["", "0", "1", "1", "1", "2", "4", ""], 584 | ), 585 | ( 586 | ["", "0", "1", "1", "3", "1", ""], 587 | ["", "0", "1", "1", "1", "3", ""], 588 | ), 589 | ] 590 | for x, y in examples: 591 | print(f"x: {x}") 592 | print(f"y: {y}") 593 | y_hat = run(x) 594 | print(f"y_hat: {y_hat}") 595 | print() 596 | -------------------------------------------------------------------------------- /programs/rasp_categorical_only/sort/sort_weights.csv: -------------------------------------------------------------------------------- 1 | feature,value,0,1,2,3,4 2 | attn_0_0_outputs,0,6.9886675,1.9029737,-1.5632732,-4.666394,-6.6998963 3 | attn_0_0_outputs,1,2.808001,3.347147,0.15896283,-2.8454115,-5.028731 4 | attn_0_0_outputs,2,-0.9442088,0.086697936,2.8916383,0.13243037,-2.0047884 5 | attn_0_0_outputs,3,-5.0501323,-3.2974973,-0.2889569,3.5421517,1.7046565 6 | attn_0_0_outputs,4,-6.3542542,-4.6292524,-1.617208,2.236299,7.221099 7 | attn_0_0_outputs,,-1.2565033,-0.85322773,0.024527336,0.8092462,0.7044034 8 | attn_0_0_outputs,,0.215054,-0.6509929,-0.43638363,0.19641109,0.49052995 9 | attn_0_0_outputs,,-0.58543414,-0.4350722,-0.11593681,-0.010477837,0.035624333 10 | attn_0_1_outputs,0,8.321785,2.2734017,-1.3977479,-7.016843,-9.175802 11 | attn_0_1_outputs,1,6.1663084,8.126509,-0.13728762,-7.441755,-11.841383 12 | attn_0_1_outputs,2,-0.9028804,0.18783027,5.3251595,-2.1351826,-6.537742 13 | attn_0_1_outputs,3,-6.775278,-5.8244658,-1.1134557,7.1989717,1.5569733 14 | attn_0_1_outputs,4,-9.164681,-8.360944,-3.933387,3.6149611,13.612795 15 | attn_0_1_outputs,,-3.547872,-4.1135654,-2.683632,2.05698,6.5400867 16 | attn_0_1_outputs,,-0.20318432,0.13683431,-0.14054358,0.1945901,0.143777 17 | attn_0_1_outputs,,-0.2252483,-0.39968184,-1.0448881,0.33150062,0.60946065 18 | attn_0_2_outputs,0,12.662373,2.54436,-4.249066,-8.406431,-10.299077 19 | attn_0_2_outputs,1,2.874053,7.271783,-0.90393364,-5.438341,-8.241028 20 | attn_0_2_outputs,2,-2.4153283,1.249425,6.9100757,-2.0006323,-8.70784 21 | attn_0_2_outputs,3,-6.3554764,-4.6737704,0.021502413,7.1687856,-1.3787818 22 | attn_0_2_outputs,4,-8.521622,-7.321229,-3.046059,3.4124966,10.726074 23 | attn_0_2_outputs,,-1.8578418,-2.6656096,-4.077629,0.20807894,6.7849145 24 | attn_0_2_outputs,,0.092216395,-0.7707133,0.16602008,0.17321993,0.7220149 25 | attn_0_2_outputs,,-9.541172,-5.166671,-0.87453365,4.0391154,10.82512 26 | attn_0_3_outputs,0,16.349298,4.3640847,-3.3730683,-9.79156,-14.669813 27 | attn_0_3_outputs,1,4.139899,8.263065,-0.17833088,-6.8866434,-11.679859 28 | attn_0_3_outputs,2,-3.926204,-0.5807447,5.967192,-1.0567814,-5.560351 29 | attn_0_3_outputs,3,-9.774161,-6.630783,-0.6135139,7.905448,2.3533192 30 | attn_0_3_outputs,4,-12.426147,-9.234469,-3.574072,4.1587915,14.047162 31 | attn_0_3_outputs,,-4.6650424,-3.240851,-1.7017452,2.632841,6.9260855 32 | attn_0_3_outputs,,-0.7803622,-0.16891778,0.42231166,-0.11307383,-0.3041337 33 | attn_0_3_outputs,,0.2458686,-1.8348123,-1.403487,0.39682958,1.4361156 34 | attn_1_0_outputs,0,3.4249582,0.4830792,-1.4701751,-2.006143,-1.3161409 35 | attn_1_0_outputs,1,1.6426649,1.2126992,-0.69759977,-1.437486,-1.3773936 36 | attn_1_0_outputs,2,0.23226358,0.64345044,1.8823681,-1.4889188,-2.2751608 37 | attn_1_0_outputs,3,-2.1088793,-1.0431234,-0.4745396,2.6021097,0.0047708363 38 | attn_1_0_outputs,4,-0.4867569,-1.6423402,-1.241547,0.5894667,2.0828123 39 | attn_1_0_outputs,,-2.0974174,-2.5358207,-0.59028447,2.7884414,3.1243649 40 | attn_1_0_outputs,,-1.4079151,0.42247453,0.79169756,1.1603345,0.31134152 41 | attn_1_0_outputs,,-0.8227935,-2.4848738,-0.8918966,1.4806012,2.132991 42 | attn_1_1_outputs,0,13.7478485,2.4539003,-4.1379776,-8.794688,-13.024131 43 | attn_1_1_outputs,1,3.659074,7.3829765,-0.4129496,-6.3457537,-8.960195 44 | attn_1_1_outputs,2,-3.8716147,-0.74163604,4.9987493,0.16779378,-3.0191104 45 | attn_1_1_outputs,3,-9.338297,-5.943794,-0.07820616,6.4413576,3.6664085 46 | attn_1_1_outputs,4,-11.835959,-7.989941,-2.6649897,4.1501045,14.6659155 47 | attn_1_1_outputs,,-1.7373832,-1.0473574,-0.07293691,0.88357306,3.197291 48 | attn_1_1_outputs,,-0.26156133,0.73238564,1.010969,-0.0779104,-0.3386485 49 | attn_1_1_outputs,,-5.260586,-0.10670824,1.731505,2.0308712,-6.089484 50 | attn_1_2_outputs,0,11.480239,2.3556077,-2.5624404,-6.0736723,-8.765706 51 | attn_1_2_outputs,1,2.2337794,6.481738,-0.6245802,-4.446057,-7.1153054 52 | attn_1_2_outputs,2,-4.0426908,-1.4142859,4.2344704,0.5490717,-2.002491 53 | attn_1_2_outputs,3,-7.267531,-4.5090466,0.7685141,7.347656,-0.10958998 54 | attn_1_2_outputs,4,-9.947915,-7.479793,-2.613032,2.9863348,10.500013 55 | attn_1_2_outputs,,0.53556585,0.11474143,-0.014953171,-1.402238,0.556593 56 | attn_1_2_outputs,,-0.25244236,0.8308479,0.7589563,-0.32094172,-1.2273515 57 | attn_1_2_outputs,,-0.53400123,-2.8056118,-0.78918266,0.04643116,2.6973917 58 | attn_1_3_outputs,0,1.9439827,0.1603147,-0.28384447,-0.5827607,-1.5783767 59 | attn_1_3_outputs,1,0.16566427,0.85972536,0.93071663,0.55593437,-0.40285176 60 | attn_1_3_outputs,2,0.22454293,0.38536653,1.0067779,0.041088387,-1.6240689 61 | attn_1_3_outputs,3,-0.34689322,-0.12858017,0.11249515,0.83399034,-0.7050651 62 | attn_1_3_outputs,4,-0.8917678,-1.1450619,-0.80545264,-0.21876524,0.8630245 63 | attn_1_3_outputs,,-0.030100795,-0.6054426,-0.36994955,-0.07122007,0.38008985 64 | attn_1_3_outputs,,-0.4243932,0.38434273,0.29122502,-0.122149594,0.6827977 65 | attn_1_3_outputs,,-2.117496,-1.7259998,-0.7347762,0.4593257,2.2775197 66 | mlp_0_0_outputs,0,4.0188355,1.367233,-1.417323,-6.483142,-13.193972 67 | mlp_0_0_outputs,1,-3.9137676,-0.15376055,0.49536678,0.9998601,0.047882058 68 | mlp_0_0_outputs,2,-15.88143,-5.672158,0.99274987,3.5297909,4.746169 69 | mlp_0_0_outputs,3,0.70922613,0.12939063,0.303092,0.19436331,-0.15312691 70 | mlp_0_0_outputs,4,-19.113987,-14.055936,-5.2655005,1.665476,7.5324826 71 | mlp_0_0_outputs,5,0.0049047554,0.3532272,-0.61323506,-0.15334794,0.20983703 72 | mlp_0_0_outputs,6,-2.2349446,0.48716977,1.2473469,1.0492531,-0.959091 73 | mlp_0_0_outputs,7,2.0180795,2.697081,1.4667916,-2.6459293,-8.920683 74 | mlp_0_1_outputs,0,4.9354362,2.2356412,-0.82518,-6.697162,-13.193914 75 | mlp_0_1_outputs,1,-14.731216,-4.927309,-0.5305642,1.9409475,2.7460032 76 | mlp_0_1_outputs,2,0.23375379,0.11552315,0.44065633,-0.086906254,0.16232851 77 | mlp_0_1_outputs,3,-0.01565274,-0.505149,0.24257246,-0.20083542,0.091414995 78 | mlp_0_1_outputs,4,1.0003215,0.1339593,0.18748964,0.14529389,0.15201308 79 | mlp_0_1_outputs,5,-0.07818346,0.11651017,0.8559554,0.40377986,0.6004811 80 | mlp_0_1_outputs,6,-5.059556,-0.8851335,1.2983979,1.9971292,1.2971436 81 | mlp_0_1_outputs,7,2.3331697,2.7773235,1.5271827,-2.8835008,-7.8288684 82 | mlp_0_2_outputs,0,-5.812679,-8.050895,-2.0316663,1.604939,3.8419476 83 | mlp_0_2_outputs,1,-15.904148,-4.8865576,0.7665841,2.2370648,4.8236346 84 | mlp_0_2_outputs,2,0.023708992,0.23602854,-0.21941465,-0.06791145,0.14681834 85 | mlp_0_2_outputs,3,-5.155198,1.1561449,2.7422621,0.84532666,-1.3750209 86 | mlp_0_2_outputs,4,-1.6516947,0.14862537,0.27884433,0.09279299,0.0030477077 87 | mlp_0_2_outputs,5,-4.058745,1.8142332,1.0810403,-0.14894587,-0.9095843 88 | mlp_0_2_outputs,6,5.3210797,2.4041133,-1.6120065,-5.1028633,-10.729489 89 | mlp_0_2_outputs,7,-0.31382832,0.9711427,-0.14119859,-0.1960116,-0.1285425 90 | mlp_0_3_outputs,0,-15.051607,-2.8604221,2.0543356,2.499284,2.205402 91 | mlp_0_3_outputs,1,5.6800714,1.356257,-2.8838096,-6.541643,-12.12329 92 | mlp_0_3_outputs,2,0.8320392,0.06768919,0.18168978,-0.13751434,0.71345246 93 | mlp_0_3_outputs,3,-2.744153,-1.4329554,1.1388501,1.9455607,-0.36828822 94 | mlp_0_3_outputs,4,-0.47749937,1.1444833,1.4878665,0.511925,-2.8988562 95 | mlp_0_3_outputs,5,4.4155946,4.092108,1.8837255,-5.068154,-17.417696 96 | mlp_0_3_outputs,6,0.91394186,1.7178278,0.29244098,-1.9470179,-0.79406583 97 | mlp_0_3_outputs,7,-27.025166,-14.000694,-3.8636396,4.7704196,9.954426 98 | mlp_1_0_outputs,0,-8.994571,-4.491032,-1.4437572,2.639763,4.943209 99 | mlp_1_0_outputs,1,5.9503913,2.6140954,-2.56839,-10.30842,-19.072475 100 | mlp_1_0_outputs,2,-3.2524624,1.9415048,2.075256,0.1147645,-2.324298 101 | mlp_1_0_outputs,3,-3.3153234,-6.1573997,-1.7476376,1.969032,3.2494743 102 | mlp_1_0_outputs,4,-7.131646,-0.9621037,3.1830635,1.1788143,-1.1787232 103 | mlp_1_0_outputs,5,2.321128,1.8879267,0.55639076,-1.853329,-5.405838 104 | mlp_1_0_outputs,6,-0.3154466,1.7536464,1.4805571,-0.4060518,-2.380789 105 | mlp_1_0_outputs,7,-4.6405716,0.54778636,2.5990744,1.2474383,-1.03555 106 | mlp_1_1_outputs,0,0.2574987,1.1060137,0.34981915,-0.5553312,-1.3027234 107 | mlp_1_1_outputs,1,0.096980885,0.19448379,0.38600856,0.21713899,0.31291375 108 | mlp_1_1_outputs,2,0.20168386,0.24377005,-0.16930194,0.0846435,0.27269673 109 | mlp_1_1_outputs,3,-0.15246853,0.13554697,-0.14340618,-0.24426788,-0.1754791 110 | mlp_1_1_outputs,4,-1.1793035,-0.2165895,0.05511073,0.106202744,1.238185 111 | mlp_1_1_outputs,5,1.5681087,0.95414907,0.6013005,-1.0025665,-1.9541403 112 | mlp_1_1_outputs,6,-0.92157024,-0.2364761,0.2521039,1.1757741,1.7977172 113 | mlp_1_1_outputs,7,-0.8971796,0.37849936,-0.026569659,0.031324264,-0.46514994 114 | mlp_1_2_outputs,0,-1.3478705,-2.4680243,0.66674495,1.3479987,0.5765663 115 | mlp_1_2_outputs,1,0.37738907,1.0555747,1.5234038,-2.4196775,-6.760583 116 | mlp_1_2_outputs,2,2.4935262,4.591032,-4.003316,-5.4412313,-2.304786 117 | mlp_1_2_outputs,3,4.2700715,2.2302587,-0.795981,-2.4704978,-4.8020735 118 | mlp_1_2_outputs,4,-3.9012556,-1.7321999,3.3738852,2.948957,2.2248867 119 | mlp_1_2_outputs,5,-2.0524378,-5.6856995,-0.8716151,2.056629,3.9021623 120 | mlp_1_2_outputs,6,2.1821785,1.7694793,-0.9968027,-1.056598,-4.7362437 121 | mlp_1_2_outputs,7,-2.2508187,0.65054375,1.1322451,0.36528292,-1.2405411 122 | mlp_1_3_outputs,0,0.3038056,-0.40180534,0.27598795,0.07215004,-0.9269151 123 | mlp_1_3_outputs,1,-3.1313102,-2.0797613,-0.5406382,1.3487104,2.601697 124 | mlp_1_3_outputs,2,1.0282918,0.54115653,0.54766685,0.22385533,-0.9048508 125 | mlp_1_3_outputs,3,-0.046904188,0.8035955,1.1800617,0.2309441,-0.29504803 126 | mlp_1_3_outputs,4,0.7052468,-0.32243052,-0.2011492,-0.5544647,0.47795638 127 | mlp_1_3_outputs,5,1.8298012,0.72522086,0.55117387,0.12736094,-1.4325687 128 | mlp_1_3_outputs,6,0.09611695,-0.158596,-0.02878351,-0.4176433,-0.78748065 129 | mlp_1_3_outputs,7,1.6000229,0.9329063,1.1538035,0.3359634,-1.0799118 130 | ones,_,-0.055436447,-0.1179431,-1.0715027,-0.15707256,-0.06473778 131 | positions,0,-0.45721093,-0.87155324,-0.35231563,0.61707205,0.11435398 132 | positions,1,7.035323,2.7762,-4.4782853,-11.52158,-7.547409 133 | positions,2,1.7988344,4.5708203,1.894728,-4.0963306,-10.078739 134 | positions,3,-2.066558,3.0727959,3.1645296,-0.5606865,-4.997998 135 | positions,4,-8.358648,-0.029626308,1.6459913,4.489941,-0.8345926 136 | positions,5,-11.328461,-4.9711714,-0.53534645,5.983784,3.4289176 137 | positions,6,-17.799313,-13.832921,-4.907452,3.116108,9.586147 138 | positions,7,-0.025707932,-0.42475572,-0.059313126,0.047880296,0.63343513 139 | tokens,0,9.776978,1.7060088,-2.9519324,-5.849588,-7.211308 140 | tokens,1,2.336422,4.711168,-0.5835804,-3.646622,-5.077569 141 | tokens,2,-2.223295,-0.92388165,3.334594,-0.44335306,-2.2490618 142 | tokens,3,-4.5560665,-3.9469757,-0.53029966,5.159678,2.8704758 143 | tokens,4,-6.5824995,-5.69455,-2.5579076,2.4740694,9.455751 144 | tokens,,0.6628496,-0.32530716,0.34280553,-0.2948921,-0.063519835 145 | tokens,,-0.010760286,-0.05715638,0.4203831,-0.12519221,0.6524513 146 | tokens,,-0.5623373,0.11551832,0.005275513,0.6801121,-0.23678207 147 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | datasets 3 | einops 4 | gensim 5 | matplotlib 6 | nltk 7 | numpy 8 | pandas 9 | scikit-learn 10 | seaborn 11 | torch 12 | tqdm 13 | -------------------------------------------------------------------------------- /scripts/classification.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N_EPOCHS=10 4 | N_LAYERS=2 5 | N_VARS_CAT=4 6 | N_HEADS_CAT=4 7 | N_CAT_MLPS=1 8 | SEED=0 9 | 10 | python src/run.py \ 11 | --dataset "${DATASET}" \ 12 | --vocab_size 10000 \ 13 | --min_length 1 \ 14 | --max_length 64 \ 15 | --n_epochs "${N_EPOCHS}" \ 16 | --batch_size 128 \ 17 | --lr "5e-2" \ 18 | --gumbel_samples 1 \ 19 | --sample_fn "gumbel_soft" \ 20 | --tau_init 3.0 \ 21 | --tau_end 0.01 \ 22 | --tau_schedule "geomspace" \ 23 | --n_vars_cat "${N_VARS_CAT}" \ 24 | --d_var 64 \ 25 | --n_vars_num 1 \ 26 | --n_layers "${N_LAYERS}" \ 27 | --n_heads_cat "${N_HEADS_CAT}" \ 28 | --n_heads_num 0 \ 29 | --n_cat_mlps "${N_CAT_MLPS}" \ 30 | --n_num_mlps 0 \ 31 | --attention_type "cat" \ 32 | --rel_pos_bias "fixed" \ 33 | --dropout 0.0 \ 34 | --mlp_vars_in 2 \ 35 | --d_mlp 64 \ 36 | --count_only \ 37 | --selector_width 0 \ 38 | --do_lower 0 \ 39 | --replace_numbers 0 \ 40 | --glove_embeddings "data/glove.840B.300d.txt" \ 41 | --do_glove 1 \ 42 | --pool_outputs 1 \ 43 | --seed "${SEED}" \ 44 | --save \ 45 | --save_code \ 46 | --output_dir "output/classification/${DATASET}/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}"; 47 | -------------------------------------------------------------------------------- /scripts/classification_short.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N_EPOCHS=10 4 | N_LAYERS=2 5 | N_VARS_CAT=4 6 | N_HEADS_CAT=4 7 | N_CAT_MLPS=1 8 | SEED=0 9 | 10 | python src/run.py \ 11 | --dataset "${DATASET}" \ 12 | --vocab_size 5000 \ 13 | --min_length 1 \ 14 | --max_length 25 \ 15 | --n_epochs "${N_EPOCHS}" \ 16 | --batch_size 128 \ 17 | --lr "5e-2" \ 18 | --gumbel_samples 1 \ 19 | --sample_fn "gumbel_soft" \ 20 | --tau_init 3.0 \ 21 | --tau_end 0.01 \ 22 | --tau_schedule "geomspace" \ 23 | --n_vars_cat "${N_VARS_CAT}" \ 24 | --d_var 25 \ 25 | --n_vars_num 1 \ 26 | --n_layers "${N_LAYERS}" \ 27 | --n_heads_cat "${N_HEADS_CAT}" \ 28 | --n_heads_num 0 \ 29 | --n_cat_mlps "${N_CAT_MLPS}" \ 30 | --n_num_mlps 0 \ 31 | --attention_type "cat" \ 32 | --rel_pos_bias "fixed" \ 33 | --dropout 0.0 \ 34 | --mlp_vars_in 2 \ 35 | --d_mlp 64 \ 36 | --count_only \ 37 | --selector_width 0 \ 38 | --do_lower 0 \ 39 | --replace_numbers 0 \ 40 | --glove_embeddings "data/glove.840B.300d.txt" \ 41 | --do_glove 1 \ 42 | --pool_outputs 1 \ 43 | --seed "${SEED}" \ 44 | --save \ 45 | --save_code \ 46 | --output_dir "output/classification_short/${DATASET}/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}"; 47 | -------------------------------------------------------------------------------- /scripts/classification_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | D_MODEL=64 4 | N_LAYERS=2 5 | N_HEADS=4 6 | SEED=0 7 | 8 | python src/run.py \ 9 | --dataset "${DATASET}" \ 10 | --standard \ 11 | --vocab_size 10000 \ 12 | --min_length 1 \ 13 | --max_length 64 \ 14 | --n_epochs 100 \ 15 | --batch_size 128 \ 16 | --lr "5e-3" \ 17 | --d_model "${D_MODEL}" \ 18 | --n_heads "${N_HEADS}" \ 19 | --n_layers "${N_LAYERS}" \ 20 | --d_mlp 64 \ 21 | --dropout 0.5 \ 22 | --max_grad_norm 5.0 \ 23 | --do_lower 0 \ 24 | --replace_numbers 0 \ 25 | --glove_embeddings "data/glove.840B.300d.txt" \ 26 | --do_glove "${DO_GLOVE}" \ 27 | --pool_outputs 1 \ 28 | --seed "${SEED}" \ 29 | --output_dir "output/classification/${DATASET}/standard_transformer/dmodel${D_MODEL}nheads${N_CAT}nlayers${N_LAYERS}/s${SEED}"; 30 | -------------------------------------------------------------------------------- /scripts/conll.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N_EPOCHS=50 4 | N_LAYERS=2 5 | N_VARS_CAT=4 6 | N_HEADS_CAT=8 7 | N_CAT_MLPS=1 8 | SEED=0 9 | 10 | python src/run.py \ 11 | --dataset "conll_ner" \ 12 | --vocab_size 10000 \ 13 | --min_length 1 \ 14 | --max_length 32 \ 15 | --n_epochs "${N_EPOCHS}" \ 16 | --batch_size 32 \ 17 | --lr "1e-2" \ 18 | --gumbel_samples 1 \ 19 | --sample_fn "gumbel_soft" \ 20 | --tau_init 3.0 \ 21 | --tau_end 0.01 \ 22 | --tau_schedule "geomspace" \ 23 | --n_vars_cat "${N_VARS_CAT}" \ 24 | --d_var 32 \ 25 | --n_vars_num 1 \ 26 | --n_layers "${N_LAYERS}" \ 27 | --n_heads_cat "${N_HEADS_CAT}" \ 28 | --n_heads_num 0 \ 29 | --n_cat_mlps "${N_CAT_MLPS}" \ 30 | --n_num_mlps 0 \ 31 | --attention_type "cat" \ 32 | --rel_pos_bias "fixed" \ 33 | --dropout 0.5 \ 34 | --mlp_vars_in 2 \ 35 | --d_mlp 64 \ 36 | --count_only \ 37 | --selector_width 0 \ 38 | --do_lower 0 \ 39 | --replace_numbers 1 \ 40 | --glove_embeddings "data/glove.840B.300d.txt" \ 41 | --do_glove 1 \ 42 | --pool_outputs 0 \ 43 | --seed "${SEED}" \ 44 | --save \ 45 | --save_code \ 46 | --output_dir "output/conll_ner/transformer_program/nvars${N_VARS_CAT}nheads${N_HEADS_CAT}nlayers${N_LAYERS}nmlps${N_CAT_MLPS}/epochs${N_EPOCHS}/s${SEED}"; 47 | -------------------------------------------------------------------------------- /scripts/conll_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | D_MODEL=128 4 | N_LAYERS=3 5 | N_HEADS=4 6 | SEED=0 7 | 8 | python src/run.py \ 9 | --dataset "conll_ner" \ 10 | --standard \ 11 | --vocab_size 10000 \ 12 | --min_length 1 \ 13 | --max_length 32 \ 14 | --n_epochs 100 \ 15 | --batch_size 32 \ 16 | --lr "5e-3" \ 17 | --d_model "${D_MODEL}" \ 18 | --n_heads "${N_HEADS}" \ 19 | --n_layers "${N_LAYERS}" \ 20 | --d_mlp 64 \ 21 | --dropout 0.5 \ 22 | --max_grad_norm 5.0 \ 23 | --do_lower 0 \ 24 | --replace_numbers 1 \ 25 | --glove_embeddings "data/glove.840B.300d.txt" \ 26 | --do_glove "${DO_GLOVE}" \ 27 | --pool_outputs 0 \ 28 | --seed "${SEED}" \ 29 | --output_dir "output/conll_ner/standard_transformer/dmodel${D_MODEL}nheads${N_CAT}nlayers${N_LAYERS}/s${SEED}"; 30 | -------------------------------------------------------------------------------- /scripts/dyck.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MAX_LENGTH=16 4 | N_LAYERS=3 5 | N_HEADS_CAT=4 6 | N_HEADS_NUM=4 7 | N_CAT_MLPS=2 8 | N_NUM_MLPS=2 9 | SEED=0 10 | 11 | python src/run.py \ 12 | --dataset "${DATASET}" \ 13 | --dataset_size 20000 \ 14 | --min_length "${MAX_LENGTH}" \ 15 | --max_length "${MAX_LENGTH}" \ 16 | --n_epochs 250 \ 17 | --batch_size 512 \ 18 | --lr "5e-2" \ 19 | --gumbel_samples 1 \ 20 | --sample_fn "gumbel_soft" \ 21 | --tau_init 3.0 \ 22 | --tau_end 0.01 \ 23 | --tau_schedule "geomspace" \ 24 | --n_vars_cat 1 \ 25 | --d_var "${MAX_LENGTH}" \ 26 | --n_vars_num 1 \ 27 | --n_layers "${N_LAYERS}" \ 28 | --n_heads_cat "${N_HEADS_CAT}" \ 29 | --n_heads_num "${N_HEADS_NUM}" \ 30 | --n_cat_mlps "${N_MLPS}" \ 31 | --n_num_mlps "${N_NUM_CAT_MLPS}" \ 32 | --attention_type "cat" \ 33 | --rel_pos_bias "fixed" \ 34 | --one_hot_embed \ 35 | --dropout 0.0 \ 36 | --mlp_vars_in 2 \ 37 | --d_mlp 64 \ 38 | --count_only \ 39 | --selector_width 0 \ 40 | --seed "${SEED}" \ 41 | --unique 1 \ 42 | --save \ 43 | --save_code \ 44 | --output_dir "output/rasp/${DATASET}/maxlen${MAX_LENGTH}/transformer_program/headsc${N_HEADS_CAT}headsn${N_HEADS_NUM}nlayers${N_LAYERS}cmlps${N_MLPS}nmlps${N_NUM_CAT_MLPS}/s${SEED}"; 45 | -------------------------------------------------------------------------------- /scripts/induction.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VOCAB_SIZE=10 4 | MIN_LENGTH=9 5 | MAX_LENGTH=9 6 | SEED=6 7 | 8 | echo "SEED=${SEED}"; 9 | 10 | python src/run.py \ 11 | --dataset "induction" \ 12 | --vocab_size "${VOCAB_SIZE}" \ 13 | --dataset_size 20000 \ 14 | --min_length "${MIN_LENGTH}" \ 15 | --max_length "${MAX_LENGTH}" \ 16 | --n_epochs 500 \ 17 | --batch_size 512 \ 18 | --lr "5e-2" \ 19 | --gumbel_samples 1 \ 20 | --sample_fn "gumbel_soft" \ 21 | --tau_init 3.0 \ 22 | --tau_end 0.01 \ 23 | --tau_schedule "geomspace" \ 24 | --n_vars_cat 1 \ 25 | --n_vars_num 1 \ 26 | --n_layers 2 \ 27 | --n_heads_cat 1 \ 28 | --n_heads_num 0 \ 29 | --n_cat_mlps 0 \ 30 | --n_num_mlps 0 \ 31 | --attention_type "cat" \ 32 | --rel_pos_bias "fixed" \ 33 | --one_hot_embed \ 34 | --count_only \ 35 | --selector_width 0 \ 36 | --seed "${SEED}" \ 37 | --unique 1 \ 38 | --unembed_mask 0 \ 39 | --autoregressive \ 40 | --save \ 41 | --save_code \ 42 | --output_dir "output/induction/s${SEED}"; 43 | -------------------------------------------------------------------------------- /scripts/rasp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VOCAB_SIZE=8 4 | MAX_LENGTH=8 5 | N_LAYERS=3 6 | N_HEADS_CAT=4 7 | N_HEADS_NUM=4 8 | N_CAT_MLPS=2 9 | N_NUM_MLPS=2 10 | SEED=0 11 | 12 | python src/run.py \ 13 | --dataset "${DATASET}" \ 14 | --vocab_size "${VOCAB_SIZE}" \ 15 | --dataset_size 20000 \ 16 | --min_length 1 \ 17 | --max_length "${MAX_LENGTH}" \ 18 | --n_epochs 250 \ 19 | --batch_size 512 \ 20 | --lr "5e-2" \ 21 | --gumbel_samples 1 \ 22 | --sample_fn "gumbel_soft" \ 23 | --tau_init 3.0 \ 24 | --tau_end 0.01 \ 25 | --tau_schedule "geomspace" \ 26 | --n_vars_cat 1 \ 27 | --d_var "${MAX_LENGTH}" \ 28 | --n_vars_num 1 \ 29 | --n_layers "${N_LAYERS}" \ 30 | --n_heads_cat "${N_HEADS_CAT}" \ 31 | --n_heads_num "${N_HEADS_NUM}" \ 32 | --n_cat_mlps "${N_MLPS}" \ 33 | --n_num_mlps "${N_NUM_CAT_MLPS}" \ 34 | --attention_type "cat" \ 35 | --rel_pos_bias "fixed" \ 36 | --one_hot_embed \ 37 | --dropout 0.0 \ 38 | --mlp_vars_in 2 \ 39 | --d_mlp 64 \ 40 | --count_only \ 41 | --selector_width 0 \ 42 | --seed "${SEED}" \ 43 | --unique 1 \ 44 | --save \ 45 | --save_code \ 46 | --output_dir "output/rasp/${DATASET}/vocab${VOCAB_SIZE}maxlen${MAX_LENGTH}/transformer_program/headsc${N_HEADS_CAT}headsn${N_HEADS_NUM}nlayers${N_LAYERS}cmlps${N_MLPS}nmlps${N_NUM_CAT_MLPS}/s${SEED}"; 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name="src", version="0.1", packages=find_packages()) 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/__init__.py -------------------------------------------------------------------------------- /src/decompile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | from copy import deepcopy 4 | from functools import partial 5 | import itertools 6 | import json 7 | import math 8 | from pathlib import Path 9 | import random 10 | 11 | import einops 12 | import numpy as np 13 | import pandas as pd 14 | import re 15 | import torch 16 | from torch import nn 17 | from torch.nn import functional as F 18 | from torch.utils.data import DataLoader 19 | from tqdm import tqdm 20 | 21 | from src.models.programs import TransformerProgramModel, argmax 22 | 23 | from src.run import set_seed, get_sample_fn 24 | from src.utils import code_utils, data_utils, logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class DotDict(dict): 31 | __getattr__ = dict.__getitem__ 32 | __setattr__ = dict.__setitem__ 33 | __delattr__ = dict.__delitem__ 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--path", type=str, default="scratch") 39 | parser.add_argument("--output_dir", type=str, default="scratch") 40 | return parser.parse_args() 41 | 42 | 43 | def load_model(path): 44 | args_fn = Path(path) / "args.json" 45 | if not args_fn.exists(): 46 | raise ValueError(f"missing {args_fn}") 47 | model_fn = Path(path) / "model.pt" 48 | if not model_fn.exists(): 49 | raise ValueError(f"missing {model_fn}") 50 | 51 | with open(args_fn, "r") as f: 52 | args = DotDict(json.load(f)) 53 | 54 | logger.info(f"loading {args.dataset}") 55 | set_seed(args.seed) 56 | ( 57 | train, 58 | test, 59 | val, 60 | idx_w, 61 | w_idx, 62 | idx_t, 63 | t_idx, 64 | X_train, 65 | Y_train, 66 | X_test, 67 | Y_test, 68 | X_val, 69 | Y_val, 70 | ) = data_utils.get_dataset( 71 | name=args.dataset, 72 | vocab_size=args.vocab_size, 73 | dataset_size=args.dataset_size, 74 | min_length=args.min_length, 75 | max_length=args.max_length, 76 | seed=args.seed, 77 | do_lower=args.do_lower, 78 | replace_numbers=args.replace_numbers, 79 | get_val=True, 80 | unique=args.unique, 81 | ) 82 | 83 | logger.info(f"initializing model from {args_fn}") 84 | if args.d_var is None: 85 | d = max(len(idx_w), X_train.shape[-1]) 86 | else: 87 | d = args.d_var 88 | init_emb = None 89 | if args.glove_embeddings and args.do_glove: 90 | emb = data_utils.get_glove_embeddings( 91 | idx_w, args.glove_embeddings, dim=args.n_vars_cat * d 92 | ) 93 | init_emb = torch.tensor(emb, dtype=torch.float32).T 94 | unembed_mask = None 95 | if args.unembed_mask: 96 | unembed_mask = np.array([t in ("", "") for t in idx_t]) 97 | set_seed(args.seed) 98 | model = TransformerProgramModel( 99 | d_vocab=len(idx_w), 100 | d_vocab_out=len(idx_t), 101 | n_vars_cat=args.n_vars_cat, 102 | n_vars_num=args.n_vars_num, 103 | d_var=d, 104 | n_heads_cat=args.n_heads_cat, 105 | n_heads_num=args.n_heads_num, 106 | d_mlp=args.d_mlp, 107 | n_cat_mlps=args.n_cat_mlps, 108 | n_num_mlps=args.n_num_mlps, 109 | mlp_vars_in=args.mlp_vars_in, 110 | n_layers=args.n_layers, 111 | n_ctx=X_train.shape[1], 112 | sample_fn=get_sample_fn(args.sample_fn), 113 | init_emb=init_emb, 114 | attention_type=args.attention_type, 115 | rel_pos_bias=args.rel_pos_bias, 116 | unembed_mask=unembed_mask, 117 | pool_outputs=args.pool_outputs, 118 | one_hot_embed=args.one_hot_embed, 119 | count_only=args.count_only, 120 | selector_width=args.selector_width, 121 | ) 122 | 123 | logger.info(f"restoring weights from {model_fn}") 124 | model.load_state_dict( 125 | torch.load(str(model_fn), map_location=torch.device("cpu")) 126 | ) 127 | 128 | model.set_temp(args.tau_end, argmax) 129 | return model, args, idx_w, idx_t, X_val 130 | 131 | 132 | def model_to_code(model, args, idx_w, idx_t, X=None, output_dir=""): 133 | if output_dir: 134 | Path(output_dir).mkdir(exist_ok=True, parents=True) 135 | x = None 136 | if X is not None: 137 | x = idx_w[X[0]] 138 | x = x[x != ""].tolist() 139 | m = code_utils.model_to_code( 140 | model=model, 141 | idx_w=idx_w, 142 | idx_t=idx_t, 143 | embed_csv=not args.one_hot_embed, 144 | unembed_csv=True, 145 | one_hot=args.one_hot_embed, 146 | autoregressive=args.autoregressive, 147 | var_types=True, 148 | output_dir=output_dir, 149 | name=args.dataset, 150 | example=x, 151 | save=bool(output_dir), 152 | ) 153 | return m 154 | 155 | 156 | if __name__ == "__main__": 157 | args = parse_args() 158 | model, args, idx_w, idx_t, X_val = load_model(args.path) 159 | model_to_code( 160 | model=model, 161 | args=args, 162 | idx_w=idx_w, 163 | idx_t=idx_t, 164 | X=X_val, 165 | output_dir=args.output_dir, 166 | ) 167 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/transformers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | # Define network architecture. Adapted from 10 | # https://colab.research.google.com/drive/19gn2tavBGDqOYHLatjSROhABBD5O_JyZ 11 | class Embed(nn.Module): 12 | def __init__(self, d_vocab, d_model, init_emb=None): 13 | super().__init__() 14 | if init_emb is not None: 15 | self.W_E = nn.Parameter(init_emb) 16 | else: 17 | self.W_E = nn.Parameter( 18 | torch.randn(d_model, d_vocab) / np.sqrt(d_model) 19 | ) 20 | 21 | def forward(self, x): 22 | return torch.einsum("dbp -> bpd", self.W_E[:, x]) 23 | 24 | 25 | class Unembed(nn.Module): 26 | def __init__(self, d_vocab, d_model, mask=None): 27 | super().__init__() 28 | self.W_U = nn.Parameter( 29 | torch.randn(d_model, d_vocab) / np.sqrt(d_vocab) 30 | ) 31 | if mask is not None: 32 | self.register_buffer( 33 | "mask", torch.tensor(mask).view(1, 1, len(mask)) 34 | ) 35 | else: 36 | self.mask = None 37 | 38 | def forward(self, x): 39 | logits = x @ self.W_U 40 | if self.mask is not None: 41 | logits = logits.masked_fill(self.mask, -1e30) 42 | return logits 43 | 44 | 45 | class PosEmbed(nn.Module): 46 | def __init__(self, max_ctx, d_model): 47 | super().__init__() 48 | self.W_pos = nn.Parameter( 49 | torch.randn(max_ctx, d_model) / np.sqrt(d_model) 50 | ) 51 | 52 | def forward(self, x): 53 | return x + self.W_pos[: x.shape[-2]] 54 | 55 | 56 | class Transformer(nn.Module): 57 | def __init__( 58 | self, 59 | d_vocab, 60 | n_layers=2, 61 | d_model=64, 62 | d_mlp=64, 63 | n_heads=4, 64 | n_ctx=32, 65 | act_type="ReLU", 66 | d_vocab_out=None, 67 | dropout=0.0, 68 | init_emb=None, 69 | unembed_mask=None, 70 | pool_outputs=False, 71 | **kwargs, 72 | ): 73 | super().__init__() 74 | self.embed = Embed(d_vocab, d_model, init_emb=init_emb) 75 | self.pos_embed = PosEmbed(n_ctx, d_model) 76 | layer = nn.TransformerEncoderLayer( 77 | d_model=d_model, 78 | nhead=n_heads, 79 | batch_first=True, 80 | dim_feedforward=d_mlp, 81 | dropout=dropout, 82 | norm_first=True, 83 | ) 84 | self.encoder = nn.TransformerEncoder(layer, n_layers) 85 | self.unembed = Unembed( 86 | d_vocab_out or d_vocab, d_model, mask=unembed_mask 87 | ) 88 | self.n_heads = n_heads 89 | self.dropout = nn.Dropout(dropout) 90 | self.pool_outputs = pool_outputs 91 | 92 | def forward(self, x, mask=None): 93 | x = self.embed(x) 94 | x = x * np.sqrt(x.shape[-1]) 95 | x = self.pos_embed(x) 96 | x = self.dropout(x) 97 | m = mask 98 | mask = torch.cat([mask for _ in range(self.n_heads)], 0) 99 | mask[:, :, 0] = True 100 | x = self.encoder(x, mask=~mask) 101 | if self.pool_outputs: 102 | x = torch.cat( 103 | [ 104 | x.masked_fill(~m[:, 0].unsqueeze(-1), 0).mean( 105 | 1, keepdims=True 106 | ), 107 | x[:, 1:], 108 | ], 109 | 1, 110 | ) 111 | x = self.unembed(x) 112 | return x 113 | 114 | @property 115 | def device(self): 116 | return self.embed.W_E.device 117 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/TransformerPrograms/59970de542e14406d8f3d01bbd30097276959865/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | from pathlib import Path 5 | import sys 6 | 7 | 8 | def get_handler(fn=None, stream=None): 9 | formatter = logging.Formatter( 10 | fmt="%(levelname).1s %(asctime)s %(name)s:%(lineno)d: %(message)s", 11 | datefmt="%Y-%m-%dT%H:%M:%S", 12 | ) 13 | h = ( 14 | logging.FileHandler(fn, mode="w") 15 | if fn 16 | else logging.StreamHandler(stream=stream) 17 | ) 18 | h.setLevel(logging.INFO) 19 | h.setFormatter(formatter) 20 | return h 21 | 22 | 23 | def get_resume_name(output_dir, s="output.log"): 24 | output_logs = list(Path(output_dir).glob(f"{s}*")) 25 | if len(output_logs) == 0: 26 | return Path(output_dir) / s 27 | idx = 1 28 | for p in output_logs: 29 | if str(p)[-1].isdigit(): 30 | p_idx = int(str(p).split(".")[-1]) 31 | idx = max(idx, p_idx + 1) 32 | return Path(output_dir) / f"{s}.{idx}" 33 | 34 | 35 | def initialize(output_dir, resume=False): 36 | fn = ( 37 | get_resume_name(output_dir) 38 | if resume 39 | else Path(output_dir) / "output.log" 40 | ) 41 | if not fn.parent.exists(): 42 | fn.parent.mkdir(parents=True) 43 | handlers = ( 44 | get_handler(stream=sys.stdout), 45 | get_handler(fn=fn), 46 | ) 47 | logging.basicConfig(handlers=handlers, force=True, level=logging.INFO) 48 | 49 | 50 | def get_logger(name): 51 | logging.basicConfig(handlers=[get_handler(stream=None)], level=logging.INFO) 52 | return logging.getLogger(name) 53 | -------------------------------------------------------------------------------- /src/utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from seqeval.metrics import precision_score 3 | from seqeval.metrics import recall_score 4 | from seqeval.metrics import f1_score 5 | from seqeval.scheme import IOB2 6 | 7 | 8 | def eval_induction(X, Y, Y_pred, y_pad_idx=0): 9 | rows = [] 10 | for y in Y: 11 | m = [] 12 | seen = set() 13 | for t in y: 14 | if t not in seen: 15 | m.append(False) 16 | seen.add(t) 17 | else: 18 | m.append(True) 19 | rows.append(m) 20 | mask = np.stack(m, 0) & (Y != y_pad_idx) 21 | return (Y == Y_pred)[mask].mean() 22 | 23 | 24 | def __f1_score(y_true, y_pred, o_idx): 25 | precision, recall, f1 = 0, 0, 0 26 | m_p = y_pred != o_idx 27 | if m_p.sum() > 0: 28 | precision = (y_true[m_p] == y_pred[m_p]).mean() 29 | m_r = y_true != o_idx 30 | if m_r.sum() > 0: 31 | recall = (y_true[m_r] == y_pred[m_r]).mean() 32 | if precision + recall > 0: 33 | f1 = 2 * precision * recall / (precision + recall) 34 | return {"precision": precision, "recall": recall, "f1": f1} 35 | 36 | 37 | def conll_score(y_true, y_pred): 38 | return { 39 | "precision": precision_score( 40 | y_true, y_pred, mode="strict", scheme=IOB2 41 | ), 42 | "recall": recall_score(y_true, y_pred, mode="strict", scheme=IOB2), 43 | "f1": f1_score(y_true, y_pred, mode="strict", scheme=IOB2), 44 | } 45 | --------------------------------------------------------------------------------