├── static_params.py ├── requirements.txt ├── inference.py ├── README.md ├── .gitignore ├── dataset.py ├── settings.py ├── lightning_model.py └── train.py /static_params.py: -------------------------------------------------------------------------------- 1 | MODELS = dict( 2 | SQUEEZENET=dict( 3 | MEAN=[0.485, 0.456, 0.406], STD=[0.229, 0.224, 0.225] 4 | ), 5 | ) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.11.0 2 | torchvision>=0.12.0 3 | albumentations==1.1.0 4 | deep-utils>=0.9.6 5 | pytorch-lightning==1.6.2 6 | opencv-python=>4.6.0.66 7 | scikit-learn>=1.1.1 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from argparse import ArgumentParser 4 | from deep_utils import TorchVisionInference 5 | 6 | if __name__ == '__main__': 7 | parser = ArgumentParser() 8 | parser.add_argument("--model_path", required=True, help="path to saved model") 9 | parser.add_argument("--device", default="cpu", help="cuda or cpu") 10 | parser.add_argument("--img_path", required=True, help="path to image") 11 | args = parser.parse_args() 12 | model = TorchVisionInference(args.model_path, device=args.device) 13 | if os.path.isdir(args.img_path): 14 | model.infer_directory(args.img_path) 15 | else: 16 | tic = time() 17 | prediction = model.infer(args.img_path) 18 | toc = time() 19 | print(f"predicted class is for {args.img_path} is {prediction}\ninference time: {toc - tic}") 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-lightning-classification-template 2 | 3 | This is a template project for training image classification tasks. 4 | To start the training modify the `settings.py` based on your requirements. 5 | 6 | # Dataset format: 7 | 8 | The format of dataset directory should be as follows: 9 | 10 | ``` 11 | ├── dataset 12 | │ ├── class-1 13 | │ │ ├──image-name.jpg 14 | │ │ ├──image-name.jpg 15 | │ │ ├──... 16 | │ ├── class-2 17 | │ │ ├──image-name.jpg 18 | │ │ ├──image-name.jpg 19 | │ │ ├──... 20 | ... 21 | ``` 22 | 23 | Note: 24 | 25 | 1. image-name is arbitrary 26 | 2. class-1 and class-2 should be renamed to the real name of the input classes. 27 | 28 | # Training: 29 | 30 | ```commandline 31 | cd pytorch-lightning-image-classification-template 32 | python train.py --model_name squeezenet --dataset_dir --output_dir 33 | ``` 34 | 35 | run `python train.py -h` for more configuration. 36 | 37 | # Inference 38 | 39 | ```commandline 40 | python inference.py --model_path --img_path 41 | ``` 42 | 43 | # References 44 | 45 | https://github.com/pooya-mohammadi/deep_utils -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import cv2 4 | from sklearn.model_selection import train_test_split 5 | from deep_utils import crawl_directory_dataset 6 | 7 | 8 | class ImageClassificationDataset(Dataset): 9 | def __init__(self, image_list, label_list, transform=None, class_to_id=None): 10 | self.images = image_list 11 | self.labels = label_list 12 | self.transform = transform 13 | self.class_to_id = class_to_id 14 | 15 | def __len__(self): 16 | return len(self.images) 17 | 18 | def __getitem__(self, idx): 19 | if torch.is_tensor(idx): 20 | idx = idx.tolist() 21 | image_path = self.images[idx] 22 | img = cv2.imread(image_path)[..., ::-1] # bgr2rgb 23 | if self.transform: 24 | img = self.transform(image=img)["image"] 25 | label_name = self.labels[idx] 26 | label = torch.tensor(self.class_to_id[label_name]).type(torch.long) 27 | sample = (img, label) 28 | return sample 29 | 30 | @staticmethod 31 | def get_loaders(config): 32 | x, y = crawl_directory_dataset(config.dataset_dir) 33 | x_train, x_val, y_train, y_val = train_test_split(x, y, 34 | test_size=config.validation_size, 35 | stratify=y) 36 | class_to_id = {v: k for k, v in enumerate(set(y_train))} 37 | train_dataset = ImageClassificationDataset(x_train, y_train, transform=config.train_transform, 38 | class_to_id=class_to_id) 39 | train_loader = torch.utils.data.DataLoader(train_dataset, 40 | batch_size=config.batch_size, 41 | shuffle=True, 42 | num_workers=config.n_workers, 43 | # pin_memory=config.pin_memory 44 | ) 45 | 46 | val_dataset = ImageClassificationDataset(x_val, y_val, transform=config.val_transform, class_to_id=class_to_id) 47 | val_loader = torch.utils.data.DataLoader(val_dataset, 48 | batch_size=config.batch_size, 49 | shuffle=False, 50 | num_workers=config.n_workers, 51 | # pin_memory=config.pin_memory 52 | ) 53 | 54 | return train_loader, val_loader 55 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | import torch 4 | import albumentations as A 5 | from albumentations.pytorch import ToTensorV2 6 | 7 | 8 | @dataclass(init=True, repr=True) 9 | class DirConfig: 10 | dataset_dir = "./dataset" 11 | output_dir = "./output" 12 | file_name = "best" 13 | 14 | 15 | @dataclass(init=True, repr=True) 16 | class ModelConfig: 17 | model_name = "squeezenet" 18 | input_size = 224 19 | mean = [0.485, 0.456, 0.406] 20 | std = [0.229, 0.224, 0.225] 21 | last_layer_nodes = 512 22 | 23 | n_classes = len( 24 | [d for d in os.listdir(DirConfig.dataset_dir) if 25 | os.path.isdir(os.path.join(DirConfig.dataset_dir, d))]) if os.path.isdir(DirConfig.dataset_dir) else 0 26 | 27 | 28 | @dataclass(init=True, repr=True) 29 | class AugConfig: 30 | train_transform = A.Compose( 31 | [A.Resize(height=ModelConfig.input_size, width=ModelConfig.input_size), 32 | A.Rotate(limit=20, p=0.2), 33 | A.HorizontalFlip(p=0.5), 34 | A.Normalize(ModelConfig.mean, ModelConfig.std, max_pixel_value=255.0), 35 | ToTensorV2() 36 | ]) 37 | val_transform = A.Compose( 38 | [A.Resize(ModelConfig.input_size, ModelConfig.input_size), 39 | A.Normalize(ModelConfig.mean, ModelConfig.std, max_pixel_value=255.0), 40 | ToTensorV2() 41 | ]) 42 | 43 | 44 | @dataclass(init=True, repr=True) 45 | class DeviceConfig: 46 | device = "cuda" if torch.cuda.is_available() else "cpu" 47 | pin_memory = True if torch.cuda.is_available() else False 48 | n_workers = 8 49 | 50 | @classmethod 51 | def update_device(cls): 52 | if cls.device is None: 53 | cls.device = "cuda" if torch.cuda.is_available() else "cpu" 54 | cls.pin_memory = True if cls.device != 'cpu' else False 55 | 56 | 57 | @dataclass(init=True, repr=True) 58 | class SaveConfig: 59 | save_model_w_weight = False 60 | 61 | 62 | @dataclass(init=True) 63 | class Config(DirConfig, ModelConfig, AugConfig, DeviceConfig, SaveConfig): 64 | train_epochs = 5 65 | train_lr = 1e-3 66 | 67 | finetune_epochs = 5 68 | finetune_lr = 1e-4 69 | finetune_layers = 50 70 | 71 | lr_reduce_factor = 0.1 72 | lr_patience = 5 73 | 74 | validation_size = 0.2 75 | batch_size = 64 76 | 77 | def update_config_param(self, args): 78 | variables = vars(args) 79 | for k, v in variables.items(): 80 | if hasattr(self, k): 81 | setattr(self, k, v) 82 | else: 83 | raise ValueError(f"value {k} is not defined in Config...") 84 | self.update() 85 | 86 | def update_model(self): 87 | self.n_classes = len( 88 | [d for d in os.listdir(self.dataset_dir) if os.path.isdir(os.path.join(self.dataset_dir, d))]) 89 | 90 | def update(self): 91 | self.update_device() 92 | self.update_model() 93 | 94 | def __repr__(self): 95 | variables = vars(self) 96 | return f"{self.__class__.__name__} -> " + ", ".join(f"{k}: {v}" for k, v in variables.items()) 97 | 98 | def vars(self) -> dict: 99 | out = dict() 100 | for key in dir(self): 101 | val = getattr(self, key) 102 | if (key.startswith("__") and key.endswith("__")) or type(val).__name__ == "method": 103 | continue 104 | else: 105 | out[key] = val 106 | return out 107 | -------------------------------------------------------------------------------- /lightning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from deep_utils import log_print 4 | from sklearn.metrics import f1_score 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | from deep_utils import TorchVisionModel 7 | from torch import nn 8 | 9 | 10 | class LitModel(pl.LightningModule): 11 | def __init__(self, model_name, n_classes, last_layer_nodes, lr, lr_reduce_factor, lr_patience, logger=None, 12 | verbose=1): 13 | super(LitModel, self).__init__() 14 | self.save_hyperparameters() 15 | 16 | self.lr = self.hparams.lr 17 | self.model = TorchVisionModel(model_name=self.hparams.model_name, 18 | num_classes=self.hparams.n_classes, 19 | last_layer_nodes=self.hparams.last_layer_nodes) 20 | self.criterion = nn.CrossEntropyLoss() 21 | self.outer_logger = logger 22 | self.verbose = verbose 23 | self.epoch = 0 24 | 25 | def forward(self, x): 26 | logit = self.model(x) 27 | return logit 28 | 29 | def training_step(self, batch, batch_idx): 30 | return self.get_loss_acc(batch) 31 | 32 | def test_step(self, batch, batch_idx): 33 | return self.get_loss_acc(batch) 34 | 35 | def test_epoch_end(self, outputs) -> None: 36 | self.calculate_metrics(outputs, type_="test") 37 | 38 | def validation_epoch_end(self, outputs) -> None: 39 | self.calculate_metrics(outputs, type_="val") 40 | 41 | def training_epoch_end(self, outputs) -> None: 42 | self.calculate_metrics(outputs, type_='train') 43 | self.epoch += 1 44 | 45 | def calculate_metrics(self, outputs, type_="train"): 46 | labels, preds = [], [] 47 | r_acc, r_loss, size = 0, 0, 0 48 | for row in outputs: 49 | r_acc += row["acc"] 50 | r_loss += row["loss"] 51 | size += row["bs"] 52 | preds.extend(row['preds']) 53 | labels.extend(row["labels"]) 54 | f1_value = f1_score(labels, preds, average="weighted") 55 | loss = r_loss / size 56 | acc = r_acc / size 57 | log_print(self.outer_logger, 58 | f"Epoch: {self.epoch} - {type_}-acc: {acc} - {type_}-loss: {loss} - {type_}-f1-score: {f1_value}") 59 | self.log(f"{type_}_f1_score", f1_value) 60 | self.log(f"{type_}_loss", loss.item()) 61 | self.log(f"{type_}_acc", acc, ) 62 | return acc, f1_value, loss 63 | 64 | def validation_step(self, batch, batch_idx): 65 | return self.get_loss_acc(batch) 66 | 67 | def get_loss_acc(self, batch): 68 | images, labels = batch 69 | batch_size = images.size(0) 70 | logits = self.model(images) 71 | loss = self.criterion(logits, labels) * batch_size 72 | _, preds = torch.max(logits, 1) 73 | corrects = torch.sum(preds == labels.data) 74 | return {"acc": corrects.item(), 75 | "loss": loss, 76 | "bs": batch_size, 77 | "preds": preds.cpu().numpy().tolist(), 78 | "labels": labels.cpu().numpy().tolist() 79 | } 80 | 81 | def configure_optimizers(self): 82 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 83 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.hparams.lr_reduce_factor, 84 | patience=self.hparams.lr_patience, verbose=True) 85 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import pytorch_lightning as pl 5 | from deep_utils import mkdir_incremental, BlocksTorch, get_logger, log_print 6 | from settings import Config 7 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 8 | from lightning_model import LitModel 9 | from dataset import ImageClassificationDataset 10 | from argparse import ArgumentParser 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument("--model_name", type=str, default="squeezenet", help="model name to be trained, default is squeezenet") 14 | parser.add_argument("--dataset_dir", type=Path, required=True, help="path to the dataset to train the model on it") 15 | parser.add_argument("--output_dir", type=Path, default="./output", 16 | help="path to the output directory, default: ./output") 17 | parser.add_argument("--train_epochs", type=int, default=5, help="number of training epochs") 18 | parser.add_argument("--finetune_epochs", type=int, default=5, help="number of fine-tuning epochs") 19 | parser.add_argument("--device", default="cuda", help="what should be the device for training, default is cuda") 20 | parser.add_argument("--n_workers", type=int, default=8, help="Number of workers for data-loaders") 21 | parser.add_argument("--finetune_layers", type=int, default=50, 22 | help="Number of layers that should be finetuned starting from the end of the layers. Default is 50") 23 | 24 | 25 | def main(): 26 | config = Config() 27 | args = parser.parse_args() 28 | config.update_config_param(args) 29 | output_dir = mkdir_incremental(config.output_dir) 30 | train_dir = output_dir / "train" 31 | finetune_dir = output_dir / "finetune" 32 | logger = get_logger("pytorch-lightning-image-classification", log_path=output_dir / "log.log") 33 | log_print(logger, f"Config files: {config}") 34 | model_checkpoint = ModelCheckpoint(dirpath=output_dir, 35 | filename=config.file_name, 36 | monitor="val_loss", 37 | verbose=True) 38 | learning_rate_monitor = LearningRateMonitor(logging_interval="epoch") 39 | trainer = pl.Trainer(gpus=1 if config.device == "cuda" else 0, 40 | max_epochs=config.train_epochs, 41 | min_epochs=config.train_epochs // 10, 42 | callbacks=[model_checkpoint, learning_rate_monitor], 43 | default_root_dir=train_dir) 44 | lit_model = LitModel(config.model_name, config.n_classes, config.last_layer_nodes, config.train_lr, 45 | config.lr_reduce_factor, config.lr_patience) 46 | lit_model.model.model_ft.classifier[1].apply(BlocksTorch.weights_init) 47 | train_loader, val_loader = ImageClassificationDataset.get_loaders(config) 48 | log_print(logger, "Training the model...") 49 | trainer.fit(model=lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader) 50 | log_print(logger, "Fine-tuning the model...") 51 | lit_model.lr = config.finetune_lr 52 | BlocksTorch.set_parameter_grad(lit_model, config.finetune_layers) 53 | 54 | trainer = pl.Trainer(gpus=1 if config.device == "cuda" else 0, 55 | max_epochs=config.finetune_epochs, 56 | min_epochs=config.finetune_epochs // 10, 57 | callbacks=[model_checkpoint, learning_rate_monitor], 58 | default_root_dir=finetune_dir) 59 | trainer.fit(model=lit_model, 60 | train_dataloaders=train_loader, 61 | val_dataloaders=val_loader) 62 | log_print(logger, "Testing val_loader...") 63 | trainer.test(lit_model, ckpt_path="best", dataloaders=val_loader) 64 | 65 | log_print(logger, "Testing train_loader...") 66 | trainer.test(lit_model, ckpt_path="best", dataloaders=train_loader) 67 | 68 | # Adding artifacts to weights 69 | weight_path = output_dir / f"{config.file_name}.ckpt" 70 | best_weight = torch.load(weight_path) 71 | best_weight['id_to_class'] = {v: k for k, v in train_loader.dataset.class_to_id.items()} 72 | for k, v in config.vars().items(): 73 | if k not in best_weight: 74 | best_weight[k] = v 75 | else: 76 | log_print(logger, f"[Warning] Did not save {k} = {v} because there is a variable with the same name!") 77 | if config.save_model_w_weight: 78 | best_weight['model'] = lit_model.model 79 | torch.save(best_weight, weight_path) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | --------------------------------------------------------------------------------