├── data ├── test.dict ├── train.dict ├── align │ └── de-en │ │ ├── en │ │ ├── de │ │ └── alignmentDeEn.talp ├── bitext.txt └── bitext.txt.align ├── criss └── args.pt ├── align ├── eval_simalign_criss.py ├── evaluate.py ├── data.py ├── test.py ├── train.py └── models.py ├── src ├── evaluate.py ├── models.py ├── fully_unsup.py └── weakly_sup.py ├── LICENSE ├── CONTRIBUTING.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── README.md └── env └── env.yml /data/test.dict: -------------------------------------------------------------------------------- 1 | hund dog 2 | ist is 3 | -------------------------------------------------------------------------------- /data/train.dict: -------------------------------------------------------------------------------- 1 | das this 2 | katze cat 3 | . . 4 | -------------------------------------------------------------------------------- /data/align/de-en/en: -------------------------------------------------------------------------------- 1 | This is a cat . 2 | This is a dog . 3 | 4 | -------------------------------------------------------------------------------- /data/align/de-en/de: -------------------------------------------------------------------------------- 1 | Das ist eine Katze . 2 | Das ist ein Hund . 3 | 4 | -------------------------------------------------------------------------------- /criss/args.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/bitext-lexind/HEAD/criss/args.pt -------------------------------------------------------------------------------- /data/align/de-en/alignmentDeEn.talp: -------------------------------------------------------------------------------- 1 | 1-1 2-2 3-3 4-4 5-5 2 | 1-1 2-2 3-3 4-4 5-5 3 | 4 | -------------------------------------------------------------------------------- /data/bitext.txt: -------------------------------------------------------------------------------- 1 | Das ist eine Katze . ||| This is a cat . 2 | Das ist ein Hund . ||| This is a dog . -------------------------------------------------------------------------------- /data/bitext.txt.align: -------------------------------------------------------------------------------- 1 | {"inter": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]], "itermax": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]} 2 | {"inter": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]], "itermax": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]} 3 | -------------------------------------------------------------------------------- /align/eval_simalign_criss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import regex 4 | from data import AlignDataset 5 | from evaluate import evaluate 6 | from models import Aligner 7 | 8 | 9 | import collections 10 | resdict = collections.defaultdict(None) 11 | 12 | aligner = Aligner( 13 | 'criss-align', distortion=0, 14 | path='criss/criss-3rd.pt', 15 | args_path='criss/args.pt', 16 | matching_method='a' 17 | ) 18 | 19 | dset = AlignDataset('data/align/', 'de-en') 20 | aligns = aligner.align_sents(dset.sent_pairs, langcodes=('de_DE', 'en_XX')) 21 | res = evaluate(dset.ground_truth, aligns, 1) 22 | print('de-en:', res) 23 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | def evaluate(pr_pairs, gt_pairs): 4 | gt_set = set([tuple(x) for x in gt_pairs]) 5 | pr_set = set([tuple(x) for x in pr_pairs]) 6 | prec = sum([1 if x in gt_set else 0 for x in pr_set]) \ 7 | / float(len(pr_set)) if len(pr_set) > 0 else 0 8 | rec = sum([1 if x in pr_set else 0 for x in gt_set]) \ 9 | / float(len(gt_set)) if len(gt_set) > 0 else 0 10 | gt_src_words = set([x[0] for x in gt_pairs]) 11 | pr_src_words = set([x[0] for x in pr_pairs]) 12 | oov_number = sum([1 if x not in pr_src_words else 0 for x in gt_src_words]) 13 | oov_rate = oov_number / float(len(gt_src_words)) 14 | eval_result = { 15 | 'oov_number': oov_number, 16 | 'oov_rate': oov_rate, 17 | 'precision': prec, 18 | 'recall': rec, 19 | 'f1': 2.0 * prec * rec / (prec + rec) if prec > 0 or rec > 0 else 0.0 20 | } 21 | return eval_result 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 FAIR Internal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to bitext-lexind 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing to bitext-lexind, you agree that your contributions will be licensed 27 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /align/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import regex 4 | 5 | 6 | def evaluate(gold, silver, offset=0): 7 | assert len(gold) == len(silver) 8 | a_size = s_size = p_size = ap_inter = as_inter = 0 9 | for i, g in enumerate(gold): 10 | s = set([ 11 | tuple(map(lambda x: int(x), item.split('-'))) 12 | for item in filter(lambda x: x.find('p') == -1, g.split()) 13 | ]) 14 | p = set([tuple(map(lambda x: int(x), regex.split('-|p', item))) for item in g.split()]) 15 | a = set([tuple(map(lambda x: int(x) + offset, regex.split('-', item))) for item in silver[i].split()]) 16 | ap_inter += len(a.intersection(p)) 17 | as_inter += len(a.intersection(s)) 18 | a_size += len(a) 19 | p_size += len(p) 20 | s_size += len(s) 21 | prec = ap_inter / a_size if a_size > 0 else 0 22 | rec = as_inter / s_size if s_size > 0 else 0 23 | return { 24 | 'prec': prec, 25 | 'rec': rec, 26 | 'f1': 2 * prec * rec / (prec + rec) if s_size > 0 and a_size > 0 else 0, 27 | 'aer': 1 - (as_inter + ap_inter) / (a_size + s_size) 28 | } 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from transformers import AutoTokenizer 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class CRISSWrapper(object): 10 | def __init__(self, path='criss/criss-3rd.pt', 11 | args_path='criss/args.pt', 12 | tokenizer='facebook/mbart-large-cc25', device='cpu'): 13 | from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils 14 | from fairseq.sequence_generator import EnsembleModel 15 | self.device = device 16 | args = torch.load(args_path) 17 | task = tasks.setup_task(args) 18 | models, _model_args = checkpoint_utils.load_model_ensemble( 19 | path.split(':'), 20 | arg_overrides=eval('{}'), 21 | task=task 22 | ) 23 | for model in models: 24 | model.make_generation_fast_( 25 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 26 | need_attn=args.print_alignment, 27 | ) 28 | if args.fp16: 29 | model.half() 30 | model = model.to(self.device) 31 | self.model = EnsembleModel(models).to(self.device) 32 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 33 | 34 | def word_embed(self, words, langcode='en_XX'): 35 | tokens = list() 36 | word_ids = list() 37 | for word in words: 38 | word_tokens = self.tokenizer.tokenize(word) + ['', langcode] 39 | tokens.append(word_tokens) 40 | lengths = [len(x) for x in tokens] 41 | max_length = max(lengths) 42 | for i in range(len(tokens)): 43 | word_ids.append(self.tokenizer.convert_tokens_to_ids([''] * (max_length - len(tokens[i])) + tokens[i])) 44 | encoder_input = { 45 | 'src_tokens': torch.tensor(word_ids).to(self.device), 46 | 'src_lengths': torch.tensor(lengths).to(self.device) 47 | } 48 | encoder_outs = self.model.forward_encoder(encoder_input) 49 | np_encoder_outs = encoder_outs[0].encoder_out.float().detach() 50 | encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.float().detach() 51 | encoder_mask = encoder_mask.transpose(0, 1).unsqueeze(2) 52 | masked_encoder_outs = encoder_mask * np_encoder_outs 53 | avg_pool = (masked_encoder_outs / encoder_mask.sum(dim=0)).sum(dim=0) 54 | return avg_pool 55 | 56 | 57 | class LexiconInducer(nn.Module): 58 | def __init__(self, input_dim, hidden_dims, output_dim=1, feature_transform=3): 59 | super(LexiconInducer, self).__init__() 60 | layers = list() 61 | hidden_dims = [input_dim] + hidden_dims 62 | for i in range(1, len(hidden_dims)): 63 | layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) 64 | layers.append(nn.ReLU()) 65 | layers.append(nn.Linear(hidden_dims[-1], output_dim)) 66 | layers.append(nn.Sigmoid()) 67 | self.model = nn.Sequential(*layers) 68 | self.bias = nn.Parameter(torch.ones(feature_transform)) 69 | self.feature_transform = feature_transform 70 | 71 | def forward(self, x): 72 | transformed_features = torch.cat([x[:, :-self.feature_transform], torch.log(x[:, -self.feature_transform:] + self.bias.abs())], dim=-1) 73 | return self.model(transformed_features) 74 | 75 | def __call__(self, *args, **kwargs): 76 | return self.forward(*args, **kwargs) 77 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bilingual Lexicon Inductionvia Unsupervised Bitext Construction and Word Alignment 2 | 3 | [Haoyue Shi](https://ttic.uchicago.edu/~freda), [Luke Zettlemoyer](https://www.cs.washington.edu/people/faculty/lsz) and [Sida I. Wang](http://www.sidaw.xyz/) 4 | 5 | ## Requirements 6 | PyTorch >= 1.7
7 | transformers == 4.0.0
8 | fairseq (to run CRISS and extract CRISS-based features)
9 | chinese_converter (to convert between simplfied and traditional Chinese, fitting the different settings of CRISS and [MUSE](https://github.com/facebookresearch/MUSE))
10 | 11 | See also [env/env.yml](./env/env.yml) for sufficient environment setup. 12 | 13 | ## A Quick Example for the Pipeline of Lexicon Induction 14 | 15 | ### Step 0: Download [CRISS](https://github.com/pytorch/fairseq/tree/master/examples/criss) 16 | The default setting assumes that the CRISS (3rd iteration) model is saved in `criss/criss-3rd.pt`. 17 | 18 | ### Step 1: Unsupervised Bitext Construction with [CRISS](https://github.com/pytorch/fairseq/tree/master/examples/criss) 19 | Let's assume that we have the following [bitext](./data/bitext.txt) (sentences separated by " ||| ", one pair per line): 20 | ``` 21 | Das ist eine Katze . ||| This is a cat . 22 | Das ist ein Hund . ||| This is a dog . 23 | ``` 24 | 25 | ### Step 2: Word Alignment with [SimAlign](https://github.com/cisnlp/simalign) 26 | Note: we use CRISS as the backbone of SimAlign and use [our own implmentation](./align/), you can also use other aligners---just make sure that the results are stored in [a json file](./data/bitext.txt.align) like follows: 27 | ``` 28 | {"inter": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]], "itermax": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]} 29 | {"inter": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]], "itermax": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]} 30 | ``` 31 | where "inter" and "itermax" denote the *argmax* and *itermax* algorithm in SimAlign respectively. 32 | The output is in the same format as the json output of [SimAlign](https://github.com/cisnlp/simalign). 33 | See the code of SimAlign for more details. 34 | 35 | ### Step 3: Training and Testing Lexicon Inducer 36 | #### Fully Unsupervised 37 | ``` 38 | python src/fully_unsup.py \ 39 | -b ./data/bitext.txt \ 40 | -a ./data/bitext.txt.align \ 41 | -te ./data/test.dict 42 | ``` 43 | 44 | #### Weakly Supervised 45 | ``` 46 | python src/weakly_sup.py \ 47 | -b ./data/bitext.txt \ 48 | -a ./data/bitext.txt.align \ 49 | -tr ./data/train.dict \ 50 | -te ./data/test.dict \ 51 | -src de_DE \ 52 | -trg en_XX 53 | ``` 54 | 55 | You would probably also like to specify a model folder by `-o $model_FOLDER` to save the statistices of bitext and alignment (default `./model`). 56 | 57 | `-src` and `-trg` specify the source and target language, where for the languages and corresponding codes that CRISS supports, check the language pairs in [this file](https://github.com/pytorch/fairseq/blob/master/examples/criss/unsupervised_mt/eval.sh). 58 | 59 | You will see the final model (`model.pt`, lexicon inducer) and the induced lexicon (`induced.weaklysup.dict`/`induced.fullyunsup.dict`) in the model folder, as well as a line of evaluation result (on the test set) like follows: 60 | ``` 61 | {'oov_number': 0, 'oov_rate': 0.0, 'precision': 1.0, 'recall': 1.0, 'f1': 1.0} 62 | ``` 63 | 64 | ## A Quick Example for the MLP-Based Aligner 65 | 66 | #### Training 67 | Training an MLP-based aligner using the bitext and alignment shown above. 68 | ``` 69 | python align/train.py \ 70 | -b ./data/bitext.txt \ 71 | -a ./data/bitext.txt.align \ 72 | -src de_DE \ 73 | -trg en_XX \ 74 | -o model/ 75 | ``` 76 | 77 | #### Testing 78 | Testing the saved aligner on the same set (note: this is only used to show how the code works, and in real scenarios we test on a different dataset from the training set). 79 | 80 | The `-b` and `-a` should be the same as those used for training, to avoid potential error (in fact, if you did not delete anything after training, the `-b` and `-a` parameters will never be actually used). 81 | ``` 82 | python align/test.py \ 83 | -b ./data/bitext.txt \ 84 | -a ./data/bitext.txt.align \ 85 | -src de_DE \ 86 | -trg en_XX \ 87 | -m model/ 88 | ``` 89 | 90 | 91 | For CRISS-SimAlign baseline, you can run a quick evaluation of CRISS-based SimAlign the above examples for German--English alignment, using the *argmax* inference algorithm 92 | ``` 93 | python align/eval_simalign_criss.py 94 | ``` 95 | 96 | ## License 97 | MIT -------------------------------------------------------------------------------- /env/env.yml: -------------------------------------------------------------------------------- 1 | name: fairseq-1.7 2 | channels: 3 | - pytorch 4 | - bioconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - argon2-cffi=20.1.0=py36h7b6447c_1 10 | - astroid=2.4.2=py36_0 11 | - async_generator=1.10=py36h28b3542_0 12 | - attrs=20.3.0=pyhd3eb1b0_0 13 | - backcall=0.2.0=py_0 14 | - blas=1.0=mkl 15 | - bleach=3.2.1=py_0 16 | - ca-certificates=2020.12.8=h06a4308_0 17 | - certifi=2020.12.5=py36h06a4308_0 18 | - cffi=1.14.4=py36h261ae71_0 19 | - cudatoolkit=10.1.243=h6bb024c_0 20 | - cycler=0.10.0=py36_0 21 | - dbus=1.13.18=hb2f20db_0 22 | - decorator=4.4.2=py_0 23 | - defusedxml=0.6.0=py_0 24 | - entrypoints=0.3=py36_0 25 | - expat=2.2.10=he6710b0_2 26 | - fontconfig=2.13.0=h9420a91_0 27 | - freetype=2.10.4=h5ab3b9f_0 28 | - glib=2.66.1=h92f7085_0 29 | - gst-plugins-base=1.14.0=hbbd80ab_1 30 | - gstreamer=1.14.0=hb31296c_0 31 | - icu=58.2=he6710b0_3 32 | - importlib-metadata=2.0.0=py_1 33 | - importlib_metadata=2.0.0=1 34 | - intel-openmp=2020.2=254 35 | - ipykernel=5.3.4=py36h5ca1d4c_0 36 | - ipython=7.16.1=py36h5ca1d4c_0 37 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 38 | - ipywidgets=7.5.1=py_1 39 | - isort=5.6.4=py_0 40 | - jedi=0.17.0=py36_0 41 | - jinja2=2.11.2=py_0 42 | - jpeg=9b=h024ee3a_2 43 | - jsonschema=3.2.0=py_2 44 | - jupyter=1.0.0=py36_7 45 | - jupyter_client=6.1.7=py_0 46 | - jupyter_console=6.2.0=py_0 47 | - jupyter_core=4.7.0=py36h06a4308_0 48 | - jupyterlab_pygments=0.1.2=py_0 49 | - kiwisolver=1.3.0=py36h2531618_0 50 | - lazy-object-proxy=1.4.3=py36h27cfd23_2 51 | - lcms2=2.11=h396b838_0 52 | - ld_impl_linux-64=2.33.1=h53a641e_7 53 | - libedit=3.1.20191231=h14c3975_1 54 | - libffi=3.3=he6710b0_2 55 | - libgcc-ng=9.1.0=hdf63c60_0 56 | - libgfortran-ng=7.3.0=hdf63c60_0 57 | - libpng=1.6.37=hbc83047_0 58 | - libsodium=1.0.18=h7b6447c_0 59 | - libstdcxx-ng=9.1.0=hdf63c60_0 60 | - libtiff=4.1.0=h2733197_1 61 | - libuuid=1.0.3=h1bed415_2 62 | - libuv=1.40.0=h7b6447c_0 63 | - libxcb=1.14=h7b6447c_0 64 | - libxml2=2.9.10=hb55368b_3 65 | - lz4-c=1.9.2=heb0550a_3 66 | - markupsafe=1.1.1=py36h7b6447c_0 67 | - matplotlib=3.3.2=h06a4308_0 68 | - matplotlib-base=3.3.2=py36h817c723_0 69 | - mccabe=0.6.1=py36_1 70 | - mistune=0.8.4=py36h7b6447c_0 71 | - mkl=2020.2=256 72 | - mkl-service=2.3.0=py36he8ac12f_0 73 | - mkl_fft=1.2.0=py36h23d657b_0 74 | - mkl_random=1.1.1=py36h0573a6f_0 75 | - mscorefonts=0.0.1=3 76 | - nbclient=0.5.1=py_0 77 | - nbconvert=6.0.7=py36_0 78 | - nbformat=5.0.8=py_0 79 | - ncurses=6.2=he6710b0_1 80 | - nest-asyncio=1.4.3=pyhd3eb1b0_0 81 | - ninja=1.10.2=py36hff7bd54_0 82 | - notebook=6.1.5=py36h06a4308_0 83 | - numpy-base=1.19.2=py36hfa32c7d_0 84 | - olefile=0.46=py36_0 85 | - open-fonts=0.7.0=1 86 | - openssl=1.1.1i=h27cfd23_0 87 | - packaging=20.8=pyhd3eb1b0_0 88 | - pandoc=2.11=hb0f4dca_0 89 | - pandocfilters=1.4.3=py36h06a4308_1 90 | - parso=0.8.1=pyhd3eb1b0_0 91 | - pcre=8.44=he6710b0_0 92 | - perl=5.26.2=h14c3975_0 93 | - perl-font-ttf=1.06=pl526_0 94 | - perl-io-string=1.08=pl526_3 95 | - perl-xml-parser=2.44=pl526h4e0c4b3_7 96 | - pexpect=4.8.0=pyhd3eb1b0_3 97 | - pickleshare=0.7.5=pyhd3eb1b0_1003 98 | - pillow=8.0.1=py36he98fc37_0 99 | - pip=20.3.1=py36h06a4308_0 100 | - prometheus_client=0.9.0=pyhd3eb1b0_0 101 | - prompt-toolkit=3.0.8=py_0 102 | - prompt_toolkit=3.0.8=0 103 | - ptyprocess=0.6.0=pyhd3eb1b0_2 104 | - pycparser=2.20=py_2 105 | - pygments=2.7.3=pyhd3eb1b0_0 106 | - pylint=2.6.0=py36_0 107 | - pyparsing=2.4.7=py_0 108 | - pyqt=5.9.2=py36h05f1152_2 109 | - pyrsistent=0.17.3=py36h7b6447c_0 110 | - python=3.6.12=hcff3b4d_2 111 | - python-dateutil=2.8.1=py_0 112 | - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0 113 | - pyzmq=20.0.0=py36h2531618_1 114 | - qt=5.9.7=h5867ecd_1 115 | - qtconsole=4.7.7=py_0 116 | - qtpy=1.9.0=py_0 117 | - readline=8.0=h7b6447c_0 118 | - scipy=1.5.2=py36h0b6359f_0 119 | - send2trash=1.5.0=pyhd3eb1b0_1 120 | - setuptools=51.0.0=py36h06a4308_2 121 | - sip=4.19.8=py36hf484d3e_0 122 | - six=1.15.0=py36h06a4308_0 123 | - sqlite=3.33.0=h62c20be_0 124 | - terminado=0.9.1=py36_0 125 | - testpath=0.4.4=py_0 126 | - tk=8.6.10=hbc83047_0 127 | - toml=0.10.1=py_0 128 | - torchaudio=0.7.0=py36 129 | - torchvision=0.8.1=py36_cu101 130 | - tornado=6.1=py36h27cfd23_0 131 | - traitlets=4.3.3=py36_0 132 | - typed-ast=1.4.1=py36h7b6447c_0 133 | - typing_extensions=3.7.4.3=py_0 134 | - wcwidth=0.2.5=py_0 135 | - webencodings=0.5.1=py36_1 136 | - wheel=0.36.2=pyhd3eb1b0_0 137 | - widgetsnbextension=3.5.1=py36_0 138 | - wrapt=1.11.2=py36h7b6447c_0 139 | - xz=5.2.5=h7b6447c_0 140 | - zeromq=4.3.3=he6710b0_3 141 | - zipp=3.4.0=pyhd3eb1b0_0 142 | - zlib=1.2.11=h7b6447c_3 143 | - zstd=1.4.5=h9ceee32_0 144 | - pip: 145 | - chardet==3.0.4 146 | - chinese-converter==1.0.2 147 | - click==7.1.2 148 | - cloudpickle==1.6.0 149 | - cython==0.29.21 150 | - dataclasses==0.8 151 | - filelock==3.0.12 152 | - flexible-dotdict==0.2.1 153 | - future==0.18.2 154 | - idna==2.10 155 | - jieba==0.42.1 156 | - joblib==0.17.0 157 | - networkx==2.5 158 | - nltk==3.5 159 | - numpy==1.19.4 160 | - portalocker==2.0.0 161 | - protobuf==3.14.0 162 | - pytz==2020.4 163 | - regex==2020.11.13 164 | - requests==2.25.0 165 | - sacrebleu==1.4.14 166 | - sacremoses==0.0.43 167 | - scikit-learn==0.23.2 168 | - sentencepiece==0.1.94 169 | - sklearn==0.0 170 | - submitit==1.1.5 171 | - threadpoolctl==2.1.0 172 | - tokenizers==0.9.4 173 | - torch==1.7.0 174 | - tqdm==4.54.0 175 | - transformers==4.0.0 176 | - urllib3==1.26.2 177 | prefix: /private/home/fhs/packages/anaconda3/envs/fairseq-1.7 178 | -------------------------------------------------------------------------------- /align/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from torch.utils.data import DataLoader, Dataset 4 | import regex 5 | import json 6 | import numpy as np 7 | import os 8 | 9 | 10 | class BitextAlignmentDataset(Dataset): 11 | def __init__(self, bitext_path, alignment_path): 12 | super(BitextAlignmentDataset, self).__init__() 13 | self.bitext_path = bitext_path 14 | self.alignment_path = alignment_path 15 | bitext = [regex.split(r'\|\|\|', x.strip()) for x in open(bitext_path)] 16 | align = open(alignment_path).readlines() 17 | self.bitext, self.edges = self.filter(bitext, align) 18 | assert len(self.bitext) == len(self.edges) 19 | 20 | @staticmethod 21 | def filter(bitext, align): 22 | real_bitext = list() 23 | edges = list() 24 | for i, a in enumerate(align): 25 | try: 26 | a = json.loads(a) 27 | if len(bitext[i]) == 2: 28 | bitext[i][0] = bitext[i][0].split() 29 | bitext[i][1] = bitext[i][1].split() 30 | real_bitext.append(bitext[i]) 31 | edge_info = np.zeros((len(bitext[i][0]), len(bitext[i][1]))) 32 | for x, y in a['inter']: 33 | edge_info[x, y] = 2 34 | for x, y in a['itermax']: 35 | if edge_info[x, y] == 0: 36 | edge_info[x, y] = 1 37 | edges.append(edge_info) 38 | except: 39 | continue 40 | return real_bitext, edges 41 | 42 | def __getitem__(self, index): 43 | return self.bitext[index], self.edges[index] 44 | 45 | def __len__(self): 46 | return len(self.bitext) 47 | 48 | @staticmethod 49 | def collate_fn(batch): 50 | return batch 51 | 52 | 53 | class AlignDataset(object): 54 | def __init__(self, path, langs, split='test'): 55 | if langs == 'de-en': 56 | src_sents = [x.strip() for x in open(os.path.join(path, langs, 'de'), encoding='iso-8859-1').readlines()][:-1] 57 | trg_sents = [x.strip() for x in open(os.path.join(path, langs, 'en'), encoding='iso-8859-1').readlines()][:-1] 58 | self.ground_truth = self.load_std_file(os.path.join(path, langs, 'alignmentDeEn.talp'))[:-1] 59 | elif langs == 'ro-en' or langs == 'en-fr': 60 | src_id2s = dict() 61 | trg_id2s = dict() 62 | for fpair in open(os.path.join(path, langs, split, f'FilePairs.{split}')): 63 | sf, tf = fpair.strip().split() 64 | for line in open(os.path.join(path, langs, split, sf), encoding='iso-8859-1'): 65 | matching = regex.match(r'(.*)', line.strip()) 66 | assert matching is not None 67 | idx = matching.group(1) 68 | sent = matching.group(2).strip() 69 | src_id2s[idx] = sent 70 | for line in open(os.path.join(path, langs, split, tf), encoding='iso-8859-1'): 71 | matching = regex.match(r'(.*)', line.strip()) 72 | assert matching is not None 73 | idx = matching.group(1) 74 | sent = matching.group(2).strip() 75 | trg_id2s[idx] = sent 76 | src_sents = [src_id2s[key] for key in sorted(src_id2s.keys())] 77 | trg_sents = [trg_id2s[key] for key in sorted(trg_id2s.keys())] 78 | snum2idx = dict([(key, i) for i, key in enumerate(sorted(trg_id2s.keys()))]) 79 | assert len(src_id2s) == len(trg_id2s) 80 | ground_truth = [list() for _ in src_id2s] 81 | raw_gt = open(os.path.join(path, langs, split, f'{split}.wa.nonullalign')).readlines() 82 | for line in raw_gt: 83 | sid, s, t, sure = line.strip().split() 84 | idx = snum2idx[sid] 85 | if sure == 'S': 86 | align = '-'.join([s, t]) 87 | else: 88 | assert sure == 'P' 89 | align = 'p'.join([s, t]) 90 | ground_truth[idx].append(align) 91 | for i, item in enumerate(ground_truth): 92 | ground_truth[i] = ' '.join(item) 93 | self.ground_truth = ground_truth 94 | elif langs == 'en-hi': 95 | src_id2s = dict() 96 | trg_id2s = dict() 97 | sf = f'{split}.e' 98 | tf = f'{split}.h' 99 | for line in open(os.path.join(path, langs, split, sf), encoding='us-ascii'): 100 | matching = regex.match(r'(.*)', line.strip()) 101 | assert matching is not None 102 | idx = matching.group(1) 103 | sent = matching.group(2).strip() 104 | src_id2s[idx] = sent 105 | for line in open(os.path.join(path, langs, split, tf), encoding='utf-8'): 106 | matching = regex.match(r'(.*)', line.strip()) 107 | assert matching is not None 108 | idx = matching.group(1) 109 | sent = matching.group(2).strip() 110 | trg_id2s[idx] = sent 111 | src_sents = [src_id2s[key] for key in sorted(src_id2s.keys())] 112 | trg_sents = [trg_id2s[key] for key in sorted(trg_id2s.keys())] 113 | snum2idx = dict([(key, i) for i, key in enumerate(sorted(trg_id2s.keys()))]) 114 | assert len(src_id2s) == len(trg_id2s) 115 | ground_truth = [list() for _ in src_id2s] 116 | raw_gt = open(os.path.join(path, langs, split, f'{split}.wa.nonullalign')).readlines() 117 | for line in raw_gt: 118 | sid, s, t = line.strip().split() 119 | idx = snum2idx[sid] 120 | align = '-'.join([s, t]) 121 | ground_truth[idx].append(align) 122 | for i, item in enumerate(ground_truth): 123 | ground_truth[i] = ' '.join(item) 124 | self.ground_truth = ground_truth 125 | else: 126 | raise Exception('language pair not supported.') 127 | self.sent_pairs = list(zip(src_sents, trg_sents)) 128 | assert len(self.sent_pairs) == len(self.ground_truth) 129 | 130 | @staticmethod 131 | def load_std_file(path): 132 | return [x.strip() for x in open(path)] 133 | 134 | -------------------------------------------------------------------------------- /align/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from train import * 4 | from data import AlignDataset 5 | import collections 6 | import copy 7 | import numpy as np 8 | from models import Aligner 9 | 10 | 11 | def eval_align(gold, silver, adjust=0): 12 | assert len(gold) == len(silver) 13 | a_size = s_size = p_size = ap_inter = as_inter = 0 14 | for i, g in enumerate(gold): 15 | s = set([ 16 | tuple(map(lambda x: int(x), item.split('-'))) 17 | for item in filter(lambda x: x.find('p') == -1, g.split()) 18 | ]) 19 | p = set([tuple(map(lambda x: int(x), regex.split('-|p', item))) for item in g.split()]) 20 | a = set([tuple(map(lambda x: int(x) + adjust, regex.split('-', item))) for item in silver[i].split()]) 21 | ap_inter += len(a.intersection(p)) 22 | as_inter += len(a.intersection(s)) 23 | a_size += len(a) 24 | p_size += len(p) 25 | s_size += len(s) 26 | prec = ap_inter / a_size if a_size > 0 else 0 27 | rec = as_inter / s_size if s_size > 0 else 0 28 | return { 29 | 'prec': prec, 30 | 'rec': rec, 31 | 'f1': 2 * prec * rec / (prec + rec) if s_size > 0 and a_size > 0 else 0, 32 | 'aer': 1 - (as_inter + ap_inter) / (a_size + s_size) 33 | } 34 | 35 | 36 | def inference(simalign, probs, threshold): 37 | n, m = probs.shape 38 | ids = probs.view(-1).argsort(descending=True) 39 | f = lambda x, m: (x.item()//m, x.item()%m) 40 | src2trg = collections.defaultdict(set) 41 | trg2src = collections.defaultdict(set) 42 | results = set() 43 | for pair in simalign.split(): 44 | x, y = pair.split('-') 45 | x = int(x) 46 | y = int(y) 47 | src2trg[x].add(y) 48 | trg2src[y].add(x) 49 | results.add((x, y)) 50 | for idx in ids: 51 | x, y = f(idx, m) 52 | if probs[x, y] < threshold: # too low similarity 53 | break 54 | if (x not in src2trg) and (y not in trg2src): # perfect company, keep 55 | src2trg[x].add(y) 56 | trg2src[y].add(x) 57 | results.add((x, y)) 58 | elif (x in src2trg) and (y in trg2src): # both have other companies, skip 59 | continue 60 | elif x in src2trg: # x has company, but y is still addable 61 | if y == max(src2trg[x]) + 1 or y == min(src2trg[x]) - 1: 62 | src2trg[x].add(y) 63 | trg2src[y].add(x) 64 | results.add((x, y)) 65 | else: 66 | if x == max(trg2src[y]) + 1 or x == min(trg2src[y]) - 1: 67 | src2trg[x].add(y) 68 | trg2src[y].add(x) 69 | results.add((x, y)) 70 | results = ' '.join([f'{x}-{y}' for x, y in sorted(results)]) 71 | return results 72 | 73 | 74 | def test(configs, criss, dataset, simaligns, threshold=0.5): 75 | setup_configs(configs) 76 | os.system(f'mkdir -p {configs.save_path}') 77 | torch.save(configs, configs.save_path + '/configs.pt') 78 | info = collect_bitext_stats( 79 | configs.bitext_path, configs.align_path, configs.save_path, 80 | configs.src_lang, configs.trg_lang, configs.reversed 81 | ) 82 | aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device) 83 | model_path = configs.save_path+f'/model.pt' 84 | results = list() 85 | aligner.load_state_dict(torch.load(model_path)) 86 | for idx, batch in enumerate(tqdm(dataset.sent_pairs)): 87 | ss, ts = batch 88 | ss = ss.split() 89 | ts = ts.split() 90 | if criss is not None: 91 | semb = criss.embed(ss, langcode=configs.src_lang) 92 | temb = criss.embed(ts, langcode=configs.trg_lang) 93 | cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1) 94 | ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1) 95 | feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1) 96 | word_pairs = list() 97 | criss_features = list() 98 | for i, sw in enumerate(ss): 99 | for j, tw in enumerate(ts): 100 | word_pairs.append((sw, tw)) 101 | criss_features.append(feat_matrix[i, j]) 102 | scores = extract_scores(word_pairs, criss_features, aligner, info, configs).reshape(len(ss), len(ts), -1) 103 | scores = scores.softmax(-1) 104 | arrange = torch.arange(3).to(configs.device).view(1, 1, -1) 105 | scores = (scores * arrange).sum(-1) 106 | result = inference(simaligns[idx], scores, threshold) 107 | results.append(result) 108 | return results 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('-a', '--align', type=str, help='path to word alignment') 114 | parser.add_argument('-b', '--bitext', type=str, help='path to bitext') 115 | parser.add_argument('-g', '--ground-truth', type=str, default='./data/align/', help='path to ground-truth') 116 | parser.add_argument('-src', '--source', type=str, help='source language code') 117 | parser.add_argument('-trg', '--target', type=str, help='target language code') 118 | parser.add_argument('-m', '--model-path', type=str, default='./model/', help='path to output folder') 119 | parser.add_argument('-d', '--device', type=str, default='cuda', help='device for training [cuda|cpu]') 120 | args = parser.parse_args() 121 | 122 | configs = dotdict.DotDict( 123 | { 124 | 'align_path': args.align, 125 | 'bitext_path': args.bitext, 126 | 'save_path': args.model_path, 127 | 'batch_size': 128, 128 | 'epochs': 100, 129 | 'device': args.device, 130 | 'hiddens': [8], 131 | 'use_criss': True, 132 | 'src_lang': args.source, 133 | 'trg_lang': args.target, 134 | 'threshold': 1.0 135 | } 136 | ) 137 | criss = CRISSWrapper(device=configs.device) 138 | dataset = collections.defaultdict(None) 139 | simaligner = Aligner( 140 | 'criss-align', distortion=0, 141 | path='criss/criss-3rd.pt', args_path='criss/args.pt', 142 | matching_method='a' 143 | ) 144 | lp = (args.source, args.target) 145 | dset = AlignDataset(args.ground_truth, f'{args.source.split("_")[0]}-{args.target.split("_")[0]}') 146 | simaligns = simaligner.align_sents(dset.sent_pairs, langcodes=lp) 147 | aligns = test(configs, criss, dset, simaligns, configs.threshold) 148 | results = eval_align(dset.ground_truth, aligns, 1) 149 | print(results) 150 | from IPython import embed; embed(using=False) 151 | -------------------------------------------------------------------------------- /src/fully_unsup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import collections 5 | import copy 6 | import dotdict 7 | import json 8 | import numpy as np 9 | import os 10 | import random 11 | import regex 12 | import tempfile 13 | import torch 14 | import torch.nn as nn 15 | from chinese_converter import to_traditional, to_simplified 16 | from tqdm import tqdm 17 | 18 | from evaluate import evaluate 19 | from models import CRISSWrapper, LexiconInducer 20 | 21 | 22 | cos = nn.CosineSimilarity(dim=-1) 23 | 24 | 25 | def setup_configs(configs): 26 | configs.save_path = configs.save_path.format(src=configs.src_lang, trg=configs.trg_lang) 27 | configs.stats_path = configs.save_path + '/stats.pt' 28 | 29 | 30 | def collect_bitext_stats(bitext_path, align_path, save_path, src_lang, trg_lang, is_reversed=False): 31 | stats_path = save_path + '/stats.pt' 32 | freq_path = save_path + '/freqs.pt' 33 | if os.path.exists(stats_path): 34 | coocc, semi_matched_coocc, matched_coocc = torch.load(stats_path) 35 | else: 36 | coocc = collections.defaultdict(collections.Counter) 37 | semi_matched_coocc = collections.defaultdict(collections.Counter) 38 | matched_coocc = collections.defaultdict(collections.Counter) 39 | tmpdir = tempfile.TemporaryDirectory() 40 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 41 | os.system(f'cat {align_path} > {tmpdir.name}/aligns.txt') 42 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 43 | aligns = open(f'{tmpdir.name}/aligns.txt').readlines() 44 | tmpdir.cleanup() 45 | assert len(bitext) == len(aligns) 46 | bar = tqdm(bitext) 47 | for i, item in enumerate(bar): 48 | try: 49 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 50 | if is_reversed: 51 | src_sent, trg_sent = trg_sent, src_sent 52 | align = [tuple(x if not is_reversed else reversed(x)) for x in json.loads(aligns[i])['inter']] # only focus on inter based alignment 53 | except: 54 | continue 55 | if src_lang == 'zh_CN': 56 | src_sent = to_simplified(src_sent) 57 | if trg_lang == 'zh_CN': 58 | trg_sent = to_simplified(trg_sent) 59 | src_words = src_sent.lower().split() 60 | trg_words = trg_sent.lower().split() 61 | src_cnt = collections.Counter([x[0] for x in align]) 62 | trg_cnt = collections.Counter([x[1] for x in align]) 63 | for x, sw in enumerate(src_words): 64 | for y, tw in enumerate(trg_words): 65 | if (x, y) in align: 66 | semi_matched_coocc[sw][tw] += 1 67 | if src_cnt[x] == 1 and trg_cnt[y] == 1: 68 | matched_coocc[sw][tw] += 1 69 | coocc[sw][tw] += 1 70 | torch.save((coocc, semi_matched_coocc, matched_coocc), stats_path) 71 | if os.path.exists(freq_path): 72 | freq_src, freq_trg = torch.load(freq_path) 73 | else: 74 | freq_src = collections.Counter() 75 | freq_trg = collections.Counter() 76 | tmpdir = tempfile.TemporaryDirectory() 77 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 78 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 79 | tmpdir.cleanup() 80 | bar = tqdm(bitext) 81 | for i, item in enumerate(bar): 82 | try: 83 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 84 | if is_reversed: 85 | src_sent, trg_sent = trg_sent, src_sent 86 | except: 87 | continue 88 | if src_lang == 'zh_CN': 89 | src_sent = to_simplified(src_sent) 90 | if trg_lang == 'zh_CN': 91 | trg_sent = to_simplified(trg_sent) 92 | for w in src_sent.split(): 93 | freq_src[w] += 1 94 | for w in trg_sent.split(): 95 | freq_trg[w] += 1 96 | torch.save((freq_src, freq_trg), freq_path) 97 | return coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg 98 | 99 | 100 | def load_lexicon(path): 101 | lexicon = [regex.split(r'\t| ', x.strip()) for x in open(path)] 102 | return set([tuple(x) for x in lexicon]) 103 | 104 | 105 | def get_test_lexicon(test_lexicon, info): 106 | induced_lexicon = list() 107 | coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg = info 108 | for tsw in tqdm(set([x[0] for x in test_lexicon])): 109 | ssw = to_simplified(tsw) 110 | candidates = list() 111 | for stw in matched_coocc[ssw]: 112 | ttw = to_traditional(stw) 113 | candidates.append([tsw, ttw, matched_coocc[ssw][stw] / (coocc[ssw][stw] + 20)]) 114 | if len(candidates) == 0: 115 | continue 116 | candidates = sorted(candidates, key=lambda x:-x[-1]) 117 | induced_lexicon.append(candidates[0][:2]) 118 | eval_result = evaluate(induced_lexicon, test_lexicon) 119 | return induced_lexicon, eval_result 120 | 121 | 122 | def test(configs, logging_steps=50000): 123 | setup_configs(configs) 124 | # prepare feature extractor 125 | info = collect_bitext_stats( 126 | configs.bitext_path, configs.align_path, configs.save_path, configs.src_lang, configs.trg_lang, configs.reversed 127 | ) 128 | # dataset 129 | test_lexicon = load_lexicon(configs.test_set) 130 | induced_test_lexicon, test_eval = get_test_lexicon(test_lexicon, info) 131 | with open(configs.save_path + '/induced.fullyunsup.dict', 'w') as fout: 132 | for item in induced_test_lexicon: 133 | fout.write('\t'.join([str(x) for x in item]) + '\n') 134 | fout.close() 135 | return induced_test_lexicon, test_eval 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('-a', '--align', type=str, help='path to word alignment') 141 | parser.add_argument('-b', '--bitext', type=str, help='path to bitext') 142 | parser.add_argument('-src', '--source', type=str, help='source language code') 143 | parser.add_argument('-trg', '--target', type=str, help='target language code') 144 | parser.add_argument('-te', '--test', type=str, help='path to test lexicon') 145 | parser.add_argument('-o', '--output', type=str, default='./model/', help='path to output folder') 146 | parser.add_argument('-d', '--device', type=str, default='cuda', help='device for training [cuda|cpu]') 147 | args = parser.parse_args() 148 | 149 | configs = dotdict.DotDict( 150 | { 151 | 'test_set': args.test, 152 | 'align_path': args.align, 153 | 'bitext_path': args.bitext, 154 | 'save_path': args.output, 155 | 'batch_size': 128, 156 | 'epochs': 50, 157 | 'device': args.device, 158 | 'hiddens': [8] 159 | } 160 | ) 161 | res = test(configs) 162 | print(res[-1]) -------------------------------------------------------------------------------- /align/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import collections 5 | import copy 6 | import dotdict 7 | import json 8 | import numpy as np 9 | import os 10 | import random 11 | import regex 12 | import tempfile 13 | import torch 14 | import torch.nn as nn 15 | from glob import glob 16 | from chinese_converter import to_traditional, to_simplified 17 | from tqdm import tqdm 18 | 19 | from models import CRISSWrapper, WordAligner 20 | from data import BitextAlignmentDataset 21 | 22 | 23 | cos = torch.nn.CosineSimilarity(dim=-1) 24 | 25 | def setup_configs(configs): 26 | configs.stats_path = configs.save_path + '/stats.pt' 27 | 28 | 29 | def collect_bitext_stats(bitext_path, align_path, save_path, src_lang, trg_lang, is_reversed=False): 30 | stats_path = save_path + '/stats.pt' 31 | freq_path = save_path + '/freqs.pt' 32 | if os.path.exists(stats_path): 33 | coocc, semi_matched_coocc, matched_coocc = torch.load(stats_path) 34 | else: 35 | coocc = collections.defaultdict(collections.Counter) 36 | semi_matched_coocc = collections.defaultdict(collections.Counter) 37 | matched_coocc = collections.defaultdict(collections.Counter) 38 | tmpdir = tempfile.TemporaryDirectory() 39 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 40 | os.system(f'cat {align_path} > {tmpdir.name}/aligns.txt') 41 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 42 | aligns = open(f'{tmpdir.name}/aligns.txt').readlines() 43 | tmpdir.cleanup() 44 | assert len(bitext) == len(aligns) 45 | bar = tqdm(bitext) 46 | for i, item in enumerate(bar): 47 | try: 48 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 49 | if is_reversed: 50 | src_sent, trg_sent = trg_sent, src_sent 51 | align = [tuple(x if not is_reversed else reversed(x)) for x in json.loads(aligns[i])['inter']] 52 | except: 53 | continue 54 | if src_lang == 'zh_CN': 55 | src_sent = to_simplified(src_sent) 56 | if trg_lang == 'zh_CN': 57 | trg_sent = to_simplified(trg_sent) 58 | src_words = src_sent.lower().split() 59 | trg_words = trg_sent.lower().split() 60 | src_cnt = collections.Counter([x[0] for x in align]) 61 | trg_cnt = collections.Counter([x[1] for x in align]) 62 | for x, sw in enumerate(src_words): 63 | for y, tw in enumerate(trg_words): 64 | if (x, y) in align: 65 | semi_matched_coocc[sw][tw] += 1 66 | if src_cnt[x] == 1 and trg_cnt[y] == 1: 67 | matched_coocc[sw][tw] += 1 68 | coocc[sw][tw] += 1 69 | torch.save((coocc, semi_matched_coocc, matched_coocc), stats_path) 70 | if os.path.exists(freq_path): 71 | freq_src, freq_trg = torch.load(freq_path) 72 | else: 73 | freq_src = collections.Counter() 74 | freq_trg = collections.Counter() 75 | tmpdir = tempfile.TemporaryDirectory() 76 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 77 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 78 | tmpdir.cleanup() 79 | bar = tqdm(bitext) 80 | for i, item in enumerate(bar): 81 | try: 82 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 83 | if is_reversed: 84 | src_sent, trg_sent = trg_sent, src_sent 85 | except: 86 | continue 87 | if src_lang == 'zh_CN': 88 | src_sent = to_simplified(src_sent) 89 | if trg_lang == 'zh_CN': 90 | trg_sent = to_simplified(trg_sent) 91 | for w in src_sent.split(): 92 | freq_src[w] += 1 93 | for w in trg_sent.split(): 94 | freq_trg[w] += 1 95 | torch.save((freq_src, freq_trg), freq_path) 96 | return coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg 97 | 98 | 99 | def extract_scores(batch, criss_features, aligner, info, configs): 100 | coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg = info 101 | all_scores = list() 102 | for i in range(0, len(batch), configs.batch_size): 103 | subbatch = batch[i:i+configs.batch_size] 104 | src_words, trg_words = zip(*subbatch) 105 | features = torch.tensor( 106 | [ 107 | [ 108 | matched_coocc[x[0]][x[1]], 109 | semi_matched_coocc[x[0]][x[1]], 110 | coocc[x[0]][x[1]], 111 | freq_src[x[0]], 112 | freq_trg[x[1]] 113 | ] for x in subbatch 114 | ] 115 | ).float().to(configs.device).reshape(-1, 5) 116 | if configs.use_criss: 117 | subbatch_crissfeat = torch.cat(criss_features[i:i+configs.batch_size], dim=0) 118 | features = torch.cat((subbatch_crissfeat, features), dim=-1).detach() 119 | scores = aligner(features).squeeze(-1) 120 | all_scores.append(scores) 121 | return torch.cat(all_scores, dim=0) 122 | 123 | 124 | def train(configs, logging_steps=50000): 125 | setup_configs(configs) 126 | os.system(f'mkdir -p {configs.save_path}') 127 | torch.save(configs, configs.save_path + '/configs.pt') 128 | info = collect_bitext_stats( 129 | configs.bitext_path, configs.align_path, configs.save_path, 130 | configs.src_lang, configs.trg_lang, configs.reversed 131 | ) 132 | if configs.use_criss: 133 | criss = CRISSWrapper(device=configs.device) 134 | else: 135 | criss = None 136 | dataset = BitextAlignmentDataset(configs.bitext_path, configs.align_path) 137 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn) 138 | aligner = WordAligner(5 + (2 if configs.use_criss else 0), configs.hiddens, 3, 5).to(configs.device) 139 | optimizer = torch.optim.Adam(aligner.parameters(), lr=.0005) 140 | for epoch in range(configs.epochs): 141 | model_cnt = 0 142 | total_loss = total_cnt = 0 143 | bar = tqdm(dataloader) 144 | for idx, batch in enumerate(bar): 145 | (ss, ts), edges = batch[0] 146 | if criss is not None: 147 | semb = criss.embed(ss, langcode=configs.src_lang) 148 | temb = criss.embed(ts, langcode=configs.trg_lang) 149 | cos_matrix = cos(semb.unsqueeze(1), temb.unsqueeze(0)).unsqueeze(-1).unsqueeze(-1) 150 | ip_matrix = (semb.unsqueeze(1) * temb.unsqueeze(0)).sum(-1).unsqueeze(-1).unsqueeze(-1) 151 | feat_matrix = torch.cat((cos_matrix, ip_matrix), dim=-1) 152 | # adding contexualized embeddings here 153 | training_sets = collections.defaultdict(list) 154 | criss_features = collections.defaultdict(list) 155 | for i, sw in enumerate(ss): 156 | for j, tw in enumerate(ts): 157 | label = edges[i, j] 158 | training_sets[label].append((sw, tw)) 159 | if criss is not None: 160 | criss_features[label].append(feat_matrix[i, j]) 161 | max_len = max(len(training_sets[k]) for k in training_sets) 162 | training_set = list() 163 | criss_feats = list() 164 | targets = list() 165 | for key in training_sets: 166 | training_set += training_sets[key] * (max_len // len(training_sets[key])) 167 | criss_feats += criss_features[key] * (max_len // len(training_sets[key])) 168 | targets += [key] * len(training_sets[key]) * (max_len // len(training_sets[key])) 169 | targets = torch.tensor(targets).long().to(configs.device) 170 | scores = extract_scores(training_set, criss_feats, aligner, info, configs) 171 | optimizer.zero_grad() 172 | loss = nn.CrossEntropyLoss()(scores, targets) 173 | loss.backward() 174 | optimizer.step() 175 | total_loss += loss.item() * len(batch) 176 | total_cnt += len(batch) 177 | bar.set_description(f'loss={total_loss / total_cnt:.5f}') 178 | if (idx + 1) % logging_steps == 0: 179 | print(f'Epoch {epoch}, step {idx+1}, loss = {total_loss / total_cnt:.5f}', flush=True) 180 | torch.save(aligner.state_dict(), configs.save_path + f'/model.pt') 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('-a', '--align', type=str, help='path to word alignment') 186 | parser.add_argument('-b', '--bitext', type=str, help='path to bitext') 187 | parser.add_argument('-src', '--source', type=str, help='source language code') 188 | parser.add_argument('-trg', '--target', type=str, help='target language code') 189 | parser.add_argument('-o', '--output', type=str, default='./model/', help='path to output folder') 190 | parser.add_argument('-d', '--device', type=str, default='cuda', help='device for training [cuda|cpu]') 191 | args = parser.parse_args() 192 | 193 | configs = dotdict.DotDict( 194 | { 195 | 'align_path': args.align, 196 | 'bitext_path': args.bitext, 197 | 'save_path': args.output, 198 | 'batch_size': 128, 199 | 'epochs': 100, 200 | 'device': args.device, 201 | 'hiddens': [8], 202 | 'use_criss': True, 203 | 'src_lang': args.source, 204 | 'trg_lang': args.target 205 | } 206 | ) 207 | 208 | train(configs) 209 | -------------------------------------------------------------------------------- /src/weakly_sup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import collections 5 | import copy 6 | import dotdict 7 | import json 8 | import numpy as np 9 | import os 10 | import random 11 | import regex 12 | import tempfile 13 | import torch 14 | import torch.nn as nn 15 | from chinese_converter import to_traditional, to_simplified 16 | from tqdm import tqdm 17 | 18 | from evaluate import evaluate 19 | from models import CRISSWrapper, LexiconInducer 20 | 21 | 22 | cos = nn.CosineSimilarity(dim=-1) 23 | 24 | 25 | def setup_configs(configs): 26 | configs.save_path = configs.save_path.format(src=configs.src_lang, trg=configs.trg_lang) 27 | configs.stats_path = configs.save_path + '/stats.pt' 28 | 29 | 30 | def collect_bitext_stats(bitext_path, align_path, save_path, src_lang, trg_lang, is_reversed=False): 31 | stats_path = save_path + '/stats.pt' 32 | freq_path = save_path + '/freqs.pt' 33 | if os.path.exists(stats_path): 34 | coocc, semi_matched_coocc, matched_coocc = torch.load(stats_path) 35 | else: 36 | coocc = collections.defaultdict(collections.Counter) 37 | semi_matched_coocc = collections.defaultdict(collections.Counter) 38 | matched_coocc = collections.defaultdict(collections.Counter) 39 | tmpdir = tempfile.TemporaryDirectory() 40 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 41 | os.system(f'cat {align_path} > {tmpdir.name}/aligns.txt') 42 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 43 | aligns = open(f'{tmpdir.name}/aligns.txt').readlines() 44 | tmpdir.cleanup() 45 | assert len(bitext) == len(aligns) 46 | bar = tqdm(bitext) 47 | for i, item in enumerate(bar): 48 | try: 49 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 50 | if is_reversed: 51 | src_sent, trg_sent = trg_sent, src_sent 52 | align = [tuple(x if not is_reversed else reversed(x)) for x in json.loads(aligns[i])['inter']] 53 | except: 54 | continue 55 | if src_lang == 'zh_CN': 56 | src_sent = to_simplified(src_sent) 57 | if trg_lang == 'zh_CN': 58 | trg_sent = to_simplified(trg_sent) 59 | src_words = src_sent.lower().split() 60 | trg_words = trg_sent.lower().split() 61 | src_cnt = collections.Counter([x[0] for x in align]) 62 | trg_cnt = collections.Counter([x[1] for x in align]) 63 | for x, sw in enumerate(src_words): 64 | for y, tw in enumerate(trg_words): 65 | if (x, y) in align: 66 | semi_matched_coocc[sw][tw] += 1 67 | if src_cnt[x] == 1 and trg_cnt[y] == 1: 68 | matched_coocc[sw][tw] += 1 69 | coocc[sw][tw] += 1 70 | torch.save((coocc, semi_matched_coocc, matched_coocc), stats_path) 71 | if os.path.exists(freq_path): 72 | freq_src, freq_trg = torch.load(freq_path) 73 | else: 74 | freq_src = collections.Counter() 75 | freq_trg = collections.Counter() 76 | tmpdir = tempfile.TemporaryDirectory() 77 | os.system(f'cat {bitext_path} > {tmpdir.name}/bitext.txt') 78 | bitext = open(f'{tmpdir.name}/bitext.txt').readlines() 79 | tmpdir.cleanup() 80 | bar = tqdm(bitext) 81 | for i, item in enumerate(bar): 82 | try: 83 | src_sent, trg_sent = regex.split(r'\|\|\|', item.strip()) 84 | if is_reversed: 85 | src_sent, trg_sent = trg_sent, src_sent 86 | except: 87 | continue 88 | if src_lang == 'zh_CN': 89 | src_sent = to_simplified(src_sent) 90 | if trg_lang == 'zh_CN': 91 | trg_sent = to_simplified(trg_sent) 92 | for w in src_sent.split(): 93 | freq_src[w] += 1 94 | for w in trg_sent.split(): 95 | freq_trg[w] += 1 96 | torch.save((freq_src, freq_trg), freq_path) 97 | return coocc, semi_matched_coocc, matched_coocc, freq_src, freq_trg 98 | 99 | 100 | def load_lexicon(path): 101 | lexicon = [regex.split(r'\t| ', x.strip()) for x in open(path)] 102 | return set([tuple(x) for x in lexicon]) 103 | 104 | 105 | def extract_dataset(train_lexicon, test_lexicon, coocc, configs): 106 | cooccs = [coocc] 107 | test_set = set() 108 | pos_training_set = set() 109 | neg_training_set = set() 110 | for tsw in set([x[0] for x in train_lexicon]): 111 | for coocc in cooccs: 112 | ssw = to_simplified(tsw) if configs.src_lang == 'zh_CN' else tsw 113 | for stw in coocc[ssw]: 114 | if stw == ssw: 115 | added_self = True 116 | ttw = to_traditional(stw) if configs.trg_lang == 'zh_CN' else stw 117 | if (tsw, ttw) in train_lexicon: 118 | pos_training_set.add((ssw, stw)) 119 | else: 120 | neg_training_set.add((ssw, stw)) 121 | if (ssw, ssw) in train_lexicon: 122 | pos_training_set.add((ssw, ssw)) 123 | else: 124 | neg_training_set.add((ssw, ssw)) 125 | for tsw in set([x[0] for x in test_lexicon]): 126 | for coocc in cooccs: 127 | ssw = to_simplified(tsw) if configs.src_lang == 'zh_CN' else tsw 128 | added_self = False 129 | for stw in coocc[ssw]: 130 | if stw == ssw: 131 | added_self = True 132 | test_set.add((ssw, stw)) 133 | test_set.add((ssw, ssw)) 134 | pos_training_set = list(pos_training_set) 135 | neg_training_set = list(neg_training_set) 136 | test_set = list(test_set) 137 | return pos_training_set, neg_training_set, test_set 138 | 139 | 140 | def extract_probs(batch, criss, lexicon_inducer, info, configs): 141 | matched_coocc, semi_matched_coocc, coocc, freq_src, freq_trg = info 142 | all_probs = list() 143 | for i in range(0, len(batch), configs.batch_size): 144 | subbatch = batch[i:i+configs.batch_size] 145 | src_words, trg_words = zip(*subbatch) 146 | src_encodings = criss.word_embed(src_words, configs.src_lang).detach() 147 | trg_encodings = criss.word_embed(trg_words, configs.trg_lang).detach() 148 | cos_sim = cos(src_encodings, trg_encodings).reshape(-1, 1) 149 | dot_prod = (src_encodings * trg_encodings).sum(-1).reshape(-1, 1) 150 | features = torch.tensor( 151 | [ 152 | [ 153 | matched_coocc[x[0]][x[1]], 154 | semi_matched_coocc[x[0]][x[1]], 155 | coocc[x[0]][x[1]], 156 | freq_src[x[0]], 157 | freq_trg[x[1]], 158 | ] for x in subbatch 159 | ] 160 | ).float().to(configs.device).reshape(-1, 5) 161 | features = torch.cat([cos_sim, dot_prod, features], dim=-1) 162 | probs = lexicon_inducer(features).squeeze(-1) 163 | all_probs.append(probs) 164 | return torch.cat(all_probs, dim=0) 165 | 166 | 167 | def get_test_lexicon( 168 | test_set, test_lexicon, criss, lexicon_inducer, info, configs, best_threshold, best_n_cand 169 | ): 170 | induced_lexicon = list() 171 | pred_test_lexicon = collections.defaultdict(collections.Counter) 172 | probs = extract_probs( 173 | test_set, criss, lexicon_inducer, info, configs 174 | ) 175 | for i, (x, y) in enumerate(test_set): 176 | pred_test_lexicon[x][y] = max(pred_test_lexicon[x][y], probs[i].item()) 177 | possible_predictions = list() 178 | for tsw in set([x[0] for x in test_lexicon]): 179 | ssw = to_simplified(tsw) 180 | for stw in pred_test_lexicon[ssw]: 181 | ttw = to_traditional(stw) 182 | pos = 1 if (tsw, ttw) in test_lexicon else 0 183 | possible_predictions.append([tsw, ttw, pred_test_lexicon[ssw][stw], pos]) 184 | possible_predictions = sorted(possible_predictions, key=lambda x:-x[-2]) 185 | word_cnt = collections.Counter() 186 | correct_predictions = 0 187 | for i, item in enumerate(possible_predictions): 188 | if item[-2] < best_threshold: 189 | prec = correct_predictions / (sum(word_cnt.values()) + 1) * 100.0 190 | rec = correct_predictions / len(test_lexicon) * 100.0 191 | f1 = 2 * prec * rec / (rec + prec) 192 | print(f'Test F1: {f1:.2f}') 193 | break 194 | if word_cnt[item[0]] == best_n_cand: 195 | continue 196 | word_cnt[item[0]] += 1 197 | if item[-1] == 1: 198 | correct_predictions += 1 199 | induced_lexicon.append(item[:2]) 200 | eval_result = evaluate(induced_lexicon, test_lexicon) 201 | return induced_lexicon, eval_result 202 | 203 | 204 | def get_optimal_parameters( 205 | pos_training_set, neg_training_set, train_lexicon, criss, 206 | lexicon_inducer, info, configs, 207 | ): 208 | pred_train_lexicon = collections.defaultdict(collections.Counter) 209 | probs = extract_probs( 210 | pos_training_set + neg_training_set, criss, lexicon_inducer, info, configs 211 | ) 212 | for i, (x, y) in enumerate(pos_training_set + neg_training_set): 213 | pred_train_lexicon[x][y] = max(pred_train_lexicon[x][y], probs[i].item()) 214 | possible_predictions = list() 215 | for tsw in set([x[0] for x in train_lexicon]): 216 | ssw = to_simplified(tsw) 217 | for stw in pred_train_lexicon[ssw]: 218 | ttw = to_traditional(stw) 219 | pos = 1 if (tsw, ttw) in train_lexicon else 0 220 | possible_predictions.append([tsw, ttw, pred_train_lexicon[ssw][stw], pos]) 221 | possible_predictions = sorted(possible_predictions, key=lambda x:-x[-2]) 222 | best_f1 = -1e10 223 | best_threshold = best_n_cand = 0 224 | for n_cand in range(1, 6): 225 | word_cnt = collections.Counter() 226 | correct_predictions = 0 227 | bar = tqdm(possible_predictions) 228 | for i, item in enumerate(bar): 229 | if word_cnt[item[0]] == n_cand: 230 | continue 231 | word_cnt[item[0]] += 1 232 | if item[-1] == 1: 233 | correct_predictions += 1 234 | prec = correct_predictions / (sum(word_cnt.values()) + 1) * 100.0 235 | rec = correct_predictions / len(train_lexicon) * 100.0 236 | f1 = 2 * prec * rec / (rec + prec) 237 | if f1 > best_f1: 238 | best_f1 = f1 239 | best_threshold = item[-2] 240 | best_n_cand = n_cand 241 | bar.set_description( 242 | f'Best F1={f1:.1f}, Prec={prec:.1f}, Rec={rec:.1f}, NCand={n_cand}, Threshold={item[-2]}' 243 | ) 244 | return best_threshold, best_n_cand 245 | 246 | 247 | def train_test(configs, logging_steps=50000): 248 | setup_configs(configs) 249 | os.system(f'mkdir -p {configs.save_path}') 250 | torch.save(configs, configs.save_path + '/configs.pt') 251 | # prepare feature extractor 252 | info = collect_bitext_stats( 253 | configs.bitext_path, configs.align_path, configs.save_path, configs.src_lang, configs.trg_lang, configs.reversed) 254 | # dataset 255 | train_lexicon = load_lexicon(configs.tuning_set) 256 | sim_train_lexicon = {(to_simplified(x[0]), to_simplified(x[1])) for x in train_lexicon} 257 | all_train_lexicon = train_lexicon.union(sim_train_lexicon) 258 | test_lexicon = load_lexicon(configs.test_set) 259 | pos_training_set, neg_training_set, test_set = extract_dataset( 260 | train_lexicon, test_lexicon, info[2], configs 261 | ) 262 | training_set_modifier = max(1, len(neg_training_set) // len(pos_training_set)) 263 | training_set = pos_training_set * training_set_modifier + neg_training_set 264 | print(f'Positive training set is repeated {training_set_modifier} times due to data imbalance.') 265 | # model and optimizers 266 | criss = CRISSWrapper(device=configs.device) 267 | lexicon_inducer = LexiconInducer(7, configs.hiddens, 1, 5).to(configs.device) 268 | optimizer = torch.optim.Adam(lexicon_inducer.parameters(), lr=.0005) 269 | # train model 270 | for epoch in range(configs.epochs): 271 | model_path = configs.save_path + f'/{epoch}.model.pt' 272 | if os.path.exists(model_path): 273 | lexicon_inducer.load_state_dict(torch.load(model_path)) 274 | continue 275 | random.shuffle(training_set) 276 | bar = tqdm(range(0, len(training_set), configs.batch_size)) 277 | total_loss = total_cnt = 0 278 | for i, sid in enumerate(bar): 279 | batch = training_set[sid:sid+configs.batch_size] 280 | probs = extract_probs(batch, criss, lexicon_inducer, info, configs) 281 | targets = torch.tensor( 282 | [1 if tuple(x) in all_train_lexicon else 0 for x in batch]).float().to(configs.device) 283 | optimizer.zero_grad() 284 | loss = nn.BCELoss()(probs, targets) 285 | loss.backward() 286 | optimizer.step() 287 | total_loss += loss.item() * len(batch) 288 | total_cnt += len(batch) 289 | bar.set_description(f'loss={total_loss / total_cnt:.5f}') 290 | if (i + 1) % logging_steps == 0: 291 | print(f'Epoch {epoch}, step {i+1}, loss = {total_loss / total_cnt:.5f}', flush=True) 292 | torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/{epoch}.{i+1}.model.pt') 293 | print(f'Epoch {epoch}, loss = {total_loss / total_cnt:.5f}', flush=True) 294 | torch.save(lexicon_inducer.state_dict(), configs.save_path + f'/model.pt') 295 | best_threshold, best_n_cand = get_optimal_parameters( 296 | pos_training_set, neg_training_set, train_lexicon, criss, 297 | lexicon_inducer, info, configs, 298 | ) 299 | induced_test_lexicon, test_eval = get_test_lexicon( 300 | test_set, test_lexicon, criss, lexicon_inducer, info, configs, best_threshold, best_n_cand 301 | ) 302 | with open(configs.save_path + '/induced.weaklysup.dict', 'w') as fout: 303 | for item in induced_test_lexicon: 304 | fout.write('\t'.join([str(x) for x in item]) + '\n') 305 | fout.close() 306 | return induced_test_lexicon, test_eval 307 | 308 | 309 | if __name__ == '__main__': 310 | parser = argparse.ArgumentParser() 311 | parser.add_argument('-a', '--align', type=str, help='path to word alignment') 312 | parser.add_argument('-b', '--bitext', type=str, help='path to bitext') 313 | parser.add_argument('-src', '--source', type=str, help='source language code') 314 | parser.add_argument('-trg', '--target', type=str, help='target language code') 315 | parser.add_argument('-te', '--test', type=str, help='path to test lexicon') 316 | parser.add_argument('-tr', '--train', type=str, help='path to training lexicon') 317 | parser.add_argument('-o', '--output', type=str, default='./model/', help='path to output folder') 318 | parser.add_argument('-d', '--device', type=str, default='cuda', help='device for training [cuda|cpu]') 319 | args = parser.parse_args() 320 | 321 | configs = dotdict.DotDict( 322 | { 323 | 'test_set': args.test, 324 | 'tuning_set': args.train, 325 | 'align_path': args.align, 326 | 'bitext_path': args.bitext, 327 | 'save_path': args.output, 328 | 'batch_size': 128, 329 | 'epochs': 50, 330 | 'device': args.device, 331 | 'hiddens': [8], 332 | 'src_lang': args.source, 333 | 'trg_lang': args.target 334 | } 335 | ) 336 | 337 | res = train_test(configs) 338 | print(res[-1]) 339 | -------------------------------------------------------------------------------- /align/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import os 6 | import tempfile 7 | import torch 8 | import torch.nn as nn 9 | from networkx.algorithms.bipartite.matrix import from_biadjacency_matrix 10 | from scipy.sparse import csr_matrix 11 | from sklearn.metrics.pairwise import cosine_similarity 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer 14 | import regex 15 | import collections 16 | from glob import glob 17 | 18 | 19 | class CRISSAligner(object): 20 | def __init__(self, path='criss/criss-3rd.pt', 21 | args_path='criss/args.pt', 22 | tokenizer='facebook/mbart-large-cc25', device='cpu', distortion=0, 23 | matching_method='a' 24 | ): 25 | from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils 26 | from fairseq.sequence_generator import EnsembleModel 27 | self.device = device 28 | args = torch.load(args_path) 29 | task = tasks.setup_task(args) 30 | models, _model_args = checkpoint_utils.load_model_ensemble( 31 | path.split(':'), 32 | arg_overrides=eval('{}'), 33 | task=task 34 | ) 35 | for model in models: 36 | model.make_generation_fast_( 37 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 38 | need_attn=args.print_alignment, 39 | ) 40 | if args.fp16: 41 | model.half() 42 | model = model.to(self.device) 43 | self.model = EnsembleModel(models).to(self.device) 44 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 45 | self.distortion = distortion 46 | self.matching_method = matching_method 47 | 48 | def get_embed(self, bpe_lists, langcodes=('en_XX', 'en_XX')): 49 | vectors = list() 50 | for i, bpe_list in enumerate(bpe_lists): 51 | input_ids = self.tokenizer.convert_tokens_to_ids(bpe_list + ['', langcodes[i]]) 52 | encoder_input = { 53 | 'src_tokens': torch.tensor(input_ids).view(1, -1).to(self.device), 54 | 'src_lengths': torch.tensor([len(input_ids)]).to(self.device) 55 | } 56 | encoder_outs = self.model.forward_encoder(encoder_input) 57 | np_encoder_outs = encoder_outs[0].encoder_out.cpu().squeeze(1).numpy().astype(np.float32) 58 | vectors.append(np_encoder_outs[:-2, :]) 59 | return vectors 60 | 61 | def get_word_aligns(self, src_sent, trg_sent, langcodes=None, fwd_dict=None, bwd_dict=None, debug=False): 62 | l1_tokens = [self.tokenizer.tokenize(word) for word in src_sent] 63 | l2_tokens = [self.tokenizer.tokenize(word) for word in trg_sent] 64 | bpe_lists = [[bpe for w in sent for bpe in w] for sent in [l1_tokens, l2_tokens]] 65 | l1_b2w_map = list() 66 | for i, wlist in enumerate(l1_tokens): 67 | l1_b2w_map += [i for _ in wlist] 68 | l2_b2w_map = list() 69 | for i, wlist in enumerate(l2_tokens): 70 | l2_b2w_map += [i for _ in wlist] 71 | vectors = self.get_embed(list(bpe_lists), langcodes) 72 | sim = (cosine_similarity(vectors[0], vectors[1]) + 1.0) / 2.0 73 | sim = self.apply_distortion(sim, self.distortion) 74 | all_mats = dict() 75 | fwd, bwd = self.get_alignment_matrix(sim) 76 | if self.matching_method.find('a') != -1: 77 | all_mats['inter'] = fwd * bwd 78 | if self.matching_method.find('i') != -1: 79 | all_mats['itermax'] = self.iter_max(sim) 80 | if self.matching_method.find('m') != -1: 81 | all_mats['mwmf'] = self.get_max_weight_match(sim) 82 | if self.matching_method.find('f') != -1: 83 | all_mats['fixed'] = fwd * bwd 84 | aligns = {k: set() for k in all_mats} 85 | for key in aligns: 86 | for i in range(vectors[0].shape[0]): 87 | for j in range(vectors[1].shape[0]): 88 | if all_mats[key][i, j] > 1e-10: 89 | aligns[key].add((l1_b2w_map[i], l2_b2w_map[j])) 90 | if 'fixed' in aligns: 91 | src_aligned = set([x[0] for x in aligns['fixed']]) 92 | trg_aligned = set([x[1] for x in aligns['fixed']]) 93 | candidate_alignment = list() 94 | for i, sw in enumerate(src_sent): 95 | sw = sw.lower() 96 | if i not in src_aligned: 97 | for j, tw in enumerate(trg_sent): 98 | tw = tw.lower() 99 | if tw in fwd_dict[sw]: 100 | ri = i / len(src_sent) 101 | rj = j / len(trg_sent) 102 | if -0.2 < ri - rj < 0.2: 103 | candidate_alignment.append((sw, tw, i, j, fwd_dict[sw][tw], 0)) 104 | for j, tw in enumerate(trg_sent): 105 | tw = tw.lower() 106 | if j not in trg_aligned: 107 | for i, sw in enumerate(src_sent): 108 | sw = sw.lower() 109 | if sw in bwd_dict[tw]: 110 | ri = i / len(src_sent) 111 | rj = j / len(trg_sent) 112 | if -0.2 < ri - rj < 0.2: 113 | candidate_alignment.append((sw, tw, i, j, bwd_dict[tw][sw], 1)) 114 | candidate_alignment = sorted(candidate_alignment, key=lambda x: -x[-2]) 115 | for sw, tw, i, j, val, d in candidate_alignment: 116 | if regex.match(r'\p{P}', sw) or regex.match(r'\p{P}', tw): 117 | continue 118 | if val < 0.05: 119 | break 120 | if d == 0: 121 | if i in src_aligned: 122 | continue 123 | if (j not in trg_aligned) or ((i-1, j) in aligns['fixed']) or ((i+1, j) in aligns['fixed']): 124 | aligns['fixed'].add((i, j)) 125 | src_aligned.add(i) 126 | trg_aligned.add(j) 127 | if debug: 128 | print(sw, tw, i, j, val, d) 129 | else: 130 | if j in trg_aligned: 131 | continue 132 | if (i not in src_aligned) or ((i, j+1) in aligns['fixed']) or ((i, j-1) in aligns['fixed']): 133 | aligns['fixed'].add((i, j)) 134 | src_aligned.add(i) 135 | trg_aligned.add(j) 136 | if debug: 137 | print(sw, tw, i, j, val, d) 138 | for ext in aligns: 139 | aligns[ext] = sorted(aligns[ext]) 140 | return aligns 141 | 142 | @staticmethod 143 | def get_max_weight_match(sim): 144 | if nx is None: 145 | raise ValueError("networkx must be installed to use match algorithm.") 146 | 147 | def permute(edge): 148 | if edge[0] < sim.shape[0]: 149 | return edge[0], edge[1] - sim.shape[0] 150 | else: 151 | return edge[1], edge[0] - sim.shape[0] 152 | 153 | G = from_biadjacency_matrix(csr_matrix(sim)) 154 | matching = nx.max_weight_matching(G, maxcardinality=True) 155 | matching = [permute(x) for x in matching] 156 | matching = sorted(matching, key=lambda x: x[0]) 157 | res_matrix = np.zeros_like(sim) 158 | for edge in matching: 159 | res_matrix[edge[0], edge[1]] = 1 160 | return res_matrix 161 | 162 | @staticmethod 163 | def iter_max(sim_matrix, max_count=2): 164 | alpha_ratio = 0.9 165 | m, n = sim_matrix.shape 166 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n 167 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m 168 | inter = forward * backward.transpose() 169 | 170 | if min(m, n) <= 2: 171 | return inter 172 | 173 | new_inter = np.zeros((m, n)) 174 | count = 1 175 | while count < max_count: 176 | mask_x = 1.0 - np.tile(inter.sum(1)[:, np.newaxis], (1, n)).clip(0.0, 1.0) 177 | mask_y = 1.0 - np.tile(inter.sum(0)[np.newaxis, :], (m, 1)).clip(0.0, 1.0) 178 | mask = ((alpha_ratio * mask_x) + (alpha_ratio * mask_y)).clip(0.0, 1.0) 179 | mask_zeros = 1.0 - ((1.0 - mask_x) * (1.0 - mask_y)) 180 | if mask_x.sum() < 1.0 or mask_y.sum() < 1.0: 181 | mask *= 0.0 182 | mask_zeros *= 0.0 183 | 184 | new_sim = sim_matrix * mask 185 | fwd = np.eye(n)[new_sim.argmax(axis=1)] * mask_zeros 186 | bac = np.eye(m)[new_sim.argmax(axis=0)].transpose() * mask_zeros 187 | new_inter = fwd * bac 188 | 189 | if np.array_equal(inter + new_inter, inter): 190 | break 191 | inter = inter + new_inter 192 | count += 1 193 | return inter 194 | 195 | @staticmethod 196 | def get_alignment_matrix(sim_matrix): 197 | m, n = sim_matrix.shape 198 | forward = np.eye(n)[sim_matrix.argmax(axis=1)] # m x n 199 | backward = np.eye(m)[sim_matrix.argmax(axis=0)] # n x m 200 | return forward, backward.transpose() 201 | 202 | @staticmethod 203 | def apply_distortion(sim_matrix, ratio=0.5): 204 | shape = sim_matrix.shape 205 | if (shape[0] < 2 or shape[1] < 2) or ratio == 0.0: 206 | return sim_matrix 207 | 208 | pos_x = np.array([[y / float(shape[1] - 1) for y in range(shape[1])] for x in range(shape[0])]) 209 | pos_y = np.array([[x / float(shape[0] - 1) for x in range(shape[0])] for y in range(shape[1])]) 210 | distortion_mask = 1.0 - ((pos_x - np.transpose(pos_y)) ** 2) * ratio 211 | 212 | return np.multiply(sim_matrix, distortion_mask) 213 | 214 | 215 | class Aligner(object): 216 | def __init__(self, aligner_type, **kwargs): 217 | self.aligner_type = aligner_type 218 | if aligner_type == 'simalign': 219 | from simalign import SentenceAligner 220 | d = 'cuda' if torch.cuda.is_available() else 'cpu' 221 | self.aligner = SentenceAligner('xlm-roberta-base', device=d, **kwargs) 222 | elif aligner_type in ['fastalign', 'giza++']: 223 | pass 224 | elif aligner_type == 'criss-align': 225 | self.aligner = CRISSAligner(**kwargs) 226 | else: 227 | raise Exception('Aligner type not supported.') 228 | 229 | def align_sents(self, sent_pairs, train_file=None, **kwargs): 230 | aligns = list() 231 | if self.aligner_type in ['simalign', 'criss-align']: 232 | for src, trg in tqdm(sent_pairs): 233 | src = src.strip().split() 234 | trg = trg.strip().split() 235 | align_info = self.aligner.get_word_aligns(src, trg, **kwargs) 236 | result = None 237 | for key in align_info: 238 | if result is None: 239 | result = set(align_info[key]) 240 | else: 241 | result = result.intersection(align_info[key]) 242 | aligns.append(' '.join(['-'.join([str(x) for x in item]) for item in sorted(result)])) 243 | elif self.aligner_type == 'fastalign': 244 | temp_dir = tempfile.TemporaryDirectory(prefix='fast-align') 245 | with open(os.path.join(temp_dir.name, 'bitext.txt'), 'w') as fout: 246 | for ss, ts in sent_pairs: 247 | fout.write(ss + ' ||| ' + ts + '\n') 248 | fout.close() 249 | if train_file is not None: 250 | assert os.path.exists(train_file) 251 | os.system(f'cat {train_file} >> {temp_dir.name}/bitext.txt') 252 | os.system(f'fast_align -d -o -v -i {temp_dir.name}/bitext.txt > {temp_dir.name}/fwd.align') 253 | os.system(f'fast_align -d -o -v -r -i {temp_dir.name}/bitext.txt > {temp_dir.name}/bwd.align') 254 | os.system(f'atools -i {temp_dir.name}/fwd.align -j {temp_dir.name}/bwd.align -c grow-diag-final-and > {temp_dir.name}/final.align') 255 | aligns = [x.strip() for x in open(f'{temp_dir.name}/final.align').readlines()][:len(sent_pairs)] 256 | elif self.aligner_type == 'giza++': 257 | assert train_file is not None 258 | giza_path = '/private/home/fhs/codebase/lexind/fairseq/2-word-align-final/giza-pp/GIZA++-v2/GIZA++' 259 | temp_dir = tempfile.TemporaryDirectory(prefix='giza++') 260 | d_src = collections.Counter() 261 | d_trg = collections.Counter() 262 | w2id_src = collections.defaultdict() 263 | w2id_trg = collections.defaultdict() 264 | for sent_pair in open(train_file): 265 | ss, ts = regex.split(r'\|\|\|', sent_pair.lower()) 266 | for w in ss.strip().split(): 267 | d_src[w] += 1 268 | for w in ts.strip().split(): 269 | d_trg[w] += 1 270 | for ss, ts in sent_pairs: 271 | ss = ss.lower() 272 | ts = ts.lower() 273 | for w in ss.strip().split(): 274 | d_src[w] += 1 275 | for w in ts.strip().split(): 276 | d_trg[w] += 1 277 | with open(os.path.join(temp_dir.name, 's.vcb'), 'w') as fout: 278 | for i, w in enumerate(sorted(d_src.keys())): 279 | print(i + 1, w, d_src[w], file=fout) 280 | w2id_src[w] = i + 1 281 | fout.close() 282 | with open(os.path.join(temp_dir.name, 't.vcb'), 'w') as fout: 283 | for i, w in enumerate(sorted(d_trg.keys())): 284 | print(i + 1, w, d_trg[w], file=fout) 285 | w2id_trg[w] = i + 1 286 | fout.close() 287 | with open(os.path.join(temp_dir.name, 'bitext.train'), 'w') as fout: 288 | for sent_pair in open(train_file): 289 | ss, ts = regex.split(r'\|\|\|', sent_pair.lower()) 290 | print(1, file=fout) 291 | print(' '.join([str(w2id_src[x]) for x in ss.strip().split()]), file=fout) 292 | print(' '.join([str(w2id_trg[x]) for x in ts.strip().split()]), file=fout) 293 | fout.close() 294 | with open(os.path.join(temp_dir.name, 'bitext.test'), 'w') as fout: 295 | for ss, ts in sent_pairs: 296 | ss = ss.lower() 297 | ts = ts.lower() 298 | print(1, file=fout) 299 | print(' '.join([str(w2id_src[x]) for x in ss.strip().split()]), file=fout) 300 | print(' '.join([str(w2id_trg[x]) for x in ts.strip().split()]), file=fout) 301 | fout.close() 302 | os.chdir(f'{temp_dir.name}') 303 | os.system(f'{giza_path} -S {temp_dir.name}/s.vcb -T {temp_dir.name}/t.vcb -C {temp_dir.name}/bitext.train -tc {temp_dir.name}/bitext.test') 304 | # read giza++ results 305 | for i, line in enumerate(open(glob(f'{temp_dir.name}/*tst.A3*')[0])): 306 | if i % 3 == 2: 307 | align = list() 308 | is_trg = False 309 | is_null = False 310 | src_idx = 0 311 | for item in line.strip().split(): 312 | if item == '({': 313 | is_trg = True 314 | elif item == '})': 315 | is_trg = False 316 | elif is_trg: 317 | if not is_null: 318 | trg_idx = int(item) 319 | align.append(f'{src_idx}-{trg_idx}') 320 | elif item != 'NULL': 321 | src_idx += 1 322 | is_null = False 323 | else: 324 | is_null = True 325 | aligns.append(' '.join(align)) 326 | temp_dir.cleanup() 327 | return aligns 328 | 329 | 330 | class CRISSWrapper(object): 331 | 332 | def __init__(self, path='criss/criss-3rd.pt', args_path='criss/args.pt', 333 | tokenizer='facebook/mbart-large-cc25', device='cpu'): 334 | from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils 335 | from fairseq.sequence_generator import EnsembleModel 336 | self.device = device 337 | args = torch.load(args_path) 338 | task = tasks.setup_task(args) 339 | models, _model_args = checkpoint_utils.load_model_ensemble( 340 | path.split(':'), 341 | arg_overrides=eval('{}'), 342 | task=task 343 | ) 344 | for model in models: 345 | model.make_generation_fast_( 346 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 347 | need_attn=args.print_alignment, 348 | ) 349 | if args.fp16: 350 | model.half() 351 | model = model.to(self.device) 352 | self.model = EnsembleModel(models).to(self.device) 353 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 354 | 355 | def embed(self, words, langcode='en_XX'): 356 | lbs, rbs = list(), list() 357 | tokens, word_ids = list(), list() 358 | for word in words: 359 | word_tokens = self.tokenizer.tokenize(word) 360 | lbs.append(len(tokens)) 361 | tokens.extend(word_tokens) 362 | rbs.append(len(tokens)) 363 | tokens = [tokens + ['', langcode]] 364 | lengths = [len(x) for x in tokens] 365 | max_length = max(lengths) 366 | for i in range(len(tokens)): 367 | word_ids.append(self.tokenizer.convert_tokens_to_ids([''] * (max_length - len(tokens[i])) + tokens[i])) 368 | encoder_input = { 369 | 'src_tokens': torch.tensor(word_ids).to(self.device), 370 | 'src_lengths': torch.tensor(lengths).to(self.device) 371 | } 372 | encoder_outs = self.model.forward_encoder(encoder_input) 373 | np_encoder_outs = encoder_outs[0].encoder_out.float().detach() 374 | word_features = list() 375 | for i, lb in enumerate(lbs): 376 | rb = rbs[i] 377 | word_features.append(np_encoder_outs[lb:rb].mean(0)) 378 | word_features = torch.cat(word_features, dim=0) 379 | return word_features 380 | 381 | 382 | class WordAligner(nn.Module): 383 | def __init__(self, input_dim, hidden_dims, output_dim=1, feature_transform=3): 384 | super(WordAligner, self).__init__() 385 | layers = list() 386 | hidden_dims = [input_dim] + hidden_dims 387 | for i in range(1, len(hidden_dims)): 388 | layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) 389 | layers.append(nn.ReLU()) 390 | layers.append(nn.Linear(hidden_dims[-1], output_dim)) 391 | layers.append(nn.Sigmoid()) 392 | self.model = nn.Sequential(*layers) 393 | self.bias = nn.Parameter(torch.ones(feature_transform)) 394 | self.feature_transform = feature_transform 395 | 396 | def forward(self, x): 397 | transformed_features = torch.cat([x[:, :-self.feature_transform], torch.log(x[:, -self.feature_transform:] + self.bias.abs())], dim=-1) 398 | return self.model(transformed_features) 399 | 400 | def __call__(self, *args, **kwargs): 401 | return self.forward(*args, **kwargs) 402 | --------------------------------------------------------------------------------