├── README.md ├── SentEval ├── LICENSE ├── README.md ├── data │ └── downstream │ │ └── download_dataset.sh ├── examples │ ├── bow.py │ ├── bow_word_piece.py │ ├── gensen.py │ ├── googleuse.py │ ├── infersent.py │ ├── models.py │ └── skipthought.py ├── senteval │ ├── __init__.py │ ├── binary.py │ ├── engine.py │ ├── mrpc.py │ ├── probing.py │ ├── rank.py │ ├── sick.py │ ├── snli.py │ ├── sst.py │ ├── sts.py │ ├── tools │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ranking.py │ │ ├── relatedness.py │ │ └── validation.py │ ├── trec.py │ └── utils.py └── setup.py ├── arial.ttf ├── data └── download_nli.sh ├── ds.config ├── eval_sts.py ├── figure └── e5v.png ├── ft_llm.py ├── load_llama3_hf.py ├── requirements.txt ├── retrieval.py └── run.sh /README.md: -------------------------------------------------------------------------------- 1 | # E5-V: Universal Embeddings with Multimodal Large Language Models 2 | 3 | ## Overview 4 | We propose a framework, called E5-V, to adpat MLLMs for achieving multimodal embeddings. E5-V effectively bridges the modality gap between different types of inputs, demonstrating strong performance in multimodal embeddings even without fine-tuning. We also propose a single modality training approach for E5-V, where the model is trained exclusively on text pairs, demonstrating better performance than multimodal training. 5 | 6 | ![](figure/e5v.png) 7 | 8 | ## Example 9 | ``` python 10 | import torch 11 | import torch.nn.functional as F 12 | import requests 13 | from PIL import Image 14 | from transformers import AutoTokenizer 15 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 16 | 17 | llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' 18 | 19 | processor = LlavaNextProcessor.from_pretrained('royokong/e5-v') 20 | model = LlavaNextForConditionalGeneration.from_pretrained('royokong/e5-v', torch_dtype=torch.float16).cuda() 21 | 22 | img_prompt = llama3_template.format('\nSummary above image in one word: ') 23 | text_prompt = llama3_template.format('\nSummary above sentence in one word: ') 24 | 25 | urls = ['https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg', 26 | 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg'] 27 | images = [Image.open(requests.get(url, stream=True).raw) for url in urls] 28 | 29 | texts = ['A dog sitting in the grass.', 30 | 'A cat standing in the snow.'] 31 | 32 | text_inputs = processor([text_prompt.replace('', text) for text in texts], return_tensors="pt", padding=True).to('cuda') 33 | img_inputs = processor([img_prompt]*len(images), images, return_tensors="pt", padding=True).to('cuda') 34 | 35 | with torch.no_grad(): 36 | text_embs = model(**text_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :] 37 | img_embs = model(**img_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :] 38 | 39 | text_embs = F.normalize(text_embs, dim=-1) 40 | img_embs = F.normalize(img_embs, dim=-1) 41 | 42 | print(text_embs @ img_embs.t()) 43 | ``` 44 | 45 | 46 | ## Evaulate 47 | To evaluate the original results in the paper, please run following 48 | ```sh 49 | # eval on coco, flickr30k, fashioniq and cirr 50 | accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py --use_e5v 51 | 52 | # eval on i2i-coco, i2i-flickr30k 53 | accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py --use_e5v --ocr_replace_text 54 | 55 | # eval on sts tasks 56 | cd SentEval/data/downstream/ 57 | bash download_dataset.sh 58 | cd - 59 | accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 eval_sts.py --model_name_or_path royokong/e5-v 60 | ``` 61 | 62 | ## Training 63 | 1. Install Dependencies 64 | 65 | ``` sh 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | 2. Download Data 70 | 71 | ``` sh 72 | cd ./data 73 | bash download_nli.sh 74 | cd - 75 | ``` 76 | 77 | 3. Transfer llava-llama-3-8b model to huggingface format on each nodes 78 | 79 | ``` sh 80 | mkdir -p models 81 | cd models 82 | for i in 1 2 3 4; do 83 | wget https://huggingface.co/lmms-lab/llama3-llava-next-8b/resolve/main/model-0000$i-of-00004.safetensors 84 | done 85 | cd - 86 | python load_llama3_hf.py 87 | rm models/*.safetensors 88 | ``` 89 | 90 | 4. Train 91 | ``` sh 92 | bash run.sh 93 | ``` 94 | 95 | 5. Test 96 | Use `--lora_path` flag to test the results. 97 | ``` sh 98 | accelerate launch --num_machines=1 --num_processes 8 --machine_rank 0 retrieval.py \ 99 | --llava_llama3 --lora_path e5v-8b --batch_size 1 100 | ``` 101 | 102 | 103 | ## Acknowledgement 104 | Our Code is based on SimCSE and alpaca-lora 105 | -------------------------------------------------------------------------------- /SentEval/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SentEval software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /SentEval/README.md: -------------------------------------------------------------------------------- 1 | Our modification to SentEval: 2 | 3 | 1. Add the `all` setting to all STS tasks. 4 | 2. Change STS-B and SICK-R to not use an additional regressor. 5 | 6 | # SentEval: evaluation toolkit for sentence embeddings 7 | 8 | SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. **SentEval currently includes 17 downstream tasks**. We also include a suite of **10 probing tasks** which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations. 9 | 10 | 11 | **(04/22) SentEval new tasks: Added probing tasks for evaluating what linguistic properties are encoded in sentence embeddings** 12 | 13 | **(10/04) SentEval example scripts for three sentence encoders: [SkipThought-LN](https://github.com/ryankiros/layer-norm#skip-thoughts)/[GenSen](https://github.com/Maluuba/gensen)/[Google-USE](https://tfhub.dev/google/universal-sentence-encoder/1)** 14 | 15 | ## Dependencies 16 | 17 | This code is written in python. The dependencies are: 18 | 19 | * Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](http://www.scipy.org/) 20 | * [Pytorch](http://pytorch.org/)>=0.4 21 | * [scikit-learn](http://scikit-learn.org/stable/index.html)>=0.18.0 22 | 23 | ## Transfer tasks 24 | 25 | ### Downstream tasks 26 | SentEval allows you to evaluate your sentence embeddings as features for the following *downstream* tasks: 27 | 28 | | Task | Type | #train | #test | needs_train | set_classifier | 29 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 30 | | [MR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | movie review | 11k | 11k | 1 | 1 | 31 | | [CR](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | product review | 4k | 4k | 1 | 1 | 32 | | [SUBJ](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | subjectivity status | 10k | 10k | 1 | 1 | 33 | | [MPQA](https://nlp.stanford.edu/~sidaw/home/projects:nbsvm) | opinion-polarity | 11k | 11k | 1 | 1 | 34 | | [SST](https://nlp.stanford.edu/sentiment/index.html) | binary sentiment analysis | 67k | 1.8k | 1 | 1 | 35 | | **[SST](https://nlp.stanford.edu/sentiment/index.html)** | **fine-grained sentiment analysis** | 8.5k | 2.2k | 1 | 1 | 36 | | [TREC](http://cogcomp.cs.illinois.edu/Data/QA/QC/) | question-type classification | 6k | 0.5k | 1 | 1 | 37 | | [SICK-E](http://clic.cimec.unitn.it/composes/sick.html) | natural language inference | 4.5k | 4.9k | 1 | 1 | 38 | | [SNLI](https://nlp.stanford.edu/projects/snli/) | natural language inference | 550k | 9.8k | 1 | 1 | 39 | | [MRPC](https://aclweb.org/aclwiki/Paraphrase_Identification_(State_of_the_art)) | paraphrase detection | 4.1k | 1.7k | 1 | 1 | 40 | | [STS 2012](https://www.cs.york.ac.uk/semeval-2012/task6/) | semantic textual similarity | N/A | 3.1k | 0 | 0 | 41 | | [STS 2013](http://ixa2.si.ehu.es/sts/) | semantic textual similarity | N/A | 1.5k | 0 | 0 | 42 | | [STS 2014](http://alt.qcri.org/semeval2014/task10/) | semantic textual similarity | N/A | 3.7k | 0 | 0 | 43 | | [STS 2015](http://alt.qcri.org/semeval2015/task2/) | semantic textual similarity | N/A | 8.5k | 0 | 0 | 44 | | [STS 2016](http://alt.qcri.org/semeval2016/task1/) | semantic textual similarity | N/A | 9.2k | 0 | 0 | 45 | | [STS B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) | semantic textual similarity | 5.7k | 1.4k | 1 | 0 | 46 | | [SICK-R](http://clic.cimec.unitn.it/composes/sick.html) | semantic textual similarity | 4.5k | 4.9k | 1 | 0 | 47 | | [COCO](http://mscoco.org/) | image-caption retrieval | 567k | 5*1k | 1 | 0 | 48 | 49 | where **needs_train** means a model with parameters is learned on top of the sentence embeddings, and **set_classifier** means you can define the parameters of the classifier in the case of a classification task (see below). 50 | 51 | Note: COCO comes with ResNet-101 2048d image embeddings. [More details on the tasks.](https://arxiv.org/pdf/1705.02364.pdf) 52 | 53 | ### Probing tasks 54 | SentEval also includes a series of [*probing* tasks](https://github.com/facebookresearch/SentEval/tree/master/data/probing) to evaluate what linguistic properties are encoded in your sentence embeddings: 55 | 56 | | Task | Type | #train | #test | needs_train | set_classifier | 57 | |---------- |------------------------------ |-----------:|----------:|:-----------:|:----------:| 58 | | [SentLen](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Length prediction | 100k | 10k | 1 | 1 | 59 | | [WC](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word Content analysis | 100k | 10k | 1 | 1 | 60 | | [TreeDepth](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Tree depth prediction | 100k | 10k | 1 | 1 | 61 | | [TopConst](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Top Constituents prediction | 100k | 10k | 1 | 1 | 62 | | [BShift](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Word order analysis | 100k | 10k | 1 | 1 | 63 | | [Tense](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Verb tense prediction | 100k | 10k | 1 | 1 | 64 | | [SubjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Subject number prediction | 100k | 10k | 1 | 1 | 65 | | [ObjNum](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Object number prediction | 100k | 10k | 1 | 1 | 66 | | [SOMO](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Semantic odd man out | 100k | 10k | 1 | 1 | 67 | | [CoordInv](https://github.com/facebookresearch/SentEval/tree/master/data/probing) | Coordination Inversion | 100k | 10k | 1 | 1 | 68 | 69 | ## Download datasets 70 | To get all the transfer tasks datasets, run (in data/downstream/): 71 | ```bash 72 | ./get_transfer_data.bash 73 | ``` 74 | This will automatically download and preprocess the downstream datasets, and store them in data/downstream (warning: for MacOS users, you may have to use p7zip instead of unzip). The probing tasks are already in data/probing by default. 75 | 76 | ## How to use SentEval: examples 77 | 78 | ### examples/bow.py 79 | 80 | In examples/bow.py, we evaluate the quality of the average of word embeddings. 81 | 82 | To download state-of-the-art fastText embeddings: 83 | 84 | ```bash 85 | curl -Lo glove.840B.300d.zip http://nlp.stanford.edu/data/glove.840B.300d.zip 86 | curl -Lo crawl-300d-2M.vec.zip https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip 87 | ``` 88 | 89 | To reproduce the results for bag-of-vectors, run (in examples/): 90 | ```bash 91 | python bow.py 92 | ``` 93 | 94 | As required by SentEval, this script implements two functions: **prepare** (optional) and **batcher** (required) that turn text sentences into sentence embeddings. Then SentEval takes care of the evaluation on the transfer tasks using the embeddings as features. 95 | 96 | ### examples/infersent.py 97 | 98 | To get the **[InferSent](https://www.github.com/facebookresearch/InferSent)** model and reproduce our results, download our best models and run infersent.py (in examples/): 99 | ```bash 100 | curl -Lo examples/infersent1.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent1.pkl 101 | curl -Lo examples/infersent2.pkl https://dl.fbaipublicfiles.com/senteval/infersent/infersent2.pkl 102 | ``` 103 | 104 | ### examples/skipthought.py - examples/gensen.py - examples/googleuse.py 105 | 106 | We also provide example scripts for three other encoders: 107 | 108 | * [SkipThought with Layer-Normalization](https://github.com/ryankiros/layer-norm#skip-thoughts) in Theano 109 | * [GenSen encoder](https://github.com/Maluuba/gensen) in Pytorch 110 | * [Google encoder](https://tfhub.dev/google/universal-sentence-encoder/1) in TensorFlow 111 | 112 | Note that for SkipThought and GenSen, following the steps of the associated githubs is necessary. 113 | The Google encoder script should work as-is. 114 | 115 | ## How to use SentEval 116 | 117 | To evaluate your sentence embeddings, SentEval requires that you implement two functions: 118 | 119 | 1. **prepare** (sees the whole dataset of each task and can thus construct the word vocabulary, the dictionary of word vectors etc) 120 | 2. **batcher** (transforms a batch of text sentences into sentence embeddings) 121 | 122 | 123 | ### 1.) prepare(params, samples) (optional) 124 | 125 | *batcher* only sees one batch at a time while the *samples* argument of *prepare* contains all the sentences of a task. 126 | 127 | ``` 128 | prepare(params, samples) 129 | ``` 130 | * *params*: senteval parameters. 131 | * *samples*: list of all sentences from the tranfer task. 132 | * *output*: No output. Arguments stored in "params" can further be used by *batcher*. 133 | 134 | *Example*: in bow.py, prepare is is used to build the vocabulary of words and construct the "params.word_vect* dictionary of word vectors. 135 | 136 | 137 | ### 2.) batcher(params, batch) 138 | ``` 139 | batcher(params, batch) 140 | ``` 141 | * *params*: senteval parameters. 142 | * *batch*: numpy array of text sentences (of size params.batch_size) 143 | * *output*: numpy array of sentence embeddings (of size params.batch_size) 144 | 145 | *Example*: in bow.py, batcher is used to compute the mean of the word vectors for each sentence in the batch using params.word_vec. Use your own encoder in that function to encode sentences. 146 | 147 | ### 3.) evaluation on transfer tasks 148 | 149 | After having implemented the batch and prepare function for your own sentence encoder, 150 | 151 | 1) to perform the actual evaluation, first import senteval and set its parameters: 152 | ```python 153 | import senteval 154 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 155 | ``` 156 | 157 | 2) (optional) set the parameters of the classifier (when applicable): 158 | ```python 159 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 160 | 'tenacity': 5, 'epoch_size': 4} 161 | ``` 162 | You can choose **nhid=0** (Logistic Regression) or **nhid>0** (MLP) and define the parameters for training. 163 | 164 | 3) Create an instance of the class SE: 165 | ```python 166 | se = senteval.engine.SE(params, batcher, prepare) 167 | ``` 168 | 169 | 4) define the set of transfer tasks and run the evaluation: 170 | ```python 171 | transfer_tasks = ['MR', 'SICKEntailment', 'STS14', 'STSBenchmark'] 172 | results = se.eval(transfer_tasks) 173 | ``` 174 | The current list of available tasks is: 175 | ```python 176 | ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 'SNLI', 177 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 'ImageCaptionRetrieval', 178 | 'STS12', 'STS13', 'STS14', 'STS15', 'STS16', 179 | 'Length', 'WordContent', 'Depth', 'TopConstituents','BigramShift', 'Tense', 180 | 'SubjNumber', 'ObjNumber', 'OddManOut', 'CoordinationInversion'] 181 | ``` 182 | 183 | ## SentEval parameters 184 | Global parameters of SentEval: 185 | ```bash 186 | # senteval parameters 187 | task_path # path to SentEval datasets (required) 188 | seed # seed 189 | usepytorch # use cuda-pytorch (else scikit-learn) where possible 190 | kfold # k-fold validation for MR/CR/SUB/MPQA. 191 | ``` 192 | 193 | Parameters of the classifier: 194 | ```bash 195 | nhid: # number of hidden units (0: Logistic Regression, >0: MLP); Default nonlinearity: Tanh 196 | optim: # optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 197 | tenacity: # how many times dev acc does not increase before training stops 198 | epoch_size: # each epoch corresponds to epoch_size pass on the train set 199 | max_epoch: # max number of epoches 200 | dropout: # dropout for MLP 201 | ``` 202 | 203 | Note that to get a proxy of the results while **dramatically reducing computation time**, 204 | we suggest the **prototyping config**: 205 | ```python 206 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 207 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 208 | 'tenacity': 3, 'epoch_size': 2} 209 | ``` 210 | which will results in a 5 times speedup for classification tasks. 211 | 212 | To produce results that are **comparable to the literature**, use the **default config**: 213 | ```python 214 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10} 215 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 216 | 'tenacity': 5, 'epoch_size': 4} 217 | ``` 218 | which takes longer but will produce better and comparable results. 219 | 220 | For probing tasks, we used an MLP with a Sigmoid nonlinearity and and tuned the nhid (in [50, 100, 200]) and dropout (in [0.0, 0.1, 0.2]) on the dev set. 221 | 222 | ## References 223 | 224 | Please considering citing [[1]](https://arxiv.org/abs/1803.05449) if using this code for evaluating sentence embedding methods. 225 | 226 | ### SentEval: An Evaluation Toolkit for Universal Sentence Representations 227 | 228 | [1] A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449) 229 | 230 | ``` 231 | @article{conneau2018senteval, 232 | title={SentEval: An Evaluation Toolkit for Universal Sentence Representations}, 233 | author={Conneau, Alexis and Kiela, Douwe}, 234 | journal={arXiv preprint arXiv:1803.05449}, 235 | year={2018} 236 | } 237 | ``` 238 | 239 | Contact: [aconneau@fb.com](mailto:aconneau@fb.com), [dkiela@fb.com](mailto:dkiela@fb.com) 240 | 241 | ### Related work 242 | * [J. R Kiros, Y. Zhu, R. Salakhutdinov, R. S. Zemel, A. Torralba, R. Urtasun, S. Fidler - SkipThought Vectors, NIPS 2015](https://arxiv.org/abs/1506.06726) 243 | * [S. Arora, Y. Liang, T. Ma - A Simple but Tough-to-Beat Baseline for Sentence Embeddings, ICLR 2017](https://openreview.net/pdf?id=SyK00v5xx) 244 | * [Y. Adi, E. Kermany, Y. Belinkov, O. Lavi, Y. Goldberg - Fine-grained analysis of sentence embeddings using auxiliary prediction tasks, ICLR 2017](https://arxiv.org/abs/1608.04207) 245 | * [A. Conneau, D. Kiela, L. Barrault, H. Schwenk, A. Bordes - Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, EMNLP 2017](https://arxiv.org/abs/1705.02364) 246 | * [S. Subramanian, A. Trischler, Y. Bengio, C. J Pal - Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, ICLR 2018](https://arxiv.org/abs/1804.00079) 247 | * [A. Nie, E. D. Bennett, N. D. Goodman - DisSent: Sentence Representation Learning from Explicit Discourse Relations, 2018](https://arxiv.org/abs/1710.04334) 248 | * [D. Cer, Y. Yang, S. Kong, N. Hua, N. Limtiaco, R. St. John, N. Constant, M. Guajardo-Cespedes, S. Yuan, C. Tar, Y. Sung, B. Strope, R. Kurzweil - Universal Sentence Encoder, 2018](https://arxiv.org/abs/1803.11175) 249 | * [A. Conneau, G. Kruszewski, G. Lample, L. Barrault, M. Baroni - What you can cram into a single vector: Probing sentence embeddings for linguistic properties, ACL 2018](https://arxiv.org/abs/1805.01070) 250 | -------------------------------------------------------------------------------- /SentEval/data/downstream/download_dataset.sh: -------------------------------------------------------------------------------- 1 | wget --no-check-certificate https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar 2 | tar xvf senteval.tar 3 | -------------------------------------------------------------------------------- /SentEval/examples/bow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import sys 11 | import io 12 | import numpy as np 13 | import logging 14 | 15 | 16 | # Set PATHs 17 | PATH_TO_SENTEVAL = '../' 18 | PATH_TO_DATA = '../data' 19 | # PATH_TO_VEC = 'glove/glove.840B.300d.txt' 20 | PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec' 21 | 22 | # import SentEval 23 | sys.path.insert(0, PATH_TO_SENTEVAL) 24 | import senteval 25 | 26 | 27 | # Create dictionary 28 | def create_dictionary(sentences, threshold=0): 29 | words = {} 30 | for s in sentences: 31 | for word in s: 32 | words[word] = words.get(word, 0) + 1 33 | 34 | if threshold > 0: 35 | newwords = {} 36 | for word in words: 37 | if words[word] >= threshold: 38 | newwords[word] = words[word] 39 | words = newwords 40 | words[''] = 1e9 + 4 41 | words[''] = 1e9 + 3 42 | words['

