├── requirements.txt ├── setup.py ├── simalign ├── __init__.py ├── utils.py └── simalign.py ├── assets └── example.png ├── pyproject.toml ├── scripts ├── align_example.py ├── calc_align_score.py ├── visualize.py └── align_files.py ├── samples ├── sample_eng.txt ├── sample_deu.txt └── sample_eng_deu.gold ├── setup.cfg ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /simalign/__init__.py: -------------------------------------------------------------------------------- 1 | from .simalign import EmbeddingLoader, SentenceAligner 2 | -------------------------------------------------------------------------------- /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cisnlp/simalign/HEAD/assets/example.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /scripts/align_example.py: -------------------------------------------------------------------------------- 1 | import simalign 2 | 3 | source_sentence = "Sir Nils Olav III. was knighted by the norwegian king ." 4 | target_sentence = "Nils Olav der Dritte wurde vom norwegischen König zum Ritter geschlagen ." 5 | 6 | model = simalign.SentenceAligner() 7 | result = model.get_word_aligns(source_sentence, target_sentence) 8 | print(result) 9 | -------------------------------------------------------------------------------- /samples/sample_eng.txt: -------------------------------------------------------------------------------- 1 | 0 We do not believe that we should cherry-pick . 2 | 1 But this is not what happens . 3 | 2 Of course , if a drug addict becomes a pusher , then it is right and necessary that he should pay and answer before the law also . 4 | 3 Commissioner , ladies and gentlemen , I should like to begin by thanking Mr Burtone for his report . 5 | 4 ' Legal drugs ' ( tranquillizers ) finding their way on to an illegal market , especially when used in combination with alcohol , are a major and serious problem particularly for young people . 6 | -------------------------------------------------------------------------------- /samples/sample_deu.txt: -------------------------------------------------------------------------------- 1 | 0 Wir glauben nicht , daß wir nur Rosinen herauspicken sollten . 2 | 1 Das stimmt nicht ! 3 | 2 Sicher - wenn ein Drogenabhängiger zum Dealer wird , dann ist es richtig und notwendig , daß er dafür auch vor dem Gesetz zur Rechenschaft gezogen wird . 4 | 3 Herr Kommissar , liebe Kolleginnen und Kollegen ! Zunächst herzlichen Dank , Herr Burtone , für Ihren Bericht . 5 | 4 Die in den illegalen Handel gelangten sogenannten legalen Drogen bzw. Beruhigungsmittel sind vor allem in Zusammenhang mit Alkohol ein gravierendes Problem , speziell für Jugendliche . 6 | -------------------------------------------------------------------------------- /samples/sample_eng_deu.gold: -------------------------------------------------------------------------------- 1 | 0 0-0 1-1 2-2 3-1 4-3 4-4 5-5 6-9 7-6 7-7 7-8 8-10 2 | 1 1-0 2-1 3-2 4-1 5-1 6-3 3 | 2 0-0 1-0 10-8 11-9 12-11 13-10 14-12 15-13 16-14 17-16 18-17 19p25 19p26 20p24 21p24 22p24 23-20 24-21 25-22 26-19 27-27 2p1 3-2 4-3 5-4 6-4 7-7 8p5 9-6 4 | 3 0-1 0p0 1-2 10-8 11-8 12-10 12-9 13-12 14-13 15-15 16-16 17-17 18-18 2-4 3-5 4-6 5p7 6-8 7-8 8-8 9-8 5 | 4 0p6 1-7 10-1 11-1 12-2 13-3 14-4 16-12 16-13 17-14 17p15 18-15 18p14 19-15 19p14 2-8 20-15 21-16 22-17 23-21 24-11 25-18 26-19 27-19 28-19 29-20 30-22 31-23 32-24 33-24 34-25 3p6 3p9 4p9 5-10 6p9 7-5 8-5 9-5 6 | -------------------------------------------------------------------------------- /simalign/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Text 3 | import os 4 | 5 | 6 | def get_logger(name: Text, filename: Text = None, level: int = logging.DEBUG) -> logging.Logger: 7 | logger = logging.getLogger(name) 8 | logger.setLevel(level) 9 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | 11 | ch = logging.StreamHandler() 12 | ch.setLevel(level) 13 | ch.setFormatter(formatter) 14 | logger.addHandler(ch) 15 | 16 | if filename is not None: 17 | fh = logging.FileHandler(filename) 18 | fh.setLevel(level) 19 | fh.setFormatter(formatter) 20 | logger.addHandler(fh) 21 | 22 | return logger 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = simalign 3 | version = v0.4 4 | author = Masoud Jalili Sabet, Philipp Dufter 5 | author_email = philipp@cis.lmu.de,masoud@cis.lmu.de 6 | description = Word Alignments using Pretrained Language Models 7 | keywords = NLP deep learning transformer pytorch BERT Word Alignment 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | url = https://github.com/cisnlp/simalign 11 | project_urls = 12 | Bug Tracker = https://github.com/cisnlp/simalign/issues 13 | classifiers = 14 | Programming Language :: Python :: 3 15 | License :: OSI Approved :: MIT License 16 | Operating System :: OS Independent 17 | [options] 18 | packages = simalign 19 | install_requires = 20 | numpy 21 | torch 22 | scipy 23 | transformers 24 | regex 25 | networkx 26 | scikit_learn 27 | python_requires = >=3.6.0 28 | zip_safe = False 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2020-2021] [Masoud Jalili Sabet, Philipp Dufter] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/calc_align_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import collections 4 | import os.path 5 | 6 | 7 | def load_gold(g_path): 8 | gold_f = open(g_path, "r") 9 | pros = {} 10 | surs = {} 11 | all_count = 0. 12 | surs_count = 0. 13 | 14 | for line in gold_f: 15 | line = line.strip().split("\t") 16 | line[1] = line[1].split() 17 | 18 | pros[line[0]] = set([x.replace("p", "-") for x in line[1]]) 19 | surs[line[0]] = set([x for x in line[1] if "p" not in x]) 20 | 21 | all_count += len(pros[line[0]]) 22 | surs_count += len(surs[line[0]]) 23 | 24 | return pros, surs, surs_count 25 | 26 | def calc_score(input_path, probs, surs, surs_count): 27 | total_hit = 0. 28 | p_hit = 0. 29 | s_hit = 0. 30 | target_f = open(input_path, "r") 31 | 32 | for line in target_f: 33 | line = line.strip().split("\t") 34 | 35 | if line[0] not in probs: continue 36 | if len(line) < 2: continue 37 | line[1] = line[1].split() 38 | if len(line[1][0].split("-")) > 2: 39 | line[1] = ["-".join(x.split("-")[:2]) for x in line[1]] 40 | 41 | p_hit += len(set(line[1]) & set(probs[line[0]])) 42 | s_hit += len(set(line[1]) & set(surs[line[0]])) 43 | total_hit += len(set(line[1])) 44 | target_f.close() 45 | 46 | y_prec = round(p_hit / max(total_hit, 1.), 3) 47 | y_rec = round(s_hit / max(surs_count, 1.), 3) 48 | y_f1 = round(2. * y_prec * y_rec / max((y_prec + y_rec), 0.01), 3) 49 | aer = round(1 - (s_hit + p_hit) / (total_hit + surs_count), 3) 50 | 51 | return y_prec, y_rec, y_f1, aer 52 | 53 | 54 | if __name__ == "__main__": 55 | ''' 56 | Calculate alignment quality scores based on the gold standard. 57 | The output contains Precision, Recall, F1, and AER. 58 | The gold annotated file should be selected by "gold_path". 59 | The generated alignment file should be selected by "input_path". 60 | Both gold file and input file are in the FastAlign format with sentence number at the start of line separated with TAB. 61 | 62 | usage: python calc_align_score.py gold_file generated_file 63 | ''' 64 | 65 | parser = argparse.ArgumentParser(description="Calculate alignment quality scores based on the gold standard.", epilog="example: python calc_align_score.py gold_path input_path") 66 | parser.add_argument("gold_path") 67 | parser.add_argument("input_path") 68 | args = parser.parse_args() 69 | 70 | if not os.path.isfile(args.input_path): 71 | print("The input file does not exist:\n", args.input_path) 72 | exit() 73 | 74 | probs, surs, surs_count = load_gold(args.gold_path) 75 | y_prec, y_rec, y_f1, aer = calc_score(args.input_path, probs, surs, surs_count) 76 | 77 | print("Prec: {}\tRec: {}\tF1: {}\tAER: {}".format(y_prec, y_rec, y_f1, aer)) 78 | 79 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from typing import List, Text, Tuple 4 | 5 | 6 | def line2matrix(line: Text, n: int, m: int) -> Tuple[np.ndarray, np.ndarray]: 7 | ''' 8 | converts alignemnt given in the format "0-1 3p4 5-6" to alignment matrices 9 | n, m: maximum length of the involved sentences (i.e., dimensions of the alignemnt matrices) 10 | ''' 11 | def convert(i, j): 12 | i, j = int(i), int(j) 13 | if i >= n or j >= m: 14 | raise ValueError("Error in Gold Standard?") 15 | return i, j 16 | possibles = np.zeros((n, m)) 17 | sures = np.zeros((n, m)) 18 | for elem in line.split(" "): 19 | if "p" in elem: 20 | i, j = convert(*elem.split("p")) 21 | possibles[i, j] = 1 22 | elif "-" in elem: 23 | i, j = convert(*elem.split("-")) 24 | possibles[i, j] = 1 25 | sures[i, j] = 1 26 | return sures, possibles 27 | 28 | 29 | def plot_alignments(e: List[Text], 30 | f: List[Text], 31 | sures: np.ndarray, 32 | possibles: np.ndarray, 33 | alignment1: np.ndarray, 34 | alignment2: np.ndarray = None, 35 | title: Text = None, 36 | filename: Text = None, 37 | dpi: int = 150): 38 | shorter = min(len(e), len(f)) 39 | scalefactor = min((4 / shorter), 1) 40 | 41 | groundtruth = 0.75 * sures + 0.4 * possibles 42 | 43 | fig, ax = plt.subplots() 44 | im = ax.imshow(groundtruth, cmap="Greens", vmin=0, vmax=1.5) 45 | 46 | # show all ticks... 47 | ax.set_xticks(np.arange(len(f))) 48 | ax.set_yticks(np.arange(len(e))) 49 | # ... and label them 50 | ax.set_xticklabels(f, fontsize=25 * scalefactor) 51 | ax.set_yticklabels(e, fontsize=25 * scalefactor) 52 | 53 | for edge, spine in ax.spines.items(): 54 | spine.set_visible(False) 55 | 56 | ax.tick_params(top=True, bottom=False, 57 | labeltop=True, labelbottom=False) 58 | 59 | # Rotate the tick labels and set their alignment. 60 | plt.setp(ax.get_xticklabels(), rotation=30, ha="left", 61 | rotation_mode="default") 62 | plt.setp(ax.get_yticklabels(), rotation=0, ha="right", 63 | rotation_mode="anchor") 64 | ax.set_xticks(np.arange(groundtruth.shape[1] + 1) - .5, minor=True) 65 | ax.set_yticks(np.arange(groundtruth.shape[0] + 1) - .5, minor=True) 66 | 67 | # set grid 68 | ax.grid(which="minor", color="black", linestyle='-', linewidth=1) 69 | ax.tick_params(which="minor", bottom=False, left=False) 70 | # Loop over data dimensions and create text annotations. 71 | circle = dict(boxstyle="circle,pad=0.3", fc=(0, 0, 0, 0.0), ec="black", lw=3) 72 | roundthing = dict(boxstyle="square,pad=0.3", fc="black", ec=(0, 0, 0, 0.0), lw=2) 73 | 74 | # plot alignments 75 | for i in range(len(e)): 76 | for j in range(len(f)): 77 | if alignment1[i, j] > 0: 78 | t = ax.text(j, i, "x", ha="center", va="center", 79 | size=25 * scalefactor, 80 | bbox=circle, color=(0, 0, 0, 0.0)) 81 | if alignment2 is not None and alignment2[i, j] > 0: 82 | t = ax.text(j, i, "x", ha="center", va="center", 83 | size=12 * scalefactor, 84 | bbox=roundthing, color=(0, 0, 0, 0.0)) 85 | if title: 86 | ax.set_title(title) 87 | fig.tight_layout() 88 | if filename: 89 | plt.savefig(filename, dpi=dpi) 90 | else: 91 | plt.show() 92 | 93 | 94 | if __name__ == '__main__': 95 | line2matrix("0-0 1p1 2-1", 3, 2) 96 | plot_alignments(["Testing", "this", "."], 97 | ["Hier", "wird", "getestet", "."], 98 | np.array([[0, 0, 1, 0], 99 | [0, 0, 0, 0], 100 | [0, 0, 0, 1]]), 101 | np.array([[0, 0, 0, 0], 102 | [1, 0, 0, 0], 103 | [0, 0, 0, 0]]), 104 | np.array([[0, 1, 0, 0], 105 | [0, 0, 0, 0], 106 | [0, 1, 0, 0]]), 107 | np.array([[0, 0, 0, 1], 108 | [0, 0, 0, 0], 109 | [0, 0, 0, 0]]), 110 | "Example") 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SimAlign: Similarity Based Word Aligner 2 | ============== 3 | 4 |
5 |
6 |
7 |
8 |
9 |
10 | SimAlign is a high-quality word alignment tool that uses static and contextualized embeddings and **does not require parallel training data**.
11 |
12 | The following table shows how it compares to popular statistical alignment models:
13 |
14 | | | ENG-CES | ENG-DEU | ENG-FAS | ENG-FRA | ENG-HIN | ENG-RON |
15 | | ---------- | ------- | ------- | ------- | ------- | ------- | ------- |
16 | | fast-align | .78 | .71 | .46 | .84 | .38 | .68 |
17 | | eflomal | .85 | .77 | .63 | .93 | .52 | .72 |
18 | | mBERT-Argmax | .87 | .81 | .67 | .94 | .55 | .65 |
19 |
20 | Shown is F1, maximum across subword and word level. For more details see the [Paper](https://arxiv.org/pdf/2004.08728.pdf).
21 |
22 |
23 | Installation and Usage
24 | --------
25 |
26 | Tested with Python 3.7, Transformers 3.1.0, Torch 1.5.0. Networkx 2.4 is optional (only required for Match algorithm).
27 | For full list of dependencies see `setup.py`.
28 | For installation of transformers see [their repo](https://github.com/huggingface/transformers#installation).
29 |
30 | Download the repo for use or alternatively install with PyPi
31 |
32 | `pip install simalign`
33 |
34 | or directly with pip from GitHub
35 |
36 | `pip install --upgrade git+https://github.com/cisnlp/simalign.git#egg=simalign`
37 |
38 |
39 | An example for using our code:
40 | ```python
41 | from simalign import SentenceAligner
42 |
43 | # making an instance of our model.
44 | # You can specify the embedding model and all alignment settings in the constructor.
45 | myaligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai")
46 |
47 | # The source and target sentences should be tokenized to words.
48 | src_sentence = ["This", "is", "a", "test", "."]
49 | trg_sentence = ["Das", "ist", "ein", "Test", "."]
50 |
51 | # The output is a dictionary with different matching methods.
52 | # Each method has a list of pairs indicating the indexes of aligned words (The alignments are zero-indexed).
53 | alignments = myaligner.get_word_aligns(src_sentence, trg_sentence)
54 |
55 | for matching_method in alignments:
56 | print(matching_method, ":", alignments[matching_method])
57 |
58 | # Expected output:
59 | # mwmf (Match): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
60 | # inter (ArgMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
61 | # itermax (IterMax): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
62 | ```
63 | For more examples of how to use our code see `scripts/align_example.py`.
64 |
65 | Demo
66 | --------
67 |
68 | An online demo is available [here](https://simalign.cis.lmu.de/).
69 |
70 |
71 | Gold Standards
72 | --------
73 | Links to the gold standars used in the paper are here:
74 |
75 |
76 | | Language Pair | Citation | Type |Link |
77 | | ------------- | ------------- | ------------- | ------------- |
78 | | ENG-CES | Marecek et al. 2008 | Gold Alignment | http://ufal.mff.cuni.cz/czech-english-manual-word-alignment |
79 | | ENG-DEU | EuroParl-based | Gold Alignment | www-i6.informatik.rwth-aachen.de/goldAlignment/ |
80 | | ENG-FAS | Tvakoli et al. 2014 | Gold Alignment | http://eceold.ut.ac.ir/en/node/940 |
81 | | ENG-FRA | WPT2003, Och et al. 2000,| Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt/ |
82 | | ENG-HIN | WPT2005 | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt05/ |
83 | | ENG-RON | WPT2005 Mihalcea et al. 2003 | Gold Alignment | http://web.eecs.umich.edu/~mihalcea/wpt05/ |
84 |
85 |
86 | Evaluation Script
87 | --------
88 | For evaluating the output alignments use `scripts/calc_align_score.py`.
89 |
90 | The gold alignment file should have the same format as SimAlign outputs.
91 | Sure alignment edges in the gold standard have a '-' between the source and the target indices and the possible edges have a 'p' between indices.
92 | For sample parallel sentences and their gold alignments from ENG-DEU, see `samples`.
93 |
94 |
95 | Publication
96 | --------
97 |
98 | If you use the code, please cite
99 |
100 | ```
101 | @inproceedings{jalili-sabet-etal-2020-simalign,
102 | title = "{S}im{A}lign: High Quality Word Alignments without Parallel Training Data using Static and Contextualized Embeddings",
103 | author = {Jalili Sabet, Masoud and
104 | Dufter, Philipp and
105 | Yvon, Fran{\c{c}}ois and
106 | Sch{\"u}tze, Hinrich},
107 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
108 | month = nov,
109 | year = "2020",
110 | address = "Online",
111 | publisher = "Association for Computational Linguistics",
112 | url = "https://www.aclweb.org/anthology/2020.findings-emnlp.147",
113 | pages = "1627--1643",
114 | }
115 | ```
116 |
117 | Feedback
118 | --------
119 |
120 | Feedback and Contributions more than welcome! Just reach out to @masoudjs or @pdufter.
121 |
122 |
123 | FAQ
124 | --------
125 |
126 | ##### Do I need parallel data to train the system?
127 |
128 | No, no parallel training data is required.
129 |
130 | ##### Which languages can be aligned?
131 |
132 | This depends on the underlying pretrained multilingual language model used. For example, if mBERT is used, it covers 104 languages as listed [here](https://github.com/google-research/bert/blob/master/multilingual.md).
133 |
134 | ##### Do I need GPUs for running this?
135 |
136 | Each alignment simply requires a single forward pass in the pretrained language model. While this is certainly
137 | faster on GPU, it runs fine on CPU. On one GPU (GeForce GTX 1080 Ti) it takes around 15-20 seconds to align 500 parallel sentences.
138 |
139 |
140 |
141 | License
142 | -------
143 |
144 | Copyright (C) 2020, Masoud Jalili Sabet, Philipp Dufter
145 |
146 | A full copy of the license can be found in LICENSE.
147 |
--------------------------------------------------------------------------------
/simalign/simalign.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import os
4 | import logging
5 | from typing import Dict, List, Tuple, Union
6 |
7 | import numpy as np
8 | from scipy.stats import entropy
9 | from scipy.sparse import csr_matrix
10 | from sklearn.preprocessing import normalize
11 | from sklearn.metrics.pairwise import cosine_similarity
12 | try:
13 | import networkx as nx
14 | from networkx.algorithms.bipartite.matrix import from_biadjacency_matrix
15 | except ImportError:
16 | nx = None
17 | import torch
18 | from transformers import BertModel, BertTokenizer, XLMModel, XLMTokenizer, RobertaModel, RobertaTokenizer, XLMRobertaModel, XLMRobertaTokenizer, AutoConfig, AutoModel, AutoTokenizer
19 |
20 | from simalign.utils import get_logger
21 |
22 | LOG = get_logger(__name__)
23 |
24 |
25 | class EmbeddingLoader(object):
26 | def __init__(self, model: str="bert-base-multilingual-cased", device=torch.device('cpu'), layer: int=8):
27 | TR_Models = {
28 | 'bert-base-uncased': (BertModel, BertTokenizer),
29 | 'bert-base-multilingual-cased': (BertModel, BertTokenizer),
30 | 'bert-base-multilingual-uncased': (BertModel, BertTokenizer),
31 | 'xlm-mlm-100-1280': (XLMModel, XLMTokenizer),
32 | 'roberta-base': (RobertaModel, RobertaTokenizer),
33 | 'xlm-roberta-base': (XLMRobertaModel, XLMRobertaTokenizer),
34 | 'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer),
35 | }
36 |
37 | self.model = model
38 | self.device = device
39 | self.layer = layer
40 | self.emb_model = None
41 | self.tokenizer = None
42 |
43 | if model in TR_Models:
44 | model_class, tokenizer_class = TR_Models[model]
45 | self.emb_model = model_class.from_pretrained(model, output_hidden_states=True)
46 | self.emb_model.eval()
47 | self.emb_model.to(self.device)
48 | self.tokenizer = tokenizer_class.from_pretrained(model)
49 | else:
50 | # try to load model with auto-classes
51 | config = AutoConfig.from_pretrained(model, output_hidden_states=True)
52 | self.emb_model = AutoModel.from_pretrained(model, config=config)
53 | self.emb_model.eval()
54 | self.emb_model.to(self.device)
55 | self.tokenizer = AutoTokenizer.from_pretrained(model)
56 | LOG.info("Initialized the EmbeddingLoader with model: {}".format(self.model))
57 |
58 | def get_embed_list(self, sent_batch: List[List[str]]) -> torch.Tensor:
59 | if self.emb_model is not None:
60 | with torch.no_grad():
61 | if not isinstance(sent_batch[0], str):
62 | inputs = self.tokenizer(sent_batch, is_split_into_words=True, padding=True, truncation=True, return_tensors="pt")
63 | else:
64 | inputs = self.tokenizer(sent_batch, is_split_into_words=False, padding=True, truncation=True, return_tensors="pt")
65 | hidden = self.emb_model(**inputs.to(self.device))["hidden_states"]
66 | if self.layer >= len(hidden):
67 | raise ValueError(f"Specified to take embeddings from layer {self.layer}, but model has only {len(hidden)} layers.")
68 | outputs = hidden[self.layer]
69 | return outputs[:, 1:-1, :]
70 | else:
71 | return None
72 |
73 |
74 | class SentenceAligner(object):
75 | def __init__(self, model: str = "bert", token_type: str = "bpe", distortion: float = 0.0, matching_methods: str = "mai", device: str = "cpu", layer: int = 8):
76 | model_names = {
77 | "bert": "bert-base-multilingual-cased",
78 | "xlmr": "xlm-roberta-base"
79 | }
80 | all_matching_methods = {"a": "inter", "m": "mwmf", "i": "itermax", "f": "fwd", "r": "rev"}
81 |
82 | self.model = model
83 | if model in model_names:
84 | self.model = model_names[model]
85 | self.token_type = token_type
86 | self.distortion = distortion
87 | self.matching_methods = [all_matching_methods[m] for m in matching_methods]
88 | self.device = torch.device(device)
89 |
90 | self.embed_loader = EmbeddingLoader(model=self.model, device=self.device, layer=layer)
91 |
92 | @staticmethod
93 | def get_max_weight_match(sim: np.ndarray) -> np.ndarray:
94 | if nx is None:
95 | raise ValueError("networkx must be installed to use match algorithm.")
96 | def permute(edge):
97 | if edge[0] < sim.shape[0]:
98 | return edge[0], edge[1] - sim.shape[0]
99 | else:
100 | return edge[1], edge[0] - sim.shape[0]
101 | G = from_biadjacency_matrix(csr_matrix(sim))
102 | matching = nx.max_weight_matching(G, maxcardinality=True)
103 | matching = [permute(x) for x in matching]
104 | matching = sorted(matching, key=lambda x: x[0])
105 | res_matrix = np.zeros_like(sim)
106 | for edge in matching:
107 | res_matrix[edge[0], edge[1]] = 1
108 | return res_matrix
109 |
110 | @staticmethod
111 | def get_similarity(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
112 | return (cosine_similarity(X, Y) + 1.0) / 2.0
113 |
114 | @staticmethod
115 | def average_embeds_over_words(bpe_vectors: np.ndarray, word_tokens_pair: List[List[str]]) -> List[np.array]:
116 | w2b_map = []
117 | cnt = 0
118 | w2b_map.append([])
119 | for wlist in word_tokens_pair[0]:
120 | w2b_map[0].append([])
121 | for x in wlist:
122 | w2b_map[0][-1].append(cnt)
123 | cnt += 1
124 | cnt = 0
125 | w2b_map.append([])
126 | for wlist in word_tokens_pair[1]:
127 | w2b_map[1].append([])
128 | for x in wlist:
129 | w2b_map[1][-1].append(cnt)
130 | cnt += 1
131 |
132 | new_vectors = []
133 | for l_id in range(2):
134 | w_vector = []
135 | for word_set in w2b_map[l_id]:
136 | w_vector.append(bpe_vectors[l_id][word_set].mean(0))
137 | new_vectors.append(np.array(w_vector))
138 | return new_vectors
139 |
140 | @staticmethod
141 | def get_alignment_matrix(sim_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
142 | m, n = sim_matrix.shape
143 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n
144 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m
145 | return forward, backward.transpose()
146 |
147 | @staticmethod
148 | def apply_distortion(sim_matrix: np.ndarray, ratio: float = 0.5) -> np.ndarray:
149 | shape = sim_matrix.shape
150 | if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0:
151 | return sim_matrix
152 |
153 | pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])])
154 | pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])])
155 | distortion_mask = 1.0 - ((pos_x - np.transpose(pos_y)) ** 2) * ratio
156 |
157 | return np.multiply(sim_matrix, distortion_mask)
158 |
159 | @staticmethod
160 | def iter_max(sim_matrix: np.ndarray, max_count: int=2) -> np.ndarray:
161 | alpha_ratio = 0.9
162 | m, n = sim_matrix.shape
163 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n
164 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m
165 | inter = forward * backward.transpose()
166 |
167 | if min(m, n) <= 2:
168 | return inter
169 |
170 | new_inter = np.zeros((m, n))
171 | count = 1
172 | while count < max_count:
173 | mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0)
174 | mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0)
175 | mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0)
176 | mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y))
177 | if mask_x.sum() < 1.0 or mask_y.sum() < 1.0:
178 | mask *= 0.0
179 | mask_zeros *= 0.0
180 |
181 | new_sim = sim_matrix * mask
182 | fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros
183 | bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros
184 | new_inter = fwd * bac
185 |
186 | if np.array_equal(inter + new_inter, inter):
187 | break
188 | inter = inter + new_inter
189 | count += 1
190 | return inter
191 |
192 | def get_word_aligns(self, src_sent: Union[str, List[str]], trg_sent: Union[str, List[str]]) -> Dict[str, List]:
193 | if isinstance(src_sent, str):
194 | src_sent = src_sent.split()
195 | if isinstance(trg_sent, str):
196 | trg_sent = trg_sent.split()
197 | l1_tokens = [self.embed_loader.tokenizer.tokenize(word) for word in src_sent]
198 | l2_tokens = [self.embed_loader.tokenizer.tokenize(word) for word in trg_sent]
199 | bpe_lists = [[bpe for w in sent for bpe in w] for sent in [l1_tokens, l2_tokens]]
200 |
201 | if self.token_type == "bpe":
202 | l1_b2w_map = []
203 | for i, wlist in enumerate(l1_tokens):
204 | l1_b2w_map += [i for x in wlist]
205 | l2_b2w_map = []
206 | for i, wlist in enumerate(l2_tokens):
207 | l2_b2w_map += [i for x in wlist]
208 |
209 | vectors = self.embed_loader.get_embed_list([src_sent, trg_sent]).cpu().detach().numpy()
210 | vectors = [vectors[i, :len(bpe_lists[i])] for i in [0, 1]]
211 |
212 | if self.token_type == "word":
213 | vectors = self.average_embeds_over_words(vectors, [l1_tokens, l2_tokens])
214 |
215 | all_mats = {}
216 | sim = self.get_similarity(vectors[0], vectors[1])
217 | sim = self.apply_distortion(sim, self.distortion)
218 |
219 | all_mats["fwd"], all_mats["rev"] = self.get_alignment_matrix(sim)
220 | all_mats["inter"] = all_mats["fwd"] * all_mats["rev"]
221 | if "mwmf" in self.matching_methods:
222 | all_mats["mwmf"] = self.get_max_weight_match(sim)
223 | if "itermax" in self.matching_methods:
224 | all_mats["itermax"] = self.iter_max(sim)
225 |
226 | aligns = {x: set() for x in self.matching_methods}
227 | for i in range(len(vectors[0])):
228 | for j in range(len(vectors[1])):
229 | for ext in self.matching_methods:
230 | if all_mats[ext][i, j] > 0:
231 | if self.token_type == "bpe":
232 | aligns[ext].add((l1_b2w_map[i], l2_b2w_map[j]))
233 | else:
234 | aligns[ext].add((i, j))
235 | for ext in aligns:
236 | aligns[ext] = sorted(aligns[ext])
237 | return aligns
238 |
--------------------------------------------------------------------------------
/scripts/align_files.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import regex
4 | import codecs
5 | import argparse
6 | import collections
7 | import numpy as np
8 | from tqdm import tqdm
9 | import torch.nn.functional as F
10 |
11 | from simalign.simalign import *
12 |
13 |
14 | def gather_null_aligns(sim_matrix: np.ndarray, inter_matrix: np.ndarray) -> List[float]:
15 | shape = sim_matrix.shape
16 | if min(shape[0], shape[1]) <= 2:
17 | return []
18 | norm_x = normalize(sim_matrix, axis=1, norm='l1')
19 | norm_y = normalize(sim_matrix, axis=0, norm='l1')
20 |
21 | entropy_x = np.array([entropy(norm_x[i, :]) / np.log(shape[1]) for i in range(shape[0])])
22 | entropy_y = np.array([entropy(norm_y[:, j]) / np.log(shape[0]) for j in range(shape[1])])
23 |
24 | mask_x = np.tile(entropy_x[:, np.newaxis], (1, shape[1]))
25 | mask_y = np.tile(entropy_y, (shape[0], 1))
26 |
27 | all_ents = np.multiply(inter_matrix, np.minimum(mask_x, mask_y))
28 | return [x.item() for x in np.nditer(all_ents) if x.item() > 0]
29 |
30 | def apply_percentile_null_aligns(sim_matrix: np.ndarray, ratio: float=1.0) -> np.ndarray:
31 | shape = sim_matrix.shape
32 | if min(shape[0], shape[1]) <= 2:
33 | return np.ones(shape)
34 | norm_x = normalize(sim_matrix, axis=1, norm='l1')
35 | norm_y = normalize(sim_matrix, axis=0, norm='l1')
36 | entropy_x = np.array([entropy(norm_x[i, :]) / np.log(shape[1]) for i in range(shape[0])])
37 | entropy_y = np.array([entropy(norm_y[:, j]) / np.log(shape[0]) for j in range(shape[1])])
38 | mask_x = np.tile(entropy_x[:, np.newaxis], (1, shape[1]))
39 | mask_y = np.tile(entropy_y, (shape[0], 1))
40 |
41 | ents_mask = np.where(np.minimum(mask_x, mask_y) > ratio, 0.0, 1.0)
42 |
43 | return ents_mask
44 |
45 |
46 | # --------------------------------------------------------
47 | # --------------------------------------------------------
48 | if __name__ == "__main__":
49 | parser = argparse.ArgumentParser(description="Extracts alignments based on different embeddings", epilog="example: python3 main.py path/to/L1/text path/to/L2/text [options]")
50 | parser.add_argument("L1_path", type=str, help="Lines in the file should be indexed separated by TABs.")
51 | parser.add_argument("L2_path", type=str, help="Same format as L1 file.")
52 | parser.add_argument("-model", type=str, default="bert", help="choices: ['bert', 'xlmr', '