├── __pycache__ └── args.cpython-37.pyc ├── tools ├── __pycache__ │ ├── loss.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── evaluate.cpython-37.pyc │ └── model_train.cpython-37.pyc ├── loss.py ├── evaluate.py ├── utils.py └── model_train.py ├── data ├── __pycache__ │ ├── imagenet.cpython-37.pyc │ ├── transform.cpython-37.pyc │ └── data_loader_1.cpython-37.pyc ├── transform.py ├── data_loader_1.py ├── cifar100.py └── imagenet.py ├── network ├── __pycache__ │ └── resnet.cpython-37.pyc └── resnet.py ├── README.md ├── args.py ├── .gitignore ├── main.py └── LICENSE /__pycache__/args.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/__pycache__/args.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/tools/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/tools/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/imagenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/data/__pycache__/imagenet.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/data/__pycache__/transform.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/network/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/tools/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/model_train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/tools/__pycache__/model_train.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader_1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUXN98/ACHNet/HEAD/data/__pycache__/data_loader_1.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACHNet 2 | Repository of Attention-guided Contrastive Hashing for Long-tailed Image Retrieval (accepted to IJCAI2022) 3 | paper link: https://www.ijcai.org/proceedings/2022/0142.pdf 4 | 5 | # Datasets Download 6 | link:https://pan.baidu.com/s/11R7Ncm4aowdKC9zh5tPe1Q \ 7 | password:saf3 \ 8 | For the purpose of fair comparison, we use the dataset from Long-tailed Hashing, both imagenet-100 and cifar-100 can be downloaded from the above link. 9 | 10 | # Model Training. 11 | - First you should create an empty 'checkpoints' directory. 12 | - Set hyper-parameters in arg.py, use python main.py to train the model 13 | 14 | # Requirements 15 | Python 3.7.12 16 | Pytorch 1.8.0 17 | -------------------------------------------------------------------------------- /tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class CELoss(nn.Module): 9 | 10 | def __init__(self): 11 | super(CELoss, self).__init__() 12 | 13 | def forward(self, assignments, targets): 14 | batch_size = assignments.size(0) 15 | 16 | assignments = F.softmax(assignments, dim=1) 17 | loss = torch.sum(- torch.log(assignments + 1e-6) * targets) / batch_size 18 | 19 | return loss 20 | 21 | 22 | class CenConLoss(nn.Module): 23 | def __init__(self): 24 | super(CenConLoss, self).__init__() 25 | self.t = 1.0 26 | 27 | def forward(self, hashcode, center, label): 28 | cos_sim = F.cosine_similarity(hashcode.unsqueeze(1), center.unsqueeze(0), dim=2) 29 | 30 | positives = (torch.exp(cos_sim * self.t) * label) 31 | denominator = torch.exp(cos_sim * self.t) * (1 - label) 32 | loss = -torch.log(torch.sum(positives, dim=1) / torch.sum(denominator, dim=1)) 33 | loss = torch.mean(loss) 34 | 35 | return loss 36 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description='code for ACHNet') 6 | parser.add_argument('--dataset', default='cifar-100-IF100', type=str, 7 | help='Dataset name.') 8 | parser.add_argument('--root', 9 | default=None, 10 | type=str, 11 | help='Path of dataset') 12 | parser.add_argument('--batch_size', default=8, type=int, 13 | help='Batch size.(default: 8)') 14 | parser.add_argument('--lr', default=1e-5, type=float, 15 | help='Learning rate.(default: 1e-5)') 16 | parser.add_argument('--code_length', default='32, 64, 96', type=str, 17 | help='Binary hash code length.(default: 32,64,96)') 18 | parser.add_argument('--feature_dim', default=2000, type=int, 19 | help='number of classes.(default: 2000)') 20 | parser.add_argument('--num_classes', default=100, type=int, 21 | help='number of classes.(default: 100)') 22 | parser.add_argument('--max_iter', default=100, type=int, 23 | help='Number of iterations.(default: 300)') 24 | parser.add_argument('--num_workers', default=6, type=int, 25 | help='Number of loading data threads.(default: 6)') 26 | parser.add_argument('--topk', default=-1, type=int, 27 | help='Calculate map of top k.(default: all)') 28 | parser.add_argument('--gpu', default=4, type=int, 29 | help='Using gpu.(default: False)') 30 | parser.add_argument('--lamb', default=0.2, type=float, 31 | help='Hyper-parameter: balance between CE loss and contrasive loss.') 32 | parser.add_argument('--seed', default=3367, type=int, 33 | help='Random seed.(default: 3367)') 34 | parser.add_argument('--evaluate-interval', default=4, type=int, 35 | help='Evaluation interval.(default: 4)') 36 | 37 | args = parser.parse_args() 38 | 39 | return args 40 | 41 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | 7 | 8 | def load_model(feature_dim, code_length, num_classes,): 9 | 10 | resnet34 = models.resnet34(pretrained=True) 11 | model = Resnet34(resnet34, feature_dim, code_length, num_classes) 12 | 13 | return model 14 | 15 | 16 | class Resnet34(nn.Module): 17 | def __init__(self, origin_model, feature_dim, code_length=64, num_classes=100): 18 | super(Resnet34, self).__init__() 19 | self.code_length = code_length 20 | self.features = nn.Sequential(*list(origin_model.children())[:-2]) 21 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 22 | 23 | self.fc_1 = nn.Linear(512, feature_dim) 24 | self.fc_2 = nn.Linear(512, feature_dim) 25 | self.fc = nn.Linear(feature_dim * 2, feature_dim) 26 | self.hash_layer = nn.Linear(feature_dim, code_length) 27 | self.classifier = nn.Linear(code_length, num_classes) 28 | self.apply(_weights_init) 29 | 30 | dict_fc = self.fc_1.state_dict() 31 | self.fc_2.load_state_dict(dict_fc) 32 | 33 | def forward(self, x): 34 | 35 | x = self.features(x) 36 | 37 | x = self.avgpool(x) 38 | x = F.relu(x.view(x.size(0), -1)) 39 | x1 = self.fc_1(x) 40 | x2 = self.fc_2(x) 41 | 42 | concept_selector1 = torch.tanh(x1) 43 | concept_selector2 = torch.tanh(x2) 44 | 45 | alpha = concept_selector2 * concept_selector1 46 | 47 | x1 = x1 * alpha 48 | x2 = x2 * alpha 49 | x = torch.cat((x1, x2), dim=1) 50 | 51 | x = self.fc(x) 52 | 53 | direct_feature = x 54 | 55 | x = torch.tanh(x) 56 | 57 | hash_codes = self.hash_layer(x) 58 | 59 | hash_codes = torch.tanh(hash_codes) 60 | 61 | assignments = self.classifier(hash_codes) 62 | 63 | return hash_codes, assignments, direct_feature 64 | 65 | 66 | def _weights_init(m): 67 | classname = m.__class__.__name__ 68 | if isinstance(m, nn.Linear): 69 | init.kaiming_normal_(m.weight) 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | 5 | 6 | def encode_onehot(labels, num_classes=100): 7 | """ 8 | one-hot labels 9 | 10 | Args: 11 | labels (numpy.ndarray): labels. 12 | num_classes (int): Number of classes. 13 | 14 | Returns: 15 | onehot_labels (numpy.ndarray): one-hot labels. 16 | """ 17 | onehot_labels = np.zeros((len(labels), num_classes)) 18 | 19 | for i in range(len(labels)): 20 | onehot_labels[i, labels[i]] = 1 21 | 22 | return onehot_labels 23 | 24 | 25 | class Onehot(object): 26 | def __init__(self, num_classes=10): 27 | self.num_classes = num_classes 28 | 29 | def __call__(self, sample): 30 | target_onehot = torch.zeros(self.num_classes) 31 | target_onehot[sample] = 1 32 | 33 | return target_onehot 34 | 35 | 36 | def train_transform_cifar(): 37 | """ 38 | Training images transform. 39 | 40 | Args 41 | None 42 | 43 | Returns 44 | transform(torchvision.transforms): transform 45 | """ 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | return transforms.Compose([ 49 | transforms.Resize((256, 256)), 50 | transforms.CenterCrop((224, 224)), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]) 54 | 55 | 56 | def train_transform_imagenet(): 57 | """ 58 | Training images transform. 59 | 60 | Args 61 | None 62 | 63 | Returns 64 | transform(torchvision.transforms): transform 65 | """ 66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 67 | std=[0.229, 0.224, 0.225]) 68 | return transforms.Compose([ 69 | transforms.RandomResizedCrop(224), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | 75 | 76 | def query_transform(): 77 | """ 78 | Query images transform. 79 | 80 | Args 81 | None 82 | 83 | Returns 84 | transform(torchvision.transforms): transform 85 | """ 86 | # Data transform 87 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]) 89 | return transforms.Compose([ 90 | transforms.Resize((256, 256)), 91 | transforms.CenterCrop((224, 224)), 92 | transforms.ToTensor(), 93 | normalize, 94 | ]) 95 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_average_precision(query_code, 5 | database_code, 6 | query_labels, 7 | database_labels, 8 | device, 9 | topk=-1, 10 | ): 11 | """ 12 | Calculate mean average precision(map). 13 | 14 | Args: 15 | query_code (torch.Tensor): Query data hash code. 16 | database_code (torch.Tensor): Database data hash code. 17 | query_labels (torch.Tensor): Query data targets, one-hot 18 | database_labels (torch.Tensor): Database data targets, one-host 19 | device (torch.device): Using CPU or GPU. 20 | topk (int): Calculate top k data map. 21 | 22 | Returns: 23 | meanAP (float): Mean Average Precision. 24 | """ 25 | num_query = query_labels.shape[0] 26 | mean_ap = 0.0 27 | mean_h, mean_m, mean_t = 0.0, 0.0, 0.0 28 | query_h, query_m, query_t = 0.0, 0.0, 0.0 29 | 30 | for i in range(num_query): 31 | # Retrieve images from database 32 | retrieval = (query_labels[i, :] @ database_labels.t() > 0).float() 33 | 34 | # Calculate hamming distance 35 | hamming_dist = 0.5 * (database_code.shape[1] - query_code[i, :] @ database_code.t()) 36 | 37 | # Arrange position according to hamming distance 38 | retrieval = retrieval[torch.argsort(hamming_dist)][:topk] 39 | 40 | # Retrieval count 41 | retrieval_cnt = retrieval.sum().int().item() 42 | 43 | # Can not retrieve images 44 | if retrieval_cnt == 0: 45 | continue 46 | 47 | # Generate score for every position 48 | score = torch.linspace(1, retrieval_cnt, retrieval_cnt).to(device) 49 | 50 | # Acquire index 51 | index = (torch.nonzero(retrieval == 1, as_tuple=False).squeeze() + 1.0).float() 52 | 53 | # Map of head, middle and tail class 54 | label = torch.where(query_labels[i, :] == 1)[0] 55 | num = label // 33 56 | mean_ap += (score / index).mean() 57 | if num == 0: 58 | mean_h += (score / index).mean() 59 | query_h += 1 60 | elif num == 1: 61 | mean_m += (score / index).mean() 62 | query_m += 1 63 | else: 64 | mean_t += (score / index).mean() 65 | query_t += 1 66 | 67 | mean_ap = mean_ap / num_query 68 | mean_ap_h = mean_h / query_h 69 | mean_ap_m = mean_m / query_m 70 | mean_ap_t = mean_t / query_t 71 | torch.cuda.empty_cache() 72 | mean_ap_class = [mean_ap_h, mean_ap_m, mean_ap_t] 73 | 74 | return mean_ap, mean_ap_class 75 | -------------------------------------------------------------------------------- /data/data_loader_1.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFile 2 | from data import cifar100, imagenet 3 | 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | 6 | 7 | def load_data(dataset, root, batch_size, num_workers): 8 | """ 9 | Load dataset. 10 | 11 | Args 12 | dataset(str): Dataset name. 13 | root(str): Path of dataset. 14 | num_workers(int): Number of loading data threads. 15 | 16 | Returns 17 | train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. 18 | """ 19 | 20 | if dataset == 'cifar-100-IF1': 21 | root = root + '/cifar-100-IF1' 22 | train_dataloader, query_dataloader, retrieval_dataloader = cifar100.load_data(root, 23 | batch_size, 24 | num_workers, 25 | ) 26 | 27 | elif dataset == 'cifar-100-IF50': 28 | root = root + '/cifar-100-IF50' 29 | train_dataloader, query_dataloader, retrieval_dataloader = cifar100.load_data(root, 30 | batch_size, 31 | num_workers, 32 | ) 33 | elif dataset == 'cifar-100-IF100': 34 | root = root + '/cifar-100-IF100' 35 | train_dataloader, query_dataloader, retrieval_dataloader = cifar100.load_data(root, 36 | batch_size, 37 | num_workers, 38 | ) 39 | 40 | elif dataset == 'imagenet-100-IF1': 41 | train_dataloader, query_dataloader, retrieval_dataloader = imagenet.load_data(dataset, 42 | root, 43 | batch_size, 44 | num_workers, 45 | ) 46 | 47 | elif dataset == 'imagenet-100-IF50': 48 | train_dataloader, query_dataloader, retrieval_dataloader = imagenet.load_data(dataset, 49 | root, 50 | batch_size, 51 | num_workers, 52 | ) 53 | elif dataset == 'imagenet-100-IF100': 54 | train_dataloader, query_dataloader, retrieval_dataloader = imagenet.load_data(dataset, 55 | root, 56 | batch_size, 57 | num_workers, 58 | ) 59 | else: 60 | raise ValueError("Invalid dataset name!") 61 | 62 | return train_dataloader, query_dataloader, retrieval_dataloader 63 | 64 | 65 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def generate_code(model, dataloader, code_length, num_classes, device, flag): 6 | """ 7 | Generate hash code 8 | Args 9 | dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 10 | code_length(int): Hash code length. 11 | device(torch.device): Using gpu or cpu. 12 | Returns 13 | code(torch.Tensor): Hash code. 14 | assignment(torch.Tensor): assignment. 15 | label(torch.Tensor): label of sample 16 | """ 17 | model.eval() 18 | with torch.no_grad(): 19 | N = len(dataloader.dataset) 20 | code = torch.zeros([N, code_length]) 21 | assignment = torch.zeros([N, num_classes]) 22 | label = torch.zeros(N, dtype=torch.long) 23 | for data, target, index in dataloader: 24 | # print(index) 25 | data = data.to(device) 26 | hash_code, class_assignment, _ = model(data) 27 | code[index, :] = hash_code.sign().cpu() 28 | assignment[index, :] = class_assignment[:data.size(0), :].cpu() 29 | if flag == 0: 30 | lab = torch.nonzero(target, as_tuple=False)[:, 1] 31 | label[index] = lab.long().cpu() 32 | torch.cuda.empty_cache() 33 | return code, assignment, label 34 | 35 | 36 | def generate_hash_center(model, dataloader, device, code_length): 37 | """ 38 | Generate hash_center 39 | Args 40 | dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 41 | code_length(int): Hash code length. 42 | device(torch.device): Using gpu or cpu. 43 | Returns 44 | code(torch.Tensor): hash_center. 45 | """ 46 | model.eval() 47 | with torch.no_grad(): 48 | hash_center = torch.zeros([100, code_length]) 49 | counter = torch.zeros([100]) 50 | for data, targets, _ in dataloader: 51 | data, targets = data.to(device), targets.to(device) 52 | hash_code, _, _ = model(data) 53 | direct_feature = hash_code.to('cpu') 54 | index = torch.nonzero(targets, as_tuple=False)[:, 1] 55 | index = index.to('cpu') 56 | for j in range(len(data)): 57 | hash_center[index[j], :] = hash_center[index[j], :] + direct_feature[j, :] 58 | counter[index[j]] = counter[index[j]] + 1 59 | 60 | for k in range(100): 61 | hash_center[k, :] = hash_center[k, :] / counter[k] 62 | torch.cuda.empty_cache() 63 | return hash_center 64 | 65 | 66 | def get_correct_num(output, labels): 67 | """ 68 | Get the number of correctly classified samples 69 | Args 70 | output: output of classification layer. 71 | label: label of sample. 72 | Returns 73 | number: number of correctly classified samples. 74 | """ 75 | return output.argmax(dim=1).eq(labels).sum().item() 76 | 77 | 78 | # Get the number of correctly classified samples for each class 79 | 80 | 81 | def sample_num_per_class(output, label, class_num, flag): 82 | class_sample_num = list(np.zeros([class_num], dtype=int)) 83 | class_correct_num = list(np.zeros([class_num], dtype=int)) 84 | 85 | for i in label: 86 | class_sample_num[i] += 1 87 | if flag == 0: 88 | output = output.argmax(axis=1) 89 | for i in range(len(label)): 90 | if output[i] == label[i]: 91 | class_correct_num[label[i]] += 1 92 | return class_correct_num, class_sample_num 93 | 94 | 95 | # Calculate the classification accuracy for each class 96 | 97 | 98 | def correct_per_class(class_correct_num, class_sample_num, class_num): 99 | accuracy_per_class = list(np.zeros(class_num)) 100 | for i in range(class_num): 101 | if class_sample_num[i] != 0: 102 | accuracy_per_class[i] = round(100 * class_correct_num[i] / class_sample_num[i], 3) 103 | return accuracy_per_class 104 | -------------------------------------------------------------------------------- /data/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data.dataloader import DataLoader 7 | from torch.utils.data.dataset import Dataset 8 | 9 | from data.transform import train_transform_cifar, query_transform, Onehot, encode_onehot 10 | 11 | 12 | def load_data(root, batch_size, num_workers): 13 | """ 14 | Load cifar-10 dataset. 15 | 16 | Args 17 | root(str): Path of dataset. 18 | batch_size(int): Batch size. 19 | num_workers(int): Number of data loading workers. 20 | 21 | Returns 22 | train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.DataLoader): Data loader. 23 | """ 24 | root = os.path.join(root, 'images') 25 | num_classes = 100 26 | train_dataloader = DataLoader( 27 | ImagenetDataset( 28 | os.path.join(root, 'train'), 29 | transform=train_transform_cifar(), 30 | target_transform=Onehot(num_classes), 31 | ), 32 | batch_size=batch_size, 33 | num_workers=num_workers, 34 | shuffle=True, 35 | pin_memory=True, 36 | ) 37 | 38 | query_dataloader = DataLoader( 39 | ImagenetDataset( 40 | os.path.join(root, 'query'), 41 | transform=query_transform(), 42 | target_transform=Onehot(num_classes), 43 | ), 44 | batch_size=batch_size, 45 | num_workers=num_workers, 46 | shuffle=False, 47 | pin_memory=True, 48 | ) 49 | 50 | retrieval_dataloader = DataLoader( 51 | ImagenetDataset( 52 | os.path.join(root, 'database'), 53 | transform=query_transform(), 54 | target_transform=Onehot(num_classes), 55 | ), 56 | batch_size=batch_size, 57 | num_workers=num_workers, 58 | shuffle=False, 59 | pin_memory=True, 60 | ) 61 | 62 | return train_dataloader, query_dataloader, retrieval_dataloader, 63 | 64 | 65 | class ImagenetDataset(Dataset): 66 | classes = None 67 | class_to_idx = None 68 | 69 | def __init__(self, root, transform=None, target_transform=None): 70 | self.root = root 71 | self.transform = transform 72 | self.target_transform = target_transform 73 | self.data = [] 74 | self.targets = [] 75 | 76 | # Assume file alphabet order is the class order 77 | if ImagenetDataset.class_to_idx is None: 78 | ImagenetDataset.classes, ImagenetDataset.class_to_idx = self._find_classes(root) 79 | 80 | for i, cl in enumerate(ImagenetDataset.classes): 81 | cur_class = os.path.join(self.root, cl) 82 | files = os.listdir(cur_class) 83 | files = [os.path.join(cur_class, i) for i in files] 84 | self.data.extend(files) 85 | self.targets.extend([ImagenetDataset.class_to_idx[cl] for i in range(len(files))]) 86 | self.targets = np.asarray(self.targets) 87 | self.onehot_targets = torch.from_numpy(encode_onehot(self.targets, 100)).float() 88 | 89 | def get_onehot_targets(self): 90 | return self.onehot_targets 91 | 92 | def __len__(self): 93 | return len(self.data) 94 | 95 | def __getitem__(self, item): 96 | img, target = self.data[item], self.targets[item] 97 | 98 | img = Image.open(img).convert('RGB') 99 | 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | if self.target_transform is not None: 103 | target = self.target_transform(target) 104 | return img, target, item 105 | 106 | def _find_classes(self, dir): 107 | """ 108 | Finds the class folders in a dataset. 109 | 110 | Args: 111 | dir (string): Root directory path. 112 | 113 | Returns: 114 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 115 | 116 | Ensures: 117 | No class is a subdirectory of another. 118 | """ 119 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 120 | classes.sort() 121 | class_to_idx = {classes[i]: i for i in range(len(classes))} 122 | return classes, class_to_idx 123 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import torch.optim as optim 6 | 7 | from args import get_args 8 | from loguru import logger 9 | from network import resnet 10 | from data import data_loader_1 11 | from tools import model_train, loss 12 | 13 | 14 | def main(args): 15 | 16 | # Get device 17 | if args.gpu is not None: 18 | torch.cuda.set_device(args.gpu) 19 | args.device = torch.device(args.gpu) 20 | else: 21 | args.device = torch.device("cpu") 22 | 23 | # Set random seed 24 | if args.seed is not None: 25 | torch.backends.cudnn.benchmark = True 26 | random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | torch.cuda.manual_seed(args.seed) 29 | np.random.seed(args.seed) 30 | 31 | logger.add('logs/{}_code_{}_lamb_{}_batch_size_{}.log'.format( 32 | args.dataset, 33 | args.code_length, 34 | args.lamb, 35 | args.batch_size, 36 | ), 37 | rotation='500 MB', 38 | level='INFO', 39 | ) 40 | logger.info(args) 41 | 42 | # Build dataset 43 | dataset = args.dataset.split('-')[0] 44 | 45 | train_loader, query_loader, retrieval_loader = data_loader_1.load_data(args.dataset, 46 | args.root, 47 | args.batch_size, 48 | args.num_workers 49 | ) 50 | 51 | print('dataset loading end') 52 | 53 | # Print class-samples number 54 | class_samples = torch.Tensor(np.zeros(args.num_classes)) 55 | for _, targets, _ in train_loader: 56 | class_samples += torch.sum(targets, dim=0) 57 | print('class sample number:{}'.format(class_samples)) 58 | 59 | args.code_length = list(map(int, args.code_length.split(','))) 60 | 61 | for length in args.code_length: 62 | length = int(length) 63 | 64 | # Build network 65 | 66 | model = resnet.load_model(args.feature_dim, length, args.num_classes) 67 | model.to(args.device) 68 | 69 | if dataset == 'cifar': 70 | optimizer = optim.RMSprop( 71 | filter(lambda p: p.requires_grad, model.parameters()), 72 | lr=args.lr, 73 | weight_decay=5e-4, 74 | ) 75 | print('cifar-100 optimizer end') 76 | 77 | elif dataset == 'imagenet': 78 | feature_params = [] 79 | hashing_params = [] 80 | for p_name, p in model.named_parameters(): 81 | if p_name.startswith('features'): 82 | feature_params += [p] 83 | else: 84 | hashing_params += [p] 85 | 86 | optimizer = optim.RMSprop([ 87 | {'params': feature_params, 'lr': 0.1*args.lr, 'weight_decay': 5e-4}, 88 | {'params': hashing_params, 'lr': 10*args.lr, 'weight_decay': 5e-4}] 89 | ) 90 | print('imagenet-100 optimizer end') 91 | else: 92 | optimizer = None 93 | print('dataset is not right') 94 | exit() 95 | 96 | criterion_ce = loss.CELoss().to(args.device) 97 | criterion_con = loss.CenConLoss().to(args.device) 98 | 99 | # Training 100 | checkpoint = model_train.train(args, 101 | length, 102 | model, 103 | optimizer, 104 | train_loader, 105 | query_loader, 106 | retrieval_loader, 107 | criterion_ce, 108 | criterion_con 109 | ) 110 | logger.info('[code_length:{}][map:{:.4f}]'.format(length, checkpoint['map'])) 111 | 112 | # Save checkpoint 113 | torch.save( 114 | checkpoint, 115 | os.path.join('checkpoints', '{}_code_{}_lamb_{}_map_{:.4f}_batchsize_{}_maxIter_{}.pt'.format( 116 | args.dataset, 117 | length, 118 | args.lamb, 119 | checkpoint['map'], 120 | args.batch_size, 121 | args.max_iter) 122 | ) 123 | ) 124 | 125 | 126 | if __name__ == '__main__': 127 | main(get_args()) 128 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Dataset 9 | 10 | from data.transform import encode_onehot, Onehot 11 | from PIL import ImageFilter 12 | import random 13 | 14 | 15 | class GaussianBlur(object): 16 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 17 | 18 | def __init__(self, sigma=[.1, 2.]): 19 | self.sigma = sigma 20 | 21 | def __call__(self, x): 22 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 23 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 24 | return x 25 | 26 | 27 | def load_data(dataset, root, batch_size, workers): 28 | """ 29 | Load imagenet dataset 30 | 31 | Args 32 | root (str): Path of imagenet dataset. 33 | batch_size (int): Number of samples in one batch. 34 | workers (int): Number of data loading threads. 35 | 36 | Returns 37 | train_dataloader, query_dataloader, retrieval_dataloader(torch.utils.data.dataloader.DataLoader): Data loader. 38 | """ 39 | # Data transform 40 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 41 | std=[0.229, 0.224, 0.225]) 42 | train_transform = transforms.Compose([ 43 | transforms.RandomResizedCrop(224), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | normalize, 47 | ]) 48 | 49 | # train_transform = transforms.Compose([ 50 | # transforms.RandomResizedCrop(224), 51 | # transforms.RandomHorizontalFlip(), 52 | # transforms.RandomApply([ 53 | # transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 54 | # ], p=0.8), 55 | # transforms.RandomGrayscale(p=0.2), 56 | # transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 57 | # transforms.ToTensor(), 58 | # normalize, 59 | # ]) 60 | 61 | query_transform = transforms.Compose([ 62 | transforms.Resize(256), 63 | transforms.CenterCrop(224), 64 | transforms.ToTensor(), 65 | normalize, 66 | ]) 67 | 68 | # Construct data loader 69 | index = dataset.index("IF") 70 | sub = dataset[index:] 71 | if sub == 'IF100': 72 | train_dir = os.path.join(root, 'train-alpha=0.99-IF=100.0') 73 | elif sub == 'IF50': 74 | train_dir = os.path.join(root, 'train-alpha=0.845-IF=50.0') 75 | elif sub == 'IF20': 76 | train_dir = os.path.join(root, 'train-IF20') 77 | elif sub == 'IF10': 78 | train_dir = os.path.join(root, 'train-IF10') 79 | elif sub == 'IF1': 80 | train_dir = os.path.join(root, 'train') 81 | else: 82 | print('train path error') 83 | return 84 | # train_dir = os.path.join(root, 'train') 85 | query_dir = os.path.join(root, 'query') 86 | database_dir = os.path.join(root, 'database') 87 | 88 | train_dataset = ImagenetDataset( 89 | train_dir, 90 | transform=train_transform, 91 | targets_transform=Onehot(100), 92 | ) 93 | 94 | print(len(train_dataset)) 95 | 96 | train_dataloader = DataLoader( 97 | train_dataset, 98 | batch_size=batch_size, 99 | shuffle=True, 100 | num_workers=workers, 101 | pin_memory=True, 102 | ) 103 | 104 | query_dataset = ImagenetDataset( 105 | query_dir, 106 | transform=query_transform, 107 | targets_transform=Onehot(100), 108 | ) 109 | 110 | query_dataloader = DataLoader( 111 | query_dataset, 112 | batch_size=batch_size, 113 | num_workers=workers, 114 | pin_memory=True, 115 | ) 116 | 117 | database_dataset = ImagenetDataset( 118 | database_dir, 119 | transform=query_transform, 120 | targets_transform=Onehot(100), 121 | ) 122 | 123 | database_dataloader = DataLoader( 124 | database_dataset, 125 | batch_size=batch_size, 126 | num_workers=workers, 127 | pin_memory=True, 128 | ) 129 | 130 | return train_dataloader, query_dataloader, database_dataloader 131 | 132 | 133 | class ImagenetDataset(Dataset): 134 | classes = None 135 | class_to_idx = None 136 | 137 | def __init__(self, root, transform=None, targets_transform=None): 138 | self.root = root 139 | self.transform = transform 140 | self.targets_transform = targets_transform 141 | self.imgs = [] 142 | self.targets = [] 143 | 144 | # Assume file alphabet order is the class order 145 | if ImagenetDataset.class_to_idx is None: 146 | ImagenetDataset.classes, ImagenetDataset.class_to_idx = self._find_classes(root) 147 | 148 | for i, cl in enumerate(ImagenetDataset.classes): 149 | cur_class = os.path.join(self.root, cl) 150 | files = os.listdir(cur_class) 151 | files = [os.path.join(cur_class, i) for i in files] 152 | self.imgs.extend(files) 153 | self.targets.extend([ImagenetDataset.class_to_idx[cl] for i in range(len(files))]) 154 | self.targets = np.asarray(self.targets) 155 | self.onehot_targets = torch.from_numpy(encode_onehot(self.targets, 100)).float() 156 | self.data = self.imgs 157 | 158 | def get_onehot_targets(self): 159 | return self.onehot_targets 160 | 161 | def __len__(self): 162 | return len(self.imgs) 163 | 164 | def __getitem__(self, item): 165 | img, target = self.imgs[item], self.targets[item] 166 | 167 | img = Image.open(img).convert('RGB') 168 | 169 | if self.transform is not None: 170 | img = self.transform(img) 171 | if self.targets_transform is not None: 172 | target = self.targets_transform(target) 173 | return img, target, item 174 | 175 | def _find_classes(self, dir): 176 | """ 177 | Finds the class folders in a dataset. 178 | 179 | Args: 180 | dir (string): Root directory path. 181 | 182 | Returns: 183 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 184 | 185 | Ensures: 186 | No class is a subdirectory of another. 187 | """ 188 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 189 | classes.sort() 190 | class_to_idx = {classes[i]: i for i in range(len(classes))} 191 | return classes, class_to_idx 192 | -------------------------------------------------------------------------------- /tools/model_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from loguru import logger 4 | from torch.optim.lr_scheduler import CosineAnnealingLR 5 | from tools.utils import * 6 | from tools.evaluate import mean_average_precision 7 | 8 | 9 | def train(args, length, model, optimizer, train_loader, query_loader, 10 | retrieval_loader, criterion_ce, criterion_con): 11 | """ 12 | Training network. 13 | 14 | Args: 15 | args: args. 16 | length: code length. 17 | model: resnet34. 18 | optimizer: optimizer. 19 | train_loader: train data loader. 20 | query_loader: query data loader. 21 | retrieval_loader: retrieval data loader. 22 | criterion_ce: CE loss. 23 | criterion_con: Cen_con loss. 24 | class_sample: Class sample number. 25 | 26 | Returns: 27 | checkpoint: Checkpoint of network. 28 | """ 29 | scheduler = CosineAnnealingLR( 30 | optimizer, 31 | args.max_iter, 32 | args.lr / 100, 33 | ) 34 | 35 | # Initialization 36 | running_loss = 0. 37 | best_map = 0. 38 | 39 | # Training 40 | 41 | for it in range(args.max_iter): 42 | 43 | generate_hash_center_start = time.time() 44 | hash_center = generate_hash_center(model, train_loader, args.device, length) 45 | generate_hash_center_end = time.time() 46 | print('iter[{}/{}] generate_hash_center time{:.3f}'.format(it, args.max_iter, 47 | generate_hash_center_end - generate_hash_center_start)) 48 | 49 | model.train() 50 | tic = time.time() 51 | iter_num = 0 52 | epoch_loss = 0 53 | loss_con_num = 0 54 | 55 | for data, targets, index in train_loader: 56 | iter_num += 1 57 | 58 | data, targets, index = data.to(args.device), targets.to(args.device), index.to(args.device) 59 | 60 | optimizer.zero_grad() 61 | 62 | hash_center = hash_center.to('cuda') 63 | 64 | hashcodes, assignments, direct_feature = model(data) 65 | 66 | loss_ce = criterion_ce(assignments, targets) 67 | 68 | loss_con = criterion_con(hashcodes, hash_center, targets) 69 | loss = args.lamb * loss_ce + loss_con 70 | 71 | running_loss = running_loss + loss.item() 72 | epoch_loss = epoch_loss + loss.item() 73 | loss_con_num = loss_con_num + loss_con 74 | 75 | loss.backward() 76 | optimizer.step() 77 | 78 | # update step 79 | scheduler.step() 80 | training_time = time.time() - tic 81 | 82 | print('iter[{}/{}] train time{:.3f} loss{:.3f} con_loss{:.3f}'.format(it, args.max_iter, training_time, epoch_loss, 83 | loss_con_num)) 84 | 85 | # Evaluate 86 | if it % args.evaluate_interval == args.evaluate_interval - 1: 87 | 88 | start = time.time() 89 | query_code, query_assignment, label_q = generate_code(model, 90 | query_loader, 91 | length, 92 | args.num_classes, 93 | args.device, 94 | 0 95 | ) 96 | retrieval_code, retrieval_assignment, _ = generate_code(model, 97 | retrieval_loader, 98 | length, 99 | args.num_classes, 100 | args.device, 101 | 1 102 | ) 103 | 104 | correct = get_correct_num(query_assignment, label_q) 105 | 106 | class_correct_num, class_sample_num = np.array(sample_num_per_class(query_assignment, label_q, 107 | 100, 0), dtype=int) 108 | 109 | train_total_correct_class = class_correct_num 110 | train_total_sample_class = class_sample_num 111 | 112 | train_accuracy_pre_class = np.array(correct_per_class(train_total_correct_class, 113 | train_total_sample_class, 100)) 114 | train_three = [train_accuracy_pre_class[:33].mean(), train_accuracy_pre_class[33:66].mean(), 115 | train_accuracy_pre_class[66:].mean()] 116 | print('epoch = {}, total test accuracy = {:.3f}, class accuracy = {}'. 117 | format(it, correct / 10000, train_three)) 118 | 119 | query_targets = query_loader.dataset.get_onehot_targets() 120 | retrieval_targets = retrieval_loader.dataset.get_onehot_targets() 121 | 122 | # Compute map 123 | mAP, mAP_class = mean_average_precision( 124 | query_code.to(args.device), 125 | retrieval_code.to(args.device), 126 | query_targets.to(args.device), 127 | retrieval_targets.to(args.device), 128 | args.device, 129 | args.topk, 130 | ) 131 | 132 | # Log 133 | logger.info('[iter:{}/{}][loss:{:.2f}][map:{:.4f}][time:{:.2f}][map class:{}]'.format( 134 | it + 1, 135 | args.max_iter, 136 | running_loss / args.evaluate_interval, 137 | mAP, 138 | training_time, 139 | mAP_class 140 | )) 141 | running_loss = 0. 142 | 143 | # Checkpoint 144 | if best_map < mAP: 145 | best_map = mAP 146 | 147 | checkpoint = { 148 | 'network': model.state_dict(), 149 | 'qB': query_code.cpu(), 150 | 'rB': retrieval_code.cpu(), 151 | 'qL': query_targets.cpu(), 152 | 'rL': retrieval_targets.cpu(), 153 | 'qAssignment': query_assignment.cpu(), 154 | 'rAssignment': retrieval_assignment.cpu(), 155 | 'map': best_map, 156 | 'hash_center': hash_center.cpu(), 157 | 'lamb': args.lamb, 158 | } 159 | end = time.time() 160 | print('evaluate time = {:3f}'.format(end - start)) 161 | 162 | return checkpoint 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------