├── README.md ├── .gitignore ├── feature_visualization.py ├── predict_new.py ├── data_pro └── split_train_val_test.py ├── my_data_loader.py ├── my_net.py ├── my_cat_net.py ├── predict.py ├── train_cnn_vgg11.py └── train_cnn.py /README.md: -------------------------------------------------------------------------------- 1 | # fault_diagnosis_cnn_pytorch 2 | 3 | ``` 4 | $$ step 1 5 | raw singal >> time-frequence image 6 | 7 | $$ step 2 8 | cropped time-frequence image getting training,valdiaton and test image 9 | 10 | $$ step 3 11 | split train,val and test data; generate txt file 12 | python data_pro/split_train_val_test.py 13 | 14 | $$ step 4 15 | train and validation 16 | python train_cnn.py 17 | 18 | $$ step 5 19 | test 20 | python predict.py 21 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /feature_visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision import transforms 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import cv2 8 | 9 | 10 | visual_transforms = transforms.Compose([ 11 | transforms.Resize((224, 224)), 12 | # transforms.RandomResizedCrop(input_size), 13 | # transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 16 | ]) 17 | 18 | 19 | class FeatureVisualization(): 20 | def __init__(self, img_path, layer_num, trainsform, model): 21 | self.img_path = img_path 22 | self.layer_num = layer_num 23 | self.model = model.features 24 | self.trainsform = trainsform 25 | 26 | def process_image(self): 27 | image = Image.open(self.img_path) 28 | image = self.trainsform(image) 29 | image_tensor = image.unsqueeze_(0) 30 | image_var = Variable(image_tensor) 31 | return image_var 32 | 33 | def get_feature(self): 34 | input = self.process_image() 35 | input = input.to(device) 36 | x = input 37 | for index, layer in enumerate(self.model): 38 | x = layer(x) 39 | if index == self.layer_num: 40 | return x 41 | 42 | def save_feature(self, feature_img_save_dir): 43 | features = self.get_feature() 44 | for i in range(features.shape[1]): 45 | feature = features[:, i, :, :] 46 | # print(feature.shape) 47 | feature = feature.view(features.shape[2], features.shape[3]) 48 | feature = feature.data.numpy() 49 | # scale the feature to [0, 1] 50 | feature = 1.0 / (1 + np.exp(-1.0 * feature)) 51 | feature = np.round(feature*255) 52 | if not os.path.exists(feature_img_save_dir): 53 | os.mkdir(feature_img_save_dir) 54 | if not os.path.exists(os.path.join(feature_img_save_dir, str(self.layer_num))): 55 | os.mkdir(os.path.join(feature_img_save_dir, str(self.layer_num))) 56 | fea_img_save_path = os.path.join(feature_img_save_dir, str(self.layer_num), str(i+1)+".jpg") 57 | cv2.imwrite(fea_img_save_path, feature) 58 | # feature.save(fea_img_save_path) 59 | 60 | 61 | if __name__ == "__main__": 62 | img_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn/jiangnan_data/jiangnan_data_2500/ib_2500/ib_108.jpg" 63 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/vgg16_14.pth" 64 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 65 | device = torch.device("cpu") 66 | model = torch.load(model_path) 67 | model = model.to(device) 68 | # print(model) 69 | model.eval() 70 | feature_img_save_dir = "feature_img" 71 | for i in range(30): 72 | layer_num = i 73 | fea_visual = FeatureVisualization(img_path, layer_num, visual_transforms, model) 74 | fea_visual.save_feature(feature_img_save_dir) 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /predict_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | from torch.autograd import Variable 5 | import os, time 6 | 7 | 8 | 9 | 10 | 11 | def load_all_image(test_txt_file, test_trainsform): 12 | img_list = [] 13 | label_list = [] 14 | img_path_list = [] 15 | 16 | with open(test_txt_file, "r") as fr: 17 | for line in fr: 18 | img_path, cls_name = line.strip().split("\t") 19 | img_path_list.append(img_path) 20 | label_list.append(int(cls_name)) 21 | 22 | for k, v in enumerate(img_path_list): 23 | img = Image.open(v) 24 | # print(img) 25 | img = test_trainsform(img) 26 | img_list.append(img) 27 | 28 | return img_list, label_list 29 | 30 | 31 | 32 | def predict_image(image, model, device): 33 | image_tensor = image.unsqueeze_(0) 34 | input = Variable(image_tensor) 35 | input = input.to(device) 36 | output = model(input) 37 | index = output.data.cpu().numpy().argmax() 38 | return index 39 | 40 | 41 | 42 | def predict(model_path, test_txt_file): 43 | test_trainsform = transforms.Compose([ 44 | transforms.Resize((224, 224)), 45 | # transforms.Resize((48, 48)), 46 | # transforms.RandomResizedCrop(input_size), 47 | # transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]) 51 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 52 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 53 | model = torch.load(model_path) 54 | model.eval() 55 | 56 | correct_nums = 0 57 | nums = 0 58 | img_list, label_list = load_all_image(test_txt_file, test_trainsform) 59 | since = time.time() 60 | for k, v in enumerate(label_list): 61 | img = img_list[k] 62 | label = v 63 | # img = test_trainsform(img) 64 | pre = predict_image(img, model, device) 65 | if label == pre: 66 | correct_nums += 1 67 | nums += 1 68 | time_elapsed = time.time() - since 69 | print("time: ", time_elapsed) 70 | print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 71 | print("correct_nums: ", correct_nums) 72 | print("Test nums: ", nums) 73 | print("Accuracy: ", correct_nums*1.0/nums) 74 | 75 | 76 | if __name__ == "__main__": 77 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/vgg16_2.pth" 78 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/resnet_2.pth" 79 | # model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/alexnet_35.pth" 80 | # model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/vgg11_1.pth" 81 | 82 | isCaseW = True 83 | # isCaseW = False 84 | if isCaseW: 85 | test_txt_file = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_2/test.txt" 86 | predict(model_path, test_txt_file) 87 | else: 88 | test_txt_file = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_2/test.txt" 89 | predict(model_path, test_txt_file) 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /data_pro/split_train_val_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | ########################## 6 | # return train, val, test list 7 | 8 | ''' 9 | jiangnan: 10 | n >>>>> 0 11 | ob >>>>> 1 12 | tb >>>>> 2 13 | ib >>>>> 3 14 | ''' 15 | ########################## 16 | def split_trainval_test(input_dir, isCaseW=True): 17 | all_file = [] 18 | for cls_name in os.listdir(input_dir): 19 | if cls_name == ".DS_Store": 20 | continue 21 | file_dir = os.path.join(input_dir, cls_name) 22 | if isCaseW: 23 | for img_name in os.listdir(file_dir): 24 | if img_name.endswith(".jpg"): 25 | img_path = os.path.join(file_dir, img_name) 26 | all_file.append([img_path, cls_name]) 27 | else: 28 | if cls_name.startswith("n"): 29 | label = '0' 30 | elif cls_name.startswith("ob"): 31 | label = '1' 32 | elif cls_name.startswith("tb"): 33 | label = '2' 34 | else: 35 | label = '3' 36 | for img_name in os.listdir(file_dir): 37 | if img_name.endswith(".jpg"): 38 | img_path = os.path.join(file_dir, img_name) 39 | all_file.append([img_path, label]) 40 | random.shuffle(all_file) 41 | train = all_file[: int(len(all_file)*0.6)] 42 | val = all_file[int(len(all_file)*0.6): int(len(all_file)*0.8)] 43 | test = all_file[int(len(all_file)*0.8): ] 44 | 45 | return train, val, test 46 | 47 | def generate_train_val_test_txt_file(train, val, test, save_dir): 48 | if not os.path.exists(save_dir): 49 | os.mkdir(save_dir) 50 | 51 | train_txt_path = os.path.join(save_dir, "train.txt") 52 | val_txt_path = os.path.join(save_dir, "val.txt") 53 | test_txt_path = os.path.join(save_dir, "test.txt") 54 | 55 | train_str = "" 56 | val_str = "" 57 | test_str = "" 58 | 59 | for img_path, cls_name in train: 60 | train_str += img_path + "\t" + cls_name + "\n" 61 | for img_path, cls_name in val: 62 | val_str += img_path + "\t" + cls_name + "\n" 63 | for img_path, cls_name in test: 64 | test_str += img_path + "\t" + cls_name + "\n" 65 | with open(train_txt_path, "w") as fw: 66 | fw.write(train_str) 67 | with open(val_txt_path, "w") as fw: 68 | fw.write(val_str) 69 | with open(test_txt_path, "w") as fw: 70 | fw.write(test_str) 71 | 72 | 73 | if __name__ == "__main__": 74 | # isCaseW = True 75 | isCaseW = False 76 | if isCaseW: 77 | input_dir = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn/CaseW_data/CaseW_raw_data" 78 | save_dir = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_1" 79 | train, val, test = split_trainval_test(input_dir, isCaseW) 80 | generate_train_val_test_txt_file(train, val, test, save_dir) 81 | else: 82 | input_dir = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn/jiangnan_data/jiangnan_data_2500" 83 | save_dir = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_1" 84 | train, val, test = split_trainval_test(input_dir, isCaseW) 85 | generate_train_val_test_txt_file(train, val, test, save_dir) 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /my_data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | # @Author: binzh 4 | # @Date: 2018-10-25 14:14:25 5 | # @Last Modified by: binzh 6 | # @Last Modified time: 2018-11-25 14:14:25 7 | ''' 8 | 9 | # import PIL 10 | # import cv2 11 | import os 12 | import os.path as osp 13 | from PIL import Image 14 | import torch 15 | from torch.utils import data 16 | from torchvision import datasets, transforms 17 | 18 | 19 | 20 | class MyDataLoader(data.Dataset): 21 | """ 22 | when define the son class of torch.utils.data.Dataset, len and getitem function must reload. 23 | Args: 24 | index (int): Index 25 | 26 | Returns: 27 | tuple: (sample, target) where target is class_index of the target class. 28 | """ 29 | def __init__(self, img_root, txt_file, transforms=None, isCaseW=False, train=True): 30 | 31 | self.img_list = [] 32 | self.labels = [] 33 | self.img_root = img_root 34 | self.isCaseW = isCaseW 35 | self.read_txt_file(txt_file) 36 | self.transforms = transforms 37 | 38 | 39 | def __getitem__(self, index): 40 | """ 41 | return one image and label 42 | """ 43 | img_path = osp.join(self.img_root, self.img_list[index]) 44 | img = Image.open(img_path) 45 | img = self.transforms(img) 46 | label = self.labels[index] 47 | # print label 48 | # return img, float(label) 49 | return img, label 50 | 51 | def __len__(self, ): 52 | return len(self.img_list) 53 | 54 | def read_txt_file(self, txt_file): 55 | """ 56 | Args: 57 | txt_file (str): txt file path 58 | Operation: 59 | analysis the filename to get label 60 | Case Western: 61 | 0 >>>>> 0 62 | 1 >>>>> 1 63 | 2 >>>>> 2 64 | 3 >>>>> 3 65 | 4 >>>>> 4 66 | 5 >>>>> 5 67 | 6 >>>>> 6 68 | 7 >>>>> 7 69 | 8 >>>>> 8 70 | 9 >>>>> 9 71 | 72 | jiangnan: 73 | n >>>>> 0 74 | ob >>>>> 1 75 | tb >>>>> 2 76 | ib >>>>> 3 77 | """ 78 | with open(txt_file, "r") as fr: 79 | for line in fr: 80 | img_path, cls_name = line.strip().split("\t") 81 | temp_label = int(cls_name) 82 | self.img_list.append(img_path) 83 | self.labels.append(temp_label) 84 | 85 | 86 | # if self.isCaseW: 87 | # with open(txt_file, "r") as fr: 88 | # for line in fr: 89 | # img_path, cls_name = line.strip().split("\t") 90 | # temp_label = int(cls_name) 91 | # self.img_list.append(img_path) 92 | # self.labels.append(temp_label) 93 | # else: 94 | # with open(txt_file, "r") as fr: 95 | # for line in fr: 96 | # # img_path, cls_name = line.strip().split("\t") 97 | # img_path = line.strip().split("\t") 98 | # img_path = img_path[0] 99 | # img_name = os.path.basename(img_path) 100 | # # if img_name.startswith("n"): 101 | # # temp_label = 0 102 | # # elif img_name.startswith("ob"): 103 | # # temp_label = 1 104 | # # elif img_name.startswith("tb"): 105 | # # temp_label = 2 106 | # # else: 107 | # # temp_label = 3 108 | 109 | # img_path = img_path 110 | # self.img_list.append(img_path) 111 | # self.labels.append(temp_label) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /my_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class MyNet(torch.nn.Module): 5 | def __init__(self, init_weights=True): 6 | super(MyNet, self).__init__() 7 | # in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True 8 | self.conv1_1 = torch.nn.Conv2d(3, 64, 3, 1, 1) 9 | self.relu1_1 = torch.nn.ReLU(inplace=True) 10 | self.conv1_2 = torch.nn.Conv2d(64, 64, 3, 1, 1) 11 | self.relu1_2 = torch.nn.ReLU(inplace=True) 12 | 13 | self.conv2_1 = torch.nn.Conv2d(64, 128, 3, 1, 1) 14 | self.relu2_1 = torch.nn.ReLU(inplace=True) 15 | self.conv2_2 = torch.nn.Conv2d(128, 128, 3, 1, 1) 16 | self.relu2_2 = torch.nn.ReLU(inplace=True) 17 | 18 | self.conv3_1 = torch.nn.Conv2d(128, 256, 3, 1, 1) 19 | self.relu3_1 = torch.nn.ReLU(inplace=True) 20 | self.conv3_2 = torch.nn.Conv2d(256, 256, 3, 1, 1) 21 | self.relu3_2 = torch.nn.ReLU(inplace=True) 22 | self.conv3_3 = torch.nn.Conv2d(256, 256, 3, 1, 1) 23 | self.relu3_3 = torch.nn.ReLU(inplace=True) 24 | 25 | self.conv4_1 = torch.nn.Conv2d(256, 512, 3, 1, 1) 26 | self.relu4_1 = torch.nn.ReLU(inplace=True) 27 | self.conv4_2 = torch.nn.Conv2d(512, 512, 3, 1, 1) 28 | self.relu4_2 = torch.nn.ReLU(inplace=True) 29 | self.conv4_3 = torch.nn.Conv2d(512, 512, 3, 1, 1) 30 | self.relu4_3 = torch.nn.ReLU(inplace=True) 31 | 32 | self.conv5_1 = torch.nn.Conv2d(512, 512, 3, 1, 1) 33 | self.relu5_1 = torch.nn.ReLU(inplace=True) 34 | self.conv5_2 = torch.nn.Conv2d(512, 512, 3, 1, 1) 35 | self.relu5_2 = torch.nn.ReLU(inplace=True) 36 | self.conv5_3 = torch.nn.Conv2d(512, 512, 3, 1, 1) 37 | self.relu5_3 = torch.nn.ReLU(inplace=True) 38 | 39 | self.fc1 = torch.nn.Linear(512*7*7, 1000) 40 | self.relu_fc1 = torch.nn.ReLU(inplace=True) 41 | self.fc2 = torch.nn.Linear(1000, 100) 42 | self.relu_fc2 = torch.nn.ReLU(inplace=True) 43 | self.fc3 = torch.nn.Linear(100, 4) 44 | 45 | if init_weights: 46 | self._initialize_weights() 47 | 48 | 49 | def forward(self, x): 50 | x = self.relu1_1(self.conv1_1(x)) 51 | x = self.relu1_2(self.conv1_2(x)) 52 | x = F.max_pool2d(x, 2) 53 | 54 | x = self.relu2_1(self.conv2_1(x)) 55 | x = self.relu2_2(self.conv2_2(x)) 56 | x = F.max_pool2d(x, 2) 57 | 58 | x = self.relu3_1(self.conv3_1(x)) 59 | x = self.relu3_2(self.conv3_2(x)) 60 | x = self.relu3_3(self.conv3_3(x)) 61 | x = F.max_pool2d(x, 2) 62 | 63 | x = self.relu4_1(self.conv4_1(x)) 64 | x = self.relu4_2(self.conv4_2(x)) 65 | x = self.relu4_3(self.conv4_3(x)) 66 | x = F.max_pool2d(x, 2) 67 | 68 | x = self.relu5_1(self.conv5_1(x)) 69 | x = self.relu5_2(self.conv5_2(x)) 70 | x = self.relu5_3(self.conv5_3(x)) 71 | x = F.max_pool2d(x, 2) 72 | 73 | 74 | x = x.view(x.size(0), -1) 75 | x = self.relu_fc1(self.fc1(x)) 76 | x = torch.nn.Dropout()(x) 77 | x = self.relu_fc2(self.fc2(x)) 78 | x = torch.nn.Dropout()(x) 79 | x = self.fc3(x) 80 | return x 81 | 82 | def _initialize_weights(self): 83 | for m in self.modules(): 84 | if isinstance(m, torch.nn.Conv2d): 85 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | if m.bias is not None: 87 | torch.nn.init.constant_(m.bias, 0) 88 | # elif isinstance(m, nn.BatchNorm2d): 89 | # nn.init.constant_(m.weight, 1) 90 | # nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, torch.nn.Linear): 92 | torch.nn.init.normal_(m.weight, 0, 0.01) 93 | torch.nn.init.constant_(m.bias, 0) 94 | 95 | # MyNet() 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /my_cat_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class MyNet(torch.nn.Module): 5 | def __init__(self, num_calsses=4, init_weights=True): 6 | super(MyNet, self).__init__() 7 | # in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True 8 | self.conv1 = torch.nn.Conv2d(3, 64, 3, 2, 1) 9 | self.bn1 = torch.nn.BatchNorm2d(64) 10 | self.relu1 = torch.nn.ReLU(inplace=True) 11 | 12 | 13 | self.conv2_1 = torch.nn.Conv2d(64, 64, 3, 1, 1) 14 | self.bn2_1 = torch.nn.BatchNorm2d(64) 15 | self.relu2_1 = torch.nn.ReLU(inplace=True) 16 | self.conv2_2 = torch.nn.Conv2d(64, 64, 3, 1, 1) 17 | self.bn2_2 = torch.nn.BatchNorm2d(64) 18 | self.relu2_2 = torch.nn.ReLU(inplace=True) 19 | 20 | 21 | 22 | self.conv3_1 = torch.nn.Conv2d(64, 128, 1, 1) 23 | self.bn3_1 = torch.nn.BatchNorm2d(128) 24 | self.relu3_1 = torch.nn.ReLU(inplace=True) 25 | self.conv3_2 = torch.nn.Conv2d(128, 128, 3, 1, 1) 26 | self.bn3_2 = torch.nn.BatchNorm2d(128) 27 | self.relu3_2 = torch.nn.ReLU(inplace=True) 28 | self.conv3_3 = torch.nn.Conv2d(128, 128, 1, 1) 29 | self.bn3_3 = torch.nn.BatchNorm2d(128) 30 | self.relu3_3 = torch.nn.ReLU(inplace=True) 31 | 32 | self.conv4_1 = torch.nn.Conv2d(128, 256, 1, 1) 33 | self.bn4_1 = torch.nn.BatchNorm2d(256) 34 | self.relu4_1 = torch.nn.ReLU(inplace=True) 35 | self.conv4_2 = torch.nn.Conv2d(256, 256, 3, 1, 1) 36 | self.bn4_2 = torch.nn.BatchNorm2d(256) 37 | self.relu4_2 = torch.nn.ReLU(inplace=True) 38 | self.conv4_3 = torch.nn.Conv2d(256, 256, 1, 1) 39 | self.bn4_3 = torch.nn.BatchNorm2d(256) 40 | self.relu4_3 = torch.nn.ReLU(inplace=True) 41 | 42 | self.last_conv = torch.nn.Conv2d(64+128+256, 512, 1, 1) 43 | self.last_bn = torch.nn.BatchNorm2d(512) 44 | self.last_relu = torch.nn.ReLU(inplace=True) 45 | 46 | self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 47 | # self.fc = torch.nn.Linear(512, num_calsses) 48 | 49 | 50 | self.fc1 = torch.nn.Linear(512, 100) 51 | self.relu_ = torch.nn.ReLU(inplace=True) 52 | self.fc2 = torch.nn.Linear(100, num_calsses) 53 | 54 | 55 | 56 | 57 | 58 | if init_weights: 59 | self._initialize_weights() 60 | 61 | 62 | def forward(self, x): 63 | x = self.relu1(self.bn1(self.conv1(x))) 64 | 65 | x = self.relu2_1(self.bn2_1(self.conv2_1(x))) 66 | x = self.relu2_2(self.bn2_2(self.conv2_2(x))) 67 | x = F.max_pool2d(x, 2) 68 | x_2 = x 69 | 70 | x = self.relu3_1(self.bn3_1(self.conv3_1(x))) 71 | x = self.relu3_2(self.bn3_2(self.conv3_2(x))) 72 | x = self.relu3_3(self.bn3_3(self.conv3_3(x))) 73 | x = F.max_pool2d(x, 2) 74 | x_3 = x 75 | 76 | x = self.relu4_1(self.bn4_1(self.conv4_1(x))) 77 | x = self.relu4_2(self.bn4_2(self.conv4_2(x))) 78 | x = self.relu4_3(self.bn4_3(self.conv4_3(x))) 79 | x = F.max_pool2d(x, 2) 80 | 81 | x_2 = F.max_pool2d(x_2, 4) 82 | x_3 = F.max_pool2d(x_3, 2) 83 | x = torch.cat((x_2, x_3, x), 1) 84 | 85 | x = self.last_relu(self.last_bn(self.last_conv(x))) 86 | x = self.avgpool(x) 87 | 88 | x = x.view(x.size(0), -1) 89 | # x = self.fc(x) 90 | x = self.relu_(self.fc1(x)) 91 | x = self.fc2(x) 92 | 93 | 94 | return x 95 | 96 | def _initialize_weights(self): 97 | for m in self.modules(): 98 | if isinstance(m, torch.nn.Conv2d): 99 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | if m.bias is not None: 101 | torch.nn.init.constant_(m.bias, 0) 102 | elif isinstance(m, torch.nn.BatchNorm2d): 103 | torch.nn.init.constant_(m.weight, 1) 104 | torch.nn.init.constant_(m.bias, 0) 105 | elif isinstance(m, torch.nn.Linear): 106 | torch.nn.init.normal_(m.weight, 0, 0.01) 107 | torch.nn.init.constant_(m.bias, 0) 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | from torch.autograd import Variable 5 | import os, time 6 | 7 | 8 | 9 | 10 | 11 | def load_image(test_txt_file, test_trainsform=None, isCaseW=True): 12 | img_name_list = [] 13 | label_list = [] 14 | 15 | with open(test_txt_file, "r") as fr: 16 | for line in fr: 17 | img_path, cls_name = line.strip().split("\t") 18 | img_name_list.append(img_path) 19 | label_list.append(int(cls_name)) 20 | 21 | # if isCaseW: 22 | # with open(test_txt_file, "r") as fr: 23 | # for line in fr: 24 | # img_path, cls_name = line.strip().split("\t") 25 | # img_name_list.append(img_path) 26 | # label_list.append(int(cls_name)) 27 | # else: 28 | # with open(test_txt_file, "r") as fr: 29 | # for line in fr: 30 | # img_path, cls_name = line.strip().split("\t") 31 | # img_path = line.strip().split("\t") 32 | # img_path = img_path[0] 33 | # img_name = os.path.basename(line.strip()) 34 | # if img_name.startswith("n"): 35 | # label = 0 36 | # elif img_name.startswith("ob"): 37 | # label = 1 38 | # elif img_name.startswith("tb"): 39 | # label = 2 40 | # else: 41 | # label = 3 42 | # img_name_list.append(img_path) 43 | # label_list.append(label) 44 | 45 | for k, v in enumerate(img_name_list): 46 | img = Image.open(v) 47 | img = test_trainsform(img) 48 | label = label_list[k] 49 | yield img, label, v 50 | 51 | 52 | def predict_image(image, model, device): 53 | image_tensor = image.unsqueeze_(0) 54 | input = Variable(image_tensor) 55 | input = input.to(device) 56 | output = model(input) 57 | index = output.data.cpu().numpy().argmax() 58 | return index 59 | 60 | 61 | 62 | def predict(model_path, test_txt_file, isCaseW=True): 63 | test_trainsform = transforms.Compose([ 64 | # transforms.Resize((224, 224)), 65 | transforms.Resize((48, 48)), 66 | # transforms.RandomResizedCrop(input_size), 67 | # transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 70 | ]) 71 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 72 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | model = torch.load(model_path) 74 | model.eval() 75 | 76 | correct_nums = 0 77 | nums = 0 78 | since = time.time() 79 | for val in load_image(test_txt_file, test_trainsform, isCaseW): 80 | img = val[0] 81 | label = val[1] 82 | pre = predict_image(img, model, device) 83 | if label == pre: 84 | correct_nums += 1 85 | nums += 1 86 | time_elapsed = time.time() - since 87 | print("time: ", time_elapsed) 88 | print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 89 | print("correct_nums: ", correct_nums) 90 | print("Test nums: ", nums) 91 | print("Accuracy: ", correct_nums*1.0/nums) 92 | 93 | 94 | if __name__ == "__main__": 95 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/vgg16_8.pth" 96 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/resnet_35.pth" 97 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/alexnet_47.pth" 98 | model_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/vgg11_17.pth" 99 | 100 | # isCaseW = True 101 | isCaseW = False 102 | if isCaseW: 103 | test_txt_file = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_1/test.txt" 104 | predict(model_path, test_txt_file, isCaseW) 105 | else: 106 | test_txt_file = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_1/test.txt" 107 | predict(model_path, test_txt_file, isCaseW) 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /train_cnn_vgg11.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.optim import lr_scheduler 5 | import numpy as np 6 | import torchvision 7 | from torchvision import transforms 8 | import time 9 | import os 10 | import copy 11 | import torch.utils.data as Data 12 | from torch.autograd import Variable 13 | from my_data_loader import MyDataLoader 14 | # from my_net import * 15 | from my_cat_net import * 16 | 17 | 18 | print("PyTorch Version: ",torch.__version__) 19 | print("Torchvision Version: ",torchvision.__version__) 20 | 21 | model_name = "vgg11" 22 | 23 | 24 | def train_model(model, dataloaders, criterion, optimizer, scheduler=None, num_epochs=25): 25 | since = time.time() 26 | 27 | train_loss_list = [] 28 | train_acc_list = [] 29 | val_loss_list = [] 30 | val_acc_list = [] 31 | 32 | val_acc_history = [] 33 | 34 | best_model_wts = copy.deepcopy(model.state_dict()) 35 | best_acc = 0.0 36 | 37 | for epoch in range(num_epochs): 38 | print('Epoch {}/{}'.format(epoch+1, num_epochs)) 39 | print('-' * 10) 40 | 41 | # Each epoch has a training and validation phase 42 | for phase in ['train', 'val']: 43 | if phase == 'train': 44 | scheduler.step() 45 | model.train() # Set model to training mode 46 | else: 47 | model.eval() # Set model to evaluate mode 48 | 49 | running_loss = 0.0 50 | running_corrects = 0 51 | 52 | # Iterate over data. 53 | for inputs, labels in dataloaders[phase]: 54 | inputs = inputs.to(device) 55 | labels = labels.to(device) 56 | inputs, labels = Variable(inputs), Variable(labels) 57 | # print(inputs.shape) 58 | # print(labels.shape) 59 | 60 | # zero the parameter gradients 61 | optimizer.zero_grad() 62 | 63 | # forward 64 | # track history if only in train 65 | with torch.set_grad_enabled(phase == 'train'): 66 | # Get model outputs and calculate loss 67 | # Special case for inception because in training it has an auxiliary output. In train 68 | # mode we calculate the loss by summing the final output and the auxiliary output 69 | # but in testing we only consider the final output. 70 | outputs = model(inputs) 71 | loss = criterion(outputs, labels) 72 | 73 | _, preds = torch.max(outputs, 1) 74 | 75 | # backward + optimize only if in training phase 76 | if phase == 'train': 77 | loss.backward() 78 | optimizer.step() 79 | 80 | # statistics 81 | running_loss += loss.item() * inputs.size(0) 82 | running_corrects += torch.sum(preds == labels.data) 83 | 84 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 85 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 86 | 87 | if phase == 'train': 88 | train_loss_list.append(epoch_loss) 89 | train_acc_list.append(epoch_acc) 90 | else: 91 | val_loss_list.append(epoch_loss) 92 | val_acc_list.append(epoch_acc) 93 | # print(train_acc_list) 94 | # print(type(train_acc_list[0])) 95 | 96 | 97 | 98 | 99 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 100 | 101 | # deep copy the model 102 | if phase == 'val' and epoch_acc > best_acc: 103 | best_acc = epoch_acc 104 | best_model_wts = copy.deepcopy(model.state_dict()) 105 | if phase == 'val': 106 | val_acc_history.append(epoch_acc) 107 | 108 | print("save model...") 109 | model_save_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/"+ model_name + "_" + str(epoch+1) + ".pth" 110 | torch.save(model, model_save_path) 111 | print() 112 | 113 | time_elapsed = time.time() - since 114 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 115 | print('Best val Acc: {:4f}'.format(best_acc)) 116 | with open("model/train_loss.txt", "w") as fw: 117 | for val in train_loss_list: 118 | fw.write(str(val) + "\n") 119 | 120 | with open("model/train_acc.txt", "w") as fw: 121 | for val in train_acc_list: 122 | fw.write(str(val) + "\n") 123 | 124 | with open("model/val_loss.txt", "w") as fw: 125 | for val in val_loss_list: 126 | fw.write(str(val) + "\n") 127 | 128 | with open("model/val_acc.txt", "w") as fw: 129 | for val in val_acc_list: 130 | fw.write(str(val) + "\n") 131 | 132 | 133 | # load best model weights 134 | model.load_state_dict(best_model_wts) 135 | return model, val_acc_history 136 | 137 | 138 | input_size = 48 139 | 140 | data_transforms = { 141 | 'train': transforms.Compose([ 142 | transforms.Resize((input_size, input_size)), 143 | # transforms.RandomResizedCrop(input_size), 144 | # transforms.RandomHorizontalFlip(), 145 | transforms.ToTensor(), 146 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 147 | ]), 148 | 'val': transforms.Compose([ 149 | transforms.Resize((input_size, input_size)), 150 | # transforms.CenterCrop(input_size), 151 | transforms.ToTensor(), 152 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 153 | ]), 154 | } 155 | 156 | 157 | 158 | print("Initializing Datasets and Dataloaders...") 159 | 160 | # # Create training and validation datasets 161 | # image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} 162 | # # Create training and validation dataloaders 163 | # dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']} 164 | 165 | 166 | # isCaseW = True 167 | isCaseW = False 168 | if isCaseW: 169 | img_root_dir = "" 170 | train_txt_path = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_1/train.txt" 171 | val_txt_path = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_1/val.txt" 172 | train_batch_size = 75 173 | test_batch_size = 10 174 | num_calsses = 10 175 | # lr = 0.1 weight_decay=0.0005 176 | else: 177 | img_root_dir = "" 178 | train_txt_path = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_1/train.txt" 179 | val_txt_path = "/workspace/mnt/group/face1/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_1/val.txt" 180 | train_batch_size = 72 181 | test_batch_size = 10 182 | num_calsses = 4 183 | # lr = 0.1 weight_decay=0.0005 184 | 185 | train_dataset = MyDataLoader(img_root=img_root_dir, txt_file=train_txt_path, transforms=data_transforms["train"], isCaseW=isCaseW) 186 | train_dataloader = Data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) 187 | 188 | test_dataset = MyDataLoader(img_root=img_root_dir, txt_file=val_txt_path, transforms=data_transforms["val"], isCaseW=isCaseW) 189 | test_dataloader = Data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True) 190 | 191 | 192 | data_loader = {"train": train_dataloader, "val": test_dataloader} 193 | 194 | 195 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 196 | net = MyNet(num_calsses=num_calsses) 197 | net = net.to(device) 198 | print(net) 199 | 200 | params_to_update = net.parameters() 201 | # 0.005 202 | optimizer_ft = optim.SGD(params_to_update, lr=0.1, momentum=0.9, weight_decay=0.0005) 203 | 204 | # Decay LR by a factor of 0.1 every 40 epochs 205 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=80, gamma=0.1) 206 | 207 | 208 | 209 | 210 | 211 | # Setup the loss fxn 212 | criterion = nn.CrossEntropyLoss() 213 | 214 | num_epochs = 80 215 | 216 | # Train and evaluate 217 | model_ft, hist = train_model(net, data_loader, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=num_epochs) 218 | 219 | 220 | -------------------------------------------------------------------------------- /train_cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import numpy as np 7 | import torchvision 8 | from torchvision import datasets, models, transforms 9 | # import matplotlib.pyplot as plt 10 | import time 11 | import os 12 | import copy 13 | import torch.utils.data as Data 14 | from my_data_loader import MyDataLoader 15 | 16 | print("PyTorch Version: ",torch.__version__) 17 | print("Torchvision Version: ",torchvision.__version__) 18 | 19 | 20 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 21 | # model_name = "squeezenet" 22 | model_name = "resnet" 23 | # model_name = "vgg16" 24 | # model_name = "alexnet" 25 | 26 | num_classes = 10 27 | 28 | num_epochs = 50 29 | 30 | # Flag for feature extracting. When False, we finetune the whole model, 31 | # when True we only update the reshaped layer params 32 | feature_extract = False 33 | 34 | 35 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25): 36 | since = time.time() 37 | 38 | val_acc_history = [] 39 | 40 | best_model_wts = copy.deepcopy(model.state_dict()) 41 | best_acc = 0.0 42 | 43 | for epoch in range(num_epochs): 44 | print('Epoch {}/{}'.format(epoch+1, num_epochs)) 45 | print('-' * 10) 46 | 47 | # Each epoch has a training and validation phase 48 | for phase in ['train', 'val']: 49 | if phase == 'train': 50 | model.train() # Set model to training mode 51 | else: 52 | model.eval() # Set model to evaluate mode 53 | 54 | running_loss = 0.0 55 | running_corrects = 0 56 | 57 | # Iterate over data. 58 | for inputs, labels in dataloaders[phase]: 59 | inputs = inputs.to(device) 60 | labels = labels.to(device) 61 | # print(inputs.shape) 62 | # print(labels.shape) 63 | 64 | # zero the parameter gradients 65 | optimizer.zero_grad() 66 | 67 | # forward 68 | # track history if only in train 69 | with torch.set_grad_enabled(phase == 'train'): 70 | # Get model outputs and calculate loss 71 | # Special case for inception because in training it has an auxiliary output. In train 72 | # mode we calculate the loss by summing the final output and the auxiliary output 73 | # but in testing we only consider the final output. 74 | outputs = model(inputs) 75 | loss = criterion(outputs, labels) 76 | 77 | _, preds = torch.max(outputs, 1) 78 | 79 | # backward + optimize only if in training phase 80 | if phase == 'train': 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # statistics 85 | running_loss += loss.item() * inputs.size(0) 86 | running_corrects += torch.sum(preds == labels.data) 87 | 88 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 89 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 90 | 91 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 92 | 93 | # deep copy the model 94 | if phase == 'val' and epoch_acc > best_acc: 95 | best_acc = epoch_acc 96 | best_model_wts = copy.deepcopy(model.state_dict()) 97 | if phase == 'val': 98 | val_acc_history.append(epoch_acc) 99 | 100 | print("save model...") 101 | model_save_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/model/"+ model_name + "_" + str(epoch+1) + ".pth" 102 | torch.save(model, model_save_path) 103 | print() 104 | 105 | time_elapsed = time.time() - since 106 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 107 | print('Best val Acc: {:4f}'.format(best_acc)) 108 | 109 | # load best model weights 110 | model.load_state_dict(best_model_wts) 111 | return model, val_acc_history 112 | 113 | 114 | def set_parameter_requires_grad(model, feature_extracting): 115 | if feature_extracting: 116 | for param in model.parameters(): 117 | param.requires_grad = False 118 | 119 | 120 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True): 121 | # Initialize these variables which will be set in this if statement. Each of these 122 | # variables is model specific. 123 | model_ft = None 124 | input_size = 0 125 | 126 | if model_name == "resnet": 127 | """ Resnet34 128 | """ 129 | model_ft = models.resnet34(pretrained=use_pretrained) 130 | set_parameter_requires_grad(model_ft, feature_extract) 131 | num_ftrs = model_ft.fc.in_features 132 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 133 | input_size = 224 134 | 135 | elif model_name == "alexnet": 136 | """ Alexnet 137 | """ 138 | model_ft = models.alexnet(pretrained=use_pretrained) 139 | set_parameter_requires_grad(model_ft, feature_extract) 140 | num_ftrs = model_ft.classifier[6].in_features 141 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 142 | input_size = 224 143 | 144 | elif model_name == "vgg": 145 | """ VGG11_bn 146 | """ 147 | model_ft = models.vgg11_bn(pretrained=use_pretrained) 148 | set_parameter_requires_grad(model_ft, feature_extract) 149 | num_ftrs = model_ft.classifier[6].in_features 150 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 151 | input_size = 224 152 | elif model_name == "vgg16": 153 | """ VGG16 154 | """ 155 | model_ft = models.vgg16(pretrained=use_pretrained) 156 | set_parameter_requires_grad(model_ft, feature_extract) 157 | num_ftrs = model_ft.classifier[6].in_features 158 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 159 | input_size = 224 160 | elif model_name == "squeezenet": 161 | """ Squeezenet 162 | """ 163 | model_ft = models.squeezenet1_0(pretrained=use_pretrained) 164 | set_parameter_requires_grad(model_ft, feature_extract) 165 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 166 | model_ft.num_classes = num_classes 167 | input_size = 224 168 | 169 | else: 170 | print("Invalid model name, exiting...") 171 | exit() 172 | 173 | return model_ft, input_size 174 | 175 | 176 | 177 | 178 | ###################################################################### 179 | # Load Data 180 | # --------- 181 | # 182 | # Now that we know what the input size must be, we can initialize the data 183 | # transforms, image datasets, and the dataloaders. Notice, the models were 184 | # pretrained with the hard-coded normalization values, as described 185 | # `here `__. 186 | # 187 | 188 | data_transforms = { 189 | 'train': transforms.Compose([ 190 | transforms.Resize((224, 224)), 191 | # transforms.RandomResizedCrop(input_size), 192 | # transforms.RandomHorizontalFlip(), 193 | transforms.ToTensor(), 194 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 195 | ]), 196 | 'val': transforms.Compose([ 197 | transforms.Resize((224, 224)), 198 | # transforms.CenterCrop(input_size), 199 | transforms.ToTensor(), 200 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 201 | ]), 202 | } 203 | 204 | print("Initializing Datasets and Dataloaders...") 205 | 206 | # # Create training and validation datasets 207 | # image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} 208 | # # Create training and validation dataloaders 209 | # dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']} 210 | 211 | # img_root_dir = "" 212 | # train_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file/train.txt" 213 | # val_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file/val.txt" 214 | 215 | # train_batch_size = 72 216 | # test_batch_size = 10 217 | # train_dataset = MyDataLoader(img_root=img_root_dir, txt_file=train_txt_path, transforms=data_transforms["train"]) 218 | # train_dataloader = Data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) 219 | 220 | # test_dataset = MyDataLoader(img_root=img_root_dir, txt_file=val_txt_path, transforms=data_transforms["val"]) 221 | # test_dataloader = Data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True) 222 | 223 | 224 | # num_classes = 0 225 | 226 | isCaseW = True 227 | # isCaseW = False 228 | if isCaseW: 229 | img_root_dir = "" 230 | train_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_9/train.txt" 231 | val_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/CaseW_train_data_file_9/val.txt" 232 | train_batch_size = 75 233 | test_batch_size = 10 234 | # num_calsses = 10 235 | # # print(num_calsses) 236 | # lr = 0.1 weight_decay=0.0005 237 | else: 238 | img_root_dir = "" 239 | train_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_9/train.txt" 240 | val_txt_path = "/workspace/mnt/group/face/zhubin/alg_code/fault_diagnosis_cnn_pytorch/jiangnan_train_data_file_9/val.txt" 241 | train_batch_size = 72 242 | test_batch_size = 10 243 | # num_calsses = 4 244 | # print(num_calsses) 245 | # lr = 0.1 weight_decay=0.0005 246 | 247 | train_dataset = MyDataLoader(img_root=img_root_dir, txt_file=train_txt_path, transforms=data_transforms["train"], isCaseW=isCaseW) 248 | train_dataloader = Data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) 249 | 250 | test_dataset = MyDataLoader(img_root=img_root_dir, txt_file=val_txt_path, transforms=data_transforms["val"], isCaseW=isCaseW) 251 | test_dataloader = Data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True) 252 | 253 | 254 | data_loader = {"train": train_dataloader, "val": test_dataloader} 255 | 256 | 257 | 258 | # Initialize the model for this run 259 | model_ft, input_size = initialize_model(model_name, num_classes=num_classes, feature_extract=False, use_pretrained=True) 260 | 261 | # Print the model we just instantiated 262 | print(model_ft) 263 | 264 | 265 | # Detect if we have a GPU available 266 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 267 | 268 | # Send the model to GPU 269 | model_ft = model_ft.to(device) 270 | 271 | # Gather the parameters to be optimized/updated in this run. If we are 272 | # finetuning we will be updating all parameters. However, if we are 273 | # doing feature extract method, we will only update the parameters 274 | # that we have just initialized, i.e. the parameters with requires_grad 275 | # is True. 276 | params_to_update = model_ft.parameters() 277 | print("Params to learn:") 278 | if feature_extract: 279 | params_to_update = [] 280 | for name,param in model_ft.named_parameters(): 281 | if param.requires_grad == True: 282 | params_to_update.append(param) 283 | print("\t",name) 284 | else: 285 | for name,param in model_ft.named_parameters(): 286 | if param.requires_grad == True: 287 | print("\t",name) 288 | 289 | # Observe that all parameters are being optimized 290 | optimizer_ft = optim.SGD(params_to_update, lr=0.01, momentum=0.9) 291 | 292 | 293 | # Setup the loss fxn 294 | criterion = nn.CrossEntropyLoss() 295 | 296 | # Train and evaluate 297 | model_ft, hist = train_model(model_ft, data_loader, criterion, optimizer_ft, num_epochs=num_epochs) 298 | 299 | 300 | --------------------------------------------------------------------------------