├── requirements.txt ├── setup.py ├── simalign ├── __init__.py ├── utils.py └── simalign.py ├── assets └── example.png ├── pyproject.toml ├── scripts ├── align_example.py ├── calc_align_score.py ├── visualize.py └── align_files.py ├── samples ├── sample_eng.txt ├── sample_deu.txt └── sample_eng_deu.gold ├── setup.cfg ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /simalign/__init__.py: -------------------------------------------------------------------------------- 1 | from .simalign import EmbeddingLoader, SentenceAligner 2 | -------------------------------------------------------------------------------- /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cisnlp/simalign/HEAD/assets/example.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /scripts/align_example.py: -------------------------------------------------------------------------------- 1 | import simalign 2 | 3 | source_sentence = "Sir Nils Olav III. was knighted by the norwegian king ." 4 | target_sentence = "Nils Olav der Dritte wurde vom norwegischen König zum Ritter geschlagen ." 5 | 6 | model = simalign.SentenceAligner() 7 | result = model.get_word_aligns(source_sentence, target_sentence) 8 | print(result) 9 | -------------------------------------------------------------------------------- /samples/sample_eng.txt: -------------------------------------------------------------------------------- 1 | 0 We do not believe that we should cherry-pick . 2 | 1 But this is not what happens . 3 | 2 Of course , if a drug addict becomes a pusher , then it is right and necessary that he should pay and answer before the law also . 4 | 3 Commissioner , ladies and gentlemen , I should like to begin by thanking Mr Burtone for his report . 5 | 4 ' Legal drugs ' ( tranquillizers ) finding their way on to an illegal market , especially when used in combination with alcohol , are a major and serious problem particularly for young people . 6 | -------------------------------------------------------------------------------- /samples/sample_deu.txt: -------------------------------------------------------------------------------- 1 | 0 Wir glauben nicht , daß wir nur Rosinen herauspicken sollten . 2 | 1 Das stimmt nicht ! 3 | 2 Sicher - wenn ein Drogenabhängiger zum Dealer wird , dann ist es richtig und notwendig , daß er dafür auch vor dem Gesetz zur Rechenschaft gezogen wird . 4 | 3 Herr Kommissar , liebe Kolleginnen und Kollegen ! Zunächst herzlichen Dank , Herr Burtone , für Ihren Bericht . 5 | 4 Die in den illegalen Handel gelangten sogenannten legalen Drogen bzw. Beruhigungsmittel sind vor allem in Zusammenhang mit Alkohol ein gravierendes Problem , speziell für Jugendliche . 6 | -------------------------------------------------------------------------------- /samples/sample_eng_deu.gold: -------------------------------------------------------------------------------- 1 | 0 0-0 1-1 2-2 3-1 4-3 4-4 5-5 6-9 7-6 7-7 7-8 8-10 2 | 1 1-0 2-1 3-2 4-1 5-1 6-3 3 | 2 0-0 1-0 10-8 11-9 12-11 13-10 14-12 15-13 16-14 17-16 18-17 19p25 19p26 20p24 21p24 22p24 23-20 24-21 25-22 26-19 27-27 2p1 3-2 4-3 5-4 6-4 7-7 8p5 9-6 4 | 3 0-1 0p0 1-2 10-8 11-8 12-10 12-9 13-12 14-13 15-15 16-16 17-17 18-18 2-4 3-5 4-6 5p7 6-8 7-8 8-8 9-8 5 | 4 0p6 1-7 10-1 11-1 12-2 13-3 14-4 16-12 16-13 17-14 17p15 18-15 18p14 19-15 19p14 2-8 20-15 21-16 22-17 23-21 24-11 25-18 26-19 27-19 28-19 29-20 30-22 31-23 32-24 33-24 34-25 3p6 3p9 4p9 5-10 6p9 7-5 8-5 9-5 6 | -------------------------------------------------------------------------------- /simalign/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Text 3 | import os 4 | 5 | 6 | def get_logger(name: Text, filename: Text = None, level: int = logging.DEBUG) -> logging.Logger: 7 | logger = logging.getLogger(name) 8 | logger.setLevel(level) 9 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | 11 | ch = logging.StreamHandler() 12 | ch.setLevel(level) 13 | ch.setFormatter(formatter) 14 | logger.addHandler(ch) 15 | 16 | if filename is not None: 17 | fh = logging.FileHandler(filename) 18 | fh.setLevel(level) 19 | fh.setFormatter(formatter) 20 | logger.addHandler(fh) 21 | 22 | return logger 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = simalign 3 | version = v0.4 4 | author = Masoud Jalili Sabet, Philipp Dufter 5 | author_email = philipp@cis.lmu.de,masoud@cis.lmu.de 6 | description = Word Alignments using Pretrained Language Models 7 | keywords = NLP deep learning transformer pytorch BERT Word Alignment 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | url = https://github.com/cisnlp/simalign 11 | project_urls = 12 | Bug Tracker = https://github.com/cisnlp/simalign/issues 13 | classifiers = 14 | Programming Language :: Python :: 3 15 | License :: OSI Approved :: MIT License 16 | Operating System :: OS Independent 17 | [options] 18 | packages = simalign 19 | install_requires = 20 | numpy 21 | torch 22 | scipy 23 | transformers 24 | regex 25 | networkx 26 | scikit_learn 27 | python_requires = >=3.6.0 28 | zip_safe = False 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2020-2021] [Masoud Jalili Sabet, Philipp Dufter] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/calc_align_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import collections 4 | import os.path 5 | 6 | 7 | def load_gold(g_path): 8 | gold_f = open(g_path, "r") 9 | pros = {} 10 | surs = {} 11 | all_count = 0. 12 | surs_count = 0. 13 | 14 | for line in gold_f: 15 | line = line.strip().split("\t") 16 | line[1] = line[1].split() 17 | 18 | pros[line[0]] = set([x.replace("p", "-") for x in line[1]]) 19 | surs[line[0]] = set([x for x in line[1] if "p" not in x]) 20 | 21 | all_count += len(pros[line[0]]) 22 | surs_count += len(surs[line[0]]) 23 | 24 | return pros, surs, surs_count 25 | 26 | def calc_score(input_path, probs, surs, surs_count): 27 | total_hit = 0. 28 | p_hit = 0. 29 | s_hit = 0. 30 | target_f = open(input_path, "r") 31 | 32 | for line in target_f: 33 | line = line.strip().split("\t") 34 | 35 | if line[0] not in probs: continue 36 | if len(line) < 2: continue 37 | line[1] = line[1].split() 38 | if len(line[1][0].split("-")) > 2: 39 | line[1] = ["-".join(x.split("-")[:2]) for x in line[1]] 40 | 41 | p_hit += len(set(line[1]) & set(probs[line[0]])) 42 | s_hit += len(set(line[1]) & set(surs[line[0]])) 43 | total_hit += len(set(line[1])) 44 | target_f.close() 45 | 46 | y_prec = round(p_hit / max(total_hit, 1.), 3) 47 | y_rec = round(s_hit / max(surs_count, 1.), 3) 48 | y_f1 = round(2. * y_prec * y_rec / max((y_prec + y_rec), 0.01), 3) 49 | aer = round(1 - (s_hit + p_hit) / (total_hit + surs_count), 3) 50 | 51 | return y_prec, y_rec, y_f1, aer 52 | 53 | 54 | if __name__ == "__main__": 55 | ''' 56 | Calculate alignment quality scores based on the gold standard. 57 | The output contains Precision, Recall, F1, and AER. 58 | The gold annotated file should be selected by "gold_path". 59 | The generated alignment file should be selected by "input_path". 60 | Both gold file and input file are in the FastAlign format with sentence number at the start of line separated with TAB. 61 | 62 | usage: python calc_align_score.py gold_file generated_file 63 | ''' 64 | 65 | parser = argparse.ArgumentParser(description="Calculate alignment quality scores based on the gold standard.", epilog="example: python calc_align_score.py gold_path input_path") 66 | parser.add_argument("gold_path") 67 | parser.add_argument("input_path") 68 | args = parser.parse_args() 69 | 70 | if not os.path.isfile(args.input_path): 71 | print("The input file does not exist:\n", args.input_path) 72 | exit() 73 | 74 | probs, surs, surs_count = load_gold(args.gold_path) 75 | y_prec, y_rec, y_f1, aer = calc_score(args.input_path, probs, surs, surs_count) 76 | 77 | print("Prec: {}\tRec: {}\tF1: {}\tAER: {}".format(y_prec, y_rec, y_f1, aer)) 78 | 79 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from typing import List, Text, Tuple 4 | 5 | 6 | def line2matrix(line: Text, n: int, m: int) -> Tuple[np.ndarray, np.ndarray]: 7 | ''' 8 | converts alignemnt given in the format "0-1 3p4 5-6" to alignment matrices 9 | n, m: maximum length of the involved sentences (i.e., dimensions of the alignemnt matrices) 10 | ''' 11 | def convert(i, j): 12 | i, j = int(i), int(j) 13 | if i >= n or j >= m: 14 | raise ValueError("Error in Gold Standard?") 15 | return i, j 16 | possibles = np.zeros((n, m)) 17 | sures = np.zeros((n, m)) 18 | for elem in line.split(" "): 19 | if "p" in elem: 20 | i, j = convert(*elem.split("p")) 21 | possibles[i, j] = 1 22 | elif "-" in elem: 23 | i, j = convert(*elem.split("-")) 24 | possibles[i, j] = 1 25 | sures[i, j] = 1 26 | return sures, possibles 27 | 28 | 29 | def plot_alignments(e: List[Text], 30 | f: List[Text], 31 | sures: np.ndarray, 32 | possibles: np.ndarray, 33 | alignment1: np.ndarray, 34 | alignment2: np.ndarray = None, 35 | title: Text = None, 36 | filename: Text = None, 37 | dpi: int = 150): 38 | shorter = min(len(e), len(f)) 39 | scalefactor = min((4 / shorter), 1) 40 | 41 | groundtruth = 0.75 * sures + 0.4 * possibles 42 | 43 | fig, ax = plt.subplots() 44 | im = ax.imshow(groundtruth, cmap="Greens", vmin=0, vmax=1.5) 45 | 46 | # show all ticks... 47 | ax.set_xticks(np.arange(len(f))) 48 | ax.set_yticks(np.arange(len(e))) 49 | # ... and label them 50 | ax.set_xticklabels(f, fontsize=25 * scalefactor) 51 | ax.set_yticklabels(e, fontsize=25 * scalefactor) 52 | 53 | for edge, spine in ax.spines.items(): 54 | spine.set_visible(False) 55 | 56 | ax.tick_params(top=True, bottom=False, 57 | labeltop=True, labelbottom=False) 58 | 59 | # Rotate the tick labels and set their alignment. 60 | plt.setp(ax.get_xticklabels(), rotation=30, ha="left", 61 | rotation_mode="default") 62 | plt.setp(ax.get_yticklabels(), rotation=0, ha="right", 63 | rotation_mode="anchor") 64 | ax.set_xticks(np.arange(groundtruth.shape[1] + 1) - .5, minor=True) 65 | ax.set_yticks(np.arange(groundtruth.shape[0] + 1) - .5, minor=True) 66 | 67 | # set grid 68 | ax.grid(which="minor", color="black", linestyle='-', linewidth=1) 69 | ax.tick_params(which="minor", bottom=False, left=False) 70 | # Loop over data dimensions and create text annotations. 71 | circle = dict(boxstyle="circle,pad=0.3", fc=(0, 0, 0, 0.0), ec="black", lw=3) 72 | roundthing = dict(boxstyle="square,pad=0.3", fc="black", ec=(0, 0, 0, 0.0), lw=2) 73 | 74 | # plot alignments 75 | for i in range(len(e)): 76 | for j in range(len(f)): 77 | if alignment1[i, j] > 0: 78 | t = ax.text(j, i, "x", ha="center", va="center", 79 | size=25 * scalefactor, 80 | bbox=circle, color=(0, 0, 0, 0.0)) 81 | if alignment2 is not None and alignment2[i, j] > 0: 82 | t = ax.text(j, i, "x", ha="center", va="center", 83 | size=12 * scalefactor, 84 | bbox=roundthing, color=(0, 0, 0, 0.0)) 85 | if title: 86 | ax.set_title(title) 87 | fig.tight_layout() 88 | if filename: 89 | plt.savefig(filename, dpi=dpi) 90 | else: 91 | plt.show() 92 | 93 | 94 | if __name__ == '__main__': 95 | line2matrix("0-0 1p1 2-1", 3, 2) 96 | plot_alignments(["Testing", "this", "."], 97 | ["Hier", "wird", "getestet", "."], 98 | np.array([[0, 0, 1, 0], 99 | [0, 0, 0, 0], 100 | [0, 0, 0, 1]]), 101 | np.array([[0, 0, 0, 0], 102 | [1, 0, 0, 0], 103 | [0, 0, 0, 0]]), 104 | np.array([[0, 1, 0, 0], 105 | [0, 0, 0, 0], 106 | [0, 1, 0, 0]]), 107 | np.array([[0, 0, 0, 1], 108 | [0, 0, 0, 0], 109 | [0, 0, 0, 0]]), 110 | "Example") 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SimAlign: Similarity Based Word Aligner 2 | ============== 3 | 4 |

