├── __pycache__ ├── utils.cpython-38.pyc └── dataset.cpython-38.pyc ├── config ├── cifar-10c.yaml └── imagenet-c.yaml ├── conda_env.yml ├── utils.py ├── dataset.py ├── README.md └── main.py /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anniesch/surgical-finetuning/HEAD/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anniesch/surgical-finetuning/HEAD/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /config/cifar-10c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - user: username 6 | 7 | wandb: 8 | project: surgical-finetuning 9 | exp_name: none 10 | sweep_filename: none 11 | use: false 12 | 13 | data: 14 | dataset_name: cifar10 15 | model_name: Standard 16 | corruption_types: [brightness] 17 | severity: 5 18 | batch_size: 64 19 | num_workers: 2 20 | 21 | args: 22 | train_mode: eval 23 | tune_option: first_two 24 | train_n: 1000 25 | epochs: 20 26 | seed: 0 27 | log_dir: cifar 28 | auto_tune: none 29 | 30 | hydra: 31 | output_subdir: hydra 32 | run: 33 | dir: ./results/${data.dataset_name}/${now:%Y.%m.%d} -------------------------------------------------------------------------------- /config/imagenet-c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - user: annie 6 | 7 | wandb: 8 | project: surgical-finetuning 9 | exp_name: none 10 | sweep_filename: none 11 | use: false 12 | 13 | data: 14 | dataset_name: imagenet-c 15 | model_name: Standard_R50 16 | corruption_types: [brightness] 17 | severity: 5 18 | batch_size: 64 19 | num_workers: 2 20 | 21 | args: 22 | train_mode: eval 23 | tune_option: first_two 24 | train_n: 1000 25 | epochs: 10 26 | seed: 0 27 | log_dir: imagenet 28 | auto_tune: none 29 | 30 | hydra: 31 | output_subdir: hydra 32 | run: 33 | dir: ./results/${data.dataset_name}/${data.corruption_types[0]} -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: fl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip=21.1.3 7 | - absl-py=0.13.0 8 | - pyparsing=2.4.7 9 | - jupyterlab=3.0.14 10 | - scikit-image=0.18.1 11 | - nvidia::cudatoolkit=10.2 12 | - pytorch::pytorch 13 | - pytorch::torchvision 14 | - pytorch::torchaudio 15 | - pip: 16 | - wandb 17 | - numpy==1.22.2 18 | - termcolor==1.1.0 19 | - tensorboard==2.8.0 20 | - imageio==2.9.0 21 | - imageio-ffmpeg==0.4.4 22 | - hydra-core==1.2.0 23 | - hydra-submitit-launcher==1.1.5 24 | - ipdb==0.13.9 25 | - yapf==0.31.0 26 | - sklearn==0.0 27 | - matplotlib==3.4.2 28 | - opencv-python==4.5.3.56 29 | - git+https://github.com/RobustBench/robustbench.git 30 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import omegaconf 5 | import torch 6 | import wandb 7 | 8 | 9 | def set_seed_everywhere(seed): 10 | torch.manual_seed(seed) 11 | if torch.cuda.is_available(): 12 | torch.cuda.manual_seed_all(seed) 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | 16 | 17 | def to_torch(xs, device): 18 | return tuple(torch.as_tensor(x, device=device) for x in xs) 19 | 20 | 21 | def setup_wandb(cfg): 22 | cfg_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 23 | wandb.init( 24 | entity=cfg.user.wandb_id, 25 | project=cfg.wandb.project, 26 | settings=wandb.Settings(start_method="thread"), 27 | name=cfg.wandb.exp_name, 28 | # reinit=True, 29 | ) 30 | wandb.config.update(cfg_dict, allow_val_change=True) 31 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from robustbench.data import load_cifar10c 3 | from torch.utils.data import DataLoader, Subset, TensorDataset 4 | from torchvision import transforms 5 | from torchvision.datasets import ImageFolder 6 | from torch.utils.data import random_split 7 | import torchvision 8 | import numpy as np 9 | 10 | 11 | def get_loaders(cfg, corruption_type, severity): 12 | if cfg.data.dataset_name == "cifar10": 13 | x_corr, y_corr = load_cifar10c( 14 | 10000, severity, cfg.user.root_dir, False, [corruption_type] 15 | ) 16 | assert cfg.args.train_n <= 9000 17 | labels = {} 18 | num_classes = int(max(y_corr)) + 1 19 | for i in range(num_classes): 20 | labels[i] = [ind for ind, n in enumerate(y_corr) if n == i] 21 | num_ex = cfg.args.train_n // num_classes 22 | tr_idxs = [] 23 | val_idxs = [] 24 | test_idxs = [] 25 | for i in range(len(labels.keys())): 26 | np.random.shuffle(labels[i]) 27 | tr_idxs.append(labels[i][:num_ex]) 28 | val_idxs.append(labels[i][num_ex:num_ex+10]) 29 | test_idxs.append(labels[i][num_ex+10:num_ex+100]) 30 | tr_idxs = np.concatenate(tr_idxs) 31 | val_idxs = np.concatenate(val_idxs) 32 | test_idxs = np.concatenate(test_idxs) 33 | 34 | tr_dataset = TensorDataset(x_corr[tr_idxs], y_corr[tr_idxs]) 35 | val_dataset = TensorDataset(x_corr[val_idxs], y_corr[val_idxs]) 36 | te_dataset = TensorDataset(x_corr[test_idxs], y_corr[test_idxs]) 37 | 38 | elif cfg.data.dataset_name == "imagenet-c": 39 | data_root = Path(cfg.user.root_dir) 40 | image_dir = data_root / "ImageNet-C" / corruption_type / str(severity) 41 | dataset = ImageFolder(image_dir, transform=transforms.ToTensor()) 42 | indices = list(range(len(dataset.imgs))) #50k examples --> 50 per class 43 | assert cfg.args.train_n <= 20000 44 | labels = {} 45 | y_corr = dataset.targets 46 | for i in range(max(y_corr)+1): 47 | labels[i] = [ind for ind, n in enumerate(y_corr) if n == i] 48 | num_ex = cfg.args.train_n // (max(y_corr)+1) 49 | tr_idxs = [] 50 | val_idxs = [] 51 | test_idxs = [] 52 | for i in range(len(labels.keys())): 53 | np.random.shuffle(labels[i]) 54 | tr_idxs.append(labels[i][:num_ex]) 55 | val_idxs.append(labels[i][num_ex:num_ex+10]) 56 | test_idxs.append(labels[i][num_ex+10:num_ex+20]) 57 | tr_idxs = np.concatenate(tr_idxs) 58 | val_idxs = np.concatenate(val_idxs) 59 | test_idxs = np.concatenate(test_idxs) 60 | tr_dataset = Subset(dataset, tr_idxs) 61 | val_dataset = Subset(dataset, val_idxs) 62 | te_dataset = Subset(dataset, test_idxs) 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Surgical Fine-Tuning Improves Adaptation to Distribution Shifts 2 | 3 | This repo provides starter code for the following paper published at ICLR 2023: 4 | > [Surgical Fine-Tuning Improves Adaptation to Distribution Shifts](https://openreview.net/pdf?id=APuPRxjHvZ). 5 | 6 | The purpose of this repo is to provide a sample implementation of surgical fine-tuning, which is simple to add to existing codebases: just optimize the parameters in the desired layers. Here we provide some sample code for running on CIFAR-C and ImageNet-C datasets. 7 | The fine-tuning pipeline is all in `main.py` with argument configs for the datasets in `config/`. 8 | 9 | ## Environment 10 | 11 | Create an environment with the following command: 12 | ``` 13 | conda env create -f conda_env.yml 14 | ``` 15 | 16 | 17 | ## **Sample Commands for Surgical Fine-Tuning** 18 | 19 | Before running, download the data ([CIFAR-10C](https://zenodo.org/record/2535967) or [ImageNet-C](https://zenodo.org/record/2235448)) and update the paths in the configs accordingly. 20 | 21 | ``` 22 | python main.py --config-name='cifar-10c' args.train_n=1000 args.seed=0 data.corruption_types=['defocus_blur'] wandb.use=True 23 | python main.py --config-name='cifar-10c' args.train_n=1000 args.seed=0 data.corruption_types=[frost,gaussian_blur,gaussian_noise,glass_blur,impulse_noise,jpeg_compression,motion_blur,pixelate,saturate,shot_noise,snow,spatter,speckle_noise,zoom_blur] wandb.use=False args.auto_tune=none args.epochs=15 24 | python main.py --config-name='imagenet-c' args.train_n=5000 args.seed=0 data.corruption_types=[brightness,contrast,defocus_blur,elastic_transform,fog,frost,gaussian_noise,glass_blur,impulse_noise,jpeg_compression,motion_blur,pixelate,shot_noise,snow,zoom_blur] wandb.use=False args.auto_tune=none args.epochs=10 25 | ``` 26 | 27 | ## Running Auto-RGN 28 | ``` 29 | python main.py --config-name='cifar-10c' args.train_n=1000 args.seed=0 data.corruption_types=[frost,gaussian_blur,gaussian_noise,glass_blur,impulse_noise,jpeg_compression,motion_blur,pixelate,saturate,shot_noise,snow,spatter,speckle_noise,zoom_blur] wandb.use=True args.auto_tune=RGN args.epochs=15 30 | 31 | python main.py --config-name='imagenet-c' args.train_n=5000 args.seed=2 data.corruption_types=[brightness,contrast,defocus_blur,elastic_transform,fog,frost,gaussian_noise,glass_blur,impulse_noise,jpeg_compression,motion_blur,pixelate,shot_noise,snow,zoom_blur] wandb.use=False args.auto_tune=RGN args.epochs=10 32 | 33 | ``` 34 | 35 | # Citing Surgical Finetuning 36 | If surgical fine-tuning or this repository is useful in your own research, you can use the following BibTeX entry: 37 | 38 | @article{lee2022surgical, 39 | title={Surgical fine-tuning improves adaptation to distribution shifts}, 40 | author={Lee, Yoonho and Chen, Annie S and Tajwar, Fahim and Kumar, Ananya and Yao, Huaxiu and Liang, Percy and Finn, Chelsea}, 41 | journal={International Conference on Learning Representations}, 42 | year={2023} 43 | } 44 | 45 | 46 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import itertools 3 | import os 4 | import time 5 | from collections import defaultdict 6 | import copy 7 | 8 | import pathlib 9 | from pathlib import Path 10 | from datetime import date 11 | import hydra 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch 15 | from torch.distributions import Categorical 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from robustbench.model_zoo.enums import ThreatModel 19 | from robustbench.utils import load_model 20 | import wandb 21 | import pandas as pd 22 | from PIL import Image 23 | 24 | import utils 25 | from dataset import get_loaders 26 | 27 | 28 | @torch.no_grad() 29 | def test(model, loader, criterion, cfg): 30 | model.eval() 31 | all_test_corrects = [] 32 | total_loss = 0.0 33 | for x, y in loader: 34 | x, y = x.cuda(), y.cuda() 35 | logits = model(x) 36 | loss = criterion(logits, y) 37 | all_test_corrects.append(torch.argmax(logits, dim=-1) == y) 38 | total_loss += loss 39 | acc = torch.cat(all_test_corrects).float().mean().detach().item() 40 | total_loss = total_loss / len(loader) 41 | total_loss = total_loss.detach().item() 42 | return acc, total_loss 43 | 44 | def get_lr_weights(model, loader, cfg): 45 | layer_names = [ 46 | n for n, _ in model.named_parameters() if "bn" not in n 47 | ] 48 | metrics = defaultdict(list) 49 | average_metrics = defaultdict(float) 50 | partial_loader = itertools.islice(loader, 5) 51 | xent_grads, entropy_grads = [], [] 52 | for x, y in partial_loader: 53 | x, y = x.cuda(), y.cuda() 54 | logits = model(x) 55 | 56 | loss_xent = F.cross_entropy(logits, y) 57 | grad_xent = torch.autograd.grad( 58 | outputs=loss_xent, inputs=model.parameters(), retain_graph=True 59 | ) 60 | xent_grads.append([g.detach() for g in grad_xent]) 61 | 62 | def get_grad_norms(model, grads, cfg): 63 | _metrics = defaultdict(list) 64 | grad_norms, rel_grad_norms = [], [] 65 | for (name, param), grad in zip(model.named_parameters(), grads): 66 | if name not in layer_names: 67 | continue 68 | if cfg.args.auto_tune == 'eb-criterion': 69 | tmp = (grad*grad) / (torch.var(grad, dim=0, keepdim=True)+1e-8) 70 | _metrics[name] = tmp.mean().item() 71 | else: 72 | _metrics[name] = torch.norm(grad).item() / torch.norm(param).item() 73 | 74 | return _metrics 75 | 76 | for xent_grad in xent_grads: 77 | xent_grad_metrics = get_grad_norms(model, xent_grad, cfg) 78 | for k, v in xent_grad_metrics.items(): 79 | metrics[k].append(v) 80 | for k, v in metrics.items(): 81 | average_metrics[k] = np.array(v).mean(0) 82 | return average_metrics 83 | 84 | def train(model, loader, criterion, opt, cfg, orig_model=None): 85 | all_train_corrects = [] 86 | total_loss = 0.0 87 | magnitudes = defaultdict(float) 88 | 89 | for x, y in loader: 90 | x, y = x.cuda(), y.cuda() 91 | logits = model(x) 92 | loss = criterion(logits, y) 93 | all_train_corrects.append(torch.argmax(logits, dim=-1) == y) 94 | total_loss += loss 95 | 96 | opt.zero_grad() 97 | loss.backward() 98 | opt.step() 99 | 100 | acc = torch.cat(all_train_corrects).float().mean().detach().item() 101 | total_loss = total_loss / len(loader) 102 | total_loss = total_loss.detach().item() 103 | return acc, total_loss, magnitudes 104 | 105 | 106 | @hydra.main(config_path="config", config_name="config") 107 | def main(cfg): 108 | cfg.args.log_dir = pathlib.Path.cwd() 109 | cfg.args.log_dir = os.path.join( 110 | cfg.args.log_dir, "results", cfg.data.dataset_name, date.today().strftime("%Y.%m.%d"), cfg.args.auto_tune 111 | ) 112 | print(f"Log dir: {cfg.args.log_dir}") 113 | os.makedirs(cfg.args.log_dir, exist_ok=True) 114 | 115 | tune_options = [ 116 | "first_two_block", 117 | "second_block", 118 | "third_block", 119 | "last", 120 | "all", 121 | ] 122 | if cfg.data.dataset_name == "imagenet-c": 123 | tune_options.append("fourth_block") 124 | if cfg.args.auto_tune != 'none': 125 | tune_options = ["all"] 126 | if cfg.args.epochs == 0: tune_options = ['all'] 127 | corruption_types = cfg.data.corruption_types 128 | for corruption_type in corruption_types: 129 | cfg.wandb.exp_name = f"{cfg.data.dataset_name}_corruption{corruption_type}" 130 | if cfg.wandb.use: 131 | utils.setup_wandb(cfg) 132 | utils.set_seed_everywhere(cfg.args.seed) 133 | loaders = get_loaders(cfg, corruption_type, cfg.data.severity) 134 | 135 | for tune_option in tune_options: 136 | tune_metrics = defaultdict(list) 137 | lr_wd_grid = [ 138 | (1e-1, 1e-4), 139 | (1e-2, 1e-4), 140 | (1e-3, 1e-4), 141 | (1e-4, 1e-4), 142 | (1e-5, 1e-4), 143 | ] 144 | for lr, wd in lr_wd_grid: 145 | dataset_name = ( 146 | "imagenet" 147 | if cfg.data.dataset_name == "imagenet-c" 148 | else cfg.data.dataset_name 149 | ) 150 | model = load_model( 151 | cfg.data.model_name, 152 | cfg.user.ckpt_dir, 153 | dataset_name, 154 | ThreatModel.corruptions, 155 | ) 156 | 157 | orig_model = copy.deepcopy(model) 158 | model = model.cuda() 159 | 160 | if cfg.data.dataset_name == "cifar10": 161 | tune_params_dict = { 162 | "all": [model.parameters()], 163 | "first_two_block": [ 164 | model.conv1.parameters(), 165 | model.block1.parameters(), 166 | ], 167 | "second_block": [ 168 | model.block2.parameters(), 169 | ], 170 | "third_block": [ 171 | model.block3.parameters(), 172 | ], 173 | "last": [model.fc.parameters()], 174 | } 175 | elif cfg.data.dataset_name == "imagenet-c": 176 | tune_params_dict = { 177 | "all": [model.model.parameters()], 178 | "first_second": [ 179 | model.model.conv1.parameters(), 180 | model.model.layer1.parameters(), 181 | model.model.layer2.parameters(), 182 | ], 183 | "first_two_block": [ 184 | model.model.conv1.parameters(), 185 | model.model.layer1.parameters(), 186 | ], 187 | "second_block": [ 188 | model.model.layer2.parameters(), 189 | ], 190 | "third_block": [ 191 | model.model.layer3.parameters(), 192 | ], 193 | "fourth_block": [ 194 | model.model.layer4.parameters(), 195 | ], 196 | "last": [model.model.fc.parameters()], 197 | } 198 | 199 | params_list = list(itertools.chain(*tune_params_dict[tune_option])) 200 | 201 | opt = optim.Adam(params_list, lr=lr, weight_decay=wd) 202 | N = sum(p.numel() for p in params_list if p.requires_grad) 203 | 204 | print( 205 | f"\nTrain mode={cfg.args.train_mode}, using {cfg.args.train_n} corrupted images for training" 206 | ) 207 | print( 208 | f"Re-training {tune_option} ({N} params). lr={lr}, wd={wd}. Corruption {corruption_type}" 209 | ) 210 | 211 | criterion = F.cross_entropy 212 | layer_weights = [0 for layer, _ in model.named_parameters() if 'bn' not in layer] 213 | layer_names = [layer for layer, _ in model.named_parameters() if 'bn' not in layer] 214 | for epoch in range(1, cfg.args.epochs + 1): 215 | if cfg.args.train_mode == "train": 216 | model.train() 217 | if cfg.args.auto_tune != 'none': 218 | if cfg.args.auto_tune == 'RGN': 219 | weights = get_lr_weights(model, loaders["train"], cfg) 220 | max_weight = max(weights.values()) 221 | for k, v in weights.items(): 222 | weights[k] = v / max_weight 223 | layer_weights = [sum(x) for x in zip(layer_weights, weights.values())] 224 | tune_metrics['layer_weights'] = layer_weights 225 | params = defaultdict() 226 | for n, p in model.named_parameters(): 227 | if "bn" not in n: 228 | params[n] = p 229 | params_weights = [] 230 | for param, weight in weights.items(): 231 | params_weights.append({"params": params[param], "lr": weight*lr}) 232 | opt = optim.Adam(params_weights, lr=lr, weight_decay=wd) 233 | elif cfg.args.auto_tune == 'eb-criterion': 234 | # Go by individual layers 235 | weights = get_lr_weights(model, loaders["train"], cfg) 236 | print(f"Epoch {epoch}, autotuning weights {min(weights.values()), max(weights.values())}") 237 | tune_metrics['max_weight'].append(max(weights.values())) 238 | tune_metrics['min_weight'].append(min(weights.values())) 239 | print(weights.values()) 240 | for k, v in weights.items(): 241 | weights[k] = 0.0 if v < 0.95 else 1.0 242 | print("weight values", weights.values()) 243 | layer_weights = [sum(x) for x in zip(layer_weights, weights.values())] 244 | tune_metrics['layer_weights'] = layer_weights 245 | params = defaultdict() 246 | for n, p in model.named_parameters(): 247 | if "bn" not in n: 248 | params[n] = p 249 | params_weights = [] 250 | for k, v in params.items(): 251 | if k in weights.keys(): 252 | params_weights.append({"params": params[k], "lr": weights[k]*lr}) 253 | else: 254 | params_weights.append({"params": params[k], "lr": 0.0}) 255 | opt = optim.Adam(params_weights, lr=lr, weight_decay=wd) 256 | 257 | else: 258 | # Log rough fraction of parameters being tuned 259 | no_weight = 0 260 | for elt in params_weights: 261 | if elt['lr'] == 0.: 262 | no_weight += elt['params'][0].flatten().shape[0] 263 | total_params = sum(p.numel() for p in model.parameters()) 264 | tune_metrics['frac_params'].append((total_params-no_weight)/total_params) 265 | print(f"Tuning {(total_params-no_weight)} out of {total_params} total") 266 | 267 | acc_tr, loss_tr, grad_magnitudes = train( 268 | model, loaders["train"], criterion, opt, cfg, orig_model=orig_model 269 | ) 270 | acc_te, loss_te = test(model, loaders["test"], criterion, cfg) 271 | acc_val, loss_val = test(model, loaders["val"], criterion, cfg) 272 | tune_metrics["acc_train"].append(acc_tr) 273 | tune_metrics["acc_val"].append(acc_val) 274 | tune_metrics["acc_te"].append(acc_te) 275 | log_dict = { 276 | f"{tune_option}/train/acc": acc_tr, 277 | f"{tune_option}/train/loss": loss_tr, 278 | f"{tune_option}/val/acc": acc_val, 279 | f"{tune_option}/val/loss": loss_val, 280 | f"{tune_option}/test/acc": acc_te, 281 | f"{tune_option}/test/loss": loss_te, 282 | } 283 | print(f"Epoch {epoch:2d} Train acc: {acc_tr:.4f}, Val acc: {acc_val:.4f}") 284 | 285 | if cfg.wandb.use: 286 | wandb.log(log_dict) 287 | 288 | tune_metrics["lr_tested"].append(lr) 289 | tune_metrics["wd_tested"].append(wd) 290 | 291 | # Get test acc according to best val acc 292 | best_run_idx = np.argmax(np.array(tune_metrics["acc_val"])) 293 | best_testacc = tune_metrics["acc_te"][best_run_idx] 294 | best_lr_wd = best_run_idx // (cfg.args.epochs) 295 | 296 | print( 297 | f"Best epoch: {best_run_idx % (cfg.args.epochs)}, Test Acc: {best_testacc}" 298 | ) 299 | 300 | data = { 301 | "corruption_type": corruption_type, 302 | "train_mode": cfg.args.train_mode, 303 | "tune_option": tune_option, 304 | "auto_tune": cfg.args.auto_tune, 305 | "train_n": cfg.args.train_n, 306 | "seed": cfg.args.seed, 307 | "lr": tune_metrics["lr_tested"][best_lr_wd], 308 | "wd": tune_metrics["wd_tested"][best_lr_wd], 309 | "val_acc": tune_metrics["acc_val"][best_run_idx], 310 | "best_testacc": best_testacc, 311 | } 312 | 313 | recorded = False 314 | fieldnames = data.keys() 315 | csv_file_name = f"{cfg.args.log_dir}/results_seed{cfg.args.seed}.csv" 316 | write_header = True if not os.path.exists(csv_file_name) else False 317 | while not recorded: 318 | try: 319 | with open(csv_file_name, "a") as f: 320 | csv_writer = csv.DictWriter(f, fieldnames=fieldnames, restval=0.0) 321 | if write_header: 322 | csv_writer.writeheader() 323 | csv_writer.writerow(data) 324 | recorded = True 325 | except: 326 | time.sleep(5) 327 | 328 | 329 | if __name__ == "__main__": 330 | main() 331 | --------------------------------------------------------------------------------