├── Pretrain ├── main.py ├── models.py ├── patch_encoder_and_projection_head_and_classifier.py ├── pytorch_pretrained_vit │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── configs.cpython-38.pyc │ │ ├── model.cpython-38.pyc │ │ ├── transformer.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── configs.py │ ├── model.py │ ├── transformer.py │ └── utils.py ├── signal_data_add.py └── utils.py ├── Readme.md ├── Test and Comparison ├── main.py ├── models.py ├── patch_encoder_and_projection_head_and_classifier.py ├── pytorch_pretrained_vit │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── configs.cpython-38.pyc │ │ ├── model.cpython-38.pyc │ │ ├── transformer.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── configs.py │ ├── model.py │ ├── transformer.py │ └── utils.py ├── signal_data_add.py └── utils.py └── generate_CWTdataset.ipynb /Pretrain/main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pickle as pk 6 | import numpy as np 7 | from utils import * 8 | from tqdm import tqdm 9 | from signal_data_add import get_signal 10 | from torch.utils.data import DataLoader 11 | import pywt 12 | import matplotlib.pyplot as plt 13 | from patch_encoder_and_projection_head_and_classifier import * 14 | from pytorch_pretrained_vit import ViT 15 | import os 16 | from models import * 17 | import pandas as pd 18 | 19 | 20 | def pretrain(model,projection_head,epoches,converter,optimizer,data_loader,save_path): 21 | for i in range(epoches): 22 | with tqdm(total=len(data_loader)) as p_bar: 23 | for batch_idx, (singals, CWTs, labels) in enumerate(data_loader): 24 | singals = singals.cuda() 25 | videos = converter(singals) 26 | CWTs = CWTs.cuda() 27 | labels = labels.cuda() 28 | 29 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 30 | singals_projection,CWTs_projection,videos_projection = projection_head(singals_feature,CWTs_feature,videos_feature) 31 | 32 | loss1 = compute_loss(singals_projection,videos_projection)#去噪,兼顾信息增益(projection head不同) 33 | loss2 = compute_loss(singals_projection,CWTs_projection)#去噪,兼顾信息增益(projection head不同) 34 | loss4 = compute_loss(videos_projection,CWTs_projection) 35 | loss = loss1 + loss2 + loss4 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | p_bar.set_description("pretrain: epoch:{} batch_idx:{} loss:{:02f}".format(i,batch_idx,loss)) 40 | p_bar.update() 41 | 42 | torch.save(model.state_dict(),os.path.join(save_path,'model_state.pth')) 43 | torch.save(projection_head.state_dict(),os.path.join(save_path,'projection_head_state.pth')) 44 | 45 | def classifier_train(model,classifier,epoches,converter,optimizer,train_loader,val_loader,save_path): 46 | max_acc = 0 47 | for epoch in range(epoches): 48 | with tqdm(total=len(train_loader)) as p_bar: 49 | for batch_idx, (singals, CWTs, labels) in enumerate(train_loader): 50 | singals = singals.cuda() 51 | videos = converter(singals) 52 | CWTs = CWTs.cuda() 53 | labels = labels.cuda() 54 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 55 | feature = torch.cat([singals_feature,CWTs_feature,videos_feature],dim = -1) 56 | pre = classifier(feature) 57 | loss = F.cross_entropy(pre,labels) 58 | optimizer.zero_grad() 59 | loss.backward() 60 | optimizer.step() 61 | p_bar.set_description("classifier_train: epoch:{} batch_idx:{} loss:{:02f}".format(epoch, batch_idx, loss)) 62 | p_bar.update() 63 | if epoch % 4 == 0 and epoch != 0: 64 | acc = test(model,classifier,val_loader) 65 | if acc>max_acc: 66 | print("max_acc:",acc) 67 | torch.save(model.state_dict(),os.path.join(save_path,'model_state.pth')) 68 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 69 | max_acc = acc 70 | 71 | acc = test(model,classifier,val_loader) 72 | if acc>max_acc: 73 | print("max_acc:",acc) 74 | torch.save(model.state_dict(),os.path.join(save_path,'model_state.pth')) 75 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 76 | max_acc = acc 77 | 78 | 79 | 80 | def test(model,classifier,data_loader,snrs = None,snr_indexs = None,save_path = None,final = False): 81 | prediction = [] 82 | true = [] 83 | if final: 84 | model.load_state_dict(torch.load(os.path.join(save_path,'model_state.pth'))) 85 | classifier.load_state_dict(torch.load(os.path.join(save_path,'classifier_state.pth'))) 86 | for batch_idx, (singals, CWTs, labels) in enumerate(data_loader): 87 | singals = singals.cuda() 88 | videos = converter(singals) 89 | CWTs = CWTs.cuda() 90 | labels = labels.cuda() 91 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 92 | feature = torch.cat([singals_feature,CWTs_feature,videos_feature],dim = -1) 93 | result = torch.argmax(classifier(feature),dim = 1) 94 | prediction.extend(result.cpu().numpy()) 95 | true.extend(labels.cpu().numpy()) 96 | if not final: 97 | return sum(np.array(prediction) == np.array(true))/len(true) 98 | else: 99 | prediction = np.array(prediction) 100 | true = np.array(true) 101 | acc = {} 102 | for i in range(len(snrs)): 103 | true_label = true[snr_indexs[i]] 104 | #print(true_label.shape) 105 | pre_label = prediction[snr_indexs[i]] 106 | cor = np.sum(true_label == pre_label) 107 | acc[snrs[i]] = 1.0 * cor / true_label.shape[0] 108 | total_acc = sum(np.array(prediction) == np.array(true))/len(true) 109 | acc['total'] = total_acc 110 | ACC = pd.DataFrame(acc.items()) 111 | ACC.to_csv(os.path.join(save_path,'result.csv')) 112 | return total_acc 113 | 114 | 115 | 116 | bsz = 100 117 | traindataset,valset,testset,classes,snrs,snr_indexs = get_signal('/root/autodl-tmp/RML2016.10a_dict.pkl','/root/autodl-tmp/CWTdata.pkl',1,L=1,snrs_index=0) 118 | ctraindataset,cvalset,ctestset,cclasses,csnrs,csnr_indexs = get_signal('/root/autodl-tmp/RML2016.10a_dict.pkl','/root/autodl-tmp/CWTdata.pkl',0.3,L=1,snrs_index=0) 119 | 120 | 121 | converter = video_Converter(bsz,14) 122 | train_loader = torch.utils.data.DataLoader(traindataset, batch_size=bsz, shuffle=True) 123 | val_loader = torch.utils.data.DataLoader(valset, batch_size=bsz, shuffle=False) 124 | test_loader = torch.utils.data.DataLoader(testset, batch_size=bsz, shuffle=False) 125 | 126 | ctrain_loader = torch.utils.data.DataLoader(ctraindataset, batch_size=bsz, shuffle=True) 127 | cval_loader = torch.utils.data.DataLoader(cvalset, batch_size=bsz, shuffle=False) 128 | ctest_loader = torch.utils.data.DataLoader(ctestset, batch_size=bsz, shuffle=False) 129 | 130 | root_path = '/root/autodl-tmp/不同标签量实验/不微调' 131 | 132 | 133 | 134 | for epoch in [5,10,15,20,50,100]: 135 | model = Model(160,16,2).cuda() 136 | projection_head = Proj(160,160).cuda()#160或者128 137 | model.train() 138 | projection_head.train() 139 | classifier = classifier_head(3*160,11).cuda() 140 | classifier.train() 141 | 142 | save_path = os.path.join(root_path,str(epoch)) 143 | if not os.path.exists(save_path): 144 | os.mkdir(save_path) 145 | 146 | optimizer = torch.optim.Adam(list(model.parameters())+list(projection_head.parameters()),lr = 0.0005) 147 | pretrain(model,projection_head,epoch,converter,optimizer,train_loader,save_path) 148 | optimizer = torch.optim.Adam(classifier.parameters(), lr = 0.001) 149 | classifier_train(model,classifier,20,converter,optimizer,ctrain_loader,cval_loader,save_path) 150 | test(model,classifier,ctest_loader,csnrs,csnr_indexs,save_path,final = True) 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /Pretrain/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pickle as pk 6 | import numpy as np 7 | from utils import * 8 | from tqdm import tqdm 9 | from signal_data_add import get_signal 10 | from torch.utils.data import DataLoader 11 | import pywt 12 | import matplotlib.pyplot as plt 13 | from patch_encoder_and_projection_head_and_classifier import * 14 | from pytorch_pretrained_vit import ViT 15 | import os 16 | 17 | class Model(nn.Module): 18 | def __init__(self,feature_dim,nh,nl): 19 | super().__init__() 20 | self.CWTs_emb = patch_embedding_for_CWTs(feature_dim) 21 | self.singals_emb = patch_embedding_for_singals(feature_dim) 22 | self.videos_emb = patch_embedding_for_videos(feature_dim) 23 | #特征提取器 24 | self.extractor = ViT(feature_dim,128,num_heads=nh,num_layers=nl) 25 | 26 | #layer_norm层(输入) 27 | self.singals_norm_layer = nn.LayerNorm([2,128])#个体初始化一致性 28 | self.CWTs_norm_layer = nn.LayerNorm([2,99,128]) 29 | self.video_norm_layer = nn.LayerNorm([128,1,14,14]) 30 | #layer_norm层(特征) 31 | self.feature_norm_layer = nn.LayerNorm([128,feature_dim]).cuda()#以特征为归一化标准,不能以个体吧?如果以个体,那所有个体都一样了 32 | 33 | def forward(self,singals,CWTs,videos): 34 | #norm 35 | singals = self.singals_norm_layer(singals) 36 | CWTs = self.CWTs_norm_layer(CWTs) 37 | videos = self.video_norm_layer(videos) 38 | 39 | singals_embed = self.feature_norm_layer(self.singals_emb(singals)) 40 | CWTs_embed = self.feature_norm_layer(self.CWTs_emb(CWTs))#每一种模态的特征在进入transformer前,都会进行layer norm,以确保特征间的一致性 41 | videos_embed = self.feature_norm_layer(self.videos_emb(videos)) 42 | #layer_norm以每个模态的一个“个体”为单位(可以更换为以一个特征为单位) 43 | singals_feature = self.extractor(singals_embed) 44 | CWTs_feature = self.extractor(CWTs_embed)#这里应该不用加layer norm,因为transformer的最后有norm 45 | videos_feature = self.extractor(videos_embed) 46 | return singals_feature,CWTs_feature,videos_feature 47 | class Proj(nn.Module): 48 | def __init__(self,in_feature_dim,out_feature_dim): 49 | super().__init__() 50 | #projection head 51 | self.singals_proj = prejection_head(in_feature_dim,out_feature_dim) 52 | self.CWTs_proj = prejection_head(in_feature_dim,out_feature_dim) 53 | self.videos_proj = prejection_head(in_feature_dim,out_feature_dim) 54 | 55 | def forward(self,singals_feature,CWTs_feature,videos_feature): 56 | singals_projection = self.singals_proj(singals_feature) 57 | CWTs_projection = self.CWTs_proj(CWTs_feature)#这里应该不用加layer norm,因为transformer的最后有norm 58 | videos_projection = self.videos_proj(videos_feature) 59 | 60 | return singals_projection,CWTs_projection,videos_projection 61 | 62 | 63 | -------------------------------------------------------------------------------- /Pretrain/patch_encoder_and_projection_head_and_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class patch_embedding_for_CWTs(nn.Module): 5 | def __init__(self,dim): 6 | super().__init__() 7 | self.conv = nn.Conv2d(2,dim,kernel_size = (99,3),padding = (0,1)) 8 | 9 | def forward(self,x): 10 | x = self.conv(x) 11 | x = x.flatten(2).transpose(1, 2) 12 | return x 13 | 14 | class patch_embedding_for_singals(nn.Module): 15 | def __init__(self,dim): 16 | super().__init__() 17 | self.conv = nn.Conv2d(1,dim,kernel_size = (2,3),padding = (0,1)) 18 | 19 | def forward(self,x): 20 | x = x.unsqueeze(1) 21 | x = self.conv(x) 22 | x = x.flatten(2).transpose(1, 2) 23 | return x 24 | 25 | class patch_embedding_for_videos(nn.Module): 26 | def __init__(self,dim): 27 | super().__init__() 28 | self.conv = nn.Conv3d(1,dim,kernel_size = (3,14,14),padding = (1,0,0)) 29 | 30 | def forward(self,x): 31 | x = x.permute(0,2,1,3,4) 32 | x = self.conv(x) 33 | x = x.flatten(2).transpose(1, 2) 34 | return x 35 | 36 | class patch_embedding_for_images(nn.Module): 37 | def __init__(self,dim): 38 | super().__init__() 39 | self.conv = nn.Conv2d(1,dim,kernel_size = (14,14),padding = 0) 40 | 41 | def forward(self,x): 42 | x = self.conv(x) 43 | x = x.flatten(2).transpose(1, 2).repeat(1, 128, 1) 44 | return x 45 | 46 | class classifier_head(nn.Module): 47 | def __init__(self,in_dim,out_dim): 48 | super().__init__() 49 | self.fc1 = nn.Linear(in_dim,in_dim) 50 | self.fc2 = nn.Linear(in_dim,out_dim) 51 | 52 | def forward(self,x): 53 | return self.fc2(nn.functional.relu(self.fc1(x))) 54 | 55 | 56 | class prejection_head(nn.Module): 57 | def __init__(self,in_dim,out_dim): 58 | super().__init__() 59 | self.fc1 = nn.Linear(in_dim,in_dim) 60 | self.fc2 = nn.Linear(in_dim,out_dim) 61 | 62 | def forward(self,x): 63 | return self.fc2(nn.functional.relu(self.fc1(x))) -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.7" 2 | 3 | from .model import ViT 4 | from .configs import * 5 | from .utils import load_pretrained_weights 6 | -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Pretrain/pytorch_pretrained_vit/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__pycache__/configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Pretrain/pytorch_pretrained_vit/__pycache__/configs.cpython-38.pyc -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Pretrain/pytorch_pretrained_vit/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Pretrain/pytorch_pretrained_vit/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Pretrain/pytorch_pretrained_vit/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/configs.py: -------------------------------------------------------------------------------- 1 | """configs.py - ViT model configurations, based on: 2 | https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py 3 | """ 4 | 5 | def get_base_config(): 6 | """Base ViT config ViT""" 7 | return dict( 8 | dim=768, 9 | ff_dim=3072, 10 | num_heads=12, 11 | num_layers=12, 12 | attention_dropout_rate=0.0, 13 | dropout_rate=0.1, 14 | representation_size=768, 15 | classifier='token' 16 | ) 17 | 18 | def get_b16_config(): 19 | """Returns the ViT-B/16 configuration.""" 20 | config = get_base_config() 21 | config.update(dict(patches=(16, 16))) 22 | return config 23 | 24 | def get_b32_config(): 25 | """Returns the ViT-B/32 configuration.""" 26 | config = get_b16_config() 27 | config.update(dict(patches=(32, 32))) 28 | return config 29 | 30 | def get_l16_config(): 31 | """Returns the ViT-L/16 configuration.""" 32 | config = get_base_config() 33 | config.update(dict( 34 | patches=(16, 16), 35 | dim=1024, 36 | ff_dim=4096, 37 | num_heads=16, 38 | num_layers=24, 39 | attention_dropout_rate=0.0, 40 | dropout_rate=0.1, 41 | representation_size=1024 42 | )) 43 | return config 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.update(dict(patches=(32, 32))) 49 | return config 50 | 51 | def drop_head_variant(config): 52 | config.update(dict(representation_size=None)) 53 | return config 54 | 55 | 56 | PRETRAINED_MODELS = { 57 | 'B_16': { 58 | 'config': get_b16_config(), 59 | 'num_classes': 21843, 60 | 'image_size': (224, 224), 61 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth" 62 | }, 63 | 'B_32': { 64 | 'config': get_b32_config(), 65 | 'num_classes': 21843, 66 | 'image_size': (224, 224), 67 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" 68 | }, 69 | 'L_16': { 70 | 'config': get_l16_config(), 71 | 'num_classes': 21843, 72 | 'image_size': (224, 224), 73 | 'url': None 74 | }, 75 | 'L_32': { 76 | 'config': get_l32_config(), 77 | 'num_classes': 21843, 78 | 'image_size': (224, 224), 79 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth" 80 | }, 81 | 'B_16_imagenet1k': { 82 | 'config': drop_head_variant(get_b16_config()), 83 | 'num_classes': 1000, 84 | 'image_size': (384, 384), 85 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth" 86 | }, 87 | 'B_32_imagenet1k': { 88 | 'config': drop_head_variant(get_b32_config()), 89 | 'num_classes': 1000, 90 | 'image_size': (384, 384), 91 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth" 92 | }, 93 | 'L_16_imagenet1k': { 94 | 'config': drop_head_variant(get_l16_config()), 95 | 'num_classes': 1000, 96 | 'image_size': (384, 384), 97 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth" 98 | }, 99 | 'L_32_imagenet1k': { 100 | 'config': drop_head_variant(get_l32_config()), 101 | 'num_classes': 1000, 102 | 'image_size': (384, 384), 103 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth" 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for ViT. 2 | They are built to mirror those in the official Jax implementation. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from .transformer import Transformer 11 | from .utils import load_pretrained_weights, as_tuple 12 | from .configs import PRETRAINED_MODELS 13 | 14 | 15 | class PositionalEmbedding1D(nn.Module): 16 | """Adds (optionally learned) positional embeddings to the inputs.""" 17 | 18 | def __init__(self, seq_len, dim): 19 | super().__init__() 20 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 21 | 22 | def forward(self, x): 23 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 24 | return x + self.pos_embedding 25 | 26 | 27 | class ViT(nn.Module): 28 | def __init__( 29 | self, 30 | dim, 31 | seq_len, 32 | num_heads, 33 | num_layers, 34 | dropout_rate = 0.1, 35 | ): 36 | super().__init__() 37 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 38 | seq_len += 1 39 | self.positional_embedding = PositionalEmbedding1D(seq_len, dim) 40 | # Transformer 41 | self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 42 | ff_dim=4*dim, dropout=dropout_rate) 43 | # Classifier head 44 | self.norm = nn.LayerNorm(dim, eps=1e-6) 45 | # Initialize weights 46 | self.init_weights() 47 | 48 | @torch.no_grad() 49 | def init_weights(self): 50 | def _init(m): 51 | if isinstance(m, nn.Linear): 52 | nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) 55 | self.apply(_init) 56 | #nn.init.constant_(self.fc.weight, 0) 57 | #nn.init.constant_(self.fc.bias, 0) 58 | nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02) # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02) 59 | nn.init.constant_(self.class_token, 0) 60 | 61 | def forward(self, x): 62 | b = x.shape[0] 63 | x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 64 | x = self.positional_embedding(x) # b,gh*gw+1,d (+1:Patch + Position Embedding 65 | x = self.transformer(x) # b,gh*gw+1,d 有数层transformer block 66 | x = self.norm(x)[:, 0] # b,d:每个图片仅使用class embedding作为特征 67 | return x 68 | 69 | -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.scores = None # for visualization 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | scores = self.drop(F.softmax(scores, dim=-1)) 53 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 54 | h = (scores @ v).transpose(1, 2).contiguous() 55 | # -merge-> (B, S, D) 56 | h = merge_last(h, 2) 57 | self.scores = scores 58 | return h 59 | 60 | 61 | class PositionWiseFeedForward(nn.Module): 62 | """FeedForward Neural Networks for each position""" 63 | def __init__(self, dim, ff_dim): 64 | super().__init__() 65 | self.fc1 = nn.Linear(dim, ff_dim) 66 | self.fc2 = nn.Linear(ff_dim, dim) 67 | 68 | def forward(self, x): 69 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 70 | return self.fc2(F.gelu(self.fc1(x))) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, ff_dim, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 78 | self.proj = nn.Linear(dim, dim) 79 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 80 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.drop = nn.Dropout(dropout) 83 | 84 | def forward(self, x, mask): 85 | h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 86 | x = x + h 87 | h = self.drop(self.pwff(self.norm2(x))) 88 | x = x + h 89 | return x 90 | 91 | 92 | class Transformer(nn.Module): 93 | """Transformer with Self-Attentive Blocks""" 94 | def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 95 | super().__init__() 96 | self.blocks = nn.ModuleList([ 97 | Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 98 | 99 | def forward(self, x, mask=None): 100 | for block in self.blocks: 101 | x = block(x, mask) 102 | return x 103 | -------------------------------------------------------------------------------- /Pretrain/pytorch_pretrained_vit/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | from .configs import PRETRAINED_MODELS 9 | 10 | 11 | def load_pretrained_weights( 12 | model, 13 | model_name=None, 14 | weights_path=None, 15 | load_first_conv=True, 16 | load_fc=True, 17 | load_repr_layer=False, 18 | resize_positional_embedding=False, 19 | verbose=True, 20 | strict=True, 21 | ): 22 | """Loads pretrained weights from weights path or download using url. 23 | Args: 24 | model (Module): Full model (a nn.Module) 25 | model_name (str): Model name (e.g. B_16) 26 | weights_path (None or str): 27 | str: path to pretrained weights file on the local disk. 28 | None: use pretrained weights downloaded from the Internet. 29 | load_first_conv (bool): Whether to load patch embedding. 30 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 31 | resize_positional_embedding=False, 32 | verbose (bool): Whether to print on completion 33 | """ 34 | assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path' 35 | 36 | # Load or download weights 37 | if weights_path is None: 38 | url = PRETRAINED_MODELS[model_name]['url'] 39 | if url: 40 | state_dict = model_zoo.load_url(url) 41 | else: 42 | raise ValueError(f'Pretrained model for {model_name} has not yet been released') 43 | else: 44 | state_dict = torch.load(weights_path) 45 | 46 | # Modifications to load partial state dict 47 | expected_missing_keys = [] 48 | if not load_first_conv and 'patch_embedding.weight' in state_dict: 49 | expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias'] 50 | if not load_fc and 'fc.weight' in state_dict: 51 | expected_missing_keys += ['fc.weight', 'fc.bias'] 52 | if not load_repr_layer and 'pre_logits.weight' in state_dict: 53 | expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias'] 54 | for key in expected_missing_keys: 55 | state_dict.pop(key) 56 | 57 | # Change size of positional embeddings 58 | if resize_positional_embedding: 59 | posemb = state_dict['positional_embedding.pos_embedding'] 60 | posemb_new = model.state_dict()['positional_embedding.pos_embedding'] 61 | state_dict['positional_embedding.pos_embedding'] = \ 62 | resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 63 | has_class_token=hasattr(model, 'class_token')) 64 | maybe_print('Resized positional embeddings from {} to {}'.format( 65 | posemb.shape, posemb_new.shape), verbose) 66 | 67 | # Load state dict 68 | ret = model.load_state_dict(state_dict, strict=False) 69 | if strict: 70 | assert set(ret.missing_keys) == set(expected_missing_keys), \ 71 | 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 72 | assert not ret.unexpected_keys, \ 73 | 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 74 | maybe_print('Loaded pretrained weights.', verbose) 75 | else: 76 | maybe_print('Missing keys when loading pretrained weights: {}'.format(ret.missing_keys), verbose) 77 | maybe_print('Unexpected keys when loading pretrained weights: {}'.format(ret.unexpected_keys), verbose) 78 | return ret 79 | 80 | 81 | def maybe_print(s: str, flag: bool): 82 | if flag: 83 | print(s) 84 | 85 | 86 | def as_tuple(x): 87 | return x if isinstance(x, tuple) else (x, x) 88 | 89 | 90 | def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 91 | """Rescale the grid of position embeddings in a sensible manner""" 92 | from scipy.ndimage import zoom 93 | 94 | # Deal with class token 95 | ntok_new = posemb_new.shape[1] 96 | if has_class_token: # this means classifier == 'token' 97 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 98 | ntok_new -= 1 99 | else: 100 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 101 | 102 | # Get old and new grid sizes 103 | gs_old = int(np.sqrt(len(posemb_grid))) 104 | gs_new = int(np.sqrt(ntok_new)) 105 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 106 | 107 | # Rescale grid 108 | zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) 109 | posemb_grid = zoom(posemb_grid, zoom_factor, order=1) 110 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 111 | posemb_grid = torch.from_numpy(posemb_grid) 112 | 113 | # Deal with class token and return 114 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 115 | return posemb 116 | 117 | -------------------------------------------------------------------------------- /Pretrain/signal_data_add.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import torch 4 | import pickle as pk 5 | import torch 6 | 7 | def get_signal(dir1,dir2,rate,L=30,snrs_index=0): 8 | #dir1:原信号 9 | #dir2,dir3:模态间信号,dir2为小波变换能量图,dir3为星座图 10 | #dir4:序列模态对比信号,是原信号经time warping变换后的信号 11 | 12 | f1 = open(dir1, 'rb') 13 | f2 = open(dir2, 'rb') 14 | 15 | data = pk.load(f1, encoding='latin1') 16 | all_snrs, mods = map(lambda j: sorted(list(set(map(lambda x: x[j], data.keys())))), [1, 0]) 17 | 18 | snrs = all_snrs[snrs_index:] 19 | not_used_snrs = all_snrs[:snrs_index] 20 | print("使用的snrs: ",snrs) 21 | print("未使用的snr:",not_used_snrs) 22 | for mod in mods: 23 | for snr in not_used_snrs: 24 | del data[(mod, snr)] 25 | print("原始信号读取完成") 26 | CWTdata = pk.load(f2, encoding='latin1') 27 | for mod in mods: 28 | for snr in not_used_snrs: 29 | del CWTdata[(mod, snr)] 30 | print("CWT读取完成") 31 | 32 | 33 | 34 | snr_choise = [10] 35 | X = [] 36 | CWT = [] 37 | lbl = [] 38 | train_idx = [] 39 | lbl_idx = [] 40 | val_idx = [] 41 | data_size = 0 42 | test_idx=[] 43 | 44 | for mod in mods: 45 | for snr in snrs: 46 | length = data[(mod, snr)].shape[0] 47 | X.append(data.pop((mod, snr))) 48 | CWT.append(CWTdata.pop((mod, snr))) 49 | for i in range(length): lbl.append((mod, snr)) 50 | train_choise = np.random.choice(range(data_size, data_size + length), size=int(length * 0.6 * rate), replace=False) 51 | train_idx += list(train_choise) 52 | if snr in snr_choise: 53 | lbl_idx += list(np.random.choice(train_choise, size=L, replace=False)) 54 | 55 | val_idx += list( 56 | np.random.choice(list(set(range(data_size, data_size + length)) - set(train_idx)), 57 | size=int(length * 0.2), replace=False)) 58 | 59 | test_idx += list( 60 | np.random.choice(list(set(range(data_size, data_size + length)) - set(train_idx) - set(val_idx)), 61 | size=int(length * 0.2), replace=False)) 62 | 63 | data_size += length 64 | 65 | 66 | print("每一类中有{}个训练样本".format(length * 0.6 * rate)) 67 | 68 | X = np.vstack(X) 69 | CWT = np.vstack(CWT) 70 | 71 | #X = np.expand_dims(X, axis=1) 72 | print("X.shape",X.shape) 73 | print("CWT.shape",CWT.shape) 74 | 75 | X_train = X[train_idx] 76 | X_val = X[val_idx] 77 | X_test = X[test_idx] 78 | del X 79 | 80 | CWT_train = CWT[train_idx] 81 | CWT_val = CWT[val_idx] 82 | CWT_test = CWT[test_idx] 83 | del CWT 84 | 85 | Y_train = np.array(list(map(lambda x: mods.index(lbl[x][0]), train_idx))) 86 | Y_val = np.array(list(map(lambda x: mods.index(lbl[x][0]), val_idx))) 87 | Y_test = np.array(list(map(lambda x: mods.index(lbl[x][0]), test_idx))) 88 | 89 | 90 | 91 | traindataset = arr_to_dataset(X_train, CWT_train, Y_train) 92 | 93 | valdataset = arr_to_dataset(X_val, CWT_val, Y_val) 94 | 95 | testdataset = arr_to_dataset(X_test, CWT_test, Y_test) 96 | 97 | 98 | snr_index = np.array(list(map(lambda x: lbl[x][1], test_idx))) 99 | 100 | snr_indexs = [] 101 | for snr in snrs: 102 | snr_indexs.extend(np.where(snr_index == snr)) 103 | 104 | return traindataset,valdataset,testdataset, mods,snrs,snr_indexs 105 | 106 | 107 | def arr_to_dataset(data1, data2, label): 108 | data1 = torch.from_numpy(data1) 109 | data2 = torch.from_numpy(data2) 110 | 111 | 112 | label = torch.from_numpy(label) 113 | dataset = torch.utils.data.TensorDataset(data1,data2,label) 114 | return dataset 115 | 116 | -------------------------------------------------------------------------------- /Pretrain/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | import os 6 | import numpy as np 7 | from sklearn.metrics import confusion_matrix 8 | import matplotlib.pyplot as plt 9 | 10 | def get_exp(len=4000): 11 | singals = [] 12 | labels = [] 13 | for i in range(len): 14 | k = 0 15 | l = random.randint(0,7) 16 | if l == 0: 17 | phi = math.pi/12 18 | labels.append(0) 19 | elif l == 1: 20 | phi = math.pi/10 21 | labels.append(1) 22 | 23 | elif l == 2: 24 | phi = math.pi/8 25 | labels.append(2) 26 | 27 | elif l ==3: 28 | phi = math.pi/6 29 | labels.append(3) 30 | 31 | elif l ==4: 32 | phi = math.pi/4 33 | labels.append(4) 34 | 35 | 36 | elif l ==5: 37 | phi = math.pi/14 38 | labels.append(5) 39 | 40 | elif l ==6: 41 | phi = math.pi/16 42 | labels.append(6) 43 | 44 | 45 | 46 | else: 47 | phi = math.pi/2 48 | labels.append(7) 49 | 50 | singal_x = [] 51 | singal_y = [] 52 | for j in range(128): 53 | k = k + random.randint(0,5) 54 | singal_x.append(math.cos(k*phi)) 55 | singal_y.append(math.sin(k*phi)) 56 | singals.append([singal_x,singal_y]) 57 | 58 | 59 | Singal = torch.tensor([]) 60 | for i in range(len): 61 | if i == 0: 62 | Singal = torch.tensor([singals[i]]) 63 | else: 64 | Singal = torch.cat([Singal,torch.tensor([singals[i]])],dim=0) 65 | Labels = torch.tensor(labels) 66 | 67 | return Singal,Labels 68 | 69 | 70 | 71 | def compute_loss(pre,target): 72 | l2 = torch.mm(torch.norm(target,dim=1).unsqueeze(1),torch.norm(pre,dim=1).unsqueeze(0)) 73 | #pre = pre + 1e-5 74 | bsz = target.shape[0] 75 | feature_dim = target.shape[1] 76 | target = target.unsqueeze(1).expand(bsz, bsz, feature_dim) 77 | pre = pre.unsqueeze(0).expand(bsz, bsz, feature_dim) 78 | # 对 A 中每个向量与 B 中每个向量进行点积 79 | dot_product = torch.matmul(target, pre.transpose(1, 2)) 80 | # 将点积结果保存为矩阵形式 81 | result = dot_product.squeeze() 82 | result = result[:,0,:] 83 | result = torch.div(result,l2) 84 | result = torch.div(result,0.07) 85 | #print("result",result) 86 | result = torch.exp(result) 87 | #print(result) 88 | diag = torch.diag(result) 89 | #print(diag) 90 | total_lic = torch.sum(result,dim=0) 91 | #print(total_lic) 92 | lic = torch.div(diag,total_lic) 93 | lic = -torch.log(lic) 94 | #print("lic",lic) 95 | return torch.sum(lic)/lic.shape[0] 96 | 97 | 98 | class MyDataset(Dataset): 99 | def __init__(self, data, labels): 100 | self.data = data 101 | self.labels = labels 102 | 103 | def __len__(self): 104 | return len(self.labels) 105 | 106 | def __getitem__(self, index): 107 | x = self.data[index] 108 | y = self.labels[index] 109 | return x, y 110 | 111 | 112 | def accuracy(output, target, topk=(1,)): 113 | """Computes the precision@k for the specified values of k""" 114 | maxk = max(topk) 115 | batch_size = target.size(0) 116 | 117 | _, pred = output.topk(maxk, 1, True, True) 118 | pred = pred.t() 119 | 120 | 121 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 122 | 123 | res = [] 124 | for k in topk: 125 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 126 | res.append(correct_k.mul_(100.0 / batch_size)) 127 | return res 128 | 129 | 130 | 131 | 132 | def figure_plot(true_labels,pre_labels,classes,snrs,snr_indexs,figure_path=None): 133 | acc = {} 134 | acc_mod_snr = np.zeros((len(classes), len(snrs))) 135 | for i in range(len(snrs)): 136 | true_label = true_labels[snr_indexs[i]] 137 | #print(true_label.shape) 138 | pre_label = pre_labels[snr_indexs[i]] 139 | cor = np.sum(true_label == pre_label) 140 | acc[snrs[i]] = 1.0 * cor / true_label.shape[0] 141 | 142 | plot_confusion_matrix(true_label,pre_label,classes, 143 | title="Confusion Matrix (SNR=%d)(ACC=%2f)" % (snrs[i], 100.0 * acc[snrs[i]]), 144 | save_filename =os.path.join(figure_path,'Confusion(SNR=%d)(ACC=%2f).png' % (snrs[i], 100.0 * acc[snrs[i]]))) 145 | confnorm_i, _, _ = calculate_confusion_matrix(true_label, pre_label, classes) 146 | acc_mod_snr[:, i] = np.round(np.diag(confnorm_i) / np.sum(confnorm_i, axis=1), 3) 147 | 148 | 149 | plt.plot(snrs, list(map(lambda x: acc[x], snrs)),'.-') 150 | 151 | 152 | plt.xlabel("Signal to Noise Ratio") 153 | plt.ylabel("Classification Accuracy") 154 | plt.title("CNN Classification Accuracy on dataset RadioML 2''016.10 Alpha") 155 | plt.savefig(os.path.join(figure_path,'dB to Noise Ratio')) 156 | plt.close() 157 | 158 | # plot acc of each mod in one picture 159 | dis_num = len(classes) 160 | for g in range(int(np.ceil(acc_mod_snr.shape[0] / dis_num))): 161 | assert (0 <= dis_num <= acc_mod_snr.shape[0]) 162 | beg_index = g * dis_num 163 | end_index = np.min([(g + 1) * dis_num, acc_mod_snr.shape[0]]) 164 | 165 | plt.figure(figsize=(12, 10)) 166 | plt.xlabel("Signal to Noise Ratio") 167 | plt.ylabel("Classification Accuracy") 168 | plt.title("Classification Accuracy for Each Mod") 169 | 170 | for i in range(beg_index, end_index): 171 | plt.plot(snrs, acc_mod_snr[i],'.-', label=classes[i]) 172 | # 设置数字标签 173 | for x, y in zip(snrs, acc_mod_snr[i]): 174 | plt.text(x, y, y, ha='center', va='bottom', fontsize=8) 175 | 176 | plt.legend() 177 | plt.grid() 178 | plt.savefig(os.path.join(figure_path,'acc_with_mod.png')) 179 | plt.close() 180 | return acc,acc_mod_snr 181 | 182 | 183 | def calculate_confusion_matrix(Y,Y_hat,classes): 184 | n_classes = len(classes) 185 | conf = np.zeros([n_classes,n_classes]) 186 | confnorm = np.zeros([n_classes,n_classes]) 187 | 188 | for k in range(0,Y.shape[0]): 189 | i = Y[k] 190 | j = Y_hat[k] 191 | conf[i,j] = conf[i,j] + 1 192 | 193 | for i in range(0,n_classes): 194 | confnorm[i,:] = conf[i,:] / np.sum(conf[i,:]) 195 | # print(confnorm) 196 | 197 | right = np.sum(np.diag(conf)) 198 | wrong = np.sum(conf) - right 199 | return confnorm,right,wrong 200 | 201 | 202 | 203 | 204 | def plot_confusion_matrix(y_true, y_pred, labels, save_filename=None, title='Confusion matrix'): 205 | 206 | cmap = plt.cm.binary 207 | cm = confusion_matrix(y_true, y_pred) 208 | tick_marks = np.array(range(len(labels))) + 0.5 209 | np.set_printoptions(precision=2) 210 | cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 211 | plt.figure(figsize=(10, 8), dpi=120) 212 | ind_array = np.arange(len(labels)) 213 | x, y = np.meshgrid(ind_array, ind_array) 214 | intFlag = 0 215 | for x_test, y_test in zip(x.flatten(), y.flatten()): 216 | 217 | if (intFlag): 218 | c = cm[y_test][x_test] 219 | plt.text(x_test, y_test, "%d" % (c,), color='red', fontsize=8, va='center', ha='center') 220 | 221 | else: 222 | c = cm_normalized[y_test][x_test] 223 | if (c > 0.01): 224 | #这里是绘制数字,可以对数字大小和颜色进行修改 225 | plt.text(x_test, y_test, "%0.2f" % (c,), color='red', fontsize=10, va='center', ha='center') 226 | else: 227 | plt.text(x_test, y_test, "%d" % (0,), color='red', fontsize=10, va='center', ha='center') 228 | if(intFlag): 229 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 230 | else: 231 | plt.imshow(cm_normalized, interpolation='nearest', cmap=cmap) 232 | plt.gca().set_xticks(tick_marks, minor=True) 233 | plt.gca().set_yticks(tick_marks, minor=True) 234 | plt.gca().xaxis.set_ticks_position('none') 235 | plt.gca().yaxis.set_ticks_position('none') 236 | plt.grid(True, which='minor', linestyle='-') 237 | plt.gcf().subplots_adjust(bottom=0.15) 238 | plt.title(title) 239 | plt.colorbar() 240 | xlocations = np.array(range(len(labels))) 241 | plt.xticks(xlocations, labels, rotation=90) 242 | plt.yticks(xlocations, labels) 243 | plt.ylabel('Index of True Classes') 244 | plt.xlabel('Index of Predict Classes') 245 | plt.savefig(save_filename) 246 | plt.close() 247 | 248 | 249 | class video_Converter: 250 | def __init__(self,batch_bsz,frame_legth): 251 | self.batch_bsz = batch_bsz 252 | self.frame_legth = frame_legth 253 | self.sample_idx = torch.arange(0, batch_bsz).repeat(128, 1).t().reshape(-1).cuda() 254 | self.frame_idx = torch.arange(0, 128).repeat(1, batch_bsz).squeeze().cuda() 255 | self.Fundation = torch.zeros(batch_bsz, 128, self.frame_legth, self.frame_legth).cuda() 256 | self.converter = singal_to_video() 257 | def __call__(self, singal): 258 | x = self.converter(singal, self.batch_bsz, self.frame_legth, self.sample_idx, self.frame_idx, self.Fundation) 259 | x = (x + torch.roll(x, shifts=-1, dims=1) * 0.5 + torch.roll(x, shifts=1, dims=1) * 0.5) 260 | x = x.unsqueeze(2) 261 | return x 262 | 263 | class singal_to_video(object): 264 | def __init__(self): 265 | pass 266 | 267 | def __call__(self, singal, bsz, frame_legth, sample_idx, frame_idx, Fundation): 268 | lists_for_image = torch.transpose(singal, 1, 2) 269 | lists_for_image = torch.stack( 270 | [(a-torch.min(a).item() )/(torch.max(a).item()-torch.min(a).item()) for a in lists_for_image]) # 这里该成了torch.cat 范围是:[-0.5,0.5] 271 | lists_for_image = torch.round(torch.mul(lists_for_image, frame_legth-1)).to(torch.int) 272 | lists_for_image = lists_for_image.reshape(128 * bsz, 2) 273 | lists_for_image = lists_for_image.long() 274 | result = Fundation.zero_() 275 | result[sample_idx, frame_idx, lists_for_image[:, 0], lists_for_image[:, 1]] += 255 276 | 277 | return result -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | 1. Prepare the data: 2 | 1. Download the RML2016a dataset (or other datasets). 3 | 2. Run the generate_CWTdataset.ipynb file to generate the CWTdata.pkl. 4 | 2. Pre-trained Model: 5 | 1. Get into the Pretrain folder. 6 | 2. Run main.py. 7 | 3. Comparison Experiments of Different Models: 8 | 1. Get into the Test and Comparison folder. 9 | 2. Run main.py. 10 | -------------------------------------------------------------------------------- /Test and Comparison/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pickle as pk 5 | import numpy as np 6 | from utils import * 7 | from tqdm import tqdm 8 | from signal_data_add import get_signal 9 | from torch.utils.data import DataLoader 10 | import pywt 11 | import matplotlib.pyplot as plt 12 | from patch_encoder_and_projection_head_and_classifier import * 13 | from pytorch_pretrained_vit import ViT 14 | import os 15 | from models import * 16 | import pandas as pd 17 | import time 18 | 19 | from fvcore.nn import FlopCountAnalysis 20 | 21 | def train(model_name,classifier,epoches,train_loader,val_loader,save_path,model=None,finetunning = False,converter = None): 22 | 23 | if model_name == 'HVSF' or model_name == 'HVSF_wo_finetunning': 24 | if finetunning: 25 | optimizer = torch.optim.Adam([{'params': model.parameters(), 'lr': 0.0003},{'params': classifier.parameters(), 'lr': 0.001}]) 26 | else: 27 | optimizer = torch.optim.Adam(classifier.parameters(), lr= 0.001) 28 | 29 | max_acc = 0 30 | for epoch in range(epoches): 31 | time1 = time.time() 32 | with tqdm(total=len(train_loader)) as p_bar: 33 | for batch_idx, (singals, CWTs, labels) in enumerate(train_loader): 34 | singals = singals.cuda() 35 | videos = converter(singals) 36 | CWTs = CWTs.cuda() 37 | labels = labels.cuda() 38 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 39 | feature = torch.cat([singals_feature,CWTs_feature,videos_feature],dim = -1) 40 | pre = classifier(feature) 41 | loss = F.cross_entropy(pre,labels) 42 | optimizer.zero_grad() 43 | loss.backward() 44 | optimizer.step() 45 | p_bar.set_description("classifier_train: epoch:{} batch_idx:{} loss:{:02f}".format(epoch, batch_idx, loss)) 46 | p_bar.update() 47 | time2 = time.time() 48 | if epoch % 10 == 0 and epoch != 0: 49 | acc = test(model_name,classifier,val_loader,model = model,converter = converter) 50 | if acc>max_acc: 51 | print("max_acc:",acc) 52 | torch.save(model.state_dict(),os.path.join(save_path,'model_state.pth')) 53 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 54 | max_acc = acc 55 | allocated_memory = torch.cuda.memory_allocated() / 1024 ** 2 # MB 56 | cached_memory = torch.cuda.memory_reserved() / 1024 ** 2 # MB 57 | acc = test(model_name,classifier,val_loader,model = model,converter = converter) 58 | if acc>max_acc: 59 | print("max_acc:",acc) 60 | torch.save(model.state_dict(),os.path.join(save_path,'model_state.pth')) 61 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 62 | max_acc = acc 63 | epoch_time = time2-time1 64 | 65 | singals, CWTs, labels = next(iter(train_loader)) 66 | singals = singals[0:1] 67 | CWTs = CWTs[0:1] 68 | singals = singals.cuda() 69 | CWTs = CWTs.cuda() 70 | converter_tmp = video_Converter(1,14) 71 | videos = converter_tmp(singals) 72 | per_sample_start_time = time.time() 73 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 74 | feature = torch.cat([singals_feature,CWTs_feature,videos_feature],dim = -1) 75 | pre = classifier(feature) 76 | per_sample_end_time = time.time() 77 | infer_one_sample_time = per_sample_end_time - per_sample_start_time 78 | 79 | flops1 = FlopCountAnalysis(model,(singals,CWTs,videos)) 80 | flops2 = FlopCountAnalysis(classifier, feature) 81 | flops = flops1.total() + flops2.total() 82 | 83 | num_params1 = sum(p.numel() for p in model.parameters()) 84 | num_params2 = sum(p.numel() for p in classifier.parameters()) 85 | num_params = num_params1+num_params2 86 | 87 | info = {} 88 | info['train_epoch_time(bsz = 100sample)'] = epoch_time 89 | info['infer_one_sample_time'] = infer_one_sample_time 90 | info['flops(infer_1_sample)'] = flops 91 | info['num_params(infer)'] = num_params 92 | info['allocated_memory'] = allocated_memory 93 | info['cached_memory'] = cached_memory 94 | info = pd.DataFrame(info.items()) 95 | info.to_csv(os.path.join(save_path,'info.csv')) 96 | 97 | 98 | 99 | else: 100 | optimizer = torch.optim.Adam(classifier.parameters(), lr = 0.001) 101 | max_acc = 0 102 | for epoch in range(epoches): 103 | time1 = time.time() 104 | with tqdm(total=len(train_loader)) as p_bar: 105 | for batch_idx, (data,_, labels) in enumerate(train_loader): 106 | data = data.cuda() 107 | labels = labels.cuda() 108 | pre = classifier(data) 109 | loss = F.cross_entropy(pre,labels) 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | p_bar.set_description("classifier_train: epoch:{} batch_idx:{} loss:{:02f}".format(epoch, batch_idx, loss)) 114 | p_bar.update() 115 | time2 = time.time() 116 | if epoch % 100 == 0 and epoch != 0: 117 | acc = test(model_name,classifier,val_loader) 118 | if acc>max_acc: 119 | print("max_acc:",acc) 120 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 121 | max_acc = acc 122 | allocated_memory = torch.cuda.memory_allocated() / 1024 ** 2 # MB 123 | cached_memory = torch.cuda.memory_reserved() / 1024 ** 2 # MB 124 | acc = test(model_name,classifier,val_loader) 125 | if acc>max_acc: 126 | print("max_acc:",acc) 127 | torch.save(classifier.state_dict(),os.path.join(save_path,'classifier_state.pth')) 128 | max_acc = acc 129 | 130 | epoch_time = time2-time1 131 | 132 | data,_, labels = next(iter(train_loader)) 133 | data = data[0:1] 134 | data = data.cuda() 135 | per_sample_start_time = time.time() 136 | pre = classifier(data) 137 | per_sample_end_time = time.time() 138 | infer_one_sample_time = per_sample_end_time - per_sample_start_time 139 | 140 | flops = FlopCountAnalysis(classifier,data).total 141 | 142 | num_params = sum(p.numel() for p in classifier.parameters()) 143 | 144 | info = {} 145 | info['train_epoch_time(bsz = 100sample)'] = epoch_time 146 | info['infer_one_sample_time'] = infer_one_sample_time 147 | info['flops(infer_1_sample)'] = flops 148 | info['num_params(infer)'] = num_params 149 | info['allocated_memory'] = allocated_memory 150 | info['cached_memory'] = cached_memory 151 | 152 | info = pd.DataFrame(info.items()) 153 | info.to_csv(os.path.join(save_path,'info.csv')) 154 | 155 | 156 | 157 | def test(model_name,classifier,data_loader,snrs = None,snr_indexs = None,save_path = None,final = False,model=None,converter = None): 158 | prediction = [] 159 | true = [] 160 | if model_name == 'HVSF' or model_name == 'HVSF_wo_finetunning': 161 | if final: 162 | model.load_state_dict(torch.load(os.path.join(save_path,'model_state.pth'))) 163 | classifier.load_state_dict(torch.load(os.path.join(save_path,'classifier_state.pth'))) 164 | for batch_idx, (singals, CWTs, labels) in enumerate(data_loader): 165 | singals = singals.cuda() 166 | videos = converter(singals) 167 | CWTs = CWTs.cuda() 168 | labels = labels.cuda() 169 | singals_feature,CWTs_feature,videos_feature = model(singals,CWTs,videos) 170 | feature = torch.cat([singals_feature,CWTs_feature,videos_feature],dim = -1) 171 | result = torch.argmax(classifier(feature),dim = 1) 172 | prediction.extend(result.cpu().numpy()) 173 | true.extend(labels.cpu().numpy()) 174 | 175 | 176 | if not final: 177 | return sum(np.array(prediction) == np.array(true))/len(true) 178 | else: 179 | prediction = np.array(prediction) 180 | true = np.array(true) 181 | acc = {} 182 | for i in range(len(snrs)): 183 | true_label = true[snr_indexs[i]] 184 | #print(true_label.shape) 185 | pre_label = prediction[snr_indexs[i]] 186 | cor = np.sum(true_label == pre_label) 187 | acc[snrs[i]] = 1.0 * cor / true_label.shape[0] 188 | total_acc = sum(np.array(prediction) == np.array(true))/len(true) 189 | acc['total'] = total_acc 190 | ACC = pd.DataFrame(acc.items()) 191 | ACC.to_csv(os.path.join(save_path,'result.csv')) 192 | return total_acc 193 | 194 | else: 195 | if final: 196 | classifier.load_state_dict(torch.load(os.path.join(save_path,'classifier_state.pth'))) 197 | for batch_idx, (data,_, labels) in enumerate(data_loader): 198 | data = data.cuda() 199 | labels = labels.cuda() 200 | result = torch.argmax(classifier(data),dim = 1) 201 | prediction.extend(result.cpu().numpy()) 202 | true.extend(labels.cpu().numpy()) 203 | if not final: 204 | return sum(np.array(prediction) == np.array(true))/len(true) 205 | else: 206 | prediction = np.array(prediction) 207 | true = np.array(true) 208 | acc = {} 209 | for i in range(len(snrs)): 210 | true_label = true[snr_indexs[i]] 211 | #print(true_label.shape) 212 | pre_label = prediction[snr_indexs[i]] 213 | cor = np.sum(true_label == pre_label) 214 | acc[snrs[i]] = 1.0 * cor / true_label.shape[0] 215 | total_acc = sum(np.array(prediction) == np.array(true))/len(true) 216 | acc['total'] = total_acc 217 | ACC = pd.DataFrame(acc.items()) 218 | ACC.to_csv(os.path.join(save_path,'result.csv')) 219 | return total_acc 220 | 221 | 222 | 223 | for rate in [0.01,0.05,0.1,0.3,0.5,0.6]: 224 | #for rate in [0.6]: 225 | for Time in range(3): 226 | root_path = os.path.join('/root/autodl-tmp/所有模型比较RML2016a',str(rate),str(Time)) 227 | if not os.path.exists(root_path): 228 | os.mkdir(root_path) 229 | 230 | bsz = 100 231 | traindataset,valset,testset,classes,snrs,snr_indexs = get_signal('/root/autodl-tmp/RML2016.10a_dict.pkl','/root/autodl-tmp/CWTdata.pkl',rate,L=1,snrs_index=0) 232 | train_loader = torch.utils.data.DataLoader(traindataset, batch_size=bsz, shuffle=True, drop_last = True) 233 | val_loader = torch.utils.data.DataLoader(valset, batch_size=bsz, shuffle=False, drop_last = True) 234 | test_loader = torch.utils.data.DataLoader(testset, batch_size=bsz, shuffle=False, drop_last = True) 235 | 236 | #for model_name in ['CNN','MCLDNN','PET','CLDNN','HVSF_wo_finetunning','HVSF']: 237 | for model_name in ['HVSF']: 238 | save_path = os.path.join(root_path,model_name) 239 | if not os.path.exists(save_path): 240 | os.mkdir(save_path) 241 | 242 | if model_name == 'HVSF' or model_name == 'HVSF_wo_finetunning': 243 | model = Model(160,16,2).cuda() 244 | classifier = classifier_head(3*160,11).cuda() 245 | model.load_state_dict(torch.load(os.path.join('/root/autodl-tmp/所有模型比较RML2016a','model_state.pth'))) 246 | model.train() 247 | classifier.train() 248 | converter = video_Converter(bsz,14) 249 | if model_name == 'HVSF': 250 | train(model_name,classifier,50,train_loader,val_loader,save_path,model=model,finetunning = True,converter = converter) 251 | else: 252 | train(model_name,classifier,50,train_loader,val_loader,save_path,model=model,finetunning = False,converter = converter) 253 | test(model_name, classifier, test_loader,snrs, snr_indexs, save_path, final = True, model=model, converter = converter) 254 | else: 255 | classifier = get_class(model_name,(2,128),11).cuda() 256 | train(model_name,classifier,50,train_loader,val_loader,save_path) 257 | test(model_name, classifier, test_loader,snrs, snr_indexs, save_path, final = True) 258 | -------------------------------------------------------------------------------- /Test and Comparison/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pickle as pk 6 | import numpy as np 7 | from utils import * 8 | from tqdm import tqdm 9 | from signal_data_add import get_signal 10 | from torch.utils.data import DataLoader 11 | import pywt 12 | import matplotlib.pyplot as plt 13 | from patch_encoder_and_projection_head_and_classifier import * 14 | from pytorch_pretrained_vit import ViT 15 | import os 16 | 17 | def get_class(class_name,input_shape,num_classes): 18 | if class_name == 'CNN': 19 | return CNN(input_shape,num_classes) 20 | if class_name == 'MCLDNN': 21 | return MCLDNN(num_classes) 22 | if class_name == 'PET': 23 | return PET(input_shape,num_classes) 24 | if class_name == 'CLDNN': 25 | return CLDNN(input_shape,num_classes) 26 | 27 | class Model(nn.Module): 28 | def __init__(self,feature_dim,nh,nl): 29 | super().__init__() 30 | self.CWTs_emb = patch_embedding_for_CWTs(feature_dim) 31 | self.singals_emb = patch_embedding_for_singals(feature_dim) 32 | self.videos_emb = patch_embedding_for_videos(feature_dim) 33 | #特征提取器 34 | self.extractor = ViT(feature_dim,128,num_heads=nh,num_layers=nl) 35 | 36 | #layer_norm层(输入) 37 | self.singals_norm_layer = nn.LayerNorm([2,128])#个体初始化一致性 38 | self.CWTs_norm_layer = nn.LayerNorm([2,99,128]) 39 | self.video_norm_layer = nn.LayerNorm([128,1,14,14]) 40 | #layer_norm层(特征) 41 | self.feature_norm_layer = nn.LayerNorm([128,feature_dim]).cuda()#以特征为归一化标准,不能以个体吧?如果以个体,那所有个体都一样了 42 | 43 | def forward(self,singals,CWTs,videos): 44 | #norm 45 | singals = self.singals_norm_layer(singals) 46 | CWTs = self.CWTs_norm_layer(CWTs) 47 | videos = self.video_norm_layer(videos) 48 | 49 | singals_embed = self.feature_norm_layer(self.singals_emb(singals)) 50 | CWTs_embed = self.feature_norm_layer(self.CWTs_emb(CWTs))#每一种模态的特征在进入transformer前,都会进行layer norm,以确保特征间的一致性 51 | videos_embed = self.feature_norm_layer(self.videos_emb(videos)) 52 | #layer_norm以每个模态的一个“个体”为单位(可以更换为以一个特征为单位) 53 | singals_feature = self.extractor(singals_embed) 54 | CWTs_feature = self.extractor(CWTs_embed)#这里应该不用加layer norm,因为transformer的最后有norm 55 | videos_feature = self.extractor(videos_embed) 56 | return singals_feature,CWTs_feature,videos_feature 57 | class Proj(nn.Module): 58 | def __init__(self,in_feature_dim,out_feature_dim): 59 | super().__init__() 60 | #projection head 61 | self.singals_proj = prejection_head(in_feature_dim,out_feature_dim) 62 | self.CWTs_proj = prejection_head(in_feature_dim,out_feature_dim) 63 | self.videos_proj = prejection_head(in_feature_dim,out_feature_dim) 64 | 65 | def forward(self,singals_feature,CWTs_feature,videos_feature): 66 | singals_projection = self.singals_proj(singals_feature) 67 | CWTs_projection = self.CWTs_proj(CWTs_feature)#这里应该不用加layer norm,因为transformer的最后有norm 68 | videos_projection = self.videos_proj(videos_feature) 69 | 70 | return singals_projection,CWTs_projection,videos_projection 71 | 72 | 73 | class CNN(nn.Module): 74 | def __init__(self,input_shape,num_classes): 75 | super(CNN, self).__init__() 76 | self.conv1 = nn.Sequential( 77 | nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(1,5)),nn.BatchNorm2d(32),nn.LeakyReLU(),nn.MaxPool2d(1,2) 78 | ) 79 | 80 | self.conv2 = nn.Sequential( 81 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,3)),nn.BatchNorm2d(32),nn.LeakyReLU(),nn.MaxPool2d(1,2) 82 | ) 83 | 84 | self.conv3 = nn.Sequential( 85 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,3)),nn.BatchNorm2d(32),nn.LeakyReLU(),nn.MaxPool2d(1,2) 86 | ) 87 | dim1 = (((input_shape[1]-4)//2-1)//2-1)//2*32 88 | self.fc1 = nn.Sequential( 89 | nn.Linear(dim1, 100), 90 | nn.Dropout(0.5)) 91 | self.fc2 = nn.Sequential( 92 | nn.Linear(100, num_classes)) 93 | 94 | def forward(self, x): 95 | x = torch.unsqueeze(x,dim=1) 96 | out = x 97 | out = self.conv1(out) 98 | out = self.conv2(out) 99 | out = self.conv3(out) 100 | out = out.view(out.size(0), -1) 101 | #print(out.shape) 102 | out = self.fc1(out) 103 | out = self.fc2(out) 104 | 105 | return out 106 | 107 | class MCLDNN(nn.Module): 108 | def __init__(self,num_classes): 109 | super(MCLDNN, self).__init__() 110 | self.conv1 = nn.Sequential(nn.Conv2d(1,50,(2,8)),nn.ReLU()) 111 | self.conv2 = nn.Sequential(nn.Conv1d(1, 50, 8), nn.ReLU()) 112 | self.conv3 = nn.Sequential(nn.Conv1d(1, 50, 8), nn.ReLU()) 113 | self.conv4 = nn.Sequential(nn.Conv2d(50, 50, (1,8)), nn.ReLU()) 114 | self.conv5 = nn.Sequential(nn.Conv2d(100, 100, (2,5)), nn.ReLU()) 115 | 116 | self.lstm = nn.LSTM(100,128,num_layers=3) 117 | 118 | self.fc1 = nn.Sequential(nn.Linear(128,128),nn.SELU(),nn.Dropout(0.5)) 119 | self.fc2 = nn.Sequential(nn.Linear(128,num_classes),nn.SELU(), nn.Dropout(0.5)) 120 | 121 | def forward(self,x): 122 | input_iq = x.unsqueeze(1) 123 | input_i = x[:,0,:].unsqueeze(1) 124 | input_q = x[:,1,:].unsqueeze(1) 125 | input_iq = self.conv1(input_iq) 126 | input_iq = F.pad(input_iq, [3,4,0,1], "constant", 0) 127 | input_i = self.conv2(input_i) 128 | input_i = F.pad(input_i, [7, 0], "constant", 0) 129 | input_q = self.conv3(input_q) 130 | input_q = F.pad(input_q, [7, 0], "constant", 0) 131 | input_i = input_i.unsqueeze(2) 132 | input_q = input_q.unsqueeze(2) 133 | inputicq = torch.cat([input_i, input_q], 2) 134 | inputicq = self.conv4(inputicq) 135 | inputicq = F.pad(inputicq, [3, 4, 0, 0], "constant", 0) 136 | input = torch.cat([input_iq, inputicq], 1) 137 | input = self.conv5(input) 138 | input = input.reshape(input.shape) 139 | input = torch.squeeze(input,dim=2).permute(2, 0, 1) 140 | input,_ = self.lstm(input) 141 | input = input[-1, :, :] 142 | input = self.fc1(input) 143 | input = self.fc2(input) 144 | 145 | return input 146 | 147 | class classifier_head_layer_2(nn.Module): 148 | def __init__(self,in_dim,out_dim): 149 | super().__init__() 150 | self.fc1 = nn.Linear(in_dim,in_dim) 151 | self.fc2 = nn.Linear(in_dim,out_dim) 152 | 153 | def forward(self,x): 154 | return self.fc2(nn.functional.relu(self.fc1(x))) 155 | 156 | class classifier_head_layer_3(nn.Module): 157 | def __init__(self,in_dim,out_dim): 158 | super().__init__() 159 | self.fc1 = nn.Linear(in_dim,in_dim) 160 | self.fc2 = nn.Linear(in_dim,in_dim) 161 | self.fc3 = nn.Linear(in_dim,out_dim) 162 | 163 | 164 | def forward(self,x): 165 | return self.fc3(nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(x))))) 166 | 167 | class classifier_head_layer_4(nn.Module): 168 | def __init__(self,in_dim,out_dim): 169 | super().__init__() 170 | self.fc1 = nn.Linear(in_dim,in_dim) 171 | self.fc2 = nn.Linear(in_dim,in_dim) 172 | self.fc3 = nn.Linear(in_dim,in_dim) 173 | self.fc4 = nn.Linear(in_dim,out_dim) 174 | 175 | 176 | def forward(self,x): 177 | return self.fc4(nn.functional.relu(self.fc3(nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(x))))))) 178 | 179 | 180 | 181 | class PET(nn.Module): 182 | def __init__(self, input_shape = (2,1024), num_classes=11): 183 | super(PET, self).__init__() 184 | # Define layers 185 | self.input_shape = input_shape 186 | self.fc1 = nn.Linear(input_shape[0] * input_shape[1], 1) # Linear layer for the first input 187 | self.conv1_1 = nn.Conv2d(1, 75, kernel_size=(2, 8), padding='valid') 188 | self.conv1_2 = nn.Conv2d(75, 25, kernel_size=(1, 5), padding='valid') 189 | self.gru = nn.GRU(input_size=25, hidden_size=128, batch_first=True) 190 | self.fc = nn.Linear(128, num_classes) 191 | 192 | def forward(self, input1): 193 | real = input1[:,0,:] 194 | image = input1[:,1,:] 195 | # Flatten and dense layer 196 | x1 = input1.view(input1.shape[0], -1) # Flatten to (batch_size, 256) 197 | x1 = self.fc1(x1) 198 | cos_value = torch.cos(x1) # Shape: (batch_size, 1) 199 | sin_value = torch.sin(x1) # Shape: (batch_size, 1) 200 | sig1 = real * cos_value + image * sin_value 201 | sig2 = image * cos_value + real * sin_value 202 | sig1 = sig1.unsqueeze(1) 203 | sig2 = sig2.unsqueeze(1) 204 | signal = torch.cat([sig1,sig2],dim = 1) 205 | signal = signal.unsqueeze(1) 206 | x3 = F.relu(self.conv1_1(signal)) 207 | x3 = F.relu(self.conv1_2(x3)) 208 | 209 | # Temporal feature extraction 210 | x4 = x3.view(x3.size(0), self.input_shape[1]-11, 25) # Reshape for GRU 211 | x4, _ = self.gru(x4) 212 | 213 | # Final classification 214 | x = self.fc(x4[:, -1, :]) # Use the last time step 215 | return x 216 | 217 | 218 | class CLDNN(nn.Module): 219 | def __init__(self, input_shape, num_classes, dropout_rate=0.4): 220 | super(CLDNN, self).__init__() 221 | self.input_shape = input_shape 222 | self.conv1d = nn.Conv1d(in_channels=2, out_channels=64, kernel_size=8) # Conv1D layer 223 | self.pool = nn.MaxPool1d(kernel_size=2) # Max pooling layer 224 | self.lstm1 = nn.LSTM(input_size=64, hidden_size=64, batch_first=True) # First LSTM layer 225 | self.dropout1 = nn.Dropout(dropout_rate) # Dropout layer 226 | self.lstm2 = nn.LSTM(input_size=64, hidden_size=64, batch_first=True) # Second LSTM layer 227 | self.dropout2 = nn.Dropout(dropout_rate) # Dropout layer 228 | self.flatten = nn.Flatten() # Flatten layer 229 | dim1 = (self.input_shape[1]-8)*32 230 | self.fc = nn.Linear( dim1, num_classes) # Fully connected layer 231 | 232 | def forward(self, x): 233 | x = self.conv1d(x) # Shape: (batch_size, 64, 121) 234 | x = self.pool(x) # Shape: (batch_size, 64, 60) 235 | x = x.permute(0, 2, 1) # Change shape to (batch_size, 60, 64) for LSTM 236 | x, _ = self.lstm1(x) # LSTM layer 237 | x = self.dropout1(x) # Dropout 238 | x, _ = self.lstm2(x) # Second LSTM layer 239 | x = self.dropout2(x) # Dropout 240 | x = self.flatten(x) # Flatten 241 | 242 | x = self.fc(x) # Output layer 243 | return x -------------------------------------------------------------------------------- /Test and Comparison/patch_encoder_and_projection_head_and_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class patch_embedding_for_CWTs(nn.Module): 5 | def __init__(self,dim): 6 | super().__init__() 7 | self.conv = nn.Conv2d(2,dim,kernel_size = (99,3),padding = (0,1)) 8 | 9 | def forward(self,x): 10 | x = self.conv(x) 11 | x = x.flatten(2).transpose(1, 2) 12 | return x 13 | 14 | class patch_embedding_for_singals(nn.Module): 15 | def __init__(self,dim): 16 | super().__init__() 17 | self.conv = nn.Conv2d(1,dim,kernel_size = (2,3),padding = (0,1)) 18 | 19 | def forward(self,x): 20 | x = x.unsqueeze(1) 21 | x = self.conv(x) 22 | x = x.flatten(2).transpose(1, 2) 23 | return x 24 | 25 | class patch_embedding_for_videos(nn.Module): 26 | def __init__(self,dim): 27 | super().__init__() 28 | self.conv = nn.Conv3d(1,dim,kernel_size = (3,14,14),padding = (1,0,0)) 29 | 30 | def forward(self,x): 31 | x = x.permute(0,2,1,3,4) 32 | x = self.conv(x) 33 | x = x.flatten(2).transpose(1, 2) 34 | return x 35 | 36 | class patch_embedding_for_images(nn.Module): 37 | def __init__(self,dim): 38 | super().__init__() 39 | self.conv = nn.Conv2d(1,dim,kernel_size = (14,14),padding = 0) 40 | 41 | def forward(self,x): 42 | x = self.conv(x) 43 | x = x.flatten(2).transpose(1, 2).repeat(1, 128, 1) 44 | return x 45 | 46 | class classifier_head(nn.Module): 47 | def __init__(self,in_dim,out_dim): 48 | super().__init__() 49 | self.fc1 = nn.Linear(in_dim,in_dim) 50 | self.fc2 = nn.Linear(in_dim,out_dim) 51 | 52 | def forward(self,x): 53 | return self.fc2(nn.functional.relu(self.fc1(x))) 54 | 55 | 56 | class prejection_head(nn.Module): 57 | def __init__(self,in_dim,out_dim): 58 | super().__init__() 59 | self.fc1 = nn.Linear(in_dim,in_dim) 60 | self.fc2 = nn.Linear(in_dim,out_dim) 61 | 62 | def forward(self,x): 63 | return self.fc2(nn.functional.relu(self.fc1(x))) -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.7" 2 | 3 | from .model import ViT 4 | from .configs import * 5 | from .utils import load_pretrained_weights 6 | -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Test and Comparison/pytorch_pretrained_vit/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__pycache__/configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Test and Comparison/pytorch_pretrained_vit/__pycache__/configs.cpython-38.pyc -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Test and Comparison/pytorch_pretrained_vit/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Test and Comparison/pytorch_pretrained_vit/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SIM-xidian/Hybrid-View-Self-supervised-Framework-for-Automatic-Modulation-Recognition/e6c84b1fd55c0be4f1d4531761eab1750d7a2182/Test and Comparison/pytorch_pretrained_vit/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/configs.py: -------------------------------------------------------------------------------- 1 | """configs.py - ViT model configurations, based on: 2 | https://github.com/google-research/vision_transformer/blob/master/vit_jax/configs.py 3 | """ 4 | 5 | def get_base_config(): 6 | """Base ViT config ViT""" 7 | return dict( 8 | dim=768, 9 | ff_dim=3072, 10 | num_heads=12, 11 | num_layers=12, 12 | attention_dropout_rate=0.0, 13 | dropout_rate=0.1, 14 | representation_size=768, 15 | classifier='token' 16 | ) 17 | 18 | def get_b16_config(): 19 | """Returns the ViT-B/16 configuration.""" 20 | config = get_base_config() 21 | config.update(dict(patches=(16, 16))) 22 | return config 23 | 24 | def get_b32_config(): 25 | """Returns the ViT-B/32 configuration.""" 26 | config = get_b16_config() 27 | config.update(dict(patches=(32, 32))) 28 | return config 29 | 30 | def get_l16_config(): 31 | """Returns the ViT-L/16 configuration.""" 32 | config = get_base_config() 33 | config.update(dict( 34 | patches=(16, 16), 35 | dim=1024, 36 | ff_dim=4096, 37 | num_heads=16, 38 | num_layers=24, 39 | attention_dropout_rate=0.0, 40 | dropout_rate=0.1, 41 | representation_size=1024 42 | )) 43 | return config 44 | 45 | def get_l32_config(): 46 | """Returns the ViT-L/32 configuration.""" 47 | config = get_l16_config() 48 | config.update(dict(patches=(32, 32))) 49 | return config 50 | 51 | def drop_head_variant(config): 52 | config.update(dict(representation_size=None)) 53 | return config 54 | 55 | 56 | PRETRAINED_MODELS = { 57 | 'B_16': { 58 | 'config': get_b16_config(), 59 | 'num_classes': 21843, 60 | 'image_size': (224, 224), 61 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth" 62 | }, 63 | 'B_32': { 64 | 'config': get_b32_config(), 65 | 'num_classes': 21843, 66 | 'image_size': (224, 224), 67 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" 68 | }, 69 | 'L_16': { 70 | 'config': get_l16_config(), 71 | 'num_classes': 21843, 72 | 'image_size': (224, 224), 73 | 'url': None 74 | }, 75 | 'L_32': { 76 | 'config': get_l32_config(), 77 | 'num_classes': 21843, 78 | 'image_size': (224, 224), 79 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32.pth" 80 | }, 81 | 'B_16_imagenet1k': { 82 | 'config': drop_head_variant(get_b16_config()), 83 | 'num_classes': 1000, 84 | 'image_size': (384, 384), 85 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16_imagenet1k.pth" 86 | }, 87 | 'B_32_imagenet1k': { 88 | 'config': drop_head_variant(get_b32_config()), 89 | 'num_classes': 1000, 90 | 'image_size': (384, 384), 91 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32_imagenet1k.pth" 92 | }, 93 | 'L_16_imagenet1k': { 94 | 'config': drop_head_variant(get_l16_config()), 95 | 'num_classes': 1000, 96 | 'image_size': (384, 384), 97 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_16_imagenet1k.pth" 98 | }, 99 | 'L_32_imagenet1k': { 100 | 'config': drop_head_variant(get_l32_config()), 101 | 'num_classes': 1000, 102 | 'image_size': (384, 384), 103 | 'url': "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/L_32_imagenet1k.pth" 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for ViT. 2 | They are built to mirror those in the official Jax implementation. 3 | """ 4 | 5 | from typing import Optional 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from .transformer import Transformer 11 | from .utils import load_pretrained_weights, as_tuple 12 | from .configs import PRETRAINED_MODELS 13 | 14 | 15 | class PositionalEmbedding1D(nn.Module): 16 | """Adds (optionally learned) positional embeddings to the inputs.""" 17 | 18 | def __init__(self, seq_len, dim): 19 | super().__init__() 20 | self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim)) 21 | 22 | def forward(self, x): 23 | """Input has shape `(batch_size, seq_len, emb_dim)`""" 24 | return x + self.pos_embedding 25 | 26 | 27 | class ViT(nn.Module): 28 | def __init__( 29 | self, 30 | dim, 31 | seq_len, 32 | num_heads, 33 | num_layers, 34 | dropout_rate = 0.1, 35 | ): 36 | super().__init__() 37 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 38 | seq_len += 1 39 | self.positional_embedding = PositionalEmbedding1D(seq_len, dim) 40 | # Transformer 41 | self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 42 | ff_dim=4*dim, dropout=dropout_rate) 43 | # Classifier head 44 | self.norm = nn.LayerNorm(dim, eps=1e-6) 45 | # Initialize weights 46 | self.init_weights() 47 | 48 | @torch.no_grad() 49 | def init_weights(self): 50 | def _init(m): 51 | if isinstance(m, nn.Linear): 52 | nn.init.xavier_uniform_(m.weight) # _trunc_normal(m.weight, std=0.02) # from .initialization import _trunc_normal 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) 55 | self.apply(_init) 56 | #nn.init.constant_(self.fc.weight, 0) 57 | #nn.init.constant_(self.fc.bias, 0) 58 | nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02) # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02) 59 | nn.init.constant_(self.class_token, 0) 60 | 61 | def forward(self, x): 62 | b = x.shape[0] 63 | x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d 64 | x = self.positional_embedding(x) # b,gh*gw+1,d (+1:Patch + Position Embedding 65 | x = self.transformer(x) # b,gh*gw+1,d 有数层transformer block 66 | x = self.norm(x)[:, 0] # b,d:每个图片仅使用class embedding作为特征 67 | return x 68 | 69 | -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/lukemelas/simple-bert 3 | """ 4 | 5 | import numpy as np 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | 11 | def split_last(x, shape): 12 | "split the last dimension to given shape" 13 | shape = list(shape) 14 | assert shape.count(-1) <= 1 15 | if -1 in shape: 16 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 17 | return x.view(*x.size()[:-1], *shape) 18 | 19 | 20 | def merge_last(x, n_dims): 21 | "merge the last n_dims to a dimension" 22 | s = x.size() 23 | assert n_dims > 1 and n_dims < len(s) 24 | return x.view(*s[:-n_dims], -1) 25 | 26 | 27 | class MultiHeadedSelfAttention(nn.Module): 28 | """Multi-Headed Dot Product Attention""" 29 | def __init__(self, dim, num_heads, dropout): 30 | super().__init__() 31 | self.proj_q = nn.Linear(dim, dim) 32 | self.proj_k = nn.Linear(dim, dim) 33 | self.proj_v = nn.Linear(dim, dim) 34 | self.drop = nn.Dropout(dropout) 35 | self.n_heads = num_heads 36 | self.scores = None # for visualization 37 | 38 | def forward(self, x, mask): 39 | """ 40 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 41 | mask : (B(batch_size) x S(seq_len)) 42 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 43 | """ 44 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 45 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 46 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 47 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 48 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 49 | if mask is not None: 50 | mask = mask[:, None, None, :].float() 51 | scores -= 10000.0 * (1.0 - mask) 52 | scores = self.drop(F.softmax(scores, dim=-1)) 53 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 54 | h = (scores @ v).transpose(1, 2).contiguous() 55 | # -merge-> (B, S, D) 56 | h = merge_last(h, 2) 57 | self.scores = scores 58 | return h 59 | 60 | 61 | class PositionWiseFeedForward(nn.Module): 62 | """FeedForward Neural Networks for each position""" 63 | def __init__(self, dim, ff_dim): 64 | super().__init__() 65 | self.fc1 = nn.Linear(dim, ff_dim) 66 | self.fc2 = nn.Linear(ff_dim, dim) 67 | 68 | def forward(self, x): 69 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 70 | return self.fc2(F.gelu(self.fc1(x))) 71 | 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, ff_dim, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 78 | self.proj = nn.Linear(dim, dim) 79 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 80 | self.pwff = PositionWiseFeedForward(dim, ff_dim) 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.drop = nn.Dropout(dropout) 83 | 84 | def forward(self, x, mask): 85 | h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 86 | x = x + h 87 | h = self.drop(self.pwff(self.norm2(x))) 88 | x = x + h 89 | return x 90 | 91 | 92 | class Transformer(nn.Module): 93 | """Transformer with Self-Attentive Blocks""" 94 | def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 95 | super().__init__() 96 | self.blocks = nn.ModuleList([ 97 | Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 98 | 99 | def forward(self, x, mask=None): 100 | for block in self.blocks: 101 | x = block(x, mask) 102 | return x 103 | -------------------------------------------------------------------------------- /Test and Comparison/pytorch_pretrained_vit/utils.py: -------------------------------------------------------------------------------- 1 | """utils.py - Helper functions 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | from .configs import PRETRAINED_MODELS 9 | 10 | 11 | def load_pretrained_weights( 12 | model, 13 | model_name=None, 14 | weights_path=None, 15 | load_first_conv=True, 16 | load_fc=True, 17 | load_repr_layer=False, 18 | resize_positional_embedding=False, 19 | verbose=True, 20 | strict=True, 21 | ): 22 | """Loads pretrained weights from weights path or download using url. 23 | Args: 24 | model (Module): Full model (a nn.Module) 25 | model_name (str): Model name (e.g. B_16) 26 | weights_path (None or str): 27 | str: path to pretrained weights file on the local disk. 28 | None: use pretrained weights downloaded from the Internet. 29 | load_first_conv (bool): Whether to load patch embedding. 30 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 31 | resize_positional_embedding=False, 32 | verbose (bool): Whether to print on completion 33 | """ 34 | assert bool(model_name) ^ bool(weights_path), 'Expected exactly one of model_name or weights_path' 35 | 36 | # Load or download weights 37 | if weights_path is None: 38 | url = PRETRAINED_MODELS[model_name]['url'] 39 | if url: 40 | state_dict = model_zoo.load_url(url) 41 | else: 42 | raise ValueError(f'Pretrained model for {model_name} has not yet been released') 43 | else: 44 | state_dict = torch.load(weights_path) 45 | 46 | # Modifications to load partial state dict 47 | expected_missing_keys = [] 48 | if not load_first_conv and 'patch_embedding.weight' in state_dict: 49 | expected_missing_keys += ['patch_embedding.weight', 'patch_embedding.bias'] 50 | if not load_fc and 'fc.weight' in state_dict: 51 | expected_missing_keys += ['fc.weight', 'fc.bias'] 52 | if not load_repr_layer and 'pre_logits.weight' in state_dict: 53 | expected_missing_keys += ['pre_logits.weight', 'pre_logits.bias'] 54 | for key in expected_missing_keys: 55 | state_dict.pop(key) 56 | 57 | # Change size of positional embeddings 58 | if resize_positional_embedding: 59 | posemb = state_dict['positional_embedding.pos_embedding'] 60 | posemb_new = model.state_dict()['positional_embedding.pos_embedding'] 61 | state_dict['positional_embedding.pos_embedding'] = \ 62 | resize_positional_embedding_(posemb=posemb, posemb_new=posemb_new, 63 | has_class_token=hasattr(model, 'class_token')) 64 | maybe_print('Resized positional embeddings from {} to {}'.format( 65 | posemb.shape, posemb_new.shape), verbose) 66 | 67 | # Load state dict 68 | ret = model.load_state_dict(state_dict, strict=False) 69 | if strict: 70 | assert set(ret.missing_keys) == set(expected_missing_keys), \ 71 | 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) 72 | assert not ret.unexpected_keys, \ 73 | 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) 74 | maybe_print('Loaded pretrained weights.', verbose) 75 | else: 76 | maybe_print('Missing keys when loading pretrained weights: {}'.format(ret.missing_keys), verbose) 77 | maybe_print('Unexpected keys when loading pretrained weights: {}'.format(ret.unexpected_keys), verbose) 78 | return ret 79 | 80 | 81 | def maybe_print(s: str, flag: bool): 82 | if flag: 83 | print(s) 84 | 85 | 86 | def as_tuple(x): 87 | return x if isinstance(x, tuple) else (x, x) 88 | 89 | 90 | def resize_positional_embedding_(posemb, posemb_new, has_class_token=True): 91 | """Rescale the grid of position embeddings in a sensible manner""" 92 | from scipy.ndimage import zoom 93 | 94 | # Deal with class token 95 | ntok_new = posemb_new.shape[1] 96 | if has_class_token: # this means classifier == 'token' 97 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 98 | ntok_new -= 1 99 | else: 100 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 101 | 102 | # Get old and new grid sizes 103 | gs_old = int(np.sqrt(len(posemb_grid))) 104 | gs_new = int(np.sqrt(ntok_new)) 105 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 106 | 107 | # Rescale grid 108 | zoom_factor = (gs_new / gs_old, gs_new / gs_old, 1) 109 | posemb_grid = zoom(posemb_grid, zoom_factor, order=1) 110 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 111 | posemb_grid = torch.from_numpy(posemb_grid) 112 | 113 | # Deal with class token and return 114 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 115 | return posemb 116 | 117 | -------------------------------------------------------------------------------- /Test and Comparison/signal_data_add.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import torch 4 | import pickle as pk 5 | import torch 6 | 7 | def get_signal(dir1,dir2,rate,L=30,snrs_index=0): 8 | #dir1:原信号 9 | #dir2,dir3:模态间信号,dir2为小波变换能量图,dir3为星座图 10 | #dir4:序列模态对比信号,是原信号经time warping变换后的信号 11 | 12 | f1 = open(dir1, 'rb') 13 | f2 = open(dir2, 'rb') 14 | 15 | data = pk.load(f1, encoding='latin1') 16 | all_snrs, mods = map(lambda j: sorted(list(set(map(lambda x: x[j], data.keys())))), [1, 0]) 17 | 18 | snrs = all_snrs[snrs_index:] 19 | not_used_snrs = all_snrs[:snrs_index] 20 | print("使用的snrs: ",snrs) 21 | print("未使用的snr:",not_used_snrs) 22 | for mod in mods: 23 | for snr in not_used_snrs: 24 | del data[(mod, snr)] 25 | print("原始信号读取完成") 26 | CWTdata = pk.load(f2, encoding='latin1') 27 | for mod in mods: 28 | for snr in not_used_snrs: 29 | del CWTdata[(mod, snr)] 30 | print("CWT读取完成") 31 | 32 | 33 | 34 | snr_choise = [10] 35 | X = [] 36 | CWT = [] 37 | lbl = [] 38 | train_idx = [] 39 | lbl_idx = [] 40 | val_idx = [] 41 | data_size = 0 42 | test_idx=[] 43 | 44 | for mod in mods: 45 | for snr in snrs: 46 | length = data[(mod, snr)].shape[0] 47 | X.append(data.pop((mod, snr))) 48 | CWT.append(CWTdata.pop((mod, snr))) 49 | for i in range(length): lbl.append((mod, snr)) 50 | train_choise = np.random.choice(range(data_size, data_size + length), size=int(length * 0.6 * rate), replace=False) 51 | train_idx += list(train_choise) 52 | if snr in snr_choise: 53 | lbl_idx += list(np.random.choice(train_choise, size=L, replace=False)) 54 | 55 | val_idx += list( 56 | np.random.choice(list(set(range(data_size, data_size + length)) - set(train_idx)), 57 | size=int(length * 0.2), replace=False)) 58 | 59 | test_idx += list( 60 | np.random.choice(list(set(range(data_size, data_size + length)) - set(train_idx) - set(val_idx)), 61 | size=int(length * 0.2), replace=False)) 62 | 63 | data_size += length 64 | 65 | 66 | print("每一类中有{}个训练样本".format(length * 0.6 * rate)) 67 | 68 | X = np.vstack(X) 69 | CWT = np.vstack(CWT) 70 | 71 | #X = np.expand_dims(X, axis=1) 72 | print("X.shape",X.shape) 73 | print("CWT.shape",CWT.shape) 74 | 75 | X_train = X[train_idx] 76 | X_val = X[val_idx] 77 | X_test = X[test_idx] 78 | del X 79 | 80 | CWT_train = CWT[train_idx] 81 | CWT_val = CWT[val_idx] 82 | CWT_test = CWT[test_idx] 83 | del CWT 84 | 85 | Y_train = np.array(list(map(lambda x: mods.index(lbl[x][0]), train_idx))) 86 | Y_val = np.array(list(map(lambda x: mods.index(lbl[x][0]), val_idx))) 87 | Y_test = np.array(list(map(lambda x: mods.index(lbl[x][0]), test_idx))) 88 | 89 | 90 | 91 | traindataset = arr_to_dataset(X_train, CWT_train, Y_train) 92 | 93 | valdataset = arr_to_dataset(X_val, CWT_val, Y_val) 94 | 95 | testdataset = arr_to_dataset(X_test, CWT_test, Y_test) 96 | 97 | 98 | snr_index = np.array(list(map(lambda x: lbl[x][1], test_idx))) 99 | 100 | snr_indexs = [] 101 | for snr in snrs: 102 | snr_indexs.extend(np.where(snr_index == snr)) 103 | 104 | return traindataset,valdataset,testdataset, mods,snrs,snr_indexs 105 | 106 | 107 | def arr_to_dataset(data1, data2, label): 108 | data1 = torch.from_numpy(data1) 109 | data2 = torch.from_numpy(data2) 110 | 111 | 112 | label = torch.from_numpy(label) 113 | dataset = torch.utils.data.TensorDataset(data1,data2,label) 114 | return dataset 115 | 116 | -------------------------------------------------------------------------------- /Test and Comparison/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | import os 6 | import numpy as np 7 | from sklearn.metrics import confusion_matrix 8 | import matplotlib.pyplot as plt 9 | 10 | def get_exp(len=4000): 11 | singals = [] 12 | labels = [] 13 | for i in range(len): 14 | k = 0 15 | l = random.randint(0,7) 16 | if l == 0: 17 | phi = math.pi/12 18 | labels.append(0) 19 | elif l == 1: 20 | phi = math.pi/10 21 | labels.append(1) 22 | 23 | elif l == 2: 24 | phi = math.pi/8 25 | labels.append(2) 26 | 27 | elif l ==3: 28 | phi = math.pi/6 29 | labels.append(3) 30 | 31 | elif l ==4: 32 | phi = math.pi/4 33 | labels.append(4) 34 | 35 | 36 | elif l ==5: 37 | phi = math.pi/14 38 | labels.append(5) 39 | 40 | elif l ==6: 41 | phi = math.pi/16 42 | labels.append(6) 43 | 44 | 45 | 46 | else: 47 | phi = math.pi/2 48 | labels.append(7) 49 | 50 | singal_x = [] 51 | singal_y = [] 52 | for j in range(128): 53 | k = k + random.randint(0,5) 54 | singal_x.append(math.cos(k*phi)) 55 | singal_y.append(math.sin(k*phi)) 56 | singals.append([singal_x,singal_y]) 57 | 58 | 59 | Singal = torch.tensor([]) 60 | for i in range(len): 61 | if i == 0: 62 | Singal = torch.tensor([singals[i]]) 63 | else: 64 | Singal = torch.cat([Singal,torch.tensor([singals[i]])],dim=0) 65 | Labels = torch.tensor(labels) 66 | 67 | return Singal,Labels 68 | 69 | 70 | 71 | def compute_loss(pre,target): 72 | l2 = torch.mm(torch.norm(target,dim=1).unsqueeze(1),torch.norm(pre,dim=1).unsqueeze(0)) 73 | #pre = pre + 1e-5 74 | bsz = target.shape[0] 75 | feature_dim = target.shape[1] 76 | target = target.unsqueeze(1).expand(bsz, bsz, feature_dim) 77 | pre = pre.unsqueeze(0).expand(bsz, bsz, feature_dim) 78 | # 对 A 中每个向量与 B 中每个向量进行点积 79 | dot_product = torch.matmul(target, pre.transpose(1, 2)) 80 | # 将点积结果保存为矩阵形式 81 | result = dot_product.squeeze() 82 | result = result[:,0,:] 83 | result = torch.div(result,l2) 84 | result = torch.div(result,0.07) 85 | #print("result",result) 86 | result = torch.exp(result) 87 | #print(result) 88 | diag = torch.diag(result) 89 | #print(diag) 90 | total_lic = torch.sum(result,dim=0) 91 | #print(total_lic) 92 | lic = torch.div(diag,total_lic) 93 | lic = -torch.log(lic) 94 | #print("lic",lic) 95 | return torch.sum(lic)/lic.shape[0] 96 | 97 | 98 | class MyDataset(Dataset): 99 | def __init__(self, data, labels): 100 | self.data = data 101 | self.labels = labels 102 | 103 | def __len__(self): 104 | return len(self.labels) 105 | 106 | def __getitem__(self, index): 107 | x = self.data[index] 108 | y = self.labels[index] 109 | return x, y 110 | 111 | 112 | def accuracy(output, target, topk=(1,)): 113 | """Computes the precision@k for the specified values of k""" 114 | maxk = max(topk) 115 | batch_size = target.size(0) 116 | 117 | _, pred = output.topk(maxk, 1, True, True) 118 | pred = pred.t() 119 | 120 | 121 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 122 | 123 | res = [] 124 | for k in topk: 125 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 126 | res.append(correct_k.mul_(100.0 / batch_size)) 127 | return res 128 | 129 | 130 | 131 | 132 | def figure_plot(true_labels,pre_labels,classes,snrs,snr_indexs,figure_path=None): 133 | acc = {} 134 | acc_mod_snr = np.zeros((len(classes), len(snrs))) 135 | for i in range(len(snrs)): 136 | true_label = true_labels[snr_indexs[i]] 137 | #print(true_label.shape) 138 | pre_label = pre_labels[snr_indexs[i]] 139 | cor = np.sum(true_label == pre_label) 140 | acc[snrs[i]] = 1.0 * cor / true_label.shape[0] 141 | 142 | plot_confusion_matrix(true_label,pre_label,classes, 143 | title="Confusion Matrix (SNR=%d)(ACC=%2f)" % (snrs[i], 100.0 * acc[snrs[i]]), 144 | save_filename =os.path.join(figure_path,'Confusion(SNR=%d)(ACC=%2f).png' % (snrs[i], 100.0 * acc[snrs[i]]))) 145 | confnorm_i, _, _ = calculate_confusion_matrix(true_label, pre_label, classes) 146 | acc_mod_snr[:, i] = np.round(np.diag(confnorm_i) / np.sum(confnorm_i, axis=1), 3) 147 | 148 | 149 | plt.plot(snrs, list(map(lambda x: acc[x], snrs)),'.-') 150 | 151 | 152 | plt.xlabel("Signal to Noise Ratio") 153 | plt.ylabel("Classification Accuracy") 154 | plt.title("CNN Classification Accuracy on dataset RadioML 2''016.10 Alpha") 155 | plt.savefig(os.path.join(figure_path,'dB to Noise Ratio')) 156 | plt.close() 157 | 158 | # plot acc of each mod in one picture 159 | dis_num = len(classes) 160 | for g in range(int(np.ceil(acc_mod_snr.shape[0] / dis_num))): 161 | assert (0 <= dis_num <= acc_mod_snr.shape[0]) 162 | beg_index = g * dis_num 163 | end_index = np.min([(g + 1) * dis_num, acc_mod_snr.shape[0]]) 164 | 165 | plt.figure(figsize=(12, 10)) 166 | plt.xlabel("Signal to Noise Ratio") 167 | plt.ylabel("Classification Accuracy") 168 | plt.title("Classification Accuracy for Each Mod") 169 | 170 | for i in range(beg_index, end_index): 171 | plt.plot(snrs, acc_mod_snr[i],'.-', label=classes[i]) 172 | # 设置数字标签 173 | for x, y in zip(snrs, acc_mod_snr[i]): 174 | plt.text(x, y, y, ha='center', va='bottom', fontsize=8) 175 | 176 | plt.legend() 177 | plt.grid() 178 | plt.savefig(os.path.join(figure_path,'acc_with_mod.png')) 179 | plt.close() 180 | return acc,acc_mod_snr 181 | 182 | 183 | def calculate_confusion_matrix(Y,Y_hat,classes): 184 | n_classes = len(classes) 185 | conf = np.zeros([n_classes,n_classes]) 186 | confnorm = np.zeros([n_classes,n_classes]) 187 | 188 | for k in range(0,Y.shape[0]): 189 | i = Y[k] 190 | j = Y_hat[k] 191 | conf[i,j] = conf[i,j] + 1 192 | 193 | for i in range(0,n_classes): 194 | confnorm[i,:] = conf[i,:] / np.sum(conf[i,:]) 195 | # print(confnorm) 196 | 197 | right = np.sum(np.diag(conf)) 198 | wrong = np.sum(conf) - right 199 | return confnorm,right,wrong 200 | 201 | 202 | 203 | 204 | def plot_confusion_matrix(y_true, y_pred, labels, save_filename=None, title='Confusion matrix'): 205 | 206 | cmap = plt.cm.binary 207 | cm = confusion_matrix(y_true, y_pred) 208 | tick_marks = np.array(range(len(labels))) + 0.5 209 | np.set_printoptions(precision=2) 210 | cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 211 | plt.figure(figsize=(10, 8), dpi=120) 212 | ind_array = np.arange(len(labels)) 213 | x, y = np.meshgrid(ind_array, ind_array) 214 | intFlag = 0 215 | for x_test, y_test in zip(x.flatten(), y.flatten()): 216 | 217 | if (intFlag): 218 | c = cm[y_test][x_test] 219 | plt.text(x_test, y_test, "%d" % (c,), color='red', fontsize=8, va='center', ha='center') 220 | 221 | else: 222 | c = cm_normalized[y_test][x_test] 223 | if (c > 0.01): 224 | #这里是绘制数字,可以对数字大小和颜色进行修改 225 | plt.text(x_test, y_test, "%0.2f" % (c,), color='red', fontsize=10, va='center', ha='center') 226 | else: 227 | plt.text(x_test, y_test, "%d" % (0,), color='red', fontsize=10, va='center', ha='center') 228 | if(intFlag): 229 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 230 | else: 231 | plt.imshow(cm_normalized, interpolation='nearest', cmap=cmap) 232 | plt.gca().set_xticks(tick_marks, minor=True) 233 | plt.gca().set_yticks(tick_marks, minor=True) 234 | plt.gca().xaxis.set_ticks_position('none') 235 | plt.gca().yaxis.set_ticks_position('none') 236 | plt.grid(True, which='minor', linestyle='-') 237 | plt.gcf().subplots_adjust(bottom=0.15) 238 | plt.title(title) 239 | plt.colorbar() 240 | xlocations = np.array(range(len(labels))) 241 | plt.xticks(xlocations, labels, rotation=90) 242 | plt.yticks(xlocations, labels) 243 | plt.ylabel('Index of True Classes') 244 | plt.xlabel('Index of Predict Classes') 245 | plt.savefig(save_filename) 246 | plt.close() 247 | 248 | 249 | 250 | class video_Converter: 251 | def __init__(self,batch_bsz,frame_legth): 252 | self.batch_bsz = batch_bsz 253 | self.frame_legth = frame_legth 254 | self.sample_idx = torch.arange(0, batch_bsz).repeat(128, 1).t().reshape(-1).cuda() 255 | self.frame_idx = torch.arange(0, 128).repeat(1, batch_bsz).squeeze().cuda() 256 | self.Fundation = torch.zeros(batch_bsz, 128, self.frame_legth, self.frame_legth).cuda() 257 | self.converter = singal_to_video() 258 | def __call__(self, singal): 259 | x = self.converter(singal, self.batch_bsz, self.frame_legth, self.sample_idx, self.frame_idx, self.Fundation) 260 | x = (x + torch.roll(x, shifts=-1, dims=1) * 0.5 + torch.roll(x, shifts=1, dims=1) * 0.5) 261 | x = x.unsqueeze(2) 262 | return x 263 | 264 | class singal_to_video(object): 265 | def __init__(self): 266 | pass 267 | 268 | def __call__(self, singal, bsz, frame_legth, sample_idx, frame_idx, Fundation): 269 | lists_for_image = torch.transpose(singal, 1, 2) 270 | lists_for_image = torch.stack( 271 | [(a-torch.min(a).item() )/(torch.max(a).item()-torch.min(a).item()) for a in lists_for_image]) # 这里该成了torch.cat 范围是:[-0.5,0.5] 272 | lists_for_image = torch.round(torch.mul(lists_for_image, frame_legth-1)).to(torch.int) 273 | lists_for_image = lists_for_image.reshape(128 * bsz, 2) 274 | lists_for_image = lists_for_image.long() 275 | result = Fundation.zero_() 276 | result[sample_idx, frame_idx, lists_for_image[:, 0], lists_for_image[:, 1]] += 255 277 | 278 | return result -------------------------------------------------------------------------------- /generate_CWTdataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import pywt\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import numpy as np\n", 13 | "# from signal_data_add import *" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 4, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import pickle as pk\n", 23 | "f = open('RML2016.10a_dict.pkl','rb')\n", 24 | "data = pk.load(f, encoding = 'latin1')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 5, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "('QPSK', 2) finished\n", 37 | "('PAM4', 8) finished\n", 38 | "('AM-DSB', -4) finished\n", 39 | "('GFSK', 6) finished\n", 40 | "('QAM64', 8) finished\n", 41 | "('AM-SSB', 12) finished\n", 42 | "('8PSK', 8) finished\n", 43 | "('8PSK', 12) finished\n", 44 | "('QAM64', -6) finished\n", 45 | "('QAM16', 2) finished\n", 46 | "('QAM16', -20) finished\n", 47 | "('PAM4', -6) finished\n", 48 | "('WBFM', -18) finished\n", 49 | "('AM-DSB', 16) finished\n", 50 | "('CPFSK', 10) finished\n", 51 | "('WBFM', 6) finished\n", 52 | "('BPSK', 4) finished\n", 53 | "('BPSK', -2) finished\n", 54 | "('QPSK', -20) finished\n", 55 | "('CPFSK', -8) finished\n", 56 | "('AM-SSB', 6) finished\n", 57 | "('QAM64', -20) finished\n", 58 | "('QAM16', 12) finished\n", 59 | "('GFSK', -20) finished\n", 60 | "('AM-SSB', -12) finished\n", 61 | "('CPFSK', 0) finished\n", 62 | "('AM-DSB', 6) finished\n", 63 | "('BPSK', -16) finished\n", 64 | "('QPSK', -6) finished\n", 65 | "('8PSK', -12) finished\n", 66 | "('CPFSK', -18) finished\n", 67 | "('BPSK', -10) finished\n", 68 | "('QPSK', 8) finished\n", 69 | "('PAM4', 14) finished\n", 70 | "('AM-DSB', -10) finished\n", 71 | "('GFSK', 12) finished\n", 72 | "('QAM64', 2) finished\n", 73 | "('WBFM', -4) finished\n", 74 | "('AM-SSB', -18) finished\n", 75 | "('QAM64', -4) finished\n", 76 | "('GFSK', -4) finished\n", 77 | "('AM-DSB', 8) finished\n", 78 | "('PAM4', -16) finished\n", 79 | "('QPSK', -16) finished\n", 80 | "('BPSK', 16) finished\n", 81 | "('8PSK', -8) finished\n", 82 | "('CPFSK', 16) finished\n", 83 | "('WBFM', 0) finished\n", 84 | "('QPSK', 6) finished\n", 85 | "('BPSK', 14) finished\n", 86 | "('AM-DSB', -8) finished\n", 87 | "('GFSK', -10) finished\n", 88 | "('CPFSK', -2) finished\n", 89 | "('AM-SSB', 8) finished\n", 90 | "('GFSK', 18) finished\n", 91 | "('QAM16', 6) finished\n", 92 | "('QAM16', -16) finished\n", 93 | "('QAM64', 18) finished\n", 94 | "('AM-SSB', -2) finished\n", 95 | "('CPFSK', 6) finished\n", 96 | "('BPSK', 0) finished\n", 97 | "('BPSK', -6) finished\n", 98 | "('8PSK', -14) finished\n", 99 | "('CPFSK', -12) finished\n", 100 | "('AM-SSB', 2) finished\n", 101 | "('WBFM', 10) finished\n", 102 | "('AM-DSB', -12) finished\n", 103 | "('PAM4', 4) finished\n", 104 | "('GFSK', 10) finished\n", 105 | "('QAM16', -6) finished\n", 106 | "('QAM64', 4) finished\n", 107 | "('PAM4', -20) finished\n", 108 | "('8PSK', 10) finished\n", 109 | "('AM-SSB', -16) finished\n", 110 | "('QAM64', -10) finished\n", 111 | "('GFSK', -6) finished\n", 112 | "('AM-DSB', 2) finished\n", 113 | "('PAM4', -10) finished\n", 114 | "('QPSK', -2) finished\n", 115 | "('WBFM', -14) finished\n", 116 | "('WBFM', 12) finished\n", 117 | "('8PSK', 0) finished\n", 118 | "('QPSK', 12) finished\n", 119 | "('PAM4', 10) finished\n", 120 | "('AM-DSB', -14) finished\n", 121 | "('GFSK', 0) finished\n", 122 | "('QAM64', 14) finished\n", 123 | "('AM-SSB', 18) finished\n", 124 | "('QAM64', -8) finished\n", 125 | "('QAM16', 0) finished\n", 126 | "('GFSK', -16) finished\n", 127 | "('PAM4', -4) finished\n", 128 | "('QPSK', -12) finished\n", 129 | "('WBFM', -20) finished\n", 130 | "('CPFSK', 12) finished\n", 131 | "('WBFM', 4) finished\n", 132 | "('PAM4', 18) finished\n", 133 | "('BPSK', 10) finished\n", 134 | "('BPSK', -4) finished\n", 135 | "('QPSK', -18) finished\n", 136 | "('PAM4', -2) finished\n", 137 | "('CPFSK', -6) finished\n", 138 | "('AM-SSB', 4) finished\n", 139 | "('AM-DSB', -20) finished\n", 140 | "('8PSK', 16) finished\n", 141 | "('WBFM', 18) finished\n", 142 | "('QAM16', 10) finished\n", 143 | "('QAM16', -12) finished\n", 144 | "('CPFSK', 8) finished\n", 145 | "('8PSK', -16) finished\n", 146 | "('8PSK', -20) finished\n", 147 | "('AM-SSB', -6) finished\n", 148 | "('CPFSK', 2) finished\n", 149 | "('QPSK', 16) finished\n", 150 | "('AM-DSB', 4) finished\n", 151 | "('AM-DSB', -18) finished\n", 152 | "('8PSK', -10) finished\n", 153 | "('CPFSK', -16) finished\n", 154 | "('8PSK', -6) finished\n", 155 | "('QPSK', 10) finished\n", 156 | "('PAM4', 0) finished\n", 157 | "('BPSK', -20) finished\n", 158 | "('GFSK', 14) finished\n", 159 | "('QAM16', -2) finished\n", 160 | "('QAM64', 0) finished\n", 161 | "('8PSK', -4) finished\n", 162 | "('AM-SSB', -20) finished\n", 163 | "('QAM64', -14) finished\n", 164 | "('GFSK', -2) finished\n", 165 | "('AM-DSB', 14) finished\n", 166 | "('PAM4', -14) finished\n", 167 | "('QPSK', -14) finished\n", 168 | "('WBFM', -10) finished\n", 169 | "('CPFSK', 18) finished\n", 170 | "('8PSK', 4) finished\n", 171 | "('QPSK', 0) finished\n", 172 | "('BPSK', 12) finished\n", 173 | "('AM-DSB', -2) finished\n", 174 | "('GFSK', 4) finished\n", 175 | "('QAM64', 10) finished\n", 176 | "('AM-SSB', 14) finished\n", 177 | "('WBFM', 8) finished\n", 178 | "('QAM16', -10) finished\n", 179 | "('PAM4', 16) finished\n", 180 | "('QAM16', 4) finished\n", 181 | "('QAM16', 18) finished\n", 182 | "('QAM16', -18) finished\n", 183 | "('QAM64', 16) finished\n", 184 | "('PAM4', -8) finished\n", 185 | "('WBFM', 16) finished\n", 186 | "('WBFM', 14) finished\n", 187 | "('AM-SSB', -4) finished\n", 188 | "('QAM16', -4) finished\n", 189 | "('BPSK', 6) finished\n", 190 | "('BPSK', -8) finished\n", 191 | "('BPSK', 18) finished\n", 192 | "('CPFSK', -10) finished\n", 193 | "('AM-SSB', 0) finished\n", 194 | "('PAM4', 6) finished\n", 195 | "('QAM64', -18) finished\n", 196 | "('QAM16', 14) finished\n", 197 | "('QAM16', -8) finished\n", 198 | "('PAM4', -18) finished\n", 199 | "('AM-DSB', 18) finished\n", 200 | "('AM-SSB', -10) finished\n", 201 | "('QAM64', -12) finished\n", 202 | "('AM-DSB', 0) finished\n", 203 | "('BPSK', -14) finished\n", 204 | "('QPSK', -8) finished\n", 205 | "('WBFM', -16) finished\n", 206 | "('CPFSK', -20) finished\n", 207 | "('8PSK', 2) finished\n", 208 | "('QPSK', 14) finished\n", 209 | "('PAM4', 12) finished\n", 210 | "('AM-DSB', -16) finished\n", 211 | "('GFSK', 2) finished\n", 212 | "('QAM64', 12) finished\n", 213 | "('AM-SSB', 16) finished\n", 214 | "('QAM64', -2) finished\n", 215 | "('8PSK', 14) finished\n", 216 | "('GFSK', -14) finished\n", 217 | "('AM-DSB', 10) finished\n", 218 | "('WBFM', -8) finished\n", 219 | "('QPSK', -10) finished\n", 220 | "('CPFSK', 14) finished\n", 221 | "('WBFM', 2) finished\n", 222 | "('QPSK', 4) finished\n", 223 | "('BPSK', 8) finished\n", 224 | "('AM-DSB', -6) finished\n", 225 | "('CPFSK', -4) finished\n", 226 | "('AM-SSB', 10) finished\n", 227 | "('WBFM', -2) finished\n", 228 | "('8PSK', 18) finished\n", 229 | "('QAM16', 8) finished\n", 230 | "('QAM16', -14) finished\n", 231 | "('8PSK', -18) finished\n", 232 | "('8PSK', -2) finished\n", 233 | "('AM-SSB', -8) finished\n", 234 | "('CPFSK', 4) finished\n", 235 | "('QPSK', 18) finished\n", 236 | "('BPSK', 2) finished\n", 237 | "('BPSK', -12) finished\n", 238 | "('WBFM', -6) finished\n", 239 | "('CPFSK', -14) finished\n", 240 | "('GFSK', 16) finished\n", 241 | "('PAM4', 2) finished\n", 242 | "('GFSK', 8) finished\n", 243 | "('GFSK', -12) finished\n", 244 | "('QAM64', 6) finished\n", 245 | "('GFSK', -18) finished\n", 246 | "('AM-SSB', -14) finished\n", 247 | "('QAM64', -16) finished\n", 248 | "('QAM16', 16) finished\n", 249 | "('GFSK', -8) finished\n", 250 | "('AM-DSB', 12) finished\n", 251 | "('PAM4', -12) finished\n", 252 | "('QPSK', -4) finished\n", 253 | "('WBFM', -12) finished\n", 254 | "('8PSK', 6) finished\n", 255 | "('BPSK', -18) finished\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "CWTdata = dict()\n", 261 | "for i in data.keys():\n", 262 | " pp = []\n", 263 | " for j in range(len(data[i])):\n", 264 | " coefs1, freqs = pywt.cwt(data[i][j][0], np.arange(1, 100), 'morl')\n", 265 | " coefs2, freqs = pywt.cwt(data[i][j][1], np.arange(1, 100), 'morl')\n", 266 | " coef = np.expand_dims(np.stack([coefs1,coefs2]),0)\n", 267 | " pp.append(coef)\n", 268 | " print(i,\"finished\")\n", 269 | " pp = np.vstack(pp)\n", 270 | " CWTdata[i] = pp\n", 271 | " \n", 272 | "\n", 273 | "with open('CWTdata.pkl', 'wb') as f:\n", 274 | " pickle.dump(CWTdata, f)\n", 275 | "\n", 276 | "\n" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "# traindataset,valset,testset,classes,snrs,snr_indexs = get_signal('RML2016.10a_dict.pkl','CWTdata.pkl',1,L=1,snrs_index = 0)\n", 286 | "# train_loader = torch.utils.data.DataLoader(traindataset, batch_size=1, shuffle=True)" 287 | ] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "base", 293 | "language": "python", 294 | "name": "python3" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.8.10" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 4 311 | } 312 | --------------------------------------------------------------------------------