├── imgs ├── archive.png ├── bt_example.png ├── sl_example.png └── HAM10000_images.png ├── exp.py ├── main.py ├── infer-hf.py ├── README.md ├── test-hf.py ├── train-pt.py └── train-hf.py /imgs/archive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hcd233/Skin_Cancer_Classification/HEAD/imgs/archive.png -------------------------------------------------------------------------------- /imgs/bt_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hcd233/Skin_Cancer_Classification/HEAD/imgs/bt_example.png -------------------------------------------------------------------------------- /imgs/sl_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hcd233/Skin_Cancer_Classification/HEAD/imgs/sl_example.png -------------------------------------------------------------------------------- /imgs/HAM10000_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hcd233/Skin_Cancer_Classification/HEAD/imgs/HAM10000_images.png -------------------------------------------------------------------------------- /exp.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import os 6 | import shutil 7 | 8 | 9 | labels_dir = "./archive/labels/" 10 | imgs_dir = "./archive/HAM10000_images/" 11 | 12 | 13 | def MakeLabels(dataframe: pd.DataFrame) -> list: 14 | """image_id: str -> PIL.Image""" 15 | 16 | assert "image_id" in dataframe.columns and "dx" in dataframe.columns 17 | 18 | lens, _ = dataframe.shape 19 | # postfix = ".jpg" 20 | dataset = [] 21 | t = tqdm(range(lens)) 22 | for i in t: 23 | t.set_description("Make labels") 24 | label_path = dataframe['image_id'][i] + ".json" 25 | img_path = dataframe['image_id'][i] + ".jpg" 26 | content = {"labels": [{"name": dataframe['dx'][i]}]} 27 | with open(labels_dir + label_path,mode='w+',encoding="utf-8") as file: 28 | json.dump(obj=content,fp=file,indent=4) 29 | shutil.copy(imgs_dir+img_path,labels_dir+img_path) 30 | 31 | if __name__ == "__main__": 32 | df = pd.read_csv("./archive/HAM4000_metadata.csv") 33 | MakeLabels(df) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # import library 2 | import pandas as pd 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from transformers import pipeline 6 | 7 | 8 | def compute_accuracy(): 9 | correct = 0 10 | lens, _ = df.shape 11 | postfix = ".jpg" 12 | t = tqdm(range(lens)) 13 | 14 | for i in t: 15 | t.set_description(f"Correct Num: {correct}") 16 | img = Image.open(IMAGES_DIR + df['image_id'][i] + postfix) 17 | res = classifier(img) 18 | if IDX2LABEL[res[0]['label']] == df['dx'][i]: 19 | correct += 1 20 | return correct / lens 21 | 22 | 23 | IMAGES_DIR = './archive/HAM10000_images/' 24 | METADATA_PATH = './archive/HAM10000_metadata.csv' 25 | MODEL_DIR = './checkpoints/vit-large-91' 26 | 27 | IDX2LABEL = {'LABEL_0': 'vasc', 28 | 'LABEL_1': 'bcc', 29 | 'LABEL_2': 'mel', 30 | 'LABEL_3': 'nv', 31 | 'LABEL_4': 'df', 32 | 'LABEL_5': 'akiec', 33 | 'LABEL_6': 'bkl'} 34 | 35 | df = pd.read_csv(METADATA_PATH, usecols=['image_id', 'dx']) 36 | classifier = pipeline("image-classification", model=MODEL_DIR) 37 | 38 | if __name__ == '__main__': 39 | acc = compute_accuracy() 40 | print("Accuracy for entire dataset: {:.6f}".format(acc)) 41 | # 94.0589 % 42 | -------------------------------------------------------------------------------- /infer-hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from PIL import Image 5 | from transformers import pipeline 6 | 7 | # argparse 8 | parser = argparse.ArgumentParser(description='Infer Model') 9 | 10 | parser.add_argument('--image_path', type=str, default=None, 11 | help='single image inference') 12 | parser.add_argument('--model_dir', type=str, default='./checkpoints/vit-large-91', 13 | help='path to pretrained model directory') 14 | 15 | parser.add_argument('--batch_images_dir', type=str, default=None, 16 | help='batch images inference') 17 | 18 | args = parser.parse_args() 19 | 20 | # hyperparameter 21 | 22 | MODEL_DIR = args.model_dir 23 | 24 | IMAGE_PATH = args.image_path 25 | 26 | BATCH_IMAGES_DIR = args.batch_images_dir 27 | 28 | IDX2LABEL = {'LABEL_0': 'vasc', 29 | 'LABEL_1': 'bcc', 30 | 'LABEL_2': 'mel', 31 | 'LABEL_3': 'nv', 32 | 'LABEL_4': 'df', 33 | 'LABEL_5': 'akiec', 34 | 'LABEL_6': 'bkl', 35 | 'LABEL_7': 'not a cancer image'} 36 | 37 | # Build Pipline 38 | 39 | assert not (IMAGE_PATH is None and BATCH_IMAGES_DIR is None), "Invalid Image Input." 40 | 41 | if __name__ == '__main__': 42 | classifier = pipeline("image-classification", model=MODEL_DIR) 43 | 44 | if IMAGE_PATH is not None: 45 | img = Image.open(IMAGE_PATH) 46 | 47 | infer_result = classifier(img) 48 | 49 | img.close() 50 | print(f"Image: {IMAGE_PATH} Result: {IDX2LABEL[infer_result[0]['label']]}") 51 | print(infer_result) 52 | 53 | if BATCH_IMAGES_DIR is not None: 54 | files = os.listdir(BATCH_IMAGES_DIR) 55 | lens = len(files) 56 | imgs = [] 57 | for i in range(lens): 58 | imgs.append(Image.open(BATCH_IMAGES_DIR + '\\' + files[i])) 59 | for i in range(lens): 60 | infer_result = classifier(imgs[i]) 61 | print(f"Image: {files[i]} Result: {IDX2LABEL[infer_result[0]['label']]} Confi: {infer_result[0]['score']}") 62 | print(infer_result[:2]) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skin_Cancer_Classification 2 | 皮肤癌分类任务,采用HAM10000数据集,总共7个类别 3 | ## Train-hf.py 4 | 5 | 这是一个用于训练图像分类模型的代码。在运行代码之前,用户需要安装以下依赖库:argparse, os, pandas, numpy, PIL, datasets, torchvision, tqdm和transformers。用户还需要从Hugging Face上下载所需的预训练模型。 6 | 7 | ### 参数说明 8 | 9 | - `--metadata_path`:metadata文件的路径。默认为"./archive/HAM10000_metadata.csv"。 10 | - `--images_dir`:图像文件夹的路径。默认为"./archive/HAM10000_images/"。 11 | - `--model_dir`:预训练模型的路径。默认为"../model/vit-large-patch16-224-in21k"。 12 | - `--checkpoints_dir`:保存检查点文件的文件夹路径。默认为"./checkpoints"。 13 | - `--learning_rate`:学习率。默认为1e-5。 14 | - `--batch_size`:批大小。默认为64。 15 | - `--epochs`:训练轮数。默认为5。 16 | - `--warmup_ratio`:预热步骤的比例。默认为0.1。 17 | - `--split`:训练-验证数据集的分割比例。默认为0.8。 18 | - `--gpu`:指定使用哪张GPU。默认为"0"。 19 | - `--logging_steps`:每隔多少步记录一次训练日志。默认为50。 20 | 21 | 用户可以在命令行中传递这些参数,例如: 22 | ```shell 23 | python train-hf.py --metadata_path ./archive/HAM4000_metadata.csv \ 24 | --images_dir ./archive/HAM10000_images/ \ 25 | --checkpoints_dir ./checkpoints \ 26 | --learning_rate 1e-4 \ 27 | --batch_size 64 \ 28 | --epochs 20 \ 29 | --warmup_ratio 0.1 \ 30 | --model_dir ../model/vit-large-patch16-224-in21k \ 31 | --gpu 5,6,7 \ 32 | --logging_steps 1 33 | ``` 34 | 在代码运行过程中,会执行以下步骤: 35 | 36 | 1. 读取metadata文件,获取图像文件名和标签。 37 | 2. 将图像读入内存,并随机打乱。 38 | 3. 将数据集划分为训练集和验证集。 39 | 4. 对图像进行预处理,包括随机裁剪、归一化和转换为tensor。 40 | 5. 加载预训练模型,构建分类器。 41 | 6. 训练模型,并在验证集上评估模型性能。 42 | 7. 在训练过程中,每隔logging_steps步记录一次训练日志,包括损失值、准确率等指标。 43 | 8. 在训练结束后,保存模型的权重文件到checkpoints_dir文件夹中。 44 | 45 | 46 | 47 | ## Infer-hf.py 48 | 这个Python脚本是一个基于Transformers库的图像分类推理模型,能够对单张或批量的图像进行分类预测。使用该脚本需要在命令行中指定相关参数。 49 | ### 参数说明 50 | `--image_path`:指定单张图片的路径,例如 `--image_path ./test.jpg` 默认为`None` 51 | 52 | `--model_dir`:指定预训练模型的路径,例如`--model_dir ./checkpoints/vit-large-91`默认`./checkpoints/vit-large-91` 53 | 54 | `--batch_images_dir`:指定需要批量预测的图片所在文件夹的路径,例如`--batch_images_dir ./batch_images`默认为`None 55 | ` 56 | ### 输出说明 57 | 在脚本运行时,会先根据指定的预训练模型构建一个图像分类推理模型。接下来,如果指定了`--image_path`则会对该路径下的单张图片进行预测,输出预测结果和置信度。如果指定了`--batch_images_dir` 58 | 则会对该文件夹下的所有图片进行批量预测,输出每张图片的预测结果和置信度。 59 | ### 使用示例 60 | 1. 对单张图片进行预测: 61 | ```shell 62 | python infer-hf.py --image_path ./test.jpg --model_dir ./checkpoints/vit-large-91 63 | ``` 64 | ![alt example 1](./imgs/sl_example.png) 65 | 2. 夹下的所有图片进行批量预测: 66 | ```shell 67 | python infer-hf.py --batch_images_dir ./batch_images --model_dir ./checkpoints/vit-large-91 68 | ``` 69 | ![alt example 1](./imgs/bt_example.png) 70 | 71 | ## Test-hf.py 72 | 对测试集进行测试,修改超参数Run一下即可,小心爆显存,实测3060laptop batch_size 200以内没问题 73 | -------------------------------------------------------------------------------- /test-hf.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from PIL import Image 3 | from datasets import Dataset 4 | from torch import device 5 | from tqdm import tqdm 6 | from transformers import pipeline 7 | 8 | # hyper 9 | TEST_IMAGES_DIR = "./archive/test/" 10 | TEST_TRUTH_CSV = "./archive/test_truth.csv" 11 | MODEL_PATH = "./checkpoints/vit-large-89-386" 12 | BATCH_SIZE = 128 13 | DEVICE = device("cuda:0") 14 | SAVE_LOG = f"./logs/test-{MODEL_PATH[MODEL_PATH.rfind('/') + 1:]}.txt" 15 | 16 | 17 | # utils functions 18 | 19 | def ReadTestImage(dataframe: pd.DataFrame, prefix) -> list: 20 | assert "image" in dataframe.columns and "label" in dataframe.columns 21 | 22 | lens, _ = dataframe.shape 23 | postfix = ".jpg" 24 | dataset = [] 25 | t = tqdm(range(lens)) 26 | for i in t: 27 | t.set_description("Reading Image") 28 | img = Image.open(prefix + dataframe['image'][i] + postfix) 29 | dataset.append( 30 | { 31 | "name": dataframe['image'][i], 32 | "image": img, 33 | "label": dataframe['label'][i] 34 | } 35 | ) 36 | img.close() 37 | return dataset 38 | 39 | 40 | # csv 41 | # image MEL NV BCC AKIEC BKL DF VASC 42 | # 0 False False True False False False False False 43 | # 1 False False True False False False False False 44 | # 2 False False False False False True False False 45 | # 3 False False True False False False False False 46 | # 4 False False True False False False False False 47 | # ... ... ... ... ... ... ... ... ... 48 | # 1507 False False False False False True False False 49 | # 1508 False False True False False False False False 50 | # 1509 False False False False True False False False 51 | # 1510 False False False False False True False False 52 | # 1511 False False False True False False False False 53 | # 54 | # [1512 rows x 8 columns] 55 | 56 | 57 | df = pd.read_csv(TEST_TRUTH_CSV) 58 | df['label'] = df.iloc[:, 1:].idxmax(axis=1) 59 | df = df.drop(columns=['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC']) 60 | 61 | # processed csv 62 | # image label 63 | # 0 ISIC_0034524 NV 64 | # 1 ISIC_0034525 NV 65 | # 2 ISIC_0034526 BKL 66 | # 3 ISIC_0034527 NV 67 | # 4 ISIC_0034528 NV 68 | # ... ... ... 69 | # 1507 ISIC_0036060 BKL 70 | # 1508 ISIC_0036061 NV 71 | # 1509 ISIC_0036062 AKIEC 72 | # 1510 ISIC_0036063 BKL 73 | # 1511 ISIC_0036064 BCC 74 | # 75 | # [1512 rows x 2 columns] 76 | 77 | # dataset 78 | IDX2LABEL = {'LABEL_0': 'VASC', 79 | 'LABEL_1': 'BCC', 80 | 'LABEL_2': 'MEL', 81 | 'LABEL_3': 'NV', 82 | 'LABEL_4': 'DF', 83 | 'LABEL_5': 'AKIEC', 84 | 'LABEL_6': 'BKL', 85 | 'LABEL_7': 'Not A Cancer Image'} 86 | 87 | ds = ReadTestImage(df, prefix=TEST_IMAGES_DIR) 88 | 89 | ds = Dataset.from_list(ds) 90 | 91 | # pineline 92 | 93 | classifier = pipeline(task="image-classification", 94 | model=MODEL_PATH, 95 | device=DEVICE, 96 | batch_size=BATCH_SIZE) 97 | 98 | # test 99 | t = tqdm(range(0, len(ds), BATCH_SIZE)) 100 | t.set_description("Testing") 101 | 102 | infer_result = [] 103 | for i in t: 104 | batch_result = classifier(ds['image'][i:i + BATCH_SIZE]) 105 | infer_result += batch_result 106 | 107 | acc = sum([1 for i in range(len(infer_result)) if IDX2LABEL[infer_result[i][0]['label']] == ds['label'][i]]) / len( 108 | infer_result) \ 109 | * 100 110 | 111 | with open(SAVE_LOG, mode="w+", encoding="utf-8") as file: 112 | t = tqdm(range(len(infer_result))) 113 | t.set_description("Writing Result") 114 | file.write("Accuracy in {:d} samples : {:.6f}% \n".format(len(infer_result), acc)) 115 | for i in t: 116 | file.write(f"IMAGE: {ds['name'][i]} " 117 | f"Truth: {ds['label'][i]} " 118 | f"Pred: {IDX2LABEL[infer_result[i][0]['label']]} " 119 | f"Confi: {infer_result[i][0]['score']} \n") 120 | 121 | print("Accuracy in {:d} samples : {:.6f}%".format(len(infer_result), acc)) 122 | -------------------------------------------------------------------------------- /train-pt.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from PIL import Image 8 | from torch.optim import AdamW, SGD, lr_scheduler 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision.transforms import Compose, ToTensor, Normalize, CenterCrop, Resize, RandomResizedCrop 11 | from tqdm import tqdm 12 | from vit_pytorch import ViT 13 | 14 | # hyperparameter 15 | 16 | IMAGES_DIR = "./archive/HAM10000_images/" 17 | METADATA_CSV = "./archive/HAM4000_metadata.csv" 18 | CHECKPOINTS = "./checkpoints/" 19 | LOG_STEP = 3 20 | SAVE_PER_EPOCH_NUM = 2 21 | IS_PARALLEL = True 22 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | NUM_GPUS = 2 24 | DEVICES = [torch.device(f"cuda:{i}") for i in range(min(torch.cuda.device_count(), NUM_GPUS))] 25 | 26 | IMAGE_SIZE = 384 27 | PATCH_SIZE = 16 28 | 29 | SPLIT = 0.8 30 | EPOCH = 40 31 | LEARNING_RATE = 1e-3 32 | BATCH_SIZE = 64 33 | 34 | if IS_PARALLEL: 35 | DEVICE = DEVICES[0] 36 | BATCH_SIZE *= len(DEVICES) 37 | LEARNING_RATE *= len(DEVICES) 38 | 39 | 40 | # utils functions and classes 41 | 42 | class ImageDataset(Dataset): 43 | def __init__(self, images, labels, transform=None): 44 | super().__init__() 45 | self.transform = transform 46 | self.images = images 47 | self.labels = labels 48 | 49 | def __len__(self): 50 | return len(self.images) 51 | 52 | def __getitem__(self, idx): 53 | with Image.open(self.images[idx]).convert("RGB") as img: 54 | 55 | if self.transform: 56 | return self.transform(img), self.labels[idx] 57 | else: 58 | return img, self.labels[idx] 59 | 60 | 61 | def ReadImage(dataframe: pd.DataFrame, images_path: str) -> list: 62 | """image_id: str -> PIL.Image""" 63 | assert "image_id" in dataframe.columns and "dx" in dataframe.columns 64 | 65 | lens, _ = dataframe.shape 66 | postfix = ".jpg" 67 | dataset = [] 68 | t = tqdm(range(lens)) 69 | for i in t: 70 | t.set_description("Reading Image") 71 | dataset.append( 72 | { 73 | "image": images_path + dataframe['image_id'][i] + postfix, 74 | "label": torch.tensor(LABEL2IDX[dataframe['dx'][i]]) 75 | } 76 | ) 77 | return dataset 78 | 79 | 80 | # load dataset 81 | 82 | # IDX2LABEL = dict(enumerate(set(df['dx']))) 83 | IDX2LABEL = {0: 'vasc', 84 | 1: 'bcc', 85 | 2: 'mel', 86 | 3: 'nv', 87 | 4: 'df', 88 | 5: 'akiec', 89 | 6: 'bkl'} 90 | 91 | # LABEL2IDX = {v: k for k, v in IDX2LABEL.items()} 92 | LABEL2IDX = {'vasc': 0, 93 | 'bcc': 1, 94 | 'mel': 2, 95 | 'nv': 3, 96 | 'df': 4, 97 | 'akiec': 5, 98 | 'bkl': 6} 99 | 100 | df = pd.read_csv(METADATA_CSV) 101 | 102 | ds = ReadImage(df, IMAGES_DIR) 103 | 104 | # split 105 | 106 | random.seed(1919810) 107 | random.shuffle(ds) 108 | 109 | lens = len(ds) 110 | train_ds = ds[:int(lens * SPLIT)] 111 | dev_ds = ds[int(lens * SPLIT):] 112 | 113 | print(f"Train : {len(train_ds)}") 114 | print(f"DEV : {len(dev_ds)}") 115 | 116 | # train_ds = ImageDataset(**{ 117 | # 'images': samp["image"], 118 | # 'labels': samp["label"], } for samp in train_ds) 119 | 120 | # train_ds = ImageDataset(*[(samp["image"], samp["label"]) for samp in train_ds]) 121 | 122 | transform = Compose([ 123 | # Resize(IMAGE_SIZE), 124 | RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)), 125 | ToTensor(), 126 | Normalize(0, 1), 127 | ]) 128 | 129 | train_ds = ImageDataset(images=[samp["image"] for samp in train_ds], 130 | labels=[samp["label"] for samp in train_ds], 131 | transform=transform) 132 | 133 | dev_ds = ImageDataset(images=[samp["image"] for samp in dev_ds], 134 | labels=[samp["label"] for samp in dev_ds], 135 | transform=transform) 136 | 137 | train_dl = DataLoader(dataset=train_ds, 138 | batch_size=BATCH_SIZE, 139 | shuffle=True, 140 | num_workers=0) 141 | dev_dl = DataLoader(dataset=dev_ds, 142 | batch_size=BATCH_SIZE, 143 | shuffle=True, 144 | num_workers=0) 145 | 146 | # define model 147 | 148 | model = ViT( 149 | image_size=IMAGE_SIZE, 150 | patch_size=PATCH_SIZE, 151 | num_classes=len(IDX2LABEL), 152 | dim=2048, 153 | depth=16, 154 | heads=20, 155 | mlp_dim=4096, 156 | dropout=0.1, 157 | emb_dropout=0 158 | ).to(DEVICE) 159 | 160 | if IS_PARALLEL: 161 | model = nn.DataParallel(model, device_ids=DEVICES) 162 | 163 | if __name__ == '__main__': 164 | # optimizer = SGD(model.parameters(), lr=LEARNING_RATE) 165 | 166 | optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) 167 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCH) 168 | 169 | epoch = tqdm(range(EPOCH)) 170 | for e in epoch: 171 | epoch.set_description(f"Epoch {e}") 172 | model.train() 173 | for batch_idx, (images, labels) in enumerate(train_dl): 174 | optimizer.zero_grad() 175 | images = images.to(DEVICE) 176 | labels = labels.to(DEVICE) 177 | 178 | out = model(images).to(DEVICE) 179 | # out = torch.argmax(out, dim=-1) 180 | loss = F.cross_entropy(out, labels) 181 | 182 | loss.backward() 183 | optimizer.step() 184 | scheduler.step() 185 | 186 | if batch_idx % LOG_STEP == 0: 187 | print('Train Epoch: {} [{}/{} ({:.3f}%)] Lr: {:e} Loss: {:.6f}'.format( 188 | e, batch_idx, len(train_dl), 189 | 100. * batch_idx / len(train_dl), optimizer.param_groups[0]['lr'], loss.item())) 190 | dev_loss = [] 191 | model.eval() 192 | with torch.no_grad(): 193 | acc = [] 194 | 195 | for batch_idx, (images, labels) in enumerate(dev_dl): 196 | images = images.to(DEVICE) 197 | labels = labels.to(DEVICE) 198 | 199 | out = model(images).to(DEVICE) 200 | 201 | pred = torch.argmax(out, dim=-1) 202 | 203 | acc += [1 if pred[i] == labels[i] else 0 for i in range(len(pred))] 204 | 205 | loss = F.cross_entropy(out, labels) 206 | dev_loss.append(loss) 207 | acc = sum(acc) / len(acc) 208 | 209 | dev_loss = torch.mean(torch.tensor(dev_loss)) 210 | 211 | print("\nEpoch {} Validation loss: {:.6f} Accuracy: {:.6f}%.\n".format(e, dev_loss, acc * 100)) 212 | 213 | if e % SAVE_PER_EPOCH_NUM == 0 or e == EPOCH - 1: 214 | save_file = "VIT-large-{}px-{}patch-{}epoch-{:.4f}loss.bin".format(IMAGE_SIZE, PATCH_SIZE, e, dev_loss) 215 | torch.save(model.state_dict(), CHECKPOINTS + save_file) 216 | print(f"Save to {CHECKPOINTS + save_file}\n") 217 | -------------------------------------------------------------------------------- /train-hf.py: -------------------------------------------------------------------------------- 1 | # import library 2 | import argparse 3 | import os 4 | from random import shuffle, seed 5 | import evaluate 6 | import numpy as np 7 | import pandas as pd 8 | from PIL import Image 9 | from datasets import Dataset, DatasetDict 10 | from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor 11 | from tqdm import tqdm 12 | from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, \ 13 | DefaultDataCollator 14 | 15 | # argparse 16 | parser = argparse.ArgumentParser(description='Train Model') 17 | 18 | parser.add_argument('--metadata_path', type=str, default='./archive/HAM10000_metadata.csv', 19 | help='path to metadata file') 20 | parser.add_argument('--images_dir', type=str, default='./archive/HAM10000_images/', 21 | help='path to images directory') 22 | parser.add_argument('--model_dir', type=str, default='../model/vit-large-patch16-224-in21k', 23 | help='path to pretrained model directory') 24 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', 25 | help='path to save model checkpoints') 26 | parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate') 27 | parser.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | parser.add_argument('--epochs', type=int, default=5, help='number of epochs to train for') 29 | parser.add_argument('--warmup_ratio', type=float, default=0.1, help='ratio of warmup steps to total training steps') 30 | parser.add_argument('--split', type=float, default=0.8, help='train-validation split ratio') 31 | parser.add_argument('--gpu', type=str, default='0', help='CUDA visible devices') 32 | parser.add_argument('--logging_steps', type=int, default='50', help='Print log per step') 33 | args = parser.parse_args() 34 | 35 | # hyperparameter 36 | 37 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 38 | 39 | METADATA_PATH = args.metadata_path 40 | 41 | IMAGES_DIR = args.images_dir 42 | 43 | MODEL_DIR = args.model_dir 44 | 45 | CHECKPOINTS_DIR = args.checkpoints_dir 46 | 47 | LEARNING_RATE = args.learning_rate 48 | 49 | BATCH_SIZE = args.batch_size 50 | 51 | EPOCHS = args.epochs 52 | 53 | WARMUP_RATIO = args.warmup_ratio 54 | 55 | SPLIT = args.split 56 | 57 | LOGGING_STEPS = args.logging_steps 58 | 59 | RAW_PATH = "./archive/raw" 60 | 61 | 62 | # utils functions 63 | 64 | def ReadImage(dataframe: pd.DataFrame, images_path: str) -> list: 65 | """image_id: str -> PIL.Image""" 66 | 67 | assert "image_id" in dataframe.columns and "dx" in dataframe.columns 68 | 69 | lens, _ = dataframe.shape 70 | postfix = ".jpg" 71 | dataset = [] 72 | t = tqdm(range(lens)) 73 | for i in t: 74 | t.set_description("Reading Image") 75 | img = Image.open(images_path + dataframe['image_id'][i] + postfix) 76 | dataset.append( 77 | { 78 | "image": img, 79 | "label": dataframe['dx'][i] 80 | } 81 | ) 82 | img.close() 83 | return dataset 84 | 85 | 86 | def ReadRaw() -> list: 87 | imgs = os.listdir(RAW_PATH) 88 | dataset = [] 89 | t = tqdm(imgs) 90 | for i in t: 91 | t.set_description("Reading Raw Image") 92 | img = Image.open(RAW_PATH + "/" + i) 93 | dataset.append( 94 | { 95 | "image": img, 96 | "label": 'not a cancer image' 97 | } 98 | ) 99 | img.close() 100 | return dataset 101 | 102 | 103 | def transforms(examples): 104 | trans = _transforms() 105 | examples["pixel_values"] = [trans(img.convert("RGB")) for img in examples["image"]] 106 | examples["label"] = [LABEL2IDX[label] for label in examples["label"]] 107 | del examples["image"] 108 | return examples 109 | 110 | 111 | def _transforms(): 112 | normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) 113 | size = ( 114 | image_processor.size["shortest_edge"] 115 | if "shortest_edge" in image_processor.size 116 | else (image_processor.size["height"], image_processor.size["width"]) 117 | ) 118 | return Compose([RandomResizedCrop(size), ToTensor(), normalize]) 119 | 120 | 121 | def compute_metrics(eval_pred): 122 | predictions, labels = eval_pred 123 | predictions = np.argmax(predictions, axis=1) 124 | return acc.compute(predictions=predictions, references=labels) 125 | 126 | 127 | # import dataset 128 | 129 | # Actinic keratoses and intraepithelial carcinoma / Bowen's disease (akiec), 130 | # basal cell carcinoma (bcc), 131 | # benign keratosis-like lesions (solar lentigines / seborrheic keratoses and lichen-planus like keratoses, bkl), 132 | # dermatofibroma (df), 133 | # melanoma (mel), 134 | # melanocytic nevi (nv), 135 | # vascular lesions (angiomas, angiokeratomas, pyogenic granulomas and hemorrhage, vasc). 136 | 137 | # IDX2LABEL = dict(enumerate(set(df['dx']))) 138 | IDX2LABEL = {0: 'vasc', 139 | 1: 'bcc', 140 | 2: 'mel', 141 | 3: 'nv', 142 | 4: 'df', 143 | 5: 'akiec', 144 | 6: 'bkl', 145 | 7: 'not a cancer image'} 146 | 147 | # LABEL2IDX = {v: k for k, v in IDX2LABEL.items()} 148 | LABEL2IDX = {'vasc': 0, 149 | 'bcc': 1, 150 | 'mel': 2, 151 | 'nv': 3, 152 | 'df': 4, 153 | 'akiec': 5, 154 | 'bkl': 6, 155 | 'not a cancer image': 7} 156 | 157 | if __name__ == '__main__': 158 | seed(114514) 159 | 160 | df = pd.read_csv(METADATA_PATH, usecols=['image_id', 'dx']) 161 | 162 | ds = ReadImage(dataframe=df, images_path=IMAGES_DIR) 163 | shuffle(ds) 164 | 165 | raw_ds = ReadRaw() 166 | shuffle(raw_ds) 167 | 168 | train_ds = ds[:int(SPLIT * len(ds))] + raw_ds[:int(SPLIT * len(raw_ds))] 169 | dev_ds = ds[int(SPLIT * len(ds)):] + raw_ds[int(SPLIT * len(raw_ds)):] 170 | 171 | ds = { 172 | "train": Dataset.from_list(train_ds), 173 | "dev": Dataset.from_list(dev_ds) 174 | } 175 | 176 | ds = DatasetDict(ds) 177 | 178 | # preprocess dataset 179 | 180 | image_processor = AutoImageProcessor.from_pretrained(MODEL_DIR) 181 | 182 | ds = ds.with_transform(transforms) 183 | data_collator = DefaultDataCollator() 184 | 185 | # define metric 186 | 187 | acc = evaluate.load("accuracy") 188 | 189 | # import model 190 | 191 | model = AutoModelForImageClassification.from_pretrained(MODEL_DIR, 192 | num_labels=len(IDX2LABEL), 193 | ignore_mismatched_sizes=True) 194 | 195 | # train model 196 | 197 | training_args = TrainingArguments( 198 | output_dir=CHECKPOINTS_DIR, 199 | remove_unused_columns=False, 200 | evaluation_strategy="epoch", 201 | save_strategy="epoch", 202 | learning_rate=LEARNING_RATE, 203 | per_device_train_batch_size=BATCH_SIZE, 204 | gradient_accumulation_steps=4, 205 | per_device_eval_batch_size=BATCH_SIZE, 206 | num_train_epochs=EPOCHS, 207 | warmup_ratio=WARMUP_RATIO, 208 | logging_steps=LOGGING_STEPS, 209 | load_best_model_at_end=True, 210 | metric_for_best_model="accuracy" 211 | ) 212 | 213 | trainer = Trainer( 214 | model=model, 215 | args=training_args, 216 | data_collator=data_collator, 217 | train_dataset=ds["train"], 218 | eval_dataset=ds["dev"], 219 | tokenizer=image_processor, 220 | compute_metrics=compute_metrics, 221 | ) 222 | trainer.train() 223 | --------------------------------------------------------------------------------