├── demo ├── __init__.py ├── huggingface_gradio.py └── gradio_invert.py ├── helpers ├── __init__.py ├── augmentations.py └── utils.py ├── figures ├── main.png └── astronaut.png ├── app.py ├── README.md ├── requirements.txt ├── invert.py └── .gitignore /demo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamidkazemi22/CLIPInversion/HEAD/figures/main.png -------------------------------------------------------------------------------- /figures/astronaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamidkazemi22/CLIPInversion/HEAD/figures/astronaut.png -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from demo.huggingface_gradio import app 2 | 3 | if __name__ == "__main__": 4 | app.launch(show_api=False, debug=True, share=True, enable_queue=True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [//]: # (# CLIPInversion) 2 | # What do we learn from inverting CLIP models? 3 | **Warning: This paper contains sexually explicit images and 4 | language, offensive visuals and terminology, discussions on 5 | pornography, gender bias, and other potentially unsettling, 6 | distressing, and/or offensive content for certain readers.** 7 | 8 | [Paper](https://arxiv.org/abs/2403.02580) 9 | ![Inverted Images](figures/main.png) 10 | 11 | **Installing requirements:** 12 | 13 | 14 | ```bash 15 | pip install requirements.txt 16 | ``` 17 | **How to run:** 18 | 19 | 20 | ```bash 21 | python invert.py \ 22 | --num_iters 3400 \ # Number of iterations during the inversion process. 23 | --prompt "The map of the African continent" \ # The text prompt to invert. 24 | --img_size 64 \ # Size of the image at iteration 0. 25 | --tv 0.005 \ # Total Variation weight. 26 | --batch_size 13 \ # How many augmentations to use at each iteration. 27 | --bri 0.4 \ # ColorJitter Augmentation brightness degree. 28 | --con 0.4 \ # ColorJitter Augmentation contrast degree. 29 | --sat 0.4 \ # ColorJitter Augmentation saturation degree. 30 | --save_every 100 \ # Frequency at which to save intermediate results. 31 | --print_every 100 \ # Frequency at which to print intermediate information. 32 | --model_name ViT-B/16 # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] 33 | ``` 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==23.2.1 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | altair==5.2.0 5 | annotated-types==0.6.0 6 | anyio==4.2.0 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | certifi==2024.2.2 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | git+https://github.com/openai/CLIP.git 13 | colorama==0.4.6 14 | contourpy==1.2.0 15 | cycler==0.12.1 16 | exceptiongroup==1.2.0 17 | fastapi==0.109.2 18 | ffmpy==0.3.1 19 | filelock==3.13.1 20 | fonttools==4.48.1 21 | frozenlist==1.4.1 22 | fsspec==2024.2.0 23 | ftfy==6.1.3 24 | gradio==3.47.1 25 | gradio-client==0.6.0 26 | h11==0.14.0 27 | httpcore==1.0.2 28 | httpx==0.26.0 29 | huggingface-hub==0.20.3 30 | idna==3.6 31 | importlib-resources==6.1.1 32 | Jinja2==3.1.3 33 | jsonschema==4.21.1 34 | jsonschema-specifications==2023.12.1 35 | kiwisolver==1.4.5 36 | kornia==0.7.1 37 | linkify-it-py==2.0.3 38 | markdown-it-py==2.2.0 39 | MarkupSafe==2.1.5 40 | matplotlib==3.8.2 41 | mdit-py-plugins==0.3.3 42 | mdurl==0.1.2 43 | mpmath==1.3.0 44 | multidict==6.0.5 45 | networkx==3.2.1 46 | numpy==1.26.4 47 | nvidia-cublas-cu12==12.1.3.1 48 | nvidia-cuda-cupti-cu12==12.1.105 49 | nvidia-cuda-nvrtc-cu12==12.1.105 50 | nvidia-cuda-runtime-cu12==12.1.105 51 | nvidia-cudnn-cu12==8.9.2.26 52 | nvidia-cufft-cu12==11.0.2.54 53 | nvidia-curand-cu12==10.3.2.106 54 | nvidia-cusolver-cu12==11.4.5.107 55 | nvidia-cusparse-cu12==12.1.0.106 56 | nvidia-nccl-cu12==2.19.3 57 | nvidia-nvjitlink-cu12==12.3.101 58 | nvidia-nvtx-cu12==12.1.105 59 | orjson==3.9.13 60 | packaging==23.2 61 | pandas==2.2.0 62 | pillow==10.2.0 63 | pydantic==2.6.1 64 | pydantic-core==2.16.2 65 | pydub==0.25.1 66 | pygments==2.17.2 67 | pyparsing==3.1.1 68 | python-dateutil==2.8.2 69 | python-multipart==0.0.7 70 | pytz==2024.1 71 | PyYAML==6.0.1 72 | referencing==0.33.0 73 | regex==2023.12.25 74 | requests==2.31.0 75 | rich==13.7.0 76 | rpds-py==0.17.1 77 | ruff==0.2.1 78 | semantic-version==2.10.0 79 | shellingham==1.5.4 80 | six==1.16.0 81 | sniffio==1.3.0 82 | starlette==0.36.3 83 | sympy==1.12 84 | tomlkit==0.12.0 85 | toolz==0.12.1 86 | torch==2.2.0 87 | torchvision==0.17.0 88 | tqdm==4.66.1 89 | triton==2.2.0 90 | typer==0.9.0 91 | typing-extensions==4.9.0 92 | tzdata==2023.4 93 | uc-micro-py==1.0.3 94 | urllib3==2.2.0 95 | uvicorn==0.27.0.post1 96 | wcwidth==0.2.13 97 | websockets==11.0.3 98 | yarl==1.9.4 99 | zipp==3.17.0 100 | -------------------------------------------------------------------------------- /helpers/augmentations.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | class TotalVariation(nn.Module): 8 | def __init__(self, p: int = 2): 9 | super().__init__() 10 | self.p = p 11 | 12 | def forward(self, x: torch.tensor) -> torch.tensor: 13 | x_wise = x[:, :, :, 1:] - x[:, :, :, :-1] 14 | y_wise = x[:, :, 1:, :] - x[:, :, :-1, :] 15 | diag_1 = x[:, :, 1:, 1:] - x[:, :, :-1, :-1] 16 | diag_2 = x[:, :, 1:, :-1] - x[:, :, :-1, 1:] 17 | return x_wise.norm(p=self.p, dim=(2, 3)).mean() + y_wise.norm(p=self.p, dim=(2, 3)).mean() + \ 18 | diag_1.norm(p=self.p, dim=(2, 3)).mean() + diag_2.norm(p=self.p, dim=(2, 3)).mean() 19 | 20 | 21 | class Jitter(nn.Module): 22 | def __init__(self, lim: int = 32): 23 | super().__init__() 24 | self.lim = lim 25 | 26 | def forward(self, x: torch.tensor) -> torch.tensor: 27 | off1 = random.randint(-self.lim, self.lim) 28 | off2 = random.randint(-self.lim, self.lim) 29 | return torch.roll(x, shifts=(off1, off2), dims=(2, 3)) 30 | 31 | 32 | class JitterBatch(nn.Module): 33 | def __init__(self, lim: int = 32): 34 | super().__init__() 35 | self.lim = lim 36 | 37 | def forward(self, x: torch.tensor) -> torch.tensor: 38 | b, c, h, w = x.shape 39 | out = torch.tensor([]).cuda() 40 | for i in range(b): 41 | off1 = random.randint(-self.lim, self.lim) 42 | off2 = random.randint(-self.lim, self.lim) 43 | out1 = torch.roll(x[i:i + 1], shifts=(off1, off2), dims=(-1, -2)) 44 | out = torch.cat((out, out1), dim=0) 45 | return out 46 | 47 | 48 | class RepeatBatch(nn.Module): 49 | def __init__(self, repeat: int = 32): 50 | super().__init__() 51 | self.size = repeat 52 | 53 | def forward(self, img: torch.tensor): 54 | return img.repeat(self.size, 1, 1, 1) 55 | 56 | 57 | class ColorJitter(nn.Module): 58 | def __init__(self, batch_size: int, shuffle_every: bool = False, mean: float = 1., std: float = 1.): 59 | super().__init__() 60 | self.batch_size, self.mean_p, self.std_p = batch_size, mean, std 61 | self.mean = self.std = None 62 | self.shuffle() 63 | self.shuffle_every = shuffle_every 64 | 65 | def shuffle(self): 66 | self.mean = (torch.rand((self.batch_size, 3, 1, 1,)).cuda() - 0.5) * 2 * self.mean_p 67 | self.std = ((torch.rand((self.batch_size, 3, 1, 1,)).cuda() - 0.5) * 2 * self.std_p).exp() 68 | 69 | def forward(self, img: torch.tensor) -> torch.tensor: 70 | if self.shuffle_every: 71 | self.shuffle() 72 | return (img - self.mean) / self.std 73 | -------------------------------------------------------------------------------- /demo/huggingface_gradio.py: -------------------------------------------------------------------------------- 1 | __all__ = ["app"] 2 | 3 | import gradio as gr 4 | 5 | from demo.gradio_invert import run 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | 9 | to_pil = transforms.ToPILImage() 10 | 11 | 12 | def run_detector(input_str, tv): 13 | for image in run(input_str, tv): 14 | yield to_pil(image[0]) 15 | 16 | 17 | css = """ 18 | .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;} 19 | .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;} 20 | .hyperlinks { 21 | display: flex; 22 | align-items: center; 23 | align-content: center; 24 | padding-top: 12px; 25 | justify-content: flex-end; 26 | margin: 0 10px; /* Adjust the margin as needed */ 27 | text-decoration: none; 28 | color: #000; /* Set the desired text color */ 29 | } 30 | """ 31 | 32 | # Most likely human generated, #most likely AI written 33 | 34 | prompt = '''An astronaut exploring an alien planet, discovering a mysterious ancient artifact" for different models.''' 35 | print(prompt) 36 | # default_image = Image.open('figures/astronaut.png') 37 | with gr.Blocks(css=css, 38 | theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as app: 39 | with gr.Row(): 40 | with gr.Column(scale=3): 41 | gr.HTML("

What do we learn from inverting CLIP models?

") 42 | with gr.Column(scale=3): 43 | gr.HTML( 44 | "

This space may generate sexually explicit and NSFW (Not Safe For Work) images.

") 45 | with gr.Column(scale=1): 46 | gr.HTML(""" 47 |

48 | paper 49 | 50 | code 51 | 52 | contact 53 | """, elem_classes="hyperlinks") 54 | with gr.Row(): 55 | input_box = gr.Textbox(value=prompt, placeholder="Enter prompt here", lines=2, label="Prompt", ) 56 | with gr.Row(): 57 | tv_number = gr.Number(0.01, label='tv') 58 | submit_button = gr.Button("Run Inversion", variant="primary") 59 | clear_button = gr.ClearButton() 60 | with gr.Column(scale=3): 61 | gr.HTML("

Generated Image:

") 62 | with gr.Column(scale=1): 63 | # output_text = gr.Textbox(label="Prediction", value="Most likely AI-Generated") 64 | # output_text = gr.Image(type='pil', show_label=False, width=224, height=224, value=default_image) 65 | output_text = gr.Image(type='pil', show_label=False, width=224, height=224) 66 | with gr.Row(): 67 | gr.HTML("

") 68 | with gr.Row(): 69 | gr.HTML("

") 70 | with gr.Row(): 71 | gr.HTML("

") 72 | 73 | with gr.Accordion("Disclaimer", open=False): 74 | gr.Markdown( 75 | """ 76 | - `Warning` : 77 | - Some prompts lead to NSFW images. 78 | """ 79 | ) 80 | 81 | with gr.Accordion("Cite our work", open=False): 82 | gr.Markdown( 83 | """ 84 | ```bibtex 85 | @misc{hans2024spotting, 86 | title={Spotting LLMs With Binoculars: Zero-Shot Detection of Machine-Generated Text}, 87 | author={Abhimanyu Hans and Avi Schwarzschild and Valeriia Cherepanova and Hamid Kazemi and Aniruddha Saha and Micah Goldblum and Jonas Geiping and Tom Goldstein}, 88 | year={2024}, 89 | eprint={2401.12070}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.CL} 92 | } 93 | """ 94 | ) 95 | 96 | submit_button.click(run_detector, inputs=[input_box, tv_number], outputs=output_text, show_progress='hidden') 97 | clear_button.click(lambda: ("", ""), outputs=[input_box, output_text]) 98 | -------------------------------------------------------------------------------- /invert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import clip 4 | import kornia.augmentation as kaugs 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | from helpers.augmentations import ColorJitter, RepeatBatch, Jitter, TotalVariation 9 | from helpers.utils import Normalization, Scale, freeze_module 10 | from torch.nn.utils import clip_grad_norm_ 11 | 12 | torch.autograd.set_detect_anomaly(True) 13 | 14 | parser = argparse.ArgumentParser(description='inverting clip!') 15 | parser.add_argument('--num_iters', default=3400, type=int) 16 | parser.add_argument('--save_every', default=100, type=int) 17 | parser.add_argument('--print_every', default=50, type=int) 18 | parser.add_argument('--batch_size', default=13, type=int) 19 | parser.add_argument('-p', '--prompt', action='append', type=str, default=[]) 20 | parser.add_argument('-e', '--extra_prompts', action='append', type=str, default=[]) 21 | parser.add_argument('--lr', default=0.1, type=float) 22 | parser.add_argument('--tv', default=0.005, type=float) 23 | parser.add_argument('--jitter', action='store_true') 24 | parser.add_argument('--color', action='store_true') 25 | parser.add_argument('--img_size', default=64, type=int) 26 | parser.add_argument('--eps', default=2 / 255) 27 | parser.add_argument('--optimizer', default='adam') 28 | parser.add_argument('--bri', type=float, default=0.4) 29 | parser.add_argument('--con', type=float, default=0.4) 30 | parser.add_argument('--sat', type=float, default=0.4) 31 | parser.add_argument('--l1', type=float, default=0.) 32 | parser.add_argument('--trial', type=int, default=1) 33 | parser.add_argument('--cg_std', type=float, default=0.) 34 | parser.add_argument('--cg_mean', type=float, default=0.) 35 | parser.add_argument('--model_name', default='ViT-B/16') 36 | parser.add_argument('--prompt_id', type=int, default=0) 37 | 38 | args = parser.parse_args() 39 | args.prompt = ' '.join(args.prompt) 40 | print(f'prompt: <{args.prompt}>') 41 | print(f'extra prompts are: {args.extra_prompts}') 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | # ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'] 44 | model_names = [args.model_name] 45 | models = [] 46 | for model_name in model_names: 47 | model, preprocess = clip.load(model_name, device) 48 | model = model.float() 49 | model = model.cuda() 50 | models.append(model) 51 | normalizer = Normalization([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]).cuda() 52 | scale = Scale(224) 53 | 54 | prompts = [args.prompt] 55 | text_inputs = torch.cat([clip.tokenize(f"{c}") for c in prompts]).to(device) 56 | 57 | corrects, total = 0, 0 58 | for model in models: 59 | freeze_module(model) 60 | image = torch.rand((1, 3, args.img_size, args.img_size)).cuda() 61 | image.requires_grad_() 62 | 63 | 64 | def get_optimizer(image): 65 | if args.optimizer == 'adam': 66 | optimizer = torch.optim.Adam([image], lr=args.lr) 67 | else: 68 | optimizer = torch.optim.LBFGS([image], lr=args.lr) 69 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2000) 70 | 71 | return optimizer, scheduler 72 | 73 | 74 | optimizer, scheduler = get_optimizer(image) 75 | 76 | criterion = nn.CrossEntropyLoss() 77 | text_features_map = {} 78 | for model in models: 79 | text_feature = model.encode_text(text_inputs) 80 | text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True) 81 | text_features_map[model] = text_feature 82 | 83 | save_path = f'images/{args.prompt}/{args.trial}/{args.lr}_{args.tv}_{args.cg_std}_{args.cg_mean}' 84 | os.makedirs(save_path, exist_ok=True) 85 | 86 | seq = [] 87 | if args.jitter: 88 | jitter = Jitter() 89 | seq.append(jitter) 90 | seq.append(RepeatBatch(args.batch_size)) 91 | pre_aug = nn.Sequential(*seq) 92 | aug = kaugs.AugmentationSequential( 93 | kaugs.RandomAffine(30, [0.1, 0.1], [0.7, 1.2], p=.5, padding_mode='border'), 94 | same_on_batch=False, 95 | ) 96 | tv_module = TotalVariation() 97 | 98 | color_jitter = ColorJitter(args.batch_size, True, mean=args.cg_mean, std=args.cg_std) 99 | targets = torch.tensor([0] * args.batch_size).cuda() 100 | 101 | 102 | def forward(image, model): 103 | image_input = pre_aug(image) 104 | image_input = aug(image_input) 105 | scale = Scale(model.visual.input_resolution) 106 | image_input = scale(image_input) 107 | image_input = color_jitter(image_input) 108 | image_input = normalizer(image_input) 109 | image_features = model.encode_image(image_input) 110 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 111 | l2_loss = torch.norm(image_features - text_features_map[model], dim=1) 112 | loss = torch.mean(l2_loss) 113 | return loss, l2_loss 114 | 115 | 116 | change_scale_schedule = [900, 1800] 117 | 118 | softmax = nn.Softmax(dim=1) 119 | for i in range(args.num_iters): 120 | max_grad_norm = 1. 121 | if i in change_scale_schedule: 122 | new_res = image.shape[2] * 2 123 | if args.jitter: 124 | jitter.lim = jitter.lim * 2 125 | if new_res >= 224: 126 | new_res = 224 127 | up_sample = Scale(new_res) 128 | image = up_sample(image.detach()) 129 | image.requires_grad_(True) 130 | optimizer, scheduler = get_optimizer(image) 131 | 132 | 133 | def closure(): 134 | optimizer.zero_grad() 135 | other_loss = tv_module(image) 136 | loss = args.tv * other_loss 137 | image_input = image 138 | l1_loss = torch.norm(image_input, p=1) 139 | loss = loss + args.l1 * l1_loss 140 | for model in models: 141 | xent_loss, scores = forward(image_input, model) 142 | loss = loss + xent_loss * (1 / len(models)) 143 | loss.backward() 144 | clip_grad_norm_([image], max_grad_norm) 145 | image.data = torch.clip(image.data, 0, 1) 146 | if i % args.print_every == 0: 147 | print(f'{i:04d}: loss is {loss:.4f}, xent: {xent_loss:.4f}, tv: {other_loss:.4f}, l1: {l1_loss:.4f}') 148 | if i % args.save_every == 0: 149 | path = os.path.join(save_path, f'{i}.png') 150 | torchvision.utils.save_image(image, path, normalize=True, scale_each=True) 151 | return loss 152 | 153 | 154 | optimizer.step(closure) 155 | if i >= 3400: 156 | scheduler.step() 157 | path = os.path.join(save_path, 'final.png') 158 | 159 | torchvision.utils.save_image(image, path, normalize=True, scale_each=True) 160 | -------------------------------------------------------------------------------- /demo/gradio_invert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import clip 4 | import kornia.augmentation as kaugs 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | from helpers.augmentations import ColorJitter, RepeatBatch, Jitter, TotalVariation 9 | from helpers.utils import Normalization, Scale, freeze_module 10 | from torch.nn.utils import clip_grad_norm_ 11 | 12 | torch.autograd.set_detect_anomaly(True) 13 | 14 | parser = argparse.ArgumentParser(description='inverting clip!') 15 | parser.add_argument('--num_iters', default=3400, type=int) 16 | parser.add_argument('--save_every', default=200, type=int) 17 | parser.add_argument('--print_every', default=1, type=int) 18 | parser.add_argument('--batch_size', default=1, type=int) 19 | parser.add_argument('-p', '--prompt', action='append', type=str, default=[]) 20 | parser.add_argument('-e', '--extra_prompts', action='append', type=str, default=[]) 21 | parser.add_argument('--lr', default=0.1, type=float) 22 | parser.add_argument('--tv', default=0.005, type=float) 23 | parser.add_argument('--jitter', action='store_true') 24 | parser.add_argument('--color', action='store_true') 25 | parser.add_argument('--img_size', default=64, type=int) 26 | parser.add_argument('--eps', default=2 / 255) 27 | parser.add_argument('--optimizer', default='adam') 28 | parser.add_argument('--bri', type=float, default=0.1) 29 | parser.add_argument('--con', type=float, default=0.1) 30 | parser.add_argument('--sat', type=float, default=0.1) 31 | parser.add_argument('--l1', type=float, default=0.) 32 | parser.add_argument('--trial', type=int, default=1) 33 | parser.add_argument('--cg_std', type=float, default=0.) 34 | parser.add_argument('--cg_mean', type=float, default=0.) 35 | parser.add_argument('--model_name', default='ViT-B/16') 36 | parser.add_argument('--prompt_id', type=int, default=0) 37 | parser.add_argument('--add_noise', type=int, default=1) 38 | args = parser.parse_args() 39 | args.prompt = ' '.join(args.prompt) 40 | print(f'prompt: <{args.prompt}>') 41 | print(f'extra prompts are: {args.extra_prompts}') 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | model_names = [args.model_name] 44 | models = [] 45 | for model_name in model_names: 46 | model, preprocess = clip.load(model_name, device) 47 | model.eval() 48 | model = model.float() 49 | model = model.to(device) 50 | models.append(model) 51 | 52 | normalizer = Normalization([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]).to(device) 53 | 54 | 55 | def run(prompt, tv): 56 | prompts = [prompt] 57 | args.tv = tv 58 | text_inputs = torch.cat([clip.tokenize(f"{c}") for c in prompts]).to(device) 59 | 60 | for model in models: 61 | freeze_module(model) 62 | image = torch.rand((1, 3, args.img_size, args.img_size)).to(device) 63 | image.requires_grad_() 64 | 65 | def get_optimizer(image): 66 | if args.optimizer == 'adam': 67 | optimizer = torch.optim.Adam([image], lr=args.lr) 68 | else: 69 | optimizer = torch.optim.LBFGS([image], lr=args.lr) 70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2000) 71 | 72 | return optimizer, scheduler 73 | 74 | optimizer, scheduler = get_optimizer(image) 75 | 76 | text_features_map = {} 77 | for model in models: 78 | text_feature = model.encode_text(text_inputs) 79 | text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True) 80 | text_features_map[model] = text_feature 81 | 82 | save_path = f'images/gradio/{args.prompt}/{args.trial}/{args.lr}_{args.tv}_{args.cg_std}_{args.cg_mean}' 83 | os.makedirs(save_path, exist_ok=True) 84 | 85 | seq = [] 86 | if args.jitter: 87 | jitter = Jitter() 88 | seq.append(jitter) 89 | seq.append(RepeatBatch(args.batch_size)) 90 | pre_aug = nn.Sequential(*seq) 91 | aug = kaugs.AugmentationSequential( 92 | kaugs.ColorJitter(args.bri, args.con, args.sat, 0.1, p=1.0), 93 | kaugs.RandomAffine(30, [0.1, 0.1], [0.7, 1.2], p=.5, padding_mode='border'), 94 | same_on_batch=False, 95 | ) 96 | tv_module = TotalVariation() 97 | 98 | color_jitter = ColorJitter(args.batch_size, True, mean=args.cg_mean, std=args.cg_std) 99 | 100 | def forward(image, model): 101 | image_input = pre_aug(image) 102 | image_input = aug(image_input) 103 | scale = Scale(model.visual.input_resolution) 104 | image_input = scale(image_input) 105 | image_input = color_jitter(image_input) 106 | epsilon = torch.rand_like(image_input) * 0.007 107 | image_input = image_input + epsilon 108 | image_input = normalizer(image_input) 109 | image_features = model.encode_image(image_input) 110 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 111 | l2_loss = torch.norm(image_features - text_features_map[model], dim=1) 112 | loss = torch.mean(l2_loss) 113 | return loss, l2_loss 114 | 115 | change_scale_schedule = [900, 1800] 116 | 117 | for i in range(args.num_iters): 118 | max_grad_norm = 1. 119 | if i in change_scale_schedule: 120 | new_res = image.shape[2] * 2 121 | if args.jitter: 122 | jitter.lim = jitter.lim * 2 123 | if new_res >= 224: 124 | new_res = 224 125 | up_sample = Scale(new_res) 126 | image = up_sample(image.detach()) 127 | image.requires_grad_(True) 128 | optimizer, scheduler = get_optimizer(image) 129 | yield image 130 | 131 | def closure(): 132 | optimizer.zero_grad() 133 | other_loss = tv_module(image) 134 | loss = args.tv * other_loss 135 | image_input = image 136 | l1_loss = torch.norm(image_input, p=1) 137 | loss = loss + args.l1 * l1_loss 138 | for model in models: 139 | xent_loss, scores = forward(image_input, model) 140 | loss = loss + xent_loss * (1 / len(models)) 141 | loss.backward() 142 | clip_grad_norm_([image], max_grad_norm) 143 | image.data = torch.clip(image.data, 0, 1) 144 | if i % args.print_every == 0: 145 | print(f'{i:04d}: loss is {loss:.4f}, xent: {xent_loss:.4f}, tv: {other_loss:.4f}, l1: {l1_loss:.4f}') 146 | return loss 147 | 148 | optimizer.step(closure) 149 | if i >= 3400: 150 | scheduler.step() 151 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,python 3 | 4 | ### PyCharm ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # AWS User-specific 16 | .idea/**/aws.xml 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/artifacts 39 | # .idea/compiler.xml 40 | # .idea/jarRepositories.xml 41 | # .idea/modules.xml 42 | # .idea/*.iml 43 | # .idea/modules 44 | # *.iml 45 | # *.ipr 46 | 47 | # CMake 48 | cmake-build-*/ 49 | 50 | # Mongo Explorer plugin 51 | .idea/**/mongoSettings.xml 52 | 53 | # File-based project format 54 | *.iws 55 | 56 | # IntelliJ 57 | out/ 58 | 59 | # mpeltonen/sbt-idea plugin 60 | .idea_modules/ 61 | 62 | # JIRA plugin 63 | atlassian-ide-plugin.xml 64 | 65 | # Cursive Clojure plugin 66 | .idea/replstate.xml 67 | 68 | # SonarLint plugin 69 | .idea/sonarlint/ 70 | 71 | # Crashlytics plugin (for Android Studio and IntelliJ) 72 | com_crashlytics_export_strings.xml 73 | crashlytics.properties 74 | crashlytics-build.properties 75 | fabric.properties 76 | 77 | # Editor-based Rest Client 78 | .idea/httpRequests 79 | 80 | # Android studio 3.1+ serialized cache file 81 | .idea/caches/build_file_checksums.ser 82 | 83 | ### PyCharm Patch ### 84 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 85 | 86 | # *.iml 87 | # modules.xml 88 | # .idea/misc.xml 89 | # *.ipr 90 | 91 | # Sonarlint plugin 92 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 93 | .idea/**/sonarlint/ 94 | 95 | # SonarQube Plugin 96 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 97 | .idea/**/sonarIssues.xml 98 | 99 | # Markdown Navigator plugin 100 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 101 | .idea/**/markdown-navigator.xml 102 | .idea/**/markdown-navigator-enh.xml 103 | .idea/**/markdown-navigator/ 104 | 105 | # Cache file creation bug 106 | # See https://youtrack.jetbrains.com/issue/JBR-2257 107 | .idea/$CACHE_FILE$ 108 | 109 | # CodeStream plugin 110 | # https://plugins.jetbrains.com/plugin/12206-codestream 111 | .idea/codestream.xml 112 | 113 | # Azure Toolkit for IntelliJ plugin 114 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 115 | .idea/**/azureSettings.xml 116 | 117 | ### Python ### 118 | # Byte-compiled / optimized / DLL files 119 | __pycache__/ 120 | *.py[cod] 121 | *$py.class 122 | 123 | # C extensions 124 | *.so 125 | 126 | # Distribution / packaging 127 | .Python 128 | build/ 129 | develop-eggs/ 130 | dist/ 131 | downloads/ 132 | eggs/ 133 | .eggs/ 134 | lib/ 135 | lib64/ 136 | parts/ 137 | sdist/ 138 | var/ 139 | wheels/ 140 | share/python-wheels/ 141 | *.egg-info/ 142 | .installed.cfg 143 | *.egg 144 | MANIFEST 145 | 146 | # PyInstaller 147 | # Usually these files are written by a python script from a template 148 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 149 | *.manifest 150 | *.spec 151 | 152 | # Installer logs 153 | pip-log.txt 154 | pip-delete-this-directory.txt 155 | 156 | # Unit test / coverage reports 157 | htmlcov/ 158 | .tox/ 159 | .nox/ 160 | .coverage 161 | .coverage.* 162 | .cache 163 | nosetests.xml 164 | coverage.xml 165 | *.cover 166 | *.py,cover 167 | .hypothesis/ 168 | .pytest_cache/ 169 | cover/ 170 | 171 | # Translations 172 | *.mo 173 | *.pot 174 | 175 | # Django stuff: 176 | *.log 177 | local_settings.py 178 | db.sqlite3 179 | db.sqlite3-journal 180 | 181 | # Flask stuff: 182 | instance/ 183 | .webassets-cache 184 | 185 | # Scrapy stuff: 186 | .scrapy 187 | 188 | # Sphinx documentation 189 | docs/_build/ 190 | 191 | # PyBuilder 192 | .pybuilder/ 193 | target/ 194 | 195 | # Jupyter Notebook 196 | .ipynb_checkpoints 197 | 198 | # IPython 199 | profile_default/ 200 | ipython_config.py 201 | 202 | # pyenv 203 | # For a library or package, you might want to ignore these files since the code is 204 | # intended to run in multiple environments; otherwise, check them in: 205 | # .python-version 206 | 207 | # pipenv 208 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 209 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 210 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 211 | # install all needed dependencies. 212 | #Pipfile.lock 213 | 214 | # poetry 215 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 216 | # This is especially recommended for binary packages to ensure reproducibility, and is more 217 | # commonly ignored for libraries. 218 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 219 | #poetry.lock 220 | 221 | # pdm 222 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 223 | #pdm.lock 224 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 225 | # in version control. 226 | # https://pdm.fming.dev/#use-with-ide 227 | .pdm.toml 228 | 229 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 230 | __pypackages__/ 231 | 232 | # Celery stuff 233 | celerybeat-schedule 234 | celerybeat.pid 235 | 236 | # SageMath parsed files 237 | *.sage.py 238 | 239 | # Environments 240 | .env 241 | .venv 242 | env/ 243 | venv/ 244 | ENV/ 245 | env.bak/ 246 | venv.bak/ 247 | 248 | # Spyder project settings 249 | .spyderproject 250 | .spyproject 251 | 252 | # Rope project settings 253 | .ropeproject 254 | 255 | # mkdocs documentation 256 | /site 257 | 258 | # mypy 259 | .mypy_cache/ 260 | .dmypy.json 261 | dmypy.json 262 | 263 | # Pyre type checker 264 | .pyre/ 265 | 266 | # pytype static type analyzer 267 | .pytype/ 268 | 269 | # Cython debug symbols 270 | cython_debug/ 271 | 272 | # PyCharm 273 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 274 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 275 | # and can be added to the global gitignore or merged into this file. For a more nuclear 276 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 277 | #.idea/ 278 | 279 | ### Python Patch ### 280 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 281 | poetry.toml 282 | 283 | # ruff 284 | .ruff_cache/ 285 | 286 | # LSP config files 287 | pyrightconfig.json 288 | 289 | # End of https://www.toptal.com/developers/gitignore/api/pycharm,python 290 | 291 | 292 | images/* 293 | .idea/* -------------------------------------------------------------------------------- /helpers/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.utils.data 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | import torch.nn.functional as F 13 | 14 | 15 | def fix_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | 21 | 22 | def get_loaders(batch_size=256, n_workers=4, dataset_name='cifar10', return_dataset=False): 23 | train_transform = transforms.Compose([ 24 | transforms.RandomCrop(size=32, padding=4), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | ]) 28 | test_transform = transforms.Compose([ 29 | transforms.ToTensor(), 30 | ]) 31 | dataset = datasets.CIFAR10 if dataset_name == 'cifar10' else datasets.CIFAR100 32 | train_dataset = dataset(f'data/datasets/{dataset_name}', download=True, 33 | transform=train_transform) 34 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 35 | shuffle=True, num_workers=n_workers) 36 | test_dataset = dataset(f'data/datasets/{dataset_name}', download=True, train=False, 37 | transform=test_transform) 38 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 39 | shuffle=False, num_workers=n_workers) 40 | if return_dataset: 41 | return train_loader, test_loader, train_dataset, test_dataset 42 | return train_loader, test_loader 43 | 44 | 45 | def get_imagenet(batch_size=256, n_workers=4, path='data/datasets/ILSVRC2012/{}', shuffle=True): 46 | train_transforms = transforms.Compose( 47 | [transforms.RandomResizedCrop(224), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), ]) 50 | 51 | eval_transforms = transforms.Compose( 52 | [transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), ]) 55 | 56 | train_dataset = datasets.ImageFolder(root=path.format('train'), 57 | transform=train_transforms) 58 | test_dataset = datasets.ImageFolder(root=path.format('val'), 59 | transform=eval_transforms) 60 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 61 | num_workers=n_workers, shuffle=True, pin_memory=True) 62 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 63 | num_workers=n_workers, shuffle=True, pin_memory=True) 64 | return train_loader, test_loader 65 | 66 | 67 | class Normalization(nn.Module): 68 | def __init__(self, mean, std): 69 | super(Normalization, self).__init__() 70 | self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1)) 71 | self.register_buffer('std', torch.tensor(std).view(-1, 1, 1)) 72 | 73 | def forward(self, img): 74 | return (img - self.mean) / self.std 75 | 76 | 77 | class GuassianNoise(nn.Module): 78 | def __init__(self, mean=0., std=1.): 79 | super(GuassianNoise, self).__init__() 80 | self.register_buffer('mean', mean) 81 | self.register_buffer('std', std) 82 | 83 | def forward(self, img): 84 | out = img + torch.randn(img.size()) * self.std + self.mean 85 | out = torch.clamp(out, 0., 1.) 86 | return out 87 | 88 | 89 | def train_step(loader, model_md, loss_fn, opt, epoch_n, scheduler=None, normal_fn=None, 90 | modify_fn=None, file=None): 91 | model_md.train() 92 | running_loss = 0.0 93 | running_corrects = 0 94 | total = 0 95 | for i, (image, label) in enumerate(loader): 96 | image = image.cuda() 97 | label = label.cuda() 98 | opt.zero_grad() 99 | image = modify_fn(image, label) if modify_fn else image 100 | image = normal_fn(image) if normal_fn else image 101 | output = model_md(image) 102 | preds = torch.argmax(output, -1) 103 | loss = loss_fn(output, label) 104 | loss.backward() 105 | opt.step() 106 | running_loss += loss.item() * image.shape[0] 107 | running_corrects += torch.sum(preds == label) 108 | total += image.shape[0] 109 | epoch_loss = running_loss / total 110 | epoch_acc = running_corrects.double() / total 111 | end = '\n' if i == (len(loader) - 1) else '\r' 112 | print(f'epoch: {epoch_n:04d}, Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}, {i + 1:04d}/{len(loader)}', 113 | end=end) 114 | scheduler.step() if scheduler else None 115 | 116 | 117 | def test_step(test_loader, model, loss_fn, normal_fn=None, modify_fn=None): 118 | model.eval() 119 | running_loss = 0.0 120 | running_corrects = 0 121 | total = 0 122 | for i, (image, label) in enumerate(test_loader): 123 | image = image.cuda() 124 | label = label.cuda() 125 | image = modify_fn(image, label) if modify_fn else image 126 | image = normal_fn(image) if normal_fn else image 127 | with torch.no_grad(): 128 | output = model(image) 129 | loss = loss_fn(output, label) 130 | preds = torch.argmax(output, 1) 131 | running_loss += loss.item() * image.shape[0] 132 | running_corrects += torch.sum(preds == label.data) 133 | total += image.shape[0] 134 | end = '\n' if i == (len(test_loader) - 1) else '\r' 135 | 136 | loss = running_loss / total 137 | accuracy = running_corrects.double() / total 138 | print(( 139 | f'Test Loss: {loss:.4f} Test Acc: {accuracy:.4f}, {i + 1:02d}/{len(test_loader)}'), 140 | end=end) 141 | accuracy = running_corrects.double() / total 142 | return accuracy 143 | 144 | 145 | def adv_test_step(test_loader, model, loss_fn, revertor=None, normal_fn=None, 146 | modify_fn=None): 147 | model.eval() 148 | running_loss = 0.0 149 | running_corrects = 0 150 | total = 0 151 | for i, (image, label) in enumerate(test_loader): 152 | image = image.cuda() 153 | label = label.cuda() 154 | # if i == 0: 155 | # plt.figure() 156 | # im = image[0].detach().cpu().numpy() 157 | # plt.imshow(np.moveaxis(im, 0, -1)) 158 | # plt.savefig('images/asli.png') 159 | image = modify_fn(image, label) if modify_fn else image 160 | if revertor is not None: 161 | image = image + revertor[label] 162 | # import matplotlib.pyplot as plt 163 | # if i == 0: 164 | # plt.figure() 165 | # im = image[0].detach().cpu().numpy() 166 | # plt.imshow(np.moveaxis(im, 0, -1)) 167 | # plt.savefig('images/adv.png') 168 | image = normal_fn(image) if normal_fn else image 169 | with torch.no_grad(): 170 | output = model(image) 171 | loss = loss_fn(output, label) 172 | preds = torch.argmax(output, 1) 173 | running_loss += loss.item() * image.shape[0] 174 | running_corrects += torch.sum(preds == label.data) 175 | total += image.shape[0] 176 | end = '\n' if i == (len(test_loader) - 1) else '\r' 177 | 178 | loss = running_loss / total 179 | accuracy = running_corrects.double() / total 180 | print(( 181 | f'Test Loss: {loss:.4f} Test Acc: {accuracy:.4f}, {i + 1:02d}/{len(test_loader)}'), 182 | end=end) 183 | accuracy = running_corrects.double() / total 184 | return accuracy 185 | 186 | 187 | def freeze_module(module: nn.Module, reverse=False): 188 | for param in module.parameters(): 189 | param.requires_grad = reverse 190 | 191 | 192 | def cross_entropy(pred, soft_targets): 193 | logsoftmax = nn.LogSoftmax(dim=1) 194 | return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1)) 195 | 196 | 197 | def get_trainable_params(module: nn.Module): 198 | trainable_params = filter(lambda p: p.requires_grad, module.parameters()) 199 | return trainable_params 200 | 201 | 202 | def get_optimizer(lr, model=None, params=None): 203 | trainable_params = params if params else get_trainable_params(model) 204 | optimizer = torch.optim.SGD(trainable_params, 205 | lr=lr, momentum=0.9, 206 | dampening=0, weight_decay=1e-4, 207 | nesterov=True) 208 | return optimizer 209 | 210 | 211 | def params_num(module: nn.Module): 212 | return len(list(get_trainable_params(module))) 213 | 214 | 215 | def make_pgd(model: nn.Module, image: torch.Tensor, normal_fn, loss_fn, label, eps, step_size=2 / 255, 216 | iters=10): 217 | # copy_model = copy.deepcopy(model) 218 | copy_model = model 219 | copy_model.eval() 220 | copy_image = image.detach().clone() 221 | # freeze_module(copy_model) 222 | copy_image.requires_grad = True 223 | for step in range(iters): 224 | output = normal_fn(copy_image) 225 | output = copy_model(output) 226 | loss = loss_fn(output, label) 227 | loss.backward() 228 | adv_image = copy_image + step_size * copy_image.grad.sign() 229 | perturb = torch.clamp(adv_image - image, -eps, +eps) 230 | copy_image.data = torch.clamp(image.data + perturb.data, 0, 1) 231 | 232 | # del copy_model 233 | return copy_image 234 | 235 | 236 | def make_pgd_v2(model: nn.Module, image: torch.Tensor, normal_fn, loss_fn, label, eps, step_size=2 / 255, 237 | iters=10): 238 | model.eval() 239 | # freeze_module(model) 240 | copy_image = image.detach().clone() 241 | copy_image.requires_grad = True 242 | for step in range(iters): 243 | output = normal_fn(copy_image) 244 | output = model(output) 245 | loss = loss_fn(output, label) 246 | loss.backward() 247 | adv_image = copy_image + step_size * copy_image.grad.sign() 248 | perturb = torch.clamp(adv_image - image, -eps, +eps) 249 | copy_image.data = torch.clamp(image.data + perturb.data, 0, 1) 250 | return copy_image 251 | 252 | 253 | def make_target_pgd(model: nn.Module, image: torch.Tensor, normal_fn, loss_fn, 254 | target_label, eps, iters=10): 255 | copy_model = copy.deepcopy(model) 256 | copy_image = image.clone().detach() 257 | freeze_module(copy_model) 258 | for step in range(iters): 259 | copy_image.requires_grad = True 260 | output = normal_fn(copy_image) 261 | output = copy_model(output) 262 | loss = loss_fn(output, target_label) 263 | loss.backward() 264 | adv_image = copy_image - eps * copy_image.grad.sign() 265 | perturb = torch.clamp(adv_image - image, -eps, +eps) 266 | copy_image = image + perturb 267 | copy_image.detach_() 268 | copy_image.clamp_(0, 1) 269 | 270 | del copy_model 271 | return copy_image 272 | 273 | 274 | def make_adv(model: nn.Module, image: torch.Tensor, normal_fn, loss_fn, label, eps, 275 | lr=0.1): 276 | copy_model = copy.deepcopy(model) 277 | copy_image = image.clone().detach() 278 | freeze_module(copy_model) 279 | for step in range(10): 280 | copy_image.requires_grad = True 281 | output = normal_fn(copy_image) 282 | output = copy_model(output) 283 | loss = loss_fn(output, label) 284 | loss.backward() 285 | adv_image = copy_image + lr * copy_image.grad 286 | perturb = torch.clamp(adv_image - image, -eps, +eps) 287 | copy_image = image + perturb 288 | copy_image.detach_() 289 | copy_image.clamp_(0, 1) 290 | 291 | del copy_model 292 | return copy_image 293 | 294 | 295 | def create_file(name): 296 | try: 297 | os.makedirs(name) 298 | except: 299 | pass 300 | 301 | 302 | def get_params(model): 303 | num_params = sum(p.numel() for p in model.parameters()) 304 | return num_params 305 | 306 | 307 | def set_lr(optimizer, lr): 308 | for g in optimizer.param_groups: 309 | g['lr'] = lr 310 | 311 | 312 | class LogPrint: 313 | def __init__(self, path): 314 | self.path = path 315 | self.file = open(path, 'w') 316 | 317 | def log(self, s): 318 | sys.stdout.write(s) 319 | self.file.write(s) 320 | 321 | def close(self): 322 | self.file.close() 323 | 324 | def flush(self): 325 | self.file.flush() 326 | sys.stdout.flush() 327 | 328 | 329 | def make_image(image: torch.Tensor): 330 | batch_size, c, h, w = image.shape 331 | flattened = image.view(batch_size, -1) 332 | batch_min, batch_max = torch.min(flattened, 1, keepdim=True)[0], torch.max(flattened, 1, keepdim=True)[0] 333 | flattened -= batch_min 334 | flattened /= torch.clamp(batch_max - batch_min, min=1e-5) 335 | return flattened.view(batch_size, c, h, w) 336 | 337 | 338 | def gray_scale(image): 339 | return torch.mean(image, dim=1, keepdim=True) 340 | 341 | 342 | def make_resnet_sequential(model: nn.Module, normalization=None): 343 | modules = [] 344 | if normalization is not None: 345 | modules.append(normalization) 346 | 347 | for name, module in model.named_children(): 348 | modules.append(module) 349 | modules.insert(-1, nn.Flatten()) 350 | out = nn.Sequential(*modules) 351 | return out 352 | 353 | 354 | def make_resnet_complete_sequential(base, normalization=None): 355 | modules = [] 356 | if normalization is not None: 357 | modules.append(modules) 358 | for name, module in base.named_children(): 359 | if isinstance(module, nn.Sequential): 360 | modules.extend(module) 361 | else: 362 | modules.append(module) 363 | modules.insert(-1, nn.Flatten()) 364 | model = nn.Sequential(*modules) 365 | return model 366 | 367 | 368 | def zero_grad(image): 369 | if image.grad is not None: 370 | if image.grad.grad_fn is not None: 371 | image.grad.detach_() 372 | else: 373 | image.grad.requires_grad_(False) 374 | image.grad.data.zero_() 375 | 376 | 377 | class Logger: 378 | def __init__(self, out_dir, log_name, resume=False): 379 | if not os.path.isdir(out_dir): 380 | os.makedirs(out_dir) 381 | self.fpath = os.path.join(out_dir, log_name) 382 | if not resume: 383 | with open(self.fpath, "w") as f: 384 | f.truncate() 385 | 386 | def log(self, content, end='\n'): 387 | with open(self.fpath, "a") as f: 388 | f.write(f'{str(content)}{end}') 389 | 390 | 391 | def get_ds_info(dataset): 392 | if dataset == 'cifar10': 393 | return { 394 | 'mean': [0.4914, 0.4822, 0.4465], 395 | 'std': [0.2023, 0.1994, 0.2010], 396 | 'classes': ('plane', 'car', 'bird', 'cat', 'deer', 397 | 'dog', 'frog', 'horse', 'ship', 'truck') 398 | } 399 | raise Exception('the dataset info was not found.') 400 | 401 | 402 | def print_message(message, it, last_it): 403 | end = '\r' 404 | if it == last_it: 405 | end = '\n' 406 | print(message, end=end) 407 | 408 | 409 | class Scale(nn.Module): 410 | def __init__(self, size, mode='bicubic'): 411 | super(Scale, self).__init__() 412 | self.mode = mode 413 | self.size = size 414 | 415 | def forward(self, x): 416 | return F.interpolate(x, size=(self.size, self.size), mode=self.mode) 417 | 418 | 419 | --------------------------------------------------------------------------------