'] = 1e9 + 2 43 | 44 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 45 | id2word = [] 46 | word2id = {} 47 | for i, (w, _) in enumerate(sorted_words): 48 | id2word.append(w) 49 | word2id[w] = i 50 | 51 | return id2word, word2id 52 | 53 | # Get word vectors from vocabulary (glove, word2vec, fasttext ..) 54 | def get_wordvec(path_to_vec, word2id): 55 | word_vec = {} 56 | 57 | with io.open(path_to_vec, 'r', encoding='utf-8') as f: 58 | # if word2vec or fasttext file : skip first line "next(f)" 59 | for line in f: 60 | word, vec = line.split(' ', 1) 61 | if word in word2id: 62 | word_vec[word] = np.fromstring(vec, sep=' ') 63 | 64 | logging.info('Found {0} words with word vectors, out of \ 65 | {1} words'.format(len(word_vec), len(word2id))) 66 | return word_vec 67 | 68 | 69 | # SentEval prepare and batcher 70 | def prepare(params, samples): 71 | _, params.word2id = create_dictionary(samples) 72 | params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) 73 | params.wvec_dim = 300 74 | return 75 | 76 | def batcher(params, batch): 77 | batch = [sent if sent != [] else ['.'] for sent in batch] 78 | embeddings = [] 79 | 80 | for sent in batch: 81 | sentvec = [] 82 | for word in sent: 83 | if word in params.word_vec: 84 | sentvec.append(params.word_vec[word]) 85 | if not sentvec: 86 | vec = np.zeros(params.wvec_dim) 87 | sentvec.append(vec) 88 | sentvec = np.mean(sentvec, 0) 89 | embeddings.append(sentvec) 90 | 91 | embeddings = np.vstack(embeddings) 92 | return embeddings 93 | 94 | 95 | # Set params for SentEval 96 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 97 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 98 | 'tenacity': 3, 'epoch_size': 2} 99 | 100 | # Set up logger 101 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 102 | 103 | if __name__ == "__main__": 104 | se = senteval.engine.SE(params_senteval, batcher, prepare) 105 | #transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 106 | #'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 107 | #'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 108 | #'Length', 'WordContent', 'Depth', 'TopConstituents', 109 | #'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 110 | #'OddManOut', 'CoordinationInversion'] 111 | transfer_tasks = ['STSBenchmark'] 112 | results = se.eval(transfer_tasks) 113 | print(results) 114 | -------------------------------------------------------------------------------- /SentEval/examples/bow_word_piece.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import sys 11 | import io 12 | import numpy as np 13 | import logging 14 | 15 | from transformers import BertTokenizer 16 | 17 | # Set PATHs 18 | PATH_TO_SENTEVAL = '../' 19 | PATH_TO_DATA = '../data' 20 | # PATH_TO_VEC = 'glove/glove.840B.300d.txt' 21 | PATH_TO_VEC = 'fasttext/crawl-300d-2M.vec' 22 | 23 | # import SentEval 24 | sys.path.insert(0, PATH_TO_SENTEVAL) 25 | import senteval 26 | 27 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 28 | a_remove_set = {".", "a", "the", "in", ",", "is", "to", "of", "and", "'", "on", "man", "-", "s", "with", "for", "\"", "at", "##s", "woman", "are", "it", "two", "that", "you", "dog", "said", "playing", "i", "an", "as", "was", "from", ":", "by", "white"} 29 | remove_set = {'?', '*', '#', '´', '’', '=', '…', '|', '~', '/', '‚', '¿', '–', '»', '-', '€', '‘', '"', '(', '•', '`', '$', ':', '[', '”', '%', '£', '<', '[UNK]', ';', '“', '@', '_', '{', '^', ',', '.', '!', '™', '&', ']', '>', '\\', "'", ')', '+', '—'} 30 | 31 | # Create dictionary 32 | def create_dictionary(sentences, threshold=0): 33 | words = {} 34 | for s in sentences: 35 | for word in s: 36 | words[word] = words.get(word, 0) + 1 37 | #for word in tokenizer.convert_ids_to_tokens(tokenizer.encode(' '.join(s), add_special_tokens=False)): 38 | #if '##' in word and word in remove_set: continue 39 | #words[word] = words.get(word, 0) + 1 40 | 41 | if threshold > 0: 42 | newwords = {} 43 | for word in words: 44 | if words[word] >= threshold: 45 | newwords[word] = words[word] 46 | words = newwords 47 | words[''] = 1e9 + 4 48 | words[''] = 1e9 + 3 49 | words['

'] = 1e9 + 2 50 | 51 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 52 | id2word = [] 53 | word2id = {} 54 | for i, (w, _) in enumerate(sorted_words): 55 | id2word.append(w) 56 | word2id[w] = i 57 | 58 | return id2word, word2id 59 | 60 | # Get word vectors from vocabulary (glove, word2vec, fasttext ..) 61 | def get_wordvec(path_to_vec, word2id): 62 | word_vec = {} 63 | 64 | with io.open(path_to_vec, 'r', encoding='utf-8') as f: 65 | # if word2vec or fasttext file : skip first line "next(f)" 66 | for line in f: 67 | word, vec = line.split(' ', 1) 68 | if word in word2id: 69 | word_vec[word] = np.fromstring(vec, sep=' ') 70 | 71 | logging.info('Found {0} words with word vectors, out of \ 72 | {1} words'.format(len(word_vec), len(word2id))) 73 | return word_vec 74 | 75 | def get_bert_wordvec(path_to_vec, word2id): 76 | word_vec = {} 77 | from transformers import BertModel 78 | bert = BertModel.from_pretrained('bert-base-uncased') 79 | vocab = tokenizer.get_vocab() 80 | bert_word_vec = bert.embeddings.word_embeddings.weight.detach().numpy() 81 | 82 | for word in word2id: 83 | if word in ['', '', '

']: 84 | word_vec[word] = np.zeros(768) 85 | else: 86 | word_vec[word] = bert_word_vec[vocab[word]] 87 | 88 | logging.info('Found {0} words with word vectors, out of \ 89 | {1} words'.format(len(word_vec), len(word2id))) 90 | return word_vec 91 | 92 | # SentEval prepare and batcher 93 | def prepare(params, samples): 94 | _, params.word2id = create_dictionary(samples) 95 | params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) 96 | params.wvec_dim = 300 97 | #params.word_vec = get_bert_wordvec(PATH_TO_VEC, params.word2id) 98 | #params.wvec_dim = 768 99 | return 100 | 101 | def batcher(params, batch): 102 | batch = [sent if sent != [] else ['.'] for sent in batch] 103 | embeddings = [] 104 | 105 | for sent in batch: 106 | sentvec = [] 107 | # for word in tokenizer.convert_ids_to_tokens(tokenizer.encode(' '.join(sent), add_special_tokens=False)): 108 | for word in sent: 109 | if word in params.word_vec:# and word not in a_remove_set and word not in remove_set: 110 | sentvec.append(params.word_vec[word]) 111 | if not sentvec: 112 | vec = np.zeros(params.wvec_dim) 113 | sentvec.append(vec) 114 | sentvec = np.mean(sentvec, 0) 115 | embeddings.append(sentvec) 116 | 117 | embeddings = np.vstack(embeddings) 118 | return embeddings 119 | 120 | 121 | # Set params for SentEval 122 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 123 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 124 | 'tenacity': 3, 'epoch_size': 2} 125 | 126 | # Set up logger 127 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 128 | 129 | if __name__ == "__main__": 130 | se = senteval.engine.SE(params_senteval, batcher, prepare) 131 | #transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 132 | #'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 133 | #'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 134 | #'Length', 'WordContent', 'Depth', 'TopConstituents', 135 | #'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 136 | #'OddManOut', 'CoordinationInversion'] 137 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 138 | results = se.eval(transfer_tasks) 139 | print(results) 140 | task_names = [] 141 | scores = [] 142 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 143 | task_names.append(task) 144 | if task in results: 145 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 146 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 147 | else: 148 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 149 | else: 150 | scores.append("0.00") 151 | task_names.append("Avg.") 152 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 153 | 154 | from prettytable import PrettyTable 155 | tb = PrettyTable() 156 | tb.field_names = task_names 157 | tb.add_row(scores) 158 | print(tb) 159 | -------------------------------------------------------------------------------- /SentEval/examples/gensen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Clone GenSen repo here: https://github.com/Maluuba/gensen.git 10 | And follow instructions for loading the model used in batcher 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import sys 16 | import logging 17 | # import GenSen package 18 | from gensen import GenSen, GenSenSingle 19 | 20 | # Set PATHs 21 | PATH_TO_SENTEVAL = '../' 22 | PATH_TO_DATA = '../data' 23 | 24 | # import SentEval 25 | sys.path.insert(0, PATH_TO_SENTEVAL) 26 | import senteval 27 | 28 | # SentEval prepare and batcher 29 | def prepare(params, samples): 30 | return 31 | 32 | def batcher(params, batch): 33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 34 | _, reps_h_t = gensen.get_representation( 35 | sentences, pool='last', return_numpy=True, tokenize=True 36 | ) 37 | embeddings = reps_h_t 38 | return embeddings 39 | 40 | # Load GenSen model 41 | gensen_1 = GenSenSingle( 42 | model_folder='../data/models', 43 | filename_prefix='nli_large_bothskip', 44 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 45 | ) 46 | gensen_2 = GenSenSingle( 47 | model_folder='../data/models', 48 | filename_prefix='nli_large_bothskip_parse', 49 | pretrained_emb='../data/embedding/glove.840B.300d.h5' 50 | ) 51 | gensen_encoder = GenSen(gensen_1, gensen_2) 52 | reps_h, reps_h_t = gensen.get_representation( 53 | sentences, pool='last', return_numpy=True, tokenize=True 54 | ) 55 | 56 | # Set params for SentEval 57 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 58 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 59 | 'tenacity': 3, 'epoch_size': 2} 60 | params_senteval['gensen'] = gensen_encoder 61 | 62 | # Set up logger 63 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 64 | 65 | if __name__ == "__main__": 66 | se = senteval.engine.SE(params_senteval, batcher, prepare) 67 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 68 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 69 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 70 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 71 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 72 | 'OddManOut', 'CoordinationInversion'] 73 | results = se.eval(transfer_tasks) 74 | print(results) 75 | -------------------------------------------------------------------------------- /SentEval/examples/googleuse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division 9 | 10 | import os 11 | import sys 12 | import logging 13 | import tensorflow as tf 14 | import tensorflow_hub as hub 15 | tf.logging.set_verbosity(0) 16 | 17 | # Set PATHs 18 | PATH_TO_SENTEVAL = '../' 19 | PATH_TO_DATA = '../data' 20 | 21 | # import SentEval 22 | sys.path.insert(0, PATH_TO_SENTEVAL) 23 | import senteval 24 | 25 | # tensorflow session 26 | session = tf.Session() 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 28 | 29 | # SentEval prepare and batcher 30 | def prepare(params, samples): 31 | return 32 | 33 | def batcher(params, batch): 34 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch] 35 | embeddings = params['google_use'](batch) 36 | return embeddings 37 | 38 | def make_embed_fn(module): 39 | with tf.Graph().as_default(): 40 | sentences = tf.placeholder(tf.string) 41 | embed = hub.Module(module) 42 | embeddings = embed(sentences) 43 | session = tf.train.MonitoredSession() 44 | return lambda x: session.run(embeddings, {sentences: x}) 45 | 46 | # Start TF session and load Google Universal Sentence Encoder 47 | encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2") 48 | 49 | # Set params for SentEval 50 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 51 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 52 | 'tenacity': 3, 'epoch_size': 2} 53 | params_senteval['google_use'] = encoder 54 | 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | se = senteval.engine.SE(params_senteval, batcher, prepare) 60 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 61 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 62 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 63 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 64 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 65 | 'OddManOut', 'CoordinationInversion'] 66 | results = se.eval(transfer_tasks) 67 | print(results) 68 | -------------------------------------------------------------------------------- /SentEval/examples/infersent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | InferSent models. See https://github.com/facebookresearch/InferSent. 10 | """ 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import sys 15 | import os 16 | import torch 17 | import logging 18 | 19 | # get models.py from InferSent repo 20 | from models import InferSent 21 | 22 | # Set PATHs 23 | PATH_SENTEVAL = '../' 24 | PATH_TO_DATA = '../data' 25 | PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2 26 | MODEL_PATH = 'infersent1.pkl' 27 | V = 1 # version of InferSent 28 | 29 | assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \ 30 | 'Set MODEL and GloVe PATHs' 31 | 32 | # import senteval 33 | sys.path.insert(0, PATH_SENTEVAL) 34 | import senteval 35 | 36 | 37 | def prepare(params, samples): 38 | params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False) 39 | 40 | 41 | def batcher(params, batch): 42 | sentences = [' '.join(s) for s in batch] 43 | embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False) 44 | return embeddings 45 | 46 | 47 | """ 48 | Evaluation of trained model on Transfer Tasks (SentEval) 49 | """ 50 | 51 | # define senteval params 52 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 53 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 54 | 'tenacity': 3, 'epoch_size': 2} 55 | # Set up logger 56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 57 | 58 | if __name__ == "__main__": 59 | # Load InferSent model 60 | params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048, 61 | 'pool_type': 'max', 'dpout_model': 0.0, 'version': V} 62 | model = InferSent(params_model) 63 | model.load_state_dict(torch.load(MODEL_PATH)) 64 | model.set_w2v_path(PATH_TO_W2V) 65 | 66 | params_senteval['infersent'] = model.cuda() 67 | 68 | se = senteval.engine.SE(params_senteval, batcher, prepare) 69 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 70 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 71 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 72 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 73 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 74 | 'OddManOut', 'CoordinationInversion'] 75 | results = se.eval(transfer_tasks) 76 | print(results) 77 | -------------------------------------------------------------------------------- /SentEval/examples/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf 10 | """ 11 | 12 | import numpy as np 13 | import time 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class InferSent(nn.Module): 20 | 21 | def __init__(self, config): 22 | super(InferSent, self).__init__() 23 | self.bsize = config['bsize'] 24 | self.word_emb_dim = config['word_emb_dim'] 25 | self.enc_lstm_dim = config['enc_lstm_dim'] 26 | self.pool_type = config['pool_type'] 27 | self.dpout_model = config['dpout_model'] 28 | self.version = 1 if 'version' not in config else config['version'] 29 | 30 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, 31 | bidirectional=True, dropout=self.dpout_model) 32 | 33 | assert self.version in [1, 2] 34 | if self.version == 1: 35 | self.bos = '' 36 | self.eos = '' 37 | self.max_pad = True 38 | self.moses_tok = False 39 | elif self.version == 2: 40 | self.bos = '

' 41 | self.eos = '

