├── dataset.py ├── .gitignore ├── push_to_hub.py ├── LICENSE ├── inference_single.py ├── config.yaml ├── README.md ├── inference.py ├── metric.py ├── train.py └── CC3M_translate_inference.py /dataset.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | runs 3 | wandb 4 | _clones 5 | logs 6 | ke-t5-base-finetuned-en-to-ko 7 | results -------------------------------------------------------------------------------- /push_to_hub.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM 2 | from easydict import EasyDict 3 | import yaml 4 | 5 | # Read config.yaml file 6 | with open("config.yaml") as infile: 7 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 8 | CFG = EasyDict(SAVED_CFG["CFG"]) 9 | 10 | model_name = "/home/ubuntu/En_to_Ko/ke-t5-base-finetuned-en-to-ko/checkpoint-17850" 11 | config = AutoConfig.from_pretrained(model_name) 12 | tokenizer = AutoTokenizer.from_pretrained(model_name) 13 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 14 | 15 | config.push_to_hub("QuoQA-NLP/KE-T5-En2Ko-Base", private=True, use_temp_dir=True) 16 | tokenizer.push_to_hub("QuoQA-NLP/KE-T5-En2Ko-Base", private=True, use_temp_dir=True) 17 | model.push_to_hub("QuoQA-NLP/KE-T5-En2Ko-Base", private=True, use_temp_dir=True) 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 QuoQA-NLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /inference_single.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, MarianMTModel, AutoTokenizer, AutoModelForSeq2SeqLM 2 | from easydict import EasyDict 3 | import yaml 4 | 5 | # Read config.yaml file 6 | with open("config.yaml") as infile: 7 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 8 | CFG = EasyDict(SAVED_CFG["CFG"]) 9 | 10 | # https://huggingface.co/datasets/conceptual_captions 11 | src_text = [ 12 | "sierra looked stunning in this top and this skirt while performing with person at their former university" 13 | ] 14 | 15 | # model_name = "/home/ubuntu/En_to_Ko/ke-t5-base-finetuned-en-to-ko/checkpoint-17850" 16 | model_name = CFG.inference_model_name 17 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) 18 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True) 19 | 20 | translated = model.generate( 21 | **tokenizer(src_text, return_tensors="pt", padding=True, max_length=CFG.max_token_length,), 22 | max_length=CFG.max_token_length, 23 | num_beams=CFG.num_beams, 24 | repetition_penalty=CFG.repetition_penalty, 25 | no_repeat_ngram_size=CFG.no_repeat_ngram_size, 26 | num_return_sequences=CFG.num_return_sequences, 27 | ) 28 | print([tokenizer.decode(t, skip_special_tokens=True) for t in translated]) 29 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Debug set to true in order to debug high-layer code. 2 | # CFG Configuration 3 | # https://wandb.ai/poolc/huggingface/runs/r0pyyxyg/overview?workspace=user-snoop2head 4 | CFG: 5 | DEBUG: false 6 | train_batch_size: 64 7 | valid_batch_size: 128 8 | 9 | # Train configuration 10 | num_epochs: 1 # validation loss is increasing after 5 epochs 11 | num_checkpoints: 3 12 | max_token_length: 64 13 | stopwords: [] 14 | learning_rate: 0.0005 # has to be set as float explicitly due to https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number 15 | weight_decay: 0.01 # https://paperswithcode.com/method/weight-decay 16 | adam_beta_1: 0.9 17 | adam_beta_2: 0.98 18 | epsilon: 0.000000001 19 | fp16: false 20 | gradient_accumulation_steps: 2 21 | save_steps: 150 22 | logging_steps: 150 23 | evaluation_strategy: "epoch" 24 | 25 | # Evaluation configuration 26 | inference_model_name: "QuoQA-NLP/KE-T5-En2Ko-Base" 27 | no_inference_sentences: 100 28 | num_beams: 5 29 | repetition_penalty: 1.3 30 | no_repeat_ngram_size: 3 31 | num_return_sequences: 1 32 | 33 | # Translation settings 34 | dset_name: "LeverageX/AIHUB-all-parallel-ko-en" # or LeverageX/AIHUB-socio-parallel-ko-en 35 | src_language: "en" 36 | tgt_language: "ko" 37 | model_name: "KETI-AIR/ke-t5-base" 38 | num_inference_sample: 5000 39 | dropout: 0.1 40 | 41 | # wandb settings 42 | entity_name: "quoqa-nlp" 43 | project_name: "EN-TO-KO-Translation" 44 | 45 | # root path 46 | ROOT_PATH: "." 47 | save_path: "./results" 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T5 Machine Translation: English ↔️ Korean 2 | 3 | [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://huggingface.co/spaces/QuoQA-NLP/QuoQaGo) 4 | 5 | ### Result 6 | 7 | | | BLEU Score | Translation Result | 8 | | :---------------: | :--------: | :--------------------------------------------------------------------------------------------------------------: | 9 | | Korean ➡️ English | 45.148 | [KE-T5-Ko2En-Base Inference Result](https://huggingface.co/datasets/QuoQA-NLP/KE-T5-Ko2En-Base-Inference-Result) | 10 | | English ➡️ Korean | - | | 11 | 12 | - Evaluation script is on [metric.py](./metric.py) 13 | - Korean ➡️ English Result evaluated on 553500 sentence pairs which are disjoint from the train set. 14 | 15 | ### How to Use 16 | 17 | ```python 18 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 19 | 20 | # Korean -> English Machine Translation 21 | tokenizer = AutoTokenizer.from_pretrained("QuoQA-NLP/KE-T5-Ko2En-Base") 22 | model = AutoModelForSeq2SeqLM.from_pretrained("QuoQA-NLP/KE-T5-Ko2En-Base") 23 | 24 | # English -> Korean Machine Translation 25 | tokenizer = AutoTokenizer.from_pretrained("QuoQA-NLP/KE-T5-En2Ko-Base") 26 | model = AutoModelForSeq2SeqLM.from_pretrained("QuoQA-NLP/KE-T5-En2Ko-Base") 27 | ``` 28 | 29 | - For batch translation, please refer to [inference.py](./inference.py). 30 | - P100 16GB supports inferencing of 250 pairs per batch on device. 31 | - A100 40GB supports inferencing of 600 pairs per batch on device. 32 | - For single sentence translation, please refer to [inference_single.py](./inference_single.py). 33 | 34 | ### References 35 | 36 | - [🔗 Dataset specification](https://github.com/snoop2head/Deep-Encoder-Shallow-Decoder#dataset) 37 | - [Translation Example](https://github.com/huggingface/notebooks/blob/main/examples/translation.ipynb) 38 | - [Summarization Example](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb) 39 | - [Deep Encoder Shallow Decoder](https://github.com/snoop2head/Deep-Encoder-Shallow-Decoder/blob/main/trainer.py) 40 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | from datasets import load_dataset, load_metric, Dataset 7 | from transformers import ( 8 | AutoTokenizer, 9 | MarianMTModel, 10 | AutoTokenizer, 11 | AutoModelForSeq2SeqLM, 12 | T5Tokenizer 13 | ) 14 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq 15 | import multiprocessing 16 | from easydict import EasyDict 17 | import yaml 18 | 19 | 20 | 21 | # Read config.yaml file 22 | with open("config.yaml") as infile: 23 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 24 | CFG = EasyDict(SAVED_CFG["CFG"]) 25 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 26 | 27 | training_args = Seq2SeqTrainingArguments 28 | 29 | model_name = CFG.inference_model_name 30 | valid_dataset = load_dataset(CFG.dset_name, split="train", use_auth_token=True) 31 | print(valid_dataset) 32 | 33 | 34 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) 35 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True) 36 | model.to(device) 37 | 38 | 39 | start = 0 40 | batch_size = 250 # P100:batch_size 250 / A100:batch_size 700 41 | length = len(valid_dataset) 42 | cnt = length//batch_size + 1 43 | df = pd.DataFrame(columns = {"src", "gen", "image_url"}) 44 | 45 | csv_start = 0 46 | save_start = csv_start 47 | save_count = 0 48 | for i in tqdm(range(start,cnt)): 49 | save_count+=1 50 | if i== cnt-1: 51 | end = len(valid_dataset) 52 | else: 53 | end=csv_start+batch_size 54 | 55 | src_sentences = valid_dataset['caption'][csv_start:end] 56 | urls = valid_dataset['image_url'][csv_start:end] 57 | 58 | encoding = tokenizer( 59 | src_sentences, padding=True, return_tensors="pt", max_length=CFG.max_token_length 60 | ).to(device) 61 | 62 | # https://huggingface.co/docs/transformers/internal/generation_utils 63 | with torch.no_grad(): 64 | translated = model.generate( 65 | **encoding, 66 | max_length=CFG.max_token_length, 67 | num_beams=CFG.num_beams, 68 | repetition_penalty=CFG.repetition_penalty, 69 | no_repeat_ngram_size=CFG.no_repeat_ngram_size, 70 | num_return_sequences=CFG.num_return_sequences, 71 | ) 72 | del encoding 73 | 74 | # https://github.com/huggingface/transformers/issues/10704 75 | generated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True) 76 | del translated 77 | print(generated_texts[0:2]) 78 | 79 | df1 = pd.DataFrame({"src": src_sentences, "gen": generated_texts, "image_url": urls}) 80 | df = df.append(df1, ignore_index = True) 81 | if save_count == 30: 82 | save_count=0 83 | df.to_csv(f"./results/tmp_translated-{save_start}-{end}-sentences.csv", index=False) 84 | csv_start = end 85 | # df = pd.DataFrame({"src": src_sentences, "tgt": tgt_sentences, "gen": generated_texts}) 86 | # df.to_csv(f"./results/translated-{CFG.no_inference_sentences}-sentences.csv", index=False) 87 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | from datasets import load_dataset, load_metric, Dataset 7 | from transformers import ( 8 | AutoTokenizer, 9 | MarianMTModel, 10 | AutoTokenizer, 11 | AutoModelForSeq2SeqLM, 12 | T5Tokenizer, 13 | ) 14 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq 15 | import multiprocessing 16 | from easydict import EasyDict 17 | import yaml 18 | 19 | 20 | # Read config.yaml file 21 | with open("config.yaml") as infile: 22 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 23 | CFG = EasyDict(SAVED_CFG["CFG"]) 24 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 25 | 26 | training_args = Seq2SeqTrainingArguments 27 | 28 | model_name = CFG.inference_model_name 29 | valid_dataset = load_dataset(CFG.dset_name, split="valid", use_auth_token=True) 30 | print(valid_dataset) 31 | 32 | 33 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) 34 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True) 35 | model.to(device) 36 | 37 | 38 | start = 0 39 | batch_size = 150 # P100:batch_size 250 / A100:batch_size 700 40 | length = len(valid_dataset) 41 | cnt = length // batch_size + 1 42 | df = pd.DataFrame(columns={"src", "gen", "label"}) 43 | 44 | csv_start = 0 45 | save_start = csv_start 46 | save_count = 0 47 | for i in tqdm(range(start, cnt)): 48 | save_count += 1 49 | if i == cnt - 1: 50 | end = len(valid_dataset) 51 | else: 52 | end = csv_start + batch_size 53 | 54 | src_sentences = valid_dataset["ko"][csv_start:end] 55 | label = valid_dataset["en"][csv_start:end] 56 | 57 | encoding = tokenizer( 58 | src_sentences, padding=True, return_tensors="pt", max_length=CFG.max_token_length 59 | ).to(device) 60 | 61 | # https://huggingface.co/docs/transformers/internal/generation_utils 62 | with torch.no_grad(): 63 | translated = model.generate( 64 | **encoding, 65 | max_length=CFG.max_token_length, 66 | num_beams=CFG.num_beams, 67 | repetition_penalty=CFG.repetition_penalty, 68 | no_repeat_ngram_size=CFG.no_repeat_ngram_size, 69 | num_return_sequences=CFG.num_return_sequences, 70 | ) 71 | del encoding 72 | 73 | # https://github.com/huggingface/transformers/issues/10704 74 | generated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True) 75 | del translated 76 | print(generated_texts[0:2]) 77 | 78 | df1 = pd.DataFrame({"src": src_sentences, "gen": generated_texts, "label": label}) 79 | df = df.append(df1, ignore_index=True) 80 | if save_count == 30: 81 | save_count = 0 82 | df.to_csv(f"./results/tmp_translated-{save_start}-{end}-sentences.csv", index=False) 83 | csv_start = end 84 | 85 | # load sacrebleu 86 | # https://huggingface.co/spaces/evaluate-metric/sacrebleu | https://github.com/mjpost/sacreBLEU 87 | metric = load_metric("sacrebleu") 88 | 89 | preds = df["gen"] 90 | labels = np.expand_dims(df["label"], axis=1) 91 | 92 | score = metric.compute(predictions=preds, references=labels) # takes 3 minutes for 550K pairs 93 | print(score) 94 | 95 | """ 96 | # Result of Korean to English Translation 97 | { 98 | "score": 45.14821527744787, 99 | "counts": [10287887, 6969037, 5035938, 3719578], 100 | "totals": [14100267, 13546767, 12993267, 12439767], 101 | "precisions": [72.96235596106088, 51.44428187182964, 38.75805830819916, 29.90070473184908], 102 | "bp": 0.9886003016662179, 103 | "sys_len": 14100267, 104 | "ref_len": 14261929, 105 | } 106 | """ 107 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | T5Tokenizer, 3 | T5ForConditionalGeneration, 4 | AutoModelForSeq2SeqLM, 5 | DataCollatorForSeq2Seq, 6 | Seq2SeqTrainingArguments, 7 | Seq2SeqTrainer, 8 | ) 9 | import wandb 10 | import numpy as np 11 | from datasets import load_dataset, load_metric 12 | import multiprocessing 13 | from easydict import EasyDict 14 | import yaml 15 | 16 | # Read config.yaml file 17 | with open("config.yaml") as infile: 18 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 19 | CFG = EasyDict(SAVED_CFG["CFG"]) 20 | 21 | metric = load_metric("sacrebleu") 22 | 23 | # all dataset 24 | dset = load_dataset(CFG.dset_name, use_auth_token=True) 25 | tokenizer = T5Tokenizer.from_pretrained(CFG.model_name) # https://github.com/AIRC-KETI/ke-t5#models 26 | 27 | 28 | def preprocess_function(examples): 29 | inputs = examples[CFG.src_language] 30 | targets = examples[CFG.tgt_language] 31 | model_inputs = tokenizer(inputs, max_length=CFG.max_token_length, truncation=True) 32 | # Setup the tokenizer for targets 33 | with tokenizer.as_target_tokenizer(): 34 | labels = tokenizer(targets, max_length=CFG.max_token_length, truncation=True) 35 | model_inputs["labels"] = labels["input_ids"] 36 | return model_inputs 37 | 38 | 39 | # print(preprocess_function(dset["train"].select(range(0, 2)))) 40 | 41 | CPU_COUNT = multiprocessing.cpu_count() // 2 42 | 43 | tokenized_datasets = dset.map(preprocess_function, batched=True, num_proc=CPU_COUNT) 44 | tokenized_datasets 45 | 46 | 47 | model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name) 48 | 49 | str_model_name = CFG.model_name.split("/")[-1] 50 | run_name = f"{str_model_name}-finetuned-{CFG.src_language}-to-{CFG.tgt_language}" 51 | wandb.init(entity=CFG.entity_name, project=CFG.project_name, name=run_name) 52 | 53 | training_args = Seq2SeqTrainingArguments( 54 | run_name, 55 | learning_rate=CFG.learning_rate, 56 | weight_decay=CFG.weight_decay, 57 | per_device_train_batch_size=CFG.train_batch_size, 58 | per_device_eval_batch_size=CFG.valid_batch_size, 59 | evaluation_strategy=CFG.evaluation_strategy, 60 | # eval_steps=CFG.eval_steps, 61 | save_steps=CFG.save_steps, 62 | num_train_epochs=CFG.num_epochs, 63 | save_total_limit=CFG.num_checkpoints, 64 | predict_with_generate=True, 65 | fp16=CFG.fp16, 66 | gradient_accumulation_steps=CFG.gradient_accumulation_steps, 67 | logging_steps=CFG.logging_steps, 68 | ) 69 | 70 | wandb.config.update(training_args) 71 | 72 | 73 | def postprocess_text(preds, labels): 74 | preds = [pred.strip() for pred in preds] 75 | labels = [[label.strip()] for label in labels] 76 | return preds, labels 77 | 78 | 79 | def compute_metrics(eval_preds): 80 | preds, labels = eval_preds 81 | if isinstance(preds, tuple): 82 | preds = preds[0] 83 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 84 | # Replace -100 in the labels as we can't decode them. 85 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 86 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 87 | # Some simple post-processing 88 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 89 | result = metric.compute(predictions=decoded_preds, references=decoded_labels) 90 | result = {"bleu": result["score"]} 91 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 92 | result["gen_len"] = np.mean(prediction_lens) 93 | result = {k: round(v, 4) for k, v in result.items()} 94 | return result 95 | 96 | 97 | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) 98 | 99 | trainer = Seq2SeqTrainer( 100 | model, 101 | training_args, 102 | train_dataset=tokenized_datasets["train"], 103 | eval_dataset=tokenized_datasets["valid"], 104 | data_collator=data_collator, 105 | tokenizer=tokenizer, 106 | compute_metrics=compute_metrics, 107 | ) 108 | 109 | trainer.train() 110 | 111 | trainer.evaluate() 112 | trainer.save_model(CFG.save_path) 113 | wandb.finish() 114 | -------------------------------------------------------------------------------- /CC3M_translate_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import time 5 | import numpy as np 6 | from datasets import load_dataset, load_metric, Dataset 7 | from transformers import ( 8 | AutoTokenizer, 9 | MarianMTModel, 10 | AutoTokenizer, 11 | AutoModelForSeq2SeqLM, 12 | T5Tokenizer 13 | ) 14 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq 15 | import multiprocessing 16 | from easydict import EasyDict 17 | import yaml 18 | 19 | 20 | 21 | # Read config.yaml file 22 | with open("config.yaml") as infile: 23 | SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader) 24 | CFG = EasyDict(SAVED_CFG["CFG"]) 25 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 26 | 27 | training_args = Seq2SeqTrainingArguments 28 | 29 | model_name = CFG.inference_model_name 30 | CFG.dset_name = "conceptual_captions" 31 | train_dataset = load_dataset(CFG.dset_name, split="train") 32 | valid_dataset = load_dataset(CFG.dset_name, split="validation") 33 | print(train_dataset) 34 | print(valid_dataset) 35 | 36 | 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) 38 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True) 39 | model.to(device) 40 | 41 | 42 | start = 0 43 | batch_size = 600 # P100:batch_size 250 / A100:batch_size 600 44 | length = len(train_dataset) 45 | cnt = length//batch_size + 1 46 | df = pd.DataFrame(columns = {"english_caption", "korean_caption", "image_url"}) 47 | 48 | # start train dastasets translate 49 | 50 | csv_start = 0 51 | save_start = csv_start 52 | save_count = 0 53 | for i in tqdm(range(start,cnt)): 54 | save_count+=1 55 | check = False 56 | 57 | end=csv_start+batch_size 58 | if end>len(train_dataset): 59 | check = True 60 | end = len(train_dataset) 61 | 62 | src_sentences = train_dataset['caption'][csv_start:end] 63 | urls = train_dataset['image_url'][csv_start:end] 64 | 65 | encoding = tokenizer( 66 | src_sentences, padding=True, return_tensors="pt", max_length=CFG.max_token_length 67 | ).to(device) 68 | 69 | # https://huggingface.co/docs/transformers/internal/generation_utils 70 | with torch.no_grad(): 71 | translated = model.generate( 72 | **encoding, 73 | max_length=CFG.max_token_length, 74 | num_beams=CFG.num_beams, 75 | repetition_penalty=CFG.repetition_penalty, 76 | no_repeat_ngram_size=CFG.no_repeat_ngram_size, 77 | num_return_sequences=CFG.num_return_sequences, 78 | ) 79 | del encoding 80 | 81 | # https://github.com/huggingface/transformers/issues/10704 82 | generated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True) 83 | del translated 84 | print(generated_texts[0:2]) 85 | 86 | df1 = pd.DataFrame({"english_caption": src_sentences, "korean_caption": generated_texts, "image_url": urls}) 87 | df = df.append(df1, ignore_index = True) 88 | if save_count == 30 or check==True: 89 | save_count=0 90 | df.to_csv(f"./results/train_translated-{save_start}-{end}-sentences.csv", index=False) 91 | csv_start = end 92 | 93 | 94 | start = 0 95 | batch_size = 600 # P100:batch_size 250 / A100:batch_size 600 96 | length = len(valid_dataset) 97 | cnt = length//batch_size + 1 98 | df = pd.DataFrame(columns = {"english_caption", "korean_caption", "image_url"}) 99 | 100 | # start validation dastasets translate 101 | 102 | csv_start = 0 103 | save_start = csv_start 104 | save_count = 0 105 | for i in tqdm(range(start,cnt)): 106 | save_count+=1 107 | check = False 108 | 109 | end=csv_start+batch_size 110 | if end>len(valid_dataset): 111 | check = True 112 | end = len(valid_dataset) 113 | 114 | src_sentences = valid_dataset['caption'][csv_start:end] 115 | urls = valid_dataset['image_url'][csv_start:end] 116 | 117 | encoding = tokenizer( 118 | src_sentences, padding=True, return_tensors="pt", max_length=CFG.max_token_length 119 | ).to(device) 120 | 121 | # https://huggingface.co/docs/transformers/internal/generation_utils 122 | with torch.no_grad(): 123 | translated = model.generate( 124 | **encoding, 125 | max_length=CFG.max_token_length, 126 | num_beams=CFG.num_beams, 127 | repetition_penalty=CFG.repetition_penalty, 128 | no_repeat_ngram_size=CFG.no_repeat_ngram_size, 129 | num_return_sequences=CFG.num_return_sequences, 130 | ) 131 | del encoding 132 | 133 | # https://github.com/huggingface/transformers/issues/10704 134 | generated_texts = tokenizer.batch_decode(translated, skip_special_tokens=True) 135 | del translated 136 | print(generated_texts[0:2]) 137 | 138 | df1 = pd.DataFrame({"english_caption": src_sentences, "korean_caption": generated_texts, "image_url": urls}) 139 | df = df.append(df1, ignore_index = True) 140 | if save_count == 30 or check==True: 141 | save_count=0 142 | df.to_csv(f"./results/valid_translated-{save_start}-{end}-sentences.csv", index=False) 143 | csv_start = end 144 | 145 | 146 | 147 | 148 | # df = pd.DataFrame({"src": src_sentences, "tgt": tgt_sentences, "gen": generated_texts}) 149 | # df.to_csv(f"./results/translated-{CFG.no_inference_sentences}-sentences.csv", index=False) 150 | --------------------------------------------------------------------------------