├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── eval.py
├── figures
├── On-device AI comm.png
└── file_structure.png
├── models
├── __init__.py
├── channels.py
├── on_device_ai_comm.py
├── utils.py
└── vq_vae.py
├── preprocess
├── __init__.py
├── europarl.py
├── flickr30k.py
└── hf_data_gen.py
├── scripts
├── eval.sh
└── train.sh
├── train.py
└── train
├── __init__.py
└── args.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | **/__pycache__
3 | data
4 | .vscode
5 | **/backup
6 | **/log
7 | **/weights
8 | visualization
9 | jobs
10 |
11 | # ignore log files
12 | *.log
13 |
14 | models/on_device_ai_comm_v2.py
15 | preprocess/allnli.py
16 | srun_gpu.sh
17 | srun_cpu.sh
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Juhyung Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On-Device AI/LLM Communication
2 |
3 |
4 | This repo is the implementation for our paper ["Integrating Pre-Trained Language Model with Physical Layer Communications"](https://arxiv.org/abs/2402.11656).
5 |
6 | ## Highlights
7 | - Fine-tuned a pre-trained LLM (BART) under noisy conditions (3GPP CDL-family channel model) in end-to-end Link-Level Simulation (LLS), integrating with 5G-NR PHY layer functions.
8 | - Developed a compression & quantization method for AI-to-AI comm, reducing transmission size by 50% without compromising performance.
9 | - Verified the framework in NVIDIA Sionna LLS within a 5G-NR PHY setup
10 |
11 |
12 |
13 |
15 |
16 | ## Citation
17 |
18 | ```bash
19 | @misc{lee2024integrating,
20 | title={Integrating Pre-Trained Language Model with Physical Layer Communications},
21 | author={Ju-Hyung Lee and Dong-Ho Lee and Joohan Lee and Jay Pujara},
22 | year={2024},
23 | eprint={2402.11656},
24 | archivePrefix={arXiv},
25 | primaryClass={cs.IT}
26 | }
27 | ```
28 |
29 | ## Model Architecture
30 | 
31 | - Each channel model(e.g., ChannelAWGN) includes channel En/Decoder, mapper, or channel(AWGN, CDL, etc.).
32 |
33 | ## Available checkpoints for On-Device AI Communication
34 |
35 |
36 | | # | Transmission | embedding
dimension | # of
embeddings | Download Link |
37 | | :---: | :------------: | :-----------------: | :--------------: | :-------------------------------------------------------------------------------------------------: |
38 | | 1 | tanh | - | - | [tf_model.h5](https://drive.google.com/file/d/156PpJPNYzHAlXGv1M_y9H9eRnUXrnFTt/view?usp=sharing)|
39 | | 2 | VectorQuantizer | 2 | 1024 | [tf_model.h5](https://drive.google.com/file/d/13gBtLnKo8wwJV6_ZdGHB3AR8WlAyEsJN/view?usp=sharing)|
40 | | 3 | VectorQuantizer | 4 | 1024 | [tf_model.h5](https://drive.google.com/file/d/1OwQ69NGi6INKAExjwVNqr2pe1l3fs2tr/view?usp=sharing)|
41 | | 4 | VectorQuantizer | 8 | 1024 | [tf_model.h5](https://drive.google.com/file/d/12qrKD-q7habrlrm-5BSS9dnUebYEPdF3/view?usp=sharing)|
42 | | 5 | VectorQuantizer | 16 | 1024 | [tf_model.h5](https://drive.google.com/file/d/1DQCapmhGIeFmP66Y11bDzHbsyWJ-MYBC/view?usp=sharing)|
43 |
44 |
45 | We provide example scripts on how to use checkpoints at train, evaluation sections below.
46 |
47 | ## Setup
48 |
49 | Clone the repository and set up the environment:
50 |
51 | ```bash
52 | git clone https://github.com/abman23/on-device-ai-comm.git
53 | cd on-device-ai-comm
54 | conda env create -f environment.yml
55 | conda activate on-device-ai-comm
56 | ```
57 |
58 | ## Data Preprocessing
59 |
60 | ### Europarl dataset
61 |
62 | ```bash
63 | data_path=data/europarl
64 | mkdir -p $data_path
65 | cd $data_path
66 | wget -P /tmp http://www.statmt.org/europarl/v7/europarl.tgz
67 | tar zxf /tmp/europarl.tgz
68 |
69 | europarl_dataset="$data_path/txt/en"
70 | out_dir="$data_path/processed"
71 | njobs=4
72 |
73 | mkdir -p $out_dir
74 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset
75 | ```
76 |
77 |
93 |
94 | ### Flickr30K
95 |
96 | To download the dataset, go to [Flickr30K](http://hockenmaier.cs.illinois.edu/DenotationGraph/) and fill out the form to get the downloadable link.
97 |
98 | ```bash
99 | data_path="data/flickr"
100 | dataset_path="${data_path}/flickr30k.tar.gz"
101 | out_dir="$data_path/processed"
102 |
103 | mkdir -p $out_dir
104 |
105 | tar xzf ${dataset_path} -C $data_path
106 | python -m preprocess.flickr30k \
107 | -o "$out_dir/flickr30k.json" \
108 | "${data_path}/results_20130124.token"
109 | ```
110 |
111 | ## Train
112 |
113 | You can run `scripts/train.sh`. Otherwise, you can train by running the follwing commands. Below is an example for training on-device ai communication system over CDL-A 5 ~ 15dB with vector quantizer.
114 |
115 | ```bash
116 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15'
117 | trainset_path='data/europarl/processed/train.csv'
118 | devset_path='data/europarl/processed/test.csv'
119 |
120 | mkdir -p $output_dir
121 |
122 | python train.py \
123 | --model_name_or_path facebook/bart-base \
124 | --config_name facebook/bart-base \
125 | --tokenizer_name facebook/bart-base \
126 | --train_file "$trainset_path" \
127 | --validation_file "$devset_path" \
128 | --test_file "$devset_path" \
129 | --preprocessing_num_workers 4 \
130 | --per_device_train_batch_size 4 \
131 | --per_device_eval_batch_size 4 \
132 | --num_train_epochs 3 \
133 | --do_train \
134 | --do_eval \
135 | --save_total_limit 1 \
136 | --no_use_fast_tokenizer \
137 | --num_beams 1 \
138 | --pad_to_max_length \
139 | --overwrite_output_dir \
140 | --max_source_length 64 \
141 | --max_target_length 64 \
142 | --output_dir $output_dir \
143 | --ebno_db_min 5 \
144 | --ebno_db_max 15 \
145 | --channel_type "CDL" \
146 | --fec_type "Polar5G" \
147 | --fec_num_iter 20 \
148 | --cdl_model "A" \
149 | --channel_num_tx_ant "2" \
150 | --channel_num_rx_ant "2" \
151 | --num_bits_per_symbol "4" \
152 | --bin_conv_method "vector_quantization" \
153 | --embedding_dim 2 \
154 | --num_embeddings 1024 \
155 | --dataset_config 3.0.0
156 | ```
157 |
158 | - For more arguments for training, please navigate to [here](./train/args.py).
159 |
160 | ## Evaluation
161 |
162 | You can use the script `scripts/eval.sh` or the following commands:
163 |
164 | ```bash
165 | # BLEU score
166 | eval_ebno_db="4"
167 | metric="bleu" # bleu, sbert
168 | testset_path='data/flickr/processed/flickr30k.json'
169 |
170 | checkpoint_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15'
171 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15/CDL-A'
172 |
173 | mkdir -p $output_dir
174 |
175 | fec_type="Polar5G" # Polar5G, LDPC5G
176 | fec_num_iter=20
177 | channel_num_tx_ant="2"
178 | channel_num_rx_ant="2"
179 | num_bits_per_symbol="4"
180 | EVAL_NUM_BEAMS="1"
181 |
182 | python eval.py \
183 | -m "${metric}" \
184 | -b 8 \
185 | -e "${eval_ebno_db}" \
186 | --result-json-path "${output_dir}/flickr_${metric}_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \
187 | --prediction-json-path "${output_dir}/flickr_prediction_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \
188 | --fec-type "${fec_type}" \
189 | --fec-num-iter "${fec_num_iter}" \
190 | --channel-type "CDL" \
191 | --cdl-model "A" \
192 | --channel-num-tx-ant "${channel_num_tx_ant}" \
193 | --channel-num-rx-ant "${channel_num_rx_ant}" \
194 | --num-bits-per-symbol "${num_bits_per_symbol}" \
195 | --bin-conv-method "vector_quantization" \
196 | --embedding-dim 2 \
197 | --num-embeddings 1024 \
198 | --num-beams "${EVAL_NUM_BEAMS}" \
199 | --testset-path "${testset_path}" \
200 | $checkpoint_dir
201 | ```
202 | - Note that the name of model checkpoint in checkpoint_dir should be 'tf_model.h5'.
203 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: on-device-ai-comm
2 | channels:
3 | - huggingface
4 | - anaconda
5 | - defaults
6 | - conda-forge
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - autopep8=1.6.0=pyhd3eb1b0_1
11 | - blas=1.0=mkl
12 | - brotlipy=0.7.0=py39h27cfd23_1003
13 | - ca-certificates=2022.12.7=ha878542_0
14 | - cffi=1.15.1=py39h74dc2b5_0
15 | - click=8.0.4=py39h06a4308_0
16 | - cryptography=37.0.1=py39h9ce1e76_0
17 | - cudatoolkit=11.2.2=hbe64b41_10
18 | - cudnn=8.1.0.77=h90431f1_0
19 | - dataclasses=0.8=pyh6d0b6a4_7
20 | - filelock=3.6.0=pyhd3eb1b0_0
21 | - huggingface_hub=0.9.1=py_0
22 | - importlib_metadata=4.11.3=hd3eb1b0_0
23 | - intel-openmp=2021.4.0=h06a4308_3561
24 | - joblib=1.1.0=pyhd3eb1b0_0
25 | - ld_impl_linux-64=2.38=h1181459_1
26 | - libffi=3.3=he6710b0_2
27 | - libgcc-ng=12.1.0=h8d9b700_16
28 | - libgomp=12.1.0=h8d9b700_16
29 | - libprotobuf=3.20.1=h4ff587b_0
30 | - libstdcxx-ng=12.1.0=ha89aaad_16
31 | - mkl=2021.4.0=h06a4308_640
32 | - mkl-service=2.4.0=py39h7f8727e_0
33 | - mkl_fft=1.3.1=py39hd3c417c_0
34 | - mkl_random=1.2.2=py39h51133e4_0
35 | - ncurses=6.3=h5eee18b_3
36 | - openssl=1.1.1s=h0b41bf4_1
37 | - pycodestyle=2.8.0=pyhd3eb1b0_0
38 | - pycparser=2.21=pyhd3eb1b0_0
39 | - pyopenssl=22.0.0=pyhd3eb1b0_0
40 | - pyparsing=3.0.9=py39h06a4308_0
41 | - pysocks=1.7.1=py39h06a4308_0
42 | - python=3.9.13=haa1d7c7_1
43 | - pyyaml=6.0=py39h7f8727e_1
44 | - readline=8.1.2=h7f8727e_1
45 | - regex=2022.7.9=py39h5eee18b_0
46 | - requests=2.28.1=py39h06a4308_0
47 | - sacremoses=master=py_0
48 | - six=1.16.0=pyhd3eb1b0_1
49 | - sqlite=3.39.2=h5082296_0
50 | - tk=8.6.12=h1ccaba5_0
51 | - toml=0.10.2=pyhd3eb1b0_0
52 | - tqdm=4.64.0=py39h06a4308_0
53 | - tzdata=2022c=h04d1e81_0
54 | - xz=5.2.5=h7f8727e_1
55 | - yaml=0.2.5=h7b6447c_0
56 | - zlib=1.2.12=h5eee18b_3
57 | - pip:
58 | - absl-py==1.3.0
59 | - accelerate==0.13.0
60 | - aiohttp==3.8.3
61 | - aiosignal==1.2.0
62 | - astroid==2.12.10
63 | - astunparse==1.6.3
64 | - async-timeout==4.0.2
65 | - attrs==22.1.0
66 | - bert-score==0.3.12
67 | - cachetools==5.2.0
68 | - certifi==2022.12.7
69 | - charset-normalizer==2.1.1
70 | - cloudpickle==2.2.1
71 | - contourpy==1.0.5
72 | - cycler==0.11.0
73 | - datasets==2.5.1
74 | - decorator==5.1.1
75 | - dill==0.3.5.1
76 | - dm-tree==0.1.8
77 | - evaluate==0.2.2
78 | - flatbuffers==22.12.6
79 | - fonttools==4.37.3
80 | - frozenlist==1.3.1
81 | - fsspec==2022.8.2
82 | - gast==0.4.0
83 | - google-auth==2.15.0
84 | - google-auth-oauthlib==0.4.6
85 | - google-pasta==0.2.0
86 | - grpcio==1.51.1
87 | - h5py==3.7.0
88 | - huggingface-hub==0.11.1
89 | - idna==3.4
90 | - importlib-metadata==5.1.0
91 | - importlib-resources==5.9.0
92 | - isort==5.10.1
93 | - keras==2.10.0
94 | - keras-preprocessing==1.1.2
95 | - kiwisolver==1.4.4
96 | - lazy-object-proxy==1.7.1
97 | - libclang==14.0.6
98 | - markdown==3.4.1
99 | - markupsafe==2.1.1
100 | - matplotlib==3.6.0
101 | - mccabe==0.7.0
102 | - moverscore==1.0.3
103 | - multidict==6.0.2
104 | - multiprocess==0.70.13
105 | - nltk==3.7
106 | - numpy==1.23.5
107 | - oauthlib==3.2.2
108 | - opt-einsum==3.3.0
109 | - packaging==22.0
110 | - pandas==1.5.0
111 | - pillow==9.2.0
112 | - pip==22.3.1
113 | - platformdirs==2.5.2
114 | - portalocker==2.6.0
115 | - protobuf==3.19.6
116 | - psutil==5.9.2
117 | - pyarrow==8.0.0
118 | - pyasn1==0.4.8
119 | - pyasn1-modules==0.2.8
120 | - pydot==1.4.2
121 | - pyemd==0.5.1
122 | - pylint==2.15.3
123 | - python-dateutil==2.8.2
124 | - pytz==2022.2.1
125 | - requests-oauthlib==1.3.1
126 | - responses==0.18.0
127 | - rouge-score==0.1.2
128 | - rsa==4.9
129 | - scikit-learn==1.1.2
130 | - scipy==1.9.1
131 | - sentence-transformers==2.2.2
132 | - sentencepiece==0.1.97
133 | - setuptools==65.6.3
134 | - sionna==0.11.0
135 | - tensorboard==2.10.1
136 | - tensorboard-data-server==0.6.1
137 | - tensorboard-plugin-wit==1.8.1
138 | - tensorflow==2.10.1
139 | - tensorflow-estimator==2.10.0
140 | - tensorflow-io-gcs-filesystem==0.28.0
141 | - tensorflow-probability==0.18.0
142 | - termcolor==2.1.1
143 | - threadpoolctl==3.1.0
144 | - tokenizers==0.12.1
145 | - tomli==2.0.1
146 | - tomlkit==0.11.4
147 | - torch==1.12.1
148 | - torchvision==0.13.1
149 | - transformers==4.25.1
150 | - typing==3.7.4.3
151 | - typing-extensions==4.4.0
152 | - urllib3==1.26.13
153 | - w3lib==2.0.1
154 | - werkzeug==2.2.2
155 | - wheel==0.38.4
156 | - wrapt==1.14.1
157 | - xxhash==3.0.0
158 | - yapf==0.32.0
159 | - yarl==1.8.1
160 | - zipp==3.11.0
161 | prefix: /home/danny911kr/miniconda3/envs/on-device-ai-comm
162 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import json
3 | import argparse
4 | from transformers import BartTokenizer
5 | import evaluate
6 | from tqdm import tqdm
7 | import warnings
8 | # from h5_utils import rename_weight
9 | from torch.profiler import profile, record_function, ProfilerActivity
10 |
11 | def get_test_data(path):
12 | with open(path) as f:
13 | return json.load(f)
14 |
15 | def from_pretrained(path, ebno_db, bin_conv_method, channel_type,
16 | embedding_dim, num_embeddings,
17 | fec_type, cdl_model,
18 | scenario, perfect_csi,
19 | channel_num_tx_ant=1, channel_num_rx_ant=1,
20 | num_bits_per_symbol=4):
21 | from models.on_device_ai_comm import TFOnDeviceAICForConditionalGeneration
22 | import transformers
23 |
24 | model = TFOnDeviceAICForConditionalGeneration.from_pretrained(
25 | path, ebno_db=ebno_db,
26 | bin_conv_method=bin_conv_method, channel_type=channel_type,
27 | fec_type=fec_type, cdl_model=cdl_model,
28 | scenario=scenario, perfect_csi=perfect_csi,
29 | channel_num_tx_ant=channel_num_tx_ant, channel_num_rx_ant=channel_num_rx_ant,
30 | num_bits_per_symbol=num_bits_per_symbol,
31 | embedding_dim=embedding_dim, num_embeddings=num_embeddings)
32 |
33 | return model
34 |
35 | def predict(path, ebno_db,
36 | tokenizer, batch_size, test_data_path,
37 | num_beams, bin_conv_method, channel_type,
38 | fec_type, cdl_model,
39 | scenario, perfect_csi,
40 | channel_num_tx_ant, channel_num_rx_ant,
41 | num_bits_per_symbol,
42 | embedding_dim, num_embeddings):
43 | import tensorflow as tf
44 | max_len = 32
45 |
46 | # load model
47 | model = from_pretrained(path, ebno_db,
48 | bin_conv_method=bin_conv_method,
49 | channel_type=channel_type,
50 | fec_type=fec_type, cdl_model=cdl_model,
51 | scenario=scenario, perfect_csi=perfect_csi,
52 | channel_num_tx_ant=channel_num_tx_ant, channel_num_rx_ant=channel_num_rx_ant,
53 | num_bits_per_symbol=num_bits_per_symbol,
54 | embedding_dim=embedding_dim,
55 | num_embeddings=num_embeddings)
56 |
57 | # # load testset
58 | test_data = get_test_data(test_data_path)
59 | input_sentences = [d['input'] for d in test_data]
60 | input_ids = tokenizer(input_sentences, return_tensors="tf",
61 | padding='max_length', truncation=True, max_length=max_len).input_ids
62 | testset = tf.data.Dataset.from_tensor_slices(input_ids)
63 |
64 | # inference
65 | pred_sentences = []
66 | bers=[]
67 | for input_ids in tqdm(testset.batch(batch_size).prefetch(tf.data.AUTOTUNE)):
68 | output = model(input_ids) # To get BER
69 | pred_batch = model.generate(input_ids, max_new_tokens=max_len, num_beams=num_beams,
70 | top_k=4, penalty_alpha=0.6, do_sample=False)
71 |
72 | output_strs = tokenizer.batch_decode(pred_batch,
73 | skip_special_tokens=True,
74 | clean_up_tokenization_spaces=False)
75 | pred_sentences.extend(output_strs)
76 | bers.append(output.ber.numpy())
77 |
78 | mean_ber = sum(bers) / len(bers)
79 |
80 | res = {
81 | 'input': input_sentences,
82 | 'pred': pred_sentences,
83 | 'refs': [d['refs'] for d in test_data],
84 | 'mean_ber': mean_ber,
85 | }
86 | return res
87 |
88 | def get_predictions(path, ebno_db, test_data_path,
89 | prediction_json_path, batch_size,
90 | tokenizer, num_beams, bin_conv_method,
91 | channel_type, fec_type, cdl_model,
92 | scenario, perfect_csi,
93 | channel_num_tx_ant,channel_num_rx_ant, num_bits_per_symbol,
94 | embedding_dim, num_embeddings,
95 | calc_flops):
96 | path = pathlib.Path(path)
97 | if not prediction_json_path.exists():
98 | print('Missing predictions.json')
99 |
100 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_flops=True,record_shapes=True) as prof:
101 | with record_function("model_inference"):
102 | res = predict(
103 | path=path,
104 | ebno_db=ebno_db,
105 | tokenizer=tokenizer,
106 | batch_size=batch_size,
107 | test_data_path=test_data_path,
108 | num_beams=num_beams,
109 | bin_conv_method=bin_conv_method,
110 | channel_type=channel_type,
111 | fec_type=fec_type,
112 | cdl_model=cdl_model,
113 | scenario=scenario,
114 | perfect_csi=perfect_csi,
115 | channel_num_tx_ant=channel_num_tx_ant,
116 | channel_num_rx_ant=channel_num_rx_ant,
117 | num_bits_per_symbol=num_bits_per_symbol,
118 | embedding_dim=embedding_dim,
119 | num_embeddings=num_embeddings
120 | )
121 | if calc_flops:
122 | print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
123 |
124 | # save result
125 | with open(prediction_json_path, 'w') as f:
126 | json.dump(res, f, indent=4)
127 | else:
128 | with open(prediction_json_path, 'r') as f:
129 | res = json.load(f)
130 | return res
131 |
132 | def calc_bleu(predictions, tokenizer, multi_ref, **kwargs):
133 | bleu = evaluate.load('bleu')
134 | if multi_ref:
135 | warnings.warn('BLEU does not support multiple references')
136 | tokenize = lambda l: tokenizer(l, add_special_tokens=False).input_ids
137 | results = bleu.compute(
138 | references=predictions['input'],
139 | predictions=predictions['pred'],
140 | tokenizer=tokenize,
141 | max_order=4)
142 | return results
143 |
144 | def calc_sbert(predictions, batch_size, multi_ref, **kwargs):
145 | from sentence_transformers import SentenceTransformer, util
146 | import torch
147 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
148 | model = SentenceTransformer(
149 | model_name_or_path='all-MiniLM-L6-v2',
150 | device=device)
151 |
152 | sentences1 = predictions['pred']
153 |
154 | if not multi_ref:
155 | refs = [[s] for s in predictions['input']]
156 | else:
157 | refs = predictions['refs']
158 |
159 | def calc_cos_score(model, hyp_embedding, ref_sentences):
160 | hyp = hyp_embedding.reshape((1, -1))
161 | refs = model.encode(ref_sentences, convert_to_tensor=True)
162 | scores = util.cos_sim(hyp, refs)
163 | scores = scores.reshape((-1)).tolist()
164 | return {
165 | 'scores': scores,
166 | 'max_score': max(scores),
167 | 'mean_score': sum(scores) / len(scores),
168 | }
169 |
170 |
171 | # compute embedding
172 | pred_embed = model.encode(sentences1, batch_size=batch_size, convert_to_tensor=True)
173 | N = pred_embed.shape[0]
174 | scores = [
175 | calc_cos_score(model, pred_embed[i], refs[i]) for i in range(N)
176 | ]
177 | max_scores = [s['max_score'] for s in scores]
178 | mean_score = sum(max_scores)/len(max_scores)
179 | return {
180 | 'metric': 'sentence textual similarity',
181 | 'mean_score': mean_score,
182 | 'scores': scores,
183 | }
184 |
185 | METRIC_TO_SCORER = {
186 | 'bleu': calc_bleu,
187 | 'sbert': calc_sbert,
188 | }
189 |
190 | def calc(args):
191 | tokenizer = BartTokenizer.from_pretrained(args.tokenizer)
192 |
193 | path = args.path
194 | metric = args.metric
195 | batch_size = args.batch_size
196 |
197 | # rename weight name in tf_model.h5 for BinConv
198 | # rename_weight(path, RENAME_MAP)
199 |
200 | # VQ-VAE arguments
201 | if args.bin_conv_method=='vector_quantization':
202 | assert (args.embedding_dim is not None and args.num_embeddings is not None), 'Set embedding_dim and num_embeddings.'
203 |
204 | predictions = get_predictions(
205 | path,
206 | ebno_db=args.ebno_db,
207 | prediction_json_path=args.prediction_json_path,
208 | test_data_path=args.testset_path,
209 | batch_size=batch_size,
210 | tokenizer=tokenizer,
211 | num_beams=args.num_beams,
212 | bin_conv_method=args.bin_conv_method,
213 | channel_type=args.channel_type,
214 | fec_type=args.fec_type,
215 | cdl_model=args.cdl_model,
216 | scenario=args.scenario,
217 | perfect_csi=args.perfect_csi,
218 | channel_num_tx_ant=args.channel_num_tx_ant,
219 | channel_num_rx_ant=args.channel_num_rx_ant,
220 | embedding_dim=int(args.embedding_dim),
221 | num_embeddings=int(args.num_embeddings),
222 | num_bits_per_symbol=args.num_bits_per_symbol,
223 | calc_flops=args.calc_flops)
224 | scorer = METRIC_TO_SCORER[metric]
225 | results = scorer(
226 | predictions=predictions,
227 | tokenizer=tokenizer,
228 | batch_size=batch_size,
229 | multi_ref=args.multi_ref,
230 | )
231 |
232 | # Add mean_ber
233 | results['mean_ber'] = predictions['mean_ber']
234 |
235 | # dump result
236 | with open(args.result_json_path, 'w') as f:
237 | json.dump(results, f, indent=4)
238 |
239 |
240 | def main():
241 | parser = argparse.ArgumentParser()
242 | parser.add_argument(dest='path', metavar='checkpoint_path', type=pathlib.Path)
243 | parser.add_argument('-m', '--metric', choices = list(METRIC_TO_SCORER.keys()), dest='metric')
244 | parser.add_argument('-b', '--batch-size', default=4, type=int, dest='batch_size')
245 | parser.add_argument('-e', '--ebno-db', required=True, type=float, dest='ebno_db')
246 | parser.add_argument('--testset-path',
247 | required=True, type=pathlib.Path, dest='testset_path')
248 | parser.add_argument('--prediction-json-path',
249 | required=True,
250 | type=pathlib.Path,
251 | dest='prediction_json_path',
252 | help='Required. Output path of prediction result cache json file. \
253 | If the file exists, the prediction result will be reused')
254 | parser.add_argument('--result-json-path',
255 | default=pathlib.Path('./result.json'),
256 | type= pathlib.Path,
257 | dest='result_json_path')
258 | parser.add_argument('--tokenizer',
259 | default='facebook/bart-base',
260 | dest='tokenizer')
261 | parser.add_argument('--num-beams',
262 | default=1,
263 | type=int,
264 | dest='num_beams')
265 | parser.add_argument('--multi-ref',
266 | action='store_true',
267 | dest='multi_ref')
268 | parser.add_argument('--bin-conv-method', default='tanh')
269 | parser.add_argument('--channel-type', default='AWGN')
270 | parser.add_argument('--fec-type', default='Polar5G')
271 | parser.add_argument('--fec-num-iter', default=6)
272 | parser.add_argument('--cdl-model', default='A') # CDL
273 | parser.add_argument('--scenario', default='umi') # 3GPP-38.901
274 | parser.add_argument('--perfect-csi', default=True) # 3GPP-38.901
275 | parser.add_argument('--channel-num-tx-ant', default=2)
276 | parser.add_argument('--channel-num-rx-ant', default=2)
277 | parser.add_argument('--num-bits-per-symbol', default=4)
278 | parser.add_argument('--embedding-dim', default=2) # vector quantization
279 | parser.add_argument('--num-embeddings', default=1024) # vector quantization
280 | parser.add_argument('--calc-flops', default=False, type=bool)
281 | args = parser.parse_args()
282 | print(f'{args=}')
283 |
284 | calc(args)
285 |
286 | if __name__ == '__main__':
287 | main()
288 |
--------------------------------------------------------------------------------
/figures/On-device AI comm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/figures/On-device AI comm.png
--------------------------------------------------------------------------------
/figures/file_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/figures/file_structure.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.on_device_ai_comm import TFOnDeviceAICMainLayer, TFOnDeviceAICForConditionalGeneration
--------------------------------------------------------------------------------
/models/channels.py:
--------------------------------------------------------------------------------
1 | from sionna.channel import AWGN, FlatFadingChannel
2 | from sionna.fec.polar import Polar5GEncoder, Polar5GDecoder
3 | from sionna.fec.ldpc.encoding import LDPC5GEncoder
4 | from sionna.fec.ldpc.decoding import LDPC5GDecoder
5 | from sionna.mapping import Mapper, Demapper, Constellation
6 |
7 | from sionna.mimo import mf_equalizer, StreamManagement, lmmse_equalizer
8 | from sionna.ofdm import ResourceGrid, ResourceGridMapper, LSChannelEstimator, LMMSEEqualizer
9 | from sionna.ofdm import OFDMModulator, OFDMDemodulator, ZFPrecoder, RemoveNulledSubcarriers
10 | from sionna.channel.tr38901 import AntennaArray, CDL, Antenna, UMi, UMa, RMa
11 | from sionna.channel import subcarrier_frequencies, cir_to_ofdm_channel, cir_to_time_channel, time_lag_discrete_time_channel
12 | from sionna.channel import gen_single_sector_topology as gen_topology
13 | from sionna.channel import ApplyOFDMChannel, ApplyTimeChannel, OFDMChannel, TimeChannel
14 |
15 | from sionna.utils import ebnodb2no, expand_to_rank, QAMSource
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from transformers.utils import (
21 | logging,
22 | )
23 | logger = logging.get_logger("transformers")
24 |
25 | class ChannelAWGN(tf.keras.Model):
26 | """
27 | Configure AWGN Channel components.
28 | ref: https://nvlabs.github.io/sionna/api/channel.wireless.html?highlight=awgn#sionna.channel.AWGN
29 |
30 | Parameters
31 | ----------
32 | :param fec_type: str, One of ["Polar5G", "LDPC5G"]
33 | :param num_bits_per_symbol: int
34 | :param fec_n: int
35 | :param fec_k: int
36 | :param ebno_db: float
37 | :param ebno_db_min: float
38 | :param ebno_db_max: float
39 | :param fec_num_iter: int
40 |
41 | """
42 | def __init__(self,
43 | fec_type,
44 | num_bits_per_symbol,
45 | fec_n,
46 | fec_k,
47 | ebno_db=None,
48 | ebno_db_min=None,
49 | ebno_db_max=None,
50 | fec_num_iter=6
51 | ):
52 | super().__init__()
53 | self.fec_type = fec_type
54 |
55 | self._n = fec_n
56 | self._k = fec_k
57 | self._coderate = self._k / self._n
58 |
59 | print(f'{self._k=}')
60 | print(f'{self._n=}')
61 | print(f'{self._coderate=}')
62 |
63 | constellation = Constellation("qam",
64 | num_bits_per_symbol,
65 | trainable=False)
66 | logger.info(f'Constellation: type={constellation._constellation_type} ' + \
67 | f'{num_bits_per_symbol=} trainable={constellation._trainable}')
68 | self.num_bits_per_symbol = num_bits_per_symbol
69 | self.mapper = Mapper(constellation=constellation)
70 |
71 | self.channel = AWGN()
72 |
73 | # channel noise
74 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)"
75 | if ebno_db is not None:
76 | self.ebno_db = float(ebno_db)
77 | else:
78 | self.ebno_db = ebno_db # None
79 | self.ebno_db_min = ebno_db_min
80 | self.ebno_db_max = ebno_db_max
81 |
82 | print(f'{self.ebno_db=}')
83 | print(f'{self.ebno_db_min=}')
84 | print(f'{self.ebno_db_max=}')
85 |
86 | self.demapper = Demapper("app", constellation=constellation)
87 |
88 | # FEC
89 | self.fec_num_iter = fec_num_iter
90 | if self.fec_type == 'Polar5G':
91 | self._encoder = Polar5GEncoder(self._k, self._n)
92 | self._decoder = Polar5GDecoder(
93 | self._encoder,
94 | dec_type='SC',
95 | list_size=8,
96 | num_iter=self.fec_num_iter
97 | )
98 | elif self.fec_type == 'LDPC5G':
99 | self._encoder = LDPC5GEncoder(self._k, self._n)
100 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter)
101 | else:
102 | raise ValueError(f"Invalid channel coding type: {fec_type}")
103 |
104 |
105 | @tf.function
106 | def call(self, input):
107 | '''
108 | Input
109 | -----
110 | :param input:
111 |
112 | Output
113 | ------
114 | :return b_hat:
115 | '''
116 | # reshape input
117 | input_shape = input.shape
118 |
119 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor.
120 | divisor=self._k
121 | if np.prod(input_shape) % divisor != 0:
122 | flatten_input = tf.reshape(input, [-1])
123 | flatten_input_len = len(flatten_input)
124 |
125 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len
126 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0)
127 | else:
128 | flatten_input = input
129 |
130 | # Channel encoder
131 | b = tf.reshape(flatten_input, (-1, self._k))
132 | codewords = self._encoder(b)
133 |
134 | # Modulation
135 | x = self.mapper(codewords)
136 |
137 | #####################
138 | # Channel
139 | #####################
140 | # Sampling a batch of SNRs
141 | batch_size=b.shape[0]
142 | if self.ebno_db_min is not None and self.ebno_db_max is not None:
143 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max)
144 | no = ebnodb2no(ebno_db_tf, self.num_bits_per_symbol, self._coderate)
145 | else:
146 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self._coderate)
147 |
148 | no = expand_to_rank(no, 2)
149 |
150 | y = self.channel([x, no])
151 |
152 | #####################
153 | # Receiver
154 | #####################
155 | # Demodulation
156 | llr = self.demapper([y, no])
157 | llr = tf.reshape(llr, (-1, self._n))
158 |
159 | # Channel decoder
160 | b_hat = self._decoder(llr)
161 |
162 | if np.prod(input_shape) % divisor != 0:
163 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements
164 | flatten_b_hat = tf.reshape(b_hat, [-1])
165 | sliced_b_hat = flatten_b_hat[:-dummy_cnt]
166 | b_hat=tf.reshape(sliced_b_hat, input_shape)
167 | else:
168 | b_hat=tf.reshape(b_hat, input_shape)
169 |
170 | return b_hat
171 |
172 | class ChannelCDL(tf.keras.Model):
173 | """
174 | Configure CDL Channel components.
175 |
176 | Parameters
177 | ----------
178 | :param fec_type: str, One of ["Polar5G", "LDPC5G"]
179 | :param cdl_model: str, One of ["A", "B", "C", "D", "E"]
180 | :param channel_num_tx_ant: int
181 | :param channel_num_rx_ant: int
182 | :param num_bits_per_symbol: int
183 | :param ebno_db: float
184 | :param ebno_db_min: float
185 | :param ebno_db_max: float
186 | :param fec_num_iter: int
187 | """
188 | def __init__(self,
189 | fec_type,
190 | cdl_model,
191 | channel_num_tx_ant,
192 | channel_num_rx_ant,
193 | num_bits_per_symbol,
194 | ebno_db=None,
195 | ebno_db_min=None,
196 | ebno_db_max=None,
197 | fec_num_iter=6
198 | ):
199 | super().__init__()
200 |
201 | # Provided parameters
202 | DL_CONFIG={
203 | "cdl_model" : cdl_model,
204 | "delay_spread" : 100e-9,
205 | "domain" : "time",
206 | "direction" : "downlink",
207 | "perfect_csi" : True,
208 | "speed" : 0.0,
209 | "cyclic_prefix_length" : 6,
210 | "pilot_ofdm_symbol_indices" : [2, 11],
211 | "duration" : None
212 | }
213 | self._domain = DL_CONFIG["domain"]
214 | self._direction = DL_CONFIG["direction"]
215 | self._cdl_model = DL_CONFIG["cdl_model"]
216 | self._delay_spread = DL_CONFIG["delay_spread"]
217 | self._perfect_csi = DL_CONFIG["perfect_csi"]
218 | self._speed = DL_CONFIG["speed"]
219 | self._cyclic_prefix_length = DL_CONFIG["cyclic_prefix_length"]
220 | self._pilot_ofdm_symbol_indices = DL_CONFIG["pilot_ofdm_symbol_indices"]
221 |
222 | logger.info(f'{DL_CONFIG=}')
223 |
224 | # System parameters
225 | self._carrier_frequency = 2.6e9
226 | self._subcarrier_spacing = 15e3 #subcarrier_spacing
227 | self._fft_size = 36
228 | self._num_ofdm_symbols = 12
229 | self._num_ut_ant = int(channel_num_tx_ant) # Must be a multiple of two as dual-polarized antennas are used
230 | self._num_bs_ant = int(channel_num_rx_ant) # Must be a multiple of two as dual-polarized antennas are used
231 | self._num_streams_per_tx = self._num_ut_ant
232 | self._dc_null = True
233 | self._num_guard_carriers = [5, 6]
234 | self._pilot_pattern = "kronecker"
235 | self._pilot_ofdm_symbol_indices = DL_CONFIG["pilot_ofdm_symbol_indices"]
236 | self._num_bits_per_symbol = int(num_bits_per_symbol)
237 | self._coderate = 0.5
238 |
239 | # Required system components
240 | self._sm = StreamManagement(np.array([[1]]), self._num_streams_per_tx)
241 |
242 | self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols,
243 | fft_size=self._fft_size,
244 | subcarrier_spacing = self._subcarrier_spacing,
245 | num_tx=1,
246 | num_streams_per_tx=self._num_streams_per_tx,
247 | cyclic_prefix_length=self._cyclic_prefix_length,
248 | num_guard_carriers=self._num_guard_carriers,
249 | dc_null=self._dc_null,
250 | pilot_pattern=self._pilot_pattern,
251 | pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices)
252 |
253 | self._n = int(self._rg.num_data_symbols * self._num_bits_per_symbol)
254 | self._k = int(self._n * self._coderate)
255 |
256 | self._ut_array = AntennaArray(num_rows=1,
257 | num_cols=int(self._num_ut_ant/2),
258 | polarization="dual",
259 | polarization_type="cross",
260 | antenna_pattern="38.901",
261 | carrier_frequency=self._carrier_frequency)
262 |
263 | self._bs_array = AntennaArray(num_rows=1,
264 | num_cols=int(self._num_bs_ant/2),
265 | polarization="dual",
266 | polarization_type="cross",
267 | antenna_pattern="38.901",
268 | carrier_frequency=self._carrier_frequency)
269 |
270 | self._cdl = CDL(model=self._cdl_model,
271 | delay_spread=self._delay_spread,
272 | carrier_frequency=self._carrier_frequency,
273 | ut_array=self._ut_array,
274 | bs_array=self._bs_array,
275 | direction=self._direction,
276 | min_speed=self._speed)
277 |
278 | self._frequencies = subcarrier_frequencies(self._rg.fft_size, self._rg.subcarrier_spacing)
279 |
280 | if self._domain == "freq":
281 | self._channel_freq = ApplyOFDMChannel(add_awgn=True)
282 |
283 | elif self._domain == "time":
284 | self._l_min, self._l_max = time_lag_discrete_time_channel(self._rg.bandwidth)
285 | self._l_tot = self._l_max - self._l_min + 1
286 | self._channel_time = ApplyTimeChannel(self._rg.num_time_samples,
287 | l_tot=self._l_tot,
288 | add_awgn=True)
289 | self._modulator = OFDMModulator(self._cyclic_prefix_length)
290 | self._demodulator = OFDMDemodulator(self._fft_size, self._l_min, self._cyclic_prefix_length)
291 |
292 | self.fec_type = fec_type
293 |
294 | # channel noise
295 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)"
296 | if ebno_db is not None:
297 | self.ebno_db = float(ebno_db)
298 | else:
299 | self.ebno_db = ebno_db # None
300 | self.ebno_db_min = ebno_db_min
301 | self.ebno_db_max = ebno_db_max
302 |
303 | # FEC
304 | self.fec_num_iter = fec_num_iter
305 | if self.fec_type == 'Polar5G':
306 | self._encoder = Polar5GEncoder(self._k, self._n)
307 | self._decoder = Polar5GDecoder(
308 | self._encoder,
309 | dec_type='SC',
310 | list_size=8,
311 | num_iter=self.fec_num_iter
312 | )
313 | elif self.fec_type == 'LDPC5G':
314 | self._encoder = LDPC5GEncoder(self._k, self._n)
315 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter)
316 | else:
317 | raise ValueError(f"Invalid channel coding type: {fec_type}")
318 |
319 | self._mapper = Mapper("qam", self._num_bits_per_symbol)
320 | self._rg_mapper = ResourceGridMapper(self._rg)
321 |
322 | if self._direction == "downlink":
323 | self._zf_precoder = ZFPrecoder(self._rg, self._sm, return_effective_channel=True)
324 |
325 | self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn")
326 | self._lmmse_equ = LMMSEEqualizer(self._rg, self._sm)
327 | self._demapper = Demapper("app", "qam", self._num_bits_per_symbol)
328 |
329 | self._remove_nulled_scs = RemoveNulledSubcarriers(self._rg)
330 |
331 | @tf.function(jit_compile=True)
332 | def call(self, input):
333 | """
334 | Input
335 | -----
336 | :param input:
337 |
338 | Output
339 | ------
340 | :return b_hat:
341 | """
342 | # reshape input
343 | input_shape = input.shape
344 |
345 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor.
346 | divisor=self._num_streams_per_tx * self._k
347 | if np.prod(input_shape) % divisor != 0:
348 | flatten_input = tf.reshape(input, [-1])
349 | flatten_input_len = len(flatten_input)
350 |
351 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len
352 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0)
353 | else:
354 | flatten_input = input
355 |
356 | b = tf.reshape(flatten_input, (-1, 1, self._num_streams_per_tx, self._k))
357 | batch_size = b.shape[0]
358 |
359 | if self.ebno_db_min is not None and self.ebno_db_max is not None:
360 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max)
361 | no = ebnodb2no(ebno_db_tf, self._num_bits_per_symbol, self._coderate, self._rg)
362 | else:
363 | no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, self._coderate, self._rg)
364 |
365 | c = self._encoder(b)
366 | x = self._mapper(c)
367 | x_rg = self._rg_mapper(x)
368 |
369 | if self._domain == "time":
370 | # Time-domain simulations
371 | a, tau = self._cdl(batch_size, self._rg.num_time_samples+self._l_tot-1, self._rg.bandwidth)
372 | h_time = cir_to_time_channel(self._rg.bandwidth, a, tau,
373 | l_min=self._l_min, l_max=self._l_max, normalize=True)
374 |
375 | # As precoding is done in the frequency domain, we need to downsample
376 | # the path gains `a` to the OFDM symbol rate prior to converting the CIR
377 | # to the channel frequency response.
378 | a_freq = a[...,self._rg.cyclic_prefix_length:-1:(self._rg.fft_size+self._rg.cyclic_prefix_length)]
379 | a_freq = a_freq[...,:self._rg.num_ofdm_symbols]
380 | h_freq = cir_to_ofdm_channel(self._frequencies, a_freq, tau, normalize=True)
381 |
382 | if self._direction == "downlink":
383 | x_rg, g = self._zf_precoder([x_rg, h_freq])
384 |
385 | x_time = self._modulator(x_rg)
386 | y_time = self._channel_time([x_time, h_time, no])
387 |
388 | y = self._demodulator(y_time)
389 |
390 | elif self._domain == "freq":
391 | # Frequency-domain simulations
392 |
393 | cir = self._cdl(batch_size, self._rg.num_ofdm_symbols, 1/self._rg.ofdm_symbol_duration)
394 | h_freq = cir_to_ofdm_channel(self._frequencies, *cir, normalize=True)
395 |
396 | if self._direction == "downlink":
397 | x_rg, g = self._zf_precoder([x_rg, h_freq])
398 |
399 | y = self._channel_freq([x_rg, h_freq, no])
400 |
401 | if self._perfect_csi:
402 | if self._direction == "uplink":
403 | h_hat = self._remove_nulled_scs(h_freq)
404 | elif self._direction =="downlink":
405 | h_hat = g
406 | err_var = 0.0
407 | else:
408 | h_hat, err_var = self._ls_est ([y, no])
409 |
410 | x_hat, no_eff = self._lmmse_equ([y, h_hat, err_var, no])
411 | llr = self._demapper([x_hat, no_eff])
412 | b_hat = self._decoder(llr)
413 |
414 | if np.prod(input_shape) % divisor != 0:
415 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements
416 | flatten_b_hat = tf.reshape(b_hat, [-1])
417 | sliced_b_hat = flatten_b_hat[:-dummy_cnt]
418 | b_hat=tf.reshape(sliced_b_hat, input_shape)
419 | else:
420 | b_hat=tf.reshape(b_hat, input_shape)
421 |
422 | return b_hat
423 |
424 | class ChannelSL(tf.keras.Model):
425 | """
426 | OFDM MIMO transmissions over a 3GPP 38.901 model.
427 | (Realistic Multiuser MIMO OFDM)
428 |
429 | Parameters
430 | ----------
431 | :param fec_type: str, One of ["Polar5G", "LDPC5G"]
432 | :param scenario: str, One of ["umi", "uma", "rma"]
433 | :param perfect_csi: boolean
434 | :param channel_num_tx_ant: int
435 | :param channel_num_rx_ant: int
436 | :param num_bits_per_symbol: int
437 | :param ebno_db: float
438 | :param ebno_db_min: float
439 | :param ebno_db_max: float
440 | :param fec_num_iter: int
441 | """
442 | def __init__(self,
443 | fec_type,
444 | scenario,
445 | perfect_csi,
446 | channel_num_tx_ant,
447 | channel_num_rx_ant,
448 | num_bits_per_symbol,
449 | ebno_db=None,
450 | ebno_db_min=None,
451 | ebno_db_max=None,
452 | fec_num_iter=6
453 | ):
454 | super().__init__()
455 | self.fec_type = fec_type
456 | self._scenario = scenario
457 | self._perfect_csi = perfect_csi
458 |
459 | # Internally set parameters
460 | self._carrier_frequency = 2.6e9
461 | self._fft_size = 36
462 | self._subcarrier_spacing = 15e3
463 | self._num_ofdm_symbols = 12
464 | self._cyclic_prefix_length = 6
465 | self._pilot_ofdm_symbol_indices = [2, 11]
466 | self._num_bs_ant = int(channel_num_rx_ant)
467 | self._num_ut = 1 # number of users communicating with a base station.
468 | self._num_ut_ant = int(channel_num_tx_ant) # number of antennas in a user terminal.
469 | self._num_bits_per_symbol = int(num_bits_per_symbol)
470 | self._coderate = 0.5
471 |
472 | self._dc_null = True
473 | self._num_guard_carriers = [5, 6]
474 |
475 | # Create an RX-TX association matrix
476 | # rx_tx_association[i,j]=1 means that receiver i gets at least one stream
477 | # from transmitter j. Depending on the transmission direction (uplink or downlink),
478 | # the role of UT and BS can change.
479 | bs_ut_association = np.zeros([1, self._num_ut])
480 | bs_ut_association[0, :] = 1
481 | self._rx_tx_association = bs_ut_association
482 | self._num_tx = self._num_ut
483 | self._num_streams_per_tx = self._num_ut_ant
484 |
485 |
486 | # Setup an OFDM Resource Grid
487 | self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols,
488 | fft_size=self._fft_size,
489 | subcarrier_spacing=self._subcarrier_spacing,
490 | num_tx=self._num_tx,
491 | num_streams_per_tx=self._num_streams_per_tx,
492 | cyclic_prefix_length=self._cyclic_prefix_length,
493 | num_guard_carriers=self._num_guard_carriers,
494 | dc_null=self._dc_null,
495 | pilot_pattern="kronecker",
496 | pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices)
497 |
498 | # Setup StreamManagement
499 | self._sm = StreamManagement(self._rx_tx_association, self._num_streams_per_tx)
500 |
501 | # Configure antenna arrays
502 | self._ut_array = AntennaArray(
503 | num_rows=1,
504 | num_cols=1,
505 | polarization="single",
506 | polarization_type="V",
507 | antenna_pattern="omni",
508 | carrier_frequency=self._carrier_frequency)
509 |
510 | self._bs_array = AntennaArray(num_rows=1,
511 | num_cols=int(self._num_bs_ant/2),
512 | polarization="dual",
513 | polarization_type="cross",
514 | antenna_pattern="38.901",
515 | carrier_frequency=self._carrier_frequency)
516 |
517 | # Configure the channel model
518 | if self._scenario == "umi":
519 | self._channel_model = UMi(carrier_frequency=self._carrier_frequency,
520 | o2i_model="low",
521 | ut_array=self._ut_array,
522 | bs_array=self._bs_array,
523 | direction="uplink",
524 | enable_pathloss=False,
525 | enable_shadow_fading=False)
526 | elif self._scenario == "uma":
527 | self._channel_model = UMa(carrier_frequency=self._carrier_frequency,
528 | o2i_model="low",
529 | ut_array=self._ut_array,
530 | bs_array=self._bs_array,
531 | direction="uplink",
532 | enable_pathloss=False,
533 | enable_shadow_fading=False)
534 | elif self._scenario == "rma":
535 | self._channel_model = RMa(carrier_frequency=self._carrier_frequency,
536 | ut_array=self._ut_array,
537 | bs_array=self._bs_array,
538 | direction="uplink",
539 | enable_pathloss=False,
540 | enable_shadow_fading=False)
541 |
542 | # Instantiate other building blocks
543 | self._qam_source = QAMSource(self._num_bits_per_symbol)
544 |
545 | self._n = int(self._rg.num_data_symbols*self._num_bits_per_symbol) # Number of coded bits
546 | self._k = int(self._n*self._coderate) # Number of information bits
547 |
548 | # FEC
549 | self.fec_num_iter = fec_num_iter
550 | if self.fec_type == 'Polar5G':
551 | self._encoder = Polar5GEncoder(self._k, self._n)
552 | self._decoder = Polar5GDecoder(
553 | self._encoder,
554 | dec_type='SC',
555 | list_size=8,
556 | num_iter=self.fec_num_iter
557 | )
558 | elif self.fec_type == 'LDPC5G':
559 | self._encoder = LDPC5GEncoder(self._k, self._n)
560 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter)
561 | else:
562 | raise ValueError(f"Invalid channel coding type: {fec_type}")
563 |
564 | self._mapper = Mapper("qam", self._num_bits_per_symbol)
565 | self._rg_mapper = ResourceGridMapper(self._rg)
566 |
567 | self._ofdm_channel = OFDMChannel(self._channel_model, self._rg, add_awgn=True,
568 | normalize_channel=True, return_channel=True)
569 |
570 | self._remove_nulled_subcarriers = RemoveNulledSubcarriers(self._rg)
571 | self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn")
572 | self._lmmse_equ = LMMSEEqualizer(self._rg, self._sm)
573 | self._demapper = Demapper("app", "qam", self._num_bits_per_symbol)
574 |
575 | # channel noise
576 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)"
577 | if ebno_db is not None:
578 | self.ebno_db = float(ebno_db)
579 | else:
580 | self.ebno_db = ebno_db # None
581 | self.ebno_db_min = ebno_db_min
582 | self.ebno_db_max = ebno_db_max
583 |
584 | print(f'{self.ebno_db=}')
585 | print(f'{self.ebno_db_min=}')
586 | print(f'{self.ebno_db_max=}')
587 |
588 | def new_topology(self, batch_size):
589 | """Set new topology"""
590 | topology = gen_topology(batch_size,
591 | self._num_ut,
592 | self._scenario,
593 | min_ut_velocity=0.0,
594 | max_ut_velocity=0.0)
595 |
596 | self._channel_model.set_topology(*topology)
597 |
598 |
599 | @tf.function(jit_compile=True)
600 | def call(self, input):
601 | """
602 | Input
603 | -----
604 | :param input:
605 |
606 | Output
607 | ------
608 | :return b_hat:
609 | """
610 | # reshape input
611 | input_shape = input.shape
612 |
613 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor.
614 | divisor=self._num_tx * self._num_streams_per_tx * self._k
615 | if np.prod(input_shape) % divisor != 0:
616 | flatten_input = tf.reshape(input, [-1])
617 | flatten_input_len = len(flatten_input)
618 |
619 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len
620 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0)
621 | else:
622 | flatten_input = input
623 |
624 | b = tf.reshape(flatten_input, (-1, self._num_tx, self._num_streams_per_tx, self._k))
625 | batch_size = b.shape[0]
626 |
627 | self.new_topology(batch_size)
628 | if self.ebno_db_min is not None and self.ebno_db_max is not None:
629 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max)
630 | no = ebnodb2no(ebno_db_tf, self._num_bits_per_symbol, self._coderate, self._rg)
631 | else:
632 | no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, self._coderate, self._rg)
633 |
634 | c = self._encoder(b)
635 | x = self._mapper(c)
636 | x_rg = self._rg_mapper(x)
637 | y, h = self._ofdm_channel([x_rg, no])
638 | if self._perfect_csi:
639 | h_hat = self._remove_nulled_subcarriers(h)
640 | err_var = 0.0
641 | else:
642 | h_hat, err_var = self._ls_est ([y, no])
643 | x_hat, no_eff = self._lmmse_equ([y, h_hat, err_var, no])
644 | llr = self._demapper([x_hat, no_eff])
645 | b_hat = self._decoder(llr)
646 |
647 | if np.prod(input_shape) % divisor != 0:
648 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements
649 | flatten_b_hat = tf.reshape(b_hat, [-1])
650 | sliced_b_hat = flatten_b_hat[:-dummy_cnt]
651 | b_hat=tf.reshape(sliced_b_hat, input_shape)
652 | else:
653 | b_hat=tf.reshape(b_hat, input_shape)
654 |
655 | return b_hat
656 |
657 | class ChannelFlatFading(tf.keras.Model):
658 | """
659 | Configure FlatFading Channel(for a simulation over RayLeigh) components.
660 | ref: https://nvlabs.github.io/sionna/examples/Simple_MIMO_Simulation.html?highlight=rayleigh
661 |
662 | Parameters
663 | ----------
664 | :param fec_type: str, One of ["Polar5G", "LDPC5G"]
665 | :param channel_num_tx_ant: int
666 | :param channel_num_rx_ant: int
667 | :param num_bits_per_symbol: int
668 | :param fec_n: int
669 | :param fec_k: int
670 | :param ebno_db: float
671 | :param ebno_db_min: float
672 | :param ebno_db_max: float
673 | :param fec_num_iter: int
674 |
675 | """
676 | def __init__(self,
677 | fec_type,
678 | channel_num_tx_ant,
679 | channel_num_rx_ant,
680 | num_bits_per_symbol,
681 | fec_n,
682 | fec_k,
683 | ebno_db=None,
684 | ebno_db_min=None,
685 | ebno_db_max=None,
686 | fec_num_iter=6
687 | ):
688 | super().__init__()
689 | self.fec_type = fec_type
690 |
691 | self._n = fec_n
692 | self._k = fec_k
693 | self._coderate = self._k / self._n
694 |
695 | constellation = Constellation("qam",
696 | num_bits_per_symbol,
697 | trainable=False)
698 | logger.info(f'Constellation: type={constellation._constellation_type} ' + \
699 | f'{num_bits_per_symbol=} trainable={constellation._trainable}')
700 | self.num_bits_per_symbol = num_bits_per_symbol
701 | self.mapper = Mapper(constellation=constellation)
702 |
703 | self.channel_num_tx_ant = int(channel_num_tx_ant)
704 | self.channel_num_rx_ant = int(channel_num_rx_ant)
705 | self.channel = FlatFadingChannel(self.channel_num_tx_ant, self.channel_num_rx_ant, add_awgn=True, return_channel=True)
706 |
707 | # channel noise
708 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)"
709 | if ebno_db is not None:
710 | self.ebno_db = float(ebno_db)
711 | else:
712 | self.ebno_db = ebno_db # None
713 | self.ebno_db_min = ebno_db_min
714 | self.ebno_db_max = ebno_db_max
715 |
716 | print(f'{self.ebno_db=}')
717 | print(f'{self.ebno_db_min=}')
718 | print(f'{self.ebno_db_max=}')
719 |
720 | self.demapper = Demapper("app", constellation=constellation)
721 |
722 | # FEC
723 | self.fec_num_iter = fec_num_iter
724 | if self.fec_type == 'Polar5G':
725 | self._encoder = Polar5GEncoder(self._k, self._n)
726 | self._decoder = Polar5GDecoder(
727 | self._encoder,
728 | dec_type='SC',
729 | list_size=8,
730 | num_iter=self.fec_num_iter
731 | )
732 | elif self.fec_type == 'LDPC5G':
733 | self._encoder = LDPC5GEncoder(self._k, self._n)
734 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter)
735 | else:
736 | raise ValueError(f"Invalid channel coding type: {fec_type}")
737 |
738 |
739 | @tf.function(jit_compile=True)
740 | def call(self, input):
741 | '''
742 | Input
743 | -----
744 | :param input:
745 |
746 | Output
747 | ------
748 | :return b_hat:
749 | '''
750 | # reshape input
751 | input_shape = input.shape
752 |
753 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor.
754 | divisor=self._k
755 | if np.prod(input_shape) % divisor != 0:
756 | flatten_input = tf.reshape(input, [-1])
757 | flatten_input_len = len(flatten_input)
758 |
759 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len
760 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0)
761 | else:
762 | flatten_input = input
763 |
764 | # Channel encoder
765 | b = tf.reshape(flatten_input, (-1, self.channel_num_tx_ant, self._k))
766 | codewords = self._encoder(b)
767 |
768 | # Modulation
769 | x = self.mapper(codewords)
770 | shape = tf.shape(x)
771 | x = tf.reshape(x, (-1, self.channel_num_tx_ant))
772 |
773 | #####################
774 | # Channel
775 | #####################
776 | # Sampling a batch of SNRs
777 | batch_size=b.shape[0]
778 | if self.ebno_db_min is not None and self.ebno_db_max is not None:
779 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max)
780 | no = ebnodb2no(ebno_db_tf, self.num_bits_per_symbol, self._coderate)
781 | else:
782 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self._coderate)
783 |
784 | no *= np.sqrt(self.channel_num_rx_ant)
785 |
786 | y, h = self.channel([x, no])
787 | s = tf.complex(no*tf.eye(self.channel_num_rx_ant, self.channel_num_rx_ant), 0.0)
788 |
789 | # x_hat, no_eff = mf_equalizer(y, h, s)
790 | x_hat, no_eff = lmmse_equalizer(y, h, s)
791 |
792 | x_hat = tf.reshape(x_hat, shape)
793 | no_eff = tf.reshape(no_eff, shape)
794 |
795 | #####################
796 | # Receiver
797 | #####################
798 | # Demodulation
799 | llr = self.demapper([x_hat, no_eff])
800 | # llr = tf.reshape(llr, (-1, self._n))
801 |
802 | # Channel decoder
803 | b_hat = self._decoder(llr)
804 |
805 | if np.prod(input_shape) % divisor != 0:
806 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements
807 | flatten_b_hat = tf.reshape(b_hat, [-1])
808 | sliced_b_hat = flatten_b_hat[:-dummy_cnt]
809 | b_hat=tf.reshape(sliced_b_hat, input_shape)
810 | else:
811 | b_hat=tf.reshape(b_hat, input_shape)
812 |
813 | return b_hat
814 |
--------------------------------------------------------------------------------
/models/on_device_ai_comm.py:
--------------------------------------------------------------------------------
1 | '''
2 | MIMO OFDM Transmissions over the CDL Channel Model
3 | https://nvlabs.github.io/sionna/examples/MIMO_OFDM_Transmissions_over_CDL.html
4 | '''
5 | from dataclasses import dataclass
6 | from typing import Optional, Tuple, Union
7 |
8 | from transformers import TFBartPretrainedModel, TFBartForConditionalGeneration
9 | from transformers.models.bart.modeling_tf_bart import TFBartMainLayer, BartConfig, shift_tokens_right, TFBartEncoder
10 | from transformers.modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqModelOutput
11 | from transformers.modeling_tf_utils import unpack_inputs, TFModelInputType
12 | import tensorflow as tf
13 | import numpy as np
14 |
15 | import sionna
16 | sionna.Config.xla_compat=True
17 |
18 | from transformers.utils import (
19 | logging,
20 | )
21 |
22 | from .channels import ChannelAWGN, ChannelCDL, ChannelSL, ChannelFlatFading
23 | from .utils import tensor_to_binary_v2, binary_to_tensor_v2, replace_nan, get_ber
24 | from .vq_vae import VectorQuantizer
25 |
26 | from transformers.tf_utils import shape_list
27 | from transformers.modeling_tf_outputs import TFSeq2SeqLMOutput, TFBaseModelOutput
28 |
29 | logger = logging.get_logger("transformers")
30 |
31 | @dataclass
32 | class TFEncoderChannelModelOutput(TFBaseModelOutput):
33 | """Output of TFAISrcEncoderAndChannel that includes
34 | - AI-Src encoder
35 | - channel encoder
36 | - mapper
37 | - channel
38 | - demapper
39 | - channel decoder
40 | """
41 | ber: Optional[tf.Tensor] = None
42 |
43 | @dataclass
44 | class TFOnDeviceAICMainLayerOutput(TFSeq2SeqModelOutput):
45 | """Output of TFOnDeviceAICMainLayer"""
46 | ber: Optional[tf.Tensor] = None
47 |
48 | @dataclass
49 | class TFOnDeviceAICOutput(TFSeq2SeqLMOutput):
50 | """Output of TFOnDeviceAICForConditionalGeneration"""
51 | ber: Optional[tf.Tensor] = None
52 |
53 | class TFAISrcEncoderAndChannel(tf.keras.layers.Layer):
54 | """
55 | This class includes AI-Src-Encoder and Channel(Channel En/Decoder, Channel, Mapper, etc.)
56 | """
57 |
58 | def __init__(self,
59 | ai_src_encoder: TFBartEncoder,
60 | vq_layer: VectorQuantizer,
61 | ebno_db,
62 | ebno_db_min,
63 | ebno_db_max,
64 | channel_type,
65 | cdl_model,
66 | scenario,
67 | perfect_csi,
68 | fec_type,
69 | channel_num_tx_ant,
70 | channel_num_rx_ant,
71 | num_bits_per_symbol=4,
72 | fec_k=512,
73 | fec_n=1024,
74 | fec_num_iter=6,
75 | bin_conv_method='tanh',
76 | do_train=False
77 | ):
78 | # NOTE: setting layer name as follows seems strange,
79 | # but it allows HuggingFace to load pretrained weight properly
80 | super().__init__(name='model/model/')
81 | self.config = ai_src_encoder.config
82 | self.ai_src_encoder = ai_src_encoder
83 | self.ai_src_encoder.trainable = False
84 |
85 | # If Training TFOnDeviceAICForConditionalGeneration or not
86 | self.do_train = do_train
87 |
88 | # make sure data types are proper.
89 | num_bits_per_symbol = int(num_bits_per_symbol)
90 |
91 | # Configure Channel Model
92 | if channel_type == 'AWGN':
93 | ch_config = {
94 | 'fec_type': fec_type,
95 | 'num_bits_per_symbol': num_bits_per_symbol,
96 | 'fec_n': fec_n,
97 | 'fec_k': fec_k,
98 | 'ebno_db' : ebno_db,
99 | 'ebno_db_min': ebno_db_min,
100 | 'ebno_db_max': ebno_db_max,
101 | 'fec_num_iter': fec_num_iter,
102 | }
103 | elif channel_type == 'CDL':
104 | ch_config = {
105 | 'fec_type': fec_type,
106 | 'cdl_model': cdl_model,
107 | 'channel_num_tx_ant': channel_num_tx_ant,
108 | 'channel_num_rx_ant': channel_num_rx_ant,
109 | 'num_bits_per_symbol': num_bits_per_symbol,
110 | 'ebno_db' : ebno_db,
111 | 'ebno_db_min': ebno_db_min,
112 | 'ebno_db_max': ebno_db_max,
113 | 'fec_num_iter': fec_num_iter,
114 | }
115 | elif channel_type == '3GPP-38.901':
116 | ch_config = {
117 | 'fec_type': fec_type,
118 | 'scenario': scenario,
119 | 'perfect_csi': perfect_csi,
120 | 'channel_num_tx_ant': channel_num_tx_ant,
121 | 'channel_num_rx_ant': channel_num_rx_ant,
122 | 'num_bits_per_symbol': num_bits_per_symbol,
123 | 'ebno_db' : ebno_db,
124 | 'ebno_db_min': ebno_db_min,
125 | 'ebno_db_max': ebno_db_max,
126 | 'fec_num_iter': fec_num_iter,
127 | }
128 | elif channel_type == 'FlatFading':
129 | ch_config = {
130 | 'fec_type': fec_type,
131 | 'channel_num_tx_ant': channel_num_tx_ant,
132 | 'channel_num_rx_ant': channel_num_rx_ant,
133 | 'num_bits_per_symbol': num_bits_per_symbol,
134 | 'fec_n': fec_n,
135 | 'fec_k': fec_k,
136 | 'ebno_db' : ebno_db,
137 | 'ebno_db_min': ebno_db_min,
138 | 'ebno_db_max': ebno_db_max,
139 | 'fec_num_iter': fec_num_iter,
140 | }
141 | else:
142 | raise ValueError('Invalid Channel type. Channel type should be AWGN or CDL')
143 |
144 | # define vq
145 | if bin_conv_method == 'vector_quantization':
146 | self.vq_layer =vq_layer
147 | self.num_embeddings = vq_layer.num_embeddings
148 | self.embedding_dim = vq_layer.embedding_dim
149 |
150 | # setup
151 | self._setup_channel(channel_type, ch_config)
152 | self._setup_bin_conv(bin_conv_method)
153 |
154 | def _setup_channel(self, channel_type, ch_config):
155 | if channel_type == 'AWGN':
156 | self.channel_model = ChannelAWGN(
157 | fec_type=ch_config['fec_type'],
158 | num_bits_per_symbol=ch_config['num_bits_per_symbol'],
159 | fec_n=ch_config['fec_n'],
160 | fec_k=ch_config['fec_k'],
161 | ebno_db=ch_config['ebno_db'],
162 | ebno_db_min=ch_config['ebno_db_min'],
163 | ebno_db_max=ch_config['ebno_db_max'],
164 | fec_num_iter=ch_config['fec_num_iter']
165 | )
166 | elif channel_type == 'CDL':
167 | self.channel_model = ChannelCDL(
168 | fec_type=ch_config['fec_type'],
169 | cdl_model=ch_config['cdl_model'],
170 | channel_num_tx_ant=ch_config['channel_num_tx_ant'],
171 | channel_num_rx_ant=ch_config['channel_num_rx_ant'],
172 | num_bits_per_symbol=ch_config['num_bits_per_symbol'],
173 | ebno_db=ch_config['ebno_db'],
174 | ebno_db_min=ch_config['ebno_db_min'],
175 | ebno_db_max=ch_config['ebno_db_max'],
176 | fec_num_iter=ch_config['fec_num_iter']
177 | )
178 | elif channel_type == '3GPP-38.901':
179 | self.channel_model = ChannelSL(
180 | fec_type=ch_config['fec_type'],
181 | scenario=ch_config['scenario'],
182 | perfect_csi=ch_config['perfect_csi'],
183 | channel_num_tx_ant=ch_config['channel_num_tx_ant'],
184 | channel_num_rx_ant=ch_config['channel_num_rx_ant'],
185 | num_bits_per_symbol=ch_config['num_bits_per_symbol'],
186 | ebno_db=ch_config['ebno_db'],
187 | ebno_db_min=ch_config['ebno_db_min'],
188 | ebno_db_max=ch_config['ebno_db_max'],
189 | fec_num_iter=ch_config['fec_num_iter']
190 | )
191 | elif channel_type == 'FlatFading':
192 | self.channel_model = ChannelFlatFading(
193 | fec_type=ch_config['fec_type'],
194 | channel_num_tx_ant=ch_config['channel_num_tx_ant'],
195 | channel_num_rx_ant=ch_config['channel_num_rx_ant'],
196 | num_bits_per_symbol=ch_config['num_bits_per_symbol'],
197 | fec_n=ch_config['fec_n'],
198 | fec_k=ch_config['fec_k'],
199 | ebno_db=ch_config['ebno_db'],
200 | ebno_db_min=ch_config['ebno_db_min'],
201 | ebno_db_max=ch_config['ebno_db_max'],
202 | fec_num_iter=ch_config['fec_num_iter']
203 | )
204 | else:
205 | raise ValueError('Invalid Channel type. Channel type should be AWGN or CDL')
206 |
207 | def _setup_bin_conv(self, bin_conv_method):
208 | self.bin_conv_method = bin_conv_method
209 | logger.info(f'{bin_conv_method=}')
210 | if bin_conv_method == 'naive':
211 | self._tensor_to_binary = [tensor_to_binary_v2]
212 | self._binary_to_tensor = [binary_to_tensor_v2]
213 | elif bin_conv_method == 'tanh':
214 | self._tensor_to_binary = [tensor_to_binary_v2]
215 | self._binary_to_tensor = [binary_to_tensor_v2, tf.math.tanh]
216 | elif bin_conv_method == 'vector_quantization':
217 | self._tensor_to_binary = [self.vq_layer.get_code_indices, tensor_to_binary_v2]
218 | self._binary_to_tensor = [binary_to_tensor_v2,
219 | self.vq_layer.handle_invalid_values,
220 | self.vq_layer.reconstruct_with_indices]
221 | else:
222 | raise ValueError(f'Invalid bin_conv_method: {bin_conv_method}')
223 |
224 |
225 | @unpack_inputs
226 | def call(
227 | self,
228 | input_ids: Optional[TFModelInputType] = None,
229 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
230 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
231 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
232 | output_attentions: Optional[bool] = None,
233 | output_hidden_states: Optional[bool] = None,
234 | return_dict: Optional[bool] = None,
235 | training: Optional[bool] = False,
236 | ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
237 | # run AI-Src encoder(BART)
238 | ai_src_encoder_outputs = self.ai_src_encoder(
239 | input_ids=input_ids,
240 | inputs_embeds=inputs_embeds,
241 | attention_mask=attention_mask,
242 | head_mask=head_mask,
243 | output_attentions=output_attentions,
244 | output_hidden_states=output_hidden_states,
245 | return_dict=return_dict,
246 | training=training,
247 | )
248 |
249 | shape = tf.shape(ai_src_encoder_outputs.last_hidden_state)
250 | ai_src_encoder_output = ai_src_encoder_outputs.last_hidden_state
251 |
252 | # binarize (and get codebook indices if vector quantizing)
253 | for f in self._tensor_to_binary:
254 | ai_src_encoder_output = f(ai_src_encoder_output)
255 | ai_src_encoder_output_binary = ai_src_encoder_output
256 |
257 | # add channel noise to binarized ai_src_encoder output
258 | last_hidden_state_binary = self.channel_model(ai_src_encoder_output_binary)
259 |
260 | # convert to tensor and denoise (and reconstruct tensors(to feed semantic decoder) using codebook if vector quantizing)
261 | last_hidden_state = last_hidden_state_binary
262 | for f in self._binary_to_tensor:
263 | last_hidden_state = f(last_hidden_state)
264 | last_hidden_state_pred = tf.reshape(last_hidden_state, shape)
265 |
266 | last_hidden_state_pred = replace_nan(last_hidden_state_pred, 0) # convert all NaN values to zero
267 | if (self.bin_conv_method != 'naive'):
268 | tf.debugging.assert_all_finite(last_hidden_state_pred, 'should not have nan/inf/-inf')
269 |
270 | # calculate BER if eval.
271 | if not self.do_train:
272 | last_hidden_state_binary = tf.reshape(last_hidden_state_binary, tf.shape(ai_src_encoder_output_binary))
273 | ber = get_ber(ai_src_encoder_output_binary, last_hidden_state_binary)
274 | else:
275 | # While training, does not calculate BER.
276 | ber = tf.constant(-1.0, dtype=tf.float32)
277 |
278 | return TFEncoderChannelModelOutput(
279 | last_hidden_state=last_hidden_state_pred,
280 | hidden_states=ai_src_encoder_outputs.hidden_states,
281 | attentions=ai_src_encoder_outputs.attentions,
282 | ber=ber,
283 | )
284 |
285 |
286 | class TFOnDeviceAICMainLayer(tf.keras.layers.Layer):
287 |
288 | def __init__(self,
289 | config: BartConfig,
290 | bart_main_layer: TFBartMainLayer,
291 | ebno_db,
292 | ebno_db_min,
293 | ebno_db_max,
294 | fec_k=512,
295 | fec_n=1024,
296 | fec_num_iter=6,
297 | num_bits_per_symbol=4,
298 | channel_type = 'AWGN',
299 | cdl_model='A',
300 | scenario='umi',
301 | perfect_csi=True,
302 | channel_num_tx_ant=1,
303 | channel_num_rx_ant=1,
304 | fec_type=None,
305 | bin_conv_method='tanh',
306 | embedding_dim=2,
307 | num_embeddings=1024,
308 | do_train=False,
309 | **kwargs):
310 | super().__init__(**kwargs)
311 |
312 | self.config = config
313 | self.shared = bart_main_layer.get_input_embeddings()
314 | self.shared.trainable = False
315 |
316 | self.bin_conv_method = bin_conv_method
317 | # VectorQuantizer layer
318 | if self.bin_conv_method == 'vector_quantization':
319 | print(f'{embedding_dim=}')
320 | print(f'{num_embeddings=}') # number of codebooks.
321 | assert (embedding_dim is not None) and (num_embeddings is not None) \
322 | , "For vector_quantization, set embedding_dim and num_embeddings arguments."
323 | self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, name="vector_quantizer")
324 | else:
325 | self.vq_layer = None
326 |
327 | # ai source encoder
328 | self.encoder = TFAISrcEncoderAndChannel(
329 | ai_src_encoder=bart_main_layer.encoder,
330 | vq_layer = self.vq_layer,
331 | ebno_db=ebno_db,
332 | ebno_db_min=ebno_db_min,
333 | ebno_db_max=ebno_db_max,
334 | fec_k=fec_k,
335 | fec_n=fec_n,
336 | fec_num_iter=fec_num_iter,
337 | num_bits_per_symbol=num_bits_per_symbol,
338 | channel_type=channel_type,
339 | cdl_model=cdl_model,
340 | scenario=scenario,
341 | perfect_csi=perfect_csi,
342 | channel_num_tx_ant=channel_num_tx_ant,
343 | channel_num_rx_ant=channel_num_rx_ant,
344 | fec_type=fec_type,
345 | bin_conv_method=bin_conv_method,
346 | do_train=do_train)
347 | self.decoder = bart_main_layer.decoder
348 |
349 | def get_input_embeddings(self):
350 | return self.shared
351 |
352 | def set_input_embeddings(self, new_embeddings):
353 | self.shared = new_embeddings
354 | self.encoder.encoder.embed_tokens = self.shared
355 | self.decoder.embed_tokens = self.shared
356 |
357 | @unpack_inputs
358 | def call(self,
359 | input_ids: Optional[TFModelInputType] = None,
360 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
361 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
362 | decoder_attention_mask: Optional[Union[np.ndarray,
363 | tf.Tensor]] = None,
364 | decoder_position_ids: Optional[Union[np.ndarray,
365 | tf.Tensor]] = None,
366 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
367 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
368 | cross_attn_head_mask: Optional[Union[np.ndarray,
369 | tf.Tensor]] = None,
370 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
371 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray,
372 | tf.Tensor]]]] = None,
373 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
374 | decoder_inputs_embeds: Optional[Union[np.ndarray,
375 | tf.Tensor]] = None,
376 | use_cache: Optional[bool] = None,
377 | output_attentions: Optional[bool] = None,
378 | output_hidden_states: Optional[bool] = None,
379 | return_dict: Optional[bool] = None,
380 | training: Optional[bool] = False,
381 | **kwargs) -> Union[TFOnDeviceAICMainLayerOutput, Tuple[tf.Tensor]]:
382 | # different to other models, Bart automatically creates decoder_input_ids from
383 | # input_ids if no decoder_input_ids are provided
384 | if decoder_input_ids is None and decoder_inputs_embeds is None:
385 | if input_ids is None:
386 | raise ValueError(
387 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
388 | "passed, `input_ids` cannot be `None`. Please pass either "
389 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
390 | )
391 |
392 | decoder_input_ids = shift_tokens_right(
393 | input_ids, self.config.pad_token_id,
394 | self.config.decoder_start_token_id)
395 |
396 | if encoder_outputs is None:
397 | encoder_outputs = self.encoder(
398 | input_ids=input_ids,
399 | attention_mask=attention_mask,
400 | head_mask=head_mask,
401 | inputs_embeds=inputs_embeds,
402 | output_attentions=output_attentions,
403 | output_hidden_states=output_hidden_states,
404 | return_dict=return_dict,
405 | training=training,
406 | )
407 |
408 | # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
409 | elif return_dict and not isinstance(encoder_outputs,
410 | TFBaseModelOutput):
411 | encoder_outputs = TFBaseModelOutput(
412 | last_hidden_state=encoder_outputs[0],
413 | hidden_states=encoder_outputs[1]
414 | if len(encoder_outputs) > 1 else None,
415 | attentions=encoder_outputs[2]
416 | if len(encoder_outputs) > 2 else None,
417 | )
418 |
419 | # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
420 | elif not return_dict and not isinstance(encoder_outputs, tuple):
421 | encoder_outputs = encoder_outputs.to_tuple()
422 |
423 | # call VectorQuantizer to train VectorQuantizer
424 | if self.bin_conv_method == 'vector_quantization':
425 | # encoder_outputs.last_hidden_state = self.vq_layer(encoder_outputs.last_hidden_state)
426 | self.vq_layer(encoder_outputs.last_hidden_state)
427 |
428 | decoder_outputs = self.decoder(
429 | decoder_input_ids,
430 | attention_mask=decoder_attention_mask,
431 | position_ids=decoder_position_ids,
432 | encoder_hidden_states=encoder_outputs[0],
433 | encoder_attention_mask=attention_mask,
434 | head_mask=decoder_head_mask,
435 | cross_attn_head_mask=cross_attn_head_mask,
436 | past_key_values=past_key_values,
437 | inputs_embeds=decoder_inputs_embeds,
438 | use_cache=use_cache,
439 | output_attentions=output_attentions,
440 | output_hidden_states=output_hidden_states,
441 | return_dict=return_dict,
442 | training=training,
443 | )
444 |
445 | if not return_dict:
446 | return decoder_outputs + encoder_outputs
447 |
448 | return TFOnDeviceAICMainLayerOutput(
449 | last_hidden_state=decoder_outputs.last_hidden_state,
450 | past_key_values=decoder_outputs.past_key_values,
451 | decoder_hidden_states=decoder_outputs.hidden_states,
452 | decoder_attentions=decoder_outputs.attentions,
453 | cross_attentions=decoder_outputs.cross_attentions,
454 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
455 | encoder_hidden_states=encoder_outputs.hidden_states,
456 | encoder_attentions=encoder_outputs.attentions,
457 | ber=encoder_outputs.ber,
458 | )
459 |
460 |
461 | class TFOnDeviceAICForConditionalGeneration(TFBartForConditionalGeneration):
462 | def __init__(self,
463 | config,
464 | ebno_db=None,
465 | ebno_db_min=None,
466 | ebno_db_max=None,
467 | fec_k=512,
468 | fec_n=1024,
469 | fec_num_iter=6,
470 | num_bits_per_symbol=4,
471 | channel_type = 'AWGN',
472 | cdl_model='A',
473 | scenario='umi',
474 | perfect_csi=True,
475 | channel_num_tx_ant = 1,
476 | channel_num_rx_ant = 1,
477 | fec_type = None,
478 | bin_conv_method='tanh',
479 | embedding_dim=None,
480 | num_embeddings=None,
481 | do_train=False,
482 | *inputs,
483 | **kwargs):
484 | super().__init__(config, *inputs, **kwargs)
485 | self.model = TFOnDeviceAICMainLayer(
486 | config,
487 | bart_main_layer=self.model,
488 | ebno_db=ebno_db,
489 | ebno_db_min=ebno_db_min,
490 | ebno_db_max=ebno_db_max,
491 | fec_k=fec_k,
492 | fec_n=fec_n,
493 | fec_num_iter=fec_num_iter,
494 | num_bits_per_symbol=num_bits_per_symbol,
495 | channel_type=channel_type,
496 | cdl_model=cdl_model,
497 | scenario=scenario,
498 | perfect_csi=perfect_csi,
499 | channel_num_tx_ant=channel_num_tx_ant,
500 | channel_num_rx_ant=channel_num_rx_ant,
501 | fec_type=fec_type,
502 | bin_conv_method=bin_conv_method,
503 | embedding_dim=embedding_dim,
504 | num_embeddings=num_embeddings,
505 | do_train=do_train,
506 | name="model")
507 |
508 | self.bin_conv_method = bin_conv_method
509 | self.VQ_LOSS_WEIGHT = 0.01
510 |
511 | @unpack_inputs
512 | def call(
513 | self,
514 | input_ids: Optional[TFModelInputType] = None,
515 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
516 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
517 | decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
518 | decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
519 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
520 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
521 | cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
522 | encoder_outputs: Optional[TFBaseModelOutput] = None,
523 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
524 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
525 | decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
526 | use_cache: Optional[bool] = None,
527 | output_attentions: Optional[bool] = None,
528 | output_hidden_states: Optional[bool] = None,
529 | return_dict: Optional[bool] = None,
530 | labels: Optional[tf.Tensor] = None,
531 | training: Optional[bool] = False,
532 | ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
533 |
534 | if labels is not None:
535 | labels = tf.where(
536 | labels == self.config.pad_token_id,
537 | tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
538 | labels,
539 | )
540 | use_cache = False
541 | if decoder_input_ids is None and decoder_inputs_embeds is None:
542 | decoder_input_ids = shift_tokens_right(
543 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
544 | )
545 |
546 | outputs = self.model(
547 | input_ids,
548 | attention_mask=attention_mask,
549 | decoder_input_ids=decoder_input_ids,
550 | encoder_outputs=encoder_outputs,
551 | decoder_attention_mask=decoder_attention_mask,
552 | decoder_position_ids=decoder_position_ids,
553 | head_mask=head_mask,
554 | decoder_head_mask=decoder_head_mask,
555 | cross_attn_head_mask=cross_attn_head_mask,
556 | past_key_values=past_key_values,
557 | inputs_embeds=inputs_embeds,
558 | decoder_inputs_embeds=decoder_inputs_embeds,
559 | use_cache=use_cache,
560 | output_attentions=output_attentions,
561 | output_hidden_states=output_hidden_states,
562 | return_dict=return_dict,
563 | training=training,
564 | )
565 | lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
566 | lm_logits = self.bias_layer(lm_logits)
567 | masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
568 |
569 | # add weighted vq loss into masked_lm_loss
570 | if self.bin_conv_method == 'vector_quantization' and masked_lm_loss is not None:
571 | vq_loss = self.VQ_LOSS_WEIGHT * sum(self.model.vq_layer.losses)
572 | # tf.print(
573 | # masked_lm_loss,
574 | # vq_loss,
575 | # sep=',',
576 | # output_stream='./joohan/seq2seq-sc2/loss.log')
577 | masked_lm_loss += vq_loss # multiply weight to vq_loss
578 |
579 | if not return_dict:
580 | output = (lm_logits,) + outputs[1:]
581 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
582 |
583 | return TFOnDeviceAICOutput(
584 | loss=masked_lm_loss,
585 | logits=lm_logits,
586 | past_key_values=outputs.past_key_values, # index 1 of d outputs
587 | decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
588 | decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
589 | cross_attentions=outputs.cross_attentions, # index 4 of d outputs
590 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
591 | encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
592 | encoder_attentions=outputs.encoder_attentions, # 2 of e out
593 | ber=outputs.ber,
594 | )
595 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | @tf.function
5 | def logit_to_binary(x):
6 | return (tf.math.sign(x) + 1.0) / 2.0
7 |
8 | @tf.function
9 | def tensor_to_binary_u32(x):
10 | # convert to bits
11 | bits = [None for i in range(32)]
12 | for i in range(32):
13 | bits[i] = tf.bitwise.bitwise_and(x, 1)
14 | x = tf.bitwise.right_shift(x, 1)
15 | res = tf.stack(bits, axis=-1)
16 | return tf.cast(res, tf.float32)
17 |
18 | @tf.function
19 | def tensor_to_binary_v2(x):
20 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32))
21 | TO_MANTISSA = tf.constant([1<<23], tf.float32)
22 |
23 | # sign
24 | sign = tf.cast(x < 0.0, tf.float32)
25 | x = tf.math.abs(x)
26 |
27 | # exponent
28 | log_x = tf.math.floor(tf.math.log(x) / LOG_BASE2)
29 | exponent = tf.cast(log_x + 127.0, tf.uint8)
30 |
31 | # mantissa
32 | mantissa = x / tf.math.exp(log_x*LOG_BASE2) - tf.math.sign(x)
33 | mantissa = tf.math.floor(mantissa * TO_MANTISSA)
34 | mantissa = tf.cast(mantissa, tf.int32)
35 |
36 | # convert to bits
37 | bits = [None for i in range(32)]
38 | for i in range(23):
39 | bits[i] = tf.bitwise.bitwise_and(mantissa, 1)
40 | mantissa = tf.bitwise.right_shift(mantissa, 1)
41 | for i in range(23, 31):
42 | bits[i] = tf.bitwise.bitwise_and(exponent, 1)
43 | exponent = tf.bitwise.right_shift(exponent, 1)
44 | bits[31] = sign
45 |
46 | for i in range(32):
47 | bits[i] = tf.cast(bits[i], tf.float32)
48 | res = tf.stack(bits, axis=-1)
49 | return res
50 |
51 | @tf.function
52 | def binary_to_tensor_u32(x: tf.Tensor):
53 | x = tf.reshape(x, (-1, 32))
54 | x = tf.cast(x, tf.uint32)
55 | out = x[:, 0]
56 | for i in range(1, 32):
57 | out += tf.bitwise.left_shift(x[:, i], i)
58 | return out
59 |
60 | @tf.function
61 | def binary_to_tensor_v2(x: tf.Tensor):
62 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32))
63 | EXPONENTS = tf.constant([ float(1 << i) for i in range(8)], tf.float32)
64 | FROM_MANTISSA = tf.constant([ 0.5**(23-i) for i in range(23)], tf.float32)
65 |
66 | x = tf.reshape(x, (-1, 32))
67 | sign = -x[:, 31] * 2 + 1
68 |
69 | exponent = tf.math.reduce_sum(x[:, 23:31] * EXPONENTS, axis=-1)
70 | mantissa = tf.math.reduce_sum(x[:, :23] * FROM_MANTISSA, axis=-1)
71 | mantissa += tf.cast(exponent > 0.0, tf.float32)
72 | return sign * tf.math.exp((exponent - 127.0) * LOG_BASE2) * mantissa
73 |
74 | @tf.function
75 | def tensor_to_binary(x: tf.Tensor):
76 | x = tf.bitcast(x, tf.uint32)
77 | mask = tf.ones_like(x)
78 | bit0 = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)),
79 | tf.float32)
80 | bits = [bit0]
81 |
82 | for _ in range(31):
83 | x = tf.bitwise.right_shift(x, 1)
84 | bitn = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)),
85 | tf.float32)
86 | bits.append(bitn)
87 |
88 | return tf.concat(bits, axis=0)
89 |
90 | @tf.function
91 | def binary_to_tensor(x: tf.Tensor):
92 | x = tf.cast(x, tf.uint32)
93 | x = tf.reshape(x, (32, -1))
94 |
95 | shape = tf.shape(x)
96 | out = tf.zeros((shape[1],), dtype=tf.uint32)
97 | for i in range(32):
98 | bitn = tf.bitwise.left_shift(x[i, :], i)
99 | out = tf.bitwise.bitwise_xor(out, bitn)
100 |
101 | return tf.bitcast(out, tf.float32)
102 |
103 |
104 | @tf.function
105 | def replace_nan(input, new_value = 0.0):
106 | new_value = float(new_value)
107 | indices = tf.where(tf.math.is_nan(input))
108 | res = tf.tensor_scatter_nd_update(
109 | input,
110 | indices,
111 | tf.fill((tf.shape(indices)[0], ), new_value)
112 | )
113 | return res
114 |
115 | INF = tf.constant(np.array([np.inf]), dtype=tf.float32)
116 |
117 | @tf.function
118 | def replace_nan_to_inf(x):
119 | EXPONENT_MASK = tf.constant([0x7F800000], dtype=tf.uint32)
120 | INF_MASK = tf.constant([0xFF800000], dtype=tf.uint32)
121 | IDENTITY_MASK = tf.constant([0xFFFFFFFF], dtype=tf.uint32)
122 | x = tf.bitcast(x, tf.uint32)
123 | mask = tf.where(
124 | tf.equal(tf.bitwise.bitwise_and(x, EXPONENT_MASK), EXPONENT_MASK),
125 | INF_MASK, IDENTITY_MASK
126 | )
127 | return tf.bitcast(tf.bitwise.bitwise_and(x, mask), tf.float32)
128 |
129 | @tf.function
130 | def get_ber(b, b_hat):
131 | bit_errors = tf.math.count_nonzero(b != b_hat, dtype=tf.float32)
132 | total_bits = tf.cast(tf.reduce_prod(b.shape), dtype=tf.float32)
133 | ber = bit_errors / total_bits
134 |
135 | return ber
136 |
137 | def test():
138 | # f32
139 | x = tf.random.normal(shape=(64,))
140 | b = tensor_to_binary_v2(x)
141 | y = binary_to_tensor_v2(b)
142 | tf.debugging.assert_near(x, y)
143 |
144 | #
145 | x = tf.random.uniform(shape=(64,), maxval=tf.int32.max, dtype=tf.int32)
146 | x = tf.cast(x, tf.uint32)
147 | b = tensor_to_binary_u32(x)
148 | y = binary_to_tensor_u32(b)
149 | tf.debugging.assert_equal(x, y)
150 |
151 | if __name__ == '__main__':
152 | test()
--------------------------------------------------------------------------------
/models/vq_vae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from tensorflow import keras
5 | from tensorflow.keras import layers
6 | import tensorflow_probability as tfp
7 | import tensorflow as tf
8 |
9 | '''https://keras.io/examples/generative/vq_vae/'''
10 |
11 | class VectorQuantizer(layers.Layer):
12 | def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
13 | super().__init__(**kwargs)
14 | self.num_embeddings = num_embeddings # the number of codebooks.
15 | self.embedding_dim = embedding_dim # the dimension of a codebook.
16 |
17 | # The `beta` parameter is best kept between [0.25, 2] as per the paper.
18 | self.beta = beta
19 |
20 | # Initialize the embeddings which we will quantize.
21 | w_init = tf.random_uniform_initializer()
22 | self.embeddings = tf.Variable(
23 | initial_value=w_init(
24 | shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
25 | ),
26 | trainable=True,
27 | name="embeddings_vqvae",
28 | )
29 |
30 | def call(self, x):
31 | # Calculate the input shape of the inputs and
32 | input_shape = tf.shape(x)
33 |
34 | # Quantization.
35 | encoding_indices = self.get_code_indices(x)
36 | encoding_indices = tf.cast(encoding_indices, tf.int64)
37 | encodings = tf.one_hot(encoding_indices, self.num_embeddings)
38 | quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
39 |
40 | # Reshape the quantized values back to the original input shape
41 | quantized = tf.reshape(quantized, input_shape)
42 |
43 | # Calculate vector quantization loss and add that to the layer. You can learn more
44 | # about adding losses to different layers here:
45 | # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
46 | # the original paper to get a handle on the formulation of the loss function.
47 | commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
48 | codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
49 | self.add_loss(self.beta * commitment_loss + codebook_loss)
50 |
51 | # Straight-through estimator.
52 | quantized = x + tf.stop_gradient(quantized - x)
53 | return quantized
54 |
55 | def get_code_indices(self, input):
56 | # get codebook indices for each semantic encoder output
57 |
58 | # flatten the inputs keeping `embedding_dim` intact.
59 | num_of_input_elems = np.prod(input.shape)
60 | assert num_of_input_elems % self.embedding_dim == 0, f"Argument 'embedding_dim' should be a factor of total input data, {num_of_input_elems}."
61 | flattened_inputs = tf.reshape(input, [-1, self.embedding_dim])
62 |
63 | # Calculate L2-normalized distance between the inputs and the codes.
64 | similarity = tf.matmul(flattened_inputs, self.embeddings)
65 | distances = (
66 | tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
67 | + tf.reduce_sum(self.embeddings ** 2, axis=0)
68 | - 2 * similarity
69 | )
70 |
71 | # Derive the indices for minimum distances.
72 | encoding_indices = tf.argmin(distances, axis=1)
73 | encoding_indices = tf.cast(tf.convert_to_tensor(encoding_indices), tf.float32)
74 |
75 | return encoding_indices
76 |
77 | def reconstruct_with_indices(self, indices):
78 | indices = tf.cast(indices, tf.int64)
79 |
80 | encodings = tf.one_hot(indices, self.num_embeddings)
81 | quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
82 |
83 | return quantized
84 |
85 | def handle_invalid_values(self, output):
86 | ### If invalid outputs exist, convert them into closest valid outputs.
87 | # change all negative values to 0
88 | output = tf.maximum(output, 0)
89 | # change bigger than num_embeddings-1 values to num_embeddings-1 (They are indices, so it should be -1)
90 | output = tf.minimum(output, self.num_embeddings-1)
91 | # round values
92 | output = tf.round(output)
93 |
94 | return output
--------------------------------------------------------------------------------
/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/preprocess/__init__.py
--------------------------------------------------------------------------------
/preprocess/europarl.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import json
3 | import argparse
4 | from .hf_data_gen import HFDataGenerator
5 |
6 | # The following code is copied (and slightly modified) from DeepSC
7 | # (https://github.com/zyy598/DeepSC/blob/master/preprocess_text.py)
8 | import unicodedata
9 | import re
10 | from w3lib.html import remove_tags
11 | import pickle
12 | import os
13 | import json
14 | from tqdm import tqdm
15 |
16 | def unicode_to_ascii(s):
17 | return ''.join(c for c in unicodedata.normalize('NFD', s)
18 | if unicodedata.category(c) != 'Mn')
19 |
20 | def normalize_string(s):
21 | # normalize unicode characters
22 | s = unicode_to_ascii(s)
23 | # remove the XML-tags
24 | s = remove_tags(s)
25 | # add white space before !.?
26 | s = re.sub(r'([!.?])', r' \1', s)
27 | s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
28 | s = re.sub(r'\s+', r' ', s)
29 | # change to lower letter
30 | s = s.lower()
31 | return s
32 |
33 | def cutted_data(cleaned, MIN_LENGTH=4, MAX_LENGTH=30):
34 | cutted_lines = list()
35 | for line in cleaned:
36 | length = len(line.split())
37 | if length > MIN_LENGTH and length < MAX_LENGTH:
38 | line = [word for word in line.split()]
39 | cutted_lines.append(' '.join(line))
40 | return cutted_lines
41 |
42 | def save_clean_sentences(sentence, save_path):
43 | pickle.dump(sentence, open(save_path, 'wb'))
44 | print('Saved: %s' % save_path)
45 |
46 | def process_text_file(file_path):
47 | with open(file_path, 'r') as f:
48 | raw_data = f.read()
49 | sentences = raw_data.strip().split('\n')
50 | raw_data_input = [normalize_string(data) for data in sentences]
51 | raw_data_input = cutted_data(raw_data_input)
52 | return raw_data_input
53 |
54 | def tokenize(s, delim=' ', add_start_token=True, add_end_token=True,
55 | punct_to_keep=None, punct_to_remove=None):
56 | """
57 | Tokenize a sequence, converting a string s into a list of (string) tokens by
58 | splitting on the specified delimiter. Optionally keep or remove certain
59 | punctuation marks and add start and end tokens.
60 | """
61 | if punct_to_keep is not None:
62 | for p in punct_to_keep:
63 | s = s.replace(p, '%s%s' % (delim, p))
64 |
65 | if punct_to_remove is not None:
66 | for p in punct_to_remove:
67 | s = s.replace(p, '')
68 |
69 | tokens = s.split(delim)
70 | if add_start_token:
71 | tokens.insert(0, '')
72 | if add_end_token:
73 | tokens.append('')
74 | return tokens
75 |
76 | def build_vocab(sequences, token_to_idx = { }, min_token_count=1, delim=' ',
77 | punct_to_keep=None, punct_to_remove=None, ):
78 | token_to_count = {}
79 |
80 | for seq in sequences:
81 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep,
82 | punct_to_remove=punct_to_remove,
83 | add_start_token=False, add_end_token=False)
84 | for token in seq_tokens:
85 | if token not in token_to_count:
86 | token_to_count[token] = 0
87 | token_to_count[token] += 1
88 |
89 | for token, count in sorted(token_to_count.items()):
90 | if count >= min_token_count:
91 | token_to_idx[token] = len(token_to_idx)
92 |
93 | return token_to_idx
94 |
95 | def encode(seq_tokens, token_to_idx, allow_unk=False):
96 | seq_idx = []
97 | for token in seq_tokens:
98 | if token not in token_to_idx:
99 | if allow_unk:
100 | token = ''
101 | else:
102 | raise KeyError('Token "%s" not in vocab' % token)
103 | seq_idx.append(token_to_idx[token])
104 | return seq_idx
105 |
106 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True):
107 | tokens = []
108 | for idx in seq_idx:
109 | tokens.append(idx_to_token[idx])
110 | if stop_at_end and tokens[-1] == '':
111 | break
112 | if delim is None:
113 | return tokens
114 | else:
115 | return delim.join(tokens)
116 |
117 | SPECIAL_TOKENS = {
118 | '': 0,
119 | '': 1,
120 | '': 2,
121 | '': 3,
122 | }
123 |
124 | def process_europarl(input_data_dir, train_test_split=0.9, njobs=1):
125 | sentences = []
126 | print('Preprocess Raw Text')
127 | from joblib import Parallel, delayed
128 | sentences = Parallel(n_jobs=njobs, verbose=1)(
129 | delayed(process_text_file)(fn)
130 | for fn in pathlib.Path(input_data_dir).glob('*.txt'))
131 | sentences = [s for s_list in sentences for s in s_list ]
132 |
133 | # remove the same sentences
134 | a = {}
135 | for set in sentences:
136 | if set not in a:
137 | a[set] = 0
138 | a[set] += 1
139 | sentences = list(a.keys())
140 | print('Number of sentences: {}'.format(len(sentences)))
141 |
142 | print('Build Vocab')
143 | token_to_idx = build_vocab(
144 | sentences, SPECIAL_TOKENS,
145 | punct_to_keep=[';', ','], punct_to_remove=['?', '.']
146 | )
147 |
148 | vocab = {'token_to_idx': token_to_idx}
149 | print('Number of words in Vocab: {}'.format(len(token_to_idx)))
150 |
151 | print('Start encoding txt')
152 | results = []
153 | for seq in tqdm(sentences):
154 | words = tokenize(seq, punct_to_keep=[';', ','], punct_to_remove=['?', '.'])
155 | tokens = [token_to_idx[word] for word in words]
156 | results.append(tokens)
157 |
158 | train_data = results[: round(len(results) * train_test_split)]
159 | test_data = results[round(len(results) * train_test_split):]
160 |
161 | return train_data, test_data, vocab
162 | # End of the copied code
163 |
164 | class Tokenizer:
165 |
166 | TOKENS_FILTERED = set([
167 | '', ''
168 | ])
169 |
170 | def __init__(self, vocab):
171 | idx_to_token = [None for _ in range(1 + max(vocab['token_to_idx'].values()))]
172 | for token, idx in vocab['token_to_idx'].items():
173 | idx_to_token[idx] = token
174 | self.idx_to_token = idx_to_token
175 | self.token_to_idx = vocab['token_to_idx']
176 |
177 | def decode(self, token_ids):
178 | tokens = map(lambda i: self.idx_to_token[i], token_ids)
179 | tokens = filter(lambda t: t not in self.TOKENS_FILTERED, tokens)
180 | return ' '.join(tokens)
181 |
182 | def batch_decode(self, token_ids_list):
183 | return list(map(lambda token_ids: self.decode(token_ids), token_ids_list))
184 |
185 | def gen_hf_dataset(path: pathlib.Path, output_path=None, train_test_split=0.9, njobs=1):
186 | path = pathlib.Path(path)
187 | if output_path is None:
188 | output_path = path
189 |
190 | train_data, test_data, vocab = process_europarl(path, train_test_split, njobs)
191 |
192 | # save processed sentences
193 | with open(output_path / 'train_data.pkl', 'wb') as f:
194 | pickle.dump(train_data, f)
195 | with open(output_path / 'test_data.pkl', 'wb') as f:
196 | pickle.dump(test_data, f)
197 | with open(output_path / 'vocab.json', 'w') as f:
198 | json.dump(vocab, f)
199 |
200 | tokenizer = Tokenizer(vocab)
201 |
202 | # train set
203 | # import random
204 | # num_train = 3000
205 | # num_test = 1000
206 | train_data = tokenizer.batch_decode(train_data)
207 | train_gen = HFDataGenerator()
208 | # train_data = random.sample(train_data, num_train)
209 | train_gen.add(train_data, train_data)
210 | train_gen.dump(output_path / 'train.csv')
211 |
212 | # train json data
213 | with open(output_path / 'train.json', 'w') as f:
214 | json_data = [{ 'input': s, 'refs': [s] } for s in train_data]
215 | json.dump(json_data, f, indent=4)
216 |
217 | # test set
218 | test_data = tokenizer.batch_decode(test_data)
219 | test_gen = HFDataGenerator()
220 | # test_data = random.sample(test_data, num_test)
221 | test_gen.add(test_data, test_data)
222 | test_gen.dump(output_path / 'test.csv')
223 |
224 | # test json data
225 | with open(output_path / 'test.json', 'w') as f:
226 | json_data = [{ 'input': s, 'refs': [s] } for s in test_data]
227 | json.dump(json_data, f, indent=4)
228 |
229 |
230 | if __name__ == '__main__':
231 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC")
232 | parser.add_argument(
233 | '-o', '--out-path',
234 | dest='out_path',
235 | required=True,
236 | type=pathlib.Path,
237 | help='Required. Path of output files.')
238 | parser.add_argument(
239 | '--train-test-split',
240 | dest='train_test_split',
241 | default=0.9,
242 | type=float,
243 | help='Trainset/ Testset split ratio')
244 | parser.add_argument(
245 | '-j'
246 | '--njobs',
247 | dest='njobs',
248 | default=1,
249 | type=int,
250 | help='Number of threads to be used for preprocessing')
251 | parser.add_argument(
252 | dest='path',
253 | type=pathlib.Path,
254 | help="Path of europarl dataset. It should be '/txt/en'")
255 | args = parser.parse_args()
256 | gen_hf_dataset(args.path, args.out_path, args.train_test_split, args.njobs)
257 |
258 |
--------------------------------------------------------------------------------
/preprocess/flickr30k.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import json
3 | import random
4 | import argparse
5 | import pathlib
6 |
7 | def parse_line(line: str):
8 | key = line.split('#')[0]
9 | caption = line.split('\t')[-1].strip()
10 | return key, caption
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-o', '--out-path',
15 | dest='out_path',
16 | type=pathlib.Path,
17 | help='output json file path')
18 | parser.add_argument('-n', '--num-samples',
19 | dest='N',
20 | default=1000,
21 | type=int,
22 | help='number of samples (default: 1000)')
23 | parser.add_argument('--send',
24 | dest='seed',
25 | default=20221017,
26 | type=int,
27 | help='seed for random module')
28 |
29 | parser.add_argument(
30 | dest='token_path',
31 | help="path of 'results_20130124.token'")
32 | args = parser.parse_args()
33 |
34 | # read token file
35 | data = defaultdict(list)
36 | with open(args.token_path) as f:
37 | for k, caption in map(parse_line, f):
38 | data[k].append(caption)
39 | data = list(data.values())
40 |
41 | # set seed
42 | random.seed(args.seed)
43 |
44 | # sample dataset
45 | samples = random.sample(range(len(data)), k=args.N)
46 | out_data = []
47 | for i in samples:
48 | captions = data[i]
49 | input_idx = random.sample(range(len(captions)), k=1)[0]
50 | input_sentence = captions[input_idx]
51 | ref_sentences = captions[:input_idx] + captions[(input_idx+1):]
52 | out_data.append({
53 | 'input': input_sentence,
54 | 'refs': ref_sentences,
55 | })
56 |
57 | with open(args.out_path, 'w') as f:
58 | json.dump(out_data, f, indent=4)
--------------------------------------------------------------------------------
/preprocess/hf_data_gen.py:
--------------------------------------------------------------------------------
1 | # exampel file format
2 | # --------------------
3 | # text,summary
4 | # "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder","I'm sitting in a room where I'm waiting for something to happen"
5 | # "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.","I'm a gardener and I'm a big fan of flowers."
6 | # "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share","It's that time of year again."
7 | import csv
8 |
9 | class HFDataGenerator:
10 |
11 | FIELDNAMES = ['text', 'summary']
12 |
13 | def __init__(self):
14 | self.rows = []
15 |
16 | def add(self, text, summary):
17 | if isinstance(text, str) and isinstance(summary, str):
18 | self.rows.append({
19 | 'text': f'{text}',
20 | 'summary': f'{summary}',
21 | })
22 | elif isinstance(text, list) and isinstance(summary, list):
23 | assert len(text) == len(summary), "Different text and summary list"
24 | rows = map(lambda ts: {'text': ts[0], 'summary': ts[1]}, zip(text, summary))
25 | self.rows.extend(rows)
26 | else:
27 | assert False, "No such case"
28 |
29 |
30 | def dump(self, filename):
31 | with open(filename, 'w', newline='') as f:
32 | writer = csv.DictWriter(f, fieldnames=self.FIELDNAMES, quoting=csv.QUOTE_ALL)
33 | writer.writeheader()
34 | writer.writerows(self.rows)
35 |
36 | def test():
37 | data_path = "/tmp/hf_data_test.csv"
38 | gen = HFDataGenerator()
39 | gen.add("text1", "summary1")
40 | gen.add("text2", "summary2")
41 | gen.dump(data_path)
42 |
43 | expected = """"text","summary"
44 | "text1","summary1"
45 | "text2","summary2"
46 | """
47 | with open( data_path) as f:
48 | actual = f.read()
49 | assert actual == expected, f"""'{actual}'"""
50 |
51 | if __name__ == '__main__':
52 | test()
53 |
54 |
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | eval_ebno_db="4"
2 | metric="sbert" # bleu, sbert
3 | testset_path='data/flickr/processed/flickr30k.json'
4 |
5 | checkpoint_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15'
6 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15/CDL-A'
7 |
8 | mkdir -p $output_dir
9 |
10 | fec_type="Polar5G" # Polar5G, LDPC5G
11 | fec_num_iter=20
12 | channel_num_tx_ant="2"
13 | channel_num_rx_ant="2"
14 | num_bits_per_symbol="4"
15 | EVAL_NUM_BEAMS="1"
16 |
17 | python eval.py \
18 | -m "${metric}" \
19 | -b 8 \
20 | -e "${eval_ebno_db}" \
21 | --result-json-path "${output_dir}/flickr_${metric}_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \
22 | --prediction-json-path "${output_dir}/flickr_prediction_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \
23 | --fec-type "${fec_type}" \
24 | --fec-num-iter "${fec_num_iter}" \
25 | --channel-type "CDL" \
26 | --cdl-model "A" \
27 | --channel-num-tx-ant "${channel_num_tx_ant}" \
28 | --channel-num-rx-ant "${channel_num_rx_ant}" \
29 | --num-bits-per-symbol "${num_bits_per_symbol}" \
30 | --bin-conv-method "vector_quantization" \
31 | --embedding-dim 2 \
32 | --num-embeddings 1024 \
33 | --num-beams "${EVAL_NUM_BEAMS}" \
34 | --testset-path "${testset_path}" \
35 | $checkpoint_dir
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15'
2 | trainset_path='data/europarl/processed/train.csv'
3 | devset_path='data/europarl/processed/test.csv'
4 |
5 | mkdir -p $output_dir
6 |
7 | python train.py \
8 | --model_name_or_path facebook/bart-base \
9 | --config_name facebook/bart-base \
10 | --tokenizer_name facebook/bart-base \
11 | --train_file "$trainset_path" \
12 | --validation_file "$devset_path" \
13 | --test_file "$devset_path" \
14 | --preprocessing_num_workers 4 \
15 | --per_device_train_batch_size 4 \
16 | --per_device_eval_batch_size 4 \
17 | --num_train_epochs 3 \
18 | --do_train \
19 | --do_eval \
20 | --save_total_limit 1 \
21 | --no_use_fast_tokenizer \
22 | --num_beams 1 \
23 | --pad_to_max_length \
24 | --overwrite_output_dir \
25 | --max_source_length 64 \
26 | --max_target_length 64 \
27 | --output_dir $output_dir \
28 | --ebno_db_min 5 \
29 | --ebno_db_max 15 \
30 | --channel_type "CDL" \
31 | --fec_type "Polar5G" \
32 | --fec_num_iter 20 \
33 | --cdl_model "A" \
34 | --channel_num_tx_ant "2" \
35 | --channel_num_rx_ant "2" \
36 | --num_bits_per_symbol "4" \
37 | --bin_conv_method "vector_quantization" \
38 | --embedding_dim 2 \
39 | --num_embeddings 1024 \
40 | --dataset_config 3.0.0
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # This file is based on
18 | # https://github.com/huggingface/transformers/blob/main/examples/tensorflow/summarization/run_summarization.py
19 | import json
20 | import logging
21 | import os
22 | import sys
23 | from dataclasses import dataclass, field
24 | from typing import Optional, List
25 |
26 | import datasets
27 | import nltk
28 | import numpy as np
29 | import tensorflow as tf
30 | from datasets import load_dataset
31 | import matplotlib.pyplot as plt
32 |
33 | import evaluate
34 | import transformers
35 | from filelock import FileLock
36 | from transformers import (
37 | AutoConfig,
38 | AutoTokenizer,
39 | DataCollatorForSeq2Seq,
40 | HfArgumentParser,
41 | KerasMetricCallback,
42 | TFTrainingArguments,
43 | set_seed,
44 | )
45 | from tensorflow.keras.callbacks import ModelCheckpoint
46 | from transformers.trainer_utils import get_last_checkpoint
47 | from transformers.utils import is_offline_mode
48 | from transformers.optimization_tf import create_optimizer
49 | from train.args import ModelArguments, DataTrainingArguments, summarization_name_mapping, Seq2SeqSCArguments
50 | from models.on_device_ai_comm import TFOnDeviceAICForConditionalGeneration
51 |
52 | logger = logging.getLogger(__name__)
53 |
54 | from datetime import datetime
55 |
56 | try:
57 | nltk.data.find("tokenizers/punkt")
58 | except (LookupError, OSError):
59 | if is_offline_mode():
60 | raise LookupError(
61 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
62 | )
63 | with FileLock(".lock") as lock:
64 | nltk.download("punkt", quiet=True)
65 |
66 | def main():
67 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments, Seq2SeqSCArguments))
68 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
69 | # If we pass only one argument to the script and it's the path to a json file,
70 | # let's parse it to get our arguments.
71 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_json_file(
72 | json_file=os.path.abspath(sys.argv[1]))
73 | else:
74 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_args_into_dataclasses()
75 |
76 | if training_args.fp16:
77 | policy = tf.keras.mixed_precision.Policy('mixed_float16')
78 | tf.keras.mixed_precision.set_global_policy(policy)
79 |
80 | current_time = datetime.now()
81 | current_time = current_time.strftime("%Y-%m-%d_%H-%M-%S")
82 | sys.stdout = open(f'./{training_args.output_dir}/train_stdout-{current_time}.log','a')
83 | sys.stderr = open(f'./{training_args.output_dir}/train_stderr-{current_time}.log','a')
84 |
85 | # region Logging
86 | logging.basicConfig(
87 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
88 | datefmt="%m/%d/%Y %H:%M:%S",
89 | handlers=[logging.StreamHandler(sys.stdout)],
90 | )
91 | logger.setLevel(logging.INFO)
92 | datasets.utils.logging.set_verbosity(logging.INFO)
93 | transformers.utils.logging.set_verbosity(logging.INFO)
94 |
95 | # Log on each process the small summary:
96 | logger.info(f"Training/evaluation parameters {training_args}")
97 | # endregion
98 |
99 | # region Detecting last checkpoint
100 | last_checkpoint = None
101 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
102 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
103 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
104 | raise ValueError(
105 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
106 | "Use --overwrite_output_dir to overcome."
107 | )
108 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
109 | logger.info(
110 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
111 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
112 | )
113 | # endregion
114 |
115 | # Set seed before initializing model.
116 | set_seed(training_args.seed)
117 |
118 | # region Load datasets
119 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
120 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
121 | # (the dataset will be downloaded automatically from the datasets Hub).
122 | #
123 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
124 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
125 | #
126 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
127 | # download the dataset.
128 | if data_args.dataset_name is not None:
129 | # Downloading and loading a dataset from the hub.
130 | raw_datasets = load_dataset(
131 | data_args.dataset_name,
132 | data_args.dataset_config_name,
133 | cache_dir=model_args.cache_dir,
134 | use_auth_token=True if model_args.use_auth_token else None,
135 | )
136 | else:
137 | data_files = {}
138 | if data_args.train_file is not None:
139 | data_files["train"] = data_args.train_file
140 | extension = data_args.train_file.split(".")[-1]
141 | if data_args.validation_file is not None:
142 | data_files["validation"] = data_args.validation_file
143 | extension = data_args.validation_file.split(".")[-1]
144 | if data_args.test_file is not None:
145 | data_files["test"] = data_args.test_file
146 | extension = data_args.test_file.split(".")[-1]
147 | raw_datasets = load_dataset(
148 | extension,
149 | data_files=data_files,
150 | cache_dir=model_args.cache_dir,
151 | use_auth_token=True if model_args.use_auth_token else None,
152 | )
153 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
154 | # https://huggingface.co/docs/datasets/loading_datasets.html.
155 | # endregion
156 |
157 | # region Load model config and tokenizer
158 | #
159 | # Distributed training:
160 | # The .from_pretrained methods guarantee that only one local process can concurrently
161 | # download model & vocab.
162 |
163 | config = AutoConfig.from_pretrained(
164 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
165 | cache_dir=model_args.cache_dir,
166 | revision=model_args.model_revision,
167 | use_auth_token=True if model_args.use_auth_token else None,
168 | )
169 | tokenizer = AutoTokenizer.from_pretrained(
170 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
171 | cache_dir=model_args.cache_dir,
172 | use_fast=model_args.use_fast_tokenizer,
173 | revision=model_args.model_revision,
174 | use_auth_token=True if model_args.use_auth_token else None,
175 | )
176 |
177 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
178 | # endregion
179 |
180 | # region Dataset preprocessing
181 | # We need to tokenize inputs and targets.
182 | if training_args.do_train:
183 | column_names = raw_datasets["train"].column_names
184 | elif training_args.do_eval:
185 | column_names = raw_datasets["validation"].column_names
186 | else:
187 | logger.info("There is nothing to do. Please pass `do_train`, and/or `do_eval`.")
188 | return
189 |
190 | # Get the column names for input/target.
191 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
192 | if data_args.text_column is None:
193 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
194 | else:
195 | text_column = data_args.text_column
196 | if text_column not in column_names:
197 | raise ValueError(
198 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
199 | )
200 | if data_args.summary_column is None:
201 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
202 | else:
203 | summary_column = data_args.summary_column
204 | if summary_column not in column_names:
205 | raise ValueError(
206 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
207 | )
208 |
209 | # Temporarily set max_target_length for training.
210 | max_target_length = data_args.max_target_length
211 | padding = "max_length" if data_args.pad_to_max_length else False
212 |
213 | def preprocess_function(examples):
214 | inputs = examples[text_column]
215 | assert prefix is not None
216 | for i, inp in enumerate(inputs):
217 | if inp is None:
218 | print(i, inputs[i], inputs[i-1], inputs[i+1])
219 | targets = examples[summary_column]
220 | inputs = [prefix + inp for inp in inputs]
221 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
222 |
223 | # Tokenize targets with the `text_target` keyword argument
224 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
225 |
226 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
227 | # padding in the loss.
228 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
229 | labels["input_ids"] = [
230 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
231 | ]
232 |
233 | model_inputs["labels"] = labels["input_ids"]
234 | return model_inputs
235 |
236 | if training_args.do_train:
237 | if "train" not in raw_datasets:
238 | raise ValueError("--do_train requires a train dataset")
239 | train_dataset = raw_datasets["train"]
240 | if data_args.max_train_samples is not None:
241 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
242 | train_dataset = train_dataset.select(range(max_train_samples))
243 | with training_args.main_process_first(desc="train dataset map pre-processing"):
244 | train_dataset = train_dataset.map(
245 | preprocess_function,
246 | batched=True,
247 | num_proc=data_args.preprocessing_num_workers,
248 | remove_columns=column_names,
249 | load_from_cache_file=not data_args.overwrite_cache,
250 | desc="Running tokenizer on train dataset",
251 | )
252 | else:
253 | train_dataset = None
254 |
255 | if training_args.do_eval:
256 | max_target_length = data_args.val_max_target_length
257 | if "validation" not in raw_datasets:
258 | raise ValueError("--do_eval requires a validation dataset")
259 | eval_dataset = raw_datasets["validation"]
260 | if data_args.max_eval_samples is not None:
261 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
262 | eval_dataset = eval_dataset.select(range(max_eval_samples))
263 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
264 | eval_dataset = eval_dataset.map(
265 | preprocess_function,
266 | batched=True,
267 | num_proc=data_args.preprocessing_num_workers,
268 | remove_columns=column_names,
269 | load_from_cache_file=not data_args.overwrite_cache,
270 | desc="Running tokenizer on validation dataset",
271 | )
272 | else:
273 | eval_dataset = None
274 | # endregion
275 |
276 | # region Text preprocessing
277 | def postprocess_text(preds, labels):
278 | preds = [pred.strip() for pred in preds]
279 | labels = [label.strip() for label in labels]
280 |
281 | # rougeLSum expects newline after each sentence
282 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
283 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
284 |
285 | return preds, labels
286 |
287 |
288 | # endregion
289 | with training_args.strategy.scope():
290 | # region Prepare model
291 | if model_args.model_name_or_path == '' or model_args.model_name_or_path is None:
292 | # Create a BART model with random initialization
293 | model = TFOnDeviceAICForConditionalGeneration(
294 | ebno_db=seq2seq_sc_args.ebno_db,
295 | ebno_db_min=seq2seq_sc_args.ebno_db_min,
296 | ebno_db_max=seq2seq_sc_args.ebno_db_max,
297 | fec_k=seq2seq_sc_args.k,
298 | fec_n=seq2seq_sc_args.n,
299 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol,
300 | channel_type=seq2seq_sc_args.channel_type,
301 | cdl_model=seq2seq_sc_args.cdl_model,
302 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant,
303 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant,
304 | fec_type=seq2seq_sc_args.fec_type,
305 | bin_conv_method=seq2seq_sc_args.bin_conv_method,
306 | embedding_dim=seq2seq_sc_args.embedding_dim,
307 | num_embeddings=seq2seq_sc_args.num_embeddings,
308 | do_train=training_args.do_train,
309 | config=config,
310 | )
311 | print(f'Random Initialized without pre-trained weights.')
312 | else:
313 | model_cls = TFOnDeviceAICForConditionalGeneration
314 | # https://huggingface.co/docs/transformers/v4.34.0/en/main_classes/model#transformers.TFPreTrainedModel.from_pretrained
315 | model = model_cls.from_pretrained(
316 | model_args.model_name_or_path,
317 | ebno_db=seq2seq_sc_args.ebno_db,
318 | ebno_db_min=seq2seq_sc_args.ebno_db_min,
319 | ebno_db_max=seq2seq_sc_args.ebno_db_max,
320 | fec_k=seq2seq_sc_args.k,
321 | fec_n=seq2seq_sc_args.n,
322 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol,
323 | channel_type=seq2seq_sc_args.channel_type,
324 | cdl_model=seq2seq_sc_args.cdl_model,
325 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant,
326 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant,
327 | fec_type=seq2seq_sc_args.fec_type,
328 | bin_conv_method=seq2seq_sc_args.bin_conv_method,
329 | embedding_dim=seq2seq_sc_args.embedding_dim,
330 | num_embeddings=seq2seq_sc_args.num_embeddings,
331 | do_train=training_args.do_train,
332 | config=config,
333 | cache_dir=model_args.cache_dir,
334 | revision=model_args.model_revision,
335 | use_auth_token=True if model_args.use_auth_token else None,
336 | )
337 | print(f'Loaded {model_args.model_name_or_path}')
338 |
339 | model.resize_token_embeddings(len(tokenizer))
340 |
341 | # endregion
342 |
343 | # region Prepare TF Dataset objects
344 | if model.config.decoder_start_token_id is None:
345 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
346 |
347 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
348 | data_collator = DataCollatorForSeq2Seq(
349 | tokenizer,
350 | model=model,
351 | label_pad_token_id=label_pad_token_id,
352 | pad_to_multiple_of=128, # Reduce the number of unique shapes for XLA, especially for generation
353 | return_tensors="tf",
354 | )
355 |
356 | dataset_options = tf.data.Options()
357 | dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
358 |
359 | num_replicas = training_args.strategy.num_replicas_in_sync
360 | total_train_batch_size = training_args.per_device_train_batch_size * num_replicas
361 | total_eval_batch_size = training_args.per_device_eval_batch_size * num_replicas
362 |
363 | # model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in
364 | # training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also
365 | # use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names
366 | # yourself if you use this method, whereas they are automatically inferred from the model input names when
367 | # using model.prepare_tf_dataset()
368 | # For more info see the docs:
369 | # https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset
370 | # https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset
371 |
372 | tf_train_dataset = model.prepare_tf_dataset(
373 | train_dataset,
374 | collate_fn=data_collator,
375 | batch_size=total_train_batch_size,
376 | shuffle=True,
377 | ).with_options(dataset_options)
378 | tf_eval_dataset = model.prepare_tf_dataset(
379 | eval_dataset,
380 | collate_fn=data_collator,
381 | batch_size=total_eval_batch_size,
382 | shuffle=False,
383 | ).with_options(dataset_options)
384 | # endregion
385 |
386 | # region Optimizer, loss and LR scheduling
387 | num_train_steps = int(len(tf_train_dataset) * training_args.num_train_epochs)
388 | if training_args.warmup_steps > 0:
389 | num_warmup_steps = training_args.warmup_steps
390 | elif training_args.warmup_ratio > 0:
391 | num_warmup_steps = int(num_train_steps * training_args.warmup_ratio)
392 | else:
393 | num_warmup_steps = 0
394 | if training_args.do_train:
395 | optimizer, lr_schedule = create_optimizer(
396 | init_lr=training_args.learning_rate,
397 | num_train_steps=num_train_steps,
398 | num_warmup_steps=num_warmup_steps,
399 | adam_beta1=training_args.adam_beta1,
400 | adam_beta2=training_args.adam_beta2,
401 | adam_epsilon=training_args.adam_epsilon,
402 | weight_decay_rate=training_args.weight_decay,
403 | adam_global_clipnorm=training_args.max_grad_norm,
404 | )
405 | else:
406 | optimizer = None
407 |
408 | # endregion
409 |
410 | # region Metric and KerasMetricCallback
411 | if training_args.do_eval:
412 | metric = evaluate.load("rouge")
413 |
414 | if data_args.val_max_target_length is None:
415 | data_args.val_max_target_length = data_args.max_target_length
416 |
417 | gen_kwargs = {
418 | "max_length": data_args.val_max_target_length if data_args is not None else config.max_length,
419 | "num_beams": data_args.num_beams,
420 | "no_repeat_ngram_size": 0, # Not supported under XLA right now, and some models set it by default
421 | }
422 |
423 | def compute_metrics(preds):
424 | predictions, labels = preds
425 | if isinstance(predictions, tuple):
426 | predictions = predictions[0]
427 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
428 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
429 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
430 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
431 | metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
432 | # Only print the mid f-measures, but there are a lot of other statistics in there too!
433 | metrics = {key: round(val * 100, 4) for key, val in metrics.items()}
434 | return metrics
435 |
436 | # The KerasMetricCallback allows metrics that are too complex to write as standard Keras metrics
437 | # to be computed each epoch. Any Python code can be included in the metric_fn. This is especially
438 | # useful for metrics like BLEU and ROUGE that perform string comparisons on decoded model outputs.
439 | # For more information, see the docs at
440 | # https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.KerasMetricCallback
441 |
442 | metric_callback = KerasMetricCallback(
443 | metric_fn=compute_metrics,
444 | eval_dataset=tf_eval_dataset,
445 | predict_with_generate=True,
446 | use_xla_generation=True,
447 | generate_kwargs=gen_kwargs,
448 | )
449 | callbacks = [metric_callback]
450 |
451 | # Define a callback to save the model
452 | checkpoint_callback = ModelCheckpoint(
453 | filepath= training_args.output_dir + '/model_epoch_{epoch:02d}_loss_{loss:.4f}.h5',
454 | monitor='loss', # metric name
455 | save_best_only=True, # Save only if the monitored quantity improves
456 | save_weights_only=True, # Only save the model weights
457 | verbose=1 # Print more information
458 | )
459 |
460 | # Add the checkpoint callback to list of callbacks
461 | callbacks.append(checkpoint_callback)
462 | else:
463 | callbacks = []
464 | # endregion
465 |
466 | # region Training
467 | model.compile(optimizer=optimizer, jit_compile=training_args.xla)
468 | eval_metrics = None
469 | if training_args.do_train:
470 | logger.info("***** Running training *****")
471 | logger.info(f" Num examples = {len(train_dataset)}")
472 | logger.info(f" Num Epochs = {training_args.num_train_epochs}")
473 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
474 | logger.info(f" Total train batch size = {total_train_batch_size}")
475 | logger.info(f" Total optimization steps = {num_train_steps}")
476 |
477 | if training_args.xla and not data_args.pad_to_max_length:
478 | logger.warning(
479 | "XLA training may be slow at first when --pad_to_max_length is not set "
480 | "until all possible shapes have been compiled."
481 | )
482 | history = model.fit(tf_train_dataset, epochs=int(training_args.num_train_epochs), callbacks=callbacks)
483 | # e.g., When epoch3, history.history={'loss': [3.3186378479003906, 1.8291417360305786, 1.5430701971054077], 'rouge1': [49.4591, 57.3391, 66.4159], 'rouge2': [36.3391, 50.6601, 60.5605], 'rougeL': [49.1986, 58.1554, 65.7636], 'rougeLsum': [49.112, 58.124, 65.7636]}
484 | eval_metrics = {key: val[-1] for key, val in history.history.items()}
485 | # endregion
486 |
487 | # region Validation
488 |
489 | if training_args.do_eval and not training_args.do_train:
490 | # Do a standalone evaluation run
491 | logger.info("Evaluation...")
492 |
493 | # Compiling generation with XLA yields enormous speedups, see https://huggingface.co/blog/tf-xla-generate
494 | @tf.function(jit_compile=True)
495 | def generate(**kwargs):
496 | return model.generate(**kwargs)
497 |
498 | for batch, labels in tf_eval_dataset:
499 | batch.update(gen_kwargs)
500 | generated_tokens = generate(**batch)
501 | if isinstance(generated_tokens, tuple):
502 | generated_tokens = generated_tokens[0]
503 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
504 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
505 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
506 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
507 |
508 | metric.add_batch(predictions=decoded_preds, references=decoded_labels)
509 |
510 | eval_metrics = metric.compute(use_stemmer=True)
511 |
512 | result = {key: round(val * 100, 4) for key, val in eval_metrics.items()}
513 | logger.info(result)
514 | # endregion
515 |
516 | if training_args.output_dir is not None and eval_metrics is not None:
517 | output_eval_file = os.path.join(training_args.output_dir, "all_results.json")
518 | with open(output_eval_file, "w") as writer:
519 | writer.write(json.dumps(eval_metrics))
520 |
521 | if history is not None:
522 | ### --- Draw plot --- ###
523 | loss = history.history['loss']
524 | rouge1 = history.history['rouge1']
525 | rouge2 = history.history['rouge2']
526 | rougeL = history.history['rougeL']
527 | rougeLsum = history.history['rougeLsum']
528 |
529 | # Create an array of epoch numbers
530 | epochs = np.arange(1, len(loss) + 1)
531 |
532 | # Plotting the validation loss
533 | fig, ax1 = plt.subplots(figsize=(10, 6))
534 | color = 'tab:blue'
535 | ax1.set_xlabel('Epochs')
536 | ax1.set_ylabel('Loss', color=color)
537 | ax1.plot(epochs, loss, label='Loss', color=color)
538 | ax1.tick_params(axis='y', labelcolor=color)
539 |
540 | # Set y-axis range for validation loss
541 | ax1.set_ylim([0, max(loss)+0.2])
542 |
543 | # Creating a secondary y-axis for ROUGE scores
544 | ax2 = ax1.twinx()
545 |
546 | color = 'tab:green'
547 | ax2.set_ylabel('ROUGE Scores', color=color)
548 | ax2.plot(epochs, rouge1, label='ROUGE-1', color=color)
549 | ax2.plot(epochs, rouge2, label='ROUGE-2', color='tab:red')
550 | ax2.plot(epochs, rougeL, label='ROUGE-L', color='tab:purple')
551 | ax2.plot(epochs, rougeLsum, label='ROUGE-Lsum', color='tab:orange')
552 | ax2.tick_params(axis='y', labelcolor=color)
553 | # Set y-axis range for ROUGE scores
554 | ax2.set_ylim([min(rouge1) - 10, 100])
555 |
556 | # Add legend and grid
557 | fig.tight_layout()
558 | fig.legend(loc='lower right')
559 | plt.grid(True)
560 |
561 | # Save the plot
562 | output_plot_file = os.path.join(training_args.output_dir, "validation_loss_rouge_scores.png")
563 | plt.savefig(output_plot_file)
564 |
565 |
566 | if training_args.output_dir is not None and not training_args.push_to_hub:
567 | # If we're not pushing to hub, at least save a local copy when we're done
568 | model.save_pretrained(training_args.output_dir)
569 |
570 |
571 | if __name__ == "__main__":
572 | main()
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/train/__init__.py
--------------------------------------------------------------------------------
/train/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | @dataclass
5 | class Seq2SeqSCArguments:
6 | ebno_db: Optional[float] = field(default=None, metadata= {"help": "a single SNR(ebno_db)"})
7 | ebno_db_min: Optional[float] = field(default=None, metadata= {"help": "ebno_db_min"})
8 | ebno_db_max: Optional[float] = field(default=None, metadata= {"help": "ebno_db_max"})
9 | k: Optional[int] = field(default = 512, metadata= {"help": "K for polar decoder"})
10 | n: Optional[int] = field(default = 1024, metadata= {"help": "N for polar decoder"})
11 | fec_num_iter: Optional[int] = field(default = 6, metadata= {"help": "Number of iterations for FEC."})
12 | num_bits_per_symbol: Optional[int] = field(default=4,metadata= {"help": "number of bits per symbol"})
13 | channel_type: Optional[str] = field(default='AWGN',metadata= {"help": "AWGN, CDL, or 3GPP-38.901"})
14 | cdl_model: Optional[str] = field(default='A',metadata= {"help": "A, B, C, D, or E. Kinds of CDL channel"})
15 | channel_num_tx_ant: Optional[int] = field(default=1,metadata= {"help": "number of tx antennas for FlatFaddingChannel"})
16 | channel_num_rx_ant: Optional[int] = field(default=1,metadata= {"help": "number of rx antennas for FlatFaddingChannel"})
17 | fec_type: Optional[str] = field(default='Polar5G',metadata= {"help": "Polar5G or LDPC5G. Type of Channel En/Decoder."})
18 | bin_conv_method: Optional[str] = field(default='tanh', metadata= {"help": "tensor to binary conversion. tanh, vector_quantization"})
19 | embedding_dim: Optional[int] = field(default=2,metadata= {"help": "the dimension of a codebook in vq layer"})
20 | num_embeddings: Optional[int] = field(default=1024,metadata= {"help": "the number of codebooks in vq layer"})
21 |
22 |
23 | @dataclass
24 | class ModelArguments:
25 | model_name_or_path: Optional[str] = field(
26 | default=None,
27 | metadata={
28 | "help": "Path to pretrained model or model identifier from huggingface.co/models. For random initialization, it should be None or ''."
29 | }
30 | )
31 | config_name: Optional[str] = field(
32 | default=None,
33 | metadata={
34 | "help": "Pretrained config name or path if not the same as model_name"
35 | }
36 | )
37 | tokenizer_name: Optional[str] = field(
38 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
39 | )
40 | cache_dir: Optional[str] = field(
41 | default=None,
42 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
43 | )
44 | use_fast_tokenizer: bool = field(
45 | default=True,
46 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
47 | )
48 | model_revision: str = field(
49 | default="main",
50 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
51 | )
52 | use_auth_token: bool = field(
53 | default=False,
54 | metadata={
55 | "help": (
56 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
57 | "with private models)."
58 | )
59 | },
60 | )
61 | ignore_mismatched_sizes: Optional[bool] = field(
62 | default=False,
63 | metadata={
64 | "help":" Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels)."
65 | }
66 | )
67 |
68 |
69 | @dataclass
70 | class DataTrainingArguments:
71 | """
72 | Arguments pertaining to what data we are going to input our model for training and eval.
73 | """
74 |
75 | dataset_name: Optional[str] = field(
76 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
77 | )
78 | dataset_config_name: Optional[str] = field(
79 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
80 | )
81 | text_column: Optional[str] = field(
82 | default=None,
83 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
84 | )
85 | summary_column: Optional[str] = field(
86 | default=None,
87 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
88 | )
89 | train_file: Optional[str] = field(
90 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
91 | )
92 | validation_file: Optional[str] = field(
93 | default=None,
94 | metadata={
95 | "help": (
96 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
97 | )
98 | },
99 | )
100 | test_file: Optional[str] = field(
101 | default=None,
102 | metadata={
103 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
104 | },
105 | )
106 | overwrite_cache: bool = field(
107 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
108 | )
109 | preprocessing_num_workers: Optional[int] = field(
110 | default=None,
111 | metadata={"help": "The number of processes to use for the preprocessing."},
112 | )
113 | max_source_length: Optional[int] = field(
114 | default=1024,
115 | metadata={
116 | "help": (
117 | "The maximum total input sequence length after tokenization. Sequences longer "
118 | "than this will be truncated, sequences shorter will be padded."
119 | )
120 | },
121 | )
122 | max_target_length: Optional[int] = field(
123 | default=128,
124 | metadata={
125 | "help": (
126 | "The maximum total sequence length for target text after tokenization. Sequences longer "
127 | "than this will be truncated, sequences shorter will be padded."
128 | )
129 | },
130 | )
131 | val_max_target_length: Optional[int] = field(
132 | default=None,
133 | metadata={
134 | "help": (
135 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
136 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
137 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
138 | "during ``evaluate`` and ``predict``."
139 | )
140 | },
141 | )
142 | pad_to_max_length: bool = field(
143 | default=False,
144 | metadata={
145 | "help": (
146 | "Whether to pad all samples to model maximum sentence length. "
147 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
148 | "efficient on GPU but very bad for TPU."
149 | )
150 | },
151 | )
152 | max_train_samples: Optional[int] = field(
153 | default=None,
154 | metadata={
155 | "help": (
156 | "For debugging purposes or quicker training, truncate the number of training examples to this "
157 | "value if set."
158 | )
159 | },
160 | )
161 | max_eval_samples: Optional[int] = field(
162 | default=None,
163 | metadata={
164 | "help": (
165 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
166 | "value if set."
167 | )
168 | },
169 | )
170 | max_predict_samples: Optional[int] = field(
171 | default=None,
172 | metadata={
173 | "help": (
174 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
175 | "value if set."
176 | )
177 | },
178 | )
179 | num_beams: Optional[int] = field(
180 | default=None,
181 | metadata={
182 | "help": (
183 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
184 | "which is used during ``evaluate`` and ``predict``."
185 | )
186 | },
187 | )
188 | ignore_pad_token_for_loss: bool = field(
189 | default=True,
190 | metadata={
191 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
192 | },
193 | )
194 | source_prefix: Optional[str] = field(
195 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
196 | )
197 |
198 | def __post_init__(self):
199 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
200 | raise ValueError("Need either a dataset name or a training/validation file.")
201 | else:
202 | if self.train_file is not None:
203 | extension = self.train_file.split(".")[-1]
204 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
205 | if self.validation_file is not None:
206 | extension = self.validation_file.split(".")[-1]
207 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
208 | if self.val_max_target_length is None:
209 | self.val_max_target_length = self.max_target_length
210 |
211 |
212 | # region Dataset name mappings
213 | summarization_name_mapping = {
214 | "amazon_reviews_multi": ("review_body", "review_title"),
215 | "big_patent": ("description", "abstract"),
216 | "cnn_dailymail": ("article", "highlights"),
217 | "orange_sum": ("text", "summary"),
218 | "pn_summary": ("article", "summary"),
219 | "psc": ("extract_text", "summary_text"),
220 | "samsum": ("dialogue", "summary"),
221 | "thaisum": ("body", "summary"),
222 | "xglue": ("news_body", "news_title"),
223 | "xsum": ("document", "summary"),
224 | "wiki_summary": ("article", "highlights"),
225 | "multi_news": ("document", "summary"),
226 | }
227 | # endregion
228 |
229 |
230 |
--------------------------------------------------------------------------------