├── .gitignore
├── README.md
├── environment.yml
├── eval.py
├── models
├── __init__.py
├── seq2seq_sc.py
├── utils.py
└── utils_test.py
├── poster_seq2seq_Asilomar2023.jpg
├── preprocess
├── __init__.py
├── allnli.py
├── europarl.py
├── flickr30k.py
└── hf_data_gen.py
├── scripts
├── eval_flickr.sh
├── preprocess_allnli.sh
├── preprocess_europarl.sh
├── preprocess_flickr30k.sh
├── train_allnli.sh
└── train_europarl.sh
├── train.py
└── train
├── __init__.py
└── args.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | **/__pycache__
3 | data
4 | .vscode
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # seq2seq-SC
2 |
3 |
4 |
5 | ## Citation
6 |
7 | ```bash
8 | @misc{lee2022seq2seqSC,
9 | author = {Lee, Ju-Hyung and Lee, Dong-Ho and Sheen, Eunsoo and Choi, Thomas and Pujara, Jay and Kim, Joongheon},
10 | title = {Seq2Seq-SC: End-to-End Semantic Communication Systems with Pre-trained Language Model},
11 | journal={arXiv preprint arXiv:2210.15237},
12 | year = {2022},
13 | }
14 | ```
15 |
16 | ## Setup
17 |
18 | 1. Setup conda environment and activate
19 |
20 | ```bash
21 | conda env create -f environment.yml
22 | ```
23 |
24 | ## Data Preprocessing
25 |
26 | ### Europarl dataset
27 |
28 | ```bash
29 | data_path=data/europarl
30 | mkdir -p $data_path
31 | cd $data_path
32 | wget -P /tmp http://www.statmt.org/europarl/v7/europarl.tgz
33 | tar zxf /tmp/europarl.tgz
34 |
35 | europarl_dataset="$data_path/txt/en"
36 | out_dir="$data_path/processed"
37 | njobs=4
38 |
39 | mkdir -p $out_dir
40 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset
41 | ```
42 |
43 | ### AllNLI
44 |
45 | Run `./scripts/preprocess_allnli.sh` or the following commands
46 |
47 | ```bash
48 | data_path=data/allnli
49 | mkdir -p $data_path
50 | wget -P $data_path https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz
51 | gunzip $data_path/AllNLI.jsonl.gz
52 |
53 | allnli_dataset="$data_path/AllNLI.jsonl"
54 | out_dir="$data_path/processed"
55 |
56 | mkdir -p $out_dir
57 | python -m preprocess.allnli -o $out_dir $allnli_dataset
58 | ```
59 |
60 | ### Flickr30K
61 |
62 | To download the dataset, go to [Flickr30K](http://hockenmaier.cs.illinois.edu/DenotationGraph/) and fill out the form to get the downloadable link.
63 |
64 | ```bash
65 | data_path="data/flickr"
66 | dataset_path="${data_path}/flickr30k.tar.gz"
67 | out_dir="$data_path/processed"
68 |
69 | mkdir -p $out_dir
70 |
71 | tar xzf ${dataset_path} -C $data_path
72 | python -m preprocess.flickr30k \
73 | -o "$out_dir/flickr30k.json" \
74 | "${data_path}/results_20130124.token"
75 | ```
76 |
77 | ## Train
78 |
79 | You can run `scripts/train_europarl.sh` or `scripts/train_allnli.sh`. Otherwise, you can train by running the follwing commands.
80 |
81 | ```bash
82 | output_dir='checkpoints/seq2seq-sc'
83 | trainset_path='data/allnli/processed/allnli_train.csv'
84 | devset_path='data/allnli/processed/allnli_dev.csv'
85 |
86 | mkdir -p $output_dir
87 |
88 | python train.py \
89 | --per_device_train_batch_size 4 \
90 | --num_train_epochs 3 \
91 | --do_train \
92 | --do_eval \
93 | --model_name_or_path facebook/bart-base \
94 | --preprocessing_num_workers 4 \
95 | --save_total_limit 1 \
96 | --no_use_fast_tokenizer \
97 | --num_beams 4 \
98 | --max_source_length 64 \
99 | --max_target_length 64 \
100 | --train_file "$trainset_path" \
101 | --validation_file "$devset_path" \
102 | --test_file "$devset_path" \
103 | --output_dir $output_dir \
104 | --ebno_db 10 \
105 | --channel_type AWGN \
106 | --overwrite_output_dir \
107 | --tokenizer_name facebook/bart-base \
108 | --pad_to_max_length \
109 | --dataset_config 3.0.0
110 | ```
111 |
112 | ## Evaluation
113 |
114 | You can use the script `scripts/eval_flickr.sh` or the following commands:
115 |
116 | ```bash
117 | # BLEU score
118 | ebno_db="10"
119 | metric="bleu" # bleu, sbert
120 | testset_path='data/flickr/processed/flickr30k.json'
121 | checkpoint_path="checkpoints/seq2seq-allnli-sc"
122 |
123 | python eval.py \
124 | --batch 4 \
125 | --metric "${metric}" \
126 | --ebno-db "${ebno_db}" \
127 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \
128 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \
129 | --testset-path "${testset_path}" \
130 | $checkpoint_path
131 | ```
132 |
133 | ```bash
134 | # SBERT
135 | ebno_db="10"
136 | metric="sbert" # bleu, sbert
137 | testset_path='data/flickr/processed/flickr30k.json'
138 | checkpoint_path="checkpoints/seq2seq-allnli-sc"
139 |
140 | python eval.py \
141 | --batch 4 \
142 | --metric "${metric}" \
143 | --ebno-db "${ebno_db}" \
144 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \
145 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \
146 | --testset-path "${testset_path}" \
147 | $checkpoint_path
148 | ```
149 |
150 |
151 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: seq2seq-sc
2 | channels:
3 | - huggingface
4 | - anaconda
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - pip==22.3
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=2_gnu
11 | - blas=1.0=mkl
12 | - brotlipy=0.7.0=py39h27cfd23_1003
13 | - ca-certificates=2022.07.19=h06a4308_0
14 | - certifi=2022.9.14=py39h06a4308_0
15 | - cffi=1.15.1=py39h74dc2b5_0
16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
17 | - click=8.0.4=py39h06a4308_0
18 | - cryptography=37.0.1=py39h9ce1e76_0
19 | - cudatoolkit=11.2.2=hbe64b41_10
20 | - cudnn=8.1.0.77=h90431f1_0
21 | - dataclasses=0.8=pyh6d0b6a4_7
22 | - filelock=3.6.0=pyhd3eb1b0_0
23 | - huggingface_hub=0.9.1=py_0
24 | - idna=3.3=pyhd3eb1b0_0
25 | - importlib-metadata=4.11.3=py39h06a4308_0
26 | - importlib_metadata=4.11.3=hd3eb1b0_0
27 | - intel-openmp=2021.4.0=h06a4308_3561
28 | - joblib=1.1.0=pyhd3eb1b0_0
29 | - ld_impl_linux-64=2.38=h1181459_1
30 | - libffi=3.3=he6710b0_2
31 | - libgcc-ng=12.1.0=h8d9b700_16
32 | - libgomp=12.1.0=h8d9b700_16
33 | - libprotobuf=3.20.1=h4ff587b_0
34 | - libstdcxx-ng=12.1.0=ha89aaad_16
35 | - mkl=2021.4.0=h06a4308_640
36 | - mkl-service=2.4.0=py39h7f8727e_0
37 | - mkl_fft=1.3.1=py39hd3c417c_0
38 | - mkl_random=1.2.2=py39h51133e4_0
39 | - ncurses=6.3=h5eee18b_3
40 | - numpy=1.23.1=py39h6c91a56_0
41 | - numpy-base=1.23.1=py39ha15fc14_0
42 | - openssl=1.1.1q=h7f8727e_0
43 | - packaging=21.3=pyhd3eb1b0_0
44 | - pycparser=2.21=pyhd3eb1b0_0
45 | - pyopenssl=22.0.0=pyhd3eb1b0_0
46 | - pyparsing=3.0.9=py39h06a4308_0
47 | - pysocks=1.7.1=py39h06a4308_0
48 | - python=3.9.13=haa1d7c7_1
49 | - pyyaml=6.0=py39h7f8727e_1
50 | - readline=8.1.2=h7f8727e_1
51 | - regex=2022.7.9=py39h5eee18b_0
52 | - requests=2.28.1=py39h06a4308_0
53 | - sacremoses=master=py_0
54 | - setuptools=63.4.1=py39h06a4308_0
55 | - six=1.16.0=pyhd3eb1b0_1
56 | - sqlite=3.39.2=h5082296_0
57 | - tk=8.6.12=h1ccaba5_0
58 | - tqdm=4.64.0=py39h06a4308_0
59 | - transformers=4.22.1=py_0
60 | - typing-extensions=4.3.0=py39h06a4308_0
61 | - typing_extensions=4.3.0=py39h06a4308_0
62 | - tzdata=2022c=h04d1e81_0
63 | - urllib3=1.26.11=py39h06a4308_0
64 | - wheel=0.37.1=pyhd3eb1b0_0
65 | - xz=5.2.5=h7f8727e_1
66 | - yaml=0.2.5=h7b6447c_0
67 | - zipp=3.8.0=py39h06a4308_0
68 | - zlib=1.2.12=h5eee18b_3
69 | - pip:
70 | - absl-py==1.2.0
71 | - accelerate==0.13
72 | - aiohttp==3.8.3
73 | - aiosignal==1.2.0
74 | - astroid==2.12.10
75 | - astunparse==1.6.3
76 | - async-timeout==4.0.2
77 | - attrs==22.1.0
78 | - bert-score==0.3.12
79 | - cachetools==5.2.0
80 | - contourpy==1.0.5
81 | - cycler==0.11.0
82 | - datasets==2.5.1
83 | - dill==0.3.5.1
84 | - evaluate==0.2.2
85 | - flatbuffers==2.0.7
86 | - fonttools==4.37.3
87 | - frozenlist==1.3.1
88 | - fsspec==2022.8.2
89 | - gast==0.4.0
90 | - google-auth==2.11.1
91 | - google-auth-oauthlib==0.4.6
92 | - google-pasta==0.2.0
93 | - grpcio==1.48.1
94 | - h5py==3.7.0
95 | - importlib-resources==5.9.0
96 | - isort==5.10.1
97 | - keras==2.10.0
98 | - keras-preprocessing==1.1.2
99 | - kiwisolver==1.4.4
100 | - lazy-object-proxy==1.7.1
101 | - libclang==14.0.6
102 | - markdown==3.4.1
103 | - markupsafe==2.1.1
104 | - matplotlib==3.6.0
105 | - mccabe==0.7.0
106 | - moverscore==1.0.3
107 | - multidict==6.0.2
108 | - multiprocess==0.70.13
109 | - nltk==3.7
110 | - oauthlib==3.2.1
111 | - opt-einsum==3.3.0
112 | - pandas==1.5.0
113 | - pillow==9.2.0
114 | - platformdirs==2.5.2
115 | - portalocker==2.6.0
116 | - protobuf==3.19.5
117 | - psutil==5.9.2
118 | - pyarrow==8.0.0
119 | - pyasn1==0.4.8
120 | - pyasn1-modules==0.2.8
121 | - pyemd==0.5.1
122 | - pylint==2.15.3
123 | - python-dateutil==2.8.2
124 | - pytz==2022.2.1
125 | - requests-oauthlib==1.3.1
126 | - responses==0.18.0
127 | - rouge-score==0.1.2
128 | - rsa==4.9
129 | - scikit-learn==1.1.2
130 | - scipy==1.9.1
131 | - sentence-transformers==2.2.2
132 | - sentencepiece==0.1.97
133 | - sionna==0.11.0
134 | - tensorboard==2.10.0
135 | - tensorboard-data-server==0.6.1
136 | - tensorboard-plugin-wit==1.8.1
137 | - tensorflow==2.10.0
138 | - tensorflow-estimator==2.10.0
139 | - tensorflow-io-gcs-filesystem==0.27.0
140 | - termcolor==2.0.1
141 | - threadpoolctl==3.1.0
142 | - tokenizers==0.12.1
143 | - tomli==2.0.1
144 | - tomlkit==0.11.4
145 | - torch==1.12.1
146 | - torchvision==0.13.1
147 | - typing==3.7.4.3
148 | - werkzeug==2.2.2
149 | - wrapt==1.14.1
150 | - xxhash==3.0.0
151 | - yapf==0.32.0
152 | - yarl==1.8.1
153 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import json
3 | import argparse
4 | import logging
5 | from transformers import BartTokenizer
6 | import evaluate
7 | from tqdm import tqdm
8 | import warnings
9 |
10 | def get_test_data(path):
11 | with open(path) as f:
12 | return json.load(f)
13 |
14 | def from_pretrained(path, ebno_db):
15 | from models import TFSeq2SeqSCForConditionalGeneration
16 | import transformers
17 | transformers.utils.logging.set_verbosity(logging.INFO)
18 | return TFSeq2SeqSCForConditionalGeneration.from_pretrained(
19 | path, ebno_db=ebno_db)
20 |
21 | def predict(path, ebno_db, tokenizer, batch_size, test_data_path, num_beams):
22 | import tensorflow as tf
23 | max_len = 32
24 |
25 | # load model
26 | model = from_pretrained(path, ebno_db)
27 |
28 | # # load testset
29 | test_data = get_test_data(test_data_path)
30 | input_sentences = [d['input'] for d in test_data]
31 | input_ids = tokenizer(input_sentences, return_tensors="tf",
32 | padding='max_length', truncation=True, max_length=max_len).input_ids
33 | testset = tf.data.Dataset.from_tensor_slices(input_ids)
34 |
35 | # inference
36 | pred_sentences = []
37 | for input_ids in tqdm(testset.batch(batch_size).prefetch(tf.data.AUTOTUNE)):
38 | pred_batch = model.generate(input_ids, max_new_tokens=max_len, num_beams=num_beams)
39 | output_strs = tokenizer.batch_decode(pred_batch,
40 | skip_special_tokens=True,
41 | clean_up_tokenization_spaces=False)
42 | pred_sentences.extend(output_strs)
43 |
44 |
45 | res = {
46 | 'input': input_sentences,
47 | 'pred': pred_sentences,
48 | 'refs': [d['refs'] for d in test_data],
49 | }
50 | return res
51 |
52 | def get_predictions(path, ebno_db, test_data_path, prediction_json_path, batch_size, tokenizer, num_beams):
53 | path = pathlib.Path(path)
54 | if not prediction_json_path.exists():
55 | print('Missing predictions.json')
56 | res = predict(
57 | path=path,
58 | ebno_db=ebno_db,
59 | tokenizer=tokenizer,
60 | batch_size=batch_size,
61 | test_data_path=test_data_path,
62 | num_beams=num_beams,
63 | )
64 |
65 | # save result
66 | with open(prediction_json_path, 'w') as f:
67 | json.dump(res, f, indent=4)
68 | else:
69 | with open(prediction_json_path, 'r') as f:
70 | res = json.load(f)
71 | return res
72 |
73 | def calc_bleu(predictions, tokenizer, multi_ref, **kwargs):
74 | bleu = evaluate.load('bleu')
75 | if multi_ref:
76 | warnings.warn('BLEU does not support multiple references')
77 | tokenize = lambda l: tokenizer(l, add_special_tokens=False).input_ids
78 | results = bleu.compute(
79 | references=predictions['input'],
80 | predictions=predictions['pred'],
81 | tokenizer=tokenize,
82 | max_order=4)
83 | return results
84 |
85 | def calc_sbert(predictions, batch_size, multi_ref, **kwargs):
86 | from sentence_transformers import SentenceTransformer, util
87 | import torch
88 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
89 | model = SentenceTransformer(
90 | model_name_or_path='all-MiniLM-L6-v2',
91 | device=device)
92 |
93 | sentences1 = predictions['pred']
94 | if not multi_ref:
95 | refs = [[s] for s in predictions['input']]
96 | else:
97 | refs = predictions['refs']
98 |
99 | def calc_cos_score(model, hyp_embedding, ref_sentences):
100 | hyp = hyp_embedding.reshape((1, -1))
101 | refs = model.encode(ref_sentences, convert_to_tensor=True)
102 | scores = util.cos_sim(hyp, refs)
103 | scores = scores.reshape((-1)).tolist()
104 | return {
105 | 'scores': scores,
106 | 'max_score': max(scores),
107 | 'mean_score': sum(scores) / len(scores),
108 | }
109 |
110 |
111 | # compute embedding
112 | pred_embed = model.encode(sentences1, batch_size=batch_size, convert_to_tensor=True)
113 | N = pred_embed.shape[0]
114 | scores = [
115 | calc_cos_score(model, pred_embed[i], refs[i]) for i in range(N)
116 | ]
117 | max_scores = [s['max_score'] for s in scores]
118 | mean_score = sum(max_scores)/len(max_scores)
119 | return {
120 | 'metric': 'sentence textual similarity',
121 | 'mean_score': mean_score,
122 | 'scores': scores,
123 | }
124 |
125 | METRIC_TO_SCORER = {
126 | 'bleu': calc_bleu,
127 | 'sbert': calc_sbert,
128 | }
129 |
130 | def calc(args):
131 | tokenizer = BartTokenizer.from_pretrained(args.tokenizer)
132 |
133 | path = args.path
134 | metric = args.metric
135 | batch_size = args.batch_size
136 |
137 | predictions = get_predictions(
138 | path,
139 | ebno_db=args.ebno_db,
140 | prediction_json_path=args.prediction_json_path,
141 | test_data_path=args.testset_path,
142 | batch_size=batch_size,
143 | tokenizer=tokenizer,
144 | num_beams=args.num_beams)
145 | scorer = METRIC_TO_SCORER[metric]
146 | results = scorer(
147 | predictions=predictions,
148 | tokenizer=tokenizer,
149 | batch_size=batch_size,
150 | multi_ref=args.multi_ref,
151 | )
152 | # dump result
153 | with open(args.result_json_path, 'w') as f:
154 | json.dump(results, f, indent=4)
155 |
156 |
157 | def main():
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument(dest='path', metavar='checkpoint_path', type=pathlib.Path)
160 | parser.add_argument('-m', '--metric', choices = list(METRIC_TO_SCORER.keys()), dest='metric')
161 | parser.add_argument('-b', '--batch-size', default=4, type=int, dest='batch_size')
162 | parser.add_argument('-e', '--ebno-db', required=True, type=float, dest='ebno_db')
163 | parser.add_argument('--testset-path',
164 | required=True, type=pathlib.Path, dest='testset_path')
165 | parser.add_argument('--prediction-json-path',
166 | required=True,
167 | type=pathlib.Path,
168 | dest='prediction_json_path',
169 | help='Required. Output path of prediction result cache json file. \
170 | If the file exists, the prediction result will be reused')
171 | parser.add_argument('--result-json-path',
172 | default=pathlib.Path('./result.json'),
173 | type= pathlib.Path,
174 | dest='result_json_path')
175 | parser.add_argument('--tokenizer',
176 | default='facebook/bart-base',
177 | dest='tokenizer')
178 | parser.add_argument('--num-beams',
179 | default=1,
180 | type=int,
181 | dest='num_beams')
182 | parser.add_argument('--multi-ref',
183 | action='store_true',
184 | dest='multi_ref')
185 | args = parser.parse_args()
186 | calc(args)
187 |
188 | if __name__ == '__main__':
189 | main()
190 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.seq2seq_sc import TFSeq2SeqSCMainLayer, TFSeq2SeqSCForConditionalGeneration, TFSeq2SeqSCModel
--------------------------------------------------------------------------------
/models/seq2seq_sc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 | from transformers import TFBartPretrainedModel, TFBartForConditionalGeneration
3 | from transformers.models.bart.modeling_tf_bart import TFBartMainLayer, BartConfig, shift_tokens_right, TFBartEncoder
4 | from transformers.modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqModelOutput
5 | from transformers.modeling_tf_utils import unpack_inputs, TFModelInputType, DUMMY_INPUTS
6 | import tensorflow as tf
7 |
8 | import sionna
9 | sionna.Config.xla_compat=True
10 |
11 | from sionna.channel import AWGN, FlatFadingChannel
12 | from sionna.fec.polar import Polar5GEncoder, Polar5GDecoder
13 | from sionna.mapping import Mapper, Demapper, Constellation
14 | from sionna.mimo import mf_equalizer
15 | from sionna.utils import ebnodb2no, expand_to_rank
16 | from .utils import tensor_to_binary, binary_to_tensor
17 | import numpy as np
18 | from transformers.utils import logging
19 |
20 | class TFSeq2SeqSCEncoderChannel(tf.keras.layers.Layer):
21 |
22 | def __init__(self,
23 | encoder: TFBartEncoder,
24 | ebno_db,
25 | polar_k=512,
26 | polar_n=1024,
27 | polar_decoder_type='SC',
28 | polar_decoder_list_size=8,
29 | num_bits_per_symbol=4,
30 | channel_type = 'AWGN',
31 | channel_num_tx_ant=1,
32 | channel_num_rx_ant=1):
33 | # NOTE: setting layer name as follows seems strange,
34 | # but it allows HuggingFace to load pretrained weight properly
35 | super().__init__(name='model/model/')
36 | logger = logging.get_logger("transformers")
37 | self.config = encoder.config
38 | self.bart_encoder = encoder
39 |
40 | # channel encoder/decoder, channel noise
41 | assert ebno_db is not None
42 | self.ebno_db = float(ebno_db)
43 | self.k = polar_k
44 | self.n = polar_n
45 | logger.info(f'{self.ebno_db=}')
46 | logger.info(f'{self.k=}')
47 | logger.info(f'{self.n=}')
48 | self.channel_encoder = Polar5GEncoder(k=self.k, n=self.n)
49 |
50 | constellation = Constellation("qam",
51 | num_bits_per_symbol,
52 | trainable=False)
53 | logger.info(f'Constellation: type={constellation._constellation_type} {num_bits_per_symbol=} trainable={constellation._trainable}')
54 | self.num_bits_per_symbol = num_bits_per_symbol
55 | self.mapper = Mapper(constellation=constellation)
56 | if channel_type == 'AWGN':
57 | self.channel = AWGN()
58 | self.channel_num_tx_ant = 1
59 | self.channel_num_rx_ant = 1
60 | elif channel_type == 'FlatFadingChannel':
61 | self.channel = FlatFadingChannel(channel_num_tx_ant, channel_num_rx_ant, add_awgn=True, return_channel=True)
62 | self.channel_num_tx_ant = channel_num_tx_ant
63 | self.channel_num_rx_ant = channel_num_rx_ant
64 | else:
65 | raise ValueError(f"Invalid channel type: {channel_type}")
66 | logger.info(f'{channel_type=}')
67 | self.demapper = Demapper("app", constellation=constellation)
68 | self.channel_decoder = Polar5GDecoder(
69 | self.channel_encoder,
70 | dec_type=polar_decoder_type,
71 | list_size=polar_decoder_list_size)
72 | self.coderate = self.k / self.n
73 | logger.info(f'{self.coderate=}')
74 |
75 | @unpack_inputs
76 | def call(
77 | self,
78 | input_ids: Optional[TFModelInputType] = None,
79 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
80 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
81 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
82 | output_attentions: Optional[bool] = None,
83 | output_hidden_states: Optional[bool] = None,
84 | return_dict: Optional[bool] = None,
85 | training: Optional[bool] = False,
86 | ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
87 | encoder_outputs = self.bart_encoder(
88 | input_ids=input_ids,
89 | inputs_embeds=inputs_embeds,
90 | attention_mask=attention_mask,
91 | head_mask=head_mask,
92 | output_attentions=output_attentions,
93 | output_hidden_states=output_hidden_states,
94 | return_dict=return_dict,
95 | training=training,
96 | )
97 | # add channel noise
98 | encoder_outputs.last_hidden_state = \
99 | self._add_channel_noise(encoder_outputs.last_hidden_state)
100 |
101 | # denoise tensor
102 | encoder_outputs.last_hidden_state = \
103 | tf.math.tanh(encoder_outputs.last_hidden_state)
104 | tf.debugging.assert_all_finite(
105 | encoder_outputs.last_hidden_state,
106 | 'should not have nan/inf/-inf after tanh')
107 | return encoder_outputs
108 |
109 | @tf.function
110 | def _add_channel_noise(self, input):
111 | encoder_output_shape = tf.shape(input)
112 |
113 | # Channel encoder
114 | encoder_output_binary = tensor_to_binary(input)
115 | encoder_output_binary = tf.reshape(encoder_output_binary, (-1, self.k))
116 | codewords = self.channel_encoder(encoder_output_binary)
117 |
118 | # Modulation
119 | x = self.mapper(codewords)
120 |
121 | #####################
122 | # Channel
123 | #####################
124 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self.coderate)
125 | no = expand_to_rank(no, 2)
126 | if isinstance(self.channel, FlatFadingChannel):
127 | shape = tf.shape(x)
128 | x = tf.reshape(x, (-1, self.channel_num_tx_ant))
129 | y, h = self.channel([x, no])
130 | s = tf.complex(no*tf.eye(self.channel_num_rx_ant, self.channel_num_rx_ant), 0.0)
131 |
132 | x_hat, no_eff = mf_equalizer(y, h, s)
133 |
134 | x_hat = tf.reshape(x_hat, shape)
135 | no_eff = tf.reshape(no_eff, shape)
136 |
137 | y = x_hat
138 | no = no_eff
139 | else:
140 | y = self.channel([x, no])
141 |
142 | #####################
143 | # Receiver
144 | #####################
145 | # Demodulation
146 | llr = self.demapper([y, no])
147 | llr = tf.reshape(llr, (-1, self.n))
148 |
149 | # Channel decoder
150 | received_codewords = self.channel_decoder(llr)
151 |
152 | received_encoder_output = binary_to_tensor(received_codewords)
153 | received_encoder_output = tf.reshape(received_encoder_output,
154 | encoder_output_shape)
155 | return received_encoder_output
156 |
157 | class TFSeq2SeqSCMainLayer(tf.keras.layers.Layer):
158 |
159 | def __init__(self,
160 | config: BartConfig,
161 | bart_main_layer: TFBartMainLayer,
162 | ebno_db,
163 | polar_k=512,
164 | polar_n=1024,
165 | polar_decoder_type='SC',
166 | polar_decoder_list_size=8,
167 | num_bits_per_symbol=4,
168 | channel_type = 'AWGN',
169 | channel_num_tx_ant=1,
170 | channel_num_rx_ant=1,
171 | **kwargs):
172 | super().__init__(**kwargs)
173 |
174 | self.config = config
175 | self.shared = bart_main_layer.get_input_embeddings()
176 |
177 | # semantic encoders
178 | self.encoder = TFSeq2SeqSCEncoderChannel(
179 | encoder=bart_main_layer.encoder,
180 | ebno_db=ebno_db,
181 | polar_k=polar_k,
182 | polar_n=polar_n,
183 | polar_decoder_type=polar_decoder_type,
184 | polar_decoder_list_size=polar_decoder_list_size,
185 | num_bits_per_symbol=num_bits_per_symbol,
186 | channel_type=channel_type,
187 | channel_num_tx_ant=channel_num_tx_ant,
188 | channel_num_rx_ant=channel_num_rx_ant)
189 | self.decoder = bart_main_layer.decoder
190 |
191 | def get_input_embeddings(self):
192 | return self.shared
193 |
194 | def set_input_embeddings(self, new_embeddings):
195 | self.shared = new_embeddings
196 | self.encoder.encoder.embed_tokens = self.shared
197 | self.decoder.embed_tokens = self.shared
198 |
199 |
200 |
201 | @unpack_inputs
202 | def call(self,
203 | input_ids: Optional[TFModelInputType] = None,
204 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
205 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
206 | decoder_attention_mask: Optional[Union[np.ndarray,
207 | tf.Tensor]] = None,
208 | decoder_position_ids: Optional[Union[np.ndarray,
209 | tf.Tensor]] = None,
210 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
211 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
212 | cross_attn_head_mask: Optional[Union[np.ndarray,
213 | tf.Tensor]] = None,
214 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
215 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray,
216 | tf.Tensor]]]] = None,
217 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
218 | decoder_inputs_embeds: Optional[Union[np.ndarray,
219 | tf.Tensor]] = None,
220 | use_cache: Optional[bool] = None,
221 | output_attentions: Optional[bool] = None,
222 | output_hidden_states: Optional[bool] = None,
223 | return_dict: Optional[bool] = None,
224 | training: Optional[bool] = False,
225 | **kwargs) -> Union[TFSeq2SeqModelOutput, Tuple[tf.Tensor]]:
226 | # different to other models, Bart automatically creates decoder_input_ids from
227 | # input_ids if no decoder_input_ids are provided
228 | if decoder_input_ids is None and decoder_inputs_embeds is None:
229 | if input_ids is None:
230 | raise ValueError(
231 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
232 | "passed, `input_ids` cannot be `None`. Please pass either "
233 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
234 | )
235 |
236 | decoder_input_ids = shift_tokens_right(
237 | input_ids, self.config.pad_token_id,
238 | self.config.decoder_start_token_id)
239 |
240 | if encoder_outputs is None:
241 | encoder_outputs = self.encoder(
242 | input_ids=input_ids,
243 | attention_mask=attention_mask,
244 | head_mask=head_mask,
245 | inputs_embeds=inputs_embeds,
246 | output_attentions=output_attentions,
247 | output_hidden_states=output_hidden_states,
248 | return_dict=return_dict,
249 | training=training,
250 | )
251 |
252 | # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
253 | elif return_dict and not isinstance(encoder_outputs,
254 | TFBaseModelOutput):
255 | encoder_outputs = TFBaseModelOutput(
256 | last_hidden_state=encoder_outputs[0],
257 | hidden_states=encoder_outputs[1]
258 | if len(encoder_outputs) > 1 else None,
259 | attentions=encoder_outputs[2]
260 | if len(encoder_outputs) > 2 else None,
261 | )
262 |
263 | # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
264 | elif not return_dict and not isinstance(encoder_outputs, tuple):
265 | encoder_outputs = encoder_outputs.to_tuple()
266 |
267 | decoder_outputs = self.decoder(
268 | decoder_input_ids,
269 | attention_mask=decoder_attention_mask,
270 | position_ids=decoder_position_ids,
271 | encoder_hidden_states=encoder_outputs[0],
272 | encoder_attention_mask=attention_mask,
273 | head_mask=decoder_head_mask,
274 | cross_attn_head_mask=cross_attn_head_mask,
275 | past_key_values=past_key_values,
276 | inputs_embeds=decoder_inputs_embeds,
277 | use_cache=use_cache,
278 | output_attentions=output_attentions,
279 | output_hidden_states=output_hidden_states,
280 | return_dict=return_dict,
281 | training=training,
282 | )
283 |
284 | if not return_dict:
285 | return decoder_outputs + encoder_outputs
286 |
287 | return TFSeq2SeqModelOutput(
288 | last_hidden_state=decoder_outputs.last_hidden_state,
289 | past_key_values=decoder_outputs.past_key_values,
290 | decoder_hidden_states=decoder_outputs.hidden_states,
291 | decoder_attentions=decoder_outputs.attentions,
292 | cross_attentions=decoder_outputs.cross_attentions,
293 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
294 | encoder_hidden_states=encoder_outputs.hidden_states,
295 | encoder_attentions=encoder_outputs.attentions,
296 | )
297 |
298 |
299 | class TFSeq2SeqSCModel(TFBartPretrainedModel):
300 |
301 | def __init__(self,
302 | config: BartConfig,
303 | load_weight_prefix=None,
304 | ebno_db=None,
305 | polar_k=512,
306 | polar_n=1024,
307 | polar_decoder_type='SC',
308 | polar_decoder_list_size=8,
309 | num_bits_per_symbol=4,
310 | channel_type = 'AWGN',
311 | channel_num_tx_ant = 1,
312 | channel_num_rx_ant = 1,
313 | *inputs,
314 | **kwargs):
315 | super().__init__(config, *inputs, **kwargs)
316 | self.bart_layer = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
317 | # self.bart_layer(DUMMY_INPUTS)
318 | self.model = TFSeq2SeqSCMainLayer(
319 | config,
320 | ebno_db=ebno_db,
321 | bart_main_layer=self.bart_layer,
322 | polar_k=polar_k,
323 | polar_n=polar_n,
324 | polar_decoder_type=polar_decoder_type,
325 | polar_decoder_list_size=polar_decoder_list_size,
326 | num_bits_per_symbol=num_bits_per_symbol,
327 | channel_type=channel_type,
328 | channel_num_tx_ant=channel_num_tx_ant,
329 | channel_num_rx_ant=channel_num_rx_ant
330 | )
331 |
332 | @unpack_inputs
333 | def call(self,
334 | input_ids: Optional[TFModelInputType] = None,
335 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
336 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
337 | decoder_attention_mask: Optional[Union[np.ndarray,
338 | tf.Tensor]] = None,
339 | decoder_position_ids: Optional[Union[np.ndarray,
340 | tf.Tensor]] = None,
341 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
342 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
343 | cross_attn_head_mask: Optional[Union[np.ndarray,
344 | tf.Tensor]] = None,
345 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
346 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray,
347 | tf.Tensor]]]] = None,
348 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
349 | decoder_inputs_embeds: Optional[Union[np.ndarray,
350 | tf.Tensor]] = None,
351 | use_cache: Optional[bool] = None,
352 | output_attentions: Optional[bool] = None,
353 | output_hidden_states: Optional[bool] = None,
354 | return_dict: Optional[bool] = None,
355 | training: Optional[bool] = False,
356 | **kwargs) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
357 |
358 | outputs = self.model(
359 | input_ids=input_ids,
360 | attention_mask=attention_mask,
361 | decoder_input_ids=decoder_input_ids,
362 | decoder_attention_mask=decoder_attention_mask,
363 | decoder_position_ids=decoder_position_ids,
364 | head_mask=head_mask,
365 | decoder_head_mask=decoder_head_mask,
366 | cross_attn_head_mask=cross_attn_head_mask,
367 | encoder_outputs=encoder_outputs,
368 | past_key_values=past_key_values,
369 | inputs_embeds=inputs_embeds,
370 | decoder_inputs_embeds=decoder_inputs_embeds,
371 | use_cache=use_cache,
372 | output_attentions=output_attentions,
373 | output_hidden_states=output_hidden_states,
374 | return_dict=return_dict,
375 | training=training,
376 | )
377 |
378 | return outputs
379 |
380 |
381 | class TFSeq2SeqSCForConditionalGeneration(TFBartForConditionalGeneration):
382 | def __init__(self,
383 | config,
384 | ebno_db=None,
385 | polar_k=512,
386 | polar_n=1024,
387 | polar_decoder_type='SC',
388 | polar_decoder_list_size=8,
389 | num_bits_per_symbol=4,
390 | channel_type = 'AWGN',
391 | channel_num_tx_ant = 1,
392 | channel_num_rx_ant = 1,
393 | *inputs,
394 | **kwargs):
395 | super().__init__(config, *inputs, **kwargs)
396 | self.model = TFSeq2SeqSCMainLayer(
397 | config,
398 | bart_main_layer=self.model,
399 | ebno_db=ebno_db,
400 | polar_k=polar_k,
401 | polar_n=polar_n,
402 | polar_decoder_type=polar_decoder_type,
403 | polar_decoder_list_size=polar_decoder_list_size,
404 | num_bits_per_symbol=num_bits_per_symbol,
405 | channel_type=channel_type,
406 | channel_num_tx_ant=channel_num_tx_ant,
407 | channel_num_rx_ant=channel_num_rx_ant,
408 | name="model")
409 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | @tf.function
5 | def tensor_to_binary(x):
6 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32))
7 | TO_MANTISSA = tf.constant([1<<23], tf.float32)
8 |
9 | # sign
10 | sign = tf.cast(x < 0.0, tf.float32)
11 | x = tf.math.abs(x)
12 |
13 | # exponent
14 | log_x = tf.math.floor(tf.math.log(x) / LOG_BASE2)
15 | exponent = tf.cast(log_x + 127.0, tf.uint8)
16 |
17 | # mantissa
18 | mantissa = x / tf.math.exp(log_x*LOG_BASE2) - tf.math.sign(x)
19 | mantissa = tf.math.floor(mantissa * TO_MANTISSA)
20 | mantissa = tf.cast(mantissa, tf.int32)
21 |
22 | # convert to bits
23 | bits = [None for i in range(32)]
24 | for i in range(23):
25 | bits[i] = tf.bitwise.bitwise_and(mantissa, 1)
26 | mantissa = tf.bitwise.right_shift(mantissa, 1)
27 | for i in range(23, 31):
28 | bits[i] = tf.bitwise.bitwise_and(exponent, 1)
29 | exponent = tf.bitwise.right_shift(exponent, 1)
30 | bits[31] = sign
31 |
32 | for i in range(32):
33 | bits[i] = tf.cast(bits[i], tf.float32)
34 | res = tf.stack(bits, axis=-1)
35 | return res
36 |
37 | @tf.function
38 | def binary_to_tensor(x: tf.Tensor):
39 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32))
40 | EXPONENTS = tf.constant([ float(1 << i) for i in range(8)], tf.float32)
41 | FROM_MANTISSA = tf.constant([ 0.5**(23-i) for i in range(23)], tf.float32)
42 |
43 | x = tf.reshape(x, (-1, 32))
44 | sign = -x[:, 31] * 2 + 1
45 |
46 | exponent = tf.math.reduce_sum(x[:, 23:31] * EXPONENTS, axis=-1)
47 | mantissa = tf.math.reduce_sum(x[:, :23] * FROM_MANTISSA, axis=-1)
48 | mantissa += tf.cast(exponent > 0.0, tf.float32)
49 | return sign * tf.math.exp((exponent - 127.0) * LOG_BASE2) * mantissa
50 |
51 |
52 |
53 | @tf.function
54 | def tensor_to_binary(x: tf.Tensor):
55 | x = tf.bitcast(x, tf.uint32)
56 | mask = tf.ones_like(x)
57 | bit0 = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)),
58 | tf.float32)
59 | bits = [bit0]
60 |
61 | for _ in range(31):
62 | x = tf.bitwise.right_shift(x, 1)
63 | bitn = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)),
64 | tf.float32)
65 | bits.append(bitn)
66 |
67 | return tf.concat(bits, axis=0)
68 |
69 | @tf.function
70 | def replace_nan(input, new_value = 0.0):
71 | new_value = float(new_value)
72 | indices = tf.where(tf.math.is_nan(input))
73 | res = tf.tensor_scatter_nd_update(
74 | input,
75 | indices,
76 | tf.fill((tf.shape(indices)[0], ), new_value)
77 | )
78 | return res
79 |
80 | INF = tf.constant(np.array([np.inf]), dtype=tf.float32)
81 |
82 | @tf.function
83 | def replace_nan_to_inf(x):
84 | EXPONENT_MASK = tf.constant([0x7F800000], dtype=tf.uint32)
85 | INF_MASK = tf.constant([0xFF800000], dtype=tf.uint32)
86 | IDENTITY_MASK = tf.constant([0xFFFFFFFF], dtype=tf.uint32)
87 | x = tf.bitcast(x, tf.uint32)
88 | mask = tf.where(
89 | tf.equal(tf.bitwise.bitwise_and(x, EXPONENT_MASK), EXPONENT_MASK),
90 | INF_MASK, IDENTITY_MASK
91 | )
92 | return tf.bitcast(tf.bitwise.bitwise_and(x, mask), tf.float32)
93 |
--------------------------------------------------------------------------------
/models/utils_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import tensorflow as tf
3 | import numpy as np
4 | from .utils import *
5 |
6 | class TestBinaryConversion(unittest.TestCase):
7 |
8 | def test_conversion(self):
9 | x = tf.random.uniform((1, 64, 128))
10 | tmp = tensor_to_binary(x)
11 | y = binary_to_tensor(tmp)
12 | y = tf.reshape(y, (1, 64, 128))
13 |
14 | x = tf.bitcast(x, tf.uint32).numpy()
15 | y = tf.bitcast(y, tf.uint32).numpy()
16 |
17 | # bit-level exact match
18 | np.testing.assert_array_equal(x, y)
19 |
20 | def test_conversion_v2(self):
21 | shape = (1, 64, 128)
22 | x = tf.random.uniform(shape)
23 | tmp = tensor_to_binary_v2(x)
24 | y = binary_to_tensor_v2(tmp)
25 | y = tf.reshape(y, shape)
26 |
27 | x = tf.bitcast(x, tf.uint32).numpy()
28 | y = tf.bitcast(y, tf.uint32).numpy()
29 |
30 | # bit-level exact match
31 | # np.testing.assert_almost_equal(x, y)
32 | np.testing.assert_array_equal(x, y)
33 |
34 | def test_tensor_to_binary(self):
35 | # IEEE754
36 | # f32: 6027202.5
37 | # hex: 0x4ab7ef85
38 | # bin: 01001010101101111110111110000101
39 | x = tf.constant([6027202.5])
40 | bin = tensor_to_binary(x)
41 | actual = bin.numpy().astype(np.int32).flatten()
42 | expected = np.array([
43 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
44 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1
45 | ], dtype=np.int32)[::-1]
46 | self.assertTrue((actual == expected).all())
47 |
48 | def test_tensor_to_binary_v2(self):
49 | # IEEE754
50 | # f32: 6027202.5
51 | # hex: 0x4ab7ef85
52 | # bin: 01001010101101111110111110000101
53 | x = tf.constant([6027202.5])
54 | bin = tensor_to_binary_v2(x)
55 | actual = bin.numpy().astype(np.int32).flatten()
56 | expected = np.array([
57 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
58 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1
59 | ], dtype=np.int32)[::-1]
60 | np.testing.assert_equal(actual, expected)
61 |
62 |
63 | def test_binary_to_tensor(self):
64 | # IEEE754
65 | # f32: 6027202.5
66 | # hex: 0x4ab7ef85
67 | # bin: 01001010101101111110111110000101
68 | x = np.array([
69 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
70 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1
71 | ],
72 | dtype=np.float32)[::-1]
73 | x = tf.constant(x)
74 | bin = binary_to_tensor(x)
75 | actual = bin.numpy()[0]
76 | expected = 6027202.5
77 |
78 | self.assertAlmostEqual(actual, expected)
79 |
80 | def test_binary_to_tensor_v2(self):
81 | # IEEE754
82 | # f32: 6027202.5
83 | # hex: 0x4ab7ef85
84 | # bin: 01001010101101111110111110000101
85 | x = np.array([
86 | 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
87 | 1, 1, 1, 0, 0, 0, 0, 1, 0, 1
88 | ],
89 | dtype=np.float32)[::-1]
90 | x = tf.constant(x)
91 | bin = binary_to_tensor_v2(x)
92 | actual = bin.numpy()[0]
93 | expected = 6027202.5
94 | self.assertAlmostEqual(actual, expected)
95 |
96 | def test_binary_to_tensor_with_tanh_v2(self):
97 | # f32: inf
98 | # hex: 0x7f800000
99 | # bin: 01111111100000000000000000000000
100 | x = np.array([0] + [1 for i in range(8)] + [0 for i in range(23)],
101 | dtype=np.float32)[::-1]
102 | x = tf.constant(x)
103 | y = tf.math.tanh(binary_to_tensor_v2(x))
104 | actual = y.numpy()[0]
105 | expected = 1.0
106 | self.assertAlmostEqual(actual, expected)
107 | # f32: -inf
108 | # hex: 0xff800000
109 | # bin: 11111111100000000000000000000000
110 | x = np.array([1] + [1 for i in range(8)] + [0 for i in range(23)],
111 | dtype=np.float32)[::-1]
112 | x = tf.constant(x)
113 | y = tf.math.tanh(binary_to_tensor_v2(x))
114 | actual = y.numpy()[0]
115 | expected = -1.0
116 | self.assertAlmostEqual(actual, expected)
117 | # f32: nan
118 | # hex: 0x7fc00000
119 | # bin: 01111111110000000000000000000000
120 | x = np.array([0] + [1 for i in range(8)] + [1] + [0 for i in range(22)],
121 | dtype=np.float32)[::-1]
122 | x = tf.constant(x)
123 | y = tf.math.tanh(binary_to_tensor_v2(x))
124 | actual = y.numpy()[0]
125 | expected = 1.0
126 | self.assertAlmostEqual(actual, expected)
127 | # f32: -nan
128 | # hex: 0xffc00000
129 | # bin: 11111111110000000000000000000000
130 | x = np.array([1] + [1 for i in range(8)] + [1] + [0 for i in range(22)],
131 | dtype=np.float32)[::-1]
132 | x = tf.constant(x)
133 | y = tf.math.tanh(binary_to_tensor_v2(x))
134 | actual = y.numpy()[0]
135 | expected = -1.0
136 | self.assertAlmostEqual(actual, expected)
137 |
138 |
139 |
140 | def test_replace_nan(self):
141 | x = np.array([
142 | [[1, 2, np.nan],
143 | [4, np.nan, 6],],
144 | [[7, 8, np.nan],
145 | [10, np.nan, 12],],
146 | ], dtype=np.float32)
147 | expected = np.array([
148 | [[1, 2, 0.0],
149 | [4, 0, 6],],
150 | [[7, 8, 0],
151 | [10, 0, 12],],
152 | ], dtype=np.float32)
153 |
154 | actual = replace_nan(tf.constant(x))
155 | actual = actual.numpy()
156 |
157 | np.testing.assert_almost_equal(actual, expected)
158 |
159 | def test_replace_nan_to_inf(self):
160 | x = np.array([
161 | [[1, 2, np.nan],
162 | [4, -np.nan, np.inf],],
163 | [[7, 8, -np.nan],
164 | [10, np.nan, -np.inf],],
165 | ], dtype=np.float32)
166 | expected = np.array([
167 | [[1, 2, np.inf],
168 | [4, -np.inf, np.inf],],
169 | [[7, 8, -np.inf],
170 | [10, np.inf, -np.inf],],
171 | ], dtype=np.float32)
172 |
173 | actual = replace_nan_to_inf(tf.constant(x))
174 | actual = actual.numpy()
175 |
176 | np.testing.assert_almost_equal(actual, expected)
177 |
178 |
179 | if __name__ == '__main__':
180 | unittest.main()
181 |
--------------------------------------------------------------------------------
/poster_seq2seq_Asilomar2023.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/poster_seq2seq_Asilomar2023.jpg
--------------------------------------------------------------------------------
/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/preprocess/__init__.py
--------------------------------------------------------------------------------
/preprocess/allnli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pathlib
4 | import random
5 |
6 | from .hf_data_gen import HFDataGenerator
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC")
10 | parser.add_argument(
11 | '-o', '--out-path',
12 | dest='out_path',
13 | required=True,
14 | type=pathlib.Path,
15 | help='Required. Path of output directory')
16 | parser.add_argument(
17 | '--train-dev-split',
18 | dest='train_dev_split',
19 | default=0.9,
20 | type=float,
21 | help='Trainset/ Devset split ratio')
22 | parser.add_argument(
23 | '--seed',
24 | dest='seed',
25 | default=1234,
26 | type=int,
27 | help='Random seed')
28 | parser.add_argument(
29 | dest='path',
30 | type=pathlib.Path,
31 | help="Path of AllNLI.jsonl")
32 | args = parser.parse_args()
33 |
34 | random.seed(args.seed)
35 |
36 | with open(args.path) as f:
37 | data = [json.loads(line) for line in f]
38 |
39 | data = map(lambda l: (l[0], l[1]), data)
40 | sentences = filter(lambda l: 'n/a' not in l, data)
41 | sentences = list(sentences)
42 |
43 | N = len(sentences)
44 | devset_size = int(N*(1-args.train_dev_split))
45 | devset_indices = random.sample(range(N), devset_size)
46 | devset_indices = set(devset_indices)
47 |
48 | trainset_gen = HFDataGenerator()
49 | devset_gen = HFDataGenerator()
50 | for i, (s1, s2) in enumerate(sentences):
51 | if i in devset_indices:
52 | devset_gen.add(s1, s2)
53 | else:
54 | trainset_gen.add(s1, s2)
55 |
56 | trainset_gen.dump(args.out_path / 'allnli_train.csv')
57 | devset_gen.dump(args.out_path / 'allnli_dev.csv')
--------------------------------------------------------------------------------
/preprocess/europarl.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import json
3 | import argparse
4 | from .hf_data_gen import HFDataGenerator
5 |
6 | # The following code is copied (and slightly modified) from DeepSC
7 | # (https://github.com/zyy598/DeepSC/blob/master/preprocess_text.py)
8 | import unicodedata
9 | import re
10 | from w3lib.html import remove_tags
11 | import pickle
12 | import os
13 | import json
14 | from tqdm import tqdm
15 |
16 | def unicode_to_ascii(s):
17 | return ''.join(c for c in unicodedata.normalize('NFD', s)
18 | if unicodedata.category(c) != 'Mn')
19 |
20 | def normalize_string(s):
21 | # normalize unicode characters
22 | s = unicode_to_ascii(s)
23 | # remove the XML-tags
24 | s = remove_tags(s)
25 | # add white space before !.?
26 | s = re.sub(r'([!.?])', r' \1', s)
27 | s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
28 | s = re.sub(r'\s+', r' ', s)
29 | # change to lower letter
30 | s = s.lower()
31 | return s
32 |
33 | def cutted_data(cleaned, MIN_LENGTH=4, MAX_LENGTH=30):
34 | cutted_lines = list()
35 | for line in cleaned:
36 | length = len(line.split())
37 | if length > MIN_LENGTH and length < MAX_LENGTH:
38 | line = [word for word in line.split()]
39 | cutted_lines.append(' '.join(line))
40 | return cutted_lines
41 |
42 | def save_clean_sentences(sentence, save_path):
43 | pickle.dump(sentence, open(save_path, 'wb'))
44 | print('Saved: %s' % save_path)
45 |
46 | def process_text_file(file_path):
47 | with open(file_path, 'r') as f:
48 | raw_data = f.read()
49 | sentences = raw_data.strip().split('\n')
50 | raw_data_input = [normalize_string(data) for data in sentences]
51 | raw_data_input = cutted_data(raw_data_input)
52 | return raw_data_input
53 |
54 | def tokenize(s, delim=' ', add_start_token=True, add_end_token=True,
55 | punct_to_keep=None, punct_to_remove=None):
56 | """
57 | Tokenize a sequence, converting a string s into a list of (string) tokens by
58 | splitting on the specified delimiter. Optionally keep or remove certain
59 | punctuation marks and add start and end tokens.
60 | """
61 | if punct_to_keep is not None:
62 | for p in punct_to_keep:
63 | s = s.replace(p, '%s%s' % (delim, p))
64 |
65 | if punct_to_remove is not None:
66 | for p in punct_to_remove:
67 | s = s.replace(p, '')
68 |
69 | tokens = s.split(delim)
70 | if add_start_token:
71 | tokens.insert(0, '')
72 | if add_end_token:
73 | tokens.append('')
74 | return tokens
75 |
76 | def build_vocab(sequences, token_to_idx = { }, min_token_count=1, delim=' ',
77 | punct_to_keep=None, punct_to_remove=None, ):
78 | token_to_count = {}
79 |
80 | for seq in sequences:
81 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep,
82 | punct_to_remove=punct_to_remove,
83 | add_start_token=False, add_end_token=False)
84 | for token in seq_tokens:
85 | if token not in token_to_count:
86 | token_to_count[token] = 0
87 | token_to_count[token] += 1
88 |
89 | for token, count in sorted(token_to_count.items()):
90 | if count >= min_token_count:
91 | token_to_idx[token] = len(token_to_idx)
92 |
93 | return token_to_idx
94 |
95 | def encode(seq_tokens, token_to_idx, allow_unk=False):
96 | seq_idx = []
97 | for token in seq_tokens:
98 | if token not in token_to_idx:
99 | if allow_unk:
100 | token = ''
101 | else:
102 | raise KeyError('Token "%s" not in vocab' % token)
103 | seq_idx.append(token_to_idx[token])
104 | return seq_idx
105 |
106 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True):
107 | tokens = []
108 | for idx in seq_idx:
109 | tokens.append(idx_to_token[idx])
110 | if stop_at_end and tokens[-1] == '':
111 | break
112 | if delim is None:
113 | return tokens
114 | else:
115 | return delim.join(tokens)
116 |
117 | SPECIAL_TOKENS = {
118 | '': 0,
119 | '': 1,
120 | '': 2,
121 | '': 3,
122 | }
123 |
124 | def process_europarl(input_data_dir, train_test_split=0.9, njobs=1):
125 | sentences = []
126 | print('Preprocess Raw Text')
127 | from joblib import Parallel, delayed
128 | sentences = Parallel(n_jobs=njobs, verbose=1)(
129 | delayed(process_text_file)(fn)
130 | for fn in pathlib.Path(input_data_dir).glob('*.txt'))
131 | sentences = [s for s_list in sentences for s in s_list ]
132 |
133 | # remove the same sentences
134 | a = {}
135 | for set in sentences:
136 | if set not in a:
137 | a[set] = 0
138 | a[set] += 1
139 | sentences = list(a.keys())
140 | print('Number of sentences: {}'.format(len(sentences)))
141 |
142 | print('Build Vocab')
143 | token_to_idx = build_vocab(
144 | sentences, SPECIAL_TOKENS,
145 | punct_to_keep=[';', ','], punct_to_remove=['?', '.']
146 | )
147 |
148 | vocab = {'token_to_idx': token_to_idx}
149 | print('Number of words in Vocab: {}'.format(len(token_to_idx)))
150 |
151 | print('Start encoding txt')
152 | results = []
153 | for seq in tqdm(sentences):
154 | words = tokenize(seq, punct_to_keep=[';', ','], punct_to_remove=['?', '.'])
155 | tokens = [token_to_idx[word] for word in words]
156 | results.append(tokens)
157 |
158 | train_data = results[: round(len(results) * train_test_split)]
159 | test_data = results[round(len(results) * train_test_split):]
160 |
161 | return train_data, test_data, vocab
162 | # End of the copied code
163 |
164 | class Tokenizer:
165 |
166 | TOKENS_FILTERED = set([
167 | '', ''
168 | ])
169 |
170 | def __init__(self, vocab):
171 | idx_to_token = [None for _ in range(1 + max(vocab['token_to_idx'].values()))]
172 | for token, idx in vocab['token_to_idx'].items():
173 | idx_to_token[idx] = token
174 | self.idx_to_token = idx_to_token
175 | self.token_to_idx = vocab['token_to_idx']
176 |
177 | def decode(self, token_ids):
178 | tokens = map(lambda i: self.idx_to_token[i], token_ids)
179 | tokens = filter(lambda t: t not in self.TOKENS_FILTERED, tokens)
180 | return ' '.join(tokens)
181 |
182 | def batch_decode(self, token_ids_list):
183 | return list(map(lambda token_ids: self.decode(token_ids), token_ids_list))
184 |
185 | def gen_hf_dataset(path: pathlib.Path, output_path=None, train_test_split=0.9, njobs=1):
186 | path = pathlib.Path(path)
187 | if output_path is None:
188 | output_path = path
189 |
190 | train_data, test_data, vocab = process_europarl(path, train_test_split, njobs)
191 |
192 | # save processed sentences
193 | with open(output_path / 'train_data.pkl', 'wb') as f:
194 | pickle.dump(train_data, f)
195 | with open(output_path / 'test_data.pkl', 'wb') as f:
196 | pickle.dump(test_data, f)
197 | with open(output_path / 'vocab.json', 'w') as f:
198 | json.dump(vocab, f)
199 |
200 | tokenizer = Tokenizer(vocab)
201 |
202 | # train set
203 | train_data = tokenizer.batch_decode(train_data)
204 | train_gen = HFDataGenerator()
205 | train_gen.add(train_data, train_data)
206 | train_gen.dump(output_path / 'train.csv')
207 |
208 | # test set
209 | test_data = tokenizer.batch_decode(test_data)
210 | test_gen = HFDataGenerator()
211 | test_gen.add(test_data, test_data)
212 | test_gen.dump(output_path / 'test.csv')
213 |
214 |
215 | if __name__ == '__main__':
216 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC")
217 | parser.add_argument(
218 | '-o', '--out-path',
219 | dest='out_path',
220 | required=True,
221 | type=pathlib.Path,
222 | help='Required. Path of output files.')
223 | parser.add_argument(
224 | '--train-test-split',
225 | dest='train_test_split',
226 | default=0.9,
227 | type=float,
228 | help='Trainset/ Testset split ratio')
229 | parser.add_argument(
230 | '-j'
231 | '--njobs',
232 | dest='njobs',
233 | default=1,
234 | type=int,
235 | help='Number of threads to be used for preprocessing')
236 | parser.add_argument(
237 | dest='path',
238 | type=pathlib.Path,
239 | help="Path of europarl dataset. It should be '/txt/en'")
240 | args = parser.parse_args()
241 | gen_hf_dataset(args.path, args.out_path, args.train_test_split, args.njobs)
242 |
243 |
--------------------------------------------------------------------------------
/preprocess/flickr30k.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import json
3 | import random
4 | import argparse
5 | import pathlib
6 |
7 | def parse_line(line: str):
8 | key = line.split('#')[0]
9 | caption = line.split('\t')[-1].strip()
10 | return key, caption
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-o', '--out-path',
15 | dest='out_path',
16 | type=pathlib.Path,
17 | help='output json file path')
18 | parser.add_argument('-n', '--num-samples',
19 | dest='N',
20 | default=1000,
21 | type=int,
22 | help='number of samples (default: 1000)')
23 | parser.add_argument('--send',
24 | dest='seed',
25 | default=20221017,
26 | type=int,
27 | help='seed for random module')
28 |
29 | parser.add_argument(
30 | dest='token_path',
31 | help="path of 'results_20130124.token'")
32 | args = parser.parse_args()
33 |
34 | # read token file
35 | data = defaultdict(list)
36 | with open(args.token_path) as f:
37 | for k, caption in map(parse_line, f):
38 | data[k].append(caption)
39 | data = list(data.values())
40 |
41 | # set seed
42 | random.seed(args.seed)
43 |
44 | # sample dataset
45 | samples = random.sample(range(len(data)), k=args.N)
46 | out_data = []
47 | for i in samples:
48 | captions = data[i]
49 | input_idx = random.sample(range(len(captions)), k=1)[0]
50 | input_sentence = captions[input_idx]
51 | ref_sentences = captions[:input_idx] + captions[(input_idx+1):]
52 | out_data.append({
53 | 'input': input_sentence,
54 | 'refs': ref_sentences,
55 | })
56 |
57 | with open(args.out_path, 'w') as f:
58 | json.dump(out_data, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/hf_data_gen.py:
--------------------------------------------------------------------------------
1 | # exampel file format
2 | # --------------------
3 | # text,summary
4 | # "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder","I'm sitting in a room where I'm waiting for something to happen"
5 | # "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.","I'm a gardener and I'm a big fan of flowers."
6 | # "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share","It's that time of year again."
7 | import csv
8 |
9 | class HFDataGenerator:
10 |
11 | FIELDNAMES = ['text', 'summary']
12 |
13 | def __init__(self):
14 | self.rows = []
15 |
16 | def add(self, text, summary):
17 | if isinstance(text, str) and isinstance(summary, str):
18 | self.rows.append({
19 | 'text': f'{text}',
20 | 'summary': f'{summary}',
21 | })
22 | elif isinstance(text, list) and isinstance(summary, list):
23 | assert len(text) == len(summary), "Different text and summary list"
24 | rows = map(lambda ts: {'text': ts[0], 'summary': ts[1]}, zip(text, summary))
25 | self.rows.extend(rows)
26 | else:
27 | assert False, "No such case"
28 |
29 |
30 | def dump(self, filename):
31 | with open(filename, 'w', newline='') as f:
32 | writer = csv.DictWriter(f, fieldnames=self.FIELDNAMES, quoting=csv.QUOTE_ALL)
33 | writer.writeheader()
34 | writer.writerows(self.rows)
35 |
36 | def test():
37 | data_path = "/tmp/hf_data_test.csv"
38 | gen = HFDataGenerator()
39 | gen.add("text1", "summary1")
40 | gen.add("text2", "summary2")
41 | gen.dump(data_path)
42 |
43 | expected = """"text","summary"
44 | "text1","summary1"
45 | "text2","summary2"
46 | """
47 | with open( data_path) as f:
48 | actual = f.read()
49 | assert actual == expected, f"""'{actual}'"""
50 |
51 | if __name__ == '__main__':
52 | test()
53 |
54 |
--------------------------------------------------------------------------------
/scripts/eval_flickr.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ebno_db="10"
3 | metric="sbert" # bleu, sbert
4 | testset_path='data/flickr/processed/flickr30k.json'
5 | checkpoint_path="checkpoints/seq2seq-allnli-sc"
6 |
7 | python eval.py \
8 | --batch 4 \
9 | --metric "${metric}" \
10 | --ebno-db "${ebno_db}" \
11 | --result-json-path "${checkpoint_path}/flikr_${metric}_ebno_${ebno_db}.json" \
12 | --prediction-json-path "${checkpoint_path}/flikr_prediction_ebno_${ebno_db}.json" \
13 | --testset-path "${testset_path}" \
14 | $checkpoint_path
--------------------------------------------------------------------------------
/scripts/preprocess_allnli.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | data_path=data/allnli
3 | mkdir -p $data_path
4 | wget -P $data_path https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/datasets/paraphrases/AllNLI.jsonl.gz
5 | gunzip $data_path/AllNLI.jsonl.gz
6 |
7 | allnli_dataset="$data_path/AllNLI.jsonl"
8 | out_dir="$data_path/processed"
9 |
10 | mkdir -p $out_dir
11 | python -m preprocess.allnli -o $out_dir $allnli_dataset
--------------------------------------------------------------------------------
/scripts/preprocess_europarl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | europarl_dataset=data/europarl/txt/en
3 | out_dir=data/europarl/processed
4 | njobs=4
5 |
6 | mkdir -p $out_dir
7 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset
8 |
--------------------------------------------------------------------------------
/scripts/preprocess_flickr30k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | data_path="data/flickr"
3 | dataset_path="${data_path}/flickr30k.tar.gz"
4 | out_dir="$data_path/processed"
5 |
6 | mkdir -p $out_dir
7 |
8 | tar xzf ${dataset_path} -C $data_path
9 | python -m preprocess.flickr30k \
10 | -o "$out_dir/flickr30k.json" \
11 | "${data_path}/results_20130124.token"
--------------------------------------------------------------------------------
/scripts/train_allnli.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | output_dir='checkpoints/seq2seq-allnli-sc'
3 | trainset_path='data/allnli/processed/allnli_train.csv'
4 | devset_path='data/allnli/processed/allnli_dev.csv'
5 |
6 | mkdir -p $output_dir
7 |
8 | python train.py \
9 | --per_device_train_batch_size 4 \
10 | --num_train_epochs 3 \
11 | --do_train \
12 | --do_eval \
13 | --model_name_or_path facebook/bart-base \
14 | --preprocessing_num_workers 4 \
15 | --save_total_limit 1 \
16 | --no_use_fast_tokenizer \
17 | --num_beams 4 \
18 | --max_source_length 64 \
19 | --max_target_length 64 \
20 | --train_file "$trainset_path" \
21 | --validation_file "$devset_path" \
22 | --test_file "$devset_path" \
23 | --output_dir $output_dir \
24 | --ebno_db 10 \
25 | --channel_type AWGN \
26 | --overwrite_output_dir \
27 | --tokenizer_name facebook/bart-base \
28 | --pad_to_max_length \
29 | --dataset_config 3.0.0
--------------------------------------------------------------------------------
/scripts/train_europarl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | output_dir='checkpoints/seq2seq-europarl-sc'
3 | trainset_path='data/europarl/processed/train.csv'
4 | devset_path='data/europarl/processed/test.csv'
5 |
6 | mkdir -p $output_dir
7 |
8 | python train.py \
9 | --per_device_train_batch_size 4 \
10 | --num_train_epochs 3 \
11 | --do_train \
12 | --do_eval \
13 | --model_name_or_path facebook/bart-base \
14 | --preprocessing_num_workers 4 \
15 | --save_total_limit 1 \
16 | --no_use_fast_tokenizer \
17 | --num_beams 4 \
18 | --max_source_length 64 \
19 | --max_target_length 64 \
20 | --train_file "$trainset_path" \
21 | --validation_file "$devset_path" \
22 | --test_file "$devset_path" \
23 | --output_dir $output_dir \
24 | --ebno_db 10 \
25 | --channel_type AWGN \
26 | --overwrite_output_dir \
27 | --tokenizer_name facebook/bart-base \
28 | --pad_to_max_length \
29 | --dataset_config 3.0.0
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # This file is based on
18 | # https://github.com/huggingface/transformers/blob/main/examples/tensorflow/summarization/run_summarization.py
19 | import json
20 | import logging
21 | import os
22 | import sys
23 | from dataclasses import dataclass, field
24 | from typing import Optional, List
25 |
26 | import datasets
27 | import nltk
28 | import numpy as np
29 | import tensorflow as tf
30 | from datasets import load_dataset
31 |
32 | import evaluate
33 | import transformers
34 | from filelock import FileLock
35 | from transformers import (
36 | AutoConfig,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | KerasMetricCallback,
41 | TFTrainingArguments,
42 | set_seed,
43 | )
44 | from transformers.trainer_utils import get_last_checkpoint
45 | from transformers.utils import is_offline_mode
46 | from transformers.optimization_tf import create_optimizer
47 | from train.args import ModelArguments, DataTrainingArguments, summarization_name_mapping, Seq2SeqSCArguments
48 | from models import TFSeq2SeqSCForConditionalGeneration
49 |
50 | logger = logging.getLogger(__name__)
51 |
52 | try:
53 | nltk.data.find("tokenizers/punkt")
54 | except (LookupError, OSError):
55 | if is_offline_mode():
56 | raise LookupError(
57 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
58 | )
59 | with FileLock(".lock") as lock:
60 | nltk.download("punkt", quiet=True)
61 | # endregion
62 |
63 | def main():
64 | # region Argument parsing
65 | # See all possible arguments in src/transformers/training_args.py
66 | # or by passing the --help flag to this script.
67 | # We now keep distinct sets of args, for a cleaner separation of concerns.
68 |
69 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments, Seq2SeqSCArguments))
70 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
71 | # If we pass only one argument to the script and it's the path to a json file,
72 | # let's parse it to get our arguments.
73 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_json_file(
74 | json_file=os.path.abspath(sys.argv[1]))
75 | else:
76 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_args_into_dataclasses()
77 |
78 | if training_args.fp16:
79 | policy = tf.keras.mixed_precision.Policy('mixed_float16')
80 | tf.keras.mixed_precision.set_global_policy(policy)
81 |
82 | # region Logging
83 | logging.basicConfig(
84 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
85 | datefmt="%m/%d/%Y %H:%M:%S",
86 | handlers=[logging.StreamHandler(sys.stdout)],
87 | )
88 | logger.setLevel(logging.INFO)
89 | datasets.utils.logging.set_verbosity(logging.INFO)
90 | transformers.utils.logging.set_verbosity(logging.INFO)
91 |
92 | # Log on each process the small summary:
93 | logger.info(f"Training/evaluation parameters {training_args}")
94 | # endregion
95 |
96 | # region Detecting last checkpoint
97 | last_checkpoint = None
98 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
99 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
100 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
101 | raise ValueError(
102 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
103 | "Use --overwrite_output_dir to overcome."
104 | )
105 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
106 | logger.info(
107 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
108 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
109 | )
110 | # endregion
111 |
112 | # Set seed before initializing model.
113 | set_seed(training_args.seed)
114 |
115 | # region Load datasets
116 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
117 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
118 | # (the dataset will be downloaded automatically from the datasets Hub).
119 | #
120 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
121 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
122 | #
123 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
124 | # download the dataset.
125 | if data_args.dataset_name is not None:
126 | # Downloading and loading a dataset from the hub.
127 | raw_datasets = load_dataset(
128 | data_args.dataset_name,
129 | data_args.dataset_config_name,
130 | cache_dir=model_args.cache_dir,
131 | use_auth_token=True if model_args.use_auth_token else None,
132 | )
133 | else:
134 | data_files = {}
135 | if data_args.train_file is not None:
136 | data_files["train"] = data_args.train_file
137 | extension = data_args.train_file.split(".")[-1]
138 | if data_args.validation_file is not None:
139 | data_files["validation"] = data_args.validation_file
140 | extension = data_args.validation_file.split(".")[-1]
141 | if data_args.test_file is not None:
142 | data_files["test"] = data_args.test_file
143 | extension = data_args.test_file.split(".")[-1]
144 | raw_datasets = load_dataset(
145 | extension,
146 | data_files=data_files,
147 | cache_dir=model_args.cache_dir,
148 | use_auth_token=True if model_args.use_auth_token else None,
149 | )
150 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
151 | # https://huggingface.co/docs/datasets/loading_datasets.html.
152 | # endregion
153 |
154 | # region Load model config and tokenizer
155 | #
156 | # Distributed training:
157 | # The .from_pretrained methods guarantee that only one local process can concurrently
158 | # download model & vocab.
159 |
160 | config = AutoConfig.from_pretrained(
161 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
162 | cache_dir=model_args.cache_dir,
163 | revision=model_args.model_revision,
164 | use_auth_token=True if model_args.use_auth_token else None,
165 | )
166 | tokenizer = AutoTokenizer.from_pretrained(
167 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
168 | cache_dir=model_args.cache_dir,
169 | use_fast=model_args.use_fast_tokenizer,
170 | revision=model_args.model_revision,
171 | use_auth_token=True if model_args.use_auth_token else None,
172 | )
173 |
174 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
175 | # endregion
176 |
177 | # region Dataset preprocessing
178 | # We need to tokenize inputs and targets.
179 | if training_args.do_train:
180 | column_names = raw_datasets["train"].column_names
181 | elif training_args.do_eval:
182 | column_names = raw_datasets["validation"].column_names
183 | else:
184 | logger.info("There is nothing to do. Please pass `do_train`, and/or `do_eval`.")
185 | return
186 |
187 | # Get the column names for input/target.
188 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
189 | if data_args.text_column is None:
190 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
191 | else:
192 | text_column = data_args.text_column
193 | if text_column not in column_names:
194 | raise ValueError(
195 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
196 | )
197 | if data_args.summary_column is None:
198 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
199 | else:
200 | summary_column = data_args.summary_column
201 | if summary_column not in column_names:
202 | raise ValueError(
203 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
204 | )
205 |
206 | # Temporarily set max_target_length for training.
207 | max_target_length = data_args.max_target_length
208 | padding = "max_length" if data_args.pad_to_max_length else False
209 |
210 | def preprocess_function(examples):
211 | inputs = examples[text_column]
212 | assert prefix is not None
213 | for i, inp in enumerate(inputs):
214 | if inp is None:
215 | print(i, inputs[i], inputs[i-1], inputs[i+1])
216 | targets = examples[summary_column]
217 | inputs = [prefix + inp for inp in inputs]
218 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
219 |
220 | # Tokenize targets with the `text_target` keyword argument
221 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
222 |
223 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
224 | # padding in the loss.
225 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
226 | labels["input_ids"] = [
227 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
228 | ]
229 |
230 | model_inputs["labels"] = labels["input_ids"]
231 | return model_inputs
232 |
233 | if training_args.do_train:
234 | if "train" not in raw_datasets:
235 | raise ValueError("--do_train requires a train dataset")
236 | train_dataset = raw_datasets["train"]
237 | if data_args.max_train_samples is not None:
238 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
239 | train_dataset = train_dataset.select(range(max_train_samples))
240 | with training_args.main_process_first(desc="train dataset map pre-processing"):
241 | train_dataset = train_dataset.map(
242 | preprocess_function,
243 | batched=True,
244 | num_proc=data_args.preprocessing_num_workers,
245 | remove_columns=column_names,
246 | load_from_cache_file=not data_args.overwrite_cache,
247 | desc="Running tokenizer on train dataset",
248 | )
249 | else:
250 | train_dataset = None
251 |
252 | if training_args.do_eval:
253 | max_target_length = data_args.val_max_target_length
254 | if "validation" not in raw_datasets:
255 | raise ValueError("--do_eval requires a validation dataset")
256 | eval_dataset = raw_datasets["validation"]
257 | if data_args.max_eval_samples is not None:
258 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
259 | eval_dataset = eval_dataset.select(range(max_eval_samples))
260 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
261 | eval_dataset = eval_dataset.map(
262 | preprocess_function,
263 | batched=True,
264 | num_proc=data_args.preprocessing_num_workers,
265 | remove_columns=column_names,
266 | load_from_cache_file=not data_args.overwrite_cache,
267 | desc="Running tokenizer on validation dataset",
268 | )
269 | else:
270 | eval_dataset = None
271 | # endregion
272 |
273 | # region Text preprocessing
274 | def postprocess_text(preds, labels):
275 | preds = [pred.strip() for pred in preds]
276 | labels = [label.strip() for label in labels]
277 |
278 | # rougeLSum expects newline after each sentence
279 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
280 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
281 |
282 | return preds, labels
283 |
284 | # endregion
285 | with training_args.strategy.scope():
286 | # region Prepare model
287 | model_cls = TFSeq2SeqSCForConditionalGeneration
288 |
289 | model = model_cls.from_pretrained(
290 | model_args.model_name_or_path,
291 | ebno_db=seq2seq_sc_args.ebno_db,
292 | polar_k=seq2seq_sc_args.k,
293 | polar_n=seq2seq_sc_args.n,
294 | polar_decoder_type=seq2seq_sc_args.polar_decoder_type,
295 | polar_decoder_list_size=seq2seq_sc_args.polar_decoder_list_size,
296 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol,
297 | channel_type=seq2seq_sc_args.channel_type,
298 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant,
299 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant,
300 | config=config,
301 | cache_dir=model_args.cache_dir,
302 | revision=model_args.model_revision,
303 | use_auth_token=True if model_args.use_auth_token else None,
304 | )
305 |
306 | model.resize_token_embeddings(len(tokenizer))
307 | # endregion
308 |
309 | # region Prepare TF Dataset objects
310 | if model.config.decoder_start_token_id is None:
311 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
312 |
313 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
314 | data_collator = DataCollatorForSeq2Seq(
315 | tokenizer,
316 | model=model,
317 | label_pad_token_id=label_pad_token_id,
318 | pad_to_multiple_of=128, # Reduce the number of unique shapes for XLA, especially for generation
319 | return_tensors="tf",
320 | )
321 |
322 | dataset_options = tf.data.Options()
323 | dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
324 |
325 | num_replicas = training_args.strategy.num_replicas_in_sync
326 | total_train_batch_size = training_args.per_device_train_batch_size * num_replicas
327 | total_eval_batch_size = training_args.per_device_eval_batch_size * num_replicas
328 |
329 | # model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in
330 | # training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also
331 | # use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names
332 | # yourself if you use this method, whereas they are automatically inferred from the model input names when
333 | # using model.prepare_tf_dataset()
334 | # For more info see the docs:
335 | # https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset
336 | # https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset
337 |
338 | tf_train_dataset = model.prepare_tf_dataset(
339 | train_dataset,
340 | collate_fn=data_collator,
341 | batch_size=total_train_batch_size,
342 | shuffle=True,
343 | ).with_options(dataset_options)
344 | tf_eval_dataset = model.prepare_tf_dataset(
345 | eval_dataset,
346 | collate_fn=data_collator,
347 | batch_size=total_eval_batch_size,
348 | shuffle=False,
349 | ).with_options(dataset_options)
350 | # endregion
351 |
352 | # region Optimizer, loss and LR scheduling
353 | num_train_steps = int(len(tf_train_dataset) * training_args.num_train_epochs)
354 | if training_args.warmup_steps > 0:
355 | num_warmup_steps = training_args.warmup_steps
356 | elif training_args.warmup_ratio > 0:
357 | num_warmup_steps = int(num_train_steps * training_args.warmup_ratio)
358 | else:
359 | num_warmup_steps = 0
360 | if training_args.do_train:
361 | optimizer, lr_schedule = create_optimizer(
362 | init_lr=training_args.learning_rate,
363 | num_train_steps=num_train_steps,
364 | num_warmup_steps=num_warmup_steps,
365 | adam_beta1=training_args.adam_beta1,
366 | adam_beta2=training_args.adam_beta2,
367 | adam_epsilon=training_args.adam_epsilon,
368 | weight_decay_rate=training_args.weight_decay,
369 | adam_global_clipnorm=training_args.max_grad_norm,
370 | )
371 | else:
372 | optimizer = None
373 |
374 | # endregion
375 |
376 | # region Metric and KerasMetricCallback
377 | if training_args.do_eval:
378 | metric = evaluate.load("rouge")
379 |
380 | if data_args.val_max_target_length is None:
381 | data_args.val_max_target_length = data_args.max_target_length
382 |
383 | gen_kwargs = {
384 | "max_length": data_args.val_max_target_length if data_args is not None else config.max_length,
385 | "num_beams": data_args.num_beams,
386 | "no_repeat_ngram_size": 0, # Not supported under XLA right now, and some models set it by default
387 | }
388 |
389 | def compute_metrics(preds):
390 | predictions, labels = preds
391 | if isinstance(predictions, tuple):
392 | predictions = predictions[0]
393 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
394 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
395 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
396 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
397 | metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
398 | # Only print the mid f-measures, but there are a lot of other statistics in there too!
399 | metrics = {key: round(val * 100, 4) for key, val in metrics.items()}
400 | return metrics
401 |
402 | # The KerasMetricCallback allows metrics that are too complex to write as standard Keras metrics
403 | # to be computed each epoch. Any Python code can be included in the metric_fn. This is especially
404 | # useful for metrics like BLEU and ROUGE that perform string comparisons on decoded model outputs.
405 | # For more information, see the docs at
406 | # https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.KerasMetricCallback
407 |
408 | metric_callback = KerasMetricCallback(
409 | metric_fn=compute_metrics,
410 | eval_dataset=tf_eval_dataset,
411 | predict_with_generate=True,
412 | use_xla_generation=True,
413 | generate_kwargs=gen_kwargs,
414 | )
415 | callbacks = [metric_callback]
416 | else:
417 | callbacks = []
418 | # endregion
419 |
420 | # region Training
421 | model.compile(optimizer=optimizer, jit_compile=training_args.xla)
422 | eval_metrics = None
423 | if training_args.do_train:
424 | logger.info("***** Running training *****")
425 | logger.info(f" Num examples = {len(train_dataset)}")
426 | logger.info(f" Num Epochs = {training_args.num_train_epochs}")
427 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
428 | logger.info(f" Total train batch size = {total_train_batch_size}")
429 | logger.info(f" Total optimization steps = {num_train_steps}")
430 |
431 | if training_args.xla and not data_args.pad_to_max_length:
432 | logger.warning(
433 | "XLA training may be slow at first when --pad_to_max_length is not set "
434 | "until all possible shapes have been compiled."
435 | )
436 | history = model.fit(tf_train_dataset, epochs=int(training_args.num_train_epochs), callbacks=callbacks)
437 | eval_metrics = {key: val[-1] for key, val in history.history.items()}
438 | # endregion
439 |
440 | # region Validation
441 |
442 | if training_args.do_eval and not training_args.do_train:
443 | # Do a standalone evaluation run
444 | logger.info("Evaluation...")
445 |
446 | # Compiling generation with XLA yields enormous speedups, see https://huggingface.co/blog/tf-xla-generate
447 | @tf.function(jit_compile=True)
448 | def generate(**kwargs):
449 | return model.generate(**kwargs)
450 |
451 | for batch, labels in tf_eval_dataset:
452 | batch.update(gen_kwargs)
453 | generated_tokens = generate(**batch)
454 | if isinstance(generated_tokens, tuple):
455 | generated_tokens = generated_tokens[0]
456 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
457 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
458 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
459 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
460 |
461 | metric.add_batch(predictions=decoded_preds, references=decoded_labels)
462 |
463 | eval_metrics = metric.compute(use_stemmer=True)
464 |
465 | result = {key: round(val * 100, 4) for key, val in eval_metrics.items()}
466 | logger.info(result)
467 | # endregion
468 |
469 | if training_args.output_dir is not None and eval_metrics is not None:
470 | output_eval_file = os.path.join(training_args.output_dir, "all_results.json")
471 | with open(output_eval_file, "w") as writer:
472 | writer.write(json.dumps(eval_metrics))
473 |
474 | if training_args.output_dir is not None and not training_args.push_to_hub:
475 | # If we're not pushing to hub, at least save a local copy when we're done
476 | model.save_pretrained(training_args.output_dir)
477 |
478 |
479 | if __name__ == "__main__":
480 | main()
481 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/seq2seq-sc/e900f637ed5e89300bb2fb9e98f0e215cc508ed0/train/__init__.py
--------------------------------------------------------------------------------
/train/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | @dataclass
5 | class Seq2SeqSCArguments:
6 | ebno_db: Optional[float] = field(default=None, metadata= {"help": "ebno_db"})
7 | k: Optional[int] = field(default = 512, metadata= {"help": "K for polar decoder"})
8 | n: Optional[int] = field(default = 1024, metadata= {"help": "N for polar decoder"})
9 | polar_decoder_type: Optional[str] = field(
10 | default = 'SC',
11 | metadata= {"help": "Polar Decoder Type"})
12 | polar_decoder_list_size: Optional[int] = field(default=8,metadata= {"help": "Polar Decoder List size"})
13 | num_bits_per_symbol: Optional[int] = field(default=4,metadata= {"help": "number of bits per symbol"})
14 | channel_type: Optional[str] = field(default='AWGN',metadata= {"help": "AWGN or FlatFadingChannel"})
15 | channel_num_tx_ant: Optional[int] = field(default=1,metadata= {"help": "number of tx antennas for FlatFaddingChannel"})
16 | channel_num_rx_ant: Optional[int] = field(default=1,metadata= {"help": "number of rx antennas for FlatFaddingChannel"})
17 |
18 |
19 | @dataclass
20 | class ModelArguments:
21 | model_name_or_path: str = field(
22 | metadata={
23 | "help": "Path to pretrained model or model identifier from huggingface.co/models"
24 | }
25 | )
26 | config_name: Optional[str] = field(
27 | default=None,
28 | metadata={
29 | "help": "Pretrained config name or path if not the same as model_name"
30 | }
31 | )
32 | tokenizer_name: Optional[str] = field(
33 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
34 | )
35 | cache_dir: Optional[str] = field(
36 | default=None,
37 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
38 | )
39 | use_fast_tokenizer: bool = field(
40 | default=True,
41 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
42 | )
43 | model_revision: str = field(
44 | default="main",
45 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
46 | )
47 | use_auth_token: bool = field(
48 | default=False,
49 | metadata={
50 | "help": (
51 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
52 | "with private models)."
53 | )
54 | },
55 | )
56 |
57 |
58 | @dataclass
59 | class DataTrainingArguments:
60 | """
61 | Arguments pertaining to what data we are going to input our model for training and eval.
62 | """
63 |
64 | dataset_name: Optional[str] = field(
65 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
66 | )
67 | dataset_config_name: Optional[str] = field(
68 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
69 | )
70 | text_column: Optional[str] = field(
71 | default=None,
72 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
73 | )
74 | summary_column: Optional[str] = field(
75 | default=None,
76 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
77 | )
78 | train_file: Optional[str] = field(
79 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
80 | )
81 | validation_file: Optional[str] = field(
82 | default=None,
83 | metadata={
84 | "help": (
85 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
86 | )
87 | },
88 | )
89 | test_file: Optional[str] = field(
90 | default=None,
91 | metadata={
92 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
93 | },
94 | )
95 | overwrite_cache: bool = field(
96 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
97 | )
98 | preprocessing_num_workers: Optional[int] = field(
99 | default=None,
100 | metadata={"help": "The number of processes to use for the preprocessing."},
101 | )
102 | max_source_length: Optional[int] = field(
103 | default=1024,
104 | metadata={
105 | "help": (
106 | "The maximum total input sequence length after tokenization. Sequences longer "
107 | "than this will be truncated, sequences shorter will be padded."
108 | )
109 | },
110 | )
111 | max_target_length: Optional[int] = field(
112 | default=128,
113 | metadata={
114 | "help": (
115 | "The maximum total sequence length for target text after tokenization. Sequences longer "
116 | "than this will be truncated, sequences shorter will be padded."
117 | )
118 | },
119 | )
120 | val_max_target_length: Optional[int] = field(
121 | default=None,
122 | metadata={
123 | "help": (
124 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
125 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
126 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
127 | "during ``evaluate`` and ``predict``."
128 | )
129 | },
130 | )
131 | pad_to_max_length: bool = field(
132 | default=False,
133 | metadata={
134 | "help": (
135 | "Whether to pad all samples to model maximum sentence length. "
136 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
137 | "efficient on GPU but very bad for TPU."
138 | )
139 | },
140 | )
141 | max_train_samples: Optional[int] = field(
142 | default=None,
143 | metadata={
144 | "help": (
145 | "For debugging purposes or quicker training, truncate the number of training examples to this "
146 | "value if set."
147 | )
148 | },
149 | )
150 | max_eval_samples: Optional[int] = field(
151 | default=None,
152 | metadata={
153 | "help": (
154 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
155 | "value if set."
156 | )
157 | },
158 | )
159 | max_predict_samples: Optional[int] = field(
160 | default=None,
161 | metadata={
162 | "help": (
163 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
164 | "value if set."
165 | )
166 | },
167 | )
168 | num_beams: Optional[int] = field(
169 | default=None,
170 | metadata={
171 | "help": (
172 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
173 | "which is used during ``evaluate`` and ``predict``."
174 | )
175 | },
176 | )
177 | ignore_pad_token_for_loss: bool = field(
178 | default=True,
179 | metadata={
180 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
181 | },
182 | )
183 | source_prefix: Optional[str] = field(
184 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
185 | )
186 |
187 | def __post_init__(self):
188 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
189 | raise ValueError("Need either a dataset name or a training/validation file.")
190 | else:
191 | if self.train_file is not None:
192 | extension = self.train_file.split(".")[-1]
193 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
194 | if self.validation_file is not None:
195 | extension = self.validation_file.split(".")[-1]
196 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
197 | if self.val_max_target_length is None:
198 | self.val_max_target_length = self.max_target_length
199 |
200 |
201 | # endregion
202 |
203 | # region Dataset name mappings
204 | summarization_name_mapping = {
205 | "amazon_reviews_multi": ("review_body", "review_title"),
206 | "big_patent": ("description", "abstract"),
207 | "cnn_dailymail": ("article", "highlights"),
208 | "orange_sum": ("text", "summary"),
209 | "pn_summary": ("article", "summary"),
210 | "psc": ("extract_text", "summary_text"),
211 | "samsum": ("dialogue", "summary"),
212 | "thaisum": ("body", "summary"),
213 | "xglue": ("news_body", "news_title"),
214 | "xsum": ("document", "summary"),
215 | "wiki_summary": ("article", "highlights"),
216 | "multi_news": ("document", "summary"),
217 | }
218 | # endregion
219 |
220 |
221 |
--------------------------------------------------------------------------------