├── requirements.txt ├── environment.yml ├── bash ├── evaluate_folder_miso_synt.sh ├── evaluate_folder_miso-ita-synt.sh ├── test_folder.sh ├── evaluate_folder_madlibs_pattern.sh ├── train_model_10_seeds.sh └── train_model_EAR_10_seeds.sh ├── config.py ├── setup.py ├── LICENSE.md ├── custom_callbacks.py ├── ear └── __init__.py ├── .gitignore ├── README.md ├── evaluate_model.py ├── dataset.py ├── metrics.py ├── train_bert.py ├── utils.py ├── term_extraction.ipynb └── data └── mlma_dev.tsv /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | transformers 3 | datasets 4 | accelerate 5 | pytorch-lightning 6 | ipython 7 | comet_ml 8 | scikit-learn -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: unbias 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - pytorch=1.7.1 7 | - cudatoolkit=11.0.* 8 | - pandas 9 | - ipython 10 | - pip 11 | - pip: 12 | - transformers 13 | - pytorch-lightning==1.2.* 14 | - click 15 | - comet_ml 16 | - gdown 17 | - scikit-learn 18 | - tensorflow 19 | - seaborn 20 | -------------------------------------------------------------------------------- /bash/evaluate_folder_miso_synt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # To make everything work as expected, run from the parent directory. 4 | # args: [model_directory] 5 | 6 | if test $# -ne 2; then 7 | echo "Specify input folder containing models and output folder" 8 | exit -1 9 | fi 10 | 11 | modeldir=$1 12 | 13 | for m in $modeldir/*; do 14 | python evaluate_model.py --dataset miso_synt_test --model_path $m \ 15 | --subgroups_path ./data/miso_it.txt \ 16 | --out_folder $2 17 | done 18 | -------------------------------------------------------------------------------- /bash/evaluate_folder_miso-ita-synt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # To make everything work as expected, run from the parent directory. 4 | 5 | if test $# -ne 3; then 6 | echo "Specify input folder containing models and output folder" 7 | exit -1 8 | fi 9 | 10 | modeldir=$1 11 | out_folder=$2 12 | src_tokenizer=$3 13 | 14 | for m in $modeldir/*; do 15 | python evaluate_model.py \ 16 | --dataset miso-ita-synt \ 17 | --model_path $m \ 18 | --subgroups_path ./data/AMI2020_test_identityterms.txt \ 19 | --out_folder $out_folder \ 20 | --src_tokenizer $src_tokenizer \ 21 | --n_jobs 8 22 | done 23 | -------------------------------------------------------------------------------- /bash/test_folder.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # To make everything work as expected, run from the parent directory. 4 | # args: [model_directory] 5 | 6 | if test $# -ne 5; then 7 | echo "Usage: ./test_folder.sh " 8 | exit -1 9 | fi 10 | 11 | 12 | in_dir=$1 13 | out_dir=$2 14 | dataset=$3 15 | src_tokenizer=$4 16 | ckpt_pattern=$5 17 | 18 | 19 | for m in $in_dir/*; do 20 | echo "--------------------" 21 | echo "Evaluating $m" 22 | echo "--------------------" 23 | 24 | python evaluate_model.py \ 25 | --dataset ${dataset} \ 26 | --out_folder ${out_dir} \ 27 | --model_path ${m} \ 28 | --no_bias_metrics \ 29 | --ckpt_pattern ${ckpt_pattern} \ 30 | --src_tokenizer bert-base-uncased 31 | done 32 | -------------------------------------------------------------------------------- /bash/evaluate_folder_madlibs_pattern.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # To make everything work as expected, run from the parent directory. 4 | 5 | if test $# -ne 3; then 6 | echo "Usage: ./evaluate_folder_madlibs_pattern.sh " 7 | exit -1 8 | fi 9 | 10 | modeldir=$1 11 | out_folder=$2 12 | ckpt_pattern=$3 13 | 14 | # Test on Madlibs 15 | # add madlibs89k in the next instruction if needed 16 | for d in madlibs77k; do 17 | for m in $modeldir/*; do 18 | python evaluate_model.py \ 19 | --dataset $d \ 20 | --model_path $m \ 21 | --subgroups_path ./data/bias_madlibs_data/adjectives_people.txt \ 22 | --out_folder $out_folder \ 23 | --n_jobs 8 \ 24 | --src_tokenizer bert-base-uncased \ 25 | --ckpt_pattern $ckpt_pattern 26 | 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /bash/train_model_10_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if test $# -ne 3; then 4 | echo "Usage: train_model_10_seeds.sh " 5 | exit -1 6 | fi 7 | 8 | src_model=$1 9 | output_dir=$2 10 | training_dataset=$3 11 | 12 | echo "Running with: ${src_model} ${output_dir} ${training_dataset}" 13 | 14 | for ((i = 0 ; i < 10 ; i++)); do 15 | 16 | echo "Seed: $i" 17 | 18 | python train_bert.py \ 19 | --src_model ${src_model} \ 20 | --output_dir ${output_dir} \ 21 | --training_dataset ${training_dataset} \ 22 | --max_epochs 20 \ 23 | --batch_size 64 \ 24 | --max_seq_length 120 \ 25 | --gpus 1 \ 26 | --num_workers 8 \ 27 | --learning_rate 2e-5 \ 28 | --early_stop_epochs 5 \ 29 | --seed $i \ 30 | --warmup_train_perc 0.1 \ 31 | --weight_decay 0.01 32 | 33 | done 34 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """Main module.""" 2 | 3 | # Define here main default values 4 | 5 | DEFAULT_OUT_DIR = "./runs" 6 | DEFAULT_EMBEDDINGS_PATH = "./data/glove.6B/glove.6B.100d.txt" 7 | DEFAULT_SAVE_DIR = "./dumps" 8 | DEFAULT_OUT_DIR = "./runs" 9 | DEFAULT_HPARAMS_CNN = { 10 | "max_sequence_length": 120, #  original 250 11 | "max_num_words": 32000, #  original 10000 12 | "embedding_dim": 100, 13 | "embedding_trainable": False, 14 | "learning_rate": 0.00005, 15 | "stop_early": True, 16 | "es_patience": 5, # Only relevant if STOP_EARLY = True, original: 1 17 | "es_min_delta": 0, # Only relevant if STOP_EARLY = True 18 | "batch_size": 64, #  original 128 19 | "epochs": 30, #  original 20 20 | "dropout_rate": 0.3, 21 | "cnn_filter_sizes": [128, 128, 128], 22 | "cnn_kernel_sizes": [5, 5, 5], 23 | "cnn_pooling_sizes": [5, 5, 40], 24 | "verbose": True, 25 | } 26 | -------------------------------------------------------------------------------- /bash/train_model_EAR_10_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if test $# -ne 3; then 4 | echo "Usage: train_model_10_seeds.sh " 5 | exit -1 6 | fi 7 | 8 | src_model=$1 9 | output_dir=$2 10 | training_dataset=$3 11 | 12 | echo "Running with: ${src_model} ${output_dir} ${training_dataset}" 13 | 14 | for ((i = 0 ; i < 10 ; i++)); do 15 | 16 | echo "Seed: $i" 17 | 18 | python train_bert.py \ 19 | --src_model ${src_model} \ 20 | --output_dir ${output_dir} \ 21 | --training_dataset ${training_dataset} \ 22 | --max_epochs 20 \ 23 | --batch_size 64 \ 24 | --max_seq_length 120 \ 25 | --gpus 1 \ 26 | --num_workers 8 \ 27 | --learning_rate 2e-5 \ 28 | --early_stop_epochs -1 \ 29 | --seed $i \ 30 | --regularization entropy \ 31 | --reg_strength 0.01 \ 32 | --warmup_train_perc 0.1 \ 33 | --weight_decay 0.01 34 | 35 | done 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup 3 | 4 | # The directory containing this file 5 | HERE = pathlib.Path(__file__).parent 6 | 7 | # The text of the README file 8 | README = (HERE / "README.md").read_text() 9 | 10 | DEPENDENCIES = ["transformers"] 11 | 12 | # This call to setup() does all the work 13 | setup( 14 | name="ear-transformers", 15 | version="1.0.0", 16 | description="Entrop-based Attention Regularization", 17 | long_description=README, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/g8a9/ear", 20 | author="Giuseppe Attanasio", 21 | author_email="giuseppeattanasio6@gmail.com", 22 | license="MIT", 23 | classifiers=[ 24 | "License :: OSI Approved :: MIT License", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.7", 27 | ], 28 | packages=["ear"], 29 | include_package_data=True, 30 | install_requires=DEPENDENCIES, 31 | ) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Giuseppe Attanasio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /custom_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import logging 4 | 5 | logging.basicConfig(format="%(asctime)s:%(module)s:%(message)s", level=logging.INFO) 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class CheckpointEveryNEpochs(pl.Callback): 10 | """ 11 | Save a checkpoint every N epochs, instead of Lightning's default that checkpoints 12 | based on validation loss. 13 | """ 14 | 15 | def __init__(self, save_epoch_frequency: int, skip_first: bool = True): 16 | self.save_epoch_frequency = save_epoch_frequency 17 | self.skip_first = skip_first 18 | 19 | def on_validation_epoch_end(self, trainer, pl_module): 20 | """ Check if we should save a checkpoint after every train batch """ 21 | epoch = trainer.current_epoch 22 | if epoch == 0 and self.skip_first: 23 | return 24 | 25 | # global_step = trainer.global_step 26 | if (epoch + 1) % self.save_epoch_frequency == 0: 27 | logger.info(f"Dumping checkpoint at epoch {epoch}") 28 | metrics = trainer.logged_metrics 29 | filename = f"PL-epoch={epoch}-val_loss={metrics['val_loss']:.3f}-val_reg_loss={metrics['val_reg_loss']:.3f}.ckpt" 30 | ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename) 31 | trainer.save_checkpoint(ckpt_path) 32 | -------------------------------------------------------------------------------- /ear/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForSequenceClassification 3 | from collections import namedtuple 4 | 5 | 6 | def compute_negative_entropy( 7 | inputs: tuple, attention_mask: torch.torch, return_values: bool = False 8 | ): 9 | """Compute the negative entropy across layers of a network for given inputs. 10 | 11 | Args: 12 | - input: tuple. Tuple of length num_layers. Each item should be in the form: BHSS 13 | - attention_mask. Tensor with dim: BS 14 | """ 15 | inputs = torch.stack(inputs) #  LayersBatchHeadsSeqlenSeqlen 16 | assert inputs.ndim == 5, "Here we expect 5 dimensions in the form LBHSS" 17 | 18 | #  average over attention heads 19 | pool_heads = inputs.mean(2) 20 | 21 | batch_size = pool_heads.shape[1] 22 | samples_entropy = list() 23 | neg_entropies = list() 24 | for b in range(batch_size): 25 | #  get inputs from non-padded tokens of the current sample 26 | mask = attention_mask[b] 27 | sample = pool_heads[:, b, mask.bool(), :] 28 | sample = sample[:, :, mask.bool()] 29 | 30 | #  get the negative entropy for each non-padded token 31 | neg_entropy = (sample.softmax(-1) * sample.log_softmax(-1)).sum(-1) 32 | if return_values: 33 | neg_entropies.append(neg_entropy.detach()) 34 | 35 | #  get the "average entropy" that traverses the layer 36 | mean_entropy = neg_entropy.mean(-1) 37 | 38 | #  store the sum across all the layers 39 | samples_entropy.append(mean_entropy.sum(0)) 40 | 41 | # average over the batch 42 | final_entropy = torch.stack(samples_entropy).mean() 43 | if return_values: 44 | return final_entropy, neg_entropies 45 | else: 46 | return final_entropy 47 | 48 | 49 | EARClassificationOutput = namedtuple( 50 | "EARClassificationOutput", 51 | ["model_output", "negative_entropy", "reg_loss", "loss"] 52 | ) 53 | 54 | 55 | 56 | class EARModelForSequenceClassification(torch.nn.Module): 57 | 58 | def __init__(self, model_name_or_path, ear_reg_strength: float = 0.01, model_kwargs={}): 59 | super().__init__() 60 | 61 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, **model_kwargs) 62 | self.ear_reg_strength = ear_reg_strength 63 | 64 | def forward(self, **model_kwargs): 65 | output = self.model(**model_kwargs, output_attentions=True) 66 | 67 | negative_entropy = compute_negative_entropy( 68 | output.attentions, model_kwargs["attention_mask"] 69 | ) 70 | reg_loss = self.ear_reg_strength * negative_entropy 71 | loss = reg_loss + output.loss 72 | 73 | return EARClassificationOutput( 74 | model_output=output, 75 | negative_entropy=negative_entropy, 76 | reg_loss=reg_loss, 77 | loss=loss 78 | ) 79 | 80 | def save_pretrained(self, *args, **kwargs): 81 | self.model.save_pretrained(*args, **kwargs) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Entropy-based Attention Regularization 👂 2 | 3 | EAR is a regularization technique to mitigate uninteded bias while reducing lexical overfitting. It is based on attention entropy maximization. In practice, EAR adds a regularization term at training time to learn tokens with maximal self-attention entropy. 4 | 5 | See the paper for additional details: 6 | 7 | *Attanasio, G., Nozza, D., Hovy, D., & Baralis, E. "Entropy-based Attention Regularization Frees Unintended Bias Mitigation from Lists". In Findings of the Association for Computational Linguistics: ACL2022. Association for Computational Linguistics, 2022.* 8 | 9 | ### Quick links 10 | 11 | - ACL Anthology bibkey: `attanasio-etal-2022-entropy` 12 | - ACL Anthology: https://aclanthology.org/2022.findings-acl.88/ 13 | - Preprint: https://arxiv.org/abs/2203.09192 14 | 15 | ## Project structure 16 | 17 | The data used in this study is in `data`. Please note that we are not allowed to distribute all the data sets. For some of those, you will need to download it yourselves (instructions below). 18 | The code is organized in python scripts (training and evaluation of models), bash scripts to run experiments, and jupyter notebooks. 19 | 20 | The main files are the following: 21 | - `train_bert.py`: use this script to train any bert-based model starting from HuggingFace checkpoints. 22 | - `evaluate_model.py`: use this script to evaluate a model either on a test set or a synthetic evaluation set. 23 | 24 | Please find all the accepted parameters running `python --help`. 25 | 26 | ## Getting started 27 | 28 | The following are the basic steps to setup our environment and replicate our results. 29 | 30 | ## Getting the data sets 31 | 32 | Please follow these instructions to retrive the presented dataset: 33 | 34 | - Misogyny (EN): the dataset is not publicly available. Please fill [this form](https://docs.google.com/forms/d/e/1FAIpQLSevs4Ji3dNmK5CxyulYG-PxX3U10-RgDrPpMKPRjtI81f0yaQ/viewform) to submit a request to the authors. 35 | - Misogyny (IT): the dataset is not publicly available. Please fill [this form](https://forms.gle/uFF3sAtMMqayiDiz9) to submit a request to the authors. 36 | - Multilingual and Multi-Aspect (MlMA): the dataset is available online. In `data`, we provide our splitfiles with the additional binary "hate" column used in our experiments. 37 | 38 | For the sake of simplicty, we have assigned short names to each data set. Please find them and how to use them in [dataset.py](./dataset.py). 39 | 40 | ## Dependencies 41 | 42 | You'll need a working Python environment to run the code. 43 | The required dependencies are specified in the file `environment.yml`. 44 | We use `conda` virtual environments to manage the project dependencies in 45 | isolation. 46 | 47 | Run the following command in the repository folder to create a separate environment 48 | and install all required dependencies in it: 49 | 50 | conda create -n ear python==3.8 51 | conda activate ear 52 | pip install -r requirements.txt 53 | 54 | ## Example 55 | 56 | EAR can be plugged very easily to HuggingFace models. 57 | 58 | ```python 59 | from transformers import AutoTokenizer, AutoModel 60 | import ear 61 | 62 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 63 | model = AutoModel.from_pretrained("bert-base-uncased") 64 | 65 | item = tokenizer("Today it's a good day!") 66 | outputs = model(**item, output_attentions=True) 67 | 68 | reg_strength = 0.01 69 | neg_entropy = ear.compute_negative_entropy( 70 | inputs=outputs.attentions, 71 | attention_mask=item["attention_mask"] 72 | ) 73 | reg_loss = reg_strength * neg_entropy 74 | loss = reg_loss + output.loss 75 | 76 | ``` 77 | 78 | ## Reproducing Hate Speech Detection results 79 | 80 | The [`bash`](bash) folder contains some utility bash scripts useful to run multiple experiments sequentially. They cover the training and evaluation pipeline of all the models tested in the paper. To let everything work as expected, please run them from the parent directory. 81 | 82 | ### Training 83 | 84 | Please check out your disk size, these scripts will save two model checkpoints (best and the last one) for every seed. 85 | 86 | Train **BERT** on the Misogyny (EN) dataset: 87 | 88 | ```bash 89 | ./bash/train_model_10_seeds.sh bert-base-uncased 90 | ``` 91 | 92 | e.g., `./bash/train_model_10_seeds.sh bert-base-uncased . miso` 93 | 94 | 95 | Train **BERT+EAR** on the Multilingual and Multi-Aspect dataset: 96 | 97 | ```bash 98 | ./bash/train_model_EAR_10_seeds.sh bert-base-uncased 99 | ``` 100 | 101 | e.g., `./bash/train_model_EAR_10_seeds.sh bert-base-uncased . mlma` 102 | 103 | 104 | Note that: 105 | - if you want to take into account class imbalance, you should add the `--balanced_loss` to the parameters passed as command line arguments to python; 106 | - for [BERT+SOC](https://github.com/BrendanKennedy/contextualizing-hate-speech-models-with-explanations) (Kennedy et al. 2020), we re-use the authors's implementation. Therefore, no 107 | training scripts are provided here. 108 | 109 | ## Testing 110 | 111 | To evaluate a model, or a folder with several models (different seeds), you have to: 112 | 1. run the evaluation on synthetic data. 113 | 2. run the evaluation on test data 114 | 115 | ### Evaluation of bias metrics on synthetic data 116 | 117 | Here we provide an example to run the evaluation on Madlibs77K synthetic data using a specific checkpoint name (`last.ckpt` in this case). 118 | 119 | ```bash 120 | ./bash/evaluate_folder_madlibs_pattern.sh last.ckpt 121 | ``` 122 | 123 | Analogous script for the other synthetic sets are stored in the folder `./bash`. Namely: 124 | - `evaluate_folder_miso_synt.sh` Run the evaluation of all the models within a specified parent directory on Misogyny (EN), synthetic set. 125 | - `evaluate_folder_miso-ita_synt.sh` Run the evaluation of all the models within a specified parent directory on Misogyny (IT), synthetic set. 126 | 127 | ### Evaluation on test data 128 | 129 | Here we provide an example to run the evaluation on the test set of MlMA. 130 | 131 | ```bash 132 | ./bash/test_folder.sh mlma 133 | ``` 134 | Note that evaluation on Misogyny (IT) requires the parameter `--src_tokenizer dbmdz/bert-base-italian-uncased` 135 | 136 | ## EAR for Biased Term Extraction 137 | 138 | We provide a Jupyter Notebook where we show how to extract terms with the lowest contextualization, which 139 | may induce most of the bias in the model. 140 | 141 | After having trained at least one model (i.e., you have a model checkpoint), the notebook [`term_extraction.ipynb`](term_extraction.ipynb) will guide you through the discovery of biased terms. 142 | 143 | ## References 144 | 145 | Please use the following bibtex entry if you use this model in your project: 146 | 147 | ```bib 148 | @inproceedings{attanasio-etal-2022-entropy, 149 | title = "Entropy-based Attention Regularization Frees Unintended Bias Mitigation from Lists", 150 | author = "Attanasio, Giuseppe and 151 | Nozza, Debora and 152 | Hovy, Dirk and 153 | Baralis, Elena", 154 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 155 | month = may, 156 | year = "2022", 157 | address = "Dublin, Ireland", 158 | publisher = "Association for Computational Linguistics", 159 | url = "https://aclanthology.org/2022.findings-acl.88", 160 | pages = "1105--1119", 161 | abstract = "Natural Language Processing (NLP) models risk overfitting to specific terms in the training data, thereby reducing their performance, fairness, and generalizability. E.g., neural hate speech detection models are strongly influenced by identity terms like gay, or women, resulting in false positives, severe unintended bias, and lower performance.Most mitigation techniques use lists of identity terms or samples from the target domain during training. However, this approach requires a-priori knowledge and introduces further bias if important terms are neglected.Instead, we propose a knowledge-free Entropy-based Attention Regularization (EAR) to discourage overfitting to training-specific terms. An additional objective function penalizes tokens with low self-attention entropy.We fine-tune BERT via EAR: the resulting model matches or exceeds state-of-the-art performance for hate speech classification and bias metrics on three benchmark corpora in English and Italian.EAR also reveals overfitting terms, i.e., terms most likely to induce bias, to help identify their effect on the model, task, and predictions.", 162 | } 163 | 164 | ``` 165 | 166 | ### 🚨 Ethical considerations 167 | 168 | The process of building the list remains a data-driven approach, which is strongly dependent on the task, collected corpus, term frequencies, and the chosen model. 169 | Therefore, the list might either lack specific terms that instead need to be attentioned, or include some that do not strictly perpetrate harm. 170 | Because of these twin issues, the resulting lists should not be read as complete or absolute. We would therefore discourage users from simply building and developing models based solely on the extracted terms. We want, instead, the terms to stand as a starting point for debugging and searching for potential bias issues in the task at hand. 171 | 172 | ## License 173 | 174 | All source code is made available under a MIT license. See `LICENSE.md` for the full license text. 175 | 176 | The manuscript text is not open source. The authors reserve the rights to the article content, which is currently submitted for publication. 177 | -------------------------------------------------------------------------------- /evaluate_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate a given model on a given dataset. 3 | """ 4 | import click 5 | 6 | from comet_ml import Experiment 7 | 8 | from config import DEFAULT_OUT_DIR 9 | from dataset import get_dataset_by_name, AVAIL_DATASETS, MLMA_RAW_DATASETS 10 | import metrics 11 | import pandas as pd 12 | import logging 13 | import numpy as np 14 | import os 15 | from tqdm import tqdm 16 | from joblib import Parallel, delayed 17 | import torch 18 | import glob 19 | 20 | logging.basicConfig( 21 | format="%(levelname)s:%(asctime)s:%(module)s:%(message)s", level=logging.INFO 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @click.command() 27 | @click.option("--dataset", type=click.Choice(AVAIL_DATASETS), required=True) 28 | @click.option("--model_path", type=str) 29 | @click.option( 30 | "--subgroups_path", type=str, help="Path to the subgroups file.", default=None 31 | ) 32 | @click.option( 33 | "--n_jobs", 34 | type=int, 35 | help="Used to parallelize the evaluation of bias metrics", 36 | default=4, 37 | ) 38 | @click.option("--cpu_only", is_flag=True) 39 | @click.option("--no_bias_metrics", is_flag=True) 40 | @click.option("--model_suffix", type=str) 41 | @click.option("--out_folder", type=str, default=DEFAULT_OUT_DIR) 42 | @click.option("--log_comet", is_flag=True) 43 | @click.option("--ckpt_pattern", type=str, default=None) 44 | @click.option("--src_tokenizer", type=str, default=None) 45 | @click.option("--src_model", type=str, default=None) 46 | def evaluate( 47 | dataset, 48 | model_path, 49 | subgroups_path, 50 | n_jobs, 51 | cpu_only, 52 | no_bias_metrics, 53 | model_suffix, 54 | out_folder, 55 | log_comet, 56 | ckpt_pattern, 57 | src_tokenizer, 58 | src_model, 59 | ): 60 | os.makedirs(out_folder, exist_ok=True) 61 | 62 | hparams = locals() 63 | 64 | if src_model: 65 | logger.info(f"Using model {src_model}") 66 | model_name = src_model.split("/")[1] 67 | model_path = src_model 68 | else: 69 | model_name = os.path.basename(model_path) 70 | 71 | # lm_ws is created by Kennedy's simulation but shouldn't be used 72 | if model_name.startswith("lm_ws"): 73 | logger.info(f"Skipping the model {model_name}...") 74 | return 75 | 76 | if model_suffix: 77 | model_name = f"{model_name}-{model_suffix}" 78 | 79 | hparams["model"] = model_name 80 | 81 | if log_comet: 82 | experiment = Experiment( 83 | api_key=os.environ["COMET_API_KEY"], 84 | project_name="unbias-text-classifiers", 85 | log_code=False, 86 | log_graph=False, 87 | ) 88 | experiment.set_name(f"evaluate_{dataset}") 89 | experiment.log_parameters(hparams) 90 | experiment.add_tag("evaluation") 91 | 92 | logger.info(f"BEGIN evaluating {model_name} on {dataset}") 93 | 94 | #  Get dataset splits. Discard train and dev 95 | _, _, test = get_dataset_by_name(dataset) 96 | if log_comet: 97 | experiment.log_other("test_size", len(test)) 98 | y_true = test.get_labels() 99 | 100 | scores_file = os.path.join(out_folder, f"scores_{model_name}_{dataset}.pt") 101 | if os.path.exists(os.path.join(out_folder, scores_file)): 102 | logger.info( 103 | "Scores already exist. Loading them and continuing the evaluation..." 104 | ) 105 | scores = torch.load(os.path.join(out_folder, scores_file)) 106 | else: 107 | scores = evaluate_bert(test, model_path, cpu_only, ckpt_pattern, src_tokenizer) 108 | 109 | #  Compute classification metrics based on scores 110 | logger.info("Evaluating standard performance metrics...") 111 | perf, y_pred = metrics.evaluate_metrics(y_true=y_true, y_score=scores, th=0.5) 112 | if log_comet: 113 | experiment.log_metrics(perf) 114 | 115 | # Save scores and classification metrics locally and on Comet 116 | torch.save(scores, scores_file) 117 | pd.Series(perf).to_frame().to_csv( 118 | os.path.join(out_folder, f"class_metrics_{model_name}_{dataset}.csv") 119 | ) 120 | if log_comet: 121 | experiment.log_asset(scores_file) 122 | experiment.log_metrics(perf) 123 | experiment.log_confusion_matrix( 124 | y_true=y_true, y_predicted=y_pred.astype(int).tolist() 125 | ) 126 | 127 | # run the evaluation on MLMA 128 | if dataset in MLMA_RAW_DATASETS: 129 | logger.info("Processing MLMA per-target performance") 130 | mlma_results = compute_metrics_on_mlma(test, y_true, scores) 131 | 132 | mlma_df = pd.DataFrame( 133 | [r[3] for r in mlma_results], index=[r[0] for r in mlma_results] 134 | ) 135 | mlma_df.to_csv( 136 | os.path.join( 137 | out_folder, f"class_metrics_by_target_{model_name}_{dataset}.csv" 138 | ) 139 | ) 140 | 141 | if no_bias_metrics: 142 | if log_comet: 143 | experiment.add_tag("no_bias_metrics") 144 | logger.info(f"END {model_name} (skipped bias metrics)") 145 | return 146 | 147 | # --- Evaluation of bias metrics --- 148 | 149 | #  Read subgroups and add a dummy column indicating its presence 150 | with open(subgroups_path) as fp: 151 | subgroups = [line.strip().split("\t")[0] for line in fp.readlines()] 152 | 153 | logging.info(f"Found subgroups: {subgroups}") 154 | if log_comet: 155 | experiment.log_other("subgroups", subgroups) 156 | experiment.log_other("subgroups_count", len(subgroups)) 157 | 158 | #  this df is required by the Jigsaw's code for bias metrics 159 | data_df = pd.DataFrame( 160 | {"text": test.get_texts(), "label": y_true, model_name: scores} 161 | ) 162 | data_df = metrics.add_subgroup_columns_from_text(data_df, "text", subgroups) 163 | 164 | logger.info("Evaluating bias metrics (parallel)...") 165 | bias_records = Parallel(n_jobs=n_jobs)( 166 | delayed(metrics.compute_bias_metrics_for_subgroup_and_model)( 167 | dataset=data_df, 168 | subgroup=subg, 169 | model=model_name, 170 | label_col="label", 171 | include_asegs=True, 172 | ) 173 | for subg in tqdm(subgroups) 174 | ) 175 | 176 | bias_terms_file = os.path.join(out_folder, f"bias_terms_{model_name}_{dataset}.csv") 177 | per_term_df = pd.DataFrame(bias_records) 178 | per_term_df.to_csv(bias_terms_file, index=False) 179 | if log_comet: 180 | experiment.log_table(bias_terms_file) 181 | 182 | # Average bias metrics across subgroups 183 | records_df = per_term_df.drop(columns=["test_size", "subgroup"]) 184 | 185 | # TODO: ignore nans? 186 | #  compute the mean value of each bias metric across subgroups. Here we use 187 | #  1. power mean (Jigsaw's Kaggle competition). It weights more subgroups where the metric is low. 188 | #  2. arithmetic mean 189 | power_mean_values = metrics.power_mean(records_df.values, -5, ignore_nans=True) 190 | mean_values = metrics.power_mean(records_df.values, 1, ignore_nans=True) 191 | 192 | power_mean_dict = { 193 | f"{name}_power_mean": v 194 | for name, v in zip(records_df.columns, power_mean_values) 195 | } 196 | mean_dict = {f"{name}_mean": v for name, v in zip(records_df.columns, mean_values)} 197 | 198 | # The final summary metric is the average between: 199 | # overall AUC, subgroup_auc, bpsn_auc, bnsp_auc 200 | summary_metric_pm = np.nanmean( 201 | np.array( 202 | [ 203 | perf["AUC"], 204 | power_mean_dict["subgroup_auc_power_mean"], 205 | power_mean_dict["bpsn_auc_power_mean"], 206 | power_mean_dict["bnsp_auc_power_mean"], 207 | ] 208 | ) 209 | ) 210 | 211 | summary_metric = np.nanmean( 212 | np.array( 213 | [ 214 | perf["AUC"], 215 | mean_dict["subgroup_auc_mean"], 216 | mean_dict["bpsn_auc_mean"], 217 | mean_dict["bnsp_auc_mean"], 218 | ] 219 | ) 220 | ) 221 | 222 | bias_metrics = { 223 | **power_mean_dict, 224 | **mean_dict, 225 | "summary_power_mean": summary_metric_pm, 226 | "summary_mean": summary_metric, 227 | } 228 | 229 | #  Add False Positive and False Negative Equality Difference (Equality of Odds) 230 | bias_metrics["fped"] = per_term_df[metrics.FPR_GAP].abs().sum() 231 | bias_metrics["fped_mean"] = per_term_df[metrics.FPR_GAP].abs().mean() 232 | bias_metrics["fped_std"] = per_term_df[metrics.FPR_GAP].abs().std() 233 | bias_metrics["fned"] = per_term_df[metrics.FNR_GAP].abs().sum() 234 | bias_metrics["fned_mean"] = per_term_df[metrics.FNR_GAP].abs().mean() 235 | bias_metrics["fned_std"] = per_term_df[metrics.FNR_GAP].abs().std() 236 | 237 | if log_comet: 238 | experiment.log_metrics(bias_metrics) 239 | pd.Series(bias_metrics).to_frame().to_csv( 240 | os.path.join(out_folder, f"bias_metrics_{model_name}_{dataset}.csv") 241 | ) 242 | logger.info(f"END {model_name}") 243 | 244 | 245 | def evaluate_bert( 246 | dataset, 247 | model_dir, 248 | cpu_only: bool, 249 | ckpt_pattern, 250 | src_tokenizer, 251 | batch_size=64, 252 | max_sequence_length=120, 253 | ): 254 | """Run evaluation on Kennedy's BERT.""" 255 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 256 | from torch.utils.data import DataLoader 257 | 258 | device = "cuda:0" if (torch.cuda.is_available() and not cpu_only) else "cpu" 259 | logger.info(f"Device: {device}") 260 | 261 | if ckpt_pattern: 262 | from train_bert import LMForSequenceClassification 263 | 264 | ckpt_path = glob.glob(os.path.join(model_dir, f"*{ckpt_pattern}*"))[0] 265 | logger.info(f"Loading ckpt {ckpt_path}") 266 | model = LMForSequenceClassification.load_from_checkpoint(ckpt_path).to(device) 267 | else: 268 | model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device) 269 | 270 | model.eval() 271 | 272 | if src_tokenizer: 273 | tokenizer = AutoTokenizer.from_pretrained(src_tokenizer) 274 | else: 275 | logger.info(f"Src tokenizer not specified, using {model_dir}") 276 | tokenizer = AutoTokenizer.from_pretrained(model_dir) 277 | 278 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 279 | 280 | test_loader = DataLoader(dataset, batch_size=batch_size, num_workers=0) 281 | 282 | probs = list() 283 | with torch.no_grad(): 284 | for batch in tqdm(test_loader): 285 | encodings = tokenizer( 286 | batch["text"], 287 | add_special_tokens=True, #  they use BERT's special tokens 288 | padding=True, 289 | truncation=True, 290 | max_length=max_sequence_length, 291 | return_tensors="pt", 292 | ).to(device) 293 | 294 | output = model(**encodings) 295 | batch_probs = output["logits"].softmax(-1) # batch_size x 2 296 | probs.append(batch_probs) 297 | 298 | probs = torch.cat(probs, dim=0) 299 | 300 | #  return probabilities for the positive label only 301 | return probs[:, 1].cpu() 302 | 303 | 304 | def compute_metrics_on_mlma(mlma_data, y_true, scores): 305 | targets = mlma_data.data.target.unique() 306 | logger.info(f"Targets found {targets}") 307 | 308 | target_mask = pd.get_dummies(mlma_data.data["target"]).astype(bool) 309 | 310 | # y_true is a list, y_pred a np.array, scores a torch.tensor 311 | y_true = np.array(y_true) 312 | 313 | results = list() 314 | for target in targets: 315 | mask = target_mask[target].values 316 | perf, y_pred = metrics.evaluate_metrics( 317 | y_true=y_true[mask], y_score=scores[mask], th=0.5 318 | ) 319 | perf["size"] = y_true[mask].size 320 | results.append((target, y_true[mask], scores[mask], perf)) 321 | return results 322 | 323 | 324 | if __name__ == "__main__": 325 | evaluate() 326 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """Collect and expose datasets for experiments.""" 2 | from torch.utils.data import Dataset, DataLoader 3 | import pytorch_lightning as pl 4 | import torch 5 | import pandas as pd 6 | from operator import itemgetter 7 | import logging 8 | import os 9 | 10 | 11 | logging.basicConfig( 12 | format="%(levelname)s:%(asctime)s:%(module)s:%(message)s", level=logging.INFO 13 | ) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | MADLIBS_DATASETS = ["madlibs77k", "madlibs89k"] 18 | TOX_DATASETS = ["tox_nonfuzz", "tox_fuzz"] 19 | 20 | MISO_DATASETS = ["miso", "miso-ita-raw", "miso-ita-synt"] 21 | MISOSYNT_DATASETS = ["miso_synt_test"] 22 | 23 | MLMA_DATASETS = ["mlma"] 24 | MLMA_RAW_DATASETS = ["mlma_en", "mlma_fr", "mlma_ar"] 25 | 26 | 27 | AVAIL_DATASETS = ( 28 | MADLIBS_DATASETS 29 | + TOX_DATASETS 30 | + MISO_DATASETS 31 | + MISOSYNT_DATASETS 32 | + MLMA_DATASETS 33 | ) 34 | 35 | 36 | def get_dataset_by_name(name: str, base_dir=None): 37 | path = os.path.join(base_dir, name) if base_dir else name 38 | 39 | train, dev, test = None, None, None 40 | if name in MADLIBS_DATASETS: 41 | test = Madlibs.build_dataset(path) 42 | elif name in TOX_DATASETS: 43 | test = Toxicity.build_dataset(path) 44 | elif name in MISO_DATASETS: 45 | if name == "miso-ita-synt": 46 | test = MisoDataset.build_dataset(name, "test") 47 | else: 48 | train = MisoDataset.build_dataset(name, "train") 49 | dev = MisoDataset.build_dataset(name, "dev") 50 | test = MisoDataset.build_dataset(name, "test") 51 | elif name in MISOSYNT_DATASETS: 52 | test = MisoSyntDataset.build_dataset(name) 53 | elif name in MLMA_RAW_DATASETS: 54 | test = MLMARawDataset.build_dataset(name) 55 | elif name in MLMA_DATASETS: 56 | train = MLMADataset.build_dataset(split="train") 57 | dev = MLMADataset.build_dataset(split="dev") 58 | test = MLMADataset.build_dataset(split="test") 59 | else: 60 | raise ValueError(f"Can't recognize dataset name {name}") 61 | return train, dev, test 62 | 63 | 64 | def get_tokenized_path(path: str): 65 | base_dir, filename = os.path.dirname(path), os.path.basename(path) 66 | return os.path.join(base_dir, f"{os.path.splitext(filename)[0]}.pt") 67 | 68 | 69 | class MLMARawDataset(Dataset): 70 | #  DEPRECATED 71 | """Multilingual and Multi-Aspect Hate Speech Analysis""" 72 | 73 | def __init__(self, path: str): 74 | self.path = path 75 | data = pd.read_csv(path) 76 | 77 | # define the hate binary label 78 | data["hate"] = 1 79 | data.loc[data.sentiment == "normal", "hate"] = 0 80 | data = data.loc[ 81 | data.sentiment.apply(lambda x: "normal" not in x or x == "normal") 82 | ] 83 | 84 | self.data = data 85 | self.texts = data["tweet"].tolist() 86 | self.labels = data["hate"].astype(int).tolist() 87 | self.tokenized_path = get_tokenized_path(path) 88 | 89 | def __getitem__(self, idx): 90 | return {"text": self.texts[idx], "label": self.labels[idx]} 91 | 92 | def __len__(self): 93 | return len(self.labels) 94 | 95 | def get_texts(self): 96 | return self.texts 97 | 98 | def get_labels(self): 99 | return self.labels 100 | 101 | @classmethod 102 | def build_dataset(cls, name: str): 103 | if name == "mlma_en": 104 | return cls(os.path.join("data", "hate_speech_mlma", f"en_dataset.csv")) 105 | elif name == "mlma_fr": 106 | return cls(os.path.join("data", "hate_speech_mlma", f"fr_dataset.csv")) 107 | elif name == "mlma_ar": 108 | return cls(os.path.join("data", "hate_speech_mlma", f"ar_dataset.csv")) 109 | else: 110 | raise ValueError("Name not recognized.") 111 | 112 | 113 | class MLMADataset(Dataset): 114 | def __init__(self, path: str): 115 | self.path = path 116 | data = pd.read_csv(path, sep="\t") 117 | self.texts = data["tweet"].tolist() 118 | self.labels = data["hate"].astype(int).tolist() 119 | self.tokenized_path = get_tokenized_path(path) 120 | 121 | def __getitem__(self, idx): 122 | return {"text": self.texts[idx], "label": self.labels[idx]} 123 | 124 | def __len__(self): 125 | return len(self.labels) 126 | 127 | def get_texts(self): 128 | return self.texts 129 | 130 | def get_labels(self): 131 | return self.labels 132 | 133 | @classmethod 134 | def build_dataset(cls, split: str): 135 | return cls(f"./data/mlma_{split}.tsv") 136 | 137 | 138 | class MisoDataset(Dataset): 139 | def __init__(self, path: str): 140 | self.path = path 141 | data = pd.read_csv(path, sep="\t") 142 | self.texts = data["text"].tolist() 143 | self.labels = data["misogynous"].astype(int).tolist() 144 | self.tokenized_path = get_tokenized_path(path) 145 | 146 | def __getitem__(self, idx): 147 | return {"text": self.texts[idx], "label": self.labels[idx]} 148 | 149 | def __len__(self): 150 | return len(self.labels) 151 | 152 | def get_texts(self): 153 | return self.texts 154 | 155 | def get_labels(self): 156 | return self.labels 157 | 158 | @classmethod 159 | def build_dataset(cls, name: str, split: str): 160 | 161 | if name == "miso": 162 | return cls(f"./data/miso_{split}.tsv") 163 | elif name == "miso-ita-raw": 164 | return cls(f"./data/AMI2020_{split}_raw.tsv") 165 | elif name == "miso-ita-synt": 166 | return cls(f"./data/AMI2020_{split}_synt.tsv") 167 | else: 168 | raise ValueError("Type not recognized.") 169 | 170 | 171 | class MisoSyntDataset(Dataset): 172 | def __init__(self, path: str): 173 | self.path = path 174 | data = pd.read_csv(path, sep="\t", header=None, names=["Text", "Label"]) 175 | self.texts = data["Text"].tolist() 176 | self.labels = data["Label"].astype(int).tolist() 177 | self.tokenized_path = get_tokenized_path(path) 178 | 179 | def __getitem__(self, idx): 180 | return {"text": self.texts[idx], "label": self.labels[idx]} 181 | 182 | def __len__(self): 183 | return len(self.labels) 184 | 185 | def get_texts(self): 186 | return self.texts 187 | 188 | def get_labels(self): 189 | return self.labels 190 | 191 | @classmethod 192 | def build_dataset(cls, type: str): 193 | if type not in MISOSYNT_DATASETS: 194 | raise ValueError("Type not recognized.") 195 | else: 196 | return cls(f"./data/miso_synt_test.tsv") 197 | 198 | 199 | class Madlibs(Dataset): 200 | def __init__(self, path: str): 201 | self.path = path 202 | data = pd.read_csv(path) 203 | # Use the same convention for binary labels: 0 (NOT_BAD/FALSE), 1 (BAD/TRUE) 204 | self.texts = data["Text"].tolist() 205 | self.labels = pd.get_dummies(data.Label)["BAD"].tolist() 206 | self.tokenized_path = get_tokenized_path(path) 207 | 208 | def __getitem__(self, idx): 209 | return {"text": self.texts[idx], "label": self.labels[idx]} 210 | 211 | def __len__(self): 212 | return len(self.labels) 213 | 214 | def get_texts(self): 215 | return self.texts 216 | 217 | def get_labels(self): 218 | return self.labels 219 | 220 | @classmethod 221 | def build_dataset(cls, type: str): 222 | if type not in MADLIBS_DATASETS: 223 | raise ValueError("Type not recognized.") 224 | if type == "madlibs77k": 225 | return cls(f"./data/bias_madlibs_77k.csv") 226 | else: 227 | return cls(f"./data/bias_madlibs_89k.csv") 228 | 229 | 230 | class TokenizerDataModule(pl.LightningDataModule): 231 | def __init__( 232 | self, 233 | dataset_name, 234 | tokenizer, 235 | batch_size, 236 | max_seq_length, 237 | num_workers, 238 | pin_memory, 239 | load_pre_tokenized=False, 240 | store_pre_tokenized=False, 241 | ): 242 | super().__init__() 243 | self.dataset_name = dataset_name 244 | self.tokenizer = tokenizer 245 | self.batch_size = batch_size 246 | self.max_seq_length = max_seq_length 247 | self.num_workers = num_workers 248 | self.pin_memory = pin_memory 249 | self.load_pre_tokenized = load_pre_tokenized 250 | self.store_pre_tokenized = store_pre_tokenized 251 | 252 | self.train, self.val, self.test = get_dataset_by_name(dataset_name) 253 | self.train_steps = int(len(self.train) / batch_size) 254 | 255 | def prepare_data(self): 256 | train, val, test = self.train, self.val, self.test 257 | 258 | for split in [train, val, test]: 259 | if self.load_pre_tokenized and os.path.exists(split.tokenized_path): 260 | logging.info( 261 | """ 262 | Loading pre-tokenized dataset. 263 | Beware! Using pre-tokenized embeddings could not match you choice for max_length 264 | """ 265 | ) 266 | continue 267 | 268 | if self.load_pre_tokenized: 269 | logging.info(f"Load tokenized but {split.tokenized_path} is not found") 270 | 271 | logger.info("Tokenizing...") 272 | encodings = self.tokenizer( 273 | split.get_texts(), 274 | truncation=True, 275 | padding="max_length", 276 | max_length=self.max_seq_length, 277 | return_tensors="pt", 278 | ) 279 | 280 | if self.store_pre_tokenized: 281 | logger.info(f"Saving to {split.tokenized_path}") 282 | torch.save(encodings, split.tokenized_path) 283 | 284 | def setup(self, stage=None): 285 | if stage == "fit": 286 | train, val = self.train, self.val 287 | 288 | logging.info(f"TRAIN len: {len(train)}") 289 | logging.info(f"VAL len: {len(val)}") 290 | 291 | train_encodings = torch.load(train.tokenized_path) 292 | train_labels = torch.LongTensor([r["label"] for r in train]) 293 | self.train_data = EncodedDataset(train_encodings, train_labels) 294 | 295 | val_encodings = torch.load(val.tokenized_path) 296 | val_labels = torch.LongTensor([r["label"] for r in val]) 297 | self.val_data = EncodedDataset(val_encodings, val_labels) 298 | 299 | elif stage == "test": 300 | test = self.test 301 | logging.info(f"TEST len: {len(test)}") 302 | 303 | test_encodings = torch.load(test.tokenized_path) 304 | test_labels = torch.LongTensor([r["label"] for r in test]) 305 | self.test_data = EncodedDataset(test_encodings, test_labels) 306 | 307 | else: 308 | raise ValueError(f"Stage {stage} not known") 309 | 310 | def train_dataloader(self): 311 | return DataLoader( 312 | self.train_data, 313 | batch_size=self.batch_size, 314 | shuffle=True, 315 | num_workers=self.num_workers, 316 | pin_memory=self.pin_memory, 317 | ) 318 | 319 | def val_dataloader(self): 320 | return DataLoader( 321 | self.val_data, 322 | batch_size=self.batch_size, 323 | num_workers=self.num_workers, 324 | pin_memory=self.pin_memory, 325 | ) 326 | 327 | def test_dataloader(self): 328 | return DataLoader( 329 | self.test_data, 330 | batch_size=self.batch_size, 331 | num_workers=self.num_workers, 332 | pin_memory=self.pin_memory, 333 | ) 334 | 335 | 336 | class EncodedDataset(Dataset): 337 | def __init__(self, encodings, labels): 338 | self.encodings = encodings 339 | self.labels = labels 340 | 341 | def __getitem__(self, idx): 342 | item = {k: v[idx] for k, v in self.encodings.items()} 343 | item["labels"] = self.labels[idx] 344 | return item 345 | 346 | def __len__(self): 347 | return self.labels.shape[0] 348 | 349 | 350 | class PlainDataset(Dataset): 351 | def __init__(self, texts, labels): 352 | self.texts = texts 353 | self.labels = labels 354 | 355 | def __getitem__(self, index): 356 | return {"text": self.texts[index], "label": self.labels[index]} 357 | 358 | def __len__(self): 359 | return len(self.labels) 360 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Part of the code is adapted from: 3 | - https://github.com/conversationai/unintended-ml-bias-analysis 4 | """ 5 | from sklearn.metrics import ( 6 | accuracy_score, 7 | roc_auc_score, 8 | f1_score, 9 | precision_score, 10 | recall_score, 11 | confusion_matrix, 12 | ) 13 | import numpy as np 14 | import pandas as pd 15 | import re 16 | import scipy.stats as stats 17 | import logging 18 | import numpy as np 19 | 20 | 21 | logging.basicConfig( 22 | format="%(levelname)s:%(asctime)s:%(module)s:%(message)s", level=logging.INFO 23 | ) 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def threshold_scores(scores, th: float = 0.5): 28 | scores = np.array(scores) 29 | s = np.zeros(scores.shape[0]) 30 | s[scores >= th] = 1 31 | return s 32 | 33 | 34 | def AUC(y_true, y_pred): 35 | return roc_auc_score(y_true, y_pred) 36 | 37 | 38 | def accuracy(y_true, y_pred): 39 | return accuracy_score(y_true, y_pred) 40 | 41 | 42 | def F1(y_true, y_pred, **kwargs): 43 | """Note: by default F1 is computed on the positive class.""" 44 | return f1_score(y_true, y_pred, **kwargs) 45 | 46 | 47 | def evaluate_metrics(y_true, y_score, th=None): 48 | """Evaluate multiple metrics of interest with default parameters at once.""" 49 | perf = dict() 50 | 51 | # compute metrics based on scores 52 | perf["AUC"] = AUC(y_true, y_score) 53 | 54 | # compute metrics based on predictions 55 | y_pred = None 56 | if th: 57 | y_pred = threshold_scores(y_score, th) 58 | perf["acc"] = accuracy(y_true, y_pred) 59 | perf["F1_weighted"] = f1_score(y_true, y_pred, average="weighted") 60 | perf["F1_macro"] = f1_score(y_true, y_pred, average="macro") 61 | perf["F1_binary"] = f1_score(y_true, y_pred, average="binary") 62 | perf["precision_1"] = precision_score(y_true, y_pred, pos_label=1) 63 | perf["precision_0"] = precision_score(y_true, y_pred, pos_label=0) 64 | perf["recall_1"] = recall_score(y_true, y_pred, pos_label=1) 65 | perf["recall_0"] = recall_score(y_true, y_pred, pos_label=0) 66 | 67 | tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() 68 | perf["FPR"] = fp / (fp + tn) 69 | perf["FNR"] = fn / (fn + tp) 70 | 71 | return perf, y_pred 72 | 73 | 74 | # https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/overview/evaluation 75 | def power_mean(x, p: int, ignore_nans: bool = False): 76 | """Evaluate the power mean. 77 | 78 | If x.ndim == 1: 79 | x : array_like (n_rows,) 80 | return: float 81 | If x.ndim == 2: 82 | x : array_like (n_rows, n_cols) 83 | return: array_like (n_cols, ) 84 | """ 85 | x = np.array(x) 86 | mean_f = np.nanmean if ignore_nans else np.mean 87 | 88 | if x.ndim == 1: 89 | return mean_f(x ** p) ** (1 / p) 90 | elif x.ndim == 2: 91 | return mean_f(x ** p, axis=0) ** (1 / p) 92 | else: 93 | raise ValueError("The input array must be either 1D or 2D.") 94 | 95 | 96 | # Code from: 97 | # https://github.com/conversationai/unintended-ml-bias-analysis/model_bias_analysis.py 98 | 99 | # Bias metrics computed for each subgroup. 100 | SUBGROUP_SIZE = "test_size" 101 | SUBGROUP = "subgroup" 102 | SUBGROUP_AUC = "subgroup_auc" 103 | NEGATIVE_CROSS_AUC = "bpsn_auc" 104 | POSITIVE_CROSS_AUC = "bnsp_auc" 105 | NEGATIVE_AEG = "negative_aeg" 106 | POSITIVE_AEG = "positive_aeg" 107 | NEGATIVE_ASEG = "negative_aseg" 108 | POSITIVE_ASEG = "positive_aseg" 109 | FPR = "fpr" 110 | FPR_GAP = "fpr_gap" 111 | FNR = "fnr" 112 | FNR_GAP = "fnr_gap" 113 | 114 | 115 | def add_subgroup_columns_from_text( 116 | df, text_column, subgroups, expect_spaces_around_words=True 117 | ): 118 | """Adds a boolean column for each subgroup to the data frame. 119 | 120 | New column contains True if the text contains that subgroup term. 121 | 122 | Args: 123 | df: Pandas dataframe to process. 124 | text_column: Column in df containing the text. 125 | subgroups: List of subgroups to search text_column for. 126 | expect_spaces_around_words: Whether to expect subgroup to be surrounded by 127 | spaces in the text_column. Set to False to for languages which do not 128 | use spaces. 129 | """ 130 | ndf = df.copy() 131 | for term in subgroups: 132 | if expect_spaces_around_words: 133 | # pylint: disable=cell-var-from-loop 134 | ndf[term] = ndf[text_column].apply( 135 | lambda x: bool( 136 | re.search("\\b" + term + "\\b", x, flags=re.UNICODE | re.IGNORECASE) 137 | ) 138 | ) 139 | else: 140 | ndf[term] = ndf[text_column].str.contains(term, case=False) 141 | return ndf 142 | 143 | 144 | def compute_bias_metrics_for_subgroup_and_model( 145 | dataset: pd.DataFrame, 146 | subgroup: str, 147 | model: str, 148 | label_col: str, 149 | threshold: float = 0.5, 150 | include_asegs=False, 151 | ): 152 | """Computes per-subgroup metrics for one model and subgroup. 153 | 154 | This the general method to extend if new metrics are included/excluded. 155 | """ 156 | record = {SUBGROUP: subgroup, SUBGROUP_SIZE: len(dataset[dataset[subgroup]])} 157 | record[SUBGROUP_AUC] = compute_subgroup_auc(dataset, subgroup, label_col, model) 158 | record[NEGATIVE_CROSS_AUC] = compute_negative_cross_auc( 159 | dataset, subgroup, label_col, model 160 | ) 161 | record[POSITIVE_CROSS_AUC] = compute_positive_cross_auc( 162 | dataset, subgroup, label_col, model 163 | ) 164 | record[NEGATIVE_AEG] = compute_negative_aeg(dataset, subgroup, label_col, model) 165 | record[POSITIVE_AEG] = compute_positive_aeg(dataset, subgroup, label_col, model) 166 | 167 | record[FPR] = compute_fpr(dataset, label_col, model, threshold, subgroup) 168 | record[FPR_GAP] = compute_fpr(dataset, label_col, model, threshold) - record[FPR] 169 | record[FNR] = compute_fnr(dataset, label_col, model, threshold, subgroup) 170 | record[FNR_GAP] = compute_fnr(dataset, label_col, model, threshold) - record[FNR] 171 | 172 | if include_asegs: 173 | ( 174 | record[POSITIVE_ASEG], 175 | record[NEGATIVE_ASEG], 176 | ) = compute_average_squared_equality_gap(dataset, subgroup, label_col, model) 177 | return record 178 | 179 | 180 | def column_name(model, metric): 181 | return f"{model}_{metric}" 182 | 183 | 184 | ################################### 185 | #  AUC-based metrics (Borkan et al., 2019) 186 | ################################### 187 | 188 | 189 | def compute_subgroup_auc(df, subgroup, label, model_name): 190 | subgroup_examples = df[df[subgroup]] 191 | try: 192 | return AUC(subgroup_examples[label], subgroup_examples[model_name]) 193 | except ValueError as e: 194 | logger.error( 195 | f"Trying to compute AUC on subgroup {subgroup}: {e}. Returning np.nan" 196 | ) 197 | return np.nan 198 | 199 | 200 | def compute_negative_cross_auc(df, subgroup, label, model_name): 201 | """Computes the AUC of the within-subgroup negative examples and the background positive examples.""" 202 | subgroup_negative_examples = df[df[subgroup] & ~df[label]] 203 | non_subgroup_positive_examples = df[~df[subgroup] & df[label]] 204 | examples = subgroup_negative_examples.append(non_subgroup_positive_examples) 205 | try: 206 | return AUC(examples[label], examples[model_name]) 207 | except ValueError as e: 208 | logger.error( 209 | f"Trying to compute AUC on subgroup {subgroup}: {e}. Returning np.nan" 210 | ) 211 | return np.nan 212 | 213 | 214 | def compute_positive_cross_auc(df, subgroup, label, model_name): 215 | """Computes the AUC of the within-subgroup positive examples and the background negative examples.""" 216 | subgroup_positive_examples = df[df[subgroup] & df[label]] 217 | non_subgroup_negative_examples = df[~df[subgroup] & ~df[label]] 218 | examples = subgroup_positive_examples.append(non_subgroup_negative_examples) 219 | try: 220 | return AUC(examples[label], examples[model_name]) 221 | except ValueError as e: 222 | logger.error( 223 | f"Trying to compute AUC on subgroup {subgroup}: {e}. Returning np.nan" 224 | ) 225 | return np.nan 226 | 227 | 228 | ################################### 229 | #  Threshold independent metrics (Borkan et al., 2019) 230 | ################################### 231 | 232 | 233 | def normalized_mwu(data1, data2, model_name): 234 | """Calculate number of datapoints with a higher score in data1 than data2.""" 235 | scores_1 = data1[model_name] 236 | scores_2 = data2[model_name] 237 | n1 = len(scores_1) 238 | n2 = len(scores_2) 239 | if n1 == 0 or n2 == 0: 240 | return None 241 | u, _ = stats.mannwhitneyu(scores_1, scores_2, alternative="less") 242 | return u / (n1 * n2) 243 | 244 | 245 | def compute_negative_aeg(df, subgroup, label, model_name): 246 | """Compute the Negative Average Equality Gap""" 247 | mwu = normalized_mwu( 248 | df[~df[subgroup] & ~df[label]], df[df[subgroup] & ~df[label]], model_name 249 | ) 250 | if mwu is None: 251 | return np.nan 252 | return 0.5 - mwu 253 | 254 | 255 | def compute_positive_aeg(df, subgroup, label, model_name): 256 | """Compute the Positive Average Equality Gap""" 257 | mwu = normalized_mwu( 258 | df[~df[subgroup] & df[label]], df[df[subgroup] & df[label]], model_name 259 | ) 260 | if mwu is None: 261 | return np.nan 262 | return 0.5 - mwu 263 | 264 | 265 | def compute_average_squared_equality_gap(df, subgroup, label, model_name): 266 | """Returns the positive and negative ASEG metrics. 267 | 268 | TODO: extend this definition 269 | """ 270 | subgroup_df = df[df[subgroup]] 271 | background_df = df[~df[subgroup]] 272 | if subgroup_df.empty or background_df.empty: 273 | return np.nan, np.nan 274 | 275 | thresholds = np.linspace(1.0, 0.0, num=1000) 276 | s_fpr, s_tpr = positive_rates(subgroup_df, model_name, label, thresholds) 277 | b_fpr, b_tpr = positive_rates(background_df, model_name, label, thresholds) 278 | 279 | def squared_diff_integral(y, x): 280 | return np.trapz(np.square(np.subtract(y, x)), x) 281 | 282 | if s_fpr and s_tpr and b_fpr and b_tpr: 283 | return squared_diff_integral(s_tpr, b_tpr), squared_diff_integral(s_fpr, b_fpr) 284 | return np.nan, np.nan 285 | 286 | 287 | def positive_rates(df, score_col, label_col, thresholds): 288 | """Compute false positive and true positive rates.""" 289 | tpr = [] 290 | fpr = [] 291 | for threshold in thresholds: 292 | confusion = confusion_matrix_counts(df, score_col, label_col, threshold) 293 | if ( 294 | confusion["tp"] + confusion["fn"] == 0 295 | or confusion["fp"] + confusion["tn"] == 0 296 | ): 297 | return None, None 298 | tpr.append(confusion["tp"] / (confusion["tp"] + confusion["fn"])) 299 | fpr.append(confusion["fp"] / (confusion["fp"] + confusion["tn"])) 300 | return fpr, tpr 301 | 302 | 303 | def confusion_matrix_counts(df, score_col, label_col, threshold): 304 | return { 305 | "tp": len(df[(df[score_col] >= threshold) & df[label_col]]), 306 | "tn": len(df[(df[score_col] < threshold) & ~df[label_col]]), 307 | "fp": len(df[(df[score_col] >= threshold) & ~df[label_col]]), 308 | "fn": len(df[(df[score_col] < threshold) & df[label_col]]), 309 | } 310 | 311 | 312 | ################################### 313 | # Error Rate Equality Difference (or Equality of Odds, Hardt, 2016) 314 | ################################### 315 | 316 | 317 | def false_positive_equality_difference( 318 | df: pd.DataFrame, label_col: str, scores_col: str, threshold: float, subgroups: list 319 | ): 320 | """Compute False Positive Equality Difference.""" 321 | fpr = compute_fpr(df, label_col, scores_col, threshold) 322 | subg_fprs = np.array( 323 | [compute_fpr(df, label_col, scores_col, threshold, subg) for subg in subgroups] 324 | ) 325 | return (fpr - subg_fprs).abs().sum() 326 | 327 | 328 | def false_negative_equality_difference( 329 | df: pd.DataFrame, label_col: str, scores_col: str, threshold: float, subgroups: list 330 | ): 331 | """Compute False Negative Equality Difference.""" 332 | fnr = compute_fnr(df, label_col, scores_col, threshold) 333 | subg_fnrs = np.array( 334 | [compute_fnr(df, label_col, scores_col, threshold, subg) for subg in subgroups] 335 | ) 336 | return (fnr - subg_fnrs).abs().sum() 337 | 338 | 339 | def compute_fpr(df, label, model_name, threshold, subgroup: str = None): 340 | """Compute FPR (optionally on a subgroup).""" 341 | if subgroup: 342 | df = df[df[subgroup]] 343 | cm = confusion_matrix_counts(df, model_name, label, threshold) 344 | return cm["fp"] / (cm["fp"] + cm["tn"]) if (cm["fp"] + cm["tn"] != 0) else np.nan 345 | 346 | 347 | def compute_fnr(df, label, model_name, threshold, subgroup: str = None): 348 | """Compute FNR (optionally on a subgroup).""" 349 | if subgroup: 350 | df = df[df[subgroup]] 351 | cm = confusion_matrix_counts(df, model_name, label, threshold) 352 | return cm["fn"] / (cm["fn"] + cm["tp"]) if (cm["fn"] + cm["tp"] != 0) else np.nan 353 | -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import click 4 | import logging 5 | 6 | import comet_ml 7 | 8 | from dataset import get_dataset_by_name, TokenizerDataModule 9 | import IPython 10 | import pdb 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.utils.data import DataLoader, random_split 17 | import pytorch_lightning as pl 18 | import pytorch_lightning.metrics.functional as plf 19 | from transformers import ( 20 | AutoModelForSequenceClassification, 21 | AutoTokenizer, 22 | AdamW, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | import pandas as pd 26 | 27 | # from aim.pytorch_lightning import AimLogger 28 | 29 | logging.basicConfig( 30 | format="%(levelname)s:%(asctime)s:%(module)s:%(message)s", level=logging.INFO 31 | ) 32 | logger = logging.getLogger(__name__) 33 | 34 | # this hides a warning thrown by huggingface transformers 35 | # https://github.com/huggingface/transformers/issues/5486 36 | # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 37 | os.environ["TOKENIZERS_PARALLELISM"] = "true" #  set to false is processes stuck 38 | 39 | 40 | class LMForSequenceClassification(pl.LightningModule): 41 | def __init__( 42 | self, 43 | src_model: str, 44 | learning_rate: float, 45 | regularization: str = None, 46 | reg_strength: float = 0.01, 47 | weight_decay: float = 0.0, 48 | warmup_train_perc: float = None, 49 | train_steps_count: int = None, 50 | class_weights: torch.Tensor = None 51 | ): 52 | super().__init__() 53 | 54 | if regularization and regularization == "norm": 55 | # use custom transformers from: 56 | # https://github.com/gorokoba560/norm-analysis-of-transformer 57 | # the norm evaluation is currently supported on Bert only 58 | import transformers 59 | from transformers import BertForSequenceClassification 60 | 61 | assert transformers.__version__ == "3.0.0" 62 | self.model = BertForSequenceClassification.from_pretrained(src_model) 63 | else: 64 | self.model = AutoModelForSequenceClassification.from_pretrained(src_model) 65 | 66 | self.save_hyperparameters() 67 | 68 | if class_weights is not None: 69 | self.register_buffer("class_weights", class_weights) 70 | 71 | #  metrics 72 | self.train_acc = pl.metrics.Accuracy() 73 | self.train_F1 = pl.metrics.F1(num_classes=2, average="macro") 74 | self.val_acc = pl.metrics.Accuracy() 75 | self.val_F1 = pl.metrics.F1(num_classes=2, average="macro") 76 | self.test_acc = pl.metrics.Accuracy() 77 | self.test_F1 = pl.metrics.F1(num_classes=2, average="macro") 78 | self.test_prec = pl.metrics.Precision(num_classes=2, average="macro") 79 | self.test_rec = pl.metrics.Recall(num_classes=2, average="macro") 80 | 81 | def forward(self, **inputs): 82 | return self.model(**inputs) 83 | 84 | def forward_pass(self, batch): 85 | if self.hparams.regularization: 86 | out = self(**batch, output_attentions=True, return_dict=True) 87 | loss, logits, attentions = out["loss"], out["logits"], out["attentions"] 88 | 89 | if self.hparams.class_weights is not None: 90 | loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) 91 | labels = batch["labels"] 92 | loss = loss_fct(logits.view(-1, self.model.num_labels), labels.view(-1)) 93 | 94 | info_vectors = attentions 95 | negative_entropy = compute_negative_entropy( 96 | info_vectors, batch["attention_mask"] 97 | ) 98 | reg_loss = self.hparams.reg_strength * negative_entropy 99 | return loss, logits, negative_entropy, reg_loss 100 | 101 | else: 102 | out = self(**batch, return_dict=True) 103 | loss, logits = out["loss"], out["logits"] 104 | if self.hparams.class_weights is not None: 105 | loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) 106 | labels = batch["labels"] 107 | loss = loss_fct(logits.view(-1, self.model.num_labels), labels.view(-1)) 108 | return loss, logits 109 | 110 | def training_step(self, batch, batch_idx): 111 | if self.hparams.regularization: 112 | loss, logits, negative_entropy, reg_loss = self.forward_pass(batch) 113 | self.log("train_class_loss", loss, prog_bar=True) 114 | self.log("train_reg_loss", reg_loss, prog_bar=True) 115 | self.log("entropy", -negative_entropy) 116 | loss += reg_loss 117 | else: 118 | loss, logits = self.forward_pass(batch) 119 | self.log("train_class_loss", loss, prog_bar=True) 120 | 121 | y_true = batch["labels"] 122 | y_pred = logits.argmax(-1) 123 | 124 | self.train_acc(y_pred, y_true) 125 | self.train_F1(y_pred, y_true) 126 | 127 | self.log("train_loss", loss, prog_bar=True, sync_dist=True) 128 | self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) 129 | self.log("train_F1", self.train_F1, on_step=False, on_epoch=True) 130 | 131 | return loss 132 | 133 | def validation_step(self, batch, batch_idx): 134 | if self.hparams.regularization: 135 | loss, logits, negative_entropy, reg_loss = self.forward_pass(batch) 136 | 137 | self.log("val_class_loss", loss, sync_dist=True) 138 | self.log("entropy", -negative_entropy, on_step=False, on_epoch=True) 139 | loss += reg_loss 140 | else: 141 | loss, logits = self.forward_pass(batch) 142 | self.log("val_class_loss", loss, sync_dist=True) 143 | 144 | y_true = batch["labels"] 145 | y_pred = logits.argmax(-1) 146 | 147 | self.val_acc(y_pred, y_true) 148 | self.val_F1(y_pred, y_true) 149 | 150 | # self.log("val_loss", loss, on_step=True, on_epoch=False) 151 | self.log("val_acc", self.val_acc, on_step=False, on_epoch=True) 152 | self.log("val_F1", self.val_F1, on_step=False, on_epoch=True) 153 | 154 | if self.hparams.regularization: 155 | return {"val_loss": loss, "val_reg_loss": reg_loss} 156 | else: 157 | return {"val_loss": loss} 158 | 159 | def validation_epoch_end(self, outputs): 160 | btc_losses = torch.stack([x["val_loss"] for x in outputs]) 161 | if self.hparams.regularization: 162 | reg_losses = torch.stack([x["val_reg_loss"] for x in outputs]) 163 | 164 | if self.trainer.use_ddp: 165 | btc_losses = self.all_gather(btc_losses) 166 | 167 | if self.hparams.regularization: 168 | reg_losses = self.all_gather(reg_losses) 169 | 170 | self.log("val_loss", btc_losses.mean(), on_step=False, sync_dist=True) 171 | if self.hparams.regularization: 172 | self.log("val_reg_loss", reg_losses.mean(), on_step=False, sync_dist=True) 173 | 174 | def test_step(self, batch, batch_idx): 175 | if self.hparams.regularization: 176 | loss, logits, negative_entropy, reg_loss = self.forward_pass(batch) 177 | loss += reg_loss 178 | else: 179 | loss, logits = self.forward_pass(batch) 180 | 181 | y_true = batch["labels"] 182 | y_pred = logits.argmax(-1) 183 | 184 | self.log("test_loss", loss, sync_dist=True) 185 | 186 | self.test_acc(y_pred, y_true) 187 | self.test_F1(y_pred, y_true) 188 | self.test_prec(y_pred, y_true) 189 | self.test_rec(y_pred, y_true) 190 | 191 | self.log("test_acc", self.test_acc) 192 | self.log("test_F1", self.test_F1) 193 | self.log("test_prec", self.test_prec) 194 | self.log("test_rec", self.test_rec) 195 | 196 | def configure_optimizers(self): 197 | # This code is taken from: 198 | # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L102 199 | 200 | # Don't apply weight decay to any parameters whose names include these tokens. 201 | # (Here, the BERT doesn't have `gamma` or `beta` parameters, only `bias` terms) 202 | no_decay = ["bias", "LayerNorm.weight"] 203 | 204 | # Separate the `weight` parameters from the `bias` parameters. 205 | # - For the `weight` parameters, this specifies a 'weight_decay_rate' of 0.01. 206 | # - For the `bias` parameters, the 'weight_decay_rate' is 0.0. 207 | grouped_parameters = [ 208 | # Filter for all parameters which *don't* include 'bias', 'gamma', 'beta'. 209 | { 210 | "params": [ 211 | p 212 | for n, p in self.named_parameters() 213 | if not any(nd in n for nd in no_decay) 214 | ], 215 | "weight_decay_rate": self.hparams.weight_decay, 216 | }, 217 | # Filter for parameters which *do* include those. 218 | { 219 | "params": [ 220 | p 221 | for n, p in self.named_parameters() 222 | if any(nd in n for nd in no_decay) 223 | ], 224 | "weight_decay_rate": 0.0, 225 | }, 226 | ] 227 | 228 | optimizer = AdamW(grouped_parameters, lr=self.hparams.learning_rate) 229 | 230 | if self.hparams.warmup_train_perc and self.hparams.train_steps_count: 231 | ws = int(self.hparams.warmup_train_perc * self.hparams.train_steps_count) 232 | scheduler = get_linear_schedule_with_warmup( 233 | optimizer, 234 | num_warmup_steps=ws, 235 | num_training_steps=self.hparams.train_steps_count, 236 | ) 237 | 238 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 239 | 240 | return optimizer 241 | 242 | def get_backbone(self): 243 | return self.model 244 | 245 | def get_progress_bar_dict(self): 246 | items = super().get_progress_bar_dict() 247 | items.pop("v_num", None) 248 | return items 249 | 250 | 251 | def compute_negative_entropy( 252 | inputs: tuple, attention_mask: torch.Tensor, return_values=False 253 | ): 254 | """Compute the negative entropy across layers of a network for given inputs. 255 | 256 | Args: 257 | - input: tuple. Tuple of length num_layers. Each item should be in the form: BHSS 258 | - attention_mask. Tensor with dim: BS 259 | """ 260 | inputs = torch.stack(inputs) #  LayersBatchHeadsSeqlenSeqlen 261 | assert inputs.ndim == 5, "Here we expect 5 dimensions in the form LBHSS" 262 | 263 | #  average over attention heads 264 | pool_heads = inputs.mean(2) 265 | 266 | batch_size = pool_heads.shape[1] 267 | samples_entropy = list() 268 | neg_entropies = list() 269 | for b in range(batch_size): 270 | #  get inputs from non-padded tokens of the current sample 271 | mask = attention_mask[b] 272 | sample = pool_heads[:, b, mask.bool(), :] 273 | sample = sample[:, :, mask.bool()] 274 | 275 | #  get the negative entropy for each non-padded token 276 | neg_entropy = (sample.softmax(-1) * sample.log_softmax(-1)).sum(-1) 277 | if return_values: 278 | neg_entropies.append(neg_entropy.detach()) 279 | 280 | #  get the "average entropy" that traverses the layer 281 | mean_entropy = neg_entropy.mean(-1) 282 | 283 | #  store the sum across all the layers 284 | samples_entropy.append(mean_entropy.sum(0)) 285 | 286 | # average over the batch 287 | final_entropy = torch.stack(samples_entropy).mean() 288 | if return_values: 289 | return final_entropy, neg_entropies 290 | else: 291 | return final_entropy 292 | 293 | 294 | SUPPORTED_MODELS = [ 295 | "bert-base-uncased", 296 | "bert-base-multilingual-uncased", 297 | "dbmdz/bert-base-italian-uncased", 298 | ] 299 | 300 | 301 | @click.command() 302 | @click.option("--src_model", type=str, required=True) 303 | @click.option("--output_dir", type=str, default="./dumps") 304 | @click.option("--training_dataset", type=str, default="wiki") 305 | @click.option("--batch_size", type=int, default=32) 306 | @click.option("--num_workers", type=int, default=0) 307 | @click.option("--seed", type=int, default=42) 308 | @click.option("--max_epochs", type=int, default=20) 309 | @click.option("--gpus", type=int, default=0) 310 | @click.option("--accelerator", type=str, default=None) 311 | @click.option("--max_seq_length", type=int, default=None) 312 | @click.option("--learning_rate", type=float, default=2e-5) 313 | @click.option("--early_stop_epochs", type=int, default=5) 314 | @click.option("--regularization", type=str, default=None) 315 | @click.option("--reg_strength", type=float, default=0.01) 316 | @click.option("--weight_decay", type=float, default=0.0) 317 | @click.option("--warmup_train_perc", type=float, default=None, help="Value [0,1]") 318 | @click.option("--accumulate_grad_batches", type=int, default=1) 319 | @click.option("--precision", type=int, default=32) 320 | @click.option("--run_test", is_flag=True) 321 | @click.option("--pin_memory", is_flag=True) 322 | @click.option("--log_every_n_steps", type=int, default=50) 323 | @click.option("--monitor", type=str, default="val_loss") 324 | @click.option("--checkpoint_every_n_epochs", type=int, default=None) 325 | @click.option("--save_transformers_model", is_flag=True) 326 | @click.option("--ckpt_save_top_k", type=int, default=1) 327 | @click.option("--resume_from_checkpoint", type=str, default=None) 328 | @click.option("--balanced_loss", is_flag=True) 329 | def main( 330 | src_model, 331 | output_dir, 332 | training_dataset, 333 | batch_size, 334 | num_workers, 335 | seed, 336 | max_epochs, 337 | gpus, 338 | accelerator, 339 | max_seq_length, 340 | learning_rate, 341 | early_stop_epochs, 342 | regularization, 343 | reg_strength, 344 | weight_decay, 345 | warmup_train_perc, 346 | accumulate_grad_batches, 347 | precision, 348 | run_test, 349 | pin_memory, 350 | log_every_n_steps, 351 | monitor, 352 | checkpoint_every_n_epochs, 353 | save_transformers_model, 354 | ckpt_save_top_k, 355 | resume_from_checkpoint, 356 | balanced_loss 357 | ): 358 | hparams = locals() 359 | pl.seed_everything(seed) 360 | 361 | model_name = None 362 | if src_model in SUPPORTED_MODELS: 363 | if not regularization: 364 | model_name = f"vanillabert-{training_dataset}-{seed}" 365 | experiment_name = f"vanillabert-{training_dataset}" 366 | elif regularization == "entropy": 367 | model_name = f"entropybert-{training_dataset}-{seed}-{reg_strength}" 368 | experiment_name = f"entropybert-{training_dataset}" 369 | elif regularization == "norm": 370 | model_name = f"normbert-{training_dataset}-{seed}-{reg_strength}" 371 | experiment_name = f"normbert-{training_dataset}" 372 | else: 373 | raise ValueError(f"src_model is not supported {src_model}") 374 | 375 | os.makedirs(output_dir, exist_ok=True) 376 | 377 | model_dir = os.path.join(output_dir, model_name) 378 | 379 | # logic to resume from checkpoint 380 | if os.path.exists(model_dir): 381 | if not resume_from_checkpoint: 382 | logger.info( 383 | f"The model {model_name} already exists and training was completed. Skipping..." 384 | ) 385 | return 386 | else: 387 | ckpt_path = os.path.join(model_dir, resume_from_checkpoint) 388 | if os.path.exists(ckpt_path): 389 | logger.info( 390 | f"The model {model_name} already exists but training was not completed. Resuming from {resume_from_checkpoint}..." 391 | ) 392 | resume_from_checkpoint = ckpt_path 393 | else: 394 | logging.error(f"{ckpt_path} doesn't exist. Aborting.") 395 | return 396 | 397 | tokenizer = AutoTokenizer.from_pretrained(src_model) 398 | 399 | 400 | # logging.info("Tokenizing sets...") 401 | # tok_train = TokenizedDataset(train, tokenizer, max_seq_length, load_tokenized=True) 402 | # tok_val = TokenizedDataset(val, tokenizer, max_seq_length, load_tokenized=True) 403 | # tok_test = TokenizedDataset(test, tokenizer, max_seq_length, load_tokenized=True) 404 | # logging.info("Tokenization completed") 405 | 406 | # logging.info(f"TRAIN: {len(tok_train)}") 407 | # logging.info(f"VAL: {len(tok_val)}") 408 | # logging.info(f"TEST: {len(tok_test)}") 409 | 410 | # train_loader = DataLoader( 411 | # tok_train, 412 | # batch_size=batch_size, 413 | # num_workers=num_workers, 414 | # pin_memory=True, 415 | # shuffle=True, 416 | # ) 417 | # val_loader = DataLoader( 418 | # tok_val, batch_size=batch_size, num_workers=num_workers, pin_memory=True 419 | # ) 420 | # test_loader = DataLoader( 421 | # tok_test, batch_size=batch_size, num_workers=num_workers, pin_memory=True 422 | # ) 423 | 424 | dataset_module = TokenizerDataModule( 425 | dataset_name=training_dataset, 426 | tokenizer=tokenizer, 427 | batch_size=batch_size, 428 | max_seq_length=max_seq_length, 429 | num_workers=num_workers, 430 | pin_memory=pin_memory, 431 | load_pre_tokenized=True, 432 | store_pre_tokenized=True, 433 | ) 434 | 435 | # check if linear lr warmup is required 436 | train_steps_count = None 437 | if warmup_train_perc: 438 | logger.info(f"Warmup linear LR requested with {warmup_train_perc}") 439 | train_steps_count = ( 440 | int(dataset_module.train_steps / accumulate_grad_batches) * max_epochs 441 | ) 442 | logger.info(f"Total training steps: {train_steps_count}") 443 | if gpus and gpus > 0: 444 | train_steps_count = train_steps_count // gpus 445 | logger.info(f"Total training steps (gpu-normalized): {train_steps_count}") 446 | 447 | if balanced_loss: 448 | train, val, test = get_dataset_by_name(training_dataset) 449 | labels_count = pd.Series(train.labels).value_counts() 450 | labels_count = labels_count / len(train.labels) 451 | labels_count = 1 - labels_count 452 | labels_count = labels_count.sort_index() 453 | class_weights = torch.Tensor(labels_count) 454 | logger.info(f"Class weights: {class_weights}") 455 | else: 456 | class_weights = None 457 | 458 | #  Instantiate a LM and create the experiment accordingly 459 | model = LMForSequenceClassification( 460 | src_model, 461 | learning_rate, 462 | regularization, 463 | reg_strength, 464 | weight_decay=weight_decay, 465 | warmup_train_perc=warmup_train_perc, 466 | train_steps_count=train_steps_count, 467 | class_weights=class_weights 468 | ) 469 | 470 | # set some training stuff (loggers, callback) 471 | loggers = list() 472 | if "COMET_API_KEY" in os.environ: 473 | comet_logger = pl.loggers.CometLogger( 474 | api_key=os.environ["COMET_API_KEY"], 475 | project_name="unbias-text-classifiers", # Optional 476 | experiment_name=experiment_name, # Optional 477 | log_code=False, 478 | log_graph=False, 479 | ) 480 | comet_logger.experiment.add_tag("training") 481 | comet_logger.log_hyperparams(hparams) 482 | loggers.append(comet_logger) 483 | 484 | #  define training callbacks 485 | callbacks = list() 486 | if early_stop_epochs > 0: 487 | early_stopping = pl.callbacks.EarlyStopping(monitor, patience=early_stop_epochs) 488 | callbacks.append(early_stopping) 489 | 490 | model_checkpoint = pl.callbacks.ModelCheckpoint( 491 | monitor=monitor, 492 | dirpath=model_dir, 493 | save_last=True, 494 | save_top_k=ckpt_save_top_k, 495 | filename="PL-{epoch}-{val_loss:.3f}-{train_loss:.3f}", 496 | ) 497 | 498 | if checkpoint_every_n_epochs: 499 | from custom_callbacks import CheckpointEveryNEpochs 500 | 501 | ckpt_n_epochs = CheckpointEveryNEpochs(checkpoint_every_n_epochs) 502 | callbacks.append(ckpt_n_epochs) 503 | 504 | lr_monitor = pl.callbacks.LearningRateMonitor() 505 | callbacks.append(model_checkpoint) 506 | callbacks.append(lr_monitor) 507 | 508 | trainer = pl.Trainer( 509 | gpus=gpus, 510 | accelerator=accelerator, 511 | max_epochs=max_epochs, 512 | logger=loggers, 513 | callbacks=callbacks, 514 | accumulate_grad_batches=accumulate_grad_batches, 515 | precision=precision, 516 | resume_from_checkpoint=resume_from_checkpoint, 517 | log_every_n_steps=log_every_n_steps, 518 | gradient_clip_val=1 519 | # plugins=pl.plugins.DDPPlugin(find_unused_parameters=True), 520 | ) 521 | 522 | trainer.fit(model, datamodule=dataset_module) 523 | 524 | logging.info(f"Best model path: {model_checkpoint.best_model_path}") 525 | logging.info(f"Best model val_loss: {model_checkpoint.best_model_score}") 526 | 527 | #  print(trainer.logger[0].experiment.get_key()) 528 | if run_test: 529 | if "COMET_API_KEY" in os.environ: 530 | trainer.logger = None 531 | # test on the dataset in-distribution 532 | trainer.test(datamodule=dataset_module, ckpt_path="best") 533 | 534 | if save_transformers_model: 535 | #  Save the tokenizer and the backbone LM with HuggingFace's serialization. 536 | #  To avoid mixing PL's and HuggingFace's serialization: 537 | #  https://github.com/PyTorchLightning/pytorch-lightning/issues/3096#issuecomment-686877242 538 | best_PL = LMForSequenceClassification.load_from_checkpoint( 539 | model_checkpoint.best_model_path 540 | ) 541 | best_PL.get_backbone().save_pretrained(model_dir) 542 | tokenizer.save_pretrained(model_dir) 543 | 544 | #  TODO resume_from_checkpoint logic 545 | #  logger.info("Simulation completed. Removing last.ckpt...") 546 | #  if early_stop_epochs > 0: 547 | #   if os.path.exists(os.path.join(model_dir, "last.ckpt")): 548 | #   os.remove(os.path.join(model_dir, "last.ckpt")) 549 | #   logger.info("Last checkpoint removed.") 550 | 551 | 552 | if __name__ == "__main__": 553 | main() 554 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from os.path import exists, join 2 | import pandas as pd 3 | import torch 4 | import logging 5 | 6 | from transformers import AutoModelForSequenceClassification 7 | from train_bert import compute_negative_entropy, LMForSequenceClassification 8 | from dataset import get_dataset_by_name, TokenizerDataModule 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import seaborn as sns 15 | 16 | sns.set_theme() 17 | import glob 18 | import numpy as np 19 | from IPython.display import display 20 | import os 21 | from os.path import join 22 | import re 23 | import torch 24 | from collections import namedtuple 25 | import pdb 26 | 27 | 28 | logging.basicConfig( 29 | format="%(levelname)s:%(asctime)s:%(module)s:%(message)s", level=logging.INFO 30 | ) 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class ScoreHandler: 35 | """Standardize how scores are saved and loaded for a given model & dataset.""" 36 | 37 | def __init__(self, dataset: torch.utils.data.Dataset): 38 | self.dataset = dataset 39 | 40 | def save_scores(self, scores, root_dir: str, column_name: str, dataset: str): 41 | """Save the scores for a model on a dataset. 42 | 43 | It uses a single csv file per dataset. Each column refers to the scores of a 44 | single dataset. 45 | 46 | Return: (datafram with scores, epath of the file containing the scores) 47 | """ 48 | file_name = f"scores_{dataset}.csv" 49 | file_path = join(root_dir, file_name) 50 | df = pd.read_csv(file_path) if exists(file_path) else self.dataset.data.copy() 51 | 52 | if column_name in df.columns: 53 | logging.info(f"Scores for {column_name} are present. Overriding them...") 54 | df[column_name] = scores 55 | df.to_csv(file_path, index=False) 56 | 57 | return df, file_path 58 | 59 | 60 | def load_model_from_folder(model_dir, pattern=None): 61 | if pattern: 62 | ckpt = glob.glob(join(model_dir, f"*{pattern}*"))[0] 63 | else: 64 | ckpt = glob.glob(f"{model_dir}/*.ckpt")[0] 65 | 66 | print("Loading", ckpt) 67 | 68 | if pattern: 69 | model = LMForSequenceClassification.load_from_checkpoint(ckpt) 70 | else: 71 | model = AutoModelForSequenceClassification.from_pretrained(model_dir) 72 | return model 73 | 74 | 75 | def join_subwords(tokens): 76 | span_start_idx = -1 77 | spans = list() 78 | for i, t in enumerate(tokens): 79 | if t.startswith("#") and span_start_idx == -1: 80 | span_start_idx = i - 1 81 | continue 82 | if not t.startswith("#") and span_start_idx != -1: 83 | spans.append((span_start_idx, i)) 84 | span_start_idx = -1 85 | 86 | #  span open at the end 87 | if span_start_idx != -1: 88 | spans.append((span_start_idx, len(tokens))) 89 | 90 | merged_tkns = list() 91 | pop_idxs = list() 92 | for span in spans: 93 | merged = "".join([t.strip("#") for t in tokens[span[0] : span[1]]]) 94 | merged_tkns.append(merged) 95 | 96 | #  indexes to remove in the final sequence 97 | for pop_idx in range(span[0] + 1, span[1]): 98 | pop_idxs.append(pop_idx) 99 | 100 | new_tokens = tokens.copy() 101 | for i, (span, merged) in enumerate(zip(spans, merged_tkns)): 102 | new_tokens[span[0]] = merged #  substitue with whole word 103 | 104 | mask = np.ones(len(tokens)) 105 | mask[pop_idxs] = 0 106 | new_tokens = np.array(new_tokens)[mask == 1] 107 | 108 | assert len(new_tokens) == len(tokens) - len(pop_idxs) 109 | return new_tokens, pop_idxs, spans 110 | 111 | 112 | def average_2d_over_spans(tensor, spans, reduce_fn="mean"): 113 | #  print("Spans #", spans) 114 | slices = list() 115 | 116 | last_span = None 117 | for span in spans: 118 | 119 | # first slice 120 | if last_span is None: 121 | slices.append(tensor[:, : span[0]]) 122 | else: 123 | slices.append(tensor[:, last_span[1] : span[0]]) 124 | 125 | # average over the subwords 126 | if reduce_fn == "mean": 127 | slices.append(tensor[:, span[0] : span[1]].mean(-1).unsqueeze(-1)) 128 | else: 129 | slices.append(tensor[:, span[0] : span[1]].sum(-1).unsqueeze(-1)) 130 | 131 | last_span = span 132 | 133 | #  last slice 134 | if spans[-1][1] != tensor.shape[1]: 135 | slices.append(tensor[:, last_span[1] :]) 136 | 137 | res = torch.cat(slices, dim=1) 138 | #  print("After average:", res.shape) 139 | return res 140 | 141 | 142 | def get_scores(y_true, scores_path): 143 | scores = torch.load(scores_path) 144 | y_pred = torch.zeros(scores.shape[0]).masked_fill(scores >= 0.5, 1) 145 | 146 | fp_mask = (y_true == 0) & (y_pred == 1) 147 | fp = torch.zeros(scores.shape[0]).masked_fill(fp_mask, 1) 148 | fp_indexes = torch.nonzero(fp).squeeze(-1) 149 | 150 | print(f"Found {fp_indexes.shape[0]} FPs") 151 | return {"scores": scores, "y_pred": y_pred, "fp_indexes": fp_indexes} 152 | 153 | 154 | #### VISUALIZATION: ENTROPY #### 155 | 156 | 157 | def show_entropy( 158 | models, 159 | tokenizer, 160 | max_sequence_length, 161 | data, 162 | names, 163 | n_samples=2, 164 | idxs=None, 165 | regularization="entropy", 166 | join=False, 167 | layers_mean=False, 168 | prompt=None, 169 | exp=False, 170 | remove_special=False, 171 | labelsize=15, 172 | titlesize=15, 173 | set_figsize=True, 174 | set_tightlayout=True, 175 | ): 176 | def process_text(idx, text): 177 | with torch.no_grad(): 178 | print(text) 179 | encoding = tokenizer( 180 | text, 181 | add_special_tokens=True, 182 | padding=True, 183 | truncation=True, 184 | max_length=max_sequence_length, 185 | return_tensors="pt", 186 | ) 187 | 188 | tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) 189 | 190 | if remove_special: 191 | tokens = tokens[1:-1] 192 | 193 | #  print("Len:", len(tokens), "tokens:", tokens) 194 | 195 | if join: 196 | # join subwords for better visualization 197 | new_tokens, pop_idxs, spans = join_subwords(tokens) 198 | #  print("Len new tokens", len(new_tokens)) 199 | tokens = new_tokens 200 | 201 | heatmap_list = list() 202 | final_entropies = list() 203 | y_scores = list() 204 | for i, (model, name) in enumerate(zip(models, names)): 205 | if regularization == "entropy": 206 | output = model(**encoding, output_attentions=True) 207 | reg_target = output["attentions"] 208 | else: 209 | output = model(**encoding, output_norms=True) 210 | norms = output["norms"] 211 | afx_norms = [t[1] for t in norms] 212 | reg_target = afx_norms 213 | 214 | logits = output["logits"] 215 | y_score = logits.softmax(-1)[0, 1] 216 | print(y_score) 217 | 218 | neg_entropy, entropies = compute_negative_entropy( 219 | reg_target, encoding["attention_mask"], return_values=True 220 | ) 221 | #  print("Entropies shape:", entropies[0].shape) 222 | 223 | #  join_subwords(entropies, tokens) 224 | #  print(name, "Final entropy: ", -neg_entropy.item()) 225 | entropies = -entropies[0] # take positive entropy 226 | entropies = torch.flipud(entropies) #  top layers are placed to the top 227 | 228 | # average subwords 229 | if join and len(spans) > 0: 230 | entropies = average_2d_over_spans(entropies, spans) 231 | 232 | if layers_mean: 233 | entropies = entropies.mean(0).unsqueeze(0) 234 | 235 | if exp: 236 | entropies = (1 / entropies).log() 237 | 238 | if remove_special: 239 | entropies = entropies[:, 1:-1] 240 | 241 | heatmap_list.append(entropies) 242 | final_entropies.append(-neg_entropy.item()) 243 | y_scores.append(y_score) 244 | 245 | #### VISUALIZATION #### 246 | 247 | if layers_mean: 248 | figsize = (12, 2 * len(models)) 249 | else: 250 | figsize = (6 * len(models), 6) 251 | 252 | if set_figsize: 253 | fig = plt.figure(constrained_layout=False, figsize=figsize) 254 | else: 255 | fig = plt.figure(constrained_layout=False) 256 | 257 | if regularization == "entropy": 258 | fig.suptitle( 259 | f"H: Entropy on Attention (a), ID:{idx}" 260 | ) # , {data[idx]}") 261 | else: 262 | fig.suptitle( 263 | f"Entropy on Norm (||a*f(zx)||), ID:{idx}" 264 | ) # , {data[idx]}") 265 | 266 | if set_tightlayout: 267 | fig.tight_layout() 268 | 269 | # compute global min and global max 270 | heatmap_tensor = torch.stack(heatmap_list) 271 | glob_min = heatmap_tensor.min().item() 272 | glob_max = heatmap_tensor.max().item() 273 | #  print("Glob max:", glob_max, "Glob min", glob_min) 274 | 275 | for i, name in enumerate(names): 276 | if layers_mean: 277 | gspec = fig.add_gridspec( 278 | len(models), 2, width_ratios=[20, 1], wspace=0.1, hspace=0.1 279 | ) 280 | splot = fig.add_subplot(gspec[i, 0]) 281 | 282 | if i == (len(names) - 1): 283 | cbar_ax = fig.add_subplot(gspec[:, 1]) 284 | sns.heatmap( 285 | heatmap_list[i], 286 | ax=splot, 287 | cbar=True, 288 | cbar_ax=cbar_ax, 289 | square=True, 290 | vmin=glob_min, 291 | vmax=glob_max, 292 | ) 293 | splot.set_xticks(np.arange(heatmap_list[i].shape[-1]) + 0.5) 294 | splot.set_xticklabels(tokens, rotation=90, fontsize=labelsize) 295 | [t.set_fontsize(labelsize) for t in cbar_ax.get_yticklabels()] 296 | 297 | # title to colorbar 298 | cbar_ax.set_title( 299 | "log(1/H)", fontsize=titlesize 300 | ) if exp else cbar_ax.set_title("H", fontsize=titlesize) 301 | 302 | else: 303 | sns.heatmap( 304 | heatmap_list[i], 305 | ax=splot, 306 | cbar=False, 307 | square=True, 308 | vmin=glob_min, 309 | vmax=glob_max, 310 | ) 311 | splot.set_xticklabels([]) 312 | 313 | splot.set_yticklabels([]) 314 | splot.set_title( 315 | f"{name}, p(1|x)={y_scores[i]:.3f}, H={final_entropies[i]:.3f}", 316 | fontsize=titlesize, 317 | ) 318 | 319 | else: 320 | width_ratios = [10] * len(models) 321 | width_ratios += [1] 322 | gspec = fig.add_gridspec( 323 | 1, len(models) + 1, width_ratios=width_ratios, wspace=0.2 324 | ) 325 | splot = fig.add_subplot(gspec[0, i]) 326 | 327 | if i == (len(names) - 1): 328 | cbar_ax = fig.add_subplot(gspec[0, -1]) 329 | sns.heatmap( 330 | heatmap_list[i], 331 | ax=splot, 332 | cbar=True, 333 | cbar_ax=cbar_ax, 334 | square=True, 335 | vmin=glob_min, 336 | vmax=glob_max, 337 | ) 338 | [t.set_fontsize(labelsize) for t in cbar_ax.get_yticklabels()] 339 | 340 | # title to colorbar 341 | cbar_ax.set_title( 342 | "log(1/H)", fontsize=titlesize 343 | ) if exp else cbar_ax.set_title("H", fontsize=titlesize) 344 | else: 345 | sns.heatmap(heatmap_list[i], ax=splot, cbar=False, square=True) 346 | 347 | if i == 0: 348 | splot.set_ylabel("Layer", fontsize=labelsize) 349 | splot.set_yticklabels(np.arange(11, -1, -1), fontsize=labelsize) 350 | else: 351 | splot.set_yticklabels([]) 352 | 353 | splot.set_xticks(np.arange(heatmap_list[i].shape[-1]) + 0.5) 354 | splot.set_xticklabels(tokens, rotation=90, fontsize=labelsize) 355 | splot.set_title( 356 | f"{name}, p(1|x)={y_scores[i]:.3f}, H={final_entropies[i]:.3f}", 357 | fontsize=titlesize, 358 | ) 359 | 360 | #  print(len(tokens), len(axes[i].get_xticklabels())) 361 | #  print(entropies.shape) 362 | #  axes[i].set_xticks(np.arange(heatmap_list[i].shape[-1])) 363 | # axes[i].set_xticklabels(tokens, rotation=90) 364 | # axes[i].set_title(f"{name}, p(1|x)={y_scores[i]:.3f}, e={final_entropies[i]:.3f}") 365 | # axes[i].set_yticklabels([]) 366 | return fig 367 | 368 | if prompt: 369 | idx = "custom" 370 | text = prompt 371 | print("ID: ", idx, text) 372 | return process_text(idx, text) 373 | 374 | if idxs is None: 375 | #  pick random samples to show 376 | idxs = np.random.randint(len(data), size=n_samples) 377 | 378 | print(idxs) 379 | for idx in idxs: 380 | print("ID: ", idx, data[idx]) 381 | process_text(idx, data[idx]["text"]) 382 | 383 | 384 | def compare_sentences( 385 | model, 386 | tokenizer, 387 | sentences, 388 | max_sequence_length=120, 389 | remove_special=True, 390 | join=True, 391 | show_log=True, 392 | labelsize=15, 393 | titlesize=15, 394 | figsize=(12, 12), 395 | ): 396 | processed = list() 397 | 398 | with torch.no_grad(): 399 | 400 | for text in sentences: 401 | 402 | encoding = tokenizer( 403 | text, 404 | add_special_tokens=True, 405 | padding=True, 406 | truncation=True, 407 | max_length=max_sequence_length, 408 | return_tensors="pt", 409 | ) 410 | 411 | tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) 412 | 413 | if remove_special: 414 | tokens = tokens[1:-1] 415 | 416 | if join: 417 | # join subwords for better visualization 418 | new_tokens, pop_idxs, spans = join_subwords(tokens) 419 | #  print("Len new tokens", len(new_tokens)) 420 | tokens = new_tokens 421 | 422 | output = model(**encoding, output_attentions=True) 423 | logits = output["logits"] 424 | y_score = logits.softmax(-1)[0, 1] 425 | 426 | neg_entropy, entropies = compute_negative_entropy( 427 | output["attentions"], encoding["attention_mask"], return_values=True 428 | ) 429 | #  print("Entropies shape:", entropies[0].shape) 430 | 431 | #  print(name, "Final entropy: ", -neg_entropy.item()) 432 | entropies = -entropies[0] # take positive entropy 433 | 434 | # average subwords 435 | if join and len(spans) > 0: 436 | entropies = average_2d_over_spans(entropies, spans) 437 | 438 | entropies = entropies.mean(0).unsqueeze(0) 439 | 440 | if show_log: 441 | entropies = (1 / entropies).log() 442 | 443 | if remove_special: 444 | entropies = entropies[:, 1:-1] 445 | 446 | processed.append((tokens, y_score, entropies)) 447 | 448 | # print(processed) 449 | fig = plt.figure(constrained_layout=False, figsize=figsize) 450 | gspec = fig.add_gridspec(len(sentences) * 2, 1, hspace=2, wspace=5) 451 | 452 | vmin = torch.stack([p[2] for p in processed]).min().item() 453 | vmax = torch.stack([p[2] for p in processed]).max().item() 454 | print(vmin, vmax) 455 | 456 | for i, (tokens, y_score, entropies) in enumerate(processed): 457 | splot = fig.add_subplot(gspec[i, 0]) 458 | 459 | #  cbar_ax = fig.add_subplot(gspec[:, 1]) 460 | sns.heatmap( 461 | entropies, 462 | ax=splot, 463 | cbar=False, 464 | # cbar_ax=cbar_ax, 465 | square=True, 466 | # cmap="Reds", 467 | annot=False, 468 | vmin=vmin, 469 | vmax=vmax, 470 | ) 471 | splot.set_xticks(np.arange(entropies.shape[-1]) + 0.5) 472 | splot.set_xticklabels(tokens, rotation=90, fontsize=labelsize) 473 | splot.set_yticklabels([]) 474 | splot.set_title( 475 | f"p(1|x)={y_score:.3f}", 476 | fontsize=titlesize, 477 | ) 478 | # [t.set_fontsize(labelsize) for t in cbar_ax.get_yticklabels()] 479 | 480 | # title to colorbar 481 | # cbar_ax.set_title( 482 | # "log(1/H)", fontsize=titlesize 483 | # ) if exp else cbar_ax.set_title("H", fontsize=titlesize) 484 | # fig.tight_layout() 485 | 486 | 487 | #### BIAS_ANALYSIS: parsing results and bias analysis 488 | 489 | 490 | def match_pattern_concat(main_dir, pattern, verbose=True): 491 | """Find all files that match a patter in main_dir. Then concatenate their content into a pandas df.""" 492 | versions = glob.glob(join(main_dir, pattern)) 493 | if verbose: 494 | print(f"Found {len(versions)} versions") 495 | 496 | res = list() 497 | for version in versions: 498 | df = pd.read_csv(version) 499 | filename = os.path.basename(version) 500 | seed = re.search(r"([0-9]{1})", filename).group(1) 501 | # print(filename, seed) 502 | df["seed"] = seed 503 | res.append(df) 504 | 505 | return pd.concat(res) 506 | 507 | 508 | def mean_std_across_subgroups(data: pd.DataFrame, metrics): 509 | print("Found the following models:", data.model.unique()) 510 | 511 | model_groups = data.groupby("model") 512 | means = list() 513 | stds = list() 514 | for model, group_df in model_groups: 515 | subgroup_groups = group_df.groupby("subgroup").mean() # across seeds 516 | for metric in metrics: 517 | means.append( 518 | { 519 | "metric": metric, 520 | "model_name": model, 521 | "mean_across_subgroups": subgroup_groups[metric].mean(), 522 | } 523 | ) 524 | stds.append( 525 | { 526 | "metric": metric, 527 | "model_name": model, 528 | "std_across_subgroups": subgroup_groups[metric].std(), 529 | } 530 | ) 531 | 532 | return pd.DataFrame(means), pd.DataFrame(stds) 533 | 534 | 535 | def bias_metrics_comparison_table(metrics, models): 536 | all_df = pd.concat(models) 537 | means, stds = mean_std_across_subgroups(all_df, metrics) 538 | return means.pivot_table(index="metric", columns="model_name").round(5) 539 | 540 | 541 | def read_scores(main_dir, model_name, dataset, reg_strength=None): 542 | if reg_strength: 543 | score_files = glob.glob( 544 | os.path.join(main_dir, f"scores_{dataset}_{model_name}-*-{reg_strength}.pt") 545 | ) 546 | else: 547 | score_files = glob.glob( 548 | os.path.join(main_dir, f"scores_{dataset}_{model_name}-*.pt") 549 | ) 550 | return [torch.load(f).numpy() for f in score_files] 551 | 552 | 553 | def compute_classification_metrics(main_dir, model_name, dataset, reg_strength=None): 554 | """Read scores and get classifcation metrics""" 555 | 556 | _, _, test = get_dataset_by_name(dataset) 557 | y_true = test.get_labels() 558 | 559 | scores = read_scores(main_dir, model_name, dataset, reg_strength) 560 | print(f"Found {len(scores)} scores files.") 561 | 562 | class_metrics = list() 563 | for y_pred in scores: 564 | class_metrics.append(evaluate_metrics(y_true, y_pred, th=0.5)) 565 | 566 | return pd.DataFrame(class_metrics) 567 | 568 | 569 | Results = namedtuple("Results", ["bmpi", "bm", "cm", "tm"]) 570 | 571 | 572 | def get_results( 573 | main_dir, model_name, bias_metrics_on=None, class_metrics_on=None, reg_strength=None 574 | ): 575 | """Gather all results available for a given model""" 576 | 577 | def attach_info(df): 578 | df["model_name"] = model_name 579 | df["bias_metrics_on"] = bias_metrics_on 580 | df["class_metrics_on"] = class_metrics_on 581 | df["reg_strength"] = reg_strength 582 | return df 583 | 584 | bias_terms_p, bias_metrics_p, class_metrics_p, test_metrics_p = ( 585 | None, 586 | None, 587 | None, 588 | None, 589 | ) 590 | if bias_metrics_on and reg_strength: 591 | bias_terms_p = f"bias_terms_{model_name}-*-{reg_strength}_{bias_metrics_on}.csv" 592 | bias_metrics_p = ( 593 | f"bias_metrics_{model_name}-*-{reg_strength}_{bias_metrics_on}.csv" 594 | ) 595 | class_metrics_p = ( 596 | f"class_metrics_{model_name}-*-{reg_strength}_{bias_metrics_on}.csv" 597 | ) 598 | 599 | if bias_metrics_on and not reg_strength: 600 | bias_terms_p = f"bias_terms_{model_name}-*_{bias_metrics_on}.csv" 601 | bias_metrics_p = f"bias_metrics_{model_name}-*_{bias_metrics_on}.csv" 602 | class_metrics_p = f"class_metrics_{model_name}-*_{bias_metrics_on}.csv" 603 | 604 | if class_metrics_on and reg_strength: 605 | test_metrics_p = ( 606 | f"class_metrics_{model_name}-*-{reg_strength}_{class_metrics_on}.csv" 607 | ) 608 | 609 | if class_metrics_on and not reg_strength: 610 | test_metrics_p = f"class_metrics_{model_name}-*_{class_metrics_on}.csv" 611 | 612 | bias_metrics_per_it, bias_metrics, class_metrics, test_metrics = ( 613 | None, 614 | None, 615 | None, 616 | None, 617 | ) 618 | 619 | # get bias metrics per identity term (x #seeds) 620 | if bias_terms_p: 621 | print("Get bias metrics per identity term") 622 | print(bias_terms_p) 623 | bias_metrics_per_it = match_pattern_concat(main_dir, bias_terms_p) 624 | bias_metrics_per_it = attach_info(bias_metrics_per_it) 625 | 626 | # get bias metrics 627 | if bias_metrics_p: 628 | try: 629 | print("Get bias metrics averaged") 630 | bias_metrics = match_pattern_concat(main_dir, bias_metrics_p) 631 | bias_metrics.columns = ["metric", "value", "seed"] 632 | bias_metrics = attach_info(bias_metrics) 633 | except: 634 | print(f"Files 'bias_metrics_{model_name}...' not found. Skipping...") 635 | 636 | # get classification metrics 637 | if class_metrics_p: 638 | try: 639 | print("Get classification metrics on 'bias_metrics_on' dataset") 640 | class_metrics = match_pattern_concat(main_dir, class_metrics_p) 641 | class_metrics.columns = ["metric", "value", "seed"] 642 | class_metrics = attach_info(class_metrics) 643 | except: 644 | print(f"Files 'class_metrics_{model_name}...' not found. Skipping...") 645 | 646 | if test_metrics_p: 647 | try: 648 | print("Get classification metrics on 'class_metrics_on' dataset") 649 | test_metrics = match_pattern_concat(main_dir, test_metrics_p) 650 | test_metrics.columns = ["metric", "value", "seed"] 651 | test_metrics = attach_info(test_metrics) 652 | test_metrics["metric"] = test_metrics.metric.apply(lambda x: f"test_{x}") 653 | 654 | #  Add summary_AUC_test 655 | bnsp = bias_metrics.loc[bias_metrics.metric == "bnsp_auc_mean"] 656 | bpsn = bias_metrics.loc[bias_metrics.metric == "bpsn_auc_mean"] 657 | subgroup = bias_metrics.loc[bias_metrics.metric == "subgroup_auc_mean"] 658 | test_AUC = test_metrics.loc[test_metrics.metric == "test_AUC"] 659 | 660 | #  import IPython 661 | #  IPython.embed() 662 | #  exit(-1) 663 | 664 | summary_AUC_test = ( 665 | bnsp.value.values 666 | + bpsn.value.values 667 | + subgroup.value.values 668 | + test_AUC.value.values 669 | ) / 4 670 | bias_metrics = bias_metrics.append( 671 | pd.DataFrame( 672 | { 673 | "metric": ["summary_AUC_test"] * test_AUC.shape[0], 674 | "value": summary_AUC_test, 675 | } 676 | ) 677 | ) 678 | 679 | except Exception as e: 680 | print( 681 | f"Files 'class_metrics_{model_name}-*_{class_metrics_on}...' not found. Skipping...", 682 | e, 683 | ) 684 | #  raise(e) 685 | 686 | return Results(bias_metrics_per_it, bias_metrics, class_metrics, test_metrics) 687 | 688 | 689 | def show_scatter_on_metric(data: list, metrics, style="box", h_pad=2, dpi=80): 690 | """Create one scatter plot per dataframe in data. 691 | 692 | Each dataframe should contain the per-IT bias metrics of several seeds for a single model. 693 | """ 694 | if not isinstance(metrics, list): 695 | metrics = list(metrics) 696 | 697 | print(f"Comparing {len(data)} model(s) on {len(metrics)} metric(s)") 698 | 699 | fig, axes = plt.subplots( 700 | nrows=len(metrics), 701 | ncols=len(data), 702 | figsize=(18, 6 * len(metrics)), 703 | sharey=True, 704 | dpi=dpi, 705 | ) 706 | 707 | for i, metric in enumerate(metrics): 708 | for j, bias_df in enumerate(data): 709 | #  bias_df = bias_df.sort_values(metric) 710 | if style == "box": 711 | sns.boxplot(x="subgroup", y=metric, data=bias_df, ax=axes[i, j]) 712 | elif style == "scatter": 713 | sns.stripplot( 714 | x="subgroup", y=metric, data=bias_df, ax=axes[i, j], jitter=0, s=10 715 | ) 716 | axes[i, j].set_title(f"{bias_df.model_name.iloc[0]}") 717 | axes[i, j].set_xticklabels(axes[i, j].get_xticklabels(), rotation=90) 718 | 719 | fig.tight_layout(h_pad=h_pad) 720 | return fig 721 | 722 | 723 | def compare_metrics(data: list): 724 | """Create a single dataframe to compare the classification/bias metrics in data, averaged over seeds.""" 725 | metrics_by_model = dict() 726 | for class_df, name in data: 727 | metrics_by_model[name] = class_df.groupby("metric").mean().value 728 | return pd.DataFrame(metrics_by_model) 729 | 730 | 731 | def get_metrics_table( 732 | models, 733 | include_bias=True, 734 | include_class_eval=True, 735 | include_class_test=True, 736 | hide_power_mean=False, 737 | ): 738 | results = list() 739 | if include_bias: 740 | results.append(compare_metrics([(m[1].bm, m[0]) for m in models])) 741 | if include_class_eval: 742 | results.append(compare_metrics([(m[1].cm, m[0]) for m in models])) 743 | if include_class_test: 744 | results.append( 745 | compare_metrics([(m[1].tm, m[0]) for m in models if m[1].tm is not None]) 746 | ) 747 | 748 | print(len(results)) 749 | cat = pd.concat(results) 750 | if hide_power_mean: 751 | print("hiding results with 'power_mean'") 752 | cat = cat.loc[[v for v in cat.index if not v.endswith("power_mean")]] 753 | return cat 754 | 755 | 756 | def get_latex_tables(metric_table: pd.DataFrame): 757 | bias_metrics = { 758 | "subgroup_auc_mean": "subgroup_auc", 759 | "bnsp_auc_mean": "bnsp_auc", 760 | "bpsn_auc_mean": "bpsn_auc", 761 | # "positive_aeg_mean": "positive_aeg", 762 | # "negative_aeg_mean": "negative_aeg", 763 | "fped": "fped", 764 | "fned": "fned", 765 | } 766 | 767 | class_metrics_eval = { 768 | "F1_macro": "F1_macro (synt)", 769 | "F1_weighted": "F1_weighted (synt)", 770 | "F1_binary": "F1_binary (synt)", 771 | #  "acc": "Accuracy", 772 | #  "AUC": "AUC", 773 | } 774 | class_metrics_test = { 775 | "test_F1_macro": "F1_macro (test)", 776 | "test_F1_weighted": "F1_weighted (test)", 777 | "test_F1_binary": "F1_binary (test)", 778 | } 779 | 780 | models = { 781 | "vanilla": "BERT", 782 | "kebert_kITs": "KeBERT", 783 | "kebert_madITs": "KeBERT (madITs)", 784 | "kebert_kITsNW": "KeBERT (noW)", 785 | "kebert_kITsITA": "KeBERT", 786 | "JigCNN": "CNN", 787 | "JigCNN_deb": "CNN (debiased)", 788 | "Entropy_0.01": "EmBERT (early stop)", 789 | "Entropy_epoch19_0.01": "BERT+EAR", 790 | "BERT": "BERT", 791 | "BERT_EAR": "BERT+EAR", 792 | "BERT_bal": "BERT (class balance)", 793 | "BERT_EAR_bal": "BERT+EAR (class balance)", 794 | "BERT_SOC": "BERT_SOC" 795 | } 796 | 797 | # filter by models 798 | metric_table = metric_table[[m for m in models.keys() if m in metric_table.columns]] 799 | metric_table = metric_table.rename(columns=models) 800 | 801 | bias_df = metric_table.loc[bias_metrics.keys()].rename(index=bias_metrics) 802 | class_eval_df = metric_table.loc[class_metrics_eval.keys()].rename( 803 | index=class_metrics_eval 804 | ) 805 | 806 | #  bias and classification performances on the evaluation set (Madlibs, Miso synt, etc.) 807 | eval_set_df = pd.concat([bias_df, class_eval_df], axis=0).T 808 | #  class. performance on the test portion (Wiki, Miso, Miso raw, etc.) 809 | 810 | test_df = ( 811 | metric_table.loc[class_metrics_test.keys()].rename(index=class_metrics_test).T 812 | ) 813 | 814 | return eval_set_df, test_df -------------------------------------------------------------------------------- /term_extraction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "fundamental-geneva", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from dataset import get_dataset_by_name\n", 11 | "import torch\n", 12 | "from torch.utils.data import DataLoader\n", 13 | "import utils\n", 14 | "from tqdm import tqdm\n", 15 | "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", 16 | "from train_bert import compute_negative_entropy, LMForSequenceClassification\n", 17 | "from collections import defaultdict\n", 18 | "from typing import Dict\n", 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "import seaborn as sns\n", 22 | "sns.set_theme(\"notebook\")\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from collections import Counter\n", 25 | "\n", 26 | "from string import punctuation\n", 27 | "from sklearn.feature_extraction.text import CountVectorizer" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "swedish-harvest", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def get_tokens_entropy(model, tokenizer, dataset, device=\"cpu\", join=True, batch_size=32):\n", 38 | " loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n", 39 | " entropy_scores = defaultdict(list)\n", 40 | " word_occ = defaultdict(int)\n", 41 | " fps = defaultdict(int)\n", 42 | " fns = defaultdict(int)\n", 43 | " \n", 44 | " entropy_fps = defaultdict(list)\n", 45 | " entropy_fns = defaultdict(list)\n", 46 | " \n", 47 | " num_positives = defaultdict(int)\n", 48 | " num_negatives = defaultdict(int)\n", 49 | " \n", 50 | " with torch.no_grad():\n", 51 | " for idx, batch in tqdm(enumerate(loader), total=len(loader)):\n", 52 | " encoding = tokenizer(\n", 53 | " batch[\"text\"],\n", 54 | " add_special_tokens=True,\n", 55 | " padding=True,\n", 56 | " truncation=True,\n", 57 | " max_length=120,\n", 58 | " return_tensors=\"pt\",\n", 59 | " ).to(device)\n", 60 | " \n", 61 | " output = model(**encoding, output_attentions=True)\n", 62 | " y_preds = output[\"logits\"].argmax(-1)\n", 63 | " y_trues = batch[\"label\"]\n", 64 | " \n", 65 | " neg_entropy, entropies = compute_negative_entropy(\n", 66 | " output[\"attentions\"], encoding[\"attention_mask\"], return_values=True\n", 67 | " )\n", 68 | " \n", 69 | " # process each batch\n", 70 | " for i_batch in range(y_preds.shape[0]):\n", 71 | " y_pred = y_preds[i_batch]\n", 72 | " y_true = y_trues[i_batch]\n", 73 | " curr_e = -entropies[i_batch]\n", 74 | " curr_e = torch.flipud(curr_e)\n", 75 | "\n", 76 | " input_ids = encoding[\"input_ids\"][i_batch]\n", 77 | " input_ids = input_ids[input_ids != 0]\n", 78 | " tokens = tokenizer.convert_ids_to_tokens(input_ids)\n", 79 | "\n", 80 | " # if remove_special:\n", 81 | " #  tokens = tokens[1:-1]\n", 82 | "\n", 83 | " if join:\n", 84 | " # join subwords for better visualization\n", 85 | " new_tokens, pop_idxs, spans = utils.join_subwords(tokens)\n", 86 | " #  print(\"Len new tokens\", len(new_tokens))\n", 87 | " tokens = new_tokens\n", 88 | " \n", 89 | " # average subwords\n", 90 | " if join and len(spans) > 0:\n", 91 | " curr_e = utils.average_2d_over_spans(curr_e, spans)\n", 92 | " \n", 93 | " curr_e = curr_e.mean(0).unsqueeze(0) \n", 94 | " assert curr_e.shape[1] == len(tokens)\n", 95 | "\n", 96 | " for i, t in enumerate(tokens):\n", 97 | " entr = curr_e[0, i].cpu().item()\n", 98 | " entropy_scores[t].append(entr)\n", 99 | " \n", 100 | " word_occ[t] += 1\n", 101 | " if y_true == 1:\n", 102 | " num_positives[t] += 1\n", 103 | " num_negatives[t] += 0\n", 104 | " else:\n", 105 | " num_negatives[t] += 1\n", 106 | " num_positives[t] += 0\n", 107 | " \n", 108 | " # false positives\n", 109 | " if y_true == 0 and y_pred == 1:\n", 110 | " fps[t] += 1\n", 111 | " fns[t] += 0\n", 112 | " entropy_fps[t].append(entr)\n", 113 | " \n", 114 | " # false negatives\n", 115 | " elif y_true == 1 and y_pred == 0:\n", 116 | " fns[t] += 1\n", 117 | " fps[t] += 0\n", 118 | " entropy_fns[t].append(entr)\n", 119 | " \n", 120 | " else:\n", 121 | " fns[t] += 0\n", 122 | " fps[t] += 0\n", 123 | "\n", 124 | " # return the average\n", 125 | " entropy_scores = {k: np.mean(v) for k, v in entropy_scores.items()}\n", 126 | " entropy_fps = {k: np.mean(v) for k, v in entropy_fps.items()}\n", 127 | " entropy_fns = {k: np.mean(v) for k, v in entropy_fns.items()}\n", 128 | " return entropy_scores, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives\n", 129 | "\n", 130 | "\n", 131 | "def filter_stats(stats):\n", 132 | " len_m = stats[\"token\"].apply(len) > 3\n", 133 | " count_min = stats[\"count\"] > 10\n", 134 | " count_max = stats[\"count\"] < 3600\n", 135 | " punct = stats[\"token\"].isin(list(punctuation))\n", 136 | " \n", 137 | " return stats.loc[\n", 138 | " len_m &\n", 139 | " count_min &\n", 140 | " count_max &\n", 141 | " ~punct\n", 142 | " ]" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 3, 148 | "id": "detected-africa", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "from nltk.tokenize import TweetTokenizer\n", 153 | " \n", 154 | "def twitter_tokenizer(doc): \n", 155 | " tokens = TweetTokenizer().tokenize(doc)\n", 156 | " \n", 157 | " tokens_new = list()\n", 158 | " for t in tokens:\n", 159 | " if t.startswith(\"@\") and len(t) > 1:\n", 160 | " tokens_new.append(\"USER\")\n", 161 | " \n", 162 | " elif len(t) < 3:\n", 163 | " continue\n", 164 | " \n", 165 | " else:\n", 166 | " tokens_new.append(t)\n", 167 | " \n", 168 | " return tokens_new\n", 169 | "\n", 170 | "\n", 171 | "def preprocess_collection(documents, min_df=0.05, max_df=0.95):\n", 172 | " cv = CountVectorizer(min_df=min_df, max_df=max_df, tokenizer=twitter_tokenizer)\n", 173 | " new_docs = cv.fit_transform(documents)\n", 174 | " new_docs = cv.inverse_transform(new_docs)\n", 175 | " new_corpus = [\" \".join(doc) for doc in new_docs]\n", 176 | " return cv, new_corpus" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "interesting-elements", 182 | "metadata": {}, 183 | "source": [ 184 | "# Misogyny (EN)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 18, 190 | "id": "confident-observation", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "model = AutoModelForSequenceClassification.from_pretrained(\n", 195 | " \"BERT-0/\"\n", 196 | ").to(device)\n", 197 | "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", 198 | "train, dev, test = get_dataset_by_name(\"miso\")\n", 199 | "cv, docs = preprocess_collection(train.get_texts(), 0.01, 0.95)\n", 200 | "train.texts = docs" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 19, 206 | "id": "pressed-lesbian", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "100%|██████████| 113/113 [00:05<00:00, 20.20it/s]\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives = get_tokens_entropy(model, tokenizer, train, device)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 20, 224 | "id": "exterior-lending", 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "data": { 229 | "text/html": [ 230 | "
\n", 231 | "\n", 244 | "\n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | "
entropyentropy_fpsentropy_fnscountfpsfnsnum_posnum_neg
[CLS]2.2377262.2714392.2300803600.0614.0256.01606.01994.0
user2.2507432.2798402.2626251235.0204.0115.0549.0686.0
every2.5086822.6430612.15322146.06.03.028.018.0
time2.5339322.5963562.37719282.015.04.036.046.0
there2.4302382.1009902.62123456.06.04.024.032.0
...........................
fuckin2.5206652.4446352.69934539.05.01.022.017.0
take2.4169072.3143402.28414265.07.03.032.033.0
did2.4473852.5616862.26970653.09.06.025.028.0
keep2.3896842.3035362.52092240.05.04.021.019.0
big2.3817302.4907981.97737939.02.02.030.09.0
\n", 382 | "