' 42 | self.max_pad = False 43 | self.moses_tok = True 44 | 45 | def is_cuda(self): 46 | # either all weights are on cpu or they are on gpu 47 | return self.enc_lstm.bias_hh_l0.data.is_cuda 48 | 49 | def forward(self, sent_tuple): 50 | # sent_len: [max_len, ..., min_len] (bsize) 51 | # sent: (seqlen x bsize x worddim) 52 | sent, sent_len = sent_tuple 53 | 54 | # Sort by length (keep idx) 55 | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) 56 | sent_len_sorted = sent_len_sorted.copy() 57 | idx_unsort = np.argsort(idx_sort) 58 | 59 | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ 60 | else torch.from_numpy(idx_sort) 61 | sent = sent.index_select(1, idx_sort) 62 | 63 | # Handling padding in Recurrent Networks 64 | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) 65 | sent_output = self.enc_lstm(sent_packed)[0] # seqlen x batch x 2*nhid 66 | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] 67 | 68 | # Un-sort by length 69 | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ 70 | else torch.from_numpy(idx_unsort) 71 | sent_output = sent_output.index_select(1, idx_unsort) 72 | 73 | # Pooling 74 | if self.pool_type == "mean": 75 | sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda() 76 | emb = torch.sum(sent_output, 0).squeeze(0) 77 | emb = emb / sent_len.expand_as(emb) 78 | elif self.pool_type == "max": 79 | if not self.max_pad: 80 | sent_output[sent_output == 0] = -1e9 81 | emb = torch.max(sent_output, 0)[0] 82 | if emb.ndimension() == 3: 83 | emb = emb.squeeze(0) 84 | assert emb.ndimension() == 2 85 | 86 | return emb 87 | 88 | def set_w2v_path(self, w2v_path): 89 | self.w2v_path = w2v_path 90 | 91 | def get_word_dict(self, sentences, tokenize=True): 92 | # create vocab of words 93 | word_dict = {} 94 | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] 95 | for sent in sentences: 96 | for word in sent: 97 | if word not in word_dict: 98 | word_dict[word] = '' 99 | word_dict[self.bos] = '' 100 | word_dict[self.eos] = '' 101 | return word_dict 102 | 103 | def get_w2v(self, word_dict): 104 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 105 | # create word_vec with w2v vectors 106 | word_vec = {} 107 | with open(self.w2v_path, encoding='utf-8') as f: 108 | for line in f: 109 | word, vec = line.split(' ', 1) 110 | if word in word_dict: 111 | word_vec[word] = np.fromstring(vec, sep=' ') 112 | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) 113 | return word_vec 114 | 115 | def get_w2v_k(self, K): 116 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 117 | # create word_vec with k first w2v vectors 118 | k = 0 119 | word_vec = {} 120 | with open(self.w2v_path, encoding='utf-8') as f: 121 | for line in f: 122 | word, vec = line.split(' ', 1) 123 | if k <= K: 124 | word_vec[word] = np.fromstring(vec, sep=' ') 125 | k += 1 126 | if k > K: 127 | if word in [self.bos, self.eos]: 128 | word_vec[word] = np.fromstring(vec, sep=' ') 129 | 130 | if k > K and all([w in word_vec for w in [self.bos, self.eos]]): 131 | break 132 | return word_vec 133 | 134 | def build_vocab(self, sentences, tokenize=True): 135 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 136 | word_dict = self.get_word_dict(sentences, tokenize) 137 | self.word_vec = self.get_w2v(word_dict) 138 | print('Vocab size : %s' % (len(self.word_vec))) 139 | 140 | # build w2v vocab with k most frequent words 141 | def build_vocab_k_words(self, K): 142 | assert hasattr(self, 'w2v_path'), 'w2v path not set' 143 | self.word_vec = self.get_w2v_k(K) 144 | print('Vocab size : %s' % (K)) 145 | 146 | def update_vocab(self, sentences, tokenize=True): 147 | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' 148 | assert hasattr(self, 'word_vec'), 'build_vocab before updating it' 149 | word_dict = self.get_word_dict(sentences, tokenize) 150 | 151 | # keep only new words 152 | for word in self.word_vec: 153 | if word in word_dict: 154 | del word_dict[word] 155 | 156 | # udpate vocabulary 157 | if word_dict: 158 | new_word_vec = self.get_w2v(word_dict) 159 | self.word_vec.update(new_word_vec) 160 | else: 161 | new_word_vec = [] 162 | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) 163 | 164 | def get_batch(self, batch): 165 | # sent in batch in decreasing order of lengths 166 | # batch: (bsize, max_len, word_dim) 167 | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) 168 | 169 | for i in range(len(batch)): 170 | for j in range(len(batch[i])): 171 | embed[j, i, :] = self.word_vec[batch[i][j]] 172 | 173 | return torch.FloatTensor(embed) 174 | 175 | def tokenize(self, s): 176 | from nltk.tokenize import word_tokenize 177 | if self.moses_tok: 178 | s = ' '.join(word_tokenize(s)) 179 | s = s.replace(" n't ", "n 't ") # HACK to get ~MOSES tokenization 180 | return s.split() 181 | else: 182 | return word_tokenize(s) 183 | 184 | def prepare_samples(self, sentences, bsize, tokenize, verbose): 185 | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else 186 | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences] 187 | n_w = np.sum([len(x) for x in sentences]) 188 | 189 | # filters words without w2v vectors 190 | for i in range(len(sentences)): 191 | s_f = [word for word in sentences[i] if word in self.word_vec] 192 | if not s_f: 193 | import warnings 194 | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ 195 | Replacing by ""..' % (sentences[i], i)) 196 | s_f = [self.eos] 197 | sentences[i] = s_f 198 | 199 | lengths = np.array([len(s) for s in sentences]) 200 | n_wk = np.sum(lengths) 201 | if verbose: 202 | print('Nb words kept : %s/%s (%.1f%s)' % ( 203 | n_wk, n_w, 100.0 * n_wk / n_w, '%')) 204 | 205 | # sort by decreasing length 206 | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) 207 | sentences = np.array(sentences)[idx_sort] 208 | 209 | return sentences, lengths, idx_sort 210 | 211 | def encode(self, sentences, bsize=64, tokenize=True, verbose=False): 212 | tic = time.time() 213 | sentences, lengths, idx_sort = self.prepare_samples( 214 | sentences, bsize, tokenize, verbose) 215 | 216 | embeddings = [] 217 | for stidx in range(0, len(sentences), bsize): 218 | batch = self.get_batch(sentences[stidx:stidx + bsize]) 219 | if self.is_cuda(): 220 | batch = batch.cuda() 221 | with torch.no_grad(): 222 | batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() 223 | embeddings.append(batch) 224 | embeddings = np.vstack(embeddings) 225 | 226 | # unsort 227 | idx_unsort = np.argsort(idx_sort) 228 | embeddings = embeddings[idx_unsort] 229 | 230 | if verbose: 231 | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( 232 | len(embeddings)/(time.time()-tic), 233 | 'gpu' if self.is_cuda() else 'cpu', bsize)) 234 | return embeddings 235 | 236 | def visualize(self, sent, tokenize=True): 237 | 238 | sent = sent.split() if not tokenize else self.tokenize(sent) 239 | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] 240 | 241 | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): 242 | import warnings 243 | warnings.warn('No words in "%s" have w2v vectors. Replacing \ 244 | by "%s %s"..' % (sent, self.bos, self.eos)) 245 | batch = self.get_batch(sent) 246 | 247 | if self.is_cuda(): 248 | batch = batch.cuda() 249 | output = self.enc_lstm(batch)[0] 250 | output, idxs = torch.max(output, 0) 251 | # output, idxs = output.squeeze(), idxs.squeeze() 252 | idxs = idxs.data.cpu().numpy() 253 | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] 254 | 255 | # visualize model 256 | import matplotlib.pyplot as plt 257 | x = range(len(sent[0])) 258 | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] 259 | plt.xticks(x, sent[0], rotation=45) 260 | plt.bar(x, y) 261 | plt.ylabel('%') 262 | plt.title('Visualisation of words importance') 263 | plt.show() 264 | 265 | return output, idxs 266 | -------------------------------------------------------------------------------- /SentEval/examples/skipthought.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | """ 11 | Example of file for SkipThought in SentEval 12 | """ 13 | import logging 14 | import sys 15 | sys.setdefaultencoding('utf8') 16 | 17 | 18 | # Set PATHs 19 | PATH_TO_SENTEVAL = '../' 20 | PATH_TO_DATA = '../data/senteval_data/' 21 | PATH_TO_SKIPTHOUGHT = '' 22 | 23 | assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH' 24 | 25 | # import skipthought and Senteval 26 | sys.path.insert(0, PATH_TO_SKIPTHOUGHT) 27 | import skipthoughts 28 | sys.path.insert(0, PATH_TO_SENTEVAL) 29 | import senteval 30 | 31 | 32 | def prepare(params, samples): 33 | return 34 | 35 | def batcher(params, batch): 36 | batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch] 37 | embeddings = skipthoughts.encode(params['encoder'], batch, 38 | verbose=False, use_eos=True) 39 | return embeddings 40 | 41 | 42 | # Set params for SentEval 43 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512} 44 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 45 | 'tenacity': 5, 'epoch_size': 4} 46 | # Set up logger 47 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 48 | 49 | if __name__ == "__main__": 50 | # Load SkipThought model 51 | params_senteval['encoder'] = skipthoughts.load_model() 52 | 53 | se = senteval.engine.SE(params_senteval, batcher, prepare) 54 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 55 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 56 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 57 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 58 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 59 | 'OddManOut', 'CoordinationInversion'] 60 | results = se.eval(transfer_tasks) 61 | print(results) 62 | -------------------------------------------------------------------------------- /SentEval/senteval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import 9 | 10 | from senteval.engine import SE 11 | -------------------------------------------------------------------------------- /SentEval/senteval/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import io 14 | import os 15 | import numpy as np 16 | import logging 17 | 18 | from senteval.tools.validation import InnerKFoldClassifier 19 | 20 | 21 | class BinaryClassifierEval(object): 22 | def __init__(self, pos, neg, seed=1111): 23 | self.seed = seed 24 | self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg) 25 | self.n_samples = len(self.samples) 26 | 27 | def do_prepare(self, params, prepare): 28 | # prepare is given the whole text 29 | return prepare(params, self.samples) 30 | # prepare puts everything it outputs in "params" : params.word2id etc 31 | # Those output will be further used by "batcher". 32 | 33 | def loadFile(self, fpath): 34 | with io.open(fpath, 'r', encoding='latin-1') as f: 35 | return [line.split() for line in f.read().splitlines()] 36 | 37 | def run(self, params, batcher): 38 | enc_input = [] 39 | # Sort to reduce padding 40 | sorted_corpus = sorted(zip(self.samples, self.labels), 41 | key=lambda z: (len(z[0]), z[1])) 42 | sorted_samples = [x for (x, y) in sorted_corpus] 43 | sorted_labels = [y for (x, y) in sorted_corpus] 44 | logging.info('Generating sentence embeddings') 45 | for ii in range(0, self.n_samples, params.batch_size): 46 | batch = sorted_samples[ii:ii + params.batch_size] 47 | embeddings = batcher(params, batch) 48 | enc_input.append(embeddings) 49 | enc_input = np.vstack(enc_input) 50 | logging.info('Generated sentence embeddings') 51 | 52 | config = {'nclasses': 2, 'seed': self.seed, 53 | 'usepytorch': params.usepytorch, 54 | 'classifier': params.classifier, 55 | 'nhid': params.nhid, 'kfold': params.kfold} 56 | clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config) 57 | devacc, testacc = clf.run() 58 | logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc)) 59 | return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples, 60 | 'ntest': self.n_samples} 61 | 62 | 63 | class CREval(BinaryClassifierEval): 64 | def __init__(self, task_path, seed=1111): 65 | logging.debug('***** Transfer task : CR *****\n\n') 66 | pos = self.loadFile(os.path.join(task_path, 'custrev.pos')) 67 | neg = self.loadFile(os.path.join(task_path, 'custrev.neg')) 68 | super(self.__class__, self).__init__(pos, neg, seed) 69 | 70 | 71 | class MREval(BinaryClassifierEval): 72 | def __init__(self, task_path, seed=1111): 73 | logging.debug('***** Transfer task : MR *****\n\n') 74 | pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos')) 75 | neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg')) 76 | super(self.__class__, self).__init__(pos, neg, seed) 77 | 78 | 79 | class SUBJEval(BinaryClassifierEval): 80 | def __init__(self, task_path, seed=1111): 81 | logging.debug('***** Transfer task : SUBJ *****\n\n') 82 | obj = self.loadFile(os.path.join(task_path, 'subj.objective')) 83 | subj = self.loadFile(os.path.join(task_path, 'subj.subjective')) 84 | super(self.__class__, self).__init__(obj, subj, seed) 85 | 86 | 87 | class MPQAEval(BinaryClassifierEval): 88 | def __init__(self, task_path, seed=1111): 89 | logging.debug('***** Transfer task : MPQA *****\n\n') 90 | pos = self.loadFile(os.path.join(task_path, 'mpqa.pos')) 91 | neg = self.loadFile(os.path.join(task_path, 'mpqa.neg')) 92 | super(self.__class__, self).__init__(pos, neg, seed) 93 | -------------------------------------------------------------------------------- /SentEval/senteval/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | 10 | Generic sentence evaluation scripts wrapper 11 | 12 | ''' 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | from senteval import utils 16 | from senteval.binary import CREval, MREval, MPQAEval, SUBJEval 17 | from senteval.snli import SNLIEval 18 | from senteval.trec import TRECEval 19 | from senteval.sick import SICKEntailmentEval, SICKEval 20 | from senteval.mrpc import MRPCEval 21 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune, STSBenchmarkEvalDev 22 | from senteval.sst import SSTEval 23 | from senteval.rank import ImageCaptionRetrievalEval 24 | from senteval.probing import * 25 | 26 | class SE(object): 27 | def __init__(self, params, batcher, prepare=None): 28 | # parameters 29 | params = utils.dotdict(params) 30 | params.usepytorch = True if 'usepytorch' not in params else params.usepytorch 31 | params.seed = 1111 if 'seed' not in params else params.seed 32 | 33 | params.batch_size = 128 if 'batch_size' not in params else params.batch_size 34 | params.nhid = 0 if 'nhid' not in params else params.nhid 35 | params.kfold = 5 if 'kfold' not in params else params.kfold 36 | 37 | if 'classifier' not in params or not params['classifier']: 38 | params.classifier = {'nhid': 0} 39 | 40 | assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!' 41 | 42 | self.params = params 43 | 44 | # batcher and prepare 45 | self.batcher = batcher 46 | self.prepare = prepare if prepare else lambda x, y: None 47 | 48 | self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 49 | 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark', 50 | 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13', 51 | 'STS14', 'STS15', 'STS16', 52 | 'Length', 'WordContent', 'Depth', 'TopConstituents', 53 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 54 | 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix', 'STSBenchmark-dev'] 55 | 56 | def eval(self, name): 57 | # evaluate on evaluation [name], either takes string or list of strings 58 | if (isinstance(name, list)): 59 | self.results = {x: self.eval(x) for x in name} 60 | return self.results 61 | 62 | tpath = self.params.task_path 63 | assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks) 64 | 65 | # Original SentEval tasks 66 | if name == 'CR': 67 | self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed) 68 | elif name == 'MR': 69 | self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed) 70 | elif name == 'MPQA': 71 | self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed) 72 | elif name == 'SUBJ': 73 | self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed) 74 | elif name == 'SST2': 75 | self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed) 76 | elif name == 'SST5': 77 | self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed) 78 | elif name == 'TREC': 79 | self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed) 80 | elif name == 'MRPC': 81 | self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed) 82 | elif name == 'SICKRelatedness': 83 | self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed) 84 | elif name == 'STSBenchmark': 85 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 86 | elif name == 'STSBenchmark-dev': 87 | self.evaluation = STSBenchmarkEvalDev(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 88 | elif name == 'STSBenchmark-fix': 89 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed) 90 | elif name == 'STSBenchmark-finetune': 91 | self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed) 92 | elif name == 'SICKRelatedness-finetune': 93 | self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed) 94 | elif name == 'SICKEntailment': 95 | self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed) 96 | elif name == 'SNLI': 97 | self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed) 98 | elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 99 | fpath = name + '-en-test' 100 | self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed) 101 | elif name == 'ImageCaptionRetrieval': 102 | self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed) 103 | 104 | # Probing Tasks 105 | elif name == 'Length': 106 | self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed) 107 | elif name == 'WordContent': 108 | self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed) 109 | elif name == 'Depth': 110 | self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed) 111 | elif name == 'TopConstituents': 112 | self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed) 113 | elif name == 'BigramShift': 114 | self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed) 115 | elif name == 'Tense': 116 | self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed) 117 | elif name == 'SubjNumber': 118 | self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed) 119 | elif name == 'ObjNumber': 120 | self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed) 121 | elif name == 'OddManOut': 122 | self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed) 123 | elif name == 'CoordinationInversion': 124 | self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed) 125 | 126 | self.params.current_task = name 127 | self.evaluation.do_prepare(self.params, self.prepare) 128 | 129 | self.results = self.evaluation.run(self.params, self.batcher) 130 | 131 | return self.results 132 | -------------------------------------------------------------------------------- /SentEval/senteval/mrpc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | MRPC : Microsoft Research Paraphrase (detection) Corpus 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import logging 15 | import numpy as np 16 | import io 17 | 18 | from senteval.tools.validation import KFoldClassifier 19 | 20 | from sklearn.metrics import f1_score 21 | 22 | 23 | class MRPCEval(object): 24 | def __init__(self, task_path, seed=1111): 25 | logging.info('***** Transfer task : MRPC *****\n\n') 26 | self.seed = seed 27 | train = self.loadFile(os.path.join(task_path, 28 | 'msr_paraphrase_train.txt')) 29 | test = self.loadFile(os.path.join(task_path, 30 | 'msr_paraphrase_test.txt')) 31 | self.mrpc_data = {'train': train, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | # TODO : Should we separate samples in "train, test"? 35 | samples = self.mrpc_data['train']['X_A'] + \ 36 | self.mrpc_data['train']['X_B'] + \ 37 | self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B'] 38 | return prepare(params, samples) 39 | 40 | def loadFile(self, fpath): 41 | mrpc_data = {'X_A': [], 'X_B': [], 'y': []} 42 | with io.open(fpath, 'r', encoding='utf-8') as f: 43 | for line in f: 44 | text = line.strip().split('\t') 45 | mrpc_data['X_A'].append(text[3].split()) 46 | mrpc_data['X_B'].append(text[4].split()) 47 | mrpc_data['y'].append(text[0]) 48 | 49 | mrpc_data['X_A'] = mrpc_data['X_A'][1:] 50 | mrpc_data['X_B'] = mrpc_data['X_B'][1:] 51 | mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]] 52 | return mrpc_data 53 | 54 | def run(self, params, batcher): 55 | mrpc_embed = {'train': {}, 'test': {}} 56 | 57 | for key in self.mrpc_data: 58 | logging.info('Computing embedding for {0}'.format(key)) 59 | # Sort to reduce padding 60 | text_data = {} 61 | sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'], 62 | self.mrpc_data[key]['X_B'], 63 | self.mrpc_data[key]['y']), 64 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 65 | 66 | text_data['A'] = [x for (x, y, z) in sorted_corpus] 67 | text_data['B'] = [y for (x, y, z) in sorted_corpus] 68 | text_data['y'] = [z for (x, y, z) in sorted_corpus] 69 | 70 | for txt_type in ['A', 'B']: 71 | mrpc_embed[key][txt_type] = [] 72 | for ii in range(0, len(text_data['y']), params.batch_size): 73 | batch = text_data[txt_type][ii:ii + params.batch_size] 74 | embeddings = batcher(params, batch) 75 | mrpc_embed[key][txt_type].append(embeddings) 76 | mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type]) 77 | mrpc_embed[key]['y'] = np.array(text_data['y']) 78 | logging.info('Computed {0} embeddings'.format(key)) 79 | 80 | # Train 81 | trainA = mrpc_embed['train']['A'] 82 | trainB = mrpc_embed['train']['B'] 83 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 84 | trainY = mrpc_embed['train']['y'] 85 | 86 | # Test 87 | testA = mrpc_embed['test']['A'] 88 | testB = mrpc_embed['test']['B'] 89 | testF = np.c_[np.abs(testA - testB), testA * testB] 90 | testY = mrpc_embed['test']['y'] 91 | 92 | config = {'nclasses': 2, 'seed': self.seed, 93 | 'usepytorch': params.usepytorch, 94 | 'classifier': params.classifier, 95 | 'nhid': params.nhid, 'kfold': params.kfold} 96 | clf = KFoldClassifier(train={'X': trainF, 'y': trainY}, 97 | test={'X': testF, 'y': testY}, config=config) 98 | 99 | devacc, testacc, yhat = clf.run() 100 | testf1 = round(100*f1_score(testY, yhat), 2) 101 | logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n' 102 | .format(devacc, testacc, testf1)) 103 | return {'devacc': devacc, 'acc': testacc, 'f1': testf1, 104 | 'ndev': len(trainA), 'ntest': len(testA)} 105 | -------------------------------------------------------------------------------- /SentEval/senteval/probing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | probing tasks 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class PROBINGEval(object): 24 | def __init__(self, task, task_path, seed=1111): 25 | self.seed = seed 26 | self.task = task 27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper()) 28 | self.task_data = {'train': {'X': [], 'y': []}, 29 | 'dev': {'X': [], 'y': []}, 30 | 'test': {'X': [], 'y': []}} 31 | self.loadFile(task_path) 32 | logging.info('Loaded %s train - %s dev - %s test for %s' % 33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']), 34 | len(self.task_data['test']['y']), self.task)) 35 | 36 | def do_prepare(self, params, prepare): 37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \ 38 | self.task_data['test']['X'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'} 43 | with io.open(fpath, 'r', encoding='utf-8') as f: 44 | for line in f: 45 | line = line.rstrip().split('\t') 46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split()) 47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1]) 48 | 49 | labels = sorted(np.unique(self.task_data['train']['y'])) 50 | self.tok2label = dict(zip(labels, range(len(labels)))) 51 | self.nclasses = len(self.tok2label) 52 | 53 | for split in self.task_data: 54 | for i, y in enumerate(self.task_data[split]['y']): 55 | self.task_data[split]['y'][i] = self.tok2label[y] 56 | 57 | def run(self, params, batcher): 58 | task_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | logging.info('Computing embeddings for train/dev/test') 61 | for key in self.task_data: 62 | # Sort to reduce padding 63 | sorted_data = sorted(zip(self.task_data[key]['X'], 64 | self.task_data[key]['y']), 65 | key=lambda z: (len(z[0]), z[1])) 66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data)) 67 | 68 | task_embed[key]['X'] = [] 69 | for ii in range(0, len(self.task_data[key]['y']), bsize): 70 | batch = self.task_data[key]['X'][ii:ii + bsize] 71 | embeddings = batcher(params, batch) 72 | task_embed[key]['X'].append(embeddings) 73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X']) 74 | task_embed[key]['y'] = np.array(self.task_data[key]['y']) 75 | logging.info('Computed embeddings') 76 | 77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 78 | 'usepytorch': params.usepytorch, 79 | 'classifier': params.classifier} 80 | 81 | if self.task == "WordContent" and params.classifier['nhid'] > 0: 82 | config_classifier = copy.deepcopy(config_classifier) 83 | config_classifier['classifier']['nhid'] = 0 84 | print(params.classifier['nhid']) 85 | 86 | clf = SplitClassifier(X={'train': task_embed['train']['X'], 87 | 'valid': task_embed['dev']['X'], 88 | 'test': task_embed['test']['X']}, 89 | y={'train': task_embed['train']['y'], 90 | 'valid': task_embed['dev']['y'], 91 | 'test': task_embed['test']['y']}, 92 | config=config_classifier) 93 | 94 | devacc, testacc = clf.run() 95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper())) 96 | 97 | return {'devacc': devacc, 'acc': testacc, 98 | 'ndev': len(task_embed['dev']['X']), 99 | 'ntest': len(task_embed['test']['X'])} 100 | 101 | """ 102 | Surface Information 103 | """ 104 | class LengthEval(PROBINGEval): 105 | def __init__(self, task_path, seed=1111): 106 | task_path = os.path.join(task_path, 'sentence_length.txt') 107 | # labels: bins 108 | PROBINGEval.__init__(self, 'Length', task_path, seed) 109 | 110 | class WordContentEval(PROBINGEval): 111 | def __init__(self, task_path, seed=1111): 112 | task_path = os.path.join(task_path, 'word_content.txt') 113 | # labels: 200 target words 114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed) 115 | 116 | """ 117 | Latent Structural Information 118 | """ 119 | class DepthEval(PROBINGEval): 120 | def __init__(self, task_path, seed=1111): 121 | task_path = os.path.join(task_path, 'tree_depth.txt') 122 | # labels: bins 123 | PROBINGEval.__init__(self, 'Depth', task_path, seed) 124 | 125 | class TopConstituentsEval(PROBINGEval): 126 | def __init__(self, task_path, seed=1111): 127 | task_path = os.path.join(task_path, 'top_constituents.txt') 128 | # labels: 'PP_NP_VP_.' .. (20 classes) 129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed) 130 | 131 | class BigramShiftEval(PROBINGEval): 132 | def __init__(self, task_path, seed=1111): 133 | task_path = os.path.join(task_path, 'bigram_shift.txt') 134 | # labels: 0 or 1 135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed) 136 | 137 | # TODO: Voice? 138 | 139 | """ 140 | Latent Semantic Information 141 | """ 142 | 143 | class TenseEval(PROBINGEval): 144 | def __init__(self, task_path, seed=1111): 145 | task_path = os.path.join(task_path, 'past_present.txt') 146 | # labels: 'PRES', 'PAST' 147 | PROBINGEval.__init__(self, 'Tense', task_path, seed) 148 | 149 | class SubjNumberEval(PROBINGEval): 150 | def __init__(self, task_path, seed=1111): 151 | task_path = os.path.join(task_path, 'subj_number.txt') 152 | # labels: 'NN', 'NNS' 153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed) 154 | 155 | class ObjNumberEval(PROBINGEval): 156 | def __init__(self, task_path, seed=1111): 157 | task_path = os.path.join(task_path, 'obj_number.txt') 158 | # labels: 'NN', 'NNS' 159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed) 160 | 161 | class OddManOutEval(PROBINGEval): 162 | def __init__(self, task_path, seed=1111): 163 | task_path = os.path.join(task_path, 'odd_man_out.txt') 164 | # labels: 'O', 'C' 165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed) 166 | 167 | class CoordinationInversionEval(PROBINGEval): 168 | def __init__(self, task_path, seed=1111): 169 | task_path = os.path.join(task_path, 'coordination_inversion.txt') 170 | # labels: 'O', 'I' 171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed) 172 | -------------------------------------------------------------------------------- /SentEval/senteval/rank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | Image-Caption Retrieval with COCO dataset 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import sys 15 | import logging 16 | import numpy as np 17 | 18 | try: 19 | import cPickle as pickle 20 | except ImportError: 21 | import pickle 22 | 23 | from senteval.tools.ranking import ImageSentenceRankingPytorch 24 | 25 | 26 | class ImageCaptionRetrievalEval(object): 27 | def __init__(self, task_path, seed=1111): 28 | logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n') 29 | 30 | # Get captions and image features 31 | self.seed = seed 32 | train, dev, test = self.loadFile(task_path) 33 | self.coco_data = {'train': train, 'dev': dev, 'test': test} 34 | 35 | def do_prepare(self, params, prepare): 36 | samples = self.coco_data['train']['sent'] + \ 37 | self.coco_data['dev']['sent'] + \ 38 | self.coco_data['test']['sent'] 39 | prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | coco = {} 43 | 44 | for split in ['train', 'valid', 'test']: 45 | list_sent = [] 46 | list_img_feat = [] 47 | if sys.version_info < (3, 0): 48 | with open(os.path.join(fpath, split + '.pkl')) as f: 49 | cocodata = pickle.load(f) 50 | else: 51 | with open(os.path.join(fpath, split + '.pkl'), 'rb') as f: 52 | cocodata = pickle.load(f, encoding='latin1') 53 | 54 | for imgkey in range(len(cocodata['features'])): 55 | assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \ 56 | cocodata['image_to_caption_ids'][imgkey] 57 | for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]: 58 | sent = cocodata['captions'][captkey]['cleaned_caption'] 59 | sent += ' .' # add punctuation to end of sentence in COCO 60 | list_sent.append(sent.encode('utf-8').split()) 61 | list_img_feat.append(cocodata['features'][imgkey]) 62 | assert len(list_sent) == len(list_img_feat) and \ 63 | len(list_sent) % 5 == 0 64 | list_img_feat = np.array(list_img_feat).astype('float32') 65 | coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat} 66 | return coco['train'], coco['valid'], coco['test'] 67 | 68 | def run(self, params, batcher): 69 | coco_embed = {'train': {'sentfeat': [], 'imgfeat': []}, 70 | 'dev': {'sentfeat': [], 'imgfeat': []}, 71 | 'test': {'sentfeat': [], 'imgfeat': []}} 72 | 73 | for key in self.coco_data: 74 | logging.info('Computing embedding for {0}'.format(key)) 75 | # Sort to reduce padding 76 | self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent']) 77 | self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent']) 78 | idx_unsort = np.argsort(idx_sort) 79 | 80 | coco_embed[key]['X'] = [] 81 | nsent = len(self.coco_data[key]['sent']) 82 | for ii in range(0, nsent, params.batch_size): 83 | batch = self.coco_data[key]['sent'][ii:ii + params.batch_size] 84 | embeddings = batcher(params, batch) 85 | coco_embed[key]['sentfeat'].append(embeddings) 86 | coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort] 87 | coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat']) 88 | logging.info('Computed {0} embeddings'.format(key)) 89 | 90 | config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2} 91 | clf = ImageSentenceRankingPytorch(train=coco_embed['train'], 92 | valid=coco_embed['dev'], 93 | test=coco_embed['test'], 94 | config=config) 95 | 96 | bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \ 97 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run() 98 | 99 | logging.debug("\nTest scores | Image to text: \ 100 | {0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 101 | logging.debug("Test scores | Text to image: \ 102 | {0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 103 | 104 | return {'devacc': bestdevscore, 105 | 'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t), 106 | (r1_t2i, r5_t2i, r10_t2i, medr_t2i)], 107 | 'ndev': len(coco_embed['dev']['sentfeat']), 108 | 'ntest': len(coco_embed['test']['sentfeat'])} 109 | -------------------------------------------------------------------------------- /SentEval/senteval/sick.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SICK Relatedness and Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import os 14 | import io 15 | import logging 16 | import numpy as np 17 | 18 | from sklearn.metrics import mean_squared_error 19 | from scipy.stats import pearsonr, spearmanr 20 | 21 | from senteval.tools.relatedness import RelatednessPytorch 22 | from senteval.tools.validation import SplitClassifier 23 | 24 | class SICKEval(object): 25 | def __init__(self, task_path, seed=1111): 26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n') 27 | self.seed = seed 28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 31 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 32 | 33 | def do_prepare(self, params, prepare): 34 | samples = self.sick_data['train']['X_A'] + \ 35 | self.sick_data['train']['X_B'] + \ 36 | self.sick_data['dev']['X_A'] + \ 37 | self.sick_data['dev']['X_B'] + \ 38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B'] 39 | return prepare(params, samples) 40 | 41 | def loadFile(self, fpath): 42 | skipFirstLine = True 43 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if skipFirstLine: 47 | skipFirstLine = False 48 | else: 49 | text = line.strip().split('\t') 50 | sick_data['X_A'].append(text[1].split()) 51 | sick_data['X_B'].append(text[2].split()) 52 | sick_data['y'].append(text[3]) 53 | 54 | sick_data['y'] = [float(s) for s in sick_data['y']] 55 | return sick_data 56 | 57 | def run(self, params, batcher): 58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sick_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 65 | self.sick_data[key]['X_B'], 66 | self.sick_data[key]['y']), 67 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 68 | 69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 72 | 73 | for txt_type in ['X_A', 'X_B']: 74 | sick_embed[key][txt_type] = [] 75 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 76 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 77 | embeddings = batcher(params, batch) 78 | sick_embed[key][txt_type].append(embeddings) 79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y']) 81 | logging.info('Computed {0} embeddings'.format(key)) 82 | 83 | # Train 84 | trainA = sick_embed['train']['X_A'] 85 | trainB = sick_embed['train']['X_B'] 86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 87 | trainY = self.encode_labels(self.sick_data['train']['y']) 88 | 89 | # Dev 90 | devA = sick_embed['dev']['X_A'] 91 | devB = sick_embed['dev']['X_B'] 92 | devF = np.c_[np.abs(devA - devB), devA * devB] 93 | devY = self.encode_labels(self.sick_data['dev']['y']) 94 | 95 | # Test 96 | testA = sick_embed['test']['X_A'] 97 | testB = sick_embed['test']['X_B'] 98 | testF = np.c_[np.abs(testA - testB), testA * testB] 99 | testY = self.encode_labels(self.sick_data['test']['y']) 100 | 101 | config = {'seed': self.seed, 'nclasses': 5} 102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY}, 103 | valid={'X': devF, 'y': devY}, 104 | test={'X': testF, 'y': testY}, 105 | devscores=self.sick_data['dev']['y'], 106 | config=config) 107 | 108 | devspr, yhat = clf.run() 109 | 110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0] 111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0] 112 | pr = 0 if pr != pr else pr 113 | sr = 0 if sr != sr else sr 114 | se = mean_squared_error(yhat, self.sick_data['test']['y']) 115 | logging.debug('Dev : Spearman {0}'.format(devspr)) 116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \ 117 | for SICK Relatedness\n'.format(pr, sr, se)) 118 | 119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se, 120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)} 121 | 122 | def encode_labels(self, labels, nclass=5): 123 | """ 124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning) 125 | """ 126 | Y = np.zeros((len(labels), nclass)).astype('float32') 127 | for j, y in enumerate(labels): 128 | for i in range(nclass): 129 | if i+1 == np.floor(y) + 1: 130 | Y[j, i] = y - np.floor(y) 131 | if i+1 == np.floor(y): 132 | Y[j, i] = np.floor(y) - y + 1 133 | return Y 134 | 135 | 136 | class SICKEntailmentEval(SICKEval): 137 | def __init__(self, task_path, seed=1111): 138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n') 139 | self.seed = seed 140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 143 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 144 | 145 | def loadFile(self, fpath): 146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2} 147 | skipFirstLine = True 148 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 149 | with io.open(fpath, 'r', encoding='utf-8') as f: 150 | for line in f: 151 | if skipFirstLine: 152 | skipFirstLine = False 153 | else: 154 | text = line.strip().split('\t') 155 | sick_data['X_A'].append(text[1].split()) 156 | sick_data['X_B'].append(text[2].split()) 157 | sick_data['y'].append(text[4]) 158 | sick_data['y'] = [label2id[s] for s in sick_data['y']] 159 | return sick_data 160 | 161 | def run(self, params, batcher): 162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}} 163 | bsize = params.batch_size 164 | 165 | for key in self.sick_data: 166 | logging.info('Computing embedding for {0}'.format(key)) 167 | # Sort to reduce padding 168 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'], 169 | self.sick_data[key]['X_B'], 170 | self.sick_data[key]['y']), 171 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 172 | 173 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus] 174 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus] 175 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus] 176 | 177 | for txt_type in ['X_A', 'X_B']: 178 | sick_embed[key][txt_type] = [] 179 | for ii in range(0, len(self.sick_data[key]['y']), bsize): 180 | batch = self.sick_data[key][txt_type][ii:ii + bsize] 181 | embeddings = batcher(params, batch) 182 | sick_embed[key][txt_type].append(embeddings) 183 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type]) 184 | logging.info('Computed {0} embeddings'.format(key)) 185 | 186 | # Train 187 | trainA = sick_embed['train']['X_A'] 188 | trainB = sick_embed['train']['X_B'] 189 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB] 190 | trainY = np.array(self.sick_data['train']['y']) 191 | 192 | # Dev 193 | devA = sick_embed['dev']['X_A'] 194 | devB = sick_embed['dev']['X_B'] 195 | devF = np.c_[np.abs(devA - devB), devA * devB] 196 | devY = np.array(self.sick_data['dev']['y']) 197 | 198 | # Test 199 | testA = sick_embed['test']['X_A'] 200 | testB = sick_embed['test']['X_B'] 201 | testF = np.c_[np.abs(testA - testB), testA * testB] 202 | testY = np.array(self.sick_data['test']['y']) 203 | 204 | config = {'nclasses': 3, 'seed': self.seed, 205 | 'usepytorch': params.usepytorch, 206 | 'classifier': params.classifier, 207 | 'nhid': params.nhid} 208 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF}, 209 | y={'train': trainY, 'valid': devY, 'test': testY}, 210 | config=config) 211 | 212 | devacc, testacc = clf.run() 213 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 214 | SICK entailment\n'.format(devacc, testacc)) 215 | return {'devacc': devacc, 'acc': testacc, 216 | 'ndev': len(devA), 'ntest': len(testA)} 217 | -------------------------------------------------------------------------------- /SentEval/senteval/snli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SNLI - Entailment 10 | ''' 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import codecs 14 | import os 15 | import io 16 | import copy 17 | import logging 18 | import numpy as np 19 | 20 | from senteval.tools.validation import SplitClassifier 21 | 22 | 23 | class SNLIEval(object): 24 | def __init__(self, taskpath, seed=1111): 25 | logging.debug('***** Transfer task : SNLI Entailment*****\n\n') 26 | self.seed = seed 27 | train1 = self.loadFile(os.path.join(taskpath, 's1.train')) 28 | train2 = self.loadFile(os.path.join(taskpath, 's2.train')) 29 | 30 | trainlabels = io.open(os.path.join(taskpath, 'labels.train'), 31 | encoding='utf-8').read().splitlines() 32 | 33 | valid1 = self.loadFile(os.path.join(taskpath, 's1.dev')) 34 | valid2 = self.loadFile(os.path.join(taskpath, 's2.dev')) 35 | validlabels = io.open(os.path.join(taskpath, 'labels.dev'), 36 | encoding='utf-8').read().splitlines() 37 | 38 | test1 = self.loadFile(os.path.join(taskpath, 's1.test')) 39 | test2 = self.loadFile(os.path.join(taskpath, 's2.test')) 40 | testlabels = io.open(os.path.join(taskpath, 'labels.test'), 41 | encoding='utf-8').read().splitlines() 42 | 43 | # sort data (by s2 first) to reduce padding 44 | sorted_train = sorted(zip(train2, train1, trainlabels), 45 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 46 | train2, train1, trainlabels = map(list, zip(*sorted_train)) 47 | 48 | sorted_valid = sorted(zip(valid2, valid1, validlabels), 49 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 50 | valid2, valid1, validlabels = map(list, zip(*sorted_valid)) 51 | 52 | sorted_test = sorted(zip(test2, test1, testlabels), 53 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 54 | test2, test1, testlabels = map(list, zip(*sorted_test)) 55 | 56 | self.samples = train1 + train2 + valid1 + valid2 + test1 + test2 57 | self.data = {'train': (train1, train2, trainlabels), 58 | 'valid': (valid1, valid2, validlabels), 59 | 'test': (test1, test2, testlabels) 60 | } 61 | 62 | def do_prepare(self, params, prepare): 63 | return prepare(params, self.samples) 64 | 65 | def loadFile(self, fpath): 66 | with codecs.open(fpath, 'rb', 'latin-1') as f: 67 | return [line.split() for line in 68 | f.read().splitlines()] 69 | 70 | def run(self, params, batcher): 71 | self.X, self.y = {}, {} 72 | dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2} 73 | for key in self.data: 74 | if key not in self.X: 75 | self.X[key] = [] 76 | if key not in self.y: 77 | self.y[key] = [] 78 | 79 | input1, input2, mylabels = self.data[key] 80 | enc_input = [] 81 | n_labels = len(mylabels) 82 | for ii in range(0, n_labels, params.batch_size): 83 | batch1 = input1[ii:ii + params.batch_size] 84 | batch2 = input2[ii:ii + params.batch_size] 85 | 86 | if len(batch1) == len(batch2) and len(batch1) > 0: 87 | enc1 = batcher(params, batch1) 88 | enc2 = batcher(params, batch2) 89 | enc_input.append(np.hstack((enc1, enc2, enc1 * enc2, 90 | np.abs(enc1 - enc2)))) 91 | if (ii*params.batch_size) % (20000*params.batch_size) == 0: 92 | logging.info("PROGRESS (encoding): %.2f%%" % 93 | (100 * ii / n_labels)) 94 | self.X[key] = np.vstack(enc_input) 95 | self.y[key] = [dico_label[y] for y in mylabels] 96 | 97 | config = {'nclasses': 3, 'seed': self.seed, 98 | 'usepytorch': params.usepytorch, 99 | 'cudaEfficient': True, 100 | 'nhid': params.nhid, 'noreg': True} 101 | 102 | config_classifier = copy.deepcopy(params.classifier) 103 | config_classifier['max_epoch'] = 15 104 | config_classifier['epoch_size'] = 1 105 | config['classifier'] = config_classifier 106 | 107 | clf = SplitClassifier(self.X, self.y, config) 108 | devacc, testacc = clf.run() 109 | logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n' 110 | .format(devacc, testacc)) 111 | return {'devacc': devacc, 'acc': testacc, 112 | 'ndev': len(self.data['valid'][0]), 113 | 'ntest': len(self.data['test'][0])} 114 | -------------------------------------------------------------------------------- /SentEval/senteval/sst.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | SST - binary classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import SplitClassifier 20 | 21 | 22 | class SSTEval(object): 23 | def __init__(self, task_path, nclasses=2, seed=1111): 24 | self.seed = seed 25 | 26 | # binary of fine-grained 27 | assert nclasses in [2, 5] 28 | self.nclasses = nclasses 29 | self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained' 30 | logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name) 31 | 32 | train = self.loadFile(os.path.join(task_path, 'sentiment-train')) 33 | dev = self.loadFile(os.path.join(task_path, 'sentiment-dev')) 34 | test = self.loadFile(os.path.join(task_path, 'sentiment-test')) 35 | self.sst_data = {'train': train, 'dev': dev, 'test': test} 36 | 37 | def do_prepare(self, params, prepare): 38 | samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \ 39 | self.sst_data['test']['X'] 40 | return prepare(params, samples) 41 | 42 | def loadFile(self, fpath): 43 | sst_data = {'X': [], 'y': []} 44 | with io.open(fpath, 'r', encoding='utf-8') as f: 45 | for line in f: 46 | if self.nclasses == 2: 47 | sample = line.strip().split('\t') 48 | sst_data['y'].append(int(sample[1])) 49 | sst_data['X'].append(sample[0].split()) 50 | elif self.nclasses == 5: 51 | sample = line.strip().split(' ', 1) 52 | sst_data['y'].append(int(sample[0])) 53 | sst_data['X'].append(sample[1].split()) 54 | assert max(sst_data['y']) == self.nclasses - 1 55 | return sst_data 56 | 57 | def run(self, params, batcher): 58 | sst_embed = {'train': {}, 'dev': {}, 'test': {}} 59 | bsize = params.batch_size 60 | 61 | for key in self.sst_data: 62 | logging.info('Computing embedding for {0}'.format(key)) 63 | # Sort to reduce padding 64 | sorted_data = sorted(zip(self.sst_data[key]['X'], 65 | self.sst_data[key]['y']), 66 | key=lambda z: (len(z[0]), z[1])) 67 | self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data)) 68 | 69 | sst_embed[key]['X'] = [] 70 | for ii in range(0, len(self.sst_data[key]['y']), bsize): 71 | batch = self.sst_data[key]['X'][ii:ii + bsize] 72 | embeddings = batcher(params, batch) 73 | sst_embed[key]['X'].append(embeddings) 74 | sst_embed[key]['X'] = np.vstack(sst_embed[key]['X']) 75 | sst_embed[key]['y'] = np.array(self.sst_data[key]['y']) 76 | logging.info('Computed {0} embeddings'.format(key)) 77 | 78 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed, 79 | 'usepytorch': params.usepytorch, 80 | 'classifier': params.classifier} 81 | 82 | clf = SplitClassifier(X={'train': sst_embed['train']['X'], 83 | 'valid': sst_embed['dev']['X'], 84 | 'test': sst_embed['test']['X']}, 85 | y={'train': sst_embed['train']['y'], 86 | 'valid': sst_embed['dev']['y'], 87 | 'test': sst_embed['test']['y']}, 88 | config=config_classifier) 89 | 90 | devacc, testacc = clf.run() 91 | logging.debug('\nDev acc : {0} Test acc : {1} for \ 92 | SST {2} classification\n'.format(devacc, testacc, self.task_name)) 93 | 94 | return {'devacc': devacc, 'acc': testacc, 95 | 'ndev': len(sst_embed['dev']['X']), 96 | 'ntest': len(sst_embed['test']['X'])} 97 | -------------------------------------------------------------------------------- /SentEval/senteval/sts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | STS-{2012,2013,2014,2015,2016} (unsupervised) and 10 | STS-benchmark (supervised) tasks 11 | ''' 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import os 16 | import io 17 | import numpy as np 18 | import logging 19 | 20 | from scipy.stats import spearmanr, pearsonr 21 | 22 | from senteval.utils import cosine 23 | from senteval.sick import SICKEval 24 | 25 | 26 | class STSEval(object): 27 | def loadFile(self, fpath): 28 | self.data = {} 29 | self.samples = [] 30 | 31 | for dataset in self.datasets: 32 | sent1, sent2 = zip(*[l.split("\t") for l in 33 | io.open(fpath + '/STS.input.%s.txt' % dataset, 34 | encoding='utf8').read().splitlines()]) 35 | raw_scores = np.array([x for x in 36 | io.open(fpath + '/STS.gs.%s.txt' % dataset, 37 | encoding='utf8') 38 | .read().splitlines()]) 39 | not_empty_idx = raw_scores != '' 40 | 41 | gs_scores = [float(x) for x in raw_scores[not_empty_idx]] 42 | sent1 = np.array([s.split() for s in sent1], dtype=object)[not_empty_idx] 43 | sent2 = np.array([s.split() for s in sent2], dtype=object)[not_empty_idx] 44 | 45 | # sort data by length to minimize padding in batcher 46 | sorted_data = sorted(zip(sent1, sent2, gs_scores), 47 | key=lambda z: (len(z[0]), len(z[1]), z[2])) 48 | sent1, sent2, gs_scores = map(list, zip(*sorted_data)) 49 | 50 | self.data[dataset] = (sent1, sent2, gs_scores) 51 | self.samples += sent1 + sent2 52 | 53 | def do_prepare(self, params, prepare): 54 | if 'similarity' in params: 55 | self.similarity = params.similarity 56 | else: # Default similarity is cosine 57 | self.similarity = lambda s1, s2: np.nan_to_num(cosine(np.nan_to_num(s1), np.nan_to_num(s2))) 58 | return prepare(params, self.samples) 59 | 60 | def run(self, params, batcher): 61 | results = {} 62 | all_sys_scores = [] 63 | all_gs_scores = [] 64 | for dataset in self.datasets: 65 | sys_scores = [] 66 | input1, input2, gs_scores = self.data[dataset] 67 | for ii in range(0, len(gs_scores), params.batch_size): 68 | batch1 = input1[ii:ii + params.batch_size] 69 | batch2 = input2[ii:ii + params.batch_size] 70 | 71 | # we assume get_batch already throws out the faulty ones 72 | if len(batch1) == len(batch2) and len(batch1) > 0: 73 | enc1 = batcher(params, batch1) 74 | enc2 = batcher(params, batch2) 75 | 76 | for kk in range(enc2.shape[0]): 77 | sys_score = self.similarity(enc1[kk], enc2[kk]) 78 | sys_scores.append(sys_score) 79 | all_sys_scores.extend(sys_scores) 80 | all_gs_scores.extend(gs_scores) 81 | results[dataset] = {'pearson': pearsonr(sys_scores, gs_scores), 82 | 'spearman': spearmanr(sys_scores, gs_scores), 83 | 'nsamples': len(sys_scores)} 84 | logging.debug('%s : pearson = %.4f, spearman = %.4f' % 85 | (dataset, results[dataset]['pearson'][0], 86 | results[dataset]['spearman'][0])) 87 | 88 | weights = [results[dset]['nsamples'] for dset in results.keys()] 89 | list_prs = np.array([results[dset]['pearson'][0] for 90 | dset in results.keys()]) 91 | list_spr = np.array([results[dset]['spearman'][0] for 92 | dset in results.keys()]) 93 | 94 | avg_pearson = np.average(list_prs) 95 | avg_spearman = np.average(list_spr) 96 | wavg_pearson = np.average(list_prs, weights=weights) 97 | wavg_spearman = np.average(list_spr, weights=weights) 98 | all_pearson = pearsonr(all_sys_scores, all_gs_scores) 99 | all_spearman = spearmanr(all_sys_scores, all_gs_scores) 100 | results['all'] = {'pearson': {'all': all_pearson[0], 101 | 'mean': avg_pearson, 102 | 'wmean': wavg_pearson}, 103 | 'spearman': {'all': all_spearman[0], 104 | 'mean': avg_spearman, 105 | 'wmean': wavg_spearman}} 106 | logging.debug('ALL : Pearson = %.4f, \ 107 | Spearman = %.4f' % (all_pearson[0], all_spearman[0])) 108 | logging.debug('ALL (weighted average) : Pearson = %.4f, \ 109 | Spearman = %.4f' % (wavg_pearson, wavg_spearman)) 110 | logging.debug('ALL (average) : Pearson = %.4f, \ 111 | Spearman = %.4f\n' % (avg_pearson, avg_spearman)) 112 | 113 | return results 114 | 115 | 116 | class STS12Eval(STSEval): 117 | def __init__(self, taskpath, seed=1111): 118 | logging.debug('***** Transfer task : STS12 *****\n\n') 119 | self.seed = seed 120 | self.datasets = ['MSRpar', 'MSRvid', 'SMTeuroparl', 121 | 'surprise.OnWN', 'surprise.SMTnews'] 122 | self.loadFile(taskpath) 123 | 124 | 125 | class STS13Eval(STSEval): 126 | # STS13 here does not contain the "SMT" subtask due to LICENSE issue 127 | def __init__(self, taskpath, seed=1111): 128 | logging.debug('***** Transfer task : STS13 (-SMT) *****\n\n') 129 | self.seed = seed 130 | self.datasets = ['FNWN', 'headlines', 'OnWN'] 131 | self.loadFile(taskpath) 132 | 133 | 134 | class STS14Eval(STSEval): 135 | def __init__(self, taskpath, seed=1111): 136 | logging.debug('***** Transfer task : STS14 *****\n\n') 137 | self.seed = seed 138 | self.datasets = ['deft-forum', 'deft-news', 'headlines', 139 | 'images', 'OnWN', 'tweet-news'] 140 | self.loadFile(taskpath) 141 | 142 | 143 | class STS15Eval(STSEval): 144 | def __init__(self, taskpath, seed=1111): 145 | logging.debug('***** Transfer task : STS15 *****\n\n') 146 | self.seed = seed 147 | self.datasets = ['answers-forums', 'answers-students', 148 | 'belief', 'headlines', 'images'] 149 | self.loadFile(taskpath) 150 | 151 | 152 | class STS16Eval(STSEval): 153 | def __init__(self, taskpath, seed=1111): 154 | logging.debug('***** Transfer task : STS16 *****\n\n') 155 | self.seed = seed 156 | self.datasets = ['answer-answer', 'headlines', 'plagiarism', 157 | 'postediting', 'question-question'] 158 | self.loadFile(taskpath) 159 | 160 | 161 | class STSBenchmarkEval(STSEval): 162 | def __init__(self, task_path, seed=1111): 163 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 164 | self.seed = seed 165 | self.samples = [] 166 | #train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 167 | #dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 168 | #test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 169 | #self.datasets = ['train', 'dev', 'test'] 170 | #self.data = {'train': train, 'dev': dev, 'test': test} 171 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 172 | self.datasets = ['test'] 173 | self.data = {'test': test} 174 | 175 | def loadFile(self, fpath): 176 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 177 | with io.open(fpath, 'r', encoding='utf-8') as f: 178 | for line in f: 179 | text = line.strip().split('\t') 180 | sick_data['X_A'].append(text[5].split()) 181 | sick_data['X_B'].append(text[6].split()) 182 | sick_data['y'].append(text[4]) 183 | 184 | sick_data['y'] = [float(s) for s in sick_data['y']] 185 | self.samples += sick_data['X_A'] + sick_data["X_B"] 186 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 187 | 188 | class STSBenchmarkEvalDev(STSEval): 189 | def __init__(self, task_path, seed=1111): 190 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 191 | self.seed = seed 192 | self.samples = [] 193 | #train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 194 | #dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 195 | #test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 196 | #self.datasets = ['train', 'dev', 'test'] 197 | #self.data = {'train': train, 'dev': dev, 'test': test} 198 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 199 | self.datasets = ['dev'] 200 | self.data = {'dev': dev} 201 | 202 | def loadFile(self, fpath): 203 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 204 | with io.open(fpath, 'r', encoding='utf-8') as f: 205 | for line in f: 206 | text = line.strip().split('\t') 207 | sick_data['X_A'].append(text[5].split()) 208 | sick_data['X_B'].append(text[6].split()) 209 | sick_data['y'].append(text[4]) 210 | 211 | sick_data['y'] = [float(s) for s in sick_data['y']] 212 | self.samples += sick_data['X_A'] + sick_data["X_B"] 213 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 214 | 215 | class STSBenchmarkFinetune(SICKEval): 216 | def __init__(self, task_path, seed=1111): 217 | logging.debug('\n\n***** Transfer task : STSBenchmark*****\n\n') 218 | self.seed = seed 219 | train = self.loadFile(os.path.join(task_path, 'sts-train.csv')) 220 | dev = self.loadFile(os.path.join(task_path, 'sts-dev.csv')) 221 | test = self.loadFile(os.path.join(task_path, 'sts-test.csv')) 222 | self.sick_data = {'train': train, 'dev': dev, 'test': test} 223 | 224 | def loadFile(self, fpath): 225 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 226 | with io.open(fpath, 'r', encoding='utf-8') as f: 227 | for line in f: 228 | text = line.strip().split('\t') 229 | sick_data['X_A'].append(text[5].split()) 230 | sick_data['X_B'].append(text[6].split()) 231 | sick_data['y'].append(text[4]) 232 | 233 | sick_data['y'] = [float(s) for s in sick_data['y']] 234 | return sick_data 235 | 236 | class SICKRelatednessEval(STSEval): 237 | def __init__(self, task_path, seed=1111): 238 | logging.debug('\n\n***** Transfer task : SICKRelatedness*****\n\n') 239 | self.seed = seed 240 | self.samples = [] 241 | #train = self.loadFile(os.path.join(task_path, 'SICK_train.txt')) 242 | #dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt')) 243 | #test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 244 | #self.datasets = ['train', 'dev', 'test'] 245 | #self.data = {'train': train, 'dev': dev, 'test': test} 246 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt')) 247 | self.datasets = ['test'] 248 | self.data = {'test': test} 249 | 250 | def loadFile(self, fpath): 251 | skipFirstLine = True 252 | sick_data = {'X_A': [], 'X_B': [], 'y': []} 253 | with io.open(fpath, 'r', encoding='utf-8') as f: 254 | for line in f: 255 | if skipFirstLine: 256 | skipFirstLine = False 257 | else: 258 | text = line.strip().split('\t') 259 | sick_data['X_A'].append(text[1].split()) 260 | sick_data['X_B'].append(text[2].split()) 261 | sick_data['y'].append(text[3]) 262 | 263 | sick_data['y'] = [float(s) for s in sick_data['y']] 264 | self.samples += sick_data['X_A'] + sick_data["X_B"] 265 | return (sick_data['X_A'], sick_data["X_B"], sick_data['y']) 266 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kongds/E5-V/856ea816b943df95d77df39a386302b14ea197a2/SentEval/senteval/tools/__init__.py -------------------------------------------------------------------------------- /SentEval/senteval/tools/classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Pytorch Classifier class in the style of scikit-learn 10 | Classifiers include Logistic Regression and MLP 11 | """ 12 | 13 | from __future__ import absolute_import, division, unicode_literals 14 | 15 | import numpy as np 16 | import copy 17 | from senteval import utils 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | 24 | class PyTorchClassifier(object): 25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111, 26 | cudaEfficient=False): 27 | # fix seed 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | 32 | self.inputdim = inputdim 33 | self.nclasses = nclasses 34 | self.l2reg = l2reg 35 | self.batch_size = batch_size 36 | self.cudaEfficient = cudaEfficient 37 | 38 | def prepare_split(self, X, y, validation_data=None, validation_split=None): 39 | # Preparing validation data 40 | assert validation_split or validation_data 41 | if validation_data is not None: 42 | trainX, trainy = X, y 43 | devX, devy = validation_data 44 | else: 45 | permutation = np.random.permutation(len(X)) 46 | trainidx = permutation[int(validation_split * len(X)):] 47 | devidx = permutation[0:int(validation_split * len(X))] 48 | trainX, trainy = X[trainidx], y[trainidx] 49 | devX, devy = X[devidx], y[devidx] 50 | 51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda') 52 | 53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32) 54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64) 55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32) 56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64) 57 | 58 | return trainX, trainy, devX, devy 59 | 60 | def fit(self, X, y, validation_data=None, validation_split=None, 61 | early_stop=True): 62 | self.nepoch = 0 63 | bestaccuracy = -1 64 | stop_train = False 65 | early_stop_count = 0 66 | 67 | # Preparing validation data 68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data, 69 | validation_split) 70 | 71 | # Training 72 | while not stop_train and self.nepoch <= self.max_epoch: 73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size) 74 | accuracy = self.score(devX, devy) 75 | if accuracy > bestaccuracy: 76 | bestaccuracy = accuracy 77 | bestmodel = copy.deepcopy(self.model) 78 | elif early_stop: 79 | if early_stop_count >= self.tenacity: 80 | stop_train = True 81 | early_stop_count += 1 82 | self.model = bestmodel 83 | return bestaccuracy 84 | 85 | def trainepoch(self, X, y, epoch_size=1): 86 | self.model.train() 87 | for _ in range(self.nepoch, self.nepoch + epoch_size): 88 | permutation = np.random.permutation(len(X)) 89 | all_costs = [] 90 | for i in range(0, len(X), self.batch_size): 91 | # forward 92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device) 93 | 94 | Xbatch = X[idx] 95 | ybatch = y[idx] 96 | 97 | if self.cudaEfficient: 98 | Xbatch = Xbatch.cuda() 99 | ybatch = ybatch.cuda() 100 | output = self.model(Xbatch) 101 | # loss 102 | loss = self.loss_fn(output, ybatch) 103 | all_costs.append(loss.data.item()) 104 | # backward 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | # Update parameters 108 | self.optimizer.step() 109 | self.nepoch += epoch_size 110 | 111 | def score(self, devX, devy): 112 | self.model.eval() 113 | correct = 0 114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient: 115 | devX = torch.FloatTensor(devX).cuda() 116 | devy = torch.LongTensor(devy).cuda() 117 | with torch.no_grad(): 118 | for i in range(0, len(devX), self.batch_size): 119 | Xbatch = devX[i:i + self.batch_size] 120 | ybatch = devy[i:i + self.batch_size] 121 | if self.cudaEfficient: 122 | Xbatch = Xbatch.cuda() 123 | ybatch = ybatch.cuda() 124 | output = self.model(Xbatch) 125 | pred = output.data.max(1)[1] 126 | correct += pred.long().eq(ybatch.data.long()).sum().item() 127 | accuracy = 1.0 * correct / len(devX) 128 | return accuracy 129 | 130 | def predict(self, devX): 131 | self.model.eval() 132 | if not isinstance(devX, torch.cuda.FloatTensor): 133 | devX = torch.FloatTensor(devX).cuda() 134 | yhat = np.array([]) 135 | with torch.no_grad(): 136 | for i in range(0, len(devX), self.batch_size): 137 | Xbatch = devX[i:i + self.batch_size] 138 | output = self.model(Xbatch) 139 | yhat = np.append(yhat, 140 | output.data.max(1)[1].cpu().numpy()) 141 | yhat = np.vstack(yhat) 142 | return yhat 143 | 144 | def predict_proba(self, devX): 145 | self.model.eval() 146 | probas = [] 147 | with torch.no_grad(): 148 | for i in range(0, len(devX), self.batch_size): 149 | Xbatch = devX[i:i + self.batch_size] 150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy()) 151 | if not probas: 152 | probas = vals 153 | else: 154 | probas = np.concatenate(probas, vals, axis=0) 155 | return probas 156 | 157 | 158 | """ 159 | MLP with Pytorch (nhid=0 --> Logistic Regression) 160 | """ 161 | 162 | class MLP(PyTorchClassifier): 163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64, 164 | seed=1111, cudaEfficient=False): 165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg, 166 | batch_size, seed, cudaEfficient) 167 | """ 168 | PARAMETERS: 169 | -nhid: number of hidden units (0: Logistic Regression) 170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..) 171 | -tenacity: how many times dev acc does not increase before stopping 172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set 173 | -max_epoch: max number of epoches 174 | -dropout: dropout for MLP 175 | """ 176 | 177 | self.nhid = 0 if "nhid" not in params else params["nhid"] 178 | self.optim = "adam" if "optim" not in params else params["optim"] 179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"] 180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"] 181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"] 182 | self.dropout = 0. if "dropout" not in params else params["dropout"] 183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"] 184 | 185 | if params["nhid"] == 0: 186 | self.model = nn.Sequential( 187 | nn.Linear(self.inputdim, self.nclasses), 188 | ).cuda() 189 | else: 190 | self.model = nn.Sequential( 191 | nn.Linear(self.inputdim, params["nhid"]), 192 | nn.Dropout(p=self.dropout), 193 | nn.Sigmoid(), 194 | nn.Linear(params["nhid"], self.nclasses), 195 | ).cuda() 196 | 197 | self.loss_fn = nn.CrossEntropyLoss().cuda() 198 | self.loss_fn.size_average = False 199 | 200 | optim_fn, optim_params = utils.get_optimizer(self.optim) 201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params) 202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg 203 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/ranking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Image Annotation/Search for COCO with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import logging 14 | import copy 15 | import numpy as np 16 | 17 | import torch 18 | from torch import nn 19 | from torch.autograd import Variable 20 | import torch.optim as optim 21 | 22 | 23 | class COCOProjNet(nn.Module): 24 | def __init__(self, config): 25 | super(COCOProjNet, self).__init__() 26 | self.imgdim = config['imgdim'] 27 | self.sentdim = config['sentdim'] 28 | self.projdim = config['projdim'] 29 | self.imgproj = nn.Sequential( 30 | nn.Linear(self.imgdim, self.projdim), 31 | ) 32 | self.sentproj = nn.Sequential( 33 | nn.Linear(self.sentdim, self.projdim), 34 | ) 35 | 36 | def forward(self, img, sent, imgc, sentc): 37 | # imgc : (bsize, ncontrast, imgdim) 38 | # sentc : (bsize, ncontrast, sentdim) 39 | # img : (bsize, imgdim) 40 | # sent : (bsize, sentdim) 41 | img = img.unsqueeze(1).expand_as(imgc).contiguous() 42 | img = img.view(-1, self.imgdim) 43 | imgc = imgc.view(-1, self.imgdim) 44 | sent = sent.unsqueeze(1).expand_as(sentc).contiguous() 45 | sent = sent.view(-1, self.sentdim) 46 | sentc = sentc.view(-1, self.sentdim) 47 | 48 | imgproj = self.imgproj(img) 49 | imgproj = imgproj / torch.sqrt(torch.pow(imgproj, 2).sum(1, keepdim=True)).expand_as(imgproj) 50 | imgcproj = self.imgproj(imgc) 51 | imgcproj = imgcproj / torch.sqrt(torch.pow(imgcproj, 2).sum(1, keepdim=True)).expand_as(imgcproj) 52 | sentproj = self.sentproj(sent) 53 | sentproj = sentproj / torch.sqrt(torch.pow(sentproj, 2).sum(1, keepdim=True)).expand_as(sentproj) 54 | sentcproj = self.sentproj(sentc) 55 | sentcproj = sentcproj / torch.sqrt(torch.pow(sentcproj, 2).sum(1, keepdim=True)).expand_as(sentcproj) 56 | # (bsize*ncontrast, projdim) 57 | 58 | anchor1 = torch.sum((imgproj*sentproj), 1) 59 | anchor2 = torch.sum((sentproj*imgproj), 1) 60 | img_sentc = torch.sum((imgproj*sentcproj), 1) 61 | sent_imgc = torch.sum((sentproj*imgcproj), 1) 62 | 63 | # (bsize*ncontrast) 64 | return anchor1, anchor2, img_sentc, sent_imgc 65 | 66 | def proj_sentence(self, sent): 67 | output = self.sentproj(sent) 68 | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) 69 | return output # (bsize, projdim) 70 | 71 | def proj_image(self, img): 72 | output = self.imgproj(img) 73 | output = output / torch.sqrt(torch.pow(output, 2).sum(1, keepdim=True)).expand_as(output) 74 | return output # (bsize, projdim) 75 | 76 | 77 | class PairwiseRankingLoss(nn.Module): 78 | """ 79 | Pairwise ranking loss 80 | """ 81 | def __init__(self, margin): 82 | super(PairwiseRankingLoss, self).__init__() 83 | self.margin = margin 84 | 85 | def forward(self, anchor1, anchor2, img_sentc, sent_imgc): 86 | 87 | cost_sent = torch.clamp(self.margin - anchor1 + img_sentc, 88 | min=0.0).sum() 89 | cost_img = torch.clamp(self.margin - anchor2 + sent_imgc, 90 | min=0.0).sum() 91 | loss = cost_sent + cost_img 92 | return loss 93 | 94 | 95 | class ImageSentenceRankingPytorch(object): 96 | # Image Sentence Ranking on COCO with Pytorch 97 | def __init__(self, train, valid, test, config): 98 | # fix seed 99 | self.seed = config['seed'] 100 | np.random.seed(self.seed) 101 | torch.manual_seed(self.seed) 102 | torch.cuda.manual_seed(self.seed) 103 | 104 | self.train = train 105 | self.valid = valid 106 | self.test = test 107 | 108 | self.imgdim = len(train['imgfeat'][0]) 109 | self.sentdim = len(train['sentfeat'][0]) 110 | self.projdim = config['projdim'] 111 | self.margin = config['margin'] 112 | 113 | self.batch_size = 128 114 | self.ncontrast = 30 115 | self.maxepoch = 20 116 | self.early_stop = True 117 | 118 | config_model = {'imgdim': self.imgdim,'sentdim': self.sentdim, 119 | 'projdim': self.projdim} 120 | self.model = COCOProjNet(config_model).cuda() 121 | 122 | self.loss_fn = PairwiseRankingLoss(margin=self.margin).cuda() 123 | 124 | self.optimizer = optim.Adam(self.model.parameters()) 125 | 126 | def prepare_data(self, trainTxt, trainImg, devTxt, devImg, 127 | testTxt, testImg): 128 | trainTxt = torch.FloatTensor(trainTxt) 129 | trainImg = torch.FloatTensor(trainImg) 130 | devTxt = torch.FloatTensor(devTxt).cuda() 131 | devImg = torch.FloatTensor(devImg).cuda() 132 | testTxt = torch.FloatTensor(testTxt).cuda() 133 | testImg = torch.FloatTensor(testImg).cuda() 134 | 135 | return trainTxt, trainImg, devTxt, devImg, testTxt, testImg 136 | 137 | def run(self): 138 | self.nepoch = 0 139 | bestdevscore = -1 140 | early_stop_count = 0 141 | stop_train = False 142 | 143 | # Preparing data 144 | logging.info('prepare data') 145 | trainTxt, trainImg, devTxt, devImg, testTxt, testImg = \ 146 | self.prepare_data(self.train['sentfeat'], self.train['imgfeat'], 147 | self.valid['sentfeat'], self.valid['imgfeat'], 148 | self.test['sentfeat'], self.test['imgfeat']) 149 | 150 | # Training 151 | while not stop_train and self.nepoch <= self.maxepoch: 152 | logging.info('start epoch') 153 | self.trainepoch(trainTxt, trainImg, devTxt, devImg, nepoches=1) 154 | logging.info('Epoch {0} finished'.format(self.nepoch)) 155 | 156 | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 157 | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 158 | 'dev': bestdevscore} 159 | score = 0 160 | for i in range(5): 161 | devTxt_i = devTxt[i*5000:(i+1)*5000] 162 | devImg_i = devImg[i*5000:(i+1)*5000] 163 | # Compute dev ranks img2txt 164 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg_i, 165 | devTxt_i) 166 | results['i2t']['r1'] += r1_i2t / 5 167 | results['i2t']['r5'] += r5_i2t / 5 168 | results['i2t']['r10'] += r10_i2t / 5 169 | results['i2t']['medr'] += medr_i2t / 5 170 | logging.info("Image to text: {0}, {1}, {2}, {3}" 171 | .format(r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 172 | # Compute dev ranks txt2img 173 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg_i, 174 | devTxt_i) 175 | results['t2i']['r1'] += r1_t2i / 5 176 | results['t2i']['r5'] += r5_t2i / 5 177 | results['t2i']['r10'] += r10_t2i / 5 178 | results['t2i']['medr'] += medr_t2i / 5 179 | logging.info("Text to Image: {0}, {1}, {2}, {3}" 180 | .format(r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 181 | score += (r1_i2t + r5_i2t + r10_i2t + 182 | r1_t2i + r5_t2i + r10_t2i) / 5 183 | 184 | logging.info("Dev mean Text to Image: {0}, {1}, {2}, {3}".format( 185 | results['t2i']['r1'], results['t2i']['r5'], 186 | results['t2i']['r10'], results['t2i']['medr'])) 187 | logging.info("Dev mean Image to text: {0}, {1}, {2}, {3}".format( 188 | results['i2t']['r1'], results['i2t']['r5'], 189 | results['i2t']['r10'], results['i2t']['medr'])) 190 | 191 | # early stop on Pearson 192 | if score > bestdevscore: 193 | bestdevscore = score 194 | bestmodel = copy.deepcopy(self.model) 195 | elif self.early_stop: 196 | if early_stop_count >= 3: 197 | stop_train = True 198 | early_stop_count += 1 199 | self.model = bestmodel 200 | 201 | # Compute test for the 5 splits 202 | results = {'i2t': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 203 | 't2i': {'r1': 0, 'r5': 0, 'r10': 0, 'medr': 0}, 204 | 'dev': bestdevscore} 205 | for i in range(5): 206 | testTxt_i = testTxt[i*5000:(i+1)*5000] 207 | testImg_i = testImg[i*5000:(i+1)*5000] 208 | # Compute test ranks img2txt 209 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(testImg_i, testTxt_i) 210 | results['i2t']['r1'] += r1_i2t / 5 211 | results['i2t']['r5'] += r5_i2t / 5 212 | results['i2t']['r10'] += r10_i2t / 5 213 | results['i2t']['medr'] += medr_i2t / 5 214 | # Compute test ranks txt2img 215 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(testImg_i, testTxt_i) 216 | results['t2i']['r1'] += r1_t2i / 5 217 | results['t2i']['r5'] += r5_t2i / 5 218 | results['t2i']['r10'] += r10_t2i / 5 219 | results['t2i']['medr'] += medr_t2i / 5 220 | 221 | return bestdevscore, results['i2t']['r1'], results['i2t']['r5'], \ 222 | results['i2t']['r10'], results['i2t']['medr'], \ 223 | results['t2i']['r1'], results['t2i']['r5'], \ 224 | results['t2i']['r10'], results['t2i']['medr'] 225 | 226 | def trainepoch(self, trainTxt, trainImg, devTxt, devImg, nepoches=1): 227 | self.model.train() 228 | for _ in range(self.nepoch, self.nepoch + nepoches): 229 | permutation = list(np.random.permutation(len(trainTxt))) 230 | all_costs = [] 231 | for i in range(0, len(trainTxt), self.batch_size): 232 | # forward 233 | if i % (self.batch_size*500) == 0 and i > 0: 234 | logging.info('samples : {0}'.format(i)) 235 | r1_i2t, r5_i2t, r10_i2t, medr_i2t = self.i2t(devImg, 236 | devTxt) 237 | logging.info("Image to text: {0}, {1}, {2}, {3}".format( 238 | r1_i2t, r5_i2t, r10_i2t, medr_i2t)) 239 | # Compute test ranks txt2img 240 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = self.t2i(devImg, 241 | devTxt) 242 | logging.info("Text to Image: {0}, {1}, {2}, {3}".format( 243 | r1_t2i, r5_t2i, r10_t2i, medr_t2i)) 244 | idx = torch.LongTensor(permutation[i:i + self.batch_size]) 245 | imgbatch = Variable(trainImg.index_select(0, idx)).cuda() 246 | sentbatch = Variable(trainTxt.index_select(0, idx)).cuda() 247 | 248 | idximgc = np.random.choice(permutation[:i] + 249 | permutation[i + self.batch_size:], 250 | self.ncontrast*idx.size(0)) 251 | idxsentc = np.random.choice(permutation[:i] + 252 | permutation[i + self.batch_size:], 253 | self.ncontrast*idx.size(0)) 254 | idximgc = torch.LongTensor(idximgc) 255 | idxsentc = torch.LongTensor(idxsentc) 256 | # Get indexes for contrastive images and sentences 257 | imgcbatch = Variable(trainImg.index_select(0, idximgc)).view( 258 | -1, self.ncontrast, self.imgdim).cuda() 259 | sentcbatch = Variable(trainTxt.index_select(0, idxsentc)).view( 260 | -1, self.ncontrast, self.sentdim).cuda() 261 | 262 | anchor1, anchor2, img_sentc, sent_imgc = self.model( 263 | imgbatch, sentbatch, imgcbatch, sentcbatch) 264 | # loss 265 | loss = self.loss_fn(anchor1, anchor2, img_sentc, sent_imgc) 266 | all_costs.append(loss.data.item()) 267 | # backward 268 | self.optimizer.zero_grad() 269 | loss.backward() 270 | # Update parameters 271 | self.optimizer.step() 272 | self.nepoch += nepoches 273 | 274 | def t2i(self, images, captions): 275 | """ 276 | Images: (5N, imgdim) matrix of images 277 | Captions: (5N, sentdim) matrix of captions 278 | """ 279 | with torch.no_grad(): 280 | # Project images and captions 281 | img_embed, sent_embed = [], [] 282 | for i in range(0, len(images), self.batch_size): 283 | img_embed.append(self.model.proj_image( 284 | Variable(images[i:i + self.batch_size]))) 285 | sent_embed.append(self.model.proj_sentence( 286 | Variable(captions[i:i + self.batch_size]))) 287 | img_embed = torch.cat(img_embed, 0).data 288 | sent_embed = torch.cat(sent_embed, 0).data 289 | 290 | npts = int(img_embed.size(0) / 5) 291 | idxs = torch.cuda.LongTensor(range(0, len(img_embed), 5)) 292 | ims = img_embed.index_select(0, idxs) 293 | 294 | ranks = np.zeros(5 * npts) 295 | for index in range(npts): 296 | 297 | # Get query captions 298 | queries = sent_embed[5*index: 5*index + 5] 299 | 300 | # Compute scores 301 | scores = torch.mm(queries, ims.transpose(0, 1)).cpu().numpy() 302 | inds = np.zeros(scores.shape) 303 | for i in range(len(inds)): 304 | inds[i] = np.argsort(scores[i])[::-1] 305 | ranks[5 * index + i] = np.where(inds[i] == index)[0][0] 306 | 307 | # Compute metrics 308 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 309 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 310 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 311 | medr = np.floor(np.median(ranks)) + 1 312 | return (r1, r5, r10, medr) 313 | 314 | def i2t(self, images, captions): 315 | """ 316 | Images: (5N, imgdim) matrix of images 317 | Captions: (5N, sentdim) matrix of captions 318 | """ 319 | with torch.no_grad(): 320 | # Project images and captions 321 | img_embed, sent_embed = [], [] 322 | for i in range(0, len(images), self.batch_size): 323 | img_embed.append(self.model.proj_image( 324 | Variable(images[i:i + self.batch_size]))) 325 | sent_embed.append(self.model.proj_sentence( 326 | Variable(captions[i:i + self.batch_size]))) 327 | img_embed = torch.cat(img_embed, 0).data 328 | sent_embed = torch.cat(sent_embed, 0).data 329 | 330 | npts = int(img_embed.size(0) / 5) 331 | index_list = [] 332 | 333 | ranks = np.zeros(npts) 334 | for index in range(npts): 335 | 336 | # Get query image 337 | query_img = img_embed[5 * index] 338 | 339 | # Compute scores 340 | scores = torch.mm(query_img.view(1, -1), 341 | sent_embed.transpose(0, 1)).view(-1) 342 | scores = scores.cpu().numpy() 343 | inds = np.argsort(scores)[::-1] 344 | index_list.append(inds[0]) 345 | 346 | # Score 347 | rank = 1e20 348 | for i in range(5*index, 5*index + 5, 1): 349 | tmp = np.where(inds == i)[0][0] 350 | if tmp < rank: 351 | rank = tmp 352 | ranks[index] = rank 353 | 354 | # Compute metrics 355 | r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 356 | r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 357 | r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 358 | medr = np.floor(np.median(ranks)) + 1 359 | return (r1, r5, r10, medr) 360 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/relatedness.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Semantic Relatedness (supervised) with Pytorch 10 | """ 11 | from __future__ import absolute_import, division, unicode_literals 12 | 13 | import copy 14 | import numpy as np 15 | 16 | import torch 17 | from torch import nn 18 | import torch.optim as optim 19 | 20 | from scipy.stats import pearsonr, spearmanr 21 | 22 | 23 | class RelatednessPytorch(object): 24 | # Can be used for SICK-Relatedness, and STS14 25 | def __init__(self, train, valid, test, devscores, config): 26 | # fix seed 27 | np.random.seed(config['seed']) 28 | torch.manual_seed(config['seed']) 29 | assert torch.cuda.is_available(), 'torch.cuda required for Relatedness' 30 | torch.cuda.manual_seed(config['seed']) 31 | 32 | self.train = train 33 | self.valid = valid 34 | self.test = test 35 | self.devscores = devscores 36 | 37 | self.inputdim = train['X'].shape[1] 38 | self.nclasses = config['nclasses'] 39 | self.seed = config['seed'] 40 | self.l2reg = 0. 41 | self.batch_size = 64 42 | self.maxepoch = 1000 43 | self.early_stop = True 44 | 45 | self.model = nn.Sequential( 46 | nn.Linear(self.inputdim, self.nclasses), 47 | nn.Softmax(dim=-1), 48 | ) 49 | self.loss_fn = nn.MSELoss() 50 | 51 | if torch.cuda.is_available(): 52 | self.model = self.model.cuda() 53 | self.loss_fn = self.loss_fn.cuda() 54 | 55 | self.loss_fn.size_average = False 56 | self.optimizer = optim.Adam(self.model.parameters(), 57 | weight_decay=self.l2reg) 58 | 59 | def prepare_data(self, trainX, trainy, devX, devy, testX, testy): 60 | # Transform probs to log-probs for KL-divergence 61 | trainX = torch.from_numpy(trainX).float().cuda() 62 | trainy = torch.from_numpy(trainy).float().cuda() 63 | devX = torch.from_numpy(devX).float().cuda() 64 | devy = torch.from_numpy(devy).float().cuda() 65 | testX = torch.from_numpy(testX).float().cuda() 66 | testY = torch.from_numpy(testy).float().cuda() 67 | 68 | return trainX, trainy, devX, devy, testX, testy 69 | 70 | def run(self): 71 | self.nepoch = 0 72 | bestpr = -1 73 | early_stop_count = 0 74 | r = np.arange(1, 6) 75 | stop_train = False 76 | 77 | # Preparing data 78 | trainX, trainy, devX, devy, testX, testy = self.prepare_data( 79 | self.train['X'], self.train['y'], 80 | self.valid['X'], self.valid['y'], 81 | self.test['X'], self.test['y']) 82 | 83 | # Training 84 | while not stop_train and self.nepoch <= self.maxepoch: 85 | self.trainepoch(trainX, trainy, nepoches=50) 86 | yhat = np.dot(self.predict_proba(devX), r) 87 | pr = spearmanr(yhat, self.devscores)[0] 88 | pr = 0 if pr != pr else pr # if NaN bc std=0 89 | # early stop on Pearson 90 | if pr > bestpr: 91 | bestpr = pr 92 | bestmodel = copy.deepcopy(self.model) 93 | elif self.early_stop: 94 | if early_stop_count >= 3: 95 | stop_train = True 96 | early_stop_count += 1 97 | self.model = bestmodel 98 | 99 | yhat = np.dot(self.predict_proba(testX), r) 100 | 101 | return bestpr, yhat 102 | 103 | def trainepoch(self, X, y, nepoches=1): 104 | self.model.train() 105 | for _ in range(self.nepoch, self.nepoch + nepoches): 106 | permutation = np.random.permutation(len(X)) 107 | all_costs = [] 108 | for i in range(0, len(X), self.batch_size): 109 | # forward 110 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda() 111 | Xbatch = X[idx] 112 | ybatch = y[idx] 113 | output = self.model(Xbatch) 114 | # loss 115 | loss = self.loss_fn(output, ybatch) 116 | all_costs.append(loss.item()) 117 | # backward 118 | self.optimizer.zero_grad() 119 | loss.backward() 120 | # Update parameters 121 | self.optimizer.step() 122 | self.nepoch += nepoches 123 | 124 | def predict_proba(self, devX): 125 | self.model.eval() 126 | probas = [] 127 | with torch.no_grad(): 128 | for i in range(0, len(devX), self.batch_size): 129 | Xbatch = devX[i:i + self.batch_size] 130 | if len(probas) == 0: 131 | probas = self.model(Xbatch).data.cpu().numpy() 132 | else: 133 | probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0) 134 | return probas 135 | -------------------------------------------------------------------------------- /SentEval/senteval/tools/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | Validation and classification 10 | (train) : inner-kfold classifier 11 | (train, test) : kfold classifier 12 | (train, dev, test) : split classifier 13 | 14 | """ 15 | from __future__ import absolute_import, division, unicode_literals 16 | 17 | import logging 18 | import numpy as np 19 | from senteval.tools.classifier import MLP 20 | 21 | import sklearn 22 | assert(sklearn.__version__ >= "0.18.0"), \ 23 | "need to update sklearn to version >= 0.18.0" 24 | from sklearn.linear_model import LogisticRegression 25 | from sklearn.model_selection import StratifiedKFold 26 | 27 | 28 | def get_classif_name(classifier_config, usepytorch): 29 | if not usepytorch: 30 | modelname = 'sklearn-LogReg' 31 | else: 32 | nhid = classifier_config['nhid'] 33 | optim = 'adam' if 'optim' not in classifier_config else classifier_config['optim'] 34 | bs = 64 if 'batch_size' not in classifier_config else classifier_config['batch_size'] 35 | modelname = 'pytorch-MLP-nhid%s-%s-bs%s' % (nhid, optim, bs) 36 | return modelname 37 | 38 | # Pytorch version 39 | class InnerKFoldClassifier(object): 40 | """ 41 | (train) split classifier : InnerKfold. 42 | """ 43 | def __init__(self, X, y, config): 44 | self.X = X 45 | self.y = y 46 | self.featdim = X.shape[1] 47 | self.nclasses = config['nclasses'] 48 | self.seed = config['seed'] 49 | self.devresults = [] 50 | self.testresults = [] 51 | self.usepytorch = config['usepytorch'] 52 | self.classifier_config = config['classifier'] 53 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 54 | 55 | self.k = 5 if 'kfold' not in config else config['kfold'] 56 | 57 | def run(self): 58 | logging.info('Training {0} with (inner) {1}-fold cross-validation' 59 | .format(self.modelname, self.k)) 60 | 61 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 62 | [2**t for t in range(-2, 4, 1)] 63 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, random_state=1111) 64 | innerskf = StratifiedKFold(n_splits=self.k, shuffle=True, 65 | random_state=1111) 66 | count = 0 67 | for train_idx, test_idx in skf.split(self.X, self.y): 68 | count += 1 69 | X_train, X_test = self.X[train_idx], self.X[test_idx] 70 | y_train, y_test = self.y[train_idx], self.y[test_idx] 71 | scores = [] 72 | for reg in regs: 73 | regscores = [] 74 | for inner_train_idx, inner_test_idx in innerskf.split(X_train, y_train): 75 | X_in_train, X_in_test = X_train[inner_train_idx], X_train[inner_test_idx] 76 | y_in_train, y_in_test = y_train[inner_train_idx], y_train[inner_test_idx] 77 | if self.usepytorch: 78 | clf = MLP(self.classifier_config, inputdim=self.featdim, 79 | nclasses=self.nclasses, l2reg=reg, 80 | seed=self.seed) 81 | clf.fit(X_in_train, y_in_train, 82 | validation_data=(X_in_test, y_in_test)) 83 | else: 84 | clf = LogisticRegression(C=reg, random_state=self.seed) 85 | clf.fit(X_in_train, y_in_train) 86 | regscores.append(clf.score(X_in_test, y_in_test)) 87 | scores.append(round(100*np.mean(regscores), 2)) 88 | optreg = regs[np.argmax(scores)] 89 | logging.info('Best param found at split {0}: l2reg = {1} \ 90 | with score {2}'.format(count, optreg, np.max(scores))) 91 | self.devresults.append(np.max(scores)) 92 | 93 | if self.usepytorch: 94 | clf = MLP(self.classifier_config, inputdim=self.featdim, 95 | nclasses=self.nclasses, l2reg=optreg, 96 | seed=self.seed) 97 | 98 | clf.fit(X_train, y_train, validation_split=0.05) 99 | else: 100 | clf = LogisticRegression(C=optreg, random_state=self.seed) 101 | clf.fit(X_train, y_train) 102 | 103 | self.testresults.append(round(100*clf.score(X_test, y_test), 2)) 104 | 105 | devaccuracy = round(np.mean(self.devresults), 2) 106 | testaccuracy = round(np.mean(self.testresults), 2) 107 | return devaccuracy, testaccuracy 108 | 109 | 110 | class KFoldClassifier(object): 111 | """ 112 | (train, test) split classifier : cross-validation on train. 113 | """ 114 | def __init__(self, train, test, config): 115 | self.train = train 116 | self.test = test 117 | self.featdim = self.train['X'].shape[1] 118 | self.nclasses = config['nclasses'] 119 | self.seed = config['seed'] 120 | self.usepytorch = config['usepytorch'] 121 | self.classifier_config = config['classifier'] 122 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 123 | 124 | self.k = 5 if 'kfold' not in config else config['kfold'] 125 | 126 | def run(self): 127 | # cross-validation 128 | logging.info('Training {0} with {1}-fold cross-validation' 129 | .format(self.modelname, self.k)) 130 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 131 | [2**t for t in range(-1, 6, 1)] 132 | skf = StratifiedKFold(n_splits=self.k, shuffle=True, 133 | random_state=self.seed) 134 | scores = [] 135 | 136 | for reg in regs: 137 | scanscores = [] 138 | for train_idx, test_idx in skf.split(self.train['X'], 139 | self.train['y']): 140 | # Split data 141 | X_train, y_train = self.train['X'][train_idx], self.train['y'][train_idx] 142 | 143 | X_test, y_test = self.train['X'][test_idx], self.train['y'][test_idx] 144 | 145 | # Train classifier 146 | if self.usepytorch: 147 | clf = MLP(self.classifier_config, inputdim=self.featdim, 148 | nclasses=self.nclasses, l2reg=reg, 149 | seed=self.seed) 150 | clf.fit(X_train, y_train, validation_data=(X_test, y_test)) 151 | else: 152 | clf = LogisticRegression(C=reg, random_state=self.seed) 153 | clf.fit(X_train, y_train) 154 | score = clf.score(X_test, y_test) 155 | scanscores.append(score) 156 | # Append mean score 157 | scores.append(round(100*np.mean(scanscores), 2)) 158 | 159 | # evaluation 160 | logging.info([('reg:' + str(regs[idx]), scores[idx]) 161 | for idx in range(len(scores))]) 162 | optreg = regs[np.argmax(scores)] 163 | devaccuracy = np.max(scores) 164 | logging.info('Cross-validation : best param found is reg = {0} \ 165 | with score {1}'.format(optreg, devaccuracy)) 166 | 167 | logging.info('Evaluating...') 168 | if self.usepytorch: 169 | clf = MLP(self.classifier_config, inputdim=self.featdim, 170 | nclasses=self.nclasses, l2reg=optreg, 171 | seed=self.seed) 172 | clf.fit(self.train['X'], self.train['y'], validation_split=0.05) 173 | else: 174 | clf = LogisticRegression(C=optreg, random_state=self.seed) 175 | clf.fit(self.train['X'], self.train['y']) 176 | yhat = clf.predict(self.test['X']) 177 | 178 | testaccuracy = clf.score(self.test['X'], self.test['y']) 179 | testaccuracy = round(100*testaccuracy, 2) 180 | 181 | return devaccuracy, testaccuracy, yhat 182 | 183 | 184 | class SplitClassifier(object): 185 | """ 186 | (train, valid, test) split classifier. 187 | """ 188 | def __init__(self, X, y, config): 189 | self.X = X 190 | self.y = y 191 | self.nclasses = config['nclasses'] 192 | self.featdim = self.X['train'].shape[1] 193 | self.seed = config['seed'] 194 | self.usepytorch = config['usepytorch'] 195 | self.classifier_config = config['classifier'] 196 | self.cudaEfficient = False if 'cudaEfficient' not in config else \ 197 | config['cudaEfficient'] 198 | self.modelname = get_classif_name(self.classifier_config, self.usepytorch) 199 | self.noreg = False if 'noreg' not in config else config['noreg'] 200 | self.config = config 201 | 202 | def run(self): 203 | logging.info('Training {0} with standard validation..' 204 | .format(self.modelname)) 205 | regs = [10**t for t in range(-5, -1)] if self.usepytorch else \ 206 | [2**t for t in range(-2, 4, 1)] 207 | if self.noreg: 208 | regs = [1e-9 if self.usepytorch else 1e9] 209 | scores = [] 210 | for reg in regs: 211 | if self.usepytorch: 212 | clf = MLP(self.classifier_config, inputdim=self.featdim, 213 | nclasses=self.nclasses, l2reg=reg, 214 | seed=self.seed, cudaEfficient=self.cudaEfficient) 215 | 216 | # TODO: Find a hack for reducing nb epoches in SNLI 217 | clf.fit(self.X['train'], self.y['train'], 218 | validation_data=(self.X['valid'], self.y['valid'])) 219 | else: 220 | clf = LogisticRegression(C=reg, random_state=self.seed) 221 | clf.fit(self.X['train'], self.y['train']) 222 | scores.append(round(100*clf.score(self.X['valid'], 223 | self.y['valid']), 2)) 224 | logging.info([('reg:'+str(regs[idx]), scores[idx]) 225 | for idx in range(len(scores))]) 226 | optreg = regs[np.argmax(scores)] 227 | devaccuracy = np.max(scores) 228 | logging.info('Validation : best param found is reg = {0} with score \ 229 | {1}'.format(optreg, devaccuracy)) 230 | clf = LogisticRegression(C=optreg, random_state=self.seed) 231 | logging.info('Evaluating...') 232 | if self.usepytorch: 233 | clf = MLP(self.classifier_config, inputdim=self.featdim, 234 | nclasses=self.nclasses, l2reg=optreg, 235 | seed=self.seed, cudaEfficient=self.cudaEfficient) 236 | 237 | # TODO: Find a hack for reducing nb epoches in SNLI 238 | clf.fit(self.X['train'], self.y['train'], 239 | validation_data=(self.X['valid'], self.y['valid'])) 240 | else: 241 | clf = LogisticRegression(C=optreg, random_state=self.seed) 242 | clf.fit(self.X['train'], self.y['train']) 243 | 244 | testaccuracy = clf.score(self.X['test'], self.y['test']) 245 | testaccuracy = round(100*testaccuracy, 2) 246 | return devaccuracy, testaccuracy 247 | -------------------------------------------------------------------------------- /SentEval/senteval/trec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | ''' 9 | TREC question-type classification 10 | ''' 11 | 12 | from __future__ import absolute_import, division, unicode_literals 13 | 14 | import os 15 | import io 16 | import logging 17 | import numpy as np 18 | 19 | from senteval.tools.validation import KFoldClassifier 20 | 21 | 22 | class TRECEval(object): 23 | def __init__(self, task_path, seed=1111): 24 | logging.info('***** Transfer task : TREC *****\n\n') 25 | self.seed = seed 26 | self.train = self.loadFile(os.path.join(task_path, 'train_5500.label')) 27 | self.test = self.loadFile(os.path.join(task_path, 'TREC_10.label')) 28 | 29 | def do_prepare(self, params, prepare): 30 | samples = self.train['X'] + self.test['X'] 31 | return prepare(params, samples) 32 | 33 | def loadFile(self, fpath): 34 | trec_data = {'X': [], 'y': []} 35 | tgt2idx = {'ABBR': 0, 'DESC': 1, 'ENTY': 2, 36 | 'HUM': 3, 'LOC': 4, 'NUM': 5} 37 | with io.open(fpath, 'r', encoding='latin-1') as f: 38 | for line in f: 39 | target, sample = line.strip().split(':', 1) 40 | sample = sample.split(' ', 1)[1].split() 41 | assert target in tgt2idx, target 42 | trec_data['X'].append(sample) 43 | trec_data['y'].append(tgt2idx[target]) 44 | return trec_data 45 | 46 | def run(self, params, batcher): 47 | train_embeddings, test_embeddings = [], [] 48 | 49 | # Sort to reduce padding 50 | sorted_corpus_train = sorted(zip(self.train['X'], self.train['y']), 51 | key=lambda z: (len(z[0]), z[1])) 52 | train_samples = [x for (x, y) in sorted_corpus_train] 53 | train_labels = [y for (x, y) in sorted_corpus_train] 54 | 55 | sorted_corpus_test = sorted(zip(self.test['X'], self.test['y']), 56 | key=lambda z: (len(z[0]), z[1])) 57 | test_samples = [x for (x, y) in sorted_corpus_test] 58 | test_labels = [y for (x, y) in sorted_corpus_test] 59 | 60 | # Get train embeddings 61 | for ii in range(0, len(train_labels), params.batch_size): 62 | batch = train_samples[ii:ii + params.batch_size] 63 | embeddings = batcher(params, batch) 64 | train_embeddings.append(embeddings) 65 | train_embeddings = np.vstack(train_embeddings) 66 | logging.info('Computed train embeddings') 67 | 68 | # Get test embeddings 69 | for ii in range(0, len(test_labels), params.batch_size): 70 | batch = test_samples[ii:ii + params.batch_size] 71 | embeddings = batcher(params, batch) 72 | test_embeddings.append(embeddings) 73 | test_embeddings = np.vstack(test_embeddings) 74 | logging.info('Computed test embeddings') 75 | 76 | config_classifier = {'nclasses': 6, 'seed': self.seed, 77 | 'usepytorch': params.usepytorch, 78 | 'classifier': params.classifier, 79 | 'kfold': params.kfold} 80 | clf = KFoldClassifier({'X': train_embeddings, 81 | 'y': np.array(train_labels)}, 82 | {'X': test_embeddings, 83 | 'y': np.array(test_labels)}, 84 | config_classifier) 85 | devacc, testacc, _ = clf.run() 86 | logging.debug('\nDev acc : {0} Test acc : {1} \ 87 | for TREC\n'.format(devacc, testacc)) 88 | return {'devacc': devacc, 'acc': testacc, 89 | 'ndev': len(self.train['X']), 'ntest': len(self.test['X'])} 90 | -------------------------------------------------------------------------------- /SentEval/senteval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from __future__ import absolute_import, division, unicode_literals 9 | 10 | import numpy as np 11 | import re 12 | import inspect 13 | from torch import optim 14 | 15 | 16 | def create_dictionary(sentences): 17 | words = {} 18 | for s in sentences: 19 | for word in s: 20 | if word in words: 21 | words[word] += 1 22 | else: 23 | words[word] = 1 24 | words[''] = 1e9 + 4 25 | words[''] = 1e9 + 3 26 | words['

