├── .idea ├── .name ├── .gitignore ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── HUST-OBS.iml └── deployment.xml ├── OCR ├── use.json ├── result.json ├── test_pic1.jpg ├── test_pic2.png ├── Dataset establishment.py ├── test.py └── train.py ├── MoCo ├── Dataset establishment.py ├── test.py └── train.py ├── requirements.txt ├── Validation ├── standard deviation.py ├── Dataset establishment.py ├── test.py ├── train.py └── Validation_label.json └── README.md /.idea/.name: -------------------------------------------------------------------------------- 1 | train.py -------------------------------------------------------------------------------- /OCR/use.json: -------------------------------------------------------------------------------- 1 | ["test_pic1.jpg", "test_pic2.png"] -------------------------------------------------------------------------------- /OCR/result.json: -------------------------------------------------------------------------------- 1 | {"test_pic1.jpg": ["五"], "test_pic2.png": ["璧"]} -------------------------------------------------------------------------------- /OCR/test_pic1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pengjie-W/HUST-OBC/HEAD/OCR/test_pic1.jpg -------------------------------------------------------------------------------- /OCR/test_pic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pengjie-W/HUST-OBC/HEAD/OCR/test_pic2.png -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/HUST-OBS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /OCR/Dataset establishment.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import shutil 5 | from tqdm import tqdm 6 | import random 7 | 8 | folder_path = './OCR_Dataset' 9 | dataset = [] 10 | for root, directories, files in tqdm(os.walk(folder_path)): 11 | 12 | for file in files: 13 | data = {} 14 | file_path = os.path.join(root, file) 15 | 16 | data['label'] = int(file.replace('.png','')) 17 | data['path'] = file_path 18 | dataset.append(copy.deepcopy(data)) 19 | 20 | print(len(dataset)) 21 | with open('OCR_train.json', 'w', encoding='utf-8') as f: 22 | json.dump(dataset, f, ensure_ascii=False) 23 | -------------------------------------------------------------------------------- /MoCo/Dataset establishment.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import shutil 5 | from tqdm import tqdm 6 | import random 7 | 8 | folder_path = '../HUST-OBC/deciphered' 9 | dataset = [] 10 | for root, directories, files in tqdm(os.walk(folder_path)): 11 | 12 | for file in files: 13 | if'ID'in file: 14 | continue 15 | data = {} 16 | file_path = os.path.join(root, file) 17 | 18 | data['label'] = int(file[2:6]) 19 | data['path'] = file_path 20 | dataset.append(copy.deepcopy(data)) 21 | 22 | print(len(dataset)) 23 | with open('MOCO_train.json', 'w', encoding='utf-8') as f: 24 | json.dump(dataset, f, ensure_ascii=False) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | filelock==3.14.0 2 | fsspec==2024.5.0 3 | Jinja2==3.1.4 4 | joblib==1.4.2 5 | MarkupSafe==2.1.5 6 | mpmath==1.3.0 7 | networkx==3.3 8 | numpy==1.26.4 9 | nvidia-cublas-cu12==12.1.3.1 10 | nvidia-cuda-cupti-cu12==12.1.105 11 | nvidia-cuda-nvrtc-cu12==12.1.105 12 | nvidia-cuda-runtime-cu12==12.1.105 13 | nvidia-cudnn-cu12==8.9.2.26 14 | nvidia-cufft-cu12==11.0.2.54 15 | nvidia-curand-cu12==10.3.2.106 16 | nvidia-cusolver-cu12==11.4.5.107 17 | nvidia-cusparse-cu12==12.1.0.106 18 | nvidia-nccl-cu12==2.20.5 19 | nvidia-nvjitlink-cu12==12.5.40 20 | nvidia-nvtx-cu12==12.1.105 21 | opencv-python==4.9.0.80 22 | pandas==2.2.2 23 | pillow==10.3.0 24 | python-dateutil==2.9.0.post0 25 | pytz==2024.1 26 | scikit-learn==1.5.0 27 | scipy==1.13.1 28 | six==1.16.0 29 | sympy==1.12 30 | threadpoolctl==3.5.0 31 | torch==2.3.0 32 | torchaudio==2.3.0 33 | torchvision==0.18.0 34 | tqdm==4.66.4 35 | triton==2.3.0 36 | typing_extensions==4.12.0 37 | tzdata==2024.1 38 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 20 | -------------------------------------------------------------------------------- /Validation/standard deviation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | from torchvision.datasets import ImageFolder 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | class TrainData(Dataset) : 10 | def __init__(self,transform = None): 11 | super(TrainData, self).__init__() 12 | with open('Validation_train.json', 'r', encoding='utf8') as f: 13 | images=json.load(f) 14 | labels=images 15 | self.images, self.labels = images, labels 16 | self.transform = transform 17 | 18 | def __getitem__(self, item): 19 | # 读取图片 20 | image = Image.open(self.images[item]['path'].replace('\\','/')) 21 | if image.mode == 'L': 22 | image = image.convert('RGB') 23 | width, height = image.size 24 | if width>height: 25 | dy = width - height 26 | 27 | yl = round(dy / 2) 28 | yr = dy - yl 29 | train_transform = transforms.Compose([ 30 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'), 31 | ]) 32 | else: 33 | dx = height - width 34 | xl = round(dx / 2) 35 | xr = dx - xl 36 | train_transform = transforms.Compose([ 37 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'), 38 | ]) 39 | 40 | image = train_transform(image) 41 | train_transform = transforms.Compose([ 42 | transforms.Resize((224, 224)), 43 | transforms.ToTensor(),]) 44 | image = train_transform(image) 45 | return image,self.images[item]['label'] 46 | 47 | def __len__(self): 48 | return len(self.images) 49 | def getStat(train_data): 50 | ''' 51 | Compute mean and variance for training data 52 | :param train_data: 自定义类Dataset(或ImageFolder即可) 53 | :return: (mean, std) 54 | ''' 55 | print('Compute mean and variance for training data.') 56 | print(len(train_data)) 57 | train_loader = torch.utils.data.DataLoader( 58 | train_data, batch_size=1, shuffle=False, num_workers=4, 59 | pin_memory=True) 60 | mean = torch.zeros(3) 61 | std = torch.zeros(3) 62 | for X, _ in tqdm(train_loader): 63 | for d in range(3): 64 | mean[d] += X[:, d, :, :].mean() 65 | std[d] += X[:, d, :, :].std() 66 | mean.div_(len(train_data)) 67 | std.div_(len(train_data)) 68 | return list(mean.numpy()), list(std.numpy()) 69 | 70 | 71 | if __name__ == '__main__': 72 | train_dataset = TrainData() 73 | print(getStat(train_dataset)) 74 | -------------------------------------------------------------------------------- /Validation/Dataset establishment.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from pathlib import Path 5 | import numpy as np 6 | from sklearn.model_selection import StratifiedShuffleSplit 7 | # 设置随机种子 8 | seed = 42 9 | np.random.seed(seed) 10 | import random 11 | random.seed(seed) 12 | 13 | dataset = [] 14 | X = [] 15 | y = [] 16 | for root, directories, files in os.walk('../HUST-OBC/deciphered/'): 17 | for file in files: 18 | data = {} 19 | if 'json' in file: 20 | continue 21 | file_path = str(Path(os.path.join(root, file))) 22 | folders = os.path.split(file_path)[0].split(os.sep) 23 | folder_name = folders[3] 24 | y.append(folder_name) 25 | X.append(file_path) 26 | 27 | # 找出样本数量 28 | unique_classes, class_counts = np.unique(y, return_counts=True) 29 | single_sample_classes = unique_classes[class_counts == 1] 30 | multiple_sample_classes = unique_classes[class_counts > 1] 31 | 32 | # 将只有一个样本的类别放入训练集 33 | train_indices = [idx for idx, label in enumerate(y) if label in single_sample_classes] 34 | remaining_indices = [idx for idx in range(len(y)) if idx not in train_indices] 35 | 36 | # 剩余样本分割为训练集和验证/测试集 (8:2) 37 | X_train = [X[idx] for idx in train_indices] 38 | y_train = [y[idx] for idx in train_indices] 39 | 40 | X_remaining = [X[idx] for idx in remaining_indices] 41 | y_remaining = [y[idx] for idx in remaining_indices] 42 | 43 | # 分层抽样分割 44 | stratified_splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42) 45 | for train_idx, test_val_idx in stratified_splitter.split(X_remaining, y_remaining): 46 | X_train_remaining = [X_remaining[idx] for idx in train_idx] 47 | y_train_remaining = [y_remaining[idx] for idx in train_idx] 48 | X_test_val = [X_remaining[idx] for idx in test_val_idx] 49 | y_test_val = [y_remaining[idx] for idx in test_val_idx] 50 | 51 | X_train.extend(X_train_remaining) 52 | y_train.extend(y_train_remaining) 53 | 54 | unique_classes, class_counts = np.unique(y_test_val, return_counts=True) 55 | single_sample_classes = unique_classes[class_counts == 1] 56 | multiple_sample_classes = unique_classes[class_counts > 1] 57 | 58 | # 将只有一个样本的类别放入随机放入验证集和测试集 59 | train_indices = [idx for idx, label in enumerate(y_test_val) if label in single_sample_classes] 60 | remaining_indices = [idx for idx in range(len(y_test_val)) if idx not in train_indices] 61 | 62 | X_val = [] 63 | y_val = [] 64 | X_test = [] 65 | y_test = [] 66 | 67 | for idx in train_indices: 68 | if np.random.rand() < 0.5: 69 | X_val.append(X_test_val[idx]) 70 | y_val.append(y_test_val[idx]) 71 | else: 72 | X_test.append(X_test_val[idx]) 73 | y_test.append(y_test_val[idx]) 74 | 75 | # 将其余样本进行分层抽样分割验证集和测试集 (1:1) 76 | 77 | X_test_val_remaining = [X_test_val[idx] for idx in remaining_indices] 78 | y_test_val_remaining = [y_test_val[idx] for idx in remaining_indices] 79 | 80 | stratified_splitter_test_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42) 81 | for val_idx, test_idx in stratified_splitter_test_val.split(X_test_val_remaining, y_test_val_remaining): 82 | X_val.extend([X_test_val_remaining[idx] for idx in val_idx]) 83 | y_val.extend([y_test_val_remaining[idx] for idx in val_idx]) 84 | X_test.extend([X_test_val_remaining[idx] for idx in test_idx]) 85 | y_test.extend([y_test_val_remaining[idx] for idx in test_idx]) 86 | 87 | unique_classes, class_counts = np.unique(y, return_counts=True) 88 | # 创建标签字典 89 | dataset = {label: idx for idx, label in enumerate(unique_classes)} 90 | 91 | # 保存训练集、验证集和测试集 92 | train_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_train, y_train)] 93 | val_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_val, y_val)] 94 | test_data = [{'path': path, 'label': dataset[label]} for path, label in zip(X_test, y_test)] 95 | 96 | with open('Validation_train.json', 'w', encoding='utf8') as f: 97 | json.dump(train_data, f, ensure_ascii=False) 98 | with open('Validation_label.json', 'w', encoding='utf8') as f: 99 | json.dump(dataset, f, ensure_ascii=False) 100 | with open('Validation_val.json', 'w', encoding='utf8') as f: 101 | json.dump(val_data, f, ensure_ascii=False) 102 | with open('Validation_test.json', 'w', encoding='utf8') as f: 103 | json.dump(test_data, f, ensure_ascii=False) 104 | 105 | print(f'Training set size: {len(train_data)}') 106 | print(f'Validation set size: {len(val_data)}') 107 | print(f'Test set size: {len(test_data)}') 108 | -------------------------------------------------------------------------------- /Validation/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import random 5 | from datetime import datetime 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | from sklearn.metrics import f1_score 11 | from torch.nn import functional as F 12 | from PIL import Image 13 | from torch import nn 14 | from torch.utils.data import DataLoader, Dataset 15 | from torchvision import transforms 16 | import math 17 | import argparse 18 | from tqdm import tqdm 19 | import pandas as pd 20 | 21 | """### Set arguments""" 22 | parser = argparse.ArgumentParser(description='Test on HUST-OBC') 23 | 24 | parser.add_argument('--lr', '--learning-rate', default=0.015, type=float, metavar='LR', help='initial learning rate', 25 | dest='lr') 26 | parser.add_argument('--epochs', default=1000, type=int, metavar='N', help='number of total epochs to run') 27 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') 28 | parser.add_argument('--num_workers', default=0, type=int) 29 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay') 30 | 31 | 32 | # utils 33 | parser.add_argument('--resume', default='./max_val_acc.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 34 | parser.add_argument('--results-dir', default='test', type=str, metavar='PATH', help='path to cache (default: none)') 35 | args = parser.parse_args() # running in command line 36 | if args.results_dir == '': 37 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco") 38 | print(args) 39 | args = parser.parse_args() # running in command line 40 | 41 | class TestData(Dataset): 42 | def __init__(self, transform=None): 43 | super(TestData, self).__init__() 44 | with open('Validation_test.json', 'r', encoding='utf8') as f: 45 | images = json.load(f) 46 | labels = images 47 | self.images, self.labels = images, labels 48 | self.transform = transform 49 | 50 | def __getitem__(self, item): 51 | # 读取图片 52 | image = Image.open(self.images[item]['path'].replace('\\','/')) 53 | # 转换 54 | if image.mode == 'L': 55 | image = image.convert('RGB') 56 | width, height = image.size 57 | if width>height: 58 | dy = width - height 59 | 60 | yl = round(dy / 2) 61 | yr = dy - yl 62 | train_transform = transforms.Compose([ 63 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'), 64 | ]) 65 | else: 66 | dx = height - width 67 | xl = round(dx / 2) 68 | xr = dx - xl 69 | train_transform = transforms.Compose([ 70 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'), 71 | ]) 72 | 73 | image = train_transform(image) 74 | train_transform = transforms.Compose([ 75 | transforms.Resize((128, 128)), 76 | # transforms.CenterCrop(224), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])]) 79 | image = train_transform(image) 80 | label = torch.from_numpy(np.array(self.images[item]['label'])) 81 | return image, label,self.images[item]['path'].replace('\\','/') 82 | 83 | def __len__(self): 84 | return len(self.images) 85 | 86 | 87 | 88 | test_dataset = TestData() 89 | test_loader = DataLoader(test_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True) 90 | 91 | 92 | net = torchvision.models.resnet50(pretrained=False) 93 | num_ftrs = net.fc.in_features 94 | net.fc = nn.Linear(num_ftrs, 1588) 95 | net = net.cuda(0) 96 | 97 | 98 | def init_weights(m): 99 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 100 | nn.init.xavier_uniform_(m.weight) 101 | 102 | 103 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9) 104 | loss = nn.CrossEntropyLoss() 105 | 106 | 107 | 108 | 109 | 110 | def accuracy(y_hat, y): 111 | """Compute the number of correct predictions. 112 | 113 | Defined in :numref:`sec_softmax_scratch`""" 114 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: 115 | y_hat = torch.argmax(y_hat, dim=1) 116 | if len(y.shape) > 1 and y.shape[1] > 1: 117 | y = torch.argmax(y, dim=1) 118 | cmp = torch.eq(y_hat, y) 119 | return float(torch.sum(cmp).item()) 120 | 121 | 122 | 123 | def test(net, test_data_loader, epoch, args): 124 | net.eval() 125 | all_labels = [] 126 | all_preds = [] 127 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader) 128 | with torch.no_grad(): 129 | for image, label,path in test_bar: 130 | image, label = image.cuda(0), label.cuda(0) 131 | y_hat = net(image) 132 | _, preds = torch.max(y_hat, 1) 133 | all_labels.extend(label.cpu().numpy()) 134 | all_preds.extend(preds.cpu().numpy()) 135 | total_num += image.shape[0] 136 | testacc += accuracy(y_hat, label) 137 | test_bar.set_description( 138 | 'Test Epoch: [{}/{}], testacc: {:.6f}'.format(epoch, args.epochs, testacc / total_num)) 139 | f1_macro = f1_score(all_labels, all_preds, average='macro') 140 | f1_micro = f1_score(all_labels, all_preds, average='micro') 141 | print(f'Macro-averaged F1 score: {f1_macro}') 142 | print(f'Micro-averaged F1 score: {f1_micro}') 143 | return testacc / total_num 144 | 145 | results = {'train_loss': [], 'train_acc': [],'test_acc': [], 'lr': []} 146 | epoch_start = 1 147 | if args.resume != '': 148 | checkpoint = torch.load(args.resume) 149 | net.load_state_dict(checkpoint['state_dict']) 150 | optimizer.load_state_dict(checkpoint['optimizer']) 151 | epoch_start = checkpoint['epoch'] 152 | print('Loaded from: {}'.format(args.resume)) 153 | else: 154 | net.apply(init_weights) 155 | 156 | test_acc = test(net, test_loader, epoch_start, args) 157 | print(test_acc) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 4 | 5 | # HUST-OBC 6 | [![Paper](https://img.shields.io/badge/Paper-white)](https://arxiv.org/abs/2401.15365) 7 | [![figshare](https://img.shields.io/badge/figshare-blue)](https://doi.org/10.6084/m9.figshare.25040543.v3) 8 | [![Download Dataset](https://img.shields.io/badge/hyper.ai-pink)](https://hyper.ai/datasets/33506) 9 | 10 | Oracle Bone Character data collected by VLRLab of HUST 11 | We have open-sourced the HUST-OBC dataset and the models used in the dataset, including: Chinese OCR, MoCo, and the ResNet50 for Validation. 12 | 13 | ## HUST-OBC Dataset 14 | [HUST-OBC Download](https://figshare.com/s/8a9c0420312d94fc01e3) 15 | ### Tree of our dataset 16 | - HUST-OBC **(We have renamed HUST-OBS to HUST-OBC)** 17 | - deciphered 18 | - ID1 19 | - Source_ID1_Filename 20 | - Source_ID1_Filename 21 | - ..... 22 | - ID2 23 | - Source_ID2_Filename 24 | - ..... 25 | - ID3 26 | - ..... 27 | - chinese_to_ID.json 28 | - ID_to_chinese.json 29 | - undeciphered 30 | - L 31 | - L_?_Filename 32 | - L_?_Filename 33 | - ..... 34 | - X 35 | - X_?_Filename 36 | - ..... 37 | - Y+H 38 | - Y_?_Filename 39 | - H_?_Filename 40 | - ..... 41 | - GuoXueDaShi_1390 42 | - ID1 43 | - Source_ID1_Filename 44 | - Source_ID1_Filename 45 | - ..... 46 | - ID2 47 | - Source_ID2_Filename 48 | - ..... 49 | - ID3 50 | - ..... 51 | - chinese_to_ID.json 52 | - ID_to_chinese.json 53 | 54 | Source:’X’ represents "New Compilation of Oracle Bone Scripts", ’L’ represents the "Oracle Bone Script: Six Digit Numerical Code",’G’ represents the "GuoXueDaShi" website, ’Y’ represents the "YinQiWenYuan" website, and ’H’ represents the HWOBC dataset, they are the sources of the data. 55 | ## Environment 56 | ```bash 57 | conda create -n HUST-OBC python=3.10 58 | conda activate HUST-OBC 59 | git clone https://github.com/Pengjie-W/HUST-OBC.git 60 | cd HUST-OBC 61 | pip install -r requirements.txt 62 | ``` 63 | ## Instructions for use 64 | To use MoCo or Validation, you need to download HUST-OBC. You can then directly use their trained models for prediction. If you want to use Chinese OCR, please download the OCR dataset and the corresponding model. After downloading, organize the data as follows. 65 | 66 | - Your_dataroot 67 | - [HUST-OBC](https://figshare.com/s/8a9c0420312d94fc01e3) 68 | - deciphered 69 | - ... 70 | - MoCo 71 | - [model_last.pth](https://figshare.com/s/30c206b1d1f1870ae76f) 72 | - ... 73 | - OCR 74 | - [OCR_Dataset](https://figshare.com/s/b03be2bccdd867b73e5f) 75 | - [model_last.pth](https://figshare.com/s/7ec755b4ba77c6994ed2) 76 | - ... 77 | - Validation 78 | - [max_val_acc.pth](https://figshare.com/s/4149c5c7f52e0f99e366) 79 | - ... 80 | 81 | ## Chinese OCR 82 | The code for training and testing (usage) is provided in the OCR folder. Includes recognition of 88,899 classes of Chinese characters. [Model download](https://figshare.com/s/7ec755b4ba77c6994ed2). Category numbers and their corresponding Chinese characters are stored in OCR/label.json. We have provided models and code with α set to 0. 83 | [OCR Dataset download](https://figshare.com/s/b03be2bccdd867b73e5f). 84 | 85 | 86 | 87 | You can use [train.py](OCR/train.py) for fine-tuning or retraining. [Chinese_to_ID.json](OCR/Chinese_to_ID.json) and [ID_to_Chinese.json](OCR/ID_to_Chinese.json) store the mappings between OCR dataset category IDs and Chinese characters. [Dataset establishment.py]() is used to generate the training dataset [OCR_train.json](OCR/OCR_train.json). Once the model is downloaded, you can directly use [test.py](OCR/test.py) for testing, which includes two example test images that are Chinese character images cropped from other PDFs. It's best to use images with a white background. [use.json](OCR/use.json) contains the paths to the test images, saved in a list format. The recognized content is output to [result.json](OCR/result.json). 88 | 89 | 90 | ## MoCo 91 | The code for training and testing (usage) is provided in the MoCo folder. [Model download](https://figshare.com/s/30c206b1d1f1870ae76f). 92 | 93 | 94 | 95 | You can use [train.py](MoCo/train.py) for fine-tuning or retraining, [Dataset establishment.py]() is used to generate the training dataset [MOCO_train.json](MoCo/MOCO_train.json). After downloading the MoCo model, [test.py](MoCo/test.py) is utilized for operating MoCo on 1,781 unmerged categories of oracle bones, seeking the first sample from another category with a similarity greater than args.w to find the similarity between different categories of oracle bones. The results are saved in [result.json](MoCo/result.json). 96 | ## Validation 97 | The code for training and testing (usage) is provided in the Validation folder. [Model download](https://figshare.com/s/4149c5c7f52e0f99e366). 98 | 99 | 101 | 102 | [Dataset establishment.py]() is used for splitting the dataset. Since the classification model cannot recognize unseen categories, all categories with only one sample are allocated to the train set. [Validation_test.json](Validation/Validation_test.json), [Validation_val.json](Validation/Validation_val.json) and [Validation_train.json](Validation/Validation_train.json) are the test, val and training sets, respectively, split in a 1:1:8 ratio. [standard deviation.py]() is used to obtain the standard deviation of the training set. 103 | 104 | You can use [train.py](Validation/train.py) for fine-tuning or retraining. Once the model is downloaded, you can use [test.py](Validation/test.py) to validate the test set with an accuracy of 94.6%. [log.csv](Validation/log.csv) records the changes in training set accuracy and test set accuracy for each epoch. 105 | [Validation_label.json](Validation/Validation_label.json) stores the relationship between classification IDs and dataset category IDs. 106 | 107 | -------------------------------------------------------------------------------- /OCR/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import random 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from PIL import Image 11 | from torch import nn 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision import transforms 14 | import math 15 | import argparse 16 | from tqdm import tqdm 17 | import pandas as pd 18 | 19 | """### Set arguments""" 20 | 21 | parser = argparse.ArgumentParser(description='Test on Chinese OCR Dataset') 22 | # utils 23 | parser.add_argument('--resume', default='model_last.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 24 | parser.add_argument('--results-dir', default='test', type=str, metavar='PATH', help='path to cache (default: none)') 25 | parser.add_argument('--k', default=1, type=int)#Display the top k Chinese characters by probability. 26 | parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size') 27 | parser.add_argument('--num_workers', default=0, type=int) 28 | args = parser.parse_args() # running in command line 29 | 30 | if args.results_dir == '': 31 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco") 32 | print(args) 33 | args = parser.parse_args() # running in command line 34 | 35 | with open('ID_to_Chinese.json', 'r', encoding='utf8') as f: 36 | data = json.load(f) 37 | class TestData(Dataset): 38 | def __init__(self, transform=None): 39 | super(TestData, self).__init__() 40 | with open('use.json', 'r') as f: 41 | images = json.load(f) 42 | labels = images 43 | self.images, self.labels = images, labels 44 | self.transform = transform 45 | 46 | def __getitem__(self, item): 47 | # 读取图片 48 | image = Image.open(self.images[item]) 49 | # 转换 50 | if image.mode == 'L': 51 | image = image.convert('RGB') 52 | # 获取当前图像的尺寸 53 | width, height = image.size 54 | if width > height: 55 | dy = width - height 56 | 57 | yl = round(dy / 2) 58 | yr = dy - yl 59 | train_transform = transforms.Compose([ 60 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'), 61 | ]) 62 | else: 63 | dx = height - width 64 | xl = round(dx / 2) 65 | xr = dx - xl 66 | train_transform = transforms.Compose([ 67 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'), 68 | ]) 69 | 70 | image = train_transform(image) 71 | train_transform = transforms.Compose([ 72 | transforms.Resize((128, 128)), 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.7760929, 0.7760929, 0.7760929], [0.39767382, 0.39767382, 0.39767382])]) 75 | image = train_transform(image) 76 | return image,self.images[item] 77 | 78 | def __len__(self): 79 | return len(self.images) 80 | 81 | test_dataset = TestData() 82 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True) 83 | 84 | 85 | class Residual(nn.Module): 86 | def __init__(self, input_channels, min_channels, num_channels, 87 | use_1x1conv=False, strides=1): 88 | super().__init__() 89 | self.conv1 = nn.Conv2d(input_channels, min_channels, 90 | kernel_size=1) 91 | self.conv2 = nn.Conv2d(min_channels, min_channels, 92 | kernel_size=3, padding=1, stride=strides) 93 | self.conv3 = nn.Conv2d(min_channels, num_channels, 94 | kernel_size=1) 95 | if use_1x1conv: 96 | self.conv4 = nn.Conv2d(input_channels, num_channels, 97 | kernel_size=1, stride=strides) 98 | else: 99 | self.conv4 = None 100 | self.bn1 = nn.BatchNorm2d(min_channels) 101 | self.bn2 = nn.BatchNorm2d(min_channels) 102 | self.bn3 = nn.BatchNorm2d(num_channels) 103 | 104 | def forward(self, X): 105 | Y = F.relu(self.bn1(self.conv1(X))) 106 | Y = self.bn2(self.conv2(Y)) 107 | Y = self.bn3(self.conv3(Y)) 108 | if self.conv4: 109 | X = self.conv4(X) 110 | Y += X 111 | return F.relu(Y) 112 | 113 | b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), 114 | nn.BatchNorm2d(64), nn.ReLU(), 115 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 116 | 117 | 118 | def resnet_block(input_channels, min_channels, num_channels, num_residuals, stride, 119 | first_block=False): 120 | blk = [] 121 | for i in range(num_residuals): 122 | if i == 0 and not first_block: 123 | blk.append(Residual(input_channels, min_channels, num_channels, 124 | use_1x1conv=True, strides=stride)) 125 | elif first_block and i == 0: 126 | blk.append(Residual(input_channels, min_channels, num_channels, use_1x1conv=True)) 127 | else: 128 | blk.append(Residual(num_channels, min_channels, num_channels)) 129 | return blk 130 | 131 | 132 | b2 = nn.Sequential(*resnet_block(64, 64, 256, 3, 2, first_block=True)) 133 | b3 = nn.Sequential(*resnet_block(256, 128, 512, 4, 2)) 134 | b4 = nn.Sequential(*resnet_block(512, 256, 1024, 6, 2)) 135 | b5 = nn.Sequential(*resnet_block(1024, 512, 2048, 2, 2)) 136 | net = nn.Sequential(b1, b2, b3, b4, b5, 137 | nn.AdaptiveAvgPool2d((1, 1)), 138 | nn.Flatten(), nn.Linear(2048, 88899)) 139 | 140 | net = net.cuda(0) 141 | def init_weights(m): 142 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 143 | nn.init.xavier_uniform_(m.weight) 144 | 145 | 146 | def adjust_learning_rate(optimizer, epoch, args): 147 | """Decay the learning rate based on schedule""" 148 | lr = args.lr 149 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 150 | for param_group in optimizer.param_groups: 151 | param_group['lr'] = lr 152 | 153 | 154 | def accuracy(y_hat, y): 155 | """Compute the number of correct predictions. 156 | 157 | Defined in :numref:`sec_softmax_scratch`""" 158 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: 159 | y_hat = torch.argmax(y_hat, dim=1) 160 | if len(y.shape) > 1 and y.shape[1] > 1: 161 | y = torch.argmax(y, dim=1) 162 | cmp = torch.eq(y_hat, y) 163 | return float(torch.sum(cmp).item()) 164 | 165 | 166 | def test(net, test_data_loader): 167 | net.eval() 168 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader) 169 | with torch.no_grad(): 170 | pathlist=[] 171 | labellist=[] 172 | for image,path in test_bar: 173 | image = image.cuda(0) 174 | y_hat = net(image) 175 | # y_hat = torch.argmax(y_hat, dim=1) 176 | y_hat = torch.topk(y_hat, args.k, dim=1)[1] 177 | label = y_hat.tolist() 178 | labellist=labellist+label 179 | path=list(path) 180 | pathlist=pathlist+path 181 | 182 | dataset={} 183 | for i in range(len(pathlist)): 184 | path_label=[] 185 | path=pathlist[i] 186 | label=labellist[i] 187 | for j in label: 188 | j=str(j).zfill(5) 189 | path_label.append(data[j]) 190 | dataset[path]=path_label 191 | with open('result.json', 'w',encoding='utf8') as f: 192 | json.dump(dataset, f, ensure_ascii=False) 193 | return 194 | 195 | 196 | results = {'train_loss': [], 'train_acc': [], 'lr': []} 197 | epoch_start = 1 198 | if args.resume != '': 199 | checkpoint = torch.load(args.resume) 200 | net.load_state_dict(checkpoint['state_dict']) 201 | # optimizer.load_state_dict(checkpoint['optimizer']) 202 | epoch_start = checkpoint['epoch'] + 1 203 | print('Loaded from: {}'.format(args.resume)) 204 | else: 205 | net.apply(init_weights) 206 | 207 | test_acc = test(net, test_loader) -------------------------------------------------------------------------------- /OCR/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import random 5 | from datetime import datetime 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from PIL import Image 11 | from torch import nn 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision import transforms 14 | import math 15 | import argparse 16 | from tqdm import tqdm 17 | import pandas as pd 18 | # nohup python train.py > output.log 2>&1 & 19 | """### Set arguments""" 20 | parser = argparse.ArgumentParser(description='Train on Chinese OCR Dataset') 21 | 22 | parser.add_argument('--lr', '--learning-rate', default=0.00015, type=float, metavar='LR', help='initial learning rate', 23 | dest='lr') 24 | parser.add_argument('--epochs', default=1800, type=int, metavar='N', help='number of total epochs to run') 25 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') 26 | parser.add_argument('--num_workers', default=9, type=int) 27 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay') 28 | # utils 29 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 30 | parser.add_argument('--results-dir', default='model', type=str, metavar='PATH', help='path to cache (default: none)') 31 | args = parser.parse_args() # running in command line 32 | if args.results_dir == '': 33 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco") 34 | print(args) 35 | args = parser.parse_args() # running in command line 36 | 37 | 38 | class RandomGaussianBlur(object): 39 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0): 40 | self.p = p 41 | self.min_kernel_size = min_kernel_size 42 | self.max_kernel_size = max_kernel_size 43 | self.min_sigma = min_sigma 44 | self.max_sigma = max_sigma 45 | 46 | def __call__(self, img): 47 | if random.random() < self.p and self.min_kernel_size < self.max_kernel_size: 48 | kernel_size = random.randrange(self.min_kernel_size, self.max_kernel_size + 1, 2) 49 | sigma = random.uniform(self.min_sigma, self.max_sigma) 50 | return transforms.functional.gaussian_blur(img, kernel_size, sigma) 51 | else: 52 | return img 53 | 54 | def jioayan(image): 55 | if np.random.random() < 0.5: 56 | image1 = np.array(image) 57 | # 添加椒盐噪声 58 | salt_vs_pepper_ratio = np.random.uniform(0, 0.4) 59 | amount = np.random.uniform(0, 0.006) 60 | num_salt = np.ceil(amount * image1.size / 3 * salt_vs_pepper_ratio) 61 | num_pepper = np.ceil(amount * image1.size / 3 * (1.0 - salt_vs_pepper_ratio)) 62 | 63 | # 在随机位置生成椒盐噪声 64 | coords_salt = [np.random.randint(0, i - 1, int(num_salt)) for i in image1.shape] 65 | coords_pepper = [np.random.randint(0, i - 1, int(num_pepper)) for i in image1.shape] 66 | image1[coords_salt[0], coords_salt[1], :] = 255 67 | image1[coords_pepper[0], coords_pepper[1], :] = 0 68 | image = Image.fromarray(image1) 69 | return image 70 | def pengzhang(image): 71 | 72 | # 生成一个0到2之间的随机数 73 | random_value = random.random() * 3 74 | 75 | if random_value < 1: # 1/3的概率进行加法操作 76 | he = random.randint(1, 3) 77 | kernel = np.ones((he, he), np.uint8) 78 | image = cv2.erode(image, kernel, iterations=1) 79 | elif random_value < 2: # 1/3的概率进行除法操作 80 | he = random.randint(1, 3) # 生成一个1到10之间的随机整数作为除数 81 | kernel = np.ones((he,he),np.uint8) 82 | image = cv2.dilate(image,kernel,iterations = 1) 83 | return image 84 | 85 | class TrainData(Dataset): 86 | def __init__(self, transform=None): 87 | super(TrainData, self).__init__() 88 | with open('OCR_train.json', 'r') as f: 89 | images = json.load(f) 90 | labels = images 91 | self.images, self.labels = images, labels 92 | self.transform = transform 93 | 94 | def __getitem__(self, item): 95 | # 读取图片 96 | image = Image.open(self.images[item]['path'].replace('\\','/')) 97 | # 转换 98 | if image.mode == 'L': 99 | image = image.convert('RGB') 100 | x, y = 72,72 101 | sizey, sizex = 129, 129 102 | if y < 128: 103 | while sizey > 128 or sizey < 16: 104 | sizey = round(random.gauss(y, 30)) 105 | if x < 128: 106 | while sizex > 128 or sizex < 16: 107 | sizex = round(random.gauss(x, 30)) 108 | dx = 128 - sizex # 差值 109 | dy = 128 - sizey 110 | if dx > 0: 111 | xl =-1 112 | while xl > dx or xl < 0: 113 | xl = round(dx / 2) 114 | xl = round(random.gauss(xl, 10)) 115 | else: 116 | xl = 0 117 | if dy > 0: 118 | yl = -1 119 | while yl > dy or yl < 0: 120 | yl = round(dy / 2) 121 | yl = round(random.gauss(yl, 10)) 122 | else: 123 | yl = 0 124 | yr = dy - yl 125 | xr = dx - xl 126 | image = jioayan(image) 127 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 128 | image = pengzhang(image) 129 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 130 | random_gaussian_blur = RandomGaussianBlur() 131 | image = random_gaussian_blur(image) 132 | train_transform = transforms.Compose([ 133 | transforms.Resize((sizey, sizex)), 134 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'), 135 | transforms.RandomRotation(degrees=(-15, 15), center=(round(64), round(64)), fill=(255, 255, 255)), 136 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 137 | transforms.RandomGrayscale(p=0.2), 138 | transforms.ToTensor(), 139 | transforms.Normalize([0.7760929, 0.7760929, 0.7760929], [0.39767382, 0.39767382, 0.39767382])]) 140 | image = train_transform(image) 141 | label = torch.from_numpy(np.array(self.images[item]['label'])) 142 | return image, label 143 | 144 | def __len__(self): 145 | return len(self.images) 146 | 147 | 148 | train_dataset = TrainData() 149 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True) 150 | 151 | 152 | class Residual(nn.Module): 153 | def __init__(self, input_channels, min_channels, num_channels, 154 | use_1x1conv=False, strides=1): 155 | super().__init__() 156 | self.conv1 = nn.Conv2d(input_channels, min_channels, 157 | kernel_size=1) 158 | self.conv2 = nn.Conv2d(min_channels, min_channels, 159 | kernel_size=3, padding=1, stride=strides) 160 | self.conv3 = nn.Conv2d(min_channels, num_channels, 161 | kernel_size=1) 162 | if use_1x1conv: 163 | self.conv4 = nn.Conv2d(input_channels, num_channels, 164 | kernel_size=1, stride=strides) 165 | else: 166 | self.conv4 = None 167 | self.bn1 = nn.BatchNorm2d(min_channels) 168 | self.bn2 = nn.BatchNorm2d(min_channels) 169 | self.bn3 = nn.BatchNorm2d(num_channels) 170 | 171 | def forward(self, X): 172 | Y = F.relu(self.bn1(self.conv1(X))) 173 | Y = self.bn2(self.conv2(Y)) 174 | Y = self.bn3(self.conv3(Y)) 175 | if self.conv4: 176 | X = self.conv4(X) 177 | Y += X 178 | return F.relu(Y) 179 | 180 | b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), 181 | nn.BatchNorm2d(64), nn.ReLU(), 182 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 183 | 184 | 185 | def resnet_block(input_channels, min_channels, num_channels, num_residuals, stride, 186 | first_block=False): 187 | blk = [] 188 | for i in range(num_residuals): 189 | if i == 0 and not first_block: 190 | blk.append(Residual(input_channels, min_channels, num_channels, 191 | use_1x1conv=True, strides=stride)) 192 | elif first_block and i == 0: 193 | blk.append(Residual(input_channels, min_channels, num_channels, use_1x1conv=True)) 194 | else: 195 | blk.append(Residual(num_channels, min_channels, num_channels)) 196 | return blk 197 | 198 | 199 | b2 = nn.Sequential(*resnet_block(64, 64, 256, 3, 2, first_block=True)) 200 | b3 = nn.Sequential(*resnet_block(256, 128, 512, 4, 2)) 201 | b4 = nn.Sequential(*resnet_block(512, 256, 1024, 6, 2)) 202 | b5 = nn.Sequential(*resnet_block(1024, 512, 2048, 2, 2)) 203 | net = nn.Sequential(b1, b2, b3, b4, b5, 204 | nn.AdaptiveAvgPool2d((1, 1)), 205 | nn.Flatten(), nn.Linear(2048, 88899)) 206 | 207 | net = net.cuda(0) 208 | def init_weights(m): 209 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 210 | nn.init.xavier_uniform_(m.weight) 211 | 212 | 213 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9) 214 | loss = nn.CrossEntropyLoss() 215 | 216 | 217 | def adjust_learning_rate(optimizer, epoch, args): 218 | """Decay the learning rate based on schedule""" 219 | lr = args.lr 220 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 221 | for param_group in optimizer.param_groups: 222 | param_group['lr'] = lr 223 | 224 | 225 | def accuracy(y_hat, y): 226 | """Compute the number of correct predictions. 227 | 228 | Defined in :numref:`sec_softmax_scratch`""" 229 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: 230 | y_hat = torch.argmax(y_hat, dim=1) 231 | if len(y.shape) > 1 and y.shape[1] > 1: 232 | y = torch.argmax(y, dim=1) 233 | cmp = torch.eq(y_hat, y) 234 | return float(torch.sum(cmp).item()) 235 | 236 | def train(net, data_loader, train_optimizer, epoch, args): 237 | net.train() 238 | adjust_learning_rate(optimizer, epoch, args) 239 | total_loss, total_num, trainacc, train_bar = 0.0, 0, 0.0, tqdm(data_loader) 240 | for image, label in train_bar: 241 | image, label = image.cuda(0), label.cuda(0) 242 | label = label.long() 243 | y_hat = net(image) 244 | 245 | train_optimizer.zero_grad() 246 | l = loss(y_hat, label) 247 | l.backward() 248 | train_optimizer.step() 249 | trainacc += accuracy(y_hat, label) 250 | # total_num += data_loader.abatch_size 251 | total_num += image.shape[0] 252 | total_loss += l.item() * data_loader.batch_size 253 | train_bar.set_description( 254 | 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}, trainacc: {:.6f}'.format(epoch, args.epochs, 255 | optimizer.param_groups[0]['lr'], 256 | total_loss / total_num, 257 | trainacc / total_num)) 258 | 259 | return total_loss / total_num, trainacc / total_num 260 | 261 | def test(net, test_data_loader, epoch, args): 262 | net.eval() 263 | testacc, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_data_loader) 264 | with torch.no_grad(): 265 | for image, label in test_bar: 266 | image, label = image.cuda(0), label.cuda(0) 267 | y_hat = net(image) 268 | total_num += test_data_loader.batch_size 269 | testacc += accuracy(y_hat, label) 270 | test_bar.set_description( 271 | 'Test Epoch: [{}/{}], testacc: {:.6f}'.format(epoch, args.epochs, testacc / total_num)) 272 | return testacc / total_num 273 | 274 | results = {'train_loss': [], 'train_acc': [], 'lr': []} 275 | epoch_start = 1 276 | if args.resume != '': 277 | checkpoint = torch.load(args.resume) 278 | net.load_state_dict(checkpoint['state_dict']) 279 | optimizer.load_state_dict(checkpoint['optimizer']) 280 | epoch_start = checkpoint['epoch'] + 1 281 | print('Loaded from: {}'.format(args.resume)) 282 | else: 283 | net.apply(init_weights) 284 | 285 | if not os.path.exists(args.results_dir): 286 | os.mkdir(args.results_dir) 287 | with open(args.results_dir + '/args.json', 'w') as fid: 288 | json.dump(args.__dict__, fid, indent=2) 289 | for epoch in range(epoch_start, args.epochs + 1): 290 | train_loss, train_acc = train(net, train_loader, optimizer, epoch, args) 291 | results['train_loss'].append(train_loss) 292 | results['train_acc'].append(train_acc) 293 | results['lr'].append(args.lr *0.5 * (1. + math.cos(math.pi * epoch / args.epochs))) 294 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1)) 295 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch') 296 | # save model 297 | torch.save({'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), }, 298 | args.results_dir + '/model_last.pth') 299 | -------------------------------------------------------------------------------- /Validation/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import random 5 | from datetime import datetime 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | from sklearn.metrics import f1_score 11 | from torch.nn import functional as F 12 | from PIL import Image 13 | from torch import nn 14 | from torch.utils.data import DataLoader, Dataset 15 | from torchvision import transforms 16 | import math 17 | import argparse 18 | from tqdm import tqdm 19 | import pandas as pd 20 | 21 | 22 | 23 | """### Set arguments""" 24 | parser = argparse.ArgumentParser(description='Train on HUST-OBC') 25 | 26 | parser.add_argument('--lr', '--learning-rate', default=0.015, type=float, metavar='LR', help='initi' 27 | 'al learning rate', 28 | dest='lr') 29 | parser.add_argument('--epochs', default=600, type=int, metavar='N', help='number of total epochs to run') 30 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') 31 | parser.add_argument('--num_workers', default=24, type=int) 32 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay') 33 | 34 | # utils 35 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 36 | parser.add_argument('--results-dir', default='output', type=str, metavar='PATH', help='path to cache (default: none)') 37 | parser.add_argument('--checkpoint_freq', type=int, default=100) 38 | parser.add_argument('--seed', type=int, default=42) 39 | args = parser.parse_args() # running in command line 40 | if args.results_dir == '': 41 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco") 42 | print(args) 43 | args = parser.parse_args() # running in command line 44 | seed=args.seed 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | random.seed(seed) 48 | class RandomGaussianBlur(object): 49 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0): 50 | self.p = p 51 | self.min_kernel_size = min_kernel_size 52 | self.max_kernel_size = max_kernel_size 53 | self.min_sigma = min_sigma 54 | self.max_sigma = max_sigma 55 | 56 | def __call__(self, img): 57 | if random.random() < self.p and self.min_kernel_size < self.max_kernel_size: 58 | kernel_size = random.randrange(self.min_kernel_size, self.max_kernel_size + 1, 2) 59 | sigma = random.uniform(self.min_sigma, self.max_sigma) 60 | return transforms.functional.gaussian_blur(img, kernel_size, sigma) 61 | else: 62 | return img 63 | 64 | # nohup python train_new.py > output.log 2>&1 &^C 65 | def jioayan(image): 66 | if np.random.random() < 0.5: 67 | image1 = np.array(image) 68 | # 添加椒盐噪声 69 | salt_vs_pepper_ratio = np.random.uniform(0, 0.4) 70 | amount = np.random.uniform(0, 0.006) 71 | num_salt = np.ceil(amount * image1.size / 3 * salt_vs_pepper_ratio) 72 | num_pepper = np.ceil(amount * image1.size / 3 * (1.0 - salt_vs_pepper_ratio)) 73 | 74 | # 在随机位置生成椒盐噪声 75 | coords_salt = [np.random.randint(0, i - 1, int(num_salt)) for i in image1.shape] 76 | coords_pepper = [np.random.randint(0, i - 1, int(num_pepper)) for i in image1.shape] 77 | # image1[coords_salt] = 255 78 | image1[coords_salt[0], coords_salt[1], :] = 255 79 | image1[coords_pepper[0], coords_pepper[1], :] = 0 80 | image = Image.fromarray(image1) 81 | return image 82 | 83 | 84 | def pengzhang(image): 85 | # 生成一个0到2之间的随机数 86 | random_value = random.random() * 3 87 | 88 | if random_value < 1: # 1/3的概率进行加法操作 89 | he = random.randint(1, 3) 90 | kernel = np.ones((he, he), np.uint8) 91 | image = cv2.erode(image, kernel, iterations=1) 92 | elif random_value < 2: # 1/3的概率进行除法操作 93 | he = random.randint(1, 3) # 生成一个1到10之间的随机整数作为除数 94 | kernel = np.ones((he, he), np.uint8) 95 | image = cv2.dilate(image, kernel, iterations=1) 96 | return image 97 | 98 | 99 | class TrainData(Dataset): 100 | def __init__(self, transform=None): 101 | super(TrainData, self).__init__() 102 | with open('Validation_train.json', 'r',encoding='utf8') as f: 103 | images = json.load(f) 104 | labels = images 105 | self.images, self.labels = images, labels 106 | self.transform = transform 107 | 108 | def __getitem__(self, item): 109 | # 读取图片 110 | image = Image.open(self.images[item]['path'].replace('\\','/')) 111 | # 转换 112 | if image.mode == 'L': 113 | image = image.convert('RGB') 114 | image_width, image_height = image.size 115 | if image_width > image_height: 116 | x = 72 117 | y = round(image_height / image_width * 72) 118 | # x, y = 72,72 119 | else: 120 | y = 72 121 | x = round(image_width / image_height * 72) 122 | sizey, sizex = 129, 129 123 | if y < 128: 124 | while sizey > 128 or sizey < 16: 125 | sizey = round(random.gauss(y, 30)) 126 | if x < 128: 127 | while sizex > 128 or sizex < 16: 128 | sizex = round(random.gauss(x, 30)) 129 | dx = 128 - sizex # 差值 130 | dy = 128 - sizey 131 | if dx > 0: 132 | xl = -1 133 | while xl > dx or xl < 0: 134 | xl = round(dx / 2) 135 | xl = round(random.gauss(xl, 10)) 136 | else: 137 | xl = 0 138 | if dy > 0: 139 | yl = -1 140 | while yl > dy or yl < 0: 141 | yl = round(dy / 2) 142 | yl = round(random.gauss(yl, 10)) 143 | else: 144 | yl = 0 145 | yr = dy - yl 146 | xr = dx - xl 147 | image = jioayan(image) 148 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 149 | image = pengzhang(image) 150 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 151 | random_gaussian_blur = RandomGaussianBlur() 152 | image = random_gaussian_blur(image) 153 | train_transform = transforms.Compose([ 154 | transforms.Resize((sizey,sizex)), 155 | transforms.RandomHorizontalFlip(p=0.5), 156 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'), 157 | transforms.RandomRotation(degrees=(-15, 15), center=(round(64), round(64)), fill=(255, 255, 255)), 158 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 159 | transforms.RandomGrayscale(p=0.2), 160 | transforms.ToTensor(), 161 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])]) 162 | image = train_transform(image) 163 | label = torch.from_numpy(np.array(self.images[item]['label'])) 164 | return image, label 165 | 166 | def __len__(self): 167 | return len(self.images) 168 | 169 | 170 | class ValData(Dataset): 171 | def __init__(self, transform=None): 172 | super(ValData, self).__init__() 173 | with open('Validation_val.json', 'r',encoding='utf8') as f: 174 | images = json.load(f) 175 | labels = images 176 | self.images, self.labels = images, labels 177 | self.transform = transform 178 | 179 | def __getitem__(self, item): 180 | # 读取图片 181 | image = Image.open(self.images[item]['path'].replace('\\','/')) 182 | # 转换 183 | if image.mode == 'L': 184 | image = image.convert('RGB') 185 | width, height = image.size 186 | if width>height: 187 | dy = width - height 188 | 189 | yl = round(dy / 2) 190 | yr = dy - yl 191 | train_transform = transforms.Compose([ 192 | transforms.Pad([0, yl, 0, yr], fill=(255, 255, 255), padding_mode='constant'), 193 | ]) 194 | else: 195 | dx = height - width 196 | xl = round(dx / 2) 197 | xr = dx - xl 198 | train_transform = transforms.Compose([ 199 | transforms.Pad([xl, 0, xr, 0], fill=(255, 255, 255), padding_mode='constant'), 200 | ]) 201 | 202 | image = train_transform(image) 203 | train_transform = transforms.Compose([ 204 | transforms.Resize((128, 128)), 205 | # transforms.CenterCrop(224), 206 | transforms.ToTensor(), 207 | transforms.Normalize([0.85233593, 0.85246795, 0.8517555], [0.31232414, 0.3122127, 0.31273854])]) 208 | image = train_transform(image) 209 | label = torch.from_numpy(np.array(self.images[item]['label'])) 210 | return image, label 211 | 212 | def __len__(self): 213 | return len(self.images) 214 | 215 | 216 | train_dataset = TrainData() 217 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True) 218 | 219 | 220 | val_dataset = ValData() 221 | val_loader = DataLoader(val_dataset, shuffle=True, batch_size = args.batch_size, num_workers=args.num_workers, pin_memory=True) 222 | 223 | 224 | net = torchvision.models.resnet50(pretrained=False) 225 | num_ftrs = net.fc.in_features 226 | net.fc = nn.Linear(num_ftrs, 1588) 227 | net = net.cuda(0) 228 | 229 | def init_weights(m): 230 | if type(m) == nn.Linear or type(m) == nn.Conv2d: 231 | nn.init.xavier_uniform_(m.weight) 232 | 233 | 234 | optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9) 235 | loss = nn.CrossEntropyLoss() 236 | 237 | 238 | def adjust_learning_rate(optimizer, epoch, args): 239 | """Decay the learning rate based on schedule""" 240 | lr = args.lr 241 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 242 | for param_group in optimizer.param_groups: 243 | param_group['lr'] = lr 244 | 245 | 246 | def accuracy(y_hat, y): 247 | """Compute the number of correct predictions. 248 | 249 | Defined in :numref:`sec_softmax_scratch`""" 250 | if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: 251 | y_hat = torch.argmax(y_hat, dim=1) 252 | if len(y.shape) > 1 and y.shape[1] > 1: 253 | y = torch.argmax(y, dim=1) 254 | cmp = torch.eq(y_hat, y) 255 | return float(torch.sum(cmp).item()) 256 | 257 | 258 | def train(net, data_loader, train_optimizer, epoch, args): 259 | net.train() 260 | adjust_learning_rate(optimizer, epoch, args) 261 | total_loss, total_num, trainacc, train_bar = 0.0, 0, 0.0, tqdm(data_loader,ncols=100) 262 | all_labels = [] 263 | all_preds = [] 264 | for image, label in train_bar: 265 | image, label = image.cuda(0), label.cuda(0) 266 | label = label.long() 267 | y_hat = net(image) 268 | _, preds = torch.max(y_hat, 1) 269 | all_labels.extend(label.cpu().numpy()) 270 | all_preds.extend(preds.cpu().numpy()) 271 | train_optimizer.zero_grad() 272 | l = loss(y_hat, label) 273 | l.backward() 274 | train_optimizer.step() 275 | trainacc += accuracy(y_hat, label) 276 | # total_num += data_loader.abatch_size 277 | total_num += image.shape[0] 278 | total_loss += l.item() * image.shape[0] 279 | train_bar.set_description( 280 | 'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}, trainacc: {:.6f}'.format(epoch, args.epochs, 281 | optimizer.param_groups[0]['lr'], 282 | total_loss / total_num, 283 | trainacc / total_num)) 284 | # 计算 F1 分数 285 | f1_macro = f1_score(all_labels, all_preds, average='macro') 286 | f1_micro = f1_score(all_labels, all_preds, average='micro') 287 | # print(f'Macro-averaged F1 score: {f1_macro}') 288 | # print(f'Micro-averaged F1 score: {f1_micro}') 289 | 290 | return total_loss / total_num, trainacc / total_num,f1_macro,f1_micro 291 | 292 | 293 | def val(net, val_data_loader, epoch, args): 294 | net.eval() 295 | valacc, total_top5, total_num, val_bar = 0.0, 0.0, 0, tqdm(val_data_loader) 296 | all_labels = [] 297 | all_preds = [] 298 | with torch.no_grad(): 299 | for image, label in val_bar: 300 | image, label = image.cuda(0), label.cuda(0) 301 | y_hat = net(image) 302 | _, preds = torch.max(y_hat, 1) 303 | all_labels.extend(label.cpu().numpy()) 304 | all_preds.extend(preds.cpu().numpy()) 305 | total_num+=image.shape[0] 306 | valacc += accuracy(y_hat, label) 307 | val_bar.set_description( 308 | 'Val Epoch: [{}/{}], valacc: {:.6f}'.format(epoch, args.epochs, valacc / total_num)) 309 | # 计算 F1 分数 310 | f1_macro = f1_score(all_labels, all_preds, average='macro') 311 | f1_micro = f1_score(all_labels, all_preds, average='micro') 312 | # print(f'Macro-averaged F1 score: {f1_macro}') 313 | # print(f'Micro-averaged F1 score: {f1_micro}') 314 | return valacc / total_num,f1_macro,f1_micro 315 | 316 | 317 | # results = {'train_loss': [], 'train_acc': [], 'val_acc': []} 318 | results = {'train_loss': [], 'train_acc': [],'train_f1_macro':[],'train_f1_micro':[],'val_acc': [],'val_f1_macro':[],'val_f1_micro':[], 'lr': []} 319 | epoch_start = 1 320 | if args.resume != '': 321 | checkpoint = torch.load(args.resume) 322 | net.load_state_dict(checkpoint['state_dict']) 323 | optimizer.load_state_dict(checkpoint['optimizer']) 324 | epoch_start = checkpoint['epoch'] + 1 325 | print('Loaded from: {}'.format(args.resume)) 326 | else: 327 | net.apply(init_weights) 328 | 329 | if not os.path.exists(args.results_dir): 330 | os.mkdir(args.results_dir) 331 | with open(args.results_dir + '/args.json', 'w') as fid: 332 | json.dump(args.__dict__, fid, indent=2) 333 | max_val_acc=0 334 | max_val_f1_macro=0 335 | max_val_f1_micro=0 336 | for epoch in range(epoch_start, args.epochs + 1): 337 | train_loss, train_acc,train_f1_macro,train_f1_micro = train(net, train_loader, optimizer, epoch, args) 338 | results['train_loss'].append(train_loss) 339 | results['train_acc'].append(train_acc) 340 | results['train_f1_macro'].append(train_f1_macro) 341 | results['train_f1_micro'].append(train_f1_micro) 342 | val_acc,val_f1_macro,val_f1_micro = val(net, val_loader, epoch, args) 343 | results['val_acc'].append(val_acc) 344 | results['val_f1_macro'].append(val_f1_macro) 345 | results['val_f1_micro'].append(val_f1_micro) 346 | results['lr'].append(args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))) 347 | # save statistics 348 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1)) 349 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch') 350 | if (epoch) % args.checkpoint_freq == 0: 351 | checkpoint_name = f'checkpoint_ep{epoch:04}.pth' 352 | # save model 353 | torch.save({'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), }, 354 | args.results_dir + '/'+checkpoint_name) 355 | if epoch>300 and max_val_acc300 and max_val_f1_macro300 and max_val_f1_micro 1 else nn.BatchNorm2d 146 | resnet_arch = getattr(resnet, arch) 147 | net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer) 148 | 149 | self.net = [] 150 | for name, module in net.named_children(): 151 | if name == 'conv1': 152 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 153 | if isinstance(module, nn.MaxPool2d): 154 | continue 155 | if isinstance(module, nn.Linear): 156 | self.net.append(nn.Flatten(1)) 157 | self.net.append(module) 158 | 159 | self.net = nn.Sequential(*self.net) 160 | 161 | def forward(self, x): 162 | x = self.net(x) 163 | # note: not normalized here 164 | return x 165 | 166 | """### Define MoCo wrapper""" 167 | 168 | class ModelMoCo(nn.Module): 169 | def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True): 170 | super(ModelMoCo, self).__init__() 171 | 172 | self.K = K 173 | self.m = m 174 | self.T = T 175 | self.symmetric = symmetric 176 | 177 | # create the encoders 178 | self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits) 179 | self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits) 180 | 181 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 182 | param_k.data.copy_(param_q.data) # initialize 183 | param_k.requires_grad = False # not update by gradient 184 | 185 | # create the queue 186 | self.register_buffer("queue", torch.randn(dim, K)) 187 | self.queue = nn.functional.normalize(self.queue, dim=0) 188 | 189 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 190 | 191 | @torch.no_grad() 192 | def _momentum_update_key_encoder(self): 193 | """ 194 | Momentum update of the key encoder 195 | """ 196 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 197 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 198 | 199 | @torch.no_grad() 200 | def _dequeue_and_enqueue(self, keys): 201 | batch_size = keys.shape[0] 202 | 203 | ptr = int(self.queue_ptr) 204 | assert self.K % batch_size == 0 # for simplicity 205 | 206 | # replace the keys at ptr (dequeue and enqueue) 207 | self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose 208 | ptr = (ptr + batch_size) % self.K # move pointer 209 | 210 | self.queue_ptr[0] = ptr 211 | 212 | @torch.no_grad() 213 | def _batch_shuffle_single_gpu(self, x): 214 | """ 215 | Batch shuffle, for making use of BatchNorm. 216 | """ 217 | # random shuffle index 218 | idx_shuffle = torch.randperm(x.shape[0]).cuda() 219 | 220 | # index for restoring 221 | idx_unshuffle = torch.argsort(idx_shuffle) 222 | 223 | return x[idx_shuffle], idx_unshuffle 224 | 225 | @torch.no_grad() 226 | def _batch_unshuffle_single_gpu(self, x, idx_unshuffle): 227 | """ 228 | Undo batch shuffle. 229 | """ 230 | return x[idx_unshuffle] 231 | 232 | def contrastive_loss(self, im_q, im_k): 233 | # compute query features 234 | q = self.encoder_q(im_q) # queries: NxC 235 | q = nn.functional.normalize(q, dim=1) # already normalized 236 | 237 | # compute key features 238 | with torch.no_grad(): # no gradient to keys 239 | # shuffle for making use of BN 240 | im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k) 241 | 242 | k = self.encoder_k(im_k_) # keys: NxC 243 | k = nn.functional.normalize(k, dim=1) # already normalized 244 | 245 | # undo shuffle 246 | k = self._batch_unshuffle_single_gpu(k, idx_unshuffle) 247 | 248 | # compute logits 249 | # Einstein sum is more intuitive 250 | # positive logits: Nx1 251 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 252 | # negative logits: NxK 253 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 254 | 255 | # logits: Nx(1+K) 256 | logits = torch.cat([l_pos, l_neg], dim=1) 257 | 258 | # apply temperature 259 | logits /= self.T 260 | 261 | # labels: positive key indicators 262 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 263 | 264 | loss = nn.CrossEntropyLoss().cuda()(logits, labels) 265 | 266 | return loss, q, k 267 | 268 | def forward(self, im1, im2): 269 | """ 270 | Input: 271 | im_q: a batch of query images 272 | im_k: a batch of key images 273 | Output: 274 | loss 275 | """ 276 | 277 | # update the key encoder 278 | with torch.no_grad(): # no gradient to keys 279 | self._momentum_update_key_encoder() 280 | 281 | # compute loss 282 | if self.symmetric: # asymmetric loss 283 | loss_12, q1, k2 = self.contrastive_loss(im1, im2) 284 | loss_21, q2, k1 = self.contrastive_loss(im2, im1) 285 | loss = loss_12 + loss_21 286 | k = torch.cat([k1, k2], dim=0) 287 | else: # asymmetric loss 288 | loss, q, k = self.contrastive_loss(im1, im2) 289 | 290 | self._dequeue_and_enqueue(k) 291 | 292 | return loss 293 | 294 | # create model 295 | model = ModelMoCo( 296 | dim=args.moco_dim, 297 | K=args.moco_k, 298 | m=args.moco_m, 299 | T=args.moco_t, 300 | arch=args.arch, 301 | bn_splits=args.bn_splits, 302 | symmetric=args.symmetric, 303 | ).cuda() 304 | print(model.encoder_q) 305 | 306 | """### Define train/test 307 | 308 | 309 | """ 310 | 311 | # train for one epoch 312 | def train(net, data_loader, train_optimizer, epoch, args,trainlist): 313 | net.train() 314 | adjust_learning_rate(optimizer, epoch, args,trainlist) 315 | 316 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 317 | for im_1,im_2 in train_bar: 318 | 319 | im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True) 320 | 321 | loss = net(im_1, im_2) 322 | 323 | train_optimizer.zero_grad() 324 | loss.backward() 325 | train_optimizer.step() 326 | 327 | total_num += data_loader.batch_size 328 | total_loss += loss.item() * data_loader.batch_size 329 | train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num)) 330 | 331 | return total_loss / total_num 332 | 333 | # lr scheduler for training 334 | def adjust_learning_rate(optimizer, epoch, args,trainlist): 335 | """Decay the learning rate based on schedule""" 336 | lr = args.lr 337 | if args.cos: # cosine lr schedule 338 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 339 | else: # stepwise lr schedule 340 | for milestone in args.schedule: 341 | lr *= 0.1 if epoch >= milestone else 1. 342 | for param_group in optimizer.param_groups: 343 | param_group['lr'] = lr 344 | # test using a knn monitor 345 | def test(net, memory_data_loader, args): 346 | net.eval() 347 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 348 | target_bank=[] 349 | pathbank=[] 350 | with torch.no_grad(): 351 | # generate feature bank 352 | for data, target,path in tqdm(memory_data_loader, desc='Feature extracting'): 353 | feature = net(data.cuda(non_blocking=True)) 354 | feature = F.normalize(feature, dim=1) 355 | feature_bank.append(feature) 356 | target=list(target) 357 | path=list(path) 358 | pathbank=pathbank+path 359 | target_bank=target_bank+target 360 | # [D, N] 361 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 362 | 363 | 364 | pred_labels = knn_predict(feature_bank, target_bank,pathbank, args.knn_k, args.knn_t) 365 | 366 | 367 | return 368 | 369 | 370 | def knn_predict(feature_bank,target_bank, pathbank,knn_k, knn_t): 371 | dataset=[] 372 | for ii in tqdm(range(feature_bank.size(1))): 373 | baselabel = target_bank[ii] 374 | basepath=pathbank[ii] 375 | temp=feature_bank.permute(1, 0)[ii] 376 | temp = temp.unsqueeze(0) 377 | sim_matrix = torch.mm(temp, feature_bank) 378 | sim_weight, sim_indices = sim_matrix.topk(k=feature_bank.size(1), dim=-1) 379 | sim_weight = (sim_weight / knn_t).exp() 380 | sim_weight=sim_weight/float(math.e**10) 381 | for nnn in range(feature_bank.size(1)): 382 | index=int(sim_indices[0][nnn]) 383 | label=target_bank[index] 384 | w = float(sim_weight[0][nnn]) 385 | if w<=args.w: 386 | break 387 | if label!=baselabel: 388 | 389 | path=pathbank[index] 390 | if w>args.w: 391 | 392 | data={} 393 | data['path1'] = basepath 394 | data['label1']=str(int(baselabel.item())).zfill(4) 395 | data['path2'] = path 396 | data['label2'] = str(int(label.item())).zfill(4) 397 | data['w'] = w 398 | print(data) 399 | dataset.append(copy.deepcopy(data)) 400 | break 401 | with open('result.json', 'w', encoding='utf-8') as f: 402 | json.dump(dataset, f, ensure_ascii=False) 403 | 404 | """### Start training""" 405 | 406 | # define optimizer 407 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9) 408 | 409 | # load model if resume 410 | epoch_start = 1 411 | if args.resume is not '': 412 | checkpoint = torch.load(args.resume) 413 | model.load_state_dict(checkpoint['state_dict']) 414 | optimizer.load_state_dict(checkpoint['optimizer']) 415 | epoch_start = checkpoint['epoch'] + 1 416 | print('Loaded from: {}'.format(args.resume)) 417 | 418 | # logging 419 | results = {'train_loss': []} 420 | # training loop 421 | trainlist=[] 422 | test_acc_1 = test(model.encoder_q, train_loader, args) -------------------------------------------------------------------------------- /Validation/Validation_label.json: -------------------------------------------------------------------------------- 1 | {"0001": 0, "0002": 1, "0003": 2, "0004": 3, "0005": 4, "0006": 5, "0007": 6, "0008": 7, "0009": 8, "0010": 9, "0011_0012_0013": 10, "0014": 11, "0015_0016": 12, "0017": 13, "0018": 14, "0019": 15, "0020": 16, "0021": 17, "0022": 18, "0023_0024": 19, "0025_0026": 20, "0027": 21, "0028": 22, "0029": 23, "0030": 24, "0031": 25, "0032": 26, "0033": 27, "0034": 28, "0035": 29, "0036": 30, "0037": 31, "0038": 32, "0039": 33, "0040": 34, "0041_0042": 35, "0043": 36, "0044": 37, "0045": 38, "0046": 39, "0047": 40, "0048_0049": 41, "0050": 42, "0051": 43, "0052_0053": 44, "0054": 45, "0055": 46, "0056": 47, "0057": 48, "0058": 49, "0059": 50, "0060": 51, "0061": 52, "0062_0063": 53, "0064": 54, "0065": 55, "0066": 56, "0067": 57, "0068": 58, "0069": 59, "0070_0071": 60, "0072": 61, "0073": 62, "0074": 63, "0075": 64, "0076": 65, "0077": 66, "0078": 67, "0079": 68, "0080": 69, "0081": 70, "0082": 71, "0083": 72, "0084": 73, "0085": 74, "0086": 75, "0087": 76, "0088": 77, "0089": 78, "0090": 79, "0091": 80, "0092": 81, "0093": 82, "0094": 83, "0095": 84, "0096": 85, "0097": 86, "0098": 87, "0099": 88, "0100": 89, "0101": 90, "0102": 91, "0103": 92, "0104": 93, "0105": 94, "0106_0107": 95, "0108": 96, "0109": 97, "0110": 98, "0111": 99, "0112": 100, "0113_0114": 101, "0115": 102, "0116": 103, "0117_0118": 104, "0119_0120": 105, "0121": 106, "0122": 107, "0123": 108, "0124": 109, "0125": 110, "0126": 111, "0127": 112, "0128": 113, "0129": 114, "0130": 115, "0131": 116, "0132": 117, "0133": 118, "0134": 119, "0135": 120, "0136": 121, "0137": 122, "0138": 123, "0139": 124, "0140": 125, "0141": 126, "0142": 127, "0143": 128, "0144": 129, "0145_0146": 130, "0147": 131, "0148": 132, "0149": 133, "0150": 134, "0151": 135, "0152": 136, "0153": 137, "0154": 138, "0155": 139, "0156": 140, "0157": 141, "0158": 142, "0159": 143, "0160": 144, "0161_0162": 145, "0163": 146, "0164": 147, "0165": 148, "0166": 149, "0167": 150, "0168": 151, "0169": 152, "0170": 153, "0171": 154, "0172": 155, "0173": 156, "0174": 157, "0175": 158, "0176": 159, "0177": 160, "0178": 161, "0179": 162, "0180": 163, "0181": 164, "0182": 165, "0183": 166, "0184": 167, "0185": 168, "0186": 169, "0187": 170, "0188": 171, "0189_0190": 172, "0191": 173, "0192": 174, "0193_0194": 175, "0195": 176, "0196_0197": 177, "0198": 178, "0199": 179, "0200": 180, "0201": 181, "0202": 182, "0203": 183, "0204": 184, "0205_0206": 185, "0207": 186, "0208": 187, "0209": 188, "0210_0211": 189, "0212_0213": 190, "0214": 191, "0215": 192, "0216": 193, "0217": 194, "0218": 195, "0219": 196, "0220_0221": 197, "0222": 198, "0223": 199, "0224": 200, "0225": 201, "0226": 202, "0227": 203, "0228": 204, "0229": 205, "0230": 206, "0231": 207, "0232": 208, "0233": 209, "0234": 210, "0235": 211, "0236_0237": 212, "0238": 213, "0239": 214, "0240": 215, "0241": 216, "0242": 217, "0243": 218, "0244": 219, "0245": 220, "0246_0247": 221, "0248": 222, "0249_0250": 223, "0251": 224, "0252": 225, "0253": 226, "0254": 227, "0255": 228, "0256": 229, "0257": 230, "0258_0259": 231, "0260": 232, "0261": 233, "0262": 234, "0263": 235, "0264": 236, "0265": 237, "0266": 238, "0267": 239, "0268": 240, "0269": 241, "0270_0271": 242, "0272": 243, "0273": 244, "0274": 245, "0275": 246, "0276": 247, "0277": 248, "0278": 249, "0279": 250, "0280": 251, "0281_0282_0283": 252, "0284": 253, "0285": 254, "0286_0287_0288": 255, "0289": 256, "0290": 257, "0291": 258, "0292": 259, "0293": 260, "0294": 261, "0295": 262, "0296": 263, "0297_0298": 264, "0299": 265, "0300": 266, "0301": 267, "0302_0303": 268, "0304_0305": 269, "0306": 270, "0307": 271, "0308": 272, "0309": 273, "0310": 274, "0311": 275, "0312_0313": 276, "0314": 277, "0315": 278, "0316": 279, "0317": 280, "0318": 281, "0319": 282, "0320": 283, "0321_0322_0323": 284, "0324": 285, "0325": 286, "0326": 287, "0327_0328": 288, "0329": 289, "0330": 290, "0331": 291, "0332": 292, "0333": 293, "0334": 294, "0335": 295, "0336_0337": 296, "0338": 297, "0339": 298, "0340": 299, "0341": 300, "0342": 301, "0343": 302, "0344": 303, "0345": 304, "0346": 305, "0347": 306, "0348": 307, "0349": 308, "0350": 309, "0351": 310, "0352": 311, "0353": 312, "0354": 313, "0355": 314, "0356": 315, "0357": 316, "0358": 317, "0359": 318, "0360_0361": 319, "0362_0363": 320, "0364": 321, "0365": 322, "0366": 323, "0367": 324, "0368": 325, "0369": 326, "0370": 327, "0371": 328, "0372": 329, "0373": 330, "0374": 331, "0375_0376": 332, "0377": 333, "0378": 334, "0379": 335, "0380": 336, "0381": 337, "0382_0383": 338, "0384": 339, "0385": 340, "0386_0387": 341, "0388": 342, "0389": 343, "0390": 344, "0391": 345, "0392": 346, "0393": 347, "0394_0395": 348, "0396_0397": 349, "0398": 350, "0399": 351, "0400": 352, "0401_0402": 353, "0403": 354, "0404": 355, "0405": 356, "0406": 357, "0407_0408": 358, "0409_0410": 359, "0411": 360, "0412": 361, "0413": 362, "0414": 363, "0415": 364, "0416": 365, "0417": 366, "0418": 367, "0419": 368, "0420": 369, "0421_0422": 370, "0423": 371, "0424": 372, "0425_0426": 373, "0427": 374, "0428": 375, "0429": 376, "0430": 377, "0431": 378, "0432": 379, "0433": 380, "0434": 381, "0435": 382, "0436": 383, "0437": 384, "0438": 385, "0439": 386, "0440": 387, "0441": 388, "0442": 389, "0443": 390, "0444": 391, "0445_0446": 392, "0447": 393, "0448": 394, "0449": 395, "0450": 396, "0451": 397, "0452": 398, "0453": 399, "0454": 400, "0455_0456": 401, "0457": 402, "0458_0459": 403, "0460": 404, "0461_0462": 405, "0463": 406, "0464": 407, "0465": 408, "0466": 409, "0467": 410, "0468": 411, "0469": 412, "0470": 413, "0471": 414, "0472": 415, "0473": 416, "0474": 417, "0475": 418, "0476": 419, "0477": 420, "0478": 421, "0479_0480": 422, "0481": 423, "0482": 424, "0483_0484": 425, "0485": 426, "0486": 427, "0487": 428, "0488": 429, "0489_0490_0491": 430, "0492": 431, "0493": 432, "0494": 433, "0495": 434, "0496": 435, "0497": 436, "0498": 437, "0499": 438, "0500": 439, "0501": 440, "0502": 441, "0503": 442, "0504": 443, "0505": 444, "0506": 445, "0507": 446, "0508": 447, "0509": 448, "0510": 449, "0511": 450, "0512": 451, "0513": 452, "0514": 453, "0515": 454, "0516_0517": 455, "0518_0519": 456, "0520": 457, "0521": 458, "0522": 459, "0523": 460, "0524": 461, "0525": 462, "0526": 463, "0527": 464, "0528": 465, "0529": 466, "0530": 467, "0531": 468, "0532": 469, "0533": 470, "0534": 471, "0535_0536": 472, "0537": 473, "0538": 474, "0539": 475, "0540_0541": 476, "0542": 477, "0543": 478, "0544_0545": 479, "0546": 480, "0547": 481, "0548": 482, "0549": 483, "0550": 484, "0551": 485, "0552": 486, "0553": 487, "0554": 488, "0555": 489, "0556": 490, "0557": 491, "0558_0559_0560": 492, "0561": 493, "0562_0563_0564": 494, "0565": 495, "0566": 496, "0567_0568": 497, "0569_0570": 498, "0571": 499, "0572": 500, "0573": 501, "0574": 502, "0575": 503, "0576": 504, "0577": 505, "0578": 506, "0579": 507, "0580": 508, "0581": 509, "0582": 510, "0583": 511, "0584": 512, "0585": 513, "0586": 514, "0587_0588_0589": 515, "0590": 516, "0591": 517, "0592": 518, "0593": 519, "0594": 520, "0595": 521, "0596": 522, "0597_0598": 523, "0599": 524, "0600": 525, "0601": 526, "0602": 527, "0603": 528, "0604": 529, "0605": 530, "0606": 531, "0607": 532, "0608": 533, "0609": 534, "0610": 535, "0611": 536, "0612": 537, "0613": 538, "0614": 539, "0615": 540, "0616": 541, "0617": 542, "0618": 543, "0619": 544, "0620": 545, "0621": 546, "0622": 547, "0623": 548, "0624": 549, "0625": 550, "0626": 551, "0627": 552, "0628": 553, "0629": 554, "0630_0631": 555, "0632": 556, "0633": 557, "0634": 558, "0635": 559, "0636": 560, "0637": 561, "0638_0639": 562, "0640": 563, "0641": 564, "0642": 565, "0643": 566, "0644": 567, "0645": 568, "0646": 569, "0647": 570, "0648": 571, "0649": 572, "0650": 573, "0651": 574, "0652": 575, "0653_0654": 576, "0655_0656": 577, "0657": 578, "0658": 579, "0659": 580, "0660": 581, "0661_0662": 582, "0663": 583, "0664": 584, "0665": 585, "0666": 586, "0667": 587, "0668": 588, "0669": 589, "0670_0671": 590, "0672_0673_0674": 591, "0675": 592, "0676": 593, "0677": 594, "0678_0679": 595, "0680_0681": 596, "0682": 597, "0683": 598, "0684": 599, "0685": 600, "0686": 601, "0687": 602, "0688": 603, "0689": 604, "0690_0691": 605, "0692": 606, "0693": 607, "0694": 608, "0695": 609, "0696": 610, "0697": 611, "0698": 612, "0699": 613, "0700": 614, "0701": 615, "0702": 616, "0703": 617, "0704": 618, "0705": 619, "0706": 620, "0707": 621, "0708": 622, "0709": 623, "0710": 624, "0711": 625, "0712": 626, "0713": 627, "0714": 628, "0715_0716": 629, "0717": 630, "0718_0719": 631, "0720": 632, "0721": 633, "0722": 634, "0723": 635, "0724": 636, "0725": 637, "0726_0727": 638, "0728": 639, "0729": 640, "0730": 641, "0731": 642, "0732": 643, "0733": 644, "0734": 645, "0735": 646, "0736": 647, "0737": 648, "0738": 649, "0739": 650, "0740": 651, "0741": 652, "0742": 653, "0743": 654, "0744": 655, "0745": 656, "0746": 657, "0747": 658, "0748": 659, "0749": 660, "0750_0751": 661, "0752_0753": 662, "0754": 663, "0755": 664, "0756": 665, "0757": 666, "0758": 667, "0759": 668, "0760": 669, "0761_0762": 670, "0763": 671, "0764": 672, "0765": 673, "0766": 674, "0767": 675, "0768": 676, "0769": 677, "0770": 678, "0771": 679, "0772": 680, "0773": 681, "0774": 682, "0775": 683, "0776": 684, "0777": 685, "0778": 686, "0779": 687, "0780": 688, "0781": 689, "0782": 690, "0783": 691, "0784": 692, "0785_0786": 693, "0787_0788_0789": 694, "0790": 695, "0791": 696, "0792": 697, "0793": 698, "0794": 699, "0795": 700, "0796": 701, "0797": 702, "0798": 703, "0799": 704, "0800_0801": 705, "0802": 706, "0803": 707, "0804": 708, "0805": 709, "0806": 710, "0807_0808": 711, "0809": 712, "0810": 713, "0811_0812": 714, "0813": 715, "0814": 716, "0815_0816": 717, "0817": 718, "0818": 719, "0819": 720, "0820": 721, "0821": 722, "0822": 723, "0823_0824": 724, "0825_0826": 725, "0827_0828": 726, "0829": 727, "0830": 728, "0831": 729, "0832": 730, "0833": 731, "0834": 732, "0835_0836": 733, "0837": 734, "0838": 735, "0839": 736, "0840": 737, "0841": 738, "0842": 739, "0843": 740, "0844": 741, "0845": 742, "0846": 743, "0847": 744, "0848": 745, "0849": 746, "0850": 747, "0851": 748, "0852": 749, "0853": 750, "0854": 751, "0855": 752, "0856": 753, "0857": 754, "0858": 755, "0859": 756, "0860": 757, "0861_0862": 758, "0863": 759, "0864": 760, "0865": 761, "0866": 762, "0867": 763, "0868": 764, "0869": 765, "0870_0871": 766, "0872": 767, "0873": 768, "0874": 769, "0875": 770, "0876": 771, "0877": 772, "0878": 773, "0879": 774, "0880": 775, "0881": 776, "0882": 777, "0883": 778, "0884": 779, "0885": 780, "0886": 781, "0887": 782, "0888": 783, "0889": 784, "0890": 785, "0891": 786, "0892": 787, "0893": 788, "0894": 789, "0895": 790, "0896": 791, "0897": 792, "0898": 793, "0899": 794, "0900_0901": 795, "0902_0903": 796, "0904": 797, "0905": 798, "0906": 799, "0907": 800, "0908_0909": 801, "0910": 802, "0911": 803, "0912": 804, "0913": 805, "0914": 806, "0915": 807, "0916": 808, "0917_0918": 809, "0919": 810, "0920": 811, "0921": 812, "0922": 813, "0923": 814, "0924": 815, "0925": 816, "0926": 817, "0927": 818, "0928": 819, "0929": 820, "0930": 821, "0931": 822, "0932": 823, "0933": 824, "0934": 825, "0935": 826, "0936": 827, "0937": 828, "0938": 829, "0939_0940": 830, "0941": 831, "0942": 832, "0943": 833, "0944": 834, "0945": 835, "0946": 836, "0947": 837, "0948": 838, "0949": 839, "0950": 840, "0951": 841, "0952": 842, "0953": 843, "0954": 844, "0955": 845, "0956": 846, "0957": 847, "0958": 848, "0959": 849, "0960": 850, "0961": 851, "0962": 852, "0963": 853, "0964": 854, "0965": 855, "0966": 856, "0967": 857, "0968_0969": 858, "0970": 859, "0971": 860, "0972": 861, "0973": 862, "0974": 863, "0975": 864, "0976_0977": 865, "0978": 866, "0979": 867, "0980": 868, "0981": 869, "0982": 870, "0983": 871, "0984_0985": 872, "0986": 873, "0987": 874, "0988": 875, "0989": 876, "0990": 877, "0991": 878, "0992": 879, "0993": 880, "0994": 881, "0995": 882, "0996": 883, "0997": 884, "0998": 885, "0999": 886, "1000": 887, "1001": 888, "1002": 889, "1003": 890, "1004": 891, "1005": 892, "1006": 893, "1007": 894, "1008_1009": 895, "1010_1011": 896, "1012": 897, "1013": 898, "1014": 899, "1015": 900, "1016": 901, "1017": 902, "1018_1019": 903, "1020": 904, "1021": 905, "1022": 906, "1023": 907, "1024": 908, "1025": 909, "1026": 910, "1027": 911, "1028": 912, "1029_1030": 913, "1031": 914, "1032": 915, "1033": 916, "1034_1035": 917, "1036": 918, "1037": 919, "1038": 920, "1039": 921, "1040": 922, "1041": 923, "1042": 924, "1043": 925, "1044": 926, "1045": 927, "1046": 928, "1047": 929, "1048": 930, "1049": 931, "1050": 932, "1051": 933, "1052": 934, "1053": 935, "1054": 936, "1055": 937, "1056": 938, "1057": 939, "1058": 940, "1059": 941, "1060": 942, "1061": 943, "1062": 944, "1063": 945, "1064": 946, "1065_1066": 947, "1067": 948, "1068": 949, "1069": 950, "1070": 951, "1071_1072": 952, "1073": 953, "1074": 954, "1075": 955, "1076": 956, "1077_1078": 957, "1079": 958, "1080": 959, "1081": 960, "1082": 961, "1083": 962, "1084": 963, "1085": 964, "1086": 965, "1087": 966, "1088": 967, "1089": 968, "1090": 969, "1091": 970, "1092": 971, "1093": 972, "1094": 973, "1095": 974, "1096": 975, "1097": 976, "1098": 977, "1099": 978, "1100": 979, "1101": 980, "1102_1103": 981, "1104": 982, "1105": 983, "1106_1107": 984, "1108": 985, "1109": 986, "1110": 987, "1111": 988, "1112": 989, "1113": 990, "1114": 991, "1115_1116": 992, "1117": 993, "1118": 994, "1119": 995, "1120_1121": 996, "1122": 997, "1123": 998, "1124": 999, "1125": 1000, "1126": 1001, "1127": 1002, "1128": 1003, "1129": 1004, "1130": 1005, "1131": 1006, "1132": 1007, "1133_1134": 1008, "1135": 1009, "1136": 1010, "1137": 1011, "1138": 1012, "1139": 1013, "1140": 1014, "1141": 1015, "1142": 1016, "1143": 1017, "1144": 1018, "1145": 1019, "1146": 1020, "1147": 1021, "1148": 1022, "1149": 1023, "1150": 1024, "1151": 1025, "1152": 1026, "1153": 1027, "1154": 1028, "1155": 1029, "1156": 1030, "1157": 1031, "1158": 1032, "1159": 1033, "1160": 1034, "1161": 1035, "1162": 1036, "1163_1164": 1037, "1165": 1038, "1166": 1039, "1167": 1040, "1168": 1041, "1169": 1042, "1170": 1043, "1171_1172": 1044, "1173": 1045, "1174": 1046, "1175": 1047, "1176": 1048, "1177": 1049, "1178": 1050, "1179": 1051, "1180": 1052, "1181": 1053, "1182": 1054, "1183": 1055, "1184": 1056, "1185": 1057, "1186": 1058, "1187": 1059, "1188": 1060, "1189": 1061, "1190": 1062, "1191": 1063, "1192": 1064, "1193": 1065, "1194": 1066, "1195": 1067, "1196": 1068, "1197": 1069, "1198": 1070, "1199": 1071, "1200": 1072, "1201": 1073, "1202": 1074, "1203": 1075, "1204_1205": 1076, "1206": 1077, "1207_1208": 1078, "1209": 1079, "1210": 1080, "1211": 1081, "1212": 1082, "1213": 1083, "1214": 1084, "1215": 1085, "1216": 1086, "1217": 1087, "1218": 1088, "1219": 1089, "1220": 1090, "1221": 1091, "1222": 1092, "1223": 1093, "1224": 1094, "1225": 1095, "1226": 1096, "1227": 1097, "1228": 1098, "1229": 1099, "1230": 1100, "1231_1232": 1101, "1233": 1102, "1234": 1103, "1235": 1104, "1236": 1105, "1237": 1106, "1238": 1107, "1239_1240": 1108, "1241": 1109, "1242": 1110, "1243": 1111, "1244": 1112, "1245": 1113, "1246": 1114, "1247": 1115, "1248": 1116, "1249": 1117, "1250": 1118, "1251": 1119, "1252": 1120, "1253": 1121, "1254": 1122, "1255": 1123, "1256": 1124, "1257": 1125, "1258": 1126, "1259": 1127, "1260": 1128, "1261_1262_1263": 1129, "1264": 1130, "1265": 1131, "1266": 1132, "1267": 1133, "1268": 1134, "1269": 1135, "1270": 1136, "1271_1272": 1137, "1273": 1138, "1274": 1139, "1275": 1140, "1276": 1141, "1277": 1142, "1278": 1143, "1279": 1144, "1280_1281": 1145, "1282": 1146, "1283": 1147, "1284_1285_1286": 1148, "1287": 1149, "1288": 1150, "1289": 1151, "1290_1291_1292": 1152, "1293": 1153, "1294": 1154, "1295": 1155, "1296": 1156, "1297": 1157, "1298": 1158, "1299": 1159, "1300": 1160, "1301": 1161, "1302": 1162, "1303": 1163, "1304_1305": 1164, "1306": 1165, "1307": 1166, "1308": 1167, "1309": 1168, "1310": 1169, "1311": 1170, "1312": 1171, "1313": 1172, "1314": 1173, "1315": 1174, "1316": 1175, "1317": 1176, "1318": 1177, "1319": 1178, "1320": 1179, "1321": 1180, "1322": 1181, "1323": 1182, "1324": 1183, "1325": 1184, "1326": 1185, "1327_1328": 1186, "1329": 1187, "1330": 1188, "1331": 1189, "1332": 1190, "1333": 1191, "1334": 1192, "1335": 1193, "1336": 1194, "1337": 1195, "1338": 1196, "1339": 1197, "1340": 1198, "1341_1342": 1199, "1343": 1200, "1344_1345": 1201, "1346": 1202, "1347_1348": 1203, "1349_1350": 1204, "1351": 1205, "1352": 1206, "1353": 1207, "1354": 1208, "1355": 1209, "1356": 1210, "1357": 1211, "1358": 1212, "1359": 1213, "1360": 1214, "1361": 1215, "1362": 1216, "1363": 1217, "1364": 1218, "1365": 1219, "1366": 1220, "1367": 1221, "1368": 1222, "1369": 1223, "1370": 1224, "1371": 1225, "1372": 1226, "1373": 1227, "1374_1375": 1228, "1376": 1229, "1377": 1230, "1378_1379": 1231, "1380": 1232, "1381": 1233, "1382": 1234, "1383": 1235, "1384": 1236, "1385": 1237, "1386": 1238, "1387": 1239, "1388": 1240, "1389_1390": 1241, "1391": 1242, "1392": 1243, "1393": 1244, "1394": 1245, "1395": 1246, "1396": 1247, "1397": 1248, "1398": 1249, "1399": 1250, "1400": 1251, "1401": 1252, "1402": 1253, "1403": 1254, "1404": 1255, "1405": 1256, "1406_1407": 1257, "1408": 1258, "1409": 1259, "1410": 1260, "1411": 1261, "1412": 1262, "1413": 1263, "1414": 1264, "1415_1416": 1265, "1417": 1266, "1418": 1267, "1419": 1268, "1420": 1269, "1421_1422": 1270, "1423": 1271, "1424": 1272, "1425": 1273, "1426": 1274, "1427": 1275, "1428": 1276, "1429_1430": 1277, "1431": 1278, "1432": 1279, "1433": 1280, "1434": 1281, "1435": 1282, "1436": 1283, "1437": 1284, "1438": 1285, "1439": 1286, "1440": 1287, "1441": 1288, "1442": 1289, "1443": 1290, "1444": 1291, "1445": 1292, "1446_1447": 1293, "1448": 1294, "1449": 1295, "1450": 1296, "1451": 1297, "1452": 1298, "1453": 1299, "1454": 1300, "1455": 1301, "1456": 1302, "1457": 1303, "1458": 1304, "1459": 1305, "1460": 1306, "1461": 1307, "1462": 1308, "1463": 1309, "1464_1465_1466": 1310, "1467": 1311, "1468": 1312, "1469": 1313, "1470": 1314, "1471_1472": 1315, "1473": 1316, "1474": 1317, "1475": 1318, "1476": 1319, "1477": 1320, "1478": 1321, "1479": 1322, "1480": 1323, "1481": 1324, "1482": 1325, "1483": 1326, "1484": 1327, "1485": 1328, "1486": 1329, "1487": 1330, "1488": 1331, "1489": 1332, "1490": 1333, "1491": 1334, "1492": 1335, "1493": 1336, "1494": 1337, "1495": 1338, "1496": 1339, "1497": 1340, "1498": 1341, "1499_1500": 1342, "1501": 1343, "1502": 1344, "1503": 1345, "1504": 1346, "1505": 1347, "1506": 1348, "1507": 1349, "1508": 1350, "1509": 1351, "1510": 1352, "1511": 1353, "1512": 1354, "1513": 1355, "1514": 1356, "1515": 1357, "1516": 1358, "1517": 1359, "1518": 1360, "1519": 1361, "1520_1521": 1362, "1522": 1363, "1523": 1364, "1524": 1365, "1525_1526": 1366, "1527": 1367, "1528": 1368, "1529": 1369, "1530": 1370, "1531": 1371, "1532": 1372, "1533_1534_1535": 1373, "1536": 1374, "1537": 1375, "1538": 1376, "1539": 1377, "1540": 1378, "1541_1542": 1379, "1543": 1380, "1544": 1381, "1545_1546": 1382, "1547": 1383, "1548": 1384, "1549_1550": 1385, "1551": 1386, "1552_1553": 1387, "1554": 1388, "1555": 1389, "1556": 1390, "1557": 1391, "1558": 1392, "1559": 1393, "1560": 1394, "1561": 1395, "1562_1563": 1396, "1564_1565": 1397, "1566_1567": 1398, "1568": 1399, "1569": 1400, "1570": 1401, "1571_1572": 1402, "1573": 1403, "1574": 1404, "1575": 1405, "1576": 1406, "1577": 1407, "1578": 1408, "1579": 1409, "1580": 1410, "1581": 1411, "1582": 1412, "1583": 1413, "1584": 1414, "1585": 1415, "1586": 1416, "1587": 1417, "1588": 1418, "1589_1590_1591": 1419, "1592": 1420, "1593": 1421, "1594_1595": 1422, "1596": 1423, "1597": 1424, "1598": 1425, "1599": 1426, "1600": 1427, "1601": 1428, "1602": 1429, "1603": 1430, "1604": 1431, "1605_1606": 1432, "1607": 1433, "1608": 1434, "1609": 1435, "1610": 1436, "1611": 1437, "1612": 1438, "1613": 1439, "1614_1615": 1440, "1616": 1441, "1617": 1442, "1618_1619": 1443, "1620": 1444, "1621": 1445, "1622": 1446, "1623_1624": 1447, "1625": 1448, "1626": 1449, "1627": 1450, "1628": 1451, "1629": 1452, "1630": 1453, "1631": 1454, "1632": 1455, "1633": 1456, "1634": 1457, "1635_1636": 1458, "1637": 1459, "1638": 1460, "1639_1640_1641": 1461, "1642_1643": 1462, "1644": 1463, "1645": 1464, "1646": 1465, "1647": 1466, "1648": 1467, "1649": 1468, "1650": 1469, "1651": 1470, "1652_1653": 1471, "1654": 1472, "1655": 1473, "1656": 1474, "1657": 1475, "1658_1659": 1476, "1660": 1477, "1661": 1478, "1662": 1479, "1663": 1480, "1664": 1481, "1665": 1482, "1666": 1483, "1667": 1484, "1668": 1485, "1669": 1486, "1670": 1487, "1671": 1488, "1672": 1489, "1673": 1490, "1674": 1491, "1675": 1492, "1676": 1493, "1677": 1494, "1678": 1495, "1679": 1496, "1680": 1497, "1681": 1498, "1682_1683_1684": 1499, "1685": 1500, "1686": 1501, "1687_1688": 1502, "1689": 1503, "1690": 1504, "1691": 1505, "1692": 1506, "1693": 1507, "1694": 1508, "1695": 1509, "1696": 1510, "1697": 1511, "1698": 1512, "1699": 1513, "1700": 1514, "1701": 1515, "1702": 1516, "1703": 1517, "1704": 1518, "1705_1706": 1519, "1707": 1520, "1708": 1521, "1709": 1522, "1710": 1523, "1711": 1524, "1712": 1525, "1713": 1526, "1714": 1527, "1715": 1528, "1716": 1529, "1717": 1530, "1718": 1531, "1719": 1532, "1720": 1533, "1721": 1534, "1722": 1535, "1723_1724_1725": 1536, "1726": 1537, "1727_1728": 1538, "1729": 1539, "1730": 1540, "1731": 1541, "1732": 1542, "1733": 1543, "1734": 1544, "1735": 1545, "1736_1737_1738": 1546, "1739": 1547, "1740": 1548, "1741": 1549, "1742": 1550, "1743": 1551, "1744": 1552, "1745": 1553, "1746": 1554, "1747": 1555, "1748": 1556, "1749": 1557, "1750": 1558, "1751": 1559, "1752": 1560, "1753": 1561, "1754": 1562, "1755": 1563, "1756_1757": 1564, "1758": 1565, "1759": 1566, "1760": 1567, "1761": 1568, "1762": 1569, "1763": 1570, "1764": 1571, "1765": 1572, "1766": 1573, "1767": 1574, "1768": 1575, "1769": 1576, "1770": 1577, "1771": 1578, "1772": 1579, "1773": 1580, "1774": 1581, "1775": 1582, "1776": 1583, "1777": 1584, "1778": 1585, "1779_1780": 1586, "1781": 1587} -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 673 | -------------------------------------------------------------------------------- /MoCo/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # gpu_info = !nvidia-smi -i 0 3 | # gpu_info = '\n'.join(gpu_info) 4 | # print(gpu_info) 5 | # nohup python -u mocodataset.py > mocodataset.log 2>&1 & 6 | from datetime import datetime 7 | from functools import partial 8 | import cv2 9 | import numpy as np 10 | import torchvision 11 | from PIL import Image 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from torchvision.datasets import CIFAR10, FashionMNIST,MNIST 15 | from torchvision.models import resnet 16 | from tqdm import tqdm 17 | import argparse 18 | import json 19 | import math 20 | import os 21 | import random 22 | import pandas as pd 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | """### Set arguments""" 28 | 29 | parser = argparse.ArgumentParser(description='Train MoCo on HUST-OBC') 30 | 31 | parser.add_argument('-a', '--arch', default='resnet18') 32 | 33 | # lr: 0.06 for batch 512 (or 0.03 for batch 256) 34 | parser.add_argument('--lr', '--learning-rate', default=0.006, type=float, metavar='LR', help='initial learning rate', dest='lr') 35 | parser.add_argument('--epochs', default=150, type=int, metavar='N', help='number of total epochs to run') 36 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on') 37 | parser.add_argument('--cos', action='store_true', help='use cosine lr schedule') 38 | 39 | parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size') 40 | parser.add_argument('--num_workers', default=9, type=int) 41 | parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay') 42 | 43 | # moco specific configs: 44 | parser.add_argument('--moco-dim', default=128, type=int, help='feature dimension') 45 | parser.add_argument('--moco-k', default=24576, type=int, help='queue size; number of negative keys') 46 | parser.add_argument('--moco-m', default=0.99, type=float, help='moco momentum of updating key encoder') 47 | parser.add_argument('--moco-t', default=0.1, type=float, help='softmax temperature') 48 | 49 | parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu') 50 | 51 | parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops') 52 | 53 | # knn monitor 54 | parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor') 55 | parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t') 56 | 57 | # utils 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 59 | parser.add_argument('--results-dir', default='moco', type=str, metavar='PATH', help='path to cache (default: none)') 60 | ''' 61 | args = parser.parse_args() # running in command line 62 | ''' 63 | # args = parser.parse_args('') # running in ipynb 64 | args = parser.parse_args() # running in command line 65 | # set command line arguments here when running in ipynb 66 | args.cos = True 67 | args.schedule = [] # cos in use 68 | args.symmetric = False 69 | if args.results_dir == '': 70 | args.results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco") 71 | 72 | print(args) 73 | 74 | """### Define data loaders""" 75 | 76 | class CIFAR10Pair(CIFAR10): 77 | """CIFAR10 Dataset. 78 | """ 79 | def __getitem__(self, index): 80 | img = self.data[index] 81 | img = Image.fromarray(img) 82 | 83 | if self.transform is not None: 84 | im_1 = self.transform(img) 85 | im_2 = self.transform(img) 86 | 87 | return im_1, im_2 88 | class MiniPair(MNIST): 89 | """CIFAR10 Dataset. 90 | """ 91 | def __getitem__(self, index): 92 | img = self.data[index] 93 | img=img.numpy() 94 | img = Image.fromarray(img) 95 | 96 | if self.transform is not None: 97 | im_1 = self.transform(img) 98 | im_2 = self.transform(img) 99 | 100 | return im_1, im_2 101 | 102 | from torch.utils.data import TensorDataset, DataLoader 103 | class RandomGaussianBlur(object): 104 | def __init__(self, p=0.5, min_kernel_size=3, max_kernel_size=15, min_sigma=0.1, max_sigma=1.0): 105 | self.p = p 106 | self.min_kernel_size = min_kernel_size 107 | self.max_kernel_size = max_kernel_size 108 | self.min_sigma = min_sigma 109 | self.max_sigma = max_sigma 110 | 111 | def __call__(self, img): 112 | if random.random() < self.p and self.min_kernel_size 128 or sizey < 32: 177 | sizey = round(random.gauss(y, 30)) 178 | if x < 128: 179 | while sizex > 128 or sizex < 32: 180 | sizex = round(random.gauss(x, 30)) 181 | dx = 128 - sizex # 差值 182 | dy = 128 - sizey 183 | if dx > 0: 184 | xl = -1 185 | while xl > dx or xl < 0: 186 | xl = round(dx / 2) 187 | xl = round(random.gauss(xl, 10)) 188 | else: 189 | xl = 0 190 | if dy > 0: 191 | yl = -1 192 | while yl > dy or yl < 0: 193 | yl = round(dy / 2) 194 | yl = round(random.gauss(yl, 10)) 195 | else: 196 | yl = 0 197 | yr = dy - yl 198 | xr = dx - xl 199 | image1=jioayan(image) 200 | image2=jioayan(image) 201 | image1 = cv2.cvtColor(np.array(image1), cv2.COLOR_RGB2BGR) 202 | image1 = pengzhang(image1) 203 | image1 = Image.fromarray(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)) 204 | image2 = cv2.cvtColor(np.array(image2), cv2.COLOR_RGB2BGR) 205 | image2 = pengzhang(image2) 206 | image2 = Image.fromarray(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)) 207 | image1 = random_gaussian_blur(image1) 208 | image2 = random_gaussian_blur(image2) 209 | train_transform1 = transforms.Compose([ 210 | transforms.Resize((sizey, sizex)), 211 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'), 212 | transforms.RandomRotation(degrees=(-20, 20), center=(round(64), round(64)), fill=(255, 255, 255)), 213 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 214 | transforms.RandomGrayscale(p=0.2), 215 | transforms.ToTensor(), 216 | transforms.Normalize([0.84959, 0.84959, 0.84959], [0.30949923, 0.30949923, 0.30949923])]) 217 | im_1 = train_transform1(image1) 218 | x, y = 80, 80 219 | sizey, sizex = 129, 129 220 | if y < 128: 221 | while sizey > 128 or sizey < 32: 222 | sizey = round(random.gauss(y, 30)) 223 | if x < 128: 224 | while sizex > 128 or sizex < 32: 225 | sizex = round(random.gauss(x, 30)) 226 | dx = 128 - sizex # 差值 227 | dy = 128 - sizey 228 | if dx > 0: 229 | xl = -1 230 | while xl > dx or xl < 0: 231 | xl = round(dx / 2) 232 | xl = round(random.gauss(xl, 10)) 233 | else: 234 | xl = 0 235 | if dy > 0: 236 | yl = -1 237 | while yl > dy or yl < 0: 238 | yl = round(dy / 2) 239 | yl = round(random.gauss(yl, 10)) 240 | else: 241 | yl = 0 242 | yr = dy - yl 243 | xr = dx - xl 244 | train_transform2 = transforms.Compose([ 245 | transforms.Resize((sizey, sizex)), 246 | transforms.Pad([xl, yl, xr, yr], fill=(255, 255, 255), padding_mode='constant'), 247 | transforms.RandomRotation(degrees=(-20, 20), center=(round(64), round(64)), fill=(255, 255, 255)), 248 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 249 | transforms.RandomGrayscale(p=0.2), 250 | transforms.ToTensor(), 251 | transforms.Normalize([0.84959, 0.84959, 0.84959], [0.30949923, 0.30949923, 0.30949923])]) 252 | im_2 = train_transform2(image2) 253 | return im_1, im_2 254 | 255 | def __len__(self): 256 | return len(self.images) 257 | 258 | train_dataset = Mydata() 259 | train_loader = DataLoader(train_dataset, shuffle = True, batch_size = args.batch_size, num_workers=args.num_workers,drop_last=True,pin_memory=True,) 260 | class SplitBatchNorm(nn.BatchNorm2d): 261 | def __init__(self, num_features, num_splits, **kw): 262 | super().__init__(num_features, **kw) 263 | self.num_splits = num_splits 264 | 265 | def forward(self, input): 266 | N, C, H, W = input.shape 267 | if self.training or not self.track_running_stats: 268 | running_mean_split = self.running_mean.repeat(self.num_splits) 269 | running_var_split = self.running_var.repeat(self.num_splits) 270 | outcome = nn.functional.batch_norm( 271 | input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split, 272 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), 273 | True, self.momentum, self.eps).view(N, C, H, W) 274 | self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0)) 275 | self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0)) 276 | return outcome 277 | else: 278 | return nn.functional.batch_norm( 279 | input, self.running_mean, self.running_var, 280 | self.weight, self.bias, False, self.momentum, self.eps) 281 | 282 | class ModelBase(nn.Module): 283 | """ 284 | Common CIFAR ResNet recipe. 285 | Comparing with ImageNet ResNet recipe, it: 286 | (i) replaces conv1 with kernel=3, str=1 287 | (ii) removes pool1 288 | """ 289 | def __init__(self, feature_dim=128, arch=None, bn_splits=16): 290 | super(ModelBase, self).__init__() 291 | 292 | # use split batchnorm 293 | norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d 294 | resnet_arch = getattr(resnet, arch) 295 | net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer) 296 | 297 | self.net = [] 298 | for name, module in net.named_children(): 299 | if name == 'conv1': 300 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 301 | if isinstance(module, nn.MaxPool2d): 302 | continue 303 | if isinstance(module, nn.Linear): 304 | self.net.append(nn.Flatten(1)) 305 | self.net.append(module) 306 | 307 | self.net = nn.Sequential(*self.net) 308 | 309 | def forward(self, x): 310 | x = self.net(x) 311 | # note: not normalized here 312 | return x 313 | 314 | """### Define MoCo wrapper""" 315 | 316 | class ModelMoCo(nn.Module): 317 | def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True): 318 | super(ModelMoCo, self).__init__() 319 | 320 | self.K = K 321 | self.m = m 322 | self.T = T 323 | self.symmetric = symmetric 324 | 325 | # create the encoders 326 | self.encoder_q = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits) 327 | self.encoder_k = ModelBase(feature_dim=dim, arch=arch, bn_splits=bn_splits) 328 | 329 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 330 | param_k.data.copy_(param_q.data) # initialize 331 | param_k.requires_grad = False # not update by gradient 332 | 333 | # create the queue 334 | self.register_buffer("queue", torch.randn(dim, K)) 335 | self.queue = nn.functional.normalize(self.queue, dim=0) 336 | 337 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 338 | 339 | @torch.no_grad() 340 | def _momentum_update_key_encoder(self): 341 | """ 342 | Momentum update of the key encoder 343 | """ 344 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 345 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 346 | 347 | @torch.no_grad() 348 | def _dequeue_and_enqueue(self, keys): 349 | batch_size = keys.shape[0] 350 | 351 | ptr = int(self.queue_ptr) 352 | assert self.K % batch_size == 0 # for simplicity 353 | 354 | # replace the keys at ptr (dequeue and enqueue) 355 | self.queue[:, ptr:ptr + batch_size] = keys.t() # transpose 356 | ptr = (ptr + batch_size) % self.K # move pointer 357 | 358 | self.queue_ptr[0] = ptr 359 | 360 | @torch.no_grad() 361 | def _batch_shuffle_single_gpu(self, x): 362 | """ 363 | Batch shuffle, for making use of BatchNorm. 364 | """ 365 | # random shuffle index 366 | idx_shuffle = torch.randperm(x.shape[0]).cuda() 367 | 368 | # index for restoring 369 | idx_unshuffle = torch.argsort(idx_shuffle) 370 | 371 | return x[idx_shuffle], idx_unshuffle 372 | 373 | @torch.no_grad() 374 | def _batch_unshuffle_single_gpu(self, x, idx_unshuffle): 375 | """ 376 | Undo batch shuffle. 377 | """ 378 | return x[idx_unshuffle] 379 | 380 | def contrastive_loss(self, im_q, im_k): 381 | # compute query features 382 | q = self.encoder_q(im_q) # queries: NxC 383 | q = nn.functional.normalize(q, dim=1) # already normalized 384 | 385 | # compute key features 386 | with torch.no_grad(): # no gradient to keys 387 | # shuffle for making use of BN 388 | im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k) 389 | 390 | k = self.encoder_k(im_k_) # keys: NxC 391 | k = nn.functional.normalize(k, dim=1) # already normalized 392 | 393 | # undo shuffle 394 | k = self._batch_unshuffle_single_gpu(k, idx_unshuffle) 395 | 396 | # compute logits 397 | # Einstein sum is more intuitive 398 | # positive logits: Nx1 399 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 400 | # negative logits: NxK 401 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 402 | 403 | # logits: Nx(1+K) 404 | logits = torch.cat([l_pos, l_neg], dim=1) 405 | 406 | # apply temperature 407 | logits /= self.T 408 | 409 | # labels: positive key indicators 410 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 411 | 412 | loss = nn.CrossEntropyLoss().cuda()(logits, labels) 413 | 414 | return loss, q, k 415 | 416 | def forward(self, im1, im2): 417 | """ 418 | Input: 419 | im_q: a batch of query images 420 | im_k: a batch of key images 421 | Output: 422 | loss 423 | """ 424 | 425 | # update the key encoder 426 | with torch.no_grad(): # no gradient to keys 427 | self._momentum_update_key_encoder() 428 | 429 | # compute loss 430 | if self.symmetric: # asymmetric loss 431 | loss_12, q1, k2 = self.contrastive_loss(im1, im2) 432 | loss_21, q2, k1 = self.contrastive_loss(im2, im1) 433 | loss = loss_12 + loss_21 434 | k = torch.cat([k1, k2], dim=0) 435 | else: # asymmetric loss 436 | loss, q, k = self.contrastive_loss(im1, im2) 437 | 438 | self._dequeue_and_enqueue(k) 439 | 440 | return loss 441 | 442 | # create model 443 | model = ModelMoCo( 444 | dim=args.moco_dim, 445 | K=args.moco_k, 446 | m=args.moco_m, 447 | T=args.moco_t, 448 | arch=args.arch, 449 | bn_splits=args.bn_splits, 450 | symmetric=args.symmetric, 451 | ).cuda() 452 | print(model.encoder_q) 453 | 454 | """### Define train/test 455 | 456 | 457 | """ 458 | 459 | # train for one epoch 460 | def train(net, data_loader, train_optimizer, epoch, args,trainlist): 461 | net.train() 462 | adjust_learning_rate(optimizer, epoch, args,trainlist) 463 | 464 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 465 | for im_1,im_2 in train_bar: 466 | 467 | im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True) 468 | 469 | loss = net(im_1, im_2) 470 | 471 | train_optimizer.zero_grad() 472 | loss.backward() 473 | train_optimizer.step() 474 | 475 | total_num += data_loader.batch_size 476 | total_loss += loss.item() * data_loader.batch_size 477 | train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.10f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num)) 478 | 479 | return total_loss / total_num 480 | 481 | # lr scheduler for training 482 | def adjust_learning_rate(optimizer, epoch, args,trainlist): 483 | """Decay the learning rate based on schedule""" 484 | lr = args.lr 485 | if args.cos: # cosine lr schedule 486 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 487 | else: # stepwise lr schedule 488 | for milestone in args.schedule: 489 | lr *= 0.1 if epoch >= milestone else 1. 490 | for param_group in optimizer.param_groups: 491 | param_group['lr'] = lr 492 | 493 | # test using a knn monitor 494 | def test(net, memory_data_loader, test_data_loader, epoch, args): 495 | net.eval() 496 | classes = len(memory_data_loader.dataset.classes) 497 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 498 | with torch.no_grad(): 499 | # generate feature bank 500 | for data, target in tqdm(memory_data_loader, desc='Feature extracting'): 501 | feature = net(data.cuda(non_blocking=True)) 502 | feature = F.normalize(feature, dim=1) 503 | feature_bank.append(feature) 504 | # [D, N] 505 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 506 | # [N] 507 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 508 | # loop test data to predict the label by weighted knn search 509 | test_bar = tqdm(test_data_loader) 510 | for data, target in test_bar: 511 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 512 | feature = net(data) 513 | feature = F.normalize(feature, dim=1) 514 | 515 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t) 516 | 517 | total_num += data.size(0) 518 | total_top1 += (pred_labels[:, 0] == target).float().sum().item() 519 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100)) 520 | 521 | return total_top1 / total_num * 100 522 | 523 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 524 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 525 | sim_matrix = torch.mm(feature, feature_bank) 526 | # [B, K] 527 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 528 | # [B, K] 529 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 530 | sim_weight = (sim_weight / knn_t).exp() 531 | 532 | # counts for each class 533 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 534 | # [B*K, C] 535 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 536 | # weighted score ---> [B, C] 537 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 538 | 539 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 540 | return pred_labels 541 | 542 | """### Start training""" 543 | 544 | # define optimizer 545 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9) 546 | 547 | # load model if resume 548 | epoch_start = 1 549 | if args.resume is not '': 550 | checkpoint = torch.load(args.resume, map_location=torch.device('cuda:0')) 551 | model.load_state_dict(checkpoint['state_dict']) 552 | optimizer.load_state_dict(checkpoint['optimizer']) 553 | epoch_start = checkpoint['epoch'] + 1 554 | print('Loaded from: {}'.format(args.resume)) 555 | 556 | # logging 557 | results = {'train_loss': []} 558 | if not os.path.exists(args.results_dir): 559 | os.mkdir(args.results_dir) 560 | # dump args 561 | with open(args.results_dir + '/args.json', 'w') as fid: 562 | json.dump(args.__dict__, fid, indent=2) 563 | 564 | # training loop 565 | trainlist=[] 566 | for epoch in range(epoch_start, args.epochs + 1): 567 | train_loss = train(model, train_loader, optimizer, epoch, args,trainlist) 568 | results['train_loss'].append(train_loss) 569 | data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1)) 570 | data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch') 571 | # save model 572 | torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth') --------------------------------------------------------------------------------