├── deepspain ├── __init__.py ├── utils.py ├── model.py ├── embeddings.py ├── search.py ├── dataset.py └── tensorboard.py ├── logs └── .gitignore ├── models ├── large_accuracy.metric ├── encoder_large_finetuned.pth.dvc ├── learner_large_finetuned.pkl.dvc ├── model_large_finetuned.pth.dvc └── .gitignore ├── pretrained ├── .gitignore ├── itos.pkl.dvc └── encoder.pth.dvc ├── pylama.ini ├── data ├── .gitignore ├── empty_data.dvc ├── lm_data.pkl.dvc ├── test.jsonlines.dvc ├── output.jsonlines.dvc └── sample_databunch.pkl.dvc ├── .dvc ├── config └── .gitignore ├── .gitignore ├── .vscode └── settings.json ├── environment.yaml ├── train_small.sh ├── train_medium.sh ├── train_large.sh ├── search.py ├── make_dataset.py ├── Makefile ├── README.md ├── train_small.dvc ├── process_dataset.py ├── train_large.dvc ├── evaluate.py ├── index_documents.py └── train.py /deepspain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | /large 2 | -------------------------------------------------------------------------------- /models/large_accuracy.metric: -------------------------------------------------------------------------------- 1 | 0.6337456107139587 -------------------------------------------------------------------------------- /pretrained/.gitignore: -------------------------------------------------------------------------------- 1 | /encoder.pth 2 | /itos.pkl 3 | -------------------------------------------------------------------------------- /pylama.ini: -------------------------------------------------------------------------------- 1 | [pylama:pycodestyle] 2 | max_line_length = 120 3 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /output.jsonlines 2 | /sample_databunch.pkl 3 | /lm_data.pkl 4 | /test.jsonlines 5 | /empty_data 6 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | ['remote "s3remote"'] 2 | url = s3://datascience.codegram.com/deepspain 3 | [core] 4 | remote = s3remote 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | *.zip 6 | .cache 7 | tmp/ 8 | -------------------------------------------------------------------------------- /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /lock 2 | /config.local 3 | /updater 4 | /updater.lock 5 | /state-journal 6 | /state-wal 7 | /state 8 | /cache 9 | -------------------------------------------------------------------------------- /data/empty_data.dvc: -------------------------------------------------------------------------------- 1 | md5: 7ccdca83442c6becfb42f1294f73f0b8 2 | outs: 3 | - md5: 01890703a4546f2ff59f30cb2717f0fc 4 | path: empty_data 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /data/lm_data.pkl.dvc: -------------------------------------------------------------------------------- 1 | md5: a50d893758edbe410f87f260a49f3113 2 | outs: 3 | - md5: e80b9f35bff0c31f18359df8b251a375 4 | path: lm_data.pkl 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /pretrained/itos.pkl.dvc: -------------------------------------------------------------------------------- 1 | md5: f06fc9c00f9f2e06bbf6baec0f2de0a6 2 | outs: 3 | - md5: 1d63af1f87ac9de8ab6a24e9c3287620 4 | path: itos.pkl 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /data/test.jsonlines.dvc: -------------------------------------------------------------------------------- 1 | md5: b1026cce241ece1d060d01d6192f4b34 2 | outs: 3 | - md5: 8fa39771eff483bcc7c1f4ecfda560cd 4 | path: test.jsonlines 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /pretrained/encoder.pth.dvc: -------------------------------------------------------------------------------- 1 | md5: 7035db78ca469c3d68210160b9e2b41d 2 | outs: 3 | - md5: 273f2bf368876aad4761e5bc14d8353e 4 | path: encoder.pth 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/usr/local/anaconda3/envs/deepspain/bin/python", 3 | "python.dataScience.jupyterServerURI": "local", 4 | "python.formatting.provider": "black" 5 | } -------------------------------------------------------------------------------- /data/output.jsonlines.dvc: -------------------------------------------------------------------------------- 1 | md5: 518507968d1a5f10ed4c897bb81a8cea 2 | outs: 3 | - md5: 717786c5cc2940a4f9aec3adf94510be 4 | path: output.jsonlines 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /data/sample_databunch.pkl.dvc: -------------------------------------------------------------------------------- 1 | md5: f0471363231c6f90acd0b01558d2af13 2 | outs: 3 | - md5: 3e528b683178549881f98348d31a827f 4 | path: sample_databunch.pkl 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /models/encoder_large_finetuned.pth.dvc: -------------------------------------------------------------------------------- 1 | md5: 069855373df0f223ac9eba07c292d25c 2 | outs: 3 | - md5: 7df8a732818bdfa465cb7281d2ac0a5b 4 | path: encoder_large_finetuned.pth 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /models/learner_large_finetuned.pkl.dvc: -------------------------------------------------------------------------------- 1 | md5: 9816f2b4c8d9487634205355257daebe 2 | outs: 3 | - md5: 7d7be3c937df550737ae28ef135ad234 4 | path: learner_large_finetuned.pkl 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /models/model_large_finetuned.pth.dvc: -------------------------------------------------------------------------------- 1 | md5: 57209836af1d48352e597d85ebfa1712 2 | outs: 3 | - md5: 81248de751c11640865df85b75738e19 4 | path: model_large_finetuned.pth 5 | cache: true 6 | metric: false 7 | persist: false 8 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: deepspain 2 | channels: 3 | - defaults 4 | dependencies: 5 | - pip: 6 | - click 7 | - aiohttp 8 | - black 9 | - dvc 10 | - boto3 11 | - tensorflow 12 | - tensorboardX 13 | - elasticsearch==7.0.4 14 | - fastai==1.0.57 15 | - torch==1.2.0 16 | - torchvision==0.4.0 17 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | /empty_data 2 | /encoder_head.pth 3 | /model_head.pth 4 | /learner_head.pkl 5 | /small_empty_data 6 | /encoder_small_head.pth 7 | /model_small_head.pth 8 | /learner_small_head.pkl 9 | /bestmodel* 10 | /large_empty_data 11 | /encoder_large_head.pth 12 | /model_large_head.pth 13 | /learner_large_head.pkl 14 | /encoder_large_finetuned.pth 15 | /learner_large_finetuned.pkl 16 | /model_large_finetuned.pth 17 | -------------------------------------------------------------------------------- /deepspain/utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Callable 2 | import time 3 | 4 | A = TypeVar("A") 5 | 6 | 7 | def measure(label: str, f: Callable[[], A], debug: bool = False) -> A: 8 | if debug: 9 | t = time.process_time() 10 | x = f() 11 | elapsed_time = time.process_time() - t 12 | print(label + " took " + str(elapsed_time)) 13 | return x 14 | else: 15 | return f() 16 | -------------------------------------------------------------------------------- /train_small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_count=$(nvidia-smi --query-gpu=gpu_name --format=csv | grep '[^name]' | wc -l) 4 | dvc run -f train_small.dvc \ 5 | -d train.py -d data/sample_databunch.pkl \ 6 | -d pretrained/encoder.pth \ 7 | -d pretrained/itos.pkl \ 8 | -o models/small_empty_data \ 9 | -o models/encoder_small_head.pth \ 10 | -o models/model_small_head.pth \ 11 | -o models/learner_small_head.pkl \ 12 | -M models/small_accuracy.metric \ 13 | python3 -m torch.distributed.launch \ 14 | --nproc_per_node=$gpu_count \ 15 | train.py \ 16 | data/sample_databunch.pkl \ 17 | models/ \ 18 | pretrained/encoder.pth \ 19 | pretrained/itos.pkl\ 20 | --label small \ 21 | --head-epochs 2 \ 22 | --gpus $gpu_count \ 23 | --head-only 24 | -------------------------------------------------------------------------------- /train_medium.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_count=$(nvidia-smi --query-gpu=gpu_name --format=csv | grep '[^name]' | wc -l) 4 | dvc run -f train_medium.dvc \ 5 | -d train.py -d data/lm_data.pkl \ 6 | -d pretrained/encoder.pth \ 7 | -d pretrained/itos.pkl \ 8 | -o models/medium_empty_data \ 9 | -o models/encoder_medium_head.pth \ 10 | -o models/model_medium_head.pth \ 11 | -o models/learner_medium_head.pkl \ 12 | -o models/encoder_medium_finetuned.pth \ 13 | -o models/model_medium_finetuned.pth \ 14 | -o models/learner_medium_finetuned.pkl \ 15 | -M models/medium_accuracy.metric \ 16 | python3 -m torch.distributed.launch \ 17 | --nproc_per_node=$gpu_count \ 18 | train.py \ 19 | data/lm_data.pkl \ 20 | models/ \ 21 | pretrained/encoder.pth \ 22 | pretrained/itos.pkl\ 23 | --label medium \ 24 | --head-epochs 1 \ 25 | --backbone-epochs 1 \ 26 | --gpus $gpu_count 27 | -------------------------------------------------------------------------------- /train_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_count=$(nvidia-smi --query-gpu=gpu_name --format=csv | grep '[^name]' | wc -l) 4 | dvc run -f train_large.dvc \ 5 | -d train.py -d data/lm_data.pkl \ 6 | -d pretrained/encoder.pth \ 7 | -d pretrained/itos.pkl \ 8 | -o models/large_empty_data \ 9 | -o models/encoder_large_head.pth \ 10 | -o models/model_large_head.pth \ 11 | -o models/learner_large_head.pkl \ 12 | -o models/encoder_large_finetuned.pth \ 13 | -o models/model_large_finetuned.pth \ 14 | -o models/learner_large_finetuned.pkl \ 15 | -o logs/large \ 16 | -M models/large_accuracy.metric \ 17 | python3 -m torch.distributed.launch \ 18 | --nproc_per_node=$gpu_count \ 19 | train.py \ 20 | data/lm_data.pkl \ 21 | models/ \ 22 | pretrained/encoder.pth \ 23 | pretrained/itos.pkl\ 24 | --label large \ 25 | --head-epochs 4 \ 26 | --backbone-epochs 2 \ 27 | --gpus $gpu_count 28 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | from elasticsearch import Elasticsearch 5 | from fire import Fire 6 | 7 | from deepspain.utils import measure 8 | from deepspain.model import from_encoder 9 | from deepspain.search import search 10 | 11 | 12 | def main( 13 | models_path: Path, 14 | query: str, 15 | index_name: str = "boe", 16 | host: str = "localhost", 17 | port: int = 9200, 18 | debug: bool = False, 19 | ): 20 | with warnings.catch_warnings(): 21 | warnings.filterwarnings("ignore", category=UserWarning) 22 | learner = measure( 23 | "model loading", 24 | lambda: from_encoder(models_path, encoder_name="encoder_large_finetuned"), 25 | debug, 26 | ) 27 | es = Elasticsearch(hosts=[{"host": host, "port": port}]) 28 | for result in search(es, learner, index_name, query, debug): 29 | print("\n") 30 | print(result) 31 | 32 | 33 | if __name__ == "__main__": 34 | Fire(main) 35 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pathlib import Path 3 | 4 | import click 5 | 6 | from deepspain.dataset import download, generate_ids 7 | 8 | 9 | @click.command() 10 | @click.argument("from_boe_id", metavar="") 11 | @click.argument( 12 | "output_file", metavar="", type=click.Path(dir_okay=False) 13 | ) 14 | def main(from_boe_id: str, output_file: str): 15 | """Fetches all BOEs starting at until yesterday's, and stores them in . 16 | Running multiple times will append to that file.""" 17 | boe_ids = list(generate_ids(from_boe_id)) 18 | if Path(output_file).exists(): 19 | click.echo(output_file + " already exists, so we'll be adding to it.") 20 | click.echo("Fetching " + str(len(boe_ids)) + " since " + from_boe_id + "✨") 21 | loop = asyncio.get_event_loop() 22 | future = asyncio.ensure_future(download(boe_ids, output_file)) 23 | loop.run_until_complete(future) 24 | click.echo("Fetched " + str(len(boe_ids)) + " items.") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /deepspain/model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from fastai.text import ( 4 | LanguageLearner, 5 | TransformerXL, 6 | language_model_learner, 7 | DataBunch, 8 | TextList, 9 | ) 10 | 11 | 12 | def from_encoder(model_path: Path, encoder_name: str) -> LanguageLearner: 13 | """Loads a trained language model for inference.""" 14 | print("Loading model for inference....") 15 | data = DataBunch.load_empty(model_path, "data/empty_data") 16 | learn = language_model_learner(data, TransformerXL, pretrained=False) 17 | learn.load_encoder(encoder_name) 18 | learn.freeze() 19 | learn.model.eval() 20 | return learn 21 | 22 | 23 | def from_model(model_path: Path, model_name: str) -> LanguageLearner: 24 | """Loads a trained language model for inference.""" 25 | print("Loading model for inference....") 26 | data = DataBunch.load_empty(model_path, "data/empty_data") 27 | learn = language_model_learner(data, TransformerXL, pretrained=False) 28 | learn.load(model_name) 29 | learn.freeze() 30 | learn.model.eval() 31 | return learn 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | deps: environment.yaml 2 | conda env create -f environment.yaml 3 | conda activate deepspain 4 | 5 | lint: 6 | python -m pylama 7 | 8 | /storage/boe: 9 | mkdir -p /storage/boe 10 | 11 | /storage/boe/lm_data.pkl: /storage/boe 12 | curl "https://s3-eu-west-1.amazonaws.com/datascience.codegram.com/lm_data.pkl" > /storage/boe/lm_data.pkl 13 | 14 | /storage/boe/itos.pkl: /storage/boe 15 | curl "https://s3-eu-west-1.amazonaws.com/datascience.codegram.com/itos.pkl" > /storage/boe/itos.pkl 16 | 17 | /storage/boe/encM2.pth: /storage/boe 18 | curl "https://s3-eu-west-1.amazonaws.com/datascience.codegram.com/encM2.pth" > /storage/boe/encM2.pth 19 | 20 | databunch: /storage/boe/lm_data.pkl 21 | pretrained_model: /storage/boe/itos.pkl /storage/boe/encM2.pth 22 | 23 | gputrain: deps databunch pretrained_model 24 | bash scripts/gputrain.sh 25 | 26 | train: 27 | gradient experiments run singlenode \ 28 | --name boe_language_model \ 29 | --projectId przjwc38i \ 30 | --container paperspace/fastai:1.0-CUDA9.2-base-3.0-v1.0.6 \ 31 | --machineType P5000 \ 32 | --command 'make gputrain' 33 | 34 | .PHONY: lint deps 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSpain 2 | 3 | A fine-tuned language model on the Spanish BOE (Official State Bulletin), using Fast.ai. 4 | 5 | It uses a [pre-trained Spanish TransformerXL language model](https://github.com/mmcctt00/SpanishTransformerXL). 6 | 7 | The idea is to facilitate: 8 | 9 | - Training classifiers 10 | - Similarity search between items via their word embeddings 11 | - Semantic search 12 | 13 | ## Setup 14 | 15 | Make sure you have Python 3.7 and access to a powerful GPU. Language models need a lot of memory and take a long time to train. 16 | 17 | make deps 18 | 19 | ## Creating the BOE dataset from scratcho 20 | 21 | Pick a starting date from which you wish to fetch the BOE documents, up until yesterday. For example, to start at 2018-11-05: 22 | 23 | python make_dataset.py BOE-S-20181105 output.jsonlines 24 | 25 | ## Processing the data for language modeling 26 | 27 | python prepare_databunch.py output.jsonlines lm_data 28 | 29 | This will output a Pickle file called `lm_data.pkl`. 30 | 31 | ## Training the model 32 | 33 | mkdir models/ 34 | python train.py lm_data.pkl models/ 35 | 36 | This will output a pickled Language Learner under `models/boe_lm`, which you can use to encode strings, extract their embeddings, etcetera. 37 | -------------------------------------------------------------------------------- /deepspain/embeddings.py: -------------------------------------------------------------------------------- 1 | """Use a Language Model to extract embeddings for text documents.""" 2 | from fastai.text import LanguageLearner 3 | from deepspain.utils import measure 4 | import torch 5 | from torch import Tensor 6 | from typing import Sequence, cast, Any 7 | 8 | 9 | def word_embeddings(learner: LanguageLearner, s: str, debug: bool = False) -> Tensor: 10 | tokens, _ = measure("tokenizing", lambda: learner.data.one_item(s), debug) 11 | measure("resetting model", lambda: learner.model.reset(), debug) 12 | encoder = learner.model[0] 13 | outputs = measure("predicting", lambda: encoder(tokens), debug) 14 | embeddings = outputs[-1][-1] 15 | return embeddings 16 | 17 | 18 | def doc2vec( 19 | learner: LanguageLearner, s: str, debug: bool = False, max_dim: int = 1024 20 | ) -> Sequence[Tensor]: 21 | with torch.no_grad(): 22 | embeddings = measure( 23 | "get_full_embeddings", lambda: word_embeddings(learner, s, debug), debug 24 | ) 25 | avg_pool = embeddings.mean(dim=1) 26 | max_pool = embeddings.max(dim=1)[0] 27 | last = cast(Tensor, cast(Any, embeddings)[:, -1]) # workaround pyright issue 28 | return ( 29 | torch.cat([last, max_pool, avg_pool], 1).to("cpu").squeeze().split(max_dim) 30 | ) 31 | -------------------------------------------------------------------------------- /train_small.dvc: -------------------------------------------------------------------------------- 1 | md5: f4ca0d7d4510d54f0c89b5a72f5db4fe 2 | cmd: python3 -m torch.distributed.launch --nproc_per_node=2 train.py data/sample_databunch.pkl 3 | models/ pretrained/encoder.pth pretrained/itos.pkl --label small --head-epochs 2 4 | --gpus 2 --head-only 5 | deps: 6 | - md5: d6ad557dff657ea570a422e92464a8fb 7 | path: train.py 8 | - md5: 3e528b683178549881f98348d31a827f 9 | path: data/sample_databunch.pkl 10 | - md5: 273f2bf368876aad4761e5bc14d8353e 11 | path: pretrained/encoder.pth 12 | - md5: 1d63af1f87ac9de8ab6a24e9c3287620 13 | path: pretrained/itos.pkl 14 | outs: 15 | - md5: eb67885984a9bdd5ced985bb70b779eb 16 | path: models/small_empty_data 17 | cache: true 18 | metric: false 19 | persist: false 20 | - md5: ecf13b4f32dfd5b608d448ddf4257179 21 | path: models/encoder_small_head.pth 22 | cache: true 23 | metric: false 24 | persist: false 25 | - md5: 444498742e7d6dd3d4f36fa0803ceb01 26 | path: models/model_small_head.pth 27 | cache: true 28 | metric: false 29 | persist: false 30 | - md5: 9c855383005a7317fb303f18a3ca507f 31 | path: models/learner_small_head.pkl 32 | cache: true 33 | metric: false 34 | persist: false 35 | - md5: 47d0fe7ad9c008045617fe43d87415bd 36 | path: models/small_accuracy.metric 37 | cache: false 38 | metric: true 39 | persist: false 40 | -------------------------------------------------------------------------------- /process_dataset.py: -------------------------------------------------------------------------------- 1 | import click 2 | import jsonlines 3 | import pandas as pd 4 | 5 | from deepspain.dataset import df_to_lm_databunch 6 | 7 | 8 | @click.command() 9 | @click.argument( 10 | "jsonlines_path", 11 | metavar="", 12 | type=click.Path(exists=True, dir_okay=False), 13 | ) 14 | @click.argument( 15 | "output_file", 16 | metavar="", 17 | type=click.Path(exists=False, dir_okay=False), 18 | ) 19 | @click.option( 20 | "--sample", is_flag=True, default=False, help="Use only 1% of the data, to test." 21 | ) 22 | def main(dataset_path: str, output_file: str, sample: bool): 23 | """Turn the data in into a DataBunch suitable for Language Modeling, 24 | saving it to .""" 25 | 26 | click.echo("Turning raw data into a Databunch suitable for language modeling") 27 | rows = [] 28 | with jsonlines.open(dataset_path) as reader: 29 | for obj in reader.iter(type=dict, skip_invalid=True): 30 | rows.append(obj) 31 | df = pd.DataFrame(rows) 32 | click.echo("Created dataframe with shape " + str(df.shape)) 33 | databunch = df_to_lm_databunch(df, columns=["title", "content"], sample=sample) 34 | databunch.save(output_file) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /train_large.dvc: -------------------------------------------------------------------------------- 1 | md5: 756ad25463746a7e572a085dd318eb5a 2 | cmd: python3 -m torch.distributed.launch --nproc_per_node=8 train.py data/lm_data.pkl 3 | models/ pretrained/encoder.pth pretrained/itos.pkl --label large --head-epochs 4 4 | --backbone-epochs 2 --gpus 8 5 | deps: 6 | - md5: 652acc7954ade70cc40c47bcde519eb4 7 | path: train.py 8 | - md5: e80b9f35bff0c31f18359df8b251a375 9 | path: data/lm_data.pkl 10 | - md5: 273f2bf368876aad4761e5bc14d8353e 11 | path: pretrained/encoder.pth 12 | - md5: 1d63af1f87ac9de8ab6a24e9c3287620 13 | path: pretrained/itos.pkl 14 | outs: 15 | - md5: 9063d4822e223dc5dcee374d1413c44a 16 | path: models/large_empty_data 17 | cache: true 18 | metric: false 19 | persist: false 20 | - md5: 4506432257b23107d12825cfb5f4f176 21 | path: models/encoder_large_head.pth 22 | cache: true 23 | metric: false 24 | persist: false 25 | - md5: 444145df2ec97f6fe53fb73f106c28d5 26 | path: models/model_large_head.pth 27 | cache: true 28 | metric: false 29 | persist: false 30 | - md5: 893553ed253a3560f2149271ebd6cdc7 31 | path: models/learner_large_head.pkl 32 | cache: true 33 | metric: false 34 | persist: false 35 | - md5: 481be1c9ee04eca46f3dfd397915646c.dir 36 | path: logs/large 37 | cache: true 38 | metric: false 39 | persist: false 40 | - md5: e41c5f245b5f5822a5f9f01b9c522013 41 | path: models/large_accuracy.metric 42 | cache: false 43 | metric: true 44 | persist: false 45 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | import click 4 | import jsonlines 5 | import pandas as pd 6 | 7 | from fastai.text import TextList 8 | 9 | from deepspain.utils import measure 10 | from deepspain.model import from_model 11 | 12 | 13 | @click.command() 14 | @click.argument( 15 | "models_path", type=click.Path(exists=True, file_okay=False), metavar="" 16 | ) 17 | @click.argument( 18 | "test_data_json", 19 | type=click.Path(exists=True, dir_okay=False), 20 | metavar="", 21 | ) 22 | @click.option("--debug", is_flag=True, default=False) 23 | def main(models_path: Path, test_data_json: Path, debug: bool): 24 | """Evaluates a language model against a test data set.""" 25 | 26 | with warnings.catch_warnings(): 27 | warnings.filterwarnings("ignore", category=UserWarning) 28 | 29 | print(f"Loading test data from {test_data_json}...") 30 | rows = [] 31 | with jsonlines.open(test_data_json) as reader: 32 | for obj in reader.iter(type=dict, skip_invalid=True): 33 | rows.append(obj) 34 | df = pd.DataFrame(rows) 35 | test_databunch = ( 36 | TextList.from_df(df, path=models_path, cols=["title", "content"]) 37 | .split_none() 38 | .label_for_lm() 39 | .databunch(bs=4) 40 | ) 41 | 42 | learner = measure( 43 | "model loading", 44 | lambda: from_model(models_path, model_name="model_large_finetuned"), 45 | debug, 46 | ) 47 | 48 | print(learner.validate(dl=test_databunch.train_dl)) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /index_documents.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | import click 4 | 5 | from elasticsearch import Elasticsearch 6 | 7 | from deepspain.utils import measure 8 | from deepspain.model import from_encoder 9 | from deepspain.dataset import load_databunch 10 | from deepspain.search import recreate_index, index_document 11 | 12 | 13 | @click.command() 14 | @click.argument( 15 | "models_path", type=click.Path(exists=True, file_okay=False), metavar="" 16 | ) 17 | @click.argument( 18 | "data_path", type=click.Path(exists=True, dir_okay=False), metavar="" 19 | ) 20 | @click.option( 21 | "--drop-index", 22 | is_flag=True, 23 | default=False, 24 | help="Whether to drop and recreate the index before indexing", 25 | ) 26 | @click.option( 27 | "--index_name", 28 | type=str, 29 | default="boe", 30 | help='ElasticSearch index name (default "boe")', 31 | ) 32 | @click.option( 33 | "--host", 34 | type=str, 35 | default="localhost", 36 | help='ElasticSearch hostname (default "localhost")', 37 | ) 38 | @click.option( 39 | "--port", type=int, default=9200, help="ElasticSearch port (default 9200)" 40 | ) 41 | @click.option( 42 | "--limit-bytes", 43 | type=int, 44 | default=5000, 45 | help="The bytes to keep from the text before indexing", 46 | ) 47 | @click.option("--debug", is_flag=True, default=False) 48 | def main( 49 | models_path: Path, 50 | data_path: Path, 51 | drop_index: bool, 52 | index_name: str, 53 | host: str, 54 | port: int, 55 | limit_bytes: int, 56 | debug: bool, 57 | ): 58 | """Index all the training rows in into ElasticSearch.""" 59 | 60 | with warnings.catch_warnings(): 61 | warnings.filterwarnings("ignore", category=UserWarning) 62 | 63 | learner = measure( 64 | "encoder loading", 65 | lambda: from_encoder(models_path, encoder_name="encoder_large_finetuned"), 66 | debug, 67 | ) 68 | 69 | es = Elasticsearch(hosts=[{"host": host, "port": port}]) 70 | if drop_index: 71 | print("Recreating index...") 72 | recreate_index(es, learner, index_name, debug) 73 | print("Loading data...") 74 | df = load_databunch(Path(data_path), debug).train_ds.inner_df 75 | total = df.shape[0] 76 | print(f"Indexing {total} rows...") 77 | for idx, row in df.iterrows(): 78 | measure( 79 | f"{idx}/{total}", 80 | lambda: index_document( 81 | es, learner, index_name, row.to_dict(), limit_bytes, debug 82 | ), 83 | debug, 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /deepspain/search.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch 2 | from fastai.text import LanguageLearner 3 | 4 | from deepspain.embeddings import doc2vec 5 | from deepspain.utils import measure 6 | 7 | 8 | def recreate_index( 9 | es: Elasticsearch, learn: LanguageLearner, index_name: str, debug=False 10 | ): 11 | if es.indices.exists(index_name): 12 | print("deleting '%s' index..." % (index_name)) 13 | res = es.indices.delete(index=index_name) 14 | print(" response: '%s'" % (res)) 15 | # since we are running locally, use one shard and no replicas 16 | shapes = list(map(lambda x: x.shape[0], doc2vec(learn, "test", debug))) 17 | mappings = { 18 | "embeddings_" + str(idx): {"type": "dense_vector", "dims": dims} 19 | for idx, dims in zip(range(len(shapes)), shapes) 20 | } 21 | request_body = { 22 | "settings": {"number_of_shards": 1, "number_of_replicas": 0}, 23 | "mappings": {"properties": mappings}, 24 | } 25 | print("creating '%s' index..." % (index_name)) 26 | res = es.indices.create(index=index_name, body=request_body) 27 | print(" response: '%s'" % (res)) 28 | 29 | 30 | def index_document( 31 | es: Elasticsearch, 32 | learner: LanguageLearner, 33 | index_name: str, 34 | document: dict, 35 | limit_bytes: int, 36 | debug=False, 37 | ): 38 | content = (document["title"] + "\n" + document["content"])[:limit_bytes] 39 | embeddings = doc2vec(learner, content, debug) 40 | for idx, e in zip(range(len(embeddings)), embeddings): 41 | document["embeddings_" + str(idx)] = e.tolist() 42 | res = es.index(index=index_name, id=document["id"], body=document) 43 | return res 44 | 45 | 46 | def search( 47 | es: Elasticsearch, 48 | learner: LanguageLearner, 49 | index_name: str, 50 | query: str, 51 | debug=False, 52 | ): 53 | embeddings = doc2vec(learner, query, debug) 54 | # embeddings = [embeddings[0]] 55 | indices = range(len(embeddings)) 56 | with_index = zip(indices, embeddings) 57 | params = {"queryVector" + str(idx): e.tolist() for idx, e in with_index} 58 | queries = [ 59 | "cosineSimilarity(params.queryVector" 60 | + str(idx) 61 | + ", doc['embeddings_" 62 | + str(idx) 63 | + "'])" 64 | for idx in indices 65 | ] 66 | q = { 67 | "size": 1, 68 | "query": { 69 | "script_score": { 70 | "query": {"match_all": {}}, 71 | "script": {"source": "+".join(queries) + "+0.0", "params": params}, 72 | } 73 | }, 74 | } 75 | result = measure("search", lambda: es.search(index=index_name, body=q), debug) 76 | return result["hits"]["hits"][0]["_source"]["title"] 77 | -------------------------------------------------------------------------------- /deepspain/dataset.py: -------------------------------------------------------------------------------- 1 | """Handles creation and processing of the BOE dataset from scratch.""" 2 | 3 | import asyncio 4 | import io 5 | import json 6 | import re 7 | import xml.etree.ElementTree as ET 8 | from datetime import date, timedelta 9 | from pathlib import Path 10 | from typing import Iterator, Sequence 11 | 12 | import numpy as np 13 | from aiohttp import ClientSession 14 | from fastai.text.data import TextList, TextLMDataBunch, load_data 15 | from pandas import DataFrame 16 | 17 | from deepspain.utils import measure 18 | 19 | BASE_URL = "https://www.boe.es" 20 | 21 | 22 | def _parse_item(item, extra, content): 23 | item_id = item.get("id") 24 | item_control = item.get("control") 25 | title = item.findtext("titulo") 26 | url = item.findtext("urlXml") 27 | 28 | try: 29 | root = ET.fromstring(content) 30 | 31 | metadata = root.find("metadatos") 32 | date = metadata.findtext("fecha_disposicion") 33 | rank = metadata.findtext("rango") 34 | publish_date = metadata.findtext("fecha_publicacion") 35 | texto = root.find("texto") 36 | 37 | content = ( 38 | "\n".join([x.text for x in texto.findall("p") if x.text]) 39 | .replace("\xa0", " ") 40 | .replace("\u2003", " ") 41 | ) 42 | signatures = " ".join( 43 | [x.text for x in texto.findall('p[@class="firma_ministro"]') if x.text] 44 | ) 45 | result = dict( 46 | id=item_id, 47 | rank=rank, 48 | control=item_control, 49 | title=title, 50 | url=url, 51 | content=content, 52 | signatures=signatures, 53 | date=date, 54 | publish_date=publish_date, 55 | **extra 56 | ) 57 | return result 58 | except Exception: 59 | return None 60 | 61 | 62 | async def _fetch_item(item, session, metadata): 63 | url = BASE_URL + item.findtext("urlXml") 64 | res = await session.request(method="GET", url=url) 65 | text = await res.text() 66 | parsed = _parse_item(item, metadata, text) 67 | 68 | return parsed 69 | 70 | 71 | async def _fetch(url, session): 72 | async with session.get(url) as response: 73 | return await response.read() 74 | 75 | 76 | async def _fetch_boe(session, id): 77 | url = BASE_URL + "/diario_boe/xml.php?id=" + id 78 | 79 | resp = await _fetch(url, session) 80 | 81 | futures = [] 82 | try: 83 | root = ET.fromstring(resp) 84 | 85 | if not root.find("./error"): 86 | for section in root.findall("./diario/seccion"): 87 | attrs = section.attrib 88 | num = attrs["num"] 89 | name = attrs["nombre"] 90 | for department in section.findall("./departamento"): 91 | department_name = department.attrib["nombre"] 92 | for epigraph in department.findall("./epigrafe"): 93 | epigraph_name = epigraph.attrib["nombre"] 94 | metadata = dict( 95 | boe_id=id, 96 | section_number=num, 97 | section_name=name, 98 | department_name=department_name, 99 | epigraph_name=epigraph_name, 100 | ) 101 | futures = futures + [ 102 | asyncio.ensure_future(_fetch_item(item, session, metadata)) 103 | for item in epigraph.findall("./item") 104 | ] 105 | except Exception: 106 | pass 107 | 108 | return await asyncio.gather(*futures) 109 | 110 | 111 | async def download(boe_ids: Sequence[str], output_file: str): 112 | """Downloads a series of `boe ids`, saving them to `output_file` as a 113 | JSONlines file. Appends to an existing file if needed.""" 114 | 115 | with io.open(output_file, "a", encoding="utf8") as json_file: 116 | n = 1 117 | for boe_id in boe_ids: 118 | print(str(n) + "/" + str(len(boe_ids)) + " - " + boe_id) 119 | n = n + 1 120 | async with ClientSession() as session: 121 | items = await _fetch_boe(session, boe_id) 122 | for item in items: 123 | json.dump(item, json_file, ensure_ascii=False) 124 | json_file.write("\n") 125 | 126 | 127 | one_day = timedelta(days=1) 128 | 129 | 130 | def generate_ids(from_boe: str) -> Iterator[str]: 131 | """Returns an iterator over BOE ids from a specific starting id until 132 | yesterday's.""" 133 | 134 | year, month, day = re.search(r"(\d{4})(\d{2})(\d{2})", from_boe).groups() 135 | start = date(int(year), int(month), int(day)) 136 | yesterday = date.today() - one_day 137 | d = start - one_day 138 | while d < yesterday: 139 | d = d + one_day 140 | yield "BOE-S-" + d.strftime("%Y%m%d") 141 | 142 | 143 | def df_to_lm_databunch( 144 | df: DataFrame, 145 | columns: Sequence[str], 146 | batch_size: int = 48, 147 | seed: int = 42, 148 | sample: bool = False, 149 | ) -> TextLMDataBunch: 150 | """Extracts text from `columns` in `df` and produces a DataBunch ready for language modeling.""" 151 | np.random.seed(seed) 152 | 153 | data = TextList.from_df(df, cols=columns) 154 | if sample: 155 | data = data.filter_by_rand(0.01, seed=seed) 156 | 157 | databunch = ( 158 | data.split_by_rand_pct(0.2, seed=seed).label_for_lm().databunch(bs=batch_size) 159 | ) 160 | return databunch 161 | 162 | 163 | def load_databunch(pkl_path: Path, debug=False) -> TextLMDataBunch: 164 | p = pkl_path 165 | folder = p.parent 166 | filename = p.name 167 | return measure("loading dataframe", lambda: load_data(folder, filename), debug) 168 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | from functools import partial 5 | 6 | import click 7 | import torch 8 | from fastai.text import ( 9 | TransformerXL, 10 | language_model_learner, 11 | LanguageLearner, 12 | TextLMDataBunch, 13 | ) 14 | from fastai.distributed import setup_distrib 15 | 16 | from deepspain.dataset import load_databunch 17 | from deepspain.tensorboard import LearnerTensorboardWriter 18 | 19 | 20 | def save( 21 | data: TextLMDataBunch, 22 | learn: LanguageLearner, 23 | label: str, 24 | suffix: str, 25 | accuracy: int, 26 | ): 27 | f = open("models/" + label + "_accuracy.metric", "w") 28 | f.write(str(accuracy)) 29 | f.close() 30 | click.echo("Saving...") 31 | learn.save("model_" + label + "_" + suffix) 32 | learn.save_encoder("encoder_" + label + "_" + suffix) 33 | click.echo("Exporting...") 34 | data.export("models/" + label + "_empty_data") 35 | learn.export("models/learner_" + label + "_" + suffix + ".pkl") 36 | 37 | 38 | def clean(word: str): 39 | w = ( 40 | word.replace("\n", "(nl)") 41 | .replace("\t", "(tab)") 42 | .replace("\u2002", "(u2002)") 43 | .replace(" ", "(sp)") 44 | ) 45 | return "".join([c if ord(c) < 1000 else "_" for c in w]) 46 | 47 | 48 | def initialize_learner( 49 | data: TextLMDataBunch, 50 | pretrained_encoder: str, 51 | pretrained_itos: str, 52 | local_rank: int, 53 | label: str, 54 | kind: str, 55 | gpus: int, 56 | ) -> LanguageLearner: 57 | data.path = Path(".") 58 | click.echo("Training language model...") 59 | learn = language_model_learner( 60 | data, 61 | TransformerXL, 62 | pretrained_fnames=[ 63 | "./../" + pretrained_encoder.replace(".pth", ""), 64 | "./../" + pretrained_itos.replace(".pkl", ""), 65 | ], 66 | drop_mult=0.1, 67 | ) 68 | 69 | tboard_path = Path("logs/" + label) 70 | node_name = "gpu-" + str(local_rank) + "-" + kind 71 | learn.callback_fns.append( 72 | partial( 73 | LearnerTensorboardWriter, base_dir=tboard_path, gpus=gpus, name=node_name 74 | ) 75 | ) 76 | 77 | if gpus > 1: 78 | learn.to_distributed(local_rank) 79 | return learn 80 | 81 | 82 | @click.command() 83 | @click.argument( 84 | "databunch", metavar="", type=click.Path(exists=True, dir_okay=False) 85 | ) 86 | @click.argument( 87 | "output_dir", metavar="", type=click.Path(exists=True, file_okay=False) 88 | ) 89 | @click.argument( 90 | "pretrained_encoder", 91 | metavar="", 92 | type=click.Path(exists=True, dir_okay=False), 93 | ) 94 | @click.argument( 95 | "pretrained_itos", 96 | metavar="", 97 | type=click.Path(exists=True, dir_okay=False), 98 | ) 99 | @click.option( 100 | "--label", 101 | default="standard", 102 | type=str, 103 | help="Label to distinguish the trained model", 104 | ) 105 | @click.option( 106 | "--head-only", 107 | is_flag=True, 108 | default=False, 109 | help="Train only the model's head (without fine-tuning the backbone)", 110 | ) 111 | @click.option( 112 | "--head-epochs", type=int, default=4, help="Number of epochs to train the head" 113 | ) 114 | @click.option( 115 | "--backbone-epochs", 116 | type=int, 117 | default=8, 118 | help="Number of epochs to train the backbone", 119 | ) 120 | @click.option("--gpus", type=int, help="Total number of GPUs") 121 | @click.option( 122 | "--local_rank", type=int, help="Node number (set by pyTorch's distributed trainer)" 123 | ) 124 | def main( 125 | databunch: str, 126 | output_dir: str, 127 | pretrained_encoder: str, 128 | pretrained_itos: str, 129 | label: str, 130 | head_only: bool, 131 | head_epochs: int, 132 | backbone_epochs: int, 133 | gpus: int, 134 | local_rank: int, 135 | ): 136 | """Trains a Language Model, starting from pretrained weights, with data from . 137 | Saves the best model and encoder to /{model,encoder}.pth respectively. 138 | """ 139 | output_path = Path(output_dir) 140 | 141 | if gpus > 1: 142 | click.echo("Setting up distributed training...") 143 | setup_distrib(local_rank) 144 | 145 | click.echo("Loading LM databunch...") 146 | data = load_databunch(Path(databunch)) 147 | 148 | # Do a bit a hack to save time 149 | pretrained_special = ["_unk_", "_pad_", "xbos", "xfld"] 150 | actual_special = ["xxunk", "xxpad", "xxbos", "xxfld"] 151 | 152 | print(len(list(filter(lambda x: "?" in x, data.vocab.itos)))) 153 | itos = [ 154 | pretrained_special[actual_special.index(word)] 155 | if (word in actual_special) 156 | else clean(word) 157 | for word in data.vocab.itos 158 | ] 159 | data.vocab.itos = itos 160 | 161 | data.path = Path(".") 162 | click.echo("Training language model...") 163 | learn = initialize_learner( 164 | data, pretrained_encoder, pretrained_itos, local_rank, label, "head", gpus 165 | ) 166 | 167 | learn.freeze() 168 | click.echo("Training model head...") 169 | learn.fit_one_cycle(head_epochs, 1e-3, moms=(0.8, 0.7)) 170 | click.echo("Validating...") 171 | accuracy = learn.validate()[1].item() 172 | 173 | if local_rank == 0: 174 | save(data, learn, label, "head", accuracy) 175 | 176 | if not head_only: 177 | click.echo("Unfreezing and fine-tuning earlier layers...") 178 | 179 | learn = initialize_learner( 180 | data, 181 | pretrained_encoder, 182 | pretrained_itos, 183 | local_rank, 184 | label, 185 | "finetuned", 186 | gpus, 187 | ) 188 | 189 | learn.load("model_" + label + "_head") 190 | 191 | learn.unfreeze() 192 | learn.fit_one_cycle(backbone_epochs, 1e-3, moms=(0.8, 0.7)) 193 | 194 | click.echo("Validating...") 195 | accuracy = learn.validate()[1].item() 196 | if local_rank == 0: 197 | save(data, learn, label, "finetuned", accuracy) 198 | 199 | 200 | if __name__ == "__main__": 201 | with warnings.catch_warnings(): 202 | warnings.filterwarnings("ignore", category=UserWarning) 203 | main() # pylint: disable=no-value-for-parameter 204 | -------------------------------------------------------------------------------- /deepspain/tensorboard.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from fastai.callbacks import LearnerCallback 3 | from typing import Any, Collection 4 | 5 | "Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard" 6 | from fastai.basic_train import Learner 7 | from fastai.basic_data import DatasetType, DataBunch 8 | from fastai.vision import Image 9 | from fastai.vision.gan import GANLearner 10 | from fastai.core import * 11 | from fastai.torch_core import * 12 | from threading import Thread, Event 13 | from time import sleep 14 | from queue import Queue 15 | import statistics 16 | import torchvision.utils as vutils 17 | from pathlib import Path 18 | from abc import ABC 19 | 20 | # ---Example usage (applies to any of the callbacks)--- 21 | # proj_id = 'Colorize' 22 | # tboard_path = Path('data/tensorboard/' + proj_id) 23 | # learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner')) 24 | 25 | 26 | class LearnerTensorboardWriter(LearnerCallback): 27 | "Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, and gradient stats." 28 | 29 | def __init__( 30 | self, 31 | learn: Learner, 32 | gpus: int, 33 | base_dir: Path, 34 | name: str, 35 | loss_iters: int = 25, 36 | hist_iters: int = 500, 37 | stats_iters: int = 100, 38 | ): 39 | super().__init__(learn=learn) 40 | self.base_dir, self.name, self.loss_iters, self.hist_iters, self.stats_iters = ( 41 | base_dir, 42 | name, 43 | loss_iters, 44 | hist_iters, 45 | stats_iters, 46 | ) 47 | log_dir = base_dir / name 48 | self.gpus = gpus 49 | self.itos = self.learn.data.vocab.itos 50 | self.tbwriter = SummaryWriter(str(log_dir)) 51 | self.hist_writer = HistogramTBWriter() 52 | self.stats_writer = ModelStatsTBWriter() 53 | self.graph_writer = GraphTBWriter() 54 | self.data = None 55 | self.metrics_root = "/metrics/" 56 | self._update_batches_if_needed() 57 | 58 | def _get_model(self): 59 | if self.gpus > 1: 60 | return self.learn.model.module 61 | else: 62 | return self.learn.model 63 | 64 | def _get_new_batch(self, ds_type: DatasetType) -> Collection[Tensor]: 65 | "Retrieves new batch of DatasetType, and detaches it." 66 | return self.learn.data.one_batch( 67 | ds_type=ds_type, detach=True, denorm=False, cpu=False 68 | ) 69 | 70 | def _update_batches_if_needed(self) -> None: 71 | "one_batch function is extremely slow with large datasets. This is caching the result as an optimization." 72 | if self.learn.data.valid_dl is None: 73 | return # Running learning rate finder, so return 74 | update_batches = self.data is not self.learn.data 75 | if not update_batches: 76 | return 77 | self.data = self.learn.data 78 | self.trn_batch = self._get_new_batch(ds_type=DatasetType.Train) 79 | self.val_batch = self._get_new_batch(ds_type=DatasetType.Valid) 80 | 81 | def _write_model_stats(self, iteration: int) -> None: 82 | "Writes gradient statistics to Tensorboard." 83 | self.stats_writer.write( 84 | model=self._get_model(), iteration=iteration, tbwriter=self.tbwriter 85 | ) 86 | 87 | def _write_training_loss(self, iteration: int, last_loss: Tensor) -> None: 88 | "Writes training loss to Tensorboard." 89 | scalar_value = to_np(last_loss) 90 | tag = self.metrics_root + "train_loss" 91 | self.tbwriter.add_scalar( 92 | tag=tag, scalar_value=scalar_value, global_step=iteration 93 | ) 94 | 95 | def _write_weight_histograms(self, iteration: int) -> None: 96 | "Writes model weight histograms to Tensorboard." 97 | self.hist_writer.write( 98 | model=self._get_model(), iteration=iteration, tbwriter=self.tbwriter 99 | ) 100 | 101 | def _write_scalar(self, name: str, scalar_value, iteration: int) -> None: 102 | "Writes single scalar value to Tensorboard." 103 | tag = self.metrics_root + name 104 | self.tbwriter.add_scalar( 105 | tag=tag, scalar_value=scalar_value, global_step=iteration 106 | ) 107 | 108 | # TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution? 109 | def _write_metrics( 110 | self, iteration: int, last_metrics: MetricsList, start_idx: int = 2 111 | ) -> None: 112 | "Writes training metrics to Tensorboard." 113 | recorder = self.learn.recorder 114 | for i, name in enumerate(recorder.names[start_idx:]): 115 | if last_metrics is None or len(last_metrics) < i + 1: 116 | return 117 | scalar_value = last_metrics[i] 118 | self._write_scalar( 119 | name=name, scalar_value=scalar_value, iteration=iteration 120 | ) 121 | 122 | def _write_embedding(self, iteration: int) -> None: 123 | "Writes embedding to Tensorboard." 124 | print(type(self._get_model())) 125 | print(type(self._get_model()[0])) 126 | encoder = self._get_model()[0].encoder 127 | self.tbwriter.add_embedding( 128 | list(encoder.parameters())[0], 129 | global_step=iteration, 130 | tag="encoder", 131 | metadata=self.itos, 132 | ) 133 | 134 | def on_train_begin(self, **kwargs: Any) -> None: 135 | self.graph_writer.write( 136 | model=self._get_model(), 137 | tbwriter=self.tbwriter, 138 | input_to_model=next(iter(self.learn.data.dl(DatasetType.Single)))[0], 139 | ) 140 | 141 | def on_batch_end( 142 | self, last_loss: Tensor, iteration: int, train: bool, **kwargs 143 | ) -> None: 144 | "Callback function that writes batch end appropriate data to Tensorboard." 145 | if iteration == 0 or not train: 146 | return 147 | self._update_batches_if_needed() 148 | if iteration % self.loss_iters == 0: 149 | self._write_training_loss(iteration=iteration, last_loss=last_loss) 150 | if iteration % self.hist_iters == 0: 151 | self._write_weight_histograms(iteration=iteration) 152 | 153 | # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop 154 | def on_backward_end(self, iteration: int, train: bool, **kwargs) -> None: 155 | "Callback function that writes backward end appropriate data to Tensorboard." 156 | if iteration == 0 and not train: 157 | return 158 | self._update_batches_if_needed() 159 | if iteration % self.stats_iters == 0: 160 | self._write_model_stats(iteration=iteration) 161 | 162 | def on_epoch_end(self, last_metrics: MetricsList, iteration: int, **kwargs) -> None: 163 | "Callback function that writes epoch end appropriate data to Tensorboard." 164 | self._write_metrics(iteration=iteration, last_metrics=last_metrics) 165 | self._write_embedding(iteration=iteration) 166 | 167 | 168 | class TBWriteRequest(ABC): 169 | "A request object for Tensorboard writes. Useful for queuing up and executing asynchronous writes." 170 | 171 | def __init__(self, tbwriter: SummaryWriter, iteration: int): 172 | super().__init__() 173 | self.tbwriter = tbwriter 174 | self.iteration = iteration 175 | 176 | @abstractmethod 177 | def write(self) -> None: 178 | pass 179 | 180 | 181 | # SummaryWriter writes tend to block quite a bit. This gets around that and greatly boosts performance. 182 | # Not all tensorboard writes are using this- just the ones that take a long time. Note that the 183 | # SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard, 184 | # so writes done outside of this async loop should be fine. 185 | class AsyncTBWriter: 186 | "Callback for GANLearners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes." 187 | 188 | def __init__(self): 189 | super().__init__() 190 | self.stop_request = Event() 191 | self.queue = Queue() 192 | self.thread = Thread(target=self._queue_processor, daemon=True) 193 | self.thread.start() 194 | 195 | def request_write(self, request: TBWriteRequest) -> None: 196 | "Queues up an asynchronous write request to Tensorboard." 197 | if self.stop_request.isSet(): 198 | return 199 | self.queue.put(request) 200 | 201 | def _queue_processor(self) -> None: 202 | "Processes queued up write requests asynchronously to Tensorboard." 203 | while not self.stop_request.isSet(): 204 | while not self.queue.empty(): 205 | if self.stop_request.isSet(): 206 | return 207 | request = self.queue.get() 208 | request.write() 209 | sleep(0.2) 210 | 211 | # Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own 212 | # upon program exit, due to being a daemon. So using this is probably unecessary. 213 | def close(self) -> None: 214 | "Stops asynchronous request queue processing thread." 215 | self.stop_request.set() 216 | self.thread.join() 217 | 218 | # Nothing to do, thread already started. Could start thread here to enforce use of context manager 219 | # (but that sounds like a pain and a bit unweildy and unecessary for actual usage) 220 | def __enter__(self): 221 | pass 222 | 223 | def __exit__(self, exc_type, exc_value, traceback): 224 | self.close() 225 | 226 | 227 | asyncTBWriter = AsyncTBWriter() 228 | 229 | 230 | class ModelImageSet: 231 | "Convenience object that holds the original, real(target) and generated versions of a single image fed to a model." 232 | 233 | @staticmethod 234 | def get_list_from_model(learn: Learner, ds_type: DatasetType, batch: Tuple) -> []: 235 | "Factory method to convert a batch of model images to a list of ModelImageSet." 236 | image_sets = [] 237 | x, y = batch[0], batch[1] 238 | preds = learn.pred_batch(ds_type=ds_type, batch=(x, y), reconstruct=True) 239 | for orig_px, real_px, gen in zip(x, y, preds): 240 | orig, real = Image(px=orig_px), Image(px=real_px) 241 | image_set = ModelImageSet(orig=orig, real=real, gen=gen) 242 | image_sets.append(image_set) 243 | return image_sets 244 | 245 | def __init__(self, orig: Image, real: Image, gen: Image): 246 | self.orig, self.real, self.gen = orig, real, gen 247 | 248 | 249 | class HistogramTBRequest(TBWriteRequest): 250 | "Request object for model histogram writes to Tensorboard." 251 | 252 | def __init__( 253 | self, model: nn.Module, iteration: int, tbwriter: SummaryWriter, name: str 254 | ): 255 | super().__init__(tbwriter=tbwriter, iteration=iteration) 256 | self.params = [ 257 | (name, values.clone().detach().cpu()) 258 | for (name, values) in model.named_parameters() 259 | ] 260 | self.name = name 261 | 262 | def _write_histogram(self, param_name: str, values) -> None: 263 | "Writes single model histogram to Tensorboard." 264 | tag = self.name + "/weights/" + param_name 265 | self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration) 266 | 267 | def write(self) -> None: 268 | "Writes model histograms to Tensorboard." 269 | for param_name, values in self.params: 270 | self._write_histogram(param_name=param_name, values=values) 271 | 272 | 273 | # If this isn't done async then this is sloooooow 274 | class HistogramTBWriter: 275 | "Writes model histograms to Tensorboard." 276 | 277 | def __init__(self): 278 | super().__init__() 279 | 280 | def write( 281 | self, 282 | model: nn.Module, 283 | iteration: int, 284 | tbwriter: SummaryWriter, 285 | name: str = "model", 286 | ) -> None: 287 | "Writes model histograms to Tensorboard." 288 | request = HistogramTBRequest( 289 | model=model, iteration=iteration, tbwriter=tbwriter, name=name 290 | ) 291 | asyncTBWriter.request_write(request) 292 | 293 | 294 | class ModelStatsTBRequest(TBWriteRequest): 295 | "Request object for model gradient statistics writes to Tensorboard." 296 | 297 | def __init__( 298 | self, model: nn.Module, iteration: int, tbwriter: SummaryWriter, name: str 299 | ): 300 | super().__init__(tbwriter=tbwriter, iteration=iteration) 301 | self.gradients = [ 302 | x.grad.clone().detach().cpu() 303 | for x in model.parameters() 304 | if x.grad is not None 305 | ] 306 | self.name = name 307 | 308 | def _add_gradient_scalar(self, name: str, scalar_value) -> None: 309 | "Writes a single scalar value for a gradient statistic to Tensorboard." 310 | tag = self.name + "/gradients/" + name 311 | self.tbwriter.add_scalar( 312 | tag=tag, scalar_value=scalar_value, global_step=self.iteration 313 | ) 314 | 315 | def _write_avg_norm(self, norms: []) -> None: 316 | "Writes the average norm of the gradients to Tensorboard." 317 | avg_norm = sum(norms) / len(self.gradients) 318 | self._add_gradient_scalar("avg_norm", scalar_value=avg_norm) 319 | 320 | def _write_median_norm(self, norms: []) -> None: 321 | "Writes the median norm of the gradients to Tensorboard." 322 | median_norm = statistics.median(norms) 323 | self._add_gradient_scalar("median_norm", scalar_value=median_norm) 324 | 325 | def _write_max_norm(self, norms: []) -> None: 326 | "Writes the maximum norm of the gradients to Tensorboard." 327 | max_norm = max(norms) 328 | self._add_gradient_scalar("max_norm", scalar_value=max_norm) 329 | 330 | def _write_min_norm(self, norms: []) -> None: 331 | "Writes the minimum norm of the gradients to Tensorboard." 332 | min_norm = min(norms) 333 | self._add_gradient_scalar("min_norm", scalar_value=min_norm) 334 | 335 | def _write_num_zeros(self) -> None: 336 | "Writes the number of zeroes in the gradients to Tensorboard." 337 | gradient_nps = [to_np(x.data) for x in self.gradients] 338 | num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps) 339 | self._add_gradient_scalar("num_zeros", scalar_value=num_zeros) 340 | 341 | def _write_avg_gradient(self) -> None: 342 | "Writes the average of the gradients to Tensorboard." 343 | avg_gradient = sum(x.data.mean() for x in self.gradients) / len(self.gradients) 344 | self._add_gradient_scalar("avg_gradient", scalar_value=avg_gradient) 345 | 346 | def _write_median_gradient(self) -> None: 347 | "Writes the median of the gradients to Tensorboard." 348 | median_gradient = statistics.median(x.data.median() for x in self.gradients) 349 | self._add_gradient_scalar("median_gradient", scalar_value=median_gradient) 350 | 351 | def _write_max_gradient(self) -> None: 352 | "Writes the maximum of the gradients to Tensorboard." 353 | max_gradient = max(x.data.max() for x in self.gradients) 354 | self._add_gradient_scalar("max_gradient", scalar_value=max_gradient) 355 | 356 | def _write_min_gradient(self) -> None: 357 | "Writes the minimum of the gradients to Tensorboard." 358 | min_gradient = min(x.data.min() for x in self.gradients) 359 | self._add_gradient_scalar("min_gradient", scalar_value=min_gradient) 360 | 361 | def write(self) -> None: 362 | "Writes model gradient statistics to Tensorboard." 363 | if len(self.gradients) == 0: 364 | return 365 | norms = [x.data.norm() for x in self.gradients] 366 | self._write_avg_norm(norms=norms) 367 | self._write_median_norm(norms=norms) 368 | self._write_max_norm(norms=norms) 369 | self._write_min_norm(norms=norms) 370 | self._write_num_zeros() 371 | self._write_avg_gradient() 372 | self._write_median_gradient() 373 | self._write_max_gradient() 374 | self._write_min_gradient() 375 | 376 | 377 | class ModelStatsTBWriter: 378 | "Writes model gradient statistics to Tensorboard." 379 | 380 | def write( 381 | self, 382 | model: nn.Module, 383 | iteration: int, 384 | tbwriter: SummaryWriter, 385 | name: str = "model_stats", 386 | ) -> None: 387 | "Writes model gradient statistics to Tensorboard." 388 | request = ModelStatsTBRequest( 389 | model=model, iteration=iteration, tbwriter=tbwriter, name=name 390 | ) 391 | asyncTBWriter.request_write(request) 392 | 393 | 394 | class ImageTBRequest(TBWriteRequest): 395 | "Request object for model image output writes to Tensorboard." 396 | 397 | def __init__( 398 | self, 399 | learn: Learner, 400 | batch: Tuple, 401 | iteration: int, 402 | tbwriter: SummaryWriter, 403 | ds_type: DatasetType, 404 | ): 405 | super().__init__(tbwriter=tbwriter, iteration=iteration) 406 | self.image_sets = ModelImageSet.get_list_from_model( 407 | learn=learn, batch=batch, ds_type=ds_type 408 | ) 409 | self.ds_type = ds_type 410 | 411 | def _write_images(self, name: str, images: [Tensor]) -> None: 412 | "Writes list of images as tensors to Tensorboard." 413 | tag = self.ds_type.name + " " + name 414 | self.tbwriter.add_image( 415 | tag=tag, 416 | img_tensor=vutils.make_grid(images, normalize=True), 417 | global_step=self.iteration, 418 | ) 419 | 420 | def _get_image_tensors(self) -> ([Tensor], [Tensor], [Tensor]): 421 | "Gets list of image tensors from lists of Image objects, as a tuple of original, generated and real(target) images." 422 | orig_images, gen_images, real_images = [], [], [] 423 | for image_set in self.image_sets: 424 | orig_images.append(image_set.orig.px) 425 | gen_images.append(image_set.gen.px) 426 | real_images.append(image_set.real.px) 427 | return orig_images, gen_images, real_images 428 | 429 | def write(self) -> None: 430 | "Writes original, generated and real(target) images to Tensorboard." 431 | orig_images, gen_images, real_images = self._get_image_tensors() 432 | self._write_images(name="orig images", images=orig_images) 433 | self._write_images(name="gen images", images=gen_images) 434 | self._write_images(name="real images", images=real_images) 435 | 436 | 437 | # If this isn't done async then this is noticeably slower 438 | class ImageTBWriter: 439 | "Writes model image output to Tensorboard." 440 | 441 | def __init__(self): 442 | super().__init__() 443 | 444 | def write( 445 | self, 446 | learn: Learner, 447 | trn_batch: Tuple, 448 | val_batch: Tuple, 449 | iteration: int, 450 | tbwriter: SummaryWriter, 451 | ) -> None: 452 | "Writes training and validation batch images to Tensorboard." 453 | self._write_for_dstype( 454 | learn=learn, 455 | batch=val_batch, 456 | iteration=iteration, 457 | tbwriter=tbwriter, 458 | ds_type=DatasetType.Valid, 459 | ) 460 | self._write_for_dstype( 461 | learn=learn, 462 | batch=trn_batch, 463 | iteration=iteration, 464 | tbwriter=tbwriter, 465 | ds_type=DatasetType.Train, 466 | ) 467 | 468 | def _write_for_dstype( 469 | self, 470 | learn: Learner, 471 | batch: Tuple, 472 | iteration: int, 473 | tbwriter: SummaryWriter, 474 | ds_type: DatasetType, 475 | ) -> None: 476 | "Writes batch images of specified DatasetType to Tensorboard." 477 | request = ImageTBRequest( 478 | learn=learn, 479 | batch=batch, 480 | iteration=iteration, 481 | tbwriter=tbwriter, 482 | ds_type=ds_type, 483 | ) 484 | asyncTBWriter.request_write(request) 485 | 486 | 487 | class GraphTBRequest(TBWriteRequest): 488 | "Request object for model histogram writes to Tensorboard." 489 | 490 | def __init__( 491 | self, model: nn.Module, tbwriter: SummaryWriter, input_to_model: Tensor 492 | ): 493 | super().__init__(tbwriter=tbwriter, iteration=0) 494 | self.model, self.input_to_model = model, input_to_model 495 | 496 | def write(self) -> None: 497 | "Writes single model graph to Tensorboard." 498 | # self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model) 499 | 500 | 501 | class GraphTBWriter: 502 | "Writes model network graph to Tensorboard." 503 | 504 | def write( 505 | self, model: nn.Module, tbwriter: SummaryWriter, input_to_model: Tensor 506 | ) -> None: 507 | "Writes model graph to Tensorboard." 508 | request = GraphTBRequest( 509 | model=model, tbwriter=tbwriter, input_to_model=input_to_model 510 | ) 511 | asyncTBWriter.request_write(request) 512 | --------------------------------------------------------------------------------