├── README.md ├── data ├── __init__.py └── datasets.py ├── dataset_paths.py ├── earlystop.py ├── models ├── __init__.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ ├── model_features.py │ └── simple_tokenizer.py ├── clip_models.py └── networks │ ├── __init__.py │ ├── customnet.py │ └── xception.py ├── networks ├── __init__.py ├── base_model.py └── trainer.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── plots.ipynb ├── train.py ├── utils └── utils.py └── validate.py /README.md: -------------------------------------------------------------------------------- 1 | # DeCLIP 2 | 3 | [![arXiv](https://img.shields.io/badge/-arXiv-B31B1B.svg?style=for-the-badge)](https://doi.org/10.48550/arXiv.2409.08849) 4 | 5 | **Official PyTorch Implementation of the Paper:** 6 | 7 | > **Ștefan Smeu, Elisabeta Oneață, Dan Oneață** 8 | > [DeCLIP: Decoding CLIP Representations for Deepfake Localization](https://arxiv.org/pdf/2409.08849) 9 | > *WACV, 2025* 10 | 11 | ## Data 12 | 13 | To set up your data, follow these steps: 14 | 15 | 1. **Download the datasets:** 16 | - **Dolos Dataset:** Follow instructions from [Dolos GitHub repo](https://github.com/bit-ml/dolos) 17 | - **AutoSplice Dataset:** Follow instructions from [AutoSplice GitHub repo](https://github.com/shanface33/AutoSplice_Dataset) 18 | 19 | 2. **Organize the data:** 20 | 21 | After downloading, place the datasets in the `datasets` folder to match the following structure: 22 | 23 | ```plaintext 24 | ├── data/ 25 | ├── datasets/ 26 | │ ├── AutoSplice/ 27 | │ ├── dolos_data/ 28 | │ │ ├── celebahq/ 29 | │ │ │ ├── fake/ 30 | │ │ │ │ ├── lama/ 31 | │ │ │ │ ├── ldm/ 32 | │ │ │ │ ├── pluralistic/ 33 | │ │ │ │ ├── repaint-p2-9k/ 34 | │ │ │ ├── real/ 35 | │ │ ├── ffhq/ 36 | ├── models/ 37 | ├── train.py 38 | ├── validate.py 39 | ├── ... 40 | 41 | ## Installation 42 | 43 | Main prerequisites: 44 | 45 | * `Python 3.10.14` 46 | * `pytorch=2.2.2 (cuda 11.8)` 47 | * `pytorch-cuda=11.8` 48 | * `torchvision=0.17.2` 49 | * `scikit-learn=1.3.2` 50 | * `pandas==2.1.1` 51 | * `numpy=1.26.4` 52 | * `pillow=10.0.1` 53 | * `seaborn=0.13.0` 54 | * `matplotlib=3.7.1` 55 | * `tensorboardX=2.6.2.2` 56 | 57 | ## Train 58 | 59 | To train the models mentioned in the article, follow these steps: 60 | 61 | 1. **Set up training and validation data paths** in `options/train_options.py` or specify them as arguments when running the training routine. 62 | 63 | 2. **Run the training command** using the following template: 64 | 65 | ```bash 66 | python train.py --name= --train_dataset= --arch= --decoder_type= --feature_layer= --fix_backbone --fully_supervised 67 | ``` 68 | 69 | Example commands: 70 | 71 | Train on Repaint-P2: 72 | 73 | ```bash 74 | python train.py --name=test_repaint --train_dataset=repaint-p2-9k --data_root_path=datasets/dolos_data/celebahq/ --arch=CLIP:ViT-L/14 --decoder_type=conv-20 --feature_layer=layer20 --fix_backbone --fully_supervised 75 | ``` 76 | 77 | Where: 78 | 79 | - `arch` specifies the architecture, such as CLIP:RN50, CLIP:ViT-L/14, CLIP:xceptionnet, or CLIP:ViT-L/14,RN50. 80 | - `decoder_type` can be linear, attention, conv-4, conv-12, or conv-20. 81 | - `feature_layer` ranges from layer0 to layer23 for ViTs and from layer1 to layer4 for ResNets. 82 | 83 | Exceptions: 84 | 85 | - For CLIP:xceptionnet, features are always extracted from the 2nd block. 86 | - For CLIP:ViT-L/14,RN50, the argument value specifies the layer from ViT; for RN50, features are always extracted from layer3. 87 | - Use `--fully_supervised` for localization tasks. Omit it for image-level detection tasks. 88 | 89 | ## Pretrained Models 90 | We provide trained models for the networks which rely on ViT and ViT+RN50 backbones listed in the table below. 91 | 92 | | Backbone | Feature Layer | Decoder | Training Dataset | Download Link | 93 | |---------------------|----------------------------------|--------------|------------------|---------------------------------------------------------------------------------------------------| 94 | | ViT | layer20 | conv-20 | Pluralistic | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT/ViT_layer20_conv-20_pluralistic.pth) | 95 | | ViT | layer20 | conv-20 | LaMa | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT/ViT_layer20_conv-20_lama.pth) | 96 | | ViT | layer20 | conv-20 | RePaint-p2-9k | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT/ViT_layer20_conv-20_repaint-p2-9k.pth) | 97 | | ViT | layer20 | conv-20 | LDM | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT/ViT_layer20_conv-20_ldm.pth) | 98 | | ViT | layer20 | conv-20 | COCO-SD | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT/ViT_layer20_conv-20_cocosd.pth) | 99 | | ViT+RN50 | layer20+layer3 | conv-20 | Pluralistic | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT%2BRN50/ViT_layer20%2BRN50_layer3_conv-20_pluralistic.pth) | 100 | | ViT+RN50 | layer20+layer3 | conv-20 | LaMa | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT%2BRN50/ViT_layer20%2BRN50_layer3_conv-20_lama.pth) | 101 | | ViT+RN50 | layer20+layer3 | conv-20 | RePaint-p2-9k | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT%2BRN50/ViT_layer20%2BRN50_layer3_conv-20_repaint-p2-9k.pth) | 102 | | ViT+RN50 | layer20+layer3 | conv-20 | LDM | [Download](https://storage.cloud.google.com/bitdefender_ml_artifacts/declip/backbone_VIT%2BRN50/ViT_layer20%2BRN50_layer3_conv-20_ldm.pth) | 103 | 104 | Additionally, one can download the checkpoints using **gsutil** from [this GCS bucket](https://console.cloud.google.com/storage/browser/bitdefender_ml_artifacts/declip?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))). The weights are located in **backbone_VIT** and **backbone_VIT+RN50** folders, where each checkpoints follows the naming convention: ```___```, **where training_dataset** is lower-cased. For the case of features concatenated from ViT and RN50, a ```+``` charachter joins the 2 backbones and feature layers. 105 | 106 | 107 | ## Evaluation 108 | 109 | To evaluate a model, use the following template: 110 | 111 | ```bash 112 | python validate.py --arch=CLIP:ViT-L/14 --ckpt=path/to/the/saved/mode/checkpoint/model_epoch_best.pth --result_folder=path/to/save/the/results --fully_supervised 113 | ``` 114 | 115 | ## License 116 | 117 |

The code is licensed under CC BY-NC-SA 4.0

