├── configs ├── __init__.py └── task_specific │ ├── wm.yaml │ ├── motor.yaml │ ├── social.yaml │ ├── emotion.yaml │ ├── gambling.yaml │ ├── language.yaml │ └── relational.yaml ├── SupplmentaryMaterial.pdf ├── data ├── __pycache__ │ ├── dataset.cpython-39.pyc │ └── task_configs.cpython-39.pyc ├── task_configs.py └── dataset.py ├── models ├── __pycache__ │ ├── convnext.cpython-39.pyc │ └── components.cpython-39.pyc ├── components.py └── convnext.py ├── training ├── __pycache__ │ └── trainer.cpython-39.pyc └── trainer.py ├── LICENSE ├── README.md └── main.py /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SupplmentaryMaterial.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/SupplmentaryMaterial.pdf -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/data/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/convnext.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/models/__pycache__/convnext.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/task_configs.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/data/__pycache__/task_configs.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/components.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/models/__pycache__/components.cpython-39.pyc -------------------------------------------------------------------------------- /training/__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyejay/3D_ConvNeXt_for_fMRI/HEAD/training/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /configs/task_specific/wm.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: WM 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_wm 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/motor.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: MOTOR 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_motor 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/social.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: SOCIAL 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_social 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/emotion.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: EMOTION 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_emotion 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/gambling.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: GAMBLING 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_gambling 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/language.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: LANGUAGE 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_language 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /configs/task_specific/relational.yaml: -------------------------------------------------------------------------------- 1 | # Working Memory task configuration 2 | 3 | task: RELATIONAL 4 | model_depths: [3, 3, 9, 3] 5 | model_dims: [96, 192, 384, 768] 6 | drop_path_rate: 0.0 7 | batch_size: 32 8 | learning_rate: 1e-4 9 | num_epochs: 100 10 | num_workers: 8 11 | project_name: hcp_classification_relational 12 | 13 | # Data augmentation settings (if needed) 14 | augmentation: 15 | enable: false 16 | # Add specific augmentation parameters here 17 | 18 | # Optimizer settings 19 | optimizer: 20 | type: AdamW 21 | weight_decay: 0.05 22 | 23 | # Learning rate scheduler settings 24 | scheduler: 25 | enable: false 26 | # Add scheduler parameters here if needed -------------------------------------------------------------------------------- /data/task_configs.py: -------------------------------------------------------------------------------- 1 | ROOT_DIR = '/media/hcp_hdd/rs_HCP_ku/HCP_sample/' 2 | 3 | TASK_CONFIGS = { 4 | 'WM': { 5 | 'label_mapping': { 6 | '0bk_body': 0, 7 | '0bk_faces': 1, 8 | '0bk_places': 2, 9 | '0bk_tools': 3, 10 | '2bk_body': 4, 11 | '2bk_faces': 5, 12 | '2bk_places': 6, 13 | '2bk_tools': 7 14 | }, 15 | 'num_classes': 8 16 | }, 17 | 'MOTOR': { 18 | 'label_mapping': { 19 | 'lf': 0, 20 | 'lh': 1, 21 | 'rf': 2, 22 | 'rh': 3, 23 | 't': 4 24 | }, 25 | 'num_classes': 5 26 | }, 27 | 'EMOTION': { 28 | 'label_mapping': { 29 | 'fear': 0, 30 | 'neut': 1 31 | }, 32 | 'num_classes': 2 33 | }, 34 | 'GAMBLING': { 35 | 'label_mapping': { 36 | 'loss': 0, 37 | 'win': 1 38 | }, 39 | 'num_classes': 2 40 | }, 41 | 'LANGUAGE': { 42 | 'label_mapping': { 43 | 'math': 0, 44 | 'story': 1 45 | }, 46 | 'num_classes': 2 47 | }, 48 | 'RELATIONAL': { 49 | 'label_mapping': { 50 | 'match': 0, 51 | 'relation': 1 52 | }, 53 | 'num_classes': 2 54 | }, 55 | 'SOCIAL': { 56 | 'label_mapping': { 57 | 'mental': 0, 58 | 'rnd': 1 59 | }, 60 | 'num_classes': 2 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 jyejay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | This project contains code from the following sources: 26 | 27 | 1. ConvNeXt architecture is modified from ConvNeXt 28 | (https://github.com/facebookresearch/ConvNeXt) 29 | Copyright (c) Meta Platforms, Inc. and affiliates. 30 | Licensed under Apache License 2.0 31 | 32 | 2. The GRN (Global Response Normalization) module is modified from ConvNeXt-V2 33 | (https://github.com/facebookresearch/ConvNeXt-V2) 34 | Copyright (c) Meta Platforms, Inc. and affiliates. 35 | Licensed under Apache License 2.0 36 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from .task_configs import TASK_CONFIGS, ROOT_DIR 6 | 7 | 8 | class HCPDataset(Dataset): 9 | """ 10 | HCP Dataset for different tasks 11 | 12 | Args: 13 | task (str): Task name ('WM', 'MOTOR', etc.) 14 | subject_ids (list): List of subject IDs to include 15 | transform (callable, optional): Optional transform to be applied on a sample 16 | """ 17 | def __init__(self, task, subject_ids, transform=None): 18 | self.file_paths = [] 19 | self.labels = [] 20 | self.transform = transform 21 | 22 | if task not in TASK_CONFIGS: 23 | raise ValueError(f"Task {task} not supported. Available tasks: {list(TASK_CONFIGS.keys())}") 24 | 25 | label_mapping = TASK_CONFIGS[task]['label_mapping'] 26 | 27 | # 데이터 파일 경로와 레이블 수집 28 | for subject_id in subject_ids: 29 | subject_dir = os.path.join(ROOT_DIR, task, subject_id) 30 | for root, _, files in os.walk(subject_dir): 31 | for file in files: 32 | if file.endswith('.npy'): 33 | file_path = os.path.join(root, file) 34 | self.file_paths.append(file_path) 35 | for label_str, label_num in label_mapping.items(): 36 | if label_str in file_path: 37 | self.labels.append(label_num) 38 | break 39 | 40 | def __len__(self): 41 | return len(self.file_paths) 42 | 43 | def __getitem__(self, idx): 44 | file_path = self.file_paths[idx] 45 | label = self.labels[idx] 46 | 47 | # 데이터 로드 48 | data = np.load(file_path) 49 | data_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0) 50 | 51 | # 변환 적용 (필요한 경우) 52 | if self.transform: 53 | data_tensor = self.transform(data_tensor) 54 | 55 | label_tensor = torch.tensor(label, dtype=torch.long) 56 | return data_tensor, label_tensor -------------------------------------------------------------------------------- /models/components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath 4 | 5 | """ 6 | The GRN (Global Response Normalization) module is modified from ConvNeXt-V2 7 | (https://github.com/facebookresearch/ConvNeXt-V2) 8 | Copyright (c) Meta Platforms, Inc. and affiliates. 9 | Licensed under Apache License 2.0 10 | 11 | Modified to support 3D operations by jyejay 12 | 13 | """ 14 | 15 | class GRN3D(nn.Module): 16 | """ GRN (Global Response Normalization) layer for 3D data """ 17 | def __init__(self, num_channels): 18 | super().__init__() 19 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1, 1)) 20 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1, 1)) 21 | 22 | def forward(self, x): 23 | Gx = torch.norm(x, p=2, dim=(2, 3, 4), keepdim=True) 24 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 25 | return self.gamma * (x * Nx) + self.beta + x 26 | 27 | class LayerNorm3D(nn.Module): 28 | def __init__(self, num_channels, eps=1e-5): 29 | super().__init__() 30 | self.weight = nn.Parameter(torch.ones(num_channels)) 31 | self.bias = nn.Parameter(torch.zeros(num_channels)) 32 | self.eps = eps 33 | 34 | def forward(self, x): 35 | mean = x.mean(dim=1, keepdim=True) 36 | var = x.var(dim=1, keepdim=True, unbiased=False) 37 | x = (x - mean) / torch.sqrt(var + self.eps) 38 | weight = self.weight.view(1, -1, 1, 1, 1) 39 | bias = self.bias.view(1, -1, 1, 1, 1) 40 | return x * weight + bias 41 | 42 | class Block3D(nn.Module): 43 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 44 | super().__init__() 45 | self.dwconv = nn.Conv3d(dim, dim, kernel_size=3, padding=1, groups=dim) 46 | self.norm = LayerNorm3D(dim) 47 | self.grn = GRN3D(dim*4) 48 | self.pwconv1 = nn.Conv3d(dim, dim * 4, kernel_size=1) 49 | self.act = nn.GELU() 50 | self.pwconv2 = nn.Conv3d(dim * 4, dim, kernel_size=1) 51 | self.gamma = nn.Parameter(torch.ones((dim,)) * layer_scale_init_value, 52 | requires_grad=True) if layer_scale_init_value > 0 else None 53 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 54 | 55 | def forward(self, x): 56 | input = x 57 | x = self.dwconv(x) 58 | x = self.norm(x) 59 | x = self.pwconv1(x) 60 | x = self.act(x) 61 | x = self.grn(x) 62 | x = self.pwconv2(x) 63 | if self.gamma is not None: 64 | gamma = self.gamma.view(1, -1, 1, 1, 1).expand_as(x) 65 | x = gamma * x 66 | return input + self.drop_path(x) -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .components import LayerNorm3D, Block3D 4 | 5 | """ 6 | This code is modified from ConvNeXt (https://github.com/facebookresearch/ConvNeXt) 7 | Copyright (c) Meta Platforms, Inc. and affiliates. 8 | Licensed under Apache License 2.0 9 | 10 | Modified to support 3D operations by jyejay 11 | 12 | """ 13 | 14 | class ConvNeXt3D(nn.Module): 15 | def __init__(self, 16 | in_chans=1, 17 | num_classes=5, 18 | depths=[3, 3, 9, 3], 19 | dims=[96, 192, 384, 768], 20 | drop_path_rate=0., 21 | layer_scale_init_value=1e-6): 22 | super().__init__() 23 | 24 | # Downsample layers 25 | self.downsample_layers = nn.ModuleList([ 26 | nn.Sequential( 27 | nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4), 28 | LayerNorm3D(dims[0]) 29 | ) 30 | ]) 31 | 32 | for i in range(1, len(dims)): 33 | self.downsample_layers.append(nn.Sequential( 34 | nn.Conv3d(dims[i - 1], dims[i], kernel_size=2, stride=2), 35 | LayerNorm3D(dims[i]) 36 | )) 37 | 38 | # Main stages 39 | self.stages = nn.ModuleList() 40 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 41 | cur = 0 42 | 43 | for i in range(4): 44 | stage = nn.Sequential(*[ 45 | Block3D(dim=dims[i], 46 | drop_path=dp_rates[cur + j], 47 | layer_scale_init_value=layer_scale_init_value) 48 | for j in range(depths[i]) 49 | ]) 50 | self.stages.append(stage) 51 | cur += depths[i] 52 | 53 | # Final norm and head 54 | self.norm = LayerNorm3D(dims[-1]) 55 | self.head = nn.Linear(dims[-1], num_classes) 56 | 57 | self.apply(self._init_weights) 58 | 59 | def _init_weights(self, m): 60 | if isinstance(m, nn.Conv3d): 61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | 65 | def forward_features(self, x): 66 | for downsample, stage in zip(self.downsample_layers, self.stages): 67 | x = downsample(x) 68 | x = stage(x) 69 | x = self.norm(x) 70 | return x.mean([2, 3, 4]) # Global average pooling 71 | 72 | def forward(self, x): 73 | x = self.forward_features(x) 74 | x = x.view(x.size(0), -1) 75 | x = self.head(x) 76 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JBHI HCP ConvNeXt 2 | This paper appears in: IEEE Journal of Biomedical and Health Informatics Digital Object Identifier: 10.1109/JBHI.2025.3606512 3 | 4 | Supplementary materials for Whole-Brain Task fMRI Decoding using Stage-Wise Residual-Optimized 3D ConvNeXt with Layer-Global Response Normalization 5 | 6 | # Running Instructions 7 | 8 | ### Prerequisites 9 | #### Data Directory Structure 10 | ``` 11 | # ROOT_DIR = "your_data_path/" 12 | └── EMOTION/ 13 | | └── subject_id/ # e.g., 100206 14 | | ├── fear/ # Label: 0 15 | | │ └── LR/ 16 | | │ ├── 0001.npy 17 | | │ ├── 0002.npy 18 | | │ └── ... 19 | | └── neut/ # Label: 1 20 | | └── LR/ 21 | | ├── 0001.npy 22 | | ├── 0002.npy 23 | | └── ... 24 | └── GAMBLING/ 25 | | └── subject_id/ 26 | | ├── loss/ # Label: 0 27 | | └── win/ # Label: 1 28 | └── LANGUAGE/ 29 | | └── subject_id/ 30 | | ├── math/ # Label: 0 31 | | └── story/ # Label: 1 32 | └── MOTOR/ 33 | | └── subject_id/ 34 | | ├── lf/ # Label: 0 35 | | ├── lh/ # Label: 1 36 | | ├── rf/ # Label: 2 37 | | ├── rh/ # Label: 3 38 | | └── t/ # Label: 4 39 | └── RELATIONAL/ 40 | | └── subject_id/ 41 | | ├── match/ # Label: 0 42 | | └── relation/ # Label: 1 43 | └── SOCIAL/ 44 | | └── subject_id/ 45 | | ├── mental/ # Label: 0 46 | | └── rnd/ # Label: 1 47 | └── WM/ 48 | └── subject_id/ 49 | ├── 0bk_body/ # Label: 0 50 | ├── 0bk_faces/ # Label: 1 51 | ├── 0bk_places/ # Label: 2 52 | ├── 0bk_tools/ # Label: 3 53 | ├── 2bk_body/ # Label: 4 54 | ├── 2bk_faces/ # Label: 5 55 | ├── 2bk_places/ # Label: 6 56 | └── 2bk_tools/ # Label: 7 57 | 58 | ``` 59 | #### Data Description 60 | - Each `.npy` file contains a 3D brain image 61 | - Labels are determined by the directory name (e.g., 'fear': 0, 'neut': 1) 62 | - Data is organized by subject IDs and task types 63 | 64 | #### Data Loading Process 65 | - **Dataset Class (`HCPDataset`):** 66 | - Loads 3D brain images from `.npy` files 67 | - Automatically assigns labels based on directory structure 68 | - Returns data as torch tensors with shape `[1, D, H, W]` 69 | 70 | Before running the code, you need to modify the data path in the configuration file: 71 | 72 | ```python 73 | # In data/task_configs.py 74 | ROOT_DIR = "your_data_path" # Change this to your data directory 75 | ``` 76 | 77 | ``` 78 | python main.py --task EMOTION --batch_size 32 --learning_rate 1e-4 --num_epochs 100 --device cuda:1 --model_dims 64 128 256 512 --drop_path_rate 0.1 79 | ``` 80 | 81 | 82 | ### Data arguments 83 | | Argument | Description | Default | 84 | |----------|-------------|---------| 85 | | `--task` | HCP task | - | 86 | 87 | ### Model arguments 88 | | Argument | Description | Default | 89 | |----------|-------------|---------| 90 | | `--model_depths` | Model layer depths | `[3, 3, 9, 3]` | 91 | | `--model_dims` | Model layer dimensions | `[96, 192, 384, 768]` | 92 | | `--drop_path_rate` | Drop path rate | `0.` | 93 | 94 | ### Training arguments 95 | | Argument | Description | Default | 96 | |----------|-------------|---------| 97 | | `--batch_size` | Batch size for training | `32` | 98 | | `--learning_rate` | Learning rate | `1e-4` | 99 | | `--num_epochs` | Number of epochs | `100` | 100 | | `--num_workers` | Number of workers | `8` | 101 | | `--device` | Device for training | `cuda` | 102 | 103 | ### Other arguments 104 | | Argument | Description | Default | 105 | |----------|-------------|---------| 106 | | `--seed` | Random seed | `1234` | 107 | | `--checkpoint_dir` | Checkpoint directory | `checkpoints` | 108 | | `--project_name` | Project name | `hcp_classification` | 109 | 110 | 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import argparse 6 | from pathlib import Path 7 | 8 | from models.convnext import ConvNeXt3D 9 | from data.task_configs import TASK_CONFIGS, ROOT_DIR 10 | from data.dataset import HCPDataset 11 | from training.trainer import train_fold, create_folds 12 | 13 | def set_seed(seed): 14 | """Set random seeds for reproducibility""" 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train HCP Classification Model') 26 | 27 | # Data arguments 28 | parser.add_argument('--task', type=str, required=True, choices=list(TASK_CONFIGS.keys()), 29 | help='Task to train on') 30 | 31 | # Model arguments 32 | parser.add_argument('--model_depths', type=int, nargs=4, default=[3, 3, 9, 3], 33 | help='Depth of each stage') 34 | parser.add_argument('--model_dims', type=int, nargs=4, default=[96, 192, 384, 768], 35 | help='Dimensions of each stage') 36 | parser.add_argument('--drop_path_rate', type=float, default=0., 37 | help='Drop path rate') 38 | 39 | # Training arguments 40 | parser.add_argument('--batch_size', type=int, default=32) 41 | parser.add_argument('--learning_rate', type=float, default=1e-4) 42 | parser.add_argument('--num_epochs', type=int, default=100) 43 | parser.add_argument('--num_workers', type=int, default=8) 44 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 45 | 46 | # Other arguments 47 | parser.add_argument('--seed', type=int, default=1234) 48 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') 49 | parser.add_argument('--project_name', type=str, default='hcp_classification') 50 | 51 | return parser.parse_args() 52 | 53 | def main(): 54 | args = parse_args() 55 | set_seed(args.seed) 56 | 57 | # Create checkpoint directory 58 | Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) 59 | 60 | # Prepare data and create folds 61 | subject_ids = [] 62 | for root, _, files in os.walk(os.path.join(ROOT_DIR, args.task)): 63 | for file in files: 64 | if file.endswith('.npy'): 65 | subject_id = root.split(os.sep)[-3] 66 | if subject_id not in subject_ids: 67 | subject_ids.append(subject_id) 68 | 69 | folds = create_folds(subject_ids, n_splits=5, random_state=args.seed) 70 | config = vars(args) # Convert args to dictionary 71 | 72 | # Cross-validation training 73 | fold_accuracies = [] 74 | for fold_idx, test_subjects in enumerate(folds): 75 | print(f"\nTraining Fold {fold_idx + 1}/5") 76 | 77 | # Prepare train/test split for this fold 78 | train_subjects = [subj for subj in subject_ids if subj not in test_subjects] 79 | 80 | train_dataset = HCPDataset(args.task, train_subjects) 81 | test_dataset = HCPDataset(args.task, test_subjects) 82 | 83 | # Create model for this fold 84 | model = ConvNeXt3D( 85 | in_chans=1, 86 | num_classes=TASK_CONFIGS[args.task]['num_classes'], 87 | depths=args.model_depths, 88 | dims=args.model_dims, 89 | drop_path_rate=args.drop_path_rate 90 | ).to(args.device) 91 | 92 | # Train model 93 | best_acc = train_fold(fold_idx, model, train_dataset, test_dataset, config) 94 | fold_accuracies.append(best_acc) 95 | 96 | print(f"Fold {fold_idx + 1} Best Accuracy: {best_acc:.2f}%") 97 | 98 | # Print final results 99 | print("\nCross-validation Results:") 100 | print(f"Mean Accuracy: {np.mean(fold_accuracies):.2f}% ± {np.std(fold_accuracies):.2f}%") 101 | print(f"Individual Fold Accuracies: {[f'{acc:.2f}%' for acc in fold_accuracies]}") 102 | 103 | if __name__ == '__main__': 104 | main() -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | from tqdm import tqdm 4 | from torch.utils.data import DataLoader 5 | from sklearn.model_selection import KFold 6 | from data.dataset import HCPDataset 7 | from data.task_configs import TASK_CONFIGS 8 | import os 9 | 10 | class Trainer: 11 | def __init__(self, model, optimizer, criterion, device, config): 12 | self.model = model 13 | self.optimizer = optimizer 14 | self.criterion = criterion 15 | self.device = device 16 | self.config = config 17 | 18 | def train_one_epoch(self, data_loader): 19 | self.model.train() 20 | running_loss = 0.0 21 | correct = 0 22 | total = 0 23 | 24 | for inputs, labels in data_loader: 25 | inputs = inputs.to(self.device, non_blocking=True) 26 | labels = labels.to(self.device, non_blocking=True) 27 | 28 | self.optimizer.zero_grad() 29 | outputs = self.model(inputs) 30 | loss = self.criterion(outputs, labels) 31 | loss.backward() 32 | self.optimizer.step() 33 | 34 | running_loss += loss.item() * inputs.size(0) 35 | _, predicted = torch.max(outputs, 1) 36 | total += labels.size(0) 37 | correct += (predicted == labels).sum().item() 38 | 39 | epoch_loss = running_loss / total 40 | epoch_acc = (correct / total) * 100 41 | return epoch_loss, epoch_acc 42 | 43 | def evaluate(self, data_loader): 44 | self.model.eval() 45 | running_loss = 0.0 46 | correct = 0 47 | total = 0 48 | 49 | with torch.no_grad(): 50 | for inputs, labels in data_loader: 51 | inputs = inputs.to(self.device, non_blocking=True) 52 | labels = labels.to(self.device, non_blocking=True) 53 | 54 | outputs = self.model(inputs) 55 | loss = self.criterion(outputs, labels) 56 | 57 | running_loss += loss.item() * inputs.size(0) 58 | _, predicted = torch.max(outputs, 1) 59 | total += labels.size(0) 60 | correct += (predicted == labels).sum().item() 61 | 62 | epoch_loss = running_loss / total 63 | epoch_acc = (correct / total) * 100 64 | return epoch_loss, epoch_acc 65 | 66 | def train_fold(fold_idx, model, train_dataset, test_dataset, config): 67 | """Train and evaluate model on a single fold""" 68 | wandb.init( 69 | project=config['project_name'], 70 | name=f"fold_{fold_idx}", 71 | config=config, 72 | reinit=True 73 | ) 74 | 75 | train_loader = DataLoader( 76 | train_dataset, 77 | batch_size=config['batch_size'], 78 | shuffle=True, 79 | num_workers=config['num_workers'], 80 | pin_memory=True 81 | ) 82 | 83 | test_loader = DataLoader( 84 | test_dataset, 85 | batch_size=config['batch_size'], 86 | shuffle=False, 87 | num_workers=config['num_workers'], 88 | pin_memory=True 89 | ) 90 | 91 | trainer = Trainer( 92 | model=model, 93 | optimizer=torch.optim.AdamW(model.parameters(), lr=config['learning_rate']), 94 | criterion=torch.nn.CrossEntropyLoss(), 95 | device=config['device'], 96 | config=config 97 | ) 98 | 99 | best_acc = 0 100 | 101 | for epoch in tqdm(range(config['num_epochs']), desc=f"Fold {fold_idx} Training"): 102 | train_loss, train_acc = trainer.train_one_epoch(train_loader) 103 | test_loss, test_acc = trainer.evaluate(test_loader) 104 | 105 | if test_acc > best_acc: 106 | best_acc = test_acc 107 | checkpoint_name = f"task_{config['task']}_depths{''.join(map(str, config['model_depths']))}_dims{''.join(map(str, config['model_dims']))}_fold{fold_idx}_best.pth" 108 | checkpoint_path = os.path.join(config['checkpoint_dir'], checkpoint_name) 109 | 110 | torch.save({ 111 | 'epoch': epoch, 112 | 'model_state_dict': model.state_dict(), 113 | 'optimizer_state_dict': trainer.optimizer.state_dict(), 114 | 'best_acc': best_acc, 115 | 'task': config['task'], 116 | 'model_depths': config['model_depths'], 117 | 'model_dims': config['model_dims'], 118 | 'config': config # 전체 설정도 저장 119 | }, checkpoint_path) 120 | 121 | 122 | wandb.log({ 123 | 'epoch': epoch, 124 | 'train_loss': train_loss, 125 | 'train_acc': train_acc, 126 | 'test_loss': test_loss, 127 | 'test_acc': test_acc, 128 | 'best_acc': best_acc 129 | }) 130 | 131 | wandb.finish() 132 | return best_acc 133 | 134 | def create_folds(subject_ids, n_splits=5, random_state=42): 135 | """Create stratified folds for cross-validation""" 136 | kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) 137 | folds = [] 138 | for _, test_idx in kf.split(subject_ids): 139 | folds.append([subject_ids[i] for i in test_idx]) 140 | return folds --------------------------------------------------------------------------------