├── .DS_Store ├── pretrained ├── khpos.gat.pt └── khpos.gcn.pt ├── requirements.txt ├── configs ├── config.gat.json └── config.gcn.json ├── segment.py ├── train.py ├── evaluate.py ├── .gitignore ├── datasets.py ├── README.md ├── utils.py └── networks.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Socret360/jaws/HEAD/.DS_Store -------------------------------------------------------------------------------- /pretrained/khpos.gat.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Socret360/jaws/HEAD/pretrained/khpos.gat.pt -------------------------------------------------------------------------------- /pretrained/khpos.gcn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Socret360/jaws/HEAD/pretrained/khpos.gcn.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://data.pyg.org/whl/torch-2.2.0+cpu.html 2 | torch 3 | kornia 4 | torchmetrics 5 | torch_geometric -------------------------------------------------------------------------------- /configs/config.gat.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "gat", 3 | "hidden_dims": [ 4 | 294, 5 | 148, 6 | 74, 7 | 38, 8 | 16 9 | ] 10 | } -------------------------------------------------------------------------------- /configs/config.gcn.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "gcn", 3 | "hidden_dims": [ 4 | 294, 5 | 148, 6 | 74, 7 | 38, 8 | 16 9 | ] 10 | } -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | # 4 | from utils import preprocess_phylypo_sample 5 | from networks import JAWSModel 6 | 7 | parser = argparse.ArgumentParser( 8 | description='Run the segmentation on one sample text file.') 9 | parser.add_argument('config', type=str, help='Path to config file.') 10 | parser.add_argument('model_path', type=str, help='Path to model weight file.') 11 | parser.add_argument('sample', type=str, 12 | help='Path to the input text file containing the text to segment.') 13 | parser.add_argument('--file_mode', action='store_true', 14 | help='Wether sample is a file') 15 | parser.add_argument('--console_output', action='store_true', 16 | help='Whether to output the result to console') 17 | 18 | args = parser.parse_args() 19 | 20 | model = JAWSModel(args.config) 21 | model.load(args.model_path) 22 | 23 | if args.file_mode: 24 | with open(args.sample, 'r') as input_file: 25 | input_str = input_file.read() 26 | else: 27 | input_str = args.sample 28 | 29 | preprocessed, _ = preprocess_phylypo_sample(input_str) 30 | out = model.segment(preprocessed) 31 | 32 | if not args.console_output: 33 | with open(args.sample.replace(".txt", ".seg.txt"), "w") as outfile: 34 | outfile.write(out) 35 | else: 36 | print(out) 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | # 4 | from utils import VOCAB 5 | from networks import JAWSModel 6 | from datasets import build_dataset 7 | 8 | 9 | parser = argparse.ArgumentParser(description='Run the training loop.') 10 | parser.add_argument('config', type=str, help='Path to config file.') 11 | parser.add_argument('dataset_path', type=str, 12 | help='Path to dataset. (text file if of type khpos or directory if of type phylypo)') 13 | parser.add_argument('--output_dir', type=str, 14 | help='Path to output directory.', default=".") 15 | parser.add_argument('--dataset_type', type=str, 16 | help='The type of dataset to use', 17 | default="khpos", 18 | choices=['khpos', 'phylypo']) 19 | parser.add_argument('--lr', type=float, 20 | help='The learning rate', 21 | default=0.01) 22 | parser.add_argument('--early_stopping_patience', type=int, 23 | help='The learning rate', 24 | default=10) 25 | parser.add_argument('--epochs', type=int, 26 | help='The number of epochs to train', 27 | default=None) 28 | parser.add_argument('--batch_size', type=int, 29 | help='The number of samples per batch', 30 | default=32) 31 | parser.add_argument('--pretrained_weights', type=str, 32 | help='The path to the pretrained weights') 33 | args = parser.parse_args() 34 | 35 | 36 | dataset = build_dataset( 37 | args.dataset_path, 38 | args.dataset_type 39 | ) 40 | 41 | 42 | model = JAWSModel(args.config) 43 | 44 | model.fit( 45 | data=dataset, 46 | epochs=args.epochs, 47 | learning_rate=args.lr, 48 | batch_size=args.batch_size, 49 | model_temp_dir=args.output_dir, 50 | pretrained_weights_path=args.pretrained_weights, 51 | early_stopping_patience=args.early_stopping_patience, 52 | ) 53 | 54 | output_filename = os.path.basename(args.config).replace(".json", ".pt") 55 | model.save(os.path.join(args.output_dir, output_filename)) 56 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from torchmetrics.text import WordErrorRate 5 | from torch_geometric.loader import DataLoader 6 | from torchmetrics.classification import BinaryF1Score 7 | # 8 | from utils import post_process 9 | from networks import JAWSModel 10 | from datasets import build_dataset 11 | 12 | parser = argparse.ArgumentParser(description='Run the training loop.') 13 | parser.add_argument('config', type=str, help='Path to config file.') 14 | parser.add_argument('model_path', type=str, help='Path to model weight file.') 15 | parser.add_argument('dataset_path', type=str, 16 | help='Path to dataset. (text file if of type khpos or directory if of type phylypo)') 17 | parser.add_argument('--output_dir', type=str, 18 | help='Path to output directory.', default=".") 19 | parser.add_argument('--dataset_type', type=str, 20 | help='The type of dataset to use', 21 | default="khpos", 22 | choices=['khpos', 'phylypo']) 23 | args = parser.parse_args() 24 | 25 | model = JAWSModel(args.config) 26 | model.load(args.model_path) 27 | 28 | dataset = build_dataset(args.dataset_path, args.dataset_type) 29 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False) 30 | 31 | wer, f1_score = WordErrorRate(), BinaryF1Score() 32 | 33 | output_filename = os.path.basename(args.dataset_path)\ 34 | .replace(".txt", ".result.txt") 35 | 36 | with open(os.path.join(args.output_dir, output_filename), 'w') as outfile: 37 | total_wer, total_f1, n = 0, 0, 0 38 | for i, sample in enumerate(tqdm(dataloader)): 39 | characters, original = sample.characters[0], sample.original[0] 40 | 41 | y_pred = model.predict_sample(sample) 42 | out = post_process(y_pred.numpy(), characters) 43 | 44 | swer = wer([out], [original]) 45 | f1 = f1_score(y_pred, sample.y.argmax(dim=-1)) 46 | 47 | outfile.write(f"{swer}\t\t{original}\t\t{out}\n") 48 | 49 | total_wer += swer 50 | total_f1 += f1 51 | n += 1 52 | 53 | metrics = f"\nf1: {(total_f1/n)*100}, wer: {(total_wer/n)*100}" 54 | outfile.write(metrics) 55 | print(metrics) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | data 162 | out 163 | temp.py -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from glob import glob 4 | from torch_geometric.data import Dataset 5 | # 6 | from utils import text_to_graph, preprocess_khpos_sample, preprocess_phylypo_sample 7 | 8 | def build_dataset(path: str, type: str) -> Dataset: 9 | """ Construct the appropriate dataset object from `path` and `type`. 10 | 11 | Args: 12 | --- 13 | - `path`: str 14 | The path to the dataset. (Text file if `type` is khpos and a directory if of type `phylypo`) 15 | - `type`: str 16 | The type of dataset. (One of type `khpos` and `phylypo`) 17 | 18 | e 19 | """ 20 | dataset = { 21 | "khpos": KhPOSDataset, 22 | "phylypo": PhylypoDataset, 23 | }[type] 24 | return dataset(path) 25 | 26 | 27 | class KhPOSDataset(Dataset): 28 | """ Pytorch dataset generator for processing KhPOS dataset. The dataset can be found here: 29 | https://github.com/ye-kyaw-thu/khPOS. The samples in the dataset are listed as single lines 30 | in a text file. 31 | """ 32 | 33 | def __init__(self, filepath: str) -> None: 34 | """ Initialize the KhPOS dataset generator. 35 | 36 | Args: 37 | --- 38 | - `filepath`: str 39 | Path pointing to the .txt file containing a list of sample texts. 40 | """ 41 | with open(filepath, "r") as file: 42 | self.samples = file.readlines() 43 | 44 | def __len__(self): 45 | """ Returns the length of the samples. 46 | 47 | Returns: 48 | --- 49 | An integer representing the number of samples of the dataset. 50 | """ 51 | return len(self.samples) 52 | 53 | def __getitem__(self, idx): 54 | """ Returns a graph representation of the sample with index `idx`. 55 | 56 | Args: 57 | --- 58 | `idx`: int - An integer representing the index of the samples. 59 | 60 | Returns: 61 | --- 62 | A preprocessed sample as a graph in the form of pytorch `Data` object. 63 | """ 64 | text = self.samples[idx] 65 | preprocessed, original = preprocess_khpos_sample(text.strip()) 66 | graph = text_to_graph(preprocessed, original) 67 | return graph 68 | 69 | 70 | class PhylypoDataset(Dataset): 71 | """ Pytorch dataset generator for preprocessing dataset scraped by Phylypo. The dataset can be found here: 72 | https://github.com/phylypo/segmentation-crf-khmer. The samples in the dataset are listed as text files in a directory. 73 | This generator assumes that each file contains Khmer text separated by space and each file name is post fix by _seg (e.g. sample1_seg.txt). 74 | """ 75 | 76 | def __init__(self, data_dir: str) -> None: 77 | """ Initialize Phylypo dataset generator. 78 | 79 | Args: 80 | --- 81 | - `data_dir`: str 82 | Path to the dataset samples folder. 83 | """ 84 | self.samples = sorted(list(glob(os.path.join(data_dir, "*_seg.txt")))) 85 | 86 | def __len__(self): 87 | """ Returns the length of the samples. 88 | 89 | Returns: 90 | --- 91 | An integer representing the number of samples of the dataset. 92 | """ 93 | return len(self.samples) 94 | 95 | def __getitem__(self, idx): 96 | """ Returns a graph representation of the sample with index `idx`. 97 | 98 | Args: 99 | --- 100 | `idx`: int 101 | An integer representing the index of the samples. 102 | 103 | Returns: 104 | --- 105 | A preprocessed sample as a graph in the form of pytorch `Data` object. 106 | """ 107 | with open(self.samples[idx], 'r', encoding='utf-8') as sample_file: 108 | preprocessed, original = preprocess_phylypo_sample( 109 | sample_file.read()) 110 | graph = text_to_graph(preprocessed, original) 111 | return graph 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Just Another Word Segmentor (JAWS) 2 | 3 | A Khmer word segmentation model based on Graph Neural Networks. It represent pairs of characters as nodes and perform node classification to determined if there is a space between two characters. 4 | 5 | ### Samples taken from Facebook posts 6 | 7 | ``` 8 | Source: តែបើអត់លុយសងផង មានលុយប្តឹងដែល ? 9 | Output: តែ បើ អត់លុយ សង ផង មាន លុយ ប្តឹង ដែល ? 10 | 11 | Source: កូនជាប់រៀនទាំងអស់គ្នាមិនបានជូនដើរលេងទេ តែគ្មានអ្វីសប្បាយជាងការជួបជុំទេ 12 | Output: កូន ជាប់ រៀន ទាំង អស់ គ្នា មិន បានជូន ដើរ លេង ទេ តែ គ្មាន អ្វី សប្បាយ ជាង ការជួប ជុំទេ 13 | 14 | Source: ស្អាតណាស់ក្មេងតូច 15 | Output: ស្អាត ណាស់ ក្មេង តូច 16 | 17 | Source: ចង់ចូលផ្តល់កម្លាំងចិត្តនិងជួយស្រង់ក្លិននិងណាបង 18 | Output: ចង់ ចូល ផ្តល់ កម្លាំង ចិត្ត និង ជួយ ស្រង់ ក្លិន និង ណាបង 19 | 20 | Source: កាន់តស្រស់ស្អាតប្លែកតែម្ដងហើយ BA លៀសហាលយើងហ្នឹង 21 | Output: កាន់ តស្រស់ ស្អាត ប្លែក តែ ម្ដង ហើយ BA លៀស ហាល យើង ហ្នឹង 22 | 23 | Source: ខ្ញុំកំពុងនៅជាមួយខាងCMG CCTV ដើម្បីថតផ្សព្វផ្សាយវប្បធម៌ខ្មែរទៅកាន់ពិភពលោក ឥឡូវនេះខ្ញុំកំពុងមើលព្រះអាទិត្យរះ នៅប្រាសាទអង្គរវត្ត🇰🇭✨🌅 24 | Output: ខ្ញុំ កំពុង នៅ ជាមួយ ខាង CMG CCTV ដើម្បី ថត ផ្សព្វផ្សាយ វប្បធម៌ ខ្មែរ ទៅ កាន់ ពិភព លោក ឥឡូវ នេះ ខ្ញុំ កំពុង មើល ព្រះ អាទិត្យ រះ នៅ ប្រាសាទ អង្គរ វត្ត 🇰🇭✨🌅 25 | 26 | Source: សូម្បីតែសម្លេងក៏ឈ្លោះគ្នាដែរ ឈ្លោះគ្រប់យ៉ាង mak Lin 😭 27 | Output: សូម្បី តែ សម្លេង ក៏ ឈ្លោះ គ្នា ដែរ ឈ្លោះ គ្រប់ យ៉ាង mak Lin 😭 28 | 29 | Source: Global Green Growth Week បានត្រលប់មកវិញហើយ នៅសប្តាហ៍ក្រោយនេះ! 30 | Output: Global Green Growth Week បាន ត្រលប់ មក វិញ ហើយ នៅ សប្តាហ៍ ក្រោយ នេះ ! 31 | 32 | Source: ញាត្តិគាំទ្រ Cambodian Pageant និងអ្នកគាំទ្រវិស័យបវរកញ្ញាកម្ពុជាចំពោះប្រសាសន៍លោកស្រីអ៉ឹម សុគន្ធា Sokunthea Im ក្នុងការប្តេជ្ញាចិត្ត សំណូមពរឳ្យស្ថាប័នពាក់ព័ន្ធ បិទការប្រកួត និងការបញ្ជូនបេក្ខភាពតំណាងកម្ពុជាចូលរួមប្រកួត Miss Grand International 33 | Output: ញា ត្តិ គាំទ្រ Cambodian Pageant និង អ្នក គាំទ្រ វិស័យ បវរ កញ្ញា កម្ពុជា ចំពោះ ប្រសាសន៍ លោក ស្រី អ៉ឹម សុគន្ធា Sokunthea Im ក្នុង ការ ប្តេជ្ញា ចិត្ត សំណូម ពរឳ្យ ស្ថាប័ន ពាក់ ព័ន្ធ បិទ ការ ប្រកួត និង ការ បញ្ជូន បេក្ខ ភាព តំណាង កម្ពុជា ចូល រួម ប្រកួត Miss Grand International 34 | 35 | Source: ហ្វេនកម្មវិធី MasterChef បញ្ចេញសមត្ថភាពតាមរយះការខំមិនបន្តិចទៅមើល! គួរចម្អិនមុខម្ហូបអ្វីទៅណ៎?​😍 36 | Output: ហ្វេន កម្មវិធី MasterChef បញ្ចេញ សមត្ថ ភាព តាម រយះការ ខំ មិន បន្តិច ទៅ មើល ! គួរ ចម្អិន មុខ ម្ហូប អ្វី ទៅ ណ៎ ? ​😍 37 | ``` 38 | 39 | ## How to run 40 | 41 | ### Segment 42 | 43 | You can perform word segmentation using the pretrained model by running the below command. 44 | 45 | ```bash 46 | usage: segment.py [-h] [--file_mode] [--console_output] config model_path sample 47 | 48 | Run the segmentation on one sample text file. 49 | 50 | positional arguments: 51 | config Path to config file. 52 | model_path Path to model weight file. 53 | sample Path to the input text file containing the text to segment. 54 | 55 | optional arguments: 56 | -h, --help show this help message and exit 57 | --file_mode Wether sample is a file 58 | --console_output Whether to output the result to console 59 | ``` 60 | 61 | ### Training 62 | 63 | The model can be retrained on your custom dataset using the below command. 64 | 65 | ```bash 66 | usage: train.py [-h] [--output_dir OUTPUT_DIR] [--dataset_type {khpos,phylypo}] [--lr LR] [--early_stopping_patience EARLY_STOPPING_PATIENCE] [--epochs EPOCHS] config dataset_path 67 | 68 | Run the training loop. 69 | 70 | positional arguments: 71 | config Path to config file. 72 | dataset_path Path to dataset. (text file if of type khpos or directory if of type phylypo) 73 | 74 | optional arguments: 75 | -h, --help show this help message and exit 76 | --output_dir OUTPUT_DIR 77 | Path to output directory. 78 | --dataset_type {khpos,phylypo} 79 | The type of dataset to use 80 | --lr LR The learning rate 81 | --early_stopping_patience EARLY_STOPPING_PATIENCE 82 | The learning rate 83 | --epochs EPOCHS The number of epochs to train 84 | ``` 85 | 86 | ### Evaluation 87 | 88 | To evaluate a trained model, use the below command. 89 | 90 | ```bash 91 | usage: evaluate.py [-h] [--output_dir OUTPUT_DIR] [--dataset_type {khpos,phylypo}] config model_path dataset_path 92 | 93 | Run the training loop. 94 | 95 | positional arguments: 96 | config Path to config file. 97 | model_path Path to model weight file. 98 | dataset_path Path to dataset. (text file if of type khpos or directory if of type phylypo) 99 | 100 | optional arguments: 101 | -h, --help show this help message and exit 102 | --output_dir OUTPUT_DIR 103 | Path to output directory. 104 | --dataset_type {khpos,phylypo} 105 | The type of dataset to use 106 | ``` 107 | 108 | ## About Pretrained Weights 109 | 110 | The pretrained weights are available [here](pretrained). Each weights were produced by training using the dataset from KhPOS available [here](https://github.com/ye-kyaw-thu/khPOS). 111 | Below is the results on the test sets from the set repository. 112 | 113 | **Open Set**: [here](https://github.com/ye-kyaw-thu/khPOS/blob/master/corpus-draft-ver-1.0/data/OPEN-TEST.word) 114 | | Model | F1 Score | Word Error Rate | 115 | | ------- | -------- | --------------- | 116 | | JAWSGCN | 98.035% | 5.915% | 117 | | JAWSGAT | 98.462% | 4.911% | 118 | 119 | **Close Set** [here](https://github.com/ye-kyaw-thu/khPOS/blob/master/corpus-draft-ver-1.0/data/CLOSE-TEST.word) 120 | | Model | F1 Score | Word Error Rate | 121 | | ------- | -------- | --------------- | 122 | | JAWSGCN | 98.814% | 3.855% | 123 | | JAWSGAT | 98.941% | 3.487% | 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | import itertools 5 | import numpy as np 6 | from typing import Dict, Any 7 | from torch_geometric.data import Data 8 | import torch_geometric.transforms as T 9 | from torch_geometric.data import Dataset 10 | 11 | NON_SPACE_SEP = "" 12 | SPACE_SEP = "" 13 | CONSTS = list(set(u'កខគឃងចឆជឈញដឋឌឍណតថទធនបផពភមយរលវឝឞសហឡអឣឤឥឦឧឨឩឪឫឬឭឮឯឰឱឲឳ')) 14 | VOWELS = list(set(u'឴឵ាិីឹឺុូួើឿៀេែៃោៅ\u17c6\u17c7\u17c8')) 15 | SUB = list(set(u'្')) 16 | DIAC = list(set(u"\u17c9\u17ca\u17cb\u17cc\u17cd\u17ce\u17cf\u17d0")) 17 | SYMS = list(set('៕។៛ៗ៚៙៘,.? ')) 18 | NUMBERS = list(set(u'០១២៣៤៥៦៧៨៩0123456789')) 19 | LUNAR = list(set('᧠᧡᧢᧣᧤᧥᧦᧧᧨᧩᧪᧫᧬᧭᧮᧯᧰᧱᧲᧳᧴᧵᧶᧷᧸᧹᧺᧻᧼᧽᧾᧿')) 20 | VOCAB = sorted(CONSTS + VOWELS + SUB + DIAC + NUMBERS + SYMS + LUNAR) 21 | REGEX_MATCHES_ALL_IGNORE_SPACES = f"([^ {''.join(re.escape(char) for char in VOCAB)}]+)(\u0020)|(\u0020)([^ {''.join(re.escape(char) for char in VOCAB)}]+)" 22 | # including unk characters 23 | FEATURE_VECTOR_LENGTH = len(VOCAB)+2 24 | 25 | CHAR_TO_INT = {v: i for i, v in enumerate(VOCAB)} 26 | CHAR_TO_INT["[SP]"] = len(VOCAB) 27 | 28 | LABEL_MAP = [NON_SPACE_SEP, SPACE_SEP] 29 | 30 | 31 | def preprocess_khpos_sample(text: str): 32 | """ Convert `text` in each line of the KhPos dataset to 33 | the appropriate format. (e.g. charchar) 34 | 35 | Args: 36 | --- 37 | - `text`: str 38 | The input text to preprocess. 39 | 40 | Returns: 41 | --- 42 | (str, str) 43 | A tuple containing the preprocessed text and the raw clean input text. 44 | """ 45 | text = text.strip('\n') 46 | # set compound words, prefix, and suffix marker as space 47 | text = re.sub('[\^~_]', '\u0020', text) 48 | preprocesed, text = preprocess_phylypo_sample(text) 49 | return preprocesed, text 50 | 51 | 52 | def preprocess_phylypo_sample(text: str) -> str: 53 | """ Convert `text` in each sample of the Phylypo dataset to 54 | the appropriate format. (e.g. charchar) 55 | 56 | Args: 57 | --- 58 | - `text`: str 59 | The input text to preprocess. 60 | 61 | Returns: 62 | --- 63 | (str, str) 64 | A tuple containing the preprocessed text and the raw clean input text. 65 | """ 66 | text = text.strip('\n') 67 | text = text.strip() 68 | text = text.replace('️', '\u200b') 69 | text = text.replace('\n', '\u0020') 70 | text = re.sub('\u0020+', '\u0020', text) 71 | text = re.sub( 72 | rf"([^{re.escape(''.join(VOCAB))}])\u0020", 73 | lambda match: f"{match.group(1)}[SP]", 74 | text 75 | ) 76 | text = re.sub( 77 | rf"\u0020([^{re.escape(''.join(VOCAB))}])", 78 | lambda match: f"[SP]{match.group(1)}", 79 | text 80 | ) 81 | text = re.sub( 82 | rf"([{re.escape(''.join(VOCAB))}])([^{re.escape(''.join(VOCAB))}])", 83 | lambda match: f"{match.group(1)}[SP]{match.group(2)}", 84 | text 85 | ) 86 | text = re.sub( 87 | rf"([^{re.escape(''.join(VOCAB))}])([{re.escape(''.join(VOCAB))}])", 88 | lambda match: f"{match.group(1)}[SP]{match.group(2)}", 89 | text 90 | ) 91 | text = re.sub(r'(\[SP\])+', "[SP]", text) 92 | preprocessed = NON_SPACE_SEP.join(re.findall(r'\[SP\]|.', text)) 93 | preprocessed = preprocessed.replace( 94 | f"{NON_SPACE_SEP}\u0020{NON_SPACE_SEP}", 95 | SPACE_SEP 96 | ) 97 | return preprocessed, text.replace("[SP]", "\u0020") 98 | 99 | 100 | def text_to_graph(text: str, original: str = None): 101 | """ Convert `text` in a graph representation suitable for training. 102 | 103 | Args: 104 | --- 105 | - `text`: str 106 | The preprocessed input text. (e.g. charchar 107 | - `original`: str 108 | The original clean representation of the text. 109 | 110 | Returns: 111 | --- 112 | str 113 | The graph representation of `text`. 114 | """ 115 | delimeter = f"{NON_SPACE_SEP}|{SPACE_SEP}" 116 | characters = re.split(delimeter, text) 117 | separators = re.findall(delimeter, text) 118 | 119 | features = np.array([CHAR_TO_INT[char] if char in CHAR_TO_INT.keys() else FEATURE_VECTOR_LENGTH-1 120 | for char in characters]) 121 | 122 | characters = "".join(characters) 123 | characters = characters.replace("[SP]", '\u0020') 124 | characters = list(characters) 125 | 126 | labels = np.array([LABEL_MAP.index(sep) for sep in separators]) 127 | 128 | x = np.zeros((len(characters), FEATURE_VECTOR_LENGTH)) 129 | x[np.arange(0, len(characters)), features] = 1 130 | x = np.repeat(x, 2, axis=0)[1:-1] 131 | x = x.flatten().reshape((len(separators), FEATURE_VECTOR_LENGTH*2)) 132 | 133 | y = np.zeros((len(separators), len(LABEL_MAP))) 134 | y[np.arange(0, len(separators)), labels] = 1 135 | 136 | edge_index = np.expand_dims(np.arange(0, x.shape[0]-1), axis=-1) 137 | edge_index = np.concatenate([ 138 | edge_index, 139 | edge_index+1, 140 | ], axis=1) 141 | 142 | x = torch.from_numpy(x).float() 143 | y = torch.from_numpy(y).float() 144 | edge_index = torch.from_numpy(edge_index).t().contiguous() 145 | 146 | graph = Data( 147 | x=x, 148 | y=y, 149 | edge_index=edge_index, 150 | characters=characters, 151 | separators=separators, 152 | original=original, 153 | ) 154 | graph = T.ToUndirected()(graph) 155 | 156 | return graph 157 | 158 | 159 | def post_process( 160 | y_pred: torch.tensor, 161 | characters: np.array, 162 | ) -> str: 163 | """ Convert a single model prediction to output text. 164 | 165 | Args: 166 | --- 167 | - `y_pred`: torch.tensor 168 | The predicted tensor of shape (num_nodes,) produced by the model. 169 | - `characters`: np.array 170 | A numpy array containing all the characters in the input text. 171 | 172 | Returns: 173 | str 174 | The output segmented text. 175 | """ 176 | y_pred = list(np.array(LABEL_MAP)[y_pred])+[None] 177 | out = "".join(list(itertools.chain( 178 | *zip(characters, y_pred)))[:-1]) 179 | out = out.replace(NON_SPACE_SEP, '') 180 | out = out.replace(SPACE_SEP, '\u0020') 181 | return out 182 | 183 | 184 | def pretty_training_log(log: Dict[str, Any]): 185 | """ Produce a prettry training log. 186 | 187 | Args: 188 | --- 189 | - `log`: Dict[str, Any] 190 | The training log produced during each epoch of training. 191 | 192 | Returns: 193 | --- 194 | str 195 | String of `log` separated by commas. 196 | """ 197 | output = [] 198 | for k, v in log.items(): 199 | output.append(f"{k}: {v}") 200 | 201 | return ", ".join(output) 202 | 203 | 204 | def read_config_file(filepath: str) -> Dict[str, Any]: 205 | """ Read the config file. 206 | 207 | Args: 208 | --- 209 | - `filepath`: str 210 | Path to the config file. 211 | 212 | Returns: 213 | --- 214 | Dict[str, Any] 215 | Dictionary representing the configuration file. 216 | """ 217 | with open(filepath, "r") as file: 218 | return json.load(file) 219 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import torch 5 | import tempfile 6 | from tqdm import tqdm 7 | from typing import List, Dict 8 | from torch.nn import ModuleList 9 | import torch.nn.functional as F 10 | from torch_geometric.data import Data 11 | from torchmetrics.text import WordErrorRate 12 | from torch_geometric.loader import DataLoader 13 | from torch_geometric.nn import GATConv, GCNConv 14 | from torchmetrics.classification import BinaryF1Score 15 | from kornia.losses import binary_focal_loss_with_logits 16 | # 17 | from utils import pretty_training_log, text_to_graph, post_process, read_config_file, FEATURE_VECTOR_LENGTH 18 | 19 | 20 | class JAWSNetwork(torch.nn.Module): 21 | def __init__( 22 | self, 23 | num_features: int, 24 | hidden_dims: List[int], 25 | type: str = "gcn", 26 | ): 27 | super().__init__() 28 | 29 | conv = { 30 | "gcn": GCNConv, 31 | "gat": GATConv, 32 | }[type] 33 | all_dims = [num_features] + hidden_dims 34 | self.layers = ModuleList([ 35 | conv(all_dims[i], all_dims[i+1]) 36 | for i in range(len(all_dims)-1) 37 | ]) 38 | self.predictor = conv(sum(hidden_dims), 2) 39 | 40 | def forward(self, data): 41 | h0, edge_index = data.x, data.edge_index 42 | # 43 | out = [h0] 44 | for i, layer in enumerate(self.layers): 45 | h = layer(out[i], edge_index).relu() 46 | out.append(h) 47 | 48 | y = self.predictor(torch.concat(out[1:], dim=-1), edge_index) 49 | out = F.softmax(y, dim=1) 50 | return out 51 | 52 | 53 | class JAWSModel: 54 | def __init__(self, config_file: str) -> None: 55 | config = read_config_file(config_file) 56 | self.device = torch.device( 57 | "cuda" if torch.cuda.is_available() else "cpu") 58 | self.model = JAWSNetwork( 59 | type=config["type"], 60 | num_features=FEATURE_VECTOR_LENGTH*2, 61 | hidden_dims=config["hidden_dims"], 62 | ).to(self.device) 63 | 64 | def fit( 65 | self, 66 | data, 67 | val_data=None, 68 | batch_size=32, 69 | learning_rate=0.01, 70 | mode: str = "min", 71 | epochs: int = None, 72 | monitor: str = "loss", 73 | logging_interval: int = 1, 74 | model_temp_dir: str = None, 75 | restore_best_weights: int = True, 76 | early_stopping_patience: int = 10, 77 | pretrained_weights_path=None, 78 | ): 79 | self.temp_weights_path = os.path.join( 80 | tempfile.gettempdir() if model_temp_dir is None else model_temp_dir, 81 | "best_weights.pt" 82 | ) 83 | self.f1_score = BinaryF1Score().to(self.device) 84 | self.wer = WordErrorRate().to(self.device) 85 | 86 | self.optimizer = torch.optim.Adam( 87 | self.model.parameters(), 88 | lr=learning_rate 89 | ) 90 | 91 | has_val_set = val_data is not None 92 | train_loader = DataLoader(data, batch_size=batch_size, shuffle=True) 93 | val_loader = None 94 | if has_val_set: 95 | val_loader = DataLoader( 96 | val_data, 97 | shuffle=False, 98 | batch_size=batch_size, 99 | ) 100 | 101 | self.best_epoch, best_metric, epoch = 0, None, 0 102 | 103 | if pretrained_weights_path is not None: 104 | print( 105 | f"Loading pretrained weights from: {pretrained_weights_path}") 106 | self.model.load_state_dict(torch.load( 107 | pretrained_weights_path, map_location=self.device)) 108 | 109 | while True: 110 | 111 | # check if maximum epoch reached or early stopping criterion reached 112 | reached_max_epochs = epochs is not None and epoch >= epochs 113 | exceeded_early_stop_patience = ( 114 | epoch - self.best_epoch) > early_stopping_patience if early_stopping_patience is not None else False 115 | 116 | if reached_max_epochs or exceeded_early_stop_patience: 117 | break 118 | 119 | self.model.train(True) 120 | 121 | epoch_result = self.__train(train_loader, use_grad=True) 122 | 123 | if has_val_set: 124 | self.model.train(False) 125 | epoch_result_val = self.__train(val_loader, use_grad=False) 126 | epoch_result = { 127 | **epoch_result, 128 | **{f"val_{k}": v for k, v in epoch_result_val.items()} 129 | } 130 | 131 | curr_metric = epoch_result[monitor] 132 | 133 | has_improved = True 134 | if best_metric is not None: 135 | if mode == "min": 136 | has_improved = curr_metric < best_metric 137 | elif mode == "max": 138 | has_improved = curr_metric > best_metric 139 | 140 | if has_improved: 141 | best_metric = curr_metric 142 | self.best_epoch = epoch 143 | torch.save(self.model.state_dict(), self.temp_weights_path) 144 | 145 | training_log = { 146 | "epoch": epoch, 147 | "best_epoch": self.best_epoch, 148 | **epoch_result 149 | } 150 | 151 | if epoch % logging_interval == 0: 152 | pretty_log = pretty_training_log(training_log) 153 | print(pretty_log) 154 | 155 | epoch += 1 156 | 157 | if restore_best_weights: 158 | self.model.load_state_dict(torch.load(self.temp_weights_path)) 159 | 160 | def save(self, output_path: str): 161 | """ Save the model weights to `output_path`. 162 | 163 | Args: 164 | --- 165 | - output_path: str - The path to the output file. 166 | """ 167 | torch.save(self.model.state_dict(), output_path) 168 | 169 | def load(self, model_path: str): 170 | """ Load the weights from `model_path`. 171 | 172 | Args: 173 | --- 174 | - model_path: str 175 | The path to the model file. 176 | """ 177 | self.model.load_state_dict(torch.load( 178 | model_path, map_location=torch.device('cpu'))) 179 | 180 | def segment(self, text: str) -> str: 181 | """ Segment the `text` into words. 182 | 183 | Args: 184 | --- 185 | - text: str - The input text. 186 | 187 | Returns: 188 | --- 189 | output segmented text. 190 | """ 191 | data = text_to_graph(text) 192 | o = self.model(data) 193 | y_pred = o.argmax(dim=1).numpy() 194 | out = post_process(y_pred, data.characters) 195 | out = re.sub('\u0020+', '\u0020', out) 196 | return out 197 | 198 | def predict_sample(self, sample: Data) -> torch.Tensor: 199 | o = self.model(sample) 200 | return o.argmax(dim=1) 201 | 202 | def __train(self, dataloader: DataLoader, use_grad=False) -> Dict[str, float]: 203 | total_loss, total_f1, total_wer, n = 0, 0, 0, 0 204 | 205 | for data in tqdm(dataloader): 206 | data = data.to(self.device) 207 | if use_grad: 208 | self.optimizer.zero_grad() 209 | 210 | out = self.model(data) 211 | text = self.__prediction_to_string( 212 | out.argmax(dim=-1), 213 | data.batch, 214 | data.characters, 215 | ) 216 | loss = binary_focal_loss_with_logits( 217 | out, data.y, alpha=0.75, gamma=2.0, reduction='mean') 218 | 219 | f1 = self.f1_score(out.argmax(dim=-1), data.y.argmax(dim=-1)) 220 | wer = self.wer(text, data.original) 221 | 222 | total_loss += loss 223 | total_f1 += f1 224 | total_wer += wer 225 | n += 1 226 | 227 | if use_grad: 228 | loss.backward() 229 | self.optimizer.step() 230 | 231 | return { 232 | "loss": float(total_loss/n), 233 | "f1": float(total_f1/n), 234 | "wer": float(total_wer/n), 235 | } 236 | 237 | def __prediction_to_string(self, y_preds, batch, characters): 238 | out = [] 239 | for i in range(len(characters)): 240 | text = post_process( 241 | y_preds[batch == i].cpu(), 242 | characters[i] 243 | ) 244 | out.append(text) 245 | return out 246 | --------------------------------------------------------------------------------