177 rows × 8 columns

\n", 383 | "
" 384 | ], 385 | "text/plain": [ 386 | " entropy entropy_fps entropy_fns count fps fns num_pos \\\n", 387 | "[CLS] 2.237726 2.271439 2.230080 3600.0 614.0 256.0 1606.0 \n", 388 | "user 2.250743 2.279840 2.262625 1235.0 204.0 115.0 549.0 \n", 389 | "every 2.508682 2.643061 2.153221 46.0 6.0 3.0 28.0 \n", 390 | "time 2.533932 2.596356 2.377192 82.0 15.0 4.0 36.0 \n", 391 | "there 2.430238 2.100990 2.621234 56.0 6.0 4.0 24.0 \n", 392 | "... ... ... ... ... ... ... ... \n", 393 | "fuckin 2.520665 2.444635 2.699345 39.0 5.0 1.0 22.0 \n", 394 | "take 2.416907 2.314340 2.284142 65.0 7.0 3.0 32.0 \n", 395 | "did 2.447385 2.561686 2.269706 53.0 9.0 6.0 25.0 \n", 396 | "keep 2.389684 2.303536 2.520922 40.0 5.0 4.0 21.0 \n", 397 | "big 2.381730 2.490798 1.977379 39.0 2.0 2.0 30.0 \n", 398 | "\n", 399 | " num_neg \n", 400 | "[CLS] 1994.0 \n", 401 | "user 686.0 \n", 402 | "every 18.0 \n", 403 | "time 46.0 \n", 404 | "there 32.0 \n", 405 | "... ... \n", 406 | "fuckin 17.0 \n", 407 | "take 33.0 \n", 408 | "did 28.0 \n", 409 | "keep 19.0 \n", 410 | "big 9.0 \n", 411 | "\n", 412 | "[177 rows x 8 columns]" 413 | ] 414 | }, 415 | "execution_count": 20, 416 | "metadata": {}, 417 | "output_type": "execute_result" 418 | } 419 | ], 420 | "source": [ 421 | "entropy_df = pd.DataFrame(\n", 422 | " [\n", 423 | " entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives\n", 424 | " ], index=[\"entropy\", \"entropy_fps\", \"entropy_fns\", \"count\", \"fps\", \"fns\", \"num_pos\", \"num_neg\"]\n", 425 | ").T\n", 426 | "entropy_df" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 21, 432 | "id": "contemporary-poison", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "entropy_df.sort_values(\"entropy\", ascending=True).to_csv(\"latex/term_extraction/miso_eng.csv\")" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "alone-hepatitis", 442 | "metadata": {}, 443 | "source": [ 444 | "# Misogyny (IT)" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 22, 450 | "id": "violent-enough", 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "model = AutoModelForSequenceClassification.from_pretrained(\n", 455 | " \"BERT-0/\"\n", 456 | ").to(device)\n", 457 | "tokenizer = AutoTokenizer.from_pretrained(\"dbmdz/bert-base-italian-uncased\")\n", 458 | "train, dev, test = get_dataset_by_name(\"miso-ita-raw\")\n", 459 | "cv, docs = preprocess_collection(train.get_texts(), 0.01, 0.95)\n", 460 | "train.texts = docs" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 23, 466 | "id": "equal-italian", 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stderr", 471 | "output_type": "stream", 472 | "text": [ 473 | "100%|██████████| 141/141 [00:06<00:00, 22.15it/s]\n" 474 | ] 475 | }, 476 | { 477 | "data": { 478 | "text/html": [ 479 | "
\n", 480 | "\n", 493 | "\n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | "
entropyentropy_fpsentropy_fnscountfpsfnsnum_posnum_neg
[CLS]2.0677162.1608252.0295604500.0344.0274.02103.02397.0
ora2.3789252.3544282.482992126.010.010.038.088.0
alle2.4770052.4510332.61949656.07.03.015.041.0
che2.2607262.2994742.2668601862.0157.0135.0777.01085.0
siete2.4286882.7108072.51671660.02.08.012.048.0
...........................
ciao2.2247122.3869892.06985357.02.01.046.011.0
tette2.2732211.9806422.68706245.07.05.028.017.0
quindi2.5222482.6513242.32179548.08.05.016.032.0
voi2.4842472.4749702.46301377.09.09.021.056.0
donne2.4569212.2958802.51731146.02.014.024.022.0
\n", 631 | "

162 rows × 8 columns

\n", 632 | "
" 633 | ], 634 | "text/plain": [ 635 | " entropy entropy_fps entropy_fns count fps fns num_pos \\\n", 636 | "[CLS] 2.067716 2.160825 2.029560 4500.0 344.0 274.0 2103.0 \n", 637 | "ora 2.378925 2.354428 2.482992 126.0 10.0 10.0 38.0 \n", 638 | "alle 2.477005 2.451033 2.619496 56.0 7.0 3.0 15.0 \n", 639 | "che 2.260726 2.299474 2.266860 1862.0 157.0 135.0 777.0 \n", 640 | "siete 2.428688 2.710807 2.516716 60.0 2.0 8.0 12.0 \n", 641 | "... ... ... ... ... ... ... ... \n", 642 | "ciao 2.224712 2.386989 2.069853 57.0 2.0 1.0 46.0 \n", 643 | "tette 2.273221 1.980642 2.687062 45.0 7.0 5.0 28.0 \n", 644 | "quindi 2.522248 2.651324 2.321795 48.0 8.0 5.0 16.0 \n", 645 | "voi 2.484247 2.474970 2.463013 77.0 9.0 9.0 21.0 \n", 646 | "donne 2.456921 2.295880 2.517311 46.0 2.0 14.0 24.0 \n", 647 | "\n", 648 | " num_neg \n", 649 | "[CLS] 2397.0 \n", 650 | "ora 88.0 \n", 651 | "alle 41.0 \n", 652 | "che 1085.0 \n", 653 | "siete 48.0 \n", 654 | "... ... \n", 655 | "ciao 11.0 \n", 656 | "tette 17.0 \n", 657 | "quindi 32.0 \n", 658 | "voi 56.0 \n", 659 | "donne 22.0 \n", 660 | "\n", 661 | "[162 rows x 8 columns]" 662 | ] 663 | }, 664 | "execution_count": 23, 665 | "metadata": {}, 666 | "output_type": "execute_result" 667 | } 668 | ], 669 | "source": [ 670 | "entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives = get_tokens_entropy(model, tokenizer, train, device)\n", 671 | "entropy_df = pd.DataFrame(\n", 672 | " [\n", 673 | " entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives\n", 674 | " ], index=[\"entropy\", \"entropy_fps\", \"entropy_fns\", \"count\", \"fps\", \"fns\", \"num_pos\", \"num_neg\"]\n", 675 | ").T\n", 676 | "entropy_df" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 24, 682 | "id": "dirty-reservoir", 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "entropy_df.sort_values(\"entropy\", ascending=True).to_csv(\"latex/term_extraction/miso_ita.csv\")" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "id": "dirty-summer", 692 | "metadata": {}, 693 | "source": [ 694 | "# MlMA" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": 7, 700 | "id": "moderate-designation", 701 | "metadata": {}, 702 | "outputs": [ 703 | { 704 | "name": "stderr", 705 | "output_type": "stream", 706 | "text": [ 707 | "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']\n", 708 | "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 709 | "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 710 | "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 711 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 712 | ] 713 | } 714 | ], 715 | "source": [ 716 | "%capture\n", 717 | "\n", 718 | "device = \"cuda\"\n", 719 | "\n", 720 | "model = LMForSequenceClassification.load_from_checkpoint(\n", 721 | " \"BERT-0/\"\n", 722 | ").to(device)\n", 723 | "\n", 724 | "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", 725 | "train, dev, test = get_dataset_by_name(\"mlma\")\n", 726 | "cv, docs = preprocess_collection(test.get_texts(), 0.01, 0.95)\n", 727 | "test.texts = docs" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 8, 733 | "id": "enabling-macintosh", 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "name": "stderr", 738 | "output_type": "stream", 739 | "text": [ 740 | " 0%| | 0/18 [00:00\n", 749 | "\n", 762 | "\n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | "
entropyentropy_fpsentropy_fnscountfpsfnsnum_posnum_neg
[CLS]1.8170481.798705NaN565.066.00.0499.066.0
stupid2.1174141.918335NaN11.01.00.010.01.0
cunt1.7714951.970147NaN51.02.00.049.02.0
user1.8542051.821117NaN431.049.00.0382.049.0
[SEP]1.7880421.768903NaN565.066.00.0499.066.0
\n", 834 | "" 835 | ], 836 | "text/plain": [ 837 | " entropy entropy_fps entropy_fns count fps fns num_pos num_neg\n", 838 | "[CLS] 1.817048 1.798705 NaN 565.0 66.0 0.0 499.0 66.0\n", 839 | "stupid 2.117414 1.918335 NaN 11.0 1.0 0.0 10.0 1.0\n", 840 | "cunt 1.771495 1.970147 NaN 51.0 2.0 0.0 49.0 2.0\n", 841 | "user 1.854205 1.821117 NaN 431.0 49.0 0.0 382.0 49.0\n", 842 | "[SEP] 1.788042 1.768903 NaN 565.0 66.0 0.0 499.0 66.0" 843 | ] 844 | }, 845 | "execution_count": 8, 846 | "metadata": {}, 847 | "output_type": "execute_result" 848 | } 849 | ], 850 | "source": [ 851 | "entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives = get_tokens_entropy(model, tokenizer, test, device)\n", 852 | "entropy_df = pd.DataFrame(\n", 853 | " [\n", 854 | " entropy_dict, entropy_fps, entropy_fns, word_occ, fps, fns, num_positives, num_negatives\n", 855 | " ], index=[\"entropy\", \"entropy_fps\", \"entropy_fns\", \"count\", \"fps\", \"fns\", \"num_pos\", \"num_neg\"]\n", 856 | ").T\n", 857 | "entropy_df.head()" 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": 9, 863 | "id": "musical-chance", 864 | "metadata": {}, 865 | "outputs": [], 866 | "source": [ 867 | "entropy_df.sort_values(\"entropy\", ascending=True).to_csv(\"latex/term_extraction/mlma.csv\")" 868 | ] 869 | } 870 | ], 871 | "metadata": { 872 | "interpreter": { 873 | "hash": "be9233e59ab1f58386d6c7ed9ac3c1cc25c801ceabc31d17529d54705ed180c1" 874 | }, 875 | "kernelspec": { 876 | "display_name": "Python 3.8.5 64-bit ('unbias': conda)", 877 | "name": "python3" 878 | }, 879 | "language_info": { 880 | "codemirror_mode": { 881 | "name": "ipython", 882 | "version": 3 883 | }, 884 | "file_extension": ".py", 885 | "mimetype": "text/x-python", 886 | "name": "python", 887 | "nbconvert_exporter": "python", 888 | "pygments_lexer": "ipython3", 889 | "version": "3.8.5" 890 | } 891 | }, 892 | "nbformat": 4, 893 | "nbformat_minor": 5 894 | } 895 | -------------------------------------------------------------------------------- /data/mlma_dev.tsv: -------------------------------------------------------------------------------- 1 | HITId tweet sentiment directness annotator_sentiment target group hate 2 | 2278 america turned best land civilized people negro people less' earthqua @url offensive_normal indirect shock_disgust origin special_needs True 3 | 5352 @user i'm glad mongoloid face able bring momentary joy. offensive indirect anger_shock origin refugees True 4 | 1564 nigger stage?! offensive direct confusion_shock_disgust_anger_fear_indifference origin indian/hindu True 5 | 1233 stop retweeting telling people stupid guys name made retarded statement twitch con. idk @url offensive direct confusion_fear_shock_sadness disability special_needs True 6 | 4069 feel edgy saying retard every 2 seconds integrated deeply vocabulary can't help normal indirect confusion_disgust disability special_needs False 7 | 5612 post video sweating pub killing someone who's tryna play casual faggot sorry sorry offensive indirect anger_confusion_sadness_indifference_disgust sexual_orientation other True 8 | 984 invite someone home without knowing are. lettin illegal aliens vote ? @url normal indirect confusion_sadness_shock_indifference_disgust origin immigrants False 9 | 4785 worst flight life rough fuck cunt smelt rotten egg bastard behind snoring like pig @url abusive_offensive indirect anger_confusion_sadness_shock_disgust gender other True 10 | 2661 comedian & retarded @url offensive indirect confusion_sadness other other True 11 | 3135 @user retard abusive_offensive indirect shock_indifference disability women True 12 | 1644 i'm proud big brudda lil kevin. nigga came outta jail put pressure. features. whole album go retarded offensive indirect sadness_indifference disability other True 13 | 172 got head like dyke got head like @url offensive indirect sadness_disgust sexual_orientation women True 14 | 547 los quise fucking retarded offensive indirect anger_confusion_sadness_shock_disgust disability special_needs True 15 | 3812 okay aku mengaku diri aku hensem - random kiddo otw became faggot.... hahahaha ... wish level @url offensive indirect confusion_sadness_shock_indifference_disgust sexual_orientation individual True 16 | 2058 got banned stream reported writing ching chong. live society. @url offensive indirect anger_confusion_sadness_indifference_disgust origin left_wing_people True 17 | 168 selling dvds hahahaha shop gonna hate woman counter mongy cunt anyway deser @url fearful_abusive_offensive_hateful_normal indirect anger_fear_indifference_sadness_disgust disability women True 18 | 2330 irrespective people's views kris commons bit soup taker shud b hof front @url offensive indirect fear_shock_indifference_sadness_disgust origin special_needs True 19 | 5420 pruitt retarded. aight? fearful_abusive_hateful_disrespectful_normal indirect anger_fear origin special_needs True 20 | 4749 nigger @url hateful_normal indirect anger_confusion_sadness_indifference_disgust origin african_descent True 21 | 2733 send africa... soon @url offensive indirect anger_sadness disability jews True 22 | 3583 fuck dutch retard (armoo) collabed 6ix9ine... literally everyone hates offensive indirect anger_sadness other special_needs True 23 | 3000 @user hey riotgames got permabanned yesterday saying ching chong im chinese game @url offensive indirect confusion_disgust origin muslims True 24 | 5345 @user @user @user matter raghead killed ragheads. offensive direct fear_shock_indifference_sadness_disgust origin special_needs True 25 | 3854 would leave shithole country hate yet wave fucking flag country march awa @url offensive indirect sadness_shock origin immigrants True 26 | 2653 drinking morning tea song ... spic jumped @url offensive indirect anger_confusion_sadness_shock_indifference origin other True 27 | 3714 @user @user 5ft tunnel rat jungle gook shrimp dick broken english irl anime lookin chinaman offensive indirect anger_sadness_shock_indifference_disgust origin indian/hindu True 28 | 225 shut nigger @url offensive indirect anger_fear_shock_indifference_sadness origin individual True 29 | 4792 @user @user @user @user @user @user 'shithole countries' commen @url offensive indirect anger_confusion_sadness_shock_disgust origin other True 30 | 3370 hes another nigger smh @url normal indirect anger_confusion_sadness_shock_indifference origin refugees False 31 | 3487 hmmm...doesn't fit trump's profiling description mexican muslim black foreign shithole cou @url fearful_abusive indirect anger_confusion_fear_shock_indifference other other True 32 | 1439 @user @user @user @user @user actions speak louder mexican rapists shithole countries @url offensive indirect anger_disgust origin indian/hindu True 33 | 2410 @user oh god sound retarded thought offensive indirect sadness_disgust disability special_needs True 34 | 3174 @user satanic sellout faggot . offensive indirect anger_confusion_fear_sadness_disgust sexual_orientation gay True 35 | 5485 @user odd cunt offensive direct anger_confusion_shock_indifference_disgust gender women True 36 | 1408 white folks fights public last dangerously long. like 10 minutes. nigger like see fights break ou @url offensive indirect shock_disgust origin other True 37 | 1263 @user @user claimed leftist? asked vox would retract @url normal indirect indifference_disgust other other False 38 | 4660 @user reading article end i'm troubled choice let twat on. lord @url normal indirect anger_indifference other other False 39 | 5291 @user care illegal immigrants americans? offensive indirect anger_sadness origin other True 40 | 3222 @user @user ok retard offensive indirect sadness_indifference disability special_needs True 41 | 4272 @user i'd watch i'm school probably paying attention sorry doorbell ding dong pink fong wing wong ching chong offensive indirect anger_confusion_sadness_shock_disgust origin special_needs True 42 | 5570 i'm racist but... notice always naggers spic's pushing drugs every @url offensive_hateful indirect anger_disgust origin individual True 43 | 2975 he's even pretending talk specific shithole countries anymore. he's talking radical islam @url offensive indirect anger_sadness religion muslims True 44 | 5412 someone take pricks head off. cunt interviewed too. @url hateful direct anger_confusion_shock_indifference_disgust gender individual True 45 | 2488 dont want white people using word nigger maybe stop using time yourself??? hateful indirect anger_disgust origin women True 46 | 5315 feminazi female counterpart surely condone that's bes @url offensive indirect sadness_indifference other women True 47 | 4086 #dagga #marijuana #weed #cannabis #southafrica #africa @url offensive indirect anger_confusion_sadness_shock_indifference origin other True 48 | 3480 kinda gutted get hours extra sleep. yes know i'm hong kong moment. yes know i'm greedy twat. disrespectful_normal indirect confusion_disgust other other True 49 | 5540 @user longer united states. trumpistan. shithole country. nov 6 later.... women @url fearful_abusive_offensive_hateful_disrespectful indirect confusion_fear_shock_indifference_disgust origin other True 50 | 2864 #democrats defend illegal aliens support sanctuary cities create racial divide back fund pp tissue @url normal indirect anger_fear_shock_indifference_sadness other other False 51 | 4974 someone tolerant left calls hateful racist basically calls spic @url normal indirect sadness_disgust origin other False 52 | 5256 goes dumb ass negro @url hateful_disrespectful indirect anger_confusion_sadness_indifference_fear origin african_descent True 53 | 4757 @user @user today twat. champions league advertising boards behind it. delete offensive_disrespectful indirect indifference_disgust gender women True 54 | 1921 @user u dont negro offensive indirect shock_disgust origin special_needs True 55 | 1983 shithole countries @url offensive indirect sadness_indifference origin other True 56 | 558 soros owned company used american tax dollars build app help illegal aliens evade authorities. @url hateful indirect sadness_indifference origin immigrants True 57 | 464 sarah silverman gone total retard smh @url offensive indirect anger_confusion_fear_shock_disgust disability individual True 58 | 1192 @user ole fiend lame ass negro lol abusive_offensive indirect sadness_indifference origin other True 59 | 5047 @user @user mongoloid = idiot? think clever enough know offensive @url offensive indirect confusion_fear_shock_sadness_disgust disability other True 60 | 753 get wrong good music give retard money and... @url fearful_abusive_offensive_hateful_disrespectful indirect confusion_fear_shock_indifference_sadness disability other True 61 | 2512 @user see ur retarded ass doin shit like @url offensive indirect anger_sadness_shock_indifference_disgust disability special_needs True 62 | 5439 @user cupcake code retard retweeted alternate account commented main. offensive indirect anger_shock disability special_needs True 63 | 185 photographer - humana - 4 #portrait #rio_negro #argentina #strkng #photography #contemporary #bestof @url normal indirect anger_indifference other special_needs False 64 | 567 fucking faggot @url offensive direct anger_disgust sexual_orientation other True 65 | 630 @user calling faggot slut unnnhola go earn money sucking offensive indirect anger_disgust sexual_orientation gay True 66 | 232 @user @user en retard look ton insta offensive indirect anger_confusion_sadness_shock_indifference disability special_needs True 67 | 2809 @user boy child slave looking ass unfunny spic ass beaner ass ass titty sucking unfunny gook ass @url abusive_hateful indirect shock_disgust origin individual True 68 | 1137 @user game looks absolutely retarded even offensive indirect sadness_shock disability special_needs True 69 | 1392 @user camerroonians celebrities people surely want shout oh wanna look retarded hateful indirect anger_disgust origin special_needs True 70 | 1344 @user @user @user @user @user call chinaman there. see long last offensive indirect anger_fear_shock_indifference_sadness origin other True 71 | 1502 @user retard offensive indirect sadness_indifference disability special_needs True 72 | 5486 @user good job bantu ass bet retarded even know that. offensive indirect sadness_shock disability special_needs True 73 | 5191 twat better change mind.nn...and stay fuck away!!! @url hateful_disrespectful direct fear_disgust origin individual True 74 | 3638 @user @user tell #maga supporters stop mailing bombs makes shithole country. offensive indirect shock_indifference_disgust origin other True 75 | 5233 @user he's right retard. offensive indirect shock_disgust disability special_needs True 76 | 3522 lhhhh wtf i'm weak. girl retarded. lmfao offensive indirect anger_disgust disability special_needs True 77 | 704 @user i'm glad read silly spic like ever attain hold entry level jobs swing ham @url fearful_offensive indirect anger_disgust origin special_needs True 78 | 3672 @user @user overwatch dead retard offensive indirect fear_shock_indifference_sadness_disgust disability special_needs True 79 | 124 @user @user mongy american probably offensive indirect sadness_shock_indifference_disgust origin special_needs True 80 | 1690 @user shut face slap ass call dirty jew! im stoked! offensive indirect anger_confusion_shock_indifference_disgust religion jews True 81 | 4908 @user @user tattie howker though. aw soup taker patter absolutely tragic. offensive_normal indirect anger_confusion_sadness_shock_disgust other other True 82 | 406 @user @user shut retard offensive indirect confusion_shock_disgust_anger_fear_indifference disability special_needs True 83 | 748 @user u take aggression lady behind me? giant twat hateful direct anger_shock gender women True 84 | 3557 tell something... kind faggot runs around christmas sweater? offensive indirect sadness_indifference gender women True 85 | 1409 @user violent variety faggot offensive indirect confusion_indifference sexual_orientation gay True 86 | 2313 @user @user @user country run like zimbabwe done shithole african state normal indirect sadness_indifference origin other False 87 | 30 @user @user btw. allowed say shithole countries? also allowed say wes @url normal indirect anger_sadness origin other False 88 | 3385 @user please start fire luke campaign cause kinda retarded. rondo finishing @url offensive indirect anger_sadness_shock_indifference_disgust disability other True 89 | 4879 put phone down. nif noticed country mourning ; yet again. orange twat. n@url offensive indirect sadness_disgust origin individual True 90 | 1757 @user sorry hear experience mr punia :( try best keep aircraft spic span @url offensive indirect anger_indifference origin other True 91 | 3664 mongoloid_devo_saw vid(+heard song fink) 4the 1st time days ago @url mongol @url offensive_normal indirect indifference_disgust origin special_needs True 92 | 2990 @user shame born shithole country. barrister law studied hard protect r @url offensive indirect indifference_disgust origin other True 93 | 4918 @user @user @user da boss' da fucking cunt like. mean world relevan @url offensive_hateful indirect anger_sadness gender women True 94 | 3786 lmfao like dude y'all network. man retarded @url normal indirect anger_confusion_sadness_indifference_disgust disability special_needs False 95 | 2755 ive seriously never seen bigger faggot pewwi abusive_disrespectful indirect anger_disgust disability gay True 96 | 3312 retarded. know make people stupid. might want kill drop collec @url hateful indirect anger_fear_indifference_sadness_disgust disability special_needs True 97 | 1323 les migrants sont une richesse pour leurope! typhus abdominal offensive_normal indirect confusion_sadness_shock_indifference_disgust origin special_needs True 98 | 5265 pace game fn retarded #ps4sharenn@url @url offensive indirect indifference_disgust disability special_needs True 99 | 1054 drunk cunt offensive_hateful direct anger_confusion_sadness_indifference_fear origin other True 100 | 982 watching charles twitch like taking care retarded child. ever let go hand gonn @url fearful_abusive_offensive_hateful_normal indirect anger_confusion_sadness_shock_indifference origin special_needs True 101 | 149 proper cunt blokes likely lost friend colleague fuckers already diving absolute @url offensive indirect anger_confusion_fear_indifference_disgust other other True 102 | 957 @user @user @user @user @user called raghead first. @url offensive direct anger_disgust origin arabs True 103 | 1012 dirty stupid nigger typical @url offensive indirect anger_disgust origin african_descent True 104 | 4044 @user see funny side tbh lads twat offensive_disrespectful indirect sadness_disgust gender women True 105 | 1626 ta said wanted chinese food girl kept repating ching chong told racist said @url offensive indirect confusion_indifference origin special_needs True 106 | 1817 actually love mongy mackem face @url offensive indirect confusion_sadness_shock_indifference_disgust other other True 107 | 4305 @user abysmal.n hope finds way one trump'sshithole countries. oopps forgot loves saudi. hateful_disrespectful indirect sadness_indifference origin other True 108 | 936 @user cheap cunt offensive indirect anger_disgust gender women True 109 | 4560 @user okay white okay. offensive indirect confusion_indifference origin other True 110 | 1867 stream men asked opinion cunt @url fearful_hateful indirect fear_indifference gender other True 111 | 1620 @user @user oh cool. love dragon ball man! favorite sasuke ching chong shit. abusive_offensive indirect confusion_indifference origin individual True 112 | 698 iron lung blimp cat soup @url fearful_disrespectful indirect confusion_fear_indifference_sadness_disgust disability other True 113 | 1893 @user @user saying nigger fine said 2017 right? lmao offensive indirect shock_disgust origin indian/hindu True 114 | 2893 really hates club twat! @url offensive_hateful indirect anger_sadness other other True 115 | 3096 @user stop reading break ching chong guy start showing im like fuck offensive indirect anger_disgust origin muslims True 116 | 4912 brisbane roar fucking score cunt normal indirect sadness_indifference gender other False 117 | 1592 @user holy shit people retarded. halloween folks blood look like fucking cherry syrup. offensive indirect confusion_fear_indifference_sadness_disgust disability special_needs True 118 | 1687 @user @user @user @user @user @user i'm 26 retard normal indirect anger_disgust disability special_needs False 119 | 2049 @user @user retarded skism foh ty cash stoopid offensive_normal indirect sadness_disgust origin special_needs True 120 | 5020 @user hey faggot suck balls normal direct confusion_indifference origin women False 121 | 1769 @user lord wants us make world better place. can't country becomes shithole. offensive indirect anger_confusion_sadness_indifference_disgust origin other True 122 | 5091 always one cunt fucking hate fearful_abusive_offensive_hateful_disrespectful indirect fear_indifference gender women True 123 | 1167 @user stupid spic offensive direct anger_disgust origin individual True 124 | 4243 absolute waste oxygen twat is!! blame media hate @user @url offensive indirect sadness_disgust gender women True 125 | 9 @user @user even called country shithole people shithole people. sad. offensive indirect anger_fear_indifference_sadness_disgust origin special_needs True 126 | 5509 pls stop putting ur cancer twitter u retard offensive indirect anger_sadness disability special_needs True 127 | 1840 @user faggot needs get knocked hateful_normal direct fear_shock_indifference_sadness_disgust sexual_orientation gay True 128 | 1171 @user went full retard never go full retard @url normal indirect anger_disgust origin other False 129 | 4391 @user uk shithole. canada actually real nice country. uk bout fuck super hard brexit @url fearful_abusive_offensive_disrespectful_normal indirect anger_shock origin special_needs True 130 | 4837 @user multiple immigrant american citizens shithole countries too? offensive indirect confusion_indifference origin other True 131 | 4964 tots godwin _ hilarious mongoloid @url hateful direct anger_fear_shock_sadness_disgust origin individual True 132 | 2141 @user like make known n shit.. i'm better playing spic pot head who's getting hateful indirect shock_disgust origin individual True 133 | 4374 ref cunt offensive indirect anger_disgust other women True 134 | 602 i'm pussy boynwomen toyyysssni'm spasticnit's cybertasticnyou suck dicknonly aqw spicni @url offensive indirect anger_sadness origin other True 135 | 4005 @user :/ dorf im gonna lie sweden pretty fucked hell im shithole third world country disrespectful indirect anger_disgust origin other True 136 | 1155 rich hill...twat hateful_normal indirect confusion_disgust other other True 137 | 4168 @user incredible even resolution condemn illegal aliens voting? isnt al @url offensive indirect confusion_disgust origin other True 138 | 1794 mung dumb retard offensive_disrespectful indirect sadness_indifference disability special_needs True 139 | 3428 @user faggot still @url hateful direct confusion_disgust sexual_orientation individual True 140 | 5296 @user fuck want literal retard abusive_offensive_hateful_normal indirect anger_confusion_sadness_shock disability special_needs True 141 | 1974 @user @user @user another soup taker. disrespectful_normal indirect confusion_indifference other other True 142 | 4091 retarded since messi individually miles ahead ronaldo. watch game players c @url normal indirect anger_fear_indifference_sadness_disgust disability special_needs False 143 | 2611 @user @user they're parasites came shithole countries need stay shithole countries hateful indirect shock_disgust origin other True 144 | 3117 favorite colors pumpkin spice orange waffle house yellow vomit green nigger black offensive indirect confusion_sadness origin hispanics True 145 | 902 @user one shithole countries right? offensive_disrespectful indirect confusion_indifference origin other True 146 | 3415 shithole countries i'd never visit even paid millions. savage backward inhumane bastards @url abusive_offensive indirect confusion_indifference origin other True 147 | 4015 @user @user say nigger validated offensive indirect anger_disgust origin special_needs True 148 | 3832 nigga dang retarded! @url fearful_abusive indirect anger_shock origin special_needs True 149 | 1976 steve little cunt #thehauntingofhillhouse fearful_abusive_offensive_hateful_disrespectful indirect anger_sadness gender individual True 150 | 294 @user @user say ching chong time? offensive indirect confusion_sadness origin special_needs True 151 | 1745 @user @user faggot abusive_normal indirect anger_confusion_fear_shock_disgust sexual_orientation individual True 152 | 4710 @user @user well western nations stopped helping shithole countries people would starve death. problem solved. hateful indirect confusion_disgust origin other True 153 | 1469 blah blah blah call retard different different viewpoints. that's easy way cen @url offensive indirect confusion_indifference disability special_needs True 154 | 2483 think genuine consensus today's game fucking smash steven naismith's twat fearful_abusive_hateful_normal indirect sadness_disgust gender women True 155 | 5262 @user @user right nintendo japanese brand they're going full retard li @url offensive indirect anger_fear_indifference_sadness disability special_needs True 156 | 5104 kelly rowland really called freddie kruger faggot lmaoooooooooooo @url offensive indirect confusion_disgust gender gay True 157 | 4449 here's problem word banter. banter exchange remarks.nnif call twat call tw @url normal indirect indifference_disgust other other False 158 | 924 @user tories turned country third world shithole rich getting richer & poor @url offensive indirect anger_disgust other other True 159 | 1070 bts fans mongoloid @url offensive_normal indirect confusion_indifference disability other True 160 | 1208 damn retarded offensive indirect confusion_shock other other True 161 | 3708 @user really horrible cunt times abusive_offensive indirect anger_fear_shock_indifference_disgust gender women True 162 | 479 @user @user popcorn selfies stick i'd snap cunt offensive indirect confusion_fear_shock_indifference_disgust gender women True 163 | 3663 triggered chinaman beat gf trying mend fences michael avenatti shaun kingnnhow w @url offensive indirect fear_shock_indifference_sadness_disgust origin other True 164 | 3821 retard getting blocked lol offensive_hateful direct anger_confusion_sadness_shock_indifference disability special_needs True 165 | 191 @user skele boi far superior toothy mongoloid offensive indirect anger_confusion_shock_indifference_disgust origin individual True 166 | 4433 @user @user i'll second proper mongy fearful_offensive_disrespectful_hateful_normal indirect sadness_disgust disability special_needs True 167 | 2618 put pause sermon n locked church doors beat ur ass pastor tj boxes jesus nigger @url offensive indirect fear_disgust gender other True 168 | 1613 @user ching chong ching chong? offensive indirect anger_confusion_fear_indifference_disgust origin hispanics True 169 | 968 women twitter really retard pack @url offensive indirect anger_fear_shock_sadness_disgust gender women True 170 | 1425 @user that's saying u retard offensive indirect confusion_indifference disability special_needs True 171 | 2698 @user fucking twat offensive indirect anger_fear gender women True 172 | 1384 @user yeah mongy farmers offensive indirect anger_fear_shock_sadness_disgust gender other True 173 | 1495 @user yet hungry children homeless vets illegal aliens. offensive indirect confusion_indifference sexual_orientation special_needs True 174 | 3943 @user fucking retarded agreed helicopters dangerous twat. that's totally point. @url normal indirect anger_confusion_fear_sadness_disgust disability special_needs False 175 | 1737 today asylum six iranian refugees. they're free live lives peace r @url fearful_disrespectful indirect fear_disgust origin refugees True 176 | 3028 hush lil nigger @url offensive indirect anger_indifference origin other True 177 | 392 @user hi still thinking mongoloid ingrate yesterday dirty snp bastard called @url offensive indirect indifference_disgust disability other True 178 | 2073 trump wants much nigger navy: heres much itll cost offensive indirect sadness_disgust origin special_needs True 179 | 429 @user obviously cunt even know teams history offensive indirect confusion_indifference gender women True 180 | 4727 y'all still use retarded describe something ridiculous stupid want know deserve @url normal indirect anger_indifference other other False 181 | 3733 niggas mad cause u retarded offensive indirect confusion_fear_indifference_sadness_disgust disability special_needs True 182 | 1956 @user @user much country either! soon another third world islamic shithole! offensive indirect anger_fear_indifference_sadness_disgust origin other True 183 | 2309 someone remove mongy jude corrie makes wanna smash telly normal indirect sadness_indifference disability other False 184 | 3791 name spic now? lol copy let one many names society got you. wanna cry racism. @url normal indirect anger_sadness_shock_indifference_disgust other other False 185 | 4351 'humanitarian emergency': doctors demand refugees moved nauru | sbs news @url via @user fearful_normal indirect anger_fear_shock_indifference_disgust origin refugees True 186 | 4296 brain morning: potato potato ching chong tomato hateful_disrespectful indirect anger_confusion_fear_shock_sadness origin other True 187 | 3568 point possibly retarded . offensive indirect sadness_indifference disability special_needs True 188 | 2204 low iq: spic?nnavg iq: false flagnnhigh iq: paint hair head? @url fearful_abusive_offensive_hateful_disrespectful indirect confusion_sadness_shock_indifference_disgust origin special_needs True 189 | 823 thread called spic kissed cousin @url offensive indirect sadness_disgust origin special_needs True 190 | 389 ..why illegal aliens allowed mobile phones? @url hateful indirect anger_indifference origin other True 191 | 2333 @user venmo wolffis $10 sucking dick faggot @user #2birds1stone abusive_hateful direct anger_disgust sexual_orientation other True 192 | 1749 @user look macrons eurafrica union mediterranean. wants millions more. offensive_normal indirect sadness_indifference origin other True 193 | 50 @user wtf gym king phone little twat offensive direct anger_disgust origin women True 194 | 3191 oh wow definitely know tried fuck bitch outed me. happy tea time cunt. (; normal indirect anger_confusion_shock_indifference_disgust other women False 195 | 304 stupid greasy spic ayoze starts next week i'll commit. offensive indirect shock_indifference origin hispanics True 196 | 2381 @user yeah probably acts like right retard basically wanted become meme offensive indirect sadness_disgust disability special_needs True 197 | 1062 friend's boyfriend asked ching chong real word...no dipshit stereotype wtf ????? normal indirect anger_sadness_shock_disgust origin women False 198 | 1256 @user con play ching chong cha determine surname child gets. wa @url offensive indirect indifference_disgust origin other True 199 | 1606 @user @user almost mueller time shithole (term used referring 3rd world countries)!!! offensive_normal indirect confusion_indifference origin women True 200 | 2574 banter excuse 14-year old boy uses twat class. @url offensive indirect anger_confusion_fear_shock_sadness gender special_needs True 201 | 1808 @user @user @user @user @user loved called faggot every second funny abusive_offensive indirect confusion_fear_indifference_sadness_disgust sexual_orientation left_wing_people True 202 | 2380 @user @user @user @user @user @user saw night retarded haha rul @url offensive indirect anger_sadness_shock_indifference_disgust disability special_needs True 203 | 2872 @user coming thru normal indirect fear_indifference other other False 204 | 665 @url hear sick leftists aka democrats want bring migrants/illegal @url offensive indirect anger_sadness other special_needs True 205 | 3319 could go one #shithole countries right?nyeah alabama. rikers middle fina @url normal indirect anger_disgust origin special_needs False 206 | 3346 retarded make past brain censors fearful_abusive_offensive_hateful_disrespectful indirect sadness_shock origin special_needs True 207 | 4270 @user made complete twat today you? fearful_abusive_disrespectful_hateful_normal indirect anger_disgust gender individual True 208 | 2610 @user @user russia shithole country offensive indirect shock_disgust origin other True 209 | 3014 get points faggot ahahah abusive_disrespectful indirect anger_sadness sexual_orientation refugees True 210 | 3979 @user @user mongoloid celebrate win whiskey. someone tell allah doesnt @url offensive indirect anger_sadness_shock_indifference_disgust religion special_needs True 211 | 3524 okay white @url normal indirect sadness_disgust origin other False 212 | 2191 ya sure - africa's full 'shithole countries'... @url offensive indirect anger_fear_shock_sadness_disgust origin left_wing_people True 213 | 1124 public dinosaurs private dinosaurs. faggot. hateful direct indifference_disgust sexual_orientation gay True 214 | 2126 @user people rather called spic hungry prosecuted. offensive indirect confusion_sadness_shock_indifference_disgust origin other True 215 | 3406 @user oooh noooo high pitch ching chong songs offensive indirect shock_indifference origin asians True 216 | 4273 @user except ur massive twat cut said soju u ruined whole vibe offensive indirect anger_confusion_sadness_shock_disgust origin women True 217 | 2734 @user it;s different saying dur dur durrrrr mock mentally challenged people 'ching chong ch @url offensive indirect anger_confusion_fear_shock_sadness disability special_needs True 218 | 2303 @user @user ching chong wong wing bong? offensive indirect anger_confusion_sadness_indifference_disgust origin other True 219 | 4830 annoying twat face? k thanks. abusive_normal indirect confusion_indifference origin women True 220 | 241 i'm looking relationship means i'm trying hoe around ? retard juice strong offensive indirect anger_confusion_sadness_shock_indifference disability special_needs True 221 | 3972 alls said faggot @url fearful_offensive indirect sadness_disgust sexual_orientation women True 222 | 642 trump incites hate daily blames #pittsburgh victims guns. shithole! ...trump reacting @url offensive indirect shock_disgust other other True 223 | 412 ryan cunt normal indirect sadness_shock other individual False 224 | 407 men fucking retarded offensive indirect anger_indifference gender other True 225 | 4105 fuck that. feminazis outraged ...bring say! @url hateful indirect fear_shock_indifference_sadness_disgust gender other True 226 | 3120 @user c bomb put thrown around lot in: cunt!...what cunts @url offensive indirect confusion_sadness gender women True 227 | 2265 @user literally faggot offensive_hateful indirect anger_disgust sexual_orientation women True 228 | 4863 fucking cunt!!!!! @url hateful indirect shock_disgust gender women True 229 | 381 @user @user @user @user @user definately receiving twitter top twat th @url offensive indirect anger_indifference other other True 230 | 2518 aliens actually retarded offensive indirect sadness_disgust disability special_needs True 231 | 3304 #staranurag @user @url normal indirect sadness_disgust other other False 232 | 1123 @user prayers @user nigger time offensive indirect anger_confusion_sadness_indifference_disgust origin other True 233 | 774 means el salvador can't create well-run society. course people want leave shithole countrie @url offensive indirect indifference_disgust origin other True 234 | 5427 def used mumble nigger breath dak missed (so prob 10x day) @url offensive indirect anger_confusion_fear_shock_sadness origin other True 235 | 5107 @user @user @user @user dont ask faggot offensive indirect fear_shock sexual_orientation individual True 236 | 346 @user @user @user follow dirty jew offensive direct sadness_disgust religion jews True 237 | 5103 @user ching chong ching chang offensive_disrespectful indirect confusion_disgust origin other True 238 | 877 #nyc #newyork #newyorkcity #art #artist #gay #faggot #pouredpaint #acrylicpainting #acrylicpouring #queer #dark @user n @url offensive indirect indifference_disgust sexual_orientation other True 239 | 2988 sabrina claudio would call nigger face stop listening music offensive indirect anger_indifference other left_wing_people True 240 | 5154 retard sat next started playing eminem's new album fearful_abusive_disrespectful_hateful_normal indirect confusion_disgust disability special_needs True 241 | 4741 @user @user @user @user many prems busquets leicester?nnya plastic mongy offensive indirect sadness_disgust origin special_needs True 242 | 1283 @user @user ching chong ding dong abusive_normal indirect anger_confusion_fear_indifference_disgust origin muslims True 243 | 492 okay white celebrate australia day #auspol normal indirect shock_indifference origin other False 244 | 1665 put gun cunt. fucking brave eh tiger? that's get @url offensive direct confusion_fear_shock_indifference_disgust other women True 245 | 2292 going try come faggot? ged education helpful. i'll read quick funny offensive_hateful indirect anger_disgust sexual_orientation individual True 246 | 2127 @user glasses steal frames n real noho twat offensive indirect confusion_sadness_shock_indifference_disgust gender women True 247 | 4804 @user @user shes feminazi linebacker like you? offensive direct anger_confusion_indifference_disgust gender women True 248 | 5495 #lemotdujour retard abusive_offensive_hateful_disrespectful_normal indirect sadness_disgust disability special_needs True 249 | 4088 sober truth ideology see retard defend it. offensive direct sadness_indifference disability special_needs True 250 | 23 @user que ce ne soit pas des burundais refugies alerte car leur corps sont visibles loeil nu hateful_normal indirect shock_disgust origin other True 251 | 3643 clearly he's retarded @url offensive indirect sadness_indifference disability special_needs True 252 | 4907 @user @user @user also kennewick man washington state. time period. definitely mongoloid. offensive indirect anger_confusion_sadness_shock_disgust disability special_needs True 253 | 676 @user hey thought hated called retard??? offensive_hateful direct indifference_disgust disability special_needs True 254 | 3554 @user chicoms love brainwash minorities. soros chicom mongol. offensive indirect confusion_sadness_shock_indifference_disgust origin individual True 255 | 49 doesnt right? negro default fearful_offensive_disrespectful_hateful_normal indirect confusion_sadness_shock_indifference_disgust origin african_descent True 256 | 5290 @user idiots realize mass migrations bring disease. offensive indirect anger_sadness origin individual True 257 | 735 discernible skill give visa here. sorry want people shithole countries @url offensive indirect anger_confusion_fear_shock_disgust origin other True 258 | 3180 @user epitome cunt. offensive direct indifference_disgust gender women True 259 | 2687 @user one days one illegal aliens r going bring deadly disease kill half c @url offensive indirect anger_confusion_sadness_shock_disgust origin immigrants True 260 | 1440 anyways intellectual creative nigger jumped joon today goodnight offensive direct anger_shock origin african_descent True 261 | 4137 ching chong pyro agree @url normal indirect anger_confusion_sadness_indifference other other False 262 | 1308 @user get one person care shithole country please thanks. wanna l @url normal indirect confusion_disgust origin other False 263 | 2497 wanna know twat shot #synagogue? read #antisemitic social media pages like @user r @url offensive indirect anger_confusion_sadness_shock_disgust origin left_wing_people True 264 | 2511 @user dick van dyke yes thought hateful_disrespectful indirect sadness_disgust sexual_orientation other True 265 | 1312 good news i'm afraid feminazi man hater f @url offensive indirect confusion_disgust gender women True 266 | 1353 faggot ass dodgers fearful_abusive_hateful_disrespectful_normal indirect fear_shock_indifference_sadness_disgust sexual_orientation individual True 267 | 5014 @user jen reached full blown retard. stop it. offensive indirect sadness_indifference disability special_needs True 268 | 5474 @user oh poor deluded one. matter much drink kool aid @user always spic. offensive direct fear_disgust origin individual True 269 | 3658 @user stayed in. dont relate feminazi useful idiots. enjoy living free country gr @url fearful_abusive_offensive_hateful_normal indirect fear_shock_indifference_sadness_disgust gender women True 270 | 4787 actual fuck people still like fucking mongoloid??? @url offensive indirect anger_confusion_sadness_shock_disgust origin special_needs True 271 | 4452 @user @user sound retarded?? offensive indirect anger_disgust disability special_needs True 272 | 1570 @user @user @user yes feel like right wingers would would leftist @url offensive indirect confusion_indifference other special_needs True 273 | 1544 @user lmaooo cunt abusive_hateful indirect anger_sadness_shock_indifference_disgust other women True 274 | 1242 retarded bastard actually using bomb threats sent people oppose suggest conspiracy @url fearful_offensive indirect confusion_fear_shock_sadness disability special_needs True 275 | 3511 ching chong day long offensive indirect sadness_indifference origin special_needs True 276 | 5148 sounds faggot me. @url offensive_hateful indirect indifference_disgust sexual_orientation gay True 277 | 383 @user @user @user @user sound retard democrat. literally files moron. par @url offensive indirect anger_indifference disability special_needs True 278 | 3400 @user < often difficult remember parody account! trump really much cunt abusive_offensive direct shock_indifference other individual True 279 | 4300 added video @user playlist @url kid goes full retard drops coin sewer offensive indirect sadness_indifference disability special_needs True 280 | 1849 hope twat ends jail. pal kenwright tried ruin everton... n @url offensive indirect indifference_disgust gender individual True 281 | 3797 @user think retarded draw drew fun asked opinion missy way past bed time ma'am offensive indirect confusion_fear_shock_sadness origin special_needs True 282 | 3841 @user @user dealing 3rd world shithole countries? offensive indirect anger_sadness_shock_indifference_disgust origin immigrants True 283 | 801 @user never zyzz bruh fuckin sick cunt bruh offensive indirect shock_indifference origin women True 284 | 3871 @user cunt friend offensive indirect anger_confusion_sadness_shock_disgust origin women True 285 | 5379 get onit last night - correct ndid make cunt - yesndo care - fucknfuck aberdeen abusive_offensive indirect sadness_disgust other other True 286 | 5450 hahayes everyone delete post apologize jason said so. dem leftist pos. oh fol @url offensive_disrespectful indirect sadness_indifference other other True 287 | 354 @user fuckin retard abusive_hateful direct anger_sadness_shock_indifference_disgust disability special_needs True 288 | 3504 @user @user yelling ching chong racist. period. offensive indirect fear_sadness origin women True 289 | 2586 @user '' putted '' typical mongy hobbo spelling that. offensive indirect indifference_disgust disability special_needs True 290 | 2339 @user @user @user @user sorry hear that. go back country use lear @url abusive_offensive_disrespectful_normal indirect sadness_shock_indifference_disgust origin individual True 291 | 4692 @user @user he's progressive/neo-marxist. radical leftist wants open borders & mass-migratio @url hateful indirect sadness_disgust other left_wing_people True 292 | 1004 @user whatever spic offensive indirect confusion_sadness_shock_indifference_disgust origin refugees True 293 | 5035 thats negro right offensive indirect anger_indifference origin women True 294 | 1651 @user yes should. doubt she'll it. another self absorbed feminazi attention whore. @url offensive indirect anger_confusion_sadness_shock_indifference gender women True 295 | 3077 shut nigger @url hateful direct anger_fear_shock_sadness_disgust origin other True 296 | 1558 verstappen bad start crashes twat ricciardo 1st vettel 2nd @url offensive indirect sadness_indifference gender special_needs True 297 | 3998 @user africa shithole countries!nthat someone's opinion.nni think people thoughtful @url offensive indirect anger_disgust origin special_needs True 298 | 2851 mid 40's smelly psy blk mom trying send retard blk kid goons car nights sleepni camera yet nnreport offensive_hateful indirect sadness_shock origin special_needs True 299 | 1873 picture called faggot said homo @url abusive_hateful indirect anger_fear_shock_sadness_disgust origin individual True 300 | 4229 ever call us we'll come. never taking account little cunt are. we'll save getting @url offensive indirect sadness_disgust origin women True 301 | 861 dis dyke finna get brutalized fearful_abusive_hateful_normal direct indifference_disgust other other True 302 | 5514 @user @user @user @user @user @user @user half breed mongoloid @url offensive indirect sadness_disgust origin indian/hindu True 303 | 5458 negra novinha cavalgando negro pauzudo fearful_abusive_offensive_hateful_disrespectful indirect fear_disgust origin other True 304 | 974 @user i'm surprised dems suggested make shithole countries us colonies @url offensive indirect fear_shock_indifference_sadness_disgust origin other True 305 | 2669 @user @user @user @user @user twat invest though. interest free @url offensive indirect confusion_sadness gender women True 306 | 4103 @user one retard frank ocean raf offensive_disrespectful indirect shock_indifference disability special_needs True 307 | 4781 @user @user raghead problem england. hows terrorist attacks? offensive_hateful indirect anger_confusion_sadness_shock_disgust origin other True 308 | 4855 bruh negro chief keef really recorded pump fisting track album hahahahaha offensive indirect anger_confusion_shock_indifference_disgust origin individual True 309 | 387 msnbc 7 bombs sent democrats conspiracy mongoloid right dirt bags. bombs tack therefore fb @url offensive indirect anger_sadness other special_needs True 310 | 1503 fucking retarded think ok flirt guys girlfriends mat @url fearful_offensive_hateful_disrespectful_normal indirect confusion_disgust disability special_needs True 311 | 255 keep mind.. call faggot spanish mean can't still understand you. offensive indirect confusion_fear_shock_indifference_sadness origin special_needs True 312 | 3447 thank god lost virginity nigger offensive indirect anger_disgust origin immigrants True 313 | 5501 @user can't call singing! twat would get four no's performance. got likeability. normal indirect anger_confusion_shock_indifference_disgust gender women False 314 | 832 @user said chinaman hateful direct anger_shock origin muslims True 315 | 5126 i'm glad talked acting retarded offensive indirect indifference_disgust disability special_needs True 316 | 1067 @user paris. november nigger hateful_normal direct indifference_disgust origin other True 317 | 3070 @user asians ching chong eyes fearful_hateful indirect anger_confusion_fear_indifference_disgust origin muslims True 318 | 4817 ching chong 100 offensive_normal indirect indifference_disgust origin muslims True 319 | 1567 antis: bts look like girls ching chong (plus racist shit) offensive direct confusion_sadness_shock_indifference_disgust origin other True 320 | 3350 @user @user unlike previous regime ran punjab like mongol halaku khan.... offensive indirect anger_fear_indifference_sadness_disgust origin other True 321 | 3493 terrorist week still: right-wing. white-guy. ever was.nnso take restaurant whining leftist mob @url fearful_abusive_offensive_hateful_disrespectful indirect sadness_indifference origin other True 322 | 1113 @user @user @user well say wants immigrants blue eyes call us shithole countr @url normal indirect confusion_sadness_shock_indifference_disgust origin other False 323 | 3506 wifi acting retarded .. see y'all tomorrow comcast offensive indirect anger_sadness_shock_indifference_disgust other special_needs True 324 | 2288 f-you dipshit. atleast apologizes. isnt enough low self esteem cunt like younnagain. i'm sorry cat. @url hateful indirect anger_disgust other other True 325 | 4620 cut hair short far families called dyke pelona soldier walter mercado normal indirect anger_sadness_shock_indifference_disgust origin gay False 326 | 5610 @user afrikaans dumb cunt speak it? abusive_disrespectful indirect anger_indifference origin women True 327 | 1912 @user nyou must always think country way swedish shithole. offensive indirect indifference_disgust origin other True 328 | 2832 yeah thats going help! talk retarded. want swarm violent morons storm j @url offensive indirect shock_indifference disability special_needs True 329 | 4526 @user @user @user @user @user @user aqua favourite spic @url offensive indirect anger_shock origin other True 330 | 155 @user fuck u spic offensive indirect anger_disgust origin hispanics True 331 | 840 golden retrievers smpr ta parse mongol @url offensive indirect confusion_fear_shock_indifference_disgust gender special_needs True 332 | 385 @user yup pretty much. i'm leaving shithole country. offensive indirect confusion_shock_disgust_anger_fear_indifference origin other True 333 | 4490 delay delaying must retarded. _ di pasir berbisik gunung bromo @url offensive indirect anger_confusion_sadness_shock_indifference origin special_needs True 334 | 2151 @user dude still twat waffle offensive indirect shock_indifference gender women True 335 | 1898 @user retard offensive indirect shock_disgust disability special_needs True 336 | 3528 @user @user refused entry countries made stay shithole helped dig. offensive indirect anger_confusion_sadness_indifference_disgust origin other True 337 | 86 everything better. cunt alive.. offensive indirect anger_sadness gender women True 338 | 3987 @user names say fucking retard dont waste time looser instead praying leicester city helicopter crash offensive indirect anger_confusion_sadness_shock_disgust disability special_needs True 339 | 308 @user delete feature retarded anyways... would someone delete something payed for? like @url offensive_hateful indirect anger_confusion_shock_indifference_disgust disability special_needs True 340 | 3381 yeah can't careful. know people shithole countries mi @url offensive indirect anger_sadness origin special_needs True 341 | 1959 dev hynes really snapped negro swan offensive indirect anger_confusion_fear_indifference_disgust origin special_needs True 342 | 4906 @user long retard offensive indirect anger_confusion_sadness_shock_disgust disability special_needs True 343 | 3390 @user @user @user grown ups show faces use thea names.mongoloid offensive indirect indifference_disgust disability special_needs True 344 | 895 @user @user jesus teacher must mongy offensive indirect fear_shock_indifference_sadness_disgust disability special_needs True 345 | 3093 time set twit twat normal indirect indifference_disgust other other False 346 | 1310 oliver's got full camel toe poor fucker offensive indirect confusion_fear_shock_sadness_disgust origin other True 347 | 2701 holy shit. models actually tell forget breath? hot chicks really retarded #luxurycard disrespectful_normal indirect confusion_disgust disability women True 348 | 2752 gotta mongoloid ask question. trolling. either way. ain't chief @url offensive indirect anger_sadness origin special_needs True 349 | 1783 @user yep. got place spic span thought well last naptime ends. nnand right. normal indirect anger_confusion_fear_shock_indifference origin other False 350 | 2253 @user @user @user donald trump retard-.. people study down's syndrome saying yes @url offensive_hateful indirect anger_confusion_fear_shock_disgust other special_needs True 351 | 2846 @user nigger talk me. offensive direct fear_indifference origin individual True 352 | 5308 @user many boards bored mongol could hoard??... nwait rhetorical? offensive indirect indifference_disgust origin other True 353 | 5588 can't believe ching chong 10 points! offensive indirect indifference_disgust origin muslims True 354 | 1092 @user ill dress black face chinaman savage injun whatever want to! america horseface! normal indirect sadness_disgust origin special_needs False 355 | 3656 hi everyone letting know okay white #auspol normal indirect anger_confusion_sadness_shock_indifference origin other False 356 | 173 literally feel retarded anytime use left hand...i looking kinda stupid offensive indirect anger_sadness disability special_needs True 357 | 746 wow... becoming shithole country. @url offensive indirect confusion_fear_shock_sadness_disgust origin other True 358 | 3116 hate @user #spic #beaner normal indirect indifference_disgust origin hispanics False 359 | 3158 sissyslutfaggotbrendab: expose sissy #faggot #sissyexposure free to... @url abusive_normal indirect anger_confusion_fear_sadness_disgust other other True 360 | 5194 walk around attacking ppl ableist call someone mongoloid. ofc choose word w @url offensive indirect fear_disgust origin other True 361 | 1903 christians retarded @url normal indirect indifference_disgust religion other False 362 | 4700 @user usually dress cia agent halloween wanted dress chinaman would @url offensive indirect anger_fear_shock_indifference_sadness origin other True 363 | 2816 fuck turn weird cunt tell daughter go make friends mongy bitch @url offensive indirect shock_indifference other women True 364 | 836 @user thanks retard offensive indirect indifference_disgust disability special_needs True 365 | 481 @user shit sooo retarded offensive indirect shock_disgust disability special_needs True 366 | 3704 @user they're gonna tell shut spic hateful indirect anger_fear_shock_indifference_sadness origin other True 367 | 4204 @user dumbass nigger stupid ass nigger calling cracker well better nigger bett @url offensive direct shock_disgust origin special_needs True 368 | 612 talk according cristiano challenge:nkoke:you're faggotncristiano:i'm billionaire @url offensive_disrespectful indirect anger_confusion_shock_indifference_disgust sexual_orientation other True 369 | 3148 @user @user @user still need go back shithole countries helped @url offensive indirect confusion_sadness origin other True 370 | 1655 @user @user know english retard. appreciate heads offensive_disrespectful indirect sadness_disgust disability special_needs True 371 | 5070 fucking cunt @url abusive_disrespectful indirect indifference_disgust gender women True 372 | 2528 rapes increase due influx primitive immigrants. #capitalism kicks save day. @url offensive indirect anger_disgust origin women True 373 | 574 @user @user silly little cunt think people affected it. plus it' @url offensive direct anger_indifference gender women True 374 | 1588 im racist much despise word nigger hilarious offensive indirect anger_shock origin left_wing_people True 375 | 77 @user @user fuck suicide male prostitute rubbish chinaman lee hsien loong imagining h @url offensive indirect confusion_sadness_shock_indifference_disgust origin muslims True 376 | 2748 ted canadian cuban father wanted hunt 12 million immigrants deport solution lol fuck ted cruz offensive indirect anger_confusion_fear_shock_sadness origin individual True 377 | 265 twat? gary twat offensive indirect confusion_disgust gender women True 378 | 136 @user 6/28/2018. four illegal aliens accused kidnapping raping two teenage sisters 13 1 @url normal indirect anger_confusion_fear_shock_sadness origin left_wing_people False 379 | 964 dis dyke finna get humped offensive indirect confusion_disgust sexual_orientation special_needs True 380 | 5098 @user tyrone retarded offensive indirect anger_indifference disability special_needs True 381 | 4048 @user @user nope. shithole countries. (would live there?) i'm fence th @url offensive indirect anger_disgust origin other True 382 | 1088 want punch face fucking much disrespectful ching chong offensive indirect anger_confusion_fear_sadness_disgust origin left_wing_people True 383 | 5530 @user @user know poorly irish immigrants treated right came over? normal indirect sadness_disgust origin other False 384 | 201 feel retarded cant offensive indirect sadness_indifference disability special_needs True 385 | 5101 nigera fucker nigger pooper fearful_hateful indirect anger_disgust origin other True 386 | 3784 shit retarded offensive_normal indirect indifference_disgust disability special_needs True 387 | 2992 imagine much mongoloid @url offensive indirect anger_indifference disability special_needs True 388 | 3688 @user @user @user saying would rape mum cunt literally sad thing @url offensive direct anger_fear_shock_indifference_disgust gender women True 389 | 2592 @user dirty fucking cunti'd like coat jam bury ants nest fearful_abusive indirect indifference_disgust gender special_needs True 390 | 2208 @user @user @user @user grow daft twat fearful_abusive_offensive_hateful_normal indirect anger_shock other women True 391 | 4884 whereas aaron straight-up rancid cunt @url hateful direct anger_disgust other women True 392 | 1003 @user brand new van that! shiny spic span. offensive indirect confusion_indifference other other True 393 | 2232 @user @user come change america shithole country fleeing from? may @url disrespectful_hateful indirect shock_indifference origin immigrants True 394 | 3792 @user conway literally fabled feminazi limbaugh yore case means pretend feminist actual nazi. fearful_abusive_offensive_disrespectful_normal indirect anger_confusion_sadness_indifference_disgust gender women True 395 | 2187 @user ching chong sing along offensive indirect anger_disgust origin individual True 396 | 4123 leftist/marxist/dems went terrorist illegal aliens reduce tax burden americans bad trade deals exce @url fearful_hateful indirect shock_indifference origin immigrants True 397 | 2 @user @user legit nilas retarded idk offensive_disrespectful indirect anger_confusion disability special_needs True 398 | 3896 call faggot tranny want cameron kai we're trans gay @url hateful indirect fear_shock_indifference_sadness_disgust origin gay True 399 | 2898 try freaking hard get know people hang stream actually care shit twat @url abusive_normal indirect anger_sadness other women True 400 | 5149 @user ask us. ask twat mate white hoisr ! fearful_abusive_offensive_hateful_disrespectful indirect indifference_disgust gender women True 401 | 3549 @user family guy holds bars stuff say racist sexist lot never change ugly cunt complains normal indirect confusion_sadness_shock_indifference_disgust origin women False 402 | 5260 negro bar two hours done called/ft 3 times cuz people watching remembered @url offensive_disrespectful indirect anger_confusion_sadness_indifference_fear origin left_wing_people True 403 | 3162 kid would've hated moved different part country adult can't wait get away shithole normal indirect anger_sadness_shock_indifference_disgust other other False 404 | 1188 yet mongy fans busy booing 18 year old booing experienced pros jones millerben @url normal indirect anger_confusion_sadness_shock_disgust gender special_needs False 405 | 3421 @user put nigger @url hateful_disrespectful indirect anger_disgust origin other True 406 | 3587 what's big spotty cunt half way back? @url abusive_normal indirect indifference_disgust gender special_needs True 407 | 237 @user violence?you got mad year ago saying shithole countries. make minds. @url offensive indirect anger_confusion_sadness_shock_disgust origin jews True 408 | 5247 @user @user @user they're really bypassing one country indonesia shithole @url hateful_disrespectful indirect sadness_disgust origin other True 409 | 4122 @user thank senator graham. i've critical support illegal aliens defen @url offensive_normal indirect indifference_disgust origin other True 410 | 5578 @user @user funny use word terrorist way leftist extremism responsible murders? offensive indirect anger_disgust other other True 411 | 3349 machados fucking twat. offensive indirect anger_fear_indifference_sadness_disgust gender women True 412 | 1258 know i'm cooked start acting retarded fearful_normal indirect anger_confusion_fear_shock_disgust other special_needs True 413 | 5346 thanks historic global warming signed law shithole countries going daca g @url offensive_disrespectful indirect anger_disgust origin other True 414 | 5493 someone: turn ching chong shitnme friends: @url offensive indirect anger_confusion_shock_indifference_disgust origin immigrants True 415 | 2171 @user benitez useless cunt fearful_abusive_offensive_hateful_disrespectful direct fear_disgust gender women True 416 | 4502 @user @user dumb? i'm connor fucking retard abusive_hateful indirect anger_disgust disability special_needs True 417 | 4965 @user good job faggot offensive indirect confusion_shock gender gay True 418 | 2545 theme herencolornhaiti/africa-shithole countriesnmexico-drug dealers & rapistsnpuerto rico-didn't g @url offensive indirect anger_disgust origin hispanics True 419 | 5278 @user get nuke though retard offensive indirect fear_shock_indifference_sadness_disgust disability special_needs True 420 | 2916 inbreed mongoloid americanized unamerican brazil military police government good shots got @url normal indirect anger_indifference disability special_needs False 421 | 2001 @user call retarded cat thing alien thing ill fuck cat front gold fish faggot offensive indirect anger_disgust disability individual True 422 | 4311 remember xbox live 2012 getting called spic cause spoke spanish 2 seconds y'all @url fearful_abusive_disrespectful_hateful_normal indirect sadness_indifference origin hispanics True 423 | 1210 @user @user saying basic human right people shithole countries flood @url fearful_offensive_hateful_disrespectful_normal indirect confusion_disgust other other True 424 | 4196 @user @user fucking spic offensive direct anger_indifference origin other True 425 | 5569 @user @user @user thinks #lnp policies #refugees #climatechange @url offensive indirect anger_disgust other refugees True 426 | 118 @user behave yurself today match day stop twat fearful_offensive indirect indifference_disgust origin individual True 427 | 3861 @user @user @user illegal immigrants! legal immigrants welcome! quit lying! offensive indirect sadness_disgust origin immigrants True 428 | 1266 @user @user @user criticise australia shithole. like @url normal indirect indifference_disgust origin other False 429 | 3772 @user feel i've got eyes nose that's midway typically arab looking chinese/mongol @url normal indirect confusion_fear_indifference_sadness_disgust origin individual False 430 | 580 white says wassup nigger made video calling nigger. @url abusive_hateful indirect anger_disgust origin other True 431 | 2928 yeah ive said word nigga & spic pissed ive also called gringo & cracka plenty times sometimes @url offensive_normal indirect indifference_disgust origin other True 432 | 5122 @user get mate. they're sickening. especially ching chong bastards @url offensive indirect anger_fear_shock_indifference_disgust origin special_needs True 433 | 5489 @user get pregerit im gonna punch stomach retarded abusive_disrespectful indirect anger_fear_shock_indifference_sadness origin special_needs True 434 | 3253 @user @user @user @user talking tye mongoloid offensive indirect anger_confusion_sadness_shock_disgust disability special_needs True 435 | 1380 @user thanks cunt offensive indirect anger_sadness other women True 436 | 5131 @user retard fighting court pic.. offensive indirect indifference_disgust disability special_needs True 437 | 1743 stuck leicester retard train driver got mistaken identity fuck sake wanker offensive indirect anger_sadness_shock_indifference_disgust origin special_needs True 438 | 5236 fucking retard! @url offensive indirect anger_confusion_fear_indifference_disgust other other True 439 | 4851 great example though disagree immigrants leave anything behind. thats whats great america - c @url fearful_abusive indirect anger_shock origin refugees True 440 | 91 @user unaccounted illegal aliens happy. tufail one offensive indirect anger_sadness origin immigrants True 441 | 4000 call spic something man least calling offensive_disrespectful indirect indifference_disgust origin other True 442 | 3021 crazy half us 90's babies got old souls half retarded af. normal indirect indifference_disgust gender women False 443 | 1328 @user dads twat look like him. break homes well prick. offensive indirect anger_disgust other other True 444 | 772 pretty much would let gemma know fat cunt shoes gave something @url offensive_hateful direct indifference_disgust other special_needs True 445 | 1928 @user @user dirtiest chinaman criminal lee hsien loong inciting racial hatred malaysian @url offensive_hateful indirect anger_disgust origin other True 446 | 1445 @user @user stop tweeting go back country. tourist kenya controlling economy. hateful_disrespectful indirect confusion_disgust origin special_needs True 447 | 4371 people fucking retarded i'm done wasting time fearful_offensive indirect sadness_disgust disability special_needs True 448 | 2874 @user @user guy sounds like twat offensive indirect fear_indifference other other True 449 | 5529 @user raghead camel shagger? normal indirect sadness_disgust religion other False 450 | 5631 @user @user leaving shithole countries turn europe continent full shithole countries.ngo home!!! hateful indirect anger_fear_shock_sadness_disgust origin other True 451 | 4090 @user issue refugees must resolved 7 years ago way wise policies countries lib @url normal indirect sadness_indifference origin refugees False 452 | 639 flux mean faggot fucks cant find boyfriend! #fluxisamotherfuckerandabitch offensive indirect fear_disgust sexual_orientation special_needs True 453 | 635 @user twat say wrong day praised emri normal indirect shock_indifference gender other False 454 | 4977 @user @user @user @user know weinstein dirty jew right? know way. offensive indirect sadness_disgust religion refugees True 455 | 4515 @user lewis faggot normal direct shock_disgust sexual_orientation other False 456 | 4752 hi ellis like use word retard employers tweet them? @url fearful_offensive_hateful_disrespectful_normal indirect indifference_disgust disability other True 457 | 296 @user @user i'm spic say build wall!!! offensive indirect confusion_sadness origin immigrants True 458 | 2033 smell like cigarettesnyeah fuckin wonder dumb fuckin cunt leave tf alone abusive_offensive direct anger_confusion_sadness_indifference_disgust gender women True 459 | 4336 elmer fudd voice faggot season offensive indirect sadness_indifference sexual_orientation individual True 460 | 5210 dem0kkkr4tz nd r3publ14canz r 2 sidez duh sam3 t0k3n. - politically retarded n0g knows nothing po @url offensive indirect sadness_indifference origin special_needs True 461 | 395 negro fumbles.... offensive indirect sadness_indifference origin african_descent True 462 | 5579 @user @user @user @user @user @user @user @user @url normal indirect anger_disgust other other False 463 | 2428 @user @user gone pregnant dyke!!! offensive indirect sadness_indifference sexual_orientation women True 464 | 5268 @user definitely will. country shithole offensive indirect anger_confusion_sadness_shock_indifference other other True 465 | 5318 @user @user who's oswald dyke offensive indirect anger_disgust sexual_orientation gay True 466 | 1962 say faggot offensive_hateful indirect confusion_shock other other True 467 | 1664 @user @user i'm saying calling shithole countries shitholes shithole bad make @url offensive_disrespectful indirect confusion_disgust origin other True 468 | 3101 calling people beaner spic funniest shit ever offensive indirect anger_disgust origin special_needs True 469 | 5615 think i'm finally 3 years. fucking cunt. offensive indirect anger_confusion_sadness_indifference_disgust gender individual True 470 | 3154 ever bought entire costume one costume true faggot offensive indirect indifference_disgust sexual_orientation other True 471 | 75 @user reread shit illiterate twat offensive indirect confusion_indifference other women True 472 | 2242 they're shithole countries. @url offensive_disrespectful indirect anger_disgust origin other True 473 | 123 @user @user @user retarded. offensive direct anger_indifference disability special_needs True 474 | 1935 mourinho self-centred twat fearful indirect anger_disgust other individual True 475 | 2742 me: hnoomf: retweet last tweet faggot @url offensive indirect anger_sadness_shock_indifference_disgust sexual_orientation other True 476 | 4589 trump culture. who's shithole country again? #despicable @url normal indirect sadness_indifference origin other False 477 | 291 cunt. @user report blasphemy? @user @url offensive indirect anger_sadness_shock_indifference_disgust gender women True 478 | 2271 @user jackson dont tweet weed faggot ruins lives terrible gay drinking cool offensive direct anger_fear origin gay True 479 | 31 yepnnplus spic & span waxed & buffed: fitting suspects dishevelled profile; new posters (not faded th @url offensive indirect confusion_disgust origin special_needs True 480 | 518 @user @user ching chong offensive indirect anger_confusion_sadness_indifference_fear origin other True 481 | 928 @user @user @user leftist fanatics ... antifa another terrorist organization. offensive_normal indirect anger_shock other other True 482 | 1021 @user republicans: keep hungry thirsty dark-skinned strangers walking shithole @url offensive indirect shock_disgust origin other True 483 | 1455 @user @user @user push back leftist professors like #billayers writing @url offensive indirect confusion_indifference other left_wing_people True 484 | 261 @user treating like treat dengists i'd probably say read book retard post b @url offensive indirect confusion_fear_shock_indifference_sadness other special_needs True 485 | 4255 i'm drunk i'm retarded normal indirect anger_indifference disability special_needs False 486 | 2343 song happy shiny people u3 renamed awesome retarded pups. abusive_disrespectful indirect anger_sadness disability special_needs True 487 | 3137 @user @user kind retarded comment that? people helicopter ground @url fearful_offensive indirect confusion_fear_shock_indifference_disgust disability special_needs True 488 | 854 libye: plus de 100 morts dans le naufrage de deux bateaux de migrants @url offensive indirect anger_fear_indifference_sadness_disgust origin immigrants True 489 | 5562 know calls kevin faggot dorms hes mad @url offensive indirect anger_sadness sexual_orientation left_wing_people True 490 | 4172 one person calls chink chinky ching chong i'm going kick balls normal indirect anger_confusion_shock_indifference_disgust origin left_wing_people False 491 | 3330 shithole shithole other's countries question @url fearful_offensive_hateful_disrespectful_normal indirect confusion_fear_shock_sadness_disgust origin other True 492 | 4856 close borders wall built. never assimilate & bring disease that's eradicated. @url normal indirect anger_confusion_shock_indifference_disgust origin individual False 493 | 4941 @user @user told ching chong face? have.nhave someone come @url offensive_normal indirect indifference_disgust origin other True 494 | 2964 @user @user thats 80's ode 'shithole' countries yes? offensive indirect anger_disgust origin other True 495 | 4343 @user knew! mean even doubt lol nthe man spic spam. nnthe saying spic span. spam @url fearful_offensive indirect anger_fear_shock_indifference_disgust origin special_needs True 496 | 1693 'daniel o'donnell bono.' bono's fucking insufferable robbie twat. would anybody wanna aspire lik @url offensive indirect anger_confusion_shock_indifference_disgust other individual True 497 | 3430 countries shit? think want live shithole @url fearful_disrespectful indirect anger_confusion_sadness_fear_disgust origin other True 498 | 1525 @user watergate guys punched faggot managers turned away causing problems offensive indirect sadness_shock sexual_orientation gay True 499 | 3707 words overhear two guys walking home.nnman 1: look faggot.nman 2: eat. @url offensive indirect anger_confusion_shock_indifference_disgust sexual_orientation special_needs True 500 | 3494 @user u retarded sumething libtard? offensive indirect anger_disgust disability special_needs True 501 | 132 @user @user cant tell bonnie actually retarded shes pretending. abusive_offensive_hateful_disrespectful_normal indirect anger_fear_indifference_sadness_disgust disability special_needs True 502 | 816 get see retard tonight normal indirect anger_confusion_fear_indifference_sadness disability special_needs False 503 | 3481 @user ok retard offensive indirect fear_shock disability special_needs True 504 | 2842 never go full retard. @url offensive indirect anger_shock origin special_needs True 505 | 2790 son amazing l can't believe l purpose. l thought retarded get strap @url offensive indirect sadness_disgust disability special_needs True 506 | 4980 least skip funny.nncolin awful. disgusting twat.nn@url offensive indirect sadness_shock other special_needs True 507 | 5034 audacity men. owe shit ya fuckin twat offensive indirect anger_disgust other other True 508 | 1950 @user fucking cheating cunt. ball hand? lying dirty bastard offensive direct anger_fear_shock_indifference_disgust other other True 509 | 103 gimme break gimme break. break piece fat retard @url disrespectful indirect indifference_disgust other other True 510 | 4953 i'm actually twat offensive_normal direct anger_sadness_shock_indifference_disgust other women True 511 | 1940 @user @user @user people want stay shithole countries live in.. @url offensive indirect anger_confusion_fear_indifference_disgust origin other True 512 | 2555 question: walking thousands miles get away shithole country shithol @url offensive indirect anger_indifference origin other True 513 | 440 s.o.b retard dangerous asshole planet! shame brasil! @url offensive indirect anger_sadness_shock_indifference_disgust disability special_needs True 514 | 1032 like remembernn nothing usnn retarded n @url offensive_disrespectful indirect anger_confusion_sadness_shock_disgust disability special_needs True 515 | 5216 @user @user many cunts would cultcunt cult cultcunt could cunt cults? offensive_disrespectful indirect sadness_indifference gender women True 516 | 2977 gfs sister high offered tight pussy amazing tight fresh cunt @url offensive_disrespectful indirect sadness_disgust gender women True 517 | 468 chance defeat vile leftie pussy hat feminazi man-hating bitter bitches brigade. god help us @url fearful_abusive_disrespectful_hateful_normal indirect anger_confusion_fear_shock_disgust origin women True 518 | 5460 @user @user begs question....what obama 8 years help shithole count @url offensive_hateful indirect fear_disgust origin other True 519 | 2662 keep mind im negro open mind got screen door lajshshsnnsnsndndnnxncnxnxnxnx nah birdman mightve point hateful_normal indirect confusion_indifference origin african_descent True 520 | 4459 sound like lil hoes beg'n another man smoke weed. faggot ass mf'z. normal indirect anger_sadness sexual_orientation other False 521 | 2700 @user @user he's dropped he's poorly illness mongy normal indirect anger_fear disability special_needs False 522 | 473 @user @user ur demo believing retard! fearful_abusive_offensive_hateful_normal indirect sadness_disgust other special_needs True 523 | 4491 tweet normal indirect anger_disgust other other False 524 | 114 @user early gab awesome maga sad see whole thing wither die one giant retard moshpit cringe. offensive indirect anger_sadness_shock_indifference_disgust disability other True 525 | 4256 blue way. sick shithole republicans country. @url offensive direct confusion_sadness_shock_indifference_disgust origin other True 526 | 2534 @user yeah nowadays much make eye contact feminazi gonna accuse rape offensive_normal indirect anger_disgust gender women True 527 | 717 @user that's twitter retarded remember bush obama years. we've @url offensive indirect shock_indifference other special_needs True 528 | 4522 @user @user @user disrespectful cunt. offensive indirect anger_shock other other True 529 | 1065 pied piper hamelin nnwhat twat offensive indirect anger_confusion_sadness_indifference_fear other special_needs True 530 | 5512 parents ban screen time' unaware raising twat - latest daily mash @url offensive indirect indifference_disgust other special_needs True 531 | 5359 rulers fucking retarded? mean uk tories. @url offensive indirect anger_confusion_fear_indifference_disgust origin special_needs True 532 | 1722 lobatan! waec can't reissue another one? country indeed shithole. @url normal indirect sadness_fear origin individual False 533 | 250 since cishets feel need butt head everything lgbt+ related: gay people say faggot.ntransgende @url offensive indirect sadness_shock sexual_orientation other True 534 | 2085 nigger behind boats directing imagine that. abusive_offensive indirect indifference_disgust origin indian/hindu True 535 | 3288 mpu general studies retarded module ever exist like even compulsory normal indirect confusion_sadness other other False 536 | 3239 roberts retarded offensive indirect anger_disgust origin other True 537 | 910 @user similarly: give straight people unilateral permission say traumatizing words like faggot @url offensive indirect indifference_disgust sexual_orientation gay True 538 | 2168 @user who's mongoloid? offensive indirect sadness_disgust disability special_needs True 539 | 3275 look like negro league pitcher @url offensive indirect sadness_indifference origin hispanics True 540 | 1059 @user @user that's actually retarded offensive indirect sadness_disgust disability special_needs True 541 | 2322 hating race bc looks / skin. racist?nn(all races caucasian negroid mongoloid capoid australoid etc.) offensive indirect anger_disgust origin women True 542 | 2442 hogan banned *all* syrian refugees maryland fearful_offensive indirect sadness_indifference origin refugees True 543 | 1083 give guy enough ketchup packets today work called wetback spic normal indirect shock_indifference origin special_needs False 544 | 3268 remember ching chong hoe flexin years ago found fuckin smh offensive indirect anger_confusion origin asians True 545 | 3709 @user u look like retard abusive_hateful direct anger_fear_shock_indifference_disgust disability special_needs True 546 | 2959 @user song retard offensive indirect anger_disgust disability special_needs True 547 | 4984 hisoka faggot shit hands real @url hateful indirect sadness_disgust sexual_orientation left_wing_people True 548 | 3674 @user luke retarded rotations well. kept josh long time. played zo play stretch. normal indirect anger_fear_shock_indifference_disgust disability special_needs False 549 | 4725 i'm #metoo turns people fake empowerment. like guys trade horrific stories @url offensive indirect anger_shock other women True 550 | 5468 things gained (lost) retarded free speech @user users like robert bowers: possibly thing @url offensive indirect sadness_disgust disability special_needs True 551 | 489 truely little mongy! . @url offensive indirect anger_confusion_sadness_indifference_disgust origin individual True 552 | 1750 kill cunt!! @url hateful direct confusion_fear_indifference_sadness_disgust gender women True 553 | 2527 @user @user listened earlier hahaha even tho hes cunt half mongy fans offensive indirect confusion_sadness gender women True 554 | 1882 @user hey hey hey. watch it. got spic kids. . nn(no offense @user @user dads racist bastard). abusive_offensive direct anger_shock origin special_needs True 555 | 1104 @user watch fall happy daddy claptrap. forgetting vile vicious cruelty refugees @url offensive_normal indirect shock_indifference origin refugees True 556 | 3984 aesthetic dumb funny dyke curly hair daddy issues offensive indirect anger_sadness_shock_indifference_disgust sexual_orientation gay True 557 | 4739 @user @user @user lol corbyn twat two kind. offensive indirect anger_shock gender women True 558 | 994 #indianarmy #hinduforces.if hindus forces #kashmiripandit wouldnt refugees & @url offensive_normal indirect confusion_fear_shock_indifference_disgust other refugees True 559 | 2643 @user he's retard offensive indirect anger_indifference origin special_needs True 560 | 1883 time moving slow. i'm already impatient person naturally..this retarded offensive_normal indirect anger_confusion_sadness_shock_indifference disability special_needs True 561 | 45 @user what's spic_ offensive_disrespectful indirect shock_disgust origin other True 562 | 3535 @user john flower faggot weighs 700 pounds lives college dorm refrigerator hateful indirect anger_confusion_sadness_indifference_disgust sexual_orientation other True 563 | 1910 people still using ching chong insult asians lmfao y'all look stupid hateful indirect indifference_disgust origin individual True 564 | 2402 trumpkin makes sense twat white house! @url offensive indirect anger_indifference origin left_wing_people True 565 | 3822 @user @user he's faggot ass band wagon bama fan offensive indirect anger_shock sexual_orientation individual True 566 | 3290 @user mom wont say spic say nigga love problematic afrolatinx queen abusive_hateful indirect shock_disgust origin individual True 567 | --------------------------------------------------------------------------------