├── README.md ├── CsiNet.py └── network.py /README.md: -------------------------------------------------------------------------------- 1 | ## 描述 2 | 3 | 对于论文[Convolutional neural network based multiple-rate compressive sensing for massive 4 | MIMO CSI feedback: Design, simulation, and analysis](https://arxiv.org/abs/1906.06007)网络结构的实现 5 | 只实现了其中CsiNet+的部分,对于论文的重点SM-CsiNet+和PM-CsiNet+可能会在以后实现。 6 | 7 | ## 与CsiNet的对比 8 | 9 | * 使用更大的卷积核,其实最主要的还是追求更大的感受野(尤其是在outdoor场景和高CR的情况下,需要更多的全局信息) 10 | * 移除了decoder后面的卷积层,因为RefineNet的输出结果足够恢复CSI,加上一层卷积层反而会是结果更差(作者是这样解释的,并没有做消融实验) 11 | 12 | ## 参考文献 13 | 14 | [1]C. Wen, W. Shih and S. Jin, “Deep Learning for Massive MIMO CSI 15 | Feedback,” IEEE Wireless Communications Letters, vol. 7, no. 5, pp. 16 | 748-751, Oct. 2018 17 | 18 | [2]J. Guo, C.-K. Wen, S. Jin, and G. Y. Li, “Convolutional neural network based multiple-rate compressive sensing for massive 19 | MIMO CSI feedback: Design, simulation, and analysis,” arXiv preprint 20 | arXiv:1906.06007, 2019 -------------------------------------------------------------------------------- /CsiNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | 6 | 7 | # PyTorch版本的CsiNet 8 | 9 | 10 | class ConvBN(nn.Sequential): # 包含卷积;批次归一化;激活函数 11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): 12 | if not isinstance(kernel_size, int): 13 | padding = [(i - 1) // 2 for i in kernel_size] 14 | else: 15 | padding = (kernel_size - 1) // 2 # padding的设置是为了让输出的特征图的大小保持一致 16 | super(ConvBN, self).__init__(OrderedDict([ 17 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 18 | padding=padding, groups=groups, bias=False)), # 为什么bia设置为FALSE呢? 19 | ('bn', nn.BatchNorm2d(out_planes)), # 所以BatchNorm2d的输入参数是输出的特征图的通道数? 20 | ('relu',nn.LeakyReLU(negative_slope=0.3, inplace=True)) 21 | ])) 22 | 23 | 24 | class ResBlock(nn.Module): 25 | def __init__(self): 26 | super(ResBlock, self).__init__() 27 | self.direct_path = nn.Sequential(OrderedDict([ 28 | ("conv_1", ConvBN(2, 8, kernel_size=3)), 29 | ("conv_2", ConvBN(8, 16, kernel_size=3)), 30 | ("conv_3", nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1)), 31 | ("bn", nn.BatchNorm2d(2)) 32 | ])) 33 | self.identity = nn.Identity() 34 | self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True) 35 | def forward(self, x): 36 | identity = self.identity(x) 37 | out = self.direct_path(x) 38 | out = self.relu(out + identity) 39 | 40 | return out 41 | 42 | class CsiNet(nn.Module): 43 | def __init__(self,reduction=4): 44 | super(CsiNet, self).__init__() 45 | total_size, in_channel, w, h = 2048, 2, 32, 32 46 | dim_out = total_size // reduction 47 | 48 | self.encoder_convbn = ConvBN(in_channel, 2, kernel_size=3) 49 | self.encoder_fc = nn.Linear(total_size, dim_out) 50 | 51 | self.decoder_fc = nn.Linear(dim_out, total_size) 52 | self.decoder_RefineNet1 = ResBlock() 53 | self.decoder_RefineNet2 = ResBlock() 54 | self.decoder_conv = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1) 55 | self.decoder_bn = nn.BatchNorm2d(2) 56 | self.decoder_sigmoid = nn.Sigmoid() 57 | 58 | def forward(self, x): 59 | n,c, h, w = x.detach().size() 60 | x = self.encoder_convbn(x) 61 | x = x.view(n,-1) # 平坦化,reshape 62 | x = self.encoder_fc(x) 63 | # 此时x为编码后的输出,需要将x回传给发送端 64 | 65 | x = self.decoder_fc(x) 66 | x = x.view(n, c, h, w) 67 | x = self.decoder_RefineNet1(x) 68 | x = self.decoder_RefineNet2(x) 69 | x = self.decoder_conv(x) 70 | x = self.decoder_bn(x) 71 | x = self.decoder_sigmoid(x) 72 | 73 | return x 74 | 75 | 76 | 77 | if __name__ == "__main__": 78 | x = torch.ones(10, 2, 32, 32) 79 | net = CsiNet() 80 | x = net(x) 81 | print(x.shape) 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | 6 | 7 | class ConvLayer(nn.Sequential): 8 | def __init__(self, in_planes, out_planes, kernel_size,stride=1, activation="LeakyReLu"): 9 | padding = (kernel_size - 1) // 2 10 | dict_activation ={"LeakyReLu":nn.LeakyReLU(negative_slope=0.3,inplace=True),"Sigmoid":nn.Sigmoid(),"Tanh":nn.Tanh()} 11 | activation_layer = dict_activation[activation] 12 | super(ConvLayer, self).__init__(OrderedDict([ 13 | ("conv", nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=padding, bias=False)), 14 | ("bn", nn.BatchNorm2d(out_planes)), 15 | ("activation",activation_layer) 16 | ])) 17 | 18 | class RefineNetBlock(nn.Module): 19 | def __init__(self): 20 | super(RefineNetBlock, self).__init__() 21 | #一个7*7的卷积层+ 5*5的卷积层 + 3*3的卷积层 再加上一个跳跃连接 22 | self.direct = nn.Sequential(OrderedDict([ 23 | ("conv_7x7", ConvLayer(2, 8, 7, activation="LeakyReLu")), 24 | ("conv_5x5", ConvLayer(8, 16, 5, activation="LeakyReLu")), 25 | ("conv_3x3",ConvLayer(16,2,3,activation="Tanh")) 26 | ])) 27 | self.identity = nn.Identity() 28 | self.relu = nn.ReLU() 29 | def forward(self, x): 30 | identity = self.identity(x) 31 | out = self.direct(x) 32 | out = self.relu(out + identity) 33 | 34 | return out 35 | 36 | class CsiNetPlus(nn.Module): 37 | def __init__(self,reduction=4): 38 | super(CsiNetPlus, self).__init__() 39 | total_size, in_channel, w, h = 2048, 2, 32, 32 40 | self.encoder_conv = nn.Sequential(OrderedDict([ 41 | ("conv1_7x7", ConvLayer(2, 2, 7, activation='LeakyReLu')), 42 | ("conv2_7x7",ConvLayer(2,2,7,activation='LeakyReLu')) 43 | ])) 44 | self.encoder_fc = nn.Linear(total_size, total_size // reduction) 45 | 46 | self.decoder_fc = nn.Linear(total_size // reduction, total_size) 47 | self.decoder_conv = ConvLayer(2, 2, 7, activation="Sigmoid") 48 | self.decoder_refine = nn.Sequential(OrderedDict([ 49 | (f"RefineNet{i+1}",RefineNetBlock()) for i in range(5) 50 | ])) 51 | # self.refinenet = RefineNetBlock() 52 | def forward(self, x): 53 | n,c,h,w = x.detach().size() 54 | out = self.encoder_conv(x) 55 | out = self.encoder_fc(out.view(n, -1)) 56 | 57 | out = self.decoder_fc(out).view(n, c, h, w) 58 | out = self.decoder_conv(out) 59 | out = self.decoder_refine(out) 60 | 61 | return out 62 | 63 | # PyTorch 现在加入了nn.Flatten() 64 | 65 | class DeFlatten(nn.Module): 66 | def __init__(self,size): 67 | super(DeFlatten, self).__init__() 68 | self.size = size 69 | def forward(self, x): 70 | x = x.view(self.size) 71 | return x 72 | class Decoder(nn.Module): 73 | pass 74 | 75 | 76 | 77 | class SM_CsiNet(nn.Module): 78 | def __init__(self,reduction_base=4): 79 | super(SM_CsiNet, self).__init__() 80 | total_size, in_channel, w, h = 2048, 2, 32, 32 81 | self.encoder_conv = nn.Sequential(OrderedDict([ 82 | ("conv1_7x7", ConvLayer(2, 2, 7, activation='LeakyReLu')), 83 | ("conv2_7x7",ConvLayer(2,2,7,activation='LeakyReLu')) 84 | ])) 85 | self.fc_cr4 = nn.Linear(total_size, total_size // 4) 86 | self.fc_cr8 = nn.Linear(total_size//4,total_size//8) 87 | self.fc_cr16 = nn.Linear(total_size // 8, total_size // 16) 88 | self.fc_cr32 = nn.Linear(total_size // 16, total_size//32) 89 | 90 | def forward(self, x,idx): 91 | x = self.encoder(x) 92 | # 如何控制流程? 93 | # 在训练过程中,原文是end 2 end的形式 94 | # 但是如果是交替的形式呢?不同的维度需要冻结 95 | 96 | dim_cr4 = self.fc_cr4(x) 97 | dim_cr8 = self.fc_cr8(dim_cr4) 98 | dim_cr16 = self.fc_cr16(dim_cr8) 99 | dim_cr32 = self.fc_cr32(dim_cr16) 100 | 101 | out_cr4 = self.decoder_cr4(dim_cr4) 102 | out_cr8 = self.decoder_cr8(dim_cr8) 103 | out_cr16 = self.decoder_cr16(dim_cr16) 104 | out_cr32 = self.decoder_cr32(dim_cr32) 105 | 106 | out_list = [out_cr4, out_cr8, out_cr16, out_cr32] 107 | # idx 为索引列表 例如idx = [0,1,2,3,4],表示返回所有值 108 | # idx = [0,1]只返回部分结果 109 | out = out_list[idx] 110 | 111 | return out 112 | 113 | class Encoder(nn.Module): 114 | pass 115 | class Decoder(nn.Module): 116 | pass 117 | class PM_CsiNet(nn.Module): 118 | def __init__(self): 119 | super(PM_CsiNet, self).__init__() 120 | self.encoder = Encoder() 121 | self.fc_list = nn.ModuleList(nn.Linear(512, 64) for i in range(8)) 122 | self.decoder_list = nn.ModuleList(Decoder(reduction=i) for i in [32,16,8,4]) 123 | def forward(self, x): 124 | dim_512 = self.encoder(x) 125 | dim_64_list = [self.fc_list[i](x) for i in range(8)] 126 | cr_out_list = [torch.cat(dim_64_list[:(i**2)],dim=1) for i in range(4)] 127 | out_list = [self.decoder_list[i](cr_out_list[i]) for i in range(4)] 128 | 129 | return out_list 130 | 131 | 132 | 133 | 134 | 135 | if __name__ == "__main__": 136 | x = torch.ones(10, 2, 32, 32) 137 | net = CsiNetPlus() 138 | print(net) 139 | out = net(x) 140 | print(x.shape) 141 | --------------------------------------------------------------------------------