├── RMDM.jpeg ├── fig ├── intro.jpg └── comperation.png ├── lib ├── __init__.py ├── modules.py └── loaders.py ├── requirement.txt ├── utils ├── logger.py ├── __init__.py ├── losses.py ├── utils.py ├── model_wrapper.py ├── nn.py └── fp16_util.py ├── __init__.py ├── .gitignore ├── README.md ├── train.py └── sample_test.py /RMDM.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hxxxz0/RMDM/HEAD/RMDM.jpeg -------------------------------------------------------------------------------- /fig/intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hxxxz0/RMDM/HEAD/fig/intro.jpg -------------------------------------------------------------------------------- /fig/comperation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hxxxz0/RMDM/HEAD/fig/comperation.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loading and module components package 3 | """ 4 | 5 | __all__ = [ 6 | 'loaders', 7 | 'modules', 8 | ] 9 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | 2 | torch>=1.7.0 3 | torchvision>=0.8.0 4 | numpy>=1.19.0 5 | pandas>=1.2.0 6 | scipy>=1.6.0 7 | scikit-image>=0.18.0 8 | matplotlib>=3.3.0 9 | opencv-python>=4.5.0 10 | nibabel>=3.2.0 11 | batchgenerators>=0.21 12 | visdom>=0.1.8 13 | torchsummary>=1.5.0 14 | blobfile>=2.0.0 15 | accelerate>=0.20.0 16 | diffusers>=0.20.0 17 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _ACC = {} 4 | _DIR = None 5 | 6 | 7 | def configure(dir: str): 8 | global _DIR 9 | _DIR = dir 10 | os.makedirs(_DIR, exist_ok=True) 11 | 12 | 13 | def logkv_mean(key, val): 14 | _ACC[key] = float(val) 15 | 16 | 17 | def get_dir(): 18 | return _DIR or '.' 19 | 20 | 21 | def log(msg): 22 | print(msg) 23 | 24 | 25 | def dumpkvs(): 26 | if not _ACC: 27 | return 28 | print({k: round(v, 6) for k, v in _ACC.items()}) 29 | _ACC.clear() 30 | 31 | 32 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions and model wrappers 3 | """ 4 | 5 | __all__ = [ 6 | 'losses', 7 | 'model_wrapper', 8 | 'utils', 9 | 'logger', 10 | 'fp16_util', 11 | 'nn', 12 | 'cal_pinn', 13 | 'build_unet_from_config', 14 | 'UNetWithTimeWrapper', 15 | ] 16 | 17 | 18 | def __getattr__(name): 19 | if name == 'cal_pinn': 20 | from .losses import cal_pinn 21 | return cal_pinn 22 | elif name == 'build_unet_from_config': 23 | from .model_wrapper import build_unet_from_config 24 | return build_unet_from_config 25 | elif name == 'UNetWithTimeWrapper': 26 | from .model_wrapper import UNetWithTimeWrapper 27 | return UNetWithTimeWrapper 28 | else: 29 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 30 | 31 | 32 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | PhyRMDM: Physics-informed Radio Map Diffusion Model 3 | 4 | A physics-informed radio map diffusion model project 5 | """ 6 | 7 | __version__ = "1.0.0" 8 | __author__ = "PhyRMDM Team" 9 | 10 | __all__ = [ 11 | 'utils', 12 | 'lib', 13 | 'build_unet_from_config', 14 | 'cal_pinn', 15 | ] 16 | 17 | # Lazy import to avoid dependency issues 18 | def __getattr__(name): 19 | if name == 'utils': 20 | from . import utils 21 | return utils 22 | elif name == 'lib': 23 | from . import lib 24 | return lib 25 | elif name == 'build_unet_from_config': 26 | from .utils import build_unet_from_config 27 | return build_unet_from_config 28 | elif name == 'cal_pinn': 29 | from .utils import cal_pinn 30 | return cal_pinn 31 | else: 32 | raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 33 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def cal_pinn(cal, buildings, shooter, k=1.0, k_building=1.0): 6 | """ 7 | Migrated from guided_diffusion.gaussian_diffusion.cal_pinn, behavior unchanged: 8 | - cal, buildings, shooter: (B, H, W) 9 | - Returns: (B,) per-sample PINN loss 10 | """ 11 | cal_t = cal.unsqueeze(1) 12 | buildings_t = buildings.unsqueeze(1) 13 | shooter_t = shooter.unsqueeze(1) 14 | 15 | device = cal_t.device 16 | dtype = cal_t.dtype 17 | 18 | lap_kernel = torch.tensor([[0.0, 1.0, 0.0], 19 | [1.0,-4.0, 1.0], 20 | [0.0, 1.0, 0.0]], device=device, dtype=dtype).view(1,1,3,3) 21 | lap = F.conv2d(cal_t, lap_kernel, padding=1) 22 | 23 | buildings_mask = (buildings_t > 0.5) 24 | shooter_mask = (shooter_t > 0.5) 25 | 26 | k_tensor = torch.tensor(float(k), device=device, dtype=dtype) 27 | k_building_tensor = torch.tensor(float(k_building), device=device, dtype=dtype) 28 | k_map = torch.where(buildings_mask, k_building_tensor, k_tensor) 29 | 30 | residual = lap + (k_map ** 2) * cal_t 31 | L_pde = residual.pow(2).flatten(1).mean(1) 32 | 33 | bc_num = buildings_mask.sum(dim=(1,2,3)).clamp_min(1) 34 | L_bc = (cal_t.pow(2) * buildings_mask).sum(dim=(1,2,3)) / bc_num 35 | 36 | src_num = shooter_mask.sum(dim=(1,2,3)).clamp_min(1) 37 | L_source = ((cal_t - 1.0).pow(2) * shooter_mask).sum(dim=(1,2,3)) / src_num 38 | 39 | loss = L_pde + L_bc + L_source 40 | return loss 41 | 42 | 43 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | softmax_helper = lambda x: F.softmax(x, 1) 7 | sigmoid_helper = lambda x: torch.sigmoid(x) 8 | 9 | 10 | class InitWeights_He(object): 11 | def __init__(self, neg_slope=1e-2): 12 | self.neg_slope = neg_slope 13 | 14 | def __call__(self, module): 15 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 16 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 17 | if module.bias is not None: 18 | module.bias = nn.init.constant_(module.bias, 0) 19 | 20 | 21 | def maybe_to_torch(d): 22 | if isinstance(d, list): 23 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 24 | elif not isinstance(d, torch.Tensor): 25 | d = torch.from_numpy(d).float() 26 | return d 27 | 28 | 29 | def to_cuda(data, non_blocking=True, gpu_id=0): 30 | if isinstance(data, list): 31 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 32 | else: 33 | data = data.cuda(gpu_id, non_blocking=non_blocking) 34 | return data 35 | 36 | 37 | class no_op(object): 38 | def __enter__(self): 39 | pass 40 | 41 | def __exit__(self, *args): 42 | pass 43 | 44 | 45 | def dice_score(pred, targs): 46 | pred = (pred>0).float() 47 | return 2. * (pred*targs).sum() / (pred+targs).sum() 48 | 49 | 50 | def norm(t): 51 | m, s = torch.mean(t), torch.std(t) 52 | return (t - m) / (s + 1e-8) 53 | 54 | 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib64/ 14 | parts/ 15 | sdist/ 16 | var/ 17 | wheels/ 18 | *.egg-info/ 19 | .installed.cfg 20 | *.egg 21 | MANIFEST 22 | 23 | # PyInstaller 24 | *.manifest 25 | *.spec 26 | 27 | # Installer logs 28 | pip-log.txt 29 | pip-delete-this-directory.txt 30 | 31 | # Unit test / coverage reports 32 | htmlcov/ 33 | .tox/ 34 | .nox/ 35 | .coverage 36 | .coverage.* 37 | .cache 38 | nosetests.xml 39 | coverage.xml 40 | *.cover 41 | .hypothesis/ 42 | .pytest_cache/ 43 | 44 | # Jupyter Notebook 45 | .ipynb_checkpoints 46 | 47 | # IPython 48 | profile_default/ 49 | ipython_config.py 50 | 51 | # pyenv 52 | .python-version 53 | 54 | # Environments 55 | .env 56 | .venv 57 | env/ 58 | venv/ 59 | ENV/ 60 | env.bak/ 61 | venv.bak/ 62 | 63 | # IDE 64 | .vscode/ 65 | .idea/ 66 | *.swp 67 | *.swo 68 | *~ 69 | 70 | # OS 71 | .DS_Store 72 | .DS_Store? 73 | ._* 74 | .Spotlight-V100 75 | .Trashes 76 | ehthumbs.db 77 | Thumbs.db 78 | 79 | # Large files (adjust size as needed) 80 | *.zip 81 | *.tar.gz 82 | *.rar 83 | *.7z 84 | 85 | # Model files (common ML model extensions) 86 | *.pkl 87 | *.pickle 88 | *.h5 89 | *.hdf5 90 | *.pb 91 | *.pth 92 | *.pt 93 | *.ckpt 94 | *.safetensors 95 | 96 | # Data files 97 | *.csv 98 | *.json 99 | *.xml 100 | *.parquet 101 | *.feather 102 | 103 | # Images (uncomment if you want to exclude all images) 104 | # *.jpg 105 | # *.jpeg 106 | # *.png 107 | # *.gif 108 | # *.bmp 109 | # *.tiff 110 | # *.svg 111 | 112 | # Logs 113 | *.log 114 | logs/ 115 | 116 | # Temporary files 117 | tmp/ 118 | temp/ 119 | *.tmp 120 | *.temp -------------------------------------------------------------------------------- /utils/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import sys 5 | 6 | # Set up path to import UNet model 7 | _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | _PARENT_DIR = os.path.dirname(_THIS_DIR) 9 | if _PARENT_DIR not in sys.path: 10 | sys.path.append(_PARENT_DIR) 11 | 12 | from unet import UNetModel_newpreview 13 | 14 | 15 | class UNetWithTimeWrapper(nn.Module): 16 | """ 17 | Wrap existing UNetModel_newpreview as standard diffusion model: 18 | - Input: noisy_input (B, C, H, W) containing conditions and noisy target; timesteps (B,) 19 | - Output: noise_pred (B, out_ch, H, W) predicted noise 20 | """ 21 | def __init__(self, unet: UNetModel_newpreview): 22 | super().__init__() 23 | self.unet = unet 24 | 25 | def forward(self, sample: torch.Tensor, timesteps: torch.Tensor): 26 | # Return underlying UNet output directly; preserve both branches (noise and cal) if tuple 27 | out = self.unet(sample, timesteps) 28 | return out 29 | 30 | 31 | def build_unet_from_config(cfg: dict) -> UNetWithTimeWrapper: 32 | # Handle empty channel_mult, set default based on image_size 33 | if not cfg['channel_mult']: 34 | if cfg['image_size'] == 512: 35 | channel_mult = (1, 1, 2, 2, 4, 4) 36 | elif cfg['image_size'] == 256: 37 | channel_mult = (1, 1, 2, 2, 4, 4) 38 | elif cfg['image_size'] == 128: 39 | channel_mult = (1, 1, 2, 3, 4) 40 | elif cfg['image_size'] == 64: 41 | channel_mult = (1, 2, 3, 4) 42 | else: 43 | raise ValueError(f"unsupported image size: {cfg['image_size']}") 44 | else: 45 | channel_mult = tuple(int(x) for x in cfg['channel_mult'].split(',')) 46 | 47 | unet = UNetModel_newpreview( 48 | image_size=cfg['image_size'], 49 | in_channels=cfg['in_ch'], 50 | model_channels=cfg['num_channels'], 51 | out_channels=cfg.get('out_ch', 1), # Use configured output channels 52 | num_res_blocks=cfg['num_res_blocks'], 53 | attention_resolutions=tuple(int(cfg['image_size']) // int(r) for r in cfg['attention_resolutions'].split(',')), 54 | dropout=cfg['dropout'], 55 | channel_mult=channel_mult, 56 | num_classes=(2 if cfg.get('class_cond', False) else None), 57 | use_checkpoint=cfg['use_checkpoint'], 58 | use_fp16=cfg['use_fp16'], 59 | num_heads=cfg['num_heads'], 60 | num_head_channels=cfg['num_head_channels'], 61 | num_heads_upsample=cfg['num_heads_upsample'], 62 | use_scale_shift_norm=cfg['use_scale_shift_norm'], 63 | resblock_updown=cfg['resblock_updown'], 64 | use_new_attention_order=cfg['use_new_attention_order'], 65 | high_way=True, 66 | ) 67 | return UNetWithTimeWrapper(unet) 68 | 69 | 70 | -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | def layer_norm(shape, *args, **kwargs): 35 | 36 | return nn.LayerNorm(shape, *args, **kwargs) 37 | 38 | def linear(*args, **kwargs): 39 | """ 40 | Create a linear module. 41 | """ 42 | return nn.Linear(*args, **kwargs) 43 | 44 | 45 | def avg_pool_nd(dims, *args, **kwargs): 46 | """ 47 | Create a 1D, 2D, or 3D average pooling module. 48 | """ 49 | if dims == 1: 50 | return nn.AvgPool1d(*args, **kwargs) 51 | elif dims == 2: 52 | return nn.AvgPool2d(*args, **kwargs) 53 | elif dims == 3: 54 | return nn.AvgPool3d(*args, **kwargs) 55 | raise ValueError(f"unsupported dimensions: {dims}") 56 | 57 | 58 | def update_ema(target_params, source_params, rate=0.99): 59 | """ 60 | Update target parameters to be closer to those of source parameters using 61 | an exponential moving average. 62 | 63 | :param target_params: the target parameter sequence. 64 | :param source_params: the source parameter sequence. 65 | :param rate: the EMA rate (closer to 1 means slower). 66 | """ 67 | for targ, src in zip(target_params, source_params): 68 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 69 | 70 | 71 | def zero_module(module): 72 | """ 73 | Zero out the parameters of a module and return it. 74 | """ 75 | for p in module.parameters(): 76 | p.detach().zero_() 77 | return module 78 | 79 | 80 | def scale_module(module, scale): 81 | """ 82 | Scale the parameters of a module and return it. 83 | """ 84 | for p in module.parameters(): 85 | p.detach().mul_(scale) 86 | return module 87 | 88 | 89 | def mean_flat(tensor): 90 | """ 91 | Take the mean over all non-batch dimensions. 92 | """ 93 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 94 | 95 | 96 | def normalization(channels): 97 | """ 98 | Make a standard normalization layer. 99 | 100 | :param channels: number of input channels. 101 | :return: an nn.Module for normalization. 102 | """ 103 | return GroupNorm32(32, channels) 104 | 105 | 106 | def timestep_embedding(timesteps, dim, max_period=10000): 107 | """ 108 | Create sinusoidal timestep embeddings. 109 | 110 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 111 | These may be fractional. 112 | :param dim: the dimension of the output. 113 | :param max_period: controls the minimum frequency of the embeddings. 114 | :return: an [N x dim] Tensor of positional embeddings. 115 | """ 116 | half = dim // 2 117 | freqs = th.exp( 118 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 119 | ).to(device=timesteps.device) 120 | args = timesteps[:, None].float() * freqs[None] 121 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 122 | if dim % 2: 123 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 124 | return embedding 125 | 126 | 127 | def checkpoint(func, inputs, params, flag): 128 | """ 129 | Evaluate a function without caching intermediate activations, allowing for 130 | reduced memory at the expense of extra compute in the backward pass. 131 | 132 | :param func: the function to evaluate. 133 | :param inputs: the argument sequence to pass to `func`. 134 | :param params: a sequence of parameters `func` depends on but does not 135 | explicitly take as arguments. 136 | :param flag: if False, disable gradient checkpointing. 137 | """ 138 | if flag: 139 | args = tuple(inputs) + tuple(params) 140 | return CheckpointFunction.apply(func, len(inputs), *args) 141 | else: 142 | return func(*inputs) 143 | 144 | 145 | class CheckpointFunction(th.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, run_function, length, *args): 148 | ctx.run_function = run_function 149 | ctx.input_tensors = list(args[:length]) 150 | ctx.input_params = list(args[length:]) 151 | with th.no_grad(): 152 | output_tensors = ctx.run_function(*ctx.input_tensors) 153 | return output_tensors 154 | 155 | @staticmethod 156 | def backward(ctx, *output_grads): 157 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 158 | with th.enable_grad(): 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | 172 | 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📡 PhyRMDM: Physics-Informed Representation Alignment for Sparse Radio-Map Reconstruction 🚀 2 | 3 | > **"When Einstein Meets Deep Learning"** — We make radio signals elegantly dance to the laws of physics in virtual cities. 💃 4 | 5 | ## 📢 NEWS 6 | 7 | 🎉 **[NEW!]** Our paper has been **accepted** by **ACM MM BNI 2025** as an **Oral Presentation**! 🏆 8 | 🔓 **All code is now available** — Ready for researchers and practitioners to explore and build upon our work! 💻 9 | 🎯 **Pre-trained model weights are now available** — Download from Baidu Netdisk: [Link](https://pan.baidu.com/s/14p0aofKzp0jhreg-9NMKdQ) (Code: dnd4) 📦 10 | 11 | ## 🌟 Project Highlights 12 | 13 | ![RMDM Model Structure](RMDM.jpeg) 14 | 15 | - 🧠 **Physics-Informed AI**: Equipping neural networks with electromagnetic wisdom, enabling AI to think using Helmholtz equations. 16 | - 🎭 **Dual U-Net Architecture**: Two neural nets—one handling physical laws, the other refining details—working seamlessly to reconstruct radio maps. 17 | - 📉 **Record-Breaking Accuracy**: Achieved an unprecedented 0.0031 NMSE error in static scenarios, 2× better than state-of-the-art methods! 18 | - 🌪️ **Dynamic Scene Mastery**: Robust reconstruction in dynamic, interference-rich environments (vehicles, moving obstacles) with an impressive 0.0047 NMSE. 19 | - 🕵️ **Sparse Data Champion**: Capable of accurately reconstructing complete radio maps even from a mere 1% sampling—like Sherlock deducing from minimal clues. 20 | 21 | ## 🎯 Problems Solved 22 | 23 | - 🧩 **Signal Reconstruction Puzzle**: Restoring complete electromagnetic fields from fragmented measurements. 24 | - 🌆 **Urban Maze Complexity**: Seamlessly handling complex obstructions from buildings, moving vehicles, and urban environments. 25 | - ⚡ **Real-Time Performance**: Achieving inference speeds up to 10× faster than traditional methods—ideal for real-time 5G/6G applications. 26 | 27 | ## 🧠 Core Technical Innovations 28 | 29 | ### 🎵 **Dual U-Net Symphony** 30 | 31 | 1. **Physics-Conductor U-Net**: Embeds physical laws (Helmholtz equations) through Physics-Informed Neural Networks (PINNs). 32 | 2. **Detail-Sculptor U-Net**: Uses advanced diffusion models for ultra-fine precision in radio map reconstruction. 33 | 34 | ### 🔥 **Three Innovative Modules** 35 | 36 | - 🎯 **Anchor Conditional Mechanism**: Precisely locking onto critical physical landmarks (like GPS for radio signals). 37 | - 🌐 **RF-Space Attention**: Models "frequency symphonies" enabling focused learning of electromagnetic signal characteristics. 38 | - ⚖️ **Multi-Objective Loss**: Harmonizing physics-based constraints and data-driven fitting to achieve optimal results. 39 | 40 | ## 📂 Benchmark Dataset 41 | 42 | Leveraged the authoritative **RadioMapSeer dataset**: 43 | 44 | - 700+ real-world urban scenarios (London, Berlin, Tel Aviv, etc.) 45 | - 80 base stations per map with high-resolution 256×256 grids 46 | - Incorporates static and dynamic challenges (buildings, vehicles) 47 | 48 | ## 🚀 Quick Start 49 | 50 | ### 📋 Prerequisites 51 | 52 | - **Python**: 3.8+ 53 | - **PyTorch**: 2.0+ with CUDA support 54 | - **GPU**: NVIDIA GPU with 8GB+ VRAM (recommended) 55 | - **Dataset**: RadioMapSeer dataset 56 | 57 | ### 🛠️ Installation 58 | 59 | 1. **Clone the repository**: 60 | ```bash 61 | git clone git@github.com:Hxxxz0/RMDM.git 62 | cd PhyRMDM 63 | ``` 64 | 65 | 2. **Create and activate conda environment**: 66 | ```bash 67 | conda create -n RMDM python=3.8 68 | conda activate RMDM 69 | ``` 70 | 71 | 3. **Install dependencies**: 72 | ```bash 73 | # Install PyTorch with CUDA support 74 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 75 | 76 | # Install all other dependencies 77 | pip install -r requirement.txt 78 | ``` 79 | 80 | 4. **Verify installation**: 81 | ```bash 82 | python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')" 83 | ``` 84 | 85 | ### 📁 Dataset Setup 86 | 87 | Download and organize the RadioMapSeer dataset: 88 | ``` 89 | RadioMapSeer/ 90 | ├── gain/ 91 | │ ├── DPM/ 92 | │ ├── IRT2/ 93 | │ └── cars*/ 94 | ├── png/ 95 | │ ├── buildings_complete/ 96 | │ ├── antennas/ 97 | │ └── cars/ 98 | └── dataset.csv 99 | ``` 100 | 101 | ### 🎯 Training 102 | 103 | **Single GPU Training (SRM)**: 104 | ```bash 105 | conda activate RMDM 106 | python train.py \ 107 | --data_name Radio \ 108 | --data_dir /path/to/RadioMapSeer \ 109 | --batch_size 16 \ 110 | --mixed_precision no \ 111 | --use_checkpoint True \ 112 | --num_channels 96 \ 113 | --attention_resolutions 16 \ 114 | --log_interval 50 \ 115 | --max_steps 100000 \ 116 | --save_interval 10000 \ 117 | --save_dir ./checkpoints_phy 118 | ``` 119 | 120 | **Multi-GPU Training (SRM)**: 121 | ```bash 122 | accelerate launch --num_processes=2 --multi_gpu --mixed_precision=no \ 123 | train.py \ 124 | --data_name Radio \ 125 | --data_dir /path/to/RadioMapSeer \ 126 | --batch_size 32 \ 127 | --mixed_precision no \ 128 | --use_checkpoint True \ 129 | --num_channels 96 \ 130 | --attention_resolutions 16 \ 131 | --log_interval 50 \ 132 | --max_steps 100000 \ 133 | --save_interval 10000 \ 134 | --save_dir ./checkpoints_phy 135 | ``` 136 | 137 | **Resume Training**: 138 | ```bash 139 | python train.py \ 140 | --resume_from ./checkpoints_phy/model_phy_step5000.pth \ 141 | --data_name Radio \ 142 | --data_dir /path/to/RadioMapSeer \ 143 | --batch_size 16 \ 144 | --mixed_precision no \ 145 | --save_dir ./checkpoints_phy 146 | ``` 147 | 148 | ### 🔮 Inference & Evaluation 149 | 150 | **Quick Inference Test (SRM)**: 151 | ```bash 152 | python sample_test.py \ 153 | --scheduler_type ddpm \ 154 | --data_dir /path/to/RadioMapSeer \ 155 | --checkpoint_path ./checkpoints_phy/model_phy_step100000.pth \ 156 | --output_dir ./inference_results \ 157 | --ddpm_steps 1000 \ 158 | --batch_size 4 \ 159 | --num_samples 100 160 | ``` 161 | 162 | **Full Test Set Evaluation (SRM)**: 163 | ```bash 164 | python sample_test.py \ 165 | --scheduler_type ddpm \ 166 | --data_dir /path/to/RadioMapSeer \ 167 | --checkpoint_path ./checkpoints_phy/model_phy_step10000.pth \ 168 | --output_dir ./full_evaluation \ 169 | --ddpm_steps 1000 \ 170 | --batch_size 4 \ 171 | --num_samples -1 172 | ``` 173 | 174 | **Inference with Image Saving** 🖼️: 175 | ```bash 176 | python sample_test.py \ 177 | --scheduler_type ddpm \ 178 | --data_dir /path/to/RadioMapSeer \ 179 | --checkpoint_path ./checkpoints_phy/model_phy_step10000.pth \ 180 | --output_dir ./results_with_images \ 181 | --ddpm_steps 1000 \ 182 | --batch_size 4 \ 183 | --num_samples 50 \ 184 | --save_images 185 | ``` 186 | 187 | This will generate: 188 | - 📊 **Generated radio maps**: `generated/` folder 189 | - 🎯 **Ground truth maps**: `ground_truth/` folder 190 | - 🏗️ **Input conditions**: `conditions/` folder (buildings + transmitters) 191 | - 🔍 **Comparison plots**: `comparison/` folder (generated vs. ground truth vs. difference) 192 | 193 | ## 📜 Academic Citation 194 | 195 | ```bibtex 196 | @misc{jia2025rmdmradiomapdiffusion, 197 | title={RMDM: Radio Map Diffusion Model with Physics Informed}, 198 | author={Haozhe Jia and Wenshuo Chen and Zhihui Huang and Hongru Xiao and Nanqian Jia and Keming Wu and Songning Lai and Yutao Yue}, 199 | year={2025}, 200 | eprint={2501.19160}, 201 | archivePrefix={arXiv}, 202 | primaryClass={cs.CV}, 203 | url={https://arxiv.org/abs/2501.19160}, 204 | } 205 | ``` 206 | 207 | ## 🙌 Acknowledgments 208 | 209 | Special thanks to: 210 | - 🏫 Joint Laboratory of Hong Kong University of Science and Technology (Guangzhou) & Shandong University 211 | - 🌉 Guangzhou Education Bureau's Key Research Project 212 | - 🤖 DIILab for generous computational support 213 | 214 | 215 | 216 | --- 217 | 218 | **License**: This project is distributed under the **Academic Free License v3.0**. Please cite accordingly for academic use. For commercial applications, contact the authors directly. 219 | 220 | -------------------------------------------------------------------------------- /lib/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | def convrelu(in_channels, out_channels, kernel, padding, pool): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 8 | 9 | nn.ReLU(inplace=True), 10 | nn.MaxPool2d(pool, stride=pool, padding=0, dilation=1, return_indices=False, ceil_mode=False) 11 | 12 | ) 13 | 14 | def convreluT(in_channels, out_channels, kernel, padding): 15 | return nn.Sequential( 16 | nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=2, padding=padding), 17 | nn.ReLU(inplace=True) 18 | 19 | ) 20 | 21 | 22 | 23 | class RadioWNet(nn.Module): 24 | 25 | def __init__(self,inputs=2,phase="firstU"): 26 | super().__init__() 27 | 28 | self.inputs=inputs 29 | self.phase=phase 30 | 31 | if inputs<=3: 32 | self.layer00 = convrelu(inputs, 6, 3, 1,1) 33 | self.layer0 = convrelu(6, 40, 5, 2,2) 34 | else: 35 | self.layer00 = convrelu(inputs, 10, 3, 1,1) 36 | self.layer0 = convrelu(10, 40, 5, 2,2) 37 | 38 | self.layer1 = convrelu(40, 50, 5, 2,2) 39 | self.layer10 = convrelu(50, 60, 5, 2,1) 40 | self.layer2 = convrelu(60, 100, 5, 2,2) 41 | self.layer20 = convrelu(100, 100, 3, 1,1) 42 | self.layer3 = convrelu(100, 150, 5, 2,2) 43 | self.layer4 =convrelu(150, 300, 5, 2,2) 44 | self.layer5 =convrelu(300, 500, 5, 2,2) 45 | 46 | self.conv_up5 =convreluT(500, 300, 4, 1) 47 | self.conv_up4 = convreluT(300+300, 150, 4, 1) 48 | self.conv_up3 = convreluT(150 + 150, 100, 4, 1) 49 | self.conv_up20 = convrelu(100 + 100, 100, 3, 1, 1) 50 | self.conv_up2 = convreluT(100 + 100, 60, 6, 2) 51 | self.conv_up10 = convrelu(60 + 60, 50, 5, 2, 1) 52 | self.conv_up1 = convreluT(50 + 50, 40, 6, 2) 53 | self.conv_up0 = convreluT(40 + 40, 20, 6, 2) 54 | if inputs<=3: 55 | self.conv_up00 = convrelu(20+6+inputs, 20, 5, 2,1) 56 | 57 | else: 58 | self.conv_up00 = convrelu(20+10+inputs, 20, 5, 2,1) 59 | 60 | self.conv_up000 = convrelu(20+inputs, 1, 5, 2,1) 61 | 62 | self.Wlayer00 = convrelu(inputs+1, 20, 3, 1,1) 63 | self.Wlayer0 = convrelu(20, 30, 5, 2,2) 64 | self.Wlayer1 = convrelu(30, 40, 5, 2,2) 65 | self.Wlayer10 = convrelu(40, 50, 5, 2,1) 66 | self.Wlayer2 = convrelu(50, 60, 5, 2,2) 67 | self.Wlayer20 = convrelu(60, 70, 3, 1,1) 68 | self.Wlayer3 = convrelu(70, 90, 5, 2,2) 69 | self.Wlayer4 =convrelu(90, 110, 5, 2,2) 70 | self.Wlayer5 =convrelu(110, 150, 5, 2,2) 71 | 72 | self.Wconv_up5 =convreluT(150, 110, 4, 1) 73 | self.Wconv_up4 = convreluT(110+110, 90, 4, 1) 74 | self.Wconv_up3 = convreluT(90 + 90, 70, 4, 1) 75 | self.Wconv_up20 = convrelu(70 + 70, 60, 3, 1, 1) 76 | self.Wconv_up2 = convreluT(60 + 60, 50, 6, 2) 77 | self.Wconv_up10 = convrelu(50 + 50, 40, 5, 2, 1) 78 | self.Wconv_up1 = convreluT(40 + 40, 30, 6, 2) 79 | self.Wconv_up0 = convreluT(30 + 30, 20, 6, 2) 80 | self.Wconv_up00 = convrelu(20+20+inputs+1, 20, 5, 2,1) 81 | self.Wconv_up000 = convrelu(20+inputs+1, 1, 5, 2,1) 82 | 83 | def forward(self, input): 84 | 85 | input0=input[:,0:self.inputs,:,:] 86 | 87 | if self.phase=="firstU": 88 | layer00 = self.layer00(input0) 89 | layer0 = self.layer0(layer00) 90 | layer1 = self.layer1(layer0) 91 | layer10 = self.layer10(layer1) 92 | layer2 = self.layer2(layer10) 93 | layer20 = self.layer20(layer2) 94 | layer3 = self.layer3(layer20) 95 | layer4 = self.layer4(layer3) 96 | layer5 = self.layer5(layer4) 97 | 98 | layer4u = self.conv_up5(layer5) 99 | layer4u = torch.cat([layer4u, layer4], dim=1) 100 | layer3u = self.conv_up4(layer4u) 101 | layer3u = torch.cat([layer3u, layer3], dim=1) 102 | layer20u = self.conv_up3(layer3u) 103 | layer20u = torch.cat([layer20u, layer20], dim=1) 104 | layer2u = self.conv_up20(layer20u) 105 | layer2u = torch.cat([layer2u, layer2], dim=1) 106 | layer10u = self.conv_up2(layer2u) 107 | layer10u = torch.cat([layer10u, layer10], dim=1) 108 | layer1u = self.conv_up10(layer10u) 109 | layer1u = torch.cat([layer1u, layer1], dim=1) 110 | layer0u = self.conv_up1(layer1u) 111 | layer0u = torch.cat([layer0u, layer0], dim=1) 112 | layer00u = self.conv_up0(layer0u) 113 | layer00u = torch.cat([layer00u, layer00], dim=1) 114 | layer00u = torch.cat([layer00u,input0], dim=1) 115 | layer000u = self.conv_up00(layer00u) 116 | layer000u = torch.cat([layer000u,input0], dim=1) 117 | output1 = self.conv_up000(layer000u) 118 | 119 | Winput=torch.cat([output1, input], dim=1).detach() 120 | 121 | Wlayer00 = self.Wlayer00(Winput).detach() 122 | Wlayer0 = self.Wlayer0(Wlayer00).detach() 123 | Wlayer1 = self.Wlayer1(Wlayer0).detach() 124 | Wlayer10 = self.Wlayer10(Wlayer1).detach() 125 | Wlayer2 = self.Wlayer2(Wlayer10).detach() 126 | Wlayer20 = self.Wlayer20(Wlayer2).detach() 127 | Wlayer3 = self.Wlayer3(Wlayer20).detach() 128 | Wlayer4 = self.Wlayer4(Wlayer3).detach() 129 | Wlayer5 = self.Wlayer5(Wlayer4).detach() 130 | 131 | Wlayer4u = self.Wconv_up5(Wlayer5).detach() 132 | Wlayer4u = torch.cat([Wlayer4u, Wlayer4], dim=1).detach() 133 | Wlayer3u = self.Wconv_up4(Wlayer4u).detach() 134 | Wlayer3u = torch.cat([Wlayer3u, Wlayer3], dim=1).detach() 135 | Wlayer20u = self.Wconv_up3(Wlayer3u).detach() 136 | Wlayer20u = torch.cat([Wlayer20u, Wlayer20], dim=1).detach() 137 | Wlayer2u = self.Wconv_up20(Wlayer20u).detach() 138 | Wlayer2u = torch.cat([Wlayer2u, Wlayer2], dim=1).detach() 139 | Wlayer10u = self.Wconv_up2(Wlayer2u).detach() 140 | Wlayer10u = torch.cat([Wlayer10u, Wlayer10], dim=1).detach() 141 | Wlayer1u = self.Wconv_up10(Wlayer10u).detach() 142 | Wlayer1u = torch.cat([Wlayer1u, Wlayer1], dim=1).detach() 143 | Wlayer0u = self.Wconv_up1(Wlayer1u).detach() 144 | Wlayer0u = torch.cat([Wlayer0u, Wlayer0], dim=1).detach() 145 | Wlayer00u = self.Wconv_up0(Wlayer0u).detach() 146 | Wlayer00u = torch.cat([Wlayer00u, Wlayer00], dim=1).detach() 147 | Wlayer00u = torch.cat([Wlayer00u,Winput], dim=1).detach() 148 | Wlayer000u = self.Wconv_up00(Wlayer00u).detach() 149 | Wlayer000u = torch.cat([Wlayer000u,Winput], dim=1).detach() 150 | output2 = self.Wconv_up000(Wlayer000u).detach() 151 | 152 | else: 153 | layer00 = self.layer00(input0).detach() 154 | layer0 = self.layer0(layer00).detach() 155 | layer1 = self.layer1(layer0).detach() 156 | layer10 = self.layer10(layer1).detach() 157 | layer2 = self.layer2(layer10).detach() 158 | layer20 = self.layer20(layer2).detach() 159 | layer3 = self.layer3(layer20).detach() 160 | layer4 = self.layer4(layer3).detach() 161 | layer5 = self.layer5(layer4).detach() 162 | 163 | layer4u = self.conv_up5(layer5).detach() 164 | layer4u = torch.cat([layer4u, layer4], dim=1).detach() 165 | layer3u = self.conv_up4(layer4u).detach() 166 | layer3u = torch.cat([layer3u, layer3], dim=1).detach() 167 | layer20u = self.conv_up3(layer3u).detach() 168 | layer20u = torch.cat([layer20u, layer20], dim=1).detach() 169 | layer2u = self.conv_up20(layer20u).detach() 170 | layer2u = torch.cat([layer2u, layer2], dim=1).detach() 171 | layer10u = self.conv_up2(layer2u).detach() 172 | layer10u = torch.cat([layer10u, layer10], dim=1).detach() 173 | layer1u = self.conv_up10(layer10u).detach() 174 | layer1u = torch.cat([layer1u, layer1], dim=1).detach() 175 | layer0u = self.conv_up1(layer1u).detach() 176 | layer0u = torch.cat([layer0u, layer0], dim=1).detach() 177 | layer00u = self.conv_up0(layer0u).detach() 178 | layer00u = torch.cat([layer00u, layer00], dim=1).detach() 179 | layer00u = torch.cat([layer00u,input0], dim=1).detach() 180 | layer000u = self.conv_up00(layer00u).detach() 181 | layer000u = torch.cat([layer000u,input0], dim=1).detach() 182 | output1 = self.conv_up000(layer000u).detach() 183 | 184 | Winput=torch.cat([output1, input], dim=1).detach() 185 | 186 | Wlayer00 = self.Wlayer00(Winput) 187 | Wlayer0 = self.Wlayer0(Wlayer00) 188 | Wlayer1 = self.Wlayer1(Wlayer0) 189 | Wlayer10 = self.Wlayer10(Wlayer1) 190 | Wlayer2 = self.Wlayer2(Wlayer10) 191 | Wlayer20 = self.Wlayer20(Wlayer2) 192 | Wlayer3 = self.Wlayer3(Wlayer20) 193 | Wlayer4 = self.Wlayer4(Wlayer3) 194 | Wlayer5 = self.Wlayer5(Wlayer4) 195 | 196 | Wlayer4u = self.Wconv_up5(Wlayer5) 197 | Wlayer4u = torch.cat([Wlayer4u, Wlayer4], dim=1) 198 | Wlayer3u = self.Wconv_up4(Wlayer4u) 199 | Wlayer3u = torch.cat([Wlayer3u, Wlayer3], dim=1) 200 | Wlayer20u = self.Wconv_up3(Wlayer3u) 201 | Wlayer20u = torch.cat([Wlayer20u, Wlayer20], dim=1) 202 | Wlayer2u = self.Wconv_up20(Wlayer20u) 203 | Wlayer2u = torch.cat([Wlayer2u, Wlayer2], dim=1) 204 | Wlayer10u = self.Wconv_up2(Wlayer2u) 205 | Wlayer10u = torch.cat([Wlayer10u, Wlayer10], dim=1) 206 | Wlayer1u = self.Wconv_up10(Wlayer10u) 207 | Wlayer1u = torch.cat([Wlayer1u, Wlayer1], dim=1) 208 | Wlayer0u = self.Wconv_up1(Wlayer1u) 209 | Wlayer0u = torch.cat([Wlayer0u, Wlayer0], dim=1) 210 | Wlayer00u = self.Wconv_up0(Wlayer0u) 211 | Wlayer00u = torch.cat([Wlayer00u, Wlayer00], dim=1) 212 | Wlayer00u = torch.cat([Wlayer00u,Winput], dim=1) 213 | Wlayer000u = self.Wconv_up00(Wlayer00u) 214 | Wlayer000u = torch.cat([Wlayer000u,Winput], dim=1) 215 | output2 = self.Wconv_up000(Wlayer000u) 216 | 217 | return [output1,output2] 218 | 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | from typing import Dict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from accelerate import Accelerator 10 | from accelerate.utils import DistributedDataParallelKwargs 11 | 12 | from diffusers import DDPMScheduler 13 | 14 | # Internal module imports 15 | from utils import build_unet_from_config, cal_pinn 16 | from lib import loaders as radio_loaders 17 | 18 | 19 | def build_dataloader(data_name, data_dir, image_size, batch_size, workers): 20 | if data_name == 'Radio': 21 | ds = radio_loaders.RadioUNet_c(phase="train", dir_dataset=data_dir) 22 | in_ch = 2 # [buildings, Tx] input channels 23 | out_ch = 1 # target output channels 24 | elif data_name == 'Radio_2': 25 | ds = radio_loaders.RadioUNet_s(phase="train", carsSimul="yes", carsInput="yes", dir_dataset=data_dir) 26 | in_ch = 4 # [buildings, Tx, samples, cars] input channels 27 | out_ch = 1 # target output channels 28 | elif data_name == 'Radio_3': 29 | ds = radio_loaders.RadioUNet_s(phase="train", simulation="rand", cityMap="missing", missing=4, dir_dataset=data_dir) 30 | in_ch = 3 # [buildings, Tx, samples] input channels 31 | out_ch = 1 # target output channels 32 | else: 33 | raise ValueError("data_name must be 'Radio' | 'Radio_2' | 'Radio_3'") 34 | dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True, drop_last=True) 35 | return dl, in_ch, out_ch 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | # Data arguments 41 | parser.add_argument('--data_name', type=str, default='Radio', choices=['Radio','Radio_2','Radio_3']) 42 | parser.add_argument('--data_dir', type=str, required=True, help='RadioMapSeer root directory, e.g., /path/to/RadioMapSeer/') 43 | parser.add_argument('--image_size', type=int, default=256) 44 | parser.add_argument('--batch_size', type=int, default=8) 45 | parser.add_argument('--workers', type=int, default=8) 46 | # Model arguments 47 | parser.add_argument('--num_channels', type=int, default=128) 48 | parser.add_argument('--num_res_blocks', type=int, default=2) 49 | parser.add_argument('--attention_resolutions', type=str, default='16,8') 50 | parser.add_argument('--channel_mult', type=str, default='') 51 | parser.add_argument('--dropout', type=float, default=0.0) 52 | parser.add_argument('--use_checkpoint', type=bool, default=False) 53 | parser.add_argument('--use_scale_shift_norm', type=bool, default=True) 54 | parser.add_argument('--resblock_updown', type=bool, default=False) 55 | parser.add_argument('--use_fp16', type=bool, default=False) 56 | parser.add_argument('--num_heads', type=int, default=4) 57 | parser.add_argument('--num_head_channels', type=int, default=-1) 58 | parser.add_argument('--num_heads_upsample', type=int, default=-1) 59 | parser.add_argument('--use_new_attention_order', type=bool, default=False) 60 | # Diffusion/Training arguments 61 | parser.add_argument('--diffusion_steps', type=int, default=1000) 62 | parser.add_argument('--noise_schedule', type=str, default='linear', choices=['linear','cosine']) 63 | parser.add_argument('--lr', type=float, default=1e-4) 64 | parser.add_argument('--mixed_precision', type=str, default='no', choices=['no','fp16','bf16']) 65 | parser.add_argument('--max_steps', type=int, default=10000) 66 | parser.add_argument('--log_interval', type=int, default=100) 67 | parser.add_argument('--save_dir', type=str, default='./checkpoints_phy') 68 | parser.add_argument('--save_interval', type=int, default=5000, help='Save every N steps, 0 means save only at the end') 69 | # Resume training 70 | parser.add_argument('--resume_from', type=str, default='', help='Model checkpoint path, e.g., /path/to/model_phy_step10000.pth') 71 | parser.add_argument('--resume_step', type=int, default=0, help='Step to resume from (0 to infer from filename)') 72 | return parser.parse_args() 73 | 74 | 75 | def main(): 76 | args = parse_args() 77 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 78 | accelerator = Accelerator(mixed_precision=args.mixed_precision, kwargs_handlers=[ddp_kwargs]) 79 | device = accelerator.device 80 | 81 | dl, in_ch, out_ch = build_dataloader(args.data_name, args.data_dir, args.image_size, args.batch_size, args.workers) 82 | 83 | # Build UNet configuration - input: conditions + noisy target, output: denoised result 84 | cfg = { 85 | 'image_size': args.image_size, 86 | 'in_ch': in_ch + out_ch, 87 | 'out_ch': out_ch, 88 | 'num_channels': args.num_channels, 89 | 'num_res_blocks': args.num_res_blocks, 90 | 'channel_mult': args.channel_mult, 91 | 'num_heads': args.num_heads, 92 | 'num_head_channels': args.num_head_channels, 93 | 'num_heads_upsample': args.num_heads_upsample, 94 | 'attention_resolutions': args.attention_resolutions, 95 | 'dropout': args.dropout, 96 | 'class_cond': False, 97 | 'use_checkpoint': args.use_checkpoint, 98 | 'use_scale_shift_norm': args.use_scale_shift_norm, 99 | 'resblock_updown': args.resblock_updown, 100 | 'use_fp16': args.use_fp16, 101 | 'use_new_attention_order': args.use_new_attention_order, 102 | 'learn_sigma': False, 103 | } 104 | model = build_unet_from_config(cfg) 105 | 106 | # Load checkpoint if resume path is specified (before prepare for single/multi-GPU compatibility) 107 | start_step = 0 108 | if getattr(args, 'resume_from', ''): 109 | ckpt_path = args.resume_from 110 | if os.path.isfile(ckpt_path): 111 | state = torch.load(ckpt_path, map_location='cpu') 112 | try: 113 | model.load_state_dict(state, strict=True) 114 | except Exception as e: 115 | raise RuntimeError(f"Failed to load checkpoint: {ckpt_path}\n{e}") 116 | # Infer starting step: use --resume_step if specified, otherwise extract from filename 117 | if getattr(args, 'resume_step', 0) and args.resume_step > 0: 118 | start_step = args.resume_step 119 | else: 120 | m = re.search(r"model_phy_step(\d+)\.pth", os.path.basename(ckpt_path)) 121 | if m: 122 | start_step = int(m.group(1)) 123 | print(f"[Resume] Loaded checkpoint from {ckpt_path}, starting step={start_step}") 124 | else: 125 | raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}") 126 | 127 | # Use diffusers scheduler for noise scheduling 128 | beta_schedule = 'linear' if args.noise_schedule == 'linear' else 'squaredcos_cap_v2' 129 | scheduler = DDPMScheduler(num_train_timesteps=args.diffusion_steps, 130 | beta_schedule=beta_schedule, 131 | prediction_type='epsilon') 132 | 133 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) 134 | 135 | model, optimizer, dl = accelerator.prepare(model, optimizer, dl) 136 | 137 | # Set find_unused_parameters=True for distributed training 138 | if hasattr(model, 'module'): 139 | # This is a DDP wrapped model 140 | model.find_unused_parameters = True 141 | 142 | model.train() 143 | 144 | step = start_step 145 | while step < args.max_steps: 146 | for inputs, image_gain, _ in dl: 147 | # Separate condition inputs and target 148 | if image_gain.dim() == 3: 149 | image_gain = image_gain.unsqueeze(1) 150 | 151 | # Condition inputs (buildings, antennas, etc.) 152 | # Cache original masks for PINN (unmodified) 153 | raw_buildings = inputs[:, 0, ...].to(device) 154 | if inputs.size(1) >= 2: 155 | raw_antenna = inputs[:, 1, ...].to(device) 156 | else: 157 | raw_antenna = torch.zeros_like(raw_buildings) 158 | 159 | conditions = inputs.to(device) # (B, in_ch, H, W) 160 | # Consistent with RMDM: linear combination on channel 0: ch0 += 10 * ch1 (when ch1 exists) 161 | if conditions.size(1) >= 2: 162 | conditions[:, 0, ...] = conditions[:, 0, ...] + 10.0 * conditions[:, 1, ...] 163 | # Target (signal strength map) 164 | target_clean = image_gain.to(device) # (B, 1, H, W) 165 | 166 | b = target_clean.shape[0] 167 | t = torch.randint(0, scheduler.config.num_train_timesteps, (b,), device=device).long() 168 | 169 | # Standard diffusion: add noise to target 170 | noise = torch.randn_like(target_clean) 171 | target_noisy = scheduler.add_noise(target_clean, noise, t) 172 | 173 | # Concatenate conditions and noisy target as model input 174 | model_input = torch.cat([conditions, target_noisy], dim=1) # (B, in_ch+1, H, W) 175 | 176 | # Forward pass: enforce RMDM consistency, model must return (pred_noise, cal) 177 | out = model(model_input, t) 178 | if not isinstance(out, tuple) or len(out) < 2: 179 | raise RuntimeError("Model must output (pred_noise, cal) for RMDM consistency. Please enable cal branch output.") 180 | pred_noise, cal = out[0], out[1] 181 | 182 | # Diffusion loss: predicted noise vs ground truth noise 183 | loss_diff = F.mse_loss(pred_noise, noise) 184 | 185 | # Cal reconstruction supervision (RMDM consistent) 186 | loss_cal_recon = F.mse_loss(cal, target_clean) 187 | 188 | # PINN loss (applied to cal, RMDM consistent) 189 | # Use unmodified original masks to avoid linear combination affecting PINN masks 190 | buildings = raw_buildings 191 | antenna = raw_antenna 192 | loss_pinn_vec = cal_pinn(cal[:, 0, :, :], buildings, antenna, k=0.2) 193 | loss_pinn = loss_pinn_vec.mean() 194 | 195 | # Total loss (RMDM structure aligned): diffusion + reconstruction + PINN 196 | loss = loss_diff + loss_cal_recon + loss_pinn 197 | 198 | optimizer.zero_grad() 199 | accelerator.backward(loss) 200 | optimizer.step() 201 | 202 | if accelerator.is_main_process and step % args.log_interval == 0: 203 | print(f"step {step} loss {loss.item():.4f} diff {loss_diff.item():.4f} cal {loss_cal_recon.item():.4f} pinn {loss_pinn.item():.4f}") 204 | 205 | # Periodic saving: main process only, avoid saving at step=0 206 | if accelerator.is_main_process and args.save_interval and args.save_interval > 0: 207 | if step > 0 and (step % args.save_interval == 0): 208 | os.makedirs(args.save_dir, exist_ok=True) 209 | unwrapped_model = accelerator.unwrap_model(model) 210 | ckpt_path = os.path.join(args.save_dir, f'model_phy_step{step}.pth') 211 | torch.save(unwrapped_model.state_dict(), ckpt_path) 212 | 213 | step += 1 214 | if step >= args.max_steps: 215 | break 216 | 217 | if accelerator.is_main_process: 218 | os.makedirs(args.save_dir, exist_ok=True) 219 | unwrapped_model = accelerator.unwrap_model(model) 220 | torch.save(unwrapped_model.state_dict(), os.path.join(args.save_dir, 'model_phy.pth')) 221 | 222 | 223 | if __name__ == '__main__': 224 | main() 225 | 226 | 227 | -------------------------------------------------------------------------------- /utils/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | Copied from guided_diffusion.fp16_util with minimal dependencies. 4 | """ 5 | 6 | import numpy as np 7 | import torch as th 8 | import torch.nn as nn 9 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 10 | 11 | INITIAL_LOG_LOSS_SCALE = 20.0 12 | 13 | 14 | def convert_module_to_f16(l): 15 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 16 | l.weight.data = l.weight.data.half() 17 | if l.bias is not None: 18 | l.bias.data = l.bias.data.half() 19 | elif isinstance(l, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)): 20 | if hasattr(l, 'weight') and l.weight is not None: 21 | l.weight.data = l.weight.data.half() 22 | if hasattr(l, 'bias') and l.bias is not None: 23 | l.bias.data = l.bias.data.half() 24 | 25 | 26 | def convert_module_to_f32(l): 27 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 28 | l.weight.data = l.weight.data.float() 29 | if l.bias is not None: 30 | l.bias.data = l.bias.data.float() 31 | 32 | 33 | def make_master_params(param_groups_and_shapes): 34 | master_params = [] 35 | for param_group, shape in param_groups_and_shapes: 36 | master_param = nn.Parameter( 37 | _flatten_dense_tensors( 38 | [param.detach().float() for (_, param) in param_group] 39 | ).view(shape) 40 | ) 41 | master_param.requires_grad = True 42 | master_params.append(master_param) 43 | return master_params 44 | 45 | 46 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 47 | for master_param, (param_group, shape) in zip( 48 | master_params, param_groups_and_shapes 49 | ): 50 | master_param.grad = _flatten_dense_tensors( 51 | [param_grad_or_zeros(param) for (_, param) in param_group] 52 | ).view(shape) 53 | 54 | 55 | def master_params_to_model_params(param_groups_and_shapes, master_params): 56 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 57 | for (_, param), unflat_master_param in zip( 58 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 59 | ): 60 | param.detach().copy_(unflat_master_param) 61 | 62 | 63 | def unflatten_master_params(param_group, master_param): 64 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 65 | 66 | 67 | def get_param_groups_and_shapes(named_model_params): 68 | named_model_params = list(named_model_params) 69 | scalar_vector_named_params = ( 70 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 71 | (-1), 72 | ) 73 | matrix_named_params = ( 74 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 75 | (1, -1), 76 | ) 77 | return [scalar_vector_named_params, matrix_named_params] 78 | 79 | 80 | def zero_master_grads(master_params): 81 | for param in master_params: 82 | param.grad = None 83 | 84 | 85 | def zero_grad(model_params): 86 | for param in model_params: 87 | if param.grad is not None: 88 | param.grad.detach_() 89 | param.grad.zero_() 90 | 91 | 92 | def param_grad_or_zeros(param): 93 | if param.grad is not None: 94 | return param.grad.data.detach() 95 | else: 96 | return th.zeros_like(param) 97 | 98 | 99 | class MixedPrecisionTrainer: 100 | def __init__( 101 | self, 102 | *, 103 | model, 104 | use_fp16=False, 105 | fp16_scale_growth=1e-3, 106 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 107 | ): 108 | self.model = model 109 | self.use_fp16 = use_fp16 110 | self.fp16_scale_growth = fp16_scale_growth 111 | 112 | self.model_params = list(self.model.parameters()) 113 | self.master_params = self.model_params 114 | self.param_groups_and_shapes = None 115 | self.lg_loss_scale = initial_lg_loss_scale 116 | 117 | if self.use_fp16: 118 | self.param_groups_and_shapes = get_param_groups_and_shapes( 119 | self.model.named_parameters() 120 | ) 121 | self.master_params = make_master_params(self.param_groups_and_shapes) 122 | self.model.convert_to_fp16() 123 | 124 | def zero_grad(self): 125 | zero_grad(self.model_params) 126 | 127 | def backward(self, loss: th.Tensor): 128 | if self.use_fp16: 129 | loss_scale = 2 ** self.lg_loss_scale 130 | (loss * loss_scale).backward() 131 | else: 132 | loss.backward() 133 | 134 | def optimize(self, opt: th.optim.Optimizer): 135 | if self.use_fp16: 136 | return self._optimize_fp16(opt) 137 | else: 138 | return self._optimize_normal(opt) 139 | 140 | def _optimize_fp16(self, opt: th.optim.Optimizer): 141 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 142 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 143 | if check_overflow(grad_norm): 144 | self.lg_loss_scale -= 1 145 | zero_master_grads(self.master_params) 146 | return False 147 | 148 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 149 | opt.step() 150 | zero_master_grads(self.master_params) 151 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 152 | self.lg_loss_scale += self.fp16_scale_growth 153 | return True 154 | 155 | def _optimize_normal(self, opt: th.optim.Optimizer): 156 | opt.step() 157 | return True 158 | 159 | def _compute_norms(self, grad_scale=1.0): 160 | grad_norm = 0.0 161 | param_norm = 0.0 162 | for p in self.master_params: 163 | with th.no_grad(): 164 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 165 | if p.grad is not None: 166 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 167 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 168 | 169 | 170 | def check_overflow(value): 171 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 172 | 173 | """ 174 | Helpers to train with 16-bit precision. (Copied from guided_diffusion) 175 | """ 176 | 177 | import numpy as np 178 | import torch as th 179 | import torch.nn as nn 180 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 181 | 182 | INITIAL_LOG_LOSS_SCALE = 20.0 183 | 184 | 185 | def convert_module_to_f16(l): 186 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): 187 | l.weight.data = l.weight.data.half() 188 | if l.bias is not None: 189 | l.bias.data = l.bias.data.half() 190 | elif isinstance(l, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)): 191 | if hasattr(l, 'weight') and l.weight is not None: 192 | l.weight.data = l.weight.data.half() 193 | if hasattr(l, 'bias') and l.bias is not None: 194 | l.bias.data = l.bias.data.half() 195 | 196 | 197 | def convert_module_to_f32(l): 198 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 199 | l.weight.data = l.weight.data.float() 200 | if l.bias is not None: 201 | l.bias.data = l.bias.data.float() 202 | 203 | 204 | def make_master_params(param_groups_and_shapes): 205 | master_params = [] 206 | for param_group, shape in param_groups_and_shapes: 207 | master_param = nn.Parameter( 208 | _flatten_dense_tensors( 209 | [param.detach().float() for (_, param) in param_group] 210 | ).view(shape) 211 | ) 212 | master_param.requires_grad = True 213 | master_params.append(master_param) 214 | return master_params 215 | 216 | 217 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 218 | for master_param, (param_group, shape) in zip( 219 | master_params, param_groups_and_shapes 220 | ): 221 | master_param.grad = _flatten_dense_tensors( 222 | [param_grad_or_zeros(param) for (_, param) in param_group] 223 | ).view(shape) 224 | 225 | 226 | def master_params_to_model_params(param_groups_and_shapes, master_params): 227 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 228 | for (_, param), unflat_master_param in zip( 229 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 230 | ): 231 | param.detach().copy_(unflat_master_param) 232 | 233 | 234 | def unflatten_master_params(param_group, master_param): 235 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 236 | 237 | 238 | def get_param_groups_and_shapes(named_model_params): 239 | named_model_params = list(named_model_params) 240 | scalar_vector_named_params = ( 241 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 242 | (-1), 243 | ) 244 | matrix_named_params = ( 245 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 246 | (1, -1), 247 | ) 248 | return [scalar_vector_named_params, matrix_named_params] 249 | 250 | 251 | def master_params_to_state_dict(model, param_groups_and_shapes, master_params, use_fp16): 252 | if use_fp16: 253 | state_dict = model.state_dict() 254 | for master_param, (param_group, _) in zip( 255 | master_params, param_groups_and_shapes 256 | ): 257 | for (name, _), unflat_master_param in zip( 258 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 259 | ): 260 | assert name in state_dict 261 | state_dict[name] = unflat_master_param 262 | else: 263 | state_dict = model.state_dict() 264 | for i, (name, _value) in enumerate(model.named_parameters()): 265 | assert name in state_dict 266 | state_dict[name] = master_params[i] 267 | return state_dict 268 | 269 | 270 | def state_dict_to_master_params(model, state_dict, use_fp16): 271 | if use_fp16: 272 | named_model_params = [ 273 | (name, state_dict[name]) for name, _ in model.named_parameters() 274 | ] 275 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 276 | master_params = make_master_params(param_groups_and_shapes) 277 | else: 278 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 279 | return master_params 280 | 281 | 282 | def zero_master_grads(master_params): 283 | for param in master_params: 284 | param.grad = None 285 | 286 | 287 | def zero_grad(model_params): 288 | for param in model_params: 289 | if param.grad is not None: 290 | param.grad.detach_() 291 | param.grad.zero_() 292 | 293 | 294 | def param_grad_or_zeros(param): 295 | if param.grad is not None: 296 | return param.grad.data.detach() 297 | else: 298 | return th.zeros_like(param) 299 | 300 | 301 | class MixedPrecisionTrainer: 302 | def __init__( 303 | self, 304 | *, 305 | model, 306 | use_fp16=False, 307 | fp16_scale_growth=1e-3, 308 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 309 | ): 310 | self.model = model 311 | self.use_fp16 = use_fp16 312 | self.fp16_scale_growth = fp16_scale_growth 313 | 314 | self.model_params = list(self.model.parameters()) 315 | self.master_params = self.model_params 316 | self.param_groups_and_shapes = None 317 | self.lg_loss_scale = initial_lg_loss_scale 318 | 319 | if self.use_fp16: 320 | self.param_groups_and_shapes = get_param_groups_and_shapes( 321 | self.model.named_parameters() 322 | ) 323 | self.master_params = make_master_params(self.param_groups_and_shapes) 324 | self.model.convert_to_fp16() 325 | 326 | def zero_grad(self): 327 | zero_grad(self.model_params) 328 | 329 | def backward(self, loss: th.Tensor): 330 | if self.use_fp16: 331 | loss_scale = 2 ** self.lg_loss_scale 332 | (loss * loss_scale).backward() 333 | else: 334 | loss.backward() 335 | 336 | def optimize(self, opt: th.optim.Optimizer): 337 | if self.use_fp16: 338 | return self._optimize_fp16(opt) 339 | else: 340 | return self._optimize_normal(opt) 341 | 342 | def _optimize_fp16(self, opt: th.optim.Optimizer): 343 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 344 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 345 | if check_overflow(grad_norm): 346 | self.lg_loss_scale -= 1 347 | zero_master_grads(self.master_params) 348 | return False 349 | 350 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 351 | opt.step() 352 | zero_master_grads(self.master_params) 353 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 354 | self.lg_loss_scale += self.fp16_scale_growth 355 | return True 356 | 357 | def _optimize_normal(self, opt: th.optim.Optimizer): 358 | opt.step() 359 | return True 360 | 361 | def _compute_norms(self, grad_scale=1.0): 362 | grad_norm = 0.0 363 | param_norm = 0.0 364 | for p in self.master_params: 365 | with th.no_grad(): 366 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 367 | if p.grad is not None: 368 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 369 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 370 | 371 | 372 | def check_overflow(value): 373 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 374 | 375 | 376 | -------------------------------------------------------------------------------- /sample_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Full test set inference and evaluation script - supports DDPM and DDIM 4 | Compute metrics: MSE, NMSE, SSIM, PSNR 5 | """ 6 | 7 | import os 8 | import argparse 9 | import csv 10 | import time 11 | from typing import Dict, List, Tuple, Optional 12 | import json 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | import numpy as np 18 | from skimage.metrics import structural_similarity as ssim 19 | import matplotlib.pyplot as plt 20 | 21 | from diffusers import DDPMScheduler, DDIMScheduler, DPMSolverMultistepScheduler 22 | 23 | # Internal module imports 24 | from utils import build_unet_from_config 25 | from lib import loaders as radio_loaders 26 | 27 | 28 | # ============================================================================ 29 | # Data loading functions 30 | # ============================================================================ 31 | 32 | def build_dataloader(data_name: str, data_dir: str, image_size: int, 33 | batch_size: int, workers: int, phase: str = "test") -> Tuple[DataLoader, int, int]: 34 | """Build data loader""" 35 | dataset_configs = { 36 | 'Radio': { 37 | 'loader': lambda: radio_loaders.RadioUNet_c(phase=phase, dir_dataset=data_dir), 38 | 'in_ch': 2, # [buildings, Tx] 39 | 'out_ch': 1 # target 40 | }, 41 | 'Radio_2': { 42 | 'loader': lambda: radio_loaders.RadioUNet_s( 43 | phase=phase, carsSimul="yes", carsInput="yes", dir_dataset=data_dir 44 | ), 45 | 'in_ch': 4, # [buildings, Tx, samples, cars] 46 | 'out_ch': 1 # target 47 | }, 48 | 'Radio_3': { 49 | 'loader': lambda: radio_loaders.RadioUNet_s( 50 | phase=phase, simulation="rand", cityMap="missing", missing=4, dir_dataset=data_dir 51 | ), 52 | 'in_ch': 3, # [buildings, Tx, samples] 53 | 'out_ch': 1 # target 54 | } 55 | } 56 | 57 | if data_name not in dataset_configs: 58 | raise ValueError(f"data_name must be one of {list(dataset_configs.keys())}") 59 | 60 | config = dataset_configs[data_name] 61 | ds = config['loader']() 62 | dl = DataLoader(ds, batch_size=batch_size, shuffle=False, 63 | num_workers=workers, pin_memory=True) 64 | 65 | return dl, config['in_ch'], config['out_ch'] 66 | 67 | 68 | # ============================================================================ 69 | # Scheduler functions 70 | # ============================================================================ 71 | 72 | def create_scheduler(scheduler_type: str, num_train_timesteps: int = 1000, 73 | noise_schedule: str = 'linear'): 74 | """Create scheduler""" 75 | beta_schedule = 'linear' if noise_schedule == 'linear' else 'squaredcos_cap_v2' 76 | 77 | if scheduler_type == 'ddpm': 78 | return DDPMScheduler( 79 | num_train_timesteps=num_train_timesteps, 80 | beta_schedule=beta_schedule, 81 | prediction_type='epsilon' 82 | ) 83 | elif scheduler_type == 'ddim': 84 | # Enable clipping and set_alpha_to_one for more stable pixel-space sampling 85 | return DDIMScheduler( 86 | num_train_timesteps=num_train_timesteps, 87 | beta_schedule=beta_schedule, 88 | prediction_type='epsilon', 89 | clip_sample=True, 90 | set_alpha_to_one=True, 91 | steps_offset=0 92 | ) 93 | elif scheduler_type == 'dpm': 94 | return DPMSolverMultistepScheduler( 95 | num_train_timesteps=num_train_timesteps, 96 | beta_schedule=beta_schedule, 97 | prediction_type='epsilon' 98 | ) 99 | else: 100 | raise ValueError(f"Unsupported scheduler type: {scheduler_type}") 101 | 102 | 103 | # ============================================================================ 104 | # Inference functions 105 | # ============================================================================ 106 | 107 | def sample_ddpm(model, scheduler, conditions: torch.Tensor, 108 | num_inference_steps: int = 1000, device: str = 'cuda') -> Tuple[torch.Tensor, float]: 109 | """DDPM sampling""" 110 | batch_size, _, height, width = conditions.shape 111 | 112 | # Initial noise 113 | image = torch.randn((batch_size, 1, height, width), device=device) 114 | 115 | # Set inference steps 116 | scheduler.set_timesteps(num_inference_steps, device=device) 117 | 118 | model.eval() 119 | start_time = time.time() 120 | 121 | with torch.no_grad(): 122 | for timestep in scheduler.timesteps: 123 | # Concatenate conditions and current noisy image 124 | model_input = torch.cat([conditions, image], dim=1) 125 | 126 | # Predict noise 127 | timestep_batch = torch.tensor([timestep] * batch_size, device=device).long() 128 | out = model(model_input, timestep_batch) 129 | noise_pred = out[0] if isinstance(out, tuple) else out 130 | 131 | # DDPM denoising step 132 | batch_size_curr = image.shape[0] 133 | image_list = [] 134 | for j in range(batch_size_curr): 135 | single_step = scheduler.step( 136 | noise_pred[j:j+1], timestep, image[j:j+1], return_dict=False 137 | )[0] 138 | image_list.append(single_step) 139 | image = torch.cat(image_list, dim=0) 140 | 141 | end_time = time.time() 142 | sampling_time = end_time - start_time 143 | 144 | return image, sampling_time 145 | 146 | 147 | def sample_ddim(model, scheduler, conditions: torch.Tensor, 148 | num_inference_steps: int = 50, device: str = 'cuda', 149 | ddim_eta: float = 1.0) -> Tuple[torch.Tensor, float]: 150 | """DDIM sampling""" 151 | batch_size, _, height, width = conditions.shape 152 | 153 | 154 | image = torch.randn((batch_size, 1, height, width), device=device) 155 | 156 | # Set inference steps 157 | scheduler.set_timesteps(num_inference_steps, device=device) 158 | 159 | 160 | model.eval() 161 | start_time = time.time() 162 | 163 | with torch.no_grad(): 164 | for timestep in scheduler.timesteps: 165 | 166 | if hasattr(scheduler, 'scale_model_input'): 167 | scaled_image = scheduler.scale_model_input(image, timestep) 168 | else: 169 | scaled_image = image 170 | 171 | 172 | model_input = torch.cat([conditions, scaled_image], dim=1) 173 | 174 | 175 | timestep_batch = torch.tensor([timestep] * batch_size, device=device).long() 176 | out = model(model_input, timestep_batch) 177 | noise_pred = out[0] if isinstance(out, tuple) else out 178 | 179 | 180 | image = scheduler.step( 181 | noise_pred, timestep, image, eta=ddim_eta, 182 | use_clipped_model_output=False, 183 | return_dict=False 184 | )[0] 185 | 186 | end_time = time.time() 187 | sampling_time = end_time - start_time 188 | 189 | return image, sampling_time 190 | 191 | 192 | def sample_dpm(model, scheduler, conditions: torch.Tensor, 193 | num_inference_steps: int = 50, device: str = 'cuda') -> Tuple[torch.Tensor, float]: 194 | """DPM-Solver (multistep) sampling""" 195 | batch_size, _, height, width = conditions.shape 196 | image = torch.randn((batch_size, 1, height, width), device=device) 197 | 198 | scheduler.set_timesteps(num_inference_steps, device=device) 199 | 200 | model.eval() 201 | start_time = time.time() 202 | 203 | with torch.no_grad(): 204 | for timestep in scheduler.timesteps: 205 | if hasattr(scheduler, 'scale_model_input'): 206 | scaled_image = scheduler.scale_model_input(image, timestep) 207 | else: 208 | scaled_image = image 209 | 210 | model_input = torch.cat([conditions, scaled_image], dim=1) 211 | 212 | timestep_batch = torch.tensor([timestep] * batch_size, device=device).long() 213 | out = model(model_input, timestep_batch) 214 | noise_pred = out[0] if isinstance(out, tuple) else out 215 | 216 | image = scheduler.step( 217 | noise_pred, timestep, image, return_dict=False 218 | )[0] 219 | 220 | end_time = time.time() 221 | sampling_time = end_time - start_time 222 | 223 | return image, sampling_time 224 | 225 | def preprocess_conditions(conditions: torch.Tensor) -> torch.Tensor: 226 | """Preprocess condition inputs, consistent with training""" 227 | if conditions.size(1) >= 2: 228 | conditions[:, 0, ...] = conditions[:, 0, ...] + 10.0 * conditions[:, 1, ...] 229 | return conditions 230 | 231 | 232 | # ============================================================================ 233 | # Metrics calculation functions 234 | # ============================================================================ 235 | 236 | def calculate_ssim_batch(generated: torch.Tensor, ground_truth: torch.Tensor) -> List[float]: 237 | """Calculate batch SSIM""" 238 | batch_size = generated.shape[0] 239 | ssim_values = [] 240 | 241 | # Convert to numpy arrays 242 | gen_np = generated.detach().cpu().numpy() 243 | gt_np = ground_truth.detach().cpu().numpy() 244 | 245 | for i in range(batch_size): 246 | # Get single sample (1, H, W) 247 | gen_sample = gen_np[i, 0] # (H, W) 248 | gt_sample = gt_np[i, 0] # (H, W) 249 | 250 | # Calculate data range 251 | data_range = gt_sample.max() - gt_sample.min() 252 | if data_range == 0: 253 | data_range = 1.0 254 | 255 | # Calculate SSIM 256 | ssim_val = ssim(gt_sample, gen_sample, data_range=data_range) 257 | ssim_values.append(ssim_val) 258 | 259 | return ssim_values 260 | 261 | 262 | def calculate_metrics(generated: torch.Tensor, ground_truth: torch.Tensor) -> Dict[str, List[float]]: 263 | """Calculate all metrics""" 264 | with torch.no_grad(): 265 | # Handle NaN and infinite values 266 | gen = torch.nan_to_num(generated, nan=0.0, posinf=0.0, neginf=0.0) 267 | gt = torch.nan_to_num(ground_truth, nan=0.0, posinf=0.0, neginf=0.0) 268 | 269 | batch_size = gen.shape[0] 270 | metrics = { 271 | 'MSE': [], 272 | 'NMSE': [], 273 | 'PSNR': [], 274 | 'SSIM': [] 275 | } 276 | 277 | # Calculate SSIM 278 | ssim_values = calculate_ssim_batch(gen, gt) 279 | metrics['SSIM'] = ssim_values 280 | 281 | # Calculate other metrics per sample 282 | for i in range(batch_size): 283 | gen_sample = gen[i:i+1] 284 | gt_sample = gt[i:i+1] 285 | 286 | # MSE 287 | mse_val = F.mse_loss(gen_sample, gt_sample).item() 288 | metrics['MSE'].append(mse_val) 289 | 290 | # NMSE (Normalized MSE) 291 | gt_power = torch.mean(gt_sample ** 2).item() 292 | nmse_val = mse_val / gt_power if gt_power > 0 else float('inf') 293 | metrics['NMSE'].append(nmse_val) 294 | 295 | # PSNR 296 | data_range = (gt_sample.max() - gt_sample.min()).item() 297 | if data_range <= 1e-12: 298 | data_range = 1.0 299 | psnr_val = 20.0 * np.log10(data_range) - 10.0 * np.log10(mse_val) if mse_val > 0 else float('inf') 300 | metrics['PSNR'].append(psnr_val) 301 | 302 | return metrics 303 | 304 | 305 | # ============================================================================ 306 | # Main utility functions 307 | # ============================================================================ 308 | 309 | def build_model_config(args: argparse.Namespace, in_ch: int, out_ch: int) -> Dict: 310 | """Build model configuration""" 311 | return { 312 | 'image_size': args.image_size, 313 | 'in_ch': in_ch + out_ch, # condition channels + noisy target channels 314 | 'out_ch': out_ch, # output denoised target 315 | 'num_channels': args.num_channels, 316 | 'num_res_blocks': args.num_res_blocks, 317 | 'channel_mult': args.channel_mult, 318 | 'num_heads': args.num_heads, 319 | 'num_head_channels': args.num_head_channels, 320 | 'num_heads_upsample': args.num_heads_upsample, 321 | 'attention_resolutions': args.attention_resolutions, 322 | 'dropout': args.dropout, 323 | 'class_cond': False, 324 | 'use_checkpoint': args.use_checkpoint, 325 | 'use_scale_shift_norm': args.use_scale_shift_norm, 326 | 'resblock_updown': args.resblock_updown, 327 | 'use_fp16': args.use_fp16, 328 | 'use_new_attention_order': args.use_new_attention_order, 329 | 'learn_sigma': False, 330 | } 331 | 332 | 333 | def _sanitize_filename_component(name: str) -> str: 334 | """Sanitize a string to be safe for filenames.""" 335 | # Replace problematic characters with underscore 336 | safe = ''.join(c if c.isalnum() or c in ('-', '_') else '_' for c in name) 337 | # Collapse consecutive underscores 338 | while '__' in safe: 339 | safe = safe.replace('__', '_') 340 | return safe.strip('_') or 'sample' 341 | 342 | 343 | def save_images(generated: torch.Tensor, ground_truth: torch.Tensor, conditions: torch.Tensor, 344 | names: List[str], batch_idx: int, output_dir: str, global_index_base: int): 345 | """Save generated images, ground truth, and conditions with globally unique filenames.""" 346 | import os 347 | 348 | # Create subdirectories for different types of images 349 | gen_dir = os.path.join(output_dir, 'generated') 350 | gt_dir = os.path.join(output_dir, 'ground_truth') 351 | cond_dir = os.path.join(output_dir, 'conditions') 352 | comp_dir = os.path.join(output_dir, 'comparison') 353 | 354 | os.makedirs(gen_dir, exist_ok=True) 355 | os.makedirs(gt_dir, exist_ok=True) 356 | os.makedirs(cond_dir, exist_ok=True) 357 | os.makedirs(comp_dir, exist_ok=True) 358 | 359 | batch_size = generated.shape[0] 360 | 361 | for i in range(batch_size): 362 | # Compute a global index to guarantee uniqueness across the whole run 363 | global_index = global_index_base + i 364 | # Derive a human-friendly name, then sanitize 365 | raw_name = None 366 | if names and i < len(names): 367 | try: 368 | raw_name = names[i] 369 | if isinstance(raw_name, bytes): 370 | raw_name = raw_name.decode('utf-8', errors='ignore') 371 | else: 372 | raw_name = str(raw_name) 373 | except Exception: 374 | raw_name = None 375 | sample_human = _sanitize_filename_component(raw_name) if raw_name is not None else 'sample' 376 | # Always include a unique prefix to avoid any overwrite 377 | base_name = f"g{global_index:07d}_b{batch_idx:04d}_s{i:02d}_{sample_human}" 378 | 379 | # Convert tensors to numpy arrays 380 | gen_img = generated[i, 0].detach().cpu().numpy() 381 | gt_img = ground_truth[i, 0].detach().cpu().numpy() 382 | 383 | # Save individual images 384 | plt.figure(figsize=(6, 6)) 385 | plt.imshow(gen_img, cmap='viridis') 386 | plt.colorbar() 387 | plt.title(f'Generated - {base_name}') 388 | plt.axis('off') 389 | plt.tight_layout() 390 | plt.savefig(os.path.join(gen_dir, f'{base_name}_generated.png'), dpi=150, bbox_inches='tight') 391 | plt.close() 392 | 393 | plt.figure(figsize=(6, 6)) 394 | plt.imshow(gt_img, cmap='viridis') 395 | plt.colorbar() 396 | plt.title(f'Ground Truth - {base_name}') 397 | plt.axis('off') 398 | plt.tight_layout() 399 | plt.savefig(os.path.join(gt_dir, f'{base_name}_ground_truth.png'), dpi=150, bbox_inches='tight') 400 | plt.close() 401 | 402 | # Save conditions (buildings and transmitters) 403 | if conditions.shape[1] >= 2: 404 | buildings = conditions[i, 0].detach().cpu().numpy() 405 | tx = conditions[i, 1].detach().cpu().numpy() 406 | 407 | fig, axes = plt.subplots(1, 2, figsize=(12, 5)) 408 | 409 | im1 = axes[0].imshow(buildings, cmap='gray') 410 | axes[0].set_title(f'Buildings - {base_name}') 411 | axes[0].axis('off') 412 | plt.colorbar(im1, ax=axes[0]) 413 | 414 | im2 = axes[1].imshow(tx, cmap='hot') 415 | axes[1].set_title(f'Transmitters - {base_name}') 416 | axes[1].axis('off') 417 | plt.colorbar(im2, ax=axes[1]) 418 | 419 | plt.tight_layout() 420 | plt.savefig(os.path.join(cond_dir, f'{base_name}_conditions.png'), dpi=150, bbox_inches='tight') 421 | plt.close() 422 | 423 | # Save comparison plot 424 | fig, axes = plt.subplots(1, 3, figsize=(18, 5)) 425 | 426 | # Generated 427 | im1 = axes[0].imshow(gen_img, cmap='viridis') 428 | axes[0].set_title(f'Generated - {base_name}') 429 | axes[0].axis('off') 430 | plt.colorbar(im1, ax=axes[0]) 431 | 432 | # Ground Truth 433 | im2 = axes[1].imshow(gt_img, cmap='viridis') 434 | axes[1].set_title(f'Ground Truth - {base_name}') 435 | axes[1].axis('off') 436 | plt.colorbar(im2, ax=axes[1]) 437 | 438 | # Difference 439 | diff = np.abs(gen_img - gt_img) 440 | im3 = axes[2].imshow(diff, cmap='Reds') 441 | axes[2].set_title(f'Absolute Difference - {base_name}') 442 | axes[2].axis('off') 443 | plt.colorbar(im3, ax=axes[2]) 444 | 445 | plt.tight_layout() 446 | plt.savefig(os.path.join(comp_dir, f'{base_name}_comparison.png'), dpi=150, bbox_inches='tight') 447 | plt.close() 448 | 449 | 450 | def save_results(metrics_all: Dict[str, List[float]], args: argparse.Namespace, 451 | total_time: float, avg_time_per_batch: float): 452 | """Save results to files""" 453 | # Calculate average metrics 454 | avg_metrics = {} 455 | for metric_name, values in metrics_all.items(): 456 | valid_values = [v for v in values if np.isfinite(v)] 457 | if valid_values: 458 | avg_metrics[metric_name] = { 459 | 'mean': np.mean(valid_values), 460 | 'std': np.std(valid_values), 461 | 'count': len(valid_values) 462 | } 463 | else: 464 | avg_metrics[metric_name] = { 465 | 'mean': float('nan'), 466 | 'std': float('nan'), 467 | 'count': 0 468 | } 469 | 470 | # Save detailed results 471 | results = { 472 | 'config': { 473 | 'scheduler_type': args.scheduler_type, 474 | 'num_inference_steps': args.ddpm_steps if args.scheduler_type == 'ddpm' else args.ddim_steps, 475 | 'ddim_eta': args.ddim_eta if args.scheduler_type == 'ddim' else None, 476 | 'data_name': args.data_name, 477 | 'batch_size': args.batch_size, 478 | 'num_samples': args.num_samples if args.num_samples > 0 else 'all' 479 | }, 480 | 'timing': { 481 | 'total_time_seconds': total_time, 482 | 'avg_time_per_batch_seconds': avg_time_per_batch, 483 | 'total_samples': len(metrics_all['MSE']) 484 | }, 485 | 'metrics': avg_metrics 486 | } 487 | 488 | # Save JSON results 489 | json_path = os.path.join(args.output_dir, f'results_{args.scheduler_type}.json') 490 | with open(json_path, 'w') as f: 491 | json.dump(results, f, indent=2) 492 | 493 | # Save CSV detailed results 494 | csv_path = os.path.join(args.output_dir, f'detailed_metrics_{args.scheduler_type}.csv') 495 | with open(csv_path, 'w', newline='') as f: 496 | writer = csv.writer(f) 497 | writer.writerow(['Sample_ID', 'MSE', 'NMSE', 'PSNR', 'SSIM']) 498 | for i in range(len(metrics_all['MSE'])): 499 | writer.writerow([ 500 | i, 501 | metrics_all['MSE'][i], 502 | metrics_all['NMSE'][i], 503 | metrics_all['PSNR'][i], 504 | metrics_all['SSIM'][i] 505 | ]) 506 | 507 | print(f"Results saved to:") 508 | print(f" - {json_path}") 509 | print(f" - {csv_path}") 510 | 511 | 512 | def print_results(metrics_all: Dict[str, List[float]], scheduler_type: str, 513 | total_time: float, total_samples: int): 514 | """Print results summary""" 515 | print("\n" + "="*60) 516 | print(f"{scheduler_type.upper()} Test Results Summary") 517 | print("="*60) 518 | 519 | for metric_name, values in metrics_all.items(): 520 | valid_values = [v for v in values if np.isfinite(v)] 521 | if valid_values: 522 | mean_val = np.mean(valid_values) 523 | std_val = np.std(valid_values) 524 | print(f"{metric_name:6s}: {mean_val:8.6f} ± {std_val:8.6f} (n={len(valid_values)})") 525 | else: 526 | print(f"{metric_name:6s}: No valid values") 527 | 528 | print(f"\nTotal time: {total_time:.2f}s") 529 | print(f"Total samples: {total_samples}") 530 | print(f"Average time per sample: {total_time/total_samples:.3f}s") 531 | 532 | 533 | # ============================================================================ 534 | # Command line argument parsing 535 | # ============================================================================ 536 | 537 | def parse_args() -> argparse.Namespace: 538 | """Parse command line arguments""" 539 | parser = argparse.ArgumentParser( 540 | description='Full test set inference and evaluation script', 541 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 542 | ) 543 | 544 | # Basic arguments 545 | parser.add_argument('--scheduler_type', type=str, default='ddim', 546 | choices=['ddpm', 'ddim', 'dpm'], help='Scheduler type') 547 | parser.add_argument('--data_name', type=str, default='Radio', 548 | choices=['Radio', 'Radio_2', 'Radio_3'], help='Dataset name') 549 | parser.add_argument('--data_dir', type=str, required=True, 550 | help='RadioMapSeer dataset root directory') 551 | parser.add_argument('--checkpoint_path', type=str, required=True, 552 | help='Trained model checkpoint file path') 553 | parser.add_argument('--output_dir', type=str, default='./sample_test_results', 554 | help='Results output directory') 555 | parser.add_argument('--save_images', action='store_true', 556 | help='Save generated images, ground truth, and comparisons') 557 | 558 | # Inference arguments 559 | parser.add_argument('--ddpm_steps', type=int, default=1000, help='DDPM inference steps') 560 | parser.add_argument('--ddim_steps', type=int, default=50, help='DDIM inference steps') 561 | parser.add_argument('--dpm_steps', type=int, default=50, help='DPM-Solver inference steps') 562 | parser.add_argument('--ddim_eta', type=float, default=1.0, help='DDIM eta parameter') 563 | parser.add_argument('--diffusion_steps', type=int, default=1000, help='Training diffusion steps') 564 | parser.add_argument('--noise_schedule', type=str, default='linear', 565 | choices=['linear', 'cosine'], help='Noise schedule type') 566 | 567 | # Data arguments 568 | parser.add_argument('--image_size', type=int, default=256, help='Image size') 569 | parser.add_argument('--batch_size', type=int, default=4, help='Batch size') 570 | parser.add_argument('--workers', type=int, default=4, help='Data loading worker processes') 571 | parser.add_argument('--num_samples', type=int, default=-1, 572 | help='Number of test samples, <=0 means full test set') 573 | 574 | # Model arguments 575 | parser.add_argument('--num_channels', type=int, default=96, help='UNet base channels') 576 | parser.add_argument('--num_res_blocks', type=int, default=2, help='Number of ResNet blocks per layer') 577 | parser.add_argument('--attention_resolutions', type=str, default='16', 578 | help='Attention mechanism resolution') 579 | parser.add_argument('--channel_mult', type=str, default='', help='Channel multiplier settings') 580 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout probability') 581 | parser.add_argument('--use_checkpoint', type=bool, default=False, help='Whether to use gradient checkpointing') 582 | parser.add_argument('--use_scale_shift_norm', type=bool, default=True, 583 | help='Whether to use scale-shift normalization') 584 | parser.add_argument('--resblock_updown', type=bool, default=False, 585 | help='Whether to use up/downsampling in ResNet blocks') 586 | parser.add_argument('--use_fp16', type=bool, default=False, help='Whether to use half precision') 587 | parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads') 588 | parser.add_argument('--num_head_channels', type=int, default=-1, help='Channels per attention head') 589 | parser.add_argument('--num_heads_upsample', type=int, default=-1, help='Attention heads for upsampling') 590 | parser.add_argument('--use_new_attention_order', type=bool, default=False, 591 | help='Whether to use new attention order') 592 | 593 | return parser.parse_args() 594 | 595 | 596 | # ============================================================================ 597 | # Main function 598 | # ============================================================================ 599 | 600 | def main(): 601 | args = parse_args() 602 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 603 | print(f"Using device: {device}") 604 | print(f"Scheduler type: {args.scheduler_type.upper()}") 605 | 606 | # Create output directory 607 | os.makedirs(args.output_dir, exist_ok=True) 608 | 609 | # Build data loader 610 | print("Building data loader...") 611 | dl, in_ch, out_ch = build_dataloader( 612 | args.data_name, args.data_dir, args.image_size, 613 | args.batch_size, args.workers, phase="test" 614 | ) 615 | 616 | # Build model 617 | print("Building model...") 618 | model_config = build_model_config(args, in_ch, out_ch) 619 | model = build_unet_from_config(model_config) 620 | 621 | # Load model weights 622 | print(f"Loading model weights: {args.checkpoint_path}") 623 | checkpoint = torch.load(args.checkpoint_path, map_location=device) 624 | model.load_state_dict(checkpoint) 625 | model = model.to(device) 626 | 627 | # Create scheduler 628 | scheduler = create_scheduler( 629 | args.scheduler_type, 630 | num_train_timesteps=args.diffusion_steps, 631 | noise_schedule=args.noise_schedule 632 | ) 633 | 634 | # Determine inference steps 635 | if args.scheduler_type == 'ddpm': 636 | num_inference_steps = args.ddpm_steps 637 | print(f"DDPM inference steps: {num_inference_steps}") 638 | else: 639 | if args.scheduler_type == 'ddim': 640 | num_inference_steps = args.ddim_steps 641 | print(f"DDIM inference steps: {num_inference_steps}, eta: {args.ddim_eta}") 642 | elif args.scheduler_type == 'dpm': 643 | num_inference_steps = args.dpm_steps 644 | print(f"DPM-Solver inference steps: {num_inference_steps}") 645 | 646 | # Start inference testing 647 | print("Starting inference testing...") 648 | all_metrics = {'MSE': [], 'NMSE': [], 'PSNR': [], 'SSIM': []} 649 | total_start_time = time.time() 650 | batch_times = [] 651 | sample_count = 0 652 | 653 | for batch_idx, (inputs, image_gain, names) in enumerate(dl): 654 | if args.num_samples > 0 and sample_count >= args.num_samples: 655 | break 656 | 657 | print(f"Processing batch {batch_idx + 1}/{len(dl)}") 658 | 659 | # Prepare data 660 | if image_gain.dim() == 3: 661 | image_gain = image_gain.unsqueeze(1) 662 | 663 | conditions = preprocess_conditions(inputs.to(device)) 664 | ground_truth = image_gain.to(device) 665 | 666 | # Inference 667 | if args.scheduler_type == 'ddpm': 668 | generated, sampling_time = sample_ddpm( 669 | model, scheduler, conditions, num_inference_steps, device 670 | ) 671 | else: 672 | if args.scheduler_type == 'ddim': 673 | generated, sampling_time = sample_ddim( 674 | model, scheduler, conditions, num_inference_steps, device, args.ddim_eta 675 | ) 676 | elif args.scheduler_type == 'dpm': 677 | generated, sampling_time = sample_dpm( 678 | model, scheduler, conditions, num_inference_steps, device 679 | ) 680 | 681 | batch_times.append(sampling_time) 682 | sample_count += conditions.shape[0] 683 | 684 | # Calculate metrics 685 | batch_metrics = calculate_metrics(generated, ground_truth) 686 | for metric_name, values in batch_metrics.items(): 687 | all_metrics[metric_name].extend(values) 688 | 689 | # Save images if requested 690 | if args.save_images: 691 | # Use a global base index so filenames are unique across the whole run 692 | base_index = sample_count - conditions.shape[0] 693 | save_images( 694 | generated, 695 | ground_truth, 696 | conditions, 697 | names, 698 | batch_idx, 699 | args.output_dir, 700 | base_index, 701 | ) 702 | 703 | # Print batch results 704 | avg_psnr = np.mean([v for v in batch_metrics['PSNR'] if np.isfinite(v)]) 705 | avg_ssim = np.mean([v for v in batch_metrics['SSIM'] if np.isfinite(v)]) 706 | print(f" Batch {batch_idx + 1} - PSNR: {avg_psnr:.3f}, SSIM: {avg_ssim:.4f}, Time: {sampling_time:.2f}s") 707 | 708 | total_end_time = time.time() 709 | total_time = total_end_time - total_start_time 710 | avg_time_per_batch = np.mean(batch_times) 711 | 712 | # Print and save results 713 | print_results(all_metrics, args.scheduler_type, total_time, sample_count) 714 | save_results(all_metrics, args, total_time, avg_time_per_batch) 715 | 716 | print(f"\nTesting completed! Processed {sample_count} samples") 717 | 718 | 719 | if __name__ == '__main__': 720 | main() 721 | -------------------------------------------------------------------------------- /lib/loaders.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import pandas as pd 5 | from skimage import io, transform 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision import transforms, utils, datasets, models 10 | import warnings 11 | warnings.filterwarnings("ignore") 12 | 13 | 14 | #dir_gainDPM="gain/DPM/", 15 | #dir_gainDPMcars="gain/carsDPM/", 16 | #dir_gainIRT2="gain/IRT2/", 17 | #dir_gainIRT2cars="gain/carsIRT2/", 18 | #dir_buildings="png/", 19 | #dir_antenna= , 20 | 21 | 22 | class RadioUNet_c(Dataset): 23 | """RadioMapSeer Loader for accurate buildings and no measurements (RadioUNet_c)""" 24 | def __init__(self,maps_inds=np.zeros(1), phase="train", 25 | ind1=0,ind2=0, 26 | dir_dataset="RadioUNet/RadioMapSeer/", 27 | numTx=80, 28 | thresh=0.2, 29 | simulation="DPM", 30 | carsSimul="no", 31 | carsInput="no", 32 | IRT2maxW=1, 33 | cityMap="complete", 34 | missing=1, 35 | transform= transforms.ToTensor()): 36 | """ 37 | Args: 38 | maps_inds: optional shuffled sequence of the maps. Leave it as maps_inds=0 (default) for the standart split. 39 | phase:"train", "val", "test", "custom". If "train", "val" or "test", uses a standard split. 40 | "custom" means that the loader will read maps ind1 to ind2 from the list maps_inds. 41 | ind1,ind2: First and last indices from maps_inds to define the maps of the loader, in case phase="custom". 42 | dir_dataset: directory of the RadioMapSeer dataset. 43 | numTx: Number of transmitters per map. Default and maximal value of numTx = 80. 44 | thresh: Pathlos threshold between 0 and 1. Defaoult is the noise floor 0.2. 45 | simulation:"DPM", "IRT2", "rand". Default= "DPM" 46 | carsSimul:"no", "yes". Use simulation with or without cars. Default="no". 47 | carsInput:"no", "yes". Take inputs with or without cars channel. Default="no". 48 | IRT2maxW: in case of "rand" simulation, the maximal weight IRT2 can take. Default=1. 49 | cityMap: "complete", "missing", "rand". Use the full city, or input map with missing buildings "rand" means that there is 50 | a random number of missing buildings. 51 | missing: 1 to 4. in case of input map with missing buildings, and not "rand", the number of missing buildings. Default=1. 52 | transform: Transform to apply on the images of the loader. Default= transforms.ToTensor()) 53 | 54 | Output: 55 | inputs: The RadioUNet inputs. 56 | image_gain 57 | 58 | """ 59 | 60 | 61 | 62 | #self.phase=phase 63 | 64 | if maps_inds.size==1: 65 | self.maps_inds=np.arange(0,700,1,dtype=np.int16) 66 | #Determenistic "random" shuffle of the maps: 67 | np.random.seed(42) 68 | np.random.shuffle(self.maps_inds) 69 | else: 70 | self.maps_inds=maps_inds 71 | 72 | if phase=="train": 73 | self.ind1=0 74 | self.ind2=500 75 | elif phase=="val": 76 | self.ind1=501 77 | self.ind2=600 78 | elif phase=="test": 79 | self.ind1=601 80 | self.ind2=699 81 | else: # custom range 82 | self.ind1=ind1 83 | self.ind2=ind2 84 | 85 | # Normalize dataset root to ensure trailing slash 86 | self.dir_dataset = dir_dataset if dir_dataset.endswith('/') else (dir_dataset + '/') 87 | self.numTx = numTx 88 | self.thresh=thresh 89 | 90 | self.simulation=simulation 91 | self.carsSimul=carsSimul 92 | self.carsInput=carsInput 93 | if simulation=="DPM" : 94 | if carsSimul=="no": 95 | self.dir_gain=self.dir_dataset+"gain/DPM/" 96 | else: 97 | self.dir_gain=self.dir_dataset+"gain/carsDPM/" 98 | elif simulation=="IRT2": 99 | if carsSimul=="no": 100 | self.dir_gain=self.dir_dataset+"gain/IRT2/" 101 | else: 102 | self.dir_gain=self.dir_dataset+"gain/carsIRT2/" 103 | elif simulation=="rand": 104 | if carsSimul=="no": 105 | self.dir_gainDPM=self.dir_dataset+"gain/DPM/" 106 | self.dir_gainIRT2=self.dir_dataset+"gain/IRT2/" 107 | else: 108 | self.dir_gainDPM=self.dir_dataset+"gain/carsDPM/" 109 | self.dir_gainIRT2=self.dir_dataset+"gain/carsIRT2/" 110 | 111 | self.IRT2maxW=IRT2maxW 112 | 113 | self.cityMap=cityMap 114 | self.missing=missing 115 | if cityMap=="complete": 116 | self.dir_buildings=self.dir_dataset+"png/buildings_complete/" 117 | print(self.dir_buildings) 118 | else: 119 | self.dir_buildings = self.dir_dataset+"png/buildings_missing" # a random index will be concatenated in the code 120 | #else: #missing==number 121 | # self.dir_buildings = self.dir_dataset+ "png/buildings_missing"+str(missing)+"/" 122 | #print(self.dir_buildings) 123 | 124 | self.transform= transform 125 | 126 | self.dir_Tx = self.dir_dataset+ "png/antennas/" 127 | #later check if reading the JSON file and creating antenna images on the fly is faster 128 | if carsInput!="no": 129 | self.dir_cars = self.dir_dataset+ "png/cars/" 130 | 131 | self.height = 256 132 | self.width = 256 133 | 134 | 135 | def __len__(self): 136 | return (self.ind2-self.ind1+1)*self.numTx 137 | 138 | def __getitem__(self, idx): 139 | 140 | idxr=np.floor(idx/self.numTx).astype(int) 141 | idxc=idx-idxr*self.numTx 142 | dataset_map_ind=self.maps_inds[idxr+self.ind1]+1 143 | #names of files that depend only on the map: 144 | name1 = str(dataset_map_ind) + ".png" 145 | #names of files that depend on the map and the Tx: 146 | name2 = str(dataset_map_ind) + "_" + str(idxc) + ".png" 147 | 148 | #Load buildings: 149 | if self.cityMap == "complete": 150 | img_name_buildings = os.path.join(self.dir_buildings, name1) 151 | else: 152 | if self.cityMap == "rand": 153 | self.missing=np.random.randint(low=1, high=5) 154 | version=np.random.randint(low=1, high=7) 155 | img_name_buildings = os.path.join(self.dir_buildings+str(self.missing)+"/"+str(version)+"/", name1) 156 | 157 | str(self.missing) 158 | image_buildings = np.asarray(io.imread(img_name_buildings)) 159 | 160 | #Load Tx (transmitter): 161 | img_name_Tx = os.path.join(self.dir_Tx, name2) 162 | image_Tx = np.asarray(io.imread(img_name_Tx)) 163 | 164 | #Load radio map: 165 | if self.simulation!="rand": 166 | img_name_gain = os.path.join(self.dir_gain, name2) 167 | image_gain = np.expand_dims(np.asarray(io.imread(img_name_gain)),axis=2)/255 168 | else: #random weighted average of DPM and IRT2 169 | img_name_gainDPM = os.path.join(self.dir_gainDPM, name2) 170 | img_name_gainIRT2 = os.path.join(self.dir_gainIRT2, name2) 171 | #image_gainDPM = np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/255 172 | #image_gainIRT2 = np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/255 173 | w=np.random.uniform(0,self.IRT2maxW) # IRT2 weight of random average 174 | image_gain= w*np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/256 \ 175 | + (1-w)*np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/256 176 | 177 | #pathloss threshold transform 178 | if self.thresh>0: 179 | mask = image_gain < self.thresh 180 | image_gain[mask]=self.thresh 181 | image_gain=image_gain-self.thresh*np.ones(np.shape(image_gain)) 182 | image_gain=image_gain/(1-self.thresh) 183 | 184 | 185 | #inputs to radioUNet 186 | if self.carsInput=="no": 187 | inputs=np.stack([image_buildings, image_Tx], axis=2) 188 | #The fact that the buildings and antenna are normalized 256 and not 1 promotes convergence, 189 | #so we can use the same learning rate as RadioUNets 190 | else: #cars 191 | #Normalization, so all settings can have the same learning rate 192 | image_buildings=image_buildings/256 193 | image_Tx=image_Tx/256 194 | img_name_cars = os.path.join(self.dir_cars, name1) 195 | image_cars = np.asarray(io.imread(img_name_cars))/256 196 | inputs=np.stack([image_buildings, image_Tx, image_cars], axis=2) 197 | #note that ToTensor moves the channel from the last asix to the first! 198 | 199 | 200 | if self.transform: 201 | inputs = self.transform(inputs).type(torch.float32) 202 | image_gain = self.transform(image_gain).type(torch.float32) 203 | #note that ToTensor moves the channel from the last asix to the first! 204 | 205 | 206 | return (inputs, image_gain, name1) 207 | 208 | 209 | 210 | 211 | 212 | class RadioUNet_c_sprseIRT4(Dataset): 213 | """RadioMapSeer Loader for accurate buildings and no measurements (RadioUNet_c)""" 214 | def __init__(self,maps_inds=np.zeros(1), phase="train", 215 | ind1=0,ind2=0, 216 | dir_dataset="RadioMapSeer/", 217 | numTx=2, 218 | thresh=0.2, 219 | simulation="IRT4", 220 | carsSimul="no", 221 | carsInput="no", 222 | cityMap="complete", 223 | missing=1, 224 | num_samples=300, 225 | transform= transforms.ToTensor()): 226 | """ 227 | Args: 228 | maps_inds: optional shuffled sequence of the maps. Leave it as maps_inds=0 (default) for the standart split. 229 | phase:"train", "val", "test", "custom". If "train", "val" or "test", uses a standard split. 230 | "custom" means that the loader will read maps ind1 to ind2 from the list maps_inds. 231 | ind1,ind2: First and last indices from maps_inds to define the maps of the loader, in case phase="custom". 232 | dir_dataset: directory of the RadioMapSeer dataset. 233 | numTx: Number of transmitters per map. Default = 2. Note that IRT4 works only with numTx<=2. 234 | thresh: Pathlos threshold between 0 and 1. Defaoult is the noise floor 0.2. 235 | simulation: default="IRT4", with an option to "DPM", "IRT2". 236 | carsSimul:"no", "yes". Use simulation with or without cars. Default="no". 237 | carsInput:"no", "yes". Take inputs with or without cars channel. Default="no". 238 | cityMap: "complete", "missing", "rand". Use the full city, or input map with missing buildings "rand" means that there is 239 | a random number of missing buildings. 240 | missing: 1 to 4. in case of input map with missing buildings, and not "rand", the number of missing buildings. Default=1. 241 | num_samples: number of samples in the sparse IRT4 radio map. Default=300. 242 | transform: Transform to apply on the images of the loader. Default= transforms.ToTensor()) 243 | 244 | Output: 245 | 246 | """ 247 | if maps_inds.size==1: 248 | self.maps_inds=np.arange(0,700,1,dtype=np.int16) 249 | #Determenistic "random" shuffle of the maps: 250 | np.random.seed(42) 251 | np.random.shuffle(self.maps_inds) 252 | else: 253 | self.maps_inds=maps_inds 254 | 255 | if phase=="train": 256 | self.ind1=0 257 | self.ind2=500 258 | elif phase=="val": 259 | self.ind1=501 260 | self.ind2=600 261 | elif phase=="test": 262 | self.ind1=601 263 | self.ind2=699 264 | else: # custom range 265 | self.ind1=ind1 266 | self.ind2=ind2 267 | 268 | self.dir_dataset = dir_dataset 269 | self.numTx= numTx 270 | self.thresh=thresh 271 | 272 | self.simulation=simulation 273 | self.carsSimul=carsSimul 274 | self.carsInput=carsInput 275 | if simulation=="IRT4": 276 | if carsSimul=="no": 277 | self.dir_gain=self.dir_dataset+"gain/IRT4/" 278 | else: 279 | self.dir_gain=self.dir_dataset+"gain/carsIRT4/" 280 | 281 | elif simulation=="DPM" : 282 | if carsSimul=="no": 283 | self.dir_gain=self.dir_dataset+"gain/DPM/" 284 | else: 285 | self.dir_gain=self.dir_dataset+"gain/carsDPM/" 286 | elif simulation=="IRT2": 287 | if carsSimul=="no": 288 | self.dir_gain=self.dir_dataset+"gain/IRT2/" 289 | else: 290 | self.dir_gain=self.dir_dataset+"gain/carsIRT2/" 291 | 292 | 293 | self.cityMap=cityMap 294 | self.missing=missing 295 | if cityMap=="complete": 296 | self.dir_buildings=self.dir_dataset+"png/buildings_complete/" 297 | else: 298 | self.dir_buildings = self.dir_dataset+"png/buildings_missing" # a random index will be concatenated in the code 299 | #else: #missing==number 300 | # self.dir_buildings = self.dir_dataset+ "png/buildings_missing"+str(missing)+"/" 301 | 302 | 303 | self.transform= transform 304 | 305 | self.num_samples=num_samples 306 | 307 | self.dir_Tx = self.dir_dataset+ "png/antennas/" 308 | #later check if reading the JSON file and creating antenna images on the fly is faster 309 | if carsInput!="no": 310 | self.dir_cars = self.dir_dataset+ "png/cars/" 311 | 312 | self.height = 256 313 | self.width = 256 314 | 315 | 316 | 317 | 318 | 319 | def __len__(self): 320 | return (self.ind2-self.ind1+1)*self.numTx 321 | 322 | def __getitem__(self, idx): 323 | 324 | idxr=np.floor(idx/self.numTx).astype(int) 325 | idxc=idx-idxr*self.numTx 326 | dataset_map_ind=self.maps_inds[idxr+self.ind1]+1 327 | #names of files that depend only on the map: 328 | name1 = str(dataset_map_ind) + ".png" 329 | #names of files that depend on the map and the Tx: 330 | name2 = str(dataset_map_ind) + "_" + str(idxc) + ".png" 331 | 332 | #Load buildings: 333 | if self.cityMap == "complete": 334 | img_name_buildings = os.path.join(self.dir_buildings, name1) 335 | else: 336 | if self.cityMap == "rand": 337 | self.missing=np.random.randint(low=1, high=5) 338 | version=np.random.randint(low=1, high=7) 339 | img_name_buildings = os.path.join(self.dir_buildings+str(self.missing)+"/"+str(version)+"/", name1) 340 | str(self.missing) 341 | image_buildings = np.asarray(io.imread(img_name_buildings)) 342 | 343 | #Load Tx (transmitter): 344 | img_name_Tx = os.path.join(self.dir_Tx, name2) 345 | image_Tx = np.asarray(io.imread(img_name_Tx)) 346 | 347 | #Load radio map: 348 | if self.simulation!="rand": 349 | img_name_gain = os.path.join(self.dir_gain, name2) 350 | image_gain = np.expand_dims(np.asarray(io.imread(img_name_gain)),axis=2)/256 351 | else: #random weighted average of DPM and IRT2 352 | img_name_gainDPM = os.path.join(self.dir_gainDPM, name2) 353 | img_name_gainIRT2 = os.path.join(self.dir_gainIRT2, name2) 354 | #image_gainDPM = np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/255 355 | #image_gainIRT2 = np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/255 356 | w=np.random.uniform(0,self.IRT2maxW) # IRT2 weight of random average 357 | image_gain= w*np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/256 \ 358 | + (1-w)*np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/256 359 | 360 | #pathloss threshold transform 361 | if self.thresh>0: 362 | mask = image_gain < self.thresh 363 | image_gain[mask]=self.thresh 364 | image_gain=image_gain-self.thresh*np.ones(np.shape(image_gain)) 365 | image_gain=image_gain/(1-self.thresh) 366 | 367 | #Saprse IRT4 samples, determenistic and fixed samples per map 368 | image_samples = np.zeros((self.width,self.height)) 369 | seed_map=np.sum(image_buildings) # Each map has its fixed samples, independent of the transmitter location. 370 | np.random.seed(seed_map) 371 | x_samples=np.random.randint(0, 255, size=self.num_samples) 372 | y_samples=np.random.randint(0, 255, size=self.num_samples) 373 | image_samples[x_samples,y_samples]= 1 374 | 375 | #inputs to radioUNet 376 | if self.carsInput=="no": 377 | inputs=np.stack([image_buildings, image_Tx], axis=2) 378 | #The fact that the buildings and antenna are normalized 256 and not 1 promotes convergence, 379 | #so we can use the same learning rate as RadioUNets 380 | else: #cars 381 | #Normalization, so all settings can have the same learning rate 382 | image_buildings=image_buildings/256 383 | image_Tx=image_Tx/256 384 | img_name_cars = os.path.join(self.dir_cars, name1) 385 | image_cars = np.asarray(io.imread(img_name_cars))/256 386 | inputs=np.stack([image_buildings, image_Tx, image_cars], axis=2) 387 | #note that ToTensor moves the channel from the last asix to the first! 388 | 389 | 390 | 391 | 392 | if self.transform: 393 | inputs = self.transform(inputs).type(torch.float32) 394 | image_gain = self.transform(image_gain).type(torch.float32) 395 | image_samples = self.transform(image_samples).type(torch.float32) 396 | 397 | 398 | return [inputs, image_gain, image_samples] 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | class RadioUNet_s(Dataset): 407 | """RadioMapSeer Loader for accurate buildings and no measurements (RadioUNet_c)""" 408 | def __init__(self,maps_inds=np.zeros(1), phase="train", 409 | ind1=0,ind2=0, 410 | dir_dataset="RadioUNet/RadioMapSeer/", 411 | numTx=80, 412 | thresh=0.2, 413 | simulation="DPM", 414 | carsSimul="no", 415 | carsInput="no", 416 | IRT2maxW=1, 417 | cityMap="complete", 418 | missing=1, 419 | fix_samples=0, 420 | num_samples_low= 10, 421 | num_samples_high= 300, 422 | transform= transforms.ToTensor()): 423 | """ 424 | Args: 425 | maps_inds: optional shuffled sequence of the maps. Leave it as maps_inds=0 (default) for the standart split. 426 | phase:"train", "val", "test", "custom". If "train", "val" or "test", uses a standard split. 427 | "custom" means that the loader will read maps ind1 to ind2 from the list maps_inds. 428 | ind1,ind2: First and last indices from maps_inds to define the maps of the loader, in case phase="custom". 429 | dir_dataset: directory of the RadioMapSeer dataset. 430 | numTx: Number of transmitters per map. Default and maximal value of numTx = 80. 431 | thresh: Pathlos threshold between 0 and 1. Defaoult is the noise floor 0.2. 432 | simulation:"DPM", "IRT2", "rand". Default= "DPM" 433 | carsSimul:"no", "yes". Use simulation with or without cars. Default="no". 434 | carsInput:"no", "yes". Take inputs with or without cars channel. Default="no". 435 | IRT2maxW: in case of "rand" simulation, the maximal weight IRT2 can take. Default=1. 436 | cityMap: "complete", "missing", "rand". Use the full city, or input map with missing buildings "rand" means that there is 437 | a random number of missing buildings. 438 | missing: 1 to 4. in case of input map with missing buildings, and not "rand", the number of missing buildings. Default=1. 439 | fix_samples: fixed or a random number of samples. If zero, fixed, else, fix_samples is the number of samples. Default = 0. 440 | num_samples_low: if random number of samples, this is the minimum number of samples. Default = 10. 441 | num_samples_high: if random number of samples, this is the maximal number of samples. Default = 300. 442 | transform: Transform to apply on the images of the loader. Default= transforms.ToTensor()) 443 | 444 | Output: 445 | inputs: The RadioUNet inputs. 446 | image_gain 447 | 448 | """ 449 | 450 | 451 | 452 | #self.phase=phase 453 | 454 | if maps_inds.size==1: 455 | self.maps_inds=np.arange(0,700,1,dtype=np.int16) 456 | #Determenistic "random" shuffle of the maps: 457 | np.random.seed(42) 458 | np.random.shuffle(self.maps_inds) 459 | else: 460 | self.maps_inds=maps_inds 461 | 462 | if phase=="train": 463 | self.ind1=0 464 | self.ind2=500 465 | elif phase=="val": 466 | self.ind1=501 467 | self.ind2=600 468 | elif phase=="test": 469 | self.ind1=601 470 | self.ind2=699 471 | else: # custom range 472 | self.ind1=ind1 473 | self.ind2=ind2 474 | 475 | self.dir_dataset = dir_dataset 476 | self.numTx= numTx 477 | self.thresh=thresh 478 | 479 | self.simulation=simulation 480 | self.carsSimul=carsSimul 481 | self.carsInput=carsInput 482 | if simulation=="DPM" : 483 | if carsSimul=="no": 484 | self.dir_gain=self.dir_dataset+"gain/DPM/" 485 | else: 486 | self.dir_gain=self.dir_dataset+"gain/carsDPM/" 487 | elif simulation=="IRT2": 488 | if carsSimul=="no": 489 | self.dir_gain=self.dir_dataset+"gain/IRT2/" 490 | else: 491 | self.dir_gain=self.dir_dataset+"gain/carsIRT2/" 492 | elif simulation=="rand": 493 | if carsSimul=="no": 494 | self.dir_gainDPM=self.dir_dataset+"gain/DPM/" 495 | self.dir_gainIRT2=self.dir_dataset+"gain/IRT2/" 496 | else: 497 | self.dir_gainDPM=self.dir_dataset+"gain/carsDPM/" 498 | self.dir_gainIRT2=self.dir_dataset+"gain/carsIRT2/" 499 | 500 | self.IRT2maxW=IRT2maxW 501 | 502 | self.cityMap=cityMap 503 | self.missing=missing 504 | if cityMap=="complete": 505 | self.dir_buildings=self.dir_dataset+"png/buildings_complete/" 506 | else: 507 | self.dir_buildings = self.dir_dataset+"png/buildings_missing" # a random index will be concatenated in the code 508 | #else: #missing==number 509 | # self.dir_buildings = self.dir_dataset+ "png/buildings_missing"+str(missing)+"/" 510 | 511 | 512 | self.fix_samples= fix_samples 513 | self.num_samples_low= num_samples_low 514 | self.num_samples_high= num_samples_high 515 | 516 | self.transform= transform 517 | 518 | self.dir_Tx = self.dir_dataset+ "png/antennas/" 519 | #later check if reading the JSON file and creating antenna images on the fly is faster 520 | if carsInput!="no": 521 | self.dir_cars = self.dir_dataset+ "png/cars/" 522 | 523 | self.height = 256 524 | self.width = 256 525 | 526 | 527 | def __len__(self): 528 | return (self.ind2-self.ind1+1)*self.numTx 529 | 530 | def __getitem__(self, idx): 531 | 532 | idxr=np.floor(idx/self.numTx).astype(int) 533 | idxc=idx-idxr*self.numTx 534 | dataset_map_ind=self.maps_inds[idxr+self.ind1]+1 535 | #names of files that depend only on the map: 536 | name1 = str(dataset_map_ind) + ".png" 537 | #names of files that depend on the map and the Tx: 538 | name2 = str(dataset_map_ind) + "_" + str(idxc) + ".png" 539 | 540 | #Load buildings: 541 | if self.cityMap == "complete": 542 | img_name_buildings = os.path.join(self.dir_buildings, name1) 543 | else: 544 | if self.cityMap == "rand": 545 | self.missing=np.random.randint(low=1, high=5) 546 | version=np.random.randint(low=1, high=7) 547 | img_name_buildings = os.path.join(self.dir_buildings+str(self.missing)+"/"+str(version)+"/", name1) 548 | str(self.missing) 549 | image_buildings = np.asarray(io.imread(img_name_buildings))/256 550 | 551 | #Load Tx (transmitter): 552 | img_name_Tx = os.path.join(self.dir_Tx, name2) 553 | image_Tx = np.asarray(io.imread(img_name_Tx))/256 554 | 555 | #Load radio map: 556 | if self.simulation!="rand": 557 | img_name_gain = os.path.join(self.dir_gain, name2) 558 | image_gain = np.expand_dims(np.asarray(io.imread(img_name_gain)),axis=2)/256 559 | else: #random weighted average of DPM and IRT2 560 | img_name_gainDPM = os.path.join(self.dir_gainDPM, name2) 561 | img_name_gainIRT2 = os.path.join(self.dir_gainIRT2, name2) 562 | #image_gainDPM = np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/255 563 | #image_gainIRT2 = np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/255 564 | w=np.random.uniform(0,self.IRT2maxW) # IRT2 weight of random average 565 | image_gain= w*np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/256 \ 566 | + (1-w)*np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/256 567 | 568 | #pathloss threshold transform 569 | if self.thresh>0: 570 | mask = image_gain < self.thresh 571 | image_gain[mask]=self.thresh 572 | image_gain=image_gain-self.thresh*np.ones(np.shape(image_gain)) 573 | image_gain=image_gain/(1-self.thresh) 574 | 575 | #image_gain=image_gain*256 # we use this normalization so all RadioUNet methods can have the same learning rate. 576 | # Namely, the loss of RadioUNet_s is 256 the loss of RadioUNet_c 577 | # Important: when evaluating the accuracy, remember to devide the errors by 256! 578 | 579 | #input measurements 580 | image_samples = np.zeros((256,256)) 581 | if self.fix_samples==0: 582 | num_samples=np.random.randint(self.num_samples_low, self.num_samples_high, size=1) 583 | else: 584 | num_samples=np.floor(self.fix_samples).astype(int) 585 | x_samples=np.random.randint(0, 255, size=num_samples) 586 | y_samples=np.random.randint(0, 255, size=num_samples) 587 | image_samples[x_samples,y_samples]= image_gain[x_samples,y_samples,0] 588 | 589 | #inputs to radioUNet 590 | if self.carsInput=="no": 591 | inputs=np.stack([image_buildings, image_Tx, image_samples], axis=2) 592 | #The fact that the buildings and antenna are normalized 256 and not 1 promotes convergence, 593 | #so we can use the same learning rate as RadioUNets 594 | else: #cars 595 | #Normalization, so all settings can have the same learning rate 596 | img_name_cars = os.path.join(self.dir_cars, name1) 597 | image_cars = np.asarray(io.imread(img_name_cars))/256 598 | inputs=np.stack([image_buildings, image_Tx, image_samples, image_cars], axis=2) 599 | #note that ToTensor moves the channel from the last asix to the first! 600 | 601 | 602 | 603 | if self.transform: 604 | inputs = self.transform(inputs).type(torch.float32) 605 | image_gain = self.transform(image_gain).type(torch.float32) 606 | #note that ToTensor moves the channel from the last asix to the first! 607 | 608 | 609 | return (inputs, image_gain, num_samples) 610 | 611 | 612 | 613 | 614 | 615 | class RadioUNet_s_sprseIRT4(Dataset): 616 | """RadioMapSeer Loader for accurate buildings and no measurements (RadioUNet_c)""" 617 | def __init__(self,maps_inds=np.zeros(1), phase="train", 618 | ind1=0,ind2=0, 619 | dir_dataset="RadioMapSeer/", 620 | numTx=2, 621 | thresh=0.2, 622 | simulation="IRT4", 623 | carsSimul="no", 624 | carsInput="no", 625 | cityMap="complete", 626 | missing=1, 627 | data_samples=300, 628 | fix_samples=0, 629 | num_samples_low= 10, 630 | num_samples_high= 299, 631 | transform= transforms.ToTensor()): 632 | """ 633 | Args: 634 | maps_inds: optional shuffled sequence of the maps. Leave it as maps_inds=0 (default) for the standart split. 635 | phase:"train", "val", "test", "custom". If "train", "val" or "test", uses a standard split. 636 | "custom" means that the loader will read maps ind1 to ind2 from the list maps_inds. 637 | ind1,ind2: First and last indices from maps_inds to define the maps of the loader, in case phase="custom". 638 | dir_dataset: directory of the RadioMapSeer dataset. 639 | numTx: Number of transmitters per map. Default = 2. Note that IRT4 works only with numTx<=2. 640 | thresh: Pathlos threshold between 0 and 1. Defaoult is the noise floor 0.2. 641 | simulation: default="IRT4", with an option to "DPM", "IRT2". 642 | carsSimul:"no", "yes". Use simulation with or without cars. Default="no". 643 | carsInput:"no", "yes". Take inputs with or without cars channel. Default="no". 644 | cityMap: "complete", "missing", "rand". Use the full city, or input map with missing buildings "rand" means that there is 645 | a random number of missing buildings. 646 | missing: 1 to 4. in case of input map with missing buildings, and not "rand", the number of missing buildings. Default=1. 647 | data_samples: number of samples in the sparse IRT4 radio map. Default=300. All input samples are taken from the data_samples 648 | fix_samples: fixed or a random number of samples. If zero, fixed, else, fix_samples is the number of samples. Default = 0. 649 | num_samples_low: if random number of samples, this is the minimum number of samples. Default = 10. 650 | num_samples_high: if random number of samples, this is the maximal number of samples. Default = 300. 651 | transform: Transform to apply on the images of the loader. Default= transforms.ToTensor()) 652 | 653 | Output: 654 | 655 | """ 656 | if maps_inds.size==1: 657 | self.maps_inds=np.arange(0,700,1,dtype=np.int16) 658 | #Determenistic "random" shuffle of the maps: 659 | np.random.seed(42) 660 | np.random.shuffle(self.maps_inds) 661 | else: 662 | self.maps_inds=maps_inds 663 | 664 | if phase=="train": 665 | self.ind1=0 666 | self.ind2=500 667 | elif phase=="val": 668 | self.ind1=501 669 | self.ind2=600 670 | elif phase=="test": 671 | self.ind1=601 672 | self.ind2=699 673 | else: # custom range 674 | self.ind1=ind1 675 | self.ind2=ind2 676 | 677 | self.dir_dataset = dir_dataset 678 | self.numTx= numTx 679 | self.thresh=thresh 680 | 681 | self.simulation=simulation 682 | self.carsSimul=carsSimul 683 | self.carsInput=carsInput 684 | if simulation=="IRT4": 685 | if carsSimul=="no": 686 | self.dir_gain=self.dir_dataset+"gain/IRT4/" 687 | else: 688 | self.dir_gain=self.dir_dataset+"gain/carsIRT4/" 689 | 690 | elif simulation=="DPM" : 691 | if carsSimul=="no": 692 | self.dir_gain=self.dir_dataset+"gain/DPM/" 693 | else: 694 | self.dir_gain=self.dir_dataset+"gain/carsDPM/" 695 | elif simulation=="IRT2": 696 | if carsSimul=="no": 697 | self.dir_gain=self.dir_dataset+"gain/IRT2/" 698 | else: 699 | self.dir_gain=self.dir_dataset+"gain/carsIRT2/" 700 | 701 | 702 | self.cityMap=cityMap 703 | self.missing=missing 704 | if cityMap=="complete": 705 | self.dir_buildings=self.dir_dataset+"png/buildings_complete/" 706 | else: 707 | self.dir_buildings = self.dir_dataset+"png/buildings_missing" # a random index will be concatenated in the code 708 | #else: #missing==number 709 | # self.dir_buildings = self.dir_dataset+ "png/buildings_missing"+str(missing)+"/" 710 | 711 | 712 | self.data_samples=data_samples 713 | self.fix_samples= fix_samples 714 | self.num_samples_low= num_samples_low 715 | self.num_samples_high= num_samples_high 716 | 717 | self.transform= transform 718 | 719 | 720 | self.dir_Tx = self.dir_dataset+ "png/antennas/" 721 | #later check if reading the JSON file and creating antenna images on the fly is faster 722 | if carsInput!="no": 723 | self.dir_cars = self.dir_dataset+ "png/cars/" 724 | 725 | self.height = 256 726 | self.width = 256 727 | 728 | 729 | 730 | 731 | 732 | def __len__(self): 733 | return (self.ind2-self.ind1+1)*self.numTx 734 | 735 | def __getitem__(self, idx): 736 | 737 | idxr=np.floor(idx/self.numTx).astype(int) 738 | idxc=idx-idxr*self.numTx 739 | dataset_map_ind=self.maps_inds[idxr+self.ind1]+1 740 | #names of files that depend only on the map: 741 | name1 = str(dataset_map_ind) + ".png" 742 | #names of files that depend on the map and the Tx: 743 | name2 = str(dataset_map_ind) + "_" + str(idxc) + ".png" 744 | 745 | #Load buildings: 746 | if self.cityMap == "complete": 747 | img_name_buildings = os.path.join(self.dir_buildings, name1) 748 | else: 749 | if self.cityMap == "rand": 750 | self.missing=np.random.randint(low=1, high=5) 751 | version=np.random.randint(low=1, high=7) 752 | img_name_buildings = os.path.join(self.dir_buildings+str(self.missing)+"/"+str(version)+"/", name1) 753 | str(self.missing) 754 | image_buildings = np.asarray(io.imread(img_name_buildings)) #Will be normalized later, after random seed is computed from it 755 | 756 | #Load Tx (transmitter): 757 | img_name_Tx = os.path.join(self.dir_Tx, name2) 758 | image_Tx = np.asarray(io.imread(img_name_Tx))/256 759 | 760 | #Load radio map: 761 | if self.simulation!="rand": 762 | img_name_gain = os.path.join(self.dir_gain, name2) 763 | image_gain = np.expand_dims(np.asarray(io.imread(img_name_gain)),axis=2)/256 764 | else: #random weighted average of DPM and IRT2 765 | img_name_gainDPM = os.path.join(self.dir_gainDPM, name2) 766 | img_name_gainIRT2 = os.path.join(self.dir_gainIRT2, name2) 767 | #image_gainDPM = np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/255 768 | #image_gainIRT2 = np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/255 769 | w=np.random.uniform(0,self.IRT2maxW) # IRT2 weight of random average 770 | image_gain= w*np.expand_dims(np.asarray(io.imread(img_name_gainIRT2)),axis=2)/256 \ 771 | + (1-w)*np.expand_dims(np.asarray(io.imread(img_name_gainDPM)),axis=2)/256 772 | 773 | #pathloss threshold transform 774 | if self.thresh>0: 775 | mask = image_gain < self.thresh 776 | image_gain[mask]=self.thresh 777 | image_gain=image_gain-self.thresh*np.ones(np.shape(image_gain)) 778 | image_gain=image_gain/(1-self.thresh) 779 | 780 | image_gain=image_gain*256 # we use this normalization so all RadioUNet methods can have the same learning rate. 781 | # Namely, the loss of RadioUNet_s is 256 the loss of RadioUNet_c 782 | # Important: when evaluating the accuracy, remember to devide the errors by 256! 783 | 784 | #Saprse IRT4 samples, determenistic and fixed samples per map 785 | sparse_samples = np.zeros((self.width,self.height)) 786 | seed_map=np.sum(image_buildings) # Each map has its fixed samples, independent of the transmitter location. 787 | np.random.seed(seed_map) 788 | x_samples=np.random.randint(0, 255, size=self.data_samples) 789 | y_samples=np.random.randint(0, 255, size=self.data_samples) 790 | sparse_samples[x_samples,y_samples]= 1 791 | 792 | #input samples from the sparse gain samples 793 | input_samples = np.zeros((256,256)) 794 | if self.fix_samples==0: 795 | num_in_samples=np.random.randint(self.num_samples_low, self.num_samples_high, size=1) 796 | else: 797 | num_in_samples=np.floor(self.fix_samples).astype(int) 798 | 799 | data_inds=range(self.data_samples) 800 | input_inds=np.random.permutation(data_inds)[0:num_in_samples[0]] 801 | x_samples_in=x_samples[input_inds] 802 | y_samples_in=y_samples[input_inds] 803 | input_samples[x_samples_in,y_samples_in]= image_gain[x_samples_in,y_samples_in,0] 804 | 805 | #normalize image_buildings, after random seed computed from it as an int 806 | image_buildings=image_buildings/256 807 | 808 | #inputs to radioUNet 809 | if self.carsInput=="no": 810 | inputs=np.stack([image_buildings, image_Tx, input_samples], axis=2) 811 | #The fact that the buildings and antenna are normalized 256 and not 1 promotes convergence, 812 | #so we can use the same learning rate as RadioUNets 813 | else: #cars 814 | #Normalization, so all settings can have the same learning rate 815 | img_name_cars = os.path.join(self.dir_cars, name1) 816 | image_cars = np.asarray(io.imread(img_name_cars))/256 817 | inputs=np.stack([image_buildings, image_Tx, input_samples, image_cars], axis=2) 818 | #note that ToTensor moves the channel from the last asix to the first! 819 | 820 | 821 | 822 | 823 | if self.transform: 824 | inputs = self.transform(inputs).type(torch.float32) 825 | image_gain = self.transform(image_gain).type(torch.float32) 826 | sparse_samples = self.transform(sparse_samples).type(torch.float32) 827 | 828 | 829 | 830 | return [inputs, image_gain, sparse_samples] 831 | 832 | 833 | 834 | 835 | 836 | 837 | 838 | 839 | --------------------------------------------------------------------------------