5 |
6 | Alignment Example 7 |
8 |

9 | 10 | SimAlign is a high-quality word alignment tool that uses static and contextualized embeddings and **does not require parallel training data**. 11 | 12 | The following table shows how it compares to popular statistical alignment models: 13 | 14 | | | ENG-CES | ENG-DEU | ENG-FAS | ENG-FRA | ENG-HIN | ENG-RON | 15 | | ---------- | ------- | ------- | ------- | ------- | ------- | ------- | 16 | | fast-align | .78 | .71 | .46 | .84 | .38 | .68 | 17 | | eflomal | .85 | .77 | .63 | .93 | .52 | .72 | 18 | | mBERT-Argmax | .87 | .81 | .67 | .94 | .55 | .65 | 19 | 20 | Shown is F1, maximum across subword and word level. For more details see the [Paper](https://arxiv.org/pdf/2004.08728.pdf). 21 | 22 | 23 | Installation and Usage 24 | -------- 25 | 26 | Tested with Python 3.7, Transformers 3.1.0, Torch 1.5.0. Networkx 2.4 is optional (only required for Match algorithm). 27 | For full list of dependencies see `setup.py`. 28 | For installation of transformers see [their repo](https://github.com/huggingface/transformers#installation). 29 | 30 | Download the repo for use or alternatively install with PyPi 31 | 32 | `pip install simalign` 33 | 34 | or directly with pip from GitHub 35 | 36 | `pip install --upgrade git+https://github.com/cisnlp/simalign.git#egg=simalign` 37 | 38 | 39 | An example for using our code: 40 | ```python 41 | from simalign import SentenceAligner 42 | 43 | # making an instance of our model. 44 | # You can specify the embedding model and all alignment settings in the constructor. 45 | myaligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai") 46 | 47 | # The source and target sentences should be tokenized to words. 48 | src_sentence = ["This", "is", "a", "test", "."] 49 | trg_sentence = ["Das", "ist", "ein", "Test", "."] 50 | 51 | # The output is a dictionary with different matching methods. 52 | # Each method has a list of pairs indicating the indexes of aligned words (The alignments are zero-indexed). 53 | alignments = myaligner.get_word_aligns(src_sentence, trg_sentence) 54 | 55 | for matching_method in alignments: 56 | print(matching_method, ":", alignments[matching_method]) 57 | 58 | # Expected output: 59 | # mwmf (Match): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] 60 | # inter (ArgMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] 61 | # itermax (IterMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] 62 | ``` 63 | For more examples of how to use our code see `scripts/align_example.py`. 64 | 65 | Demo 66 | -------- 67 | 68 | An online demo is available [here](https://simalign.cis.lmu.de/). 69 | 70 | 71 | Gold Standards 72 | -------- 73 | Links to the gold standars used in the paper are here: 74 | 75 | 76 | | Language Pair | Citation | Type |Link | 77 | | ------------- | ------------- | ------------- | ------------- | 78 | | ENG-CES | Marecek et al. 2008 | Gold Alignment | http://ufal.mff.cuni.cz/czech-english-manual-word-alignment | 79 | | ENG-DEU | EuroParl-based | Gold Alignment | www-i6.informatik.rwth-aachen.de/goldAlignment/ | 80 | | ENG-FAS | Tvakoli et al. 2014 | Gold Alignment | http://eceold.ut.ac.ir/en/node/940 | 81 | | ENG-FRA | WPT2003, Och et al. 2000,| Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt/ | 82 | | ENG-HIN | WPT2005 | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt05/ | 83 | | ENG-RON | WPT2005 Mihalcea et al. 2003 | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt05/ | 84 | 85 | 86 | Evaluation Script 87 | -------- 88 | For evaluating the output alignments use `scripts/calc_align_score.py`. 89 | 90 | The gold alignment file should have the same format as SimAlign outputs. 91 | Sure alignment edges in the gold standard have a '-' between the source and the target indices and the possible edges have a 'p' between indices. 92 | For sample parallel sentences and their gold alignments from ENG-DEU, see `samples`. 93 | 94 | 95 | Publication 96 | -------- 97 | 98 | If you use the code, please cite 99 | 100 | ``` 101 | @inproceedings{jalili-sabet-etal-2020-simalign, 102 | title = "{S}im{A}lign: High Quality Word Alignments without Parallel Training Data using Static and Contextualized Embeddings", 103 | author = {Jalili Sabet, Masoud and 104 | Dufter, Philipp and 105 | Yvon, Fran{\c{c}}ois and 106 | Sch{\"u}tze, Hinrich}, 107 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings", 108 | month = nov, 109 | year = "2020", 110 | address = "Online", 111 | publisher = "Association for Computational Linguistics", 112 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.147", 113 | pages = "1627--1643", 114 | } 115 | ``` 116 | 117 | Feedback 118 | -------- 119 | 120 | Feedback and Contributions more than welcome! Just reach out to @masoudjs or @pdufter. 121 | 122 | 123 | FAQ 124 | -------- 125 | 126 | ##### Do I need parallel data to train the system? 127 | 128 | No, no parallel training data is required. 129 | 130 | ##### Which languages can be aligned? 131 | 132 | This depends on the underlying pretrained multilingual language model used. For example, if mBERT is used, it covers 104 languages as listed [here](https://github.com/google-research/bert/blob/master/multilingual.md). 133 | 134 | ##### Do I need GPUs for running this? 135 | 136 | Each alignment simply requires a single forward pass in the pretrained language model. While this is certainly 137 | faster on GPU, it runs fine on CPU. On one GPU (GeForce GTX 1080 Ti) it takes around 15-20 seconds to align 500 parallel sentences. 138 | 139 | 140 | 141 | License 142 | ------- 143 | 144 | Copyright (C) 2020, Masoud Jalili Sabet, Philipp Dufter 145 | 146 | A full copy of the license can be found in LICENSE. 147 | -------------------------------------------------------------------------------- /simalign/simalign.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import logging 5 | from typing import Dict, List, Tuple, Union 6 | 7 | import numpy as np 8 | from scipy.stats import entropy 9 | from scipy.sparse import csr_matrix 10 | from sklearn.preprocessing import normalize 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | try: 13 | import networkx as nx 14 | from networkx.algorithms.bipartite.matrix import from_biadjacency_matrix 15 | except ImportError: 16 | nx = None 17 | import torch 18 | from transformers import BertModel, BertTokenizer, XLMModel, XLMTokenizer, RobertaModel, RobertaTokenizer, XLMRobertaModel, XLMRobertaTokenizer, AutoConfig, AutoModel, AutoTokenizer 19 | 20 | from simalign.utils import get_logger 21 | 22 | LOG = get_logger(__name__) 23 | 24 | 25 | class EmbeddingLoader(object): 26 | def __init__(self, model: str="bert-base-multilingual-cased", device=torch.device('cpu'), layer: int=8): 27 | TR_Models = { 28 | 'bert-base-uncased': (BertModel, BertTokenizer), 29 | 'bert-base-multilingual-cased': (BertModel, BertTokenizer), 30 | 'bert-base-multilingual-uncased': (BertModel, BertTokenizer), 31 | 'xlm-mlm-100-1280': (XLMModel, XLMTokenizer), 32 | 'roberta-base': (RobertaModel, RobertaTokenizer), 33 | 'xlm-roberta-base': (XLMRobertaModel, XLMRobertaTokenizer), 34 | 'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer), 35 | } 36 | 37 | self.model = model 38 | self.device = device 39 | self.layer = layer 40 | self.emb_model = None 41 | self.tokenizer = None 42 | 43 | if model in TR_Models: 44 | model_class, tokenizer_class = TR_Models[model] 45 | self.emb_model = model_class.from_pretrained(model, output_hidden_states=True) 46 | self.emb_model.eval() 47 | self.emb_model.to(self.device) 48 | self.tokenizer = tokenizer_class.from_pretrained(model) 49 | else: 50 | # try to load model with auto-classes 51 | config = AutoConfig.from_pretrained(model, output_hidden_states=True) 52 | self.emb_model = AutoModel.from_pretrained(model, config=config) 53 | self.emb_model.eval() 54 | self.emb_model.to(self.device) 55 | self.tokenizer = AutoTokenizer.from_pretrained(model) 56 | LOG.info("Initialized the EmbeddingLoader with model: {}".format(self.model)) 57 | 58 | def get_embed_list(self, sent_batch: List[List[str]]) -> torch.Tensor: 59 | if self.emb_model is not None: 60 | with torch.no_grad(): 61 | if not isinstance(sent_batch[0], str): 62 | inputs = self.tokenizer(sent_batch, is_split_into_words=True, padding=True, truncation=True, return_tensors="pt") 63 | else: 64 | inputs = self.tokenizer(sent_batch, is_split_into_words=False, padding=True, truncation=True, return_tensors="pt") 65 | hidden = self.emb_model(**inputs.to(self.device))["hidden_states"] 66 | if self.layer >= len(hidden): 67 | raise ValueError(f"Specified to take embeddings from layer {self.layer}, but model has only {len(hidden)} layers.") 68 | outputs = hidden[self.layer] 69 | return outputs[:, 1:-1, :] 70 | else: 71 | return None 72 | 73 | 74 | class SentenceAligner(object): 75 | def __init__(self, model: str = "bert", token_type: str = "bpe", distortion: float = 0.0, matching_methods: str = "mai", device: str = "cpu", layer: int = 8): 76 | model_names = { 77 | "bert": "bert-base-multilingual-cased", 78 | "xlmr": "xlm-roberta-base" 79 | } 80 | all_matching_methods = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev"} 81 | 82 | self.model = model 83 | if model in model_names: 84 | self.model = model_names[model] 85 | self.token_type = token_type 86 | self.distortion = distortion 87 | self.matching_methods = [all_matching_methods[m] for m in matching_methods] 88 | self.device = torch.device(device) 89 | 90 | self.embed_loader = EmbeddingLoader(model=self.model, device=self.device, layer=layer) 91 | 92 | @staticmethod 93 | def get_max_weight_match(sim: np.ndarray) -> np.ndarray: 94 | if nx is None: 95 | raise ValueError("networkx must be installed to use match algorithm.") 96 | def permute(edge): 97 | if edge[0] < sim.shape[0]: 98 | return edge[0], edge[1] - sim.shape[0] 99 | else: 100 | return edge[1], edge[0] - sim.shape[0] 101 | G = from_biadjacency_matrix(csr_matrix(sim)) 102 | matching = nx.max_weight_matching(G, maxcardinality=True) 103 | matching = [permute(x) for x in matching] 104 | matching = sorted(matching, key=lambda x: x[0]) 105 | res_matrix = np.zeros_like(sim) 106 | for edge in matching: 107 | res_matrix[edge[0], edge[1]] = 1 108 | return res_matrix 109 | 110 | @staticmethod 111 | def get_similarity(X: np.ndarray, Y: np.ndarray) -> np.ndarray: 112 | return (cosine_similarity(X, Y) + 1.0) / 2.0 113 | 114 | @staticmethod 115 | def average_embeds_over_words(bpe_vectors: np.ndarray, word_tokens_pair: List[List[str]]) -> List[np.array]: 116 | w2b_map = [] 117 | cnt = 0 118 | w2b_map.append([]) 119 | for wlist in word_tokens_pair[0]: 120 | w2b_map[0].append([]) 121 | for x in wlist: 122 | w2b_map[0][-1].append(cnt) 123 | cnt += 1 124 | cnt = 0 125 | w2b_map.append([]) 126 | for wlist in word_tokens_pair[1]: 127 | w2b_map[1].append([]) 128 | for x in wlist: 129 | w2b_map[1][-1].append(cnt) 130 | cnt += 1 131 | 132 | new_vectors = [] 133 | for l_id in range(2): 134 | w_vector = [] 135 | for word_set in w2b_map[l_id]: 136 | w_vector.append(bpe_vectors[l_id][word_set].mean(0)) 137 | new_vectors.append(np.array(w_vector)) 138 | return new_vectors 139 | 140 | @staticmethod 141 | def get_alignment_matrix(sim_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 142 | m, n = sim_matrix.shape 143 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n 144 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m 145 | return forward, backward.transpose() 146 | 147 | @staticmethod 148 | def apply_distortion(sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray: 149 | shape = sim_matrix.shape 150 | if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: 151 | return sim_matrix 152 | 153 | pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])]) 154 | pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])]) 155 | distortion_mask = 1.0 - ((pos_x - np.transpose(pos_y)) ** 2) * ratio 156 | 157 | return np.multiply(sim_matrix, distortion_mask) 158 | 159 | @staticmethod 160 | def iter_max(sim_matrix: np.ndarray, max_count: int=2) -> np.ndarray: 161 | alpha_ratio = 0.9 162 | m, n = sim_matrix.shape 163 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n 164 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m 165 | inter = forward * backward.transpose() 166 | 167 | if min(m, n) <= 2: 168 | return inter 169 | 170 | new_inter = np.zeros((m, n)) 171 | count = 1 172 | while count < max_count: 173 | mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0) 174 | mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0) 175 | mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0) 176 | mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y)) 177 | if mask_x.sum() < 1.0 or mask_y.sum() < 1.0: 178 | mask *= 0.0 179 | mask_zeros *= 0.0 180 | 181 | new_sim = sim_matrix * mask 182 | fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros 183 | bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros 184 | new_inter = fwd * bac 185 | 186 | if np.array_equal(inter + new_inter, inter): 187 | break 188 | inter = inter + new_inter 189 | count += 1 190 | return inter 191 | 192 | def get_word_aligns(self, src_sent: Union[str, List[str]], trg_sent: Union[str, List[str]]) -> Dict[str, List]: 193 | if isinstance(src_sent, str): 194 | src_sent = src_sent.split() 195 | if isinstance(trg_sent, str): 196 | trg_sent = trg_sent.split() 197 | l1_tokens = [self.embed_loader.tokenizer.tokenize(word) for word in src_sent] 198 | l2_tokens = [self.embed_loader.tokenizer.tokenize(word) for word in trg_sent] 199 | bpe_lists = [[bpe for w in sent for bpe in w] for sent in [l1_tokens, l2_tokens]] 200 | 201 | if self.token_type == "bpe": 202 | l1_b2w_map = [] 203 | for i, wlist in enumerate(l1_tokens): 204 | l1_b2w_map += [i for x in wlist] 205 | l2_b2w_map = [] 206 | for i, wlist in enumerate(l2_tokens): 207 | l2_b2w_map += [i for x in wlist] 208 | 209 | vectors = self.embed_loader.get_embed_list([src_sent, trg_sent]).cpu().detach().numpy() 210 | vectors = [vectors[i, :len(bpe_lists[i])] for i in [0, 1]] 211 | 212 | if self.token_type == "word": 213 | vectors = self.average_embeds_over_words(vectors, [l1_tokens, l2_tokens]) 214 | 215 | all_mats = {} 216 | sim = self.get_similarity(vectors[0], vectors[1]) 217 | sim = self.apply_distortion(sim, self.distortion) 218 | 219 | all_mats["fwd"], all_mats["rev"] = self.get_alignment_matrix(sim) 220 | all_mats["inter"] = all_mats["fwd"] * all_mats["rev"] 221 | if "mwmf" in self.matching_methods: 222 | all_mats["mwmf"] = self.get_max_weight_match(sim) 223 | if "itermax" in self.matching_methods: 224 | all_mats["itermax"] = self.iter_max(sim) 225 | 226 | aligns = {x: set() for x in self.matching_methods} 227 | for i in range(len(vectors[0])): 228 | for j in range(len(vectors[1])): 229 | for ext in self.matching_methods: 230 | if all_mats[ext][i, j] > 0: 231 | if self.token_type == "bpe": 232 | aligns[ext].add((l1_b2w_map[i], l2_b2w_map[j])) 233 | else: 234 | aligns[ext].add((i, j)) 235 | for ext in aligns: 236 | aligns[ext] = sorted(aligns[ext]) 237 | return aligns 238 | -------------------------------------------------------------------------------- /scripts/align_files.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import regex 4 | import codecs 5 | import argparse 6 | import collections 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch.nn.functional as F 10 | 11 | from simalign.simalign import * 12 | 13 | 14 | def gather_null_aligns(sim_matrix: np.ndarray, inter_matrix: np.ndarray) -> List[float]: 15 | shape = sim_matrix.shape 16 | if min(shape[0], shape[1]) <= 2: 17 | return [] 18 | norm_x = normalize(sim_matrix, axis=1, norm='l1') 19 | norm_y = normalize(sim_matrix, axis=0, norm='l1') 20 | 21 | entropy_x = np.array([entropy(norm_x[i, :]) / np.log(shape[1]) for i in range(shape[0])]) 22 | entropy_y = np.array([entropy(norm_y[:, j]) / np.log(shape[0]) for j in range(shape[1])]) 23 | 24 | mask_x = np.tile(entropy_x[:, np.newaxis], (1, shape[1])) 25 | mask_y = np.tile(entropy_y, (shape[0], 1)) 26 | 27 | all_ents = np.multiply(inter_matrix, np.minimum(mask_x, mask_y)) 28 | return [x.item() for x in np.nditer(all_ents) if x.item() > 0] 29 | 30 | def apply_percentile_null_aligns(sim_matrix: np.ndarray, ratio: float=1.0) -> np.ndarray: 31 | shape = sim_matrix.shape 32 | if min(shape[0], shape[1]) <= 2: 33 | return np.ones(shape) 34 | norm_x = normalize(sim_matrix, axis=1, norm='l1') 35 | norm_y = normalize(sim_matrix, axis=0, norm='l1') 36 | entropy_x = np.array([entropy(norm_x[i, :]) / np.log(shape[1]) for i in range(shape[0])]) 37 | entropy_y = np.array([entropy(norm_y[:, j]) / np.log(shape[0]) for j in range(shape[1])]) 38 | mask_x = np.tile(entropy_x[:, np.newaxis], (1, shape[1])) 39 | mask_y = np.tile(entropy_y, (shape[0], 1)) 40 | 41 | ents_mask = np.where(np.minimum(mask_x, mask_y) > ratio, 0.0, 1.0) 42 | 43 | return ents_mask 44 | 45 | 46 | # -------------------------------------------------------- 47 | # -------------------------------------------------------- 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description="Extracts alignments based on different embeddings", epilog="example: python3 main.py path/to/L1/text path/to/L2/text [options]") 50 | parser.add_argument("L1_path", type=str, help="Lines in the file should be indexed separated by TABs.") 51 | parser.add_argument("L2_path", type=str, help="Same format as L1 file.") 52 | parser.add_argument("-model", type=str, default="bert", help="choices: ['bert', 'xlmr', '']") 53 | parser.add_argument("-distortion", type=float, default=0.0) 54 | parser.add_argument("--null-align", type=float, default=1.0) 55 | parser.add_argument("--token-type", type=str, choices=["bpe", "word"], default="bpe") 56 | parser.add_argument("--matching-methods", type=str, default="mai", help="m: Max Weight Matching (mwmf), a: argmax (inter), i: itermax, f: forward (fwd), r: reverse (rev)") 57 | parser.add_argument("--num-test-sents", type=int, default=None, help="None means all sentences") 58 | parser.add_argument("--batch-size", type=int, default=100) 59 | parser.add_argument("-log", action="store_true") 60 | parser.add_argument("-device", type=str, default="cpu") 61 | parser.add_argument("-output", type=str, default="align_out", help="output alignment files (without extension)") 62 | parser.add_argument("--add-probs", action="store_true") 63 | parser.add_argument("--layer", type=int, default=8, help="The layer the embeddings should be taken from.") 64 | args = parser.parse_args() 65 | 66 | if args.model == "bert": 67 | args.model = "bert-base-multilingual-cased" 68 | elif args.model == "xlmr": 69 | args.model = "xlm-roberta-base" 70 | 71 | LOG.info("Simalign parameters: " + str(args)) 72 | 73 | langs = [args.L1_path, args.L2_path] 74 | max_sent_id = args.num_test_sents 75 | convert_to_words = (args.token_type == "word") 76 | device = torch.device(args.device) 77 | 78 | # -------------------------------------------------------- 79 | embed_loader = EmbeddingLoader(model=args.model, device=device, 80 | layer=args.layer) 81 | 82 | original_paths = [lang for lang in langs] 83 | original_corpora = [] 84 | for path in original_paths: 85 | corpus = [l.strip() for l in codecs.open(path, 'r', 'utf-8').readlines()] 86 | if len(corpus[0].split("\t")) == 2: 87 | corpus = [line.split("\t")[1] for line in corpus] 88 | corpus = [regex.sub("\\p{C}+", "", regex.sub("\\p{Separator}+", " ", l)).strip() for l in corpus] 89 | original_corpora.append(corpus[:max_sent_id]) 90 | 91 | words_tokens = [] 92 | for sent_id in range(len(original_corpora[0])): 93 | l1_tokens = [embed_loader.tokenizer.tokenize(word) for word in original_corpora[0][sent_id].split()] 94 | l2_tokens = [embed_loader.tokenizer.tokenize(word) for word in original_corpora[1][sent_id].split()] 95 | words_tokens.append([l1_tokens, l2_tokens]) 96 | 97 | sentences_bpe_lists = [] 98 | sentences_b2w_map = [] 99 | for sent_id in range(len(words_tokens)): 100 | sent_pair = [[bpe for w in sent for bpe in w] for sent in words_tokens[sent_id]] 101 | b2w_map_pair = [[i for i, w in enumerate(sent) for bpe in w] for sent in words_tokens[sent_id]] 102 | sentences_bpe_lists.append(sent_pair) 103 | sentences_b2w_map.append(b2w_map_pair) 104 | 105 | corpora_lengths = [len(corpus) for corpus in original_corpora] 106 | if min(corpora_lengths) != max(corpora_lengths): 107 | LOG.warning("Mismatch in corpus lengths: " + str(corpora_lengths)) 108 | raise ValueError('Cannot load parallel corpus.') 109 | 110 | # -------------------------------------------------------- 111 | all_matching_methods = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev"} 112 | matching_methods = [all_matching_methods[m] for m in args.matching_methods] 113 | 114 | out_f = {ext: open('{}.{}'.format(args.output, ext), 'w') for ext in matching_methods} 115 | if args.log: 116 | out_log = open('{}.log'.format(args.output), 'w') 117 | 118 | if args.null_align < 1.0: 119 | entropies = {x: [] for x in matching_methods} 120 | for sent_id in range(len(original_corpora[0])): 121 | sent_pair = [original_corpora[i][sent_id] for i in [0, 1]] 122 | vectors = embed_loader.get_embed_list(sent_pair).cpu().detach().numpy() 123 | vectors = [vectors[i][:len(sentences_bpe_lists[sent_id][i])] for i in [0, 1]] 124 | 125 | if convert_to_words: 126 | vectors = SentenceAligner.average_embeds_over_words(vectors, words_tokens[sent_id]) 127 | 128 | all_mats = {} 129 | sim = SentenceAligner.get_similarity(vectors[0], vectors[1]) 130 | sim = SentenceAligner.apply_distortion(sim, args.distortion) 131 | 132 | methods_matrix = {} 133 | methods_matrix["forward"], methods_matrix["backward"] = SentenceAligner.get_alignment_matrix(sim) 134 | methods_matrix["inter"] = methods_matrix["forward"] * methods_matrix["backward"] 135 | if "mwmf" in matching_methods: 136 | methods_matrix["mwmf"] = SentenceAligner.get_max_weight_match(sim) 137 | if "itermax" in matching_methods: 138 | methods_matrix["itermax"] = SentenceAligner.iter_max(sim) 139 | 140 | for m in entropies: 141 | entropies[m] += gather_null_aligns(sim, methods_matrix[m]) 142 | null_thresh = {m: sorted(entropies[m])[int(args.null_align * len(entropies[m]))] for m in entropies} 143 | 144 | ds = [(idx, original_corpora[0][idx], original_corpora[1][idx]) for idx in range(len(original_corpora[0]))] 145 | data_loader = torch.utils.data.DataLoader(ds, batch_size=args.batch_size, shuffle=False) 146 | for batch_id, batch_sentences in enumerate(tqdm(data_loader)): 147 | batch_vectors_src = embed_loader.get_embed_list(batch_sentences[1]) 148 | batch_vectors_trg = embed_loader.get_embed_list(batch_sentences[2]) 149 | btach_sim = None 150 | if not convert_to_words: 151 | batch_vectors_src = F.normalize(batch_vectors_src, dim=2) 152 | batch_vectors_trg = F.normalize(batch_vectors_trg, dim=2) 153 | 154 | btach_sim = torch.bmm(batch_vectors_src, torch.transpose(batch_vectors_trg, 1, 2)) 155 | btach_sim = ((btach_sim + 1.0) / 2.0).cpu().detach().numpy() 156 | 157 | batch_vectors_src = batch_vectors_src.cpu().detach().numpy() 158 | batch_vectors_trg = batch_vectors_trg.cpu().detach().numpy() 159 | 160 | for in_batch_id, sent_id in enumerate(batch_sentences[0].numpy()): 161 | sent_pair = sentences_bpe_lists[sent_id] 162 | vectors = [batch_vectors_src[in_batch_id, :len(sent_pair[0])], batch_vectors_trg[in_batch_id, :len(sent_pair[1])]] 163 | 164 | if not convert_to_words: 165 | sim = btach_sim[in_batch_id, :len(sent_pair[0]), :len(sent_pair[1])] 166 | else: 167 | vectors = SentenceAligner.average_embeds_over_words(vectors, words_tokens[sent_id]) 168 | sim = SentenceAligner.get_similarity(vectors[0], vectors[1]) 169 | 170 | all_mats = {} 171 | 172 | sim = SentenceAligner.apply_distortion(sim, args.distortion) 173 | if args.null_align < 1.0: 174 | mask_nulls = {mmethod: apply_percentile_null_aligns(sim, null_thresh[mmethod]) for mmethod in matching_methods} 175 | 176 | all_mats["fwd"], all_mats["rev"] = SentenceAligner.get_alignment_matrix(sim) 177 | all_mats["inter"] = all_mats["fwd"] * all_mats["rev"] 178 | if "mwmf" in matching_methods: 179 | all_mats["mwmf"] = SentenceAligner.get_max_weight_match(sim) 180 | if "itermax" in matching_methods: 181 | all_mats["itermax"] = SentenceAligner.iter_max(sim) 182 | 183 | if args.null_align < 1.0: 184 | if "inter" in matching_methods: 185 | all_mats["inter"] = np.multiply(all_mats["inter"], mask_nulls["inter"]) 186 | if "mwmf" in matching_methods: 187 | all_mats["mwmf"] = np.multiply(all_mats["mwmf"], mask_nulls["mwmf"]) 188 | if "itermax" in matching_methods: 189 | all_mats["itermax"] = np.multiply(all_mats["itermax"], mask_nulls["itermax"]) 190 | 191 | raw_aligns = {x: [] for x in matching_methods} 192 | b2w_aligns = {x: set() for x in matching_methods} 193 | raw_scores = {x: collections.defaultdict(lambda: []) for x in matching_methods} 194 | b2w_scores = {x: collections.defaultdict(lambda: []) for x in matching_methods} 195 | log_aligns = [] 196 | 197 | for i in range(len(vectors[0])): 198 | for j in range(len(vectors[1])): 199 | for ext in matching_methods: 200 | if all_mats[ext][i, j] > 0: 201 | raw_aligns[ext].append('{}-{}'.format(i, j)) 202 | raw_scores[ext]['{}-{}'.format(i, j)].append(sim[i, j]) 203 | if args.token_type == "bpe": 204 | b2w_aligns[ext].add('{}-{}'.format(sentences_b2w_map[sent_id][0][i], sentences_b2w_map[sent_id][1][j])) 205 | b2w_scores[ext]['{}-{}'.format(sentences_b2w_map[sent_id][0][i], sentences_b2w_map[sent_id][1][j])].append(sim[i, j]) 206 | if ext == "inter": 207 | log_aligns.append('{}-{}:({}, {})'.format(i, j, sent_pair[0][i], sent_pair[1][j])) 208 | else: 209 | b2w_aligns[ext].add('{}-{}'.format(i, j)) 210 | 211 | for ext in out_f: 212 | if convert_to_words: 213 | if not args.add_probs: 214 | out_f[ext].write(str(sent_id) + "\t" + ' '.join(sorted(raw_aligns[ext])) + "\n") 215 | else: 216 | out_f[ext].write(str(sent_id) + "\t" + ' '.join(sorted([F"{p}-{str(round(np.mean(vals), 3))[1:]}" for p, vals in raw_scores[ext].items()])) + "\n") 217 | else: 218 | if not args.add_probs: 219 | out_f[ext].write(str(sent_id) + "\t" + ' '.join(sorted(b2w_aligns[ext])) + "\n") 220 | else: 221 | out_f[ext].write(str(sent_id) + "\t" + ' '.join(sorted([F"{p}-{str(round(np.mean(vals), 3))[1:]}" for p, vals in b2w_scores[ext].items()])) + "\n") 222 | if args.log: 223 | out_log.write(str(sent_id) + "\t" + ' '.join(sorted(log_aligns)) + "\n") 224 | 225 | if args.log: 226 | out_log.close() 227 | for ext in out_f: 228 | out_f[ext].close() 229 | 230 | --------------------------------------------------------------------------------