├── .gitignore ├── README.md ├── environment.yml ├── eval.py ├── models ├── __init__.py ├── seq2seq_sc.py ├── utils.py └── utils_test.py ├── poster_seq2seq_Asilomar2023.jpg ├── preprocess ├── __init__.py ├── allnli.py ├── europarl.py ├── flickr30k.py └── hf_data_gen.py ├── scripts ├── eval_flickr.sh ├── preprocess_allnli.sh ├── preprocess_europarl.sh ├── preprocess_flickr30k.sh ├── train_allnli.sh └── train_europarl.sh ├── train.py └── train ├── __init__.py └── args.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | **/__pycache__ 3 | data 4 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # seq2seq-SC 2 | 3 | poster 4 | 5 | ## Citation 6 | 7 | ```bash 8 | @misc{lee2022seq2seqSC, 9 | author = {Lee, Ju-Hyung and Lee, Dong-Ho and Sheen, Eunsoo and Choi, Thomas and Pujara, Jay and Kim, Joongheon}, 10 | title = {Seq2Seq-SC: End-to-End Semantic Communication Systems with Pre-trained Language Model}, 11 | journal={arXiv preprint arXiv:2210.15237}, 12 | year = {2022}, 13 | } 14 | ``` 15 | 16 | ## Setup 17 | 18 | 1. Setup conda environment and activate 19 | 20 | ```bash 21 | conda env create -f environment.yml 22 | ``` 23 | 24 | ## Data Preprocessing 25 | 26 | ### Europarl dataset 27 | 28 | ```bash 29 | data_path=data/europarl 30 | mkdir -p $data_path 31 | cd $data_path 32 | wget -P /tmp http://www.statmt.org/europarl/v7/europarl.tgz 33 | tar zxf /tmp/europarl.tgz 34 | 35 | europarl_dataset="$data_path/txt/en" 36 | out_dir="$data_path/processed" 37 | njobs=4 38 | 39 | mkdir -p $out_dir 40 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset 41 | ``` 42 | 43 | ### AllNLI 44 | 45 | Run `./scripts/preprocess_allnli.sh` or the following commands 46 | 47 | ```bash 48 | data_path=data/allnli 49 | mkdir -p $data_path 50 | wget -P $data_path https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz 51 | gunzip $data_path/AllNLI.jsonl.gz 52 | 53 | allnli_dataset="$data_path/AllNLI.jsonl" 54 | out_dir="$data_path/processed" 55 | 56 | mkdir -p $out_dir 57 | python -m preprocess.allnli -o $out_dir $allnli_dataset 58 | ``` 59 | 60 | ### Flickr30K 61 | 62 | To download the dataset, go to [Flickr30K](http://hockenmaier.cs.illinois.edu/DenotationGraph/) and fill out the form to get the downloadable link. 63 | 64 | ```bash 65 | data_path="data/flickr" 66 | dataset_path="${data_path}/flickr30k.tar.gz" 67 | out_dir="$data_path/processed" 68 | 69 | mkdir -p $out_dir 70 | 71 | tar xzf ${dataset_path} -C $data_path 72 | python -m preprocess.flickr30k \ 73 | -o "$out_dir/flickr30k.json" \ 74 | "${data_path}/results_20130124.token" 75 | ``` 76 | 77 | ## Train 78 | 79 | You can run `scripts/train_europarl.sh` or `scripts/train_allnli.sh`. Otherwise, you can train by running the follwing commands. 80 | 81 | ```bash 82 | output_dir='checkpoints/seq2seq-sc' 83 | trainset_path='data/allnli/processed/allnli_train.csv' 84 | devset_path='data/allnli/processed/allnli_dev.csv' 85 | 86 | mkdir -p $output_dir 87 | 88 | python train.py \ 89 | --per_device_train_batch_size 4 \ 90 | --num_train_epochs 3 \ 91 | --do_train \ 92 | --do_eval \ 93 | --model_name_or_path facebook/bart-base \ 94 | --preprocessing_num_workers 4 \ 95 | --save_total_limit 1 \ 96 | --no_use_fast_tokenizer \ 97 | --num_beams 4 \ 98 | --max_source_length 64 \ 99 | --max_target_length 64 \ 100 | --train_file "$trainset_path" \ 101 | --validation_file "$devset_path" \ 102 | --test_file "$devset_path" \ 103 | --output_dir $output_dir \ 104 | --ebno_db 10 \ 105 | --channel_type AWGN \ 106 | --overwrite_output_dir \ 107 | --tokenizer_name facebook/bart-base \ 108 | --pad_to_max_length \ 109 | --dataset_config 3.0.0 110 | ``` 111 | 112 | ## Evaluation 113 | 114 | You can use the script `scripts/eval_flickr.sh` or the following commands: 115 | 116 | ```bash 117 | # BLEU score 118 | ebno_db="10" 119 | metric="bleu" # bleu, sbert 120 | testset_path='data/flickr/processed/flickr30k.json' 121 | checkpoint_path="checkpoints/seq2seq-allnli-sc" 122 | 123 | python eval.py \ 124 | --batch 4 \ 125 | --metric "${metric}" \ 126 | --ebno-db "${ebno_db}" \ 127 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \ 128 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \ 129 | --testset-path "${testset_path}" \ 130 | $checkpoint_path 131 | ``` 132 | 133 | ```bash 134 | # SBERT 135 | ebno_db="10" 136 | metric="sbert" # bleu, sbert 137 | testset_path='data/flickr/processed/flickr30k.json' 138 | checkpoint_path="checkpoints/seq2seq-allnli-sc" 139 | 140 | python eval.py \ 141 | --batch 4 \ 142 | --metric "${metric}" \ 143 | --ebno-db "${ebno_db}" \ 144 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \ 145 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \ 146 | --testset-path "${testset_path}" \ 147 | $checkpoint_path 148 | ``` 149 | 150 | 151 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: seq2seq-sc 2 | channels: 3 | - huggingface 4 | - anaconda 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - pip==22.3 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_gnu 11 | - blas=1.0=mkl 12 | - brotlipy=0.7.0=py39h27cfd23_1003 13 | - ca-certificates=2022.07.19=h06a4308_0 14 | - certifi=2022.9.14=py39h06a4308_0 15 | - cffi=1.15.1=py39h74dc2b5_0 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - click=8.0.4=py39h06a4308_0 18 | - cryptography=37.0.1=py39h9ce1e76_0 19 | - cudatoolkit=11.2.2=hbe64b41_10 20 | - cudnn=8.1.0.77=h90431f1_0 21 | - dataclasses=0.8=pyh6d0b6a4_7 22 | - filelock=3.6.0=pyhd3eb1b0_0 23 | - huggingface_hub=0.9.1=py_0 24 | - idna=3.3=pyhd3eb1b0_0 25 | - importlib-metadata=4.11.3=py39h06a4308_0 26 | - importlib_metadata=4.11.3=hd3eb1b0_0 27 | - intel-openmp=2021.4.0=h06a4308_3561 28 | - joblib=1.1.0=pyhd3eb1b0_0 29 | - ld_impl_linux-64=2.38=h1181459_1 30 | - libffi=3.3=he6710b0_2 31 | - libgcc-ng=12.1.0=h8d9b700_16 32 | - libgomp=12.1.0=h8d9b700_16 33 | - libprotobuf=3.20.1=h4ff587b_0 34 | - libstdcxx-ng=12.1.0=ha89aaad_16 35 | - mkl=2021.4.0=h06a4308_640 36 | - mkl-service=2.4.0=py39h7f8727e_0 37 | - mkl_fft=1.3.1=py39hd3c417c_0 38 | - mkl_random=1.2.2=py39h51133e4_0 39 | - ncurses=6.3=h5eee18b_3 40 | - numpy=1.23.1=py39h6c91a56_0 41 | - numpy-base=1.23.1=py39ha15fc14_0 42 | - openssl=1.1.1q=h7f8727e_0 43 | - packaging=21.3=pyhd3eb1b0_0 44 | - pycparser=2.21=pyhd3eb1b0_0 45 | - pyopenssl=22.0.0=pyhd3eb1b0_0 46 | - pyparsing=3.0.9=py39h06a4308_0 47 | - pysocks=1.7.1=py39h06a4308_0 48 | - python=3.9.13=haa1d7c7_1 49 | - pyyaml=6.0=py39h7f8727e_1 50 | - readline=8.1.2=h7f8727e_1 51 | - regex=2022.7.9=py39h5eee18b_0 52 | - requests=2.28.1=py39h06a4308_0 53 | - sacremoses=master=py_0 54 | - setuptools=63.4.1=py39h06a4308_0 55 | - six=1.16.0=pyhd3eb1b0_1 56 | - sqlite=3.39.2=h5082296_0 57 | - tk=8.6.12=h1ccaba5_0 58 | - tqdm=4.64.0=py39h06a4308_0 59 | - transformers=4.22.1=py_0 60 | - typing-extensions=4.3.0=py39h06a4308_0 61 | - typing_extensions=4.3.0=py39h06a4308_0 62 | - tzdata=2022c=h04d1e81_0 63 | - urllib3=1.26.11=py39h06a4308_0 64 | - wheel=0.37.1=pyhd3eb1b0_0 65 | - xz=5.2.5=h7f8727e_1 66 | - yaml=0.2.5=h7b6447c_0 67 | - zipp=3.8.0=py39h06a4308_0 68 | - zlib=1.2.12=h5eee18b_3 69 | - pip: 70 | - absl-py==1.2.0 71 | - accelerate==0.13 72 | - aiohttp==3.8.3 73 | - aiosignal==1.2.0 74 | - astroid==2.12.10 75 | - astunparse==1.6.3 76 | - async-timeout==4.0.2 77 | - attrs==22.1.0 78 | - bert-score==0.3.12 79 | - cachetools==5.2.0 80 | - contourpy==1.0.5 81 | - cycler==0.11.0 82 | - datasets==2.5.1 83 | - dill==0.3.5.1 84 | - evaluate==0.2.2 85 | - flatbuffers==2.0.7 86 | - fonttools==4.37.3 87 | - frozenlist==1.3.1 88 | - fsspec==2022.8.2 89 | - gast==0.4.0 90 | - google-auth==2.11.1 91 | - google-auth-oauthlib==0.4.6 92 | - google-pasta==0.2.0 93 | - grpcio==1.48.1 94 | - h5py==3.7.0 95 | - importlib-resources==5.9.0 96 | - isort==5.10.1 97 | - keras==2.10.0 98 | - keras-preprocessing==1.1.2 99 | - kiwisolver==1.4.4 100 | - lazy-object-proxy==1.7.1 101 | - libclang==14.0.6 102 | - markdown==3.4.1 103 | - markupsafe==2.1.1 104 | - matplotlib==3.6.0 105 | - mccabe==0.7.0 106 | - moverscore==1.0.3 107 | - multidict==6.0.2 108 | - multiprocess==0.70.13 109 | - nltk==3.7 110 | - oauthlib==3.2.1 111 | - opt-einsum==3.3.0 112 | - pandas==1.5.0 113 | - pillow==9.2.0 114 | - platformdirs==2.5.2 115 | - portalocker==2.6.0 116 | - protobuf==3.19.5 117 | - psutil==5.9.2 118 | - pyarrow==8.0.0 119 | - pyasn1==0.4.8 120 | - pyasn1-modules==0.2.8 121 | - pyemd==0.5.1 122 | - pylint==2.15.3 123 | - python-dateutil==2.8.2 124 | - pytz==2022.2.1 125 | - requests-oauthlib==1.3.1 126 | - responses==0.18.0 127 | - rouge-score==0.1.2 128 | - rsa==4.9 129 | - scikit-learn==1.1.2 130 | - scipy==1.9.1 131 | - sentence-transformers==2.2.2 132 | - sentencepiece==0.1.97 133 | - sionna==0.11.0 134 | - tensorboard==2.10.0 135 | - tensorboard-data-server==0.6.1 136 | - tensorboard-plugin-wit==1.8.1 137 | - tensorflow==2.10.0 138 | - tensorflow-estimator==2.10.0 139 | - tensorflow-io-gcs-filesystem==0.27.0 140 | - termcolor==2.0.1 141 | - threadpoolctl==3.1.0 142 | - tokenizers==0.12.1 143 | - tomli==2.0.1 144 | - tomlkit==0.11.4 145 | - torch==1.12.1 146 | - torchvision==0.13.1 147 | - typing==3.7.4.3 148 | - werkzeug==2.2.2 149 | - wrapt==1.14.1 150 | - xxhash==3.0.0 151 | - yapf==0.32.0 152 | - yarl==1.8.1 153 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import json 3 | import argparse 4 | import logging 5 | from transformers import BartTokenizer 6 | import evaluate 7 | from tqdm import tqdm 8 | import warnings 9 | 10 | def get_test_data(path): 11 | with open(path) as f: 12 | return json.load(f) 13 | 14 | def from_pretrained(path, ebno_db): 15 | from models import TFSeq2SeqSCForConditionalGeneration 16 | import transformers 17 | transformers.utils.logging.set_verbosity(logging.INFO) 18 | return TFSeq2SeqSCForConditionalGeneration.from_pretrained( 19 | path, ebno_db=ebno_db) 20 | 21 | def predict(path, ebno_db, tokenizer, batch_size, test_data_path, num_beams): 22 | import tensorflow as tf 23 | max_len = 32 24 | 25 | # load model 26 | model = from_pretrained(path, ebno_db) 27 | 28 | # # load testset 29 | test_data = get_test_data(test_data_path) 30 | input_sentences = [d['input'] for d in test_data] 31 | input_ids = tokenizer(input_sentences, return_tensors="tf", 32 | padding='max_length', truncation=True, max_length=max_len).input_ids 33 | testset = tf.data.Dataset.from_tensor_slices(input_ids) 34 | 35 | # inference 36 | pred_sentences = [] 37 | for input_ids in tqdm(testset.batch(batch_size).prefetch(tf.data.AUTOTUNE)): 38 | pred_batch = model.generate(input_ids, max_new_tokens=max_len, num_beams=num_beams) 39 | output_strs = tokenizer.batch_decode(pred_batch, 40 | skip_special_tokens=True, 41 | clean_up_tokenization_spaces=False) 42 | pred_sentences.extend(output_strs) 43 | 44 | 45 | res = { 46 | 'input': input_sentences, 47 | 'pred': pred_sentences, 48 | 'refs': [d['refs'] for d in test_data], 49 | } 50 | return res 51 | 52 | def get_predictions(path, ebno_db, test_data_path, prediction_json_path, batch_size, tokenizer, num_beams): 53 | path = pathlib.Path(path) 54 | if not prediction_json_path.exists(): 55 | print('Missing predictions.json') 56 | res = predict( 57 | path=path, 58 | ebno_db=ebno_db, 59 | tokenizer=tokenizer, 60 | batch_size=batch_size, 61 | test_data_path=test_data_path, 62 | num_beams=num_beams, 63 | ) 64 | 65 | # save result 66 | with open(prediction_json_path, 'w') as f: 67 | json.dump(res, f, indent=4) 68 | else: 69 | with open(prediction_json_path, 'r') as f: 70 | res = json.load(f) 71 | return res 72 | 73 | def calc_bleu(predictions, tokenizer, multi_ref, **kwargs): 74 | bleu = evaluate.load('bleu') 75 | if multi_ref: 76 | warnings.warn('BLEU does not support multiple references') 77 | tokenize = lambda l: tokenizer(l, add_special_tokens=False).input_ids 78 | results = bleu.compute( 79 | references=predictions['input'], 80 | predictions=predictions['pred'], 81 | tokenizer=tokenize, 82 | max_order=4) 83 | return results 84 | 85 | def calc_sbert(predictions, batch_size, multi_ref, **kwargs): 86 | from sentence_transformers import SentenceTransformer, util 87 | import torch 88 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 89 | model = SentenceTransformer( 90 | model_name_or_path='all-MiniLM-L6-v2', 91 | device=device) 92 | 93 | sentences1 = predictions['pred'] 94 | if not multi_ref: 95 | refs = [[s] for s in predictions['input']] 96 | else: 97 | refs = predictions['refs'] 98 | 99 | def calc_cos_score(model, hyp_embedding, ref_sentences): 100 | hyp = hyp_embedding.reshape((1, -1)) 101 | refs = model.encode(ref_sentences, convert_to_tensor=True) 102 | scores = util.cos_sim(hyp, refs) 103 | scores = scores.reshape((-1)).tolist() 104 | return { 105 | 'scores': scores, 106 | 'max_score': max(scores), 107 | 'mean_score': sum(scores) / len(scores), 108 | } 109 | 110 | 111 | # compute embedding 112 | pred_embed = model.encode(sentences1, batch_size=batch_size, convert_to_tensor=True) 113 | N = pred_embed.shape[0] 114 | scores = [ 115 | calc_cos_score(model, pred_embed[i], refs[i]) for i in range(N) 116 | ] 117 | max_scores = [s['max_score'] for s in scores] 118 | mean_score = sum(max_scores)/len(max_scores) 119 | return { 120 | 'metric': 'sentence textual similarity', 121 | 'mean_score': mean_score, 122 | 'scores': scores, 123 | } 124 | 125 | METRIC_TO_SCORER = { 126 | 'bleu': calc_bleu, 127 | 'sbert': calc_sbert, 128 | } 129 | 130 | def calc(args): 131 | tokenizer = BartTokenizer.from_pretrained(args.tokenizer) 132 | 133 | path = args.path 134 | metric = args.metric 135 | batch_size = args.batch_size 136 | 137 | predictions = get_predictions( 138 | path, 139 | ebno_db=args.ebno_db, 140 | prediction_json_path=args.prediction_json_path, 141 | test_data_path=args.testset_path, 142 | batch_size=batch_size, 143 | tokenizer=tokenizer, 144 | num_beams=args.num_beams) 145 | scorer = METRIC_TO_SCORER[metric] 146 | results = scorer( 147 | predictions=predictions, 148 | tokenizer=tokenizer, 149 | batch_size=batch_size, 150 | multi_ref=args.multi_ref, 151 | ) 152 | # dump result 153 | with open(args.result_json_path, 'w') as f: 154 | json.dump(results, f, indent=4) 155 | 156 | 157 | def main(): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument(dest='path', metavar='checkpoint_path', type=pathlib.Path) 160 | parser.add_argument('-m', '--metric', choices = list(METRIC_TO_SCORER.keys()), dest='metric') 161 | parser.add_argument('-b', '--batch-size', default=4, type=int, dest='batch_size') 162 | parser.add_argument('-e', '--ebno-db', required=True, type=float, dest='ebno_db') 163 | parser.add_argument('--testset-path', 164 | required=True, type=pathlib.Path, dest='testset_path') 165 | parser.add_argument('--prediction-json-path', 166 | required=True, 167 | type=pathlib.Path, 168 | dest='prediction_json_path', 169 | help='Required. Output path of prediction result cache json file. \ 170 | If the file exists, the prediction result will be reused') 171 | parser.add_argument('--result-json-path', 172 | default=pathlib.Path('./result.json'), 173 | type= pathlib.Path, 174 | dest='result_json_path') 175 | parser.add_argument('--tokenizer', 176 | default='facebook/bart-base', 177 | dest='tokenizer') 178 | parser.add_argument('--num-beams', 179 | default=1, 180 | type=int, 181 | dest='num_beams') 182 | parser.add_argument('--multi-ref', 183 | action='store_true', 184 | dest='multi_ref') 185 | args = parser.parse_args() 186 | calc(args) 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.seq2seq_sc import TFSeq2SeqSCMainLayer, TFSeq2SeqSCForConditionalGeneration, TFSeq2SeqSCModel -------------------------------------------------------------------------------- /models/seq2seq_sc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | from transformers import TFBartPretrainedModel, TFBartForConditionalGeneration 3 | from transformers.models.bart.modeling_tf_bart import TFBartMainLayer, BartConfig, shift_tokens_right, TFBartEncoder 4 | from transformers.modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqModelOutput 5 | from transformers.modeling_tf_utils import unpack_inputs, TFModelInputType, DUMMY_INPUTS 6 | import tensorflow as tf 7 | 8 | import sionna 9 | sionna.Config.xla_compat=True 10 | 11 | from sionna.channel import AWGN, FlatFadingChannel 12 | from sionna.fec.polar import Polar5GEncoder, Polar5GDecoder 13 | from sionna.mapping import Mapper, Demapper, Constellation 14 | from sionna.mimo import mf_equalizer 15 | from sionna.utils import ebnodb2no, expand_to_rank 16 | from .utils import tensor_to_binary, binary_to_tensor 17 | import numpy as np 18 | from transformers.utils import logging 19 | 20 | class TFSeq2SeqSCEncoderChannel(tf.keras.layers.Layer): 21 | 22 | def __init__(self, 23 | encoder: TFBartEncoder, 24 | ebno_db, 25 | polar_k=512, 26 | polar_n=1024, 27 | polar_decoder_type='SC', 28 | polar_decoder_list_size=8, 29 | num_bits_per_symbol=4, 30 | channel_type = 'AWGN', 31 | channel_num_tx_ant=1, 32 | channel_num_rx_ant=1): 33 | # NOTE: setting layer name as follows seems strange, 34 | # but it allows HuggingFace to load pretrained weight properly 35 | super().__init__(name='model/model/') 36 | logger = logging.get_logger("transformers") 37 | self.config = encoder.config 38 | self.bart_encoder = encoder 39 | 40 | # channel encoder/decoder, channel noise 41 | assert ebno_db is not None 42 | self.ebno_db = float(ebno_db) 43 | self.k = polar_k 44 | self.n = polar_n 45 | logger.info(f'{self.ebno_db=}') 46 | logger.info(f'{self.k=}') 47 | logger.info(f'{self.n=}') 48 | self.channel_encoder = Polar5GEncoder(k=self.k, n=self.n) 49 | 50 | constellation = Constellation("qam", 51 | num_bits_per_symbol, 52 | trainable=False) 53 | logger.info(f'Constellation: type={constellation._constellation_type} {num_bits_per_symbol=} trainable={constellation._trainable}') 54 | self.num_bits_per_symbol = num_bits_per_symbol 55 | self.mapper = Mapper(constellation=constellation) 56 | if channel_type == 'AWGN': 57 | self.channel = AWGN() 58 | self.channel_num_tx_ant = 1 59 | self.channel_num_rx_ant = 1 60 | elif channel_type == 'FlatFadingChannel': 61 | self.channel = FlatFadingChannel(channel_num_tx_ant, channel_num_rx_ant, add_awgn=True, return_channel=True) 62 | self.channel_num_tx_ant = channel_num_tx_ant 63 | self.channel_num_rx_ant = channel_num_rx_ant 64 | else: 65 | raise ValueError(f"Invalid channel type: {channel_type}") 66 | logger.info(f'{channel_type=}') 67 | self.demapper = Demapper("app", constellation=constellation) 68 | self.channel_decoder = Polar5GDecoder( 69 | self.channel_encoder, 70 | dec_type=polar_decoder_type, 71 | list_size=polar_decoder_list_size) 72 | self.coderate = self.k / self.n 73 | logger.info(f'{self.coderate=}') 74 | 75 | @unpack_inputs 76 | def call( 77 | self, 78 | input_ids: Optional[TFModelInputType] = None, 79 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 80 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 81 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 82 | output_attentions: Optional[bool] = None, 83 | output_hidden_states: Optional[bool] = None, 84 | return_dict: Optional[bool] = None, 85 | training: Optional[bool] = False, 86 | ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: 87 | encoder_outputs = self.bart_encoder( 88 | input_ids=input_ids, 89 | inputs_embeds=inputs_embeds, 90 | attention_mask=attention_mask, 91 | head_mask=head_mask, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | training=training, 96 | ) 97 | # add channel noise 98 | encoder_outputs.last_hidden_state = \ 99 | self._add_channel_noise(encoder_outputs.last_hidden_state) 100 | 101 | # denoise tensor 102 | encoder_outputs.last_hidden_state = \ 103 | tf.math.tanh(encoder_outputs.last_hidden_state) 104 | tf.debugging.assert_all_finite( 105 | encoder_outputs.last_hidden_state, 106 | 'should not have nan/inf/-inf after tanh') 107 | return encoder_outputs 108 | 109 | @tf.function 110 | def _add_channel_noise(self, input): 111 | encoder_output_shape = tf.shape(input) 112 | 113 | # Channel encoder 114 | encoder_output_binary = tensor_to_binary(input) 115 | encoder_output_binary = tf.reshape(encoder_output_binary, (-1, self.k)) 116 | codewords = self.channel_encoder(encoder_output_binary) 117 | 118 | # Modulation 119 | x = self.mapper(codewords) 120 | 121 | ##################### 122 | # Channel 123 | ##################### 124 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self.coderate) 125 | no = expand_to_rank(no, 2) 126 | if isinstance(self.channel, FlatFadingChannel): 127 | shape = tf.shape(x) 128 | x = tf.reshape(x, (-1, self.channel_num_tx_ant)) 129 | y, h = self.channel([x, no]) 130 | s = tf.complex(no*tf.eye(self.channel_num_rx_ant, self.channel_num_rx_ant), 0.0) 131 | 132 | x_hat, no_eff = mf_equalizer(y, h, s) 133 | 134 | x_hat = tf.reshape(x_hat, shape) 135 | no_eff = tf.reshape(no_eff, shape) 136 | 137 | y = x_hat 138 | no = no_eff 139 | else: 140 | y = self.channel([x, no]) 141 | 142 | ##################### 143 | # Receiver 144 | ##################### 145 | # Demodulation 146 | llr = self.demapper([y, no]) 147 | llr = tf.reshape(llr, (-1, self.n)) 148 | 149 | # Channel decoder 150 | received_codewords = self.channel_decoder(llr) 151 | 152 | received_encoder_output = binary_to_tensor(received_codewords) 153 | received_encoder_output = tf.reshape(received_encoder_output, 154 | encoder_output_shape) 155 | return received_encoder_output 156 | 157 | class TFSeq2SeqSCMainLayer(tf.keras.layers.Layer): 158 | 159 | def __init__(self, 160 | config: BartConfig, 161 | bart_main_layer: TFBartMainLayer, 162 | ebno_db, 163 | polar_k=512, 164 | polar_n=1024, 165 | polar_decoder_type='SC', 166 | polar_decoder_list_size=8, 167 | num_bits_per_symbol=4, 168 | channel_type = 'AWGN', 169 | channel_num_tx_ant=1, 170 | channel_num_rx_ant=1, 171 | **kwargs): 172 | super().__init__(**kwargs) 173 | 174 | self.config = config 175 | self.shared = bart_main_layer.get_input_embeddings() 176 | 177 | # semantic encoders 178 | self.encoder = TFSeq2SeqSCEncoderChannel( 179 | encoder=bart_main_layer.encoder, 180 | ebno_db=ebno_db, 181 | polar_k=polar_k, 182 | polar_n=polar_n, 183 | polar_decoder_type=polar_decoder_type, 184 | polar_decoder_list_size=polar_decoder_list_size, 185 | num_bits_per_symbol=num_bits_per_symbol, 186 | channel_type=channel_type, 187 | channel_num_tx_ant=channel_num_tx_ant, 188 | channel_num_rx_ant=channel_num_rx_ant) 189 | self.decoder = bart_main_layer.decoder 190 | 191 | def get_input_embeddings(self): 192 | return self.shared 193 | 194 | def set_input_embeddings(self, new_embeddings): 195 | self.shared = new_embeddings 196 | self.encoder.encoder.embed_tokens = self.shared 197 | self.decoder.embed_tokens = self.shared 198 | 199 | 200 | 201 | @unpack_inputs 202 | def call(self, 203 | input_ids: Optional[TFModelInputType] = None, 204 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 205 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, 206 | decoder_attention_mask: Optional[Union[np.ndarray, 207 | tf.Tensor]] = None, 208 | decoder_position_ids: Optional[Union[np.ndarray, 209 | tf.Tensor]] = None, 210 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 211 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 212 | cross_attn_head_mask: Optional[Union[np.ndarray, 213 | tf.Tensor]] = None, 214 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, 215 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, 216 | tf.Tensor]]]] = None, 217 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 218 | decoder_inputs_embeds: Optional[Union[np.ndarray, 219 | tf.Tensor]] = None, 220 | use_cache: Optional[bool] = None, 221 | output_attentions: Optional[bool] = None, 222 | output_hidden_states: Optional[bool] = None, 223 | return_dict: Optional[bool] = None, 224 | training: Optional[bool] = False, 225 | **kwargs) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]: 226 | # different to other models, Bart automatically creates decoder_input_ids from 227 | # input_ids if no decoder_input_ids are provided 228 | if decoder_input_ids is None and decoder_inputs_embeds is None: 229 | if input_ids is None: 230 | raise ValueError( 231 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 232 | "passed, `input_ids` cannot be `None`. Please pass either " 233 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 234 | ) 235 | 236 | decoder_input_ids = shift_tokens_right( 237 | input_ids, self.config.pad_token_id, 238 | self.config.decoder_start_token_id) 239 | 240 | if encoder_outputs is None: 241 | encoder_outputs = self.encoder( 242 | input_ids=input_ids, 243 | attention_mask=attention_mask, 244 | head_mask=head_mask, 245 | inputs_embeds=inputs_embeds, 246 | output_attentions=output_attentions, 247 | output_hidden_states=output_hidden_states, 248 | return_dict=return_dict, 249 | training=training, 250 | ) 251 | 252 | # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True 253 | elif return_dict and not isinstance(encoder_outputs, 254 | TFBaseModelOutput): 255 | encoder_outputs = TFBaseModelOutput( 256 | last_hidden_state=encoder_outputs[0], 257 | hidden_states=encoder_outputs[1] 258 | if len(encoder_outputs) > 1 else None, 259 | attentions=encoder_outputs[2] 260 | if len(encoder_outputs) > 2 else None, 261 | ) 262 | 263 | # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False 264 | elif not return_dict and not isinstance(encoder_outputs, tuple): 265 | encoder_outputs = encoder_outputs.to_tuple() 266 | 267 | decoder_outputs = self.decoder( 268 | decoder_input_ids, 269 | attention_mask=decoder_attention_mask, 270 | position_ids=decoder_position_ids, 271 | encoder_hidden_states=encoder_outputs[0], 272 | encoder_attention_mask=attention_mask, 273 | head_mask=decoder_head_mask, 274 | cross_attn_head_mask=cross_attn_head_mask, 275 | past_key_values=past_key_values, 276 | inputs_embeds=decoder_inputs_embeds, 277 | use_cache=use_cache, 278 | output_attentions=output_attentions, 279 | output_hidden_states=output_hidden_states, 280 | return_dict=return_dict, 281 | training=training, 282 | ) 283 | 284 | if not return_dict: 285 | return decoder_outputs + encoder_outputs 286 | 287 | return TFSeq2SeqModelOutput( 288 | last_hidden_state=decoder_outputs.last_hidden_state, 289 | past_key_values=decoder_outputs.past_key_values, 290 | decoder_hidden_states=decoder_outputs.hidden_states, 291 | decoder_attentions=decoder_outputs.attentions, 292 | cross_attentions=decoder_outputs.cross_attentions, 293 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 294 | encoder_hidden_states=encoder_outputs.hidden_states, 295 | encoder_attentions=encoder_outputs.attentions, 296 | ) 297 | 298 | 299 | class TFSeq2SeqSCModel(TFBartPretrainedModel): 300 | 301 | def __init__(self, 302 | config: BartConfig, 303 | load_weight_prefix=None, 304 | ebno_db=None, 305 | polar_k=512, 306 | polar_n=1024, 307 | polar_decoder_type='SC', 308 | polar_decoder_list_size=8, 309 | num_bits_per_symbol=4, 310 | channel_type = 'AWGN', 311 | channel_num_tx_ant = 1, 312 | channel_num_rx_ant = 1, 313 | *inputs, 314 | **kwargs): 315 | super().__init__(config, *inputs, **kwargs) 316 | self.bart_layer = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model") 317 | # self.bart_layer(DUMMY_INPUTS) 318 | self.model = TFSeq2SeqSCMainLayer( 319 | config, 320 | ebno_db=ebno_db, 321 | bart_main_layer=self.bart_layer, 322 | polar_k=polar_k, 323 | polar_n=polar_n, 324 | polar_decoder_type=polar_decoder_type, 325 | polar_decoder_list_size=polar_decoder_list_size, 326 | num_bits_per_symbol=num_bits_per_symbol, 327 | channel_type=channel_type, 328 | channel_num_tx_ant=channel_num_tx_ant, 329 | channel_num_rx_ant=channel_num_rx_ant 330 | ) 331 | 332 | @unpack_inputs 333 | def call(self, 334 | input_ids: Optional[TFModelInputType] = None, 335 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 336 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, 337 | decoder_attention_mask: Optional[Union[np.ndarray, 338 | tf.Tensor]] = None, 339 | decoder_position_ids: Optional[Union[np.ndarray, 340 | tf.Tensor]] = None, 341 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 342 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 343 | cross_attn_head_mask: Optional[Union[np.ndarray, 344 | tf.Tensor]] = None, 345 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, 346 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, 347 | tf.Tensor]]]] = None, 348 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 349 | decoder_inputs_embeds: Optional[Union[np.ndarray, 350 | tf.Tensor]] = None, 351 | use_cache: Optional[bool] = None, 352 | output_attentions: Optional[bool] = None, 353 | output_hidden_states: Optional[bool] = None, 354 | return_dict: Optional[bool] = None, 355 | training: Optional[bool] = False, 356 | **kwargs) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: 357 | 358 | outputs = self.model( 359 | input_ids=input_ids, 360 | attention_mask=attention_mask, 361 | decoder_input_ids=decoder_input_ids, 362 | decoder_attention_mask=decoder_attention_mask, 363 | decoder_position_ids=decoder_position_ids, 364 | head_mask=head_mask, 365 | decoder_head_mask=decoder_head_mask, 366 | cross_attn_head_mask=cross_attn_head_mask, 367 | encoder_outputs=encoder_outputs, 368 | past_key_values=past_key_values, 369 | inputs_embeds=inputs_embeds, 370 | decoder_inputs_embeds=decoder_inputs_embeds, 371 | use_cache=use_cache, 372 | output_attentions=output_attentions, 373 | output_hidden_states=output_hidden_states, 374 | return_dict=return_dict, 375 | training=training, 376 | ) 377 | 378 | return outputs 379 | 380 | 381 | class TFSeq2SeqSCForConditionalGeneration(TFBartForConditionalGeneration): 382 | def __init__(self, 383 | config, 384 | ebno_db=None, 385 | polar_k=512, 386 | polar_n=1024, 387 | polar_decoder_type='SC', 388 | polar_decoder_list_size=8, 389 | num_bits_per_symbol=4, 390 | channel_type = 'AWGN', 391 | channel_num_tx_ant = 1, 392 | channel_num_rx_ant = 1, 393 | *inputs, 394 | **kwargs): 395 | super().__init__(config, *inputs, **kwargs) 396 | self.model = TFSeq2SeqSCMainLayer( 397 | config, 398 | bart_main_layer=self.model, 399 | ebno_db=ebno_db, 400 | polar_k=polar_k, 401 | polar_n=polar_n, 402 | polar_decoder_type=polar_decoder_type, 403 | polar_decoder_list_size=polar_decoder_list_size, 404 | num_bits_per_symbol=num_bits_per_symbol, 405 | channel_type=channel_type, 406 | channel_num_tx_ant=channel_num_tx_ant, 407 | channel_num_rx_ant=channel_num_rx_ant, 408 | name="model") 409 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | @tf.function 5 | def tensor_to_binary(x): 6 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32)) 7 | TO_MANTISSA = tf.constant([1<<23], tf.float32) 8 | 9 | # sign 10 | sign = tf.cast(x < 0.0, tf.float32) 11 | x = tf.math.abs(x) 12 | 13 | # exponent 14 | log_x = tf.math.floor(tf.math.log(x) / LOG_BASE2) 15 | exponent = tf.cast(log_x + 127.0, tf.uint8) 16 | 17 | # mantissa 18 | mantissa = x / tf.math.exp(log_x*LOG_BASE2) - tf.math.sign(x) 19 | mantissa = tf.math.floor(mantissa * TO_MANTISSA) 20 | mantissa = tf.cast(mantissa, tf.int32) 21 | 22 | # convert to bits 23 | bits = [None for i in range(32)] 24 | for i in range(23): 25 | bits[i] = tf.bitwise.bitwise_and(mantissa, 1) 26 | mantissa = tf.bitwise.right_shift(mantissa, 1) 27 | for i in range(23, 31): 28 | bits[i] = tf.bitwise.bitwise_and(exponent, 1) 29 | exponent = tf.bitwise.right_shift(exponent, 1) 30 | bits[31] = sign 31 | 32 | for i in range(32): 33 | bits[i] = tf.cast(bits[i], tf.float32) 34 | res = tf.stack(bits, axis=-1) 35 | return res 36 | 37 | @tf.function 38 | def binary_to_tensor(x: tf.Tensor): 39 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32)) 40 | EXPONENTS = tf.constant([ float(1 << i) for i in range(8)], tf.float32) 41 | FROM_MANTISSA = tf.constant([ 0.5**(23-i) for i in range(23)], tf.float32) 42 | 43 | x = tf.reshape(x, (-1, 32)) 44 | sign = -x[:, 31] * 2 + 1 45 | 46 | exponent = tf.math.reduce_sum(x[:, 23:31] * EXPONENTS, axis=-1) 47 | mantissa = tf.math.reduce_sum(x[:, :23] * FROM_MANTISSA, axis=-1) 48 | mantissa += tf.cast(exponent > 0.0, tf.float32) 49 | return sign * tf.math.exp((exponent - 127.0) * LOG_BASE2) * mantissa 50 | 51 | 52 | 53 | @tf.function 54 | def tensor_to_binary(x: tf.Tensor): 55 | x = tf.bitcast(x, tf.uint32) 56 | mask = tf.ones_like(x) 57 | bit0 = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)), 58 | tf.float32) 59 | bits = [bit0] 60 | 61 | for _ in range(31): 62 | x = tf.bitwise.right_shift(x, 1) 63 | bitn = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)), 64 | tf.float32) 65 | bits.append(bitn) 66 | 67 | return tf.concat(bits, axis=0) 68 | 69 | @tf.function 70 | def replace_nan(input, new_value = 0.0): 71 | new_value = float(new_value) 72 | indices = tf.where(tf.math.is_nan(input)) 73 | res = tf.tensor_scatter_nd_update( 74 | input, 75 | indices, 76 | tf.fill((tf.shape(indices)[0], ), new_value) 77 | ) 78 | return res 79 | 80 | INF = tf.constant(np.array([np.inf]), dtype=tf.float32) 81 | 82 | @tf.function 83 | def replace_nan_to_inf(x): 84 | EXPONENT_MASK = tf.constant([0x7F800000], dtype=tf.uint32) 85 | INF_MASK = tf.constant([0xFF800000], dtype=tf.uint32) 86 | IDENTITY_MASK = tf.constant([0xFFFFFFFF], dtype=tf.uint32) 87 | x = tf.bitcast(x, tf.uint32) 88 | mask = tf.where( 89 | tf.equal(tf.bitwise.bitwise_and(x, EXPONENT_MASK), EXPONENT_MASK), 90 | INF_MASK, IDENTITY_MASK 91 | ) 92 | return tf.bitcast(tf.bitwise.bitwise_and(x, mask), tf.float32) 93 | -------------------------------------------------------------------------------- /models/utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tensorflow as tf 3 | import numpy as np 4 | from .utils import * 5 | 6 | class TestBinaryConversion(unittest.TestCase): 7 | 8 | def test_conversion(self): 9 | x = tf.random.uniform((1, 64, 128)) 10 | tmp = tensor_to_binary(x) 11 | y = binary_to_tensor(tmp) 12 | y = tf.reshape(y, (1, 64, 128)) 13 | 14 | x = tf.bitcast(x, tf.uint32).numpy() 15 | y = tf.bitcast(y, tf.uint32).numpy() 16 | 17 | # bit-level exact match 18 | np.testing.assert_array_equal(x, y) 19 | 20 | def test_conversion_v2(self): 21 | shape = (1, 64, 128) 22 | x = tf.random.uniform(shape) 23 | tmp = tensor_to_binary_v2(x) 24 | y = binary_to_tensor_v2(tmp) 25 | y = tf.reshape(y, shape) 26 | 27 | x = tf.bitcast(x, tf.uint32).numpy() 28 | y = tf.bitcast(y, tf.uint32).numpy() 29 | 30 | # bit-level exact match 31 | # np.testing.assert_almost_equal(x, y) 32 | np.testing.assert_array_equal(x, y) 33 | 34 | def test_tensor_to_binary(self): 35 | # IEEE754 36 | # f32: 6027202.5 37 | # hex: 0x4ab7ef85 38 | # bin: 01001010101101111110111110000101 39 | x = tf.constant([6027202.5]) 40 | bin = tensor_to_binary(x) 41 | actual = bin.numpy().astype(np.int32).flatten() 42 | expected = np.array([ 43 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 44 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1 45 | ], dtype=np.int32)[::-1] 46 | self.assertTrue((actual == expected).all()) 47 | 48 | def test_tensor_to_binary_v2(self): 49 | # IEEE754 50 | # f32: 6027202.5 51 | # hex: 0x4ab7ef85 52 | # bin: 01001010101101111110111110000101 53 | x = tf.constant([6027202.5]) 54 | bin = tensor_to_binary_v2(x) 55 | actual = bin.numpy().astype(np.int32).flatten() 56 | expected = np.array([ 57 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 58 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1 59 | ], dtype=np.int32)[::-1] 60 | np.testing.assert_equal(actual, expected) 61 | 62 | 63 | def test_binary_to_tensor(self): 64 | # IEEE754 65 | # f32: 6027202.5 66 | # hex: 0x4ab7ef85 67 | # bin: 01001010101101111110111110000101 68 | x = np.array([ 69 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 70 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1 71 | ], 72 | dtype=np.float32)[::-1] 73 | x = tf.constant(x) 74 | bin = binary_to_tensor(x) 75 | actual = bin.numpy()[0] 76 | expected = 6027202.5 77 | 78 | self.assertAlmostEqual(actual, expected) 79 | 80 | def test_binary_to_tensor_v2(self): 81 | # IEEE754 82 | # f32: 6027202.5 83 | # hex: 0x4ab7ef85 84 | # bin: 01001010101101111110111110000101 85 | x = np.array([ 86 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 87 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1 88 | ], 89 | dtype=np.float32)[::-1] 90 | x = tf.constant(x) 91 | bin = binary_to_tensor_v2(x) 92 | actual = bin.numpy()[0] 93 | expected = 6027202.5 94 | self.assertAlmostEqual(actual, expected) 95 | 96 | def test_binary_to_tensor_with_tanh_v2(self): 97 | # f32: inf 98 | # hex: 0x7f800000 99 | # bin: 01111111100000000000000000000000 100 | x = np.array([0] + [1 for i in range(8)] + [0 for i in range(23)], 101 | dtype=np.float32)[::-1] 102 | x = tf.constant(x) 103 | y = tf.math.tanh(binary_to_tensor_v2(x)) 104 | actual = y.numpy()[0] 105 | expected = 1.0 106 | self.assertAlmostEqual(actual, expected) 107 | # f32: -inf 108 | # hex: 0xff800000 109 | # bin: 11111111100000000000000000000000 110 | x = np.array([1] + [1 for i in range(8)] + [0 for i in range(23)], 111 | dtype=np.float32)[::-1] 112 | x = tf.constant(x) 113 | y = tf.math.tanh(binary_to_tensor_v2(x)) 114 | actual = y.numpy()[0] 115 | expected = -1.0 116 | self.assertAlmostEqual(actual, expected) 117 | # f32: nan 118 | # hex: 0x7fc00000 119 | # bin: 01111111110000000000000000000000 120 | x = np.array([0] + [1 for i in range(8)] + [1] + [0 for i in range(22)], 121 | dtype=np.float32)[::-1] 122 | x = tf.constant(x) 123 | y = tf.math.tanh(binary_to_tensor_v2(x)) 124 | actual = y.numpy()[0] 125 | expected = 1.0 126 | self.assertAlmostEqual(actual, expected) 127 | # f32: -nan 128 | # hex: 0xffc00000 129 | # bin: 11111111110000000000000000000000 130 | x = np.array([1] + [1 for i in range(8)] + [1] + [0 for i in range(22)], 131 | dtype=np.float32)[::-1] 132 | x = tf.constant(x) 133 | y = tf.math.tanh(binary_to_tensor_v2(x)) 134 | actual = y.numpy()[0] 135 | expected = -1.0 136 | self.assertAlmostEqual(actual, expected) 137 | 138 | 139 | 140 | def test_replace_nan(self): 141 | x = np.array([ 142 | [[1, 2, np.nan], 143 | [4, np.nan, 6],], 144 | [[7, 8, np.nan], 145 | [10, np.nan, 12],], 146 | ], dtype=np.float32) 147 | expected = np.array([ 148 | [[1, 2, 0.0], 149 | [4, 0, 6],], 150 | [[7, 8, 0], 151 | [10, 0, 12],], 152 | ], dtype=np.float32) 153 | 154 | actual = replace_nan(tf.constant(x)) 155 | actual = actual.numpy() 156 | 157 | np.testing.assert_almost_equal(actual, expected) 158 | 159 | def test_replace_nan_to_inf(self): 160 | x = np.array([ 161 | [[1, 2, np.nan], 162 | [4, -np.nan, np.inf],], 163 | [[7, 8, -np.nan], 164 | [10, np.nan, -np.inf],], 165 | ], dtype=np.float32) 166 | expected = np.array([ 167 | [[1, 2, np.inf], 168 | [4, -np.inf, np.inf],], 169 | [[7, 8, -np.inf], 170 | [10, np.inf, -np.inf],], 171 | ], dtype=np.float32) 172 | 173 | actual = replace_nan_to_inf(tf.constant(x)) 174 | actual = actual.numpy() 175 | 176 | np.testing.assert_almost_equal(actual, expected) 177 | 178 | 179 | if __name__ == '__main__': 180 | unittest.main() 181 | -------------------------------------------------------------------------------- /poster_seq2seq_Asilomar2023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/poster_seq2seq_Asilomar2023.jpg -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/allnli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | import random 5 | 6 | from .hf_data_gen import HFDataGenerator 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC") 10 | parser.add_argument( 11 | '-o', '--out-path', 12 | dest='out_path', 13 | required=True, 14 | type=pathlib.Path, 15 | help='Required. Path of output directory') 16 | parser.add_argument( 17 | '--train-dev-split', 18 | dest='train_dev_split', 19 | default=0.9, 20 | type=float, 21 | help='Trainset/ Devset split ratio') 22 | parser.add_argument( 23 | '--seed', 24 | dest='seed', 25 | default=1234, 26 | type=int, 27 | help='Random seed') 28 | parser.add_argument( 29 | dest='path', 30 | type=pathlib.Path, 31 | help="Path of AllNLI.jsonl") 32 | args = parser.parse_args() 33 | 34 | random.seed(args.seed) 35 | 36 | with open(args.path) as f: 37 | data = [json.loads(line) for line in f] 38 | 39 | data = map(lambda l: (l[0], l[1]), data) 40 | sentences = filter(lambda l: 'n/a' not in l, data) 41 | sentences = list(sentences) 42 | 43 | N = len(sentences) 44 | devset_size = int(N*(1-args.train_dev_split)) 45 | devset_indices = random.sample(range(N), devset_size) 46 | devset_indices = set(devset_indices) 47 | 48 | trainset_gen = HFDataGenerator() 49 | devset_gen = HFDataGenerator() 50 | for i, (s1, s2) in enumerate(sentences): 51 | if i in devset_indices: 52 | devset_gen.add(s1, s2) 53 | else: 54 | trainset_gen.add(s1, s2) 55 | 56 | trainset_gen.dump(args.out_path / 'allnli_train.csv') 57 | devset_gen.dump(args.out_path / 'allnli_dev.csv') -------------------------------------------------------------------------------- /preprocess/europarl.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import json 3 | import argparse 4 | from .hf_data_gen import HFDataGenerator 5 | 6 | # The following code is copied (and slightly modified) from DeepSC 7 | # (https://github.com/zyy598/DeepSC/blob/master/preprocess_text.py) 8 | import unicodedata 9 | import re 10 | from w3lib.html import remove_tags 11 | import pickle 12 | import os 13 | import json 14 | from tqdm import tqdm 15 | 16 | def unicode_to_ascii(s): 17 | return ''.join(c for c in unicodedata.normalize('NFD', s) 18 | if unicodedata.category(c) != 'Mn') 19 | 20 | def normalize_string(s): 21 | # normalize unicode characters 22 | s = unicode_to_ascii(s) 23 | # remove the XML-tags 24 | s = remove_tags(s) 25 | # add white space before !.? 26 | s = re.sub(r'([!.?])', r' \1', s) 27 | s = re.sub(r'[^a-zA-Z.!?]+', r' ', s) 28 | s = re.sub(r'\s+', r' ', s) 29 | # change to lower letter 30 | s = s.lower() 31 | return s 32 | 33 | def cutted_data(cleaned, MIN_LENGTH=4, MAX_LENGTH=30): 34 | cutted_lines = list() 35 | for line in cleaned: 36 | length = len(line.split()) 37 | if length > MIN_LENGTH and length < MAX_LENGTH: 38 | line = [word for word in line.split()] 39 | cutted_lines.append(' '.join(line)) 40 | return cutted_lines 41 | 42 | def save_clean_sentences(sentence, save_path): 43 | pickle.dump(sentence, open(save_path, 'wb')) 44 | print('Saved: %s' % save_path) 45 | 46 | def process_text_file(file_path): 47 | with open(file_path, 'r') as f: 48 | raw_data = f.read() 49 | sentences = raw_data.strip().split('\n') 50 | raw_data_input = [normalize_string(data) for data in sentences] 51 | raw_data_input = cutted_data(raw_data_input) 52 | return raw_data_input 53 | 54 | def tokenize(s, delim=' ', add_start_token=True, add_end_token=True, 55 | punct_to_keep=None, punct_to_remove=None): 56 | """ 57 | Tokenize a sequence, converting a string s into a list of (string) tokens by 58 | splitting on the specified delimiter. Optionally keep or remove certain 59 | punctuation marks and add start and end tokens. 60 | """ 61 | if punct_to_keep is not None: 62 | for p in punct_to_keep: 63 | s = s.replace(p, '%s%s' % (delim, p)) 64 | 65 | if punct_to_remove is not None: 66 | for p in punct_to_remove: 67 | s = s.replace(p, '') 68 | 69 | tokens = s.split(delim) 70 | if add_start_token: 71 | tokens.insert(0, '') 72 | if add_end_token: 73 | tokens.append('') 74 | return tokens 75 | 76 | def build_vocab(sequences, token_to_idx = { }, min_token_count=1, delim=' ', 77 | punct_to_keep=None, punct_to_remove=None, ): 78 | token_to_count = {} 79 | 80 | for seq in sequences: 81 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep, 82 | punct_to_remove=punct_to_remove, 83 | add_start_token=False, add_end_token=False) 84 | for token in seq_tokens: 85 | if token not in token_to_count: 86 | token_to_count[token] = 0 87 | token_to_count[token] += 1 88 | 89 | for token, count in sorted(token_to_count.items()): 90 | if count >= min_token_count: 91 | token_to_idx[token] = len(token_to_idx) 92 | 93 | return token_to_idx 94 | 95 | def encode(seq_tokens, token_to_idx, allow_unk=False): 96 | seq_idx = [] 97 | for token in seq_tokens: 98 | if token not in token_to_idx: 99 | if allow_unk: 100 | token = '' 101 | else: 102 | raise KeyError('Token "%s" not in vocab' % token) 103 | seq_idx.append(token_to_idx[token]) 104 | return seq_idx 105 | 106 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True): 107 | tokens = [] 108 | for idx in seq_idx: 109 | tokens.append(idx_to_token[idx]) 110 | if stop_at_end and tokens[-1] == '': 111 | break 112 | if delim is None: 113 | return tokens 114 | else: 115 | return delim.join(tokens) 116 | 117 | SPECIAL_TOKENS = { 118 | '': 0, 119 | '': 1, 120 | '': 2, 121 | '': 3, 122 | } 123 | 124 | def process_europarl(input_data_dir, train_test_split=0.9, njobs=1): 125 | sentences = [] 126 | print('Preprocess Raw Text') 127 | from joblib import Parallel, delayed 128 | sentences = Parallel(n_jobs=njobs, verbose=1)( 129 | delayed(process_text_file)(fn) 130 | for fn in pathlib.Path(input_data_dir).glob('*.txt')) 131 | sentences = [s for s_list in sentences for s in s_list ] 132 | 133 | # remove the same sentences 134 | a = {} 135 | for set in sentences: 136 | if set not in a: 137 | a[set] = 0 138 | a[set] += 1 139 | sentences = list(a.keys()) 140 | print('Number of sentences: {}'.format(len(sentences))) 141 | 142 | print('Build Vocab') 143 | token_to_idx = build_vocab( 144 | sentences, SPECIAL_TOKENS, 145 | punct_to_keep=[';', ','], punct_to_remove=['?', '.'] 146 | ) 147 | 148 | vocab = {'token_to_idx': token_to_idx} 149 | print('Number of words in Vocab: {}'.format(len(token_to_idx))) 150 | 151 | print('Start encoding txt') 152 | results = [] 153 | for seq in tqdm(sentences): 154 | words = tokenize(seq, punct_to_keep=[';', ','], punct_to_remove=['?', '.']) 155 | tokens = [token_to_idx[word] for word in words] 156 | results.append(tokens) 157 | 158 | train_data = results[: round(len(results) * train_test_split)] 159 | test_data = results[round(len(results) * train_test_split):] 160 | 161 | return train_data, test_data, vocab 162 | # End of the copied code 163 | 164 | class Tokenizer: 165 | 166 | TOKENS_FILTERED = set([ 167 | '', '' 168 | ]) 169 | 170 | def __init__(self, vocab): 171 | idx_to_token = [None for _ in range(1 + max(vocab['token_to_idx'].values()))] 172 | for token, idx in vocab['token_to_idx'].items(): 173 | idx_to_token[idx] = token 174 | self.idx_to_token = idx_to_token 175 | self.token_to_idx = vocab['token_to_idx'] 176 | 177 | def decode(self, token_ids): 178 | tokens = map(lambda i: self.idx_to_token[i], token_ids) 179 | tokens = filter(lambda t: t not in self.TOKENS_FILTERED, tokens) 180 | return ' '.join(tokens) 181 | 182 | def batch_decode(self, token_ids_list): 183 | return list(map(lambda token_ids: self.decode(token_ids), token_ids_list)) 184 | 185 | def gen_hf_dataset(path: pathlib.Path, output_path=None, train_test_split=0.9, njobs=1): 186 | path = pathlib.Path(path) 187 | if output_path is None: 188 | output_path = path 189 | 190 | train_data, test_data, vocab = process_europarl(path, train_test_split, njobs) 191 | 192 | # save processed sentences 193 | with open(output_path / 'train_data.pkl', 'wb') as f: 194 | pickle.dump(train_data, f) 195 | with open(output_path / 'test_data.pkl', 'wb') as f: 196 | pickle.dump(test_data, f) 197 | with open(output_path / 'vocab.json', 'w') as f: 198 | json.dump(vocab, f) 199 | 200 | tokenizer = Tokenizer(vocab) 201 | 202 | # train set 203 | train_data = tokenizer.batch_decode(train_data) 204 | train_gen = HFDataGenerator() 205 | train_gen.add(train_data, train_data) 206 | train_gen.dump(output_path / 'train.csv') 207 | 208 | # test set 209 | test_data = tokenizer.batch_decode(test_data) 210 | test_gen = HFDataGenerator() 211 | test_gen.add(test_data, test_data) 212 | test_gen.dump(output_path / 'test.csv') 213 | 214 | 215 | if __name__ == '__main__': 216 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC") 217 | parser.add_argument( 218 | '-o', '--out-path', 219 | dest='out_path', 220 | required=True, 221 | type=pathlib.Path, 222 | help='Required. Path of output files.') 223 | parser.add_argument( 224 | '--train-test-split', 225 | dest='train_test_split', 226 | default=0.9, 227 | type=float, 228 | help='Trainset/ Testset split ratio') 229 | parser.add_argument( 230 | '-j' 231 | '--njobs', 232 | dest='njobs', 233 | default=1, 234 | type=int, 235 | help='Number of threads to be used for preprocessing') 236 | parser.add_argument( 237 | dest='path', 238 | type=pathlib.Path, 239 | help="Path of europarl dataset. It should be '/txt/en'") 240 | args = parser.parse_args() 241 | gen_hf_dataset(args.path, args.out_path, args.train_test_split, args.njobs) 242 | 243 | -------------------------------------------------------------------------------- /preprocess/flickr30k.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import random 4 | import argparse 5 | import pathlib 6 | 7 | def parse_line(line: str): 8 | key = line.split('#')[0] 9 | caption = line.split('\t')[-1].strip() 10 | return key, caption 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-o', '--out-path', 15 | dest='out_path', 16 | type=pathlib.Path, 17 | help='output json file path') 18 | parser.add_argument('-n', '--num-samples', 19 | dest='N', 20 | default=1000, 21 | type=int, 22 | help='number of samples (default: 1000)') 23 | parser.add_argument('--send', 24 | dest='seed', 25 | default=20221017, 26 | type=int, 27 | help='seed for random module') 28 | 29 | parser.add_argument( 30 | dest='token_path', 31 | help="path of 'results_20130124.token'") 32 | args = parser.parse_args() 33 | 34 | # read token file 35 | data = defaultdict(list) 36 | with open(args.token_path) as f: 37 | for k, caption in map(parse_line, f): 38 | data[k].append(caption) 39 | data = list(data.values()) 40 | 41 | # set seed 42 | random.seed(args.seed) 43 | 44 | # sample dataset 45 | samples = random.sample(range(len(data)), k=args.N) 46 | out_data = [] 47 | for i in samples: 48 | captions = data[i] 49 | input_idx = random.sample(range(len(captions)), k=1)[0] 50 | input_sentence = captions[input_idx] 51 | ref_sentences = captions[:input_idx] + captions[(input_idx+1):] 52 | out_data.append({ 53 | 'input': input_sentence, 54 | 'refs': ref_sentences, 55 | }) 56 | 57 | with open(args.out_path, 'w') as f: 58 | json.dump(out_data, f, indent=4) -------------------------------------------------------------------------------- /preprocess/hf_data_gen.py: -------------------------------------------------------------------------------- 1 | # exampel file format 2 | # -------------------- 3 | # text,summary 4 | # "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder","I'm sitting in a room where I'm waiting for something to happen" 5 | # "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.","I'm a gardener and I'm a big fan of flowers." 6 | # "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share","It's that time of year again." 7 | import csv 8 | 9 | class HFDataGenerator: 10 | 11 | FIELDNAMES = ['text', 'summary'] 12 | 13 | def __init__(self): 14 | self.rows = [] 15 | 16 | def add(self, text, summary): 17 | if isinstance(text, str) and isinstance(summary, str): 18 | self.rows.append({ 19 | 'text': f'{text}', 20 | 'summary': f'{summary}', 21 | }) 22 | elif isinstance(text, list) and isinstance(summary, list): 23 | assert len(text) == len(summary), "Different text and summary list" 24 | rows = map(lambda ts: {'text': ts[0], 'summary': ts[1]}, zip(text, summary)) 25 | self.rows.extend(rows) 26 | else: 27 | assert False, "No such case" 28 | 29 | 30 | def dump(self, filename): 31 | with open(filename, 'w', newline='') as f: 32 | writer = csv.DictWriter(f, fieldnames=self.FIELDNAMES, quoting=csv.QUOTE_ALL) 33 | writer.writeheader() 34 | writer.writerows(self.rows) 35 | 36 | def test(): 37 | data_path = "/tmp/hf_data_test.csv" 38 | gen = HFDataGenerator() 39 | gen.add("text1", "summary1") 40 | gen.add("text2", "summary2") 41 | gen.dump(data_path) 42 | 43 | expected = """"text","summary" 44 | "text1","summary1" 45 | "text2","summary2" 46 | """ 47 | with open( data_path) as f: 48 | actual = f.read() 49 | assert actual == expected, f"""'{actual}'""" 50 | 51 | if __name__ == '__main__': 52 | test() 53 | 54 | -------------------------------------------------------------------------------- /scripts/eval_flickr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ebno_db="10" 3 | metric="sbert" # bleu, sbert 4 | testset_path='data/flickr/processed/flickr30k.json' 5 | checkpoint_path="checkpoints/seq2seq-allnli-sc" 6 | 7 | python eval.py \ 8 | --batch 4 \ 9 | --metric "${metric}" \ 10 | --ebno-db "${ebno_db}" \ 11 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \ 12 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \ 13 | --testset-path "${testset_path}" \ 14 | $checkpoint_path -------------------------------------------------------------------------------- /scripts/preprocess_allnli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_path=data/allnli 3 | mkdir -p $data_path 4 | wget -P $data_path https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz 5 | gunzip $data_path/AllNLI.jsonl.gz 6 | 7 | allnli_dataset="$data_path/AllNLI.jsonl" 8 | out_dir="$data_path/processed" 9 | 10 | mkdir -p $out_dir 11 | python -m preprocess.allnli -o $out_dir $allnli_dataset -------------------------------------------------------------------------------- /scripts/preprocess_europarl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | europarl_dataset=data/europarl/txt/en 3 | out_dir=data/europarl/processed 4 | njobs=4 5 | 6 | mkdir -p $out_dir 7 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset 8 | -------------------------------------------------------------------------------- /scripts/preprocess_flickr30k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_path="data/flickr" 3 | dataset_path="${data_path}/flickr30k.tar.gz" 4 | out_dir="$data_path/processed" 5 | 6 | mkdir -p $out_dir 7 | 8 | tar xzf ${dataset_path} -C $data_path 9 | python -m preprocess.flickr30k \ 10 | -o "$out_dir/flickr30k.json" \ 11 | "${data_path}/results_20130124.token" -------------------------------------------------------------------------------- /scripts/train_allnli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | output_dir='checkpoints/seq2seq-allnli-sc' 3 | trainset_path='data/allnli/processed/allnli_train.csv' 4 | devset_path='data/allnli/processed/allnli_dev.csv' 5 | 6 | mkdir -p $output_dir 7 | 8 | python train.py \ 9 | --per_device_train_batch_size 4 \ 10 | --num_train_epochs 3 \ 11 | --do_train \ 12 | --do_eval \ 13 | --model_name_or_path facebook/bart-base \ 14 | --preprocessing_num_workers 4 \ 15 | --save_total_limit 1 \ 16 | --no_use_fast_tokenizer \ 17 | --num_beams 4 \ 18 | --max_source_length 64 \ 19 | --max_target_length 64 \ 20 | --train_file "$trainset_path" \ 21 | --validation_file "$devset_path" \ 22 | --test_file "$devset_path" \ 23 | --output_dir $output_dir \ 24 | --ebno_db 10 \ 25 | --channel_type AWGN \ 26 | --overwrite_output_dir \ 27 | --tokenizer_name facebook/bart-base \ 28 | --pad_to_max_length \ 29 | --dataset_config 3.0.0 -------------------------------------------------------------------------------- /scripts/train_europarl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | output_dir='checkpoints/seq2seq-europarl-sc' 3 | trainset_path='data/europarl/processed/train.csv' 4 | devset_path='data/europarl/processed/test.csv' 5 | 6 | mkdir -p $output_dir 7 | 8 | python train.py \ 9 | --per_device_train_batch_size 4 \ 10 | --num_train_epochs 3 \ 11 | --do_train \ 12 | --do_eval \ 13 | --model_name_or_path facebook/bart-base \ 14 | --preprocessing_num_workers 4 \ 15 | --save_total_limit 1 \ 16 | --no_use_fast_tokenizer \ 17 | --num_beams 4 \ 18 | --max_source_length 64 \ 19 | --max_target_length 64 \ 20 | --train_file "$trainset_path" \ 21 | --validation_file "$devset_path" \ 22 | --test_file "$devset_path" \ 23 | --output_dir $output_dir \ 24 | --ebno_db 10 \ 25 | --channel_type AWGN \ 26 | --overwrite_output_dir \ 27 | --tokenizer_name facebook/bart-base \ 28 | --pad_to_max_length \ 29 | --dataset_config 3.0.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This file is based on 18 | # https://github.com/huggingface/transformers/blob/main/examples/tensorflow/summarization/run_summarization.py 19 | import json 20 | import logging 21 | import os 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Optional, List 25 | 26 | import datasets 27 | import nltk 28 | import numpy as np 29 | import tensorflow as tf 30 | from datasets import load_dataset 31 | 32 | import evaluate 33 | import transformers 34 | from filelock import FileLock 35 | from transformers import ( 36 | AutoConfig, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | KerasMetricCallback, 41 | TFTrainingArguments, 42 | set_seed, 43 | ) 44 | from transformers.trainer_utils import get_last_checkpoint 45 | from transformers.utils import is_offline_mode 46 | from transformers.optimization_tf import create_optimizer 47 | from train.args import ModelArguments, DataTrainingArguments, summarization_name_mapping, Seq2SeqSCArguments 48 | from models import TFSeq2SeqSCForConditionalGeneration 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | try: 53 | nltk.data.find("tokenizers/punkt") 54 | except (LookupError, OSError): 55 | if is_offline_mode(): 56 | raise LookupError( 57 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 58 | ) 59 | with FileLock(".lock") as lock: 60 | nltk.download("punkt", quiet=True) 61 | # endregion 62 | 63 | def main(): 64 | # region Argument parsing 65 | # See all possible arguments in src/transformers/training_args.py 66 | # or by passing the --help flag to this script. 67 | # We now keep distinct sets of args, for a cleaner separation of concerns. 68 | 69 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments, Seq2SeqSCArguments)) 70 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 71 | # If we pass only one argument to the script and it's the path to a json file, 72 | # let's parse it to get our arguments. 73 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_json_file( 74 | json_file=os.path.abspath(sys.argv[1])) 75 | else: 76 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_args_into_dataclasses() 77 | 78 | if training_args.fp16: 79 | policy = tf.keras.mixed_precision.Policy('mixed_float16') 80 | tf.keras.mixed_precision.set_global_policy(policy) 81 | 82 | # region Logging 83 | logging.basicConfig( 84 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 85 | datefmt="%m/%d/%Y %H:%M:%S", 86 | handlers=[logging.StreamHandler(sys.stdout)], 87 | ) 88 | logger.setLevel(logging.INFO) 89 | datasets.utils.logging.set_verbosity(logging.INFO) 90 | transformers.utils.logging.set_verbosity(logging.INFO) 91 | 92 | # Log on each process the small summary: 93 | logger.info(f"Training/evaluation parameters {training_args}") 94 | # endregion 95 | 96 | # region Detecting last checkpoint 97 | last_checkpoint = None 98 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 99 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 100 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 101 | raise ValueError( 102 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 103 | "Use --overwrite_output_dir to overcome." 104 | ) 105 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 106 | logger.info( 107 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 108 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 109 | ) 110 | # endregion 111 | 112 | # Set seed before initializing model. 113 | set_seed(training_args.seed) 114 | 115 | # region Load datasets 116 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 117 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 118 | # (the dataset will be downloaded automatically from the datasets Hub). 119 | # 120 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 121 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 122 | # 123 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 124 | # download the dataset. 125 | if data_args.dataset_name is not None: 126 | # Downloading and loading a dataset from the hub. 127 | raw_datasets = load_dataset( 128 | data_args.dataset_name, 129 | data_args.dataset_config_name, 130 | cache_dir=model_args.cache_dir, 131 | use_auth_token=True if model_args.use_auth_token else None, 132 | ) 133 | else: 134 | data_files = {} 135 | if data_args.train_file is not None: 136 | data_files["train"] = data_args.train_file 137 | extension = data_args.train_file.split(".")[-1] 138 | if data_args.validation_file is not None: 139 | data_files["validation"] = data_args.validation_file 140 | extension = data_args.validation_file.split(".")[-1] 141 | if data_args.test_file is not None: 142 | data_files["test"] = data_args.test_file 143 | extension = data_args.test_file.split(".")[-1] 144 | raw_datasets = load_dataset( 145 | extension, 146 | data_files=data_files, 147 | cache_dir=model_args.cache_dir, 148 | use_auth_token=True if model_args.use_auth_token else None, 149 | ) 150 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 151 | # https://huggingface.co/docs/datasets/loading_datasets.html. 152 | # endregion 153 | 154 | # region Load model config and tokenizer 155 | # 156 | # Distributed training: 157 | # The .from_pretrained methods guarantee that only one local process can concurrently 158 | # download model & vocab. 159 | 160 | config = AutoConfig.from_pretrained( 161 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 162 | cache_dir=model_args.cache_dir, 163 | revision=model_args.model_revision, 164 | use_auth_token=True if model_args.use_auth_token else None, 165 | ) 166 | tokenizer = AutoTokenizer.from_pretrained( 167 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 168 | cache_dir=model_args.cache_dir, 169 | use_fast=model_args.use_fast_tokenizer, 170 | revision=model_args.model_revision, 171 | use_auth_token=True if model_args.use_auth_token else None, 172 | ) 173 | 174 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 175 | # endregion 176 | 177 | # region Dataset preprocessing 178 | # We need to tokenize inputs and targets. 179 | if training_args.do_train: 180 | column_names = raw_datasets["train"].column_names 181 | elif training_args.do_eval: 182 | column_names = raw_datasets["validation"].column_names 183 | else: 184 | logger.info("There is nothing to do. Please pass `do_train`, and/or `do_eval`.") 185 | return 186 | 187 | # Get the column names for input/target. 188 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 189 | if data_args.text_column is None: 190 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 191 | else: 192 | text_column = data_args.text_column 193 | if text_column not in column_names: 194 | raise ValueError( 195 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 196 | ) 197 | if data_args.summary_column is None: 198 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 199 | else: 200 | summary_column = data_args.summary_column 201 | if summary_column not in column_names: 202 | raise ValueError( 203 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 204 | ) 205 | 206 | # Temporarily set max_target_length for training. 207 | max_target_length = data_args.max_target_length 208 | padding = "max_length" if data_args.pad_to_max_length else False 209 | 210 | def preprocess_function(examples): 211 | inputs = examples[text_column] 212 | assert prefix is not None 213 | for i, inp in enumerate(inputs): 214 | if inp is None: 215 | print(i, inputs[i], inputs[i-1], inputs[i+1]) 216 | targets = examples[summary_column] 217 | inputs = [prefix + inp for inp in inputs] 218 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 219 | 220 | # Tokenize targets with the `text_target` keyword argument 221 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 222 | 223 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 224 | # padding in the loss. 225 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 226 | labels["input_ids"] = [ 227 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 228 | ] 229 | 230 | model_inputs["labels"] = labels["input_ids"] 231 | return model_inputs 232 | 233 | if training_args.do_train: 234 | if "train" not in raw_datasets: 235 | raise ValueError("--do_train requires a train dataset") 236 | train_dataset = raw_datasets["train"] 237 | if data_args.max_train_samples is not None: 238 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 239 | train_dataset = train_dataset.select(range(max_train_samples)) 240 | with training_args.main_process_first(desc="train dataset map pre-processing"): 241 | train_dataset = train_dataset.map( 242 | preprocess_function, 243 | batched=True, 244 | num_proc=data_args.preprocessing_num_workers, 245 | remove_columns=column_names, 246 | load_from_cache_file=not data_args.overwrite_cache, 247 | desc="Running tokenizer on train dataset", 248 | ) 249 | else: 250 | train_dataset = None 251 | 252 | if training_args.do_eval: 253 | max_target_length = data_args.val_max_target_length 254 | if "validation" not in raw_datasets: 255 | raise ValueError("--do_eval requires a validation dataset") 256 | eval_dataset = raw_datasets["validation"] 257 | if data_args.max_eval_samples is not None: 258 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 259 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 260 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 261 | eval_dataset = eval_dataset.map( 262 | preprocess_function, 263 | batched=True, 264 | num_proc=data_args.preprocessing_num_workers, 265 | remove_columns=column_names, 266 | load_from_cache_file=not data_args.overwrite_cache, 267 | desc="Running tokenizer on validation dataset", 268 | ) 269 | else: 270 | eval_dataset = None 271 | # endregion 272 | 273 | # region Text preprocessing 274 | def postprocess_text(preds, labels): 275 | preds = [pred.strip() for pred in preds] 276 | labels = [label.strip() for label in labels] 277 | 278 | # rougeLSum expects newline after each sentence 279 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 280 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 281 | 282 | return preds, labels 283 | 284 | # endregion 285 | with training_args.strategy.scope(): 286 | # region Prepare model 287 | model_cls = TFSeq2SeqSCForConditionalGeneration 288 | 289 | model = model_cls.from_pretrained( 290 | model_args.model_name_or_path, 291 | ebno_db=seq2seq_sc_args.ebno_db, 292 | polar_k=seq2seq_sc_args.k, 293 | polar_n=seq2seq_sc_args.n, 294 | polar_decoder_type=seq2seq_sc_args.polar_decoder_type, 295 | polar_decoder_list_size=seq2seq_sc_args.polar_decoder_list_size, 296 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol, 297 | channel_type=seq2seq_sc_args.channel_type, 298 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant, 299 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant, 300 | config=config, 301 | cache_dir=model_args.cache_dir, 302 | revision=model_args.model_revision, 303 | use_auth_token=True if model_args.use_auth_token else None, 304 | ) 305 | 306 | model.resize_token_embeddings(len(tokenizer)) 307 | # endregion 308 | 309 | # region Prepare TF Dataset objects 310 | if model.config.decoder_start_token_id is None: 311 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 312 | 313 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 314 | data_collator = DataCollatorForSeq2Seq( 315 | tokenizer, 316 | model=model, 317 | label_pad_token_id=label_pad_token_id, 318 | pad_to_multiple_of=128, # Reduce the number of unique shapes for XLA, especially for generation 319 | return_tensors="tf", 320 | ) 321 | 322 | dataset_options = tf.data.Options() 323 | dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 324 | 325 | num_replicas = training_args.strategy.num_replicas_in_sync 326 | total_train_batch_size = training_args.per_device_train_batch_size * num_replicas 327 | total_eval_batch_size = training_args.per_device_eval_batch_size * num_replicas 328 | 329 | # model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in 330 | # training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also 331 | # use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names 332 | # yourself if you use this method, whereas they are automatically inferred from the model input names when 333 | # using model.prepare_tf_dataset() 334 | # For more info see the docs: 335 | # https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset 336 | # https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset 337 | 338 | tf_train_dataset = model.prepare_tf_dataset( 339 | train_dataset, 340 | collate_fn=data_collator, 341 | batch_size=total_train_batch_size, 342 | shuffle=True, 343 | ).with_options(dataset_options) 344 | tf_eval_dataset = model.prepare_tf_dataset( 345 | eval_dataset, 346 | collate_fn=data_collator, 347 | batch_size=total_eval_batch_size, 348 | shuffle=False, 349 | ).with_options(dataset_options) 350 | # endregion 351 | 352 | # region Optimizer, loss and LR scheduling 353 | num_train_steps = int(len(tf_train_dataset) * training_args.num_train_epochs) 354 | if training_args.warmup_steps > 0: 355 | num_warmup_steps = training_args.warmup_steps 356 | elif training_args.warmup_ratio > 0: 357 | num_warmup_steps = int(num_train_steps * training_args.warmup_ratio) 358 | else: 359 | num_warmup_steps = 0 360 | if training_args.do_train: 361 | optimizer, lr_schedule = create_optimizer( 362 | init_lr=training_args.learning_rate, 363 | num_train_steps=num_train_steps, 364 | num_warmup_steps=num_warmup_steps, 365 | adam_beta1=training_args.adam_beta1, 366 | adam_beta2=training_args.adam_beta2, 367 | adam_epsilon=training_args.adam_epsilon, 368 | weight_decay_rate=training_args.weight_decay, 369 | adam_global_clipnorm=training_args.max_grad_norm, 370 | ) 371 | else: 372 | optimizer = None 373 | 374 | # endregion 375 | 376 | # region Metric and KerasMetricCallback 377 | if training_args.do_eval: 378 | metric = evaluate.load("rouge") 379 | 380 | if data_args.val_max_target_length is None: 381 | data_args.val_max_target_length = data_args.max_target_length 382 | 383 | gen_kwargs = { 384 | "max_length": data_args.val_max_target_length if data_args is not None else config.max_length, 385 | "num_beams": data_args.num_beams, 386 | "no_repeat_ngram_size": 0, # Not supported under XLA right now, and some models set it by default 387 | } 388 | 389 | def compute_metrics(preds): 390 | predictions, labels = preds 391 | if isinstance(predictions, tuple): 392 | predictions = predictions[0] 393 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 394 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 395 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 396 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 397 | metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 398 | # Only print the mid f-measures, but there are a lot of other statistics in there too! 399 | metrics = {key: round(val * 100, 4) for key, val in metrics.items()} 400 | return metrics 401 | 402 | # The KerasMetricCallback allows metrics that are too complex to write as standard Keras metrics 403 | # to be computed each epoch. Any Python code can be included in the metric_fn. This is especially 404 | # useful for metrics like BLEU and ROUGE that perform string comparisons on decoded model outputs. 405 | # For more information, see the docs at 406 | # https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.KerasMetricCallback 407 | 408 | metric_callback = KerasMetricCallback( 409 | metric_fn=compute_metrics, 410 | eval_dataset=tf_eval_dataset, 411 | predict_with_generate=True, 412 | use_xla_generation=True, 413 | generate_kwargs=gen_kwargs, 414 | ) 415 | callbacks = [metric_callback] 416 | else: 417 | callbacks = [] 418 | # endregion 419 | 420 | # region Training 421 | model.compile(optimizer=optimizer, jit_compile=training_args.xla) 422 | eval_metrics = None 423 | if training_args.do_train: 424 | logger.info("***** Running training *****") 425 | logger.info(f" Num examples = {len(train_dataset)}") 426 | logger.info(f" Num Epochs = {training_args.num_train_epochs}") 427 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") 428 | logger.info(f" Total train batch size = {total_train_batch_size}") 429 | logger.info(f" Total optimization steps = {num_train_steps}") 430 | 431 | if training_args.xla and not data_args.pad_to_max_length: 432 | logger.warning( 433 | "XLA training may be slow at first when --pad_to_max_length is not set " 434 | "until all possible shapes have been compiled." 435 | ) 436 | history = model.fit(tf_train_dataset, epochs=int(training_args.num_train_epochs), callbacks=callbacks) 437 | eval_metrics = {key: val[-1] for key, val in history.history.items()} 438 | # endregion 439 | 440 | # region Validation 441 | 442 | if training_args.do_eval and not training_args.do_train: 443 | # Do a standalone evaluation run 444 | logger.info("Evaluation...") 445 | 446 | # Compiling generation with XLA yields enormous speedups, see https://huggingface.co/blog/tf-xla-generate 447 | @tf.function(jit_compile=True) 448 | def generate(**kwargs): 449 | return model.generate(**kwargs) 450 | 451 | for batch, labels in tf_eval_dataset: 452 | batch.update(gen_kwargs) 453 | generated_tokens = generate(**batch) 454 | if isinstance(generated_tokens, tuple): 455 | generated_tokens = generated_tokens[0] 456 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 457 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 458 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 459 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 460 | 461 | metric.add_batch(predictions=decoded_preds, references=decoded_labels) 462 | 463 | eval_metrics = metric.compute(use_stemmer=True) 464 | 465 | result = {key: round(val * 100, 4) for key, val in eval_metrics.items()} 466 | logger.info(result) 467 | # endregion 468 | 469 | if training_args.output_dir is not None and eval_metrics is not None: 470 | output_eval_file = os.path.join(training_args.output_dir, "all_results.json") 471 | with open(output_eval_file, "w") as writer: 472 | writer.write(json.dumps(eval_metrics)) 473 | 474 | if training_args.output_dir is not None and not training_args.push_to_hub: 475 | # If we're not pushing to hub, at least save a local copy when we're done 476 | model.save_pretrained(training_args.output_dir) 477 | 478 | 479 | if __name__ == "__main__": 480 | main() 481 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/train/__init__.py -------------------------------------------------------------------------------- /train/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | @dataclass 5 | class Seq2SeqSCArguments: 6 | ebno_db: Optional[float] = field(default=None, metadata= {"help": "ebno_db"}) 7 | k: Optional[int] = field(default = 512, metadata= {"help": "K for polar decoder"}) 8 | n: Optional[int] = field(default = 1024, metadata= {"help": "N for polar decoder"}) 9 | polar_decoder_type: Optional[str] = field( 10 | default = 'SC', 11 | metadata= {"help": "Polar Decoder Type"}) 12 | polar_decoder_list_size: Optional[int] = field(default=8,metadata= {"help": "Polar Decoder List size"}) 13 | num_bits_per_symbol: Optional[int] = field(default=4,metadata= {"help": "number of bits per symbol"}) 14 | channel_type: Optional[str] = field(default='AWGN',metadata= {"help": "AWGN or FlatFadingChannel"}) 15 | channel_num_tx_ant: Optional[int] = field(default=1,metadata= {"help": "number of tx antennas for FlatFaddingChannel"}) 16 | channel_num_rx_ant: Optional[int] = field(default=1,metadata= {"help": "number of rx antennas for FlatFaddingChannel"}) 17 | 18 | 19 | @dataclass 20 | class ModelArguments: 21 | model_name_or_path: str = field( 22 | metadata={ 23 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 24 | } 25 | ) 26 | config_name: Optional[str] = field( 27 | default=None, 28 | metadata={ 29 | "help": "Pretrained config name or path if not the same as model_name" 30 | } 31 | ) 32 | tokenizer_name: Optional[str] = field( 33 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 34 | ) 35 | cache_dir: Optional[str] = field( 36 | default=None, 37 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 38 | ) 39 | use_fast_tokenizer: bool = field( 40 | default=True, 41 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 42 | ) 43 | model_revision: str = field( 44 | default="main", 45 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 46 | ) 47 | use_auth_token: bool = field( 48 | default=False, 49 | metadata={ 50 | "help": ( 51 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 52 | "with private models)." 53 | ) 54 | }, 55 | ) 56 | 57 | 58 | @dataclass 59 | class DataTrainingArguments: 60 | """ 61 | Arguments pertaining to what data we are going to input our model for training and eval. 62 | """ 63 | 64 | dataset_name: Optional[str] = field( 65 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 66 | ) 67 | dataset_config_name: Optional[str] = field( 68 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 69 | ) 70 | text_column: Optional[str] = field( 71 | default=None, 72 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 73 | ) 74 | summary_column: Optional[str] = field( 75 | default=None, 76 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 77 | ) 78 | train_file: Optional[str] = field( 79 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 80 | ) 81 | validation_file: Optional[str] = field( 82 | default=None, 83 | metadata={ 84 | "help": ( 85 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 86 | ) 87 | }, 88 | ) 89 | test_file: Optional[str] = field( 90 | default=None, 91 | metadata={ 92 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 93 | }, 94 | ) 95 | overwrite_cache: bool = field( 96 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 97 | ) 98 | preprocessing_num_workers: Optional[int] = field( 99 | default=None, 100 | metadata={"help": "The number of processes to use for the preprocessing."}, 101 | ) 102 | max_source_length: Optional[int] = field( 103 | default=1024, 104 | metadata={ 105 | "help": ( 106 | "The maximum total input sequence length after tokenization. Sequences longer " 107 | "than this will be truncated, sequences shorter will be padded." 108 | ) 109 | }, 110 | ) 111 | max_target_length: Optional[int] = field( 112 | default=128, 113 | metadata={ 114 | "help": ( 115 | "The maximum total sequence length for target text after tokenization. Sequences longer " 116 | "than this will be truncated, sequences shorter will be padded." 117 | ) 118 | }, 119 | ) 120 | val_max_target_length: Optional[int] = field( 121 | default=None, 122 | metadata={ 123 | "help": ( 124 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 125 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 126 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 127 | "during ``evaluate`` and ``predict``." 128 | ) 129 | }, 130 | ) 131 | pad_to_max_length: bool = field( 132 | default=False, 133 | metadata={ 134 | "help": ( 135 | "Whether to pad all samples to model maximum sentence length. " 136 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 137 | "efficient on GPU but very bad for TPU." 138 | ) 139 | }, 140 | ) 141 | max_train_samples: Optional[int] = field( 142 | default=None, 143 | metadata={ 144 | "help": ( 145 | "For debugging purposes or quicker training, truncate the number of training examples to this " 146 | "value if set." 147 | ) 148 | }, 149 | ) 150 | max_eval_samples: Optional[int] = field( 151 | default=None, 152 | metadata={ 153 | "help": ( 154 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 155 | "value if set." 156 | ) 157 | }, 158 | ) 159 | max_predict_samples: Optional[int] = field( 160 | default=None, 161 | metadata={ 162 | "help": ( 163 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 164 | "value if set." 165 | ) 166 | }, 167 | ) 168 | num_beams: Optional[int] = field( 169 | default=None, 170 | metadata={ 171 | "help": ( 172 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 173 | "which is used during ``evaluate`` and ``predict``." 174 | ) 175 | }, 176 | ) 177 | ignore_pad_token_for_loss: bool = field( 178 | default=True, 179 | metadata={ 180 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 181 | }, 182 | ) 183 | source_prefix: Optional[str] = field( 184 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 185 | ) 186 | 187 | def __post_init__(self): 188 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 189 | raise ValueError("Need either a dataset name or a training/validation file.") 190 | else: 191 | if self.train_file is not None: 192 | extension = self.train_file.split(".")[-1] 193 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 194 | if self.validation_file is not None: 195 | extension = self.validation_file.split(".")[-1] 196 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 197 | if self.val_max_target_length is None: 198 | self.val_max_target_length = self.max_target_length 199 | 200 | 201 | # endregion 202 | 203 | # region Dataset name mappings 204 | summarization_name_mapping = { 205 | "amazon_reviews_multi": ("review_body", "review_title"), 206 | "big_patent": ("description", "abstract"), 207 | "cnn_dailymail": ("article", "highlights"), 208 | "orange_sum": ("text", "summary"), 209 | "pn_summary": ("article", "summary"), 210 | "psc": ("extract_text", "summary_text"), 211 | "samsum": ("dialogue", "summary"), 212 | "thaisum": ("body", "summary"), 213 | "xglue": ("news_body", "news_title"), 214 | "xsum": ("document", "summary"), 215 | "wiki_summary": ("article", "highlights"), 216 | "multi_news": ("document", "summary"), 217 | } 218 | # endregion 219 | 220 | 221 | --------------------------------------------------------------------------------