├── requirements.txt ├── .idea └── .gitignore ├── bert-rna-model.json ├── data └── family_to_index.txt ├── family_name_to_id.py ├── rna_k_mer_tokenizer.py ├── readme.txt ├── plot_metrics.py ├── plot_dataset.py ├── make_k_mers.py ├── custom_data_collator.py ├── finetune.py ├── custom_classification_head.py └── run_mlm.py /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | torch >= 1.3 3 | datasets >= 1.8.0 4 | sentencepiece != 0.1.92 5 | protobuf 6 | evaluate 7 | scikit-learn 8 | git+https://github.com/huggingface/transformers 9 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /bert-rna-model.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "bert-rna", 3 | "architectures": [ 4 | "BertForMaskedLM" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "max_position_embeddings": 1024, 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 6, 15 | "layer_norm_eps": 1e-12, 16 | "pad_token_id": 3, 17 | "type_vocab_size": 2, 18 | "use_cache": true, 19 | "vocab_size": 8278, 20 | "num_labels": 31 21 | 22 | } -------------------------------------------------------------------------------- /data/family_to_index.txt: -------------------------------------------------------------------------------- 1 | {"RF00001.fa.csv": "0", "RF00003.fa.csv": "1", "RF00004.fa.csv": "2", "RF00005.fa.csv": "3", "RF00007.fa.csv": "4", "RF00017.fa.csv": "5", "RF00019.fa.csv": "6", "RF00026.fa.csv": "7", "RF00029.fa.csv": "8", "RF00032.fa.csv": "9", "RF00059.fa.csv": "10", "RF00097.fa.csv": "11", "RF00100.fa.csv": "12", "RF00163.fa.csv": "13", "RF00174.fa.csv": "14", "RF00177.fa.csv": "15", "RF00230.fa.csv": "16", "RF00436.fa.csv": "17", "RF00906.fa.csv": "18", "RF00994.fa.csv": "19", "RF01315.fa.csv": "20", "RF01317.fa.csv": "21", "RF01787.fa.csv": "22", "RF01960.fa.csv": "23", "RF02271.fa.csv": "24", "RF02541.fa.csv": "25", "RF02543.fa.csv": "26", "RF04021.fa.csv": "27", "RF04088.fa.csv": "28", "human_mrna.csv": "29", "virus_mrna.fasta.csv": "30"} -------------------------------------------------------------------------------- /family_name_to_id.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import ast 5 | 6 | 7 | df_train = pd.read_csv("data/corrected_sequences_X_train.csv") 8 | df_train = df_train.dropna() 9 | df_validation = pd.read_csv("data/corrected_sequences_X_val.csv") 10 | df_validation = df_validation.dropna() 11 | df_test = pd.read_csv("data/corrected_sequences_X_test.csv") 12 | df_test = df_test.dropna() 13 | 14 | df_all = pd.concat([df_train, df_validation, df_test]) 15 | 16 | family_to_index = {label:str(i) for i,label in enumerate(np.unique(df_all['label'].to_list()))} 17 | index_to_family = {str(i):label for i,label in enumerate(np.unique(df_all['label'].to_list()))} 18 | num_labels = len(family_to_index.keys()) 19 | print("Number of labels:", num_labels) 20 | 21 | print(family_to_index) 22 | print(index_to_family) 23 | 24 | -------------------------------------------------------------------------------- /rna_k_mer_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from tokenizers.models import WordLevel 3 | from tokenizers import normalizers 4 | from tokenizers.normalizers import Lowercase, NFD, StripAccents 5 | from tokenizers.pre_tokenizers import Whitespace 6 | from tokenizers.trainers import WordLevelTrainer 7 | from tokenizers.processors import TemplateProcessing 8 | 9 | import glob 10 | import os 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | bert_tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 16 | bert_tokenizer.pre_tokenizer = Whitespace() 17 | bert_tokenizer.post_processor = TemplateProcessing( 18 | single="[CLS] $A [SEP]", 19 | special_tokens=[("[CLS]", 1), ("[SEP]", 2)], 20 | ) 21 | trainer = WordLevelTrainer(special_tokens=["[PAD]", "[CLS]", "[SEP]","[UNK]", "[MASK]"]) 22 | bert_tokenizer.train(["./data/pd_6_mer_pretrain.txt"], trainer) 23 | bert_tokenizer.save("./bert-rna-6-mer-tokenizer.json") 24 | 25 | 26 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | rna_k_mer_tokenizer.py: creates tokenizer .json file by reading k-mer pretraining data 2 | 3 | bert-rna-model.json: Find an online example for Bert configuration and modified it. Reduced number of layers and vocabulary size. Added num_labels 4 | 5 | bert-rna-6-mer-tokenizer.json: Output of run_k_mer_tokenizer.py. 6 | 7 | make_k_mers.py: turns nucleotide sequence into given k-mer sequences. 8 | 9 | run_mlm.py: masked language model pretraining. Modified to pretrain from scratch and to read sequence data. Default values are updated for our purpose. 10 | 11 | fintune.py: finetunes pretrained model with family Classification task 12 | 13 | plot_metrics.py: Gets checkpoint directory and plots loss, accuracy 14 | 15 | plot_dataset.py: Used for dataset length distribution and size. 16 | 17 | 18 | 19 | 20 | conda create -n CS230 python=3.10 21 | pip install -r requirements.txt 22 | 23 | python run_mlm.py --output_dir ./out_mlm 24 | python run_mlm.py --output_dir ./out_mlm --resume ./out_mlm/chekpoint-XXXX 25 | 26 | python run_cls.py --output_dir ./out_cls --model_name_or_path ./out_mlm/ 27 | python run_cls.py --output_dir ./out_cls --resume ./out_cls/checkpoint-XXXX 28 | -------------------------------------------------------------------------------- /plot_metrics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import matplotlib.mlab as mlab 4 | 5 | 6 | df = pd.read_json("out_finetune_harvar_var_len_250/checkpoint-800/trainer_state.json") 7 | 8 | dict_1 = df.log_history.to_dict() 9 | 10 | df = pd.DataFrame.from_dict(dict_1) 11 | print(df.loc['loss']) 12 | print(df.loc['epoch']) 13 | 14 | df_1 = pd.DataFrame({'loss':df.loc['loss'], 'epoch':df.loc['epoch']}) 15 | df_1 = df_1.dropna() 16 | 17 | df_2 = pd.DataFrame({'eval_loss':df.loc['eval_loss'], 'epoch':df.loc['epoch']}) 18 | df_2 = df_2.dropna() 19 | 20 | 21 | #plt.ylim([0,1]) 22 | #plt.xlim([0,20]) 23 | plt.xlabel("epochs") 24 | plt.plot(df_1["epoch"], df_1['loss'], color='red', label='train loss') 25 | plt.plot(df_2["epoch"], df_2['eval_loss'], color = 'green', 26 | linestyle = 'solid', marker = 'o', 27 | markerfacecolor = 'green', markersize = 4, label='eval loss') 28 | plt.plot(df.loc["epoch"].values, df.loc["eval_accuracy"].values, color = 'blue', 29 | linestyle = 'dashed', marker = '*', 30 | markerfacecolor = 'green', markersize = 4, label='Eval accuracy') 31 | plt.legend(loc='best') 32 | 33 | plt.show() 34 | plt.savefig("loss.png", dpi=300) 35 | 36 | 37 | -------------------------------------------------------------------------------- /plot_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib as mpl 4 | import matplotlib.pyplot as plt 5 | import matplotlib.mlab as mlab 6 | 7 | df_train = pd.read_csv("data/harvar_var_len/len_250_X_train_1.csv") 8 | #df_validation = pd.read_csv("data/corrected_sequences_X_val.csv") 9 | #df_test = pd.read_csv("data/corrected_sequences_X_test.csv") 10 | 11 | #print(df_train.head()) 12 | 13 | #labelss = df_train['label'].unique() 14 | labelss = np.unique(df_train['label'].to_list()) 15 | datas = [] 16 | data_lens = [] 17 | for lbl in labelss: 18 | data = df_train[df_train['label'] == lbl]['len'].values 19 | datas.append(data) 20 | 21 | data_lens.append(data.size) 22 | 23 | 24 | print(labelss) 25 | 26 | 27 | labelss = [str.replace(x, "fa.csv", "") for x in labelss] 28 | labelss = [str.replace(x, "fasta.csv", "") for x in labelss] 29 | labelss = [str.replace(x, "RF", "") for x in labelss] 30 | fig, ax = plt.subplots() 31 | fig.set_figwidth(15) 32 | fig.set_figheight(10) 33 | 34 | fig.subplots_adjust(bottom=0.2) 35 | mpl.rcParams['boxplot.boxprops.color'] = 'blue' 36 | mpl.rcParams['boxplot.flierprops.color'] = 'blue' 37 | mpl.rcParams['boxplot.whiskerprops.color'] = 'blue' 38 | box = ax.boxplot(datas) 39 | 40 | 41 | 42 | pos = np.arange(len(labelss)) + 1 43 | ax.set_xticks(pos, labels=labelss, rotation=90, fontsize=18) 44 | plt.yticks(fontsize=20) 45 | 46 | ax2 = ax.twinx() 47 | 48 | #ax2.plot(pos, data_lens, linestyle = 'dashed', marker = '*', 49 | # markerfacecolor = 'red', markersize = 8, label='family size') 50 | ax2.plot(pos, data_lens, 'ro', markersize = 8, label='family size') 51 | ax.set_ylabel('sequence length distribution', color='blue', fontsize=20) 52 | ax2.set_ylabel('family size', color='red', fontsize=20) 53 | ax.set_xlabel('families', fontsize=20) 54 | 55 | 56 | 57 | plt.show() 58 | plt.savefig("dataset_len_distribution.png") 59 | -------------------------------------------------------------------------------- /make_k_mers.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import ast 5 | import json 6 | 7 | k_mer = 6 8 | split = "_1" 9 | seq_len = 50 10 | 11 | train_file_name = f"data/harvar_var_len/len_{seq_len}_X_train{split}.csv" 12 | k_mer_train_file_name = f"data/harvar_var_len/{k_mer}_mer_{seq_len}_train{split}.csv" 13 | 14 | val_file_name = f"data/harvar_var_len/len_{seq_len}_X_val{split}.csv" 15 | k_mer_val_file_name = f"data/harvar_var_len/{k_mer}_mer_{seq_len}_val{split}.csv" 16 | 17 | test_file_name = f"data/harvar_var_len/len_{seq_len}_X_test{split}.csv" 18 | k_mer_test_file_name = f"data/harvar_var_len/{k_mer}_mer_{seq_len}_test{split}.csv" 19 | 20 | #f_pretrain = open("data/pd_k_mer_pretrain.txt", "w") 21 | ############################################################################################ 22 | # 23 | # 24 | ########################################################################################### 25 | df_train = pd.read_csv(train_file_name) 26 | df_train = df_train[pd.notna(df_train['corrected_sequences'])] 27 | 28 | labels = [] 29 | seqs = [] 30 | for seq, family in zip(df_train['corrected_sequences'].to_list(), df_train['label'].to_list()): 31 | 32 | line2 = "" 33 | line = [] 34 | #print(seq) 35 | for i in range(len(seq)-k_mer): 36 | if seq[i].lower() not in ['c', 'a', 'g', 't', 'u', 'n', 'w', 'r', 'k', 'm', 'y', 's', 'v', 'h', 'd', 'b', '\n']: 37 | print(seq[i], "NOT IN VOCAB:", line) 38 | k_mer_word = "" 39 | for j in range(k_mer): 40 | k_mer_word = k_mer_word + seq[i+j].lower() 41 | line2 = line2 + str(k_mer_word) + " " 42 | line.append(k_mer_word) 43 | 44 | labels.append(family) 45 | seqs.append(line) 46 | #f_pretrain.write(line2+'\n') 47 | 48 | #print(seqs) 49 | df_train_out = pd.DataFrame({"label":labels, "text":seqs}) 50 | df_train_out.to_csv(k_mer_train_file_name, index=False) 51 | 52 | 53 | ############################################################################################ 54 | # 55 | # 56 | ########################################################################################### 57 | df_validation = pd.read_csv(val_file_name) 58 | df_validation = df_validation[pd.notna(df_validation['corrected_sequences'])] 59 | 60 | labels = [] 61 | seqs = [] 62 | 63 | for seq, family in zip(df_validation['corrected_sequences'].to_list(), df_validation['label'].to_list()): 64 | 65 | line2 = "" 66 | line = [] 67 | #print(seq) 68 | for i in range(len(seq)-k_mer): 69 | if seq[i].lower() not in ['c', 'a', 'g', 't', 'u', 'n', 'w', 'r', 'k', 'm', 'y', 's', 'v', 'h', 'd', 'b', '\n']: 70 | print(seq[i], "NOT IN VOCAB:", line) 71 | k_mer_word = "" 72 | for j in range(k_mer): 73 | k_mer_word = k_mer_word + seq[i+j].lower() 74 | line2 = line2 + str(k_mer_word) + " " 75 | line.append(k_mer_word) 76 | 77 | labels.append(family) 78 | seqs.append(line) 79 | #f_pretrain.write(line2+'\n') 80 | 81 | df_val_out = pd.DataFrame({"label":labels, "text":seqs}) 82 | df_val_out.to_csv(k_mer_val_file_name, index=False) 83 | ############################################################################################ 84 | # 85 | # 86 | ########################################################################################### 87 | df_test = pd.read_csv(test_file_name) 88 | df_test = df_test[pd.notna(df_test['corrected_sequences'])] 89 | #df_test = df_test.dropna() 90 | labels = [] 91 | seqs = [] 92 | 93 | for seq, family in zip(df_test['corrected_sequences'].to_list(), df_test['label'].to_list()): 94 | 95 | line2 = "" 96 | line = [] 97 | #print(seq) 98 | for i in range(len(seq)-k_mer): 99 | if seq[i].lower() not in ['c', 'a', 'g', 't', 'u', 'n', 'w', 'r', 'k', 'm', 'y', 's', 'v', 'h', 'd', 'b', '\n']: 100 | print(seq[i], "NOT IN VOCAB:", line) 101 | k_mer_word = "" 102 | for j in range(k_mer): 103 | k_mer_word = k_mer_word + seq[i+j].lower() 104 | line2 = line2 + str(k_mer_word) + " " 105 | line.append(k_mer_word) 106 | 107 | labels.append(family) 108 | seqs.append(line) 109 | #f_pretrain.write(line2+'\n') 110 | 111 | df_test_out = pd.DataFrame({"label":labels, "text":seqs}) 112 | df_test_out.to_csv(k_mer_test_file_name, index=False) 113 | ############################################################################################ 114 | # 115 | # 116 | ########################################################################################### 117 | 118 | #f_pretrain.close() 119 | 120 | df_train = pd.read_csv(k_mer_train_file_name) 121 | df_validation = pd.read_csv(k_mer_val_file_name) 122 | df_test = pd.read_csv(k_mer_test_file_name) 123 | 124 | df_all = pd.concat([df_train, df_validation, df_test]) 125 | #df_all = pd.concat([df_train, df_test]) 126 | 127 | with open('data/family_to_index.txt') as f: 128 | data = f.read() 129 | family_to_index = json.loads(data) 130 | 131 | 132 | df_train = df_train.rename(columns={"label": "label_text"}) 133 | df_validation = df_validation.rename(columns={"label": "label_text"}) 134 | df_test = df_test.rename(columns={"label": "label_text"}) 135 | 136 | df_train['label'] = df_train['label_text'].apply(lambda x: family_to_index[x]) 137 | df_validation['label'] = df_validation['label_text'].apply(lambda x: family_to_index[x]) 138 | df_test['label'] = df_test['label_text'].apply(lambda x: family_to_index[x]) 139 | 140 | 141 | df_train.to_csv(k_mer_train_file_name, index=False) 142 | df_validation.to_csv(k_mer_val_file_name, index=False) 143 | df_test.to_csv(k_mer_test_file_name, index=False) 144 | -------------------------------------------------------------------------------- /custom_data_collator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | import random 4 | import torch 5 | from collections.abc import Mapping 6 | from transformers import DataCollatorForLanguageModeling 7 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 8 | from transformers import BertModel, BertConfig, BertTokenizer, BertTokenizerFast, BertForMaskedLM 9 | 10 | def tolist(x): 11 | if isinstance(x, list): 12 | return x 13 | elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import 14 | x = x.numpy() 15 | return x.tolist() 16 | 17 | def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): 18 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" 19 | import torch 20 | 21 | # Tensorize if necessary. 22 | if isinstance(examples[0], (list, tuple, np.ndarray)): 23 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 24 | 25 | length_of_first = examples[0].size(0) 26 | 27 | # Check if padding is necessary. 28 | 29 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 30 | if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): 31 | return torch.stack(examples, dim=0) 32 | 33 | # If yes, check if we have a `pad_token`. 34 | if tokenizer._pad_token is None: 35 | raise ValueError( 36 | "You are attempting to pad samples but the tokenizer you are using" 37 | f" ({tokenizer.__class__.__name__}) does not have a pad token." 38 | ) 39 | 40 | # Creating the full tensor and filling it with our data. 41 | max_length = max(x.size(0) for x in examples) 42 | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 43 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 44 | result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) 45 | for i, example in enumerate(examples): 46 | if tokenizer.padding_side == "right": 47 | result[i, : example.shape[0]] = example 48 | else: 49 | result[i, -example.shape[0] :] = example 50 | return result 51 | 52 | 53 | class RnaDataCollator(DataCollatorForLanguageModeling): 54 | 55 | def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 56 | # Handle dict or lists with proper padding and conversion to tensor. 57 | if isinstance(examples[0], Mapping): 58 | batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of) 59 | else: 60 | batch = { 61 | "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 62 | } 63 | 64 | # If special token mask has been preprocessed, pop it from the dict. 65 | special_tokens_mask = batch.pop("special_tokens_mask", None) 66 | if self.mlm: 67 | batch["input_ids"], batch["labels"] = self.torch_mask_tokens( 68 | batch["input_ids"], special_tokens_mask=special_tokens_mask 69 | ) 70 | else: 71 | labels = batch["input_ids"].clone() 72 | if self.tokenizer.pad_token_id is not None: 73 | labels[labels == self.tokenizer.pad_token_id] = -100 74 | batch["labels"] = labels 75 | return batch 76 | 77 | 78 | def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: 79 | """ 80 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 81 | """ 82 | import torch 83 | 84 | labels = inputs.clone() 85 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 86 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 87 | if special_tokens_mask is None: 88 | special_tokens_mask = [ 89 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 90 | ] 91 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 92 | else: 93 | special_tokens_mask = special_tokens_mask.bool() 94 | 95 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 96 | masked_indices = torch.bernoulli(probability_matrix).bool() 97 | masked_indices = self.torch_mask_k_tokens(6, masked_indices) 98 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 99 | 100 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 101 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 102 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 103 | 104 | # 10% of the time, we replace masked input tokens with random word 105 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 106 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 107 | inputs[indices_random] = random_words[indices_random] 108 | 109 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 110 | return inputs, labels 111 | 112 | def torch_mask_k_tokens(self, k, masked_indices): 113 | 114 | shifted_masks = masked_indices 115 | for i in range(k-1): 116 | temp = torch.roll(masked_indices, i+1, 1) 117 | temp[:,0] = 0 118 | shifted_masks = torch.add(shifted_masks, temp) 119 | 120 | return shifted_masks 121 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | import ast 6 | import json 7 | from dataclasses import dataclass, field 8 | from typing import Optional 9 | from torchmetrics.classification import MulticlassAccuracy 10 | import datasets 11 | import numpy as np 12 | from datasets import load_dataset, Dataset 13 | import sklearn 14 | import evaluate 15 | import transformers 16 | import pandas as pd 17 | from transformers import ( 18 | AutoConfig, 19 | AutoModelForSequenceClassification, 20 | AutoTokenizer, 21 | PreTrainedTokenizerFast, 22 | DataCollatorWithPadding, 23 | EvalPrediction, 24 | HfArgumentParser, 25 | PretrainedConfig, 26 | Trainer, 27 | TrainingArguments, 28 | default_data_collator, 29 | set_seed, 30 | BertForSequenceClassification, 31 | ) 32 | from transformers.trainer_utils import get_last_checkpoint 33 | from transformers.utils import check_min_version, send_example_telemetry 34 | from transformers.utils.versions import require_version 35 | from custom_classification_head import BertForCustomClassification 36 | 37 | 38 | #df_test = pd.read_csv("data/pd_6_mer_test_1.csv") 39 | #df_test = pd.DataFrame(df_test) 40 | #raw_datasets_test = Dataset.from_pandas(df_test) 41 | 42 | data_files = {"train": "data/harvar_var_len/6_mer_100_train_1.csv", "validation": "data/harvar_var_len/6_mer_100_val_1.csv", "test": "data/harvar_var_len/6_mer_100_test_1.csv"} 43 | raw_datasets = load_dataset("csv", data_files=data_files) 44 | 45 | 46 | tokenizer = PreTrainedTokenizerFast(tokenizer_file="bert-rna-6-mer-tokenizer.json") 47 | tokenizer.pad_token = "[PAD]" 48 | tokenizer.cls_token = "[CLS]" 49 | 50 | def preprocess_function(examples): 51 | #print(examples['text'][0]) 52 | #print(str.replace(examples['text'][0]," ", "")) 53 | #print(ast.literal_eval(examples['text'][0])) 54 | examples['text'] = [str.replace(examples['text'][i],"'", "") for i in range(len(examples['text']))] 55 | examples['text'] = [str.replace(examples['text'][i],",", "") for i in range(len(examples['text']))] 56 | examples['text'] = [str.replace(examples['text'][i],"[", "") for i in range(len(examples['text']))] 57 | examples['text'] = [str.replace(examples['text'][i],"]", "") for i in range(len(examples['text']))] 58 | #text_len = len(examples['text']) 59 | #x_repeat = 512/text_len+1 60 | #tmp_text = [examples['text'] + ' ' for i in range(x_repeat)] 61 | 62 | #print(examples['text'][0]) 63 | return tokenizer( 64 | examples['text'], 65 | padding=True, 66 | truncation=True, 67 | max_length=512, 68 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it 69 | # receives the `special_tokens_mask`. 70 | return_special_tokens_mask=True, 71 | ) 72 | 73 | 74 | tokenized_seqs = raw_datasets.map(preprocess_function, batched=True) 75 | 76 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 77 | 78 | #model = BertForSequenceClassification.from_pretrained("./out_mlm_6layer_3/checkpoint-360000", num_labels=31) 79 | #model = BertForSequenceClassification.from_pretrained("./out_mlm/checkpoint-10", num_labels=31) 80 | #model = BertForSequenceClassification.from_pretrained("./out_finetune_with_pretrain/checkpoint-12500", num_labels=31) 81 | #model = AutoModelForSequenceClassification.from_pretrained("./out_mlm/checkpoint-360000", num_labels=30) 82 | model = BertForSequenceClassification.from_pretrained("./out_finetune_equal_length_1/checkpoint-5500", num_labels=31) 83 | 84 | 85 | training_args = TrainingArguments( 86 | output_dir="./out_finetune", 87 | learning_rate=5e-5, 88 | per_device_train_batch_size=50, 89 | per_device_eval_batch_size=50, 90 | num_train_epochs=10, 91 | #weight_decay=0.01, 92 | metric_for_best_model="accuracy", 93 | eval_steps= 200, 94 | evaluation_strategy= 'steps', 95 | save_total_limit = 3, 96 | log_level = 'info', 97 | logging_steps = 100, 98 | #dataloader_num_workers = 10, 99 | save_strategy='steps', 100 | save_steps=200, 101 | load_best_model_at_end = True, 102 | ) 103 | 104 | 105 | def compute_metrics2(p: EvalPrediction): 106 | 107 | cls = sklearn.metrics.classification_report(p.label_ids, np.argmax(p.predictions, axis=1), output_dict=True) 108 | return cls 109 | preds = np.argmax(p.predictions, axis=1) 110 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 111 | 112 | metric = MulticlassAccuracy(num_classes=30, average=None) 113 | result = metric(p.predictions, p.label_ids) 114 | return {"result":result} 115 | 116 | trainer = Trainer( 117 | model=model, 118 | args=training_args, 119 | train_dataset=tokenized_seqs["train"], 120 | eval_dataset=tokenized_seqs["validation"], 121 | tokenizer=tokenizer, 122 | compute_metrics=compute_metrics2, 123 | data_collator=data_collator, 124 | ) 125 | 126 | for batch in trainer.get_train_dataloader(): 127 | print(batch) 128 | break 129 | 130 | #trainer.train(resume_from_checkpoint = False) 131 | #result = trainer.train() 132 | #print(json.dumps(result)) 133 | 134 | #df_test_2 = pd.DataFrame({'label':tokenized_seqs_test['label'][:]}) 135 | #print(df_test_2['label'].value_counts()) 136 | #print(df_test_2['label'].dtype) 137 | 138 | p = trainer.predict(tokenized_seqs['test']) 139 | print(p) 140 | 141 | 142 | cls = sklearn.metrics.classification_report( p.label_ids, np.argmax(p.predictions, axis=1), 143 | labels = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], 144 | target_names=["RF00001.fa.csv", "RF00003.fa.csv", "RF00004.fa.csv", "RF00005.fa.csv", 145 | "RF00007.fa.csv", "RF00017.fa.csv", "RF00019.fa.csv", "RF00026.fa.csv", 146 | "RF00029.fa.csv", "RF00032.fa.csv", "RF00059.fa.csv", "RF00097.fa.csv", 147 | "RF00100.fa.csv", "RF00163.fa.csv", "RF00174.fa.csv", "RF00177.fa.csv", 148 | "RF00230.fa.csv", "RF00436.fa.csv", "RF00906.fa.csv", "RF00994.fa.csv", 149 | "RF01315.fa.csv", "RF01317.fa.csv", "RF01787.fa.csv", "RF01960.fa.csv", 150 | "RF02271.fa.csv", "RF02541.fa.csv", "RF02543.fa.csv", "RF04021.fa.csv", 151 | "RF04088.fa.csv", "human_mrna.csv", "virus_mrna.fasta.csv"] 152 | ) 153 | print(cls) 154 | 155 | 156 | #cm = sklearn.metrics.confusion_matrix(p.label_ids, np.argmax(p.predictions, axis=1)) 157 | #print(cm) 158 | -------------------------------------------------------------------------------- /custom_classification_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertPreTrainedModel, BertModel 4 | import logging 5 | from typing import List, Optional, Tuple, Union 6 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 7 | 8 | from transformers.modeling_outputs import ( 9 | BaseModelOutputWithPastAndCrossAttentions, 10 | BaseModelOutputWithPoolingAndCrossAttentions, 11 | CausalLMOutputWithCrossAttentions, 12 | MaskedLMOutput, 13 | MultipleChoiceModelOutput, 14 | NextSentencePredictorOutput, 15 | QuestionAnsweringModelOutput, 16 | SequenceClassifierOutput, 17 | TokenClassifierOutput, 18 | ) 19 | 20 | from transformers.utils import ( 21 | ModelOutput, 22 | add_code_sample_docstrings, 23 | add_start_docstrings, 24 | add_start_docstrings_to_model_forward, 25 | logging, 26 | replace_return_docstrings, 27 | ) 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | _CHECKPOINT_FOR_DOC = "bert-base-uncased" 32 | _CONFIG_FOR_DOC = "BertConfig" 33 | _TOKENIZER_FOR_DOC = "BertTokenizer" 34 | 35 | # SequenceClassification docstring 36 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity" 37 | _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" 38 | _SEQ_CLASS_EXPECTED_LOSS = 0.01 39 | 40 | 41 | BERT_INPUTS_DOCSTRING = r""" 42 | Args: 43 | input_ids (`torch.LongTensor` of shape `({0})`): 44 | Indices of input sequence tokens in the vocabulary. 45 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and 46 | [`PreTrainedTokenizer.__call__`] for details. 47 | [What are input IDs?](../glossary#input-ids) 48 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 49 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 50 | - 1 for tokens that are **not masked**, 51 | - 0 for tokens that are **masked**. 52 | [What are attention masks?](../glossary#attention-mask) 53 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 54 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 55 | 1]`: 56 | - 0 corresponds to a *sentence A* token, 57 | - 1 corresponds to a *sentence B* token. 58 | [What are token type IDs?](../glossary#token-type-ids) 59 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 60 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 61 | config.max_position_embeddings - 1]`. 62 | [What are position IDs?](../glossary#position-ids) 63 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 64 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 65 | - 1 indicates the head is **not masked**, 66 | - 0 indicates the head is **masked**. 67 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 68 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 69 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 70 | model's internal embedding lookup matrix. 71 | output_attentions (`bool`, *optional*): 72 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 73 | tensors for more detail. 74 | output_hidden_states (`bool`, *optional*): 75 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 76 | more detail. 77 | return_dict (`bool`, *optional*): 78 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 79 | """ 80 | 81 | 82 | class BertForCustomClassification(BertPreTrainedModel): 83 | def __init__(self, config): 84 | super().__init__(config) 85 | self.num_labels = config.num_labels 86 | self.config = config 87 | 88 | self.bert = BertModel(config) 89 | classifier_dropout = ( 90 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 91 | ) 92 | self.dropout = nn.Dropout(classifier_dropout) 93 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 94 | 95 | # Initialize weights and apply final processing 96 | self.post_init() 97 | 98 | 99 | def forward( 100 | self, 101 | input_ids: Optional[torch.Tensor] = None, 102 | attention_mask: Optional[torch.Tensor] = None, 103 | token_type_ids: Optional[torch.Tensor] = None, 104 | position_ids: Optional[torch.Tensor] = None, 105 | head_mask: Optional[torch.Tensor] = None, 106 | inputs_embeds: Optional[torch.Tensor] = None, 107 | labels: Optional[torch.Tensor] = None, 108 | output_attentions: Optional[bool] = None, 109 | output_hidden_states: Optional[bool] = None, 110 | return_dict: Optional[bool] = None, 111 | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: 112 | r""" 113 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 114 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 115 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 116 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 117 | """ 118 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 119 | 120 | outputs = self.bert( 121 | input_ids, 122 | attention_mask=attention_mask, 123 | token_type_ids=token_type_ids, 124 | position_ids=position_ids, 125 | head_mask=head_mask, 126 | inputs_embeds=inputs_embeds, 127 | output_attentions=output_attentions, 128 | output_hidden_states=output_hidden_states, 129 | return_dict=return_dict, 130 | ) 131 | 132 | last_hidden_state = outputs[0] 133 | cls_representation = last_hidden_state[:,0,:] 134 | pooled_output = self.dropout(cls_representation) 135 | logits = self.classifier(pooled_output) 136 | 137 | loss = None 138 | if labels is not None: 139 | if self.config.problem_type is None: 140 | if self.num_labels == 1: 141 | self.config.problem_type = "regression" 142 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 143 | self.config.problem_type = "single_label_classification" 144 | else: 145 | self.config.problem_type = "multi_label_classification" 146 | 147 | if self.config.problem_type == "regression": 148 | loss_fct = MSELoss() 149 | if self.num_labels == 1: 150 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 151 | else: 152 | loss = loss_fct(logits, labels) 153 | elif self.config.problem_type == "single_label_classification": 154 | loss_fct = CrossEntropyLoss() 155 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 156 | elif self.config.problem_type == "multi_label_classification": 157 | loss_fct = BCEWithLogitsLoss() 158 | loss = loss_fct(logits, labels) 159 | if not return_dict: 160 | output = (logits,) + outputs[2:] 161 | return ((loss,) + output) if loss is not None else output 162 | 163 | return SequenceClassifierOutput( 164 | loss=loss, 165 | logits=logits, 166 | hidden_states=outputs.hidden_states, 167 | attentions=outputs.attentions, 168 | ) 169 | -------------------------------------------------------------------------------- /run_mlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=fill-mask 21 | """ 22 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from itertools import chain 30 | from typing import Optional 31 | from transformers import BertModel, BertConfig, BertTokenizer, BertTokenizerFast, BertForMaskedLM 32 | import datasets 33 | from datasets import load_dataset 34 | from custom_data_collator import RnaDataCollator 35 | 36 | import evaluate 37 | import transformers 38 | from transformers import ( 39 | CONFIG_MAPPING, 40 | MODEL_FOR_MASKED_LM_MAPPING, 41 | AutoConfig, 42 | AutoModelForMaskedLM, 43 | AutoTokenizer, 44 | DataCollatorForLanguageModeling, 45 | HfArgumentParser, 46 | Trainer, 47 | TrainingArguments, 48 | is_torch_tpu_available, 49 | set_seed, 50 | ) 51 | from transformers.trainer_utils import get_last_checkpoint 52 | from transformers.utils import check_min_version, send_example_telemetry 53 | from transformers.utils.versions import require_version 54 | from tokenizers import Tokenizer 55 | 56 | 57 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 58 | check_min_version("4.24.0.dev0") 59 | 60 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 61 | 62 | logger = logging.getLogger(__name__) 63 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 64 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 65 | 66 | 67 | @dataclass 68 | class ModelArguments: 69 | """ 70 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 71 | """ 72 | 73 | model_name_or_path: Optional[str] = field( 74 | default=None, 75 | metadata={ 76 | "help": ( 77 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 78 | ) 79 | }, 80 | ) 81 | model_type: Optional[str] = field( 82 | default=None, 83 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 84 | ) 85 | config_overrides: Optional[str] = field( 86 | default=None, 87 | metadata={ 88 | "help": ( 89 | "Override some existing default config settings when a model is trained from scratch. Example: " 90 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 91 | ) 92 | }, 93 | ) 94 | config_name: Optional[str] = field( 95 | default="./bert-rna-model.json", metadata={"help": "Pretrained config name or path if not the same as model_name"} 96 | ) 97 | tokenizer_name: Optional[str] = field( 98 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 99 | ) 100 | cache_dir: Optional[str] = field( 101 | default=None, 102 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 103 | ) 104 | use_fast_tokenizer: bool = field( 105 | default=True, 106 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 107 | ) 108 | model_revision: str = field( 109 | default="main", 110 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 111 | ) 112 | use_auth_token: bool = field( 113 | default=False, 114 | metadata={ 115 | "help": ( 116 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 117 | "with private models)." 118 | ) 119 | }, 120 | ) 121 | 122 | def __post_init__(self): 123 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 124 | raise ValueError( 125 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 126 | ) 127 | 128 | 129 | @dataclass 130 | class DataTrainingArguments: 131 | """ 132 | Arguments pertaining to what data we are going to input our model for training and eval. 133 | """ 134 | 135 | dataset_name: Optional[str] = field( 136 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, default=None 137 | ) 138 | dataset_config_name: Optional[str] = field( 139 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 140 | ) 141 | train_file: Optional[str] = field(default='data/pd_k_mer_pretrain.txt', metadata={"help": "The input training data file (a text file)."}) 142 | validation_file: Optional[str] = field( 143 | default=None, 144 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 145 | ) 146 | overwrite_cache: bool = field( 147 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 148 | ) 149 | validation_split_percentage: Optional[int] = field( 150 | default=2, 151 | metadata={ 152 | "help": "The percentage of the train set used as validation set in case there's no validation split" 153 | }, 154 | ) 155 | max_seq_length: Optional[int] = field( 156 | default=512, 157 | metadata={ 158 | "help": ( 159 | "The maximum total input sequence length after tokenization. Sequences longer " 160 | "than this will be truncated." 161 | ) 162 | }, 163 | ) 164 | preprocessing_num_workers: Optional[int] = field( 165 | default=None, 166 | metadata={"help": "The number of processes to use for the preprocessing."}, 167 | ) 168 | mlm_probability: float = field( 169 | default=0.060, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 170 | ) 171 | line_by_line: bool = field( 172 | default=True, 173 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 174 | ) 175 | pad_to_max_length: bool = field( 176 | default=True, 177 | metadata={ 178 | "help": ( 179 | "Whether to pad all samples to `max_seq_length`. " 180 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 181 | ) 182 | }, 183 | ) 184 | max_train_samples: Optional[int] = field( 185 | default=None, 186 | metadata={ 187 | "help": ( 188 | "For debugging purposes or quicker training, truncate the number of training examples to this " 189 | "value if set." 190 | ) 191 | }, 192 | ) 193 | max_eval_samples: Optional[int] = field( 194 | default=None, 195 | metadata={ 196 | "help": ( 197 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 198 | "value if set." 199 | ) 200 | }, 201 | ) 202 | 203 | def __post_init__(self): 204 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 205 | raise ValueError("Need either a dataset name or a training/validation file.") 206 | else: 207 | if self.train_file is not None: 208 | extension = self.train_file.split(".")[-1] 209 | if extension not in ["csv", "json", "txt"]: 210 | raise ValueError("`train_file` should be a csv, a json or a txt file.") 211 | if self.validation_file is not None: 212 | extension = self.validation_file.split(".")[-1] 213 | if extension not in ["csv", "json", "txt"]: 214 | raise ValueError("`validation_file` should be a csv, a json or a txt file.") 215 | 216 | 217 | def main(): 218 | # See all possible argumNoneents in src/transformers/training_args.py 219 | # or by passing the --help flag to this script. 220 | # We now keep distinct sets of args, for a cleaner separation of concerns. 221 | 222 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 223 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 224 | # If we pass only one argutment to the script and it's the path to a json file, 225 | # let's parse it to get our arguments. 226 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 227 | else: 228 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 229 | 230 | training_args.do_train=True 231 | training_args.do_eval=True 232 | training_args.do_predict=False 233 | training_args.eval_steps=2500 234 | training_args.evaluation_strategy="steps" 235 | training_args.num_train_epochs = 100 236 | training_args.per_device_train_batch_size = 50 237 | training_args.per_device_eval_batch_size = 50 238 | training_args.gradient_accumulation_steps = 2 239 | training_args.overwrite_output_dir = True 240 | #training_args.resume_from_checkpoint = True 241 | #training_args.model_name_or_path = '/home/desin/CS230/RNABERT/out_mlm/checkpoint-9500' 242 | training_args.save_total_limit = 3 243 | training_args.log_level = 'info' 244 | training_args.logging_steps = 500 245 | training_args.learning_rate=8e-5 246 | #training_args.save_strategy='steps' 247 | #training_args.save_steps=10 248 | 249 | #print("TRAINING ARGS::::::::::::::::::\n",training_args) 250 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 251 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 252 | send_example_telemetry("run_mlm", model_args, data_args) 253 | 254 | # Setup logging 255 | logging.basicConfig( 256 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 257 | datefmt="%m/%d/%Y %H:%M:%S", 258 | handlers=[logging.StreamHandler(sys.stdout)], 259 | ) 260 | 261 | log_level = training_args.get_process_log_level() 262 | logger.setLevel(log_level) 263 | datasets.utils.logging.set_verbosity(log_level) 264 | transformers.utils.logging.set_verbosity(log_level) 265 | transformers.utils.logging.enable_default_handler() 266 | transformers.utils.logging.enable_explicit_format() 267 | 268 | # Log on each process the small summary: 269 | logger.warning( 270 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 271 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 272 | ) 273 | # Set the verbosity to info of the Transformers logger (on main process only): 274 | logger.info(f"Training/evaluation parameters {training_args}") 275 | 276 | # Detecting last checkpoint. 277 | last_checkpoint = None 278 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 279 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 280 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 281 | raise ValueError( 282 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 283 | "Use --overwrite_output_dir to overcome." 284 | ) 285 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 286 | logger.info( 287 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 288 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 289 | ) 290 | 291 | # Set seed before initializing model. 292 | set_seed(training_args.seed) 293 | 294 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 295 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 296 | # (the dataset will be downloaded automatically from the datasets Hub 297 | # 298 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 299 | # behavior (see below) 300 | # 301 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 302 | # download the dataset. 303 | if data_args.dataset_name is not None: 304 | # Downloading and loading a dataset from the hub. 305 | raw_datasets = load_dataset( 306 | data_args.dataset_name, 307 | data_args.dataset_config_name, 308 | cache_dir=model_args.cache_dir, 309 | use_auth_token=True if model_args.use_auth_token else None, 310 | ) 311 | if "validation" not in raw_datasets.keys(): 312 | raw_datasets["validation"] = load_dataset( 313 | data_args.dataset_name, 314 | data_args.dataset_config_name, 315 | split=f"train[:{data_args.validation_split_percentage}%]", 316 | cache_dir=model_args.cache_dir, 317 | use_auth_token=True if model_args.use_auth_token else None, 318 | ) 319 | raw_datasets["train"] = load_dataset( 320 | data_args.dataset_name, 321 | data_args.dataset_config_name, 322 | split=f"train[{data_args.validation_split_percentage}%:]", 323 | cache_dir=model_args.cache_dir, 324 | use_auth_token=True if model_args.use_auth_token else None, 325 | ) 326 | else: 327 | data_files = {} 328 | if data_args.train_file is not None: 329 | data_files["train"] = data_args.train_file 330 | extension = data_args.train_file.split(".")[-1] 331 | if data_args.validation_file is not None: 332 | data_files["validation"] = data_args.validation_file 333 | extension = data_args.validation_file.split(".")[-1] 334 | if extension == "txt": 335 | extension = "text" 336 | raw_datasets = load_dataset( 337 | extension, 338 | data_files=data_files, 339 | cache_dir=model_args.cache_dir, 340 | use_auth_token=True if model_args.use_auth_token else None, 341 | ) 342 | 343 | 344 | # If no validation data is there, validation_split_percentage will be used to divide the dataset. 345 | if "validation" not in raw_datasets.keys(): 346 | raw_datasets["validation"] = load_dataset( 347 | extension, 348 | data_files=data_files, 349 | split=f"train[:{data_args.validation_split_percentage}%]", 350 | cache_dir=model_args.cache_dir, 351 | use_auth_token=True if model_args.use_auth_token else None, 352 | ) 353 | raw_datasets["train"] = load_dataset( 354 | extension, 355 | data_files=data_files, 356 | split=f"train[{data_args.validation_split_percentage}%:]", 357 | cache_dir=model_args.cache_dir, 358 | use_auth_token=True if model_args.use_auth_token else None, 359 | ) 360 | 361 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 362 | # https://huggingface.co/docs/datasets/loading_datasets.html. 363 | 364 | # Load pretrained model and tokenizer 365 | # 366 | # Distributed training: 367 | # The .from_pretrained methods guarantee that only one local process can concurrently 368 | # download model & vocab. 369 | config_kwargs = { 370 | "cache_dir": model_args.cache_dir, 371 | "revision": model_args.model_revision, 372 | "use_auth_token": True if model_args.use_auth_token else None, 373 | } 374 | if model_args.config_name: 375 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 376 | elif model_args.model_name_or_path: 377 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 378 | else: 379 | config = CONFIG_MAPPING[model_args.model_type]() 380 | #config = BertConfig() 381 | logger.warning("You are instantiating a new config instance from scratch.") 382 | if model_args.config_overrides is not None: 383 | logger.info(f"Overriding config: {model_args.config_overrides}") 384 | config.update_from_string(model_args.config_overrides) 385 | logger.info(f"New config: {config}") 386 | 387 | tokenizer_kwargs = { 388 | "cache_dir": model_args.cache_dir, 389 | "use_fast": model_args.use_fast_tokenizer, 390 | "revision": model_args.model_revision, 391 | "use_auth_token": True if model_args.use_auth_token else None, 392 | } 393 | if model_args.tokenizer_name: 394 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 395 | elif model_args.model_name_or_path: 396 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 397 | else: 398 | from transformers import PreTrainedTokenizerFast 399 | tokenizer_kwargs = { 400 | "mask_token": '[MASK]', 401 | "pad_token": '[PAD]' 402 | } 403 | #tokenizer = PreTrainedTokenizerFast(cls_token='[CLS]', mask_token="[MASK]", tokenizer_file="bert-rna-tokenizer-1.json", pad_token='[PAD]') 404 | tokenizer = PreTrainedTokenizerFast(tokenizer_file="bert-rna-6-mer-tokenizer.json", mask_token="[MASK]", pad_token="[PAD]", cls_token='[CLS]', sep_token='[SEP]') 405 | #tokenizer.mask_token = '[MASK]' 406 | #tokenizer.pad_token = '[PAD]' 407 | print("**************MASK TOKEN:", tokenizer.mask_token) 408 | 409 | #raise ValueError( 410 | # "You are instantiating a new tokenizer from scratch. This is not supported by this script." 411 | # "You can do it from another script, save it, and load it from here, using --tokenizer_name." 412 | #) 413 | 414 | if model_args.model_name_or_path: 415 | model = AutoModelForMaskedLM.from_pretrained( 416 | model_args.model_name_or_path, 417 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 418 | config=config, 419 | cache_dir=model_args.cache_dir, 420 | revision=model_args.model_revision, 421 | use_auth_token=True if model_args.use_auth_token else None, 422 | ) 423 | else: 424 | logger.info("Training new model from scratch") 425 | #config = AutoConfig.from_pretrained("bert-base-cased") 426 | #model = AutoModelForMaskedLM.from_config(config) 427 | model = AutoModelForMaskedLM.from_config(config) 428 | 429 | model.resize_token_embeddings(len(tokenizer)) 430 | 431 | # Preprocessing the datasets. 432 | # First we tokenize all the texts. 433 | if training_args.do_train: 434 | column_names = raw_datasets["train"].column_names 435 | else: 436 | column_names = raw_datasets["validation"].column_names 437 | text_column_name = "text" if "text" in column_names else column_names[0] 438 | 439 | print("****RAW DATASETS COLUMN NAMES:", raw_datasets["train"][0]) 440 | 441 | 442 | """if data_args.max_seq_length is None: 443 | max_seq_length = tokenizer.model_max_length 444 | if max_seq_length > 1024: 445 | logger.warning( 446 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 447 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 448 | ) 449 | max_seq_length = 1024 450 | else: 451 | if data_args.max_seq_length > tokenizer.model_max_length: 452 | logger.warning( 453 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 454 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 455 | ) 456 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 457 | """ 458 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 459 | if data_args.line_by_line: 460 | # When using line_by_line, we just tokenize each nonempty line. 461 | padding = "max_length" if data_args.pad_to_max_length else False 462 | 463 | def tokenize_function(examples): 464 | # Remove empty lines 465 | examples[text_column_name] = [ 466 | line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() 467 | ] 468 | return tokenizer( 469 | examples[text_column_name], 470 | padding=padding, 471 | truncation=True, 472 | max_length=max_seq_length, 473 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it 474 | # receives the `special_tokens_mask`. 475 | return_special_tokens_mask=True, 476 | ) 477 | 478 | with training_args.main_process_first(desc="dataset map tokenization"): 479 | tokenized_datasets = raw_datasets.map( 480 | tokenize_function, 481 | batched=True, 482 | num_proc=data_args.preprocessing_num_workers, 483 | remove_columns=[text_column_name], 484 | load_from_cache_file=not data_args.overwrite_cache, 485 | desc="Running tokenizer on dataset line_by_line", 486 | ) 487 | #for data in tokenized_datasets['train']: 488 | # print(data) 489 | else: 490 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. 491 | # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more 492 | # efficient when it receives the `special_tokens_mask`. 493 | def tokenize_function(examples): 494 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 495 | 496 | with training_args.main_process_first(desc="dataset map tokenization"): 497 | tokenized_datasets = raw_datasets.map( 498 | tokenize_function, 499 | batched=True, 500 | num_proc=data_args.preprocessing_num_workers, 501 | remove_columns=column_names, 502 | load_from_cache_file=not data_args.overwrite_cache, 503 | desc="Running tokenizer on every text in dataset", 504 | ) 505 | 506 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 507 | # max_seq_length. 508 | def group_texts(examples): 509 | # Concatenate all texts. 510 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 511 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 512 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 513 | # customize this part to your needs. 514 | if total_length >= max_seq_length: 515 | total_length = (total_length // max_seq_length) * max_seq_length 516 | # Split by chunks of max_len. 517 | result = { 518 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 519 | for k, t in concatenated_examples.items() 520 | } 521 | return result 522 | 523 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 524 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 525 | # might be slower to preprocess. 526 | # 527 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 528 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 529 | 530 | with training_args.main_process_first(desc="grouping texts together"): 531 | tokenized_datasets = tokenized_datasets.map( 532 | group_texts, 533 | batched=True, 534 | num_proc=data_args.preprocessing_num_workers, 535 | load_from_cache_file=not data_args.overwrite_cache, 536 | desc=f"Grouping texts in chunks of {max_seq_length}", 537 | ) 538 | 539 | 540 | 541 | if training_args.do_train: 542 | if "train" not in tokenized_datasets: 543 | raise ValueError("--do_train requires a train dataset") 544 | train_dataset = tokenized_datasets["train"] 545 | if data_args.max_train_samples is not None: 546 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 547 | train_dataset = train_dataset.select(range(max_train_samples)) 548 | 549 | if training_args.do_eval: 550 | if "validation" not in tokenized_datasets: 551 | raise ValueError("--do_eval requires a validation dataset") 552 | eval_dataset = tokenized_datasets["validation"] 553 | if data_args.max_eval_samples is not None: 554 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 555 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 556 | 557 | def preprocess_logits_for_metrics(logits, labels): 558 | if isinstance(logits, tuple): 559 | # Depending on the model and config, logits may contain extra tensors, 560 | # like past_key_values, but logits always come first 561 | logits = logits[0] 562 | return logits.argmax(dim=-1) 563 | 564 | metric = evaluate.load("accuracy") 565 | 566 | def compute_metrics(eval_preds): 567 | preds, labels = eval_preds 568 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 569 | # by preprocess_logits_for_metrics 570 | labels = labels.reshape(-1) 571 | preds = preds.reshape(-1) 572 | mask = labels != -100 573 | labels = labels[mask] 574 | preds = preds[mask] 575 | return metric.compute(predictions=preds, references=labels) 576 | 577 | # Data collator 578 | # This one will take care of randomly masking the tokens. 579 | pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length 580 | #data_collator = DataCollatorForLanguageModeling( 581 | data_collator = RnaDataCollator( 582 | mlm=True, 583 | tokenizer=tokenizer, 584 | mlm_probability=data_args.mlm_probability, 585 | pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, 586 | ) 587 | 588 | 589 | 590 | # Initialize our Trainer 591 | trainer = Trainer( 592 | model=model, 593 | args=training_args, 594 | train_dataset=train_dataset if training_args.do_train else None, 595 | eval_dataset=eval_dataset if training_args.do_eval else None, 596 | tokenizer=tokenizer, 597 | data_collator=data_collator, 598 | compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, 599 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 600 | if training_args.do_eval and not is_torch_tpu_available() 601 | else None, 602 | ) 603 | 604 | 605 | 606 | for batch in trainer.get_train_dataloader(): 607 | print(batch['input_ids'][0]) 608 | print(batch['labels'][0]) 609 | break 610 | 611 | 612 | 613 | # Training 614 | if training_args.do_train: 615 | checkpoint = None 616 | if training_args.resume_from_checkpoint is not None: 617 | checkpoint = training_args.resume_from_checkpoint 618 | elif last_checkpoint is not None: 619 | checkpoint = last_checkpoint 620 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 621 | trainer.save_model() # Saves the tokenizer too for easy upload 622 | metrics = train_result.metrics 623 | 624 | max_train_samples = ( 625 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 626 | ) 627 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 628 | 629 | trainer.log_metrics("train", metrics) 630 | trainer.save_metrics("train", metrics) 631 | trainer.save_state() 632 | 633 | # Evaluation 634 | if training_args.do_eval: 635 | logger.info("*** Evaluate ***") 636 | 637 | metrics = trainer.evaluate() 638 | 639 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 640 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 641 | try: 642 | perplexity = math.exp(metrics["eval_loss"]) 643 | except OverflowError: 644 | perplexity = float("inf") 645 | metrics["perplexity"] = perplexity 646 | 647 | trainer.log_metrics("eval", metrics) 648 | trainer.save_metrics("eval", metrics) 649 | 650 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"} 651 | if data_args.dataset_name is not None: 652 | kwargs["dataset_tags"] = data_args.dataset_name 653 | if data_args.dataset_config_name is not None: 654 | kwargs["dataset_args"] = data_args.dataset_config_name 655 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 656 | else: 657 | kwargs["dataset"] = data_args.dataset_name 658 | 659 | if training_args.push_to_hub: 660 | trainer.push_to_hub(**kwargs) 661 | else: 662 | trainer.create_model_card(**kwargs) 663 | 664 | 665 | def _mp_fn(index): 666 | # For xla_spawn (TPUs) 667 | main() 668 | 669 | 670 | if __name__ == "__main__": 671 | main() 672 | --------------------------------------------------------------------------------