├── __init__.py ├── supcon ├── __init__.py ├── util.py └── losses.py ├── networks ├── __init__.py ├── classifier.py ├── transfer.py ├── simple_cnn.py ├── inception.py └── resnet_big.py ├── figures ├── tsne_ce.png ├── hyperparams.png ├── tsne_supcon.png ├── transfer_Chev.png ├── transfer_KIA.png ├── confusion_matrix_ce.png └── confusion_matrix_supcon.png ├── environment.yml ├── reports ├── KIA_CE ├── KIA_random ├── Spark_CE ├── KIA_supcon4096 ├── Spark_random ├── small_KIA_ce ├── small_Spark_ce ├── Spark_supcon512_100 ├── small_KIA_random ├── small_Spark_random ├── small_KIA_supcon_512_200 ├── small_Spark_random_20epochs ├── small_Spark_supcon_512_epoch200 ├── KIA_supcon1024 ├── KIA_supcon512 ├── Spark_supcon1024 ├── Spark_supcon4096 ├── Spark_supcon512_200 ├── small_KIA_supcon_1024_200 ├── small_Spark_supcon_1024_epoch200 ├── KIA_incep ├── Spark_incep ├── Spark_incep_20 ├── supcon2048.json ├── supcon4096.json ├── supcon512.json ├── supcon1024.json ├── small_supcon2048.json ├── small_supcon512.json ├── small_supcon1024.json ├── small_supcon4096.json ├── small_resnet.json ├── incep.json └── resnet.json ├── dataset.py ├── .gitignore ├── README.md ├── utils.py ├── test_model.py ├── notebooks ├── RecCNN.ipynb ├── Test_multiple.ipynb ├── Performance_test.ipynb └── histogram_based.ipynb ├── train_test_split.py ├── preprocessing.py ├── preprocessing_survival.py ├── train_baseline.py ├── transfer.py └── train_supcon.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /supcon/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/tsne_ce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/tsne_ce.png -------------------------------------------------------------------------------- /figures/hyperparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/hyperparams.png -------------------------------------------------------------------------------- /figures/tsne_supcon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/tsne_supcon.png -------------------------------------------------------------------------------- /figures/transfer_Chev.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/transfer_Chev.png -------------------------------------------------------------------------------- /figures/transfer_KIA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/transfer_KIA.png -------------------------------------------------------------------------------- /figures/confusion_matrix_ce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/confusion_matrix_ce.png -------------------------------------------------------------------------------- /figures/confusion_matrix_supcon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htn274/CAN-SupCon-IDS/HEAD/figures/confusion_matrix_supcon.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nu 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - python=3.7 9 | - jupyter 10 | - pandas 11 | - numpy 12 | - matplotlib 13 | - scikit-learn 14 | - tqdm 15 | - pytorch 16 | - torchvision 17 | - torchaudio -------------------------------------------------------------------------------- /reports/KIA_CE: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 2 3 | Finish : 1 4 | Finish : 5 5 | Finish : 3 6 | Finish : 4 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0404, 0.0552, 0.4808, 0.1219] 14 | rec [0.9996, 0.9994, 0.9952, 0.9988] 15 | pre [0.9980, 0.9992, 0.9993, 1.0000] 16 | f1 [0.9988, 0.9993, 0.9972, 0.9994] 17 | Validation: 18 | fnr [0.1221, 0.0515, 0.8807, 1.0095] 19 | rec [0.9988, 0.9995, 0.9912, 0.9899] 20 | pre [0.9954, 0.9987, 0.9973, 0.9981] 21 | f1 [0.9971, 0.9991, 0.9942, 0.9940] 22 | -------------------------------------------------------------------------------- /reports/KIA_random: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 4 4 | Finish : 1 5 | Finish : 2 6 | Finish : 3 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0381, 0.0883, 0.6799, 0.2302] 14 | rec [0.9996, 0.9991, 0.9932, 0.9977] 15 | pre [0.9981, 1.0000, 1.0000, 0.9922] 16 | f1 [0.9988, 0.9996, 0.9966, 0.9949] 17 | Validation: 18 | fnr [0.1221, 0.0772, 1.7934, 0.7886] 19 | rec [0.9988, 0.9992, 0.9821, 0.9921] 20 | pre [0.9942, 1.0000, 0.9992, 0.9838] 21 | f1 [0.9965, 0.9996, 0.9905, 0.9879] 22 | -------------------------------------------------------------------------------- /reports/Spark_CE: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 4 3 | Finish : 1 4 | Finish : 3 5 | Finish : 2 6 | Finish : 5 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.2332, 0.0437, 4.9091, 0.0536] 14 | rec [0.9977, 0.9996, 0.9509, 0.9995] 15 | pre [0.9948, 0.9997, 0.9794, 0.9968] 16 | f1 [0.9962, 0.9996, 0.9649, 0.9981] 17 | Validation: 18 | fnr [0.3742, 0.0000, 9.7354, 0.5000] 19 | rec [0.9963, 1.0000, 0.9026, 0.9950] 20 | pre [0.9899, 0.9973, 0.9674, 0.9938] 21 | f1 [0.9931, 0.9986, 0.9337, 0.9944] 22 | -------------------------------------------------------------------------------- /reports/KIA_supcon4096: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 2 3 | Finish : 1 4 | Finish : 4 5 | Finish : 3 6 | Finish : 5 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0000, 0.0000, 0.0549, 0.0271] 14 | rec [1.0000, 1.0000, 0.9995, 0.9997] 15 | pre [0.9998, 1.0000, 1.0000, 1.0000] 16 | f1 [0.9999, 1.0000, 0.9997, 0.9999] 17 | Validation: 18 | fnr [0.0000, 0.0000, 0.1922, 0.1577] 19 | rec [1.0000, 1.0000, 0.9981, 0.9984] 20 | pre [0.9991, 1.0000, 1.0000, 1.0000] 21 | f1 [0.9995, 1.0000, 0.9990, 0.9992] 22 | -------------------------------------------------------------------------------- /reports/Spark_random: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 3 3 | Finish : 1 4 | Finish : 2 5 | Finish : 5 6 | Finish : 4 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.1865, 0.0875, 13.1818, 1.3673] 14 | rec [0.9981, 0.9991, 0.8682, 0.9863] 15 | pre [0.9878, 1.0000, 0.9555, 0.9847] 16 | f1 [0.9930, 0.9996, 0.9097, 0.9855] 17 | Validation: 18 | fnr [0.2535, 0.0680, 20.5291, 1.1250] 19 | rec [0.9975, 0.9993, 0.7947, 0.9887] 20 | pre [0.9803, 1.0000, 0.9524, 0.9808] 21 | f1 [0.9888, 0.9997, 0.8663, 0.9847] 22 | -------------------------------------------------------------------------------- /reports/small_KIA_ce: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 4 4 | Finish : 1 5 | Finish : 3 6 | Finish : 2 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0000, 0.0533, 0.0284, 0.0140] 14 | rec [1.0000, 0.9995, 0.9997, 0.9999] 15 | pre [0.9998, 1.0000, 1.0000, 1.0000] 16 | f1 [0.9999, 0.9997, 0.9999, 0.9999] 17 | Validation: 18 | fnr [0.0364, 0.0533, 0.2044, 0.2070] 19 | rec [0.9996, 0.9995, 0.9980, 0.9979] 20 | pre [0.9990, 1.0000, 0.9992, 0.9984] 21 | f1 [0.9993, 0.9997, 0.9986, 0.9981] 22 | -------------------------------------------------------------------------------- /reports/small_Spark_ce: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 2 4 | Finish : 1 5 | Finish : 3 6 | Finish : 4 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0697, 0.0503, 0.0626, 0.0831] 14 | rec [0.9993, 0.9995, 0.9994, 0.9992] 15 | pre [0.9996, 1.0000, 0.9939, 1.0000] 16 | f1 [0.9994, 0.9997, 0.9966, 0.9996] 17 | Validation: 18 | fnr [0.1084, 0.0469, 1.4572, 0.1292] 19 | rec [0.9989, 0.9995, 0.9854, 0.9987] 20 | pre [0.9982, 1.0000, 0.9901, 0.9989] 21 | f1 [0.9985, 0.9998, 0.9878, 0.9988] 22 | -------------------------------------------------------------------------------- /reports/Spark_supcon512_100: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 4 3 | Finish : 3 4 | Finish : 1 5 | Finish : 5 6 | Finish : 2 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.1969, 0.0583, 3.5909, 0.0000] 14 | rec [0.9980, 0.9994, 0.9641, 1.0000] 15 | pre [0.9964, 1.0000, 0.9824, 0.9965] 16 | f1 [0.9972, 0.9997, 0.9732, 0.9983] 17 | Validation: 18 | fnr [0.1569, 0.0340, 4.7619, 0.0000] 19 | rec [0.9984, 0.9997, 0.9524, 1.0000] 20 | pre [0.9951, 1.0000, 0.9858, 0.9969] 21 | f1 [0.9967, 0.9998, 0.9688, 0.9984] 22 | -------------------------------------------------------------------------------- /reports/small_KIA_random: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 4 3 | Finish : 1 4 | Finish : 5 5 | Finish : 2 6 | Finish : 3 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0008, 0.1028, 0.2084, 0.0280] 14 | rec [1.0000, 0.9990, 0.9979, 0.9997] 15 | pre [0.9991, 1.0000, 1.0000, 0.9992] 16 | f1 [0.9996, 0.9995, 0.9990, 0.9995] 17 | Validation: 18 | fnr [0.0364, 0.1243, 0.4862, 0.1961] 19 | rec [0.9996, 0.9988, 0.9951, 0.9980] 20 | pre [0.9980, 1.0000, 0.9994, 0.9973] 21 | f1 [0.9988, 0.9994, 0.9973, 0.9977] 22 | -------------------------------------------------------------------------------- /reports/small_Spark_random: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 2 3 | Finish : 4 4 | Finish : 3 5 | Finish : 5 6 | Finish : 1 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.3486, 0.1307, 10.8764, 1.0429] 14 | rec [0.9965, 0.9987, 0.8912, 0.9896] 15 | pre [0.9878, 0.9997, 0.9542, 0.9943] 16 | f1 [0.9921, 0.9992, 0.9215, 0.9919] 17 | Validation: 18 | fnr [0.3711, 0.1876, 10.7832, 0.9688] 19 | rec [0.9963, 0.9981, 0.8922, 0.9903] 20 | pre [0.9880, 0.9998, 0.9516, 0.9940] 21 | f1 [0.9921, 0.9989, 0.9207, 0.9921] 22 | -------------------------------------------------------------------------------- /networks/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LinearClassifier(nn.Module): 6 | def __init__(self, n_classes, feat_dim, init=False): 7 | super().__init__() 8 | self.n_classes = n_classes 9 | self.fc = nn.Linear(feat_dim, n_classes) 10 | if init: 11 | torch.nn.init.xavier_normal_(self.fc.weight) 12 | torch.nn.init.constant_(self.fc.bias, 0) 13 | 14 | def forward(self, x): 15 | output = self.fc(x) 16 | return output -------------------------------------------------------------------------------- /reports/small_KIA_supcon_512_200: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 1 3 | Finish : 4 4 | Finish : 3 5 | Finish : 2 6 | Finish : 5 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.0008, 0.0114, 0.0237, 0.0140] 14 | rec [1.0000, 0.9999, 0.9998, 0.9999] 15 | pre [0.9999, 1.0000, 1.0000, 1.0000] 16 | f1 [0.9999, 0.9999, 0.9999, 0.9999] 17 | Validation: 18 | fnr [0.0000, 0.0089, 0.0552, 0.0545] 19 | rec [1.0000, 0.9999, 0.9994, 0.9995] 20 | pre [0.9997, 1.0000, 1.0000, 0.9999] 21 | f1 [0.9999, 1.0000, 0.9997, 0.9997] 22 | -------------------------------------------------------------------------------- /reports/small_Spark_random_20epochs: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 4 3 | Finish : 3 4 | Finish : 1 5 | Finish : 5 6 | Finish : 2 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.1984, 0.1256, 1.8153, 0.2677] 14 | rec [0.9980, 0.9987, 0.9818, 0.9973] 15 | pre [0.9973, 1.0000, 0.9797, 0.9999] 16 | f1 [0.9977, 0.9994, 0.9808, 0.9986] 17 | Validation: 18 | fnr [0.2543, 0.1758, 2.9872, 0.2799] 19 | rec [0.9975, 0.9982, 0.9701, 0.9972] 20 | pre [0.9960, 1.0000, 0.9737, 0.9994] 21 | f1 [0.9967, 0.9991, 0.9719, 0.9983] 22 | -------------------------------------------------------------------------------- /networks/transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | class TransferModel(nn.Module): 7 | def __init__(self, feat_extractor, classifier): 8 | super().__init__() 9 | self.encoder = copy.deepcopy(feat_extractor) 10 | self.classifier = copy.deepcopy(classifier) 11 | 12 | def forward(self, x, return_feat=False): 13 | feat = self.encoder(x) 14 | output = self.classifier(feat) 15 | if return_feat: 16 | return feat, output 17 | return output -------------------------------------------------------------------------------- /reports/small_Spark_supcon_512_epoch200: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 2 4 | Finish : 1 5 | Finish : 4 6 | Finish : 3 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr [0.1215, 0.0201, 0.0000, 0.0554] 14 | rec [0.9988, 0.9998, 1.0000, 0.9994] 15 | pre [0.9999, 1.0000, 0.9892, 1.0000] 16 | f1 [0.9993, 0.9999, 0.9946, 0.9997] 17 | Validation: 18 | fnr [0.1126, 0.0352, 0.4736, 0.0431] 19 | rec [0.9989, 0.9996, 0.9953, 0.9996] 20 | pre [0.9992, 1.0000, 0.9902, 1.0000] 21 | f1 [0.9991, 0.9998, 0.9927, 0.9998] 22 | -------------------------------------------------------------------------------- /reports/KIA_supcon1024: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 2 3 | Finish : 3 4 | Finish : 4 5 | Finish : 5 6 | Finish : 1 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.0000', '0.0442', '0.1786', '0.0406'] 14 | rec ['1.0000', '0.9996', '0.9982', '0.9996'] 15 | pre ['0.9992', '1.0000', '1.0000', '1.0000'] 16 | f1 ['0.9996', '0.9998', '0.9991', '0.9998'] 17 | Validation: 18 | fnr ['0.0000', '0.0515', '0.3363', '0.3155'] 19 | rec ['1.0000', '0.9995', '0.9966', '0.9968'] 20 | pre ['0.9982', '1.0000', '1.0000', '1.0000'] 21 | f1 ['0.9991', '0.9997', '0.9983', '0.9984'] 22 | -------------------------------------------------------------------------------- /reports/KIA_supcon512: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 4 4 | Finish : 2 5 | Finish : 3 6 | Finish : 1 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.0167', '0.0442', '0.2610', '0.0812'] 14 | rec ['0.9998', '0.9996', '0.9974', '0.9992'] 15 | pre ['0.9990', '1.0000', '0.9999', '0.9985'] 16 | f1 ['0.9994', '0.9998', '0.9987', '0.9988'] 17 | Validation: 18 | fnr ['0.0333', '0.0257', '0.4163', '0.2208'] 19 | rec ['0.9997', '0.9997', '0.9958', '0.9978'] 20 | pre ['0.9982', '1.0000', '0.9994', '0.9987'] 21 | f1 ['0.9989', '0.9999', '0.9976', '0.9983'] 22 | -------------------------------------------------------------------------------- /reports/Spark_supcon1024: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 3 3 | Finish : 4 4 | Finish : 5 5 | Finish : 2 6 | Finish : 1 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.2383', '0.0583', '3.4091', '0.0268'] 14 | rec ['0.9976', '0.9994', '0.9659', '0.9997'] 15 | pre ['0.9959', '1.0000', '0.9788', '1.0000'] 16 | f1 ['0.9967', '0.9997', '0.9723', '0.9999'] 17 | Validation: 18 | fnr ['0.2293', '0.0340', '4.4444', '0.1875'] 19 | rec ['0.9977', '0.9997', '0.9556', '0.9981'] 20 | pre ['0.9945', '1.0000', '0.9795', '1.0000'] 21 | f1 ['0.9961', '0.9998', '0.9673', '0.9991'] 22 | -------------------------------------------------------------------------------- /reports/Spark_supcon4096: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 5 3 | Finish : 2 4 | Finish : 3 5 | Finish : 1 6 | Finish : 4 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.1503', '0.0583', '5.9091', '0.5898'] 14 | rec ['0.9985', '0.9994', '0.9409', '0.9941'] 15 | pre ['0.9920', '1.0000', '0.9862', '1.0000'] 16 | f1 ['0.9952', '0.9997', '0.9629', '0.9970'] 17 | Validation: 18 | fnr ['0.0966', '0.0340', '8.4656', '0.5625'] 19 | rec ['0.9990', '0.9997', '0.9153', '0.9944'] 20 | pre ['0.9893', '1.0000', '0.9909', '1.0000'] 21 | f1 ['0.9941', '0.9998', '0.9515', '0.9972'] 22 | -------------------------------------------------------------------------------- /reports/Spark_supcon512_200: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 1 3 | Finish : 5 4 | Finish : 2 5 | Finish : 4 6 | Finish : 3 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.1606', '0.0875', '9.0909', '0.0000'] 14 | rec ['0.9984', '0.9991', '0.9091', '1.0000'] 15 | pre ['0.9917', '1.0000', '0.9847', '0.9881'] 16 | f1 ['0.9950', '0.9996', '0.9452', '0.9940'] 17 | Validation: 18 | fnr ['0.0966', '0.0680', '10.3704', '0.0000'] 19 | rec ['0.9990', '0.9993', '0.8963', '1.0000'] 20 | pre ['0.9905', '1.0000', '0.9908', '0.9871'] 21 | f1 ['0.9948', '0.9997', '0.9410', '0.9935'] 22 | -------------------------------------------------------------------------------- /reports/small_KIA_supcon_1024_200: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 3 3 | Finish : 4 4 | Finish : 1 5 | Finish : 5 6 | Finish : 2 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.0000', '0.0228', '0.0047', '0.0047'] 14 | rec ['1.0000', '0.9998', '1.0000', '1.0000'] 15 | pre ['0.9999', '1.0000', '1.0000', '1.0000'] 16 | f1 ['1.0000', '0.9999', '1.0000', '1.0000'] 17 | Validation: 18 | fnr ['0.0000', '0.0178', '0.0552', '0.0327'] 19 | rec ['1.0000', '0.9998', '0.9994', '0.9997'] 20 | pre ['0.9998', '1.0000', '1.0000', '0.9998'] 21 | f1 ['0.9999', '0.9999', '0.9997', '0.9997'] 22 | -------------------------------------------------------------------------------- /reports/small_Spark_supcon_1024_epoch200: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Finish : 3 3 | Finish : 2 4 | Finish : 1 5 | Finish : 5 6 | Finish : 4 7 | Submit : 2 8 | Submit : 3 9 | Submit : 4 10 | Submit : 5 11 | FINAL RESULTS 12 | Train: 13 | fnr ['0.0411', '0.0151', '0.0000', '0.0461'] 14 | rec ['0.9996', '0.9998', '1.0000', '0.9995'] 15 | pre ['0.9999', '1.0000', '0.9963', '1.0000'] 16 | f1 ['0.9997', '0.9999', '0.9981', '0.9998'] 17 | Validation: 18 | fnr ['0.0750', '0.0352', '0.4736', '0.0646'] 19 | rec ['0.9992', '0.9996', '0.9953', '0.9994'] 20 | pre ['0.9992', '1.0000', '0.9935', '1.0000'] 21 | f1 ['0.9992', '0.9998', '0.9944', '0.9997'] 22 | -------------------------------------------------------------------------------- /reports/KIA_incep: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Loading pretrained model 3 | Loading pretrained model 4 | Loading pretrained model 5 | Loading pretrained model 6 | Loading pretrained model 7 | Finish : 4 8 | Finish : 3 9 | Finish : 5 10 | Finish : 1 11 | Finish : 2 12 | Submit : 2 13 | Submit : 3 14 | Submit : 4 15 | Submit : 5 16 | FINAL RESULTS 17 | Train: 18 | fnr ['0.0705', '0.1142', '1.5396', '0.5184'] 19 | rec ['0.9993', '0.9989', '0.9846', '0.9948'] 20 | pre ['0.9940', '0.9999', '0.9988', '0.9963'] 21 | f1 ['0.9966', '0.9994', '0.9916', '0.9955'] 22 | Validation: 23 | fnr ['0.1033', '0.1421', '1.6243', '0.5338'] 24 | rec ['0.9990', '0.9986', '0.9838', '0.9947'] 25 | pre ['0.9937', '0.9998', '0.9981', '0.9952'] 26 | f1 ['0.9963', '0.9992', '0.9909', '0.9949'] 27 | -------------------------------------------------------------------------------- /reports/Spark_incep: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Loading pretrained model 3 | Loading pretrained model 4 | Loading pretrained model 5 | Loading pretrained model 6 | Loading pretrained model 7 | Finish : 4 8 | Finish : 2 9 | Finish : 3 10 | Finish : 1 11 | Finish : 5 12 | Submit : 2 13 | Submit : 3 14 | Submit : 4 15 | Submit : 5 16 | FINAL RESULTS 17 | Train: 18 | fnr ['0.7579', '0.1407', '14.6479', '1.6336'] 19 | rec ['0.9924', '0.9986', '0.8535', '0.9837'] 20 | pre ['0.9834', '0.9996', '0.9239', '0.9847'] 21 | f1 ['0.9879', '0.9991', '0.8867', '0.9841'] 22 | Validation: 23 | fnr ['0.7338', '0.1641', '15.6284', '1.5285'] 24 | rec ['0.9927', '0.9984', '0.8437', '0.9847'] 25 | pre ['0.9825', '0.9999', '0.9288', '0.9819'] 26 | f1 ['0.9876', '0.9991', '0.8835', '0.9832'] 27 | -------------------------------------------------------------------------------- /reports/Spark_incep_20: -------------------------------------------------------------------------------- 1 | Submit : 1 2 | Loading pretrained model 3 | Loading pretrained model 4 | Loading pretrained model 5 | Loading pretrained model 6 | Loading pretrained model 7 | Finish : 2 8 | Finish : 1 9 | Finish : 3 10 | Finish : 4 11 | Finish : 5 12 | Submit : 2 13 | Submit : 3 14 | Submit : 4 15 | Submit : 5 16 | FINAL RESULTS 17 | Train: 18 | fnr ['0.5470', '0.1357', '6.3224', '0.8583'] 19 | rec ['0.9945', '0.9986', '0.9368', '0.9914'] 20 | pre ['0.9919', '0.9998', '0.9482', '0.9959'] 21 | f1 ['0.9932', '0.9992', '0.9422', '0.9937'] 22 | Validation: 23 | fnr ['0.5462', '0.1876', '7.3588', '0.8181'] 24 | rec ['0.9945', '0.9981', '0.9264', '0.9918'] 25 | pre ['0.9907', '0.9998', '0.9480', '0.9957'] 26 | f1 ['0.9926', '0.9989', '0.9368', '0.9937'] 27 | -------------------------------------------------------------------------------- /reports/supcon2048.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.0, 0.03222687721559781, 0.1433997364545384, 0.16903938089326748, 0.22225443300961492], [0.0, 0.02301919801114129, 0.11627005658476088, 0.2509178310134439, 0.17635406097502054], [0.0, 0.05984991482896736, 0.10464305092628479, 0.242994110034072, 0.16185920664830653], [0.0, 0.02301919801114129, 0.11239438803193551, 0.17960434219909668, 0.1352853070493308], [0.0, 0.04143455642005432, 0.10076738237345942, 0.21129922611658436, 0.1739382519205682]], "rec": [[1.0, 0.999677731227844, 0.9985660026354546, 0.9983096061910673, 0.9977774556699038], [1.0, 0.9997698080198886, 0.9988372994341523, 0.9974908216898656, 0.9982364593902497], [1.0, 0.9994015008517103, 0.9989535694907371, 0.9975700588996593, 0.9983814079335169], [1.0, 0.9997698080198886, 0.9988760561196807, 0.9982039565780091, 0.9986471469295067], [1.0, 0.9995856544357995, 0.9989923261762654, 0.9978870077388342, 0.9982606174807943]], "pre": [[0.999023499484896, 1.0, 1.0, 1.0, 1.0], [0.9990088664948051, 1.0, 1.0, 1.0, 1.0], [0.9990283772435209, 1.0, 1.0, 1.0, 1.0], [0.9992284063661359, 1.0, 1.0, 1.0, 1.0], [0.9990869140625, 1.0, 1.0, 1.0, 1.0]], "f1": [[0.999511511237684, 0.9998388396454472, 0.9992824868617527, 0.9991540881334425, 0.9988874915352617], [0.9995041875392315, 0.9998848907613326, 0.9994183115523325, 0.9987438348737885, 0.9991174514900562], [0.9995139524943518, 0.9997006608487417, 0.9994765108478586, 0.9987835515007272, 0.9991900484762032], [0.9996140542864401, 0.9998848907613326, 0.9994377120698014, 0.9991011711211569, 0.9993231156021853], [0.9995432485045663, 0.9997927842884443, 0.9994959091085347, 0.9989423865048518, 0.9991295517191354]]} -------------------------------------------------------------------------------- /reports/supcon4096.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.0, 0.013811518806684775, 0.05038369118672971, 0.08716093077309105, 0.08213750785137942], [0.0, 0.04143455642005432, 0.07363770250368189, 0.11357333403766409, 0.10629559839590279], [0.0, 0.02301919801114129, 0.06976203395085652, 0.13998573730223712, 0.1135430255592598], [0.0, 0.004603839602228259, 0.08914037671498333, 0.1241382953434933, 0.0942165531236411], [0.0, 0.02301919801114129, 0.07751337105650724, 0.08716093077309105, 0.1135430255592598]], "rec": [[1.0, 0.9998618848119332, 0.9994961630881327, 0.9991283906922691, 0.9991786249214862], [1.0, 0.9995856544357995, 0.9992636229749632, 0.9988642666596234, 0.998937044016041], [1.0, 0.9997698080198886, 0.9993023796604914, 0.9986001426269776, 0.9988645697444074], [1.0, 0.9999539616039778, 0.9991085962328502, 0.9987586170465651, 0.9990578344687636], [1.0, 0.9997698080198886, 0.999224866289435, 0.9991283906922691, 0.9988645697444074]], "pre": [[0.9995945206550201, 1.0, 1.0, 1.0, 1.0], [0.9994382790824899, 1.0, 1.0, 1.0, 1.0], [0.9993992263207252, 1.0, 1.0, 1.0, 1.0], [0.9994626886085101, 1.0, 1.0, 1.0, 1.0], [0.999487099326879, 1.0, 1.0, 1.0, 1.0]], "f1": [[0.9997972192158002, 0.999930937636686, 0.9997480180651664, 0.9995640053376317, 0.9995891437271914], [0.9997190606364864, 0.9997927842884443, 0.9996316758747698, 0.9994318106740311, 0.9994682393889587], [0.9996995229009964, 0.9998848907613326, 0.9996510681192572, 0.9992995810701873, 0.9994319623886586], [0.9997312721089765, 0.9999769802720931, 0.9995540993776778, 0.9993789230260985, 0.9995286952108183], [0.9997434838797942, 0.9998848907613326, 0.999612282878412, 0.9995640053376317, 0.9994319623886586]]} -------------------------------------------------------------------------------- /reports/supcon512.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.0, 0.004603839602228259, 0.03488101697542826, 0.06867224848788991, 0.045900372034594385], [0.0004887275002077092, 0.004603839602228259, 0.04263235408107899, 0.05810728718206069, 0.04831618108904672], [0.0004887275002077092, 0.0, 0.03488101697542826, 0.06338976783497531, 0.07247427163357008], [0.001954910000830837, 0.013811518806684775, 0.027129679869777535, 0.05282480652914609, 0.0676426535246654], [0.0, 0.009207679204456518, 0.027129679869777535, 0.05282480652914609, 0.0676426535246654]], "rec": [[1.0, 0.9999539616039778, 0.9996511898302457, 0.9993132775151211, 0.9995409962796541], [0.999995112724998, 0.9999539616039778, 0.9995736764591893, 0.9994189271281794, 0.9995168381891095], [0.999995112724998, 1.0, 0.9996511898302457, 0.9993661023216502, 0.9992752572836643], [0.9999804508999917, 0.9998618848119332, 0.9997287032013022, 0.9994717519347085, 0.9993235734647533], [1.0, 0.9999079232079554, 0.9997287032013022, 0.9994717519347085, 0.9993235734647533]], "pre": [[0.9997312721089765, 1.0, 1.0, 1.0, 1.0], [0.9997361554923632, 1.0, 1.0, 1.0, 0.99997583081571], [0.9996921949432026, 1.0, 1.0, 0.9999735715418363, 1.0], [0.9997166128393927, 1.0, 1.0, 0.9999207271958567, 0.9999758261416105], [0.9997215029071188, 1.0, 1.0, 1.0, 1.0]], "f1": [[0.9998656179983924, 0.9999769802720931, 0.9998255644926833, 0.9996565208201227, 0.9997704454566324], [0.9998656173417155, 0.9999769802720931, 0.9997867927819665, 0.9997093791281373, 0.9997462818204883], [0.9998436308907176, 1.0, 0.9998255644926833, 0.9996697446465608, 0.9996374972812296], [0.9998485144644254, 0.999930937636686, 0.9998643331976665, 0.9996961891552737, 0.999649593407523], [0.9998607320607014, 0.9999539594843463, 0.9998643331976665, 0.9997358061874191, 0.9996616723054615]]} -------------------------------------------------------------------------------- /reports/supcon1024.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.0, 0.018415358408913035, 0.0620106968452058, 0.05546604685560339, 0.05073199014349906], [0.0004887275002077092, 0.009207679204456518, 0.05038369118672971, 0.05810728718206069, 0.05073199014349906], [0.001954910000830837, 0.013811518806684775, 0.027129679869777535, 0.04490108554977417, 0.0603952263613084], [0.0014661825006231275, 0.013811518806684775, 0.027129679869777535, 0.04754232587623148, 0.05797941730685607], [0.0004887275002077092, 0.004603839602228259, 0.023254011316952174, 0.07659596946726183, 0.05797941730685607]], "rec": [[1.0, 0.9998158464159109, 0.999379893031548, 0.999445339531444, 0.999492680098565], [0.999995112724998, 0.9999079232079554, 0.9994961630881327, 0.9994189271281794, 0.999492680098565], [0.9999804508999917, 0.9998618848119332, 0.9997287032013022, 0.9995509891445022, 0.9993960477363869], [0.9999853381749938, 0.9998618848119332, 0.9997287032013022, 0.9995245767412377, 0.9994202058269315], [0.999995112724998, 0.9999539616039778, 0.9997674598868305, 0.9992340403053274, 0.9994202058269315]], "pre": [[0.999697080737755, 1.0, 1.0, 1.0, 1.0], [0.9997215015463993, 1.0, 1.0, 1.0, 0.9999516616314199], [0.9997459213040101, 1.0, 1.0, 0.9998943140984993, 1.0], [0.9997459225454652, 1.0, 1.0, 1.0, 0.9999274889420636], [0.9997068480300187, 1.0, 1.0, 0.9999735680490577, 1.0]], "f1": [[0.9998485174253826, 0.9999079147290391, 0.9996898503527951, 0.9997225928323272, 0.9997462756895863], [0.9998582884172772, 0.9999539594843463, 0.9997480180651664, 0.9997093791281373, 0.9997221181843444], [0.9998631723490864, 0.999930937636686, 0.9998643331976665, 0.9997226221453196, 0.9996979326510639], [0.9998656160283428, 0.999930937636686, 0.9998643331976665, 0.9997622318503646, 0.9996737830294683], [0.9998509596002785, 0.9999769802720931, 0.9998837164231171, 0.99960366739768, 0.9997100188487749]]} -------------------------------------------------------------------------------- /reports/small_supcon2048.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.0014661825006231275, 0.013811518806684775, 0.06976203395085652, 0.04754232587623148, 0.0603952263613084], [0.0014661825006231275, 0.03222687721559781, 0.05813502829238044, 0.07659596946726183, 0.06281103541576073], [0.0009774550004154185, 0.013811518806684775, 0.05038369118672971, 0.06338976783497531, 0.0676426535246654], [0.001954910000830837, 0.004603839602228259, 0.06976203395085652, 0.05810728718206069, 0.04106875392568971], [0.0, 0.02301919801114129, 0.06976203395085652, 0.05546604685560339, 0.045900372034594385]], "rec": [[0.9999853381749938, 0.9998618848119332, 0.9993023796604914, 0.9995245767412377, 0.9993960477363869], [0.9999853381749938, 0.999677731227844, 0.9994186497170762, 0.9992340403053274, 0.9993718896458423], [0.9999902254499958, 0.9998618848119332, 0.9994961630881327, 0.9993661023216502, 0.9993235734647533], [0.9999804508999917, 0.9999539616039778, 0.9993023796604914, 0.9994189271281794, 0.9995893124607431], [1.0, 0.9997698080198886, 0.9993023796604914, 0.999445339531444, 0.9995409962796541]], "pre": [[0.9996873076208995, 1.0, 1.0, 0.999920731385087, 1.0], [0.9996286995690961, 1.0, 1.0, 0.9999207083388397, 0.9999758273103048], [0.9996677724632229, 1.0, 1.0, 0.9999471444805624, 1.0], [0.9997166128393927, 1.0, 1.0, 0.9999207230061836, 0.9999758325680314], [0.9996921964470675, 1.0, 1.0, 1.0, 1.0]], "f1": [[0.9998363006887587, 0.999930937636686, 0.9996510681192572, 0.9997226148176523, 0.9996979326510639], [0.9998069870681336, 0.9998388396454472, 0.9997092403419334, 0.9995772563939971, 0.9996737672631489], [0.9998289729581815, 0.999930937636686, 0.9997480180651664, 0.9996565389696169, 0.9996616723054615], [0.9998485144644254, 0.9999769802720931, 0.9996510681192572, 0.9996697620966144, 0.9997825351568164], [0.9998460745341311, 0.9998848907613326, 0.9996510681192572, 0.9997225928323272, 0.9997704454566324]]} -------------------------------------------------------------------------------- /reports/small_supcon512.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.002932365001246255, 0.013811518806684775, 0.03488101697542826, 0.04490108554977417, 0.06281103541576073], [0.0, 0.004603839602228259, 0.05038369118672971, 0.05810728718206069, 0.053147799197951394], [0.0014661825006231275, 0.009207679204456518, 0.0310053484226029, 0.06867224848788991, 0.06281103541576073], [0.0009774550004154185, 0.018415358408913035, 0.0310053484226029, 0.05018356620268878, 0.05797941730685607], [0.0009774550004154185, 0.004603839602228259, 0.05038369118672971, 0.05282480652914609, 0.04831618108904672]], "rec": [[0.9999706763499875, 0.9998618848119332, 0.9996511898302457, 0.9995509891445022, 0.9993718896458423], [1.0, 0.9999539616039778, 0.9994961630881327, 0.9994189271281794, 0.9994685220080205], [0.9999853381749938, 0.9999079232079554, 0.9996899465157739, 0.9993132775151211, 0.9993718896458423], [0.9999902254499958, 0.9998158464159109, 0.9996899465157739, 0.9994981643379731, 0.9994202058269315], [0.9999902254499958, 0.9999539616039778, 0.9994961630881327, 0.9994717519347085, 0.9995168381891095]], "pre": [[0.9997312642307805, 1.0, 1.0, 0.9998414795244386, 1.0], [0.9997166183777868, 1.0, 1.0, 1.0, 1.0], [0.999697076297686, 1.0, 1.0, 0.9999207146255088, 1.0], [0.9997312694829625, 1.0, 1.0, 0.9999735750336919, 0.9999758284788862], [0.9997361542032102, 1.0, 1.0, 0.9999735743353946, 0.99997583081571]], "f1": [[0.9998509559587074, 0.999930937636686, 0.9998255644926833, 0.999696213231895, 0.9996858461612818], [0.9998582891097624, 0.9999769802720931, 0.9997480180651664, 0.9997093791281373, 0.9997341903680255], [0.9998411864593133, 0.9999539594843463, 0.9998449492208699, 0.9996169037900106, 0.9996858461612818], [0.9998607306995961, 0.9999079147290391, 0.9998449492208699, 0.9997358131670717, 0.9996979399504622], [0.9998631736863451, 0.9999769802720931, 0.9997480180651664, 0.9997226001611561, 0.9997462818204883]]} -------------------------------------------------------------------------------- /reports/small_supcon1024.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.001954910000830837, 0.02762303761336955, 0.06588636539803117, 0.05018356620268878, 0.05556360825240373], [0.0014661825006231275, 0.013811518806684775, 0.06976203395085652, 0.060748527508518, 0.045900372034594385], [0.0004887275002077092, 0.018415358408913035, 0.05038369118672971, 0.03961860489685957, 0.053147799197951394], [0.0009774550004154185, 0.02301919801114129, 0.05038369118672971, 0.07131348881434722, 0.05556360825240373], [0.0, 0.02762303761336955, 0.05813502829238044, 0.05282480652914609, 0.04831618108904672]], "rec": [[0.9999804508999917, 0.9997237696238663, 0.9993411363460197, 0.9994981643379731, 0.9994443639174759], [0.9999853381749938, 0.9998618848119332, 0.9993023796604914, 0.9993925147249149, 0.9995409962796541], [0.999995112724998, 0.9998158464159109, 0.9994961630881327, 0.9996038139510314, 0.9994685220080205], [0.9999902254499958, 0.9997698080198886, 0.9994961630881327, 0.9992868651118565, 0.9994443639174759], [1.0, 0.9997237696238663, 0.9994186497170762, 0.9994717519347085, 0.9995168381891095]], "pre": [[0.9996824218024761, 1.0, 1.0, 0.9998943085134493, 1.0], [0.999697076297686, 1.0, 1.0, 0.999920720911181, 0.9999758313998454], [0.999741040236484, 1.0, 1.0, 0.999973577826512, 0.9999758296473545], [0.9996726565629916, 1.0, 1.0, 0.9999471402896712, 0.9999758290631345], [0.999706849462313, 1.0, 1.0, 0.9999735743353946, 1.0]], "f1": [[0.9998314141421545, 0.999861865733493, 0.9996704596119178, 0.9996961971812383, 0.9997221047544251], [0.9998411864593133, 0.999930937636686, 0.9996510681192572, 0.9996565480436448, 0.9997583665579316], [0.9998680603404043, 0.9999079147290391, 0.9997480180651664, 0.9997886617002166, 0.999722111469547], [0.9998314157897438, 0.9998848907613326, 0.9997480180651664, 0.9996168936681771, 0.999710025856028], [0.9998534032436975, 0.999861865733493, 0.9997092403419334, 0.9997226001611561, 0.9997583607191185]]} -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from tfrecord.torch.dataset import TFRecordDataset 5 | from torch.utils.data import Dataset 6 | 7 | ID_LEN = 29 #CAN bus 2.0 has 29 bits 8 | DATA_LEN = 8 #Data field in Can message has 8 bytes 9 | HIST_LEN = 256 10 | 11 | class CANDataset(Dataset): 12 | def __init__(self, root_dir, window_size, is_train=True, include_data=False, transform=None): 13 | if is_train: 14 | self.root_dir = os.path.join(root_dir, 'train') 15 | else: 16 | self.root_dir = os.path.join(root_dir, 'val') 17 | 18 | # self.num_classes = num_classes 19 | self.include_data = include_data 20 | self.is_train = is_train 21 | self.transform = transform 22 | self.window_size = window_size 23 | self.total_size = len(os.listdir(self.root_dir)) 24 | 25 | def __getitem__(self, idx): 26 | filenames = '{}/{}.tfrec'.format(self.root_dir, idx) 27 | index_path = None 28 | description = {'id_seq': 'int', 'data_seq': 'int','label': 'int'} 29 | dataset = TFRecordDataset(filenames, index_path, description) 30 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) 31 | data = next(iter(dataloader)) 32 | id_seq, data_seq, label = data['id_seq'], data['data_seq'], data['label'] 33 | id_seq = id_seq.to(torch.float) 34 | data_seq = data_seq.to(torch.float) 35 | 36 | id_seq[id_seq == 0] = -1 37 | id_seq = id_seq.view(-1, self.window_size, ID_LEN) 38 | data_seq = data_seq.view(-1, self.window_size, DATA_LEN) 39 | 40 | if self.include_data: 41 | return id_seq, data_seq, label[0][0] 42 | else: 43 | return id_seq, label[0][0] 44 | 45 | def __len__(self): 46 | return self.total_size -------------------------------------------------------------------------------- /reports/small_supcon4096.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.001954910000830837, 0.04603839602228258, 0.10464305092628479, 0.04225984522331687, 0.07005846257911774], [0.0014661825006231275, 0.0552460752267391, 0.07363770250368189, 0.07659596946726183, 0.0676426535246654], [0.0004887275002077092, 0.04143455642005432, 0.06976203395085652, 0.04754232587623148, 0.05556360825240373], [0.0004887275002077092, 0.03683071681782607, 0.13952406790171304, 0.060748527508518, 0.053147799197951394], [0.001954910000830837, 0.02301919801114129, 0.0930160452678087, 0.060748527508518, 0.05073199014349906]], "rec": [[0.9999804508999917, 0.9995396160397771, 0.9989535694907371, 0.9995774015477669, 0.9992994153742089], [0.9999853381749938, 0.9994475392477327, 0.9992636229749632, 0.9992340403053274, 0.9993235734647533], [0.999995112724998, 0.9995856544357995, 0.9993023796604914, 0.9995245767412377, 0.9994443639174759], [0.999995112724998, 0.9996316928318217, 0.9986047593209829, 0.9993925147249149, 0.9994685220080205], [0.9999804508999917, 0.9997698080198886, 0.9990698395473219, 0.9993925147249149, 0.999492680098565]], "pre": [[0.9996091631694912, 1.0, 1.0, 0.9998943168908029, 0.9999516522832209], [0.9995798647750811, 1.0, 0.9999612162581446, 0.9999471374953746, 0.9999516534519435], [0.9996677740863787, 1.0, 1.0, 1.0, 0.9999758290631345], [0.9995652195152931, 1.0, 1.0, 0.9999735722402812, 1.0], [0.9996482331041963, 1.0, 1.0, 0.999920720911181, 0.9999516616314199]], "f1": [[0.9997947725640236, 0.9997697550193405, 0.9994765108478586, 0.9997358341038172, 0.9996254274356279], [0.9997825603639303, 0.9997236932995626, 0.9996122979102857, 0.9995904617213819, 0.9996375148014788], [0.9998314166135264, 0.9997927842884443, 0.9996510681192572, 0.9997622318503646, 0.999710025856028], [0.9997801199079436, 0.9998158124971221, 0.9993018926466026, 0.999682959048877, 0.9997341903680255], [0.9998143144048043, 0.9998848907613326, 0.9995347033734006, 0.9996565480436448, 0.9997221181843444]]} -------------------------------------------------------------------------------- /reports/small_resnet.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.004398547501869383, 0.13351134846461948, 0.18215642198279203, 0.11621457436412139, 0.08455331690583176], [0.003909820001661674, 0.23019198011141292, 0.26354546159212466, 0.13998573730223712, 0.1328694979948785], [0.003909820001661674, 0.24860733852032596, 0.22866444461669638, 0.08980217109954834, 0.13045368894042614], [0.00635345750270022, 0.22558814050918466, 0.2480427873808232, 0.08716093077309105, 0.11112721650480746], [0.004398547501869383, 0.23019198011141292, 0.25579412448647393, 0.09508465175246296, 0.1546117794849495]], "rec": [[0.9999560145249813, 0.9986648865153538, 0.9981784357801721, 0.9988378542563587, 0.9991544668309417], [0.9999609017999834, 0.9976980801988858, 0.9973645453840787, 0.9986001426269776, 0.9986713050200512], [0.9999609017999834, 0.9975139266147968, 0.997713355553833, 0.9991019782890045, 0.9986954631105958], [0.999936465424973, 0.9977441185949082, 0.9975195721261918, 0.9991283906922691, 0.998888727834952], [0.9999560145249813, 0.9976980801988858, 0.9974420587551353, 0.9990491534824754, 0.9984538822051505]], "pre": [[0.9993064577574166, 1.0, 0.9997670897868871, 0.9997884996695308, 0.9998066091328837], [0.9989746844713522, 1.0, 0.9997668997668998, 0.9996562756140769, 0.9998790634674922], [0.9990917569619758, 1.0, 0.999572881882426, 0.9998942666067511, 0.999806520267002], [0.9991600414119118, 0.9998615916955017, 0.9998057724429942, 0.9995772117112356, 0.9997823826679885], [0.9990576033828621, 0.9998615853095876, 0.9996892479801118, 0.9996300113639367, 0.9998306601833709]], "f1": [[0.9996311306212884, 0.9993319973279893, 0.9989721311793339, 0.9993129508759877, 0.9994804316042579], [0.9994675498500347, 0.9988477138643067, 0.998564277676458, 0.999127930022991, 0.9992748193091445], [0.9995261404382957, 0.9987554162441228, 0.9986422530840251, 0.9994979654388838, 0.9992506828454716], [0.9995481026417676, 0.9988017328786064, 0.9986613638568241, 0.9993527508090615, 0.9993353554639823], [0.9995066070687088, 0.9987786611360756, 0.9985643890893571, 0.999339498018494, 0.9991417969080515]]} -------------------------------------------------------------------------------- /reports/incep.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.018571645007892948, 0.32687261175820637, 0.2984264785675529, 0.13470325664932253, 0.16185920664830653], [0.011240732504777312, 0.26702269692923897, 0.3410588326486319, 0.2509178310134439, 0.15944339759385417], [0.016616735007062112, 0.20717278210027162, 0.34493450120145724, 0.21922294709595627, 0.19568053341063923], [0.009285822503946474, 0.409741724598315, 0.25579412448647393, 0.20073426481075513, 0.19326472435618688], [0.01270691500540044, 0.28083421573592376, 0.41857220370513915, 0.182245582525554, 0.2270860511185196]], "rec": [[0.9998142835499211, 0.9967312738824179, 0.9970157352143245, 0.9986529674335067, 0.9983814079335169], [0.9998875926749522, 0.9973297730307076, 0.9965894116735137, 0.9974908216898656, 0.9984055660240615], [0.9998338326499294, 0.9979282721789973, 0.9965506549879855, 0.9978077705290405, 0.9980431946658936], [0.9999071417749605, 0.9959025827540169, 0.9974420587551353, 0.9979926573518925, 0.9980673527564381], [0.999872930849946, 0.9971916578426407, 0.9958142779629486, 0.9981775441747445, 0.9977291394888148]], "pre": [[0.9988038277511961, 0.9999538127569165, 0.998990330472603, 0.9996562937894932, 0.9995404634063755], [0.9985845372901211, 0.9999538404726735, 0.9993781577924602, 0.9998411521762152, 0.9995888257346717], [0.9986624555168828, 0.9997693833310272, 0.9991839589647936, 0.9994179894179894, 0.9997580040171333], [0.9985845649245426, 0.9999537743262608, 0.9993398827321088, 0.9997089639115251, 0.9997338173018754], [0.9985114108896395, 1.0, 0.9991833560178884, 0.999418204897657, 0.9997579278625031]], "f1": [[0.9993088002188393, 0.9983399428202526, 0.9980020561364034, 0.9991543787326251, 0.9989605994682136], [0.9992356401696731, 0.9986400829779878, 0.9979818365287589, 0.998664604074941, 0.9989968455021211], [0.999247800794205, 0.9988479793557901, 0.9978655696988513, 0.9986122308719154, 0.9988998633896298], [0.9992454157173312, 0.9979240669834386, 0.9983900688585006, 0.9988500733573893, 0.9988998899889988], [0.9991917070609638, 0.99859385445241, 0.9974959722033504, 0.9987974892632969, 0.9987425033855678]]} -------------------------------------------------------------------------------- /networks/simple_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CNNEncoder(nn.Module): 6 | output_shape = 64*3*3 7 | def __init__(self): 8 | super().__init__() 9 | self.convnet = nn.Sequential( 10 | nn.Conv2d(1, 16, 3, padding='same'), nn.ReLU(), 11 | nn.Conv2d(16, 16, 3, padding='same'),nn.ReLU(), 12 | nn.BatchNorm2d(16), 13 | nn.MaxPool2d(2, stride=2), 14 | nn.Conv2d(16, 32, 3, padding='same'),nn.ReLU(), 15 | nn.Conv2d(32, 32, 3, padding='same'),nn.ReLU(), 16 | nn.BatchNorm2d(32), 17 | nn.MaxPool2d(2, stride=2), 18 | nn.Conv2d(32, 64, 3, padding='same'),nn.ReLU(), 19 | nn.Conv2d(64, 64, 3, padding='same'),nn.ReLU(), 20 | nn.BatchNorm2d(64), 21 | nn.MaxPool2d(2, stride=2)) 22 | 23 | def forward(self, x): 24 | feat = self.convnet(x) 25 | feat = torch.flatten(feat, 1) 26 | return feat 27 | 28 | class SupConCNN(nn.Module): 29 | def __init__(self, feat_dim): 30 | super().__init__() 31 | self.encoder = CNNEncoder() 32 | self.feat_dim = feat_dim 33 | dim_in = self.encoder.output_shape 34 | self.head = nn.Sequential( 35 | nn.Linear(dim_in, 256), nn.ReLU(), 36 | nn.Linear(256, feat_dim)) 37 | 38 | def forward(self, x): 39 | feat = self.encoder(x) 40 | feat = self.head(feat) 41 | feat = F.normalize(feat, dim=1) 42 | return feat 43 | 44 | class BaselineCNNClassifier(nn.Module): 45 | def __init__(self, num_classes): 46 | super().__init__() 47 | self.encoder = CNNEncoder() 48 | feat_dim = CNNEncoder.output_shape 49 | self.n_classes = num_classes 50 | self.fc = nn.Linear(feat_dim, n_classes) 51 | 52 | def forward(self, x): 53 | feat = self.encoder(x) 54 | out = self.fc(feat) 55 | return out 56 | -------------------------------------------------------------------------------- /reports/resnet.json: -------------------------------------------------------------------------------- 1 | {"fnr": [[0.003909820001661674, 0.17494590488467382, 0.1240213936904116, 0.13734449697577983, 0.09663236217809344], [0.001954910000830837, 0.19336126329358685, 0.17052941632431595, 0.1320620163228652, 0.10629559839590279], [0.002443637501038546, 0.2117766217024999, 0.17052941632431595, 0.10036713240537756, 0.10146398028699811], [0.002443637501038546, 0.11509599005570646, 0.22478877606387102, 0.09244341142600565, 0.08213750785137942], [0.0004887275002077092, 0.23019198011141292, 0.20153476474691884, 0.08187845012017643, 0.09180074406918877], [0.0009774550004154185, 0.20256894249804336, 0.22866444461669638, 0.06338976783497531, 0.12320626177706914]], "rec": [[0.9999609017999834, 0.9982505409511533, 0.9987597860630959, 0.9986265550302422, 0.999033676378219], [0.9999804508999917, 0.9980663873670641, 0.9982947058367568, 0.9986793798367714, 0.998937044016041], [0.9999755636249896, 0.997882233782975, 0.9982947058367568, 0.9989963286759462, 0.99898536019713], [0.9999755636249896, 0.998849040099443, 0.9977521122393613, 0.99907556588574, 0.9991786249214862], [0.999995112724998, 0.9976980801988858, 0.9979846523525309, 0.9991812154987982, 0.9990819925593081], [0.9999902254499958, 0.9979743105750196, 0.997713355553833, 0.9993661023216502, 0.9987679373822294]], "pre": [[0.9992722974886937, 0.9999538830474082, 0.9997672253258846, 0.9997620180866255, 0.9998791073285137], [0.9991649575153824, 0.9999538745387454, 0.9998835448934436, 0.999920664304226, 0.9998549182706258], [0.9992332674040973, 0.9999538660269423, 0.9998447325518205, 0.9997885332135021, 0.9998791014822158], [0.9993357558707459, 0.9999539106788957, 0.9998834815706684, 0.9997885499814981, 0.9997824457927433], [0.9992381622128458, 0.999907719282056, 0.9998835087174309, 0.9998414208690136, 0.9998791131742463], [0.9992137596937081, 0.9999077448221781, 0.999961155997514, 0.9996829506724088, 0.999903257793794]], "f1": [[0.999616481054897, 0.9991014860039166, 0.9992632517740122, 0.9991939639794395, 0.9994562130678042], [0.9995725378792505, 0.9990092394184467, 0.9990884936873339, 0.9992996366038983, 0.9993957703927492], [0.99960427770911, 0.9989169758278224, 0.9990691179892949, 0.9993922739523331, 0.9994320310328576], [0.9996555573741257, 0.99940117002165, 0.9988166598770102, 0.99943193077482, 0.9994804441598298], [0.999616494170467, 0.9988016776512881, 0.9989331781592474, 0.9995112091628466, 0.9994803939339012], [0.9996018417870708, 0.9989400921658986, 0.9988359911535327, 0.999524501386871, 0.9993352751356643]]} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | save/ 2 | tmp/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAN-SupCon-IDS 2 | 3 | This is the implementation of the paper ["SupCon ResNet and Transfer Learning for the In-vehicle Intrusion Detection System"](https://arxiv.org/submit/4407238/view) 4 | 5 | ## Environment 6 | 7 | - tensorflow: 2.0 8 | - torch: 1.9 9 | 10 | ## How to run 11 | 12 | ### Train test split 13 | 14 | ``` 15 | python3 train_test_split.py --data_path ../Data/ --car_model None --window_size 29 --strided 15 --rid 2 16 | ``` 17 | 18 | ### Baseline train 19 | 20 | ``` 21 | python3 train_baseline.py --data_dir ../Data/TFrecord_w29_s15/ \\ 22 | --model resnet18 --save_freq 10 --window_size 29 \\ 23 | --num_workers 8 --cosine --epochs 50 \\ 24 | --batch_size 256 --learning_rate 0.0005 --rid 5 25 | ``` 26 | 27 | ### Supcon train 28 | 29 | ``` 30 | python3 train_supcon.py --data_dir ../Data/TFrecord_w29_s15/ \\ 31 | --model resnet18 --save_freq 10 --window_size 29 \\ 32 | --epochs 200 --num_workers 8 --temp 0.07 \\ 33 | --learning_rate 0.1 --learning_rate_classifier 0.01 \\ 34 | --cosine --epoch_start_classifier 170 --rid 3 --batch_size 512 35 | ``` 36 | 37 | ### Transfer 38 | 39 | Random initialization 40 | 41 | ``` 42 | python3 transfer.py --data_path ../Data/Survival/ --car_model Spark \\ 43 | --pretrained_model resnet --tf_algo tune \\ 44 | --num_classes 4 --window_size 29 --strided 10 \\ 45 | --lr_tune 0.001 --tune_epochs 20 46 | ``` 47 | 48 | Using CE ResNet as the pretrained model 49 | 50 | ``` 51 | python3 transfer.py --data_path ../Data/Survival/ --car_model Spark \\ 52 | --window_size 29 --strided 10 --num_classes 4 --lr_transfer 0.01 \\ 53 | --lr_tune 0.001 --transfer_epochs 50 --tune_epochs 10 \\ 54 | --tf_algo transfer_tune --pretrained_model resnet \\ 55 | --pretrained_path save/smallresnet18.ce1_gamma0_lr0.001_bs256_50epochs_051822_100142_cosine/models/ \\ 56 | --source_ckpt 50 57 | ``` 58 | 59 | Using SupCon ResNet as the pretrained model 60 | 61 | ``` 62 | python3 transfer.py --data_path ../Data/Survival/ --car_model Spark \\ 63 | --window_size 29 --strided 10 --num_classes 4 --lr_transfer 0.01 \\ 64 | --lr_tune 0.001 --transfer_epochs 40 --tune_epochs 20 \\ 65 | --tf_algo transfer_tune --pretrained_model supcon \\ 66 | --pretrained_path save/SupCon_resnet18.ce2_lr0.05_0.01_bs512_200epoch_temp0.07_052322_102305_cosine_warm/models/ \\ 67 | --source_ckpt 200 68 | ``` 69 | 70 | ## Acknowledgement 71 | 72 | This codebase was adapted from [SupContrast](https://github.com/HobbitLong/SupContrast). 73 | -------------------------------------------------------------------------------- /supcon/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class TwoCropTransform: 10 | """Create two crops of the same image""" 11 | def __init__(self, transform): 12 | self.transform = transform 13 | 14 | def __call__(self, x): 15 | return [self.transform(x), self.transform(x)] 16 | 17 | 18 | class AverageMeter(object): 19 | """Computes and stores the average and current value""" 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def accuracy(output, target, topk=(1,)): 37 | """Computes the accuracy over the k top predictions for the specified values of k""" 38 | with torch.no_grad(): 39 | maxk = max(topk) 40 | batch_size = target.size(0) 41 | 42 | _, pred = output.topk(maxk, 1, True, True) 43 | pred = pred.t() 44 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 45 | 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / batch_size)) 50 | return res 51 | 52 | 53 | def adjust_learning_rate(args, optimizer, epoch): 54 | lr = args.learning_rate 55 | if args.cosine: 56 | eta_min = lr * (args.lr_decay_rate ** 3) 57 | lr = eta_min + (lr - eta_min) * ( 58 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 59 | else: 60 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs.split(','), dtype=int)) 61 | if steps > 0: 62 | lr = lr * (args.lr_decay_rate ** steps) 63 | 64 | for param_group in optimizer.param_groups: 65 | param_group['lr'] = lr 66 | 67 | 68 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 69 | if args.warm and epoch <= args.warm_epochs: 70 | p = (batch_id + (epoch - 1) * total_batches) / \ 71 | (args.warm_epochs * total_batches) 72 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 73 | 74 | for param_group in optimizer.param_groups: 75 | param_group['lr'] = lr 76 | 77 | 78 | def set_optimizer(opt, model): 79 | optimizer = optim.SGD(model.parameters(), 80 | lr=opt.learning_rate, 81 | momentum=opt.momentum, 82 | weight_decay=opt.weight_decay) 83 | return optimizer 84 | 85 | 86 | 87 | def save_model(model, optimizer, opt, epoch, save_file): 88 | print('==> Saving...') 89 | state = { 90 | 'opt': opt, 91 | 'model': model.state_dict(), 92 | 'optimizer': optimizer.state_dict(), 93 | 'epoch': epoch, 94 | } 95 | torch.save(state, save_file) 96 | del state 97 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from sklearn.metrics import confusion_matrix 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | from dataset import CANDataset 10 | 11 | def load_dataset(args, trial_id=1): 12 | if args.car_model is None: 13 | data_dir = f'TFrecord_w{args.window_size}_s{args.strided}' 14 | else: 15 | data_dir = f'TFrecord_{args.car_model}_w{args.window_size}_s{args.strided}' 16 | data_dir = os.path.join(args.data_path, data_dir, str(trial_id)) 17 | 18 | train_dataset = CANDataset(data_dir, 19 | window_size = args.window_size) 20 | val_dataset = CANDataset(data_dir, 21 | window_size = args.window_size, 22 | is_train=False) 23 | 24 | train_loader = torch.utils.data.DataLoader( 25 | train_dataset, batch_size=args.batch_size, 26 | shuffle=True, num_workers=args.num_workers, pin_memory=True) 27 | 28 | val_loader = torch.utils.data.DataLoader( 29 | val_dataset, batch_size=args.batch_size, 30 | num_workers=args.num_workers, pin_memory=True) 31 | 32 | return train_loader, val_loader 33 | 34 | def change_new_state_dict(state_dict): 35 | new_state_dict = {} 36 | for k, v in state_dict.items(): 37 | k = k.replace("module.", "") 38 | new_state_dict[k] = v 39 | return new_state_dict 40 | 41 | def plot_embeddings(embeddings, targets, xlim=None, ylim=None, save_dir=None): 42 | classes = ['Normal', 'DoS', 'Fuzzy', 'gear', 'RPM'] 43 | colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728','#9467bd',] 44 | n_classes = len(classes) 45 | plt.figure(figsize=(10,10)) 46 | for i in range(n_classes): 47 | inds = np.where(targets==i)[0] 48 | plt.scatter(embeddings[inds,0], embeddings[inds,1], alpha=0.5, color=colors[i]) 49 | if xlim: 50 | plt.xlim(xlim[0], xlim[1]) 51 | if ylim: 52 | plt.ylim(ylim[0], ylim[1]) 53 | plt.legend(classes) 54 | 55 | if save_dir is not None: 56 | plt.savefig(save_dir, dpi=300) 57 | 58 | def cal_metric(label, pred): 59 | cm = confusion_matrix(label, pred) 60 | recall = np.diag(cm) / np.sum(cm, axis = 1) 61 | precision = np.diag(cm) / np.sum(cm, axis = 0) 62 | f1 = 2*recall*precision / (recall + precision) 63 | 64 | total_actual = np.sum(cm, axis=1) 65 | true_predicted = np.diag(cm) 66 | fnr = (total_actual - true_predicted)*100/total_actual 67 | 68 | return cm, { 69 | 'fnr': np.array(fnr), 70 | 'rec': recall, 71 | 'pre': precision, 72 | 'f1': f1 73 | } 74 | 75 | def get_prediction(model, dataloader): 76 | with torch.no_grad(): 77 | model.eval() 78 | prediction = np.zeros(len(dataloader.dataset)) 79 | labels = np.zeros(len(dataloader.dataset)) 80 | k = 0 81 | for images, target in dataloader: 82 | if cuda: 83 | images = images.cuda() 84 | prediction[k:k+len(images)] = np.argmax(model(images).data.cpu().numpy(), axis=1) 85 | labels[k:k+len(images)] = target.numpy() 86 | k += len(images) 87 | return prediction, labels 88 | 89 | def draw_confusion_matrix(cm, classes, save_dir=None): 90 | cm_df = pd.DataFrame(cm, 91 | index = classes, 92 | columns = classes) 93 | plt.figure(figsize=(10,8)) 94 | sns.heatmap(cm_df, annot=True, cmap='YlGnBu', cbar=False, linewidths=0.5) 95 | plt.title('Confusion Matrix') 96 | plt.ylabel('Actual Values') 97 | plt.xlabel('Predicted Values') 98 | if save_dir is not None: 99 | plt.savefig(save_dir, dpi=300) 100 | plt.show() 101 | 102 | def print_results(results): 103 | print('\t' + '\t'.join(map(str, results.keys()))) 104 | for idx, c in enumerate(classes): 105 | res = [round(results[k][idx], 4) for k in results.keys()] 106 | output = [c] + res 107 | print('\t'.join(map(str, output))) 108 | -------------------------------------------------------------------------------- /supcon/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SupConLoss(nn.Module): 12 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 13 | It also supports the unsupervised contrastive loss in SimCLR""" 14 | def __init__(self, temperature=0.07, contrast_mode='all', 15 | base_temperature=0.07): 16 | super(SupConLoss, self).__init__() 17 | self.temperature = temperature 18 | self.contrast_mode = contrast_mode 19 | self.base_temperature = base_temperature 20 | 21 | def forward(self, features, labels=None, mask=None): 22 | """Compute loss for model. If both `labels` and `mask` are None, 23 | it degenerates to SimCLR unsupervised loss: 24 | https://arxiv.org/pdf/2002.05709.pdf 25 | 26 | Args: 27 | features: hidden vector of shape [bsz, n_views, ...]. 28 | labels: ground truth of shape [bsz]. 29 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 30 | has the same class as sample i. Can be asymmetric. 31 | Returns: 32 | A loss scalar. 33 | """ 34 | device = (torch.device('cuda') 35 | if features.is_cuda 36 | else torch.device('cpu')) 37 | 38 | if len(features.shape) < 3: 39 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 40 | 'at least 3 dimensions are required') 41 | if len(features.shape) > 3: 42 | features = features.view(features.shape[0], features.shape[1], -1) 43 | 44 | batch_size = features.shape[0] 45 | if labels is not None and mask is not None: 46 | raise ValueError('Cannot define both `labels` and `mask`') 47 | elif labels is None and mask is None: 48 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 49 | elif labels is not None: 50 | labels = labels.contiguous().view(-1, 1) 51 | if labels.shape[0] != batch_size: 52 | raise ValueError('Num of labels does not match num of features') 53 | mask = torch.eq(labels, labels.T).float().to(device) 54 | else: 55 | mask = mask.float().to(device) 56 | 57 | # contrast count is the number of augmented views 58 | contrast_count = features.shape[1] 59 | # contrast feature: the concatenation of all views 60 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 61 | if self.contrast_mode == 'one': 62 | anchor_feature = features[:, 0] 63 | anchor_count = 1 64 | elif self.contrast_mode == 'all': 65 | anchor_feature = contrast_feature 66 | anchor_count = contrast_count 67 | else: 68 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 69 | 70 | # compute logits 71 | anchor_dot_contrast = torch.div( 72 | torch.matmul(anchor_feature, contrast_feature.T), 73 | self.temperature) 74 | # for numerical stability 75 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 76 | logits = anchor_dot_contrast - logits_max.detach() 77 | 78 | # tile mask 79 | mask = mask.repeat(anchor_count, contrast_count) 80 | # mask-out self-contrast cases 81 | logits_mask = torch.scatter( 82 | torch.ones_like(mask), 83 | 1, 84 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 85 | 0 86 | ) 87 | mask = mask * logits_mask 88 | 89 | # compute log_prob 90 | exp_logits = torch.exp(logits) * logits_mask # this is the numerator -> positive 91 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # this is the denom -> negative 92 | 93 | # compute mean of log-likelihood over positive 94 | # Prevent the nan loss by adding a small amount of number in the denom 95 | eps = 1e-12 96 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + eps) 97 | 98 | # loss 99 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 100 | loss = loss.view(anchor_count, batch_size).mean() 101 | 102 | return loss 103 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm.auto import tqdm 4 | import copy 5 | import time 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from utils import cal_metric, load_dataset, change_new_state_dict 12 | 13 | from networks.classifier import LinearClassifier 14 | from networks.transfer import TransferModel 15 | from networks.resnet_big import SupConResNet, SupCEResNet 16 | from networks.inception import SupIncepResnet 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('argument for testing') 21 | parser.add_argument('--data_path', type=str, help='data path to train the target model') 22 | parser.add_argument('--car_model', type=str, ) 23 | parser.add_argument('--pretrained_model', type=str, default='supcon') 24 | parser.add_argument('--pretrained_path', type=str, help='path which stores the pretrained model weights') 25 | parser.add_argument('--window_size', type=int, default=29) 26 | parser.add_argument('--strided', type=int, default=None) 27 | parser.add_argument('--batch_size', type=int, default=256) 28 | parser.add_argument('--num_workers', type=int, default=8) 29 | parser.add_argument('--trial_id', type=int, default=1) 30 | parser.add_argument('--ckpt', type=int, help='id checkpoint for pretrained model') 31 | args = parser.parse_args() 32 | 33 | if args.strided == None: 34 | args.strided = args.window_size 35 | return args 36 | 37 | def load_models_weights(args, model, verbose=False): 38 | if verbose: 39 | print('Loading: ', model.__class__.__name__) 40 | if model.__class__.__name__ == 'LinearClassifier': 41 | model_file = f'{args.pretrained_path}/ckpt_class_epoch_{args.ckpt}.pth' 42 | else: 43 | model_file = f'{args.pretrained_path}/ckpt_epoch_{args.ckpt}.pth' 44 | ckpt = torch.load(model_file) 45 | state_dict = change_new_state_dict(ckpt['model']) 46 | model.load_state_dict(state_dict=state_dict) 47 | return model 48 | 49 | 50 | def load_model(args, verbose=False, is_cuda=True): 51 | if args.pretrained_model == 'resnet': 52 | model = SupCEResNet(num_classes=5) 53 | model = load_models_weights(args, model, verbose) 54 | elif args.pretrained_model == 'supcon': 55 | supcon_model = SupConResNet(name='resnet18') 56 | classifier = LinearClassifier(n_classes=5, feat_dim=128) 57 | supcon_model = load_models_weights(args, supcon_model, verbose) 58 | classifier = load_models_weights(args, classifier, verbose) 59 | model = TransferModel(supcon_model.encoder, classifier) 60 | elif args.pretrained_model == 'incep': 61 | model = SupIncepResnet(num_classes=5) 62 | model = load_models_weights(args, model, verbose) 63 | 64 | if is_cuda: 65 | model = model.cuda() 66 | return model 67 | 68 | def inference(model, data_loader, verbose=False): 69 | total_pred = np.empty(shape=(0), dtype=int) 70 | total_label = np.empty(shape=(0), dtype=int) 71 | 72 | model.eval() 73 | if verbose: 74 | data_loader = tqdm(data_loader) 75 | with torch.no_grad(): 76 | for samples, labels in data_loader: 77 | samples = samples.cuda(non_blocking=True) 78 | outputs = model(samples) 79 | _, pred = outputs.topk(1, 1, True, True) 80 | pred = pred.t().cpu().numpy().squeeze(0) 81 | total_pred = np.concatenate((total_pred, pred), axis=0) 82 | total_label = np.concatenate((total_label, labels), axis=0) 83 | 84 | return total_label, total_pred 85 | 86 | def evaluate(model, data_loader, verbose=False): 87 | total_label, total_pred = inference(model, data_loader, verbose=verbose) 88 | _, results = cal_metric(total_label, total_pred) 89 | if verbose: 90 | for key, values in results.items(): 91 | print(list("{0:0.4f}".format(i) for i in values)) 92 | return results 93 | 94 | def test(args, verbose=False, is_cuda=True): 95 | train_loader, val_loader = load_dataset(args, trial_id=args.trial_id) 96 | model = load_model(args, is_cuda=is_cuda, verbose=verbose) 97 | results = evaluate(model, val_loader, verbose=verbose) 98 | return results 99 | 100 | if __name__ == '__main__': 101 | args = parse_args() 102 | test(args, verbose=True, is_cuda=True) -------------------------------------------------------------------------------- /notebooks/RecCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn \n", 11 | "import torch.nn.functional as F\n", 12 | "from torchsummary import summary\n", 13 | "\n", 14 | "class RecCNN(nn.Module):\n", 15 | " def __init__(self, num_classes):\n", 16 | " super().__init__()\n", 17 | " self.convnet = nn.Sequential(\n", 18 | " nn.Conv2d(1, 64, 3, padding='same'), nn.ReLU(),\n", 19 | " nn.MaxPool2d(2, stride=2),\n", 20 | " \n", 21 | " nn.Conv2d(64, 16, 3, padding='same'), nn.ReLU(),\n", 22 | " nn.MaxPool2d(2, stride=2),\n", 23 | " )\n", 24 | " self.fc1 = nn.Linear(32*32*16, 32)\n", 25 | " self.fc2 = nn.Linear(32, num_classes)\n", 26 | " # self.softmax = nn.Softmax()\n", 27 | " def forward(self, x):\n", 28 | " x = torch.flatten(self.convnet(x), 1)\n", 29 | " x = self.fc1(x)\n", 30 | " x = self.fc2(x)\n", 31 | " # c = self.softmax(x)\n", 32 | " return x" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 25, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "x = torch.rand((1, 1, 128, 128))\n", 42 | "model = RecCNN(num_classes=5)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 26, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "----------------------------------------------------------------\n", 55 | " Layer (type) Output Shape Param #\n", 56 | "================================================================\n", 57 | " Conv2d-1 [-1, 64, 128, 128] 640\n", 58 | " ReLU-2 [-1, 64, 128, 128] 0\n", 59 | " MaxPool2d-3 [-1, 64, 64, 64] 0\n", 60 | " Conv2d-4 [-1, 16, 64, 64] 9,232\n", 61 | " ReLU-5 [-1, 16, 64, 64] 0\n", 62 | " MaxPool2d-6 [-1, 16, 32, 32] 0\n", 63 | " Linear-7 [-1, 32] 524,320\n", 64 | " Linear-8 [-1, 5] 165\n", 65 | "================================================================\n", 66 | "Total params: 534,357\n", 67 | "Trainable params: 534,357\n", 68 | "Non-trainable params: 0\n", 69 | "----------------------------------------------------------------\n", 70 | "Input size (MB): 0.06\n", 71 | "Forward/backward pass size (MB): 19.13\n", 72 | "Params size (MB): 2.04\n", 73 | "Estimated Total Size (MB): 21.23\n", 74 | "----------------------------------------------------------------\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "summary(model, (1, 128, 128))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 27, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from thop import profile" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 29, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "[INFO] Register count_convNd() for .\n", 101 | "[INFO] Register zero_ops() for .\n", 102 | "[INFO] Register zero_ops() for .\n", 103 | "[INFO] Register zero_ops() for .\n", 104 | "[INFO] Register count_linear() for .\n", 105 | "MACs (G): 47.710368\n", 106 | "Params (M): 0.534357\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "x = torch.rand((1, 1, 128, 128))\n", 112 | "model = RecCNN(num_classes=5)\n", 113 | "macs, params = profile(model, inputs=(x, ))\n", 114 | "print('MACs (G): ', macs/1000**2)\n", 115 | "print('Params (M): ', params/1000**2)" 116 | ] 117 | } 118 | ], 119 | "metadata": { 120 | "kernelspec": { 121 | "display_name": "Python 3.8.13 ('torch')", 122 | "language": "python", 123 | "name": "python3" 124 | }, 125 | "language_info": { 126 | "codemirror_mode": { 127 | "name": "ipython", 128 | "version": 3 129 | }, 130 | "file_extension": ".py", 131 | "mimetype": "text/x-python", 132 | "name": "python", 133 | "nbconvert_exporter": "python", 134 | "pygments_lexer": "ipython3", 135 | "version": "3.8.13" 136 | }, 137 | "orig_nbformat": 4, 138 | "vscode": { 139 | "interpreter": { 140 | "hash": "78a29cc2c05d3ee8d935820ad86792723c958d8c7f217aee9aa88e38f878a5d1" 141 | } 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /train_test_split.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import json 4 | import argparse 5 | from functools import partial 6 | from multiprocessing import Pool 7 | from tqdm import tqdm 8 | import os 9 | print('Tensorflow version', tf.__version__) 10 | 11 | class TFwriter: 12 | def __init__(self, outdir, start_idx = 0): 13 | print('Writing to: ', outdir) 14 | self._outdir = outdir 15 | self._start_idx = start_idx 16 | 17 | def serialize_example(self, x, y): 18 | """converts x, y to tf.train.Example and serialize""" 19 | #Need to pay attention to whether it needs to be converted to numpy() form 20 | id_seq, data_seq = x 21 | id_seq = tf.train.Int64List(value = np.array(id_seq).flatten()) 22 | data_seq = tf.train.Int64List(value = np.array(data_seq).flatten()) 23 | #data_histogram = tf.train.Int64List(value = np.array(data_histogram).flatten()) 24 | label = tf.train.Int64List(value = np.array([y])) 25 | features = tf.train.Features( 26 | feature = { 27 | "id_seq": tf.train.Feature(int64_list = id_seq), 28 | "data_seq": tf.train.Feature(int64_list = data_seq), 29 | #"data_histogram": tf.train.Feature(int64_list = data_histogram), 30 | "label" : tf.train.Feature(int64_list = label) 31 | } 32 | ) 33 | example = tf.train.Example(features = features) 34 | return example.SerializeToString() 35 | 36 | def write(self, data, label): 37 | filename = os.path.join(self._outdir, str(self._start_idx)+'.tfrec') 38 | with tf.io.TFRecordWriter(filename) as outfile: 39 | outfile.write(self.serialize_example(data, label)) 40 | self._start_idx += 1 41 | 42 | def read_tfrecord(example, window_size): 43 | # window_size = 20 44 | data_bytes = 256 45 | feature_description = { 46 | 'id_seq': tf.io.FixedLenFeature([window_size*29], tf.int64), 47 | 'data_seq': tf.io.FixedLenFeature([window_size*8], tf.int64), 48 | 'data_histogram': tf.io.FixedLenFeature([data_bytes], tf.int64), 49 | 'label': tf.io.FixedLenFeature([1], tf.int64) 50 | } 51 | return tf.io.parse_single_example(example, feature_description) 52 | 53 | 54 | def write_tfrecord(dataset, tfwriter): 55 | for batch_data in iter(dataset): 56 | features = zip(batch_data['id_seq'], batch_data['data_seq']) 57 | for x, y in zip(features, batch_data['label']): 58 | tfwriter.write(x, y) 59 | 60 | def train_test_split(**args): 61 | """ 62 | """ 63 | if args['strided'] == None: 64 | args['strided'] = args['window_size'] 65 | 66 | if args['car_model'] is None: 67 | data_dir = f"{args['data_path']}/TFrecord_w{args['window_size']}_s{args['strided']}" 68 | else: 69 | data_dir = f"{args['data_path']}/TFrecord_{args['car_model']}_w{args['window_size']}_s{args['strided']}" 70 | 71 | out_dir = data_dir + '/{}'.format(args['rid']) 72 | train_dir = os.path.join(out_dir, 'train') 73 | val_dir = os.path.join(out_dir, 'val') 74 | if not os.path.exists(train_dir): 75 | os.makedirs(train_dir) 76 | if not os.path.exists(val_dir): 77 | os.makedirs(val_dir) 78 | data_info = json.load(open(data_dir + '/datainfo.txt')) 79 | train_writer = TFwriter(train_dir) 80 | val_writer = TFwriter(val_dir) 81 | 82 | train_ratio = 0.7 83 | batch_size = 1000 84 | 85 | total_train_size = 0 86 | total_val_size = 0 87 | 88 | for filename, dataset_size in data_info.items(): 89 | print('Read from {}: {} records'.format(filename, dataset_size)) 90 | dataset = tf.data.TFRecordDataset(filename) 91 | dataset = dataset.map(lambda x: read_tfrecord(x, args['window_size']), 92 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 93 | dataset = dataset.shuffle(50000) 94 | 95 | train_size = int(dataset_size * train_ratio) 96 | val_size = (dataset_size - train_size) 97 | 98 | train_dataset = dataset.take(train_size) 99 | val_dataset = dataset.skip(train_size) 100 | 101 | train_dataset = train_dataset.batch(batch_size) 102 | val_dataset = val_dataset.batch(batch_size) 103 | 104 | 105 | # inputs = ([train_dataset, train_writer], [val_dataset, val_writer]) 106 | # p = Pool(2) 107 | # p.map(write_tfrecord, inputs) 108 | write_tfrecord(train_dataset, train_writer) 109 | write_tfrecord(val_dataset, val_writer) 110 | 111 | total_train_size += train_size 112 | total_val_size += val_size 113 | 114 | print('Total training: ', total_train_size) 115 | print('Total validation: ', total_val_size) 116 | 117 | 118 | if __name__ == '__main__': 119 | # Parse argument 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--data_path', type=str, default='../Data/') 122 | parser.add_argument('--car_model', type=str, default=None) 123 | parser.add_argument('--window_size', type=int) 124 | parser.add_argument('--strided', type=int) 125 | parser.add_argument('--rid', type=int, default=1) 126 | 127 | args = vars(parser.parse_args()) 128 | print(args) 129 | train_test_split(**args) 130 | -------------------------------------------------------------------------------- /networks/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Stem(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(1, 32, 3, padding='same'), # 29x29x32 10 | nn.Conv2d(32, 32, 3), # 27x27x32 11 | nn.MaxPool2d(2, stride=2), #13x13x32 12 | nn.Conv2d(32, 64, 1), #13x13x64 13 | nn.Conv2d(64, 128, 3, padding='same'), # 13x13x128 14 | nn.Conv2d(128, 128, 3, padding='same') # 13x13x128 15 | ) 16 | def forward(self, x): 17 | return self.conv(x) 18 | 19 | class InceptionresenetA(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.branch1 = nn.Conv2d(128, 32, 1) 23 | self.branch2 = nn.Sequential( 24 | nn.Conv2d(128, 32, 1), 25 | nn.Conv2d(32, 32, 3, padding='same'), 26 | ) 27 | self.branch3 = nn.Sequential( 28 | nn.Conv2d(128, 32, 1), 29 | nn.Conv2d(32, 32, 3, padding='same'), 30 | nn.Conv2d(32, 32, 3, padding='same') 31 | ) 32 | self.linear = nn.Conv2d(96, 128, 1) 33 | self.relu = nn.ReLU() 34 | 35 | def forward(self, x): 36 | residual = [ 37 | self.branch1(x), 38 | self.branch2(x), 39 | self.branch3(x) 40 | ] 41 | #print([e.shape for e in residual]) 42 | residual = torch.cat(residual, 1) 43 | residual = self.linear(residual) 44 | output = self.relu(x + residual) 45 | return output 46 | 47 | class ReductionA(nn.Module): 48 | def __init__(self): 49 | super().__init__() 50 | self.branch1 = nn.MaxPool2d(3, stride=2) 51 | self.branch2 = nn.Conv2d(128, 192, 3, stride=2) 52 | self.branch3 = nn.Sequential( 53 | nn.Conv2d(128, 96, 1), 54 | nn.Conv2d(96, 96, 3, padding='same'), 55 | nn.Conv2d(96, 128, 3, stride=2) 56 | ) 57 | 58 | def forward(self, x): 59 | x = [ 60 | self.branch1(x), 61 | self.branch2(x), 62 | self.branch3(x) 63 | ] 64 | return torch.cat(x, 1) 65 | 66 | class InceptionresnetB(nn.Module): 67 | def __init__(self): 68 | super().__init__() 69 | self.branch1 = nn.Conv2d(448, 64, 1) 70 | self.branch2 = nn.Sequential( 71 | nn.Conv2d(448, 64, 1), 72 | nn.Conv2d(64, 64, kernel_size=(1, 3), padding='same'), 73 | nn.Conv2d(64, 64, kernel_size=(3, 1), padding='same') 74 | ) 75 | self.linear = nn.Conv2d(64*2, 448, 1) 76 | self.relu = nn.ReLU() 77 | def forward(self, x): 78 | residual = [ 79 | self.branch1(x), 80 | self.branch2(x) 81 | ] 82 | residual = torch.cat(residual, 1) 83 | residual = self.linear(residual) 84 | output = self.relu(x + residual) 85 | return output 86 | 87 | class ReductionB(nn.Module): 88 | def __init__(self): 89 | super().__init__() 90 | self.branch1 = nn.MaxPool2d(3) 91 | self.branch2 = nn.Sequential( 92 | nn.Conv2d(448, 128, 1), 93 | nn.Conv2d(128, 192, 3, stride=2) 94 | ) 95 | self.branch3 = nn.Sequential( 96 | nn.Conv2d(448, 128, 1), 97 | nn.Conv2d(128, 128, 3, stride=2) 98 | ) 99 | self.branch4 = nn.Sequential( 100 | nn.Conv2d(448, 128, 1), 101 | nn.Conv2d(128, 128, 3), 102 | nn.Conv2d(128, 128, 3) 103 | ) 104 | 105 | def forward(self, x): 106 | x = [ 107 | self.branch1(x), 108 | self.branch2(x), 109 | self.branch3(x), 110 | self.branch4(x), 111 | ] 112 | # print([e.shape for e in x]) 113 | return torch.cat(x, 1) 114 | 115 | class InceptionResnet(nn.Module): 116 | def __init__(self): 117 | super().__init__() 118 | self.stem = Stem() 119 | self.inceptionresA = InceptionresenetA() 120 | self.reductionA = ReductionA() 121 | self.inceptionresB = InceptionresnetB() 122 | self.reductionB = ReductionB() 123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | self.dropout = nn.Dropout2d(1 - 0.8) 125 | # self.linear = nn.Linear(896, n_classes) 126 | 127 | def forward(self, x): 128 | x = self.stem(x) 129 | # print('Stem: ', x.shape) 130 | x = self.inceptionresA(x) 131 | # print('Incetion Res A: ', x.shape) 132 | x = self.reductionA(x) 133 | # print('Reduction A: ', x.shape) 134 | x = self.inceptionresB(x) 135 | # print('Inception Res B: ', x.shape) 136 | x = self.reductionB(x) 137 | # print('Reduction B: ', x.shape) 138 | x = self.avgpool(x) 139 | # print('Avg pool: ', x.shape) 140 | x = self.dropout(x) 141 | x = torch.flatten(x, 1) 142 | # print('Final: ', x.shape) 143 | return x 144 | 145 | class SupIncepResnet(nn.Module): 146 | def __init__(self, num_classes): 147 | super(SupIncepResnet, self).__init__() 148 | self.encoder = InceptionResnet() 149 | self.fc = nn.Linear(896, num_classes) 150 | 151 | def forward(self, x): 152 | return self.fc(self.encoder(x)) -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to convert .csv into tfrecord format 3 | """ 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import glob 8 | import swifter 9 | import json 10 | from sklearn.model_selection import train_test_split 11 | import tensorflow as tf 12 | from tqdm import tqdm 13 | import argparse 14 | 15 | attributes = ['Timestamp', 'canID', 'DLC', 16 | 'Data0', 'Data1', 'Data2', 17 | 'Data3', 'Data4', 'Data5', 18 | 'Data6', 'Data7', 'Flag'] 19 | def fill_flag(sample): 20 | if not isinstance(sample['Flag'], str): 21 | col = 'Data' + str(sample['DLC']) 22 | sample['Flag'], sample[col] = sample[col], sample['Flag'] 23 | return sample 24 | 25 | 26 | def serialize_example(x, y): 27 | """converts x, y to tf.train.Example and serialize""" 28 | #Need to pay attention to whether it needs to be converted to numpy() form 29 | id_seq, data_seq, data_histogram = x 30 | id_seq = tf.train.Int64List(value = np.array(id_seq).flatten()) 31 | data_seq = tf.train.Int64List(value = np.array(data_seq).flatten()) 32 | data_histogram = tf.train.Int64List(value = np.array(data_histogram).flatten()) 33 | label = tf.train.Int64List(value = np.array([y])) 34 | features = tf.train.Features( 35 | feature = { 36 | "id_seq": tf.train.Feature(int64_list = id_seq), 37 | "data_seq": tf.train.Feature(int64_list = data_seq), 38 | "data_histogram": tf.train.Feature(int64_list = data_histogram), 39 | "label" : tf.train.Feature(int64_list = label) 40 | } 41 | ) 42 | example = tf.train.Example(features = features) 43 | return example.SerializeToString() 44 | 45 | def write_tfrecord(data, filename): 46 | tfrecord_writer = tf.io.TFRecordWriter(filename) 47 | for _, row in tqdm(data.iterrows()): 48 | X = (row['id_seq'], row['data_seq'], row['data_histogram']) 49 | Y = row['label'] 50 | tfrecord_writer.write(serialize_example(X, Y)) 51 | tfrecord_writer.close() 52 | 53 | def preprocess(file_name, attack_id, window_size = 29, strided_size = 29): 54 | print("Window size = {}, strided = {}".format(window_size, strided_size)) 55 | df = pd.read_csv(file_name, header=None, names=attributes) 56 | print("Reading {}: done".format(file_name)) 57 | df = df.sort_values('Timestamp', ascending=True) 58 | df = df.swifter.apply(fill_flag, axis=1) # Paralellization is faster 59 | # Change data from hex string to int 60 | num_data_bytes = 8 61 | for x in range(num_data_bytes): 62 | df['Data'+str(x)] = df['Data'+str(x)].map(lambda x: int(x, 16), na_action='ignore') 63 | # Change can id from hex string to binary 29-bits length 64 | df['canID'] = df['canID'].apply(int, base=16).apply(bin).str[2:]\ 65 | .apply(lambda x: x.zfill(29)).apply(list)\ 66 | .apply(lambda x: list(map(int, x))) 67 | df = df.fillna(0) 68 | data_cols = ['Data{}'.format(x) for x in range(num_data_bytes)] 69 | df[data_cols] = df[data_cols].astype(int) 70 | df['Data'] = df[data_cols].values.tolist() 71 | df['Flag'] = df['Flag'].apply(lambda x: True if x=='T' else False) 72 | print("Pre-processing: Done") 73 | 74 | as_strided = np.lib.stride_tricks.as_strided 75 | output_shape = ((len(df) - window_size) // strided_size + 1, window_size) 76 | canid = as_strided(df.canID, output_shape, (8*strided_size, 8)) 77 | data = as_strided(df.Data, output_shape, (8*strided_size, 8)) #Stride is counted by bytes 78 | label = as_strided(df.Flag, output_shape, (1*strided_size, 1)) 79 | 80 | df = pd.DataFrame({ 81 | 'id_seq': pd.Series(canid.tolist()), 82 | 'data_seq': pd.Series(data.tolist()), 83 | 'label': pd.Series(label.tolist()) 84 | }, index= range(len(canid))) 85 | df['data_histogram'] = df['data_seq'].apply(lambda x: np.histogram(np.array(x), bins=256)[0]) 86 | df['label'] = df['label'].apply(lambda x: attack_id if any(x) else 0) 87 | print("Aggregating data: Done") 88 | print('#Normal: ', df[df['label'] == 0].shape[0]) 89 | print('#Attack: ', df[df['label'] != 0].shape[0]) 90 | return df[['id_seq', 'data_seq', 'data_histogram', 'label']].reset_index().drop(['index'], axis=1) 91 | 92 | def main(indir, outdir, attacks, window_size, strided): 93 | print(outdir) 94 | if not os.path.exists(outdir): 95 | os.makedirs(outdir) 96 | data_info = {} 97 | for attack_id, attack in enumerate(attacks): 98 | print('Attack: {} ==============='.format(attack)) 99 | finput = '{}/{}_dataset.csv'.format(indir, attack) 100 | df = preprocess(finput, attack_id + 1, window_size, strided) 101 | print("Writing...................") 102 | foutput_attack = '{}/{}'.format(outdir, attack) 103 | foutput_normal = '{}/Normal_{}'.format(outdir, attack) 104 | df_attack = df[df['label'] != 0] 105 | df_normal = df[df['label'] == 0] 106 | write_tfrecord(df_attack, foutput_attack) 107 | write_tfrecord(df_normal, foutput_normal) 108 | 109 | data_info[foutput_attack] = df_attack.shape[0] 110 | data_info[foutput_normal] = df_normal.shape[0] 111 | 112 | json.dump(data_info, open('{}/datainfo.txt'.format(outdir), 'w')) 113 | print("DONE!") 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--indir', type=str, default="../Data/Car-Hacking") 118 | parser.add_argument('--outdir', type=str, default="../Data/TFRecord") 119 | parser.add_argument('--window_size', type=int, default=None) 120 | parser.add_argument('--strided', type=int, default=None) 121 | parser.add_argument('--attack_type', type=str, default="all", nargs='+') 122 | args = parser.parse_args() 123 | 124 | if args.attack_type == 'all': 125 | attack_types = ['DoS', 'Fuzzy', 'gear', 'RPM'] 126 | else: 127 | attack_types = [args.attack_type] 128 | 129 | if args.strided == None: 130 | args.strided = args.window_size 131 | 132 | outdir = args.outdir + '_w{}_s{}'.format(args.window_size, args.strided) 133 | main(args.indir, outdir, attack_types, args.window_size, args.strided) 134 | -------------------------------------------------------------------------------- /preprocessing_survival.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used to convert .csv into tfrecord format 3 | """ 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | import glob 8 | import swifter 9 | import json 10 | from sklearn.model_selection import train_test_split 11 | import tensorflow as tf 12 | from tqdm import tqdm 13 | import argparse 14 | 15 | attributes = ['Timestamp', 'canID', 'DLC', 16 | 'Data0', 'Data1', 'Data2', 17 | 'Data3', 'Data4', 'Data5', 18 | 'Data6', 'Data7', 'Flag'] 19 | def fill_flag(sample): 20 | if not isinstance(sample['Flag'], str): 21 | col = 'Data' + str(sample['DLC']) 22 | sample['Flag'], sample[col] = sample[col], sample['Flag'] 23 | return sample 24 | 25 | 26 | def serialize_example(x, y): 27 | """converts x, y to tf.train.Example and serialize""" 28 | #Need to pay attention to whether it needs to be converted to numpy() form 29 | id_seq, data_seq, data_histogram = x 30 | id_seq = tf.train.Int64List(value = np.array(id_seq).flatten()) 31 | data_seq = tf.train.Int64List(value = np.array(data_seq).flatten()) 32 | data_histogram = tf.train.Int64List(value = np.array(data_histogram).flatten()) 33 | label = tf.train.Int64List(value = np.array([y])) 34 | features = tf.train.Features( 35 | feature = { 36 | "id_seq": tf.train.Feature(int64_list = id_seq), 37 | "data_seq": tf.train.Feature(int64_list = data_seq), 38 | "data_histogram": tf.train.Feature(int64_list = data_histogram), 39 | "label" : tf.train.Feature(int64_list = label) 40 | } 41 | ) 42 | example = tf.train.Example(features = features) 43 | return example.SerializeToString() 44 | 45 | def write_tfrecord(data, filename): 46 | tfrecord_writer = tf.io.TFRecordWriter(filename) 47 | for _, row in tqdm(data.iterrows()): 48 | X = (row['id_seq'], row['data_seq'], row['data_histogram']) 49 | Y = row['label'] 50 | tfrecord_writer.write(serialize_example(X, Y)) 51 | tfrecord_writer.close() 52 | 53 | def preprocess(file_name, attack_id, window_size = 29, strided_size = 29): 54 | print("Window size = {}, strided = {}".format(window_size, strided_size)) 55 | df = pd.read_csv(file_name, header=None, names=attributes) 56 | print("Reading {}: done".format(file_name)) 57 | df = df.sort_values('Timestamp', ascending=True) 58 | df = df.swifter.apply(fill_flag, axis=1) # Paralellization is faster 59 | # Change data from hex string to int 60 | num_data_bytes = 8 61 | for x in range(num_data_bytes): 62 | df['Data'+str(x)] = df['Data'+str(x)].map(lambda x: int(x, 16), na_action='ignore') 63 | # Change can id from hex string to binary 29-bits length 64 | df['canID'] = df['canID'].apply(int, base=16).apply(bin).str[2:]\ 65 | .apply(lambda x: x.zfill(29)).apply(list)\ 66 | .apply(lambda x: list(map(int, x))) 67 | df = df.fillna(0) 68 | data_cols = ['Data{}'.format(x) for x in range(num_data_bytes)] 69 | df[data_cols] = df[data_cols].astype(int) 70 | df['Data'] = df[data_cols].values.tolist() 71 | df['Flag'] = df['Flag'].apply(lambda x: True if x=='T' else False) 72 | print("Pre-processing: Done") 73 | 74 | as_strided = np.lib.stride_tricks.as_strided 75 | output_shape = ((len(df) - window_size) // strided_size + 1, window_size) 76 | canid = as_strided(df.canID, output_shape, (8*strided_size, 8)) 77 | data = as_strided(df.Data, output_shape, (8*strided_size, 8)) #Stride is counted by bytes 78 | label = as_strided(df.Flag, output_shape, (1*strided_size, 1)) 79 | 80 | df = pd.DataFrame({ 81 | 'id_seq': pd.Series(canid.tolist()), 82 | 'data_seq': pd.Series(data.tolist()), 83 | 'label': pd.Series(label.tolist()) 84 | }, index= range(len(canid))) 85 | df['data_histogram'] = df['data_seq'].apply(lambda x: np.histogram(np.array(x), bins=256)[0]) 86 | df['label'] = df['label'].apply(lambda x: attack_id if any(x) else 0) 87 | print("Aggregating data: Done") 88 | print('#Normal: ', df[df['label'] == 0].shape[0]) 89 | print('#Attack: ', df[df['label'] != 0].shape[0]) 90 | return df[['id_seq', 'data_seq', 'data_histogram', 'label']].reset_index().drop(['index'], axis=1) 91 | 92 | def main(indir, outdir, car_model, attacks, window_size, strided): 93 | print(outdir) 94 | if not os.path.exists(outdir): 95 | os.makedirs(outdir) 96 | data_info = {} 97 | for attack_id, attack in enumerate(attacks): 98 | print('Attack: {} ==============='.format(attack)) 99 | finput = '{}/{}_dataset_{}.txt'.format(indir, attack, car_model) 100 | df = preprocess(finput, attack_id + 1, window_size, strided) 101 | print("Writing...................") 102 | foutput_attack = '{}/{}'.format(outdir, attack) 103 | foutput_normal = '{}/Normal_{}'.format(outdir, attack) 104 | df_attack = df[df['label'] != 0] 105 | df_normal = df[df['label'] == 0] 106 | write_tfrecord(df_attack, foutput_attack) 107 | write_tfrecord(df_normal, foutput_normal) 108 | 109 | data_info[foutput_attack] = df_attack.shape[0] 110 | data_info[foutput_normal] = df_normal.shape[0] 111 | 112 | json.dump(data_info, open('{}/datainfo.txt'.format(outdir), 'w')) 113 | print("DONE!") 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--indir', type=str, default="../Data/Car-Hacking") 118 | parser.add_argument('--outdir', type=str, default="../Data/TFRecord") 119 | parser.add_argument('--car_model', type=str) 120 | parser.add_argument('--window_size', type=int, default=None) 121 | parser.add_argument('--strided', type=int, default=None) 122 | parser.add_argument('--attack_type', type=str, default="all", nargs='+') 123 | args = parser.parse_args() 124 | 125 | if args.attack_type == 'all': 126 | attack_types = ['Flooding', 'Fuzzy', 'Malfunction'] 127 | # attack_types = ['DoS', 'Fuzzy', 'gear', 'RPM'] 128 | else: 129 | attack_types = [args.attack_type] 130 | 131 | if args.strided == None: 132 | args.strided = args.window_size 133 | 134 | indir = os.path.join(args.indir, args.car_model) 135 | outdir = args.outdir + 'TFrecord_{}_w{}_s{}'.format(args.car_model, args.window_size, args.strided) 136 | main(indir, outdir, args.car_model, attack_types, args.window_size, args.strided) 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /networks/resnet_big.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | Adapted from: https://github.com/bearpaw/pytorch-classification 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, in_channel=1, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 16 79 | 80 | self.conv1 = nn.Conv2d(in_channel, 16, kernel_size=3, stride=1, padding=1, 81 | bias=False) 82 | self.bn1 = nn.BatchNorm2d(16) 83 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2) 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves 98 | # like an identity. This improves the model by 0.2~0.3% according to: 99 | # https://arxiv.org/abs/1706.02677 100 | if zero_init_residual: 101 | for m in self.modules(): 102 | if isinstance(m, Bottleneck): 103 | nn.init.constant_(m.bn3.weight, 0) 104 | elif isinstance(m, BasicBlock): 105 | nn.init.constant_(m.bn2.weight, 0) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | layers = [] 110 | for i in range(num_blocks): 111 | stride = strides[i] 112 | layers.append(block(self.in_planes, planes, stride)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, layer=100): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = self.avgpool(out) 123 | out = torch.flatten(out, 1) 124 | return out 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 141 | 142 | 143 | model_dict = { 144 | 'resnet18': [resnet18, 128], 145 | 'resnet34': [resnet34, 512], 146 | 'resnet50': [resnet50, 2048], 147 | 'resnet101': [resnet101, 2048], 148 | } 149 | 150 | 151 | class LinearBatchNorm(nn.Module): 152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 153 | def __init__(self, dim, affine=True): 154 | super(LinearBatchNorm, self).__init__() 155 | self.dim = dim 156 | self.bn = nn.BatchNorm2d(dim, affine=affine) 157 | 158 | def forward(self, x): 159 | x = x.view(-1, self.dim, 1, 1) 160 | x = self.bn(x) 161 | x = x.view(-1, self.dim) 162 | return x 163 | 164 | 165 | class SupConResNet(nn.Module): 166 | """backbone + projection head""" 167 | def __init__(self, name='resnet18', head='mlp', feat_dim=128): 168 | super(SupConResNet, self).__init__() 169 | model_fun, dim_in = model_dict[name] 170 | self.encoder = model_fun() 171 | if head == 'linear': 172 | self.head = nn.Linear(dim_in, feat_dim) 173 | elif head == 'mlp': 174 | self.head = nn.Sequential( 175 | nn.Linear(dim_in, dim_in), 176 | nn.ReLU(inplace=True), 177 | nn.Linear(dim_in, feat_dim) 178 | ) 179 | else: 180 | raise NotImplementedError( 181 | 'head not supported: {}'.format(head)) 182 | 183 | def forward(self, x): 184 | feat = self.encoder(x) 185 | feat = F.normalize(self.head(feat), dim=1) 186 | return feat 187 | 188 | 189 | class SupCEResNet(nn.Module): 190 | """encoder + classifier""" 191 | def __init__(self, name='resnet18', num_classes=10): 192 | super(SupCEResNet, self).__init__() 193 | model_fun, dim_in = model_dict[name] 194 | self.encoder = model_fun() 195 | self.fc = nn.Linear(dim_in, num_classes) 196 | 197 | def forward(self, x): 198 | return self.fc(self.encoder(x)) 199 | 200 | 201 | class LinearClassifier(nn.Module): 202 | """Linear classifier""" 203 | def __init__(self, name='resnet18', num_classes=10): 204 | super(LinearClassifier, self).__init__() 205 | _, feat_dim = model_dict[name] 206 | self.fc = nn.Linear(feat_dim, num_classes) 207 | 208 | def forward(self, features): 209 | return self.fc(features) 210 | -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import math 4 | import numpy as np 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | 8 | from dataset import CANDataset 9 | from utils import get_prediction, cal_metric, print_results 10 | from networks.inception import SupIncepResnet 11 | from networks.simple_cnn import BaselineCNNClassifier 12 | 13 | from networks.resnet_big import SupCEResNet 14 | from supcon.util import set_optimizer, save_model 15 | from supcon.util import AverageMeter 16 | from supcon.util import adjust_learning_rate, warmup_learning_rate, accuracy 17 | 18 | import torch 19 | from torch.utils.tensorboard import SummaryWriter 20 | import torch.optim as optim 21 | import torch.backends.cudnn as cudnn 22 | from sklearn.metrics import f1_score 23 | 24 | #from focalloss import FocalLoss 25 | 26 | NUM_CLASSES = 5 27 | MODELS = { 28 | 'inception': SupIncepResnet, 29 | 'cnn': BaselineCNNClassifier, 30 | 'resnet18': SupCEResNet 31 | } 32 | 33 | def parse_option(): 34 | parser = argparse.ArgumentParser('argument for training') 35 | parser.add_argument('--data_dir', type=str, help='directory of data for training') 36 | parser.add_argument('--model', type=str, help='choosing models in [inception, cnn]') 37 | parser.add_argument('--print_freq', type=int, default=10) 38 | parser.add_argument('--save_freq', type=int, default=1) 39 | parser.add_argument('--window_size', type=int) 40 | parser.add_argument('--batch_size', type=int) 41 | parser.add_argument('--epochs', type=int) 42 | parser.add_argument('--num_workers', type=int, default=0) 43 | parser.add_argument('--gpu_device', type=int, default=0) 44 | parser.add_argument('--rid', type=int, default=1) 45 | 46 | # optimization 47 | parser.add_argument('--gamma', type=int, default=0, 48 | help='gamma hyperparameter for focal loss') 49 | parser.add_argument('--learning_rate', type=float, default=0.2, 50 | help='learning rate') 51 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450', 52 | help='where to decay lr, can be a list') 53 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 54 | help='decay rate for learning rate') 55 | parser.add_argument('--weight_decay', type=float, default=1e-4, 56 | help='weight decay') 57 | parser.add_argument('--momentum', type=float, default=0.9, 58 | help='momentum') 59 | parser.add_argument('--cosine', action='store_true', 60 | help='using cosine annealing') 61 | 62 | opt = parser.parse_args() 63 | 64 | 65 | if torch.cuda.is_available(): 66 | torch.cuda.set_device(opt.gpu_device) 67 | 68 | opt.warm = False 69 | if opt.batch_size > 256: 70 | opt.warm = True 71 | if opt.warm: 72 | opt.warmup_from = 0.01 73 | opt.warm_epochs = 10 74 | if opt.cosine: 75 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 76 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 77 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 78 | else: 79 | opt.warmup_to = opt.learning_rate 80 | 81 | 82 | opt.model_path = './save/{}/models/' 83 | opt.tb_path = './save/{}/runs/' 84 | current_time = datetime.now().strftime("%D_%H%M%S").replace('/', '') 85 | opt.model_name = f'{opt.model}{opt.rid}_gamma{opt.gamma}_lr{opt.learning_rate}_bs{opt.batch_size}_{opt.epochs}epochs_{current_time}' 86 | if opt.cosine: 87 | opt.model_name = '{}_cosine'.format(opt.model_name) 88 | if opt.warm: 89 | opt.model_name = '{}_warm'.format(opt.model_name) 90 | 91 | opt.tb_folder = opt.tb_path.format(opt.model_name) 92 | if not os.path.isdir(opt.tb_folder): 93 | os.makedirs(opt.tb_folder, exist_ok=True) 94 | 95 | opt.save_folder = opt.model_path.format(opt.model_name) 96 | if not os.path.isdir(opt.save_folder): 97 | os.makedirs(opt.save_folder, exist_ok=True) 98 | 99 | opt.log_file = f'./save/{opt.model_name}/log' 100 | return opt 101 | 102 | 103 | def set_loader(opt): 104 | data_dir = f'{opt.data_dir}/{opt.rid}/' 105 | train_dataset = CANDataset(root_dir=data_dir, 106 | window_size=opt.window_size) 107 | val_dataset = CANDataset(root_dir=data_dir, 108 | window_size=opt.window_size, 109 | is_train=False) 110 | #train_dataset.total_size = 100000 111 | #val_dataset.total_size = 10000 112 | print('Train size: ', len(train_dataset)) 113 | print('Val size: ', len(val_dataset)) 114 | train_loader = torch.utils.data.DataLoader( 115 | train_dataset, batch_size=opt.batch_size, 116 | shuffle=True, num_workers=opt.num_workers, 117 | pin_memory=True, sampler=None) 118 | val_loader = torch.utils.data.DataLoader( 119 | val_dataset, batch_size=opt.batch_size, shuffle=False, 120 | num_workers=8, pin_memory=True, sampler=None) 121 | 122 | return train_loader, val_loader 123 | 124 | def set_model(opt): 125 | model = MODELS[opt.model] 126 | # model = SupCEResNet 127 | model = model(num_classes=NUM_CLASSES) 128 | #class_weights = [0.25, 1.0, 1.0, 1.0, 1.0] 129 | #criterion = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).cuda()) 130 | criterion = torch.nn.CrossEntropyLoss() 131 | #criterion = FocalLoss(gamma=opt.gamma) 132 | if torch.cuda.is_available(): 133 | model = model.cuda() 134 | criterion = criterion.cuda() 135 | # Incerease runtime performance 136 | cudnn.benchmark = True 137 | return model, criterion 138 | 139 | def train(train_loader, model, criterion, optimizer, epoch, opt, logger, step): 140 | model.train() 141 | 142 | losses = AverageMeter() 143 | accs = AverageMeter() 144 | 145 | for idx, (images, labels) in tqdm(enumerate(train_loader)): 146 | step += 1 147 | images = images.cuda(non_blocking=True) 148 | labels = labels.cuda(non_blocking=True) 149 | bsz = labels.shape[0] 150 | 151 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 152 | 153 | output = model(images) 154 | loss = criterion(output, labels) 155 | losses.update(loss.item(), bsz) 156 | acc1 = accuracy(output, labels, topk=(1, )) 157 | accs.update(acc1[0], bsz) 158 | 159 | if step % opt.print_freq == 0: 160 | logger.add_scalar('loss/train', losses.avg, step) 161 | 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | 166 | return step, losses.avg, accs.avg 167 | 168 | 169 | def get_predict(outputs): 170 | _, pred = outputs.topk(1, 1, True, True) 171 | pred = pred.t().cpu().numpy().squeeze(0) 172 | return pred 173 | 174 | def validate(val_loader, model, criterion, opt): 175 | model.eval() 176 | 177 | losses = AverageMeter() 178 | total_pred = np.array([], dtype=int) 179 | total_label = np.array([], dtype=int) 180 | 181 | with torch.no_grad(): 182 | for images, labels in tqdm(val_loader): 183 | images = images.cuda(non_blocking=True) 184 | outputs = model(images) 185 | bsz = labels.shape[0] 186 | loss = criterion(outputs, labels.cuda()) 187 | losses.update(loss.item(), bsz) 188 | 189 | pred = get_predict(outputs) 190 | total_pred = np.concatenate((total_pred, pred), axis=0) 191 | total_label = np.concatenate((total_label, labels), axis=0) 192 | 193 | 194 | f1 = f1_score(total_pred, total_label, average='weighted') 195 | return losses.avg, f1 196 | 197 | # python3 train_baseline.py --data_dir ../Data/TFRecord_w29_s15/1 --batch_size 1024 --window_size 29 --cosine --print_freq 100 --save_freq 5 --gpu_device 0 --model cnn --num_workers 8 --learning_rate 0.01 --epochs 100 198 | def main(): 199 | opt = parse_option() 200 | 201 | train_loader, val_loader = set_loader(opt) 202 | model, criterion = set_model(opt) 203 | optimizer = optim.SGD(model.parameters(), 204 | lr=opt.learning_rate, 205 | momentum=opt.momentum, 206 | weight_decay=opt.weight_decay) 207 | 208 | logger = SummaryWriter(log_dir=opt.tb_folder, flush_secs=2) 209 | log_writer = open(opt.log_file, 'w') 210 | step = 0 211 | for epoch in range(1, opt.epochs + 1): 212 | adjust_learning_rate(opt, optimizer, epoch) 213 | 214 | step, loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt, logger, step) 215 | print('Epoch: {}, Loss: {}, Acc: {}'.format(epoch, loss, train_acc)) 216 | log_writer.write('Epoch: {}, Loss: {}, Acc: {}\n'.format(epoch, loss, train_acc)) 217 | 218 | if epoch % 5 == 0: 219 | loss, val_f1 = validate(val_loader, model, criterion, opt) 220 | logger.add_scalar('loss/val', loss, step) 221 | print('Validation: Loss: {}, F1: {}'.format(loss, val_f1)) 222 | log_writer.write('Validation: Loss: {}, F1: {}\n'.format(loss, val_f1)) 223 | 224 | if epoch % opt.save_freq == 0: 225 | ckpt = 'ckpt_epoch_{}.pth'.format(epoch) 226 | save_file = os.path.join(opt.save_folder, ckpt) 227 | save_model(model, optimizer, opt, epoch, save_file) 228 | 229 | save_file = os.path.join(opt.save_folder, 'last.pth') 230 | save_model(model, optimizer, opt, opt.epochs, save_file) 231 | log_writer.close() 232 | 233 | if __name__ == '__main__': 234 | main() 235 | -------------------------------------------------------------------------------- /notebooks/Test_multiple.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "b1e9fa50-d96c-4268-b4d6-2b78dc331a5b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import sys\n", 14 | "sys.path.append('../')" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "8ec78ee2-1909-479e-81f4-3f781f10c2f8", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from test_model import test\n", 25 | "from argparse import Namespace\n", 26 | "import glob\n", 27 | "import numpy as np\n", 28 | "from concurrent import futures\n", 29 | "import multiprocessing\n", 30 | "import copy\n", 31 | "import json" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "id": "02f827e4-a3a4-4762-933b-f952a132df28", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "args = {\n", 42 | " 'data_path': '../../Data/',\n", 43 | " 'pretrained_model': 'incep', #supcon', #'resnet',\n", 44 | " 'pretrained_path': '',\n", 45 | " 'window_size': 29,\n", 46 | " 'strided': 15,\n", 47 | " 'batch_size': 512,\n", 48 | " 'num_workers': 8,\n", 49 | " 'trial_id': 1,\n", 50 | " 'ckpt' : 50,\n", 51 | " 'car_model': None\n", 52 | "}\n", 53 | "\n", 54 | "args = Namespace(**args)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "942c1c1d-1e41-4f0c-ab9a-9f52ddfad2f0", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "['../save/inception1_gamma0_lr0.001_bs256_100epochs_092722_095310',\n", 67 | " '../save/inception2_gamma0_lr0.001_bs256_100epochs_092722_143206',\n", 68 | " '../save/inception3_gamma0_lr0.001_bs256_100epochs_092722_143213',\n", 69 | " '../save/inception4_gamma0_lr0.001_bs256_100epochs_092722_143222',\n", 70 | " '../save/inception5_gamma0_lr0.001_bs256_100epochs_092722_143237']" 71 | ] 72 | }, 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "bs = 512\n", 80 | "pretrained_paths = sorted(glob.glob(f'../save/inception*'))\n", 81 | "# pretrained_paths = sorted(glob.glob(f'../save/SupCon.resnet18*_bs{bs}*'))\n", 82 | "# pretrained_paths = sorted(glob.glob('../save/smallresnet18.ce?_*_lr0.001_*'))\n", 83 | "\n", 84 | "pretrained_paths" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "id": "2d414d56-4ccf-464d-b226-46bd210a2890", 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Loading: SupIncepResnet\n", 100 | "Loading: SupIncepResnetLoading: \n", 101 | "Loading: SupIncepResnet \n", 102 | "SupIncepResnet\n", 103 | "Loading: SupIncepResnet\n" 104 | ] 105 | }, 106 | { 107 | "data": { 108 | "application/vnd.jupyter.widget-view+json": { 109 | "model_id": "0418bf45e0c44e1c9d38aa3b05fde324", 110 | "version_major": 2, 111 | "version_minor": 0 112 | }, 113 | "text/plain": [ 114 | " 0%| | 0/648 [00:00 256: 80 | opt.warm = True 81 | if opt.warm: 82 | opt.warmup_from = 0.01 83 | opt.warm_epochs = 10 84 | if opt.cosine: 85 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 86 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 87 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 88 | else: 89 | opt.warmup_to = opt.learning_rate 90 | 91 | 92 | opt.model_path = './save/{}/models/' 93 | opt.tb_path = './save/{}/runs/' 94 | current_time = datetime.now().strftime("%D_%H%M%S").replace('/', '') 95 | opt.model_name = f'SupCon_{opt.model}{opt.rid}_lr{opt.learning_rate}_{opt.learning_rate_classifier}_bs{opt.batch_size}_{opt.epochs}epoch_temp{opt.temp}_{current_time}' 96 | if opt.cosine: 97 | opt.model_name = '{}_cosine'.format(opt.model_name) 98 | if opt.warm: 99 | opt.model_name = '{}_warm'.format(opt.model_name) 100 | 101 | opt.tb_folder = opt.tb_path.format(opt.model_name) 102 | if not os.path.isdir(opt.tb_folder): 103 | os.makedirs(opt.tb_folder, exist_ok=True) 104 | 105 | opt.save_folder = opt.model_path.format(opt.model_name) 106 | if not os.path.isdir(opt.save_folder): 107 | os.makedirs(opt.save_folder, exist_ok=True) 108 | 109 | 110 | opt.log_file = f'./save/{opt.model_name}/log' 111 | 112 | return opt 113 | 114 | 115 | def set_loader(opt): 116 | data_dir = f'{opt.data_dir}/{opt.rid}/' 117 | train_dataset = CANDataset(root_dir=data_dir, 118 | window_size=opt.window_size) 119 | val_dataset = CANDataset(root_dir=data_dir, 120 | window_size=opt.window_size, 121 | is_train=False) 122 | #train_dataset.total_size = 100000 123 | #val_dataset.total_size = 10000 124 | #print('Train size: ', len(train_dataset)) 125 | #print('Val size: ', len(val_dataset)) 126 | train_loader = torch.utils.data.DataLoader( 127 | train_dataset, batch_size=opt.batch_size, 128 | shuffle=True, num_workers=opt.num_workers, 129 | pin_memory=True, sampler=None) 130 | 131 | train_classifier_loader = torch.utils.data.DataLoader( 132 | train_dataset, batch_size=256, 133 | shuffle=True, num_workers=opt.num_workers, 134 | pin_memory=True, sampler=None) 135 | 136 | val_loader = torch.utils.data.DataLoader( 137 | val_dataset, batch_size=1024, shuffle=False, 138 | num_workers=2, pin_memory=True) 139 | 140 | return train_loader, train_classifier_loader, val_loader 141 | 142 | def set_model(opt): 143 | #model = SupConCNN(feat_dim=128) 144 | model = SupConResNet('resnet18') 145 | criterion_model = SupConLoss(temperature=opt.temp, contrast_mode='one') 146 | classifier = LinearClassifier(n_classes=NUM_CLASSES, feat_dim=128) 147 | #class_weights = [0.25, 1.0, 1.0, 1.0, 1.0] 148 | #criterion_classifier = FocalLoss(gamma=0.0) #torch.nn.CrossEntropyLoss() 149 | criterion_classifier = torch.nn.CrossEntropyLoss() 150 | 151 | if torch.cuda.is_available(): 152 | if torch.cuda.device_count() > 1: 153 | model.encoder = torch.nn.DataParallel(model.encoder) 154 | model = model.cuda() 155 | criterion_model = criterion_model.cuda() 156 | classifier = classifier.cuda() 157 | criterion_classifier = criterion_classifier.cuda() 158 | # Incerease runtime performance 159 | cudnn.benchmark = True 160 | print('Model device: ', next(model.parameters()).device) 161 | return model, criterion_model, classifier, criterion_classifier 162 | 163 | def train_model(train_loader, model, criterion, optimizer, epoch, opt, logger, step): 164 | model.train() 165 | 166 | losses = AverageMeter() 167 | 168 | for idx, (images, labels) in tqdm(enumerate(train_loader)): 169 | step += 1 170 | if torch.cuda.is_available(): 171 | images = images.cuda(non_blocking=True) 172 | labels = labels.cuda(non_blocking=True) 173 | bsz = labels.shape[0] 174 | 175 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 176 | 177 | print('Data device: ', images.device) 178 | features = model(images) 179 | features = features.unsqueeze(1) 180 | loss = criterion(features, labels) 181 | 182 | losses.update(loss.item(), bsz) 183 | 184 | if step % opt.print_freq == 0: 185 | logger.add_scalar('loss_supcon', losses.avg, step) 186 | 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | return step, losses.avg 192 | 193 | 194 | def train_classifier(train_loader, model, classifier, criterion, optimizer, epoch, opt, step, logger): 195 | model.eval() 196 | classifier.train() 197 | 198 | losses = AverageMeter() 199 | accs = AverageMeter() 200 | 201 | for idx, (images, labels) in tqdm(enumerate(train_loader)): 202 | step += 1 203 | if torch.cuda.is_available(): 204 | images = images.cuda(non_blocking=True) 205 | labels = labels.cuda(non_blocking=True) 206 | bsz = labels.shape[0] 207 | 208 | #warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 209 | 210 | with torch.no_grad(): 211 | features = model.encoder(images) 212 | 213 | output = classifier(features.detach()) 214 | loss = criterion(output, labels) 215 | losses.update(loss.item(), bsz) 216 | acc = accuracy(output, labels, topk=(1, )) 217 | accs.update(acc[0], bsz) 218 | 219 | if step % opt.print_freq == 0: 220 | logger.add_scalar('loss_ce/train', losses.avg, step) 221 | 222 | optimizer.zero_grad() 223 | loss.backward() 224 | optimizer.step() 225 | 226 | return step, losses.avg, accs.avg 227 | 228 | 229 | def get_predict(outputs): 230 | _, pred = outputs.topk(1, 1, True, True) 231 | pred = pred.t().cpu().numpy().squeeze(0) 232 | return pred 233 | 234 | def validate(val_loader, model, classifier, criterion, opt): 235 | model.eval() 236 | classifier.eval() 237 | 238 | losses = AverageMeter() 239 | total_pred = np.array([], dtype=int) 240 | total_label = np.array([], dtype=int) 241 | 242 | with torch.no_grad(): 243 | for images, labels in tqdm(val_loader): 244 | images = images.cuda(non_blocking=True) 245 | outputs = classifier(model.encoder(images)) 246 | bsz = labels.shape[0] 247 | loss = criterion(outputs, labels.cuda()) 248 | losses.update(loss.item(), bsz) 249 | 250 | pred = get_predict(outputs) 251 | total_pred = np.concatenate((total_pred, pred), axis=0) 252 | total_label = np.concatenate((total_label, labels), axis=0) 253 | 254 | 255 | f1 = f1_score(total_pred, total_label, average='weighted') 256 | return losses.avg, f1 257 | 258 | optimize_dict = { 259 | 'SGD' : optim.SGD, 260 | 'RMSprop': optim.RMSprop, 261 | 'Adam': optim.Adam 262 | } 263 | 264 | def set_optimizer(opt, model, class_str='', optim_choice='SGD'): 265 | dict_opt = vars(opt) 266 | optimizer = optimize_dict[optim_choice] 267 | if optim_choice == 'Adam': 268 | optimizer = optimizer(model.parameters(), 269 | lr=dict_opt['learning_rate'+class_str], 270 | weight_decay=dict_opt['weight_decay'+class_str]) 271 | else: 272 | optimizer = optimizer(model.parameters(), 273 | lr=dict_opt['learning_rate'+class_str], 274 | momentum=dict_opt['momentum'+class_str], 275 | weight_decay=dict_opt['weight_decay'+class_str]) 276 | return optimizer 277 | 278 | def adjust_learning_rate(args, optimizer, epoch, class_str=''): 279 | dict_args = vars(args) 280 | lr = dict_args['learning_rate'+class_str] 281 | if args.cosine: 282 | eta_min = lr * (dict_args['lr_decay_rate'+class_str] ** 3) 283 | lr = eta_min + (lr - eta_min) * ( 284 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 285 | else: 286 | lr_decay_epochs_arr = list(map(int, dict_args['lr_decay_epochs'+class_str].split(','))) 287 | lr_decay_epochs_arr = np.asarray(lr_decay_epochs_arr) 288 | steps = np.sum(epoch > lr_decay_epochs_arr) 289 | if steps > 0: 290 | lr = lr * (dict_args['lr_decay_rate'+class_str] ** steps) 291 | 292 | for param_group in optimizer.param_groups: 293 | param_group['lr'] = lr 294 | 295 | # python3 train_supcon.py --data_dir ../Data/TFRecord_w29_s15/1/ --batch_size 4096 --window_size 29 --cosine --print_freq 100 --mode cnn --gpu_device 0 --learning_rate 0.1 --learning_rate_classifier 0.5 --num_workers 4 --epochs 200 --save_freq 5 296 | 297 | def main(): 298 | opt = parse_option() 299 | 300 | train_loader, train_classifier_loader, val_loader = set_loader(opt) 301 | model, criterion_model, classifier, criterion_classifier = set_model(opt) 302 | 303 | optimizer_model = set_optimizer(opt, model, optim_choice='SGD') 304 | optimizer_classifier = set_optimizer(opt, classifier, class_str='_classifier', optim_choice='SGD') 305 | 306 | logger = SummaryWriter(log_dir=opt.tb_folder, flush_secs=2) 307 | #train_classifier_freq = 2 308 | step = 0 309 | 310 | log_writer = open(opt.log_file, 'w') 311 | for epoch in range(1, opt.epochs + 1): 312 | adjust_learning_rate(opt, optimizer_model, epoch) 313 | 314 | new_step, loss = train_model(train_loader, model, criterion_model, optimizer_model, epoch, opt, logger, step) 315 | print('Epoch: {}, SupCon Loss: {:.4f}'.format(epoch, loss)) 316 | log_writer.write('Epoch: {}, SupCon Loss: {:.4f}\n'.format(epoch, loss)) 317 | # Train and validate classifier 318 | class_epoch = epoch - opt.epoch_start_classifier + 1 319 | if class_epoch > 0: 320 | adjust_learning_rate(opt, optimizer_classifier, class_epoch, '_classifier') 321 | new_step, loss_ce, train_acc = train_classifier(train_classifier_loader, model, classifier, 322 | criterion_classifier, optimizer_classifier, epoch, opt, step, logger) 323 | print('Classifier: Loss: {:.4f}, Acc: {}'.format(loss_ce, train_acc)) 324 | log_writer.write('Classifier: Loss: {:.4f}, Acc: {}\n'.format(loss_ce, train_acc)) 325 | loss, val_f1 = validate(val_loader, model, classifier, criterion_classifier, opt) 326 | logger.add_scalar('loss_ce/val', loss, step) 327 | print('Validation: Loss: {:.6f}, F1: {:.8f}'.format(loss, val_f1)) 328 | log_writer.write('Validation: Loss: {:.6f}, F1: {:.8f}\n'.format(loss, val_f1)) 329 | 330 | step = new_step 331 | if epoch % opt.save_freq == 0: 332 | ckpt = 'ckpt_epoch_{}.pth'.format(epoch) 333 | save_file = os.path.join(opt.save_folder, ckpt) 334 | save_model(model, optimizer_model, opt, epoch, save_file) 335 | if class_epoch > 0: 336 | ckpt = 'ckpt_class_epoch_{}.pth'.format(epoch) 337 | save_file = os.path.join(opt.save_folder, ckpt) 338 | save_model(classifier, optimizer_classifier, opt, epoch, save_file) 339 | 340 | save_file = os.path.join(opt.save_folder, 'last.pth') 341 | save_model(model, optimizer_model, opt, opt.epochs, save_file) 342 | save_file = os.path.join(opt.save_folder, 'last_classifier.pth') 343 | save_model(classifier, optimizer_classifier, opt, opt.epochs, save_file) 344 | 345 | 346 | # python3 train_supcon.py --data_dir ../Data/TFrecord_w29_s15/ --model resnet18 --save_freq 10 --window_size 29 --epochs 200 --num_workers 8 --temp 0.07 --learning_rate 0.1 --learning_rate_classifier 0.01 --cosine --epoch_start_classifier 170 --batch_size 1024 --rid 5 347 | if __name__ == '__main__': 348 | main() 349 | -------------------------------------------------------------------------------- /notebooks/Performance_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4b4eb8c0-6156-4d2b-86c7-8546daa6632a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import sys\n", 14 | "sys.path.append('../')" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "94e9d30d-9e89-4ee8-800a-6f440228144a", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from networks.inception import InceptionResnet, SupIncepResnet\n", 25 | "from networks.simple_cnn import BaselineCNNClassifier\n", 26 | "from networks.resnet_big import SupCEResNet, SupConResNet, LinearClassifier\n", 27 | "import torch\n", 28 | "from torchsummary import summary\n", 29 | "from thop import profile\n", 30 | "import numpy as np\n", 31 | "import time" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "id": "398ddce7", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "def measure_time_gpu(model, device, rep):\n", 42 | " model = model.to(device=device)\n", 43 | " dummy_input = torch.randn(1, 1, 29, 29, dtype=torch.float).to(device)\n", 44 | " # INIT LOGGERS\n", 45 | " starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n", 46 | " repetitions = rep\n", 47 | " timings=np.zeros((repetitions,1))\n", 48 | " #GPU-WARM-UP\n", 49 | " for _ in range(100):\n", 50 | " _ = model(dummy_input)\n", 51 | " # MEASURE PERFORMANCE\n", 52 | " with torch.no_grad():\n", 53 | " for rep in range(repetitions):\n", 54 | " starter.record()\n", 55 | " _ = model(dummy_input)\n", 56 | " ender.record()\n", 57 | " # WAIT FOR GPU SYNC\n", 58 | " torch.cuda.synchronize()\n", 59 | " curr_time = starter.elapsed_time(ender)\n", 60 | " timings[rep] = curr_time\n", 61 | " mean_syn = np.sum(timings) / repetitions\n", 62 | " std_syn = np.std(timings)\n", 63 | " return mean_syn, std_syn\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "id": "f08bd61e", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "def measure_time_cpu(model, device, rep = 10):\n", 74 | " model = model.to(device=device)\n", 75 | " x = torch.rand((1, 1, 29, 29), device=device)\n", 76 | " timings=np.zeros((rep,1))\n", 77 | " for i in range(rep): \n", 78 | " start_time = time.time()\n", 79 | " out = model(x)\n", 80 | " timings[i] = time.time() - start_time\n", 81 | " mean_syn = np.sum(timings) / rep\n", 82 | " std_syn = np.std(timings)\n", 83 | " return mean_syn, std_syn" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "id": "46511c12-302e-412d-9558-1c30d6598fd8", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "(5.565476221561432, 0.2572142879522647)" 96 | ] 97 | }, 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "incep = SupIncepResnet(num_classes=5)\n", 105 | "measure_time_gpu(incep, 'cuda', rep=1000)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "id": "b35e749d-bf36-4c56-a6a3-47e6732abf5e", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "(5.958114630699158, 0.3380396613151823)" 118 | ] 119 | }, 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "baseline_model = SupCEResNet(name='resnet18', num_classes=5)\n", 127 | "measure_time_gpu(baseline_model, 'cuda', rep=1000)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "id": "ec6b048f-724f-4617-8650-d46c4daf0575", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "[INFO] Register count_convNd() for .\n", 141 | "[INFO] Register count_normalization() for .\n", 142 | "[INFO] Register zero_ops() for .\n", 143 | "[INFO] Register count_adap_avgpool() for .\n", 144 | "[INFO] Register count_linear() for .\n", 145 | "MACs (G): 0.032560592\n", 146 | "Params (M): 0.700533\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "x = torch.rand((1, 1, 29, 29), device='cpu')\n", 152 | "baseline_model = baseline_model.to(device='cpu')\n", 153 | "macs, params = profile(baseline_model, inputs=(x, ))\n", 154 | "print('MACs (G): ', macs/1000**3)\n", 155 | "print('Params (M): ', params/1000**2)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "id": "31daf0f2-281a-44e8-aede-c7d57a1205df", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "[INFO] Register count_convNd() for .\n", 169 | "[INFO] Register zero_ops() for .\n", 170 | "[INFO] Register zero_ops() for .\n", 171 | "[INFO] Register zero_ops() for .\n", 172 | "[INFO] Register count_adap_avgpool() for .\n", 173 | "[INFO] Register count_linear() for .\n", 174 | "MACs (G): 0.097190176\n", 175 | "Params (M): 1.694181\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "x = torch.rand((1, 1, 29, 29), device='cpu')\n", 181 | "incep = incep.to(device='cpu')\n", 182 | "macs, params = profile(incep, inputs=(x, ))\n", 183 | "print('MACs (G): ', macs/1000**3)\n", 184 | "print('Params (M): ', params/1000**2)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "id": "cf5bac89-5eee-4c5c-b616-c7ddd465c953", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "----------------------------------------------------------------\n", 198 | " Layer (type) Output Shape Param #\n", 199 | "================================================================\n", 200 | " Conv2d-1 [-1, 16, 29, 29] 144\n", 201 | " BatchNorm2d-2 [-1, 16, 29, 29] 32\n", 202 | " Conv2d-3 [-1, 16, 29, 29] 2,304\n", 203 | " BatchNorm2d-4 [-1, 16, 29, 29] 32\n", 204 | " Conv2d-5 [-1, 16, 29, 29] 2,304\n", 205 | " BatchNorm2d-6 [-1, 16, 29, 29] 32\n", 206 | " BasicBlock-7 [-1, 16, 29, 29] 0\n", 207 | " Conv2d-8 [-1, 16, 29, 29] 2,304\n", 208 | " BatchNorm2d-9 [-1, 16, 29, 29] 32\n", 209 | " Conv2d-10 [-1, 16, 29, 29] 2,304\n", 210 | " BatchNorm2d-11 [-1, 16, 29, 29] 32\n", 211 | " BasicBlock-12 [-1, 16, 29, 29] 0\n", 212 | " Conv2d-13 [-1, 32, 15, 15] 4,608\n", 213 | " BatchNorm2d-14 [-1, 32, 15, 15] 64\n", 214 | " Conv2d-15 [-1, 32, 15, 15] 9,216\n", 215 | " BatchNorm2d-16 [-1, 32, 15, 15] 64\n", 216 | " Conv2d-17 [-1, 32, 15, 15] 512\n", 217 | " BatchNorm2d-18 [-1, 32, 15, 15] 64\n", 218 | " BasicBlock-19 [-1, 32, 15, 15] 0\n", 219 | " Conv2d-20 [-1, 32, 15, 15] 9,216\n", 220 | " BatchNorm2d-21 [-1, 32, 15, 15] 64\n", 221 | " Conv2d-22 [-1, 32, 15, 15] 9,216\n", 222 | " BatchNorm2d-23 [-1, 32, 15, 15] 64\n", 223 | " BasicBlock-24 [-1, 32, 15, 15] 0\n", 224 | " Conv2d-25 [-1, 64, 8, 8] 18,432\n", 225 | " BatchNorm2d-26 [-1, 64, 8, 8] 128\n", 226 | " Conv2d-27 [-1, 64, 8, 8] 36,864\n", 227 | " BatchNorm2d-28 [-1, 64, 8, 8] 128\n", 228 | " Conv2d-29 [-1, 64, 8, 8] 2,048\n", 229 | " BatchNorm2d-30 [-1, 64, 8, 8] 128\n", 230 | " BasicBlock-31 [-1, 64, 8, 8] 0\n", 231 | " Conv2d-32 [-1, 64, 8, 8] 36,864\n", 232 | " BatchNorm2d-33 [-1, 64, 8, 8] 128\n", 233 | " Conv2d-34 [-1, 64, 8, 8] 36,864\n", 234 | " BatchNorm2d-35 [-1, 64, 8, 8] 128\n", 235 | " BasicBlock-36 [-1, 64, 8, 8] 0\n", 236 | " Conv2d-37 [-1, 128, 4, 4] 73,728\n", 237 | " BatchNorm2d-38 [-1, 128, 4, 4] 256\n", 238 | " Conv2d-39 [-1, 128, 4, 4] 147,456\n", 239 | " BatchNorm2d-40 [-1, 128, 4, 4] 256\n", 240 | " Conv2d-41 [-1, 128, 4, 4] 8,192\n", 241 | " BatchNorm2d-42 [-1, 128, 4, 4] 256\n", 242 | " BasicBlock-43 [-1, 128, 4, 4] 0\n", 243 | " Conv2d-44 [-1, 128, 4, 4] 147,456\n", 244 | " BatchNorm2d-45 [-1, 128, 4, 4] 256\n", 245 | " Conv2d-46 [-1, 128, 4, 4] 147,456\n", 246 | " BatchNorm2d-47 [-1, 128, 4, 4] 256\n", 247 | " BasicBlock-48 [-1, 128, 4, 4] 0\n", 248 | "AdaptiveAvgPool2d-49 [-1, 128, 1, 1] 0\n", 249 | " ResNet-50 [-1, 128] 0\n", 250 | " Linear-51 [-1, 5] 645\n", 251 | "================================================================\n", 252 | "Total params: 700,533\n", 253 | "Trainable params: 700,533\n", 254 | "Non-trainable params: 0\n", 255 | "----------------------------------------------------------------\n", 256 | "Input size (MB): 0.00\n", 257 | "Forward/backward pass size (MB): 2.46\n", 258 | "Params size (MB): 2.67\n", 259 | "Estimated Total Size (MB): 5.13\n", 260 | "----------------------------------------------------------------\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "summary(baseline_model.cuda(), (1, 29, 29))" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 10, 271 | "id": "56eb30d1-60ce-4944-b917-2a012b1f49fc", 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "name": "stdout", 276 | "output_type": "stream", 277 | "text": [ 278 | "----------------------------------------------------------------\n", 279 | " Layer (type) Output Shape Param #\n", 280 | "================================================================\n", 281 | " Conv2d-1 [-1, 32, 29, 29] 320\n", 282 | " Conv2d-2 [-1, 32, 27, 27] 9,248\n", 283 | " MaxPool2d-3 [-1, 32, 13, 13] 0\n", 284 | " Conv2d-4 [-1, 64, 13, 13] 2,112\n", 285 | " Conv2d-5 [-1, 128, 13, 13] 73,856\n", 286 | " Conv2d-6 [-1, 128, 13, 13] 147,584\n", 287 | " Stem-7 [-1, 128, 13, 13] 0\n", 288 | " Conv2d-8 [-1, 32, 13, 13] 4,128\n", 289 | " Conv2d-9 [-1, 32, 13, 13] 4,128\n", 290 | " Conv2d-10 [-1, 32, 13, 13] 9,248\n", 291 | " Conv2d-11 [-1, 32, 13, 13] 4,128\n", 292 | " Conv2d-12 [-1, 32, 13, 13] 9,248\n", 293 | " Conv2d-13 [-1, 32, 13, 13] 9,248\n", 294 | " Conv2d-14 [-1, 128, 13, 13] 12,416\n", 295 | " ReLU-15 [-1, 128, 13, 13] 0\n", 296 | "InceptionresenetA-16 [-1, 128, 13, 13] 0\n", 297 | " MaxPool2d-17 [-1, 128, 6, 6] 0\n", 298 | " Conv2d-18 [-1, 192, 6, 6] 221,376\n", 299 | " Conv2d-19 [-1, 96, 13, 13] 12,384\n", 300 | " Conv2d-20 [-1, 96, 13, 13] 83,040\n", 301 | " Conv2d-21 [-1, 128, 6, 6] 110,720\n", 302 | " ReductionA-22 [-1, 448, 6, 6] 0\n", 303 | " Conv2d-23 [-1, 64, 6, 6] 28,736\n", 304 | " Conv2d-24 [-1, 64, 6, 6] 28,736\n", 305 | " Conv2d-25 [-1, 64, 6, 6] 12,352\n", 306 | " Conv2d-26 [-1, 64, 6, 6] 12,352\n", 307 | " Conv2d-27 [-1, 448, 6, 6] 57,792\n", 308 | " ReLU-28 [-1, 448, 6, 6] 0\n", 309 | " InceptionresnetB-29 [-1, 448, 6, 6] 0\n", 310 | " MaxPool2d-30 [-1, 448, 2, 2] 0\n", 311 | " Conv2d-31 [-1, 128, 6, 6] 57,472\n", 312 | " Conv2d-32 [-1, 192, 2, 2] 221,376\n", 313 | " Conv2d-33 [-1, 128, 6, 6] 57,472\n", 314 | " Conv2d-34 [-1, 128, 2, 2] 147,584\n", 315 | " Conv2d-35 [-1, 128, 6, 6] 57,472\n", 316 | " Conv2d-36 [-1, 128, 4, 4] 147,584\n", 317 | " Conv2d-37 [-1, 128, 2, 2] 147,584\n", 318 | " ReductionB-38 [-1, 896, 2, 2] 0\n", 319 | "AdaptiveAvgPool2d-39 [-1, 896, 1, 1] 0\n", 320 | " Dropout2d-40 [-1, 896, 1, 1] 0\n", 321 | " InceptionResnet-41 [-1, 896] 0\n", 322 | " Linear-42 [-1, 5] 4,485\n", 323 | "================================================================\n", 324 | "Total params: 1,694,181\n", 325 | "Trainable params: 1,694,181\n", 326 | "Non-trainable params: 0\n", 327 | "----------------------------------------------------------------\n", 328 | "Input size (MB): 0.00\n", 329 | "Forward/backward pass size (MB): 2.87\n", 330 | "Params size (MB): 6.46\n", 331 | "Estimated Total Size (MB): 9.34\n", 332 | "----------------------------------------------------------------\n" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "summary(incep.cuda(), (1, 29, 29))" 338 | ] 339 | } 340 | ], 341 | "metadata": { 342 | "kernelspec": { 343 | "display_name": "Python 3 (ipykernel)", 344 | "language": "python", 345 | "name": "python3" 346 | }, 347 | "language_info": { 348 | "codemirror_mode": { 349 | "name": "ipython", 350 | "version": 3 351 | }, 352 | "file_extension": ".py", 353 | "mimetype": "text/x-python", 354 | "name": "python", 355 | "nbconvert_exporter": "python", 356 | "pygments_lexer": "ipython3", 357 | "version": "3.8.10" 358 | }, 359 | "vscode": { 360 | "interpreter": { 361 | "hash": "ec8a7a313ab33d199c8aa698bb86bd912b8385ce4922a6e184e3f5edd5eb95f6" 362 | } 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 5 367 | } 368 | -------------------------------------------------------------------------------- /notebooks/histogram_based.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import sys\n", 13 | "sys.path.append('../')" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "INFO: Pandarallel will run on 16 workers.\n", 26 | "INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "import pandas as pd\n", 32 | "from numpy.lib.stride_tricks import as_strided\n", 33 | "from pandarallel import pandarallel\n", 34 | "import numpy as np\n", 35 | "pandarallel.initialize(progress_bar=False)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Preprocessing (Run only 1 time)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def fill_flag(sample):\n", 52 | " if not isinstance(sample['Flag'], str):\n", 53 | " col = 'Data' + str(sample['DLC'])\n", 54 | " sample['Flag'], sample[col] = sample[col], sample['Flag']\n", 55 | " return sample" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "def sliding_window(data, win=29, s=1):\n", 65 | " itemsize = data.itemsize\n", 66 | " N = len(data)\n", 67 | " sliding_data = as_strided(data, shape=((N - win) // s + 1, win), strides=(itemsize*s, itemsize))\n", 68 | " return sliding_data" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "attacks = ['DoS', 'Fuzzy', 'gear', 'RPM']\n", 78 | "attributes = ['Timestamp', 'canID', 'DLC', \n", 79 | " 'Data0', 'Data1', 'Data2', \n", 80 | " 'Data3', 'Data4', 'Data5', \n", 81 | " 'Data6', 'Data7', 'Flag']\n", 82 | "def preprocessing(attack_id, window_size, strided):\n", 83 | " filename = f'../../Data/Car-Hacking/{attacks[attack_id]}_dataset.csv'\n", 84 | " df = pd.read_csv(filename, header=None, names=attributes)\n", 85 | " df = df.parallel_apply(fill_flag, axis=1)\n", 86 | " num_data_bytes = 8\n", 87 | " for x in range(num_data_bytes):\n", 88 | " df['Data'+str(x)] = df['Data'+str(x)].map(lambda x: int(x, 16), na_action='ignore')\n", 89 | " df = df.fillna(0)\n", 90 | " data_cols = ['Data{}'.format(x) for x in range(num_data_bytes)]\n", 91 | " df[data_cols] = df[data_cols].astype(int) \n", 92 | " df['Data'] = df[data_cols].values.tolist()\n", 93 | " df['Flag'] = (df['Flag'] == 'T')\n", 94 | " sliding_label = sliding_window(df['Flag'].to_numpy(), win=window_size, s=strided)\n", 95 | " sliding_data = sliding_window(df['Data'].to_numpy(), win=window_size, s=strided)\n", 96 | " labels = np.any(sliding_label, axis=1).astype('int8')\n", 97 | " pp_df = pd.DataFrame({\n", 98 | " 'data_seq': pd.Series(sliding_data.tolist()),\n", 99 | " 'label': pd.Series(sliding_label.tolist())\n", 100 | " }, index=range(len(sliding_label)))\n", 101 | " pp_df['data_histogram'] = pp_df['data_seq'].parallel_apply(lambda x: np.histogram(np.array(x), bins=256)[0])\n", 102 | " pp_df['label'] = pp_df['label'].parallel_apply(lambda x: (attack_id + 1) if any(x) else 0)\n", 103 | " return pp_df" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "LABEL: DoS\n", 116 | "LABEL: Fuzzy\n", 117 | "LABEL: gear\n", 118 | "LABEL: RPM\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "df_list = []\n", 124 | "for i, a in enumerate(attacks): \n", 125 | " print(f'LABEL: {a}')\n", 126 | " df = preprocessing(i, window_size=30, strided=10)\n", 127 | " df_list.append(df)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "df = pd.concat(df_list)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 8, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/html": [ 147 | "
\n", 148 | "\n", 161 | "\n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | "
data_seqlabeldata_histogram
0[[5, 33, 104, 9, 33, 33, 0, 111], [254, 91, 0,...0[108, 0, 2, 2, 1, 4, 0, 1, 7, 2, 0, 3, 1, 0, 0...
1[[229, 127, 0, 0, 72, 127, 11, 172], [0, 0, 0,...0[110, 1, 2, 1, 1, 3, 0, 1, 8, 1, 1, 2, 3, 0, 1...
2[[64, 187, 127, 20, 17, 32, 0, 20], [0, 0, 0, ...0[114, 1, 1, 2, 0, 3, 0, 0, 4, 1, 1, 1, 4, 3, 2...
3[[11, 128, 0, 255, 69, 128, 12, 133], [14, 128...0[96, 1, 3, 1, 0, 3, 0, 0, 6, 3, 1, 1, 3, 3, 2,...
4[[5, 33, 104, 9, 33, 33, 0, 111], [64, 187, 12...0[109, 0, 2, 2, 0, 4, 0, 0, 4, 3, 0, 0, 1, 3, 3...
............
462163[[0, 32, 0, 0, 0, 0, 0, 0], [0, 64, 96, 255, 1...0[115, 3, 1, 0, 0, 3, 3, 3, 3, 4, 1, 0, 0, 0, 0...
462164[[0, 128, 0, 0, 48, 127, 6, 68], [0, 0, 0, 0, ...0[107, 2, 2, 0, 0, 3, 1, 3, 2, 4, 2, 0, 0, 0, 0...
462165[[0, 64, 96, 255, 126, 133, 9, 0], [255, 0, 0,...0[108, 3, 1, 0, 0, 4, 0, 3, 5, 4, 2, 0, 0, 0, 1...
462166[[5, 34, 28, 10, 34, 30, 0, 111], [254, 89, 0,...0[110, 2, 2, 0, 2, 3, 0, 0, 6, 5, 2, 0, 0, 0, 1...
462167[[0, 0, 0, 0, 0, 0, 0, 0], [41, 39, 39, 35, 0,...0[101, 2, 2, 0, 2, 4, 0, 0, 6, 5, 2, 0, 0, 1, 1...
\n", 239 | "

1656939 rows × 3 columns

\n", 240 | "
" 241 | ], 242 | "text/plain": [ 243 | " data_seq label \\\n", 244 | "0 [[5, 33, 104, 9, 33, 33, 0, 111], [254, 91, 0,... 0 \n", 245 | "1 [[229, 127, 0, 0, 72, 127, 11, 172], [0, 0, 0,... 0 \n", 246 | "2 [[64, 187, 127, 20, 17, 32, 0, 20], [0, 0, 0, ... 0 \n", 247 | "3 [[11, 128, 0, 255, 69, 128, 12, 133], [14, 128... 0 \n", 248 | "4 [[5, 33, 104, 9, 33, 33, 0, 111], [64, 187, 12... 0 \n", 249 | "... ... ... \n", 250 | "462163 [[0, 32, 0, 0, 0, 0, 0, 0], [0, 64, 96, 255, 1... 0 \n", 251 | "462164 [[0, 128, 0, 0, 48, 127, 6, 68], [0, 0, 0, 0, ... 0 \n", 252 | "462165 [[0, 64, 96, 255, 126, 133, 9, 0], [255, 0, 0,... 0 \n", 253 | "462166 [[5, 34, 28, 10, 34, 30, 0, 111], [254, 89, 0,... 0 \n", 254 | "462167 [[0, 0, 0, 0, 0, 0, 0, 0], [41, 39, 39, 35, 0,... 0 \n", 255 | "\n", 256 | " data_histogram \n", 257 | "0 [108, 0, 2, 2, 1, 4, 0, 1, 7, 2, 0, 3, 1, 0, 0... \n", 258 | "1 [110, 1, 2, 1, 1, 3, 0, 1, 8, 1, 1, 2, 3, 0, 1... \n", 259 | "2 [114, 1, 1, 2, 0, 3, 0, 0, 4, 1, 1, 1, 4, 3, 2... \n", 260 | "3 [96, 1, 3, 1, 0, 3, 0, 0, 6, 3, 1, 1, 3, 3, 2,... \n", 261 | "4 [109, 0, 2, 2, 0, 4, 0, 0, 4, 3, 0, 0, 1, 3, 3... \n", 262 | "... ... \n", 263 | "462163 [115, 3, 1, 0, 0, 3, 3, 3, 3, 4, 1, 0, 0, 0, 0... \n", 264 | "462164 [107, 2, 2, 0, 0, 3, 1, 3, 2, 4, 2, 0, 0, 0, 0... \n", 265 | "462165 [108, 3, 1, 0, 0, 4, 0, 3, 5, 4, 2, 0, 0, 0, 1... \n", 266 | "462166 [110, 2, 2, 0, 2, 3, 0, 0, 6, 5, 2, 0, 0, 0, 1... \n", 267 | "462167 [101, 2, 2, 0, 2, 4, 0, 0, 6, 5, 2, 0, 0, 1, 1... \n", 268 | "\n", 269 | "[1656939 rows x 3 columns]" 270 | ] 271 | }, 272 | "execution_count": 8, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "df" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 9, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "new_df = pd.DataFrame([pd.Series(x) for x in df['data_histogram']])\n", 288 | "new_df.columns = ['{}'.format(x+1) for x in new_df.columns]" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 10, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/html": [ 299 | "
\n", 300 | "\n", 313 | "\n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | "
12345678910...247248249250251252253254255256
0108022140172...0000000028
1110121130181...0000000029
2114112030041...0000000019
396131030063...00000000210
4109022040043...0000000019
..................................................................
1656934115310033334...00000000112
1656935107220031324...0000000029
1656936108310040354...01000000112
1656937110220230065...0100000029
1656938101220240065...0210000029
\n", 607 | "

1656939 rows × 256 columns

\n", 608 | "
" 609 | ], 610 | "text/plain": [ 611 | " 1 2 3 4 5 6 7 8 9 10 ... 247 248 249 250 251 252 \\\n", 612 | "0 108 0 2 2 1 4 0 1 7 2 ... 0 0 0 0 0 0 \n", 613 | "1 110 1 2 1 1 3 0 1 8 1 ... 0 0 0 0 0 0 \n", 614 | "2 114 1 1 2 0 3 0 0 4 1 ... 0 0 0 0 0 0 \n", 615 | "3 96 1 3 1 0 3 0 0 6 3 ... 0 0 0 0 0 0 \n", 616 | "4 109 0 2 2 0 4 0 0 4 3 ... 0 0 0 0 0 0 \n", 617 | "... ... .. .. .. .. .. .. .. .. .. ... ... ... ... ... ... ... \n", 618 | "1656934 115 3 1 0 0 3 3 3 3 4 ... 0 0 0 0 0 0 \n", 619 | "1656935 107 2 2 0 0 3 1 3 2 4 ... 0 0 0 0 0 0 \n", 620 | "1656936 108 3 1 0 0 4 0 3 5 4 ... 0 1 0 0 0 0 \n", 621 | "1656937 110 2 2 0 2 3 0 0 6 5 ... 0 1 0 0 0 0 \n", 622 | "1656938 101 2 2 0 2 4 0 0 6 5 ... 0 2 1 0 0 0 \n", 623 | "\n", 624 | " 253 254 255 256 \n", 625 | "0 0 0 2 8 \n", 626 | "1 0 0 2 9 \n", 627 | "2 0 0 1 9 \n", 628 | "3 0 0 2 10 \n", 629 | "4 0 0 1 9 \n", 630 | "... ... ... ... ... \n", 631 | "1656934 0 0 1 12 \n", 632 | "1656935 0 0 2 9 \n", 633 | "1656936 0 0 1 12 \n", 634 | "1656937 0 0 2 9 \n", 635 | "1656938 0 0 2 9 \n", 636 | "\n", 637 | "[1656939 rows x 256 columns]" 638 | ] 639 | }, 640 | "execution_count": 10, 641 | "metadata": {}, 642 | "output_type": "execute_result" 643 | } 644 | ], 645 | "source": [ 646 | "new_df" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": 11, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "X = new_df.to_numpy()\n", 656 | "y = df['label']" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 21, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "np.savez_compressed('../../Data/Car-Hacking/full_histogram.npz', X=X, y=y)" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "metadata": {}, 671 | "source": [ 672 | "## Modeling" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": 3, 678 | "metadata": {}, 679 | "outputs": [], 680 | "source": [ 681 | "data = np.load('../../Data/Car-Hacking/full_histogram.npz')\n", 682 | "X, y = data['X'], data['y']" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 4, 688 | "metadata": {}, 689 | "outputs": [], 690 | "source": [ 691 | "import faiss\n", 692 | "class FaissKNeighbors:\n", 693 | " def __init__(self, k=5):\n", 694 | " self.index = None\n", 695 | " self.y = None\n", 696 | " self.k = k\n", 697 | "\n", 698 | " def fit(self, X, y):\n", 699 | " self.index = faiss.IndexFlatL2(X.shape[1])\n", 700 | " self.index.add(X.astype(np.float32))\n", 701 | " self.y = y\n", 702 | "\n", 703 | " def predict(self, X):\n", 704 | " distances, indices = self.index.search(X.astype(np.float32), k=self.k)\n", 705 | " votes = self.y[indices]\n", 706 | " predictions = np.array([np.argmax(np.bincount(x)) for x in votes])\n", 707 | " return predictions" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 5, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "from sklearn.model_selection import train_test_split\n", 717 | "from sklearn.neighbors import KNeighborsClassifier\n", 718 | "from sklearn.model_selection import StratifiedShuffleSplit\n", 719 | "from utils import cal_metric" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 6, 725 | "metadata": {}, 726 | "outputs": [ 727 | { 728 | "name": "stdout", 729 | "output_type": "stream", 730 | "text": [ 731 | "CV: 1\n", 732 | "{'fnr': array([0.00879808, 0.58000368, 1.13926994, 0.22710468, 0.16425121]), 'rec': array([0.99991202, 0.99419996, 0.9886073 , 0.99772895, 0.99835749]), 'pre': array([0.99720199, 0.99935221, 0.99984324, 1. , 1. ]), 'f1': array([0.99855517, 0.99676943, 0.99419352, 0.99886319, 0.99917807])}\n", 733 | "CV: 2\n", 734 | "{'fnr': array([0.00488782, 0.52476524, 1.19352089, 0.21126017, 0.16908213]), 'rec': array([0.99995112, 0.99475235, 0.98806479, 0.9978874 , 0.99830918]), 'pre': array([0.99721182, 0.99990746, 0.99968635, 1. , 1. ]), 'f1': array([0.99857959, 0.99732324, 0.9938416 , 0.99894258, 0.99915387])}\n", 735 | "CV: 3\n", 736 | "{'fnr': array([0.00391026, 0.39587553, 1.12376967, 0.22710468, 0.13043478]), 'rec': array([0.9999609 , 0.99604124, 0.9887623 , 0.99772895, 0.99869565]), 'pre': array([0.99748413, 0.99972279, 0.99992162, 1. , 1. ]), 'f1': array([0.99872098, 0.99787862, 0.99431065, 0.99886319, 0.9993474 ])}\n", 737 | "CV: 4\n", 738 | "{'fnr': array([0.00488782, 0.50635242, 1.40277455, 0.23766769, 0.15458937]), 'rec': array([0.99995112, 0.99493648, 0.98597225, 0.99762332, 0.99845411]), 'pre': array([0.99694941, 0.99972248, 0.99984282, 1. , 1. ]), 'f1': array([0.99844801, 0.99732374, 0.99285909, 0.99881025, 0.99922646])}\n", 739 | "CV: 5\n", 740 | "{'fnr': array([0.00391026, 0.51555883, 1.11601953, 0.23238618, 0.17391304]), 'rec': array([0.9999609 , 0.99484441, 0.9888398 , 0.99767614, 0.99826087]), 'pre': array([0.99727018, 0.99972245, 0.99992163, 1. , 1. ]), 'f1': array([0.99861373, 0.99727747, 0.99434984, 0.99883672, 0.99912968])}\n" 741 | ] 742 | } 743 | ], 744 | "source": [ 745 | "sss = StratifiedShuffleSplit(n_splits=5, random_state=0)\n", 746 | "ys = []\n", 747 | "y_preds = []\n", 748 | "total_results = {\n", 749 | " 'fnr': np.zeros(5),\n", 750 | " 'rec': np.zeros(5),\n", 751 | " 'pre': np.zeros(5),\n", 752 | " 'f1': np.zeros(5)\n", 753 | "}\n", 754 | "for i, (train_index, test_index) in enumerate(sss.split(X, y)):\n", 755 | " print('CV: ', i + 1)\n", 756 | " X_train, X_test = X[train_index], X[test_index]\n", 757 | " y_train, y_test = y[train_index], y[test_index]\n", 758 | " ys.append(y_test)\n", 759 | " # Prediction\n", 760 | " fast_knn = FaissKNeighbors(k=5)\n", 761 | " fast_knn.fit(X_train, y_train)\n", 762 | " y_pred = fast_knn.predict(X_test)\n", 763 | " y_preds.append(y_pred)\n", 764 | " ############\n", 765 | " cm, results = cal_metric(y_test, y_pred)\n", 766 | " print(results)\n", 767 | " for k, v in total_results.items():\n", 768 | " v += results[k]" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 7, 774 | "metadata": {}, 775 | "outputs": [ 776 | { 777 | "data": { 778 | "text/plain": [ 779 | "{'fnr': array([0.00527885, 0.50451114, 1.19507091, 0.22710468, 0.15845411]),\n", 780 | " 'rec': array([0.99994721, 0.99495489, 0.98804929, 0.99772895, 0.99841546]),\n", 781 | " 'pre': array([0.99722351, 0.99968548, 0.99984313, 1. , 1. ]),\n", 782 | " 'f1': array([0.9985835 , 0.9973145 , 0.99391094, 0.99886318, 0.9992071 ])}" 783 | ] 784 | }, 785 | "execution_count": 7, 786 | "metadata": {}, 787 | "output_type": "execute_result" 788 | } 789 | ], 790 | "source": [ 791 | "total_results = {k: v/5 for k, v in total_results.items()}\n", 792 | "total_results\n" 793 | ] 794 | } 795 | ], 796 | "metadata": { 797 | "kernelspec": { 798 | "display_name": "Python 3.7.10 ('torch')", 799 | "language": "python", 800 | "name": "python3" 801 | }, 802 | "language_info": { 803 | "codemirror_mode": { 804 | "name": "ipython", 805 | "version": 3 806 | }, 807 | "file_extension": ".py", 808 | "mimetype": "text/x-python", 809 | "name": "python", 810 | "nbconvert_exporter": "python", 811 | "pygments_lexer": "ipython3", 812 | "version": "3.7.10" 813 | }, 814 | "orig_nbformat": 4, 815 | "vscode": { 816 | "interpreter": { 817 | "hash": "ec8a7a313ab33d199c8aa698bb86bd912b8385ce4922a6e184e3f5edd5eb95f6" 818 | } 819 | } 820 | }, 821 | "nbformat": 4, 822 | "nbformat_minor": 2 823 | } 824 | --------------------------------------------------------------------------------