├── .gitattributes
├── .gitignore
├── README.md
├── backend
├── .gitattributes
├── .gitignore
├── README.md
├── app
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ └── settings.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── status.py
│ ├── routers
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── train.py
│ │ └── unlearn.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── train.py
│ │ ├── unlearn_FT.py
│ │ ├── unlearn_GA.py
│ │ ├── unlearn_RL.py
│ │ ├── unlearn_custom.py
│ │ └── unlearn_retrain.py
│ ├── threads
│ │ ├── __init__.py
│ │ ├── train_thread.py
│ │ ├── unlearn_FT_thread.py
│ │ ├── unlearn_GA_thread.py
│ │ ├── unlearn_RL_thread.py
│ │ ├── unlearn_custom_thread.py
│ │ └── unlearn_retrain_thread.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── attack.py
│ │ ├── data_loader.py
│ │ ├── evaluation.py
│ │ ├── helpers.py
│ │ ├── visualization.py
│ │ └── visualize_distributions.py
├── data
│ ├── 0
│ │ ├── 0000.json
│ │ └── a000.json
│ ├── 1
│ │ ├── 0001.json
│ │ └── a001.json
│ ├── 2
│ │ ├── 0002.json
│ │ └── a002.json
│ ├── 3
│ │ ├── 0003.json
│ │ └── a003.json
│ ├── 4
│ │ ├── 0004.json
│ │ └── a004.json
│ ├── 5
│ │ ├── 0005.json
│ │ └── a005.json
│ ├── 6
│ │ ├── 0006.json
│ │ └── a006.json
│ ├── 7
│ │ ├── 0007.json
│ │ └── a007.json
│ ├── 8
│ │ ├── 0008.json
│ │ └── a008.json
│ └── 9
│ │ ├── 0009.json
│ │ └── a009.json
├── index.html
├── main.py
└── pyproject.toml
├── demo.mp4
├── frontend
├── .gitignore
├── README.md
├── components.json
├── package.json
├── pnpm-lock.yaml
├── public
│ ├── index.html
│ ├── logo.svg
│ ├── manifest.json
│ └── robots.txt
├── src
│ ├── app
│ │ ├── App.tsx
│ │ ├── index.css
│ │ └── react-app-env.d.ts
│ ├── components
│ │ ├── Core
│ │ │ ├── Embeddings
│ │ │ │ ├── ConnectionLine.tsx
│ │ │ │ ├── ConnectionLineWrapper.tsx
│ │ │ │ ├── Legend.tsx
│ │ │ │ ├── ScatterPlot.tsx
│ │ │ │ └── Tooltip.tsx
│ │ │ └── PrivacyAttack
│ │ │ │ ├── AttackAnalytics.tsx
│ │ │ │ ├── AttackPlot.tsx
│ │ │ │ ├── AttackSuccessFailure.tsx
│ │ │ │ ├── InstancePanel.tsx
│ │ │ │ ├── Legend.tsx
│ │ │ │ └── Tooltip.tsx
│ │ ├── Header
│ │ │ ├── Header.tsx
│ │ │ ├── Tab.tsx
│ │ │ ├── TabPlusButton.tsx
│ │ │ └── Tabs.tsx
│ │ ├── MetricsView
│ │ │ ├── Accuracy
│ │ │ │ └── VerticalBarChart.tsx
│ │ │ ├── CKA
│ │ │ │ └── LineChart.tsx
│ │ │ └── Predictions
│ │ │ │ ├── BubbleMatrix.tsx
│ │ │ │ ├── BubbleMatrixLegend.tsx
│ │ │ │ ├── CorrelationMatrix.tsx
│ │ │ │ └── CorrelationMatrixLegend.tsx
│ │ ├── ModelScreening
│ │ │ ├── Experiments
│ │ │ │ ├── ColumnHeaders.tsx
│ │ │ │ ├── Columns.tsx
│ │ │ │ ├── CustomUnlearning.tsx
│ │ │ │ ├── DataTable.tsx
│ │ │ │ ├── HyperparameterInput.tsx
│ │ │ │ ├── Legend.tsx
│ │ │ │ ├── MethodFilterHeader.tsx
│ │ │ │ ├── MethodUnlearning.tsx
│ │ │ │ ├── TableBody.tsx
│ │ │ │ └── TableHeader.tsx
│ │ │ └── Progress
│ │ │ │ ├── AddModelsButton.tsx
│ │ │ │ ├── Pagination.tsx
│ │ │ │ ├── Stepper.tsx
│ │ │ │ └── UnlearningConfiguration.tsx
│ │ ├── UI
│ │ │ ├── badge.tsx
│ │ │ ├── button.tsx
│ │ │ ├── chart.tsx
│ │ │ ├── checkbox.tsx
│ │ │ ├── context-menu.tsx
│ │ │ ├── dialog.tsx
│ │ │ ├── hover-card.tsx
│ │ │ ├── icons.tsx
│ │ │ ├── input.tsx
│ │ │ ├── label.tsx
│ │ │ ├── pagination.tsx
│ │ │ ├── radio-group.tsx
│ │ │ ├── scroll-area.tsx
│ │ │ ├── select.tsx
│ │ │ ├── separator.tsx
│ │ │ ├── slider.tsx
│ │ │ ├── stepper.tsx
│ │ │ ├── table.tsx
│ │ │ └── tabs.tsx
│ │ └── common
│ │ │ ├── CustomButton.tsx
│ │ │ ├── DatasetModeSelector.tsx
│ │ │ ├── Indicator.tsx
│ │ │ ├── Subtitle.tsx
│ │ │ ├── Title.tsx
│ │ │ └── View.tsx
│ ├── constants
│ │ ├── colors.ts
│ │ ├── common.ts
│ │ ├── correlations.ts
│ │ ├── embeddings.ts
│ │ ├── experiments.ts
│ │ └── privacyAttack.ts
│ ├── hooks
│ │ ├── useClasses.ts
│ │ └── useModelExperiment.ts
│ ├── index.tsx
│ ├── stores
│ │ ├── attackStore.ts
│ │ ├── baseConfigStore.ts
│ │ ├── experimentsStore.ts
│ │ ├── forgetClassStore.ts
│ │ ├── modelDataStore.ts
│ │ ├── runningIndexStore.ts
│ │ ├── runningStatusStore.ts
│ │ └── thresholdStore.ts
│ ├── types
│ │ ├── attack.ts
│ │ ├── data.ts
│ │ ├── embeddings.ts
│ │ └── experiments.ts
│ ├── utils
│ │ ├── api
│ │ │ ├── dataTable.ts
│ │ │ ├── privacyAttack.ts
│ │ │ ├── requests.ts
│ │ │ └── unlearning.ts
│ │ ├── config
│ │ │ └── unlearning.ts
│ │ ├── data
│ │ │ ├── accuracies.ts
│ │ │ ├── colors.ts
│ │ │ ├── experiments.ts
│ │ │ ├── getButterflyLegendData.ts
│ │ │ ├── getCkaData.ts
│ │ │ ├── getProgressSteps.ts
│ │ │ └── running-status-context.ts
│ │ └── util.ts
│ └── views
│ │ ├── Core
│ │ ├── Core.tsx
│ │ ├── Embedding.tsx
│ │ └── PrivacyAttack.tsx
│ │ ├── MetricsView
│ │ ├── ClassWiseAnalysis.tsx
│ │ ├── LayerWiseSimilarity.tsx
│ │ ├── MetricsView.tsx
│ │ └── PredictionMatrix.tsx
│ │ └── ModelScreening
│ │ ├── Experiments.tsx
│ │ ├── ModelScreening.tsx
│ │ └── Progress.tsx
├── tailwind.config.js
└── tsconfig.json
└── img
├── attack.png
└── embeddings.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.mp4 filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 |
3 | backend/umap_visualizations/
4 | backend/trained_models/
5 | backend/unlearned_models/
6 | backend/uploaded_models/
7 | backend/attack/
8 | backend/.python-version
9 | backend/*.png
10 | paper/
11 | z_dist_exp/
12 | myenv/
13 | *.py
14 | *.pdf
15 | *.mdc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Machine Unlearning Comparator
2 |
5 | This tool facilitates the evaluation and comparison of machine unlearning methods by providing interactive visualizations and analytical insights. It enables a systematic examination of model behavior through privacy attacks and performance metrics, offering comprehensive analysis of various unlearning techniques.
6 |
7 | ## Demo
8 |
9 | Try our live demo: [Machine Unlearning Comparator](https://gnueaj.github.io/Machine-Unlearning-Comparator/)
10 |
11 | [Watch the video](https://youtu.be/yAyAYp2msDk?si=Q-8IgVlrk8uSBceu)
12 |
13 | ## Features
14 |
15 | ### Built-in Baseline Methods
16 | The Machine Unlearning Comparator provides comparison of various baseline methods:
17 | - **Fine-Tuning**: Leverages catastrophic forgetting by fine-tuning the model on remaining data with increased learning rate
18 | - **Gradient-Ascent**: Moves in the direction of increasing loss for forget samples using negative gradients
19 | - **Random-Labeling**: Fine-tunes the model by randomly reassigning labels for forget samples, excluding the original forget class labels
20 |
21 | ### **Custom Method Integration** ✨
22 | **Upload and evaluate your own unlearning methods!** The comparator supports custom implementations, enabling you to:
23 | - **Benchmark** your novel approaches against established baselines
24 | - **Upload** your custom unlearning implementations for comparison
25 | - **Compare** results using standardized evaluation metrics and privacy attacks
26 |
27 | It includes various visualizations and evaluations through privacy attacks to assess the effectiveness of each method.
28 |
29 | ## How to Start
30 |
31 | ### Backend
32 |
33 | 1. **Install Dependencies Using Hatch**
34 | ```shell
35 | hatch shell
36 | ```
37 |
38 | 2. **Start the Backend Server**
39 | ```shell
40 | hatch run start
41 | ```
42 |
43 | ### Frontend
44 |
45 | 1. **Install Dependencies Using pnpm**
46 | ```shell
47 | pnpm install
48 | ```
49 |
50 | 2. **Start the Frontend Server**
51 | ```shell
52 | pnpm start
53 | ```
54 |
55 | ## Related Resources
56 | - [ResNet18 CIFAR-10 Unlearning Models on Hugging Face](https://huggingface.co/jaeunglee/resnet18-cifar10-unlearning)
57 |
--------------------------------------------------------------------------------
/backend/.gitattributes:
--------------------------------------------------------------------------------
1 | /home/jaeung/mu-dashboard/backend/demo.mp4 filter=lfs diff=lfs merge=lfs -text
2 | *.mp4 filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/backend/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | backend/umap_visualizations/
3 | backend/trained_models/
4 | backend/unlearned_models/
5 | backend/uploaded_models/
6 | backend/*.png
7 | d3ex/
8 | attack/
9 | attack_exp/
--------------------------------------------------------------------------------
/backend/README.md:
--------------------------------------------------------------------------------
1 | # machine-unlearning-dashboard
2 | machine-unlearning-dashboard
3 |
4 | ## Development
5 |
6 | Install [pnpm](https://pnpm.io/installation) to set up the development environment.
7 |
8 | ```shell
9 | git clone https://github.com/gnueaj/mu-dashboard.git && cd mu-dashboard/frontend
10 | pnpm install
11 | pnpm dev
12 | ```
13 |
14 | ## Backend
15 | ```shell
16 | cd backend
17 | hatch shell
18 | hatch run start
19 |
20 | Swagger - http://127.0.0.1:8000/docs
21 | e.g.
22 |
23 | POST /train
24 | {
25 | "seed": 42,
26 | "batch_size": 128,
27 | "learning_rate": 0.01,
28 | "epochs": 30
29 | }
30 |
31 | GET /status
32 | ```
33 | API 추가할 때: router -> service -> main 순 구현
34 | Swagger: http://127.0.0.1:8000/docs
35 |
36 | ### Training
37 | - /train POST 요청을 보내 트레이닝을 시작
38 | - 주기적으로 /train/status GET 요청을 보내 트레이닝 진행 상황을 확인
39 | - 트레이닝이 완료되면 /train/result GET 요청을 보내 SVG 파일 리스트를 받음
40 |
41 | ### Inference
42 | - /inference POST 요청을 보내 Inference를 시작
43 | - 주기적으로 /inference/status 엔드포인트로 GET 요청을 보내 Inference 진행 상황을 확인
44 | - Inference가 완료되면 /inference/result GET 요청을 보내 SVG 파일 리스트를 받음
45 |
46 | ### Unlearn
47 | ## Retrain
48 | ```shell
49 | POST /unlearn
50 | {
51 | "batch_size": 128,
52 | "learning_rate": 0.001,
53 | "epochs": 15,
54 | "forget_class": 1
55 | }
56 | ```
--------------------------------------------------------------------------------
/backend/app/__init__.py:
--------------------------------------------------------------------------------
1 | # This file can be empty or contain package-level imports if needed
--------------------------------------------------------------------------------
/backend/app/config/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains all configuration settings for the application.
3 | """
4 |
5 | from .settings import (
6 | UMAP_N_NEIGHBORS,
7 | UMAP_MIN_DIST,
8 | UMAP_INIT,
9 | UMAP_RANDOM_STATE,
10 | UMAP_N_JOBS,
11 | UMAP_DATA_SIZE,
12 | UMAP_DATASET,
13 | MAX_GRAD_NORM,
14 | UNLEARN_SEED,
15 | MOMENTUM,
16 | WEIGHT_DECAY,
17 | BATCH_SIZE,
18 | LEARNING_RATE,
19 | EPOCHS,
20 | DECREASING_LR,
21 | GAMMA
22 | )
23 |
24 | __all__ = [
25 | # UMAP settings
26 | 'UMAP_N_NEIGHBORS',
27 | 'UMAP_MIN_DIST',
28 | 'UMAP_INIT',
29 | 'UMAP_RANDOM_STATE',
30 | 'UMAP_N_JOBS',
31 | 'UMAP_DATA_SIZE',
32 | 'UMAP_DATASET',
33 |
34 | # Training settings
35 | 'MAX_GRAD_NORM',
36 | 'UNLEARN_SEED',
37 | 'MOMENTUM',
38 | 'WEIGHT_DECAY',
39 | 'BATCH_SIZE',
40 | 'LEARNING_RATE',
41 | 'EPOCHS',
42 |
43 | # Learning rate schedule
44 | 'DECREASING_LR',
45 | 'GAMMA'
46 | ]
--------------------------------------------------------------------------------
/backend/app/config/settings.py:
--------------------------------------------------------------------------------
1 | UMAP_N_NEIGHBORS = 7
2 | UMAP_MIN_DIST = 0.4
3 | UMAP_INIT = 'pca'
4 | UMAP_RANDOM_STATE = 42
5 | UMAP_N_JOBS = -1
6 | UMAP_DATA_SIZE = 2000
7 | UMAP_DATASET = 'train'
8 |
9 | MAX_GRAD_NORM = 100.0
10 | UNLEARN_SEED = 2048
11 | MOMENTUM = 0.9
12 | WEIGHT_DECAY = 5e-4
13 | BATCH_SIZE = 128
14 | LEARNING_RATE = 0.1
15 | EPOCHS = 200
16 |
17 | DECREASING_LR = [80, 120]
18 | GAMMA = 0.2
--------------------------------------------------------------------------------
/backend/app/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains neural network model architectures and related data structures.
3 | It includes implementations of ResNet and status tracking classes for training and unlearning processes.
4 | """
5 |
6 | from app.models.resnet import get_resnet18
7 | from app.models.status import TrainingStatus, UnlearningStatus
8 |
9 | __all__ = [
10 | 'get_resnet18',
11 | 'TrainingStatus',
12 | 'UnlearningStatus'
13 | ]
--------------------------------------------------------------------------------
/backend/app/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision import models
3 |
4 | def get_resnet18(num_classes=10):
5 | model = models.resnet18(weights=None)
6 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
7 | model.maxpool = nn.Identity()
8 | model.fc = nn.Linear(model.fc.in_features, num_classes)
9 | return model
--------------------------------------------------------------------------------
/backend/app/models/status.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | class TrainingStatus:
4 | def __init__(self):
5 | self.is_training = False
6 | self.progress = 0
7 | self.current_epoch = 0
8 | self.total_epochs = 0
9 | self.current_loss = 0
10 | self.best_loss = 9999.99
11 | self.current_accuracy = 0
12 | self.best_accuracy = 0
13 | self.test_loss = 0
14 | self.test_accuracy = 0
15 | self.best_test_accuracy = 0
16 | self.train_class_accuracies: Dict[int, float] = {}
17 | self.test_class_accuracies: Dict[int, float] = {}
18 | self.start_time = 0
19 | self.estimated_time_remaining = 0
20 | self.umap_embeddings = None
21 | self.cancel_requested = False
22 |
23 | def reset(self):
24 | self.__init__()
25 |
26 | class UnlearningStatus:
27 | def __init__(self):
28 | self.is_unlearning = False
29 | self.recent_id = None
30 | self.progress = "Idle"
31 | self.current_epoch = 0
32 | self.total_epochs = 0
33 | self.current_unlearn_loss = 0
34 | self.current_unlearn_accuracy = 0
35 | self.estimated_time_remaining = 0
36 | self.cancel_requested = False
37 | self.forget_class = -1
38 |
39 | # for retraining
40 | self.current_loss = 0
41 | self.current_accuracy = 0
42 | self.best_loss = 9999.99
43 | self.best_accuracy = 0
44 | self.test_loss = 9999.99
45 | self.test_accuracy = 0
46 | self.best_test_accuracy = 0
47 |
48 | # for Evaluation progress
49 | self.method = ""
50 | self.p_training_loss = 0
51 | self.p_training_accuracy = 0
52 | self.p_test_loss = 0
53 | self.p_test_accuracy = 0
54 |
55 | self.train_class_accuracies: Dict[int, float] = {}
56 | self.test_class_accuracies: Dict[int, float] = {}
57 | self.start_time = None
58 |
59 | def reset(self):
60 | self.__init__()
--------------------------------------------------------------------------------
/backend/app/routers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package contains route definitions for the API endpoints.
3 |
4 | Available routers:
5 | - train_router: Handles model training related endpoints
6 | - unlearn_router: Manages model unlearning operations
7 | - data_router: Manages data operations and preprocessing
8 |
9 | These routers are used to organize and structure the API endpoints
10 | for different functionalities of the application.
11 | """
12 |
13 | from .train import router as train_router
14 | from .unlearn import router as unlearn_router
15 | from .data import router as data_router
16 |
17 | __all__ = ['train_router', 'unlearn_router', 'data_router']
--------------------------------------------------------------------------------
/backend/app/routers/train.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, BackgroundTasks, HTTPException
2 | from pydantic import BaseModel, Field
3 | from app.services import run_training
4 | from app.models import TrainingStatus
5 | from app.config import (
6 | BATCH_SIZE,
7 | LEARNING_RATE,
8 | EPOCHS
9 | )
10 |
11 | router = APIRouter()
12 | status = TrainingStatus()
13 |
14 | class TrainingRequest(BaseModel):
15 | # seed: int = Field(default=1111, description="Random seed for reproducibility")
16 | batch_size: int = Field(default=BATCH_SIZE, description="Batch size for training")
17 | learning_rate: float = Field(default=LEARNING_RATE, description="Learning rate for optimizer")
18 | epochs: int = Field(default=EPOCHS, description="Number of training epochs")
19 |
20 | @router.post("/train")
21 | async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
22 | if status.is_training:
23 | raise HTTPException(status_code=400, detail="Training is already in progress")
24 | status.reset() # Reset status before starting new training
25 | background_tasks.add_task(run_training, request, status)
26 | return {"message": "Training started"}
27 |
28 | @router.get("/train/status")
29 | async def get_status():
30 | return {
31 | "is_training": status.is_training,
32 | "progress": status.progress,
33 | "current_epoch": status.current_epoch,
34 | "total_epochs": status.total_epochs,
35 | "current_loss": status.current_loss,
36 | "best_loss": status.best_loss,
37 | "current_accuracy": status.current_accuracy,
38 | "best_accuracy": status.best_accuracy,
39 | "test_loss": status.test_loss,
40 | "test_accuracy": status.test_accuracy,
41 | "train_class_accuracies": status.train_class_accuracies,
42 | "test_class_accuracies": status.test_class_accuracies,
43 | "estimated_time_remaining": status.estimated_time_remaining
44 | }
45 |
46 | @router.get("/train/result")
47 | async def get_training_result():
48 | if status.is_training:
49 | raise HTTPException(status_code=400, detail="Training is still in progress")
50 | if status.svg_files is None:
51 | raise HTTPException(status_code=404, detail="No training results available")
52 | return {"svg_files": status.svg_files}
53 |
54 | @router.post("/train/cancel")
55 | async def cancel_training():
56 | if not status.is_training:
57 | raise HTTPException(status_code=400, detail="No training in progress")
58 | status.cancel_requested = True
59 | return {"message": "Cancellation requested. Training will stop after the current epoch."}
--------------------------------------------------------------------------------
/backend/app/services/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Services package for managing model training and unlearning operations.
3 |
4 | This package provides modules for training neural networks and implementing various unlearning methods.
5 | The service modules define recipes that are executed by separate threads to handle the actual computation.
6 |
7 | Available services:
8 | train: Model training with configurable hyperparameters
9 | unlearn_GA: Unlearning using gradient ascent method
10 | unlearn_RL: Unlearning using random labeling method
11 | unlearn_FT: Unlearning using fine-tuning method
12 | unlearn_custom: Custom unlearning method for inference
13 |
14 | Each service module follows a similar pattern:
15 | 1. Takes training/unlearning parameters as input
16 | 2. Sets up the model, data, and optimization components
17 | 3. Passes the configuration to a dedicated execution thread
18 | 4. Provides status tracking and result handling
19 | """
20 |
21 | from .train import run_training
22 | from .unlearn_GA import run_unlearning_GA
23 | from .unlearn_RL import run_unlearning_RL
24 | from .unlearn_FT import run_unlearning_FT
25 | from .unlearn_retrain import run_unlearning_retrain
26 | from .unlearn_custom import run_unlearning_custom
27 |
28 | __all__ = ['run_training', 'run_unlearning_GA', 'run_unlearning_RL', 'run_unlearning_FT', 'run_unlearning_retrain', 'run_unlearning_custom']
--------------------------------------------------------------------------------
/backend/app/services/train.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | from app.threads import TrainingThread
7 | from app.models import get_resnet18
8 | from app.utils import set_seed, get_data_loaders
9 | from app.config import (
10 | MOMENTUM,
11 | WEIGHT_DECAY,
12 | UNLEARN_SEED
13 | )
14 |
15 | async def training(request, status):
16 | print(f"Starting training with {request.epochs} epochs...")
17 | set_seed(UNLEARN_SEED)
18 | device = torch.device(
19 | "cuda" if torch.cuda.is_available()
20 | else "mps" if torch.backends.mps.is_available()
21 | else "cpu"
22 | )
23 |
24 | (
25 | train_loader,
26 | test_loader,
27 | train_set,
28 | test_set
29 | ) = get_data_loaders(
30 | batch_size=request.batch_size,
31 | augmentation=True
32 | )
33 | model = get_resnet18().to(device=device)
34 |
35 | criterion = nn.CrossEntropyLoss()
36 | optimizer = optim.SGD(
37 | model.parameters(),
38 | lr=request.learning_rate,
39 | momentum=MOMENTUM,
40 | weight_decay=WEIGHT_DECAY,
41 | nesterov=True
42 | )
43 | scheduler = optim.lr_scheduler.CosineAnnealingLR(
44 | optimizer=optimizer,
45 | T_max=request.epochs,
46 | )
47 |
48 | training_thread = TrainingThread(
49 | model=model,
50 | train_loader=train_loader,
51 | test_loader=test_loader,
52 | criterion=criterion,
53 | optimizer=optimizer,
54 | scheduler=scheduler,
55 | device=device,
56 | epochs=request.epochs,
57 | status=status,
58 | model_name="resnet18",
59 | dataset_name="CIFAR10",
60 | learning_rate=request.learning_rate
61 | )
62 | training_thread.start()
63 |
64 | while training_thread.is_alive():
65 | await asyncio.sleep(0.1)
66 | if status.cancel_requested:
67 | training_thread.stop()
68 | print("Cancel requested. Stopping training thread...")
69 | break
70 |
71 | return status
72 |
73 | async def run_training(request, status):
74 | try:
75 | status.is_training = True
76 | status.cancel_requested = False
77 | updated_status = await training(request, status)
78 | return updated_status
79 | finally:
80 | status.is_training = False
81 | status.cancel_requested = False
82 | status.progress = 100
--------------------------------------------------------------------------------
/backend/app/services/unlearn_GA.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | from app.threads import UnlearningGAThread
7 | from app.models import get_resnet18
8 | from app.utils.helpers import set_seed
9 | from app.utils.data_loader import get_data_loaders
10 | from app.config import (
11 | MOMENTUM,
12 | WEIGHT_DECAY,
13 | DECREASING_LR,
14 | UNLEARN_SEED
15 | )
16 |
17 | async def unlearning_GA(request, status, base_weights_path):
18 | print(f"Starting GA unlearning for class {request.forget_class} with {request.epochs} epochs...")
19 | set_seed(UNLEARN_SEED)
20 |
21 | device = torch.device(
22 | "cuda" if torch.cuda.is_available()
23 | else "mps" if torch.backends.mps.is_available()
24 | else "cpu"
25 | )
26 |
27 | # Create Unlearning Settings
28 | model_before = get_resnet18().to(device)
29 | model_after = get_resnet18().to(device)
30 | model_before.load_state_dict(torch.load(f"unlearned_models/{request.forget_class}/000{request.forget_class}.pth", map_location=device))
31 | model_after.load_state_dict(torch.load(base_weights_path, map_location=device))
32 |
33 | (
34 | train_loader,
35 | test_loader,
36 | train_set,
37 | test_set
38 | ) = get_data_loaders(
39 | batch_size=request.batch_size,
40 | augmentation=False
41 | )
42 |
43 | forget_indices = [
44 | i for i, (_, label) in enumerate(train_set)
45 | if label == request.forget_class
46 | ]
47 | forget_subset = torch.utils.data.Subset(
48 | dataset=train_set,
49 | indices=forget_indices
50 | )
51 | forget_loader = torch.utils.data.DataLoader(
52 | dataset=forget_subset,
53 | batch_size=request.batch_size,
54 | shuffle=True
55 | )
56 |
57 | criterion = nn.CrossEntropyLoss()
58 | optimizer = optim.SGD(
59 | params=model_after.parameters(),
60 | lr=request.learning_rate,
61 | momentum=MOMENTUM,
62 | weight_decay=WEIGHT_DECAY
63 | )
64 | scheduler = optim.lr_scheduler.MultiStepLR(
65 | optimizer=optimizer,
66 | milestones=DECREASING_LR,
67 | gamma=0.2
68 | )
69 | unlearning_GA_thread = UnlearningGAThread(
70 | request=request,
71 | status=status,
72 | model_before=model_before,
73 | model_after=model_after,
74 | criterion=criterion,
75 | optimizer=optimizer,
76 | scheduler=scheduler,
77 |
78 | forget_loader=forget_loader,
79 | train_loader=train_loader,
80 | test_loader=test_loader,
81 | train_set=train_set,
82 | test_set=test_set,
83 | device=device,
84 | base_weights_path=base_weights_path
85 | )
86 | unlearning_GA_thread.start()
87 |
88 | # thread start
89 | while unlearning_GA_thread.is_alive():
90 | await asyncio.sleep(0.1)
91 | if status.cancel_requested:
92 | unlearning_GA_thread.stop()
93 | print("Cancellation requested, stopping the unlearning process...")
94 |
95 | status.is_unlearning = False
96 |
97 | # thread end
98 | if unlearning_GA_thread.exception:
99 | print(f"An error occurred during GA unlearning: {str(unlearning_GA_thread.exception)}")
100 | elif status.cancel_requested:
101 | print("Unlearning process was cancelled.")
102 | else:
103 | print("Unlearning process completed successfully.")
104 |
105 | return status
106 |
107 | async def run_unlearning_GA(request, status, base_weights_path):
108 | try:
109 | status.is_unlearning = True
110 | status.progress = "Unlearning"
111 | updated_status = await unlearning_GA(request, status, base_weights_path)
112 | return updated_status
113 | finally:
114 | status.cancel_requested = False
115 | status.progress = "Completed"
--------------------------------------------------------------------------------
/backend/app/services/unlearn_RL.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | from app.threads import UnlearningRLThread
7 | from app.models import get_resnet18
8 | from app.utils.helpers import set_seed
9 | from app.utils.data_loader import get_data_loaders
10 |
11 | from app.config import (
12 | MOMENTUM,
13 | WEIGHT_DECAY,
14 | DECREASING_LR,
15 | UNLEARN_SEED
16 | )
17 |
18 | async def unlearning_RL(request, status, base_weights_path):
19 | print(f"Starting RL unlearning for class {request.forget_class} with {request.epochs} epochs...")
20 | set_seed(UNLEARN_SEED)
21 |
22 | device = torch.device(
23 | "cuda" if torch.cuda.is_available()
24 | else "mps" if torch.backends.mps.is_available()
25 | else "cpu"
26 | )
27 |
28 | # Create Unlearning Settings
29 | model_before = get_resnet18().to(device)
30 | model_after = get_resnet18().to(device)
31 | model_before.load_state_dict(torch.load(f"unlearned_models/{request.forget_class}/000{request.forget_class}.pth", map_location=device))
32 | model_after.load_state_dict(torch.load(base_weights_path, map_location=device))
33 |
34 | (
35 | train_loader,
36 | test_loader,
37 | train_set,
38 | test_set
39 | ) = get_data_loaders(
40 | batch_size=request.batch_size,
41 | augmentation=False
42 | )
43 |
44 | # Create retain loader (excluding forget class)
45 | retain_indices = [
46 | i for i, (_, label) in enumerate(train_set)
47 | if label != request.forget_class
48 | ]
49 | retain_subset = torch.utils.data.Subset(
50 | dataset=train_set,
51 | indices=retain_indices
52 | )
53 | retain_loader = torch.utils.data.DataLoader(
54 | dataset=retain_subset,
55 | batch_size=request.batch_size,
56 | shuffle=True
57 | )
58 |
59 | # Create forget loader (only forget class)
60 | forget_indices = [
61 | i for i, (_, label) in enumerate(train_set)
62 | if label == request.forget_class
63 | ]
64 | forget_subset = torch.utils.data.Subset(
65 | dataset=train_set,
66 | indices=forget_indices
67 | )
68 | forget_loader = torch.utils.data.DataLoader(
69 | dataset=forget_subset,
70 | batch_size=request.batch_size,
71 | shuffle=True
72 | )
73 |
74 | criterion = nn.CrossEntropyLoss()
75 | optimizer = optim.SGD(
76 | params=model_after.parameters(),
77 | lr=request.learning_rate,
78 | momentum=MOMENTUM,
79 | weight_decay=WEIGHT_DECAY
80 | )
81 | scheduler = optim.lr_scheduler.MultiStepLR(
82 | optimizer=optimizer,
83 | milestones=DECREASING_LR,
84 | gamma=0.2
85 | )
86 |
87 | unlearning_RL_thread = UnlearningRLThread(
88 | request=request,
89 | status=status,
90 | model_before=model_before,
91 | model_after=model_after,
92 | criterion=criterion,
93 | optimizer=optimizer,
94 | scheduler=scheduler,
95 |
96 | retain_loader=retain_loader,
97 | forget_loader=forget_loader,
98 | train_loader=train_loader,
99 | test_loader=test_loader,
100 | train_set=train_set,
101 | test_set=test_set,
102 | device=device,
103 | base_weights_path=base_weights_path
104 | )
105 |
106 | unlearning_RL_thread.start()
107 |
108 | # thread start
109 | while unlearning_RL_thread.is_alive():
110 | await asyncio.sleep(0.1)
111 | if status.cancel_requested:
112 | unlearning_RL_thread.stop()
113 | print("Cancellation requested, stopping the unlearning process...")
114 |
115 | status.is_unlearning = False
116 |
117 | # thread end
118 | if unlearning_RL_thread.exception:
119 | print(f"An error occurred during RL unlearning: {str(unlearning_RL_thread.exception)}")
120 | elif status.cancel_requested:
121 | print("Unlearning process was cancelled.")
122 | else:
123 | print("Unlearning process completed successfully.")
124 |
125 | return status
126 |
127 | async def run_unlearning_RL(request, status, base_weights_path):
128 | try:
129 | status.is_unlearning = True
130 | status.progress = "Unlearning"
131 | updated_status = await unlearning_RL(request, status, base_weights_path)
132 | return updated_status
133 | finally:
134 | status.cancel_requested = False
135 | status.progress = "Completed"
--------------------------------------------------------------------------------
/backend/app/services/unlearn_custom.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import torch
4 | import torch.nn as nn
5 |
6 | from app.threads import UnlearningCustomThread
7 | from app.utils.helpers import set_seed
8 | from app.utils.data_loader import get_data_loaders
9 | from app.models import get_resnet18
10 | from app.config import UNLEARN_SEED
11 |
12 |
13 | async def unlearning_custom(forget_class, status, weights_path, base_weights):
14 | print(f"Starting custom unlearning inference for class {forget_class}...")
15 | set_seed(UNLEARN_SEED)
16 | (
17 | train_loader,
18 | test_loader,
19 | train_set,
20 | test_set
21 | ) = get_data_loaders(
22 | batch_size=1000,
23 | augmentation=False
24 | )
25 |
26 | criterion = nn.CrossEntropyLoss()
27 | device = torch.device("cuda" if torch.cuda.is_available()
28 | else "mps" if torch.backends.mps.is_available()
29 | else "cpu")
30 | model_before = get_resnet18().to(device)
31 | model_before.load_state_dict(torch.load(f"unlearned_models/{forget_class}/000{forget_class}.pth", map_location=device))
32 | model = get_resnet18().to(device)
33 | model.load_state_dict(torch.load(weights_path, map_location=device))
34 |
35 | unlearning_thread = UnlearningCustomThread(
36 | forget_class=forget_class,
37 | status=status,
38 | model_before=model_before,
39 | model=model,
40 | train_loader=train_loader,
41 | test_loader=test_loader,
42 | train_set=train_set,
43 | test_set=test_set,
44 | criterion=criterion,
45 | device=device,
46 | base_weights=base_weights
47 | )
48 | unlearning_thread.start()
49 | print("unlearning started")
50 | # thread start
51 | while unlearning_thread.is_alive():
52 | await asyncio.sleep(0.1)
53 | if status.cancel_requested:
54 | unlearning_thread.stop()
55 | print("Cancellation requested, stopping the unlearning process...")
56 |
57 | status.is_unlearning = False
58 |
59 | if unlearning_thread.is_alive():
60 | print("Warning: Unlearning thread did not stop within the timeout period.")
61 |
62 | # thread end
63 | if unlearning_thread.exception:
64 | print(f"An error occurred during custom unlearning: {str(unlearning_thread.exception)}")
65 | elif status.cancel_requested:
66 | print("Unlearning process was cancelled.")
67 | else:
68 | print("Unlearning process completed successfully.")
69 |
70 | return status
71 |
72 | async def run_unlearning_custom(forget_class, status, weights_path, base_weights):
73 | try:
74 | status.is_unlearning = True
75 | updated_status = await unlearning_custom(forget_class, status, weights_path, base_weights)
76 | return updated_status
77 | finally:
78 | status.is_unlearning = False
79 | status.cancel_requested = False
80 | status.progress = "Completed"
81 | os.remove(weights_path)
--------------------------------------------------------------------------------
/backend/app/services/unlearn_retrain.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | from app.threads import UnlearningRetrainThread
7 | from app.models import get_resnet18
8 | from app.utils.helpers import set_seed
9 | from app.utils.data_loader import get_data_loaders
10 | from app.utils.visualization import (
11 | compute_umap_embedding,
12 | )
13 | from app.utils.evaluation import (
14 | get_layer_activations_and_predictions,
15 | )
16 | from app.config import (
17 | MOMENTUM,
18 | WEIGHT_DECAY,
19 | UNLEARN_SEED
20 | )
21 |
22 | async def unlearning_retrain(request, status):
23 | print(
24 | f"Starting unlearning for class {request.forget_class} "
25 | f"with {request.epochs} epochs..."
26 | )
27 | set_seed(UNLEARN_SEED)
28 | device = torch.device(
29 | "cuda" if torch.cuda.is_available()
30 | else "mps" if torch.backends.mps.is_available()
31 | else "cpu"
32 | )
33 |
34 | (
35 | train_loader,
36 | test_loader,
37 | train_set,
38 | test_set
39 | ) = get_data_loaders(
40 | batch_size=request.batch_size,
41 | augmentation=True
42 | )
43 |
44 | # Create dataset excluding the forget class
45 | indices = [
46 | i for i, (_, label) in enumerate(train_set)
47 | if label != request.forget_class
48 | ]
49 | subset = torch.utils.data.Subset(train_set, indices)
50 | unlearning_loader = torch.utils.data.DataLoader(
51 | dataset=subset,
52 | batch_size=request.batch_size,
53 | shuffle=True
54 | )
55 |
56 | model = get_resnet18().to(device)
57 | criterion = nn.CrossEntropyLoss()
58 | optimizer = optim.SGD(
59 | model.parameters(),
60 | lr=request.learning_rate,
61 | momentum=MOMENTUM,
62 | weight_decay=WEIGHT_DECAY,
63 | nesterov=True
64 | )
65 | scheduler = optim.lr_scheduler.CosineAnnealingLR(
66 | optimizer=optimizer,
67 | T_max=request.epochs,
68 | )
69 |
70 | status.progress = 0
71 | status.forget_class = request.forget_class
72 |
73 | unlearning_thread = UnlearningRetrainThread(
74 | model=model,
75 | unlearning_loader=unlearning_loader,
76 | full_train_loader=train_loader,
77 | test_loader=test_loader,
78 | criterion=criterion,
79 | optimizer=optimizer,
80 | scheduler=scheduler,
81 | device=device,
82 | epochs=request.epochs,
83 | status=status,
84 | model_name="resnet18",
85 | dataset_name=f"CIFAR10_without_class_{request.forget_class}",
86 | learning_rate=request.learning_rate,
87 | forget_class=request.forget_class
88 | )
89 | unlearning_thread.start()
90 |
91 | while unlearning_thread.is_alive():
92 | await asyncio.sleep(0.5)
93 | if status.cancel_requested:
94 | unlearning_thread.stop()
95 | print("Cancel requested. Stopping unlearning thread...")
96 | break
97 |
98 | if unlearning_thread.exception:
99 | print(
100 | f"An error occurred during Retrain unlearning: "
101 | f"{str(unlearning_thread.exception)}"
102 | )
103 | return status
104 |
105 | return status
106 |
107 | async def run_unlearning_retrain(request, status):
108 | try:
109 | status.is_unlearning = True
110 | status.progress = "Unlearning"
111 | updated_status = await unlearning_retrain(request, status)
112 | return updated_status
113 | finally:
114 | status.is_unlearning = False
115 | status.cancel_requested = False
116 | status.progress = "Completed"
--------------------------------------------------------------------------------
/backend/app/threads/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Thread package for managing model training and unlearning operations.
3 |
4 | This package provides thread classes that handle the actual computation for training
5 | and various unlearning methods. Each thread runs independently and provides status tracking.
6 |
7 | Available threads:
8 | TrainingThread: Handles model training operations
9 | UnlearningGAThread: Unlearning using gradient ascent method
10 | UnlearningRLThread: Unlearning using random labeling method
11 | UnlearningFTThread: Unlearning using fine-tuning method
12 | UnlearningRetrainThread: Unlearning by retraining from scratch
13 | UnlearningCustomThread: Custom unlearning method for inference
14 |
15 | Each thread class follows a similar pattern:
16 | 1. Inherits from threading.Thread
17 | 2. Takes model, data, and optimization components as input
18 | 3. Implements run() method for the main computation loop
19 | 4. Provides methods for status updates and graceful termination
20 | """
21 |
22 | from .train_thread import TrainingThread
23 | from .unlearn_GA_thread import UnlearningGAThread
24 | from .unlearn_RL_thread import UnlearningRLThread
25 | from .unlearn_FT_thread import UnlearningFTThread
26 | from .unlearn_retrain_thread import UnlearningRetrainThread
27 | from .unlearn_custom_thread import UnlearningCustomThread
28 |
29 | __all__ = ['TrainingThread', 'UnlearningGAThread', 'UnlearningRLThread', 'UnlearningFTThread', 'UnlearningRetrainThread', 'UnlearningCustomThread']
--------------------------------------------------------------------------------
/backend/app/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This package contains utility functions and modules.
3 | """
4 |
5 | from .data_loader import (
6 | load_cifar10_data,
7 | get_data_loaders
8 | )
9 | from .evaluation import (
10 | get_layer_activations_and_predictions,
11 | evaluate_model,
12 | evaluate_model_with_distributions,
13 | calculate_cka_similarity
14 | )
15 | from .helpers import (
16 | set_seed,
17 | save_model,
18 | format_distribution,
19 | compress_prob_array
20 | )
21 | from .visualization import compute_umap_embedding
22 |
23 |
24 | __all__ = [
25 | 'load_cifar10_data', 'get_data_loaders',
26 | 'get_layer_activations_and_predictions', 'evaluate_model',
27 | 'evaluate_model_with_distributions', 'calculate_cka_similarity', 'set_seed', 'save_model',
28 | 'compute_umap_embedding', 'format_distribution', 'compress_prob_array'
29 | ]
--------------------------------------------------------------------------------
/backend/app/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets, transforms
3 | from torch.utils.data import DataLoader
4 | from app.config import UNLEARN_SEED
5 |
6 | def load_cifar10_data():
7 | """Load CIFAR-10 training data with automatic download"""
8 | # Use torchvision's CIFAR10 dataset to handle downloading
9 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
10 |
11 | # Convert to numpy array in the format we need
12 | x_train = train_set.data # This is already in (N, 32, 32, 3) format
13 | y_train = np.array(train_set.targets)
14 |
15 | return x_train, y_train
16 |
17 | def get_data_loaders(batch_size, augmentation=False):
18 | base_transforms = [
19 | transforms.ToTensor(),
20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
21 | ]
22 |
23 | train_transform = transforms.Compose(
24 | ([
25 | transforms.RandomCrop(32, padding=4),
26 | transforms.RandomHorizontalFlip(),
27 | ] if augmentation else []) + base_transforms
28 | )
29 |
30 | test_transform = transforms.Compose(base_transforms)
31 |
32 | train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
33 | test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
34 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
35 | test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=0)
36 | print("loaded loaders")
37 | return train_loader, test_loader, train_set, test_set
38 |
39 | def get_fixed_umap_indices(total_samples=2000, seed=UNLEARN_SEED):
40 | import torch
41 | _, y_train = load_cifar10_data()
42 | num_classes = 10
43 | targets_tensor = torch.tensor(y_train)
44 |
45 | samples_per_class = total_samples // num_classes
46 | generator = torch.Generator()
47 | generator.manual_seed(seed)
48 |
49 | indices_dict = {}
50 | for i in range(num_classes):
51 | class_indices = (targets_tensor == i).nonzero(as_tuple=False).squeeze()
52 | perm = torch.randperm(len(class_indices), generator=generator)
53 | indices_dict[i] = class_indices[perm[:samples_per_class]].tolist()
54 |
55 | return indices_dict
--------------------------------------------------------------------------------
/backend/app/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os
3 | import time
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from umap import UMAP
8 |
9 | from app.config import (
10 | UMAP_N_NEIGHBORS,
11 | UMAP_MIN_DIST,
12 | UMAP_INIT,
13 | UMAP_N_JOBS
14 | )
15 |
16 | async def compute_umap_embedding(
17 | activation,
18 | labels,
19 | forget_class=-1,
20 | forget_labels=None,
21 | save_dir='umap_visualizations'
22 | ):
23 | umap_embedding = []
24 |
25 | class_names = [
26 | 'airplane',
27 | 'automobile',
28 | 'bird',
29 | 'cat',
30 | 'deer',
31 | 'dog',
32 | 'frog',
33 | 'horse',
34 | 'ship',
35 | 'truck'
36 | ]
37 | if(forget_class != -1):
38 | class_names[forget_class] += " (forget)"
39 |
40 | colors = plt.cm.tab10(np.linspace(0, 1, 10))
41 | if not os.path.exists(save_dir):
42 | os.makedirs(save_dir)
43 |
44 | umap = UMAP(n_components=2,
45 | n_neighbors=UMAP_N_NEIGHBORS,
46 | min_dist=UMAP_MIN_DIST,
47 | init=UMAP_INIT,
48 | n_jobs=UMAP_N_JOBS)
49 | print(f"UMAP start!")
50 | start_time = time.time()
51 | embedding = umap.fit_transform(activation)
52 | print(f"UMAP done! Time taken: {time.time() - start_time:.2f}s")
53 |
54 | umap_embedding = embedding
55 | plt.figure(figsize=(12, 11))
56 |
57 | # Plot non-forget points
58 | if forget_labels is not None:
59 | non_forget_mask = ~forget_labels
60 | scatter = plt.scatter(
61 | embedding[non_forget_mask, 0],
62 | embedding[non_forget_mask, 1],
63 | c=labels[non_forget_mask],
64 | cmap='tab10',
65 | s=20,
66 | alpha=0.7,
67 | vmin=0,
68 | vmax=9
69 | )
70 |
71 | # Plot forget points with 'x' marker
72 | forget_mask = forget_labels
73 | plt.scatter(
74 | embedding[forget_mask, 0],
75 | embedding[forget_mask, 1],
76 | c=labels[forget_mask],
77 | cmap='tab10',
78 | s=50,
79 | alpha=0.7,
80 | marker='x',
81 | linewidths=2.5,
82 | vmin=0,
83 | vmax=9
84 | )
85 | else:
86 | scatter = plt.scatter(
87 | embedding[:, 0],
88 | embedding[:, 1],
89 | c=labels,
90 | cmap='tab10',
91 | s=20,
92 | alpha=0.7
93 | )
94 |
95 | legend_elements = [
96 | plt.Line2D(
97 | [0], [0],
98 | marker='o',
99 | color='w',
100 | label=class_names[i],
101 | markerfacecolor=colors[i],
102 | markersize=10
103 | ) for i in range(10)
104 | ]
105 |
106 | if forget_labels is not None:
107 | legend_elements.append(
108 | plt.Line2D(
109 | [0], [0],
110 | marker='x',
111 | color='k',
112 | label='Forget Data',
113 | markerfacecolor='k',
114 | markersize=10,
115 | markeredgewidth=3,
116 | linestyle='None',
117 | markeredgecolor='k'
118 | )
119 | )
120 |
121 | plt.legend(
122 | handles=legend_elements,
123 | title="Predicted Classes",
124 | loc='upper right',
125 | bbox_to_anchor=(0.99, 0.99),
126 | fontsize='x-large',
127 | title_fontsize='x-large'
128 | )
129 | plt.axis('off')
130 | plt.text(
131 | 0.5, -0.05, f'Last Layer',
132 | fontsize=24,
133 | ha='center',
134 | va='bottom',
135 | transform=plt.gca().transAxes
136 | )
137 | plt.tight_layout()
138 |
139 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
140 | filename = f'{timestamp}_umap_layer_last.svg'
141 | filepath = os.path.join(save_dir, filename)
142 |
143 | plt.savefig(
144 | filepath,
145 | format='svg',
146 | dpi=300,
147 | bbox_inches='tight',
148 | pad_inches=0.1
149 | )
150 |
151 | print("\nUMAP embeddings computation and saving completed!")
152 | return umap_embedding
--------------------------------------------------------------------------------
/backend/main.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from contextlib import asynccontextmanager
3 | from fastapi.middleware.cors import CORSMiddleware
4 | from app.routers import train, unlearn, data
5 | from app.utils.helpers import download_weights_from_hub
6 |
7 | # Constants
8 | ALLOW_ORIGINS = ["*"] # TODO: Update URL after deployment
9 |
10 | @asynccontextmanager
11 | async def lifespan(app: FastAPI):
12 | download_weights_from_hub()
13 | yield
14 |
15 | def setup_middleware(app: FastAPI) -> None:
16 | app.add_middleware(
17 | CORSMiddleware,
18 | allow_origins=ALLOW_ORIGINS,
19 | allow_credentials=True,
20 | allow_methods=["*"],
21 | allow_headers=["*"],
22 | )
23 |
24 | def register_routers(app: FastAPI) -> None:
25 | app.include_router(train.router)
26 | app.include_router(unlearn.router)
27 | app.include_router(data.router)
28 |
29 | def create_app() -> FastAPI:
30 | app = FastAPI(lifespan=lifespan)
31 | setup_middleware(app)
32 | register_routers(app)
33 | return app
34 |
35 | # Create application instance after all definitions
36 | app = create_app()
37 |
38 | @app.get("/")
39 | async def root():
40 | return {"message": "Welcome to the MU Dashboard API"}
--------------------------------------------------------------------------------
/backend/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.envs.default]
6 | python = "3.10"
7 |
8 | [project]
9 | name = "Machine-Unlearning-Comparator"
10 | version = "0.1.0"
11 | description = "Machine Unlearning Comparator"
12 | readme = "README.md"
13 | requires-python = ">=3.8, <3.12"
14 | license = "MIT"
15 | keywords = []
16 | authors = [
17 | { name = "gnueaj", email = "dlwodnd00@gmail.com" },
18 | ]
19 | classifiers = [
20 | "Development Status :: 4 - Beta",
21 | "Programming Language :: Python",
22 | "Programming Language :: Python :: 3.8",
23 | "Programming Language :: Python :: 3.9",
24 | "Programming Language :: Python :: 3.10",
25 | "Programming Language :: Python :: 3.11",
26 | "Programming Language :: Python :: Implementation :: CPython",
27 | "Programming Language :: Python :: Implementation :: PyPy",
28 | ]
29 | dependencies = [
30 | "fastapi==0.109.0", #
31 | "uvicorn==0.27.0", #
32 | "torch==2.1.2", # (CUDA 12.1)
33 | "torchvision==0.16.2", # (torch 2.1.2)
34 | "numpy==1.24.3", #
35 | "umap-learn==0.5.5", #
36 | "scikit-learn==1.3.0", #
37 | "packaging>=21.0",
38 | "matplotlib",
39 | "python-multipart",
40 | "seaborn",
41 | "torch_cka",
42 | "huggingface_hub",
43 | ]
44 |
45 |
46 | [project.urls]
47 | Documentation = "https://github.com/gnueaj/Machine-Unlearning-Comparator#readme"
48 | Issues = "https://github.com/gnueaj/Machine-Unlearning-Comparator/issues"
49 | Source = "https://github.com/gnueaj/Machine-Unlearning-Comparator"
50 |
51 | [tool.hatch.envs.types]
52 | extra-dependencies = [
53 | "mypy>=1.0.0",
54 | ]
55 | [tool.hatch.envs.types.scripts]
56 | cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/fastapi_resnet_cifar10 --cov=tests {args}"
57 | no-cov = "cov --no-cov {args}"
58 | start = "uvicorn main:app --reload"
59 |
60 | [tool.hatch.envs.default.scripts]
61 | start = "uvicorn main:app --host 0.0.0.0 --port 8000 --reload"
62 | # start = "uvicorn main:app --reload"
63 |
64 | [tool.hatch.build.targets.wheel]
65 | packages = ["app"]
66 |
67 | [tool.coverage.run]
68 | branch = true
69 | parallel = true
70 |
71 | [tool.coverage.report]
72 | exclude_lines = [
73 | "no cov",
74 | "if __name__ == .__main__.:",
75 | "if TYPE_CHECKING:",
76 | ]
77 |
--------------------------------------------------------------------------------
/demo.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:0a1cb36f049cf96027e9abf76baf5c6b7fac146177d86e56a24149c5fad99328
3 | size 248814650
4 |
--------------------------------------------------------------------------------
/frontend/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # production
12 | /build
13 |
14 | # misc
15 | .DS_Store
16 | .env.local
17 | .env.development.local
18 | .env.test.local
19 | .env.production.local
20 |
21 | npm-debug.log*
22 | yarn-debug.log*
23 | yarn-error.log*
24 |
--------------------------------------------------------------------------------
/frontend/README.md:
--------------------------------------------------------------------------------
1 | # Machine Unlearning Dashboard
2 | It's MU (Machine Unlearning) Dashboard that provides training, evaluation of unlearning, and attack simulation.
3 |
4 | ## How to start the frontend
5 |
6 | * Install [pnpm](https://pnpm.io/installation) to set up the development environment.
7 |
8 | **1. Clone the repository and move to the frontend directory**
9 | ```shell
10 | git clone https://github.com/gnueaj/mu-dashboard.git && cd mu-dashboard/frontend
11 | ```
12 |
13 | **2. Download necessary modules using `pnpm`**
14 | ```shell
15 | pnpm install
16 | ```
17 |
18 | **3. Run the project**
19 | ```shell
20 | `pnpm dev` or `pnpm start`
21 | ```
22 |
--------------------------------------------------------------------------------
/frontend/components.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://ui.shadcn.com/schema.json",
3 | "style": "default",
4 | "rsc": false,
5 | "tsx": true,
6 | "tailwind": {
7 | "config": "tailwind.config.js",
8 | "css": "src/app/index.css",
9 | "baseColor": "slate",
10 | "cssVariables": true
11 | },
12 | "aliases": {
13 | "components": "@/src/components",
14 | "utils": "@/util.ts"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "my-app",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "@fortawesome/fontawesome-svg-core": "^6.5.2",
7 | "@fortawesome/free-regular-svg-icons": "^6.5.2",
8 | "@fortawesome/free-solid-svg-icons": "^6.5.2",
9 | "@fortawesome/react-fontawesome": "^0.2.2",
10 | "@radix-ui/react-checkbox": "^1.1.4",
11 | "@radix-ui/react-context-menu": "^2.2.2",
12 | "@radix-ui/react-dialog": "^1.1.2",
13 | "@radix-ui/react-hover-card": "^1.1.2",
14 | "@radix-ui/react-label": "^2.1.0",
15 | "@radix-ui/react-radio-group": "^1.2.0",
16 | "@radix-ui/react-scroll-area": "^1.1.0",
17 | "@radix-ui/react-select": "^2.1.1",
18 | "@radix-ui/react-separator": "^1.1.0",
19 | "@radix-ui/react-slider": "^1.2.0",
20 | "@radix-ui/react-slot": "^1.1.0",
21 | "@radix-ui/react-tabs": "^1.1.3",
22 | "@tanstack/react-table": "^8.20.5",
23 | "@testing-library/jest-dom": "^5.17.0",
24 | "@testing-library/react": "^13.4.0",
25 | "@testing-library/user-event": "^13.5.0",
26 | "@types/d3": "^7.4.3",
27 | "@types/jest": "^27.5.2",
28 | "@types/node": "^16.18.101",
29 | "@types/react": "^18.3.3",
30 | "@types/react-dom": "^18.3.0",
31 | "class-variance-authority": "^0.7.0",
32 | "clsx": "^2.1.1",
33 | "d3": "^7.9.0",
34 | "lucide-react": "^0.438.0",
35 | "react": "^18.3.1",
36 | "react-dom": "^18.3.1",
37 | "react-icons": "^5.3.0",
38 | "react-scripts": "5.0.1",
39 | "recharts": "^2.12.7",
40 | "tailwind-merge": "^2.5.2",
41 | "tailwindcss": "^3.4.10",
42 | "tailwindcss-animate": "^1.0.7",
43 | "typescript": "^4.9.5",
44 | "web-vitals": "^2.1.4",
45 | "zustand": "^5.0.3"
46 | },
47 | "scripts": {
48 | "start": "react-scripts start",
49 | "build": "react-scripts build",
50 | "test": "react-scripts test",
51 | "eject": "react-scripts eject",
52 | "dev": "pnpm start"
53 | },
54 | "eslintConfig": {
55 | "extends": [
56 | "react-app",
57 | "react-app/jest"
58 | ]
59 | },
60 | "browserslist": {
61 | "production": [
62 | ">0.2%",
63 | "not dead",
64 | "not op_mini all"
65 | ],
66 | "development": [
67 | "last 1 chrome version",
68 | "last 1 firefox version",
69 | "last 1 safari version"
70 | ]
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
12 |
13 |
14 |
15 | Machine Unlearning Comparator
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/frontend/public/logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/frontend/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "React App",
3 | "name": "Create React App Sample",
4 | "icons": [
5 | {
6 | "src": "favicon.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | },
10 | {
11 | "src": "logo192.png",
12 | "type": "image/png",
13 | "sizes": "192x192"
14 | },
15 | {
16 | "src": "logo512.png",
17 | "type": "image/png",
18 | "sizes": "512x512"
19 | }
20 | ],
21 | "start_url": ".",
22 | "display": "standalone",
23 | "theme_color": "#000000",
24 | "background_color": "#ffffff"
25 | }
26 |
--------------------------------------------------------------------------------
/frontend/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/frontend/src/app/App.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect, useCallback } from "react";
2 |
3 | import Header from "../components/Header/Header";
4 | import ModelScreening from "../views/ModelScreening/ModelScreening";
5 | import Core from "../views/Core/Core";
6 | import MetricsView from "../views/MetricsView/MetricsView";
7 | import { useExperimentsStore } from "../stores/experimentsStore";
8 | import { calculateZoom } from "../utils/util";
9 |
10 | export const CONFIG = {
11 | TOTAL_WIDTH: 1805,
12 | EXPERIMENTS_WIDTH: 1032,
13 | CORE_WIDTH: 1312,
14 | get PROGRESS_WIDTH() {
15 | return this.CORE_WIDTH - this.EXPERIMENTS_WIDTH - 1;
16 | },
17 | get ANALYSIS_VIEW_WIDTH() {
18 | return this.TOTAL_WIDTH - this.CORE_WIDTH;
19 | },
20 |
21 | EXPERIMENTS_PROGRESS_HEIGHT: 256,
22 | CORE_HEIGHT: 820,
23 | get TOTAL_HEIGHT() {
24 | return this.CORE_HEIGHT + this.EXPERIMENTS_PROGRESS_HEIGHT;
25 | },
26 | } as const;
27 |
28 | export default function App() {
29 | const { isExperimentLoading } = useExperimentsStore();
30 |
31 | const [isPageLoading, setIsPageLoading] = useState(true);
32 | const [zoom, setZoom] = useState(1);
33 |
34 | const handleResize = useCallback(() => {
35 | setZoom(calculateZoom());
36 | }, []);
37 |
38 | useEffect(() => {
39 | setIsPageLoading(false);
40 |
41 | window.addEventListener("resize", handleResize);
42 | handleResize();
43 |
44 | return () => window.removeEventListener("resize", handleResize);
45 | }, [handleResize]);
46 |
47 | if (isPageLoading) return <>>;
48 |
49 | return (
50 |
51 |
52 | {!isExperimentLoading && (
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | )}
61 |
62 | );
63 | }
64 |
--------------------------------------------------------------------------------
/frontend/src/app/index.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | @layer base {
6 | :root {
7 | --background: 0 0% 100%;
8 | --foreground: 222.2 47.4% 11.2%;
9 |
10 | --muted: 210 40% 96.1%;
11 | --muted-foreground: 215.4 16.3% 46.9%;
12 |
13 | --popover: 0 0% 100%;
14 | --popover-foreground: 222.2 47.4% 11.2%;
15 |
16 | --border: 214.3 31.8% 91.4%;
17 | --input: 214.3 31.8% 91.4%;
18 |
19 | --card: 0 0% 100%;
20 | --card-foreground: 222.2 47.4% 11.2%;
21 |
22 | --primary: 222.2 47.4% 11.2%;
23 | --primary-foreground: 210 40% 98%;
24 |
25 | --secondary: 210 40% 96.1%;
26 | --secondary-foreground: 222.2 47.4% 11.2%;
27 |
28 | --accent: 210 40% 96.1%;
29 | --accent-foreground: 222.2 47.4% 11.2%;
30 |
31 | --destructive: 0 100% 50%;
32 | --destructive-foreground: 210 40% 98%;
33 |
34 | --ring: 215 20.2% 65.1%;
35 |
36 | --radius: 0.5rem;
37 | }
38 |
39 | .dark {
40 | --background: 224 71% 4%;
41 | --foreground: 213 31% 91%;
42 |
43 | --muted: 223 47% 11%;
44 | --muted-foreground: 215.4 16.3% 56.9%;
45 |
46 | --accent: 216 34% 17%;
47 | --accent-foreground: 210 40% 98%;
48 |
49 | --popover: 224 71% 4%;
50 | --popover-foreground: 215 20.2% 65.1%;
51 |
52 | --border: 216 34% 17%;
53 | --input: 216 34% 17%;
54 |
55 | --card: 224 71% 4%;
56 | --card-foreground: 213 31% 91%;
57 |
58 | --primary: 210 40% 98%;
59 | --primary-foreground: 222.2 47.4% 1.2%;
60 |
61 | --secondary: 222.2 47.4% 11.2%;
62 | --secondary-foreground: 210 40% 98%;
63 |
64 | --destructive: 0 63% 31%;
65 | --destructive-foreground: 210 40% 98%;
66 |
67 | --ring: 216 34% 17%;
68 |
69 | --radius: 0.5rem;
70 | }
71 | }
72 |
73 | @layer base {
74 | :root {
75 | --chart-1: 12 76% 61%;
76 | --chart-2: 173 58% 39%;
77 | --chart-3: 197 37% 24%;
78 | --chart-4: 43 74% 66%;
79 | --chart-5: 27 87% 67%;
80 | }
81 |
82 | .dark {
83 | --chart-1: 220 70% 50%;
84 | --chart-2: 160 60% 45%;
85 | --chart-3: 30 80% 55%;
86 | --chart-4: 280 65% 60%;
87 | --chart-5: 340 75% 55%;
88 | }
89 | }
90 |
91 | * {
92 | box-sizing: border-box;
93 | margin: 0;
94 | padding: 0;
95 | }
96 |
97 | body {
98 | display: flex;
99 | justify-content: flex-start;
100 | align-items: flex-start;
101 | font-family: "Roboto Condensed", "pretendard", -apple-system, "Roboto",
102 | sans-serif;
103 | -webkit-font-smoothing: antialiased;
104 | -moz-osx-font-smoothing: grayscale;
105 | }
106 |
107 | input[type="number"]::-webkit-outer-spin-button,
108 | input[type="number"]::-webkit-inner-spin-button {
109 | -webkit-appearance: none;
110 | margin: 0;
111 | }
112 | input:disabled,
113 | select:disabled {
114 | cursor: not-allowed;
115 | }
116 |
117 | form {
118 | width: 100%;
119 | }
120 |
--------------------------------------------------------------------------------
/frontend/src/app/react-app-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/frontend/src/components/Core/Embeddings/ConnectionLine.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { calculateZoom } from "../../../utils/util";
4 |
5 | interface ConnectionLineProps {
6 | from: { x: number; y: number } | null;
7 | to: { x: number; y: number } | null;
8 | }
9 |
10 | export default function ConnectionLine({ from, to }: ConnectionLineProps) {
11 | if (!from || !to) {
12 | return null;
13 | }
14 |
15 | const zoom = calculateZoom();
16 | const scrollbarWidth =
17 | window.innerWidth - document.documentElement.clientWidth;
18 | const adjustedZoom =
19 | ((window.innerWidth - scrollbarWidth) / window.innerWidth) * zoom;
20 |
21 | const lineStyle: React.CSSProperties = {
22 | position: "fixed",
23 | left: 0,
24 | top: 0,
25 | pointerEvents: "none",
26 | };
27 |
28 | const x1 = from.x / adjustedZoom;
29 | const y1 = from.y / adjustedZoom;
30 | const x2 = to.x / adjustedZoom;
31 | const y2 = to.y / adjustedZoom;
32 |
33 | const length = Math.hypot(x2 - x1, y2 - y1);
34 | const angle = (Math.atan2(y2 - y1, x2 - x1) * 180) / Math.PI;
35 |
36 | const shortenedLength = length - 8;
37 |
38 | const offsetX = 4 * Math.cos((angle * Math.PI) / 180);
39 | const offsetY = 4 * Math.sin((angle * Math.PI) / 180);
40 |
41 | const linePositionStyle: React.CSSProperties = {
42 | position: "absolute",
43 | transformOrigin: "0 0",
44 | transform: `translate(${x1 + offsetX}px, ${
45 | y1 + offsetY
46 | }px) rotate(${angle}deg)`,
47 | width: `${shortenedLength}px`,
48 | height: "2px",
49 | backgroundColor: "black",
50 | };
51 |
52 | return (
53 |
56 | );
57 | }
58 |
--------------------------------------------------------------------------------
/frontend/src/components/Core/Embeddings/ConnectionLineWrapper.tsx:
--------------------------------------------------------------------------------
1 | import React, { memo, useRef, useEffect, useState } from "react";
2 |
3 | import ConnectionLine from "./ConnectionLine";
4 | import { Coordinate } from "../../../types/embeddings";
5 |
6 | type Position = {
7 | from: Coordinate | null;
8 | to: Coordinate | null;
9 | };
10 |
11 | interface Props {
12 | positionRef: React.MutableRefObject;
13 | }
14 |
15 | const ConnectionLineWrapper = memo(({ positionRef }: Props) => {
16 | const [, setUpdateKey] = useState(0);
17 | const prevPositionRef = useRef({ from: null, to: null });
18 |
19 | useEffect(() => {
20 | const hasPositionChanged = () => {
21 | const current = positionRef.current;
22 | const prev = prevPositionRef.current;
23 |
24 | if (!current.from !== !prev.from || !current.to !== !prev.to) return true;
25 | if (!current.from || !current.to) return false;
26 | if (!prev.from || !prev.to) return true;
27 |
28 | return (
29 | current.from.x !== prev.from.x ||
30 | current.from.y !== prev.from.y ||
31 | current.to.x !== prev.to.x ||
32 | current.to.y !== prev.to.y
33 | );
34 | };
35 |
36 | const intervalId = setInterval(() => {
37 | if (hasPositionChanged()) {
38 | prevPositionRef.current = {
39 | from: positionRef.current.from
40 | ? { ...positionRef.current.from }
41 | : null,
42 | to: positionRef.current.to ? { ...positionRef.current.to } : null,
43 | };
44 | setUpdateKey((prev) => prev + 1);
45 | }
46 | }, 16);
47 |
48 | return () => clearInterval(intervalId);
49 | }, [positionRef]);
50 |
51 | return (
52 |
56 | );
57 | });
58 |
59 | export default ConnectionLineWrapper;
60 |
--------------------------------------------------------------------------------
/frontend/src/components/Core/PrivacyAttack/AttackSuccessFailure.tsx:
--------------------------------------------------------------------------------
1 | import InstancePanel from "./InstancePanel";
2 | import { Bin, Data, CategoryType, Image } from "../../../types/attack";
3 |
4 | interface AttackSuccessFailureProps {
5 | mode: "A" | "B";
6 | thresholdValue: number;
7 | hoveredId: number | null;
8 | data: Data;
9 | imageMap: Map;
10 | attackScore: number;
11 | setHoveredId: (val: number | null) => void;
12 | onElementClick: (
13 | event: React.MouseEvent,
14 | elementData: Bin & { type: CategoryType }
15 | ) => void;
16 | }
17 |
18 | export default function AttackSuccessFailure({
19 | mode,
20 | thresholdValue,
21 | hoveredId,
22 | data,
23 | imageMap,
24 | attackScore,
25 | setHoveredId,
26 | onElementClick,
27 | }: AttackSuccessFailureProps) {
28 | const forgettingQualityScore = 1 - attackScore;
29 |
30 | return (
31 |
32 |
33 | Privacy Score ={" "}
34 |
35 | {forgettingQualityScore === 1 ? 1 : forgettingQualityScore.toFixed(3)}
36 |
37 |
38 |
39 |
49 |
59 |
60 |
61 | );
62 | }
63 |
--------------------------------------------------------------------------------
/frontend/src/components/Header/Header.tsx:
--------------------------------------------------------------------------------
1 | import { FileText } from "lucide-react";
2 |
3 | import Tabs from "./Tabs";
4 | import { Logo, GithubIcon } from "../UI/icons";
5 | import { useBaseConfigStore } from "../../stores/baseConfigStore";
6 | import { DATASETS, NEURAL_NETWORK_MODELS } from "../../constants/common";
7 | import {
8 | Select,
9 | SelectContent,
10 | SelectItem,
11 | SelectTrigger,
12 | SelectValue,
13 | } from "../UI/select";
14 |
15 | export default function Header() {
16 | const { dataset, setDataset, neuralNetworkModel, setNeuralNetworkModel } =
17 | useBaseConfigStore();
18 |
19 | const handleGithubIconClick = () => {
20 | window.open(
21 | "https://github.com/gnueaj/Machine-Unlearning-Comparator",
22 | "_blank"
23 | );
24 | };
25 |
26 | const handleDatasetChange = (dataset: string) => {
27 | setDataset(dataset);
28 | };
29 |
30 | const handleNeuralNetworkModelChange = (model: string) => {
31 | setNeuralNetworkModel(model);
32 | };
33 |
34 | return (
35 |
36 |
37 |
38 |
39 |
40 |
41 | Unlearning Comparator
42 |
43 |
44 |
45 |
46 | Architecture
47 |
63 |
64 |
65 | Dataset
66 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
93 |
94 |
95 | );
96 | }
97 |
--------------------------------------------------------------------------------
/frontend/src/components/Header/Tab.tsx:
--------------------------------------------------------------------------------
1 | import { MultiplicationSignIcon } from "../UI/icons";
2 | import { useClasses } from "../../hooks/useClasses";
3 | import { useForgetClassStore } from "../../stores/forgetClassStore";
4 | import { cn } from "../../utils/util";
5 |
6 | interface Props {
7 | setOpen: (open: boolean) => void;
8 | fetchAndSaveExperiments: (forgetClass: string) => Promise;
9 | }
10 |
11 | export default function Tab({ setOpen, fetchAndSaveExperiments }: Props) {
12 | const classes = useClasses();
13 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
14 | const saveForgetClass = useForgetClassStore((state) => state.saveForgetClass);
15 | const selectedForgetClasses = useForgetClassStore(
16 | (state) => state.selectedForgetClasses
17 | );
18 | const deleteSelectedForgetClass = useForgetClassStore(
19 | (state) => state.deleteSelectedForgetClass
20 | );
21 |
22 | const handleForgetClassChange = async (value: string) => {
23 | if (classes[forgetClass] !== value) {
24 | saveForgetClass(value);
25 | await fetchAndSaveExperiments(value);
26 | }
27 | };
28 |
29 | const handleDeleteClick = async (targetClass: string) => {
30 | const firstSelectedForgetClass = selectedForgetClasses[0];
31 | const secondSelectedForgetClass = selectedForgetClasses[1];
32 | const targetSelectedForgetClassesIndex = selectedForgetClasses.indexOf(
33 | classes.indexOf(targetClass)
34 | );
35 |
36 | deleteSelectedForgetClass(targetClass);
37 |
38 | if (targetClass === classes[forgetClass]) {
39 | if (selectedForgetClasses.length === 1) {
40 | saveForgetClass(-1);
41 | setOpen(true);
42 | } else {
43 | const autoSelectedForgetClass =
44 | targetSelectedForgetClassesIndex === 0
45 | ? classes[secondSelectedForgetClass]
46 | : classes[firstSelectedForgetClass];
47 | saveForgetClass(autoSelectedForgetClass);
48 | await fetchAndSaveExperiments(autoSelectedForgetClass);
49 | }
50 | }
51 | };
52 |
53 | return (
54 | <>
55 | {selectedForgetClasses.map((selectedForgetClass, idx) => {
56 | const isSelectedForgetClass = selectedForgetClass === forgetClass;
57 | const forgetClassName = classes[selectedForgetClass];
58 |
59 | return (
60 |
61 |
handleForgetClassChange(forgetClassName)}
69 | >
70 |
76 | Forget: {forgetClassName}
77 |
78 |
79 |
handleDeleteClick(forgetClassName)}
81 | className={cn(
82 | "w-3.5 h-3.5 p-[1px] cursor-pointer rounded-full absolute right-2.5 bg-transparent transition text-gray-500",
83 | isSelectedForgetClass
84 | ? "hover:bg-gray-300"
85 | : "hover:bg-gray-700"
86 | )}
87 | />
88 | {isSelectedForgetClass && (
89 |
90 | )}
91 |
92 | );
93 | })}
94 | >
95 | );
96 | }
97 |
--------------------------------------------------------------------------------
/frontend/src/components/Header/Tabs.tsx:
--------------------------------------------------------------------------------
1 | import { useState } from "react";
2 |
3 | import Tab from "./Tab";
4 | import ForgetClassTabPlusButton from "./TabPlusButton";
5 | import { Experiment, Experiments } from "../../types/data";
6 | import { useClasses } from "../../hooks/useClasses";
7 | import { useExperimentsStore } from "../../stores/experimentsStore";
8 | import { useForgetClassStore } from "../../stores/forgetClassStore";
9 | import { fetchAllExperimentsData } from "../../utils/api/unlearning";
10 |
11 | export default function Tabs() {
12 | const classes = useClasses();
13 | const saveExperiments = useExperimentsStore((state) => state.saveExperiments);
14 | const selectedForgetClasses = useForgetClassStore(
15 | (state) => state.selectedForgetClasses
16 | );
17 | const setIsExperimentsLoading = useExperimentsStore(
18 | (state) => state.setIsExperimentsLoading
19 | );
20 |
21 | const hasNoSelectedForgetClass = selectedForgetClasses.length === 0;
22 |
23 | const [open, setOpen] = useState(hasNoSelectedForgetClass);
24 |
25 | const fetchAndSaveExperiments = async (forgetClass: string) => {
26 | const classIndex = classes.indexOf(forgetClass);
27 | setIsExperimentsLoading(true);
28 | try {
29 | const allData: Experiments = await fetchAllExperimentsData(classIndex);
30 |
31 | if ("detail" in allData) {
32 | saveExperiments({});
33 | } else {
34 | Object.values(allData).forEach((experiment: Experiment) => {
35 | if (experiment && "points" in experiment) {
36 | delete experiment.points;
37 | }
38 | });
39 | saveExperiments(allData);
40 | }
41 | } finally {
42 | setIsExperimentsLoading(false);
43 | }
44 | };
45 |
46 | return (
47 |
48 |
52 |
58 |
59 | );
60 | }
61 |
--------------------------------------------------------------------------------
/frontend/src/components/MetricsView/Predictions/BubbleMatrixLegend.tsx:
--------------------------------------------------------------------------------
1 | import * as d3 from "d3";
2 |
3 | import { Arrow } from "../../UI/icons";
4 |
5 | export default function BubbleChartLegend() {
6 | const colorScale = d3
7 | .scaleSequential((t) => d3.interpolateTurbo(0.05 + 0.95 * t))
8 | .domain([0, 1]);
9 | const numStops = 100;
10 | const bubbleColorScale = Array.from({ length: numStops }, (_, i) =>
11 | colorScale(i / (numStops - 1))
12 | );
13 | const gradient = `linear-gradient(to right, ${bubbleColorScale.join(", ")})`;
14 |
15 | return (
16 |
17 |
18 |
19 |
20 |
21 |
22 | Small
23 | Proportion
24 |
25 |
26 |
27 | Large
28 | Proportion
29 |
30 |
31 |
32 |
33 |
34 |
38 |
39 | 0
40 |
41 |
42 | 1
43 |
44 |
45 |
46 |
47 | Low
48 | Confidence
49 |
50 |
51 |
52 | High
53 | Confidence
54 |
55 |
56 |
57 |
58 | );
59 | }
60 |
--------------------------------------------------------------------------------
/frontend/src/components/MetricsView/Predictions/CorrelationMatrixLegend.tsx:
--------------------------------------------------------------------------------
1 | import * as d3 from "d3";
2 |
3 | import { Arrow } from "../../UI/icons";
4 |
5 | export default function PredictionMatrixLegend() {
6 | const colorScale = d3
7 | .scaleSequential((t) => d3.interpolateGreys(0.05 + 0.95 * t))
8 | .domain([0, 1]);
9 | const numStops = 10;
10 | const bubbleColorScale = Array.from({ length: numStops }, (_, i) =>
11 | colorScale(i / (numStops - 1))
12 | );
13 | const gradient = `linear-gradient(to right, ${bubbleColorScale.join(", ")})`;
14 |
15 | return (
16 |
17 |
18 |
19 |
20 | Row
21 | Proportion
22 |
23 |
24 |
25 | Row
26 | Confidence
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | 0
36 |
37 |
38 | 1
39 |
40 |
41 |
42 |
Low
43 |
44 |
High
45 |
46 |
47 |
48 | );
49 | }
50 |
51 | function RectangleIcon({ className }: { className: string }) {
52 | return (
53 |
79 | );
80 | }
81 |
82 | function RightArrowIcon() {
83 | return (
84 |
98 | );
99 | }
100 |
101 | function LeftArrowIcon() {
102 | return (
103 |
117 | );
118 | }
119 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/CustomUnlearning.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { Input } from "../../UI/input";
4 | import { cn } from "../../../utils/util";
5 |
6 | interface Props extends React.InputHTMLAttributes {
7 | fileName: string;
8 | }
9 |
10 | export default function CustomUnlearning({ fileName, ...props }: Props) {
11 | return (
12 |
13 |
Upload File
14 |
15 |
22 |
23 |
26 | {fileName ? fileName : "Choose File"}
27 |
28 |
29 |
30 |
31 | );
32 | }
33 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/HyperparameterInput.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect } from "react";
2 |
3 | import {
4 | EPOCH,
5 | LEARNING_RATE,
6 | BATCH_SIZE,
7 | } from "../../../constants/experiments";
8 | import { Input } from "../../UI/input";
9 | import { PlusIcon } from "../../UI/icons";
10 | import { COLORS } from "../../../constants/colors";
11 | import { cn } from "../../../utils/util";
12 |
13 | interface Props
14 | extends Omit, "list"> {
15 | title: string;
16 | initialValue: string;
17 | paramList: (string | number)[];
18 | onPlusClick: (id: string, value: string) => void;
19 | }
20 |
21 | const CONFIG = {
22 | EPOCHS_MIN: 1,
23 | LEARNING_RATE_MIN: 0,
24 | LEARNING_RATE_MAX: 1,
25 | BATCH_SIZE_MIN: 1,
26 | } as const;
27 |
28 | export default function HyperparameterInput({
29 | title,
30 | initialValue,
31 | paramList,
32 | onPlusClick,
33 | ...props
34 | }: Props) {
35 | const [value, setValue] = useState(initialValue);
36 | const [isDisabled, setIsDisabled] = useState(false);
37 |
38 | useEffect(() => {
39 | setValue(initialValue);
40 | }, [initialValue]);
41 |
42 | useEffect(() => {
43 | if (paramList.length === 5 || value.trim() === "") {
44 | setIsDisabled(true);
45 | } else {
46 | setIsDisabled(false);
47 | }
48 | }, [paramList.length, value]);
49 |
50 | const processedTitle = title.replace(/\s+/g, "");
51 | const id = processedTitle.charAt(0).toLowerCase() + processedTitle.slice(1);
52 | const isIntegerInput = id === EPOCH || id === BATCH_SIZE;
53 |
54 | const handleValueChange = (event: React.ChangeEvent) => {
55 | const inputValue = event.currentTarget.value;
56 |
57 | if (inputValue === "") {
58 | setValue(inputValue);
59 | setIsDisabled(true);
60 | return;
61 | }
62 | setIsDisabled(false);
63 |
64 | let newValue = inputValue;
65 | if (isIntegerInput && newValue.includes(".")) {
66 | newValue = parseInt(newValue, 10).toString();
67 | }
68 |
69 | switch (id) {
70 | case EPOCH: {
71 | const validValue = Math.max(Number(newValue), CONFIG.EPOCHS_MIN);
72 | setValue(String(validValue));
73 | break;
74 | }
75 | case BATCH_SIZE: {
76 | const validValue = Math.max(Number(newValue), CONFIG.BATCH_SIZE_MIN);
77 | setValue(String(validValue));
78 | break;
79 | }
80 | case LEARNING_RATE: {
81 | setValue(newValue);
82 | const numericValue = Number(newValue);
83 | if (
84 | numericValue >= CONFIG.LEARNING_RATE_MAX ||
85 | numericValue <= CONFIG.LEARNING_RATE_MIN
86 | ) {
87 | setIsDisabled(true);
88 | } else {
89 | setIsDisabled(false);
90 | }
91 | break;
92 | }
93 | default:
94 | setValue(newValue);
95 | break;
96 | }
97 | };
98 |
99 | const handlePlusClick = () => {
100 | if (!isDisabled) {
101 | onPlusClick(id, value);
102 | }
103 | };
104 |
105 | return (
106 |
107 |
{title}
108 |
116 |
125 |
126 | );
127 | }
128 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/Legend.tsx:
--------------------------------------------------------------------------------
1 | import * as d3 from "d3";
2 |
3 | const CONFIG = {
4 | COLOR_TEMPERATURE_LOW: 0.03,
5 | COLOR_TEMPERATURE_HIGH: 0.77,
6 | } as const;
7 |
8 | export default function DualMetricsLegend() {
9 | const greenScale = d3
10 | .scaleSequential((t) =>
11 | d3.interpolateGreens(
12 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t
13 | )
14 | )
15 | .domain([0, 1]);
16 | const accuracyGradient = `linear-gradient(to right, ${greenScale(
17 | 0
18 | )} 0%, ${greenScale(1)} 100%)`;
19 |
20 | const blueScale = d3
21 | .scaleSequential((t) =>
22 | d3.interpolateBlues(
23 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t
24 | )
25 | )
26 | .domain([0, 1]);
27 | const efficiencyGradient = `linear-gradient(to right, ${blueScale(
28 | 0
29 | )} 0%, ${blueScale(1)} 100%)`;
30 |
31 | const orangeScale = d3
32 | .scaleSequential((t) =>
33 | d3.interpolateOranges(
34 | CONFIG.COLOR_TEMPERATURE_LOW + CONFIG.COLOR_TEMPERATURE_HIGH * t
35 | )
36 | )
37 | .domain([0, 1]);
38 | const forgettingQualityGradient = `linear-gradient(to right, ${orangeScale(
39 | 0
40 | )} 0%, ${orangeScale(1)} 100%)`;
41 |
42 | return (
43 |
44 |
45 |
Accuracy
46 |
47 |
51 |
52 | Worst
53 |
54 |
55 | Best
56 |
57 |
58 |
59 |
60 |
61 |
Efficiency
62 |
63 |
67 |
68 | Worst
69 |
70 |
71 | Best
72 |
73 |
74 |
75 |
76 |
77 |
78 | Privacy
79 |
80 |
81 |
85 |
86 | 0
87 |
88 |
89 | 1
90 |
91 |
92 |
93 |
94 |
95 | );
96 | }
97 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/MethodFilterHeader.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useRef, useEffect } from "react";
2 | import { createPortal } from "react-dom";
3 |
4 | import { FilterIcon } from "../../UI/icons";
5 | import { UNLEARNING_METHODS } from "../../../constants/experiments";
6 | import { cn } from "../../../utils/util";
7 |
8 | const MOUSEDOWN = "mousedown";
9 |
10 | export default function MethodFilterHeader({ column }: { column: any }) {
11 | const [filterValues, setFilterValues] = useState([]);
12 | const [showDropdown, setShowDropdown] = useState(false);
13 | const [dropdownCoords, setDropdownCoords] = useState({ top: 0, left: 0 });
14 |
15 | const containerRef = useRef(null);
16 | const dropdownRef = useRef(null);
17 |
18 | useEffect(() => {
19 | if (showDropdown && containerRef.current) {
20 | const rect = containerRef.current.getBoundingClientRect();
21 | setDropdownCoords({
22 | top: rect.bottom,
23 | left: rect.left,
24 | });
25 | }
26 | }, [showDropdown]);
27 |
28 | const handleSelect = (value: string) => () => {
29 | let newFilterValues: string[];
30 |
31 | if (filterValues.includes(value)) {
32 | newFilterValues = filterValues.filter((val) => val !== value);
33 | } else if (value === "") {
34 | newFilterValues = [];
35 | } else {
36 | newFilterValues = [...filterValues, value];
37 | }
38 |
39 | setFilterValues(newFilterValues);
40 | column.setFilterValue(newFilterValues);
41 | setShowDropdown(false);
42 | };
43 |
44 | useEffect(() => {
45 | function handleClickOutside(event: MouseEvent) {
46 | if (
47 | containerRef.current &&
48 | !containerRef.current.contains(event.target as Node) &&
49 | dropdownRef.current &&
50 | !dropdownRef.current.contains(event.target as Node)
51 | ) {
52 | setShowDropdown(false);
53 | }
54 | }
55 |
56 | if (showDropdown) {
57 | document.addEventListener(MOUSEDOWN, handleClickOutside);
58 | } else {
59 | document.removeEventListener(MOUSEDOWN, handleClickOutside);
60 | }
61 | return () => {
62 | document.removeEventListener(MOUSEDOWN, handleClickOutside);
63 | };
64 | }, [showDropdown]);
65 |
66 | return (
67 |
68 |
Method
69 |
setShowDropdown(true)}
73 | />
74 | {showDropdown &&
75 | createPortal(
76 |
84 |
88 | All
89 |
90 | {UNLEARNING_METHODS.map((method) => (
91 |
101 | {method}
102 |
103 | ))}
104 |
,
105 | document.body
106 | )}
107 |
108 | );
109 | }
110 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/MethodUnlearning.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect } from "react";
2 |
3 | import HyperparameterInput from "./HyperparameterInput";
4 | import {
5 | EPOCH,
6 | LEARNING_RATE,
7 | BATCH_SIZE,
8 | } from "../../../constants/experiments";
9 | import { Badge } from "../../UI/badge";
10 | import { getDefaultUnlearningConfig } from "../../../utils/config/unlearning";
11 |
12 | interface Props extends React.InputHTMLAttributes {
13 | method: string;
14 | epochsList: string[];
15 | learningRateList: string[];
16 | batchSizeList: string[];
17 | setEpoch: (epoch: string[]) => void;
18 | setLearningRate: (lr: string[]) => void;
19 | setBatchSize: (bs: string[]) => void;
20 | onPlusClick: (id: string, value: string) => void;
21 | onBadgeClick: (event: React.MouseEvent) => void;
22 | }
23 |
24 | export default function MethodUnlearning({
25 | method,
26 | epochsList,
27 | learningRateList,
28 | batchSizeList,
29 | setEpoch,
30 | setLearningRate,
31 | setBatchSize,
32 | onPlusClick,
33 | onBadgeClick,
34 | ...props
35 | }: Props) {
36 | const [initialValues, setInitialValues] = useState(["10", "0.01", "64"]);
37 |
38 | useEffect(() => {
39 | const { epoch, learning_rate, batch_size } =
40 | getDefaultUnlearningConfig(method);
41 |
42 | setInitialValues([epoch, learning_rate, batch_size]);
43 | setEpoch([epoch]);
44 | setLearningRate([learning_rate]);
45 | setBatchSize([batch_size]);
46 | }, [method, setBatchSize, setEpoch, setLearningRate]);
47 |
48 | return (
49 |
50 |
Hyperparameters
51 |
52 |
53 |
60 | {epochsList.length > 0 && (
61 |
62 | {epochsList.map((epoch, idx) => (
63 |
69 | {epoch}
70 |
71 | ))}
72 |
73 | )}
74 |
75 |
76 |
83 | {learningRateList.length > 0 && (
84 |
85 | {learningRateList.map((rate, idx) => (
86 |
92 | {rate}
93 |
94 | ))}
95 |
96 | )}
97 |
98 |
99 |
106 | {batchSizeList.length > 0 && (
107 |
108 | {batchSizeList.map((batch, idx) => (
109 |
115 | {batch}
116 |
117 | ))}
118 |
119 | )}
120 |
121 |
122 |
123 | );
124 | }
125 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Experiments/TableHeader.tsx:
--------------------------------------------------------------------------------
1 | import { Table as TableType, flexRender } from "@tanstack/react-table";
2 |
3 | import { Table, TableHead, TableHeader, TableRow } from "../../UI/table";
4 | import { COLUMN_WIDTHS } from "./Columns";
5 | import { ExperimentData } from "../../../types/data";
6 |
7 | interface Props {
8 | table: TableType;
9 | }
10 |
11 | export default function _TableHeader({ table }: Props) {
12 | return (
13 |
14 |
15 | {table.getHeaderGroups().map((headerGroup) => (
16 |
17 | {headerGroup.headers.map((header) => {
18 | const columnWidth =
19 | COLUMN_WIDTHS[header.column.id as keyof typeof COLUMN_WIDTHS];
20 | return (
21 |
28 | {header.isPlaceholder
29 | ? null
30 | : flexRender(
31 | header.column.columnDef.header,
32 | header.getContext()
33 | )}
34 |
35 | );
36 | })}
37 |
38 | ))}
39 |
40 |
41 | );
42 | }
43 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Progress/AddModelsButton.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 |
3 | import UnlearningConfiguration from "./UnlearningConfiguration";
4 | import Button from "../../common/CustomButton";
5 | import { PlusIcon } from "../../UI/icons";
6 | import { useRunningStatusStore } from "../../../stores/runningStatusStore";
7 | import { cn } from "../../../utils/util";
8 | import {
9 | Dialog,
10 | DialogContent,
11 | DialogHeader,
12 | DialogTitle,
13 | DialogTrigger,
14 | } from "../../UI/dialog";
15 |
16 | export default function AddExperimentsButton() {
17 | const { isRunning } = useRunningStatusStore();
18 |
19 | const [open, setOpen] = useState(false);
20 |
21 | useEffect(() => {
22 | if (isRunning) setOpen(false);
23 | }, [isRunning]);
24 |
25 | return (
26 |
45 | );
46 | }
47 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Progress/Pagination.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import {
4 | Pagination,
5 | PaginationContent,
6 | PaginationItem,
7 | PaginationNext,
8 | PaginationPrevious,
9 | } from "../../UI/pagination";
10 | import { PREV, NEXT } from "../../../views/ModelScreening/Progress";
11 | import { useRunningStatusStore } from "../../../stores/runningStatusStore";
12 |
13 | interface Props extends React.LiHTMLAttributes {
14 | currentPage: number;
15 | }
16 |
17 | export default function ProgressPagination({ currentPage, ...props }: Props) {
18 | const { totalExperimentsCount } = useRunningStatusStore();
19 |
20 | return (
21 |
22 |
23 |
24 |
25 |
26 |
27 | {currentPage} / {totalExperimentsCount}
28 |
29 |
30 |
31 |
32 |
33 |
34 | );
35 | }
36 |
--------------------------------------------------------------------------------
/frontend/src/components/ModelScreening/Progress/Stepper.tsx:
--------------------------------------------------------------------------------
1 | import { memo } from "react";
2 | import { Check, Dot, Loader2 } from "lucide-react";
3 |
4 | import {
5 | Stepper,
6 | StepperDescription,
7 | StepperItem,
8 | StepperSeparator,
9 | StepperTitle,
10 | StepperTrigger,
11 | } from "../../UI/stepper";
12 | import { Button } from "../../UI/button";
13 | import { Step } from "../../../views/ModelScreening/Progress";
14 |
15 | const _Stepper = memo(function _Stepper({
16 | steps,
17 | activeStep,
18 | completedSteps,
19 | isRunning,
20 | }: {
21 | steps: Step[];
22 | activeStep: number;
23 | completedSteps: number[];
24 | isRunning: boolean;
25 | }) {
26 | return (
27 |
28 | {steps.map((step, idx) => {
29 | const isNotLastStep = idx !== steps.length - 1;
30 |
31 | let state: "completed" | "active" | "inactive";
32 | if (step.step === activeStep) {
33 | state = "active";
34 | } else if (completedSteps.includes(step.step)) {
35 | state = "completed";
36 | } else {
37 | state = "inactive";
38 | }
39 |
40 | return (
41 |
45 | {isNotLastStep && (
46 |
47 |
48 |
49 | )}
50 |
51 |
64 |
65 |
66 |
67 | {step.title}
68 |
69 |
70 | {step.description.split("\n").map((el, idx) => (
71 |
72 | {el
73 | .split("**")
74 | .map((part, partIdx) =>
75 | partIdx % 2 === 1 ? (
76 | {part}
77 | ) : (
78 | part
79 | )
80 | )}
81 |
82 | ))}
83 |
84 |
85 |
86 | );
87 | })}
88 |
89 | );
90 | });
91 |
92 | export default _Stepper;
93 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/badge.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import { cva, type VariantProps } from "class-variance-authority";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const badgeVariants = cva(
7 | "inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
8 | {
9 | variants: {
10 | variant: {
11 | default:
12 | "border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
13 | secondary:
14 | "border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
15 | destructive:
16 | "border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
17 | outline: "text-foreground",
18 | },
19 | },
20 | defaultVariants: {
21 | variant: "default",
22 | },
23 | }
24 | );
25 |
26 | export interface BadgeProps
27 | extends React.HTMLAttributes,
28 | VariantProps {}
29 |
30 | function Badge({ className, variant, ...props }: BadgeProps) {
31 | return (
32 |
33 | );
34 | }
35 |
36 | export { Badge, badgeVariants };
37 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/button.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import { Slot } from "@radix-ui/react-slot";
3 | import { cva, type VariantProps } from "class-variance-authority";
4 |
5 | import { cn } from "../../utils/util";
6 |
7 | const buttonVariants = cva(
8 | "inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
9 | {
10 | variants: {
11 | variant: {
12 | default: "bg-primary text-primary-foreground hover:bg-primary/90",
13 | destructive:
14 | "bg-destructive text-destructive-foreground hover:bg-destructive/90",
15 | outline:
16 | "border border-input bg-background hover:bg-accent hover:text-accent-foreground",
17 | secondary:
18 | "bg-secondary text-secondary-foreground hover:bg-secondary/80",
19 | ghost: "hover:bg-accent hover:text-accent-foreground",
20 | link: "text-primary underline-offset-4 hover:underline",
21 | },
22 | size: {
23 | default: "h-10 px-4 py-2",
24 | sm: "h-9 rounded-md px-3",
25 | lg: "h-11 rounded-md px-8",
26 | icon: "h-10 w-10",
27 | },
28 | },
29 | defaultVariants: {
30 | variant: "default",
31 | size: "default",
32 | },
33 | }
34 | );
35 |
36 | export interface ButtonProps
37 | extends React.ButtonHTMLAttributes,
38 | VariantProps {
39 | asChild?: boolean;
40 | }
41 |
42 | const Button = React.forwardRef(
43 | ({ className, variant, size, asChild = false, ...props }, ref) => {
44 | const Comp = asChild ? Slot : "button";
45 | return (
46 |
51 | );
52 | }
53 | );
54 | Button.displayName = "Button";
55 |
56 | export { Button, buttonVariants };
57 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/checkbox.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as CheckboxPrimitive from "@radix-ui/react-checkbox";
3 | import { Check } from "lucide-react";
4 |
5 | import { cn } from "../../utils/util";
6 |
7 | const Checkbox = React.forwardRef<
8 | React.ElementRef,
9 | React.ComponentPropsWithoutRef
10 | >(({ className, ...props }, ref) => (
11 |
19 |
22 |
23 |
24 |
25 | ));
26 | Checkbox.displayName = CheckboxPrimitive.Root.displayName;
27 |
28 | export { Checkbox };
29 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/dialog.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as DialogPrimitive from "@radix-ui/react-dialog";
3 | import { X } from "lucide-react";
4 |
5 | import { cn } from "../../utils/util";
6 |
7 | const Dialog = DialogPrimitive.Root;
8 |
9 | const DialogTrigger = DialogPrimitive.Trigger;
10 |
11 | const DialogPortal = DialogPrimitive.Portal;
12 |
13 | const DialogClose = DialogPrimitive.Close;
14 |
15 | const DialogOverlay = React.forwardRef<
16 | React.ElementRef,
17 | React.ComponentPropsWithoutRef
18 | >(({ className, ...props }, ref) => (
19 |
27 | ));
28 | DialogOverlay.displayName = DialogPrimitive.Overlay.displayName;
29 |
30 | const DialogContent = React.forwardRef<
31 | React.ElementRef,
32 | React.ComponentPropsWithoutRef
33 | >(({ className, children, ...props }, ref) => (
34 |
35 |
36 |
44 | {children}
45 |
46 |
47 | Close
48 |
49 |
50 |
51 | ));
52 | DialogContent.displayName = DialogPrimitive.Content.displayName;
53 |
54 | const DialogHeader = ({
55 | className,
56 | ...props
57 | }: React.HTMLAttributes) => (
58 |
65 | );
66 | DialogHeader.displayName = "DialogHeader";
67 |
68 | const DialogFooter = ({
69 | className,
70 | ...props
71 | }: React.HTMLAttributes) => (
72 |
79 | );
80 | DialogFooter.displayName = "DialogFooter";
81 |
82 | const DialogTitle = React.forwardRef<
83 | React.ElementRef,
84 | React.ComponentPropsWithoutRef
85 | >(({ className, ...props }, ref) => (
86 |
94 | ));
95 | DialogTitle.displayName = DialogPrimitive.Title.displayName;
96 |
97 | const DialogDescription = React.forwardRef<
98 | React.ElementRef,
99 | React.ComponentPropsWithoutRef
100 | >(({ className, ...props }, ref) => (
101 |
106 | ));
107 | DialogDescription.displayName = DialogPrimitive.Description.displayName;
108 |
109 | export {
110 | Dialog,
111 | DialogPortal,
112 | DialogOverlay,
113 | DialogClose,
114 | DialogTrigger,
115 | DialogContent,
116 | DialogHeader,
117 | DialogFooter,
118 | DialogTitle,
119 | DialogDescription,
120 | };
121 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/hover-card.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as HoverCardPrimitive from "@radix-ui/react-hover-card";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const HoverCard = HoverCardPrimitive.Root;
7 |
8 | const HoverCardTrigger = HoverCardPrimitive.Trigger;
9 |
10 | const HoverCardContent = React.forwardRef<
11 | React.ElementRef,
12 | React.ComponentPropsWithoutRef
13 | >(({ className, align = "center", sideOffset = 4, ...props }, ref) => (
14 |
24 | ));
25 | HoverCardContent.displayName = HoverCardPrimitive.Content.displayName;
26 |
27 | export { HoverCard, HoverCardTrigger, HoverCardContent };
28 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/input.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { cn } from "../../utils/util";
4 |
5 | export interface InputProps
6 | extends React.InputHTMLAttributes {}
7 |
8 | const Input = React.forwardRef(
9 | ({ className, type, ...props }, ref) => {
10 | return (
11 |
20 | );
21 | }
22 | );
23 | Input.displayName = "Input";
24 |
25 | export { Input };
26 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/label.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import * as LabelPrimitive from "@radix-ui/react-label";
3 | import { cva, type VariantProps } from "class-variance-authority";
4 |
5 | import { cn } from "../../utils/util";
6 |
7 | const labelVariants = cva(
8 | "text-sm leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
9 | );
10 |
11 | const Label = React.forwardRef<
12 | React.ElementRef,
13 | React.ComponentPropsWithoutRef &
14 | VariantProps
15 | >(({ className, ...props }, ref) => (
16 |
21 | ));
22 | Label.displayName = LabelPrimitive.Root.displayName;
23 |
24 | export { Label };
25 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/pagination.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import { ChevronLeft, ChevronRight, MoreHorizontal } from "lucide-react";
3 |
4 | import { cn } from "../../utils/util";
5 | import { ButtonProps, buttonVariants } from "./button";
6 |
7 | const Pagination = ({ className, ...props }: React.ComponentProps<"nav">) => (
8 |
14 | );
15 | Pagination.displayName = "Pagination";
16 |
17 | const PaginationContent = React.forwardRef<
18 | HTMLUListElement,
19 | React.ComponentProps<"ul">
20 | >(({ className, ...props }, ref) => (
21 |
26 | ));
27 | PaginationContent.displayName = "PaginationContent";
28 |
29 | const PaginationItem = React.forwardRef<
30 | HTMLLIElement,
31 | React.ComponentProps<"li">
32 | >(({ className, ...props }, ref) => (
33 |
34 | ));
35 | PaginationItem.displayName = "PaginationItem";
36 |
37 | type PaginationLinkProps = {
38 | isActive?: boolean;
39 | } & Pick &
40 | React.ComponentProps<"a">;
41 |
42 | const PaginationLink = ({
43 | className,
44 | isActive,
45 | size = "icon",
46 | children,
47 | ...props
48 | }: PaginationLinkProps & { children?: React.ReactNode }) => (
49 |
60 | {children || }
61 |
62 | );
63 | PaginationLink.displayName = "PaginationLink";
64 |
65 | const PaginationPrevious = ({
66 | className,
67 | ...props
68 | }: React.ComponentProps) => (
69 |
75 |
76 | {/* Previous */}
77 |
78 | );
79 | PaginationPrevious.displayName = "PaginationPrevious";
80 |
81 | const PaginationNext = ({
82 | className,
83 | ...props
84 | }: React.ComponentProps) => (
85 |
91 | {/* Next */}
92 |
93 |
94 | );
95 | PaginationNext.displayName = "PaginationNext";
96 |
97 | const PaginationEllipsis = ({
98 | className,
99 | ...props
100 | }: React.ComponentProps<"span">) => (
101 |
106 |
107 | More pages
108 |
109 | );
110 | PaginationEllipsis.displayName = "PaginationEllipsis";
111 |
112 | export {
113 | Pagination,
114 | PaginationContent,
115 | PaginationEllipsis,
116 | PaginationItem,
117 | PaginationLink,
118 | PaginationNext,
119 | PaginationPrevious,
120 | };
121 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/radio-group.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as RadioGroupPrimitive from "@radix-ui/react-radio-group";
3 | import { Circle } from "lucide-react";
4 |
5 | import { cn } from "../../utils/util";
6 |
7 | const RadioGroup = React.forwardRef<
8 | React.ElementRef,
9 | React.ComponentPropsWithoutRef
10 | >(({ className, ...props }, ref) => {
11 | return (
12 |
17 | );
18 | });
19 | RadioGroup.displayName = RadioGroupPrimitive.Root.displayName;
20 |
21 | const RadioGroupItem = React.forwardRef<
22 | React.ElementRef,
23 | React.ComponentPropsWithoutRef
24 | >(({ className, color, ...props }, ref) => {
25 | return (
26 |
41 |
42 |
43 |
44 |
45 | );
46 | });
47 | RadioGroupItem.displayName = RadioGroupPrimitive.Item.displayName;
48 |
49 | export { RadioGroup, RadioGroupItem };
50 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/scroll-area.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const ScrollArea = React.forwardRef<
7 | React.ElementRef,
8 | React.ComponentPropsWithoutRef
9 | >(({ className, children, ...props }, ref) => (
10 |
15 |
16 | {children}
17 |
18 |
19 |
20 |
21 | ));
22 | ScrollArea.displayName = ScrollAreaPrimitive.Root.displayName;
23 |
24 | const ScrollBar = React.forwardRef<
25 | React.ElementRef,
26 | React.ComponentPropsWithoutRef
27 | >(({ className, orientation = "vertical", ...props }, ref) => (
28 |
41 |
42 |
43 | ));
44 | ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName;
45 |
46 | export { ScrollArea, ScrollBar };
47 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/separator.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as SeparatorPrimitive from "@radix-ui/react-separator";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const Separator = React.forwardRef<
7 | React.ElementRef,
8 | React.ComponentPropsWithoutRef
9 | >(
10 | (
11 | { className, orientation = "horizontal", decorative = true, ...props },
12 | ref
13 | ) => (
14 |
25 | )
26 | );
27 | Separator.displayName = SeparatorPrimitive.Root.displayName;
28 |
29 | export { Separator };
30 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/slider.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import * as SliderPrimitive from "@radix-ui/react-slider";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const Slider = React.forwardRef<
7 | React.ElementRef,
8 | React.ComponentPropsWithoutRef
9 | >(({ className, ...props }, ref) => (
10 |
18 |
19 |
20 |
21 |
22 |
23 | ));
24 | Slider.displayName = SliderPrimitive.Root.displayName;
25 |
26 | export { Slider };
27 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/stepper.tsx:
--------------------------------------------------------------------------------
1 | import { cn } from "../../utils/util";
2 |
3 | interface StepperProps {
4 | className?: string;
5 | children: React.ReactNode;
6 | }
7 |
8 | const Stepper = ({ className, children, ...delegated }: StepperProps) => {
9 | return (
10 |
11 | {children}
12 |
13 | );
14 | };
15 |
16 | const StepperItem = ({ className, children, ...delegated }: StepperProps) => {
17 | return (
18 |
25 | {children}
26 |
27 | );
28 | };
29 |
30 | const StepperTrigger = ({
31 | className,
32 | children,
33 | ...delegated
34 | }: StepperProps) => {
35 | return (
36 |
45 | );
46 | };
47 |
48 | const StepperSeparator = ({ className, ...delegated }: StepperProps) => {
49 | return (
50 |
61 | );
62 | };
63 |
64 | const StepperTitle = ({ className, children, ...delegated }: StepperProps) => {
65 | return (
66 |
70 | {children}
71 |
72 | );
73 | };
74 |
75 | const StepperDescription = ({
76 | className,
77 | children,
78 | ...delegated
79 | }: StepperProps) => {
80 | return (
81 |
85 | {children}
86 |
87 | );
88 | };
89 |
90 | export {
91 | Stepper,
92 | StepperItem,
93 | StepperTrigger,
94 | StepperSeparator,
95 | StepperTitle,
96 | StepperDescription,
97 | };
98 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/table.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { cn } from "../../utils/util";
4 |
5 | const Table = React.forwardRef<
6 | HTMLTableElement,
7 | React.HTMLAttributes
8 | >(({ className, ...props }, ref) => (
9 |
16 | ));
17 | Table.displayName = "Table";
18 |
19 | const TableHeader = React.forwardRef<
20 | HTMLTableSectionElement,
21 | React.HTMLAttributes
22 | >(({ className, ...props }, ref) => (
23 |
24 | ));
25 | TableHeader.displayName = "TableHeader";
26 |
27 | const TableBody = React.forwardRef<
28 | HTMLTableSectionElement,
29 | React.HTMLAttributes
30 | >(({ className, ...props }, ref) => (
31 |
36 | ));
37 | TableBody.displayName = "TableBody";
38 |
39 | const TableFooter = React.forwardRef<
40 | HTMLTableSectionElement,
41 | React.HTMLAttributes
42 | >(({ className, ...props }, ref) => (
43 | tr]:last:border-b-0",
47 | className
48 | )}
49 | {...props}
50 | />
51 | ));
52 | TableFooter.displayName = "TableFooter";
53 |
54 | const TableRow = React.forwardRef<
55 | HTMLTableRowElement,
56 | React.HTMLAttributes
57 | >(({ className, ...props }, ref) => (
58 |
66 | ));
67 | TableRow.displayName = "TableRow";
68 |
69 | const TableHead = React.forwardRef<
70 | HTMLTableCellElement,
71 | React.ThHTMLAttributes
72 | >(({ className, ...props }, ref) => (
73 | |
81 | ));
82 | TableHead.displayName = "TableHead";
83 |
84 | const TableCell = React.forwardRef<
85 | HTMLTableCellElement,
86 | React.TdHTMLAttributes
87 | >(({ className, ...props }, ref) => (
88 | |
96 | ));
97 | TableCell.displayName = "TableCell";
98 |
99 | const TableCaption = React.forwardRef<
100 | HTMLTableCaptionElement,
101 | React.HTMLAttributes
102 | >(({ className, ...props }, ref) => (
103 |
108 | ));
109 | TableCaption.displayName = "TableCaption";
110 |
111 | export {
112 | Table,
113 | TableHeader,
114 | TableBody,
115 | TableFooter,
116 | TableHead,
117 | TableRow,
118 | TableCell,
119 | TableCaption,
120 | };
121 |
--------------------------------------------------------------------------------
/frontend/src/components/UI/tabs.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react";
2 | import * as TabsPrimitive from "@radix-ui/react-tabs";
3 |
4 | import { cn } from "../../utils/util";
5 |
6 | const Tabs = TabsPrimitive.Root;
7 |
8 | const TabsList = React.forwardRef<
9 | React.ElementRef,
10 | React.ComponentPropsWithoutRef
11 | >(({ className, ...props }, ref) => (
12 |
20 | ));
21 | TabsList.displayName = TabsPrimitive.List.displayName;
22 |
23 | const TabsTrigger = React.forwardRef<
24 | React.ElementRef,
25 | React.ComponentPropsWithoutRef
26 | >(({ className, ...props }, ref) => (
27 |
35 | ));
36 | TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
37 |
38 | const TabsContent = React.forwardRef<
39 | React.ElementRef,
40 | React.ComponentPropsWithoutRef
41 | >(({ className, ...props }, ref) => (
42 |
50 | ));
51 | TabsContent.displayName = TabsPrimitive.Content.displayName;
52 |
53 | export { Tabs, TabsList, TabsTrigger, TabsContent };
54 |
--------------------------------------------------------------------------------
/frontend/src/components/common/CustomButton.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { Button } from "../UI/button";
4 | import { cn } from "../../utils/util";
5 |
6 | interface Props
7 | extends Omit, "className"> {
8 | children: React.ReactNode;
9 | className?: string;
10 | }
11 |
12 | export default function CustomButton({ children, className, ...props }: Props) {
13 | return (
14 |
23 | );
24 | }
25 |
--------------------------------------------------------------------------------
/frontend/src/components/common/DatasetModeSelector.tsx:
--------------------------------------------------------------------------------
1 | import { Label } from "../UI/label";
2 | import { RadioGroup, RadioGroupItem } from "../UI/radio-group";
3 | import { TRAIN, TEST } from "../../constants/common";
4 |
5 | interface Props {
6 | dataset: string;
7 | onValueChange: (value: string) => void;
8 | }
9 |
10 | export default function DatasetModeSelector({ dataset, onValueChange }: Props) {
11 | const isTrainChecked = dataset === TRAIN;
12 |
13 | return (
14 |
15 |
20 |
21 |
28 |
31 |
32 |
33 |
40 |
43 |
44 |
45 |
46 | );
47 | }
48 |
--------------------------------------------------------------------------------
/frontend/src/components/common/Indicator.tsx:
--------------------------------------------------------------------------------
1 | import { cn } from "../../utils/util";
2 |
3 | interface Props {
4 | about?: "AB" | "ForgetClass";
5 | text?: string;
6 | }
7 |
8 | export default function Indicator({ about, text }: Props) {
9 | const content = text
10 | ? text
11 | : about === "AB"
12 | ? "Select both model A and model B."
13 | : "Select the target forget class first.";
14 |
15 | return (
16 |
22 | {content}
23 |
24 | );
25 | }
26 |
--------------------------------------------------------------------------------
/frontend/src/components/common/Subtitle.tsx:
--------------------------------------------------------------------------------
1 | import { cn } from "../../utils/util";
2 |
3 | interface Props {
4 | title: string;
5 | AdditionalContent?: JSX.Element | false;
6 | className?: string;
7 | }
8 |
9 | export default function Subtitle({
10 | title,
11 | AdditionalContent,
12 | className,
13 | }: Props) {
14 | return (
15 |
21 | {title}
22 | {AdditionalContent}
23 |
24 | );
25 | }
26 |
--------------------------------------------------------------------------------
/frontend/src/components/common/Title.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import { cn } from "../../utils/util";
4 |
5 | interface Props {
6 | title: string;
7 | id?: string;
8 | className?: string;
9 | AdditionalContent?: JSX.Element | false;
10 | onClick?: (e: React.MouseEvent) => void;
11 | }
12 |
13 | export default function Title({
14 | title,
15 | id,
16 | className,
17 | AdditionalContent,
18 | onClick,
19 | }: Props) {
20 | return (
21 |
29 | {title}
30 | {AdditionalContent}
31 |
32 | );
33 | }
34 |
--------------------------------------------------------------------------------
/frontend/src/components/common/View.tsx:
--------------------------------------------------------------------------------
1 | import { cn } from "../../utils/util";
2 |
3 | interface Props {
4 | width?: number;
5 | height: number | string;
6 | className?: string;
7 | borderTop?: boolean;
8 | borderRight?: boolean;
9 | borderBottom?: boolean;
10 | borderLeft?: boolean;
11 | children: React.ReactNode;
12 | }
13 |
14 | export default function View({
15 | width,
16 | height,
17 | className,
18 | borderTop = false,
19 | borderRight = false,
20 | borderBottom = false,
21 | borderLeft = false,
22 | children,
23 | }: Props) {
24 | return (
25 |
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/frontend/src/constants/colors.ts:
--------------------------------------------------------------------------------
1 | export const TABLEAU10 = [
2 | "#4E79A7",
3 | "#F28E2B",
4 | "#E15759",
5 | "#76B7B2",
6 | "#59A14F",
7 | "#EDC948",
8 | "#B07AA1",
9 | "#FF9DA7",
10 | "#9C755F",
11 | "#BAB0AC",
12 | ];
13 |
14 | export const COLORS = {
15 | WHITE: "#FFF",
16 | BLACK: "#000",
17 |
18 | PURPLE: "#A855F7",
19 | LIGHT_PURPLE: "#E6D0FD",
20 |
21 | EMERALD: "#10B981",
22 | LIGHT_EMERALD: "#C8EBDA",
23 |
24 | GRAY: "#777",
25 | DARK_GRAY: "#6a6a6a",
26 | LIGHT_GRAY: "#D4D4D4",
27 |
28 | GRID_COLOR: "#EFEFEF",
29 |
30 | BUTTON_BG_COLOR: "#585858",
31 | HOVERED_BUTTON_BG_COLOR: "#696969",
32 | } as const;
33 |
--------------------------------------------------------------------------------
/frontend/src/constants/common.ts:
--------------------------------------------------------------------------------
1 | // export const API_URL = "http://115.145.171.130:8000";
2 | export const API_URL = "http://localhost:8000";
3 |
4 | export const DATASETS = ["CIFAR-10", "Fashion-MNIST"];
5 | export const NEURAL_NETWORK_MODELS = ["ResNet-18", "ViT-B/16"];
6 |
7 | export const TRAIN = "train";
8 | export const TEST = "test";
9 |
10 | export const BASELINE = "baseline";
11 | export const COMPARISON = "comparison";
12 |
13 | export const ENTROPY = "entropy";
14 | export const CONFIDENCE = "confidence";
15 |
16 | export const UNLEARN = "unlearn";
17 | export const RETRAIN = "retrain";
18 |
19 | export const CIFAR_10_CLASSES = [
20 | "airplane",
21 | "automobile",
22 | "bird",
23 | "cat",
24 | "deer",
25 | "dog",
26 | "frog",
27 | "horse",
28 | "ship",
29 | "truck",
30 | ];
31 |
32 | export const FASHION_MNIST_CLASSES = [
33 | "T-shirt",
34 | "Trouser",
35 | "Pullover",
36 | "Dress",
37 | "Coat",
38 | "Sandal",
39 | "Shirt",
40 | "Sneaker",
41 | "Bag",
42 | "Boot",
43 | ];
44 |
45 | export const ANIMATION_DURATION = 500;
46 |
47 | export const FONT_CONFIG = {
48 | LIGHT_FONT_WEIGHT: 300,
49 |
50 | FONT_SIZE_10: 10,
51 | FONT_SIZE_12: 12,
52 | FONT_SIZE_13: 13,
53 | FONT_SIZE_14: 14,
54 |
55 | ROBOTO_CONDENSED: "Roboto Condensed",
56 | };
57 |
58 | export const STROKE_CONFIG = {
59 | DEFAULT_STROKE_WIDTH: 2,
60 | LIGHT_STROKE_WIDTH: 1,
61 | STROKE_DASHARRAY: "3 3",
62 | };
63 |
--------------------------------------------------------------------------------
/frontend/src/constants/correlations.ts:
--------------------------------------------------------------------------------
1 | import { COLORS } from "./colors";
2 | import { ChartConfig } from "../components/UI/chart";
3 |
4 | export const CKA_DATA_KEYS = [
5 | "modelAForgetCka",
6 | "modelAOtherCka",
7 | "modelBForgetCka",
8 | "modelBOtherCka",
9 | ];
10 |
11 | export const LINE_CHART_TICK_STYLE = `
12 | .recharts-cartesian-axis-tick text {
13 | fill: ${COLORS.BLACK} !important;
14 | }
15 | `;
16 |
17 | export const LINE_CHART_CONFIG = {
18 | layer: {
19 | label: "Layer",
20 | color: "#000",
21 | },
22 | modelAForgetCka: {
23 | color: COLORS.EMERALD,
24 | },
25 | modelAOtherCka: {
26 | color: COLORS.EMERALD,
27 | },
28 | modelBForgetCka: {
29 | color: COLORS.PURPLE,
30 | },
31 | modelBOtherCka: {
32 | color: COLORS.PURPLE,
33 | },
34 | } satisfies ChartConfig;
35 |
36 | export const LINE_CHART_LEGEND_DATA = [
37 | {
38 | type: "circle",
39 | color: COLORS.EMERALD,
40 | label: "Model A",
41 | spacing: "py-0.5",
42 | },
43 | {
44 | type: "circle",
45 | color: COLORS.PURPLE,
46 | label: "Model B",
47 | spacing: "py-0.5",
48 | },
49 | {
50 | type: "cross",
51 | color: COLORS.EMERALD,
52 | label: "Model A",
53 | spacing: "py-0.5",
54 | },
55 | {
56 | type: "cross",
57 | color: COLORS.PURPLE,
58 | label: "Model B",
59 | spacing: "py-0.5",
60 | },
61 | ] as const;
62 |
--------------------------------------------------------------------------------
/frontend/src/constants/embeddings.ts:
--------------------------------------------------------------------------------
1 | import { ViewMode } from "../types/embeddings";
2 |
3 | export const VIEW_MODES: ViewMode[] = [
4 | {
5 | label: "All",
6 | explanation: "Shows all instances from both retain and forget classes.",
7 | length: 60,
8 | },
9 | {
10 | label: "Target to Forget",
11 | explanation:
12 | "Highlights all forget class instances that the model is supposed to unlearn.",
13 | length: 125,
14 | },
15 | {
16 | label: "Correctly Forgotten",
17 | explanation:
18 | "Highlights forget class instances that the model successfully unlearned and now misclassifies.",
19 | length: 145,
20 | },
21 | {
22 | label: "Not Forgotten",
23 | explanation:
24 | "Highlights forget class instances that the model failed to unlearn and still correctly classifies.",
25 | length: 110,
26 | },
27 | {
28 | label: "Overly Forgotten",
29 | explanation:
30 | "Highlights retain class instances that the model was not supposed to unlearn but did.",
31 | length: 130,
32 | },
33 | ];
34 |
--------------------------------------------------------------------------------
/frontend/src/constants/experiments.ts:
--------------------------------------------------------------------------------
1 | export const UNLEARNING_METHODS = [
2 | "Fine-Tuning",
3 | "Random Labeling",
4 | "Gradient Ascent",
5 | "MyMethod1",
6 | "Upload",
7 | ];
8 |
9 | export const EPOCH = "epoch";
10 | export const LEARNING_RATE = "learningRate";
11 | export const BATCH_SIZE = "batchSize";
12 |
--------------------------------------------------------------------------------
/frontend/src/constants/privacyAttack.ts:
--------------------------------------------------------------------------------
1 | import { TABLEAU10 } from "./colors";
2 |
3 | interface LineGraphLegendData {
4 | color: string;
5 | label: string;
6 | }
7 |
8 | interface ThresholdStrategy {
9 | strategy: string;
10 | explanation: string;
11 | length: number;
12 | }
13 |
14 | export const THRESHOLD_STRATEGIES: ThresholdStrategy[] = [
15 | {
16 | strategy: "Custom Threshold",
17 | explanation:
18 | "Manually set the threshold by dragging a slider for custom control.",
19 | length: 150,
20 | },
21 | {
22 | strategy: "Max Attack Score",
23 | explanation:
24 | "Maximizes the attack score based on false positive and false negative rates.",
25 | length: 155,
26 | },
27 | {
28 | strategy: "Max Success Rate",
29 | explanation:
30 | "Maximizes the probability of correctly identifying the model's type as retrained or unlearned.",
31 | length: 150,
32 | },
33 | {
34 | strategy: "Common Threshold",
35 | explanation:
36 | "Sets a single threshold that maximizes the sum of attack scores across two different models.",
37 | length: 160,
38 | },
39 | ];
40 |
41 | export const LINE_GRAPH_LEGEND_DATA: LineGraphLegendData[] = [
42 | { color: TABLEAU10[2], label: "Attack Score" }, // red
43 | { color: TABLEAU10[0], label: "False Positive Rate" }, // blue
44 | { color: TABLEAU10[4], label: "False Negative Rate" }, // green
45 | ];
46 |
--------------------------------------------------------------------------------
/frontend/src/hooks/useClasses.ts:
--------------------------------------------------------------------------------
1 | import { useBaseConfigStore } from "../stores/baseConfigStore";
2 | import { CIFAR_10_CLASSES, FASHION_MNIST_CLASSES } from "../constants/common";
3 |
4 | export function useClasses() {
5 | const dataset = useBaseConfigStore((state) => state.dataset);
6 | return dataset === "CIFAR-10" ? CIFAR_10_CLASSES : FASHION_MNIST_CLASSES;
7 | }
8 |
--------------------------------------------------------------------------------
/frontend/src/hooks/useModelExperiment.ts:
--------------------------------------------------------------------------------
1 | import { useExperimentsStore } from "../stores/experimentsStore";
2 | import { useModelDataStore } from "../stores/modelDataStore";
3 |
4 | export const useModelAExperiment = () => {
5 | const experiments = useExperimentsStore((state) => state.experiments);
6 | const modelA = useModelDataStore((state) => state.modelA);
7 | return experiments[modelA];
8 | };
9 |
10 | export const useModelBExperiment = () => {
11 | const experiments = useExperimentsStore((state) => state.experiments);
12 | const modelB = useModelDataStore((state) => state.modelB);
13 | return experiments[modelB];
14 | };
15 |
--------------------------------------------------------------------------------
/frontend/src/index.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 | import ReactDOM from "react-dom/client";
3 |
4 | import "./app/index.css";
5 | import App from "./app/App";
6 |
7 | const root = ReactDOM.createRoot(
8 | document.getElementById("root") as HTMLElement
9 | );
10 | root.render(
11 |
12 |
13 |
14 | );
15 |
--------------------------------------------------------------------------------
/frontend/src/stores/attackStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 |
3 | import { ENTROPY, UNLEARN } from "../constants/common";
4 | import { THRESHOLD_STRATEGIES } from "../constants/privacyAttack";
5 |
6 | type AttackState = {
7 | metric: string;
8 | direction: string;
9 | strategy: string;
10 | worstCaseModel: "A" | "B" | null;
11 | setMetric: (metric: string) => void;
12 | setDirection: (direction: string) => void;
13 | setStrategy: (strategy: string) => void;
14 | setWorstCaseModel: (model: "A" | "B") => void;
15 | };
16 |
17 | export const useAttackStateStore = create((set, get) => ({
18 | metric: ENTROPY,
19 | direction: UNLEARN,
20 | strategy: THRESHOLD_STRATEGIES[0].strategy,
21 | worstCaseModel: null,
22 | setMetric: (metric: string) => set({ metric }),
23 | setDirection: (direction: string) => set({ direction }),
24 | setStrategy: (strategy: string) => set({ strategy }),
25 | setWorstCaseModel: (model: "A" | "B") => set({ worstCaseModel: model }),
26 | }));
27 |
--------------------------------------------------------------------------------
/frontend/src/stores/baseConfigStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 |
4 | import { DATASETS, NEURAL_NETWORK_MODELS } from "../constants/common";
5 |
6 | type BaseConfigState = {
7 | dataset: string;
8 | neuralNetworkModel: string;
9 | setDataset: (dataset: string) => void;
10 | setNeuralNetworkModel: (neuralNetworkModel: string) => void;
11 | };
12 |
13 | export const useBaseConfigStore = create()(
14 | persist(
15 | (set) => ({
16 | dataset: DATASETS[0],
17 | neuralNetworkModel: NEURAL_NETWORK_MODELS[0],
18 | setDataset: (dataset) => set({ dataset }),
19 | setNeuralNetworkModel: (neuralNetworkModel) =>
20 | set({ neuralNetworkModel }),
21 | }),
22 | {
23 | name: "config",
24 | storage: {
25 | getItem: (key) => {
26 | const value = sessionStorage.getItem(key);
27 | return value ? JSON.parse(value) : null;
28 | },
29 | setItem: (key, value) =>
30 | sessionStorage.setItem(key, JSON.stringify(value)),
31 | removeItem: (key) => sessionStorage.removeItem(key),
32 | },
33 | }
34 | )
35 | );
36 |
--------------------------------------------------------------------------------
/frontend/src/stores/experimentsStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 | import { Experiments } from "../types/data";
4 | import { ExperimentData } from "../types/data";
5 |
6 | type ExperimentsState = {
7 | experiments: Experiments;
8 | isExperimentLoading: boolean;
9 | addExperiment: (experiment: ExperimentData, tempIdx?: number) => void;
10 | updateExperiment: (experiment: ExperimentData, idx: number) => void;
11 | saveExperiments: (experiments: Experiments) => void;
12 | deleteExperiment: (id: string) => void;
13 | setIsExperimentsLoading: (loading: boolean) => void;
14 | };
15 |
16 | export const useExperimentsStore = create()(
17 | persist(
18 | (set, get) => ({
19 | experiments: {},
20 | isExperimentLoading: false,
21 |
22 | addExperiment: (experiment, tempIdx) => {
23 | const { points, ...experimentWithoutPoints } = experiment;
24 | const newExperiments =
25 | tempIdx !== undefined
26 | ? {
27 | ...get().experiments,
28 | [tempIdx]: experimentWithoutPoints,
29 | }
30 | : {
31 | ...get().experiments,
32 | [experiment.ID]: experimentWithoutPoints,
33 | };
34 | set({ experiments: newExperiments });
35 | },
36 |
37 | updateExperiment: (experiment, idx) => {
38 | const { [idx]: _, ...remainingExperiments } = get().experiments;
39 | set({
40 | experiments: {
41 | ...remainingExperiments,
42 | [experiment.ID]: experiment,
43 | },
44 | });
45 | },
46 |
47 | saveExperiments: (experiments) => set({ experiments }),
48 |
49 | deleteExperiment: (id) => {
50 | const { [id]: _, ...remainingExperiments } = get().experiments;
51 | set({ experiments: remainingExperiments });
52 | },
53 |
54 | setIsExperimentsLoading: (loading) =>
55 | set({ isExperimentLoading: loading }),
56 | }),
57 | {
58 | name: "experiments",
59 | storage: {
60 | getItem: (key) => {
61 | const value = sessionStorage.getItem(key);
62 | return value ? JSON.parse(value) : null;
63 | },
64 | setItem: (key, value) =>
65 | sessionStorage.setItem(key, JSON.stringify(value)),
66 | removeItem: (key) => sessionStorage.removeItem(key),
67 | },
68 | }
69 | )
70 | );
71 |
--------------------------------------------------------------------------------
/frontend/src/stores/forgetClassStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 | import { CIFAR_10_CLASSES, FASHION_MNIST_CLASSES } from "../constants/common";
4 | import { useBaseConfigStore } from "./baseConfigStore";
5 |
6 | type ForgetClassState = {
7 | forgetClass: number;
8 | selectedForgetClasses: number[];
9 | saveForgetClass: (forgetClass: string | number) => void;
10 | addSelectedForgetClass: (forgetClass: string) => void;
11 | deleteSelectedForgetClass: (forgetClass: string) => void;
12 | };
13 |
14 | const getClassesForDataset = () => {
15 | const dataset = useBaseConfigStore.getState().dataset;
16 | return dataset === "CIFAR-10" ? CIFAR_10_CLASSES : FASHION_MNIST_CLASSES;
17 | };
18 |
19 | export const useForgetClassStore = create()(
20 | persist(
21 | (set, get) => ({
22 | forgetClass: -1,
23 | selectedForgetClasses: [],
24 |
25 | saveForgetClass: (forgetClass) => {
26 | const classes = getClassesForDataset();
27 |
28 | set({
29 | forgetClass:
30 | typeof forgetClass === "string"
31 | ? classes.indexOf(forgetClass)
32 | : forgetClass,
33 | });
34 | },
35 |
36 | addSelectedForgetClass: (forgetClass) => {
37 | const classes = getClassesForDataset();
38 |
39 | const target = classes.indexOf(forgetClass);
40 | if (!get().selectedForgetClasses.includes(target)) {
41 | set({
42 | selectedForgetClasses: [...get().selectedForgetClasses, target],
43 | });
44 | }
45 | },
46 |
47 | deleteSelectedForgetClass: (forgetClass) => {
48 | const classes = getClassesForDataset();
49 |
50 | const target = classes.indexOf(forgetClass);
51 | set({
52 | selectedForgetClasses: get().selectedForgetClasses.filter(
53 | (item) => item !== target
54 | ),
55 | });
56 | },
57 | }),
58 | {
59 | name: "forgetclass",
60 | storage: {
61 | getItem: (key) => {
62 | const value = sessionStorage.getItem(key);
63 | return value ? JSON.parse(value) : null;
64 | },
65 | setItem: (key, value) =>
66 | sessionStorage.setItem(key, JSON.stringify(value)),
67 | removeItem: (key) => sessionStorage.removeItem(key),
68 | },
69 | }
70 | )
71 | );
72 |
--------------------------------------------------------------------------------
/frontend/src/stores/modelDataStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 |
4 | type ModelState = {
5 | modelA: string;
6 | modelB: string;
7 | saveModelA: (modelA: string) => void;
8 | saveModelB: (modelB: string) => void;
9 | };
10 |
11 | export const useModelDataStore = create()(
12 | persist(
13 | (set) => ({
14 | modelA: "",
15 | modelB: "",
16 | saveModelA: (modelA) => set({ modelA }),
17 | saveModelB: (modelB) => set({ modelB }),
18 | }),
19 | {
20 | name: "models",
21 | storage: {
22 | getItem: (key) => {
23 | const value = sessionStorage.getItem(key);
24 | return value ? JSON.parse(value) : null;
25 | },
26 | setItem: (key, value) =>
27 | sessionStorage.setItem(key, JSON.stringify(value)),
28 | removeItem: (key) => sessionStorage.removeItem(key),
29 | },
30 | }
31 | )
32 | );
33 |
--------------------------------------------------------------------------------
/frontend/src/stores/runningIndexStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 |
4 | type RunningIndexState = {
5 | runningIndex: number;
6 | updateRunningIndex: (runningIndex: number) => void;
7 | };
8 |
9 | export const useRunningIndexStore = create()(
10 | persist(
11 | (set, get) => ({
12 | runningIndex: 0,
13 | updateRunningIndex: (runningIndex) => {
14 | if (get().runningIndex !== runningIndex) {
15 | set({ runningIndex });
16 | }
17 | },
18 | }),
19 | {
20 | name: "running-index",
21 | storage: {
22 | getItem: (key) => {
23 | const value = sessionStorage.getItem(key);
24 | return value ? JSON.parse(value) : null;
25 | },
26 | setItem: (key, value) =>
27 | sessionStorage.setItem(key, JSON.stringify(value)),
28 | removeItem: (key) => sessionStorage.removeItem(key),
29 | },
30 | }
31 | )
32 | );
33 |
--------------------------------------------------------------------------------
/frontend/src/stores/runningStatusStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 | import { persist } from "zustand/middleware";
3 | import { UnlearningStatus } from "../types/experiments";
4 |
5 | const initialStatus: UnlearningStatus = {
6 | is_unlearning: false,
7 | progress: "Idle",
8 | recent_id: null,
9 | current_epoch: 0,
10 | total_epochs: 0,
11 | current_unlearn_loss: 0,
12 | current_unlearn_accuracy: 0,
13 | p_training_loss: 0,
14 | p_training_accuracy: 0,
15 | p_test_loss: 0,
16 | p_test_accuracy: 0,
17 | method: "",
18 | estimated_time_remaining: 0,
19 | elapsed_time: 0,
20 | completed_steps: [],
21 | learning_rate: 0,
22 | batch_size: 0,
23 | };
24 |
25 | type RunningStatusState = {
26 | isRunning: boolean;
27 | statuses: UnlearningStatus[][];
28 | activeStep: number;
29 | totalExperimentsCount: number;
30 | updateIsRunning: (isRunning: boolean) => void;
31 | initStatus: (forgetClass: number, count: number) => void;
32 | updateStatus: (payload: UpdateStatusPayload) => void;
33 | updateActiveStep: (step: number) => void;
34 | };
35 |
36 | type UpdateStatusPayload = {
37 | status: UnlearningStatus;
38 | forgetClass: number;
39 | experimentIndex: number;
40 | progress: string;
41 | elapsedTime: number;
42 | completedSteps: number[];
43 | learningRate?: number;
44 | batchSize?: number;
45 | };
46 |
47 | export const useRunningStatusStore = create()(
48 | persist(
49 | (set, get) => ({
50 | isRunning: false,
51 | statuses: Array.from({ length: 10 }, () => []),
52 | activeStep: 0,
53 | totalExperimentsCount: 0,
54 |
55 | updateIsRunning: (isRunning) => set({ isRunning }),
56 |
57 | initStatus: (forgetClass, count) => {
58 | const newStatuses = [...get().statuses];
59 | newStatuses[forgetClass] = Array.from({ length: count }, () => ({
60 | ...initialStatus,
61 | }));
62 | set({ statuses: newStatuses, totalExperimentsCount: count });
63 | },
64 |
65 | updateStatus: ({
66 | status,
67 | forgetClass,
68 | experimentIndex,
69 | progress,
70 | elapsedTime,
71 | completedSteps,
72 | learningRate,
73 | batchSize,
74 | }) => {
75 | const classStatuses = get().statuses[forgetClass] || [];
76 | if (experimentIndex < 0 || experimentIndex >= classStatuses.length)
77 | return;
78 |
79 | const currentStatus = classStatuses[experimentIndex];
80 | const newStatus: UnlearningStatus = {
81 | ...currentStatus,
82 | ...status,
83 | progress,
84 | elapsed_time: elapsedTime,
85 | completed_steps: completedSteps,
86 | learning_rate: learningRate,
87 | batch_size: batchSize,
88 | };
89 |
90 | const { elapsed_time: _, ...oldWithoutTime } = currentStatus;
91 | const { elapsed_time: __, ...newWithoutTime } = newStatus;
92 | if (JSON.stringify(oldWithoutTime) === JSON.stringify(newWithoutTime))
93 | return;
94 |
95 | const updatedClassStatuses = [...classStatuses];
96 | updatedClassStatuses[experimentIndex] = newStatus;
97 | const newStatuses = [...get().statuses];
98 | newStatuses[forgetClass] = updatedClassStatuses;
99 | set({ statuses: newStatuses });
100 | },
101 |
102 | updateActiveStep: (step) => set({ activeStep: step }),
103 | }),
104 | {
105 | name: "running-status",
106 | storage: {
107 | getItem: (key) => {
108 | const value = sessionStorage.getItem(key);
109 | return value ? JSON.parse(value) : null;
110 | },
111 | setItem: (key, value) =>
112 | sessionStorage.setItem(key, JSON.stringify(value)),
113 | removeItem: (key) => sessionStorage.removeItem(key),
114 | },
115 | }
116 | )
117 | );
118 |
--------------------------------------------------------------------------------
/frontend/src/stores/thresholdStore.ts:
--------------------------------------------------------------------------------
1 | import { create } from "zustand";
2 |
3 | interface ThresholdState {
4 | strategyThresholds: {
5 | A: number[];
6 | B: number[];
7 | };
8 | initializeThresholds: (isMetricEntropy: boolean) => void;
9 | setStrategyThresholds: (mode: "A" | "B", thresholds: number[]) => void;
10 | }
11 |
12 | export const useThresholdStore = create((set) => ({
13 | strategyThresholds: {
14 | A: [1.25, 0, 0, 0],
15 | B: [1.25, 0, 0, 0],
16 | },
17 | initializeThresholds: (isMetricEntropy: boolean) =>
18 | set(() => ({
19 | strategyThresholds: {
20 | A: [isMetricEntropy ? 1.25 : 3.75, 0, 0, 0],
21 | B: [isMetricEntropy ? 1.25 : 3.75, 0, 0, 0],
22 | },
23 | })),
24 | setStrategyThresholds: (mode, thresholds) =>
25 | set((state) => ({
26 | strategyThresholds: {
27 | ...state.strategyThresholds,
28 | [mode]: thresholds,
29 | },
30 | })),
31 | }));
32 |
--------------------------------------------------------------------------------
/frontend/src/types/attack.ts:
--------------------------------------------------------------------------------
1 | import { AttackResult } from "./data";
2 |
3 | export interface Bin {
4 | img_idx: number;
5 | value: number;
6 | }
7 |
8 | export type AttackResultWithType = AttackResult & { type: string };
9 |
10 | export type Data = {
11 | retrainData: Bin[];
12 | unlearnData: Bin[];
13 | lineChartData: AttackResultWithType[];
14 | } | null;
15 |
16 | export type CategoryType = "unlearn" | "retrain";
17 |
18 | export interface Image {
19 | index: number;
20 | base64: string;
21 | }
22 |
23 | export interface TooltipData {
24 | img_idx: number;
25 | value: number;
26 | type: CategoryType;
27 | }
28 |
29 | export interface TooltipPosition {
30 | x: number;
31 | y: number;
32 | }
33 |
--------------------------------------------------------------------------------
/frontend/src/types/data.ts:
--------------------------------------------------------------------------------
1 | export interface GapDataItem {
2 | category: string;
3 | classLabel: string;
4 | gap: number;
5 | fill: string;
6 | modelAAccuracy: number;
7 | modelBAccuracy: number;
8 | }
9 |
10 | export type Dist = {
11 | [key: string]: number[];
12 | };
13 |
14 | type CKA = {
15 | layers: string[];
16 | train: {
17 | forget_class: number[][];
18 | other_classes: number[][];
19 | };
20 | test: {
21 | forget_class: number[][];
22 | other_classes: number[][];
23 | };
24 | };
25 |
26 | export type Point = [
27 | number,
28 | number,
29 | number,
30 | number,
31 | number,
32 | number,
33 | {
34 | [key: string]: number;
35 | }
36 | ];
37 |
38 | export type AttackValue = {
39 | img: number;
40 | entropy: number;
41 | confidence: number;
42 | };
43 |
44 | export type AttackResult = {
45 | threshold: number;
46 | fpr: number;
47 | fnr: number;
48 | attack_score: number;
49 | };
50 |
51 | export type AttackResults = {
52 | entropy_above_unlearn: AttackResult[];
53 | entropy_above_retrain: AttackResult[];
54 | confidence_above_unlearn: AttackResult[];
55 | confidence_above_retrain: AttackResult[];
56 | };
57 |
58 | export type AttackData = {
59 | values: AttackValue[];
60 | results: AttackResults;
61 | };
62 |
63 | export type ExperimentData = {
64 | CreatedAt: string;
65 | ID: string;
66 | FC: number;
67 | Type: string;
68 | Base: string;
69 | Method: string;
70 | Epoch: number | string;
71 | BS: number | string;
72 | LR: number | string;
73 | UA: number | string;
74 | RA: number | string;
75 | TUA: number | string;
76 | TRA: number | string;
77 | PA: number | string;
78 | RTE: number | string;
79 | FQS: number | string;
80 | accs: number[];
81 | label_dist: Dist;
82 | conf_dist: Dist;
83 | t_accs: number[];
84 | t_label_dist: Dist;
85 | t_conf_dist: Dist;
86 | cka: CKA;
87 | points: Point[];
88 | attack: AttackData;
89 | };
90 |
91 | export type Experiment = Omit;
92 |
93 | export type Experiments = { [key: string]: Experiment };
94 |
--------------------------------------------------------------------------------
/frontend/src/types/embeddings.ts:
--------------------------------------------------------------------------------
1 | export type Coordinate = { x: number; y: number };
2 |
3 | export type Position = {
4 | from: Coordinate | null;
5 | to: Coordinate | null;
6 | };
7 |
8 | export type HoverInstance = {
9 | imgIdx: number;
10 | source: "A" | "B";
11 | modelAProb?: Prob;
12 | modelBProb?: Prob;
13 | } | null;
14 |
15 | export type ViewMode = {
16 | label: string;
17 | explanation: string;
18 | length: number;
19 | };
20 |
21 | export type Prob = { [key: string]: number };
22 |
23 | export type SelectedData = (number | Prob)[][];
24 |
25 | export type SvgElementsRefType = {
26 | svg: d3.Selection | null;
27 | gMain: d3.Selection | null;
28 | gDot: d3.Selection | null;
29 | circles: d3.Selection<
30 | SVGCircleElement,
31 | (number | Prob)[],
32 | SVGGElement,
33 | undefined
34 | > | null;
35 | crosses: d3.Selection<
36 | SVGPathElement,
37 | (number | Prob)[],
38 | SVGGElement,
39 | undefined
40 | > | null;
41 | };
42 |
--------------------------------------------------------------------------------
/frontend/src/types/experiments.ts:
--------------------------------------------------------------------------------
1 | // Configuration Data
2 | export interface UnlearningConfigurationData {
3 | method: string;
4 | forget_class: number;
5 | epochs: number;
6 | learning_rate: number;
7 | batch_size: number;
8 | base_weights: string;
9 | }
10 |
11 | // Status
12 | export interface ClassAccuracies {
13 | [key: string]: number;
14 | }
15 |
16 | export interface UnlearningStatus {
17 | is_unlearning: boolean;
18 | progress: string;
19 | recent_id: string | null;
20 | current_epoch: number;
21 | total_epochs: number;
22 | current_unlearn_loss: number;
23 | current_unlearn_accuracy: number;
24 | p_training_loss: number;
25 | p_training_accuracy: number;
26 | p_test_loss: number;
27 | p_test_accuracy: number;
28 | method: string;
29 | estimated_time_remaining: number;
30 | elapsed_time: number;
31 | completed_steps: number[];
32 | learning_rate?: number;
33 | batch_size?: number;
34 | }
35 |
36 | // others
37 | export interface Action {
38 | type: string;
39 | payload: string | number;
40 | }
41 |
42 | export type PerformanceMetrics = {
43 | [key: string]: d3.ScaleLinear;
44 | };
45 |
--------------------------------------------------------------------------------
/frontend/src/utils/api/dataTable.ts:
--------------------------------------------------------------------------------
1 | import { API_URL } from "../../constants/common";
2 |
3 | export async function deleteRow(forgetClass: number, fileName: string) {
4 | try {
5 | const response = await fetch(`${API_URL}/data/${forgetClass}/${fileName}`, {
6 | method: "DELETE",
7 | });
8 |
9 | if (!response.ok) {
10 | throw new Error(
11 | `Status Code: ${response.status}, Message: ${response.statusText}`
12 | );
13 | }
14 | } catch (error) {
15 | console.error("Failed to delete the row:", error);
16 |
17 | if (error instanceof Error) {
18 | alert(`Failed to delete the row: ${error.message}`);
19 | } else {
20 | alert("An unknown error occurred while deleting the row . . .");
21 | }
22 |
23 | throw error;
24 | }
25 | }
26 |
27 | export async function downloadJSON(forgetClass: number, fileName: string) {
28 | try {
29 | const response = await fetch(`${API_URL}/data/${forgetClass}/${fileName}`);
30 |
31 | if (!response.ok) {
32 | throw new Error(
33 | `Status Code: ${response.status}, Message: ${response.statusText}`
34 | );
35 | }
36 |
37 | return await response.json();
38 | } catch (error) {
39 | console.error("Failed to download the JSON file:", error);
40 |
41 | if (error instanceof Error) {
42 | alert(`Failed to download the JSON file: ${error.message}`);
43 | } else {
44 | alert("An unknown error occurred while downloading the JSON file . . .");
45 | }
46 |
47 | throw error;
48 | }
49 | }
50 |
51 | export async function downloadPTH(forgetClass: number, fileName: string) {
52 | const fetchUrl = fileName.startsWith("000")
53 | ? `${API_URL}/trained_models`
54 | : `${API_URL}/data/${forgetClass}/${fileName}/weights`;
55 | console.log(fetchUrl);
56 |
57 | try {
58 | const response = await fetch(fetchUrl);
59 |
60 | if (!response.ok) {
61 | throw new Error(
62 | `Status Code: ${response.status}, Message: ${response.statusText}`
63 | );
64 | }
65 |
66 | const blob = await response.blob();
67 |
68 | const blobUrl = window.URL.createObjectURL(blob);
69 | const a = document.createElement("a");
70 | a.href = blobUrl;
71 | a.download = `${fileName}.pth`;
72 | document.body.appendChild(a);
73 | a.click();
74 | window.URL.revokeObjectURL(blobUrl);
75 | document.body.removeChild(a);
76 |
77 | return blob;
78 | } catch (error) {
79 | console.error("Failed to download the PTH file:", error);
80 |
81 | if (error instanceof Error) {
82 | alert(`Failed to download the PTH file: ${error.message}`);
83 | } else {
84 | alert("An unknown error occurred while downloading the PTH file . . .");
85 | }
86 |
87 | throw error;
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/frontend/src/utils/api/privacyAttack.ts:
--------------------------------------------------------------------------------
1 | import { API_URL } from "../../constants/common";
2 |
3 | export async function fetchAllSubsetImages(forgetClass: number) {
4 | try {
5 | const response = await fetch(`${API_URL}/image/all_subset/${forgetClass}`);
6 |
7 | if (!response.ok) {
8 | throw new Error(
9 | `Status Code: ${response.status}, Message: ${response.statusText}`
10 | );
11 | }
12 |
13 | return await response.json();
14 | } catch (error) {
15 | if (error instanceof Error) {
16 | throw new Error(`fetchAllSubsetImages failed: ${error.message}`);
17 | } else {
18 | throw new Error(`fetchAllSubsetImages failed: Unknown error`);
19 | }
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/frontend/src/utils/api/requests.ts:
--------------------------------------------------------------------------------
1 | import { API_URL } from "../../constants/common";
2 | import { UnlearningStatus } from "../../types/experiments";
3 |
4 | export async function fetchModelFiles(
5 | end: "trained_models" | "unlearned_models"
6 | ) {
7 | try {
8 | const response = await fetch(`${API_URL}/${end}`);
9 |
10 | if (!response.ok) {
11 | throw new Error(
12 | `Status Code: ${response.status}, Message: ${response.statusText}`
13 | );
14 | }
15 |
16 | return await response.json();
17 | } catch (error) {
18 | console.error(`Failed to fetch model files (${end}):`, error);
19 |
20 | if (error instanceof Error) {
21 | alert(`Failed to fetch model files: ${error.message}`);
22 | } else {
23 | alert("An unknown Error occurred while fetching model files . . .");
24 | }
25 |
26 | throw error;
27 | }
28 | }
29 |
30 | export async function fetchUnlearningStatus(): Promise {
31 | try {
32 | const response = await fetch(`${API_URL}/unlearn/status`);
33 |
34 | if (!response.ok) {
35 | throw new Error(
36 | `Status Code: ${response.status}, Message: ${response.statusText}`
37 | );
38 | }
39 |
40 | const data = await response.json();
41 |
42 | if (data.method) {
43 | data.method = data.method.replace(/-/g, "");
44 | }
45 |
46 | return data;
47 | } catch (error) {
48 | throw new Error(`Failed to fetch unlearning status: ${error}`);
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/frontend/src/utils/api/unlearning.ts:
--------------------------------------------------------------------------------
1 | import { ExperimentData } from "../../types/data";
2 | import { UnlearningConfigurationData } from "../../types/experiments";
3 | import { API_URL } from "../../constants/common";
4 |
5 | export async function executeMethodUnlearning(
6 | runningConfig: UnlearningConfigurationData
7 | ) {
8 | const method = runningConfig.method;
9 | const data: Omit = {
10 | forget_class: runningConfig.forget_class,
11 | epochs: runningConfig.epochs,
12 | learning_rate: runningConfig.learning_rate,
13 | batch_size: runningConfig.batch_size,
14 | base_weights: runningConfig.base_weights,
15 | };
16 |
17 | try {
18 | const response = await fetch(`${API_URL}/unlearn/${method}`, {
19 | method: "POST",
20 | headers: { "Content-Type": "application/json" },
21 | body: JSON.stringify(data),
22 | });
23 |
24 | if (!response.ok) {
25 | throw new Error(
26 | `Status Code: ${response.status}, Message: ${response.statusText}`
27 | );
28 | }
29 | } catch (error) {
30 | console.error("Failed to unlearn with the predefined setting:", error);
31 |
32 | if (error instanceof Error) {
33 | alert(`Failed to unlearn with the predefined setting: ${error.message}`);
34 | } else {
35 | alert("An unknown error occurred while unlearning . . .");
36 | }
37 |
38 | throw error;
39 | }
40 | }
41 |
42 | export async function executeCustomUnlearning(
43 | customFile: File,
44 | forgetClass: number
45 | ) {
46 | try {
47 | const formData = new FormData();
48 | formData.append("weights_file", customFile);
49 | formData.append("forget_class", forgetClass.toString());
50 |
51 | const response = await fetch(`${API_URL}/unlearn/custom`, {
52 | method: "POST",
53 | body: formData,
54 | });
55 |
56 | if (!response.ok) {
57 | throw new Error(
58 | `Status Code: ${response.status}, Message: ${response.statusText}`
59 | );
60 | }
61 | } catch (error) {
62 | console.error("Failed to unlearn with the custom file:", error);
63 |
64 | if (error instanceof Error) {
65 | alert(`Failed to unlearn with the custom file: ${error.message}`);
66 | } else {
67 | alert(
68 | "An unknown error occurred while executing custom unlearning . . ."
69 | );
70 | }
71 |
72 | throw error;
73 | }
74 | }
75 |
76 | export async function fetchFileData(
77 | forgetClass: number,
78 | fileName: string
79 | ): Promise {
80 | try {
81 | const response = await fetch(`${API_URL}/data/${forgetClass}/${fileName}`);
82 |
83 | if (!response.ok) {
84 | throw new Error(
85 | `Server error: ${response.status} ${response.statusText}`
86 | );
87 | }
88 |
89 | return await response.json();
90 | } catch (error) {
91 | console.error("Failed to fetch an unlearned data file:", error);
92 |
93 | const errorMessage =
94 | error instanceof Error ? error.message : "Unknown error occurred.";
95 |
96 | throw new Error(`Failed to fetch an unlearned data file: ${errorMessage}`);
97 | }
98 | }
99 |
100 | export async function fetchAllExperimentsData(forgetClass: number) {
101 | try {
102 | const response = await fetch(`${API_URL}/data/${forgetClass}/all`);
103 |
104 | return await response.json();
105 | } catch (error) {
106 | console.error("Failed to fetch all unlearned data file:", error);
107 |
108 | if (error instanceof Error) {
109 | alert(`Failed to fetch all unlearned data file: ${error.message}`);
110 | } else {
111 | alert(
112 | "An unknown error occurred while fetching all unlearned data file . . ."
113 | );
114 | }
115 |
116 | throw error;
117 | }
118 | }
119 |
120 | export async function fetchAllWeightNames(forgetClass: number) {
121 | try {
122 | const response = await fetch(
123 | `${API_URL}/data/${forgetClass}/all_weights_name`
124 | );
125 |
126 | return await response.json();
127 | } catch (error) {
128 | console.error("Failed to fetch all weights names:", error);
129 |
130 | if (error instanceof Error) {
131 | alert(`Failed to fetch all weights names: ${error.message}`);
132 | } else {
133 | alert("An unknown error occurred while fetching all weights names . . .");
134 | }
135 |
136 | throw error;
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/frontend/src/utils/config/unlearning.ts:
--------------------------------------------------------------------------------
1 | export function getDefaultUnlearningConfig(method: string) {
2 | let epoch, learning_rate, batch_size;
3 |
4 | if (method === "ft") {
5 | epoch = "1";
6 | learning_rate = "0.001";
7 | batch_size = "64";
8 | } else if (method === "rl") {
9 | epoch = "3";
10 | learning_rate = "0.001";
11 | batch_size = "64";
12 | } else {
13 | epoch = "7";
14 | learning_rate = "0.0001";
15 | batch_size = "256";
16 | }
17 |
18 | return { epoch, learning_rate, batch_size };
19 | }
20 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/accuracies.ts:
--------------------------------------------------------------------------------
1 | import { TABLEAU10 } from "../../constants/colors";
2 | import { GapDataItem } from "../../types/data";
3 |
4 | const GAP_FIX_LENGTH = 3;
5 |
6 | export function getAccuracyGap(
7 | baseAcc: number[] | undefined,
8 | compAcc: number[] | undefined
9 | ) {
10 | return baseAcc && compAcc
11 | ? Object.keys(baseAcc).map((key, idx) => {
12 | const modelAValue = baseAcc[idx];
13 | const modelBValue = compAcc[idx];
14 | const categoryLetter = String.fromCharCode(65 + idx);
15 | return {
16 | category: categoryLetter,
17 | classLabel: key,
18 | gap: parseFloat((modelBValue - modelAValue).toFixed(GAP_FIX_LENGTH)),
19 | fill: TABLEAU10[idx],
20 | modelAAccuracy: modelAValue,
21 | modelBAccuracy: modelBValue,
22 | };
23 | })
24 | : [];
25 | }
26 |
27 | export function getMaxGap(gapData: GapDataItem[]) {
28 | return Number(
29 | Math.max(...gapData.map((item) => Math.abs(item.gap))).toFixed(3)
30 | );
31 | }
32 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/colors.ts:
--------------------------------------------------------------------------------
1 | import { COLORS } from "../../constants/colors";
2 |
3 | export function getTypeColors(type: string) {
4 | let backgroundColor;
5 | if (type === "Original") {
6 | backgroundColor = "#F3F4F6";
7 | } else if (type === "Retrained") {
8 | backgroundColor = COLORS.DARK_GRAY;
9 | } else if (type === "Unlearned") {
10 | backgroundColor = "#FF8C00";
11 | }
12 | const color = type === "Original" ? "#222222" : "#FFFFFF";
13 | return { color, backgroundColor };
14 | }
15 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/experiments.ts:
--------------------------------------------------------------------------------
1 | import * as d3 from "d3";
2 |
3 | import { Dist } from "../../types/data";
4 | import { Experiment, Experiments } from "../../types/data";
5 | import { TRAIN } from "../../constants/common";
6 |
7 | const metrics = ["UA", "TUA", "RA", "TRA", "PA", "RTE", "FQS"] as const;
8 | export function calculatePerformanceMetrics(data: Experiments) {
9 | const values = metrics.reduce((acc, key) => {
10 | acc[key] = Object.values(data).map((d) => Number(d[key]));
11 | return acc;
12 | }, {} as Record<(typeof metrics)[number], number[]>);
13 |
14 | const mins = metrics.reduce((acc, key) => {
15 | acc[key] = d3.min(values[key])!;
16 | return acc;
17 | }, {} as Record<(typeof metrics)[number], number>);
18 |
19 | const maxes = metrics.reduce((acc, key) => {
20 | acc[key] = d3.max(values[key])!;
21 | return acc;
22 | }, {} as Record<(typeof metrics)[number], number>);
23 |
24 | return metrics.reduce((acc, key) => {
25 | acc[key] = d3.scaleLinear().domain([mins[key], maxes[key]]).range([0, 1]);
26 | return acc;
27 | }, {} as Record<(typeof metrics)[number], d3.ScaleLinear>);
28 | }
29 |
30 | type BubbleChartData = {
31 | x: number;
32 | y: number;
33 | label: number;
34 | conf: number;
35 | }[];
36 |
37 | export function extractBubbleChartData(datasetMode: string, data: Experiment) {
38 | let bubbleChartData: {
39 | label_dist: Dist;
40 | conf_dist: Dist;
41 | };
42 | if (datasetMode === TRAIN)
43 | bubbleChartData = {
44 | label_dist: data.label_dist,
45 | conf_dist: data.conf_dist,
46 | };
47 | else
48 | bubbleChartData = {
49 | label_dist: data.t_label_dist,
50 | conf_dist: data.t_conf_dist,
51 | };
52 |
53 | let result: BubbleChartData = [];
54 | if (bubbleChartData) {
55 | Object.entries(bubbleChartData.label_dist).forEach(
56 | ([gtIndex, dist], gtIdx) => {
57 | Object.entries(dist).forEach(([_, labelValue], predIdx) => {
58 | const confValue = bubbleChartData.conf_dist[gtIndex][predIdx];
59 | result.push({
60 | x: predIdx,
61 | y: gtIdx,
62 | label: labelValue,
63 | conf: confValue,
64 | });
65 | });
66 | }
67 | );
68 | }
69 |
70 | return result;
71 | }
72 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/getButterflyLegendData.ts:
--------------------------------------------------------------------------------
1 | import { COLORS } from "../../constants/colors";
2 |
3 | export function getButterflyLegendData(
4 | isAboveThresholdUnlearn: boolean,
5 | isModelA: boolean
6 | ) {
7 | return isAboveThresholdUnlearn
8 | ? [
9 | {
10 | label: "From Retrained / Pred. Retrained",
11 | side: "left",
12 | color: COLORS.LIGHT_GRAY,
13 | },
14 | {
15 | label: `From ${isModelA ? "Model A" : "Model B"} / Pred. Retrained`,
16 | side: "right",
17 | color: isModelA ? COLORS.LIGHT_EMERALD : COLORS.LIGHT_PURPLE,
18 | },
19 | {
20 | label: `From Retrained / Pred. ${isModelA ? "Model A" : "Model B"}`,
21 | side: "left",
22 | color: COLORS.DARK_GRAY,
23 | },
24 | {
25 | label: `From ${isModelA ? "Model A" : "Model B"} / Pred. ${
26 | isModelA ? "Model A" : "Model B"
27 | }`,
28 | side: "right",
29 | color: isModelA ? COLORS.EMERALD : COLORS.PURPLE,
30 | },
31 | ]
32 | : [
33 | {
34 | label: `From Retrained / Pred. ${isModelA ? "Model A" : "Model B"}`,
35 | side: "left",
36 | color: COLORS.DARK_GRAY,
37 | },
38 | {
39 | label: `From ${isModelA ? "Model A" : "Model B"} / Pred. ${
40 | isModelA ? "Model A" : "Model B"
41 | }`,
42 | side: "right",
43 | color: isModelA ? COLORS.EMERALD : COLORS.PURPLE,
44 | },
45 | {
46 | label: "From Retrained / Pred. Retrained",
47 | side: "left",
48 | color: COLORS.LIGHT_GRAY,
49 | },
50 | {
51 | label: `From ${isModelA ? "Model A" : "Model B"} / Pred. Retrained`,
52 | side: "right",
53 | color: isModelA ? COLORS.LIGHT_EMERALD : COLORS.LIGHT_PURPLE,
54 | },
55 | ];
56 | }
57 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/getCkaData.ts:
--------------------------------------------------------------------------------
1 | import { Experiment } from "../../types/data";
2 | import { TRAIN } from "../../constants/common";
3 |
4 | export const getCkaData = (
5 | dataset: string,
6 | modelAExperiment: Experiment,
7 | modelBExperiment: Experiment
8 | ) => {
9 | const layers = modelAExperiment.cka.layers;
10 |
11 | const modelACka =
12 | dataset === TRAIN ? modelAExperiment.cka.train : modelAExperiment.cka.test;
13 | const modelBCka =
14 | dataset === TRAIN ? modelBExperiment.cka.train : modelBExperiment.cka.test;
15 |
16 | const modelAForgetCka = modelACka.forget_class.map(
17 | (layer, idx) => layer[idx]
18 | );
19 | const modelAOtherCka = modelACka.other_classes.map(
20 | (layer, idx) => layer[idx]
21 | );
22 | const modelBForgetCka = modelBCka.forget_class.map(
23 | (layer, idx) => layer[idx]
24 | );
25 | const modelBOtherCka = modelBCka.other_classes.map(
26 | (layer, idx) => layer[idx]
27 | );
28 |
29 | const ckaData = layers.map((layer, idx) => ({
30 | layer,
31 | modelAForgetCka: modelAForgetCka[idx],
32 | modelAOtherCka: modelAOtherCka[idx],
33 | modelBForgetCka: modelBForgetCka[idx],
34 | modelBOtherCka: modelBOtherCka[idx],
35 | }));
36 |
37 | return ckaData;
38 | };
39 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/getProgressSteps.ts:
--------------------------------------------------------------------------------
1 | import { UnlearningStatus } from "../../types/experiments";
2 |
3 | const TO_FIXED_LENGTH = 3;
4 | const UMAP = "UMAP";
5 | const CKA = "CKA";
6 | const UNLEARNING = "Unlearning";
7 | const IDLE = "Idle";
8 |
9 | export const getProgressSteps = (
10 | status: UnlearningStatus | null,
11 | activeStep: number,
12 | umapProgress: number,
13 | ckaProgress: number
14 | ) => {
15 | if (!status) {
16 | return [
17 | {
18 | step: 1,
19 | title: "Unlearn",
20 | description: `Method: **-** | Epochs: **-**\nCurrent Unlearning Accuracy: **-**`,
21 | },
22 | {
23 | step: 2,
24 | title: "Evaluate",
25 | description: `Training Accuracy: **-**\nTest Accuracy: **-**`,
26 | },
27 | {
28 | step: 3,
29 | title: "Analyze",
30 | description: `Computing UMAP Embedding\nCalculating CKA Values`,
31 | },
32 | ];
33 | }
34 |
35 | const method = status && status.method ? status.method : "";
36 | const progress = status.progress;
37 | const currentUnlearnAccuracy = status.current_unlearn_accuracy;
38 | const currentEpoch = status.current_epoch;
39 | const totalEpochs = status.total_epochs;
40 | const trainingAccuracy = status.p_training_accuracy;
41 | const testAccuracy = status.p_test_accuracy;
42 | const completedSteps = status.completed_steps;
43 |
44 | return [
45 | {
46 | step: 1,
47 | title: "Unlearn",
48 | description: `Method: **${method ? method : "-"}** | Epoch: **${
49 | !completedSteps.includes(1) ? "-" : currentEpoch + "/" + totalEpochs
50 | }**\nCurrent Unlearning Accuracy: **${
51 | completedSteps.includes(1) &&
52 | (currentEpoch > 1 || (totalEpochs === 1 && completedSteps.includes(2)))
53 | ? currentUnlearnAccuracy === 0
54 | ? 0
55 | : currentUnlearnAccuracy.toFixed(TO_FIXED_LENGTH)
56 | : "-"
57 | }**`,
58 | },
59 | {
60 | step: 2,
61 | title: "Evaluate",
62 | description: `Training Accuracy: **${
63 | completedSteps.includes(3) ||
64 | (completedSteps.includes(2) && progress.includes("Test"))
65 | ? trainingAccuracy === 0
66 | ? 0
67 | : trainingAccuracy.toFixed(TO_FIXED_LENGTH)
68 | : "-"
69 | }**\nTest Accuracy: **${
70 | completedSteps.includes(3)
71 | ? testAccuracy === 0
72 | ? 0
73 | : testAccuracy.toFixed(TO_FIXED_LENGTH)
74 | : "-"
75 | }**`,
76 | },
77 | {
78 | step: 3,
79 | title: "Analyze",
80 | description: `${
81 | (activeStep === 3 &&
82 | (progress.includes(UMAP) || progress.includes(CKA))) ||
83 | (completedSteps.includes(3) &&
84 | (progress === IDLE || progress === UNLEARNING))
85 | ? `Computing UMAP Embedding... **${
86 | !progress.includes(UMAP) ? "100" : umapProgress
87 | }%**`
88 | : "Computing UMAP Embedding"
89 | }\n${
90 | (activeStep === 3 && progress.includes(CKA)) ||
91 | (completedSteps.includes(3) &&
92 | (progress === IDLE || progress === UNLEARNING))
93 | ? `Calculating CKA Values... **${
94 | progress === IDLE || progress === UNLEARNING ? "100" : ckaProgress
95 | }%**`
96 | : "Calculating CKA Values"
97 | }\n${
98 | completedSteps.includes(3) &&
99 | (progress === IDLE || progress === UNLEARNING)
100 | ? `Done! Model ID: **${status.recent_id}**`
101 | : ""
102 | }`,
103 | },
104 | ];
105 | };
106 |
--------------------------------------------------------------------------------
/frontend/src/utils/data/running-status-context.ts:
--------------------------------------------------------------------------------
1 | import { UnlearningStatus } from "../../types/experiments";
2 |
3 | export function getCurrentProgress(status: UnlearningStatus) {
4 | return status.is_unlearning && status.progress === "Idle"
5 | ? "Unlearning"
6 | : !status.is_unlearning
7 | ? "Idle"
8 | : status.progress;
9 | }
10 |
11 | export function getCompletedSteps(progress: string, status: UnlearningStatus) {
12 | if (
13 | (progress === "Unlearning" &&
14 | status.current_epoch !== status.total_epochs) ||
15 | (progress === "Unlearning" && status.is_unlearning)
16 | ) {
17 | return [1];
18 | } else if (progress.includes("Evaluating")) {
19 | return [1, 2];
20 | } else if (progress.includes("UMAP") || progress.includes("CKA")) {
21 | return [1, 2, 3];
22 | } else {
23 | return [1, 2, 3];
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/frontend/src/utils/util.ts:
--------------------------------------------------------------------------------
1 | import { type ClassValue, clsx } from "clsx";
2 | import { twMerge } from "tailwind-merge";
3 | import { CONFIG } from "../app/App";
4 |
5 | export function cn(...inputs: ClassValue[]) {
6 | return twMerge(clsx(inputs));
7 | }
8 |
9 | export function calculateZoom() {
10 | const screenWidth = window.innerWidth;
11 |
12 | const totalHeight =
13 | 48 + CONFIG.EXPERIMENTS_PROGRESS_HEIGHT + CONFIG.CORE_HEIGHT;
14 |
15 | const expectedZoom = screenWidth / CONFIG.TOTAL_WIDTH;
16 | const scaledHeight = totalHeight * expectedZoom;
17 |
18 | if (scaledHeight > window.innerHeight) {
19 | const outer = document.createElement("div");
20 | outer.style.visibility = "hidden";
21 | outer.style.overflow = "scroll";
22 | document.body.appendChild(outer);
23 |
24 | const inner = document.createElement("div");
25 | outer.appendChild(inner);
26 |
27 | const scrollbarWidth = outer.offsetWidth - inner.offsetWidth;
28 | outer.parentNode?.removeChild(outer);
29 |
30 | return (screenWidth - scrollbarWidth) / CONFIG.TOTAL_WIDTH;
31 | }
32 |
33 | return expectedZoom;
34 | }
35 |
--------------------------------------------------------------------------------
/frontend/src/views/Core/Core.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect } from "react";
2 |
3 | import View from "../../components/common/View";
4 | import Title from "../../components/common/Title";
5 | import Indicator from "../../components/common/Indicator";
6 | import Embedding from "./Embedding";
7 | import PrivacyAttack from "./PrivacyAttack";
8 | import { CONFIG } from "../../app/App";
9 | import { fetchFileData, fetchAllWeightNames } from "../../utils/api/unlearning";
10 | import { useForgetClassStore } from "../../stores/forgetClassStore";
11 | import { useModelDataStore } from "../../stores/modelDataStore";
12 | import { Point } from "../../types/data";
13 | import { cn } from "../../utils/util";
14 |
15 | const EMBEDDINGS = "embeddings";
16 | const ATTACK = "attack";
17 |
18 | export default function Core() {
19 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
20 | const modelA = useModelDataStore((state) => state.modelA);
21 | const modelB = useModelDataStore((state) => state.modelB);
22 |
23 | const [displayMode, setDisplayMode] = useState(EMBEDDINGS);
24 | const [modelAPoints, setModelAPoints] = useState([]);
25 | const [modelBPoints, setModelBPoints] = useState([]);
26 |
27 | const isEmbeddingMode = displayMode === EMBEDDINGS;
28 | const forgetClassExist = forgetClass !== -1;
29 |
30 | const handleDisplayModeChange = (e: React.MouseEvent) => {
31 | const id = e.currentTarget.id;
32 |
33 | if (id === EMBEDDINGS) {
34 | setDisplayMode(EMBEDDINGS);
35 | } else {
36 | setDisplayMode(ATTACK);
37 | }
38 | };
39 |
40 | useEffect(() => {
41 | async function loadModelAData() {
42 | if (!forgetClassExist) return;
43 |
44 | const ids: string[] = await fetchAllWeightNames(forgetClass);
45 | const slicedIds = ids.map((id) => id.slice(0, -4));
46 |
47 | if (!modelA || !slicedIds.includes(modelA)) return;
48 |
49 | try {
50 | const data = await fetchFileData(forgetClass, modelA);
51 | setModelAPoints(data.points);
52 | } catch (error) {
53 | console.error(`Failed to fetch an model A data file: ${error}`);
54 | setModelAPoints([]);
55 | }
56 | }
57 | loadModelAData();
58 | }, [forgetClass, forgetClassExist, modelA]);
59 |
60 | useEffect(() => {
61 | async function loadModelBData() {
62 | if (!forgetClassExist) return;
63 |
64 | const ids: string[] = await fetchAllWeightNames(forgetClass);
65 | const slicedIds = ids.map((id) => id.slice(0, -4));
66 |
67 | if (!modelB || !slicedIds.includes(modelB)) return;
68 |
69 | try {
70 | const data = await fetchFileData(forgetClass, modelB);
71 | setModelBPoints(data.points);
72 | } catch (error) {
73 | console.error(`Error fetching model B file data: ${error}`);
74 | setModelBPoints([]);
75 | }
76 | }
77 | loadModelBData();
78 | }, [forgetClass, forgetClassExist, modelB]);
79 |
80 | return (
81 |
87 |
88 |
}
96 | onClick={handleDisplayModeChange}
97 | />
98 | }
106 | onClick={handleDisplayModeChange}
107 | />
108 |
109 | {forgetClassExist ? (
110 | isEmbeddingMode ? (
111 |
112 | ) : (
113 |
117 | )
118 | ) : (
119 |
120 | )}
121 |
122 | );
123 | }
124 |
125 | function UnderLine() {
126 | return ;
127 | }
128 |
--------------------------------------------------------------------------------
/frontend/src/views/Core/PrivacyAttack.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 |
3 | import Indicator from "../../components/common/Indicator";
4 | import Legend from "../../components/Core/PrivacyAttack/Legend";
5 | import AttackAnalytics from "../../components/Core/PrivacyAttack/AttackAnalytics";
6 | import { Prob } from "../../types/embeddings";
7 | import { Separator } from "../../components/UI/separator";
8 | import { useModelDataStore } from "../../stores/modelDataStore";
9 | import { fetchFileData } from "../../utils/api/unlearning";
10 | import { useForgetClassStore } from "../../stores/forgetClassStore";
11 | import { useAttackStateStore } from "../../stores/attackStore";
12 | import { ExperimentData } from "../../types/data";
13 | import { ENTROPY, UNLEARN } from "../../constants/common";
14 | import { THRESHOLD_STRATEGIES } from "../../constants/privacyAttack";
15 |
16 | interface Props {
17 | modelAPoints: (number | Prob)[][];
18 | modelBPoints: (number | Prob)[][];
19 | }
20 |
21 | export default function PrivacyAttack({ modelAPoints, modelBPoints }: Props) {
22 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
23 | const modelA = useModelDataStore((state) => state.modelA);
24 | const modelB = useModelDataStore((state) => state.modelB);
25 | const setMetric = useAttackStateStore((state) => state.setMetric);
26 | const setDirection = useAttackStateStore((state) => state.setDirection);
27 | const setStrategy = useAttackStateStore((state) => state.setStrategy);
28 |
29 | const [retrainData, setRetrainData] = useState();
30 |
31 | const isModelAOriginal = modelA.startsWith("000");
32 | const isModelBOriginal = modelB.startsWith("000");
33 |
34 | useEffect(() => {
35 | setMetric(ENTROPY);
36 | setDirection(UNLEARN);
37 | setStrategy(THRESHOLD_STRATEGIES[0].strategy);
38 | }, [setDirection, setMetric, setStrategy]);
39 |
40 | useEffect(() => {
41 | async function loadRetrainData() {
42 | if (forgetClass === -1) return;
43 | try {
44 | const data = await fetchFileData(forgetClass, `a00${forgetClass}`);
45 | setRetrainData(data);
46 | } catch (error) {
47 | console.error(`Failed to fetch an retrained data file: ${error}`);
48 | }
49 | }
50 | loadRetrainData();
51 | }, [forgetClass]);
52 |
53 | return (
54 |
55 |
56 |
57 | {!retrainData ? (
58 |
59 | ) : isModelAOriginal ? (
60 |
61 | ) : (
62 |
68 | )}
69 |
73 | {!retrainData ? (
74 |
75 | ) : isModelBOriginal ? (
76 |
77 | ) : (
78 |
84 | )}
85 |
86 |
87 | );
88 | }
89 |
--------------------------------------------------------------------------------
/frontend/src/views/MetricsView/ClassWiseAnalysis.tsx:
--------------------------------------------------------------------------------
1 | import { useMemo, useState } from "react";
2 |
3 | import Subtitle from "../../components/common/Subtitle";
4 | import Indicator from "../../components/common/Indicator";
5 | import VerticalBarChart from "../../components/MetricsView/Accuracy/VerticalBarChart";
6 | import { useForgetClassStore } from "../../stores/forgetClassStore";
7 | import { useModelDataStore } from "../../stores/modelDataStore";
8 | import { getAccuracyGap, getMaxGap } from "../../utils/data/accuracies";
9 | import {
10 | useModelAExperiment,
11 | useModelBExperiment,
12 | } from "../../hooks/useModelExperiment";
13 |
14 | export default function ClassWiseAnalysis() {
15 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
16 | const modelA = useModelDataStore((state) => state.modelA);
17 | const modelB = useModelDataStore((state) => state.modelB);
18 | const modelAExperiment = useModelAExperiment();
19 | const modelBExperiment = useModelBExperiment();
20 |
21 | const [hoveredClass, setHoveredClass] = useState(null);
22 |
23 | const accuracyData = useMemo(() => {
24 | const trainAccuracyGap = getAccuracyGap(
25 | modelAExperiment?.accs,
26 | modelBExperiment?.accs
27 | );
28 | const testAccuracyGap = getAccuracyGap(
29 | modelAExperiment?.t_accs,
30 | modelBExperiment?.t_accs
31 | );
32 | const trainMaxGap = getMaxGap(trainAccuracyGap);
33 | const testMaxGap = getMaxGap(testAccuracyGap);
34 | const maxGap = Math.max(trainMaxGap, testMaxGap);
35 |
36 | return {
37 | trainAccuracyGap,
38 | testAccuracyGap,
39 | maxGap,
40 | };
41 | }, [modelAExperiment, modelBExperiment]);
42 |
43 | return (
44 |
45 |
46 | {forgetClass !== undefined ? (
47 | modelA !== "" && modelB !== "" ? (
48 |
49 |
56 |
64 |
65 | ) : (
66 |
67 | )
68 | ) : (
69 |
70 | )}
71 |
72 | );
73 | }
74 |
--------------------------------------------------------------------------------
/frontend/src/views/MetricsView/LayerWiseSimilarity.tsx:
--------------------------------------------------------------------------------
1 | import { useState } from "react";
2 |
3 | import Subtitle from "../../components/common/Subtitle";
4 | import Indicator from "../../components/common/Indicator";
5 | import LineChart from "../../components/MetricsView/CKA/LineChart";
6 | import DatasetModeSelector from "../../components/common/DatasetModeSelector";
7 | import { useModelDataStore } from "../../stores/modelDataStore";
8 | import { useForgetClassStore } from "../../stores/forgetClassStore";
9 | import { TRAIN } from "../../constants/common";
10 |
11 | export default function LayerWiseSimilarity() {
12 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
13 | const modelA = useModelDataStore((state) => state.modelA);
14 | const modelB = useModelDataStore((state) => state.modelB);
15 |
16 | const [selectedDataset, setSelectedDataset] = useState(TRAIN);
17 |
18 | const areAllModelsSelected = modelA !== "" && modelB !== "";
19 | const forgetClassExist = forgetClass !== -1;
20 |
21 | return (
22 |
23 |
24 |
25 | {forgetClassExist && areAllModelsSelected && (
26 |
30 | )}
31 |
32 | {forgetClassExist ? (
33 | areAllModelsSelected ? (
34 |
35 |
36 | Similarity Between Before and After Unlearning
37 |
38 |
39 |
40 | ) : (
41 |
42 | )
43 | ) : (
44 |
45 | )}
46 |
47 | );
48 | }
49 |
--------------------------------------------------------------------------------
/frontend/src/views/MetricsView/MetricsView.tsx:
--------------------------------------------------------------------------------
1 | import View from "../../components/common/View";
2 | import Title from "../../components/common/Title";
3 | import ClassWiseAnalysis from "./ClassWiseAnalysis";
4 | import PredictionMatrix from "./PredictionMatrix";
5 | import LayerWiseSimilarity from "./LayerWiseSimilarity";
6 | import { CONFIG } from "../../app/App";
7 |
8 | export default function MetricsCore() {
9 | return (
10 |
16 |
17 |
22 |
23 | );
24 | }
25 |
--------------------------------------------------------------------------------
/frontend/src/views/MetricsView/PredictionMatrix.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useCallback } from "react";
2 |
3 | import Subtitle from "../../components/common/Subtitle";
4 | import DatasetModeSelector from "../../components/common/DatasetModeSelector";
5 | import BubbleMatrix from "../../components/MetricsView/Predictions/BubbleMatrix";
6 | import BubbleMatrixLegend from "../../components/MetricsView/Predictions/BubbleMatrixLegend";
7 | import CorrelationMatrix from "../../components/MetricsView/Predictions/CorrelationMatrix";
8 | import CorrelationMatrixLegend from "../../components/MetricsView/Predictions/CorrelationMatrixLegend";
9 | import Indicator from "../../components/common/Indicator";
10 | import { useForgetClassStore } from "../../stores/forgetClassStore";
11 | import { useModelDataStore } from "../../stores/modelDataStore";
12 | import { TRAIN } from "../../constants/common";
13 | import {
14 | useModelAExperiment,
15 | useModelBExperiment,
16 | } from "../../hooks/useModelExperiment";
17 |
18 | export interface MatrixProps {
19 | mode: "A" | "B";
20 | modelType: string;
21 | datasetMode: string;
22 | hoveredY: number | null;
23 | onHover: (y: number | null) => void;
24 | showYAxis?: boolean;
25 | }
26 |
27 | export default function PredictionMatrix() {
28 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
29 | const modelA = useModelDataStore((state) => state.modelA);
30 | const modelB = useModelDataStore((state) => state.modelB);
31 | const modelAExperiment = useModelAExperiment();
32 | const modelBExperiment = useModelBExperiment();
33 |
34 | const [selectedDataset, setSelectedDataset] = useState(TRAIN);
35 | const [hoveredY, setHoveredY] = useState(null);
36 | const [chartMode, setChartMode] = useState<"corr" | "bubble">("corr");
37 |
38 | const areAllModelsSelected = modelA !== "" && modelB !== "";
39 | const forgetClassExist = forgetClass !== -1;
40 | const isChartModeCorr = chartMode === "corr";
41 |
42 | const handleHover = useCallback((y: number | null) => setHoveredY(y), []);
43 |
44 | const handleModeBtnClick = () => {
45 | isChartModeCorr ? setChartMode("bubble") : setChartMode("corr");
46 | };
47 |
48 | function Matrix(props: MatrixProps) {
49 | return isChartModeCorr ? (
50 |
51 | ) : (
52 |
53 | );
54 | }
55 |
56 | function MatrixLegend() {
57 | return isChartModeCorr ? (
58 |
59 | ) : (
60 |
61 | );
62 | }
63 |
64 | return (
65 |
66 |
70 |
71 |
72 | {forgetClassExist && areAllModelsSelected && (
73 |
77 | )}
78 |
79 | {forgetClassExist ? (
80 | !areAllModelsSelected ? (
81 |
82 | ) : (
83 |
84 |
85 |
86 | {modelAExperiment && (
87 |
94 | )}
95 | {modelBExperiment && (
96 |
104 | )}
105 |
106 |
107 | )
108 | ) : (
109 |
110 | )}
111 |
112 | );
113 | }
114 |
--------------------------------------------------------------------------------
/frontend/src/views/ModelScreening/Experiments.tsx:
--------------------------------------------------------------------------------
1 | import { useState } from "react";
2 |
3 | import View from "../../components/common/View";
4 | import Title from "../../components/common/Title";
5 | import Indicator from "../../components/common/Indicator";
6 | import DataTable from "../../components/ModelScreening/Experiments/DataTable";
7 | import Legend from "../../components/ModelScreening/Experiments/Legend";
8 | import { useForgetClassStore } from "../../stores/forgetClassStore";
9 | import { ArrowDownIcon, ArrowUpIcon } from "../../components/UI/icons";
10 | import { CONFIG } from "../../app/App";
11 |
12 | export default function Experiments() {
13 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
14 |
15 | const [isExpanded, setIsExpanded] = useState(false);
16 |
17 | const forgetClassExist = forgetClass !== -1;
18 |
19 | const handleExpandClick = () => {
20 | setIsExpanded((prevState) => !prevState);
21 | };
22 |
23 | return (
24 |
29 |
30 |
31 |
32 |
36 | {isExpanded ? (
37 |
38 | ) : (
39 |
40 | )}
41 |
42 |
43 | {forgetClassExist &&
}
44 |
45 | {forgetClassExist ? (
46 |
47 | ) : (
48 |
49 | )}
50 |
51 | );
52 | }
53 |
--------------------------------------------------------------------------------
/frontend/src/views/ModelScreening/ModelScreening.tsx:
--------------------------------------------------------------------------------
1 | import Experiments from "./Experiments";
2 | import Progress from "./Progress";
3 | import { Separator } from "../../components/UI/separator";
4 |
5 | export default function ModelScreening() {
6 | return (
7 |
12 | );
13 | }
14 |
--------------------------------------------------------------------------------
/frontend/src/views/ModelScreening/Progress.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect, useMemo } from "react";
2 |
3 | import View from "../../components/common/View";
4 | import Stepper from "../../components/ModelScreening/Progress/Stepper";
5 | import Indicator from "../../components/common/Indicator";
6 | import AddModelsButton from "../../components/ModelScreening/Progress/AddModelsButton";
7 | import Pagination from "../../components/ModelScreening/Progress/Pagination";
8 | import { useForgetClassStore } from "../../stores/forgetClassStore";
9 | import { CONFIG } from "../../app/App";
10 | import { useRunningIndexStore } from "../../stores/runningIndexStore";
11 | import { useRunningStatusStore } from "../../stores/runningStatusStore";
12 | import { getProgressSteps } from "../../utils/data/getProgressSteps";
13 |
14 | export type Step = {
15 | step: number;
16 | title: string;
17 | description: string;
18 | };
19 |
20 | export const PREV = "prev";
21 | export const NEXT = "next";
22 |
23 | export default function Progress() {
24 | const forgetClass = useForgetClassStore((state) => state.forgetClass);
25 | const runningIndex = useRunningIndexStore((state) => state.runningIndex);
26 | const { isRunning, statuses, activeStep, totalExperimentsCount } =
27 | useRunningStatusStore();
28 |
29 | const [umapProgress, setUmapProgress] = useState(0);
30 | const [ckaProgress, setCkaProgress] = useState(0);
31 | const [currentPage, setCurrentPage] = useState(runningIndex + 1);
32 |
33 | const displayedPageIdx = currentPage - 1;
34 | const forgetClassExist = forgetClass !== -1;
35 |
36 | useEffect(() => {
37 | if (isRunning) {
38 | setCurrentPage(1);
39 | }
40 | }, [isRunning]);
41 |
42 | useEffect(() => {
43 | setCurrentPage(runningIndex + 1);
44 | }, [runningIndex]);
45 |
46 | const currentStatus =
47 | forgetClassExist && statuses[forgetClass].length > displayedPageIdx
48 | ? statuses[forgetClass][displayedPageIdx]
49 | : null;
50 |
51 | const progress = currentStatus ? currentStatus.progress : "";
52 |
53 | const steps: Step[] = useMemo(
54 | () =>
55 | forgetClassExist
56 | ? getProgressSteps(currentStatus, activeStep, umapProgress, ckaProgress)
57 | : [],
58 | [activeStep, ckaProgress, currentStatus, forgetClassExist, umapProgress]
59 | );
60 |
61 | useEffect(() => {
62 | let intervalId: NodeJS.Timeout | null = null;
63 | const startTime = Date.now();
64 | const durationInSeconds = 10;
65 | const maxProgress = durationInSeconds * 10;
66 |
67 | if (forgetClassExist && progress) {
68 | intervalId = setInterval(() => {
69 | const elapsedTime = Date.now() - startTime;
70 | const progressValue = Math.min(
71 | Math.floor(elapsedTime / 100),
72 | maxProgress
73 | );
74 |
75 | if (progress.includes("UMAP")) {
76 | setUmapProgress(progressValue);
77 | } else if (progress.includes("CKA")) {
78 | setCkaProgress(progressValue);
79 | }
80 |
81 | if (progressValue === maxProgress) {
82 | clearInterval(intervalId!);
83 | }
84 | }, 100);
85 | }
86 |
87 | return () => {
88 | if (intervalId) clearInterval(intervalId);
89 | };
90 | }, [forgetClassExist, progress]);
91 |
92 | const handlePaginationClick = (event: React.MouseEvent) => {
93 | const id = event.currentTarget.id;
94 | if (id === PREV && currentPage > 1) {
95 | setCurrentPage((prevPage) => prevPage - 1);
96 | } else if (id === NEXT && currentPage < totalExperimentsCount) {
97 | setCurrentPage((prevPage) => prevPage + 1);
98 | }
99 | };
100 |
101 | return (
102 |
106 |
107 | {forgetClassExist ? (
108 |
118 | ) : (
119 |
120 | )}
121 | {totalExperimentsCount > 0 && statuses[forgetClass].length > 0 && (
122 |
123 | )}
124 |
125 | );
126 | }
127 |
--------------------------------------------------------------------------------
/frontend/tailwind.config.js:
--------------------------------------------------------------------------------
1 | /** @type {import('tailwindcss').Config} */
2 | module.exports = {
3 | content: ["./src/**/*.{ts,tsx,js,jsx}"],
4 | theme: {
5 | container: {
6 | center: true,
7 | padding: "2rem",
8 | screens: {
9 | "2xl": "1400px",
10 | },
11 | },
12 | extend: {
13 | colors: {
14 | border: "hsl(var(--border))",
15 | input: "hsl(var(--input))",
16 | ring: "hsl(var(--ring))",
17 | background: "hsl(var(--background))",
18 | foreground: "hsl(var(--foreground))",
19 | primary: {
20 | DEFAULT: "hsl(var(--primary))",
21 | foreground: "hsl(var(--primary-foreground))",
22 | },
23 | secondary: {
24 | DEFAULT: "hsl(var(--secondary))",
25 | foreground: "hsl(var(--secondary-foreground))",
26 | },
27 | destructive: {
28 | DEFAULT: "hsl(var(--destructive))",
29 | foreground: "hsl(var(--destructive-foreground))",
30 | },
31 | muted: {
32 | DEFAULT: "hsl(var(--muted))",
33 | foreground: "hsl(var(--muted-foreground))",
34 | },
35 | accent: {
36 | DEFAULT: "hsl(var(--accent))",
37 | foreground: "hsl(var(--accent-foreground))",
38 | },
39 | popover: {
40 | DEFAULT: "hsl(var(--popover))",
41 | foreground: "hsl(var(--popover-foreground))",
42 | },
43 | card: {
44 | DEFAULT: "hsl(var(--card))",
45 | foreground: "hsl(var(--card-foreground))",
46 | },
47 | },
48 | borderRadius: {
49 | lg: "var(--radius)",
50 | md: "calc(var(--radius) - 2px)",
51 | sm: "calc(var(--radius) - 4px)",
52 | },
53 | keyframes: {
54 | bgPulse: {
55 | "0%, 100%": { backgroundColor: "#f0f6fa" },
56 | "50%": { backgroundColor: "#e0e6f4" },
57 | },
58 | "accordion-down": {
59 | from: { height: 0 },
60 | to: { height: "var(--radix-accordion-content-height)" },
61 | },
62 | "accordion-up": {
63 | from: { height: "var(--radix-accordion-content-height)" },
64 | to: { height: 0 },
65 | },
66 | },
67 | animation: {
68 | bgPulse: "bgPulse 2s infinite",
69 | "accordion-down": "accordion-down 0.2s ease-out",
70 | "accordion-up": "accordion-up 0.2s ease-out",
71 | },
72 | },
73 | },
74 | plugins: [require("tailwindcss-animate")],
75 | };
76 |
--------------------------------------------------------------------------------
/frontend/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es5",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "skipLibCheck": true,
7 | "esModuleInterop": true,
8 | "allowSyntheticDefaultImports": true,
9 | "strict": true,
10 | "forceConsistentCasingInFileNames": true,
11 | "noFallthroughCasesInSwitch": true,
12 | "module": "esnext",
13 | "moduleResolution": "node",
14 | "resolveJsonModule": true,
15 | "isolatedModules": true,
16 | "noEmit": true,
17 | "jsx": "react-jsx",
18 | "baseUrl": ".",
19 | "paths": {
20 | "@/*": ["./*"]
21 | }
22 | },
23 | "include": ["src/app/*", "src/index.tsx", "src/stores"]
24 | }
25 |
--------------------------------------------------------------------------------
/img/attack.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gnueaj/Machine-Unlearning-Comparator/bf0c2c08534ff1bda21300f5da8c407b9af02b61/img/attack.png
--------------------------------------------------------------------------------
/img/embeddings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gnueaj/Machine-Unlearning-Comparator/bf0c2c08534ff1bda21300f5da8c407b9af02b61/img/embeddings.png
--------------------------------------------------------------------------------