├── .gitignore ├── assets └── Fig1.png ├── pretrained_model └── download.sh ├── scripts ├── test_vtab.sh ├── test_cifar.sh └── test_cub.sh ├── models └── load_model.py ├── LICENSE ├── utils.py ├── readme.md ├── datasets └── load_dataset.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | datasets/data/ 4 | .npz -------------------------------------------------------------------------------- /assets/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gfyddha/Fly-CL/HEAD/assets/Fig1.png -------------------------------------------------------------------------------- /pretrained_model/download.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # download pre-trained Vision Transformer checkpoint 4 | wget https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz 5 | mv B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz vit_base_patch16_224_in21k.npz -------------------------------------------------------------------------------- /scripts/test_vtab.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | cd .. 4 | 5 | python main.py --dataset VTAB --num_classes 50 --num_tasks 5 --model_name vit_base_patch16_224 --embedding_dim 768 --expand_dim 10000 --synaptic_degree 300 --coding_level 0.3 --seed 2023 --batch_size 128 --gpu 6 --data_augmentation vit --ridge_lower 6 --ridge_upper 10 -------------------------------------------------------------------------------- /scripts/test_cifar.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | cd .. 4 | 5 | python main.py --dataset CIFAR-100 --num_classes 100 --num_tasks 10 --model_name vit_base_patch16_224 --embedding_dim 768 --expand_dim 10000 --synaptic_degree 300 --coding_level 0.3 --seed 1993 --batch_size 128 --gpu 5 --data_augmentation vit --ridge_lower 6 --ridge_upper 10 -------------------------------------------------------------------------------- /scripts/test_cub.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | cd .. 4 | 5 | python main.py --dataset CUB-200-2011 --num_classes 200 --num_tasks 10 --model_name vit_base_patch16_224 --embedding_dim 768 --expand_dim 10000 --synaptic_degree 300 --coding_level 0.3 --seed 2023 --batch_size 128 --gpu 1 --data_augmentation vit --ridge_lower 6 --ridge_upper 10 -------------------------------------------------------------------------------- /models/load_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import timm 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def load_model(model_name): 8 | if model_name == "vit_base_patch16_224": 9 | return timm.create_model( 10 | "vit_base_patch16_224", 11 | pretrained=True, 12 | num_classes=0 13 | ) 14 | 15 | elif model_name == "resnet-50": 16 | model = timm.create_model("resnet50", pretrained=False, 17 | checkpoint_path='./pretrained_model/resnet50-11ad3fa6.pth', num_classes=1000) 18 | state_dict = model.state_dict() 19 | keys_to_remove = [k for k in state_dict if "classifier" in k] 20 | for k in keys_to_remove: 21 | del state_dict[k] 22 | model_new = timm.create_model("resnet50", pretrained=False, num_classes=0) 23 | model_new.load_state_dict(state_dict, strict=False) 24 | return model_new 25 | 26 | else: 27 | raise ValueError(f"Unknown model name: {model_name}") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Heming Zou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | 9 | def random_initialization(seed: int = 2025): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.enabled = False 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | 18 | 19 | def get_parameters(model: nn.Module): 20 | return [p for p in model.parameters() if p.requires_grad] 21 | 22 | 23 | @torch.no_grad() 24 | def feature_extract(model: nn.Module, data_loader: DataLoader, device: torch.device): 25 | embedding_list, label_list = [], [] 26 | with torch.no_grad(): 27 | for i, (data, label) in enumerate(tqdm(data_loader)): 28 | data, label = data.to(device), label.to(device) 29 | embedding = model(data) 30 | embedding_list.append(embedding) 31 | label_list.append(label) 32 | embedding_list = torch.cat(embedding_list, dim=0) 33 | label_list = torch.cat(label_list, dim=0) 34 | return embedding_list, label_list 35 | 36 | def target2onehot(targets, n_classes): 37 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) 38 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.0) 39 | return onehot -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Fly-CL: A Fly-Inspired Framework for Enhancing Efficient Decorrelation and Reduced Training Time in Pre-trained Model-based Continual Representation Learning 2 | 3 | This repository contains the official implementation of our paper: 4 | **"Fly-CL: A Fly-Inspired Framework for Enhancing Efficient Decorrelation and Reduced Training Time in Pre-trained Model-based Continual Representation Learning."** 5 | 6 | ![](assets/Fig1.png) 7 | 8 | ## 🧠 Abstract 9 | 10 | Using a nearly-frozen pretrained model, the continual representation learning paradigm reframes parameter updates as a similarity-matching problem to mitigate catastrophic forgetting. However, directly leveraging pretrained features for downstream tasks often suffers from multicollinearity in the similarity-matching stage, and more advanced methods can be computationally prohibitive for real-time, low-latency applications. Inspired by the fly olfactory circuit, we propose Fly-CL, a bio-inspired framework compatible with a wide range of pretrained backbones. Fly-CL substantially reduces training time while achieving performance comparable to or exceeding that of current state-of-the-art methods. We theoretically show how Fly-CL progressively resolves multicollinearity, enabling more effective similarity matching with low time complexity. Extensive simulation experiments across diverse network architectures and data regimes validate Fly-CL’s effectiveness in addressing this challenge through a biologically inspired design. 11 | 12 | ## ⚙️ Environment Setup 13 | 14 | Experiment Configuration using Miniconda. People can create environment using following command 15 | ``` 16 | conda create -n FlyCL python=3.9 17 | conda activate FlyCL 18 | conda install pytorch==1.13.1 torchvision==0.14.1 pytorch-cuda=11.7 -c pytorch -c nvidia 19 | conda install "numpy<2.0.0" 20 | conda install timm==0.9.16 tqdm 21 | conda install scipy 22 | ``` 23 | 24 | ## Pre-trained Model Download 25 | 26 | Download the pretrained models using the provided script `pretrained_model/download.sh` 27 | 28 | ## 🚀 Running Experiments 29 | 30 | We provide example scripts for running experiments with the CIFAR-100, CUB-200-2011, and VTAB datasets. 31 | ```bash 32 | cd scripts 33 | ./test_cifar.sh 34 | ./test_cub.sh 35 | ./test_vtab.sh 36 | ``` 37 | 38 | ## 📖 Citation 39 | 40 | If you find this repository useful, please consider citing our paper: 41 | 42 | ```bibtex 43 | @article{zou2025fly, 44 | title={Fly-CL: A Fly-Inspired Framework for Enhancing Efficient Decorrelation and Reduced Training Time in Pre-trained Model-based Continual Representation Learning}, 45 | author={Zou, Heming and Zang, Yunliang and Xu, Wutong and Ji, Xiangyang}, 46 | journal={arXiv preprint arXiv:2510.16877}, 47 | year={2025}, 48 | } 49 | ``` 50 | 51 | ## 📬 Contact 52 | 53 | If you have any questions or feedback, please feel free to reach out: 54 | 📧 [zouhm24@mails.tsinghua.edu.cn](mailto:zouhm24@mails.tsinghua.edu.cn) 55 | -------------------------------------------------------------------------------- /datasets/load_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import transforms, datasets 6 | from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset 7 | 8 | 9 | class CustomDataset(Dataset): 10 | def __init__(self, data, targets, transform=None): 11 | self.data = data 12 | self.targets = targets 13 | self.transform = transform 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, idx): 19 | img_path = self.data[idx] 20 | label = self.targets[idx] 21 | image = Image.open(img_path).convert('RGB') 22 | if self.transform: 23 | image = self.transform(image) 24 | return image, label 25 | 26 | 27 | def build_transform(is_cifar: bool = False, data_augmentation = None) -> transforms.Compose: 28 | """ Build a transformation pipeline for image preprocessing. """ 29 | input_size = 224 30 | resize_im = input_size > 32 31 | transform = [] 32 | if resize_im: 33 | size = int((256 / 224) * input_size) if not is_cifar else input_size 34 | transform.append(transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC)) 35 | transform.append(transforms.CenterCrop(input_size)) 36 | transform.append(transforms.ToTensor()) 37 | if data_augmentation is None: 38 | pass 39 | elif data_augmentation == "resnet": 40 | transform.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 41 | elif data_augmentation == "vit": 42 | transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 43 | else: 44 | raise ValueError(f"Unsupported data augmentation: {data_augmentation}") 45 | return transform 46 | 47 | 48 | def load_dataset(args, domain_name=None, train=None): 49 | """ Load a dataset and split it into tasks for continual learning. """ 50 | dataset = args.dataset 51 | root = args.root 52 | num_classes = args.num_classes 53 | num_tasks = args.num_tasks 54 | batch_size = args.batch_size 55 | data_augmentation = args.data_augmentation 56 | 57 | # Build transformations 58 | is_cifar = dataset == "CIFAR-100" 59 | train_transform = build_transform(is_cifar=is_cifar, data_augmentation=data_augmentation) 60 | test_transform = build_transform(is_cifar=is_cifar, data_augmentation=data_augmentation) 61 | train_transform = transforms.Compose([*train_transform]) 62 | test_transform = transforms.Compose([*test_transform]) 63 | 64 | # Load the full dataset 65 | if dataset == "CIFAR-100": 66 | full_train_dataset = datasets.CIFAR100(root=root, train=True, download=True, transform=train_transform) 67 | full_test_dataset = datasets.CIFAR100(root=root, train=False, download=True, transform=test_transform) 68 | elif dataset == "CUB-200-2011": 69 | full_train_dataset = datasets.ImageFolder(root=f"{root}/cub/train/", transform=train_transform) 70 | full_test_dataset = datasets.ImageFolder(root=f"{root}/cub/test/", transform=test_transform) 71 | elif dataset == "VTAB": 72 | full_train_dataset = datasets.ImageFolder(root=f"{root}/vtab/train/", transform=train_transform) 73 | full_test_dataset = datasets.ImageFolder(root=f"{root}/vtab/test/", transform=test_transform) 74 | else: 75 | raise ValueError(f"Unsupported dataset: {dataset}") 76 | 77 | # Split dataset into tasks 78 | class_per_task = num_classes // num_tasks 79 | random_classes = random.sample(list(range(num_classes)), num_classes) 80 | task_classes = [ 81 | random_classes[i * class_per_task:(i + 1) * class_per_task] 82 | for i in range(num_tasks) 83 | ] 84 | 85 | # Create DataLoader for each task 86 | train_loader = {} 87 | test_loader = {} 88 | for i, classes_in_task in enumerate(task_classes): 89 | train_subset = Subset( 90 | full_train_dataset, 91 | indices=[index for index, label in enumerate(full_train_dataset.targets) if label in classes_in_task] 92 | ) 93 | test_subset = Subset( 94 | full_test_dataset, 95 | indices=[index for index, label in enumerate(full_test_dataset.targets) if label in classes_in_task] 96 | ) 97 | 98 | train_loader[i] = DataLoader(train_subset, batch_size=batch_size, shuffle=True, 99 | num_workers=8, pin_memory=True) 100 | test_loader[i] = DataLoader(test_subset, batch_size=batch_size, shuffle=False, 101 | num_workers=8, pin_memory=True) 102 | # train_loader[i] = DataLoader(train_subset, batch_size=batch_size, shuffle=True) 103 | # test_loader[i] = DataLoader(test_subset, batch_size=batch_size, shuffle=False) 104 | 105 | return train_loader, test_loader 106 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | import timm 6 | import numpy as np 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from tqdm import tqdm 10 | 11 | from datasets.load_dataset import load_dataset 12 | from models.load_model import load_model 13 | from utils import random_initialization, feature_extract, target2onehot 14 | 15 | 16 | def get_parser() -> argparse.ArgumentParser: 17 | parser = argparse.ArgumentParser(description="Input hyperparameters for the experiment.") 18 | 19 | # Continual Learning Task Setting 20 | parser.add_argument('--dataset', default='CIFAR-100', help='Choose dataset') 21 | parser.add_argument('--root', default='../data', help='Dataset path') 22 | parser.add_argument('--num_classes', type=int, default=100, help='Total number of classes') 23 | parser.add_argument('--num_tasks', type=int, default=10, help='Number of tasks') 24 | 25 | # model Architecture 26 | parser.add_argument('--model_name', type=str, default="vit_base_patch16_224", help='model name') 27 | parser.add_argument('--embedding_dim', type=int, default=768, help='Embedding dimension of pre-trained model') 28 | parser.add_argument('--expand_dim', type=int, default=10000, help='Expansion dimension of FlyModel') 29 | parser.add_argument('--synaptic_degree', type=int, default=100, help='Number of connections') 30 | parser.add_argument('--coding_level', type=float, default=0.01, help='Top-k number') 31 | 32 | # Training Configuration 33 | parser.add_argument('--seed', type=int, default=2025, help='Random seed') 34 | parser.add_argument('--ridge_lower', type=float, default=4, help='lower bound for ridge coefficient (log10)') 35 | parser.add_argument('--ridge_upper', type=float, default=10, help='lower bound for ridge coefficient (log10)') 36 | parser.add_argument('--data_augmentation', default=None, help='choose which normalization or not') 37 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size') 38 | parser.add_argument('--gpu', type=int, default=0, help='Choose gpu') 39 | 40 | return parser 41 | 42 | 43 | def select_ridge_parameter(Features, Y, ridge_lower, ridge_upper): 44 | X = Features 45 | U, S, Vh = torch.linalg.svd(X, full_matrices=False) 46 | S_sq = S**2 47 | UTY = U.T @ Y 48 | ridges = torch.tensor(10.0 ** np.arange(ridge_lower, ridge_upper)) 49 | n_samples = X.shape[0] 50 | 51 | gcv_scores = [] 52 | for ridge in ridges: 53 | diag = S_sq / (S_sq + ridge) 54 | df = diag.sum() 55 | Y_hat = U @ (diag[:, None] * UTY) 56 | residual = torch.norm(Y - Y_hat)**2 57 | gcv = (residual / n_samples) / (1 - df / n_samples)**2 58 | gcv_scores.append(gcv.item()) 59 | 60 | optimal_idx = np.argmin(gcv_scores) 61 | return ridges[optimal_idx] 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = get_parser() 66 | args = parser.parse_args() 67 | cuda_available = torch.cuda.is_available() 68 | device = torch.device(f"cuda:{args.gpu}" if cuda_available else "cpu") 69 | random_initialization(args.seed) 70 | 71 | if args.dataset == "CIFAR-100" or args.dataset == "CUB-200-2011" or args.dataset == "VTAB": 72 | print("Load and Split CIL Dataset...") 73 | train_loader, test_loader = load_dataset(args) 74 | print("Load and Split CIL Dataset Done") 75 | 76 | pretrained_model = load_model(args.model_name) 77 | pretrained_model.out_dim = args.embedding_dim 78 | pretrained_model.eval() 79 | pretrained_model.to(device) 80 | 81 | non_zero_per_col = args.synaptic_degree 82 | projection_matrix = torch.zeros(args.expand_dim, args.embedding_dim) 83 | for row in range(args.expand_dim): 84 | selected_cols = torch.randperm(args.embedding_dim)[:non_zero_per_col] 85 | projection_matrix[row, selected_cols] = torch.randn(non_zero_per_col) 86 | projection_matrix = projection_matrix.to(device).to_sparse_csc() 87 | 88 | acc = {} 89 | training_time = [] 90 | feature_extract_time = [] 91 | Q = torch.zeros(args.expand_dim, args.num_classes).to(device) 92 | G = torch.zeros(args.expand_dim, args.expand_dim).to(device) 93 | last_ridge = None 94 | print("Start Continual Learning") 95 | for task in range(args.num_tasks): 96 | acc[task] = [] 97 | training_start = time.time() 98 | feature_extract_start = time.time() 99 | train_embeddings, train_labels = feature_extract(pretrained_model, train_loader[task], device) 100 | feature_extract_end = time.time() 101 | feature_extract_time.append(feature_extract_end - feature_extract_start) 102 | 103 | train_embeddings = torch.sparse.mm(projection_matrix, train_embeddings.T) # 10000, N 104 | values, indices = train_embeddings.topk(int(args.expand_dim * args.coding_level), dim=0, largest=True) 105 | output = torch.zeros_like(train_embeddings) 106 | output.scatter_(0, indices, values) 107 | train_embeddings = output 108 | 109 | Y = target2onehot(train_labels, args.num_classes) 110 | Q = Q + train_embeddings @ Y 111 | G = G + train_embeddings @ train_embeddings.T 112 | ridge = select_ridge_parameter(train_embeddings.T, Y, args.ridge_lower, args.ridge_upper) 113 | L = torch.linalg.cholesky(G + ridge * torch.eye(G.size(dim=0)).to(device)) # 40% faster 114 | Wo = torch.cholesky_solve(Q, L) 115 | training_end = time.time() 116 | training_time.append(training_end - training_start) 117 | 118 | for sub_task in range(task + 1): 119 | test_embeddings, test_labels = feature_extract(pretrained_model, test_loader[sub_task], device) 120 | test_embeddings = torch.sparse.mm(projection_matrix, test_embeddings.T) 121 | values, indices = test_embeddings.topk(int(args.expand_dim * args.coding_level), dim=0, largest=True) 122 | output = torch.zeros_like(test_embeddings) 123 | output.scatter_(0, indices, values) 124 | test_embeddings = output.T.to_sparse_csc() 125 | output = torch.sparse.mm(test_embeddings, Wo) 126 | predicts = torch.topk(output, k=1, dim=1, largest=True, sorted=True)[1].squeeze() 127 | test_accuracy = np.mean(predicts.cpu().numpy() == test_labels.cpu().numpy()) * 100 128 | acc[sub_task].append(test_accuracy) 129 | 130 | # display acc_matrix 131 | acc_matrix = [["{:.2f}".format(0.00) for _ in range(args.num_tasks)] for _ in range(len(acc))] 132 | for i, (task, values) in enumerate(acc.items()): 133 | for j, value in enumerate(values): 134 | acc_matrix[i][i + j] = round(value, 2) 135 | 136 | print("Accuracy Matrix") 137 | for row in acc_matrix: 138 | print(row) 139 | print() 140 | 141 | print("Average Accuracy") 142 | A_t = [] 143 | for j in range(args.num_tasks): 144 | cnt = 0.0 145 | for i in range(j + 1): 146 | cnt += acc_matrix[i][j] 147 | cnt /= (j + 1) 148 | A_t.append(cnt) 149 | print(round(cnt, 2), end=", ") 150 | print("\n") 151 | 152 | print("Accumulated Accuracy") 153 | print(round(np.mean(A_t), 2)) 154 | print() 155 | 156 | print("Training Time") 157 | for task_time in training_time: 158 | print(round(task_time, 2), end=", ") 159 | print("\n") 160 | 161 | print("Average Training Time") 162 | print(round(np.mean(training_time), 2)) 163 | print() 164 | 165 | print("Feature Extract Time") 166 | for task_time in feature_extract_time: 167 | print(round(task_time, 2), end=", ") 168 | print("\n") 169 | 170 | print("Average Feature Extract Time") 171 | print(round(np.mean(feature_extract_time), 2)) 172 | print() --------------------------------------------------------------------------------