├── outputs └── NOTE.txt ├── rfcutils ├── __init__.py ├── sig_utils_fn.py ├── rrc_helper_fn.py ├── qpsk_helper_fn.py ├── qam16_helper_fn.py ├── qpsk2_helper_fn.py └── ofdm_helper_fn.py ├── src ├── configs │ └── wavenet.yml ├── torchdataset.py ├── config_torchwavenet.py ├── unet_model.py ├── torchwavenet.py └── learner_torchwavenet.py ├── supervised_config.yml ├── sampletest_generatetestmixtures.sh ├── dataset_utils ├── example_preprocess_npy_dataset.py ├── tfds_scripts │ ├── Dataset_QPSK_CommSignal2_Mixture.py │ ├── Dataset_QPSK_CommSignal3_Mixture.py │ ├── Dataset_QPSK_EMISignal1_Mixture.py │ ├── Dataset_QPSK_CommSignal5G1_Mixture.py │ ├── Dataset_OFDMQPSK_EMISignal1_Mixture.py │ ├── Dataset_OFDMQPSK_CommSignal2_Mixture.py │ ├── Dataset_OFDMQPSK_CommSignal3_Mixture.py │ └── Dataset_OFDMQPSK_CommSignal5G1_Mixture.py ├── example_generate_competition_trainmixture.py └── example_generate_rfc_mixtures.py ├── .gitignore ├── train_unet_model.py ├── sampletest_evaluationscript.py ├── train_torchwavenet.py ├── sampletest_tf_unet_inference.py ├── sampletrain_gendataset_script.sh ├── sampletest_torch_wavenet_inference.py ├── sampletest_testmixture_generator.py ├── rfsionna_env.yml ├── rftorch_env.yml ├── notebook └── RFC_QuickStart_Guide.ipynb └── README.md /outputs/NOTE.txt: -------------------------------------------------------------------------------- 1 | Inference outputs will be saved here. 2 | -------------------------------------------------------------------------------- /rfcutils/__init__.py: -------------------------------------------------------------------------------- 1 | from .rrc_helper_fn import * 2 | from .qpsk_helper_fn import * 3 | from .qpsk2_helper_fn import * 4 | from .qam16_helper_fn import * 5 | from .ofdm_helper_fn import * 6 | from .sig_utils_fn import * -------------------------------------------------------------------------------- /rfcutils/sig_utils_fn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | get_pow = lambda s: np.mean(np.abs(s)**2) 4 | 5 | def get_sinr(s, i, units='dB'): 6 | sinr = get_pow(s)/get_pow(i) 7 | if units == 'dB': 8 | return 10*np.log10(sinr) 9 | return sinr 10 | -------------------------------------------------------------------------------- /src/configs/wavenet.yml: -------------------------------------------------------------------------------- 1 | model_dir: ${datetime:"checkpoints/wavenet"} 2 | model: 3 | residual_layers: 30 4 | residual_channels: 128 5 | dilation_cycle_length: 10 6 | data: 7 | root_dir: npydataset/Dataset_QPSK_SynOFDM_Mixture 8 | batch_size: 32 9 | num_workers: 2 10 | train_fraction: 0.90 11 | distributed: 12 | distributed: True 13 | world_size: 2 14 | trainer: 15 | fp16: True 16 | learning_rate: 5e-4 17 | max_steps: 500_000 18 | log_every: 50 19 | save_every: 2000 20 | validate_every: 2000 -------------------------------------------------------------------------------- /supervised_config.yml: -------------------------------------------------------------------------------- 1 | model_dir: ${datetime:"checkpoints/supervised"} 2 | model: 3 | residual_layers: 30 4 | residual_channels: 128 5 | dilation_cycle_length: 10 6 | data: 7 | augmentation: True 8 | target_len: 2560 9 | batch_size: 256 10 | num_workers: 2 11 | train_fraction: 0.9 12 | # coeff: -10 # This is in dB 13 | distributed: 14 | distributed: true 15 | world_size: 2 16 | trainer: 17 | learning_rate: 5e-4 18 | max_steps: 500_000 19 | log_every: 50 20 | save_every: 2000 21 | validate_every: 1000 22 | infer_every: 2000 23 | num_infer_samples: 2 -------------------------------------------------------------------------------- /sampletest_generatetestmixtures.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create TestSet1Example in the 'dataset' folder 4 | python sampletest_testmixture_generator.py QPSK EMISignal1 5 | python sampletest_testmixture_generator.py QPSK CommSignal2 6 | python sampletest_testmixture_generator.py QPSK CommSignal3 7 | python sampletest_testmixture_generator.py QPSK CommSignal5G1 8 | 9 | python sampletest_testmixture_generator.py OFDMQPSK EMISignal1 10 | python sampletest_testmixture_generator.py OFDMQPSK CommSignal2 11 | python sampletest_testmixture_generator.py OFDMQPSK CommSignal3 12 | python sampletest_testmixture_generator.py OFDMQPSK CommSignal5G1 13 | -------------------------------------------------------------------------------- /rfcutils/rrc_helper_fn.py: -------------------------------------------------------------------------------- 1 | import sionna as sn 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | def get_psf(samples_per_symbol, span_in_symbols, beta): 6 | # samples_per_symbol: Number of samples per symbol, i.e., the oversampling factor 7 | # beta: Roll-off factor 8 | # span_in_symbols: Filter span in symbold 9 | rrcf = sn.signal.RootRaisedCosineFilter(span_in_symbols, samples_per_symbol, beta) 10 | return rrcf 11 | 12 | def matched_filter(sig, samples_per_symbol, span_in_symbols, beta): 13 | rrcf = get_psf(samples_per_symbol, span_in_symbols, beta) 14 | x_mf = rrcf(sig, padding="same") 15 | return x_mf -------------------------------------------------------------------------------- /dataset_utils/example_preprocess_npy_dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | import h5py 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | main_folder = os.path.abspath(os.path.join(os.path.dirname(__file__),'..')) 8 | print(main_folder) 9 | 10 | def preprocess_dataset(root_dir: str, save_dir: str) -> None: 11 | save_dir = os.path.join(save_dir, os.path.basename(root_dir)) 12 | os.makedirs(save_dir, exist_ok=True) 13 | 14 | count = 0 15 | for folder in tqdm(glob.glob(os.path.join(root_dir, "*.h5"))): 16 | with h5py.File(folder, "r") as f: 17 | mixture = np.array(f.get("mixture")) 18 | soi = np.array(f.get("target")) 19 | for i in range(mixture.shape[0]): 20 | data = { 21 | "sample_mix": mixture[i, ...], 22 | "sample_soi": soi[i, ...], 23 | } 24 | np.save(os.path.join(save_dir, f"sample_{count}.npy"), data) 25 | count += 1 26 | 27 | 28 | if __name__ == "__main__": 29 | dataset_type = sys.argv[1] 30 | preprocess_dataset(root_dir=f'{main_folder}/dataset/Dataset_{dataset_type}_Mixture', 31 | save_dir=f'{main_folder}/npydataset/') 32 | -------------------------------------------------------------------------------- /src/torchdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import h5py 4 | import torch 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | 10 | class RFMixtureDatasetBase(Dataset): 11 | def __init__(self, root_dir: str): 12 | super().__init__() 13 | self.root_dir = root_dir 14 | if not os.path.exists(self.root_dir): 15 | raise FileNotFoundError("Dataset root directory does not exsist.") 16 | self.files = glob.glob(os.path.join(self.root_dir, "*.npy")) 17 | 18 | def __len__(self): 19 | return len(self.files) 20 | 21 | def __getitem__(self, i): 22 | data = np.load(self.files[i], allow_pickle=True).item() 23 | return { 24 | "sample_mix": torch.tensor(data["sample_mix"]).transpose(0, 1), 25 | "sample_soi": torch.tensor(data["sample_soi"]).transpose(0, 1), 26 | } 27 | 28 | 29 | def get_train_val_dataset(dataset: Dataset, train_fraction: float): 30 | # print(len(dataset)) 31 | val_examples = int((1 - train_fraction) * len(dataset)) 32 | train_dataset, val_dataset = torch.utils.data.random_split( 33 | dataset, [len(dataset) - val_examples, val_examples], generator=torch.Generator().manual_seed(42)) 34 | return train_dataset, val_dataset 35 | 36 | 37 | if __name__ == "__main__": 38 | dataset = RFMixtureDatasetBase( 39 | root_dir="./npydataset/Dataset_QPSK_SynOFDM_Mixture", 40 | ) 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | tfds 3 | models 4 | torchmodels 5 | *RFCDEV* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | -------------------------------------------------------------------------------- /src/config_torchwavenet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import MISSING, asdict, dataclass 2 | from datetime import datetime 3 | from typing import Optional 4 | 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | OmegaConf.register_new_resolver( 8 | "datetime", lambda s: f'{s}_{datetime.now().strftime("%H_%M_%S")}') 9 | 10 | 11 | @dataclass 12 | class ModelConfig: 13 | input_channels: int = 2 14 | residual_layers: int = 30 15 | residual_channels: int = 64 16 | dilation_cycle_length: int = 10 17 | 18 | 19 | @dataclass 20 | class DataConfig: 21 | root_dir: str = MISSING 22 | batch_size: int = 16 23 | num_workers: int = 4 24 | train_fraction: float = 0.8 25 | 26 | 27 | @dataclass 28 | class DistributedConfig: 29 | distributed: bool = False 30 | world_size: int = 2 31 | 32 | 33 | @dataclass 34 | class TrainerConfig: 35 | learning_rate: float = 2e-4 36 | max_steps: int = 1000 37 | max_grad_norm: Optional[float] = None 38 | fp16: bool = False 39 | 40 | log_every: int = 50 41 | save_every: int = 2000 42 | validate_every: int = 100 43 | 44 | 45 | @dataclass 46 | class Config: 47 | model_dir: str = MISSING 48 | 49 | model: ModelConfig = ModelConfig() 50 | data: DataConfig = DataConfig(root_dir="") 51 | distributed: DistributedConfig = DistributedConfig() 52 | trainer: TrainerConfig = TrainerConfig() 53 | 54 | 55 | def parse_configs(cfg: DictConfig, cli_cfg: Optional[DictConfig] = None) -> DictConfig: 56 | base_cfg = OmegaConf.structured(Config) 57 | merged_cfg = OmegaConf.merge(base_cfg, cfg) 58 | if cli_cfg is not None: 59 | merged_cfg = OmegaConf.merge(merged_cfg, cli_cfg) 60 | return merged_cfg 61 | 62 | 63 | if __name__ == "__main__": 64 | base_config = OmegaConf.structured(Config) 65 | config = OmegaConf.load("configs/short_ofdm.yaml") 66 | config = OmegaConf.merge(base_config, OmegaConf.from_cli(), config) 67 | config = Config(**config) 68 | 69 | print(asdict(config)) 70 | -------------------------------------------------------------------------------- /src/unet_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.keras as k 3 | from tensorflow.keras import layers 4 | from tensorflow.keras.models import Model 5 | 6 | def get_unet_model(input_shape, k_sz=3, long_k_sz=101, lr=0.0003, k_neurons=32): 7 | n_window = input_shape[0] 8 | n_ch = 2 9 | 10 | in0 = layers.Input(shape=input_shape) 11 | x = in0 12 | 13 | x = layers.BatchNormalization()(x) 14 | 15 | upsamp_blocks = [] 16 | 17 | for n_layer, k in enumerate([8, 8, 8, 8, 8]): 18 | if n_layer == 0: 19 | conv = layers.Conv1D(k_neurons * k, long_k_sz, activation="relu", padding="same")(x) 20 | else: 21 | conv = layers.Conv1D(k_neurons * k, k_sz, activation="relu", padding="same")(x) 22 | 23 | conv = layers.Conv1D(k_neurons * k, k_sz, activation="relu", padding="same")(conv) 24 | pool = layers.MaxPooling1D(2)(conv) 25 | if n_layer == 0: 26 | pool = layers.Dropout(0.25)(pool) 27 | else: 28 | pool = layers.Dropout(0.5)(pool) 29 | 30 | upsamp_blocks.append(conv) 31 | x = pool 32 | 33 | # Middle 34 | convm = layers.Conv1D(k_neurons * 8, k_sz, activation="relu", padding="same")(x) 35 | convm = layers.Conv1D(k_neurons * 8, k_sz, activation="relu", padding="same")(convm) 36 | 37 | x = convm 38 | for n_layer, k in enumerate([8, 8, 4, 2, 1]): 39 | deconv = layers.Conv1DTranspose(k_neurons * k, k_sz, strides=2, padding="same")(x) 40 | uconv = layers.concatenate([deconv, upsamp_blocks[-(n_layer+1)]]) 41 | uconv = layers.Dropout(0.5)(uconv) 42 | uconv = layers.Conv1D(k_neurons * k, k_sz, activation="relu", padding="same")(uconv) 43 | uconv = layers.Conv1D(k_neurons * k, k_sz, activation="relu", padding="same")(uconv) 44 | 45 | x = uconv 46 | 47 | output_layer = layers.Conv1D(n_ch, 1, padding="same", activation=None)(x) 48 | 49 | x_out = output_layer 50 | supreg_net = Model(in0, x_out, name='supreg') 51 | 52 | supreg_net.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), 53 | loss=tf.keras.losses.MeanSquaredError(), 54 | metrics=[tf.keras.losses.MeanSquaredError()]) 55 | 56 | return supreg_net 57 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal2_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'QPSK' 19 | interference_sig_type = 'CommSignal2' 20 | 21 | class DatasetQpskCommsignal2Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal3_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'QPSK' 19 | interference_sig_type = 'CommSignal3' 20 | 21 | class DatasetQpskCommsignal3Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_QPSK_EMISignal1_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'QPSK' 19 | interference_sig_type = 'EMISignal1' 20 | 21 | class DatasetQpskEmisignal1Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal5G1_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'QPSK' 19 | interference_sig_type = 'CommSignal5G1' 20 | 21 | class DatasetQpskCommsignal5g1Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_OFDMQPSK_EMISignal1_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'OFDMQPSK' 19 | interference_sig_type = 'EMISignal1' 20 | 21 | class DatasetOfdmqpskEmisignal1Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /rfcutils/qpsk_helper_fn.py: -------------------------------------------------------------------------------- 1 | import sionna as sn 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .rrc_helper_fn import get_psf, matched_filter 6 | 7 | # Binary source to generate uniform i.i.d. bits 8 | binary_source = sn.utils.BinarySource() 9 | 10 | samples_per_symbol=16 11 | span_in_symbols=8 12 | beta=0.5 13 | 14 | # 4-QAM constellation 15 | NUM_BITS_PER_SYMBOL = 2 16 | constellation = sn.mapping.Constellation("qam", NUM_BITS_PER_SYMBOL, trainable=False) # The constellation is set to be NOT trainable 17 | 18 | # Mapper and demapper 19 | mapper = sn.mapping.Mapper(constellation=constellation) 20 | demapper = sn.mapping.Demapper("app", constellation=constellation) 21 | 22 | # AWGN channel 23 | awgn_channel = sn.channel.AWGN() 24 | 25 | def generate_qpsk_signal(batch_size, num_symbols, ebno_db=None): 26 | bits = binary_source([batch_size, num_symbols*NUM_BITS_PER_SYMBOL]) # Blocklength 27 | return modulate_qpsk_signal(bits, ebno_db) 28 | 29 | def qpsk_matched_filter_demod(sig, no=1e-4, soft_demod=False): 30 | x_mf = matched_filter(sig, samples_per_symbol, span_in_symbols, beta) 31 | num_symbols = sig.shape[-1]//samples_per_symbol 32 | ds = sn.signal.Downsampling(samples_per_symbol, samples_per_symbol//2, num_symbols) 33 | x_hat = ds(x_mf) 34 | x_hat = x_hat / tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 35 | llr = demapper([x_hat,no]) 36 | if soft_demod: 37 | return llr, x_hat 38 | return tf.cast(llr > 0, tf.float32), x_hat 39 | 40 | def modulate_qpsk_signal(info_bits, ebno_db=None): 41 | x = mapper(info_bits) 42 | us = sn.signal.Upsampling(samples_per_symbol) 43 | x_us = us(x) 44 | x_us = tf.pad(x_us, tf.constant([[0, 0,], [samples_per_symbol//2, 0]]), "CONSTANT") 45 | x_us = x_us[:, :-samples_per_symbol//2] 46 | x_rrcf = matched_filter(x_us, samples_per_symbol, span_in_symbols, beta) 47 | if ebno_db is None: 48 | y = x_rrcf 49 | else: 50 | no = sn.utils.ebnodb2no(ebno_db=ebno_db, 51 | num_bits_per_symbol=NUM_BITS_PER_SYMBOL, 52 | coderate=1.0) # Coderate set to 1 as we do uncoded transmission here 53 | y = awgn_channel([x_rrcf, no]) 54 | y = y * tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 55 | return y, x, info_bits, constellation -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal2_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'OFDMQPSK' 19 | interference_sig_type = 'CommSignal2' 20 | 21 | class DatasetOfdmqpskCommsignal2Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal3_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'OFDMQPSK' 19 | interference_sig_type = 'CommSignal3' 20 | 21 | class DatasetOfdmqpskCommsignal3Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /rfcutils/qam16_helper_fn.py: -------------------------------------------------------------------------------- 1 | import sionna as sn 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .rrc_helper_fn import get_psf, matched_filter 6 | 7 | # Binary source to generate uniform i.i.d. bits 8 | binary_source = sn.utils.BinarySource() 9 | 10 | samples_per_symbol=16 11 | span_in_symbols=8 12 | beta=0.5 13 | 14 | # 16-QAM constellation 15 | NUM_BITS_PER_SYMBOL = 4 16 | constellation = sn.mapping.Constellation("qam", NUM_BITS_PER_SYMBOL, trainable=False) # The constellation is set to be NOT trainable 17 | 18 | # Mapper and demapper 19 | mapper = sn.mapping.Mapper(constellation=constellation) 20 | demapper = sn.mapping.Demapper("app", constellation=constellation) 21 | 22 | # AWGN channel 23 | awgn_channel = sn.channel.AWGN() 24 | 25 | def generate_qam16_signal(batch_size, num_symbols, ebno_db=None): 26 | bits = binary_source([batch_size, num_symbols*NUM_BITS_PER_SYMBOL]) # Blocklength 27 | return modulate_qam16_signal(bits, ebno_db) 28 | 29 | def qam16_matched_filter_demod(sig, no=1e-4, soft_demod=False): 30 | x_mf = matched_filter(sig, samples_per_symbol, span_in_symbols, beta) 31 | num_symbols = sig.shape[-1]//samples_per_symbol 32 | ds = sn.signal.Downsampling(samples_per_symbol, samples_per_symbol//2, num_symbols) 33 | x_hat = ds(x_mf) 34 | x_hat = x_hat / tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 35 | llr = demapper([x_hat,no]) 36 | if soft_demod: 37 | return llr, x_hat 38 | return tf.cast(llr > 0, tf.float32), x_hat 39 | 40 | def modulate_qam16_signal(info_bits, ebno_db=None): 41 | x = mapper(info_bits) 42 | us = sn.signal.Upsampling(samples_per_symbol) 43 | x_us = us(x) 44 | x_us = tf.pad(x_us, tf.constant([[0, 0,], [samples_per_symbol//2, 0]]), "CONSTANT") 45 | x_us = x_us[:, :-samples_per_symbol//2] 46 | x_rrcf = matched_filter(x_us, samples_per_symbol, span_in_symbols, beta) 47 | if ebno_db is None: 48 | y = x_rrcf 49 | else: 50 | no = sn.utils.ebnodb2no(ebno_db=ebno_db, 51 | num_bits_per_symbol=NUM_BITS_PER_SYMBOL, 52 | coderate=1.0) # Coderate set to 1 as we do uncoded transmission here 53 | y = awgn_channel([x_rrcf, no]) 54 | y = y * tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 55 | return y, x, info_bits, constellation -------------------------------------------------------------------------------- /rfcutils/qpsk2_helper_fn.py: -------------------------------------------------------------------------------- 1 | import sionna as sn 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .rrc_helper_fn import get_psf, matched_filter 6 | 7 | # Binary source to generate uniform i.i.d. bits 8 | binary_source = sn.utils.BinarySource() 9 | 10 | samples_per_symbol=4 11 | span_in_symbols=8 12 | beta=0.5 13 | 14 | # 4-QAM constellation 15 | NUM_BITS_PER_SYMBOL = 2 16 | constellation = sn.mapping.Constellation("qam", NUM_BITS_PER_SYMBOL, trainable=False) # The constellation is set to be NOT trainable 17 | 18 | # Mapper and demapper 19 | mapper = sn.mapping.Mapper(constellation=constellation) 20 | demapper = sn.mapping.Demapper("app", constellation=constellation) 21 | 22 | # AWGN channel 23 | awgn_channel = sn.channel.AWGN() 24 | 25 | def generate_qpsk2_signal(batch_size, num_symbols, ebno_db=None): 26 | bits = binary_source([batch_size, num_symbols*NUM_BITS_PER_SYMBOL]) # Blocklength 27 | return modulate_qpsk2_signal(bits, ebno_db) 28 | 29 | def qpsk2_matched_filter_demod(sig, no=1e-4, soft_demod=False): 30 | x_mf = matched_filter(sig, samples_per_symbol, span_in_symbols, beta) 31 | num_symbols = sig.shape[-1]//samples_per_symbol 32 | ds = sn.signal.Downsampling(samples_per_symbol, samples_per_symbol//2, num_symbols) 33 | x_hat = ds(x_mf) 34 | x_hat = x_hat / tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 35 | llr = demapper([x_hat,no]) 36 | if soft_demod: 37 | return llr, x_hat 38 | return tf.cast(llr > 0, tf.float32), x_hat 39 | 40 | def modulate_qpsk2_signal(info_bits, ebno_db=None): 41 | x = mapper(info_bits) 42 | us = sn.signal.Upsampling(samples_per_symbol) 43 | x_us = us(x) 44 | x_us = tf.pad(x_us, tf.constant([[0, 0,], [samples_per_symbol//2, 0]]), "CONSTANT") 45 | x_us = x_us[:, :-samples_per_symbol//2] 46 | x_rrcf = matched_filter(x_us, samples_per_symbol, span_in_symbols, beta) 47 | if ebno_db is None: 48 | y = x_rrcf 49 | else: 50 | no = sn.utils.ebnodb2no(ebno_db=ebno_db, 51 | num_bits_per_symbol=NUM_BITS_PER_SYMBOL, 52 | coderate=1.0) # Coderate set to 1 as we do uncoded transmission here 53 | y = awgn_channel([x_rrcf, no]) 54 | y = y * tf.math.sqrt(tf.cast(samples_per_symbol, tf.complex64)) 55 | return y, x, info_bits, constellation -------------------------------------------------------------------------------- /dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal5G1_Mixture.py: -------------------------------------------------------------------------------- 1 | """Dataset.""" 2 | 3 | import os 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | import glob 8 | import h5py 9 | import numpy as np 10 | 11 | _DESCRIPTION = """ 12 | RFChallenge at MIT v0.2.0 13 | """ 14 | _CITATION = """ 15 | MIT, “RF Challenge - AI Accelerator.” https://rfchallenge.mit.edu/ 16 | """ 17 | 18 | soi_type = 'OFDMQPSK' 19 | interference_sig_type = 'CommSignal5G1' 20 | 21 | class DatasetOfdmqpskCommsignal5g1Mixture(tfds.core.GeneratorBasedBuilder): 22 | VERSION = tfds.core.Version('0.2.0') 23 | RELEASE_NOTES = { 24 | '0.2.0': 'RFChallenge 2023 release.', 25 | } 26 | 27 | def _info(self) -> tfds.core.DatasetInfo: 28 | """Returns the dataset metadata.""" 29 | return tfds.core.DatasetInfo( 30 | builder=self, 31 | description=_DESCRIPTION, 32 | features=tfds.features.FeaturesDict({ 33 | 'mixture': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 34 | 'signal': tfds.features.Tensor(shape=(None, 2), dtype=tf.float32), 35 | }), 36 | supervised_keys=('mixture', 'signal'), 37 | homepage='https://rfchallenge.mit.edu/', 38 | citation=_CITATION, 39 | ) 40 | 41 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 42 | """Returns SplitGenerators.""" 43 | path = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 44 | 45 | return { 46 | 'train': self._generate_examples(path), 47 | } 48 | 49 | def _generate_examples(self, path): 50 | """Yields examples.""" 51 | for f in glob.glob(os.path.join(path, '*.h5')): 52 | with h5py.File(f,'r') as h5file: 53 | mixture = np.array(h5file.get('mixture')) 54 | target = np.array(h5file.get('target')) 55 | sig_type = h5file.get('sig_type')[()] 56 | if isinstance(sig_type, bytes): 57 | sig_type = sig_type.decode("utf-8") 58 | for i in range(mixture.shape[0]): 59 | yield f'data_{f}_{i}', { 60 | 'mixture': mixture[i], 61 | 'signal': target[i], 62 | } 63 | -------------------------------------------------------------------------------- /train_unet_model.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import numpy as np 4 | import random 5 | import h5py 6 | import argparse 7 | 8 | import rfcutils 9 | import tensorflow_datasets as tfds 10 | import tensorflow as tf 11 | 12 | import glob, h5py 13 | 14 | 15 | from src import unet_model as unet 16 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping 17 | 18 | mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"]) 19 | 20 | bsz = 64 21 | 22 | all_datasets = ['QPSK_CommSignal2', 'QPSK2_CommSignal2', 'QAM16_CommSignal2', 'OFDMQPSK_CommSignal2', 23 | 'QPSK_CommSignal3', 'QPSK2_CommSignal3', 'QAM16_CommSignal3', 'OFDMQPSK_CommSignal3', 'CommSignal2_CommSignal3', 24 | 'QPSK_EMISignal1', 'QPSK2_EMISignal1', 'QAM16_EMISignal1', 'OFDMQPSK_EMISignal1', 'CommSignal2_EMISignal1', 25 | 'QPSK_CommSignal5G1', 'QPSK2_CommSignal5G1', 'QAM16_CommSignal5G1', 'OFDMQPSK_CommSignal5G1', 'CommSignal2_CommSignal5G1'] 26 | 27 | def train_script(idx): 28 | dataset_type = all_datasets[idx] 29 | 30 | ds_train, _ = tfds.load(dataset_type, split="train[:90%]", 31 | shuffle_files=True, 32 | as_supervised=True, 33 | with_info=True, 34 | data_dir='tfds' 35 | ) 36 | ds_val, _ = tfds.load(dataset_type, split="train[90%:]", 37 | shuffle_files=True, 38 | as_supervised=True, 39 | with_info=True, 40 | data_dir='tfds' 41 | ) 42 | 43 | def extract_example(mixture, target): 44 | return mixture, target 45 | 46 | ds_train = ds_train.map(extract_example, num_parallel_calls=tf.data.AUTOTUNE) 47 | ds_train = ds_train.batch(bsz) 48 | ds_train = ds_train.prefetch(tf.data.AUTOTUNE) 49 | 50 | ds_val = ds_val.map(extract_example, num_parallel_calls=tf.data.AUTOTUNE) 51 | ds_val = ds_val.batch(bsz) 52 | ds_val = ds_val.prefetch(tf.data.AUTOTUNE) 53 | 54 | 55 | window_len = 40960 56 | earlystopping = EarlyStopping(monitor='val_loss', patience=100) 57 | model_pathname = os.path.join('models', f'{dataset_type}_unet', 'checkpoint') 58 | checkpoint = ModelCheckpoint(filepath=model_pathname, monitor='val_loss', verbose=0, save_best_only=True, mode='min', save_weights_only=True) 59 | 60 | with mirrored_strategy.scope(): 61 | nn_model = unet.get_unet_model((window_len, 2), k_sz=3, long_k_sz=101, k_neurons=32, lr=0.0003) 62 | nn_model.fit(ds_train, epochs=2000, batch_size=bsz, shuffle=True, verbose=1, validation_data=ds_val, callbacks=[checkpoint, earlystopping]) 63 | 64 | if __name__ == '__main__': 65 | train_script(int(sys.argv[1])) 66 | -------------------------------------------------------------------------------- /src/torchwavenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from math import sqrt 6 | 7 | from .config_torchwavenet import ModelConfig 8 | 9 | 10 | Linear = nn.Linear 11 | ConvTranspose2d = nn.ConvTranspose2d 12 | 13 | 14 | def Conv1d(*args, **kwargs): 15 | layer = nn.Conv1d(*args, **kwargs) 16 | nn.init.kaiming_normal_(layer.weight) 17 | return layer 18 | 19 | 20 | @torch.jit.script 21 | def silu(x): 22 | return x * torch.sigmoid(x) 23 | 24 | 25 | class ResidualBlock(nn.Module): 26 | def __init__(self, residual_channels, dilation): 27 | ''' 28 | :param residual_channels: audio conv 29 | :param dilation: audio conv dilation 30 | ''' 31 | super().__init__() 32 | self.dilated_conv = Conv1d( 33 | residual_channels, 2 * residual_channels, 34 | 3, padding=dilation, dilation=dilation) 35 | 36 | self.output_projection = Conv1d( 37 | residual_channels, 2 * residual_channels, 1) 38 | 39 | def forward(self, x): 40 | y = self.dilated_conv(x) 41 | 42 | gate, filter = torch.chunk(y, 2, dim=1) 43 | y = torch.sigmoid(gate) * torch.tanh(filter) 44 | 45 | y = self.output_projection(y) 46 | residual, skip = torch.chunk(y, 2, dim=1) 47 | return (x + residual) / sqrt(2.0), skip 48 | 49 | 50 | class Wave(nn.Module): 51 | def __init__(self, cfg: ModelConfig): 52 | super().__init__() 53 | self.cfg = cfg 54 | self.input_projection = Conv1d( 55 | cfg.input_channels, cfg.residual_channels, 1) 56 | 57 | self.residual_layers = nn.ModuleList([ 58 | ResidualBlock(cfg.residual_channels, 2**(i % 59 | cfg.dilation_cycle_length)) 60 | for i in range(cfg.residual_layers) 61 | ]) 62 | self.skip_projection = Conv1d( 63 | cfg.residual_channels, cfg.residual_channels, 1) 64 | self.output_projection = Conv1d( 65 | cfg.residual_channels, cfg.input_channels, 1) 66 | nn.init.zeros_(self.output_projection.weight) 67 | 68 | def forward(self, input): 69 | x = input 70 | x = self.input_projection(x) 71 | x = F.relu(x) 72 | 73 | skip = None 74 | for layer in self.residual_layers: 75 | x, skip_connection = layer(x) 76 | skip = skip_connection if skip is None else skip_connection + skip 77 | 78 | x = skip / sqrt(len(self.residual_layers)) 79 | x = self.skip_projection(x) 80 | x = F.relu(x) 81 | x = self.output_projection(x) 82 | return x -------------------------------------------------------------------------------- /sampletest_evaluationscript.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | from tqdm import tqdm 4 | import pickle 5 | 6 | get_db = lambda p: 10*np.log10(p) 7 | get_pow = lambda s: np.mean(np.abs(s)**2) 8 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 9 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 10 | 11 | sig_len = 40960 12 | n_per_batch = 100 13 | all_sinr = np.arange(-30, 0.1, 3) 14 | 15 | def run_demod_test(sig1_est, bit1_est, soi_type, interference_sig_type, testset_identifier): 16 | all_sig_mixture, all_sig1, all_bits1 = pickle.load(open(os.path.join('dataset', f'GroundTruth_{testset_identifier}_Dataset_{soi_type}_{interference_sig_type}.pkl'), 'rb')) 17 | 18 | # Evaluation pipeline 19 | def eval_mse(all_sig_est, all_sig_soi): 20 | assert all_sig_est.shape == all_sig_soi.shape, 'Invalid SOI estimate shape' 21 | return np.mean(np.abs(all_sig_est - all_sig_soi)**2, axis=1) 22 | 23 | def eval_ber(bit_est, bit_true): 24 | ber = np.sum((bit_est != bit_true).astype(np.float32), axis=1) / bit_true.shape[1] 25 | assert bit_est.shape == bit_true.shape, 'Invalid bit estimate shape' 26 | return ber 27 | 28 | all_mse, all_ber = [], [] 29 | for idx, sinr in tqdm(enumerate(all_sinr)): 30 | batch_mse = eval_mse(sig1_est[idx*n_per_batch:(idx+1)*n_per_batch], all_sig1[idx*n_per_batch:(idx+1)*n_per_batch]) 31 | bit_true_batch = all_bits1[idx*n_per_batch:(idx+1)*n_per_batch] 32 | batch_ber = eval_ber(bit1_est[idx*n_per_batch:(idx+1)*n_per_batch], bit_true_batch) 33 | all_mse.append(batch_mse) 34 | all_ber.append(batch_ber) 35 | 36 | all_mse, all_ber = np.array(all_mse), np.array(all_ber) 37 | 38 | mse_mean = 10*np.log10(np.mean(all_mse, axis=-1)) 39 | ber_mean = np.mean(all_ber, axis=-1) 40 | print(f'{"SINR [dB]":>12} {"MSE [dB]":>12} {"BER":>12}') 41 | print('==================================================') 42 | for sinr, mse, ber in zip(all_sinr, mse_mean, ber_mean): 43 | print(f'{sinr:>12} {mse:>12,.5f} {ber:>12,.5f}') 44 | return mse_mean, ber_mean 45 | 46 | if __name__ == "__main__": 47 | soi_type, interference_sig_type = sys.argv[1], sys.argv[2] 48 | testset_identifier = sys.argv[3] # 'TestSet1Example' 49 | id_string = sys.argv[4] #'Default_TF_UNet' 50 | 51 | sig1_est = np.load(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_soi_{soi_type}_{interference_sig_type}.npy')) 52 | bit1_est = np.load(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_bits_{soi_type}_{interference_sig_type}.npy')) 53 | assert ~np.isnan(sig1_est).any(), 'NaN or Inf in Signal Estimate' 54 | assert ~np.isnan(bit1_est).any(), 'NaN or Inf in Bit Estimate' 55 | 56 | mse_mean, ber_mean = run_demod_test(sig1_est, bit1_est, soi_type, interference_sig_type, testset_identifier) 57 | pickle.dump((mse_mean, ber_mean),open(os.path.join('outputs', f'{id_string}_{testset_identifier}_exports_summary_{soi_type}_{interference_sig_type}.pkl'), 'wb')) 58 | -------------------------------------------------------------------------------- /train_torchwavenet.py: -------------------------------------------------------------------------------- 1 | import socketserver 2 | import sys 3 | 4 | from argparse import ArgumentParser 5 | from omegaconf import OmegaConf 6 | from torch.cuda import device_count 7 | from torch.multiprocessing import spawn 8 | from typing import List 9 | 10 | from src.config_torchwavenet import Config, parse_configs 11 | from src.learner_torchwavenet import train, train_distributed 12 | 13 | 14 | def _get_free_port(): 15 | with socketserver.TCPServer(('localhost', 0), None) as s: 16 | return s.server_address[1] 17 | 18 | all_datasets = ['QPSK_CommSignal2', 'QPSK2_CommSignal2', 'QAM16_CommSignal2', 'OFDMQPSK_CommSignal2', 19 | 'QPSK_CommSignal3', 'QPSK2_CommSignal3', 'QAM16_CommSignal3', 'OFDMQPSK_CommSignal3', 'CommSignal2_CommSignal3', 20 | 'QPSK_EMISignal1', 'QPSK2_EMISignal1', 'QAM16_EMISignal1', 'OFDMQPSK_EMISignal1', 'CommSignal2_EMISignal1', 21 | 'QPSK_CommSignal5G1', 'QPSK2_CommSignal5G1', 'QAM16_CommSignal5G1', 'OFDMQPSK_CommSignal5G1', 'CommSignal2_CommSignal5G1'] 22 | 23 | def main(argv: List[str]): 24 | parser = ArgumentParser(description="Train a Diffwave model.") 25 | parser.add_argument("--sigindex", type=int, required=True, 26 | help="Index for Mixture Type.") 27 | parser.add_argument("--config", type=str, default="src/configs/wavenet.yml", 28 | help="Configuration file for model.") 29 | args = parser.parse_args(argv[1:-1]) 30 | 31 | sigtype = all_datasets[args.sigindex] 32 | # First create the base config 33 | cfg = OmegaConf.load(args.config) 34 | cli_cfg = OmegaConf.from_cli( 35 | argv[-1].split("::")) if argv[-1] != "" else None 36 | cfg: Config = Config(**parse_configs(cfg, cli_cfg)) 37 | cfg.data.root_dir = f"npydataset/Dataset_{sigtype}_Mixture" 38 | cfg.model_dir = f"torchmodels/dataset_{sigtype.lower()}_mixture_wavenet" 39 | 40 | # Setup training 41 | world_size = device_count() 42 | if world_size != cfg.distributed.world_size: 43 | raise ValueError( 44 | "Requested world size is not the same as number of visible GPUs.") 45 | if cfg.distributed.distributed: 46 | if world_size < 2: 47 | raise ValueError( 48 | "Distributed training cannot be run on machine" 49 | f" with {world_size} device(s)." 50 | ) 51 | if cfg.data.batch_size % world_size != 0: 52 | raise ValueError( 53 | f"Batch size {cfg.data.batch_size} is not evenly" 54 | f" divisble by # GPUs = {world_size}." 55 | ) 56 | cfg.data.batch_size = cfg.data.batch_size // world_size 57 | port = _get_free_port() 58 | spawn( 59 | train_distributed, 60 | args=(world_size, port, cfg), 61 | nprocs=world_size, 62 | join=True 63 | ) 64 | else: 65 | train(cfg) 66 | 67 | 68 | if __name__ == "__main__": 69 | argv = sys.argv 70 | if len(sys.argv) == 3: 71 | argv = argv + [""] 72 | main(argv) 73 | -------------------------------------------------------------------------------- /sampletest_tf_unet_inference.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import random 4 | import h5py 5 | from tqdm import tqdm 6 | import pickle 7 | 8 | import rfcutils 9 | import tensorflow as tf 10 | from src import unet_model as unet 11 | 12 | get_db = lambda p: 10*np.log10(p) 13 | get_pow = lambda s: np.mean(np.abs(s)**2) 14 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 15 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 16 | 17 | sig_len = 40960 18 | n_per_batch = 100 19 | all_sinr = np.arange(-30, 0.1, 3) 20 | 21 | def get_soi_generation_fn(soi_sig_type): 22 | if soi_sig_type == 'QPSK': 23 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16) 24 | demod_soi = rfcutils.qpsk_matched_filter_demod 25 | elif soi_sig_type == 'QAM16': 26 | generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16) 27 | demod_soi = rfcutils.qam16_matched_filter_demod 28 | elif soi_sig_type == 'QPSK2': 29 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4) 30 | demod_soi = rfcutils.qpsk2_matched_filter_demod 31 | elif soi_sig_type == 'OFDMQPSK': 32 | generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80) 33 | _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80) 34 | demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID) 35 | else: 36 | raise Exception("SOI Type not recognized") 37 | return generate_soi, demod_soi 38 | 39 | 40 | def run_inference(all_sig_mixture, soi_type, interference_sig_type): 41 | 42 | generate_soi, demod_soi = get_soi_generation_fn(soi_type) 43 | 44 | nn_model = unet.get_unet_model((sig_len, 2), k_sz=3, long_k_sz=101, k_neurons=32, lr=0.0003) 45 | nn_model.load_weights(os.path.join('models', f'dataset_{soi_type.lower()}_{interference_sig_type.lower()}_mixture_unet', 'checkpoint')) 46 | 47 | sig_mixture_comp = tf.stack((tf.math.real(all_sig_mixture), tf.math.imag(all_sig_mixture)), axis=-1) 48 | sig1_out = nn_model.predict(sig_mixture_comp, batch_size=100, verbose=1) 49 | sig1_est = tf.complex(sig1_out[:,:,0], sig1_out[:,:,1]) 50 | 51 | bit_est = [] 52 | for idx, sinr_db in tqdm(enumerate(all_sinr)): 53 | bit_est_batch, _ = demod_soi(sig1_est[idx*n_per_batch:(idx+1)*n_per_batch]) 54 | bit_est.append(bit_est_batch) 55 | bit_est = tf.concat(bit_est, axis=0) 56 | sig1_est, bit_est = sig1_est.numpy(), bit_est.numpy() 57 | return sig1_est, bit_est 58 | 59 | if __name__ == "__main__": 60 | soi_type, interference_sig_type = sys.argv[1], sys.argv[2] 61 | testset_identifier = sys.argv[3] # 'TestSet1Example' 62 | id_string = 'Default_TF_UNet' 63 | all_sig_mixture = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}.npy')) 64 | sig1_est, bit1_est = run_inference(all_sig_mixture, soi_type, interference_sig_type) 65 | np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_soi_{soi_type}_{interference_sig_type}'), sig1_est) 66 | np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_bits_{soi_type}_{interference_sig_type}'), bit1_est) 67 | -------------------------------------------------------------------------------- /rfcutils/ofdm_helper_fn.py: -------------------------------------------------------------------------------- 1 | import sionna as sn 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | NFFT = 64 6 | CP_LEN = 16 7 | OFDM_LEN = NFFT + CP_LEN 8 | CODERATE = 1 9 | n_streams_per_tx = 1 10 | 11 | # Binary source to generate uniform i.i.d. bits 12 | binary_source = sn.utils.BinarySource() 13 | 14 | # 4-QAM constellation 15 | NUM_BITS_PER_SYMBOL = 2 16 | constellation = sn.mapping.Constellation("qam", NUM_BITS_PER_SYMBOL, trainable=False) # The constellation is set to be NOT trainable 17 | stream_manager = sn.mimo.StreamManagement(np.array([[1]]), 1) 18 | 19 | # Mapper and demapper 20 | mapper = sn.mapping.Mapper(constellation=constellation) 21 | demapper = sn.mapping.Demapper("app", constellation=constellation) 22 | 23 | # AWGN channel 24 | awgn_channel = sn.channel.AWGN() 25 | 26 | # # The encoder maps information bits to coded bits 27 | # encoder = sn.fec.ldpc.LDPC5GEncoder(k, n) 28 | 29 | def get_resource_grid(num_ofdm_symbols): 30 | RESOURCE_GRID = sn.ofdm.ResourceGrid( num_ofdm_symbols=num_ofdm_symbols, 31 | fft_size=NFFT, 32 | subcarrier_spacing=20e6/NFFT, 33 | num_tx=1, 34 | num_streams_per_tx=n_streams_per_tx, 35 | num_guard_carriers=(4,3), 36 | dc_null=True, 37 | cyclic_prefix_length=CP_LEN, 38 | pilot_pattern=None, 39 | pilot_ofdm_symbol_indices=[]) 40 | return RESOURCE_GRID 41 | 42 | def generate_ofdm_signal(batch_size, num_ofdm_symbols, ebno_db=None): 43 | RESOURCE_GRID = get_resource_grid(num_ofdm_symbols) 44 | 45 | # Number of coded bits in a resource grid 46 | n = int(RESOURCE_GRID.num_data_symbols*NUM_BITS_PER_SYMBOL) 47 | # Number of information bits in a resource groud 48 | k = int(n*CODERATE) 49 | 50 | bits = binary_source([batch_size, 1, n_streams_per_tx, k]) 51 | return modulate_ofdm_signal(bits, RESOURCE_GRID, ebno_db) 52 | 53 | 54 | def ofdm_demod(sig, RESOURCE_GRID, no=1e-4): 55 | rg_demapper = sn.ofdm.ResourceGridDemapper(RESOURCE_GRID, stream_manager) 56 | ofdm_demod_block = sn.ofdm.OFDMDemodulator(NFFT, 0, CP_LEN) 57 | 58 | x_ofdm_demod = ofdm_demod_block(sig) 59 | x_demod = rg_demapper(tf.reshape(x_ofdm_demod, (sig.shape[0],1,1,-1,NFFT))) 60 | llr = demapper([x_demod,no]) 61 | llr = tf.squeeze(llr, axis=[1,2]) 62 | return tf.cast(llr > 0, tf.float32), x_ofdm_demod 63 | 64 | 65 | def modulate_ofdm_signal(info_bits, RESOURCE_GRID, ebno_db=None): 66 | # codewords = encoder(info_bits) # using uncoded bits for now 67 | codewords = info_bits 68 | rg_mapper = sn.ofdm.ResourceGridMapper(RESOURCE_GRID) 69 | ofdm_mod = sn.ofdm.OFDMModulator(RESOURCE_GRID.cyclic_prefix_length) 70 | 71 | x = mapper(codewords) 72 | x_rg = rg_mapper(x) 73 | x_ofdm = ofdm_mod(x_rg) 74 | 75 | if ebno_db is None: 76 | y = x_ofdm 77 | else: 78 | no = sn.utils.ebnodb2no(ebno_db=10.0, 79 | num_bits_per_symbol=NUM_BITS_PER_SYMBOL, 80 | coderate=CODERATE, 81 | resource_grid=RESOURCE_GRID) 82 | y = awgn_channel([x_ofdm, no]) 83 | # squeeze axis corresponding to num_tx, num_streams_per_tx (assumed to be 1) 84 | y = tf.squeeze(y, axis=[1,2]) 85 | x_rg = tf.squeeze(x_rg, axis=[1,2]) 86 | info_bits = tf.squeeze(info_bits, axis=[1,2]) 87 | return y, x_rg, info_bits, RESOURCE_GRID 88 | 89 | -------------------------------------------------------------------------------- /sampletrain_gendataset_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # # Create examples in h5 files under the 'dataset' folder 4 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type QPSK --interference_sig_type EMISignal1 5 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type QPSK --interference_sig_type CommSignal2 6 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type QPSK --interference_sig_type CommSignal3 7 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type QPSK --interference_sig_type CommSignal5G1 8 | 9 | 10 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type OFDMQPSK --interference_sig_type EMISignal1 11 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal2 12 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal3 13 | # python dataset_utils/example_generate_rfc_mixtures.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal5G1 14 | 15 | 16 | # # Create TFDS Dataset from 'dataset' folder (for TF UNet training) 17 | # tfds build dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal2_Mixture.py --data_dir tfds/ 18 | # tfds build dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal3_Mixture.py --data_dir tfds/ 19 | # tfds build dataset_utils/tfds_scripts/Dataset_QPSK_CommSignal5G1_Mixture.py --data_dir tfds/ 20 | # tfds build dataset_utils/tfds_scripts/Dataset_QPSK_EMISignal1_Mixture.py --data_dir tfds/ 21 | 22 | # tfds build dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal2_Mixture.py --data_dir tfds/ 23 | # tfds build dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal3_Mixture.py --data_dir tfds/ 24 | # tfds build dataset_utils/tfds_scripts/Dataset_OFDMQPSK_CommSignal5G1_Mixture.py --data_dir tfds/ 25 | # tfds build dataset_utils/tfds_scripts/Dataset_OFDMQPSK_EMISignal1_Mixture.py --data_dir tfds/ 26 | 27 | 28 | # # Create NPY Dataset from 'dataset' folder (for PyTorch Wavenet training) 29 | # python dataset_utils/example_preprocess_npy_dataset.py QPSK_CommSignal2 30 | # python dataset_utils/example_preprocess_npy_dataset.py QPSK_CommSignal3 31 | # python dataset_utils/example_preprocess_npy_dataset.py QPSK_CommSignal5G1 32 | # python dataset_utils/example_preprocess_npy_dataset.py QPSK_EMISignal1 33 | 34 | # python dataset_utils/example_preprocess_npy_dataset.py OFDMQPSK_CommSignal2 35 | # python dataset_utils/example_preprocess_npy_dataset.py OFDMQPSK_CommSignal3 36 | # python dataset_utils/example_preprocess_npy_dataset.py OFDMQPSK_CommSignal5G1 37 | # python dataset_utils/example_preprocess_npy_dataset.py OFDMQPSK_EMISignal1 38 | 39 | 40 | # Create training set examples similar to TestSet from the Grand Challenge specifications 41 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type QPSK --interference_sig_type EMISignal1 --n_per_batch 1000 42 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type QPSK --interference_sig_type CommSignal2 --n_per_batch 1000 43 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type QPSK --interference_sig_type CommSignal3 --n_per_batch 1000 44 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type QPSK --interference_sig_type CommSignal5G1 --n_per_batch 1000 45 | 46 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type OFDMQPSK --interference_sig_type EMISignal1 --n_per_batch 1000 47 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal2 --n_per_batch 1000 48 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal3 --n_per_batch 1000 49 | python dataset_utils/example_generate_competition_trainmixture.py --soi_sig_type OFDMQPSK --interference_sig_type CommSignal5G1 --n_per_batch 1000 50 | -------------------------------------------------------------------------------- /sampletest_torch_wavenet_inference.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import random 4 | import h5py 5 | from tqdm import tqdm 6 | import pickle 7 | 8 | import torch 9 | import tensorflow as tf 10 | tf.config.set_visible_devices([], 'GPU') 11 | import rfcutils 12 | from src.torchwavenet import Wave 13 | from omegaconf import OmegaConf 14 | from src.config_torchwavenet import Config, parse_configs 15 | 16 | get_db = lambda p: 10*np.log10(p) 17 | get_pow = lambda s: np.mean(np.abs(s)**2) 18 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 19 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 20 | 21 | sig_len = 40960 22 | n_per_batch = 100 23 | all_sinr = np.arange(-30, 0.1, 3) 24 | 25 | def get_soi_generation_fn(soi_sig_type): 26 | if soi_sig_type == 'QPSK': 27 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16) 28 | demod_soi = rfcutils.qpsk_matched_filter_demod 29 | elif soi_sig_type == 'QAM16': 30 | generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16) 31 | demod_soi = rfcutils.qam16_matched_filter_demod 32 | elif soi_sig_type == 'QPSK2': 33 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4) 34 | demod_soi = rfcutils.qpsk2_matched_filter_demod 35 | elif soi_sig_type == 'OFDMQPSK': 36 | generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80) 37 | _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80) 38 | demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID) 39 | else: 40 | raise Exception("SOI Type not recognized") 41 | return generate_soi, demod_soi 42 | 43 | 44 | def run_inference(all_sig_mixture, soi_type, interference_sig_type): 45 | 46 | generate_soi, demod_soi = get_soi_generation_fn(soi_type) 47 | 48 | cfg = OmegaConf.load("src/configs/wavenet.yml") 49 | cli_cfg = None 50 | cfg: Config = Config(**parse_configs(cfg, cli_cfg)) 51 | cfg.model_dir = f"torchmodels/dataset_{soi_type.lower()}_{interference_sig_type.lower()}_mixture_wavenet" 52 | nn_model = Wave(cfg.model).cuda() 53 | nn_model.load_state_dict(torch.load(cfg.model_dir+"/weights.pt")['model']) 54 | 55 | sig_mixture_comp = tf.stack((tf.math.real(all_sig_mixture), tf.math.imag(all_sig_mixture)), axis=-1) 56 | with torch.no_grad(): 57 | nn_model.eval() 58 | all_sig1_out = [] 59 | bsz = 100 60 | for i in tqdm(range(sig_mixture_comp.shape[0]//bsz)): 61 | sig_input = torch.Tensor(sig_mixture_comp[i*bsz:(i+1)*bsz].numpy()).transpose(1, 2).to('cuda') 62 | sig1_out = nn_model(sig_input) 63 | all_sig1_out.append(sig1_out.transpose(1,2).detach().cpu().numpy()) 64 | sig1_out = tf.concat(all_sig1_out, axis=0) 65 | print(sig1_out.shape) 66 | sig1_est = tf.complex(sig1_out[:,:,0], sig1_out[:,:,1]) 67 | 68 | bit_est = [] 69 | for idx, sinr_db in tqdm(enumerate(all_sinr)): 70 | bit_est_batch, _ = demod_soi(sig1_est[idx*n_per_batch:(idx+1)*n_per_batch]) 71 | bit_est.append(bit_est_batch) 72 | bit_est = tf.concat(bit_est, axis=0) 73 | sig1_est, bit_est = sig1_est.numpy(), bit_est.numpy() 74 | return sig1_est, bit_est 75 | 76 | if __name__ == "__main__": 77 | soi_type, interference_sig_type = sys.argv[1], sys.argv[2] 78 | testset_identifier = sys.argv[3] # 'TestSet1Example' 79 | id_string = 'Default_Torch_WaveNet' 80 | all_sig_mixture = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}.npy')) 81 | sig1_est, bit1_est = run_inference(all_sig_mixture, soi_type, interference_sig_type) 82 | np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_soi_{soi_type}_{interference_sig_type}'), sig1_est) 83 | np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_bits_{soi_type}_{interference_sig_type}'), bit1_est) 84 | -------------------------------------------------------------------------------- /sampletest_testmixture_generator.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import random 4 | import h5py 5 | from tqdm import tqdm 6 | import pickle 7 | 8 | import rfcutils 9 | import tensorflow as tf 10 | 11 | get_db = lambda p: 10*np.log10(p) 12 | get_pow = lambda s: np.mean(np.abs(s)**2, axis=-1) 13 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 14 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 15 | 16 | sig_len = 40960 17 | n_per_batch = 100 18 | all_sinr = np.arange(-30, 0.1, 3) 19 | 20 | seed_number = 0 21 | 22 | def get_soi_generation_fn(soi_sig_type): 23 | if soi_sig_type == 'QPSK': 24 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16) 25 | demod_soi = rfcutils.qpsk_matched_filter_demod 26 | elif soi_sig_type == 'QAM16': 27 | generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16) 28 | demod_soi = rfcutils.qam16_matched_filter_demod 29 | elif soi_sig_type == 'QPSK2': 30 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4) 31 | demod_soi = rfcutils.qpsk2_matched_filter_demod 32 | elif soi_sig_type == 'OFDMQPSK': 33 | generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80) 34 | _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80) 35 | demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID) 36 | else: 37 | raise Exception("SOI Type not recognized") 38 | return generate_soi, demod_soi 39 | 40 | 41 | def generate_demod_testmixture(soi_type, interference_sig_type): 42 | 43 | generate_soi, demod_soi = get_soi_generation_fn(soi_type) 44 | with h5py.File(os.path.join('dataset', 'testset1_frame', interference_sig_type+'_test1_raw_data.h5'),'r') as data_h5file: 45 | sig_data = np.array(data_h5file.get('dataset')) 46 | sig_type_info = data_h5file.get('sig_type')[()] 47 | if isinstance(sig_type_info, bytes): 48 | sig_type_info = sig_type_info.decode("utf-8") 49 | 50 | random.seed(seed_number) 51 | np.random.seed(seed_number) 52 | tf.random.set_seed(seed_number) 53 | 54 | all_sig_mixture, all_sig1, all_bits1, meta_data = [], [], [], [] 55 | for idx, sinr in tqdm(enumerate(all_sinr)): 56 | sig1, _, bits1, _ = generate_soi(n_per_batch, sig_len) 57 | sig2 = sig_data[np.random.randint(sig_data.shape[0], size=(n_per_batch)), :] 58 | 59 | sig_target = sig1[:, :sig_len] 60 | 61 | rand_start_idx2 = np.random.randint(sig2.shape[1]-sig_len, size=sig2.shape[0]) 62 | inds2 = tf.cast(rand_start_idx2.reshape(-1,1) + np.arange(sig_len).reshape(1,-1), tf.int32) 63 | sig_interference = tf.experimental.numpy.take_along_axis(sig2, inds2, axis=1) 64 | 65 | # Interference Coefficient 66 | rand_gain = np.sqrt(10**(-sinr/10)).astype(np.float32) 67 | rand_phase = tf.random.uniform(shape=[sig_interference.shape[0],1]) 68 | rand_gain = tf.complex(rand_gain, tf.zeros_like(rand_gain)) 69 | rand_phase = tf.complex(rand_phase, tf.zeros_like(rand_phase)) 70 | coeff = rand_gain * tf.math.exp(1j*2*np.pi*rand_phase) 71 | 72 | sig_mixture = sig_target + sig_interference * coeff 73 | 74 | all_sig_mixture.append(sig_mixture) 75 | all_sig1.append(sig_target) 76 | all_bits1.append(bits1) 77 | 78 | actual_sinr = get_sinr_db(sig_target, sig_interference * coeff) 79 | meta_data.append(np.vstack(([rand_gain.numpy().real for _ in range(n_per_batch)], [sinr for _ in range(n_per_batch)], actual_sinr, [soi_type for _ in range(n_per_batch)], [interference_sig_type for _ in range(n_per_batch)]))) 80 | 81 | with tf.device('CPU'): 82 | all_sig_mixture = tf.concat(all_sig_mixture, axis=0).numpy() 83 | all_sig1 = tf.concat(all_sig1, axis=0).numpy() 84 | all_bits1 = tf.concat(all_bits1, axis=0).numpy() 85 | 86 | pickle.dump((all_sig_mixture, all_sig1, all_bits1), open(os.path.join('dataset', f'GroundTruth_TestSet1Example_Dataset_{soi_type}_{interference_sig_type}.pkl'), 'wb'), protocol=4) 87 | np.save(os.path.join('dataset', f'TestSet1Example_testmixture_{soi_type}_{interference_sig_type}'), all_sig_mixture) 88 | 89 | meta_data = np.concatenate(meta_data, axis=1).T 90 | np.save(os.path.join('dataset', f'TestSet1Example_testmixture_{soi_type}_{interference_sig_type}_metadata'), meta_data) 91 | 92 | if __name__ == "__main__": 93 | generate_demod_testmixture(sys.argv[1], sys.argv[2]) 94 | -------------------------------------------------------------------------------- /dataset_utils/example_generate_competition_trainmixture.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 3 | import numpy as np 4 | import random 5 | import h5py 6 | from tqdm import tqdm 7 | import pickle 8 | import argparse 9 | 10 | import rfcutils 11 | import tensorflow as tf 12 | 13 | get_db = lambda p: 10*np.log10(p) 14 | get_pow = lambda s: np.mean(np.abs(s)**2, axis=-1) 15 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 16 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 17 | 18 | sig_len = 40960 19 | default_n_per_batch = 100 20 | all_sinr = np.arange(-30, 0.1, 3) 21 | 22 | seed_number = 0 23 | 24 | def get_soi_generation_fn(soi_sig_type): 25 | if soi_sig_type == 'QPSK': 26 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16) 27 | demod_soi = rfcutils.qpsk_matched_filter_demod 28 | elif soi_sig_type == 'QAM16': 29 | generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16) 30 | demod_soi = rfcutils.qam16_matched_filter_demod 31 | elif soi_sig_type == 'QPSK2': 32 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4) 33 | demod_soi = rfcutils.qpsk2_matched_filter_demod 34 | elif soi_sig_type == 'OFDMQPSK': 35 | generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80) 36 | _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80) 37 | demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID) 38 | else: 39 | raise Exception("SOI Type not recognized") 40 | return generate_soi, demod_soi 41 | 42 | 43 | def generate_demod_testmixture(soi_type, interference_sig_type, n_per_batch=default_n_per_batch): 44 | 45 | generate_soi, demod_soi = get_soi_generation_fn(soi_type) 46 | with h5py.File(os.path.join('dataset', 'interferenceset_frame', interference_sig_type+'_raw_data.h5'),'r') as data_h5file: 47 | sig_data = np.array(data_h5file.get('dataset')) 48 | sig_type_info = data_h5file.get('sig_type')[()] 49 | if isinstance(sig_type_info, bytes): 50 | sig_type_info = sig_type_info.decode("utf-8") 51 | 52 | random.seed(seed_number) 53 | np.random.seed(seed_number) 54 | tf.random.set_seed(seed_number) 55 | 56 | all_sig_mixture, all_sig1, all_bits1, meta_data = [], [], [], [] 57 | for idx, sinr in tqdm(enumerate(all_sinr)): 58 | sig1, _, bits1, _ = generate_soi(n_per_batch, sig_len) 59 | sig2 = sig_data[np.random.randint(sig_data.shape[0], size=(n_per_batch)), :] 60 | 61 | sig_target = sig1[:, :sig_len] 62 | 63 | rand_start_idx2 = np.random.randint(sig2.shape[1]-sig_len, size=sig2.shape[0]) 64 | inds2 = tf.cast(rand_start_idx2.reshape(-1,1) + np.arange(sig_len).reshape(1,-1), tf.int32) 65 | sig_interference = tf.experimental.numpy.take_along_axis(sig2, inds2, axis=1) 66 | 67 | # Interference Coefficient 68 | rand_gain = np.sqrt(10**(-sinr/10)).astype(np.float32) 69 | rand_phase = tf.random.uniform(shape=[sig_interference.shape[0],1]) 70 | rand_gain = tf.complex(rand_gain, tf.zeros_like(rand_gain)) 71 | rand_phase = tf.complex(rand_phase, tf.zeros_like(rand_phase)) 72 | coeff = rand_gain * tf.math.exp(1j*2*np.pi*rand_phase) 73 | 74 | sig_mixture = sig_target + sig_interference * coeff 75 | 76 | all_sig_mixture.append(sig_mixture) 77 | all_sig1.append(sig_target) 78 | all_bits1.append(bits1) 79 | 80 | actual_sinr = get_sinr_db(sig_target, sig_interference * coeff) 81 | meta_data.append(np.vstack(([rand_gain.numpy().real for _ in range(n_per_batch)], [sinr for _ in range(n_per_batch)], actual_sinr, [soi_type for _ in range(n_per_batch)], [interference_sig_type for _ in range(n_per_batch)]))) 82 | 83 | with tf.device('CPU'): 84 | all_sig_mixture = tf.concat(all_sig_mixture, axis=0).numpy() 85 | all_sig1 = tf.concat(all_sig1, axis=0).numpy() 86 | all_bits1 = tf.concat(all_bits1, axis=0).numpy() 87 | 88 | meta_data = np.concatenate(meta_data, axis=1).T 89 | pickle.dump((all_sig_mixture, all_sig1, all_bits1, meta_data), open(os.path.join('dataset', f'Training_Dataset_{soi_type}_{interference_sig_type}.pkl'), 'wb'), protocol=4) 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser(description='Generate Synthetic Dataset') 93 | parser.add_argument('-b', '--n_per_batch', default=100, type=int, help='') 94 | parser.add_argument('--random_seed', default=0, type=int, help='') 95 | parser.add_argument('--soi_sig_type', help='') 96 | parser.add_argument('--interference_sig_type', help='') 97 | 98 | args = parser.parse_args() 99 | 100 | soi_type = args.soi_sig_type 101 | interference_sig_type = args.interference_sig_type 102 | 103 | generate_demod_testmixture(args.soi_sig_type, args.interference_sig_type, args.n_per_batch) 104 | -------------------------------------------------------------------------------- /rfsionna_env.yml: -------------------------------------------------------------------------------- 1 | name: rfsionna 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 9 | - argon2-cffi-bindings=21.2.0=py37h7f8727e_0 10 | - attrs=21.4.0=pyhd3eb1b0_0 11 | - backcall=0.2.0=pyhd3eb1b0_0 12 | - beautifulsoup4=4.11.1=py37h06a4308_0 13 | - bleach=4.1.0=pyhd3eb1b0_0 14 | - ca-certificates=2022.07.19=h06a4308_0 15 | - certifi=2022.6.15=py37h06a4308_0 16 | - cffi=1.15.0=py37hd667e15_1 17 | - cudatoolkit=11.2.2=hbe64b41_10 18 | - cudnn=8.1.0.77=h90431f1_0 19 | - dbus=1.13.18=hb2f20db_0 20 | - debugpy=1.5.1=py37h295c915_0 21 | - decorator=5.1.1=pyhd3eb1b0_0 22 | - defusedxml=0.7.1=pyhd3eb1b0_0 23 | - entrypoints=0.4=py37h06a4308_0 24 | - expat=2.4.4=h295c915_0 25 | - fontconfig=2.13.1=h6c09931_0 26 | - freetype=2.11.0=h70c0345_0 27 | - glib=2.69.1=h4ff587b_1 28 | - gst-plugins-base=1.14.0=h8213a91_2 29 | - gstreamer=1.14.0=h28cd5cc_2 30 | - icu=58.2=he6710b0_3 31 | - importlib_metadata=4.11.3=hd3eb1b0_0 32 | - importlib_resources=5.2.0=pyhd3eb1b0_1 33 | - ipykernel=6.9.1=py37h06a4308_0 34 | - ipython=7.31.1=py37h06a4308_0 35 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 36 | - ipywidgets=7.6.5=pyhd3eb1b0_1 37 | - jedi=0.18.1=py37h06a4308_1 38 | - jinja2=3.0.3=pyhd3eb1b0_0 39 | - jpeg=9e=h7f8727e_0 40 | - jsonschema=4.4.0=py37h06a4308_0 41 | - jupyter=1.0.0=py37_7 42 | - jupyter_client=7.2.2=py37h06a4308_0 43 | - jupyter_console=6.4.3=pyhd3eb1b0_0 44 | - jupyter_core=4.10.0=py37h06a4308_0 45 | - jupyterlab_pygments=0.1.2=py_0 46 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 47 | - ld_impl_linux-64=2.38=h1181459_1 48 | - libffi=3.3=he6710b0_2 49 | - libgcc-ng=11.2.0=h1234567_1 50 | - libgomp=11.2.0=h1234567_1 51 | - libpng=1.6.37=hbc83047_0 52 | - libsodium=1.0.18=h7b6447c_0 53 | - libstdcxx-ng=11.2.0=h1234567_1 54 | - libuuid=1.0.3=h7f8727e_2 55 | - libxcb=1.15=h7f8727e_0 56 | - libxml2=2.9.14=h74e7548_0 57 | - markupsafe=2.1.1=py37h7f8727e_0 58 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 59 | - mistune=0.8.4=py37h14c3975_1001 60 | - nbclient=0.5.13=py37h06a4308_0 61 | - nbconvert=6.4.4=py37h06a4308_0 62 | - nbformat=5.3.0=py37h06a4308_0 63 | - ncurses=6.3=h5eee18b_3 64 | - nest-asyncio=1.5.5=py37h06a4308_0 65 | - notebook=6.4.12=py37h06a4308_0 66 | - openssl=1.1.1q=h7f8727e_0 67 | - packaging=21.3=pyhd3eb1b0_0 68 | - pandocfilters=1.5.0=pyhd3eb1b0_0 69 | - parso=0.8.3=pyhd3eb1b0_0 70 | - pcre=8.45=h295c915_0 71 | - pexpect=4.8.0=pyhd3eb1b0_3 72 | - pickleshare=0.7.5=pyhd3eb1b0_1003 73 | - prometheus_client=0.13.1=pyhd3eb1b0_0 74 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 75 | - prompt_toolkit=3.0.20=hd3eb1b0_0 76 | - ptyprocess=0.7.0=pyhd3eb1b0_2 77 | - pycparser=2.21=pyhd3eb1b0_0 78 | - pygments=2.11.2=pyhd3eb1b0_0 79 | - pyparsing=3.0.4=pyhd3eb1b0_0 80 | - pyqt=5.9.2=py37h05f1152_2 81 | - pyrsistent=0.18.0=py37heee7806_0 82 | - python=3.7.13=h12debd9_0 83 | - python-dateutil=2.8.2=pyhd3eb1b0_0 84 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 85 | - python_abi=3.7=2_cp37m 86 | - pyzmq=23.2.0=py37h6a678d5_0 87 | - qt=5.9.7=h5867ecd_1 88 | - qtconsole=5.3.1=py37h06a4308_0 89 | - qtpy=2.0.1=pyhd3eb1b0_0 90 | - readline=8.1.2=h7f8727e_1 91 | - send2trash=1.8.0=pyhd3eb1b0_1 92 | - sip=4.19.8=py37hf484d3e_0 93 | - six=1.16.0=pyhd3eb1b0_1 94 | - soupsieve=2.3.1=pyhd3eb1b0_0 95 | - sqlite=3.38.5=hc218d9a_0 96 | - terminado=0.13.1=py37h06a4308_0 97 | - testpath=0.6.0=py37h06a4308_0 98 | - tk=8.6.12=h1ccaba5_0 99 | - tornado=6.1=py37h27cfd23_0 100 | - traitlets=5.1.1=pyhd3eb1b0_0 101 | - typing_extensions=4.1.1=pyh06a4308_0 102 | - wcwidth=0.2.5=pyhd3eb1b0_0 103 | - webencodings=0.5.1=py37_1 104 | - widgetsnbextension=3.5.2=py37h06a4308_0 105 | - xz=5.2.5=h7f8727e_1 106 | - zeromq=4.3.4=h2531618_0 107 | - zlib=1.2.12=h7f8727e_2 108 | - pip: 109 | - absl-py==1.2.0 110 | - astunparse==1.6.3 111 | - cachetools==5.2.0 112 | - charset-normalizer==2.1.1 113 | - cycler==0.11.0 114 | - dill==0.3.6 115 | - dm-tree==0.1.7 116 | - etils==0.9.0 117 | - flatbuffers==2.0.7 118 | - fonttools==4.34.4 119 | - gast==0.5.3 120 | - google-auth==2.11.0 121 | - google-auth-oauthlib==0.4.6 122 | - google-pasta==0.2.0 123 | - googleapis-common-protos==1.56.4 124 | - grpcio==1.47.0 125 | - h5py==3.7.0 126 | - idna==3.3 127 | - importlib-metadata==4.12.0 128 | - joblib==1.2.0 129 | - keras==2.8.0 130 | - keras-preprocessing==1.1.2 131 | - kiwisolver==1.4.4 132 | - libclang==14.0.6 133 | - markdown==3.4.1 134 | - matplotlib==3.5.2 135 | - numpy==1.21.6 136 | - oauthlib==3.2.0 137 | - opt-einsum==3.3.0 138 | - pillow==9.2.0 139 | - pip==22.2.2 140 | - promise==2.3 141 | - protobuf==3.19.4 142 | - pyasn1==0.4.8 143 | - pyasn1-modules==0.2.8 144 | - requests==2.28.1 145 | - requests-oauthlib==1.3.1 146 | - rsa==4.9 147 | - scikit-learn==1.0.2 148 | - scipy==1.7.3 149 | - setuptools==65.3.0 150 | - sigmf==1.0.0 151 | - simplejpeg==1.6.5 152 | - sionna==0.10.0 153 | - tensorboard==2.8.0 154 | - tensorboard-data-server==0.6.1 155 | - tensorboard-plugin-wit==1.8.1 156 | - tensorflow==2.8.2 157 | - tensorflow-addons==0.17.1 158 | - tensorflow-estimator==2.8.0 159 | - tensorflow-io-gcs-filesystem==0.26.0 160 | - tensorflow-metadata==1.10.0 161 | - tensorflow-probability==0.16.0 162 | - termcolor==1.1.0 163 | - tfds-nightly==4.7.0.dev202210310045 164 | - threadpoolctl==3.1.0 165 | - toml==0.10.2 166 | - tqdm==4.64.0 167 | - typeguard==2.13.3 168 | - typing-extensions==4.3.0 169 | - urllib3==1.26.12 170 | - werkzeug==2.2.2 171 | - wheel==0.37.1 172 | - wrapt==1.14.1 173 | - zipp==3.8.1 174 | prefix: /home/gridsan/glcf411/.conda/envs/rfsionna 175 | -------------------------------------------------------------------------------- /dataset_utils/example_generate_rfc_mixtures.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 3 | import h5py 4 | import argparse 5 | 6 | import numpy as np 7 | import random 8 | import matplotlib.pyplot as plt 9 | from tqdm import tqdm 10 | import time 11 | import pickle 12 | 13 | import rfcutils 14 | import tensorflow as tf 15 | 16 | get_db = lambda p: 10*np.log10(p) 17 | get_pow = lambda s: np.mean(np.abs(s)**2) 18 | get_sinr = lambda s, i: get_pow(s)/get_pow(i) 19 | get_sinr_db = lambda s, i: get_db(get_sinr(s,i)) 20 | 21 | 22 | random.seed(0) 23 | np.random.seed(0) 24 | tf.random.set_seed(0) 25 | 26 | def get_soi_generation_fn(soi_sig_type): 27 | if soi_sig_type == 'QPSK': 28 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16) 29 | elif soi_sig_type == 'QAM16': 30 | generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16) 31 | elif soi_sig_type == 'QPSK2': 32 | generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4) 33 | elif soi_sig_type == 'OFDMQPSK': 34 | generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80) 35 | elif soi_sig_type == 'CommSignal2': 36 | with h5py.File(os.path.join('dataset', 'interferenceset_frame', soi_sig_type+'_raw_data.h5'),'r') as data_h5file: 37 | commsignal2_data = np.array(data_h5file.get('dataset')) 38 | def generate_commsignal2_signal(n, s_len): 39 | sig1 = commsignal2_data[np.random.randint(commsignal2_data.shape[0], size=(n)), :] 40 | rand_start_idx1 = np.random.randint(sig1.shape[1]-s_len, size=sig1.shape[0]) 41 | inds1 = tf.cast(rand_start_idx1.reshape(-1,1) + np.arange(s_len).reshape(1,-1), tf.int32) 42 | sig_target = tf.experimental.numpy.take_along_axis(sig1, inds1, axis=1) 43 | return sig_target, None, None, None # returning dummy 2nd to 4th variable to be similar to rfcutils generation function output 44 | generate_soi = generate_commsignal2_signal 45 | else: 46 | raise Exception("SOI Type not recognized") 47 | return generate_soi 48 | 49 | def generate_dataset(sig_data, soi_type, interference_sig_type, sig_len, n_examples, n_per_batch, foldername, seed, verbosity): 50 | random.seed(seed) 51 | np.random.seed(seed) 52 | generate_soi = get_soi_generation_fn(soi_type) 53 | 54 | n_batches = int(np.ceil(n_examples/n_per_batch)) 55 | for idx in tqdm(range(n_batches), disable=not bool(verbosity)): 56 | sig1, _, _, _ = generate_soi(n_per_batch, sig_len) 57 | sig2 = sig_data[np.random.randint(sig_data.shape[0], size=(n_per_batch)), :] 58 | 59 | sig_target = sig1[:, :sig_len] 60 | 61 | rand_start_idx2 = np.random.randint(sig2.shape[1]-sig_len, size=sig2.shape[0]) 62 | inds2 = tf.cast(rand_start_idx2.reshape(-1,1) + np.arange(sig_len).reshape(1,-1), tf.int32) 63 | sig_interference = tf.experimental.numpy.take_along_axis(sig2, inds2, axis=1) 64 | 65 | # Interference Coefficient 66 | # rand_gain = 31*tf.random.uniform(shape=[sig_interference.shape[0],1]) + 1 67 | rand_sinr_db = -36*tf.random.uniform(shape=[sig_interference.shape[0],1]) + 3 68 | rand_gain = 10**(-0.5*rand_sinr_db/10) 69 | rand_phase = tf.random.uniform(shape=[sig_interference.shape[0],1]) 70 | rand_gain = tf.complex(rand_gain, tf.zeros_like(rand_gain)) 71 | rand_phase = tf.complex(rand_phase, tf.zeros_like(rand_phase)) 72 | coeff = rand_gain * tf.math.exp(1j*2*np.pi*rand_phase) 73 | 74 | sig_mixture = sig_target + sig_interference * coeff 75 | 76 | sig_mixture_comp = tf.stack((tf.math.real(sig_mixture), tf.math.imag(sig_mixture)), axis=-1) 77 | sig_target_comp = tf.stack((tf.math.real(sig_target), tf.math.imag(sig_target)), axis=-1) 78 | 79 | mixture_filename = f'{dataset_type}_{soi_type}_{interference_sig_type}_mixture_{idx:04}.h5' 80 | if not os.path.exists(os.path.join(foldername)): 81 | os.makedirs(os.path.join(foldername)) 82 | with h5py.File(os.path.join(foldername, mixture_filename), 'w') as h5file0: 83 | h5file0.create_dataset('mixture', data=sig_mixture_comp) 84 | h5file0.create_dataset('target', data=sig_target_comp) 85 | h5file0.create_dataset('sig_type', data=f'{soi_type}_{interference_sig_type}_mixture') 86 | 87 | del sig1, sig2, sig_mixture_comp, sig_target_comp 88 | return 0 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser(description='Generate Synthetic Dataset') 93 | parser.add_argument('-l', '--sig_len', default=40960, type=int) 94 | parser.add_argument('-n', '--n_examples', default=240000, type=int, help='') 95 | parser.add_argument('-b', '--n_per_batch', default=4000, type=int, help='') 96 | parser.add_argument('-d', '--dataset', default='train', help='') 97 | parser.add_argument('--random_seed', default=0, type=int, help='') 98 | parser.add_argument('-v', '--verbosity', default=1, help='') 99 | parser.add_argument('--soi_sig_type', help='') 100 | parser.add_argument('--interference_sig_type', help='') 101 | args = parser.parse_args() 102 | 103 | soi_type = args.soi_sig_type 104 | 105 | interference_sig_type = args.interference_sig_type 106 | with h5py.File(os.path.join('dataset', 'interferenceset_frame', interference_sig_type+'_raw_data.h5'),'r') as data_h5file: 107 | sig_data = np.array(data_h5file.get('dataset')) 108 | sig_type_info = data_h5file.get('sig_type')[()] 109 | if isinstance(sig_type_info, bytes): 110 | sig_type_info = sig_type_info.decode("utf-8") 111 | 112 | # Generate synthetic dataset based on input arguments 113 | dataset_type = args.dataset 114 | foldername = os.path.join('dataset', f'Dataset_{soi_type}_{interference_sig_type}_Mixture') 115 | 116 | generate_dataset(sig_data, soi_type, interference_sig_type, args.sig_len, args.n_examples, args.n_per_batch, foldername, args.random_seed, args.verbosity) 117 | -------------------------------------------------------------------------------- /rftorch_env.yml: -------------------------------------------------------------------------------- 1 | name: rftorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - argon2-cffi=20.1.0=py37h27cfd23_1 10 | - async_generator=1.10=py37h28b3542_0 11 | - attrs=21.2.0=pyhd3eb1b0_0 12 | - backcall=0.2.0=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - bleach=3.3.0=pyhd3eb1b0_0 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2022.3.29=h06a4308_0 17 | - certifi=2021.10.8=py37h06a4308_2 18 | - cffi=1.14.5=py37h261ae71_0 19 | - cudatoolkit=10.2.89=hfd86e86_1 20 | - dbus=1.13.18=hb2f20db_0 21 | - decorator=5.0.9=pyhd3eb1b0_0 22 | - defusedxml=0.7.1=pyhd3eb1b0_0 23 | - entrypoints=0.3=py37_0 24 | - expat=2.4.1=h2531618_2 25 | - ffmpeg=4.3=hf484d3e_0 26 | - fontconfig=2.13.1=h6c09931_0 27 | - freetype=2.10.4=h5ab3b9f_0 28 | - glib=2.68.2=h36276a3_0 29 | - gmp=6.2.1=h2531618_2 30 | - gnutls=3.6.15=he1e5248_0 31 | - gst-plugins-base=1.14.0=h8213a91_2 32 | - gstreamer=1.14.0=h28cd5cc_2 33 | - icu=58.2=he6710b0_3 34 | - importlib_metadata=3.10.0=hd3eb1b0_0 35 | - intel-openmp=2021.2.0=h06a4308_610 36 | - ipykernel=5.3.4=py37h5ca1d4c_0 37 | - ipython=7.24.1=py37h085eea5_0 38 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 39 | - ipywidgets=7.6.3=pyhd3eb1b0_1 40 | - jedi=0.18.0=py37h06a4308_1 41 | - jinja2=3.0.0=pyhd3eb1b0_0 42 | - jpeg=9b=h024ee3a_2 43 | - jsonschema=3.2.0=py_2 44 | - jupyter=1.0.0=py37_7 45 | - jupyter_client=6.1.12=pyhd3eb1b0_0 46 | - jupyter_console=6.4.0=pyhd3eb1b0_0 47 | - jupyter_core=4.7.1=py37h06a4308_0 48 | - jupyterlab_pygments=0.1.2=py_0 49 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 50 | - lame=3.100=h7b6447c_0 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.35.1=h7274673_9 53 | - libffi=3.3=he6710b0_2 54 | - libgcc-ng=9.3.0=h5101ec6_17 55 | - libgomp=9.3.0=h5101ec6_17 56 | - libiconv=1.15=h63c8f33_5 57 | - libidn2=2.3.1=h27cfd23_0 58 | - libpng=1.6.37=hbc83047_0 59 | - libsodium=1.0.18=h7b6447c_0 60 | - libstdcxx-ng=9.3.0=hd4cf53a_17 61 | - libtasn1=4.16.0=h27cfd23_0 62 | - libtiff=4.2.0=h85742a9_0 63 | - libunistring=0.9.10=h27cfd23_0 64 | - libuuid=1.0.3=h1bed415_2 65 | - libuv=1.40.0=h7b6447c_0 66 | - libwebp-base=1.2.0=h27cfd23_0 67 | - libxcb=1.14=h7b6447c_0 68 | - libxml2=2.9.10=hb55368b_3 69 | - lz4-c=1.9.3=h2531618_0 70 | - matplotlib-inline=0.1.2=pyhd8ed1ab_2 71 | - mistune=0.8.4=py37h14c3975_1001 72 | - mkl=2021.2.0=h06a4308_296 73 | - mkl-service=2.3.0=py37h27cfd23_1 74 | - mkl_fft=1.3.0=py37h42c9631_2 75 | - mkl_random=1.2.1=py37ha9443f7_2 76 | - nbclient=0.5.3=pyhd3eb1b0_0 77 | - nbconvert=6.0.7=py37_0 78 | - nbformat=5.1.3=pyhd3eb1b0_0 79 | - ncurses=6.2=he6710b0_1 80 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 81 | - nettle=3.7.3=hbbd107a_1 82 | - ninja=1.10.2=hff7bd54_1 83 | - notebook=6.4.0=py37h06a4308_0 84 | - numpy-base=1.20.2=py37hfae3a4d_0 85 | - olefile=0.46=py37_0 86 | - openh264=2.1.0=hd408876_0 87 | - openssl=1.1.1n=h7f8727e_0 88 | - packaging=20.9=pyhd3eb1b0_0 89 | - pandoc=2.12=h06a4308_0 90 | - pandocfilters=1.4.3=py37h06a4308_1 91 | - parso=0.8.2=pyhd3eb1b0_0 92 | - pcre=8.44=he6710b0_0 93 | - pexpect=4.8.0=pyhd3eb1b0_3 94 | - pickleshare=0.7.5=pyhd3eb1b0_1003 95 | - pillow=8.2.0=py37he98fc37_0 96 | - pip=21.1.2=py37h06a4308_0 97 | - prometheus_client=0.11.0=pyhd3eb1b0_0 98 | - prompt-toolkit=3.0.17=pyh06a4308_0 99 | - prompt_toolkit=3.0.17=hd3eb1b0_0 100 | - ptyprocess=0.7.0=pyhd3eb1b0_2 101 | - pycparser=2.20=py_2 102 | - pygments=2.9.0=pyhd3eb1b0_0 103 | - pyparsing=2.4.7=pyhd3eb1b0_0 104 | - pyqt=5.9.2=py37h05f1152_2 105 | - pyrsistent=0.17.3=py37h7b6447c_0 106 | - python=3.7.9=h7579374_0 107 | - python-dateutil=2.8.1=pyhd3eb1b0_0 108 | - python_abi=3.7=1_cp37m 109 | - pyzmq=20.0.0=py37h2531618_1 110 | - qt=5.9.7=h5867ecd_1 111 | - qtconsole=5.0.3=pyhd3eb1b0_0 112 | - qtpy=1.9.0=py_0 113 | - readline=8.1=h27cfd23_0 114 | - send2trash=1.5.0=pyhd3eb1b0_1 115 | - setuptools=52.0.0=py37h06a4308_0 116 | - sip=4.19.8=py37hf484d3e_0 117 | - six=1.15.0=py37h06a4308_0 118 | - sqlite=3.35.4=hdfb4753_0 119 | - terminado=0.9.4=py37h06a4308_0 120 | - testpath=0.4.4=pyhd3eb1b0_0 121 | - tk=8.6.10=hbc83047_0 122 | - tornado=6.1=py37h27cfd23_0 123 | - traitlets=5.0.5=pyhd3eb1b0_0 124 | - typing_extensions=3.7.4.3=pyha847dfd_0 125 | - wcwidth=0.2.5=py_0 126 | - webencodings=0.5.1=py37_1 127 | - wheel=0.36.2=pyhd3eb1b0_0 128 | - widgetsnbextension=3.5.1=py37_0 129 | - xz=5.2.5=h7b6447c_0 130 | - zeromq=4.3.4=h2531618_0 131 | - zipp=3.4.1=pyhd3eb1b0_0 132 | - zlib=1.2.11=h7b6447c_3 133 | - zstd=1.4.9=haebb681_0 134 | - pip: 135 | - absl-py==1.4.0 136 | - antlr4-python3-runtime==4.9.3 137 | - astunparse==1.6.3 138 | - cachetools==5.3.0 139 | - charset-normalizer==3.0.1 140 | - click==8.1.3 141 | - comet-ml==3.31.6 142 | - configobj==5.0.6 143 | - cycler==0.10.0 144 | - dill==0.3.6 145 | - dm-tree==0.1.8 146 | - dotmap==1.3.30 147 | - dulwich==0.20.45 148 | - etils==0.9.0 149 | - everett==3.0.0 150 | - flatbuffers==23.1.21 151 | - gast==0.4.0 152 | - google-auth==2.16.0 153 | - google-auth-oauthlib==0.4.6 154 | - google-pasta==0.2.0 155 | - googleapis-common-protos==1.59.1 156 | - grpcio==1.44.0 157 | - h5py==3.8.0 158 | - hdf5storage==0.1.18 159 | - idna==3.4 160 | - importlib-metadata==6.0.0 161 | - importlib-resources==5.10.2 162 | - keras==2.11.0 163 | - kiwisolver==1.3.1 164 | - libclang==15.0.6.1 165 | - llvmlite==0.38.0 166 | - markdown==3.4.1 167 | - markupsafe==2.1.2 168 | - mat73==0.59 169 | - matplotlib==3.4.2 170 | - mpmath==1.2.1 171 | - numba==0.55.1 172 | - numpy==1.20.3 173 | - nvidia-ml-py3==7.352.0 174 | - oauthlib==3.2.2 175 | - omegaconf==2.3.0 176 | - opt-einsum==3.3.0 177 | - promise==2.3 178 | - protobuf==3.19.6 179 | - psutil==5.9.5 180 | - pyasn1==0.4.8 181 | - pyasn1-modules==0.2.8 182 | - pysimplegui==4.0.0 183 | - pywavelets==1.3.0 184 | - pyyaml==6.0 185 | - requests==2.28.2 186 | - requests-oauthlib==1.3.1 187 | - requests-toolbelt==0.9.1 188 | - rsa==4.9 189 | - scikit-commpy==0.7.0 190 | - scipy==1.7.3 191 | - semantic-version==2.10.0 192 | - sentry-sdk==1.8.0 193 | - sigmf==1.1.0 194 | - sigpy==0.1.23 195 | - sionna==0.12.1 196 | - soundfile==0.10.3.post1 197 | - sympy==1.9 198 | - tensorboard==2.11.2 199 | - tensorboard-data-server==0.6.1 200 | - tensorboard-plugin-wit==1.8.1 201 | - tensorflow==2.11.0 202 | - tensorflow-datasets==4.8.2 203 | - tensorflow-estimator==2.11.0 204 | - tensorflow-io-gcs-filesystem==0.30.0 205 | - tensorflow-metadata==1.12.0 206 | - termcolor==2.2.0 207 | - toml==0.10.2 208 | - torch==1.11.0 209 | - torchaudio==0.11.0 210 | - torchvision==0.12.0 211 | - tqdm==4.62.3 212 | - tueplots==0.0.8 213 | - urllib3==1.26.14 214 | - websocket-client==1.3.3 215 | - werkzeug==2.2.2 216 | - wrapt==1.14.1 217 | - wurlitzer==3.0.2 218 | prefix: /home/gridsan/glcf411/.conda/envs/rftorch 219 | -------------------------------------------------------------------------------- /src/learner_torchwavenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | import torch 20 | import torch.distributed as dist 21 | import torch.nn as nn 22 | 23 | from dataclasses import asdict 24 | from torch.nn.parallel import DistributedDataParallel 25 | from torch.utils.data import DataLoader, DistributedSampler 26 | from torch.utils.tensorboard import SummaryWriter 27 | from tqdm import tqdm 28 | from typing import Dict 29 | 30 | from .config_torchwavenet import Config 31 | from .torchdataset import RFMixtureDatasetBase, get_train_val_dataset 32 | from .torchwavenet import Wave 33 | 34 | 35 | def _nested_map(struct, map_fn): 36 | if isinstance(struct, tuple): 37 | return tuple(_nested_map(x, map_fn) for x in struct) 38 | if isinstance(struct, list): 39 | return [_nested_map(x, map_fn) for x in struct] 40 | if isinstance(struct, dict): 41 | return {k: _nested_map(v, map_fn) for k, v in struct.items()} 42 | return map_fn(struct) 43 | 44 | 45 | def view_as_complex(x): 46 | x = x[:, 0, ...] + 1j * x[:, 1, ...] 47 | return x 48 | 49 | 50 | class WaveLearner: 51 | def __init__(self, cfg: Config, model: nn.Module, rank: int): 52 | self.cfg = cfg 53 | 54 | # Store some import variables 55 | self.model_dir = cfg.model_dir 56 | self.distributed = cfg.distributed.distributed 57 | self.world_size = cfg.distributed.world_size 58 | self.rank = rank 59 | self.log_every = cfg.trainer.log_every 60 | self.validate_every = cfg.trainer.validate_every 61 | self.save_every = cfg.trainer.save_every 62 | self.max_steps = cfg.trainer.max_steps 63 | self.build_dataloaders() 64 | 65 | self.model = model 66 | self.optimizer = torch.optim.Adam( 67 | self.model.parameters(), lr=cfg.trainer.learning_rate) 68 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 69 | self.optimizer, "min", 70 | ) 71 | self.autocast = torch.cuda.amp.autocast(enabled=cfg.trainer.fp16) 72 | self.scaler = torch.cuda.amp.GradScaler(enabled=cfg.trainer.fp16) 73 | self.step = 0 74 | 75 | self.loss_fn = nn.MSELoss() 76 | self.writer = SummaryWriter(self.model_dir) 77 | 78 | @property 79 | def is_master(self): 80 | return self.rank == 0 81 | 82 | def build_dataloaders(self): 83 | self.dataset = RFMixtureDatasetBase( 84 | root_dir=self.cfg.data.root_dir, 85 | ) 86 | self.train_dataset, self.val_dataset = get_train_val_dataset( 87 | self.dataset, self.cfg.data.train_fraction) 88 | 89 | self.train_dataloader = DataLoader( 90 | self.train_dataset, 91 | batch_size=self.cfg.data.batch_size, 92 | shuffle=not self.distributed, 93 | num_workers=self.cfg.data.num_workers if self.distributed else 0, 94 | sampler=DistributedSampler( 95 | self.train_dataset, 96 | num_replicas=self.world_size, 97 | rank=self.rank) if self.distributed else None, 98 | pin_memory=True, 99 | ) 100 | self.val_dataloader = DataLoader( 101 | self.val_dataset, 102 | batch_size=self.cfg.data.batch_size * 4, 103 | shuffle=not self.distributed, 104 | num_workers=self.cfg.data.num_workers if self.distributed else 0, 105 | sampler=DistributedSampler( 106 | self.val_dataset, 107 | num_replicas=self.world_size, 108 | rank=self.rank) if self.distributed else None, 109 | pin_memory=True, 110 | ) 111 | 112 | def state_dict(self): 113 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 114 | model_state = self.model.module.state_dict() 115 | else: 116 | model_state = self.model.state_dict() 117 | return { 118 | 'step': self.step, 119 | 'model': {k: v.cpu() if isinstance(v, torch.Tensor) 120 | else v for k, v in model_state.items()}, 121 | 'optimizer': {k: v.cpu() if isinstance(v, torch.Tensor) 122 | else v for k, v in self.optimizer.state_dict().items()}, 123 | 'cfg': asdict(self.cfg), 124 | 'scaler': self.scaler.state_dict(), 125 | } 126 | 127 | def load_state_dict(self, state_dict): 128 | if hasattr(self.model, 'module') and isinstance(self.model.module, nn.Module): 129 | self.model.module.load_state_dict(state_dict['model']) 130 | else: 131 | self.model.load_state_dict(state_dict['model']) 132 | self.optimizer.load_state_dict(state_dict['optimizer']) 133 | self.scaler.load_state_dict(state_dict['scaler']) 134 | self.step = state_dict['step'] 135 | 136 | def save_to_checkpoint(self, filename='weights'): 137 | save_basename = f'{filename}-{self.step}.pt' 138 | save_name = f'{self.model_dir}/{save_basename}' 139 | link_name = f'{self.model_dir}/{filename}.pt' 140 | torch.save(self.state_dict(), save_name) 141 | 142 | if os.path.islink(link_name): 143 | os.unlink(link_name) 144 | os.symlink(save_basename, link_name) 145 | 146 | def restore_from_checkpoint(self, filename='weights'): 147 | try: 148 | checkpoint = torch.load(f'{self.model_dir}/{filename}.pt') 149 | self.load_state_dict(checkpoint) 150 | return True 151 | except FileNotFoundError: 152 | return False 153 | 154 | def train(self): 155 | device = next(self.model.parameters()).device 156 | 157 | while True: 158 | for i, features in enumerate( 159 | tqdm(self.train_dataloader, 160 | desc=f"Training ({self.step} / {self.max_steps})")): 161 | features = _nested_map(features, lambda x: x.to( 162 | device) if isinstance(x, torch.Tensor) else x) 163 | loss = self.train_step(features) 164 | 165 | # Check for NaNs 166 | if torch.isnan(loss).any(): 167 | raise RuntimeError( 168 | f'Detected NaN loss at step {self.step}.') 169 | 170 | if self.is_master: 171 | if self.step % self.log_every == 0: 172 | self.writer.add_scalar('train/loss', loss, self.step) 173 | self.writer.add_scalar( 174 | 'train/grad_norm', self.grad_norm, self.step) 175 | if self.step % self.save_every == 0: 176 | self.save_to_checkpoint() 177 | 178 | if self.step % self.validate_every == 0: 179 | val_loss = self.validate() 180 | # Update the learning rate if it plateus 181 | self.lr_scheduler.step(val_loss) 182 | 183 | if self.distributed: 184 | dist.barrier() 185 | 186 | self.step += 1 187 | 188 | if self.step == self.max_steps: 189 | if self.is_master and self.distributed: 190 | self.save_to_checkpoint() 191 | print("Ending training...") 192 | dist.barrier() 193 | exit(0) 194 | 195 | def train_step(self, features: Dict[str, torch.Tensor]): 196 | for param in self.model.parameters(): 197 | param.grad = None 198 | 199 | sample_mix = features["sample_mix"] 200 | sample_soi = features["sample_soi"] 201 | 202 | N, _, _ = sample_mix.shape 203 | 204 | with self.autocast: 205 | predicted = self.model(sample_mix) 206 | loss = self.loss_fn(predicted, sample_soi) 207 | 208 | self.scaler.scale(loss).backward() 209 | self.scaler.unscale_(self.optimizer) 210 | self.grad_norm = nn.utils.clip_grad_norm_( 211 | self.model.parameters(), self.cfg.trainer.max_grad_norm or 1e9) 212 | self.scaler.step(self.optimizer) 213 | self.scaler.update() 214 | 215 | return loss 216 | 217 | @torch.no_grad() 218 | def validate(self): 219 | device = next(self.model.parameters()).device 220 | self.model.eval() 221 | 222 | loss = 0 223 | for features in tqdm( 224 | self.val_dataloader, 225 | desc=f"Running validation after step {self.step}" 226 | ): 227 | features = _nested_map(features, lambda x: x.to( 228 | device) if isinstance(x, torch.Tensor) else x) 229 | sample_mix = features["sample_mix"] 230 | sample_soi = features["sample_soi"] 231 | N, _, _ = sample_mix.shape 232 | 233 | with self.autocast: 234 | predicted = self.model(sample_mix) 235 | loss += torch.sum( 236 | (predicted - sample_soi) ** 2, (0, 1, 2) 237 | ) / len(self.val_dataset) / np.prod(sample_soi.shape[1:]) 238 | if self.distributed: 239 | dist.all_reduce(loss, op=dist.ReduceOp.SUM) 240 | 241 | self.writer.add_scalar('val/loss', loss, self.step) 242 | self.model.train() 243 | 244 | return loss 245 | 246 | 247 | def _train_impl(rank: int, model: nn.Module, cfg: Config): 248 | torch.backends.cudnn.benchmark = True 249 | 250 | learner = WaveLearner(cfg, model, rank) 251 | learner.restore_from_checkpoint() 252 | learner.train() 253 | 254 | 255 | def train(cfg: Config): 256 | """Training on a single GPU.""" 257 | model = Wave(cfg.model).cuda() 258 | _train_impl(0, model, cfg) 259 | 260 | 261 | def init_distributed(rank: int, world_size: int, port: str): 262 | """Initialize distributed training on multiple GPUs.""" 263 | os.environ['MASTER_ADDR'] = 'localhost' 264 | os.environ['MASTER_PORT'] = str(port) 265 | torch.distributed.init_process_group( 266 | 'nccl', rank=rank, world_size=world_size) 267 | 268 | 269 | def train_distributed(rank: int, world_size: int, port, cfg: Config): 270 | """Training on multiple GPUs.""" 271 | init_distributed(rank, world_size, port) 272 | device = torch.device('cuda', rank) 273 | torch.cuda.set_device(device) 274 | model = Wave(cfg.model).to(device) 275 | model = DistributedDataParallel(model, device_ids=[rank]) 276 | _train_impl(rank, model, cfg) 277 | -------------------------------------------------------------------------------- /notebook/RFC_QuickStart_Guide.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f92bbb09-2961-4aee-98e9-6080259cbb16", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import os\n", 14 | "os.chdir(globals()['_dh'][0])\n", 15 | "os.chdir('..')\n", 16 | "# print(os.path.abspath(os.curdir))" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "2abb7b1d-3191-42ff-82e7-94e9ad2ff195", 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stderr", 27 | "output_type": "stream", 28 | "text": [ 29 | "2023-08-28 15:11:29.180909: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F FMA\n", 30 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", 31 | "2023-08-28 15:11:29.777151: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30976 MB memory: -> device: 0, name: Tesla V100-PCIE-32GB, pci bus id: 0000:86:00.0, compute capability: 7.0\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "import pickle\n", 39 | "import rfcutils\n", 40 | "import h5py\n", 41 | "\n", 42 | "import random\n", 43 | "import tensorflow as tf\n", 44 | "\n", 45 | "get_db = lambda p: 10*np.log10(p)\n", 46 | "get_pow = lambda s: np.mean(np.abs(s)**2, axis=-1)\n", 47 | "get_sinr = lambda s, i: get_pow(s)/get_pow(i)\n", 48 | "get_sinr_db = lambda s, i: get_db(get_sinr(s,i))" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "d72d61a6-34c5-4dda-be08-10259ba031a3", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "sig_len = 40960\n", 59 | "n_per_batch = 100\n", 60 | "all_sinr = np.arange(-30, 0.1, 3)\n", 61 | "\n", 62 | "seed_number = 0" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "33f7f5af-94b0-4e26-b358-2463b1c564ff", 68 | "metadata": {}, 69 | "source": [ 70 | "## Creating Training Set (Small Example)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "id": "32ae3f93-68ec-4af5-9c01-ceff00fd70fc", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Function to obtain the relevant generation and demodulation function for SOI type; we will focus on QPSK and OFDMQPSK for now\n", 81 | "def get_soi_generation_fn(soi_sig_type):\n", 82 | " if soi_sig_type == 'QPSK':\n", 83 | " generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16)\n", 84 | " demod_soi = rfcutils.qpsk_matched_filter_demod\n", 85 | " # elif soi_sig_type == 'QAM16':\n", 86 | " # generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16)\n", 87 | " # demod_soi = rfcutils.qam16_matched_filter_demod\n", 88 | " # elif soi_sig_type == 'QPSK2':\n", 89 | " # generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4)\n", 90 | " # demod_soi = rfcutils.qpsk2_matched_filter_demod\n", 91 | " elif soi_sig_type == 'OFDMQPSK':\n", 92 | " generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80)\n", 93 | " _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80)\n", 94 | " demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID)\n", 95 | " else:\n", 96 | " raise Exception(\"SOI Type not recognized\")\n", 97 | " return generate_soi, demod_soi" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "id": "8df395fc-f727-472c-bccb-9928ae57476d", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "sig_len = 40960\n", 108 | "n_per_batch = 100\n", 109 | "all_sinr = np.arange(-30, 0.1, 3)\n", 110 | "\n", 111 | "soi_type, interference_sig_type = 'QPSK', 'CommSignal2'\n", 112 | "\n", 113 | "seed_number = 0\n", 114 | "\n", 115 | "random.seed(seed_number)\n", 116 | "np.random.seed(seed_number)\n", 117 | "tf.random.set_seed(seed_number)\n", 118 | "\n", 119 | "generate_soi, demod_soi = get_soi_generation_fn(soi_type)\n", 120 | "\n", 121 | "with h5py.File(os.path.join('dataset', 'interferenceset_frame', interference_sig_type+'_raw_data.h5'),'r') as data_h5file:\n", 122 | " sig_data = np.array(data_h5file.get('dataset'))\n", 123 | " sig_type_info = data_h5file.get('sig_type')[()]\n", 124 | " if isinstance(sig_type_info, bytes):\n", 125 | " sig_type_info = sig_type_info.decode(\"utf-8\") " 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "id": "5c381346-69be-4d37-a74c-128aa9fdef9d", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stderr", 136 | "output_type": "stream", 137 | "text": [ 138 | "2023-08-28 15:11:38.160649: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100\n", 139 | "2023-08-28 15:12:47.550813: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "all_sig_mixture, all_sig1, all_bits1, meta_data = [], [], [], []\n", 145 | "for idx, sinr in enumerate(all_sinr):\n", 146 | " sig1, _, bits1, _ = generate_soi(n_per_batch, sig_len)\n", 147 | " sig2 = sig_data[np.random.randint(sig_data.shape[0], size=(n_per_batch)), :]\n", 148 | "\n", 149 | " sig_target = sig1[:, :sig_len]\n", 150 | "\n", 151 | " rand_start_idx2 = np.random.randint(sig2.shape[1]-sig_len, size=sig2.shape[0])\n", 152 | " inds2 = tf.cast(rand_start_idx2.reshape(-1,1) + np.arange(sig_len).reshape(1,-1), tf.int32)\n", 153 | " sig_interference = tf.experimental.numpy.take_along_axis(sig2, inds2, axis=1)\n", 154 | "\n", 155 | " # Interference Coefficient\n", 156 | " rand_gain = np.sqrt(10**(-sinr/10)).astype(np.float32)\n", 157 | " rand_phase = tf.random.uniform(shape=[sig_interference.shape[0],1])\n", 158 | " rand_gain = tf.complex(rand_gain, tf.zeros_like(rand_gain))\n", 159 | " rand_phase = tf.complex(rand_phase, tf.zeros_like(rand_phase))\n", 160 | " coeff = rand_gain * tf.math.exp(1j*2*np.pi*rand_phase)\n", 161 | "\n", 162 | " sig_mixture = sig_target + sig_interference * coeff\n", 163 | "\n", 164 | " all_sig_mixture.append(sig_mixture)\n", 165 | " all_sig1.append(sig_target)\n", 166 | " all_bits1.append(bits1)\n", 167 | "\n", 168 | " actual_sinr = get_sinr_db(sig_target, sig_interference * coeff)\n", 169 | " meta_data.append(np.vstack(([rand_gain.numpy().real for _ in range(n_per_batch)], [sinr for _ in range(n_per_batch)], actual_sinr, [soi_type for _ in range(n_per_batch)], [interference_sig_type for _ in range(n_per_batch)])))\n", 170 | "\n", 171 | "with tf.device('CPU'):\n", 172 | " all_sig_mixture = tf.concat(all_sig_mixture, axis=0).numpy()\n", 173 | " all_sig1 = tf.concat(all_sig1, axis=0).numpy()\n", 174 | " all_bits1 = tf.concat(all_bits1, axis=0).numpy()\n", 175 | "\n", 176 | "meta_data = np.concatenate(meta_data, axis=1).T" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "id": "eb04bf5a-1ec9-4ac5-94c4-fde839cfd8e0", 182 | "metadata": {}, 183 | "source": [ 184 | "### TODO: Train a model\n", 185 | "\n", 186 | "***Input***: `all_sig_mixture` (and optionally, `meta_data`)\n", 187 | "\n", 188 | "***Output***: `all_sig1` and `all_bits1`" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "39d59539-0355-4269-a95a-88743dae46aa", 194 | "metadata": {}, 195 | "source": [ 196 | "## Example Inference/Output Submission" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 8, 202 | "id": "9d56590c-f108-410c-835c-0971b59c9cab", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "testset_identifier = 'TestSet1Mixture'\n", 207 | "all_sig_mixture = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}.npy'))\n", 208 | "meta_data = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}_metadata.npy'))" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 9, 214 | "id": "63a2e11a-22c9-4c88-b532-622275d79696", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "n_per_batch = 100\n", 219 | "all_sinr = np.arange(-30, 0.1, 3)\n", 220 | "\n", 221 | "id_string = 'ExampleNoMitigation'\n", 222 | "def run_inference(all_sig_mixture, meta_data, soi_type, interference_sig_type): \n", 223 | " \n", 224 | " ##################################################\n", 225 | " # Perform your inference here.\n", 226 | " ##################################################\n", 227 | " generate_soi, demod_soi = get_soi_generation_fn(soi_type)\n", 228 | " \n", 229 | " # E.g. No mitigation, standard matched filtering demodulation\n", 230 | " sig1_est = all_sig_mixture\n", 231 | " \n", 232 | " bit_est = []\n", 233 | " for idx, sinr_db in enumerate(all_sinr):\n", 234 | " bit_est_batch, _ = demod_soi(sig1_est[idx*n_per_batch:(idx+1)*n_per_batch])\n", 235 | " bit_est.append(bit_est_batch)\n", 236 | " bit_est = tf.concat(bit_est, axis=0)\n", 237 | " sig1_est, bit_est = sig1_est, bit_est.numpy()\n", 238 | " ##################################################\n", 239 | " \n", 240 | " return sig1_est, bit_est" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 10, 246 | "id": "3857f7c5-a4e9-4001-b7e9-137b054c2871", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "sig1_est, bit1_est = run_inference(all_sig_mixture, meta_data, soi_type, interference_sig_type)\n", 251 | "np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_soi_{soi_type}_{interference_sig_type}'), sig1_est)\n", 252 | "np.save(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_bits_{soi_type}_{interference_sig_type}'), bit1_est)" 253 | ] 254 | } 255 | ], 256 | "metadata": { 257 | "kernelspec": { 258 | "display_name": "Python [conda env:.conda-rfsionna]", 259 | "language": "python", 260 | "name": "conda-env-.conda-rfsionna-py" 261 | }, 262 | "language_info": { 263 | "codemirror_mode": { 264 | "name": "ipython", 265 | "version": 3 266 | }, 267 | "file_extension": ".py", 268 | "mimetype": "text/x-python", 269 | "name": "python", 270 | "nbconvert_exporter": "python", 271 | "pygments_lexer": "ipython3", 272 | "version": "3.7.13" 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 5 277 | } 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Starter Code for ICASSP 2024 SP Grand Challenge: Data-Driven Signal Separation in Radio Spectrum 2 | 3 | [Click here for details on the challenge setup](https://rfchallenge.mit.edu/wp-content/uploads/2023/09/ICASSP24_single_channel.pdf) 4 | 5 | ## November 23, 2023 UPDATE: TestSet2Mixture has been released. 6 | For participants of the ICASSP 2024 SP Grand Challenge for Data-Driven Signal Separation in Radio Spectrum, the final test set for evaluation has been released. 7 | 8 | [Click here for TestSet2Mixture files]([https://www.dropbox.com/scl/fi/d2kjtfmbh3mgxddbubf80/TestSet1Mixture.zip?rlkey=lwhzt1ayn2bqwosc9o9cq9dwr&dl=0](https://www.dropbox.com/scl/fi/m36l2imiit5svqz1yz46g/TestSet2Mixture.zip?rlkey=n5mwzi11l55l2xzfw9ee5m0ye&dl=0)) 9 | 10 | ## About this Repository 11 | For those eager to dive in, we have prepared a concise guide to get you started. 12 | 13 | Check out [notebook/RFC_QuickStart_Guide.ipynb](https://github.com/RFChallenge/icassp2024rfchallenge/blob/0.2.0/notebook/RFC_QuickStart_Guide.ipynb) for practical code snippets. You will find steps to create a small but representative training set and steps for inference to generate your submission outputs. 14 | For a broader understanding and other helpful resources in the starter kit integral to the competition, please see the details and references provided below. 15 | 16 | [Link to InterferenceSet](https://www.dropbox.com/scl/fi/zlvgxlhp8het8j8swchgg/dataset.zip?rlkey=4rrm2eyvjgi155ceg8gxb5fc4&dl=0) 17 | 18 | ## TestSet for Evaluation 19 | 20 | This starter kit equips you with essential resources to develop signal separation and interference rejection solutions. In this competition, the crux of the evaluation hinges on your ability to handle provided signal mixtures. Your task will be twofold: 21 | 22 | 1. Estimate the Signal of Interest (SOI) component within the two-component mixture. 23 | 24 | 2. Deduce the best possible estimate of the underlying information bits encapsulated in the SOI. 25 | 26 | Delve into the specifics below for comprehensive details. 27 | 28 | ### TestSet1: 29 | 30 | [Click here for TestSet1Mixture files](https://www.dropbox.com/scl/fi/d2kjtfmbh3mgxddbubf80/TestSet1Mixture.zip?rlkey=lwhzt1ayn2bqwosc9o9cq9dwr&dl=0) 31 | 32 | 50 frames of each interference type have been reserved to form TestSet1 (interference frames). These will be released alongside the main dataset (InterferenceSet frames), and the mixtures from TestSet1Mixture are generated from this collection. Please note that although TestSet1 is available for examination, the final evaluation for participants will be based on a hidden, unreleased set (TestSet2 interference frames). 33 | 34 | #### File Descriptions: 35 | 36 | ***`TestSet1Mixture_testmixture_[SOI Type]_[Interference Type].npy`:*** This is a numpy array of 1,100 x 40,960 np.complex64 floats; each row represents a mixture signal (40,960 complex values); the mixture signals are organized in increasing SINR, spanning 11 SINR levels with 100 mixtures per SINR level. 37 | 38 | ***`TestSet1Mixture_testmixture_[SOI Type]_[Interference Type]_metadata.npy`:*** This is a numpy array of 1,100 x 5 objects containing metadata information. The first column is the scalar value (k) by which the interference component is scaled; the second column is the corresponding target SINR level in dB (calculated as $10 \log_{10}(k)$ ), the third column is the actual SINR computed based on the components present in the mixture (represented as $10 \log_{10}(P_{SOI}/P_{Interference})$ ), and the fourth and fifth column are strings denoting the SOI type and interference type respectively 39 | 40 | Note that the Ground Truth of TestSet1Mixture will not be released. Participants are encouraged to send in their submissions at the respective intermediate deadlines (October 1st and November 1st) for evaluation against the ground truth. 41 | 42 | Participants are also provided with starter code to generate similar testing set mixtures (refer to "Helper Functions for Testing" and TestSet1Example) for their own testing. 43 | 44 | ### TestSet2: 45 | [Click here for TestSet2Mixture files]([https://www.dropbox.com/scl/fi/d2kjtfmbh3mgxddbubf80/TestSet1Mixture.zip?rlkey=lwhzt1ayn2bqwosc9o9cq9dwr&dl=0](https://www.dropbox.com/scl/fi/m36l2imiit5svqz1yz46g/TestSet2Mixture.zip?rlkey=n5mwzi11l55l2xzfw9ee5m0ye&dl=0)) 46 | 47 | 50 frames of each interference type have been designated for TestSet2. Please note that this set will not be made available during the competition. 48 | 49 | The format for test mixtures in TestSet2Mixture will be consistent with that of TestSet1Mixture. However, any changes or modifications to the format will be communicated to the participants as the competition progresses. 50 | 51 | 52 | ### Submission Specifications: 53 | 54 | For every configuration defined by a specific SOI Type and Interference Type, participants are required to provide: 55 | 56 | 1. SOI Component Estimate: 57 | - A numpy array of dimensions 1,100 x 40,960. 58 | - This should contain complex values representing the estimated SOI component present. 59 | - Filename: `[ID String]_[TestSet Identifier]_estimated_soi_[SOI Type]_[Interference Type].npy` 60 | (where ID String will be a unique identifier, e.g., your team name) 61 | 62 | 2. Information Bits Estimate: 63 | - A numpy array of dimensions 1,100 x B. 64 | - The value of B depends on the SOI type: 65 | - B = 5,120 for QPSK SOI 66 | - B = 57,344 for OFDMQPSK SOI 67 | - The array should exclusively contain values of 1’s and 0’s, corresponding to the estimated information bits carried by the SOI. 68 | - Filename: `[ID String]_[TestSet Identifier]_estimated_bits_[SOI Type]_[Interference Type].npy` 69 | (where ID String will be a unique identifier, e.g., your team name) 70 | 71 | For guidance on mapping the SOI signal to the information bits, participants are advised to consult the provided demodulation helper functions (e.g., as used in [notebook/RFC_EvalSet_Demo.ipynb](https://github.com/RFChallenge/rfchallenge_singlechannel_starter_grandchallenge2023/blob/0.2.0/notebook/RFC_EvalSet_Demo.ipynb)). 72 | 73 | Submissions should be sent to the organizers at rfchallenge@mit.edu. 74 | 75 | _The intellectual property (IP) is not transferred to the challenge organizers; in other words, if code is shared or submitted, the participants retain ownership of their code._ 76 | 77 | ## Starter Code Setup: 78 | Relevant bash commands to set up the starter code: 79 | ```bash 80 | git clone https://github.com/RFChallenge/icassp2024rfchallenge.git rfchallenge 81 | cd rfchallenge 82 | 83 | # To obtain the dataset 84 | wget -O dataset.zip "https://www.dropbox.com/scl/fi/zlvgxlhp8het8j8swchgg/dataset.zip?rlkey=4rrm2eyvjgi155ceg8gxb5fc4&dl=0" 85 | unzip dataset.zip 86 | rm dataset.zip 87 | 88 | # To obtain TestSet1Mixture 89 | wget -O TestSet1Mixture.zip "https://www.dropbox.com/scl/fi/d2kjtfmbh3mgxddbubf80/TestSet1Mixture.zip?rlkey=lwhzt1ayn2bqwosc9o9cq9dwr&dl=0" 90 | unzip TestSet1Mixture.zip -d dataset 91 | rm TestSet1Mixture.zip 92 | ``` 93 | 94 | Dependencies: The organizers have used the following libraries to generate the signal mixtures and test the relevant baseline models 95 | * python==3.7.13 96 | * numpy==1.21.6 97 | * tensorflow==2.8.2 98 | * sionna==0.10.0 99 | * tqdm==4.64.0 100 | * h5py==3.7.0 101 | 102 | For a complete overview of the dependencies within our Anaconda environment, please refer [here (rfsionna)](https://github.com/RFChallenge/icassp2024rfchallenge/blob/0.2.0/rfsionna_env.yml). Additionally, if you're interested in the PyTorch-based baseline, you can find the respective Anaconda environment dependencies that the organizers used [here (rftorch)](https://github.com/RFChallenge/icassp2024rfchallenge/blob/0.2.0/rftorch_env.yml). 103 | 104 | Since participants are tasked with running their own inference, we are currently not imposing restrictions on the libraries for training and inference. However, the submissions are expected to be in the form of numpy arrays (`.npy` files) that are compatible with our system (`numpy==1.21.6`). 105 | 106 | > Note: Diverging from the versions of the dependencies listed above might result in varied behaviors of the starter code. Participants are advised to check for version compatibility in their implementations and solutions. 107 | 108 | 109 | ## Helper Functions for Testing: 110 | 111 | To assist participants during testing, we provide several example scripts designed to create and test with evaluation sets analogous to TestSet1Mixture. 112 | 113 | `python sampletest_testmixture_generator.py [SOI Type] [Interference Type]` 114 | 115 | This script generates a new evaluation set (default name: TestSet1Example) based on the raw interference dataset of TestSet1. Participants can employ this for cross-checking. The produced outputs include a mixture numpy array, a metadata numpy array (similar to what's given in TestSet1Mixture), and a ground truth file. Participants can also change the seed number to generate new instances of such example test sets. 116 | 117 | (An example generated, named TestSet1Example (using seed_number=0), can be found [here](https://drive.google.com/file/d/1D1rHwEBpDRBVWhBGalEGJ0OzYbBeb4il/view?usp=drive_link).) 118 | 119 | 120 | `python sampletest_tf_unet_inference.py [SOI Type] [Interference Type] [TestSet Identifier]` 121 | 122 | `python sampletest_torch_wavenet_inference.py [SOI Type] [Interference Type] [TestSet Identifier]` 123 | 124 | (Default: Use TestSet1Example for [TestSet Identifier]) 125 | Scripts that leverage the supplied baseline methods (Modified U-Net on Tensorflow or WaveNet on PyTorch) for inference. 126 | 127 | `python sampletest_evaluationscript.py [SOI Type] [Interference Type] [TestSet Identifier] [Method ID String]` 128 | 129 | [Method ID String] is your submission's unique identifier---refer to submission specifications. 130 | Utilize this script to assess the outputs generated from the inference script. 131 | 132 | 133 | ## Helper Functions for Training: 134 | 135 | For a grasp of the basic functionalities concerning the communication signals (the SOI) and code snippets relating to how we load and extract interference signal windows to create signal mixtures, participants are referred to the RFC_Demo.ipynb in our starter code. 136 | 137 | We also provide some reference codes used by the organizers to train the baseline methods. These files include: 138 | 139 | 1. Training Dataset Scripts: Used for creating an extensive training set. The shell script file with the relevant commands is included: sampletrain_gendataset_script.sh. Participants can refer to and modify (comment/uncomment) the relevant commands in the shell script. The corresponding python files used can be found in the `dataset_utils` directory and include: 140 | - `example_generate_competition_trainmixture.py`: A python script for generating example mixtures for training; this script creates a training set that is more aligned with the TestSet’s specifications (e.g., focusing solely on the 11 discrete target SINR levels). This script saves a pickle file `dataset/Training_Dataset_[SOI Type]_[Interference Type].pkl'` that contains `all_sig_mixture, all_sig1_groundtruth, all_bits1_groundtruth, meta_data`. 141 | - `example_generate_rfc_mixtures.py`: Another python script that creates 240,000 sample mixtures with varying random target SINR levels (ranging between -33 dB and 3 dB). For each signal mixture configuration, the output is saved as 60 HDF5 files, each containing 4,000 mixtures. This is the organizers' choice when generating the training set (for better generalization properties, while setting aside the metadata for implementation simplicity). 142 | - `tfds_scripts/Dataset_[SOI Type]_[Interference Type]_Mixture.py`: Used in conjunction with the Tensorflow UNet training scripts; the HDF5 files are processed into Tensorflow Datasets (TFDS) for training. 143 | - ` example_preprocess_npy_dataset.py`: Used in conjunction with the Torch WaveNet training scripts; the HDF5 files are processed into separate npy files (one file per mixture). An associated dataloader is supplied within the PyTorch baseline code. 144 | 145 | 2. Model Training Scripts: The competition organizers have curated two implementations: 146 | - UNet on Tensorflow: `train_unet_model.py`, accompanied with neural network specification in `src/unet_model.py` 147 | - WaveNet on Torch: `train_torchwavenet.py`, accompanied with dependencies including `supervised_config.yml` and `src/configs`, `src/torchdataset.py`, `src/learner_torchwavenet.py`, `src/config_torchwavenet.py` and `src/torchwavenet.py` 148 | 149 | While the provided scripts serve as a starting point, participants have no obligations to utilize them. These files are provided as references to aid those wishing to expand upon or employ the baseline methods. Participants are encouraged to explore other possible strategies for creating training sets from the corresponding InterferenceSet frames and SOI generation functions, as well as more effective ways of utilizing relevant information (e.g., the metadata). 150 | 151 | Trained model weights for the UNet and WaveNet can be obtained here: [reference_models.zip](https://www.dropbox.com/scl/fi/890vztq67krephwyr0whb/reference_models.zip?rlkey=6yct3w8rx183f0l3ok2my6rej&dl=0). 152 | 153 | Relevant bash commands: 154 | ```bash 155 | wget -O reference_models.zip "https://www.dropbox.com/scl/fi/890vztq67krephwyr0whb/reference_models.zip?rlkey=6yct3w8rx183f0l3ok2my6rej&dl=0" 156 | unzip reference_models.zip 157 | rm reference_models.zip 158 | ``` 159 | 160 | --- 161 | ## Available Support Channels: 162 | *(For the Grand Challenge: September to December 2023)* 163 | 164 | As you embark on this challenge, we would like to offer avenues for assistance. 165 | Below are several channels through which you can reach out to us for help. Our commitment is to foster an environment that aids understanding and collaboration. Your questions, feedback, and concerns are instrumental in ensuring a seamless competition. 166 | * Discord (Invitation Link): https://discord.gg/4thrZCVsTu 167 | 168 | * Github (under the Issues tab): https://github.com/RFChallenge/icassp2024rfchallenge/issues 169 | 170 | * Email: rfchallenge@mit.edu 171 | >Note: Please be aware that the organizers reserve the right to publicly share email exchanges on any of the above channels. This is done to promote information dissemination and provide clarifications to commonly asked questions. 172 | 173 | While we endeavor to offer robust support and timely communication, please understand that our assistance is provided on a "best-effort" basis. We are committed to addressing as many queries and issues as possible, but we may not have solutions to all problems. 174 | 175 | Participants are encouraged to utilize the provided channels and collaborate with peers. By participating, you acknowledge and agree that the organizers are not responsible for resolving all issues or ensuring uninterrupted functionality of any tools or platforms. Your understanding and patience are greatly appreciated. 176 | 177 | --- 178 | ### Acknowledgements 179 | The efforts of the organizers are supported by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. 180 | 181 | The organizers acknowledge the MIT SuperCloud and Lincoln Laboratory Supercomputing Center for providing HPC resources that have contributed to the development of this work. 182 | --------------------------------------------------------------------------------