'] = 1e9 + 2 27 | # words[''] = 1e9 + 1 28 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort 29 | id2word = [] 30 | word2id = {} 31 | for i, (w, _) in enumerate(sorted_words): 32 | id2word.append(w) 33 | word2id[w] = i 34 | 35 | return id2word, word2id 36 | 37 | 38 | def cosine(u, v): 39 | return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v)) 40 | 41 | 42 | class dotdict(dict): 43 | """ dot.notation access to dictionary attributes """ 44 | __getattr__ = dict.get 45 | __setattr__ = dict.__setitem__ 46 | __delattr__ = dict.__delitem__ 47 | 48 | 49 | def get_optimizer(s): 50 | """ 51 | Parse optimizer parameters. 52 | Input should be of the form: 53 | - "sgd,lr=0.01" 54 | - "adagrad,lr=0.1,lr_decay=0.05" 55 | """ 56 | if "," in s: 57 | method = s[:s.find(',')] 58 | optim_params = {} 59 | for x in s[s.find(',') + 1:].split(','): 60 | split = x.split('=') 61 | assert len(split) == 2 62 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 63 | optim_params[split[0]] = float(split[1]) 64 | else: 65 | method = s 66 | optim_params = {} 67 | 68 | if method == 'adadelta': 69 | optim_fn = optim.Adadelta 70 | elif method == 'adagrad': 71 | optim_fn = optim.Adagrad 72 | elif method == 'adam': 73 | optim_fn = optim.Adam 74 | elif method == 'adamax': 75 | optim_fn = optim.Adamax 76 | elif method == 'asgd': 77 | optim_fn = optim.ASGD 78 | elif method == 'rmsprop': 79 | optim_fn = optim.RMSprop 80 | elif method == 'rprop': 81 | optim_fn = optim.Rprop 82 | elif method == 'sgd': 83 | optim_fn = optim.SGD 84 | assert 'lr' in optim_params 85 | else: 86 | raise Exception('Unknown optimization method: "%s"' % method) 87 | 88 | # check that we give good parameters to the optimizer 89 | try: 90 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 91 | except ValueError: 92 | expected_args = inspect.getfullargspec(optim_fn.__init__)[0] 93 | assert expected_args[:2] == ['self', 'params'] 94 | if not all(k in expected_args[2:] for k in optim_params.keys()): 95 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 96 | str(expected_args[2:]), str(optim_params.keys()))) 97 | 98 | return optim_fn, optim_params 99 | -------------------------------------------------------------------------------- /SentEval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import io 9 | from setuptools import setup, find_packages 10 | 11 | with io.open('./README.md', encoding='utf-8') as f: 12 | readme = f.read() 13 | 14 | setup( 15 | name='SentEval', 16 | version='0.1.0', 17 | url='https://github.com/facebookresearch/SentEval', 18 | packages=find_packages(exclude=['examples']), 19 | license='Attribution-NonCommercial 4.0 International', 20 | long_description=readme, 21 | ) 22 | -------------------------------------------------------------------------------- /arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kongds/E5-V/856ea816b943df95d77df39a386302b14ea197a2/arial.ttf -------------------------------------------------------------------------------- /data/download_nli.sh: -------------------------------------------------------------------------------- 1 | wget https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/nli_for_simcse.csv 2 | -------------------------------------------------------------------------------- /ds.config: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 1000, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "zero_optimization": { 16 | "stage": 2, 17 | "allgather_partitions": true, 18 | "allgather_bucket_size": 2e8, 19 | "overlap_comm": true, 20 | "reduce_scatter": true, 21 | "reduce_bucket_size": 2e8, 22 | "contiguous_gradients": true 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /eval_sts.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import io, os 4 | import torch 5 | import numpy as np 6 | import logging 7 | import tqdm 8 | import fcntl 9 | import time 10 | import argparse 11 | from prettytable import PrettyTable 12 | import transformers 13 | from transformers import LlamaTokenizer 14 | from transformers import AutoTokenizer, AutoModelForCausalLM 15 | # Set up logger 16 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 17 | 18 | # Set PATHs 19 | PATH_TO_SENTEVAL = './SentEval' 20 | PATH_TO_DATA = './SentEval/data' 21 | 22 | # Import SentEval 23 | sys.path.insert(0, PATH_TO_SENTEVAL) 24 | import senteval 25 | 26 | 27 | def print_table(task_names, scores): 28 | tb = PrettyTable() 29 | tb.field_names = task_names 30 | tb.add_row(scores) 31 | print(tb) 32 | 33 | def lock_and_write_file(file_path, content): 34 | with open(file_path, 'a') as file: 35 | while True: 36 | try: 37 | # Acquire an exclusive lock (non-blocking) 38 | fcntl.flock(file, fcntl.LOCK_EX | fcntl.LOCK_NB) 39 | 40 | # Perform your write operations here 41 | file.write(content + '\n') 42 | file.flush() 43 | 44 | except IOError as e: 45 | print("File is locked by another process. Can't write.") 46 | time.sleep(1) 47 | finally: 48 | # Release the lock 49 | fcntl.flock(file, fcntl.LOCK_UN) 50 | break 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--mask_embedding_sentence_template', type=str, default='*sent_0*\nSummary_above_sentence_in_one_word:') 55 | parser.add_argument("--tokenizer_name", type=str, default='') 56 | parser.add_argument("--model_name_or_path", type=str, 57 | help="Transformers' model name or path") 58 | parser.add_argument("--mode", type=str, 59 | choices=['dev', 'test', 'fasttest'], 60 | default='test', 61 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results") 62 | parser.add_argument("--task_set", type=str, 63 | choices=['sts', 'transfer', 'full', 'na'], 64 | default='sts', 65 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'") 66 | parser.add_argument('--load_kbit', type=int, 67 | choices=[4,8,16], 68 | default=16, 69 | help="Load model in kbit") 70 | 71 | parser.add_argument('--avg', action='store_true') 72 | 73 | 74 | args = parser.parse_args() 75 | from accelerate import Accelerator 76 | accelerator = Accelerator() 77 | device = accelerator.device 78 | 79 | if args.load_kbit == 4: 80 | from transformers import BitsAndBytesConfig 81 | model = AutoModelForCausalLM.from_pretrained( 82 | args.model_name_or_path, 83 | load_in_4bit=True, 84 | quantization_config=BitsAndBytesConfig( 85 | load_in_4bit=True, 86 | llm_int8_threshold=6.0, 87 | llm_int8_has_fp16_weight=False, 88 | bnb_4bit_compute_dtype=torch.float16, 89 | bnb_4bit_use_double_quant=True, 90 | bnb_4bit_quant_type='nf4', 91 | ), 92 | torch_dtype=torch.float16, 93 | device_map=device, 94 | ) 95 | elif 'llava' in args.model_name_or_path or 'e5-v' in args.model_name_or_path: 96 | from transformers import LlavaNextForConditionalGeneration 97 | model = LlavaNextForConditionalGeneration.from_pretrained( 98 | args.model_name_or_path, 99 | load_in_8bit=args.load_kbit == 8, 100 | torch_dtype=torch.float16, 101 | device_map=device, 102 | ) 103 | model = model.language_model 104 | else: 105 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, 106 | device_map=device, 107 | output_hidden_states=True, 108 | _attn_implementation='eager', 109 | trust_remote_code=True, 110 | load_in_8bit=args.load_kbit == 8,) 111 | 112 | 113 | if 'Phi' in args.model_name_or_path or 'phi' in args.model_name_or_path: 114 | from transformers import AutoProcessor 115 | transform = AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True) 116 | tokenizer = transform.tokenizer 117 | tokenizer.padding = True 118 | elif 'llama-3' in args.model_name_or_path or 'e5-v' in args.model_name_or_path: 119 | tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct") 120 | tokenizer.pad_token_id = tokenizer.eos_token_id 121 | tokenizer.padding = True 122 | elif 'llava' in args.model_name_or_path: 123 | from transformers import LlavaNextProcessor 124 | transform = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", 125 | revision='a1d521368f8d353afa4da2ed2bb1bf646ef1ff5f') 126 | tokenizer = transform.tokenizer 127 | tokenizer.padding = True 128 | else: 129 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 130 | 131 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 132 | tokenizer.padding_side = "left" # Allow batched inference 133 | 134 | # Set up the tasks 135 | #args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 136 | #args.tasks = ['MR'] 137 | if args.task_set == 'sts': 138 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 139 | if args.mode == 'dev': 140 | args.tasks = ['STSBenchmark-dev'] 141 | elif args.task_set == 'transfer': 142 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 143 | elif args.task_set == 'full': 144 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 145 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC'] 146 | 147 | # Set params for SentEval 148 | if args.mode == 'dev' or args.mode == 'fasttest': 149 | # Fast mode 150 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5, 'batch_size': 32} 151 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 32, 152 | 'tenacity': 3, 'epoch_size': 2} 153 | elif args.mode == 'test': 154 | # Full mode 155 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size':16} 156 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64, 157 | 'tenacity': 5, 'epoch_size': 4} 158 | else: 159 | raise NotImplementedError 160 | 161 | import torch.nn.functional as F 162 | 163 | import torch.distributed as dist 164 | local_rank = dist.get_rank() 165 | world_size = dist.get_world_size() 166 | 167 | 168 | # SentEval prepare and batcher 169 | def prepare(params, samples): 170 | return 171 | 172 | params['batch_size'] = 4*world_size 173 | def batcher(params, batch, max_length=None): 174 | # Handle rare token encoding issues in the dataset 175 | if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes): 176 | batch = [[word.decode('utf-8') for word in s] for s in batch] 177 | 178 | sentences = [' '.join(s) for s in batch] 179 | if max_length == 500: 180 | sentences = [tokenizer.decode(tokenizer.encode(s, add_special_tokens=False)[:max_length]) for s in sentences] 181 | max_length = 512 182 | 183 | if args.mask_embedding_sentence_template is not None: 184 | # *cls*_This_sentence_of_"*sent_0*"_means*mask*.*sep+* 185 | if 'llama-3' in args.model_name_or_path or 'e5-v' in args.model_name_or_path: 186 | mllm_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' 187 | elif 'llava' in args.model_name_or_path: 188 | mllm_template = '[INST] {} [/INST]' 189 | elif 'Phi' in args.model_name_or_path or 'phi' in args.model_name_or_path: 190 | mllm_template = '<|user|>\n{}<|end|>\n<|assistant|>\n' 191 | 192 | template = args.mask_embedding_sentence_template 193 | template = template.replace('_', ' ').replace('*sep+*', '')\ 194 | .replace('*cls*', '') 195 | 196 | for i, s in enumerate(sentences): 197 | if len(s) > 0 and s[-1] not in '.?"\'': s += '.' 198 | s = s.replace('"', '\'') 199 | if len(s) > 0 and '?' == s[-1]: s = s[:-1] + '.' 200 | sentences[i] = mllm_template.format(' ' + template.replace('*sent 0*', s).strip()) 201 | real_bsz = len(sentences) 202 | if real_bsz % world_size != 0: 203 | sentences += [sentences[0]] * (world_size - real_bsz % world_size) 204 | 205 | bsz = len(sentences) 206 | sub_sentences = sentences[local_rank * bsz // world_size: (local_rank + 1) * bsz // world_size] 207 | 208 | batch = tokenizer.batch_encode_plus( 209 | sub_sentences, 210 | return_tensors='pt', 211 | padding=True, 212 | max_length=max_length, 213 | truncation=max_length is not None 214 | ) 215 | 216 | # Move to the correct device 217 | for k in batch: 218 | batch[k] = batch[k].to(device) if batch[k] is not None else None 219 | 220 | # Get raw embeddings 221 | with torch.no_grad(): 222 | hidden_states = model(output_hidden_states=True, return_dict=True, **batch).hidden_states 223 | if args.avg: 224 | last_layer = hidden_states[-1] 225 | attention_mask = batch['attention_mask'].unsqueeze(-1).expand(last_layer.shape) 226 | outputs = (last_layer * attention_mask).mean(1) 227 | else: 228 | outputs = hidden_states[-1][:, -1, :] 229 | 230 | if outputs.dtype == torch.bfloat16: 231 | # bfloat16 not support for .numpy() 232 | outputs = outputs.float() 233 | 234 | emb = outputs 235 | emb = accelerator.gather(emb)[:real_bsz] 236 | return emb.cpu() 237 | 238 | 239 | 240 | results = {} 241 | 242 | 243 | args.mask_embedding_sentence_template = args.mask_embedding_sentence_template.replace('\\n', '\n') 244 | print(args.mask_embedding_sentence_template) 245 | 246 | for task in args.tasks: 247 | se = senteval.engine.SE(params, batcher, prepare) 248 | result = se.eval(task) 249 | results[task] = result 250 | 251 | # Print evaluation results 252 | if args.mode == 'test' or args.mode == 'fasttest': 253 | print("------ %s ------" % (args.mode)) 254 | 255 | task_names = [] 256 | scores = [] 257 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']: 258 | task_names.append(task) 259 | if task in results: 260 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']: 261 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100)) 262 | else: 263 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100)) 264 | else: 265 | scores.append("0.00") 266 | task_names.append("Avg.") 267 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 268 | if accelerator.is_main_process: 269 | print_table(task_names, scores) 270 | # 271 | # write results and template to file 272 | if args.mask_embedding_sentence_template is not None and args.task_set != 'transfer': 273 | with open('./sts-org-results', 'a') as f: 274 | bits = f'{args.load_kbit}bit' 275 | model_name = args.model_name_or_path.split('/')[-1] + f'({bits})' 276 | f.write(args.mask_embedding_sentence_template + ' ' + model_name + ' ' + ' '.join([str(s) for s in scores]) + '\n') 277 | 278 | task_names = [] 279 | scores = [] 280 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']: 281 | task_names.append(task) 282 | if task in results: 283 | scores.append("%.2f" % (results[task]['acc'])) 284 | else: 285 | scores.append("0.00") 286 | task_names.append("Avg.") 287 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores))) 288 | if accelerator.is_main_process: 289 | print_table(task_names, scores) 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /figure/e5v.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kongds/E5-V/856ea816b943df95d77df39a386302b14ea197a2/figure/e5v.png -------------------------------------------------------------------------------- /load_llama3_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, LlavaNextConfig, AutoConfig 2 | import torch 3 | from PIL import Image 4 | import requests 5 | 6 | config = LlavaNextConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") 7 | config.text_config = AutoConfig.from_pretrained("unsloth/llama-3-8b-Instruct") 8 | 9 | from safetensors import safe_open 10 | sd = {} 11 | for i in range(1, 5): 12 | with safe_open(f"models/model-0000{i}-of-00004.safetensors", framework="pt", device="cpu") as f: 13 | for key in f.keys(): 14 | sd[key] = f.get_tensor(key) 15 | 16 | model = LlavaNextForConditionalGeneration(config) 17 | 18 | keys = list(sd.keys()) 19 | for key in keys: 20 | if 'mm_projector' not in key and 'vision_tower' not in key: 21 | sd['language_model.' + key] = sd[key] 22 | del sd[key] 23 | keys = list(sd.keys()) 24 | for key in keys: 25 | if 'vision_tower' in key: 26 | sd[key.replace('model.vision_tower.', '')] = sd[key] 27 | del sd[key] 28 | sd['multi_modal_projector.linear_1.weight'] = sd['model.mm_projector.0.weight'] 29 | sd['multi_modal_projector.linear_2.weight'] = sd['model.mm_projector.2.weight'] 30 | sd['multi_modal_projector.linear_1.bias'] = sd['model.mm_projector.0.bias'] 31 | sd['multi_modal_projector.linear_2.bias'] = sd['model.mm_projector.2.bias'] 32 | del sd['model.mm_projector.0.weight'] 33 | del sd['model.mm_projector.2.weight'] 34 | del sd['model.mm_projector.0.bias'] 35 | del sd['model.mm_projector.2.bias'] 36 | sd['image_newline'] = sd['language_model.model.image_newline'] 37 | del sd['language_model.model.image_newline'] 38 | model.load_state_dict(sd) 39 | model.save_pretrained('models/llava-llama-3-8b') 40 | # save language model for training 41 | model.language_model.save_pretrained('models/llava-llama-3-8b-llm') 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.41.2 2 | deepspeed==0.11.1 3 | peft==0.11.1 4 | accelerate==0.27.2 5 | bitsandbytes==0.43.1 6 | scipy 7 | datasets==2.20.0 8 | safetensors==0.4.3 9 | pandas==1.1.5 10 | scikit-learn 11 | prettytable 12 | -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import math 5 | import glob 6 | import random 7 | import sys 8 | import torch 9 | from contextlib import suppress 10 | from tqdm import tqdm 11 | 12 | import argparse 13 | from accelerate import Accelerator 14 | import transformers 15 | from copy import copy 16 | from itertools import product 17 | 18 | 19 | import logging 20 | 21 | from collections import defaultdict 22 | from html.parser import HTMLParser 23 | from typing import Any, Callable, Dict, List, Optional, Tuple 24 | 25 | from PIL import Image 26 | from datasets import load_from_disk, load_dataset 27 | 28 | 29 | from torch import nn 30 | import torch.nn.functional as F 31 | 32 | from transformers import AutoTokenizer 33 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 34 | from transformers import LlavaNextConfig, AutoModel 35 | from transformers.models.llava_next.modeling_llava_next import LlavaNextMultiModalProjector 36 | from transformers import AutoModelForCausalLM 37 | from transformers import AutoProcessor 38 | 39 | from peft import PeftModel 40 | 41 | DEBUG = False 42 | MODEL_TYPE = 'llava' 43 | 44 | accelerator = Accelerator() 45 | 46 | llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' 47 | 48 | def recall_at_k(scores, positive_pairs, k): 49 | """ 50 | Compute the recall at k for each sample 51 | :param scores: compability score between text and image embeddings (nb texts, nb images) 52 | :param k: number of images to consider per text, for retrieval 53 | :param positive_pairs: boolean matrix of positive pairs (nb texts, nb images) 54 | :return: recall at k averaged over all texts 55 | """ 56 | nb_texts, nb_images = scores.shape 57 | # for each text, sort according to image scores in decreasing order 58 | topk_indices = torch.topk(scores, k, dim=1)[1] 59 | # compute number of positives for each text 60 | nb_positive = positive_pairs.sum(dim=1) 61 | # nb_texts, k, nb_images 62 | topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images) 63 | # compute number of true positives 64 | positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images) 65 | # a true positive means a positive among the topk 66 | nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1,2)) 67 | # compute recall at k 68 | recall_at_k = (nb_true_positive / nb_positive) 69 | return recall_at_k 70 | 71 | def batchify(func, X, Y, batch_size, device, *args, **kwargs): 72 | results = [] 73 | for start in range(0, len(X), batch_size): 74 | end = start + batch_size 75 | x = X[start:end].to(device) 76 | y = Y[start:end].to(device) 77 | result = func(x, y, *args, **kwargs).cpu() 78 | results.append(result) 79 | return torch.cat(results) 80 | 81 | def emb_data(model, transform, dataset, device, 82 | emb_type='text', prompt=None, bsz=4, 83 | text_column='caption', img_column='img'): 84 | # emb img 85 | def custom_collate_fn(batch): 86 | collated_batch = {} 87 | for key in batch[0].keys(): 88 | collated_batch[key] = [b[key] for b in batch] 89 | return collated_batch 90 | 91 | dataloader = torch.utils.data.DataLoader( 92 | dataset, batch_size=3*bsz if emb_type == 'text' else bsz, 93 | shuffle=False, num_workers=1, 94 | collate_fn=custom_collate_fn 95 | ) 96 | dataloader = accelerator.prepare(dataloader) 97 | embs = [] 98 | bar = tqdm(total=len(dataloader)) 99 | for batch in dataloader: 100 | if emb_type == 'text': 101 | input_texts = [prompt.replace('', text) for text in sum(batch[text_column], start=[])] 102 | inputs = transform(input_texts, 103 | return_tensors="pt", padding=True) 104 | for key in inputs: 105 | if inputs[key] is not None: 106 | inputs[key] = inputs[key].to(device) 107 | else: 108 | input_texts = [prompt]*len(batch[img_column]) 109 | if MODEL_TYPE == 'phi3': 110 | # phi3 only support 1 bsz for image 111 | assert len(input_texts) == 1 112 | input_texts = input_texts[0] 113 | inputs = transform(input_texts, 114 | batch[img_column], return_tensors="pt", padding=True).to(device) 115 | 116 | 117 | with torch.no_grad(): 118 | emb = model(**inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :] 119 | emb = F.normalize(emb, dim=-1) 120 | emb = accelerator.gather(emb) 121 | embs.append(emb.cpu().float()) 122 | bar.update(1) 123 | embs = torch.cat(embs) 124 | total = 0 125 | for i in dataset: 126 | if emb_type == 'text' and type(i[text_column]) is list: 127 | total += len(i[text_column]) 128 | else: 129 | total += 1 130 | bar.close() 131 | return embs[:total] 132 | 133 | def log_to_file(data, metrics, checkpoint_name, fiq_data_type=None, orc_replace_text=False): 134 | if data == 'flickr30k' or data == 'coco': 135 | if orc_replace_text: 136 | output = f"orc {data}: {metrics['image_retrieval_recall@5']:.4f} {metrics['text_retrieval_recall@5']:.4f}" 137 | else: 138 | output = f"{data}: {metrics['image_retrieval_recall@5']:.4f} {metrics['text_retrieval_recall@5']:.4f}" 139 | elif data == 'fashioniq': 140 | assert len(metrics) == 2 141 | r_at_1, r_at_5 = metrics 142 | output = f"{data} {fiq_data_type}: R@10: {r_at_1:.4f} R@50: {r_at_5:.4f}" 143 | elif data == 'cirr': 144 | assert len(metrics) == 3 145 | r_at_1, r_at_3, r_at_5 = metrics 146 | output = f"{data}: R@1: {r_at_1:.4f} R@5: {r_at_3:.4f} R@10: {r_at_5:.4f}" 147 | 148 | if checkpoint_name is not None: 149 | with open(checkpoint_name, 'a') as f: 150 | print(output, file=f) 151 | return output 152 | 153 | def init_model_and_transform(lora_path, bf16, fp32, use_e5v=False): 154 | dtype = torch.bfloat16 if bf16 else torch.float16 155 | if fp32: 156 | dtype = torch.float32 157 | 158 | if MODEL_TYPE == 'phi3': 159 | transform = AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True) 160 | 161 | if lora_path is not None: 162 | merge_path = 'merged-' + lora_path.replace('/', '-').replace('.', '') 163 | with accelerator.main_process_first(): 164 | if not os.path.exists(merge_path): 165 | model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-vision-128k-instruct", 166 | device_map="cuda", trust_remote_code=True, 167 | torch_dtype=dtype, _attn_implementation='eager') 168 | model = PeftModel.from_pretrained(model, lora_path).merge_and_unload() 169 | model.save_pretrained(merge_path, safe_serialization=False) 170 | model_name = merge_path 171 | else: 172 | model_name = "microsoft/Phi-3-vision-128k-instruct" 173 | model = AutoModelForCausalLM.from_pretrained(model_name, 174 | device_map="cuda", trust_remote_code=True, 175 | torch_dtype=dtype, _attn_implementation='eager') 176 | transform.tokenizer.padding_side = "left" 177 | transform.tokenizer.padding = True 178 | return model, transform 179 | else: 180 | MODEL_CLASS = LlavaNextForConditionalGeneration 181 | transform = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", 182 | revision='a1d521368f8d353afa4da2ed2bb1bf646ef1ff5f') 183 | if MODEL_TYPE == 'llava_llama3': 184 | tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct") 185 | transform.tokenizer = tokenizer 186 | transform.tokenizer.add_tokens('') 187 | transform.tokenizer.pad_token_id = transform.tokenizer.eos_token_id 188 | transform.tokenizer.padding_side = "left" 189 | transform.tokenizer.padding = True 190 | 191 | model_name = "llava-hf/llava-v1.6-mistral-7b-hf" 192 | 193 | if MODEL_TYPE == 'llava_llama3': 194 | model_name = "./llava-llama-3-8b" 195 | 196 | if lora_path is not None: 197 | merge_path = 'merged-' + lora_path.replace('/', '-').replace('.', '') 198 | with accelerator.main_process_first(): 199 | if not os.path.exists(merge_path): 200 | model = MODEL_CLASS.from_pretrained(model_name, 201 | device_map='cpu') 202 | model.language_model = PeftModel.from_pretrained(model.language_model, lora_path).merge_and_unload() 203 | model.save_pretrained(merge_path) 204 | model_name = merge_path 205 | 206 | if use_e5v: 207 | model_name = 'royokong/e5-v' 208 | transform = LlavaNextProcessor.from_pretrained('royokong/e5-v') 209 | 210 | 211 | model = MODEL_CLASS.from_pretrained(model_name, 212 | torch_dtype=dtype, low_cpu_mem_usage=True) 213 | if MODEL_TYPE == 'llava_llama3': 214 | model.config.image_token_index = 128256 215 | 216 | return model, transform 217 | 218 | def create_text_image(text, image_width=800, image_height=400, font_path="arial.ttf", 219 | font_size=40, background_color=(255, 255, 255), text_color=(0, 0, 0)): 220 | from PIL import Image, ImageDraw, ImageFont 221 | image = Image.new('RGB', (image_width, image_height), color=background_color) 222 | 223 | # Initialize ImageDraw 224 | draw = ImageDraw.Draw(image) 225 | 226 | # Load the font 227 | font = ImageFont.truetype(font_path, font_size) 228 | 229 | # Function to wrap text 230 | def draw_text_with_wrapping(draw, text, font, max_width): 231 | lines = [] 232 | words = text.split() 233 | while words: 234 | line = '' 235 | while words and draw.textlength(line + words[0], font=font) <= max_width: 236 | line += (words.pop(0) + ' ') 237 | lines.append(line) 238 | return lines 239 | 240 | # Calculate the maximum width for the text 241 | max_text_width = image_width - 40 # Adding some padding 242 | 243 | # Get the lines of wrapped text 244 | lines = draw_text_with_wrapping(draw, text, font, max_text_width) 245 | 246 | # Calculate the position for the text 247 | total_text_height = sum(draw.textbbox((0, 0), line, font=font)[3] - draw.textbbox((0, 0), line, font=font)[1] for line in lines) 248 | text_x = 20 249 | text_y = (image_height - total_text_height) // 2 250 | 251 | # Add text to image 252 | for line in lines: 253 | draw.text((text_x, text_y), line, font=font, fill=text_color) 254 | text_y += draw.textbbox((0, 0), line, font=font)[3] - draw.textbbox((0, 0), line, font=font)[1] 255 | 256 | return image 257 | 258 | def ir(model, transform, 259 | img_prompt, text_prompt, 260 | data, device, 261 | ocr_replace_text=False, 262 | batch_size=None): 263 | dataset = load_dataset(f'royokong/{data}_test', split='test') 264 | 265 | dataset = dataset.rename_column('text', 'caption') 266 | dataset = dataset.rename_column('image', 'img') 267 | if data == 'coco': 268 | dataset = dataset.map(lambda x: {'caption': x['caption'][:5]}, num_proc=4) 269 | 270 | bsz = 4 271 | if batch_size is not None: 272 | bsz = batch_size 273 | 274 | if ocr_replace_text: 275 | with accelerator.main_process_first(): 276 | if os.path.exists(f'{data}_ocr'): 277 | ocr_dataset = load_from_disk(f'{data}_ocr') 278 | else: 279 | ocrs = [] 280 | for i in dataset: 281 | ocrs.extend(i['caption']) 282 | from datasets import Dataset 283 | ocr_dataset = Dataset.from_dict({'ocr': ocrs}) 284 | ocr_dataset = ocr_dataset.map(lambda x: {'img': create_text_image(x['ocr'])}, num_proc=40) 285 | ocr_dataset.save_to_disk(f'{data}_ocr') 286 | orc_prompt = img_prompt#.replace(' above image ', ' sentence in above image ') 287 | print(orc_prompt) 288 | text_embs = emb_data(model,transform, ocr_dataset, device, emb_type='image', prompt=orc_prompt, bsz=bsz) 289 | else: 290 | text_embs = emb_data(model,transform, dataset, device, emb_type='text', prompt=text_prompt, bsz=bsz) 291 | img_embs = emb_data(model,transform, dataset, device, emb_type='image', prompt=img_prompt, bsz=bsz) 292 | 293 | texts_image_index = [i // 5 for i in range(img_embs.shape[0]*5)] 294 | assert len(texts_image_index) == len(text_embs) 295 | 296 | assert text_embs.isnan().sum().item() == 0, 'nan in retrieve emb' 297 | assert img_embs.isnan().sum().item() == 0, 'nan in images emb' 298 | 299 | # get the score for each text and image pair 300 | scores = text_embs @ img_embs.t() 301 | 302 | positive_pairs = torch.zeros_like(scores, dtype=bool) 303 | positive_pairs[torch.arange(len(scores)), texts_image_index] = True 304 | metrics = {} 305 | recall_k_list = [1, 5, 10] 306 | batch_size = 64 307 | for recall_k in recall_k_list: 308 | # Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number 309 | # of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k. 310 | # Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions 311 | # for each image, that number will be greater than 1 for text retrieval. 312 | # However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different. 313 | # recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k. 314 | # so we can easily compute that using the actual recall, by checking whether there is at least one true positive, 315 | # which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average 316 | # it over the dataset. 317 | metrics[f"image_retrieval_recall@{recall_k}"] = (batchify(recall_at_k, scores, positive_pairs, batch_size, device, k=recall_k)>0).float().mean().item() 318 | metrics[f"text_retrieval_recall@{recall_k}"] = (batchify(recall_at_k, scores.T, positive_pairs.T, batch_size, device, k=recall_k)>0).float().mean().item() 319 | 320 | return metrics 321 | 322 | def cir(model, transform, 323 | img_prompt, text_img_prompt, 324 | data, fiq_data_type, 325 | device, 326 | fiq_two=False, 327 | fusion_cir=False, 328 | img_only=False, 329 | batch_size=None): 330 | print(img_prompt) 331 | print(text_img_prompt) 332 | phi3 = MODEL_TYPE == 'phi3' 333 | 334 | if data == 'fashioniq': 335 | assert fiq_data_type in ['dress', 'shirt', 'toptee'] 336 | dataset = load_dataset('royokong/fashioniq_val') 337 | img_dataset = load_dataset('royokong/fashioniq_val_imgs') 338 | 339 | dataset = dataset['val'].filter(lambda x: x['category'] == fiq_data_type, num_proc=4) 340 | img_dataset = img_dataset['val'].filter(lambda x: x['category'] == fiq_data_type, num_proc=4) 341 | elif data == 'cirrtest': 342 | dataset = load_dataset('royokong/cirr_test') 343 | img_dataset = load_dataset('royokong/cirr_imgs') 344 | 345 | dataset = dataset['test'] 346 | img_dataset = img_dataset['test'] 347 | # skip error of not having target_id 348 | dataset = dataset.add_column('target_id', [img_dataset[0]['id'] for i in range(len(dataset))]) 349 | else: 350 | dataset = load_dataset('royokong/cirr_val') 351 | img_dataset = load_dataset('royokong/cirr_imgs') 352 | 353 | dataset = dataset['val'] 354 | img_dataset = img_dataset['val'] 355 | 356 | if DEBUG: 357 | dataset = dataset.select(range(50)) 358 | img_dataset = img_dataset.select(range(50)) 359 | 360 | assert len(set(dataset['target_id']) - set(img_dataset['id'])) == 0 361 | 362 | bsz = 4 363 | if fiq_two: 364 | bsz //= 2 365 | if batch_size is not None: 366 | bsz = batch_size 367 | 368 | # emb img 369 | def custom_collate_fn(batch): 370 | collated_batch = {} 371 | for key in batch[0].keys(): 372 | collated_batch[key] = [b[key] for b in batch] 373 | return collated_batch 374 | collate_fn = custom_collate_fn 375 | 376 | if phi3: bsz=1 377 | 378 | img_dataloader = torch.utils.data.DataLoader( 379 | img_dataset, batch_size=bsz, 380 | shuffle=False, num_workers=4, 381 | collate_fn=collate_fn 382 | ) 383 | img_dataloader = accelerator.prepare(img_dataloader) 384 | images_embs = [] 385 | bar = tqdm(total=len(img_dataloader)) 386 | for batch in img_dataloader: 387 | input_texts = [img_prompt]*len(batch['img']) 388 | if phi3: 389 | assert len(input_texts) == 1 390 | input_texts = input_texts[0] 391 | inputs = transform(input_texts, 392 | batch['img'], return_tensors="pt", padding=True).to(device) 393 | with torch.no_grad(): 394 | embs = model(**inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :] 395 | embs = F.normalize(embs, dim=-1) 396 | assert embs.isnan().sum() == 0, 'nan in emb after norm' 397 | embs = accelerator.gather(embs) 398 | images_embs.append(embs.cpu().float()) 399 | bar.update(1) 400 | images_emb = torch.cat(images_embs)[:len(img_dataset['id'])] 401 | images_ids = img_dataset['id'] 402 | 403 | bar.close() 404 | 405 | dataloader = torch.utils.data.DataLoader( 406 | dataset, batch_size=bsz, 407 | shuffle=False, num_workers=4, 408 | collate_fn=collate_fn 409 | ) 410 | 411 | retrieve_emb = [] 412 | dataloader = accelerator.prepare(dataloader) 413 | bar = tqdm(total=len(dataloader)) 414 | for batch in dataloader: 415 | images = batch['candidate'] 416 | if data == 'fashioniq': 417 | caption = batch['caption'] 418 | if fiq_two: 419 | caption = caption + [i[::-1] for i in caption] 420 | images = images + images 421 | input_texts = [text_img_prompt.replace('', ', '.join([cc.strip('.?, ') for cc in c])) for c in caption] 422 | else: 423 | input_texts = [text_img_prompt.replace('', c) for c in batch['caption']] 424 | 425 | if phi3: 426 | with torch.no_grad(): 427 | _embs = [] 428 | for i in range(len(input_texts)): 429 | inputs = transform(input_texts[i], 430 | [images[i],], return_tensors="pt", padding=True).to(device) 431 | _embs.append(model(**inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]) 432 | embs = torch.cat(_embs, dim=0) 433 | if fiq_two: 434 | embs = embs[:len(batch['caption'])] + embs[len(batch['caption']):] 435 | embs = F.normalize(embs, dim=-1) 436 | else: 437 | inputs = transform(input_texts, 438 | images, return_tensors="pt", padding=True).to(device) 439 | with torch.no_grad(): 440 | embs = model(**inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :] 441 | if fiq_two: 442 | embs = embs[:len(batch['caption'])] + embs[len(batch['caption']):] 443 | embs = F.normalize(embs, dim=-1) 444 | embs = accelerator.gather(embs) 445 | retrieve_emb.append(embs.cpu().float()) 446 | bar.update(1) 447 | retrieve_emb = torch.cat(retrieve_emb)[:len(dataset['target_id'])] 448 | target_ids = dataset['target_id'] 449 | bar.close() 450 | 451 | assert retrieve_emb.isnan().sum().item() == 0, 'nan in retrieve emb' 452 | assert images_emb.isnan().sum().item() == 0, 'nan in images emb' 453 | 454 | scores = retrieve_emb @ images_emb.t() 455 | 456 | 457 | labels = [] 458 | for i, target_id in enumerate(target_ids): 459 | labels.append(images_ids.index(target_id)) 460 | 461 | if data == 'cirr' or data == 'cirrtest': 462 | # remove reference itself like SEARLE 463 | if not DEBUG: 464 | mask_index = [images_ids.index(label) for label in dataset['candidate_id']] 465 | for i, mid in enumerate(mask_index): 466 | scores[i][mid] = -1 467 | 468 | if data == 'cirrtest': 469 | submission = { 470 | 'version': 'rc2', 471 | 'metric': 'recall' 472 | } 473 | pairids = dataset['pairid'] 474 | for i, pairid in enumerate(pairids): 475 | top_k_indices = torch.topk(scores[i], k=50, largest=True).indices 476 | submission[str(pairid)] = [images_ids[j] for j in top_k_indices] 477 | return submission 478 | 479 | 480 | def cir_recall_at_k(scores, labels, k): 481 | """ 482 | Calculate Recall@k using PyTorch 483 | """ 484 | num_queries = scores.size(0) 485 | recalls = [] 486 | for i in range(num_queries): 487 | top_k_indices = torch.topk(scores[i], k=k, largest=True).indices 488 | recalls.append(int(labels[i] in top_k_indices)) 489 | return sum(recalls) / num_queries 490 | 491 | if data == 'fashioniq': 492 | # Calculate R@1, R@3, and R@5 493 | r_at_1 = cir_recall_at_k(scores, labels, 10) 494 | r_at_5 = cir_recall_at_k(scores, labels, 50) 495 | metrics = [r_at_1, r_at_5] 496 | else: 497 | # Calculate R@1, R@3, and R@5 498 | r_at_1 = cir_recall_at_k(scores, labels, 1) 499 | r_at_3 = cir_recall_at_k(scores, labels, 5) 500 | r_at_5 = cir_recall_at_k(scores, labels, 10) 501 | metrics = [r_at_1, r_at_3, r_at_5] 502 | 503 | return metrics 504 | 505 | def main( 506 | llava: bool = False, 507 | llava_llama3: bool = False, 508 | lora_path: str = None, 509 | img_only: bool = False, 510 | eol2: bool = False, 511 | name: str = None, 512 | use_icl: bool = False, 513 | fiq_two: bool = False, 514 | batch_size: int = 1, 515 | bf16: bool = False, 516 | fp32: bool = False, 517 | use_4bit: bool = False, 518 | data: str = None, 519 | not_save_fp32: bool = False, 520 | e5_project: str = None, 521 | debug: bool = False, 522 | ocr_replace_text: bool = False, 523 | phi3: bool = False, 524 | use_e5v: bool = False, 525 | ): 526 | global DEBUG, MODEL_TYPE 527 | DEBUG = debug 528 | 529 | if phi3: 530 | MODEL_TYPE = 'phi3' 531 | elif llava_llama3: 532 | MODEL_TYPE = 'llava_llama3' 533 | elif use_e5v: 534 | llava_llama3 = True 535 | MODEL_TYPE = 'llava_llama3' 536 | 537 | 538 | assert MODEL_TYPE in ['llava', 'llava_llama3', 'phi3'] 539 | 540 | # set NCCL_DEBUG 541 | if os.environ.get("NCCL_DEBUG", None) is None: 542 | os.environ["NCCL_DEBUG"] = "ERROR" 543 | 544 | device=accelerator.device 545 | 546 | model, transform = init_model_and_transform(lora_path, bf16, fp32, use_e5v=use_e5v) 547 | model.to(device) 548 | 549 | from datasets import disable_caching 550 | disable_caching() 551 | 552 | datasets = ['flickr30k', 'coco', 'fashioniq dress', 'fashioniq shirt', 'fashioniq toptee', 'cirr'] 553 | if data: 554 | datasets = data.split(',') 555 | 556 | if ocr_replace_text: 557 | datasets = ['flickr30k', 'coco'] 558 | 559 | all_results = [] 560 | for data in datasets: 561 | if 'fashioniq' in data: 562 | data, fiq_data_type = data.split(' ') 563 | fiq_two = True 564 | else: 565 | fiq_data_type = None 566 | fiq_two = False 567 | 568 | if data == 'flickr30k' or data == 'coco': 569 | if phi3: 570 | img_prompt = '<|user|>\n<|image_1|>\nSummary above image in one word:<|end|>\n<|assistant|>\n' 571 | text_prompt = '<|user|>\n\nSummary above sentence in one word:<|end|>\n<|assistant|>\n' 572 | elif llava_llama3: 573 | img_prompt = llama3_template.format('\nSummary above image in one word: ') 574 | text_prompt = llama3_template.format('\nSummary above sentence in one word: ') 575 | else: 576 | img_prompt = "[INST] \nSummary above image in one word: [/INST]" 577 | text_prompt = "[INST] \nSummary above sentence in one word: [/INST]" 578 | 579 | if accelerator.is_main_process: 580 | print(img_prompt) 581 | print(text_prompt) 582 | 583 | metrics = ir(model, transform, img_prompt, text_prompt, 584 | data, device, ocr_replace_text, batch_size) 585 | elif data == 'fashioniq' or data == 'cirr' or data == 'cirrtest': 586 | if data == 'fashioniq': 587 | fiq_data_name = fiq_data_type 588 | if fiq_data_type == 'toptee': 589 | fiq_data_name = 'shirt' 590 | img_prompt = f"[INST] \n Describe this {fiq_data_name} in one word based on its style: [/INST]" 591 | text_img_prompt = f"[INST] change the style of this {fiq_data_name} to \n Desribe this modified {fiq_data_name} in one word based on its style: [/INST]" 592 | else: 593 | img_prompt = "[INST] \n Describe this image in one word: [/INST]" 594 | text_img_prompt = "[INST] Modify this image with \"\", desribe modified image in one word: [/INST]" 595 | 596 | if llava_llama3: 597 | img_prompt = img_prompt.replace('[INST] ', '').replace(' [/INST]', '') 598 | text_img_prompt = text_img_prompt.replace('[INST] ', '').replace(' [/INST]', '') 599 | img_prompt = llama3_template.format(img_prompt) 600 | text_img_prompt = llama3_template.format(text_img_prompt) 601 | 602 | if phi3: 603 | img_prompt = img_prompt.replace('[INST] ', '').replace(' [/INST]', '').replace('', '<|image_1|>') 604 | text_img_prompt = text_img_prompt.replace('[INST] ', '').replace(' [/INST]', '').replace('', '<|image_1|>') 605 | 606 | img_prompt = '<|user|>\n{} <|end|>\n<|assistant|>\n'.format(img_prompt) 607 | text_img_prompt = '<|user|>\n{} <|end|>\n<|assistant|>\n'.format(text_img_prompt) 608 | 609 | if accelerator.is_main_process: 610 | print(img_prompt) 611 | print(text_img_prompt) 612 | 613 | metrics = cir(model, transform, img_prompt, text_img_prompt, data, fiq_data_type, 614 | device, 615 | fiq_two=fiq_two, 616 | batch_size=batch_size) 617 | 618 | if accelerator.is_main_process: 619 | print(metrics) 620 | if lora_path is not None or name is not None: 621 | checkpoint_name = lora_path.replace('/', '_') + '.txt' if lora_path is not None else name 622 | elif use_e5v: 623 | checkpoint_name = 'e5v.txt' 624 | else: 625 | checkpoint_name = None 626 | if data == 'cirrtest': 627 | with open(checkpoint_name.replace('.txt', '') + 'cirr_sub.json', 'w') as f: 628 | json.dump(metrics, f) 629 | else: 630 | all_results.append(log_to_file(data, metrics, checkpoint_name, fiq_data_type=fiq_data_type, orc_replace_text=ocr_replace_text)) 631 | 632 | if accelerator.is_main_process: 633 | print('\n'.join(all_results)) 634 | 635 | 636 | if __name__ == '__main__': 637 | from fire import Fire 638 | Fire(main) 639 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | RUN=e5v-8b 2 | 3 | args=() 4 | 5 | BASE_MODEL="models/llava-llama-3-8b" 6 | TEMPLATE='*sent_0*\nSummary_above_sentence_in_one_word:' 7 | 8 | BIT=4 9 | 10 | R=64 11 | ALPHA=16 12 | BATCH_SIZE=768 13 | MICRO_BATCH_SIZE=24 14 | EPOCH=2 15 | LR=4e-4 16 | 17 | echo $BASE_MODEL 18 | echo $TEMPLATE 19 | 20 | 21 | echo $MICRO_BATCH_SIZE $BATCH_SIZE 22 | 23 | GPUS=8 24 | NUM_NODES=4 25 | 26 | wandb online 27 | 28 | 29 | NCCL_DEBUG=ERROR deepspeed --num_gpus=$GPUS --num_nodes=$NUM_NODES ft_llm.py \ 30 | --base_model $BASE_MODEL \ 31 | --data_path 'data/nli_for_simcse.csv' \ 32 | --batch_size $BATCH_SIZE \ 33 | --micro_batch_size $MICRO_BATCH_SIZE \ 34 | --num_epochs $EPOCH \ 35 | --learning_rate $LR \ 36 | --cutoff_len 32 \ 37 | --lora_r $R \ 38 | --lora_alpha $ALPHA \ 39 | --lora_dropout 0.05 \ 40 | --output_dir $RUN --is_sentemb \ 41 | --mask_embedding_sentence_template $TEMPLATE --use_neg_sentence --save_steps 50 \ 42 | --deepspeed ds.config \ 43 | --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj --logging_steps 1 --grad_checkpoint \ 44 | --load_kbit $BIT \ 45 | ${args[@]} 46 | 47 | --------------------------------------------------------------------------------