├── .gitignore ├── README.md ├── baselines ├── README.md ├── difference │ ├── demo.ipynb │ └── diff_explainer.py ├── integrated_gradients │ ├── demo.ipynb │ ├── ig.py │ ├── ig_utils.py │ └── visualize.py ├── integrated_hessians │ ├── demo_bert_sentiment.ipynb │ ├── embedding_explainer_bert.py │ └── path_explain │ │ ├── __init__.py │ │ ├── explainers │ │ ├── __init__.py │ │ ├── embedding_explainer_tf.py │ │ ├── explainer.py │ │ ├── path_explainer_tf.py │ │ └── path_explainer_torch.py │ │ ├── plot │ │ ├── __init__.py │ │ ├── colors.py │ │ ├── scatter.py │ │ ├── summary.py │ │ └── text.py │ │ └── utils.py ├── mahe_madex │ ├── madex │ │ ├── madex_example_dna.ipynb │ │ ├── madex_example_graph.ipynb │ │ ├── madex_example_image.ipynb │ │ ├── madex_example_text.ipynb │ │ ├── neural_interaction_detection.py │ │ ├── sampling_and_inference.py │ │ └── utils │ │ │ ├── data │ │ │ ├── cora │ │ │ │ ├── README │ │ │ │ ├── cora.cites │ │ │ │ └── cora.content │ │ │ └── sample_images │ │ │ │ ├── bus.jpg │ │ │ │ ├── dog.jpg │ │ │ │ ├── shark.jpg │ │ │ │ └── viaduct.jpg │ │ │ ├── dna_utils.py │ │ │ ├── general_utils.py │ │ │ ├── graph_utils.py │ │ │ ├── image_utils.py │ │ │ ├── lime │ │ │ ├── lime_base.py │ │ │ └── lime_text.py │ │ │ ├── linear_cross_utils.py │ │ │ ├── pretrained │ │ │ ├── dna_cnn.pt │ │ │ ├── gcn_cora.pt │ │ │ └── model_gcn.py │ │ │ └── text_utils.py │ └── mahe │ │ ├── deps │ │ ├── interaction_explainer.py │ │ └── lime_scores.py │ │ ├── pipeline_mod.py │ │ └── run_experiment.py ├── scd_soc │ ├── demo.ipynb │ ├── helper.py │ └── hiexpl │ │ ├── __init__.py │ │ ├── algo │ │ ├── __init__.py │ │ ├── cd_func.py │ │ ├── scd_func.py │ │ ├── scd_lstm.py │ │ ├── scd_transformer.py │ │ ├── soc_lstm.py │ │ └── soc_transformer.py │ │ ├── bert │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── convert_gpt2_checkpoint_to_pytorch.py │ │ ├── convert_openai_checkpoint_to_pytorch.py │ │ ├── convert_tf_checkpoint_to_pytorch.py │ │ ├── convert_transfo_xl_checkpoint_to_pytorch.py │ │ ├── decomp_util.py │ │ ├── file_utils.py │ │ ├── filter_sentence.py │ │ ├── global_state.py │ │ ├── modeling.py │ │ ├── modeling_gpt2.py │ │ ├── modeling_openai.py │ │ ├── modeling_transfo_xl.py │ │ ├── modeling_transfo_xl_utilities.py │ │ ├── optimization.py │ │ ├── optimization_openai.py │ │ ├── run_classifier.py │ │ ├── run_lm_finetuning.py │ │ ├── tacred_f1.py │ │ ├── tokenization.py │ │ ├── tokenization_gpt2.py │ │ ├── tokenization_openai.py │ │ └── tokenization_transfo_xl.py │ │ ├── eval_explanations.py │ │ ├── explain.py │ │ ├── lm │ │ └── train.sh │ │ ├── lm_arch.py │ │ ├── lm_train.py │ │ ├── nns │ │ ├── Untitled.ipynb │ │ ├── __init__.py │ │ ├── hiexpl_vocab_atts.pickle │ │ ├── layers.py │ │ ├── linear.sh │ │ ├── linear_model.py │ │ ├── linear_models │ │ │ ├── best_snapshot_devacc_82.11009174311926_devloss_0.6408365964889526_iter_14500_model.pt │ │ │ └── snapshot_acc_0.0000_loss_0.233998_iter_39500_model.pt │ │ ├── model.py │ │ └── vocab │ │ │ └── vocab_sst.pkl │ │ ├── outputs │ │ └── sst │ │ │ ├── scd_bert_results │ │ │ └── scdbert2020.3.txt │ │ │ ├── soc_bert_results │ │ │ ├── socbert2020.3.txt │ │ │ └── socbertsoc2020.3.txt │ │ │ └── soc_results │ │ │ ├── soc.test.txt │ │ │ ├── soc2019.3.txt │ │ │ ├── soctest2020.3.txt │ │ │ └── soctest2020.3_2.txt │ │ ├── scripts │ │ ├── explanations │ │ │ └── explain_sst_lstm.sh │ │ └── train_model │ │ │ └── train_sst_lstm.sh │ │ ├── train.py │ │ ├── utils │ │ ├── __init__.py │ │ ├── agglomeration.py │ │ ├── args.py │ │ ├── parser.py │ │ ├── reader.py │ │ └── tacred_f1.py │ │ ├── visualize.py │ │ └── vocab │ │ ├── vocab_sst.pkl │ │ └── vocab_sst_bert.pkl ├── shapley_interaction_index │ ├── demo.ipynb │ └── si_explainer.py └── shapley_taylor_interaction_index │ ├── demo.ipynb │ └── sti_explainer.py ├── demos ├── 1. text analysis │ ├── demo_bert_tf.ipynb │ ├── demo_bert_torch.ipynb │ ├── demo_bert_torch_interactive.ipynb │ ├── demo_bert_torch_submission_figure.ipynb │ └── demo_bert_torch_word_level.ipynb ├── 2. image classification │ ├── demo_dog.ipynb │ ├── demo_imagenet.ipynb │ └── dog.jpg ├── 3. recommendation │ ├── autoint │ │ ├── README.md │ │ ├── model.py │ │ └── train.py │ └── demo.ipynb ├── 4. covid classification │ ├── covid_xray.jpg │ └── demo.ipynb ├── figures │ ├── covid.png │ ├── interactive.gif │ ├── recommendation.png │ └── sentiment.png └── requirements.txt ├── download.py ├── experiments ├── 1. archdetect │ ├── 1. synthetic_performance.ipynb │ ├── 2. redundancy_bert.ipynb │ ├── 2. redundancy_resnet.ipynb │ ├── 2.1. redundancy_analysis_plotting.ipynb │ ├── context_explainer.py │ ├── redundancy.png │ └── synthetic_utils.py ├── 2. archattribute │ ├── analysis │ │ ├── analyze_segment_auc.ipynb │ │ └── analyze_word_phrase_correlation.ipynb │ ├── experiment_utils.py │ ├── parallel_mahe │ │ ├── mahe_segment_auc.py │ │ └── mahe_text_correlation.py │ ├── processed_data │ │ ├── image_data │ │ │ └── coco_to_i1k_map.pickle │ │ ├── prepare_text_ground_truth.ipynb │ │ └── text_data │ │ │ ├── subtree_allphrase_nosentencelabel.pickle │ │ │ └── subtree_single_token.pickle │ ├── segment_auc.ipynb │ └── text_correlation.ipynb ├── README.md └── requirements.txt ├── setup_demos.sh ├── setup_experiments.sh ├── setup_interactive_viz.sh └── src ├── application_utils ├── common_utils.py ├── image_utils.py ├── rec_utils.py ├── text_utils.py ├── text_utils_tf.py ├── text_utils_torch.py └── utils_torch.py ├── explainer.py └── viz ├── colors.py ├── rec.py └── text.py /.gitignore: -------------------------------------------------------------------------------- 1 | downloads/* 2 | __pycache__ 3 | cache 4 | *.bak 5 | *.pyc 6 | .data 7 | .ipynb_checkpoints 8 | -------------------------------------------------------------------------------- /baselines/README.md: -------------------------------------------------------------------------------- 1 | # Baseline Methods 2 | 3 | This directory provides baseline methods primarily used by ArchAttribute experiments. Our ArchDetect experiments also use some of these baselines. 4 | 5 | For specialized baseline methods, we tried to use the original implementation as much as possible. For other baseline methods, we either implemented them ourselves or used public re-implementations. 6 | 7 | ## References 8 | 9 | The following are references for the methods in this directory. 10 | 11 | - Kedar Dhamdhere, Ashish Agarwal, and Mukund Sundararajan. The shapley taylor interaction index. arXiv preprint arXiv:1902.05622, 2019. 12 | - Michel Grabisch and Marc Roubens. An axiomatic approach to the concept of interaction among players in cooperative games. International Journal of game theory, 28(4):547–565, 1999. 13 | - Joseph D Janizek, Pascal Sturmfels, and Su-In Lee. Explaining explanations: Axiomatic feature interactions for deep networks. arXiv preprint arXiv:2002.04138, 2020. 14 | - Xisen Jin, Zhongyu Wei, Junyi Du, Xiangyang Xue, and Xiang Ren. Towards hierarchical importance attribution: Explaining compositional semantics for neural sequence models. In ICLR, 2020. 15 | - Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pages 3319–3328. JMLR.org, 2017. 16 | - Michael Tsang, Youbang Sun, Dongxu Ren, and Yan Liu. Can i trust you more? model-agnostic hierarchical explanations. arXiv preprint arXiv:1812.04801, 2018. 17 | -------------------------------------------------------------------------------- /baselines/difference/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../src\")\n", 13 | "from diff_explainer import DiffExplainer\n", 14 | "sys.path.append(\"../../experiments/1. archdetect\")\n", 15 | "from synthetic_utils import *\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Parameters" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "function_id = 4\n", 35 | "\n", 36 | "p = 40 # num features\n", 37 | "input_value, base_value = 1, -1" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Get Data and Synthetic Function" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "function id: 4\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "input = np.array([input_value]*p)\n", 62 | "baseline = np.array([base_value]*p)\n", 63 | "\n", 64 | "print(\"function id:\", function_id)\n", 65 | "model = synth_model(function_id, input_value, base_value)\n", 66 | "gts = model.get_gts(p)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Get Explanation" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "df = DiffExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20) \n", 83 | "explanation = df.explain()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## Check Completeness" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "82.0 82.0\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "# sum of attributions\n", 108 | "att_sum = sum(list(explanation.values()))\n", 109 | "# f(input) - f(baseline)\n", 110 | "f_diff = (model(input) - model(baseline)).item() \n", 111 | "\n", 112 | "# they should be equal\n", 113 | "assert(att_sum == f_diff)\n", 114 | "print(att_sum, f_diff)" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.6.12" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 4 139 | } 140 | -------------------------------------------------------------------------------- /baselines/difference/diff_explainer.py: -------------------------------------------------------------------------------- 1 | from explainer import Archipelago 2 | 3 | 4 | class DiffExplainer(Archipelago): 5 | def __init__( 6 | self, 7 | model, 8 | input=None, 9 | baseline=None, 10 | data_xformer=None, 11 | output_indices=0, 12 | batch_size=2, 13 | ): 14 | Archipelago.__init__( 15 | self, model, input, baseline, data_xformer, output_indices, batch_size 16 | ) 17 | 18 | def difference_attribution(self, set_indices): 19 | """ 20 | Gets attributions of index sets by f(x*) - f(x'_{I} + x*_{\I}) 21 | """ 22 | if not set_indices: 23 | return dict() 24 | scores = self.batch_set_inference( 25 | set_indices, self.input, self.baseline, include_context=True 26 | ) 27 | ditch_scores = scores["scores"] 28 | input_score = scores["context_score"] 29 | set_scores = {} 30 | for index_tuple in ditch_scores: 31 | set_scores[index_tuple] = input_score - ditch_scores[index_tuple] 32 | return set_scores 33 | -------------------------------------------------------------------------------- /baselines/integrated_gradients/ig.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def integrated_gradients( 5 | inputs, 6 | model, 7 | target_label_idx, 8 | get_gradients, 9 | baseline, 10 | device, 11 | steps=50, 12 | softmax=False, 13 | ): 14 | if baseline is None: 15 | baseline = 0 * inputs 16 | # scale inputs and compute gradients 17 | scaled_inputs = [ 18 | baseline + (float(i) / steps) * (inputs - baseline) for i in range(0, steps + 1) 19 | ] 20 | grads = get_gradients( 21 | scaled_inputs, model, target_label_idx, device, softmax=softmax 22 | ) 23 | avg_grads = np.average(grads[:-1], axis=0) 24 | integrated_grad = (inputs - baseline) * avg_grads 25 | return integrated_grad 26 | -------------------------------------------------------------------------------- /baselines/integrated_gradients/ig_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # import cv2 5 | import numpy as np 6 | 7 | 8 | def get_gradients(scaled_inputs, model, target_label_idx, device, softmax=False): 9 | grads = [] 10 | for i, input in enumerate(scaled_inputs): 11 | torch_input = torch.FloatTensor(input).unsqueeze(0).to(device) 12 | torch_input.requires_grad = True 13 | pred = model(torch_input) 14 | if softmax: 15 | pred = F.softmax(pred, dim=1) 16 | output = pred[:, target_label_idx] 17 | model.zero_grad() 18 | output.backward() 19 | grad = torch_input.grad.detach().cpu().numpy() 20 | grads.append(grad) 21 | grads = np.concatenate(grads) 22 | return grads 23 | -------------------------------------------------------------------------------- /baselines/integrated_gradients/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | G = [0, 255, 0] 5 | R = [255, 0, 0] 6 | 7 | 8 | def convert_to_gray_scale(attributions): 9 | return np.average(attributions, axis=2) 10 | 11 | 12 | def linear_transform( 13 | attributions, 14 | clip_above_percentile=99.9, 15 | clip_below_percentile=70.0, 16 | low=0.2, 17 | plot_distribution=False, 18 | ): 19 | m = compute_threshold_by_top_percentage( 20 | attributions, 21 | percentage=100 - clip_above_percentile, 22 | plot_distribution=plot_distribution, 23 | ) 24 | e = compute_threshold_by_top_percentage( 25 | attributions, 26 | percentage=100 - clip_below_percentile, 27 | plot_distribution=plot_distribution, 28 | ) 29 | transformed = (1 - low) * (np.abs(attributions) - e) / (m - e) + low 30 | transformed *= np.sign(attributions) 31 | transformed *= transformed >= low 32 | transformed = np.clip(transformed, 0.0, 1.0) 33 | return transformed 34 | 35 | 36 | def compute_threshold_by_top_percentage( 37 | attributions, percentage=60, plot_distribution=True 38 | ): 39 | if percentage < 0 or percentage > 100: 40 | raise ValueError("percentage must be in [0, 100]") 41 | if percentage == 100: 42 | return np.min(attributions) 43 | flat_attributions = attributions.flatten() 44 | attribution_sum = np.sum(flat_attributions) 45 | sorted_attributions = np.sort(np.abs(flat_attributions))[::-1] 46 | cum_sum = 100.0 * np.cumsum(sorted_attributions) / attribution_sum 47 | threshold_idx = np.where(cum_sum >= percentage)[0][0] 48 | threshold = sorted_attributions[threshold_idx] 49 | if plot_distribution: 50 | raise NotImplementedError 51 | return threshold 52 | 53 | 54 | def polarity_function(attributions, polarity): 55 | if polarity == "positive": 56 | return np.clip(attributions, 0, 1) 57 | elif polarity == "negative": 58 | return np.clip(attributions, -1, 0) 59 | else: 60 | raise NotImplementedError 61 | 62 | 63 | def overlay_function(attributions, image): 64 | return np.clip(0.7 * image + 0.5 * attributions, 0, 255) 65 | 66 | 67 | def visualize( 68 | attributions, 69 | image, 70 | positive_channel=G, 71 | negative_channel=R, 72 | polarity="positive", 73 | clip_above_percentile=99.9, 74 | clip_below_percentile=0, 75 | morphological_cleanup=False, 76 | structure=np.ones((3, 3)), 77 | outlines=False, 78 | outlines_component_percentage=90, 79 | overlay=True, 80 | mask_mode=False, 81 | plot_distribution=False, 82 | ): 83 | if polarity == "both": 84 | raise NotImplementedError 85 | 86 | elif polarity == "positive": 87 | attributions = polarity_function(attributions, polarity=polarity) 88 | channel = positive_channel 89 | 90 | # convert the attributions to the gray scale 91 | attributions = convert_to_gray_scale(attributions) 92 | attributions = linear_transform( 93 | attributions, 94 | clip_above_percentile, 95 | clip_below_percentile, 96 | 0.0, 97 | plot_distribution=plot_distribution, 98 | ) 99 | attributions_mask = attributions.copy() 100 | if morphological_cleanup: 101 | raise NotImplementedError 102 | if outlines: 103 | raise NotImplementedError 104 | attributions = np.expand_dims(attributions, 2) * channel 105 | if overlay: 106 | if mask_mode == False: 107 | attributions = overlay_function(attributions, image) 108 | else: 109 | attributions = np.expand_dims(attributions_mask, 2) 110 | attributions = np.clip(attributions * image, 0, 255) 111 | attributions = attributions[:, :, (2, 1, 0)] 112 | return attributions 113 | -------------------------------------------------------------------------------- /baselines/integrated_hessians/path_explain/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module for explaining the output of gradient-based 3 | models using path attributions. 4 | """ 5 | __version__ = "1.0" 6 | 7 | from .explainers.path_explainer_tf import PathExplainerTF 8 | from .utils import set_up_environment, softplus_activation 9 | from .plot.scatter import scatter_plot 10 | from .plot.summary import summary_plot 11 | from .plot.text import text_plot, matrix_interaction_plot, bar_interaction_plot 12 | -------------------------------------------------------------------------------- /baselines/integrated_hessians/path_explain/explainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/integrated_hessians/path_explain/explainers/__init__.py -------------------------------------------------------------------------------- /baselines/integrated_hessians/path_explain/explainers/explainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains the base class for the two explainer objects. 3 | """ 4 | 5 | 6 | class Explainer: 7 | """ 8 | A superclass for all explainer objects. 9 | This (somewhat) matches the SHAP 10 | package in terms of API. 11 | """ 12 | 13 | def attributions( 14 | self, 15 | inputs, 16 | baseline, 17 | batch_size=50, 18 | num_samples=100, 19 | use_expectation=True, 20 | output_indices=None, 21 | verbose=False, 22 | ): 23 | """ 24 | A function that returns the path attributions for the 25 | given inputs. 26 | """ 27 | raise Exception( 28 | "Attributions have not been implemented " 29 | + "for this class. Likely, you have imported " 30 | + "the wrong class from this package." 31 | ) 32 | 33 | def interactions( 34 | self, 35 | inputs, 36 | baseline, 37 | batch_size=50, 38 | num_samples=100, 39 | use_expectation=True, 40 | output_indices=None, 41 | verbose=False, 42 | interaction_index=None, 43 | ): 44 | """ 45 | A function that returns the path interactions for the 46 | given inputs. 47 | """ 48 | raise Exception( 49 | "Interactions have not been implemented " 50 | + "for this class. Likely, you have imported " 51 | + "the wrong class from this package." 52 | ) 53 | -------------------------------------------------------------------------------- /baselines/integrated_hessians/path_explain/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module for plotting path attributions. Built 3 | on top of altair, and occasionally 4 | matplotlib. 5 | """ 6 | -------------------------------------------------------------------------------- /baselines/integrated_hessians/path_explain/plot/colors.py: -------------------------------------------------------------------------------- 1 | """ 2 | A place for defining the default color scheme. 3 | """ 4 | 5 | import numpy as np 6 | import matplotlib as mpl 7 | 8 | 9 | def green_gold(): 10 | """ 11 | Returns the green and gold colormap we use as the 12 | default color scheme for this repository. 13 | """ 14 | color_map_size = 256 15 | vals = np.ones((color_map_size, 4)) 16 | vals[:, 0] = np.linspace(20 / 256, 250 / 256, color_map_size) 17 | vals[:, 1] = np.linspace(125 / 256, 230 / 256, color_map_size) 18 | vals[:, 2] = np.linspace(0 / 256, 0 / 256, color_map_size) 19 | cmap = mpl.colors.ListedColormap(vals) 20 | return cmap 21 | 22 | 23 | def maroon_white_aqua(): 24 | """ 25 | Returns the green and gold colormap we use as the 26 | default color scheme for plotting text. 27 | """ 28 | color_map_size = 256 29 | vals = np.ones((color_map_size, 4)) 30 | vals[: int(color_map_size / 2), 0] = np.linspace( 31 | 140 / 256, 1.0, int(color_map_size / 2) 32 | ) 33 | vals[: int(color_map_size / 2), 1] = np.linspace( 34 | 15 / 256, 1.0, int(color_map_size / 2) 35 | ) 36 | vals[: int(color_map_size / 2), 2] = np.linspace( 37 | 15 / 256, 1.0, int(color_map_size / 2) 38 | ) 39 | 40 | vals[int(color_map_size / 2) :, 0] = np.linspace( 41 | 1.0, 0 / 256, int(color_map_size / 2) 42 | ) 43 | vals[int(color_map_size / 2) :, 1] = np.linspace( 44 | 1.0, 220 / 256, int(color_map_size / 2) 45 | ) 46 | vals[int(color_map_size / 2) :, 2] = np.linspace( 47 | 1.0, 170 / 256, int(color_map_size / 2) 48 | ) 49 | cmap = mpl.colors.ListedColormap(vals) 50 | return cmap 51 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/madex_example_dna.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from torchtext import datasets, data\n", 10 | "import numpy as np\n", 11 | "import os, sys\n", 12 | "from time import time\n", 13 | "\n", 14 | "sys.path.append(\"../1. madex\")\n", 15 | "\n", 16 | "from neural_interaction_detection import *\n", 17 | "from sampling_and_inference import *\n", 18 | "from utils.dna_utils import *\n", 19 | "\n", 20 | "%matplotlib inline\n", 21 | "\n", 22 | "import warnings\n", 23 | "warnings.filterwarnings(\"ignore\")\n", 24 | "\n", 25 | "%load_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "\n", 28 | "device = torch.device(\"cuda:0\")" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Load Model" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "model = load_dna_model(\"utils/pretrained/dna_cnn.pt\").to(device)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Get DNA Sequence" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "GTAGGTAAGCGCACGTGTTGCACTTCCCTTAATCCA True\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "np.random.seed(42)\n", 69 | "seq_instance = generate_random_dna_sequence_with_CACGTG()\n", 70 | "print(seq_instance, \"CACGTG\" in seq_instance)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## Run MADEX" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stderr", 87 | "output_type": "stream", 88 | "text": [ 89 | "100%|██████████| 60/60 [00:02<00:00, 29.46it/s]\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "data_inst = {\"orig\": seq_instance, \"vectorizer\": encode_dna_onehot}\n", 95 | "Xs, Ys = generate_perturbation_dataset_dna(data_inst, model, device, seed=42)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "0.0046 test loss, 16.0 seconds elapsed\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "t0 = time()\n", 113 | "interactions, mlp_loss = detect_interactions(Xs, Ys, weight_samples=False, seed=42, verbose=False, add_linear=False)\n", 114 | "print(\"{} test loss, {} seconds elapsed\".format(round(mlp_loss, 4), round(time() - t0, 1)))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "interaction ranking \n", 127 | "\n", 128 | "1 found CACGTG >> ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16')\n", 129 | "2 ('A_21', 'C_25')\n", 130 | "3 ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'A_21')\n", 131 | "4 ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')\n", 132 | "5 ('A_21', 'C_25', 'C_26')\n", 133 | "6 ('A_21', 'T_23', 'C_25', 'C_26')\n", 134 | "7 ('A_21', 'T_23', 'C_25', 'C_26', 'T_28')\n", 135 | "8 ('A_2', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')\n", 136 | "9 ('A_2', 'A_6', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')\n", 137 | "10 ('A_2', 'A_6', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18', 'C_20')\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "print(\"interaction ranking\", \"\\n\")\n", 143 | "for rank, inter in enumerate(interactions[:10]):\n", 144 | " inter_indices, _ = inter\n", 145 | " inter_verbose = tuple((seq_instance[s], s) for s in inter_indices)\n", 146 | "\n", 147 | " inter_nucleotides, _ = zip(*inter_verbose)\n", 148 | " if \"\".join(inter_nucleotides) == \"CACGTG\" and all(np.diff(inter_indices) == 1):\n", 149 | " postfix = \"found CACGTG >>\"\n", 150 | " else:\n", 151 | " postfix = \"\"\n", 152 | " print(rank+1, postfix, tuple(a + \"_\" + str(b) for a,b in inter_verbose))\n", 153 | "\n" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python [conda env:torch]", 160 | "language": "python", 161 | "name": "conda-env-torch-py" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.2" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 4 178 | } 179 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/madex_example_graph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from torchtext import datasets, data\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "import os, sys\n", 13 | "from time import time\n", 14 | "\n", 15 | "from neural_interaction_detection import *\n", 16 | "from sampling_and_inference import *\n", 17 | "from utils.general_utils import *\n", 18 | "from utils.graph_utils import *\n", 19 | "\n", 20 | "%matplotlib inline\n", 21 | "\n", 22 | "import warnings\n", 23 | "warnings.filterwarnings(\"ignore\")\n", 24 | "\n", 25 | "%load_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "\n", 28 | "device = torch.device(\"cuda:0\")" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Load Model" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "model_folder = \"utils/pretrained\"\n", 45 | "\n", 46 | "model, n_nodes, n_hops, test_idxs = get_graph_model(model_folder)\n", 47 | "model = model.to(device)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Classify Graph" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "data_folder = \"utils/data/cora\"\n", 64 | "\n", 65 | "node_feats, adj_mat, labels = load_cora(data_folder, device)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "target node classification: 6\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "target_idx = test_idxs[0]\n", 83 | "\n", 84 | "preds = model(node_feats, convert_adj_to_da(adj_mat))\n", 85 | "classification = torch.argmax(preds, 1).cpu().numpy()[target_idx] \n", 86 | "print(\"target node classification:\", classification)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Run MADEX" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "100%|██████████| 6000/6000 [01:40<00:00, 59.72it/s]\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "data_inst = {\"nodes\": node_feats, \"edges\": adj_mat, \"test_idxs\": test_idxs}\n", 111 | "Xs, Ys = generate_perturbation_dataset_graph(data_inst, model, target_idx, n_hops+1, device, seed=42, std_scale=False)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "19.4754 test loss, 94.2 seconds elapsed\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "t0 = time()\n", 129 | "interactions, mlp_loss = detect_interactions(Xs, Ys, weight_samples=True, seed=42, verbose=False)\n", 130 | "print(\"{} test loss, {} seconds elapsed\".format(round(mlp_loss, 4), round(time() - t0, 1)))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "## Show Main Effects and Interaction Interpretations" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 7, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "legend: (hops from target node, node idx). All hops should be within n_hops: 3\n", 150 | "\n", 151 | "target (0, 1808)\n", 152 | "\n", 153 | "main effects\n", 154 | "(2, 722)\n", 155 | "(2, 2465)\n", 156 | "(2, 264)\n", 157 | "(2, 1189)\n", 158 | "(2, 2146)\n", 159 | "\n", 160 | "interactions\n", 161 | "2\n", 162 | "inter 0: ((1, 638), (2, 722))\n", 163 | "4\n", 164 | "inter 1: ((2, 264), (1, 638), (2, 722), (2, 2465))\n", 165 | "5\n", 166 | "inter 2: ((2, 264), (1, 638), (2, 722), (2, 1189), (2, 2465))\n", 167 | "6\n", 168 | "inter 3: ((2, 264), (1, 638), (2, 722), (2, 1189), (2, 2146), (2, 2465))\n", 169 | "9\n", 170 | "inter 4: ((2, 264), (2, 294), (2, 296), (1, 638), (2, 722), (2, 1189), (2, 1327), (2, 2146), (2, 2465))\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "node_to_hop = get_hops_to_target(target_idx, adj_mat, n_hops)\n", 176 | "local_map = data_inst[\"local_idx_map\"]\n", 177 | "\n", 178 | "print(\"legend: (hops from target node, node idx). All hops should be within n_hops:\", n_hops)\n", 179 | "\n", 180 | "print(\"\\ntarget\", (0, target_idx))\n", 181 | "print(\"\\nmain effects\")\n", 182 | "for uni, att in get_lime_attributions(Xs, Ys)[:5]:\n", 183 | " if att > 0:\n", 184 | " print((node_to_hop[local_map[uni]],local_map[uni]))\n", 185 | "print(\"\\ninteractions\")\n", 186 | "for i, inter in enumerate(interactions[:5]):\n", 187 | " print(len(inter[0]))\n", 188 | " print(\"inter {}:\".format(i), tuple((node_to_hop[local_map[n]],local_map[n]) for n in inter[0]))\n" 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "kernelspec": { 194 | "display_name": "Python [conda env:torch]", 195 | "language": "python", 196 | "name": "conda-env-torch-py" 197 | }, 198 | "language_info": { 199 | "codemirror_mode": { 200 | "name": "ipython", 201 | "version": 3 202 | }, 203 | "file_extension": ".py", 204 | "mimetype": "text/x-python", 205 | "name": "python", 206 | "nbconvert_exporter": "python", 207 | "pygments_lexer": "ipython3", 208 | "version": "3.6.2" 209 | } 210 | }, 211 | "nbformat": 4, 212 | "nbformat_minor": 4 213 | } 214 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/madex_example_text.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from torchtext import datasets, data\n", 10 | "import numpy as np\n", 11 | "import os, sys\n", 12 | "from time import time\n", 13 | "\n", 14 | "from neural_interaction_detection import *\n", 15 | "from sampling_and_inference import *\n", 16 | "from utils.general_utils import *\n", 17 | "from utils.text_utils import *\n", 18 | "\n", 19 | "import warnings\n", 20 | "warnings.filterwarnings(\"ignore\")\n", 21 | "%load_ext autoreload\n", 22 | "%autoreload 2\n", 23 | "\n", 24 | "device = torch.device(\"cuda:0\")" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Load Model" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stderr", 41 | "output_type": "stream", 42 | "text": [ 43 | "Widget Javascript not detected. It may not be installed or enabled properly.\n" 44 | ] 45 | }, 46 | { 47 | "data": { 48 | "application/vnd.jupyter.widget-view+json": { 49 | "model_id": "f1b3bc5eb16f46edb5fcd643307db412" 50 | } 51 | }, 52 | "metadata": {}, 53 | "output_type": "display_data" 54 | }, 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "model = get_bert_model(device)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "## Classify Sentence" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "positive sentiment\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "sentence = \"this was not a great movie, but a good movie nevertheless\"\n", 89 | "\n", 90 | "out = model(sentence)\n", 91 | "pred = np.argmax(out[0])\n", 92 | "print((\"positive\" if pred== 1 else \"negative\") + \" sentiment\")" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Run MADEX" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stderr", 109 | "output_type": "stream", 110 | "text": [ 111 | "100%|██████████| 12/12 [00:10<00:00, 1.15it/s]\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "data_inst = {\"orig\": sentence}\n", 117 | "Xs, Ys = generate_perturbation_dataset_text(data_inst, model, 1, device, model_id=\"bert\", batch_size=500, seed=42, std_scale=True)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "0.0142 test loss, 29.9 seconds elapsed\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "t0 = time()\n", 135 | "interactions, mlp_loss = detect_interactions(Xs, Ys, detector=\"GradientNID\", add_linear=True, device=device, weight_samples=True, seed=42, verbose=False)\n", 136 | "print(\"{} test loss, {} seconds elapsed\".format(round(mlp_loss, 4), round(time() - t0, 1)))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "## Show Main Effects and Interaction Interpretations" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 6, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "this was not a great movie, but a good movie nevertheless\n", 156 | "\n", 157 | "main effects: ('but', 'a', 'good', 'movie', 'nevertheless')\n", 158 | "\n", 159 | "top-5 interactions\n", 160 | "inter 1: ('not', 'but') 2.7557428\n", 161 | "inter 2: ('but', 'good') 1.9747727\n", 162 | "inter 3: ('not', 'good') 1.8207084\n", 163 | "inter 4: ('great', 'good') 1.3452238\n", 164 | "inter 5: ('not', 'great') 1.2503706\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "print(sentence + \"\\n\")\n", 170 | "\n", 171 | "dom_map = data_inst[\"domain_mapper\"]\n", 172 | "\n", 173 | "lime_atts = get_lime_attributions(Xs, Ys)\n", 174 | "print(\"main effects:\", map_words([i for i, a in lime_atts if a*(pred*2-1) > 0], dom_map))\n", 175 | "\n", 176 | "print(\"\\ntop-5 interactions\")\n", 177 | "for i, inter_tuple in enumerate(interactions[:5]):\n", 178 | " inter, strength = inter_tuple\n", 179 | " word_inter = map_words(inter, dom_map)\n", 180 | " print(\"inter {}:\".format(i+1), word_inter, strength)" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python [conda env:torch]", 187 | "language": "python", 188 | "name": "conda-env-torch-py" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.6.2" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 4 205 | } 206 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/data/cora/README: -------------------------------------------------------------------------------- 1 | This directory contains the a selection of the Cora dataset (www.research.whizbang.com/data). 2 | 3 | The Cora dataset consists of Machine Learning papers. These papers are classified into one of the following seven classes: 4 | Case_Based 5 | Genetic_Algorithms 6 | Neural_Networks 7 | Probabilistic_Methods 8 | Reinforcement_Learning 9 | Rule_Learning 10 | Theory 11 | 12 | The papers were selected in a way such that in the final corpus every paper cites or is cited by atleast one other paper. There are 2708 papers in the whole corpus. 13 | 14 | After stemming and removing stopwords we were left with a vocabulary of size 1433 unique words. All words with document frequency less than 10 were removed. 15 | 16 | 17 | THE DIRECTORY CONTAINS TWO FILES: 18 | 19 | The .content file contains descriptions of the papers in the following format: 20 | 21 | + 22 | 23 | The first entry in each line contains the unique string ID of the paper followed by binary values indicating whether each word in the vocabulary is present (indicated by 1) or absent (indicated by 0) in the paper. Finally, the last entry in the line contains the class label of the paper. 24 | 25 | The .cites file contains the citation graph of the corpus. Each line describes a link in the following format: 26 | 27 | 28 | 29 | Each line contains two paper IDs. The first entry is the ID of the paper being cited and the second ID stands for the paper which contains the citation. The direction of the link is from right to left. If a line is represented by "paper1 paper2" then the link is "paper2->paper1". -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/data/sample_images/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/data/sample_images/bus.jpg -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/data/sample_images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/data/sample_images/dog.jpg -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/data/sample_images/shark.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/data/sample_images/shark.jpg -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/data/sample_images/viaduct.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/data/sample_images/viaduct.jpg -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/dna_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import h5py as h5 4 | import numpy as np 5 | from utils.general_utils import * 6 | 7 | # from sampling_and_inference import * 8 | 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def createConv1D(n_inp, n_out, hidden_units, kernel_size, seq_len, activation=nn.ReLU): 16 | 17 | layers = [] 18 | layers_size = [n_inp] + hidden_units 19 | for i in range(len(layers_size) - 1): 20 | layers.append(nn.Conv1d(layers_size[i], layers_size[i + 1], kernel_size)) 21 | if activation is not None: 22 | layers.append(activation()) 23 | layers.append(Flatten()) 24 | seq_len = seq_len - (kernel_size - 1) * len(hidden_units) 25 | linear_dim = layers_size[-1] * seq_len 26 | layers.append(nn.Linear(linear_dim, n_out)) 27 | 28 | return nn.Sequential(*layers) 29 | 30 | 31 | class conv1D(nn.Module): 32 | def __init__(self, n_inp, n_out, hidden_units, kernel_size, seq_len, **kwargs): 33 | super(conv1D, self).__init__() 34 | self.conv1D = createConv1D(n_inp, n_out, hidden_units, kernel_size, seq_len) 35 | 36 | def forward(self, x): 37 | return self.conv1D(x) 38 | 39 | 40 | def load_dna_model(path): 41 | model = conv1D(4, 1, [64, 64], 5, 36) 42 | model.load_state_dict(torch.load(path)) 43 | return model 44 | 45 | 46 | def generate_random_dna_sequence_with_CACGTG(length=36, seed=None): 47 | if seed is not None: 48 | set_seed(seed) 49 | 50 | nucleotides = ["A", "C", "G", "T"] 51 | seq = "" 52 | ebox = "CACGTG" 53 | for i in np.random.randint(0, 4, (length)): 54 | seq += nucleotides[i] 55 | i = np.random.randint(0, length - len(ebox)) 56 | seq = seq[:i] + ebox + seq[i + len(ebox) :] 57 | return seq 58 | 59 | 60 | def encode_dna_onehot(seq): 61 | seq_as_list = list(seq) 62 | 63 | for i, c in enumerate(seq_as_list): 64 | if c == "A": 65 | seq_as_list[i] = [1, 0, 0, 0] 66 | elif c == "T": 67 | seq_as_list[i] = [0, 1, 0, 0] 68 | elif c == "C": 69 | seq_as_list[i] = [0, 0, 1, 0] 70 | elif c == "G": 71 | seq_as_list[i] = [0, 0, 0, 1] 72 | else: 73 | seq_as_list[i] = [0, 0, 0, 0] 74 | 75 | return np.array(seq_as_list) 76 | 77 | 78 | class IndexedNucleotides(object): 79 | """String with various indexes.""" 80 | 81 | """Based on LIME official Repo""" 82 | 83 | def __init__(self, raw_string): 84 | """Initializer. 85 | 86 | Args: 87 | raw_string: string with raw text in it 88 | """ 89 | self.raw = raw_string 90 | self.as_list = list(self.raw) 91 | self.as_np = np.array(self.as_list) 92 | self.string_start = np.arange(len(self.raw)) 93 | vocab = {} 94 | self.inverse_vocab = [] 95 | self.positions = [] 96 | non_vocab = set() 97 | for i, char in enumerate(self.as_np): 98 | if char in non_vocab: 99 | continue 100 | self.inverse_vocab.append(char) 101 | self.positions.append(i) 102 | self.positions = np.array(self.positions) 103 | 104 | def raw_string(self): 105 | """Returns the original raw string""" 106 | return self.raw 107 | 108 | def num_nucleotides(self): 109 | """Returns the number of tokens in the vocabulary for this document.""" 110 | return len(self.inverse_vocab) 111 | 112 | def choose_alt(self, existing): 113 | nucleotides = ["A", "T", "G", "C"] 114 | nucleotides.remove(existing) 115 | return nucleotides[np.random.randint(0, 3)] 116 | 117 | def perturb_nucleotide(self, chars_to_remove): 118 | mask = np.ones(self.as_np.shape[0], dtype="bool") 119 | mask[self.__get_idxs(chars_to_remove)] = False 120 | return "".join( 121 | [ 122 | self.as_list[i] if mask[i] else self.choose_alt(self.as_list[i]) 123 | for i in range(mask.shape[0]) 124 | ] 125 | ) 126 | 127 | def __get_idxs(self, chars): 128 | """Returns indexes to appropriate words.""" 129 | return self.positions[chars] 130 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | from utils.pretrained.model_gcn import * 2 | from collections import defaultdict 3 | import numpy as np 4 | import copy 5 | 6 | 7 | def get_graph_model(model_folder): 8 | 9 | meta = torch.load(model_folder + "/gcn_cora.pt") 10 | 11 | n_hops = meta["n_hops"] 12 | n_nodes = meta["n_nodes"] 13 | test_idxs = meta["test_idxs"] 14 | n_samples = meta["n_samples"] 15 | dim_inp = meta["dim_inp"] 16 | dim_hid = meta["dim_hid"] 17 | dim_out = meta["dim_out"] 18 | 19 | model = create_model(dim_inp, dim_hid, dim_out, n_samples, n_hops) 20 | model.load_state_dict(meta["state_dict"]) 21 | 22 | return model, n_nodes, n_hops, test_idxs 23 | 24 | 25 | def convert_adj_to_da(adj_mat, make_undirected=False): 26 | # Converts adjacency to laplacian matrix 27 | if isinstance(adj_mat, np.ndarray): 28 | adj_mat = torch.from_numpy(adj_mat).float() 29 | if make_undirected: 30 | diag = torch.diag(torch.diag(adj_mat)) 31 | x = adj_mat - diag 32 | adj_mat = x + x.t() + adj_mat 33 | 34 | da_mat = torch.eye(len(adj_mat)).to(adj_mat.device) - adj_mat 35 | return da_mat 36 | 37 | 38 | def load_cora(data_folder, device): 39 | num_nodes = 2708 40 | num_feats = 1433 41 | feat_data = np.zeros((num_nodes, num_feats)) 42 | labels = np.empty((num_nodes, 1), dtype=np.int64) 43 | node_map = {} 44 | label_map = {} 45 | with open(data_folder + "/cora.content") as fp: 46 | for i, line in enumerate(fp): 47 | info = line.strip().split() 48 | feat_data[i, :] = [float(_) for _ in info[1:-1]] 49 | node_map[info[0]] = i 50 | if not info[-1] in label_map: 51 | label_map[info[-1]] = len(label_map) 52 | labels[i] = label_map[info[-1]] 53 | 54 | adj_lists = defaultdict(set) 55 | with open(data_folder + "/cora.cites") as fp: 56 | for i, line in enumerate(fp): 57 | info = line.strip().split() 58 | n1 = node_map[info[0]] 59 | n2 = node_map[info[1]] 60 | adj_lists[n1].add(n2) 61 | adj_lists[n2].add(n1) 62 | 63 | adj_mat = np.zeros((num_nodes, num_nodes)) 64 | for u in adj_lists: 65 | for v in adj_lists[u]: 66 | adj_mat[u, v] = 1 67 | 68 | feat_data = torch.FloatTensor(feat_data).to(device) 69 | adj_mat = torch.FloatTensor(adj_mat).to(device) 70 | return feat_data, adj_mat, labels 71 | 72 | 73 | def get_hops_to_target(target_idx, adj_mat, n_hops): 74 | # Create a map from node to the number of hops from the target test index 75 | node_to_hop = {target_idx: 0} 76 | seen_points = {target_idx} 77 | for j in range(1, n_hops + 2): 78 | adj_cum = copy.deepcopy(adj_mat) 79 | for i in range(j - 1): 80 | adj_cum = torch.matmul(adj_cum, adj_mat) 81 | collect = {i for i, v in enumerate(adj_cum[target_idx]) if v != 0} 82 | ex_collect = collect - seen_points 83 | seen_points |= collect 84 | for e in ex_collect: 85 | node_to_hop[e] = j 86 | 87 | return node_to_hop 88 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import requests 3 | from PIL import Image 4 | from skimage.segmentation import mark_boundaries 5 | import numpy as np 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from matplotlib.gridspec import GridSpec 9 | 10 | matplotlib.rcParams["mathtext.fontset"] = "cm" 11 | matplotlib.rcParams["font.family"] = "STIXGeneral" 12 | 13 | 14 | # image pre-processing needed for ResNet 15 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 16 | 17 | preprocess = transforms.Compose( 18 | [ 19 | transforms.Resize((224, 224)), 20 | transforms.ToTensor(), 21 | normalize, 22 | ] 23 | ) 24 | 25 | 26 | def get_image_and_labels( 27 | image_path, 28 | device, 29 | labels_url="https://s3.amazonaws.com/outcome-blog/imagenet/labels.json", 30 | ): 31 | """ 32 | Loads image instance and labels 33 | 34 | Args: 35 | image_path: path to image instance 36 | labels_url: url to json labels 37 | 38 | Returns: 39 | image, labels 40 | """ 41 | image = Image.open(image_path) 42 | if image.mode != "RGB": 43 | image = image.convert("RGB") 44 | image_tensor = preprocess(image) 45 | image = ( 46 | image_tensor.cpu().numpy().transpose(1, 2, 0) / image_tensor.abs().max().item() 47 | ) 48 | image_tensor = ( 49 | image_tensor.unsqueeze_(0).to(device) / image_tensor.abs().max().item() 50 | ) 51 | labels = { 52 | int(key): value for (key, value) in requests.get(labels_url).json().items() 53 | } 54 | return image, image_tensor, labels 55 | 56 | 57 | def show_segmented_image(image, segments): 58 | plt.imshow(mark_boundaries(image / 2 + 0.5, segments)) 59 | 60 | 61 | def plot_explanations(img_arrays, figsize=0.4, spacing=0.15, savepath=""): 62 | w_spacing = (2 / 3) * spacing 63 | left = 0 64 | ax_arays = [] 65 | fig = plt.figure() 66 | for img_array in img_arrays: 67 | num_imgs = len(img_array) 68 | right = left + figsize * (num_imgs) + (num_imgs - 1) * 0.4 * w_spacing 69 | ax_arays.append( 70 | fig.subplots( 71 | 1, num_imgs, gridspec_kw=dict(left=left, right=right, wspace=w_spacing) 72 | ) 73 | ) 74 | left = right + spacing 75 | 76 | for i, ax_array in enumerate(ax_arays): 77 | if hasattr(ax_array, "flat"): 78 | for j, ax in enumerate(ax_array.flat): 79 | img, title = img_arrays[i][j] 80 | ax.imshow(img / 2 + 0.5) 81 | ax.set_title(title, fontsize=55 * figsize) 82 | ax.axis("off") 83 | else: 84 | img, title = img_arrays[i][0] 85 | 86 | ax_array.imshow(img / 2 + 0.5) 87 | ax_array.set_title(title, fontsize=55 * figsize) 88 | ax_array.axis("off") 89 | 90 | if savepath: 91 | plt.savefig(savepath, bbox_inches="tight") 92 | plt.show() 93 | 94 | 95 | def show_explanations( 96 | inter_sets, image, segments, figsize=0.4, spacing=0.15, lime_atts=None, savepath="" 97 | ): 98 | def get_interaction_img(inter): 99 | temp = (np.ones(image.shape, image.dtype) - 0.5) * 1 100 | for n in inter: 101 | temp[segments == n] = image[segments == n].copy() 102 | return temp 103 | 104 | img_arrays = [] 105 | img_arrays.append([(image, "Original image")]) 106 | 107 | ## main effects 108 | if lime_atts is not None: 109 | temp = (np.ones(image.shape, image.dtype) - 0.5) * 1 110 | for n, _ in lime_atts[:5]: 111 | temp[segments == n] = image[segments == n].copy() 112 | img_arrays.append([(temp, "Main effects")]) 113 | 114 | inter_img_arrays = [] 115 | for i, inter_set in enumerate(inter_sets): 116 | inter_img_arrays.append( 117 | ( 118 | get_interaction_img(inter_set), 119 | "Interaction $\mathcal{I}_" + str(i + 1) + "$", 120 | ) 121 | ) 122 | img_arrays.append(inter_img_arrays) 123 | 124 | plot_explanations(img_arrays, figsize, spacing, savepath) 125 | -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/pretrained/dna_cnn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/pretrained/dna_cnn.pt -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/pretrained/gcn_cora.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/mahe_madex/madex/utils/pretrained/gcn_cora.pt -------------------------------------------------------------------------------- /baselines/mahe_madex/madex/utils/pretrained/model_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import relu 4 | 5 | 6 | class InductiveGCN(nn.Module): 7 | def __init__(self, dim_inp, dim_hid, dim_out, n_samples, n_hops): 8 | super().__init__() 9 | 10 | self.dim_inp = dim_inp 11 | self.dim_hid = dim_hid 12 | self.dim_out = dim_out 13 | self.n_samples = n_samples 14 | 15 | dim_hiddens = [dim_inp] + [dim_hid] * n_hops 16 | self.layers = [ 17 | nn.Linear(dim_hiddens[i], dim_hiddens[i + 1]) 18 | for i in range(len(dim_hiddens) - 1) 19 | ] 20 | self.final_fc = nn.Linear(dim_hiddens[-1], dim_out) 21 | for layer in self.layers + [self.final_fc]: 22 | nn.init.xavier_normal_(layer.weight) 23 | nn.init.zeros_(layer.bias) 24 | self.layers = nn.ModuleList(self.layers) 25 | 26 | def forward(self, x, adj_mat): 27 | """ 28 | 29 | :param x: (n_nodes, dim_inp) 30 | :param adj_mat: (n_nodes, n_nodes) 31 | :return: (n_nodes, dim_out) 32 | """ 33 | for layer in self.layers: 34 | x = torch.matmul(adj_mat, x) 35 | x = relu(layer(x)) 36 | x = torch.matmul(adj_mat, x) 37 | x = self.final_fc(x) 38 | return x 39 | 40 | 41 | def create_model(dim_inp, dim_hid, dim_out, n_samples, n_hops): 42 | return InductiveGCN(dim_inp, dim_hid, dim_out, n_samples, n_hops) 43 | -------------------------------------------------------------------------------- /baselines/mahe_madex/mahe/deps/lime_scores.py: -------------------------------------------------------------------------------- 1 | from utils.lime import lime_base 2 | import numpy as np 3 | from utils.general_utils import * 4 | from sklearn.metrics import mean_squared_error 5 | 6 | 7 | def get_lime_mse( 8 | Xd, 9 | Yd, 10 | max_features=10000, 11 | kernel_width=0.25, 12 | weight_samples=True, 13 | sort=True, 14 | **kwargs 15 | ): 16 | def kernel(d): 17 | return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) 18 | 19 | distances = get_sample_distances(Xd)["train"] 20 | if not weight_samples: 21 | distances = np.ones_like(distances) 22 | 23 | lb = lime_base.LimeBase(kernel_fn=kernel) 24 | lb_out = lb.explain_instance_with_data( 25 | Xd["train"], Yd["train"], distances, 0, max_features 26 | ) 27 | easy_model = lb_out[-1] 28 | 29 | weights = lb_out[-3] 30 | all_pred = easy_model.predict(Xd["test"]) 31 | 32 | Wd = get_sample_weights(Xd, enable=weight_samples, **kwargs) 33 | 34 | mse = mean_squared_error(Yd["test"], all_pred, sample_weight=Wd["test"]) 35 | 36 | mse_train = lb_out[2] 37 | # print(mse_train, mse) 38 | # assert(False) 39 | 40 | return mse 41 | -------------------------------------------------------------------------------- /baselines/mahe_madex/mahe/pipeline_mod.py: -------------------------------------------------------------------------------- 1 | from transformers import Pipeline 2 | from typing import Dict, List, Optional, Tuple, Union 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.tokenization_utils import PreTrainedTokenizer 5 | from transformers.modelcard import ModelCard 6 | from transformers.tokenization_auto import AutoTokenizer 7 | from transformers.configuration_auto import ( 8 | ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, 9 | AutoConfig, 10 | ) 11 | import torch 12 | 13 | from transformers.modeling_auto import ( 14 | AutoModel, 15 | AutoModelForSequenceClassification, 16 | AutoModelForQuestionAnswering, 17 | AutoModelForTokenClassification, 18 | AutoModelWithLMHead, 19 | ) 20 | 21 | 22 | class TextClassificationPipeline(Pipeline): 23 | """ 24 | Text classification pipeline using ModelForTextClassification head. 25 | """ 26 | 27 | def __call__(self, *args, **kwargs): 28 | outputs = super().__call__(*args, **kwargs) 29 | # scores = np.exp(outputs) / np.exp(outputs).sum(-1) 30 | return outputs # [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores] 31 | 32 | 33 | # Register all the supported task here 34 | SUPPORTED_TASKS = { 35 | "sentiment-analysis": { 36 | "impl": TextClassificationPipeline, 37 | "pt": AutoModelForSequenceClassification, # if is_torch_available() else None, 38 | "default": { 39 | "model": { 40 | "pt": "distilbert-base-uncased-finetuned-sst-2-english", 41 | }, 42 | "config": "distilbert-base-uncased-finetuned-sst-2-english", 43 | "tokenizer": "distilbert-base-uncased", 44 | }, 45 | }, 46 | } 47 | 48 | 49 | def pipeline( 50 | task: str, 51 | model: Optional = None, 52 | config: Optional[Union[str, PretrainedConfig]] = None, 53 | tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, 54 | modelcard: Optional[Union[str, ModelCard]] = None, 55 | device=torch.device("cpu"), 56 | **kwargs 57 | ) -> Pipeline: 58 | """ 59 | Utility factory method to build a pipeline. 60 | Pipeline are made of: 61 | A Tokenizer instance in charge of mapping raw textual input to token 62 | A Model instance 63 | Some (optional) post processing for enhancing model's output 64 | Examples: 65 | pipeline('sentiment-analysis') 66 | pipeline('question-answering', model='distilbert-base-uncased-distilled-squad', tokenizer='bert-base-cased') 67 | pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...) 68 | pipeline('ner', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='bert-base-cased') 69 | pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased') 70 | """ 71 | # Retrieve the task 72 | if task not in SUPPORTED_TASKS: 73 | raise KeyError( 74 | "Unknown task {}, available tasks are {}".format( 75 | task, list(SUPPORTED_TASKS.keys()) 76 | ) 77 | ) 78 | 79 | framework = "pt" # get_framework(model) 80 | 81 | targeted_task = SUPPORTED_TASKS[task] 82 | task, model_class = targeted_task["impl"], targeted_task[framework] 83 | 84 | # Use default model/config/tokenizer for the task if no model is provided 85 | if model is None: 86 | models, config, tokenizer = tuple(targeted_task["default"].values()) 87 | model = models[framework] 88 | 89 | # Try to infer tokenizer from model or config name (if provided as str) 90 | if tokenizer is None: 91 | if isinstance(model, str) and model in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: 92 | tokenizer = model 93 | elif isinstance(config, str) and config in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: 94 | tokenizer = config 95 | else: 96 | # Impossible to guest what is the right tokenizer here 97 | raise Exception( 98 | "Impossible to guess which tokenizer to use. " 99 | "Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer." 100 | ) 101 | 102 | # Try to infer modelcard from model or config name (if provided as str) 103 | if modelcard is None: 104 | # Try to fallback on one of the provided string for model or config (will replace the suffix) 105 | if isinstance(model, str): 106 | modelcard = model 107 | elif isinstance(config, str): 108 | modelcard = config 109 | 110 | # Instantiate tokenizer if needed 111 | if isinstance(tokenizer, str): 112 | tokenizer = AutoTokenizer.from_pretrained(tokenizer) 113 | 114 | # Instantiate config if needed 115 | if isinstance(config, str): 116 | config = AutoConfig.from_pretrained(config) 117 | 118 | # Instantiate modelcard if needed 119 | if isinstance(modelcard, str): 120 | modelcard = ModelCard.from_pretrained(modelcard) 121 | 122 | # Instantiate model if needed 123 | if isinstance(model, str): 124 | # Handle transparent TF/PT model conversion 125 | model_kwargs = {} 126 | if framework == "pt" and model.endswith(".h5"): 127 | model_kwargs["from_tf"] = True 128 | logger.warning( 129 | "Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. " 130 | "Trying to load the model with PyTorch." 131 | ) 132 | # else: 133 | # raise ValueError("invalid framework or model type") 134 | 135 | model = model_class.from_pretrained(model, config=config, **model_kwargs) 136 | model = model.to(device) 137 | 138 | return task( 139 | model=model, 140 | tokenizer=tokenizer, 141 | modelcard=modelcard, 142 | framework=framework, 143 | **kwargs 144 | ) 145 | -------------------------------------------------------------------------------- /baselines/scd_soc/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "import torch\n", 21 | "import numpy as np\n", 22 | "import random\n", 23 | "# from bert.run_classifier import BertTokenizer\n", 24 | "from transformers import BertTokenizer\n", 25 | "\n", 26 | "sys.path.append(\"./hiexpl\")\n", 27 | "from helper import *\n", 28 | "\n", 29 | "import logging\n", 30 | "logger = logging.getLogger()\n", 31 | "logger.setLevel(logging.CRITICAL)\n", 32 | "\n", 33 | "device = torch.device(\"cuda:0\")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Get Model" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')#, do_lower_case=True, cache_dir='bert/cache')\n", 50 | "bert_path = \"../../downloads/pretrained_bert\"\n", 51 | "model = get_bert(bert_path, device)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "## Get Data" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "sentence = \"though a bit of a patch ##work in script and production , a glossy , rich green , environment almost makes the picture work .\"" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Get Prediction" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "tensor([[-1.9403, 2.7804]], device='cuda:0', grad_fn=)\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "preds = get_prediction(model, [sentence], tokenizer, device)\n", 92 | "print(preds)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Get Explanations" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "method = \"soc\"\n", 109 | "# method = \"scd\"\n", 110 | "# method = \"cd\" # doesnt work at the moment" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 7, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "loading vocab from vocab/vocab_sst.pkl\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "lm_path = \"../../downloads/pretrained_hiexpl_lm/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt\"\n", 128 | "algo = get_hiexpl(method, model, lm_path, tokenizer, device, sample_num=20)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 8, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "tokens\n", 141 | " ['though', 'a', 'bit', 'of', 'a', 'patch', '#', '#', 'work', 'in', 'script', 'and', 'production', ',', 'a', 'glossy', ',', 'rich', 'green', ',', 'environment', 'almost', 'makes', 'the', 'picture', 'work', '.']\n", 142 | "\n", 143 | "scores\n", 144 | " {(1, 1): 0.146915465593338, (2, 2): 0.3505658805370331, (3, 3): 0.11184519529342651, (4, 4): 0.23161092400550842, (5, 5): 0.38228902220726013, (6, 6): 0.08281208574771881, (7, 7): 0.24915504455566406, (8, 8): 0.10236231982707977, (9, 9): 0.2930784225463867, (10, 10): 0.41771554946899414, (11, 11): 0.17606183886528015, (12, 12): 0.46679267287254333, (13, 13): -0.0669143944978714, (14, 14): -0.08309005200862885, (15, 15): 0.2622664272785187, (16, 16): -0.14915215969085693, (17, 17): 0.0672696977853775, (18, 18): 1.3571451902389526, (19, 19): 0.06354774534702301, (20, 20): 0.17312249541282654, (21, 21): 0.6043580770492554, (22, 22): 0.4030352234840393, (23, 23): 0.9028497934341431, (24, 24): -0.015416771173477173, (25, 25): 0.4994756579399109, (26, 26): 0.9656969904899597, (27, 27): 0.2780976891517639}\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "scores, tokens = explain_sentence(sentence, algo, tokenizer)\n", 150 | "print(\"tokens\\n\",tokens)\n", 151 | "print(\"\\nscores\\n\",scores)" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python [conda env:test]", 158 | "language": "python", 159 | "name": "conda-env-test-py" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.6.7" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /baselines/scd_soc/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from bert.run_classifier import BertConfig, BertForSequenceClassification 5 | from algo.soc_transformer import SOCForTransformer 6 | from algo.scd_transformer import CDForTransformer, SCDForTransformer 7 | from utils.reader import ( 8 | get_data_iterators_sst_flatten, 9 | get_data_iterators_yelp, 10 | get_data_iterators_tacred, 11 | ) 12 | import sys 13 | 14 | sys.path.append("../../src") 15 | from application_utils.text_utils import prepare_huggingface_data 16 | 17 | 18 | def set_seed(seed=0): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | if torch.cuda.is_available(): 23 | torch.cuda.manual_seed(seed) 24 | 25 | 26 | def get_bert(bert_path, device): 27 | CONFIG_NAME = "config.json" 28 | WEIGHTS_NAME = "pytorch_model.bin" 29 | 30 | output_model_file = os.path.join(bert_path, WEIGHTS_NAME) 31 | output_config_file = os.path.join(bert_path, CONFIG_NAME) 32 | 33 | config = BertConfig(output_config_file) 34 | model = BertForSequenceClassification(config, num_labels=2) 35 | model.load_state_dict(torch.load(output_model_file)) 36 | model.eval() 37 | 38 | if device.index >= 0: 39 | model = model.to(device) 40 | return model 41 | 42 | 43 | def get_lm_model(lm_path, gpu): 44 | lm_model = torch.load( 45 | lm_path, map_location=lambda storage, location: storage.cuda(gpu) 46 | ) 47 | lm_model.gpu = gpu 48 | lm_model.encoder.gpu = gpu 49 | return lm_model 50 | 51 | 52 | def get_hiexpl( 53 | method, bert_model, lm_path, tokenizer, device, sample_num=20, lm_model=None 54 | ): 55 | assert method in {"soc", "scd", "cd"} 56 | 57 | gpu = device.index 58 | 59 | class args: 60 | pass 61 | 62 | args.gpu = gpu 63 | args.task = "sst" 64 | args.dataset = "test" 65 | args.lm_path = lm_path 66 | args.nb_method = "ngram" 67 | args.d_out = 2 68 | # args.n_cells = args.n_layers 69 | args.use_gpu = args.gpu >= 0 70 | args.method = method 71 | args.nb_range = 10 72 | args.start = 0 73 | args.stop = 10 74 | args.batch_size = 1 75 | args.sample_n = sample_num 76 | args.use_bert_lm = True 77 | 78 | ( 79 | text_field, 80 | length_field, 81 | train_iter, 82 | dev_iter, 83 | test_iter, 84 | train, 85 | dev, 86 | ) = get_data_iterators_sst_flatten(map_cpu=False) 87 | 88 | iter_map = {"train": train_iter, "dev": dev_iter, "test": test_iter} 89 | if args.task == "sst": 90 | tree_path = ".data/sst/trees/%s.txt" 91 | else: 92 | raise ValueError 93 | 94 | args.n_embed = len(text_field.vocab) 95 | 96 | if args.method == "soc": 97 | if lm_model is None: 98 | lm_model = get_lm_model(args.lm_path, args.gpu) 99 | algo = SOCForTransformer( 100 | bert_model, 101 | lm_model, 102 | tree_path=tree_path % args.dataset, 103 | output_path=None, 104 | # output_path='outputs/' + args.task + '/soc_bert_results/soc%s.txt' % args.exp_name, 105 | config=args, 106 | vocab=text_field.vocab, 107 | tokenizer=tokenizer, 108 | ) 109 | elif args.method == "scd": 110 | if lm_model is None: 111 | lm_model = get_lm_model(args.lm_path, args.gpu) 112 | algo = SCDForTransformer( 113 | bert_model, 114 | lm_model, 115 | tree_path=tree_path % args.dataset, 116 | # output_path='outputs/' + args.task + '/scd_bert_results/scd%s.txt' % args.exp_name, 117 | output_path=None, 118 | config=args, 119 | vocab=text_field.vocab, 120 | tokenizer=tokenizer, 121 | ) 122 | 123 | elif args.method == "cd": 124 | algo = CDForTransformer( 125 | bert_model, 126 | tree_path=tree_path % args.dataset, 127 | # output_path='outputs/' + args.task + '/scd_bert_results/scd%s.txt' % args.exp_name, 128 | output_path=None, 129 | config=args, 130 | tokenizer=tokenizer, 131 | ) 132 | else: 133 | raise ValueError 134 | 135 | return algo 136 | 137 | 138 | def get_prediction(model, sentences, tokenizer, device): 139 | X = prepare_huggingface_data(sentences, tokenizer) 140 | for key in X: 141 | X[key] = torch.from_numpy(X[key]).to(device) 142 | preds = model( 143 | X["input_ids"], X["token_type_ids"].long(), X["attention_mask"].long() 144 | ) 145 | return preds 146 | 147 | 148 | def explain_sentence(sentence, algo, tokenizer, spans=None): 149 | 150 | X = prepare_huggingface_data([sentence], tokenizer) 151 | 152 | for key in X: 153 | X[key] = torch.from_numpy(X[key]).to(torch.device("cuda:" + str(algo.gpu))) 154 | 155 | inp = X["input_ids"] 156 | segment_ids = X["token_type_ids"].long() 157 | input_mask = X["attention_mask"].long() 158 | 159 | if spans is None: 160 | spans = [(x, x) for x in range(0, inp.shape[1] - 2)] 161 | 162 | contribs = {} 163 | for span in spans: 164 | span = (span[0] + 1, span[1] + 1) 165 | contrib = algo.explain_single_transformer(inp, input_mask, segment_ids, span) 166 | contribs[span] = contrib 167 | 168 | tokens = tokenizer.convert_ids_to_tokens(inp.view(-1).cpu().numpy()) 169 | 170 | return contribs, tokens[1:-1] 171 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/__init__.py -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/algo/__init__.py -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLCorpus 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import ( 8 | BertConfig, 9 | BertModel, 10 | BertForPreTraining, 11 | BertForMaskedLM, 12 | BertForNextSentencePrediction, 13 | BertForSequenceClassification, 14 | BertForMultipleChoice, 15 | BertForTokenClassification, 16 | BertForQuestionAnswering, 17 | load_tf_weights_in_bert, 18 | ) 19 | from .modeling_openai import ( 20 | OpenAIGPTConfig, 21 | OpenAIGPTModel, 22 | OpenAIGPTLMHeadModel, 23 | OpenAIGPTDoubleHeadsModel, 24 | load_tf_weights_in_openai_gpt, 25 | ) 26 | from .modeling_transfo_xl import ( 27 | TransfoXLConfig, 28 | TransfoXLModel, 29 | TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl, 31 | ) 32 | from .modeling_gpt2 import ( 33 | GPT2Config, 34 | GPT2Model, 35 | GPT2LMHeadModel, 36 | GPT2DoubleHeadsModel, 37 | load_tf_weights_in_gpt2, 38 | ) 39 | 40 | from .optimization import BertAdam 41 | from .optimization_openai import OpenAIAdam 42 | 43 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 44 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | 5 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 6 | "convert_tf_checkpoint_to_pytorch", 7 | "convert_openai_checkpoint", 8 | "convert_transfo_xl_checkpoint", 9 | "convert_gpt2_checkpoint", 10 | ]: 11 | print( 12 | "Should be used as one of: \n" 13 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 14 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 15 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 16 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`" 17 | ) 18 | else: 19 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 20 | try: 21 | from .convert_tf_checkpoint_to_pytorch import ( 22 | convert_tf_checkpoint_to_pytorch, 23 | ) 24 | except ImportError: 25 | print( 26 | "pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 27 | "In that case, it requires TensorFlow to be installed. Please see " 28 | "https://www.tensorflow.org/install/ for installation instructions." 29 | ) 30 | raise 31 | 32 | if len(sys.argv) != 5: 33 | # pylint: disable=line-too-long 34 | print( 35 | "Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`" 36 | ) 37 | else: 38 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 39 | TF_CONFIG = sys.argv.pop() 40 | TF_CHECKPOINT = sys.argv.pop() 41 | convert_tf_checkpoint_to_pytorch( 42 | TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT 43 | ) 44 | elif sys.argv[1] == "convert_openai_checkpoint": 45 | from .convert_openai_checkpoint_to_pytorch import ( 46 | convert_openai_checkpoint_to_pytorch, 47 | ) 48 | 49 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 50 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 51 | if len(sys.argv) == 5: 52 | OPENAI_GPT_CONFIG = sys.argv[4] 53 | else: 54 | OPENAI_GPT_CONFIG = "" 55 | convert_openai_checkpoint_to_pytorch( 56 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 57 | OPENAI_GPT_CONFIG, 58 | PYTORCH_DUMP_OUTPUT, 59 | ) 60 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 61 | try: 62 | from .convert_transfo_xl_checkpoint_to_pytorch import ( 63 | convert_transfo_xl_checkpoint_to_pytorch, 64 | ) 65 | except ImportError: 66 | print( 67 | "pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 68 | "In that case, it requires TensorFlow to be installed. Please see " 69 | "https://www.tensorflow.org/install/ for installation instructions." 70 | ) 71 | raise 72 | 73 | if "ckpt" in sys.argv[2].lower(): 74 | TF_CHECKPOINT = sys.argv[2] 75 | TF_DATASET_FILE = "" 76 | else: 77 | TF_DATASET_FILE = sys.argv[2] 78 | TF_CHECKPOINT = "" 79 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 80 | if len(sys.argv) == 5: 81 | TF_CONFIG = sys.argv[4] 82 | else: 83 | TF_CONFIG = "" 84 | convert_transfo_xl_checkpoint_to_pytorch( 85 | TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE 86 | ) 87 | else: 88 | try: 89 | from .convert_gpt2_checkpoint_to_pytorch import ( 90 | convert_gpt2_checkpoint_to_pytorch, 91 | ) 92 | except ImportError: 93 | print( 94 | "pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions." 97 | ) 98 | raise 99 | 100 | TF_CHECKPOINT = sys.argv[2] 101 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 102 | if len(sys.argv) == 5: 103 | TF_CONFIG = sys.argv[4] 104 | else: 105 | TF_CONFIG = "" 106 | convert_gpt2_checkpoint_to_pytorch( 107 | TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import ( 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | GPT2Config, 28 | GPT2Model, 29 | load_tf_weights_in_gpt2, 30 | ) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch( 34 | gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path 35 | ): 36 | # Construct model 37 | if gpt2_config_file == "": 38 | config = GPT2Config() 39 | else: 40 | config = GPT2Config(gpt2_config_file) 41 | model = GPT2Model(config) 42 | 43 | # Load weights from numpy 44 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 45 | 46 | # Save pytorch-model 47 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 48 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model.state_dict(), pytorch_weights_dump_path) 51 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 52 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 53 | f.write(config.to_json_string()) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | ## Required parameters 59 | parser.add_argument( 60 | "--gpt2_checkpoint_path", 61 | default=None, 62 | type=str, 63 | required=True, 64 | help="Path the TensorFlow checkpoint path.", 65 | ) 66 | parser.add_argument( 67 | "--pytorch_dump_folder_path", 68 | default=None, 69 | type=str, 70 | required=True, 71 | help="Path to the output PyTorch model.", 72 | ) 73 | parser.add_argument( 74 | "--gpt2_config_file", 75 | default="", 76 | type=str, 77 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 78 | "This specifies the model architecture.", 79 | ) 80 | args = parser.parse_args() 81 | convert_gpt2_checkpoint_to_pytorch( 82 | args.gpt2_checkpoint_path, args.gpt2_config_file, args.pytorch_dump_folder_path 83 | ) 84 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import ( 25 | CONFIG_NAME, 26 | WEIGHTS_NAME, 27 | OpenAIGPTConfig, 28 | OpenAIGPTModel, 29 | load_tf_weights_in_openai_gpt, 30 | ) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch( 34 | openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path 35 | ): 36 | # Construct model 37 | if openai_config_file == "": 38 | config = OpenAIGPTConfig() 39 | else: 40 | config = OpenAIGPTConfig(openai_config_file) 41 | model = OpenAIGPTModel(config) 42 | 43 | # Load weights from numpy 44 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 45 | 46 | # Save pytorch-model 47 | pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME 48 | pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model.state_dict(), pytorch_weights_dump_path) 51 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 52 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 53 | f.write(config.to_json_string()) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | ## Required parameters 59 | parser.add_argument( 60 | "--openai_checkpoint_folder_path", 61 | default=None, 62 | type=str, 63 | required=True, 64 | help="Path the TensorFlow checkpoint path.", 65 | ) 66 | parser.add_argument( 67 | "--pytorch_dump_folder_path", 68 | default=None, 69 | type=str, 70 | required=True, 71 | help="Path to the output PyTorch model.", 72 | ) 73 | parser.add_argument( 74 | "--openai_config_file", 75 | default="", 76 | type=str, 77 | help="An optional config json file corresponding to the pre-trained OpenAI model. \n" 78 | "This specifies the model architecture.", 79 | ) 80 | args = parser.parse_args() 81 | convert_openai_checkpoint_to_pytorch( 82 | args.openai_checkpoint_folder_path, 83 | args.openai_config_file, 84 | args.pytorch_dump_folder_path, 85 | ) 86 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import ( 29 | BertConfig, 30 | BertForPreTraining, 31 | load_tf_weights_in_bert, 32 | ) 33 | 34 | 35 | def convert_tf_checkpoint_to_pytorch( 36 | tf_checkpoint_path, bert_config_file, pytorch_dump_path 37 | ): 38 | # Initialise PyTorch model 39 | config = BertConfig.from_json_file(bert_config_file) 40 | print("Building PyTorch model from configuration: {}".format(str(config))) 41 | model = BertForPreTraining(config) 42 | 43 | # Load weights from tf checkpoint 44 | load_tf_weights_in_bert(model, tf_checkpoint_path) 45 | 46 | # Save pytorch-model 47 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 48 | torch.save(model.state_dict(), pytorch_dump_path) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument( 55 | "--tf_checkpoint_path", 56 | default=None, 57 | type=str, 58 | required=True, 59 | help="Path the TensorFlow checkpoint path.", 60 | ) 61 | parser.add_argument( 62 | "--bert_config_file", 63 | default=None, 64 | type=str, 65 | required=True, 66 | help="The config json file corresponding to the pre-trained BERT model. \n" 67 | "This specifies the model architecture.", 68 | ) 69 | parser.add_argument( 70 | "--pytorch_dump_path", 71 | default=None, 72 | type=str, 73 | required=True, 74 | help="Path to the output PyTorch model.", 75 | ) 76 | args = parser.parse_args() 77 | convert_tf_checkpoint_to_pytorch( 78 | args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path 79 | ) 80 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import ( 28 | CONFIG_NAME, 29 | WEIGHTS_NAME, 30 | TransfoXLConfig, 31 | TransfoXLLMHeadModel, 32 | load_tf_weights_in_transfo_xl, 33 | ) 34 | from pytorch_pretrained_bert.tokenization_transfo_xl import CORPUS_NAME, VOCAB_NAME 35 | 36 | if sys.version_info[0] == 2: 37 | import cPickle as pickle 38 | else: 39 | import pickle 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules["data_utils"] = data_utils 46 | sys.modules["vocabulary"] = data_utils 47 | 48 | 49 | def convert_transfo_xl_checkpoint_to_pytorch( 50 | tf_checkpoint_path, 51 | transfo_xl_config_file, 52 | pytorch_dump_folder_path, 53 | transfo_xl_dataset_file, 54 | ): 55 | if transfo_xl_dataset_file: 56 | # Convert a pre-processed corpus (see original TensorFlow repo) 57 | with open(transfo_xl_dataset_file, "rb") as fp: 58 | corpus = pickle.load(fp, encoding="latin1") 59 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 60 | pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_NAME 61 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 62 | corpus_vocab_dict = corpus.vocab.__dict__ 63 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 64 | 65 | corpus_dict_no_vocab = corpus.__dict__ 66 | corpus_dict_no_vocab.pop("vocab", None) 67 | pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME 68 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 69 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 70 | 71 | if tf_checkpoint_path: 72 | # Convert a pre-trained TensorFlow model 73 | config_path = os.path.abspath(transfo_xl_config_file) 74 | tf_path = os.path.abspath(tf_checkpoint_path) 75 | 76 | print( 77 | "Converting Transformer XL checkpoint from {} with config at {}".format( 78 | tf_path, config_path 79 | ) 80 | ) 81 | # Initialise PyTorch model 82 | if transfo_xl_config_file == "": 83 | config = TransfoXLConfig() 84 | else: 85 | config = TransfoXLConfig(transfo_xl_config_file) 86 | print("Building PyTorch model from configuration: {}".format(str(config))) 87 | model = TransfoXLLMHeadModel(config) 88 | 89 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 90 | # Save pytorch-model 91 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 92 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 93 | print( 94 | "Save PyTorch model to {}".format( 95 | os.path.abspath(pytorch_weights_dump_path) 96 | ) 97 | ) 98 | torch.save(model.state_dict(), pytorch_weights_dump_path) 99 | print( 100 | "Save configuration file to {}".format( 101 | os.path.abspath(pytorch_config_dump_path) 102 | ) 103 | ) 104 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 105 | f.write(config.to_json_string()) 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument( 111 | "--pytorch_dump_folder_path", 112 | default=None, 113 | type=str, 114 | required=True, 115 | help="Path to the folder to store the PyTorch model or dataset/vocab.", 116 | ) 117 | parser.add_argument( 118 | "--tf_checkpoint_path", 119 | default="", 120 | type=str, 121 | help="An optional path to a TensorFlow checkpoint path to be converted.", 122 | ) 123 | parser.add_argument( 124 | "--transfo_xl_config_file", 125 | default="", 126 | type=str, 127 | help="An optional config json file corresponding to the pre-trained BERT model. \n" 128 | "This specifies the model architecture.", 129 | ) 130 | parser.add_argument( 131 | "--transfo_xl_dataset_file", 132 | default="", 133 | type=str, 134 | help="An optional dataset file to be converted in a vocabulary.", 135 | ) 136 | args = parser.parse_args() 137 | convert_transfo_xl_checkpoint_to_pytorch( 138 | args.tf_checkpoint_path, 139 | args.transfo_xl_config_file, 140 | args.pytorch_dump_folder_path, 141 | args.transfo_xl_dataset_file, 142 | ) 143 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/decomp_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/filter_sentence.py: -------------------------------------------------------------------------------- 1 | all_sent_file = "glue_data/SST-2/original/datasetSentences.txt" 2 | train_file = "glue_data/SST-2/train_all.tsv" 3 | 4 | f1 = open(all_sent_file) 5 | f2 = open(train_file) 6 | fw = open("glue_data/SST-2/train.tsv", "w") 7 | 8 | lines1, lines2 = f1.readlines(), f2.readlines() 9 | fw.write(lines1[0]) 10 | 11 | hash_set = set() 12 | for line in lines1: 13 | sent = line.split("\t")[-1] 14 | hash_set.add(sent.strip().lower()) 15 | 16 | for line in lines2: 17 | sent = line.split("\t")[0] 18 | if sent.strip() in hash_set: 19 | fw.write(line) 20 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/global_state.py: -------------------------------------------------------------------------------- 1 | class _GlobalStateDict: 2 | def __init__(self): 3 | self.states = [] 4 | self.current_layer_id = 0 5 | self.store_flag = True 6 | self.activated = False 7 | 8 | self.rel_span_len = 0 9 | self.total_span_len = 1 10 | 11 | def store_state(self, value): 12 | if self.activated: 13 | self.states.append(value) 14 | self.current_layer_id += 1 15 | 16 | def get_states(self): 17 | states = self.states[self.current_layer_id] # [B * H] 18 | states = states.split(1, 0) 19 | self.current_layer_id += 1 20 | return states 21 | 22 | def init_store_states(self): 23 | self.activated = True 24 | self.states = [] 25 | self.current_layer_id = 0 26 | self.store_flag = True 27 | 28 | def init_fetch_states(self): 29 | self.current_layer_id = 0 30 | self.store_flag = False 31 | 32 | 33 | global_state_dict = _GlobalStateDict() 34 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/bert/tacred_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Board of Trustees of The Leland Stanford Junior University 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import Counter 16 | import sys 17 | 18 | NO_RELATION = 0 19 | 20 | 21 | def score(key, prediction, verbose=False): 22 | correct_by_relation = Counter() 23 | guessed_by_relation = Counter() 24 | gold_by_relation = Counter() 25 | 26 | # Loop over the data to compute a score 27 | for row in range(len(key)): 28 | gold = key[row] 29 | guess = prediction[row] 30 | 31 | if gold == NO_RELATION and guess == NO_RELATION: 32 | pass 33 | elif gold == NO_RELATION and guess != NO_RELATION: 34 | guessed_by_relation[guess] += 1 35 | elif gold != NO_RELATION and guess == NO_RELATION: 36 | gold_by_relation[gold] += 1 37 | elif gold != NO_RELATION and guess != NO_RELATION: 38 | guessed_by_relation[guess] += 1 39 | gold_by_relation[gold] += 1 40 | if gold == guess: 41 | correct_by_relation[guess] += 1 42 | 43 | # Print verbose information 44 | if verbose: 45 | print("Per-relation statistics:") 46 | relations = gold_by_relation.keys() 47 | longest_relation = 0 48 | for relation in sorted(relations): 49 | longest_relation = max(len(relation), longest_relation) 50 | for relation in sorted(relations): 51 | # (compute the score) 52 | correct = correct_by_relation[relation] 53 | guessed = guessed_by_relation[relation] 54 | gold = gold_by_relation[relation] 55 | prec = 1.0 56 | if guessed > 0: 57 | prec = float(correct) / float(guessed) 58 | recall = 0.0 59 | if gold > 0: 60 | recall = float(correct) / float(gold) 61 | f1 = 0.0 62 | if prec + recall > 0: 63 | f1 = 2.0 * prec * recall / (prec + recall) 64 | # (print the score) 65 | sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation)) 66 | sys.stdout.write(" P: ") 67 | if prec < 0.1: 68 | sys.stdout.write(" ") 69 | if prec < 1.0: 70 | sys.stdout.write(" ") 71 | sys.stdout.write("{:.2%}".format(prec)) 72 | sys.stdout.write(" R: ") 73 | if recall < 0.1: 74 | sys.stdout.write(" ") 75 | if recall < 1.0: 76 | sys.stdout.write(" ") 77 | sys.stdout.write("{:.2%}".format(recall)) 78 | sys.stdout.write(" F1: ") 79 | if f1 < 0.1: 80 | sys.stdout.write(" ") 81 | if f1 < 1.0: 82 | sys.stdout.write(" ") 83 | sys.stdout.write("{:.2%}".format(f1)) 84 | sys.stdout.write(" #: %d" % gold) 85 | sys.stdout.write("\n") 86 | print("") 87 | 88 | # Print the aggregate score 89 | if verbose: 90 | print("Final Score:") 91 | prec_micro = 1.0 92 | if sum(guessed_by_relation.values()) > 0: 93 | prec_micro = float(sum(correct_by_relation.values())) / float( 94 | sum(guessed_by_relation.values()) 95 | ) 96 | recall_micro = 0.0 97 | if sum(gold_by_relation.values()) > 0: 98 | recall_micro = float(sum(correct_by_relation.values())) / float( 99 | sum(gold_by_relation.values()) 100 | ) 101 | f1_micro = 0.0 102 | if prec_micro + recall_micro > 0.0: 103 | f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro) 104 | # print("Precision (micro): {:.3%}".format(prec_micro)) 105 | # print(" Recall (micro): {:.3%}".format(recall_micro)) 106 | # print(" F1 (micro): {:.3%}".format(f1_micro)) 107 | return prec_micro, recall_micro, f1_micro 108 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/lm/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python lm_train.py --task sst --gpu 0 --save_path model -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/lm_arch.py: -------------------------------------------------------------------------------- 1 | from torch.distributions import Categorical 2 | from nns.layers import * 3 | from utils.args import get_args 4 | from torch.nn import functional as F 5 | 6 | args = get_args() 7 | 8 | 9 | class LSTMLanguageModel(nn.Module): 10 | def __init__(self, config, vocab): 11 | super().__init__() 12 | self.hidden_size = config.lm_d_hidden 13 | self.embed_size = config.lm_d_embed 14 | self.n_vocab = config.n_embed 15 | self.gpu = args.gpu 16 | 17 | self.encoder = DynamicEncoder( 18 | self.n_vocab, self.embed_size, self.hidden_size, self.gpu 19 | ) 20 | self.fw_proj = nn.Linear(self.hidden_size, self.n_vocab) 21 | self.bw_proj = nn.Linear(self.hidden_size, self.n_vocab) 22 | 23 | self.loss = nn.CrossEntropyLoss(ignore_index=1) 24 | self.vocab = vocab 25 | 26 | self.warning_flag = False 27 | 28 | def forward(self, batch): 29 | inp = batch.text 30 | inp_len_np = batch.length.cpu().numpy() 31 | if self.gpu >= 0: 32 | inp = inp.to(self.gpu) 33 | output = self.encoder(inp, inp_len_np) 34 | fw_output, bw_output = ( 35 | output[:, :, : self.hidden_size], 36 | output[:, :, self.hidden_size :], 37 | ) 38 | fw_proj, bw_proj = self.fw_proj(fw_output), self.bw_proj(bw_output) 39 | 40 | fw_loss = self.loss( 41 | fw_proj[:-1].view(-1, fw_proj.size(2)).contiguous(), 42 | inp[1:].view(-1).contiguous(), 43 | ) 44 | bw_loss = self.loss( 45 | bw_proj[1:].view(-1, bw_proj.size(2)).contiguous(), 46 | inp[:-1].view(-1).contiguous(), 47 | ) 48 | return fw_loss, bw_loss 49 | 50 | def _sample_n_sequences( 51 | self, method, direction, token_inp, hidden, length, sample_num 52 | ): 53 | outputs = [] 54 | token_inp = token_inp.repeat(1, sample_num) # [1, N] 55 | hidden = hidden[0].repeat(1, sample_num, 1), hidden[1].repeat( 56 | 1, sample_num, 1 57 | ) # [x, N, H] 58 | for t in range(length): 59 | output, hidden = self.encoder.rollout( 60 | token_inp, hidden, direction=direction 61 | ) 62 | if direction == "fw": 63 | proj = self.fw_proj(output[:, :, : self.hidden_size]) 64 | elif direction == "bw": 65 | proj = self.bw_proj(output[:, :, self.hidden_size :]) 66 | proj = proj.squeeze(0) 67 | if method == "max": 68 | _, token_inp = torch.max(proj, -1) 69 | outputs.append(token_inp.view(-1)) 70 | elif method == "random": 71 | dist = Categorical(F.softmax(proj, -1)) 72 | token_inp = dist.sample() 73 | outputs.append(token_inp) 74 | token_inp = token_inp.view(1, -1) 75 | if direction == "bw": 76 | outputs = list(reversed(outputs)) 77 | outputs = torch.stack(outputs) 78 | return outputs 79 | 80 | def sample_n(self, method, batch, max_sample_length, sample_num): 81 | inp = batch.text 82 | inp_len_np = batch.length.cpu().numpy() 83 | batch_size = inp.size(1) 84 | assert batch_size == 1 85 | 86 | pad_inp1 = torch.LongTensor([self.vocab.stoi[""]] * inp.size(1)).view(1, -1) 87 | pad_inp2 = torch.LongTensor([self.vocab.stoi[""]] * inp.size(1)).view(1, -1) 88 | 89 | if self.gpu >= 0: 90 | inp = inp.to(self.gpu) 91 | pad_inp1 = pad_inp1.to(self.gpu) 92 | pad_inp2 = pad_inp2.to(self.gpu) 93 | 94 | padded_inp = torch.cat([pad_inp1, inp, pad_inp2], 0) 95 | assert padded_inp.max().item() < self.n_vocab 96 | assert inp_len_np[0] + 2 <= padded_inp.size(0) 97 | padded_enc_out, (padded_hidden_states, padded_cell_states) = self.encoder( 98 | padded_inp, inp_len_np + 2, return_all_states=True 99 | ) # [T+2,B,H] 100 | 101 | # extract forward hidden state 102 | assert 0 <= batch.fw_pos.item() - 1 <= padded_enc_out.size(0) - 1 103 | assert 0 <= batch.fw_pos.item() <= padded_enc_out.size(0) - 1 104 | 105 | fw_hidden_state = padded_hidden_states.index_select(0, batch.fw_pos - 1)[0] 106 | fw_cell_state = padded_cell_states.index_select(0, batch.fw_pos - 1)[0] 107 | fw_next_token = padded_inp.index_select(0, batch.fw_pos).view(1, -1) 108 | 109 | # extract backward hidden state 110 | assert 0 <= batch.bw_pos.item() + 3 <= padded_enc_out.size(0) - 1 111 | assert 0 <= batch.bw_pos.item() + 2 <= padded_enc_out.size(0) - 1 112 | # batch 113 | bw_hidden_state = padded_hidden_states.index_select(0, batch.bw_pos + 3)[0] 114 | bw_cell_state = padded_cell_states.index_select(0, batch.bw_pos + 3)[0] 115 | # torch.cat([bw_hidden[:,:,:self.hidden_size], bw_hidden[:,:,self.hidden_size:]], 0) 116 | bw_next_token = padded_inp.index_select(0, batch.bw_pos + 2).view(1, -1) 117 | 118 | fw_sample_outputs = self._sample_n_sequences( 119 | method, 120 | "fw", 121 | fw_next_token, 122 | (fw_hidden_state, fw_cell_state), 123 | max_sample_length, 124 | sample_num, 125 | ) 126 | bw_sample_outputs = self._sample_n_sequences( 127 | method, 128 | "bw", 129 | bw_next_token, 130 | (bw_hidden_state, bw_cell_state), 131 | max_sample_length, 132 | sample_num, 133 | ) 134 | 135 | self.filter_special_tokens(fw_sample_outputs) 136 | self.filter_special_tokens(bw_sample_outputs) 137 | 138 | return fw_sample_outputs, bw_sample_outputs 139 | 140 | def filter_special_tokens(self, m): 141 | for i in range(m.size(0)): 142 | for j in range(m.size(1)): 143 | if m[i, j] >= self.n_vocab - 2: 144 | m[i, j] = 0 145 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/nns/__init__.py -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/hiexpl_vocab_atts.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/nns/hiexpl_vocab_atts.pickle -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class DynamicEncoder(nn.Module): 7 | def __init__( 8 | self, input_size, embed_size, hidden_size, gpu, n_layers=1, dropout=0.1 9 | ): 10 | super().__init__() 11 | self.input_size = input_size 12 | self.hidden_size = hidden_size 13 | self.embed_size = embed_size 14 | self.n_layers = n_layers 15 | self.dropout = dropout 16 | self.embedding = nn.Embedding(input_size, embed_size) 17 | self.lstm = nn.LSTM(embed_size, hidden_size, n_layers, bidirectional=True) 18 | self.gpu = gpu 19 | 20 | def forward(self, input_seqs, input_lens, hidden=None, return_all_states=False): 21 | batch_size = input_seqs.size(1) 22 | embedded = self.embedding(input_seqs) 23 | if not return_all_states: 24 | embedded = embedded.transpose(0, 1) # [B,T,E] 25 | sort_idx = np.argsort(-input_lens) 26 | unsort_idx = torch.LongTensor(np.argsort(sort_idx)) 27 | if self.gpu >= 0: 28 | unsort_idx = unsort_idx.to(self.gpu) 29 | input_lens = input_lens[sort_idx] 30 | sort_idx = torch.LongTensor(sort_idx) 31 | if self.gpu >= 0: 32 | sort_idx = sort_idx.to(self.gpu) 33 | embedded = embedded[sort_idx].transpose(0, 1) # [T,B,E] 34 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lens) 35 | outputs, hidden = self.lstm(packed, hidden) 36 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) 37 | # outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 38 | outputs = outputs.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() 39 | return outputs 40 | else: 41 | hidden = None 42 | hidden_states, cell_states = [], [] 43 | outputs = [] 44 | for t in range(input_seqs.size(0)): 45 | output, hidden = self.lstm(embedded[t].unsqueeze(0), hidden) 46 | hidden_states.append(hidden[0]) 47 | cell_states.append(hidden[1]) 48 | outputs.append(output) 49 | outputs = torch.cat(outputs, 0) 50 | hidden_states = torch.stack(hidden_states, 0) 51 | cell_states = torch.stack(cell_states, 0) 52 | return outputs, (hidden_states, cell_states) 53 | 54 | def rollout(self, input_word, prev_hidden, direction): 55 | embed = self.embedding(input_word) 56 | output, hidden = self.lstm(embed, prev_hidden) 57 | return output, hidden 58 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/linear.sh: -------------------------------------------------------------------------------- 1 | python linear_model.py --task sst --gpu 1 --save_path "./linear_models/" --epochs 20 -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/linear_models/best_snapshot_devacc_82.11009174311926_devloss_0.6408365964889526_iter_14500_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/nns/linear_models/best_snapshot_devacc_82.11009174311926_devloss_0.6408365964889526_iter_14500_model.pt -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/linear_models/snapshot_acc_0.0000_loss_0.233998_iter_39500_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/nns/linear_models/snapshot_acc_0.0000_loss_0.233998_iter_39500_model.pt -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/nns/vocab/vocab_sst.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/nns/vocab/vocab_sst.pkl -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/outputs/sst/scd_bert_results/scdbert2020.3.txt: -------------------------------------------------------------------------------- 1 | effective 3.506518 but 0.285621 effective but 2.824195 too - te ##pid -2.738906 bio ##pic 0.394085 too - te ##pid bio ##pic -2.932729 effective but too - te ##pid bio ##pic -3.581248 2 | if -0.105014 you 0.321082 sometimes 0.595583 like 0.477573 to 0.033809 go 0.043812 to 0.043598 the 0.052670 movies 0.069463 the movies 0.027218 to the movies 0.165071 go to the movies 0.366723 to 0.032711 have 0.033589 fun 0.593609 have fun 0.992044 to have fun 0.923421 go to the movies to have fun 0.539228 to go to the movies to have fun 0.462752 like to go to the movies to have fun 1.632296 sometimes like to go to the movies to have fun 1.306539 you sometimes like to go to the movies to have fun 2.127275 if you sometimes like to go to the movies to have fun 1.825670 , 0.021351 was ##abi 0.024231 is 0.128994 a 0.108641 good 1.790137 place 0.139623 to 0.132926 start 1.122074 to start 1.162639 place to start 1.994683 good place to start 5.531188 a good place to start 5.364074 is a good place to start 6.137796 . 0.044383 is a good place to start . 4.254536 was ##abi is a good place to start . 4.564784 , was ##abi is a good place to start . 5.114339 if you sometimes like to go to the movies to have fun , was ##abi is a good place to start . 11.554394 3 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/outputs/sst/soc_bert_results/socbert2020.3.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/outputs/sst/soc_bert_results/socbert2020.3.txt -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/outputs/sst/soc_results/soctest2020.3.txt: -------------------------------------------------------------------------------- 1 | effective 0.770350 but -0.012891 effective but 0.398268 too-tepid -0.519229 biopic 0.077375 too-tepid biopic 0.199293 effective but too-tepid biopic 0.833620 2 | if 0.009862 you 0.213750 sometimes -0.215167 like -0.182198 to -0.112900 go -0.102774 to -0.045183 the 0.036247 movies 0.417075 the movies 0.274618 to the movies 0.264284 go to the movies 0.094036 to -0.034154 have -0.117422 fun 0.848628 have fun 0.535340 to have fun 0.623575 go to the movies to have fun 0.606097 to go to the movies to have fun 0.335351 like to go to the movies to have fun 0.248454 sometimes like to go to the movies to have fun 0.207922 you sometimes like to go to the movies to have fun 0.572097 if you sometimes like to go to the movies to have fun 0.839144 , 0.176022 wasabi 0.516602 is 0.094853 a -0.040336 good 0.540690 place -0.314529 to -0.067079 start -0.067583 to start 0.027760 place to start -0.477552 good place to start 0.438865 a good place to start 0.664726 is a good place to start 1.253866 . -0.051005 is a good place to start . 1.456468 wasabi is a good place to start . 2.278472 , wasabi is a good place to start . 2.371592 if you sometimes like to go to the movies to have fun , wasabi is a good place to start . 2.606780 3 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/outputs/sst/soc_results/soctest2020.3_2.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/outputs/sst/soc_results/soctest2020.3_2.txt -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/scripts/explanations/explain_sst_lstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_path=${1} 4 | lm_path=${2} 5 | python explain.py --resume_snapshot ${1} --task sst --method cd --batch_size 1 --exp_name .cd2019.3 --nb_range 10 --lm_path models/sst_lm_2/best_snapshot_devloss_11.634532430897588_iter_1300_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test 6 | 7 | 8 | python explain.py --resume_snapshot results/best_snapshot_devacc_84.05963134765625_devloss_0.42193958163261414_iter_3700_model.pt --task sst --method soc --batch_size 1 --exp_name .cd2019.3 --nb_range 10 --lm_path model/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test 9 | 10 | 11 | python explain.py --resume_snapshot results/best_snapshot_devacc_84.05963134765625_devloss_0.42193958163261414_iter_3700_model.pt --task sst --method soc --batch_size 1 --exp_name test2020.3 --nb_range 10 --lm_path model/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test 12 | 13 | python explain.py --explain_model bert --resume_snapshot models_sst --task sst --method soc --batch_size 1 --exp_name bertsoc2020.3 --nb_range 10 --lm_path model/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test --use_bert_tokenizer 14 | 15 | python explain.py --explain_model bert --resume_snapshot models_sst --task sst --method scd --batch_size 1 --exp_name bert2020.3 --nb_range 10 --lm_path model/best_snapshot_devloss_11.708949835404105_iter_2000_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test --use_bert_tokenizer 16 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/scripts/train_model/train_sst_lstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --task sst --save_path sst_lstm -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/utils/__init__.py -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/utils/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | 4 | 5 | def makedirs(name): 6 | """helper function for python 2 and 3 to call os.makedirs() 7 | avoiding an error if the directory to be created already exists""" 8 | 9 | import os, errno 10 | 11 | try: 12 | os.makedirs(name) 13 | except OSError as ex: 14 | if ex.errno == errno.EEXIST and os.path.isdir(name): 15 | # ignore existing directory 16 | pass 17 | else: 18 | # a different error happened 19 | raise 20 | 21 | 22 | def get_best_snapshot(dir): 23 | if os.path.isdir(dir): 24 | files = os.listdir(dir) 25 | for file in files: 26 | if file.startswith("best_"): 27 | return os.path.join(dir, file) 28 | return None 29 | 30 | 31 | def get_args(): 32 | parser = ArgumentParser(description="PyTorch/torchtext SST") 33 | 34 | # model parameters 35 | parser.add_argument("--optim", type=str, default="adam", choices=["adam", "sgd"]) 36 | parser.add_argument( 37 | "--metrics", default="accuracy", choices=["accuracy", "tacred_f1"] 38 | ) 39 | parser.add_argument("--epochs", type=int, default=20) 40 | parser.add_argument("--task", type=str, default="sst") 41 | parser.add_argument("--batch_size", type=int, default=50) 42 | parser.add_argument("--d_embed", type=int, default=300) 43 | parser.add_argument("--d_proj", type=int, default=300) 44 | parser.add_argument("--d_hidden", type=int, default=128) 45 | parser.add_argument("--n_layers", type=int, default=1) 46 | parser.add_argument("--dropout", type=float, default=0.0) 47 | parser.add_argument("--log_every", type=int, default=10000) 48 | parser.add_argument("--lr", type=float, default=0.0005) 49 | parser.add_argument("--weight_decay", type=float, default=1e-6) 50 | parser.add_argument("--dev_every", type=int, default=100) 51 | parser.add_argument("--save_every", type=int, default=100) 52 | parser.add_argument("--no-bidirectional", action="store_false", dest="birnn") 53 | parser.add_argument("--preserve-case", action="store_false", dest="lower") 54 | parser.add_argument("--no-projection", action="store_false", dest="projection") 55 | parser.add_argument("--fix_emb", action="store_true") 56 | parser.add_argument("--gpu", default=0) 57 | parser.add_argument("--save_path", type=str, default="results") 58 | parser.add_argument( 59 | "--vector_cache", 60 | type=str, 61 | default=os.path.join(os.getcwd(), ".vector_cache/input_vectors.pt"), 62 | ) 63 | parser.add_argument("--word_vectors", type=str, default="glove.6B.300d") 64 | parser.add_argument("--resume_snapshot", type=str, default="") 65 | parser.add_argument("--word_dropout", action="store_true") 66 | 67 | parser.add_argument("--lm_d_embed", type=int, default=300) 68 | parser.add_argument("--lm_d_hidden", type=int, default=128) 69 | 70 | parser.add_argument("--method", nargs="?") 71 | parser.add_argument("--nb_method", default="ngram") 72 | parser.add_argument("--nb_range", type=int, default=3) 73 | parser.add_argument("--exp_name", default="") 74 | parser.add_argument("--lm_dir", nargs="?", default="") 75 | parser.add_argument("--lm_path", nargs="?", default="") 76 | parser.add_argument("--start", type=int, default=0) 77 | parser.add_argument("--stop", type=int, default=10000000000) 78 | parser.add_argument("--sample_n", type=int, default=5) 79 | 80 | parser.add_argument("--explain_model", default="lstm") 81 | parser.add_argument("--demo", action="store_true") 82 | 83 | parser.add_argument("--dataset", default="dev") 84 | parser.add_argument("--use_bert_tokenizer", action="store_true") 85 | parser.add_argument("--no_subtrees", action="store_true") 86 | 87 | parser.add_argument("--use_bert_lm", action="store_true") 88 | parser.add_argument("--fix_test_vocab", action="store_true") 89 | 90 | parser.add_argument("--include_noise_labels", action="store_true") 91 | parser.add_argument("--filter_length_gt", type=int, default=-1) 92 | parser.add_argument("--add_itself", action="store_true") 93 | 94 | parser.add_argument("--mean_hidden", action="store_true") 95 | parser.add_argument("--agg", action="store_true") 96 | parser.add_argument("--class_score", action="store_true") 97 | 98 | parser.add_argument("--cd_pad", action="store_true") 99 | parser.add_argument("--eval_file", default="") 100 | parser.add_argument("-f", default="") 101 | 102 | args = parser.parse_args() 103 | 104 | try: 105 | args.gpu = int(args.gpu) 106 | except ValueError: 107 | args.gpu = "cpu" 108 | 109 | if os.path.isdir(args.resume_snapshot): 110 | args.resume_snapshot = get_best_snapshot(args.resume_snapshot) 111 | if os.path.isdir(args.lm_path): 112 | args.lm_path = get_best_snapshot(args.lm_path) 113 | return args 114 | 115 | 116 | args = get_args() 117 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/utils/parser.py: -------------------------------------------------------------------------------- 1 | from nltk import ParentedTree 2 | 3 | 4 | def parse_tree(s): 5 | tree = ParentedTree.fromstring(s) 6 | return tree 7 | 8 | 9 | def read_trees_from_corpus(path): 10 | f = open(path) 11 | rows = f.readlines() 12 | trees = [] 13 | for row in rows: 14 | row = row.lower() 15 | tree = parse_tree(row) 16 | trees.append(tree) 17 | return trees 18 | 19 | 20 | def is_leaf(node): 21 | if type(node[0]) == str and len(node) != 1: 22 | print(1) 23 | return type(node[0]) == str 24 | 25 | 26 | def get_span_to_node_mapping(tree): 27 | def dfs(node, span_to_node, node_to_span, idx): 28 | if is_leaf(node): 29 | span_to_node[idx] = node 30 | node_to_span[id(node)] = idx 31 | return idx + 1 32 | prev_idx = idx 33 | for child in node: 34 | idx = dfs(child, span_to_node, node_to_span, idx) 35 | span_to_node[(prev_idx, idx - 1)] = node 36 | node_to_span[id(node)] = (prev_idx, idx - 1) 37 | return idx 38 | 39 | span2node, node2span = {}, {} 40 | dfs(tree, span2node, node2span, 0) 41 | return span2node, node2span 42 | 43 | 44 | def get_siblings_idx(node, node2span): 45 | parent = node.parent() 46 | if parent is None: # root 47 | return node2span[id(node)] 48 | return node2span[id(parent)] 49 | 50 | 51 | def find_region_neighbourhood(s_or_tree, region): 52 | if type(s_or_tree) is str: 53 | tree = parse_tree(s_or_tree) 54 | else: 55 | tree = s_or_tree 56 | 57 | if type(region) is tuple and region[0] == region[1]: 58 | region = region[0] 59 | 60 | span2node, node2span = get_span_to_node_mapping(tree) 61 | node = span2node[region] 62 | sibling_idx = get_siblings_idx(node, node2span) 63 | return sibling_idx 64 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/utils/tacred_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Board of Trustees of The Leland Stanford Junior University 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import Counter 17 | import sys 18 | 19 | NO_RELATION = 0 20 | 21 | 22 | def score(key, prediction, verbose=False): 23 | correct_by_relation = Counter() 24 | guessed_by_relation = Counter() 25 | gold_by_relation = Counter() 26 | 27 | # Loop over the data to compute a score 28 | for row in range(len(key)): 29 | gold = key[row] 30 | guess = prediction[row] 31 | 32 | if gold == NO_RELATION and guess == NO_RELATION: 33 | pass 34 | elif gold == NO_RELATION and guess != NO_RELATION: 35 | guessed_by_relation[guess] += 1 36 | elif gold != NO_RELATION and guess == NO_RELATION: 37 | gold_by_relation[gold] += 1 38 | elif gold != NO_RELATION and guess != NO_RELATION: 39 | guessed_by_relation[guess] += 1 40 | gold_by_relation[gold] += 1 41 | if gold == guess: 42 | correct_by_relation[guess] += 1 43 | 44 | # Print verbose information 45 | if verbose: 46 | print("Per-relation statistics:") 47 | relations = gold_by_relation.keys() 48 | longest_relation = 0 49 | for relation in sorted(relations): 50 | longest_relation = max(len(relation), longest_relation) 51 | for relation in sorted(relations): 52 | # (compute the score) 53 | correct = correct_by_relation[relation] 54 | guessed = guessed_by_relation[relation] 55 | gold = gold_by_relation[relation] 56 | prec = 1.0 57 | if guessed > 0: 58 | prec = float(correct) / float(guessed) 59 | recall = 0.0 60 | if gold > 0: 61 | recall = float(correct) / float(gold) 62 | f1 = 0.0 63 | if prec + recall > 0: 64 | f1 = 2.0 * prec * recall / (prec + recall) 65 | # (print the score) 66 | sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation)) 67 | sys.stdout.write(" P: ") 68 | if prec < 0.1: 69 | sys.stdout.write(" ") 70 | if prec < 1.0: 71 | sys.stdout.write(" ") 72 | sys.stdout.write("{:.2%}".format(prec)) 73 | sys.stdout.write(" R: ") 74 | if recall < 0.1: 75 | sys.stdout.write(" ") 76 | if recall < 1.0: 77 | sys.stdout.write(" ") 78 | sys.stdout.write("{:.2%}".format(recall)) 79 | sys.stdout.write(" F1: ") 80 | if f1 < 0.1: 81 | sys.stdout.write(" ") 82 | if f1 < 1.0: 83 | sys.stdout.write(" ") 84 | sys.stdout.write("{:.2%}".format(f1)) 85 | sys.stdout.write(" #: %d" % gold) 86 | sys.stdout.write("\n") 87 | print("") 88 | 89 | # Print the aggregate score 90 | if verbose: 91 | print("Final Score:") 92 | prec_micro = 1.0 93 | if sum(guessed_by_relation.values()) > 0: 94 | prec_micro = float(sum(correct_by_relation.values())) / float( 95 | sum(guessed_by_relation.values()) 96 | ) 97 | recall_micro = 0.0 98 | if sum(gold_by_relation.values()) > 0: 99 | recall_micro = float(sum(correct_by_relation.values())) / float( 100 | sum(gold_by_relation.values()) 101 | ) 102 | f1_micro = 0.0 103 | if prec_micro + recall_micro > 0.0: 104 | f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro) 105 | # print("Precision (micro): {:.3%}".format(prec_micro)) 106 | # print(" Recall (micro): {:.3%}".format(recall_micro)) 107 | # print(" F1 (micro): {:.3%}".format(f1_micro)) 108 | return prec_micro, recall_micro, f1_micro 109 | -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/vocab/vocab_sst.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/vocab/vocab_sst.pkl -------------------------------------------------------------------------------- /baselines/scd_soc/hiexpl/vocab/vocab_sst_bert.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/baselines/scd_soc/hiexpl/vocab/vocab_sst_bert.pkl -------------------------------------------------------------------------------- /baselines/shapley_interaction_index/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../src\")\n", 13 | "from si_explainer import SiExplainer\n", 14 | "sys.path.append(\"../../experiments/1. archdetect\")\n", 15 | "from synthetic_utils import *\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Parameters" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "function_id = 1\n", 35 | "\n", 36 | "p = 40 # num features\n", 37 | "input_value, base_value = 1, -1" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Get Data and Synthetic Function" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "function id: 1\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "input = np.array([input_value]*p)\n", 62 | "baseline = np.array([base_value]*p)\n", 63 | "\n", 64 | "print(\"function id:\", function_id)\n", 65 | "model = synth_model(function_id, input_value, base_value)\n", 66 | "gts = model.get_gts(p)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## Get Explanation" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "si_method = SiExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Individual Tests" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "num_T = 1\n", 99 | "\n", 100 | "inters = dict()\n", 101 | "for i in range(p):\n", 102 | " for j in range(i+1, p):\n", 103 | " if i == j: continue\n", 104 | " S = (i,j)\n", 105 | " \n", 106 | " att = si_method.attribution(S, num_T)\n", 107 | " inters[S] = att**2\n", 108 | " \n", 109 | "for i in range(p):\n", 110 | " att = si_method.attribution([i], num_T)**2" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "auc 1.0\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "print(\"auc\", get_auc(inters.items(), gts))" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "### Batch Version" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 7, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "mat = si_method.batch_attribution(num_T, pairwise=True)\n", 144 | "arr = si_method.batch_attribution(num_T, main_effects=True, pairwise=False)\n", 145 | "\n", 146 | "inters = {}\n", 147 | "for i in range(p):\n", 148 | " for j in range(i+1, p):\n", 149 | " inters[(i,j)] = mat[i, j]**2" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 8, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "auc 1.0\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "print(\"auc\", get_auc(inters.items(), gts))" 167 | ] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "Python [conda env:test]", 173 | "language": "python", 174 | "name": "conda-env-test-py" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.6.7" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 4 191 | } 192 | -------------------------------------------------------------------------------- /baselines/shapley_interaction_index/si_explainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from itertools import chain, combinations 4 | import math, random 5 | from explainer import Explainer 6 | 7 | 8 | def powerset(iterable): 9 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 10 | s = list(iterable) 11 | return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) 12 | 13 | 14 | def random_subset(s): 15 | out = [] 16 | for el in s: 17 | # random coin flip 18 | if random.randint(0, 1) == 0: 19 | out.append(el) 20 | return tuple(out) 21 | 22 | 23 | class SiExplainer(Explainer): 24 | def __init__( 25 | self, 26 | model, 27 | input=None, 28 | baseline=None, 29 | data_xformer=None, 30 | output_indices=0, 31 | batch_size=20, 32 | verbose=False, 33 | seed=None, 34 | ): 35 | Explainer.__init__( 36 | self, 37 | model, 38 | input, 39 | baseline, 40 | data_xformer, 41 | output_indices, 42 | batch_size, 43 | verbose, 44 | ) 45 | if seed is not None: 46 | random.seed(seed) 47 | 48 | def attribution(self, S, num_T): 49 | """ 50 | S: the interaction index set to get attributions for 51 | T: the input index set 52 | """ 53 | 54 | s = len(S) 55 | n = len(self.input) 56 | 57 | N_excl_S = [i for i in range(n) if i not in S] 58 | 59 | num_T = min(num_T, 2 ** len(N_excl_S)) 60 | 61 | random_T_set = set() 62 | for _ in range(num_T): 63 | T = random_subset(N_excl_S) 64 | while T in random_T_set: 65 | T = random_subset(N_excl_S) 66 | random_T_set.add(T) 67 | 68 | total_att = 0 69 | 70 | for T in random_T_set: 71 | t = len(T) 72 | 73 | n1 = math.factorial(n - t - s) 74 | n2 = math.factorial(t) 75 | d1 = math.factorial(n - s + 1) 76 | 77 | coef = (n1 * n2) / d1 78 | 79 | subsetsW = powerset(S) 80 | 81 | set_indices = [] 82 | for W in subsetsW: 83 | set_indices.append(tuple(set(W) | set(T))) 84 | 85 | scores_dict = self.batch_set_inference( 86 | set_indices, self.baseline, self.input, include_context=False 87 | ) 88 | scores = scores_dict["scores"] 89 | 90 | att = 0 91 | for i, W in enumerate(subsetsW): 92 | w = len(W) 93 | att += (-1) ** (w - s) * scores[set_indices[i]] 94 | 95 | total_att += coef * att 96 | 97 | return total_att 98 | 99 | def batch_attribution(self, num_T, main_effects=False, pairwise=True): 100 | """ 101 | S: the interaction index set to get attributions for 102 | T: the input index set 103 | """ 104 | 105 | def collect_att(S, S_T_Z_dict, Z_score_dict, n): 106 | s = len(S) 107 | 108 | subsetsW = powerset(S) 109 | 110 | total_att = 0 111 | 112 | for T in S_T_Z_dict[S]: 113 | 114 | att = 0 115 | for i, W in enumerate(subsetsW): 116 | w = len(W) 117 | att += (-1) ** (w - s) * Z_score_dict[S_T_Z_dict[S][T][i]] 118 | 119 | t = len(T) 120 | n1 = math.factorial(n - t - s) 121 | n2 = math.factorial(t) 122 | d1 = math.factorial(n - s + 1) 123 | 124 | coef = (n1 * n2) / d1 125 | total_att += coef * att 126 | 127 | return total_att 128 | 129 | n = len(self.input) 130 | num_features = n 131 | 132 | if main_effects == False and pairwise == False: 133 | raise ValueError() 134 | if main_effects == True and pairwise == True: 135 | raise ValueError() 136 | 137 | Ss = [] 138 | if pairwise: 139 | for i in range(num_features): 140 | for j in range(i + 1, num_features): 141 | S = (i, j) 142 | Ss.append(S) 143 | elif main_effects: 144 | for i in range(num_features): 145 | Ss.append(tuple([i])) 146 | 147 | Z_set = set() 148 | S_T_Z_dict = {} 149 | 150 | for S in Ss: 151 | s = len(S) 152 | 153 | N_excl_S = [i for i in range(n) if i not in S] 154 | num_T = min(num_T, 2 ** len(N_excl_S)) 155 | 156 | random_T_set = set() 157 | for _ in range(num_T): 158 | T = random_subset(N_excl_S) 159 | while T in random_T_set: 160 | T = random_subset(N_excl_S) 161 | random_T_set.add(tuple(T)) 162 | 163 | S_T_Z_dict[S] = {} 164 | 165 | subsetsW = powerset(S) 166 | 167 | for T in random_T_set: 168 | S_T_Z_dict[S][T] = [] 169 | 170 | for W in subsetsW: 171 | Z = tuple(set(W) | set(T)) 172 | Z_set.add(Z) 173 | S_T_Z_dict[S][T].append(Z) 174 | 175 | Z_list = list(Z_set) 176 | scores_dict = self.batch_set_inference( 177 | Z_list, self.baseline, self.input, include_context=False 178 | ) 179 | scores = scores_dict["scores"] 180 | Z_score_dict = scores 181 | 182 | if pairwise: 183 | res = np.zeros((num_features, num_features)) 184 | for i in range(num_features): 185 | for j in range(i + 1, num_features): 186 | S = (i, j) 187 | att = collect_att(S, S_T_Z_dict, Z_score_dict, n) 188 | res[i, j] = att 189 | return res 190 | 191 | elif main_effects: 192 | res = [] 193 | for i in range(num_features): 194 | S = tuple([i]) 195 | att = collect_att(S, S_T_Z_dict, Z_score_dict, n) 196 | res.append(att) 197 | 198 | return np.array(res) 199 | -------------------------------------------------------------------------------- /baselines/shapley_taylor_interaction_index/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"../../src\")\n", 13 | "from sti_explainer import StiExplainer\n", 14 | "sys.path.append(\"../../experiments/1. archdetect\")\n", 15 | "from synthetic_utils import *\n", 16 | "\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Parameters" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "function_id = 4\n", 35 | "\n", 36 | "p = 40 # num features\n", 37 | "input_value, base_value = 1, -1" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Get Data and Synthetic Function" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "function id: 4\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "input = np.array([input_value]*p)\n", 62 | "baseline = np.array([base_value]*p)\n", 63 | "\n", 64 | "print(\"function id:\", function_id)\n", 65 | "model = synth_model(function_id, input_value, base_value)\n", 66 | "gts = model.get_gts(p)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "f_diff = (model(input)-model(baseline)).item()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "## Get Explanation" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "sti_method = StiExplainer(model, input=input, baseline=baseline, output_indices=0, batch_size=20)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### Individual Tests" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "np.random.seed(42)\n", 108 | "\n", 109 | "def subset_before(i, j, ordering, ordering_dict):\n", 110 | " end_idx = min(ordering_dict[i], ordering_dict[j])\n", 111 | " return ordering[:end_idx]\n", 112 | "\n", 113 | "ordering = np.random.permutation(list(range(p)))\n", 114 | "ordering_dict = {ordering[i]: i for i in range(len(ordering))}\n", 115 | "\n", 116 | "att_sum = 0\n", 117 | "inters = {}\n", 118 | "for i in range(p):\n", 119 | " for j in range(0, p):\n", 120 | " if i >= j: continue\n", 121 | " T = subset_before(i, j, ordering, ordering_dict)\n", 122 | " S = (i,j)\n", 123 | " \n", 124 | " att = sti_method.attribution(S, T)\n", 125 | " att_sum+=att\n", 126 | " inters[S] = att\n", 127 | " \n", 128 | "for i in range(p):\n", 129 | " att = sti_method.attribution([i], [])\n", 130 | " att_sum += att" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "### Check Completeness" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 7, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "82.0 82.0\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "assert(att_sum == f_diff)\n", 155 | "print(att_sum, f_diff)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Batch Version" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "CPU times: user 13 s, sys: 93.9 ms, total: 13.1 s\n", 175 | "Wall time: 22.8 s\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "%%time\n", 181 | "num_orderings = 50\n", 182 | "mat = sti_method.batch_attribution(num_orderings, pairwise=True, seed=4)\n", 183 | "arr = sti_method.batch_attribution(num_orderings, main_effects=True, pairwise=False, seed=4)\n", 184 | "\n", 185 | "att_sum = mat.sum() + arr.sum()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "### Check Completeness" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 9, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "81.97242471249145 82.0\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "assert(round(att_sum) == f_diff)\n", 210 | "print(att_sum, f_diff)" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "Python [conda env:test]", 217 | "language": "python", 218 | "name": "conda-env-test-py" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.6.7" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 4 235 | } 236 | -------------------------------------------------------------------------------- /baselines/shapley_taylor_interaction_index/sti_explainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from itertools import chain, combinations 4 | from explainer import Explainer 5 | 6 | 7 | def powerset(iterable): 8 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 9 | s = list(iterable) 10 | return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) 11 | 12 | 13 | class StiExplainer(Explainer): 14 | def __init__( 15 | self, 16 | model, 17 | input=None, 18 | baseline=None, 19 | data_xformer=None, 20 | output_indices=0, 21 | batch_size=20, 22 | verbose=False, 23 | ): 24 | Explainer.__init__( 25 | self, 26 | model, 27 | input, 28 | baseline, 29 | data_xformer, 30 | output_indices, 31 | batch_size, 32 | verbose, 33 | ) 34 | 35 | def attribution(self, S, T): 36 | """ 37 | S: the interaction index set to get attributions for 38 | T: the input index set 39 | """ 40 | subsetsW = powerset(S) 41 | 42 | set_indices = [] 43 | for W in subsetsW: 44 | set_indices.append(tuple(set(W) | set(T))) 45 | 46 | scores_dict = self.batch_set_inference( 47 | set_indices, self.baseline, self.input, include_context=False 48 | ) 49 | scores = scores_dict["scores"] 50 | 51 | att = 0 52 | for i, W in enumerate(subsetsW): 53 | w = len(W) 54 | s = len(S) 55 | att += (-1) ** (w - s) * scores[set_indices[i]] 56 | 57 | return att 58 | 59 | def batch_attribution( 60 | self, num_orderings, main_effects=False, pairwise=True, seed=None, max_order=2 61 | ): 62 | def collect_att(S, S_T_Z_dict, Z_score_dict, n): 63 | s = len(S) 64 | subsetsW = powerset(S) 65 | 66 | total_att = 0 67 | 68 | for T in S_T_Z_dict[S]: 69 | 70 | att = 0 71 | for i, W in enumerate(subsetsW): 72 | w = len(W) 73 | att += (-1) ** (w - s) * Z_score_dict[S_T_Z_dict[S][T][i]] 74 | 75 | total_att += att 76 | 77 | num_orderings = len(S_T_Z_dict[S]) 78 | return total_att / num_orderings 79 | 80 | num_features = len(self.input) 81 | 82 | if main_effects == False and pairwise == False: 83 | raise ValueError() 84 | if main_effects == True and pairwise == True: 85 | raise ValueError() 86 | 87 | Ss = [] 88 | if pairwise: 89 | for i in range(num_features): 90 | for j in range(i + 1, num_features): 91 | S = (i, j) 92 | Ss.append(S) 93 | elif main_effects: 94 | for i in range(num_features): 95 | Ss.append(tuple([i])) 96 | 97 | Z_set = set() 98 | S_T_Z_dict = dict() 99 | for S in Ss: 100 | subsetsW = powerset(S) 101 | S_T_Z_dict[S] = {} 102 | 103 | if seed is not None: 104 | np.random.seed(seed) 105 | for _ in range(num_orderings): 106 | ordering = np.random.permutation(list(range(num_features))) 107 | ordering_dict = {ordering[i]: i for i in range(len(ordering))} 108 | 109 | if len(S) == max_order: 110 | T = subset_before(S, ordering, ordering_dict) 111 | else: 112 | T = [] 113 | T = tuple(T) 114 | 115 | S_T_Z_dict[S][T] = [] 116 | 117 | set_indices = [] 118 | for W in subsetsW: 119 | Z = tuple(set(W) | set(T)) 120 | Z_set.add(Z) 121 | S_T_Z_dict[S][T].append(Z) 122 | 123 | Z_list = list(Z_set) 124 | 125 | scores_dict = self.batch_set_inference( 126 | Z_list, self.baseline, self.input, include_context=False 127 | ) 128 | scores = scores_dict["scores"] 129 | Z_score_dict = scores 130 | 131 | # Z_score_dict = {Z: scores[Z_idx] for Z_idx, Z in enumerate(Z_list)} 132 | 133 | if pairwise: 134 | res = np.zeros((num_features, num_features)) 135 | for i in range(num_features): 136 | for j in range(i + 1, num_features): 137 | S = (i, j) 138 | att = collect_att(S, S_T_Z_dict, Z_score_dict, num_features) 139 | res[i, j] = att 140 | return res 141 | 142 | elif main_effects: 143 | res = [] 144 | for i in range(num_features): 145 | S = tuple([i]) 146 | att = collect_att(S, S_T_Z_dict, Z_score_dict, num_features) 147 | res.append(att) 148 | return np.array(res) 149 | 150 | 151 | def subset_before(S, ordering, ordering_dict): 152 | end_idx = min(ordering_dict[s] for s in S) 153 | return ordering[:end_idx] 154 | -------------------------------------------------------------------------------- /demos/1. text analysis/demo_bert_torch_interactive.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from transformers import *\n", 11 | "import numpy as np\n", 12 | "import sys\n", 13 | "\n", 14 | "sys.path.append(\"../../src\")\n", 15 | "from explainer import Archipelago\n", 16 | "from application_utils.text_utils import *\n", 17 | "from application_utils.text_utils_torch import BertWrapperTorch\n", 18 | "from viz.text import interactive_viz_text\n", 19 | "\n", 20 | "%load_ext autoreload\n", 21 | "%autoreload 2" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Get Model" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "device = torch.device(\"cuda:0\")\n", 38 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 39 | "\n", 40 | "model_path = \"../../downloads/pretrained_bert\"\n", 41 | "model = BertForSequenceClassification.from_pretrained(model_path)\n", 42 | "model_wrapper = BertWrapperTorch(model, device)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## Define Text" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "text = \" It 's predictable , but it jumps through the expected hoops with style and even some depth .\"\n", 59 | "baseline_token = \"_\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Get Sentiment" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "positive 4.163356781005859\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "text_ids, baseline_ids = get_input_baseline_ids(text, baseline_token, tokenizer)\n", 84 | "\n", 85 | "class_idx = 1\n", 86 | "logit = model_wrapper([text_ids])[0,class_idx].item()\n", 87 | "polarity = \"positive\" if logit > 0 else \"negative\"\n", 88 | "print(polarity, logit)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "## Explain Prediction" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "xf = TextXformer(text_ids, baseline_ids) \n", 105 | "apgo = Archipelago(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20, interactive=True)\n", 106 | "exps, max_magn = apgo.get_interactive_explanations()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## Interactive Visualization" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "application/vnd.jupyter.widget-view+json": { 124 | "model_id": "1ddf1eaa51c4409491c6bea260c0ffcc", 125 | "version_major": 2, 126 | "version_minor": 0 127 | }, 128 | "text/plain": [ 129 | "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" 130 | ] 131 | }, 132 | "metadata": {}, 133 | "output_type": "display_data" 134 | }, 135 | { 136 | "data": { 137 | "application/vnd.jupyter.widget-view+json": { 138 | "model_id": "668624f549b74fb2b08061c5e7999d9a", 139 | "version_major": 2, 140 | "version_minor": 0 141 | }, 142 | "text/plain": [ 143 | "interactive(children=(IntSlider(value=3, description='k', max=19), Output()), _dom_classes=('widget-interact',…" 144 | ] 145 | }, 146 | "metadata": {}, 147 | "output_type": "display_data" 148 | } 149 | ], 150 | "source": [ 151 | "%matplotlib widget\n", 152 | "tokens = get_token_list(text_ids, tokenizer) \n", 153 | "interactive_viz_text(exps, tokens, process_stop_words, init_k=3, max_magn=max_magn, fontsize=10)" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.12" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 4 178 | } 179 | -------------------------------------------------------------------------------- /demos/2. image classification/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/2. image classification/dog.jpg -------------------------------------------------------------------------------- /demos/3. recommendation/autoint/README.md: -------------------------------------------------------------------------------- 1 | # AutoInt 2 | 3 | This is a TenforFlow implementation of ***AutoInt*** for CTR prediction task, as described in our paper: 4 | 5 | Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang and Jian Tang. [AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf). arXiv preprint arXiv:1810.11921, 2018. 6 | 7 | ## Requirements: 8 | * **Tensorflow 1.4.0-rc1** 9 | * Python 3 10 | * CUDA 8.0+ (For GPU) 11 | 12 | ## Introduction 13 | 14 | AutoInt:An effective and efficient algorithm to 15 | automatically learn high-order feature interactions for (sparse) categorical and numerical features. 16 | 17 |
18 | 19 |
20 | The illustration of AutoInt. We first project all sparse features 21 | (both categorical and numerical features) into the low-dimensional space. Next, we feed embeddings of all fields into stacked multiple interacting layers implemented by self-attentive neural network. The output of the final interacting layer is the low-dimensional representation of learnt combinatorial features, which is further used for estimating the CTR via sigmoid function. 22 | 23 | ## Usage 24 | ### Input Format 25 | AutoInt requires the input data in the following format: 26 | * train_x: matrix with shape *(num_sample, num_field)*. train_x[s][t] is the feature value of feature field t of sample s in the dataset. The default value for categorical feature is 1. 27 | * train_i: matrix with shape *(num_sample, num_field)*. train_i[s][t] is the feature index of feature field t of sample s in the dataset. The maximal value of train_i is the feature size. 28 | * train_y: label of each sample in the dataset. 29 | 30 | If you want to know how to preprocess the data, please refer to `./Dataprocess/Criteo/preprocess.py` 31 | 32 | ### Example 33 | We use four public real-world datasets(Avazu, Criteo, KDD12, MovieLens-1M) in our experiments. Since the first three datasets are super huge, they can not be fit into the memory as a whole. In our implementation, we split the whole dataset into 10 parts and we use the first file as test set and the second file as valid set. We provide the codes for preprocessing these three datasets in `./Dataprocess`. If you want to reuse these codes, you should first run `preprocess.py` to generate `train_x.txt, train_i.txt, train_y.txt` as described in `Input Format`. Then you should run `./Dataprocesss/Kfold_split/StratifiedKfold.py` to split the whole dataset into ten folds. Finally you can run `scale.py` to scale the numerical value(optional). 34 | 35 | To help test the correctness of the code and familarize yourself with the code, we upload the first `10000` samples of `Criteo` dataset in `train_examples.txt`. And we provide the scripts for preprocessing and training.(Please refer to ` sample_preprocess.sh` and `test_code.sh`, you may need to modify the path in `config.py` and `test_code.sh`). 36 | 37 | After you run the `test_code.sh`, you should get a folder named `Criteo` which contains `part*, feature_size.npy, fold_index.npy, train_*.txt`. `feature_size.npy` contains the number of total features which will be used to initialize the model. `train_*.txt` is the whole dataset. If you use other small dataset, say `MovieLens-1M`, you only need to modify the function `_run_` in `train.py`. 38 | 39 | Here's how to run the preprocessing. 40 | ``` 41 | mkdir Criteo 42 | python ./Dataprocess/Criteo/preprocess.py 43 | python ./Dataprocess/Kfold_split/stratifiedKfold.py 44 | python ./Dataprocess/Criteo/scale.py 45 | ``` 46 | 47 | Here's how to run the training. 48 | ``` 49 | python -u train.py \ 50 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \ 51 | --is_save "True" --save_path "./test_code/Criteo/b3h2_64x64x64/" \ 52 | --field_size 39 --run_times 1 --data_path "./" \ 53 | --epoch 3 --has_residual "True" --has_wide "False" \ 54 | --batch_size 1024 \ 55 | > test_code_single.out & 56 | ``` 57 | 58 | You should see output like this: 59 | 60 | ``` 61 | ... 62 | train logs 63 | ... 64 | start testing!... 65 | restored from ./test_code/Criteo/b3h2_dnn_dropkeep1_400x2/1/ 66 | test-result = 0.8088, test-logloss = 0.4430 67 | test_auc [0.8088305055534442] 68 | test_log_loss [0.44297631300399626] 69 | avg_auc 0.8088305055534442 70 | avg_log_loss 0.44297631300399626 71 | ``` 72 | 73 | ## Citation 74 | If you find AutoInt useful for your research, please consider citing the following paper: 75 | ``` 76 | @article{weiping2018autoint, 77 | title={AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks}, 78 | author={Weiping, Song and Chence, Shi and Zhiping, Xiao and Zhijian, Duan and Yewen, Xu and Ming, Zhang and Jian, Tang}, 79 | journal={arXiv preprint arXiv:1810.11921}, 80 | year={2018} 81 | } 82 | ``` 83 | 84 | 85 | ## Contact information 86 | If you have questions related to the code, feel free to contact Weiping Song (`songweiping@pku.edu.cn`), Chence Shi (`chenceshi@pku.edu.cn`) and Zhijian Duan (`zjduan@pku.edu.cn`). 87 | 88 | ## License 89 | MIT 90 | 91 | ## Acknowledgement 92 | This implementation gets inspirations from Kyubyong Park's [transformer](https://github.com/Kyubyong/transformer) and Chenglong Chen' [DeepFM](https://github.com/ChenglongChen/tensorflow-DeepFM). 93 | -------------------------------------------------------------------------------- /demos/4. covid classification/covid_xray.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/4. covid classification/covid_xray.jpg -------------------------------------------------------------------------------- /demos/figures/covid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/figures/covid.png -------------------------------------------------------------------------------- /demos/figures/interactive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/figures/interactive.gif -------------------------------------------------------------------------------- /demos/figures/recommendation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/figures/recommendation.png -------------------------------------------------------------------------------- /demos/figures/sentiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/demos/figures/sentiment.png -------------------------------------------------------------------------------- /demos/requirements.txt: -------------------------------------------------------------------------------- 1 | ipympl==0.5.8 2 | ipywidgets==7.5.1 3 | jupyterlab==2.2.8 4 | matplotlib==3.3.2 5 | numpy==1.18.5 6 | opencv_python==4.4.0.44 7 | Pillow==6.2.2 8 | requests==2.22.0 9 | scikit-image==0.16.2 10 | scikit-learn==0.21.3 11 | scipy==1.5.2 12 | tensorflow-gpu==2.3.1 13 | torch==1.6.0 14 | torchvision==0.7.0 15 | tqdm==4.32.2 16 | transformers==2.9.0 -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import os 4 | import shutil 5 | 6 | 7 | def str2bool(v): 8 | if isinstance(v, bool): 9 | return v 10 | if v.lower() in ("yes", "true", "t", "y", "1"): 11 | return True 12 | elif v.lower() in ("no", "false", "f", "n", "0"): 13 | return False 14 | else: 15 | raise argparse.ArgumentTypeError("Boolean value expected.") 16 | 17 | 18 | # from this StackOverflow answer: https://stackoverflow.com/a/39225039 19 | def download_file_from_google_drive(id, destination): 20 | URL = "https://docs.google.com/uc?export=download" 21 | 22 | session = requests.Session() 23 | 24 | response = session.get(URL, params={"id": id}, stream=True) 25 | token = get_confirm_token(response) 26 | 27 | if token: 28 | params = {"id": id, "confirm": token} 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | save_response_content(response, destination) 32 | 33 | 34 | def get_confirm_token(response): 35 | for key, value in response.cookies.items(): 36 | if key.startswith("download_warning"): 37 | return value 38 | 39 | return None 40 | 41 | 42 | def save_response_content(response, destination): 43 | CHUNK_SIZE = 32768 44 | 45 | with open(destination, "wb") as f: 46 | for chunk in response.iter_content(CHUNK_SIZE): 47 | if chunk: # filter out keep-alive new chunks 48 | f.write(chunk) 49 | 50 | 51 | parser = argparse.ArgumentParser(description="Download data and pretrained models.") 52 | parser.add_argument( 53 | "--quick_demo", 54 | type=str2bool, 55 | nargs="?", 56 | const=True, 57 | default=False, 58 | help="Downloads all data and pretrained models for the quick demo.", 59 | ) 60 | parser.add_argument( 61 | "--demos", 62 | type=str2bool, 63 | nargs="?", 64 | const=True, 65 | default=False, 66 | help="Downloads all data and pretrained models for the demos.", 67 | ) 68 | parser.add_argument( 69 | "--experiments", 70 | type=str2bool, 71 | nargs="?", 72 | const=True, 73 | default=False, 74 | help="Downloads all data and pretrained models for the experiments.", 75 | ) 76 | parser.add_argument( 77 | "--all", 78 | type=str2bool, 79 | nargs="?", 80 | const=True, 81 | default=False, 82 | help="Downloads all data and pretrained models.", 83 | ) 84 | parser.add_argument( 85 | "--downloads_folder", 86 | type=str, 87 | default="downloads", 88 | help="Name of downloads folder. All default code points to 'downloads'", 89 | ) 90 | 91 | 92 | args = parser.parse_args() 93 | quick_demo = args.quick_demo 94 | demos = args.demos 95 | experiments = args.experiments 96 | download_all = args.all 97 | downloads_folder = args.downloads_folder 98 | 99 | all_options = [quick_demo, demos, experiments, download_all] 100 | 101 | if not any(all_options): 102 | download_all = True 103 | 104 | if sum(all_options) > 1: 105 | raise ValueError("Cannot enable multiple options.") 106 | 107 | 108 | pretrained_model_ids = { 109 | "bert": "1sUqMqCqoZEjEuNEt6MQZQ2obJVm_r4Vt", 110 | "covid_net": "1aoZ9RTJeuAxPEMYo1ytYbmyQAJkaYIHo", 111 | "hiexpl_lm": "1hU9EmzdtL8s21PgnYxH6tSLfKSndggKf", 112 | "autoint": "1I2jzt_zLlmUB2RO3XjbHj1evmXN3CDuu", 113 | } 114 | 115 | data_ids = { 116 | "sst": "1iBrbVQrFzDfjWl-1Pv05tp-2IPeaGUFK", 117 | "avazu": "1NXzDvOFxAPMj4oAr_KgkEsIoURg0JCwH", 118 | } 119 | 120 | destination = "/home/myusername/work/myfile.ext" 121 | 122 | 123 | if quick_demo: 124 | keys = ["bert", "sst"] 125 | 126 | elif demos: 127 | keys = ["bert", "covid_net", "autoint", "sst", "avazu"] 128 | 129 | elif experiments: 130 | keys = ["bert", "hiexpl_lm", "sst"] 131 | 132 | elif download_all: 133 | keys = ["bert", "covid_net", "hiexpl_lm", "autoint", "sst", "avazu"] 134 | else: 135 | raise ValueError 136 | 137 | 138 | if not os.path.exists(downloads_folder): 139 | os.makedirs(downloads_folder) 140 | 141 | subfolders = next(os.walk(downloads_folder))[1] 142 | 143 | for key in keys: 144 | if key in pretrained_model_ids: 145 | file_id = pretrained_model_ids[key] 146 | dest_id = "pretrained_" + key 147 | elif key in data_ids: 148 | file_id = data_ids[key] 149 | dest_id = key + "_data" 150 | else: 151 | raise ValueError 152 | 153 | if dest_id in subfolders: 154 | print( 155 | "Found " 156 | + dest_id 157 | + "/ already in " 158 | + downloads_folder 159 | + "/. Skipping download for it." 160 | ) 161 | continue 162 | 163 | destination = downloads_folder + "/" + dest_id + ".zip" 164 | download_file_from_google_drive(file_id, destination) 165 | 166 | shutil.unpack_archive(destination, downloads_folder) 167 | os.remove(destination) 168 | -------------------------------------------------------------------------------- /experiments/1. archdetect/2. redundancy_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from transformers import *\n", 11 | "import scipy\n", 12 | "import numpy as np\n", 13 | "import sys, os\n", 14 | "from tqdm import tqdm\n", 15 | "\n", 16 | "sys.path.append(\"../../src\")\n", 17 | "from context_explainer import ContextExplainer\n", 18 | "from application_utils.text_utils import *\n", 19 | "from application_utils.text_utils_torch import BertWrapperTorch\n", 20 | "\n", 21 | "%load_ext autoreload\n", 22 | "%autoreload 2" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "save_path = \"results/bert_random_context_only.pickle\"\n", 32 | "random_context_only = True" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Get Model" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "task = 'sst-2'\n", 49 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", 50 | "\n", 51 | "model_path = \"../../downloads/pretrained_bert\"\n", 52 | "model = BertForSequenceClassification.from_pretrained(model_path);" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "device = torch.device(\"cuda:0\")\n", 62 | "class_idx = 1\n", 63 | "model_wrapper = BertWrapperTorch(model, device)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Get Sentences" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "sentences = get_sst_sentences(split=\"test\", path=\"../../downloads/sst_data/sst_trees.pickle\")\n", 80 | "baseline_token = \"_\"" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 6, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "if os.path.exists(save_path):\n", 90 | " with open(save_path, 'rb') as handle:\n", 91 | " all_res = pickle.load(handle)\n", 92 | "else:\n", 93 | " all_res = []" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## Run Experiment" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 7, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "np.random.seed(42)\n", 110 | "\n", 111 | "for s_idx, text in enumerate(tqdm(sentences)):\n", 112 | " \n", 113 | " if s_idx < len(all_res):\n", 114 | " # if an experiment is already done, skip it\n", 115 | " assert(all_res[s_idx][\"text\"] == text)\n", 116 | " print(\"skip\", s_idx)\n", 117 | " continue\n", 118 | "\n", 119 | " text_ids, baseline_ids = get_input_baseline_ids(text, baseline_token, tokenizer)\n", 120 | "\n", 121 | " xf = TextXformer(text_ids, baseline_ids) \n", 122 | " ctx = ContextExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20, verbose=False)\n", 123 | "\n", 124 | " context1 = ctx.input\n", 125 | " context2 = ctx.baseline\n", 126 | "\n", 127 | " n_samples = 9\n", 128 | "\n", 129 | " new_contexts = []\n", 130 | " if random_context_only:\n", 131 | " seen_contexts_tuples = []\n", 132 | " n_samples += 2\n", 133 | " else:\n", 134 | " seen_contexts_tuples = [tuple(context1), tuple(context2)]\n", 135 | "\n", 136 | " for n in range(n_samples):\n", 137 | " while True:\n", 138 | " context = np.random.randint(0, high=2, size=len(context1)).astype(bool)\n", 139 | " context_tuple = tuple(context)\n", 140 | " if context_tuple not in seen_contexts_tuples:\n", 141 | " break\n", 142 | " new_contexts.append(context)\n", 143 | " seen_contexts_tuples.append(context_tuple)\n", 144 | "\n", 145 | " if random_context_only:\n", 146 | " all_contexts = new_contexts\n", 147 | " else:\n", 148 | " all_contexts = [context1, context2] + new_contexts\n", 149 | "\n", 150 | " res = ctx.detect_with_running_contexts(all_contexts)\n", 151 | " \n", 152 | " all_res.append({\"text\": text, \"result\": res})\n", 153 | " \n", 154 | " if (s_idx+1) % 3 == 0: \n", 155 | " with open(save_path, 'wb') as handle:\n", 156 | " pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", 157 | "\n", 158 | "with open(save_path, 'wb') as handle:\n", 159 | " pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python [conda env:test]", 166 | "language": "python", 167 | "name": "conda-env-test-py" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.6.7" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 4 184 | } 185 | -------------------------------------------------------------------------------- /experiments/1. archdetect/context_explainer.py: -------------------------------------------------------------------------------- 1 | from explainer import Archipelago 2 | import copy 3 | 4 | 5 | class ContextExplainer(Archipelago): 6 | def __init__( 7 | self, 8 | model, 9 | input=None, 10 | baseline=None, 11 | data_xformer=None, 12 | output_indices=0, 13 | batch_size=20, 14 | verbose=True, 15 | ): 16 | super().__init__( 17 | model, input, baseline, data_xformer, output_indices, batch_size, verbose 18 | ) 19 | 20 | def detect_with_running_contexts(self, contexts): 21 | """ 22 | Detects interactions and sorts them 23 | Optional: gets archipelago main effects and/or pairwise effects from function reuse 24 | """ 25 | inter_scores_each_context = [] 26 | inter_scores_running = {} 27 | for n, context in enumerate(contexts): 28 | insertion_target = [] 29 | for i, c in enumerate(context): 30 | assert c in {self.baseline[i], self.input[i]} 31 | if c == self.baseline[i]: 32 | insertion_target.append(self.input[i]) 33 | else: 34 | insertion_target.append(self.baseline[i]) 35 | 36 | search = self.search_feature_sets(context, insertion_target) 37 | 38 | context_inters = search["interactions"] 39 | 40 | for pair in context_inters: 41 | if pair not in inter_scores_running: 42 | inter_scores_running[pair] = 0 43 | inter_scores_running[pair] += context_inters[pair] ** 2 44 | 45 | inter_scores = copy.deepcopy(inter_scores_running) 46 | for key in inter_scores: 47 | inter_scores[key] = inter_scores[key] / (n + 1) 48 | sorted_scores = sorted(inter_scores.items(), key=lambda kv: -kv[1]) 49 | 50 | inter_scores_each_context.append(sorted_scores) 51 | 52 | return inter_scores_each_context 53 | 54 | def detect_with_contexts(self, contexts): 55 | """ 56 | Detects interactions and sorts them 57 | Optional: gets archipelago main effects and/or pairwise effects from function reuse 58 | """ 59 | inters = [] 60 | for context in contexts: 61 | insertion_target = [] 62 | for i, c in enumerate(context): 63 | assert c in {self.baseline[i], self.input[i]} 64 | if c == self.baseline[i]: 65 | insertion_target.append(self.input[i]) 66 | else: 67 | insertion_target.append(self.baseline[i]) 68 | 69 | search = self.search_feature_sets(context, insertion_target) 70 | inters.append(search["interactions"]) 71 | 72 | inter_scores = {} 73 | for pair in inters[0]: 74 | avg_score = 0 75 | for inter in inters: 76 | avg_score += inter[pair] ** 2 77 | inter_scores[pair] = avg_score / len(inters) 78 | sorted_scores = sorted(inter_scores.items(), key=lambda kv: -kv[1]) 79 | 80 | output = {"interactions": sorted_scores} 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /experiments/1. archdetect/redundancy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/experiments/1. archdetect/redundancy.png -------------------------------------------------------------------------------- /experiments/1. archdetect/synthetic_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import itertools 6 | from sklearn import metrics 7 | 8 | 9 | class synth_model: 10 | def __init__(self, test_id, input_value, base_value): 11 | self.test_id = test_id 12 | self.input_value = input_value 13 | self.base_value = base_value 14 | 15 | def and_func(self, X, inter): 16 | bool_cols = [] 17 | indices = [] 18 | for i, val in inter: 19 | bool_cols.append(X[:, i] == val) 20 | indices.append(i) 21 | bool_out = np.all(np.array(bool_cols), axis=0) 22 | gt = list(itertools.combinations(indices, 2)) 23 | return np.where(bool_out, 1, -1), gt 24 | 25 | def preprocess(self, X): 26 | Y = np.zeros(X.shape[0]) 27 | Y += X.sum(1) 28 | p = X.shape[1] 29 | q = p // 4 30 | gts = [] 31 | return Y, p, q, gts 32 | 33 | def synth0(self, X): 34 | Y, p, q, gts = self.preprocess(X) # simple sum, no interactions 35 | return Y, gts 36 | 37 | def synth1(self, X): 38 | Y, p, q, gts = self.preprocess(X) 39 | # Y += X.sum(1) 40 | gts = [] 41 | for i in range(q): 42 | for j in range(q): 43 | Y += X[:, i] * X[:, j] 44 | gts.append((i, j)) 45 | for i in range(q, q * 2): 46 | for j in range(q * 2, q * 3): 47 | Y += X[:, i] * X[:, j] 48 | gts.append((i, j)) 49 | return Y, gts 50 | 51 | def synth2(self, X): 52 | Y, p, q, gts = self.preprocess(X) 53 | # Y += X.sum(1) 54 | Y1, gt1 = self.and_func(X, [(i, self.input_value) for i in range(q * 2)]) 55 | Y2, gt2 = self.and_func(X, [(i, self.input_value) for i in range(q, q * 3)]) 56 | Y += Y1 + Y2 57 | gts += gt1 + gt2 58 | return Y, gts 59 | 60 | def synth3(self, X): 61 | Y, p, q, gts = self.preprocess(X) 62 | # Y += X.sum(1) 63 | 64 | Y1, gt1 = self.and_func(X, [(i, self.base_value) for i in range(q * 2)]) 65 | Y2, gt2 = self.and_func(X, [(i, self.input_value) for i in range(q, q * 3)]) 66 | Y += Y1 + Y2 67 | gts += gt1 + gt2 68 | return Y, gts 69 | 70 | def synth4(self, X): 71 | Y, p, q, gts = self.preprocess(X) 72 | # Y += X.sum(1) 73 | 74 | range1 = [(i, self.input_value) for i in range(2)] 75 | range2 = [(i, self.base_value) for i in range(2, 3)] 76 | Y1, gt1 = self.and_func(X, range1 + range2) 77 | Y2, gt2 = self.and_func(X, [(i, self.input_value) for i in range(q, q * 3)]) 78 | Y += Y1 + Y2 79 | gts += gt1 + gt2 80 | return Y, gts 81 | 82 | def synth(self, X): 83 | if self.test_id == 0: 84 | Y, gts = self.synth0(X) 85 | elif self.test_id == 1: 86 | Y, gts = self.synth1(X) 87 | elif self.test_id == 2: 88 | Y, gts = self.synth2(X) 89 | elif self.test_id == 3: 90 | Y, gts = self.synth3(X) 91 | elif self.test_id == 4: 92 | Y, gts = self.synth4(X) 93 | else: 94 | raise ValueError 95 | return Y, gts 96 | 97 | def get_gts(self, num_features): 98 | X = np.ones((1, num_features)) * self.input_value 99 | _, gts = self.synth(X) 100 | return gts 101 | 102 | def __call__(self, X): 103 | X = np.array(X) 104 | if len(X.shape) == 1: 105 | X = np.expand_dims(X, 0) 106 | Y, _ = self.synth(X) 107 | return np.expand_dims(Y, 1) 108 | 109 | 110 | def get_auc(inter_scores, gts): 111 | gt_vec = [] 112 | pred_vec = [] 113 | for inter in inter_scores: 114 | # print(inter[0]) 115 | # print(inter) 116 | pred_vec.append(inter[1]) 117 | if inter[0] in gts: 118 | gt_vec.append(1) 119 | else: 120 | gt_vec.append(0) 121 | 122 | fpr, tpr, thresholds = metrics.roc_curve(gt_vec, pred_vec, pos_label=1) 123 | auc = metrics.auc(fpr, tpr) 124 | return auc 125 | 126 | 127 | def gen_data_samples(model, input_value, base_value, p, n=30000, seed=None): 128 | if seed is not None: 129 | np.random.seed(seed) 130 | X = [] 131 | for i in range(n): 132 | X.append(np.random.choice([input_value, base_value], p)) 133 | X = np.stack(X) 134 | 135 | Y = model(X).squeeze() 136 | return X, Y 137 | -------------------------------------------------------------------------------- /experiments/2. archattribute/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pickle 4 | import requests 5 | from PIL import Image 6 | import torch 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | 11 | def prep_imagenet_coco_conversion( 12 | coco, 13 | i1k_labels_url="https://s3.amazonaws.com/outcome-blog/imagenet/labels.json", 14 | data_dir="/meladyfs/newyork/datasets/mscoco", 15 | data_type="val2017", 16 | coco_to_i1k_path="processed_data/image_data/coco_to_i1k_map.pickle", 17 | ): 18 | # get imagenet labels 19 | i1k_labels = { 20 | int(key): value for (key, value) in requests.get(i1k_labels_url).json().items() 21 | } 22 | i1k_labels_rev = {v: k for k, v in i1k_labels.items()} 23 | 24 | # maps coco labels to imagenet labels 25 | with open(coco_to_i1k_path, "rb") as handle: 26 | coco_to_i1k_map = pickle.load(handle) 27 | 28 | # get ms coco data 29 | # display COCO categories and supercategories 30 | cats = coco.loadCats(coco.getCatIds()) 31 | cat_nms = [cat["name"] for cat in cats] 32 | # print('COCO categories: \n{}\n'.format(' '.join(cat_nms))) 33 | 34 | supercat_nms = set([cat["supercategory"] for cat in cats]) 35 | # print('COCO supercategories: \n{}'.format(' '.join(supercat_nms))) 36 | 37 | # maps a category id to its name and the name of its supercategory 38 | cat_map = {cat["id"]: (cat["name"], cat["supercategory"]) for cat in cats} 39 | 40 | # get category ids that intersect with imagenet labels 41 | valid_cats = list(coco_to_i1k_map.keys()) 42 | valid_cat_ids = coco.getCatIds(catNms=[k[0] for k in valid_cats]) 43 | 44 | # maps imagenet label indices to coco categories 45 | i1k_idx_to_cat = {} 46 | for nm in valid_cats: 47 | for i1k_label in coco_to_i1k_map[nm]: 48 | i1k_idx = i1k_labels_rev[i1k_label[0]] 49 | if i1k_idx not in i1k_idx_to_cat: 50 | i1k_idx_to_cat[i1k_idx] = set() 51 | i1k_idx_to_cat[i1k_idx].add(nm) 52 | 53 | return i1k_idx_to_cat, valid_cat_ids, cat_map 54 | 55 | 56 | # pytorch transformation functions for preprocessing images 57 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 58 | 59 | preprocess = transforms.Compose( 60 | [ 61 | transforms.Resize((224, 224)), 62 | transforms.ToTensor(), 63 | normalize, 64 | ] 65 | ) 66 | preprocess_mask = transforms.Compose( 67 | [ 68 | transforms.Resize((224, 224)), 69 | transforms.ToTensor(), 70 | ] 71 | ) 72 | 73 | 74 | def transform_img(image, preprocess, device): 75 | image_tensor = preprocess(Image.fromarray(image)) 76 | denom = image_tensor.abs().max().item() 77 | image = image_tensor.cpu().numpy().transpose(1, 2, 0) / denom 78 | image_tensor = image_tensor.unsqueeze_(0).to(device) / denom 79 | return image, image_tensor 80 | 81 | 82 | def match_segments_and_mask(segments, mask_orig, ratio_threshold=0.5): 83 | inter = [] 84 | for seg in np.unique(segments): 85 | seg_mask = 1 * (segments == seg) 86 | # if the original mask overlaps with > 50% of a superpixel segment, count that segment as 87 | # part of the mask 88 | ratio = (seg_mask * mask_orig[:, :, 0]).sum() / seg_mask.sum() 89 | if ratio > ratio_threshold: 90 | inter.append(seg) 91 | return inter 92 | 93 | 94 | def generate_perturbation_dataset_bert( 95 | data_inst, 96 | model, 97 | class_idx, 98 | device, 99 | num_samples=6000, 100 | batch_size=100, 101 | seed=None, 102 | model_id=None, 103 | **kwargs 104 | ): 105 | sys.path.append("../../../baselines/mahe_madex/madex/") 106 | from utils.general_utils import set_seed, proprocess_data 107 | from sampling_and_inference import generate_binary_perturbations 108 | 109 | if seed is not None: 110 | set_seed(seed) 111 | 112 | target_ids = data_inst["target"] 113 | baseline_ids = data_inst["baseline"] 114 | samples_binary = generate_binary_perturbations(len(target_ids), num_samples, True) 115 | n_batches = int(np.ceil(num_samples / batch_size)) 116 | samples_labels = [] 117 | for i in tqdm(range(n_batches)): 118 | samples_binary_batch = samples_binary[i * batch_size : (i + 1) * batch_size] 119 | perturbed_text = [] 120 | for sample_binary in samples_binary_batch: 121 | vec = target_ids.copy() 122 | vec[sample_binary == 0] = baseline_ids[sample_binary == 0] 123 | 124 | perturbed_text.append(vec) 125 | preds = ( 126 | model(torch.LongTensor(np.stack(perturbed_text)).to(device))[0] 127 | .data.cpu() 128 | .numpy() 129 | ) 130 | samples_labels.append(preds) 131 | samples_labels = np.concatenate(samples_labels) 132 | Xs, Ys = proprocess_data(samples_binary, samples_labels[:, class_idx], **kwargs) 133 | 134 | return Xs, Ys 135 | 136 | 137 | def convert_spans_to_interactions(spans): 138 | inters = [] 139 | for span in spans: 140 | inter = tuple(range(span[0], span[1] + 1)) 141 | inters.append(inter) 142 | return inters 143 | -------------------------------------------------------------------------------- /experiments/2. archattribute/parallel_mahe/mahe_text_correlation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import pickle 3 | import torch 4 | from transformers import * 5 | import torch.multiprocessing as multiprocessing 6 | from itertools import repeat 7 | from tqdm import tqdm 8 | 9 | sys.path.append("..") 10 | from experiment_utils import * 11 | 12 | sys.path.append("../../../src") 13 | from application_utils.text_utils import * 14 | 15 | sys.path.append("../../../baselines/mahe_madex/madex/") 16 | sys.path.append("../../../baselines/mahe_madex/mahe/") 17 | from deps.interaction_explainer import learn_hierarchical_gam 18 | 19 | 20 | gt_file = "../processed_data/text_data/subtree_allphrase_nosentencelabel.pickle" 21 | save_path = "../analysis/results/phrase_corr_mahe.pickle" 22 | 23 | # gt_file = '../processed_data/text_data/subtree_single_token.pickle' 24 | # save_path = "../analysis/results/word_corr_mahe.pickle" 25 | 26 | model_path = "../../../downloads/pretrained_bert" 27 | num_processes = 10 28 | 29 | device = torch.device("cuda:0") 30 | mlp_device = torch.device("cuda:0") 31 | 32 | 33 | def par_experiment(index, Xd, Yd, interaction, mlp_device): 34 | # modify mlp_device to distribute device load, e.g. mlp_device = index % 2 35 | 36 | if interaction == None: 37 | interactions = [] 38 | else: 39 | interactions = [(interaction, 0)] 40 | 41 | ( 42 | interaction_contributions, 43 | univariate_contributions, 44 | prediction_scores, 45 | ) = learn_hierarchical_gam( 46 | Xd, 47 | Yd, 48 | interactions, 49 | mlp_device, 50 | weight_samples=True, 51 | hierarchy_stepsize=4, 52 | num_steps=100, 53 | hierarchical_patience=2, 54 | nepochs=100, 55 | verbose=False, 56 | early_stopping=True, 57 | stopping=False, 58 | seed=index, 59 | ) 60 | 61 | trial_results = { 62 | "inter_contribs": interaction_contributions, 63 | "uni_contribs": univariate_contributions, 64 | "pred_scores": prediction_scores, 65 | } 66 | return index, trial_results 67 | 68 | 69 | def run(): 70 | 71 | multiprocessing.set_start_method("spawn", force=True) 72 | 73 | with open(gt_file, "rb") as handle: 74 | phrase_gt_splits = pickle.load(handle) 75 | 76 | phrase_gt = phrase_gt_splits["test"] 77 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 78 | model = BertForSequenceClassification.from_pretrained(model_path).to(device) 79 | 80 | class_idx = 1 81 | 82 | if os.path.exists(save_path): 83 | with open(save_path, "rb") as handle: 84 | p_dict = pickle.load(handle) 85 | ref = p_dict["ref"] 86 | est_methods = p_dict["est"] 87 | else: 88 | ref = {} 89 | est_methods = {"mahe": {}} 90 | 91 | for s_idx, phrase_dict in enumerate(tqdm(phrase_gt)): 92 | 93 | if all((s_idx in est_methods[m]) for m in ["mahe"]) and s_idx in ref: 94 | print("skip", s_idx) 95 | continue 96 | 97 | sentence = phrase_dict["sentence"] 98 | tokens = phrase_dict["tokens"] 99 | subtrees = phrase_dict["subtrees"] 100 | att_len = len(tokens) 101 | 102 | span_to_label = {} 103 | for subtree in subtrees: 104 | span_to_label[subtree["span"]] = subtree["label"] 105 | 106 | spans = list(span_to_label.keys()) 107 | 108 | baseline_token = "_" 109 | text_ids, baseline_ids = get_input_baseline_ids( 110 | sentence, baseline_token, tokenizer 111 | ) 112 | 113 | data_inst = {"target": text_ids, "baseline": baseline_ids} 114 | 115 | inters = convert_spans_to_interactions(spans) 116 | 117 | Xs, Ys = generate_perturbation_dataset_bert( 118 | data_inst, 119 | model, 120 | class_idx, 121 | device, 122 | batch_size=15, 123 | # num_samples = 100, 124 | # test_size=10, 125 | # valid_size=10, 126 | seed=s_idx, 127 | ) 128 | 129 | for k in Xs: 130 | Xs[k] = Xs[k][:, 1:-1] 131 | 132 | inters2 = [] 133 | i1_to_i2 = {} 134 | for i, inter in enumerate(inters): 135 | if len(inter) == 1: 136 | if None not in inters2: 137 | i1_to_i2[i] = len(inters2) 138 | non_idx = len(inters2) 139 | inters2.append(None) 140 | else: 141 | i1_to_i2[i] = non_idx 142 | else: 143 | i1_to_i2[i] = len(inters2) 144 | inters2.append(inter) 145 | 146 | with multiprocessing.Pool(processes=num_processes) as pool: 147 | results_batch = pool.starmap( 148 | par_experiment, 149 | zip( 150 | list(range(len(inters2))), 151 | repeat(Xs), 152 | repeat(Ys), 153 | inters2, 154 | repeat(mlp_device), 155 | ), 156 | ) 157 | results_dict = dict(results_batch) 158 | 159 | est_vec = [] 160 | ref_vec = [] 161 | 162 | for i, inter in enumerate(inters): 163 | label = span_to_label[spans[i]] 164 | ref_vec.append(label) 165 | rd = results_dict[i1_to_i2[i]] 166 | 167 | if len(inter) >= 2: 168 | ires = rd["inter_contribs"] 169 | est = ires[0][0][1] 170 | 171 | for j in inter: 172 | ures = rd["uni_contribs"] 173 | est += ures[1][j] 174 | else: 175 | 176 | ures = rd["uni_contribs"] 177 | est = ures[1][inter[0]] 178 | 179 | est_vec.append(est) 180 | 181 | est_methods["mahe"][s_idx] = est_vec 182 | 183 | ref[s_idx] = ref_vec 184 | 185 | with open(save_path, "wb") as handle: 186 | pickle.dump( 187 | {"est": est_methods, "ref": ref}, 188 | handle, 189 | protocol=pickle.HIGHEST_PROTOCOL, 190 | ) 191 | 192 | 193 | if __name__ == "__main__": 194 | run() 195 | -------------------------------------------------------------------------------- /experiments/2. archattribute/processed_data/image_data/coco_to_i1k_map.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/experiments/2. archattribute/processed_data/image_data/coco_to_i1k_map.pickle -------------------------------------------------------------------------------- /experiments/2. archattribute/processed_data/prepare_text_ground_truth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from transformers import BertTokenizer\n", 10 | "import pickle\n", 11 | "from tqdm import tqdm\n", 12 | "import sys\n", 13 | "import unidecode\n", 14 | "\n", 15 | "sys.path.append(\"../../../src\")\n", 16 | "from application_utils.text_utils import get_token_list" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "with open('../../../downloads/sst_data/sst_trees.pickle', 'rb') as handle:\n", 35 | " sst_trees = pickle.load(handle)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 4, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def map_word_to_token_index(words, tokens):\n", 45 | " t=0\n", 46 | " token = tokens[t].replace(\"#\", \"\")\n", 47 | "\n", 48 | " token_to_word_map = {}\n", 49 | " word_to_token_map = {}\n", 50 | " for w, word in enumerate(words):\n", 51 | " tmp_word = str(word)\n", 52 | "\n", 53 | " i = 0\n", 54 | " while(tmp_word):\n", 55 | " tmp_word = \"\".join(list(word)[i:])\n", 56 | "\n", 57 | " if tmp_word.startswith(token): \n", 58 | " token_to_word_map[t] = w \n", 59 | " if w not in word_to_token_map:\n", 60 | " word_to_token_map[w] = []\n", 61 | " word_to_token_map[w].append(t)\n", 62 | " \n", 63 | " i += len(token)\n", 64 | " t += 1\n", 65 | " if t >= len(tokens):\n", 66 | " break\n", 67 | "\n", 68 | " token = tokens[t].replace(\"##\", \"\")\n", 69 | " else:\n", 70 | " i += 1\n", 71 | "\n", 72 | " assert(t == len(tokens))\n", 73 | " assert(w == len(words)-1)\n", 74 | " return token_to_word_map, word_to_token_map" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stderr", 84 | "output_type": "stream", 85 | "text": [ 86 | "100%|██████████| 2210/2210 [00:02<00:00, 775.08it/s]\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "\n", 92 | "index = 0\n", 93 | "batch_size = 20\n", 94 | "\n", 95 | "splits = {}\n", 96 | "count = 0\n", 97 | "for split in [\"test\"]:\n", 98 | " token_trees = []\n", 99 | "\n", 100 | " for index in tqdm(range(len(sst_trees[split]))):\n", 101 | "\n", 102 | " sentence = sst_trees[split][index][0]\n", 103 | " subtrees = sst_trees[split][index][2]\n", 104 | " sen_len = len(sentence.split())\n", 105 | "\n", 106 | " tokens = get_token_list(sentence, tokenizer)[1:-1]\n", 107 | "\n", 108 | " words = unidecode.unidecode(sentence.lower()).split()\n", 109 | " try:\n", 110 | " token_to_word_map, word_to_token_map = map_word_to_token_index(words, list(tokens))\n", 111 | " except:\n", 112 | " print(tokens)\n", 113 | " print(words)\n", 114 | " assert(False)\n", 115 | "\n", 116 | " filtered_subtrees = []\n", 117 | "\n", 118 | " for subtree in subtrees:\n", 119 | " if subtree[\"phrase\"] == sentence: continue #excludes a phrase and phrase label if that phrase is the original sentence\n", 120 | " phrase_list = subtree[\"phrase\"].split()\n", 121 | " #if len(phrase_list) == 1: continue #excludes phrases that only consist of a single word before tokenization\n", 122 | " \n", 123 | " pos = subtree[\"position\"]\n", 124 | " \n", 125 | " phrase_span_tokenspace = ( min(word_to_token_map[pos]), max(word_to_token_map[pos + len(phrase_list)-1]))\n", 126 | " first_token_index, last_token_index = phrase_span_tokenspace\n", 127 | " \n", 128 | " #if last_token_index - first_token_index == 0: continue #excludes phrases that only consist of a single token\n", 129 | " \n", 130 | " count +=1\n", 131 | " \n", 132 | " filtered_subtrees.append({\"span\": phrase_span_tokenspace, \"label\": subtree[\"label\"], \"phrase\": subtree[\"phrase\"], \"position\": pos })\n", 133 | "\n", 134 | " token_trees.append({\"sentence\": sentence, \"tokens\": tokens, \"subtrees\": filtered_subtrees })\n", 135 | " splits[split] = token_trees\n", 136 | " \n", 137 | "splits[\"note\"] = \"the phrase spans need to be shifted right based on which methods index SEP and CLS. the spans mean (first token index, last token index)\"" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "with open('text_data/subtree_token_pairphrase_and_greater.pickle', 'wb') as handle:\n", 147 | " pickle.dump(splits, handle, protocol=pickle.HIGHEST_PROTOCOL)" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python [conda env:test]", 154 | "language": "python", 155 | "name": "conda-env-test-py" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.6.7" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 2 172 | } 173 | -------------------------------------------------------------------------------- /experiments/2. archattribute/processed_data/text_data/subtree_allphrase_nosentencelabel.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/experiments/2. archattribute/processed_data/text_data/subtree_allphrase_nosentencelabel.pickle -------------------------------------------------------------------------------- /experiments/2. archattribute/processed_data/text_data/subtree_single_token.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtsang/archipelago/8ff437e5672809827d7daa6a5656aeedbc0e1094/experiments/2. archattribute/processed_data/text_data/subtree_single_token.pickle -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Reproducibility 2 | 3 | Almost all experiments are run through Jupyter notebooks. 4 | 5 | If you haven't done so already, please setup experiments with the following steps: 6 | 1. Run the following command to setup experiments 7 | ```bash 8 | sh ../setup_experiments.sh 9 | ``` 10 | 2. Download the ImageNet '14 test set from [here](http://www.image-net.org/challenges/LSVRC/2014/) and place it in ```../downloads/imagenet14/```. 11 | 12 | ## 1. ArchDetect 13 | 14 | Navigate to the ArchDetect experiment folder: 15 | ```bash 16 | cd 1.\ archdetect/ 17 | ``` 18 | 19 | ### Table 1b 20 | 21 | Method | F1 | F2 | F3 | F4 22 | ------------- | ------------- | ------------- | ------------- | ------------- 23 | Two-way ANOVA | 1.0 | 0.51 | 0.51 | 0.55 24 | Integrated Hessians | 1.0 | N/A | N/A | N/A 25 | Neural Interaction Detection | 0.94 | 0.54 | 0.54 | 0.56 26 | Shapley Interaction Index | 1.0 | 0.50 | 0.50 | 0.51 27 | Shapley Taylor Interaction Detection | 1.0 | 0.55 | 0.78 | 0.55 28 | ArchDetect | 1.0 | 1.0 | 1.0 | 1.0 29 | 30 | 31 | To reproduce these experiments, please use ```1. synthetic_performance.ipynb```. 32 | 33 | ### Figure 3 34 |

35 | 36 |

37 | 38 | To run redundancy experiments, please use ```2. redundancy_bert.ipynb``` and ```2. redundancy_resnet.ipynb```. Plotting can be done in ```2.1. redundancy_analysis_plotting.ipynb```. 39 | 40 | ## 2. ArchAttribute 41 | 42 | Navigate to the ArchAttribute experiment folder: 43 | ```bash 44 | cd 2.\ archattribute/ 45 | ``` 46 | 47 | ### Table 2 48 | 49 | 50 | Method | Word Correlation | Phrase Correlation | Segment AUC 51 | ------------- | ------------- | ------------- | ------------- 52 | Difference | 0.333 | 0.639 | 0.705 53 | Integrated Gradients (IG) | 0.473 | 0.737 | 0.786 54 | Integrated Hessians (IH) | N/A | 0.128 | N/A 55 | Model-Agnostic Hierarchical Explanations (MAHE) | 0.570 | 0.702 | 0.712 56 | Shapley Interaction Index (SI) | 0.160 | -0.018 | 0.530 57 | Shapley Taylor Interaction Index (STI) | 0.657 | 0.286 | 0.626 58 | Sampling Contextual Decomposition (SCD) | 0.622 | 0.742 | N/A 59 | Sampling Occlusion (SOC) | 0.670 | 0.794 | N/A 60 | ArchAttribute | 0.745 | 0.836 | 0.919 61 | 62 | To reproduce these experiments, please use ```text_correlation.ipynb``` and ```segment_auc.ipynb```. Both Word Correlation and Phrase Correlation are evaluated in ```text_correlation```. 63 | 64 | To run MAHE, use the correponding python scripts in ```parallel_mahe/```. 65 | 66 | To actually compute correlation and AUC scores, you can use the notebooks in ```analysis/```. 67 | 68 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | ipympl==0.5.8 2 | ipywidgets==7.5.1 3 | jupyterlab==2.2.8 4 | matplotlib==3.3.2 5 | nltk==3.4.5 6 | numpy==1.18.5 7 | opencv_python==4.4.0.44 8 | pandas==1.1.3 9 | Pillow==6.2.2 10 | pycocotools==2.0 11 | requests==2.22.0 12 | scikit-image==0.16.2 13 | scikit-learn==0.21.3 14 | scipy==1.5.2 15 | statsmodels==0.10.0rc2 16 | tensorflow-gpu==2.3.1 17 | torch==1.2.0 18 | torchtext==0.3.1 19 | torchvision==0.4.0 20 | tqdm==4.32.2 21 | transformers==2.9.0 -------------------------------------------------------------------------------- /setup_demos.sh: -------------------------------------------------------------------------------- 1 | pip install --upgrade pip 2 | pip install -r demos/requirements.txt 3 | python download.py --demos -------------------------------------------------------------------------------- /setup_experiments.sh: -------------------------------------------------------------------------------- 1 | pip install --upgrade pip 2 | pip install -r experiments/requirements.txt 3 | python download.py --experiments -------------------------------------------------------------------------------- /setup_interactive_viz.sh: -------------------------------------------------------------------------------- 1 | pip install --upgrade pip 2 | pip install -r demos/requirements.txt 3 | jupyter labextension install @jupyter-widgets/jupyterlab-manager 4 | jupyter labextension install jupyter-matplotlib 5 | jupyter nbextension enable --py widgetsnbextension 6 | python download.py --quick_demo -------------------------------------------------------------------------------- /src/application_utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_efficient_mask_indices(inst, baseline, input): 5 | invert = np.sum(1 * inst) >= len(inst) // 2 6 | if invert: 7 | context = input.copy() 8 | insertion_target = baseline 9 | mask_indices = np.argwhere(inst == False).flatten() 10 | else: 11 | context = baseline.copy() 12 | insertion_target = input 13 | mask_indices = np.argwhere(inst == True).flatten() 14 | return mask_indices, context, insertion_target 15 | -------------------------------------------------------------------------------- /src/application_utils/rec_utils.py: -------------------------------------------------------------------------------- 1 | from autoint.model import AutoInt 2 | from application_utils.common_utils import get_efficient_mask_indices 3 | import os 4 | import numpy as np 5 | from sklearn.metrics import roc_auc_score 6 | from tqdm import tqdm 7 | import pickle 8 | 9 | 10 | class AutoIntWrapper: 11 | def __init__(self, model, Xi_inst, inv_sigmoid=True): 12 | self.model = model 13 | self.Xi_inst = Xi_inst 14 | self.use_inv_sigmoid = inv_sigmoid 15 | 16 | def inv_sigmoid(self, y): 17 | return np.log(y / (1 - y)) 18 | 19 | def __call__(self, Xv): 20 | Xi = np.repeat(self.Xi_inst, Xv.shape[0], axis=0) 21 | pred = self.model.predict(Xi, Xv) 22 | if self.use_inv_sigmoid: 23 | pred = self.inv_sigmoid(pred) 24 | return np.expand_dims(pred, 1) 25 | 26 | 27 | class IdXformer: 28 | def __init__(self, input_ids, baseline_ids): 29 | self.input = input_ids.flatten() 30 | self.baseline = baseline_ids.flatten() 31 | self.num_features = len(self.input) 32 | 33 | def efficient_xform(self, inst): 34 | mask_indices, base, change = get_efficient_mask_indices( 35 | inst, self.baseline, self.input 36 | ) 37 | for i in mask_indices: 38 | base[i] = change[i] 39 | return base 40 | 41 | def __call__(self, inst): 42 | id_list = self.efficient_xform(inst) 43 | return id_list 44 | 45 | 46 | def evaluate(model, data, batch_size=1000): 47 | num_samples = data["Xi"].shape[0] 48 | num_batches = int(np.ceil(num_samples / batch_size)) 49 | 50 | preds = [] 51 | for i in tqdm(range(num_batches)): 52 | Xi_batch = data["Xi"][i * batch_size : (i + 1) * batch_size] 53 | Xv_batch = data["Xv"][i * batch_size : (i + 1) * batch_size] 54 | pred_batch = model.predict(Xi_batch, Xv_batch) 55 | preds.append(pred_batch) 56 | 57 | y_pred = np.concatenate(preds) 58 | y_gt = data["y"][:num_samples] 59 | 60 | return roc_auc_score(y_gt, y_pred) 61 | 62 | 63 | def get_example(data, index): 64 | Xv_inst = data["Xv"][index : index + 1] 65 | Xi_inst = data["Xi"][index : index + 1] 66 | return Xv_inst, Xi_inst 67 | 68 | 69 | def get_autoint_and_data( 70 | dataset=None, 71 | data_path=None, 72 | save_path=None, 73 | feature_size=1544489, 74 | ): 75 | args = parse_args(dataset, data_path, save_path) 76 | 77 | run_cnt = 0 78 | model = AutoInt(args=args, feature_size=feature_size, run_cnt=run_cnt) 79 | model.restore(args.save_path) 80 | 81 | with open(data_path, "rb") as handle: 82 | data_batch = pickle.load(handle) 83 | return model, data_batch 84 | 85 | 86 | def get_avazu_dict(): 87 | avazu_dict = { 88 | 0: "id: ad identifier", 89 | 1: "hour", 90 | 2: "C1", 91 | 3: "banner_pos", 92 | 4: "site_id", 93 | 5: "site_domain", 94 | 6: "site_category", 95 | 7: "app_id", 96 | 8: "app_domain", 97 | 9: "app_category", 98 | 10: "device_id", 99 | 11: "device_ip", 100 | 12: "device_model", 101 | 13: "device_type", 102 | 14: "device_conn_type", 103 | } 104 | for i in range(15, 23): 105 | avazu_dict[i] = "C" + str(i - 1) 106 | return avazu_dict 107 | 108 | 109 | def parse_args(dataset, data_path, save_path): 110 | dataset = dataset.lower() 111 | if "avazu" in dataset: 112 | field_size = 23 113 | elif "criteo" in dataset: 114 | field_size = 39 115 | else: 116 | raise ValueError("Invalid dataset") 117 | 118 | return get_args(save_path, field_size, dataset, data_path) 119 | 120 | 121 | def get_data_info(args): 122 | data = args.data.split("/")[-1].lower() 123 | if any([data.startswith(d) for d in ["avazu"]]): 124 | file_name = ["train_i.npy", "train_x.npy", "train_y.npy"] 125 | elif any([data.startswith(d) for d in ["criteo"]]): 126 | file_name = ["train_i.npy", "train_x2.npy", "train_y.npy"] 127 | else: 128 | raise ValueError("invalid data arg") 129 | 130 | path_prefix = os.path.join(args.data_path, args.data) 131 | return file_name, path_prefix 132 | 133 | 134 | class get_args: 135 | # the original parameter configuration of AutoInt 136 | blocks = 3 137 | block_shape = [64, 64, 64] 138 | heads = 2 139 | embedding_size = 16 140 | dropout_keep_prob = [1, 1, 1] 141 | epoch = 3 142 | batch_size = 1024 143 | learning_rate = 0.001 144 | learning_rate_wide = 0.001 145 | optimizer_type = "adam" 146 | l2_reg = 0.0 147 | random_seed = 2018 # used in the official autoint code 148 | loss_type = "logloss" 149 | verbose = 1 150 | run_times = 1 151 | is_save = False 152 | greater_is_better = False 153 | has_residual = True 154 | has_wide = False 155 | deep_layers = [400, 400] 156 | batch_norm = 0 157 | batch_norm_decay = 0.995 158 | 159 | def __init__(self, save_path, field_size, dataset, data_path): 160 | self.save_path = save_path 161 | self.field_size = field_size 162 | self.data = dataset 163 | self.data_path = data_path 164 | -------------------------------------------------------------------------------- /src/application_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from application_utils.common_utils import get_efficient_mask_indices 3 | import pickle 4 | import copy 5 | 6 | 7 | class TextXformer: 8 | # note: this xformer is not the transformer from Vaswani et al., 2017 9 | 10 | def __init__(self, input_ids, baseline_ids): 11 | self.input = input_ids 12 | self.baseline = baseline_ids 13 | self.num_features = len(self.input) 14 | 15 | def simple_xform(self, inst): 16 | mask_indices = np.argwhere(inst == True).flatten() 17 | id_list = list(self.baseline) 18 | for i in mask_indices: 19 | id_list[i] = self.input[i] 20 | return id_list 21 | 22 | def efficient_xform(self, inst): 23 | mask_indices, base, change = get_efficient_mask_indices( 24 | inst, self.baseline, self.input 25 | ) 26 | for i in mask_indices: 27 | base[i] = change[i] 28 | return base 29 | 30 | def __call__(self, inst): 31 | id_list = self.efficient_xform(inst) 32 | return id_list 33 | 34 | 35 | def process_stop_words(explanation, tokens, strip_first_last=True): 36 | explanation = copy.deepcopy(explanation) 37 | tokens = copy.deepcopy(tokens) 38 | stop_words = set( 39 | [ 40 | "a", 41 | "an", 42 | "and", 43 | "are", 44 | "as", 45 | "at", 46 | "be", 47 | "by", 48 | "for", 49 | "from", 50 | "has", 51 | "he", 52 | "in", 53 | "is", 54 | "it", 55 | "its", 56 | "of", 57 | "on", 58 | "that", 59 | "the", 60 | "to", 61 | "was", 62 | "were", 63 | "will", 64 | "with", 65 | "s", 66 | "ll", 67 | ] 68 | ) 69 | for i, token in enumerate(tokens): 70 | if token in stop_words: 71 | if (i,) in explanation: 72 | explanation[(i,)] = 0.0 73 | 74 | if strip_first_last: 75 | explanation.pop((0,)) 76 | explanation.pop((len(tokens) - 1,)) 77 | tokens = tokens[1:-1] 78 | return explanation, tokens 79 | 80 | 81 | def get_input_baseline_ids(text, baseline_token, tokenizer): 82 | input_ids = prepare_huggingface_data([text, baseline_token], tokenizer)["input_ids"] 83 | text_ids = input_ids[0] 84 | baseline_ids = np.array( 85 | [input_ids[0][0]] + [input_ids[1, 1]] * (len(text_ids) - 2) + [input_ids[0][-1]] 86 | ) 87 | return text_ids, baseline_ids 88 | 89 | 90 | def get_token_list(sentence, tokenizer): 91 | if isinstance(sentence, str): 92 | X = prepare_huggingface_data([sentence], tokenizer) 93 | batch_ids = X["input_ids"] 94 | else: 95 | batch_ids = np.expand_dims(sentence, 0) 96 | token_list = [] 97 | for i in range(batch_ids.shape[0]): 98 | ids = batch_ids[i] 99 | tokens = tokenizer.convert_ids_to_tokens(ids) 100 | token_list.append(tokens) 101 | return token_list[0] 102 | 103 | 104 | def get_sst_sentences(split="test", path="../../downloads/sst_data/sst_trees.pickle"): 105 | with open(path, "rb") as handle: 106 | sst_trees = pickle.load(handle) 107 | 108 | data = [] 109 | for s in range(len(sst_trees[split])): 110 | sst_item = sst_trees[split][s] 111 | 112 | sst_item = sst_trees[split][s] 113 | sentence = sst_item[0] 114 | data.append(sentence) 115 | return data 116 | 117 | 118 | def prepare_huggingface_data(sentences, tokenizer): 119 | X = {"input_ids": [], "token_type_ids": [], "attention_mask": []} 120 | for sentence in sentences: 121 | encoded_sentence = tokenizer.encode_plus(sentence, add_special_tokens=True) 122 | for key in encoded_sentence: 123 | X[key].append(encoded_sentence[key]) 124 | 125 | assert not any(encoded_sentence["token_type_ids"]) 126 | 127 | # pad to the batch max length (auto-identified from encode_plus) 128 | batch_ids = X["input_ids"] 129 | max_len = np.max([len(ids) for ids in batch_ids]) 130 | X_pad = {} 131 | for i, ids in enumerate(batch_ids): 132 | diff = max_len - len(ids) 133 | for key in X: 134 | if key not in X_pad: 135 | X_pad[key] = [] 136 | X_pad[key].append(X[key][i] + [0] * diff) 137 | 138 | for key in X_pad: 139 | X_pad[key] = np.array(X_pad[key]) 140 | return X_pad 141 | -------------------------------------------------------------------------------- /src/application_utils/text_utils_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from application_utils.text_utils import ( 3 | prepare_huggingface_data, 4 | get_input_baseline_ids, 5 | ) 6 | from transformers import glue_convert_examples_to_features 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | class BertWrapper: 12 | def __init__(self, model): 13 | self.model = model 14 | 15 | def get_predictions(self, batch_ids): 16 | X = {"input_ids": np.array(batch_ids)} 17 | batch_conf = self.model(X)[0] 18 | return batch_conf 19 | 20 | def __call__(self, batch_ids): 21 | batch_predictions = self.get_predictions(batch_ids) 22 | return batch_predictions.numpy() 23 | 24 | 25 | class BertWrapperIH: 26 | def __init__(self, model): 27 | self.model = model 28 | 29 | def embedding_model(self, batch_ids): 30 | batch_embedding = self.model.bert.embeddings((batch_ids, None, None, None)) 31 | # batch_embedding = self.model.bert.embeddings(batch_ids, None, None, None) 32 | return batch_embedding 33 | 34 | def prediction_model(self, batch_embedding, attention_mask): 35 | extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] 36 | extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) 37 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 38 | head_mask = [None] * self.model.bert.num_hidden_layers 39 | 40 | encoder_outputs = self.model.bert.encoder( 41 | [batch_embedding, extended_attention_mask, head_mask], training=False 42 | ) 43 | 44 | sequence_output = encoder_outputs[0] 45 | pooled_output = self.model.bert.pooler(sequence_output) 46 | logits = self.model.classifier(pooled_output) 47 | return logits 48 | 49 | def get_predictions(self, batch_ids): 50 | X = {"input_ids": np.array(batch_ids)} 51 | batch_conf = self.model(X)[0] 52 | return batch_conf 53 | 54 | def get_predictions_extra(self, sentences, tokenizer, baseline_token=None): 55 | X = prepare_huggingface_data(sentences, tokenizer) 56 | 57 | assert len(sentences) == 1 58 | if baseline_token is not None: 59 | _, baseline_ids = get_input_baseline_ids( 60 | sentences[0], baseline_token, tokenizer 61 | ) 62 | 63 | for key in X: 64 | X[key] = tf.convert_to_tensor(X[key]) 65 | batch_ids = X["input_ids"] 66 | attention_mask = X["attention_mask"] 67 | 68 | batch_conf = self.model(X)[0] 69 | 70 | batch_embedding = self.embedding_model(batch_ids) 71 | batch_predictions = self.prediction_model(batch_embedding, attention_mask) 72 | 73 | if baseline_token is None: 74 | batch_baseline = np.zeros((1, batch_ids.shape[1]), dtype=np.int64) 75 | else: 76 | batch_baseline = np.expand_dims(baseline_ids, 0) 77 | baseline_embedding = self.embedding_model(batch_baseline) 78 | 79 | orig_token_list = [] 80 | for i in range(batch_ids.shape[0]): 81 | ids = batch_ids[i].numpy() 82 | tokens = tokenizer.convert_ids_to_tokens(ids) 83 | orig_token_list.append(tokens) 84 | 85 | return ( 86 | batch_predictions, 87 | orig_token_list, 88 | batch_embedding, 89 | baseline_embedding, 90 | attention_mask, 91 | ) 92 | 93 | def __call__(self, batch_ids): 94 | batch_predictions = self.get_predictions(batch_ids) 95 | return batch_predictions.numpy() 96 | 97 | 98 | class DistilbertWrapperIH(BertWrapperIH): 99 | def __init__(self, model): 100 | super().__init__(model) 101 | self.model = model 102 | 103 | def embedding_model(self, batch_ids): 104 | batch_embedding = self.model.distilbert.embeddings(batch_ids) 105 | return batch_embedding 106 | 107 | def prediction_model(self, batch_embedding, attention_mask): 108 | # attention_mask = tf.ones(batch_embedding.shape[:2]) 109 | attention_mask = tf.cast(attention_mask, dtype=tf.float32) 110 | head_mask = [None] * self.model.distilbert.num_hidden_layers 111 | 112 | transformer_output = self.model.distilbert.transformer( 113 | [batch_embedding, attention_mask, head_mask], training=False 114 | )[0] 115 | pooled_output = transformer_output[:, 0] 116 | pooled_output = self.model.pre_classifier(pooled_output) 117 | logits = self.model.classifier(pooled_output) 118 | return logits 119 | -------------------------------------------------------------------------------- /src/application_utils/text_utils_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class BertWrapperTorch: 6 | def __init__(self, model, device, merge_logits=False): 7 | self.model = model.to(device) 8 | self.device = device 9 | self.merge_logits = merge_logits 10 | 11 | def get_predictions(self, batch_ids): 12 | batch_ids = torch.LongTensor(batch_ids).to(self.device) 13 | batch_conf = self.model(batch_ids, None, None) 14 | if isinstance(batch_conf, tuple): 15 | return batch_conf[0].data.cpu() 16 | else: 17 | return batch_conf.data.cpu() 18 | return batch_conf 19 | 20 | def __call__(self, batch_ids): 21 | batch_predictions = self.get_predictions(batch_ids) 22 | if self.merge_logits: 23 | batch_predictions2 = ( 24 | (batch_predictions[:, 1] - batch_predictions[:, 0]).unsqueeze(1).numpy() 25 | ) 26 | return batch_predictions2 27 | else: 28 | return batch_predictions.numpy() 29 | -------------------------------------------------------------------------------- /src/application_utils/utils_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils import data 4 | import numpy as np 5 | import torch.optim as optim 6 | 7 | 8 | class ModelWrapperTorch: 9 | def __init__(self, model, device, input_type="image"): 10 | self.device = device 11 | self.model = model.to(device) 12 | self.input_type = input_type 13 | 14 | def __call__(self, X): 15 | if self.input_type == "text": 16 | X = torch.LongTensor(X).to(self.device) 17 | preds = self.model(X)[0].data.cpu().numpy() 18 | else: 19 | X = torch.FloatTensor(X).to(self.device) 20 | if self.input_type == "image": 21 | X = X.permute(0, 3, 1, 2) 22 | preds = self.model(X).data.cpu().numpy() 23 | return preds 24 | -------------------------------------------------------------------------------- /src/viz/colors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | 4 | 5 | def pos_neg_colors(get_rgb=False): 6 | """ 7 | Based on the Integrated Hessians Code Repo 8 | """ 9 | color_map_size = 256 10 | vals = np.ones((color_map_size, 4)) 11 | 12 | # colors chosen for color-blind 13 | rgb_a = (210, 110, 105) 14 | rgb_b = (103, 169, 207) 15 | 16 | vals[: int(color_map_size / 2), 0] = np.linspace( 17 | rgb_a[0] / 256, 1.0, int(color_map_size / 2) 18 | ) 19 | vals[: int(color_map_size / 2), 1] = np.linspace( 20 | rgb_a[1] / 256, 1.0, int(color_map_size / 2) 21 | ) 22 | vals[: int(color_map_size / 2), 2] = np.linspace( 23 | rgb_a[2] / 256, 1.0, int(color_map_size / 2) 24 | ) 25 | 26 | vals[int(color_map_size / 2) :, 0] = np.linspace( 27 | 1.0, rgb_b[0] / 256, int(color_map_size / 2) 28 | ) 29 | vals[int(color_map_size / 2) :, 1] = np.linspace( 30 | 1.0, rgb_b[1] / 256, int(color_map_size / 2) 31 | ) 32 | vals[int(color_map_size / 2) :, 2] = np.linspace( 33 | 1.0, rgb_b[2] / 256, int(color_map_size / 2) 34 | ) 35 | cmap = mpl.colors.ListedColormap(vals) 36 | if get_rgb: 37 | return rgb_a, rgb_b 38 | else: 39 | return cmap 40 | -------------------------------------------------------------------------------- /src/viz/rec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import viz.colors as colors 5 | from textwrap import wrap 6 | 7 | 8 | def viz_bar_chart( 9 | data, 10 | top_k=5, 11 | figsize=(10, 4), 12 | save_file=None, 13 | max_label_size=100, 14 | remove_space=False, 15 | y_label="Feature Set", 16 | sort_again=True, 17 | bounds=None, 18 | **kwargs 19 | ): 20 | 21 | feature_labels, attributions = zip(*data) 22 | if top_k > len(attributions): 23 | top_k = len(attributions) 24 | 25 | args = np.argsort(-1 * np.abs(attributions)) 26 | args = args[:top_k] 27 | 28 | args2 = np.argsort(np.array(attributions)[args]) 29 | feature_labels = np.array(feature_labels)[args] 30 | attributions = np.array(attributions)[args] 31 | 32 | if sort_again: 33 | feature_labels = feature_labels[args2] 34 | attributions = attributions[args2] 35 | 36 | fig, axis = plt.subplots(figsize=figsize) 37 | 38 | if bounds is None: 39 | bounds = np.max(np.abs(attributions)) 40 | normalizer = mpl.colors.Normalize(vmin=-bounds, vmax=bounds) 41 | 42 | if "cmap" in kwargs: 43 | cmap = kwargs["cmap"] 44 | else: 45 | cmap = colors.pos_neg_colors() 46 | 47 | axis.barh( 48 | np.arange(top_k), 49 | attributions, 50 | color=[cmap(normalizer(c)) for c in attributions], 51 | align="center", 52 | zorder=10, 53 | **kwargs 54 | ) 55 | 56 | if not sort_again: 57 | axis.invert_yaxis() 58 | 59 | axis.set_xlabel("Attribution", fontsize=18) 60 | axis.set_ylabel(y_label, fontsize=18) 61 | axis.set_yticks(np.arange(top_k)) 62 | axis.tick_params(axis="y", which="both", left=False, labelsize=14) 63 | axis.tick_params(axis="x", which="both", left=False, labelsize=14) 64 | 65 | if remove_space: 66 | token = " " 67 | else: 68 | token = "" 69 | 70 | axis.set_yticklabels( 71 | ["\n".join(wrap(y, max_label_size)).replace(token, "") for y in feature_labels] 72 | ) 73 | 74 | axis.grid(axis="x", zorder=0, linewidth=0.2) 75 | axis.grid(axis="y", zorder=0, linestyle="--", linewidth=1.0) 76 | _set_axis_config(axis, linewidths=(0.0, 0.0, 0.0, 1.0)) 77 | if save_file is not None: 78 | plt.savefig(save_file, bbox_inches="tight") 79 | 80 | 81 | def _set_axis_config( 82 | axis, linewidths=(0.0, 0.0, 0.0, 0.0), clear_y_ticks=False, clear_x_ticks=False 83 | ): 84 | """ 85 | Source: Integrated Hessians Code Repo 86 | """ 87 | axis.spines["right"].set_linewidth(linewidths[0]) 88 | axis.spines["top"].set_linewidth(linewidths[1]) 89 | axis.spines["left"].set_linewidth(linewidths[2]) 90 | axis.spines["bottom"].set_linewidth(linewidths[3]) 91 | if clear_x_ticks: 92 | axis.set_xticks([]) 93 | if clear_y_ticks: 94 | axis.set_yticks([]) 95 | --------------------------------------------------------------------------------