├── .gitignore ├── LICENSE ├── README.md └── multiModalityFusionForClassification.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, woosual 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # multiModalityFusionForClassification 2 | 多模态数据融合:为了完成多模态数据融合,首先利用VGG16网络和cifar10数据集完成多输入网络的分类,在VGG16的基础之上,将前三层特征提取网络作为不同输入的特征提取网络,在中间层进行特征拼接,后面的卷积层用于提取融合特征,最后加上全连接层。该网络稍作修改就能同时提取两张对应的图片作为输入,在特征提取之后进行融合用于分类。 3 | -------------------------------------------------------------------------------- /multiModalityFusionForClassification.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 2020-09-24 3 | 作者:吴愚 4 | 研究目标: 5 | 用于学习多模态目标识别网络 6 | ----------------------------------------------------------- 7 | -首先从多分类开始做起(多分类网络): - 8 | -输入:一张原始图片用模块A提取特征,同时用模块B提取特征最后将特征拼接 - 9 | -中间层:采用特征向量叠加的方式 torch.cat - 10 | -输出:图片的类别 - 11 | ----------------------------------------------------------- 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch import optim 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision import datasets 21 | from tqdm import tqdm 22 | from torch.optim import lr_scheduler 23 | 24 | '''定义超参数''' 25 | batch_size = 64 # 批的大小 26 | learning_rate = 1e-2 # 学习率 27 | num_epoches = 200 # 遍历训练集的次数 28 | 29 | ''' 30 | transform = transforms.Compose([ 31 | transforms.RandomSizedCrop(224), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], 35 | std = [ 0.229, 0.224, 0.225 ]), 36 | ]) 37 | ''' 38 | 39 | '''下载训练集 CIFAR-10 10分类训练集''' 40 | train_dataset = datasets.CIFAR10('./data', train=True, transform=transforms.ToTensor(), download=True) 41 | train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) 42 | test_dataset = datasets.CIFAR10('./data', train=False, transform=transforms.ToTensor(), download=True) 43 | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) 44 | 45 | class fusion(nn.Module): 46 | def __init__(self, num_classes=10): 47 | super(fusion, self).__init__() 48 | self.featuresA = nn.Sequential( 49 | # 1 50 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 51 | nn.BatchNorm2d(64), 52 | nn.ReLU(True), 53 | # 2 54 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(64), 56 | nn.ReLU(True), 57 | nn.MaxPool2d(kernel_size=2, stride=2), 58 | # 3 59 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(128), 61 | nn.ReLU(True), 62 | 63 | ) 64 | self.featuresB = nn.Sequential( 65 | # 1 66 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 67 | nn.BatchNorm2d(64), 68 | nn.ReLU(True), 69 | # 2 70 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 71 | nn.BatchNorm2d(64), 72 | nn.ReLU(True), 73 | nn.MaxPool2d(kernel_size=2, stride=2), 74 | # 3 75 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 76 | nn.BatchNorm2d(128), 77 | nn.ReLU(True), 78 | ) 79 | self.fusionFeature = nn.Sequential( 80 | # 4 81 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 82 | nn.BatchNorm2d(256), 83 | nn.ReLU(True), 84 | nn.MaxPool2d(kernel_size=2, stride=2), 85 | # 5 86 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 87 | nn.BatchNorm2d(512), 88 | nn.ReLU(True), 89 | # 6 90 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 91 | nn.BatchNorm2d(512), 92 | nn.ReLU(True), 93 | # 7 94 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 95 | nn.BatchNorm2d(512), 96 | nn.ReLU(True), 97 | nn.MaxPool2d(kernel_size=2, stride=2), 98 | # 8 99 | nn.Conv2d(512, 1024, kernel_size=3, padding=1), 100 | nn.BatchNorm2d(1024), 101 | nn.ReLU(True), 102 | # 9 103 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 104 | nn.BatchNorm2d(1024), 105 | nn.ReLU(True), 106 | # 10 107 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 108 | nn.BatchNorm2d(1024), 109 | nn.ReLU(True), 110 | nn.MaxPool2d(kernel_size=2, stride=2), 111 | # 11 112 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 113 | nn.BatchNorm2d(1024), 114 | nn.ReLU(True), 115 | # 12 116 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 117 | nn.BatchNorm2d(1024), 118 | nn.ReLU(True), 119 | # 13 120 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 121 | nn.BatchNorm2d(1024), 122 | nn.ReLU(True), 123 | nn.MaxPool2d(kernel_size=2, stride=2), 124 | nn.AvgPool2d(kernel_size=1, stride=1), 125 | ) 126 | self.classifier_test = nn.Linear(1024, 10) 127 | self.classifier = nn.Sequential( 128 | # 14 129 | nn.Linear(1024, 4096), 130 | nn.ReLU(True), 131 | nn.Dropout(), 132 | # 15 133 | nn.Linear(4096, 4096), 134 | nn.ReLU(True), 135 | nn.Dropout(), 136 | # 16 137 | nn.Linear(4096, num_classes), 138 | ) 139 | # self.classifier = nn.Linear(512, 10) 140 | 141 | def forward(self, x): 142 | outA = self.featuresA(x) 143 | outB = self.featuresB(x) 144 | out_fusion = torch.cat((outA, outB), dim=1) 145 | 146 | out_fusion = self.fusionFeature(out_fusion) 147 | out_fusion = out_fusion.view(out_fusion.size(0), -1) 148 | out = self.classifier_test(out_fusion) 149 | return out 150 | 151 | 152 | '''创建model实例对象,并检测是否支持使用GPU''' 153 | model = fusion() 154 | 155 | 156 | use_gpu = torch.cuda.is_available() # 判断是否有GPU加速 157 | if use_gpu: 158 | model = model.cuda() 159 | 160 | '''定义loss和optimizer''' 161 | criterion = nn.CrossEntropyLoss() 162 | optimizer = optim.SGD(model.parameters(), lr=learning_rate) 163 | '''训练模型''' 164 | 165 | for epoch in range(num_epoches): 166 | print('*' * 25, 'epoch {}'.format(epoch + 1), '*' * 25) # .format为输出格式,formet括号里的即为左边花括号的输出 167 | running_loss = 0.0 168 | running_acc = 0.0 169 | for i, data in tqdm(enumerate(train_loader, 1)): 170 | 171 | img, label = data 172 | # cuda 173 | if use_gpu: 174 | img = img.cuda() 175 | label = label.cuda() 176 | img = Variable(img) 177 | label = Variable(label) 178 | # 向前传播 179 | out = model(img) 180 | loss = criterion(out, label) 181 | running_loss += loss.item() * label.size(0) 182 | _, pred = torch.max(out, 1) # 预测最大值所在的位置标签 183 | num_correct = (pred == label).sum() 184 | accuracy = (pred == label).float().mean() 185 | running_acc += num_correct.item() 186 | # 向后传播 187 | optimizer.zero_grad() 188 | loss.backward() 189 | 190 | print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format( 191 | epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset)))) 192 | 193 | model.eval() # 模型评估 194 | eval_loss = 0 195 | eval_acc = 0 196 | for data in test_loader: # 测试模型 197 | img, label = data 198 | if use_gpu: 199 | img = Variable(img, volatile=True).cuda() 200 | label = Variable(label, volatile=True).cuda() 201 | else: 202 | img = Variable(img, volatile=True) 203 | label = Variable(label, volatile=True) 204 | out = model(img) 205 | loss = criterion(out, label) 206 | eval_loss += loss.item() * label.size(0) 207 | _, pred = torch.max(out, 1) 208 | num_correct = (pred == label).sum() 209 | eval_acc += num_correct.item() 210 | print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len( 211 | test_dataset)), eval_acc / (len(test_dataset)))) 212 | print() 213 | 214 | # 保存模型 215 | torch.save(model.state_dict(), './cnn.pth') 216 | --------------------------------------------------------------------------------