├── datasets ├── word_freq │ ├── __init__.py │ ├── unigram │ │ └── __init__.py │ └── polyglot │ │ └── __init__.py ├── __init__.py ├── target_vectors │ ├── __init__.py │ ├── wiki2vec │ │ └── __init__.py │ ├── polyglot │ │ └── __init__.py │ ├── google │ │ └── __init__.py │ ├── glove │ │ └── __init__.py │ └── utils.py ├── affix │ └── __init__.py ├── ud │ ├── __init__.py │ └── make_dataset.py └── word_similarity │ └── __init__.py ├── .gitmodules ├── .gitignore ├── requirements.txt ├── pos_compute_oov.py ├── ws_multilingual_eval_target.py ├── pos_gather_result.py ├── pos_eval_target.py ├── pos_tune_reg_plot.py ├── nshortest.py ├── ws_eval_target.py ├── utils ├── args.py └── __init__.py ├── README.md ├── ws_plot_loss.py ├── load.py ├── pos_exp_sasaki.py ├── pos_tune_reg.py ├── ws_exp_sasaki.py ├── ws_report_nparam.py ├── ws_multilingual_exp_sasaki.py ├── pbos_pred.py ├── affix_exp.py ├── pos_eval.py ├── sasaki_utils.py ├── sasaki_codecs.py ├── ws_exp_pbos.py ├── ws_eval.py ├── ws_multilingual_exp_pbos.py ├── pbos_segment.py ├── pbos_train.py ├── pbos.py ├── pos_exp.py └── subwords.py /datasets/word_freq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "compact_reconstruction"] 2 | path = compact_reconstruction 3 | url = git@github.com:losyer/compact_reconstruction.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /datasets/*.* 3 | /datasets/**/*.* 4 | !/datasets/**/*.py 5 | 6 | /results/ 7 | .DS_Store 8 | __pycache__ 9 | .vscode 10 | /venv 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy==1.18.1 3 | progressbar 4 | scipy~=1.4.1 5 | scikit-learn~=0.22.1 6 | chainer~=7.2.0 7 | gensim==3.8.3 8 | cupy-cuda102 # or other version of cuda 9 | -------------------------------------------------------------------------------- /pos_compute_oov.py: -------------------------------------------------------------------------------- 1 | from datasets import prepare_target_vector_paths, polyglot_languages, prepare_ud_paths 2 | from load import load_embedding 3 | 4 | for language in polyglot_languages: 5 | polyglot_path = prepare_target_vector_paths(language).pkl_emb_path 6 | polyglot_vocab, _ = load_embedding(polyglot_path) 7 | polyglot_vocab = set(polyglot_vocab) 8 | 9 | _, ud_vocab_path = prepare_ud_paths(language) 10 | with open(ud_vocab_path) as f: 11 | ud_vocab = [w.strip() for w in f] 12 | 13 | oov = sum(w not in polyglot_vocab for w in ud_vocab) / len(ud_vocab) 14 | print(language, oov) 15 | -------------------------------------------------------------------------------- /ws_multilingual_eval_target.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to generate target vector statistics 3 | """ 4 | from datasets import prepare_target_vector_paths, prepare_ws_dataset_paths, get_ws_dataset_names 5 | from ws_eval import eval_ws 6 | 7 | for lang in ("de", "en", "it", "ru",): 8 | target_vector_path = prepare_target_vector_paths(f"wiki2vec-{lang}").txt_emb_path 9 | for dataset in get_ws_dataset_names(lang): 10 | data_path = prepare_ws_dataset_paths(dataset).txt_path 11 | for oov_handling in ("drop", "zero"): 12 | print(eval_ws(target_vector_path, data_path, lower=True, oov_handling=oov_handling)) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .target_vectors import prepare_target_vector_paths, polyglot_languages 2 | from .ud import prepare_ud_paths 3 | from .word_freq.polyglot import prepare_polyglot_freq_paths 4 | from .word_freq.unigram import prepare_unigram_freq_paths 5 | from .word_similarity import prepare_ws_combined_query_path, prepare_ws_dataset_paths, get_ws_dataset_names 6 | 7 | __all__ = [ 8 | "polyglot_languages", 9 | 10 | "prepare_target_vector_paths", 11 | 12 | "prepare_ws_combined_query_path", 13 | "prepare_ws_dataset_paths", 14 | "get_ws_dataset_names", 15 | 16 | "prepare_ud_paths", 17 | 18 | "prepare_polyglot_freq_paths", 19 | "prepare_unigram_freq_paths" 20 | 21 | ] 22 | -------------------------------------------------------------------------------- /pos_gather_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script to gather the result for POS 3 | """ 4 | 5 | from pathlib import Path 6 | 7 | pos_result_dir = Path("results") / "pos" 8 | 9 | 10 | def get_acc(lang, model_type): 11 | out_path = lang / model_type / "ud.out" 12 | if not out_path.exists(): 13 | return -1 14 | with open(out_path, 'r') as f: 15 | lines = f.read().splitlines() 16 | if len(lines) == 0: 17 | return -1 18 | last_line = lines[-1] 19 | _, acc = last_line.split(":") 20 | return acc.strip() 21 | 22 | 23 | if __name__ == "__main__": 24 | model_types = ("sasaki", "bos", "pbos", ) 25 | 26 | print("lang", *model_types, sep="\t") 27 | for lang in sorted(pos_result_dir.iterdir()): 28 | print(lang.name, *(get_acc(lang, m) for m in model_types), sep="\t") 29 | -------------------------------------------------------------------------------- /pos_eval_target.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script used to evaluate the raw PolyGlot vector for POS 3 | 4 | One can redirect the starnard output of this file to get rid of the training log 5 | 6 | python pos_exp_polyglot.py 2>train.log 1>eval.log 7 | """ 8 | 9 | import subprocess as sp 10 | 11 | from datasets import prepare_target_vector_paths, polyglot_languages, prepare_ud_paths 12 | 13 | for language_code in polyglot_languages: 14 | ud_vocab_embedding_path = prepare_target_vector_paths(language_code).pkl_emb_path 15 | ud_data_path, ud_vocab_path = prepare_ud_paths(language_code) 16 | 17 | cmd = f""" 18 | python pos_eval.py \ 19 | --dataset {ud_data_path} \ 20 | --embeddings {ud_vocab_embedding_path} \ 21 | """.split() 22 | output = sp.check_output(cmd) 23 | print(f"{language_code}: {output.decode('utf-8')}") 24 | -------------------------------------------------------------------------------- /pos_tune_reg_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script used to plot C vs score for POS results 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--results_dir', help="path to the results directory", default="results/pos_reg_search") 11 | args = parser.parse_args() 12 | 13 | xs = [] 14 | ys = [] 15 | 16 | for path in Path(args.results_dir).iterdir(): 17 | with open(path, 'r') as f: 18 | line = f.readline() 19 | if len(line) == 0: 20 | continue 21 | _, score_str = line.split(":") 22 | xs.append(float(path.name)) 23 | ys.append(float(score_str.strip())) 24 | 25 | plt.plot(xs, ys, 'o') 26 | plt.xscale('log', basex=10) 27 | plt.show() 28 | 29 | for x, y in sorted(zip(xs, ys)): 30 | print(f"{x}: {y:.6f}") 31 | -------------------------------------------------------------------------------- /nshortest.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import operator 3 | 4 | def nshortest(adjmat, n): 5 | """N shortest paths given a DAG as adjacency matrix, assuming (i, j) for all i < j. 6 | 7 | Returns: 8 | A list of (score, path), sorted by score. Path is a tuple of node ids. 9 | 10 | Examples: 11 | >>> adj = [[0, 1, 1], [0, 0, 1], [0, 0, 0]] 12 | >>> nshortest(adj, 2) 13 | [(1, (0, 2)), (2, (0, 1, 2))] 14 | """ 15 | candss = [[(0, (0, ))]] 16 | for j in range(1, len(adjmat)): 17 | cands = [] 18 | for i in range(j): 19 | for icand in candss[i]: 20 | iscore, ipath = icand 21 | score = iscore + adjmat[i][j] 22 | path = ipath + (j, ) 23 | cand = score, path 24 | cands.append(cand) 25 | candss.append(heapq.nsmallest(n, cands, key=operator.itemgetter(0))) 26 | return candss[-1] 27 | -------------------------------------------------------------------------------- /datasets/target_vectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .glove import prepare_glove_paths 2 | from .google import prepare_google_paths 3 | from .polyglot import prepare_polyglot_emb_paths, polyglot_languages 4 | from .wiki2vec import prepare_wiki2vec_emb_paths 5 | 6 | 7 | def prepare_target_vector_paths(target_vector_name): 8 | target_vector_name = target_vector_name.lower() 9 | 10 | if target_vector_name.startswith("polyglot-"): 11 | return prepare_polyglot_emb_paths(target_vector_name[-2:]) 12 | if target_vector_name.startswith("wiki2vec-"): 13 | return prepare_wiki2vec_emb_paths(target_vector_name[-2:]) 14 | if target_vector_name == "google": 15 | return prepare_google_paths() 16 | if target_vector_name == "polyglot": 17 | return prepare_polyglot_emb_paths("en") 18 | if target_vector_name == "glove": 19 | return prepare_glove_paths() 20 | if target_vector_name in polyglot_languages: 21 | return prepare_polyglot_emb_paths(target_vector_name) 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /ws_eval_target.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from datasets import prepare_target_vector_paths, get_ws_dataset_names, prepare_ws_dataset_paths 4 | from ws_eval import eval_ws 5 | 6 | 7 | def main(targets): 8 | for target in targets: 9 | target_vector_path = target if target == "EditSim" else prepare_target_vector_paths(target).txt_emb_path 10 | 11 | for dataset in get_ws_dataset_names(): 12 | data_path = prepare_ws_dataset_paths(dataset).txt_path 13 | for oov_handling in ("drop", "zero"): 14 | result = eval_ws(target_vector_path, data_path, lower=True, oov_handling=oov_handling) 15 | print(target, result) 16 | 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser("Show target vector statistics for the word similarity task") 20 | parser.add_argument( 21 | '--targets', 22 | '-t', 23 | nargs='+', 24 | choices=["EditSim", "polyglot", "google", "glove"], 25 | default=["polyglot", "google"] 26 | ) 27 | 28 | args = parser.parse_args() 29 | 30 | main(args.targets) 31 | -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def add_logging_args(parser): 5 | group = parser.add_argument_group('logging arguments') 6 | group.add_argument('--log_level', '-ll', default='INFO', 7 | help='log level used by logging module') 8 | return group 9 | 10 | 11 | def set_logging_config(args): 12 | """Set log level using args.log_level""" 13 | numeric_level = getattr(logging, args.log_level.upper(), None) 14 | if not isinstance(numeric_level, int): 15 | raise ValueError('Invalid log level: %s' % args.loglevel) 16 | logging.basicConfig(level=numeric_level) 17 | 18 | 19 | def dump_args(args, logger=None, save_path=None): 20 | """ 21 | log args 22 | save the args if `save_path` is not None 23 | """ 24 | import json 25 | 26 | if not isinstance(args, dict): 27 | args = vars(args) 28 | 29 | if logger: 30 | logger.info(json.dumps(args, indent=2)) 31 | else: 32 | logging.info(json.dumps(args, indent=2)) 33 | 34 | if save_path: 35 | with open(save_path, 'w') as fout: 36 | json.dump(args, fout, indent=2) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | - Clone compact_reconstruction 4 | from `https://github.com/losyer/compact_reconstruction/tree/a55627c99a7b17d556cc96275a4f41b6b93f8782` into the 5 | folder `compact_reconstruction/`. 6 | 7 | ```sh 8 | git submodule update --init 9 | ``` 10 | 11 | 12 | - Install dependencies 13 | 14 | ```sh 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | # Usage 19 | 20 | - To reproduce the word similarity results: 21 | 22 | ```sh 23 | python ws_exp_pbos.py 24 | python ws_exp_sasaki.py 25 | ``` 26 | 27 | The results will be available at `results/ws/{target_vector_name}_{model_type}/result.txt` 28 | 29 | 30 | - To reproduce the multilingual word similarity results: 31 | 32 | ```sh 33 | python ws_multilingual_exp_pbos.py 34 | python ws_multilingual_exp_sasaki.py 35 | ``` 36 | 37 | The results will be available at `results/ws_multi/{lang}_{model_type}/result.txt` 38 | 39 | 40 | - To reproduce the POS tagging results: 41 | 42 | ```sh 43 | python pos_exp.py 44 | python pos_exp_sasaki.py 45 | ``` 46 | 47 | The results will be available in the `results/pos` folder. 48 | 49 | You can print out all the results with `python pos_gather_results.py` 50 | 51 | -------------------------------------------------------------------------------- /datasets/word_freq/unigram/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import subprocess as sp 4 | import os 5 | 6 | from utils import dotdict 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | dir_path = os.path.dirname(os.path.realpath(__file__)) 11 | csv_path = f"{dir_path}/unigram_freq.csv" 12 | word_freq_path = f"{dir_path}/word_freq.jsonl" 13 | 14 | 15 | def prepare_unigram_freq_paths( 16 | dir_path=dir_path, 17 | csv_path=csv_path, 18 | word_freq_path=word_freq_path, 19 | ): 20 | if not os.path.exists(csv_path): 21 | url = "https://raw.githubusercontent.com/jai-dewani/Word-completion/master/unigram_freq.csv" 22 | sp.run(f"wget -O {csv_path} {url}".split()) 23 | 24 | if not os.path.exists(word_freq_path): 25 | with open(csv_path) as fin, open(word_freq_path, "w") as fout: 26 | for i, line in enumerate(fin, start=1): 27 | if i > 1: # skip the column names 28 | word, count = line.split(',') 29 | count = int(count) # make sure count is an int 30 | print(json.dumps((word, count)), file=fout) 31 | 32 | return dotdict( 33 | dir_path=dir_path, 34 | csv_path=csv_path, 35 | word_freq_path=word_freq_path, 36 | ) 37 | -------------------------------------------------------------------------------- /ws_plot_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script used to plot loss 3 | """ 4 | import argparse 5 | from pathlib import Path 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def parse_log(log_path): 11 | epochs, loss, epoch_times, total_times = [], [], [], [] 12 | with open(log_path, 'r') as f: 13 | for line in f: 14 | if not line.startswith("INFO:pbos_train:epoch") or "time" not in line: 15 | continue 16 | parts = line.split() 17 | epochs.append(int(parts[1])) 18 | loss.append(float(parts[6])) 19 | epoch_times.append(float(parts[9][:-1])) 20 | total_times.append(float(parts[11][:-1])) 21 | return epochs, loss, epoch_times, total_times 22 | 23 | 24 | def plot_loss(result_paths): 25 | for model_path in Path(result_paths).iterdir(): 26 | log_path = model_path / "info.log" 27 | if not log_path.exists(): 28 | continue 29 | epochs, loss, epoch_times, total_times = parse_log(log_path) 30 | plt.plot(epochs, loss, '.') 31 | plt.title(str(model_path)) 32 | plt.show() 33 | 34 | print(f"average epoch time for {str(model_path):<60}: {total_times[-1] / epochs[-1]:.3f}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--results_dir', help="path to the results directory", default="results/ws") 40 | args = parser.parse_args() 41 | 42 | plot_loss(args.results_dir) -------------------------------------------------------------------------------- /load.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | 7 | def load_embedding(filename: str, show_progress=False) -> (List[str], np.ndarray): 8 | """ 9 | :param filename: a .txt file or a .pkl/.pickle file 10 | :return: tuple (words, embeddings) 11 | """ 12 | import os 13 | if show_progress: 14 | from utils import file_tqdm 15 | else: 16 | from utils import dummy_tqdm as file_tqdm 17 | 18 | _, ext = os.path.splitext(filename) 19 | if ext in (".txt", ".w2v"): 20 | vocab, emb = [], [] 21 | with open(filename, "r") as fin: 22 | if ext == ".w2v": 23 | next(fin) 24 | for line in file_tqdm(fin): 25 | ss = line.split() 26 | try: 27 | emb.append([float(x) for x in ss[1:]]) 28 | vocab.append(ss[0]) 29 | except ValueError: 30 | print(f"Error loading the line: {line[:30]} ...") 31 | emb = np.array(emb) 32 | elif ext in (".pickle", ".pkl"): 33 | import pickle 34 | try: 35 | with open(filename, 'rb') as bfin: 36 | vocab, emb = pickle.load(bfin) 37 | except UnicodeDecodeError: 38 | with open(filename, 'rb') as bfin: 39 | vocab, emb = pickle.load(bfin, encoding='bytes') 40 | else: 41 | raise ValueError(f'Unsupported target vector file extent: {filename}') 42 | 43 | return vocab, emb 44 | -------------------------------------------------------------------------------- /pos_exp_sasaki.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import subprocess as sp 3 | from pathlib import Path 4 | 5 | from datasets import prepare_target_vector_paths, polyglot_languages, prepare_ud_paths, prepare_polyglot_freq_paths 6 | from sasaki_utils import inference, train, prepare_codecs_path 7 | 8 | 9 | def exp(language): 10 | result_path = Path("results") / "pos" / language / "sasaki" 11 | 12 | emb_path = prepare_target_vector_paths(language).w2v_emb_path 13 | freq_path = prepare_polyglot_freq_paths(language).raw_count_path 14 | codecs_path = prepare_codecs_path(emb_path, result_path) 15 | ud_data_path, ud_vocab_path = prepare_ud_paths(language) 16 | 17 | model_info = train( 18 | emb_path, 19 | result_path, 20 | freq_path=freq_path, 21 | codecs_path=codecs_path, 22 | epoch=300, 23 | H=40_000, 24 | F=500_000 25 | ) 26 | 27 | result_emb_path = inference(model_info, ud_vocab_path) 28 | 29 | with open(result_path / "ud.out", "w") as fout, open(result_path / "ud.log", "w") as ferr: 30 | cmd = f""" 31 | python pos_eval.py \ 32 | --dataset {ud_data_path} \ 33 | --embeddings {result_emb_path} \ 34 | --C {70} \ 35 | """.split() 36 | sp.call(cmd, stdout=fout, stderr=ferr) 37 | 38 | 39 | if __name__ == "__main__": 40 | with mp.Pool() as pool: 41 | for lang in polyglot_languages: 42 | pool.apply_async(exp, (lang,)) 43 | 44 | pool.close() 45 | pool.join() 46 | -------------------------------------------------------------------------------- /pos_tune_reg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script used to tune regression C for POS 3 | """ 4 | import argparse 5 | import os 6 | import subprocess as sp 7 | import multiprocessing as mp 8 | from itertools import product 9 | 10 | C_interval = sorted(x * 10 ** b for x, b in product(range(1, 10), range(-1, 5))) 11 | 12 | 13 | def evaluate(results_dir, embeddings, dataset, C): 14 | with open(f"{results_dir}/{C:.1f}", "w+") as f: 15 | sp.call( 16 | f"python pos_eval.py \ 17 | --embeddings {embeddings} \ 18 | --C {C} \ 19 | --dataset {dataset} \ 20 | ".split(), 21 | stdout=f 22 | ) 23 | 24 | 25 | def main(results_dir, embeddings, dataset): 26 | os.makedirs(results_dir, exist_ok=True) 27 | with mp.Pool() as pool: 28 | results = [ 29 | pool.apply_async(evaluate, (results_dir, embeddings, dataset, C)) 30 | for C in C_interval 31 | ] 32 | 33 | for r in results: 34 | r.get() 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--dataset', required=True, help="path to dataset") 40 | parser.add_argument( 41 | '--embeddings', 42 | required=True, 43 | help="path to word embeddings" 44 | ) 45 | parser.add_argument( 46 | '--results_dir', 47 | help="path to the results directory", 48 | default="results/pos_reg_search" 49 | ) 50 | args = parser.parse_args() 51 | main(args.results_dir, args.embeddings, args.dataset) 52 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from tqdm import tqdm 4 | 5 | 6 | def get_substrings(s: str, min_len=1, max_len=None) -> List[str]: 7 | """ 8 | :param s: string 9 | :return: a list of contiguous substrings 10 | """ 11 | max_len = max_len or len(s) 12 | for j in range(min_len, len(s) + 1): 13 | for i in range(max(0, j - max_len), max(0, j - min_len + 1)): 14 | yield s[i:j] 15 | 16 | 17 | def normalize_prob(subword_count: Dict[str, int]) -> Dict[str, float]: 18 | """ 19 | :param subword_count: dictionary {word: count} 20 | :return: normalized probability {word: probability}, the length of word is also normalized 21 | """ 22 | total = sum(subword_count.values()) 23 | if total == 0: 24 | return {} 25 | return {k: (v / total) for k, v in subword_count.items()} 26 | 27 | 28 | def dummy_tqdm(x, *args, **kwargs): 29 | return x 30 | 31 | 32 | def get_number_of_lines(fobj): 33 | pos = fobj.tell() 34 | nol = sum(1 for _ in fobj) 35 | fobj.seek(pos) 36 | return nol 37 | 38 | 39 | def file_tqdm(fobj): 40 | return tqdm(fobj, total=get_number_of_lines(fobj)) 41 | 42 | 43 | class dotdict(dict): 44 | __getattr__ = dict.__getitem__ 45 | __setattr__ = dict.__setitem__ 46 | __delattr__ = dict.__delitem__ 47 | 48 | @classmethod 49 | def nested(cls, dct): 50 | d = cls() 51 | for key, value in dct.items(): 52 | if hasattr(value, 'keys'): 53 | value = cls(value) 54 | d[key] = value 55 | return d 56 | -------------------------------------------------------------------------------- /ws_exp_sasaki.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing as mp 3 | from pathlib import Path 4 | 5 | from datasets import prepare_ws_combined_query_path, prepare_target_vector_paths 6 | from sasaki_utils import inference, prepare_codecs_path, train 7 | from utils import dotdict 8 | from ws_exp_pbos import evaluate 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def exp(ref_vec_name): 14 | result_path = Path("results") / "ws" / f"{ref_vec_name}_sasaki" 15 | ref_vec_path = prepare_target_vector_paths(ref_vec_name).w2v_emb_path 16 | codecs_path = prepare_codecs_path(ref_vec_path, result_path) 17 | 18 | log_file = open(result_path / "log.txt", "w+") 19 | logging.basicConfig(level=logging.DEBUG, stream=log_file) 20 | 21 | logger.info("Training...") 22 | model_info = train( 23 | ref_vec_path, 24 | result_path, 25 | codecs_path=codecs_path, 26 | H=40_000, 27 | F=500_000, 28 | epoch=300, 29 | ) 30 | 31 | logger.info("Inferencing...") 32 | combined_query_path = prepare_ws_combined_query_path() 33 | result_emb_path = inference(model_info, combined_query_path) 34 | 35 | logger.info("Evaluating...") 36 | evaluate(dotdict( 37 | eval_result_path=result_path / "result.txt", 38 | pred_path=result_emb_path 39 | )) 40 | 41 | 42 | if __name__ == '__main__': 43 | with mp.Pool() as pool: 44 | target_vector_names = ("polyglot", "google") 45 | 46 | results = [ 47 | pool.apply_async(exp, (ref_vec_name,)) 48 | for ref_vec_name in target_vector_names 49 | ] 50 | 51 | for r in results: 52 | r.get() 53 | -------------------------------------------------------------------------------- /ws_report_nparam.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script to print the number of parameters for all models in a directory. 3 | """ 4 | 5 | import argparse 6 | import pickle 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | 11 | 12 | def count_nparam_sasaki(result_dir): 13 | from sasaki_utils import get_info_from_result_path 14 | 15 | model_info = get_info_from_result_path(result_dir / "sep_kvq") 16 | model_path = model_info["model_path"] 17 | model = np.load(model_path) 18 | 19 | nparam = 0 20 | for filename in model.files: 21 | # print(filename, model[filename].size) 22 | nparam += model[filename].size 23 | 24 | return nparam 25 | 26 | 27 | def count_nparam_pbos(result_dir): 28 | model_path = result_dir / "model.pkl" 29 | if not model_path.exists(): 30 | return -1 31 | with open(model_path, "rb") as f: 32 | model = pickle.load(f, encoding='bytes') 33 | nsubwords = len(model) 34 | embed_dim = len(next(iter(model.values()))) 35 | return nsubwords * embed_dim 36 | 37 | 38 | def main(results_dir): 39 | for result_dir in sorted(Path(results_dir).iterdir()): 40 | if "bos" in result_dir.name: 41 | nparam = count_nparam_pbos(result_dir) 42 | elif "sasaki" in result_dir.name: 43 | nparam = count_nparam_sasaki(result_dir) 44 | else: 45 | nparam = -1 46 | print(f"{result_dir.name:<20}{nparam:,}") 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument( 52 | '--results_dir', '-d', 53 | help="path to the results directory", 54 | default="results/ws" 55 | ) 56 | args = parser.parse_args() 57 | main(args.results_dir) 58 | -------------------------------------------------------------------------------- /datasets/affix/__init__.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import logging 3 | import os 4 | import shutil 5 | import subprocess as sp 6 | from itertools import islice 7 | from pathlib import Path 8 | 9 | from utils import dotdict 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | dir_path = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | 16 | def prepare_affix_paths( 17 | *, 18 | dir_path=dir_path, 19 | ): 20 | dir_path = Path(dir_path) 21 | gz_path = dir_path / "affix_complete_set.txt.gz" 22 | raw_path = dir_path / "affix_complete_set.txt" 23 | queries_path = dir_path / "queries.txt" 24 | 25 | if not os.path.exists(gz_path): 26 | logger.info(f"Downloading {gz_path}") 27 | url = "http://marcobaroni.org/PublicData/affix_complete_set.txt.gz" 28 | sp.run(f"wget -O {gz_path} {url}".split()) 29 | 30 | if not os.path.exists(raw_path): 31 | logger.info(f"Unzipping {raw_path}") 32 | with gzip.open(gz_path, 'rb') as fin, open(raw_path, 'wb') as fout: 33 | shutil.copyfileobj(fin, fout) 34 | 35 | if not os.path.exists(queries_path): 36 | logger.info(f"Making {queries_path}") 37 | with open(raw_path) as fin, open(queries_path, 'w') as fout: 38 | for line in islice(fin, 1, None): ## skip the title row 39 | ## row fmt: affix stem stemPOS derived derivedPOS type ... 40 | affix, stem, _, derived, _, split = line.split()[:6] 41 | print(derived, file=fout) 42 | if derived.lower() != derived: 43 | print(derived.lower(), file=fout) 44 | 45 | return dotdict( 46 | dir_path = dir_path, 47 | gz_path = gz_path, 48 | raw_path = raw_path, 49 | queries_path = queries_path, 50 | ) 51 | 52 | 53 | if __name__ == '__main__': 54 | logging.basicConfig(level=logging.INFO) 55 | prepare_affix_paths() 56 | -------------------------------------------------------------------------------- /ws_multilingual_exp_sasaki.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing as mp 3 | from pathlib import Path 4 | 5 | from datasets import prepare_target_vector_paths, prepare_ws_combined_query_path 6 | from sasaki_utils import inference, prepare_codecs_path, train, get_info_from_result_path 7 | from utils import dotdict 8 | from ws_multilingual_exp_pbos import evaluate 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def exp(ref_vec_name): 14 | result_path = Path("results") / "ws_multi" / f"{ref_vec_name}_sasaki" 15 | ref_vec_path = prepare_target_vector_paths(f"wiki2vec-{ref_vec_name}").w2v_emb_path 16 | codecs_path = prepare_codecs_path(ref_vec_path, result_path) 17 | 18 | log_file = open(result_path / "log.txt", "w+") 19 | logging.basicConfig(level=logging.DEBUG, stream=log_file) 20 | 21 | logger.info("Training...") 22 | train( 23 | ref_vec_path, 24 | result_path, 25 | codecs_path=codecs_path, 26 | H=40_000, 27 | F=500_000, 28 | epoch=300, 29 | ) 30 | 31 | model_info = get_info_from_result_path(result_path / "sep_kvq") 32 | 33 | logger.info("Inferencing...") 34 | combined_query_path = prepare_ws_combined_query_path(ref_vec_name) 35 | result_emb_path = inference(model_info, combined_query_path) 36 | 37 | logger.info("Evaluating...") 38 | evaluate(dotdict( 39 | model_type="sasaki", 40 | eval_result_path=result_path / "result.txt", 41 | pred_path=result_emb_path, 42 | target_vector_name=ref_vec_name, 43 | results_dir=result_path, 44 | )) 45 | 46 | 47 | if __name__ == '__main__': 48 | with mp.Pool() as pool: 49 | target_vector_names = ("en", "de", "it", "ru") 50 | 51 | results = [ 52 | pool.apply_async(exp, (ref_vec_name,)) 53 | for ref_vec_name in target_vector_names 54 | ] 55 | 56 | for r in results: 57 | r.get() 58 | -------------------------------------------------------------------------------- /datasets/target_vectors/wiki2vec/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess as sp 4 | 5 | from datasets.target_vectors.utils import convert_target_dataset 6 | from utils import dotdict 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | dir_path = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | def prepare_wiki2vec_emb_paths(lang, *, dir_path=dir_path): 14 | language_dir_path = os.path.join(dir_path, lang) 15 | download_path = os.path.join(language_dir_path, "raw_embeddings.w2v.bz2") 16 | pkl_emb_path = os.path.join(language_dir_path, "embeddings.pkl") 17 | w2v_emb_path = os.path.join(language_dir_path, "embeddings.w2v") 18 | word_freq_path = os.path.join(language_dir_path, "word_freq.jsonl") 19 | txt_emb_path = os.path.join(language_dir_path, "embeddings.txt") 20 | 21 | os.makedirs(language_dir_path, exist_ok=True) 22 | 23 | if not os.path.exists(download_path): 24 | url = f"http://wikipedia2vec.s3.amazonaws.com/models/{lang}/2018-04-20/{lang}wiki_20180420_300d.txt.bz2" 25 | logger.info(f"Downloading {url} to {download_path}") 26 | sp.run(f"wget -O {download_path} {url}".split()) 27 | 28 | if not os.path.exists(w2v_emb_path): 29 | logger.info(f"Unzipping {download_path}") 30 | sp.run(f"bzip2 -dk {download_path}".split()) 31 | os.system(f"head -n 100001 {download_path[:-4]} > {w2v_emb_path}") # keep 100k tokens and one line of header 32 | 33 | convert_target_dataset( 34 | input_emb_path=w2v_emb_path, 35 | 36 | txt_emb_path=txt_emb_path, 37 | pkl_emb_path=pkl_emb_path, 38 | 39 | word_freq_path=word_freq_path, 40 | ) 41 | 42 | return dotdict( 43 | dir_path=dir_path, 44 | language_dir_path=language_dir_path, 45 | download_path=download_path, 46 | 47 | pkl_emb_path=pkl_emb_path, 48 | w2v_emb_path=w2v_emb_path, 49 | txt_emb_path=txt_emb_path, 50 | word_freq_path=word_freq_path, 51 | ) 52 | 53 | 54 | languages = ["en", "it", "ru", "de"] 55 | 56 | if __name__ == '__main__': 57 | logging.basicConfig(level=logging.INFO) 58 | for language_code in languages: 59 | prepare_wiki2vec_emb_paths(language_code) 60 | -------------------------------------------------------------------------------- /datasets/word_freq/polyglot/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | import subprocess as sp 6 | import tarfile 7 | 8 | from utils import dotdict 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | 15 | def prepare_polyglot_freq_paths( 16 | language_code, 17 | *, 18 | dir_path=dir_path, 19 | word_freq_max_size=1000000, 20 | ): 21 | language_dir_path = os.path.join(dir_path, language_code) 22 | tar_path = os.path.join(language_dir_path, f"{language_code}.voc.tar.bz2") 23 | raw_count_path = os.path.join(language_dir_path, "count.txt") 24 | word_freq_path = os.path.join(language_dir_path, "word_freq.jsonl") 25 | 26 | os.makedirs(language_dir_path, exist_ok=True) 27 | 28 | if not os.path.exists(tar_path): 29 | logger.info(f"Downloading {tar_path}") 30 | url = f"http://polyglot.cs.stonybrook.edu/~polyglot/counts2/{language_code}/{language_code}.voc.tar.bz2" 31 | sp.run(f"wget -O {tar_path} {url}".split()) 32 | 33 | if not os.path.exists(raw_count_path): 34 | logger.info(f"Unzipping {raw_count_path}") 35 | with tarfile.open(tar_path) as tar, open(raw_count_path, 'wb+') as dst_file: 36 | src_file = tar.extractfile(f"counts/{language_code}.docs.txt.voc") 37 | shutil.copyfileobj(src_file, dst_file) 38 | 39 | if not os.path.exists(word_freq_path): 40 | with open(raw_count_path) as fin, open(word_freq_path, "w") as fout: 41 | for i_line, line in enumerate(fin): 42 | if i_line >= word_freq_max_size: 43 | break 44 | word, count = line.split() 45 | count = int(count) 46 | print(json.dumps((word, count)), file=fout) 47 | 48 | return dotdict( 49 | dir_path = dir_path, 50 | language_dir_path = language_dir_path, 51 | tar_path = tar_path, 52 | raw_count_path = raw_count_path, 53 | word_freq_path = word_freq_path, 54 | ) 55 | 56 | 57 | if __name__ == '__main__': 58 | from datasets import polyglot_languages 59 | 60 | for language_code in polyglot_languages: 61 | prepare_polyglot_freq_paths(language_code) 62 | -------------------------------------------------------------------------------- /pbos_pred.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | import time 6 | 7 | from pbos import PBoS 8 | from subwords import add_word_args, bound_word 9 | from utils.args import add_logging_args, set_logging_config 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def predict(model, word_boundary, queries=None, save=None, pre_trained=None): 15 | """ 16 | :return: The total time used in prediction 17 | """ 18 | if pre_trained: 19 | from load import load_embedding 20 | 21 | pre_trained_vocab, pre_trained_emb = load_embedding(pre_trained) 22 | pre_trained_emb_lookup = dict(zip(pre_trained_vocab, pre_trained_emb)) 23 | 24 | logger.info("loading...") 25 | model = PBoS.load(model) 26 | logging.info("generating...") 27 | if queries: 28 | fin = open(queries, "r", encoding="utf-8") 29 | else: 30 | fin = sys.stdin 31 | if save: 32 | save_dir = os.path.dirname(save) 33 | try: 34 | os.makedirs(save_dir) 35 | except FileExistsError: 36 | logger.warning("Things will get overwritten for directory {}".format(save_dir)) 37 | fout = open(save, "w", encoding="utf-8") 38 | else: 39 | fout = sys.stdout 40 | 41 | start = time.time() 42 | for line in fin: 43 | ori_query = line.strip() 44 | query = bound_word(ori_query) if word_boundary else ori_query 45 | if pre_trained: 46 | vector = ( 47 | pre_trained_emb_lookup[query] 48 | if query in pre_trained_emb_lookup 49 | else model.embed(query) 50 | ) 51 | else: 52 | vector = model.embed(query) 53 | print(ori_query, *vector, file=fout) 54 | 55 | return time.time() - start 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser(description="Bag of substrings: prediction") 60 | parser.add_argument("--pre_trained", 61 | help="If this variable is specified, only use the model for OOV, " 62 | "and use the pre_trainved vectors for query") 63 | parser.add_argument("--model", required=True) 64 | parser.add_argument("--save", help="If not specified, use stdin.") 65 | parser.add_argument("--queries", help="If not specified, use stdout.") 66 | add_logging_args(parser) 67 | add_word_args(parser) 68 | args = parser.parse_args() 69 | 70 | set_logging_config(args) 71 | 72 | predict(args.model, args.word_boundary, args.queries, args.save, args.pre_trained) 73 | -------------------------------------------------------------------------------- /datasets/target_vectors/polyglot/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import subprocess as sp 5 | import tarfile 6 | 7 | from datasets.target_vectors.utils import convert_target_dataset 8 | from utils import dotdict 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | 15 | def prepare_polyglot_emb_paths(lang, *, dir_path=dir_path): 16 | if lang not in polyglot_languages: 17 | raise NotImplementedError 18 | 19 | language_dir_path = os.path.join(dir_path, lang) 20 | tar_path = os.path.join(language_dir_path, "embeddings.tar.bz2") 21 | raw_pkl_emb_path = os.path.join(language_dir_path, "raw_embeddings.pkl") 22 | pkl_emb_path = os.path.join(language_dir_path, "embeddings.pkl") 23 | w2v_emb_path = os.path.join(language_dir_path, "embeddings.w2v") 24 | word_freq_path = os.path.join(language_dir_path, "word_freq.jsonl") 25 | txt_emb_path = os.path.join(language_dir_path, "embeddings.txt") 26 | 27 | os.makedirs(language_dir_path, exist_ok=True) 28 | 29 | if not os.path.exists(tar_path): 30 | logger.info(f"Downloading {tar_path}") 31 | url = f"http://polyglot.cs.stonybrook.edu/~polyglot/embeddings2/{lang}/embeddings_pkl.tar.bz2" 32 | sp.run(f"wget -O {tar_path} {url}".split()) 33 | 34 | if not os.path.exists(raw_pkl_emb_path): 35 | logger.info(f"Unzipping {tar_path}") 36 | with tarfile.open(tar_path) as tar, open(raw_pkl_emb_path, 'wb+') as dst_file: 37 | src_file = tar.extractfile("./words_embeddings_32.pkl") 38 | shutil.copyfileobj(src_file, dst_file) 39 | 40 | convert_target_dataset( 41 | input_emb_path=raw_pkl_emb_path, 42 | 43 | txt_emb_path=txt_emb_path, 44 | w2v_emb_path=w2v_emb_path, 45 | pkl_emb_path=pkl_emb_path, 46 | word_freq_path=word_freq_path, 47 | 48 | # vocab_size_limit=10_000 49 | ) 50 | 51 | return dotdict( 52 | dir_path=dir_path, 53 | language_dir_path=language_dir_path, 54 | tar_path=tar_path, 55 | 56 | pkl_emb_path=pkl_emb_path, 57 | w2v_emb_path=w2v_emb_path, 58 | txt_emb_path=txt_emb_path, 59 | word_freq_path=word_freq_path, 60 | ) 61 | 62 | 63 | polyglot_languages = [ 64 | 'ar', 'bg', 'cs', 'da', 'el', 'en', 'es', 'eu', 'fa', 'he', 'hi', 'hu', 65 | 'id', 'it', 'kk', 'lv', 'ro', 'ru', 'sv', 'ta', 'tr', 'vi', 'zh', 66 | ] 67 | 68 | if __name__ == '__main__': 69 | logging.basicConfig(level=logging.INFO) 70 | for language_code in polyglot_languages: 71 | prepare_polyglot_emb_paths(language_code) 72 | -------------------------------------------------------------------------------- /datasets/target_vectors/google/__init__.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import logging 3 | import os 4 | import shutil 5 | import subprocess as sp 6 | 7 | import gensim 8 | 9 | from datasets.target_vectors.utils import save_target_dataset, clean_target_emb, convert_target_dataset 10 | from utils import dotdict 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | dir_path = os.path.dirname(os.path.realpath(__file__)) 15 | gz_path = f"{dir_path}/embedding.bin.gz" 16 | bin_emb_path = f"{dir_path}/embedding.bin" 17 | txt_emb_path = f"{dir_path}/embedding.txt" 18 | pkl_emb_path = f"{dir_path}/embedding.pkl" 19 | w2v_emb_path = f"{dir_path}/embedding.w2v" 20 | word_list_path = f"{dir_path}/word_list.txt" 21 | word_freq_path = f"{dir_path}/word_freq.jsonl" 22 | raw_count_path = f"{dir_path}/word_freq.txt" 23 | 24 | 25 | def prepare_google_paths( 26 | dir_path=dir_path, 27 | gz_path=gz_path, 28 | bin_emb_path=bin_emb_path, 29 | txt_emb_path=txt_emb_path, 30 | pkl_emb_path=pkl_emb_path, 31 | word_list_path=word_list_path, 32 | word_freq_path=word_freq_path, 33 | w2v_emb_path=w2v_emb_path, 34 | raw_count_path=raw_count_path, 35 | ): 36 | if not os.path.exists(gz_path): 37 | url = "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz" 38 | sp.run(f"wget -O {gz_path} {url}".split()) 39 | 40 | if not os.path.exists(bin_emb_path): 41 | with gzip.open(gz_path, "rb") as fin, open(bin_emb_path, "wb") as fout: 42 | shutil.copyfileobj(fin, fout) 43 | 44 | if not os.path.exists(pkl_emb_path): 45 | logger.info("loading pre-trained google news vectors...") 46 | model = gensim.models.KeyedVectors.load_word2vec_format(bin_emb_path, binary=True) 47 | vocab, emb = clean_target_emb(raw_vocab=list(model.vocab), raw_emb=model.vectors) 48 | save_target_dataset(vocab, emb, pkl_emb_path=pkl_emb_path) 49 | 50 | convert_target_dataset( 51 | input_emb_path=pkl_emb_path, 52 | 53 | txt_emb_path=txt_emb_path, 54 | w2v_emb_path=w2v_emb_path, 55 | 56 | word_list_path=word_list_path, 57 | word_freq_path=word_freq_path, 58 | raw_count_path=raw_count_path, 59 | ) 60 | 61 | return dotdict( 62 | dir_path=dir_path, 63 | gz_path=gz_path, 64 | 65 | bin_emb_path=bin_emb_path, 66 | txt_emb_path=txt_emb_path, 67 | pkl_emb_path=pkl_emb_path, 68 | w2v_emb_path=w2v_emb_path, 69 | 70 | word_list_path=word_list_path, 71 | word_freq_path=word_freq_path, 72 | raw_count_path=raw_count_path, 73 | ) 74 | 75 | 76 | if __name__ == '__main__': 77 | logging.basicConfig(level=logging.INFO) 78 | prepare_google_paths() 79 | -------------------------------------------------------------------------------- /datasets/ud/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess as sp 4 | import tarfile 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | dir_path = os.path.dirname(os.path.realpath(__file__)) 9 | 10 | lang_folder_dic = { 11 | "cu": "UD_Old_Church_Slavonic", 12 | "id": "UD_Indonesian", 13 | "ug": "UD_Uyghur", 14 | "ru": "UD_Russian", 15 | "la": "UD_Latin", 16 | "tr": "UD_Turkish", 17 | "hr": "UD_Croatian", 18 | "uk": "UD_Ukrainian", 19 | "hi": "UD_Hindi", 20 | "sk": "UD_Slovak", 21 | "sl": "UD_Slovenian", 22 | "cs": "UD_Czech", 23 | "kk": "UD_Kazakh", 24 | "ga": "UD_Irish", 25 | "de": "UD_German", 26 | "lv": "UD_Latvian", 27 | "co": "UD_Coptic", 28 | "pt": "UD_Portuguese", 29 | "ca": "UD_Catalan", 30 | "no": "UD_Norwegian", 31 | "nl": "UD_Dutch", 32 | "he": "UD_Hebrew", 33 | "da": "UD_Danish", 34 | "fr": "UD_French", 35 | "pl": "UD_Polish", 36 | "zh": "UD_Chinese", 37 | "fa": "UD_Persian", 38 | "ta": "UD_Tamil", 39 | "hu": "UD_Hungarian", 40 | "ja": "UD_Japanese", 41 | "et": "UD_Estonian", 42 | "go": "UD_Gothic", 43 | "eu": "UD_Basque", 44 | "en": "UD_English", 45 | "it": "UD_Italian", 46 | "gl": "UD_Galician", 47 | "vi": "UD_Vietnamese", 48 | "ro": "UD_Romanian", 49 | "el": "UD_Greek", 50 | "es": "UD_Spanish", 51 | "bg": "UD_Bulgarian", 52 | "sa": "UD_Sanskrit", 53 | "sv": "UD_Swedish", 54 | "ar": "UD_Arabic", 55 | "fi": "UD_Finnish", 56 | } 57 | 58 | 59 | def prepare_ud_paths(language): 60 | tgz_path = f"{dir_path}/ud-treebanks-v1.4.tgz" 61 | language_folder_path = ( 62 | f"{dir_path}/ud-treebanks-v1.4/{lang_folder_dic[language]}" 63 | ) 64 | vocab_path = f"{language_folder_path}/vocab.txt" 65 | data_path = f"{language_folder_path}/combined.pkl" 66 | 67 | if not os.path.exists(tgz_path): 68 | url = "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1827/ud-treebanks-v1.4.tgz?sequence=4&isAllowed=y" 69 | sp.run(f"wget -O {tgz_path} {url}".split()) 70 | 71 | if not os.path.exists(f"{dir_path}/ud-treebanks-v1.4"): 72 | with tarfile.open(tgz_path) as tar: 73 | tar.extractall(dir_path) 74 | 75 | if not os.path.exists(vocab_path) or not os.path.exists(data_path): 76 | sp.run( 77 | f""" 78 | python {dir_path}/make_dataset.py \ 79 | --training-data {language_folder_path}/{language}-ud-train.conllu \ 80 | --dev-data {language_folder_path}/{language}-ud-dev.conllu \ 81 | --test-data {language_folder_path}/{language}-ud-test.conllu \ 82 | --output {data_path} \ 83 | --vocab {vocab_path} \ 84 | --ud-tags 85 | """.split() 86 | ) 87 | 88 | return data_path, vocab_path 89 | 90 | 91 | if __name__ == "__main__": 92 | from datasets import polyglot_languages 93 | 94 | for language in polyglot_languages: 95 | prepare_ud_paths(language) 96 | -------------------------------------------------------------------------------- /affix_exp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | 4 | import numpy as np 5 | import sklearn.metrics 6 | 7 | from datasets.affix import prepare_affix_paths 8 | from pbos import PBoS 9 | from utils import get_substrings 10 | 11 | 12 | def bos_predictor(w, rng, possible_affixes): 13 | affixes = list(sorted(possible_affixes & set(get_substrings(w)))) ## sort to ensure reproducibility 14 | return rng.choice(affixes) 15 | 16 | 17 | def pbos_predictor(w, model, possible_affixes): 18 | subword_weights = model._calc_subword_weights(w) 19 | score, affix = max((subword_weights[af], af) 20 | for af in possible_affixes if af in subword_weights) 21 | return affix 22 | 23 | 24 | def print_metrics(true_y, pred_y): 25 | for average_scheme in ('micro', 'macro'): 26 | for score_name in ('precision', 'recall', 'f1'): 27 | print("{} {}:\t{}".format( 28 | average_scheme, 29 | score_name, 30 | getattr(sklearn.metrics, score_name + "_score")( 31 | true_y, pred_y, 32 | average=average_scheme, 33 | ), 34 | )) 35 | 36 | 37 | def main(args): 38 | word_affix_pairs = [] 39 | with open(prepare_affix_paths().raw_path) as fin: 40 | for line in islice(fin, 1, None): ## skip the title row 41 | ## row fmt: affix stem stemPOS derived derivedPOS type ... 42 | affix, stem, _, derived, _, split = line.split()[:6] 43 | affix = affix.strip('-') 44 | if affix != 'y': 45 | word_affix_pairs.append((derived, affix)) 46 | 47 | possible_affixes = set(af for w, af in word_affix_pairs) 48 | print(f"# interesting possible affixes: {len(possible_affixes)}") 49 | 50 | interesting_word_affix_pairs = [ 51 | (w, af) 52 | for w, af in word_affix_pairs 53 | if len(possible_affixes & set(get_substrings(w))) > 1 54 | ] 55 | print(f"# interesting words: {len(interesting_word_affix_pairs)}") 56 | 57 | 58 | true_affixes = [af for w, af in interesting_word_affix_pairs] 59 | print("bos affix prediction:") 60 | bos_predict = partial( 61 | bos_predictor, 62 | rng=np.random.RandomState(args.seed), 63 | possible_affixes=possible_affixes, 64 | ) 65 | bos_affixes = [bos_predict(w) for w, af in interesting_word_affix_pairs] 66 | print_metrics(true_affixes, bos_affixes) 67 | print("pbos affix prediction:") 68 | pbos_predict = partial( 69 | pbos_predictor, 70 | model=PBoS.load(args.pbos), 71 | possible_affixes=possible_affixes, 72 | ) 73 | pbos_affixes = [pbos_predict(w) for w, af in interesting_word_affix_pairs] 74 | print_metrics(true_affixes, pbos_affixes) 75 | 76 | 77 | if __name__ == '__main__': 78 | import argparse 79 | parser = argparse.ArgumentParser(description='PBoS trainer', 80 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 81 | parser.add_argument('--pbos', 82 | help="path to pbos model") 83 | parser.add_argument('--seed', type=int, default=1337, 84 | help="random seed") 85 | args = parser.parse_args() 86 | main(args) 87 | -------------------------------------------------------------------------------- /pos_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | from collections import namedtuple 5 | from itertools import chain, repeat 6 | 7 | import numpy as np 8 | from sklearn.linear_model import LogisticRegression 9 | from tqdm import tqdm 10 | 11 | from load import load_embedding 12 | from utils.args import add_logging_args, set_logging_config 13 | 14 | parser = argparse.ArgumentParser("Evaluate embedding on POS tagging", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument('--dataset', 17 | help="path to processed UD dataset") 18 | parser.add_argument('--embeddings', 19 | help="path to word embeddings") 20 | parser.add_argument('--random_seed', type=int, default=42, 21 | help="random seed for training the classifier") 22 | parser.add_argument('--C', type=float, default=1.0, help="Inverse of regularization strength") 23 | add_logging_args(parser) 24 | args = parser.parse_args() 25 | 26 | set_logging_config(args) 27 | 28 | Instance = namedtuple("Instance", ["sentence", "tags"]) 29 | 30 | ## Load Mimick-style UD data 31 | with open(args.dataset, 'rb') as f: 32 | dataset = pickle.load(f) 33 | w2i = dataset["w2i"] 34 | t2is = dataset["t2is"] 35 | c2i = dataset["c2i"] 36 | i2w = {i: w for w, i in list(w2i.items())} 37 | i2ts = { 38 | att: {i: t 39 | for t, i in list(t2i.items())} 40 | for att, t2i in list(t2is.items()) 41 | } 42 | i2c = {i: c for c, i in list(c2i.items())} 43 | 44 | training_instances = dataset["training_instances"] 45 | training_vocab = dataset["training_vocab"] 46 | dev_instances = dataset["dev_instances"] 47 | dev_vocab = dataset["dev_vocab"] 48 | test_instances = dataset["test_instances"] 49 | 50 | ## Load embeddings 51 | vocab, emb = load_embedding(args.embeddings) 52 | emb_w2i = {w : i for i, w in enumerate(vocab)} 53 | 54 | emb_unk_i = vocab.index('') 55 | # assert '' == vocab[0], vocab[0] 56 | assert '' == i2w[0], i2w[0] 57 | 58 | ## Prepare training and testing data arrays 59 | def make_X(instance, ipad=0, hws=2): 60 | i_seq = chain(repeat(ipad, hws), instance.sentence, repeat(ipad, hws)) 61 | emb_i_seq = [emb_w2i.get(i2w[i], emb_unk_i) for i in i_seq] 62 | len_sen = len(instance.sentence) 63 | ws = 2 * hws + 1 64 | emb_i_X = [emb_i_seq[i : i + ws] for i in range(len_sen)] 65 | X = emb.take(emb_i_X, axis=0) # shape: (len, ws, emb_dim) 66 | X = X.reshape((len_sen, -1)) # shape: (len, ws * emb_dim) 67 | return X 68 | 69 | def make_y(instance, tag_type='POS'): 70 | return np.array(instance.tags[tag_type]) 71 | 72 | def make_X_y(instances): 73 | X = np.concatenate(list(make_X(ins) for ins in tqdm(instances))) 74 | y = np.concatenate(list(make_y(ins) for ins in instances)) 75 | logging.info(f"X.shape = {X.shape}") 76 | logging.info(f"y.shape = {y.shape}") 77 | return X, y 78 | 79 | logging.info("building training instances...") 80 | train_X, train_y = make_X_y(training_instances) 81 | logging.info("building test instances...") 82 | test_X, test_y = make_X_y(test_instances) 83 | 84 | ## Train a logistic regression classifier and report scores 85 | logging.info("training...") 86 | clsfr = LogisticRegression(random_state=args.random_seed, verbose=False, C=args.C) 87 | clsfr.fit(train_X, train_y) 88 | # print("Train acc: {}".format(clsfr.score(train_X, train_y))) 89 | print("Test acc: {}".format(clsfr.score( test_X, test_y))) 90 | -------------------------------------------------------------------------------- /datasets/target_vectors/glove/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import subprocess as sp 5 | import unicodedata 6 | import zipfile 7 | 8 | from utils import dotdict, file_tqdm 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | zip_path = f"{dir_path}/glove.840B.300d.zip" 14 | raw_emb_path = f"{dir_path}/glove.840B.300d.txt" 15 | txt_emb_path = f"{dir_path}/glove.840B.300d.processed.txt" 16 | w2v_emb_path = f"{dir_path}/glove.840B.300d.processed.w2v" 17 | word_freq_path = f"{dir_path}/word_freq.jsonl" 18 | raw_count_path = f"{dir_path}/word_freq.txt" 19 | 20 | emb_dim = 300 21 | 22 | 23 | def prepare_glove_paths( 24 | dir_path=dir_path, 25 | zip_path=zip_path, 26 | raw_emb_path=raw_emb_path, 27 | txt_emb_path=txt_emb_path, 28 | word_freq_path=word_freq_path, 29 | w2v_emb_path=w2v_emb_path, 30 | raw_count_path=raw_count_path, 31 | ): 32 | if not os.path.exists(zip_path): 33 | logger.info("downloading zip file...") 34 | url = "http://nlp.stanford.edu/data/glove.840B.300d.zip" 35 | sp.run(f"wget -O {zip_path} {url}".split()) 36 | 37 | if not os.path.exists(raw_emb_path): 38 | logger.info("unzipping...") 39 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 40 | zip_ref.extractall(dir_path) 41 | 42 | if not os.path.exists(txt_emb_path): 43 | logger.info("generating txt emb file...") 44 | with open(raw_emb_path, "r") as fin, open(txt_emb_path, "w") as fout: 45 | vocab_len = 0 46 | for line in file_tqdm(fin): 47 | ss = line.split() 48 | if len(ss) != emb_dim + 1: 49 | logging.critical(f'line "{line[:30]}"... might include word with space, skipped') 50 | continue 51 | 52 | w = ss[0] 53 | 54 | # copied from `datasets/google/converter.py` 55 | aw = unicodedata.normalize("NFKD", w).encode("ASCII", "ignore") 56 | if 20 > len(aw) > 1 and not any(c in w for c in " _./") and aw.islower(): 57 | vocab_len += 1 58 | fout.write(line) 59 | 60 | if not os.path.exists(w2v_emb_path): 61 | logger.info("generating w2v emb file...") 62 | with open(txt_emb_path) as fin, open(w2v_emb_path, "w") as fout: 63 | print(vocab_len, emb_dim, file=fout) 64 | for line in file_tqdm(fin): 65 | fout.write(line) 66 | 67 | if not os.path.exists(word_freq_path): 68 | logger.info("generating word freq jsonl file...") 69 | with open(txt_emb_path) as fin, open(word_freq_path, "w") as fout: 70 | for line in fin: 71 | print(json.dumps((line.split()[0], 1)), file=fout) 72 | 73 | if not os.path.exists(raw_count_path): 74 | logger.info("generating word freq txt file...") 75 | with open(txt_emb_path) as fin, open(raw_count_path, "w") as fout: 76 | for line in fin: 77 | print(line.split()[0], 1, file=fout, sep='\t') 78 | 79 | return dotdict( 80 | dir_path=dir_path, 81 | raw_emb_path=raw_emb_path, 82 | txt_emb_path=txt_emb_path, 83 | w2v_emb_path=w2v_emb_path, 84 | word_freq_path=word_freq_path, 85 | raw_count_path=raw_count_path, 86 | ) 87 | 88 | 89 | if __name__ == '__main__': 90 | prepare_glove_paths() 91 | -------------------------------------------------------------------------------- /sasaki_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess as sp 4 | from pathlib import Path 5 | 6 | from utils import dotdict 7 | 8 | 9 | def train( 10 | emb_path, 11 | result_path, 12 | epoch, 13 | H, 14 | F, 15 | freq_path=None, 16 | codecs_path=None, 17 | ): 18 | """ 19 | :return: a model info object 20 | """ 21 | with open(emb_path) as f: 22 | _, embed_dim = f.readline().strip().split() 23 | 24 | cmd = f""" 25 | python compact_reconstruction/src/train.py 26 | --gpu 0 27 | --ref_vec_path {emb_path} 28 | --embed_dim {embed_dim} 29 | --maxlen 200 30 | --network_type 3 31 | --limit_size {F} 32 | --result_dir {result_path} 33 | --unique_false 34 | --epoch {epoch} 35 | --snapshot_interval 10 36 | --subword_type 4 37 | --multi_hash two 38 | --hashed_idx 39 | --bucket_size {H} 40 | """ 41 | 42 | if freq_path: 43 | cmd += f" --freq_path {freq_path} " 44 | 45 | if codecs_path: 46 | cmd += f" --codecs_path {codecs_path} " 47 | 48 | sp.call(cmd.split()) 49 | 50 | return get_info_from_result_path(result_path / "sep_kvq") 51 | 52 | 53 | def inference(model_info, query_path): 54 | """ 55 | :return: resulting embedding path 56 | """ 57 | result_path = Path(model_info["result_path"]) 58 | model_path = model_info["model_path"] 59 | codecs_path = model_info["codecs_path"] 60 | epoch = model_info["epoch"] 61 | 62 | cmd = f""" 63 | python compact_reconstruction/src/inference.py 64 | --gpu 0 65 | --model_path {model_path} 66 | --codecs_path {codecs_path} 67 | --oov_word_path {query_path} 68 | """ 69 | sp.call(cmd.split()) 70 | 71 | return result_path / f"inference_embedding_epoch{epoch}" / "embedding.txt" 72 | 73 | 74 | def prepare_codecs_path(ref_vec_path, result_path, n_min=3, n_max=30): 75 | """ 76 | See https://github.com/losyer/compact_reconstruction/tree/master/src/preprocess 77 | """ 78 | os.makedirs(result_path, exist_ok=True) 79 | unsorted_codecs_path = os.path.join(result_path, f"codecs-min{n_min}max{n_max}.unsorted") 80 | sorted_codecs_path = os.path.join(result_path, f"codecs-min{n_min}max{n_max}.sorted") 81 | 82 | if not os.path.exists(unsorted_codecs_path): 83 | from sasaki_codecs import main as make_codecs 84 | 85 | make_codecs(dotdict( 86 | ref_vec_path=ref_vec_path, 87 | output=unsorted_codecs_path, 88 | n_max=n_max, 89 | n_min=n_min, 90 | test=False, 91 | )) 92 | 93 | if not os.path.exists(sorted_codecs_path): 94 | with open(sorted_codecs_path, 'w') as fout: 95 | sp.run(f"sort -k 2,2 -n -r {unsorted_codecs_path}".split(), stdout=fout) 96 | 97 | return sorted_codecs_path 98 | 99 | 100 | def _get_latest_in_dir(dir_path): 101 | return max(dir_path.iterdir(), key=lambda x: x.stat().st_mtime) 102 | 103 | 104 | def get_info_from_result_path(result_path): 105 | result_path = _get_latest_in_dir(result_path) 106 | settings_path = result_path / "settings.json" 107 | data = json.load(open(settings_path, "r")) 108 | epoch = data['epoch'] 109 | data["model_path"] = result_path / f"model_epoch_{epoch}" 110 | data["result_path"] = result_path 111 | return data 112 | -------------------------------------------------------------------------------- /sasaki_codecs.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from https://github.com/losyer/compact_reconstruction/blob/master/src/preprocess/make_ngram_dic.py 3 | Usage: 4 | 5 | $ python make_ngram_dic.py \ 6 | --ref_vec_path [reference_vector_path (word2vec format)] \ 7 | --output codecs-min3max30 \ 8 | --n_max 30 \ 9 | --n_min 3 10 | 11 | $ sort -k 2,2 -n -r codecs-min3max30 > codecs-min3max30.sorted 12 | """ 13 | 14 | import json, sys, argparse, os, codecs 15 | from datetime import datetime 16 | from collections import defaultdict 17 | 18 | def get_total_line(path, test): 19 | if not(test): 20 | total_line = 0 21 | print('get # of lines', flush=True) 22 | with codecs.open(path, "r", 'utf-8', errors='replace') as input_data: 23 | for _ in input_data: 24 | total_line += 1 25 | print('done', flush=True) 26 | print('# of lines = {}'.format(total_line), flush=True) 27 | else: 28 | total_line = 1000 29 | print('# of lines = {}'.format(total_line), flush=True) 30 | 31 | return total_line 32 | 33 | 34 | def get_ngram(word, nmax, nmin): 35 | all_ngrams =[] 36 | word = '^' + word + '@' 37 | N = len(word) 38 | f =lambda x,n :[x[i:i+n] for i in range(len(x)-n+1)] 39 | for n in range(N): 40 | if n+1 < nmin or n+1 > nmax: 41 | continue 42 | ngram_list = f(word, n+1) 43 | all_ngrams += ngram_list 44 | return all_ngrams 45 | 46 | def main(args): 47 | 48 | total_line = get_total_line(path=args.ref_vec_path, test=args.test) 49 | 50 | print('create ngram frequency dictionary ...', flush=True) 51 | idx_freq_dic = defaultdict(int) 52 | with codecs.open(args.ref_vec_path, "r", 'utf-8', errors='replace') as input_data: 53 | for i, line in enumerate(input_data): 54 | 55 | if i % int(total_line/10) == 0: 56 | print('{} % done'.format(round(i / (total_line/100))), flush=True) 57 | 58 | if i == 0: 59 | col = line.strip('\n').split() 60 | vocab_size, dim = int(col[0]), int(col[1]) 61 | else: 62 | col = line.strip(' \n').rsplit(' ', dim) 63 | assert len(col) == dim+1 64 | 65 | word = col[0] 66 | # if ' ' in word: 67 | # from IPython.core.debugger import Pdb; Pdb().set_trace() 68 | if len(word) > 30: 69 | continue 70 | ngrams = get_ngram(word, args.n_max, args.n_min) 71 | 72 | for ngram in ngrams: 73 | idx_freq_dic[ngram] += 1 74 | 75 | if args.test and i > 1000: 76 | break 77 | 78 | print('create ngram frequency dictionary ... done', flush=True) 79 | 80 | # save 81 | print('save ... ', flush=True) 82 | fo = open(args.output, 'w') 83 | for ngram, freq in idx_freq_dic.items(): 84 | fo.write('{} {}\n'.format(ngram, freq)) 85 | fo.close() 86 | print('save ... done', flush=True) 87 | 88 | 89 | if __name__ == '__main__': 90 | parser=argparse.ArgumentParser() 91 | 92 | parser.add_argument('--test', action='store_true', help='use tiny dataset') 93 | parser.add_argument('--n_max', type=int, default=30, help='') 94 | parser.add_argument('--n_min', type=int, default=3, help='') 95 | 96 | # data path 97 | parser.add_argument('--ref_vec_path', type=str, default="") 98 | parser.add_argument('--output', type=str, default="") 99 | 100 | args = parser.parse_args() 101 | main(args) 102 | 103 | -------------------------------------------------------------------------------- /datasets/target_vectors/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pickle 5 | import unicodedata 6 | 7 | import numpy as np 8 | 9 | from load import load_embedding 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def convert_target_dataset( 15 | input_emb_path, 16 | *, 17 | w2v_emb_path=None, 18 | txt_emb_path=None, 19 | pkl_emb_path=None, 20 | word_list_path=None, 21 | word_freq_path=None, 22 | raw_count_path=None, 23 | vocab_size_limit=None, 24 | ): 25 | if all(path is None or os.path.exists(path) for path in 26 | (w2v_emb_path, txt_emb_path, pkl_emb_path, word_list_path, word_freq_path, raw_count_path)): 27 | return 28 | 29 | vocab, emb = load_embedding(input_emb_path) 30 | 31 | vocab = vocab[:vocab_size_limit] 32 | emb = emb[:vocab_size_limit] 33 | 34 | return save_target_dataset( 35 | vocab, 36 | emb, 37 | w2v_emb_path=w2v_emb_path, 38 | txt_emb_path=txt_emb_path, 39 | pkl_emb_path=pkl_emb_path, 40 | word_list_path=word_list_path, 41 | word_freq_path=word_freq_path, 42 | raw_count_path=raw_count_path, 43 | ) 44 | 45 | 46 | def save_target_dataset( 47 | vocab, 48 | emb, 49 | *, 50 | w2v_emb_path=None, 51 | txt_emb_path=None, 52 | pkl_emb_path=None, 53 | word_list_path=None, 54 | word_freq_path=None, 55 | raw_count_path=None, 56 | ): 57 | if w2v_emb_path and not os.path.exists(w2v_emb_path): 58 | logger.info("generating w2v emb file...") 59 | with open(w2v_emb_path, "w") as fout: 60 | print(len(vocab), len(emb[0]), file=fout) 61 | for v, e in zip(vocab, emb): 62 | print(v, *e, file=fout) 63 | 64 | if txt_emb_path and not os.path.exists(txt_emb_path): 65 | logger.info("generating txt emb file...") 66 | with open(txt_emb_path, "w") as fout: 67 | for v, e in zip(vocab, emb): 68 | print(v, *e, file=fout) 69 | 70 | if pkl_emb_path and not os.path.exists(pkl_emb_path): 71 | logger.info("generating pkl emb file...") 72 | emb = emb if isinstance(emb, np.ndarray) else np.array(emb) 73 | with open(pkl_emb_path, "bw") as fout: 74 | pickle.dump((vocab, emb), fout) 75 | 76 | if word_list_path and not os.path.exists(word_list_path): 77 | logger.info("generating word list file...") 78 | with open(word_list_path, "w") as fout: 79 | for word in vocab: 80 | print(word, file=fout) 81 | 82 | if word_freq_path and not os.path.exists(word_freq_path): 83 | logger.info("generating word freq jsonl file...") 84 | with open(word_freq_path, "w") as fout: 85 | for word in vocab: 86 | print(json.dumps((word, 1)), file=fout) 87 | 88 | if raw_count_path and not os.path.exists(raw_count_path): 89 | logger.info("generating word freq txt file...") 90 | with open(raw_count_path, "w") as fout: 91 | for word in vocab: 92 | print(word, 1, file=fout, sep='\t') 93 | 94 | 95 | def _is_word(w): 96 | aw = unicodedata.normalize("NFKD", w).encode("ASCII", "ignore") 97 | return 20 > len(aw) > 1 and not any(c in w for c in " _./") and aw.islower() 98 | 99 | 100 | def clean_target_emb(raw_vocab, raw_emb): 101 | logger.info("normalizing...") 102 | 103 | vocab, emb = [], [] 104 | for w, e in zip(raw_vocab, raw_emb): 105 | if _is_word(w): 106 | vocab.append(w) 107 | emb.append(e) 108 | return vocab, emb 109 | -------------------------------------------------------------------------------- /ws_exp_pbos.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import multiprocessing as mp 4 | import os 5 | from collections import ChainMap 6 | 7 | import pbos_train 8 | import subwords 9 | from datasets import prepare_ws_combined_query_path, prepare_target_vector_paths, prepare_unigram_freq_paths, \ 10 | prepare_ws_dataset_paths, get_ws_dataset_names 11 | from pbos_pred import predict 12 | from utils import dotdict 13 | from utils.args import dump_args 14 | from ws_eval import eval_ws 15 | 16 | 17 | def train(args): 18 | subwords.build_subword_vocab_cli(dotdict(ChainMap( 19 | dict(word_freq=args.subword_vocab_word_freq, output=args.subword_vocab), args, 20 | ))) 21 | 22 | if args.subword_prob: 23 | subwords.build_subword_prob_cli(dotdict(ChainMap( 24 | dict(word_freq=args.subword_prob_word_freq, output=args.subword_prob), args, 25 | ))) 26 | 27 | pbos_train.main(args) 28 | 29 | 30 | def evaluate(args): 31 | with open(args.eval_result_path, "w") as fout: 32 | for bname in get_ws_dataset_names(): 33 | bench_path = prepare_ws_dataset_paths(bname).txt_path 34 | for lower in (True, False): 35 | print(eval_ws(args.pred_path, bench_path, lower=lower, oov_handling='zero'), file=fout) 36 | 37 | 38 | def exp(model_type, target_vector_name): 39 | target_vector_paths = prepare_target_vector_paths(target_vector_name) 40 | args = dotdict() 41 | 42 | # misc 43 | args.results_dir = f"results/ws/{target_vector_name}_{model_type}" 44 | args.model_type = model_type 45 | args.log_level = "INFO" 46 | 47 | # subword 48 | if model_type == "bos": 49 | args.word_boundary = True 50 | elif model_type in ('pbos', 'pbosn'): 51 | args.word_boundary = False 52 | args.subword_min_count = None 53 | args.subword_uniq_factor = None 54 | if model_type == 'bos': 55 | args.subword_min_len = 3 56 | args.subword_max_len = 6 57 | elif model_type in ('pbos', 'pbosn'): 58 | args.subword_min_len = 1 59 | args.subword_max_len = None 60 | 61 | # subword vocab 62 | args.subword_vocab_max_size = None 63 | args.subword_vocab_word_freq = target_vector_paths.word_freq_path 64 | args.subword_vocab = f"{args.results_dir}/subword_vocab.jsonl" 65 | 66 | # subword prob 67 | args.subword_prob_take_root = False 68 | if model_type == 'bos': 69 | args.subword_prob = None 70 | elif model_type in ('pbos', 'pbosn'): 71 | args.subword_prob_min_prob = 0 72 | args.subword_prob_word_freq = prepare_unigram_freq_paths().word_freq_path 73 | args.subword_prob = f"{args.results_dir}/subword_prob.jsonl" 74 | 75 | # training 76 | args.target_vectors = target_vector_paths.pkl_emb_path 77 | args.model_path = f"{args.results_dir}/model.pkl" 78 | args.epochs = 50 79 | args.lr = 1 80 | args.lr_decay = True 81 | args.random_seed = 42 82 | args.subword_prob_eps = 0.01 83 | args.subword_weight_threshold = None 84 | if args.model_type == 'pbosn': 85 | args.normalize_semb = True 86 | else: 87 | args.normalize_semb = False 88 | 89 | # prediction & evaluation 90 | args.pred_path = f"{args.results_dir}/vectors.txt" 91 | args.query_path = prepare_ws_combined_query_path() 92 | args.eval_result_path = f"{args.results_dir}/result.txt" 93 | os.makedirs(args.results_dir, exist_ok=True) 94 | 95 | # redirect log output 96 | log_file = open(f"{args.results_dir}/info.log", "w+") 97 | logging.basicConfig(level=logging.INFO, stream=log_file) 98 | dump_args(args) 99 | 100 | with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): 101 | train(args) 102 | 103 | # prediction 104 | time_used = predict( 105 | model=args.model_path, 106 | queries=args.query_path, 107 | save=args.pred_path, 108 | word_boundary=args.word_boundary, 109 | ) 110 | print(f"time used: {time_used:.3f}") 111 | 112 | # evaluate 113 | evaluate(args) 114 | 115 | 116 | if __name__ == '__main__': 117 | model_types = ("pbos", "bos") 118 | target_vector_names = ("google", "polyglot") 119 | 120 | for target_vector_name in target_vector_names: # avoid race condition 121 | prepare_target_vector_paths(target_vector_name) 122 | 123 | with mp.Pool() as pool: 124 | results = [ 125 | pool.apply_async(exp, (model_type, target_vector_name)) 126 | for model_type in model_types 127 | for target_vector_name in target_vector_names 128 | ] 129 | 130 | for r in results: 131 | r.get() -------------------------------------------------------------------------------- /ws_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | """Evaluate word similarity. 5 | Adapted from: `https://github.com/facebookresearch/fastText/blob/316b4c9f499669f0cacc989c32bf2cef23a8f9ac/eval.py`. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | from __future__ import unicode_literals 12 | 13 | import logging 14 | import math 15 | import os 16 | 17 | import numpy as np 18 | from scipy import stats 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def compat_splitting(line): 24 | return line.decode('utf8').split() 25 | 26 | 27 | def similarity(v1, v2): 28 | n1 = np.linalg.norm(v1) 29 | n2 = np.linalg.norm(v2) 30 | return np.dot(v1, v2) / n1 / n2 31 | 32 | def edit_distence(s1, s2) : 33 | if len(s1) > len(s2): 34 | s1, s2 = s2, s1 35 | distances = range(len(s1) + 1) 36 | for i2, c2 in enumerate(s2): 37 | distances_ = [i2+1] 38 | for i1, c1 in enumerate(s1): 39 | if c1 == c2: 40 | distances_.append(distances[i1]) 41 | else: 42 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 43 | distances = distances_ 44 | return distances[-1] 45 | 46 | def editsim(w1, w2): 47 | return -edit_distence(w1, w2) / max(len(w1), len(w2)) 48 | 49 | 50 | def load_vectors(modelPath): 51 | vectors = {} 52 | fin = open(modelPath, 'rb') 53 | for _, line in enumerate(fin): 54 | try: 55 | tab = compat_splitting(line) 56 | vec = np.array([float(x) for x in tab[1:]], dtype=float) 57 | word = tab[0] 58 | if np.linalg.norm(vec) < 1e-6: 59 | continue 60 | if not word in vectors: 61 | vectors[word] = vec 62 | except ValueError: 63 | continue 64 | except UnicodeDecodeError: 65 | continue 66 | fin.close() 67 | return vectors 68 | 69 | 70 | def eval_ws(modelPath, dataPath, lower, oov_handling="drop"): 71 | mysim = [] 72 | gold = [] 73 | words = [] 74 | drop = 0.0 75 | nwords = 0.0 76 | 77 | if modelPath != "EditSim": 78 | vectors = load_vectors(modelPath) 79 | 80 | fin = open(dataPath, 'rb') 81 | for line in fin: 82 | tline = compat_splitting(line) 83 | word1 = tline[0] 84 | word2 = tline[1] 85 | golden_score = float(tline[2]) 86 | 87 | if lower: 88 | word1, word2 = word1.lower(), word2.lower() 89 | nwords = nwords + 1.0 90 | 91 | words.append((word1, word2)) 92 | 93 | if modelPath == "EditSim": 94 | d = editsim(word1, word2) 95 | else: 96 | if (word1 in vectors) and (word2 in vectors): 97 | v1 = vectors[word1] 98 | v2 = vectors[word2] 99 | d = similarity(v1, v2) 100 | else: 101 | drop = drop + 1.0 102 | if oov_handling == "zero": 103 | d = 0 104 | else: 105 | continue 106 | 107 | mysim.append(d) 108 | gold.append(golden_score) 109 | fin.close() 110 | 111 | corr = stats.spearmanr(mysim, gold) 112 | dataset = os.path.basename(dataPath) 113 | 114 | logger.info(f"eval info for: {dataset}") 115 | for _, g, m, (w1, w2) in sorted(zip(stats.zscore(mysim) - stats.zscore(gold), gold, mysim, words)): 116 | logger.info(f"{g:.2f} {m: .2f} {w1} {w2}") 117 | 118 | return "{:15s}: {:2.0f} (OOV: {:2.0f}%, {}, l={})".format( 119 | dataset, 120 | corr[0] * 100, 121 | math.ceil(drop / nwords * 100.0), 122 | oov_handling[0], 123 | "T" if lower else "F" 124 | ) 125 | 126 | 127 | if __name__ == "__main__": 128 | import argparse 129 | 130 | parser = argparse.ArgumentParser(description='Process some integers.') 131 | parser.add_argument( 132 | '--model', 133 | '-m', 134 | dest='modelPath', 135 | action='store', 136 | required=True, 137 | help='path to model' 138 | ) 139 | parser.add_argument( 140 | '--data', 141 | '-d', 142 | dest='dataPath', 143 | action='store', 144 | required=True, 145 | help='path to data' 146 | ) 147 | parser.add_argument('--lower', action='store_true', default=True) 148 | parser.add_argument('--no_lower', dest='lower', action='store_false') 149 | parser.add_argument('--oov_handling', default='drop', choices=['drop', 'zero']) 150 | args = parser.parse_args() 151 | 152 | print(eval_ws(args.modelPath, args.dataPath, args.lower, args.oov_handling)) 153 | -------------------------------------------------------------------------------- /datasets/word_similarity/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess as sp 3 | 4 | from utils import dotdict 5 | 6 | datasets_dir = os.path.dirname(os.path.realpath(__file__)) 7 | 8 | _ws_datasets = { 9 | # default word similarity datasets for English 10 | None: { 11 | "wordsim353": { 12 | "url": "https://leviants.com/wp-content/uploads/2020/01/wordsim353.zip", 13 | "raw_txt_rel_path": "combined.tab", 14 | "skip_lines": 1, 15 | }, 16 | "rw": { 17 | "url": "https://nlp.stanford.edu/~lmthang/morphoNLM/rw.zip", 18 | "raw_txt_rel_path": "rw/rw.txt", 19 | }, 20 | "card660": { 21 | "url": "https://pilehvar.github.io/card-660/dataset.tsv", 22 | "no_zip": True, 23 | "raw_txt_rel_path": "dataset.tsv", 24 | }, 25 | }, 26 | # multilingual word similarity datasets 27 | "en": {}, 28 | "it": {}, 29 | "ru": {}, 30 | "de": {}, 31 | } 32 | 33 | for lang, full_name in [('en', 'english'), ('it', 'italian'), ('ru', 'russian'), ('de', 'german')]: 34 | for ws_suffix in ("-rel", "-sim", ""): 35 | _ws_datasets[lang][f"ws353-{lang}{ws_suffix}"] = { 36 | "url": f"https://raw.githubusercontent.com/iraleviant/eval-multilingual-simlex/master/evaluation/ws-353/wordsim353-{full_name}{ws_suffix}.txt", 37 | "raw_txt_rel_path": f"wordsim353-{full_name}{ws_suffix}.txt", 38 | "no_zip": True, 39 | "skip_lines": 1, 40 | } 41 | 42 | _ws_datasets[lang][f"simlex999-{lang}"] = { 43 | "url": f"https://raw.githubusercontent.com/nmrksic/eval-multilingual-simlex/master/evaluation/simlex-{full_name}.txt", 44 | "raw_txt_rel_path": f"simlex-{full_name}.txt", 45 | "no_zip": True, 46 | "skip_lines": 1, 47 | } 48 | 49 | 50 | def get_ws_dataset_names(lang=None): 51 | return list(_ws_datasets[lang]) 52 | 53 | 54 | def _get_ws_dataset_info(name): 55 | for datasets in _ws_datasets.values(): 56 | if name in datasets: 57 | return datasets[name] 58 | 59 | raise NotImplementedError 60 | 61 | 62 | def prepare_ws_dataset_paths(name): 63 | binfo = _get_ws_dataset_info(name) 64 | 65 | raw_txt_path = f"{datasets_dir}/{name}/{binfo['raw_txt_rel_path']}" 66 | txt_path = f"{datasets_dir}/{name}/{name}.txt" 67 | query_path = f"{datasets_dir}/{name}/queries.txt" 68 | 69 | if not os.path.exists(raw_txt_path): 70 | sp.call( 71 | f""" 72 | wget -c {binfo['url']} -P {datasets_dir}/{name} 73 | """.split() 74 | ) 75 | if not binfo.get("no_zip"): 76 | sp.call( 77 | f""" 78 | unzip {datasets_dir}/{name}/{name}.zip -d {datasets_dir}/{name} 79 | """.split() 80 | ) 81 | 82 | if not os.path.exists(txt_path): 83 | with open(raw_txt_path) as f, open(txt_path, "w") as fout: 84 | for i, line in enumerate(f): 85 | # discard head lines 86 | if i < binfo.get("skip_lines", 0): 87 | continue 88 | # NOTE: in `fastText/eval.py`, golden words get lowercased anyways, 89 | # but predicted words remain as they are. 90 | print(line, end="", file=fout) 91 | 92 | if not os.path.exists(query_path): 93 | words = set() 94 | with open(txt_path) as f: 95 | for line in f: 96 | w1, w2 = line.split()[:2] 97 | words.add(w1) 98 | words.add(w2) 99 | with open(query_path, "w") as fout: 100 | for w in words: 101 | print(w, file=fout) 102 | 103 | return dotdict( 104 | txt_path=txt_path, 105 | query_path=query_path, 106 | ) 107 | 108 | 109 | def prepare_ws_combined_query_path(lang=None): 110 | """ 111 | Prepare the combined query path for word similarity datasets dataset 112 | """ 113 | 114 | combined_query_path = f"{datasets_dir}/combined_query_{lang}.txt" 115 | 116 | if not os.path.exists(combined_query_path): 117 | all_words = set() 118 | for bname in get_ws_dataset_names(lang): 119 | query_path = prepare_ws_dataset_paths(bname).query_path 120 | with open(query_path) as fin: 121 | for line in fin: 122 | all_words.add(line.strip()) 123 | all_words.add(line.strip().lower()) 124 | with open(combined_query_path, 'w') as fout: 125 | for w in all_words: 126 | print(w, file=fout) 127 | 128 | return combined_query_path 129 | 130 | 131 | if __name__ == '__main__': 132 | for lang in [None, "en", "it", "ru", "de"]: 133 | for bname in get_ws_dataset_names(lang): 134 | prepare_ws_dataset_paths(bname) 135 | prepare_ws_combined_query_path(lang) 136 | -------------------------------------------------------------------------------- /ws_multilingual_exp_pbos.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import multiprocessing as mp 4 | import os 5 | from collections import ChainMap 6 | 7 | import pbos_train 8 | import subwords 9 | from datasets import (prepare_target_vector_paths, prepare_polyglot_freq_paths, prepare_ws_dataset_paths, 10 | get_ws_dataset_names, prepare_ws_combined_query_path) 11 | from pbos_pred import predict 12 | from utils import dotdict 13 | from utils.args import dump_args 14 | from ws_eval import eval_ws 15 | 16 | 17 | def train(args): 18 | subwords.build_subword_vocab_cli(dotdict(ChainMap( 19 | dict(word_freq=args.subword_vocab_word_freq, output=args.subword_vocab), args, 20 | ))) 21 | 22 | if args.subword_prob: 23 | subwords.build_subword_prob_cli(dotdict(ChainMap( 24 | dict(word_freq=args.subword_prob_word_freq, output=args.subword_prob), args, 25 | ))) 26 | 27 | pbos_train.main(args) 28 | 29 | 30 | def evaluate(args): 31 | with open(args.eval_result_path, "w") as fout: 32 | for bname in get_ws_dataset_names(args.target_vector_name): 33 | bench_path = prepare_ws_dataset_paths(bname).txt_path 34 | for lower in (True, False): 35 | print(args.model_type.ljust(10), eval_ws(args.pred_path, bench_path, lower=lower, oov_handling='zero'), 36 | file=fout) 37 | 38 | 39 | def exp(model_type, target_vector_name): 40 | target_vector_paths = prepare_target_vector_paths(f"wiki2vec-{target_vector_name}") 41 | args = dotdict() 42 | 43 | # misc 44 | args.results_dir = f"results/ws_multi/{target_vector_name}_{model_type}" 45 | args.model_type = model_type 46 | args.log_level = "INFO" 47 | args.target_vector_name = target_vector_name 48 | 49 | # subword 50 | if model_type == "bos": 51 | args.word_boundary = True 52 | elif model_type in ('pbos', 'pbosn'): 53 | args.word_boundary = False 54 | args.subword_min_count = None 55 | args.subword_uniq_factor = None 56 | if model_type == 'bos': 57 | args.subword_min_len = 3 58 | args.subword_max_len = 6 59 | elif model_type in ('pbos', 'pbosn'): 60 | args.subword_min_len = 1 61 | args.subword_max_len = None 62 | 63 | # subword vocab 64 | args.subword_vocab_max_size = None 65 | args.subword_vocab_word_freq = target_vector_paths.word_freq_path 66 | args.subword_vocab = f"{args.results_dir}/subword_vocab.jsonl" 67 | 68 | # subword prob 69 | args.subword_prob_take_root = False 70 | if model_type == 'bos': 71 | args.subword_prob = None 72 | elif model_type in ('pbos', 'pbosn'): 73 | args.subword_prob_min_prob = 0 74 | args.subword_prob_word_freq = prepare_polyglot_freq_paths(target_vector_name).word_freq_path 75 | args.subword_prob = f"{args.results_dir}/subword_prob.jsonl" 76 | 77 | # training 78 | args.target_vectors = target_vector_paths.pkl_emb_path 79 | args.model_path = f"{args.results_dir}/model.pkl" 80 | args.epochs = 50 81 | args.lr = 1 82 | args.lr_decay = True 83 | args.random_seed = 42 84 | args.subword_prob_eps = 0.01 85 | args.subword_weight_threshold = None 86 | if args.model_type == 'pbosn': 87 | args.normalize_semb = True 88 | else: 89 | args.normalize_semb = False 90 | 91 | # prediction & evaluation 92 | args.eval_result_path = f"{args.results_dir}/result.txt" 93 | args.pred_path = f"{args.results_dir}/vectors.txt" 94 | os.makedirs(args.results_dir, exist_ok=True) 95 | 96 | # redirect log output 97 | log_file = open(f"{args.results_dir}/info.log", "w+") 98 | logging.basicConfig(level=logging.INFO, stream=log_file) 99 | dump_args(args) 100 | 101 | with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): 102 | train(args) 103 | 104 | combined_query_path = prepare_ws_combined_query_path(args.target_vector_name) 105 | 106 | predict( 107 | model=args.model_path, 108 | queries=combined_query_path, 109 | save=args.pred_path, 110 | word_boundary=args.word_boundary, 111 | ) 112 | 113 | evaluate(args) 114 | 115 | 116 | if __name__ == '__main__': 117 | model_types = ("bos", "pbos") 118 | target_vector_names = ("en", "de", "it", "ru", ) 119 | 120 | for target_vector_name in target_vector_names: # avoid race condition 121 | prepare_target_vector_paths(f"wiki2vec-{target_vector_name}") 122 | prepare_polyglot_freq_paths(target_vector_name) 123 | 124 | with mp.Pool() as pool: 125 | results = [ 126 | pool.apply_async(exp, (model_type, target_vector_name)) 127 | for target_vector_name in target_vector_names 128 | for model_type in model_types 129 | ] 130 | 131 | for r in results: 132 | r.get() 133 | -------------------------------------------------------------------------------- /pbos_segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | from importlib import import_module 5 | from itertools import islice 6 | 7 | from datasets import prepare_target_vector_paths 8 | from nshortest import nshortest 9 | from pbos import * 10 | from subwords import ( 11 | add_subword_args, 12 | add_subword_prob_args, 13 | add_subword_vocab_args, 14 | build_subword_counter, 15 | build_subword_prob, 16 | ) 17 | from utils import file_tqdm, normalize_prob 18 | from utils.args import add_logging_args, set_logging_config, dump_args 19 | 20 | parser = argparse.ArgumentParser("PboS segmenter and subword weigher.", 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--prob_word_freq', default="unigram_freq", 23 | choices=["unigram_freq"], 24 | help="list of words to create subword prob") 25 | parser.add_argument('--vocab_word_freq', 26 | choices=("google", "polyglot", "glove"), 27 | help="list of words to create subword vocab") 28 | parser.add_argument('--n_largest', '-n', type=int, default=20, 29 | help="the number of segmentations to show") 30 | parser.add_argument('--subword_prob_eps', '-spe', type=float, default=1e-2, 31 | help="the infinitesimal prob for unseen subwords") 32 | parser.add_argument('--subword_weight_threshold', '-swt', type=float, 33 | help="the minimum weight of a subword to be considered") 34 | parser.add_argument('--interactive', '-i', action='store_true', 35 | help="interactive mode") 36 | parser.add_argument('--latex', action='store_true', 37 | help="output latex") 38 | add_subword_args(parser) 39 | add_subword_prob_args(parser) 40 | add_subword_vocab_args(parser) 41 | add_logging_args(parser) 42 | args = parser.parse_args() 43 | 44 | set_logging_config(args) 45 | dump_args(args) 46 | 47 | logger.info(f"building subword prob from `{args.prob_word_freq}`...") 48 | if args.prob_word_freq.lower().startswith("unigram_freq"): 49 | word_freq_path = import_module("datasets.unigram_freq")\ 50 | .prepare_unigram_freq_paths().word_freq_path 51 | else: 52 | raise ValueError(f"args.prob_word_freq=`{args.prob_word_freq}` not supported.") 53 | with open(word_freq_path) as fin: 54 | word_count_iter = (json.loads(line) for line in file_tqdm(fin)) 55 | subword_counter = build_subword_counter( 56 | word_count_iter, 57 | min_count=args.subword_min_count, 58 | min_len=args.subword_min_len, 59 | max_len=args.subword_max_len, 60 | word_boundary=args.word_boundary, 61 | uniq_factor=args.subword_uniq_factor, 62 | ) 63 | subword_prob = build_subword_prob( 64 | subword_counter, 65 | normalize_prob=normalize_prob, 66 | min_prob=args.subword_prob_min_prob, 67 | take_root=args.subword_prob_take_root, 68 | ) 69 | logger.info(f"subword prob size: {len(subword_prob)}") 70 | 71 | logger.info(f"building subword vocab from `{args.vocab_word_freq}`...") 72 | if args.vocab_word_freq is None: 73 | subword_vocab = set(subword_prob) 74 | else: 75 | word_freq_path = prepare_target_vector_paths(args.vocab_word_freq).word_freq_path 76 | with open(word_freq_path) as fin: 77 | word_count_iter = (json.loads(line) for line in file_tqdm(fin)) 78 | subword_counter = build_subword_counter( 79 | word_count_iter, 80 | max_size=args.subword_vocab_max_size, 81 | min_count=args.subword_min_count, 82 | min_len=args.subword_min_len, 83 | max_len=args.subword_max_len, 84 | word_boundary=args.word_boundary, 85 | uniq_factor=args.subword_uniq_factor, 86 | ) 87 | subword_vocab = set(subword_counter) 88 | subword_vocab -= set('<>') 89 | logger.info(f"subword vocab size: {len(subword_vocab)}") 90 | 91 | 92 | 93 | test_words = [ 94 | "farmland", 95 | "higher", 96 | "penpineapplepie", 97 | "paradichlorobenzene", 98 | "bisimulation", 99 | ] 100 | 101 | get_subword_prob=partial( 102 | get_subword_prob, 103 | subword_prob=subword_prob, 104 | take_root=args.subword_prob_take_root, 105 | eps=args.subword_prob_eps, 106 | ) 107 | 108 | 109 | def word_segs(w): 110 | if args.word_boundary: 111 | w = '<' + w + '>' 112 | 113 | p_prefix = calc_prefix_prob(w, get_subword_prob) 114 | p_suffix = calc_prefix_prob(w, get_subword_prob, backward=True) 115 | 116 | adjmat = [[None for __ in range(len(w) + 1)] for _ in range(len(w) + 1)] 117 | for i in range(len(w)): 118 | for j in range(i + 1, len(w) + 1): 119 | adjmat[i][j] = - math.log(max(1e-100, get_subword_prob(w[i:j]))) 120 | segs = nshortest(adjmat, args.n_largest) 121 | 122 | seg_score_dict = { 123 | '/'.join(w[i:j] for i, j in zip(seg, seg[1:])): math.exp(-score) / p_prefix[-1] 124 | for score, seg in segs 125 | } 126 | 127 | subword_weights = calc_subword_weights( 128 | w, 129 | subword_vocab=subword_vocab, 130 | get_subword_prob=get_subword_prob, 131 | weight_threshold=args.subword_weight_threshold, 132 | ) 133 | 134 | sub_weight_dict = { 135 | sub : weight 136 | for sub, weight in islice(sorted(subword_weights.items(), key=lambda t: t[1], reverse=True), args.n_largest) 137 | } 138 | 139 | return p_prefix, p_suffix, seg_score_dict, sub_weight_dict 140 | 141 | 142 | def test_word(w): 143 | p_prefix, p_suffix, seg_score_dict, sub_weight_dict = word_segs(w) 144 | 145 | if args.latex: 146 | top_seg_str = ", ".join(f"{seg} ({score:.3f})" for seg, score in seg_score_dict.items()) 147 | sub_weight_str = ", ".join(f"{sub} ({weight:.3f})" for sub, weight in sub_weight_dict.items()) 148 | print(f"{w} \n& {top_seg_str} \n& {sub_weight_str} \n\\\\\n\n".translate( 149 | str.maketrans({ 150 | "<": r"{\textless}", 151 | ">": r"{\textgreater}", 152 | }) 153 | ) 154 | ) 155 | 156 | else: 157 | 158 | print("Word:", w) 159 | 160 | logging.info("p_prefix: " + '\t'.join(f"{x:.5e}" for x in p_prefix)) 161 | logging.info("p_suffix: " + '\t'.join(f"{x:.5e}" for x in p_suffix)) 162 | 163 | print("top segmentations:") 164 | for seg, score in seg_score_dict.items(): 165 | print("{:.5e} : {}".format(score, seg)) 166 | 167 | print("top subword weights:") 168 | for sub, weight in sub_weight_dict.items(): 169 | print("{:.5e} : {}".format(weight, sub)) 170 | 171 | 172 | for w in test_words: 173 | test_word(w) 174 | 175 | if args.interactive: 176 | while True: 177 | w = input().strip() 178 | if not w: 179 | break 180 | test_word(w) 181 | -------------------------------------------------------------------------------- /pbos_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | from time import time 7 | 8 | import numpy as np 9 | 10 | from load import load_embedding 11 | from pbos import PBoS 12 | from subwords import ( 13 | add_subword_prob_args, 14 | add_word_args, 15 | bound_word, 16 | subword_prob_post_process, 17 | ) 18 | from utils import file_tqdm 19 | from utils.args import add_logging_args, set_logging_config, dump_args 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description='PBoS trainer', 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | add_args(parser) 27 | return parser.parse_args() 28 | 29 | def add_args(parser): 30 | parser.add_argument('--target_vectors', required=True, 31 | help='pretrained target word vectors') 32 | parser.add_argument('--model_path', required=True, 33 | default="./results/run_{timestamp}/model.pbos", 34 | help='save path') 35 | add_training_args(parser) 36 | add_model_args(parser) 37 | add_word_args(parser) 38 | add_subword_prob_args(parser) 39 | add_logging_args(parser) 40 | return parser 41 | 42 | def add_training_args(parser): 43 | group = parser.add_argument_group('training hyperparameters') 44 | group.add_argument('--epochs', type=int, default=20, 45 | help='number of training epochs') 46 | group.add_argument('--lr', type=float, default=1.0, 47 | help='learning rate') 48 | group.add_argument('--random_seed', type=int, default=42, 49 | help='random seed used in training') 50 | group.add_argument('--lr_decay', action='store_true', default=True, 51 | help='reduce learning learning rate between epochs') 52 | group.add_argument('--no_lr_decay', dest='lr_decay', action='store_false') 53 | return group 54 | 55 | def add_model_args(parser): 56 | group = parser.add_argument_group('PBoS model arguments') 57 | group.add_argument('--subword_vocab', required=True, 58 | help="list of subwords to maintain subword embeddings") 59 | group.add_argument('--subword_prob', 60 | help="dict of subwords and their likelihood of presence. " 61 | "If not specified, assume uniform likelihood, aka fall back to BoS.") 62 | group.add_argument('--subword_weight_threshold', type=float, 63 | help="minimum weight of a subword within a word for it to contribute " 64 | "to the word embedding") 65 | group.add_argument('--subword_prob_eps', type=float, default=1e-2, 66 | help="default likelihood of a subword if it is not present in " 67 | "the given `subword_prob`") 68 | group.add_argument( 69 | '--normalize_semb', 70 | action='store_true', default=False, 71 | help='if set, normalize subword embeddings during training' 72 | ) 73 | return group 74 | 75 | 76 | def main(args): 77 | set_logging_config(args) 78 | 79 | save_path = args.model_path.format( 80 | timestamp=datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 81 | save_dir, _ = os.path.split(save_path) 82 | try : 83 | os.makedirs(save_dir) 84 | except FileExistsError : 85 | logger.warning( 86 | "Things will get overwritten for directory {}".format(save_dir)) 87 | 88 | dump_args(args, logger, os.path.join(save_dir, 'args.json')) 89 | 90 | logger.info(f'loading target vectors from `{args.target_vectors}`...') 91 | target_words, target_embs = \ 92 | load_embedding(args.target_vectors, show_progress=True) 93 | logger.info(f'embeddings loaded with {len(target_words)} words') 94 | 95 | logger.info(f"loading subword vocab from `{args.subword_vocab}`...") 96 | with open(args.subword_vocab) as fin: 97 | subword_vocab = dict(json.loads(line) for line in file_tqdm(fin)) 98 | logger.info(f"subword vocab size: {len(subword_vocab)}") 99 | 100 | if args.subword_prob: 101 | logger.info(f"loading subword prob from `{args.subword_prob}`...") 102 | with open(args.subword_prob) as fin: 103 | subword_prob = dict(json.loads(line) for line in file_tqdm(fin)) 104 | subword_prob = subword_prob_post_process( 105 | subword_prob, 106 | min_prob=args.subword_prob_min_prob, 107 | # take_root=args.subword_prob_take_root, 108 | ) 109 | else: 110 | subword_prob = None 111 | 112 | np.random.seed(args.random_seed) 113 | 114 | def MSE(pred, target) : 115 | return sum((pred - target) ** 2) / 2 116 | def MSE_backward(pred, target) : 117 | return (pred - target) 118 | 119 | model = PBoS( 120 | embedding_dim=len(target_embs[0]), 121 | subword_vocab=subword_vocab, 122 | subword_prob=subword_prob, 123 | weight_threshold=args.subword_weight_threshold, 124 | eps=args.subword_prob_eps, 125 | take_root=args.subword_prob_take_root, 126 | normalize_semb=args.normalize_semb, 127 | ) 128 | start_time = time() 129 | for i_epoch in range(args.epochs) : 130 | h = [] 131 | h_epoch = [] 132 | lr = args.lr / (1 + i_epoch) ** 0.5 if args.lr_decay else args.lr 133 | logger.info('epoch {:>2} / {} | lr {:.5f}'.format(1 + i_epoch, args.epochs, lr)) 134 | epoch_start_time = time() 135 | for i_inst, wi in enumerate( 136 | np.random.choice(len(target_words), len(target_words), replace=False), 137 | start=1, 138 | ) : 139 | target_emb = target_embs[wi] 140 | word = target_words[wi] 141 | model_word = bound_word(word) if args.word_boundary else word 142 | model_emb = model.embed(model_word) 143 | grad = MSE_backward(model_emb, target_emb) 144 | 145 | if i_inst % 20 == 0 : 146 | loss = MSE(model_emb, target_emb) / len(target_emb) # average over dimension for easy reading 147 | h.append(loss) 148 | if i_inst % 10000 == 0 : 149 | width = len(f"{len(target_words)}") 150 | fmt = 'processed {:%d}/{:%d} | loss {:.5f}' % (width, width) 151 | logger.info(fmt.format(i_inst, len(target_words), np.average(h))) 152 | h_epoch.extend(h) 153 | h = [] 154 | 155 | d = - lr * grad 156 | model.step(model_word, d) 157 | now_time = time() 158 | logger.info('epoch {i_epoch:>2} / {n_epoch} | loss {loss:.5f} | time {epoch_time:.2f}s / {training_time:.2f}s'.format( 159 | i_epoch = 1 + i_epoch, n_epoch = args.epochs, 160 | loss = np.average(h_epoch), 161 | epoch_time = now_time - epoch_start_time, 162 | training_time = now_time - start_time, 163 | )) 164 | 165 | logger.info('saving model...') 166 | model.dump(save_path) 167 | 168 | 169 | if __name__ == '__main__': 170 | main(parse_args()) 171 | -------------------------------------------------------------------------------- /pbos.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import Counter, defaultdict 3 | from functools import lru_cache, partial 4 | 5 | import numpy as np 6 | 7 | from utils import normalize_prob 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def get_subword_prob(sub, subword_prob, eps=None, take_root=False): 12 | prob = subword_prob.get(sub, eps if len(sub) == 1 else 0) 13 | if take_root: 14 | prob = prob ** (1 / len(sub)) 15 | return prob 16 | 17 | def calc_prefix_prob(w, get_subword_prob, backward=False): 18 | w = w[::-1] if backward else w 19 | p = [1] 20 | for i in range(1, len(w) + 1): 21 | p.append(sum( 22 | p[j] * get_subword_prob(w[j:i][::-1] if backward else w[j:i]) 23 | for j in range(i))) 24 | return p[::-1] if backward else p 25 | 26 | def calc_subword_weights( 27 | w, 28 | *, 29 | subword_vocab, 30 | get_subword_prob=None, 31 | weight_threshold=None, 32 | ): 33 | subword_weights = {} 34 | if get_subword_prob: 35 | p_prefix = calc_prefix_prob(w, get_subword_prob) 36 | p_suffix = calc_prefix_prob(w, get_subword_prob, backward=True) 37 | for j in range(1, len(w) + 1): 38 | for i in range(j): 39 | sub = w[i:j] 40 | if sub in subword_vocab: 41 | p_sub = get_subword_prob(sub) * p_prefix[i] * p_suffix[j] 42 | subword_weights.setdefault(sub, 0) 43 | subword_weights[sub] += p_sub 44 | subword_weights = normalize_prob(subword_weights) 45 | if weight_threshold: 46 | subword_weights = {k : v for k, v in subword_weights.items() if v > weight_threshold} 47 | else: 48 | for j in range(1, len(w) + 1): 49 | for i in range(j): 50 | sub = w[i:j] 51 | if sub in subword_vocab: 52 | subword_weights.setdefault(sub, 0) 53 | subword_weights[sub] += 1 54 | subword_weights = normalize_prob(subword_weights) 55 | 56 | if len(subword_weights) == 0: 57 | logger.warning(f"no qualified subwords for '{w}'") 58 | return {} 59 | 60 | return subword_weights 61 | 62 | 63 | class PBoS: 64 | def __init__( 65 | self, 66 | subword_embedding=None, 67 | *, 68 | subword_vocab, 69 | embedding_dim=None, 70 | subword_prob=None, 71 | weight_threshold=None, 72 | eps=1e-2, 73 | take_root=False, 74 | normalize_semb=False, 75 | ): 76 | """ 77 | Params: 78 | subword_embedding (default: None) - existing subword embeddings. 79 | If None, initialize an empty set of embeddings. 80 | 81 | embedding_dim (default: None) - embedding dimensions. 82 | If None, infer from `subword_embedding`. 83 | 84 | subword_prob (default: None) - subword probabilities. 85 | Used by probabilistic segmentation to calculate subword weights. 86 | If None, assume uniform probability, i.e. = BoS. 87 | 88 | subword_vocab - subword vocabulary. 89 | The set of subwords to maintain subword embeddings. 90 | OOV subwords will be regarded as having zero vector embedding. 91 | 92 | weight_threshold (default: None) - minimum subword weight to consider. 93 | Extremely low-weighted subword will be discarded for effiency. 94 | If None, consider subwords with any weights. 95 | 96 | eps (default: 1e-2) - the default subword probability if it is not 97 | present in `subword_prob`. This is needed to keep the segmenation 98 | graph connected. 99 | Only effective when `subword_prob` is present. 100 | 101 | take_root (default: False) - whether take `** ( 1 / len(sub))` when 102 | getting subword prob. 103 | """ 104 | self.semb = subword_embedding or defaultdict(float) 105 | if embedding_dim is None: 106 | subword_embedding_entry = next(iter(subword_embedding.values())) 107 | embedding_dim = len(subword_embedding_entry) 108 | for w in '<>': 109 | if w in subword_vocab: 110 | del subword_vocab[w] 111 | self._calc_subword_weights = lru_cache(maxsize=32)(partial( 112 | calc_subword_weights, 113 | subword_vocab=subword_vocab, 114 | get_subword_prob=partial( 115 | get_subword_prob, 116 | subword_prob=subword_prob, 117 | eps=eps, 118 | take_root=take_root, 119 | ) if subword_prob else None, 120 | weight_threshold=weight_threshold, 121 | )) 122 | self.config = dict( 123 | embedding_dim=embedding_dim, 124 | weight_threshold=weight_threshold, 125 | eps=eps, 126 | take_root=take_root, 127 | subword_vocab=subword_vocab, 128 | subword_prob=subword_prob, 129 | normalize_semb=normalize_semb, 130 | ) 131 | self._zero_emb = np.zeros(self.config['embedding_dim']) 132 | 133 | def dump(self, filename) : 134 | import json, pickle 135 | with open(filename + '.config.json', 'w') as fout: 136 | json.dump(self.config, fout) 137 | with open(filename, 'bw') as bfout : 138 | pickle.dump(self.semb, bfout) 139 | 140 | @classmethod 141 | def load(cls, filename) : 142 | import json, pickle 143 | try: 144 | # backward compatibility 145 | with open(filename, 'rb') as bfin: 146 | config, semb = pickle.load(bfin) 147 | except ValueError: 148 | with open(filename, 'rb') as bfin: 149 | semb = pickle.load(bfin) 150 | with open(filename + '.config.json') as fin: 151 | config = json.load(fin) 152 | bos = cls(**config) 153 | bos.semb = semb 154 | return bos 155 | 156 | @staticmethod 157 | def _semb_normalized_contrib(w, emb): 158 | norm = np.linalg.norm(emb) 159 | return w * emb / norm if norm > 1e-4 else 0 160 | 161 | def embed(self, w): 162 | subword_weights = self._calc_subword_weights(w) 163 | logger.debug(Counter(subword_weights).most_common()) 164 | # Will we have performance issue if we put the if check inside sum? 165 | if self.config['normalize_semb']: 166 | wemb = sum( 167 | self._semb_normalized_contrib(w, self.semb[sub]) 168 | for sub, w in subword_weights.items() 169 | ) 170 | else: 171 | wemb = sum( 172 | w * self.semb[sub] 173 | for sub, w in subword_weights.items() 174 | ) 175 | return wemb if isinstance(wemb, np.ndarray) else self._zero_emb 176 | 177 | def step(self, w, d): 178 | subword_weights = self._calc_subword_weights(w) 179 | for sub, weight in subword_weights.items(): 180 | self.semb[sub] += weight * d 181 | -------------------------------------------------------------------------------- /datasets/ud/make_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reads in CONLL files to make the dataset 3 | Output a textual vocabulary file, and a cPickle file of a dict with the following elements 4 | training_instances: List of (sentence, tags) for training data 5 | dev_instances 6 | test_instances 7 | w2i: Dict mapping words to indices 8 | t2is: Dict mapping attribute types (POS / morpho) to dicts from tags to indices 9 | c2i: Dict mapping characters to indices 10 | """ 11 | 12 | import argparse 13 | import codecs 14 | import collections 15 | import pickle 16 | from _collections import defaultdict 17 | 18 | __author__ = "Yuval Pinter and Robert Guthrie, 2017" 19 | 20 | Instance = collections.namedtuple("Instance", ["sentence", "tags"]) 21 | 22 | UNK_TAG = "" 23 | NONE_TAG = "" 24 | PADDING_CHAR = "<*>" 25 | POS_KEY = "POS" 26 | 27 | 28 | def split_tagstring(s, uni_key=False, has_pos=False): 29 | ''' 30 | Returns attribute-value mapping from UD-type CONLL field 31 | :param uni_key: if toggled, returns attribute-value pairs as joined strings (with the '=') 32 | :param has_pos: input line segment includes POS tag label 33 | ''' 34 | if has_pos: 35 | s = s.split("\t")[1] 36 | ret = [] if uni_key else {} 37 | if "=" not in s: # incorrect format 38 | return ret 39 | for attval in s.split('|'): 40 | attval = attval.strip() 41 | if not uni_key: 42 | a, v = attval.split('=') 43 | ret[a] = v 44 | else: 45 | ret.append(attval) 46 | return ret 47 | 48 | 49 | def read_file(filename, w2i, t2is, c2i, options): 50 | """ 51 | Read in a dataset and turn it into a list of instances. 52 | Modifies the w2i, t2is and c2i dicts, adding new words/attributes/tags/chars 53 | as it sees them. 54 | """ 55 | 56 | # populate mandatory t2i tables 57 | if POS_KEY not in t2is: 58 | t2is[POS_KEY] = {} 59 | 60 | # build dataset 61 | instances = [] 62 | vocab_counter = collections.Counter() 63 | with codecs.open(filename, "r", "utf-8") as f: 64 | 65 | # running sentence buffers (lines are tokens) 66 | sentence = [] 67 | tags = defaultdict(list) 68 | 69 | # main file reading loop 70 | for i, line in enumerate(f): 71 | 72 | # discard comments 73 | if line.startswith("#"): 74 | continue 75 | 76 | # parse sentence end 77 | elif line.isspace(): 78 | 79 | # pad tag lists to sentence end 80 | slen = len(sentence) 81 | for seq in list(tags.values()): 82 | if len(seq) < slen: 83 | seq.extend([0] * (slen - len(seq))) # 0 guaranteed below to represent NONE_TAG 84 | 85 | # add sentence to dataset 86 | instances.append(Instance(sentence, tags)) 87 | sentence = [] 88 | tags = defaultdict(list) 89 | 90 | else: 91 | 92 | # parse token information in line 93 | data = line.split("\t") 94 | if '-' in data[0]: 95 | # Some UD languages have contractions on a separate line, we don't want to include them also 96 | continue 97 | try: 98 | idx = int(data[0]) 99 | except: 100 | continue 101 | word = data[1] 102 | postag = data[3] if options.ud_tags else data[4] 103 | morphotags = {} if options.no_morphotags else split_tagstring(data[5], uni_key=False) 104 | 105 | # ensure counts and dictionary population 106 | vocab_counter[word] += 1 107 | if word not in w2i: 108 | w2i[word] = len(w2i) 109 | pt2i = t2is[POS_KEY] 110 | if postag not in pt2i: 111 | pt2i[postag] = len(pt2i) 112 | for c in word: 113 | if c not in c2i: 114 | c2i[c] = len(c2i) 115 | for key, val in list(morphotags.items()): 116 | if key not in t2is: 117 | t2is[key] = {NONE_TAG: 0} 118 | mt2i = t2is[key] 119 | if val not in mt2i: 120 | mt2i[val] = len(mt2i) 121 | 122 | # add data to sentence buffer 123 | sentence.append(w2i[word]) 124 | tags[POS_KEY].append(t2is[POS_KEY][postag]) 125 | for k, v in list(morphotags.items()): 126 | mtags = tags[k] 127 | # pad backwards to latest seen 128 | missing_tags = idx - len(mtags) - 1 129 | mtags.extend([0] * missing_tags) # 0 guaranteed above to represent NONE_TAG 130 | mtags.append(t2is[k][v]) 131 | 132 | return instances, vocab_counter 133 | 134 | 135 | if __name__ == "__main__": 136 | # parse command line arguments 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("--training-data", required=True, dest="training_data", help="Training data .txt file") 139 | parser.add_argument("--dev-data", required=True, dest="dev_data", help="Development data .txt file") 140 | parser.add_argument("--test-data", required=True, dest="test_data", help="Test data .txt file") 141 | parser.add_argument("--ud-tags", dest="ud_tags", action="store_true", 142 | help="Extract UD tags instead of original tags") 143 | parser.add_argument("--no-morphotags", dest="no_morphotags", action="store_true", 144 | help="Don't add morphosyntactic tags to dataset") 145 | parser.add_argument("--output", required=True, dest="output", help="Output filename (.pkl)") 146 | parser.add_argument("--vocab", dest="vocab_file", default="vocab.txt", help="Text file containing all of the words in \ 147 | the train/dev/test data to use in outputting embeddings") 148 | options = parser.parse_args() 149 | 150 | w2i = {} # mapping from word to index 151 | t2is = {} # mapping from attribute name to mapping from tag to index 152 | c2i = {} # mapping from character to index, for char-RNN concatenations 153 | output = {} 154 | 155 | # Add special tokens / tags / chars to dicts 156 | w2i[UNK_TAG] = len(w2i) 157 | c2i[PADDING_CHAR] = len(c2i) 158 | 159 | # read data from UD files 160 | output["training_instances"], output["training_vocab"] = read_file(options.training_data, w2i, t2is, c2i, options) 161 | output["dev_instances"], output["dev_vocab"] = read_file(options.dev_data, w2i, t2is, c2i, options) 162 | output["test_instances"], output["test_vocab"] = read_file(options.test_data, w2i, t2is, c2i, options) 163 | 164 | output["w2i"] = w2i 165 | output["t2is"] = t2is 166 | output["c2i"] = c2i 167 | 168 | # write outputs to files 169 | with open(options.output, "wb") as outfile: 170 | pickle.dump(output, outfile) 171 | with codecs.open(options.vocab_file, "w", "utf-8") as vocabfile: 172 | for word in list(w2i.keys()): 173 | vocabfile.write(word + "\n") 174 | -------------------------------------------------------------------------------- /pos_exp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing as mp 3 | import os 4 | import subprocess as sp 5 | 6 | from datasets import prepare_target_vector_paths, polyglot_languages, prepare_ud_paths, prepare_polyglot_freq_paths 7 | from utils.args import add_logging_args, set_logging_config 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def tee_open(path): 13 | return sp.Popen(['/usr/bin/tee', '-a', path], stdin=sp.PIPE) 14 | 15 | 16 | def pos_eval(ud_data_path, ud_vocab_embedding_path, result_path): 17 | """train pos tagging and report scores""" 18 | ud_log_path = os.path.join(result_path , "ud.log") 19 | ud_out_path = os.path.join(result_path , "ud.out") 20 | cmd = f""" 21 | python pos_eval.py \ 22 | --dataset {ud_data_path} \ 23 | --embeddings {ud_vocab_embedding_path} \ 24 | """.split() 25 | with \ 26 | tee_open(ud_log_path) as log_tee, \ 27 | tee_open(ud_out_path) as out_tee: 28 | sp.call(cmd, stdout=out_tee.stdin, stderr=log_tee.stdin) 29 | 30 | 31 | def evaluate_pbos(language_code, model_type): 32 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})] start...") 33 | 34 | # Input files 35 | polyglot_embeddings_path = prepare_target_vector_paths(language_code) 36 | polyglot_frequency_path = prepare_polyglot_freq_paths(language_code) 37 | 38 | # Output/result files 39 | result_path = os.path.join("results", "pos", language_code, model_type) 40 | os.makedirs(result_path, exist_ok=True) 41 | subword_vocab_path = os.path.join(result_path, "subword_vocab.jsonl") 42 | subword_prob_path = os.path.join(result_path, "subword_prob.jsonl") 43 | subword_embedding_model_path = os.path.join(result_path , "model.pbos") 44 | training_log_path = subword_embedding_model_path + ".log" 45 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 46 | f" result_path=`{result_path}`") 47 | 48 | # train subword embedding model using target embeddings and word freq 49 | if not os.path.exists(subword_embedding_model_path): 50 | # build subword vocab from target words 51 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 52 | f" building subword vocab...") 53 | cmd = f""" 54 | python subwords.py build_vocab \ 55 | --word_freq {polyglot_embeddings_path.word_freq_path} \ 56 | --output {subword_vocab_path} \ 57 | """ 58 | if model_type == 'bos': 59 | cmd += f" --subword_min_len 3" 60 | cmd += f" --subword_max_len 6" 61 | sp.call(cmd.split()) 62 | 63 | if model_type in ('pbos', 'pbosn'): 64 | # build subword prob from word freqs 65 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 66 | f" building subword prob...") 67 | cmd = f""" 68 | python subwords.py build_prob \ 69 | --word_freq {polyglot_frequency_path.word_freq_path} \ 70 | --output {subword_prob_path} \ 71 | """ 72 | sp.call(cmd.split()) 73 | else: 74 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 75 | f" skipped building subword prob.") 76 | 77 | # invoke training of subword model 78 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 79 | f" training subword model...") 80 | cmd = f""" 81 | python pbos_train.py \ 82 | --target_vectors {polyglot_embeddings_path.pkl_emb_path} \ 83 | --model_path {subword_embedding_model_path} \ 84 | --subword_vocab {subword_vocab_path} \ 85 | """ 86 | if model_type == "pbos": 87 | cmd += f" --subword_prob {subword_prob_path}" 88 | elif model_type == 'pbosn': 89 | cmd += f" --subword_prob {subword_prob_path}" 90 | cmd += f" --normalize_semb" 91 | cmd = cmd.split() 92 | with open(training_log_path, "w+") as log: 93 | sp.call(cmd, stdout=log, stderr=log) 94 | # with tee_open(training_log_path) as log_tee: 95 | # sp.call(cmd, stdout=log_tee.stdin, stderr=log_tee.stdin) 96 | else: 97 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 98 | f" skipped training subword model.") 99 | 100 | ud_data_path, ud_vocab_path = prepare_ud_paths(language_code) 101 | ud_vocab_embedding_path = os.path.join(result_path, "ud_vocab_embedding.txt") 102 | 103 | # predict embeddings for ud vocabs 104 | if not os.path.exists(ud_vocab_embedding_path): 105 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 106 | f" predicting word embeddings...") 107 | cmd = f""" 108 | python pbos_pred.py \ 109 | --queries {ud_vocab_path} \ 110 | --save {ud_vocab_embedding_path} \ 111 | --model {subword_embedding_model_path} \ 112 | """ 113 | # --pre_trained {polyglot_embeddings_path.pkl_emb_path} \ 114 | sp.call(cmd.split()) 115 | else: 116 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 117 | f" skipped predicting word embeddings.") 118 | 119 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 120 | f" evaluating on POS tagging...") 121 | pos_eval(ud_data_path, ud_vocab_embedding_path, result_path) 122 | 123 | logger.info(f"[evaluate_pbos({language_code}, model_type={model_type})]" 124 | f" done.") 125 | 126 | def main(): 127 | import argparse 128 | 129 | parser = argparse.ArgumentParser("Run POS tagging experiments on PolyGlot and UD", 130 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 131 | parser.add_argument('--languages', '-langs', nargs='+', metavar="LANG_CODE", 132 | choices=polyglot_languages + ["ALL"], 133 | default="ALL", 134 | help="languages to evaluate over") 135 | parser.add_argument('--num_processes', '-nproc', type=int, 136 | help="number of processers to use") 137 | add_logging_args(parser) 138 | args = parser.parse_args() 139 | 140 | set_logging_config(args) 141 | 142 | language_codes = polyglot_languages if "ALL" in args.languages else args.languages 143 | logger.debug(f"language_codes: {language_codes}") 144 | 145 | model_types = ("pbos", "bos") 146 | 147 | def job(apply): 148 | for language_code in language_codes: 149 | # prepare raw data without multiprocessing, 150 | # otherwise trouble comes with race conditions of file write 151 | print(language_code) 152 | prepare_target_vector_paths(language_code) 153 | prepare_polyglot_freq_paths(language_code) 154 | prepare_ud_paths(language_code) 155 | for model_type in model_types: 156 | apply(evaluate_pbos, (language_code, model_type,)) 157 | if args.num_processes == 1: 158 | def apply(func, args): 159 | return func(*args) 160 | job(apply) 161 | else: 162 | with mp.Pool(args.num_processes) as pool: 163 | results = [] 164 | def apply(func, args): 165 | return results.append(pool.apply_async(func, args)) 166 | job(apply) 167 | for r in results: 168 | r.get() 169 | logger.debug("done.") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /subwords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import Counter 4 | import json 5 | import logging 6 | 7 | from tqdm import tqdm 8 | 9 | from utils import file_tqdm, get_substrings, normalize_prob 10 | from utils.args import add_logging_args, set_logging_config, dump_args 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def bound_word(word): 16 | return '<' + word + '>' 17 | 18 | 19 | def build_subword_counter( 20 | word_count_iter, 21 | max_size=None, 22 | min_count=None, 23 | min_len=1, 24 | max_len=None, 25 | word_boundary=False, 26 | uniq_factor=None, 27 | ): 28 | subword_counter = Counter() 29 | for word, count in iter(word_count_iter): 30 | if word_boundary: 31 | word = bound_word(word) 32 | for subword in get_substrings(word, min_len=min_len, max_len=max_len): 33 | subword_counter[subword] += count 34 | 35 | if min_count: 36 | subword_counter = Counter({k : v for k, v in subword_counter.items() if v >= min_count}) 37 | if uniq_factor is not None: 38 | for sub in tqdm(list(subword_counter)): 39 | for subsub in get_substrings(sub, min_len=min_len, max_len=max_len): 40 | if subsub != sub and subsub in subword_counter and subword_counter[subsub] * uniq_factor <= subword_counter[sub]: 41 | del subword_counter[subsub] 42 | if max_size: 43 | subword_count_pairs = subword_counter.most_common(max_size) 44 | else: 45 | subword_count_pairs = subword_counter.items() 46 | return Counter(dict(subword_counter)) 47 | 48 | 49 | 50 | def subword_prob_post_process(subword_prob, min_prob=None, take_root=False): 51 | if min_prob: 52 | subword_prob = {k : v for k, v in subword_prob.items() if v >= min_prob} 53 | if take_root: 54 | subword_prob = {k : (v ** (1 / len(k))) for k, v in subword_prob.items()} 55 | return subword_prob 56 | 57 | def build_subword_prob( 58 | subword_counter, 59 | normalize_prob=normalize_prob, 60 | min_prob=None, 61 | take_root=False, 62 | ): 63 | subword_prob = normalize_prob(subword_counter) 64 | subword_prob = subword_prob_post_process( 65 | subword_prob, 66 | min_prob=min_prob, 67 | take_root=take_root, 68 | ) 69 | return Counter(subword_prob) 70 | 71 | 72 | def parse_args(): 73 | parser = argparse.ArgumentParser(description='Subword processing', 74 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 75 | add_args(parser) 76 | return parser.parse_args() 77 | 78 | 79 | def add_args(parser): 80 | parser.add_argument('command', choices=['build_vocab', 'build_prob']) 81 | parser.add_argument('--word_freq', required=True, 82 | help='word frequencies (.jsonl). ' 83 | 'Each line is a pair of word and its count.') 84 | parser.add_argument('--output', default='subword.jsonl', 85 | help='output file (.jsonl). ' 86 | 'Each line is a pair of word and count (build_vocab) ' 87 | 'or a pair of word and score (build_prob).') 88 | add_logging_args(parser) 89 | add_subword_args(parser) 90 | add_subword_vocab_args(parser) 91 | add_subword_prob_args(parser) 92 | return parser 93 | 94 | 95 | def add_word_args(parser): 96 | group = parser.add_argument_group('word arguments') 97 | group.add_argument('--word_boundary', '-wb', action='store_true', 98 | help="annotate word boundary with '<' and '>'") 99 | group.add_argument('--no_word_boundary', '-Nwb', 100 | dest='word_boundary', action='store_false') 101 | return group 102 | 103 | 104 | def add_subword_args(parser): 105 | add_word_args(parser) 106 | group = parser.add_argument_group('subword arguments') 107 | group.add_argument('--subword_min_count', type=int, 108 | help="subword min count for it to be included in vocab") 109 | group.add_argument('--subword_min_len', type=int, default=1, 110 | help="subword min length for it to be included in vocab") 111 | group.add_argument('--subword_max_len', type=int, 112 | help="subword max length for it to be included in vocab") 113 | group.add_argument('--subword_uniq_factor', '-suf', type=float, 114 | help="subword uniqueness factor") 115 | return group 116 | 117 | 118 | def add_subword_vocab_args(parser): 119 | group = parser.add_argument_group('subword vocab arguments') 120 | group.add_argument('--subword_vocab_max_size', type=int, 121 | help="maximum size of subword vocab") 122 | return group 123 | 124 | 125 | def add_subword_prob_args(parser): 126 | group = parser.add_argument_group('subword prob arguments') 127 | group.add_argument('--subword_prob_min_prob', '-spmp', type=float, 128 | help="minimum prob score of subword vocab") 129 | group.add_argument('--subword_prob_take_root', '-sptr', action='store_true', 130 | help="take `** (1 / len(subword))` for prob score") 131 | group.add_argument('--no_subword_prob_take_root', '-Nsptr', 132 | dest='subword_prob_take_root', action='store_false') 133 | return group 134 | 135 | 136 | def build_subword_vocab_cli(args): 137 | if os.path.exists(args.output): 138 | logger.warning(f"{args.output} already exists!") 139 | 140 | logger.info("loading...") 141 | with open(args.word_freq) as fin: 142 | word_count_iter = (json.loads(line) for line in file_tqdm(fin)) 143 | subword_counter = build_subword_counter( 144 | word_count_iter, 145 | max_size=args.subword_vocab_max_size, 146 | min_count=args.subword_min_count, 147 | min_len=args.subword_min_len, 148 | max_len=args.subword_max_len, 149 | word_boundary=args.word_boundary, 150 | uniq_factor=args.subword_uniq_factor, 151 | ) 152 | logger.info("processing...") 153 | subword_vocab = subword_counter 154 | logger.info("saving...") 155 | with open(args.output, 'w') as fout: 156 | for (subword, count) in tqdm(subword_vocab.most_common()): 157 | print(json.dumps((subword, count)), file=fout) 158 | 159 | 160 | def build_subword_prob_cli(args): 161 | if os.path.exists(args.output): 162 | logger.warning(f"{args.output} already exists!") 163 | 164 | logger.info("loading...") 165 | with open(args.word_freq) as fin: 166 | word_count_iter = (json.loads(line) for line in file_tqdm(fin)) 167 | subword_counter = build_subword_counter( 168 | word_count_iter, 169 | min_count=args.subword_min_count, 170 | min_len=args.subword_min_len, 171 | max_len=args.subword_max_len, 172 | word_boundary=args.word_boundary, 173 | uniq_factor=args.subword_uniq_factor, 174 | ) 175 | logger.info("processing...") 176 | if args.subword_prob_take_root: 177 | logger.warning("`args.subword_prob_take_root = True` ignored at this step.") 178 | subword_prob = build_subword_prob( 179 | subword_counter, 180 | normalize_prob=normalize_prob, 181 | min_prob=args.subword_prob_min_prob, 182 | # take_root=args.subword_prob_take_root, 183 | ) 184 | logger.info("saving...") 185 | with open(args.output, 'w') as fout: 186 | for (subword, prob) in tqdm(subword_prob.most_common()): 187 | print(json.dumps((subword, prob)), file=fout) 188 | 189 | 190 | def main_cli(args): 191 | set_logging_config(args) 192 | if args.command == 'build_vocab': 193 | build_subword_vocab_cli(args) 194 | elif args.command == 'build_prob': 195 | build_subword_prob_cli(args) 196 | else: 197 | raise ValueError(f"Unknown command `{args.command}`") 198 | 199 | 200 | if __name__ == '__main__': 201 | main_cli(parse_args()) 202 | --------------------------------------------------------------------------------