├── examples ├── conf │ ├── wandb │ │ └── default.yaml │ ├── mlflow │ │ └── default.yaml │ ├── data │ │ ├── health_fact.yaml │ │ ├── arxiv_classification.yaml │ │ └── default.yaml │ ├── model │ │ └── default.yaml │ ├── config.yaml │ └── training_arguments │ │ └── default.yaml ├── requirements.txt ├── run_classification.py ├── utils.py ├── data.py └── run_classification_no_trainer.py ├── src └── strideformer │ ├── __init__.py │ ├── collator.py │ ├── pipeline.py │ ├── config.py │ └── model.py ├── setup.cfg ├── setup.py ├── .gitignore ├── README.md └── LICENSE /examples/conf/wandb/default.yaml: -------------------------------------------------------------------------------- 1 | project: health-fact 2 | entity: nbroad 3 | group: '' 4 | tags: null 5 | notes: > 6 | No notes 7 | job_type: train -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.26.1 2 | accelerate==0.16.0 3 | datasets==2.10.1 4 | hydra-core==1.3.2 5 | scikit-learn==1.2.1 6 | torch -------------------------------------------------------------------------------- /src/strideformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Strideformer 2 | from .config import StrideformerConfig 3 | from .collator import StrideformerCollator 4 | from .pipeline import Pipeline 5 | 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | multi_line_output = 3 3 | include_trailing_comma = True 4 | force_grid_wrap = 0 5 | use_parentheses = True 6 | ensure_newline_before_comments = True 7 | line_length = 79 8 | lines_after_imports = 2 9 | 10 | [flake8] 11 | ignore = E203, W503 12 | max-line-length = 79 -------------------------------------------------------------------------------- /examples/conf/mlflow/default.yaml: -------------------------------------------------------------------------------- 1 | log_artifact: 'FALSE' 2 | experiment_name: "mlflow_experiment" 3 | tags: 4 | test: "TRUE" 5 | model_type: transformer 6 | nested_run: "FALSE" 7 | run_id: null 8 | flatten_params: "FALSE" 9 | tracking_uri: null 10 | logging_dir: null 11 | run_name: null 12 | description: null -------------------------------------------------------------------------------- /examples/conf/data/health_fact.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: health_fact 2 | dataset_config_name: null 3 | max_seq_length: 512 4 | truncation: yes 5 | pad_multiple: 8 6 | # set to small number for debugging, -1 for all rows 7 | n_rows: -1 8 | text_col: "claim" 9 | text_pair_col: "main_text" 10 | label_col: "label" 11 | map_batch_size: 1000 -------------------------------------------------------------------------------- /examples/conf/data/arxiv_classification.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: 'ccdv/arxiv-classification' 2 | dataset_config_name: "no_ref" 3 | max_seq_length: 512 4 | truncation: yes 5 | pad_multiple: 8 6 | # set to small number for debugging, -1 for all rows 7 | n_rows: -1 8 | text_col: "text" 9 | text_pair_col: null 10 | label_col: "label" 11 | map_batch_size: 1000 -------------------------------------------------------------------------------- /examples/conf/data/default.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: MISSING 2 | dataset_config_name: null 3 | max_seq_length: 512 4 | truncation: yes 5 | pad_multiple: 8 6 | # set to small number for debugging, -1 for all rows 7 | n_rows: -1 8 | text_col: "text" 9 | text_pair_col: null 10 | label_col: "label" 11 | map_batch_size: 1000 12 | data_files: 13 | train: [] 14 | validation: [] -------------------------------------------------------------------------------- /examples/conf/model/default.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: "allenai/longformer-base-4096" 2 | config_name: null # Pretrained config name or path if not the same as model_name 3 | tokenizer_name: null # Pretrained tokenizer name or path if not the same as model_name" 4 | model_revision: main 5 | use_auth_token: no 6 | from_tf: no # set to 'yes' if loading from tensorflow weights 7 | from_flax: no # set to 'yes' if loading from flax weights -------------------------------------------------------------------------------- /examples/conf/config.yaml: -------------------------------------------------------------------------------- 1 | # Default configuration 2 | 3 | defaults: 4 | - data: health_fact 5 | - training_arguments: default 6 | - model: default 7 | - mlflow: default 8 | - wandb: default 9 | - _self_ # attributes in this file will override defaults 10 | 11 | 12 | model: 13 | model_name_or_path: "sentence-transformers/all-MiniLM-L12-v2" 14 | freeze_first_model: no 15 | max_chunks: 128 16 | num_hidden_layers: 48 17 | num_attention_heads: 12 18 | intermediate_size: 2048 19 | hidden_act: gelu 20 | dropout: 0.1 21 | layer_norm_eps: 1e-7 22 | initializer_range: 0.02 23 | 24 | 25 | data: 26 | max_seq_length: 256 27 | pad_multiple: 8 28 | stride: 128 29 | n_rows: 2000 30 | map_batch_size: 500 # batch size when using `datasets.map` 31 | data_files: 32 | train: 33 | - 'full_train.csv' 34 | validation: 35 | - 'full_val.csv' 36 | 37 | training_arguments: 38 | do_train: yes 39 | do_eval: yes 40 | 41 | evaluation_strategy: "epoch" 42 | fp16: yes 43 | 44 | learning_rate: 2e-6 45 | per_device_train_batch_size: 2 46 | per_device_eval_batch_size: 16 47 | gradient_accumulation_steps: 1 48 | 49 | metric_for_best_model: "f1_micro" 50 | greater_is_better: yes 51 | report_to: "wandb" 52 | log_level: "warning" 53 | save_strategy: "epoch" 54 | logging_steps: 100 55 | save_total_limit: 2 56 | 57 | # general configuration 58 | num_proc: 8 59 | task: text-classification 60 | language: en 61 | project_name: stride-former 62 | 63 | hydra: 64 | output_subdir: config -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | LONG_DESCRIPTION = """ 5 | A package that includes an implementation of a two-transformer 6 | setup to classify long documents. The first stage is a 7 | sentence-transformer which embeds chunks of the text. The second 8 | stage is a generic transformer used to classify the documents 9 | based on a sequence of embeddings from the first model. 10 | Training scripts included. 11 | """ 12 | 13 | REQUIRED_PKGS = ["transformers>=4.21.0"] 14 | 15 | QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"] 16 | TRAIN_REQUIRE = ["hydra-core>=1.2.0"] 17 | 18 | EXTRAS_REQUIRE = {"quality": QUALITY_REQUIRE, "train": TRAIN_REQUIRE} 19 | 20 | 21 | def combine_requirements(base_keys): 22 | return list(set(k for v in base_keys for k in EXTRAS_REQUIRE[v])) 23 | 24 | 25 | EXTRAS_REQUIRE["dev"] = combine_requirements([k for k in EXTRAS_REQUIRE]) 26 | 27 | 28 | setup( 29 | name="strideformer", 30 | version="0.2.0", 31 | description="Package to use a two-stage transformer for long-document classification", 32 | long_description=LONG_DESCRIPTION, 33 | author="Nicholas Broad", 34 | author_email="nicholas@huggingface.co", 35 | url="https://github.com/nbroad1881/strideformer", 36 | download_url="https://github.com/nbroad1881/strideformer", 37 | license="Apache 2.0", 38 | package_dir={"": "src"}, 39 | packages=find_packages("src"), 40 | install_requires=REQUIRED_PKGS, 41 | extras_require=EXTRAS_REQUIRE, 42 | keywords="nlp, machine learning, transformers, long sequence", 43 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store* 132 | .vscode/ -------------------------------------------------------------------------------- /src/strideformer/collator.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List, Dict 3 | 4 | import torch 5 | from transformers import PreTrainedTokenizerFast 6 | 7 | 8 | class StrideformerCollator: 9 | def __init__(self, tokenizer: PreTrainedTokenizerFast, max_chunks: int = 128): 10 | """ 11 | Loads a collator designed for Strideformer. 12 | 13 | Args: 14 | tokenizer (`PreTrainedTokenizerFast`): 15 | The tokenizer that corresponds with the first model in Strideformer. 16 | max_chunks (`int`, *optional*, defaults to 128): 17 | The maximum number of chunks that can be passed to the first model. 18 | This is to limit OOM errors. 19 | 20 | """ 21 | self.tokenizer = tokenizer 22 | self.max_chunks = max_chunks 23 | 24 | def __call__(self, features: List[Dict]): 25 | """ 26 | Put features in a format that the model can use. The dataloader will pass a list 27 | of dictionaries to this function. Each dictionary will have keys: input_ids, attention_mask, labels. 28 | 29 | This function will combine these dictionaries into a single dictionary with keys: input_ids, attention_mask, labels. 30 | Each value in the dictionary will be a tensor. 31 | 32 | Args: 33 | features (`List[Dict]`): 34 | The list will be as long as the batch size specified 35 | passed to the DataLoader. 36 | Each element of features will have keys: input_ids, attention_mask, labels 37 | input_ids will be of shape [num_chunks, sequence_length] 38 | attention_mask will be of shape [num_chunks, sequence_length] 39 | label will be a single value if this is single_label_classification or regression 40 | It will be a list if multi_label_classification 41 | 42 | 43 | Returns: 44 | (dict): input_ids, attention_mask, labels to be passed to the model. 45 | """ 46 | 47 | label_key = "label" if "label" in features[0] else "labels" 48 | 49 | ids = list(chain(*[x["input_ids"] for x in features])) 50 | mask = list(chain(*[x["attention_mask"] for x in features])) 51 | labels = [x[label_key] for x in features] 52 | 53 | longest_seq = max([len(x) for x in ids]) 54 | 55 | ids = [x + [self.tokenizer.pad_token_id] * (longest_seq - len(x)) for x in ids] 56 | mask = [x + [0] * (longest_seq - len(x)) for x in mask] 57 | 58 | return { 59 | "input_ids": torch.tensor(ids, dtype=torch.long)[: self.max_chunks, :], 60 | "attention_mask": torch.tensor(mask, dtype=torch.long)[ 61 | : self.max_chunks, : 62 | ], 63 | "labels": torch.tensor(labels, dtype=torch.long), 64 | } 65 | -------------------------------------------------------------------------------- /examples/conf/training_arguments/default.yaml: -------------------------------------------------------------------------------- 1 | # TrainingArguments 2 | # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments 3 | 4 | 5 | # output 6 | output_dir: model_output 7 | overwrite_output_dir: no 8 | 9 | # training 10 | do_train: no 11 | resume_from_checkpoint: 12 | 13 | # hyperparams 14 | per_device_train_batch_size: 8 15 | per_device_eval_batch_size: 8 16 | gradient_accumulation_steps: 1 17 | gradient_checkpointing: no 18 | group_by_length: no 19 | learning_rate: 5e-5 20 | weight_decay: 0.0 21 | seed: 42 22 | 23 | # schedule + steps 24 | num_train_epochs: 3 25 | lr_scheduler_type: linear # linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup 26 | warmup_ratio: 0.0 27 | warmup_steps: 0 28 | max_steps: -1 29 | 30 | # optimizer 31 | adam_beta1: 0.9 32 | adam_beta2: 0.999 33 | adam_epsilon: 1e-8 34 | max_grad_norm: 1.0 35 | optim: 'adamw_hf' # adamw_hf, adamw_torch, adamw_apex_fused, or adafactor 36 | adafactor: no 37 | 38 | # logging 39 | log_level: "passive" 40 | log_level_replica: "passive" 41 | log_on_each_node: yes 42 | logging_dir: null 43 | logging_strategy: steps 44 | logging_first_step: no 45 | logging_steps: 500 46 | logging_nan_inf_filter: yes 47 | 48 | # saving 49 | save_strategy: "steps" 50 | save_steps: 500 51 | save_total_limit: null 52 | 53 | # dtype 54 | fp16: no 55 | bf16: no # bf16 requires Ampere GPUs or newer (A100, A6000, rtx 3080) 56 | fp16_opt_level: "O1" 57 | half_precision_backend: "auto" 58 | bf16_full_eval: no 59 | fp16_full_eval: no 60 | tf32: no 61 | 62 | # compile 63 | torch_compile: no 64 | 65 | # evaluation/prediction 66 | do_eval: no 67 | evaluation_strategy: 'no' 68 | eval_delay: 0 69 | include_inputs_for_metrics: no 70 | do_predict: no 71 | jit_mode_eval: no 72 | 73 | # hub 74 | hub_model_id: null 75 | hub_token: null 76 | 77 | # reporting 78 | report_to: null 79 | 80 | # rarely used 81 | debug: '' 82 | prediction_loss_only: no 83 | eval_accumulation_steps: null 84 | use_ipex: no 85 | save_on_each_node: no 86 | no_cuda: no 87 | # use_mps_device: no 88 | data_seed: null 89 | local_rank: -1 90 | xpu_backend: null 91 | tpu_num_cores: null 92 | tpu_metrics_debug: no 93 | dataloader_drop_last: no 94 | past_index: -1 95 | run_name: null 96 | disable_tqdm: null 97 | remove_unused_columns: yes 98 | label_names: null 99 | greater_is_better: null 100 | ignore_data_skip: no 101 | sharded_ddp: '' 102 | fsdp: '' 103 | fsdp_min_num_params: 0 104 | fsdp_transformer_layer_cls_to_wrap: null 105 | deepspeed: null 106 | label_smoothing_factor: 0.0 107 | length_column_name: length 108 | ddp_find_unused_parameters: null 109 | ddp_bucket_cap_mb: null 110 | skip_memory_metrics: yes 111 | use_legacy_prediction_loop: no 112 | # deprecated 113 | per_gpu_train_batch_size: null 114 | per_gpu_eval_batch_size: null 115 | fp16_backend: auto 116 | push_to_hub_model_id: null 117 | push_to_hub_organization: null 118 | push_to_hub_token: null 119 | # _n_gpu: 120 | mp_parameters: '' 121 | auto_find_batch_size: no 122 | full_determinism: no 123 | torchdynamo: null 124 | ray_scope: last -------------------------------------------------------------------------------- /src/strideformer/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | 4 | from .model import Strideformer, StrideformerConfig 5 | 6 | 7 | class Pipeline: 8 | """ 9 | Pipeline abstraction to make it easy to use a 10 | Strideformer model. 11 | """ 12 | 13 | def __init__( 14 | self, model_name_or_path=None, model=None, tokenizer=None, device="cpu" 15 | ): 16 | """ 17 | Load pretrained model and tokenizer from path or objects. 18 | 19 | Args: 20 | model_name_or_path (`str`, *optional*, defaults to `None`): 21 | Location of model weights and configuration file. 22 | model (`Strideformer`, *optional*, defaults to `None`): 23 | Strideformer model to directly load. 24 | tokenizer (`PreTrainedTokenizer`, *optional*, defaults to `None`): 25 | Tokenizer to directly load. 26 | device (`str`, *optional*, defaults to `"cpu"`): 27 | String indicating which device the model should run on. 28 | """ 29 | if model_name_or_path is not None: 30 | config = StrideformerConfig.from_pretrained(model_name_or_path) 31 | self.model = Strideformer.from_pretrained(model_name_or_path, config=config) 32 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 33 | else: 34 | self.model = model 35 | self.tokenizer = tokenizer 36 | 37 | self.model.eval() 38 | self.device = torch.device(device) 39 | 40 | self.model.to(self.device) 41 | 42 | @torch.no_grad() 43 | def __call__(self, text, max_length=384, stride=128, return_type="list"): 44 | """ 45 | Pass the model a string and it will return predictions. Only works for 46 | one string and not a list of strings. 47 | 48 | Args: 49 | text (`str`): 50 | Text to be classified. 51 | max_length (`int`, *optional*, defaults to 384): 52 | Maximum sequence length for the first model. 53 | stride (`int`, *optional*, defaults to 128): 54 | Amount of overlap between token chunks. 55 | return_type (`str`, *optional*, defaults to `"list"`): 56 | String to determine what type of results to return. 57 | Options are `"np"`, `"pytorch"`, and `"list"`. 58 | 59 | Returns: 60 | (`np.array`, `torch.Tensor`, `list`): Classification scores. 61 | """ 62 | 63 | tokens = self.tokenizer( 64 | text, 65 | return_tensors="pt", 66 | max_length=max_length, 67 | stride=stride, 68 | return_overflowing_tokens=True, 69 | padding=True, 70 | truncation=True, 71 | ) 72 | 73 | del tokens["overflow_to_sample_mapping"] 74 | 75 | output = self.model(**tokens.to(self.device)) 76 | 77 | if return_type == "np": 78 | return output.logits.detach().cpu().numpy() 79 | elif return_type == "pt": 80 | return output.logits.detach().cpu() 81 | 82 | return output.logits.detach().cpu().tolist() 83 | -------------------------------------------------------------------------------- /src/strideformer/config.py: -------------------------------------------------------------------------------- 1 | """ Strideformer model configuration """ 2 | import copy 3 | from typing import Dict 4 | 5 | from transformers import PretrainedConfig, AutoConfig 6 | 7 | 8 | class StrideformerConfig(PretrainedConfig): 9 | r""" 10 | This is the configuration class to store the configuration of a [`Strideformer`]. It is used to instantiate a BStrideformerART 11 | model according to the specified arguments, defining the model architecture. 12 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 13 | documentation from [`PretrainedConfig`] for more information. 14 | Args: 15 | first_model_config (`PretrainedConfig`, *optional*, defaults to `None`): 16 | The configuration of the first model. 17 | freeze_first_model (`bool`, *optional*, defaults to `True`): 18 | If True, freeze the weights of the first model. Otherwise, train it as well. 19 | max_chunks (`int`, *optional*, defaults to 64) 20 | The maximum number of chunks the first model can take. 21 | hidden_size (`int`, *optional*, defaults to 384): 22 | The hidden size of the second model. It must match the first model's hidden size. 23 | initializer_range (`float`, *optional*, defaults to 0.02): 24 | The standard deviation when initializing weights using a normal distribution. 25 | num_hidden_layers (`int`, *optional*, defaults to 24): 26 | Number of layers for the second model. 27 | num_attention_heads (`int`, *optional*, defaults to 12): 28 | Number of attention heads for each attention layer in the second model. 29 | intermediate_size (`int`, *optional*, defaults to 4096): 30 | Dimensionality of the "intermediate" (often named feed-forward) layer in second model. 31 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 32 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 33 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 34 | dropout (`float`, *optional*, defaults to 0.1): 35 | The dropout probability for all fully connected layers in the second model. 36 | layer_norm_eps (`float`, *optional*, defaults to 1e-7): 37 | The epsilon value in LayerNorm 38 | num_labels (`int`, *optional*, defaults to 2): 39 | The number of labels for the classifier. 40 | 41 | Example: 42 | 43 | ```python 44 | >>> from strideformer import StrideformerConfig, Strideformer 45 | >>> from transformers import AutoConfig 46 | >>> first_model_config = AutoConfig.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") 47 | >>> config = StrideformerConfig(first_model_config, max_chunks=128) 48 | >>> model = Strideformer(config) 49 | ``` 50 | 51 | """ 52 | model_type = "strideformer" 53 | keys_to_ignore_at_inference = [ 54 | "first_model_hidden_states", 55 | "second_model_hidden_states", 56 | ] 57 | is_composition = True 58 | 59 | def __init__( 60 | self, 61 | **kwargs 62 | ): 63 | first_model_config_dict = kwargs.pop("first_model_config") 64 | first_model_type = first_model_config_dict.pop("model_type") 65 | self.first_model_config = AutoConfig.for_model(first_model_type, **first_model_config_dict) 66 | 67 | super().__init__(**kwargs) 68 | 69 | 70 | 71 | @classmethod 72 | def from_two_configs(cls, first_model_config: PretrainedConfig, second_model_config: Dict): 73 | 74 | return cls( 75 | first_model_config=first_model_config.to_dict(), 76 | **second_model_config 77 | ) 78 | 79 | def to_dict(self): 80 | """ 81 | Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*. 82 | Returns: 83 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 84 | """ 85 | output = copy.deepcopy(self.__dict__) 86 | output["first_model_config"] = self.first_model_config.to_dict() 87 | output["model_type"] = self.__class__.model_type 88 | return output -------------------------------------------------------------------------------- /examples/run_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning short model chunks for long sequence classification""" 17 | 18 | import logging 19 | import sys 20 | 21 | import hydra 22 | import numpy as np 23 | from sklearn.metrics import f1_score 24 | from omegaconf import DictConfig 25 | import datasets 26 | import transformers 27 | from transformers import ( 28 | AutoConfig, 29 | AutoModelForSequenceClassification, 30 | TrainingArguments, 31 | Trainer, 32 | set_seed, 33 | DataCollatorWithPadding, 34 | ) 35 | 36 | from strideformer import Strideformer, StrideformerConfig, StrideformerCollator 37 | from data import DataModule 38 | 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def compute_metrics(eval_predictions): 44 | predictions, labels = eval_predictions 45 | predictions = np.argmax(predictions, axis=1) 46 | 47 | f1_micro = f1_score(labels, predictions, average="micro") 48 | f1_macro = f1_score(labels, predictions, average="macro") 49 | accuracy = (predictions == labels).mean().item() 50 | 51 | return { 52 | "f1_micro": f1_micro, 53 | "f1_macro": f1_macro, 54 | "accuracy": accuracy, 55 | } 56 | 57 | 58 | @hydra.main(version_base=None, config_path="conf", config_name="config") 59 | def main(cfg: DictConfig) -> None: 60 | 61 | train_args = dict(cfg.training_arguments) 62 | training_args = TrainingArguments(**train_args) 63 | 64 | 65 | # Setup logging 66 | logging.basicConfig( 67 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 68 | datefmt="%m/%d/%Y %H:%M:%S", 69 | handlers=[logging.StreamHandler(sys.stdout)], 70 | ) 71 | 72 | if training_args.should_log: 73 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 74 | transformers.utils.logging.set_verbosity_info() 75 | 76 | log_level = training_args.get_process_log_level() 77 | logger.setLevel(log_level) 78 | datasets.utils.logging.set_verbosity(log_level) 79 | transformers.utils.logging.set_verbosity(log_level) 80 | transformers.utils.logging.enable_default_handler() 81 | transformers.utils.logging.enable_explicit_format() 82 | 83 | # Log on each process the small summary: 84 | logger.warning( 85 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 86 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 87 | ) 88 | logger.info(f"Training/evaluation parameters {training_args}") 89 | 90 | # Set seed before initializing model. 91 | set_seed(training_args.seed) 92 | 93 | data_module = DataModule(cfg) 94 | 95 | # If running distributed, this will do something on the main process, while 96 | # blocking replicas, and when it's finished, it will release the replicas. 97 | with training_args.main_process_first(desc="Dataset loading and tokenization"): 98 | data_module.prepare_dataset() 99 | 100 | 101 | use_strideformer = cfg.data.stride is not None and cfg.data.stride > 0 102 | 103 | if use_strideformer: 104 | collator = StrideformerCollator(tokenizer=data_module.tokenizer) 105 | 106 | # batch_sizes must always be 1 when using strided approach 107 | if ( 108 | training_args.per_device_train_batch_size != 1 109 | or training_args.per_device_eval_batch_size != 1 110 | ): 111 | logger.warning( 112 | "Batch size must be 1 when using strided approach. Changing to 1 now." 113 | ) 114 | training_args.per_device_train_batch_size = 1 115 | training_args.per_device_eval_batch_size = 1 116 | 117 | else: 118 | collator = DataCollatorWithPadding( 119 | tokenizer=data_module.tokenizer, pad_to_multiple_of=cfg.data.pad_multiple 120 | ) 121 | 122 | model_config = AutoConfig.from_pretrained( 123 | cfg.model.model_name_or_path, 124 | label2id=data_module.label2id, 125 | id2label=data_module.id2label, 126 | num_labels=len(data_module.label2id), 127 | ) 128 | if use_strideformer: 129 | second_model_config = dict( 130 | freeze_first_model=cfg.model.freeze_first_model, 131 | max_chunks=cfg.model.max_chunks, 132 | num_hidden_layers=cfg.model.num_hidden_layers, 133 | num_attention_heads=cfg.model.num_attention_heads, 134 | intermediate_size=cfg.model.intermediate_size, 135 | hidden_act=cfg.model.hidden_act, 136 | dropout=cfg.model.dropout, 137 | layer_norm_eps=cfg.model.layer_norm_eps, 138 | initializer_range=cfg.model.initializer_range, 139 | hidden_size=model_config.hidden_size, 140 | label2id=data_module.label2id, 141 | id2label=data_module.id2label, 142 | num_labels=len(data_module.label2id), 143 | ) 144 | model_config = StrideformerConfig.from_two_configs( 145 | first_model_config=model_config, 146 | second_model_config=second_model_config, 147 | ) 148 | 149 | if use_strideformer: 150 | model = Strideformer( 151 | config=model_config 152 | ) 153 | else: 154 | model = AutoModelForSequenceClassification.from_pretrained( 155 | cfg.model.model_name_or_path, config=model_config 156 | ) 157 | 158 | trainer = Trainer( 159 | model=model, 160 | args=training_args, 161 | train_dataset=data_module.get_dataset("train") if training_args.do_train else None, 162 | eval_dataset=data_module.get_dataset("validation") if training_args.do_eval else None, 163 | tokenizer=data_module.tokenizer, 164 | data_collator=collator, 165 | compute_metrics=compute_metrics, 166 | ) 167 | 168 | if training_args.do_train: 169 | 170 | train_results = trainer.train() 171 | 172 | logger.info(train_results.metrics) 173 | 174 | if training_args.do_eval: 175 | eval_results = trainer.evaluate() 176 | 177 | logger.info(eval_results) 178 | 179 | if training_args.do_predict: 180 | 181 | predict_results = trainer.predict(test_dataset=data_module.get_dataset("test")) 182 | 183 | logger.info(predict_results.metrics) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stride-former 2 | 3 | Using short models to classify long texts. Still a work-in-progress, but the general skeleton is there. This is very similar to this: [Three-level Hierarchical Transformer Networks for Long-sequence and Multiple Clinical Documents Classification](https://arxiv.org/abs/2104.08444). 4 | 5 | [Here is a good illustration from the paper.](https://www.semanticscholar.org/paper/Three-level-Hierarchical-Transformer-Networks-for-Si-Roberts/2cc7d0a242f8358bf5e3e41defd2e93e6297f74c/figure/0) 6 | ## How it works 7 | 8 | Since attention in transformers scales quadratically with sequence length, it can become infeasible to do full self-attention on sequences longer than 512 tokens. Models like [Longformer](https://arxiv.org/abs/2004.05150) and [Big Bird](https://arxiv.org/abs/2007.14062) use different attention mechanisms to reduce the quadratic scaling to linear scaling, allowing them to work on sequences up to 4096 tokens. This approach chunks the text into 512 token chunks and then aggregates those embeddings by putting them through another transformer. Because there are two models involved, the first one is a pre-trained sentence transformer to reduce the number of trainable parameters. The second model is a [generic transformer encoder](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html#torch.nn.TransformerEncoder). This repo is called stride-former because there is an overlap between chunks of text - typically of 256 tokens, which is referred to as a stride. 9 | 10 | ## Is it better than Longformer or Bigbird? 11 | 12 | In a very small and un-scientific set of experiments on a small subset of one dataset ([arxiv-classification](https://huggingface.co/datasets/ccdv/arxiv-classification)), the stride-former did best. This dataset has extremely long documents: about half of the documents are over 15k tokens and 25% are over 23k. This is far more than the 4096 token limit of Longformer and Bigbird, and much more than the typical 512 token limit of standard full-attention models. The current default parameters for the stride-former are: chunks of 384 tokens, stride by 128 tokens, and limit the total number of chunks to 128, which means that documents with 32k tokens can be passed through. 13 | 14 | The two rows at the bottom of the table below are the stride-former runs which use sentence transformers to encode the chunks. 15 | 16 | |model |max_seq_length|stride|eval_f1_micro |eval_f1_macro | 17 | |---------------------------------------|-------------------|-----------------|-----------|-------------------| 18 | |microsoft/deberta-v3-base |512 |0 |0.2871 |0.16918| 19 | |microsoft/deberta-v3-base |512 |0 |0.2475|0.1494| 20 | |allenai/longformer-base-4096 |4096 |0 |0.6732 |0.6094 | 21 | |allenai/longformer-base-4096 |4096 |0 |0.6831 |0.6295 | 22 | |google/bigbird-roberta-base |4096 |0 |0.5841 |0.5118 | 23 | |google/bigbird-roberta-base |4096 |0 |0.6534 |0.6064 | 24 | |sentence-transformers/all-MiniLM-L12-v2|384 |128 |0.7227 |0.6728 | 25 | |sentence-transformers/all-MiniLM-L12-v2|384 |128 |0.6831 |0.6386 | 26 | 27 | 28 | See the [Weights and Biases report here for more details.](https://wandb.ai/nbroad/stride-former/reports/Stride-former-comparison--VmlldzoyNTUyOTEy?accessToken=p5x55isxp9thu5ktrhrrlmm98c82ckaagcam3r2re43mye8z45763mudidrb4vml) 29 | 30 | ## How many chunks? 31 | 32 | If the max sequence length is `L` and the stride is `S`, and the original sequence length without striding is `N`, then here is how to calculate the number of chunks. Chunk 1 will consist of a sequence starting at token `0` and going until, but not including, `L`. Chunk 2 will start at token `L - S` and end at `2L - S`, chunk `C` will end at `CL - (C - 1)S`. How many chunks does it take for an original sequence of `N` tokens? 33 | 34 | ```text 35 | CL - (C - 1)*S > N 36 | C*(L - S) + S > N 37 | C*(L - S) > N - S 38 | C > (N - S)/(L - S) 39 | ``` 40 | 41 | ## Installation 42 | 43 | ```sh 44 | pip install strideformer 45 | ``` 46 | 47 | ## Basic usage 48 | 49 | ```python 50 | from strideformer import StrideformerConfig, Strideformer 51 | from transformers import AutoConfig, AutoTokenizer 52 | 53 | label2id = {"positive": 0, "negative": 1} 54 | id2label = {"0": "positive", "1": "negative"} 55 | 56 | first_model_name = "sentence-transformers/all-MiniLM-L6-v2" 57 | 58 | first_model_config = AutoConfig.from_pretrained( 59 | first_model_name, 60 | label2id=label2id, 61 | id2label=id2label, 62 | num_labels=len(label2id), 63 | ) 64 | second_model_config = dict( 65 | freeze_first_model=False, # Train first model? 66 | max_chunks=128, # Maximum number of chunks to consider 67 | num_hidden_layers=24, 68 | num_attention_heads=12, 69 | intermediate_size=4096, 70 | hidden_act="gelu", 71 | dropout=0.1, 72 | layer_norm_eps=1e-7, 73 | initializer_range=0.02, 74 | hidden_size=first_model_config.hidden_size, # Needs to match first model 75 | label2id=label2id, 76 | id2label=id2label, 77 | num_labels=len(label2id), 78 | ) 79 | 80 | model_config = StrideformerConfig.from_two_configs( 81 | first_model_config=first_model_config, 82 | second_model_config=second_model_config, 83 | ) 84 | 85 | model = Strideformer( 86 | config=model_config 87 | ) 88 | 89 | tokenizer = AutoTokenizer.from_pretrained(first_model_name) 90 | # Set the max_length and stride to whatever values you like (256, 128 are good starting points) 91 | inputs = tokenizer( 92 | ["Here is a bunch of text that should be strided over"], 93 | return_tensors="pt", 94 | return_overflowing_tokens=True, 95 | padding=True, 96 | truncation=True, 97 | max_length=5, 98 | stride=2 99 | ) 100 | 101 | print(inputs.input_ids.shape) 102 | # torch.Size([10, 5]) 103 | 104 | outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) 105 | print(outputs) 106 | # StrideformerOutput([('loss', None), 107 | # ('logits', 108 | # tensor([[ 0.4127, -0.2096]], grad_fn=)), 109 | # ('first_model_hidden_states', None), 110 | # ('second_model_hidden_states', None)]) 111 | 112 | output_path = "trained_strideformer" 113 | model.save_pretrained(output_path) 114 | 115 | # `from_pretrained` is only to be used after training Strideformer 116 | # Use `Strideformer(config)` before training 117 | loaded_model = Strideformer.from_pretrained(output_path) 118 | ``` 119 | 120 | 121 | ## Basic run 122 | 123 | Start by installing the required packages in the examples folder using `pip install -r requirements.txt`. Some additional packages might also be necessary depending on the tracking framework and models chosen (see [here](#optional-packages)). Then change the values inside `examples/conf/config.yaml` to your desired parameters. If you set `data.stride` to <=0 or null, then it will train with a regular approach (no strides). 124 | 125 | - `examples/run_classification.py` uses the Hugging Face Trainer and abstracts much of the training complexities away. 126 | - `examples/run_classification_no_trainer.py` uses Hugging Face Accelerate and allows for more customization. 127 | 128 | ```sh 129 | python run_classification.py 130 | ``` 131 | 132 | ## Dataset 133 | 134 | It is currently set up for two datasets with long texts: [health_fact](https://huggingface.co/datasets/health_fact) and [arxiv-classification](https://huggingface.co/datasets/ccdv/arxiv-classification). `data.py` contains a data module and processors to make it relatively simple to load in any dataset. 135 | 136 | ## Optional packages 137 | 138 | ```text 139 | mlflow 140 | wandb 141 | sentencepiece 142 | ``` 143 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import importlib 4 | from typing import Optional, Any, Union, Dict 5 | 6 | 7 | import torch 8 | from omegaconf import OmegaConf 9 | from transformers import TrainingArguments 10 | 11 | from accelerate.tracking import GeneralTracker 12 | from accelerate.logging import get_logger 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def set_wandb_env_vars(cfg): 18 | """ 19 | Set environment variables from the config dict object. 20 | The environment variables can be picked up by wandb in Trainer. 21 | """ 22 | 23 | os.environ["WANDB_ENTITY"] = getattr(cfg.wandb, "entity", "") 24 | os.environ["WANDB_PROJECT"] = getattr(cfg.wandb, "project", "") 25 | os.environ["WANDB_RUN_GROUP"] = getattr(cfg.wandb, "group", "") 26 | os.environ["WANDB_JOB_TYPE"] = getattr(cfg.wandb, "job_type", "") 27 | os.environ["WANDB_NOTES"] = getattr(cfg.wandb, "notes", "") 28 | if cfg.wandb.tags: 29 | os.environ["WANDB_TAGS"] = ",".join(cfg.wandb.tags) 30 | 31 | 32 | def set_mlflow_env_vars(cfg): 33 | """ 34 | Set environment variables from the config object. 35 | The environment variables can be picked up by mlflow in Trainer. 36 | """ 37 | 38 | os.environ["MLFLOW_NESTED_RUN"] = cfg.mlflow.nested_run or "" 39 | os.environ["MLFLOW_EXPERIMENT_NAME"] = cfg.mlflow.experiment_name or "" 40 | os.environ["MLFLOW_FLATTEN_PARAMS"] = cfg.mlflow.flatten_params or "" 41 | os.environ["MLFLOW_RUN_ID"] = cfg.mlflow.run_id or "" 42 | os.environ["MLFLOW_TAGS"] = json.dumps(OmegaConf.to_container(cfg.mlflow.tags)) 43 | 44 | 45 | training_args_to_log = { 46 | "per_device_train_batch_size", 47 | "per_device_eval_batch_size", 48 | "gradient_accumulation_steps", 49 | "evaluation_strategy", 50 | "eval_delay", 51 | "learning_rate", 52 | "weight_decay", 53 | "adam_beta1", 54 | "adam_beta2", 55 | "adam_epsilon", 56 | "max_grad_norm", 57 | "num_train_epochs", 58 | "max_steps", 59 | "lr_scheduler_type", 60 | "warmup_ratio", 61 | "warmup_steps", 62 | "log_level", 63 | "logging_dir", 64 | "logging_strategy", 65 | "logging_steps", 66 | "save_strategy", 67 | "save_steps", 68 | "save_total_limit", 69 | "seed", 70 | "bf16", 71 | "fp16", 72 | "tf32", 73 | "dataloader_drop_last", 74 | "eval_steps", 75 | "dataloader_num_workers", 76 | "metric_for_best_model", 77 | "greater_is_better", 78 | "label_smoothing_factor", 79 | "optim", 80 | "adafactor", 81 | "group_by_length", 82 | "resume_from_checkpoint", 83 | "gradient_checkpointing", 84 | } 85 | 86 | if importlib.util.find_spec("mlflow") is not None: 87 | import mlflow 88 | 89 | 90 | class MLflowTracker(GeneralTracker): 91 | """ 92 | A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script. 93 | Args: 94 | experiment_name (`str`): 95 | Name of the experiment. 96 | Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument. 97 | logging_dir (`str`, `os.PathLike`): 98 | Location for mlflow logs to be stored. 99 | run_id (`str`): 100 | If specified, get the run with the specified UUID and log parameters and metrics under that run. 101 | The run’s end time is unset and its status is set to running, but the run’s other attributes 102 | (source_version, source_type, etc.) are not changed. 103 | Environment variable MLFLOW_RUN_ID has priority over this argument. 104 | tags (`dict`, `str`): 105 | An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, 106 | to set as tags on the run. If a run is being resumed, these tags are set on the resumed run. 107 | If a new run is being created, these tags are set on the new run. 108 | Environment variable MLFLOW_TAGS has priority over this argument. 109 | nested_run (`bool`): 110 | Controls whether run is nested in parent run. True creates a nested run. 111 | Environment variable MLFLOW_NESTED_RUN has priority over this argument. 112 | run_name (`str`): 113 | Name of new run (stored as a mlflow.runName tag). Used only when run_id is unspecified. 114 | description (`str`): 115 | An optional string that populates the description box of the run. 116 | If a run is being resumed, the description is set on the resumed run. 117 | If a new run is being created, the description is set on the new run. 118 | """ 119 | 120 | name = "mlflow" 121 | requires_logging_directory = True 122 | 123 | def __init__( 124 | self, 125 | experiment_name: Optional[str] = None, 126 | logging_dir: Optional[Union[str, os.PathLike]] = ".", 127 | run_id: Optional[str] = None, 128 | tags: Optional[Union[Dict[str, Any], str]] = None, 129 | nested_run: Optional[bool] = False, 130 | run_name: Optional[str] = None, 131 | description: Optional[str] = None, 132 | ): 133 | 134 | self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH 135 | self._MAX_PARAMS_TAGS_PER_BATCH = ( 136 | mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH 137 | ) 138 | 139 | experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", experiment_name) 140 | run_id = os.getenv("MLFLOW_RUN_ID", run_id) 141 | tags = os.getenv("MLFLOW_TAGS", tags) 142 | if isinstance(tags, str): 143 | tags = json.loads(tags) 144 | 145 | nested_run = os.getenv("MLFLOW_NESTED_RUN", nested_run) 146 | 147 | if mlflow.active_run() and not nested_run: 148 | raise ValueError("Detected active run. `nested_run` must be set to True.") 149 | 150 | if mlflow.get_experiment_by_name(experiment_name): 151 | experiment_id = mlflow.set_experiment( 152 | experiment_name, 153 | ).experiment_id 154 | mlflow.set_experiment_tags(tags) 155 | 156 | else: 157 | experiment_id = mlflow.create_experiment( 158 | name=experiment_name, 159 | artifact_location=logging_dir, 160 | tags=tags, 161 | ) 162 | 163 | self.active_run = mlflow.start_run( 164 | run_id=run_id, 165 | experiment_id=experiment_id, 166 | run_name=run_name, 167 | nested=nested_run, 168 | tags=tags, 169 | description=description, 170 | ) 171 | 172 | logger.debug(f"Initialized mlflow experiment {experiment_name}") 173 | logger.debug( 174 | "Make sure to log any initial configurations with `self.store_init_configuration` before training!" 175 | ) 176 | 177 | @property 178 | def tracker(self): 179 | return self.active_run 180 | 181 | def save_training_args(self, training_args: TrainingArguments): 182 | """ 183 | To be used in Azure ML. Since the log limit is 100 parameters, 184 | this method should be called first to save the training arguments to a 185 | json file. The method then returns a filtered down dict that can be 186 | logged as parameters when passed as a configuration argument to 187 | `init_trackers`. 188 | Args: 189 | training_arguments (`TrainingArguments`): 190 | TrainingArguments to save. 191 | Returns: 192 | `dict` of the key/value pairs to be logged as parameters. 193 | """ 194 | mlflow.log_dict(training_args.to_dict(), "training_arguments.json") 195 | 196 | return { 197 | k: v 198 | for k, v in training_args.to_dict().items() 199 | if k in training_args_to_log 200 | } 201 | 202 | def store_init_configuration(self, values: dict): 203 | """ 204 | Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. 205 | Args: 206 | values (`dict`): 207 | Values to be stored as initial hyperparameters as key-value pairs. 208 | """ 209 | 210 | for name, value in list(values.items()): 211 | # internally, all values are converted to str in MLflow 212 | if len(str(value)) > self._MAX_PARAM_VAL_LENGTH: 213 | logger.warning( 214 | f'Attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s' 215 | f" log_param() only accepts values shorter than {self._MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute." 216 | ) 217 | del values[name] 218 | 219 | values_list = list(values.items()) 220 | 221 | if os.getenv("AML_CloudName") == "AzureCloud": 222 | values_list = values_list[:100] 223 | 224 | # MLflow cannot log more than 100 values in one go, so we have to split it 225 | for i in range(0, len(values_list), self._MAX_PARAMS_TAGS_PER_BATCH): 226 | mlflow.log_params( 227 | dict(values_list[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]) 228 | ) 229 | 230 | logger.debug("Stored initial configuration hyperparameters to MLflow") 231 | 232 | def log(self, values: dict, step: Optional[int]): 233 | """ 234 | Logs `values` to the current run. 235 | Args: 236 | values (`dict`): 237 | Values to be logged as key-value pairs. 238 | step (`int`, *optional*): 239 | The run step. If included, the log will be affiliated with this step. 240 | """ 241 | metrics = {} 242 | for k, v in values.items(): 243 | if isinstance(v, (int, float)): 244 | metrics[k] = v 245 | else: 246 | logger.warning( 247 | f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. ' 248 | "MLflow's log_metric() only accepts float and int types so we dropped this attribute." 249 | ) 250 | 251 | mlflow.log_metrics(metrics, step=step) 252 | logger.debug("Successfully logged to mlflow") 253 | 254 | def finish(self): 255 | """ 256 | End the active MLflow run. 257 | """ 258 | mlflow.end_run() 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/strideformer/model.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Strideformer Model """ 2 | import os 3 | from pathlib import Path 4 | from typing import Optional, Tuple, Union, Iterator 5 | 6 | import torch 7 | from torch import nn 8 | from transformers import PreTrainedModel, AutoModel, PretrainedConfig, AutoConfig 9 | from transformers.modeling_outputs import ModelOutput 10 | 11 | from .config import StrideformerConfig 12 | 13 | 14 | class StrideformerOutput(ModelOutput): 15 | """ 16 | Base class for outputs of Strideformer model. 17 | Args: 18 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 19 | Classification (or regression if config.num_labels==1) loss. 20 | logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): 21 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 22 | first_model_last_hidden_states (`torch.FloatTensor`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 23 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 24 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 25 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 26 | second_model_last_hidden_states (`torch.FloatTensor`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 27 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 28 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 29 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 30 | """ 31 | 32 | loss: torch.FloatTensor = None 33 | logits: torch.FloatTensor = None 34 | first_model_hidden_states: Optional[torch.FloatTensor] = None 35 | second_model_hidden_states: Optional[torch.FloatTensor] = None 36 | 37 | 38 | class Strideformer(PreTrainedModel): 39 | 40 | config_class = StrideformerConfig 41 | supports_gradient_checkpointing = False 42 | 43 | def __init__( 44 | self, 45 | config: StrideformerConfig, 46 | ) -> None: 47 | """ 48 | Initializes Strideformer model with random values. 49 | Use `from_pretrained` to load pretrained weights. 50 | 51 | Args: 52 | config (StrideformerConfig): 53 | Configuration file for this model. Holds the information for what 54 | pretrained model to use for the first model and all of the parameters 55 | for the hidden_size 56 | """ 57 | super().__init__( 58 | config, 59 | ) 60 | self.config = config 61 | self.first_model = AutoModel.from_config(self.config.first_model_config) 62 | 63 | self.max_chunks = config.max_chunks 64 | 65 | encoder_layer = nn.TransformerEncoderLayer( 66 | d_model=self.config.first_model_config.hidden_size, 67 | nhead=config.num_attention_heads, 68 | dim_feedforward=config.intermediate_size, 69 | dropout=config.dropout, 70 | activation=config.hidden_act, 71 | layer_norm_eps=config.layer_norm_eps, 72 | batch_first=True, 73 | ) 74 | self.second_model = nn.TransformerEncoder( 75 | encoder_layer, num_layers=config.num_hidden_layers 76 | ) 77 | self.classifier = nn.Linear(self.config.first_model_config.hidden_size, config.num_labels) 78 | 79 | self.post_init() 80 | 81 | self.first_model = AutoModel.from_pretrained( 82 | self.config.first_model_config._name_or_path 83 | ) 84 | 85 | @staticmethod 86 | def mean_pooling( 87 | token_embeddings: torch.FloatTensor, 88 | attention_mask: Optional[torch.FloatTensor] = None, 89 | ) -> torch.FloatTensor: 90 | """ 91 | Mean pool across the `sequence_length` dimension. Assumes that token 92 | embeddings have shape `(batch_size, sequence_length, hidden_size)`. 93 | If batched, there can be pad tokens in the sequence. 94 | This will ignore padded outputs when doing mean pooling by using 95 | `attention_mask`. 96 | 97 | Args: 98 | token_embeddings (`torch.FloatTensor`): 99 | Embeddings to be averaged across the first dimension. 100 | attention_mask (`torch.LongTensor`): 101 | Attention mask for the embeddings. Used to ignore 102 | padd tokens from the averaging. 103 | 104 | Returns: 105 | `torch.FloatTensor`of shape `(batch_size, hidden_size)` that is 106 | `token_embeddings` averaged across the 1st dimension. 107 | """ 108 | 109 | if attention_mask is not None: 110 | input_mask_expanded = ( 111 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 112 | ) 113 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 114 | input_mask_expanded.sum(1), min=1e-9 115 | ) 116 | 117 | # this might be wrong 118 | return torch.sum(token_embeddings, 1) / torch.clamp( 119 | token_embeddings.sum(1), min=1e-9 120 | ) 121 | 122 | def forward( 123 | self, 124 | input_ids: Optional[torch.Tensor] = None, 125 | attention_mask: Optional[torch.Tensor] = None, 126 | token_type_ids: Optional[torch.Tensor] = None, 127 | labels: Optional[torch.Tensor] = None, 128 | batch_size: Optional[int] = 1, 129 | output_hidden_states: Optional[bool] = None, 130 | return_dict: Optional[bool] = None, 131 | ) -> Union[Tuple[torch.Tensor], StrideformerOutput]: 132 | """ 133 | Args: 134 | input_ids (`torch.Tensor`, *optional*, defaults to None): 135 | Indices of input sequence tokens in the vocabulary. These can be created 136 | using the corresponding tokenizer for the first model. 137 | Shape is `(num_chunks, sequence_length)` where `num_chunks` is `(batch_size*chunks_per_batch)`. 138 | token_type_ids (`torch.Tensor`, *optional*, defaults to None): 139 | Some models take token_type_ids. This comes from the tokenizer and gets 140 | passed to the first model. 141 | labels (`torch.FloatTensor` or `torch.Tensor`, *optional*, defaults to None): 142 | The true values. Used for loss calculation. 143 | Shape is `(batch_size, num_classes)` if multilabel, 144 | `(batch_size, 1)` for multiclass or regression. 145 | batch_size (`int`, *optional*, defaults to 1): 146 | If passing batched inputs, this specifies the shape of input for the second model. 147 | The first model will get input `(num_chunks, sequence_length)` wherewhere `num_chunks` 148 | is `(batch_size*chunks_per_batch)`. The output of the first model is `(num_chunks, hidden_size)`. 149 | This gets reshaped to `(batch_size, chunks_per_batch, hidden_size)`. This means that 150 | all document sequences must be tokenized to the same number of chunks. 151 | Returns: 152 | A `tuple` of `torch.Tensor` if `return_dict` is `False`. 153 | A `StrideformerOutput` object if `return_dict` is None or True. 154 | These containers hold values for loss, logits, and last hidden states for 155 | both models. 156 | """ 157 | return_dict = ( 158 | return_dict if return_dict is not None else self.config.use_return_dict 159 | ) 160 | 161 | token_type_ids = ( 162 | {"token_type_ids": token_type_ids} if token_type_ids is not None else {} 163 | ) 164 | 165 | if self.config.freeze_first_model: 166 | # No gradients, no training, save memory 167 | with torch.no_grad(): 168 | first_model_hidden_states = self.first_model( 169 | input_ids=input_ids, 170 | attention_mask=attention_mask, 171 | **token_type_ids, 172 | )[0] 173 | else: 174 | first_model_hidden_states = self.first_model( 175 | input_ids=input_ids, 176 | attention_mask=attention_mask, 177 | **token_type_ids, 178 | )[0] 179 | 180 | # mean pool last hidden state 181 | embeddings = self.mean_pooling( 182 | first_model_hidden_states, attention_mask=attention_mask 183 | ) 184 | 185 | second_model_hidden_states = self.second_model( 186 | embeddings.reshape(batch_size, -1, self.config.hidden_size), 187 | ) # [batch_size, chunks_per_batch, hidden_size] 188 | 189 | # Classifier uses mean pooling to combine output embeddings into single embedding. 190 | second_model_chunk_logits = self.classifier( 191 | second_model_hidden_states 192 | ) # [batch_size, chunks_per_batch, num_labels] 193 | logits = second_model_chunk_logits.mean(dim=1) # [batch_size, num_labels] 194 | 195 | loss = None 196 | if labels is not None: 197 | 198 | if self.config.problem_type is None: 199 | if self.config.num_labels == 1: 200 | self.config.problem_type = "regression" 201 | elif self.config.num_labels > 1 and ( 202 | labels.dtype == torch.long or labels.dtype == torch.int 203 | ): 204 | self.config.problem_type = "single_label_classification" 205 | else: 206 | self.config.problem_type = "multi_label_classification" 207 | 208 | if self.config.problem_type == "regression": 209 | loss_fct = nn.MSELoss() 210 | if self.config.num_labels == 1: 211 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 212 | else: 213 | loss = loss_fct(logits, labels) 214 | elif self.config.problem_type == "single_label_classification": 215 | loss_fct = nn.CrossEntropyLoss() 216 | loss = loss_fct( 217 | logits.view(-1, self.config.num_labels), labels.view(-1) 218 | ) 219 | elif self.config.problem_type == "multi_label_classification": 220 | loss_fct = nn.BCEWithLogitsLoss() 221 | loss = loss_fct(logits, labels) 222 | 223 | if not output_hidden_states: 224 | first_model_hidden_states = None 225 | second_model_hidden_states = None 226 | 227 | if not return_dict: 228 | output = (logits, first_model_hidden_states, second_model_hidden_states) 229 | return ((loss,) + output) if loss is not None else output 230 | 231 | return StrideformerOutput( 232 | loss=loss, 233 | logits=logits, 234 | first_model_hidden_states=first_model_hidden_states, 235 | second_model_hidden_states=second_model_hidden_states, 236 | ) 237 | 238 | 239 | # From: https://github.com/huggingface/transformers/blob/31e3c6c393964d3e470c8229741b444d7ce94941/src/transformers/models/bert/modeling_bert.py#L748 240 | def _init_weights(self, module): 241 | """Initialize the weights""" 242 | if isinstance(module, nn.Linear): 243 | # Slightly different from the TF version which uses truncated_normal for initialization 244 | # cf https://github.com/pytorch/pytorch/pull/5617 245 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 246 | if module.bias is not None: 247 | module.bias.data.zero_() 248 | elif isinstance(module, nn.Embedding): 249 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 250 | if module.padding_idx is not None: 251 | module.weight.data[module.padding_idx].zero_() 252 | elif isinstance(module, nn.LayerNorm): 253 | module.bias.data.zero_() 254 | module.weight.data.fill_(1.0) 255 | -------------------------------------------------------------------------------- /examples/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from functools import partial 3 | from typing import Dict, Optional, Tuple, List 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | 7 | from datasets import load_dataset, Dataset 8 | from transformers import AutoTokenizer, PreTrainedTokenizer 9 | from omegaconf import OmegaConf, DictConfig 10 | 11 | 12 | @dataclass 13 | class DataModule: 14 | """ 15 | Responsible for loading and tokenizing the dataset. 16 | Can handle local files or datasets from the Hugging Face Hub. 17 | 18 | If no dataset is specified, it will load a local file. 19 | Otherwise, it will load a dataset from the Hub. 20 | """ 21 | 22 | 23 | cfg: DictConfig = None 24 | 25 | def __post_init__(self): 26 | if self.cfg is None: 27 | raise ValueError 28 | 29 | self.tokenizer = AutoTokenizer.from_pretrained( 30 | self.cfg.model.model_name_or_path, 31 | ) 32 | 33 | def prepare_dataset(self) -> None: 34 | """ 35 | Load in dataset and tokenize. 36 | 37 | If debugging, take small subset of the full dataset. 38 | """ 39 | 40 | if self.cfg.data.dataset_name is None: 41 | processor = LocalFileProcessor( 42 | cfg=self.cfg, 43 | tokenizer=self.tokenizer, 44 | ) 45 | else: 46 | processor = name2processor[self.cfg.data.dataset_name]( 47 | cfg=self.cfg, 48 | tokenizer=self.tokenizer, 49 | ) 50 | 51 | self.raw_dataset, self.tokenized_dataset = processor.prepare_dataset() 52 | 53 | self.label2id = processor.get_label2id() 54 | self.id2label = {i: l for l, i in self.label2id.items()} 55 | 56 | def get_dataset(self, split: str, tokenized: bool = True) -> Dataset: 57 | ds_attr = "tokenized_dataset" if tokenized else "raw_dataset" 58 | return getattr(self, ds_attr)[split] 59 | 60 | 61 | class GenericDatasetProcessor: 62 | """ 63 | Can load any dataset from the Hub. 64 | """ 65 | 66 | def __init__(self, cfg: DictConfig, tokenizer: PreTrainedTokenizer) -> None: 67 | super().__init__() 68 | 69 | self.cfg = cfg 70 | self.tokenizer = tokenizer 71 | 72 | def set_label2id(self, train_dataset: Dataset) -> None: 73 | 74 | labels = train_dataset.unique(self.cfg.data.label_col) 75 | labels = sorted(labels) 76 | 77 | self.label2id = {label: i for i, label in enumerate(labels)} 78 | 79 | def get_label2id(self) -> Dict[str, int]: 80 | return self.label2id 81 | 82 | def prepare_dataset(self) -> Tuple[Dataset, Dataset]: 83 | raw_dataset = load_dataset( 84 | self.cfg.data.dataset_name, self.cfg.data.dataset_config_name 85 | ) 86 | 87 | # Limit the number of rows, if desired 88 | if self.cfg.data.n_rows is not None and self.cfg.data.n_rows > 0: 89 | for split in raw_dataset: 90 | max_split_samples = min(self.cfg.data.n_rows, len(raw_dataset[split])) 91 | raw_dataset[split] = raw_dataset[split].select(range(max_split_samples)) 92 | 93 | cols = raw_dataset["train"].column_names 94 | 95 | tokenized_dataset = raw_dataset.map( 96 | partial( 97 | tokenize, 98 | tokenizer=self.tokenizer, 99 | max_length=self.cfg.data.max_seq_length, 100 | stride=self.cfg.data.stride, 101 | text_col=self.cfg.data.text_col, 102 | text_pair_col=self.cfg.data.text_pair_col, 103 | label_col=self.cfg.data.label_col, 104 | ), 105 | batched=self.cfg.data.stride in {None, 0}, 106 | batch_size=self.cfg.data.map_batch_size, 107 | num_proc=self.cfg.num_proc, 108 | remove_columns=cols, 109 | ) 110 | 111 | return raw_dataset, tokenized_dataset 112 | 113 | 114 | class PubHealthProcessor: 115 | """ 116 | Data processor for the PUBHEALTH dataset. 117 | https://huggingface.co/datasets/health_fact 118 | """ 119 | 120 | def __init__(self, cfg: DictConfig, tokenizer: PreTrainedTokenizer) -> None: 121 | super().__init__() 122 | 123 | self.cfg = cfg 124 | self.tokenizer = tokenizer 125 | 126 | @staticmethod 127 | def get_label2id() -> Dict[str, int]: 128 | return {"false": 0, "mixture": 1, 129 | "true": 2, "unproven": 3} 130 | 131 | def prepare_dataset(self) -> Tuple[Dataset, Dataset]: 132 | 133 | raw_dataset = load_dataset( 134 | self.cfg.data.dataset_name, self.cfg.data.dataset_config_name 135 | ) 136 | 137 | # Ignore examples that have bad labels 138 | raw_dataset = raw_dataset.filter(lambda x: x["label"] != -1) 139 | 140 | # Limit the number of rows, if desired 141 | if self.cfg.data.n_rows is not None and self.cfg.data.n_rows > 0: 142 | for split in raw_dataset: 143 | max_split_samples = min(self.cfg.data.n_rows, len(raw_dataset[split])) 144 | raw_dataset[split] = raw_dataset[split].select(range(max_split_samples)) 145 | 146 | cols = raw_dataset["train"].column_names 147 | 148 | tokenized_dataset = raw_dataset.map( 149 | partial( 150 | tokenize, 151 | tokenizer=self.tokenizer, 152 | max_length=self.cfg.data.max_seq_length, 153 | stride=self.cfg.data.stride, 154 | text_col=self.cfg.data.text_col, 155 | text_pair_col=self.cfg.data.text_pair_col, 156 | label_col=self.cfg.data.label_col, 157 | ), 158 | batched=self.cfg.data.stride in {None, 0}, 159 | batch_size=self.cfg.data.map_batch_size, 160 | num_proc=self.cfg.num_proc, 161 | remove_columns=cols, 162 | ) 163 | 164 | return raw_dataset, tokenized_dataset 165 | 166 | 167 | class ArxivProcessor: 168 | 169 | """ 170 | Data processor for the arxiv dataset. 171 | https://huggingface.co/datasets/ccdv/arxiv-classification 172 | 173 | Task: Given an arxiv paper, predict the subject area. 174 | """ 175 | 176 | 177 | def __init__(self, cfg: DictConfig, tokenizer: PreTrainedTokenizer) -> None: 178 | super().__init__() 179 | 180 | self.cfg = cfg 181 | self.tokenizer = tokenizer 182 | 183 | @staticmethod 184 | def get_label2id() -> Dict[str, int]: 185 | 186 | labels = [ 187 | "math.AC", 188 | "cs.CV", 189 | "cs.AI", 190 | "cs.SY", 191 | "math.GR", 192 | "cs.CE", 193 | "cs.PL", 194 | "cs.IT", 195 | "cs.DS", 196 | "cs.NE", 197 | "math.ST", 198 | ] 199 | 200 | return {label: i for i, label in enumerate(labels)} 201 | 202 | def prepare_dataset(self) -> Tuple[Dataset, Dataset]: 203 | 204 | raw_dataset = load_dataset( 205 | self.cfg.data.dataset_name, self.cfg.data.dataset_config_name 206 | ) 207 | 208 | # Limit the number of rows, if desired 209 | if self.cfg.data.n_rows is not None and self.cfg.data.n_rows > 0: 210 | for split in raw_dataset: 211 | max_split_samples = min(self.cfg.data.n_rows, len(raw_dataset[split])) 212 | raw_dataset[split] = raw_dataset[split].select(range(max_split_samples)) 213 | 214 | cols = raw_dataset["train"].column_names 215 | 216 | tokenized_dataset = raw_dataset.map( 217 | partial( 218 | tokenize, 219 | tokenizer=self.tokenizer, 220 | max_length=self.cfg.data.max_seq_length, 221 | stride=self.cfg.data.stride, 222 | text_col=self.cfg.data.text_col, 223 | label_col=self.cfg.data.label_col, 224 | ), 225 | batched=self.cfg.data.stride in {None, 0}, 226 | batch_size=self.cfg.data.map_batch_size, 227 | num_proc=self.cfg.num_proc, 228 | remove_columns=cols, 229 | ) 230 | 231 | return raw_dataset, tokenized_dataset 232 | 233 | 234 | class LocalFileProcessor: 235 | """ 236 | Can load csv, json, or parquet files that are on local storage. 237 | """ 238 | 239 | def __init__(self, cfg: DictConfig, tokenizer: PreTrainedTokenizer) -> None: 240 | super().__init__() 241 | 242 | self.cfg = cfg 243 | self.tokenizer = tokenizer 244 | 245 | def prepare_dataset(self) -> Tuple[Dataset, Dataset]: 246 | 247 | # get ending of file (csv, json, parquet) 248 | filetype = Path(self.cfg.data.data_files["train"][0]).suffix 249 | filetype = filetype.lstrip(".") # suffix keeps period 250 | 251 | if filetype not in {"csv", "json", "parquet"}: 252 | raise ValueError(f"Files should end in 'csv', 'json', or 'parquet', not {filetype}.") 253 | 254 | data_files = OmegaConf.to_container(self.cfg.data.data_files) 255 | raw_dataset = load_dataset(filetype, data_files=data_files) 256 | 257 | # Limit the number of rows, if desired 258 | if self.cfg.data.n_rows is not None and self.cfg.data.n_rows > 0: 259 | for split in raw_dataset: 260 | max_split_samples = min(self.cfg.data.n_rows, len(raw_dataset[split])) 261 | raw_dataset[split] = raw_dataset[split].select(range(max_split_samples)) 262 | 263 | cols = raw_dataset["train"].column_names 264 | self.set_label2id(raw_dataset["train"]) 265 | 266 | raw_dataset = raw_dataset.map( 267 | partial( 268 | change_label_to_int, 269 | label2id=self.label2id, 270 | label_col=self.cfg.data.label_col, 271 | ), 272 | batched=True, 273 | num_proc=self.cfg.num_proc, 274 | ) 275 | 276 | 277 | tokenized_dataset = raw_dataset.map( 278 | partial( 279 | tokenize, 280 | tokenizer=self.tokenizer, 281 | max_length=self.cfg.data.max_seq_length, 282 | stride=self.cfg.data.stride, 283 | text_col=self.cfg.data.text_col, 284 | text_pair_col=self.cfg.data.text_pair_col, 285 | label_col=self.cfg.data.label_col, 286 | ), 287 | batched=self.cfg.data.stride in {None, 0}, 288 | batch_size=self.cfg.data.map_batch_size, 289 | num_proc=self.cfg.num_proc, 290 | remove_columns=cols, 291 | ) 292 | 293 | return raw_dataset, tokenized_dataset 294 | 295 | def set_label2id(self, train_dataset: Dataset) -> None: 296 | 297 | labels = train_dataset.unique(self.cfg.data.label_col) 298 | labels = sorted(labels) 299 | 300 | self.label2id = {label: i for i, label in enumerate(labels)} 301 | 302 | def get_label2id(self) -> Dict[str, int]: 303 | """ 304 | Must be called after `set_label2id` 305 | """ 306 | return self.label2id 307 | 308 | 309 | def tokenize( 310 | examples, 311 | tokenizer, 312 | max_length, 313 | stride=None, 314 | text_col="text", 315 | text_pair_col=None, 316 | label_col="label", 317 | ) -> Dict[str, List[int]]: 318 | 319 | tokenizer_kwargs = { 320 | "padding": False, 321 | } 322 | 323 | # If stride is not None, using sbert approach 324 | if stride is not None and stride > 0: 325 | tokenizer_kwargs.update( 326 | { 327 | "padding": True, 328 | "stride": stride, 329 | "return_overflowing_tokens": True, 330 | } 331 | ) 332 | 333 | texts = [examples[text_col]] 334 | if text_pair_col is not None: 335 | texts.append(examples[text_pair_col]) 336 | 337 | tokenized = tokenizer( 338 | *texts, 339 | truncation=True, 340 | max_length=max_length, 341 | **tokenizer_kwargs, 342 | ) 343 | 344 | tokenized["labels"] = examples[label_col] 345 | 346 | return tokenized 347 | 348 | 349 | def change_label_to_int(example: Dict[str, str], label2id: Dict[str, int], label_col: str): 350 | 351 | if isinstance(example[label_col], list): 352 | return { 353 | label_col: [label2id[l] for l in example[label_col]] 354 | } 355 | return { 356 | label_col: label2id[example[label_col]] 357 | } 358 | 359 | 360 | 361 | 362 | # map dataset name to processor 363 | name2processor = defaultdict(lambda: GenericDatasetProcessor) 364 | 365 | name2processor.update( 366 | { 367 | "health_fact": PubHealthProcessor, 368 | "ccdv/arxiv-classification": ArxivProcessor, 369 | } 370 | ) 371 | -------------------------------------------------------------------------------- /examples/run_classification_no_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning short model chunks for long sequence classification""" 17 | 18 | import logging 19 | import os 20 | import math 21 | from itertools import chain 22 | 23 | from tqdm.auto import tqdm 24 | from sklearn.metrics import f1_score 25 | import torch 26 | from torch.utils.data import DataLoader 27 | import hydra 28 | from omegaconf import DictConfig, OmegaConf 29 | import datasets 30 | import transformers 31 | from transformers import ( 32 | AutoConfig, 33 | AutoModelForSequenceClassification, 34 | TrainingArguments, 35 | set_seed, 36 | get_scheduler, 37 | DataCollatorWithPadding, 38 | ) 39 | from transformers.trainer_pt_utils import IterableDatasetShard 40 | from transformers.utils import flatten_dict 41 | from accelerate import Accelerator 42 | from accelerate.logging import get_logger 43 | 44 | from utils import ( 45 | set_wandb_env_vars, 46 | set_mlflow_env_vars, 47 | ) 48 | from strideformer import Strideformer, StrideformerConfig, StrideformerCollator 49 | from data import DataModule 50 | 51 | 52 | logger = get_logger(__name__) 53 | 54 | 55 | def training_loop(accelerator, model, optimizer, lr_scheduler, dataloaders, args): 56 | 57 | progress_bar = tqdm( 58 | range(args.max_steps), 59 | disable=not accelerator.is_local_main_process, 60 | desc="Training", 61 | ) 62 | completed_steps = 0 63 | 64 | for epoch in range(args.num_train_epochs): 65 | 66 | model.train() 67 | 68 | if args.report_to is not None: 69 | epoch_loss = 0 70 | 71 | example_count = 0 72 | for step, batch in enumerate(dataloaders["train"]): 73 | 74 | with accelerator.accumulate(model): 75 | outputs = model(**batch) 76 | loss = outputs["loss"] 77 | 78 | example_count += batch["labels"].numel() 79 | 80 | # We keep track of the loss at each epoch 81 | if args.report_to is not None: 82 | epoch_loss += loss.detach().float() * batch["labels"].numel() 83 | 84 | loss = loss / args.gradient_accumulation_steps 85 | accelerator.backward(loss) 86 | 87 | optimizer.step() 88 | lr_scheduler.step() 89 | optimizer.zero_grad() 90 | 91 | if (step+1) % args.gradient_accumulation_steps == 0: 92 | progress_bar.update(1) 93 | completed_steps += 1 94 | 95 | if args.save_strategy == "steps": 96 | checkpointing_steps = args.save_steps 97 | if completed_steps % checkpointing_steps == 0: 98 | output_dir = f"step_{completed_steps }" 99 | if args.output_dir is not None: 100 | output_dir = os.path.join(args.output_dir, output_dir) 101 | accelerator.save_state(output_dir) 102 | 103 | if ( 104 | args.logging_strategy == "steps" 105 | and (step + 1) % args.logging_steps == 0 106 | ): 107 | details = { 108 | "train_loss": epoch_loss.item() / example_count, 109 | "step": completed_steps, 110 | } 111 | accelerator.log( 112 | details, 113 | step=completed_steps, 114 | ) 115 | logging.info(details) 116 | progress_bar.set_description( 117 | f"Training: Epoch {epoch} | Loss {details['train_loss']} | Step {details['step']}" 118 | ) 119 | 120 | if completed_steps >= args.max_steps: 121 | break 122 | 123 | if args.do_eval: 124 | 125 | eval_metrics = eval_loop( 126 | accelerator=accelerator, 127 | model=model, 128 | dataloader=dataloaders["validation"], 129 | prefix="eval", 130 | ) 131 | 132 | eval_metrics.update( 133 | { 134 | "epoch": epoch, 135 | "step": completed_steps, 136 | } 137 | ) 138 | 139 | accelerator.log(eval_metrics, step=completed_steps) 140 | logging.info(eval_metrics) 141 | 142 | if completed_steps >= args.max_steps: 143 | break 144 | 145 | return completed_steps 146 | 147 | 148 | @torch.no_grad() 149 | def eval_loop(accelerator, model, dataloader, prefix): 150 | 151 | model.eval() 152 | 153 | y_preds = [] 154 | y_true = [] 155 | eval_loss = 0 156 | count = 0 157 | 158 | num_steps = len(dataloader) 159 | 160 | progress_bar = tqdm( 161 | range(num_steps), 162 | disable=not accelerator.is_local_main_process, 163 | desc=f"{prefix}", 164 | ) 165 | 166 | for batch in dataloader: 167 | 168 | outputs = model(**batch) 169 | 170 | eval_loss += outputs["loss"].detach().float() 171 | 172 | # TODO: gather for distributed training 173 | 174 | preds = outputs["logits"].argmax(-1).detach().cpu().tolist() 175 | labels = batch["labels"].detach().cpu().tolist() 176 | count += len(labels) 177 | 178 | y_preds.append(preds) 179 | y_true.append(labels) 180 | 181 | progress_bar.update(1) 182 | 183 | if isinstance(y_preds[0], list): 184 | y_preds = list(chain(*y_preds)) 185 | y_true = list(chain(*y_true)) 186 | 187 | f1_micro = f1_score(y_true, y_preds, average="micro") 188 | f1_macro = f1_score(y_true, y_preds, average="macro") 189 | 190 | metrics = { 191 | "loss": eval_loss.item() / count, 192 | "f1_micro": f1_micro, 193 | "f1_macro": f1_macro, 194 | } 195 | 196 | metrics_desc = " | ".join( 197 | [f"{name.capitalize()} {round(score, 4)}" for name, score in metrics.items()] 198 | ) 199 | 200 | progress_bar.set_description(f"{prefix}: {metrics_desc}") 201 | 202 | for name in metrics.keys(): 203 | if not name.startswith(prefix): 204 | metrics[f"{prefix}_{name}"] = metrics.pop(name) 205 | 206 | return metrics 207 | 208 | 209 | def num_examples(dataloader: DataLoader) -> int: 210 | """ 211 | Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When 212 | dataloader.dataset does not exist or has no length, estimates as best it can 213 | """ 214 | dataset = dataloader.dataset 215 | # Special case for IterableDatasetShard, we need to dig deeper 216 | if isinstance(dataset, IterableDatasetShard): 217 | return len(dataloader.dataset.dataset) 218 | return len(dataloader.dataset) 219 | 220 | 221 | def train_samples_steps_epochs(dataloader, args): 222 | # TODO: make use of num_train_samples and num_train_epochs 223 | 224 | total_train_batch_size = ( 225 | args.train_batch_size * args.gradient_accumulation_steps * args.world_size 226 | ) 227 | 228 | num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps) 229 | 230 | if args.max_steps > 0: 231 | max_steps = args.max_steps 232 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 233 | args.max_steps % num_update_steps_per_epoch > 0 234 | ) 235 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's 236 | # the best we can do. 237 | num_train_samples = args.max_steps * total_train_batch_size 238 | else: 239 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 240 | num_train_epochs = math.ceil(args.num_train_epochs) 241 | num_train_samples = num_examples(dataloader) * args.num_train_epochs 242 | 243 | args.max_steps = max_steps 244 | 245 | return num_train_samples, max_steps, num_train_epochs 246 | 247 | 248 | def get_optimizer_and_scheduler(model, train_dataloader, args): 249 | 250 | # Split weights in two groups, one with weight decay and the other not. 251 | no_decay = ["bias", "LayerNorm.weight"] 252 | optimizer_grouped_parameters = [ 253 | { 254 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 255 | "weight_decay": args.weight_decay, 256 | }, 257 | { 258 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 259 | "weight_decay": 0.0, 260 | }, 261 | ] 262 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 263 | 264 | num_train_samples, max_steps, num_train_epochs = train_samples_steps_epochs( 265 | train_dataloader, args 266 | ) 267 | 268 | lr_scheduler = get_scheduler( 269 | name=args.lr_scheduler_type, 270 | optimizer=optimizer, 271 | num_warmup_steps=args.get_warmup_steps(max_steps), 272 | num_training_steps=max_steps, 273 | ) 274 | 275 | return optimizer, lr_scheduler 276 | 277 | 278 | @hydra.main(version_base=None, config_path="conf", config_name="config") 279 | def main(cfg: DictConfig) -> None: 280 | 281 | train_args = dict(cfg.training_arguments) 282 | training_args = TrainingArguments(**train_args) 283 | 284 | mixed_precision = "no" 285 | if training_args.fp16: 286 | mixed_precision = "fp16" 287 | if training_args.bf16: 288 | mixed_precision = "bf16" 289 | 290 | if "wandb" in training_args.report_to: 291 | set_wandb_env_vars(cfg) 292 | if "mlflow" in training_args.report_to: 293 | set_mlflow_env_vars(cfg) 294 | 295 | accelerator = Accelerator( 296 | mixed_precision=mixed_precision, 297 | log_with=training_args.report_to, 298 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 299 | ) 300 | 301 | log_level = training_args.get_process_log_level() 302 | # Make one log on every process with the configuration for debugging. 303 | logging.basicConfig( 304 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 305 | datefmt="%m/%d/%Y %H:%M:%S", 306 | level=log_level, 307 | ) 308 | logger.info(accelerator.state, main_process_only=False) 309 | 310 | 311 | if log_level != -1: 312 | logger.setLevel(log_level) 313 | 314 | if accelerator.is_local_main_process: 315 | if log_level == -1: 316 | datasets.utils.logging.set_verbosity_warning() 317 | transformers.utils.logging.set_verbosity_info() 318 | else: 319 | datasets.utils.logging.set_verbosity(log_level) 320 | transformers.utils.logging.set_verbosity(log_level) 321 | else: 322 | datasets.utils.logging.set_verbosity_error() 323 | transformers.utils.logging.set_verbosity_error() 324 | 325 | # Log on each process the small summary: 326 | logger.warning( 327 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 328 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 329 | ) 330 | logger.info(f"Training/evaluation parameters {training_args}") 331 | 332 | # Set seed before initializing model. 333 | set_seed(training_args.seed) 334 | 335 | data_module = DataModule(cfg) 336 | 337 | # If running distributed, this will do something on the main process, while 338 | # blocking replicas, and when it's finished, it will release the replicas. 339 | with training_args.main_process_first(desc="Dataset loading and tokenization"): 340 | data_module.prepare_dataset() 341 | 342 | 343 | use_strideformer = cfg.data.stride is not None and cfg.data.stride > 0 344 | 345 | if use_strideformer: 346 | collator = StrideformerCollator(tokenizer=data_module.tokenizer) 347 | 348 | # batch_sizes must always be 1 when using strided approach 349 | if ( 350 | training_args.per_device_train_batch_size != 1 351 | or training_args.per_device_eval_batch_size != 1 352 | ): 353 | logger.warning( 354 | "Batch size must be 1 when using strided approach. Changing to 1 now." 355 | ) 356 | training_args.per_device_train_batch_size = 1 357 | training_args.per_device_eval_batch_size = 1 358 | 359 | else: 360 | collator = DataCollatorWithPadding( 361 | tokenizer=data_module.tokenizer, pad_to_multiple_of=cfg.data.pad_multiple 362 | ) 363 | 364 | dataloaders = {} 365 | 366 | dataloaders["train"] = DataLoader( 367 | data_module.get_dataset("train"), 368 | batch_size=training_args.per_device_train_batch_size, 369 | collate_fn=collator, 370 | drop_last=training_args.dataloader_drop_last, 371 | num_workers=training_args.dataloader_num_workers, 372 | pin_memory=training_args.dataloader_pin_memory, 373 | ) 374 | 375 | if training_args.do_eval: 376 | 377 | dataloaders["validation"] = DataLoader( 378 | data_module.get_dataset("validation"), 379 | batch_size=training_args.per_device_eval_batch_size, 380 | collate_fn=collator, 381 | num_workers=training_args.dataloader_num_workers, 382 | pin_memory=training_args.dataloader_pin_memory, 383 | ) 384 | 385 | if training_args.do_predict: 386 | 387 | dataloaders["test"] = DataLoader( 388 | data_module.get_dataset("test"), 389 | batch_size=training_args.per_device_eval_batch_size, 390 | collate_fn=collator, 391 | num_workers=training_args.dataloader_num_workers, 392 | pin_memory=training_args.dataloader_pin_memory, 393 | ) 394 | 395 | 396 | model_config = AutoConfig.from_pretrained( 397 | cfg.model.model_name_or_path, 398 | label2id=data_module.label2id, 399 | id2label=data_module.id2label, 400 | num_labels=len(data_module.label2id), 401 | ) 402 | if use_strideformer: 403 | second_model_config = dict( 404 | freeze_first_model=cfg.model.freeze_first_model, 405 | max_chunks=cfg.model.max_chunks, 406 | num_hidden_layers=cfg.model.num_hidden_layers, 407 | num_attention_heads=cfg.model.num_attention_heads, 408 | intermediate_size=cfg.model.intermediate_size, 409 | hidden_act=cfg.model.hidden_act, 410 | dropout=cfg.model.dropout, 411 | layer_norm_eps=cfg.model.layer_norm_eps, 412 | initializer_range=cfg.model.initializer_range, 413 | hidden_size=model_config.hidden_size, 414 | label2id=data_module.label2id, 415 | id2label=data_module.id2label, 416 | num_labels=len(data_module.label2id), 417 | ) 418 | model_config = StrideformerConfig.from_two_configs( 419 | first_model_config=model_config, 420 | second_model_config=second_model_config, 421 | ) 422 | 423 | if use_strideformer: 424 | model = Strideformer( 425 | config=model_config 426 | ) 427 | else: 428 | model = AutoModelForSequenceClassification.from_pretrained( 429 | cfg.model.model_name_or_path, config=model_config 430 | ) 431 | 432 | optimizer, lr_scheduler = get_optimizer_and_scheduler( 433 | model, dataloaders["train"], training_args 434 | ) 435 | 436 | items2prepare = [ 437 | model, 438 | optimizer, 439 | lr_scheduler, 440 | dataloaders["train"], 441 | ] 442 | 443 | if training_args.do_eval: 444 | items2prepare.append(dataloaders["validation"]) 445 | 446 | if training_args.do_predict: 447 | items2prepare.append(dataloaders["test"]) 448 | 449 | model, optimizer, lr_scheduler, *dataloaders_list = accelerator.prepare( 450 | *items2prepare 451 | ) 452 | 453 | dataloaders["train"] = dataloaders_list.pop(0) 454 | 455 | if training_args.do_eval: 456 | dataloaders["validation"] = dataloaders_list.pop(0) 457 | 458 | if training_args.do_predict: 459 | dataloaders["test"] = dataloaders_list.pop(0) 460 | 461 | # We need to initialize the trackers we use, and also store our configuration. 462 | # We initialize the trackers only on main process because `accelerator.log` 463 | # only logs on main process and we don't want empty logs/runs on other processes. 464 | if training_args.report_to: 465 | if accelerator.is_main_process: 466 | experiment_config = OmegaConf.to_container(cfg) 467 | experiment_config.update(model.config.to_diff_dict()) 468 | 469 | accelerator.init_trackers(cfg.project_name, flatten_dict(experiment_config)) 470 | 471 | 472 | completed_steps = training_loop( 473 | accelerator, 474 | model, 475 | optimizer, 476 | lr_scheduler, 477 | dataloaders, 478 | training_args, 479 | ) 480 | 481 | if training_args.do_predict: 482 | 483 | test_metrics = eval_loop( 484 | accelerator=accelerator, 485 | model=model, 486 | dataloader=dataloaders["test"], 487 | prefix="test", 488 | ) 489 | 490 | accelerator.log(test_metrics, step=completed_steps) 491 | 492 | logging.info(test_metrics) 493 | 494 | accelerator.end_training() 495 | 496 | 497 | if __name__ == "__main__": 498 | main() 499 | --------------------------------------------------------------------------------