├── tests ├── __init__.py ├── tmp │ └── .gitkeep ├── sample │ ├── sample.jpg │ ├── pytest_val.csv │ ├── pytest_test.csv │ ├── config.yaml │ └── pytest_train.csv ├── test_mean_std.py ├── test_models.py ├── test_seed.py ├── test_class_label_map.py ├── test_device.py ├── test_loss_fn.py ├── test_class_weight.py ├── test_meter.py ├── test_metric.py ├── test_checkpoint.py ├── test_dataset.py ├── test_config.py ├── test_logger.py └── test_helper.py ├── .vscode └── settings.json ├── src ├── libs │ ├── __init__.py │ ├── transformer.py │ ├── class_id_map.py │ ├── seed.py │ ├── mean_std.py │ ├── device.py │ ├── dataset_csv.py │ ├── models │ │ └── __init__.py │ ├── metric.py │ ├── loss_fn │ │ ├── __init__.py │ │ └── class_weight.py │ ├── checkpoint.py │ ├── meter.py │ ├── logger.py │ ├── dataset.py │ ├── config.py │ └── helper.py ├── util_scripts │ ├── visualize_model.py │ ├── make_csv_files.py │ └── make_configs.py ├── evaluate.py └── train.py ├── dockerfiles ├── Dockerfile.pre-commit ├── Dockerfile └── Dockerfile.poetry ├── .gitattributes ├── .dockerignore ├── scripts ├── build_docker.sh ├── run_docker.sh ├── experiment.sh └── run_docker_and_do_exps.sh ├── .gitignore ├── results ├── model=resnet18-learning_rate=0.001-dataset_name=dammy │ └── config.yaml ├── model=resnet18-learning_rate=0.01-dataset_name=dammy │ └── config.yaml ├── model=resnet18-learning_rate=0.01-dataset_name=flower │ └── config.yaml ├── model=resnet34-learning_rate=0.001-dataset_name=dammy │ └── config.yaml ├── model=resnet34-learning_rate=0.01-dataset_name=dammy │ └── config.yaml ├── model=resnet34-learning_rate=0.01-dataset_name=flower │ └── config.yaml ├── model=resnet18-learning_rate=0.001-dataset_name=flower │ └── config.yaml └── model=resnet34-learning_rate=0.001-dataset_name=flower │ └── config.yaml ├── .flake8 ├── .github └── workflows │ └── mypy_pytest.yaml ├── LICENSE ├── notebooks └── template.ipynb ├── pyproject.toml ├── .pre-commit-config.yaml ├── docs └── FOR_AOLAB_MEMBERS.md ├── README.md └── requirements.txt /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tmp/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/libs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dockerfiles/Dockerfile.pre-commit: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/libs/transformer.py: -------------------------------------------------------------------------------- 1 | # 今回は torchvision で提供されているものを用います. 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Jupyter notebook 2 | *.ipynb linguist-documentation 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .github 2 | .vscode 3 | docs 4 | imgs 5 | notebooks 6 | results 7 | src 8 | tests 9 | -------------------------------------------------------------------------------- /tests/sample/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiskw713/pytorch_template/HEAD/tests/sample/sample.jpg -------------------------------------------------------------------------------- /tests/sample/pytest_val.csv: -------------------------------------------------------------------------------- 1 | image_path,class_id,label 2 | ./tests/sample/sample.jpg,0,daisy 3 | ./tests/sample/sample.jpg,0,daisy 4 | -------------------------------------------------------------------------------- /tests/sample/pytest_test.csv: -------------------------------------------------------------------------------- 1 | image_path,class_id,label 2 | ./tests/sample/sample.jpg,0,daisy 3 | ./tests/sample/sample.jpg,0,daisy 4 | -------------------------------------------------------------------------------- /scripts/build_docker.sh: -------------------------------------------------------------------------------- 1 | PROJECT_ROOT=$(dirname $(cd $(dirname $0) && pwd)) 2 | 3 | docker image build -t pytorch_template_image \ 4 | -f $PROJECT_ROOT/dockerfiles/Dockerfile \ 5 | $PROJECT_ROOT 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .python-version 3 | *.png 4 | *.jpg 5 | __pycache__ 6 | .DS_Store 7 | *.prm 8 | *.pth 9 | events.out.tfevents.* 10 | wandb 11 | wandb_result 12 | .mypy_cache 13 | .coverage* 14 | .pytest_cache 15 | 16 | !tests/sample/*.jpg 17 | !src/imgs/*.png -------------------------------------------------------------------------------- /tests/sample/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: flower 3 | height: 224 4 | learning_rate: 0.0003 5 | max_epoch: 50 6 | model: resnet18 7 | num_workers: 2 8 | pretrained: true 9 | use_class_weight: true 10 | width: 224 11 | topk: 12 | - 1 13 | - 3 14 | - 5 15 | -------------------------------------------------------------------------------- /results/model=resnet18-learning_rate=0.001-dataset_name=dammy/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: dammy 3 | height: 224 4 | learning_rate: 0.001 5 | max_epoch: 50 6 | model: resnet18 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet18-learning_rate=0.01-dataset_name=dammy/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: dammy 3 | height: 224 4 | learning_rate: 0.01 5 | max_epoch: 50 6 | model: resnet18 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet18-learning_rate=0.01-dataset_name=flower/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: flower 3 | height: 224 4 | learning_rate: 0.01 5 | max_epoch: 50 6 | model: resnet18 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet34-learning_rate=0.001-dataset_name=dammy/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: dammy 3 | height: 224 4 | learning_rate: 0.001 5 | max_epoch: 50 6 | model: resnet34 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet34-learning_rate=0.01-dataset_name=dammy/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: dammy 3 | height: 224 4 | learning_rate: 0.01 5 | max_epoch: 50 6 | model: resnet34 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet34-learning_rate=0.01-dataset_name=flower/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: flower 3 | height: 224 4 | learning_rate: 0.01 5 | max_epoch: 50 6 | model: resnet34 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet18-learning_rate=0.001-dataset_name=flower/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: flower 3 | height: 224 4 | learning_rate: 0.001 5 | max_epoch: 50 6 | model: resnet18 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /results/model=resnet34-learning_rate=0.001-dataset_name=flower/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | dataset_name: flower 3 | height: 224 4 | learning_rate: 0.001 5 | max_epoch: 50 6 | model: resnet34 7 | num_workers: 2 8 | pretrained: true 9 | topk: 10 | - 1 11 | - 3 12 | use_class_weight: true 13 | width: 224 14 | -------------------------------------------------------------------------------- /tests/test_mean_std.py: -------------------------------------------------------------------------------- 1 | from src.libs.mean_std import get_mean, get_std 2 | 3 | 4 | def test_get_mean() -> None: 5 | mean = get_mean(norm_value=1.0) 6 | assert mean == [123.675, 116.28, 103.53] 7 | 8 | 9 | def test_get_std() -> None: 10 | std = get_std(norm_value=1.0) 11 | return std == [58.395, 57.12, 57.375] 12 | -------------------------------------------------------------------------------- /tests/sample/pytest_train.csv: -------------------------------------------------------------------------------- 1 | image_path,class_id,label 2 | ./tests/sample/sample.jpg,0,daisy 3 | ./tests/sample/sample.jpg,0,daisy 4 | ./tests/sample/sample.jpg,0,daisy 5 | ./tests/sample/sample.jpg,0,daisy 6 | ./tests/sample/sample.jpg,0,daisy 7 | ./tests/sample/sample.jpg,0,daisy 8 | ./tests/sample/sample.jpg,0,daisy 9 | ./tests/sample/sample.jpg,0,daisy 10 | -------------------------------------------------------------------------------- /src/libs/class_id_map.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | def get_cls2id_map() -> Dict[str, int]: 5 | cls2id_map = {"daisy": 0, "dandelion": 1, "rose": 2, "sunflower": 3, "tulip": 4} 6 | 7 | return cls2id_map 8 | 9 | 10 | def get_id2cls_map() -> Dict[int, str]: 11 | cls2id_map = get_cls2id_map() 12 | return {val: key for key, val in cls2id_map.items()} 13 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.libs.models import get_model 5 | 6 | 7 | def test_get_model() -> None: 8 | with pytest.raises(ValueError): 9 | get_model("modelname", 10, False) 10 | 11 | model = get_model("resnet18", 10) 12 | 13 | x = torch.rand((2, 3, 112, 112)) 14 | y = model(x) 15 | 16 | assert y.shape == (2, 10) 17 | -------------------------------------------------------------------------------- /src/libs/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | from logging import getLogger 3 | 4 | import torch 5 | 6 | logger = getLogger(__name__) 7 | 8 | 9 | def set_seed(seed: int = 0) -> None: 10 | # set seed 11 | random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | 16 | logger.info("Finished setting up seed.") 17 | -------------------------------------------------------------------------------- /scripts/run_docker.sh: -------------------------------------------------------------------------------- 1 | PROJECT_ROOT=$(dirname $(cd $(dirname $0) && pwd)) 2 | 3 | docker container run \ 4 | --gpus all --shm-size=8g \ 5 | -itd --rm \ 6 | -v /etc/group:/etc/group:ro \ 7 | -v /etc/password:/etc/password:ro \ 8 | -u $(id -u $USER):$(id -g $USER) \ 9 | -p 8888:8888 \ 10 | -v $PROJECT_ROOT:/project \ 11 | --name pytorch_template_cnt \ 12 | pytorch_template_image bash 13 | -------------------------------------------------------------------------------- /tests/test_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | from src.libs.seed import set_seed 6 | 7 | 8 | def test_get_model() -> None: 9 | set_seed(seed=0) 10 | value1 = random.random() 11 | tensor1 = torch.randn(1, 2, 3) 12 | 13 | set_seed(seed=0) 14 | value2 = random.random() 15 | tensor2 = torch.randn(1, 2, 3) 16 | 17 | assert value1 == value2 18 | assert torch.all(tensor1 == tensor2) 19 | -------------------------------------------------------------------------------- /tests/test_class_label_map.py: -------------------------------------------------------------------------------- 1 | from src.libs.class_id_map import get_cls2id_map, get_id2cls_map 2 | 3 | 4 | def test_get_cls2id_map() -> None: 5 | cls2id_map = get_cls2id_map() 6 | 7 | assert len(cls2id_map) == 5 8 | assert cls2id_map["daisy"] == 0 9 | 10 | 11 | def test_get_id2cls_map() -> None: 12 | id2cls_map = get_id2cls_map() 13 | 14 | assert len(id2cls_map) == 5 15 | assert id2cls_map[0] == "daisy" 16 | -------------------------------------------------------------------------------- /scripts/experiment.sh: -------------------------------------------------------------------------------- 1 | python utils/make_csv_files.py 2 | python utils/make_configs.py --model resnet18 resnet34 resnet50 --learning_rate 0.003 0.0003 3 | 4 | files="./result/*" 5 | for filepath in $files; do 6 | if [ -d $filepath ] ; then 7 | python train.py "${filepath}/config.yaml" 8 | python evaluate.py "${filepath}/config.yaml" validation 9 | python evaluate.py "${filepath}/config.yaml" test 10 | fi 11 | done 12 | -------------------------------------------------------------------------------- /scripts/run_docker_and_do_exps.sh: -------------------------------------------------------------------------------- 1 | # TODO: add train and create config codes. 2 | PROJECT_ROOT=$(dirname $(cd $(dirname $0) && pwd)) 3 | 4 | docker container run \ 5 | --gpus all --shm-size=8g \ 6 | -d --rm --restart always \ 7 | -v /etc/group:/etc/group:ro \ 8 | -v /etc/password:/etc/password:ro \ 9 | -u $(id -u $USER):$(id -g $USER) \ 10 | -p 8888:8888 \ 11 | -v $PROJECT_ROOT:/project \ 12 | --name pytorch_template_cnt \ 13 | pytorch_template_image python src/train.py 14 | -------------------------------------------------------------------------------- /src/libs/mean_std.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import List 3 | 4 | logger = getLogger(__name__) 5 | 6 | 7 | def get_mean(norm_value: float = 255) -> List[float]: 8 | # mean of imagenet 9 | mean = [123.675 / norm_value, 116.28 / norm_value, 103.53 / norm_value] 10 | 11 | logger.info(f"mean value: {mean}") 12 | return mean 13 | 14 | 15 | def get_std(norm_value: float = 255) -> List[float]: 16 | # std fo imagenet 17 | std = [58.395 / norm_value, 57.12 / norm_value, 57.375 / norm_value] 18 | 19 | logger.info(f"std value: {std}") 20 | return std 21 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E402 ... Module level import not at top of file 3 | # PT011 ... set the match parameter in pytest.raises({exception}) 4 | # E501 ...Line too long 5 | # E203 ... Whitespace before ':' 6 | # W503 ... Line break occurred before a binary operator 7 | # W504 ... Line break occurred after a binary operator 8 | ignore = E402,PT011,E501,E203,W503,W504 9 | # max-line-length setting is the same as black 10 | max-line-length = 88 11 | # commit cannot be done when cyclomatic complexity is more than 10. 12 | max-complexity = 10 13 | max-expression-complexity = 7 14 | max-cognitive-complexity = 8 15 | 16 | exclude =.git,__pycache__,.mypy_cache 17 | -------------------------------------------------------------------------------- /src/libs/device.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import torch 4 | 5 | logger = getLogger(__name__) 6 | 7 | 8 | def get_device(allow_only_gpu: bool = True) -> str: 9 | if torch.cuda.is_available(): 10 | device = "cuda" 11 | torch.backends.cudnn.benchmark = True 12 | else: 13 | if allow_only_gpu: 14 | message = ( 15 | "You can use only cpu while you don't" 16 | "allow the use of cpu alone during training." 17 | ) 18 | logger.error(message) 19 | raise ValueError(message) 20 | 21 | device = "cpu" 22 | logger.warning( 23 | "CPU will be used for training. It is better to use GPUs instead" 24 | "because training CNN is computationally expensive." 25 | ) 26 | 27 | return device 28 | -------------------------------------------------------------------------------- /tests/test_device.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_mock import MockFixture 3 | 4 | from src.libs.device import get_device 5 | 6 | 7 | @pytest.mark.parametrize( 8 | ("cuda_available", "allow_only_gpu", "expected"), 9 | [ 10 | (False, False, "cpu"), 11 | (True, True, "cuda"), 12 | (True, False, "cuda"), 13 | ], 14 | ) 15 | def test_get_device1( 16 | mocker: MockFixture, cuda_available: bool, allow_only_gpu: bool, expected: str 17 | ) -> None: 18 | mocker.patch("torch.cuda.is_available").return_value = cuda_available 19 | 20 | assert get_device(allow_only_gpu=allow_only_gpu) == expected 21 | 22 | 23 | def test_get_device2(mocker: MockFixture): 24 | mocker.patch("torch.cuda.is_available").return_value = False 25 | 26 | with pytest.raises(ValueError): 27 | get_device(allow_only_gpu=True) 28 | -------------------------------------------------------------------------------- /tests/test_loss_fn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.libs.loss_fn import get_criterion 5 | 6 | 7 | def test_get_criterion() -> None: 8 | with pytest.raises(ValueError): 9 | get_criterion(True, device="cpu") 10 | 11 | with pytest.raises(ValueError): 12 | get_criterion(True, dataset_name="pytest") 13 | 14 | with pytest.raises(ValueError): 15 | get_criterion(True, dataset_name="hoge", device="cpu") 16 | 17 | criterion = get_criterion(False) 18 | 19 | pred = torch.rand((2, 10)) 20 | pred = torch.softmax(pred, dim=1) 21 | gt = torch.tensor([0, 1]) 22 | 23 | loss = criterion(pred, gt) 24 | assert loss > 0 25 | assert criterion.weight is None 26 | 27 | criterion = get_criterion(True, dataset_name="pytest", device="cpu") 28 | assert criterion.weight is not None 29 | -------------------------------------------------------------------------------- /src/libs/dataset_csv.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from logging import getLogger 3 | 4 | logger = getLogger(__name__) 5 | 6 | __all__ = ["DATASET_CSVS"] 7 | 8 | 9 | @dataclasses.dataclass(frozen=True) 10 | class DatasetCSV: 11 | train: str 12 | val: str 13 | test: str 14 | 15 | 16 | DATASET_CSVS = { 17 | # paths from `src` directory 18 | "flower": DatasetCSV( 19 | train="./csv/train.csv", 20 | val="./csv/val.csv", 21 | test="./csv/test.csv", 22 | ), 23 | "dammy": DatasetCSV( 24 | train="./csv/dammy/train.csv", 25 | val="./csv/dammy/val.csv", 26 | test="./csv/dammy/test.csv", 27 | ), 28 | # paths to the csv files for pytest is from project root 29 | "pytest": DatasetCSV( 30 | train="./tests/sample/pytest_train.csv", 31 | val="./tests/sample/pytest_val.csv", 32 | test="./tests/sample/pytest_test.csv", 33 | ), 34 | } 35 | -------------------------------------------------------------------------------- /dockerfiles/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.9.0" 2 | ARG CUDA="11.1" 3 | ARG CUDNN="8" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-runtime 6 | 7 | RUN apt update && \ 8 | apt install -y \ 9 | ffmpeg libsm6 libxext6 ninja-build libglib2.0-0 libsm6 libxrender-dev \ 10 | gcc vim git watch 11 | 12 | COPY ./requirements.txt /project/requirements.txt 13 | 14 | WORKDIR /project 15 | RUN python -m pip install --upgrade pip && \ 16 | pip install -r requirements.txt 17 | 18 | RUN jupyter notebook --generate-config && \ 19 | echo "c.NotebookApp.notebook_dir = '/project'" >> ~/.jupyter/jupyter_notebook_config.py && \ 20 | echo "c.NotebookApp.token = ''" >> ~/.jupyter/jupyter_notebook_config.py && \ 21 | echo "c.NotebookApp.password = ''" >> ~/.jupyter/jupyter_notebook_config.py && \ 22 | echo "c.NotebookApp.ip = '0.0.0.0'" >> ~/.jupyter/jupyter_notebook_config.py 23 | 24 | CMD ["/bin/bash"] 25 | -------------------------------------------------------------------------------- /src/libs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import torch.nn as nn 4 | import torchvision 5 | 6 | __all__ = ["get_model"] 7 | 8 | model_names = ["resnet18", "resnet34", "resnet50"] 9 | logger = getLogger(__name__) 10 | 11 | 12 | def get_model(name: str, n_classes: int, pretrained: bool = True) -> nn.Module: 13 | name = name.lower() 14 | if name not in model_names: 15 | message = ( 16 | "There is no model appropriate to your choice. " 17 | "You have to choose resnet18, resnet34, resnet50 as a model." 18 | ) 19 | logger.error(message) 20 | raise ValueError(message) 21 | 22 | logger.info("{} will be used as a model.".format(name)) 23 | 24 | model = getattr(torchvision.models, name)(pretrained=pretrained) 25 | in_features = model.fc.in_features 26 | model.fc = nn.Linear(in_features=in_features, out_features=n_classes, bias=True) 27 | 28 | return model 29 | -------------------------------------------------------------------------------- /tests/test_class_weight.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pandas as pd 4 | import pytest 5 | import torch 6 | 7 | from src.libs.loss_fn.class_weight import get_class_num, get_class_weight 8 | 9 | 10 | @pytest.fixture() 11 | def train_data_csv() -> Tuple[pd.DataFrame, str]: 12 | csv_file = "./src/csv/train.csv" 13 | df = pd.read_csv(csv_file) 14 | return df, csv_file 15 | 16 | 17 | def test_get_class_num(train_data_csv: Tuple[pd.DataFrame, str]) -> None: 18 | train_data, csv_file = train_data_csv 19 | 20 | class_num = get_class_num(csv_file) 21 | 22 | assert class_num.shape == (5,) 23 | assert len(train_data) == class_num.sum().item() 24 | assert torch.all(class_num > 0) 25 | 26 | 27 | def test_get_class_weight(train_data_csv: Tuple[pd.DataFrame, str]) -> None: 28 | _, csv_file = train_data_csv 29 | 30 | class_weight = get_class_weight(csv_file) 31 | 32 | assert class_weight.shape == (5,) 33 | assert torch.all(class_weight > 0) 34 | assert class_weight.dtype == torch.float 35 | -------------------------------------------------------------------------------- /tests/test_meter.py: -------------------------------------------------------------------------------- 1 | from logging import DEBUG, INFO 2 | 3 | import pytest 4 | from _pytest.logging import LogCaptureFixture 5 | 6 | from src.libs.meter import AverageMeter, ProgressMeter 7 | 8 | 9 | @pytest.fixture() 10 | def average_meter() -> AverageMeter: 11 | meter = AverageMeter("acc", ":.1f") 12 | meter.update(8.0, 1.0) 13 | meter.update(12.0, 1.0) 14 | return meter 15 | 16 | 17 | def test_average_meter(average_meter: AverageMeter) -> None: 18 | assert average_meter.get_average() == 10.0 19 | 20 | 21 | def test_progress_meter(average_meter: AverageMeter, caplog: LogCaptureFixture) -> None: 22 | caplog.set_level(DEBUG) 23 | 24 | meter = ProgressMeter(2, [average_meter]) 25 | meter.display(2) 26 | 27 | # test logs 28 | assert ( 29 | "src.libs.meter", 30 | DEBUG, 31 | "Progress meter is set up.", 32 | ) in caplog.record_tuples 33 | assert ( 34 | "src.libs.meter", 35 | INFO, 36 | "[2/2]\tacc 12.0 (avg. 10.0)", 37 | ) in caplog.record_tuples 38 | -------------------------------------------------------------------------------- /src/libs/metric.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | 5 | 6 | def calc_accuracy( 7 | output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,) 8 | ) -> List[float]: 9 | """Computes the accuracy over the k top predictions. 10 | 11 | Args: 12 | output: (N, C). model output. 13 | target: (N, C). ground truth. 14 | topk: if you set (1, 5), top 1 and top 5 accuracy are calcuated. 15 | Return: 16 | res: List of calculated top k accuracy 17 | """ 18 | with torch.no_grad(): 19 | maxk = max(topk) 20 | batch_size = target.size(0) 21 | 22 | _, pred = output.topk(maxk, 1, True, True) 23 | pred = pred.t() 24 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 25 | 26 | res = [] 27 | for k in topk: 28 | correct_k = correct[:k].contiguous().view(-1) 29 | correct_k = correct_k.float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size).item()) 31 | return res 32 | -------------------------------------------------------------------------------- /tests/test_metric.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.libs.metric import calc_accuracy 5 | 6 | 7 | @pytest.fixture() 8 | def predictions() -> torch.Tensor: 9 | # shape (N, C) = (5, 5) 10 | preds = torch.tensor( 11 | [ 12 | [0.05, 0.1, 0.15, 0.2, 0.5], 13 | [0.05, 0.1, 0.15, 0.2, 0.5], 14 | [0.05, 0.1, 0.15, 0.2, 0.5], 15 | [0.05, 0.1, 0.15, 0.2, 0.5], 16 | [0.05, 0.1, 0.15, 0.2, 0.5], 17 | ] 18 | ) 19 | return preds 20 | 21 | 22 | @pytest.fixture() 23 | def ground_truths() -> torch.Tensor: 24 | # shape (N, ) = (1, ) 25 | gts = torch.tensor([4, 3, 2, 1, 0]) 26 | return gts 27 | 28 | 29 | def test_calc_accuracy(predictions: torch.Tensor, ground_truths: torch.Tensor) -> None: 30 | top1, top2, top3, top4, top5 = calc_accuracy( 31 | predictions, ground_truths, topk=(1, 2, 3, 4, 5) 32 | ) 33 | assert top1 == 1 / 5 * 100 34 | assert top2 == 2 / 5 * 100 35 | assert top3 == 3 / 5 * 100 36 | assert top4 == 4 / 5 * 100 37 | assert top5 == 5 / 5 * 100 38 | -------------------------------------------------------------------------------- /.github/workflows/mypy_pytest.yaml: -------------------------------------------------------------------------------- 1 | name: mypy_pytest 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | mypy_pytest: 6 | name: running mypy and pytest 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v2 11 | 12 | # set up python 13 | - name: Setting up python. 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.8 17 | 18 | # install poetry 19 | - name: Install Poetry 20 | run: | 21 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python 22 | 23 | - name: Add path for Poetry 24 | run: echo "$HOME/.poetry/bin" >> $GITHUB_PATH 25 | 26 | # インストールした Poetry を使って必要な Python パッケージをインストールする 27 | - name: Install Dependencies 28 | run: poetry install --no-interaction 29 | 30 | # mypy 31 | - name: Type check 32 | run: poetry run mypy . --ignore-missing-imports 33 | 34 | # pytest 35 | - name: Pytest 36 | run: poetry run pytest -v --cov=src --cov-report term-missing 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 yiskw713 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /notebooks/template.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "sys.path.append(os.pardir)\n", 12 | "\n", 13 | "# auto reload module\n", 14 | "%load_ext autoreload\n", 15 | "%autoreload 2" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [] 31 | } 32 | ], 33 | "metadata": { 34 | "kernelspec": { 35 | "display_name": "Python 3", 36 | "language": "python", 37 | "name": "python3" 38 | }, 39 | "language_info": { 40 | "codemirror_mode": { 41 | "name": "ipython", 42 | "version": 3 43 | }, 44 | "file_extension": ".py", 45 | "mimetype": "text/x-python", 46 | "name": "python", 47 | "nbconvert_exporter": "python", 48 | "pygments_lexer": "ipython3", 49 | "version": "3.7.3" 50 | } 51 | }, 52 | "nbformat": 4, 53 | "nbformat_minor": 2 54 | } 55 | -------------------------------------------------------------------------------- /src/util_scripts/visualize_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import hiddenlayer as hl 6 | import torch 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 9 | from libs.models import get_model 10 | 11 | 12 | def get_arguments() -> argparse.Namespace: 13 | """parse all the arguments from command line inteface return a list of 14 | parsed arguments.""" 15 | 16 | parser = argparse.ArgumentParser(description="Model visualization.") 17 | parser.add_argument( 18 | "model", 19 | type=str, 20 | choices=["resnet18", "resnet34", "resnet50"], 21 | help="name of the model you want to visualize.", 22 | ) 23 | parser.add_argument( 24 | "--save_dir", 25 | type=str, 26 | default="./imgs", 27 | help="a directory where images will be saved", 28 | ) 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def main() -> None: 34 | args = get_arguments() 35 | 36 | model = get_model(args.model, 10) 37 | save_path = os.path.join(args.save_dir, f"{args.model}.png") 38 | 39 | hl_graph = hl.build_graph(model, torch.zeros([1, 3, 224, 224])) 40 | hl_graph.save(save_path, format="png") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /src/libs/loss_fn/__init__.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Optional 3 | 4 | import torch.nn as nn 5 | 6 | from ..dataset_csv import DATASET_CSVS 7 | from .class_weight import get_class_weight 8 | 9 | __all__ = ["get_criterion"] 10 | logger = getLogger(__name__) 11 | 12 | 13 | def get_criterion( 14 | use_class_weight: bool = False, 15 | dataset_name: Optional[str] = None, 16 | device: Optional[str] = None, 17 | ) -> nn.Module: 18 | 19 | if use_class_weight: 20 | if dataset_name is None: 21 | message = "dataset_name used for training should be specified." 22 | logger.error(message) 23 | raise ValueError(message) 24 | 25 | if device is None: 26 | message = "you should specify a device when you use class weight." 27 | logger.error(message) 28 | raise ValueError(message) 29 | 30 | if dataset_name not in DATASET_CSVS: 31 | message = "dataset_name is invalid." 32 | logger.error(message) 33 | raise ValueError(message) 34 | 35 | train_csv_file = DATASET_CSVS[dataset_name].train 36 | class_weight = get_class_weight(train_csv_file).to(device) 37 | criterion = nn.CrossEntropyLoss(weight=class_weight) 38 | else: 39 | criterion = nn.CrossEntropyLoss() 40 | 41 | return criterion 42 | -------------------------------------------------------------------------------- /src/libs/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | def save_checkpoint( 13 | result_path: str, 14 | epoch: int, 15 | model: nn.Module, 16 | optimizer: optim.Optimizer, 17 | best_loss: float, 18 | ) -> None: 19 | 20 | save_states = { 21 | "epoch": epoch, 22 | "state_dict": model.state_dict(), 23 | "optimizer": optimizer.state_dict(), 24 | "best_loss": best_loss, 25 | } 26 | 27 | torch.save(save_states, os.path.join(result_path, "checkpoint.pth")) 28 | logger.debug("successfully saved the ckeckpoint.") 29 | 30 | 31 | def resume( 32 | resume_path: str, model: nn.Module, optimizer: optim.Optimizer 33 | ) -> Tuple[int, nn.Module, optim.Optimizer, float]: 34 | try: 35 | checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) 36 | logger.info("loading checkpoint {}".format(resume_path)) 37 | except FileNotFoundError("there is no checkpoint at the result folder.") as e: 38 | logger.exception(f"{e}") 39 | 40 | begin_epoch = checkpoint["epoch"] 41 | best_loss = checkpoint["best_loss"] 42 | model.load_state_dict(checkpoint["state_dict"]) 43 | 44 | optimizer.load_state_dict(checkpoint["optimizer"]) 45 | 46 | logger.info("training will start from {} epoch".format(begin_epoch)) 47 | 48 | return begin_epoch, model, optimizer, best_loss 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pytorch_template" 3 | version = "0.1.0" 4 | description = "pytorch project template" 5 | authors = ["yiskw713"] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.9" 10 | Pillow = "^9.0.1" 11 | PyYAML = "^6.0" 12 | numpy = "^1.22.3" 13 | pandas = "^1.4.1" 14 | wandb = "^0.12.11" 15 | hiddenlayer = "^0.3" 16 | graphviz = "^0.19.1" 17 | opencv-python = "^4.5.5" 18 | pydantic = "^1.9.0" 19 | 20 | [tool.poetry.dev-dependencies] 21 | black = "^22.1.0" 22 | flake8 = "^4.0.1" 23 | isort = "^5.10.1" 24 | mypy = "^0.931" 25 | pytest = "^7.0.1" 26 | pytest-cov = "^3.0.0" 27 | pytest-mock = "^3.7.0" 28 | jupyterlab = "^3.3.0" 29 | bandit = "^1.7.4" 30 | poethepoet = "^0.13.1" 31 | flake8-bugbear = "22.1.11" 32 | flake8-builtins = "1.5.3" 33 | flake8-eradicate = "1.2.0" 34 | pep8-naming = "0.12.1" 35 | flake8-expression-complexity = "0.0.10" 36 | flake8-cognitive-complexity = "0.1.0" 37 | flake8-pytest-style = "1.6.0" 38 | 39 | [build-system] 40 | requires = ["poetry-core>=1.0.0"] 41 | build-backend = "poetry.core.masonry.api" 42 | 43 | [tool.isort] 44 | ensure_newline_before_comments = true 45 | force_grid_wrap = 0 46 | include_trailing_comma = true 47 | line_length = 88 48 | multi_line_output = 3 49 | use_parentheses = true 50 | 51 | [tool.poe.tasks] 52 | install_sklearn = """ 53 | pip install \ 54 | scipy==1.7.3 \ 55 | scikit-learn 56 | """ 57 | 58 | install_torch = """ 59 | pip install \ 60 | torch==1.10.2 \ 61 | torchvision \ 62 | timm \ 63 | torchinfo \ 64 | imgaug \ 65 | """ -------------------------------------------------------------------------------- /dockerfiles/Dockerfile.poetry: -------------------------------------------------------------------------------- 1 | ARG PYTHON_ENV=python:3.9.9-slim 2 | 3 | # build stage 4 | FROM $PYTHON_ENV as build 5 | 6 | COPY ./pyproject.toml /app/pyproject.toml 7 | RUN apt update && \ 8 | apt install -y \ 9 | ffmpeg libsm6 libxext6 ninja-build libglib2.0-0 libsm6 libxrender-dev gcc 10 | 11 | WORKDIR /app 12 | # install python dependencies 13 | RUN python -m pip install --upgrade pip && \ 14 | python -m pip install poetry && \ 15 | poetry config virtualenvs.in-project true && \ 16 | poetry install --no-dev && \ 17 | rm -rf ~/.cache 18 | 19 | # install python dependencies with poethepoet 20 | # some packages cannot installed with poetry (e.g. pytorch) 21 | RUN poetry add poethepoet && \ 22 | poetry run poe install_torch && \ 23 | poetry run poe install_sklearn && \ 24 | poetry remove poethepoet && \ 25 | rm -rf ~/.cache 26 | 27 | # development stage 28 | FROM $BPYTHON_$PYTHON_ENV as development 29 | 30 | COPY --from=build /app/pyproject.toml /app/pyproject.toml 31 | COPY --from=build /app/poetry.lock /app/poetry.lock 32 | COPY --from=build /app/.venv /app/.venv 33 | ENV PATH=/app/.venv/bin:$PATH 34 | 35 | COPY --from=build /usr/bin/* /usr/bin/ 36 | RUN apt update && \ 37 | apt install -y git vim 38 | 39 | WORKDIR /app 40 | RUN python -m pip install --upgrade pip && \ 41 | python -m pip install poetry && \ 42 | poetry config virtualenvs.in-project true && \ 43 | poetry install && \ 44 | rm -rf ~/.cache 45 | 46 | # production stage 47 | FROM $BPYTHON_$PYTHON_ENV as production 48 | 49 | RUN apt update && \ 50 | apt clean && \ 51 | rm -rf /var/lib/apt/lists/* 52 | 53 | COPY --from=build /app/.venv /app/.venv 54 | ENV PATH=/app/.venv/bin:$PATH 55 | 56 | CMD ["/bin/bash"] 57 | -------------------------------------------------------------------------------- /tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import Tuple 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torchvision import models 10 | 11 | from src.libs.checkpoint import resume, save_checkpoint 12 | 13 | 14 | @pytest.fixture() 15 | def model_optim() -> Tuple[nn.Module, optim.Optimizer]: 16 | model = models.resnet18() 17 | optimizer = optim.Adam(model.parameters(), lr=0.1) 18 | 19 | return model, optimizer 20 | 21 | 22 | def test_checkpoint(model_optim: Tuple[nn.Module, optim.Optimizer]) -> None: 23 | model = model_optim[0] 24 | optimizer = model_optim[1] 25 | epoch = 100 26 | best_loss = 0.1 27 | result_path = "./tests/tmp" 28 | 29 | save_checkpoint(result_path, epoch, model, optimizer, best_loss) 30 | checkpoint_path = os.path.join(result_path, "checkpoint.pth") 31 | 32 | assert os.path.exists(checkpoint_path) 33 | 34 | model2 = copy.deepcopy(model) 35 | optimizer2 = copy.deepcopy(optimizer) 36 | 37 | begin_epoch, model2, optimizer2, best_loss2 = resume( 38 | checkpoint_path, model2, optimizer2 39 | ) 40 | 41 | assert epoch == begin_epoch 42 | 43 | for state, state2 in zip(optimizer.state_dict(), optimizer2.state_dict()): 44 | assert state == state2 45 | 46 | assert best_loss == best_loss2 47 | 48 | # check if models have the same weights 49 | # https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351 50 | for key_item1, key_item2 in zip( 51 | model.state_dict().items(), model2.state_dict().items() 52 | ): 53 | assert torch.equal(key_item1[1], key_item2[1]) 54 | 55 | os.remove(checkpoint_path) 56 | 57 | assert not os.path.exists(checkpoint_path) 58 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torchvision import transforms 5 | 6 | from src.libs.dataset import FlowersDataset, get_dataloader 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [1, 2]) 10 | def test_get_dataloader(batch_size) -> None: 11 | loader = get_dataloader( 12 | dataset_name="pytest", 13 | split="train", 14 | batch_size=batch_size, 15 | shuffle=True, 16 | num_workers=1, 17 | pin_memory=False, 18 | drop_last=True, 19 | transform=transforms.Compose( 20 | [ 21 | transforms.Resize((224, 224)), 22 | transforms.ToTensor(), 23 | ] 24 | ), 25 | ) 26 | 27 | assert len(loader) == 8 // batch_size 28 | assert isinstance(loader, DataLoader) 29 | 30 | for sample in loader: 31 | assert sample["img"].shape == (batch_size, 3, 224, 224) 32 | assert sample["img"].dtype == torch.float 33 | 34 | assert isinstance(sample["label"], list) 35 | assert isinstance(sample["label"][0], str) 36 | 37 | assert sample["class_id"].shape == (batch_size,) 38 | assert sample["class_id"].dtype == torch.int64 39 | break 40 | 41 | 42 | class TestFlowersDataset(object): 43 | @pytest.fixture() 44 | def data(self): 45 | data = FlowersDataset("./tests/sample/pytest_train.csv") 46 | return data 47 | 48 | def test_len(self, data): 49 | assert len(data) == 8 50 | 51 | def test_get_n_classes(self, data): 52 | assert data.get_n_classes() == 1 53 | 54 | def test_getitem(self, data): 55 | sample = data.__getitem__(0) 56 | 57 | assert "class_id" in sample 58 | assert "label" in sample 59 | assert "img" in sample 60 | -------------------------------------------------------------------------------- /src/libs/loss_fn/class_weight.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import pandas as pd 4 | import torch 5 | 6 | logger = getLogger(__name__) 7 | 8 | 9 | def get_class_num(train_csv_file: str) -> torch.Tensor: 10 | """ 11 | get the number of samples in each class 12 | Args: 13 | train_csv_file: the path to the train csv file 14 | """ 15 | try: 16 | df = pd.read_csv(train_csv_file) 17 | except FileNotFoundError as err: 18 | logger.exception(f"{err}") 19 | raise err 20 | 21 | df = pd.read_csv(train_csv_file) 22 | n_classes = df["class_id"].nunique() 23 | 24 | nums = {} 25 | for i in range(n_classes): 26 | nums[i] = 0 27 | for i in range(len(df)): 28 | nums[df.iloc[i, 1]] += 1 29 | class_num = [] 30 | for val in nums.values(): 31 | class_num.append(val) 32 | class_num = torch.tensor(class_num) 33 | 34 | logger.debug(f"the number of samples per class: {class_num}") 35 | 36 | return class_num 37 | 38 | 39 | def get_class_weight(train_csv_file: str) -> torch.Tensor: 40 | """Class weight for CrossEntropy in Flowers Recognition Dataset Class 41 | weight is calculated in the way described in: 42 | 43 | D. Eigen and R. Fergus, “Predicting depth, surface normals and semantic labels 44 | with a common multi-scale convolutional architecture,” in ICCV 2015, 45 | openaccess: 46 | https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Eigen_Predicting_Depth_Surface_ICCV_2015_paper.pdf 47 | """ 48 | 49 | class_num = get_class_num(train_csv_file) 50 | total = class_num.sum().item() 51 | frequency = class_num.float() / total 52 | median = torch.median(frequency) 53 | class_weight = median / frequency 54 | 55 | logger.debug(f"class weight: {class_num}") 56 | 57 | return class_weight 58 | -------------------------------------------------------------------------------- /src/libs/meter.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import List 3 | 4 | logger = getLogger(__name__) 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value.""" 9 | 10 | def __init__(self, name: str, fmt: str = ":f") -> None: 11 | self.name = name 12 | self.fmt = fmt 13 | self._reset() 14 | logger.debug("Average meter is set up.") 15 | 16 | def _reset(self) -> None: 17 | self.val = 0.0 18 | self.avg = 0.0 19 | self.sum = 0.0 20 | self.count = 0 21 | 22 | def update(self, val: float, n: int = 1) -> None: 23 | # `val` is the average value of `n` samples 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | def get_average(self) -> float: 30 | return self.avg 31 | 32 | def __str__(self) -> str: 33 | fmtstr = "{name} {val" + self.fmt + "} (avg. {avg" + self.fmt + "})" 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class ProgressMeter(object): 38 | def __init__( 39 | self, num_batches: int, meters: List[AverageMeter], prefix: str = "" 40 | ) -> None: 41 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 42 | self.meters = meters 43 | self.prefix = prefix 44 | 45 | logger.debug("Progress meter is set up.") 46 | 47 | def display(self, batch: int) -> None: 48 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 49 | 50 | # show current values and average values 51 | entries += [str(meter) for meter in self.meters] 52 | logger.info("\t".join(entries)) 53 | 54 | def _get_batch_fmtstr(self, num_batches: int) -> str: 55 | num_digits = len(str(num_batches // 1)) 56 | # format the number of digits for string 57 | fmt = "{:" + str(num_digits) + "d}" 58 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 59 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.3.0 4 | hooks: 5 | - id: check-added-large-files 6 | args: [--maxkb=5000] 7 | - id: check-json 8 | - id: check-toml 9 | - id: check-xml 10 | - id: check-yaml 11 | - id: debug-statements 12 | - id: detect-aws-credentials 13 | args: [--allow-missing-credentials] 14 | - id: detect-private-key 15 | - id: end-of-file-fixer 16 | - id: name-tests-test 17 | args: [--django] # to match `test*.py`. 18 | # - id: no-commit-to-branch # to protect specific branches from direct checkins. 19 | # args: [--branch, master] 20 | - id: pretty-format-json 21 | args: [--autofix] 22 | 23 | - repo: https://gitlab.com/pycqa/flake8 24 | rev: 3.8.1 25 | hooks: 26 | - id: flake8 27 | # max-line-length setting is the same as black 28 | # commit cannot be done when cyclomatic complexity is more than 10. 29 | # E402 ... Module level import not at top of file 30 | # PT011 ... set the match parameter in pytest.raises({exception}) 31 | # E501 ...Line too long 32 | # E203 ... Whitespace before ':' 33 | # W503 ... Line break occurred before a binary operator 34 | # W504 ... Line break occurred after a binary operator 35 | args: [--max-line-length, "88", "--ignore=E402,PT011,E501,E203,W503,W504", --max-complexity, "10", --max-expression-complexity=7, --max-cognitive-complexity=8] 36 | additional_dependencies: [flake8-bugbear, flake8-builtins, flake8-eradicate, pep8-naming, flake8-expression-complexity, flake8-cognitive-complexity, flake8-pytest-style] 37 | 38 | - repo: https://github.com/psf/black 39 | rev: stable 40 | hooks: 41 | - id: black 42 | language_version: python3 43 | 44 | - repo: https://github.com/pycqa/isort 45 | rev: 5.5.2 46 | hooks: 47 | - id: isort 48 | args: ["--settings-path=pyproject.toml"] 49 | 50 | # for docstrings in python codes 51 | - repo: https://github.com/myint/docformatter 52 | rev: master 53 | hooks: 54 | - id: docformatter 55 | args: [--in-place] 56 | 57 | # for markdown 58 | - repo: https://github.com/markdownlint/markdownlint 59 | rev: master # or specific git tag 60 | hooks: 61 | - id: markdownlint 62 | # ignore line length of makrdownlint 63 | args: [-r, ~MD013] 64 | -------------------------------------------------------------------------------- /src/libs/logger.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import pandas as pd 4 | 5 | logger = getLogger(__name__) 6 | 7 | 8 | class TrainLogger(object): 9 | def __init__(self, log_path: str, resume: bool) -> None: 10 | self.log_path = log_path 11 | self.columns = [ 12 | "epoch", 13 | "lr", 14 | "train_time[sec]", 15 | "train_loss", 16 | "train_acc@1", 17 | "train_f1s", 18 | "val_time[sec]", 19 | "val_loss", 20 | "val_acc@1", 21 | "val_f1s", 22 | ] 23 | 24 | if resume: 25 | self.df = self._load_log() 26 | else: 27 | self.df = pd.DataFrame(columns=self.columns) 28 | 29 | def _load_log(self) -> pd.DataFrame: 30 | try: 31 | df = pd.read_csv(self.log_path) 32 | logger.info("successfully loaded log csv file.") 33 | return df 34 | except FileNotFoundError as err: 35 | logger.exception(f"{err}") 36 | raise err 37 | 38 | def _save_log(self) -> None: 39 | self.df.to_csv(self.log_path, index=False) 40 | logger.debug("training logs are saved.") 41 | 42 | def update( 43 | self, 44 | epoch: int, 45 | lr: float, 46 | train_time: int, 47 | train_loss: float, 48 | train_acc: float, 49 | train_f1s: float, 50 | val_time: int, 51 | val_loss: float, 52 | val_acc1: float, 53 | val_f1s: float, 54 | ) -> None: 55 | tmp = pd.Series( 56 | [ 57 | epoch, 58 | lr, 59 | train_time, 60 | train_loss, 61 | train_acc, 62 | train_f1s, 63 | val_time, 64 | val_loss, 65 | val_acc1, 66 | val_f1s, 67 | ], 68 | index=self.columns, 69 | ) 70 | 71 | self.df = self.df.append(tmp, ignore_index=True) 72 | self._save_log() 73 | 74 | logger.info( 75 | f"epoch: {epoch}\tepoch time[sec]: {train_time + val_time}\tlr: {lr}\t" 76 | f"train loss: {train_loss:.4f}\tval loss: {val_loss:.4f}\t" 77 | f"val_acc1: {val_acc1:.5f}\tval_f1s: {val_f1s:.5f}" 78 | ) 79 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | from _pytest.capture import CaptureFixture 5 | 6 | from src.libs.config import Config, convert_list2tuple, get_config 7 | 8 | 9 | @pytest.fixture() 10 | def base_dict() -> Dict[str, Any]: 11 | _dict = { 12 | "batch_size": 32, 13 | "dataset_name": "flower", 14 | "height": 224, 15 | "learning_rate": 0.0003, 16 | "max_epoch": 50, 17 | "model": "resnet18", 18 | "num_workers": 2, 19 | "pretrained": True, 20 | "topk": (1, 3, 5), 21 | "use_class_weight": True, 22 | "width": 224, 23 | } 24 | return _dict 25 | 26 | 27 | class TestConfig(object): 28 | def test_type_check(self, base_dict: Dict[str, Any]) -> None: 29 | for k in base_dict.keys(): 30 | _dict = base_dict.copy() 31 | if k == "pretrained" or k == "use_class_weight": 32 | _dict[k] = "test" 33 | else: 34 | _dict[k] = True 35 | 36 | with pytest.raises(TypeError): 37 | Config(**_dict) 38 | 39 | def test_value_check(self, base_dict: Dict[str, Any]) -> None: 40 | _dict = base_dict.copy() 41 | _dict["dataset_name"] = "hoge" 42 | with pytest.raises(ValueError): 43 | Config(**_dict) 44 | 45 | _dict = base_dict.copy() 46 | _dict["max_epoch"] = -100 47 | with pytest.raises(ValueError): 48 | Config(**_dict) 49 | 50 | def test_type_check_element(self, base_dict: Dict[str, Any]) -> None: 51 | for val in [("train", "test"), (True, False), (1.2, 2.5)]: 52 | _dict = base_dict.copy() 53 | _dict["topk"] = val 54 | 55 | with pytest.raises(TypeError): 56 | Config(**_dict) 57 | 58 | def test_post_init(self, base_dict: Dict[str, Any], capfd: CaptureFixture) -> None: 59 | Config(**base_dict) 60 | 61 | # test printed string 62 | _, err = capfd.readouterr() 63 | assert err == "" 64 | 65 | 66 | def test_convert_list2tuple() -> None: 67 | _dict = {"test": [1, 2, 3], "train": ["hoge", "foo"], "validation": [True, False]} 68 | 69 | _dict = convert_list2tuple(_dict) 70 | 71 | for val in _dict.values(): 72 | assert isinstance(val, tuple) 73 | 74 | 75 | def test_get_config(base_dict: Dict[str, Any]) -> None: 76 | config = get_config("tests/sample/config.yaml") 77 | 78 | for key, val in base_dict.items(): 79 | assert val == getattr(config, key) 80 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import DEBUG, INFO 3 | from typing import Tuple, Union 4 | 5 | import pytest 6 | from _pytest.logging import LogCaptureFixture 7 | 8 | from src.libs.logger import TrainLogger 9 | 10 | 11 | @pytest.fixture() 12 | def epoch_result() -> Tuple[Union[int, float], ...]: 13 | results = ( 14 | 0, 15 | 0.1, 16 | 20, 17 | 0.05, 18 | 30.0, 19 | 30.0, 20 | 10, 21 | 0.03, 22 | 28.3, 23 | 28.3, 24 | ) 25 | return results 26 | 27 | 28 | def test_update(epoch_result: Tuple[Union[int, float], ...], caplog: LogCaptureFixture): 29 | caplog.set_level(DEBUG) 30 | 31 | log_path = "./tests/tmp/log.csv" 32 | logger = TrainLogger(log_path, False) 33 | logger.update(*epoch_result) 34 | 35 | # test logs 36 | assert ( 37 | "src.libs.logger", 38 | DEBUG, 39 | "training logs are saved.", 40 | ) in caplog.record_tuples 41 | 42 | assert ( 43 | "src.libs.logger", 44 | INFO, 45 | f"epoch: {epoch_result[0]}\t" 46 | f"epoch time[sec]: {epoch_result[2] + epoch_result[6]}\t" 47 | f"lr: {epoch_result[1]}\t" 48 | f"train loss: {epoch_result[3]:.4f}\tval loss: {epoch_result[7]:.4f}\t" 49 | f"val_acc1: {epoch_result[8]:.5f}\tval_f1s: {epoch_result[9]:.5f}", 50 | ) in caplog.record_tuples 51 | 52 | assert os.path.exists(log_path) 53 | assert len(logger.df) == 1 54 | 55 | logger.update(*epoch_result) 56 | assert len(logger.df) == 2 57 | 58 | 59 | def test_init_load(caplog: LogCaptureFixture): 60 | caplog.set_level(DEBUG) 61 | 62 | log_path = "./tests/tmp/no_exist_log.csv" 63 | 64 | with pytest.raises(FileNotFoundError): 65 | logger = TrainLogger(log_path, True) 66 | 67 | log_path = "./tests/tmp/log.csv" 68 | logger = TrainLogger(log_path, True) 69 | 70 | # test logs 71 | assert ( 72 | "src.libs.logger", 73 | INFO, 74 | "successfully loaded log csv file.", 75 | ) in caplog.record_tuples 76 | 77 | assert len(logger.df) == 2 78 | assert logger.df.iloc[0]["epoch"] == 0 79 | assert logger.df.iloc[0]["lr"] == 0.1 80 | assert logger.df.iloc[0]["train_time[sec]"] == 20 81 | assert logger.df.iloc[0]["train_loss"] == 0.05 82 | assert logger.df.iloc[0]["train_acc@1"] == 30.0 83 | assert logger.df.iloc[0]["train_f1s"] == 30.0 84 | assert logger.df.iloc[0]["val_time[sec]"] == 10 85 | assert logger.df.iloc[0]["val_loss"] == 0.03 86 | assert logger.df.iloc[0]["val_acc@1"] == 28.3 87 | assert logger.df.iloc[0]["val_f1s"] == 28.3 88 | 89 | os.remove(log_path) 90 | -------------------------------------------------------------------------------- /src/libs/dataset.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Any, Dict, Optional 3 | 4 | import pandas as pd 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision import transforms 9 | 10 | from .dataset_csv import DATASET_CSVS 11 | 12 | __all__ = ["get_dataloader"] 13 | 14 | logger = getLogger(__name__) 15 | 16 | 17 | def get_dataloader( 18 | dataset_name: str, 19 | split: str, 20 | batch_size: int, 21 | shuffle: bool, 22 | num_workers: int, 23 | pin_memory: bool, 24 | drop_last: bool = False, 25 | transform: Optional[transforms.Compose] = None, 26 | ) -> DataLoader: 27 | if dataset_name not in DATASET_CSVS: 28 | message = f"dataset_name should be selected from {list(DATASET_CSVS.keys())}." 29 | logger.error(message) 30 | raise ValueError(message) 31 | 32 | if split not in ["train", "val", "test"]: 33 | message = "split should be selected from ['train', 'val', 'test']." 34 | logger.error(message) 35 | raise ValueError(message) 36 | 37 | logger.info(f"Dataset: {dataset_name}\tSplit: {split}\tBatch size: {batch_size}.") 38 | 39 | csv_file = getattr(DATASET_CSVS[dataset_name], split) 40 | 41 | data = FlowersDataset(csv_file, transform=transform) 42 | dataloader = DataLoader( 43 | data, 44 | batch_size=batch_size, 45 | shuffle=shuffle, 46 | num_workers=num_workers, 47 | pin_memory=pin_memory, 48 | drop_last=drop_last, 49 | ) 50 | 51 | return dataloader 52 | 53 | 54 | class FlowersDataset(Dataset): 55 | def __init__( 56 | self, csv_file: str, transform: Optional[transforms.Compose] = None 57 | ) -> None: 58 | super().__init__() 59 | 60 | try: 61 | self.df = pd.read_csv(csv_file) 62 | except FileNotFoundError("csv file not found.") as e: 63 | logger.exception(f"{e}") 64 | 65 | self.n_classes = self.df["class_id"].nunique() 66 | self.transform = transform 67 | 68 | logger.info(f"the number of classes: {self.n_classes}") 69 | logger.info(f"the number of samples: {len(self.df)}") 70 | 71 | def __len__(self) -> int: 72 | return len(self.df) 73 | 74 | def __getitem__(self, idx: int) -> Dict[str, Any]: 75 | img_path = self.df.iloc[idx]["image_path"] 76 | img = Image.open(img_path) 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | 81 | cls_id = self.df.iloc[idx]["class_id"] 82 | cls_id = torch.tensor(cls_id).long() 83 | 84 | label = self.df.iloc[idx]["label"] 85 | 86 | sample = {"img": img, "class_id": cls_id, "label": label, "img_path": img_path} 87 | 88 | return sample 89 | 90 | def get_n_classes(self) -> int: 91 | return self.n_classes 92 | -------------------------------------------------------------------------------- /docs/FOR_AOLAB_MEMBERS.md: -------------------------------------------------------------------------------- 1 | # Image Classification Template for B4 in Aolab 2 | 3 | 画像分類問題のコードの一例です.研究室での勉強用コードです. 4 | 5 | ## Requirements 6 | 7 | * python >= 3.7 8 | * pytorch >= 1.0 9 | * pyyaml 10 | * scikit-learn 11 | * wandb 12 | 13 | 必要なpythonパッケージは,`pip install -r requirements.txt` でインストールできます. 14 | 15 | ## Dataset 16 | 17 | Flowers Recognition Dataset を使います. 18 | ダウンロードは[こちら](https://www.kaggle.com/alxmamaev/flowers-recognition/download)から. 19 | 20 | ## Directory Structure 21 | 22 | 以下のようなディレクトリ構成を想定しています. 23 | 基本的には,`libs`には `train.py` や `evaluate.py` などメインのスクリプトを実行するのに必要なスクリプトをおきます. 24 | `utils`にはそれ以外のスクリプト(例: `make_csv_files.py` や`make_configs.py`など,実行したいスクリプトに直接関係ないもの)を配置します. 25 | データセットの場所は,以下の通りでなくても大丈夫です. 26 | 27 | ```Directory Structure 28 | root/ ──── csv/ 29 | ├─ libs/ 30 | ├─ result/ 31 | ├─ utils/ 32 | ├─ notebook/ 33 | ├─ dataset ─── flowers/ 34 | ├─ scripts ─── experiment.sh 35 | ├ .gitignore 36 | ├ README.md 37 | ├ FOR_AOLAB_MEMBERS.md 38 | ├ requirements.txt 39 | ├ evaluate.py 40 | └ train.py 41 | ``` 42 | 43 | ## このレポジトリの特徴 44 | 45 | * 実験設定を記述する`Config`クラスをpythonの`dataclass`で実装. 46 | 47 | * 型のチェックや,値の過不足を確認できる 48 | * イミュータブルなオブジェクトなので,誤って実験設定が変更されることがない 49 | 50 | * `uits/make_configs.py`で,configファイルを自動で生成できる 51 | * `scripts/experiment.sh`で,実験を一気に回せる 52 | * typingの使用 53 | * black / isort / flake8 によるコードの整形, 論理エラーの検出など 54 | 55 | ## コードを書く手順 56 | 57 | 以下の順でコードを書いていきます. 58 | 59 | 1. データセットクラスのための csv file の作成 (`utils/make_csv_files.py`) 60 | 1. `Config`クラスの定義`libs/config.py`と,configファイルの自動生成 (`utils/make_configs.py`) 61 | 1. データセットクラスの作成 (`libs/dataset.py`) 62 | 63 | \- データセットの画像に対する前処理のコード (`libs/transformer.py`) 64 | 65 | \- 前処理のコードに必要な平均値,標準偏差を書いたスクリプト (`libs/mean.py`) 66 | 67 | \- クラスのindex とラベルの対応を記すスクリプト (`libs/class_weight_map.py`) 68 | 69 | 1. モデルの定義 (e.g. `libs/models/mymodel.py`) 70 | 1. ロス関数の定義 (e.g. `libs/loss_fn/myloss.py`) 71 | 1. その他学習に必要なコード (`libs/checkpoint.py`, `libs/class_weight.py`, `libs/metric.py`) 72 | 1. 学習のコード (`train.py`) 73 | 74 | \- config file を用いる (`result/r18_lr0.0005/config.yaml`) 75 | 76 | 1. 評価するためのコード (`evaluate.py`) 77 | 1. 学習とテストのコードをいっぺんに回すためのシェルスクリプトの作成 (`experiment.sh`) 78 | 79 | このコードでは,データセットの画像とラベルのペアをあらかじめ csv file に書き出します. 80 | csv に書き出す理由は,ラベル以外に情報を含めるのが簡単だったり,json file などと比べて見やすいと個人的に思うからです. 81 | また,ラベル以外のメタ情報を使いたいときなども,それらの処理が容易だからです. 82 | 83 | また学習を回す際は,configuration を書いたファイルを作成して,それを読み込むような設定にしています. 84 | `argparse` などで細かく実験設定を記載するよりも,楽で見やすいし,何より実験設定を保存して置いたり,スクリプトを一気に回すことが容易だからです. 85 | 86 | ## Experiment 87 | 88 | 以下のスクリプトを実行することで実験が回る.`utils/make_configs.py`を実行すると自動でconfiguration fileを生成してくれる. 89 | 変更したいパラメータをコマンドライン引数として実行してください. 90 | 91 | ```shell 92 | sh experiment.sh 93 | ``` 94 | 95 | ## その他 96 | 97 | コードの可読性は本当に大事です.pep8は守ったり,コメントはできる限り残すようにした方がいいと思います. 98 | 99 | * black 100 | * flake8 101 | * isort 102 | 103 | などを使ってコードを綺麗に整形しましょう. 104 | -------------------------------------------------------------------------------- /src/util_scripts/make_csv_files.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import sys 5 | from typing import Dict, List, Union 6 | 7 | import pandas as pd 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | from libs.class_id_map import get_cls2id_map 11 | 12 | 13 | def get_arguments() -> argparse.Namespace: 14 | """parse all the arguments from command line inteface return a list of 15 | parsed arguments.""" 16 | 17 | parser = argparse.ArgumentParser( 18 | description="make csv files for flowers recognition dataset" 19 | ) 20 | parser.add_argument( 21 | "--dataset_dir", 22 | type=str, 23 | default="../dataset/flowers/", 24 | help="path to a dataset dirctory", 25 | ) 26 | parser.add_argument( 27 | "--save_dir", 28 | type=str, 29 | default="./csv", 30 | help="a directory where csv files will be saved", 31 | ) 32 | 33 | return parser.parse_args() 34 | 35 | 36 | def split_data( 37 | data: Dict[str, Dict[str, List[Union[int, str]]]], 38 | img_paths: List[str], 39 | cls_name: str, 40 | cls_id: int, 41 | ) -> None: 42 | 43 | for i, path in enumerate(img_paths): 44 | path = os.path.abspath(path) 45 | if i % 5 == 4: 46 | # for test 47 | data["test"]["image_path"].append(path) 48 | data["test"]["label"].append(cls_name) 49 | data["test"]["class_id"].append(cls_id) 50 | elif i % 5 == 3: 51 | # for validation 52 | data["val"]["image_path"].append(path) 53 | data["val"]["label"].append(cls_name) 54 | data["val"]["class_id"].append(cls_id) 55 | else: 56 | # for training 57 | data["train"]["image_path"].append(path) 58 | data["train"]["label"].append(cls_name) 59 | data["train"]["class_id"].append(cls_id) 60 | 61 | 62 | def main() -> None: 63 | args = get_arguments() 64 | 65 | cls2id_map = get_cls2id_map() 66 | 67 | data: Dict[str, Dict[str, List[Union[int, str]]]] = { 68 | "train": { 69 | "image_path": [], 70 | "class_id": [], 71 | "label": [], 72 | }, 73 | "val": { 74 | "image_path": [], 75 | "class_id": [], 76 | "label": [], 77 | }, 78 | "test": { 79 | "image_path": [], 80 | "class_id": [], 81 | "label": [], 82 | }, 83 | } 84 | 85 | # 各ディレクトリから画像のパスを指定 86 | # train : val : test = 6 : 2 : 2 になるように分割 87 | for cls_name in cls2id_map.keys(): 88 | img_paths = glob.glob(os.path.join(args.dataset_dir, cls_name, "*.jpg")) 89 | 90 | split_data(data, img_paths, cls_name, cls2id_map[cls_name]) 91 | 92 | # list を DataFrame に変換 93 | train_df = pd.DataFrame( 94 | data["train"], 95 | columns=["image_path", "class_id", "label"], 96 | ) 97 | 98 | val_df = pd.DataFrame( 99 | data["val"], 100 | columns=["image_path", "class_id", "label"], 101 | ) 102 | 103 | test_df = pd.DataFrame( 104 | data["test"], 105 | columns=["image_path", "class_id", "label"], 106 | ) 107 | 108 | # 保存ディレクトリがなければ,作成 109 | os.makedirs(args.save_dir, exist_ok=True) 110 | 111 | # 保存 112 | train_df.to_csv(os.path.join(args.save_dir, "train.csv"), index=None) 113 | val_df.to_csv(os.path.join(args.save_dir, "val.csv"), index=None) 114 | test_df.to_csv(os.path.join(args.save_dir, "test.csv"), index=None) 115 | 116 | print("Finished making csv files.") 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /src/libs/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from logging import getLogger 3 | from pprint import pformat 4 | from typing import Any, Dict, Tuple 5 | 6 | import yaml 7 | 8 | from .dataset_csv import DATASET_CSVS 9 | 10 | __all__ = ["get_config"] 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class Config: 17 | """Experimental configuration class.""" 18 | 19 | model: str = "resnet18" 20 | pretrained: bool = True 21 | 22 | # whether you use class weight to calculate cross entropy or not 23 | use_class_weight: bool = True 24 | 25 | batch_size: int = 32 26 | 27 | width: int = 224 28 | height: int = 224 29 | 30 | num_workers: int = 2 31 | max_epoch: int = 50 32 | 33 | learning_rate: float = 0.003 34 | 35 | dataset_name: str = "flower" 36 | 37 | topk: Tuple[int, ...] = (1, 3) 38 | 39 | def __post_init__(self) -> None: 40 | self._type_check() 41 | self._value_check() 42 | 43 | logger.info( 44 | "Experiment Configuration\n" + pformat(dataclasses.asdict(self), width=1) 45 | ) 46 | 47 | def _value_check(self) -> None: 48 | if self.dataset_name not in DATASET_CSVS: 49 | message = ( 50 | f"dataset_name should be selected from {list(DATASET_CSVS.keys())}." 51 | ) 52 | logger.error(message) 53 | raise ValueError(message) 54 | 55 | if self.max_epoch <= 0: 56 | message = "max_epoch must be positive." 57 | logger.error(message) 58 | raise ValueError(message) 59 | 60 | def _type_check(self) -> None: 61 | """Reference: 62 | https://qiita.com/obithree/items/1c2b43ca94e4fbc3aa8d 63 | """ 64 | 65 | _dict = dataclasses.asdict(self) 66 | 67 | for field, field_type in self.__annotations__.items(): 68 | # if you use type annotation class provided by `typing`, 69 | # you should convert it to the type class used in python. 70 | # e.g.) Tuple[int] -> tuple 71 | # https://stackoverflow.com/questions/51171908/extracting-data-from-typing-types 72 | 73 | # check the instance is Tuple or not. 74 | # https://github.com/zalando/connexion/issues/739 75 | if hasattr(field_type, "__origin__"): 76 | # e.g.) Tuple[int].__args__[0] -> `int` 77 | element_type = field_type.__args__[0] 78 | 79 | # e.g.) Tuple[int].__origin__ -> `tuple` 80 | field_type = field_type.__origin__ 81 | 82 | self._type_check_element(field, _dict[field], element_type) 83 | 84 | # bool is the subclass of int, 85 | # so need to use `type() is` instead of `isinstance` 86 | if type(_dict[field]) is not field_type: 87 | message = f"The type of '{field}' field is supposed to be {field_type}." 88 | logger.error(message) 89 | raise TypeError(message) 90 | 91 | def _type_check_element( 92 | self, field: str, vals: Tuple[Any], element_type: type 93 | ) -> None: 94 | for val in vals: 95 | if type(val) is not element_type: 96 | message = ( 97 | f"The element of '{field}' field is supposed to be {element_type}." 98 | ) 99 | logger.error(message) 100 | raise TypeError(message) 101 | 102 | 103 | def convert_list2tuple(_dict: Dict[str, Any]) -> Dict[str, Any]: 104 | # cannot use list in dataclass because mutable defaults are not allowed. 105 | for key, val in _dict.items(): 106 | if isinstance(val, list): 107 | _dict[key] = tuple(val) 108 | 109 | logger.debug("converted list to tuple in dictionary.") 110 | return _dict 111 | 112 | 113 | def get_config(config_path: str) -> Config: 114 | with open(config_path, "r") as f: 115 | config_dict = yaml.safe_load(f) 116 | 117 | config_dict = convert_list2tuple(config_dict) 118 | config = Config(**config_dict) 119 | 120 | logger.info("successfully loaded configuration.") 121 | return config 122 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import datetime 4 | import os 5 | from logging import DEBUG, INFO, basicConfig, getLogger 6 | 7 | import pandas as pd 8 | import torch 9 | from torchvision.transforms import Compose, Normalize, ToTensor 10 | 11 | from libs.class_id_map import get_cls2id_map 12 | from libs.config import get_config 13 | from libs.dataset import get_dataloader 14 | from libs.device import get_device 15 | from libs.helper import evaluate 16 | from libs.loss_fn import get_criterion 17 | from libs.mean_std import get_mean, get_std 18 | from libs.models import get_model 19 | 20 | logger = getLogger(__name__) 21 | 22 | 23 | def get_arguments() -> argparse.Namespace: 24 | """parse all the arguments from command line inteface return a list of 25 | parsed arguments.""" 26 | 27 | parser = argparse.ArgumentParser( 28 | description=""" 29 | train a network for image classification 30 | with Flowers Recognition Dataset 31 | """ 32 | ) 33 | parser.add_argument("config", type=str, help="path of a config file") 34 | parser.add_argument("mode", type=str, help="validation or test") 35 | parser.add_argument( 36 | "--model", 37 | type=str, 38 | default=None, 39 | help="""path to the trained model. If you do not specify, the trained model, 40 | 'best_acc1_model.prm' in result directory will be used.""", 41 | ) 42 | parser.add_argument( 43 | "--debug", 44 | action="store_true", 45 | help="Add --debug option if you want to see debug-level logs.", 46 | ) 47 | 48 | return parser.parse_args() 49 | 50 | 51 | def main() -> None: 52 | args = get_arguments() 53 | 54 | # configuration 55 | config = get_config(args.config) 56 | result_path = os.path.dirname(args.config) 57 | 58 | if args.mode not in ["validation", "test"]: 59 | message = "args.mode is invalid. ['validation', 'test']" 60 | logger.error(message) 61 | raise ValueError(message) 62 | 63 | # setting logger configuration 64 | logname = os.path.join( 65 | result_path, f"{datetime.datetime.now():%Y-%m-%d}_{args.mode}.log" 66 | ) 67 | basicConfig( 68 | level=DEBUG if args.debug else INFO, 69 | format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", 70 | datefmt="%Y-%m-%d %H:%M:%S", 71 | filename=logname, 72 | ) 73 | 74 | # cpu or cuda 75 | device = get_device(allow_only_gpu=True) 76 | 77 | # Dataloader 78 | transform = Compose([ToTensor(), Normalize(mean=get_mean(), std=get_std())]) 79 | 80 | loader = get_dataloader( 81 | config.dataset_name, 82 | "val" if args.mode == "validation" else "test", 83 | batch_size=1, 84 | shuffle=False, 85 | num_workers=config.num_workers, 86 | pin_memory=True, 87 | transform=transform, 88 | ) 89 | 90 | # the number of classes 91 | n_classes = len(get_cls2id_map()) 92 | 93 | model = get_model(config.model, n_classes, pretrained=config.pretrained) 94 | 95 | # send the model to cuda/cpu 96 | model.to(device) 97 | 98 | # load the state dict of the model 99 | if args.model is not None: 100 | state_dict = torch.load(args.model) 101 | else: 102 | state_dict = torch.load(os.path.join(result_path, "best_model.prm")) 103 | 104 | model.load_state_dict(state_dict) 105 | 106 | # criterion for loss 107 | criterion = get_criterion(config.use_class_weight, config.dataset_name, device) 108 | 109 | # train and validate model 110 | logger.info(f"---------- Start evaluation for {args.mode} data ----------") 111 | 112 | # evaluation 113 | loss, acc1, f1s, c_matrix = evaluate(loader, model, criterion, device) 114 | 115 | logger.info("loss: {:.5f}\tacc1: {:.2f}\tF1 Score: {:.2f}".format(loss, acc1, f1s)) 116 | 117 | df = pd.DataFrame( 118 | {"loss": [loss], "acc@1": [acc1], "f1score": [f1s]}, 119 | columns=["loss", "acc@1", "f1score"], 120 | index=None, 121 | ) 122 | 123 | df.to_csv(os.path.join(result_path, "{}_log.csv").format(args.mode), index=False) 124 | 125 | with open( 126 | os.path.join(result_path, "{}_c_matrix.csv").format(args.mode), "w" 127 | ) as file: 128 | writer = csv.writer(file, lineterminator="\n") 129 | writer.writerows(c_matrix) 130 | 131 | logger.info("Done.") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Template 2 | 3 | ![status-badge](https://github.com/yiskw713/pytorch_template/workflows/mypy_pytest/badge.svg) 4 | 5 | project for pytorch implementation example of image classification 6 | 7 | ## Requirements 8 | 9 | * python >= 3.7 10 | * pytorch >= 1.0 11 | * pyyaml 12 | * scikit-learn 13 | * [wandb](https://wandb.ai/) 14 | * [pre-commit](https://pre-commit.com/) (for pre-commit formatting, type check and testing) 15 | * [hiddenlayer](https://github.com/waleedka/hiddenlayer) 16 | * [graphviz](https://graphviz.gitlab.io/download/) 17 | * [python wrapper for graphviz](https://github.com/xflr6/graphviz) 18 | 19 | Please run `poetry install` to install the necessary packages. 20 | 21 | You can also setup the environment using docker and docker-compose. 22 | 23 | ## Dataset 24 | 25 | Flowers Recognition Dataset 26 | Download the dataset from [HERE](https://www.kaggle.com/alxmamaev/flowers-recognition/download). 27 | 28 | ## Directory Structure 29 | 30 | ```Directory Structure 31 | . 32 | ├── docs/ 33 | ├── LICENSE 34 | ├── README.md 35 | ├── dataset/ 36 | │ └── flowers/ 37 | ├── pyproject.toml 38 | ├── .gitignore 39 | ├── .gitattributes 40 | ├── .pre-commit-config.yaml 41 | ├── poetry.lock 42 | ├── docker-compose.yaml 43 | ├── Dockerfile 44 | ├── tests/ 45 | └── src/ 46 | ├── csv 47 | ├── libs/ 48 | ├── utils 49 | ├── notebook/ 50 | ├── result/ 51 | ├── scripts/ 52 | │ └── experiment.sh 53 | ├── train.py 54 | └── evaluate.py 55 | ``` 56 | 57 | ## Features 58 | 59 | * configuration class using `dataclasses.dataclass` (`libs/config.py`) 60 | * type check. 61 | * detection of unnecessary / extra parameters in a specified configuration. 62 | * `dataclass` is an immutable object, 63 | which prevents the setting from being changed by mistake. 64 | * automatically generating configuration files (`utils/make_configs.py`) 65 | * e.g.) run this command 66 | 67 | ```bash 68 | python utils/make_configs.py --model resnet18 resnet30 resnet50 --learning_rate 0.001 0.0001 --dataset_name flower 69 | ``` 70 | 71 | then you can get all of the combinations with `model` and `learning_rate` (total 6 config files), 72 | while the other parameters are set by default as described in `libs/config.py`. 73 | 74 | You can choose which data you use in experiment by specifying `dataset_name`. 75 | The lists of data for training, validation and testing are saved as csv files. 76 | You can see the paths to them in `libs/dataset_csv.py` and get them corresponding to `dataset_name`. 77 | If you want to use another dataset, please add csv files and the paths in `DATASET_CSVS` in `libs/dataset_csv.py`. 78 | 79 | You can also set tuple object parameters in configs like the below. 80 | 81 | ```bash 82 | python utils/make_configs.py --model resnet18 --topk 1 3 --topk 1 3 5 83 | ``` 84 | 85 | By running this, you can get two configurations, 86 | in one of which topk parameter is (1, 3) 87 | and in the other topk parameter is (1, 3, 5). 88 | * running all the experiments by running shell scripts (`scripts/experiment.sh`) 89 | * support type annotation (`typing`) 90 | * code formatting with `black`, `isort` and `flake8` 91 | * visualize model for debug using [`hiddenlayer`](https://github.com/waleedka/hiddenlayer) (`src/utils/visualize_model.py`) 92 | 93 | ## Experiment 94 | 95 | Please see `scripts/experiment.sh` for the detail. 96 | You can set configurations and run all the experiments by the below command. 97 | 98 | ```sh 99 | sh scripts/experiment.sh 100 | ``` 101 | 102 | ### Setup dependencies 103 | 104 | If you use local environment, then run 105 | 106 | ```sh 107 | poetry install 108 | ``` 109 | 110 | If you use docker, then run 111 | 112 | ```sh 113 | docker-compose up -d --build 114 | docker-compose run mlserver bash 115 | ``` 116 | 117 | ### training 118 | 119 | ```sh 120 | python train.py ./result/xxxx/config.yaml 121 | ``` 122 | 123 | ### evaluation 124 | 125 | ```shell 126 | python evaluate.py ./result/xxxx/config.yaml validation 127 | python evaluate.py ./result/xxxx/config.yaml test 128 | ``` 129 | 130 | ### Model visualization 131 | 132 | ```shell 133 | python utils/visualize_model.py MODEL_NAME 134 | ``` 135 | 136 | ## Formatting 137 | 138 | * black 139 | * flake8 140 | * isort 141 | 142 | ## TODO 143 | 144 | * [x] pytorch implementation of image classification 145 | * [x] configuration class using `dataclasses.dataclass` 146 | * [x] auto generation of config yaml files 147 | * [x] shell script to run all the experiment 148 | * [x] support `typing` (type annotation) 149 | * [x] test code (run testing with pre-commit check) 150 | * [x] `mypy` (pre-commit check) 151 | * [x] formatting (pre-commit `isort`, `black` and `flake8`) 152 | * [x] calculate cyclomatic complexity / expression complexity / cognitive complexity (`flake8` extension) 153 | * [x] CI for testing using GitHub Actions 154 | * [x] visualization of models 155 | * [x] add Dockerfile and docker-compose.yaml 156 | 157 | ## License 158 | 159 | This repository is released under the [MIT License](./LICENSE) 160 | -------------------------------------------------------------------------------- /src/libs/helper.py: -------------------------------------------------------------------------------- 1 | import time 2 | from logging import getLogger 3 | from typing import Any, Dict, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from sklearn.metrics import confusion_matrix, f1_score 10 | from torch.utils.data import DataLoader 11 | 12 | from .meter import AverageMeter, ProgressMeter 13 | from .metric import calc_accuracy 14 | 15 | __all__ = ["train", "evaluate"] 16 | 17 | logger = getLogger(__name__) 18 | 19 | 20 | def do_one_iteration( 21 | sample: Dict[str, Any], 22 | model: nn.Module, 23 | criterion: Any, 24 | device: str, 25 | iter_type: str, 26 | optimizer: Optional[optim.Optimizer] = None, 27 | ) -> Tuple[int, float, float, np.ndarray, np.ndarray]: 28 | 29 | if iter_type not in ["train", "evaluate"]: 30 | message = "iter_type must be either 'train' or 'evaluate'." 31 | logger.error(message) 32 | raise ValueError(message) 33 | 34 | if iter_type == "train" and optimizer is None: 35 | message = "optimizer must be set during training." 36 | logger.error(message) 37 | raise ValueError(message) 38 | 39 | x = sample["img"].to(device) 40 | t = sample["class_id"].to(device) 41 | 42 | batch_size = x.shape[0] 43 | 44 | # compute output and loss 45 | output = model(x) 46 | loss = criterion(output, t) 47 | 48 | # measure accuracy and record loss 49 | accs = calc_accuracy(output, t, topk=(1,)) 50 | acc1 = accs[0] 51 | 52 | # keep predicted results and gts for calculate F1 Score 53 | _, pred = output.max(dim=1) 54 | gt = t.to("cpu").numpy() 55 | pred = pred.to("cpu").numpy() 56 | 57 | if iter_type == "train" and optimizer is not None: 58 | # compute gradient and do SGD step 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | 63 | return batch_size, loss.item(), acc1, gt, pred 64 | 65 | 66 | def train( 67 | loader: DataLoader, 68 | model: nn.Module, 69 | criterion: Any, 70 | optimizer: optim.Optimizer, 71 | epoch: int, 72 | device: str, 73 | interval_of_progress: int = 50, 74 | ) -> Tuple[float, float, float]: 75 | 76 | batch_time = AverageMeter("Time", ":6.3f") 77 | data_time = AverageMeter("Data", ":6.3f") 78 | losses = AverageMeter("Loss", ":.4e") 79 | top1 = AverageMeter("Acc@1", ":6.2f") 80 | 81 | progress = ProgressMeter( 82 | len(loader), 83 | [batch_time, data_time, losses, top1], 84 | prefix="Epoch: [{}]".format(epoch), 85 | ) 86 | 87 | # keep predicted results and gts for calculate F1 Score 88 | gts = [] 89 | preds = [] 90 | 91 | # switch to train mode 92 | model.train() 93 | 94 | end = time.time() 95 | for i, sample in enumerate(loader): 96 | # measure data loading time 97 | data_time.update(time.time() - end) 98 | 99 | batch_size, loss, acc1, gt, pred = do_one_iteration( 100 | sample, model, criterion, device, "train", optimizer 101 | ) 102 | 103 | losses.update(loss, batch_size) 104 | top1.update(acc1, batch_size) 105 | 106 | # save the ground truths and predictions in lists 107 | gts += list(gt) 108 | preds += list(pred) 109 | 110 | # measure elapsed time 111 | batch_time.update(time.time() - end) 112 | end = time.time() 113 | 114 | # show progress bar per 50 iteration 115 | if i != 0 and i % interval_of_progress == 0: 116 | progress.display(i) 117 | 118 | # calculate F1 Score 119 | f1s = f1_score(gts, preds, average="macro") 120 | 121 | return losses.get_average(), top1.get_average(), f1s 122 | 123 | 124 | def evaluate( 125 | loader: DataLoader, model: nn.Module, criterion: Any, device: str 126 | ) -> Tuple[float, float, float, np.ndarray]: 127 | losses = AverageMeter("Loss", ":.4e") 128 | top1 = AverageMeter("Acc@1", ":6.2f") 129 | 130 | # keep predicted results and gts for calculate F1 Score 131 | gts = [] 132 | preds = [] 133 | 134 | # calculate confusion matrix 135 | n_classes = loader.dataset.get_n_classes() 136 | c_matrix = np.zeros((n_classes, n_classes), dtype=np.int32) 137 | 138 | # switch to evaluate mode 139 | model.eval() 140 | 141 | with torch.no_grad(): 142 | for sample in loader: 143 | batch_size, loss, acc1, gt, pred = do_one_iteration( 144 | sample, model, criterion, device, "evaluate" 145 | ) 146 | 147 | losses.update(loss, batch_size) 148 | top1.update(acc1, batch_size) 149 | 150 | # keep predicted results and gts for calculate F1 Score 151 | gts += list(gt) 152 | preds += list(pred) 153 | 154 | c_matrix += confusion_matrix( 155 | gt, 156 | pred, 157 | labels=[i for i in range(n_classes)], 158 | ) 159 | 160 | f1s = f1_score(gts, preds, average="macro") 161 | 162 | return losses.get_average(), top1.get_average(), f1s, c_matrix 163 | -------------------------------------------------------------------------------- /tests/test_helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | import torch.optim as optim 7 | from pytest_mock import MockFixture 8 | from torchvision import transforms 9 | 10 | from src.libs.dataset import get_dataloader 11 | from src.libs.helper import do_one_iteration, evaluate, train 12 | from src.libs.loss_fn import get_criterion 13 | from src.libs.models import get_model 14 | 15 | 16 | @pytest.fixture() 17 | def sample(): 18 | img = torch.randn(2, 3, 112, 112) 19 | class_id = torch.tensor([0, 1]).long() 20 | label = ["daisy", "dandelion"] 21 | 22 | return {"img": img, "class_id": class_id, "label": label} 23 | 24 | 25 | @pytest.fixture() 26 | def model_optimizer(): 27 | model = get_model("resnet18", 5) 28 | optimizer = optim.Adam(model.parameters(), lr=0.0003) 29 | return (model, optimizer) 30 | 31 | 32 | @pytest.fixture() 33 | def criterion(): 34 | return get_criterion() 35 | 36 | 37 | def test_do_one_iteration1(sample, model_optimizer, criterion): 38 | # check iteration for training 39 | model, optimizer = model_optimizer 40 | original_model = copy.deepcopy(model) 41 | 42 | batch_size, loss, acc1, gt, pred = do_one_iteration( 43 | sample, model, criterion, "cpu", "train", optimizer 44 | ) 45 | 46 | assert batch_size == 2 47 | assert loss > 0 48 | assert 0 <= acc1 <= 100.0 49 | assert np.all(gt == np.array([0, 1])) 50 | assert pred.shape == (2,) 51 | 52 | # check if models have the same weights 53 | # https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351 54 | for key_item1, key_item2 in zip( 55 | model.state_dict().items(), original_model.state_dict().items() 56 | ): 57 | # if the weights are completely identical, training does not work. 58 | assert not torch.equal(key_item1[1], key_item2[1]) 59 | 60 | 61 | def test_do_one_iteration2(sample, model_optimizer, criterion): 62 | # check iteration for evaluation 63 | model, optimizer = model_optimizer 64 | original_model = copy.deepcopy(model) 65 | 66 | model.eval() 67 | batch_size, loss, acc1, gt, pred = do_one_iteration( 68 | sample, model, criterion, "cpu", "evaluate" 69 | ) 70 | 71 | assert batch_size == 2 72 | assert loss > 0 73 | assert 0 <= acc1 <= 100.0 74 | assert np.all(gt == np.array([0, 1])) 75 | assert pred.shape == (2,) 76 | 77 | # check if models have the same weights 78 | # https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351 79 | for key_item1, key_item2 in zip( 80 | model.state_dict().items(), original_model.state_dict().items() 81 | ): 82 | # if the weights are completely identical, training does not work. 83 | assert torch.equal(key_item1[1], key_item2[1]) 84 | 85 | 86 | def test_do_one_iteration3(sample, model_optimizer, criterion): 87 | model, optimizer = model_optimizer 88 | with pytest.raises(ValueError): 89 | do_one_iteration(sample, model, criterion, "cpu", "test") 90 | 91 | with pytest.raises(ValueError): 92 | do_one_iteration(sample, model, criterion, "cpu", "train") 93 | 94 | 95 | def test_train(mocker: MockFixture, model_optimizer, criterion): 96 | model, optimizer = model_optimizer 97 | 98 | mocker.patch("src.libs.helper.do_one_iteration").return_value = ( 99 | 2, 100 | 0.1, 101 | 50.0, 102 | np.array([0, 1]), 103 | np.array([1, 1]), 104 | ) 105 | 106 | loader = get_dataloader( 107 | "pytest", 108 | "train", 109 | batch_size=2, 110 | shuffle=False, 111 | num_workers=1, 112 | pin_memory=False, 113 | drop_last=True, 114 | transform=transforms.Compose( 115 | [ 116 | transforms.Resize((224, 224)), 117 | transforms.ToTensor(), 118 | ] 119 | ), 120 | ) 121 | 122 | # make small dataset 123 | loader.dataset.df = loader.dataset.df[:10] 124 | 125 | loss, acc1, f1s = train( 126 | loader, model, criterion, optimizer, 0, "cpu", interval_of_progress=1 127 | ) 128 | 129 | assert model.training 130 | assert loss == 0.1 131 | assert acc1 == 50.0 132 | assert 0 <= f1s <= 1.0 133 | 134 | 135 | def test_evaluate(mocker: MockFixture, model_optimizer, criterion): 136 | model, _ = model_optimizer 137 | 138 | mocker.patch("src.libs.helper.do_one_iteration").return_value = ( 139 | 2, 140 | 0.1, 141 | 50.0, 142 | np.array([0, 1]), 143 | np.array([1, 1]), 144 | ) 145 | 146 | loader = get_dataloader( 147 | "pytest", 148 | "test", 149 | batch_size=2, 150 | shuffle=False, 151 | num_workers=1, 152 | pin_memory=False, 153 | drop_last=False, 154 | transform=transforms.Compose( 155 | [ 156 | transforms.Resize((224, 224)), 157 | transforms.ToTensor(), 158 | ] 159 | ), 160 | ) 161 | 162 | # make small dataset 163 | loader.dataset.df = loader.dataset.df[:10] 164 | n_classes = loader.dataset.get_n_classes() 165 | 166 | loss, acc1, f1s, c_matrix = evaluate(loader, model, criterion, "cpu") 167 | 168 | assert not model.training 169 | assert loss == 0.1 170 | assert acc1 == 50.0 171 | assert 0 <= f1s <= 1.0 172 | assert c_matrix.shape == (n_classes, n_classes) 173 | -------------------------------------------------------------------------------- /src/util_scripts/make_configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import itertools 4 | import os 5 | import sys 6 | 7 | import yaml 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | 11 | 12 | from typing import Any, Dict, List, Tuple 13 | 14 | from libs.config import Config 15 | 16 | 17 | def str2bool(val: str) -> bool: 18 | if isinstance(val, bool): 19 | return val 20 | if val.lower() in ("yes", "true", "t", "y", "1"): 21 | return True 22 | elif val.lower() in ("no", "false", "f", "n", "0"): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError("Boolean value expected.") 26 | 27 | 28 | def get_arguments() -> argparse.Namespace: 29 | """parse all the arguments from command line inteface return a list of 30 | parsed arguments.""" 31 | 32 | parser = argparse.ArgumentParser(description="make configuration yaml files.") 33 | 34 | parser.add_argument( 35 | "--root_dir", 36 | type=str, 37 | default="./result", 38 | help="path to a directory where you want to make config files and directories.", 39 | ) 40 | 41 | fields = dataclasses.fields(Config) 42 | 43 | for field in fields: 44 | type_func = str2bool if field.type is bool else field.type 45 | 46 | if isinstance(field.default, dataclasses._MISSING_TYPE): 47 | # default value is not set. 48 | # do not specify boolean type in argparse 49 | # ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 50 | parser.add_argument( 51 | f"--{field.name}", 52 | type=type_func, 53 | nargs="*", 54 | required=True, 55 | ) 56 | elif hasattr(field.type, "__origin__"): 57 | # the field type is Tuple or not. 58 | # https://github.com/zalando/connexion/issues/739 59 | parser.add_argument( 60 | f"--{field.name}", 61 | type=field.type.__args__[0], 62 | action="append", 63 | nargs="+", 64 | default=[list(field.default)], 65 | ) 66 | else: 67 | # default value is provided in config dataclass. 68 | parser.add_argument( 69 | f"--{field.name}", 70 | type=type_func, 71 | nargs="*", 72 | default=field.default, 73 | ) 74 | 75 | return parser.parse_args() 76 | 77 | 78 | def convert_tuple2list(_dict: Dict[str, Any]) -> Dict[str, Any]: 79 | # cannot use tuple in yaml file for safe loading. 80 | for key, val in _dict.items(): 81 | if isinstance(val, tuple): 82 | _dict[key] = tuple(val) 83 | 84 | return _dict 85 | 86 | 87 | def parse_params( 88 | args_dict: Dict[str, Any] 89 | ) -> Tuple[Dict[str, Any], List[str], List[List[Any]]]: 90 | 91 | base_config = {} 92 | variable_keys = [] 93 | variable_values = [] 94 | 95 | for k, v in args_dict.items(): 96 | if isinstance(v, list): 97 | variable_keys.append(k) 98 | variable_values.append(v) 99 | else: 100 | base_config[k] = v 101 | 102 | return base_config, variable_keys, variable_values 103 | 104 | 105 | def get_n_options( 106 | variable_keys: List[str], variable_values: List[List[Any]] 107 | ) -> Dict[str, int]: 108 | cnt = {} 109 | for k, v in zip(variable_keys, variable_values): 110 | cnt[k] = len(v) 111 | 112 | return cnt 113 | 114 | 115 | def generate_and_save_config( 116 | base_config: Dict[str, Any], 117 | variable_keys: List[str], 118 | values: Tuple[Any], 119 | root_dir: str, 120 | n_options_dict: Dict[str, int], 121 | ) -> None: 122 | config = base_config.copy() 123 | param_list = [] 124 | for k, v in zip(variable_keys, values): 125 | config[k] = v 126 | 127 | if n_options_dict[k] == 1: 128 | continue 129 | else: 130 | param_list.append(f"{k}={v}") 131 | 132 | dir_name = "-".join(param_list) 133 | dir_path = os.path.join(root_dir, dir_name) 134 | 135 | if not os.path.exists(dir_path): 136 | os.makedirs(dir_path) 137 | 138 | config_path = os.path.join(dir_path, "config.yaml") 139 | 140 | # save configuration file as yaml 141 | with open(config_path, "w") as f: 142 | yaml.dump(config, f, default_flow_style=False) 143 | 144 | 145 | def main() -> None: 146 | args = get_arguments() 147 | 148 | # convert Namespace to dictionary. 149 | args_dict = vars(args).copy() 150 | del args_dict["root_dir"] 151 | 152 | base_config, variable_keys, variable_values = parse_params(args_dict) 153 | 154 | # base_config may contain tuple object and they should be converted. 155 | base_config = convert_tuple2list(base_config) 156 | 157 | # get direct product 158 | product = itertools.product(*variable_values) 159 | 160 | # get the number of options for each key. 161 | n_options_dict = get_n_options(variable_keys, variable_values) 162 | 163 | # make a directory and save configuration file there. 164 | for values in product: 165 | generate_and_save_config( 166 | base_config, variable_keys, values, args.root_dir, n_options_dict 167 | ) 168 | 169 | print("Finished making configuration files.") 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import time 5 | from logging import DEBUG, INFO, basicConfig, getLogger 6 | 7 | import torch 8 | import torch.optim as optim 9 | import wandb 10 | from torchvision.transforms import ( 11 | ColorJitter, 12 | Compose, 13 | Normalize, 14 | RandomHorizontalFlip, 15 | RandomResizedCrop, 16 | ToTensor, 17 | ) 18 | 19 | from libs.checkpoint import resume, save_checkpoint 20 | from libs.class_id_map import get_cls2id_map 21 | from libs.config import get_config 22 | from libs.dataset import get_dataloader 23 | from libs.device import get_device 24 | from libs.helper import evaluate, train 25 | from libs.logger import TrainLogger 26 | from libs.loss_fn import get_criterion 27 | from libs.mean_std import get_mean, get_std 28 | from libs.models import get_model 29 | from libs.seed import set_seed 30 | 31 | logger = getLogger(__name__) 32 | 33 | 34 | def get_arguments() -> argparse.Namespace: 35 | """parse all the arguments from command line inteface return a list of 36 | parsed arguments.""" 37 | 38 | parser = argparse.ArgumentParser( 39 | description=""" 40 | train a network for image classification with Flowers Recognition Dataset. 41 | """ 42 | ) 43 | parser.add_argument("config", type=str, help="path of a config file") 44 | parser.add_argument( 45 | "--resume", 46 | action="store_true", 47 | help="Add --resume option if you start training from checkpoint.", 48 | ) 49 | parser.add_argument( 50 | "--use_wandb", 51 | action="store_true", 52 | help="Add --use_wandb option if you want to use wandb.", 53 | ) 54 | parser.add_argument( 55 | "--debug", 56 | action="store_true", 57 | help="Add --debug option if you want to see debug-level logs.", 58 | ) 59 | parser.add_argument( 60 | "--seed", 61 | type=int, 62 | default=0, 63 | help="random seed", 64 | ) 65 | 66 | return parser.parse_args() 67 | 68 | 69 | def main() -> None: 70 | args = get_arguments() 71 | 72 | # save log files in the directory which contains config file. 73 | result_path = os.path.dirname(args.config) 74 | experiment_name = os.path.basename(result_path) 75 | 76 | # setting logger configuration 77 | logname = os.path.join(result_path, f"{datetime.datetime.now():%Y-%m-%d}_train.log") 78 | basicConfig( 79 | level=DEBUG if args.debug else INFO, 80 | format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", 81 | datefmt="%Y-%m-%d %H:%M:%S", 82 | filename=logname, 83 | ) 84 | 85 | # fix seed 86 | set_seed() 87 | 88 | # configuration 89 | config = get_config(args.config) 90 | 91 | # cpu or cuda 92 | device = get_device(allow_only_gpu=False) 93 | 94 | # Dataloader 95 | train_transform = Compose( 96 | [ 97 | RandomResizedCrop(size=(config.height, config.width)), 98 | RandomHorizontalFlip(), 99 | ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), 100 | ToTensor(), 101 | Normalize(mean=get_mean(), std=get_std()), 102 | ] 103 | ) 104 | 105 | val_transform = Compose([ToTensor(), Normalize(mean=get_mean(), std=get_std())]) 106 | 107 | train_loader = get_dataloader( 108 | config.dataset_name, 109 | "train", 110 | batch_size=config.batch_size, 111 | shuffle=True, 112 | num_workers=config.num_workers, 113 | pin_memory=True, 114 | drop_last=True, 115 | transform=train_transform, 116 | ) 117 | 118 | val_loader = get_dataloader( 119 | config.dataset_name, 120 | "val", 121 | batch_size=1, 122 | shuffle=False, 123 | num_workers=config.num_workers, 124 | pin_memory=True, 125 | transform=val_transform, 126 | ) 127 | 128 | # the number of classes 129 | n_classes = len(get_cls2id_map()) 130 | 131 | # define a model 132 | model = get_model(config.model, n_classes, pretrained=config.pretrained) 133 | 134 | # send the model to cuda/cpu 135 | model.to(device) 136 | 137 | optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) 138 | 139 | # keep training and validation log 140 | begin_epoch = 0 141 | best_loss = float("inf") 142 | 143 | # resume if you want 144 | if args.resume: 145 | resume_path = os.path.join(result_path, "checkpoint.pth") 146 | begin_epoch, model, optimizer, best_loss = resume(resume_path, model, optimizer) 147 | 148 | log_path = os.path.join(result_path, "log.csv") 149 | train_logger = TrainLogger(log_path, resume=args.resume) 150 | 151 | # criterion for loss 152 | criterion = get_criterion(config.use_class_weight, config.dataset_name, device) 153 | 154 | # Weights and biases 155 | if args.use_wandb: 156 | wandb.init( 157 | name=experiment_name, 158 | config=config, 159 | project="image_classification_template", 160 | job_type="training", 161 | dirs="./wandb_result/", 162 | ) 163 | # Magic 164 | wandb.watch(model, log="all") 165 | 166 | # train and validate model 167 | logger.info("Start training.") 168 | 169 | for epoch in range(begin_epoch, config.max_epoch): 170 | # training 171 | start = time.time() 172 | train_loss, train_acc1, train_f1s = train( 173 | train_loader, model, criterion, optimizer, epoch, device 174 | ) 175 | train_time = int(time.time() - start) 176 | 177 | # validation 178 | start = time.time() 179 | val_loss, val_acc1, val_f1s, c_matrix = evaluate( 180 | val_loader, model, criterion, device 181 | ) 182 | val_time = int(time.time() - start) 183 | 184 | # save a model if top1 acc is higher than ever 185 | if best_loss > val_loss: 186 | best_loss = val_loss 187 | torch.save( 188 | model.state_dict(), 189 | os.path.join(result_path, "best_model.prm"), 190 | ) 191 | 192 | # save checkpoint every epoch 193 | save_checkpoint(result_path, epoch, model, optimizer, best_loss) 194 | 195 | # write logs to dataframe and csv file 196 | train_logger.update( 197 | epoch, 198 | optimizer.param_groups[0]["lr"], 199 | train_time, 200 | train_loss, 201 | train_acc1, 202 | train_f1s, 203 | val_time, 204 | val_loss, 205 | val_acc1, 206 | val_f1s, 207 | ) 208 | 209 | # save logs to wandb 210 | if args.use_wandb: 211 | wandb.log( 212 | { 213 | "lr": optimizer.param_groups[0]["lr"], 214 | "train_time[sec]": train_time, 215 | "train_loss": train_loss, 216 | "train_acc@1": train_acc1, 217 | "train_f1s": train_f1s, 218 | "val_time[sec]": val_time, 219 | "val_loss": val_loss, 220 | "val_acc@1": val_acc1, 221 | "val_f1s": val_f1s, 222 | }, 223 | step=epoch, 224 | ) 225 | 226 | # save models 227 | torch.save(model.state_dict(), os.path.join(result_path, "final_model.prm")) 228 | 229 | # delete checkpoint 230 | os.remove(os.path.join(result_path, "checkpoint.pth")) 231 | 232 | logger.info("Done") 233 | 234 | 235 | if __name__ == "__main__": 236 | main() 237 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.5.0; python_full_version >= "3.6.2" and python_version >= "3.7" 2 | appnope==0.1.2; sys_platform == "darwin" and python_version >= "3.8" and platform_system == "Darwin" 3 | argon2-cffi-bindings==21.2.0; python_version >= "3.7" 4 | argon2-cffi==21.3.0; python_version >= "3.7" 5 | asttokens==2.0.5; python_version >= "3.8" 6 | atomicwrites==1.4.0; python_version >= "3.7" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.4.0" 7 | attrs==21.4.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7" 8 | babel==2.9.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7" 9 | backcall==0.2.0; python_version >= "3.8" 10 | bandit==1.7.4; python_version >= "3.7" 11 | black==22.1.0; python_full_version >= "3.6.2" 12 | bleach==4.1.0; python_version >= "3.7" 13 | certifi==2021.10.8; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6" 14 | cffi==1.15.0; implementation_name == "pypy" and python_version >= "3.7" and python_full_version >= "3.6.1" 15 | charset-normalizer==2.0.12; python_full_version >= "3.6.0" and python_version >= "3.6" 16 | click==8.0.4; python_version >= "3.6" and python_full_version >= "3.6.2" 17 | colorama==0.4.4; sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.6.2" and platform_system == "Windows" and (python_version >= "3.7" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.5.0") and (python_version >= "3.8" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.8" and python_full_version >= "3.5.0") 18 | coverage==6.3.2; python_version >= "3.7" 19 | debugpy==1.5.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7" 20 | decorator==5.1.1; python_version >= "3.8" 21 | defusedxml==0.7.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7" 22 | docker-pycreds==0.4.0; python_version >= "3.6" 23 | entrypoints==0.4; python_full_version >= "3.6.1" and python_version >= "3.7" 24 | executing==0.8.3; python_version >= "3.8" 25 | flake8==4.0.1; python_version >= "3.6" 26 | gitdb==4.0.9; python_version >= "3.7" 27 | gitpython==3.1.27; python_version >= "3.7" 28 | graphviz==0.19.1; python_version >= "3.6" 29 | hiddenlayer==0.3 30 | idna==3.3; python_full_version >= "3.6.2" and python_version >= "3.7" 31 | iniconfig==1.1.1; python_version >= "3.7" 32 | ipykernel==6.9.1; python_version >= "3.7" 33 | ipython-genutils==0.2.0; python_version >= "3.7" 34 | ipython==8.1.1; python_version >= "3.8" 35 | isort==5.10.1; python_full_version >= "3.6.1" and python_version < "4.0" 36 | jedi==0.18.1; python_version >= "3.8" 37 | jinja2==3.0.3; python_version >= "3.7" 38 | json5==0.9.6; python_version >= "3.7" 39 | jsonschema==4.4.0; python_version >= "3.7" 40 | jupyter-client==7.1.2; python_full_version >= "3.7.0" and python_version >= "3.7" 41 | jupyter-core==4.9.2; python_full_version >= "3.6.1" and python_version >= "3.7" 42 | jupyter-server==1.13.5; python_version >= "3.7" 43 | jupyterlab-pygments==0.1.2; python_version >= "3.7" 44 | jupyterlab-server==2.10.3; python_version >= "3.7" 45 | jupyterlab==3.3.1; python_version >= "3.7" 46 | markupsafe==2.1.0; python_version >= "3.7" 47 | matplotlib-inline==0.1.3; python_version >= "3.8" 48 | mccabe==0.6.1; python_version >= "3.6" 49 | mistune==0.8.4; python_version >= "3.7" 50 | mypy-extensions==0.4.3; python_full_version >= "3.6.2" and python_version >= "3.6" 51 | mypy==0.931; python_version >= "3.6" 52 | nbclassic==0.3.6; python_version >= "3.7" 53 | nbclient==0.5.12; python_full_version >= "3.7.0" and python_version >= "3.7" 54 | nbconvert==6.4.2; python_version >= "3.7" 55 | nbformat==5.1.3; python_full_version >= "3.7.0" and python_version >= "3.7" 56 | nest-asyncio==1.5.4; python_full_version >= "3.7.0" and python_version >= "3.7" 57 | notebook-shim==0.1.0; python_version >= "3.7" 58 | notebook==6.4.8; python_version >= "3.7" 59 | numpy==1.22.3; python_version >= "3.8" 60 | opencv-python==4.5.5.64; python_version >= "3.6" 61 | packaging==21.3; python_version >= "3.7" 62 | pandas==1.4.1; python_version >= "3.8" 63 | pandocfilters==1.5.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7" 64 | parso==0.8.3; python_version >= "3.8" 65 | pastel==0.2.1; python_full_version >= "3.6.2" 66 | pathspec==0.9.0; python_full_version >= "3.6.2" 67 | pathtools==0.1.2; python_version >= "3.6" 68 | pbr==5.8.1; python_version >= "3.7" 69 | pexpect==4.8.0; sys_platform != "win32" and python_version >= "3.8" 70 | pickleshare==0.7.5; python_version >= "3.8" 71 | pillow==9.0.1; python_version >= "3.7" 72 | platformdirs==2.5.1; python_version >= "3.7" and python_full_version >= "3.6.2" 73 | pluggy==1.0.0; python_version >= "3.7" 74 | poethepoet==0.13.1; python_full_version >= "3.6.2" 75 | prometheus-client==0.13.1; python_version >= "3.7" 76 | promise==2.3; python_version >= "3.6" 77 | prompt-toolkit==3.0.28; python_full_version >= "3.6.2" and python_version >= "3.8" 78 | protobuf==3.19.4; python_version >= "3.6" 79 | psutil==5.9.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" 80 | ptyprocess==0.7.0; sys_platform != "win32" and python_version >= "3.8" and os_name != "nt" 81 | pure-eval==0.2.2; python_version >= "3.8" 82 | py==1.11.0; python_full_version >= "3.6.1" and python_version >= "3.7" and implementation_name == "pypy" 83 | pycodestyle==2.8.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6" 84 | pycparser==2.21; implementation_name == "pypy" and python_version >= "3.7" and python_full_version >= "3.6.1" 85 | pydantic==1.9.0; python_full_version >= "3.6.1" 86 | pyflakes==2.4.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" 87 | pygments==2.11.2; python_version >= "3.8" 88 | pyparsing==3.0.7; python_version >= "3.7" 89 | pyrsistent==0.18.1; python_version >= "3.7" 90 | pytest-cov==3.0.0; python_version >= "3.6" 91 | pytest-mock==3.7.0; python_version >= "3.7" 92 | pytest==7.0.1; python_version >= "3.6" 93 | python-dateutil==2.8.2; python_full_version >= "3.6.1" and python_version >= "3.8" 94 | pytz==2021.3; python_version >= "3.8" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.8" 95 | pywin32==303; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.7" 96 | pywinpty==1.1.6; os_name == "nt" and python_version >= "3.7" 97 | pyyaml==6.0; python_version >= "3.6" 98 | pyzmq==22.3.0; python_full_version >= "3.6.1" and python_version >= "3.7" 99 | requests==2.27.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7" 100 | send2trash==1.8.0; python_version >= "3.7" 101 | sentry-sdk==1.5.7; python_version >= "3.6" 102 | setproctitle==1.2.2; python_version >= "3.6" 103 | shortuuid==1.0.8; python_version >= "3.6" 104 | six==1.16.0; python_version >= "3.8" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.8" 105 | smmap==5.0.0; python_version >= "3.7" 106 | sniffio==1.2.0; python_full_version >= "3.6.2" and python_version >= "3.7" 107 | stack-data==0.2.0; python_version >= "3.8" 108 | stevedore==3.5.0; python_version >= "3.7" 109 | termcolor==1.1.0; python_full_version >= "3.6.2" and python_full_version < "4.0.0" and python_version >= "3.6" 110 | terminado==0.13.3; python_version >= "3.7" 111 | testpath==0.6.0; python_version >= "3.7" 112 | tomli==2.0.1; python_version >= "3.7" and python_full_version >= "3.6.2" 113 | tornado==6.1; python_full_version >= "3.6.1" and python_version >= "3.7" 114 | traitlets==5.1.1; python_full_version >= "3.7.0" and python_version >= "3.8" 115 | typing-extensions==4.1.1; python_version >= "3.6" and python_full_version >= "3.6.2" and python_version < "3.10" 116 | urllib3==1.26.8; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version < "4" and python_version >= "3.6" 117 | wandb==0.12.11; python_version >= "3.6" 118 | wcwidth==0.2.5; python_full_version >= "3.6.2" and python_version >= "3.8" 119 | webencodings==0.5.1; python_version >= "3.7" 120 | websocket-client==1.3.1; python_version >= "3.7" 121 | yaspin==2.1.0; python_full_version >= "3.6.2" and python_full_version < "4.0.0" and python_version >= "3.6" 122 | imageio==2.16.1 123 | imgaug==0.4.0 124 | joblib==1.1.0 125 | matplotlib==3.5.1 126 | scikit-image==0.19.2 127 | scikit-learn==1.0.2 128 | scipy==1.7.3 129 | timm==0.5.4 130 | torch==1.9.0 131 | torchinfo==1.6.3 132 | torchvision==0.10.0 133 | h5py==3.6.0 134 | flake8-bugbear==22.1.11 135 | flake8-builtins==1.5.3 136 | flake8-eradicate==1.2.0 137 | pep8-naming==0.12.1 138 | flake8-expression-complexity==0.0.10 139 | flake8-cognitive-complexity==0.1.0 140 | flake8-pytest-style==1.6.0 --------------------------------------------------------------------------------