├── .idea
├── .name
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── HUST-OBS.iml
└── deployment.xml
├── OCR
├── use.json
├── result.json
├── test_pic1.jpg
├── test_pic2.png
├── Dataset establishment.py
├── test.py
└── train.py
├── MoCo
├── Dataset establishment.py
├── test.py
└── train.py
├── requirements.txt
├── Validation
├── standard deviation.py
├── Dataset establishment.py
├── test.py
├── train.py
└── Validation_label.json
└── README.md
/.idea/.name:
--------------------------------------------------------------------------------
1 | train.py
--------------------------------------------------------------------------------
/OCR/use.json:
--------------------------------------------------------------------------------
1 | ["test_pic1.jpg", "test_pic2.png"]
--------------------------------------------------------------------------------
/OCR/result.json:
--------------------------------------------------------------------------------
1 | {"test_pic1.jpg": ["五"], "test_pic2.png": ["璧"]}
--------------------------------------------------------------------------------
/OCR/test_pic1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pengjie-W/HUST-OBC/HEAD/OCR/test_pic1.jpg
--------------------------------------------------------------------------------
/OCR/test_pic2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pengjie-W/HUST-OBC/HEAD/OCR/test_pic2.png
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/HUST-OBS.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/OCR/Dataset establishment.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | import shutil
5 | from tqdm import tqdm
6 | import random
7 |
8 | folder_path = './OCR_Dataset'
9 | dataset = []
10 | for root, directories, files in tqdm(os.walk(folder_path)):
11 |
12 | for file in files:
13 | data = {}
14 | file_path = os.path.join(root, file)
15 |
16 | data['label'] = int(file.replace('.png',''))
17 | data['path'] = file_path
18 | dataset.append(copy.deepcopy(data))
19 |
20 | print(len(dataset))
21 | with open('OCR_train.json', 'w', encoding='utf-8') as f:
22 | json.dump(dataset, f, ensure_ascii=False)
23 |
--------------------------------------------------------------------------------
/MoCo/Dataset establishment.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | import shutil
5 | from tqdm import tqdm
6 | import random
7 |
8 | folder_path = '../HUST-OBC/deciphered'
9 | dataset = []
10 | for root, directories, files in tqdm(os.walk(folder_path)):
11 |
12 | for file in files:
13 | if'ID'in file:
14 | continue
15 | data = {}
16 | file_path = os.path.join(root, file)
17 |
18 | data['label'] = int(file[2:6])
19 | data['path'] = file_path
20 | dataset.append(copy.deepcopy(data))
21 |
22 | print(len(dataset))
23 | with open('MOCO_train.json', 'w', encoding='utf-8') as f:
24 | json.dump(dataset, f, ensure_ascii=False)
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | filelock==3.14.0
2 | fsspec==2024.5.0
3 | Jinja2==3.1.4
4 | joblib==1.4.2
5 | MarkupSafe==2.1.5
6 | mpmath==1.3.0
7 | networkx==3.3
8 | numpy==1.26.4
9 | nvidia-cublas-cu12==12.1.3.1
10 | nvidia-cuda-cupti-cu12==12.1.105
11 | nvidia-cuda-nvrtc-cu12==12.1.105
12 | nvidia-cuda-runtime-cu12==12.1.105
13 | nvidia-cudnn-cu12==8.9.2.26
14 | nvidia-cufft-cu12==11.0.2.54
15 | nvidia-curand-cu12==10.3.2.106
16 | nvidia-cusolver-cu12==11.4.5.107
17 | nvidia-cusparse-cu12==12.1.0.106
18 | nvidia-nccl-cu12==2.20.5
19 | nvidia-nvjitlink-cu12==12.5.40
20 | nvidia-nvtx-cu12==12.1.105
21 | opencv-python==4.9.0.80
22 | pandas==2.2.2
23 | pillow==10.3.0
24 | python-dateutil==2.9.0.post0
25 | pytz==2024.1
26 | scikit-learn==1.5.0
27 | scipy==1.13.1
28 | six==1.16.0
29 | sympy==1.12
30 | threadpoolctl==3.5.0
31 | torch==2.3.0
32 | torchaudio==2.3.0
33 | torchvision==0.18.0
34 | tqdm==4.66.4
35 | triton==2.3.0
36 | typing_extensions==4.12.0
37 | tzdata==2024.1
38 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/Validation/standard deviation.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import torch
4 | from PIL import Image
5 | from torchvision import transforms
6 | from torchvision.datasets import ImageFolder
7 | from torch.utils.data import DataLoader, Dataset
8 |
9 | class TrainData(Dataset) :
10 | def __init__(self,transform = None):
11 | super(TrainData, self).__init__()
12 | with open('Validation_train.json', 'r', encoding='utf8') as f:
13 | images=json.load(f)
14 | labels=images
15 | self.images, self.labels = images, labels
16 | self.transform = transform
17 |
18 | def __getitem__(self, item):
19 | # 读取图片
20 | image = Image.open(self.images[item]['path'].replace('\\','/'))
21 | if image.mode == 'L':
22 | image = image.convert('RGB')
23 | width, height = image.size
24 | if width>height:
25 | dy = width - height
26 |
27 | yl = round(dy / 2)
28 | yr = dy - yl
29 | train_transform = transforms.Compose([
30 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'),
31 | ])
32 | else:
33 | dx = height - width
34 | xl = round(dx / 2)
35 | xr = dx - xl
36 | train_transform = transforms.Compose([
37 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'),
38 | ])
39 |
40 | image = train_transform(image)
41 | train_transform = transforms.Compose([
42 | transforms.Resize((224, 224)),
43 | transforms.ToTensor(),])
44 | image = train_transform(image)
45 | return image,self.images[item]['label']
46 |
47 | def __len__(self):
48 | return len(self.images)
49 | def getStat(train_data):
50 | '''
51 | Compute mean and variance for training data
52 | :param train_data: 自定义类Dataset(或ImageFolder即可)
53 | :return: (mean, std)
54 | '''
55 | print('Compute mean and variance for training data.')
56 | print(len(train_data))
57 | train_loader = torch.utils.data.DataLoader(
58 | train_data, batch_size=1, shuffle=False, num_workers=4,
59 | pin_memory=True)
60 | mean = torch.zeros(3)
61 | std = torch.zeros(3)
62 | for X, _ in tqdm(train_loader):
63 | for d in range(3):
64 | mean[d] += X[:, d, :, :].mean()
65 | std[d] += X[:, d, :, :].std()
66 | mean.div_(len(train_data))
67 | std.div_(len(train_data))
68 | return list(mean.numpy()), list(std.numpy())
69 |
70 |
71 | if __name__ == '__main__':
72 | train_dataset = TrainData()
73 | print(getStat(train_dataset))
74 |
--------------------------------------------------------------------------------
/Validation/Dataset establishment.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | from pathlib import Path
5 | import numpy as np
6 | from sklearn.model_selection import StratifiedShuffleSplit
7 | # 设置随机种子
8 | seed = 42
9 | np.random.seed(seed)
10 | import random
11 | random.seed(seed)
12 |
13 | dataset = []
14 | X = []
15 | y = []
16 | for root, directories, files in os.walk('../HUST-OBC/deciphered/'):
17 | for file in files:
18 | data = {}
19 | if 'json' in file:
20 | continue
21 | file_path = str(Path(os.path.join(root, file)))
22 | folders = os.path.split(file_path)[0].split(os.sep)
23 | folder_name = folders[3]
24 | y.append(folder_name)
25 | X.append(file_path)
26 |
27 | # 找出样本数量
28 | unique_classes, class_counts = np.unique(y, return_counts=True)
29 | single_sample_classes = unique_classes[class_counts == 1]
30 | multiple_sample_classes = unique_classes[class_counts > 1]
31 |
32 | # 将只有一个样本的类别放入训练集
33 | train_indices = [idx for idx, label in enumerate(y) if label in single_sample_classes]
34 | remaining_indices = [idx for idx in range(len(y)) if idx not in train_indices]
35 |
36 | # 剩余样本分割为训练集和验证/测试集 (8:2)
37 | X_train = [X[idx] for idx in train_indices]
38 | y_train = [y[idx] for idx in train_indices]
39 |
40 | X_remaining = [X[idx] for idx in remaining_indices]
41 | y_remaining = [y[idx] for idx in remaining_indices]
42 |
43 | # 分层抽样分割
44 | stratified_splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
45 | for train_idx, test_val_idx in stratified_splitter.split(X_remaining, y_remaining):
46 | X_train_remaining = [X_remaining[idx] for idx in train_idx]
47 | y_train_remaining = [y_remaining[idx] for idx in train_idx]
48 | X_test_val = [X_remaining[idx] for idx in test_val_idx]
49 | y_test_val = [y_remaining[idx] for idx in test_val_idx]
50 |
51 | X_train.extend(X_train_remaining)
52 | y_train.extend(y_train_remaining)
53 |
54 | unique_classes, class_counts = np.unique(y_test_val, return_counts=True)
55 | single_sample_classes = unique_classes[class_counts == 1]
56 | multiple_sample_classes = unique_classes[class_counts > 1]
57 |
58 | # 将只有一个样本的类别放入随机放入验证集和测试集
59 | train_indices = [idx for idx, label in enumerate(y_test_val) if label in single_sample_classes]
60 | remaining_indices = [idx for idx in range(len(y_test_val)) if idx not in train_indices]
61 |
62 | X_val = []
63 | y_val = []
64 | X_test = []
65 | y_test = []
66 |
67 | for idx in train_indices:
68 | if np.random.rand() < 0.5:
69 | X_val.append(X_test_val[idx])
70 | y_val.append(y_test_val[idx])
71 | else:
72 | X_test.append(X_test_val[idx])
73 | y_test.append(y_test_val[idx])
74 |
75 | # 将其余样本进行分层抽样分割验证集和测试集 (1:1)
76 |
77 | X_test_val_remaining = [X_test_val[idx] for idx in remaining_indices]
78 | y_test_val_remaining = [y_test_val[idx] for idx in remaining_indices]
79 |
80 | stratified_splitter_test_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
81 | for val_idx, test_idx in stratified_splitter_test_val.split(X_test_val_remaining, y_test_val_remaining):
82 | X_val.extend([X_test_val_remaining[idx] for idx in val_idx])
83 | y_val.extend([y_test_val_remaining[idx] for idx in val_idx])
84 | X_test.extend([X_test_val_remaining[idx] for idx in test_idx])
85 | y_test.extend([y_test_val_remaining[idx] for idx in test_idx])
86 |
87 | unique_classes, class_counts = np.unique(y, return_counts=True)
88 | # 创建标签字典
89 | dataset = {label: idx for idx, label in enumerate(unique_classes)}
90 |
91 | # 保存训练集、验证集和测试集
92 | train_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_train, y_train)]
93 | val_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_val, y_val)]
94 | test_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_test, y_test)]
95 |
96 | with open('Validation_train.json', 'w', encoding='utf8') as f:
97 | json.dump(train_data, f, ensure_ascii=False)
98 | with open('Validation_label.json', 'w', encoding='utf8') as f:
99 | json.dump(dataset, f, ensure_ascii=False)
100 | with open('Validation_val.json', 'w', encoding='utf8') as f:
101 | json.dump(val_data, f, ensure_ascii=False)
102 | with open('Validation_test.json', 'w', encoding='utf8') as f:
103 | json.dump(test_data, f, ensure_ascii=False)
104 |
105 | print(f'Training set size: {len(train_data)}')
106 | print(f'Validation set size: {len(val_data)}')
107 | print(f'Test set size: {len(test_data)}')
108 |
--------------------------------------------------------------------------------
/Validation/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import os
4 | import random
5 | from datetime import datetime
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchvision
10 | from sklearn.metrics import f1_score
11 | from torch.nn import functional as F
12 | from PIL import Image
13 | from torch import nn
14 | from torch.utils.data import DataLoader, Dataset
15 | from torchvision import transforms
16 | import math
17 | import argparse
18 | from tqdm import tqdm
19 | import pandas as pd
20 |
21 | """### Set arguments"""
22 | parser = argparse.ArgumentParser(description='Test on HUST-OBC')
23 |
24 | parser.add_argument('--lr', '--learning-rate', default=0.015, type=float, metavar='LR', help='initial learning rate',
25 | dest='lr')
26 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', help='number of total epochs to run')
27 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
28 | parser.add_argument('--num_workers', default=0, type=int)
29 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
30 |
31 |
32 | # utils
33 | parser.add_argument('--resume', default='./max_val_acc.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
34 | parser.add_argument('--results-dir', default='test', type=str, metavar='PATH', help='path to cache (default: none)')
35 | args = parser.parse_args() # running in command line
36 | if args.results_dir == '':
37 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
38 | print(args)
39 | args = parser.parse_args() # running in command line
40 |
41 | class TestData(Dataset):
42 | def __init__(self, transform=None):
43 | super(TestData, self).__init__()
44 | with open('Validation_test.json', 'r', encoding='utf8') as f:
45 | images = json.load(f)
46 | labels = images
47 | self.images, self.labels = images, labels
48 | self.transform = transform
49 |
50 | def __getitem__(self, item):
51 | # 读取图片
52 | image = Image.open(self.images[item]['path'].replace('\\','/'))
53 | # 转换
54 | if image.mode == 'L':
55 | image = image.convert('RGB')
56 | width, height = image.size
57 | if width>height:
58 | dy = width - height
59 |
60 | yl = round(dy / 2)
61 | yr = dy - yl
62 | train_transform = transforms.Compose([
63 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'),
64 | ])
65 | else:
66 | dx = height - width
67 | xl = round(dx / 2)
68 | xr = dx - xl
69 | train_transform = transforms.Compose([
70 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'),
71 | ])
72 |
73 | image = train_transform(image)
74 | train_transform = transforms.Compose([
75 | transforms.Resize((128, 128)),
76 | # transforms.CenterCrop(224),
77 | transforms.ToTensor(),
78 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])])
79 | image = train_transform(image)
80 | label = torch.from_numpy(np.array(self.images[item]['label']))
81 | return image, label,self.images[item]['path'].replace('\\','/')
82 |
83 | def __len__(self):
84 | return len(self.images)
85 |
86 |
87 |
88 | test_dataset = TestData()
89 | test_loader = DataLoader(test_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True)
90 |
91 |
92 | net = torchvision.models.resnet50(pretrained=False)
93 | num_ftrs = net.fc.in_features
94 | net.fc = nn.Linear(num_ftrs, 1588)
95 | net = net.cuda(0)
96 |
97 |
98 | def init_weights(m):
99 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
100 | nn.init.xavier_uniform_(m.weight)
101 |
102 |
103 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
104 | loss = nn.CrossEntropyLoss()
105 |
106 |
107 |
108 |
109 |
110 | def accuracy(y_hat, y):
111 | """Compute the number of correct predictions.
112 |
113 | Defined in :numref:`sec_softmax_scratch`"""
114 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
115 | y_hat = torch.argmax(y_hat, dim=1)
116 | if len(y.shape) > 1 and y.shape[1] > 1:
117 | y = torch.argmax(y, dim=1)
118 | cmp = torch.eq(y_hat, y)
119 | return float(torch.sum(cmp).item())
120 |
121 |
122 |
123 | def test(net, test_data_loader, epoch, args):
124 | net.eval()
125 | all_labels = []
126 | all_preds = []
127 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader)
128 | with torch.no_grad():
129 | for image, label,path in test_bar:
130 | image, label = image.cuda(0), label.cuda(0)
131 | y_hat = net(image)
132 | _, preds = torch.max(y_hat, 1)
133 | all_labels.extend(label.cpu().numpy())
134 | all_preds.extend(preds.cpu().numpy())
135 | total_num += image.shape[0]
136 | testacc += accuracy(y_hat, label)
137 | test_bar.set_description(
138 | 'Test Epoch: [{}/{}], testacc: {:.6f}'.format(epoch, args.epochs, testacc / total_num))
139 | f1_macro = f1_score(all_labels, all_preds, average='macro')
140 | f1_micro = f1_score(all_labels, all_preds, average='micro')
141 | print(f'Macro-averaged F1 score: {f1_macro}')
142 | print(f'Micro-averaged F1 score: {f1_micro}')
143 | return testacc / total_num
144 |
145 | results = {'train_loss': [], 'train_acc': [],'test_acc': [], 'lr': []}
146 | epoch_start = 1
147 | if args.resume != '':
148 | checkpoint = torch.load(args.resume)
149 | net.load_state_dict(checkpoint['state_dict'])
150 | optimizer.load_state_dict(checkpoint['optimizer'])
151 | epoch_start = checkpoint['epoch']
152 | print('Loaded from: {}'.format(args.resume))
153 | else:
154 | net.apply(init_weights)
155 |
156 | test_acc = test(net, test_loader, epoch_start, args)
157 | print(test_acc)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
4 |
5 | # HUST-OBC
6 | [](https://arxiv.org/abs/2401.15365)
7 | [](https://doi.org/10.6084/m9.figshare.25040543.v3)
8 | [](https://hyper.ai/datasets/33506)
9 |
10 | Oracle Bone Character data collected by VLRLab of HUST
11 | We have open-sourced the HUST-OBC dataset and the models used in the dataset, including: Chinese OCR, MoCo, and the ResNet50 for Validation.
12 |
13 | ## HUST-OBC Dataset
14 | [HUST-OBC Download](https://figshare.com/s/8a9c0420312d94fc01e3)
15 | ### Tree of our dataset
16 | - HUST-OBC **(We have renamed HUST-OBS to HUST-OBC)**
17 | - deciphered
18 | - ID1
19 | - Source_ID1_Filename
20 | - Source_ID1_Filename
21 | - .....
22 | - ID2
23 | - Source_ID2_Filename
24 | - .....
25 | - ID3
26 | - .....
27 | - chinese_to_ID.json
28 | - ID_to_chinese.json
29 | - undeciphered
30 | - L
31 | - L_?_Filename
32 | - L_?_Filename
33 | - .....
34 | - X
35 | - X_?_Filename
36 | - .....
37 | - Y+H
38 | - Y_?_Filename
39 | - H_?_Filename
40 | - .....
41 | - GuoXueDaShi_1390
42 | - ID1
43 | - Source_ID1_Filename
44 | - Source_ID1_Filename
45 | - .....
46 | - ID2
47 | - Source_ID2_Filename
48 | - .....
49 | - ID3
50 | - .....
51 | - chinese_to_ID.json
52 | - ID_to_chinese.json
53 |
54 | Source:’X’ represents "New Compilation of Oracle Bone Scripts", ’L’ represents the "Oracle Bone Script: Six Digit Numerical Code",’G’ represents the "GuoXueDaShi" website, ’Y’ represents the "YinQiWenYuan" website, and ’H’ represents the HWOBC dataset, they are the sources of the data.
55 | ## Environment
56 | ```bash
57 | conda create -n HUST-OBC python=3.10
58 | conda activate HUST-OBC
59 | git clone https://github.com/Pengjie-W/HUST-OBC.git
60 | cd HUST-OBC
61 | pip install -r requirements.txt
62 | ```
63 | ## Instructions for use
64 | To use MoCo or Validation, you need to download HUST-OBC. You can then directly use their trained models for prediction. If you want to use Chinese OCR, please download the OCR dataset and the corresponding model. After downloading, organize the data as follows.
65 |
66 | - Your_dataroot
67 | - [HUST-OBC](https://figshare.com/s/8a9c0420312d94fc01e3)
68 | - deciphered
69 | - ...
70 | - MoCo
71 | - [model_last.pth](https://figshare.com/s/30c206b1d1f1870ae76f)
72 | - ...
73 | - OCR
74 | - [OCR_Dataset](https://figshare.com/s/b03be2bccdd867b73e5f)
75 | - [model_last.pth](https://figshare.com/s/7ec755b4ba77c6994ed2)
76 | - ...
77 | - Validation
78 | - [max_val_acc.pth](https://figshare.com/s/4149c5c7f52e0f99e366)
79 | - ...
80 |
81 | ## Chinese OCR
82 | The code for training and testing (usage) is provided in the OCR folder. Includes recognition of 88,899 classes of Chinese characters. [Model download](https://figshare.com/s/7ec755b4ba77c6994ed2). Category numbers and their corresponding Chinese characters are stored in OCR/label.json. We have provided models and code with α set to 0.
83 | [OCR Dataset download](https://figshare.com/s/b03be2bccdd867b73e5f).
84 |
85 |
86 |
87 | You can use [train.py](OCR/train.py) for fine-tuning or retraining. [Chinese_to_ID.json](OCR/Chinese_to_ID.json) and [ID_to_Chinese.json](OCR/ID_to_Chinese.json) store the mappings between OCR dataset category IDs and Chinese characters. [Dataset establishment.py]() is used to generate the training dataset [OCR_train.json](OCR/OCR_train.json). Once the model is downloaded, you can directly use [test.py](OCR/test.py) for testing, which includes two example test images that are Chinese character images cropped from other PDFs. It's best to use images with a white background. [use.json](OCR/use.json) contains the paths to the test images, saved in a list format. The recognized content is output to [result.json](OCR/result.json).
88 |
89 |
90 | ## MoCo
91 | The code for training and testing (usage) is provided in the MoCo folder. [Model download](https://figshare.com/s/30c206b1d1f1870ae76f).
92 |
93 |
94 |
95 | You can use [train.py](MoCo/train.py) for fine-tuning or retraining, [Dataset establishment.py]() is used to generate the training dataset [MOCO_train.json](MoCo/MOCO_train.json). After downloading the MoCo model, [test.py](MoCo/test.py) is utilized for operating MoCo on 1,781 unmerged categories of oracle bones, seeking the first sample from another category with a similarity greater than args.w to find the similarity between different categories of oracle bones. The results are saved in [result.json](MoCo/result.json).
96 | ## Validation
97 | The code for training and testing (usage) is provided in the Validation folder. [Model download](https://figshare.com/s/4149c5c7f52e0f99e366).
98 |
99 |
101 |
102 | [Dataset establishment.py]() is used for splitting the dataset. Since the classification model cannot recognize unseen categories, all categories with only one sample are allocated to the train set. [Validation_test.json](Validation/Validation_test.json), [Validation_val.json](Validation/Validation_val.json) and [Validation_train.json](Validation/Validation_train.json) are the test, val and training sets, respectively, split in a 1:1:8 ratio. [standard deviation.py]() is used to obtain the standard deviation of the training set.
103 |
104 | You can use [train.py](Validation/train.py) for fine-tuning or retraining. Once the model is downloaded, you can use [test.py](Validation/test.py) to validate the test set with an accuracy of 94.6%. [log.csv](Validation/log.csv) records the changes in training set accuracy and test set accuracy for each epoch.
105 | [Validation_label.json](Validation/Validation_label.json) stores the relationship between classification IDs and dataset category IDs.
106 |
107 |
--------------------------------------------------------------------------------
/OCR/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import os
4 | import random
5 | from datetime import datetime
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from PIL import Image
11 | from torch import nn
12 | from torch.utils.data import DataLoader, Dataset
13 | from torchvision import transforms
14 | import math
15 | import argparse
16 | from tqdm import tqdm
17 | import pandas as pd
18 |
19 | """### Set arguments"""
20 |
21 | parser = argparse.ArgumentParser(description='Test on Chinese OCR Dataset')
22 | # utils
23 | parser.add_argument('--resume', default='model_last.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
24 | parser.add_argument('--results-dir', default='test', type=str, metavar='PATH', help='path to cache (default: none)')
25 | parser.add_argument('--k', default=1, type=int)#Display the top k Chinese characters by probability.
26 | parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size')
27 | parser.add_argument('--num_workers', default=0, type=int)
28 | args = parser.parse_args() # running in command line
29 |
30 | if args.results_dir == '':
31 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
32 | print(args)
33 | args = parser.parse_args() # running in command line
34 |
35 | with open('ID_to_Chinese.json', 'r', encoding='utf8') as f:
36 | data = json.load(f)
37 | class TestData(Dataset):
38 | def __init__(self, transform=None):
39 | super(TestData, self).__init__()
40 | with open('use.json', 'r') as f:
41 | images = json.load(f)
42 | labels = images
43 | self.images, self.labels = images, labels
44 | self.transform = transform
45 |
46 | def __getitem__(self, item):
47 | # 读取图片
48 | image = Image.open(self.images[item])
49 | # 转换
50 | if image.mode == 'L':
51 | image = image.convert('RGB')
52 | # 获取当前图像的尺寸
53 | width, height = image.size
54 | if width > height:
55 | dy = width - height
56 |
57 | yl = round(dy / 2)
58 | yr = dy - yl
59 | train_transform = transforms.Compose([
60 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'),
61 | ])
62 | else:
63 | dx = height - width
64 | xl = round(dx / 2)
65 | xr = dx - xl
66 | train_transform = transforms.Compose([
67 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'),
68 | ])
69 |
70 | image = train_transform(image)
71 | train_transform = transforms.Compose([
72 | transforms.Resize((128, 128)),
73 | transforms.ToTensor(),
74 | transforms.Normalize([0.7760929, 0.7760929, 0.7760929], [0.39767382, 0.39767382, 0.39767382])])
75 | image = train_transform(image)
76 | return image,self.images[item]
77 |
78 | def __len__(self):
79 | return len(self.images)
80 |
81 | test_dataset = TestData()
82 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True)
83 |
84 |
85 | class Residual(nn.Module):
86 | def __init__(self, input_channels, min_channels, num_channels,
87 | use_1x1conv=False, strides=1):
88 | super().__init__()
89 | self.conv1 = nn.Conv2d(input_channels, min_channels,
90 | kernel_size=1)
91 | self.conv2 = nn.Conv2d(min_channels, min_channels,
92 | kernel_size=3, padding=1, stride=strides)
93 | self.conv3 = nn.Conv2d(min_channels, num_channels,
94 | kernel_size=1)
95 | if use_1x1conv:
96 | self.conv4 = nn.Conv2d(input_channels, num_channels,
97 | kernel_size=1, stride=strides)
98 | else:
99 | self.conv4 = None
100 | self.bn1 = nn.BatchNorm2d(min_channels)
101 | self.bn2 = nn.BatchNorm2d(min_channels)
102 | self.bn3 = nn.BatchNorm2d(num_channels)
103 |
104 | def forward(self, X):
105 | Y = F.relu(self.bn1(self.conv1(X)))
106 | Y = self.bn2(self.conv2(Y))
107 | Y = self.bn3(self.conv3(Y))
108 | if self.conv4:
109 | X = self.conv4(X)
110 | Y += X
111 | return F.relu(Y)
112 |
113 | b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
114 | nn.BatchNorm2d(64), nn.ReLU(),
115 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
116 |
117 |
118 | def resnet_block(input_channels, min_channels, num_channels, num_residuals, stride,
119 | first_block=False):
120 | blk = []
121 | for i in range(num_residuals):
122 | if i == 0 and not first_block:
123 | blk.append(Residual(input_channels, min_channels, num_channels,
124 | use_1x1conv=True, strides=stride))
125 | elif first_block and i == 0:
126 | blk.append(Residual(input_channels, min_channels, num_channels, use_1x1conv=True))
127 | else:
128 | blk.append(Residual(num_channels, min_channels, num_channels))
129 | return blk
130 |
131 |
132 | b2 = nn.Sequential(*resnet_block(64, 64, 256, 3, 2, first_block=True))
133 | b3 = nn.Sequential(*resnet_block(256, 128, 512, 4, 2))
134 | b4 = nn.Sequential(*resnet_block(512, 256, 1024, 6, 2))
135 | b5 = nn.Sequential(*resnet_block(1024, 512, 2048, 2, 2))
136 | net = nn.Sequential(b1, b2, b3, b4, b5,
137 | nn.AdaptiveAvgPool2d((1, 1)),
138 | nn.Flatten(), nn.Linear(2048, 88899))
139 |
140 | net = net.cuda(0)
141 | def init_weights(m):
142 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
143 | nn.init.xavier_uniform_(m.weight)
144 |
145 |
146 | def adjust_learning_rate(optimizer, epoch, args):
147 | """Decay the learning rate based on schedule"""
148 | lr = args.lr
149 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
150 | for param_group in optimizer.param_groups:
151 | param_group['lr'] = lr
152 |
153 |
154 | def accuracy(y_hat, y):
155 | """Compute the number of correct predictions.
156 |
157 | Defined in :numref:`sec_softmax_scratch`"""
158 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
159 | y_hat = torch.argmax(y_hat, dim=1)
160 | if len(y.shape) > 1 and y.shape[1] > 1:
161 | y = torch.argmax(y, dim=1)
162 | cmp = torch.eq(y_hat, y)
163 | return float(torch.sum(cmp).item())
164 |
165 |
166 | def test(net, test_data_loader):
167 | net.eval()
168 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader)
169 | with torch.no_grad():
170 | pathlist=[]
171 | labellist=[]
172 | for image,path in test_bar:
173 | image = image.cuda(0)
174 | y_hat = net(image)
175 | # y_hat = torch.argmax(y_hat, dim=1)
176 | y_hat = torch.topk(y_hat, args.k, dim=1)[1]
177 | label = y_hat.tolist()
178 | labellist=labellist+label
179 | path=list(path)
180 | pathlist=pathlist+path
181 |
182 | dataset={}
183 | for i in range(len(pathlist)):
184 | path_label=[]
185 | path=pathlist[i]
186 | label=labellist[i]
187 | for j in label:
188 | j=str(j).zfill(5)
189 | path_label.append(data[j])
190 | dataset[path]=path_label
191 | with open('result.json', 'w',encoding='utf8') as f:
192 | json.dump(dataset, f, ensure_ascii=False)
193 | return
194 |
195 |
196 | results = {'train_loss': [], 'train_acc': [], 'lr': []}
197 | epoch_start = 1
198 | if args.resume != '':
199 | checkpoint = torch.load(args.resume)
200 | net.load_state_dict(checkpoint['state_dict'])
201 | # optimizer.load_state_dict(checkpoint['optimizer'])
202 | epoch_start = checkpoint['epoch'] + 1
203 | print('Loaded from: {}'.format(args.resume))
204 | else:
205 | net.apply(init_weights)
206 |
207 | test_acc = test(net, test_loader)
--------------------------------------------------------------------------------
/OCR/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import os
4 | import random
5 | from datetime import datetime
6 | import cv2
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from PIL import Image
11 | from torch import nn
12 | from torch.utils.data import DataLoader, Dataset
13 | from torchvision import transforms
14 | import math
15 | import argparse
16 | from tqdm import tqdm
17 | import pandas as pd
18 | # nohup python train.py > output.log 2>&1 &
19 | """### Set arguments"""
20 | parser = argparse.ArgumentParser(description='Train on Chinese OCR Dataset')
21 |
22 | parser.add_argument('--lr', '--learning-rate', default=0.00015, type=float, metavar='LR', help='initial learning rate',
23 | dest='lr')
24 | parser.add_argument('--epochs', default=1800, type=int, metavar='N', help='number of total epochs to run')
25 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
26 | parser.add_argument('--num_workers', default=9, type=int)
27 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
28 | # utils
29 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
30 | parser.add_argument('--results-dir', default='model', type=str, metavar='PATH', help='path to cache (default: none)')
31 | args = parser.parse_args() # running in command line
32 | if args.results_dir == '':
33 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
34 | print(args)
35 | args = parser.parse_args() # running in command line
36 |
37 |
38 | class RandomGaussianBlur(object):
39 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0):
40 | self.p = p
41 | self.min_kernel_size = min_kernel_size
42 | self.max_kernel_size = max_kernel_size
43 | self.min_sigma = min_sigma
44 | self.max_sigma = max_sigma
45 |
46 | def __call__(self, img):
47 | if random.random() < self.p and self.min_kernel_size < self.max_kernel_size:
48 | kernel_size = random.randrange(self.min_kernel_size, self.max_kernel_size + 1, 2)
49 | sigma = random.uniform(self.min_sigma, self.max_sigma)
50 | return transforms.functional.gaussian_blur(img, kernel_size, sigma)
51 | else:
52 | return img
53 |
54 | def jioayan(image):
55 | if np.random.random() < 0.5:
56 | image1 = np.array(image)
57 | # 添加椒盐噪声
58 | salt_vs_pepper_ratio = np.random.uniform(0, 0.4)
59 | amount = np.random.uniform(0, 0.006)
60 | num_salt = np.ceil(amount * image1.size / 3 * salt_vs_pepper_ratio)
61 | num_pepper = np.ceil(amount * image1.size / 3 * (1.0 - salt_vs_pepper_ratio))
62 |
63 | # 在随机位置生成椒盐噪声
64 | coords_salt = [np.random.randint(0, i - 1, int(num_salt)) for i in image1.shape]
65 | coords_pepper = [np.random.randint(0, i - 1, int(num_pepper)) for i in image1.shape]
66 | image1[coords_salt[0], coords_salt[1], :] = 255
67 | image1[coords_pepper[0], coords_pepper[1], :] = 0
68 | image = Image.fromarray(image1)
69 | return image
70 | def pengzhang(image):
71 |
72 | # 生成一个0到2之间的随机数
73 | random_value = random.random() * 3
74 |
75 | if random_value < 1: # 1/3的概率进行加法操作
76 | he = random.randint(1, 3)
77 | kernel = np.ones((he, he), np.uint8)
78 | image = cv2.erode(image, kernel, iterations=1)
79 | elif random_value < 2: # 1/3的概率进行除法操作
80 | he = random.randint(1, 3) # 生成一个1到10之间的随机整数作为除数
81 | kernel = np.ones((he,he),np.uint8)
82 | image = cv2.dilate(image,kernel,iterations = 1)
83 | return image
84 |
85 | class TrainData(Dataset):
86 | def __init__(self, transform=None):
87 | super(TrainData, self).__init__()
88 | with open('OCR_train.json', 'r') as f:
89 | images = json.load(f)
90 | labels = images
91 | self.images, self.labels = images, labels
92 | self.transform = transform
93 |
94 | def __getitem__(self, item):
95 | # 读取图片
96 | image = Image.open(self.images[item]['path'].replace('\\','/'))
97 | # 转换
98 | if image.mode == 'L':
99 | image = image.convert('RGB')
100 | x, y = 72,72
101 | sizey, sizex = 129, 129
102 | if y < 128:
103 | while sizey > 128 or sizey < 16:
104 | sizey = round(random.gauss(y, 30))
105 | if x < 128:
106 | while sizex > 128 or sizex < 16:
107 | sizex = round(random.gauss(x, 30))
108 | dx = 128 - sizex # 差值
109 | dy = 128 - sizey
110 | if dx > 0:
111 | xl =-1
112 | while xl > dx or xl < 0:
113 | xl = round(dx / 2)
114 | xl = round(random.gauss(xl, 10))
115 | else:
116 | xl = 0
117 | if dy > 0:
118 | yl = -1
119 | while yl > dy or yl < 0:
120 | yl = round(dy / 2)
121 | yl = round(random.gauss(yl, 10))
122 | else:
123 | yl = 0
124 | yr = dy - yl
125 | xr = dx - xl
126 | image = jioayan(image)
127 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
128 | image = pengzhang(image)
129 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
130 | random_gaussian_blur = RandomGaussianBlur()
131 | image = random_gaussian_blur(image)
132 | train_transform = transforms.Compose([
133 | transforms.Resize((sizey, sizex)),
134 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'),
135 | transforms.RandomRotation(degrees=(-15, 15), center=(round(64), round(64)), fill=(255, 255, 255)),
136 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
137 | transforms.RandomGrayscale(p=0.2),
138 | transforms.ToTensor(),
139 | transforms.Normalize([0.7760929, 0.7760929, 0.7760929], [0.39767382, 0.39767382, 0.39767382])])
140 | image = train_transform(image)
141 | label = torch.from_numpy(np.array(self.images[item]['label']))
142 | return image, label
143 |
144 | def __len__(self):
145 | return len(self.images)
146 |
147 |
148 | train_dataset = TrainData()
149 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True)
150 |
151 |
152 | class Residual(nn.Module):
153 | def __init__(self, input_channels, min_channels, num_channels,
154 | use_1x1conv=False, strides=1):
155 | super().__init__()
156 | self.conv1 = nn.Conv2d(input_channels, min_channels,
157 | kernel_size=1)
158 | self.conv2 = nn.Conv2d(min_channels, min_channels,
159 | kernel_size=3, padding=1, stride=strides)
160 | self.conv3 = nn.Conv2d(min_channels, num_channels,
161 | kernel_size=1)
162 | if use_1x1conv:
163 | self.conv4 = nn.Conv2d(input_channels, num_channels,
164 | kernel_size=1, stride=strides)
165 | else:
166 | self.conv4 = None
167 | self.bn1 = nn.BatchNorm2d(min_channels)
168 | self.bn2 = nn.BatchNorm2d(min_channels)
169 | self.bn3 = nn.BatchNorm2d(num_channels)
170 |
171 | def forward(self, X):
172 | Y = F.relu(self.bn1(self.conv1(X)))
173 | Y = self.bn2(self.conv2(Y))
174 | Y = self.bn3(self.conv3(Y))
175 | if self.conv4:
176 | X = self.conv4(X)
177 | Y += X
178 | return F.relu(Y)
179 |
180 | b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
181 | nn.BatchNorm2d(64), nn.ReLU(),
182 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
183 |
184 |
185 | def resnet_block(input_channels, min_channels, num_channels, num_residuals, stride,
186 | first_block=False):
187 | blk = []
188 | for i in range(num_residuals):
189 | if i == 0 and not first_block:
190 | blk.append(Residual(input_channels, min_channels, num_channels,
191 | use_1x1conv=True, strides=stride))
192 | elif first_block and i == 0:
193 | blk.append(Residual(input_channels, min_channels, num_channels, use_1x1conv=True))
194 | else:
195 | blk.append(Residual(num_channels, min_channels, num_channels))
196 | return blk
197 |
198 |
199 | b2 = nn.Sequential(*resnet_block(64, 64, 256, 3, 2, first_block=True))
200 | b3 = nn.Sequential(*resnet_block(256, 128, 512, 4, 2))
201 | b4 = nn.Sequential(*resnet_block(512, 256, 1024, 6, 2))
202 | b5 = nn.Sequential(*resnet_block(1024, 512, 2048, 2, 2))
203 | net = nn.Sequential(b1, b2, b3, b4, b5,
204 | nn.AdaptiveAvgPool2d((1, 1)),
205 | nn.Flatten(), nn.Linear(2048, 88899))
206 |
207 | net = net.cuda(0)
208 | def init_weights(m):
209 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
210 | nn.init.xavier_uniform_(m.weight)
211 |
212 |
213 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
214 | loss = nn.CrossEntropyLoss()
215 |
216 |
217 | def adjust_learning_rate(optimizer, epoch, args):
218 | """Decay the learning rate based on schedule"""
219 | lr = args.lr
220 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
221 | for param_group in optimizer.param_groups:
222 | param_group['lr'] = lr
223 |
224 |
225 | def accuracy(y_hat, y):
226 | """Compute the number of correct predictions.
227 |
228 | Defined in :numref:`sec_softmax_scratch`"""
229 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
230 | y_hat = torch.argmax(y_hat, dim=1)
231 | if len(y.shape) > 1 and y.shape[1] > 1:
232 | y = torch.argmax(y, dim=1)
233 | cmp = torch.eq(y_hat, y)
234 | return float(torch.sum(cmp).item())
235 |
236 | def train(net, data_loader, train_optimizer, epoch, args):
237 | net.train()
238 | adjust_learning_rate(optimizer, epoch, args)
239 | total_loss, total_num, trainacc, train_bar = 0.0, 0, 0.0, tqdm(data_loader)
240 | for image, label in train_bar:
241 | image, label = image.cuda(0), label.cuda(0)
242 | label = label.long()
243 | y_hat = net(image)
244 |
245 | train_optimizer.zero_grad()
246 | l = loss(y_hat, label)
247 | l.backward()
248 | train_optimizer.step()
249 | trainacc += accuracy(y_hat, label)
250 | # total_num += data_loader.abatch_size
251 | total_num += image.shape[0]
252 | total_loss += l.item() * data_loader.batch_size
253 | train_bar.set_description(
254 | 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}, trainacc: {:.6f}'.format(epoch, args.epochs,
255 | optimizer.param_groups[0]['lr'],
256 | total_loss / total_num,
257 | trainacc / total_num))
258 |
259 | return total_loss / total_num, trainacc / total_num
260 |
261 | def test(net, test_data_loader, epoch, args):
262 | net.eval()
263 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader)
264 | with torch.no_grad():
265 | for image, label in test_bar:
266 | image, label = image.cuda(0), label.cuda(0)
267 | y_hat = net(image)
268 | total_num += test_data_loader.batch_size
269 | testacc += accuracy(y_hat, label)
270 | test_bar.set_description(
271 | 'Test Epoch: [{}/{}], testacc: {:.6f}'.format(epoch, args.epochs, testacc / total_num))
272 | return testacc / total_num
273 |
274 | results = {'train_loss': [], 'train_acc': [], 'lr': []}
275 | epoch_start = 1
276 | if args.resume != '':
277 | checkpoint = torch.load(args.resume)
278 | net.load_state_dict(checkpoint['state_dict'])
279 | optimizer.load_state_dict(checkpoint['optimizer'])
280 | epoch_start = checkpoint['epoch'] + 1
281 | print('Loaded from: {}'.format(args.resume))
282 | else:
283 | net.apply(init_weights)
284 |
285 | if not os.path.exists(args.results_dir):
286 | os.mkdir(args.results_dir)
287 | with open(args.results_dir + '/args.json', 'w') as fid:
288 | json.dump(args.__dict__, fid, indent=2)
289 | for epoch in range(epoch_start, args.epochs + 1):
290 | train_loss, train_acc = train(net, train_loader, optimizer, epoch, args)
291 | results['train_loss'].append(train_loss)
292 | results['train_acc'].append(train_acc)
293 | results['lr'].append(args.lr *0.5 * (1. + math.cos(math.pi * epoch / args.epochs)))
294 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
295 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
296 | # save model
297 | torch.save({'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), },
298 | args.results_dir + '/model_last.pth')
299 |
--------------------------------------------------------------------------------
/Validation/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import os
4 | import random
5 | from datetime import datetime
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchvision
10 | from sklearn.metrics import f1_score
11 | from torch.nn import functional as F
12 | from PIL import Image
13 | from torch import nn
14 | from torch.utils.data import DataLoader, Dataset
15 | from torchvision import transforms
16 | import math
17 | import argparse
18 | from tqdm import tqdm
19 | import pandas as pd
20 |
21 |
22 |
23 | """### Set arguments"""
24 | parser = argparse.ArgumentParser(description='Train on HUST-OBC')
25 |
26 | parser.add_argument('--lr', '--learning-rate', default=0.015, type=float, metavar='LR', help='initi'
27 | 'al learning rate',
28 | dest='lr')
29 | parser.add_argument('--epochs', default=600, type=int, metavar='N', help='number of total epochs to run')
30 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
31 | parser.add_argument('--num_workers', default=24, type=int)
32 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
33 |
34 | # utils
35 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
36 | parser.add_argument('--results-dir', default='output', type=str, metavar='PATH', help='path to cache (default: none)')
37 | parser.add_argument('--checkpoint_freq', type=int, default=100)
38 | parser.add_argument('--seed', type=int, default=42)
39 | args = parser.parse_args() # running in command line
40 | if args.results_dir == '':
41 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
42 | print(args)
43 | args = parser.parse_args() # running in command line
44 | seed=args.seed
45 | torch.manual_seed(seed)
46 | np.random.seed(seed)
47 | random.seed(seed)
48 | class RandomGaussianBlur(object):
49 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0):
50 | self.p = p
51 | self.min_kernel_size = min_kernel_size
52 | self.max_kernel_size = max_kernel_size
53 | self.min_sigma = min_sigma
54 | self.max_sigma = max_sigma
55 |
56 | def __call__(self, img):
57 | if random.random() < self.p and self.min_kernel_size < self.max_kernel_size:
58 | kernel_size = random.randrange(self.min_kernel_size, self.max_kernel_size + 1, 2)
59 | sigma = random.uniform(self.min_sigma, self.max_sigma)
60 | return transforms.functional.gaussian_blur(img, kernel_size, sigma)
61 | else:
62 | return img
63 |
64 | # nohup python train_new.py > output.log 2>&1 &^C
65 | def jioayan(image):
66 | if np.random.random() < 0.5:
67 | image1 = np.array(image)
68 | # 添加椒盐噪声
69 | salt_vs_pepper_ratio = np.random.uniform(0, 0.4)
70 | amount = np.random.uniform(0, 0.006)
71 | num_salt = np.ceil(amount * image1.size / 3 * salt_vs_pepper_ratio)
72 | num_pepper = np.ceil(amount * image1.size / 3 * (1.0 - salt_vs_pepper_ratio))
73 |
74 | # 在随机位置生成椒盐噪声
75 | coords_salt = [np.random.randint(0, i - 1, int(num_salt)) for i in image1.shape]
76 | coords_pepper = [np.random.randint(0, i - 1, int(num_pepper)) for i in image1.shape]
77 | # image1[coords_salt] = 255
78 | image1[coords_salt[0], coords_salt[1], :] = 255
79 | image1[coords_pepper[0], coords_pepper[1], :] = 0
80 | image = Image.fromarray(image1)
81 | return image
82 |
83 |
84 | def pengzhang(image):
85 | # 生成一个0到2之间的随机数
86 | random_value = random.random() * 3
87 |
88 | if random_value < 1: # 1/3的概率进行加法操作
89 | he = random.randint(1, 3)
90 | kernel = np.ones((he, he), np.uint8)
91 | image = cv2.erode(image, kernel, iterations=1)
92 | elif random_value < 2: # 1/3的概率进行除法操作
93 | he = random.randint(1, 3) # 生成一个1到10之间的随机整数作为除数
94 | kernel = np.ones((he, he), np.uint8)
95 | image = cv2.dilate(image, kernel, iterations=1)
96 | return image
97 |
98 |
99 | class TrainData(Dataset):
100 | def __init__(self, transform=None):
101 | super(TrainData, self).__init__()
102 | with open('Validation_train.json', 'r',encoding='utf8') as f:
103 | images = json.load(f)
104 | labels = images
105 | self.images, self.labels = images, labels
106 | self.transform = transform
107 |
108 | def __getitem__(self, item):
109 | # 读取图片
110 | image = Image.open(self.images[item]['path'].replace('\\','/'))
111 | # 转换
112 | if image.mode == 'L':
113 | image = image.convert('RGB')
114 | image_width, image_height = image.size
115 | if image_width > image_height:
116 | x = 72
117 | y = round(image_height / image_width * 72)
118 | # x, y = 72,72
119 | else:
120 | y = 72
121 | x = round(image_width / image_height * 72)
122 | sizey, sizex = 129, 129
123 | if y < 128:
124 | while sizey > 128 or sizey < 16:
125 | sizey = round(random.gauss(y, 30))
126 | if x < 128:
127 | while sizex > 128 or sizex < 16:
128 | sizex = round(random.gauss(x, 30))
129 | dx = 128 - sizex # 差值
130 | dy = 128 - sizey
131 | if dx > 0:
132 | xl = -1
133 | while xl > dx or xl < 0:
134 | xl = round(dx / 2)
135 | xl = round(random.gauss(xl, 10))
136 | else:
137 | xl = 0
138 | if dy > 0:
139 | yl = -1
140 | while yl > dy or yl < 0:
141 | yl = round(dy / 2)
142 | yl = round(random.gauss(yl, 10))
143 | else:
144 | yl = 0
145 | yr = dy - yl
146 | xr = dx - xl
147 | image = jioayan(image)
148 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
149 | image = pengzhang(image)
150 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
151 | random_gaussian_blur = RandomGaussianBlur()
152 | image = random_gaussian_blur(image)
153 | train_transform = transforms.Compose([
154 | transforms.Resize((sizey,sizex)),
155 | transforms.RandomHorizontalFlip(p=0.5),
156 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'),
157 | transforms.RandomRotation(degrees=(-15, 15), center=(round(64), round(64)), fill=(255, 255, 255)),
158 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
159 | transforms.RandomGrayscale(p=0.2),
160 | transforms.ToTensor(),
161 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])])
162 | image = train_transform(image)
163 | label = torch.from_numpy(np.array(self.images[item]['label']))
164 | return image, label
165 |
166 | def __len__(self):
167 | return len(self.images)
168 |
169 |
170 | class ValData(Dataset):
171 | def __init__(self, transform=None):
172 | super(ValData, self).__init__()
173 | with open('Validation_val.json', 'r',encoding='utf8') as f:
174 | images = json.load(f)
175 | labels = images
176 | self.images, self.labels = images, labels
177 | self.transform = transform
178 |
179 | def __getitem__(self, item):
180 | # 读取图片
181 | image = Image.open(self.images[item]['path'].replace('\\','/'))
182 | # 转换
183 | if image.mode == 'L':
184 | image = image.convert('RGB')
185 | width, height = image.size
186 | if width>height:
187 | dy = width - height
188 |
189 | yl = round(dy / 2)
190 | yr = dy - yl
191 | train_transform = transforms.Compose([
192 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'),
193 | ])
194 | else:
195 | dx = height - width
196 | xl = round(dx / 2)
197 | xr = dx - xl
198 | train_transform = transforms.Compose([
199 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'),
200 | ])
201 |
202 | image = train_transform(image)
203 | train_transform = transforms.Compose([
204 | transforms.Resize((128, 128)),
205 | # transforms.CenterCrop(224),
206 | transforms.ToTensor(),
207 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])])
208 | image = train_transform(image)
209 | label = torch.from_numpy(np.array(self.images[item]['label']))
210 | return image, label
211 |
212 | def __len__(self):
213 | return len(self.images)
214 |
215 |
216 | train_dataset = TrainData()
217 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True)
218 |
219 |
220 | val_dataset = ValData()
221 | val_loader = DataLoader(val_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True)
222 |
223 |
224 | net = torchvision.models.resnet50(pretrained=False)
225 | num_ftrs = net.fc.in_features
226 | net.fc = nn.Linear(num_ftrs, 1588)
227 | net = net.cuda(0)
228 |
229 | def init_weights(m):
230 | if type(m) == nn.Linear or type(m) == nn.Conv2d:
231 | nn.init.xavier_uniform_(m.weight)
232 |
233 |
234 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
235 | loss = nn.CrossEntropyLoss()
236 |
237 |
238 | def adjust_learning_rate(optimizer, epoch, args):
239 | """Decay the learning rate based on schedule"""
240 | lr = args.lr
241 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
242 | for param_group in optimizer.param_groups:
243 | param_group['lr'] = lr
244 |
245 |
246 | def accuracy(y_hat, y):
247 | """Compute the number of correct predictions.
248 |
249 | Defined in :numref:`sec_softmax_scratch`"""
250 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
251 | y_hat = torch.argmax(y_hat, dim=1)
252 | if len(y.shape) > 1 and y.shape[1] > 1:
253 | y = torch.argmax(y, dim=1)
254 | cmp = torch.eq(y_hat, y)
255 | return float(torch.sum(cmp).item())
256 |
257 |
258 | def train(net, data_loader, train_optimizer, epoch, args):
259 | net.train()
260 | adjust_learning_rate(optimizer, epoch, args)
261 | total_loss, total_num, trainacc, train_bar = 0.0, 0, 0.0, tqdm(data_loader,ncols=100)
262 | all_labels = []
263 | all_preds = []
264 | for image, label in train_bar:
265 | image, label = image.cuda(0), label.cuda(0)
266 | label = label.long()
267 | y_hat = net(image)
268 | _, preds = torch.max(y_hat, 1)
269 | all_labels.extend(label.cpu().numpy())
270 | all_preds.extend(preds.cpu().numpy())
271 | train_optimizer.zero_grad()
272 | l = loss(y_hat, label)
273 | l.backward()
274 | train_optimizer.step()
275 | trainacc += accuracy(y_hat, label)
276 | # total_num += data_loader.abatch_size
277 | total_num += image.shape[0]
278 | total_loss += l.item() * image.shape[0]
279 | train_bar.set_description(
280 | 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}, trainacc: {:.6f}'.format(epoch, args.epochs,
281 | optimizer.param_groups[0]['lr'],
282 | total_loss / total_num,
283 | trainacc / total_num))
284 | # 计算 F1 分数
285 | f1_macro = f1_score(all_labels, all_preds, average='macro')
286 | f1_micro = f1_score(all_labels, all_preds, average='micro')
287 | # print(f'Macro-averaged F1 score: {f1_macro}')
288 | # print(f'Micro-averaged F1 score: {f1_micro}')
289 |
290 | return total_loss / total_num, trainacc / total_num,f1_macro,f1_micro
291 |
292 |
293 | def val(net, val_data_loader, epoch, args):
294 | net.eval()
295 | valacc, total_top5, total_num, val_bar = 0.0, 0.0, 0, tqdm(val_data_loader)
296 | all_labels = []
297 | all_preds = []
298 | with torch.no_grad():
299 | for image, label in val_bar:
300 | image, label = image.cuda(0), label.cuda(0)
301 | y_hat = net(image)
302 | _, preds = torch.max(y_hat, 1)
303 | all_labels.extend(label.cpu().numpy())
304 | all_preds.extend(preds.cpu().numpy())
305 | total_num+=image.shape[0]
306 | valacc += accuracy(y_hat, label)
307 | val_bar.set_description(
308 | 'Val Epoch: [{}/{}], valacc: {:.6f}'.format(epoch, args.epochs, valacc / total_num))
309 | # 计算 F1 分数
310 | f1_macro = f1_score(all_labels, all_preds, average='macro')
311 | f1_micro = f1_score(all_labels, all_preds, average='micro')
312 | # print(f'Macro-averaged F1 score: {f1_macro}')
313 | # print(f'Micro-averaged F1 score: {f1_micro}')
314 | return valacc / total_num,f1_macro,f1_micro
315 |
316 |
317 | # results = {'train_loss': [], 'train_acc': [], 'val_acc': []}
318 | results = {'train_loss': [], 'train_acc': [],'train_f1_macro':[],'train_f1_micro':[],'val_acc': [],'val_f1_macro':[],'val_f1_micro':[], 'lr': []}
319 | epoch_start = 1
320 | if args.resume != '':
321 | checkpoint = torch.load(args.resume)
322 | net.load_state_dict(checkpoint['state_dict'])
323 | optimizer.load_state_dict(checkpoint['optimizer'])
324 | epoch_start = checkpoint['epoch'] + 1
325 | print('Loaded from: {}'.format(args.resume))
326 | else:
327 | net.apply(init_weights)
328 |
329 | if not os.path.exists(args.results_dir):
330 | os.mkdir(args.results_dir)
331 | with open(args.results_dir + '/args.json', 'w') as fid:
332 | json.dump(args.__dict__, fid, indent=2)
333 | max_val_acc=0
334 | max_val_f1_macro=0
335 | max_val_f1_micro=0
336 | for epoch in range(epoch_start, args.epochs + 1):
337 | train_loss, train_acc,train_f1_macro,train_f1_micro = train(net, train_loader, optimizer, epoch, args)
338 | results['train_loss'].append(train_loss)
339 | results['train_acc'].append(train_acc)
340 | results['train_f1_macro'].append(train_f1_macro)
341 | results['train_f1_micro'].append(train_f1_micro)
342 | val_acc,val_f1_macro,val_f1_micro = val(net, val_loader, epoch, args)
343 | results['val_acc'].append(val_acc)
344 | results['val_f1_macro'].append(val_f1_macro)
345 | results['val_f1_micro'].append(val_f1_micro)
346 | results['lr'].append(args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)))
347 | # save statistics
348 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
349 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
350 | if (epoch) % args.checkpoint_freq == 0:
351 | checkpoint_name = f'checkpoint_ep{epoch:04}.pth'
352 | # save model
353 | torch.save({'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), },
354 | args.results_dir + '/'+checkpoint_name)
355 | if epoch>300 and max_val_acc300 and max_val_f1_macro300 and max_val_f1_micro 1 else nn.BatchNorm2d
146 | resnet_arch = getattr(resnet, arch)
147 | net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)
148 |
149 | self.net = []
150 | for name, module in net.named_children():
151 | if name == 'conv1':
152 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
153 | if isinstance(module, nn.MaxPool2d):
154 | continue
155 | if isinstance(module, nn.Linear):
156 | self.net.append(nn.Flatten(1))
157 | self.net.append(module)
158 |
159 | self.net = nn.Sequential(*self.net)
160 |
161 | def forward(self, x):
162 | x = self.net(x)
163 | # note: not normalized here
164 | return x
165 |
166 | """### Define MoCo wrapper"""
167 |
168 | class ModelMoCo(nn.Module):
169 | def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
170 | super(ModelMoCo, self).__init__()
171 |
172 | self.K = K
173 | self.m = m
174 | self.T = T
175 | self.symmetric = symmetric
176 |
177 | # create the encoders
178 | self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
179 | self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
180 |
181 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
182 | param_k.data.copy_(param_q.data) # initialize
183 | param_k.requires_grad = False # not update by gradient
184 |
185 | # create the queue
186 | self.register_buffer("queue", torch.randn(dim, K))
187 | self.queue = nn.functional.normalize(self.queue, dim=0)
188 |
189 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
190 |
191 | @torch.no_grad()
192 | def _momentum_update_key_encoder(self):
193 | """
194 | Momentum update of the key encoder
195 | """
196 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
197 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
198 |
199 | @torch.no_grad()
200 | def _dequeue_and_enqueue(self, keys):
201 | batch_size = keys.shape[0]
202 |
203 | ptr = int(self.queue_ptr)
204 | assert self.K % batch_size == 0 # for simplicity
205 |
206 | # replace the keys at ptr (dequeue and enqueue)
207 | self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
208 | ptr = (ptr + batch_size) % self.K # move pointer
209 |
210 | self.queue_ptr[0] = ptr
211 |
212 | @torch.no_grad()
213 | def _batch_shuffle_single_gpu(self, x):
214 | """
215 | Batch shuffle, for making use of BatchNorm.
216 | """
217 | # random shuffle index
218 | idx_shuffle = torch.randperm(x.shape[0]).cuda()
219 |
220 | # index for restoring
221 | idx_unshuffle = torch.argsort(idx_shuffle)
222 |
223 | return x[idx_shuffle], idx_unshuffle
224 |
225 | @torch.no_grad()
226 | def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
227 | """
228 | Undo batch shuffle.
229 | """
230 | return x[idx_unshuffle]
231 |
232 | def contrastive_loss(self, im_q, im_k):
233 | # compute query features
234 | q = self.encoder_q(im_q) # queries: NxC
235 | q = nn.functional.normalize(q, dim=1) # already normalized
236 |
237 | # compute key features
238 | with torch.no_grad(): # no gradient to keys
239 | # shuffle for making use of BN
240 | im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
241 |
242 | k = self.encoder_k(im_k_) # keys: NxC
243 | k = nn.functional.normalize(k, dim=1) # already normalized
244 |
245 | # undo shuffle
246 | k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
247 |
248 | # compute logits
249 | # Einstein sum is more intuitive
250 | # positive logits: Nx1
251 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
252 | # negative logits: NxK
253 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
254 |
255 | # logits: Nx(1+K)
256 | logits = torch.cat([l_pos, l_neg], dim=1)
257 |
258 | # apply temperature
259 | logits /= self.T
260 |
261 | # labels: positive key indicators
262 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
263 |
264 | loss = nn.CrossEntropyLoss().cuda()(logits, labels)
265 |
266 | return loss, q, k
267 |
268 | def forward(self, im1, im2):
269 | """
270 | Input:
271 | im_q: a batch of query images
272 | im_k: a batch of key images
273 | Output:
274 | loss
275 | """
276 |
277 | # update the key encoder
278 | with torch.no_grad(): # no gradient to keys
279 | self._momentum_update_key_encoder()
280 |
281 | # compute loss
282 | if self.symmetric: # asymmetric loss
283 | loss_12, q1, k2 = self.contrastive_loss(im1, im2)
284 | loss_21, q2, k1 = self.contrastive_loss(im2, im1)
285 | loss = loss_12 + loss_21
286 | k = torch.cat([k1, k2], dim=0)
287 | else: # asymmetric loss
288 | loss, q, k = self.contrastive_loss(im1, im2)
289 |
290 | self._dequeue_and_enqueue(k)
291 |
292 | return loss
293 |
294 | # create model
295 | model = ModelMoCo(
296 | dim=args.moco_dim,
297 | K=args.moco_k,
298 | m=args.moco_m,
299 | T=args.moco_t,
300 | arch=args.arch,
301 | bn_splits=args.bn_splits,
302 | symmetric=args.symmetric,
303 | ).cuda()
304 | print(model.encoder_q)
305 |
306 | """### Define train/test
307 |
308 |
309 | """
310 |
311 | # train for one epoch
312 | def train(net, data_loader, train_optimizer, epoch, args,trainlist):
313 | net.train()
314 | adjust_learning_rate(optimizer, epoch, args,trainlist)
315 |
316 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
317 | for im_1,im_2 in train_bar:
318 |
319 | im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)
320 |
321 | loss = net(im_1, im_2)
322 |
323 | train_optimizer.zero_grad()
324 | loss.backward()
325 | train_optimizer.step()
326 |
327 | total_num += data_loader.batch_size
328 | total_loss += loss.item() * data_loader.batch_size
329 | train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))
330 |
331 | return total_loss / total_num
332 |
333 | # lr scheduler for training
334 | def adjust_learning_rate(optimizer, epoch, args,trainlist):
335 | """Decay the learning rate based on schedule"""
336 | lr = args.lr
337 | if args.cos: # cosine lr schedule
338 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
339 | else: # stepwise lr schedule
340 | for milestone in args.schedule:
341 | lr *= 0.1 if epoch >= milestone else 1.
342 | for param_group in optimizer.param_groups:
343 | param_group['lr'] = lr
344 | # test using a knn monitor
345 | def test(net, memory_data_loader, args):
346 | net.eval()
347 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
348 | target_bank=[]
349 | pathbank=[]
350 | with torch.no_grad():
351 | # generate feature bank
352 | for data, target,path in tqdm(memory_data_loader, desc='Feature extracting'):
353 | feature = net(data.cuda(non_blocking=True))
354 | feature = F.normalize(feature, dim=1)
355 | feature_bank.append(feature)
356 | target=list(target)
357 | path=list(path)
358 | pathbank=pathbank+path
359 | target_bank=target_bank+target
360 | # [D, N]
361 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
362 |
363 |
364 | pred_labels = knn_predict(feature_bank, target_bank,pathbank, args.knn_k, args.knn_t)
365 |
366 |
367 | return
368 |
369 |
370 | def knn_predict(feature_bank,target_bank, pathbank,knn_k, knn_t):
371 | dataset=[]
372 | for ii in tqdm(range(feature_bank.size(1))):
373 | baselabel = target_bank[ii]
374 | basepath=pathbank[ii]
375 | temp=feature_bank.permute(1, 0)[ii]
376 | temp = temp.unsqueeze(0)
377 | sim_matrix = torch.mm(temp, feature_bank)
378 | sim_weight, sim_indices = sim_matrix.topk(k=feature_bank.size(1), dim=-1)
379 | sim_weight = (sim_weight / knn_t).exp()
380 | sim_weight=sim_weight/float(math.e**10)
381 | for nnn in range(feature_bank.size(1)):
382 | index=int(sim_indices[0][nnn])
383 | label=target_bank[index]
384 | w = float(sim_weight[0][nnn])
385 | if w<=args.w:
386 | break
387 | if label!=baselabel:
388 |
389 | path=pathbank[index]
390 | if w>args.w:
391 |
392 | data={}
393 | data['path1'] = basepath
394 | data['label1']=str(int(baselabel.item())).zfill(4)
395 | data['path2'] = path
396 | data['label2'] = str(int(label.item())).zfill(4)
397 | data['w'] = w
398 | print(data)
399 | dataset.append(copy.deepcopy(data))
400 | break
401 | with open('result.json', 'w', encoding='utf-8') as f:
402 | json.dump(dataset, f, ensure_ascii=False)
403 |
404 | """### Start training"""
405 |
406 | # define optimizer
407 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
408 |
409 | # load model if resume
410 | epoch_start = 1
411 | if args.resume is not '':
412 | checkpoint = torch.load(args.resume)
413 | model.load_state_dict(checkpoint['state_dict'])
414 | optimizer.load_state_dict(checkpoint['optimizer'])
415 | epoch_start = checkpoint['epoch'] + 1
416 | print('Loaded from: {}'.format(args.resume))
417 |
418 | # logging
419 | results = {'train_loss': []}
420 | # training loop
421 | trainlist=[]
422 | test_acc_1 = test(model.encoder_q, train_loader, args)
--------------------------------------------------------------------------------
/Validation/Validation_label.json:
--------------------------------------------------------------------------------
1 | {"0001": 0, "0002": 1, "0003": 2, "0004": 3, "0005": 4, "0006": 5, "0007": 6, "0008": 7, "0009": 8, "0010": 9, "0011_0012_0013": 10, "0014": 11, "0015_0016": 12, "0017": 13, "0018": 14, "0019": 15, "0020": 16, "0021": 17, "0022": 18, "0023_0024": 19, "0025_0026": 20, "0027": 21, "0028": 22, "0029": 23, "0030": 24, "0031": 25, "0032": 26, "0033": 27, "0034": 28, "0035": 29, "0036": 30, "0037": 31, "0038": 32, "0039": 33, "0040": 34, "0041_0042": 35, "0043": 36, "0044": 37, "0045": 38, "0046": 39, "0047": 40, "0048_0049": 41, "0050": 42, "0051": 43, "0052_0053": 44, "0054": 45, "0055": 46, "0056": 47, "0057": 48, "0058": 49, "0059": 50, "0060": 51, "0061": 52, "0062_0063": 53, "0064": 54, "0065": 55, "0066": 56, "0067": 57, "0068": 58, "0069": 59, "0070_0071": 60, "0072": 61, "0073": 62, "0074": 63, "0075": 64, "0076": 65, "0077": 66, "0078": 67, "0079": 68, "0080": 69, "0081": 70, "0082": 71, "0083": 72, "0084": 73, "0085": 74, "0086": 75, "0087": 76, "0088": 77, "0089": 78, "0090": 79, "0091": 80, "0092": 81, "0093": 82, "0094": 83, "0095": 84, "0096": 85, "0097": 86, "0098": 87, "0099": 88, "0100": 89, "0101": 90, "0102": 91, "0103": 92, "0104": 93, "0105": 94, "0106_0107": 95, "0108": 96, "0109": 97, "0110": 98, "0111": 99, "0112": 100, "0113_0114": 101, "0115": 102, "0116": 103, "0117_0118": 104, "0119_0120": 105, "0121": 106, "0122": 107, "0123": 108, "0124": 109, "0125": 110, "0126": 111, "0127": 112, "0128": 113, "0129": 114, "0130": 115, "0131": 116, "0132": 117, "0133": 118, "0134": 119, "0135": 120, "0136": 121, "0137": 122, "0138": 123, "0139": 124, "0140": 125, "0141": 126, "0142": 127, "0143": 128, "0144": 129, "0145_0146": 130, "0147": 131, "0148": 132, "0149": 133, "0150": 134, "0151": 135, "0152": 136, "0153": 137, "0154": 138, "0155": 139, "0156": 140, "0157": 141, "0158": 142, "0159": 143, "0160": 144, "0161_0162": 145, "0163": 146, "0164": 147, "0165": 148, "0166": 149, "0167": 150, "0168": 151, "0169": 152, "0170": 153, "0171": 154, "0172": 155, "0173": 156, "0174": 157, "0175": 158, "0176": 159, "0177": 160, "0178": 161, "0179": 162, "0180": 163, "0181": 164, "0182": 165, "0183": 166, "0184": 167, "0185": 168, "0186": 169, "0187": 170, "0188": 171, "0189_0190": 172, "0191": 173, "0192": 174, "0193_0194": 175, "0195": 176, "0196_0197": 177, "0198": 178, "0199": 179, "0200": 180, "0201": 181, "0202": 182, "0203": 183, "0204": 184, "0205_0206": 185, "0207": 186, "0208": 187, "0209": 188, "0210_0211": 189, "0212_0213": 190, "0214": 191, "0215": 192, "0216": 193, "0217": 194, "0218": 195, "0219": 196, "0220_0221": 197, "0222": 198, "0223": 199, "0224": 200, "0225": 201, "0226": 202, "0227": 203, "0228": 204, "0229": 205, "0230": 206, "0231": 207, "0232": 208, "0233": 209, "0234": 210, "0235": 211, "0236_0237": 212, "0238": 213, "0239": 214, "0240": 215, "0241": 216, "0242": 217, "0243": 218, "0244": 219, "0245": 220, "0246_0247": 221, "0248": 222, "0249_0250": 223, "0251": 224, "0252": 225, "0253": 226, "0254": 227, "0255": 228, "0256": 229, "0257": 230, "0258_0259": 231, "0260": 232, "0261": 233, "0262": 234, "0263": 235, "0264": 236, "0265": 237, "0266": 238, "0267": 239, "0268": 240, "0269": 241, "0270_0271": 242, "0272": 243, "0273": 244, "0274": 245, "0275": 246, "0276": 247, "0277": 248, "0278": 249, "0279": 250, "0280": 251, "0281_0282_0283": 252, "0284": 253, "0285": 254, "0286_0287_0288": 255, "0289": 256, "0290": 257, "0291": 258, "0292": 259, "0293": 260, "0294": 261, "0295": 262, "0296": 263, "0297_0298": 264, "0299": 265, "0300": 266, "0301": 267, "0302_0303": 268, "0304_0305": 269, "0306": 270, "0307": 271, "0308": 272, "0309": 273, "0310": 274, "0311": 275, "0312_0313": 276, "0314": 277, "0315": 278, "0316": 279, "0317": 280, "0318": 281, "0319": 282, "0320": 283, "0321_0322_0323": 284, "0324": 285, "0325": 286, "0326": 287, "0327_0328": 288, "0329": 289, "0330": 290, "0331": 291, "0332": 292, "0333": 293, "0334": 294, "0335": 295, "0336_0337": 296, "0338": 297, "0339": 298, "0340": 299, "0341": 300, "0342": 301, "0343": 302, "0344": 303, "0345": 304, "0346": 305, "0347": 306, "0348": 307, "0349": 308, "0350": 309, "0351": 310, "0352": 311, "0353": 312, "0354": 313, "0355": 314, "0356": 315, "0357": 316, "0358": 317, "0359": 318, "0360_0361": 319, "0362_0363": 320, "0364": 321, "0365": 322, "0366": 323, "0367": 324, "0368": 325, "0369": 326, "0370": 327, "0371": 328, "0372": 329, "0373": 330, "0374": 331, "0375_0376": 332, "0377": 333, "0378": 334, "0379": 335, "0380": 336, "0381": 337, "0382_0383": 338, "0384": 339, "0385": 340, "0386_0387": 341, "0388": 342, "0389": 343, "0390": 344, "0391": 345, "0392": 346, "0393": 347, "0394_0395": 348, "0396_0397": 349, "0398": 350, "0399": 351, "0400": 352, "0401_0402": 353, "0403": 354, "0404": 355, "0405": 356, "0406": 357, "0407_0408": 358, "0409_0410": 359, "0411": 360, "0412": 361, "0413": 362, "0414": 363, "0415": 364, "0416": 365, "0417": 366, "0418": 367, "0419": 368, "0420": 369, "0421_0422": 370, "0423": 371, "0424": 372, "0425_0426": 373, "0427": 374, "0428": 375, "0429": 376, "0430": 377, "0431": 378, "0432": 379, "0433": 380, "0434": 381, "0435": 382, "0436": 383, "0437": 384, "0438": 385, "0439": 386, "0440": 387, "0441": 388, "0442": 389, "0443": 390, "0444": 391, "0445_0446": 392, "0447": 393, "0448": 394, "0449": 395, "0450": 396, "0451": 397, "0452": 398, "0453": 399, "0454": 400, "0455_0456": 401, "0457": 402, "0458_0459": 403, "0460": 404, "0461_0462": 405, "0463": 406, "0464": 407, "0465": 408, "0466": 409, "0467": 410, "0468": 411, "0469": 412, "0470": 413, "0471": 414, "0472": 415, "0473": 416, "0474": 417, "0475": 418, "0476": 419, "0477": 420, "0478": 421, "0479_0480": 422, "0481": 423, "0482": 424, "0483_0484": 425, "0485": 426, "0486": 427, "0487": 428, "0488": 429, "0489_0490_0491": 430, "0492": 431, "0493": 432, "0494": 433, "0495": 434, "0496": 435, "0497": 436, "0498": 437, "0499": 438, "0500": 439, "0501": 440, "0502": 441, "0503": 442, "0504": 443, "0505": 444, "0506": 445, "0507": 446, "0508": 447, "0509": 448, "0510": 449, "0511": 450, "0512": 451, "0513": 452, "0514": 453, "0515": 454, "0516_0517": 455, "0518_0519": 456, "0520": 457, "0521": 458, "0522": 459, "0523": 460, "0524": 461, "0525": 462, "0526": 463, "0527": 464, "0528": 465, "0529": 466, "0530": 467, "0531": 468, "0532": 469, "0533": 470, "0534": 471, "0535_0536": 472, "0537": 473, "0538": 474, "0539": 475, "0540_0541": 476, "0542": 477, "0543": 478, "0544_0545": 479, "0546": 480, "0547": 481, "0548": 482, "0549": 483, "0550": 484, "0551": 485, "0552": 486, "0553": 487, "0554": 488, "0555": 489, "0556": 490, "0557": 491, "0558_0559_0560": 492, "0561": 493, "0562_0563_0564": 494, "0565": 495, "0566": 496, "0567_0568": 497, "0569_0570": 498, "0571": 499, "0572": 500, "0573": 501, "0574": 502, "0575": 503, "0576": 504, "0577": 505, "0578": 506, "0579": 507, "0580": 508, "0581": 509, "0582": 510, "0583": 511, "0584": 512, "0585": 513, "0586": 514, "0587_0588_0589": 515, "0590": 516, "0591": 517, "0592": 518, "0593": 519, "0594": 520, "0595": 521, "0596": 522, "0597_0598": 523, "0599": 524, "0600": 525, "0601": 526, "0602": 527, "0603": 528, "0604": 529, "0605": 530, "0606": 531, "0607": 532, "0608": 533, "0609": 534, "0610": 535, "0611": 536, "0612": 537, "0613": 538, "0614": 539, "0615": 540, "0616": 541, "0617": 542, "0618": 543, "0619": 544, "0620": 545, "0621": 546, "0622": 547, "0623": 548, "0624": 549, "0625": 550, "0626": 551, "0627": 552, "0628": 553, "0629": 554, "0630_0631": 555, "0632": 556, "0633": 557, "0634": 558, "0635": 559, "0636": 560, "0637": 561, "0638_0639": 562, "0640": 563, "0641": 564, "0642": 565, "0643": 566, "0644": 567, "0645": 568, "0646": 569, "0647": 570, "0648": 571, "0649": 572, "0650": 573, "0651": 574, "0652": 575, "0653_0654": 576, "0655_0656": 577, "0657": 578, "0658": 579, "0659": 580, "0660": 581, "0661_0662": 582, "0663": 583, "0664": 584, "0665": 585, "0666": 586, "0667": 587, "0668": 588, "0669": 589, "0670_0671": 590, "0672_0673_0674": 591, "0675": 592, "0676": 593, "0677": 594, "0678_0679": 595, "0680_0681": 596, "0682": 597, "0683": 598, "0684": 599, "0685": 600, "0686": 601, "0687": 602, "0688": 603, "0689": 604, "0690_0691": 605, "0692": 606, "0693": 607, "0694": 608, "0695": 609, "0696": 610, "0697": 611, "0698": 612, "0699": 613, "0700": 614, "0701": 615, "0702": 616, "0703": 617, "0704": 618, "0705": 619, "0706": 620, "0707": 621, "0708": 622, "0709": 623, "0710": 624, "0711": 625, "0712": 626, "0713": 627, "0714": 628, "0715_0716": 629, "0717": 630, "0718_0719": 631, "0720": 632, "0721": 633, "0722": 634, "0723": 635, "0724": 636, "0725": 637, "0726_0727": 638, "0728": 639, "0729": 640, "0730": 641, "0731": 642, "0732": 643, "0733": 644, "0734": 645, "0735": 646, "0736": 647, "0737": 648, "0738": 649, "0739": 650, "0740": 651, "0741": 652, "0742": 653, "0743": 654, "0744": 655, "0745": 656, "0746": 657, "0747": 658, "0748": 659, "0749": 660, "0750_0751": 661, "0752_0753": 662, "0754": 663, "0755": 664, "0756": 665, "0757": 666, "0758": 667, "0759": 668, "0760": 669, "0761_0762": 670, "0763": 671, "0764": 672, "0765": 673, "0766": 674, "0767": 675, "0768": 676, "0769": 677, "0770": 678, "0771": 679, "0772": 680, "0773": 681, "0774": 682, "0775": 683, "0776": 684, "0777": 685, "0778": 686, "0779": 687, "0780": 688, "0781": 689, "0782": 690, "0783": 691, "0784": 692, "0785_0786": 693, "0787_0788_0789": 694, "0790": 695, "0791": 696, "0792": 697, "0793": 698, "0794": 699, "0795": 700, "0796": 701, "0797": 702, "0798": 703, "0799": 704, "0800_0801": 705, "0802": 706, "0803": 707, "0804": 708, "0805": 709, "0806": 710, "0807_0808": 711, "0809": 712, "0810": 713, "0811_0812": 714, "0813": 715, "0814": 716, "0815_0816": 717, "0817": 718, "0818": 719, "0819": 720, "0820": 721, "0821": 722, "0822": 723, "0823_0824": 724, "0825_0826": 725, "0827_0828": 726, "0829": 727, "0830": 728, "0831": 729, "0832": 730, "0833": 731, "0834": 732, "0835_0836": 733, "0837": 734, "0838": 735, "0839": 736, "0840": 737, "0841": 738, "0842": 739, "0843": 740, "0844": 741, "0845": 742, "0846": 743, "0847": 744, "0848": 745, "0849": 746, "0850": 747, "0851": 748, "0852": 749, "0853": 750, "0854": 751, "0855": 752, "0856": 753, "0857": 754, "0858": 755, "0859": 756, "0860": 757, "0861_0862": 758, "0863": 759, "0864": 760, "0865": 761, "0866": 762, "0867": 763, "0868": 764, "0869": 765, "0870_0871": 766, "0872": 767, "0873": 768, "0874": 769, "0875": 770, "0876": 771, "0877": 772, "0878": 773, "0879": 774, "0880": 775, "0881": 776, "0882": 777, "0883": 778, "0884": 779, "0885": 780, "0886": 781, "0887": 782, "0888": 783, "0889": 784, "0890": 785, "0891": 786, "0892": 787, "0893": 788, "0894": 789, "0895": 790, "0896": 791, "0897": 792, "0898": 793, "0899": 794, "0900_0901": 795, "0902_0903": 796, "0904": 797, "0905": 798, "0906": 799, "0907": 800, "0908_0909": 801, "0910": 802, "0911": 803, "0912": 804, "0913": 805, "0914": 806, "0915": 807, "0916": 808, "0917_0918": 809, "0919": 810, "0920": 811, "0921": 812, "0922": 813, "0923": 814, "0924": 815, "0925": 816, "0926": 817, "0927": 818, "0928": 819, "0929": 820, "0930": 821, "0931": 822, "0932": 823, "0933": 824, "0934": 825, "0935": 826, "0936": 827, "0937": 828, "0938": 829, "0939_0940": 830, "0941": 831, "0942": 832, "0943": 833, "0944": 834, "0945": 835, "0946": 836, "0947": 837, "0948": 838, "0949": 839, "0950": 840, "0951": 841, "0952": 842, "0953": 843, "0954": 844, "0955": 845, "0956": 846, "0957": 847, "0958": 848, "0959": 849, "0960": 850, "0961": 851, "0962": 852, "0963": 853, "0964": 854, "0965": 855, "0966": 856, "0967": 857, "0968_0969": 858, "0970": 859, "0971": 860, "0972": 861, "0973": 862, "0974": 863, "0975": 864, "0976_0977": 865, "0978": 866, "0979": 867, "0980": 868, "0981": 869, "0982": 870, "0983": 871, "0984_0985": 872, "0986": 873, "0987": 874, "0988": 875, "0989": 876, "0990": 877, "0991": 878, "0992": 879, "0993": 880, "0994": 881, "0995": 882, "0996": 883, "0997": 884, "0998": 885, "0999": 886, "1000": 887, "1001": 888, "1002": 889, "1003": 890, "1004": 891, "1005": 892, "1006": 893, "1007": 894, "1008_1009": 895, "1010_1011": 896, "1012": 897, "1013": 898, "1014": 899, "1015": 900, "1016": 901, "1017": 902, "1018_1019": 903, "1020": 904, "1021": 905, "1022": 906, "1023": 907, "1024": 908, "1025": 909, "1026": 910, "1027": 911, "1028": 912, "1029_1030": 913, "1031": 914, "1032": 915, "1033": 916, "1034_1035": 917, "1036": 918, "1037": 919, "1038": 920, "1039": 921, "1040": 922, "1041": 923, "1042": 924, "1043": 925, "1044": 926, "1045": 927, "1046": 928, "1047": 929, "1048": 930, "1049": 931, "1050": 932, "1051": 933, "1052": 934, "1053": 935, "1054": 936, "1055": 937, "1056": 938, "1057": 939, "1058": 940, "1059": 941, "1060": 942, "1061": 943, "1062": 944, "1063": 945, "1064": 946, "1065_1066": 947, "1067": 948, "1068": 949, "1069": 950, "1070": 951, "1071_1072": 952, "1073": 953, "1074": 954, "1075": 955, "1076": 956, "1077_1078": 957, "1079": 958, "1080": 959, "1081": 960, "1082": 961, "1083": 962, "1084": 963, "1085": 964, "1086": 965, "1087": 966, "1088": 967, "1089": 968, "1090": 969, "1091": 970, "1092": 971, "1093": 972, "1094": 973, "1095": 974, "1096": 975, "1097": 976, "1098": 977, "1099": 978, "1100": 979, "1101": 980, "1102_1103": 981, "1104": 982, "1105": 983, "1106_1107": 984, "1108": 985, "1109": 986, "1110": 987, "1111": 988, "1112": 989, "1113": 990, "1114": 991, "1115_1116": 992, "1117": 993, "1118": 994, "1119": 995, "1120_1121": 996, "1122": 997, "1123": 998, "1124": 999, "1125": 1000, "1126": 1001, "1127": 1002, "1128": 1003, "1129": 1004, "1130": 1005, "1131": 1006, "1132": 1007, "1133_1134": 1008, "1135": 1009, "1136": 1010, "1137": 1011, "1138": 1012, "1139": 1013, "1140": 1014, "1141": 1015, "1142": 1016, "1143": 1017, "1144": 1018, "1145": 1019, "1146": 1020, "1147": 1021, "1148": 1022, "1149": 1023, "1150": 1024, "1151": 1025, "1152": 1026, "1153": 1027, "1154": 1028, "1155": 1029, "1156": 1030, "1157": 1031, "1158": 1032, "1159": 1033, "1160": 1034, "1161": 1035, "1162": 1036, "1163_1164": 1037, "1165": 1038, "1166": 1039, "1167": 1040, "1168": 1041, "1169": 1042, "1170": 1043, "1171_1172": 1044, "1173": 1045, "1174": 1046, "1175": 1047, "1176": 1048, "1177": 1049, "1178": 1050, "1179": 1051, "1180": 1052, "1181": 1053, "1182": 1054, "1183": 1055, "1184": 1056, "1185": 1057, "1186": 1058, "1187": 1059, "1188": 1060, "1189": 1061, "1190": 1062, "1191": 1063, "1192": 1064, "1193": 1065, "1194": 1066, "1195": 1067, "1196": 1068, "1197": 1069, "1198": 1070, "1199": 1071, "1200": 1072, "1201": 1073, "1202": 1074, "1203": 1075, "1204_1205": 1076, "1206": 1077, "1207_1208": 1078, "1209": 1079, "1210": 1080, "1211": 1081, "1212": 1082, "1213": 1083, "1214": 1084, "1215": 1085, "1216": 1086, "1217": 1087, "1218": 1088, "1219": 1089, "1220": 1090, "1221": 1091, "1222": 1092, "1223": 1093, "1224": 1094, "1225": 1095, "1226": 1096, "1227": 1097, "1228": 1098, "1229": 1099, "1230": 1100, "1231_1232": 1101, "1233": 1102, "1234": 1103, "1235": 1104, "1236": 1105, "1237": 1106, "1238": 1107, "1239_1240": 1108, "1241": 1109, "1242": 1110, "1243": 1111, "1244": 1112, "1245": 1113, "1246": 1114, "1247": 1115, "1248": 1116, "1249": 1117, "1250": 1118, "1251": 1119, "1252": 1120, "1253": 1121, "1254": 1122, "1255": 1123, "1256": 1124, "1257": 1125, "1258": 1126, "1259": 1127, "1260": 1128, "1261_1262_1263": 1129, "1264": 1130, "1265": 1131, "1266": 1132, "1267": 1133, "1268": 1134, "1269": 1135, "1270": 1136, "1271_1272": 1137, "1273": 1138, "1274": 1139, "1275": 1140, "1276": 1141, "1277": 1142, "1278": 1143, "1279": 1144, "1280_1281": 1145, "1282": 1146, "1283": 1147, "1284_1285_1286": 1148, "1287": 1149, "1288": 1150, "1289": 1151, "1290_1291_1292": 1152, "1293": 1153, "1294": 1154, "1295": 1155, "1296": 1156, "1297": 1157, "1298": 1158, "1299": 1159, "1300": 1160, "1301": 1161, "1302": 1162, "1303": 1163, "1304_1305": 1164, "1306": 1165, "1307": 1166, "1308": 1167, "1309": 1168, "1310": 1169, "1311": 1170, "1312": 1171, "1313": 1172, "1314": 1173, "1315": 1174, "1316": 1175, "1317": 1176, "1318": 1177, "1319": 1178, "1320": 1179, "1321": 1180, "1322": 1181, "1323": 1182, "1324": 1183, "1325": 1184, "1326": 1185, "1327_1328": 1186, "1329": 1187, "1330": 1188, "1331": 1189, "1332": 1190, "1333": 1191, "1334": 1192, "1335": 1193, "1336": 1194, "1337": 1195, "1338": 1196, "1339": 1197, "1340": 1198, "1341_1342": 1199, "1343": 1200, "1344_1345": 1201, "1346": 1202, "1347_1348": 1203, "1349_1350": 1204, "1351": 1205, "1352": 1206, "1353": 1207, "1354": 1208, "1355": 1209, "1356": 1210, "1357": 1211, "1358": 1212, "1359": 1213, "1360": 1214, "1361": 1215, "1362": 1216, "1363": 1217, "1364": 1218, "1365": 1219, "1366": 1220, "1367": 1221, "1368": 1222, "1369": 1223, "1370": 1224, "1371": 1225, "1372": 1226, "1373": 1227, "1374_1375": 1228, "1376": 1229, "1377": 1230, "1378_1379": 1231, "1380": 1232, "1381": 1233, "1382": 1234, "1383": 1235, "1384": 1236, "1385": 1237, "1386": 1238, "1387": 1239, "1388": 1240, "1389_1390": 1241, "1391": 1242, "1392": 1243, "1393": 1244, "1394": 1245, "1395": 1246, "1396": 1247, "1397": 1248, "1398": 1249, "1399": 1250, "1400": 1251, "1401": 1252, "1402": 1253, "1403": 1254, "1404": 1255, "1405": 1256, "1406_1407": 1257, "1408": 1258, "1409": 1259, "1410": 1260, "1411": 1261, "1412": 1262, "1413": 1263, "1414": 1264, "1415_1416": 1265, "1417": 1266, "1418": 1267, "1419": 1268, "1420": 1269, "1421_1422": 1270, "1423": 1271, "1424": 1272, "1425": 1273, "1426": 1274, "1427": 1275, "1428": 1276, "1429_1430": 1277, "1431": 1278, "1432": 1279, "1433": 1280, "1434": 1281, "1435": 1282, "1436": 1283, "1437": 1284, "1438": 1285, "1439": 1286, "1440": 1287, "1441": 1288, "1442": 1289, "1443": 1290, "1444": 1291, "1445": 1292, "1446_1447": 1293, "1448": 1294, "1449": 1295, "1450": 1296, "1451": 1297, "1452": 1298, "1453": 1299, "1454": 1300, "1455": 1301, "1456": 1302, "1457": 1303, "1458": 1304, "1459": 1305, "1460": 1306, "1461": 1307, "1462": 1308, "1463": 1309, "1464_1465_1466": 1310, "1467": 1311, "1468": 1312, "1469": 1313, "1470": 1314, "1471_1472": 1315, "1473": 1316, "1474": 1317, "1475": 1318, "1476": 1319, "1477": 1320, "1478": 1321, "1479": 1322, "1480": 1323, "1481": 1324, "1482": 1325, "1483": 1326, "1484": 1327, "1485": 1328, "1486": 1329, "1487": 1330, "1488": 1331, "1489": 1332, "1490": 1333, "1491": 1334, "1492": 1335, "1493": 1336, "1494": 1337, "1495": 1338, "1496": 1339, "1497": 1340, "1498": 1341, "1499_1500": 1342, "1501": 1343, "1502": 1344, "1503": 1345, "1504": 1346, "1505": 1347, "1506": 1348, "1507": 1349, "1508": 1350, "1509": 1351, "1510": 1352, "1511": 1353, "1512": 1354, "1513": 1355, "1514": 1356, "1515": 1357, "1516": 1358, "1517": 1359, "1518": 1360, "1519": 1361, "1520_1521": 1362, "1522": 1363, "1523": 1364, "1524": 1365, "1525_1526": 1366, "1527": 1367, "1528": 1368, "1529": 1369, "1530": 1370, "1531": 1371, "1532": 1372, "1533_1534_1535": 1373, "1536": 1374, "1537": 1375, "1538": 1376, "1539": 1377, "1540": 1378, "1541_1542": 1379, "1543": 1380, "1544": 1381, "1545_1546": 1382, "1547": 1383, "1548": 1384, "1549_1550": 1385, "1551": 1386, "1552_1553": 1387, "1554": 1388, "1555": 1389, "1556": 1390, "1557": 1391, "1558": 1392, "1559": 1393, "1560": 1394, "1561": 1395, "1562_1563": 1396, "1564_1565": 1397, "1566_1567": 1398, "1568": 1399, "1569": 1400, "1570": 1401, "1571_1572": 1402, "1573": 1403, "1574": 1404, "1575": 1405, "1576": 1406, "1577": 1407, "1578": 1408, "1579": 1409, "1580": 1410, "1581": 1411, "1582": 1412, "1583": 1413, "1584": 1414, "1585": 1415, "1586": 1416, "1587": 1417, "1588": 1418, "1589_1590_1591": 1419, "1592": 1420, "1593": 1421, "1594_1595": 1422, "1596": 1423, "1597": 1424, "1598": 1425, "1599": 1426, "1600": 1427, "1601": 1428, "1602": 1429, "1603": 1430, "1604": 1431, "1605_1606": 1432, "1607": 1433, "1608": 1434, "1609": 1435, "1610": 1436, "1611": 1437, "1612": 1438, "1613": 1439, "1614_1615": 1440, "1616": 1441, "1617": 1442, "1618_1619": 1443, "1620": 1444, "1621": 1445, "1622": 1446, "1623_1624": 1447, "1625": 1448, "1626": 1449, "1627": 1450, "1628": 1451, "1629": 1452, "1630": 1453, "1631": 1454, "1632": 1455, "1633": 1456, "1634": 1457, "1635_1636": 1458, "1637": 1459, "1638": 1460, "1639_1640_1641": 1461, "1642_1643": 1462, "1644": 1463, "1645": 1464, "1646": 1465, "1647": 1466, "1648": 1467, "1649": 1468, "1650": 1469, "1651": 1470, "1652_1653": 1471, "1654": 1472, "1655": 1473, "1656": 1474, "1657": 1475, "1658_1659": 1476, "1660": 1477, "1661": 1478, "1662": 1479, "1663": 1480, "1664": 1481, "1665": 1482, "1666": 1483, "1667": 1484, "1668": 1485, "1669": 1486, "1670": 1487, "1671": 1488, "1672": 1489, "1673": 1490, "1674": 1491, "1675": 1492, "1676": 1493, "1677": 1494, "1678": 1495, "1679": 1496, "1680": 1497, "1681": 1498, "1682_1683_1684": 1499, "1685": 1500, "1686": 1501, "1687_1688": 1502, "1689": 1503, "1690": 1504, "1691": 1505, "1692": 1506, "1693": 1507, "1694": 1508, "1695": 1509, "1696": 1510, "1697": 1511, "1698": 1512, "1699": 1513, "1700": 1514, "1701": 1515, "1702": 1516, "1703": 1517, "1704": 1518, "1705_1706": 1519, "1707": 1520, "1708": 1521, "1709": 1522, "1710": 1523, "1711": 1524, "1712": 1525, "1713": 1526, "1714": 1527, "1715": 1528, "1716": 1529, "1717": 1530, "1718": 1531, "1719": 1532, "1720": 1533, "1721": 1534, "1722": 1535, "1723_1724_1725": 1536, "1726": 1537, "1727_1728": 1538, "1729": 1539, "1730": 1540, "1731": 1541, "1732": 1542, "1733": 1543, "1734": 1544, "1735": 1545, "1736_1737_1738": 1546, "1739": 1547, "1740": 1548, "1741": 1549, "1742": 1550, "1743": 1551, "1744": 1552, "1745": 1553, "1746": 1554, "1747": 1555, "1748": 1556, "1749": 1557, "1750": 1558, "1751": 1559, "1752": 1560, "1753": 1561, "1754": 1562, "1755": 1563, "1756_1757": 1564, "1758": 1565, "1759": 1566, "1760": 1567, "1761": 1568, "1762": 1569, "1763": 1570, "1764": 1571, "1765": 1572, "1766": 1573, "1767": 1574, "1768": 1575, "1769": 1576, "1770": 1577, "1771": 1578, "1772": 1579, "1773": 1580, "1774": 1581, "1775": 1582, "1776": 1583, "1777": 1584, "1778": 1585, "1779_1780": 1586, "1781": 1587}
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
--------------------------------------------------------------------------------
/MoCo/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # gpu_info = !nvidia-smi -i 0
3 | # gpu_info = '\n'.join(gpu_info)
4 | # print(gpu_info)
5 | # nohup python -u mocodataset.py > mocodataset.log 2>&1 &
6 | from datetime import datetime
7 | from functools import partial
8 | import cv2
9 | import numpy as np
10 | import torchvision
11 | from PIL import Image
12 | from torch.utils.data import DataLoader
13 | from torchvision import transforms
14 | from torchvision.datasets import CIFAR10, FashionMNIST,MNIST
15 | from torchvision.models import resnet
16 | from tqdm import tqdm
17 | import argparse
18 | import json
19 | import math
20 | import os
21 | import random
22 | import pandas as pd
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 |
27 | """### Set arguments"""
28 |
29 | parser = argparse.ArgumentParser(description='Train MoCo on HUST-OBC')
30 |
31 | parser.add_argument('-a', '--arch', default='resnet18')
32 |
33 | # lr: 0.06 for batch 512 (or 0.03 for batch 256)
34 | parser.add_argument('--lr', '--learning-rate', default=0.006, type=float, metavar='LR', help='initial learning rate', dest='lr')
35 | parser.add_argument('--epochs', default=150, type=int, metavar='N', help='number of total epochs to run')
36 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
37 | parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')
38 |
39 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
40 | parser.add_argument('--num_workers', default=9, type=int)
41 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
42 |
43 | # moco specific configs:
44 | parser.add_argument('--moco-dim', default=128, type=int, help='feature dimension')
45 | parser.add_argument('--moco-k', default=24576, type=int, help='queue size; number of negative keys')
46 | parser.add_argument('--moco-m', default=0.99, type=float, help='moco momentum of updating key encoder')
47 | parser.add_argument('--moco-t', default=0.1, type=float, help='softmax temperature')
48 |
49 | parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
50 |
51 | parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')
52 |
53 | # knn monitor
54 | parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
55 | parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')
56 |
57 | # utils
58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
59 | parser.add_argument('--results-dir', default='moco', type=str, metavar='PATH', help='path to cache (default: none)')
60 | '''
61 | args = parser.parse_args() # running in command line
62 | '''
63 | # args = parser.parse_args('') # running in ipynb
64 | args = parser.parse_args() # running in command line
65 | # set command line arguments here when running in ipynb
66 | args.cos = True
67 | args.schedule = [] # cos in use
68 | args.symmetric = False
69 | if args.results_dir == '':
70 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")
71 |
72 | print(args)
73 |
74 | """### Define data loaders"""
75 |
76 | class CIFAR10Pair(CIFAR10):
77 | """CIFAR10 Dataset.
78 | """
79 | def __getitem__(self, index):
80 | img = self.data[index]
81 | img = Image.fromarray(img)
82 |
83 | if self.transform is not None:
84 | im_1 = self.transform(img)
85 | im_2 = self.transform(img)
86 |
87 | return im_1, im_2
88 | class MiniPair(MNIST):
89 | """CIFAR10 Dataset.
90 | """
91 | def __getitem__(self, index):
92 | img = self.data[index]
93 | img=img.numpy()
94 | img = Image.fromarray(img)
95 |
96 | if self.transform is not None:
97 | im_1 = self.transform(img)
98 | im_2 = self.transform(img)
99 |
100 | return im_1, im_2
101 |
102 | from torch.utils.data import TensorDataset, DataLoader
103 | class RandomGaussianBlur(object):
104 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0):
105 | self.p = p
106 | self.min_kernel_size = min_kernel_size
107 | self.max_kernel_size = max_kernel_size
108 | self.min_sigma = min_sigma
109 | self.max_sigma = max_sigma
110 |
111 | def __call__(self, img):
112 | if random.random() < self.p and self.min_kernel_size 128 or sizey < 32:
177 | sizey = round(random.gauss(y, 30))
178 | if x < 128:
179 | while sizex > 128 or sizex < 32:
180 | sizex = round(random.gauss(x, 30))
181 | dx = 128 - sizex # 差值
182 | dy = 128 - sizey
183 | if dx > 0:
184 | xl = -1
185 | while xl > dx or xl < 0:
186 | xl = round(dx / 2)
187 | xl = round(random.gauss(xl, 10))
188 | else:
189 | xl = 0
190 | if dy > 0:
191 | yl = -1
192 | while yl > dy or yl < 0:
193 | yl = round(dy / 2)
194 | yl = round(random.gauss(yl, 10))
195 | else:
196 | yl = 0
197 | yr = dy - yl
198 | xr = dx - xl
199 | image1=jioayan(image)
200 | image2=jioayan(image)
201 | image1 = cv2.cvtColor(np.array(image1), cv2.COLOR_RGB2BGR)
202 | image1 = pengzhang(image1)
203 | image1 = Image.fromarray(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
204 | image2 = cv2.cvtColor(np.array(image2), cv2.COLOR_RGB2BGR)
205 | image2 = pengzhang(image2)
206 | image2 = Image.fromarray(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB))
207 | image1 = random_gaussian_blur(image1)
208 | image2 = random_gaussian_blur(image2)
209 | train_transform1 = transforms.Compose([
210 | transforms.Resize((sizey, sizex)),
211 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'),
212 | transforms.RandomRotation(degrees=(-20, 20), center=(round(64), round(64)), fill=(255, 255, 255)),
213 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
214 | transforms.RandomGrayscale(p=0.2),
215 | transforms.ToTensor(),
216 | transforms.Normalize([0.84959, 0.84959, 0.84959], [0.30949923, 0.30949923, 0.30949923])])
217 | im_1 = train_transform1(image1)
218 | x, y = 80, 80
219 | sizey, sizex = 129, 129
220 | if y < 128:
221 | while sizey > 128 or sizey < 32:
222 | sizey = round(random.gauss(y, 30))
223 | if x < 128:
224 | while sizex > 128 or sizex < 32:
225 | sizex = round(random.gauss(x, 30))
226 | dx = 128 - sizex # 差值
227 | dy = 128 - sizey
228 | if dx > 0:
229 | xl = -1
230 | while xl > dx or xl < 0:
231 | xl = round(dx / 2)
232 | xl = round(random.gauss(xl, 10))
233 | else:
234 | xl = 0
235 | if dy > 0:
236 | yl = -1
237 | while yl > dy or yl < 0:
238 | yl = round(dy / 2)
239 | yl = round(random.gauss(yl, 10))
240 | else:
241 | yl = 0
242 | yr = dy - yl
243 | xr = dx - xl
244 | train_transform2 = transforms.Compose([
245 | transforms.Resize((sizey, sizex)),
246 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'),
247 | transforms.RandomRotation(degrees=(-20, 20), center=(round(64), round(64)), fill=(255, 255, 255)),
248 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
249 | transforms.RandomGrayscale(p=0.2),
250 | transforms.ToTensor(),
251 | transforms.Normalize([0.84959, 0.84959, 0.84959], [0.30949923, 0.30949923, 0.30949923])])
252 | im_2 = train_transform2(image2)
253 | return im_1, im_2
254 |
255 | def __len__(self):
256 | return len(self.images)
257 |
258 | train_dataset = Mydata()
259 | train_loader = DataLoader(train_dataset, shuffle = True, batch_size = args.batch_size, num_workers=args.num_workers,drop_last=True,pin_memory=True,)
260 | class SplitBatchNorm(nn.BatchNorm2d):
261 | def __init__(self, num_features, num_splits, **kw):
262 | super().__init__(num_features, **kw)
263 | self.num_splits = num_splits
264 |
265 | def forward(self, input):
266 | N, C, H, W = input.shape
267 | if self.training or not self.track_running_stats:
268 | running_mean_split = self.running_mean.repeat(self.num_splits)
269 | running_var_split = self.running_var.repeat(self.num_splits)
270 | outcome = nn.functional.batch_norm(
271 | input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split,
272 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
273 | True, self.momentum, self.eps).view(N, C, H, W)
274 | self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
275 | self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
276 | return outcome
277 | else:
278 | return nn.functional.batch_norm(
279 | input, self.running_mean, self.running_var,
280 | self.weight, self.bias, False, self.momentum, self.eps)
281 |
282 | class ModelBase(nn.Module):
283 | """
284 | Common CIFAR ResNet recipe.
285 | Comparing with ImageNet ResNet recipe, it:
286 | (i) replaces conv1 with kernel=3, str=1
287 | (ii) removes pool1
288 | """
289 | def __init__(self, feature_dim=128, arch=None, bn_splits=16):
290 | super(ModelBase, self).__init__()
291 |
292 | # use split batchnorm
293 | norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
294 | resnet_arch = getattr(resnet, arch)
295 | net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)
296 |
297 | self.net = []
298 | for name, module in net.named_children():
299 | if name == 'conv1':
300 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
301 | if isinstance(module, nn.MaxPool2d):
302 | continue
303 | if isinstance(module, nn.Linear):
304 | self.net.append(nn.Flatten(1))
305 | self.net.append(module)
306 |
307 | self.net = nn.Sequential(*self.net)
308 |
309 | def forward(self, x):
310 | x = self.net(x)
311 | # note: not normalized here
312 | return x
313 |
314 | """### Define MoCo wrapper"""
315 |
316 | class ModelMoCo(nn.Module):
317 | def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
318 | super(ModelMoCo, self).__init__()
319 |
320 | self.K = K
321 | self.m = m
322 | self.T = T
323 | self.symmetric = symmetric
324 |
325 | # create the encoders
326 | self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
327 | self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits)
328 |
329 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
330 | param_k.data.copy_(param_q.data) # initialize
331 | param_k.requires_grad = False # not update by gradient
332 |
333 | # create the queue
334 | self.register_buffer("queue", torch.randn(dim, K))
335 | self.queue = nn.functional.normalize(self.queue, dim=0)
336 |
337 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
338 |
339 | @torch.no_grad()
340 | def _momentum_update_key_encoder(self):
341 | """
342 | Momentum update of the key encoder
343 | """
344 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
345 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
346 |
347 | @torch.no_grad()
348 | def _dequeue_and_enqueue(self, keys):
349 | batch_size = keys.shape[0]
350 |
351 | ptr = int(self.queue_ptr)
352 | assert self.K % batch_size == 0 # for simplicity
353 |
354 | # replace the keys at ptr (dequeue and enqueue)
355 | self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose
356 | ptr = (ptr + batch_size) % self.K # move pointer
357 |
358 | self.queue_ptr[0] = ptr
359 |
360 | @torch.no_grad()
361 | def _batch_shuffle_single_gpu(self, x):
362 | """
363 | Batch shuffle, for making use of BatchNorm.
364 | """
365 | # random shuffle index
366 | idx_shuffle = torch.randperm(x.shape[0]).cuda()
367 |
368 | # index for restoring
369 | idx_unshuffle = torch.argsort(idx_shuffle)
370 |
371 | return x[idx_shuffle], idx_unshuffle
372 |
373 | @torch.no_grad()
374 | def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
375 | """
376 | Undo batch shuffle.
377 | """
378 | return x[idx_unshuffle]
379 |
380 | def contrastive_loss(self, im_q, im_k):
381 | # compute query features
382 | q = self.encoder_q(im_q) # queries: NxC
383 | q = nn.functional.normalize(q, dim=1) # already normalized
384 |
385 | # compute key features
386 | with torch.no_grad(): # no gradient to keys
387 | # shuffle for making use of BN
388 | im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
389 |
390 | k = self.encoder_k(im_k_) # keys: NxC
391 | k = nn.functional.normalize(k, dim=1) # already normalized
392 |
393 | # undo shuffle
394 | k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
395 |
396 | # compute logits
397 | # Einstein sum is more intuitive
398 | # positive logits: Nx1
399 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
400 | # negative logits: NxK
401 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
402 |
403 | # logits: Nx(1+K)
404 | logits = torch.cat([l_pos, l_neg], dim=1)
405 |
406 | # apply temperature
407 | logits /= self.T
408 |
409 | # labels: positive key indicators
410 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
411 |
412 | loss = nn.CrossEntropyLoss().cuda()(logits, labels)
413 |
414 | return loss, q, k
415 |
416 | def forward(self, im1, im2):
417 | """
418 | Input:
419 | im_q: a batch of query images
420 | im_k: a batch of key images
421 | Output:
422 | loss
423 | """
424 |
425 | # update the key encoder
426 | with torch.no_grad(): # no gradient to keys
427 | self._momentum_update_key_encoder()
428 |
429 | # compute loss
430 | if self.symmetric: # asymmetric loss
431 | loss_12, q1, k2 = self.contrastive_loss(im1, im2)
432 | loss_21, q2, k1 = self.contrastive_loss(im2, im1)
433 | loss = loss_12 + loss_21
434 | k = torch.cat([k1, k2], dim=0)
435 | else: # asymmetric loss
436 | loss, q, k = self.contrastive_loss(im1, im2)
437 |
438 | self._dequeue_and_enqueue(k)
439 |
440 | return loss
441 |
442 | # create model
443 | model = ModelMoCo(
444 | dim=args.moco_dim,
445 | K=args.moco_k,
446 | m=args.moco_m,
447 | T=args.moco_t,
448 | arch=args.arch,
449 | bn_splits=args.bn_splits,
450 | symmetric=args.symmetric,
451 | ).cuda()
452 | print(model.encoder_q)
453 |
454 | """### Define train/test
455 |
456 |
457 | """
458 |
459 | # train for one epoch
460 | def train(net, data_loader, train_optimizer, epoch, args,trainlist):
461 | net.train()
462 | adjust_learning_rate(optimizer, epoch, args,trainlist)
463 |
464 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
465 | for im_1,im_2 in train_bar:
466 |
467 | im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)
468 |
469 | loss = net(im_1, im_2)
470 |
471 | train_optimizer.zero_grad()
472 | loss.backward()
473 | train_optimizer.step()
474 |
475 | total_num += data_loader.batch_size
476 | total_loss += loss.item() * data_loader.batch_size
477 | train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.10f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))
478 |
479 | return total_loss / total_num
480 |
481 | # lr scheduler for training
482 | def adjust_learning_rate(optimizer, epoch, args,trainlist):
483 | """Decay the learning rate based on schedule"""
484 | lr = args.lr
485 | if args.cos: # cosine lr schedule
486 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
487 | else: # stepwise lr schedule
488 | for milestone in args.schedule:
489 | lr *= 0.1 if epoch >= milestone else 1.
490 | for param_group in optimizer.param_groups:
491 | param_group['lr'] = lr
492 |
493 | # test using a knn monitor
494 | def test(net, memory_data_loader, test_data_loader, epoch, args):
495 | net.eval()
496 | classes = len(memory_data_loader.dataset.classes)
497 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
498 | with torch.no_grad():
499 | # generate feature bank
500 | for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
501 | feature = net(data.cuda(non_blocking=True))
502 | feature = F.normalize(feature, dim=1)
503 | feature_bank.append(feature)
504 | # [D, N]
505 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
506 | # [N]
507 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
508 | # loop test data to predict the label by weighted knn search
509 | test_bar = tqdm(test_data_loader)
510 | for data, target in test_bar:
511 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
512 | feature = net(data)
513 | feature = F.normalize(feature, dim=1)
514 |
515 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)
516 |
517 | total_num += data.size(0)
518 | total_top1 += (pred_labels[:, 0] == target).float().sum().item()
519 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))
520 |
521 | return total_top1 / total_num * 100
522 |
523 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
524 | # compute cos similarity between each feature vector and feature bank ---> [B, N]
525 | sim_matrix = torch.mm(feature, feature_bank)
526 | # [B, K]
527 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
528 | # [B, K]
529 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
530 | sim_weight = (sim_weight / knn_t).exp()
531 |
532 | # counts for each class
533 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
534 | # [B*K, C]
535 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
536 | # weighted score ---> [B, C]
537 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
538 |
539 | pred_labels = pred_scores.argsort(dim=-1, descending=True)
540 | return pred_labels
541 |
542 | """### Start training"""
543 |
544 | # define optimizer
545 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
546 |
547 | # load model if resume
548 | epoch_start = 1
549 | if args.resume is not '':
550 | checkpoint = torch.load(args.resume, map_location=torch.device('cuda:0'))
551 | model.load_state_dict(checkpoint['state_dict'])
552 | optimizer.load_state_dict(checkpoint['optimizer'])
553 | epoch_start = checkpoint['epoch'] + 1
554 | print('Loaded from: {}'.format(args.resume))
555 |
556 | # logging
557 | results = {'train_loss': []}
558 | if not os.path.exists(args.results_dir):
559 | os.mkdir(args.results_dir)
560 | # dump args
561 | with open(args.results_dir + '/args.json', 'w') as fid:
562 | json.dump(args.__dict__, fid, indent=2)
563 |
564 | # training loop
565 | trainlist=[]
566 | for epoch in range(epoch_start, args.epochs + 1):
567 | train_loss = train(model, train_loader, optimizer, epoch, args,trainlist)
568 | results['train_loss'].append(train_loss)
569 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
570 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
571 | # save model
572 | torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')
--------------------------------------------------------------------------------