├── .gitattributes ├── .gitignore ├── README.md ├── backend ├── .gitattributes ├── .gitignore ├── README.md ├── app │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ └── settings.py │ ├── models │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── status.py │ ├── routers │ │ ├── __init__.py │ │ ├── data.py │ │ ├── train.py │ │ └── unlearn.py │ ├── services │ │ ├── __init__.py │ │ ├── train.py │ │ ├── unlearn_FT.py │ │ ├── unlearn_GA.py │ │ ├── unlearn_RL.py │ │ ├── unlearn_custom.py │ │ └── unlearn_retrain.py │ ├── threads │ │ ├── __init__.py │ │ ├── train_thread.py │ │ ├── unlearn_FT_thread.py │ │ ├── unlearn_GA_thread.py │ │ ├── unlearn_RL_thread.py │ │ ├── unlearn_custom_thread.py │ │ └── unlearn_retrain_thread.py │ └── utils │ │ ├── __init__.py │ │ ├── attack.py │ │ ├── data_loader.py │ │ ├── evaluation.py │ │ ├── helpers.py │ │ ├── visualization.py │ │ └── visualize_distributions.py ├── data │ ├── 0 │ │ ├── 0000.json │ │ └── a000.json │ ├── 1 │ │ ├── 0001.json │ │ └── a001.json │ ├── 2 │ │ ├── 0002.json │ │ └── a002.json │ ├── 3 │ │ ├── 0003.json │ │ └── a003.json │ ├── 4 │ │ ├── 0004.json │ │ └── a004.json │ ├── 5 │ │ ├── 0005.json │ │ └── a005.json │ ├── 6 │ │ ├── 0006.json │ │ └── a006.json │ ├── 7 │ │ ├── 0007.json │ │ └── a007.json │ ├── 8 │ │ ├── 0008.json │ │ └── a008.json │ └── 9 │ │ ├── 0009.json │ │ └── a009.json ├── index.html ├── main.py └── pyproject.toml ├── demo.mp4 ├── frontend ├── .gitignore ├── README.md ├── components.json ├── package.json ├── pnpm-lock.yaml ├── public │ ├── index.html │ ├── logo.svg │ ├── manifest.json │ └── robots.txt ├── src │ ├── app │ │ ├── App.tsx │ │ ├── index.css │ │ └── react-app-env.d.ts │ ├── components │ │ ├── Core │ │ │ ├── Embeddings │ │ │ │ ├── ConnectionLine.tsx │ │ │ │ ├── ConnectionLineWrapper.tsx │ │ │ │ ├── Legend.tsx │ │ │ │ ├── ScatterPlot.tsx │ │ │ │ └── Tooltip.tsx │ │ │ └── PrivacyAttack │ │ │ │ ├── AttackAnalytics.tsx │ │ │ │ ├── AttackPlot.tsx │ │ │ │ ├── AttackSuccessFailure.tsx │ │ │ │ ├── InstancePanel.tsx │ │ │ │ ├── Legend.tsx │ │ │ │ └── Tooltip.tsx │ │ ├── Header │ │ │ ├── Header.tsx │ │ │ ├── Tab.tsx │ │ │ ├── TabPlusButton.tsx │ │ │ └── Tabs.tsx │ │ ├── MetricsView │ │ │ ├── Accuracy │ │ │ │ └── VerticalBarChart.tsx │ │ │ ├── CKA │ │ │ │ └── LineChart.tsx │ │ │ └── Predictions │ │ │ │ ├── BubbleMatrix.tsx │ │ │ │ ├── BubbleMatrixLegend.tsx │ │ │ │ ├── CorrelationMatrix.tsx │ │ │ │ └── CorrelationMatrixLegend.tsx │ │ ├── ModelScreening │ │ │ ├── Experiments │ │ │ │ ├── ColumnHeaders.tsx │ │ │ │ ├── Columns.tsx │ │ │ │ ├── CustomUnlearning.tsx │ │ │ │ ├── DataTable.tsx │ │ │ │ ├── HyperparameterInput.tsx │ │ │ │ ├── Legend.tsx │ │ │ │ ├── MethodFilterHeader.tsx │ │ │ │ ├── MethodUnlearning.tsx │ │ │ │ ├── TableBody.tsx │ │ │ │ └── TableHeader.tsx │ │ │ └── Progress │ │ │ │ ├── AddModelsButton.tsx │ │ │ │ ├── Pagination.tsx │ │ │ │ ├── Stepper.tsx │ │ │ │ └── UnlearningConfiguration.tsx │ │ ├── UI │ │ │ ├── badge.tsx │ │ │ ├── button.tsx │ │ │ ├── chart.tsx │ │ │ ├── checkbox.tsx │ │ │ ├── context-menu.tsx │ │ │ ├── dialog.tsx │ │ │ ├── hover-card.tsx │ │ │ ├── icons.tsx │ │ │ ├── input.tsx │ │ │ ├── label.tsx │ │ │ ├── pagination.tsx │ │ │ ├── radio-group.tsx │ │ │ ├── scroll-area.tsx │ │ │ ├── select.tsx │ │ │ ├── separator.tsx │ │ │ ├── slider.tsx │ │ │ ├── stepper.tsx │ │ │ ├── table.tsx │ │ │ └── tabs.tsx │ │ └── common │ │ │ ├── CustomButton.tsx │ │ │ ├── DatasetModeSelector.tsx │ │ │ ├── Indicator.tsx │ │ │ ├── Subtitle.tsx │ │ │ ├── Title.tsx │ │ │ └── View.tsx │ ├── constants │ │ ├── colors.ts │ │ ├── common.ts │ │ ├── correlations.ts │ │ ├── embeddings.ts │ │ ├── experiments.ts │ │ └── privacyAttack.ts │ ├── hooks │ │ ├── useClasses.ts │ │ └── useModelExperiment.ts │ ├── index.tsx │ ├── stores │ │ ├── attackStore.ts │ │ ├── baseConfigStore.ts │ │ ├── experimentsStore.ts │ │ ├── forgetClassStore.ts │ │ ├── modelDataStore.ts │ │ ├── runningIndexStore.ts │ │ ├── runningStatusStore.ts │ │ └── thresholdStore.ts │ ├── types │ │ ├── attack.ts │ │ ├── data.ts │ │ ├── embeddings.ts │ │ └── experiments.ts │ ├── utils │ │ ├── api │ │ │ ├── dataTable.ts │ │ │ ├── privacyAttack.ts │ │ │ ├── requests.ts │ │ │ └── unlearning.ts │ │ ├── config │ │ │ └── unlearning.ts │ │ ├── data │ │ │ ├── accuracies.ts │ │ │ ├── colors.ts │ │ │ ├── experiments.ts │ │ │ ├── getButterflyLegendData.ts │ │ │ ├── getCkaData.ts │ │ │ ├── getProgressSteps.ts │ │ │ └── running-status-context.ts │ │ └── util.ts │ └── views │ │ ├── Core │ │ ├── Core.tsx │ │ ├── Embedding.tsx │ │ └── PrivacyAttack.tsx │ │ ├── MetricsView │ │ ├── ClassWiseAnalysis.tsx │ │ ├── LayerWiseSimilarity.tsx │ │ ├── MetricsView.tsx │ │ └── PredictionMatrix.tsx │ │ └── ModelScreening │ │ ├── Experiments.tsx │ │ ├── ModelScreening.tsx │ │ └── Progress.tsx ├── tailwind.config.js └── tsconfig.json └── img ├── attack.png └── embeddings.png /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mp4 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | 3 | backend/umap_visualizations/ 4 | backend/trained_models/ 5 | backend/unlearned_models/ 6 | backend/uploaded_models/ 7 | backend/attack/ 8 | backend/.python-version 9 | backend/*.png 10 | paper/ 11 | z_dist_exp/ 12 | myenv/ 13 | *.py 14 | *.pdf 15 | *.mdc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Unlearning Comparator 2 | 5 | This tool facilitates the evaluation and comparison of machine unlearning methods by providing interactive visualizations and analytical insights. It enables a systematic examination of model behavior through privacy attacks and performance metrics, offering comprehensive analysis of various unlearning techniques. 6 | 7 | ## Demo 8 | 9 | Try our live demo: [Machine Unlearning Comparator](https://gnueaj.github.io/Machine-Unlearning-Comparator/) 10 | 11 | [Watch the video](https://youtu.be/yAyAYp2msDk?si=Q-8IgVlrk8uSBceu) 12 | 13 | ## Features 14 | 15 | ### Built-in Baseline Methods 16 | The Machine Unlearning Comparator provides comparison of various baseline methods: 17 | - **Fine-Tuning**: Leverages catastrophic forgetting by fine-tuning the model on remaining data with increased learning rate 18 | - **Gradient-Ascent**: Moves in the direction of increasing loss for forget samples using negative gradients 19 | - **Random-Labeling**: Fine-tunes the model by randomly reassigning labels for forget samples, excluding the original forget class labels 20 | 21 | ### **Custom Method Integration** ✨ 22 | **Upload and evaluate your own unlearning methods!** The comparator supports custom implementations, enabling you to: 23 | - **Benchmark** your novel approaches against established baselines 24 | - **Upload** your custom unlearning implementations for comparison 25 | - **Compare** results using standardized evaluation metrics and privacy attacks 26 | 27 | It includes various visualizations and evaluations through privacy attacks to assess the effectiveness of each method. 28 | 29 | ## How to Start 30 | 31 | ### Backend 32 | 33 | 1. **Install Dependencies Using Hatch** 34 | ```shell 35 | hatch shell 36 | ``` 37 | 38 | 2. **Start the Backend Server** 39 | ```shell 40 | hatch run start 41 | ``` 42 | 43 | ### Frontend 44 | 45 | 1. **Install Dependencies Using pnpm** 46 | ```shell 47 | pnpm install 48 | ``` 49 | 50 | 2. **Start the Frontend Server** 51 | ```shell 52 | pnpm start 53 | ``` 54 | 55 | ## Related Resources 56 | - [ResNet18 CIFAR-10 Unlearning Models on Hugging Face](https://huggingface.co/jaeunglee/resnet18-cifar10-unlearning) 57 | -------------------------------------------------------------------------------- /backend/.gitattributes: -------------------------------------------------------------------------------- 1 | /home/jaeung/mu-dashboard/backend/demo.mp4 filter=lfs diff=lfs merge=lfs -text 2 | *.mp4 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /backend/.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | backend/umap_visualizations/ 3 | backend/trained_models/ 4 | backend/unlearned_models/ 5 | backend/uploaded_models/ 6 | backend/*.png 7 | d3ex/ 8 | attack/ 9 | attack_exp/ -------------------------------------------------------------------------------- /backend/README.md: -------------------------------------------------------------------------------- 1 | # machine-unlearning-dashboard 2 | machine-unlearning-dashboard 3 | 4 | ## Development 5 | 6 | Install [pnpm](https://pnpm.io/installation) to set up the development environment. 7 | 8 | ```shell 9 | git clone https://github.com/gnueaj/mu-dashboard.git && cd mu-dashboard/frontend 10 | pnpm install 11 | pnpm dev 12 | ``` 13 | 14 | ## Backend 15 | ```shell 16 | cd backend 17 | hatch shell 18 | hatch run start 19 | 20 | Swagger - http://127.0.0.1:8000/docs 21 | e.g. 22 | 23 | POST /train 24 | { 25 | "seed": 42, 26 | "batch_size": 128, 27 | "learning_rate": 0.01, 28 | "epochs": 30 29 | } 30 | 31 | GET /status 32 | ``` 33 | API 추가할 때: router -> service -> main 순 구현 34 | Swagger: http://127.0.0.1:8000/docs 35 | 36 | ### Training 37 | - /train POST 요청을 보내 트레이닝을 시작 38 | - 주기적으로 /train/status GET 요청을 보내 트레이닝 진행 상황을 확인 39 | - 트레이닝이 완료되면 /train/result GET 요청을 보내 SVG 파일 리스트를 받음 40 | 41 | ### Inference 42 | - /inference POST 요청을 보내 Inference를 시작 43 | - 주기적으로 /inference/status 엔드포인트로 GET 요청을 보내 Inference 진행 상황을 확인 44 | - Inference가 완료되면 /inference/result GET 요청을 보내 SVG 파일 리스트를 받음 45 | 46 | ### Unlearn 47 | ## Retrain 48 | ```shell 49 | POST /unlearn 50 | { 51 | "batch_size": 128, 52 | "learning_rate": 0.001, 53 | "epochs": 15, 54 | "forget_class": 1 55 | } 56 | ``` -------------------------------------------------------------------------------- /backend/app/__init__.py: -------------------------------------------------------------------------------- 1 | # This file can be empty or contain package-level imports if needed -------------------------------------------------------------------------------- /backend/app/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all configuration settings for the application. 3 | """ 4 | 5 | from .settings import ( 6 | UMAP_N_NEIGHBORS, 7 | UMAP_MIN_DIST, 8 | UMAP_INIT, 9 | UMAP_RANDOM_STATE, 10 | UMAP_N_JOBS, 11 | UMAP_DATA_SIZE, 12 | UMAP_DATASET, 13 | MAX_GRAD_NORM, 14 | UNLEARN_SEED, 15 | MOMENTUM, 16 | WEIGHT_DECAY, 17 | BATCH_SIZE, 18 | LEARNING_RATE, 19 | EPOCHS, 20 | DECREASING_LR, 21 | GAMMA 22 | ) 23 | 24 | __all__ = [ 25 | # UMAP settings 26 | 'UMAP_N_NEIGHBORS', 27 | 'UMAP_MIN_DIST', 28 | 'UMAP_INIT', 29 | 'UMAP_RANDOM_STATE', 30 | 'UMAP_N_JOBS', 31 | 'UMAP_DATA_SIZE', 32 | 'UMAP_DATASET', 33 | 34 | # Training settings 35 | 'MAX_GRAD_NORM', 36 | 'UNLEARN_SEED', 37 | 'MOMENTUM', 38 | 'WEIGHT_DECAY', 39 | 'BATCH_SIZE', 40 | 'LEARNING_RATE', 41 | 'EPOCHS', 42 | 43 | # Learning rate schedule 44 | 'DECREASING_LR', 45 | 'GAMMA' 46 | ] -------------------------------------------------------------------------------- /backend/app/config/settings.py: -------------------------------------------------------------------------------- 1 | UMAP_N_NEIGHBORS = 7 2 | UMAP_MIN_DIST = 0.4 3 | UMAP_INIT = 'pca' 4 | UMAP_RANDOM_STATE = 42 5 | UMAP_N_JOBS = -1 6 | UMAP_DATA_SIZE = 2000 7 | UMAP_DATASET = 'train' 8 | 9 | MAX_GRAD_NORM = 100.0 10 | UNLEARN_SEED = 2048 11 | MOMENTUM = 0.9 12 | WEIGHT_DECAY = 5e-4 13 | BATCH_SIZE = 128 14 | LEARNING_RATE = 0.1 15 | EPOCHS = 200 16 | 17 | DECREASING_LR = [80, 120] 18 | GAMMA = 0.2 -------------------------------------------------------------------------------- /backend/app/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains neural network model architectures and related data structures. 3 | It includes implementations of ResNet and status tracking classes for training and unlearning processes. 4 | """ 5 | 6 | from app.models.resnet import get_resnet18 7 | from app.models.status import TrainingStatus, UnlearningStatus 8 | 9 | __all__ = [ 10 | 'get_resnet18', 11 | 'TrainingStatus', 12 | 'UnlearningStatus' 13 | ] -------------------------------------------------------------------------------- /backend/app/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | def get_resnet18(num_classes=10): 5 | model = models.resnet18(weights=None) 6 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 7 | model.maxpool = nn.Identity() 8 | model.fc = nn.Linear(model.fc.in_features, num_classes) 9 | return model -------------------------------------------------------------------------------- /backend/app/models/status.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | class TrainingStatus: 4 | def __init__(self): 5 | self.is_training = False 6 | self.progress = 0 7 | self.current_epoch = 0 8 | self.total_epochs = 0 9 | self.current_loss = 0 10 | self.best_loss = 9999.99 11 | self.current_accuracy = 0 12 | self.best_accuracy = 0 13 | self.test_loss = 0 14 | self.test_accuracy = 0 15 | self.best_test_accuracy = 0 16 | self.train_class_accuracies: Dict[int, float] = {} 17 | self.test_class_accuracies: Dict[int, float] = {} 18 | self.start_time = 0 19 | self.estimated_time_remaining = 0 20 | self.umap_embeddings = None 21 | self.cancel_requested = False 22 | 23 | def reset(self): 24 | self.__init__() 25 | 26 | class UnlearningStatus: 27 | def __init__(self): 28 | self.is_unlearning = False 29 | self.recent_id = None 30 | self.progress = "Idle" 31 | self.current_epoch = 0 32 | self.total_epochs = 0 33 | self.current_unlearn_loss = 0 34 | self.current_unlearn_accuracy = 0 35 | self.estimated_time_remaining = 0 36 | self.cancel_requested = False 37 | self.forget_class = -1 38 | 39 | # for retraining 40 | self.current_loss = 0 41 | self.current_accuracy = 0 42 | self.best_loss = 9999.99 43 | self.best_accuracy = 0 44 | self.test_loss = 9999.99 45 | self.test_accuracy = 0 46 | self.best_test_accuracy = 0 47 | 48 | # for Evaluation progress 49 | self.method = "" 50 | self.p_training_loss = 0 51 | self.p_training_accuracy = 0 52 | self.p_test_loss = 0 53 | self.p_test_accuracy = 0 54 | 55 | self.train_class_accuracies: Dict[int, float] = {} 56 | self.test_class_accuracies: Dict[int, float] = {} 57 | self.start_time = None 58 | 59 | def reset(self): 60 | self.__init__() -------------------------------------------------------------------------------- /backend/app/routers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains route definitions for the API endpoints. 3 | 4 | Available routers: 5 | - train_router: Handles model training related endpoints 6 | - unlearn_router: Manages model unlearning operations 7 | - data_router: Manages data operations and preprocessing 8 | 9 | These routers are used to organize and structure the API endpoints 10 | for different functionalities of the application. 11 | """ 12 | 13 | from .train import router as train_router 14 | from .unlearn import router as unlearn_router 15 | from .data import router as data_router 16 | 17 | __all__ = ['train_router', 'unlearn_router', 'data_router'] -------------------------------------------------------------------------------- /backend/app/routers/train.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, BackgroundTasks, HTTPException 2 | from pydantic import BaseModel, Field 3 | from app.services import run_training 4 | from app.models import TrainingStatus 5 | from app.config import ( 6 | BATCH_SIZE, 7 | LEARNING_RATE, 8 | EPOCHS 9 | ) 10 | 11 | router = APIRouter() 12 | status = TrainingStatus() 13 | 14 | class TrainingRequest(BaseModel): 15 | # seed: int = Field(default=1111, description="Random seed for reproducibility") 16 | batch_size: int = Field(default=BATCH_SIZE, description="Batch size for training") 17 | learning_rate: float = Field(default=LEARNING_RATE, description="Learning rate for optimizer") 18 | epochs: int = Field(default=EPOCHS, description="Number of training epochs") 19 | 20 | @router.post("/train") 21 | async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): 22 | if status.is_training: 23 | raise HTTPException(status_code=400, detail="Training is already in progress") 24 | status.reset() # Reset status before starting new training 25 | background_tasks.add_task(run_training, request, status) 26 | return {"message": "Training started"} 27 | 28 | @router.get("/train/status") 29 | async def get_status(): 30 | return { 31 | "is_training": status.is_training, 32 | "progress": status.progress, 33 | "current_epoch": status.current_epoch, 34 | "total_epochs": status.total_epochs, 35 | "current_loss": status.current_loss, 36 | "best_loss": status.best_loss, 37 | "current_accuracy": status.current_accuracy, 38 | "best_accuracy": status.best_accuracy, 39 | "test_loss": status.test_loss, 40 | "test_accuracy": status.test_accuracy, 41 | "train_class_accuracies": status.train_class_accuracies, 42 | "test_class_accuracies": status.test_class_accuracies, 43 | "estimated_time_remaining": status.estimated_time_remaining 44 | } 45 | 46 | @router.get("/train/result") 47 | async def get_training_result(): 48 | if status.is_training: 49 | raise HTTPException(status_code=400, detail="Training is still in progress") 50 | if status.svg_files is None: 51 | raise HTTPException(status_code=404, detail="No training results available") 52 | return {"svg_files": status.svg_files} 53 | 54 | @router.post("/train/cancel") 55 | async def cancel_training(): 56 | if not status.is_training: 57 | raise HTTPException(status_code=400, detail="No training in progress") 58 | status.cancel_requested = True 59 | return {"message": "Cancellation requested. Training will stop after the current epoch."} -------------------------------------------------------------------------------- /backend/app/services/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Services package for managing model training and unlearning operations. 3 | 4 | This package provides modules for training neural networks and implementing various unlearning methods. 5 | The service modules define recipes that are executed by separate threads to handle the actual computation. 6 | 7 | Available services: 8 | train: Model training with configurable hyperparameters 9 | unlearn_GA: Unlearning using gradient ascent method 10 | unlearn_RL: Unlearning using random labeling method 11 | unlearn_FT: Unlearning using fine-tuning method 12 | unlearn_custom: Custom unlearning method for inference 13 | 14 | Each service module follows a similar pattern: 15 | 1. Takes training/unlearning parameters as input 16 | 2. Sets up the model, data, and optimization components 17 | 3. Passes the configuration to a dedicated execution thread 18 | 4. Provides status tracking and result handling 19 | """ 20 | 21 | from .train import run_training 22 | from .unlearn_GA import run_unlearning_GA 23 | from .unlearn_RL import run_unlearning_RL 24 | from .unlearn_FT import run_unlearning_FT 25 | from .unlearn_retrain import run_unlearning_retrain 26 | from .unlearn_custom import run_unlearning_custom 27 | 28 | __all__ = ['run_training', 'run_unlearning_GA', 'run_unlearning_RL', 'run_unlearning_FT', 'run_unlearning_retrain', 'run_unlearning_custom'] -------------------------------------------------------------------------------- /backend/app/services/train.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from app.threads import TrainingThread 7 | from app.models import get_resnet18 8 | from app.utils import set_seed, get_data_loaders 9 | from app.config import ( 10 | MOMENTUM, 11 | WEIGHT_DECAY, 12 | UNLEARN_SEED 13 | ) 14 | 15 | async def training(request, status): 16 | print(f"Starting training with {request.epochs} epochs...") 17 | set_seed(UNLEARN_SEED) 18 | device = torch.device( 19 | "cuda" if torch.cuda.is_available() 20 | else "mps" if torch.backends.mps.is_available() 21 | else "cpu" 22 | ) 23 | 24 | ( 25 | train_loader, 26 | test_loader, 27 | train_set, 28 | test_set 29 | ) = get_data_loaders( 30 | batch_size=request.batch_size, 31 | augmentation=True 32 | ) 33 | model = get_resnet18().to(device=device) 34 | 35 | criterion = nn.CrossEntropyLoss() 36 | optimizer = optim.SGD( 37 | model.parameters(), 38 | lr=request.learning_rate, 39 | momentum=MOMENTUM, 40 | weight_decay=WEIGHT_DECAY, 41 | nesterov=True 42 | ) 43 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 44 | optimizer=optimizer, 45 | T_max=request.epochs, 46 | ) 47 | 48 | training_thread = TrainingThread( 49 | model=model, 50 | train_loader=train_loader, 51 | test_loader=test_loader, 52 | criterion=criterion, 53 | optimizer=optimizer, 54 | scheduler=scheduler, 55 | device=device, 56 | epochs=request.epochs, 57 | status=status, 58 | model_name="resnet18", 59 | dataset_name="CIFAR10", 60 | learning_rate=request.learning_rate 61 | ) 62 | training_thread.start() 63 | 64 | while training_thread.is_alive(): 65 | await asyncio.sleep(0.1) 66 | if status.cancel_requested: 67 | training_thread.stop() 68 | print("Cancel requested. Stopping training thread...") 69 | break 70 | 71 | return status 72 | 73 | async def run_training(request, status): 74 | try: 75 | status.is_training = True 76 | status.cancel_requested = False 77 | updated_status = await training(request, status) 78 | return updated_status 79 | finally: 80 | status.is_training = False 81 | status.cancel_requested = False 82 | status.progress = 100 -------------------------------------------------------------------------------- /backend/app/services/unlearn_GA.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from app.threads import UnlearningGAThread 7 | from app.models import get_resnet18 8 | from app.utils.helpers import set_seed 9 | from app.utils.data_loader import get_data_loaders 10 | from app.config import ( 11 | MOMENTUM, 12 | WEIGHT_DECAY, 13 | DECREASING_LR, 14 | UNLEARN_SEED 15 | ) 16 | 17 | async def unlearning_GA(request, status, base_weights_path): 18 | print(f"Starting GA unlearning for class {request.forget_class} with {request.epochs} epochs...") 19 | set_seed(UNLEARN_SEED) 20 | 21 | device = torch.device( 22 | "cuda" if torch.cuda.is_available() 23 | else "mps" if torch.backends.mps.is_available() 24 | else "cpu" 25 | ) 26 | 27 | # Create Unlearning Settings 28 | model_before = get_resnet18().to(device) 29 | model_after = get_resnet18().to(device) 30 | model_before.load_state_dict(torch.load(f"unlearned_models/{request.forget_class}/000{request.forget_class}.pth", map_location=device)) 31 | model_after.load_state_dict(torch.load(base_weights_path, map_location=device)) 32 | 33 | ( 34 | train_loader, 35 | test_loader, 36 | train_set, 37 | test_set 38 | ) = get_data_loaders( 39 | batch_size=request.batch_size, 40 | augmentation=False 41 | ) 42 | 43 | forget_indices = [ 44 | i for i, (_, label) in enumerate(train_set) 45 | if label == request.forget_class 46 | ] 47 | forget_subset = torch.utils.data.Subset( 48 | dataset=train_set, 49 | indices=forget_indices 50 | ) 51 | forget_loader = torch.utils.data.DataLoader( 52 | dataset=forget_subset, 53 | batch_size=request.batch_size, 54 | shuffle=True 55 | ) 56 | 57 | criterion = nn.CrossEntropyLoss() 58 | optimizer = optim.SGD( 59 | params=model_after.parameters(), 60 | lr=request.learning_rate, 61 | momentum=MOMENTUM, 62 | weight_decay=WEIGHT_DECAY 63 | ) 64 | scheduler = optim.lr_scheduler.MultiStepLR( 65 | optimizer=optimizer, 66 | milestones=DECREASING_LR, 67 | gamma=0.2 68 | ) 69 | unlearning_GA_thread = UnlearningGAThread( 70 | request=request, 71 | status=status, 72 | model_before=model_before, 73 | model_after=model_after, 74 | criterion=criterion, 75 | optimizer=optimizer, 76 | scheduler=scheduler, 77 | 78 | forget_loader=forget_loader, 79 | train_loader=train_loader, 80 | test_loader=test_loader, 81 | train_set=train_set, 82 | test_set=test_set, 83 | device=device, 84 | base_weights_path=base_weights_path 85 | ) 86 | unlearning_GA_thread.start() 87 | 88 | # thread start 89 | while unlearning_GA_thread.is_alive(): 90 | await asyncio.sleep(0.1) 91 | if status.cancel_requested: 92 | unlearning_GA_thread.stop() 93 | print("Cancellation requested, stopping the unlearning process...") 94 | 95 | status.is_unlearning = False 96 | 97 | # thread end 98 | if unlearning_GA_thread.exception: 99 | print(f"An error occurred during GA unlearning: {str(unlearning_GA_thread.exception)}") 100 | elif status.cancel_requested: 101 | print("Unlearning process was cancelled.") 102 | else: 103 | print("Unlearning process completed successfully.") 104 | 105 | return status 106 | 107 | async def run_unlearning_GA(request, status, base_weights_path): 108 | try: 109 | status.is_unlearning = True 110 | status.progress = "Unlearning" 111 | updated_status = await unlearning_GA(request, status, base_weights_path) 112 | return updated_status 113 | finally: 114 | status.cancel_requested = False 115 | status.progress = "Completed" -------------------------------------------------------------------------------- /backend/app/services/unlearn_RL.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from app.threads import UnlearningRLThread 7 | from app.models import get_resnet18 8 | from app.utils.helpers import set_seed 9 | from app.utils.data_loader import get_data_loaders 10 | 11 | from app.config import ( 12 | MOMENTUM, 13 | WEIGHT_DECAY, 14 | DECREASING_LR, 15 | UNLEARN_SEED 16 | ) 17 | 18 | async def unlearning_RL(request, status, base_weights_path): 19 | print(f"Starting RL unlearning for class {request.forget_class} with {request.epochs} epochs...") 20 | set_seed(UNLEARN_SEED) 21 | 22 | device = torch.device( 23 | "cuda" if torch.cuda.is_available() 24 | else "mps" if torch.backends.mps.is_available() 25 | else "cpu" 26 | ) 27 | 28 | # Create Unlearning Settings 29 | model_before = get_resnet18().to(device) 30 | model_after = get_resnet18().to(device) 31 | model_before.load_state_dict(torch.load(f"unlearned_models/{request.forget_class}/000{request.forget_class}.pth", map_location=device)) 32 | model_after.load_state_dict(torch.load(base_weights_path, map_location=device)) 33 | 34 | ( 35 | train_loader, 36 | test_loader, 37 | train_set, 38 | test_set 39 | ) = get_data_loaders( 40 | batch_size=request.batch_size, 41 | augmentation=False 42 | ) 43 | 44 | # Create retain loader (excluding forget class) 45 | retain_indices = [ 46 | i for i, (_, label) in enumerate(train_set) 47 | if label != request.forget_class 48 | ] 49 | retain_subset = torch.utils.data.Subset( 50 | dataset=train_set, 51 | indices=retain_indices 52 | ) 53 | retain_loader = torch.utils.data.DataLoader( 54 | dataset=retain_subset, 55 | batch_size=request.batch_size, 56 | shuffle=True 57 | ) 58 | 59 | # Create forget loader (only forget class) 60 | forget_indices = [ 61 | i for i, (_, label) in enumerate(train_set) 62 | if label == request.forget_class 63 | ] 64 | forget_subset = torch.utils.data.Subset( 65 | dataset=train_set, 66 | indices=forget_indices 67 | ) 68 | forget_loader = torch.utils.data.DataLoader( 69 | dataset=forget_subset, 70 | batch_size=request.batch_size, 71 | shuffle=True 72 | ) 73 | 74 | criterion = nn.CrossEntropyLoss() 75 | optimizer = optim.SGD( 76 | params=model_after.parameters(), 77 | lr=request.learning_rate, 78 | momentum=MOMENTUM, 79 | weight_decay=WEIGHT_DECAY 80 | ) 81 | scheduler = optim.lr_scheduler.MultiStepLR( 82 | optimizer=optimizer, 83 | milestones=DECREASING_LR, 84 | gamma=0.2 85 | ) 86 | 87 | unlearning_RL_thread = UnlearningRLThread( 88 | request=request, 89 | status=status, 90 | model_before=model_before, 91 | model_after=model_after, 92 | criterion=criterion, 93 | optimizer=optimizer, 94 | scheduler=scheduler, 95 | 96 | retain_loader=retain_loader, 97 | forget_loader=forget_loader, 98 | train_loader=train_loader, 99 | test_loader=test_loader, 100 | train_set=train_set, 101 | test_set=test_set, 102 | device=device, 103 | base_weights_path=base_weights_path 104 | ) 105 | 106 | unlearning_RL_thread.start() 107 | 108 | # thread start 109 | while unlearning_RL_thread.is_alive(): 110 | await asyncio.sleep(0.1) 111 | if status.cancel_requested: 112 | unlearning_RL_thread.stop() 113 | print("Cancellation requested, stopping the unlearning process...") 114 | 115 | status.is_unlearning = False 116 | 117 | # thread end 118 | if unlearning_RL_thread.exception: 119 | print(f"An error occurred during RL unlearning: {str(unlearning_RL_thread.exception)}") 120 | elif status.cancel_requested: 121 | print("Unlearning process was cancelled.") 122 | else: 123 | print("Unlearning process completed successfully.") 124 | 125 | return status 126 | 127 | async def run_unlearning_RL(request, status, base_weights_path): 128 | try: 129 | status.is_unlearning = True 130 | status.progress = "Unlearning" 131 | updated_status = await unlearning_RL(request, status, base_weights_path) 132 | return updated_status 133 | finally: 134 | status.cancel_requested = False 135 | status.progress = "Completed" -------------------------------------------------------------------------------- /backend/app/services/unlearn_custom.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from app.threads import UnlearningCustomThread 7 | from app.utils.helpers import set_seed 8 | from app.utils.data_loader import get_data_loaders 9 | from app.models import get_resnet18 10 | from app.config import UNLEARN_SEED 11 | 12 | 13 | async def unlearning_custom(forget_class, status, weights_path, base_weights): 14 | print(f"Starting custom unlearning inference for class {forget_class}...") 15 | set_seed(UNLEARN_SEED) 16 | ( 17 | train_loader, 18 | test_loader, 19 | train_set, 20 | test_set 21 | ) = get_data_loaders( 22 | batch_size=1000, 23 | augmentation=False 24 | ) 25 | 26 | criterion = nn.CrossEntropyLoss() 27 | device = torch.device("cuda" if torch.cuda.is_available() 28 | else "mps" if torch.backends.mps.is_available() 29 | else "cpu") 30 | model_before = get_resnet18().to(device) 31 | model_before.load_state_dict(torch.load(f"unlearned_models/{forget_class}/000{forget_class}.pth", map_location=device)) 32 | model = get_resnet18().to(device) 33 | model.load_state_dict(torch.load(weights_path, map_location=device)) 34 | 35 | unlearning_thread = UnlearningCustomThread( 36 | forget_class=forget_class, 37 | status=status, 38 | model_before=model_before, 39 | model=model, 40 | train_loader=train_loader, 41 | test_loader=test_loader, 42 | train_set=train_set, 43 | test_set=test_set, 44 | criterion=criterion, 45 | device=device, 46 | base_weights=base_weights 47 | ) 48 | unlearning_thread.start() 49 | print("unlearning started") 50 | # thread start 51 | while unlearning_thread.is_alive(): 52 | await asyncio.sleep(0.1) 53 | if status.cancel_requested: 54 | unlearning_thread.stop() 55 | print("Cancellation requested, stopping the unlearning process...") 56 | 57 | status.is_unlearning = False 58 | 59 | if unlearning_thread.is_alive(): 60 | print("Warning: Unlearning thread did not stop within the timeout period.") 61 | 62 | # thread end 63 | if unlearning_thread.exception: 64 | print(f"An error occurred during custom unlearning: {str(unlearning_thread.exception)}") 65 | elif status.cancel_requested: 66 | print("Unlearning process was cancelled.") 67 | else: 68 | print("Unlearning process completed successfully.") 69 | 70 | return status 71 | 72 | async def run_unlearning_custom(forget_class, status, weights_path, base_weights): 73 | try: 74 | status.is_unlearning = True 75 | updated_status = await unlearning_custom(forget_class, status, weights_path, base_weights) 76 | return updated_status 77 | finally: 78 | status.is_unlearning = False 79 | status.cancel_requested = False 80 | status.progress = "Completed" 81 | os.remove(weights_path) -------------------------------------------------------------------------------- /backend/app/services/unlearn_retrain.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from app.threads import UnlearningRetrainThread 7 | from app.models import get_resnet18 8 | from app.utils.helpers import set_seed 9 | from app.utils.data_loader import get_data_loaders 10 | from app.utils.visualization import ( 11 | compute_umap_embedding, 12 | ) 13 | from app.utils.evaluation import ( 14 | get_layer_activations_and_predictions, 15 | ) 16 | from app.config import ( 17 | MOMENTUM, 18 | WEIGHT_DECAY, 19 | UNLEARN_SEED 20 | ) 21 | 22 | async def unlearning_retrain(request, status): 23 | print( 24 | f"Starting unlearning for class {request.forget_class} " 25 | f"with {request.epochs} epochs..." 26 | ) 27 | set_seed(UNLEARN_SEED) 28 | device = torch.device( 29 | "cuda" if torch.cuda.is_available() 30 | else "mps" if torch.backends.mps.is_available() 31 | else "cpu" 32 | ) 33 | 34 | ( 35 | train_loader, 36 | test_loader, 37 | train_set, 38 | test_set 39 | ) = get_data_loaders( 40 | batch_size=request.batch_size, 41 | augmentation=True 42 | ) 43 | 44 | # Create dataset excluding the forget class 45 | indices = [ 46 | i for i, (_, label) in enumerate(train_set) 47 | if label != request.forget_class 48 | ] 49 | subset = torch.utils.data.Subset(train_set, indices) 50 | unlearning_loader = torch.utils.data.DataLoader( 51 | dataset=subset, 52 | batch_size=request.batch_size, 53 | shuffle=True 54 | ) 55 | 56 | model = get_resnet18().to(device) 57 | criterion = nn.CrossEntropyLoss() 58 | optimizer = optim.SGD( 59 | model.parameters(), 60 | lr=request.learning_rate, 61 | momentum=MOMENTUM, 62 | weight_decay=WEIGHT_DECAY, 63 | nesterov=True 64 | ) 65 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 66 | optimizer=optimizer, 67 | T_max=request.epochs, 68 | ) 69 | 70 | status.progress = 0 71 | status.forget_class = request.forget_class 72 | 73 | unlearning_thread = UnlearningRetrainThread( 74 | model=model, 75 | unlearning_loader=unlearning_loader, 76 | full_train_loader=train_loader, 77 | test_loader=test_loader, 78 | criterion=criterion, 79 | optimizer=optimizer, 80 | scheduler=scheduler, 81 | device=device, 82 | epochs=request.epochs, 83 | status=status, 84 | model_name="resnet18", 85 | dataset_name=f"CIFAR10_without_class_{request.forget_class}", 86 | learning_rate=request.learning_rate, 87 | forget_class=request.forget_class 88 | ) 89 | unlearning_thread.start() 90 | 91 | while unlearning_thread.is_alive(): 92 | await asyncio.sleep(0.5) 93 | if status.cancel_requested: 94 | unlearning_thread.stop() 95 | print("Cancel requested. Stopping unlearning thread...") 96 | break 97 | 98 | if unlearning_thread.exception: 99 | print( 100 | f"An error occurred during Retrain unlearning: " 101 | f"{str(unlearning_thread.exception)}" 102 | ) 103 | return status 104 | 105 | return status 106 | 107 | async def run_unlearning_retrain(request, status): 108 | try: 109 | status.is_unlearning = True 110 | status.progress = "Unlearning" 111 | updated_status = await unlearning_retrain(request, status) 112 | return updated_status 113 | finally: 114 | status.is_unlearning = False 115 | status.cancel_requested = False 116 | status.progress = "Completed" -------------------------------------------------------------------------------- /backend/app/threads/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thread package for managing model training and unlearning operations. 3 | 4 | This package provides thread classes that handle the actual computation for training 5 | and various unlearning methods. Each thread runs independently and provides status tracking. 6 | 7 | Available threads: 8 | TrainingThread: Handles model training operations 9 | UnlearningGAThread: Unlearning using gradient ascent method 10 | UnlearningRLThread: Unlearning using random labeling method 11 | UnlearningFTThread: Unlearning using fine-tuning method 12 | UnlearningRetrainThread: Unlearning by retraining from scratch 13 | UnlearningCustomThread: Custom unlearning method for inference 14 | 15 | Each thread class follows a similar pattern: 16 | 1. Inherits from threading.Thread 17 | 2. Takes model, data, and optimization components as input 18 | 3. Implements run() method for the main computation loop 19 | 4. Provides methods for status updates and graceful termination 20 | """ 21 | 22 | from .train_thread import TrainingThread 23 | from .unlearn_GA_thread import UnlearningGAThread 24 | from .unlearn_RL_thread import UnlearningRLThread 25 | from .unlearn_FT_thread import UnlearningFTThread 26 | from .unlearn_retrain_thread import UnlearningRetrainThread 27 | from .unlearn_custom_thread import UnlearningCustomThread 28 | 29 | __all__ = ['TrainingThread', 'UnlearningGAThread', 'UnlearningRLThread', 'UnlearningFTThread', 'UnlearningRetrainThread', 'UnlearningCustomThread'] -------------------------------------------------------------------------------- /backend/app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains utility functions and modules. 3 | """ 4 | 5 | from .data_loader import ( 6 | load_cifar10_data, 7 | get_data_loaders 8 | ) 9 | from .evaluation import ( 10 | get_layer_activations_and_predictions, 11 | evaluate_model, 12 | evaluate_model_with_distributions, 13 | calculate_cka_similarity 14 | ) 15 | from .helpers import ( 16 | set_seed, 17 | save_model, 18 | format_distribution, 19 | compress_prob_array 20 | ) 21 | from .visualization import compute_umap_embedding 22 | 23 | 24 | __all__ = [ 25 | 'load_cifar10_data', 'get_data_loaders', 26 | 'get_layer_activations_and_predictions', 'evaluate_model', 27 | 'evaluate_model_with_distributions', 'calculate_cka_similarity', 'set_seed', 'save_model', 28 | 'compute_umap_embedding', 'format_distribution', 'compress_prob_array' 29 | ] -------------------------------------------------------------------------------- /backend/app/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import DataLoader 4 | from app.config import UNLEARN_SEED 5 | 6 | def load_cifar10_data(): 7 | """Load CIFAR-10 training data with automatic download""" 8 | # Use torchvision's CIFAR10 dataset to handle downloading 9 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=None) 10 | 11 | # Convert to numpy array in the format we need 12 | x_train = train_set.data # This is already in (N, 32, 32, 3) format 13 | y_train = np.array(train_set.targets) 14 | 15 | return x_train, y_train 16 | 17 | def get_data_loaders(batch_size, augmentation=False): 18 | base_transforms = [ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 21 | ] 22 | 23 | train_transform = transforms.Compose( 24 | ([ 25 | transforms.RandomCrop(32, padding=4), 26 | transforms.RandomHorizontalFlip(), 27 | ] if augmentation else []) + base_transforms 28 | ) 29 | 30 | test_transform = transforms.Compose(base_transforms) 31 | 32 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) 33 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) 34 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0) 35 | test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=0) 36 | print("loaded loaders") 37 | return train_loader, test_loader, train_set, test_set 38 | 39 | def get_fixed_umap_indices(total_samples=2000, seed=UNLEARN_SEED): 40 | import torch 41 | _, y_train = load_cifar10_data() 42 | num_classes = 10 43 | targets_tensor = torch.tensor(y_train) 44 | 45 | samples_per_class = total_samples // num_classes 46 | generator = torch.Generator() 47 | generator.manual_seed(seed) 48 | 49 | indices_dict = {} 50 | for i in range(num_classes): 51 | class_indices = (targets_tensor == i).nonzero(as_tuple=False).squeeze() 52 | perm = torch.randperm(len(class_indices), generator=generator) 53 | indices_dict[i] = class_indices[perm[:samples_per_class]].tolist() 54 | 55 | return indices_dict -------------------------------------------------------------------------------- /backend/app/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import time 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from umap import UMAP 8 | 9 | from app.config import ( 10 | UMAP_N_NEIGHBORS, 11 | UMAP_MIN_DIST, 12 | UMAP_INIT, 13 | UMAP_N_JOBS 14 | ) 15 | 16 | async def compute_umap_embedding( 17 | activation, 18 | labels, 19 | forget_class=-1, 20 | forget_labels=None, 21 | save_dir='umap_visualizations' 22 | ): 23 | umap_embedding = [] 24 | 25 | class_names = [ 26 | 'airplane', 27 | 'automobile', 28 | 'bird', 29 | 'cat', 30 | 'deer', 31 | 'dog', 32 | 'frog', 33 | 'horse', 34 | 'ship', 35 | 'truck' 36 | ] 37 | if(forget_class != -1): 38 | class_names[forget_class] += " (forget)" 39 | 40 | colors = plt.cm.tab10(np.linspace(0, 1, 10)) 41 | if not os.path.exists(save_dir): 42 | os.makedirs(save_dir) 43 | 44 | umap = UMAP(n_components=2, 45 | n_neighbors=UMAP_N_NEIGHBORS, 46 | min_dist=UMAP_MIN_DIST, 47 | init=UMAP_INIT, 48 | n_jobs=UMAP_N_JOBS) 49 | print(f"UMAP start!") 50 | start_time = time.time() 51 | embedding = umap.fit_transform(activation) 52 | print(f"UMAP done! Time taken: {time.time() - start_time:.2f}s") 53 | 54 | umap_embedding = embedding 55 | plt.figure(figsize=(12, 11)) 56 | 57 | # Plot non-forget points 58 | if forget_labels is not None: 59 | non_forget_mask = ~forget_labels 60 | scatter = plt.scatter( 61 | embedding[non_forget_mask, 0], 62 | embedding[non_forget_mask, 1], 63 | c=labels[non_forget_mask], 64 | cmap='tab10', 65 | s=20, 66 | alpha=0.7, 67 | vmin=0, 68 | vmax=9 69 | ) 70 | 71 | # Plot forget points with 'x' marker 72 | forget_mask = forget_labels 73 | plt.scatter( 74 | embedding[forget_mask, 0], 75 | embedding[forget_mask, 1], 76 | c=labels[forget_mask], 77 | cmap='tab10', 78 | s=50, 79 | alpha=0.7, 80 | marker='x', 81 | linewidths=2.5, 82 | vmin=0, 83 | vmax=9 84 | ) 85 | else: 86 | scatter = plt.scatter( 87 | embedding[:, 0], 88 | embedding[:, 1], 89 | c=labels, 90 | cmap='tab10', 91 | s=20, 92 | alpha=0.7 93 | ) 94 | 95 | legend_elements = [ 96 | plt.Line2D( 97 | [0], [0], 98 | marker='o', 99 | color='w', 100 | label=class_names[i], 101 | markerfacecolor=colors[i], 102 | markersize=10 103 | ) for i in range(10) 104 | ] 105 | 106 | if forget_labels is not None: 107 | legend_elements.append( 108 | plt.Line2D( 109 | [0], [0], 110 | marker='x', 111 | color='k', 112 | label='Forget Data', 113 | markerfacecolor='k', 114 | markersize=10, 115 | markeredgewidth=3, 116 | linestyle='None', 117 | markeredgecolor='k' 118 | ) 119 | ) 120 | 121 | plt.legend( 122 | handles=legend_elements, 123 | title="Predicted Classes", 124 | loc='upper right', 125 | bbox_to_anchor=(0.99, 0.99), 126 | fontsize='x-large', 127 | title_fontsize='x-large' 128 | ) 129 | plt.axis('off') 130 | plt.text( 131 | 0.5, -0.05, f'Last Layer', 132 | fontsize=24, 133 | ha='center', 134 | va='bottom', 135 | transform=plt.gca().transAxes 136 | ) 137 | plt.tight_layout() 138 | 139 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 140 | filename = f'{timestamp}_umap_layer_last.svg' 141 | filepath = os.path.join(save_dir, filename) 142 | 143 | plt.savefig( 144 | filepath, 145 | format='svg', 146 | dpi=300, 147 | bbox_inches='tight', 148 | pad_inches=0.1 149 | ) 150 | 151 | print("\nUMAP embeddings computation and saving completed!") 152 | return umap_embedding -------------------------------------------------------------------------------- /backend/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from contextlib import asynccontextmanager 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from app.routers import train, unlearn, data 5 | from app.utils.helpers import download_weights_from_hub 6 | 7 | # Constants 8 | ALLOW_ORIGINS = ["*"] # TODO: Update URL after deployment 9 | 10 | @asynccontextmanager 11 | async def lifespan(app: FastAPI): 12 | download_weights_from_hub() 13 | yield 14 | 15 | def setup_middleware(app: FastAPI) -> None: 16 | app.add_middleware( 17 | CORSMiddleware, 18 | allow_origins=ALLOW_ORIGINS, 19 | allow_credentials=True, 20 | allow_methods=["*"], 21 | allow_headers=["*"], 22 | ) 23 | 24 | def register_routers(app: FastAPI) -> None: 25 | app.include_router(train.router) 26 | app.include_router(unlearn.router) 27 | app.include_router(data.router) 28 | 29 | def create_app() -> FastAPI: 30 | app = FastAPI(lifespan=lifespan) 31 | setup_middleware(app) 32 | register_routers(app) 33 | return app 34 | 35 | # Create application instance after all definitions 36 | app = create_app() 37 | 38 | @app.get("/") 39 | async def root(): 40 | return {"message": "Welcome to the MU Dashboard API"} -------------------------------------------------------------------------------- /backend/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.envs.default] 6 | python = "3.10" 7 | 8 | [project] 9 | name = "Machine-Unlearning-Comparator" 10 | version = "0.1.0" 11 | description = "Machine Unlearning Comparator" 12 | readme = "README.md" 13 | requires-python = ">=3.8, <3.12" 14 | license = "MIT" 15 | keywords = [] 16 | authors = [ 17 | { name = "gnueaj", email = "dlwodnd00@gmail.com" }, 18 | ] 19 | classifiers = [ 20 | "Development Status :: 4 - Beta", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: Implementation :: CPython", 27 | "Programming Language :: Python :: Implementation :: PyPy", 28 | ] 29 | dependencies = [ 30 | "fastapi==0.109.0", # 31 | "uvicorn==0.27.0", # 32 | "torch==2.1.2", # (CUDA 12.1) 33 | "torchvision==0.16.2", # (torch 2.1.2) 34 | "numpy==1.24.3", # 35 | "umap-learn==0.5.5", # 36 | "scikit-learn==1.3.0", # 37 | "packaging>=21.0", 38 | "matplotlib", 39 | "python-multipart", 40 | "seaborn", 41 | "torch_cka", 42 | "huggingface_hub", 43 | ] 44 | 45 | 46 | [project.urls] 47 | Documentation = "https://github.com/gnueaj/Machine-Unlearning-Comparator#readme" 48 | Issues = "https://github.com/gnueaj/Machine-Unlearning-Comparator/issues" 49 | Source = "https://github.com/gnueaj/Machine-Unlearning-Comparator" 50 | 51 | [tool.hatch.envs.types] 52 | extra-dependencies = [ 53 | "mypy>=1.0.0", 54 | ] 55 | [tool.hatch.envs.types.scripts] 56 | cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/fastapi_resnet_cifar10 --cov=tests {args}" 57 | no-cov = "cov --no-cov {args}" 58 | start = "uvicorn main:app --reload" 59 | 60 | [tool.hatch.envs.default.scripts] 61 | start = "uvicorn main:app --host 0.0.0.0 --port 8000 --reload" 62 | # start = "uvicorn main:app --reload" 63 | 64 | [tool.hatch.build.targets.wheel] 65 | packages = ["app"] 66 | 67 | [tool.coverage.run] 68 | branch = true 69 | parallel = true 70 | 71 | [tool.coverage.report] 72 | exclude_lines = [ 73 | "no cov", 74 | "if __name__ == .__main__.:", 75 | "if TYPE_CHECKING:", 76 | ] 77 | -------------------------------------------------------------------------------- /demo.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0a1cb36f049cf96027e9abf76baf5c6b7fac146177d86e56a24149c5fad99328 3 | size 248814650 4 | -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | # Machine Unlearning Dashboard 2 | It's MU (Machine Unlearning) Dashboard that provides training, evaluation of unlearning, and attack simulation. 3 | 4 | ## How to start the frontend 5 | 6 | * Install [pnpm](https://pnpm.io/installation) to set up the development environment. 7 | 8 | **1. Clone the repository and move to the frontend directory** 9 | ```shell 10 | git clone https://github.com/gnueaj/mu-dashboard.git && cd mu-dashboard/frontend 11 | ``` 12 | 13 | **2. Download necessary modules using `pnpm`** 14 | ```shell 15 | pnpm install 16 | ``` 17 | 18 | **3. Run the project** 19 | ```shell 20 | `pnpm dev` or `pnpm start` 21 | ``` 22 | -------------------------------------------------------------------------------- /frontend/components.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://ui.shadcn.com/schema.json", 3 | "style": "default", 4 | "rsc": false, 5 | "tsx": true, 6 | "tailwind": { 7 | "config": "tailwind.config.js", 8 | "css": "src/app/index.css", 9 | "baseColor": "slate", 10 | "cssVariables": true 11 | }, 12 | "aliases": { 13 | "components": "@/src/components", 14 | "utils": "@/util.ts" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "my-app", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@fortawesome/fontawesome-svg-core": "^6.5.2", 7 | "@fortawesome/free-regular-svg-icons": "^6.5.2", 8 | "@fortawesome/free-solid-svg-icons": "^6.5.2", 9 | "@fortawesome/react-fontawesome": "^0.2.2", 10 | "@radix-ui/react-checkbox": "^1.1.4", 11 | "@radix-ui/react-context-menu": "^2.2.2", 12 | "@radix-ui/react-dialog": "^1.1.2", 13 | "@radix-ui/react-hover-card": "^1.1.2", 14 | "@radix-ui/react-label": "^2.1.0", 15 | "@radix-ui/react-radio-group": "^1.2.0", 16 | "@radix-ui/react-scroll-area": "^1.1.0", 17 | "@radix-ui/react-select": "^2.1.1", 18 | "@radix-ui/react-separator": "^1.1.0", 19 | "@radix-ui/react-slider": "^1.2.0", 20 | "@radix-ui/react-slot": "^1.1.0", 21 | "@radix-ui/react-tabs": "^1.1.3", 22 | "@tanstack/react-table": "^8.20.5", 23 | "@testing-library/jest-dom": "^5.17.0", 24 | "@testing-library/react": "^13.4.0", 25 | "@testing-library/user-event": "^13.5.0", 26 | "@types/d3": "^7.4.3", 27 | "@types/jest": "^27.5.2", 28 | "@types/node": "^16.18.101", 29 | "@types/react": "^18.3.3", 30 | "@types/react-dom": "^18.3.0", 31 | "class-variance-authority": "^0.7.0", 32 | "clsx": "^2.1.1", 33 | "d3": "^7.9.0", 34 | "lucide-react": "^0.438.0", 35 | "react": "^18.3.1", 36 | "react-dom": "^18.3.1", 37 | "react-icons": "^5.3.0", 38 | "react-scripts": "5.0.1", 39 | "recharts": "^2.12.7", 40 | "tailwind-merge": "^2.5.2", 41 | "tailwindcss": "^3.4.10", 42 | "tailwindcss-animate": "^1.0.7", 43 | "typescript": "^4.9.5", 44 | "web-vitals": "^2.1.4", 45 | "zustand": "^5.0.3" 46 | }, 47 | "scripts": { 48 | "start": "react-scripts start", 49 | "build": "react-scripts build", 50 | "test": "react-scripts test", 51 | "eject": "react-scripts eject", 52 | "dev": "pnpm start" 53 | }, 54 | "eslintConfig": { 55 | "extends": [ 56 | "react-app", 57 | "react-app/jest" 58 | ] 59 | }, 60 | "browserslist": { 61 | "production": [ 62 | ">0.2%", 63 | "not dead", 64 | "not op_mini all" 65 | ], 66 | "development": [ 67 | "last 1 chrome version", 68 | "last 1 firefox version", 69 | "last 1 safari version" 70 | ] 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 14 | 15 | Machine Unlearning Comparator 16 | 17 | 18 |
19 | 20 | 21 | -------------------------------------------------------------------------------- /frontend/public/logo.svg: -------------------------------------------------------------------------------- 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 21 | 25 | -------------------------------------------------------------------------------- /frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /frontend/src/app/App.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useEffect, useCallback } from "react"; 2 | 3 | import Header from "../components/Header/Header"; 4 | import ModelScreening from "../views/ModelScreening/ModelScreening"; 5 | import Core from "../views/Core/Core"; 6 | import MetricsView from "../views/MetricsView/MetricsView"; 7 | import { useExperimentsStore } from "../stores/experimentsStore"; 8 | import { calculateZoom } from "../utils/util"; 9 | 10 | export const CONFIG = { 11 | TOTAL_WIDTH: 1805, 12 | EXPERIMENTS_WIDTH: 1032, 13 | CORE_WIDTH: 1312, 14 | get PROGRESS_WIDTH() { 15 | return this.CORE_WIDTH - this.EXPERIMENTS_WIDTH - 1; 16 | }, 17 | get ANALYSIS_VIEW_WIDTH() { 18 | return this.TOTAL_WIDTH - this.CORE_WIDTH; 19 | }, 20 | 21 | EXPERIMENTS_PROGRESS_HEIGHT: 256, 22 | CORE_HEIGHT: 820, 23 | get TOTAL_HEIGHT() { 24 | return this.CORE_HEIGHT + this.EXPERIMENTS_PROGRESS_HEIGHT; 25 | }, 26 | } as const; 27 | 28 | export default function App() { 29 | const { isExperimentLoading } = useExperimentsStore(); 30 | 31 | const [isPageLoading, setIsPageLoading] = useState(true); 32 | const [zoom, setZoom] = useState(1); 33 | 34 | const handleResize = useCallback(() => { 35 | setZoom(calculateZoom()); 36 | }, []); 37 | 38 | useEffect(() => { 39 | setIsPageLoading(false); 40 | 41 | window.addEventListener("resize", handleResize); 42 | handleResize(); 43 | 44 | return () => window.removeEventListener("resize", handleResize); 45 | }, [handleResize]); 46 | 47 | if (isPageLoading) return <>; 48 | 49 | return ( 50 |
51 |
52 | {!isExperimentLoading && ( 53 |
54 |
55 | 56 | 57 |
58 | 59 |
60 | )} 61 |
62 | ); 63 | } 64 | -------------------------------------------------------------------------------- /frontend/src/app/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | @layer base { 6 | :root { 7 | --background: 0 0% 100%; 8 | --foreground: 222.2 47.4% 11.2%; 9 | 10 | --muted: 210 40% 96.1%; 11 | --muted-foreground: 215.4 16.3% 46.9%; 12 | 13 | --popover: 0 0% 100%; 14 | --popover-foreground: 222.2 47.4% 11.2%; 15 | 16 | --border: 214.3 31.8% 91.4%; 17 | --input: 214.3 31.8% 91.4%; 18 | 19 | --card: 0 0% 100%; 20 | --card-foreground: 222.2 47.4% 11.2%; 21 | 22 | --primary: 222.2 47.4% 11.2%; 23 | --primary-foreground: 210 40% 98%; 24 | 25 | --secondary: 210 40% 96.1%; 26 | --secondary-foreground: 222.2 47.4% 11.2%; 27 | 28 | --accent: 210 40% 96.1%; 29 | --accent-foreground: 222.2 47.4% 11.2%; 30 | 31 | --destructive: 0 100% 50%; 32 | --destructive-foreground: 210 40% 98%; 33 | 34 | --ring: 215 20.2% 65.1%; 35 | 36 | --radius: 0.5rem; 37 | } 38 | 39 | .dark { 40 | --background: 224 71% 4%; 41 | --foreground: 213 31% 91%; 42 | 43 | --muted: 223 47% 11%; 44 | --muted-foreground: 215.4 16.3% 56.9%; 45 | 46 | --accent: 216 34% 17%; 47 | --accent-foreground: 210 40% 98%; 48 | 49 | --popover: 224 71% 4%; 50 | --popover-foreground: 215 20.2% 65.1%; 51 | 52 | --border: 216 34% 17%; 53 | --input: 216 34% 17%; 54 | 55 | --card: 224 71% 4%; 56 | --card-foreground: 213 31% 91%; 57 | 58 | --primary: 210 40% 98%; 59 | --primary-foreground: 222.2 47.4% 1.2%; 60 | 61 | --secondary: 222.2 47.4% 11.2%; 62 | --secondary-foreground: 210 40% 98%; 63 | 64 | --destructive: 0 63% 31%; 65 | --destructive-foreground: 210 40% 98%; 66 | 67 | --ring: 216 34% 17%; 68 | 69 | --radius: 0.5rem; 70 | } 71 | } 72 | 73 | @layer base { 74 | :root { 75 | --chart-1: 12 76% 61%; 76 | --chart-2: 173 58% 39%; 77 | --chart-3: 197 37% 24%; 78 | --chart-4: 43 74% 66%; 79 | --chart-5: 27 87% 67%; 80 | } 81 | 82 | .dark { 83 | --chart-1: 220 70% 50%; 84 | --chart-2: 160 60% 45%; 85 | --chart-3: 30 80% 55%; 86 | --chart-4: 280 65% 60%; 87 | --chart-5: 340 75% 55%; 88 | } 89 | } 90 | 91 | * { 92 | box-sizing: border-box; 93 | margin: 0; 94 | padding: 0; 95 | } 96 | 97 | body { 98 | display: flex; 99 | justify-content: flex-start; 100 | align-items: flex-start; 101 | font-family: "Roboto Condensed", "pretendard", -apple-system, "Roboto", 102 | sans-serif; 103 | -webkit-font-smoothing: antialiased; 104 | -moz-osx-font-smoothing: grayscale; 105 | } 106 | 107 | input[type="number"]::-webkit-outer-spin-button, 108 | input[type="number"]::-webkit-inner-spin-button { 109 | -webkit-appearance: none; 110 | margin: 0; 111 | } 112 | input:disabled, 113 | select:disabled { 114 | cursor: not-allowed; 115 | } 116 | 117 | form { 118 | width: 100%; 119 | } 120 | -------------------------------------------------------------------------------- /frontend/src/app/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /frontend/src/components/Core/Embeddings/ConnectionLine.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | import { calculateZoom } from "../../../utils/util"; 4 | 5 | interface ConnectionLineProps { 6 | from: { x: number; y: number } | null; 7 | to: { x: number; y: number } | null; 8 | } 9 | 10 | export default function ConnectionLine({ from, to }: ConnectionLineProps) { 11 | if (!from || !to) { 12 | return null; 13 | } 14 | 15 | const zoom = calculateZoom(); 16 | const scrollbarWidth = 17 | window.innerWidth - document.documentElement.clientWidth; 18 | const adjustedZoom = 19 | ((window.innerWidth - scrollbarWidth) / window.innerWidth) * zoom; 20 | 21 | const lineStyle: React.CSSProperties = { 22 | position: "fixed", 23 | left: 0, 24 | top: 0, 25 | pointerEvents: "none", 26 | }; 27 | 28 | const x1 = from.x / adjustedZoom; 29 | const y1 = from.y / adjustedZoom; 30 | const x2 = to.x / adjustedZoom; 31 | const y2 = to.y / adjustedZoom; 32 | 33 | const length = Math.hypot(x2 - x1, y2 - y1); 34 | const angle = (Math.atan2(y2 - y1, x2 - x1) * 180) / Math.PI; 35 | 36 | const shortenedLength = length - 8; 37 | 38 | const offsetX = 4 * Math.cos((angle * Math.PI) / 180); 39 | const offsetY = 4 * Math.sin((angle * Math.PI) / 180); 40 | 41 | const linePositionStyle: React.CSSProperties = { 42 | position: "absolute", 43 | transformOrigin: "0 0", 44 | transform: `translate(${x1 + offsetX}px, ${ 45 | y1 + offsetY 46 | }px) rotate(${angle}deg)`, 47 | width: `${shortenedLength}px`, 48 | height: "2px", 49 | backgroundColor: "black", 50 | }; 51 | 52 | return ( 53 |
54 |
55 |
56 | ); 57 | } 58 | -------------------------------------------------------------------------------- /frontend/src/components/Core/Embeddings/ConnectionLineWrapper.tsx: -------------------------------------------------------------------------------- 1 | import React, { memo, useRef, useEffect, useState } from "react"; 2 | 3 | import ConnectionLine from "./ConnectionLine"; 4 | import { Coordinate } from "../../../types/embeddings"; 5 | 6 | type Position = { 7 | from: Coordinate | null; 8 | to: Coordinate | null; 9 | }; 10 | 11 | interface Props { 12 | positionRef: React.MutableRefObject; 13 | } 14 | 15 | const ConnectionLineWrapper = memo(({ positionRef }: Props) => { 16 | const [, setUpdateKey] = useState(0); 17 | const prevPositionRef = useRef({ from: null, to: null }); 18 | 19 | useEffect(() => { 20 | const hasPositionChanged = () => { 21 | const current = positionRef.current; 22 | const prev = prevPositionRef.current; 23 | 24 | if (!current.from !== !prev.from || !current.to !== !prev.to) return true; 25 | if (!current.from || !current.to) return false; 26 | if (!prev.from || !prev.to) return true; 27 | 28 | return ( 29 | current.from.x !== prev.from.x || 30 | current.from.y !== prev.from.y || 31 | current.to.x !== prev.to.x || 32 | current.to.y !== prev.to.y 33 | ); 34 | }; 35 | 36 | const intervalId = setInterval(() => { 37 | if (hasPositionChanged()) { 38 | prevPositionRef.current = { 39 | from: positionRef.current.from 40 | ? { ...positionRef.current.from } 41 | : null, 42 | to: positionRef.current.to ? { ...positionRef.current.to } : null, 43 | }; 44 | setUpdateKey((prev) => prev + 1); 45 | } 46 | }, 16); 47 | 48 | return () => clearInterval(intervalId); 49 | }, [positionRef]); 50 | 51 | return ( 52 | 56 | ); 57 | }); 58 | 59 | export default ConnectionLineWrapper; 60 | -------------------------------------------------------------------------------- /frontend/src/components/Core/PrivacyAttack/AttackSuccessFailure.tsx: -------------------------------------------------------------------------------- 1 | import InstancePanel from "./InstancePanel"; 2 | import { Bin, Data, CategoryType, Image } from "../../../types/attack"; 3 | 4 | interface AttackSuccessFailureProps { 5 | mode: "A" | "B"; 6 | thresholdValue: number; 7 | hoveredId: number | null; 8 | data: Data; 9 | imageMap: Map; 10 | attackScore: number; 11 | setHoveredId: (val: number | null) => void; 12 | onElementClick: ( 13 | event: React.MouseEvent, 14 | elementData: Bin & { type: CategoryType } 15 | ) => void; 16 | } 17 | 18 | export default function AttackSuccessFailure({ 19 | mode, 20 | thresholdValue, 21 | hoveredId, 22 | data, 23 | imageMap, 24 | attackScore, 25 | setHoveredId, 26 | onElementClick, 27 | }: AttackSuccessFailureProps) { 28 | const forgettingQualityScore = 1 - attackScore; 29 | 30 | return ( 31 |
32 |

33 | Privacy Score ={" "} 34 | 35 | {forgettingQualityScore === 1 ? 1 : forgettingQualityScore.toFixed(3)} 36 | 37 |

38 |
39 | 49 | 59 |
60 |
61 | ); 62 | } 63 | -------------------------------------------------------------------------------- /frontend/src/components/Header/Header.tsx: -------------------------------------------------------------------------------- 1 | import { FileText } from "lucide-react"; 2 | 3 | import Tabs from "./Tabs"; 4 | import { Logo, GithubIcon } from "../UI/icons"; 5 | import { useBaseConfigStore } from "../../stores/baseConfigStore"; 6 | import { DATASETS, NEURAL_NETWORK_MODELS } from "../../constants/common"; 7 | import { 8 | Select, 9 | SelectContent, 10 | SelectItem, 11 | SelectTrigger, 12 | SelectValue, 13 | } from "../UI/select"; 14 | 15 | export default function Header() { 16 | const { dataset, setDataset, neuralNetworkModel, setNeuralNetworkModel } = 17 | useBaseConfigStore(); 18 | 19 | const handleGithubIconClick = () => { 20 | window.open( 21 | "https://github.com/gnueaj/Machine-Unlearning-Comparator", 22 | "_blank" 23 | ); 24 | }; 25 | 26 | const handleDatasetChange = (dataset: string) => { 27 | setDataset(dataset); 28 | }; 29 | 30 | const handleNeuralNetworkModelChange = (model: string) => { 31 | setNeuralNetworkModel(model); 32 | }; 33 | 34 | return ( 35 |
36 |
37 |
38 |
39 | 40 | 41 | Unlearning Comparator 42 | 43 |
44 |
45 |
46 | Architecture 47 | 63 |
64 |
65 | Dataset 66 | 82 |
83 |
84 | 85 |
86 |
87 |
88 | 89 | 93 |
94 |
95 | ); 96 | } 97 | -------------------------------------------------------------------------------- /frontend/src/components/Header/Tab.tsx: -------------------------------------------------------------------------------- 1 | import { MultiplicationSignIcon } from "../UI/icons"; 2 | import { useClasses } from "../../hooks/useClasses"; 3 | import { useForgetClassStore } from "../../stores/forgetClassStore"; 4 | import { cn } from "../../utils/util"; 5 | 6 | interface Props { 7 | setOpen: (open: boolean) => void; 8 | fetchAndSaveExperiments: (forgetClass: string) => Promise; 9 | } 10 | 11 | export default function Tab({ setOpen, fetchAndSaveExperiments }: Props) { 12 | const classes = useClasses(); 13 | const forgetClass = useForgetClassStore((state) => state.forgetClass); 14 | const saveForgetClass = useForgetClassStore((state) => state.saveForgetClass); 15 | const selectedForgetClasses = useForgetClassStore( 16 | (state) => state.selectedForgetClasses 17 | ); 18 | const deleteSelectedForgetClass = useForgetClassStore( 19 | (state) => state.deleteSelectedForgetClass 20 | ); 21 | 22 | const handleForgetClassChange = async (value: string) => { 23 | if (classes[forgetClass] !== value) { 24 | saveForgetClass(value); 25 | await fetchAndSaveExperiments(value); 26 | } 27 | }; 28 | 29 | const handleDeleteClick = async (targetClass: string) => { 30 | const firstSelectedForgetClass = selectedForgetClasses[0]; 31 | const secondSelectedForgetClass = selectedForgetClasses[1]; 32 | const targetSelectedForgetClassesIndex = selectedForgetClasses.indexOf( 33 | classes.indexOf(targetClass) 34 | ); 35 | 36 | deleteSelectedForgetClass(targetClass); 37 | 38 | if (targetClass === classes[forgetClass]) { 39 | if (selectedForgetClasses.length === 1) { 40 | saveForgetClass(-1); 41 | setOpen(true); 42 | } else { 43 | const autoSelectedForgetClass = 44 | targetSelectedForgetClassesIndex === 0 45 | ? classes[secondSelectedForgetClass] 46 | : classes[firstSelectedForgetClass]; 47 | saveForgetClass(autoSelectedForgetClass); 48 | await fetchAndSaveExperiments(autoSelectedForgetClass); 49 | } 50 | } 51 | }; 52 | 53 | return ( 54 | <> 55 | {selectedForgetClasses.map((selectedForgetClass, idx) => { 56 | const isSelectedForgetClass = selectedForgetClass === forgetClass; 57 | const forgetClassName = classes[selectedForgetClass]; 58 | 59 | return ( 60 |
61 |
handleForgetClassChange(forgetClassName)} 69 | > 70 | 76 | Forget: {forgetClassName} 77 | 78 |
79 | handleDeleteClick(forgetClassName)} 81 | className={cn( 82 | "w-3.5 h-3.5 p-[1px] cursor-pointer rounded-full absolute right-2.5 bg-transparent transition text-gray-500", 83 | isSelectedForgetClass 84 | ? "hover:bg-gray-300" 85 | : "hover:bg-gray-700" 86 | )} 87 | /> 88 | {isSelectedForgetClass && ( 89 |
90 | )} 91 |
92 | ); 93 | })} 94 | 95 | ); 96 | } 97 | -------------------------------------------------------------------------------- /frontend/src/components/Header/Tabs.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | 3 | import Tab from "./Tab"; 4 | import ForgetClassTabPlusButton from "./TabPlusButton"; 5 | import { Experiment, Experiments } from "../../types/data"; 6 | import { useClasses } from "../../hooks/useClasses"; 7 | import { useExperimentsStore } from "../../stores/experimentsStore"; 8 | import { useForgetClassStore } from "../../stores/forgetClassStore"; 9 | import { fetchAllExperimentsData } from "../../utils/api/unlearning"; 10 | 11 | export default function Tabs() { 12 | const classes = useClasses(); 13 | const saveExperiments = useExperimentsStore((state) => state.saveExperiments); 14 | const selectedForgetClasses = useForgetClassStore( 15 | (state) => state.selectedForgetClasses 16 | ); 17 | const setIsExperimentsLoading = useExperimentsStore( 18 | (state) => state.setIsExperimentsLoading 19 | ); 20 | 21 | const hasNoSelectedForgetClass = selectedForgetClasses.length === 0; 22 | 23 | const [open, setOpen] = useState(hasNoSelectedForgetClass); 24 | 25 | const fetchAndSaveExperiments = async (forgetClass: string) => { 26 | const classIndex = classes.indexOf(forgetClass); 27 | setIsExperimentsLoading(true); 28 | try { 29 | const allData: Experiments = await fetchAllExperimentsData(classIndex); 30 | 31 | if ("detail" in allData) { 32 | saveExperiments({}); 33 | } else { 34 | Object.values(allData).forEach((experiment: Experiment) => { 35 | if (experiment && "points" in experiment) { 36 | delete experiment.points; 37 | } 38 | }); 39 | saveExperiments(allData); 40 | } 41 | } finally { 42 | setIsExperimentsLoading(false); 43 | } 44 | }; 45 | 46 | return ( 47 |
48 | 52 | 58 |
59 | ); 60 | } 61 | -------------------------------------------------------------------------------- /frontend/src/components/MetricsView/Predictions/BubbleMatrixLegend.tsx: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | 3 | import { Arrow } from "../../UI/icons"; 4 | 5 | export default function BubbleChartLegend() { 6 | const colorScale = d3 7 | .scaleSequential((t) => d3.interpolateTurbo(0.05 + 0.95 * t)) 8 | .domain([0, 1]); 9 | const numStops = 100; 10 | const bubbleColorScale = Array.from({ length: numStops }, (_, i) => 11 | colorScale(i / (numStops - 1)) 12 | ); 13 | const gradient = `linear-gradient(to right, ${bubbleColorScale.join(", ")})`; 14 | 15 | return ( 16 |
17 |
18 |
19 |
20 |
21 |

22 | Small 23 | Proportion 24 |

25 | 26 |

27 | Large 28 | Proportion 29 |

30 |
31 | 32 |
33 |
34 |
38 |
39 | 0 40 |
41 |
42 | 1 43 |
44 |
45 |
46 |

47 | Low 48 | Confidence 49 |

50 | 51 |

52 | High 53 | Confidence 54 |

55 |
56 |
57 |
58 | ); 59 | } 60 | -------------------------------------------------------------------------------- /frontend/src/components/MetricsView/Predictions/CorrelationMatrixLegend.tsx: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | 3 | import { Arrow } from "../../UI/icons"; 4 | 5 | export default function PredictionMatrixLegend() { 6 | const colorScale = d3 7 | .scaleSequential((t) => d3.interpolateGreys(0.05 + 0.95 * t)) 8 | .domain([0, 1]); 9 | const numStops = 10; 10 | const bubbleColorScale = Array.from({ length: numStops }, (_, i) => 11 | colorScale(i / (numStops - 1)) 12 | ); 13 | const gradient = `linear-gradient(to right, ${bubbleColorScale.join(", ")})`; 14 | 15 | return ( 16 |
17 |
18 | 19 |
20 | Row 21 | Proportion 22 |
23 | 24 |
25 | Row 26 | Confidence 27 |
28 | 29 |
30 | 31 |
32 |
33 |
34 |
35 | 0 36 |
37 |
38 | 1 39 |
40 |
41 |
42 | Low 43 | 44 | High 45 |
46 |
47 |
48 | ); 49 | } 50 | 51 | function RectangleIcon({ className }: { className: string }) { 52 | return ( 53 | 61 | 70 | 78 | 79 | ); 80 | } 81 | 82 | function RightArrowIcon() { 83 | return ( 84 | 92 | 97 | 98 | ); 99 | } 100 | 101 | function LeftArrowIcon() { 102 | return ( 103 | 111 | 116 | 117 | ); 118 | } 119 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/CustomUnlearning.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | import { Input } from "../../UI/input"; 4 | import { cn } from "../../../utils/util"; 5 | 6 | interface Props extends React.InputHTMLAttributes { 7 | fileName: string; 8 | } 9 | 10 | export default function CustomUnlearning({ fileName, ...props }: Props) { 11 | return ( 12 |
13 |

Upload File

14 |
15 | 22 |
23 | 26 | {fileName ? fileName : "Choose File"} 27 | 28 |
29 |
30 |
31 | ); 32 | } 33 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/HyperparameterInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | 3 | import { 4 | EPOCH, 5 | LEARNING_RATE, 6 | BATCH_SIZE, 7 | } from "../../../constants/experiments"; 8 | import { Input } from "../../UI/input"; 9 | import { PlusIcon } from "../../UI/icons"; 10 | import { COLORS } from "../../../constants/colors"; 11 | import { cn } from "../../../utils/util"; 12 | 13 | interface Props 14 | extends Omit, "list"> { 15 | title: string; 16 | initialValue: string; 17 | paramList: (string | number)[]; 18 | onPlusClick: (id: string, value: string) => void; 19 | } 20 | 21 | const CONFIG = { 22 | EPOCHS_MIN: 1, 23 | LEARNING_RATE_MIN: 0, 24 | LEARNING_RATE_MAX: 1, 25 | BATCH_SIZE_MIN: 1, 26 | } as const; 27 | 28 | export default function HyperparameterInput({ 29 | title, 30 | initialValue, 31 | paramList, 32 | onPlusClick, 33 | ...props 34 | }: Props) { 35 | const [value, setValue] = useState(initialValue); 36 | const [isDisabled, setIsDisabled] = useState(false); 37 | 38 | useEffect(() => { 39 | setValue(initialValue); 40 | }, [initialValue]); 41 | 42 | useEffect(() => { 43 | if (paramList.length === 5 || value.trim() === "") { 44 | setIsDisabled(true); 45 | } else { 46 | setIsDisabled(false); 47 | } 48 | }, [paramList.length, value]); 49 | 50 | const processedTitle = title.replace(/\s+/g, ""); 51 | const id = processedTitle.charAt(0).toLowerCase() + processedTitle.slice(1); 52 | const isIntegerInput = id === EPOCH || id === BATCH_SIZE; 53 | 54 | const handleValueChange = (event: React.ChangeEvent) => { 55 | const inputValue = event.currentTarget.value; 56 | 57 | if (inputValue === "") { 58 | setValue(inputValue); 59 | setIsDisabled(true); 60 | return; 61 | } 62 | setIsDisabled(false); 63 | 64 | let newValue = inputValue; 65 | if (isIntegerInput && newValue.includes(".")) { 66 | newValue = parseInt(newValue, 10).toString(); 67 | } 68 | 69 | switch (id) { 70 | case EPOCH: { 71 | const validValue = Math.max(Number(newValue), CONFIG.EPOCHS_MIN); 72 | setValue(String(validValue)); 73 | break; 74 | } 75 | case BATCH_SIZE: { 76 | const validValue = Math.max(Number(newValue), CONFIG.BATCH_SIZE_MIN); 77 | setValue(String(validValue)); 78 | break; 79 | } 80 | case LEARNING_RATE: { 81 | setValue(newValue); 82 | const numericValue = Number(newValue); 83 | if ( 84 | numericValue >= CONFIG.LEARNING_RATE_MAX || 85 | numericValue <= CONFIG.LEARNING_RATE_MIN 86 | ) { 87 | setIsDisabled(true); 88 | } else { 89 | setIsDisabled(false); 90 | } 91 | break; 92 | } 93 | default: 94 | setValue(newValue); 95 | break; 96 | } 97 | }; 98 | 99 | const handlePlusClick = () => { 100 | if (!isDisabled) { 101 | onPlusClick(id, value); 102 | } 103 | }; 104 | 105 | return ( 106 |
107 | {title} 108 | 116 |
123 | 124 |
125 |
126 | ); 127 | } 128 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/Legend.tsx: -------------------------------------------------------------------------------- 1 | import * as d3 from "d3"; 2 | 3 | const CONFIG = { 4 | COLOR_TEMPERATURE_LOW: 0.03, 5 | COLOR_TEMPERATURE_HIGH: 0.77, 6 | } as const; 7 | 8 | export default function DualMetricsLegend() { 9 | const greenScale = d3 10 | .scaleSequential((t) => 11 | d3.interpolateGreens( 12 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t 13 | ) 14 | ) 15 | .domain([0, 1]); 16 | const accuracyGradient = `linear-gradient(to right, ${greenScale( 17 | 0 18 | )} 0%, ${greenScale(1)} 100%)`; 19 | 20 | const blueScale = d3 21 | .scaleSequential((t) => 22 | d3.interpolateBlues( 23 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t 24 | ) 25 | ) 26 | .domain([0, 1]); 27 | const efficiencyGradient = `linear-gradient(to right, ${blueScale( 28 | 0 29 | )} 0%, ${blueScale(1)} 100%)`; 30 | 31 | const orangeScale = d3 32 | .scaleSequential((t) => 33 | d3.interpolateOranges( 34 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t 35 | ) 36 | ) 37 | .domain([0, 1]); 38 | const forgettingQualityGradient = `linear-gradient(to right, ${orangeScale( 39 | 0 40 | )} 0%, ${orangeScale(1)} 100%)`; 41 | 42 | return ( 43 |
44 |
45 |

Accuracy

46 |
47 |
51 | 52 | Worst 53 | 54 | 55 | Best 56 | 57 |
58 |
59 |
60 |
61 |

Efficiency

62 |
63 |
67 | 68 | Worst 69 | 70 | 71 | Best 72 | 73 |
74 |
75 |
76 |
77 |

78 | Privacy 79 |

80 |
81 |
85 | 86 | 0 87 | 88 | 89 | 1 90 | 91 |
92 |
93 |
94 |
95 | ); 96 | } 97 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/MethodFilterHeader.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useRef, useEffect } from "react"; 2 | import { createPortal } from "react-dom"; 3 | 4 | import { FilterIcon } from "../../UI/icons"; 5 | import { UNLEARNING_METHODS } from "../../../constants/experiments"; 6 | import { cn } from "../../../utils/util"; 7 | 8 | const MOUSEDOWN = "mousedown"; 9 | 10 | export default function MethodFilterHeader({ column }: { column: any }) { 11 | const [filterValues, setFilterValues] = useState([]); 12 | const [showDropdown, setShowDropdown] = useState(false); 13 | const [dropdownCoords, setDropdownCoords] = useState({ top: 0, left: 0 }); 14 | 15 | const containerRef = useRef(null); 16 | const dropdownRef = useRef(null); 17 | 18 | useEffect(() => { 19 | if (showDropdown && containerRef.current) { 20 | const rect = containerRef.current.getBoundingClientRect(); 21 | setDropdownCoords({ 22 | top: rect.bottom, 23 | left: rect.left, 24 | }); 25 | } 26 | }, [showDropdown]); 27 | 28 | const handleSelect = (value: string) => () => { 29 | let newFilterValues: string[]; 30 | 31 | if (filterValues.includes(value)) { 32 | newFilterValues = filterValues.filter((val) => val !== value); 33 | } else if (value === "") { 34 | newFilterValues = []; 35 | } else { 36 | newFilterValues = [...filterValues, value]; 37 | } 38 | 39 | setFilterValues(newFilterValues); 40 | column.setFilterValue(newFilterValues); 41 | setShowDropdown(false); 42 | }; 43 | 44 | useEffect(() => { 45 | function handleClickOutside(event: MouseEvent) { 46 | if ( 47 | containerRef.current && 48 | !containerRef.current.contains(event.target as Node) && 49 | dropdownRef.current && 50 | !dropdownRef.current.contains(event.target as Node) 51 | ) { 52 | setShowDropdown(false); 53 | } 54 | } 55 | 56 | if (showDropdown) { 57 | document.addEventListener(MOUSEDOWN, handleClickOutside); 58 | } else { 59 | document.removeEventListener(MOUSEDOWN, handleClickOutside); 60 | } 61 | return () => { 62 | document.removeEventListener(MOUSEDOWN, handleClickOutside); 63 | }; 64 | }, [showDropdown]); 65 | 66 | return ( 67 |
68 | Method 69 | setShowDropdown(true)} 73 | /> 74 | {showDropdown && 75 | createPortal( 76 |
84 |
88 | All 89 |
90 | {UNLEARNING_METHODS.map((method) => ( 91 |
101 | {method} 102 |
103 | ))} 104 |
, 105 | document.body 106 | )} 107 |
108 | ); 109 | } 110 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/MethodUnlearning.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | 3 | import HyperparameterInput from "./HyperparameterInput"; 4 | import { 5 | EPOCH, 6 | LEARNING_RATE, 7 | BATCH_SIZE, 8 | } from "../../../constants/experiments"; 9 | import { Badge } from "../../UI/badge"; 10 | import { getDefaultUnlearningConfig } from "../../../utils/config/unlearning"; 11 | 12 | interface Props extends React.InputHTMLAttributes { 13 | method: string; 14 | epochsList: string[]; 15 | learningRateList: string[]; 16 | batchSizeList: string[]; 17 | setEpoch: (epoch: string[]) => void; 18 | setLearningRate: (lr: string[]) => void; 19 | setBatchSize: (bs: string[]) => void; 20 | onPlusClick: (id: string, value: string) => void; 21 | onBadgeClick: (event: React.MouseEvent) => void; 22 | } 23 | 24 | export default function MethodUnlearning({ 25 | method, 26 | epochsList, 27 | learningRateList, 28 | batchSizeList, 29 | setEpoch, 30 | setLearningRate, 31 | setBatchSize, 32 | onPlusClick, 33 | onBadgeClick, 34 | ...props 35 | }: Props) { 36 | const [initialValues, setInitialValues] = useState(["10", "0.01", "64"]); 37 | 38 | useEffect(() => { 39 | const { epoch, learning_rate, batch_size } = 40 | getDefaultUnlearningConfig(method); 41 | 42 | setInitialValues([epoch, learning_rate, batch_size]); 43 | setEpoch([epoch]); 44 | setLearningRate([learning_rate]); 45 | setBatchSize([batch_size]); 46 | }, [method, setBatchSize, setEpoch, setLearningRate]); 47 | 48 | return ( 49 |
50 |

Hyperparameters

51 |
52 |
53 | 60 | {epochsList.length > 0 && ( 61 |
62 | {epochsList.map((epoch, idx) => ( 63 | 69 | {epoch} 70 | 71 | ))} 72 |
73 | )} 74 |
75 |
76 | 83 | {learningRateList.length > 0 && ( 84 |
85 | {learningRateList.map((rate, idx) => ( 86 | 92 | {rate} 93 | 94 | ))} 95 |
96 | )} 97 |
98 |
99 | 106 | {batchSizeList.length > 0 && ( 107 |
108 | {batchSizeList.map((batch, idx) => ( 109 | 115 | {batch} 116 | 117 | ))} 118 |
119 | )} 120 |
121 |
122 |
123 | ); 124 | } 125 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Experiments/TableHeader.tsx: -------------------------------------------------------------------------------- 1 | import { Table as TableType, flexRender } from "@tanstack/react-table"; 2 | 3 | import { Table, TableHead, TableHeader, TableRow } from "../../UI/table"; 4 | import { COLUMN_WIDTHS } from "./Columns"; 5 | import { ExperimentData } from "../../../types/data"; 6 | 7 | interface Props { 8 | table: TableType; 9 | } 10 | 11 | export default function _TableHeader({ table }: Props) { 12 | return ( 13 | 14 | 15 | {table.getHeaderGroups().map((headerGroup) => ( 16 | 17 | {headerGroup.headers.map((header) => { 18 | const columnWidth = 19 | COLUMN_WIDTHS[header.column.id as keyof typeof COLUMN_WIDTHS]; 20 | return ( 21 | 28 | {header.isPlaceholder 29 | ? null 30 | : flexRender( 31 | header.column.columnDef.header, 32 | header.getContext() 33 | )} 34 | 35 | ); 36 | })} 37 | 38 | ))} 39 | 40 |
41 | ); 42 | } 43 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Progress/AddModelsButton.tsx: -------------------------------------------------------------------------------- 1 | import { useState, useEffect } from "react"; 2 | 3 | import UnlearningConfiguration from "./UnlearningConfiguration"; 4 | import Button from "../../common/CustomButton"; 5 | import { PlusIcon } from "../../UI/icons"; 6 | import { useRunningStatusStore } from "../../../stores/runningStatusStore"; 7 | import { cn } from "../../../utils/util"; 8 | import { 9 | Dialog, 10 | DialogContent, 11 | DialogHeader, 12 | DialogTitle, 13 | DialogTrigger, 14 | } from "../../UI/dialog"; 15 | 16 | export default function AddExperimentsButton() { 17 | const { isRunning } = useRunningStatusStore(); 18 | 19 | const [open, setOpen] = useState(false); 20 | 21 | useEffect(() => { 22 | if (isRunning) setOpen(false); 23 | }, [isRunning]); 24 | 25 | return ( 26 | setOpen(val)}> 27 | 28 | 37 | 38 | 39 | 40 | Model Builder 41 | 42 | 43 | 44 | 45 | ); 46 | } 47 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Progress/Pagination.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | import { 4 | Pagination, 5 | PaginationContent, 6 | PaginationItem, 7 | PaginationNext, 8 | PaginationPrevious, 9 | } from "../../UI/pagination"; 10 | import { PREV, NEXT } from "../../../views/ModelScreening/Progress"; 11 | import { useRunningStatusStore } from "../../../stores/runningStatusStore"; 12 | 13 | interface Props extends React.LiHTMLAttributes { 14 | currentPage: number; 15 | } 16 | 17 | export default function ProgressPagination({ currentPage, ...props }: Props) { 18 | const { totalExperimentsCount } = useRunningStatusStore(); 19 | 20 | return ( 21 | 22 | 23 | 24 | 25 | 26 | 27 | {currentPage} / {totalExperimentsCount} 28 | 29 | 30 | 31 | 32 | 33 | 34 | ); 35 | } 36 | -------------------------------------------------------------------------------- /frontend/src/components/ModelScreening/Progress/Stepper.tsx: -------------------------------------------------------------------------------- 1 | import { memo } from "react"; 2 | import { Check, Dot, Loader2 } from "lucide-react"; 3 | 4 | import { 5 | Stepper, 6 | StepperDescription, 7 | StepperItem, 8 | StepperSeparator, 9 | StepperTitle, 10 | StepperTrigger, 11 | } from "../../UI/stepper"; 12 | import { Button } from "../../UI/button"; 13 | import { Step } from "../../../views/ModelScreening/Progress"; 14 | 15 | const _Stepper = memo(function _Stepper({ 16 | steps, 17 | activeStep, 18 | completedSteps, 19 | isRunning, 20 | }: { 21 | steps: Step[]; 22 | activeStep: number; 23 | completedSteps: number[]; 24 | isRunning: boolean; 25 | }) { 26 | return ( 27 | 28 | {steps.map((step, idx) => { 29 | const isNotLastStep = idx !== steps.length - 1; 30 | 31 | let state: "completed" | "active" | "inactive"; 32 | if (step.step === activeStep) { 33 | state = "active"; 34 | } else if (completedSteps.includes(step.step)) { 35 | state = "completed"; 36 | } else { 37 | state = "inactive"; 38 | } 39 | 40 | return ( 41 | 45 | {isNotLastStep && ( 46 | 47 |
48 | 49 | )} 50 | 51 | 64 | 65 |
66 | 67 | {step.title} 68 | 69 | 70 | {step.description.split("\n").map((el, idx) => ( 71 |

72 | {el 73 | .split("**") 74 | .map((part, partIdx) => 75 | partIdx % 2 === 1 ? ( 76 | {part} 77 | ) : ( 78 | part 79 | ) 80 | )} 81 |

82 | ))} 83 |
84 |
85 | 86 | ); 87 | })} 88 | 89 | ); 90 | }); 91 | 92 | export default _Stepper; 93 | -------------------------------------------------------------------------------- /frontend/src/components/UI/badge.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import { cva, type VariantProps } from "class-variance-authority"; 3 | 4 | import { cn } from "../../utils/util"; 5 | 6 | const badgeVariants = cva( 7 | "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2", 8 | { 9 | variants: { 10 | variant: { 11 | default: 12 | "border-transparent bg-primary text-primary-foreground hover:bg-primary/80", 13 | secondary: 14 | "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80", 15 | destructive: 16 | "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80", 17 | outline: "text-foreground", 18 | }, 19 | }, 20 | defaultVariants: { 21 | variant: "default", 22 | }, 23 | } 24 | ); 25 | 26 | export interface BadgeProps 27 | extends React.HTMLAttributes, 28 | VariantProps {} 29 | 30 | function Badge({ className, variant, ...props }: BadgeProps) { 31 | return ( 32 |
33 | ); 34 | } 35 | 36 | export { Badge, badgeVariants }; 37 | -------------------------------------------------------------------------------- /frontend/src/components/UI/button.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { Slot } from "@radix-ui/react-slot"; 3 | import { cva, type VariantProps } from "class-variance-authority"; 4 | 5 | import { cn } from "../../utils/util"; 6 | 7 | const buttonVariants = cva( 8 | "inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50", 9 | { 10 | variants: { 11 | variant: { 12 | default: "bg-primary text-primary-foreground hover:bg-primary/90", 13 | destructive: 14 | "bg-destructive text-destructive-foreground hover:bg-destructive/90", 15 | outline: 16 | "border border-input bg-background hover:bg-accent hover:text-accent-foreground", 17 | secondary: 18 | "bg-secondary text-secondary-foreground hover:bg-secondary/80", 19 | ghost: "hover:bg-accent hover:text-accent-foreground", 20 | link: "text-primary underline-offset-4 hover:underline", 21 | }, 22 | size: { 23 | default: "h-10 px-4 py-2", 24 | sm: "h-9 rounded-md px-3", 25 | lg: "h-11 rounded-md px-8", 26 | icon: "h-10 w-10", 27 | }, 28 | }, 29 | defaultVariants: { 30 | variant: "default", 31 | size: "default", 32 | }, 33 | } 34 | ); 35 | 36 | export interface ButtonProps 37 | extends React.ButtonHTMLAttributes, 38 | VariantProps { 39 | asChild?: boolean; 40 | } 41 | 42 | const Button = React.forwardRef( 43 | ({ className, variant, size, asChild = false, ...props }, ref) => { 44 | const Comp = asChild ? Slot : "button"; 45 | return ( 46 | 51 | ); 52 | } 53 | ); 54 | Button.displayName = "Button"; 55 | 56 | export { Button, buttonVariants }; 57 | -------------------------------------------------------------------------------- /frontend/src/components/UI/checkbox.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import * as CheckboxPrimitive from "@radix-ui/react-checkbox"; 3 | import { Check } from "lucide-react"; 4 | 5 | import { cn } from "../../utils/util"; 6 | 7 | const Checkbox = React.forwardRef< 8 | React.ElementRef, 9 | React.ComponentPropsWithoutRef 10 | >(({ className, ...props }, ref) => ( 11 | 19 | 22 | 23 | 24 | 25 | )); 26 | Checkbox.displayName = CheckboxPrimitive.Root.displayName; 27 | 28 | export { Checkbox }; 29 | -------------------------------------------------------------------------------- /frontend/src/components/UI/dialog.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import * as DialogPrimitive from "@radix-ui/react-dialog"; 3 | import { X } from "lucide-react"; 4 | 5 | import { cn } from "../../utils/util"; 6 | 7 | const Dialog = DialogPrimitive.Root; 8 | 9 | const DialogTrigger = DialogPrimitive.Trigger; 10 | 11 | const DialogPortal = DialogPrimitive.Portal; 12 | 13 | const DialogClose = DialogPrimitive.Close; 14 | 15 | const DialogOverlay = React.forwardRef< 16 | React.ElementRef, 17 | React.ComponentPropsWithoutRef 18 | >(({ className, ...props }, ref) => ( 19 | 27 | )); 28 | DialogOverlay.displayName = DialogPrimitive.Overlay.displayName; 29 | 30 | const DialogContent = React.forwardRef< 31 | React.ElementRef, 32 | React.ComponentPropsWithoutRef 33 | >(({ className, children, ...props }, ref) => ( 34 | 35 | 36 | 44 | {children} 45 | 46 | 47 | Close 48 | 49 | 50 | 51 | )); 52 | DialogContent.displayName = DialogPrimitive.Content.displayName; 53 | 54 | const DialogHeader = ({ 55 | className, 56 | ...props 57 | }: React.HTMLAttributes) => ( 58 |
65 | ); 66 | DialogHeader.displayName = "DialogHeader"; 67 | 68 | const DialogFooter = ({ 69 | className, 70 | ...props 71 | }: React.HTMLAttributes) => ( 72 |
79 | ); 80 | DialogFooter.displayName = "DialogFooter"; 81 | 82 | const DialogTitle = React.forwardRef< 83 | React.ElementRef, 84 | React.ComponentPropsWithoutRef 85 | >(({ className, ...props }, ref) => ( 86 | 94 | )); 95 | DialogTitle.displayName = DialogPrimitive.Title.displayName; 96 | 97 | const DialogDescription = React.forwardRef< 98 | React.ElementRef, 99 | React.ComponentPropsWithoutRef 100 | >(({ className, ...props }, ref) => ( 101 | 106 | )); 107 | DialogDescription.displayName = DialogPrimitive.Description.displayName; 108 | 109 | export { 110 | Dialog, 111 | DialogPortal, 112 | DialogOverlay, 113 | DialogClose, 114 | DialogTrigger, 115 | DialogContent, 116 | DialogHeader, 117 | DialogFooter, 118 | DialogTitle, 119 | DialogDescription, 120 | }; 121 | -------------------------------------------------------------------------------- /frontend/src/components/UI/hover-card.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import * as HoverCardPrimitive from "@radix-ui/react-hover-card"; 3 | 4 | import { cn } from "../../utils/util"; 5 | 6 | const HoverCard = HoverCardPrimitive.Root; 7 | 8 | const HoverCardTrigger = HoverCardPrimitive.Trigger; 9 | 10 | const HoverCardContent = React.forwardRef< 11 | React.ElementRef, 12 | React.ComponentPropsWithoutRef 13 | >(({ className, align = "center", sideOffset = 4, ...props }, ref) => ( 14 | 24 | )); 25 | HoverCardContent.displayName = HoverCardPrimitive.Content.displayName; 26 | 27 | export { HoverCard, HoverCardTrigger, HoverCardContent }; 28 | -------------------------------------------------------------------------------- /frontend/src/components/UI/input.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | import { cn } from "../../utils/util"; 4 | 5 | export interface InputProps 6 | extends React.InputHTMLAttributes {} 7 | 8 | const Input = React.forwardRef( 9 | ({ className, type, ...props }, ref) => { 10 | return ( 11 | 20 | ); 21 | } 22 | ); 23 | Input.displayName = "Input"; 24 | 25 | export { Input }; 26 | -------------------------------------------------------------------------------- /frontend/src/components/UI/label.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import * as LabelPrimitive from "@radix-ui/react-label"; 3 | import { cva, type VariantProps } from "class-variance-authority"; 4 | 5 | import { cn } from "../../utils/util"; 6 | 7 | const labelVariants = cva( 8 | "text-sm leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70" 9 | ); 10 | 11 | const Label = React.forwardRef< 12 | React.ElementRef, 13 | React.ComponentPropsWithoutRef & 14 | VariantProps 15 | >(({ className, ...props }, ref) => ( 16 | 21 | )); 22 | Label.displayName = LabelPrimitive.Root.displayName; 23 | 24 | export { Label }; 25 | -------------------------------------------------------------------------------- /frontend/src/components/UI/pagination.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import { ChevronLeft, ChevronRight, MoreHorizontal } from "lucide-react"; 3 | 4 | import { cn } from "../../utils/util"; 5 | import { ButtonProps, buttonVariants } from "./button"; 6 | 7 | const Pagination = ({ className, ...props }: React.ComponentProps<"nav">) => ( 8 |