├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── eval.py ├── figures ├── On-device AI comm.png └── file_structure.png ├── models ├── __init__.py ├── channels.py ├── on_device_ai_comm.py ├── utils.py └── vq_vae.py ├── preprocess ├── __init__.py ├── europarl.py ├── flickr30k.py └── hf_data_gen.py ├── scripts ├── eval.sh └── train.sh ├── train.py └── train ├── __init__.py └── args.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | **/__pycache__ 3 | data 4 | .vscode 5 | **/backup 6 | **/log 7 | **/weights 8 | visualization 9 | jobs 10 | 11 | # ignore log files 12 | *.log 13 | 14 | models/on_device_ai_comm_v2.py 15 | preprocess/allnli.py 16 | srun_gpu.sh 17 | srun_cpu.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Juhyung Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On-Device AI/LLM Communication 2 | 3 | 4 | This repo is the implementation for our paper ["Integrating Pre-Trained Language Model with Physical Layer Communications"](https://arxiv.org/abs/2402.11656). 5 | 6 | ## Highlights 7 | - Fine-tuned a pre-trained LLM (BART) under noisy conditions (3GPP CDL-family channel model) in end-to-end Link-Level Simulation (LLS), integrating with 5G-NR PHY layer functions. 8 | - Developed a compression & quantization method for AI-to-AI comm, reducing transmission size by 50% without compromising performance. 9 | - Verified the framework in NVIDIA Sionna LLS within a 5G-NR PHY setup 10 | 11 | 12 | 13 | 15 | 16 | ## Citation 17 | 18 | ```bash 19 | @misc{lee2024integrating, 20 | title={Integrating Pre-Trained Language Model with Physical Layer Communications}, 21 | author={Ju-Hyung Lee and Dong-Ho Lee and Joohan Lee and Jay Pujara}, 22 | year={2024}, 23 | eprint={2402.11656}, 24 | archivePrefix={arXiv}, 25 | primaryClass={cs.IT} 26 | } 27 | ``` 28 | 29 | ## Model Architecture 30 | ![Model architecture](<./figures/On-device AI comm.png>) 31 | - Each channel model(e.g., ChannelAWGN) includes channel En/Decoder, mapper, or channel(AWGN, CDL, etc.). 32 | 33 | ## Available checkpoints for On-Device AI Communication 34 |
35 | 36 | | # | Transmission | embedding
dimension | # of
embeddings | Download Link | 37 | | :---: | :------------: | :-----------------: | :--------------: | :-------------------------------------------------------------------------------------------------: | 38 | | 1 | tanh | - | - | [tf_model.h5](https://drive.google.com/file/d/156PpJPNYzHAlXGv1M_y9H9eRnUXrnFTt/view?usp=sharing)| 39 | | 2 | VectorQuantizer | 2 | 1024 | [tf_model.h5](https://drive.google.com/file/d/13gBtLnKo8wwJV6_ZdGHB3AR8WlAyEsJN/view?usp=sharing)| 40 | | 3 | VectorQuantizer | 4 | 1024 | [tf_model.h5](https://drive.google.com/file/d/1OwQ69NGi6INKAExjwVNqr2pe1l3fs2tr/view?usp=sharing)| 41 | | 4 | VectorQuantizer | 8 | 1024 | [tf_model.h5](https://drive.google.com/file/d/12qrKD-q7habrlrm-5BSS9dnUebYEPdF3/view?usp=sharing)| 42 | | 5 | VectorQuantizer | 16 | 1024 | [tf_model.h5](https://drive.google.com/file/d/1DQCapmhGIeFmP66Y11bDzHbsyWJ-MYBC/view?usp=sharing)| 43 | 44 |
45 | We provide example scripts on how to use checkpoints at train, evaluation sections below. 46 | 47 | ## Setup 48 | 49 | Clone the repository and set up the environment: 50 | 51 | ```bash 52 | git clone https://github.com/abman23/on-device-ai-comm.git 53 | cd on-device-ai-comm 54 | conda env create -f environment.yml 55 | conda activate on-device-ai-comm 56 | ``` 57 | 58 | ## Data Preprocessing 59 | 60 | ### Europarl dataset 61 | 62 | ```bash 63 | data_path=data/europarl 64 | mkdir -p $data_path 65 | cd $data_path 66 | wget -P /tmp http://www.statmt.org/europarl/v7/europarl.tgz 67 | tar zxf /tmp/europarl.tgz 68 | 69 | europarl_dataset="$data_path/txt/en" 70 | out_dir="$data_path/processed" 71 | njobs=4 72 | 73 | mkdir -p $out_dir 74 | python -m preprocess.europarl -j $njobs -o $out_dir $europarl_dataset 75 | ``` 76 | 77 | 93 | 94 | ### Flickr30K 95 | 96 | To download the dataset, go to [Flickr30K](http://hockenmaier.cs.illinois.edu/DenotationGraph/) and fill out the form to get the downloadable link. 97 | 98 | ```bash 99 | data_path="data/flickr" 100 | dataset_path="${data_path}/flickr30k.tar.gz" 101 | out_dir="$data_path/processed" 102 | 103 | mkdir -p $out_dir 104 | 105 | tar xzf ${dataset_path} -C $data_path 106 | python -m preprocess.flickr30k \ 107 | -o "$out_dir/flickr30k.json" \ 108 | "${data_path}/results_20130124.token" 109 | ``` 110 | 111 | ## Train 112 | 113 | You can run `scripts/train.sh`. Otherwise, you can train by running the follwing commands. Below is an example for training on-device ai communication system over CDL-A 5 ~ 15dB with vector quantizer. 114 | 115 | ```bash 116 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15' 117 | trainset_path='data/europarl/processed/train.csv' 118 | devset_path='data/europarl/processed/test.csv' 119 | 120 | mkdir -p $output_dir 121 | 122 | python train.py \ 123 | --model_name_or_path facebook/bart-base \ 124 | --config_name facebook/bart-base \ 125 | --tokenizer_name facebook/bart-base \ 126 | --train_file "$trainset_path" \ 127 | --validation_file "$devset_path" \ 128 | --test_file "$devset_path" \ 129 | --preprocessing_num_workers 4 \ 130 | --per_device_train_batch_size 4 \ 131 | --per_device_eval_batch_size 4 \ 132 | --num_train_epochs 3 \ 133 | --do_train \ 134 | --do_eval \ 135 | --save_total_limit 1 \ 136 | --no_use_fast_tokenizer \ 137 | --num_beams 1 \ 138 | --pad_to_max_length \ 139 | --overwrite_output_dir \ 140 | --max_source_length 64 \ 141 | --max_target_length 64 \ 142 | --output_dir $output_dir \ 143 | --ebno_db_min 5 \ 144 | --ebno_db_max 15 \ 145 | --channel_type "CDL" \ 146 | --fec_type "Polar5G" \ 147 | --fec_num_iter 20 \ 148 | --cdl_model "A" \ 149 | --channel_num_tx_ant "2" \ 150 | --channel_num_rx_ant "2" \ 151 | --num_bits_per_symbol "4" \ 152 | --bin_conv_method "vector_quantization" \ 153 | --embedding_dim 2 \ 154 | --num_embeddings 1024 \ 155 | --dataset_config 3.0.0 156 | ``` 157 | 158 | - For more arguments for training, please navigate to [here](./train/args.py). 159 | 160 | ## Evaluation 161 | 162 | You can use the script `scripts/eval.sh` or the following commands: 163 | 164 | ```bash 165 | # BLEU score 166 | eval_ebno_db="4" 167 | metric="bleu" # bleu, sbert 168 | testset_path='data/flickr/processed/flickr30k.json' 169 | 170 | checkpoint_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15' 171 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15/CDL-A' 172 | 173 | mkdir -p $output_dir 174 | 175 | fec_type="Polar5G" # Polar5G, LDPC5G 176 | fec_num_iter=20 177 | channel_num_tx_ant="2" 178 | channel_num_rx_ant="2" 179 | num_bits_per_symbol="4" 180 | EVAL_NUM_BEAMS="1" 181 | 182 | python eval.py \ 183 | -m "${metric}" \ 184 | -b 8 \ 185 | -e "${eval_ebno_db}" \ 186 | --result-json-path "${output_dir}/flickr_${metric}_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \ 187 | --prediction-json-path "${output_dir}/flickr_prediction_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \ 188 | --fec-type "${fec_type}" \ 189 | --fec-num-iter "${fec_num_iter}" \ 190 | --channel-type "CDL" \ 191 | --cdl-model "A" \ 192 | --channel-num-tx-ant "${channel_num_tx_ant}" \ 193 | --channel-num-rx-ant "${channel_num_rx_ant}" \ 194 | --num-bits-per-symbol "${num_bits_per_symbol}" \ 195 | --bin-conv-method "vector_quantization" \ 196 | --embedding-dim 2 \ 197 | --num-embeddings 1024 \ 198 | --num-beams "${EVAL_NUM_BEAMS}" \ 199 | --testset-path "${testset_path}" \ 200 | $checkpoint_dir 201 | ``` 202 | - Note that the name of model checkpoint in checkpoint_dir should be 'tf_model.h5'. 203 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: on-device-ai-comm 2 | channels: 3 | - huggingface 4 | - anaconda 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - autopep8=1.6.0=pyhd3eb1b0_1 11 | - blas=1.0=mkl 12 | - brotlipy=0.7.0=py39h27cfd23_1003 13 | - ca-certificates=2022.12.7=ha878542_0 14 | - cffi=1.15.1=py39h74dc2b5_0 15 | - click=8.0.4=py39h06a4308_0 16 | - cryptography=37.0.1=py39h9ce1e76_0 17 | - cudatoolkit=11.2.2=hbe64b41_10 18 | - cudnn=8.1.0.77=h90431f1_0 19 | - dataclasses=0.8=pyh6d0b6a4_7 20 | - filelock=3.6.0=pyhd3eb1b0_0 21 | - huggingface_hub=0.9.1=py_0 22 | - importlib_metadata=4.11.3=hd3eb1b0_0 23 | - intel-openmp=2021.4.0=h06a4308_3561 24 | - joblib=1.1.0=pyhd3eb1b0_0 25 | - ld_impl_linux-64=2.38=h1181459_1 26 | - libffi=3.3=he6710b0_2 27 | - libgcc-ng=12.1.0=h8d9b700_16 28 | - libgomp=12.1.0=h8d9b700_16 29 | - libprotobuf=3.20.1=h4ff587b_0 30 | - libstdcxx-ng=12.1.0=ha89aaad_16 31 | - mkl=2021.4.0=h06a4308_640 32 | - mkl-service=2.4.0=py39h7f8727e_0 33 | - mkl_fft=1.3.1=py39hd3c417c_0 34 | - mkl_random=1.2.2=py39h51133e4_0 35 | - ncurses=6.3=h5eee18b_3 36 | - openssl=1.1.1s=h0b41bf4_1 37 | - pycodestyle=2.8.0=pyhd3eb1b0_0 38 | - pycparser=2.21=pyhd3eb1b0_0 39 | - pyopenssl=22.0.0=pyhd3eb1b0_0 40 | - pyparsing=3.0.9=py39h06a4308_0 41 | - pysocks=1.7.1=py39h06a4308_0 42 | - python=3.9.13=haa1d7c7_1 43 | - pyyaml=6.0=py39h7f8727e_1 44 | - readline=8.1.2=h7f8727e_1 45 | - regex=2022.7.9=py39h5eee18b_0 46 | - requests=2.28.1=py39h06a4308_0 47 | - sacremoses=master=py_0 48 | - six=1.16.0=pyhd3eb1b0_1 49 | - sqlite=3.39.2=h5082296_0 50 | - tk=8.6.12=h1ccaba5_0 51 | - toml=0.10.2=pyhd3eb1b0_0 52 | - tqdm=4.64.0=py39h06a4308_0 53 | - tzdata=2022c=h04d1e81_0 54 | - xz=5.2.5=h7f8727e_1 55 | - yaml=0.2.5=h7b6447c_0 56 | - zlib=1.2.12=h5eee18b_3 57 | - pip: 58 | - absl-py==1.3.0 59 | - accelerate==0.13.0 60 | - aiohttp==3.8.3 61 | - aiosignal==1.2.0 62 | - astroid==2.12.10 63 | - astunparse==1.6.3 64 | - async-timeout==4.0.2 65 | - attrs==22.1.0 66 | - bert-score==0.3.12 67 | - cachetools==5.2.0 68 | - certifi==2022.12.7 69 | - charset-normalizer==2.1.1 70 | - cloudpickle==2.2.1 71 | - contourpy==1.0.5 72 | - cycler==0.11.0 73 | - datasets==2.5.1 74 | - decorator==5.1.1 75 | - dill==0.3.5.1 76 | - dm-tree==0.1.8 77 | - evaluate==0.2.2 78 | - flatbuffers==22.12.6 79 | - fonttools==4.37.3 80 | - frozenlist==1.3.1 81 | - fsspec==2022.8.2 82 | - gast==0.4.0 83 | - google-auth==2.15.0 84 | - google-auth-oauthlib==0.4.6 85 | - google-pasta==0.2.0 86 | - grpcio==1.51.1 87 | - h5py==3.7.0 88 | - huggingface-hub==0.11.1 89 | - idna==3.4 90 | - importlib-metadata==5.1.0 91 | - importlib-resources==5.9.0 92 | - isort==5.10.1 93 | - keras==2.10.0 94 | - keras-preprocessing==1.1.2 95 | - kiwisolver==1.4.4 96 | - lazy-object-proxy==1.7.1 97 | - libclang==14.0.6 98 | - markdown==3.4.1 99 | - markupsafe==2.1.1 100 | - matplotlib==3.6.0 101 | - mccabe==0.7.0 102 | - moverscore==1.0.3 103 | - multidict==6.0.2 104 | - multiprocess==0.70.13 105 | - nltk==3.7 106 | - numpy==1.23.5 107 | - oauthlib==3.2.2 108 | - opt-einsum==3.3.0 109 | - packaging==22.0 110 | - pandas==1.5.0 111 | - pillow==9.2.0 112 | - pip==22.3.1 113 | - platformdirs==2.5.2 114 | - portalocker==2.6.0 115 | - protobuf==3.19.6 116 | - psutil==5.9.2 117 | - pyarrow==8.0.0 118 | - pyasn1==0.4.8 119 | - pyasn1-modules==0.2.8 120 | - pydot==1.4.2 121 | - pyemd==0.5.1 122 | - pylint==2.15.3 123 | - python-dateutil==2.8.2 124 | - pytz==2022.2.1 125 | - requests-oauthlib==1.3.1 126 | - responses==0.18.0 127 | - rouge-score==0.1.2 128 | - rsa==4.9 129 | - scikit-learn==1.1.2 130 | - scipy==1.9.1 131 | - sentence-transformers==2.2.2 132 | - sentencepiece==0.1.97 133 | - setuptools==65.6.3 134 | - sionna==0.11.0 135 | - tensorboard==2.10.1 136 | - tensorboard-data-server==0.6.1 137 | - tensorboard-plugin-wit==1.8.1 138 | - tensorflow==2.10.1 139 | - tensorflow-estimator==2.10.0 140 | - tensorflow-io-gcs-filesystem==0.28.0 141 | - tensorflow-probability==0.18.0 142 | - termcolor==2.1.1 143 | - threadpoolctl==3.1.0 144 | - tokenizers==0.12.1 145 | - tomli==2.0.1 146 | - tomlkit==0.11.4 147 | - torch==1.12.1 148 | - torchvision==0.13.1 149 | - transformers==4.25.1 150 | - typing==3.7.4.3 151 | - typing-extensions==4.4.0 152 | - urllib3==1.26.13 153 | - w3lib==2.0.1 154 | - werkzeug==2.2.2 155 | - wheel==0.38.4 156 | - wrapt==1.14.1 157 | - xxhash==3.0.0 158 | - yapf==0.32.0 159 | - yarl==1.8.1 160 | - zipp==3.11.0 161 | prefix: /home/danny911kr/miniconda3/envs/on-device-ai-comm 162 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import json 3 | import argparse 4 | from transformers import BartTokenizer 5 | import evaluate 6 | from tqdm import tqdm 7 | import warnings 8 | # from h5_utils import rename_weight 9 | from torch.profiler import profile, record_function, ProfilerActivity 10 | 11 | def get_test_data(path): 12 | with open(path) as f: 13 | return json.load(f) 14 | 15 | def from_pretrained(path, ebno_db, bin_conv_method, channel_type, 16 | embedding_dim, num_embeddings, 17 | fec_type, cdl_model, 18 | scenario, perfect_csi, 19 | channel_num_tx_ant=1, channel_num_rx_ant=1, 20 | num_bits_per_symbol=4): 21 | from models.on_device_ai_comm import TFOnDeviceAICForConditionalGeneration 22 | import transformers 23 | 24 | model = TFOnDeviceAICForConditionalGeneration.from_pretrained( 25 | path, ebno_db=ebno_db, 26 | bin_conv_method=bin_conv_method, channel_type=channel_type, 27 | fec_type=fec_type, cdl_model=cdl_model, 28 | scenario=scenario, perfect_csi=perfect_csi, 29 | channel_num_tx_ant=channel_num_tx_ant, channel_num_rx_ant=channel_num_rx_ant, 30 | num_bits_per_symbol=num_bits_per_symbol, 31 | embedding_dim=embedding_dim, num_embeddings=num_embeddings) 32 | 33 | return model 34 | 35 | def predict(path, ebno_db, 36 | tokenizer, batch_size, test_data_path, 37 | num_beams, bin_conv_method, channel_type, 38 | fec_type, cdl_model, 39 | scenario, perfect_csi, 40 | channel_num_tx_ant, channel_num_rx_ant, 41 | num_bits_per_symbol, 42 | embedding_dim, num_embeddings): 43 | import tensorflow as tf 44 | max_len = 32 45 | 46 | # load model 47 | model = from_pretrained(path, ebno_db, 48 | bin_conv_method=bin_conv_method, 49 | channel_type=channel_type, 50 | fec_type=fec_type, cdl_model=cdl_model, 51 | scenario=scenario, perfect_csi=perfect_csi, 52 | channel_num_tx_ant=channel_num_tx_ant, channel_num_rx_ant=channel_num_rx_ant, 53 | num_bits_per_symbol=num_bits_per_symbol, 54 | embedding_dim=embedding_dim, 55 | num_embeddings=num_embeddings) 56 | 57 | # # load testset 58 | test_data = get_test_data(test_data_path) 59 | input_sentences = [d['input'] for d in test_data] 60 | input_ids = tokenizer(input_sentences, return_tensors="tf", 61 | padding='max_length', truncation=True, max_length=max_len).input_ids 62 | testset = tf.data.Dataset.from_tensor_slices(input_ids) 63 | 64 | # inference 65 | pred_sentences = [] 66 | bers=[] 67 | for input_ids in tqdm(testset.batch(batch_size).prefetch(tf.data.AUTOTUNE)): 68 | output = model(input_ids) # To get BER 69 | pred_batch = model.generate(input_ids, max_new_tokens=max_len, num_beams=num_beams, 70 | top_k=4, penalty_alpha=0.6, do_sample=False) 71 | 72 | output_strs = tokenizer.batch_decode(pred_batch, 73 | skip_special_tokens=True, 74 | clean_up_tokenization_spaces=False) 75 | pred_sentences.extend(output_strs) 76 | bers.append(output.ber.numpy()) 77 | 78 | mean_ber = sum(bers) / len(bers) 79 | 80 | res = { 81 | 'input': input_sentences, 82 | 'pred': pred_sentences, 83 | 'refs': [d['refs'] for d in test_data], 84 | 'mean_ber': mean_ber, 85 | } 86 | return res 87 | 88 | def get_predictions(path, ebno_db, test_data_path, 89 | prediction_json_path, batch_size, 90 | tokenizer, num_beams, bin_conv_method, 91 | channel_type, fec_type, cdl_model, 92 | scenario, perfect_csi, 93 | channel_num_tx_ant,channel_num_rx_ant, num_bits_per_symbol, 94 | embedding_dim, num_embeddings, 95 | calc_flops): 96 | path = pathlib.Path(path) 97 | if not prediction_json_path.exists(): 98 | print('Missing predictions.json') 99 | 100 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], with_flops=True,record_shapes=True) as prof: 101 | with record_function("model_inference"): 102 | res = predict( 103 | path=path, 104 | ebno_db=ebno_db, 105 | tokenizer=tokenizer, 106 | batch_size=batch_size, 107 | test_data_path=test_data_path, 108 | num_beams=num_beams, 109 | bin_conv_method=bin_conv_method, 110 | channel_type=channel_type, 111 | fec_type=fec_type, 112 | cdl_model=cdl_model, 113 | scenario=scenario, 114 | perfect_csi=perfect_csi, 115 | channel_num_tx_ant=channel_num_tx_ant, 116 | channel_num_rx_ant=channel_num_rx_ant, 117 | num_bits_per_symbol=num_bits_per_symbol, 118 | embedding_dim=embedding_dim, 119 | num_embeddings=num_embeddings 120 | ) 121 | if calc_flops: 122 | print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) 123 | 124 | # save result 125 | with open(prediction_json_path, 'w') as f: 126 | json.dump(res, f, indent=4) 127 | else: 128 | with open(prediction_json_path, 'r') as f: 129 | res = json.load(f) 130 | return res 131 | 132 | def calc_bleu(predictions, tokenizer, multi_ref, **kwargs): 133 | bleu = evaluate.load('bleu') 134 | if multi_ref: 135 | warnings.warn('BLEU does not support multiple references') 136 | tokenize = lambda l: tokenizer(l, add_special_tokens=False).input_ids 137 | results = bleu.compute( 138 | references=predictions['input'], 139 | predictions=predictions['pred'], 140 | tokenizer=tokenize, 141 | max_order=4) 142 | return results 143 | 144 | def calc_sbert(predictions, batch_size, multi_ref, **kwargs): 145 | from sentence_transformers import SentenceTransformer, util 146 | import torch 147 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 148 | model = SentenceTransformer( 149 | model_name_or_path='all-MiniLM-L6-v2', 150 | device=device) 151 | 152 | sentences1 = predictions['pred'] 153 | 154 | if not multi_ref: 155 | refs = [[s] for s in predictions['input']] 156 | else: 157 | refs = predictions['refs'] 158 | 159 | def calc_cos_score(model, hyp_embedding, ref_sentences): 160 | hyp = hyp_embedding.reshape((1, -1)) 161 | refs = model.encode(ref_sentences, convert_to_tensor=True) 162 | scores = util.cos_sim(hyp, refs) 163 | scores = scores.reshape((-1)).tolist() 164 | return { 165 | 'scores': scores, 166 | 'max_score': max(scores), 167 | 'mean_score': sum(scores) / len(scores), 168 | } 169 | 170 | 171 | # compute embedding 172 | pred_embed = model.encode(sentences1, batch_size=batch_size, convert_to_tensor=True) 173 | N = pred_embed.shape[0] 174 | scores = [ 175 | calc_cos_score(model, pred_embed[i], refs[i]) for i in range(N) 176 | ] 177 | max_scores = [s['max_score'] for s in scores] 178 | mean_score = sum(max_scores)/len(max_scores) 179 | return { 180 | 'metric': 'sentence textual similarity', 181 | 'mean_score': mean_score, 182 | 'scores': scores, 183 | } 184 | 185 | METRIC_TO_SCORER = { 186 | 'bleu': calc_bleu, 187 | 'sbert': calc_sbert, 188 | } 189 | 190 | def calc(args): 191 | tokenizer = BartTokenizer.from_pretrained(args.tokenizer) 192 | 193 | path = args.path 194 | metric = args.metric 195 | batch_size = args.batch_size 196 | 197 | # rename weight name in tf_model.h5 for BinConv 198 | # rename_weight(path, RENAME_MAP) 199 | 200 | # VQ-VAE arguments 201 | if args.bin_conv_method=='vector_quantization': 202 | assert (args.embedding_dim is not None and args.num_embeddings is not None), 'Set embedding_dim and num_embeddings.' 203 | 204 | predictions = get_predictions( 205 | path, 206 | ebno_db=args.ebno_db, 207 | prediction_json_path=args.prediction_json_path, 208 | test_data_path=args.testset_path, 209 | batch_size=batch_size, 210 | tokenizer=tokenizer, 211 | num_beams=args.num_beams, 212 | bin_conv_method=args.bin_conv_method, 213 | channel_type=args.channel_type, 214 | fec_type=args.fec_type, 215 | cdl_model=args.cdl_model, 216 | scenario=args.scenario, 217 | perfect_csi=args.perfect_csi, 218 | channel_num_tx_ant=args.channel_num_tx_ant, 219 | channel_num_rx_ant=args.channel_num_rx_ant, 220 | embedding_dim=int(args.embedding_dim), 221 | num_embeddings=int(args.num_embeddings), 222 | num_bits_per_symbol=args.num_bits_per_symbol, 223 | calc_flops=args.calc_flops) 224 | scorer = METRIC_TO_SCORER[metric] 225 | results = scorer( 226 | predictions=predictions, 227 | tokenizer=tokenizer, 228 | batch_size=batch_size, 229 | multi_ref=args.multi_ref, 230 | ) 231 | 232 | # Add mean_ber 233 | results['mean_ber'] = predictions['mean_ber'] 234 | 235 | # dump result 236 | with open(args.result_json_path, 'w') as f: 237 | json.dump(results, f, indent=4) 238 | 239 | 240 | def main(): 241 | parser = argparse.ArgumentParser() 242 | parser.add_argument(dest='path', metavar='checkpoint_path', type=pathlib.Path) 243 | parser.add_argument('-m', '--metric', choices = list(METRIC_TO_SCORER.keys()), dest='metric') 244 | parser.add_argument('-b', '--batch-size', default=4, type=int, dest='batch_size') 245 | parser.add_argument('-e', '--ebno-db', required=True, type=float, dest='ebno_db') 246 | parser.add_argument('--testset-path', 247 | required=True, type=pathlib.Path, dest='testset_path') 248 | parser.add_argument('--prediction-json-path', 249 | required=True, 250 | type=pathlib.Path, 251 | dest='prediction_json_path', 252 | help='Required. Output path of prediction result cache json file. \ 253 | If the file exists, the prediction result will be reused') 254 | parser.add_argument('--result-json-path', 255 | default=pathlib.Path('./result.json'), 256 | type= pathlib.Path, 257 | dest='result_json_path') 258 | parser.add_argument('--tokenizer', 259 | default='facebook/bart-base', 260 | dest='tokenizer') 261 | parser.add_argument('--num-beams', 262 | default=1, 263 | type=int, 264 | dest='num_beams') 265 | parser.add_argument('--multi-ref', 266 | action='store_true', 267 | dest='multi_ref') 268 | parser.add_argument('--bin-conv-method', default='tanh') 269 | parser.add_argument('--channel-type', default='AWGN') 270 | parser.add_argument('--fec-type', default='Polar5G') 271 | parser.add_argument('--fec-num-iter', default=6) 272 | parser.add_argument('--cdl-model', default='A') # CDL 273 | parser.add_argument('--scenario', default='umi') # 3GPP-38.901 274 | parser.add_argument('--perfect-csi', default=True) # 3GPP-38.901 275 | parser.add_argument('--channel-num-tx-ant', default=2) 276 | parser.add_argument('--channel-num-rx-ant', default=2) 277 | parser.add_argument('--num-bits-per-symbol', default=4) 278 | parser.add_argument('--embedding-dim', default=2) # vector quantization 279 | parser.add_argument('--num-embeddings', default=1024) # vector quantization 280 | parser.add_argument('--calc-flops', default=False, type=bool) 281 | args = parser.parse_args() 282 | print(f'{args=}') 283 | 284 | calc(args) 285 | 286 | if __name__ == '__main__': 287 | main() 288 | -------------------------------------------------------------------------------- /figures/On-device AI comm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/figures/On-device AI comm.png -------------------------------------------------------------------------------- /figures/file_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/figures/file_structure.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.on_device_ai_comm import TFOnDeviceAICMainLayer, TFOnDeviceAICForConditionalGeneration -------------------------------------------------------------------------------- /models/channels.py: -------------------------------------------------------------------------------- 1 | from sionna.channel import AWGN, FlatFadingChannel 2 | from sionna.fec.polar import Polar5GEncoder, Polar5GDecoder 3 | from sionna.fec.ldpc.encoding import LDPC5GEncoder 4 | from sionna.fec.ldpc.decoding import LDPC5GDecoder 5 | from sionna.mapping import Mapper, Demapper, Constellation 6 | 7 | from sionna.mimo import mf_equalizer, StreamManagement, lmmse_equalizer 8 | from sionna.ofdm import ResourceGrid, ResourceGridMapper, LSChannelEstimator, LMMSEEqualizer 9 | from sionna.ofdm import OFDMModulator, OFDMDemodulator, ZFPrecoder, RemoveNulledSubcarriers 10 | from sionna.channel.tr38901 import AntennaArray, CDL, Antenna, UMi, UMa, RMa 11 | from sionna.channel import subcarrier_frequencies, cir_to_ofdm_channel, cir_to_time_channel, time_lag_discrete_time_channel 12 | from sionna.channel import gen_single_sector_topology as gen_topology 13 | from sionna.channel import ApplyOFDMChannel, ApplyTimeChannel, OFDMChannel, TimeChannel 14 | 15 | from sionna.utils import ebnodb2no, expand_to_rank, QAMSource 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from transformers.utils import ( 21 | logging, 22 | ) 23 | logger = logging.get_logger("transformers") 24 | 25 | class ChannelAWGN(tf.keras.Model): 26 | """ 27 | Configure AWGN Channel components. 28 | ref: https://nvlabs.github.io/sionna/api/channel.wireless.html?highlight=awgn#sionna.channel.AWGN 29 | 30 | Parameters 31 | ---------- 32 | :param fec_type: str, One of ["Polar5G", "LDPC5G"] 33 | :param num_bits_per_symbol: int 34 | :param fec_n: int 35 | :param fec_k: int 36 | :param ebno_db: float 37 | :param ebno_db_min: float 38 | :param ebno_db_max: float 39 | :param fec_num_iter: int 40 | 41 | """ 42 | def __init__(self, 43 | fec_type, 44 | num_bits_per_symbol, 45 | fec_n, 46 | fec_k, 47 | ebno_db=None, 48 | ebno_db_min=None, 49 | ebno_db_max=None, 50 | fec_num_iter=6 51 | ): 52 | super().__init__() 53 | self.fec_type = fec_type 54 | 55 | self._n = fec_n 56 | self._k = fec_k 57 | self._coderate = self._k / self._n 58 | 59 | print(f'{self._k=}') 60 | print(f'{self._n=}') 61 | print(f'{self._coderate=}') 62 | 63 | constellation = Constellation("qam", 64 | num_bits_per_symbol, 65 | trainable=False) 66 | logger.info(f'Constellation: type={constellation._constellation_type} ' + \ 67 | f'{num_bits_per_symbol=} trainable={constellation._trainable}') 68 | self.num_bits_per_symbol = num_bits_per_symbol 69 | self.mapper = Mapper(constellation=constellation) 70 | 71 | self.channel = AWGN() 72 | 73 | # channel noise 74 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)" 75 | if ebno_db is not None: 76 | self.ebno_db = float(ebno_db) 77 | else: 78 | self.ebno_db = ebno_db # None 79 | self.ebno_db_min = ebno_db_min 80 | self.ebno_db_max = ebno_db_max 81 | 82 | print(f'{self.ebno_db=}') 83 | print(f'{self.ebno_db_min=}') 84 | print(f'{self.ebno_db_max=}') 85 | 86 | self.demapper = Demapper("app", constellation=constellation) 87 | 88 | # FEC 89 | self.fec_num_iter = fec_num_iter 90 | if self.fec_type == 'Polar5G': 91 | self._encoder = Polar5GEncoder(self._k, self._n) 92 | self._decoder = Polar5GDecoder( 93 | self._encoder, 94 | dec_type='SC', 95 | list_size=8, 96 | num_iter=self.fec_num_iter 97 | ) 98 | elif self.fec_type == 'LDPC5G': 99 | self._encoder = LDPC5GEncoder(self._k, self._n) 100 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter) 101 | else: 102 | raise ValueError(f"Invalid channel coding type: {fec_type}") 103 | 104 | 105 | @tf.function 106 | def call(self, input): 107 | ''' 108 | Input 109 | ----- 110 | :param input: 111 | 112 | Output 113 | ------ 114 | :return b_hat: 115 | ''' 116 | # reshape input 117 | input_shape = input.shape 118 | 119 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor. 120 | divisor=self._k 121 | if np.prod(input_shape) % divisor != 0: 122 | flatten_input = tf.reshape(input, [-1]) 123 | flatten_input_len = len(flatten_input) 124 | 125 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len 126 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0) 127 | else: 128 | flatten_input = input 129 | 130 | # Channel encoder 131 | b = tf.reshape(flatten_input, (-1, self._k)) 132 | codewords = self._encoder(b) 133 | 134 | # Modulation 135 | x = self.mapper(codewords) 136 | 137 | ##################### 138 | # Channel 139 | ##################### 140 | # Sampling a batch of SNRs 141 | batch_size=b.shape[0] 142 | if self.ebno_db_min is not None and self.ebno_db_max is not None: 143 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max) 144 | no = ebnodb2no(ebno_db_tf, self.num_bits_per_symbol, self._coderate) 145 | else: 146 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self._coderate) 147 | 148 | no = expand_to_rank(no, 2) 149 | 150 | y = self.channel([x, no]) 151 | 152 | ##################### 153 | # Receiver 154 | ##################### 155 | # Demodulation 156 | llr = self.demapper([y, no]) 157 | llr = tf.reshape(llr, (-1, self._n)) 158 | 159 | # Channel decoder 160 | b_hat = self._decoder(llr) 161 | 162 | if np.prod(input_shape) % divisor != 0: 163 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements 164 | flatten_b_hat = tf.reshape(b_hat, [-1]) 165 | sliced_b_hat = flatten_b_hat[:-dummy_cnt] 166 | b_hat=tf.reshape(sliced_b_hat, input_shape) 167 | else: 168 | b_hat=tf.reshape(b_hat, input_shape) 169 | 170 | return b_hat 171 | 172 | class ChannelCDL(tf.keras.Model): 173 | """ 174 | Configure CDL Channel components. 175 | 176 | Parameters 177 | ---------- 178 | :param fec_type: str, One of ["Polar5G", "LDPC5G"] 179 | :param cdl_model: str, One of ["A", "B", "C", "D", "E"] 180 | :param channel_num_tx_ant: int 181 | :param channel_num_rx_ant: int 182 | :param num_bits_per_symbol: int 183 | :param ebno_db: float 184 | :param ebno_db_min: float 185 | :param ebno_db_max: float 186 | :param fec_num_iter: int 187 | """ 188 | def __init__(self, 189 | fec_type, 190 | cdl_model, 191 | channel_num_tx_ant, 192 | channel_num_rx_ant, 193 | num_bits_per_symbol, 194 | ebno_db=None, 195 | ebno_db_min=None, 196 | ebno_db_max=None, 197 | fec_num_iter=6 198 | ): 199 | super().__init__() 200 | 201 | # Provided parameters 202 | DL_CONFIG={ 203 | "cdl_model" : cdl_model, 204 | "delay_spread" : 100e-9, 205 | "domain" : "time", 206 | "direction" : "downlink", 207 | "perfect_csi" : True, 208 | "speed" : 0.0, 209 | "cyclic_prefix_length" : 6, 210 | "pilot_ofdm_symbol_indices" : [2, 11], 211 | "duration" : None 212 | } 213 | self._domain = DL_CONFIG["domain"] 214 | self._direction = DL_CONFIG["direction"] 215 | self._cdl_model = DL_CONFIG["cdl_model"] 216 | self._delay_spread = DL_CONFIG["delay_spread"] 217 | self._perfect_csi = DL_CONFIG["perfect_csi"] 218 | self._speed = DL_CONFIG["speed"] 219 | self._cyclic_prefix_length = DL_CONFIG["cyclic_prefix_length"] 220 | self._pilot_ofdm_symbol_indices = DL_CONFIG["pilot_ofdm_symbol_indices"] 221 | 222 | logger.info(f'{DL_CONFIG=}') 223 | 224 | # System parameters 225 | self._carrier_frequency = 2.6e9 226 | self._subcarrier_spacing = 15e3 #subcarrier_spacing 227 | self._fft_size = 36 228 | self._num_ofdm_symbols = 12 229 | self._num_ut_ant = int(channel_num_tx_ant) # Must be a multiple of two as dual-polarized antennas are used 230 | self._num_bs_ant = int(channel_num_rx_ant) # Must be a multiple of two as dual-polarized antennas are used 231 | self._num_streams_per_tx = self._num_ut_ant 232 | self._dc_null = True 233 | self._num_guard_carriers = [5, 6] 234 | self._pilot_pattern = "kronecker" 235 | self._pilot_ofdm_symbol_indices = DL_CONFIG["pilot_ofdm_symbol_indices"] 236 | self._num_bits_per_symbol = int(num_bits_per_symbol) 237 | self._coderate = 0.5 238 | 239 | # Required system components 240 | self._sm = StreamManagement(np.array([[1]]), self._num_streams_per_tx) 241 | 242 | self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols, 243 | fft_size=self._fft_size, 244 | subcarrier_spacing = self._subcarrier_spacing, 245 | num_tx=1, 246 | num_streams_per_tx=self._num_streams_per_tx, 247 | cyclic_prefix_length=self._cyclic_prefix_length, 248 | num_guard_carriers=self._num_guard_carriers, 249 | dc_null=self._dc_null, 250 | pilot_pattern=self._pilot_pattern, 251 | pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices) 252 | 253 | self._n = int(self._rg.num_data_symbols * self._num_bits_per_symbol) 254 | self._k = int(self._n * self._coderate) 255 | 256 | self._ut_array = AntennaArray(num_rows=1, 257 | num_cols=int(self._num_ut_ant/2), 258 | polarization="dual", 259 | polarization_type="cross", 260 | antenna_pattern="38.901", 261 | carrier_frequency=self._carrier_frequency) 262 | 263 | self._bs_array = AntennaArray(num_rows=1, 264 | num_cols=int(self._num_bs_ant/2), 265 | polarization="dual", 266 | polarization_type="cross", 267 | antenna_pattern="38.901", 268 | carrier_frequency=self._carrier_frequency) 269 | 270 | self._cdl = CDL(model=self._cdl_model, 271 | delay_spread=self._delay_spread, 272 | carrier_frequency=self._carrier_frequency, 273 | ut_array=self._ut_array, 274 | bs_array=self._bs_array, 275 | direction=self._direction, 276 | min_speed=self._speed) 277 | 278 | self._frequencies = subcarrier_frequencies(self._rg.fft_size, self._rg.subcarrier_spacing) 279 | 280 | if self._domain == "freq": 281 | self._channel_freq = ApplyOFDMChannel(add_awgn=True) 282 | 283 | elif self._domain == "time": 284 | self._l_min, self._l_max = time_lag_discrete_time_channel(self._rg.bandwidth) 285 | self._l_tot = self._l_max - self._l_min + 1 286 | self._channel_time = ApplyTimeChannel(self._rg.num_time_samples, 287 | l_tot=self._l_tot, 288 | add_awgn=True) 289 | self._modulator = OFDMModulator(self._cyclic_prefix_length) 290 | self._demodulator = OFDMDemodulator(self._fft_size, self._l_min, self._cyclic_prefix_length) 291 | 292 | self.fec_type = fec_type 293 | 294 | # channel noise 295 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)" 296 | if ebno_db is not None: 297 | self.ebno_db = float(ebno_db) 298 | else: 299 | self.ebno_db = ebno_db # None 300 | self.ebno_db_min = ebno_db_min 301 | self.ebno_db_max = ebno_db_max 302 | 303 | # FEC 304 | self.fec_num_iter = fec_num_iter 305 | if self.fec_type == 'Polar5G': 306 | self._encoder = Polar5GEncoder(self._k, self._n) 307 | self._decoder = Polar5GDecoder( 308 | self._encoder, 309 | dec_type='SC', 310 | list_size=8, 311 | num_iter=self.fec_num_iter 312 | ) 313 | elif self.fec_type == 'LDPC5G': 314 | self._encoder = LDPC5GEncoder(self._k, self._n) 315 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter) 316 | else: 317 | raise ValueError(f"Invalid channel coding type: {fec_type}") 318 | 319 | self._mapper = Mapper("qam", self._num_bits_per_symbol) 320 | self._rg_mapper = ResourceGridMapper(self._rg) 321 | 322 | if self._direction == "downlink": 323 | self._zf_precoder = ZFPrecoder(self._rg, self._sm, return_effective_channel=True) 324 | 325 | self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn") 326 | self._lmmse_equ = LMMSEEqualizer(self._rg, self._sm) 327 | self._demapper = Demapper("app", "qam", self._num_bits_per_symbol) 328 | 329 | self._remove_nulled_scs = RemoveNulledSubcarriers(self._rg) 330 | 331 | @tf.function(jit_compile=True) 332 | def call(self, input): 333 | """ 334 | Input 335 | ----- 336 | :param input: 337 | 338 | Output 339 | ------ 340 | :return b_hat: 341 | """ 342 | # reshape input 343 | input_shape = input.shape 344 | 345 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor. 346 | divisor=self._num_streams_per_tx * self._k 347 | if np.prod(input_shape) % divisor != 0: 348 | flatten_input = tf.reshape(input, [-1]) 349 | flatten_input_len = len(flatten_input) 350 | 351 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len 352 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0) 353 | else: 354 | flatten_input = input 355 | 356 | b = tf.reshape(flatten_input, (-1, 1, self._num_streams_per_tx, self._k)) 357 | batch_size = b.shape[0] 358 | 359 | if self.ebno_db_min is not None and self.ebno_db_max is not None: 360 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max) 361 | no = ebnodb2no(ebno_db_tf, self._num_bits_per_symbol, self._coderate, self._rg) 362 | else: 363 | no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, self._coderate, self._rg) 364 | 365 | c = self._encoder(b) 366 | x = self._mapper(c) 367 | x_rg = self._rg_mapper(x) 368 | 369 | if self._domain == "time": 370 | # Time-domain simulations 371 | a, tau = self._cdl(batch_size, self._rg.num_time_samples+self._l_tot-1, self._rg.bandwidth) 372 | h_time = cir_to_time_channel(self._rg.bandwidth, a, tau, 373 | l_min=self._l_min, l_max=self._l_max, normalize=True) 374 | 375 | # As precoding is done in the frequency domain, we need to downsample 376 | # the path gains `a` to the OFDM symbol rate prior to converting the CIR 377 | # to the channel frequency response. 378 | a_freq = a[...,self._rg.cyclic_prefix_length:-1:(self._rg.fft_size+self._rg.cyclic_prefix_length)] 379 | a_freq = a_freq[...,:self._rg.num_ofdm_symbols] 380 | h_freq = cir_to_ofdm_channel(self._frequencies, a_freq, tau, normalize=True) 381 | 382 | if self._direction == "downlink": 383 | x_rg, g = self._zf_precoder([x_rg, h_freq]) 384 | 385 | x_time = self._modulator(x_rg) 386 | y_time = self._channel_time([x_time, h_time, no]) 387 | 388 | y = self._demodulator(y_time) 389 | 390 | elif self._domain == "freq": 391 | # Frequency-domain simulations 392 | 393 | cir = self._cdl(batch_size, self._rg.num_ofdm_symbols, 1/self._rg.ofdm_symbol_duration) 394 | h_freq = cir_to_ofdm_channel(self._frequencies, *cir, normalize=True) 395 | 396 | if self._direction == "downlink": 397 | x_rg, g = self._zf_precoder([x_rg, h_freq]) 398 | 399 | y = self._channel_freq([x_rg, h_freq, no]) 400 | 401 | if self._perfect_csi: 402 | if self._direction == "uplink": 403 | h_hat = self._remove_nulled_scs(h_freq) 404 | elif self._direction =="downlink": 405 | h_hat = g 406 | err_var = 0.0 407 | else: 408 | h_hat, err_var = self._ls_est ([y, no]) 409 | 410 | x_hat, no_eff = self._lmmse_equ([y, h_hat, err_var, no]) 411 | llr = self._demapper([x_hat, no_eff]) 412 | b_hat = self._decoder(llr) 413 | 414 | if np.prod(input_shape) % divisor != 0: 415 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements 416 | flatten_b_hat = tf.reshape(b_hat, [-1]) 417 | sliced_b_hat = flatten_b_hat[:-dummy_cnt] 418 | b_hat=tf.reshape(sliced_b_hat, input_shape) 419 | else: 420 | b_hat=tf.reshape(b_hat, input_shape) 421 | 422 | return b_hat 423 | 424 | class ChannelSL(tf.keras.Model): 425 | """ 426 | OFDM MIMO transmissions over a 3GPP 38.901 model. 427 | (Realistic Multiuser MIMO OFDM) 428 | 429 | Parameters 430 | ---------- 431 | :param fec_type: str, One of ["Polar5G", "LDPC5G"] 432 | :param scenario: str, One of ["umi", "uma", "rma"] 433 | :param perfect_csi: boolean 434 | :param channel_num_tx_ant: int 435 | :param channel_num_rx_ant: int 436 | :param num_bits_per_symbol: int 437 | :param ebno_db: float 438 | :param ebno_db_min: float 439 | :param ebno_db_max: float 440 | :param fec_num_iter: int 441 | """ 442 | def __init__(self, 443 | fec_type, 444 | scenario, 445 | perfect_csi, 446 | channel_num_tx_ant, 447 | channel_num_rx_ant, 448 | num_bits_per_symbol, 449 | ebno_db=None, 450 | ebno_db_min=None, 451 | ebno_db_max=None, 452 | fec_num_iter=6 453 | ): 454 | super().__init__() 455 | self.fec_type = fec_type 456 | self._scenario = scenario 457 | self._perfect_csi = perfect_csi 458 | 459 | # Internally set parameters 460 | self._carrier_frequency = 2.6e9 461 | self._fft_size = 36 462 | self._subcarrier_spacing = 15e3 463 | self._num_ofdm_symbols = 12 464 | self._cyclic_prefix_length = 6 465 | self._pilot_ofdm_symbol_indices = [2, 11] 466 | self._num_bs_ant = int(channel_num_rx_ant) 467 | self._num_ut = 1 # number of users communicating with a base station. 468 | self._num_ut_ant = int(channel_num_tx_ant) # number of antennas in a user terminal. 469 | self._num_bits_per_symbol = int(num_bits_per_symbol) 470 | self._coderate = 0.5 471 | 472 | self._dc_null = True 473 | self._num_guard_carriers = [5, 6] 474 | 475 | # Create an RX-TX association matrix 476 | # rx_tx_association[i,j]=1 means that receiver i gets at least one stream 477 | # from transmitter j. Depending on the transmission direction (uplink or downlink), 478 | # the role of UT and BS can change. 479 | bs_ut_association = np.zeros([1, self._num_ut]) 480 | bs_ut_association[0, :] = 1 481 | self._rx_tx_association = bs_ut_association 482 | self._num_tx = self._num_ut 483 | self._num_streams_per_tx = self._num_ut_ant 484 | 485 | 486 | # Setup an OFDM Resource Grid 487 | self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols, 488 | fft_size=self._fft_size, 489 | subcarrier_spacing=self._subcarrier_spacing, 490 | num_tx=self._num_tx, 491 | num_streams_per_tx=self._num_streams_per_tx, 492 | cyclic_prefix_length=self._cyclic_prefix_length, 493 | num_guard_carriers=self._num_guard_carriers, 494 | dc_null=self._dc_null, 495 | pilot_pattern="kronecker", 496 | pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices) 497 | 498 | # Setup StreamManagement 499 | self._sm = StreamManagement(self._rx_tx_association, self._num_streams_per_tx) 500 | 501 | # Configure antenna arrays 502 | self._ut_array = AntennaArray( 503 | num_rows=1, 504 | num_cols=1, 505 | polarization="single", 506 | polarization_type="V", 507 | antenna_pattern="omni", 508 | carrier_frequency=self._carrier_frequency) 509 | 510 | self._bs_array = AntennaArray(num_rows=1, 511 | num_cols=int(self._num_bs_ant/2), 512 | polarization="dual", 513 | polarization_type="cross", 514 | antenna_pattern="38.901", 515 | carrier_frequency=self._carrier_frequency) 516 | 517 | # Configure the channel model 518 | if self._scenario == "umi": 519 | self._channel_model = UMi(carrier_frequency=self._carrier_frequency, 520 | o2i_model="low", 521 | ut_array=self._ut_array, 522 | bs_array=self._bs_array, 523 | direction="uplink", 524 | enable_pathloss=False, 525 | enable_shadow_fading=False) 526 | elif self._scenario == "uma": 527 | self._channel_model = UMa(carrier_frequency=self._carrier_frequency, 528 | o2i_model="low", 529 | ut_array=self._ut_array, 530 | bs_array=self._bs_array, 531 | direction="uplink", 532 | enable_pathloss=False, 533 | enable_shadow_fading=False) 534 | elif self._scenario == "rma": 535 | self._channel_model = RMa(carrier_frequency=self._carrier_frequency, 536 | ut_array=self._ut_array, 537 | bs_array=self._bs_array, 538 | direction="uplink", 539 | enable_pathloss=False, 540 | enable_shadow_fading=False) 541 | 542 | # Instantiate other building blocks 543 | self._qam_source = QAMSource(self._num_bits_per_symbol) 544 | 545 | self._n = int(self._rg.num_data_symbols*self._num_bits_per_symbol) # Number of coded bits 546 | self._k = int(self._n*self._coderate) # Number of information bits 547 | 548 | # FEC 549 | self.fec_num_iter = fec_num_iter 550 | if self.fec_type == 'Polar5G': 551 | self._encoder = Polar5GEncoder(self._k, self._n) 552 | self._decoder = Polar5GDecoder( 553 | self._encoder, 554 | dec_type='SC', 555 | list_size=8, 556 | num_iter=self.fec_num_iter 557 | ) 558 | elif self.fec_type == 'LDPC5G': 559 | self._encoder = LDPC5GEncoder(self._k, self._n) 560 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter) 561 | else: 562 | raise ValueError(f"Invalid channel coding type: {fec_type}") 563 | 564 | self._mapper = Mapper("qam", self._num_bits_per_symbol) 565 | self._rg_mapper = ResourceGridMapper(self._rg) 566 | 567 | self._ofdm_channel = OFDMChannel(self._channel_model, self._rg, add_awgn=True, 568 | normalize_channel=True, return_channel=True) 569 | 570 | self._remove_nulled_subcarriers = RemoveNulledSubcarriers(self._rg) 571 | self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn") 572 | self._lmmse_equ = LMMSEEqualizer(self._rg, self._sm) 573 | self._demapper = Demapper("app", "qam", self._num_bits_per_symbol) 574 | 575 | # channel noise 576 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)" 577 | if ebno_db is not None: 578 | self.ebno_db = float(ebno_db) 579 | else: 580 | self.ebno_db = ebno_db # None 581 | self.ebno_db_min = ebno_db_min 582 | self.ebno_db_max = ebno_db_max 583 | 584 | print(f'{self.ebno_db=}') 585 | print(f'{self.ebno_db_min=}') 586 | print(f'{self.ebno_db_max=}') 587 | 588 | def new_topology(self, batch_size): 589 | """Set new topology""" 590 | topology = gen_topology(batch_size, 591 | self._num_ut, 592 | self._scenario, 593 | min_ut_velocity=0.0, 594 | max_ut_velocity=0.0) 595 | 596 | self._channel_model.set_topology(*topology) 597 | 598 | 599 | @tf.function(jit_compile=True) 600 | def call(self, input): 601 | """ 602 | Input 603 | ----- 604 | :param input: 605 | 606 | Output 607 | ------ 608 | :return b_hat: 609 | """ 610 | # reshape input 611 | input_shape = input.shape 612 | 613 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor. 614 | divisor=self._num_tx * self._num_streams_per_tx * self._k 615 | if np.prod(input_shape) % divisor != 0: 616 | flatten_input = tf.reshape(input, [-1]) 617 | flatten_input_len = len(flatten_input) 618 | 619 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len 620 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0) 621 | else: 622 | flatten_input = input 623 | 624 | b = tf.reshape(flatten_input, (-1, self._num_tx, self._num_streams_per_tx, self._k)) 625 | batch_size = b.shape[0] 626 | 627 | self.new_topology(batch_size) 628 | if self.ebno_db_min is not None and self.ebno_db_max is not None: 629 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max) 630 | no = ebnodb2no(ebno_db_tf, self._num_bits_per_symbol, self._coderate, self._rg) 631 | else: 632 | no = ebnodb2no(self.ebno_db, self._num_bits_per_symbol, self._coderate, self._rg) 633 | 634 | c = self._encoder(b) 635 | x = self._mapper(c) 636 | x_rg = self._rg_mapper(x) 637 | y, h = self._ofdm_channel([x_rg, no]) 638 | if self._perfect_csi: 639 | h_hat = self._remove_nulled_subcarriers(h) 640 | err_var = 0.0 641 | else: 642 | h_hat, err_var = self._ls_est ([y, no]) 643 | x_hat, no_eff = self._lmmse_equ([y, h_hat, err_var, no]) 644 | llr = self._demapper([x_hat, no_eff]) 645 | b_hat = self._decoder(llr) 646 | 647 | if np.prod(input_shape) % divisor != 0: 648 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements 649 | flatten_b_hat = tf.reshape(b_hat, [-1]) 650 | sliced_b_hat = flatten_b_hat[:-dummy_cnt] 651 | b_hat=tf.reshape(sliced_b_hat, input_shape) 652 | else: 653 | b_hat=tf.reshape(b_hat, input_shape) 654 | 655 | return b_hat 656 | 657 | class ChannelFlatFading(tf.keras.Model): 658 | """ 659 | Configure FlatFading Channel(for a simulation over RayLeigh) components. 660 | ref: https://nvlabs.github.io/sionna/examples/Simple_MIMO_Simulation.html?highlight=rayleigh 661 | 662 | Parameters 663 | ---------- 664 | :param fec_type: str, One of ["Polar5G", "LDPC5G"] 665 | :param channel_num_tx_ant: int 666 | :param channel_num_rx_ant: int 667 | :param num_bits_per_symbol: int 668 | :param fec_n: int 669 | :param fec_k: int 670 | :param ebno_db: float 671 | :param ebno_db_min: float 672 | :param ebno_db_max: float 673 | :param fec_num_iter: int 674 | 675 | """ 676 | def __init__(self, 677 | fec_type, 678 | channel_num_tx_ant, 679 | channel_num_rx_ant, 680 | num_bits_per_symbol, 681 | fec_n, 682 | fec_k, 683 | ebno_db=None, 684 | ebno_db_min=None, 685 | ebno_db_max=None, 686 | fec_num_iter=6 687 | ): 688 | super().__init__() 689 | self.fec_type = fec_type 690 | 691 | self._n = fec_n 692 | self._k = fec_k 693 | self._coderate = self._k / self._n 694 | 695 | constellation = Constellation("qam", 696 | num_bits_per_symbol, 697 | trainable=False) 698 | logger.info(f'Constellation: type={constellation._constellation_type} ' + \ 699 | f'{num_bits_per_symbol=} trainable={constellation._trainable}') 700 | self.num_bits_per_symbol = num_bits_per_symbol 701 | self.mapper = Mapper(constellation=constellation) 702 | 703 | self.channel_num_tx_ant = int(channel_num_tx_ant) 704 | self.channel_num_rx_ant = int(channel_num_rx_ant) 705 | self.channel = FlatFadingChannel(self.channel_num_tx_ant, self.channel_num_rx_ant, add_awgn=True, return_channel=True) 706 | 707 | # channel noise 708 | assert ebno_db is not None or (ebno_db_min is not None and ebno_db_max is not None), "Set a single ebno_db or (ebno_db_min and ebno_db_max)" 709 | if ebno_db is not None: 710 | self.ebno_db = float(ebno_db) 711 | else: 712 | self.ebno_db = ebno_db # None 713 | self.ebno_db_min = ebno_db_min 714 | self.ebno_db_max = ebno_db_max 715 | 716 | print(f'{self.ebno_db=}') 717 | print(f'{self.ebno_db_min=}') 718 | print(f'{self.ebno_db_max=}') 719 | 720 | self.demapper = Demapper("app", constellation=constellation) 721 | 722 | # FEC 723 | self.fec_num_iter = fec_num_iter 724 | if self.fec_type == 'Polar5G': 725 | self._encoder = Polar5GEncoder(self._k, self._n) 726 | self._decoder = Polar5GDecoder( 727 | self._encoder, 728 | dec_type='SC', 729 | list_size=8, 730 | num_iter=self.fec_num_iter 731 | ) 732 | elif self.fec_type == 'LDPC5G': 733 | self._encoder = LDPC5GEncoder(self._k, self._n) 734 | self._decoder = LDPC5GDecoder(self._encoder, hard_out=True, num_iter=self.fec_num_iter) 735 | else: 736 | raise ValueError(f"Invalid channel coding type: {fec_type}") 737 | 738 | 739 | @tf.function(jit_compile=True) 740 | def call(self, input): 741 | ''' 742 | Input 743 | ----- 744 | :param input: 745 | 746 | Output 747 | ------ 748 | :return b_hat: 749 | ''' 750 | # reshape input 751 | input_shape = input.shape 752 | 753 | # Add dummy zeros in the end of input to fit input shape to a multiple of divisor. 754 | divisor=self._k 755 | if np.prod(input_shape) % divisor != 0: 756 | flatten_input = tf.reshape(input, [-1]) 757 | flatten_input_len = len(flatten_input) 758 | 759 | dummy_cnt = ((flatten_input_len // divisor)+1) * divisor - flatten_input_len 760 | flatten_input = tf.concat([flatten_input, [0 for _ in range(dummy_cnt)]],0) 761 | else: 762 | flatten_input = input 763 | 764 | # Channel encoder 765 | b = tf.reshape(flatten_input, (-1, self.channel_num_tx_ant, self._k)) 766 | codewords = self._encoder(b) 767 | 768 | # Modulation 769 | x = self.mapper(codewords) 770 | shape = tf.shape(x) 771 | x = tf.reshape(x, (-1, self.channel_num_tx_ant)) 772 | 773 | ##################### 774 | # Channel 775 | ##################### 776 | # Sampling a batch of SNRs 777 | batch_size=b.shape[0] 778 | if self.ebno_db_min is not None and self.ebno_db_max is not None: 779 | ebno_db_tf = tf.random.uniform(shape=[batch_size], minval=self.ebno_db_min, maxval=self.ebno_db_max) 780 | no = ebnodb2no(ebno_db_tf, self.num_bits_per_symbol, self._coderate) 781 | else: 782 | no = ebnodb2no(self.ebno_db, self.num_bits_per_symbol, self._coderate) 783 | 784 | no *= np.sqrt(self.channel_num_rx_ant) 785 | 786 | y, h = self.channel([x, no]) 787 | s = tf.complex(no*tf.eye(self.channel_num_rx_ant, self.channel_num_rx_ant), 0.0) 788 | 789 | # x_hat, no_eff = mf_equalizer(y, h, s) 790 | x_hat, no_eff = lmmse_equalizer(y, h, s) 791 | 792 | x_hat = tf.reshape(x_hat, shape) 793 | no_eff = tf.reshape(no_eff, shape) 794 | 795 | ##################### 796 | # Receiver 797 | ##################### 798 | # Demodulation 799 | llr = self.demapper([x_hat, no_eff]) 800 | # llr = tf.reshape(llr, (-1, self._n)) 801 | 802 | # Channel decoder 803 | b_hat = self._decoder(llr) 804 | 805 | if np.prod(input_shape) % divisor != 0: 806 | #Reshape b_hat to the original shape by cutting the arbitrarily appended elements 807 | flatten_b_hat = tf.reshape(b_hat, [-1]) 808 | sliced_b_hat = flatten_b_hat[:-dummy_cnt] 809 | b_hat=tf.reshape(sliced_b_hat, input_shape) 810 | else: 811 | b_hat=tf.reshape(b_hat, input_shape) 812 | 813 | return b_hat 814 | -------------------------------------------------------------------------------- /models/on_device_ai_comm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | MIMO OFDM Transmissions over the CDL Channel Model 3 | https://nvlabs.github.io/sionna/examples/MIMO_OFDM_Transmissions_over_CDL.html 4 | ''' 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | from transformers import TFBartPretrainedModel, TFBartForConditionalGeneration 9 | from transformers.models.bart.modeling_tf_bart import TFBartMainLayer, BartConfig, shift_tokens_right, TFBartEncoder 10 | from transformers.modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqModelOutput 11 | from transformers.modeling_tf_utils import unpack_inputs, TFModelInputType 12 | import tensorflow as tf 13 | import numpy as np 14 | 15 | import sionna 16 | sionna.Config.xla_compat=True 17 | 18 | from transformers.utils import ( 19 | logging, 20 | ) 21 | 22 | from .channels import ChannelAWGN, ChannelCDL, ChannelSL, ChannelFlatFading 23 | from .utils import tensor_to_binary_v2, binary_to_tensor_v2, replace_nan, get_ber 24 | from .vq_vae import VectorQuantizer 25 | 26 | from transformers.tf_utils import shape_list 27 | from transformers.modeling_tf_outputs import TFSeq2SeqLMOutput, TFBaseModelOutput 28 | 29 | logger = logging.get_logger("transformers") 30 | 31 | @dataclass 32 | class TFEncoderChannelModelOutput(TFBaseModelOutput): 33 | """Output of TFAISrcEncoderAndChannel that includes 34 | - AI-Src encoder 35 | - channel encoder 36 | - mapper 37 | - channel 38 | - demapper 39 | - channel decoder 40 | """ 41 | ber: Optional[tf.Tensor] = None 42 | 43 | @dataclass 44 | class TFOnDeviceAICMainLayerOutput(TFSeq2SeqModelOutput): 45 | """Output of TFOnDeviceAICMainLayer""" 46 | ber: Optional[tf.Tensor] = None 47 | 48 | @dataclass 49 | class TFOnDeviceAICOutput(TFSeq2SeqLMOutput): 50 | """Output of TFOnDeviceAICForConditionalGeneration""" 51 | ber: Optional[tf.Tensor] = None 52 | 53 | class TFAISrcEncoderAndChannel(tf.keras.layers.Layer): 54 | """ 55 | This class includes AI-Src-Encoder and Channel(Channel En/Decoder, Channel, Mapper, etc.) 56 | """ 57 | 58 | def __init__(self, 59 | ai_src_encoder: TFBartEncoder, 60 | vq_layer: VectorQuantizer, 61 | ebno_db, 62 | ebno_db_min, 63 | ebno_db_max, 64 | channel_type, 65 | cdl_model, 66 | scenario, 67 | perfect_csi, 68 | fec_type, 69 | channel_num_tx_ant, 70 | channel_num_rx_ant, 71 | num_bits_per_symbol=4, 72 | fec_k=512, 73 | fec_n=1024, 74 | fec_num_iter=6, 75 | bin_conv_method='tanh', 76 | do_train=False 77 | ): 78 | # NOTE: setting layer name as follows seems strange, 79 | # but it allows HuggingFace to load pretrained weight properly 80 | super().__init__(name='model/model/') 81 | self.config = ai_src_encoder.config 82 | self.ai_src_encoder = ai_src_encoder 83 | self.ai_src_encoder.trainable = False 84 | 85 | # If Training TFOnDeviceAICForConditionalGeneration or not 86 | self.do_train = do_train 87 | 88 | # make sure data types are proper. 89 | num_bits_per_symbol = int(num_bits_per_symbol) 90 | 91 | # Configure Channel Model 92 | if channel_type == 'AWGN': 93 | ch_config = { 94 | 'fec_type': fec_type, 95 | 'num_bits_per_symbol': num_bits_per_symbol, 96 | 'fec_n': fec_n, 97 | 'fec_k': fec_k, 98 | 'ebno_db' : ebno_db, 99 | 'ebno_db_min': ebno_db_min, 100 | 'ebno_db_max': ebno_db_max, 101 | 'fec_num_iter': fec_num_iter, 102 | } 103 | elif channel_type == 'CDL': 104 | ch_config = { 105 | 'fec_type': fec_type, 106 | 'cdl_model': cdl_model, 107 | 'channel_num_tx_ant': channel_num_tx_ant, 108 | 'channel_num_rx_ant': channel_num_rx_ant, 109 | 'num_bits_per_symbol': num_bits_per_symbol, 110 | 'ebno_db' : ebno_db, 111 | 'ebno_db_min': ebno_db_min, 112 | 'ebno_db_max': ebno_db_max, 113 | 'fec_num_iter': fec_num_iter, 114 | } 115 | elif channel_type == '3GPP-38.901': 116 | ch_config = { 117 | 'fec_type': fec_type, 118 | 'scenario': scenario, 119 | 'perfect_csi': perfect_csi, 120 | 'channel_num_tx_ant': channel_num_tx_ant, 121 | 'channel_num_rx_ant': channel_num_rx_ant, 122 | 'num_bits_per_symbol': num_bits_per_symbol, 123 | 'ebno_db' : ebno_db, 124 | 'ebno_db_min': ebno_db_min, 125 | 'ebno_db_max': ebno_db_max, 126 | 'fec_num_iter': fec_num_iter, 127 | } 128 | elif channel_type == 'FlatFading': 129 | ch_config = { 130 | 'fec_type': fec_type, 131 | 'channel_num_tx_ant': channel_num_tx_ant, 132 | 'channel_num_rx_ant': channel_num_rx_ant, 133 | 'num_bits_per_symbol': num_bits_per_symbol, 134 | 'fec_n': fec_n, 135 | 'fec_k': fec_k, 136 | 'ebno_db' : ebno_db, 137 | 'ebno_db_min': ebno_db_min, 138 | 'ebno_db_max': ebno_db_max, 139 | 'fec_num_iter': fec_num_iter, 140 | } 141 | else: 142 | raise ValueError('Invalid Channel type. Channel type should be AWGN or CDL') 143 | 144 | # define vq 145 | if bin_conv_method == 'vector_quantization': 146 | self.vq_layer =vq_layer 147 | self.num_embeddings = vq_layer.num_embeddings 148 | self.embedding_dim = vq_layer.embedding_dim 149 | 150 | # setup 151 | self._setup_channel(channel_type, ch_config) 152 | self._setup_bin_conv(bin_conv_method) 153 | 154 | def _setup_channel(self, channel_type, ch_config): 155 | if channel_type == 'AWGN': 156 | self.channel_model = ChannelAWGN( 157 | fec_type=ch_config['fec_type'], 158 | num_bits_per_symbol=ch_config['num_bits_per_symbol'], 159 | fec_n=ch_config['fec_n'], 160 | fec_k=ch_config['fec_k'], 161 | ebno_db=ch_config['ebno_db'], 162 | ebno_db_min=ch_config['ebno_db_min'], 163 | ebno_db_max=ch_config['ebno_db_max'], 164 | fec_num_iter=ch_config['fec_num_iter'] 165 | ) 166 | elif channel_type == 'CDL': 167 | self.channel_model = ChannelCDL( 168 | fec_type=ch_config['fec_type'], 169 | cdl_model=ch_config['cdl_model'], 170 | channel_num_tx_ant=ch_config['channel_num_tx_ant'], 171 | channel_num_rx_ant=ch_config['channel_num_rx_ant'], 172 | num_bits_per_symbol=ch_config['num_bits_per_symbol'], 173 | ebno_db=ch_config['ebno_db'], 174 | ebno_db_min=ch_config['ebno_db_min'], 175 | ebno_db_max=ch_config['ebno_db_max'], 176 | fec_num_iter=ch_config['fec_num_iter'] 177 | ) 178 | elif channel_type == '3GPP-38.901': 179 | self.channel_model = ChannelSL( 180 | fec_type=ch_config['fec_type'], 181 | scenario=ch_config['scenario'], 182 | perfect_csi=ch_config['perfect_csi'], 183 | channel_num_tx_ant=ch_config['channel_num_tx_ant'], 184 | channel_num_rx_ant=ch_config['channel_num_rx_ant'], 185 | num_bits_per_symbol=ch_config['num_bits_per_symbol'], 186 | ebno_db=ch_config['ebno_db'], 187 | ebno_db_min=ch_config['ebno_db_min'], 188 | ebno_db_max=ch_config['ebno_db_max'], 189 | fec_num_iter=ch_config['fec_num_iter'] 190 | ) 191 | elif channel_type == 'FlatFading': 192 | self.channel_model = ChannelFlatFading( 193 | fec_type=ch_config['fec_type'], 194 | channel_num_tx_ant=ch_config['channel_num_tx_ant'], 195 | channel_num_rx_ant=ch_config['channel_num_rx_ant'], 196 | num_bits_per_symbol=ch_config['num_bits_per_symbol'], 197 | fec_n=ch_config['fec_n'], 198 | fec_k=ch_config['fec_k'], 199 | ebno_db=ch_config['ebno_db'], 200 | ebno_db_min=ch_config['ebno_db_min'], 201 | ebno_db_max=ch_config['ebno_db_max'], 202 | fec_num_iter=ch_config['fec_num_iter'] 203 | ) 204 | else: 205 | raise ValueError('Invalid Channel type. Channel type should be AWGN or CDL') 206 | 207 | def _setup_bin_conv(self, bin_conv_method): 208 | self.bin_conv_method = bin_conv_method 209 | logger.info(f'{bin_conv_method=}') 210 | if bin_conv_method == 'naive': 211 | self._tensor_to_binary = [tensor_to_binary_v2] 212 | self._binary_to_tensor = [binary_to_tensor_v2] 213 | elif bin_conv_method == 'tanh': 214 | self._tensor_to_binary = [tensor_to_binary_v2] 215 | self._binary_to_tensor = [binary_to_tensor_v2, tf.math.tanh] 216 | elif bin_conv_method == 'vector_quantization': 217 | self._tensor_to_binary = [self.vq_layer.get_code_indices, tensor_to_binary_v2] 218 | self._binary_to_tensor = [binary_to_tensor_v2, 219 | self.vq_layer.handle_invalid_values, 220 | self.vq_layer.reconstruct_with_indices] 221 | else: 222 | raise ValueError(f'Invalid bin_conv_method: {bin_conv_method}') 223 | 224 | 225 | @unpack_inputs 226 | def call( 227 | self, 228 | input_ids: Optional[TFModelInputType] = None, 229 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 230 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 231 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 232 | output_attentions: Optional[bool] = None, 233 | output_hidden_states: Optional[bool] = None, 234 | return_dict: Optional[bool] = None, 235 | training: Optional[bool] = False, 236 | ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: 237 | # run AI-Src encoder(BART) 238 | ai_src_encoder_outputs = self.ai_src_encoder( 239 | input_ids=input_ids, 240 | inputs_embeds=inputs_embeds, 241 | attention_mask=attention_mask, 242 | head_mask=head_mask, 243 | output_attentions=output_attentions, 244 | output_hidden_states=output_hidden_states, 245 | return_dict=return_dict, 246 | training=training, 247 | ) 248 | 249 | shape = tf.shape(ai_src_encoder_outputs.last_hidden_state) 250 | ai_src_encoder_output = ai_src_encoder_outputs.last_hidden_state 251 | 252 | # binarize (and get codebook indices if vector quantizing) 253 | for f in self._tensor_to_binary: 254 | ai_src_encoder_output = f(ai_src_encoder_output) 255 | ai_src_encoder_output_binary = ai_src_encoder_output 256 | 257 | # add channel noise to binarized ai_src_encoder output 258 | last_hidden_state_binary = self.channel_model(ai_src_encoder_output_binary) 259 | 260 | # convert to tensor and denoise (and reconstruct tensors(to feed semantic decoder) using codebook if vector quantizing) 261 | last_hidden_state = last_hidden_state_binary 262 | for f in self._binary_to_tensor: 263 | last_hidden_state = f(last_hidden_state) 264 | last_hidden_state_pred = tf.reshape(last_hidden_state, shape) 265 | 266 | last_hidden_state_pred = replace_nan(last_hidden_state_pred, 0) # convert all NaN values to zero 267 | if (self.bin_conv_method != 'naive'): 268 | tf.debugging.assert_all_finite(last_hidden_state_pred, 'should not have nan/inf/-inf') 269 | 270 | # calculate BER if eval. 271 | if not self.do_train: 272 | last_hidden_state_binary = tf.reshape(last_hidden_state_binary, tf.shape(ai_src_encoder_output_binary)) 273 | ber = get_ber(ai_src_encoder_output_binary, last_hidden_state_binary) 274 | else: 275 | # While training, does not calculate BER. 276 | ber = tf.constant(-1.0, dtype=tf.float32) 277 | 278 | return TFEncoderChannelModelOutput( 279 | last_hidden_state=last_hidden_state_pred, 280 | hidden_states=ai_src_encoder_outputs.hidden_states, 281 | attentions=ai_src_encoder_outputs.attentions, 282 | ber=ber, 283 | ) 284 | 285 | 286 | class TFOnDeviceAICMainLayer(tf.keras.layers.Layer): 287 | 288 | def __init__(self, 289 | config: BartConfig, 290 | bart_main_layer: TFBartMainLayer, 291 | ebno_db, 292 | ebno_db_min, 293 | ebno_db_max, 294 | fec_k=512, 295 | fec_n=1024, 296 | fec_num_iter=6, 297 | num_bits_per_symbol=4, 298 | channel_type = 'AWGN', 299 | cdl_model='A', 300 | scenario='umi', 301 | perfect_csi=True, 302 | channel_num_tx_ant=1, 303 | channel_num_rx_ant=1, 304 | fec_type=None, 305 | bin_conv_method='tanh', 306 | embedding_dim=2, 307 | num_embeddings=1024, 308 | do_train=False, 309 | **kwargs): 310 | super().__init__(**kwargs) 311 | 312 | self.config = config 313 | self.shared = bart_main_layer.get_input_embeddings() 314 | self.shared.trainable = False 315 | 316 | self.bin_conv_method = bin_conv_method 317 | # VectorQuantizer layer 318 | if self.bin_conv_method == 'vector_quantization': 319 | print(f'{embedding_dim=}') 320 | print(f'{num_embeddings=}') # number of codebooks. 321 | assert (embedding_dim is not None) and (num_embeddings is not None) \ 322 | , "For vector_quantization, set embedding_dim and num_embeddings arguments." 323 | self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, name="vector_quantizer") 324 | else: 325 | self.vq_layer = None 326 | 327 | # ai source encoder 328 | self.encoder = TFAISrcEncoderAndChannel( 329 | ai_src_encoder=bart_main_layer.encoder, 330 | vq_layer = self.vq_layer, 331 | ebno_db=ebno_db, 332 | ebno_db_min=ebno_db_min, 333 | ebno_db_max=ebno_db_max, 334 | fec_k=fec_k, 335 | fec_n=fec_n, 336 | fec_num_iter=fec_num_iter, 337 | num_bits_per_symbol=num_bits_per_symbol, 338 | channel_type=channel_type, 339 | cdl_model=cdl_model, 340 | scenario=scenario, 341 | perfect_csi=perfect_csi, 342 | channel_num_tx_ant=channel_num_tx_ant, 343 | channel_num_rx_ant=channel_num_rx_ant, 344 | fec_type=fec_type, 345 | bin_conv_method=bin_conv_method, 346 | do_train=do_train) 347 | self.decoder = bart_main_layer.decoder 348 | 349 | def get_input_embeddings(self): 350 | return self.shared 351 | 352 | def set_input_embeddings(self, new_embeddings): 353 | self.shared = new_embeddings 354 | self.encoder.encoder.embed_tokens = self.shared 355 | self.decoder.embed_tokens = self.shared 356 | 357 | @unpack_inputs 358 | def call(self, 359 | input_ids: Optional[TFModelInputType] = None, 360 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 361 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, 362 | decoder_attention_mask: Optional[Union[np.ndarray, 363 | tf.Tensor]] = None, 364 | decoder_position_ids: Optional[Union[np.ndarray, 365 | tf.Tensor]] = None, 366 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 367 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 368 | cross_attn_head_mask: Optional[Union[np.ndarray, 369 | tf.Tensor]] = None, 370 | encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, 371 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, 372 | tf.Tensor]]]] = None, 373 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 374 | decoder_inputs_embeds: Optional[Union[np.ndarray, 375 | tf.Tensor]] = None, 376 | use_cache: Optional[bool] = None, 377 | output_attentions: Optional[bool] = None, 378 | output_hidden_states: Optional[bool] = None, 379 | return_dict: Optional[bool] = None, 380 | training: Optional[bool] = False, 381 | **kwargs) -> Union[TFOnDeviceAICMainLayerOutput, Tuple[tf.Tensor]]: 382 | # different to other models, Bart automatically creates decoder_input_ids from 383 | # input_ids if no decoder_input_ids are provided 384 | if decoder_input_ids is None and decoder_inputs_embeds is None: 385 | if input_ids is None: 386 | raise ValueError( 387 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 388 | "passed, `input_ids` cannot be `None`. Please pass either " 389 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 390 | ) 391 | 392 | decoder_input_ids = shift_tokens_right( 393 | input_ids, self.config.pad_token_id, 394 | self.config.decoder_start_token_id) 395 | 396 | if encoder_outputs is None: 397 | encoder_outputs = self.encoder( 398 | input_ids=input_ids, 399 | attention_mask=attention_mask, 400 | head_mask=head_mask, 401 | inputs_embeds=inputs_embeds, 402 | output_attentions=output_attentions, 403 | output_hidden_states=output_hidden_states, 404 | return_dict=return_dict, 405 | training=training, 406 | ) 407 | 408 | # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True 409 | elif return_dict and not isinstance(encoder_outputs, 410 | TFBaseModelOutput): 411 | encoder_outputs = TFBaseModelOutput( 412 | last_hidden_state=encoder_outputs[0], 413 | hidden_states=encoder_outputs[1] 414 | if len(encoder_outputs) > 1 else None, 415 | attentions=encoder_outputs[2] 416 | if len(encoder_outputs) > 2 else None, 417 | ) 418 | 419 | # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False 420 | elif not return_dict and not isinstance(encoder_outputs, tuple): 421 | encoder_outputs = encoder_outputs.to_tuple() 422 | 423 | # call VectorQuantizer to train VectorQuantizer 424 | if self.bin_conv_method == 'vector_quantization': 425 | # encoder_outputs.last_hidden_state = self.vq_layer(encoder_outputs.last_hidden_state) 426 | self.vq_layer(encoder_outputs.last_hidden_state) 427 | 428 | decoder_outputs = self.decoder( 429 | decoder_input_ids, 430 | attention_mask=decoder_attention_mask, 431 | position_ids=decoder_position_ids, 432 | encoder_hidden_states=encoder_outputs[0], 433 | encoder_attention_mask=attention_mask, 434 | head_mask=decoder_head_mask, 435 | cross_attn_head_mask=cross_attn_head_mask, 436 | past_key_values=past_key_values, 437 | inputs_embeds=decoder_inputs_embeds, 438 | use_cache=use_cache, 439 | output_attentions=output_attentions, 440 | output_hidden_states=output_hidden_states, 441 | return_dict=return_dict, 442 | training=training, 443 | ) 444 | 445 | if not return_dict: 446 | return decoder_outputs + encoder_outputs 447 | 448 | return TFOnDeviceAICMainLayerOutput( 449 | last_hidden_state=decoder_outputs.last_hidden_state, 450 | past_key_values=decoder_outputs.past_key_values, 451 | decoder_hidden_states=decoder_outputs.hidden_states, 452 | decoder_attentions=decoder_outputs.attentions, 453 | cross_attentions=decoder_outputs.cross_attentions, 454 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 455 | encoder_hidden_states=encoder_outputs.hidden_states, 456 | encoder_attentions=encoder_outputs.attentions, 457 | ber=encoder_outputs.ber, 458 | ) 459 | 460 | 461 | class TFOnDeviceAICForConditionalGeneration(TFBartForConditionalGeneration): 462 | def __init__(self, 463 | config, 464 | ebno_db=None, 465 | ebno_db_min=None, 466 | ebno_db_max=None, 467 | fec_k=512, 468 | fec_n=1024, 469 | fec_num_iter=6, 470 | num_bits_per_symbol=4, 471 | channel_type = 'AWGN', 472 | cdl_model='A', 473 | scenario='umi', 474 | perfect_csi=True, 475 | channel_num_tx_ant = 1, 476 | channel_num_rx_ant = 1, 477 | fec_type = None, 478 | bin_conv_method='tanh', 479 | embedding_dim=None, 480 | num_embeddings=None, 481 | do_train=False, 482 | *inputs, 483 | **kwargs): 484 | super().__init__(config, *inputs, **kwargs) 485 | self.model = TFOnDeviceAICMainLayer( 486 | config, 487 | bart_main_layer=self.model, 488 | ebno_db=ebno_db, 489 | ebno_db_min=ebno_db_min, 490 | ebno_db_max=ebno_db_max, 491 | fec_k=fec_k, 492 | fec_n=fec_n, 493 | fec_num_iter=fec_num_iter, 494 | num_bits_per_symbol=num_bits_per_symbol, 495 | channel_type=channel_type, 496 | cdl_model=cdl_model, 497 | scenario=scenario, 498 | perfect_csi=perfect_csi, 499 | channel_num_tx_ant=channel_num_tx_ant, 500 | channel_num_rx_ant=channel_num_rx_ant, 501 | fec_type=fec_type, 502 | bin_conv_method=bin_conv_method, 503 | embedding_dim=embedding_dim, 504 | num_embeddings=num_embeddings, 505 | do_train=do_train, 506 | name="model") 507 | 508 | self.bin_conv_method = bin_conv_method 509 | self.VQ_LOSS_WEIGHT = 0.01 510 | 511 | @unpack_inputs 512 | def call( 513 | self, 514 | input_ids: Optional[TFModelInputType] = None, 515 | attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 516 | decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, 517 | decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 518 | decoder_position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, 519 | head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 520 | decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 521 | cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, 522 | encoder_outputs: Optional[TFBaseModelOutput] = None, 523 | past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, 524 | inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 525 | decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, 526 | use_cache: Optional[bool] = None, 527 | output_attentions: Optional[bool] = None, 528 | output_hidden_states: Optional[bool] = None, 529 | return_dict: Optional[bool] = None, 530 | labels: Optional[tf.Tensor] = None, 531 | training: Optional[bool] = False, 532 | ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]: 533 | 534 | if labels is not None: 535 | labels = tf.where( 536 | labels == self.config.pad_token_id, 537 | tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), 538 | labels, 539 | ) 540 | use_cache = False 541 | if decoder_input_ids is None and decoder_inputs_embeds is None: 542 | decoder_input_ids = shift_tokens_right( 543 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 544 | ) 545 | 546 | outputs = self.model( 547 | input_ids, 548 | attention_mask=attention_mask, 549 | decoder_input_ids=decoder_input_ids, 550 | encoder_outputs=encoder_outputs, 551 | decoder_attention_mask=decoder_attention_mask, 552 | decoder_position_ids=decoder_position_ids, 553 | head_mask=head_mask, 554 | decoder_head_mask=decoder_head_mask, 555 | cross_attn_head_mask=cross_attn_head_mask, 556 | past_key_values=past_key_values, 557 | inputs_embeds=inputs_embeds, 558 | decoder_inputs_embeds=decoder_inputs_embeds, 559 | use_cache=use_cache, 560 | output_attentions=output_attentions, 561 | output_hidden_states=output_hidden_states, 562 | return_dict=return_dict, 563 | training=training, 564 | ) 565 | lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) 566 | lm_logits = self.bias_layer(lm_logits) 567 | masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) 568 | 569 | # add weighted vq loss into masked_lm_loss 570 | if self.bin_conv_method == 'vector_quantization' and masked_lm_loss is not None: 571 | vq_loss = self.VQ_LOSS_WEIGHT * sum(self.model.vq_layer.losses) 572 | # tf.print( 573 | # masked_lm_loss, 574 | # vq_loss, 575 | # sep=',', 576 | # output_stream='./joohan/seq2seq-sc2/loss.log') 577 | masked_lm_loss += vq_loss # multiply weight to vq_loss 578 | 579 | if not return_dict: 580 | output = (lm_logits,) + outputs[1:] 581 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 582 | 583 | return TFOnDeviceAICOutput( 584 | loss=masked_lm_loss, 585 | logits=lm_logits, 586 | past_key_values=outputs.past_key_values, # index 1 of d outputs 587 | decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs 588 | decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs 589 | cross_attentions=outputs.cross_attentions, # index 4 of d outputs 590 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs 591 | encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out 592 | encoder_attentions=outputs.encoder_attentions, # 2 of e out 593 | ber=outputs.ber, 594 | ) 595 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | @tf.function 5 | def logit_to_binary(x): 6 | return (tf.math.sign(x) + 1.0) / 2.0 7 | 8 | @tf.function 9 | def tensor_to_binary_u32(x): 10 | # convert to bits 11 | bits = [None for i in range(32)] 12 | for i in range(32): 13 | bits[i] = tf.bitwise.bitwise_and(x, 1) 14 | x = tf.bitwise.right_shift(x, 1) 15 | res = tf.stack(bits, axis=-1) 16 | return tf.cast(res, tf.float32) 17 | 18 | @tf.function 19 | def tensor_to_binary_v2(x): 20 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32)) 21 | TO_MANTISSA = tf.constant([1<<23], tf.float32) 22 | 23 | # sign 24 | sign = tf.cast(x < 0.0, tf.float32) 25 | x = tf.math.abs(x) 26 | 27 | # exponent 28 | log_x = tf.math.floor(tf.math.log(x) / LOG_BASE2) 29 | exponent = tf.cast(log_x + 127.0, tf.uint8) 30 | 31 | # mantissa 32 | mantissa = x / tf.math.exp(log_x*LOG_BASE2) - tf.math.sign(x) 33 | mantissa = tf.math.floor(mantissa * TO_MANTISSA) 34 | mantissa = tf.cast(mantissa, tf.int32) 35 | 36 | # convert to bits 37 | bits = [None for i in range(32)] 38 | for i in range(23): 39 | bits[i] = tf.bitwise.bitwise_and(mantissa, 1) 40 | mantissa = tf.bitwise.right_shift(mantissa, 1) 41 | for i in range(23, 31): 42 | bits[i] = tf.bitwise.bitwise_and(exponent, 1) 43 | exponent = tf.bitwise.right_shift(exponent, 1) 44 | bits[31] = sign 45 | 46 | for i in range(32): 47 | bits[i] = tf.cast(bits[i], tf.float32) 48 | res = tf.stack(bits, axis=-1) 49 | return res 50 | 51 | @tf.function 52 | def binary_to_tensor_u32(x: tf.Tensor): 53 | x = tf.reshape(x, (-1, 32)) 54 | x = tf.cast(x, tf.uint32) 55 | out = x[:, 0] 56 | for i in range(1, 32): 57 | out += tf.bitwise.left_shift(x[:, i], i) 58 | return out 59 | 60 | @tf.function 61 | def binary_to_tensor_v2(x: tf.Tensor): 62 | LOG_BASE2 = tf.math.log(tf.constant([2.0], tf.float32)) 63 | EXPONENTS = tf.constant([ float(1 << i) for i in range(8)], tf.float32) 64 | FROM_MANTISSA = tf.constant([ 0.5**(23-i) for i in range(23)], tf.float32) 65 | 66 | x = tf.reshape(x, (-1, 32)) 67 | sign = -x[:, 31] * 2 + 1 68 | 69 | exponent = tf.math.reduce_sum(x[:, 23:31] * EXPONENTS, axis=-1) 70 | mantissa = tf.math.reduce_sum(x[:, :23] * FROM_MANTISSA, axis=-1) 71 | mantissa += tf.cast(exponent > 0.0, tf.float32) 72 | return sign * tf.math.exp((exponent - 127.0) * LOG_BASE2) * mantissa 73 | 74 | @tf.function 75 | def tensor_to_binary(x: tf.Tensor): 76 | x = tf.bitcast(x, tf.uint32) 77 | mask = tf.ones_like(x) 78 | bit0 = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)), 79 | tf.float32) 80 | bits = [bit0] 81 | 82 | for _ in range(31): 83 | x = tf.bitwise.right_shift(x, 1) 84 | bitn = tf.cast(tf.reshape(tf.bitwise.bitwise_and(x, mask), (1, -1)), 85 | tf.float32) 86 | bits.append(bitn) 87 | 88 | return tf.concat(bits, axis=0) 89 | 90 | @tf.function 91 | def binary_to_tensor(x: tf.Tensor): 92 | x = tf.cast(x, tf.uint32) 93 | x = tf.reshape(x, (32, -1)) 94 | 95 | shape = tf.shape(x) 96 | out = tf.zeros((shape[1],), dtype=tf.uint32) 97 | for i in range(32): 98 | bitn = tf.bitwise.left_shift(x[i, :], i) 99 | out = tf.bitwise.bitwise_xor(out, bitn) 100 | 101 | return tf.bitcast(out, tf.float32) 102 | 103 | 104 | @tf.function 105 | def replace_nan(input, new_value = 0.0): 106 | new_value = float(new_value) 107 | indices = tf.where(tf.math.is_nan(input)) 108 | res = tf.tensor_scatter_nd_update( 109 | input, 110 | indices, 111 | tf.fill((tf.shape(indices)[0], ), new_value) 112 | ) 113 | return res 114 | 115 | INF = tf.constant(np.array([np.inf]), dtype=tf.float32) 116 | 117 | @tf.function 118 | def replace_nan_to_inf(x): 119 | EXPONENT_MASK = tf.constant([0x7F800000], dtype=tf.uint32) 120 | INF_MASK = tf.constant([0xFF800000], dtype=tf.uint32) 121 | IDENTITY_MASK = tf.constant([0xFFFFFFFF], dtype=tf.uint32) 122 | x = tf.bitcast(x, tf.uint32) 123 | mask = tf.where( 124 | tf.equal(tf.bitwise.bitwise_and(x, EXPONENT_MASK), EXPONENT_MASK), 125 | INF_MASK, IDENTITY_MASK 126 | ) 127 | return tf.bitcast(tf.bitwise.bitwise_and(x, mask), tf.float32) 128 | 129 | @tf.function 130 | def get_ber(b, b_hat): 131 | bit_errors = tf.math.count_nonzero(b != b_hat, dtype=tf.float32) 132 | total_bits = tf.cast(tf.reduce_prod(b.shape), dtype=tf.float32) 133 | ber = bit_errors / total_bits 134 | 135 | return ber 136 | 137 | def test(): 138 | # f32 139 | x = tf.random.normal(shape=(64,)) 140 | b = tensor_to_binary_v2(x) 141 | y = binary_to_tensor_v2(b) 142 | tf.debugging.assert_near(x, y) 143 | 144 | # 145 | x = tf.random.uniform(shape=(64,), maxval=tf.int32.max, dtype=tf.int32) 146 | x = tf.cast(x, tf.uint32) 147 | b = tensor_to_binary_u32(x) 148 | y = binary_to_tensor_u32(b) 149 | tf.debugging.assert_equal(x, y) 150 | 151 | if __name__ == '__main__': 152 | test() -------------------------------------------------------------------------------- /models/vq_vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from tensorflow import keras 5 | from tensorflow.keras import layers 6 | import tensorflow_probability as tfp 7 | import tensorflow as tf 8 | 9 | '''https://keras.io/examples/generative/vq_vae/''' 10 | 11 | class VectorQuantizer(layers.Layer): 12 | def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): 13 | super().__init__(**kwargs) 14 | self.num_embeddings = num_embeddings # the number of codebooks. 15 | self.embedding_dim = embedding_dim # the dimension of a codebook. 16 | 17 | # The `beta` parameter is best kept between [0.25, 2] as per the paper. 18 | self.beta = beta 19 | 20 | # Initialize the embeddings which we will quantize. 21 | w_init = tf.random_uniform_initializer() 22 | self.embeddings = tf.Variable( 23 | initial_value=w_init( 24 | shape=(self.embedding_dim, self.num_embeddings), dtype="float32" 25 | ), 26 | trainable=True, 27 | name="embeddings_vqvae", 28 | ) 29 | 30 | def call(self, x): 31 | # Calculate the input shape of the inputs and 32 | input_shape = tf.shape(x) 33 | 34 | # Quantization. 35 | encoding_indices = self.get_code_indices(x) 36 | encoding_indices = tf.cast(encoding_indices, tf.int64) 37 | encodings = tf.one_hot(encoding_indices, self.num_embeddings) 38 | quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) 39 | 40 | # Reshape the quantized values back to the original input shape 41 | quantized = tf.reshape(quantized, input_shape) 42 | 43 | # Calculate vector quantization loss and add that to the layer. You can learn more 44 | # about adding losses to different layers here: 45 | # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check 46 | # the original paper to get a handle on the formulation of the loss function. 47 | commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2) 48 | codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) 49 | self.add_loss(self.beta * commitment_loss + codebook_loss) 50 | 51 | # Straight-through estimator. 52 | quantized = x + tf.stop_gradient(quantized - x) 53 | return quantized 54 | 55 | def get_code_indices(self, input): 56 | # get codebook indices for each semantic encoder output 57 | 58 | # flatten the inputs keeping `embedding_dim` intact. 59 | num_of_input_elems = np.prod(input.shape) 60 | assert num_of_input_elems % self.embedding_dim == 0, f"Argument 'embedding_dim' should be a factor of total input data, {num_of_input_elems}." 61 | flattened_inputs = tf.reshape(input, [-1, self.embedding_dim]) 62 | 63 | # Calculate L2-normalized distance between the inputs and the codes. 64 | similarity = tf.matmul(flattened_inputs, self.embeddings) 65 | distances = ( 66 | tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) 67 | + tf.reduce_sum(self.embeddings ** 2, axis=0) 68 | - 2 * similarity 69 | ) 70 | 71 | # Derive the indices for minimum distances. 72 | encoding_indices = tf.argmin(distances, axis=1) 73 | encoding_indices = tf.cast(tf.convert_to_tensor(encoding_indices), tf.float32) 74 | 75 | return encoding_indices 76 | 77 | def reconstruct_with_indices(self, indices): 78 | indices = tf.cast(indices, tf.int64) 79 | 80 | encodings = tf.one_hot(indices, self.num_embeddings) 81 | quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) 82 | 83 | return quantized 84 | 85 | def handle_invalid_values(self, output): 86 | ### If invalid outputs exist, convert them into closest valid outputs. 87 | # change all negative values to 0 88 | output = tf.maximum(output, 0) 89 | # change bigger than num_embeddings-1 values to num_embeddings-1 (They are indices, so it should be -1) 90 | output = tf.minimum(output, self.num_embeddings-1) 91 | # round values 92 | output = tf.round(output) 93 | 94 | return output -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/europarl.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import json 3 | import argparse 4 | from .hf_data_gen import HFDataGenerator 5 | 6 | # The following code is copied (and slightly modified) from DeepSC 7 | # (https://github.com/zyy598/DeepSC/blob/master/preprocess_text.py) 8 | import unicodedata 9 | import re 10 | from w3lib.html import remove_tags 11 | import pickle 12 | import os 13 | import json 14 | from tqdm import tqdm 15 | 16 | def unicode_to_ascii(s): 17 | return ''.join(c for c in unicodedata.normalize('NFD', s) 18 | if unicodedata.category(c) != 'Mn') 19 | 20 | def normalize_string(s): 21 | # normalize unicode characters 22 | s = unicode_to_ascii(s) 23 | # remove the XML-tags 24 | s = remove_tags(s) 25 | # add white space before !.? 26 | s = re.sub(r'([!.?])', r' \1', s) 27 | s = re.sub(r'[^a-zA-Z.!?]+', r' ', s) 28 | s = re.sub(r'\s+', r' ', s) 29 | # change to lower letter 30 | s = s.lower() 31 | return s 32 | 33 | def cutted_data(cleaned, MIN_LENGTH=4, MAX_LENGTH=30): 34 | cutted_lines = list() 35 | for line in cleaned: 36 | length = len(line.split()) 37 | if length > MIN_LENGTH and length < MAX_LENGTH: 38 | line = [word for word in line.split()] 39 | cutted_lines.append(' '.join(line)) 40 | return cutted_lines 41 | 42 | def save_clean_sentences(sentence, save_path): 43 | pickle.dump(sentence, open(save_path, 'wb')) 44 | print('Saved: %s' % save_path) 45 | 46 | def process_text_file(file_path): 47 | with open(file_path, 'r') as f: 48 | raw_data = f.read() 49 | sentences = raw_data.strip().split('\n') 50 | raw_data_input = [normalize_string(data) for data in sentences] 51 | raw_data_input = cutted_data(raw_data_input) 52 | return raw_data_input 53 | 54 | def tokenize(s, delim=' ', add_start_token=True, add_end_token=True, 55 | punct_to_keep=None, punct_to_remove=None): 56 | """ 57 | Tokenize a sequence, converting a string s into a list of (string) tokens by 58 | splitting on the specified delimiter. Optionally keep or remove certain 59 | punctuation marks and add start and end tokens. 60 | """ 61 | if punct_to_keep is not None: 62 | for p in punct_to_keep: 63 | s = s.replace(p, '%s%s' % (delim, p)) 64 | 65 | if punct_to_remove is not None: 66 | for p in punct_to_remove: 67 | s = s.replace(p, '') 68 | 69 | tokens = s.split(delim) 70 | if add_start_token: 71 | tokens.insert(0, '') 72 | if add_end_token: 73 | tokens.append('') 74 | return tokens 75 | 76 | def build_vocab(sequences, token_to_idx = { }, min_token_count=1, delim=' ', 77 | punct_to_keep=None, punct_to_remove=None, ): 78 | token_to_count = {} 79 | 80 | for seq in sequences: 81 | seq_tokens = tokenize(seq, delim=delim, punct_to_keep=punct_to_keep, 82 | punct_to_remove=punct_to_remove, 83 | add_start_token=False, add_end_token=False) 84 | for token in seq_tokens: 85 | if token not in token_to_count: 86 | token_to_count[token] = 0 87 | token_to_count[token] += 1 88 | 89 | for token, count in sorted(token_to_count.items()): 90 | if count >= min_token_count: 91 | token_to_idx[token] = len(token_to_idx) 92 | 93 | return token_to_idx 94 | 95 | def encode(seq_tokens, token_to_idx, allow_unk=False): 96 | seq_idx = [] 97 | for token in seq_tokens: 98 | if token not in token_to_idx: 99 | if allow_unk: 100 | token = '' 101 | else: 102 | raise KeyError('Token "%s" not in vocab' % token) 103 | seq_idx.append(token_to_idx[token]) 104 | return seq_idx 105 | 106 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True): 107 | tokens = [] 108 | for idx in seq_idx: 109 | tokens.append(idx_to_token[idx]) 110 | if stop_at_end and tokens[-1] == '': 111 | break 112 | if delim is None: 113 | return tokens 114 | else: 115 | return delim.join(tokens) 116 | 117 | SPECIAL_TOKENS = { 118 | '': 0, 119 | '': 1, 120 | '': 2, 121 | '': 3, 122 | } 123 | 124 | def process_europarl(input_data_dir, train_test_split=0.9, njobs=1): 125 | sentences = [] 126 | print('Preprocess Raw Text') 127 | from joblib import Parallel, delayed 128 | sentences = Parallel(n_jobs=njobs, verbose=1)( 129 | delayed(process_text_file)(fn) 130 | for fn in pathlib.Path(input_data_dir).glob('*.txt')) 131 | sentences = [s for s_list in sentences for s in s_list ] 132 | 133 | # remove the same sentences 134 | a = {} 135 | for set in sentences: 136 | if set not in a: 137 | a[set] = 0 138 | a[set] += 1 139 | sentences = list(a.keys()) 140 | print('Number of sentences: {}'.format(len(sentences))) 141 | 142 | print('Build Vocab') 143 | token_to_idx = build_vocab( 144 | sentences, SPECIAL_TOKENS, 145 | punct_to_keep=[';', ','], punct_to_remove=['?', '.'] 146 | ) 147 | 148 | vocab = {'token_to_idx': token_to_idx} 149 | print('Number of words in Vocab: {}'.format(len(token_to_idx))) 150 | 151 | print('Start encoding txt') 152 | results = [] 153 | for seq in tqdm(sentences): 154 | words = tokenize(seq, punct_to_keep=[';', ','], punct_to_remove=['?', '.']) 155 | tokens = [token_to_idx[word] for word in words] 156 | results.append(tokens) 157 | 158 | train_data = results[: round(len(results) * train_test_split)] 159 | test_data = results[round(len(results) * train_test_split):] 160 | 161 | return train_data, test_data, vocab 162 | # End of the copied code 163 | 164 | class Tokenizer: 165 | 166 | TOKENS_FILTERED = set([ 167 | '', '' 168 | ]) 169 | 170 | def __init__(self, vocab): 171 | idx_to_token = [None for _ in range(1 + max(vocab['token_to_idx'].values()))] 172 | for token, idx in vocab['token_to_idx'].items(): 173 | idx_to_token[idx] = token 174 | self.idx_to_token = idx_to_token 175 | self.token_to_idx = vocab['token_to_idx'] 176 | 177 | def decode(self, token_ids): 178 | tokens = map(lambda i: self.idx_to_token[i], token_ids) 179 | tokens = filter(lambda t: t not in self.TOKENS_FILTERED, tokens) 180 | return ' '.join(tokens) 181 | 182 | def batch_decode(self, token_ids_list): 183 | return list(map(lambda token_ids: self.decode(token_ids), token_ids_list)) 184 | 185 | def gen_hf_dataset(path: pathlib.Path, output_path=None, train_test_split=0.9, njobs=1): 186 | path = pathlib.Path(path) 187 | if output_path is None: 188 | output_path = path 189 | 190 | train_data, test_data, vocab = process_europarl(path, train_test_split, njobs) 191 | 192 | # save processed sentences 193 | with open(output_path / 'train_data.pkl', 'wb') as f: 194 | pickle.dump(train_data, f) 195 | with open(output_path / 'test_data.pkl', 'wb') as f: 196 | pickle.dump(test_data, f) 197 | with open(output_path / 'vocab.json', 'w') as f: 198 | json.dump(vocab, f) 199 | 200 | tokenizer = Tokenizer(vocab) 201 | 202 | # train set 203 | # import random 204 | # num_train = 3000 205 | # num_test = 1000 206 | train_data = tokenizer.batch_decode(train_data) 207 | train_gen = HFDataGenerator() 208 | # train_data = random.sample(train_data, num_train) 209 | train_gen.add(train_data, train_data) 210 | train_gen.dump(output_path / 'train.csv') 211 | 212 | # train json data 213 | with open(output_path / 'train.json', 'w') as f: 214 | json_data = [{ 'input': s, 'refs': [s] } for s in train_data] 215 | json.dump(json_data, f, indent=4) 216 | 217 | # test set 218 | test_data = tokenizer.batch_decode(test_data) 219 | test_gen = HFDataGenerator() 220 | # test_data = random.sample(test_data, num_test) 221 | test_gen.add(test_data, test_data) 222 | test_gen.dump(output_path / 'test.csv') 223 | 224 | # test json data 225 | with open(output_path / 'test.json', 'w') as f: 226 | json_data = [{ 'input': s, 'refs': [s] } for s in test_data] 227 | json.dump(json_data, f, indent=4) 228 | 229 | 230 | if __name__ == '__main__': 231 | parser = argparse.ArgumentParser("Preprocess Eurlparl data. It generates the same dataset used in DeepSC") 232 | parser.add_argument( 233 | '-o', '--out-path', 234 | dest='out_path', 235 | required=True, 236 | type=pathlib.Path, 237 | help='Required. Path of output files.') 238 | parser.add_argument( 239 | '--train-test-split', 240 | dest='train_test_split', 241 | default=0.9, 242 | type=float, 243 | help='Trainset/ Testset split ratio') 244 | parser.add_argument( 245 | '-j' 246 | '--njobs', 247 | dest='njobs', 248 | default=1, 249 | type=int, 250 | help='Number of threads to be used for preprocessing') 251 | parser.add_argument( 252 | dest='path', 253 | type=pathlib.Path, 254 | help="Path of europarl dataset. It should be '/txt/en'") 255 | args = parser.parse_args() 256 | gen_hf_dataset(args.path, args.out_path, args.train_test_split, args.njobs) 257 | 258 | -------------------------------------------------------------------------------- /preprocess/flickr30k.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import random 4 | import argparse 5 | import pathlib 6 | 7 | def parse_line(line: str): 8 | key = line.split('#')[0] 9 | caption = line.split('\t')[-1].strip() 10 | return key, caption 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-o', '--out-path', 15 | dest='out_path', 16 | type=pathlib.Path, 17 | help='output json file path') 18 | parser.add_argument('-n', '--num-samples', 19 | dest='N', 20 | default=1000, 21 | type=int, 22 | help='number of samples (default: 1000)') 23 | parser.add_argument('--send', 24 | dest='seed', 25 | default=20221017, 26 | type=int, 27 | help='seed for random module') 28 | 29 | parser.add_argument( 30 | dest='token_path', 31 | help="path of 'results_20130124.token'") 32 | args = parser.parse_args() 33 | 34 | # read token file 35 | data = defaultdict(list) 36 | with open(args.token_path) as f: 37 | for k, caption in map(parse_line, f): 38 | data[k].append(caption) 39 | data = list(data.values()) 40 | 41 | # set seed 42 | random.seed(args.seed) 43 | 44 | # sample dataset 45 | samples = random.sample(range(len(data)), k=args.N) 46 | out_data = [] 47 | for i in samples: 48 | captions = data[i] 49 | input_idx = random.sample(range(len(captions)), k=1)[0] 50 | input_sentence = captions[input_idx] 51 | ref_sentences = captions[:input_idx] + captions[(input_idx+1):] 52 | out_data.append({ 53 | 'input': input_sentence, 54 | 'refs': ref_sentences, 55 | }) 56 | 57 | with open(args.out_path, 'w') as f: 58 | json.dump(out_data, f, indent=4) -------------------------------------------------------------------------------- /preprocess/hf_data_gen.py: -------------------------------------------------------------------------------- 1 | # exampel file format 2 | # -------------------- 3 | # text,summary 4 | # "I'm sitting here in a boring room. It's just another rainy Sunday afternoon. I'm wasting my time I got nothing to do. I'm hanging around I'm waiting for you. But nothing ever happens. And I wonder","I'm sitting in a room where I'm waiting for something to happen" 5 | # "I see trees so green, red roses too. I see them bloom for me and you. And I think to myself what a wonderful world. I see skies so blue and clouds so white. The bright blessed day, the dark sacred night. And I think to myself what a wonderful world.","I'm a gardener and I'm a big fan of flowers." 6 | # "Christmas time is here. Happiness and cheer. Fun for all that children call. Their favorite time of the year. Snowflakes in the air. Carols everywhere. Olden times and ancient rhymes. Of love and dreams to share","It's that time of year again." 7 | import csv 8 | 9 | class HFDataGenerator: 10 | 11 | FIELDNAMES = ['text', 'summary'] 12 | 13 | def __init__(self): 14 | self.rows = [] 15 | 16 | def add(self, text, summary): 17 | if isinstance(text, str) and isinstance(summary, str): 18 | self.rows.append({ 19 | 'text': f'{text}', 20 | 'summary': f'{summary}', 21 | }) 22 | elif isinstance(text, list) and isinstance(summary, list): 23 | assert len(text) == len(summary), "Different text and summary list" 24 | rows = map(lambda ts: {'text': ts[0], 'summary': ts[1]}, zip(text, summary)) 25 | self.rows.extend(rows) 26 | else: 27 | assert False, "No such case" 28 | 29 | 30 | def dump(self, filename): 31 | with open(filename, 'w', newline='') as f: 32 | writer = csv.DictWriter(f, fieldnames=self.FIELDNAMES, quoting=csv.QUOTE_ALL) 33 | writer.writeheader() 34 | writer.writerows(self.rows) 35 | 36 | def test(): 37 | data_path = "/tmp/hf_data_test.csv" 38 | gen = HFDataGenerator() 39 | gen.add("text1", "summary1") 40 | gen.add("text2", "summary2") 41 | gen.dump(data_path) 42 | 43 | expected = """"text","summary" 44 | "text1","summary1" 45 | "text2","summary2" 46 | """ 47 | with open( data_path) as f: 48 | actual = f.read() 49 | assert actual == expected, f"""'{actual}'""" 50 | 51 | if __name__ == '__main__': 52 | test() 53 | 54 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | eval_ebno_db="4" 2 | metric="sbert" # bleu, sbert 3 | testset_path='data/flickr/processed/flickr30k.json' 4 | 5 | checkpoint_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15' 6 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15/CDL-A' 7 | 8 | mkdir -p $output_dir 9 | 10 | fec_type="Polar5G" # Polar5G, LDPC5G 11 | fec_num_iter=20 12 | channel_num_tx_ant="2" 13 | channel_num_rx_ant="2" 14 | num_bits_per_symbol="4" 15 | EVAL_NUM_BEAMS="1" 16 | 17 | python eval.py \ 18 | -m "${metric}" \ 19 | -b 8 \ 20 | -e "${eval_ebno_db}" \ 21 | --result-json-path "${output_dir}/flickr_${metric}_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \ 22 | --prediction-json-path "${output_dir}/flickr_prediction_${eval_ebno_db}dB_${fec_type}_${channel_num_tx_ant}_${channel_num_rx_ant}_${num_bits_per_symbol}.json" \ 23 | --fec-type "${fec_type}" \ 24 | --fec-num-iter "${fec_num_iter}" \ 25 | --channel-type "CDL" \ 26 | --cdl-model "A" \ 27 | --channel-num-tx-ant "${channel_num_tx_ant}" \ 28 | --channel-num-rx-ant "${channel_num_rx_ant}" \ 29 | --num-bits-per-symbol "${num_bits_per_symbol}" \ 30 | --bin-conv-method "vector_quantization" \ 31 | --embedding-dim 2 \ 32 | --num-embeddings 1024 \ 33 | --num-beams "${EVAL_NUM_BEAMS}" \ 34 | --testset-path "${testset_path}" \ 35 | $checkpoint_dir -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | output_dir='checkpoints/on-device-ai-comm/train_CDL-A_ebnodb_5_15' 2 | trainset_path='data/europarl/processed/train.csv' 3 | devset_path='data/europarl/processed/test.csv' 4 | 5 | mkdir -p $output_dir 6 | 7 | python train.py \ 8 | --model_name_or_path facebook/bart-base \ 9 | --config_name facebook/bart-base \ 10 | --tokenizer_name facebook/bart-base \ 11 | --train_file "$trainset_path" \ 12 | --validation_file "$devset_path" \ 13 | --test_file "$devset_path" \ 14 | --preprocessing_num_workers 4 \ 15 | --per_device_train_batch_size 4 \ 16 | --per_device_eval_batch_size 4 \ 17 | --num_train_epochs 3 \ 18 | --do_train \ 19 | --do_eval \ 20 | --save_total_limit 1 \ 21 | --no_use_fast_tokenizer \ 22 | --num_beams 1 \ 23 | --pad_to_max_length \ 24 | --overwrite_output_dir \ 25 | --max_source_length 64 \ 26 | --max_target_length 64 \ 27 | --output_dir $output_dir \ 28 | --ebno_db_min 5 \ 29 | --ebno_db_max 15 \ 30 | --channel_type "CDL" \ 31 | --fec_type "Polar5G" \ 32 | --fec_num_iter 20 \ 33 | --cdl_model "A" \ 34 | --channel_num_tx_ant "2" \ 35 | --channel_num_rx_ant "2" \ 36 | --num_bits_per_symbol "4" \ 37 | --bin_conv_method "vector_quantization" \ 38 | --embedding_dim 2 \ 39 | --num_embeddings 1024 \ 40 | --dataset_config 3.0.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This file is based on 18 | # https://github.com/huggingface/transformers/blob/main/examples/tensorflow/summarization/run_summarization.py 19 | import json 20 | import logging 21 | import os 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Optional, List 25 | 26 | import datasets 27 | import nltk 28 | import numpy as np 29 | import tensorflow as tf 30 | from datasets import load_dataset 31 | import matplotlib.pyplot as plt 32 | 33 | import evaluate 34 | import transformers 35 | from filelock import FileLock 36 | from transformers import ( 37 | AutoConfig, 38 | AutoTokenizer, 39 | DataCollatorForSeq2Seq, 40 | HfArgumentParser, 41 | KerasMetricCallback, 42 | TFTrainingArguments, 43 | set_seed, 44 | ) 45 | from tensorflow.keras.callbacks import ModelCheckpoint 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import is_offline_mode 48 | from transformers.optimization_tf import create_optimizer 49 | from train.args import ModelArguments, DataTrainingArguments, summarization_name_mapping, Seq2SeqSCArguments 50 | from models.on_device_ai_comm import TFOnDeviceAICForConditionalGeneration 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | from datetime import datetime 55 | 56 | try: 57 | nltk.data.find("tokenizers/punkt") 58 | except (LookupError, OSError): 59 | if is_offline_mode(): 60 | raise LookupError( 61 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 62 | ) 63 | with FileLock(".lock") as lock: 64 | nltk.download("punkt", quiet=True) 65 | 66 | def main(): 67 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments, Seq2SeqSCArguments)) 68 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 69 | # If we pass only one argument to the script and it's the path to a json file, 70 | # let's parse it to get our arguments. 71 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_json_file( 72 | json_file=os.path.abspath(sys.argv[1])) 73 | else: 74 | model_args, data_args, training_args, seq2seq_sc_args = parser.parse_args_into_dataclasses() 75 | 76 | if training_args.fp16: 77 | policy = tf.keras.mixed_precision.Policy('mixed_float16') 78 | tf.keras.mixed_precision.set_global_policy(policy) 79 | 80 | current_time = datetime.now() 81 | current_time = current_time.strftime("%Y-%m-%d_%H-%M-%S") 82 | sys.stdout = open(f'./{training_args.output_dir}/train_stdout-{current_time}.log','a') 83 | sys.stderr = open(f'./{training_args.output_dir}/train_stderr-{current_time}.log','a') 84 | 85 | # region Logging 86 | logging.basicConfig( 87 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 88 | datefmt="%m/%d/%Y %H:%M:%S", 89 | handlers=[logging.StreamHandler(sys.stdout)], 90 | ) 91 | logger.setLevel(logging.INFO) 92 | datasets.utils.logging.set_verbosity(logging.INFO) 93 | transformers.utils.logging.set_verbosity(logging.INFO) 94 | 95 | # Log on each process the small summary: 96 | logger.info(f"Training/evaluation parameters {training_args}") 97 | # endregion 98 | 99 | # region Detecting last checkpoint 100 | last_checkpoint = None 101 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 102 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 103 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 104 | raise ValueError( 105 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 106 | "Use --overwrite_output_dir to overcome." 107 | ) 108 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 109 | logger.info( 110 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 111 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 112 | ) 113 | # endregion 114 | 115 | # Set seed before initializing model. 116 | set_seed(training_args.seed) 117 | 118 | # region Load datasets 119 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 120 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 121 | # (the dataset will be downloaded automatically from the datasets Hub). 122 | # 123 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 124 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 125 | # 126 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 127 | # download the dataset. 128 | if data_args.dataset_name is not None: 129 | # Downloading and loading a dataset from the hub. 130 | raw_datasets = load_dataset( 131 | data_args.dataset_name, 132 | data_args.dataset_config_name, 133 | cache_dir=model_args.cache_dir, 134 | use_auth_token=True if model_args.use_auth_token else None, 135 | ) 136 | else: 137 | data_files = {} 138 | if data_args.train_file is not None: 139 | data_files["train"] = data_args.train_file 140 | extension = data_args.train_file.split(".")[-1] 141 | if data_args.validation_file is not None: 142 | data_files["validation"] = data_args.validation_file 143 | extension = data_args.validation_file.split(".")[-1] 144 | if data_args.test_file is not None: 145 | data_files["test"] = data_args.test_file 146 | extension = data_args.test_file.split(".")[-1] 147 | raw_datasets = load_dataset( 148 | extension, 149 | data_files=data_files, 150 | cache_dir=model_args.cache_dir, 151 | use_auth_token=True if model_args.use_auth_token else None, 152 | ) 153 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 154 | # https://huggingface.co/docs/datasets/loading_datasets.html. 155 | # endregion 156 | 157 | # region Load model config and tokenizer 158 | # 159 | # Distributed training: 160 | # The .from_pretrained methods guarantee that only one local process can concurrently 161 | # download model & vocab. 162 | 163 | config = AutoConfig.from_pretrained( 164 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 165 | cache_dir=model_args.cache_dir, 166 | revision=model_args.model_revision, 167 | use_auth_token=True if model_args.use_auth_token else None, 168 | ) 169 | tokenizer = AutoTokenizer.from_pretrained( 170 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 171 | cache_dir=model_args.cache_dir, 172 | use_fast=model_args.use_fast_tokenizer, 173 | revision=model_args.model_revision, 174 | use_auth_token=True if model_args.use_auth_token else None, 175 | ) 176 | 177 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 178 | # endregion 179 | 180 | # region Dataset preprocessing 181 | # We need to tokenize inputs and targets. 182 | if training_args.do_train: 183 | column_names = raw_datasets["train"].column_names 184 | elif training_args.do_eval: 185 | column_names = raw_datasets["validation"].column_names 186 | else: 187 | logger.info("There is nothing to do. Please pass `do_train`, and/or `do_eval`.") 188 | return 189 | 190 | # Get the column names for input/target. 191 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 192 | if data_args.text_column is None: 193 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 194 | else: 195 | text_column = data_args.text_column 196 | if text_column not in column_names: 197 | raise ValueError( 198 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 199 | ) 200 | if data_args.summary_column is None: 201 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 202 | else: 203 | summary_column = data_args.summary_column 204 | if summary_column not in column_names: 205 | raise ValueError( 206 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 207 | ) 208 | 209 | # Temporarily set max_target_length for training. 210 | max_target_length = data_args.max_target_length 211 | padding = "max_length" if data_args.pad_to_max_length else False 212 | 213 | def preprocess_function(examples): 214 | inputs = examples[text_column] 215 | assert prefix is not None 216 | for i, inp in enumerate(inputs): 217 | if inp is None: 218 | print(i, inputs[i], inputs[i-1], inputs[i+1]) 219 | targets = examples[summary_column] 220 | inputs = [prefix + inp for inp in inputs] 221 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 222 | 223 | # Tokenize targets with the `text_target` keyword argument 224 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 225 | 226 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 227 | # padding in the loss. 228 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 229 | labels["input_ids"] = [ 230 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 231 | ] 232 | 233 | model_inputs["labels"] = labels["input_ids"] 234 | return model_inputs 235 | 236 | if training_args.do_train: 237 | if "train" not in raw_datasets: 238 | raise ValueError("--do_train requires a train dataset") 239 | train_dataset = raw_datasets["train"] 240 | if data_args.max_train_samples is not None: 241 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 242 | train_dataset = train_dataset.select(range(max_train_samples)) 243 | with training_args.main_process_first(desc="train dataset map pre-processing"): 244 | train_dataset = train_dataset.map( 245 | preprocess_function, 246 | batched=True, 247 | num_proc=data_args.preprocessing_num_workers, 248 | remove_columns=column_names, 249 | load_from_cache_file=not data_args.overwrite_cache, 250 | desc="Running tokenizer on train dataset", 251 | ) 252 | else: 253 | train_dataset = None 254 | 255 | if training_args.do_eval: 256 | max_target_length = data_args.val_max_target_length 257 | if "validation" not in raw_datasets: 258 | raise ValueError("--do_eval requires a validation dataset") 259 | eval_dataset = raw_datasets["validation"] 260 | if data_args.max_eval_samples is not None: 261 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 262 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 263 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 264 | eval_dataset = eval_dataset.map( 265 | preprocess_function, 266 | batched=True, 267 | num_proc=data_args.preprocessing_num_workers, 268 | remove_columns=column_names, 269 | load_from_cache_file=not data_args.overwrite_cache, 270 | desc="Running tokenizer on validation dataset", 271 | ) 272 | else: 273 | eval_dataset = None 274 | # endregion 275 | 276 | # region Text preprocessing 277 | def postprocess_text(preds, labels): 278 | preds = [pred.strip() for pred in preds] 279 | labels = [label.strip() for label in labels] 280 | 281 | # rougeLSum expects newline after each sentence 282 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 283 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 284 | 285 | return preds, labels 286 | 287 | 288 | # endregion 289 | with training_args.strategy.scope(): 290 | # region Prepare model 291 | if model_args.model_name_or_path == '' or model_args.model_name_or_path is None: 292 | # Create a BART model with random initialization 293 | model = TFOnDeviceAICForConditionalGeneration( 294 | ebno_db=seq2seq_sc_args.ebno_db, 295 | ebno_db_min=seq2seq_sc_args.ebno_db_min, 296 | ebno_db_max=seq2seq_sc_args.ebno_db_max, 297 | fec_k=seq2seq_sc_args.k, 298 | fec_n=seq2seq_sc_args.n, 299 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol, 300 | channel_type=seq2seq_sc_args.channel_type, 301 | cdl_model=seq2seq_sc_args.cdl_model, 302 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant, 303 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant, 304 | fec_type=seq2seq_sc_args.fec_type, 305 | bin_conv_method=seq2seq_sc_args.bin_conv_method, 306 | embedding_dim=seq2seq_sc_args.embedding_dim, 307 | num_embeddings=seq2seq_sc_args.num_embeddings, 308 | do_train=training_args.do_train, 309 | config=config, 310 | ) 311 | print(f'Random Initialized without pre-trained weights.') 312 | else: 313 | model_cls = TFOnDeviceAICForConditionalGeneration 314 | # https://huggingface.co/docs/transformers/v4.34.0/en/main_classes/model#transformers.TFPreTrainedModel.from_pretrained 315 | model = model_cls.from_pretrained( 316 | model_args.model_name_or_path, 317 | ebno_db=seq2seq_sc_args.ebno_db, 318 | ebno_db_min=seq2seq_sc_args.ebno_db_min, 319 | ebno_db_max=seq2seq_sc_args.ebno_db_max, 320 | fec_k=seq2seq_sc_args.k, 321 | fec_n=seq2seq_sc_args.n, 322 | num_bits_per_symbol=seq2seq_sc_args.num_bits_per_symbol, 323 | channel_type=seq2seq_sc_args.channel_type, 324 | cdl_model=seq2seq_sc_args.cdl_model, 325 | channel_num_tx_ant=seq2seq_sc_args.channel_num_tx_ant, 326 | channel_num_rx_ant=seq2seq_sc_args.channel_num_rx_ant, 327 | fec_type=seq2seq_sc_args.fec_type, 328 | bin_conv_method=seq2seq_sc_args.bin_conv_method, 329 | embedding_dim=seq2seq_sc_args.embedding_dim, 330 | num_embeddings=seq2seq_sc_args.num_embeddings, 331 | do_train=training_args.do_train, 332 | config=config, 333 | cache_dir=model_args.cache_dir, 334 | revision=model_args.model_revision, 335 | use_auth_token=True if model_args.use_auth_token else None, 336 | ) 337 | print(f'Loaded {model_args.model_name_or_path}') 338 | 339 | model.resize_token_embeddings(len(tokenizer)) 340 | 341 | # endregion 342 | 343 | # region Prepare TF Dataset objects 344 | if model.config.decoder_start_token_id is None: 345 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 346 | 347 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 348 | data_collator = DataCollatorForSeq2Seq( 349 | tokenizer, 350 | model=model, 351 | label_pad_token_id=label_pad_token_id, 352 | pad_to_multiple_of=128, # Reduce the number of unique shapes for XLA, especially for generation 353 | return_tensors="tf", 354 | ) 355 | 356 | dataset_options = tf.data.Options() 357 | dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 358 | 359 | num_replicas = training_args.strategy.num_replicas_in_sync 360 | total_train_batch_size = training_args.per_device_train_batch_size * num_replicas 361 | total_eval_batch_size = training_args.per_device_eval_batch_size * num_replicas 362 | 363 | # model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in 364 | # training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also 365 | # use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names 366 | # yourself if you use this method, whereas they are automatically inferred from the model input names when 367 | # using model.prepare_tf_dataset() 368 | # For more info see the docs: 369 | # https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset 370 | # https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset 371 | 372 | tf_train_dataset = model.prepare_tf_dataset( 373 | train_dataset, 374 | collate_fn=data_collator, 375 | batch_size=total_train_batch_size, 376 | shuffle=True, 377 | ).with_options(dataset_options) 378 | tf_eval_dataset = model.prepare_tf_dataset( 379 | eval_dataset, 380 | collate_fn=data_collator, 381 | batch_size=total_eval_batch_size, 382 | shuffle=False, 383 | ).with_options(dataset_options) 384 | # endregion 385 | 386 | # region Optimizer, loss and LR scheduling 387 | num_train_steps = int(len(tf_train_dataset) * training_args.num_train_epochs) 388 | if training_args.warmup_steps > 0: 389 | num_warmup_steps = training_args.warmup_steps 390 | elif training_args.warmup_ratio > 0: 391 | num_warmup_steps = int(num_train_steps * training_args.warmup_ratio) 392 | else: 393 | num_warmup_steps = 0 394 | if training_args.do_train: 395 | optimizer, lr_schedule = create_optimizer( 396 | init_lr=training_args.learning_rate, 397 | num_train_steps=num_train_steps, 398 | num_warmup_steps=num_warmup_steps, 399 | adam_beta1=training_args.adam_beta1, 400 | adam_beta2=training_args.adam_beta2, 401 | adam_epsilon=training_args.adam_epsilon, 402 | weight_decay_rate=training_args.weight_decay, 403 | adam_global_clipnorm=training_args.max_grad_norm, 404 | ) 405 | else: 406 | optimizer = None 407 | 408 | # endregion 409 | 410 | # region Metric and KerasMetricCallback 411 | if training_args.do_eval: 412 | metric = evaluate.load("rouge") 413 | 414 | if data_args.val_max_target_length is None: 415 | data_args.val_max_target_length = data_args.max_target_length 416 | 417 | gen_kwargs = { 418 | "max_length": data_args.val_max_target_length if data_args is not None else config.max_length, 419 | "num_beams": data_args.num_beams, 420 | "no_repeat_ngram_size": 0, # Not supported under XLA right now, and some models set it by default 421 | } 422 | 423 | def compute_metrics(preds): 424 | predictions, labels = preds 425 | if isinstance(predictions, tuple): 426 | predictions = predictions[0] 427 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 428 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 429 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 430 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 431 | metrics = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 432 | # Only print the mid f-measures, but there are a lot of other statistics in there too! 433 | metrics = {key: round(val * 100, 4) for key, val in metrics.items()} 434 | return metrics 435 | 436 | # The KerasMetricCallback allows metrics that are too complex to write as standard Keras metrics 437 | # to be computed each epoch. Any Python code can be included in the metric_fn. This is especially 438 | # useful for metrics like BLEU and ROUGE that perform string comparisons on decoded model outputs. 439 | # For more information, see the docs at 440 | # https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.KerasMetricCallback 441 | 442 | metric_callback = KerasMetricCallback( 443 | metric_fn=compute_metrics, 444 | eval_dataset=tf_eval_dataset, 445 | predict_with_generate=True, 446 | use_xla_generation=True, 447 | generate_kwargs=gen_kwargs, 448 | ) 449 | callbacks = [metric_callback] 450 | 451 | # Define a callback to save the model 452 | checkpoint_callback = ModelCheckpoint( 453 | filepath= training_args.output_dir + '/model_epoch_{epoch:02d}_loss_{loss:.4f}.h5', 454 | monitor='loss', # metric name 455 | save_best_only=True, # Save only if the monitored quantity improves 456 | save_weights_only=True, # Only save the model weights 457 | verbose=1 # Print more information 458 | ) 459 | 460 | # Add the checkpoint callback to list of callbacks 461 | callbacks.append(checkpoint_callback) 462 | else: 463 | callbacks = [] 464 | # endregion 465 | 466 | # region Training 467 | model.compile(optimizer=optimizer, jit_compile=training_args.xla) 468 | eval_metrics = None 469 | if training_args.do_train: 470 | logger.info("***** Running training *****") 471 | logger.info(f" Num examples = {len(train_dataset)}") 472 | logger.info(f" Num Epochs = {training_args.num_train_epochs}") 473 | logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") 474 | logger.info(f" Total train batch size = {total_train_batch_size}") 475 | logger.info(f" Total optimization steps = {num_train_steps}") 476 | 477 | if training_args.xla and not data_args.pad_to_max_length: 478 | logger.warning( 479 | "XLA training may be slow at first when --pad_to_max_length is not set " 480 | "until all possible shapes have been compiled." 481 | ) 482 | history = model.fit(tf_train_dataset, epochs=int(training_args.num_train_epochs), callbacks=callbacks) 483 | # e.g., When epoch3, history.history={'loss': [3.3186378479003906, 1.8291417360305786, 1.5430701971054077], 'rouge1': [49.4591, 57.3391, 66.4159], 'rouge2': [36.3391, 50.6601, 60.5605], 'rougeL': [49.1986, 58.1554, 65.7636], 'rougeLsum': [49.112, 58.124, 65.7636]} 484 | eval_metrics = {key: val[-1] for key, val in history.history.items()} 485 | # endregion 486 | 487 | # region Validation 488 | 489 | if training_args.do_eval and not training_args.do_train: 490 | # Do a standalone evaluation run 491 | logger.info("Evaluation...") 492 | 493 | # Compiling generation with XLA yields enormous speedups, see https://huggingface.co/blog/tf-xla-generate 494 | @tf.function(jit_compile=True) 495 | def generate(**kwargs): 496 | return model.generate(**kwargs) 497 | 498 | for batch, labels in tf_eval_dataset: 499 | batch.update(gen_kwargs) 500 | generated_tokens = generate(**batch) 501 | if isinstance(generated_tokens, tuple): 502 | generated_tokens = generated_tokens[0] 503 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 504 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 505 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 506 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 507 | 508 | metric.add_batch(predictions=decoded_preds, references=decoded_labels) 509 | 510 | eval_metrics = metric.compute(use_stemmer=True) 511 | 512 | result = {key: round(val * 100, 4) for key, val in eval_metrics.items()} 513 | logger.info(result) 514 | # endregion 515 | 516 | if training_args.output_dir is not None and eval_metrics is not None: 517 | output_eval_file = os.path.join(training_args.output_dir, "all_results.json") 518 | with open(output_eval_file, "w") as writer: 519 | writer.write(json.dumps(eval_metrics)) 520 | 521 | if history is not None: 522 | ### --- Draw plot --- ### 523 | loss = history.history['loss'] 524 | rouge1 = history.history['rouge1'] 525 | rouge2 = history.history['rouge2'] 526 | rougeL = history.history['rougeL'] 527 | rougeLsum = history.history['rougeLsum'] 528 | 529 | # Create an array of epoch numbers 530 | epochs = np.arange(1, len(loss) + 1) 531 | 532 | # Plotting the validation loss 533 | fig, ax1 = plt.subplots(figsize=(10, 6)) 534 | color = 'tab:blue' 535 | ax1.set_xlabel('Epochs') 536 | ax1.set_ylabel('Loss', color=color) 537 | ax1.plot(epochs, loss, label='Loss', color=color) 538 | ax1.tick_params(axis='y', labelcolor=color) 539 | 540 | # Set y-axis range for validation loss 541 | ax1.set_ylim([0, max(loss)+0.2]) 542 | 543 | # Creating a secondary y-axis for ROUGE scores 544 | ax2 = ax1.twinx() 545 | 546 | color = 'tab:green' 547 | ax2.set_ylabel('ROUGE Scores', color=color) 548 | ax2.plot(epochs, rouge1, label='ROUGE-1', color=color) 549 | ax2.plot(epochs, rouge2, label='ROUGE-2', color='tab:red') 550 | ax2.plot(epochs, rougeL, label='ROUGE-L', color='tab:purple') 551 | ax2.plot(epochs, rougeLsum, label='ROUGE-Lsum', color='tab:orange') 552 | ax2.tick_params(axis='y', labelcolor=color) 553 | # Set y-axis range for ROUGE scores 554 | ax2.set_ylim([min(rouge1) - 10, 100]) 555 | 556 | # Add legend and grid 557 | fig.tight_layout() 558 | fig.legend(loc='lower right') 559 | plt.grid(True) 560 | 561 | # Save the plot 562 | output_plot_file = os.path.join(training_args.output_dir, "validation_loss_rouge_scores.png") 563 | plt.savefig(output_plot_file) 564 | 565 | 566 | if training_args.output_dir is not None and not training_args.push_to_hub: 567 | # If we're not pushing to hub, at least save a local copy when we're done 568 | model.save_pretrained(training_args.output_dir) 569 | 570 | 571 | if __name__ == "__main__": 572 | main() -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abman23/on-device-ai-comm/c7b579ddfd88c5423ce9f9cb2e3e190ac89f8fdc/train/__init__.py -------------------------------------------------------------------------------- /train/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | @dataclass 5 | class Seq2SeqSCArguments: 6 | ebno_db: Optional[float] = field(default=None, metadata= {"help": "a single SNR(ebno_db)"}) 7 | ebno_db_min: Optional[float] = field(default=None, metadata= {"help": "ebno_db_min"}) 8 | ebno_db_max: Optional[float] = field(default=None, metadata= {"help": "ebno_db_max"}) 9 | k: Optional[int] = field(default = 512, metadata= {"help": "K for polar decoder"}) 10 | n: Optional[int] = field(default = 1024, metadata= {"help": "N for polar decoder"}) 11 | fec_num_iter: Optional[int] = field(default = 6, metadata= {"help": "Number of iterations for FEC."}) 12 | num_bits_per_symbol: Optional[int] = field(default=4,metadata= {"help": "number of bits per symbol"}) 13 | channel_type: Optional[str] = field(default='AWGN',metadata= {"help": "AWGN, CDL, or 3GPP-38.901"}) 14 | cdl_model: Optional[str] = field(default='A',metadata= {"help": "A, B, C, D, or E. Kinds of CDL channel"}) 15 | channel_num_tx_ant: Optional[int] = field(default=1,metadata= {"help": "number of tx antennas for FlatFaddingChannel"}) 16 | channel_num_rx_ant: Optional[int] = field(default=1,metadata= {"help": "number of rx antennas for FlatFaddingChannel"}) 17 | fec_type: Optional[str] = field(default='Polar5G',metadata= {"help": "Polar5G or LDPC5G. Type of Channel En/Decoder."}) 18 | bin_conv_method: Optional[str] = field(default='tanh', metadata= {"help": "tensor to binary conversion. tanh, vector_quantization"}) 19 | embedding_dim: Optional[int] = field(default=2,metadata= {"help": "the dimension of a codebook in vq layer"}) 20 | num_embeddings: Optional[int] = field(default=1024,metadata= {"help": "the number of codebooks in vq layer"}) 21 | 22 | 23 | @dataclass 24 | class ModelArguments: 25 | model_name_or_path: Optional[str] = field( 26 | default=None, 27 | metadata={ 28 | "help": "Path to pretrained model or model identifier from huggingface.co/models. For random initialization, it should be None or ''." 29 | } 30 | ) 31 | config_name: Optional[str] = field( 32 | default=None, 33 | metadata={ 34 | "help": "Pretrained config name or path if not the same as model_name" 35 | } 36 | ) 37 | tokenizer_name: Optional[str] = field( 38 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 39 | ) 40 | cache_dir: Optional[str] = field( 41 | default=None, 42 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 43 | ) 44 | use_fast_tokenizer: bool = field( 45 | default=True, 46 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 47 | ) 48 | model_revision: str = field( 49 | default="main", 50 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 51 | ) 52 | use_auth_token: bool = field( 53 | default=False, 54 | metadata={ 55 | "help": ( 56 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 57 | "with private models)." 58 | ) 59 | }, 60 | ) 61 | ignore_mismatched_sizes: Optional[bool] = field( 62 | default=False, 63 | metadata={ 64 | "help":" Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels)." 65 | } 66 | ) 67 | 68 | 69 | @dataclass 70 | class DataTrainingArguments: 71 | """ 72 | Arguments pertaining to what data we are going to input our model for training and eval. 73 | """ 74 | 75 | dataset_name: Optional[str] = field( 76 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 77 | ) 78 | dataset_config_name: Optional[str] = field( 79 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 80 | ) 81 | text_column: Optional[str] = field( 82 | default=None, 83 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 84 | ) 85 | summary_column: Optional[str] = field( 86 | default=None, 87 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 88 | ) 89 | train_file: Optional[str] = field( 90 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 91 | ) 92 | validation_file: Optional[str] = field( 93 | default=None, 94 | metadata={ 95 | "help": ( 96 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 97 | ) 98 | }, 99 | ) 100 | test_file: Optional[str] = field( 101 | default=None, 102 | metadata={ 103 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 104 | }, 105 | ) 106 | overwrite_cache: bool = field( 107 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 108 | ) 109 | preprocessing_num_workers: Optional[int] = field( 110 | default=None, 111 | metadata={"help": "The number of processes to use for the preprocessing."}, 112 | ) 113 | max_source_length: Optional[int] = field( 114 | default=1024, 115 | metadata={ 116 | "help": ( 117 | "The maximum total input sequence length after tokenization. Sequences longer " 118 | "than this will be truncated, sequences shorter will be padded." 119 | ) 120 | }, 121 | ) 122 | max_target_length: Optional[int] = field( 123 | default=128, 124 | metadata={ 125 | "help": ( 126 | "The maximum total sequence length for target text after tokenization. Sequences longer " 127 | "than this will be truncated, sequences shorter will be padded." 128 | ) 129 | }, 130 | ) 131 | val_max_target_length: Optional[int] = field( 132 | default=None, 133 | metadata={ 134 | "help": ( 135 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 136 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 137 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 138 | "during ``evaluate`` and ``predict``." 139 | ) 140 | }, 141 | ) 142 | pad_to_max_length: bool = field( 143 | default=False, 144 | metadata={ 145 | "help": ( 146 | "Whether to pad all samples to model maximum sentence length. " 147 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 148 | "efficient on GPU but very bad for TPU." 149 | ) 150 | }, 151 | ) 152 | max_train_samples: Optional[int] = field( 153 | default=None, 154 | metadata={ 155 | "help": ( 156 | "For debugging purposes or quicker training, truncate the number of training examples to this " 157 | "value if set." 158 | ) 159 | }, 160 | ) 161 | max_eval_samples: Optional[int] = field( 162 | default=None, 163 | metadata={ 164 | "help": ( 165 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 166 | "value if set." 167 | ) 168 | }, 169 | ) 170 | max_predict_samples: Optional[int] = field( 171 | default=None, 172 | metadata={ 173 | "help": ( 174 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 175 | "value if set." 176 | ) 177 | }, 178 | ) 179 | num_beams: Optional[int] = field( 180 | default=None, 181 | metadata={ 182 | "help": ( 183 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 184 | "which is used during ``evaluate`` and ``predict``." 185 | ) 186 | }, 187 | ) 188 | ignore_pad_token_for_loss: bool = field( 189 | default=True, 190 | metadata={ 191 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 192 | }, 193 | ) 194 | source_prefix: Optional[str] = field( 195 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 196 | ) 197 | 198 | def __post_init__(self): 199 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 200 | raise ValueError("Need either a dataset name or a training/validation file.") 201 | else: 202 | if self.train_file is not None: 203 | extension = self.train_file.split(".")[-1] 204 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 205 | if self.validation_file is not None: 206 | extension = self.validation_file.split(".")[-1] 207 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 208 | if self.val_max_target_length is None: 209 | self.val_max_target_length = self.max_target_length 210 | 211 | 212 | # region Dataset name mappings 213 | summarization_name_mapping = { 214 | "amazon_reviews_multi": ("review_body", "review_title"), 215 | "big_patent": ("description", "abstract"), 216 | "cnn_dailymail": ("article", "highlights"), 217 | "orange_sum": ("text", "summary"), 218 | "pn_summary": ("article", "summary"), 219 | "psc": ("extract_text", "summary_text"), 220 | "samsum": ("dialogue", "summary"), 221 | "thaisum": ("body", "summary"), 222 | "xglue": ("news_body", "news_title"), 223 | "xsum": ("document", "summary"), 224 | "wiki_summary": ("article", "highlights"), 225 | "multi_news": ("document", "summary"), 226 | } 227 | # endregion 228 | 229 | 230 | --------------------------------------------------------------------------------