├── .gitignore ├── CHANGELOG.md ├── README.md ├── build_data.sh ├── convert_data_format.py ├── model_card.md ├── requirements.txt ├── semsearch.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | model/ 3 | sbert-base-ja/ 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [Unreleased] 4 | 5 | ## [v1.0] - 2021-08-01 6 | 7 | ### Added 8 | 9 | - models and model card 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sentence BERT Japanese 2 | 3 | This repository contains training script for Sentence BERT Japanese models. 4 | 5 | ## Prepare environment 6 | 7 | ```sh 8 | $ docker container run --gpus all --ipc=host --rm -it -v $(pwd):/work -w /work nvidia/cuda:11.1-devel-ubuntu20.04 bash 9 | (container)$ apt update && apt install -y python3 python3-pip git wget zip 10 | (container)$ pip3 install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 11 | (container)$ pip3 install -r requirements.txt 12 | ``` 13 | 14 | ## Data 15 | 16 | This model uses [Japanese SNLI data](https://nlp.ist.i.kyoto-u.ac.jp/index.php?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88) released under CC BY-SA 4.0 . 17 | 18 | ```sh 19 | $ bash build_data.sh 20 | ``` 21 | 22 | Check the data sha1sum. 23 | 24 | ```sh 25 | $ sha1sum data/JSNLI.zip 26 | d6c9b45e8e6df03959f38cfbb58c31a747d6d12f data/JSNLI.zip 27 | ``` 28 | 29 | {train,val,test}.jsonl data are prepared under a `data` directory. 30 | 31 | ## Train 32 | 33 | ```sh 34 | $ python3 train.py --base_model colorfulscoop/bert-base-ja --output_model model --train_data data/train.jsonl --valid_data data/val.jsonl --test_data data/test.jsonl --epochs 1 --evaluation_steps=5000 --batch_size 8 --seed 1000 --use_amp 35 | ``` 36 | 37 | ## Example usage 38 | 39 | ```sh 40 | $ python semsearch.py --model model 41 | ====== 42 | Query: 走るのが趣味です 43 | 0.9029 外をランニングするのが好きです 44 | 0.7534 運動はそこそこです 45 | 0.5894 走るのは嫌いです 46 | 0.5451 天ぷらが食べたい 47 | 0.5335 りんごが食べたい 48 | 0.4970 海外旅行に行きたい 49 | 0.4268 揚げ物は食べたくない 50 | ====== 51 | Query: 外国を旅したい 52 | 0.9073 海外旅行に行きたい 53 | 0.7153 運動はそこそこです 54 | 0.6544 外をランニングするのが好きです 55 | 0.5313 天ぷらが食べたい 56 | 0.4653 りんごが食べたい 57 | 0.4413 揚げ物は食べたくない 58 | 0.4154 走るのは嫌いです 59 | ====== 60 | Query: 揚げ物が食べたい 61 | 0.9118 天ぷらが食べたい 62 | 0.7990 りんごが食べたい 63 | 0.6382 運動はそこそこです 64 | 0.5176 海外旅行に行きたい 65 | 0.5028 揚げ物は食べたくない 66 | 0.4898 外をランニングするのが好きです 67 | 0.4168 走るのは嫌いです 68 | ``` 69 | 70 | ## Upload to Hugging Face Model Hub 71 | 72 | Finally, upload the trained model to HuggingFace's model hub. Following the official document, the following process is executed. 73 | 74 | First, create a repository named "sbert-base-ja" from HuggingFace's website. 75 | 76 | Then, prepare git lfs. In a MacOS environment, git lfs can be installed as follows. 77 | 78 | ```sh 79 | $ brew install git-lfs 80 | $ git lfs install 81 | Updated git hooks. 82 | Git LFS initialized. 83 | ``` 84 | 85 | Then clone repository to local 86 | 87 | ```sh 88 | $ git clone https://huggingface.co/colorfulscoop/sbert-base-ja 89 | ``` 90 | 91 | Copy model without evaluation result. 92 | 93 | ```sh 94 | $ cp -r model/* sbert-base-ja 95 | $ rm -r sbert-base-ja/eval 96 | ``` 97 | 98 | Copy model card and changelog files 99 | 100 | ```sh 101 | $ cp model_card.md sbert-base-ja/README.md 102 | $ cp CHANGELOG.md sbert-base-ja 103 | ``` 104 | 105 | Finally commit it and push to Model Hub. 106 | 107 | ```sh 108 | $ cd sbert-base-ja 109 | $ git add . 110 | $ git commit -m "Add models and model card" 111 | $ git push origin 112 | ``` 113 | -------------------------------------------------------------------------------- /build_data.sh: -------------------------------------------------------------------------------- 1 | set -eu 2 | 3 | if [ ! -e data ]; then 4 | mkdir data 5 | fi 6 | 7 | cd data 8 | if [ ! -e JSNLI.zip ]; then 9 | wget -O JSNLI.zip https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi\?down\=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip\&name\=JSNLI.zip 10 | unzip JSNLI.zip 11 | fi 12 | 13 | # Convert to jsonl format 14 | cat jsnli_1.1/train_w_filtering.tsv | python3 ../convert_data_format.py >train_orig.jsonl 15 | cat jsnli_1.1/dev.tsv | python3 ../convert_data_format.py >test.jsonl 16 | 17 | # Split train/val data 18 | cat train_orig.jsonl | head -n 523005 >train.jsonl 19 | cat train_orig.jsonl | tail -n 10000 >val.jsonl -------------------------------------------------------------------------------- /convert_data_format.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | 5 | def format_sentence(xs): 6 | return xs.replace(" ", "") 7 | 8 | 9 | def main(): 10 | for line in sys.stdin: 11 | label, premise, hypothesis = line.strip("\n").split("\t") 12 | json_out = json.dumps( 13 | {"label": label, 14 | "premise": format_sentence(premise), 15 | "hypothesis": format_sentence(hypothesis)}, 16 | ensure_ascii=False 17 | ) 18 | print(json_out) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /model_card.md: -------------------------------------------------------------------------------- 1 | --- 2 | language: ja 3 | pipeline_tag: sentence-similarity 4 | tags: 5 | - sentence-transformers 6 | - feature-extraction 7 | - sentence-similarity 8 | widget: 9 | source_sentence: "走るのが趣味です" 10 | sentences: 11 | - 外をランニングするのが好きです 12 | - 運動はそこそこです 13 | - 走るのは嫌いです 14 | license: cc-by-sa-4.0 15 | --- 16 | 17 | # Sentence BERT base Japanese model 18 | 19 | This repository contains a Sentence BERT base model for Japanese. 20 | 21 | ## Pretrained model 22 | 23 | This model utilizes a Japanese BERT model [colorfulscoop/bert-base-ja](https://huggingface.co/colorfulscoop/bert-base-ja) v1.0 released under [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/) as a pretrained model. 24 | 25 | ## Training data 26 | 27 | [Japanese SNLI dataset](https://nlp.ist.i.kyoto-u.ac.jp/index.php?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88) released under [Creative Commons Attribution-ShareAlike 4.0](https://creativecommons.org/licenses/by-sa/4.0/) is used for training. 28 | 29 | Original training dataset is splitted into train/valid dataset. Finally, follwoing data is prepared. 30 | 31 | * Train data: 523,005 samples 32 | * Valid data: 10,000 samples 33 | * Test data: 3,916 samples 34 | 35 | ## Model description 36 | 37 | This model utilizes `SentenceTransformer` model from the [sentence-transformers](https://github.com/UKPLab/sentence-transformers) . 38 | The model detail is as below. 39 | 40 | ```py 41 | >>> from sentence_transformers import SentenceTransformer 42 | >>> SentenceTransformer("colorfulscoop/sbert-base-ja") 43 | SentenceTransformer( 44 | (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 45 | (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False}) 46 | ) 47 | ``` 48 | 49 | ## Training 50 | 51 | This model finetuned [colorfulscoop/bert-base-ja](https://huggingface.co/colorfulscoop/bert-base-ja) with Softmax classifier of 3 labels of SNLI. AdamW optimizer with learning rate of 2e-05 linearly warmed-up in 10% of train data was used. The model was trained in 1 epoch with batch size 8. 52 | 53 | Note: in a original paper of [Sentence BERT](https://arxiv.org/abs/1908.10084), a batch size of the model trained on SNLI and Multi-Genle NLI was 16. In this model, the dataset is around half smaller than the origial one, therefore the batch size was set to half of the original batch size of 16. 54 | 55 | Trainind was conducted on Ubuntu 18.04.5 LTS with one RTX 2080 Ti. 56 | 57 | After training, test set accuracy reached to 0.8529. 58 | 59 | Training code is available in [a GitHub repository](https://github.com/colorfulscoop/sbert-ja). 60 | 61 | ## Usage 62 | 63 | First, install dependecies. 64 | 65 | ```sh 66 | $ pip install sentence-transformers==2.0.0 67 | ``` 68 | 69 | Then initialize `SentenceTransformer` model and use `encode` method to convert to vectors. 70 | 71 | ```py 72 | >>> from sentence_transformers import SentenceTransformer 73 | >>> model = SentenceTransformer("colorfulscoop/sbert-base-ja") 74 | >>> sentences = ["外をランニングするのが好きです", "海外旅行に行くのが趣味です"] 75 | >>> model.encode(sentences) 76 | ``` 77 | 78 | ## License 79 | 80 | Copyright (c) 2021 Colorful Scoop 81 | 82 | All the models included in this repository are licensed under [Creative Commons Attribution-ShareAlike 4.0](https://creativecommons.org/licenses/by-sa/4.0/). 83 | 84 | **Disclaimer:** Use of this model is at your sole risk. Colorful Scoop makes no warranty or guarantee of any outputs from the model. Colorful Scoop is not liable for any trouble, loss, or damage arising from the model output. 85 | 86 | --- 87 | 88 | This model utilizes the folllowing pretrained model. 89 | 90 | * **Name:** bert-base-ja 91 | * **Credit:** (c) 2021 Colorful Scoop 92 | * **License:** [Creative Commons Attribution-ShareAlike 3.0](https://creativecommons.org/licenses/by-sa/3.0/) 93 | * **Disclaimer:** The model potentially has possibility that it generates similar texts in the training data, texts not to be true, or biased texts. Use of the model is at your sole risk. Colorful Scoop makes no warranty or guarantee of any outputs from the model. Colorful Scoop is not liable for any trouble, loss, or damage arising from the model output. 94 | * **Link:** https://huggingface.co/colorfulscoop/bert-base-ja 95 | 96 | --- 97 | 98 | This model utilizes the following data for fine-tuning. 99 | 100 | * **Name:** 日本語SNLI(JSNLI)データセット 101 | * **Credit:** [https://nlp.ist.i.kyoto-u.ac.jp/index.php?日本語SNLI(JSNLI)データセット](https://nlp.ist.i.kyoto-u.ac.jp/index.php?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88) 102 | * **License:** [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) 103 | * **Link:** [https://nlp.ist.i.kyoto-u.ac.jp/index.php?日本語SNLI(JSNLI)データセット](https://nlp.ist.i.kyoto-u.ac.jp/index.php?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers==2.0.0 2 | fire==0.4.0 -------------------------------------------------------------------------------- /semsearch.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer, util 2 | import torch 3 | 4 | 5 | def main(model, top_k=None): 6 | samples = [ 7 | "外をランニングするのが好きです", 8 | "走るのは嫌いです", 9 | "運動はそこそこです", 10 | "海外旅行に行きたい", 11 | "天ぷらが食べたい", 12 | "りんごが食べたい", 13 | "揚げ物は食べたくない", 14 | ] 15 | model = SentenceTransformer(model) 16 | samples_embedding = model.encode(samples, convert_to_tensor=True) 17 | 18 | # Query sentences: 19 | queries = [ 20 | "走るのが趣味です", 21 | "外国を旅したい", 22 | "揚げ物が食べたい", 23 | ] 24 | 25 | if top_k: 26 | top_k = min(top_k, len(samples)) 27 | else: 28 | top_k = len(samples) 29 | for query in queries: 30 | query_embedding = model.encode(query, convert_to_tensor=True) 31 | 32 | scores = util.pytorch_cos_sim(query_embedding, samples_embedding)[0] 33 | result = torch.topk(scores, k=top_k) 34 | 35 | print("======") 36 | print(f"Query: {query}") 37 | for score, idx in zip(*result): 38 | print(f"{score.item():.4f} {samples[idx]}") 39 | 40 | 41 | if __name__ == "__main__": 42 | import fire 43 | 44 | fire.Fire(main) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code trains Sentence BERT model based on NLI dataset 3 | """ 4 | 5 | from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation 6 | from torch.utils.data import DataLoader 7 | import math 8 | import json 9 | import os 10 | import transformers 11 | import logging 12 | 13 | 14 | def load_samples(jsonl_file, label_mapper): 15 | samples = [] 16 | for line in open(jsonl_file): 17 | item = json.loads(line) 18 | sample = InputExample(texts=[item["premise"], item["hypothesis"]], label=label_mapper[item["label"]]) 19 | samples.append(sample) 20 | return samples 21 | 22 | 23 | def main( 24 | base_model, output_model, train_data, valid_data, test_data, 25 | epochs=1, evaluation_steps=1000, batch_size=8, seed=None, 26 | use_amp=False, 27 | ): 28 | logging.basicConfig(level=logging.INFO) 29 | 30 | if seed: 31 | transformers.trainer_utils.set_seed(0) 32 | 33 | # Prepare model 34 | model = SentenceTransformer(base_model) 35 | 36 | # Prepare data 37 | label_mapper = { 38 | "contradiction": 0, 39 | "entailment": 1, 40 | "neutral": 2, 41 | } 42 | 43 | train_samples = load_samples(train_data, label_mapper) 44 | valid_samples = load_samples(valid_data, label_mapper) 45 | test_samples = load_samples(test_data, label_mapper) 46 | 47 | train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) 48 | valid_dataloader = DataLoader(valid_samples, shuffle=False, batch_size=batch_size) 49 | test_dataloader = DataLoader(test_samples, shuffle=False, batch_size=batch_size) 50 | loss = losses.SoftmaxLoss( 51 | model=model, 52 | sentence_embedding_dimension=model.get_sentence_embedding_dimension(), 53 | num_labels=len(label_mapper) 54 | ) 55 | # See https://github.com/UKPLab/sentence-transformers/issues/27 about how to use LabelAccuracyEvaluator 56 | evaluator = evaluation.LabelAccuracyEvaluator(valid_dataloader, softmax_model=loss, name="val") 57 | warmup_steps = math.ceil(len(train_dataloader) * 0.1) 58 | 59 | model.fit( 60 | train_objectives=[(train_dataloader, loss)], 61 | evaluator=evaluator, 62 | epochs=epochs, 63 | evaluation_steps=evaluation_steps, 64 | warmup_steps=warmup_steps, 65 | output_path=output_model, 66 | use_amp=use_amp, 67 | ) 68 | 69 | # Test model 70 | test_model = SentenceTransformer(output_model) 71 | test_model.to(model.device) 72 | loss.model = test_model 73 | test_evaluator = evaluation.LabelAccuracyEvaluator(test_dataloader, softmax_model=loss, name="test") 74 | test_evaluator(test_model, output_path=os.path.join(output_model, "eval")) 75 | 76 | 77 | if __name__ == "__main__": 78 | import fire 79 | 80 | fire.Fire(main) 81 | --------------------------------------------------------------------------------