├── SentEval
├── senteval
│ ├── tools
│ │ ├── __init__.py
│ │ ├── relatedness.py
│ │ ├── classifier.py
│ │ ├── validation.py
│ │ └── ranking.py
│ ├── __init__.py
│ ├── utils.py
│ ├── trec.py
│ ├── binary.py
│ ├── sst.py
│ ├── mrpc.py
│ ├── snli.py
│ ├── rank.py
│ ├── engine.py
│ ├── probing.py
│ ├── sick.py
│ └── sts.py
├── .gitignore
├── setup.py
├── LICENSE
├── examples
│ ├── skipthought.py
│ ├── googleuse.py
│ ├── gensen.py
│ ├── infersent.py
│ ├── bow.py
│ └── models.py
└── README.md
├── paper
├── paper.pdf
└── appendix.pdf
├── .gitignore
├── requirements.txt
├── config.py
├── args.py
├── eval_mteb.py
├── README.md
├── trainer.py
├── train.py
├── eval_senteval.py
└── model.py
/SentEval/senteval/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/paper/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinghaow99/DenoSent/HEAD/paper/paper.pdf
--------------------------------------------------------------------------------
/paper/appendix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinghaow99/DenoSent/HEAD/paper/appendix.pdf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | results
3 | wandb
4 | mteb_results
5 | SentEval/data/downstream/*
6 | *.sh
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.30.2
2 | mteb==1.0.0
3 | numpy==1.23.4
4 | accelerate
5 | torch==2.0.1
6 | scipy==1.9.2
7 | scikit-learn==1.1.2
8 | datasets
9 | prettytable
10 | wandb
--------------------------------------------------------------------------------
/SentEval/.gitignore:
--------------------------------------------------------------------------------
1 | # SentEval data and .pyc files
2 |
3 |
4 |
5 | # python
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # log files
11 | *.log
12 | *.txt
13 |
14 | # data files
15 | data/senteval_data*
16 | data/downstream/
17 |
--------------------------------------------------------------------------------
/SentEval/senteval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import
9 |
10 | from senteval.engine import SE
11 |
--------------------------------------------------------------------------------
/SentEval/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import io
9 | from setuptools import setup, find_packages
10 |
11 | with io.open('./README.md', encoding='utf-8') as f:
12 | readme = f.read()
13 |
14 | setup(
15 | name='SentEval',
16 | version='0.1.0',
17 | url='https://github.com/facebookresearch/SentEval',
18 | packages=find_packages(exclude=['examples']),
19 | license='Attribution-NonCommercial 4.0 International',
20 | long_description=readme,
21 | )
22 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 | from typing import Optional
3 |
4 | class DenoSentConfig(PretrainedConfig):
5 | def __init__(self,
6 | encoder_name_or_path:Optional[str]=None,
7 | hidden_size:Optional[int]=768,
8 | max_length:Optional[int]=32,
9 | decoder_num_heads:Optional[int]=1,
10 | decoder_num_layers:Optional[int]=16,
11 | decoder_noise_dropout:Optional[float]=0.825,
12 | pooler:Optional[str]='mask',
13 | do_contrastive:Optional[bool]=False,
14 | do_generative:Optional[bool]=False,
15 | prompt_format:Optional[str]='[X] means [MASK]',
16 | contrastive_weight:Optional[float]=1.0,
17 | generative_weight:Optional[float]=1.0,
18 | contrastive_temp: Optional[float]=0.05,
19 | **kwargs):
20 | super().__init__(**kwargs)
21 | self.encoder_name_or_path = encoder_name_or_path
22 | self.hidden_size = hidden_size
23 | self.max_length = max_length
24 | self.decoder_num_heads = decoder_num_heads
25 | self.decoder_num_layers = decoder_num_layers
26 | self.decoder_noise_dropout = decoder_noise_dropout
27 | self.pooler = pooler
28 | self.do_contrastive = do_contrastive
29 | self.do_generative = do_generative
30 | self.prompt_format = prompt_format
31 | self.contrastive_weight = contrastive_weight
32 | self.generative_weight = generative_weight
33 | self.contrastive_temp = contrastive_temp
--------------------------------------------------------------------------------
/SentEval/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | For SentEval software
4 |
5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without modification,
8 | are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name Facebook nor the names of its contributors may be used to
18 | endorse or promote products derived from this software without specific
19 | prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional, Tuple, Union, Dict
3 |
4 | @dataclass
5 | class ModelArguments:
6 | model_name_or_path: Optional[str] = field(
7 | default='bert-base-uncased'
8 | )
9 | max_length: Optional[int] = field(
10 | default=32
11 | )
12 | pooler: Optional[str] = field(
13 | default='cls'
14 | )
15 | prompt_format: Optional[str] = field(
16 | default='"[X]" means [MASK].'
17 | )
18 | decoder_num_layers: Optional[int] = field(
19 | default=16
20 | )
21 | decoder_num_heads: Optional[int] = field(
22 | default=1
23 | )
24 | decoder_target_dropout: Optional[float] = field(
25 | default=0.825
26 | )
27 |
28 | do_contrastive: Optional[bool] = field(
29 | default=False
30 | )
31 | do_generative: Optional[bool] = field(
32 | default=False
33 | )
34 | contrastive_temp: Optional[float] = field(
35 | default=0.05
36 | )
37 | contrastive_weight: Optional[float] = field(
38 | default=1.0
39 | )
40 | generative_weight: Optional[float] = field(
41 | default=1.0
42 | )
43 |
44 |
45 | @dataclass
46 | class DatasetArguments:
47 | train_dataset: Optional[str] = field(
48 | # Singhoo/stssickr, princeton-nlp/datasets-for-simcse, bookcorpus
49 | default='Singhoo/denosent_data',
50 | metadata={
51 | 'help': 'Can be princeton-nlp/datasets-for-simcse, wiki1m-aug, wiki1m-aug-cleaned, Singhoo/wiki1m_translated, Singhoo/stssickr, bookcorpus.'
52 | }
53 | )
54 | split: Optional[str] = field(
55 | default='train'
56 | )
57 | use_auth_token: Optional[bool] = field(
58 | default=False
59 | )
60 | group: Optional[str] = field(
61 | default=None
62 | )
63 |
--------------------------------------------------------------------------------
/eval_mteb.py:
--------------------------------------------------------------------------------
1 |
2 | from mteb import MTEB
3 | import argparse
4 | import logging
5 | from model import DenoSentModel
6 | from config import DenoSentConfig
7 |
8 | logging.basicConfig(level=logging.INFO)
9 |
10 | TASK_CLASSIFICATION = [
11 | "AmazonCounterfactualClassification",
12 | "AmazonReviewsClassification",
13 | "Banking77Classification",
14 | "EmotionClassification",
15 | "MassiveIntentClassification",
16 | "MassiveScenarioClassification",
17 | "MTOPDomainClassification",
18 | "MTOPIntentClassification",
19 | "ToxicConversationsClassification",
20 | "TweetSentimentExtractionClassification",
21 | ]
22 |
23 |
24 |
25 | TASK_RERANKING = [
26 | "AskUbuntuDupQuestions",
27 | "MindSmallReranking",
28 | "SciDocsRR",
29 | "StackOverflowDupQuestions",
30 | ]
31 |
32 | TASK_RETRIEVAL = [
33 | "QuoraRetrieval",
34 | ]
35 |
36 | TASK_STS = [
37 | "SICK-R",
38 | "STS12",
39 | "STS13",
40 | "STS14",
41 | "STS15",
42 | "STS16",
43 | "STSBenchmark",
44 | ]
45 |
46 | TASK_LIST = TASK_CLASSIFICATION + TASK_RERANKING + TASK_RETRIEVAL + TASK_STS
47 |
48 |
49 | def main():
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument("--model_name_or_path", type=str,
52 | help="Transformers' model name or path")
53 |
54 | args = parser.parse_args()
55 |
56 | config = DenoSentConfig.from_pretrained(args.model_name_or_path)
57 | model = DenoSentModel.from_pretrained(args.model_name_or_path, config=config)
58 | model = model.to("cuda")
59 | model.eval()
60 |
61 | eval_splits = ["test"]
62 | evaluation = MTEB(tasks=TASK_LIST, task_langs=["en"], task_categories=['S2S'])
63 | evaluation.run(model, overwrite_results=True, batch_size=64, eval_splits=eval_splits, output_folder='mteb_results/'+args.model_name_or_path.split('/')[-1])
64 |
65 | if __name__ == '__main__':
66 | main()
--------------------------------------------------------------------------------
/SentEval/examples/skipthought.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import, division, unicode_literals
9 |
10 | """
11 | Example of file for SkipThought in SentEval
12 | """
13 | import logging
14 | import sys
15 | sys.setdefaultencoding('utf8')
16 |
17 |
18 | # Set PATHs
19 | PATH_TO_SENTEVAL = '../'
20 | PATH_TO_DATA = '../data/senteval_data/'
21 | PATH_TO_SKIPTHOUGHT = ''
22 |
23 | assert PATH_TO_SKIPTHOUGHT != '', 'Download skipthought and set correct PATH'
24 |
25 | # import skipthought and Senteval
26 | sys.path.insert(0, PATH_TO_SKIPTHOUGHT)
27 | import skipthoughts
28 | sys.path.insert(0, PATH_TO_SENTEVAL)
29 | import senteval
30 |
31 |
32 | def prepare(params, samples):
33 | return
34 |
35 | def batcher(params, batch):
36 | batch = [str(' '.join(sent), errors="ignore") if sent != [] else '.' for sent in batch]
37 | embeddings = skipthoughts.encode(params['encoder'], batch,
38 | verbose=False, use_eos=True)
39 | return embeddings
40 |
41 |
42 | # Set params for SentEval
43 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'batch_size': 512}
44 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
45 | 'tenacity': 5, 'epoch_size': 4}
46 | # Set up logger
47 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
48 |
49 | if __name__ == "__main__":
50 | # Load SkipThought model
51 | params_senteval['encoder'] = skipthoughts.load_model()
52 |
53 | se = senteval.engine.SE(params_senteval, batcher, prepare)
54 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
55 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
56 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
57 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
58 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
59 | 'OddManOut', 'CoordinationInversion']
60 | results = se.eval(transfer_tasks)
61 | print(results)
62 |
--------------------------------------------------------------------------------
/SentEval/examples/googleuse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import, division
9 |
10 | import os
11 | import sys
12 | import logging
13 | import tensorflow as tf
14 | import tensorflow_hub as hub
15 | tf.logging.set_verbosity(0)
16 |
17 | # Set PATHs
18 | PATH_TO_SENTEVAL = '../'
19 | PATH_TO_DATA = '../data'
20 |
21 | # import SentEval
22 | sys.path.insert(0, PATH_TO_SENTEVAL)
23 | import senteval
24 |
25 | # tensorflow session
26 | session = tf.Session()
27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
28 |
29 | # SentEval prepare and batcher
30 | def prepare(params, samples):
31 | return
32 |
33 | def batcher(params, batch):
34 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
35 | embeddings = params['google_use'](batch)
36 | return embeddings
37 |
38 | def make_embed_fn(module):
39 | with tf.Graph().as_default():
40 | sentences = tf.placeholder(tf.string)
41 | embed = hub.Module(module)
42 | embeddings = embed(sentences)
43 | session = tf.train.MonitoredSession()
44 | return lambda x: session.run(embeddings, {sentences: x})
45 |
46 | # Start TF session and load Google Universal Sentence Encoder
47 | encoder = make_embed_fn("https://tfhub.dev/google/universal-sentence-encoder-large/2")
48 |
49 | # Set params for SentEval
50 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
51 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
52 | 'tenacity': 3, 'epoch_size': 2}
53 | params_senteval['google_use'] = encoder
54 |
55 | # Set up logger
56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
57 |
58 | if __name__ == "__main__":
59 | se = senteval.engine.SE(params_senteval, batcher, prepare)
60 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
61 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
62 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
63 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
64 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
65 | 'OddManOut', 'CoordinationInversion']
66 | results = se.eval(transfer_tasks)
67 | print(results)
68 |
--------------------------------------------------------------------------------
/SentEval/examples/gensen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Clone GenSen repo here: https://github.com/Maluuba/gensen.git
10 | And follow instructions for loading the model used in batcher
11 | """
12 |
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | import sys
16 | import logging
17 | # import GenSen package
18 | from gensen import GenSen, GenSenSingle
19 |
20 | # Set PATHs
21 | PATH_TO_SENTEVAL = '../'
22 | PATH_TO_DATA = '../data'
23 |
24 | # import SentEval
25 | sys.path.insert(0, PATH_TO_SENTEVAL)
26 | import senteval
27 |
28 | # SentEval prepare and batcher
29 | def prepare(params, samples):
30 | return
31 |
32 | def batcher(params, batch):
33 | batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
34 | _, reps_h_t = gensen.get_representation(
35 | sentences, pool='last', return_numpy=True, tokenize=True
36 | )
37 | embeddings = reps_h_t
38 | return embeddings
39 |
40 | # Load GenSen model
41 | gensen_1 = GenSenSingle(
42 | model_folder='../data/models',
43 | filename_prefix='nli_large_bothskip',
44 | pretrained_emb='../data/embedding/glove.840B.300d.h5'
45 | )
46 | gensen_2 = GenSenSingle(
47 | model_folder='../data/models',
48 | filename_prefix='nli_large_bothskip_parse',
49 | pretrained_emb='../data/embedding/glove.840B.300d.h5'
50 | )
51 | gensen_encoder = GenSen(gensen_1, gensen_2)
52 | reps_h, reps_h_t = gensen.get_representation(
53 | sentences, pool='last', return_numpy=True, tokenize=True
54 | )
55 |
56 | # Set params for SentEval
57 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
58 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
59 | 'tenacity': 3, 'epoch_size': 2}
60 | params_senteval['gensen'] = gensen_encoder
61 |
62 | # Set up logger
63 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
64 |
65 | if __name__ == "__main__":
66 | se = senteval.engine.SE(params_senteval, batcher, prepare)
67 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
68 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
69 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
70 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
71 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
72 | 'OddManOut', 'CoordinationInversion']
73 | results = se.eval(transfer_tasks)
74 | print(results)
75 |
--------------------------------------------------------------------------------
/SentEval/examples/infersent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | InferSent models. See https://github.com/facebookresearch/InferSent.
10 | """
11 |
12 | from __future__ import absolute_import, division, unicode_literals
13 |
14 | import sys
15 | import os
16 | import torch
17 | import logging
18 |
19 | # get models.py from InferSent repo
20 | from models import InferSent
21 |
22 | # Set PATHs
23 | PATH_SENTEVAL = '../'
24 | PATH_TO_DATA = '../data'
25 | PATH_TO_W2V = 'PATH/TO/glove.840B.300d.txt' # or crawl-300d-2M.vec for V2
26 | MODEL_PATH = 'infersent1.pkl'
27 | V = 1 # version of InferSent
28 |
29 | assert os.path.isfile(MODEL_PATH) and os.path.isfile(PATH_TO_W2V), \
30 | 'Set MODEL and GloVe PATHs'
31 |
32 | # import senteval
33 | sys.path.insert(0, PATH_SENTEVAL)
34 | import senteval
35 |
36 |
37 | def prepare(params, samples):
38 | params.infersent.build_vocab([' '.join(s) for s in samples], tokenize=False)
39 |
40 |
41 | def batcher(params, batch):
42 | sentences = [' '.join(s) for s in batch]
43 | embeddings = params.infersent.encode(sentences, bsize=params.batch_size, tokenize=False)
44 | return embeddings
45 |
46 |
47 | """
48 | Evaluation of trained model on Transfer Tasks (SentEval)
49 | """
50 |
51 | # define senteval params
52 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
53 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
54 | 'tenacity': 3, 'epoch_size': 2}
55 | # Set up logger
56 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
57 |
58 | if __name__ == "__main__":
59 | # Load InferSent model
60 | params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
61 | 'pool_type': 'max', 'dpout_model': 0.0, 'version': V}
62 | model = InferSent(params_model)
63 | model.load_state_dict(torch.load(MODEL_PATH))
64 | model.set_w2v_path(PATH_TO_W2V)
65 |
66 | params_senteval['infersent'] = model.cuda()
67 |
68 | se = senteval.engine.SE(params_senteval, batcher, prepare)
69 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
70 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
71 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
72 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
73 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
74 | 'OddManOut', 'CoordinationInversion']
75 | results = se.eval(transfer_tasks)
76 | print(results)
77 |
--------------------------------------------------------------------------------
/SentEval/senteval/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | from __future__ import absolute_import, division, unicode_literals
9 |
10 | import numpy as np
11 | import re
12 | import inspect
13 | from torch import optim
14 |
15 |
16 | def create_dictionary(sentences):
17 | words = {}
18 | for s in sentences:
19 | for word in s:
20 | if word in words:
21 | words[word] += 1
22 | else:
23 | words[word] = 1
24 | words[''] = 1e9 + 4
25 | words[''] = 1e9 + 3
26 | words['
'] = 1e9 + 2
27 | # words[' '] = 1e9 + 2
43 |
44 | sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort
45 | id2word = []
46 | word2id = {}
47 | for i, (w, _) in enumerate(sorted_words):
48 | id2word.append(w)
49 | word2id[w] = i
50 |
51 | return id2word, word2id
52 |
53 | # Get word vectors from vocabulary (glove, word2vec, fasttext ..)
54 | def get_wordvec(path_to_vec, word2id):
55 | word_vec = {}
56 |
57 | with io.open(path_to_vec, 'r', encoding='utf-8') as f:
58 | # if word2vec or fasttext file : skip first line "next(f)"
59 | for line in f:
60 | word, vec = line.split(' ', 1)
61 | if word in word2id:
62 | word_vec[word] = np.fromstring(vec, sep=' ')
63 |
64 | logging.info('Found {0} words with word vectors, out of \
65 | {1} words'.format(len(word_vec), len(word2id)))
66 | return word_vec
67 |
68 |
69 | # SentEval prepare and batcher
70 | def prepare(params, samples):
71 | _, params.word2id = create_dictionary(samples)
72 | params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id)
73 | params.wvec_dim = 300
74 | return
75 |
76 | def batcher(params, batch):
77 | batch = [sent if sent != [] else ['.'] for sent in batch]
78 | embeddings = []
79 |
80 | for sent in batch:
81 | sentvec = []
82 | for word in sent:
83 | if word in params.word_vec:
84 | sentvec.append(params.word_vec[word])
85 | if not sentvec:
86 | vec = np.zeros(params.wvec_dim)
87 | sentvec.append(vec)
88 | sentvec = np.mean(sentvec, 0)
89 | embeddings.append(sentvec)
90 |
91 | embeddings = np.vstack(embeddings)
92 | return embeddings
93 |
94 |
95 | # Set params for SentEval
96 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
97 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
98 | 'tenacity': 3, 'epoch_size': 2}
99 |
100 | # Set up logger
101 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
102 |
103 | if __name__ == "__main__":
104 | se = senteval.engine.SE(params_senteval, batcher, prepare)
105 | transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
106 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
107 | 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
108 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
109 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
110 | 'OddManOut', 'CoordinationInversion']
111 | results = se.eval(transfer_tasks)
112 | print(results)
113 |
--------------------------------------------------------------------------------
/SentEval/senteval/binary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | Binary classifier and corresponding datasets : MR, CR, SUBJ, MPQA
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import io
14 | import os
15 | import numpy as np
16 | import logging
17 |
18 | from senteval.tools.validation import InnerKFoldClassifier
19 |
20 |
21 | class BinaryClassifierEval(object):
22 | def __init__(self, pos, neg, seed=1111):
23 | self.seed = seed
24 | self.samples, self.labels = pos + neg, [1] * len(pos) + [0] * len(neg)
25 | self.n_samples = len(self.samples)
26 |
27 | def do_prepare(self, params, prepare):
28 | # prepare is given the whole text
29 | return prepare(params, self.samples)
30 | # prepare puts everything it outputs in "params" : params.word2id etc
31 | # Those output will be further used by "batcher".
32 |
33 | def loadFile(self, fpath):
34 | with io.open(fpath, 'r', encoding='latin-1') as f:
35 | return [line.split() for line in f.read().splitlines()]
36 |
37 | def run(self, params, batcher):
38 | enc_input = []
39 | # Sort to reduce padding
40 | sorted_corpus = sorted(zip(self.samples, self.labels),
41 | key=lambda z: (len(z[0]), z[1]))
42 | sorted_samples = [x for (x, y) in sorted_corpus]
43 | sorted_labels = [y for (x, y) in sorted_corpus]
44 | logging.info('Generating sentence embeddings')
45 | for ii in range(0, self.n_samples, params.batch_size):
46 | batch = sorted_samples[ii:ii + params.batch_size]
47 | embeddings = batcher(params, batch)
48 | enc_input.append(embeddings)
49 | enc_input = np.vstack(enc_input)
50 | logging.info('Generated sentence embeddings')
51 |
52 | config = {'nclasses': 2, 'seed': self.seed,
53 | 'usepytorch': params.usepytorch,
54 | 'classifier': params.classifier,
55 | 'nhid': params.nhid, 'kfold': params.kfold}
56 | clf = InnerKFoldClassifier(enc_input, np.array(sorted_labels), config)
57 | devacc, testacc = clf.run()
58 | logging.debug('Dev acc : {0} Test acc : {1}\n'.format(devacc, testacc))
59 | return {'devacc': devacc, 'acc': testacc, 'ndev': self.n_samples,
60 | 'ntest': self.n_samples}
61 |
62 |
63 | class CREval(BinaryClassifierEval):
64 | def __init__(self, task_path, seed=1111):
65 | logging.debug('***** Transfer task : CR *****\n\n')
66 | pos = self.loadFile(os.path.join(task_path, 'custrev.pos'))
67 | neg = self.loadFile(os.path.join(task_path, 'custrev.neg'))
68 | super(self.__class__, self).__init__(pos, neg, seed)
69 |
70 |
71 | class MREval(BinaryClassifierEval):
72 | def __init__(self, task_path, seed=1111):
73 | logging.debug('***** Transfer task : MR *****\n\n')
74 | pos = self.loadFile(os.path.join(task_path, 'rt-polarity.pos'))
75 | neg = self.loadFile(os.path.join(task_path, 'rt-polarity.neg'))
76 | super(self.__class__, self).__init__(pos, neg, seed)
77 |
78 |
79 | class SUBJEval(BinaryClassifierEval):
80 | def __init__(self, task_path, seed=1111):
81 | logging.debug('***** Transfer task : SUBJ *****\n\n')
82 | obj = self.loadFile(os.path.join(task_path, 'subj.objective'))
83 | subj = self.loadFile(os.path.join(task_path, 'subj.subjective'))
84 | super(self.__class__, self).__init__(obj, subj, seed)
85 |
86 |
87 | class MPQAEval(BinaryClassifierEval):
88 | def __init__(self, task_path, seed=1111):
89 | logging.debug('***** Transfer task : MPQA *****\n\n')
90 | pos = self.loadFile(os.path.join(task_path, 'mpqa.pos'))
91 | neg = self.loadFile(os.path.join(task_path, 'mpqa.neg'))
92 | super(self.__class__, self).__init__(pos, neg, seed)
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DenoSent: A Denoising Objective for Self-Supervised Sentence Representation Learning
2 |
3 | Official repo for our AAAI 2024 paper: [DenoSent: A Denoising Objective for Self-Supervised Sentence Representation Learning](https://arxiv.org/abs/2401.13621).
4 |
5 | ## Getting Started
6 |
7 | Run `pip install -r requirements.txt` to prepare the environment.
8 |
9 | Use the script from the [SimCSE repo](https://github.com/princeton-nlp/SimCSE) to download the datasets for SentEval evaluation:
10 | ```
11 | cd SentEval/data/downstream/
12 | bash download_dataset.sh
13 | ```
14 |
15 | ## Access Our Model and Dataset from Huggingface🤗
16 | Both our [model checkpoint](https://huggingface.co/Singhoo/denosent-bert-base) and [dataset](https://huggingface.co/datasets/Singhoo/denosent_data) are available on 🤗.
17 |
18 | Generate embeddings with DenoSent:
19 | ```
20 | from transformers import AutoModel
21 |
22 | model = AutoModel.from_pretrained("Singhoo/denosent-bert-base", trust_remote_code=True)
23 |
24 | sentences = [
25 | "The curious cat tiptoed across the creaky wooden floor, pausing to inspect a fluttering curtain.",
26 | "A lone hiker stood atop the misty mountain, marveling at the tapestry of stars unfolding above."
27 | ]
28 |
29 | embeddings = model.encode(sentences)
30 | print(embeddings)
31 |
32 | # Excepted output
33 | # tensor([[ 0.3314, -0.2520, 0.4150, ..., 0.1575, -0.1235, -0.1226],
34 | # [ 0.5128, -0.0051, 0.2179, ..., 0.1010, 0.1654, -0.3872]])
35 | ```
36 |
37 | ## Evaluation
38 |
39 | ### Run Evaluation with SentEval
40 | ```
41 | python eval_senteval.py \
42 | --model_name_or_path Singhoo/denosent-bert-base \
43 | --task_set sts \
44 | --mode test \
45 | ```
46 | This checkpoint has slightly higher STS results than those reported in the paper.
47 | ```
48 | ------ test ------
49 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+
50 | | STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. |
51 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+
52 | | 75.48 | 83.82 | 77.54 | 84.76 | 80.16 | 81.20 | 73.97 | 79.56 |
53 | +-------+-------+-------+-------+-------+--------------+-----------------+-------+
54 | ```
55 |
56 | ### Run evaluation with MTEB
57 | ```
58 | python eval_mteb.py \
59 | --model_name_or_path Singhoo/denosent-bert-base \
60 | ```
61 | Evaluation results for MTEB will appear in a separate directory `mteb_results`.
62 |
63 | ## Train Your Own DenoSent Models
64 | Run the following command to train your own models. Try out different hyperparameters as you like. The dataset will be automatically downloaded from Huggingface.
65 | ```
66 | python \
67 | train.py \
68 | --train_dataset Singhoo/denosent_data \
69 | --torch_compile True \
70 | --model_name_or_path bert-base-uncased \
71 | --max_length 32 \
72 | --decoder_num_layers 16 \
73 | --decoder_num_heads 1 \
74 | --decoder_target_dropout 0.825 \
75 | --pooler mask \
76 | --output_dir results \
77 | --overwrite_output_dir \
78 | --per_device_train_batch_size 64 \
79 | --per_device_eval_batch_size 256 \
80 | --learning_rate 4e-5 \
81 | --lr_scheduler_type constant_with_warmup \
82 | --do_train \
83 | --do_eval \
84 | --evaluation_strategy steps \
85 | --eval_steps 50 \
86 | --save_strategy steps \
87 | --save_steps 50 \
88 | --num_train_epochs 1 \
89 | --metric_for_best_model eval_avg_sts \
90 | --prompt_format '"[X]" means [MASK].' \
91 | --do_contrastive \
92 | --do_generative \
93 | --save_total_limit 1 \
94 | --contrastive_temp 0.05 \
95 | --warmup_steps 500 \
96 | --contrastive_weight 5 \
97 | --generative_weight 7 \
98 | --max_steps 5000 \
99 | --load_best_model_at_end \
100 | ```
101 |
102 | ## Acknowledgements
103 |
104 | We use the [SentEval toolkit](https://github.com/facebookresearch/SentEval) and the [MTEB toolkit](https://github.com/embeddings-benchmark/mteb) for evaluations, and we adopt the modified version of SentEval from the [SimCSE repository](https://github.com/princeton-nlp/SimCSE).
105 |
--------------------------------------------------------------------------------
/SentEval/senteval/sst.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | SST - binary classification
10 | '''
11 |
12 | from __future__ import absolute_import, division, unicode_literals
13 |
14 | import os
15 | import io
16 | import logging
17 | import numpy as np
18 |
19 | from senteval.tools.validation import SplitClassifier
20 |
21 |
22 | class SSTEval(object):
23 | def __init__(self, task_path, nclasses=2, seed=1111):
24 | self.seed = seed
25 |
26 | # binary of fine-grained
27 | assert nclasses in [2, 5]
28 | self.nclasses = nclasses
29 | self.task_name = 'Binary' if self.nclasses == 2 else 'Fine-Grained'
30 | logging.debug('***** Transfer task : SST %s classification *****\n\n', self.task_name)
31 |
32 | train = self.loadFile(os.path.join(task_path, 'sentiment-train'))
33 | dev = self.loadFile(os.path.join(task_path, 'sentiment-dev'))
34 | test = self.loadFile(os.path.join(task_path, 'sentiment-test'))
35 | self.sst_data = {'train': train, 'dev': dev, 'test': test}
36 |
37 | def do_prepare(self, params, prepare):
38 | samples = self.sst_data['train']['X'] + self.sst_data['dev']['X'] + \
39 | self.sst_data['test']['X']
40 | return prepare(params, samples)
41 |
42 | def loadFile(self, fpath):
43 | sst_data = {'X': [], 'y': []}
44 | with io.open(fpath, 'r', encoding='utf-8') as f:
45 | for line in f:
46 | if self.nclasses == 2:
47 | sample = line.strip().split('\t')
48 | sst_data['y'].append(int(sample[1]))
49 | sst_data['X'].append(sample[0].split())
50 | elif self.nclasses == 5:
51 | sample = line.strip().split(' ', 1)
52 | sst_data['y'].append(int(sample[0]))
53 | sst_data['X'].append(sample[1].split())
54 | assert max(sst_data['y']) == self.nclasses - 1
55 | return sst_data
56 |
57 | def run(self, params, batcher):
58 | sst_embed = {'train': {}, 'dev': {}, 'test': {}}
59 | bsize = params.batch_size
60 |
61 | for key in self.sst_data:
62 | logging.info('Computing embedding for {0}'.format(key))
63 | # Sort to reduce padding
64 | sorted_data = sorted(zip(self.sst_data[key]['X'],
65 | self.sst_data[key]['y']),
66 | key=lambda z: (len(z[0]), z[1]))
67 | self.sst_data[key]['X'], self.sst_data[key]['y'] = map(list, zip(*sorted_data))
68 |
69 | sst_embed[key]['X'] = []
70 | for ii in range(0, len(self.sst_data[key]['y']), bsize):
71 | batch = self.sst_data[key]['X'][ii:ii + bsize]
72 | embeddings = batcher(params, batch)
73 | sst_embed[key]['X'].append(embeddings)
74 | sst_embed[key]['X'] = np.vstack(sst_embed[key]['X'])
75 | sst_embed[key]['y'] = np.array(self.sst_data[key]['y'])
76 | logging.info('Computed {0} embeddings'.format(key))
77 |
78 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
79 | 'usepytorch': params.usepytorch,
80 | 'classifier': params.classifier}
81 |
82 | clf = SplitClassifier(X={'train': sst_embed['train']['X'],
83 | 'valid': sst_embed['dev']['X'],
84 | 'test': sst_embed['test']['X']},
85 | y={'train': sst_embed['train']['y'],
86 | 'valid': sst_embed['dev']['y'],
87 | 'test': sst_embed['test']['y']},
88 | config=config_classifier)
89 |
90 | devacc, testacc = clf.run()
91 | logging.debug('\nDev acc : {0} Test acc : {1} for \
92 | SST {2} classification\n'.format(devacc, testacc, self.task_name))
93 |
94 | return {'devacc': devacc, 'acc': testacc,
95 | 'ndev': len(sst_embed['dev']['X']),
96 | 'ntest': len(sst_embed['test']['X'])}
97 |
--------------------------------------------------------------------------------
/SentEval/senteval/mrpc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | MRPC : Microsoft Research Paraphrase (detection) Corpus
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import os
14 | import logging
15 | import numpy as np
16 | import io
17 |
18 | from senteval.tools.validation import KFoldClassifier
19 |
20 | from sklearn.metrics import f1_score
21 |
22 |
23 | class MRPCEval(object):
24 | def __init__(self, task_path, seed=1111):
25 | logging.info('***** Transfer task : MRPC *****\n\n')
26 | self.seed = seed
27 | train = self.loadFile(os.path.join(task_path,
28 | 'msr_paraphrase_train.txt'))
29 | test = self.loadFile(os.path.join(task_path,
30 | 'msr_paraphrase_test.txt'))
31 | self.mrpc_data = {'train': train, 'test': test}
32 |
33 | def do_prepare(self, params, prepare):
34 | # TODO : Should we separate samples in "train, test"?
35 | samples = self.mrpc_data['train']['X_A'] + \
36 | self.mrpc_data['train']['X_B'] + \
37 | self.mrpc_data['test']['X_A'] + self.mrpc_data['test']['X_B']
38 | return prepare(params, samples)
39 |
40 | def loadFile(self, fpath):
41 | mrpc_data = {'X_A': [], 'X_B': [], 'y': []}
42 | with io.open(fpath, 'r', encoding='utf-8') as f:
43 | for line in f:
44 | text = line.strip().split('\t')
45 | mrpc_data['X_A'].append(text[3].split())
46 | mrpc_data['X_B'].append(text[4].split())
47 | mrpc_data['y'].append(text[0])
48 |
49 | mrpc_data['X_A'] = mrpc_data['X_A'][1:]
50 | mrpc_data['X_B'] = mrpc_data['X_B'][1:]
51 | mrpc_data['y'] = [int(s) for s in mrpc_data['y'][1:]]
52 | return mrpc_data
53 |
54 | def run(self, params, batcher):
55 | mrpc_embed = {'train': {}, 'test': {}}
56 |
57 | for key in self.mrpc_data:
58 | logging.info('Computing embedding for {0}'.format(key))
59 | # Sort to reduce padding
60 | text_data = {}
61 | sorted_corpus = sorted(zip(self.mrpc_data[key]['X_A'],
62 | self.mrpc_data[key]['X_B'],
63 | self.mrpc_data[key]['y']),
64 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
65 |
66 | text_data['A'] = [x for (x, y, z) in sorted_corpus]
67 | text_data['B'] = [y for (x, y, z) in sorted_corpus]
68 | text_data['y'] = [z for (x, y, z) in sorted_corpus]
69 |
70 | for txt_type in ['A', 'B']:
71 | mrpc_embed[key][txt_type] = []
72 | for ii in range(0, len(text_data['y']), params.batch_size):
73 | batch = text_data[txt_type][ii:ii + params.batch_size]
74 | embeddings = batcher(params, batch)
75 | mrpc_embed[key][txt_type].append(embeddings)
76 | mrpc_embed[key][txt_type] = np.vstack(mrpc_embed[key][txt_type])
77 | mrpc_embed[key]['y'] = np.array(text_data['y'])
78 | logging.info('Computed {0} embeddings'.format(key))
79 |
80 | # Train
81 | trainA = mrpc_embed['train']['A']
82 | trainB = mrpc_embed['train']['B']
83 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
84 | trainY = mrpc_embed['train']['y']
85 |
86 | # Test
87 | testA = mrpc_embed['test']['A']
88 | testB = mrpc_embed['test']['B']
89 | testF = np.c_[np.abs(testA - testB), testA * testB]
90 | testY = mrpc_embed['test']['y']
91 |
92 | config = {'nclasses': 2, 'seed': self.seed,
93 | 'usepytorch': params.usepytorch,
94 | 'classifier': params.classifier,
95 | 'nhid': params.nhid, 'kfold': params.kfold}
96 | clf = KFoldClassifier(train={'X': trainF, 'y': trainY},
97 | test={'X': testF, 'y': testY}, config=config)
98 |
99 | devacc, testacc, yhat = clf.run()
100 | testf1 = round(100*f1_score(testY, yhat), 2)
101 | logging.debug('Dev acc : {0} Test acc {1}; Test F1 {2} for MRPC.\n'
102 | .format(devacc, testacc, testf1))
103 | return {'devacc': devacc, 'acc': testacc, 'f1': testf1,
104 | 'ndev': len(trainA), 'ntest': len(testA)}
105 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from transformers import Trainer
2 | from transformers.trainer import unwrap_model
3 | from typing import List, Optional, Dict
4 | import wandb
5 |
6 | import sys
7 |
8 | from torch.utils.data.dataset import Dataset
9 |
10 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
11 |
12 |
13 |
14 | # Set path to SentEval
15 | PATH_TO_SENTEVAL = './SentEval'
16 | PATH_TO_DATA = './SentEval/data'
17 |
18 | # Import SentEval
19 | sys.path.insert(0, PATH_TO_SENTEVAL)
20 | import senteval
21 |
22 | from mteb import MTEB
23 |
24 | class MyTrainer(Trainer):
25 | def __init__(self, *args, **kwargs):
26 | super().__init__(*args, **kwargs)
27 | self.best_stsb_spearman = 0
28 | self.best_sickr_spearman = 0
29 | self.best_avg_sts = 0
30 |
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | ) -> Dict[str, float]:
37 | metrics = {}
38 | # SentEval prepare and batcher
39 | def prepare(params, samples):
40 | return
41 |
42 | def batcher(params, batch):
43 | sentences = [' '.join(s) for s in batch]
44 | return self.model.encode(sentences, len(sentences))
45 |
46 | # Set params for SentEval (fastmode)
47 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
48 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
49 | 'tenacity': 3, 'epoch_size': 2}
50 |
51 | se = senteval.engine.SE(params, batcher, prepare)
52 | tasks = ['STSBenchmark', 'SICKRelatedness']
53 | self.model.eval()
54 | results = se.eval(tasks)
55 |
56 | stsb_spearman = results['STSBenchmark']['dev']['spearman'][0]
57 | sickr_spearman = results['SICKRelatedness']['dev']['spearman'][0]
58 | # evaluation = MTEB(tasks=['STSBenchmark'], task_langs=["en"], task_categories=['S2S'])
59 | # results = evaluation.run(self.model, verbosity=0, output_folder=None, eval_splits=['validation'], batch_size=self.args.eval_batch_size)
60 | # stsb_spearman = results['STSBenchmark']['validation']['cos_sim']['spearman']
61 | # sickr_spearman = results['SICK-R']['validation']['cos_sim']['spearman']
62 | metrics.update({"eval_stsb_spearman": stsb_spearman, "eval_sickr_spearman": sickr_spearman, "eval_avg_sts": (stsb_spearman + sickr_spearman) / 2})
63 | # metrics.update({"eval_stsb_spearman": stsb_spearman})
64 | if stsb_spearman > self.best_stsb_spearman:
65 | self.best_stsb_spearman = stsb_spearman
66 | if sickr_spearman > self.best_sickr_spearman:
67 | self.best_sickr_spearman = sickr_spearman
68 | if (stsb_spearman + sickr_spearman) / 2 > self.best_avg_sts:
69 | self.best_avg_sts = (stsb_spearman + sickr_spearman) / 2
70 | wandb.run.summary["best_stsb_spearman"] = self.best_stsb_spearman
71 | wandb.run.summary["best_sickr_spearman"] = self.best_sickr_spearman
72 | wandb.run.summary["best_avg_sts"] = self.best_avg_sts
73 | self.log(metrics)
74 | return metrics
75 |
76 | def compute_loss(self, model, inputs, return_outputs=False):
77 | """
78 | How the loss is computed by Trainer. By default, all models return the loss in the first element.
79 |
80 | Subclass and override for custom behavior.
81 | """
82 | if self.label_smoother is not None and "labels" in inputs:
83 | labels = inputs.pop("labels")
84 | else:
85 | labels = None
86 | outputs = model(**inputs, global_step=self.state.global_step, max_steps=self.state.max_steps)
87 | # Save past state if it exists
88 | # TODO: this needs to be fixed and made cleaner later.
89 | if self.args.past_index >= 0:
90 | self._past = outputs[self.args.past_index]
91 |
92 | if labels is not None:
93 | if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
94 | loss = self.label_smoother(outputs, labels, shift_labels=True)
95 | else:
96 | loss = self.label_smoother(outputs, labels)
97 | else:
98 | if isinstance(outputs, dict) and "loss" not in outputs:
99 | raise ValueError(
100 | "The model did not return a loss from the inputs, only the following keys: "
101 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
102 | )
103 | # We don't use .loss here since the model may return tuples instead of ModelOutput.
104 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
105 |
106 | return (loss, outputs) if return_outputs else loss
107 |
--------------------------------------------------------------------------------
/SentEval/senteval/snli.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | SNLI - Entailment
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import codecs
14 | import os
15 | import io
16 | import copy
17 | import logging
18 | import numpy as np
19 |
20 | from senteval.tools.validation import SplitClassifier
21 |
22 |
23 | class SNLIEval(object):
24 | def __init__(self, taskpath, seed=1111):
25 | logging.debug('***** Transfer task : SNLI Entailment*****\n\n')
26 | self.seed = seed
27 | train1 = self.loadFile(os.path.join(taskpath, 's1.train'))
28 | train2 = self.loadFile(os.path.join(taskpath, 's2.train'))
29 |
30 | trainlabels = io.open(os.path.join(taskpath, 'labels.train'),
31 | encoding='utf-8').read().splitlines()
32 |
33 | valid1 = self.loadFile(os.path.join(taskpath, 's1.dev'))
34 | valid2 = self.loadFile(os.path.join(taskpath, 's2.dev'))
35 | validlabels = io.open(os.path.join(taskpath, 'labels.dev'),
36 | encoding='utf-8').read().splitlines()
37 |
38 | test1 = self.loadFile(os.path.join(taskpath, 's1.test'))
39 | test2 = self.loadFile(os.path.join(taskpath, 's2.test'))
40 | testlabels = io.open(os.path.join(taskpath, 'labels.test'),
41 | encoding='utf-8').read().splitlines()
42 |
43 | # sort data (by s2 first) to reduce padding
44 | sorted_train = sorted(zip(train2, train1, trainlabels),
45 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
46 | train2, train1, trainlabels = map(list, zip(*sorted_train))
47 |
48 | sorted_valid = sorted(zip(valid2, valid1, validlabels),
49 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
50 | valid2, valid1, validlabels = map(list, zip(*sorted_valid))
51 |
52 | sorted_test = sorted(zip(test2, test1, testlabels),
53 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
54 | test2, test1, testlabels = map(list, zip(*sorted_test))
55 |
56 | self.samples = train1 + train2 + valid1 + valid2 + test1 + test2
57 | self.data = {'train': (train1, train2, trainlabels),
58 | 'valid': (valid1, valid2, validlabels),
59 | 'test': (test1, test2, testlabels)
60 | }
61 |
62 | def do_prepare(self, params, prepare):
63 | return prepare(params, self.samples)
64 |
65 | def loadFile(self, fpath):
66 | with codecs.open(fpath, 'rb', 'latin-1') as f:
67 | return [line.split() for line in
68 | f.read().splitlines()]
69 |
70 | def run(self, params, batcher):
71 | self.X, self.y = {}, {}
72 | dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
73 | for key in self.data:
74 | if key not in self.X:
75 | self.X[key] = []
76 | if key not in self.y:
77 | self.y[key] = []
78 |
79 | input1, input2, mylabels = self.data[key]
80 | enc_input = []
81 | n_labels = len(mylabels)
82 | for ii in range(0, n_labels, params.batch_size):
83 | batch1 = input1[ii:ii + params.batch_size]
84 | batch2 = input2[ii:ii + params.batch_size]
85 |
86 | if len(batch1) == len(batch2) and len(batch1) > 0:
87 | enc1 = batcher(params, batch1)
88 | enc2 = batcher(params, batch2)
89 | enc_input.append(np.hstack((enc1, enc2, enc1 * enc2,
90 | np.abs(enc1 - enc2))))
91 | if (ii*params.batch_size) % (20000*params.batch_size) == 0:
92 | logging.info("PROGRESS (encoding): %.2f%%" %
93 | (100 * ii / n_labels))
94 | self.X[key] = np.vstack(enc_input)
95 | self.y[key] = [dico_label[y] for y in mylabels]
96 |
97 | config = {'nclasses': 3, 'seed': self.seed,
98 | 'usepytorch': params.usepytorch,
99 | 'cudaEfficient': True,
100 | 'nhid': params.nhid, 'noreg': True}
101 |
102 | config_classifier = copy.deepcopy(params.classifier)
103 | config_classifier['max_epoch'] = 15
104 | config_classifier['epoch_size'] = 1
105 | config['classifier'] = config_classifier
106 |
107 | clf = SplitClassifier(self.X, self.y, config)
108 | devacc, testacc = clf.run()
109 | logging.debug('Dev acc : {0} Test acc : {1} for SNLI\n'
110 | .format(devacc, testacc))
111 | return {'devacc': devacc, 'acc': testacc,
112 | 'ndev': len(self.data['valid'][0]),
113 | 'ntest': len(self.data['test'][0])}
114 |
--------------------------------------------------------------------------------
/SentEval/senteval/rank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | Image-Caption Retrieval with COCO dataset
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import os
14 | import sys
15 | import logging
16 | import numpy as np
17 |
18 | try:
19 | import cPickle as pickle
20 | except ImportError:
21 | import pickle
22 |
23 | from senteval.tools.ranking import ImageSentenceRankingPytorch
24 |
25 |
26 | class ImageCaptionRetrievalEval(object):
27 | def __init__(self, task_path, seed=1111):
28 | logging.debug('***** Transfer task: Image Caption Retrieval *****\n\n')
29 |
30 | # Get captions and image features
31 | self.seed = seed
32 | train, dev, test = self.loadFile(task_path)
33 | self.coco_data = {'train': train, 'dev': dev, 'test': test}
34 |
35 | def do_prepare(self, params, prepare):
36 | samples = self.coco_data['train']['sent'] + \
37 | self.coco_data['dev']['sent'] + \
38 | self.coco_data['test']['sent']
39 | prepare(params, samples)
40 |
41 | def loadFile(self, fpath):
42 | coco = {}
43 |
44 | for split in ['train', 'valid', 'test']:
45 | list_sent = []
46 | list_img_feat = []
47 | if sys.version_info < (3, 0):
48 | with open(os.path.join(fpath, split + '.pkl')) as f:
49 | cocodata = pickle.load(f)
50 | else:
51 | with open(os.path.join(fpath, split + '.pkl'), 'rb') as f:
52 | cocodata = pickle.load(f, encoding='latin1')
53 |
54 | for imgkey in range(len(cocodata['features'])):
55 | assert len(cocodata['image_to_caption_ids'][imgkey]) >= 5, \
56 | cocodata['image_to_caption_ids'][imgkey]
57 | for captkey in cocodata['image_to_caption_ids'][imgkey][0:5]:
58 | sent = cocodata['captions'][captkey]['cleaned_caption']
59 | sent += ' .' # add punctuation to end of sentence in COCO
60 | list_sent.append(sent.encode('utf-8').split())
61 | list_img_feat.append(cocodata['features'][imgkey])
62 | assert len(list_sent) == len(list_img_feat) and \
63 | len(list_sent) % 5 == 0
64 | list_img_feat = np.array(list_img_feat).astype('float32')
65 | coco[split] = {'sent': list_sent, 'imgfeat': list_img_feat}
66 | return coco['train'], coco['valid'], coco['test']
67 |
68 | def run(self, params, batcher):
69 | coco_embed = {'train': {'sentfeat': [], 'imgfeat': []},
70 | 'dev': {'sentfeat': [], 'imgfeat': []},
71 | 'test': {'sentfeat': [], 'imgfeat': []}}
72 |
73 | for key in self.coco_data:
74 | logging.info('Computing embedding for {0}'.format(key))
75 | # Sort to reduce padding
76 | self.coco_data[key]['sent'] = np.array(self.coco_data[key]['sent'])
77 | self.coco_data[key]['sent'], idx_sort = np.sort(self.coco_data[key]['sent']), np.argsort(self.coco_data[key]['sent'])
78 | idx_unsort = np.argsort(idx_sort)
79 |
80 | coco_embed[key]['X'] = []
81 | nsent = len(self.coco_data[key]['sent'])
82 | for ii in range(0, nsent, params.batch_size):
83 | batch = self.coco_data[key]['sent'][ii:ii + params.batch_size]
84 | embeddings = batcher(params, batch)
85 | coco_embed[key]['sentfeat'].append(embeddings)
86 | coco_embed[key]['sentfeat'] = np.vstack(coco_embed[key]['sentfeat'])[idx_unsort]
87 | coco_embed[key]['imgfeat'] = np.array(self.coco_data[key]['imgfeat'])
88 | logging.info('Computed {0} embeddings'.format(key))
89 |
90 | config = {'seed': self.seed, 'projdim': 1000, 'margin': 0.2}
91 | clf = ImageSentenceRankingPytorch(train=coco_embed['train'],
92 | valid=coco_embed['dev'],
93 | test=coco_embed['test'],
94 | config=config)
95 |
96 | bestdevscore, r1_i2t, r5_i2t, r10_i2t, medr_i2t, \
97 | r1_t2i, r5_t2i, r10_t2i, medr_t2i = clf.run()
98 |
99 | logging.debug("\nTest scores | Image to text: \
100 | {0}, {1}, {2}, {3}".format(r1_i2t, r5_i2t, r10_i2t, medr_i2t))
101 | logging.debug("Test scores | Text to image: \
102 | {0}, {1}, {2}, {3}\n".format(r1_t2i, r5_t2i, r10_t2i, medr_t2i))
103 |
104 | return {'devacc': bestdevscore,
105 | 'acc': [(r1_i2t, r5_i2t, r10_i2t, medr_i2t),
106 | (r1_t2i, r5_t2i, r10_t2i, medr_t2i)],
107 | 'ndev': len(coco_embed['dev']['sentfeat']),
108 | 'ntest': len(coco_embed['test']['sentfeat'])}
109 |
--------------------------------------------------------------------------------
/SentEval/senteval/tools/relatedness.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Semantic Relatedness (supervised) with Pytorch
10 | """
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import copy
14 | import numpy as np
15 |
16 | import torch
17 | from torch import nn
18 | import torch.optim as optim
19 |
20 | from scipy.stats import pearsonr, spearmanr
21 |
22 |
23 | class RelatednessPytorch(object):
24 | # Can be used for SICK-Relatedness, and STS14
25 | def __init__(self, train, valid, test, devscores, config):
26 | # fix seed
27 | np.random.seed(config['seed'])
28 | torch.manual_seed(config['seed'])
29 | assert torch.cuda.is_available(), 'torch.cuda required for Relatedness'
30 | torch.cuda.manual_seed(config['seed'])
31 |
32 | self.train = train
33 | self.valid = valid
34 | self.test = test
35 | self.devscores = devscores
36 |
37 | self.inputdim = train['X'].shape[1]
38 | self.nclasses = config['nclasses']
39 | self.seed = config['seed']
40 | self.l2reg = 0.
41 | self.batch_size = 64
42 | self.maxepoch = 1000
43 | self.early_stop = True
44 |
45 | self.model = nn.Sequential(
46 | nn.Linear(self.inputdim, self.nclasses),
47 | nn.Softmax(dim=-1),
48 | )
49 | self.loss_fn = nn.MSELoss()
50 |
51 | if torch.cuda.is_available():
52 | self.model = self.model.cuda()
53 | self.loss_fn = self.loss_fn.cuda()
54 |
55 | self.loss_fn.size_average = False
56 | self.optimizer = optim.Adam(self.model.parameters(),
57 | weight_decay=self.l2reg)
58 |
59 | def prepare_data(self, trainX, trainy, devX, devy, testX, testy):
60 | # Transform probs to log-probs for KL-divergence
61 | trainX = torch.from_numpy(trainX).float().cuda()
62 | trainy = torch.from_numpy(trainy).float().cuda()
63 | devX = torch.from_numpy(devX).float().cuda()
64 | devy = torch.from_numpy(devy).float().cuda()
65 | testX = torch.from_numpy(testX).float().cuda()
66 | testY = torch.from_numpy(testy).float().cuda()
67 |
68 | return trainX, trainy, devX, devy, testX, testy
69 |
70 | def run(self):
71 | self.nepoch = 0
72 | bestpr = -1
73 | early_stop_count = 0
74 | r = np.arange(1, 6)
75 | stop_train = False
76 |
77 | # Preparing data
78 | trainX, trainy, devX, devy, testX, testy = self.prepare_data(
79 | self.train['X'], self.train['y'],
80 | self.valid['X'], self.valid['y'],
81 | self.test['X'], self.test['y'])
82 |
83 | # Training
84 | while not stop_train and self.nepoch <= self.maxepoch:
85 | self.trainepoch(trainX, trainy, nepoches=50)
86 | yhat = np.dot(self.predict_proba(devX), r)
87 | pr = spearmanr(yhat, self.devscores)[0]
88 | pr = 0 if pr != pr else pr # if NaN bc std=0
89 | # early stop on Pearson
90 | if pr > bestpr:
91 | bestpr = pr
92 | bestmodel = copy.deepcopy(self.model)
93 | elif self.early_stop:
94 | if early_stop_count >= 3:
95 | stop_train = True
96 | early_stop_count += 1
97 | self.model = bestmodel
98 |
99 | yhat = np.dot(self.predict_proba(testX), r)
100 |
101 | return bestpr, yhat
102 |
103 | def trainepoch(self, X, y, nepoches=1):
104 | self.model.train()
105 | for _ in range(self.nepoch, self.nepoch + nepoches):
106 | permutation = np.random.permutation(len(X))
107 | all_costs = []
108 | for i in range(0, len(X), self.batch_size):
109 | # forward
110 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda()
111 | Xbatch = X[idx]
112 | ybatch = y[idx]
113 | output = self.model(Xbatch)
114 | # loss
115 | loss = self.loss_fn(output, ybatch)
116 | all_costs.append(loss.item())
117 | # backward
118 | self.optimizer.zero_grad()
119 | loss.backward()
120 | # Update parameters
121 | self.optimizer.step()
122 | self.nepoch += nepoches
123 |
124 | def predict_proba(self, devX):
125 | self.model.eval()
126 | probas = []
127 | with torch.no_grad():
128 | for i in range(0, len(devX), self.batch_size):
129 | Xbatch = devX[i:i + self.batch_size]
130 | if len(probas) == 0:
131 | probas = self.model(Xbatch).data.cpu().numpy()
132 | else:
133 | probas = np.concatenate((probas, self.model(Xbatch).data.cpu().numpy()), axis=0)
134 | return probas
135 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from transformers import set_seed, TrainingArguments, HfArgumentParser, PretrainedConfig
2 | from transformers import AutoTokenizer
3 | from datasets import load_dataset
4 | import torch
5 | import wandb
6 | from args import ModelArguments, DatasetArguments
7 | from model import DenoSentModel
8 | from trainer import MyTrainer
9 | from mteb import MTEB
10 | from prettytable import PrettyTable
11 | from config import DenoSentConfig
12 |
13 | def preprocess_logits_for_metrics(logits, labels):
14 | """
15 | Original Trainer may have a memory leak.
16 | This is a workaround to avoid storing too many tensors that are not needed.
17 | """
18 | pred_ids = torch.argmax(logits[0], dim=-1)
19 | return pred_ids
20 |
21 |
22 | def eval_mteb(model, batch_size):
23 | tasks = [
24 | "STS12",
25 | "STS13",
26 | "STS14",
27 | "STS15",
28 | "STS16",
29 | "STSBenchmark",
30 | "SICK-R",
31 | ]
32 | evaluation = MTEB(tasks=tasks, task_langs=["en"], task_categories=['S2S'])
33 | results = evaluation.run(model, overwrite_results=True, batch_size=batch_size, eval_splits=['test'], output_folder='mteb_results/'+wandb.run.name)
34 | sts12 = results['STS12']['test']['cos_sim']['spearman']
35 | sts13 = results['STS13']['test']['cos_sim']['spearman']
36 | sts14 = results['STS14']['test']['cos_sim']['spearman']
37 | sts15 = results['STS15']['test']['cos_sim']['spearman']
38 | sts16 = results['STS16']['test']['cos_sim']['spearman']
39 | sickr = results['SICK-R']['test']['cos_sim']['spearman']
40 | stsb = results['STSBenchmark']['test']['cos_sim']['spearman']
41 | avg_sts = (sts12 + sts13 + sts14 + sts15 + sts16 + sickr + stsb) / 7
42 | wandb.summary['STS12'] = sts12
43 | wandb.summary['STS13'] = sts13
44 | wandb.summary['STS14'] = sts14
45 | wandb.summary['STS15'] = sts15
46 | wandb.summary['STS16'] = sts16
47 | wandb.summary['SICK-R'] = sickr
48 | wandb.summary['STSBenchmark'] = stsb
49 | wandb.summary['mteb_avg_sts'] = avg_sts
50 | return results
51 |
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DatasetArguments))
56 | model_args, training_args, dataset_args = parser.parse_args_into_dataclasses()
57 | wandb.init(project='DenoSent')
58 | set_seed(training_args.seed)
59 | wandb.config.update(model_args)
60 | wandb.config.update(training_args)
61 | wandb.config.update(dataset_args)
62 | training_args.output_dir = 'results/' + wandb.run.name
63 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
64 | config = DenoSentConfig(
65 | encoder_name_or_path=model_args.model_name_or_path,
66 | max_length=model_args.max_length,
67 | decoder_num_heads=model_args.decoder_num_heads,
68 | decoder_num_layers=model_args.decoder_num_layers,
69 | decoder_noise_dropout=model_args.decoder_target_dropout,
70 | pooler=model_args.pooler,
71 | do_contrastive=model_args.do_contrastive,
72 | do_generative=model_args.do_generative,
73 | prompt_format=model_args.prompt_format,
74 | contrastive_weight=model_args.contrastive_weight,
75 | generative_weight=model_args.generative_weight,
76 | contrastive_temp=model_args.contrastive_temp,
77 | )
78 | print(config)
79 |
80 | model = DenoSentModel(config)
81 |
82 | def map_fn(example):
83 |
84 | max_length = model_args.max_length
85 | if config.pooler == 'mask':
86 | prompt_len = len(tokenizer(config.prompt_format, add_special_tokens=False)['input_ids'])
87 | example['sent0'] = tokenizer.decode(tokenizer(example['sent0'], padding=True, truncation=True, max_length=config.max_length)['input_ids'], skip_special_tokens=True)
88 | example['sent1'] = tokenizer.decode(tokenizer(example['sent1'], padding=True, truncation=True, max_length=config.max_length)['input_ids'], skip_special_tokens=True)
89 | example['sent0'] = config.prompt_format.replace('[X]', example['sent0']).replace('[MASK]', tokenizer.mask_token)
90 | example['sent1'] = config.prompt_format.replace('[X]', example['sent1']).replace('[MASK]', tokenizer.mask_token)
91 | max_length = max_length + prompt_len
92 | original_inputs = tokenizer(example['sent0'], padding='max_length', truncation=True, max_length=max_length)
93 | example['input_ids'] = original_inputs['input_ids']
94 | example['attention_mask'] = original_inputs['attention_mask']
95 |
96 | positive_inputs = tokenizer(example['sent1'], padding='max_length', truncation=True, max_length=max_length)
97 | example['positive_input_ids'] = positive_inputs['input_ids']
98 | example['positive_attention_mask'] = positive_inputs['attention_mask']
99 | return example
100 |
101 |
102 | if dataset_args.train_dataset == "Singhoo/denosent_data":
103 | dataset = load_dataset(dataset_args.train_dataset, split='train')
104 | # dataset = load_dataset('csv', data_files='./augdata.csv', sep='\t', split='train')
105 | else:
106 | raise NotImplementedError()
107 | dataset = dataset.map(map_fn, batched=False, num_proc=12).train_test_split(0.1, seed=training_args.seed, shuffle=True)
108 | test_valid = dataset['test'].train_test_split(0.01)
109 |
110 | trainer = MyTrainer(
111 | model=model,
112 | args=training_args,
113 | tokenizer=tokenizer,
114 | train_dataset=dataset['train'],
115 | eval_dataset=test_valid['test'],
116 | preprocess_logits_for_metrics=preprocess_logits_for_metrics,
117 | )
118 | trainer.train()
119 | mteb_results = eval_mteb(model, batch_size=training_args.eval_batch_size)
120 | table = PrettyTable(["Name", "Value"])
121 |
122 | # Add rows
123 | table.add_row(["STS12", wandb.summary['STS12']])
124 | table.add_row(["STS13", wandb.summary['STS13']])
125 | table.add_row(["STS14", wandb.summary['STS14']])
126 | table.add_row(["STS15", wandb.summary['STS15']])
127 | table.add_row(["STS16", wandb.summary['STS16']])
128 | table.add_row(["SICK-R", wandb.summary['SICK-R']])
129 | table.add_row(["STSBenchmark", wandb.summary['STSBenchmark']])
130 | table.add_row(["Avg.", wandb.summary['mteb_avg_sts']])
131 | # Print the table
132 | print(table)
133 |
134 | wandb.finish()
--------------------------------------------------------------------------------
/eval_senteval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import argparse
4 | from prettytable import PrettyTable
5 | import torch
6 | from model import DenoSentModel
7 | from config import DenoSentConfig
8 | # Set up logger
9 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
10 |
11 | # Set PATHs
12 | PATH_TO_SENTEVAL = './SentEval'
13 | PATH_TO_DATA = './SentEval/data'
14 |
15 | # Import SentEval
16 | sys.path.insert(0, PATH_TO_SENTEVAL)
17 | import senteval
18 |
19 | def print_table(task_names, scores):
20 | tb = PrettyTable()
21 | tb.field_names = task_names
22 | tb.add_row(scores)
23 | print(tb)
24 |
25 | def main():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--model_name_or_path", type=str,
28 | help="Transformers' model name or path")
29 | parser.add_argument("--pooler", type=str,
30 | choices=['cls', 'mean', 'mask'],
31 | default='mask',
32 | help="Which pooler to use")
33 | parser.add_argument("--mode", type=str,
34 | choices=['dev', 'test', 'fasttest'],
35 | default='test',
36 | help="What evaluation mode to use (dev: fast mode, dev results; test: full mode, test results); fasttest: fast mode, test results")
37 | parser.add_argument("--task_set", type=str,
38 | choices=['sts', 'transfer', 'full', 'na'],
39 | default='sts',
40 | help="What set of tasks to evaluate on. If not 'na', this will override '--tasks'")
41 | parser.add_argument("--tasks", type=str, nargs='+',
42 | default=['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
43 | 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC',
44 | 'SICKRelatedness', 'STSBenchmark'],
45 | help="Tasks to evaluate on. If '--task_set' is specified, this will be overridden")
46 | args = parser.parse_args()
47 | # Load transformers' model checkpoint
48 |
49 | config = DenoSentConfig.from_pretrained(args.model_name_or_path)
50 | model = DenoSentModel.from_pretrained(args.model_name_or_path, config=config)
51 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52 | model = model.to(device)
53 | # Set up the tasks
54 | if args.task_set == 'sts':
55 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
56 | elif args.task_set == 'transfer':
57 | args.tasks = ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
58 | elif args.task_set == 'full':
59 | args.tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
60 | args.tasks += ['MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'TREC', 'MRPC']
61 |
62 | # Set params for SentEval
63 | if args.mode == 'dev' or args.mode == 'fasttest':
64 | # Fast mode
65 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5, 'cudaEfficient': True}
66 | params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
67 | 'tenacity': 3, 'epoch_size': 2}
68 | elif args.mode == 'test':
69 | # Full mode
70 | params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'cudaEfficient': True}
71 | params['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
72 | 'tenacity': 5, 'epoch_size': 4}
73 | else:
74 | raise NotImplementedError
75 |
76 | # SentEval prepare and batcher
77 | def prepare(params, samples):
78 | return
79 |
80 | def batcher(params, batch):
81 | sentences = [' '.join(s) for s in batch]
82 | return model.encode(sentences, len(sentences))
83 | results = {}
84 |
85 | for task in args.tasks:
86 | se = senteval.engine.SE(params, batcher, prepare)
87 | result = se.eval(task)
88 | results[task] = result
89 |
90 | # Print evaluation results
91 | if args.mode == 'dev':
92 | print("------ %s ------" % (args.mode))
93 |
94 | task_names = []
95 | scores = []
96 | for task in ['STSBenchmark', 'SICKRelatedness']:
97 | task_names.append(task)
98 | if task in results:
99 | scores.append("%.2f" % (results[task]['dev']['spearman'][0] * 100))
100 | else:
101 | scores.append("0.00")
102 | print_table(task_names, scores)
103 |
104 | task_names = []
105 | scores = []
106 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
107 | task_names.append(task)
108 | if task in results:
109 | scores.append("%.2f" % (results[task]['devacc']))
110 | else:
111 | scores.append("0.00")
112 | task_names.append("Avg.")
113 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
114 | print_table(task_names, scores)
115 |
116 | elif args.mode == 'test' or args.mode == 'fasttest':
117 | print("------ %s ------" % (args.mode))
118 |
119 | task_names = []
120 | scores = []
121 | for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']:
122 | task_names.append(task)
123 | if task in results:
124 | if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
125 | scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100))
126 | else:
127 | scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100))
128 | else:
129 | scores.append("0.00")
130 | task_names.append("Avg.")
131 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
132 | print_table(task_names, scores)
133 |
134 | task_names = []
135 | scores = []
136 | for task in ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']:
137 | task_names.append(task)
138 | if task in results:
139 | scores.append("%.2f" % (results[task]['acc']))
140 | else:
141 | scores.append("0.00")
142 | task_names.append("Avg.")
143 | scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
144 | print_table(task_names, scores)
145 |
146 |
147 | if __name__ == "__main__":
148 | main()
--------------------------------------------------------------------------------
/SentEval/senteval/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 |
10 | Generic sentence evaluation scripts wrapper
11 |
12 | '''
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | from senteval import utils
16 | from senteval.binary import CREval, MREval, MPQAEval, SUBJEval
17 | from senteval.snli import SNLIEval
18 | from senteval.trec import TRECEval
19 | from senteval.sick import SICKEntailmentEval, SICKEval
20 | from senteval.mrpc import MRPCEval
21 | from senteval.sts import STS12Eval, STS13Eval, STS14Eval, STS15Eval, STS16Eval, STSBenchmarkEval, SICKRelatednessEval, STSBenchmarkFinetune
22 | from senteval.sst import SSTEval
23 | from senteval.rank import ImageCaptionRetrievalEval
24 | from senteval.probing import *
25 |
26 | class SE(object):
27 | def __init__(self, params, batcher, prepare=None):
28 | # parameters
29 | params = utils.dotdict(params)
30 | params.usepytorch = True if 'usepytorch' not in params else params.usepytorch
31 | params.seed = 1111 if 'seed' not in params else params.seed
32 |
33 | params.batch_size = 128 if 'batch_size' not in params else params.batch_size
34 | params.nhid = 0 if 'nhid' not in params else params.nhid
35 | params.kfold = 5 if 'kfold' not in params else params.kfold
36 |
37 | if 'classifier' not in params or not params['classifier']:
38 | params.classifier = {'nhid': 0}
39 |
40 | assert 'nhid' in params.classifier, 'Set number of hidden units in classifier config!!'
41 |
42 | self.params = params
43 |
44 | # batcher and prepare
45 | self.batcher = batcher
46 | self.prepare = prepare if prepare else lambda x, y: None
47 |
48 | self.list_tasks = ['CR', 'MR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
49 | 'SICKRelatedness', 'SICKEntailment', 'STSBenchmark',
50 | 'SNLI', 'ImageCaptionRetrieval', 'STS12', 'STS13',
51 | 'STS14', 'STS15', 'STS16',
52 | 'Length', 'WordContent', 'Depth', 'TopConstituents',
53 | 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
54 | 'OddManOut', 'CoordinationInversion', 'SICKRelatedness-finetune', 'STSBenchmark-finetune', 'STSBenchmark-fix']
55 |
56 | def eval(self, name):
57 | # evaluate on evaluation [name], either takes string or list of strings
58 | if (isinstance(name, list)):
59 | self.results = {x: self.eval(x) for x in name}
60 | return self.results
61 |
62 | tpath = self.params.task_path
63 | assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)
64 |
65 | # Original SentEval tasks
66 | if name == 'CR':
67 | self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed)
68 | elif name == 'MR':
69 | self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed)
70 | elif name == 'MPQA':
71 | self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed)
72 | elif name == 'SUBJ':
73 | self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed)
74 | elif name == 'SST2':
75 | self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed)
76 | elif name == 'SST5':
77 | self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed)
78 | elif name == 'TREC':
79 | self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed)
80 | elif name == 'MRPC':
81 | self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed)
82 | elif name == 'SICKRelatedness':
83 | self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed)
84 | elif name == 'STSBenchmark':
85 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
86 | elif name == 'STSBenchmark-fix':
87 | self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark-fix', seed=self.params.seed)
88 | elif name == 'STSBenchmark-finetune':
89 | self.evaluation = STSBenchmarkFinetune(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
90 | elif name == 'SICKRelatedness-finetune':
91 | self.evaluation = SICKEval(tpath + '/downstream/SICK', seed=self.params.seed)
92 | elif name == 'SICKEntailment':
93 | self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed)
94 | elif name == 'SNLI':
95 | self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed)
96 | elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
97 | fpath = name + '-en-test'
98 | self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed)
99 | elif name == 'ImageCaptionRetrieval':
100 | self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed)
101 |
102 | # Probing Tasks
103 | elif name == 'Length':
104 | self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed)
105 | elif name == 'WordContent':
106 | self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed)
107 | elif name == 'Depth':
108 | self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed)
109 | elif name == 'TopConstituents':
110 | self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed)
111 | elif name == 'BigramShift':
112 | self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed)
113 | elif name == 'Tense':
114 | self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed)
115 | elif name == 'SubjNumber':
116 | self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed)
117 | elif name == 'ObjNumber':
118 | self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed)
119 | elif name == 'OddManOut':
120 | self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed)
121 | elif name == 'CoordinationInversion':
122 | self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed)
123 |
124 | self.params.current_task = name
125 | self.evaluation.do_prepare(self.params, self.prepare)
126 |
127 | self.results = self.evaluation.run(self.params, self.batcher)
128 |
129 | return self.results
130 |
--------------------------------------------------------------------------------
/SentEval/senteval/probing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | probing tasks
10 | '''
11 |
12 | from __future__ import absolute_import, division, unicode_literals
13 |
14 | import os
15 | import io
16 | import copy
17 | import logging
18 | import numpy as np
19 |
20 | from senteval.tools.validation import SplitClassifier
21 |
22 |
23 | class PROBINGEval(object):
24 | def __init__(self, task, task_path, seed=1111):
25 | self.seed = seed
26 | self.task = task
27 | logging.debug('***** (Probing) Transfer task : %s classification *****', self.task.upper())
28 | self.task_data = {'train': {'X': [], 'y': []},
29 | 'dev': {'X': [], 'y': []},
30 | 'test': {'X': [], 'y': []}}
31 | self.loadFile(task_path)
32 | logging.info('Loaded %s train - %s dev - %s test for %s' %
33 | (len(self.task_data['train']['y']), len(self.task_data['dev']['y']),
34 | len(self.task_data['test']['y']), self.task))
35 |
36 | def do_prepare(self, params, prepare):
37 | samples = self.task_data['train']['X'] + self.task_data['dev']['X'] + \
38 | self.task_data['test']['X']
39 | return prepare(params, samples)
40 |
41 | def loadFile(self, fpath):
42 | self.tok2split = {'tr': 'train', 'va': 'dev', 'te': 'test'}
43 | with io.open(fpath, 'r', encoding='utf-8') as f:
44 | for line in f:
45 | line = line.rstrip().split('\t')
46 | self.task_data[self.tok2split[line[0]]]['X'].append(line[-1].split())
47 | self.task_data[self.tok2split[line[0]]]['y'].append(line[1])
48 |
49 | labels = sorted(np.unique(self.task_data['train']['y']))
50 | self.tok2label = dict(zip(labels, range(len(labels))))
51 | self.nclasses = len(self.tok2label)
52 |
53 | for split in self.task_data:
54 | for i, y in enumerate(self.task_data[split]['y']):
55 | self.task_data[split]['y'][i] = self.tok2label[y]
56 |
57 | def run(self, params, batcher):
58 | task_embed = {'train': {}, 'dev': {}, 'test': {}}
59 | bsize = params.batch_size
60 | logging.info('Computing embeddings for train/dev/test')
61 | for key in self.task_data:
62 | # Sort to reduce padding
63 | sorted_data = sorted(zip(self.task_data[key]['X'],
64 | self.task_data[key]['y']),
65 | key=lambda z: (len(z[0]), z[1]))
66 | self.task_data[key]['X'], self.task_data[key]['y'] = map(list, zip(*sorted_data))
67 |
68 | task_embed[key]['X'] = []
69 | for ii in range(0, len(self.task_data[key]['y']), bsize):
70 | batch = self.task_data[key]['X'][ii:ii + bsize]
71 | embeddings = batcher(params, batch)
72 | task_embed[key]['X'].append(embeddings)
73 | task_embed[key]['X'] = np.vstack(task_embed[key]['X'])
74 | task_embed[key]['y'] = np.array(self.task_data[key]['y'])
75 | logging.info('Computed embeddings')
76 |
77 | config_classifier = {'nclasses': self.nclasses, 'seed': self.seed,
78 | 'usepytorch': params.usepytorch,
79 | 'classifier': params.classifier}
80 |
81 | if self.task == "WordContent" and params.classifier['nhid'] > 0:
82 | config_classifier = copy.deepcopy(config_classifier)
83 | config_classifier['classifier']['nhid'] = 0
84 | print(params.classifier['nhid'])
85 |
86 | clf = SplitClassifier(X={'train': task_embed['train']['X'],
87 | 'valid': task_embed['dev']['X'],
88 | 'test': task_embed['test']['X']},
89 | y={'train': task_embed['train']['y'],
90 | 'valid': task_embed['dev']['y'],
91 | 'test': task_embed['test']['y']},
92 | config=config_classifier)
93 |
94 | devacc, testacc = clf.run()
95 | logging.debug('\nDev acc : %.1f Test acc : %.1f for %s classification\n' % (devacc, testacc, self.task.upper()))
96 |
97 | return {'devacc': devacc, 'acc': testacc,
98 | 'ndev': len(task_embed['dev']['X']),
99 | 'ntest': len(task_embed['test']['X'])}
100 |
101 | """
102 | Surface Information
103 | """
104 | class LengthEval(PROBINGEval):
105 | def __init__(self, task_path, seed=1111):
106 | task_path = os.path.join(task_path, 'sentence_length.txt')
107 | # labels: bins
108 | PROBINGEval.__init__(self, 'Length', task_path, seed)
109 |
110 | class WordContentEval(PROBINGEval):
111 | def __init__(self, task_path, seed=1111):
112 | task_path = os.path.join(task_path, 'word_content.txt')
113 | # labels: 200 target words
114 | PROBINGEval.__init__(self, 'WordContent', task_path, seed)
115 |
116 | """
117 | Latent Structural Information
118 | """
119 | class DepthEval(PROBINGEval):
120 | def __init__(self, task_path, seed=1111):
121 | task_path = os.path.join(task_path, 'tree_depth.txt')
122 | # labels: bins
123 | PROBINGEval.__init__(self, 'Depth', task_path, seed)
124 |
125 | class TopConstituentsEval(PROBINGEval):
126 | def __init__(self, task_path, seed=1111):
127 | task_path = os.path.join(task_path, 'top_constituents.txt')
128 | # labels: 'PP_NP_VP_.' .. (20 classes)
129 | PROBINGEval.__init__(self, 'TopConstituents', task_path, seed)
130 |
131 | class BigramShiftEval(PROBINGEval):
132 | def __init__(self, task_path, seed=1111):
133 | task_path = os.path.join(task_path, 'bigram_shift.txt')
134 | # labels: 0 or 1
135 | PROBINGEval.__init__(self, 'BigramShift', task_path, seed)
136 |
137 | # TODO: Voice?
138 |
139 | """
140 | Latent Semantic Information
141 | """
142 |
143 | class TenseEval(PROBINGEval):
144 | def __init__(self, task_path, seed=1111):
145 | task_path = os.path.join(task_path, 'past_present.txt')
146 | # labels: 'PRES', 'PAST'
147 | PROBINGEval.__init__(self, 'Tense', task_path, seed)
148 |
149 | class SubjNumberEval(PROBINGEval):
150 | def __init__(self, task_path, seed=1111):
151 | task_path = os.path.join(task_path, 'subj_number.txt')
152 | # labels: 'NN', 'NNS'
153 | PROBINGEval.__init__(self, 'SubjNumber', task_path, seed)
154 |
155 | class ObjNumberEval(PROBINGEval):
156 | def __init__(self, task_path, seed=1111):
157 | task_path = os.path.join(task_path, 'obj_number.txt')
158 | # labels: 'NN', 'NNS'
159 | PROBINGEval.__init__(self, 'ObjNumber', task_path, seed)
160 |
161 | class OddManOutEval(PROBINGEval):
162 | def __init__(self, task_path, seed=1111):
163 | task_path = os.path.join(task_path, 'odd_man_out.txt')
164 | # labels: 'O', 'C'
165 | PROBINGEval.__init__(self, 'OddManOut', task_path, seed)
166 |
167 | class CoordinationInversionEval(PROBINGEval):
168 | def __init__(self, task_path, seed=1111):
169 | task_path = os.path.join(task_path, 'coordination_inversion.txt')
170 | # labels: 'O', 'I'
171 | PROBINGEval.__init__(self, 'CoordinationInversion', task_path, seed)
172 |
--------------------------------------------------------------------------------
/SentEval/senteval/tools/classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | Pytorch Classifier class in the style of scikit-learn
10 | Classifiers include Logistic Regression and MLP
11 | """
12 |
13 | from __future__ import absolute_import, division, unicode_literals
14 |
15 | import numpy as np
16 | import copy
17 | from senteval import utils
18 |
19 | import torch
20 | from torch import nn
21 | import torch.nn.functional as F
22 |
23 |
24 | class PyTorchClassifier(object):
25 | def __init__(self, inputdim, nclasses, l2reg=0., batch_size=64, seed=1111,
26 | cudaEfficient=False):
27 | # fix seed
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 |
32 | self.inputdim = inputdim
33 | self.nclasses = nclasses
34 | self.l2reg = l2reg
35 | self.batch_size = batch_size
36 | self.cudaEfficient = cudaEfficient
37 |
38 | def prepare_split(self, X, y, validation_data=None, validation_split=None):
39 | # Preparing validation data
40 | assert validation_split or validation_data
41 | if validation_data is not None:
42 | trainX, trainy = X, y
43 | devX, devy = validation_data
44 | else:
45 | permutation = np.random.permutation(len(X))
46 | trainidx = permutation[int(validation_split * len(X)):]
47 | devidx = permutation[0:int(validation_split * len(X))]
48 | trainX, trainy = X[trainidx], y[trainidx]
49 | devX, devy = X[devidx], y[devidx]
50 |
51 | device = torch.device('cpu') if self.cudaEfficient else torch.device('cuda')
52 |
53 | trainX = torch.from_numpy(trainX).to(device, dtype=torch.float32)
54 | trainy = torch.from_numpy(trainy).to(device, dtype=torch.int64)
55 | devX = torch.from_numpy(devX).to(device, dtype=torch.float32)
56 | devy = torch.from_numpy(devy).to(device, dtype=torch.int64)
57 |
58 | return trainX, trainy, devX, devy
59 |
60 | def fit(self, X, y, validation_data=None, validation_split=None,
61 | early_stop=True):
62 | self.nepoch = 0
63 | bestaccuracy = -1
64 | stop_train = False
65 | early_stop_count = 0
66 |
67 | # Preparing validation data
68 | trainX, trainy, devX, devy = self.prepare_split(X, y, validation_data,
69 | validation_split)
70 |
71 | # Training
72 | while not stop_train and self.nepoch <= self.max_epoch:
73 | self.trainepoch(trainX, trainy, epoch_size=self.epoch_size)
74 | accuracy = self.score(devX, devy)
75 | if accuracy > bestaccuracy:
76 | bestaccuracy = accuracy
77 | bestmodel = copy.deepcopy(self.model)
78 | elif early_stop:
79 | if early_stop_count >= self.tenacity:
80 | stop_train = True
81 | early_stop_count += 1
82 | self.model = bestmodel
83 | return bestaccuracy
84 |
85 | def trainepoch(self, X, y, epoch_size=1):
86 | self.model.train()
87 | for _ in range(self.nepoch, self.nepoch + epoch_size):
88 | permutation = np.random.permutation(len(X))
89 | all_costs = []
90 | for i in range(0, len(X), self.batch_size):
91 | # forward
92 | idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().to(X.device)
93 |
94 | Xbatch = X[idx]
95 | ybatch = y[idx]
96 |
97 | if self.cudaEfficient:
98 | Xbatch = Xbatch.cuda()
99 | ybatch = ybatch.cuda()
100 | output = self.model(Xbatch)
101 | # loss
102 | loss = self.loss_fn(output, ybatch)
103 | all_costs.append(loss.data.item())
104 | # backward
105 | self.optimizer.zero_grad()
106 | loss.backward()
107 | # Update parameters
108 | self.optimizer.step()
109 | self.nepoch += epoch_size
110 |
111 | def score(self, devX, devy):
112 | self.model.eval()
113 | correct = 0
114 | if not isinstance(devX, torch.cuda.FloatTensor) or self.cudaEfficient:
115 | devX = torch.FloatTensor(devX).cuda()
116 | devy = torch.LongTensor(devy).cuda()
117 | with torch.no_grad():
118 | for i in range(0, len(devX), self.batch_size):
119 | Xbatch = devX[i:i + self.batch_size]
120 | ybatch = devy[i:i + self.batch_size]
121 | if self.cudaEfficient:
122 | Xbatch = Xbatch.cuda()
123 | ybatch = ybatch.cuda()
124 | output = self.model(Xbatch)
125 | pred = output.data.max(1)[1]
126 | correct += pred.long().eq(ybatch.data.long()).sum().item()
127 | accuracy = 1.0 * correct / len(devX)
128 | return accuracy
129 |
130 | def predict(self, devX):
131 | self.model.eval()
132 | if not isinstance(devX, torch.cuda.FloatTensor):
133 | devX = torch.FloatTensor(devX).cuda()
134 | yhat = np.array([])
135 | with torch.no_grad():
136 | for i in range(0, len(devX), self.batch_size):
137 | Xbatch = devX[i:i + self.batch_size]
138 | output = self.model(Xbatch)
139 | yhat = np.append(yhat,
140 | output.data.max(1)[1].cpu().numpy())
141 | yhat = np.vstack(yhat)
142 | return yhat
143 |
144 | def predict_proba(self, devX):
145 | self.model.eval()
146 | probas = []
147 | with torch.no_grad():
148 | for i in range(0, len(devX), self.batch_size):
149 | Xbatch = devX[i:i + self.batch_size]
150 | vals = F.softmax(self.model(Xbatch).data.cpu().numpy())
151 | if not probas:
152 | probas = vals
153 | else:
154 | probas = np.concatenate(probas, vals, axis=0)
155 | return probas
156 |
157 |
158 | """
159 | MLP with Pytorch (nhid=0 --> Logistic Regression)
160 | """
161 |
162 | class MLP(PyTorchClassifier):
163 | def __init__(self, params, inputdim, nclasses, l2reg=0., batch_size=64,
164 | seed=1111, cudaEfficient=False):
165 | super(self.__class__, self).__init__(inputdim, nclasses, l2reg,
166 | batch_size, seed, cudaEfficient)
167 | """
168 | PARAMETERS:
169 | -nhid: number of hidden units (0: Logistic Regression)
170 | -optim: optimizer ("sgd,lr=0.1", "adam", "rmsprop" ..)
171 | -tenacity: how many times dev acc does not increase before stopping
172 | -epoch_size: each epoch corresponds to epoch_size pass on the train set
173 | -max_epoch: max number of epoches
174 | -dropout: dropout for MLP
175 | """
176 |
177 | self.nhid = 0 if "nhid" not in params else params["nhid"]
178 | self.optim = "adam" if "optim" not in params else params["optim"]
179 | self.tenacity = 5 if "tenacity" not in params else params["tenacity"]
180 | self.epoch_size = 4 if "epoch_size" not in params else params["epoch_size"]
181 | self.max_epoch = 200 if "max_epoch" not in params else params["max_epoch"]
182 | self.dropout = 0. if "dropout" not in params else params["dropout"]
183 | self.batch_size = 64 if "batch_size" not in params else params["batch_size"]
184 |
185 | if params["nhid"] == 0:
186 | self.model = nn.Sequential(
187 | nn.Linear(self.inputdim, self.nclasses),
188 | ).cuda()
189 | else:
190 | self.model = nn.Sequential(
191 | nn.Linear(self.inputdim, params["nhid"]),
192 | nn.Dropout(p=self.dropout),
193 | nn.Sigmoid(),
194 | nn.Linear(params["nhid"], self.nclasses),
195 | ).cuda()
196 |
197 | self.loss_fn = nn.CrossEntropyLoss().cuda()
198 | self.loss_fn.size_average = False
199 |
200 | optim_fn, optim_params = utils.get_optimizer(self.optim)
201 | self.optimizer = optim_fn(self.model.parameters(), **optim_params)
202 | self.optimizer.param_groups[0]['weight_decay'] = self.l2reg
203 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, BertForMaskedLM
2 | from transformers.models.bert.modeling_bert import BertForMaskedLM
3 | from transformers.modeling_outputs import TokenClassifierOutput
4 | from transformers import PreTrainedModel
5 | import torch
6 | from torch import nn
7 | from torch.nn import TransformerDecoder, TransformerDecoderLayer
8 |
9 | from typing import Optional
10 |
11 | import wandb
12 | import numpy as np
13 |
14 | class DenoSentModel(PreTrainedModel):
15 | def __init__(self, config):
16 | super().__init__(config)
17 | self.pooler = config.pooler
18 | self.sent_embedding_projector = nn.Linear(config.hidden_size, config.hidden_size)
19 | self.decoder = TransformerDecoder(TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.decoder_num_heads, batch_first=True, dropout=0.1), num_layers=config.decoder_num_layers)
20 | self.decoder_noise_dropout = nn.Dropout(config.decoder_noise_dropout)
21 | self.sim = nn.CosineSimilarity(dim=-1)
22 | self.init_weights()
23 | self.tokenizer = AutoTokenizer.from_pretrained(config.encoder_name_or_path)
24 | self.encoder = BertForMaskedLM.from_pretrained(config.encoder_name_or_path)
25 | self.prediction_head = self.encoder.cls
26 | self.encoder = self.encoder.bert
27 | self.post_init()
28 |
29 | def _init_weights(self, module):
30 | """Initialize the weights"""
31 | if isinstance(module, nn.Linear):
32 | # Slightly different from the TF version which uses truncated_normal for initialization
33 | # cf https://github.com/pytorch/pytorch/pull/5617
34 | module.weight.data.normal_(mean=0.0, std=0.02)
35 | if module.bias is not None:
36 | module.bias.data.zero_()
37 | elif isinstance(module, nn.Embedding):
38 | module.weight.data.normal_(mean=0.0, std=0.02)
39 | if module.padding_idx is not None:
40 | module.weight.data[module.padding_idx].zero_()
41 | elif isinstance(module, nn.LayerNorm):
42 | module.bias.data.zero_()
43 | module.weight.data.fill_(1.0)
44 |
45 | def encode(self, sentences, batch_size=32, **kwargs):
46 | """ Returns a list of embeddings for the given sentences.
47 | Args:
48 | sentences (`List[str]`): List of sentences to encode
49 | batch_size (`int`): Batch size for the encoding
50 |
51 | Returns:
52 | `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences
53 | """
54 | self.eval()
55 | all_embeddings = []
56 | length_sorted_idx = np.argsort([len(sen) for sen in sentences])
57 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
58 | if self.config.pooler == 'mask':
59 | prompt_length = len(self.tokenizer(self.config.prompt_format, add_special_tokens=False)['input_ids'])
60 | sentences_sorted = self.tokenizer.batch_decode(self.tokenizer(sentences_sorted, padding=True, truncation=True, max_length=self.config.max_length, return_tensors='pt').input_ids, skip_special_tokens=True)
61 | sentences_sorted = [self.config.prompt_format.replace('[X]', s).replace('[MASK]', self.tokenizer.mask_token) for s in sentences_sorted]
62 | for start_index in range(0, len(sentences), batch_size):
63 | sentences_batch = sentences_sorted[start_index:start_index+batch_size]
64 | inputs = self.tokenizer(sentences_batch, padding='max_length', truncation=True, return_tensors="pt", max_length=self.config.max_length+prompt_length)
65 | inputs = {k: v.to(self.device) for k,v in inputs.items()}
66 | with torch.no_grad():
67 | encoder_outputs = self.encoder(**inputs, output_hidden_states=True, output_attentions=True, return_dict=True)
68 | last_hidden_state = encoder_outputs.last_hidden_state
69 | if self.config.pooler == 'cls':
70 | embeddings = last_hidden_state[:, 0, :]
71 | elif self.config.pooler == 'mean':
72 | embeddings = (last_hidden_state * inputs['attention_mask'].unsqueeze(-1)).sum(1) / inputs['attention_mask'].sum(-1).unsqueeze(-1)
73 | elif self.pooler == 'mask':
74 | embeddings = last_hidden_state[inputs['input_ids'] == self.tokenizer.mask_token_id]
75 | else:
76 | raise NotImplementedError()
77 | all_embeddings.extend(embeddings.cpu().numpy())
78 | all_embeddings = torch.tensor(np.array([all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]))
79 | return all_embeddings
80 |
81 | def forward(
82 | self,
83 | input_ids: Optional[torch.LongTensor] = None,
84 | attention_mask: Optional[torch.LongTensor] = None,
85 | positive_input_ids: Optional[torch.LongTensor] = None,
86 | positive_attention_mask: Optional[torch.LongTensor] = None,
87 | negative_input_ids: Optional[torch.LongTensor] = None,
88 | negative_attention_mask: Optional[torch.LongTensor] = None,
89 | global_step: Optional[int] = None,
90 | max_steps: Optional[int] = None,
91 | ):
92 | batch_size = input_ids.size(0)
93 | if negative_input_ids is not None:
94 | encoder_input_ids = torch.cat([input_ids, positive_input_ids, negative_input_ids], dim=0).to(self.device)
95 | encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask, negative_attention_mask], dim=0).to(self.device)
96 | elif positive_input_ids is not None:
97 | encoder_input_ids = torch.cat([input_ids, positive_input_ids], dim=0).to(self.device)
98 | encoder_attention_mask = torch.cat([attention_mask, positive_attention_mask], dim=0).to(self.device)
99 | elif self.config.do_contrastive:
100 | encoder_input_ids = torch.cat([input_ids, input_ids], dim=0).to(self.device)
101 | encoder_attention_mask = torch.cat([attention_mask, attention_mask], dim=0).to(self.device)
102 | elif self.config.do_generative and not self.config.do_contrastive:
103 | encoder_input_ids = input_ids.to(self.device)
104 | encoder_attention_mask = attention_mask.to(self.device)
105 | else:
106 | raise NotImplementedError()
107 | encoder_outputs = self.encoder(input_ids=encoder_input_ids, attention_mask=encoder_attention_mask, return_dict=True, output_hidden_states=True, output_attentions=True)
108 | if self.pooler == 'cls':
109 | sent_embedding = encoder_outputs.last_hidden_state[:, 0, :]
110 | elif self.pooler == 'mean':
111 | sent_embedding = ((encoder_outputs.last_hidden_state * encoder_attention_mask.unsqueeze(-1)).sum(1) / encoder_attention_mask.sum(-1).unsqueeze(-1))
112 | elif self.pooler == 'mask':
113 | sent_embedding = encoder_outputs.last_hidden_state[encoder_input_ids == self.tokenizer.mask_token_id]
114 | else:
115 | raise NotImplementedError()
116 | sent_embedding = sent_embedding.unsqueeze(1)
117 | sent_embedding = self.sent_embedding_projector(sent_embedding)
118 |
119 | if self.config.do_generative:
120 | if positive_input_ids is not None:
121 | tgt = encoder_outputs.hidden_states[0][batch_size:2*batch_size].detach()
122 | tgt_key_padding_mask = (positive_input_ids == self.tokenizer.pad_token_id)
123 | labels = positive_input_ids
124 | else:
125 | tgt = encoder_outputs.hidden_states[0][:batch_size].detach()
126 | tgt_key_padding_mask = (input_ids == self.tokenizer.pad_token_id)
127 | labels = input_ids
128 | tgt = self.decoder_noise_dropout(tgt)
129 | decoder_outputs = self.decoder(tgt=tgt, memory=sent_embedding[:batch_size], tgt_mask=None, tgt_key_padding_mask=tgt_key_padding_mask)
130 | logits = self.prediction_head(decoder_outputs)
131 | loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
132 | generative_loss = loss_fct(logits.view(-1, self.encoder.config.vocab_size), labels.view(-1))
133 | wandb.log({'train/generative_loss': generative_loss})
134 |
135 | if self.config.do_contrastive:
136 | positive_sim = self.sim(sent_embedding[:batch_size], sent_embedding[batch_size:2*batch_size].transpose(0, 1))
137 | cos_sim = positive_sim
138 | if negative_attention_mask is not None:
139 | negative_sim = self.sim(sent_embedding[:batch_size], sent_embedding[2*batch_size:].transpose(0, 1))
140 | cos_sim = torch.cat([positive_sim, negative_sim], dim=1)
141 | cos_sim = cos_sim / self.config.contrastive_temp
142 | contrastive_labels = torch.arange(batch_size, dtype=torch.long, device=self.device)
143 | contrastive_loss = nn.CrossEntropyLoss()(cos_sim, contrastive_labels)
144 | wandb.log({'train/contrastive_loss': contrastive_loss.item()})
145 | logits = None
146 | loss = 0
147 | if self.config.do_contrastive:
148 | loss += self.config.contrastive_weight * contrastive_loss
149 | if self.config.do_generative:
150 | loss += self.config.generative_weight * generative_loss
151 | wandb.log({'train/loss': loss})
152 | return TokenClassifierOutput(
153 | loss=loss,
154 | logits=logits,
155 | hidden_states=encoder_outputs.hidden_states,
156 | attentions=encoder_outputs.attentions,
157 | )
158 |
--------------------------------------------------------------------------------
/SentEval/senteval/sick.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''
9 | SICK Relatedness and Entailment
10 | '''
11 | from __future__ import absolute_import, division, unicode_literals
12 |
13 | import os
14 | import io
15 | import logging
16 | import numpy as np
17 |
18 | from sklearn.metrics import mean_squared_error
19 | from scipy.stats import pearsonr, spearmanr
20 |
21 | from senteval.tools.relatedness import RelatednessPytorch
22 | from senteval.tools.validation import SplitClassifier
23 |
24 | class SICKEval(object):
25 | def __init__(self, task_path, seed=1111):
26 | logging.debug('***** Transfer task : SICK-Relatedness*****\n\n')
27 | self.seed = seed
28 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
29 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
30 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
31 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
32 |
33 | def do_prepare(self, params, prepare):
34 | samples = self.sick_data['train']['X_A'] + \
35 | self.sick_data['train']['X_B'] + \
36 | self.sick_data['dev']['X_A'] + \
37 | self.sick_data['dev']['X_B'] + \
38 | self.sick_data['test']['X_A'] + self.sick_data['test']['X_B']
39 | return prepare(params, samples)
40 |
41 | def loadFile(self, fpath):
42 | skipFirstLine = True
43 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
44 | with io.open(fpath, 'r', encoding='utf-8') as f:
45 | for line in f:
46 | if skipFirstLine:
47 | skipFirstLine = False
48 | else:
49 | text = line.strip().split('\t')
50 | sick_data['X_A'].append(text[1].split())
51 | sick_data['X_B'].append(text[2].split())
52 | sick_data['y'].append(text[3])
53 |
54 | sick_data['y'] = [float(s) for s in sick_data['y']]
55 | return sick_data
56 |
57 | def run(self, params, batcher):
58 | sick_embed = {'train': {}, 'dev': {}, 'test': {}}
59 | bsize = params.batch_size
60 |
61 | for key in self.sick_data:
62 | logging.info('Computing embedding for {0}'.format(key))
63 | # Sort to reduce padding
64 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
65 | self.sick_data[key]['X_B'],
66 | self.sick_data[key]['y']),
67 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
68 |
69 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
70 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
71 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
72 |
73 | for txt_type in ['X_A', 'X_B']:
74 | sick_embed[key][txt_type] = []
75 | for ii in range(0, len(self.sick_data[key]['y']), bsize):
76 | batch = self.sick_data[key][txt_type][ii:ii + bsize]
77 | embeddings = batcher(params, batch)
78 | sick_embed[key][txt_type].append(embeddings)
79 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
80 | sick_embed[key]['y'] = np.array(self.sick_data[key]['y'])
81 | logging.info('Computed {0} embeddings'.format(key))
82 |
83 | # Train
84 | trainA = sick_embed['train']['X_A']
85 | trainB = sick_embed['train']['X_B']
86 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
87 | trainY = self.encode_labels(self.sick_data['train']['y'])
88 |
89 | # Dev
90 | devA = sick_embed['dev']['X_A']
91 | devB = sick_embed['dev']['X_B']
92 | devF = np.c_[np.abs(devA - devB), devA * devB]
93 | devY = self.encode_labels(self.sick_data['dev']['y'])
94 |
95 | # Test
96 | testA = sick_embed['test']['X_A']
97 | testB = sick_embed['test']['X_B']
98 | testF = np.c_[np.abs(testA - testB), testA * testB]
99 | testY = self.encode_labels(self.sick_data['test']['y'])
100 |
101 | config = {'seed': self.seed, 'nclasses': 5}
102 | clf = RelatednessPytorch(train={'X': trainF, 'y': trainY},
103 | valid={'X': devF, 'y': devY},
104 | test={'X': testF, 'y': testY},
105 | devscores=self.sick_data['dev']['y'],
106 | config=config)
107 |
108 | devspr, yhat = clf.run()
109 |
110 | pr = pearsonr(yhat, self.sick_data['test']['y'])[0]
111 | sr = spearmanr(yhat, self.sick_data['test']['y'])[0]
112 | pr = 0 if pr != pr else pr
113 | sr = 0 if sr != sr else sr
114 | se = mean_squared_error(yhat, self.sick_data['test']['y'])
115 | logging.debug('Dev : Spearman {0}'.format(devspr))
116 | logging.debug('Test : Pearson {0} Spearman {1} MSE {2} \
117 | for SICK Relatedness\n'.format(pr, sr, se))
118 |
119 | return {'devspearman': devspr, 'pearson': pr, 'spearman': sr, 'mse': se,
120 | 'yhat': yhat, 'ndev': len(devA), 'ntest': len(testA)}
121 |
122 | def encode_labels(self, labels, nclass=5):
123 | """
124 | Label encoding from Tree LSTM paper (Tai, Socher, Manning)
125 | """
126 | Y = np.zeros((len(labels), nclass)).astype('float32')
127 | for j, y in enumerate(labels):
128 | for i in range(nclass):
129 | if i+1 == np.floor(y) + 1:
130 | Y[j, i] = y - np.floor(y)
131 | if i+1 == np.floor(y):
132 | Y[j, i] = np.floor(y) - y + 1
133 | return Y
134 |
135 |
136 | class SICKEntailmentEval(SICKEval):
137 | def __init__(self, task_path, seed=1111):
138 | logging.debug('***** Transfer task : SICK-Entailment*****\n\n')
139 | self.seed = seed
140 | train = self.loadFile(os.path.join(task_path, 'SICK_train.txt'))
141 | dev = self.loadFile(os.path.join(task_path, 'SICK_trial.txt'))
142 | test = self.loadFile(os.path.join(task_path, 'SICK_test_annotated.txt'))
143 | self.sick_data = {'train': train, 'dev': dev, 'test': test}
144 |
145 | def loadFile(self, fpath):
146 | label2id = {'CONTRADICTION': 0, 'NEUTRAL': 1, 'ENTAILMENT': 2}
147 | skipFirstLine = True
148 | sick_data = {'X_A': [], 'X_B': [], 'y': []}
149 | with io.open(fpath, 'r', encoding='utf-8') as f:
150 | for line in f:
151 | if skipFirstLine:
152 | skipFirstLine = False
153 | else:
154 | text = line.strip().split('\t')
155 | sick_data['X_A'].append(text[1].split())
156 | sick_data['X_B'].append(text[2].split())
157 | sick_data['y'].append(text[4])
158 | sick_data['y'] = [label2id[s] for s in sick_data['y']]
159 | return sick_data
160 |
161 | def run(self, params, batcher):
162 | sick_embed = {'train': {}, 'dev': {}, 'test': {}}
163 | bsize = params.batch_size
164 |
165 | for key in self.sick_data:
166 | logging.info('Computing embedding for {0}'.format(key))
167 | # Sort to reduce padding
168 | sorted_corpus = sorted(zip(self.sick_data[key]['X_A'],
169 | self.sick_data[key]['X_B'],
170 | self.sick_data[key]['y']),
171 | key=lambda z: (len(z[0]), len(z[1]), z[2]))
172 |
173 | self.sick_data[key]['X_A'] = [x for (x, y, z) in sorted_corpus]
174 | self.sick_data[key]['X_B'] = [y for (x, y, z) in sorted_corpus]
175 | self.sick_data[key]['y'] = [z for (x, y, z) in sorted_corpus]
176 |
177 | for txt_type in ['X_A', 'X_B']:
178 | sick_embed[key][txt_type] = []
179 | for ii in range(0, len(self.sick_data[key]['y']), bsize):
180 | batch = self.sick_data[key][txt_type][ii:ii + bsize]
181 | embeddings = batcher(params, batch)
182 | sick_embed[key][txt_type].append(embeddings)
183 | sick_embed[key][txt_type] = np.vstack(sick_embed[key][txt_type])
184 | logging.info('Computed {0} embeddings'.format(key))
185 |
186 | # Train
187 | trainA = sick_embed['train']['X_A']
188 | trainB = sick_embed['train']['X_B']
189 | trainF = np.c_[np.abs(trainA - trainB), trainA * trainB]
190 | trainY = np.array(self.sick_data['train']['y'])
191 |
192 | # Dev
193 | devA = sick_embed['dev']['X_A']
194 | devB = sick_embed['dev']['X_B']
195 | devF = np.c_[np.abs(devA - devB), devA * devB]
196 | devY = np.array(self.sick_data['dev']['y'])
197 |
198 | # Test
199 | testA = sick_embed['test']['X_A']
200 | testB = sick_embed['test']['X_B']
201 | testF = np.c_[np.abs(testA - testB), testA * testB]
202 | testY = np.array(self.sick_data['test']['y'])
203 |
204 | config = {'nclasses': 3, 'seed': self.seed,
205 | 'usepytorch': params.usepytorch,
206 | 'classifier': params.classifier,
207 | 'nhid': params.nhid}
208 | clf = SplitClassifier(X={'train': trainF, 'valid': devF, 'test': testF},
209 | y={'train': trainY, 'valid': devY, 'test': testY},
210 | config=config)
211 |
212 | devacc, testacc = clf.run()
213 | logging.debug('\nDev acc : {0} Test acc : {1} for \
214 | SICK entailment\n'.format(devacc, testacc))
215 | return {'devacc': devacc, 'acc': testacc,
216 | 'ndev': len(devA), 'ntest': len(testA)}
217 |
--------------------------------------------------------------------------------
/SentEval/examples/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | """
9 | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf
10 | """
11 |
12 | import numpy as np
13 | import time
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 |
19 | class InferSent(nn.Module):
20 |
21 | def __init__(self, config):
22 | super(InferSent, self).__init__()
23 | self.bsize = config['bsize']
24 | self.word_emb_dim = config['word_emb_dim']
25 | self.enc_lstm_dim = config['enc_lstm_dim']
26 | self.pool_type = config['pool_type']
27 | self.dpout_model = config['dpout_model']
28 | self.version = 1 if 'version' not in config else config['version']
29 |
30 | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1,
31 | bidirectional=True, dropout=self.dpout_model)
32 |
33 | assert self.version in [1, 2]
34 | if self.version == 1:
35 | self.bos = ' '
41 | self.eos = ''] = 1e9 + 4
41 | words[''] = 1e9 + 3
42 | words[''
36 | self.eos = ''
37 | self.max_pad = True
38 | self.moses_tok = False
39 | elif self.version == 2:
40 | self.bos = '