118 | 119 | 120 | This repository also integrates code from the following repositories: 121 | ``` 122 | @inproceedings{ojha2023fakedetect, 123 | title={Towards Universal Fake Image Detectors that Generalize Across Generative Models}, 124 | author={Ojha, Utkarsh and Li, Yuheng and Lee, Yong Jae}, 125 | booktitle={CVPR}, 126 | year={2023}, 127 | } 128 | ``` 129 | ``` 130 | @inproceedings{patchforensics, 131 | title={What makes fake images detectable? Understanding properties that generalize}, 132 | author={Chai, Lucy and Bau, David and Lim, Ser-Nam and Isola, Phillip}, 133 | booktitle={European Conference on Computer Vision}, 134 | year={2020} 135 | } 136 | ``` 137 | 138 | ## Citation 139 | 140 | If you find this work useful in your research, please cite it. 141 | 142 | ``` 143 | @InProceedings{DeCLIP, 144 | author = {Smeu, Stefan and Oneata, Elisabeta and Oneata, Dan}, 145 | title = {DeCLIP: Decoding CLIP representations for deepfake localization}, 146 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 147 | year = {2025} 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .datasets import RealFakeDataset, RealFakeDetectionDataset 3 | 4 | 5 | def create_dataloader(opt): 6 | shuffle = True if opt.data_label == 'train' else False 7 | if opt.fully_supervised: 8 | dataset = RealFakeDataset(opt) 9 | else: 10 | dataset = RealFakeDetectionDataset(opt) 11 | 12 | data_loader = torch.utils.data.DataLoader(dataset, 13 | batch_size=opt.batch_size, 14 | shuffle=shuffle, 15 | num_workers=int(opt.num_threads)) 16 | return data_loader 17 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | from io import BytesIO 5 | from PIL import Image, ImageOps, ImageFile 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import Dataset 8 | from random import shuffle 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | # Constants for mean and standard deviation 13 | MEAN = [0.48145466, 0.4578275, 0.40821073] 14 | 15 | STD = [0.26862954, 0.26130258, 0.27577711] 16 | 17 | # Helper functions 18 | def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", 'tif', 'tiff']): 19 | out = [] 20 | for r, d, f in os.walk(rootdir): 21 | for file in f: 22 | if file.split('.')[-1] in exts and must_contain in os.path.join(r, file): 23 | out.append(os.path.join(r, file)) 24 | return out 25 | 26 | def get_list(path, must_contain=''): 27 | if path.endswith(".pickle"): 28 | with open(path, 'rb') as f: 29 | image_list = pickle.load(f) 30 | return [item for item in image_list if must_contain in item] 31 | return recursively_read(path, must_contain) 32 | 33 | def randomJPEGcompression(image): 34 | qf = random.randint(30, 100) 35 | output_io_stream = BytesIO() 36 | image.save(output_io_stream, "JPEG", quality=qf, optimize=True) 37 | output_io_stream.seek(0) 38 | return Image.open(output_io_stream) 39 | 40 | # Base Dataset class 41 | class BaseDataset(Dataset): 42 | def __init__(self, opt): 43 | self.opt = opt 44 | self._init_data() 45 | 46 | def _init_data(self): 47 | pass 48 | 49 | def _get_data(self): 50 | pass 51 | 52 | def _get_transform(self): 53 | transform_list = [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD)] 54 | if self.opt.data_label == 'train': 55 | if self.opt.data_aug == "blur": 56 | transform_list.insert(1, transforms.GaussianBlur(kernel_size=5, sigma=(0.4, 2.0))) 57 | elif self.opt.data_aug == "color_jitter": 58 | transform_list.insert(1, transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)) 59 | elif self.opt.data_aug == "jpeg_compression": 60 | transform_list.insert(1, transforms.Lambda(randomJPEGcompression)) 61 | elif self.opt.data_aug == "all": 62 | transform_list.insert(1, transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)) 63 | transform_list.insert(2, transforms.Lambda(randomJPEGcompression)) 64 | transform_list.insert(3, transforms.GaussianBlur(kernel_size=5, sigma=(0.4, 2.0))) 65 | return transforms.Compose(transform_list) 66 | 67 | def __len__(self): 68 | pass 69 | 70 | class RealFakeDataset(BaseDataset): 71 | def __init__(self, opt): 72 | super().__init__(opt) 73 | self.mask_transf = self._get_mask_transform() 74 | 75 | def _init_data(self): 76 | if self.opt.data_label == "train": 77 | self.input_path = self.opt.train_path 78 | self.masks_path = self.opt.train_masks_ground_truth_path 79 | elif self.opt.data_label == "valid": 80 | self.input_path = self.opt.valid_path 81 | self.masks_path = self.opt.valid_masks_ground_truth_path 82 | elif self.opt.data_label == "test": 83 | self.input_path = self.opt.test_path 84 | self.masks_path = self.opt.test_masks_ground_truth_path 85 | 86 | fake_list = self._get_data() 87 | 88 | self.labels_dict = self._set_labels(fake_list) 89 | self.fake_list = fake_list 90 | shuffle(self.fake_list) 91 | self.transform = self._get_transform() 92 | 93 | def _get_data(self): 94 | fake_list = get_list(self.input_path) 95 | 96 | return fake_list 97 | 98 | def _get_mask_transform(self): 99 | return transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) 100 | 101 | def get_mask_from_file(self, file_name): 102 | if "autosplice" in self.opt.train_dataset: 103 | file_name = file_name[:file_name.rfind('_')] + "_mask.png" 104 | self.mask_path = os.path.join(self.masks_path, file_name) 105 | mask = Image.open(self.mask_path).convert("L") 106 | if self.opt.train_dataset in ['pluralistic', 'lama', 'repaint-p2-9k', 'ldm', 'ldm_clean', 'ldm_real']: 107 | mask = ImageOps.invert(mask) 108 | return self.mask_transf(mask).view(-1) 109 | 110 | def _set_labels(self, fake_list): 111 | # masks images should be .png 112 | labels = {img: img.split("/")[-1].replace(".jpg", ".png") for img in fake_list} 113 | return labels 114 | 115 | def __len__(self): 116 | return len(self.fake_list) 117 | 118 | def __getitem__(self, idx): 119 | img_path = self.fake_list[idx] 120 | label = self.labels_dict[img_path] 121 | img = Image.open(img_path).convert("RGB") 122 | img = self.transform(img) 123 | label = self.get_mask_from_file(label) 124 | 125 | return img, label, img_path, self.mask_path 126 | 127 | class RealFakeDetectionDataset(BaseDataset): 128 | def __init__(self, opt): 129 | super().__init__(opt) 130 | 131 | def _get_data(self): 132 | fake_list = get_list(self.input_path) 133 | real_list = get_list(self.input_path_real) 134 | 135 | return real_list, fake_list 136 | 137 | def _init_data(self): 138 | if self.opt.data_label == "train": 139 | self.input_path = self.opt.train_path 140 | self.input_path_real = self.opt.train_real_list_path 141 | self.masks_path = self.opt.train_masks_ground_truth_path 142 | elif self.opt.data_label == "valid": 143 | self.input_path = self.opt.valid_path 144 | self.input_path_real = self.opt.valid_real_list_path 145 | self.masks_path = self.opt.valid_masks_ground_truth_path 146 | elif self.opt.data_label == "test": 147 | self.input_path = self.opt.test_path 148 | self.input_path_real = self.opt.test_real_list_path 149 | self.masks_path = self.opt.test_masks_ground_truth_path 150 | 151 | real_list, fake_list = self._get_data() 152 | self.labels_dict = self._set_labels(real_list, fake_list) 153 | self.total_list = real_list + fake_list 154 | shuffle(self.total_list) 155 | self.transform = self._get_transform() 156 | 157 | def _set_labels(self, real_list, fake_list): 158 | labels = {img: 0 for img in real_list} 159 | labels.update({img: 1 for img in fake_list}) 160 | return labels 161 | 162 | def __len__(self): 163 | return len(self.total_list) 164 | 165 | def __getitem__(self, idx): 166 | img_path = self.total_list[idx] 167 | label = self.labels_dict[img_path] 168 | img = Image.open(img_path).convert("RGB") 169 | img = self.transform(img) 170 | 171 | return img, label, img_path -------------------------------------------------------------------------------- /dataset_paths.py: -------------------------------------------------------------------------------- 1 | def get_dolos_localisation_dataset_paths(dataset): 2 | paths = dict( 3 | fake_path=f'datasets/dolos_data/celebahq/fake/{dataset}/images/test', 4 | masks_path=f'datasets/dolos_data/celebahq/fake/{dataset}/masks/test', 5 | key=dataset 6 | ) 7 | return paths 8 | 9 | def get_dolos_detection_dataset_paths(dataset): 10 | paths = dict( 11 | real_path=f'datasets/dolos_data/celebahq/fake/{dataset}/images/test', 12 | fake_path=f'datasets/dolos_data/celebahq/real/{dataset}/test', 13 | masks_path=f'datasets/dolos_data/celebahq/fake/{dataset}/masks/test', 14 | key=dataset 15 | ), 16 | return paths 17 | 18 | def get_autosplice_localisation_dataset_paths(compression): 19 | paths = dict( 20 | fake_path=f'datasets/AutoSplice/Forged_JPEG{compression}', 21 | masks_path=f'datasets/AutoSplice/Mask', 22 | key=f'autosplice_jpeg{compression}' 23 | ) 24 | return paths 25 | 26 | LOCALISATION_DATASET_PATHS = [ 27 | get_dolos_localisation_dataset_paths('pluralistic'), 28 | get_dolos_localisation_dataset_paths('lama'), 29 | get_dolos_localisation_dataset_paths('repaint-p2-9k'), 30 | get_dolos_localisation_dataset_paths('ldm'), 31 | # TO BE PUBLISHED 32 | # get_dolos_localisation_dataset_paths('ldm_clean'), 33 | # get_dolos_localisation_dataset_paths('ldm_real'), 34 | 35 | get_autosplice_localisation_dataset_paths("75"), 36 | get_autosplice_localisation_dataset_paths("90"), 37 | get_autosplice_localisation_dataset_paths("100"), 38 | ] 39 | 40 | 41 | DETECTION_DATASET_PATHS = [ 42 | get_dolos_detection_dataset_paths('pluralistic'), 43 | get_dolos_detection_dataset_paths('lama'), 44 | get_dolos_detection_dataset_paths('repaint-p2-9k'), 45 | get_dolos_detection_dataset_paths('ldm'), 46 | # TO BE PUBLISHED 47 | # get_dolos_detection_dataset_paths('ldm_clean'), 48 | # get_dolos_detection_dataset_paths('ldm_real'), 49 | ] 50 | -------------------------------------------------------------------------------- /earlystop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class EarlyStopping: 4 | """Early stops the training if validation loss doesn't improve after a given patience.""" 5 | def __init__(self, patience=1, verbose=False, delta=0): 6 | """ 7 | Args: 8 | patience (int): How long to wait after last time validation loss improved. 9 | Default: 7 10 | verbose (bool): If True, prints a message for each validation loss improvement. 11 | Default: False 12 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 13 | Default: 0 14 | """ 15 | self.patience = patience 16 | self.verbose = verbose 17 | self.counter = 0 18 | self.best_score = None 19 | self.early_stop = False 20 | self.score_max = -np.Inf 21 | self.delta = delta 22 | 23 | def __call__(self, score, model): 24 | if self.best_score is None: 25 | self.best_score = score 26 | self.save_checkpoint(score, model) 27 | elif score < self.best_score - self.delta: 28 | self.counter += 1 29 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 30 | if self.counter >= self.patience: 31 | self.early_stop = True 32 | else: 33 | self.best_score = score 34 | self.save_checkpoint(score, model) 35 | self.counter = 0 36 | 37 | def save_checkpoint(self, score, model): 38 | '''Saves model when validation loss decrease.''' 39 | if self.verbose: 40 | print(f'Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...') 41 | # model.save_networks('best') 42 | self.score_max = score -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_models import CLIPModelLocalisation 2 | 3 | 4 | VALID_NAMES = [ 5 | 'CLIP:RN50', 6 | 'CLIP:ViT-L/14', 7 | 'CLIP:xceptionnet', 8 | 'CLIP:ViT-L/14,RN50', 9 | ] 10 | 11 | def get_model(opt): 12 | name, layer, decoder_type = opt.arch, opt.feature_layer, opt.decoder_type 13 | 14 | assert name in VALID_NAMES 15 | 16 | return CLIPModelLocalisation(name.split(':')[1], intermidiate_layer_output = layer, decoder_type=decoder_type) 17 | 18 | -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bit-ml/DeCLIP/09d4293c78ed648b103b7c453e7d4e47be0707f3/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .model_features import build_model_features 15 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 16 | 17 | try: 18 | from torchvision.transforms import InterpolationMode 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | 24 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 25 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 26 | 27 | 28 | __all__ = ["available_models", "load", "tokenize"] 29 | _tokenizer = _Tokenizer() 30 | 31 | _MODELS = { 32 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 33 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 34 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 35 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 37 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 38 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 39 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 40 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 41 | } 42 | 43 | 44 | def _download(url: str, root: str): 45 | os.makedirs(root, exist_ok=True) 46 | filename = os.path.basename(url) 47 | 48 | expected_sha256 = url.split("/")[-2] 49 | download_target = os.path.join(root, filename) 50 | 51 | if os.path.exists(download_target) and not os.path.isfile(download_target): 52 | raise RuntimeError(f"{download_target} exists and is not a regular file") 53 | 54 | if os.path.isfile(download_target): 55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 56 | return download_target 57 | else: 58 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 59 | 60 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 61 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 62 | while True: 63 | buffer = source.read(8192) 64 | if not buffer: 65 | break 66 | 67 | output.write(buffer) 68 | loop.update(len(buffer)) 69 | 70 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 71 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 72 | 73 | return download_target 74 | 75 | 76 | def _convert_image_to_rgb(image): 77 | return image.convert("RGB") 78 | 79 | 80 | def _transform(n_px): 81 | return Compose([ 82 | Resize(n_px, interpolation=BICUBIC), 83 | CenterCrop(n_px), 84 | _convert_image_to_rgb, 85 | ToTensor(), 86 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 87 | ]) 88 | 89 | 90 | def available_models() -> List[str]: 91 | """Returns the names of available CLIP models""" 92 | return list(_MODELS.keys()) 93 | 94 | 95 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, attention_features = False, intermidiate_layers = None): 96 | """Load a CLIP model 97 | 98 | Parameters 99 | ---------- 100 | name : str 101 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 102 | 103 | device : Union[str, torch.device] 104 | The device to put the loaded model 105 | 106 | jit : bool 107 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 108 | 109 | download_root: str 110 | path to download the model files; by default, it uses "~/.cache/clip" 111 | 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | 117 | preprocess : Callable[[PIL.Image], torch.Tensor] 118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 119 | """ 120 | if name in _MODELS: 121 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 122 | elif os.path.isfile(name): 123 | model_path = name 124 | else: 125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 126 | 127 | with open(model_path, 'rb') as opened_file: 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 136 | jit = False 137 | state_dict = torch.load(opened_file, map_location="cpu") 138 | 139 | if not jit: 140 | if intermidiate_layers: 141 | model = build_model(state_dict or model.state_dict(), intermidiate_layers = intermidiate_layers).to(device) 142 | else: 143 | if attention_features: 144 | model = build_model_features(state_dict or model.state_dict()).to(device) 145 | else: 146 | model = build_model(state_dict or model.state_dict()).to(device) 147 | if str(device) == "cpu": 148 | model.float() 149 | return model, _transform(model.visual.input_resolution) 150 | 151 | # patch the device names 152 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 153 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 154 | 155 | def patch_device(module): 156 | try: 157 | graphs = [module.graph] if hasattr(module, "graph") else [] 158 | except RuntimeError: 159 | graphs = [] 160 | 161 | if hasattr(module, "forward1"): 162 | graphs.append(module.forward1.graph) 163 | 164 | for graph in graphs: 165 | for node in graph.findAllNodes("prim::Constant"): 166 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 167 | node.copyAttributes(device_node) 168 | 169 | model.apply(patch_device) 170 | patch_device(model.encode_image) 171 | patch_device(model.encode_text) 172 | 173 | # patch dtype to float32 on CPU 174 | if str(device) == "cpu": 175 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 176 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 177 | float_node = float_input.node() 178 | 179 | def patch_float(module): 180 | try: 181 | graphs = [module.graph] if hasattr(module, "graph") else [] 182 | except RuntimeError: 183 | graphs = [] 184 | 185 | if hasattr(module, "forward1"): 186 | graphs.append(module.forward1.graph) 187 | 188 | for graph in graphs: 189 | for node in graph.findAllNodes("aten::to"): 190 | inputs = list(node.inputs()) 191 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 192 | if inputs[i].node()["value"] == 5: 193 | inputs[i].node().copyAttributes(float_node) 194 | 195 | model.apply(patch_float) 196 | patch_float(model.encode_image) 197 | patch_float(model.encode_text) 198 | 199 | model.float() 200 | 201 | return model, _transform(model.input_resolution.item()) 202 | 203 | 204 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 205 | """ 206 | Returns the tokenized representation of given input string(s) 207 | 208 | Parameters 209 | ---------- 210 | texts : Union[str, List[str]] 211 | An input string or a list of input strings to tokenize 212 | 213 | context_length : int 214 | The context length to use; all CLIP models use 77 as the context length 215 | 216 | truncate: bool 217 | Whether to truncate the text in case its encoding is longer than the context length 218 | 219 | Returns 220 | ------- 221 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 222 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 223 | """ 224 | if isinstance(texts, str): 225 | texts = [texts] 226 | 227 | sot_token = _tokenizer.encoder["<|startoftext|>"] 228 | eot_token = _tokenizer.encoder["<|endoftext|>"] 229 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 230 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 231 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 232 | else: 233 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 234 | 235 | for i, tokens in enumerate(all_tokens): 236 | if len(tokens) > context_length: 237 | if truncate: 238 | tokens = tokens[:context_length] 239 | tokens[-1] = eot_token 240 | else: 241 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 242 | result[i, :len(tokens)] = torch.tensor(tokens) 243 | 244 | return result 245 | -------------------------------------------------------------------------------- /models/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64, 103 | intermidiate_layers = False): 104 | super().__init__() 105 | self.intermidiate_layers = intermidiate_layers 106 | self.output_dim = output_dim 107 | self.input_resolution = input_resolution 108 | 109 | # the 3-layer stem 110 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(width // 2) 112 | self.relu1 = nn.ReLU(inplace=True) 113 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(width // 2) 115 | self.relu2 = nn.ReLU(inplace=True) 116 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(width) 118 | self.relu3 = nn.ReLU(inplace=True) 119 | self.avgpool = nn.AvgPool2d(2) 120 | 121 | # residual layers 122 | self._inplanes = width # this is a *mutable* variable used during construction 123 | self.layer1 = self._make_layer(width, layers[0]) 124 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 125 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 126 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 127 | 128 | embed_dim = width * 32 # the ResNet feature dimension 129 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | out = {} 142 | def stem(x): 143 | x = self.relu1(self.bn1(self.conv1(x))) 144 | x = self.relu2(self.bn2(self.conv2(x))) 145 | x = self.relu3(self.bn3(self.conv3(x))) 146 | x = self.avgpool(x) 147 | return x 148 | 149 | x = x.type(self.conv1.weight.dtype) 150 | x = stem(x) 151 | x = self.layer1(x) 152 | out['layer1'] = x 153 | x = self.layer2(x) 154 | out['layer2'] = x 155 | x = self.layer3(x) 156 | out['layer3'] = x 157 | x = self.layer4(x) 158 | out['layer4'] = x 159 | x = self.attnpool(x) 160 | out["final_embeddings"] = x 161 | 162 | if self.intermidiate_layers: 163 | return out 164 | else: 165 | return x 166 | 167 | 168 | class LayerNorm(nn.LayerNorm): 169 | """Subclass torch's LayerNorm to handle fp16.""" 170 | 171 | def forward(self, x: torch.Tensor): 172 | orig_type = x.dtype 173 | ret = super().forward(x.type(torch.float32)) 174 | return ret.type(orig_type) 175 | 176 | 177 | class QuickGELU(nn.Module): 178 | def forward(self, x: torch.Tensor): 179 | return x * torch.sigmoid(1.702 * x) 180 | 181 | 182 | class ResidualAttentionBlock(nn.Module): 183 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 184 | super().__init__() 185 | 186 | self.attn = nn.MultiheadAttention(d_model, n_head) 187 | self.ln_1 = LayerNorm(d_model) 188 | self.mlp = nn.Sequential(OrderedDict([ 189 | ("c_fc", nn.Linear(d_model, d_model * 4)), 190 | ("gelu", QuickGELU()), 191 | ("c_proj", nn.Linear(d_model * 4, d_model)) 192 | ])) 193 | self.ln_2 = LayerNorm(d_model) 194 | self.attn_mask = attn_mask 195 | 196 | def attention(self, x: torch.Tensor): 197 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 198 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 199 | 200 | def forward(self, x: torch.Tensor): 201 | x = x + self.attention(self.ln_1(x)) 202 | x = x + self.mlp(self.ln_2(x)) 203 | return x 204 | 205 | 206 | class Transformer(nn.Module): 207 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 208 | super().__init__() 209 | self.width = width 210 | self.layers = layers 211 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 212 | 213 | def forward(self, x: torch.Tensor): 214 | out = {} 215 | for idx, layer in enumerate(self.resblocks.children()): 216 | x = layer(x) 217 | out['layer'+str(idx)] = x # shape:LND. choose cls token feature (choose x[0] for CLS token only) 218 | return out, x 219 | 220 | # return self.resblocks(x) # This is the original code 221 | 222 | 223 | class VisionTransformer(nn.Module): 224 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, intermidiate_layers: bool = False): 225 | super().__init__() 226 | self.input_resolution = input_resolution 227 | self.output_dim = output_dim 228 | self.intermidiate_layers = intermidiate_layers 229 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 230 | 231 | scale = width ** -0.5 232 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 233 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 234 | self.ln_pre = LayerNorm(width) 235 | 236 | self.transformer = Transformer(width, layers, heads) 237 | 238 | self.ln_post = LayerNorm(width) 239 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 240 | 241 | 242 | 243 | def forward(self, x: torch.Tensor): 244 | x = self.conv1(x) # shape = [*, width, grid, grid] 245 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 246 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 247 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 248 | x = x + self.positional_embedding.to(x.dtype) 249 | x = self.ln_pre(x) 250 | 251 | x = x.permute(1, 0, 2) # NLD -> LND 252 | out, x = self.transformer(x) 253 | x = x.permute(1, 0, 2) # LND -> NLD 254 | 255 | x = self.ln_post(x[:, 0, :]) 256 | 257 | 258 | out['before_projection'] = x 259 | 260 | if self.proj is not None: 261 | x = x @ self.proj 262 | out['after_projection'] = x 263 | 264 | # Return both intermediate features and final clip feature 265 | if self.intermidiate_layers: 266 | return out 267 | else: 268 | # This only returns CLIP features 269 | return x 270 | 271 | 272 | class CLIP(nn.Module): 273 | def __init__(self, 274 | embed_dim: int, 275 | # vision 276 | image_resolution: int, 277 | vision_layers: Union[Tuple[int, int, int, int], int], 278 | vision_width: int, 279 | vision_patch_size: int, 280 | # text 281 | context_length: int, 282 | vocab_size: int, 283 | transformer_width: int, 284 | transformer_heads: int, 285 | transformer_layers: int, 286 | # optional 287 | intermidiate_layers: bool = False 288 | ): 289 | super().__init__() 290 | 291 | self.context_length = context_length 292 | 293 | if isinstance(vision_layers, (tuple, list)): 294 | vision_heads = vision_width * 32 // 64 295 | self.visual = ModifiedResNet( 296 | layers=vision_layers, 297 | output_dim=embed_dim, 298 | heads=vision_heads, 299 | input_resolution=image_resolution, 300 | width=vision_width, 301 | intermidiate_layers = intermidiate_layers 302 | ) 303 | else: 304 | vision_heads = vision_width // 64 305 | self.visual = VisionTransformer( 306 | input_resolution=image_resolution, 307 | patch_size=vision_patch_size, 308 | width=vision_width, 309 | layers=vision_layers, 310 | heads=vision_heads, 311 | output_dim=embed_dim, 312 | intermidiate_layers = intermidiate_layers 313 | ) 314 | 315 | self.transformer = Transformer( 316 | width=transformer_width, 317 | layers=transformer_layers, 318 | heads=transformer_heads, 319 | attn_mask=self.build_attention_mask() 320 | ) 321 | 322 | self.vocab_size = vocab_size 323 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 324 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 325 | self.ln_final = LayerNorm(transformer_width) 326 | 327 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 328 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 329 | 330 | self.initialize_parameters() 331 | 332 | def initialize_parameters(self): 333 | nn.init.normal_(self.token_embedding.weight, std=0.02) 334 | nn.init.normal_(self.positional_embedding, std=0.01) 335 | 336 | if isinstance(self.visual, ModifiedResNet): 337 | if self.visual.attnpool is not None: 338 | std = self.visual.attnpool.c_proj.in_features ** -0.5 339 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 340 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 341 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 342 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 343 | 344 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 345 | for name, param in resnet_block.named_parameters(): 346 | if name.endswith("bn3.weight"): 347 | nn.init.zeros_(param) 348 | 349 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 350 | attn_std = self.transformer.width ** -0.5 351 | fc_std = (2 * self.transformer.width) ** -0.5 352 | for block in self.transformer.resblocks: 353 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 354 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 355 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 356 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 357 | 358 | if self.text_projection is not None: 359 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 360 | 361 | def build_attention_mask(self): 362 | # lazily create causal attention mask, with full attention between the vision tokens 363 | # pytorch uses additive attention mask; fill with -inf 364 | mask = torch.empty(self.context_length, self.context_length) 365 | mask.fill_(float("-inf")) 366 | mask.triu_(1) # zero out the lower diagonal 367 | return mask 368 | 369 | @property 370 | def dtype(self): 371 | return self.visual.conv1.weight.dtype 372 | 373 | def encode_image(self, image): 374 | return self.visual(image.type(self.dtype)) 375 | 376 | def encode_text(self, text): 377 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 378 | 379 | x = x + self.positional_embedding.type(self.dtype) 380 | x = x.permute(1, 0, 2) # NLD -> LND 381 | x = self.transformer(x) 382 | x = x.permute(1, 0, 2) # LND -> NLD 383 | x = self.ln_final(x).type(self.dtype) 384 | 385 | # x.shape = [batch_size, n_ctx, transformer.width] 386 | # take features from the eot embedding (eot_token is the highest number in each sequence) 387 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 388 | 389 | return x 390 | 391 | def forward(self, image, text): 392 | image_features = self.encode_image(image) 393 | text_features = self.encode_text(text) 394 | 395 | # normalized features 396 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 397 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 398 | 399 | # cosine similarity as logits 400 | logit_scale = self.logit_scale.exp() 401 | logits_per_image = logit_scale * image_features @ text_features.t() 402 | logits_per_text = logits_per_image.t() 403 | 404 | # shape = [global_batch_size, global_batch_size] 405 | return logits_per_image, logits_per_text 406 | 407 | 408 | def convert_weights(model: nn.Module): 409 | """Convert applicable model parameters to fp16""" 410 | 411 | def _convert_weights_to_fp16(l): 412 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 413 | l.weight.data = l.weight.data.half() 414 | if l.bias is not None: 415 | l.bias.data = l.bias.data.half() 416 | 417 | if isinstance(l, nn.MultiheadAttention): 418 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 419 | tensor = getattr(l, attr) 420 | if tensor is not None: 421 | tensor.data = tensor.data.half() 422 | 423 | for name in ["text_projection", "proj"]: 424 | if hasattr(l, name): 425 | attr = getattr(l, name) 426 | if attr is not None: 427 | attr.data = attr.data.half() 428 | 429 | model.apply(_convert_weights_to_fp16) 430 | 431 | 432 | def build_model(state_dict: dict, intermidiate_layers: bool = False): 433 | vit = "visual.proj" in state_dict 434 | 435 | if vit: 436 | vision_width = state_dict["visual.conv1.weight"].shape[0] 437 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 438 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 439 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 440 | image_resolution = vision_patch_size * grid_size 441 | else: 442 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 443 | vision_layers = tuple(counts) 444 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 445 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 446 | vision_patch_size = None 447 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 448 | image_resolution = output_width * 32 449 | 450 | embed_dim = state_dict["text_projection"].shape[1] 451 | context_length = state_dict["positional_embedding"].shape[0] 452 | vocab_size = state_dict["token_embedding.weight"].shape[0] 453 | transformer_width = state_dict["ln_final.weight"].shape[0] 454 | transformer_heads = transformer_width // 64 455 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 456 | 457 | model = CLIP( 458 | embed_dim, 459 | image_resolution, vision_layers, vision_width, vision_patch_size, 460 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, intermidiate_layers = intermidiate_layers 461 | ) 462 | 463 | for key in ["input_resolution", "context_length", "vocab_size"]: 464 | if key in state_dict: 465 | del state_dict[key] 466 | 467 | convert_weights(model) 468 | model.load_state_dict(state_dict) 469 | return model.eval() -------------------------------------------------------------------------------- /models/clip/model_features.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | # import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | """Subclass torch's LayerNorm to handle fp16.""" 12 | 13 | def forward(self, x: torch.Tensor): 14 | orig_type = x.dtype 15 | ret = super().forward(x.type(torch.float32)) 16 | return ret.type(orig_type) 17 | 18 | 19 | class QuickGELU(nn.Module): 20 | def forward(self, x: torch.Tensor): 21 | return x * torch.sigmoid(1.702 * x) 22 | 23 | 24 | class ResidualAttentionBlock(nn.Module): 25 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 26 | super().__init__() 27 | 28 | self.attn = nn.MultiheadAttention(d_model, n_head) # d_model => (token) embed_dim 29 | self.ln_1 = LayerNorm(d_model) 30 | self.mlp = nn.Sequential(OrderedDict([ 31 | ("c_fc", nn.Linear(d_model, d_model * 4)), 32 | ("gelu", QuickGELU()), 33 | ("c_proj", nn.Linear(d_model * 4, d_model)) 34 | ])) 35 | self.ln_2 = LayerNorm(d_model) 36 | self.attn_mask = attn_mask 37 | 38 | def attention(self, x: torch.Tensor): 39 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 40 | # Yup, pytorch's forward for MultideadAttention expects arguments (query, key, value, ...) 41 | # query => [L, N, E]; key and value => [S, N, E] 42 | # L: target dim; S: source dim; E: (token) embedding dim; N: batch 43 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask, average_attn_weights=False) 44 | 45 | def forward(self, x: torch.Tensor): 46 | attention_res = self.attention(self.ln_1(x)) 47 | x, weights = x + attention_res[0], attention_res[1] 48 | # x => attn_output => shape = [L, N, E] 49 | # weights => attn_output_weights => shape = [layers, N, heads, L, S] 50 | x = x + self.mlp(self.ln_2(x)) 51 | return x, weights 52 | 53 | 54 | class Transformer(nn.Module): 55 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 56 | super().__init__() 57 | self.width = width 58 | self.layers = layers 59 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 60 | 61 | def forward(self, x: torch.Tensor): 62 | weights_all_blocks = [] 63 | 64 | # Go through all the blocks (layers) 65 | for block in self.resblocks: 66 | x, weight = block(x) 67 | weights_all_blocks.append(weight) 68 | 69 | return x, torch.stack(weights_all_blocks) 70 | 71 | 72 | class VisionTransformer(nn.Module): 73 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 74 | super().__init__() 75 | self.input_resolution = input_resolution 76 | self.output_dim = output_dim 77 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 78 | 79 | scale = width ** -0.5 80 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 81 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 82 | self.ln_pre = LayerNorm(width) 83 | 84 | self.transformer = Transformer(width, layers, heads) 85 | 86 | self.ln_post = LayerNorm(width) 87 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 88 | 89 | self.get_cls = False 90 | 91 | def forward(self, x: torch.Tensor): 92 | # The conv1 uses kernel_size=patch_size, stride=patch_size, therefore stride==kernel_size and 93 | # that will do the equivalent of chopping the image into patches and converting those into "tokens" 94 | # with a depth (number of output layers / channels in the conv1) equals to "width". 95 | # Note1: grid => input_resolution/patch_size 96 | # Note2: width => d_model => (token) embedding dim 97 | x = self.conv1(x) # shape = [*, width, grid, grid] or [batch, embedding_dim, grid, grid] 98 | 99 | # Reshape it to flatten the grid (it will look more like a "sentence" made of "tokens") 100 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid**2] or [batch, embedding_dim, grid**2] 101 | 102 | # Swap axis so each token (patch) leads to its own learned high dim embedding "juice" 103 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] or [batch, grid**2, embedding_dim] 104 | 105 | # Concatenate an extra token (learned) that will represent the class (it will become the first token, it's the [CLS] for BERT (?) stuff) 106 | # Finally, we get to what is usual for transformers with a final shape = [batch, sentence_length, embedding_dim] 107 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, 1 + grid**2, width] 108 | 109 | # Add to the previous tensor (it's not concat now!) the positional embeddings (learned) 110 | x = x + self.positional_embedding.to(x.dtype) 111 | 112 | # The normalization layer learns a gain and a bias during training, 113 | # but it also "uses statistics computed from input data (x) in both training and evaluation modes". 114 | x = self.ln_pre(x) 115 | 116 | # Pytorch transformer stuff expects the batch as the second dimension 117 | x = x.permute(1, 0, 2) # shape: [N, L, E] -> [L, N, E] 118 | 119 | x, weights = self.transformer(x) 120 | # x => attn_output => shape = [L, N, E] 121 | # weights => attn_output_weights => shape = [layers, N, heads, L, S] 122 | # N: batch 123 | # L = S: 1 + grid**2 124 | # E: width or (token) embedding dim 125 | 126 | # Undo the mess with the dimensions as explained... 127 | x = x.permute(1, 0, 2) # shape: [L, N, E] -> [N, L, E] 128 | 129 | # At this point we have x.shape = [*, 1 + grid ** 2, width] or [batch, sentence_length, embedding_dim] 130 | # but the line below will keep only the class embedding token (learned!), the index=0 below, 131 | # and normalize it (as before, the normalization layer learns a gain and a bias...) 132 | x = self.ln_post(x[:, 0, :]) # shape = [*, width] or [batch, emdedding_dim] 133 | # Explanation about the CLS from the ViT creators: 134 | # https://github.com/google-research/vision_transformer/issues/61#issuecomment-802233921 135 | # https://github.com/google-research/vision_transformer/issues/83#issuecomment-805661088 136 | # "After training the embedding is fixed and it will be exactly the same for all inputs 137 | # in the first layer - but due to interactions with the other tokens in every layer, the 138 | # value will be input-dependent (and strongly correlate with the input's class) at the output layer." 139 | 140 | # self.proj.shape = [width, output_dim] 141 | if self.get_cls: 142 | return x 143 | 144 | x = x @ self.proj # Project the output from the transformer (the CLS token only) into the choosen output dimension 145 | # The operation is [batch, width] x [width, output_dim] = [batch, output_dim] 146 | # CLS embeddings, "x", has size "width" and (for the available models) it's always bigger than the output (output_dim). 147 | # Therefore we are compressing "x" into the output (smaller number of dimensions). 148 | 149 | return x, weights 150 | 151 | 152 | class CLIP(nn.Module): 153 | def __init__(self, 154 | # output dimension (embeddings generated by text and image encoders) 155 | output_embed_dim: int, 156 | # vision 157 | image_resolution: int, 158 | vision_layers: Union[Tuple[int, int, int, int], int], 159 | vision_width: int, 160 | vision_patch_size: int, 161 | # text 162 | context_length: int, 163 | vocab_size: int, 164 | transformer_width: int, 165 | transformer_heads: int, 166 | transformer_layers: int 167 | ): 168 | super().__init__() 169 | 170 | self.context_length = context_length 171 | 172 | self.vision_heads = vision_width // 64 173 | self.visual = VisionTransformer( 174 | input_resolution=image_resolution, 175 | patch_size=vision_patch_size, 176 | width=vision_width, 177 | layers=vision_layers, 178 | heads=self.vision_heads, 179 | output_dim=output_embed_dim 180 | ) 181 | 182 | self.transformer = Transformer( 183 | width=transformer_width, 184 | layers=transformer_layers, 185 | heads=transformer_heads, 186 | attn_mask=self.build_attention_mask() 187 | ) 188 | 189 | self.vocab_size = vocab_size 190 | self.token_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=transformer_width) 191 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 192 | self.ln_final = LayerNorm(transformer_width) 193 | 194 | self.text_projection = nn.Parameter(torch.empty(transformer_width, output_embed_dim)) 195 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 196 | 197 | self.initialize_parameters() 198 | 199 | def initialize_parameters(self): 200 | nn.init.normal_(self.token_embedding.weight, std=0.02) 201 | nn.init.normal_(self.positional_embedding, std=0.01) 202 | 203 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 204 | attn_std = self.transformer.width ** -0.5 205 | fc_std = (2 * self.transformer.width) ** -0.5 206 | for block in self.transformer.resblocks: 207 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 208 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 209 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 210 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 211 | 212 | if self.text_projection is not None: 213 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 214 | 215 | def build_attention_mask(self): 216 | # Lazily create the causal attention mask. 217 | # Pytorch uses additive attention mask (it adds the mask inside the softmax, remember the exponentials!), 218 | # therefore we fill it with "-inf" and make the lower diagonal part all zeros. 219 | # Note1: This is used only for the text encoder. For the visual transformer we want 220 | # full attention between the vision tokens (non-causal) because the image is 2D 221 | # and we only use a 1D thing like for text to reuse the transformer architecture... 222 | # Note2: A system is called "causal" when its output depends only on present and past inputs. 223 | # For text, the "sentence" is seen as a time-series, so the positions are usually referred to as T ("time"). 224 | # That's why we apply this weird mask and it will have the effect that all 225 | # elements ABOVE the main diagonal will be zeroed. 226 | # t0 t1 t2 t3 227 | # t0 w11 0 0 0 => At t0, you can only use t <= t0 228 | # t1 w21 w22 0 0 => ... 229 | # t2 w31 w32 w33 0 => ... 230 | # t3 w41 w42 w43 w45 => At t3, the final row, you can use everything because <= t3 231 | # Apparently it's "cheaper" to use a full matrix and at the end chop off half of it... 232 | mask = torch.empty(self.context_length, self.context_length) 233 | mask.fill_(float("-inf")) 234 | mask.triu_(1) # zero out the lower diagonal 235 | return mask 236 | 237 | @property 238 | def dtype(self): 239 | return self.visual.conv1.weight.dtype 240 | 241 | def encode_image(self, image, get_cls=False): 242 | self.visual.get_cls = get_cls 243 | return self.visual(image.type(self.dtype)) 244 | 245 | def encode_text(self, text): 246 | # Here text is the input "tokenized" text (clip.tokenize). 247 | # It has shape = [batch, context_length] 248 | 249 | x = self.token_embedding(text).type(self.dtype) # [batch_size, context_length, transformer_width] 250 | # transformer_width => d_model => token embeddings 251 | # context_length => sentence_length 252 | 253 | # A clear difference between this and the vision one is the lack of a CLS token. 254 | x = x + self.positional_embedding.type(self.dtype) 255 | 256 | # Pytorch transformer stuff expects the batch as the second dimension 257 | x = x.permute(1, 0, 2) # shape: [N, L, E] -> [L, N, E] 258 | 259 | x, weights = self.transformer(x) 260 | # x => attn_output => shape = [L, N, E] 261 | # weights => attn_output_weights => shape = [layers, N, heads, L, S] 262 | # N: batch 263 | # L = S: context_length 264 | # E: transformer_width or (token) embedding_dim 265 | 266 | # Undo the mess with the dimensions as explained... 267 | x = x.permute(1, 0, 2) # shape: [L, N, E] -> [N, L, E] 268 | 269 | # x.shape = [batch_size, context_length, transformer_width] 270 | x = self.ln_final(x).type(self.dtype) 271 | 272 | # Take features ONLY from the eot embedding. The eot_token is the token with the highest value (index) in each sequence. 273 | # Although we are only using that token, it went through all the transformer layers, 274 | # so it carries lots if information (but I don't know why they chose exactly the last one) 275 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1), :] @ self.text_projection # I added the ":" for clarity... 276 | # The operation is [batch_size, transformer_width] x [transformer_width, output_embed_dim] = [batch_size, output_embed_dim] 277 | 278 | return x, weights 279 | 280 | def forward(self, image, text): 281 | image_features, _ = self.encode_image(image) 282 | text_features, _ = self.encode_text(text) 283 | 284 | # normalized features 285 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 286 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 287 | 288 | # cosine similarity as logits 289 | logit_scale = self.logit_scale.exp() 290 | logits_per_image = logit_scale * image_features @ text_features.t() 291 | logits_per_text = logits_per_image.t() 292 | 293 | # shape = [global_batch_size, global_batch_size] 294 | return logits_per_image, logits_per_text 295 | 296 | 297 | def convert_weights(model: nn.Module): 298 | """Convert applicable model parameters to fp16""" 299 | 300 | def _convert_weights_to_fp16(l): 301 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 302 | l.weight.data = l.weight.data.half() 303 | if l.bias is not None: 304 | l.bias.data = l.bias.data.half() 305 | 306 | if isinstance(l, nn.MultiheadAttention): 307 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 308 | tensor = getattr(l, attr) 309 | if tensor is not None: 310 | tensor.data = tensor.data.half() 311 | 312 | for name in ["text_projection", "proj"]: 313 | if hasattr(l, name): 314 | attr = getattr(l, name) 315 | if attr is not None: 316 | attr.data = attr.data.half() 317 | 318 | model.apply(_convert_weights_to_fp16) 319 | 320 | 321 | def build_model_features(state_dict: dict, name: str = "", fp16: bool = True): 322 | vit = "visual.proj" in state_dict 323 | 324 | if vit: 325 | vision_width = state_dict["visual.conv1.weight"].shape[0] 326 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 327 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 328 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 329 | image_resolution = vision_patch_size * grid_size 330 | else: 331 | raise ValueError("Modified to work only with ViT image encoders...") 332 | 333 | output_embed_dim = state_dict["text_projection"].shape[1] 334 | context_length = state_dict["positional_embedding"].shape[0] 335 | vocab_size = state_dict["token_embedding.weight"].shape[0] 336 | transformer_width = state_dict["ln_final.weight"].shape[0] 337 | transformer_heads = transformer_width // 64 338 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 339 | 340 | model = CLIP( 341 | output_embed_dim, 342 | image_resolution, vision_layers, vision_width, vision_patch_size, 343 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 344 | ) 345 | 346 | for key in ["input_resolution", "context_length", "vocab_size"]: 347 | if key in state_dict: 348 | del state_dict[key] 349 | 350 | if fp16: 351 | convert_weights(model) 352 | 353 | model.load_state_dict(state_dict) 354 | 355 | print(f"Model stats for {name}") 356 | print(f"- output_embed_dim: {output_embed_dim}") 357 | print(f"- vision_width: {vision_width}") 358 | print(f"- vision_layers: {vision_layers}") 359 | print(f"- vision_patch_size: {vision_patch_size}") 360 | print(f"- vision_heads: {model.vision_heads}") 361 | print(f"- grid_size: {grid_size}") 362 | print(f"- image_resolution: {image_resolution}") 363 | print(f"- context_length: {context_length}") 364 | print(f"- vocab_size: {vocab_size}") 365 | print(f"- transformer_width: {transformer_width}") 366 | print(f"- transformer_heads: {transformer_heads}") 367 | print(f"- transformer_layers: {transformer_layers}") 368 | print(f"- total number of parameters: {np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") 369 | return model.eval() -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /models/clip_models.py: -------------------------------------------------------------------------------- 1 | from .clip import clip 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | import models.networks.customnet as customnetworks 6 | from models.clip.model import ResidualAttentionBlock 7 | import re 8 | 9 | # Model for localisation 10 | class CLIPModelLocalisation(nn.Module): 11 | def __init__(self, name, intermidiate_layer_output = None, decoder_type = "conv-4"): 12 | super(CLIPModelLocalisation, self).__init__() 13 | 14 | self.intermidiate_layer_output = intermidiate_layer_output 15 | self.decoder_type = decoder_type 16 | self.name = name # architecure 17 | 18 | if self.intermidiate_layer_output: 19 | assert "layer" in self.intermidiate_layer_output or "all" in self.intermidiate_layer_output or "xceptionnet" in self.intermidiate_layer_output 20 | 21 | self._set_backbone() 22 | self._set_decoder() 23 | 24 | def _set_backbone(self): 25 | # Set up the backbone model architecture and parameters 26 | if self.name in ["RN50", "ViT-L/14"]: 27 | self.model, self.preprocess = clip.load(self.name, device="cpu", intermidiate_layers = (self.intermidiate_layer_output != None)) 28 | elif self.name == "xceptionnet": 29 | # XceptionNet 30 | layername = 'block%d' % 2 31 | extra_output = 'block%d' % 1 32 | self.model = customnetworks.make_patch_xceptionnet( 33 | layername=layername, extra_output=extra_output, num_classes=2) 34 | # ViT+RN fusion 35 | elif "RN50" in self.name and "ViT-L/14" in self.name: 36 | name = self.name.split(",") 37 | model1, self.preprocess = clip.load(name[0], device="cpu", intermidiate_layers = (self.intermidiate_layer_output != None)) 38 | model2, self.preprocess = clip.load(name[1], device="cpu", intermidiate_layers = (self.intermidiate_layer_output != None)) 39 | self.model = [model1.to("cuda"), model2.to("cuda")] 40 | 41 | def _set_decoder(self): 42 | # Set up decoder architecture 43 | upscaling_layers = [] 44 | if "conv" in self.decoder_type: 45 | filter_sizes = self._get_conv_filter_sizes(self.name, self.intermidiate_layer_output, self.decoder_type) 46 | num_convs = int(re.search(r'\d{0,3}$', self.decoder_type).group()) 47 | 48 | for i in range(1, len(filter_sizes)): 49 | upscaling_layers.append(nn.Conv2d(filter_sizes[i-1], filter_sizes[i], kernel_size=5, padding=2)) 50 | upscaling_layers.append(nn.BatchNorm2d(filter_sizes[i])) 51 | upscaling_layers.append(nn.ReLU()) 52 | for _ in range(num_convs//4 - 1): 53 | upscaling_layers.append(nn.Conv2d(filter_sizes[i], filter_sizes[i], kernel_size=5, padding=2)) 54 | upscaling_layers.append(nn.BatchNorm2d(filter_sizes[i])) 55 | upscaling_layers.append(nn.ReLU()) 56 | 57 | # skip some upscaling layers if the input is too large (case for CNNs) 58 | skip_upscaling = ( 59 | self.intermidiate_layer_output == "layer2" and i == 1 60 | or self.intermidiate_layer_output == "layer1" and i <= 2 61 | ) and ("RN50" in self.name or "xceptionnet" in self.name) 62 | if skip_upscaling: 63 | continue 64 | 65 | upscaling_layers.append(nn.Upsample(scale_factor=2, mode='bilinear')) 66 | 67 | # CNNs output may not be in (256, 256) - usually a (224, 224) size 68 | if "RN50" in self.name or "xceptionnet" in self.name: 69 | upscaling_layers.append(nn.Upsample(size=(256, 256), mode='bilinear')) 70 | 71 | upscaling_layers.append(nn.Conv2d(64, 1, kernel_size=5, padding=2)) 72 | 73 | elif self.decoder_type == "linear": 74 | # Xceptionnet 75 | if self.name == "xceptionnet": 76 | upscaling_layers.append(nn.Linear(784, 1)) 77 | # CLIP 78 | else: 79 | upscaling_layers.append(nn.Linear(1024, 1)) 80 | 81 | elif self.decoder_type == "attention": 82 | transformer_width = 1024 83 | transformer_heads = transformer_width // 64 84 | attn_mask = self._build_attention_mask() 85 | self.att1 = ResidualAttentionBlock(transformer_width, transformer_heads, attn_mask) 86 | self.att2 = ResidualAttentionBlock(transformer_width, transformer_heads, attn_mask) 87 | upscaling_layers.append(nn.Linear(1024, 1)) 88 | 89 | self.fc = nn.Sequential(*upscaling_layers) 90 | 91 | def _get_conv_filter_sizes(self, name, intermidiate_layer_output, decoder_type): 92 | assert "conv" in decoder_type 93 | 94 | if "RN50" in name and "ViT-L/14" in name: 95 | num_layers = len(name.split(",")) 96 | return [1024*num_layers, 512, 256, 128, 64] 97 | elif "RN50" in name: 98 | if intermidiate_layer_output == "layer1": 99 | return [256, 512, 256, 128, 64] 100 | elif intermidiate_layer_output == "layer2": 101 | return [512, 512, 256, 128, 64] 102 | elif intermidiate_layer_output == "layer3": 103 | return [1024, 512, 256, 128, 64] 104 | elif intermidiate_layer_output == "layer4": 105 | return [2048, 512, 256, 128, 64] 106 | elif "xceptionnet" in name: 107 | return [256, 512, 256, 128, 64] 108 | else: 109 | return [1024, 512, 256, 128, 64] 110 | 111 | def _unify_linear_layer_outputs(self, linear_outputs): 112 | output = torch.cat(linear_outputs, dim=1) 113 | output = output.view(output.size()[0], int(output.size()[1]**0.5), int(output.size()[1]**0.5)) 114 | output = torch.nn.functional.interpolate(output.unsqueeze(1), size = (256, 256), mode = 'bicubic') 115 | return output 116 | 117 | # standard CLIPs method 118 | def _build_attention_mask(self): 119 | # lazily create causal attention mask, with full attention between the vision tokens 120 | # pytorch uses additive attention mask; fill with -inf 121 | context_length = 257 122 | mask = torch.empty(context_length, context_length) 123 | mask.fill_(float("-inf")) 124 | mask.triu_(1) # zero out the lower diagonal 125 | return mask 126 | 127 | def _feature_map_transform(self, input): 128 | output = input.permute(1, 2, 0) 129 | output = output.view(output.size()[0], output.size()[1], int(output.size()[2]**0.5), int(output.size()[2]**0.5)) 130 | return output 131 | 132 | def feature_extraction(self, x): 133 | if self.name == "RN50" or self.name=="ViT-L/14": 134 | features = self.model.encode_image(x) 135 | if self.intermidiate_layer_output: 136 | features = features[self.intermidiate_layer_output] 137 | # choose the last layer 138 | else: 139 | if self.name == "RN50": 140 | features = features["layer4"] 141 | else: 142 | features = features["layer23"] 143 | # ViT+RN fusion 144 | elif "RN50" in self.name and "ViT-L/14" in self.name: 145 | # given ViT feature layer 146 | features_vit = self.model[0].encode_image(x)[self.intermidiate_layer_output] 147 | features_vit = self._feature_map_transform(features_vit[1:]) 148 | # explicit RN50 3rd layer to match the feature dimension 149 | features_rn50 = self.model[1].encode_image(x)["layer3"] 150 | features_rn50 = F.interpolate(features_rn50, size=(16, 16), mode='bilinear', align_corners=False) 151 | features = torch.cat([features_vit, features_rn50], 1) 152 | # for xceptionnet 153 | else: 154 | features = self.model(x) 155 | features = features[0] 156 | 157 | return features 158 | 159 | def forward(self, x): 160 | # Feature extraction 161 | features = self.feature_extraction(x) 162 | 163 | # Forward step 164 | # ViT+RN fusion convolutional decoder 165 | if "RN50" in self.name and "ViT-L/14" in self.name and "conv" in self.decoder_type: 166 | output = self.fc(features) 167 | 168 | # Linear decoder 169 | elif self.decoder_type == "linear": 170 | # xceptionnet + linear 171 | if self.name == "xceptionnet": 172 | features = features.view(features.size()[0], features.size()[1], -1) 173 | features = features.permute(1, 0, 2) 174 | linear_outputs = [self.fc(input_part) for input_part in features[0:]] 175 | # CLIP + linear 176 | else: 177 | linear_outputs = [self.fc(input_part) for input_part in features[1:]] 178 | 179 | output = self._unify_linear_layer_outputs(linear_outputs) 180 | 181 | # Attention decoder 182 | elif self.decoder_type == "attention": 183 | features = self.att1(features) 184 | features = self.att2(features) 185 | linear_outputs = [self.fc(input_part) for input_part in features[1:]] 186 | output = self._unify_linear_layer_outputs(linear_outputs) 187 | 188 | # Convolutional decoder over RN 189 | elif "conv" in self.decoder_type and "RN50" == self.name: 190 | output = self.fc(features) 191 | 192 | # Convolutional decoder over ViT 193 | else: 194 | features = features[1:] 195 | output = self._feature_map_transform(features) 196 | output = self.fc(output) 197 | 198 | output = torch.flatten(output, start_dim =1) 199 | return output -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # models from https://github.com/bit-ml/dolos/tree/main/dolos/methods/patch_forensics/networks -------------------------------------------------------------------------------- /models/networks/customnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Customized version of pytorch resnet, alexnets. 3 | ''' 4 | import models.networks.xception 5 | import sys 6 | from torchvision.models import resnet 7 | from collections import OrderedDict 8 | from torch import nn 9 | import os 10 | import math 11 | import torch 12 | import numpy 13 | # current_path = os.path.dirname(os.path.abspath(__file__)) 14 | # sys.path.append(current_path) 15 | 16 | 17 | def make_patch_resnet(depth, layername, extra_output, num_classes=2): 18 | def change_out(layers): 19 | ind, layer = [(i, l) for i, (n, l) in enumerate(layers) 20 | if n == layername][0] 21 | if layername.startswith('layer'): 22 | bn = list(layer.modules())[-1 if depth < 23 | 50 else -2] # find final batchnorm 24 | assert (isinstance(bn, nn.BatchNorm2d)) 25 | num_ch = bn.num_features 26 | else: 27 | num_ch = 64 28 | layers[ind+1:] = [('convout', nn.Conv2d(num_ch, 29 | num_classes, kernel_size=1))] 30 | return layers 31 | model = CustomResNet(depth, modify_sequence=change_out, 32 | extra_output=extra_output) 33 | return model 34 | 35 | 36 | def make_patch_xceptionnet(layername, extra_output, frontend=None, num_classes=2): 37 | def change_out(layers): 38 | ind, layer = [(i, l) for i, (n, l) in enumerate(layers) 39 | if n == layername][0] 40 | if layername.startswith('block'): 41 | module_list = list(layer.modules()) 42 | bn = module_list[-1] # hack to find final batchnorm 43 | if not isinstance(bn, nn.BatchNorm2d): 44 | bn = module_list[-2] 45 | assert (isinstance(bn, nn.BatchNorm2d)) 46 | num_ch = bn.num_features 47 | elif layername.startswith('relu'): 48 | bn = layers[ind-1][1] 49 | assert (isinstance(bn, nn.BatchNorm2d)) 50 | num_ch = bn.num_features 51 | else: 52 | raise NotImplementedError 53 | # FC layer for localisation 54 | # layers[ind+1:] = [('convout', nn.Conv2d(num_ch, 55 | # num_classes, kernel_size=1))] 56 | layers[ind+1:] = [] 57 | return layers 58 | model = CustomXceptionNet( 59 | modify_sequence=change_out, frontend=frontend, extra_output=extra_output) 60 | return model 61 | 62 | 63 | def make_xceptionnet_long(): 64 | # a modified xception net with blocks of kernel size 1 65 | from . import xception 66 | 67 | def change_out(layers): 68 | channels = [3, 32, 64, 128, 256, 728, 728, 728, 728, 728, 728, 728, 69 | 728, 728, 1024, 1536, 2048] 70 | ind, layer = [(i, l) for i, (n, l) in enumerate(layers) 71 | if n == 'block2'][0] 72 | new_layers = [ 73 | # made all strides = 1 74 | ('pblock3', models.networks.xception.PixelBlock(channels[4], channels[5], 75 | 2, 1, start_with_relu=True, grow_first=True)), 76 | ('pblock4', models.networks.xception.PixelBlock(channels[5], channels[6], 77 | 3, 1, start_with_relu=True, grow_first=True)), 78 | ] 79 | num_ch = channels[9] 80 | new_layers.append(('convout', nn.Conv2d(num_ch, 2, kernel_size=1))) 81 | layers[ind+1:] = new_layers 82 | return layers 83 | model = CustomXceptionNet(modify_sequence=change_out) 84 | return model 85 | 86 | 87 | class CustomResNet(nn.Module): 88 | ''' 89 | Customizable ResNet, compatible with pytorch's resnet, but: 90 | * The top-level sequence of modules can be modified to add 91 | or remove or alter layers. 92 | * Extra outputs can be produced, to allow backprop and access 93 | to internal features. 94 | * Pooling is replaced by resizable GlobalAveragePooling so that 95 | any size can be input (e.g., any multiple of 32 pixels). 96 | * halfsize=True halves striding on the first pooling to 97 | set the default size to 112x112 instead of 224x224. 98 | ''' 99 | 100 | def __init__(self, size=None, block=None, layers=None, num_classes=1000, 101 | extra_output=None, modify_sequence=None, halfsize=False): 102 | standard_sizes = { 103 | 18: (resnet.BasicBlock, [2, 2, 2, 2]), 104 | 34: (resnet.BasicBlock, [3, 4, 6, 3]), 105 | 50: (resnet.Bottleneck, [3, 4, 6, 3]), 106 | 101: (resnet.Bottleneck, [3, 4, 23, 3]), 107 | 152: (resnet.Bottleneck, [3, 8, 36, 3]) 108 | } 109 | assert (size in standard_sizes) == (block is None) == (layers is None) 110 | if size in standard_sizes: 111 | block, layers = standard_sizes[size] 112 | if modify_sequence is None: 113 | def modify_sequence(x): return x 114 | self.inplanes = 64 115 | norm_layer = nn.BatchNorm2d 116 | self._norm_layer = norm_layer # for recent resnet 117 | self.dilation = 1 118 | self.groups = 1 119 | self.base_width = 64 120 | sequence = modify_sequence([ 121 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, 122 | padding=3, bias=False)), 123 | ('bn1', norm_layer(64)), 124 | ('relu', nn.ReLU(inplace=True)), 125 | ('maxpool', nn.MaxPool2d(3, stride=1 if halfsize else 2, 126 | padding=1)), 127 | ('layer1', self._make_layer(block, 64, layers[0])), 128 | ('layer2', self._make_layer(block, 128, layers[1], stride=2)), 129 | ('layer3', self._make_layer(block, 256, layers[2], stride=2)), 130 | ('layer4', self._make_layer(block, 512, layers[3], stride=2)), 131 | ('avgpool', GlobalAveragePool2d()), 132 | ('fc', nn.Linear(512 * block.expansion, num_classes)) 133 | ]) 134 | super(CustomResNet, self).__init__() 135 | for name, layer in sequence: 136 | setattr(self, name, layer) 137 | self.extra_output = extra_output 138 | 139 | def _make_layer(self, block, channels, depth, stride=1): 140 | return resnet.ResNet._make_layer(self, block, channels, depth, stride) 141 | 142 | def forward(self, x): 143 | extra = [] 144 | for name, module in self._modules.items(): 145 | x = module(x) 146 | if self.extra_output and name in self.extra_output: 147 | extra.append(x) 148 | if self.extra_output: 149 | return (x,) + tuple(extra) 150 | return x 151 | 152 | 153 | class CustomXceptionNet(nn.Module): 154 | ''' 155 | Customizable Xceptionnet, compatible with https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py 156 | but: 157 | * The top-level sequence of modules can be modified to add 158 | or remove or alter layers. 159 | * Extra outputs can be produced, to allow backprop and access 160 | to internal features. 161 | * halfsize=True halves striding on the first convolution to 162 | allow 151x151 images to be processed rather than 299x299 only. 163 | ''' 164 | 165 | def __init__(self, channels=None, num_classes=1000, 166 | extra_output=None, modify_sequence=None, frontend=None, 167 | halfsize=False): 168 | # from . import xception 169 | # from .heads import LFS_Head, FAD_Head, LFS_FAD_Head 170 | 171 | if channels is None: 172 | channels = [3, 32, 64, 128, 256, 728, 728, 728, 728, 728, 728, 728, 173 | 728, 728, 1024, 1536, 2048] 174 | assert (len(channels) == 17) 175 | if modify_sequence is None: 176 | def modify_sequence(x): return x 177 | ''' 178 | if frontend == "lfs": 179 | channels[0] = 6 180 | layers = [("lfs", LFS_Head())] 181 | elif frontend == "fad": 182 | channels[0] = 12 183 | layers = [("fad", FAD_Head(299))] 184 | elif frontend == "lfs+fad": 185 | channels[0] = 18 186 | layers = [("lfs+fad", LFS_FAD_Head(299))] 187 | el 188 | ''' 189 | if frontend == None: 190 | layers = [] 191 | else: 192 | assert False 193 | 194 | sequence = modify_sequence(layers + [ 195 | ('conv1', nn.Conv2d(channels[0], channels[1], kernel_size=3, 196 | stride=1 if halfsize else 2, padding=0, 197 | bias=False)), 198 | ('bn1', nn.BatchNorm2d(channels[1])), 199 | ('relu1', nn.ReLU(inplace=True)), 200 | ('conv2', nn.Conv2d(channels[1], channels[2], 3, bias=False)), 201 | ('bn2', nn.BatchNorm2d(channels[2])), 202 | ('relu2', nn.ReLU(inplace=True)), 203 | ('block1', models.networks.xception.Block(channels[2], 204 | channels[3], 2, 2, start_with_relu=False, grow_first=True)), 205 | ('block2', models.networks.xception.Block(channels[3], channels[4], 206 | 2, 2, start_with_relu=True, grow_first=True)), 207 | ('block3', models.networks.xception.Block(channels[4], channels[5], 208 | 2, 2, start_with_relu=True, grow_first=True)), 209 | ('block4', models.networks.xception.Block(channels[5], channels[6], 210 | 3, 1, start_with_relu=True, grow_first=True)), 211 | ('block5', models.networks.xception.Block(channels[6], channels[7], 212 | 3, 1, start_with_relu=True, grow_first=True)), 213 | ('block6', models.networks.xception.Block(channels[7], channels[8], 214 | 3, 1, start_with_relu=True, grow_first=True)), 215 | ('block7', models.networks.xception.Block(channels[8], channels[9], 216 | 3, 1, start_with_relu=True, grow_first=True)), 217 | ('block8', models.networks.xception.Block(channels[9], channels[10], 218 | 3, 1, start_with_relu=True, grow_first=True)), 219 | ('block9', models.networks.xception.Block(channels[10], channels[11], 220 | 3, 1, start_with_relu=True, grow_first=True)), 221 | ('block10', models.networks.xception.Block(channels[11], channels[12], 222 | 3, 1, start_with_relu=True, grow_first=True)), 223 | ('block11', models.networks.xception.Block(channels[12], channels[13], 224 | 3, 1, start_with_relu=True, grow_first=True)), 225 | ('block12', models.networks.xception.Block(channels[13], channels[14], 226 | 2, 2, start_with_relu=True, grow_first=False)), 227 | ('conv3', models.networks.xception.SeparableConv2d(channels[14], channels[15], 228 | 3, 1, 1)), 229 | ('bn3', nn.BatchNorm2d(channels[15])), 230 | ('relu3', nn.ReLU(inplace=True)), 231 | ('conv4', models.networks.xception.SeparableConv2d(channels[15], channels[16], 232 | 3, 1, 1)), 233 | ('bn4', nn.BatchNorm2d(channels[16])), 234 | ('relu4', nn.ReLU(inplace=True)), 235 | # does adaptive_avg_pool and flatten 236 | ('avgpool', GlobalAveragePool2d()), 237 | ('fc', nn.Linear(channels[16], num_classes)) 238 | ]) 239 | 240 | super(CustomXceptionNet, self).__init__() 241 | for name, layer in sequence: 242 | setattr(self, name, layer) 243 | self.extra_output = extra_output 244 | 245 | def forward(self, x): 246 | extra = [] 247 | for name, module in self._modules.items(): 248 | x = module(x) 249 | if self.extra_output and name in self.extra_output: 250 | extra.append(x) 251 | if self.extra_output: 252 | return (x,) + tuple(extra) 253 | return x 254 | 255 | 256 | class Vectorize(nn.Module): 257 | def __init__(self): 258 | super(Vectorize, self).__init__() 259 | 260 | def forward(self, x): 261 | x = x.view(x.size(0), int(numpy.prod(x.size()[1:]))) 262 | return x 263 | 264 | 265 | class GlobalAveragePool2d(nn.Module): 266 | def __init__(self): 267 | super(GlobalAveragePool2d, self).__init__() 268 | 269 | def forward(self, x): 270 | x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) 271 | return x 272 | 273 | 274 | if __name__ == '__main__': 275 | import torch.utils.model_zoo as model_zoo 276 | # Verify that at the default settings, pytorch standard pretrained 277 | # models can be loaded into each of the custom nets. 278 | print('Loading resnet18') 279 | model = CustomResNet(18) 280 | model.load_state_dict(model_zoo.load_url(resnet.ResNet18_Weights.IMAGENET1K_V2)) 281 | print('Loading resnet34') 282 | model = CustomResNet(34) 283 | model.load_state_dict(model_zoo.load_url(resnet.ResNet34_Weights.IMAGENET1K_V2)) 284 | print('Loading resnet50') 285 | model = CustomResNet(50) 286 | model.load_state_dict(model_zoo.load_url(resnet.ResNet50_Weights.IMAGENET1K_V2)) 287 | print('Loading resnet101') 288 | model = CustomResNet(101) 289 | model.load_state_dict(model_zoo.load_url(resnet.ResNet101_Weights.IMAGENET1K_V2)) 290 | print('Loading resnet152') 291 | model = CustomResNet(152) 292 | model.load_state_dict(model_zoo.load_url(resnet.ResNet152_Weights.IMAGENET1K_V2)) 293 | 294 | print('Loading xceptionnet') 295 | model = CustomXceptionNet() 296 | model.load_state_dict(model_zoo.load_url( 297 | 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth')) 298 | -------------------------------------------------------------------------------- /models/networks/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source:https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/master/pretrainedmodels/models/xception.py 3 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 4 | 5 | @author: tstandley 6 | Adapted by cadene 7 | 8 | Creates an Xception Model as defined in: 9 | 10 | Francois Chollet 11 | Xception: Deep Learning with Depthwise Separable Convolutions 12 | https://arxiv.org/pdf/1610.02357.pdf 13 | 14 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 15 | 16 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 17 | 18 | REMEMBER to set your image size to 3x299x299 for both test and validation 19 | 20 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 21 | std=[0.5, 0.5, 0.5]) 22 | 23 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 24 | """ 25 | from __future__ import print_function, division, absolute_import 26 | import math 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import torch.utils.model_zoo as model_zoo 31 | from torch.nn import init 32 | 33 | __all__ = ['xception'] 34 | 35 | pretrained_settings = { 36 | 'xception': { 37 | 'imagenet': { 38 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', 39 | 'input_space': 'RGB', 40 | 'input_size': [3, 299, 299], 41 | 'input_range': [0, 1], 42 | 'mean': [0.5, 0.5, 0.5], 43 | 'std': [0.5, 0.5, 0.5], 44 | 'num_classes': 1000, 45 | 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 46 | } 47 | } 48 | } 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 53 | super(SeparableConv2d, self).__init__() 54 | 55 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, 56 | stride, padding, dilation, groups=in_channels, bias=bias) 57 | self.pointwise = nn.Conv2d( 58 | in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.pointwise(x) 63 | return x 64 | 65 | # modified to replace 3x3 convs with 1x1 convs 66 | 67 | 68 | class PixelBlock(nn.Module): 69 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 70 | super(PixelBlock, self).__init__() 71 | assert (strides == 1) 72 | if out_filters != in_filters or strides != 1: 73 | self.skip = nn.Conv2d(in_filters, out_filters, 74 | 1, stride=strides, bias=False) 75 | self.skipbn = nn.BatchNorm2d(out_filters) 76 | else: 77 | self.skip = None 78 | 79 | rep = [] 80 | 81 | filters = in_filters 82 | if grow_first: 83 | rep.append(nn.ReLU(inplace=True)) 84 | rep.append(SeparableConv2d(in_filters, out_filters, 85 | 1, stride=1, padding=0, bias=False)) 86 | rep.append(nn.BatchNorm2d(out_filters)) 87 | filters = out_filters 88 | 89 | for i in range(reps-1): 90 | rep.append(nn.ReLU(inplace=True)) 91 | rep.append(SeparableConv2d(filters, filters, 92 | 1, stride=1, padding=0, bias=False)) 93 | rep.append(nn.BatchNorm2d(filters)) 94 | 95 | if not grow_first: 96 | rep.append(nn.ReLU(inplace=True)) 97 | rep.append(SeparableConv2d(in_filters, out_filters, 98 | 1, stride=1, padding=0, bias=False)) 99 | rep.append(nn.BatchNorm2d(out_filters)) 100 | 101 | if not start_with_relu: 102 | rep = rep[1:] 103 | else: 104 | rep[0] = nn.ReLU(inplace=False) 105 | 106 | if strides != 1: 107 | pass 108 | # take out the maxpool 109 | # rep.append(nn.MaxPool2d(3,strides,1)) 110 | self.rep = nn.Sequential(*rep) 111 | 112 | def forward(self, inp): 113 | x = self.rep(inp) 114 | 115 | if self.skip is not None: 116 | skip = self.skip(inp) 117 | skip = self.skipbn(skip) 118 | else: 119 | skip = inp 120 | 121 | x += skip 122 | return x 123 | 124 | 125 | class Block(nn.Module): 126 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 127 | super(Block, self).__init__() 128 | 129 | if out_filters != in_filters or strides != 1: 130 | self.skip = nn.Conv2d(in_filters, out_filters, 131 | 1, stride=strides, bias=False) 132 | self.skipbn = nn.BatchNorm2d(out_filters) 133 | else: 134 | self.skip = None 135 | 136 | rep = [] 137 | 138 | filters = in_filters 139 | if grow_first: 140 | rep.append(nn.ReLU(inplace=True)) 141 | rep.append(SeparableConv2d(in_filters, out_filters, 142 | 3, stride=1, padding=1, bias=False)) 143 | rep.append(nn.BatchNorm2d(out_filters)) 144 | filters = out_filters 145 | 146 | for i in range(reps-1): 147 | rep.append(nn.ReLU(inplace=True)) 148 | rep.append(SeparableConv2d(filters, filters, 149 | 3, stride=1, padding=1, bias=False)) 150 | rep.append(nn.BatchNorm2d(filters)) 151 | 152 | if not grow_first: 153 | rep.append(nn.ReLU(inplace=True)) 154 | rep.append(SeparableConv2d(in_filters, out_filters, 155 | 3, stride=1, padding=1, bias=False)) 156 | rep.append(nn.BatchNorm2d(out_filters)) 157 | 158 | if not start_with_relu: 159 | rep = rep[1:] 160 | else: 161 | rep[0] = nn.ReLU(inplace=False) 162 | 163 | if strides != 1: 164 | rep.append(nn.MaxPool2d(3, strides, 1)) 165 | self.rep = nn.Sequential(*rep) 166 | 167 | def forward(self, inp): 168 | x = self.rep(inp) 169 | 170 | if self.skip is not None: 171 | skip = self.skip(inp) 172 | skip = self.skipbn(skip) 173 | else: 174 | skip = inp 175 | 176 | x += skip 177 | return x 178 | 179 | 180 | class Xception(nn.Module): 181 | """ 182 | Xception optimized for the ImageNet dataset, as specified in 183 | https://arxiv.org/pdf/1610.02357.pdf 184 | """ 185 | 186 | def __init__(self, num_classes=1000): 187 | """ Constructor 188 | Args: 189 | num_classes: number of classes 190 | """ 191 | super(Xception, self).__init__() 192 | self.num_classes = num_classes 193 | 194 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) 195 | self.bn1 = nn.BatchNorm2d(32) 196 | self.relu1 = nn.ReLU(inplace=True) 197 | 198 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 199 | self.bn2 = nn.BatchNorm2d(64) 200 | self.relu2 = nn.ReLU(inplace=True) 201 | # do relu here 202 | 203 | self.block1 = Block( 204 | 64, 128, 2, 2, start_with_relu=False, grow_first=True) 205 | self.block2 = Block( 206 | 128, 256, 2, 2, start_with_relu=True, grow_first=True) 207 | self.block3 = Block( 208 | 256, 728, 2, 2, start_with_relu=True, grow_first=True) 209 | 210 | self.block4 = Block( 211 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 212 | self.block5 = Block( 213 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 214 | self.block6 = Block( 215 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 216 | self.block7 = Block( 217 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 218 | 219 | self.block8 = Block( 220 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 221 | self.block9 = Block( 222 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 223 | self.block10 = Block( 224 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 225 | self.block11 = Block( 226 | 728, 728, 3, 1, start_with_relu=True, grow_first=True) 227 | 228 | self.block12 = Block( 229 | 728, 1024, 2, 2, start_with_relu=True, grow_first=False) 230 | 231 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 232 | self.bn3 = nn.BatchNorm2d(1536) 233 | self.relu3 = nn.ReLU(inplace=True) 234 | 235 | # do relu here 236 | self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) 237 | self.bn4 = nn.BatchNorm2d(2048) 238 | 239 | self.fc = nn.Linear(2048, num_classes) 240 | 241 | # #------- init weights -------- 242 | # for m in self.modules(): 243 | # if isinstance(m, nn.Conv2d): 244 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 245 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 246 | # elif isinstance(m, nn.BatchNorm2d): 247 | # m.weight.data.fill_(1) 248 | # m.bias.data.zero_() 249 | # #----------------------------- 250 | 251 | def features(self, input): 252 | x = self.conv1(input) 253 | x = self.bn1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.conv2(x) 257 | x = self.bn2(x) 258 | x = self.relu2(x) 259 | 260 | x = self.block1(x) 261 | x = self.block2(x) 262 | x = self.block3(x) 263 | x = self.block4(x) 264 | x = self.block5(x) 265 | x = self.block6(x) 266 | x = self.block7(x) 267 | x = self.block8(x) 268 | x = self.block9(x) 269 | x = self.block10(x) 270 | x = self.block11(x) 271 | x = self.block12(x) 272 | 273 | x = self.conv3(x) 274 | x = self.bn3(x) 275 | x = self.relu3(x) 276 | 277 | x = self.conv4(x) 278 | x = self.bn4(x) 279 | return x 280 | 281 | def logits(self, features): 282 | x = nn.ReLU(inplace=True)(features) 283 | 284 | x = F.adaptive_avg_pool2d(x, (1, 1)) 285 | x = x.view(x.size(0), -1) 286 | x = self.last_linear(x) 287 | return x 288 | 289 | def forward(self, input): 290 | x = self.features(input) 291 | x = self.logits(x) 292 | return x 293 | 294 | 295 | def xception(num_classes=1000, pretrained='imagenet'): 296 | model = Xception(num_classes=num_classes) 297 | if pretrained: 298 | settings = pretrained_settings['xception'][pretrained] 299 | 300 | # load matching keys (except for final fc layer) 301 | model = Xception(num_classes=num_classes) 302 | pretrained_state = model_zoo.load_url(settings['url']) 303 | model_state = model.state_dict() 304 | pretrained_state = {k: v for k, v in pretrained_state.items() 305 | if k in model_state and v.size() == model_state[k].size()} 306 | print(list(pretrained_state.keys())) 307 | model_state.update(pretrained_state) 308 | model.load_state_dict(model_state) 309 | 310 | model.input_space = settings['input_space'] 311 | model.input_size = settings['input_size'] 312 | model.input_range = settings['input_range'] 313 | model.mean = settings['mean'] 314 | model.std = settings['std'] 315 | 316 | # TODO: ugly 317 | model.last_linear = model.fc 318 | del model.fc 319 | return model 320 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bit-ml/DeCLIP/09d4293c78ed648b103b7c453e7d4e47be0707f3/networks/__init__.py -------------------------------------------------------------------------------- /networks/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__(self, opt): 9 | super(BaseModel, self).__init__() 10 | self.opt = opt 11 | self.total_steps = 0 12 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 13 | self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') 14 | 15 | def save_networks(self, save_filename): 16 | save_path = os.path.join(self.save_dir, save_filename) 17 | 18 | # serialize model and optimizer to dict 19 | state_dict = { 20 | 'model': self.model.state_dict(), 21 | 'optimizer' : self.optimizer.state_dict(), 22 | 'total_steps' : self.total_steps, 23 | 'feature_layer': self.opt.feature_layer, 24 | 'decoder_type': self.opt.decoder_type, 25 | } 26 | 27 | torch.save(state_dict, save_path) 28 | 29 | 30 | def eval(self): 31 | self.model.eval() 32 | 33 | def test(self): 34 | with torch.no_grad(): 35 | self.forward() 36 | 37 | 38 | def init_weights(net, init_type='normal', gain=0.02): 39 | def init_func(m): 40 | classname = m.__class__.__name__ 41 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 42 | if init_type == 'normal': 43 | init.normal_(m.weight.data, 0.0, gain) 44 | elif init_type == 'xavier': 45 | init.xavier_normal_(m.weight.data, gain=gain) 46 | elif init_type == 'kaiming': 47 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 48 | elif init_type == 'orthogonal': 49 | init.orthogonal_(m.weight.data, gain=gain) 50 | else: 51 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 52 | if hasattr(m, 'bias') and m.bias is not None: 53 | init.constant_(m.bias.data, 0.0) 54 | elif classname.find('BatchNorm2d') != -1: 55 | init.normal_(m.weight.data, 1.0, gain) 56 | init.constant_(m.bias.data, 0.0) 57 | 58 | print('initialize network with %s' % init_type) 59 | net.apply(init_func) 60 | -------------------------------------------------------------------------------- /networks/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.base_model import BaseModel 5 | from models import get_model 6 | from utils.utils import compute_batch_iou, compute_batch_localization_f1, compute_batch_ap 7 | 8 | class Trainer(BaseModel): 9 | def name(self): 10 | return 'Trainer' 11 | 12 | def __init__(self, opt): 13 | super(Trainer, self).__init__(opt) 14 | self.opt = opt 15 | self.model = get_model(opt) 16 | 17 | # Initialize all possible parameters in the final layer 18 | for fc in self.model.fc: 19 | try: 20 | torch.nn.init.normal_(fc.weight.data, 0.0, opt.init_gain) 21 | except: 22 | pass 23 | 24 | if opt.fix_backbone: 25 | params = [] 26 | for name, p in self.model.named_parameters(): 27 | if "fc" in name and "resblock" not in name: 28 | params.append(p) 29 | else: 30 | p.requires_grad = False 31 | else: 32 | print("Your backbone is not fixed. Are you sure you want to proceed? If this is a mistake, enable the --fix_backbone command during training and rerun") 33 | import time 34 | time.sleep(3) 35 | params = self.model.parameters() 36 | 37 | if opt.optim == 'adam': 38 | self.optimizer = torch.optim.AdamW(params, lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) 39 | elif opt.optim == 'sgd': 40 | self.optimizer = torch.optim.SGD(params, lr=opt.lr, momentum=0.0, weight_decay=opt.weight_decay) 41 | else: 42 | raise ValueError("optim should be [adam, sgd]") 43 | 44 | self.loss_fn = nn.BCEWithLogitsLoss() 45 | 46 | self.model.to(opt.gpu_ids[0]) 47 | 48 | if opt.fully_supervised: 49 | self.ious = [] 50 | self.F1_best = [] 51 | self.F1_fixed = [] 52 | self.ap = [] 53 | else: 54 | self.logits = [] 55 | self.labels = [] 56 | 57 | def adjust_learning_rate(self, min_lr=1e-6): 58 | for param_group in self.optimizer.param_groups: 59 | param_group['lr'] /= 10. 60 | if param_group['lr'] < min_lr: 61 | return False 62 | return True 63 | 64 | def set_input(self, input): 65 | self.input = input[0].to(self.device) 66 | self.label = input[1].to(self.device).float() 67 | 68 | def forward(self): 69 | self.output = self.model(self.input) 70 | 71 | if self.opt.fully_supervised: 72 | # resize prediction to ground truth mask size 73 | if self.label.size()[1] != 256 * 256: 74 | label_size = (int(self.label.size()[1] ** 0.5), int(self.label.size()[1] ** 0.5)) 75 | self.output = self.output.view(-1, 1, 256, 256) 76 | self.output = F.interpolate(self.output, size=label_size, mode='bilinear', align_corners=False) 77 | self.output = torch.flatten(self.output, start_dim=1).unsqueeze(1) 78 | 79 | if not self.opt.fully_supervised: 80 | self.output = torch.mean(self.output, dim=1) 81 | 82 | def get_loss(self): 83 | return self.loss_fn(self.output.squeeze(1), self.label) 84 | 85 | def optimize_parameters(self): 86 | self.forward() 87 | outputs = self.output 88 | 89 | if self.opt.fully_supervised: 90 | sigmoid_outputs = torch.sigmoid(outputs) 91 | 92 | # unflatten outputs and ground truth masks 93 | sigmoid_outputs = sigmoid_outputs.view(sigmoid_outputs.size(0), int(sigmoid_outputs.size(1)**0.5), int(sigmoid_outputs.size(1)**0.5)) 94 | labels = self.label.view(self.label.size(0), int(self.label.size(1)**0.5), int(self.label.size(1)**0.5)) 95 | 96 | iou = compute_batch_iou(sigmoid_outputs, labels) 97 | self.ious.extend(iou) 98 | 99 | F1_best, F1_fixed = compute_batch_localization_f1(sigmoid_outputs, labels) 100 | self.F1_best.extend(F1_best) 101 | self.F1_fixed.extend(F1_fixed) 102 | 103 | ap = compute_batch_ap(sigmoid_outputs, labels) 104 | self.ap.extend(ap) 105 | else: 106 | self.logits.append(outputs) 107 | self.labels.append(self.label) 108 | 109 | self.optimizer.zero_grad() 110 | self.loss = self.loss_fn(outputs, self.label) 111 | self.loss.backward() 112 | self.optimizer.step() 113 | 114 | def format_output(self): 115 | if not self.opt.fully_supervised: 116 | self.logits = torch.cat(self.logits, dim=0) 117 | self.labels = torch.cat(self.labels, dim=0) 118 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bit-ml/DeCLIP/09d4293c78ed648b103b7c453e7d4e47be0707f3/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | class BaseOptions(): 6 | def __init__(self): 7 | self.initialized = False 8 | 9 | def initialize(self, parser): 10 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 11 | parser.add_argument('--fully_supervised', action='store_true', help='use fully supervision with local manipulation ground truth masks') 12 | 13 | parser.add_argument('--arch', type=str, default='CLIP:ViT-L/14', help='see models/__init__.py') 14 | parser.add_argument('--fix_backbone', action='store_true', help='train only the decoder') 15 | 16 | parser.add_argument('--weight_decay', type=float, default=0.0, help='loss weight for l2 reg') 17 | parser.add_argument('--batch_size', type=int, default=64, help='input batch size') 18 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 19 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 20 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization (normal/xavier/kaiming/orthogonal)') 21 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 22 | 23 | self.initialized = True 24 | return parser 25 | 26 | def gather_options(self): 27 | # initialize parser with basic options 28 | if not self.initialized: 29 | parser = argparse.ArgumentParser( 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser = self.initialize(parser) 32 | 33 | # get the basic options 34 | opt, _ = parser.parse_known_args() 35 | self.parser = parser 36 | 37 | return parser.parse_args() 38 | 39 | def print_options(self, opt): 40 | message = '' 41 | message += '----------------- Options ---------------\n' 42 | for k, v in sorted(vars(opt).items()): 43 | comment = '' 44 | default = self.parser.get_default(k) 45 | if v != default: 46 | comment = '\t[default: %s]' % str(default) 47 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 48 | message += '----------------- End -------------------' 49 | print(message) 50 | 51 | # save to the disk 52 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 53 | os.makedirs(expr_dir, exist_ok=True) 54 | file_name = os.path.join(expr_dir, 'opt.txt') 55 | with open(file_name, 'wt') as opt_file: 56 | opt_file.write(message) 57 | opt_file.write('\n') 58 | 59 | def parse(self, print_options=True): 60 | 61 | opt = self.gather_options() 62 | opt.data_label = self.data_label 63 | 64 | if print_options: 65 | self.print_options(opt) 66 | 67 | # set gpu ids 68 | str_ids = opt.gpu_ids.split(',') 69 | opt.gpu_ids = [] 70 | for str_id in str_ids: 71 | id = int(str_id) 72 | if id >= 0: 73 | opt.gpu_ids.append(id) 74 | if len(opt.gpu_ids) > 0: 75 | torch.cuda.set_device(opt.gpu_ids[0]) 76 | 77 | self.opt = opt 78 | return self.opt 79 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ckpt', type=str, help='path to the trained model\'s checkpoint') 8 | parser.add_argument('--result_folder', type=str, default='result', help='path to the folder to log the test resutls') 9 | parser.add_argument('--output_save_path', type=str, default=None, help="The path to which the resulted images will be saved, along side the scores for each input sample") 10 | 11 | # TODO: uncomment these line for backwards compability 12 | parser.add_argument('--decoder_type', type=str, default='conv-20', help='type of decoder (linear/attention/conv-4/conv-12/conv-20)') 13 | parser.add_argument('--feature_layer', type=str, default=None, help='layer of the backbone from which to extract features') 14 | self.data_label = 'test' 15 | return parser 16 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--train_dataset', type=str, default='pluralistic', help='the dataset on which to train') 8 | parser.add_argument('--decoder_type', type=str, default='conv-20', help='type of decoder (linear/attention/conv-4/conv-12/conv-20)') 9 | parser.add_argument('--feature_layer', type=str, default=None, help='layer of the backbone from which to extract features') 10 | parser.add_argument('--data_aug', type=str, default=None, help='if specified, perform additional data augmentation (blur/color_jitter/jpeg_compression/all)') 11 | 12 | parser.add_argument('--earlystop_epoch', type=int, default=5) 13 | parser.add_argument('--optim', type=str, default='adam', help='optim to use (sgd/adam)') 14 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 15 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 16 | 17 | parser.add_argument('--loss_freq', type=int, default=50, help='frequency of showing loss on tensorboard') 18 | parser.add_argument('--niter', type=int, default=400, help='total epochs') 19 | 20 | 21 | parser.add_argument('--data_root_path', type=str, default=None, help='Root path for dolos data only! Explicitly fill in the other data paths for other dataset.') 22 | parser.add_argument('--train_path', type=str, default='datasets/dolos_data/celebahq/fake/ldm/images/train', help='folder path to training fake data') 23 | parser.add_argument('--valid_path', type=str, default='datasets/dolos_data/celebahq/fake/ldm/images/valid', help='folder path to validation fake data') 24 | parser.add_argument('--train_masks_ground_truth_path', type=str, default='datasets/dolos_data/celebahq/fake/ldm/masks/train', help='path to train ground truth masks (only for fully_supervised training)') 25 | parser.add_argument('--valid_masks_ground_truth_path', type=str, default='datasets/dolos_data/celebahq/fake/ldm/masks/valid', help='path to valid ground truth masks (only for fully_supervised training)') 26 | parser.add_argument('--train_real_list_path', default='datasets/dolos_data/celebahq/real/train', help='folder path to training real data') 27 | parser.add_argument('--valid_real_list_path', default='datasets/dolos_data/celebahq/real/valid', help='folder path to validation real data') 28 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints/', help='models are saved here') 29 | 30 | self.data_label = 'train' 31 | return parser 32 | -------------------------------------------------------------------------------- /plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "AVG IID: 73.8505\n", 13 | "AVG OOD: 34.720000000000006\n", 14 | "MIN OOD: 4.903\n" 15 | ] 16 | }, 17 | { 18 | "data": { 19 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAHECAYAAAATY9HhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACPt0lEQVR4nOzddVzU9x/A8dfRKWCAih1gIiri7EBnizVrxtycMXXTny5ch9t0Tp2znbmZsztxM6aYiN1FCAbdcff7A71541Ti4L7g+/l73OO3+3w/9/m8jxN486mvSqPRaBBCCCGEUBATYwcghBBCCPFfkqAIIYQQQnEkQRFCCCGE4kiCIoQQQgjFkQRFCCGEEIojCYoQQgghFEcSFCGEEEIojiQoQgghhFAcSVCEEEIIoTiSoAghhBBCccyMHYAQQgghlCkiIoJZs2bx999/8/DhQxwdHWnZsiUffPABJUqU0KmbmJjInDlz2LlzJw8ePMDZ2ZlOnTrx3nvvYW1tne2+VXIvHiGEEEL8V0REBG+88QYhISF069YNT09PgoODWblyJU5OTqxbt45ixYoBkJ6ezltvvcWJEyfw9fWlQYMGXLlyhdWrV9OgQQOWLl2KiUn2Jm1kBEUIIYQQmcyfP5/g4GD+97//MXz4cG1569at6d+/P7/88gvfffcdAJs2beLEiRMMHDiQzz//XFvX1dWVKVOmsHXrVrp165at/mUNihBCCCEy8ff3B6BHjx465fXq1aN8+fJs376d5ORkALZs2QLAkCFDdOr2798fKysrNm/enO3+JUERQgghRCYpKSkAetePWFtbk5CQwLVr19BoNJw/fx5nZ2dcXV116llZWVG9enXOnz+f7f5likcIIYQoxHx8fF543c/PT2951apVuX37Nv7+/rRp00Zb/uDBA27dugXA/fv3KVOmDImJiVStWlVvOy4uLgQEBBAXF4ednV2W45YE5Ym79dq8vJLIF3MjSry8ksgXoZpkY4cgntgdddHYIYhnPIy+mqftpz66laftZ8Vbb72Fn58fX3/9NSkpKdSpU4fQ0FB++ukn1Go1kLFzJykpCQALCwu97VhaWmrrSoIihBBCFGTqdIM19bwRkpepX78+M2fO5LvvvmPcuHEAqFQq2rVrR61atVi1ahV2dnZYWVkB/04J/dfTdSrZ3WosCYoQQggh9Grbti2tW7fmxo0bxMTEUK5cOVxcXPjggw8AqFy5Mo6OjlhbWxMWFqa3jfDwcOzs7LI1egKSoAghhBDKo1EbOwItU1NT3N3dtc9TUlLw9/enQoUKVKhQAYBatWpx8uRJQkJCdBbKJiUlcfnyZerWrZvtfmUXjxBCCKE0arXhHgY2ffp0oqKiGDlypLbM19cXgKVLl+rUXb16NUlJSdrr2SEjKEIIIYTCaBQygtK+fXtat25N+fLlSUpKYv/+/Zw4cYL+/fvrHLzWo0cPNm/ezB9//EFsbCxeXl5cvXqVVatW4e3tTdeuXbPdtyQoQgghhNDL09OTffv2ER4ejrm5OTVq1OCXX36hQ4cOOvVMTU1ZuHAhc+bMYdeuXezYsYMSJUowZMgQRo0ahampabb7lnvxPCHbjJVDthkrh2wzVg7ZZqwseb3NOCU4+webPY9FmdoGays/yQiKEEIIoTQKmeIxJlkkK4QQQgjFkREUIYQQQmkMeFBbQSUJihBCCKE0MsUjUzxCCCGEUB4ZQRFCCCGUJg8OWCtoJEERQgghFEYpB7UZk0zxCCGEEEJxZARFCCGEUBqZ4pEERQghhFAcmeKRBEUIIYRQHDkHRdagCCGEEEJ5ZARFCCGEUBqZ4pEERQghhFAcWSQrUzxCCCGEUB4ZQRFCCCGURqZ4JEERQgghFEemeGSKRwghhBDKIyMoQgghhMJoNHIOiiQoQgghhNLIGhSZ4hFCCCGE8sgIihBCCKE0skhWEhQhhBBCcWSKRxIUIYQQQnHkZoGyBkUIIYQQyiMjKEIIIYTSyBSP8hOUnTt38scff3Dnzh2ioqIyXVepVFy6dCn/AxNCCCHyiiySVXaCsnDhQmbMmEGxYsWoW7cuDg4Oxg5JCCGEeGXExcWxfPlydu/eTXBwMBYWFpQpU4YePXrQu3dvzM3NtXUTExOZM2cOO3fu5MGDBzg7O9OpUyfee+89rK2ts923ohOUFStW4OXlxeLFi7GwsDB2OEIIIUT+UMAUT1paGoMHD+bSpUt069aNN998k5SUFPbu3cu3335LQEAAP//8MwDp6ekMGzaMEydO4OvrS4MGDbhy5QqLFy/m3LlzLF26FBOT7C17VXSCEhMTQ8eOHSU5EUII8WpRwBTPiRMnuHDhAm+//TYff/yxtvzNN9+kZ8+e7Nixg6+//ho7Ozs2bdrEiRMnGDhwIJ9//rm2rqurK1OmTGHr1q1069YtW/0rOkGpUqUKjx8/NnYYecp1+wrMSpfUey39UQTBr/f+t8DMFPs3umLhXhkL9yqYVyqPytycx99OI27zrux1bMi2CgEbRztqtmtAtdZ1KeleFoeSRUlPSeP+1XucXneQU+sOotFotPXf+HkEXr1avLDNG/9c4Lc3v89RPD0nv4t339YA/NRiLI/vhueonYLIztGO+u1fw7NVfcpWK4dTyaKkpaQRdPUeh9cd4NCfB3Q+C1MzU3wGtqd8jYqUr1kR16plMLMwZ9HHczm4Zn+OYlCZmNC8d2ua9mhB2WrlMbc0J+pBJLcCb7Jh2irCbt831NstVHr17sq836YCMG7MZ6z4fX2WXmdubs7w9wbT840uVKpcnvS0dC5dvMpvC/5gy6ZX7+eRUsTGxgLg7OysU25qakrx4sW5ceOGdgBhy5YtAAwZMkSnbv/+/Zk5cyabN28uXAnKyJEj+fLLL+nZsyelSpUydjh5Rh0bR8yqjZnKNQmJOs9VVlYU/XAUkJG8pD+KxKyUc6bXZYUh2yoMand6jR7fv0NMeCQ3j13kfOhj7Io7UKt9A3r9NBz3lp6seO8Xbf2Le08RGfxQb1v1ujejWHkXrv59NkexVPeph3ff1iTHJWJpl/1524LOu1NjhvwwgsjwCC4fu8DjkIcUKeGIV7vXGPrTKDxa1mPWyKna+pY2lgz8+h0Aoh5EEvUwiuKuJXLcv6WNFeMWfULNJh7cuXiLw+v/JjU5BaeSRXFvUIOSFUtLgqJHadeSTJ76BXGx8djZ22b5debm5vy5cTFNmzfk7t1g1qzciMrEhDZtm7No2S9Uq16VKT/8moeRK5QCRlDq1auHjY0NCxcuxMXFBU9PT5KTk9m1axdHjhzh/fffx8LCAo1Gw/nz53F2dsbV1VWnDSsrK6pXr8758+ez3b+iExQfHx/i4+Pp3LkzPj4+lClTJtMclkqlYtSoUUaK0DDUsXFEL/j9pfU0ScmEj55I6rWbpD+KwGH4IByHD8pRn4ZsqzB4dOs+y96ZypUDATp/ne+euoYxmydRu2NDarX35sLuEwBc2nuKS3tPZWrHqogNLYZ3IS05lVPrD2U7Dtui9vSc/C6B245iV8KRyq/VyPmbKqDCbocy/e0fOHvgtM5nse6nlXy9ZQreHRvh1eE1Tu3yByA5MYWpg7/j7qU7RD+IpPvYPvQY1yfH/b/94whqNvFgycT5/LVqb6brpmamOW67MPt17o9EREaxY9s+Rr//TpZf9/a7/WnavCEnjp/hjW5vk/DkDzNbWxs27/iD/304kt27DhAYcCGvQlckQ97N2MfH54XX/fz89JaXKFGCuXPn8vXXXzNu3DhtuaWlJd9//z09e/YEICoqisTERKpWraq3HRcXFwICAoiLi8POzi7LcSs6Qbl79y6//vor8fHxbN26VW+dwpCgZFlaGklHTyqvrULg5rGLesvjHkbjv3I/7T/qS6XXamgTlOep170ZFtaWnN16lITI2GzH0ePHdwHY/MVSBswf95LahdOlo/p/EUU/jOLAyr30/uhNqr9WS5ugpKemce7vAIP0Xb5WJRp3a47/1iN6kxOA9DQ54fO/ho0YRLPmr9Gt00CaNn8tW6/t1LktAL/8PF+bnADExycwfepcfl81l7ff6c8Hoz81aMwia+zs7KhYsSLe3t40adKEpKQkNm3axBdffIFKpaJHjx4kJSUBPHe9qKWlJZCxy6fQJCjffvstDx8+ZOLEiXh7e1OkSBFjh5Q3zM2x7eiDaUlnNIlJpFy/RfKZ84oY4hP//kJSp7/8F9PTdSPHV+v/i+RF6vdqTq12DVj+7s8kRMVl+/WvgvS0NADUeZQkNPZtBsCxrYextrehbhsvipUqTlxkLBePnufB3bA86bcgq+pWic+/Hs/Ceb9z7OipbCcozi7FAbhzJyjTtbt3ggFo1iJ7bRYKBvz5/7wRkpe5cuUK/fv3Z/DgwUyYMEFb3rVrV/r168e3335Ly5YtsbKyAiAlJUVvO8nJyQDZ3mqs6ATlzJkzDBkyhMGDBxs7lDxlVqIYxSdN1ClLDb7P46+nknzmnJGiEgAmpibU65HxS+vawcAX1i1Xryqlqpfj4c1Qbh3L3uGBjq7F6frlYM5sPMylfadzHG9hZmJqQtMeLQE4d9AwIyb/VcmjCgDFXUsw7dBc7Iv++0eRWq3mwIo9/P7VYjTyxwOQsVhy7sKphATf5/tvp+eojcePI6lcpSLly5fh+rVbOtfKVygDQNlyrlhZWZKUlJzrmAsMBWwzXr58OSkpKbRv316n3MTEhHbt2nH27FnOnTtHixYtsLa2JixMfwIfHh6OnZ1dtkZPQOH34rG2ts60eriwidu6h/DhEwhq04t7jTsR+sZQYtdvw6y0C86zfsC8aiVjh/hK6/BxP0pVK8flAwFcO/TiZLFhv4zRkxNrDmSrD5VKRe9pI0lOSGLr18tzHGth1+eTgZStVp6zB05z/tDZPOnDvnjGYZD9vxjCZf+LfNR6DEOr9+fH/l/x4G4YbQZ1oNv7b+RJ3wXRhI9HUdujOmNGfpLj5GH/3oMAjJswEisrS225jY01Y8eP0D53cCikI+jPo1Yb7pFDDx48eBJK5jbSnoxmpqWloVKpqFWrFg8ePCAkJESnXlJSEpcvX6Z27drZ7l/RCUqrVq04fPiwscPIU9EL/yDp5FnUEVFokpJJvXmHiB9mErNiPSbWVq/0wlVja/xWO5oP68yDGyGsHTfnhXWt7K3x6PRajhbHNn2nI5Vfq8GGT34jMSY+NyEXWq+/1ZGOw3wJuRHM/LEz86wfE5UKgPs3Q5g9ahr3b4aQnJDEpX/OM2vkz6jT0+kwtAum5ooefM4X9ep7MHb8cObOXsqpk2dz3M7Ceb9z4dxlvF+rx+HjO5g89Qum/Pwlh/23U6x4UaKjYgD9vyRF3qpSJWNEceNG3V2mqampbN++HVNTU23i4evrC8DSpUt16q5evZqkpCTt9exQdILy0Ucf8ejRI7755hvu3buns6K/sIvbsB0Ay3oeRo7k1dRo0Ov4fv0W4deCWdDvOxKjX5w41O3WFAsbKy7sOZmtxbHFK5ak3Ye9Ofnn3znellzYtRncgYHfDCX42j1+7Psl8dF5tz4n4UmCGLD/ZKZpnHuX7/Aw6AHW9ja4VimTZzEUBKampsxZ8BM3b9xh8qRfctVWfHwCndv355dp80lPS2PA4N749ujIsaOn6NyuH6ampqSmphIZGW2Y4AsKjdpwjxwaPHgwTk5OrF69mhEjRrBy5UoWLVpEz549uXr1KoMGDcLFxQWAHj164OXlxR9//MHHH3/MunXrmDRpElOnTsXb25uuXbtmu39F/xnQsGFDVCoVFy5cYM2aNXrrFNabBaY/+WY0sbYyciSvnqZvd6DLl4O4f+Uev735PfGPY176Gu8n0zvHV2XvYDDnqmUwt7SgQe+WNOjdUm+djw7+AsDyYdP0bm0uzNq93ZkBX71N0JW7TO7/NTGP8/aX1P1boVSu60Z8TILe6/FPElVzq1f7dGtbOxuqVK0IQMhD/buuZsz6nhmzvmfB3OV8PvGHF7YXH5/A99/O4PtvZ+iUl69QBjt7W84GXNBOKbwyFDBiVLp0adavX8/cuXM5evQohw8fxtzcnKpVqzJp0iR69eqlrWtqasrChQuZM2cOu3btYseOHZQoUYIhQ4YwatQoTE2zvz1f0QlKt27dUD0Zcn3VWNauDkBqiBwIlZ9ajOhCx0/6E3LxDosG/JCl0ZCynpUpXaNCxuJY/8vZ6i8y+OFz16xUa12XIs5OnNvuT1JcwnMPhiusOo3oTt+JA7lz8RZT3vyGuBxs286uC0fO0bRnS8q4l810zczCDJcKGQdGPgp+kOexKFlKcgorfl+n95pHnRp41KmJ/9FT3Lhxm1Mnc76guXffbgBsWLctx22I3ClTpgw//PDiBPMpW1tbPvroIz766COD9K3oBGXy5MnGDiFPmVUsR/r9B2ie7CF/yrSUC0U/Hg1A/M6cHdX9lMrOFtPiRdHExZP+KCJXbRV2PmO68/r43gSfu8WigT+8dFrnqYb9Mg5BOr76xYtjreytsXd2IikmgdiHUQDcv3SXDZ/8prf+sDVfUMTZid1T17xSR90D+L7/Br3G9+PWuRv8NOBbg0/rWNvb4OjsREJsAtEPIrXlp3YdI+LjN3mtcxP2LdvJrcAb2mvd3n8DWwdbLh09T/STz+9VlZSUzLgxn+u99uEno/GoU5O1qzfpHHVvbW2Fa5nSJCYmEhKs+4eXnb0tcbG6328tWjVmzNh3uX3rLr8vXWv4N6F0CtjFY2yKTlAKO9vXW1JkQC+Szpwn/X446oQEzMqUxrppQ0ysLEk4fJyY//yVUuStvphXyPjrzsK9ckY7Xdth6VkLgOSzF3TupWPTqgnFv/mIuK17ePz11Fy1VZjV69mc18f3Jj0tndsnr9BkSPtMdSKDH3L6PwtgLe2s8ejciNTkFE5vePHi2JrtGtD755GcWn+QdRPmGzT+wqRpz5b0Gt+P9LR0rp24zOtDOmWq8yj4AYfX/6V93nlkd0pXzjhiu1yNjKmH5m+0xt2rGgBXT13RuS+PV7uGDJs2hsPrDrBwwmxteXJiMr+Nn83/lnzK5+u+59RufyLDI6jsWRV37xpEP4xiyUT57HKibn0Ptuz4g38OH6dbZ93F/8dO7ubSxatcv3aL5ORkatepQYuWjXkQ/oiB/d7TOcDtlaGAKR5jU1SCEhoamqPXlS5d2sCR5I+kk2cxL18GC/cqWHnWRGVlhToujuSzF4jfsZ/4Hfsyvca6cQOsvOrolFl51oInSQWQ5aTCkG0VdEXLZty7xdTMlGbvdNRb56b/pUwJSt1uTbC0tcrxybEisxJlMxbdmZqZ0n5oF711Lh+7oJOgeLSoS/VGtXTquHlVw+1JggJk+caBF44E8rXvx/i+/wY1m3pgY29D1MMo/P7YzeZf1xH1zIiLMIz167bR2qcpDbzrYmZuRnBQKLNnLmLWzEVEvWqLY4WWSqOgrTHVqlXL0ZqTy5ezN++vz916bXLdhjCMuRE5v9GbMKxQzSt0MJbC7Y7SfzsGYRwPo6/mafuJO34xWFvWncYarK38pKgRlFGjRmVKUA4cOMDly5dp1KiRdk/29evX8ff3p0aNGrRq1coYoQohhBB5R9agKCtBGTNmjM7z3bt3s3TpUtasWUOdOrpTEQEBAbzzzjsMGzYsP0MUQgghRD5Q9EFtCxcupH///pmSE4C6devSv39/FixYYITIhBBCiDykgKPujU3RCcrNmzcpVarUc6+XLl2amzdv5mNEQgghRD5QwEmyxqboBMXe3h5/f//nXvf398fe3j4fIxJCCCHygYygKDtBad++Pfv27WPy5MlERPx7yFhERAQ//vgj+/bty3QbaCGEEEIUfIpaJPtf48aN4+LFiyxbtozly5fj5OQEQGRkJBqNBk9PT8aNG2fkKIUQQggDK8BTM4ai6ATF1taWFStWsGnTJvbu3UtQUBAAtWvXpm3btnTv3j1HNyASQgghFK0AT80YiqITFMi4Q2KvXr107poohBBCiMJNUQnK5s2bc/S6bt26GTQOIYQQwqhkBEVZCconn3yCSqUiO6fvq1QqSVCEEEIULsq5C43RKCpB+f33340dghBCCCEUQFEJire3t7FDEEIIIYxPpniUlaAIIYQQAklQUPhBbUIIIYR4NckIihBCCKE0clCbJChCCCGE4sgUjyQoQgghhOLINmNZgyKEEEII5ZERFCGEEEJpZIpHEhQhhBBCcSRBkSkeIYQQQiiPjKAIIYQQSiPbjCVBEUIIIZRGo1bGLp5Zs2Yxe/bsF9Y5dOgQLi4uAKSlpbFkyRI2bNhASEgIjo6O+Pj4MHbsWJycnLLVtyQoQgghhNCrbdu2lCtXLlN5aGgov/zyCzVr1tQmJwATJ05k69attGrVinfeeYfg4GCWL1/OmTNnWLt2LTY2NlnuWxIUIYQQQmkUski2WrVqVKtWLVP5L7/8AkDv3r21ZceOHWPr1q20bt2aefPmactr1qzJ+++/z5IlSxg9enSW+5ZFskIIIYTSaNSGexhYeno6GzduxMbGhs6dO2vLt2zZAsCQIUN06rdr1w5XV1ft9aySBEUIIYQQWXbo0CHCw8Pp0KEDdnZ22vLAwEBMTEzw9PTM9Jq6dety7949oqKistyPTPEIIYQQSmPARbI+Pj4vvO7n55et9v78808A+vTpo1MeFhaGk5MTFhYWmV7zdJ1KWFgYjo6OWepHEhQhhBBCaRSyBuW/Hjx4wMGDB3Fzc6NOnTo615KSknBwcND7OktLS22drJIERQghhFAaAyYo2R0heZGNGzeSnp6uszj2KSsrK1JSUvS+Ljk5WVsnq2QNihBCCCFeSqPRsH79eqysrPD19c10vWTJkkRGRupNUsLDw7V1skoSFCGEEEJpNBrDPQzk2LFjBAUF0a5dO4oUKZLpuoeHB2q1msDAwEzXAgICKFeuXJbXn4AkKEIIIYTyqNWGexjIunXrAPRO7wDaUZUlS5bolO/du5eQkBC9oy4vImtQhBBCCPFCERER7Nu3j0qVKuHl5aW3TuPGjencuTPbt29nxIgR+Pj4EBwczLJly6hSpUqm81FeRhIUIYQQQmkUci+ep7Zs2UJqaupzR0+emjx5Mm5ubmzcuJFvvvkGR0dHfH19GTt2LLa2ttnqU6XRGHCCqgC7W6+NsUMQT8yNKGHsEMQToZpkY4cgntgdddHYIYhnPIy+mqftJ0x922Bt2Xy45OWVFEjWoAghhBBCcWSKRwghhFAahU3xGIMkKE88CLc3dgjiia9/djN2COKJqZ/cMHYI4ol/zLN+m3pR8GkUepJsfpIERQghhFAaGUGRNShCCCGEUB4ZQRFCCCGURiNTPIocQbl9+zbjxo2jY8eODBw4kM2bN+utt3///pfeRloIIYQocNQawz0KKMWNoDx8+JC+ffsSHR0NwK1btzh16hQHDhxg6tSp2ls2AyQkJBAaGmqsUIUQQgiRRxQ3grJgwQLi4uL49ttvOXXqFDt27KBt27bs3buX4cOHa2/ZLIQQQhRaCrwXT35TXIJy7NgxunXrRu/evbGzs6Ny5cr8+uuvjB07Fn9/f9577z29t3IWQgghCg2Z4lFeghIaGoqnp2em8hEjRjBx4kT++ecfxowZQ2pqav4HJ4QQQoh8obg1KLa2tiQlJem9NnjwYNRqNVOmTOGDDz6gTRu5f44QQohCSHbxKC9BKVu2LGfPnmXgwIF6rw8ZMoTU1FSmT5/OuXPn8jk6IYQQIh8U4KkZQ1HcFE/jxo3566+/iIuLe26dYcOGMXbsWB49epSPkQkhhBAivyhuBKVr166kpKRw9+5datas+dx6I0aMwN7engsXLuRjdEIIIUTek3vxKDBBKV++POPHj89S3TfffDOPoxFCCCGMQKZ4lJegPHXhwgXu3buHk5MTXl5emJubGzskIYQQIn9IgqK8BCUlJYXRo0dz+PBhbVnZsmVZvHgxZcuWNWJkQgghhMgvilsku3jxYg4dOoS7uztvvfUWLVq04N69e3z55ZfGDk0IIYTIHxq14R4FlOJGUHbt2kXt2rVZs2YNpqamAPz8888sXryYyMhInJycjByhEEIIkcdkikd5IyhBQUF06tRJm5wAdO/eHY1Gw927d40YmRBCCCHyi+JGUBITEylWrJhOWdGiRQGee8KsEEIIUZhoZARFeQnKi2g08oEJIYR4BUiCoswExc/Pj5CQEO3zxMREVCoV27ZtIzAwUKeuSqVi+PDh+R2iEEIIIfKQIhOU3bt3s3v37kzlGzduzFQmCYoQQohCR06SVV6C8vvvvxs7BCGEEMK4ZIpHeQmKt7e3sUMQQgghhJEpLkF5EbVaTVhYGMWLF8fCwsLY4QghhBB5Q0EjKHFxcfz222/s3buXkJAQrKysKF++PAMGDMDX11dbLzExkTlz5rBz504ePHiAs7MznTp14r333sPa2jrb/RaoBCUiIgIfHx+WLFlCo0aNjB2OEEIIkSeUsms1PDycQYMGERkZSffu3alSpQqJiYncuXOH0NBQbb309HSGDRvGiRMn8PX1pUGDBly5coXFixdz7tw5li5diolJ9o5eK1AJCijnQxNCCCHyjEJGUD766CPi4+PZsmULpUqVem69TZs2ceLECQYOHMjnn3+uLXd1dWXKlCls3bqVbt26ZatvxZ0kK4QQQgjjO336NP7+/gwdOpRSpUqRnp5OfHy83rpbtmwBYMiQITrl/fv3x8rKis2bN2e7f0lQhBBCCKVRawz3yKGDBw8CUK5cOcaMGUOdOnWoV68eTZs2Ze7cuaSnpwMZMxvnz5/H2dkZV1dXnTasrKyoXr0658+fz3b/BWqKx9zcnAYNGuDg4GDsUIQQQog8Y8ij7n18fF543c/PT2/5zZs3Afjss88oU6YMkyZNAmD16tXMnDmT+/fv89133xEVFUViYiJVq1bV246LiwsBAQHExcVhZ2eX5bgLVILi4ODAH3/8YewwhBBCiELv6XSOtbU1K1eu1O6e7dixI506dWLdunUMGTJEu0PnebtrLS0tgYxdPoUmQfHx8eHTTz99bvb3119/MWnSpOdmf0IIIUSBZMARlJz+jrSysgKgS5cuOsmHhYUFXbp0Yc6cORw/fpz27dsDkJKSored5ORkgGxvNVZ0ghISEkJCQsJzrycmJupscxJCCCEKBQWcdF+yZEkASpQokena07Lo6GgcHR2xtrYmLCxMbzvh4eHY2dlla/QECvgi2UePHmkzPCGEEEIYjqenJwD379/PdO1pMlKsWDFUKhW1atXiwYMHOjf6BUhKSuLy5cvUrl072/0rbgTl5MmTHD9+XPt837593L17N1O96Ohodu7cSfXq1fMzPCGEECLPGXKRbE75+PhQpEgRtmzZwsiRI7UjIPHx8WzatAlzc3OaNm0KgK+vLydPnmTp0qU656CsXr2apKQknRNns0pxCcrx48eZPXs2kHGn4r1797J37169dcuXL8/EiRPzMzwhhBAi7ykgQbG3t+ezzz7j448/plevXvTq1QuVSsWGDRsIDw9n3Lhx2sPbevTowebNm/njjz+IjY3Fy8uLq1evsmrVKry9venatWu2+1dpFHY0a2xsLDExMWg0Gtq0aaN3kaxKpcLGxgZHR0eD9XvStbvB2hK5U+tnD2OHIJ6Y+skNY4cgnlgWf8nYIYhn3HoUkKftR/VrZbC2HFf/lavXHzx4kN9++42LFy+iVqtxc3PjrbfeolOnTjr14uPjmTNnDrt27eLhw4eUKFGCjh07MmrUKGxsbLLdr+ISlGdt2rSJBg0aUKZMmTzvSykJSrEeLag0aywAtyfM4dHq/S+s77b6KxyaewJwslxPSM/ayioP/wVYlnV+YZ3gqau4/8u6LLVnSMZIULYE3OKrLSdeWMdEpeLMV32ee/2bLSfYFHALgK1jOlGumH22Yjh0NYTlR69wNSyKdLWGys5F6N2gKl09K2arHUMyRoJi7WhHtfZeVG3liXO1stiXLEp6ShoPrgZxdt1Bzv55CP7zY8vC1oom73WhentvHMsUJzU5ldCzNzm6YDu3/7mY5b7Lv1adwWs/f+71f+ZuxW/K2hy/t9xQYoLSqm1T3hrWnyrulXBycuBB+CMuBF5m8bwVBJw6l6U2OnRpg3fj+tSo5Ua1Wm7Y29uxed0O/jfy+Z+DEuR5gtLHgAnK2twlKMaiuCmeZ3XvroykIb9YlC5GuUnvkh6XiKndy7djOQ/pSJHGtVEnJmNibZmtvsIXbcO0iG3mCyoVpcb0xMTcjOgDZ7LVZkHmXtKJ4S1q6r0WcO8hJ24/oEmV59+H4uDVEDYF3MLGwoyElLRs97/m+DUm7zqDo7UFHT3KY25qwv5LQXy5+Tg3wqP4X7u62W6zoKrRqSGdfnib2PBI7hy7RHTICWxLFKF6uwZ0/WkYVVp6sn7kTG19qyI2vLXhK5zdyvDgahCnVvphYWuFe9v6DFz1KVs/WsjZtQezFcOdY5e46385U/m9k1dz/f4Ki4+/fJ/h7w8h4nEk+3b9TeTjKMpXLEubDi1p38WH8aO+YMu6nS9tZ9T/hlKjtjtxcfGEhYZjb5+9nR6FlRLWoBibohOUp44ePcrt27eJiorKdLNAlUrFqFGjjBSZYVWYPoa0yFgid/lTamS3F9a1qlyaMp8NImz+For6Nn3paMh/hS/arre8SAtPTMzNiD9/k4RzN7PVZkFWrZQT1Uo56b02aNE+AHrWr6z3ekR8Et9uPUm7muV4FJfI6bsPs9V3SGQc0/eexcHagpXDXsfVKeMH9PAWtXhz4V5+P3YVnxplqVO2eLbaLage377Pmrd/5tqBszojJQd++pOhW76lRkdvqnVowJVdJwFoMa4nzm5luLzrBOtHzULzZBTxwE9/8u627+jwzWBuHjxPbFhElmO463+Zg79sNOj7KkyKOxdj6KhBPAx/RMcWvXn8KFJ77bWmXqza/BvjPh6ZpQRl0hc/Exb6gDu37tGwSX1Wb1mUl6EXHArYZmxsik5Q7t27x6hRo7hx48Zz72JcWBIUl3c6U6RJba70+oIiTV6yHcvUhIozx5J8N5yQaasp6tvUYHGUePN1AB6u0L8w+VVzPTyKc8GPcba3ppmb/hGU77Zl/KKc2Kk+49ceyXYfWwJuk5Ku5i3vqtrkBKCItQXvNKvB11tPsO7UjVcmQblzVP9URvzDaE6v9KP1R32o8Fp1bYJSrZ0XAH9PW69NTgASHsfgv2gX7b4aSN3eLTj066a8D/4V4VqmFKamppw9c0EnOQHwP3KK2Ng4ihbXn/D/l/+RU3kRoigEFJ2gTJo0idu3bzNu3DiaNGli0EWxSmJVpQxlJg4gfPF24o5femmCUvqDN7CpVZHLXT9Bk4PphOcxK+6AY1sv0uMSebzpkMHaLcg2nM4YRepWrxKmJpmPDdoScIu/roQwo29THG2yN8321Inb4QB6p5CaVM0oO/mkzqsuPS3j5mTqtH8TEbsSjgBE3nuQqf7TsopNamYrQXGq4EKDwW2xtLMm7mE0905cIeKOfAZP3bl1j+TkFOrUq4lTUUciI6K01xo0qoe9vR17dhwwXoCFgEzxKDxBOXnyJAMHDmTYsGHGDiXvmJpQ6dcPSA59RPDklS+tblunCqXe70XYnI0Gn4Ip0dcHEwtzHvz5F+r4JIO2XRAlpaax49wdTFUqetSrlOl6aFQ8U3cH0MmjPK2q5Xwh993HsQCU17OotoS9NdbmZoTHJJKYkoa1haK/ZfOUytQEjx7NALh5MFBbnhAZi72LE47lnHl0XfeQKKdyGVOfxSo9f/2QPh7dm+LRXXdk8tLOE2z/+DeSYp5/uvWrIjoqhp++ncln341nzz8b2LfrL6IioilXoQxt2rfg8F/H+Hz8JGOHWbDJFI+yExRTU1MqVKhg7DDyVOlxfTJGQ7p/hiZJ/30MnlJZWVDx1w9IuhZE6Iw/DR5L8X5tAXi4UqZ3APZeDCI2KZVmVUtT0kF3QbFareGLTf5YW5jxcYf6ueonNikVADtLc73X7azMSUxNIy459ZVOUNp80heXamW5fiCAm4f+vXX79QNnqdevFS3H9WTD6FnavzxtitrT8J0OAFg56FkQrkfC4xj2/7ia6wfOEhX8EDNLc0p7VKL1R72p0dEbuxIOLHvju0y7iF5FSxesIvheKFN+/Zp+g3pqy+/cuseGNVszTf0IkV2K/mnXoEEDLl7M+hbBgsa2blVKj+lJ2IKtxJ9++e6Asp8NwrKcC5c6fYTmyVC3oRRpVgerCiWJP/dqLY59kafTO728Mi+OXeF/ldN3HzKrf3OKWOu/g6cwHO+32tFoWCce3ghh09h5Otf+nr6eys1rU6NTQ4pVLs3tfy5gYWOFe9t6xIRH4lim+HPXsP3Xw+shPHxmFCY1IZmbB88RdPo6w3f9QLkG7ri1qce1facN+v4KomFjBjPhs9Es/20Nvy9aw8MHj6lctQIffj6GXxb8SPVa7kz5ZubLGxJ6aWQERdn34vn44485cOAAO3e+fCV4gWNqQsWZH5B0K5SQqateWt3+tZo4v9WB+zPXk3jpjsHDKTHgyeJYGT0B4MaDaAKDHuFSxJqmVXWnB+4+imG23zl8PSvSzK10rvuyt8oYOYlLTtV7Pe4lIyyFXYPBbWn/zSAeXAvm977fkxQdr3M97kEUi7p+wYnle7G0taLBwLZUbe3Jxe3+rB/5KwDxj2JyFUNKXCIXthwFoHzDarlqqzBo2KQ+n3w1Fr/dB/n+i2kE3Q0hKTGJi+euMGLweO6HhjP0vYGULe9q7FALLrUBHwWUokdQvvzyS2xtbRk/fjxTpkyhbNmymPxnoaJKpWL58uVGijDnTG2tsK6c8c3rdVv/YWgVfx5FxZ9HEbZoGylBD1CZmOD6YT9cP+ynt36DexsAuPD6OBIv3slyLGbFHHB8vYEsjn2GdnFs3cyLY28+jCElXc2Ws7fZcva23td3nbUDgOl9mtK6+ovXp5QvZk9kQjJ3H8dmWmj7MDaRxNQ0XIpYv5LTOw3fbk+7rwYSfiWIP/r/QMJj/YlG/KMYdn+5nN1f6v4sqNC4BgCh527lOpb4J32bZ/PMocKo9evNATimZwdOUmIS585coF1nH2rWrkbQ3ZBMdYTICkX/xAsODgbQnvUfGhpqzHAMSp2SxsNV+/Res6ldCdvalYk9fomkmyHEn75KWmTsc+sX7doUUztrHq7eDxoN6ZGx2YqleJ/Wsjj2Gcmp6ewIzFgc213P4tjSjrZ0r5u5HODw9VAexSXRtkZZ7CzNKe348rUP3hVdOBv0iH9u3M+0lfif6xl3EW1Q0SUH76RgazyiM20m9uP+xTusePNHEiPjst3G00W1F7b8k+t4ytStAkCUnt1CrxoLi4xpzedtJS5aLKM8JVX/qKB4OZniUXiCcuBA4d2mpklK4c6Hc/VeK/2/PtjWrsyjdX/pHHUfc1j/0dFFmtXB1M6aOx/Py3TUvbmzE6b2NqQ+iCQ9Vv/ugxL92wDwcMWenLyVQmffpXvEJKXQ3C3z4ljIONTtK19vva99Z6kfj+KSGOPjkemo+8j4ZKISknG0scTJ9t+/wn3rVmTZP5dZe+I6vp4VtWehxCSmsPhwxpkgb3hVMdTbKxCavd+NVuPfIPTcLVYMmJxpWkeHSoW5tQWpCck6xbW7N6VOz6YEnbrGlT26a0asneywKWpPQkSsTuJTqnZF7p/PPCpWu3sTanZ5jbTkVC7u8M/dmysETvqfYfC7fek3sAerl60nPOzfwwlb+DShfkNPkhKTOHMiY7eVmZkZ5SqUIS0tjXt3go0VdsEiCYqyExSRe2UmDqB479bcGvcrj//MfD8G+6a1sapYOmNx7PncD4MXBk+nd553cmxOrTlxjQUHLzK8RU1Gtvr3rBtXJzvGve7JlF1neHPhXl6vVU571H14TCKDGrm/Moe0AXj0bEar8W+gTkvn3omrNBzSLlOdqOBHBK7PmI40t7Zg/Om53Dp8gch7D9Co1ZT1cqNsfTceXg9h3ciZmXbdeA9+nRbjenJwxgadE2PfmPcB6vR0Qs/dJjYsAlNLc1w9KuFatwrpqWns+HQJ0cGP8vYLUADs2rqfI3/707Tla+w9tpG9O/7i0YNHVHarROvXm2FiYsJP3/1KVGQ0AC6lSrDffxPB90JpXk/3BnNtO7SkbceM+86UcC4GQF0vD36a9Q0AkRFR/PjVjHx8d0IpJEF5xTm/KYtjn3XrYTQB9/Qvjs1L/Rq6UdrRlt+PXmF74B3UGg2VSjgwqrWHUW8WaAxOZUsAYGJmymtDO+itc+fYJW2Ckp6SxsVt/pT1cqNSs1oARNwO58BPa/FfvJu0l2zff9apFfup1LQWZb3csClqjwqICY/k7J8HOb5kN+GX7+XuzRUSGo2Gt/uOYeA7vencvR2vd2qFtbUVUZEx/L3/CMsWrubI31kbaapR251e/brqlJWvWJbyFcsCEHwv9JVMUGSKR2F3Mx40aBAqlYrFixdjZmbGoEGDXvoaQy2SVcrdjIVx7mYs9DPG3YyFfkq8m/GrLK/vZvzAp4XB2nL2y97NMpVCUSMowcHBqFQq7ZkFTxfJCiGEEK8SGUFRWILy30WxhXmRrBBCCCGeT1EJihBCCCEAjcrYERidJChCCCGEwsgUTwFIUGJiYli/fj2BgYFER0ejVut+agX1JFkhhBBCPJ+iE5T79+/Tr18/wsLCsLe3Jy4uDgcHB2JiYlCr1Tg5OWFtbW3sMIUQQgiD0qhlikfRNwucOXMmUVFRLF26lD179qDRaJgxYwanT59m6NCh2Nrasnr1amOHKYQQQhiURm24R0GV6wQlMTGR8PBwQkND9T5y4+jRo/Tq1YtGjRqhUv2bTVpbWzNhwgQqV67MtGnTcvsWhBBCCKEwOZ7i2bx5M4sWLeLmzZvPraNSqbh0KeeHC0VERODu7g5k3MsBIDn53/ttNG3alPnz5+e4fSGEEEKJNLKLJ2cJysaNG/n0008xNTXFy8uLkiVLahMIQ3J0dCQ2NuPOvHZ2dpibm+uMyqhUKuLjX3ATMSGEEKIAKshTM4aSo6xiyZIlODg4sGrVKipXNuwN1Z5VsWJFbtzIOGpbpVJRs2ZNNm3aRK9evUhPT2fjxo2UK1cuz/oXQgghhHHkaA3K3bt3adeuXZ4mJwBNmjRh79692mmdoUOHcv78eby9vWncuDGXL1/mrbfeytMYhBBCiPymUasM9iiocjSC4uDggIWFhaFjyWT48OG8/fbb2r7atGnD7Nmz2bJlCyYmJrRv354OHfTf7VQIIYQoqJRzG1/jyVGC0qpVK06cOIFGo9HZXWNoKpUqUyLUpk0b2rRpk2d9CiGEEMamlJGPpxtV9Nm2bRtubm7a52lpaSxZsoQNGzYQEhKCo6MjPj4+jB07Ficnp2z3naME5X//+x/9+vXjq6++4uOPP8bW1jYnzWRbREQEAEWLFs2X/oQQQohXnZeXF717985UXqpUKZ3nEydOZOvWrbRq1Yp33nmH4OBgli9fzpkzZ1i7di02NjbZ6jdHCcoHH3yAtbU169atY9u2bVSoUAF7e/tM9QxxDP2DBw+YMWMG+/fvJy4uDsjY0dOmTRvGjh2Li4tLrtoXQgghlEYpIygAZcuWxdfX94V1jh07xtatW2ndujXz5s3TltesWZP333+fJUuWMHr06Gz1m6ME5cSJE9r/TkxM5PLly3rr5Xb6JygoiH79+vHo0SMqV65Mo0aNALh58yabNm3i8OHDrF69mrJly+aqHyGEEEJJlLYGJTU1leTkZOzs7PRe37JlCwBDhgzRKW/Xrh2urq5s2bIlfxKUK1eu5ORl2fbTTz8RFRXF7NmzM6072bdvH+PGjeOnn35i1qxZ+RKPEEII8arZs2cPW7duJT09HXt7e1q2bMnYsWMpU6aMtk5gYCAmJiZ4enpmen3dunXZvn07UVFRODo6ZrlfRd8s8NixY/Tr10/voti2bdvSt29fNm/enP+BCSGEEHnIkFM8Pj4+L7zu5+f33Gu1atWiXbt2VKhQgZSUFE6fPs26des4fPiwzlloYWFhODk56d3h+3QpRlhYWOFJUNRq9QvPWqlSpQoapY2DCSGEELmklKPuN2zYoPO8c+fOtGzZkmHDhvHDDz+wePFiAJKSknBwcNDbhqWlpbZOduQqQdmxYwfr1q3j8uXLxMbGYmdnR82aNenVqxedOnXKTdMAeHp6EhgYSN++ffVeDwwM1DucJIQQQogMLxohyYkWLVpQp04d/P39SU5OxtLSEisrK1JSUvTWf3rYqpWVVbb6ydFJshqNhg8//JAJEybg7+9PXFwcRYsWJT4+nmPHjjFhwgTGjx+fk6Z1fPTRRxw4cIBly5aRmpqqLU9NTWXJkiUcOHCAjz76KNf9CCGEEEqiURvukRfKlClDWloaUVFRAJQsWZLIyEi9SUp4eLi2TnbkaARlzZo1bNu2jZo1azJhwgS8vb0xNTUlPT2dEydOMG3aNHbu3ImXlxf9+vXLSRcA/PDDDzg4ODBlyhRmz56tve/OvXv3iI+Pp1y5cnz//fc6rzHE1mYhhBDCmNQKmeJ5njt37mBubq49gM3Dw4Nbt24RGBhIgwYNdOoGBARQrly5bK0/gRyOoGzYsAFXV1dWrlxJo0aNMDU1BcDU1JRGjRqxYsUKXF1dWb9+fU6a1woODiYtLY1SpUpRpEgRoqKiiIqKokiRIpQqVYrU1FSCg4N1HkFBQbnqUwghhBAQGRmpt3z79u1cvHiRpk2bahfFPj0nZcmSJTp19+7dS0hIyEvPUdEnRyMoN2/epE+fPs+dT7KysqJNmzasXbs2J81rHThwIFevF0IIIQoiJSySnTdvHmfOnOG1117TDgqcOXOGvXv3UqJECT777DNt3caNG9O5c2e2b9/OiBEj8PHxITg4mGXLllGlSpVM56NkRY4Xyb5s94zsrhFCCCFyRgknyTZs2JBbt26xbds2IiMj0Wg0uLq68tZbb/Huu+9SrFgxnfqTJ0/Gzc2NjRs38s033+Do6Iivry9jx47N0S1xcpSgVK5cWXtQmr5RlKSkJPbv3//CLcLZERwczLFjx3j06BFdunShTJkypKSk8OjRI4oXL54vd1YWQggh8osS/sb38fF56RkqzzI3N2f48OEMHz7cIP3naA1Kz549CQ0N5c033+TYsWOkpaUBkJ6ejr+/P4MGDSI0NJSePXvmOsDp06fTrl07vvjiC3799VftGpOUlBQ6derE6tWrc92HEEIIIZQlRyMoffv25dSpU+zYsYO3334bExMTHBwciI6ORq1Wo9Fo6NChQ6528ACsW7eOhQsXMmDAAO3dEZ+ys7OjVatW/PXXXwwePDhX/QghhBBKooQpHmPLUYKiUqmYNm0arVq1YsOGDVy6dIno6Gjs7OyoUaMGPXv2pHPnzrkObtWqVfj4+PD555/rXU3s7u7OqlWrct2PEEIIoSRK32acH3J1kmznzp0Nkog8z61bt+jdu/dzrxctWpSIiIg8618IIYQQxqHoe/GYmZlpj8jVJzw8/Lm3fhZCCCEKKiVsMza2HC2SzS81atTg77//1nstLS2N7du3U6dOnfwNSgghhMhjGo3hHgWVohOUAQMG4O/vz48//sjDhw+BjPvwXL58mREjRnD37l0GDhxo5CiFEEIIYWiKnuJp164d7733HvPmzeP3338H0O6v1mg0jB07liZNmhgzRCGEEMLgZJGsghOUhIQElixZQt26ddm4cSNbt27l1q1bqNVqKlSogK+vL7Vq1TJ2mEIIIYTByRoUBScoNjY2zJ8/n6+++opmzZpRvXp1Y4ckhBBCiHyi2AQFwNXVVbYRCyGEeOUU5MWthpKjRbLVq1dnzpw5L6wzb948atSokaOgnurRowebN28mKSkpV+0IIYQQBYlaozLYo6DK0QiKRqPJ0t2Kc3tHYw8PD3bv3o2vry8DBgygfPnyWFtbZ6rXoEGDXPUDkJJumus2hGEcH3vJ2CGIJz459qGxQxBPDOkz0dghiHwka1DycIonJiYGS0vLXLUxZMgQ7X9///33qFS6H5hGo0GlUnH58uVc9SOEEEIIZclygnLy5Emd5yEhIZnKIOOOxvfv32fbtm1UrFgxV8H9+OOPuXq9EEIIURAV5KkZQ8lygjJw4EDtCIZKpWLz5s1s3rxZb12NRoOJiQkff/xxroLr3r17rl4vhBBCFESyRjYbCcqoUaNQqVRoNBrmzJmDt7c33t7emeqZmJjg6OhIw4YNqVy5skGDFUIIIcSrIcsJypgxY7T/vWnTJtq0acOgQYPyJCghhBDiVSZTPDlcJHvgwAFDxyGEEEKIJ2QXTw4TlPT0dFJSUjJt+T127Bh+fn5YW1vTu3dvypYta5AghRBCCPFqyVGCMmXKFFavXs3Ro0ext7cHYMeOHUyYMEF79sm6devYtGkTpUqVMly0QgghxCtAbewAFCBHJ8meOnWKhg0bapMTgNmzZ1OkSBGmTJnChx9+SGxsLEuXLjVYoEIIIcSrQoPKYI+CKkcjKPfv36du3bra50FBQdy+fZtRo0bh6+sLZJybcvjwYcNEKYQQQohXSo5GUOLi4rCzs9M+P336NCqVimbNmmnLqlatSlhYWO4jFEIIIV4xao3hHgVVjkZQSpQoQXBwsPb5sWPHsLKyombNmtqyhIQEzMwUfbNkIYQQQpHUBXhqxlBylEF4enpy4MAB/vrrLywtLdmzZw+vvfYa5ubm2jrBwcG4uLgYLFAhhBDiVVGQ144YSo4SlOHDh+Pn58d7770HZJweO3LkSO315ORkTp06Rbt27QwTpRBCCCFeKTlKUNzd3fnzzz+19+Lp0KEDHh4e2uuXLl3itddeo3PnzgYJUgghhHiVKHWbsVqtpm/fvgQGBtKoUSOWLVumcz0xMZE5c+awc+dOHjx4gLOzM506deK9997LdHbay+R4kYi7u/tzbwZYt25d5syZk9OmhRBCiFeaUqd4li9fzvXr1/VeS09PZ9iwYZw4cQJfX18aNGjAlStXWLx4MefOnWPp0qWYmGR9b45BVrFGR0eTkJAgh7IJIYQQhVRQUBAzZ85k3Lhx/PDDD5mub9q0iRMnTjBw4EA+//xzbbmrqytTpkxh69atdOvWLcv95WibMUB8fDyTJ0+mSZMmvPbaa/j4+GivBQYG8u6773Lx4sWcNi+EEEK8stQGfBjK559/TpUqVRg4cKDe61u2bAFgyJAhOuX9+/fHyspKuywkq3I0ghIbG0v//v25fv061atXx8nJiZs3b2qvu7m5cerUKbZv366z9VgIIYQQL6e0NSh//vknp06dYsOGDXqnaTQaDefPn8fZ2RlXV1eda1ZWVlSvXp3z589nq88cJSjz5s3j+vXrTJ48mW7dujF79mydNSfW1tZ4e3vj7++fk+aFEEIIYSDPznDo4+fn98Lr4eHh/PTTTwwZMoRq1arprRMVFUViYiJVq1bVe93FxYWAgIBMB72+SI6mePbt20fTpk1fOJdUunRpwsPDc9K8EEII8UpT0r14vv76a5ycnBg9evRz6yQlJQFgYWGh97qlpSWQscsnq3I0ghIWFsbrr7/+wjo2NjbExsbmpHkhhBDilaY24Cael42QvMiOHTs4cOAAS5cuxcrK6rn1nl5LSUnRez05ORkgW1uNc5Sg2NraEhER8cI6wcHBODk55aR5IYQQQhhZSkoKkyZNomnTpri6unL37l2d60lJSdy9exdbW1uKFSuGtbX1c+/BFx4ejp2dXZandyCHCUrt2rX566+/njuX9ODBAw4dOkTLli1z0rwQQgjxSlPCvXiSkpKIiIjgyJEjemdNAgICeP311+nYsSMzZsygVq1anDx5kpCQEJ2FsklJSVy+fJm6detmq/8cJSiDBg3i3XffZdiwYXz33Xc6127evMnnn39OcnLyc7ciCSGEEOL5lHATYmtra2bOnKn32gcffICbmxujRo3SnoHm6+vLyZMnWbp0qc45KKtXryYpKQlfX99s9Z+jBKVZs2aMHj2a2bNn07lzZ+1dixs2bEhMTAwajYYJEyZQr169nDQvhBBCvNKUsM3Y3Nyc9u3bP/d6sWLFdK736NGDzZs388cffxAbG4uXlxdXr15l1apVeHt707Vr12z1n+OTZEePHo2Xlxd//PEHgYGBREVFoVKpaNGiBYMHD6ZRo0Y5bVoIIYQQBYypqSkLFy5kzpw57Nq1ix07dlCiRAmGDBnCqFGjMDU1zVZ7Ko1Go4SRJKP7p2QvY4cgnkjVGH/uVWRofOxDY4cgngjvM9HYIYhnlD2Z850xWbG+1JsGa6vX/ZUGays/5WgEZfPmzVSrVu25B7YAXL16lcuXL2fr3H19goKCWLZsGYGBgURHR6NW6w58qVQq9u/fn6s+hBBCCCWRkYMcHtT2ySefvDQpOHDgABMn5i7jv379Ot27d2ft2rUkJycTFBSEtbU1ycnJhISEYGpqKjcoFEIIIQqhHN8s8GXS09NRqXI3VD9r1ixMTU3ZsmULy5cvB+Czzz7jyJEjfPXVV8TExPD1118bIFohhBBCOZR4s8D8lmcJyp07dyhSpEiu2jh16hRvvPEGlStXzpTs9OvXjyZNmjB9+vRc9SGEEEIojVpluEdBleU1KP+drvHz8yMkJCRTPbVaTWhoKKdPn6ZFixa5Ci4mJoby5csDGdudQPcc//r16z93j7YQQgghCq4sJyibNm3S/rdKpeLy5ctcvnxZb12VSkWdOnX49NNPcxVcsWLFiIyMBMDOzg4rKyuCg4O115OSkp577r8QQghRUCnhJFljy3KC8vRmQxqNhjZt2jB48GAGDRqUqZ6pqSlFihTBxsYm18FVqVKFq1evap97enqyevVqfHx8UKvVrF27lipVquS6HyGEEEJJZBdPNhKUZ8/VHz16NA0bNtQpyws+Pj4sXryYpKQkrKysGDVqFEOGDKFNmzZAxkjN3Llz8zQGIYQQQuS/HJ2DMnr0aEPHoVf//v3p37+/9nmDBg1Yu3Yt27Ztw8TEhNdffx1PT898iUUIIYTILwV5cauh5Pioe2OpWbMmNWvWNHYYeaZEz2a4zfkAgBv/m0f4qn9PK7SpVo7S73bE1qMylqWLYmpnQ+qjaBJvhnJ/2R4idh7Pcj9FGtek9sZvnns9eNYm7n5fME8fNBSXns2oPvd9AK7+bx73Vx7QXrOrWYHiHRrg1MIDq/IumDvZk/o4hij/SwTN2Urc+du56tvjzy8o2sIDgIOl+6BJL8ibBXPm0ImzrNi8h5v3QoiOjaN4UUdqVKnAoB4d8KxeNVP99HQ1m/cdYqvfEa7fCSIlJZXiRR2p5VaJ0QN7UqHMy89MCgl/SPu3/vfc6+2bN2TqxPz5A00pSm1ZiVnpknqvpT+OILT9G/8WmJpi94YvFm6VMXevgnnF8qjMzYmYNI34LTuz3bfKxpoig/th3boZZqVKoklOJvniFWL/WEvyyYCcvqUC4dX7js+swCUohZlF6WJU+mEo6XGJmNpZZ7puV6cSRdt7E3vmOrGnrpIWk4CFsyNFX69P9SUf8mDdQa6PmZWtPqOPXiT66MVM5THH9S+AflVYli5G1R/fIS0uETM9n4Xb1HcpUt+N2LM3ebTjOOnxSdjVqoBL96aU6Pwal4bN4NHOEznq2/Wd9jg1qUl6Ygqm1ha5fSsF0vTFa1i6fgeORexo3ag+jkXsCQoN5y//M+z/5xTfTxhOl9ZNtPUTEpN4/5sZHA+8RLVK5fH1aYaFhTkPHkdy5sJV7oaEZSlBecq9UjlaN6qfqbxK+TIGeX8FjTo2jtjVGzKVa57ZVQmgsrbCafwoICN5SX8cgVlJlxz1qbK3w2XRTMwrVSD15m3iNm5DZW2NdYvGOM/9mYjvfiZ+664ctV0QyBqUApCgJCQksGPHDm7fvk1UVBT/vXWQSqXihx9+MFJ0hlX1l1GkRsYSsfM4ru9lvi31w01HeLD270zlpnbWeOz8Aec3WnB/yS7iAm5kuc/ooxcJ+vnP3IRdKFWb+R6pkbE83HGCcqMy34EzfMNhLr83i8Q7YTrlzj2bUmPuB7j/PJzH+86gSU3LVr/WlUtT6fMBBM3bhrNvY0zLOefqfRREjyKiWL5xJ8WcHNgw93uKOTpor50IvMQ7n/zInD826CQo3/y6hOOBl/hizBB6d2ydqc3UtOx9Du6VyvHegB45fxOFjDo2jpjffn9pPU1SMg8/mEjK1RuoH0dQ5N1BOAwbnKM+HYYNxrxSBRIOHOLxp9/Bk1HE6LmLcfl9Lo4fjibJ/yTpDx7lqH2hfIpOUAIDAxkxYoR2q7E+hSVBKTW0Iw5Na3Ghx1c4NK2tt44mRf8P2fS4RKL+CsTGrSzWFUtlK0ERmbm+2xHHprU42/1rnJrW0lsnZPFuveUPNhyhwvg3sKlcGtvq5Yg7dyvL/apMTag+ZwxJ98K5/dNanH0b5yj+gi70wSPUag213SvrJCcA3nVqYGttRWR0rLbs0o077Pz7GO2bN9SbnACYmyn6R13hkZZG0tGcjRz+l3XLjAQ0esEybXICoI6MInbVepz+Nwrbrh2IWfSHQfpTGlmDovAE5bvvvkOtVjN79my8vb1zfTKtUllXdaX8Z29y/7edxPhffm6C8jwm1hY4PPlFGn/5XrZea1WhJCXfbo+ZnQ0pD6OI8b9E0u2wl7+wkLKp6kqlz94k+LedRPtffm6C8iKa1PSM/09Lz9bryo/riV2tCgR0+uy5yeiroLxrSczNzLhw9SaR0bE4Odhrr506f4X4xCSd6Zedfx0FoEPLRsTGJ/D38QDCHz7Gwd6Ohp41KVc6+1MMDx9H8efOA0THxOFQxI461avgXrFc7t9cAaWyMMemQxtMXZzRJCWRev0WyQHnQJ13KyVMixUFID3kfqZraU/KLBvUhcKaoBg7AAVQdIJy7do1Ro8erd1WXCiZmuA2+31SQh5x98dVWXqJVYWSlOjVHJWJCeYlHHBqUw/LUsUImrmRhMt3s9W9c6/mOPdqrlP2aPsxboyfT3p0fLbaKuhUpiZUnz2G5JBH3P5hdY7aKFK/KrbVypIc+pj4K0FZfp29Z2XKje3BvVmbiQ3M+qhLYeRgb8e4t/sw9bdVdBv+Ma0a1cexiB1B9x/wt38AjerW4sv3h2jrX7iW8fW6/+ARHd8eT1RMnPaaSqWid6fWTBwxCFPTrN/Z41jABY4FXNApa+BRne/HD6OUc/FcvsOCx7R4MYp9q3uaeFpIKBHfTiX5zLk86VMdFY1pieKYli5F2m3dn2tmrhnriczLl82TvoUyKDpBKVq0KNbWmRcoFiblxr+Bba0KnO/6BeqkrJ2Ka1WxJOUm9NY+Vyencvub3wmdtzXL/aY+jubOpBVE7j9N0r2HmFiZY1enMuUn9qd450ZYlHDkfLcvQfPqLNUqP/4N7GpXJCAbn8WzzBztqDYrY4fHjS+XZfmvSxMrC6rPHkPC1WDuTluf7X4Lo4Hd21PapThfzljEht1/a8vLlXbBt20znamfiCfTPVMXrqJ1o/qMHtyLksWLcu7KTb6bvZS12/0o6lAkS2tKrCwtGN6vG60b16dMyRIAXLsdxLyVGzkReJmhEyezbs4kbKysDPuGFSx+226Sz54n9dZdNPEJmJUphd0b3bDt3oniM3/kwdtjSL1u+KQ68Z/j2HXrhMOwwTz+bJL2+8nE0QH7fr0y/tvezuD9KoWMoOThzQINoXPnzuzdu9fYYeQZu7pVKfN+D0LmbyP29LUsvy7qr7P8U7IXR8v04XTDUQTP3ED5if2o/vsnqMyzlnMmXg0mZPZmEq4EoU5IIi0ilqi/znKhx1ck3Q2nSMPqFH3dK6dvrcCxr1eF8h90J2jeNmJOZf2zeMrExpJayz/CpnJp7s3azMNt/ll+baUvB2BV3oXLY2Zne1qosFqybjvjv5+Fb9tm7FwyjRObFrF21neUKenMJz/NY/rif0e4NE9+cVUsW4qpE0dTqWxpbKyteK1uTaZ/NgYTExW/b9xFahYWLBdzdGD0oJ7UqFKBIna2FLGzxat2NRZ8/zEe7pW5FxrOxt0H8+x9K1HMoj9IPnUWdUQkmuRkUm/eIXLyL8SuWo+JlRVF3s3ZItiXiZ6/jLSwcGzatMBl5QIc//ceTp/9j5JrF6OOiQFAoy68f0BpVIZ7FFSKTlDef/99bG1tGTFiBMeOHSMoKIjQ0NBMjwLJ1AS3WaNJvBXKvSlrctSEJi2dpLvhBE1fz72f1lL0dS9KDe2Yq7DS4xJ5uPEwAEUaVc9VWwWFytSE6rPGkHDzPrdz8FmY2FjisWIijq9VJ2jeNm5Nyvr5MQ6NauA6pB13Z2wg/lL2pucKq5PnLjNjyVpavlaPj4a9SdlSzlhbWVKjSgV++eIDnIs5sXzjLoLuPwDA3i7jthotGtbNNI3jXqk8ri4liE9M4lZQ5pubZpWZqSk92rcE4PSFKzlupzCJ37ANAMu6HnnSvvpxBOGDRxH752ZMbGyw69UV6yYNSdj3N48/+TajTmRUnvQtlEHRUzzm5ua4u7uzcOFCDh58/l8tz7tpoZKZ2lphXSXjVgGN7+n/pVhl+kiqTB9J6MLt3P5y2QvbizwQQIXPB+DQuGa2pnr0SX2c8deJqfWrMYxtamuFTZXSALQI0r/2xH36SNynjyR44Q5ufLFM57W1V07EsVEN7s3anK3kBMC+dgVUJiZU/LgPFT/uo7dOi9C1AJxq/SFxF+9kq/2C6ODxjAO4GnhkTpCtrSyp7V4Zv6OnuHLzLmVLOVOhTCnOX72Fva2t3vaK2GWUJyWn5iqup4t1E5OSc9VOYZEeGQ2ASR7+nFBHRBI1dRZRU3XPd7L08gQg5dJVPa8qHGSKR+EJyrRp01i8eDHVqlXDy8sLBweHl7+ogNCkpBG+cr/ea7a1K2HnUYlo/8sk3QzJ0vSPZamMFe+GmCKwr+8GQNK98Fy3VRCoU9K4v9JP7zW72hWx96hElP9lEm+GEv3M9I+pvQ0eaz7DwcuNuzM2cHty9kdf4q8EPbfvEr6NMbOz5v6qA6DRkBoZq7deYZPyZCrm2a3Ez4qMzkigzc1MAXjNsxbb/P7hxt3gTHVTUlK5F5rx79jVJXeLW89dydi+X6bkq3c2jT6WtTMSyDQ9u2zymm2n1wFI2KP/e6cwkARF4QnKpk2baNWqVaG8IaA6KYUb4+frvVZ2Qm/sPCrx8M+/dY66t6tTmbjAm5nqmxUrQvnPBgAQuf+M7rWi9pgXtSc1Ipa0iH9/4D+vrRI9m1HctzHq5FQebT2ao/dW0KiTUrj6P/2fRYUJb2DvUYnwP//WOerezMEWj7WfU6RuFW5PWcvd6S9f3Gr+zGeR+uSziDx0nshD5/XWd2pWGzM7a65NWPBKHXVfr5Y7q7ftY/3uv3ijYytcihfVXjt8MpCAS9extDDHs0bGcfdtmnoxc9mf7D7kT/+ubantXllbf/7qzcTGJ+BdpzrFizpqy2PjE3gYEYW9rQ0lnim/dOMO1SqVw8REd6rIP+Aif2zaA0Dn1q/O+TRmFcqRHvYATVKSTrlpKRccPxwDQPwu/X9oZZXK1hbT4kVRx8WjfhzxzAUVKitLNIm6fdt0aINNx7YkB14g8e9/ctW3UDZFJygJCQm0aNHC2GEoRuVpIzB3sic24AYpIY/QqNVYli2BU+t6mNpY8njnccJXH9B5Tam3O1BuQm/u/fynzomx7ovGo0lLJy7wFin3H2NiaY6dZxXs61VFnZrGzY8WkBz0ML/fYoFRc+kEitStQuLtMFQmKipMeCNTnUe7TupMybi+3Z4KH/bmztQ/ufPzunyMtmB5vWkDNtStiX/ARXyHfYxPYy+KOTlwOyiUgyfOotFoGDukD45FMqZcbKys+O5/7zL6q+kMnjCJNk28cC7mxPmrNzlz8RpFHYvw5Zi3dfrwO3qKL6b/Rtc2Tfl+/HBt+dSFK7kXGk6d6lW0idH120EcD7wEwOhBPfGs4ZZPXwnjs2nbEvs33yA54BzpYeGo4xMxK1MaqyYNMbGyJPGIP7ErdE+ith/cF/MKGWfGmLtlJIu2Xdph6ZlxplDy2Qs69+WxbtWUYl99RPz2PUR885O2XGVlSek960k6fob0kNCMn3d1amHpUZPUW3d59Mm3hXqXYeF9Z1mn6ASlevXq3LuXvYPHCrPQeVsp2sEbu9qVMG9VB5W5GWkRsUT/c4GH6w/yaEvWRzzClu/FsVltijRwx6yoPSqViuSwCMLXHCB04Q4SZMHmC1mXzRjmt65Ykgof9tZbJyno4SuxZsTQTExMmPvtBNZs28/ug/74HT1FUnIKDva2NGtQhze7vk7j+rqHGTauV5vVM79mwaot+AdcJDYhgeJOjvTu2Jrh/bvhXMwpS3138WmC39HTXLx2myOnzpGWlk4xJwfaNW9Ivy5tqV/LPS/esmIlnz6LWfmyWLhXwbJOLVTWVqhj40gJvED8zn0k7NyX6TVWjRpgVd9Tp8yyTi0s6/x76GFWbhyoSUklYe/fWHrWwqphPQDSgkKImrOYuNUb0CQX7rVAcpIsqDT/vbmNgpw6dYpRo0bx22+/4eGRNyvFn/qnZK88bV9kXWpB3hdXyDQ+9qGxQxBPhPeZ+PJKIt+UPZm3619mlBtgsLbG3VthsLbyk6JHUDZs2EDJkiXp27cvnp6elC1bNtPccGG5F48QQggh/qXoBGXTpk3a/z5z5gxnzpzJVEcSFCGEEIXNq7Ms/vkUnaBcuSIHIgkhhHj1KHbtRT5SdIIihBBCCOOIiIhg6tSpXLx4kfDwcBISEihRogR16tRh6NCh1KxZU6d+WloaS5YsYcOGDYSEhODo6IiPjw9jx47FySlrC9WfJQmKEEIIoTBK2MUTGxvL7du3ady4MaVLl8ba2pqQkBA2bdpE7969mT9/Ps2aNdPWnzhxIlu3bqVVq1a88847BAcHs3z5cs6cOcPatWuxsbHJVv+KT1BiYmJYv349gYGBREdHo/7PHWJVKhXLly83UnRCCCGE4SlhDUr58uVZsybzCdn9+vWjVatW/Pbbb9oE5dixY2zdupXWrVszb948bd2aNWvy/vvvs2TJEkaPHp2t/hWdoNy/f59+/foRFhaGvb09cXFxODg4EBMTg1qtxsnJCWtra2OHKYQQQrwyihcvjqWlJbGx/55OvmXLFgCGDBmiU7ddu3a4urqyZcuWbCcoir6b8cyZM4mKimLp0qXs2bMHjUbDjBkzOH36NEOHDsXW1pbVq/Xf3E0IIYQoqDQGfORWamoqERERPHz4kHPnzjF+/HgSEhJo2bKltk5gYCAmJiZ4enpmen3dunW5d+8eUVFR2epX0SMoR48epVevXjRq1IjIyEhtubW1NRMmTOD69etMmzaNn3766QWtCCGEEAWL2oD7eHx8fF543c/vxYfOnTlzhkGDBmmf29vb8+677zJq1ChtWVhYGE5OTlhYWGR6vYuLi7aOo6NjluNWdIISERGBu3vG0dJmZhmhJj9zvHHTpk2ZP1//Td6EEEIIkXvVqlVj6dKlpKSkcOfOHbZs2UJ8fDwpKSna381JSUk4ODjofb2lpaW2TnYoOkFxdHTUznHZ2dlhbm5OaGio9rpKpSI+Pt5Y4QkhhBB5wpCLZF82QvIyDg4ONG787128u3fvjq+vL0FBQSxatAgAKysrUlJS9L7+6cCClZVVtvpV9BqUihUrcuPGDSAjGalZsyabNm0iJSWFxMRENm7cSLly5YwcpRBCCGFYSlqD8l8ODg60bt2aw4cPExwcDEDJkiWJjIzUm6SEh4dr62SHohOUJk2asHfvXm32NXToUM6fP4+3tzeNGzfm8uXLvPXWW8YNUgghhDAwtQEfeeHpdE1MTAwAHh4eqNVqAgMDM9UNCAigXLly2Vp/AgpPUIYPH87Ro0e181dt2rRh9uzZNG3alObNmzN9+nR69Ohh5CiFEEKIwufRo0d6y4ODg/Hz88Pe3p7KlSsD4OvrC8CSJUt06u7du5eQkBDt9exQ9BoUlUqVaUVwmzZtaNOmjZEiEkIIIfKeEk6SXbBgAUePHqV58+aUKVMGgFu3brF582YSEhKYPHmydgChcePGdO7cme3btzNixAh8fHwIDg5m2bJlVKlSJdP5KFmhqARl8+bNOXpdt27dDBqHEEIIYUyG3GacU61atSI8PJw9e/YQERFBWloazs7OtGzZksGDB+Ph4aFTf/Lkybi5ubFx40a++eYbHB0d8fX1ZezYsdja2ma7f5VGozH+V+GJatWqoVKpyE5IKpWKy5cv57rvf0r2ynUbwjBSNQr400EA0PjYh8YOQTwR3meisUMQzyh7Mnc7Y17m8wr9DdbWpDurDNZWflLUCMrvv/9u7BCEEEIIo1PMyIERKSpB8fb2NnYIQgghhNEp4WaBxqboXTxCCCGEeDUpagRFCCGEEMpYJGtskqAIIYQQCiPpiUzxCCGEEEKBZARFCCGEUBhZJCsJihBCCKE4sgZFEhQhhBBCcSQ9kTUoQgghhFAgGUERQgghFEbWoEiCIoQQQiiORiZ5ZIpHCCGEEMojIyhCCCGEwsgUjyQoQgghhOLINmOZ4hFCCCGEAskIihBCCKEwMn4iCYoQQgihODLFI1M8QgghhFAgGUERQgghFEZ28UiCIoQQQiiOHNQmCYoQQgihODKCImtQhBBCCKFAMoLyRLUGD40dgngiLc7YEYinZrX81dghiCdGb//I2CGIfCRTPJKgCCGEEIojUzwyxSOEEEIIBZIRFCGEEEJh1BqZ4pEERQghhFAYJaQnd+7cYdu2bfzzzz8EBQURHx9P6dKlady4McOGDcPZ2VmnflpaGkuWLGHDhg2EhITg6OiIj48PY8eOxcnJKdv9K3qKJyoqiitXrjz3+pUrV4iOjs7HiIQQQohXw/r161myZAmlS5dm2LBhTJw4kTp16rBq1So6d+7MzZs3depPnDiRadOmUbFiRb788kt69OjB5s2bGTRoEAkJCdnuX9EjKD///DMXLlxg8+bNeq9PnDgRDw8Pvvnmm/wNTAghhMhDSrgXT7t27Rg2bBhFihTRlvXp0wdPT0++/PJLfv31V2bOnAnAsWPH2Lp1K61bt2bevHna+jVr1uT9999nyZIljB49Olv9K3oExd/fn1atWj33euvWrTl27Fg+RiSEEELkPY0B/5dTtWvX1klOnurUqRMAV69e1ZZt2bIFgCFDhujUbdeuHa6urtrr2aHoBOXBgweULl36uddLlizJgwcP8jEiIYQQ4tUWHh4OQPHixbVlgYGBmJiY4Onpmal+3bp1uXfvHlFRUdnqR9FTPFZWVoSFhT33elhYGGZmin4LQgghRLYZ8hwUHx+fF1738/PLVntPp3V69OihLQsLC8PJyQkLC4tM9V1cXLR1HB0ds9yPokdQatSowfbt20lKSsp0LTk5me3bt1O9enUjRCaEEELkHTUagz0Maf78+ezZs4c2bdrQvXt3bXlSUpLe5ATA0tJSWyc7FD38MHDgQEaNGsWQIUMYP368Nhm5fPky06dP5969e0yYMMHIUQohhBCGZcij7rM7QvI8y5cvZ8aMGXh7e/Pzzz+jUqm016ysrEhJSdH7uuTkZG2d7FB0guLj48PIkSOZN28eAwcO1Lmm0WgYPnw4bdu2NVJ0QgghxKth6dKlTJ48mUaNGjFv3jysra11rpcsWZI7d+6QkpKSaSTl6ZqVkiVLZqtPRScoAB988AGtW7dm69at3L17F4AKFSrQpUsXateubeTohBBCCMNT0r14Fi5cyLRp02jWrBlz5szRTtk8y8PDg1u3bhEYGEiDBg10rgUEBFCuXLlsrT+BApCgQMZWJ0lGhBBCvCo0Cjnqfv78+cyYMYNWrVrx66+/Pnedia+vL5s3b2bJkiU6CcrevXsJCQlhzJgx2e67QCQoQgghhMhfK1euZMaMGRQvXpy2bduya9cuneu2tra0adMGgMaNG9O5c2e2b9/OiBEj8PHxITg4mGXLllGlSpVM56NkhaISlNmzZ6NSqRg5ciQmJibMnj37pa9RqVSMGjUqH6ITQggh8ocSTpI9f/48AI8ePeLTTz/NdN3V1VWboABMnjwZNzc3Nm7cyDfffIOjoyO+vr6MHTsWW1vbbPev0ihlHAmoVq0aKpWKwMBALCwsqFat2ktfo1KpuHz5cq77ftylRa7bEIaRFmfsCMRTK26WNXYI4onR2we+vJLIN5Ye7fK0/S7lOhusrW33thusrfykqBGUp1uhns5xGWprlBBCCCEKFkUlKK6uri98LoQQQrwKDHkOSkGl6JNkJ06cSGBg4HOvnzt3jokTJ+ZjREIIIUTeU+pJsvlJ0QnKpk2buHfv3nOvBwcHs3nz5vwLSAghhBD5QlFTPNmVkJAgNwsUQghR6Cho/4rRKO63e2hoKCEhIdrnt27d4uTJk5nqRUdHs3r1asqXL5+f4QkhhBB5TkknyRqL4hKUjRs3as9DUalUzJ8/n/nz52eqp9FoMDEx4YcffjBClEIIIUTekUWyCkxQ2rRpg6urKxqNhk8//ZTevXtTt25dnToqlQobGxtq165NqVKljBSpEEIIIfKK4hKUatWqaQ9oO3nyJD179qROnTpGjkoIIYTIPwV5942hKC5BedaPP/5o7BCEEEKIfCeLZBW+zfj06dOsWLFCp2zXrl34+PhQv359vv/+eyNFJoQQQoi8pOgEZf78+Rw5ckT7PDg4mA8//JCEhARKly7NihUrWLdunREjFEIIIQxPDmpTeIJy7do16tWrp32+fft2VCoVmzdvZtu2bTRp0oT169cbMUIhhBDC8DQG/F9BpegEJTIykuLFi2ufnzhxAi8vL1xcXABo1aoVd+7cMVJ0QgghhMgrik5Q7OzsiIqKAiAtLY2AgADq16+vvW5mZkZSUpKRohNCCCHyhlqjMdijoFJ0glK1alW2bNlCREQEa9euJSkpicaNG2uvh4SEUKxYMSNGKIQQQhiexoCPgkrR24zfeecdRo4cSZMmTQCoVauWzpqUI0eOUKNGDWOFJ4QQQog8ougEpXnz5ixfvpz9+/djb2/PgAEDtNciIiIoXbo03bp1M16AQgghRB4oyLtvDEXRCQqAl5cXXl5emcqLFi3K7NmzjRCREEIIkbckQSkACYoQQgjxqpGTZBWWoEycOBGVSsV3332HqakpEydOfOlrVCpVgb6jseOiNZi66L/hoTryMZGDemS+YGKCZZsOWLZqh2mFSqjMLVBHPibt+hUSVixGHRqco1hsx3yI1eudAYgc1h/1/ZActVNQFV+9BtOS+j+L9IjHPOqp57N4RpEJH2LdKePr9+jN/qSHZu3rZ1a5CpZNm2Lh5YVpqdKYFCmCOiqK1HPniF+7mrTr17P3RgqBZhP74OJRiaIVS2JV1J60pBRigx9xY+9pApbtIykqTlvXsYILVTs0oELz2jhVLIlNcQeSouO5H3CDM4t3E3Tscrb7L12/Kg3H+FKqXhXMrCyIuh3GhT8PErB0Lxr1q/eL49Dpi6zc+Tc3g8OIjk2guFMRalQqy6DOrajjXlFbL+xRJIs27ePSrSDuP4wgJj4BR3tbyrgUp3vr1+jUrAHmZqZZ6vPz2SvYevDEC+t413Jj0Vejc/XehHIpKkHZtGkTKpWKr7/+GlNTUzZt2vTS1xT0BAVAHRdL0tbMB85pkhIzV7aypsjn32Nepz5pN6+T7LcbUlMwKVYCsxq1MXUtm6MExbxBY6xe74wmIQGVjU1O3kahoI6LJUHP4X+aRD2fxTMsGjXGulNn1AkJmGTz62f/v/9hUaMmqVevkHz4EJrERMyqVMHKxwfLFi2I/vZrkg8fzlabBV39dzoQfuEOdw9fIOFxDObWlpSqV4XG/+tJ7f6tWO37NbH3IwBoMqEX1bo24tG1YG79FUhSVBxFK5Wictt6VHm9Pge++p2ApXuz3HfltvXouuAD0pJTubrNn6SoeCq3qUurrwZS2suN7SNn5dXbVqQZK7awdIsfjva2tGpQGyd7O+6FPeSvk+fZfzyQ70cPoHPzBgAEhT9i5+FT1K5anureHhSxsyE6Np4jAZf5cu4qth08yYIv3sPM9OVJSmtvD0o7F9V7bfuhkwSHP6Zp3eoGfa9KIlM8CktQrly58sLnhZUmPo7E1cuyVNdu1HjM69Qnbs7PJO/elrlCFr7x/0tVxAG7MRNIPuSHiVNRzGvXzXYbhYUmLo745cuy9RqVgwNFJkwg6YAfJkWLYuGZva9f0v79xHz/faYRF6s2bXD47AuKjJ/Aw2PHIC0tW+0WZLNqvkt6cmqm8iYfvsFrY3zxHtUVv8+XAXDn73OcnLedBxfv6tQt07AavVZ+QvNP+3FtxwniH0S9tF8LO2ten/IO6nQ1f/b5nvBztwH4Z9p6eq+eiHunhlzvcpKr2/xz/R4LgkeRMSzfeoBiDvasn/YJxRzstddOXLjG0G9mM2ftTm2C4ulWkSPLJmNionuCRWpaOiMmzeHkxev4HQ+kXeN6vExrbw9ae3tkKo+JT2DZFj/MzUzxbdkwl+9QuQryCbCGouhzUIQu08pVsWzZluRDfvqTE4D09Gy3azf6QwDi5/+Si+heXUXGZ3z9Ymb+kqPXJ27aqHc6KGn/ftKCgjBxcMSsUqXchFjg6EtOAK5tPw6AU0UXbdnF9YczJScAwcevEOR/GTNLc0rXr5qlft06NsCmuANXt/lrk5On8Rz5OWNkrc5Anyy/j4Iu9FEEao2G2lXL6yQnkDG9YmttSWTMv9Nt5uZmmZITAHMzU1o1yEg27t5/mKuYth86SVJKKj4N6+BUxC5XbQllU9QIyqtKZW6BRcu2mJZwQZOUSNqdW6RdDAS1WqeeZYs2ACQf8kNlY4u5d2NMijujiY0h9dyZHK0ZsfRpj0WjZsRM+hRNbIxB3k+BZm6BVZu2mLg8+Sxu3iL1XObP4imrdu2xataMqM8/RROTB1+/9CejJjlIPAujSm0yRqceXg7KUn11asbXTZ2Wta9f2cY1gYxRmf8KPn6F1IQkSteviqmFGekphX9Eq3zJEpibmXLhxj0iY+J0EoJTl24Qn5hM6waZRzn+Kz1dzZGASwC4lS+dq5g27D8GQK82jV9Ss2CTRbIKS1Bysm1YpVIxatSoPIgm/5gULYb9+M91ytLDQombOZm0C4HaMrOq1QAwdS6J3W+rMCniqL2mUatJ3rWF+IW/PveXaaZ+S7hg8+4Ykv/aS+rxf3L/RgoB02LFcPhM97NICw0l5qfJpAYG6pSbuLhgP3oMiXv3kvyP4b9+5tVrYFahIukPH5B2+/bLX1AIeQ3riLmtFZb21rh4VKKMtzsPLt3lxNznjCA+w961GOWa1CA1IYngE1mbLi5aOWORdMTt+5muadLVRAc9pLh7WRzKORNxIzR7b6YAcrC3ZeyArvy8fDPdxv1A6wYeONrbEBT+iL9PXaCRhztfDO+T6XWRMXGs3n0IjSbjv/3PXeVe2EM6Nq1PS6/aOY4n8Optrt8LpXwpZ7xrueXmrSmerEGRBMXokvfvIvXSOdLv3UGTmICpS2msOnfHsl0Xinz9E9ET3iP9zk0AVA5OANi88x4p/kcyduw8eoiZW3XsRo3HqlN31NFRWVvPolJhN24iJCUSv2BmHr7DgiNx1y5Szp8j/c4d1AkJmJYqjU337lh37oLT5J+IGP0eaTczPgtUKhw+mYgmMZHYWYb/+qns7Sky8VMAYufMyXLSWdh4DeuIrbOj9vntvwLZPX4BiRGxL3ydqYUZnX59DzMrCw5+v5rk6IQs9Wdhbw1ASoz+RdHJsRnllkVenYXkAzu1wrVEMb6cu4oNfke15eVKlqBry4aZpn4AomLjmb9ut/a5SqVicJfWvN+/S65iWb8/o/+ebRrlqh1RMCgqQfHz8zN2CPkucc1ynefp924TP3c6mqRErLv3xab/EGJ/yPiLXqVSZdQJvkfcT99of2mlnTtD7I9f4vDLb1h1603iuhUvXVBp5fsG5rXrEvP1R2ji415Y91UR//t/Pos7t4mdMR1NYiK2ffpiO3gI0V9mfBY2vd7AwrMukZ98hCbOwF8/KyscJ/2AWdmyxK9eRfLBvw3bfgEy3ytjC6lN8SKUru9Gs0/6MHDX92waMo0HF+7ofY3KREWHX0bi2sCdK1uPcWrBjnyMuPBZsmU/s1Ztp3+H5vTr0JxijkW4HRLOr6u2MfHX37l6J4T/DfTVeU1FVxfOrfuV9HQ1DyKi8DtxjrlrdxJw5RZzJg7Hwd4223HExiey51hAoV8c+5RM8Shskayrq2uOHoVR0q6tAJjV/Hd+V/0kkUg9cTTTX9Tpd26iDr+PiY0tpmXKv7Btk9JlsBk4lKR9O0k9fdzAkRc+iVszPgsLj4zPwrRMGeyGDiVx105Sjhv462dlhdOPk7Hw8CD+z7XELVxg2PYLqIRHMdzYc4r1AyZj5WhHhxnD9dZTmajoOPM93Ds35Oo2f3Z+MC9b/aQ8GSGxKGKt97rlkxGW5JisjcgUdCcvXueXFVtp6VWLD9/qQRmX4lhbWlCjUllmfDgU56IO/L7tAMHhj/S+3tTUhFIlijKgU0u+GN6Hc9fvMGftzhzFsuPwSZKSU16ZxbFqNAZ75MbChQsZO3Ysr7/+OtWqVXvp/e/S0tJYuHAh7dq1o1atWjRt2pSvvvqKyMjIbPetqBEU8S9NdBQAKisrbZk6JAjca2gTlUyveVKusrR8Ydum5SqgsrDEqm1HrNp21FvHaeEqAGK+/4xU/yPZDb9QUT/9LKwzPguz8hlfP+sOHbHuoP/rV3xlxtcv6vPPSP4na18/lbU1jj9OwaJOHeJXr5LkRI/YkMdEXA/BuVYFrJ3sSIz893vBxMyUjr9mJCeXN/3DrnHzs32oWsTN+5SsU4miFUvx4PwdnWsqUxMcypYgPTWN6HsPDPF2FO/g6YsANKiVeReUtaUFtauUx+/EOS7fDqaMS/EXttXUM+MX28lLN3IUy9PFsW+0bZKj1xc0StlmPG3aNIoUKUL16tVJSEggIiLihfUnTpzI1q1badWqFe+88w7BwcEsX76cM2fOsHbtWmyycU5UgUhQLly4QGBgINHR0aj/M3JQ0NegPI+Ze8Y3c3rYv4v1Us6exrJ1O8zK69lyamaOSamM0aT08LAXtq0ODyNp73a91yy8GmFStBjJR/5CkxCP+iVtvQrMn/zFkB6a8Vmkh4WRuOM5X7/XGmFarBhJf/+FJj6e9LCsff1UtrY4TpmKRc2axP3xO/FLFhsm+ELI1iVjLdazPwtMzE3pMncMVdp5cXH9YXaPXwg5GCIPOnqRGj2aUKGlB1e2HtO5VqZhNcxtrAjyv/xK7OABSE3NeJ+R0fr/KIp4ssU4K6fDPoiIAsBMzzbklzl3/Q5X74ZQvpQzDWpmbcu4MIx9+/ZRrlw5AAYOHPjCBOXYsWNs3bqV1q1bM2/ev6OXNWvW5P3332fJkiWMHp31k38VnaAkJyfz/vvvc+jQITQaDSqVSjsv9/S/C3KCYlqmPOkPwyE5SafcxLkktiPGApDy9z5tecrRg6gHv4tF01aYbdtA2vV/dyZY9x2EiZ09qYFn0ET9+w9IZWOLqmgxNPFxaCIzytNv3yB+1lT9Mf3wCyZFi5Hw+2+v1FH3puXKk/4gHJL+81m4lMT+/bEAJO7P+CzSbt4g5mf9Xz+nGb9gWqwYcb/9lulsE1URB0wcHFBHR6OJif633M4Op6nTMK9WjbilSzKthXnVOFUsSfyjaO10i5ZKRZMJvbAt4UDIqWvaha+mFmZ0XTCWSj6enF/9N3s/WfzS5MTC3ho7Z0eSYxN1DnC7tvMkzSb2xb3LawQs26s9C8XU0pymE3oBEPjHq7NWrl71SqzefYj1+4/Sq20TXIo5aq8dDrjE2au3sTQ3x9M944+mS7eCcC/viqmpbhKSkJjMlKUbAWhWr6bOtdj4RB5FxWBnY0UJJwe9cWzYl7E4tlfbwr21+FlqhaxBeZqcZMWWLVsAGDJkiE55u3btcHV1ZcuWLYUnQZk7dy6HDh1i+PDhNG7cmEGDBjF58mScnJxYuHAhqampTJkyxdhh5phFs1ZYd+tD6sVA1A/D0SQkYFKqNBZejVBZWpJy8hiJm9b8+4LkJOJ+mYz9lz9SZMosUo4eRh3xEDO3GpjX9EAdGUHcnJ91+2jUDLuxE0ny20X8L5Pz+R0WHFatWmHTuw+p5wJJD8/4LExLl8bytYzPItn/GAlr17y8oRew6d4du7eGELdsqc5ptY7fTsK8WjXSQoJBpcJ28FuZXpt85AhpN3M2NF7QVGxdh6Yf9yH05FWigx6SGBmHbXEHyrxWDcfyLsQ9iGLvx4u09dv88DaVfDxJeBxDXHgEjcZ2z9Rm0LHLBPv/e0+equ28aD99OBfWHWLP+IXa8pS4RPZ9vJgu89+n99rPuLrVn6ToOCq3qUfRKqW5uuP4K3OKLEDb1zx5rbY7/uev0m3c97T2rkNxR3tuBYdz6MxFNBoNH7zZBccni14XrN/N2Su3qONekVLFi2JlaU7YoyiOnL1EbHwinu4VGdq9rU4fB06c44u5K+nawptJowdkiiEuIZHdR89gYW5G1xbe+fK+lcCQUzw+Pi8+XNBQG1QCAwMxMTHB09Mz07W6deuyfft2oqKicHR0zFJ7ik5Qdu/eTdu2bRk7dqx2gY2LiwuNGjWicePG9OjRgy1btjB27FjjBppDqecDMHUth1nlKphVr4XKyhpNfBypl86T/NdeUv7ak/k1Z08R/b8RWPcdhLlnfVQ2tqijIkjauYWEtcvRRDw2wjsp+FLOBmBWthxmVatgXuvJZxEXR8r58yTt20vS3syfhaGYlioJgJlrGezeGqK3TnpY2CuToNw9fBHH8n/j2sAd55oVsCxiQ2pCMpG3wzi6cSMBS/aQFB2vre9QtgQANsWK0Gis/hs6HmWjToLyIjf2nmZt70k0HO1L1Y4NMLU0J+pOOH99s4KApXn370CJTExMmPPpCNbsOcTuf85w4EQgScmpFLGzoVndGvTv2JzGdf69H05Pn8bYWFly4cZdTl28QVJKCva2NtSoVJZ2jerSrfVrWboPz7N2HD5FYnIK7ZvUeyUWxxZkYWFhODk5YWFhkemai4uLtk5WExSVRsF7mWrXrs0nn3zCm2++SXR0NA0bNmThwoU0b94cyFhd/Oeff7J///5c9/W4S4tctyEMI012PSvGiptljR2CeGL09oHGDkE8w9KjXZ62X93ZcKNFlx+8+K7QWTVw4EBOnz7NpUuX9F6vXr06Li4u/P3335muzZw5k7lz57J27Vq9Iyz6KHoE5dnVvra2tpiYmOgs0HF0dOTBg1djNb0QQohXh1J28WSHlZUVKSkpeq8lJydr62SVos5B+S9XV1fu3bsHgJmZGRUqVODgwYPa60eOHKFEiRLGCk8IIYQQT5QsWZLIyEi9SUp4eLi2TlYpOkFp2LChzvRNt27d2LVrFwMHDmTAgAHs27ePTp06GTFCIYQQwvDUGo3BHvnFw8MDtVpN4H/uWwYQEBBAuXLlsrz+BBSeoAwZMoSvv/5am40NHTqUQYMGce3aNW7evEnfvn2ztWVJCCGEKAg0BvxffvH1zbjlwZIlS3TK9+7dS0hIiPZ6Vil6DYqzszPOzs7a5yYmJnz66ad8+umnRoxKCCGEeDVs3ryZ0NCMO3eHhISg0WiYO3eu9vp7772n/e/GjRvTuXNntm/fzogRI/Dx8SE4OJhly5ZRpUqVTOejvIxid/HExcXx3nvv0aVLF954440870928SiH7OJRDtnFoxyyi0dZ8noXT+Xi9QzW1s1HZ3L82oEDB3LixPN3AV29elXneWpqKkuWLGHjxo2EhITg6OhI69atGTt2LEWLFs1W34odQbGzs+PcuXN06ZK723MLIYQQBY1SdvH88ccf2apvbm7O8OHDGT5c/w09s0OxCQpA1apVCQoKMnYYQgghRL7SaNQvr1TIKXqR7LvvvsuaNWu4efOmsUMRQgghRD5S9AjKtWvXcHV1xdfXl1atWlG+fPlMh7wU5JsFCiGEEPqoFTLFY0yKTlBmz56t/e99+/bprSMJihBCiMJGoftX8pWiExRD3WFRCCGEEAWLohMUV1dXY4cghBBC5DuZ4lF4giKEEEK8imSKR+EJysSJE19aR6VS8cMPP+RDNEIIIYTIL4pOUDZt2vTSOpKgCCGEKGzy8yZ/SqXoBOXKlSuZytLT0wkKCmLRokXcuHGDRYsWGSEyIYQQIu8o5SRZY1L0QW36mJqaUqFCBSZNmoSdnR3Tp083dkhCCCGEMLACl6A8q2XLluzZs8fYYQghhBAGpdFoDPYoqBQ9xfMyiYmJxMbGGjsMIYQQwqBkm3EBTlDOnz/P77//jpubm7FDEUIIIQyqII98GIqiExQfHx+95dHR0cTHx2NmZsaUKVPyOSohhBBC5DVFJyilS5fOVKZSqahZsyYVK1akT58+eusIIYQQBZlsM1Z4gvLHH38YOwQhhBAi38kUTwHfxSOEEEKIwknRIyhCCCHEq0h28SgsQalWrRoqlSpbr1GpVFy6dCmPIhJCCCHyn0zxKCxB6datW7YTFCGEEEIUPopKUCZPnmzsEIQQQgijk108CktQnicoKAg/Pz/u3r0LQPny5fHx8aFs2bJGjkwIIYQwPLlZYAFIUGbOnMnChQtJT0/XKZ86dSpDhw5l3LhxRopMCCGEEHlF0QnKihUrmDdvHh4eHgwZMoQqVaoAcP36dZYuXcrChQspUaIEAwYMMHKkQgghhOHIFE8BSFBq1arFypUrMTc315ZXrVqVNm3a0LdvX1asWCEJihBCiEJFdvEo/KC2kJAQOnfurJOcPGVhYUGXLl0ICQkxQmRCCCFE3tEY8H8FlaITFGdnZ1JSUp57PTU1FRcXl3yMSAghhHi17N27l969e+Pp6UmDBg0YMWIE165dy/N+FZ2g9OzZkw0bNhAXF5fpWmxsLBs2bKBnz55GiEwIIYTIOxqNxmCP3Fi3bh1jxowhMTGRCRMmMGLECK5evUrfvn25evWqgd6tfopeg1K3bl38/Pzo0qUL/fv3p3LlygDcuHGD1atXU6xYMTw9PTl58qTO6xo0aGCMcIUQQgiDUMIalOjoaCZPnkzJkiVZvXo1dnZ2AHTo0IFOnTrx/fff8/vvv+dZ/4pOUIYMGaL972nTpmlPmX36wd2/f5+3335bW0ej0aBSqbh8+XL+BiqEEEIUMn5+fsTFxTFkyBBtcgJQunRp2rVrx6ZNm7h//z6lSpXKk/4VnaD8+OOPxg5BCCGEyHeGHD/x8fF54XU/Pz+95YGBgUDGbMZ/1a1bl02bNnH+/PlXM0Hp3r17vvVVbNvBfOtLiIJivLEDEOIVlZZiuB2qL0tQnic8PByAkiVLZrr2tCwsLCzngb2EohMUIYQQQuTO80ZIXiYxMRHIONbjv56WJSUl5Tywl1D0Lh4hhBBCGIe1tTWA3uM+npZZWVnlWf+SoAghhBAik6fnjOmbxnlapm/6x1AkQRFCCCFEJh4eHgAEBARkunb27FkAateunWf9S4IihBBCiEzatGmDra0t69at0zkwNTQ0lN27d+Pt7Z1nO3gAVBolnAYjhBBCCMVZs2YNX331FW5ubvTp04eUlBRWrFhBZGQkq1evplq1annWtyQoQgghhHiu3bt3s3jxYq5du4a5uTleXl6MHTs2T5MTkARFCCGEEAoka1CEEEIIoTiSoAghhBBCcSRBEUIIIYTiSIIihBBCCMWRBEUIIYQQiiMJisIcP34cd3d3Nm7c+NK6s2bNwt3dneDg4HyITLxI69atGThwoLHDKBSy8z0gsu+TTz7B3d3d2GEAEBwcjLu7O7NmzdIpd3d355NPPsmXvoRyyd2MhVCAmJgYli9fjre3Nw0bNszy644fP86gQYN0yqytrSlbtiwdOnTgnXfewdLSEoA7d+6wbds2/vnnH4KCgoiPj6d06dI0btyYYcOG4ezsbND3JER+u3z5Mvv376d79+6UKVPG2OGIXJIERQgD2L17d65eHxMTw+zZsxk9enS2EpSn2rVrh4+PDwCPHz9mx44dzJw5kzNnzrBo0SIA1q9fz8qVK2nVqhUdOnTAysqKs2fPsmrVKrZu3crq1aupXLlyrt6HEIZw7tw5TEyyP8B/+fJlZs+ejbe3d6YExdXVlXPnzmFqamqoMEUekwRFFEgajYaEhARsbW2NHQoAFhYWRu2/WrVq+Pr6ap8PHDiQXr16cfjwYc6dO4eHhwft2rVj2LBhFClSRFuvT58+eHp68uWXX/Lrr78yc+ZMY4QvCoG4uDjs7OwM0tbTUT9DUqlUedKuyDuyBqUAiIuLY9KkSTRt2hQPDw+6d+/Orl279NZ9Or8cFRXFZ599RqNGjahbty5vv/02d+7cAcDPz4+ePXtSp04dmjVrxoIFC/Lx3WTfxo0bcXd35+jRoyxYsIB27dpRu3ZtlixZAsCePXsYMGAA9erVw8PDg27durFu3bpM7TxdJ3LlyhXefvtt6tatS/369Rk9ejT37t3TqatWq5k/fz4DBw6kadOm1KpVi2bNmvHxxx8TGhr63Lb1ld2+fZuRI0dSv3596taty7vvvsvdu3d13t/T0Y/Zs2fj7u6Ou7s7rVu3zvHXzNzcnMaNGwNo31vt2rV1kpOnOnXqBMDVq1dz3F9eyu5n8XT9wsmTJ+nXrx+enp40btyYqVOnkp6eTkpKCj///DMtWrSgdu3a9OzZU3tn1metWrWKd955h+bNm1OrVi0aNWrEmDFjuHbtWj686+x59ntk7ty5tG7dmlq1atGuXTv++OOPl75+4MCBz/339t/1IM+u5dizZw+9evWiTp06jBw5EoDw8HCmTJlC9+7d8fb21sYxY8YMkpKSsvR+9K1BOXToEIMGDaJRo0bUrl2b5s2bM3ToUE6dOgVk/OybOHEiAIMGDdJ+Hz1t50VrUPz8/Hjrrbdo0KABtWvXxsfHh88++4yIiIgsxSvyhoygKFxaWhrvvvsuZ86coW3btjRq1IjQ0FA+/fRTKlas+NzXDR06lOLFizN69GgePHjA0qVLefvtt/nggw+YMmUKffv2pUePHuzcuZPp06fj6upK586d8/GdZd9PP/1EYmIi3bp1o2jRopQsWZJff/2VOXPm0LBhQ0aPHo2lpSVHjhzh888/5+7du0yYMEGnjbCwMAYNGkTr1q358MMPuXXrFmvWrCEgIICNGzfi4uICQGpqKr/99huvv/46LVq0wN7enqtXr7JhwwaOHTvG1q1bcXR0fGnM4eHhDBgwgNatWzNhwgTu3r3LihUreO+999i2bRsmJiY0aNCAiRMn8uOPP9K2bVvatm0LkOvRodu3bwNQtGjRl8YIULx48Vz1l1dy8llcunSJAwcO0LNnT7p27crBgwdZtGgRpqamXLt2jbi4ON5++20SExNZunQpw4cPx8/PT2cEYNGiRdSpU4c333wTJycn7ty5w/r16/nnn3/YvHkz5cqVy+evxMv9/PPPxMXF0bt3bywsLNi+fTuTJk3i0aNHjBs3zqB9+fn5sXz5cvr27Uvv3r15eteUq1evsmfPHnx8fOjZsycajYYTJ06wYMECLl26xG+//Zbtvk6ePMmIESOoXLky77zzDo6Ojjx69IiAgAAuXbqEl5cXffr0wcLCgrVr1zJixAgqVaoE8NLPaebMmcydO5dy5coxYMAASpYsSWhoKH/99Rfh4eEv/f4ReUgjFMXf31/j5uam2bBhg0aj0WjWrVuncXNz00yaNEmn3pkzZzTu7u4aNzc3TVBQkLb8448/1ri5uWm++OILnfpLly7VuLm5aTw9PXXqJycnaxo3bqzp06dPHr6r3NmwYYPGzc1N06ZNG01cXJy2/OLFixp3d3fNd999l+k13377raZatWqae/fuactatWqlcXNz0yxatEin7t69ezVubm6ajz/+WFumVqs1CQkJmdr9559/NG5ubprffvtNp7xVq1aaAQMGZCpzc3PTbNu2Tad8wYIFGjc3N83hw4e1ZUFBQRo3NzfNr7/++qIvRSZP/71MmzZN8/jxY83jx481169f10ydOlXj5uamad26tSY5OfmFbYwZM0bn35yx/fd7ILufhZubm8bd3V1z5swZnXJfX1+Nu7u7ZtiwYRq1Wq0t37dvn8bNzU2zZs0anfrx8fGZ+rx27ZqmZs2amq+//jrH7y8vPP0ead68uSY6OlpbnpycrOnVq5emWrVqmrt372o0mn9/RjxrwIABmlatWult+7/fG0//rdaoUUNz9erVTPUTExM16enpmcqnT5+ucXNz0wQGBmZq67//7v/b5w8//KBxc3PTPHz48EVfBu3Xwd/fP9M1fX0FBgZq3NzcNL1799b7eet7HyL/yBSPwu3duxeAESNG6JTXrVuXRo0aPfd1b7/9ts5zb29vIGPa4dnFYxYWFnh4eGj/2layN998U2dUYdu2bWg0Gnr16kVERITOo3Xr1qjVao4eParThq2tbaapmLZt21K5cmX27duHWq0GMuarra2tgYwphpiYGCIiIqhWrRr29vacO3cuSzE7OztnGpl6OvXydMrNEBYsWECjRo1o1KgRnTp14rfffqNhw4YsWbLkhetj5s+fz549e2jTpg3du3c3WDyGlJPPwtPTk7p16+qUeXl5odFoGDRoECqVSlveoEEDIPPnYWNjA2Ssd4qLiyMiIoJixYpRsWJFAgMDDfkWDaZ///4603gWFhYMGTIEtVrN/v37DdpXixYtcHNzy1RuZWWlXeCamppKVFQUERERNGnSBCDL3zvPsre3BzIWo6empuYial3btm0D4H//+5/2835WThbqCsORKR6Fu3fvHk5OThQrVizTtSpVqmT6BfxU2bJldZ4//aH133IABwcHoqKich9sHvvvlNbNmzcBdBaH/tejR490npcrV07vL+wqVapw8+ZNIiIitFMd+/fvZ9GiRVy4cCHTD8Wsfr30fb2fTkdktY2oqKhM/RctWlRnN0KPHj3o0qWLdiFghQoVXjo0vXz5cmbMmIG3tzc///yzzi9tpcnuZ/G8f+f6rj0t/287J0+eZM6cOQQEBGRaO6HULaz6dmFVqVIFQGfdkyFUqFBBb3l6ejqLFy9m06ZN3LlzR5v0P5WTnzUDBgzgr7/+4rvvvmPatGl4enri7e1N586d9X7WWfU0Ka1Zs2aO2xB5RxKUQup5W+kK8hY7KysrnedPf/AtWLDguaMEOf3htX//fkaNGkWtWrWYOHEipUqV0vY/btw47Xz7y7zo653VNsaMGcOJEyd0yvz8/HR+SZYtW1Y7MpMVS5cuZfLkyTRq1Ih58+ZpRyiUKCefxYu+7s/7q/jZdi5cuMBbb71FmTJlGDduHGXKlMHa2hqVSsX3339PYmJiLt9VwZCWlvbca8/7NzNlyhSWL19Ou3btePfddylWrBjm5uaEh4fzySefZPnf/bMcHR1Zt24dZ86c4dixY5w6dYo5c+YwZ84cfvrpJzp27JjtNoXySYKicOXKleP27ds8fvw40yjKjRs3jBSVMlSoUIHDhw9TokSJLP8FdO/ePVJSUjIlNDdu3MDOzk476rB582YsLS1ZsWKFzg/ihIQEYmJiDPcmnnjR6MXHH3+cqc8SJUrkuK+FCxcybdo0mjVrxpw5cxS/9TK/PwvIGPpPS0tj0aJFmZLcqKgoxX7Nbt68SZs2bXTKnv6cKF++/HNf5+joyMWLFzOVBwUFZTuGzZs34+Xlxa+//qpTfvDgwWy39SwTExO8vLzw8vIC4P79+3Tv3p2ff/5Zm6BkdxSwQoUKHDp0iEuXLmmnwYVyyASbwj3d0TF//nyd8oCAAI4dO2aMkBTj6dTO9OnT9c5Lx8bGkpKSolMWHx+fadvlvn37tD/Yn/51bWJigkqlyjQ8PXfu3ExlhvB0/js6OjrTtVq1atG4cWOdR05/Qc6fP59p06bRqlUr5s6dq9hftM/K78/iaZ+QeZRr9erVmaYNlWTVqlU6SVtKSgpLly7FxMREu5Vdn4oVKxIfH59pfcjixYuzHYOJiUmmr1tqaioLFy7MdltP6dvuW6pUKYoXL05kZKS27EXfR/p06dIFyPgZom8LdE5Ge4ThyAiKwnXv3p0NGzbw+++/ExYWxmuvvcb9+/dZuXIlNWrU0PtXz6uidu3ajB07ll9++YXOnTvTuXNnSpYsyePHj7l27Rp+fn7s2LFDZyqkXLlyLFiwgBs3buDh4cHNmzdZs2YNRYsWZezYsdp67du3Z8+ePQwcOJDu3buj0Wg4cuQIN27cwMnJyeDvxcnJifLly7Njxw7Kli1L8eLFsba2ztVZKP+1cuVKZsyYQfHixWnbtm2ms3RsbW0z/fWtBPn9WQC8/vrrLFu2jHfffZfevXtjZWXFmTNnOHLkCOXKlSM9PT1P+s2tYsWK0atXL3r27Im5uTnbt2/n4sWLDBs27IUjKH369GHJkiW89957DBo0CGtra/7++2/i4uKyHUP79u1ZvXo177//Pk2aNCE6Oppt27blKhn+4osvuH//Pk2aNMHV1ZX09HT++usvrl+/zoABA7T1ateujYmJCfPnzyc6OhobGxvKlClDnTp19Lbr4eHBiBEjmD9/Pl27dqVz586UKlWKsLAw/Pz8+PHHH6levXqO4xa5IwmKwpmZmbFo0SJmzJjB7t27+fvvv6lcuTI//PADN27ceKUTFICRI0dSq1Yt/vjjD1asWEF8fDxOTk5UrPj/9u4upKk3jgP4d2aCTSVEa+4gEsSZEzNcJbki0yxrFaJBmL2YXXRhCEKCOhBakMKUBLeLxIt0kVaQBTEJwrxogVMpW+VFEZUvCUbvajTN878QR2Orlv3LM/p+wJvnd855nochfHme52yrUFZW5rMVolKpYLFYYDabYTaboVAosGXLFlRUVCAuLs5zncFgwNTUFFpbW1FXVwelUgm9Xo+2tjYUFhb+kbnU19ejpqYGDQ0N+Pz5MwRB+F8DysOHDwHMHRw2Go0+dUEQZBlQFuOzSE1N9ZxxsFgsCAsLg06nw8WLF2EymTA6OvpH+v1d5eXluH//Pi5fvozx8XEIggCj0YiioqIf3icIAs6dO4ezZ8+isbERkZGR2LFjB8rLyz1bKoGqrKxEREQEOjs7cfv2bc+bbLm5uQs+K5Kbm4vr16/jxo0bePPmDcLDw5GQkACTyYT9+/d7rlOr1aipqUFzczNMJhOmp6eRl5f33YACzJ1jSkpKwoULF9DS0oKZmRmsWLEC6enpUKlUCxov/T8UEtew6B+RlZUFQRAC+mZNomDS0dGBqqoq2Gy2Bf2WE5Ec8QwKERERyQ4DChEREckOAwoRERHJDs+gEBERkexwBYWIiIhkhwGFiIiIZIcBhYiIiGSHAYWIiIhkhwGF6B9lsVig0WjgdDoXeyhERD4YUIhkYGRkBBqNBpWVlYs9FCIiWeBv8RD9ow4ePAiDwQC1Wr3YQyEi8sGAQvSPio6ORnR09GIPg4jIL27xEC0yi8WCbdu2AQCuXbsGjUbj+evo6AAAOJ1OaDQaWCwWuFwuHD9+HGlpadBoNBgZGQEA9PT0oLq6GgaDATqdDikpKdizZw+sViu+fPnit19/Z1A0Gg0OHz6Mt2/forq6Gps3b0ZycjJ2796Nq1ev/vL8Hj16hNLSUqSnpyM5ORmZmZk4deoUxsfHfa6trKz0zOnSpUvYu3cv1qxZA71ej+rqanz69OmX+yei4MQVFKJFlpaWhiNHjsBmsyExMRHZ2dmemlar9bp2YGAATU1NWLduHfbt24d3795h6dKlAIDm5mY8f/4cqampyMjIgNvtxr1792CxWOB0OtHS0oIlS5YENKaPHz/iwIEDCAsLQ05ODtxuN27evAmj0YiQkBDk5eUF9Jzu7m6UlpYCAHJycqBWq/H48WO0t7ejq6sLbW1tiI+P97mvrq4ODocDmZmZ2LRpE5xOJ65cuYKXL1/CZrMF1DcRBTmJiBbd8PCwJIqiVFFR4bfe09MjiaIoiaIotbe3+71maGhImp2d9WlvaGiQRFGU7Ha7V3tjY6MkiqLU09Pj1T7fj9FolGZmZjztT58+lbRarbRr166A5jQxMSGlpaVJiYmJUl9fn1etqalJEkVRKi4u9mqvqKiQRFGUMjIypNHRUU/79PS0VFhYKImiKD148CCg/okouHGLhyiIaLVaFBQU+K3Fx8dDoVD4tB89ehQAcOfOnYD7CQ8PR1VVldeKy+rVq6HT6fDs2TNMTk7+9BldXV14//49DAYD1q9f71U7duwYBEHA3bt38erVK597T5w44XV4NzQ0FPn5+QAAl8sV8DyIKHhxi4coiKSkpHy3NjU1BZvNhlu3buHFixeYnJyE9M1vgfo78/E9CQkJiIiI8GlXqVQA5raAlErlD58xODgIANi4caNPLTQ0FBs2bMDo6CgGBwd93iRKTk72uScuLg4A8OHDh8AmQURBjQGFKIjExMT4bZ+enkZRURFcLhdEUYTBYEB0dDRCQ+f+xa1WK9xud8D9REVF+W2ff97Xr19/+oz5A62xsbF+6/Pt/g6+RkZG+rTNr+bMzs7+tG8iCn4MKERBxN8WDjC3neJyuZCfn4/a2lqv2vj4OKxW698Ynpf5kPH69Wu/9fl2f2GEiIhnUIhkYH51IJCVCX+GhoYAANu3b/ep9fX1LXxgv2H+DaTe3l6f2szMDPr7+wEASUlJf3VcRBQcGFCIZCAqKgoKhQJjY2MLul8QBAC+YWB4eBj19fW/Pb6FyM7OxvLly2G32zEwMOBVa21txcjICPR6Pb/Jloj84hYPkQwolUqsXbsW/f39OHnyJFatWoWQkBBkZWUhMTHxp/dnZmYiISEB58+fx5MnT6DVajE2Nobu7m5s3brV75syf5pSqcSZM2dQVlaGQ4cOYefOnZ7vQXE4HIiNjcXp06f/+riIKDgwoBDJhNlsRm1tLRwOB+x2OyRJgkqlCiigLFu2DK2traivr0dvby/6+/sRHx+PkpISFBcXo7Oz8y/MwFd2djba2trQ1NQEh8OBiYkJxMTEoKCgACUlJVi5cuWijIuI5E8hffseIhEREZEM8AwKERERyQ4DChEREckOAwoRERHJDgMKERERyQ4DChEREckOAwoRERHJDgMKERERyQ4DChEREckOAwoRERHJDgMKERERyQ4DChEREckOAwoRERHJDgMKERERyc5/7UmUbzpRZVAAAAAASUVORK5CYII=", 20 | "text/plain": [ 21 | "
" 22 | ] 23 | }, 24 | "metadata": {}, 25 | "output_type": "display_data" 26 | } 27 | ], 28 | "source": [ 29 | "import seaborn as sns\n", 30 | "import pandas as pd\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "\n", 33 | "# Base path for score files\n", 34 | "base_path = \"results/\"\n", 35 | "runs_common_name = \"test\"\n", 36 | "\n", 37 | "# Initialize dataset names\n", 38 | "datasets = [\"ldm\", \"repaint-p2-9k\", \"lama\", \"pluralistic\"]\n", 39 | "\n", 40 | "# Initialize the data dictionary to store IOU scores\n", 41 | "data = {dataset: {} for dataset in datasets}\n", 42 | "\n", 43 | "# Function to read and process the scores from file\n", 44 | "def process_scores(file_path, dataset, test_datasets):\n", 45 | " try:\n", 46 | " with open(file_path, \"r\") as file:\n", 47 | " for idx, line in enumerate(file):\n", 48 | " if idx == 0:\n", 49 | " continue\n", 50 | " parts = line.split()\n", 51 | " other_dataset = parts[0][:-1]\n", 52 | " if other_dataset not in test_datasets:\n", 53 | " continue\n", 54 | " iou = float(parts[1])\n", 55 | " data[dataset][other_dataset] = iou\n", 56 | " except Exception as e:\n", 57 | " print(f\"Error processing {file_path}: {e}\")\n", 58 | " for d in test_datasets:\n", 59 | " data[dataset][d] = 0\n", 60 | " # Reorder the dictionary according to the test_datasets order\n", 61 | " data[dataset] = {key: data[dataset][key] for key in test_datasets}\n", 62 | "\n", 63 | "# Process each dataset\n", 64 | "for dataset in datasets:\n", 65 | " file_path = f\"{base_path}{dataset}_{runs_common_name}/scores.txt\"\n", 66 | " process_scores(file_path, dataset, datasets)\n", 67 | "\n", 68 | "# Calculate average ID and OOD scores\n", 69 | "avg_iid = sum(data[dataset][dataset] for dataset in datasets) / len(datasets)\n", 70 | "\n", 71 | "total_values = []\n", 72 | "min_ood = float('inf')\n", 73 | "\n", 74 | "for i, dataset in enumerate(datasets):\n", 75 | " for j, other_dataset in enumerate(datasets):\n", 76 | " if i != j:\n", 77 | " total_values.append(data[dataset][other_dataset])\n", 78 | " min_ood = min(min_ood, data[dataset][other_dataset])\n", 79 | "\n", 80 | "avg_ood = sum(total_values) / len(total_values)\n", 81 | "\n", 82 | "print(\"AVG IID:\", avg_iid)\n", 83 | "print(\"AVG OOD:\", avg_ood)\n", 84 | "print(\"MIN OOD:\", min_ood)\n", 85 | "\n", 86 | "# Create a DataFrame from the data dictionary\n", 87 | "df = pd.DataFrame(data)\n", 88 | "\n", 89 | "# Plot the heatmap\n", 90 | "sns.set(font_scale=1.2)\n", 91 | "heatmap = sns.heatmap(df, annot=True, fmt='.1f', vmin=0, vmax=90)\n", 92 | "heatmap.set(xlabel='train on', ylabel='test on')\n", 93 | "\n", 94 | "plt.show()\n" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "py", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.10.14" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import time 4 | from tensorboardX import SummaryWriter 5 | 6 | from validate import validate, validate_fully_supervised 7 | from utils.utils import compute_accuracy_detection, compute_average_precision_detection 8 | from data import create_dataloader 9 | from earlystop import EarlyStopping 10 | from networks.trainer import Trainer 11 | from options.train_options import TrainOptions 12 | from utils.utils import derive_datapaths 13 | import torch.multiprocessing 14 | 15 | def get_val_opt(opt): 16 | val_opt = deepcopy(opt) 17 | val_opt.data_label = 'valid' 18 | return val_opt 19 | 20 | 21 | if __name__ == '__main__': 22 | torch.multiprocessing.set_sharing_strategy('file_system') 23 | 24 | opt = TrainOptions().parse() 25 | if opt.data_root_path: 26 | opt = derive_datapaths(opt) 27 | val_opt = get_val_opt(opt) 28 | 29 | model = Trainer(opt) 30 | 31 | data_loader = create_dataloader(opt) 32 | val_loader = create_dataloader(val_opt) 33 | 34 | train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train")) 35 | val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val")) 36 | 37 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True) 38 | start_time = time.time() 39 | best_iou = 0 40 | print ("Length of training dataset: %d" %(len(data_loader.dataset))) 41 | for epoch in range(opt.niter): 42 | print(f"Epoch {epoch}") 43 | epoch_loss = 0 44 | for i, data in enumerate(data_loader): 45 | model.total_steps += 1 46 | 47 | model.set_input(data) 48 | model.optimize_parameters() 49 | 50 | if model.total_steps % opt.loss_freq == 0: 51 | print(f"Train loss: {round(model.loss.item(), 4)} at step: {model.total_steps};\t Iter time: {round((time.time() - start_time) / model.total_steps, 4)}") 52 | epoch_loss += model.loss 53 | train_writer.add_scalar('loss', model.loss, model.total_steps) 54 | 55 | epoch_loss /= len(data_loader) 56 | if opt.fully_supervised: 57 | # compute train performance 58 | mean_iou = sum(model.ious)/len(model.ious) 59 | model.ious = [] 60 | print(f"Epoch mean train IOU: {round(mean_iou, 2)}") 61 | 62 | mean_F1_best = sum(model.F1_best)/len(model.F1_best) 63 | model.F1_best = [] 64 | print(f"Epoch mean train F1_best: {round(mean_F1_best, 4)}") 65 | mean_F1_fixed = sum(model.F1_fixed)/len(model.F1_fixed) 66 | model.F1_fixed = [] 67 | print(f"Epoch mean train F1_fixed: {round(mean_F1_fixed, 4)}") 68 | 69 | mean_ap = sum(model.ap)/len(model.ap) 70 | model.ap = [] 71 | print(f"Epoch mean train Mean AP: {round(mean_ap, 4)}") 72 | else: 73 | model.format_output() 74 | mean_acc = compute_accuracy_detection(model.logits, model.labels) 75 | print(f"Epoch mean train ACC: {round(mean_acc, 2)}") 76 | mean_ap = compute_average_precision_detection(model.logits, model.labels) 77 | print(f"Epoch mean train AP: {round(mean_ap, 4)}") 78 | 79 | model.logits = [] 80 | model.labels = [] 81 | 82 | # Validation 83 | model.eval() 84 | print('Validation') 85 | if opt.fully_supervised: 86 | ious, f1_best, f1_fixed, mean_ap, _ = validate_fully_supervised(model.model, val_loader, opt.train_dataset) 87 | mean_iou = sum(ious)/len(ious) 88 | val_writer.add_scalar('iou', mean_iou, model.total_steps) 89 | print(f"(Val @ epoch {epoch}) IOU: {round(mean_iou, 2)}") 90 | 91 | mean_f1_best = sum(f1_best)/len(f1_best) 92 | val_writer.add_scalar('F1_best', mean_f1_best, model.total_steps) 93 | print(f"(Val @ epoch {epoch}) F1 best: {round(mean_f1_best, 4)}") 94 | 95 | mean_f1_fixed = sum(f1_fixed)/len(f1_fixed) 96 | val_writer.add_scalar('F1_fixed', mean_f1_fixed, model.total_steps) 97 | print(f"(Val @ epoch {epoch}) F1 fixed: {round(mean_f1_fixed, 4)}") 98 | 99 | mean_ap = sum(mean_ap)/len(mean_ap) 100 | val_writer.add_scalar('Mean AP', mean_ap, model.total_steps) 101 | print(f"(Val @ epoch {epoch}) Mean AP: {round(mean_ap, 4)}") 102 | 103 | # save best model weights or those at save_epoch_freq 104 | if mean_iou > best_iou: 105 | print('saving best model at the end of epoch %d' % (epoch)) 106 | model.save_networks( 'model_epoch_best.pth' ) 107 | best_iou = mean_iou 108 | 109 | early_stopping(mean_iou, model) 110 | else: 111 | ap, r_acc, f_acc, acc = validate(model.model, val_loader) 112 | val_writer.add_scalar('accuracy', acc, model.total_steps) 113 | val_writer.add_scalar('ap', ap, model.total_steps) 114 | print(f"(Val @ epoch {epoch}) ACC: {acc}; AP: {ap}") 115 | 116 | # save best model weights or those at save_epoch_freq 117 | if ap > best_iou: 118 | print('saving best model at the end of epoch %d' % (epoch)) 119 | print(ap, best_iou) 120 | model.save_networks( 'model_epoch_best.pth' ) 121 | best_iou = ap 122 | 123 | early_stopping(acc, model) 124 | 125 | if early_stopping.early_stop: 126 | cont_train = model.adjust_learning_rate() 127 | if cont_train: 128 | print("Learning rate dropped by 10, continue training...") 129 | early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.002, verbose=True) 130 | else: 131 | print("Early stopping.") 132 | break 133 | model.train() 134 | print() -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from PIL import Image 6 | from torchmetrics import ConfusionMatrix, AveragePrecision 7 | from sklearn.metrics import average_precision_score 8 | 9 | 10 | # Function to derive data paths (train/val/test + fake/real) for dolos dataset 11 | def derive_datapaths(opt): 12 | root_path = opt.data_root_path.rstrip('/') 13 | 14 | opt.train_path = f"{root_path}/fake/{opt.train_dataset}/images/train" 15 | opt.valid_path = f"{root_path}/fake/{opt.train_dataset}/images/valid" 16 | opt.train_masks_ground_truth_path = f"{root_path}/fake/{opt.train_dataset}/masks/train" 17 | opt.valid_masks_ground_truth_path = f"{root_path}/fake/{opt.train_dataset}/masks/valid" 18 | opt.train_real_list_path = f"{root_path}/real/train" 19 | opt.valid_real_list_path = f"{root_path}/real/valid" 20 | 21 | return opt 22 | 23 | # Localisation 24 | # IOU 25 | def compute_iou(pred_mask, gt_mask, threshold=0.5): 26 | if torch.is_tensor(pred_mask) and torch.is_tensor(gt_mask): 27 | pred_mask = (pred_mask > threshold).to(torch.uint8) 28 | gt_mask = (gt_mask > threshold).to(torch.uint8) 29 | 30 | intersection = torch.logical_and(pred_mask, gt_mask).float().sum() 31 | union = torch.logical_or(pred_mask, gt_mask).float().sum() 32 | 33 | iou_score = 100.0 * ((intersection + 1e-8) / (union + 1e-8)).item() 34 | else: 35 | pred_mask = (pred_mask > threshold).astype(np.uint8) 36 | gt_mask = (gt_mask > threshold).astype(np.uint8) 37 | 38 | intersection = np.logical_and(pred_mask, gt_mask).sum() 39 | union = np.logical_or(pred_mask, gt_mask).sum() 40 | 41 | iou_score = 100.0 * ((intersection + 1e-8) / (union + 1e-8)) 42 | 43 | return iou_score 44 | 45 | def compute_batch_iou(predictions, ground_truths, threshold=0.5): 46 | assert len(predictions) == len(ground_truths), "Both lists must have the same length" 47 | 48 | iou_scores = [] 49 | for pred_mask, gt_mask in zip(predictions, ground_truths): 50 | iou_score = compute_iou(pred_mask, gt_mask, threshold) 51 | iou_scores.append(iou_score) 52 | 53 | return iou_scores 54 | 55 | # F1 - inspired from Guillaro, Fabrizio et al. TruFor paper https://github.com/grip-unina/TruFor/blob/main/test_docker/metrics.py 56 | # adapted for GPU use 57 | def min_filter(tensor, size): 58 | padding = (size - 1) // 2 59 | tensor = F.pad(tensor, (padding, padding, padding, padding), value=float("inf")) 60 | tensor = tensor.unsqueeze(0) * -1 61 | result = F.max_pool2d(tensor, kernel_size=size, stride=1) 62 | return result.squeeze() * -1 63 | 64 | def max_filter(tensor, size): 65 | padding = (size - 1) // 2 66 | tensor = F.pad(tensor, (padding, padding, padding, padding)) 67 | tensor = tensor.unsqueeze(0) 68 | result = F.max_pool2d(tensor, kernel_size=size, stride=1) 69 | return result.squeeze() 70 | 71 | def extract_ground_truths(gt, erode_size=15, dilate_size=11): 72 | gt_eroded = min_filter(gt, erode_size) 73 | gt_dilated = torch.logical_not(max_filter(gt, dilate_size)) 74 | return gt_dilated, gt_eroded 75 | 76 | def compute_f1(fp, tp, fn): 77 | return (2 * tp + 1e-32) / torch.maximum(2 * tp + fn + fp, torch.tensor(1e-32)) 78 | 79 | def dynamic_threshold_metrics(preds, gt_dilated, gt_eroded): 80 | preds, gt_dilated, gt_eroded = preds.flatten(), gt_dilated.flatten(), gt_eroded.flatten() 81 | inds = torch.argsort(preds) 82 | inds = inds[(gt_dilated[inds] + gt_eroded[inds]) > 0] 83 | thresholds = preds[inds] 84 | gt_dilated, gt_eroded = gt_dilated[inds], gt_eroded[inds] 85 | tn = torch.cumsum(gt_dilated, dim=0) 86 | fn = torch.cumsum(gt_eroded, dim=0) 87 | fp, tp = torch.sum(gt_dilated) - tn, torch.sum(gt_eroded) - fn 88 | mask = F.pad(thresholds[1:] > thresholds[:-1], (0, 1), mode="constant") 89 | return fp[mask], tp[mask], fn[mask], tn[mask] 90 | 91 | def fixed_threshold_metrics(preds, gt, gt_dilated, gt_eroded, threshold): 92 | preds = (preds > threshold).flatten().int() 93 | gt, gt_dilated, gt_eroded = gt.flatten().int(), gt_dilated.flatten().int(), gt_eroded.flatten().int() 94 | gt, preds = gt[(gt_dilated + gt_eroded) > 0], preds[(gt_dilated + gt_eroded) > 0] 95 | cm = ConfusionMatrix(task="binary", num_classes=2).to(gt.device)(gt, preds) 96 | return cm[1, 0], cm[1, 1], cm[0, 1], cm[0, 0] 97 | 98 | def localization_f1(pred, gt): 99 | if not isinstance(pred, torch.Tensor): 100 | pred = torch.tensor(pred, dtype=torch.float32) 101 | if not isinstance(gt, torch.Tensor): 102 | gt = torch.tensor(gt, dtype=torch.float32) 103 | pred, gt = pred.float(), gt.float() 104 | gt_dilated, gt_eroded = extract_ground_truths(gt) 105 | 106 | # Best threshold F1 107 | try: 108 | fp, tp, fn, tn = dynamic_threshold_metrics(pred, gt_dilated, gt_eroded) 109 | f1_dynamic = compute_f1(fp, tp, fn) 110 | best_f1 = torch.max(f1_dynamic) 111 | except Exception as e: 112 | print(e) 113 | best_f1 = torch.tensor(np.nan) 114 | 115 | # Fixed threshold F1 116 | try: 117 | fp, tp, fn, tn = fixed_threshold_metrics(pred, gt, gt_dilated, gt_eroded, 0.5) 118 | f1_fixed = compute_f1(fp, tp, fn) 119 | except Exception as e: 120 | print(e) 121 | f1_fixed = torch.tensor(np.nan) 122 | 123 | return max(best_f1, f1_fixed), f1_fixed 124 | 125 | def compute_batch_localization_f1(preds_list, gts_list): 126 | assert len(preds_list) == len(gts_list), "Both lists must have the same length" 127 | 128 | batch_f1_scores_best = [] 129 | batch_f1_scores_fixed = [] 130 | for preds, gt in zip(preds_list, gts_list): 131 | best_f1, fixed_f1 = localization_f1(preds, gt) 132 | batch_f1_scores_best.append(best_f1.item()) 133 | batch_f1_scores_fixed.append(fixed_f1.item()) 134 | return batch_f1_scores_best, batch_f1_scores_fixed 135 | 136 | # Average Precision 137 | def compute_ap(pred_mask, gt_mask): 138 | ap_cls = AveragePrecision(task="binary").to(pred_mask.device) 139 | pred_mask = pred_mask.flatten() 140 | # grounf truth has to be of type int 141 | gt_mask = gt_mask.flatten().int() 142 | 143 | ap_score = ap_cls(pred_mask, gt_mask) 144 | return ap_score.item() 145 | 146 | def compute_batch_ap(predictions, ground_truths): 147 | assert len(predictions) == len(ground_truths), "Both lists must have the same length" 148 | 149 | ap_scores = [] 150 | for pred_mask, gt_mask in zip(predictions, ground_truths): 151 | ap_score = compute_ap(pred_mask, gt_mask) 152 | ap_scores.append(ap_score) 153 | 154 | return ap_scores 155 | 156 | # Detection 157 | def compute_accuracy_detection(logits, labels, threshold=0.5): 158 | predicted_classes = (logits >= threshold).float() 159 | correct_predictions = (predicted_classes == labels).sum().item() 160 | accuracy = correct_predictions / labels.size()[0] 161 | return accuracy * 100 162 | 163 | def compute_average_precision_detection(logits, labels): 164 | probabilities_np = logits.detach().cpu().numpy() 165 | labels_np = labels.detach().cpu().numpy() 166 | ap = average_precision_score(labels_np, probabilities_np) 167 | return ap 168 | 169 | # Find best thresholf for detection classification 170 | def find_best_threshold(y_true, y_pred): 171 | "We assume first half is real 0, and the second half is fake 1" 172 | 173 | N = len(y_true) 174 | 175 | if torch.max(y_pred[0:N//2]) <= torch.min(y_pred[N//2:N]): # perfectly separable case 176 | return (torch.max(y_pred[0:N//2]) + torch.min(y_pred[N//2:N])) / 2 177 | 178 | best_acc = 0 179 | best_thres = 0 180 | for thres in y_pred: 181 | temp = copy.deepcopy(y_pred) 182 | temp[temp>=thres] = 1 183 | temp[temp= best_acc: 187 | best_thres = thres 188 | best_acc = acc 189 | 190 | return best_thres.item() 191 | 192 | # Outputs saving 193 | def generate_outputs(output_save_path, pred, img_names=[]): 194 | for i in range(len(pred)): 195 | if isinstance(pred[i], torch.Tensor): 196 | pred_i = pred[i].detach().cpu().numpy() 197 | else: 198 | pred_i = pred[i] 199 | 200 | pred_mask_array = pred_i * 255 201 | pred_mask = Image.fromarray(pred_mask_array).convert(mode="L") 202 | pred_mask.save(output_save_path + "/" + img_names[i].split("/")[-1]) -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.utils.data 5 | import numpy as np 6 | from models import get_model 7 | from PIL import Image, ImageOps 8 | from dataset_paths import DETECTION_DATASET_PATHS, LOCALISATION_DATASET_PATHS 9 | import random 10 | import shutil 11 | from utils.utils import compute_batch_iou, compute_batch_localization_f1, compute_batch_ap, generate_outputs, find_best_threshold, compute_accuracy_detection, compute_average_precision_detection 12 | from data.datasets import RealFakeDataset, RealFakeDetectionDataset 13 | import torchvision 14 | from torchvision.transforms import functional as F 15 | from torchvision import transforms 16 | from options.test_options import TestOptions 17 | 18 | 19 | SEED = 0 20 | def set_seed(): 21 | torch.manual_seed(SEED) 22 | torch.cuda.manual_seed(SEED) 23 | np.random.seed(SEED) 24 | random.seed(SEED) 25 | 26 | 27 | MEAN = { 28 | "imagenet":[0.485, 0.456, 0.406], 29 | "clip":[0.48145466, 0.4578275, 0.40821073] 30 | } 31 | STD = { 32 | "imagenet":[0.229, 0.224, 0.225], 33 | "clip":[0.26862954, 0.26130258, 0.27577711] 34 | } 35 | 36 | def validate(model, loader): 37 | 38 | with torch.no_grad(): 39 | y_true, y_pred = [], [] 40 | all_img_paths = [] 41 | print ("Length of dataset: %d" %(len(loader.dataset))) 42 | for img, label, img_names in loader: 43 | in_tens = img.cuda() 44 | outputs = torch.sigmoid(model(in_tens)) 45 | outputs = torch.mean(outputs , dim=1) 46 | 47 | y_pred.extend(outputs) 48 | y_true.extend(label) 49 | all_img_paths.extend(img_names) 50 | 51 | y_pred = torch.stack(y_pred).to('cpu') 52 | y_true = torch.stack(y_true).to('cpu') 53 | 54 | # Acc based on the best thres 55 | best_thres = find_best_threshold(y_true, y_pred) 56 | mean_acc_best_th = compute_accuracy_detection(y_pred, y_true, threshold = best_thres) 57 | mean_acc = compute_accuracy_detection(y_pred, y_true) 58 | mean_ap = compute_average_precision_detection(y_pred, y_true) 59 | 60 | return mean_ap, mean_acc, mean_acc_best_th, best_thres, all_img_paths 61 | 62 | def validate_fully_supervised(model, loader, dataset_name, output_save_path = ''): 63 | with torch.no_grad(): 64 | ious = [] 65 | f1_best = [] 66 | f1_fixed = [] 67 | all_img_paths = [] 68 | mean_ap = [] 69 | print ("Length of dataset: %d" %(len(loader.dataset))) 70 | for _, data in enumerate(loader): 71 | img, _, img_paths, masks_paths = data 72 | 73 | in_tens = img.cuda() 74 | outputs = torch.sigmoid(model(in_tens).squeeze(1)) 75 | 76 | if dataset_name in ["pluralistic", "lama", "repaint-p2-9k", "ldm", "ldm_clean", "ldm_real"]: 77 | masks = [ImageOps.invert(Image.open(mask_path).convert("L")) for mask_path in masks_paths] 78 | else: 79 | masks = [Image.open(mask_path).convert("L") for mask_path in masks_paths] 80 | 81 | masks = [ ((transforms.ToTensor()(x).to(outputs.device)) > 0.5).float().squeeze() for x in masks] 82 | 83 | outputs = outputs.view(outputs.size(0), int(outputs.size(1)**0.5), int(outputs.size(1)**0.5)) 84 | resized_outputs = [] 85 | for i, output in enumerate(outputs): 86 | if output.size() != masks[i].size(): 87 | output_resized = F.resize(output.unsqueeze(0), masks[i].size(), interpolation=torchvision.transforms.InterpolationMode.BILINEAR).squeeze(0) 88 | resized_outputs.append(output_resized) 89 | else: 90 | resized_outputs.append(output) 91 | 92 | batch_ious = compute_batch_iou(resized_outputs, masks, threshold = 0.5) 93 | batch_F1_best, batch_F1_fixed = compute_batch_localization_f1(resized_outputs, masks) 94 | batch_ap = compute_batch_ap(resized_outputs, masks) 95 | 96 | if output_save_path: 97 | generate_outputs(output_save_path + "/" + dataset_name, resized_outputs, img_paths) 98 | 99 | ious.extend(batch_ious) 100 | f1_best.extend(batch_F1_best) 101 | f1_fixed.extend(batch_F1_fixed) 102 | all_img_paths.extend(img_paths) 103 | mean_ap.extend(batch_ap) 104 | 105 | return ious, f1_best, f1_fixed, mean_ap, all_img_paths 106 | 107 | def save_scores_to_file(ious, f1_best, f1_fixed, aps, img_paths, file_path): 108 | with open(file_path + "/scores.txt", 'w') as file: 109 | file.write(f'Image path \t iou \t f1_best \t f1_fixed \t ap\n') 110 | for iou, f1_b, f1_f, ap, img_path in zip(ious, f1_best, f1_fixed, aps, img_paths): 111 | file.write(f'{img_path} \t {iou} \t {f1_b} \t {f1_f} \t {ap}\n') 112 | 113 | def save_scores_to_file_detection(aps, acc0s, acc1s, img_paths, file_path): 114 | with open(file_path + "/scores.txt", 'w') as file: 115 | file.write(f'Image path \t AP \t Acc_fixed \t Acc_best \n') 116 | for ap, acc0, acc1, th, img_path in zip(aps, acc0s, acc1s, img_paths): 117 | file.write(f'{img_path} \t {ap} \t {acc0} \t {acc1}\n') 118 | 119 | if __name__ == '__main__': 120 | opt = TestOptions().parse(print_options=False) 121 | assert opt.ckpt != None 122 | 123 | # Add test args 124 | state_dict = torch.load(opt.ckpt, map_location='cpu') 125 | try: 126 | opt.feature_layer = state_dict['feature_layer'] 127 | opt.decoder_type = state_dict['decoder_type'] 128 | except: 129 | print('No feature_layer or decoder_type in the checkpoint state_dict, using the info from feature_layer and decoder_type args') 130 | 131 | # Load model 132 | model = get_model(opt) 133 | model.load_state_dict(state_dict['model'], strict=False) 134 | print ("Model loaded..") 135 | 136 | model.eval() 137 | model.cuda() 138 | 139 | if os.path.exists(opt.result_folder): 140 | shutil.rmtree(opt.result_folder) 141 | os.makedirs(opt.result_folder) 142 | 143 | if opt.fully_supervised: 144 | dataset_paths = LOCALISATION_DATASET_PATHS 145 | # Create and write the header of results file 146 | with open( os.path.join(opt.result_folder,'scores.txt'), 'a') as f: 147 | f.write('dataset \t iou \t f1_best \t f1_fixed \t ap \n' ) 148 | else: 149 | dataset_paths = DETECTION_DATASET_PATHS 150 | with open( os.path.join(opt.result_folder,'scores.txt'), 'a') as f: 151 | f.write('dataset \t AP \t Acc_fixed \t Acc_best \t Best_threshold \n' ) 152 | 153 | 154 | for dataset_path in (dataset_paths): 155 | print(f"Testing on {dataset_path['key']}") 156 | 157 | set_seed() 158 | opt.test_path = dataset_path['fake_path'] 159 | opt.test_masks_ground_truth_path = dataset_path['masks_path'] 160 | opt.train_dataset = dataset_path['key'] 161 | 162 | if opt.fully_supervised: 163 | dataset = RealFakeDataset(opt) 164 | else: 165 | opt.test_real_list_path = dataset_path['real_path'] 166 | dataset = RealFakeDetectionDataset(opt) 167 | 168 | loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=4) 169 | 170 | # Localisation 171 | if opt.fully_supervised: 172 | if opt.output_save_path: 173 | output_save_path = opt.output_save_path + "/" + dataset_path['key'] 174 | if not os.path.exists(output_save_path): 175 | os.makedirs(output_save_path) 176 | 177 | ious, f1_best, f1_fixed, ap, original_img_paths = validate_fully_supervised(model, loader, dataset_path['key'], output_save_path = opt.output_save_path) 178 | mean_iou = sum(ious)/len(ious) 179 | mean_f1_best = sum(f1_best)/len(f1_best) 180 | mean_f1_fixed = sum(f1_fixed)/len(f1_fixed) 181 | mean_ap = sum(ap)/len(ap) 182 | if opt.output_save_path: 183 | save_scores_to_file(ious, f1_best, f1_fixed, ap, original_img_paths, output_save_path) 184 | 185 | with open( os.path.join(opt.result_folder,'scores.txt'), 'a') as f: 186 | f.write(dataset_path['key']+': ' + str(round(mean_iou, 3))+ '\t' +\ 187 | str(round(mean_f1_best, 4))+ '\t' +\ 188 | str(round(mean_f1_fixed, 4))+ '\t' +\ 189 | str(round(mean_ap, 4))+ '\t' +\ 190 | '\n' ) 191 | print(dataset_path['key']+': IOU = ' + str(round(mean_iou, 3))) 192 | print(dataset_path['key']+': F1_best = ' + str(round(mean_f1_best, 4))) 193 | print(dataset_path['key']+': F1_fixed = ' + str(round(mean_f1_fixed, 4))) 194 | print(dataset_path['key']+': AP = ' + str(round(mean_ap, 4))) 195 | print() 196 | 197 | # Detection 198 | else: 199 | mean_ap, mean_acc, mean_acc_best_th, best_thres, all_img_paths = validate(model, loader) 200 | 201 | with open( os.path.join(opt.result_folder,'scores.txt'), 'a') as f: 202 | f.write(dataset_path['key']+': ' + str(round(mean_ap, 4))+ '\t' +\ 203 | str(round(mean_acc, 4)) + '\t' +\ 204 | str(round(mean_acc_best_th, 4)) + '\t' +\ 205 | str(best_thres) + '\t' +\ 206 | '\n' ) 207 | print(dataset_path['key']+': AP = ' + str(round(mean_ap, 4))) 208 | print(dataset_path['key']+': Acc_fixed = ' + str(round(mean_acc, 4))) 209 | print(dataset_path['key']+': Acc_best = ' + str(round(mean_acc_best_th, 4))) 210 | print(dataset_path['key']+': Best_threshold = ' + str(round(best_thres, 4))) 211 | print() --------------------------------------------------------------------------------