├── .gitignore ├── requirements.txt ├── 70_inference.sh ├── 00_download_data.py ├── 10_compile_python_scripts.sh ├── 61_train_dummy_model.sh ├── 60_train_model.sh ├── 20_convert_base256.py ├── configs └── bart-tiny.json ├── 40_tokenize_data.py ├── 50_create_jsonlines.py ├── 30_train_tokenizer.py └── run_translation.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | tokenizers 3 | datasets==2.15.0 4 | -------------------------------------------------------------------------------- /70_inference.sh: -------------------------------------------------------------------------------- 1 | python run_translation.py \ 2 | --model_name_or_path ../models/bart-tiny-lr1e-4 \ 3 | --tokenizer_name ../tokenizers/bpe_combined_ByteLevel_8000vocab_10000subset.json \ 4 | --output_dir ../models/bart-tiny-lr1e-4 \ 5 | --train_file ../data/jsonlines/train.json \ 6 | --test_file ../data/jsonlines/valid.json \ 7 | --source_lang bytecode \ 8 | --target_lang code \ 9 | --do_predict \ 10 | --generation_max_length 1024 \ 11 | --generation_num_beams 4 \ 12 | --predict_with_generate 13 | -------------------------------------------------------------------------------- /00_download_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tqdm import tqdm 3 | # download only the first archive 4 | dataset = load_dataset( 5 | "codeparrot/codeparrot-clean-train", data_files="file-000000000001.json.gz" 6 | ) 7 | 8 | print(dataset) 9 | 10 | for sample in tqdm(dataset["train"]): 11 | # todo: shard the data according to the gz archives 12 | with open( 13 | f"../data/codeparrot-clean-train/original_code/{sample['hash']}.py", "w" 14 | ) as f: 15 | f.write(sample["content"]) 16 | 17 | # py_compile.compile('data/codeparrot-clean-train/test.py', 'data/codeparrot-clean-train/test.pyc') 18 | -------------------------------------------------------------------------------- /10_compile_python_scripts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # compileall uses only one core 4 | # 100k scripts take around 2-3 minutes 5 | # for some scripts the compilation fails, this is expected (different versions of python) 6 | python -m compileall -qq -j 4 -d . ../data/codeparrot-clean-train/original_code 7 | version=$(python --version | cut -c 8-) 8 | mkdir -p ../data/codeparrot-clean-train/compiled-$version/ 9 | # mv does not work with 100k files: 10 | # mv ../data/codeparrot-clean-train/original_code/__pycache__/* ../data/codeparrot-clean-train/compiled-$version/ 11 | echo ../data/codeparrot-clean-train/original_code/__pycache__/*.pyc | xargs mv -t ../data/codeparrot-clean-train/compiled-$version/ -- -------------------------------------------------------------------------------- /61_train_dummy_model.sh: -------------------------------------------------------------------------------- 1 | rm -rf ../models/bart-tiny-dummy 2 | python run_translation.py \ 3 | --config_name ./configs/bart-tiny.json \ 4 | --tokenizer_name ../tokenizers/bpe_combined_ByteLevel_8000vocab_10000subset.json \ 5 | --output_dir ../models/bart-tiny-dummy \ 6 | --train_file ../data/jsonlines/train-dummy-repeat.json \ 7 | --validation_file ../data/jsonlines/train-dummy.json \ 8 | --source_lang bytecode \ 9 | --target_lang code \ 10 | --per_device_train_batch_size 32 \ 11 | --learning_rate 1e-3 \ 12 | --num_train_epochs 200 \ 13 | --generation_max_length 1024 \ 14 | --max_target_length 1024 \ 15 | --evaluation_strategy epoch \ 16 | --do_train \ 17 | --do_eval \ 18 | --predict_with_generate 19 | -------------------------------------------------------------------------------- /60_train_model.sh: -------------------------------------------------------------------------------- 1 | python run_translation.py \ 2 | --config_name ./configs/bart-tiny.json \ 3 | --tokenizer_name ../tokenizers/bpe_combined_ByteLevel_8000vocab_10000subset.json \ 4 | --output_dir ../models/bart-tiny-lr1e-4 \ 5 | --train_file ../data/jsonlines/train.json \ 6 | --validation_file ../data/jsonlines/valid.json \ 7 | --source_lang bytecode \ 8 | --target_lang code \ 9 | --per_device_train_batch_size 128 \ 10 | --learning_rate 1e-4 \ 11 | --warmup_steps 500 \ 12 | --num_train_epochs 100 \ 13 | --generation_max_length 1024 \ 14 | --max_target_length 1024 \ 15 | --evaluation_strategy epoch \ 16 | --do_train \ 17 | --do_eval \ 18 | --max_eval_samples 32 \ 19 | --predict_with_generate \ 20 | --generation_num_beams=4 21 | -------------------------------------------------------------------------------- /20_convert_base256.py: -------------------------------------------------------------------------------- 1 | import re 2 | def generate_alphabet(): 3 | for _i in range(256): 4 | i = _i 5 | if _i >= 0 and _i <= 31: 6 | i = _i + 255 7 | # if _i == 32: 8 | # i = 0x005F 9 | if _i >= 127 and _i <= 160: 10 | i = _i + 255 11 | char = chr(i) 12 | # print(_i, char, re.match(r"\s", char)) 13 | yield (_i, char) 14 | 15 | alphabet = dict(generate_alphabet()) 16 | print(alphabet, file=open("alphabet.txt", "w")) 17 | # print(alphabet) 18 | 19 | from glob import glob 20 | import os 21 | from tqdm import tqdm 22 | 23 | path = "../data/codeparrot-clean-train/compiled-3.8.18" 24 | outdir = "../data/processed/codeparrot-clean-train/compiled-3.8.18" 25 | os.makedirs(outdir, exist_ok=True) 26 | for pyc in tqdm(glob(path+"/*.pyc")): 27 | filename = os.path.basename(pyc) 28 | out_path = os.path.join(outdir, filename+".txt") 29 | with open(pyc, "rb") as f, open(out_path, "w") as out: 30 | code = f.read() 31 | output = [] 32 | for byte in code: 33 | output.append(alphabet[byte]) 34 | 35 | out.write("".join(output)) 36 | -------------------------------------------------------------------------------- /configs/bart-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "bart-tiny", 3 | "activation_dropout": 0.1, 4 | "activation_function": "gelu", 5 | "add_bias_logits": false, 6 | "add_final_layer_norm": false, 7 | "architectures": [ 8 | "BartModel" 9 | ], 10 | "attention_dropout": 0.1, 11 | "bos_token_id": 1, 12 | "classif_dropout": 0.1, 13 | "classifier_dropout": 0.0, 14 | "d_model": 256, 15 | "decoder_attention_heads": 8, 16 | "decoder_ffn_dim": 1024, 17 | "decoder_layerdrop": 0.0, 18 | "decoder_layers": 3, 19 | "decoder_start_token_id": 2, 20 | "forced_bos_token_id": 1, 21 | "dropout": 0.1, 22 | "early_stopping": true, 23 | "encoder_attention_heads": 8, 24 | "encoder_ffn_dim": 1024, 25 | "encoder_layerdrop": 0.0, 26 | "encoder_layers": 3, 27 | "eos_token_id": 2, 28 | "gradient_checkpointing": false, 29 | "id2label": { 30 | "0": "LABEL_0", 31 | "1": "LABEL_1", 32 | "2": "LABEL_2" 33 | }, 34 | "init_std": 0.02, 35 | "is_encoder_decoder": true, 36 | "label2id": { 37 | "LABEL_0": 0, 38 | "LABEL_1": 1, 39 | "LABEL_2": 2 40 | }, 41 | "max_position_embeddings": 1024, 42 | "model_type": "bart", 43 | "no_repeat_ngram_size": 0, 44 | "normalize_before": false, 45 | "normalize_embedding": true, 46 | "num_beams": 4, 47 | "num_hidden_layers": 6, 48 | "pad_token_id": 0, 49 | "scale_embedding": false, 50 | "torch_dtype": "float32", 51 | "transformers_version": "4.12.0.dev0", 52 | "use_cache": true, 53 | "vocab_size": 8000 54 | } 55 | -------------------------------------------------------------------------------- /40_tokenize_data.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedTokenizerFast 2 | tokenizer = PreTrainedTokenizerFast(tokenizer_file="../tokenizers/bpe_combined_ByteLevel_8000vocab_10000subset.json") 3 | 4 | from glob import glob 5 | from tqdm import tqdm 6 | binary_files = glob("../data/processed/codeparrot-clean-train/compiled-3.8.18/*.txt") 7 | code_files = glob("../data/codeparrot-clean-train/original_code/*.py") 8 | 9 | results = [] 10 | from collections import Counter 11 | import pandas as pd 12 | counter_bytecode = Counter() 13 | counter_code = Counter() 14 | for filename in tqdm(binary_files + code_files): 15 | # print(filename) 16 | with open(filename, "r") as f: 17 | text = f.read() 18 | # print(text) 19 | tokenized = tokenizer(text) 20 | # tokens = tokenizer.convert_ids_to_tokens(tokenized["input_ids"]) 21 | tokens = tokenized["input_ids"] 22 | 23 | type = "bytecode" if filename.endswith(".txt") else "code" 24 | if type == "bytecode": 25 | counter_bytecode.update(tokens) 26 | else: 27 | counter_code.update(tokens) 28 | 29 | results.append({ 30 | "filename": filename, 31 | "type": type, 32 | "tokens": len(tokenized["input_ids"]), 33 | }) 34 | 35 | pd.DataFrame(results).to_csv("40_tokenize_data.csv", index=False) 36 | pd.DataFrame(counter_bytecode.most_common()).to_csv("token_stats_bytecode.tsv",sep='\t',index=False,header=False) 37 | pd.DataFrame(counter_code.most_common()).to_csv("token_stats_code.tsv",sep='\t',index=False,header=False) 38 | 39 | -------------------------------------------------------------------------------- /50_create_jsonlines.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # sacrebleu format to jsonlines 4 | 5 | import io 6 | import json 7 | import re 8 | import pandas as pd 9 | import os 10 | 11 | src_lang, tgt_lang = ["bytecode", "code"] 12 | 13 | df = pd.read_csv("40_tokenize_data.csv") 14 | df["id"] = df["filename"].map(lambda f: os.path.basename(f).split(".")[0]) 15 | df_bytecode = df[df.type=="bytecode"].join(df[df.type=="code"].set_index("id"), on="id", lsuffix="_bytecode", rsuffix="_code") 16 | df_short = df_bytecode[(df_bytecode.tokens_bytecode+df_bytecode.tokens_code < 1024)] 17 | 18 | from sklearn.model_selection import train_test_split 19 | 20 | # Splitting into train and remaining data 21 | df_train, df_remaining = train_test_split(df_short, test_size=0.2, random_state=42) 22 | 23 | # Splitting remaining data into valid and test in equal proportions 24 | df_valid, df_test = train_test_split(df_remaining, test_size=0.5, random_state=42) 25 | 26 | # Print the lengths of the datasets to verify proportions 27 | print("Train set length:", len(df_train)) 28 | print("Validation set length:", len(df_valid)) 29 | print("Test set length:", len(df_test)) 30 | # all_ids = set(df_short.id.values) 31 | 32 | 33 | for split, df in zip(["train", "valid", "test"], [df_train, df_valid, df_test]): 34 | fout = f"../data/jsonlines/{split}.json" 35 | with open(fout, "w", encoding="utf-8") as f: 36 | for i, row in df.iterrows(): 37 | bytecode = open(row["filename_bytecode"], "r").read() 38 | code = open(row["filename_code"], "r").read() 39 | out = {"translation": { src_lang: bytecode, tgt_lang: code } } 40 | x = json.dumps(out, indent=0, ensure_ascii=False) 41 | x = re.sub(r'\n', ' ', x, 0, re.M) 42 | f.write(x + "\n") 43 | 44 | # for type in ["source", "target"]: 45 | # fin = f"{split}.{type}" 46 | # recs.append([line.strip() for line in open(fin)]) 47 | # for src, tgt in zip(*recs): 48 | # out = {"translation": { src_lang: src, tgt_lang: tgt } } 49 | # x = json.dumps(out, indent=0, ensure_ascii=False) 50 | # x = re.sub(r'\n', ' ', x, 0, re.M) 51 | # f.write(x + "\n") -------------------------------------------------------------------------------- /30_train_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers.trainers import BpeTrainer 2 | from tokenizers.models import BPE 3 | from tokenizers import Tokenizer, decoders, processors, normalizers, pre_tokenizers 4 | tokenizer = Tokenizer(BPE()) 5 | processing = "ByteLevel" 6 | if processing == "ByteLevel": 7 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 8 | # tokenizer.normalizer = normalizers.NFKC() 9 | tokenizer.decoder = decoders.ByteLevel(add_prefix_space=False) 10 | tokenizer.post_processor = processors.ByteLevel(add_prefix_space=False) 11 | elif processing == "Metaspace": 12 | tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() 13 | tokenizer.decoder = decoders.Metaspace() 14 | 15 | tokenizer.post_processor = processors.TemplateProcessing(single="[BOS] $A [EOS]", special_tokens=[("[BOS]", 1), ("[EOS]", 2)]) 16 | # tokenizer.pre_tokenizer = Split(pattern=r'\w+|\s', behavior='isolated') 17 | vocab_size = 8000 18 | trainer = BpeTrainer( 19 | vocab_size=vocab_size, 20 | max_token_length=16, 21 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), 22 | # limit_alphabet=350, 23 | special_tokens=["[PAD]", "[BOS]", "[EOS]"], 24 | ) 25 | 26 | subset = 10000 27 | from glob import glob 28 | from random import shuffle 29 | binary_files = glob("../data/processed/codeparrot-clean-train/compiled-3.8.18/*.txt") 30 | shuffle(binary_files) 31 | binary_files = binary_files[:subset] 32 | code_files = glob("../data/codeparrot-clean-train/original_code/*.py") 33 | shuffle(code_files) 34 | code_files = code_files[:subset] 35 | 36 | all_files = binary_files + code_files 37 | 38 | # convert binary to hex 39 | # def data_iterator(): 40 | # for file in glob("../data/processed/codeparrot-clean-train/compiled-3.8.16/*.txt"): 41 | # code = open(file, "rb").read() 42 | # yield code.hex() 43 | # tokenizer.train_from_iterator(iter(data_iterator()), trainer=trainer) 44 | 45 | tokenizer.train(all_files, trainer=trainer) 46 | import os 47 | os.makedirs(f"../tokenizers", exist_ok=True) 48 | tokenizer_path = f"../tokenizers/bpe_combined_{processing}_{vocab_size}vocab_{subset}subset.json" 49 | tokenizer.save(tokenizer_path) 50 | # try loading the tokenizer 51 | tokenizer.from_file(tokenizer_path) 52 | # tokenizer loading fails if the trained tokens contain whitespaces... 53 | -------------------------------------------------------------------------------- /run_translation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | import warnings 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import evaluate 30 | import numpy as np 31 | from datasets import load_dataset 32 | 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | M2M100Tokenizer, 41 | MBart50Tokenizer, 42 | MBart50TokenizerFast, 43 | MBartTokenizer, 44 | MBartTokenizerFast, 45 | Seq2SeqTrainer, 46 | Seq2SeqTrainingArguments, 47 | default_data_collator, 48 | set_seed, 49 | ) 50 | from transformers import PreTrainedTokenizerFast 51 | from transformers.trainer_utils import get_last_checkpoint 52 | from transformers.utils import check_min_version, send_example_telemetry 53 | from transformers.utils.versions import require_version 54 | 55 | 56 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 57 | # check_min_version("4.39.0.dev0") 58 | 59 | # require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") 60 | 61 | logger = logging.getLogger(__name__) 62 | 63 | # A list of all multilingual tokenizer which require src_lang and tgt_lang attributes. 64 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer] 65 | 66 | 67 | @dataclass 68 | class ModelArguments: 69 | """ 70 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 71 | """ 72 | 73 | model_name_or_path: Optional[str] = field( 74 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 75 | ) 76 | config_name: Optional[str] = field( 77 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 78 | ) 79 | tokenizer_name: Optional[str] = field( 80 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 81 | ) 82 | cache_dir: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 85 | ) 86 | use_fast_tokenizer: bool = field( 87 | default=True, 88 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 89 | ) 90 | model_revision: str = field( 91 | default="main", 92 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 93 | ) 94 | token: str = field( 95 | default=None, 96 | metadata={ 97 | "help": ( 98 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " 99 | "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." 100 | ) 101 | }, 102 | ) 103 | use_auth_token: bool = field( 104 | default=None, 105 | metadata={ 106 | "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead." 107 | }, 108 | ) 109 | trust_remote_code: bool = field( 110 | default=False, 111 | metadata={ 112 | "help": ( 113 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " 114 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 115 | "execute code present on the Hub on your local machine." 116 | ) 117 | }, 118 | ) 119 | 120 | 121 | @dataclass 122 | class DataTrainingArguments: 123 | """ 124 | Arguments pertaining to what data we are going to input our model for training and eval. 125 | """ 126 | 127 | source_lang: str = field(default=None, metadata={"help": "Source language id for translation."}) 128 | target_lang: str = field(default=None, metadata={"help": "Target language id for translation."}) 129 | 130 | dataset_name: Optional[str] = field( 131 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 132 | ) 133 | dataset_config_name: Optional[str] = field( 134 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 135 | ) 136 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a jsonlines)."}) 137 | validation_file: Optional[str] = field( 138 | default=None, 139 | metadata={ 140 | "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file." 141 | }, 142 | ) 143 | test_file: Optional[str] = field( 144 | default=None, 145 | metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."}, 146 | ) 147 | overwrite_cache: bool = field( 148 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 149 | ) 150 | preprocessing_num_workers: Optional[int] = field( 151 | default=None, 152 | metadata={"help": "The number of processes to use for the preprocessing."}, 153 | ) 154 | max_source_length: Optional[int] = field( 155 | default=1024, 156 | metadata={ 157 | "help": ( 158 | "The maximum total input sequence length after tokenization. Sequences longer " 159 | "than this will be truncated, sequences shorter will be padded." 160 | ) 161 | }, 162 | ) 163 | max_target_length: Optional[int] = field( 164 | default=128, 165 | metadata={ 166 | "help": ( 167 | "The maximum total sequence length for target text after tokenization. Sequences longer " 168 | "than this will be truncated, sequences shorter will be padded." 169 | ) 170 | }, 171 | ) 172 | val_max_target_length: Optional[int] = field( 173 | default=None, 174 | metadata={ 175 | "help": ( 176 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 177 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`. " 178 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 179 | "during ``evaluate`` and ``predict``." 180 | ) 181 | }, 182 | ) 183 | pad_to_max_length: bool = field( 184 | default=False, 185 | metadata={ 186 | "help": ( 187 | "Whether to pad all samples to model maximum sentence length. " 188 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 189 | "efficient on GPU but very bad for TPU." 190 | ) 191 | }, 192 | ) 193 | max_train_samples: Optional[int] = field( 194 | default=None, 195 | metadata={ 196 | "help": ( 197 | "For debugging purposes or quicker training, truncate the number of training examples to this " 198 | "value if set." 199 | ) 200 | }, 201 | ) 202 | max_eval_samples: Optional[int] = field( 203 | default=None, 204 | metadata={ 205 | "help": ( 206 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 207 | "value if set." 208 | ) 209 | }, 210 | ) 211 | max_predict_samples: Optional[int] = field( 212 | default=None, 213 | metadata={ 214 | "help": ( 215 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 216 | "value if set." 217 | ) 218 | }, 219 | ) 220 | num_beams: Optional[int] = field( 221 | default=1, 222 | metadata={ 223 | "help": ( 224 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 225 | "which is used during ``evaluate`` and ``predict``." 226 | ) 227 | }, 228 | ) 229 | ignore_pad_token_for_loss: bool = field( 230 | default=True, 231 | metadata={ 232 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 233 | }, 234 | ) 235 | source_prefix: Optional[str] = field( 236 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 237 | ) 238 | forced_bos_token: Optional[str] = field( 239 | default=None, 240 | metadata={ 241 | "help": ( 242 | "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for" 243 | " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to" 244 | " be the target language token.(Usually it is the target language token)" 245 | ) 246 | }, 247 | ) 248 | 249 | def __post_init__(self): 250 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 251 | raise ValueError("Need either a dataset name or a training/validation file.") 252 | elif self.source_lang is None or self.target_lang is None: 253 | raise ValueError("Need to specify the source language and the target language.") 254 | 255 | # accepting both json and jsonl file extensions, as 256 | # many jsonlines files actually have a .json extension 257 | valid_extensions = ["json", "jsonl"] 258 | 259 | if self.train_file is not None: 260 | extension = self.train_file.split(".")[-1] 261 | assert extension in valid_extensions, "`train_file` should be a jsonlines file." 262 | if self.validation_file is not None: 263 | extension = self.validation_file.split(".")[-1] 264 | assert extension in valid_extensions, "`validation_file` should be a jsonlines file." 265 | if self.val_max_target_length is None: 266 | self.val_max_target_length = self.max_target_length 267 | 268 | 269 | def main(): 270 | # See all possible arguments in src/transformers/training_args.py 271 | # or by passing the --help flag to this script. 272 | # We now keep distinct sets of args, for a cleaner separation of concerns. 273 | 274 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 275 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 276 | # If we pass only one argument to the script and it's the path to a json file, 277 | # let's parse it to get our arguments. 278 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 279 | else: 280 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 281 | 282 | if model_args.use_auth_token is not None: 283 | warnings.warn( 284 | "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.", 285 | FutureWarning, 286 | ) 287 | if model_args.token is not None: 288 | raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") 289 | model_args.token = model_args.use_auth_token 290 | 291 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 292 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 293 | send_example_telemetry("run_translation", model_args, data_args) 294 | 295 | # Setup logging 296 | logging.basicConfig( 297 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 298 | datefmt="%m/%d/%Y %H:%M:%S", 299 | handlers=[logging.StreamHandler(sys.stdout)], 300 | ) 301 | 302 | if training_args.should_log: 303 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 304 | transformers.utils.logging.set_verbosity_info() 305 | 306 | log_level = training_args.get_process_log_level() 307 | logger.setLevel(log_level) 308 | datasets.utils.logging.set_verbosity(log_level) 309 | transformers.utils.logging.set_verbosity(log_level) 310 | transformers.utils.logging.enable_default_handler() 311 | transformers.utils.logging.enable_explicit_format() 312 | 313 | # Log on each process the small summary: 314 | logger.warning( 315 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 316 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 317 | ) 318 | logger.info(f"Training/evaluation parameters {training_args}") 319 | 320 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 321 | "google-t5/t5-small", 322 | "google-t5/t5-base", 323 | "google-t5/t5-large", 324 | "google-t5/t5-3b", 325 | "google-t5/t5-11b", 326 | ]: 327 | logger.warning( 328 | "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with " 329 | "`--source_prefix 'translate English to German: ' `" 330 | ) 331 | 332 | # Detecting last checkpoint. 333 | last_checkpoint = None 334 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 335 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 336 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 337 | raise ValueError( 338 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 339 | "Use --overwrite_output_dir to overcome." 340 | ) 341 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 342 | logger.info( 343 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 344 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 345 | ) 346 | 347 | # Set seed before initializing model. 348 | set_seed(training_args.seed) 349 | 350 | # Get the datasets: you can either provide your own JSON training and evaluation files (see below) 351 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 352 | # (the dataset will be downloaded automatically from the datasets Hub). 353 | # 354 | # For translation, only JSON files are supported, with one field named "translation" containing two keys for the 355 | # source and target languages (unless you adapt what follows). 356 | # 357 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 358 | # download the dataset. 359 | if data_args.dataset_name is not None: 360 | # Downloading and loading a dataset from the hub. 361 | raw_datasets = load_dataset( 362 | data_args.dataset_name, 363 | data_args.dataset_config_name, 364 | cache_dir=model_args.cache_dir, 365 | token=model_args.token, 366 | ) 367 | else: 368 | data_files = {} 369 | if data_args.train_file is not None: 370 | data_files["train"] = data_args.train_file 371 | extension = data_args.train_file.split(".")[-1] 372 | if data_args.validation_file is not None: 373 | data_files["validation"] = data_args.validation_file 374 | extension = data_args.validation_file.split(".")[-1] 375 | if data_args.test_file is not None: 376 | data_files["test"] = data_args.test_file 377 | extension = data_args.test_file.split(".")[-1] 378 | if extension == "jsonl": 379 | builder_name = "json" # the "json" builder reads both .json and .jsonl files 380 | else: 381 | builder_name = extension # e.g. "parquet" 382 | raw_datasets = load_dataset( 383 | builder_name, 384 | data_files=data_files, 385 | cache_dir=model_args.cache_dir, 386 | token=model_args.token, 387 | ) 388 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 389 | # https://huggingface.co/docs/datasets/loading. 390 | 391 | # Load pretrained model and tokenizer 392 | # 393 | # Distributed training: 394 | # The .from_pretrained methods guarantee that only one local process can concurrently 395 | # download model & vocab. 396 | config = AutoConfig.from_pretrained( 397 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 398 | cache_dir=model_args.cache_dir, 399 | revision=model_args.model_revision, 400 | token=model_args.token, 401 | trust_remote_code=model_args.trust_remote_code, 402 | ) 403 | # tokenizer = AutoTokenizer.from_pretrained( 404 | # model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 405 | # cache_dir=model_args.cache_dir, 406 | # use_fast=model_args.use_fast_tokenizer, 407 | # revision=model_args.model_revision, 408 | # token=model_args.token, 409 | # trust_remote_code=model_args.trust_remote_code, 410 | # ) 411 | 412 | tokenizer = PreTrainedTokenizerFast(tokenizer_file=model_args.tokenizer_name) 413 | tokenizer.pad_token = "[PAD]" 414 | print("tokenizer.pad_token_id",tokenizer.pad_token_id) 415 | 416 | 417 | model = AutoModelForSeq2SeqLM.from_config(config, trust_remote_code=model_args.trust_remote_code) 418 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 419 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 420 | 421 | 422 | # model = AutoModelForSeq2SeqLM.from_pretrained( 423 | # model_args.model_name_or_path, 424 | # from_tf=bool(".ckpt" in model_args.model_name_or_path), 425 | # config=config, 426 | # cache_dir=model_args.cache_dir, 427 | # revision=model_args.model_revision, 428 | # token=model_args.token, 429 | # trust_remote_code=model_args.trust_remote_code, 430 | # ) 431 | 432 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 433 | # on a small vocab and want a smaller embedding size, remove this test. 434 | embedding_size = model.get_input_embeddings().weight.shape[0] 435 | if len(tokenizer) > embedding_size: 436 | model.resize_token_embeddings(len(tokenizer)) 437 | 438 | # Set decoder_start_token_id 439 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 440 | if isinstance(tokenizer, MBartTokenizer): 441 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang] 442 | else: 443 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang) 444 | 445 | if model.config.decoder_start_token_id is None: 446 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 447 | 448 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 449 | 450 | # Preprocessing the datasets. 451 | # We need to tokenize inputs and targets. 452 | if training_args.do_train: 453 | column_names = raw_datasets["train"].column_names 454 | elif training_args.do_eval: 455 | column_names = raw_datasets["validation"].column_names 456 | elif training_args.do_predict: 457 | column_names = raw_datasets["test"].column_names 458 | else: 459 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 460 | return 461 | 462 | # For translation we set the codes of our source and target languages (only useful for mBART, the others will 463 | # ignore those attributes). 464 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 465 | assert data_args.target_lang is not None and data_args.source_lang is not None, ( 466 | f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and " 467 | "--target_lang arguments." 468 | ) 469 | 470 | tokenizer.src_lang = data_args.source_lang 471 | tokenizer.tgt_lang = data_args.target_lang 472 | 473 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 474 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 475 | forced_bos_token_id = ( 476 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 477 | ) 478 | model.config.forced_bos_token_id = forced_bos_token_id 479 | 480 | # Get the language codes for input/target. 481 | source_lang = data_args.source_lang.split("_")[0] 482 | target_lang = data_args.target_lang.split("_")[0] 483 | 484 | # Check the whether the source target length fits in the model, if it has absolute positional embeddings 485 | if ( 486 | hasattr(model.config, "max_position_embeddings") 487 | and not hasattr(model.config, "relative_attention_max_distance") 488 | and model.config.max_position_embeddings < data_args.max_source_length 489 | ): 490 | raise ValueError( 491 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" 492 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing" 493 | f" `--max_source_length` to {model.config.max_position_embeddings} or using a model with larger position " 494 | "embeddings" 495 | ) 496 | 497 | # Temporarily set max_target_length for training. 498 | max_target_length = data_args.max_target_length 499 | padding = "max_length" if data_args.pad_to_max_length else False 500 | 501 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 502 | logger.warning( 503 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for " 504 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 505 | ) 506 | 507 | def preprocess_function(examples): 508 | inputs = [ex[source_lang] for ex in examples["translation"]] 509 | targets = [ex[target_lang] for ex in examples["translation"]] 510 | inputs = [prefix + inp for inp in inputs] 511 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 512 | 513 | # Tokenize targets with the `text_target` keyword argument 514 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 515 | 516 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 517 | # padding in the loss. 518 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 519 | labels["input_ids"] = [ 520 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 521 | ] 522 | 523 | model_inputs["labels"] = labels["input_ids"] 524 | return model_inputs 525 | 526 | if training_args.do_train: 527 | if "train" not in raw_datasets: 528 | raise ValueError("--do_train requires a train dataset") 529 | train_dataset = raw_datasets["train"] 530 | if data_args.max_train_samples is not None: 531 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 532 | train_dataset = train_dataset.select(range(max_train_samples)) 533 | with training_args.main_process_first(desc="train dataset map pre-processing"): 534 | train_dataset = train_dataset.map( 535 | preprocess_function, 536 | batched=True, 537 | num_proc=data_args.preprocessing_num_workers, 538 | remove_columns=column_names, 539 | load_from_cache_file=not data_args.overwrite_cache, 540 | desc="Running tokenizer on train dataset", 541 | ) 542 | 543 | if training_args.do_eval: 544 | max_target_length = data_args.val_max_target_length 545 | if "validation" not in raw_datasets: 546 | raise ValueError("--do_eval requires a validation dataset") 547 | eval_dataset = raw_datasets["validation"] 548 | if data_args.max_eval_samples is not None: 549 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 550 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 551 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 552 | eval_dataset = eval_dataset.map( 553 | preprocess_function, 554 | batched=True, 555 | num_proc=data_args.preprocessing_num_workers, 556 | remove_columns=column_names, 557 | load_from_cache_file=not data_args.overwrite_cache, 558 | desc="Running tokenizer on validation dataset", 559 | ) 560 | 561 | if training_args.do_predict: 562 | max_target_length = data_args.val_max_target_length 563 | if "test" not in raw_datasets: 564 | raise ValueError("--do_predict requires a test dataset") 565 | predict_dataset = raw_datasets["test"] 566 | if data_args.max_predict_samples is not None: 567 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 568 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 569 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 570 | predict_dataset = predict_dataset.map( 571 | preprocess_function, 572 | batched=True, 573 | num_proc=data_args.preprocessing_num_workers, 574 | remove_columns=column_names, 575 | load_from_cache_file=not data_args.overwrite_cache, 576 | desc="Running tokenizer on prediction dataset", 577 | ) 578 | 579 | # Data collator 580 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 581 | if data_args.pad_to_max_length: 582 | data_collator = default_data_collator 583 | else: 584 | data_collator = DataCollatorForSeq2Seq( 585 | tokenizer, 586 | model=model, 587 | label_pad_token_id=label_pad_token_id, 588 | pad_to_multiple_of=8 if training_args.fp16 else None, 589 | ) 590 | 591 | # Metric 592 | metric = evaluate.load("sacrebleu", cache_dir=model_args.cache_dir) 593 | 594 | def postprocess_text(preds, labels): 595 | preds = [pred.strip() for pred in preds] 596 | labels = [[label.strip()] for label in labels] 597 | 598 | return preds, labels 599 | 600 | def compute_metrics(eval_preds): 601 | preds, labels = eval_preds 602 | if isinstance(preds, tuple): 603 | preds = preds[0] 604 | # Replace -100s used for padding as we can't decode them 605 | preds = np.where(preds != -100, preds, tokenizer.pad_token_id) 606 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 607 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 608 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 609 | 610 | # Some simple post-processing 611 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 612 | 613 | result = metric.compute(predictions=decoded_preds, references=decoded_labels) 614 | result = {"bleu": result["score"]} 615 | 616 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 617 | result["gen_len"] = np.mean(prediction_lens) 618 | result = {k: round(v, 4) for k, v in result.items()} 619 | return result 620 | 621 | # Initialize our Trainer 622 | trainer = Seq2SeqTrainer( 623 | model=model, 624 | args=training_args, 625 | train_dataset=train_dataset if training_args.do_train else None, 626 | eval_dataset=eval_dataset if training_args.do_eval else None, 627 | tokenizer=tokenizer, 628 | data_collator=data_collator, 629 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 630 | ) 631 | 632 | # Training 633 | if training_args.do_train: 634 | checkpoint = None 635 | if training_args.resume_from_checkpoint is not None: 636 | checkpoint = training_args.resume_from_checkpoint 637 | elif last_checkpoint is not None: 638 | checkpoint = last_checkpoint 639 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 640 | trainer.save_model() # Saves the tokenizer too for easy upload 641 | 642 | metrics = train_result.metrics 643 | max_train_samples = ( 644 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 645 | ) 646 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 647 | 648 | trainer.log_metrics("train", metrics) 649 | trainer.save_metrics("train", metrics) 650 | trainer.save_state() 651 | 652 | # Evaluation 653 | results = {} 654 | max_length = ( 655 | training_args.generation_max_length 656 | if training_args.generation_max_length is not None 657 | else data_args.val_max_target_length 658 | ) 659 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 660 | if training_args.do_eval: 661 | logger.info("*** Evaluate ***") 662 | 663 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 664 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 665 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 666 | 667 | trainer.log_metrics("eval", metrics) 668 | trainer.save_metrics("eval", metrics) 669 | 670 | if training_args.do_predict: 671 | logger.info("*** Predict ***") 672 | 673 | predict_results = trainer.predict( 674 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 675 | ) 676 | metrics = predict_results.metrics 677 | max_predict_samples = ( 678 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 679 | ) 680 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 681 | 682 | trainer.log_metrics("predict", metrics) 683 | trainer.save_metrics("predict", metrics) 684 | 685 | if trainer.is_world_process_zero(): 686 | if training_args.predict_with_generate: 687 | predictions = predict_results.predictions 688 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 689 | predictions = tokenizer.batch_decode( 690 | predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 691 | ) 692 | predictions = [pred.strip() for pred in predictions] 693 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 694 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 695 | writer.write("\n".join(predictions)) 696 | 697 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "translation"} 698 | if data_args.dataset_name is not None: 699 | kwargs["dataset_tags"] = data_args.dataset_name 700 | if data_args.dataset_config_name is not None: 701 | kwargs["dataset_args"] = data_args.dataset_config_name 702 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 703 | else: 704 | kwargs["dataset"] = data_args.dataset_name 705 | 706 | languages = [l for l in [data_args.source_lang, data_args.target_lang] if l is not None] 707 | if len(languages) > 0: 708 | kwargs["language"] = languages 709 | 710 | if training_args.push_to_hub: 711 | trainer.push_to_hub(**kwargs) 712 | else: 713 | trainer.create_model_card(**kwargs) 714 | 715 | return results 716 | 717 | 718 | def _mp_fn(index): 719 | # For xla_spawn (TPUs) 720 | main() 721 | 722 | 723 | if __name__ == "__main__": 724 | main() 725 | --------------------------------------------------------------------------------