├── utils ├── __init__.py └── hash_functions.py ├── datasets ├── __init__.py ├── dataset_registry.py ├── forestnet_download.sh ├── dataset_factory.py ├── bigearthnet_download.sh ├── utils.py ├── bigearthnet.py └── forestnet.py ├── figures ├── approach.png ├── examples.png └── experimental_results.png ├── experiments.sh ├── requirements.txt ├── model ├── __init__.py ├── model_registry.py ├── model_factory.py ├── rgb_vit.py └── vit.py ├── configs ├── rgb_vit.yaml └── prithvi_vit.yaml ├── inference.sh ├── tsne.py ├── README.md ├── speed_test_milvus.py ├── inference.py ├── visualization.py ├── experiments.py └── LICENSE /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bigearthnet, forestnet 2 | from .dataset_factory import load_dataset 3 | -------------------------------------------------------------------------------- /figures/approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/remote-sensing-image-retrieval/HEAD/figures/approach.png -------------------------------------------------------------------------------- /figures/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/remote-sensing-image-retrieval/HEAD/figures/examples.png -------------------------------------------------------------------------------- /figures/experimental_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/remote-sensing-image-retrieval/HEAD/figures/experimental_results.png -------------------------------------------------------------------------------- /experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python experiments.py --match any --distance_function hamming --hash_method trivial,lsh,none --hash_length 32,768 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchgeo 4 | timm 5 | numpy 6 | pandas 7 | pyyaml 8 | lshashpy3 9 | h5py 10 | milvus 11 | pymilvus -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # load model classes 2 | from .vit import PrithviViT 3 | from .rgb_vit import RGB_ViT 4 | 5 | from .model_factory import build_model 6 | -------------------------------------------------------------------------------- /datasets/dataset_registry.py: -------------------------------------------------------------------------------- 1 | 2 | DATASET_REGISTRY = {} 3 | 4 | def register_dataset(dataset_name, dataset_fn): 5 | DATASET_REGISTRY[dataset_name] = dataset_fn 6 | -------------------------------------------------------------------------------- /model/model_registry.py: -------------------------------------------------------------------------------- 1 | 2 | MODEL_REGISTRY = {} 3 | 4 | 5 | def register_model(model_class): 6 | MODEL_REGISTRY[model_class.__name__] = model_class 7 | return model_class 8 | -------------------------------------------------------------------------------- /configs/rgb_vit.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: RGB_ViT 3 | img_size: 224 4 | data_mean: 5 | - 5000 6 | - 5000 7 | - 5000 8 | data_std: 9 | - 5000 10 | - 5000 11 | - 5000 12 | dataset: 13 | name: BigEarthNetBGR 14 | split: test 15 | dataloader: 16 | batch_size: 16 17 | num_workers: 4 18 | pin_memory: False 19 | shuffle: False 20 | -------------------------------------------------------------------------------- /datasets/forestnet_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get directory with data 4 | data_dir="${DATA_DIR:-data}" 5 | mkdir $data_dir 6 | cd $data_dir || exit 7 | 8 | # Download data 9 | wget http://download.cs.stanford.edu/deep/ForestNetDataset.zip 10 | unzip ForestNetDataset.zip 11 | mv deep/downloads/ForestNetDataset . 12 | 13 | # Remove zip and empty tmp dir 14 | rm ForestNetDataset.zip 15 | rm -r deep 16 | 17 | 18 | # (~3 Gb, ~40 min) -------------------------------------------------------------------------------- /model/model_factory.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | import torch 5 | import subprocess 6 | from .model_registry import MODEL_REGISTRY 7 | 8 | 9 | def build_model(cfg): 10 | # Get model class from registry 11 | model_name = cfg['model']['name'] 12 | logging.info(f'Load model {model_name}') 13 | assert model_name in MODEL_REGISTRY, (f'model {model_name} not registered.' 14 | f'Select a model from {MODEL_REGISTRY.keys()}') 15 | model_constructor = MODEL_REGISTRY[model_name] 16 | 17 | # Init model 18 | model = model_constructor(**cfg['model']) 19 | logging.debug('Model initialized') 20 | 21 | return model 22 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from .dataset_registry import DATASET_REGISTRY 4 | from torchvision import transforms 5 | 6 | 7 | def load_dataset(cfg): 8 | # load settings from cfg 9 | dataset_name = cfg['dataset']['name'] 10 | logging.info(f"Load dataset {dataset_name} ({cfg['dataset']['split']} split)") 11 | assert dataset_name in DATASET_REGISTRY, (f"Dataset {dataset_name} not registered. " 12 | f"Select a dataset from {DATASET_REGISTRY.keys()} ") 13 | # get dataset fc from registry 14 | dataset_fn = DATASET_REGISTRY[dataset_name] 15 | # load dataset 16 | dataset = dataset_fn(cfg) 17 | return dataset 18 | -------------------------------------------------------------------------------- /configs/prithvi_vit.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: PrithviViT 3 | weights: weights/Prithvi_100M.pt 4 | depth: 12 5 | embed_dim: 768 6 | img_size: 224 7 | in_chans: 6 8 | num_frames: 1 9 | num_heads: 12 10 | patch_size: 16 11 | tubelet_size: 1 12 | data_mean: 13 | - 775.2290211032589 14 | - 1080.992780391705 15 | - 1228.5855250417867 16 | - 2497.2022620507532 17 | - 2204.2139147975554 18 | - 1610.8324823273745 19 | data_std: 20 | - 1281.526139861424 21 | - 1270.0297974547493 22 | - 1399.4802505642526 23 | - 1368.3446143747644 24 | - 1291.6764008585435 25 | - 1154.505683480695 26 | dataset: 27 | name: BigEarthNet 28 | split: test 29 | dataloader: 30 | batch_size: 16 31 | num_workers: 4 32 | pin_memory: False 33 | shuffle: False 34 | -------------------------------------------------------------------------------- /datasets/bigearthnet_download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get directory with data 4 | data_dir="${DATA_DIR:-data}" 5 | mkdir $data_dir 6 | cd $data_dir || exit 7 | 8 | # create BigEarthNet dir 9 | mkdir BigEarthNet 10 | cd BigEarthNet || exit 11 | 12 | # download splits 13 | wget https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/master/splits/train.csv -O bigearthnet-train.csv 14 | wget https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/master/splits/val.csv -O bigearthnet-val.csv 15 | wget https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/master/splits/test.csv -O bigearthnet-test.csv 16 | 17 | # download BigEarthNet 18 | curl -O https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz 19 | tar -xvzf BigEarthNet-S2-v1.0.tar.gz 20 | # delete tar file 21 | rm BigEarthNet-S2-v1.0.tar.gz 22 | -------------------------------------------------------------------------------- /model/rgb_vit.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision 4 | from torch import nn 5 | from .model_registry import register_model 6 | 7 | 8 | @register_model 9 | class RGB_ViT(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.model = torchvision.models.vision_transformer.vit_b_16( 13 | weights=torchvision.models.vision_transformer.ViT_B_16_Weights) 14 | 15 | def forward(self, input): 16 | # squeeze time dimension 17 | input = input.squeeze(2) 18 | # Reorder the BGR channels to RGB 19 | input = input[:, [2, 1, 0]] 20 | 21 | # Run forward pass 22 | x = self.model._process_input(input) 23 | n = x.shape[0] 24 | 25 | # Expand the class token to the full batch 26 | batch_class_token = self.model.class_token.expand(n, -1, -1) 27 | x = torch.cat([batch_class_token, x], dim=1) 28 | 29 | x = self.model.encoder(x) 30 | 31 | # Classifier "token" as used by standard language architectures 32 | x = x[:, 0] 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DictTransforms: 5 | def __init__(self, 6 | dict_transform : dict, 7 | ): 8 | self.dict_transform = dict_transform 9 | 10 | def __call__(self, sample): 11 | # Apply your transforms to the 'image' key 12 | for key, function in self.dict_transform.items(): 13 | sample[key] = function(sample[key]) 14 | return sample 15 | 16 | 17 | class SelectChannels: 18 | def __init__(self, channels): 19 | self.channels = channels 20 | 21 | def __call__(self, tensor): 22 | return tensor[self.channels] 23 | 24 | 25 | class Unsqueeze: 26 | def __init__(self, dim): 27 | self.dim = dim 28 | 29 | def __call__(self, tensor): 30 | return tensor.unsqueeze(dim=self.dim) 31 | 32 | 33 | class ConvertType: 34 | def __init__(self, dtype): 35 | self.dtype = dtype 36 | 37 | def __call__(self, tensor): 38 | return tensor.to(self.dtype) 39 | 40 | 41 | class AddMeanChannels: 42 | """ 43 | Add missing channels to the tensor based on the mean values. Results in zeros after standardization. 44 | """ 45 | def __init__(self, mean): 46 | self.mean = mean 47 | self.mean_tensor = None 48 | 49 | def __call__(self, tensor): 50 | if self.mean_tensor is None: 51 | # Init tensor with mean values 52 | self.mean_tensor = (torch.ones([len(self.mean) - len(tensor), *tensor.shape[1:]]) * 53 | torch.tensor(self.mean)[len(tensor):, None, None]) 54 | # Add mean values for missing channels 55 | tensor = torch.concat([tensor, self.mean_tensor]) 56 | return tensor 57 | -------------------------------------------------------------------------------- /datasets/bigearthnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from torchvision import transforms 5 | from torchgeo.datasets import BigEarthNet 6 | from functools import partial 7 | from .dataset_registry import register_dataset 8 | from .utils import SelectChannels, Unsqueeze, DictTransforms, ConvertType, AddMeanChannels 9 | 10 | 11 | def init_bigearthnet(bands, normalize, num_classes, cfg, *args, **kwargs): 12 | """ 13 | Init BigEarthNet dataset, with S2 data and 43 classes as default. 14 | """ 15 | # Get dataset parameters 16 | split = cfg['dataset']['split'] 17 | satellite = cfg['dataset']['satellite'] if 'satellite' in cfg['dataset'] else 's2' 18 | 19 | # Get BigEarthNet directory 20 | DATA_DIR = os.getenv('DATA_DIR', 'data') 21 | bigearthnet_dir = os.path.join(DATA_DIR, 'BigEarthNet') 22 | # Check if data is downloaded 23 | assert os.path.isdir(os.path.join(bigearthnet_dir, BigEarthNet.metadata['s2']['directory'])), \ 24 | "Download BigEarthNet with `sh datasets/bigearthnet_download.sh` or specify the DATA_DIR via a env variable." 25 | 26 | # Init transforms 27 | image_transforms = [ 28 | SelectChannels(bands), 29 | ConvertType(torch.float), 30 | transforms.Resize(size=cfg['model']['img_size'], antialias=True), 31 | ] 32 | 33 | if normalize: 34 | if len(bands) != len(cfg['model']['data_mean']): 35 | # Add mean channels values for missing channels (e.g. for BGR data) 36 | image_transforms.append(AddMeanChannels(cfg['model']['data_mean'])) 37 | # Normalize images 38 | image_transforms.append(transforms.Normalize(mean=cfg['model']['data_mean'], std=cfg['model']['data_std'])) 39 | image_transforms.append(Unsqueeze(dim=1)) # add time dim 40 | 41 | ben_transforms = DictTransforms({'image': transforms.Compose(image_transforms)}) 42 | 43 | # Init dataset 44 | dataset = BigEarthNet( 45 | root=bigearthnet_dir, 46 | split=split, 47 | bands=satellite, 48 | num_classes=num_classes, 49 | transforms=ben_transforms, 50 | ) 51 | 52 | return dataset 53 | 54 | 55 | # Add dataset to the registry 56 | # Using the six channels from Prithvi 57 | register_dataset('BigEarthNet', partial(init_bigearthnet, [1, 2, 3, 8, 10, 11], True, 43)) 58 | 59 | register_dataset('BigEarthNetBGR', partial(init_bigearthnet, [1, 2, 3], True, 43)) 60 | 61 | register_dataset('BigEarthNetVisual', partial(init_bigearthnet, [3, 2, 1], False, 43)) 62 | 63 | register_dataset('BigEarthNet19', partial(init_bigearthnet, [1, 2, 3, 8, 10, 11], True, 19)) 64 | 65 | register_dataset('BigEarthNet19BGR', partial(init_bigearthnet, [1, 2, 3], True, 19)) 66 | 67 | register_dataset('BigEarthNet19Visual', partial(init_bigearthnet, [3, 2, 1], False, 19)) -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run inference for ForestNet 4 | # python inference.py -c configs/prithvi_vit.yaml -d ForestNet -s train 5 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet -s val 6 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet -s test 7 | 8 | # python inference.py -c configs/prithvi_vit.yaml -d ForestNet4 -s train 9 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet4 -s val 10 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet4 -s test 11 | 12 | # python inference.py -c configs/prithvi_vit.yaml -d ForestNetBGR -s train 13 | python inference.py -c configs/prithvi_vit.yaml -d ForestNetBGR -s val 14 | python inference.py -c configs/prithvi_vit.yaml -d ForestNetBGR -s test 15 | 16 | # python inference.py -c configs/prithvi_vit.yaml -d ForestNet4BGR -s train 17 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet4BGR -s val 18 | python inference.py -c configs/prithvi_vit.yaml -d ForestNet4BGR -s test 19 | 20 | # Run inference for BigEarthNet 21 | #python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet -s train 22 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet -s val 23 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet -s test 24 | 25 | #python inference.py -c configs/prithvi_vit.yaml -d BigEarthNetBGR -s train 26 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNetBGR -s val 27 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNetBGR -s test 28 | 29 | #python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19 -s train 30 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19 -s val 31 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19 -s test 32 | 33 | #python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19BGR -s train 34 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19BGR -s val 35 | python inference.py -c configs/prithvi_vit.yaml -d BigEarthNet19BGR -s test 36 | 37 | # Run inference with vanilla ViT 38 | # python inference.py -c configs/rgb_vit.yaml -d ForestNetBGR -s train 39 | python inference.py -c configs/rgb_vit.yaml -d ForestNetBGR -s val 40 | python inference.py -c configs/rgb_vit.yaml -d ForestNetBGR -s test 41 | 42 | # python inference.py -c configs/rgb_vit.yaml -d ForestNet4BGR -s train 43 | python inference.py -c configs/rgb_vit.yaml -d ForestNet4BGR -s val 44 | python inference.py -c configs/rgb_vit.yaml -d ForestNet4BGR -s test 45 | 46 | #python inference.py -c configs/rgb_vit.yaml -d BigEarthNetBGR -s train 47 | python inference.py -c configs/rgb_vit.yaml -d BigEarthNetBGR -s val 48 | python inference.py -c configs/rgb_vit.yaml -d BigEarthNetBGR -s test 49 | 50 | #python inference.py -c configs/rgb_vit.yaml -d BigEarthNet19BGR -s train 51 | python inference.py -c configs/rgb_vit.yaml -d BigEarthNet19BGR -s val 52 | python inference.py -c configs/rgb_vit.yaml -d BigEarthNet19BGR -s test 53 | -------------------------------------------------------------------------------- /tsne.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.colors as mcolors 8 | from sklearn.manifold import TSNE 9 | from utils.hash_functions import get_hash 10 | 11 | 12 | def main(args): 13 | # Init embeddings with shape [sample, embedding] 14 | output_path = os.getenv('OUTPUT_PATH', os.path.join('output', 'embeddings')) 15 | test_embeddings = torch.load(os.path.join(output_path, args.model, args.dataset, 'test', 'embeddings.pt'), 16 | map_location='cpu') 17 | # Multi-labels with shape [sample, label] 18 | test_labels = torch.load(os.path.join(output_path, args.model, args.dataset, 'test', 'labels.pt'), 19 | map_location='cpu') 20 | 21 | np.random.seed(42) 22 | binary_emb = get_hash(test_embeddings, 'trivial', length=768) 23 | trivial_hash = get_hash(test_embeddings, 'trivial', length=32) 24 | ls_hash = get_hash(test_embeddings, 'lsh', length=32) 25 | 26 | # Select a subset to avoid computation costs 27 | num_samples = 1000 28 | indices = np.random.choice(len(test_embeddings), min(num_samples, len(test_embeddings)), replace=False) 29 | sampled_binary_emb = binary_emb[indices] 30 | sampled_embeddings = test_embeddings[indices] 31 | sampled_trivial_hash = trivial_hash[indices] 32 | sampled_ls_hash = ls_hash[indices] 33 | sampled_labels = test_labels[indices].nonzero()[:, 1] 34 | assert len(sampled_labels) == len(indices), "t-SNE plot is not working with multi-label datasets." 35 | 36 | for i, (name, vector) in enumerate([('embedding', sampled_embeddings), 37 | ('binary', sampled_binary_emb), 38 | ('lsh', sampled_ls_hash), 39 | ('trivial', sampled_trivial_hash)]): 40 | # Compute t-SNE embeddings 41 | tsne = TSNE(n_components=2, random_state=42) 42 | tsne_results = tsne.fit_transform(vector) 43 | 44 | fig, ax = plt.subplots(figsize=(3, 3)) 45 | ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=sampled_labels, cmap='Dark2', alpha=0.9) 46 | ax.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False) 47 | 48 | fig.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01) 49 | if os.path.isdir(args.output_dir): 50 | output_file = os.path.join(args.output_dir, f"{args.model}_{args.dataset}_tsne_{name}.pdf") 51 | else: 52 | output_file = args.output_dir 53 | 54 | plt.savefig(output_file) 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('-o', '--output_dir', type=str, default='output/figures', 60 | help='Path to output dir') 61 | parser.add_argument('-d', '--dataset', type=str, default='ForestNet4') 62 | parser.add_argument('-m', '--model', type=str, default='PrithviViT') 63 | 64 | args = parser.parse_args() 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /utils/hash_functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from lshashpy3 import LSHash 5 | 6 | 7 | def get_hash(embedding, method='trivial', *args, **kwargs): 8 | """ 9 | Convert embedding or list of embeddings to hash codes using the defined method. 10 | """ 11 | # Check for type 12 | if not isinstance(embedding, torch.Tensor): 13 | if isinstance(embedding, list): 14 | # Iterate over list of embeddings 15 | return [get_hash(e, method, *args, **kwargs) for e in embedding] 16 | else: 17 | raise TypeError 18 | 19 | # Create hash codes based on method 20 | if method == 'trivial': 21 | return trivial_hash(embedding, *args, **kwargs) 22 | elif method == 'lsh': 23 | return lshash(embedding, *args, **kwargs) 24 | elif method == 'none': 25 | return embedding 26 | else: 27 | raise NotImplementedError 28 | 29 | def trivial_hash(embedding: torch.Tensor, length: str = 64, threshold: float = 0., seed=None): 30 | """ 31 | Creates a trivial binary hash by averaging multiple embedding dimensions and binarization with a threshold. 32 | 33 | :param embedding: torch tensor with shape [samples, embedding] 34 | :param length: hash length (must divide the embedding length without a rest) 35 | :param threshold: value to binarize float embedding 36 | :return: binary hash codes as int (0 and 1) with shape [samples, hash] 37 | """ 38 | assert embedding.size(-1) % length == 0, \ 39 | f"Cannot create hash with length {length} with embedding dim {embedding.size(-1)}" 40 | resize_factor = int(embedding.size(-1) / length) 41 | binary_hash = embedding.reshape([-1, resize_factor, length]).mean(dim=1) > threshold 42 | return binary_hash.int() 43 | 44 | 45 | def lshash(embedding: torch.Tensor, length: str = 64, store: str = None, seed: int = 42): 46 | """ 47 | Creates a binary hash by applying LSH from the paper: 48 | Tang, Y. K., Mao, X. L., Hao, Y. J., Xu, C., & Huang, H. (2017). 49 | Locality-sensitive hashing for finding nearest neighbors in probability distributions. 50 | In Social Media Processing: 6th National Conference, SMP 2017, Beijing, China, September 14-17, 2017, 51 | Proceedings (pp. 3-15). Springer Singapore. 52 | 53 | Using the implementation from https://github.com/loretoparisi/lshash. 54 | 55 | :param embedding: torch tensor with shape [samples, embedding] 56 | :param length: hash length 57 | :param store: hash table file name (optional) 58 | :param seed: seed for reproducibility. Leads to similar hash codes with multiple calls. 59 | :return: binary hash codes as int (0 and 1) with shape [samples, hash] 60 | """ 61 | # Init Locality-Sensitive Hashing 62 | np.random.seed(seed) 63 | lsh = LSHash(hash_size=length, input_dim=embedding.size(-1), num_hashtables=1, 64 | hashtable_filename=store, overwrite=True) 65 | 66 | # Generate hashes for each embedding 67 | hashes = [] 68 | for e in embedding: 69 | h = lsh.index(e.tolist(), extra_data=None) 70 | hashes.append(list(map(int, h[0]))) 71 | 72 | return torch.tensor(hashes).int() 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Remote sensing image retrieval 2 | 3 | This is the official implementation of the paper **Multi-Spectral Remote Sensing Image Retrieval using Geospatial Foundation Models**. 4 | The experiments are explained in our [paper](https://arxiv.org/abs/2403.02059), and you find more information about Prithvi on [Hugging Face](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M). 5 | 6 | ## Approach 7 | 8 | ![approach.png](figures%2Fapproach.png) 9 | 10 | GeoFM embeddings enable simple but accurate content-based image retrieval of remote sensing images. Optionally, the embeddings are compressed into smaller binary vectors to speed up the process and reduce memory usage. 11 | For each query image, similar images from the database are returned and sorted based on their distance to the query embedding. 12 | 13 | ## Experimental results 14 | 15 | ![experimental_results.png](figures%2Fexperimental_results.png) 16 | 17 | The table presents the mAP@20 results for all evaluated models and datasets. We highlight the best-performing method in bold and underline the second-best one. LSH is reported with a 95% confidence interval based on five seeds as this method uses random hyperplanes. 18 | 19 | 20 | ![examples.png](figures%2Fexamples.png) 21 | 22 | The figure displays examples from two datasets with query images (left), their labels, and retrieved images (right) using Prithvi and the trivial hash method. Images with green frames indicate positive matches, while those with red frames have different labels. Orange shows partial correct matches, where the number represents the number of label matches within the multi-labels. 23 | 24 | ## Setup 25 | 26 | Create an environment and install the required packages with: 27 | ```sh 28 | # create env 29 | conda create -n "rsir" python=3.10 30 | # activate env 31 | conda activate rsir 32 | # install pytorch (e.g., with CUDA 12.1, see https://pytorch.org for other versions) 33 | pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121 34 | # install requirements 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ### Datasets 39 | 40 | We provide bash script for downloading the datasets. Run the following script from the project root: 41 | 42 | ```shell 43 | # optionally specific the data directory (default: 'data/') 44 | export DATA_DIR=data 45 | 46 | # Download BigEarthNet (~66 Gb, ~1h) 47 | sh datasets/bigearthnet_download.sh 48 | 49 | # Download ForestNet (3 Gb, ~5 min) 50 | sh datasets/forestnet_download.sh 51 | ``` 52 | 53 | ### Models 54 | 55 | You can download the model weights for Prithvi-100M from [Hugging Face](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/Prithvi_100M.pt) with the following commands. 56 | 57 | ```shell 58 | mkdir weights 59 | cd weights && wget https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/resolve/main/Prithvi_100M.pt 60 | ``` 61 | 62 | The weights are saved at `weights/Prithvi_100M.pt` but you can also update the path in the config file `configs/prithvi_vit_us.yaml`. 63 | 64 | The weights for the vanilla ViT with RGB channels are downloaded automatically. 65 | 66 | 67 | ## Run experiments 68 | 69 | You can save the embeddings of a dataset with: 70 | ```shell 71 | # Save embeddings 72 | python inference.py -c configs/prithvi_vit.yaml --dataset ForestNet --split val 73 | python inference.py -c configs/prithvi_vit.yaml --dataset ForestNet --split test 74 | ``` 75 | 76 | If you want to save the embeddings of all evaluated models and dataset versions, you can run: 77 | ```shell 78 | bash inference.sh 79 | ``` 80 | 81 | Evaluate all saved embeddings with given 82 | ```shell 83 | # Run experiments 84 | python experiments.py --match any --distance_function hamming --hash_method trivial --hash_length 32 85 | # You can also combine multiple methods 86 | python experiments.py --match any --distance_function hamming --hash_method trivial,lsh,none --hash_length 32,768 87 | ``` 88 | 89 | 90 | ### Speed experiments 91 | 92 | You need a running [Milvus](https://milvus.io) instance for these experiments. 93 | 94 | With saved BigEarthNet embeddings, run the experiments with: 95 | ```shell 96 | python speed_test_milvus.py 97 | ``` 98 | 99 | If you want to run the experiments on another machine, connect to Milvus via ssh. 100 | 101 | ```shell 102 | ssh -L19530:localhost:19530 103 | ``` 104 | 105 | ## Citation 106 | 107 | ```text 108 | @article{RSIR2024, 109 | title={{Multi-Spectral Remote Sensing Image Retrieval using Geospatial Foundation Models}}, 110 | author={Blumenstiel, Benedikt and Moor, Viktoria and Kienzler, Romeo and Brunschwiler, Thomas}, 111 | journal={arXiv preprint arXiv:2403.02059}, 112 | year={2024} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /speed_test_milvus.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import time 5 | import tqdm 6 | import numpy as np 7 | from pymilvus import ( 8 | connections, 9 | utility, 10 | FieldSchema, 11 | CollectionSchema, 12 | DataType, 13 | Collection, 14 | ) 15 | 16 | # Init database 17 | connections.connect("default", host="localhost", port="19530") 18 | 19 | 20 | def run_experiment(queries, database, data_type='float', length=768): 21 | if data_type == 'float': 22 | dtype = DataType.FLOAT_VECTOR 23 | index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} 24 | search_params = {"metric_type": "L2"} # "params": {"nprobe": 10}, 25 | elif data_type == 'bool': 26 | dtype = DataType.BINARY_VECTOR 27 | index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "HAMMING"} 28 | search_params = {"metric_type": "HAMMING"} 29 | else: 30 | raise NotImplementedError 31 | 32 | fields = [ 33 | FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False), 34 | FieldSchema(name="vector", dtype=dtype, dim=length) 35 | ] 36 | schema = CollectionSchema(fields, "test speed") 37 | try: 38 | collection = Collection("igarss", schema) 39 | except: 40 | utility.drop_collection("igarss") 41 | collection = Collection("igarss", schema) 42 | pass 43 | 44 | # Register data 45 | # Iterate over database to avoid hitting the max entries threshold 46 | step_size = 10000 47 | for i in tqdm.tqdm(range(0, len(database), step_size), desc='Upload images'): 48 | m = min(len(database), i + 10000) 49 | entities = [ 50 | [i for i in range(i, m)], 51 | database[i:m] 52 | ] 53 | insert_result = collection.insert(entities) 54 | collection.flush() 55 | 56 | # Create index 57 | collection.create_index("vector", index) 58 | 59 | collection.load() 60 | 61 | # Run retrieval test 62 | 63 | time_start = time.time() 64 | for n, query in enumerate(queries): 65 | result = collection.search([query], "vector", search_params, limit=20, output_fields=["pk"]) 66 | if (n+1) % 1000 == 0: 67 | print(f'Average retrieval time after {n+1} samples: {(time.time() - time_start) / (n+1):.4f} s/query') 68 | 69 | # Drop database 70 | utility.drop_collection("igarss") 71 | 72 | 73 | def main(): 74 | output_dir = 'output/embeddings/PrithviViT/BigEarthNet' 75 | 76 | print('\nLoad Binary hash codes with length 32') 77 | # Load hash codes 78 | queries = torch.load(os.path.join(output_dir, 'val/hash_codes.pt')).numpy()[:, :32] 79 | database = torch.load(os.path.join(output_dir, 'test/hash_codes.pt')).numpy()[:, :32] 80 | # Create binary vectors 81 | queries = [bytes(q) for q in np.packbits(queries, axis=-1)] 82 | database = [bytes(d) for d in np.packbits(database, axis=-1)] 83 | 84 | print('Experiment with 10K data') 85 | run_experiment(queries[:1000], database[:10000], data_type='bool', length=32) 86 | print('Experiment with 50K data') 87 | run_experiment(queries[:1000], database[:50000], data_type='bool', length=32) 88 | print('Experiment with 100K data') 89 | run_experiment(queries[:1000], database[:100000], data_type='bool', length=32) 90 | 91 | # Load embeddings 92 | print('\nLoad float embeddings with length 768') 93 | queries = torch.load(os.path.join(output_dir, 'val/embeddings.pt')).numpy() 94 | database = torch.load(os.path.join(output_dir, 'test/embeddings.pt')).numpy() 95 | 96 | print('Experiment with 10K data') 97 | run_experiment(queries[:1000], database[:10000], data_type='float', length=768) 98 | print('Experiment with 50K data') 99 | run_experiment(queries[:1000], database[:50000], data_type='float', length=768) 100 | print('Experiment with 100K data') 101 | run_experiment(queries[:1000], database[:100000], data_type='float', length=768) 102 | 103 | print('\nUse binary embeddings with length 768') 104 | queries = (queries > 0).to(int) 105 | database = (database > 0).to(int) 106 | queries = [bytes(q) for q in np.packbits(queries, axis=-1)] 107 | database = [bytes(d) for d in np.packbits(database, axis=-1)] 108 | 109 | print('Experiment with 10K data') 110 | run_experiment(queries[:1000], database[:10000], data_type='bool', length=768) 111 | print('Experiment with 50K data') 112 | run_experiment(queries[:1000], database[:50000], data_type='bool', length=768) 113 | print('Experiment with 100K data') 114 | run_experiment(queries[:1000], database[:100000], data_type='bool', length=768) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /datasets/forestnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import h5py 4 | import json 5 | import torch 6 | import numpy as np 7 | import pandas as pd 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms, io 10 | import torch.nn.functional as F 11 | from functools import partial 12 | from .dataset_registry import register_dataset 13 | from .utils import Unsqueeze, SelectChannels, AddMeanChannels 14 | 15 | 16 | LABEL_NAMES = ['Oil palm plantation', 'Timber plantation', 'Other large-scale plantations', 'Grassland shrubland', 'Small-scale agriculture', 'Small-scale mixed plantation', 'Small-scale oil palm plantation', 'Mining', 'Fish pond', 'Logging', 'Secondary forest', 'Other'] 17 | LABEL_DICT = {'Oil palm plantation': 0, 'Timber plantation': 1, 'Other large-scale plantations': 2, 'Grassland shrubland': 3, 'Small-scale agriculture': 4, 'Small-scale mixed plantation': 5, 'Small-scale oil palm plantation': 6, 'Mining': 7, 'Fish pond': 8, 'Logging': 9, 'Secondary forest': 10, 'Other': 11} 18 | LABEL_NAMES_MERGED = ['Plantation', 'Grassland shrubland', 'Smallholder agriculture', 'Other'] 19 | MERGED_LABEL_DICT = {'Plantation': 0, 'Grassland shrubland': 1, 'Smallholder agriculture': 2, 'Other': 3} 20 | 21 | MERGE_LABEL_DICT = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 2, 6: 2, 7: 3, 8: 3, 9: 3, 10: 3, 11: 3} 22 | 23 | 24 | class ForestNet(Dataset): 25 | """ 26 | Dataset class for ForestNet, introduced in: 27 | Irvin, J., Sheng, H., Ramachandran, N., Johnson-Yu, S., Zhou, S., Story, K., ... & Ng, A. Y. (2020). 28 | Forestnet: Classifying drivers of deforestation in indonesia using deep learning on satellite imagery. 29 | arXiv preprint arXiv:2011.05479. https://arxiv.org/pdf/2011.05479.pdf 30 | """ 31 | def __init__(self, dataset_root, split, transform=None, merge_labels=False): 32 | self.dataset_root = dataset_root 33 | self.split = split 34 | self.transform = transform 35 | 36 | assert os.path.isdir(dataset_root), ('ForestNet data not found. ' 37 | 'Download dataset with `sh datasets/forestnet_download.sh`') 38 | 39 | # Load file names and labels 40 | split_df = pd.read_csv(os.path.join(dataset_root, f'{split}.csv')) 41 | self.filenames = split_df['example_path'].values 42 | 43 | if not merge_labels: 44 | self.labels = split_df['label'].values 45 | self.labels = [LABEL_DICT[l] for l in self.labels] 46 | self.label_names = LABEL_NAMES 47 | else: 48 | self.labels = split_df['merged_label'].values 49 | self.labels = [MERGED_LABEL_DICT[l] for l in self.labels] 50 | self.label_names = LABEL_NAMES_MERGED 51 | 52 | self.num_labels = len(self.label_names) 53 | self.labels = F.one_hot(torch.tensor(self.labels), self.num_labels) 54 | 55 | def __len__(self): 56 | return len(self.filenames) 57 | 58 | def __getitem__(self, index): 59 | # Load BGR and infrared data from png and npy file using the composite data 60 | visible = io.read_image(os.path.join(self.dataset_root, self.filenames[index], 'images', 'visible', 'composite.png')) 61 | infrared = torch.tensor(np.load(os.path.join( 62 | self.dataset_root, self.filenames[index], 'images', 'infrared', 'composite.npy')).astype(int)) 63 | # Stack data into the expected order 64 | data = torch.concat([visible[[2, 1, 0]], infrared.permute(2, 0, 1)]) 65 | # Scale data from 0 - 255 (BGR) to 0 - 10000 (HLS) 66 | data = data / 255 * 10000 67 | 68 | if self.transform is not None: 69 | data = self.transform(data) 70 | 71 | sample = { 72 | 'image': data, 73 | 'label': self.labels[index] 74 | } 75 | 76 | return sample 77 | 78 | 79 | def init_forestnet(bands, normalize, merge_labels, cfg, *args, **kwargs): 80 | """ 81 | Init m-ForestNet dataset. 82 | """ 83 | # Get dataset parameters 84 | split = cfg['dataset']['split'] 85 | 86 | # Get BigEarthNet directory 87 | DATA_DIR = os.getenv('DATA_DIR', 'data') 88 | forestnet_dir = os.path.join(DATA_DIR, 'ForestNetDataset') 89 | 90 | # Init transforms 91 | image_transforms = [ 92 | SelectChannels(bands), 93 | transforms.Resize(size=cfg['model']['img_size'], antialias=True), 94 | ] 95 | 96 | if normalize: 97 | if len(bands) != len(cfg['model']['data_mean']): 98 | # Add mean channels values for missing channels (e.g. for BGR data) 99 | image_transforms.append(AddMeanChannels(cfg['model']['data_mean'])) 100 | # Normalize images 101 | image_transforms.append(transforms.Normalize(mean=cfg['model']['data_mean'], std=cfg['model']['data_std'])) 102 | image_transforms.append(Unsqueeze(dim=1)) # add time dim 103 | 104 | # Init dataset 105 | dataset = ForestNet( 106 | dataset_root=forestnet_dir, 107 | split=split, 108 | transform=transforms.Compose(image_transforms), 109 | merge_labels=merge_labels, 110 | ) 111 | 112 | return dataset 113 | 114 | 115 | # Add datasets to the registry 116 | register_dataset('ForestNet', partial(init_forestnet, [0, 1, 2, 3, 4, 5], True, False)) 117 | 118 | register_dataset('ForestNet4', partial(init_forestnet, [0, 1, 2, 3, 4, 5], True, True)) 119 | 120 | register_dataset('ForestNetBGR', partial(init_forestnet, [0, 1, 2], True, False)) 121 | 122 | register_dataset('ForestNet4BGR', partial(init_forestnet, [0, 1, 2], True, True)) 123 | 124 | register_dataset('ForestNetVisual', partial(init_forestnet, [2, 1, 0], False, False)) 125 | 126 | register_dataset('ForestNet4Visual', partial(init_forestnet, [2, 1, 0], False, True)) 127 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import time 4 | import yaml 5 | import logging 6 | import torch 7 | import sys 8 | import os 9 | from torch.utils.data import DataLoader 10 | from model import build_model 11 | from datasets import load_dataset 12 | from datetime import datetime, timezone, timedelta 13 | from utils.hash_functions import get_hash 14 | 15 | output_path = os.getenv('OUTPUT_PATH', os.path.join('output', 'embeddings')) 16 | 17 | 18 | def run_inference(cfg, args): 19 | # Create output dir (default: output_path///) 20 | output_dir = args.output_dir or os.path.join(output_path, 21 | cfg['model']['name'], 22 | f"{cfg['dataset']['name']}{args.input_size or ''}", 23 | cfg['dataset']['split']) 24 | assert not os.path.isdir(output_dir) or len(os.listdir(output_dir)) == 0, \ 25 | (f"Output directory already exists and is not empty ({output_dir}). " 26 | f"Specify the directory with --output_dir ") 27 | os.makedirs(output_dir, exist_ok=True) 28 | 29 | logging.info(f"Running inference with model {cfg['model']['name']} for dataset {cfg['dataset']['name']}.") 30 | # Init dataset 31 | dataset = load_dataset(cfg) 32 | 33 | # Init data loader 34 | data_loader = DataLoader( 35 | dataset, 36 | **cfg['dataloader'], 37 | ) 38 | logging.info('DataLoader initialized') 39 | 40 | if torch.cuda.is_available(): 41 | device = torch.device('cuda') 42 | else: 43 | device = torch.device('cpu') 44 | 45 | # Init model 46 | model = build_model(cfg) 47 | model = model.to(device) 48 | model.eval() 49 | logging.info('Model loaded') 50 | 51 | embeddings = [] 52 | labels = [] 53 | time_start = time.time() 54 | num_batches = len(data_loader) 55 | i = 0 56 | 57 | # Run inference 58 | logging.info(f'Starting inference on {len(dataset)} samples') 59 | for batch in data_loader: 60 | # Load input 61 | input = batch['image'] 62 | label = batch['label'] 63 | input = input.to(device) 64 | 65 | # Compute model embedding 66 | with torch.no_grad(): 67 | embedding = model(input) 68 | 69 | embeddings.append(embedding.cpu()) 70 | labels.append(label) 71 | 72 | # Log progress 73 | i += 1 74 | if i % 100 == 0: 75 | speed = i / (time.time() - time_start) 76 | eta = timedelta(seconds=int((num_batches - i) / speed)) 77 | logging.info(f"Batch {i:5d}/{num_batches:4d} - Speed {speed:.2f} batches/s - ETA: {eta}") 78 | 79 | logging.info('Finished inference') 80 | 81 | batch_size = cfg['dataloader']['batch_size'] if 'batch_size' in cfg['dataloader'] else 1 82 | sample_speed = (time.time() - time_start) / (i * batch_size) 83 | logging.info(f'Average inference time: {sample_speed:.4f} s/sample') 84 | 85 | # Combine batch embeddings 86 | embeddings = torch.concat(embeddings, dim=0) 87 | labels = torch.concat(labels, dim=0) 88 | # Create hash codes 89 | hash_codes = get_hash(embeddings, method='lsh', length=64) 90 | logging.info('Hash codes generated') 91 | 92 | # Save embeddings, labels, and hashes (using numpy because of smaller files) 93 | logging.info(f'Saving {len(embeddings)} embeddings, labels, and hash_codes to {output_dir}') 94 | torch.save(embeddings, os.path.join(output_dir, 'embeddings.pt')) 95 | torch.save(labels, os.path.join(output_dir, 'labels.pt')) 96 | torch.save(hash_codes, os.path.join(output_dir, 'hash_codes.pt')) 97 | 98 | logging.info('Files saved') 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-c', '--config_file', type=str, required=True, help='Path to config file') 104 | parser.add_argument('-o', '--output_dir', type=str, default=None, 105 | help='Path to output dir for embeddings and labels ' 106 | '(default: output/embeddings///)') 107 | parser.add_argument('-d', '--dataset', type=str, 108 | help='Overwrite the dataset name in the config file') 109 | parser.add_argument('-s', '--split', type=str, 110 | help='Overwrite the dataset split in the config file') 111 | parser.add_argument('--input_size', type=int, 112 | help='Overwrite the size of the model input in the config file') 113 | parser.add_argument('--data_dir', type=str, 114 | help='Path to data directory (default `data`)') 115 | parser.add_argument('--log_level', type=str, default='INFO', 116 | help='Log level (DEBUG, INFO, WARNING, ERROR)') 117 | parser.add_argument('--log_file', type=str, default=None, 118 | help='Log file') 119 | args = parser.parse_args() 120 | 121 | # Load config file 122 | with open(args.config_file, 'r') as f: 123 | cfg = yaml.safe_load(f) 124 | 125 | # Overwrite dataset and split from optional args 126 | if args.dataset: 127 | cfg['dataset']['name'] = args.dataset 128 | if args.split: 129 | cfg['dataset']['split'] = args.split 130 | if args.input_size: 131 | cfg['model']['img_size'] = args.input_size 132 | 133 | # Set data dir as env variable if specified 134 | if args.data_dir: 135 | os.environ['DATA_DIR'] = args.data_dir 136 | 137 | # init logger 138 | current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%MZ") 139 | log_file = args.log_file or f"logs/{current_time}_{cfg['model']['name']}_{cfg['dataset']['name']}.log" 140 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 141 | logging.basicConfig( 142 | level=args.log_level.upper(), 143 | handlers=[logging.FileHandler(log_file), logging.StreamHandler()], 144 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 145 | ) 146 | 147 | logging.info(f'Config:\n {yaml.dump(cfg)}') 148 | 149 | try: 150 | run_inference(cfg, args) 151 | except Exception as e: 152 | # log potential error 153 | logging.error(f'{type(e)}: {e}') 154 | raise e 155 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import textwrap 5 | import torch 6 | import yaml 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from datasets import load_dataset 10 | from experiments import get_similarity, get_hash 11 | 12 | 13 | def plot_retrival_results(queries, results, correct, labels): 14 | """ 15 | Create figure with label names, queries and retrieved images. 16 | 17 | :param queries: tensor of shape [sample, channel, h, w] 18 | :param results: tensor of shape [sample, images, channel, h, w] 19 | :param correct: tensor of shape [sample, images] with bool values of retrieved image is correct 20 | :param labels: list of shape [sample] with label names of each sample 21 | """ 22 | num_samples, num_images = results.shape[:2] 23 | 24 | # Scale by max sensor value per query 25 | scale_max = torch.max(torch.amax(queries, dim=(1, 2, 3)), torch.amax(results, dim=(1, 2, 3, 4))) * 0.5 26 | # Convert all tensors to RGB format 27 | queries_rgb = (queries / scale_max[:, None, None, None] * 255).to(int).permute(0, 2, 3, 1) 28 | results_rgb = (results / scale_max[:, None, None, None, None] * 255).to(int).permute(0, 1, 3, 4, 2) 29 | queries_rgb = queries_rgb.clip(0, 255) 30 | results_rgb = results_rgb.clip(0, 255) 31 | 32 | # Create subplots 33 | fig, axs = plt.subplots(num_samples, num_images + 2, figsize=(14, num_samples * 1.2), 34 | gridspec_kw={'width_ratios': [1, 2.] + [1] * num_images}) 35 | fig.subplots_adjust(hspace=0.1, wspace=0.1, left=0.01, right=0.99, top=0.99, bottom=0.01) 36 | 37 | for i in range(num_samples): 38 | # Display query image 39 | axs[i, 0].imshow(queries_rgb[i]) 40 | axs[i, 0].axis('off') 41 | # Label to string 42 | # label_str = ',\n'.join([textwrap.shorten(l, width=30, placeholder='...') for l in labels[i]]) 43 | label_str = ',\n'.join(labels[i]) 44 | # Replace a very long label from BigEarthNet 45 | label_str = label_str.replace("Land principally occupied by agriculture, " 46 | "with significant areas of natural vegetation", 47 | "Agriculture with\n natural vegetation") 48 | axs[i, 1].text(0.5, 0.5, label_str, ha='center', va='center', fontsize=10) 49 | axs[i, 1].axis('off') 50 | 51 | for j in range(num_images): 52 | # Display result image 53 | axs[i, j + 2].imshow(results_rgb[i, j]) 54 | axs[i, j + 2].tick_params(left=False, right=False, labelleft=False, 55 | labelbottom=False, bottom=False) 56 | # Add frame based on correctness 57 | if correct[i, j] == len(labels[i]): 58 | frame_color = 'green' 59 | elif correct[i, j]: 60 | frame_color = 'orange' 61 | # Add number of correct labels 62 | axs[i, j + 2].text(200, 200, str(correct[i, j].item()), fontsize=10, 63 | color='white', ha='center', va='center') 64 | else: 65 | frame_color = 'red' 66 | for spine in axs[i, j + 2].spines.values(): 67 | spine.set_edgecolor(frame_color) 68 | spine.set_linewidth(3) 69 | 70 | # axs[0, 0].set_title("Query") 71 | # axs[0, 1 + round(num_images / 2)].set_title("Retrieved images") 72 | # plt.tight_layout() 73 | 74 | 75 | def main(args): 76 | # Init dataset 77 | with open(args.config_file, 'r') as f: 78 | # Load config file 79 | cfg = yaml.safe_load(f) 80 | cfg['dataset']['name'] = args.dataset_visual 81 | cfg['dataset']['split'] = 'val' 82 | val_dataset = load_dataset(cfg) 83 | cfg['dataset']['split'] = 'test' 84 | test_dataset = load_dataset(cfg) 85 | 86 | # Init embeddings with shape [sample, embedding] 87 | output_path = os.getenv('OUTPUT_PATH', os.path.join('output', 'embeddings')) 88 | val_embeddings = torch.load(os.path.join(output_path, cfg['model']['name'], args.dataset, 'val', 'embeddings.pt'), 89 | map_location='cpu') 90 | test_embeddings = torch.load(os.path.join(output_path, cfg['model']['name'], args.dataset, 'test', 'embeddings.pt'), 91 | map_location='cpu') 92 | # Multi-labels with shape [sample, label] 93 | val_labels = torch.load(os.path.join(output_path, cfg['model']['name'], args.dataset, 'val', 'labels.pt'), 94 | map_location='cpu') 95 | test_labels = torch.load(os.path.join(output_path, cfg['model']['name'], args.dataset, 'test', 'labels.pt'), 96 | map_location='cpu') 97 | 98 | # Load label names 99 | if 'BigEarthNet19' in args.dataset: 100 | label_names = val_dataset.class_sets[19] 101 | elif 'BigEarthNet' in args.dataset: 102 | label_names = val_dataset.class_sets[43] 103 | elif 'ForestNet' in args.dataset: 104 | label_names = val_dataset.label_names 105 | 106 | # Select sample queries 107 | np.random.seed(42) 108 | indices = args.indices or np.random.choice(range(len(val_dataset)), args.num_queries) 109 | print('Selected indices:', indices) 110 | 111 | # Retrival 112 | val_hash, test_hash = get_hash([val_embeddings, test_embeddings], 'trivial', length=32) 113 | similarity = get_similarity(val_hash[indices], test_hash, distance='hamming') 114 | # Select top k results 115 | ranking = similarity.topk(args.num_samples, sorted=True, dim=-1)[1] 116 | 117 | # Get correct match 118 | correct = val_labels[indices].unsqueeze(1).repeat(1, len(test_labels), 1) 119 | # count label matches 120 | correct = correct.logical_and(correct == test_labels).sum(dim=-1) 121 | correct = torch.gather(correct, 1, ranking) 122 | 123 | # Get images 124 | queries = [] 125 | retrieved_images = [] 126 | labels = [] 127 | for i, r in zip(indices, ranking): 128 | queries.append(val_dataset[i]['image']) 129 | retrieved_images.append(torch.stack([test_dataset[n]['image'] for n in r])) 130 | label_idx = val_labels[i].nonzero().flatten() 131 | labels.append([label_names[l] for l in label_idx]) 132 | queries = torch.stack(queries) 133 | retrieved_images = torch.stack(retrieved_images) 134 | 135 | plot_retrival_results(queries, retrieved_images, correct, labels) 136 | 137 | if os.path.isdir(args.output_dir): 138 | output_file = os.path.join(args.output_dir, f"{cfg['model']['name']}_{args.dataset}_retrieval.pdf") 139 | else: 140 | output_file = args.output_dir 141 | 142 | plt.savefig(output_file) 143 | plt.show() 144 | 145 | 146 | if __name__ == '__main__': 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument('-c', '--config_file', type=str, default='configs/prithvi_vit_us.yaml', 149 | help='Path to config file') 150 | parser.add_argument('-o', '--output_dir', type=str, default='output/figures', 151 | help='Path to output dir') 152 | parser.add_argument('-d', '--dataset', type=str, default='ForestNet') 153 | parser.add_argument('-v', '--dataset_visual', type=str, default='ForestNetVisual') 154 | parser.add_argument('-i', '--indices', nargs='+', type=str, default=[84, 435]) 155 | parser.add_argument('-n', '--num_samples', type=int, default=9) 156 | parser.add_argument('-q', '--num_queries', type=int, default=4) 157 | args = parser.parse_args() 158 | 159 | main(args) 160 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pandas as pd 4 | import tqdm 5 | import argparse 6 | import glob 7 | import logging 8 | import torch 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from datetime import datetime, timezone 12 | from torchmetrics.retrieval import RetrievalMAP, RetrievalNormalizedDCG, RetrievalPrecision 13 | from utils.hash_functions import get_hash 14 | 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | 17 | 18 | def get_similarity(queries, database, distance='hamming'): 19 | if distance == 'hamming': 20 | return -torch.cdist(queries.float(), database.float(), p=1) 21 | elif distance == 'euclidean': 22 | return -torch.cdist(queries.float(), database.float(), p=2) 23 | elif distance == 'cosine': 24 | return F.cosine_similarity(queries.float().unsqueeze(1), database.float().unsqueeze(0), dim=-1) 25 | elif distance == 'dotproduct': 26 | return torch.einsum('ab,cb->ac', queries, database).float() 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def run_experiment(val_embeddings, test_embeddings, val_labels, test_labels, distance_function, hash_method, 32 | hash_length, k, match, seed=None): 33 | # Run multiple seeds for lsh hash 34 | if hash_method == 'lsh' and seed is None: 35 | mean_average_precision = [] 36 | precision = [] 37 | # Run experiment 5 times with different seeds 38 | for s in range(1, 6): 39 | map, p = run_experiment(val_embeddings, test_embeddings, val_labels, test_labels, distance_function, 40 | hash_method, hash_length, k, match, seed=s*42) 41 | mean_average_precision.append(map) 42 | precision.append(p) 43 | # Compute mean and 95% confidence interval 44 | map_mean = np.mean(mean_average_precision) * 100 45 | map_ci = np.std(mean_average_precision, ddof=1) / np.sqrt(len(mean_average_precision)) * 1.96 * 100 46 | precision_mean = np.mean(precision) * 100 47 | precision_ci = np.std(precision, ddof=1) / np.sqrt(len(precision)) * 1.96 * 100 48 | 49 | return f'{map_mean:.2f} ± {map_ci:.2f}', f'{precision_mean:.2f} ± {precision_ci:.2f}' 50 | 51 | logging.info(f'Top {k} experiment with {hash_method} hash, length {hash_length}, and {distance_function} distance.') 52 | val_hash, test_hash = get_hash([val_embeddings, test_embeddings], hash_method, 53 | length=hash_length, seed=seed) 54 | 55 | # Init metrics 56 | map_metric = RetrievalMAP(top_k=k) 57 | precision_metric = RetrievalPrecision(top_k=k) 58 | # Passing only the top k results to ndcg results in different values. Skipping this metric. 59 | # ndcg_metric = RetrievalNormalizedDCG(top_k=k) 60 | 61 | # Iterate over results to avoid OOM errors 62 | step_size = 100 63 | for i in tqdm.tqdm(range(0, len(val_hash), step_size), desc='Compute retrival results'): 64 | # Get similarity values 65 | similarity = get_similarity(val_hash[i:i+step_size], test_hash, distance=distance_function) 66 | similarity = similarity.to(device) 67 | 68 | target = val_labels[i:i+step_size].unsqueeze(1).repeat(1, len(test_labels), 1) 69 | if match == 'any': 70 | # Count any positive overlap for reported experiments 71 | target = target.logical_and(target == test_labels).any(dim=-1) 72 | elif match == 'exact': 73 | target = (target == test_labels).all(dim=-1) 74 | 75 | # Select top k results to reduce computation time (no influence on mAP metric) 76 | assert k < similarity.shape[-1] 77 | ranking = similarity.topk(k, sorted=True, dim=-1)[1] 78 | similarity_k = torch.gather(similarity, 1, ranking) 79 | target_k = torch.gather(target, 1, ranking) 80 | indexes_k = torch.arange(i, i+len(ranking)).unsqueeze(1).repeat(1, k) 81 | 82 | # Add samples to retrieval metrics 83 | map_metric.update(similarity_k, target_k, indexes_k) 84 | precision_metric.update(similarity_k, target_k, indexes_k) 85 | 86 | # Compute metrics 87 | mean_average_precision = map_metric.compute().item() 88 | precision = precision_metric.compute().item() 89 | 90 | # Log results 91 | logging.debug(f'Retrival mAP@{k}: {mean_average_precision:.4f}') 92 | logging.debug(f'Retrival Precision@{k}: {precision:.4f}') 93 | 94 | return mean_average_precision, precision 95 | 96 | 97 | def main(args): 98 | output_path = os.getenv('OUTPUT_PATH', os.path.join('output', 'embeddings')) 99 | # expects results in the structure ////embeddings.pt 100 | val_embedding_files = sorted(glob.glob(os.path.join(output_path, args.folder_pattern, 'val', 'embeddings.pt'))) 101 | 102 | results = pd.DataFrame([], columns=['Dataset', 'Model', 'Match', 'Distance', 'Hash method', 'Hash length', 103 | 'Top K', 'mAP@K', 'Precision@K']) 104 | results_path = args.results_file or os.path.join('output', 'results.csv') 105 | 106 | logging.info(f'Found {len(val_embedding_files)} model+dataset combinations.') 107 | 108 | for val_embedding_file in val_embedding_files: 109 | # Embeddings with shape [sample, embedding] 110 | val_embeddings = torch.load(val_embedding_file, map_location=device) 111 | test_embeddings = torch.load(val_embedding_file.replace('val', 'test'), map_location=device) 112 | # Multi-labels with shape [sample, label] 113 | val_labels = torch.load(val_embedding_file.replace('embeddings.pt', 'labels.pt'), 114 | map_location=device) 115 | test_labels = torch.load(val_embedding_file.replace('embeddings.pt', 'labels.pt') 116 | .replace('val', 'test'), map_location=device) 117 | model, dataset = val_embedding_file.split('/')[-4:-2] 118 | logging.info(f'Embedding and labels loaded for {model} and {dataset}.') 119 | 120 | for match in args.match.split(','): 121 | for distance_function in args.distance_function.split(','): 122 | for hash_method in args.hash_method.split(','): 123 | hash_lengths = args.hash_length.split(',') if hash_method != 'none' else [val_embeddings.size(-1)] 124 | for hash_length in hash_lengths: 125 | hash_length = int(hash_length) 126 | for k in args.top_k.split(','): 127 | k = int(k) 128 | metrics = run_experiment(val_embeddings, test_embeddings, val_labels, test_labels, 129 | distance_function, hash_method, hash_length, k, match) 130 | results.loc[len(results)] = [dataset, model, match, distance_function, hash_method, 131 | hash_length, k, *metrics] 132 | 133 | # Save results 134 | results.to_csv(results_path) 135 | logging.debug(f'Saved metrics in {results_path}') 136 | logging.info(f'Finished experiments. Results are saved in {results_path}') 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument('-o', '--output_dir', type=str, default=None, 142 | help='Path to output dir for embeddings and labels ' 143 | '(default: output/embeddings//)') 144 | parser.add_argument('--folder_pattern', type=str, default='*/*', 145 | help='Pattern for output dir, default: assumes //') 146 | parser.add_argument('--match', type=str, default='any', help='Select match type (any, exact)') 147 | parser.add_argument('--distance_function', type=str, default='hamming', 148 | help='Distance function (hamming, euclidean, cosine, dotproduct)') 149 | parser.add_argument('--hash_method', type=str, default='trivial', 150 | help='Method (trivial, lsh, none)') 151 | parser.add_argument('--hash_length', type=str, default='32', help='Hash length') 152 | parser.add_argument('--top_k', type=str, default='20', help='Number of retrieved samples') 153 | parser.add_argument('--results_file', type=str, default=None, help='Results file') 154 | parser.add_argument('--log_level', type=str, default='INFO', 155 | help='Log level (DEBUG, INFO, WARNING, ERROR)') 156 | parser.add_argument('--log_file', type=str, default=None, help='Log file') 157 | args = parser.parse_args() 158 | 159 | # init logger 160 | current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%MZ") 161 | log_file = args.log_file or f"logs/{current_time}_experiments.log" 162 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 163 | logging.basicConfig( 164 | level=args.log_level.upper(), 165 | handlers=[logging.FileHandler(log_file), logging.StreamHandler()], 166 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 167 | ) 168 | 169 | try: 170 | main(args) 171 | except Exception as e: 172 | # log potential error 173 | logging.error(f'{type(e)}: {e}') 174 | raise e 175 | -------------------------------------------------------------------------------- /model/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | import os 13 | import logging 14 | 15 | from functools import partial 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from timm.models.vision_transformer import Block 21 | from timm.models.layers import to_2tuple 22 | 23 | import numpy as np 24 | 25 | from einops import rearrange 26 | from .model_registry import register_model 27 | 28 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 29 | """ 30 | embed_dim: output dimension for each position 31 | pos: a list of positions to be encoded: size (M,) 32 | out: (M, D) 33 | """ 34 | assert embed_dim % 2 == 0 35 | omega = np.arange(embed_dim // 2, dtype=np.float32) 36 | omega /= embed_dim / 2. 37 | omega = 1. / 10000**omega # (D/2,) 38 | 39 | pos = pos.reshape(-1) # (M,) 40 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 41 | 42 | emb_sin = np.sin(out) # (M, D/2) 43 | emb_cos = np.cos(out) # (M, D/2) 44 | 45 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 46 | return emb 47 | 48 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 49 | assert embed_dim % 2 == 0 50 | 51 | # use half of dimensions to encode grid_h 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 54 | 55 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 56 | return emb 57 | 58 | def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 59 | """ 60 | grid_size: 3d tuple of grid size: t, h, w 61 | return: 62 | pos_embed: L, D 63 | """ 64 | 65 | assert embed_dim % 16 == 0 66 | 67 | t_size, h_size, w_size = grid_size 68 | 69 | w_embed_dim = embed_dim // 16 * 6 70 | h_embed_dim = embed_dim // 16 * 6 71 | t_embed_dim = embed_dim // 16 * 4 72 | 73 | w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) 74 | h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) 75 | t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) 76 | 77 | w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) 78 | h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) 79 | t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) 80 | 81 | pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) 82 | 83 | if cls_token: 84 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 85 | return pos_embed 86 | 87 | 88 | class PatchEmbed(nn.Module): 89 | """ Frames of 2D Images to Patch Embedding 90 | The 3D version of timm.models.vision_transformer.PatchEmbed 91 | """ 92 | def __init__( 93 | self, 94 | img_size=224, 95 | patch_size=16, 96 | num_frames=3, 97 | tubelet_size=1, 98 | in_chans=3, 99 | embed_dim=768, 100 | norm_layer=None, 101 | flatten=True, 102 | bias=True, 103 | ): 104 | super().__init__() 105 | img_size = to_2tuple(img_size) 106 | patch_size = to_2tuple(patch_size) 107 | self.img_size = img_size 108 | self.patch_size = patch_size 109 | self.num_frames = num_frames 110 | self.tubelet_size = tubelet_size 111 | self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 112 | self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] 113 | self.flatten = flatten 114 | 115 | self.proj = nn.Conv3d(in_chans, embed_dim, 116 | kernel_size=(tubelet_size, patch_size[0], patch_size[1]), 117 | stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias) 118 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 119 | 120 | def forward(self, x): 121 | B, C, T, H, W = x.shape 122 | x = self.proj(x) 123 | if self.flatten: 124 | x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C 125 | x = self.norm(x) 126 | return x 127 | 128 | 129 | @register_model 130 | class PrithviViT(nn.Module): 131 | """ 132 | Encoder with VisionTransformer backbone 133 | """ 134 | def __init__(self, img_size=224, patch_size=16, 135 | num_frames=3, tubelet_size=1, 136 | in_chans=3, embed_dim=1024, depth=24, num_heads=16, 137 | mlp_ratio=4., norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 138 | norm_pix_loss=False, weights=None, *args, **kwargs): 139 | super().__init__() 140 | # -------------------------------------------------------------------------- 141 | # MAE encoder specifics 142 | self.patch_embed = PatchEmbed(img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim) 143 | num_patches = self.patch_embed.num_patches 144 | 145 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 146 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 147 | 148 | self.blocks = nn.ModuleList([ 149 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 150 | for i in range(depth)]) 151 | self.norm = norm_layer(embed_dim) 152 | # -------------------------------------------------------------------------- 153 | 154 | self.norm_pix_loss = norm_pix_loss 155 | 156 | self.initialize_weights() 157 | 158 | # Load model weights 159 | if os.path.isfile(weights): 160 | logging.info(f'Loading weights from {weights}') 161 | state_dict = torch.load(weights, map_location='cuda' if torch.cuda.is_available() else 'cpu') 162 | # discard pos_embedding weights 163 | del state_dict['pos_embed'] 164 | self.load_state_dict(state_dict, strict=False) 165 | else: 166 | # No weights provided 167 | logging.warning(f'No weights provided for cfg model weights {weights}') 168 | 169 | def initialize_weights(self): 170 | # initialization 171 | # initialize (and freeze) pos_embed by sin-cos embedding 172 | pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True) 173 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 174 | 175 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 176 | w = self.patch_embed.proj.weight.data 177 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 178 | 179 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 180 | torch.nn.init.normal_(self.cls_token, std=.02) 181 | 182 | # initialize nn.Linear and nn.LayerNorm 183 | self.apply(self._init_weights) 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | # we use xavier_uniform following official JAX ViT: 188 | torch.nn.init.xavier_uniform_(m.weight) 189 | if isinstance(m, nn.Linear) and m.bias is not None: 190 | nn.init.constant_(m.bias, 0) 191 | elif isinstance(m, nn.LayerNorm): 192 | nn.init.constant_(m.bias, 0) 193 | nn.init.constant_(m.weight, 1.0) 194 | 195 | def patchify(self, imgs): 196 | """ 197 | imgs: B, C, T, H, W 198 | x: B, L, D 199 | """ 200 | p = self.patch_embed.patch_size[0] 201 | tub = self.patch_embed.tubelet_size 202 | x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p) 203 | 204 | return x 205 | 206 | def unpatchify(self, x): 207 | """ 208 | x: B, L, D 209 | imgs: B, C, T, H, W 210 | """ 211 | p = self.patch_embed.patch_size[0] 212 | num_p = self.patch_embed.img_size[0] // p 213 | tub = self.patch_embed.tubelet_size 214 | imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p) 215 | return imgs 216 | 217 | def forward(self, x): 218 | # embed patches 219 | x = self.patch_embed(x) 220 | 221 | # add pos embed w/o cls token 222 | x = x + self.pos_embed[:, 1:, :] 223 | 224 | # append cls token 225 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 226 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 227 | x = torch.cat((cls_tokens, x), dim=1) 228 | 229 | # apply Transformer blocks 230 | for blk in self.blocks: 231 | x = blk(x) 232 | x = self.norm(x) 233 | 234 | # Return mean patch embedding 235 | x = x[:, 1:].mean(dim=1) 236 | 237 | return x 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------