├── .gitignore ├── LICENSE ├── README.md ├── asset ├── CATDOG-LOG-240818-15:22:23_conf_f1.png ├── CATDOG-LOG-240818-15:22:23_iou_score.png ├── CATDOG-LOG-240818-15:22:23_loss.png ├── CATDOG-LOG-240818-15:22:23_map.png ├── CATDOG-LOG-240818-15:22:23_map_50.png ├── Figure_01.png ├── Figure_02.png ├── Figure_03.png ├── Figure_04.png ├── Figure_05.png ├── Figure_06.png ├── Figure_07.png ├── Figure_08.png ├── Figure_11.png ├── Figure_12.png ├── Figure_13.png ├── Figure_14.png ├── Figure_15.png ├── Figure_16.png ├── Figure_21.png ├── Figure_22.png ├── Figure_23.png ├── Figure_24.png ├── Figure_25.png ├── Figure_26.png ├── Figure_31.png ├── Figure_32.png ├── Figure_33.png ├── Figure_34.png ├── Figure_35.png ├── Figure_36.png ├── Figure_37.png ├── Figure_38.png ├── Figure_39.png ├── Figure_40.png ├── Figure_41.png ├── Figure_42.png ├── Figure_43.png ├── Figure_44.png ├── Figure_45.png └── vikit-learn.jpeg ├── doc ├── img │ └── clf │ │ ├── Figure_1.png │ │ ├── __results___6_1.png │ │ ├── __results___6_2.png │ │ ├── __results___6_3.png │ │ ├── __results___6_4.png │ │ └── __results___6_5.png ├── task.md └── tutorials │ └── clf.md ├── requirements.txt ├── setup.py ├── toolkit ├── __init__.py ├── clf_cfg.py └── clf_cli.py └── vklearn ├── __init__.py ├── datasets ├── __init__.py ├── characters.py ├── characters │ ├── ch_sim.txt │ ├── ch_sym_sim.txt │ ├── ch_sym_tra.txt │ ├── ch_tra.txt │ ├── en_sym.txt │ ├── ja.txt │ └── ja_sym.txt ├── cocodetection.py ├── cocopretrain.py ├── hwdb_gnt.py ├── hwdb_pot.py ├── images_folder.py ├── labelme_detection.py ├── labelme_joints.py ├── lsvt_joints.py ├── masksegment.py ├── ms_coco_classnames.json ├── mvtec_screws.py ├── ocr_instruct.py ├── ocr_printing.py ├── ocr_synthesizer.py ├── olhwdb2.py ├── oxford_iiit_pet.py ├── places365.py ├── plain_bbox.py ├── publaynet.py └── vocdetection.py ├── models ├── __init__.py ├── basic.py ├── classifier.py ├── component.py ├── detector.py ├── distiller.py ├── joints.py ├── ocr.py ├── segment.py ├── trimnetclf.py ├── trimnetdet.py ├── trimnetdst.py ├── trimnetjot.py ├── trimnetocr.py ├── trimnetseg.py └── trimnetx.py ├── pipelines ├── __init__.py ├── classifier.py ├── detector.py ├── joints.py ├── ocr.py └── segment.py ├── trainer ├── __init__.py ├── logging.py ├── task.py ├── tasks │ ├── __init__.py │ ├── classification.py │ ├── detection.py │ ├── distillation.py │ ├── joints.py │ ├── ocr.py │ └── segmentation.py └── trainer.py ├── utils ├── __init__.py └── focal_boost.py └── version.py /.gitignore: -------------------------------------------------------------------------------- 1 | # These are some examples of commonly ignored file patterns. 2 | # You should customize this list as applicable to your project. 3 | # Learn more about .gitignore: 4 | # https://www.atlassian.com/git/tutorials/saving-changes/gitignore 5 | 6 | # Node artifact files 7 | node_modules/ 8 | dist/ 9 | 10 | # Compiled Java class files 11 | *.class 12 | 13 | # Compiled Python bytecode 14 | *.py[cod] 15 | 16 | # Log files 17 | *.log 18 | 19 | # Package files 20 | *.jar 21 | 22 | # Maven 23 | target/ 24 | dist/ 25 | 26 | # JetBrains IDE 27 | .idea/ 28 | 29 | # Unit test reports 30 | TEST*.xml 31 | 32 | # Generated by MacOS 33 | .DS_Store 34 | 35 | # Generated by Windows 36 | Thumbs.db 37 | 38 | # Applications 39 | *.app 40 | *.exe 41 | *.war 42 | 43 | # Large media files 44 | *.mp4 45 | *.tiff 46 | *.avi 47 | *.flv 48 | *.mov 49 | *.wmv 50 | 51 | # pyenv 52 | .python-version 53 | 54 | # pipenv 55 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 56 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 57 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 58 | # install all needed dependencies. 59 | #Pipfile.lock 60 | 61 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 62 | __pypackages__/ 63 | 64 | # Celery stuff 65 | celerybeat-schedule 66 | celerybeat.pid 67 | 68 | # SageMath parsed files 69 | *.sage.py 70 | 71 | # Environments 72 | .env 73 | .venv 74 | env/ 75 | venv/ 76 | ENV/ 77 | env.bak/ 78 | venv.bak/ 79 | 80 | # vim swap 81 | *.swp 82 | 83 | # log dir 84 | logs/ 85 | 86 | # checkpoints 87 | checkpoints/ 88 | 89 | # package builds 90 | /build/ 91 | /vikit_learn.egg-info/ 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vikit-learn 2 | 3 | Vikit-learn is a computer vision processing toolkit developed using Python, based on deep learning technology. 4 | 5 | This package aims to provide a series of easy-to-use tools that can handle real-world tasks. 6 | 7 | The project is still under active construction and development, so please look forward to this work! 8 | 9 | Current Support: Image Classification, Object Detection, Semantic Segmentation, Keypoint&Joint Detection, OCR 10 | 11 | ![](./asset/vikit-learn.jpeg) 12 | 13 | ## Installation 14 | 15 | ### Dependencies 16 | 17 | - matplotlib>=3.7.5 18 | - torch>=2.4.0 19 | - torchvision>=0.19.0 20 | - torchmetrics>=1.4.2 21 | - lightning-utilities>=0.11.7 22 | - faster-coco-eval>=1.6.0 23 | - pycocotools>=2.0.7 24 | - opencv-python>=4.10.0 25 | - flet>=0.24.0 26 | - shapely>=2.0.6 27 | - tqdm>=4.66.5 28 | - timm>=1.0.12 29 | 30 | ### With pip 31 | 32 | ```bash 33 | pip install git+https://github.com/bxt-kk/vikit-learn.git 34 | ``` 35 | 36 | ## Usage 37 | 38 | ### Training model 39 | 40 | ```python 41 | # Import `pytorch` and `vklearn` 42 | import torch 43 | from torch.utils.data import DataLoader 44 | 45 | from vklearn.trainer.trainer import Trainer 46 | from vklearn.trainer.tasks import Detection 47 | from vklearn.models.trimnetdet import TrimNetDet as Model 48 | from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet 49 | 50 | 51 | dataset_root = '/kaggle/working/OxfordIIITPet' 52 | dataset_type = 'detection' 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | batch_size = 16 55 | lr = 1e-3 56 | lrf = 0.2 57 | 58 | # Get default transforms from TRBNetX 59 | train_transforms, test_transforms = Model.get_transforms('cocox448') 60 | 61 | # Create datasets 62 | train_data = OxfordIIITPet( 63 | dataset_root, 64 | split='trainval', 65 | target_types=dataset_type, 66 | transforms=train_transforms) 67 | test_data = OxfordIIITPet( 68 | dataset_root, 69 | split='trainval', 70 | target_types=dataset_type, 71 | transforms=test_transforms) 72 | 73 | # Create model TrbnetX 74 | model = Model( 75 | categories=train_data.classes, 76 | ) 77 | 78 | # Create DataLoader 79 | train_loader = DataLoader( 80 | train_data, batch_size, 81 | shuffle=True, 82 | drop_last=True, 83 | collate_fn=model.collate_fn, 84 | num_workers=4) 85 | test_loader = DataLoader( 86 | test_data, batch_size, 87 | shuffle=False, 88 | drop_last=True, 89 | collate_fn=model.collate_fn, 90 | num_workers=4) 91 | 92 | print(len(train_loader)) 93 | 94 | # Build object detection task 95 | task = Detection( 96 | model, device, metric_options={'conf_thresh': 0.05}, 97 | ) 98 | 99 | # Build a trainer by specifying the training task and setting up trainer parameters 100 | trainer = Trainer( 101 | task, 102 | output='/kaggle/working/catdog', 103 | checkpoint=None, 104 | train_loader=train_loader, 105 | test_loader=test_loader, 106 | epochs=10, 107 | lr=lr, 108 | lrf=lrf, 109 | show_step=50, 110 | drop_optim=True, 111 | drop_lr_scheduler=True, 112 | save_epoch=5) 113 | 114 | # Initialize the trainer, then perform training. 115 | trainer.initialize() 116 | trainer.fit() 117 | ``` 118 | 119 | Upon training completion, there will be visualization images of model training results in the `/kaggle/working/logs/` directory: 120 | 121 | ![](./asset/CATDOG-LOG-240818-15:22:23_loss.png) 122 | ![](./asset/CATDOG-LOG-240818-15:22:23_conf_f1.png) 123 | ![](./asset/CATDOG-LOG-240818-15:22:23_iou_score.png) 124 | ![](./asset/CATDOG-LOG-240818-15:22:23_map.png) 125 | ![](./asset/CATDOG-LOG-240818-15:22:23_map_50.png) 126 | 127 | Based on the focal-boost loss function I designed, the model can be successfully trained on tasks with extremely low positive sample ratio. 128 | 129 | ## Using model 130 | 131 | We can call the trained model for object detection in the following way: 132 | 133 | ```python 134 | # Import `vklearn` 135 | from vklearn.models.trimnetdet import TrimNetDet as Model 136 | from vklearn.pipelines.detector import Detector as Pipeline 137 | 138 | 139 | pipeline = Pipeline.load_from_state(Model, '/kaggle/working/catdog-best.pt') 140 | 141 | import matplotlib.pyplot as plt 142 | from PIL import Image 143 | 144 | img = Image.open('??YOUR IMAGE PATH??') 145 | # Detect and display results 146 | objs = pipeline(img, align_size=448) 147 | print(len(objs), objs) 148 | fig = plt.figure() 149 | pipeline.plot_result(img, objs, fig) 150 | plt.show() 151 | ``` 152 | 153 | Here are some examples: 154 | 155 | ![](./asset/Figure_01.png) 156 | ![](./asset/Figure_02.png) 157 | ![](./asset/Figure_03.png) 158 | ![](./asset/Figure_04.png) 159 | ![](./asset/Figure_05.png) 160 | ![](./asset/Figure_06.png) 161 | ![](./asset/Figure_07.png) 162 | ![](./asset/Figure_08.png) 163 | 164 | Here are some examples of image classification: 165 | 166 | ![](./asset/Figure_11.png) 167 | ![](./asset/Figure_12.png) 168 | ![](./asset/Figure_13.png) 169 | ![](./asset/Figure_14.png) 170 | ![](./asset/Figure_15.png) 171 | ![](./asset/Figure_16.png) 172 | 173 | Here are some examples of semantic segmentation: 174 | 175 | ![](./asset/Figure_21.png) 176 | ![](./asset/Figure_22.png) 177 | ![](./asset/Figure_23.png) 178 | ![](./asset/Figure_24.png) 179 | ![](./asset/Figure_25.png) 180 | ![](./asset/Figure_26.png) 181 | 182 | Here are some objectives that support directional localization, which is implemented based on keypoint&joint detection technology: 183 | 184 | ![](./asset/Figure_31.png) 185 | ![](./asset/Figure_32.png) 186 | ![](./asset/Figure_33.png) 187 | ![](./asset/Figure_34.png) 188 | ![](./asset/Figure_35.png) 189 | 190 | Here are some examples of text detection in natural scenes: 191 | 192 | ![](./asset/Figure_36.png) 193 | ![](./asset/Figure_37.png) 194 | ![](./asset/Figure_38.png) 195 | ![](./asset/Figure_39.png) 196 | ![](./asset/Figure_40.png) 197 | 198 | Here are some examples of Optical Character Recognition (OCR): 199 | 200 | ![](./asset/Figure_41.png) 201 | ![](./asset/Figure_42.png) 202 | ![](./asset/Figure_43.png) 203 | ![](./asset/Figure_44.png) 204 | ![](./asset/Figure_45.png) 205 | -------------------------------------------------------------------------------- /asset/CATDOG-LOG-240818-15:22:23_conf_f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/CATDOG-LOG-240818-15:22:23_conf_f1.png -------------------------------------------------------------------------------- /asset/CATDOG-LOG-240818-15:22:23_iou_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/CATDOG-LOG-240818-15:22:23_iou_score.png -------------------------------------------------------------------------------- /asset/CATDOG-LOG-240818-15:22:23_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/CATDOG-LOG-240818-15:22:23_loss.png -------------------------------------------------------------------------------- /asset/CATDOG-LOG-240818-15:22:23_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/CATDOG-LOG-240818-15:22:23_map.png -------------------------------------------------------------------------------- /asset/CATDOG-LOG-240818-15:22:23_map_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/CATDOG-LOG-240818-15:22:23_map_50.png -------------------------------------------------------------------------------- /asset/Figure_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_01.png -------------------------------------------------------------------------------- /asset/Figure_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_02.png -------------------------------------------------------------------------------- /asset/Figure_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_03.png -------------------------------------------------------------------------------- /asset/Figure_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_04.png -------------------------------------------------------------------------------- /asset/Figure_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_05.png -------------------------------------------------------------------------------- /asset/Figure_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_06.png -------------------------------------------------------------------------------- /asset/Figure_07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_07.png -------------------------------------------------------------------------------- /asset/Figure_08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_08.png -------------------------------------------------------------------------------- /asset/Figure_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_11.png -------------------------------------------------------------------------------- /asset/Figure_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_12.png -------------------------------------------------------------------------------- /asset/Figure_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_13.png -------------------------------------------------------------------------------- /asset/Figure_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_14.png -------------------------------------------------------------------------------- /asset/Figure_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_15.png -------------------------------------------------------------------------------- /asset/Figure_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_16.png -------------------------------------------------------------------------------- /asset/Figure_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_21.png -------------------------------------------------------------------------------- /asset/Figure_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_22.png -------------------------------------------------------------------------------- /asset/Figure_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_23.png -------------------------------------------------------------------------------- /asset/Figure_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_24.png -------------------------------------------------------------------------------- /asset/Figure_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_25.png -------------------------------------------------------------------------------- /asset/Figure_26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_26.png -------------------------------------------------------------------------------- /asset/Figure_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_31.png -------------------------------------------------------------------------------- /asset/Figure_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_32.png -------------------------------------------------------------------------------- /asset/Figure_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_33.png -------------------------------------------------------------------------------- /asset/Figure_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_34.png -------------------------------------------------------------------------------- /asset/Figure_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_35.png -------------------------------------------------------------------------------- /asset/Figure_36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_36.png -------------------------------------------------------------------------------- /asset/Figure_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_37.png -------------------------------------------------------------------------------- /asset/Figure_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_38.png -------------------------------------------------------------------------------- /asset/Figure_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_39.png -------------------------------------------------------------------------------- /asset/Figure_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_40.png -------------------------------------------------------------------------------- /asset/Figure_41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_41.png -------------------------------------------------------------------------------- /asset/Figure_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_42.png -------------------------------------------------------------------------------- /asset/Figure_43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_43.png -------------------------------------------------------------------------------- /asset/Figure_44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_44.png -------------------------------------------------------------------------------- /asset/Figure_45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/Figure_45.png -------------------------------------------------------------------------------- /asset/vikit-learn.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/asset/vikit-learn.jpeg -------------------------------------------------------------------------------- /doc/img/clf/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/Figure_1.png -------------------------------------------------------------------------------- /doc/img/clf/__results___6_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/__results___6_1.png -------------------------------------------------------------------------------- /doc/img/clf/__results___6_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/__results___6_2.png -------------------------------------------------------------------------------- /doc/img/clf/__results___6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/__results___6_3.png -------------------------------------------------------------------------------- /doc/img/clf/__results___6_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/__results___6_4.png -------------------------------------------------------------------------------- /doc/img/clf/__results___6_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/doc/img/clf/__results___6_5.png -------------------------------------------------------------------------------- /doc/task.md: -------------------------------------------------------------------------------- 1 | ## Task 2 | 3 | > This `class` is used to configure a set of parameters relevant to a specific task in model training. 4 | 5 | Args: 6 | 7 | - `model`: Model object for a specific task. 8 | - `device`: Computation device supported by PyTorch. 9 | - `metric_start_epoch`: Sets the epoch from which metric calculation starts, defaults to 0. 10 | - `fit_features_start`: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 11 | - `loss_options`: Optional parameters for the given loss calculation function. 12 | - `score_options`: Optional parameters for the given score calculation function. 13 | - `metric_options`: Optional parameters for the given metric evaluation function. 14 | - `key_metrics`: Specifies which key evaluation metrics to track. 15 | - `best_metric`: Current best metric score, initialized to 0. 16 | 17 | ## Classification(Task) 18 | 19 | > This `class` is used to configure a set of parameters relevant to a specific task in classifier model training. 20 | 21 | Args: 22 | 23 | - `model`: Specify a classification model object. 24 | - `device`: Computation device supported by PyTorch. 25 | - `metric_start_epoch`: Sets the epoch from which metric calculation starts, defaults to 0. 26 | - `fit_features_start`: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 27 | - `loss_options`: Set optional parameters for the classification model's loss function. 28 | - `score_options`: Set optional parameters for the classification model's scoring function. 29 | - `metric_options`: Set optional parameters for the classification model's metric evaluation function. 30 | - `key_metrics`: Specifies which key evaluation metrics to track. 31 | - `best_metric`: Current best metric score, initialized to 0. 32 | -------------------------------------------------------------------------------- /doc/tutorials/clf.md: -------------------------------------------------------------------------------- 1 | # vikit-learn・Classification 2 | 3 | In this tutorial, we'll learn how to train an image classifier using `vikit-learn`. We'll be using the OxfordIIITPet dataset, which contains images of cats and dogs, for this practical example. 4 | 5 | ![](../img/clf/Figure_1.png) 6 | 7 | ## Installing the `vikit-learn` Tool 8 | 9 | We can use `pip` to download and install `vikit-learn` directly from GitHub: 10 | 11 | ```bash 12 | pip install git+https://github.com/bxt-kk/vikit-learn.git 13 | ``` 14 | 15 | --- 16 | 17 | ## Writing the Training Script 18 | 19 | We need to write some script code to train our model. 20 | 21 | ### 1. Import the necessary packages from `vikit-learn` and `pytorch` 22 | 23 | ```python 24 | import torch 25 | from torch.utils.data import DataLoader 26 | 27 | from vklearn.trainer.trainer import Trainer 28 | from vklearn.trainer.tasks import Classification as Task 29 | from vklearn.models.trimnetclf import TrimNetClf as Model 30 | from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet 31 | ``` 32 | 33 | - `Trainer`: A general training tool used to set training parameters and execute the training process. 34 | - `Classification`: Specifies the training parameters related to the classification task. 35 | - `TrimNetClf`: A built-in classifier model in `vikit-learn`. 36 | - `OxfordIIITPet`: A built-in dataset tool in `vikit-learn`. 37 | 38 | ### 2. Prepare the Training Data 39 | 40 | ```python 41 | dataset_root = '/kaggle/working/OxfordIIITPet' 42 | dataset_type = 'binary-category' 43 | 44 | train_transforms, test_transforms = Model.get_transforms() 45 | 46 | train_data = OxfordIIITPet( 47 | dataset_root, 48 | split='trainval', 49 | target_types=dataset_type, 50 | download=False, 51 | transforms=train_transforms) 52 | test_data = OxfordIIITPet( 53 | dataset_root, 54 | split='test', 55 | target_types=dataset_type, 56 | transforms=test_transforms) 57 | ``` 58 | 59 | First, we need to specify the location of the data with `dataset_root`. Then, we specify the type of data with `dataset_type = 'binary-category'`, which means binary classification data for cats and dogs. Additionally, we split the data into a training set (`split='trainval'`) and a test set (`split='test'`). 60 | 61 | **Note: If the data is not available in the local directory, we need to set `download` to `True` to download the data from the internet.** 62 | 63 | ```python 64 | batch_size = 128 65 | 66 | train_loader = DataLoader( 67 | train_data, batch_size, 68 | shuffle=True, 69 | drop_last=True, 70 | num_workers=4) 71 | test_loader = DataLoader( 72 | test_data, batch_size, 73 | shuffle=False, 74 | drop_last=True, 75 | num_workers=4) 76 | 77 | print(len(train_loader)) 78 | ``` 79 | 80 | We use the data loading tool `DataLoader` provided by `pytorch` to load the data. Here, we set `batch_size = 128`. 81 | 82 | ### 3. Create the Model and Training Task 83 | 84 | ```python 85 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 86 | model = Model(categories=train_data.bin_classes) 87 | task = Task(model, device) 88 | ``` 89 | 90 | We create a model using the `TrimNetClf` class. Here, we need to specify the number of classification categories and their names for the model. To do this, we use the value of `train_data.bin_classes` as the `categories` parameter for the model. Then, we create a training task object using the model object `model` and the device object `device`: `task = Task(model, device)`. 91 | 92 | ### 4. Initialize the Trainer 93 | 94 | ```python 95 | trainer = Trainer( 96 | task, 97 | output='/kaggle/working/catdog-clf', 98 | train_loader=train_loader, 99 | test_loader=test_loader, 100 | epochs=20, 101 | lr=1e-3, 102 | lrf=0.2, 103 | show_step=50, 104 | save_epoch=5) 105 | 106 | trainer.initialize() 107 | ``` 108 | 109 | By setting the trainer parameters, we can create a trainer for model training. After creating the trainer object, we need to initialize it with the `trainer.initialize()` method. 110 | 111 | We set the following parameters for the trainer: 112 | 113 | - `task`: Specifies the training task. 114 | - `output`: Sets the output path for training data, which is used to store checkpoints and logs. 115 | - `train_loader`: Specifies the training set loader. 116 | - `test_loader`: Specifies the test set loader. 117 | - `epochs`: Sets the total number of training epochs. 118 | - `lr`: Sets the learning rate. 119 | - `lrf`: Sets the learning rate decay factor. 120 | - `show_step`: Sets how often to print the training status. 121 | - `save_epoch`: Sets how often to save a checkpoint. 122 | 123 | ### 5. Execute the Training Task 124 | 125 | Finally, we start the model training with the following code: 126 | 127 | ```python 128 | trainer.fit() 129 | ``` 130 | 131 | After the model training is complete, we will see training logs in the `logs` subdirectory of the output path specified for the trainer: 132 | 133 | ![](../img/clf/__results___6_1.png) 134 | ![](../img/clf/__results___6_2.png) 135 | ![](../img/clf/__results___6_3.png) 136 | ![](../img/clf/__results___6_4.png) 137 | ![](../img/clf/__results___6_5.png) 138 | 139 | In addition to the logs, we will also see the following checkpoint files: 140 | 141 | ```bash 142 | - catdog-clf-4.pt 143 | - catdog-clf-9.pt 144 | - catdog-clf-14.pt 145 | - catdog-clf-19.pt 146 | - catdog-clf-best.pt 147 | ``` 148 | 149 | Generally, we select the one ending with `best.pt` for use, as it is the checkpoint with the highest score on the evaluation metrics of the test set. 150 | 151 | --- 152 | 153 | ## Using the Image Classifier 154 | 155 | After completing the training of the image classifier, we can use the trained classifier to automatically classify images. 156 | 157 | ### 1. First, import the necessary packages 158 | 159 | ```python 160 | import matplotlib.pyplot as plt 161 | from PIL import Image 162 | 163 | from vklearn.models.trimnetclf import TrimNetClf as Model 164 | from vklearn.pipelines.classifier import Classifier as Pipeline 165 | ``` 166 | 167 | `from vklearn.pipelines.classifier import Classifier` will import the pipeline tool `Classifier`, which greatly simplifies model invocation. 168 | 169 | ### 2. Specify the Model Class and Checkpoint File to Create a Classifier 170 | 171 | ```python 172 | pipeline = Pipeline.load_from_state( 173 | Model, '???/catdog-clf-best.pt') 174 | ``` 175 | 176 | **Note: Remember to replace `'???/catdog-clf-best.pt'` with the actual path to the checkpoint file on your computer.** 177 | 178 | ### 3. Open the Model for Classification Prediction and Visualize the Results 179 | 180 | After completing a series of preparations, we can use the following code to perform classification: 181 | 182 | ```python 183 | img = Image.open('??your image path??') 184 | result = pipeline(img) 185 | fig = plt.figure() 186 | pipeline.plot_result(img, result, fig) 187 | plt.show() 188 | ``` 189 | 190 | We use the above code to open an image `img = Image.open('??your image path??')` for classification prediction `result = pipeline(img)`, and visualize the prediction results: 191 | 192 | ![](../img/clf/Figure_1.png) 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.7.5 2 | torch>=2.4.0 3 | torchvision>=0.19.0 4 | torchmetrics>=1.4.2 5 | lightning-utilities>=0.11.7 6 | faster-coco-eval>=1.6.0 7 | pycocotools>=2.0.7 8 | opencv-python>=4.10.0 9 | flet>=0.24.0 10 | shapely>=2.0.6 11 | tqdm>=4.66.5 12 | timm>=1.0.12 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from setuptools import setup, find_packages 3 | 4 | 5 | home = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | __version__ = '0.0.0' 8 | exec(open(os.path.join(home, 'vklearn/version.py')).read()) 9 | 10 | with open(os.path.join(home, 'README.md')) as f: 11 | readme = f.read() 12 | 13 | with open(os.path.join(home, 'requirements.txt')) as f: 14 | requirements = [item.strip() for item in f] 15 | 16 | 17 | setup( 18 | name='vikit-learn', 19 | version=__version__, 20 | author='jojowee', 21 | description='A computer vision toolkit that is easy-to-use and based on deep learning.', 22 | long_description=readme, 23 | license='Apache-2.0 License', 24 | url='https://github.com/bxt-kk/vikit-learn', 25 | project_urls={ 26 | 'Source Code': 'https://github.com/bxt-kk/vikit-learn', 27 | }, 28 | packages=find_packages(), 29 | package_data={'vklearn':[ 30 | 'datasets/ms_coco_classnames.json', 31 | 'datasets/characters/*.txt', 32 | ]}, 33 | install_requires=requirements, 34 | entry_points = { 35 | 'console_scripts': [ 36 | 'vkl-clf-cfg=toolkit.clf_cfg:entry', 37 | 'vkl-clf-cli=toolkit.clf_cli:entry', 38 | ], 39 | } 40 | ) 41 | -------------------------------------------------------------------------------- /toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/toolkit/__init__.py -------------------------------------------------------------------------------- /toolkit/clf_cli.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | from typing import Dict, Any 3 | import json 4 | 5 | from torch.utils.data import DataLoader 6 | from torch.optim import AdamW, Adam, SGD, NAdam 7 | import torch 8 | 9 | from vklearn.trainer.trainer import Trainer 10 | from vklearn.trainer.tasks import Classification as Task 11 | from vklearn.models.trimnetclf import TrimNetClf as Model 12 | from vklearn.datasets.images_folder import ImagesFolder 13 | from vklearn.datasets.oxford_iiit_pet import OxfordIIITPet 14 | from vklearn.datasets.places365 import Places365 15 | 16 | 17 | def load_from_json(path:str) -> Dict[str, Any]: 18 | with open(path) as f: 19 | cfg = json.load(f) 20 | return cfg 21 | 22 | 23 | def main(cfg:Dict[str, Any]): 24 | dataset_root = cfg['dataset']['root'] 25 | 26 | train_transforms, test_transforms = Model.get_transforms( 27 | cfg['dataset']['transform']) 28 | 29 | dataset_type = cfg['dataset']['type'] 30 | dataset_opts = dict() 31 | if dataset_type == 'ImagesFolder': 32 | Dataset = ImagesFolder 33 | dataset_opts['extensions'] = cfg['dataset']['extensions'] 34 | elif dataset_type == 'OxfordIIITPet': 35 | Dataset = OxfordIIITPet 36 | dataset_opts['target_types'] = 'binary-category' 37 | elif dataset_type == 'Places365': 38 | Dataset = Places365 39 | 40 | train_data = Dataset( 41 | dataset_root, 42 | split='train', 43 | transforms=train_transforms, 44 | **dataset_opts) 45 | test_data = Dataset( 46 | dataset_root, 47 | split='val', 48 | transforms=test_transforms, 49 | **dataset_opts) 50 | 51 | batch_size = cfg['dataset']['batch_size'] 52 | num_workers = cfg['dataset']['num_workers'] 53 | 54 | train_loader = DataLoader( 55 | train_data, batch_size, 56 | shuffle=True, 57 | drop_last=True, 58 | num_workers=num_workers) 59 | test_loader = DataLoader( 60 | test_data, batch_size, 61 | shuffle=False, 62 | drop_last=True, 63 | num_workers=num_workers) 64 | 65 | device_name = cfg['task']['device'] 66 | if device_name == 'auto': 67 | device_name = 'cuda' if torch.cuda.is_available() else 'cpu' 68 | device = torch.device(device_name) 69 | model = Model(categories=train_data.classes, **cfg['model']) 70 | task = Task( 71 | model=model, 72 | device=device, 73 | metric_start_epoch=cfg['task']['metric_start_epoch'], 74 | fit_features_start=cfg['task']['fit_features_start'], 75 | ) 76 | 77 | optim_method = { 78 | 'AdamW': AdamW, 79 | 'Adam': Adam, 80 | 'SGD': SGD, 81 | 'NAdam': NAdam, 82 | }[cfg['trainer']['optim_method']] 83 | 84 | trainer = Trainer( 85 | task, 86 | output=cfg['trainer']['output'], 87 | train_loader=train_loader, 88 | test_loader=test_loader, 89 | checkpoint=cfg['trainer']['checkpoint'], 90 | drop_optim=cfg['trainer']['drop_optim'], 91 | drop_lr_scheduler=cfg['trainer']['drop_lr_scheduler'], 92 | optim_method=optim_method, 93 | lr=cfg['trainer']['lr'], 94 | weight_decay=cfg['trainer']['weight_decay'], 95 | lrf=cfg['trainer']['lrf'], 96 | T_num=cfg['trainer']['T_num'], 97 | grad_steps=cfg['trainer']['grad_steps'], 98 | epochs=cfg['trainer']['epochs'], 99 | show_step=cfg['trainer']['show_step'], 100 | save_epoch=cfg['trainer']['save_epoch'], 101 | ) 102 | 103 | trainer.initialize() 104 | trainer.fit() 105 | 106 | 107 | def entry(): 108 | import argparse 109 | 110 | parser = argparse.ArgumentParser( 111 | description='Image classification trainer') 112 | parser.add_argument( 113 | 'path', 114 | metavar='configure path', 115 | type=str, 116 | help='The parameters configure of the trainer', 117 | ) 118 | 119 | args = parser.parse_args() 120 | cfg = load_from_json(args.path) 121 | main(cfg) 122 | 123 | 124 | if __name__ == "__main__": 125 | entry() 126 | -------------------------------------------------------------------------------- /vklearn/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | -------------------------------------------------------------------------------- /vklearn/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/vklearn/datasets/__init__.py -------------------------------------------------------------------------------- /vklearn/datasets/characters.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | CHARACTERS_DIR = os.path.join( 5 | os.path.dirname(os.path.abspath(__file__)), 'characters/') 6 | 7 | CHARACTERS_DICT = { 8 | os.path.basename(os.path.splitext(path)[0]): path 9 | for path in glob(os.path.join(CHARACTERS_DIR, '*.txt'))} 10 | -------------------------------------------------------------------------------- /vklearn/datasets/characters/en_sym.txt: -------------------------------------------------------------------------------- 1 | ! 2 | " 3 | # 4 | $ 5 | % 6 | & 7 | ' 8 | ( 9 | ) 10 | * 11 | + 12 | , 13 | - 14 | . 15 | / 16 | 0 17 | 1 18 | 2 19 | 3 20 | 4 21 | 5 22 | 6 23 | 7 24 | 8 25 | 9 26 | : 27 | ; 28 | < 29 | = 30 | > 31 | ? 32 | @ 33 | A 34 | B 35 | C 36 | D 37 | E 38 | F 39 | G 40 | H 41 | I 42 | J 43 | K 44 | L 45 | M 46 | N 47 | O 48 | P 49 | Q 50 | R 51 | S 52 | T 53 | U 54 | V 55 | W 56 | X 57 | Y 58 | Z 59 | [ 60 | \ 61 | ] 62 | ^ 63 | _ 64 | ` 65 | a 66 | b 67 | c 68 | d 69 | e 70 | f 71 | g 72 | h 73 | i 74 | j 75 | k 76 | l 77 | m 78 | n 79 | o 80 | p 81 | q 82 | r 83 | s 84 | t 85 | u 86 | v 87 | w 88 | x 89 | y 90 | z 91 | { 92 | | 93 | } 94 | ~ 95 | ° 96 | ˇ 97 | ‖ 98 | ‘ 99 | ’ 100 | “ 101 | ” 102 | … 103 | ‰ 104 | ′ 105 | ″ 106 | ※ 107 | ① 108 | ② 109 | ③ 110 | ④ 111 | ⑤ 112 | ⑥ 113 | ⑦ 114 | ⑧ 115 | ⑨ 116 | ⑩ 117 | ● 118 | ★ 119 | 、 120 | 。 121 | 〇 122 | 〈 123 | 〉 124 | 《 125 | 》 126 | 「 127 | 」 128 | 『 129 | 』 130 | 【 131 | 】 132 | 〔 133 | 〕 134 | 〖 135 | 〗 136 | ・ 137 | ! 138 | ( 139 | ) 140 | , 141 | : 142 | ? 143 | { 144 | } 145 | ¥ -------------------------------------------------------------------------------- /vklearn/datasets/cocopretrain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple, Dict 2 | import os.path 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | from torchvision.datasets.vision import VisionDataset 8 | from torchvision import tv_tensors 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class CocoPretrain(VisionDataset): 14 | '''`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root: Root directory where images are downloaded to. 20 | annFile: Path to json annotation file. 21 | category_type: The type of category, can be ``name`` or ``supercategory``. 22 | sub_categories: Select some target categories as a List of subcategories. 23 | max_datas_size: For large amounts of data, it is used to limit the number of samples, 24 | and the default is 0, which means no limit. 25 | transform: A function/transform that takes in a PIL image 26 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 27 | target_transform: A function/transform that takes in the 28 | target and transforms it. 29 | transforms: A function/transform that takes input sample and its target as entry 30 | and returns a transformed version. 31 | ''' 32 | 33 | NAME_OTHER = 'other' 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | annFile: str, 39 | category_type: str='name', 40 | sub_categories: List[str] | None=None, 41 | max_datas_size: int=0, 42 | transform: Callable | None=None, 43 | target_transform: Callable | None=None, 44 | transforms: Callable | None=None, 45 | ): 46 | super().__init__(root, transforms, transform, target_transform) 47 | from pycocotools.coco import COCO 48 | 49 | assert category_type in ('name', 'supercategory') 50 | 51 | if sub_categories is None: sub_categories = [] 52 | 53 | self.coco = COCO(annFile) 54 | self.ids = list(sorted(self.coco.imgs.keys())) 55 | self.category_type = category_type 56 | 57 | self.coid2name = { 58 | clss['id']: clss['name'] 59 | for clss in self.coco.dataset['categories']} 60 | self.coid2supercategory = { 61 | clss['id']: clss['supercategory'] 62 | for clss in self.coco.dataset['categories']} 63 | self.coid2subcategory = { 64 | clss['id']: (clss[category_type] if clss[category_type] in sub_categories else self.NAME_OTHER) 65 | for clss in self.coco.dataset['categories']} 66 | 67 | idxs = sorted(self.coid2name.keys()) 68 | self.names = [self.coid2name[i] for i in idxs] 69 | self.supercategories = [] 70 | for i in idxs: 71 | category = self.coid2supercategory[i] 72 | if category in self.supercategories: continue 73 | self.supercategories.append(category) 74 | self.subcategories = sub_categories # + [self.NAME_OTHER] 75 | 76 | if len(sub_categories) > 0: 77 | self.classes = self.subcategories 78 | self.coid2class = self.coid2subcategory 79 | self.ids = self._drop_other_images(self.ids) 80 | elif category_type == 'name': 81 | self.classes = self.names 82 | self.coid2class = self.coid2name 83 | elif category_type == 'supercategory': 84 | self.classes = self.supercategories 85 | self.coid2class = self.coid2supercategory 86 | 87 | self.max_datas_size = max_datas_size if max_datas_size > 0 else len(self.ids) 88 | 89 | def __len__(self) -> int: 90 | return min(self.max_datas_size, len(self.ids)) 91 | 92 | def _drop_other_images(self, ids:List[int]) -> List[int]: 93 | new_ids = [] 94 | for _id in tqdm(ids, ncols=80): 95 | anns = self._load_anns(_id) 96 | for ann in anns: 97 | class_name = self.coid2class[ann['category_id']] 98 | if class_name != self.NAME_OTHER: 99 | new_ids.append(_id) 100 | break 101 | return new_ids 102 | 103 | def _load_image(self, id:int) -> Image.Image: 104 | path = self.coco.loadImgs(id)[0]['file_name'] 105 | return Image.open(os.path.join(self.root, path)).convert('RGB') 106 | 107 | def _load_anns(self, id:int) -> List[Any]: 108 | return [ 109 | ann for ann in self.coco.loadAnns(self.coco.getAnnIds(id)) 110 | if ann['category_id'] > 0] 111 | 112 | def _format_anns( 113 | self, 114 | anns: List[Any], 115 | image_size: Tuple[int, int], 116 | ) -> Dict[str, Any]: 117 | 118 | xywh2xyxy = lambda x, y, w, h: (x, y, x + w, y + h) 119 | validation = lambda ann: ann['iscrowd'] == 0 120 | box_list = [] 121 | label_list = [] 122 | for ann in anns: 123 | if not validation(ann): continue 124 | class_name = self.coid2class[ann['category_id']] 125 | if class_name == self.NAME_OTHER: continue 126 | box_list.append(xywh2xyxy(*ann['bbox'])) 127 | label_list.append(self.classes.index(class_name)) 128 | boxes = tv_tensors.BoundingBoxes( 129 | box_list, 130 | format='XYXY', 131 | canvas_size=(image_size[1], image_size[0]), 132 | ) 133 | labels = torch.LongTensor(label_list) 134 | return dict(boxes=boxes, labels=labels) 135 | 136 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 137 | if not isinstance(index, int): 138 | raise ValueError( 139 | f'Index must be of type integer, got {type(index)} instead.') 140 | 141 | id = self.ids[index] 142 | anns = self._load_anns(id) 143 | if len(anns) == 0: 144 | return self.__getitem__((index + 1) % self.__len__()) 145 | 146 | image = self._load_image(id) 147 | target = self._format_anns(anns, image.size) 148 | 149 | if self.transforms is not None: 150 | image, target = self.transforms(image, target) 151 | 152 | multilabel = torch.zeros(len(self.classes)) 153 | for label_idx in target['labels']: 154 | multilabel[label_idx] = 1. 155 | multilabel /= max(len(target['labels']), 1) 156 | 157 | return image, multilabel 158 | -------------------------------------------------------------------------------- /vklearn/datasets/hwdb_gnt.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Tuple 2 | from glob import glob 3 | from collections import defaultdict 4 | import os 5 | import struct 6 | 7 | from torch import Tensor 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from .characters import CHARACTERS_DICT 15 | 16 | 17 | class HWDBGnt(VisionDataset): 18 | '''HWDB-Gnt Dataset 19 | 20 | Args: 21 | root: Root directory of the dataset. 22 | split: The dataset split, supports `"train"` (default) or `"test"`. 23 | characters_file: The path of character-set file. 24 | limit: Limit the number of data files to be loaded. 25 | transform: A function/transform that takes in a PIL image 26 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 27 | target_transform: A function/transform that takes in the 28 | target and transforms it. 29 | transforms: A function/transform that takes input sample and its target as entry 30 | and returns a transformed version. 31 | ''' 32 | 33 | CHARACTERS_FILE = CHARACTERS_DICT['ch_sym_sim'] 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | characters_file: str | None=None, 39 | split: str='train', 40 | limit: int=0, 41 | transforms: Callable | None=None, 42 | transform: Callable | None=None, 43 | target_transform: Callable | None=None, 44 | ): 45 | 46 | assert split in ('train', 'test') 47 | super().__init__(root, transforms, transform, target_transform) 48 | gnt_files = glob(os.path.join(root, f'*{split}', '*.gnt')) 49 | self._build_index(gnt_files, characters_file, limit) 50 | 51 | def _load_characters(self, characters_file:str) -> List[str]: 52 | with open(characters_file, encoding='utf-8') as f: 53 | return sorted(set(f.read().replace(' ', '').replace('\n', ''))) 54 | 55 | def _build_index( 56 | self, 57 | gnt_files: list, 58 | characters_file: str, 59 | limit: int, 60 | ): 61 | 62 | char2index = defaultdict(list) 63 | index = [] 64 | valid_characters = set() 65 | characters = self._load_characters(characters_file or self.CHARACTERS_FILE) 66 | gnt_files_sorted = sorted(gnt_files) 67 | if limit > 0: 68 | gnt_files_sorted = gnt_files_sorted[:limit] 69 | print('build index for gnt files...') 70 | for gnt_i, gnt_file in enumerate(tqdm(gnt_files_sorted, ncols=80)): 71 | with open(gnt_file, 'rb') as f: 72 | while True: 73 | pointer = f.tell() 74 | sample_size_bytes = f.read(4) 75 | if sample_size_bytes == b'': break 76 | sample_size = struct.unpack(' Tuple[Image.Image | Tensor, int]: 92 | gnt_i, pointer, character = self._data_index[idx] 93 | with open(self._gnt_files[gnt_i], 'rb') as f: 94 | f.seek(pointer + 4, 0) 95 | target_code = f.read(2).decode('gbk', 'ignore').strip('\x00') 96 | assert target_code == character 97 | width = struct.unpack(' List[str]: 55 | with open(characters_file, encoding='utf-8') as f: 56 | return sorted(set(f.read().replace(' ', '').replace('\n', ''))) 57 | 58 | def _build_index( 59 | self, 60 | pot_files: list, 61 | characters_file: str | None, 62 | limit: int, 63 | ): 64 | 65 | char2index = defaultdict(list) 66 | index = [] 67 | valid_characters = set() 68 | characters = self._load_characters(characters_file or self.CHARACTERS_FILE) 69 | pot_files_sorted = sorted(pot_files) 70 | if limit > 0: 71 | pot_files_sorted = pot_files_sorted[:limit] 72 | print('build index for pot files...') 73 | for pot_i, pot_file in enumerate(tqdm(pot_files_sorted, ncols=80)): 74 | with open(pot_file, 'rb') as f: 75 | while True: 76 | pointer = f.tell() 77 | sample_size_bytes = f.read(2) 78 | if sample_size_bytes == b'': break 79 | sample_size = struct.unpack(' Tuple[Image.Image | Tensor, int]: 95 | pot_i, pointer, character = self._data_index[idx] 96 | with open(self._pot_files[pot_i], 'rb') as f: 97 | f.seek(pointer + 2, 0) 98 | target_code = f.read(4)[::-1].decode('gbk', 'ignore').strip('\x00') 99 | assert target_code == character 100 | stroke_number = struct.unpack(' int: 44 | return len(self.paths) 45 | 46 | def _load_image(self, path:str) -> Image.Image: 47 | return Image.open(path).convert('RGB') 48 | 49 | def _load_label(self, path:str) -> int: 50 | category = os.path.basename(os.path.dirname(path)) 51 | return self.classes.index(category) 52 | 53 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 54 | if not isinstance(index, int): 55 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 56 | 57 | path = self.paths[index] 58 | image = self._load_image(path) 59 | target = self._load_label(path) 60 | 61 | if self.transforms is not None: 62 | image, target = self.transforms(image, target) 63 | 64 | return image, target 65 | -------------------------------------------------------------------------------- /vklearn/datasets/labelme_detection.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple 2 | import os.path 3 | from glob import glob 4 | import json 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset 11 | from torchvision import tv_tensors 12 | 13 | 14 | class LabelmeDetection(VisionDataset): 15 | '''`Labelme annotated format common dataset. 16 | 17 | Args: 18 | root: Root directory where images are downloaded to. 19 | split: The dataset split, supports ``""`` (default), ``"train"``, ``"valid"`` or ``"test"``. 20 | transform: A function/transform that takes in a PIL image 21 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 22 | target_transform: A function/transform that takes in the 23 | target and transforms it. 24 | transforms: A function/transform that takes input sample and its target as entry 25 | and returns a transformed version. 26 | ''' 27 | 28 | def __init__( 29 | self, 30 | root: str, 31 | split: str='', 32 | transform: Callable | None=None, 33 | target_transform: Callable | None=None, 34 | transforms: Callable | None=None, 35 | ): 36 | super().__init__(root, transforms, transform, target_transform) 37 | assert split in ['', 'train', 'valid', 'test'] 38 | self.dataset_dir = os.path.abspath(os.path.join(root, split)) 39 | assert os.path.isdir(self.dataset_dir) 40 | 41 | self.label_paths = sorted(glob(os.path.join(self.dataset_dir, '*.json'))) 42 | self.classes = [] 43 | with open(os.path.join(root, 'classnames.txt')) as f: 44 | for name in f: 45 | name = name.strip() 46 | if not name: continue 47 | self.classes.append(name) 48 | 49 | def __len__(self) -> int: 50 | return len(self.label_paths) 51 | 52 | def _load_image(self, path:str) -> Image.Image: 53 | return Image.open(os.path.join(self.dataset_dir, path)).convert('RGB') 54 | 55 | def _load_anns(self, id:int) -> Tuple[List[Any], str]: 56 | label_path = self.label_paths[id] 57 | with open(label_path) as f: 58 | data = json.load(f) 59 | return data['shapes'], data['imagePath'] 60 | 61 | def _points2xyxy(self, points:List[List[float]]) -> List[float]: 62 | array = np.asarray(points, dtype=np.float32) 63 | x1, y1 = array.min(axis=0) 64 | x2, y2 = array.max(axis=0) 65 | return [x1, y1, x2, y2] 66 | 67 | def _format_anns( 68 | self, 69 | anns: List[Any], 70 | image_size: Tuple[int, int], 71 | ) -> dict[str, Any]: 72 | boxes = tv_tensors.BoundingBoxes( 73 | [self._points2xyxy(ann['points']) for ann in anns], 74 | format='XYXY', 75 | canvas_size=(image_size[1], image_size[0]), 76 | ) 77 | labels = torch.LongTensor([self.classes.index(ann['label']) for ann in anns]) 78 | return dict(boxes=boxes, labels=labels) 79 | 80 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 81 | 82 | if not isinstance(index, int): 83 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 84 | 85 | anns, imagePath = self._load_anns(index) 86 | if len(anns) == 0: 87 | return self.__getitem__((index + 1) % self.__len__()) 88 | 89 | image = self._load_image(imagePath) 90 | target = self._format_anns(anns, image.size) 91 | 92 | if self.transforms is not None: 93 | image, target = self.transforms(image, target) 94 | 95 | return image, target 96 | -------------------------------------------------------------------------------- /vklearn/datasets/lsvt_joints.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple, Dict 2 | from glob import glob 3 | import os.path 4 | import json 5 | # import math 6 | 7 | from PIL import Image 8 | import cv2 as cv 9 | import numpy as np 10 | 11 | import torch 12 | from torchvision.datasets.vision import VisionDataset 13 | from torchvision.tv_tensors import BoundingBoxes, Mask 14 | from torchvision.ops import box_convert 15 | 16 | 17 | class LSVTJoints(VisionDataset): 18 | '''`LSVT Joints Detection dataset. 19 | 20 | Args: 21 | root: Root directory where images are downloaded to. 22 | transform: A function/transform that takes in a PIL image 23 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 24 | target_transform: A function/transform that takes in the 25 | target and transforms it. 26 | transforms: A function/transform that takes input sample and its target as entry 27 | and returns a transformed version. 28 | ''' 29 | LINE_THICKNESS = 7 30 | 31 | def __init__( 32 | self, 33 | root: str, 34 | ignore_illegibility: bool=True, 35 | transform: Callable | None=None, 36 | target_transform: Callable | None=None, 37 | transforms: Callable | None=None, 38 | ): 39 | super().__init__(root, transforms, transform, target_transform) 40 | 41 | self.image_paths = sorted(glob(os.path.join(root, 'train_full_images_*/*/*.jpg'))) 42 | label_path = os.path.join(root, 'train_full_labels.json') 43 | 44 | with open(label_path) as f: 45 | self.anns_dict = json.load(f) 46 | 47 | self.classes = ['text'] 48 | self.ignore_illegibility = ignore_illegibility 49 | 50 | def __len__(self) -> int: 51 | return len(self.image_paths) 52 | 53 | def _load_image(self, image_paths:str) -> Image.Image: 54 | return Image.open(image_paths).convert('RGB') 55 | 56 | def _format_anns( 57 | self, 58 | anns: List[Any], 59 | image_size: Tuple[int, int], 60 | ) -> Dict[str, Any]: 61 | 62 | bbox_list = [] 63 | for ann in anns: 64 | if self.ignore_illegibility and ann['illegibility']: continue 65 | (x, y), (w, h), a = cv.minAreaRect(np.array(ann['points'])) 66 | if w < h: 67 | w, h = h, w 68 | a -= 90 69 | pts = cv.boxPoints(((x, y), (w, h), a)) 70 | 71 | diameter = min(w, h) 72 | if diameter < self.LINE_THICKNESS: continue 73 | length = max(w, h) 74 | x1y1 = (pts[0] + pts[1]) * 0.5 75 | x2y2 = (pts[2] + pts[3]) * 0.5 76 | x1, y1 = x1y1 + (x2y2 - x1y1) * diameter / length * 0.5 77 | x2, y2 = x2y2 + (x1y1 - x2y2) * diameter / length * 0.5 78 | 79 | bbox_list.append([x1, y1, diameter, diameter]) 80 | bbox_list.append([x2, y2, diameter, diameter]) 81 | 82 | return dict( 83 | boxes=BoundingBoxes( 84 | bbox_list, 85 | format='CXCYWH', 86 | canvas_size=(image_size[1], image_size[0])), 87 | labels=torch.LongTensor([0] * len(bbox_list)), 88 | ) 89 | 90 | def _draw_masks(self, boxes:BoundingBoxes) -> Mask: 91 | ground = np.zeros(boxes.canvas_size, dtype=np.uint8) 92 | for i in range(boxes.shape[0]): 93 | if i % 2 == 0: continue 94 | pt1 = boxes[i - 1][:2].round().numpy().astype(int) 95 | pt2 = boxes[i][:2].round().numpy().astype(int) 96 | thickness = max(1, min(self.LINE_THICKNESS, round(0.33 * boxes[i -1][3].item()))) 97 | cv.line(ground, pt1, pt2, 1, thickness, lineType=cv.LINE_AA) 98 | return Mask(np.expand_dims(ground, 0)) 99 | 100 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 101 | 102 | if not isinstance(index, int): 103 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 104 | 105 | image_path = self.image_paths[index] 106 | image_id = os.path.splitext(os.path.basename(image_path))[0] 107 | anns = self.anns_dict[image_id] 108 | if len(anns) == 0: 109 | return self.__getitem__((index + 1) % self.__len__()) 110 | 111 | image = self._load_image(image_path) 112 | target = self._format_anns(anns, image.size) 113 | if target['boxes'].numel() == 0: 114 | return self.__getitem__((index + 1) % self.__len__()) 115 | 116 | num_boxes = len(target['boxes']) 117 | if self.transforms is not None: 118 | max_diameter = target['boxes'][:, 2:].max() 119 | target['boxes'][:, 2:] /= max_diameter 120 | image, target = self.transforms(image, target) 121 | target['boxes'][:, 2:] *= max_diameter 122 | 123 | for i in range(0, len(target['boxes']), 2): 124 | x1, y1 = target['boxes'][i][:2].tolist() 125 | x2, y2 = target['boxes'][i + 1][:2].tolist() 126 | if abs(x2 - x1) > abs(y2 - y1): 127 | if x1 <= x2: continue 128 | else: 129 | if y1 <= y2: continue 130 | target['boxes'][i][0] = x2 131 | target['boxes'][i][1] = y2 132 | target['boxes'][i + 1][0] = x1 133 | target['boxes'][i + 1][1] = y1 134 | 135 | target['masks'] = self._draw_masks(target['boxes']) 136 | target['boxes'] = box_convert(target['boxes'], 'cxcywh', 'xyxy') 137 | assert len(target['boxes']) == num_boxes 138 | 139 | return image, target 140 | 141 | 142 | if __name__ == "__main__": 143 | import matplotlib.pyplot as plt 144 | from matplotlib.pyplot import Circle 145 | 146 | dataset = LSVTJoints('/media/kk/Data/dataset/image/LSVT') 147 | print(len(dataset)) 148 | 149 | for i in range(len(dataset)): 150 | image, target = dataset[i] 151 | ax:plt.Axes = plt.subplot() 152 | img_arr = np.array(image, dtype=np.uint8) 153 | mask = target['masks'].numpy() 154 | img_arr[..., 1][mask[0] == 1] = 0 155 | ax.imshow(img_arr) 156 | for bnd_id, bbox in enumerate(target['boxes']): 157 | is_begin = bnd_id % 2 == 0 158 | x, y = (bbox[:2] + bbox[2:]) * 0.5 159 | diameter = (bbox[2:] - bbox[:2]).min() 160 | color = 'red' if is_begin else 'blue' 161 | ax.add_patch(Circle((x, y), diameter * 0.5, color=color, fill=False, linewidth=1)) 162 | ax.add_patch(Circle((x, y), 5, color=color, fill=True)) 163 | plt.show() 164 | if input('continue?>').strip() == 'q': break 165 | -------------------------------------------------------------------------------- /vklearn/datasets/masksegment.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Any, Dict, List 2 | from xml.etree.ElementTree import parse as ET_parse 3 | from glob import glob 4 | import os 5 | 6 | from torchvision.datasets.vision import VisionDataset 7 | from torchvision import tv_tensors 8 | import torch 9 | import torch.nn.functional as F 10 | from PIL import Image 11 | 12 | 13 | class MaskSegment(VisionDataset): 14 | 15 | def __init__( 16 | self, 17 | root: str, 18 | split: str='train', 19 | categories: List[str] | None=None, 20 | transform: Callable | None=None, 21 | target_transform: Callable | None=None, 22 | transforms: Callable | None=None, 23 | ): 24 | 25 | super().__init__( 26 | root, transforms=transforms, transform=transform, target_transform=target_transform) 27 | 28 | labels_file = os.path.join(root, 'labels.txt') 29 | self.classes = categories 30 | if os.path.isfile(labels_file): 31 | with open(labels_file, 'w') as f: 32 | self.classes = [label.strip() for label in f] 33 | assert self.classes is not None 34 | 35 | self._images = sorted(glob(os.path.join(root, split, 'image', '*'))) 36 | self._masks = sorted(glob(os.path.join(root, split, 'mask', '*'))) 37 | assert len(self._images) == len(self._masks) 38 | 39 | def __len__(self) -> int: 40 | return len(self._images) 41 | 42 | def _format_mask(self, mask:tv_tensors.Mask) -> tv_tensors.Mask: 43 | return tv_tensors.Mask( 44 | F.one_hot(mask, len(self.classes)).transpose(0, 3).squeeze(-1)) 45 | 46 | def __getitem__(self, idx:int) -> Tuple[Any, Any]: 47 | image = Image.open(self._images[idx]).convert('RGB') 48 | target = Image.open(self._masks[idx]) 49 | assert target.mode == 'P' 50 | if image.size != target.size: 51 | image = image.resize(target.size, resample=Image.Resampling.BICUBIC) 52 | target = tv_tensors.Mask(target, dtype=torch.long) 53 | if self.transforms is not None: 54 | image, target = self.transforms(image, target) 55 | target = self._format_mask(target) 56 | return image, target 57 | -------------------------------------------------------------------------------- /vklearn/datasets/ms_coco_classnames.json: -------------------------------------------------------------------------------- 1 | {"0": "__background__", 2 | "1": "person", 3 | "2": "bicycle", 4 | "3": "car", 5 | "4": "motorcycle", 6 | "5": "airplane", 7 | "6": "bus", 8 | "7": "train", 9 | "8": "truck", 10 | "9": "boat", 11 | "10": "traffic light", 12 | "11": "fire hydrant", 13 | "12": "stop sign", 14 | "13": "parking meter", 15 | "14": "bench", 16 | "15": "bird", 17 | "16": "cat", 18 | "17": "dog", 19 | "18": "horse", 20 | "19": "sheep", 21 | "20": "cow", 22 | "21": "elephant", 23 | "22": "bear", 24 | "23": "zebra", 25 | "24": "giraffe", 26 | "25": "backpack", 27 | "26": "umbrella", 28 | "27": "handbag", 29 | "28": "tie", 30 | "29": "suitcase", 31 | "30": "frisbee", 32 | "31": "skis", 33 | "32": "snowboard", 34 | "33": "sports ball", 35 | "34": "kite", 36 | "35": "baseball bat", 37 | "36": "baseball glove", 38 | "37": "skateboard", 39 | "38": "surfboard", 40 | "39": "tennis racket", 41 | "40": "bottle", 42 | "41": "wine glass", 43 | "42": "cup", 44 | "43": "fork", 45 | "44": "knife", 46 | "45": "spoon", 47 | "46": "bowl", 48 | "47": "banana", 49 | "48": "apple", 50 | "49": "sandwich", 51 | "50": "orange", 52 | "51": "broccoli", 53 | "52": "carrot", 54 | "53": "hot dog", 55 | "54": "pizza", 56 | "55": "donut", 57 | "56": "cake", 58 | "57": "chair", 59 | "58": "couch", 60 | "59": "potted plant", 61 | "60": "bed", 62 | "61": "dining table", 63 | "62": "toilet", 64 | "63": "tv", 65 | "64": "laptop", 66 | "65": "mouse", 67 | "66": "remote", 68 | "67": "keyboard", 69 | "68": "cell phone", 70 | "69": "microwave", 71 | "70": "oven", 72 | "71": "toaster", 73 | "72": "sink", 74 | "73": "refrigerator", 75 | "74": "book", 76 | "75": "clock", 77 | "76": "vase", 78 | "77": "scissors", 79 | "78": "teddy bear", 80 | "79": "hair drier", 81 | "80": "toothbrush"} 82 | -------------------------------------------------------------------------------- /vklearn/datasets/mvtec_screws.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple, Dict 2 | from collections import defaultdict 3 | import os.path 4 | import json 5 | import math 6 | 7 | from PIL import Image 8 | import cv2 as cv 9 | import numpy as np 10 | 11 | import torch 12 | from torchvision.datasets.vision import VisionDataset 13 | from torchvision.tv_tensors import BoundingBoxes, Mask 14 | from torchvision.ops import box_convert 15 | 16 | 17 | class MVTecScrews(VisionDataset): 18 | '''MVTec-Screws Joints Detection dataset. 19 | 20 | Args: 21 | root: Root directory where images are downloaded to. 22 | split: The dataset split, supports ``""`` (default), ``"train"``, ``"val"`` or ``"test"``. 23 | transform: A function/transform that takes in a PIL image 24 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 25 | target_transform: A function/transform that takes in the 26 | target and transforms it. 27 | transforms: A function/transform that takes input sample and its target as entry 28 | and returns a transformed version. 29 | ''' 30 | LINE_THICKNESS = 7 31 | 32 | def __init__( 33 | self, 34 | root: str, 35 | split: str='train', 36 | transform: Callable | None=None, 37 | target_transform: Callable | None=None, 38 | transforms: Callable | None=None, 39 | ): 40 | super().__init__(root, transforms, transform, target_transform) 41 | assert split in ['train', 'val', 'test'] 42 | 43 | self.images_dir = os.path.join(root, 'images/') 44 | 45 | label_path = os.path.join(root, f'mvtec_screws_{split}.json') 46 | 47 | with open(label_path) as f: 48 | data = json.load(f) 49 | 50 | self.id2category = {item['id'] - 1: item['name'] for item in data['categories']} 51 | self.classes = list(self.id2category.values()) 52 | 53 | self.image_id2filename = {item['id']: item['file_name'] for item in data['images']} 54 | self.image_ids = list(self.image_id2filename.keys()) 55 | 56 | self.anns_dict = defaultdict(list) 57 | for ann in data['annotations']: 58 | label = ann['category_id'] - 1 59 | bbox = ann['bbox'] 60 | image_id = ann['image_id'] 61 | self.anns_dict[image_id].append(dict( 62 | label=label, 63 | bbox=bbox, 64 | )) 65 | 66 | def __len__(self) -> int: 67 | return len(self.image_ids) 68 | 69 | def _load_image(self, image_id:int) -> Image.Image: 70 | filename = self.image_id2filename[image_id] 71 | return Image.open( 72 | os.path.join(self.images_dir, filename)).convert('RGB') 73 | 74 | def _format_anns( 75 | self, 76 | anns: List[Any], 77 | image_size: Tuple[int, int], 78 | ) -> Dict[str, Any]: 79 | 80 | bbox_list = [] 81 | label_list = [] 82 | for ann in anns: 83 | y, x, w, h, a = ann['bbox'] 84 | a = - a / math.pi * 180 85 | pts = cv.boxPoints(((x, y), (w, h), a)) 86 | 87 | diameter = min(w, h) 88 | length = max(w, h) 89 | x1y1 = (pts[0] + pts[1]) * 0.5 90 | x2y2 = (pts[2] + pts[3]) * 0.5 91 | x1, y1 = x1y1 + (x2y2 - x1y1) * diameter / length * 0.5 92 | x2, y2 = x2y2 + (x1y1 - x2y2) * diameter / length * 0.5 93 | 94 | bbox_list.append([x1, y1, diameter, diameter]) 95 | bbox_list.append([x2, y2, diameter, diameter]) 96 | label_list.append(ann['label']) 97 | label_list.append(ann['label']) 98 | 99 | return dict( 100 | boxes=BoundingBoxes( 101 | bbox_list, 102 | format='CXCYWH', 103 | canvas_size=(image_size[1], image_size[0])), 104 | labels=torch.LongTensor(label_list), 105 | ) 106 | 107 | def _draw_masks(self, boxes:BoundingBoxes) -> Mask: 108 | ground = np.zeros(boxes.canvas_size, dtype=np.uint8) 109 | for i in range(boxes.shape[0]): 110 | if i % 2 == 0: continue 111 | pt1 = boxes[i - 1][:2].round().numpy().astype(int) 112 | pt2 = boxes[i][:2].round().numpy().astype(int) 113 | cv.line(ground, pt1, pt2, 1, self.LINE_THICKNESS, lineType=cv.LINE_AA) 114 | return Mask(np.expand_dims(ground, 0)) 115 | 116 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 117 | 118 | if not isinstance(index, int): 119 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 120 | 121 | image_id = self.image_ids[index] 122 | anns = self.anns_dict[image_id] 123 | if len(anns) == 0: 124 | return self.__getitem__((index + 1) % self.__len__()) 125 | 126 | image = self._load_image(image_id) 127 | target = self._format_anns(anns, image.size) 128 | 129 | num_boxes = len(target['boxes']) 130 | if self.transforms is not None: 131 | max_diameter = target['boxes'][:, 2:].max() 132 | target['boxes'][:, 2:] /= max_diameter 133 | image, target = self.transforms(image, target) 134 | target['boxes'][:, 2:] *= max_diameter 135 | target['masks'] = self._draw_masks(target['boxes']) 136 | target['boxes'] = box_convert(target['boxes'], 'cxcywh', 'xyxy') 137 | assert len(target['boxes']) == num_boxes 138 | 139 | return image, target 140 | 141 | 142 | if __name__ == "__main__": 143 | import matplotlib.pyplot as plt 144 | from matplotlib.pyplot import Circle 145 | 146 | dataset = MVTecScrews('/media/kk/Data/dataset/image/MVTec-Screws', 'test') 147 | print(len(dataset)) 148 | 149 | for i in range(3): 150 | image, target = dataset[i] 151 | ax:plt.Axes = plt.subplot() 152 | img_arr = np.array(image, dtype=np.uint8) 153 | mask = target['masks'].numpy() 154 | img_arr[..., 0][mask[0] == 1] = 0 155 | ax.imshow(img_arr) 156 | for bnd_id, bbox in enumerate(target['boxes']): 157 | is_begin = bnd_id % 2 == 0 158 | x, y, diameter, _ = bbox 159 | color = 'red' if is_begin else 'blue' 160 | ax.add_patch(Circle((x, y), diameter * 0.5, color=color, fill=False, linewidth=1)) 161 | ax.add_patch(Circle((x, y), 10, color=color, fill=True)) 162 | plt.show() 163 | -------------------------------------------------------------------------------- /vklearn/datasets/ocr_instruct.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Tuple, Sequence 2 | from glob import glob 3 | import os 4 | import json 5 | import random 6 | 7 | from torchvision.datasets.vision import VisionDataset 8 | from torch.utils.data import Dataset 9 | from torch import Tensor 10 | import torch 11 | from PIL import Image 12 | import cv2 as cv 13 | import numpy as np 14 | 15 | from .ocr_synthesizer import OCRSynthesizer 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | class OCRInstruct(Dataset): 21 | '''`OCR Instruct dataset. 22 | 23 | Args: 24 | subject_datas: A list of sbuject datas. 25 | synthesizer: The OCRSynthesizer object. 26 | synthesis_rate: The adoption rate of the synthesis data. 27 | transform: A function/transform that takes in a PIL image 28 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 29 | target_transform: A function/transform that takes in the 30 | target and transforms it. 31 | transforms: A function/transform that takes input sample and its target as entry 32 | and returns a transformed version. 33 | ''' 34 | 35 | def __init__( 36 | self, 37 | subject_datas: List[str | Tuple[str, Callable] | VisionDataset], 38 | synthesizer: OCRSynthesizer, 39 | synthesis_rate: float=1., 40 | transforms: Callable | None=None, 41 | transform: Callable | None=None, 42 | target_transform: Callable | None=None, 43 | ): 44 | 45 | self.subjects = [] 46 | for data in subject_datas: 47 | if isinstance(data, str): 48 | self.subjects.append(InstructSubject( 49 | data, 50 | synthesizer.characters, 51 | synthesizer._reverse_rate, 52 | transforms=transforms, 53 | transform=transform, 54 | target_transform=target_transform, 55 | )) 56 | elif isinstance(data, Sequence) and isinstance(data[0], str): 57 | self.subjects.append(InstructSubject( 58 | data[0], 59 | synthesizer.characters, 60 | synthesizer._reverse_rate, 61 | transform=data[1], 62 | )) 63 | else: 64 | self.subjects.append(data) 65 | 66 | self.synthesizer = synthesizer 67 | self.subject_total = sum([len(subject) for subject in self.subjects]) 68 | 69 | self.synthesis_limit = int(len(synthesizer) * synthesis_rate) 70 | 71 | def __repr__(self): 72 | info = f'Dataset {self.__class__.__name__}\n' 73 | info += f'\tNumber of datapoints: {len(self)}\n' 74 | info += f'Synthesizer: {self.synthesizer}\n' 75 | info += 'Subjects:' 76 | for subject in self.subjects: 77 | info += f'\n* {str(subject)}' 78 | return info 79 | 80 | def __len__(self): 81 | return self.subject_total + self.synthesis_limit 82 | 83 | def __getitem__(self, idx:int): 84 | if idx >= self.subject_total: 85 | synthesis_size = len(self.synthesizer) 86 | if self.synthesis_limit == synthesis_size: 87 | return self.synthesizer[idx - self.subject_total] 88 | return self.synthesizer[random.randrange(synthesis_size)] 89 | begin = 0 90 | for subject in self.subjects: 91 | end = len(subject) + begin 92 | if idx < end: break 93 | begin = end 94 | return subject[idx - begin] 95 | 96 | 97 | class InstructSubject(VisionDataset): 98 | '''`Instruct Subject dataset. 99 | 100 | Args: 101 | root: Root directory where images are downloaded to. 102 | characters: A list of the characters. 103 | reverse_rate: The rate of images reversed. 104 | transform: A function/transform that takes in a PIL image 105 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 106 | target_transform: A function/transform that takes in the 107 | target and transforms it. 108 | transforms: A function/transform that takes input sample and its target as entry 109 | and returns a transformed version. 110 | ''' 111 | 112 | def __init__( 113 | self, 114 | root: str, 115 | characters: List[str], 116 | reverse_rate: float=0., 117 | transforms: Callable | None=None, 118 | transform: Callable | None=None, 119 | target_transform: Callable | None=None, 120 | ): 121 | 122 | super().__init__(root, transforms, transform, target_transform) 123 | 124 | self._reverse_rate = reverse_rate 125 | self.characters = characters 126 | self.char2index = {c: i for i, c in enumerate(self.characters)} 127 | 128 | print('preload subject dataset...') 129 | image_paths = sorted(glob(os.path.join(root, 'images/*.jpg'))) 130 | with open(os.path.join(root, 'labels.json')) as f: 131 | labels = json.load(f) 132 | self.items = [] 133 | for path in tqdm(image_paths, ncols=80): 134 | name = os.path.splitext(os.path.basename(path))[0] 135 | label = labels[name][0] 136 | if label['illegibility']: continue 137 | text = label['transcription'].strip() 138 | if not text: continue 139 | has_lack = False 140 | for c in text: 141 | if c in self.characters: continue 142 | has_lack = True 143 | break 144 | if has_lack: continue 145 | points = label['points'] 146 | self.items.append((path, text, points)) 147 | 148 | def __len__(self): 149 | return len(self.items) 150 | 151 | def _load_image( 152 | self, 153 | path: str, 154 | points: List[List[int]], 155 | ) -> Image.Image: 156 | 157 | rect = cv.minAreaRect(np.intp(points)) 158 | bbox = cv.boxPoints(rect) 159 | if bbox[1, 0] < bbox[3, 0]: 160 | width, height = rect[1] 161 | else: 162 | height, width = rect[1] 163 | bbox = bbox[[3, 0, 1, 2]] 164 | dst_pts = np.array([ 165 | [0, height], 166 | [0, 0], 167 | [width, 0], 168 | [width, height]], dtype=np.float32) 169 | M = cv.getPerspectiveTransform(bbox, dst_pts) 170 | width = int(round(width)) 171 | height = int(round(height)) 172 | im_arr = cv.imread(path) 173 | im_arr = cv.warpPerspective(im_arr, M, (width, height), flags=cv.INTER_AREA) 174 | if len(im_arr.shape) == 2: 175 | im_arr = cv.cvtColor(im_arr, cv.COLOR_GRAY2RGB) 176 | elif len(im_arr.shape) == 3: 177 | im_arr = cv.cvtColor(im_arr, cv.COLOR_BGR2RGB) 178 | image = Image.fromarray(im_arr) 179 | 180 | # if image.size[0] < image.size[1]: 181 | # image = image.transpose(Image.Transpose.ROTATE_90) 182 | return image 183 | 184 | def __getitem__(self, idx:int) -> Tuple[Image.Image | Tensor, Tensor, int]: 185 | path, text, points = self.items[idx] 186 | image = self._load_image(path, points) 187 | if (len(text) > 1) and (image.size[0] < image.size[1]): 188 | image = image.transpose(Image.Transpose.ROTATE_90) 189 | 190 | reverse = int(self._reverse_rate > max(1e-7, random.random())) 191 | if reverse: 192 | image = image.transpose(Image.Transpose.ROTATE_180) 193 | text = text[::-1] 194 | 195 | if self.transform is not None: 196 | image = self.transform(image) 197 | 198 | target = torch.LongTensor([self.char2index[c] for c in text]) 199 | return image, target, reverse 200 | -------------------------------------------------------------------------------- /vklearn/datasets/ocr_printing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Tuple, Any 2 | from glob import glob 3 | import os 4 | import math 5 | import random 6 | 7 | from torch import Tensor 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | from PIL import Image, ImageFont, ImageDraw 11 | from fontTools.ttLib import TTFont 12 | import numpy as np 13 | 14 | from tqdm import tqdm 15 | 16 | from .characters import CHARACTERS_DICT 17 | 18 | 19 | class Font: 20 | '''Printing Character Font 21 | 22 | Args: 23 | path: The font file path. 24 | size: Font size. 25 | chars: A list of characters. 26 | ''' 27 | 28 | def __init__( 29 | self, 30 | path: str, 31 | size: int, 32 | chars: list, 33 | ): 34 | 35 | self._name = os.path.basename(path) 36 | self._font = ImageFont.truetype(path, size) 37 | self.valid = self._font_checks(path, chars) 38 | self._lengths = dict() 39 | 40 | def _has_char(self, font:TTFont, char:str) -> bool: 41 | for table in font['cmap'].tables: 42 | if ord(char) in table.cmap.keys(): 43 | return True 44 | return False 45 | 46 | def _font_checks(self, font_path:str, chars:List[str]) -> List[str]: 47 | font = TTFont(font_path) 48 | valid = set(chars) 49 | for char in chars: 50 | if not self._has_char(font, char): 51 | valid.remove(char) 52 | return sorted(valid) 53 | 54 | def random_lack( 55 | self, 56 | text: str, 57 | chars: List[str] | None=None, 58 | ) -> str: 59 | 60 | chars = chars or self.valid 61 | _text = '' 62 | for c in text: 63 | if c == ' ' or c in self.valid: _text += c 64 | else: _text += random.choice(chars) 65 | return _text 66 | 67 | def text2image_with_anchor( 68 | self, 69 | text: str, 70 | direction: str='ltr', 71 | ) -> Tuple[Image.Image, Tuple[int, int]]: 72 | 73 | l, t, r, d = self._font.getbbox(text, direction=direction) 74 | anchor = (-l, -t) 75 | image = Image.new('L', (r - l, d - t), color=0) 76 | draw = ImageDraw.Draw(image) 77 | draw.text(anchor, text, font=self._font, fill=255, align='center', direction=direction) 78 | if direction == 'ttb': 79 | image = image.transpose(Image.Transpose.ROTATE_90) 80 | anchor = anchor[::-1] 81 | return image, anchor 82 | 83 | def get_length( 84 | self, 85 | text: str, 86 | direction: str='ltr', 87 | ) -> int: 88 | 89 | length = self._lengths.get(text) 90 | if length is None: 91 | length = self._font.getlength(text, direction=direction) 92 | self._lengths[text] = length 93 | return length 94 | 95 | def text2image_with_xyxys( 96 | self, 97 | text: str, 98 | direction: str='ltr', 99 | ) -> Tuple[Image.Image, List[Any]]: 100 | 101 | image, anchor = self.text2image_with_anchor(text, direction=direction) 102 | xyxys = [] 103 | lengths = [self.get_length(c, direction=direction) for c in text] 104 | bitmap = np.asarray(image, dtype=np.uint8) 105 | 106 | left = 0 107 | for i, length in enumerate(lengths): 108 | right = left + length 109 | if i == 0: right += anchor[0] 110 | 111 | col_l = math.floor(left) 112 | col_r = math.ceil(right) 113 | if text[i] != ' ': 114 | submap = bitmap[:, col_l:col_r] 115 | shrink = max(1, submap.shape[1] // 9) 116 | r0 = 0 117 | c0 = 0 118 | r1 = submap.shape[0] - 1 119 | c1 = submap.shape[1] - 1 120 | for r0 in range(submap.shape[0]): 121 | if submap[r0, shrink:-shrink].sum() != 0: break 122 | for r1 in range(submap.shape[0] - 1, 0, -1): 123 | if submap[r1, shrink:-shrink].sum() != 0: break 124 | for c0 in range(submap.shape[1]): 125 | if submap[:, c0].sum() != 0: break 126 | for c1 in range(submap.shape[1] - 1, 0, -1): 127 | if submap[:, c1].sum() != 0: break 128 | x0 = col_l + c0 129 | x1 = col_l + c1 130 | y0 = r0 131 | y1 = r1 132 | if x1 - x0 < 2: 133 | x0 -= 1 134 | x1 = x0 + 2 135 | if y1 - y0 < 2: 136 | y0 -= 1 137 | y1 = y0 + 2 138 | else: 139 | x0 = col_l 140 | x1 = col_r 141 | y0 = 0 142 | y1 = bitmap.shape[0] 143 | xyxys.append([x0, y0, x1, y1]) 144 | 145 | left = right 146 | return image, xyxys 147 | 148 | 149 | class PrintingCharacter(VisionDataset): 150 | '''Printing Character Dataset 151 | 152 | Args: 153 | root: Root directory of the dataset. 154 | characters_file: The path of character-set file. 155 | fontsize: The size of font, default is 48. 156 | limit: Limit the number of font files to be loaded. 157 | transform: A function/transform that takes in a PIL image 158 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 159 | target_transform: A function/transform that takes in the 160 | target and transforms it. 161 | transforms: A function/transform that takes input sample and its target as entry 162 | and returns a transformed version. 163 | ''' 164 | 165 | CHARACTERS_FILE = CHARACTERS_DICT['en_sym'] 166 | 167 | def __init__( 168 | self, 169 | root: str, 170 | characters_file: str | None=None, 171 | fontsize: int=48, 172 | limit: int=0, 173 | transforms: Callable | None=None, 174 | transform: Callable | None=None, 175 | target_transform: Callable | None=None, 176 | ): 177 | 178 | super().__init__(root, transforms, transform, target_transform) 179 | characters_file = characters_file or self.CHARACTERS_FILE 180 | with open(characters_file, encoding='utf-8') as f: 181 | self.characters = sorted(set(f.read().replace(' ', '').replace('\n', ''))) 182 | 183 | font_paths = glob(os.path.join(root, '*/*')) 184 | if limit > 0: font_paths = font_paths[:limit] 185 | self.fonts = [] 186 | print('loading fonts...') 187 | for font_path in tqdm(sorted(font_paths), ncols=80): 188 | try: 189 | font = Font(font_path, fontsize, self.characters) 190 | except OSError as e: 191 | print(e, font_path) 192 | continue 193 | self.fonts.append(font) 194 | 195 | def __len__(self): 196 | return len(self.fonts) * len(self.characters) 197 | 198 | def __getitem__(self, idx:int) -> Tuple[Image.Image | Tensor, int]: 199 | font_id = idx % len(self.fonts) 200 | font = self.fonts[font_id] 201 | character_id = idx // len(self.fonts) % len(font.valid) 202 | character = font.valid[character_id] 203 | image = font.text2image_with_anchor(character)[0] 204 | if self.transform is not None: 205 | image = self.transform(image) 206 | return image, character_id 207 | -------------------------------------------------------------------------------- /vklearn/datasets/ocr_synthesizer.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Any, Tuple 2 | import random 3 | 4 | from torch import Tensor 5 | from torchvision.datasets.vision import VisionDataset 6 | import torch 7 | 8 | from PIL import Image 9 | 10 | from .ocr_printing import PrintingCharacter 11 | from .hwdb_gnt import HWDBGnt 12 | from .hwdb_pot import HWDBPot 13 | 14 | 15 | class OCRSynthesizer(VisionDataset): 16 | '''OCR-Synthesizer Dataset 17 | 18 | Args: 19 | root: Root directory of the dataset. 20 | fonts_dir: The directory of font files. 21 | characters_file: The path of character-set file. 22 | split: The dataset split, supports `"train"` (default) or `"test"`. 23 | transform: A function/transform that takes in a PIL image 24 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 25 | target_transform: A function/transform that takes in the 26 | target and transforms it. 27 | transforms: A function/transform that takes input sample and its target as entry 28 | and returns a transformed version. 29 | ''' 30 | 31 | def __init__( 32 | self, 33 | root: str, 34 | fonts_dir: str, 35 | characters_file: str | None=None, 36 | split: str='train', 37 | transforms: Callable | None=None, 38 | transform: Callable | None=None, 39 | target_transform: Callable | None=None, 40 | **kwargs, 41 | ): 42 | 43 | assert split in ('train', 'test') 44 | super().__init__(root, transforms, transform, target_transform) 45 | 46 | self._printing = PrintingCharacter( 47 | fonts_dir, 48 | characters_file, 49 | fontsize=kwargs.get('printing_fontsize', 48), 50 | limit=kwargs.get('printing_limit', 0)) 51 | self._hwdb_rate = kwargs.get('hwdb_rate', 0.) 52 | self._hwdb_list = [] 53 | if self._hwdb_rate > 0.: 54 | self._hwdb_list = [ 55 | HWDBGnt( 56 | kwargs.get('hwdb_gnt_dir'), 57 | characters_file, 58 | split=split, 59 | limit=kwargs.get('hwdb_limit', 0)), 60 | HWDBPot( 61 | kwargs.get('hwdb_pot_dir'), 62 | characters_file, 63 | split=split, 64 | limit=kwargs.get('hwdb_limit', 0))] 65 | 66 | self.corpus = [] 67 | text_length = kwargs.get('text_length', 10) 68 | print('loading corpus...') 69 | with open(root, encoding='utf-8') as f: 70 | while True: 71 | text = f.read(text_length) 72 | if len(text) < text_length: break 73 | text = text.strip().replace('\n', ' ') 74 | # k = text_length - len(text) 75 | # for _ in range(k): 76 | # text += random.choice(self._printing.characters) 77 | # assert len(text) == text_length 78 | # self.corpus.append(text) 79 | if text: self.corpus.append(text) 80 | 81 | self.characters = ['', ' '] + self._printing.characters 82 | self.char2index = {c: i for i, c in enumerate(self.characters)} 83 | self._use_debug = kwargs.get('use_debug', False) 84 | self._reverse_rate = kwargs.get('reverse_rate', 0.) 85 | self._letter_spacing = kwargs.get('letter_spacing', 0.) 86 | self._layout_direction = kwargs.get('layout_direction', 'ltr') 87 | 88 | def __len__(self): 89 | return len(self.corpus) 90 | 91 | def _render_handwriting( 92 | self, 93 | text: str, 94 | image: Image.Image, 95 | xyxys: List[Any], 96 | direction: str='ltb', 97 | ) -> Image.Image: 98 | 99 | output = Image.new('L', image.size) 100 | x_hwdb = random.choice(self._hwdb_list) 101 | for i, (x0, y0, x1, y1) in enumerate(xyxys): 102 | character = text[i] 103 | if character == ' ': continue 104 | index = x_hwdb._char2index.get(character) 105 | anchor = [x0, y0] 106 | if index is not None: 107 | char_image, _ = x_hwdb[random.choice(index)] 108 | if direction == 'ttb': 109 | char_image = char_image.transpose(Image.Transpose.ROTATE_90) 110 | img_w, img_h = char_image.size 111 | box_w, box_h = x1 - x0, y1 - y0 112 | if box_w < box_h: 113 | dst_h = box_h 114 | dst_w = max(1, min(box_w, round(dst_h / img_h * img_w))) 115 | anchor[0] += (box_w - dst_w) // 2 116 | else: 117 | dst_w = box_w 118 | dst_h = max(1, min(box_h, round(dst_w / img_w * img_h))) 119 | anchor[1] += (box_h - dst_h) // 2 120 | char_image = char_image.resize( 121 | (dst_w, dst_h), resample=Image.Resampling.BILINEAR) 122 | else: 123 | char_image = image.crop((x0, y0, x1, y1)) 124 | output.paste(char_image, anchor) 125 | return output 126 | 127 | def update_letter_spacing( 128 | self, 129 | image: Image.Image, 130 | xyxys: List[Any], 131 | spacing: float, 132 | ) -> Image.Image: 133 | 134 | if spacing == 0: return image 135 | if len(xyxys) < 2: return image 136 | src_w, src_h = image.size 137 | if spacing > 0: 138 | offset_base = sum([ 139 | r - l for l, _, r, _ in xyxys]) / len(xyxys) 140 | else: 141 | offset_base = min([ 142 | xyxys[i + 1][0] - xyxys[i][0] for i in range(len(xyxys) - 1)]) 143 | offset = int(offset_base * spacing) 144 | if offset == 0: return image 145 | exp_w = (len(xyxys) - 1) * offset + src_w 146 | expanded = Image.new('L', (exp_w, src_h), color=0) 147 | exp_size = 0 148 | for l, _, r, _ in xyxys: 149 | sub_image = image.crop((l, 0, r, src_h)) 150 | expanded.paste(sub_image, (l + exp_size, 0, r + exp_size, src_h), mask=sub_image) 151 | exp_size += offset 152 | return expanded 153 | 154 | def __getitem__(self, idx:int) -> Tuple[Image.Image | Tensor, Tensor, int]: 155 | text = self.corpus[idx] 156 | 157 | # font = self._printing.fonts[idx % len(self._printing.fonts)] 158 | font = random.choice(self._printing.fonts) 159 | text = font.random_lack(text, ['#']) 160 | 161 | printing, xyxys = font.text2image_with_xyxys(text, direction=self._layout_direction) 162 | if min(printing.size) == 0: 163 | return self.__getitem__((idx + 1) % self.__len__()) 164 | 165 | applied_hwdb = self._hwdb_rate > max(1e-7, random.random()) 166 | if applied_hwdb: 167 | image = self._render_handwriting(text, printing, xyxys, direction=self._layout_direction) 168 | else: 169 | image = printing 170 | 171 | letter_spacing = random.uniform(-self._letter_spacing, self._letter_spacing) 172 | # if letter_spacing > -0.25: 173 | # update_size = int(letter_spacing * image.size[1]) 174 | # image = self.update_letter_spacing(image, xyxys, update_size) 175 | image = self.update_letter_spacing(image, xyxys, letter_spacing) 176 | 177 | reverse = int(self._reverse_rate > max(1e-7, random.random())) 178 | if reverse: 179 | image = image.transpose(Image.Transpose.ROTATE_180) 180 | text = text[::-1] 181 | 182 | if self.transform is not None: 183 | image.__ONLY_PRINT__ = not applied_hwdb 184 | image = self.transform(image) 185 | 186 | target = torch.LongTensor([self.char2index[c] for c in text]) 187 | 188 | if self._use_debug: 189 | return printing, image, target, reverse 190 | return image, target, reverse 191 | -------------------------------------------------------------------------------- /vklearn/datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Sequence, Tuple, Dict 2 | import os 3 | import os.path 4 | import pathlib 5 | import xml.etree.ElementTree as ET 6 | 7 | from PIL import Image 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | from torchvision import tv_tensors 14 | 15 | 16 | class OxfordIIITPet(VisionDataset): 17 | '''`Oxford-IIIT Pet Dataset `_. 18 | 19 | Args: 20 | root: Root directory of the dataset. 21 | split: The dataset split, supports `"trainval"` (default) or `"test"`. 22 | target_types: Types of target to use. Can be `category` (default) or 23 | `segmentation`. Can also be a list to output a tuple with all specified target types. The types represent: 24 | 25 | - `category` (int): Label for one of the 37 pet categories. 26 | - `binary-category` (int): Binary label for cat or dog. 27 | - `segmentation` (PIL image): Segmentation trimap of the image. 28 | - `detection` (PIL image, Labels & BoundingBoxes): Detection annotation. 29 | 30 | If empty, `None` will be returned as target. 31 | 32 | transform: A function/transform that takes in a PIL image and returns a transformed 33 | version. E.g, `transforms.RandomCrop`. 34 | target_transform: A function/transform that takes in the target and transforms it. 35 | download: If True, downloads the dataset from the internet and puts it into 36 | `root/oxford-iiit-pet`. If dataset is already downloaded, it is not downloaded again. 37 | ''' 38 | 39 | _RESOURCES = ( 40 | ('https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz', '5c4f3ee8e5d25df40f4fd59a7f44e54c'), 41 | ('https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz', '95a8c909bbe2e81eed6a22bccdf3f68f'), 42 | ) 43 | _VALID_TARGET_TYPES = ('category', 'binary-category', 'segmentation', 'detection') 44 | 45 | def __init__( 46 | self, 47 | root: str | pathlib.Path, 48 | split: str='trainval', 49 | target_types: Sequence[str] | str='category', 50 | transforms: Callable | None=None, 51 | transform: Callable | None=None, 52 | target_transform: Callable | None=None, 53 | download: bool=False, 54 | ): 55 | 56 | if split == 'train': split = 'trainval' 57 | if split == 'val': split = 'test' 58 | 59 | self._split = verify_str_arg(split, 'split', ('trainval', 'test')) 60 | if isinstance(target_types, str): 61 | target_types = [target_types] 62 | self._target_types = [ 63 | verify_str_arg(target_type, 'target_types', self._VALID_TARGET_TYPES) 64 | for target_type in target_types] 65 | 66 | super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) 67 | self._base_folder = pathlib.Path(self.root) # / 'oxford-iiit-pet' 68 | self._images_folder = self._base_folder / 'images' 69 | self._anns_folder = self._base_folder / 'annotations' 70 | self._segs_folder = self._anns_folder / 'trimaps' 71 | self._objs_folder = self._anns_folder / 'xmls' 72 | 73 | if download: 74 | self._download() 75 | 76 | if not self._check_exists(): 77 | raise RuntimeError( 78 | 'Dataset not found. You can use download=True to download it') 79 | 80 | image_ids = [] 81 | self._labels = [] 82 | self._bin_labels = [] 83 | with open(self._anns_folder / f'{self._split}.txt') as file: 84 | for line in file: 85 | image_id, label, bin_label, _ = line.strip().split() 86 | image_ids.append(image_id) 87 | self._labels.append(int(label) - 1) 88 | self._bin_labels.append(int(bin_label) - 1) 89 | 90 | self.bin_classes = ['Cat', 'Dog'] 91 | self.multi_classes = [ 92 | ' '.join(part.title() for part in raw_cls.split('_')) 93 | for raw_cls, _ in sorted( 94 | {(image_id.rsplit('_', 1)[0], label) for image_id, label in zip(image_ids, self._labels)}, 95 | key=lambda image_id_and_label: image_id_and_label[1], 96 | ) 97 | ] 98 | self.bin_class_to_idx = dict(zip(self.bin_classes, range(len(self.bin_classes)))) 99 | self.multi_class_to_idx = dict(zip(self.multi_classes, range(len(self.multi_classes)))) 100 | if len(target_types) == 1 and 'category' in target_types: 101 | self.classes = self.multi_classes 102 | else: 103 | self.classes = self.bin_classes 104 | 105 | self.tag_to_ix = dict(zip(['xmin', 'ymin', 'xmax', 'ymax'], range(4))) 106 | 107 | self._images = [self._images_folder / f'{image_id}.jpg' for image_id in image_ids] 108 | self._segs = [self._segs_folder / f'{image_id}.png' for image_id in image_ids] 109 | if 'detection' not in target_types: 110 | return 111 | _images = [] 112 | _segs = [] 113 | _labels = [] 114 | _bin_labels = [] 115 | self._objs = [] 116 | for i, image_id in enumerate(image_ids): 117 | xml_path = self._objs_folder / f'{image_id}.xml' 118 | if not os.path.exists(xml_path): continue 119 | _images.append(self._images[i]) 120 | _segs.append(self._segs[i]) 121 | _labels.append(self._labels[i]) 122 | _bin_labels.append(self._bin_labels[i]) 123 | self._objs.append(xml_path) 124 | self._images = _images 125 | self._segs = _segs 126 | self._labels = _labels 127 | self._bin_labels = _bin_labels 128 | 129 | def __len__(self) -> int: 130 | return len(self._images) 131 | 132 | def _load_obj(self, idx:int, image_size: Tuple[int, int]) -> Dict[str, Any]: 133 | path = self._objs[idx] 134 | with open(path) as f: 135 | tree = ET.parse(f) 136 | root = tree.getroot() 137 | for cursor in root: 138 | if cursor.tag == 'object': break 139 | for bndbox in cursor: 140 | if bndbox.tag == 'bndbox': break 141 | 142 | bbox = [0] * 4 143 | for item in bndbox: 144 | bbox_ix = self.tag_to_ix[item.tag] 145 | bbox[bbox_ix] = float(item.text) 146 | 147 | return dict( 148 | # labels=torch.LongTensor([self._labels[idx]]), 149 | labels=torch.LongTensor([self._bin_labels[idx]]), 150 | boxes=tv_tensors.BoundingBoxes( 151 | [bbox], 152 | format='XYXY', 153 | canvas_size=(image_size[1], image_size[0]), 154 | ) 155 | ) 156 | 157 | def _load_mask(self, idx:int) -> tv_tensors.Mask: 158 | image = Image.open(self._segs[idx]) 159 | mask = tv_tensors.Mask(image, dtype=torch.long) 160 | mask[mask == 2] = 0 161 | mask[mask == 1] = 2 162 | mask[mask == 3] = 1 163 | return tv_tensors.Mask( 164 | F.one_hot(mask, 3).transpose(0, 3).squeeze(-1)) 165 | 166 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 167 | image = Image.open(self._images[idx]).convert('RGB') 168 | 169 | target: Any = [] 170 | for target_type in self._target_types: 171 | if target_type == 'category': 172 | target.append(self._labels[idx]) 173 | elif target_type == 'binary-category': 174 | target.append(self._bin_labels[idx]) 175 | elif target_type == 'segmentation': 176 | target.append(self._load_mask(idx)) 177 | elif target_type == 'detection': 178 | target.append(self._load_obj(idx, image.size)) 179 | 180 | if not target: 181 | target = None 182 | elif len(target) == 1: 183 | target = target[0] 184 | else: 185 | target = tuple(target) 186 | 187 | if self.transforms: 188 | image, target = self.transforms(image, target) 189 | 190 | return image, target 191 | 192 | def _check_exists(self) -> bool: 193 | for folder in (self._images_folder, self._anns_folder): 194 | if not (os.path.exists(folder) and os.path.isdir(folder)): 195 | return False 196 | return True 197 | 198 | def _download(self) -> None: 199 | if self._check_exists(): return 200 | 201 | for url, md5 in self._RESOURCES: 202 | download_and_extract_archive( 203 | url, download_root=str(self._base_folder), md5=md5) 204 | -------------------------------------------------------------------------------- /vklearn/datasets/places365.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Tuple 2 | import os 3 | 4 | from PIL import Image 5 | from torchvision.datasets.vision import VisionDataset 6 | 7 | 8 | class Places365(VisionDataset): 9 | '''`Places365 `_ classification dataset. 10 | 11 | Args: 12 | root: Root directory where images are downloaded to. 13 | split: The dataset split, supports `"train"`(default), `"val"`. 14 | transform: A function/transform that takes in a PIL image 15 | and returns a transformed version. E.g, `transforms.PILToTensor` 16 | target_transform: A function/transform that takes in the 17 | target and transforms it. 18 | transforms: A function/transform that takes input sample and its target as entry 19 | and returns a transformed version. 20 | ''' 21 | 22 | def __init__( 23 | self, 24 | root: str, 25 | split: str='train', 26 | transform: Callable | None=None, 27 | target_transform: Callable | None=None, 28 | transforms: Callable | None=None, 29 | ): 30 | 31 | super().__init__(root, transforms, transform, target_transform) 32 | assert split in ['train', 'val'] 33 | self.dataset_dir = root 34 | self.classes = sorted(os.listdir(os.path.join(root, split))) 35 | paths_file = os.path.join(root, f'{split}.txt') 36 | self.paths = [] 37 | with open(paths_file) as f: 38 | for path in f: 39 | path = path.strip() 40 | if not path: continue 41 | self.paths.append(path) 42 | 43 | def __len__(self) -> int: 44 | return len(self.paths) 45 | 46 | def _load_image(self, path:str) -> Image.Image: 47 | return Image.open( 48 | os.path.join(self.dataset_dir, path)).convert('RGB') 49 | 50 | def _load_label(self, path:str) -> int: 51 | category = os.path.basename(os.path.dirname(path)) 52 | return self.classes.index(category) 53 | 54 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 55 | if not isinstance(index, int): 56 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 57 | 58 | path = self.paths[index] 59 | image = self._load_image(path) 60 | target = self._load_label(path) 61 | 62 | if self.transforms is not None: 63 | image, target = self.transforms(image, target) 64 | 65 | return image, target 66 | -------------------------------------------------------------------------------- /vklearn/datasets/plain_bbox.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple, Dict 2 | import os.path 3 | from glob import glob 4 | 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision.datasets.vision import VisionDataset 9 | from torchvision import tv_tensors 10 | 11 | 12 | class PlainBBox(VisionDataset): 13 | '''`Plain bbox annotated format common dataset. 14 | 15 | Args: 16 | root: Root directory where images are downloaded to. 17 | split: The dataset split, supports ``""`` (default), ``"train"``, ``"valid"`` or ``"test"``. 18 | transform: A function/transform that takes in a PIL image 19 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 20 | target_transform: A function/transform that takes in the 21 | target and transforms it. 22 | transforms: A function/transform that takes input sample and its target as entry 23 | and returns a transformed version. 24 | ''' 25 | 26 | def __init__( 27 | self, 28 | root: str, 29 | split: str='', 30 | transform: Callable | None=None, 31 | target_transform: Callable | None=None, 32 | transforms: Callable | None=None, 33 | ): 34 | super().__init__(root, transforms, transform, target_transform) 35 | assert split in ['', 'train', 'valid', 'test'] 36 | self.dataset_dir = os.path.abspath(os.path.join(root, split)) 37 | assert os.path.isdir(self.dataset_dir) 38 | 39 | self.label_paths = sorted(glob(os.path.join(self.dataset_dir, 'labels/*.txt'))) 40 | 41 | def __len__(self) -> int: 42 | return len(self.label_paths) 43 | 44 | def _load_image(self, path:str) -> Image.Image: 45 | return Image.open( 46 | os.path.join(self.dataset_dir, 'images', path)).convert('RGB') 47 | 48 | def _load_anns(self, id:int) -> Tuple[List[Any], str]: 49 | label_path = self.label_paths[id] 50 | imagePath = os.path.splitext(os.path.basename(label_path))[0] + '.jpg' 51 | anns = [] 52 | with open(label_path) as f: 53 | for line in f: 54 | line = line.strip() 55 | if not line: continue 56 | items = line.split() 57 | anns.append(dict( 58 | label=int(items[0]), 59 | bndbox=list(map(float, items[1:])), 60 | )) 61 | return anns, imagePath 62 | 63 | def _bndbox2xyxy( 64 | self, 65 | cxcywh: List[float], 66 | image_size: Tuple[int, int], 67 | ) -> List[float]: 68 | 69 | cxcywh = cxcywh 70 | if len(cxcywh) == 8: 71 | xs = [cxcywh[i] for i in range(0, 7, 2)] 72 | ys = [cxcywh[i] for i in range(1, 8, 2)] 73 | cx = sum(xs) / 4 74 | cy = sum(ys) / 4 75 | w = max(xs) - min(xs) 76 | h = max(ys) - min(ys) 77 | else: 78 | cx, cy, w, h = cxcywh 79 | x1 = (cx - w / 2) * image_size[0] 80 | y1 = (cy - h / 2) * image_size[1] 81 | x2 = (cx + w / 2) * image_size[0] 82 | y2 = (cy + h / 2) * image_size[1] 83 | return [x1, y1, x2, y2] 84 | 85 | def _format_anns( 86 | self, 87 | anns: List[Any], 88 | image_size: Tuple[int, int], 89 | ) -> Dict[str, Any]: 90 | boxes = tv_tensors.BoundingBoxes( 91 | [self._bndbox2xyxy(ann['bndbox'], image_size) for ann in anns], 92 | format='XYXY', 93 | canvas_size=(image_size[1], image_size[0]), 94 | ) 95 | labels = torch.LongTensor([ann['label'] for ann in anns]) 96 | return dict(boxes=boxes, labels=labels) 97 | 98 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 99 | 100 | if not isinstance(index, int): 101 | raise ValueError(f'Index must be of type integer, got {type(index)} instead.') 102 | 103 | anns, imagePath = self._load_anns(index) 104 | if len(anns) == 0: 105 | return self.__getitem__((index + 1) % self.__len__()) 106 | 107 | image = self._load_image(imagePath) 108 | target = self._format_anns(anns, image.size) 109 | 110 | if self.transforms is not None: 111 | image, target = self.transforms(image, target) 112 | 113 | return image, target 114 | -------------------------------------------------------------------------------- /vklearn/datasets/publaynet.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Tuple, Dict 2 | import os.path 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | from torchvision.datasets.vision import VisionDataset 8 | from torchvision import tv_tensors 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class PubLayNetDet(VisionDataset): 14 | '''`PubLayNet Detection Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root: Root directory where images are downloaded to. 20 | annFile: Path to json annotation file. 21 | transform: A function/transform that takes in a PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform: A function/transform that takes in the 24 | target and transforms it. 25 | transforms: A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | ''' 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | annFile: str, 33 | transform: Callable | None=None, 34 | target_transform: Callable | None=None, 35 | transforms: Callable | None=None, 36 | ): 37 | super().__init__(root, transforms, transform, target_transform) 38 | from pycocotools.coco import COCO 39 | 40 | self.coco = COCO(annFile) 41 | 42 | print('discard redundant annotations.') 43 | ids = list(sorted(self.coco.imgs.keys())) 44 | paths = set(os.listdir(root)) 45 | self.ids = [] 46 | for ix in tqdm(ids, ncols=80): 47 | path = self.coco.loadImgs(ix)[0]['file_name'] 48 | if path not in paths: continue 49 | paths.remove(path) 50 | self.ids.append(ix) 51 | 52 | self.coid2class = { 53 | clss['id']: clss['name'] 54 | for clss in self.coco.dataset['categories']} 55 | 56 | self.classes = [self.coid2class[i] for i in sorted(self.coid2class.keys())] 57 | 58 | def __len__(self) -> int: 59 | return len(self.ids) 60 | 61 | def _load_image(self, id:int) -> Image.Image: 62 | path = self.coco.loadImgs(id)[0]['file_name'] 63 | return Image.open(os.path.join(self.root, path)).convert('RGB') 64 | 65 | def _load_anns(self, id:int) -> List[Any]: 66 | # return self.coco.loadAnns(self.coco.getAnnIds(id)) 67 | return [ 68 | ann for ann in self.coco.loadAnns(self.coco.getAnnIds(id)) 69 | if ann['category_id'] > 0] 70 | 71 | def _format_anns( 72 | self, 73 | anns: List[Any], 74 | image_size: Tuple[int, int], 75 | ) -> Dict[str, Any]: 76 | 77 | xywh2xyxy = lambda x, y, w, h: (x, y, x + w, y + h) 78 | validation = lambda ann: ann['iscrowd'] == 0 79 | box_list = [] 80 | label_list = [] 81 | for ann in anns: 82 | if not validation(ann): continue 83 | class_name = self.coid2class[ann['category_id']] 84 | box_list.append(xywh2xyxy(*ann['bbox'])) 85 | label_list.append(self.classes.index(class_name)) 86 | boxes = tv_tensors.BoundingBoxes( 87 | box_list, 88 | format='XYXY', 89 | canvas_size=(image_size[1], image_size[0]), 90 | ) 91 | labels = torch.LongTensor(label_list) 92 | return dict(boxes=boxes, labels=labels) 93 | 94 | def __getitem__(self, index:int) -> Tuple[Any, Any]: 95 | if not isinstance(index, int): 96 | raise ValueError( 97 | f'Index must be of type integer, got {type(index)} instead.') 98 | 99 | id = self.ids[index] 100 | anns = self._load_anns(id) 101 | if len(anns) == 0: 102 | return self.__getitem__((index + 1) % self.__len__()) 103 | 104 | image = self._load_image(id) 105 | target = self._format_anns(anns, image.size) 106 | 107 | if self.transforms is not None: 108 | image, target = self.transforms(image, target) 109 | 110 | return image, target 111 | -------------------------------------------------------------------------------- /vklearn/datasets/vocdetection.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Any, Dict, List 2 | from pathlib import Path 3 | from xml.etree.ElementTree import parse as ET_parse 4 | 5 | from torchvision.datasets import VOCDetection as _VOCDetection 6 | from torchvision import tv_tensors 7 | import torch 8 | from PIL import Image 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class VOCDetection(_VOCDetection): 14 | DEFAULT_CLASSES = [ 15 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 16 | 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 17 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 18 | ] 19 | 20 | def __init__( 21 | self, 22 | root: str | Path, 23 | year: str='2012', 24 | image_set: str='train', 25 | sub_categories: List[str] | None=None, 26 | download: bool=False, 27 | transform: Callable | None=None, 28 | target_transform: Callable | None=None, 29 | transforms: Callable | None=None, 30 | ): 31 | super().__init__( 32 | root, year, image_set, download, transform, target_transform, transforms) 33 | 34 | self.classes = sub_categories or self.DEFAULT_CLASSES 35 | if len(self.classes) != len(self.DEFAULT_CLASSES): 36 | ids = list(range(len(self.images))) 37 | ids = self._drop_others(ids) 38 | self.images = [self.images[idx] for idx in ids] 39 | self.targets = [self.targets[idx] for idx in ids] 40 | 41 | def _drop_others(self, ids:List[int]) -> List[int]: 42 | new_ids = [] 43 | for idx in tqdm(ids, ncols=80): 44 | anns = self.parse_voc_xml(ET_parse(self.annotations[idx]).getroot()) 45 | anns = anns['annotation'] 46 | for obj in anns['object']: 47 | name = obj['name'] 48 | if name in self.classes: 49 | new_ids.append(idx) 50 | break 51 | return new_ids 52 | 53 | def _format_anns(self, anns:Dict[str, Any]) -> Dict[str, Any]: 54 | anns = anns['annotation'] 55 | box_list = [] 56 | label_list = [] 57 | for obj in anns['object']: 58 | name = obj['name'] 59 | if name not in self.classes: continue 60 | bbox_raw = obj['bndbox'] 61 | bbox = [float(bbox_raw[k]) for k in ('xmin', 'ymin', 'xmax', 'ymax')] 62 | box_list.append(bbox) 63 | label_list.append(self.classes.index(name)) 64 | size_w = int(anns['size']['width']) 65 | size_h = int(anns['size']['height']) 66 | boxes = tv_tensors.BoundingBoxes( 67 | box_list, 68 | format='XYXY', 69 | canvas_size=(size_h, size_w), 70 | ) 71 | labels = torch.LongTensor(label_list) 72 | return dict(boxes=boxes, labels=labels) 73 | 74 | def __getitem__(self, idx:int) -> Tuple[Any, Any]: 75 | img = Image.open(self.images[idx]).convert('RGB') 76 | anns = self.parse_voc_xml(ET_parse(self.annotations[idx]).getroot()) 77 | target = self._format_anns(anns) 78 | if self.transforms is not None: 79 | img, target = self.transforms(img, target) 80 | return img, target 81 | -------------------------------------------------------------------------------- /vklearn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/vklearn/models/__init__.py -------------------------------------------------------------------------------- /vklearn/models/basic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Mapping, Tuple, List 2 | import math 3 | 4 | from torch import Tensor 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.transforms import v2 8 | from PIL import Image 9 | 10 | 11 | class Basic(nn.Module): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | self._keep_features = False 17 | 18 | self._image2tensor = v2.Compose([ 19 | v2.ToImage(), 20 | v2.ToDtype(torch.float32, scale=True), 21 | v2.Normalize( 22 | mean=[0.485, 0.456, 0.406], 23 | std=[0.229, 0.224, 0.225], 24 | ), 25 | ]) 26 | 27 | @classmethod 28 | def get_transforms( 29 | cls, 30 | task_name: str='default', 31 | ) -> Tuple[v2.Transform, v2.Transform]: 32 | assert not 'this is an empty func' 33 | 34 | @classmethod 35 | def load_from_state(cls, state:Mapping[str, Any]) -> 'Basic': 36 | assert not 'this is an empty func' 37 | 38 | def hyperparameters(self) -> Dict[str, Any]: 39 | assert not 'this is an empty func' 40 | 41 | def collate_fn( 42 | self, 43 | batch: List[Any], 44 | ) -> Any: 45 | assert not 'this is an empty func' 46 | 47 | def preprocess( 48 | self, 49 | image: Image.Image, 50 | align_size: int, 51 | limit_size: int, 52 | fill_value: int, 53 | ) -> Tuple[Tensor, float, int, int]: 54 | 55 | src_w, src_h = image.size 56 | _align_size = math.ceil(align_size / limit_size) * limit_size 57 | scale = _align_size / max(src_w, src_h) 58 | dst_w, dst_h = round(scale * src_w), round(scale * src_h) 59 | sample = image.resize( 60 | size=(dst_w, dst_h), 61 | resample=Image.Resampling.BILINEAR) 62 | frame = Image.new( 63 | mode='RGB', 64 | size=(_align_size, _align_size), 65 | color=(fill_value, ) * 3) 66 | pad_x = (align_size - dst_w) // 2 67 | pad_y = (align_size - dst_h) // 2 68 | frame.paste(sample, box=(pad_x, pad_y)) 69 | # inputs = v2.Compose([ 70 | # v2.ToImage(), 71 | # v2.ToDtype(torch.float32, scale=True), 72 | # v2.Normalize( 73 | # mean=[0.485, 0.456, 0.406], 74 | # std=[0.229, 0.224, 0.225], 75 | # ) 76 | # ])(frame).unsqueeze(dim=0) 77 | inputs = self._image2tensor(frame).unsqueeze(dim=0) 78 | return inputs, scale, pad_x, pad_y 79 | 80 | def get_model_device(self) -> torch.device: 81 | return next(self.parameters()).device 82 | 83 | def train_features(self, flag:bool): 84 | self._keep_features = not flag 85 | -------------------------------------------------------------------------------- /vklearn/models/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict, Tuple 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | 7 | from torchvision import tv_tensors 8 | from torchvision.transforms import v2 9 | 10 | from torchmetrics.classification import Precision, Recall 11 | 12 | from PIL import Image 13 | 14 | from .basic import Basic 15 | 16 | 17 | class Classifier(Basic): 18 | 19 | def __init__(self, categories:List[str]): 20 | super().__init__() 21 | 22 | self.categories = list(categories) 23 | self.num_classes = len(categories) 24 | self.precision_metric = Precision( 25 | task='multiclass', 26 | num_classes=self.num_classes, 27 | average='macro') 28 | self.recall_metric = Recall( 29 | task='multiclass', 30 | num_classes=self.num_classes, 31 | average='macro') 32 | 33 | def preprocess( 34 | self, 35 | image: Image.Image, 36 | align_size: int | Tuple[int, int], 37 | ) -> Tensor: 38 | 39 | if isinstance(align_size, int): 40 | src_w, src_h = image.size 41 | scale = align_size / min(src_w, src_h) 42 | dst_w, dst_h = round(src_w * scale), round(src_h * scale) 43 | x1 = (dst_w - align_size) // 2 44 | y1 = (dst_h - align_size) // 2 45 | x2 = x1 + align_size 46 | y2 = y1 + align_size 47 | 48 | resized = image.resize((dst_w, dst_h), resample=Image.Resampling.BILINEAR) 49 | sampled = resized.crop((x1, y1, x2, y2)) 50 | else: 51 | sampled = image.resize(align_size, resample=Image.Resampling.BILINEAR) 52 | 53 | return self._image2tensor(sampled).unsqueeze(dim=0) 54 | 55 | 56 | def classify( 57 | self, 58 | image: Image.Image, 59 | top_k: int=10, 60 | align_size: int | Tuple[int, int]=224, 61 | ) -> List[Dict[str, Any]]: 62 | assert not 'this is an empty func' 63 | 64 | def calc_loss( 65 | self, 66 | inputs: Tensor, 67 | target: Tensor, 68 | weights: Dict[str, float] | None=None, 69 | alpha: float=0.25, 70 | gamma: float=2., 71 | ) -> Dict[str, Any]: 72 | assert not 'this is an empty func' 73 | 74 | def calc_score( 75 | self, 76 | inputs: Tensor, 77 | target: Tensor, 78 | thresh: float=0.5, 79 | eps: float=1e-5, 80 | ) -> Dict[str, Any]: 81 | assert not 'this is an empty func' 82 | 83 | def update_metric( 84 | self, 85 | inputs: Tensor, 86 | target: Tensor, 87 | thresh: float=0.5, 88 | ): 89 | 90 | predict = inputs.argmax(dim=-1) 91 | self.precision_metric.update(predict, target) 92 | self.recall_metric.update(predict, target) 93 | 94 | def compute_metric(self) -> Dict[str, Any]: 95 | precision = self.precision_metric.compute() 96 | recall = self.recall_metric.compute() 97 | f1_score = ( 98 | 2 * precision * recall / 99 | torch.clamp_min(precision + recall, 1e-5)) 100 | self.precision_metric.reset() 101 | self.recall_metric.reset() 102 | return dict( 103 | precision=precision, 104 | recall=recall, 105 | f1_score=f1_score, 106 | ) 107 | 108 | def collate_fn( 109 | self, 110 | batch: List[Any], 111 | ) -> Any: 112 | assert not 'this is an empty func' 113 | 114 | @classmethod 115 | def get_transforms( 116 | cls, 117 | task_name: str='default', 118 | ) -> Tuple[v2.Transform, v2.Transform]: 119 | 120 | train_transforms = None 121 | test_transforms = None 122 | 123 | if task_name == 'default': 124 | task_name = 'imagenetx224' 125 | 126 | if task_name == 'imagenetx224': 127 | aligned_size = 224 128 | elif task_name == 'imagenetx256': 129 | aligned_size = 256 130 | elif task_name == 'imagenetx384': 131 | aligned_size = 384 132 | elif task_name == 'imagenetx448': 133 | aligned_size = 448 134 | elif task_name == 'imagenetx512': 135 | aligned_size = 512 136 | elif task_name == 'imagenetx640': 137 | aligned_size = 640 138 | elif task_name == 'documentx224': 139 | aligned_size = 224 140 | else: 141 | raise ValueError(f'Unsupported the task `{task_name}`') 142 | 143 | if task_name.startswith('imagenet'): 144 | train_transforms = v2.Compose([ 145 | v2.ToImage(), 146 | v2.ScaleJitter( 147 | target_size=(aligned_size, aligned_size), 148 | scale_range=(1., 2.), 149 | antialias=True), 150 | v2.RandomPhotometricDistort(p=1), 151 | v2.RandomHorizontalFlip(p=0.5), 152 | v2.RandomChoice([ 153 | v2.GaussianBlur(7, sigma=(0.1, 2.0)), 154 | v2.RandomAdjustSharpness(2, p=0.5), 155 | v2.RandomEqualize(p=0.5), 156 | ]), 157 | v2.RandomCrop( 158 | size=(aligned_size, aligned_size), 159 | pad_if_needed=True, 160 | fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}), 161 | v2.ToDtype(torch.float32, scale=True), 162 | v2.Normalize( 163 | mean=[0.485, 0.456, 0.406], 164 | std=[0.229, 0.224, 0.225], 165 | ) 166 | ]) 167 | test_transforms = v2.Compose([ 168 | v2.ToImage(), 169 | v2.Resize( 170 | size=aligned_size, 171 | antialias=True), 172 | v2.CenterCrop(aligned_size), 173 | v2.ToDtype(torch.float32, scale=True), 174 | v2.Normalize( 175 | mean=[0.485, 0.456, 0.406], 176 | std=[0.229, 0.224, 0.225], 177 | ) 178 | ]) 179 | 180 | elif task_name.startswith('document'): 181 | train_transforms = v2.Compose([ 182 | v2.ToImage(), 183 | v2.Resize( 184 | size=(aligned_size, aligned_size), 185 | antialias=True), 186 | v2.RandomPhotometricDistort(p=1), 187 | v2.RandomChoice([ 188 | v2.GaussianBlur(7, sigma=(0.1, 2.0)), 189 | v2.RandomAdjustSharpness(2, p=0.5), 190 | v2.RandomEqualize(p=0.5), 191 | ]), 192 | v2.ToDtype(torch.float32, scale=True), 193 | v2.Normalize( 194 | mean=[0.485, 0.456, 0.406], 195 | std=[0.229, 0.224, 0.225], 196 | ) 197 | ]) 198 | test_transforms = v2.Compose([ 199 | v2.ToImage(), 200 | v2.Resize( 201 | size=(aligned_size, aligned_size), 202 | antialias=True), 203 | v2.ToDtype(torch.float32, scale=True), 204 | v2.Normalize( 205 | mean=[0.485, 0.456, 0.406], 206 | std=[0.229, 0.224, 0.225], 207 | ) 208 | ]) 209 | 210 | return train_transforms, test_transforms 211 | -------------------------------------------------------------------------------- /vklearn/models/distiller.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict, Tuple, Callable 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torchvision import tv_tensors 10 | from torchvision.transforms import v2 11 | 12 | from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError 13 | 14 | from .basic import Basic 15 | 16 | 17 | class Distiller(Basic): 18 | 19 | def __init__( 20 | self, 21 | teacher: nn.Module, 22 | student: nn.Module, 23 | in_transform: Callable | None=None, 24 | out_project: nn.Module | None=None, 25 | ): 26 | 27 | super().__init__() 28 | 29 | if in_transform is None: 30 | in_transform = lambda x: x 31 | if out_project is None: 32 | out_project = nn.Identity() 33 | 34 | self.teacher = teacher 35 | self.student = student 36 | self.in_transform = in_transform 37 | self.out_project = out_project 38 | 39 | self.mse_metric = MeanSquaredError() 40 | self.mae_metric = MeanAbsoluteError() 41 | 42 | def forward(self, x:Tensor) -> Tuple[Tensor, Tensor]: 43 | with torch.no_grad(): 44 | teacher_targ = self.teacher(self.in_transform(x)) 45 | student_pred = self.out_project(self.student(x)) 46 | return teacher_targ, student_pred 47 | 48 | def calc_loss( 49 | self, 50 | inputs: Tuple[Tensor, Tensor], 51 | target: Tensor, 52 | ) -> Dict[str, Any]: 53 | 54 | teacher_targ, student_pred = inputs 55 | loss = F.mse_loss(student_pred, teacher_targ, reduction='mean') 56 | return dict( 57 | loss=loss, 58 | ) 59 | 60 | def calc_score( 61 | self, 62 | inputs: Tuple[Tensor, Tensor], 63 | target: Tensor, 64 | eps: float=1e-5, 65 | ) -> Dict[str, Any]: 66 | 67 | teacher_targ, student_pred = inputs 68 | cos_sim = F.cosine_similarity(student_pred, teacher_targ, dim=1).mean() 69 | return dict( 70 | cos_sim=cos_sim, 71 | ) 72 | 73 | def update_metric( 74 | self, 75 | inputs: Tuple[Tensor, Tensor], 76 | target: Tensor, 77 | ): 78 | 79 | teacher_targ, student_pred = inputs 80 | teacher_targ = teacher_targ.contiguous() 81 | self.mse_metric.update(student_pred, teacher_targ) 82 | self.mae_metric.update(student_pred, teacher_targ) 83 | 84 | def compute_metric(self) -> Dict[str, Any]: 85 | mse = self.mse_metric.compute() 86 | mae = self.mae_metric.compute() 87 | mss = 1 - mse / (1 + mse) 88 | self.mse_metric.reset() 89 | self.mae_metric.reset() 90 | return dict( 91 | mss=mss, 92 | mse=mse, 93 | mae=mae, 94 | ) 95 | 96 | def collate_fn( 97 | self, 98 | batch: List[Any], 99 | ) -> Any: 100 | assert not 'this is an empty func' 101 | 102 | @classmethod 103 | def get_transforms( 104 | cls, 105 | task_name: str='default', 106 | ) -> Tuple[v2.Transform, v2.Transform]: 107 | 108 | train_transforms = None 109 | test_transforms = None 110 | 111 | if task_name in ('default', 'cocox512'): 112 | aligned_size = 512 113 | elif task_name == 'cocox384': 114 | aligned_size = 384 115 | elif task_name == 'cocox448': 116 | aligned_size = 448 117 | elif task_name == 'cocox640': 118 | aligned_size = 640 119 | else: 120 | raise ValueError(f'Unsupported the task `{task_name}`') 121 | 122 | train_transforms = v2.Compose([ 123 | v2.ToImage(), 124 | v2.ScaleJitter( 125 | target_size=(aligned_size, aligned_size), 126 | scale_range=(min(0.9, 384 / aligned_size), 1.1), 127 | antialias=True), 128 | v2.RandomPhotometricDistort(p=1), 129 | v2.RandomHorizontalFlip(p=0.5), 130 | v2.RandomChoice([ 131 | v2.GaussianBlur(7, sigma=(0.1, 2.0)), 132 | v2.RandomAdjustSharpness(2, p=0.5), 133 | v2.RandomEqualize(p=0.5), 134 | ]), 135 | v2.RandomCrop( 136 | size=(aligned_size, aligned_size), 137 | pad_if_needed=True, 138 | fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}), 139 | v2.ToDtype(torch.float32, scale=True), 140 | v2.Normalize( 141 | mean=[0.485, 0.456, 0.406], 142 | std=[0.229, 0.224, 0.225], 143 | ) 144 | ]) 145 | test_transforms = v2.Compose([ 146 | v2.ToImage(), 147 | v2.Resize( 148 | size=aligned_size - 1, 149 | max_size=aligned_size, 150 | antialias=True), 151 | v2.Pad( 152 | padding=aligned_size // 4, 153 | fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}), 154 | v2.CenterCrop(aligned_size), 155 | v2.ToDtype(torch.float32, scale=True), 156 | v2.Normalize( 157 | mean=[0.485, 0.456, 0.406], 158 | std=[0.229, 0.224, 0.225], 159 | ) 160 | ]) 161 | 162 | return train_transforms, test_transforms 163 | -------------------------------------------------------------------------------- /vklearn/models/trimnetclf.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict, Mapping, Tuple 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from PIL import Image 10 | 11 | from .classifier import Classifier 12 | from .trimnetx import TrimNetX 13 | from .component import DEFAULT_ACTIVATION 14 | 15 | 16 | class TrimNetClf(Classifier): 17 | '''A light-weight and easy-to-train model for image classification 18 | 19 | Args: 20 | categories: Target categories. 21 | num_scans: Number of the Trim-Units. 22 | scan_range: Range factor of the Trim-Unit convolution. 23 | backbone: Specify a basic model as a feature extraction module. 24 | backbone_pretrained: Whether to load backbone pretrained weights. 25 | ''' 26 | 27 | def __init__( 28 | self, 29 | categories: List[str], 30 | dropout_p: float=0.2, 31 | num_scans: int | None=None, 32 | scan_range: int | None=None, 33 | backbone: str | None=None, 34 | backbone_pretrained: bool | None=None, 35 | ): 36 | super().__init__(categories) 37 | 38 | self.dropout_p = dropout_p 39 | 40 | self.trimnetx = TrimNetX( 41 | num_scans, scan_range, backbone, backbone_pretrained) 42 | 43 | merged_dim = self.trimnetx.merged_dim 44 | expanded_dim = merged_dim * 4 45 | 46 | self.predictor = nn.Sequential( 47 | nn.Linear(merged_dim, expanded_dim), 48 | DEFAULT_ACTIVATION(inplace=False), 49 | nn.Dropout(dropout_p, inplace=False), 50 | nn.Linear(expanded_dim, self.num_classes), 51 | ) 52 | 53 | self.alphas = nn.Parameter(torch.zeros( 54 | 1, self.num_classes, self.trimnetx.num_scans)) 55 | 56 | def train_features(self, flag:bool): 57 | self.trimnetx.train_features(flag) 58 | 59 | def forward(self, x:Tensor) -> Tensor: 60 | hs, _ = self.trimnetx(x) 61 | alphas = self.alphas.softmax(dim=-1) 62 | p = 0. 63 | times = len(hs) 64 | for t in range(times): 65 | h = F.adaptive_avg_pool2d(hs[t], 1).flatten(start_dim=1) 66 | p = p + self.predictor(h) * alphas[..., t] 67 | return p 68 | 69 | @classmethod 70 | def load_from_state(cls, state:Mapping[str, Any]) -> 'TrimNetClf': 71 | hyps = state['hyperparameters'] 72 | model = cls( 73 | categories = hyps['categories'], 74 | dropout_p = hyps['dropout_p'], 75 | num_scans = hyps['num_scans'], 76 | scan_range = hyps['scan_range'], 77 | backbone = hyps['backbone'], 78 | backbone_pretrained = False, 79 | ) 80 | model.load_state_dict(state['model']) 81 | return model 82 | 83 | def hyperparameters(self) -> Dict[str, Any]: 84 | return dict( 85 | categories = self.categories, 86 | dropout_p = self.dropout_p, 87 | num_scans = self.trimnetx.num_scans, 88 | scan_range = self.trimnetx.scan_range, 89 | backbone = self.trimnetx.backbone, 90 | ) 91 | 92 | def classify( 93 | self, 94 | image: Image.Image, 95 | top_k: int=10, 96 | align_size: int | Tuple[int, int]=224, 97 | ) -> Dict[str, Any]: 98 | 99 | device = self.get_model_device() 100 | x = self.preprocess(image, align_size) 101 | x = x.to(device) 102 | x = self.forward(x) 103 | top_k = min(self.num_classes, top_k) 104 | topk = x.squeeze(dim=0).softmax(dim=-1).topk(top_k) 105 | probs = [round(v, 5) for v in topk.values.tolist()] 106 | labels = [self.categories[cid] for cid in topk.indices] 107 | return dict( 108 | probs=dict(zip(labels, probs)), 109 | predict=labels[0], 110 | ) 111 | 112 | def calc_loss( 113 | self, 114 | inputs: Tensor, 115 | target: Tensor, 116 | weights: Dict[str, float] | None=None, 117 | alpha: float=0.25, 118 | gamma: float=2., 119 | ) -> Dict[str, Any]: 120 | 121 | reduction = 'mean' 122 | loss = F.cross_entropy(inputs, target, reduction=reduction) 123 | 124 | return dict( 125 | loss=loss, 126 | ) 127 | 128 | def calc_score( 129 | self, 130 | inputs: Tensor, 131 | target: Tensor, 132 | thresh: float=0.5, 133 | eps: float=1e-5, 134 | ) -> Dict[str, Any]: 135 | 136 | if len(target.shape) == 2: 137 | predict = torch.softmax(inputs, dim=-1) 138 | accuracy = (1 - 0.5 * torch.abs(predict - target).sum(dim=-1)).mean() 139 | else: 140 | predict = torch.argmax(inputs, dim=-1) 141 | accuracy = (predict == target).sum() / len(predict) 142 | 143 | return dict( 144 | accuracy=accuracy, 145 | ) 146 | -------------------------------------------------------------------------------- /vklearn/models/trimnetdst.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Mapping 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .distiller import Distiller 7 | from .component import MobileNetFeatures, DinoFeatures, ConvNormActive 8 | 9 | 10 | class TrimNetDst(Distiller): 11 | '''A light-weight and easy-to-train model for knowledge distillation 12 | 13 | Args: 14 | teacher_arch: The architecture name of the teacher model. 15 | student_arch: The architecture name of the student model. 16 | pretrained: Whether to load student model pretrained weights. 17 | ''' 18 | 19 | def __init__( 20 | self, 21 | teacher_arch: str='dinov2_vits14', 22 | student_arch: str='mobilenet_v3_small', 23 | pretrained: bool=True, 24 | ): 25 | 26 | if teacher_arch.startswith('mobilenet'): 27 | teacher = MobileNetFeatures(teacher_arch, pretrained=True) 28 | elif teacher_arch.startswith('dinov2'): 29 | teacher = DinoFeatures(teacher_arch) 30 | elif teacher_arch.startswith('cares'): 31 | teacher = DinoFeatures(teacher_arch) 32 | else: 33 | raise ValueError(f'Unsupported arch `{teacher_arch}`') 34 | 35 | student = MobileNetFeatures(student_arch, pretrained) 36 | 37 | in_transform = None 38 | if teacher.cell_size != student.cell_size: 39 | scale_factor = teacher.cell_size / student.cell_size 40 | in_transform = lambda x: F.interpolate( 41 | x, scale_factor=scale_factor, mode='bilinear') 42 | 43 | in_planes, out_planes = student.features_dim, teacher.features_dim 44 | out_project = nn.Sequential( 45 | ConvNormActive(in_planes, in_planes, 1), 46 | ConvNormActive(in_planes, in_planes, 3, groups=in_planes), 47 | ConvNormActive(in_planes, out_planes, 1, norm_layer=None, activation=None), 48 | ) 49 | 50 | super().__init__(teacher, student, in_transform, out_project) 51 | 52 | self.teacher_arch = teacher_arch 53 | self.student_arch = student_arch 54 | 55 | @classmethod 56 | def load_from_state(cls, state:Mapping[str, Any]) -> 'TrimNetDst': 57 | hyps = state['hyperparameters'] 58 | model = cls( 59 | teacher_arch = hyps['teacher_arch'], 60 | student_arch = hyps['student_arch'], 61 | pretrained = False, 62 | ) 63 | model.load_state_dict(state['model']) 64 | return model 65 | 66 | def hyperparameters(self) -> Dict[str, Any]: 67 | return dict( 68 | teacher_arch = self.teacher_arch, 69 | student_arch = self.student_arch, 70 | ) 71 | -------------------------------------------------------------------------------- /vklearn/models/trimnetocr.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict, Mapping, Tuple 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from PIL import Image 10 | 11 | from .ocr import OCR 12 | from .trimnetx import TrimNetX 13 | from .component import LayerNorm2dChannel, CBANet 14 | 15 | 16 | class TrimNetOcr(OCR): 17 | '''A light-weight and easy-to-train model for ocr 18 | 19 | Args: 20 | categories: Target categories. 21 | num_scans: Number of the Trim-Units. 22 | scan_range: Range factor of the Trim-Unit convolution. 23 | backbone: Specify a basic model as a feature extraction module. 24 | backbone_pretrained: Whether to load backbone pretrained weights. 25 | ''' 26 | 27 | def __init__( 28 | self, 29 | categories: List[str], 30 | dropout_p: float=0., 31 | num_scans: int | None=None, 32 | scan_range: int | None=None, 33 | backbone: str | None='cares_large', 34 | backbone_pretrained: bool | None=False, 35 | ): 36 | 37 | assert backbone.startswith('cares') 38 | 39 | super().__init__(categories) 40 | 41 | self.dropout_p = dropout_p 42 | 43 | self.trimnetx = TrimNetX( 44 | num_scans, scan_range, backbone, backbone_pretrained, 45 | norm_layer=LayerNorm2dChannel) 46 | 47 | features_dim = self.trimnetx.features_dim 48 | merged_dim = self.trimnetx.merged_dim 49 | 50 | self.project = nn.Sequential( 51 | nn.Linear(features_dim + merged_dim, features_dim), 52 | nn.LayerNorm((features_dim, )), 53 | ) 54 | 55 | self.classifier = nn.Sequential( 56 | nn.Dropout(dropout_p, inplace=False), 57 | nn.Linear(features_dim, self.num_classes), 58 | ) 59 | 60 | for m in list(self.trimnetx.trim_units.modules()): 61 | if not isinstance(m, CBANet): continue 62 | m.channel_attention = nn.Identity() 63 | # m.spatial_attention = nn.Identity() 64 | 65 | self._temp_num_scans = self.trimnetx.num_scans 66 | 67 | def train_features(self, flag:bool): 68 | self.trimnetx.train_features(flag) 69 | 70 | def set_num_scans(self, num_scans:int): 71 | self._temp_num_scans = num_scans 72 | 73 | def forward(self, x:Tensor) -> Tuple[Tensor, List[Tensor]]: 74 | hs, f = self.trimnetx(x, self._temp_num_scans) 75 | fs = [f.squeeze(dim=2).transpose(1, 2)] 76 | for h in hs: 77 | h = h.squeeze(dim=2).transpose(1, 2) 78 | p = self.project(torch.cat([h, fs[-1]], dim=-1)) 79 | # fs.append(p + fs[-1]) 80 | fs.append((p + fs[-1]) * 0.5) 81 | x = fs.pop() 82 | x = self.classifier(x) 83 | return x, fs 84 | 85 | @classmethod 86 | def load_from_state(cls, state:Mapping[str, Any]) -> 'TrimNetOcr': 87 | hyps = state['hyperparameters'] 88 | model = cls( 89 | categories = hyps['categories'], 90 | dropout_p = hyps['dropout_p'], 91 | num_scans = hyps['num_scans'], 92 | scan_range = hyps['scan_range'], 93 | backbone = hyps['backbone'], 94 | backbone_pretrained = False, 95 | ) 96 | model.load_state_dict(state['model']) 97 | return model 98 | 99 | def hyperparameters(self) -> Dict[str, Any]: 100 | return dict( 101 | categories = self.categories, 102 | dropout_p = self.dropout_p, 103 | num_scans = self.trimnetx.num_scans, 104 | scan_range = self.trimnetx.scan_range, 105 | backbone = self.trimnetx.backbone, 106 | ) 107 | 108 | def recognize( 109 | self, 110 | image: Image.Image, 111 | top_k: int=0, 112 | align_size: int=32, 113 | to_gray: bool=True, 114 | whitelist: List[str] | None=None, 115 | ) -> Dict[str, Any]: 116 | 117 | device = self.get_model_device() 118 | if to_gray and (image.mode != 'L'): 119 | image = image.convert('L') 120 | if image.mode != 'RGB': 121 | image = image.convert('RGB') 122 | x = self.preprocess(image, align_size) 123 | x = x.to(device) 124 | x, _ = self.forward(x) 125 | 126 | if whitelist is not None: 127 | white_ixs = [self.categories.index(char) for char in whitelist] 128 | for ix in range(1, len(self.categories)): 129 | if ix in white_ixs: continue 130 | x[..., ix] = -10000 131 | 132 | preds = x.argmax(dim=2) # n, T 133 | mask = (F.pad(preds, [0, 1], value=0)[:, 1:] - preds) != 0 134 | preds = preds * mask 135 | decoded = self._categorie_arr[preds[0].cpu().numpy()] 136 | text = ''.join(decoded) 137 | 138 | probs = None 139 | labels = None 140 | top_k = min(self.num_classes, top_k) 141 | if top_k > 0: 142 | topk = x.squeeze(dim=0).softmax(dim=-1).topk(top_k) 143 | probs = topk.values.cpu().numpy() 144 | labels = topk.indices.cpu().numpy() 145 | return dict( 146 | probs=probs, 147 | labels=labels, 148 | text=text, 149 | decoded=decoded, 150 | ) 151 | 152 | def calc_loss( 153 | self, 154 | inputs: Tuple[Tensor, List[Tensor]], 155 | targets: Tensor, 156 | target_lengths: Tensor, 157 | zero_infinity: bool=False, 158 | weights: Dict[str, float] | None=None, 159 | ) -> Dict[str, Any]: 160 | 161 | predicts, features = inputs 162 | 163 | losses = super().calc_loss( 164 | predicts, targets, target_lengths, zero_infinity=zero_infinity) 165 | 166 | weights = weights or dict() 167 | auxi_weight = weights.get('auxi', 0) 168 | if auxi_weight == 0: features = [] 169 | 170 | loss = losses['loss'] 171 | losses['pf_loss'] = loss 172 | 173 | auxi_loss = 0 174 | for fid, feature in enumerate(features): 175 | predicts = self.classifier(feature) 176 | auxi_i_loss = super().calc_loss( 177 | predicts, targets, target_lengths, zero_infinity=zero_infinity)['loss'] 178 | auxi_loss = auxi_loss + auxi_i_loss 179 | losses[f'p{fid}_loss'] = auxi_i_loss 180 | 181 | losses['loss'] = ( 182 | loss * max(0, 1 - len(features) * auxi_weight) + 183 | auxi_loss * auxi_weight) 184 | return losses 185 | 186 | def calc_score( 187 | self, 188 | inputs: Tuple[Tensor, List[Tensor]], 189 | targets: Tensor, 190 | target_lengths: Tensor, 191 | ) -> Dict[str, Any]: 192 | return super().calc_score(inputs[0], targets, target_lengths) 193 | 194 | def update_metric( 195 | self, 196 | inputs: Tuple[Tensor, List[Tensor]], 197 | targets: Tensor, 198 | target_lengths: Tensor, 199 | ): 200 | return super().update_metric(inputs[0], targets, target_lengths) 201 | -------------------------------------------------------------------------------- /vklearn/models/trimnetseg.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Dict, Mapping 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from PIL import Image 10 | 11 | from .segment import Segment 12 | from .trimnetx import TrimNetX 13 | from .component import SegPredictor 14 | 15 | 16 | class TrimNetSeg(Segment): 17 | '''A light-weight and easy-to-train model for image segmentation 18 | 19 | Args: 20 | categories: Target categories. 21 | num_scans: Number of the Trim-Units. 22 | scan_range: Range factor of the Trim-Unit convolution. 23 | backbone: Specify a basic model as a feature extraction module. 24 | backbone_pretrained: Whether to load backbone pretrained weights. 25 | ''' 26 | 27 | def __init__( 28 | self, 29 | categories: List[str], 30 | num_scans: int | None=None, 31 | scan_range: int | None=None, 32 | backbone: str | None=None, 33 | backbone_pretrained: bool | None=None, 34 | ): 35 | super().__init__(categories) 36 | 37 | self.trimnetx = TrimNetX( 38 | num_scans, scan_range, backbone, backbone_pretrained) 39 | 40 | merged_dim = self.trimnetx.merged_dim 41 | 42 | self.scale_compensate = self.trimnetx.cell_size / 2**4 43 | 44 | self.predictor = SegPredictor(merged_dim, self.num_classes, self.trimnetx.num_scans) 45 | self.decoder = nn.Conv2d(self.predictor.embeded_dim, self.num_classes, 1) 46 | 47 | def train_features(self, flag:bool): 48 | self.trimnetx.train_features(flag) 49 | 50 | def forward_latent(self, x:Tensor) -> List[Tensor]: 51 | hs, _ = self.trimnetx(x) 52 | ps = self.predictor(hs) 53 | return ps 54 | 55 | def forward(self, x:Tensor) -> Tensor: 56 | ps = self.forward_latent(x) 57 | ps = [self.decoder(p) for p in ps] 58 | times = len(ps) 59 | for t in range(times): 60 | scale_factor = 2**(3 - t) 61 | if scale_factor == 1: continue 62 | ps[t] = F.interpolate( 63 | ps[t], scale_factor=scale_factor * self.scale_compensate, mode='bilinear') 64 | return torch.cat([p[..., None] for p in ps], dim=-1) 65 | 66 | @classmethod 67 | def load_from_state(cls, state:Mapping[str, Any]) -> 'TrimNetSeg': 68 | hyps = state['hyperparameters'] 69 | model = cls( 70 | categories = hyps['categories'], 71 | num_scans = hyps['num_scans'], 72 | scan_range = hyps['scan_range'], 73 | backbone = hyps['backbone'], 74 | backbone_pretrained = False, 75 | ) 76 | model.load_state_dict(state['model']) 77 | return model 78 | 79 | def hyperparameters(self) -> Dict[str, Any]: 80 | return dict( 81 | categories = self.categories, 82 | num_scans = self.trimnetx.num_scans, 83 | scan_range = self.trimnetx.scan_range, 84 | backbone = self.trimnetx.backbone, 85 | ) 86 | 87 | def segment( 88 | self, 89 | image: Image.Image, 90 | conf_thresh: float=0.5, 91 | align_size: int=448, 92 | ) -> List[Dict[str, Any]]: 93 | 94 | device = self.get_model_device() 95 | x, scale, pad_x, pad_y = self.preprocess( 96 | image, align_size, limit_size=32, fill_value=127) 97 | x = x.to(device) 98 | ps = self.forward_latent(x) 99 | p = self.decoder(ps[-1]) 100 | 101 | scale_factor = 2**(3 - (len(ps) - 1)) 102 | if scale_factor != 1: 103 | p = F.interpolate( 104 | p, scale_factor=scale_factor, mode='bilinear') 105 | 106 | src_w, src_h = image.size 107 | dst_w, dst_h = round(scale * src_w), round(scale * src_h) 108 | p = torch.softmax(p[..., pad_y:pad_y + dst_h, pad_x:pad_x + dst_w], dim=1) 109 | p[p < conf_thresh] = 0. 110 | p = F.interpolate(p, (src_h, src_w), mode='bilinear') 111 | return p[0].cpu().numpy() 112 | 113 | def dice_loss( 114 | self, 115 | inputs: Tensor, 116 | target: Tensor, 117 | smooth: float=1., 118 | ) -> Tensor: 119 | 120 | predict = torch.softmax(inputs, dim=1).flatten(2) 121 | ground = target.flatten(2) 122 | intersection = predict * ground 123 | dice = ( 124 | (intersection.sum(dim=2) * 2 + smooth) / 125 | (predict.sum(dim=2) + ground.sum(dim=2) + smooth) 126 | ).mean(dim=1) 127 | dice_loss = 1 - dice 128 | return dice_loss 129 | 130 | def calc_loss( 131 | self, 132 | inputs: Tensor, 133 | target: Tensor, 134 | ce_weight: Tensor | None=None, 135 | alpha: float=0.25, 136 | ) -> Dict[str, Any]: 137 | 138 | target = target.type_as(inputs) 139 | times = inputs.shape[-1] 140 | loss = 0. 141 | for t in range(times): 142 | dice = self.dice_loss( 143 | inputs[..., t], 144 | target, 145 | ).mean() 146 | ce = torch.zeros_like(dice) 147 | if alpha < 1.: 148 | ce = F.cross_entropy( 149 | inputs[..., t], 150 | target.argmax(dim=1), 151 | weight=ce_weight, 152 | reduction='mean', 153 | ) 154 | loss_t = alpha * dice + (1 - alpha) * ce 155 | loss = loss + loss_t / times 156 | 157 | return dict( 158 | loss=loss, 159 | dice_loss=dice, 160 | ce_loss=ce, 161 | ) 162 | 163 | def calc_score( 164 | self, 165 | inputs: Tensor, 166 | target: Tensor, 167 | eps: float=1e-5, 168 | ) -> Dict[str, Any]: 169 | 170 | predicts = torch.softmax(inputs[..., -1], dim=1) 171 | distance = torch.abs(predicts - target).mean(dim=(2, 3)).mean() 172 | 173 | return dict( 174 | mae=distance, 175 | ) 176 | 177 | def update_metric( 178 | self, 179 | inputs: Tensor, 180 | target: Tensor, 181 | ): 182 | 183 | predicts = F.one_hot(inputs[..., -1].argmax(dim=1)).permute(0, 3, 1, 2) 184 | self.m_iou_metric.update(predicts.to(torch.int), target.to(torch.int)) 185 | -------------------------------------------------------------------------------- /vklearn/models/trimnetx.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping, Any, Dict, Tuple, Callable 2 | 3 | from torch import Tensor 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .component import ConvNormActive, DEFAULT_NORM_LAYER 9 | from .component import MobileNetFeatures, DinoFeatures, CaresFeatures 10 | from .component import CBANet 11 | from .basic import Basic 12 | 13 | 14 | class TrimUnit(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | in_planes: int, 19 | out_planes: int, 20 | head_dim: int, 21 | scan_range: int=4, 22 | norm_layer: Callable[..., nn.Module]=DEFAULT_NORM_LAYER, 23 | ): 24 | 25 | super().__init__() 26 | 27 | assert out_planes % head_dim == 0 28 | groups = out_planes // head_dim 29 | dense_dim = out_planes // scan_range 30 | 31 | self.cbanet = CBANet(in_planes, out_planes, norm_layer=norm_layer) 32 | self.convs = nn.ModuleList() 33 | self.denses = nn.ModuleList() 34 | for r in range(scan_range): 35 | self.convs.append(ConvNormActive( 36 | out_planes, 37 | out_planes, 38 | dilation=2**r, 39 | groups=groups, 40 | norm_layer=None, 41 | activation=None, 42 | )) 43 | self.denses.append(ConvNormActive(out_planes, dense_dim, 1, norm_layer=norm_layer)) 44 | self.merge = ConvNormActive(dense_dim * scan_range, out_planes, 1, norm_layer=norm_layer) 45 | 46 | def forward(self, x:Tensor) -> Tensor: 47 | x = self.cbanet(x) 48 | ds = [] 49 | for conv, dense in zip(self.convs, self.denses): 50 | x = x + conv(x) 51 | ds.append(dense(x)) 52 | m = self.merge(torch.cat(ds, dim=1)) 53 | return m 54 | 55 | 56 | class TrimNetX(Basic): 57 | '''A light-weight and easy-to-train model 58 | 59 | Args: 60 | num_scans: Number of the Trim-Units. 61 | scan_range: Range factor of the Trim-Unit convolution. 62 | backbone: Specify a basic model as a feature extraction module. 63 | backbone_pretrained: Whether to load backbone pretrained weights. 64 | ''' 65 | 66 | def __init__( 67 | self, 68 | num_scans: int | None=None, 69 | scan_range: int | None=None, 70 | backbone: str | None=None, 71 | backbone_pretrained: bool | None=None, 72 | norm_layer: Callable[..., nn.Module]=DEFAULT_NORM_LAYER, 73 | ): 74 | 75 | super().__init__() 76 | 77 | if num_scans is None: 78 | num_scans = 3 79 | if scan_range is None: 80 | scan_range = 4 81 | if backbone is None: 82 | backbone = 'mobilenet_v3_small' 83 | if backbone_pretrained is None: 84 | backbone_pretrained = True 85 | 86 | self.num_scans = num_scans 87 | self.scan_range = scan_range 88 | self.backbone = backbone 89 | 90 | if backbone == 'mobilenet_v3_small': 91 | self.features = MobileNetFeatures( 92 | backbone, backbone_pretrained) 93 | self.features_dim = self.features.features_dim 94 | self.merged_dim = 128 95 | 96 | elif backbone == 'mobilenet_v3_large': 97 | self.features = MobileNetFeatures( 98 | backbone, backbone_pretrained) 99 | self.features_dim = self.features.features_dim 100 | self.merged_dim = 192 101 | 102 | elif backbone == 'mobilenet_v3_larges': 103 | self.features = MobileNetFeatures( 104 | backbone, backbone_pretrained) 105 | self.features_dim = self.features.features_dim 106 | self.merged_dim = 192 107 | 108 | elif backbone == 'mobilenet_v2': 109 | self.features = MobileNetFeatures( 110 | backbone, backbone_pretrained) 111 | self.features_dim = self.features.features_dim 112 | self.merged_dim = 192 113 | 114 | elif backbone.startswith('mobilenet_v4'): 115 | self.features = MobileNetFeatures( 116 | backbone, backbone_pretrained) 117 | self.features_dim = self.features.features_dim 118 | 119 | self.merged_dim = 128 120 | if 'medium' in backbone: 121 | self.merged_dim = 192 122 | elif 'large' in backbone: 123 | self.merged_dim = 384 124 | 125 | elif backbone == 'dinov2_vits14': 126 | self.features = DinoFeatures(backbone) 127 | self.features_dim = self.features.features_dim 128 | self.merged_dim = self.features_dim # 384 129 | 130 | elif backbone == 'dinov2_vits14_h192': 131 | self.features = DinoFeatures(backbone.rstrip('_h192')) 132 | self.features_dim = self.features.features_dim 133 | self.merged_dim = 192 134 | 135 | elif backbone == 'cares_small': 136 | self.features = CaresFeatures(arch='small') 137 | self.features_dim = self.features.features_dim 138 | self.merged_dim = 128 139 | 140 | elif backbone.startswith('cares_large'): 141 | self.features = CaresFeatures(arch=backbone.lstrip('cares_')) 142 | self.features_dim = self.features.features_dim 143 | self.merged_dim = 192 144 | 145 | else: 146 | raise ValueError(f'Unsupported backbone `{backbone}`') 147 | 148 | self.cell_size = self.features.cell_size 149 | 150 | self.projects = nn.ModuleList([ 151 | ConvNormActive(self.merged_dim, self.merged_dim, 1, activation=None, norm_layer=norm_layer) 152 | for _ in range(num_scans - 1)]) 153 | 154 | self.trim_units = nn.ModuleList() 155 | for t in range(num_scans): 156 | in_planes = self.features_dim 157 | if t > 0: 158 | in_planes = self.features_dim + self.merged_dim 159 | self.trim_units.append(TrimUnit( 160 | in_planes, 161 | self.merged_dim, 162 | head_dim=16, 163 | scan_range=scan_range, 164 | norm_layer=norm_layer, 165 | )) 166 | 167 | def forward( 168 | self, 169 | x: Tensor, 170 | num_scans: int | None=None, 171 | ) -> Tuple[List[Tensor], Tensor]: 172 | 173 | if not self._keep_features: 174 | f = self.features(x) 175 | else: 176 | with torch.no_grad(): 177 | f = self.features(x) 178 | 179 | if num_scans is None: 180 | num_scans = self.num_scans 181 | 182 | if not num_scans: return [], f 183 | h = self.trim_units[0](f) 184 | ht = [h] 185 | for t in range(1, num_scans): 186 | e = self.projects[t - 1](h) 187 | h = self.trim_units[t](torch.cat([f, e], dim=1)) 188 | ht.append(h) 189 | return ht, f 190 | 191 | @classmethod 192 | def load_from_state(cls, state:Mapping[str, Any]) -> 'TrimNetX': 193 | hyps = state['hyperparameters'] 194 | model = cls( 195 | num_scans = hyps['num_scans'], 196 | scan_range = hyps['scan_range'], 197 | backbone = hyps['backbone'], 198 | backbone_pretrained = False, 199 | ) 200 | model.load_state_dict(state['model']) 201 | return model 202 | 203 | def hyperparameters(self) -> Dict[str, Any]: 204 | return dict( 205 | num_scans = self.num_scans, 206 | scan_range = self.scan_range, 207 | backbone = self.backbone, 208 | ) 209 | -------------------------------------------------------------------------------- /vklearn/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/vklearn/pipelines/__init__.py -------------------------------------------------------------------------------- /vklearn/pipelines/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Mapping, Tuple 2 | import io 3 | 4 | from PIL.Image import Image as PILImage 5 | from numpy import ndarray 6 | from PIL import Image 7 | from matplotlib.pyplot import Figure 8 | import torch 9 | 10 | from ..models.classifier import Classifier as Model 11 | 12 | 13 | class Classifier: 14 | '''This class is used for handling image classification tasks. 15 | 16 | Args: 17 | model: Classifier model. 18 | ''' 19 | 20 | def __init__(self, model:Model): 21 | self.model = model 22 | 23 | @classmethod 24 | def load_from_state( 25 | cls, 26 | model: Model, 27 | state: Mapping[str, Any] | str, 28 | ) -> 'Classifier': 29 | 30 | if isinstance(state, str): 31 | state = torch.load(state, map_location='cpu', weights_only=True) 32 | return cls(model.load_from_state(state).eval()) 33 | 34 | def export_onnx( 35 | self, 36 | f: str | io.BytesIO, 37 | align_size: int=224, 38 | ): 39 | inputs = torch.randn(1, 3, align_size, align_size) 40 | torch.onnx.export( 41 | model=self.model, 42 | args=inputs, 43 | f=f, 44 | input_names=['input'], 45 | output_names=['output'], 46 | dynamic_axes={ 47 | 'input': {0: 'batch_size'}, 48 | 'output': {0: 'batch_size'}, 49 | }, 50 | ) 51 | 52 | def to(self, device:torch.device) -> 'Classifier': 53 | self.model.to(device) 54 | return self 55 | 56 | def __call__( 57 | self, 58 | image: PILImage | str | ndarray, 59 | top_k: int=10, 60 | align_size: int | Tuple[int, int]=224, 61 | ) -> List[Dict[str, Any]]: 62 | '''Invoke the method for image classification. 63 | 64 | Args: 65 | image: The image to be classified. 66 | top_k: Specifies the number of top classes to return, sorted by probability in descending order. 67 | align_size: The size to which the image will be aligned after preprocessing. 68 | ''' 69 | 70 | if isinstance(image, str): 71 | image = Image.open(image) 72 | if isinstance(image, ndarray): 73 | image = Image.fromarray(image, mode='RGB') 74 | 75 | with torch.no_grad(): 76 | result = self.model.classify( 77 | image=image, 78 | top_k=top_k, 79 | align_size=align_size, 80 | ) 81 | return result 82 | 83 | def plot_result( 84 | self, 85 | image: PILImage, 86 | result: List[Dict[str, Any]], 87 | fig: Figure, 88 | ): 89 | '''This method visualizes the model prediction results. 90 | 91 | Args: 92 | image: The image used for classification. 93 | result: The data returned after the model performs the classification. 94 | fig: The matplotlib Figure object. 95 | ''' 96 | 97 | fig.subplots_adjust(left=0.05) 98 | ax = fig.add_subplot(1, 9, (1, 6)) 99 | ax.margins(x=0, y=0) 100 | ax.set_axis_off() 101 | ax.imshow(image) 102 | ax = fig.add_subplot(1, 9, (8, 9)) 103 | ax.margins(x=0, y=0) 104 | for pos in ['top', 'right']: 105 | ax.spines[pos].set_visible(False) 106 | probs = result['probs'] 107 | colors = ['lightgray'] * len(probs) 108 | colors[-1] = 'steelblue' 109 | keys = sorted(probs.keys(), key=lambda k: probs[k]) 110 | values = [probs[k] for k in keys] 111 | p = ax.barh(keys, values, color=colors) 112 | ax.bar_label(p, padding=1) 113 | ax.tick_params(axis='y', rotation=45) 114 | -------------------------------------------------------------------------------- /vklearn/pipelines/detector.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Mapping 2 | import io 3 | 4 | import torch 5 | from PIL.Image import Image as PILImage 6 | from numpy import ndarray 7 | from PIL import Image 8 | from matplotlib.pyplot import Figure, Rectangle 9 | 10 | from ..models.detector import Detector as Model 11 | 12 | 13 | class Detector: 14 | '''This class is used for handling object detection tasks. 15 | 16 | Args: 17 | model: Object detector model. 18 | ''' 19 | 20 | def __init__(self, model:Model): 21 | self.model = model 22 | 23 | @classmethod 24 | def load_from_state( 25 | cls, 26 | model: Model, 27 | state: Mapping[str, Any] | str, 28 | ) -> 'Detector': 29 | 30 | if isinstance(state, str): 31 | state = torch.load(state, map_location='cpu', weights_only=True) 32 | return cls(model.load_from_state(state).eval()) 33 | 34 | def export_onnx( 35 | self, 36 | f: str | io.BytesIO, 37 | align_size: int=448, 38 | ): 39 | inputs = torch.randn(1, 3, align_size, align_size) 40 | torch.onnx.export( 41 | model=self.model, 42 | args=inputs, 43 | f=f, 44 | input_names=['input'], 45 | output_names=['output'], 46 | dynamic_axes={ 47 | 'input': {0: 'batch_size'}, 48 | 'output': {0: 'batch_size'}, 49 | }, 50 | ) 51 | 52 | def to(self, device:torch.device) -> 'Detector': 53 | self.model.to(device) 54 | return self 55 | 56 | def __call__( 57 | self, 58 | image: PILImage | str | ndarray, 59 | conf_thresh: float=0.5, 60 | recall_thresh: float=0.5, 61 | iou_thresh: float=0.5, 62 | align_size: int=448, 63 | mini_side: int=1, 64 | ) -> List[Dict[str, Any]]: 65 | '''Invoke the method for object detection. 66 | 67 | Args: 68 | image: The image to be detected. 69 | conf_thresh: Confidence threshold. 70 | recall_thresh: Recall score threshold. 71 | iou_thresh: Intersection over union threshold. 72 | align_size: The size to which the image will be aligned after preprocessing. 73 | mini_side: Minimum bounding box side length. 74 | ''' 75 | 76 | if isinstance(image, str): 77 | image = Image.open(image) 78 | if isinstance(image, ndarray): 79 | image = Image.fromarray(image, mode='RGB') 80 | 81 | with torch.no_grad(): 82 | result = self.model.detect( 83 | image=image, 84 | conf_thresh=conf_thresh, 85 | recall_thresh=recall_thresh, 86 | iou_thresh=iou_thresh, 87 | align_size=align_size, 88 | mini_side=mini_side, 89 | ) 90 | return result 91 | 92 | def plot_result( 93 | self, 94 | image: PILImage, 95 | result: List[Dict[str, Any]], 96 | fig: Figure, 97 | color: str='red', 98 | text_color: str='white', 99 | ): 100 | '''This method visualizes the model prediction results. 101 | 102 | Args: 103 | image: The image used for classification. 104 | result: The data returned after the model performs the detection. 105 | fig: The matplotlib Figure object. 106 | color: The color of annotates. 107 | text_color: The color of the label text. 108 | ''' 109 | 110 | ax = fig.add_subplot() 111 | ax.imshow(image) 112 | for obj in result: 113 | x1, y1, x2, y2 = obj['box'] 114 | ax.add_patch(Rectangle( 115 | (x1, y1), x2 - x1, y2 - y1, color=color, fill=False)) 116 | ax.annotate( 117 | f"{obj['label']}: {round(obj['score'], 3)}", 118 | (x1, y1), 119 | color=text_color, 120 | ha='left', 121 | va='bottom', 122 | bbox=dict( 123 | color=color, 124 | pad=0, 125 | ), 126 | ) 127 | -------------------------------------------------------------------------------- /vklearn/pipelines/joints.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Mapping, Sequence, Tuple 2 | import io 3 | 4 | import torch 5 | from PIL.Image import Image as PILImage 6 | from PIL import Image 7 | from numpy import ndarray 8 | import cv2 as cv 9 | from matplotlib.pyplot import Figure, Circle, Polygon 10 | 11 | from ..models.joints import Joints as Model 12 | 13 | 14 | class Joints: 15 | '''This class is used for handling keypoint&joint detection tasks. 16 | 17 | Args: 18 | model: Keypoint&joint detection model. 19 | ''' 20 | 21 | def __init__(self, model:Model): 22 | self.model = model 23 | 24 | @classmethod 25 | def load_from_state( 26 | cls, 27 | model: Model, 28 | state: Mapping[str, Any] | str, 29 | ) -> 'Joints': 30 | 31 | if isinstance(state, str): 32 | state = torch.load(state, map_location='cpu', weights_only=True) 33 | return cls(model.load_from_state(state).eval()) 34 | 35 | def export_onnx( 36 | self, 37 | f: str | io.BytesIO, 38 | align_size: int=448, 39 | ): 40 | inputs = torch.randn(1, 3, align_size, align_size) 41 | torch.onnx.export( 42 | model=self.model, 43 | args=inputs, 44 | f=f, 45 | input_names=['input'], 46 | output_names=['output'], 47 | dynamic_axes={ 48 | 'input': {0: 'batch_size'}, 49 | 'output': {0: 'batch_size'}, 50 | }, 51 | ) 52 | 53 | def to(self, device:torch.device) -> 'Joints': 54 | self.model.to(device) 55 | return self 56 | 57 | def __call__( 58 | self, 59 | image: PILImage | str | ndarray, 60 | joints_type: str='normal', 61 | conf_thresh: float=0.5, 62 | iou_thresh: float=0.5, 63 | align_size: int=448, 64 | score_thresh: float=0.5, 65 | ocr_params: Sequence[Tuple[float, int]]=((0.7, 2), (0.9, 2)), 66 | ) -> List[Dict[str, Any]]: 67 | '''Invoke the method for keypoint&joint detection. 68 | 69 | Args: 70 | image: The image to be detected. 71 | joints_type: The type of joints operation. 72 | conf_thresh: Confidence threshold. 73 | iou_thresh: Intersection over union threshold. 74 | align_size: The size to which the image will be aligned after preprocessing. 75 | ''' 76 | 77 | if isinstance(image, str): 78 | image = Image.open(image) 79 | if isinstance(image, ndarray): 80 | image = Image.fromarray(image, mode='RGB') 81 | 82 | with torch.no_grad(): 83 | result = self.model.detect( 84 | image=image, 85 | joints_type=joints_type, 86 | conf_thresh=conf_thresh, 87 | iou_thresh=iou_thresh, 88 | align_size=align_size, 89 | score_thresh=score_thresh, 90 | ocr_params=ocr_params, 91 | ) 92 | return result 93 | 94 | def plot_result( 95 | self, 96 | image: PILImage, 97 | result: List[Dict[str, Any]], 98 | fig: Figure, 99 | show_annotate: bool=True, 100 | show_nodes: bool=False, 101 | show_heatmap: bool=False, 102 | ): 103 | '''This method visualizes the model prediction results. 104 | 105 | Args: 106 | image: The image used for keypoint&joint detection. 107 | result: The data returned after the model performs the detection. 108 | fig: The matplotlib Figure object. 109 | ''' 110 | 111 | if show_heatmap: 112 | ax = fig.add_subplot(1, 2, 1) 113 | fig.add_subplot(1, 2, 2).imshow(result['heatmap']) 114 | else: 115 | ax = fig.add_subplot() 116 | ax.imshow(image) 117 | for obj in result['objs']: 118 | rect = obj['rect'] 119 | pts = cv.boxPoints(rect) 120 | ax.add_patch(Polygon(pts, closed=True, fill=False, color='red')) 121 | if not show_annotate: continue 122 | ax.annotate( 123 | f"{obj['label']}: {round(obj['score'], 3)}", 124 | (pts[1] + pts[2]) * 0.5, 125 | color='white', 126 | ha='center', 127 | va='center', 128 | rotation=-rect[-1], 129 | bbox=dict( 130 | boxstyle='rarrow', 131 | color='red', 132 | pad=0, 133 | ), 134 | ) 135 | if show_nodes: 136 | for node in result['nodes']: 137 | x1, y1, x2, y2 = node['box'] 138 | xy = (x1 + x2) * 0.5, (y1 + y2) * 0.5 139 | radius = min(x2 - x1, y2 - y1) * 0.25 140 | color = 'blue' 141 | fill = True 142 | alpha = 0.5 143 | if node['anchor'] == 1: color = 'yellow' 144 | ax.add_patch(Circle( 145 | xy, radius, color=color, fill=fill, linewidth=1, alpha=alpha)) 146 | -------------------------------------------------------------------------------- /vklearn/pipelines/ocr.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Mapping 2 | import io 3 | 4 | from PIL.Image import Image as PILImage 5 | from numpy import ndarray 6 | from PIL import Image 7 | from matplotlib.pyplot import Figure 8 | import torch 9 | 10 | from ..models.ocr import OCR as Model 11 | 12 | 13 | class OCR: 14 | '''This class is used for handling image ocr tasks. 15 | 16 | Args: 17 | model: ocr model. 18 | ''' 19 | 20 | def __init__(self, model:Model): 21 | self.model = model 22 | 23 | @classmethod 24 | def load_from_state( 25 | cls, 26 | model: Model, 27 | state: Mapping[str, Any] | str, 28 | ) -> 'OCR': 29 | 30 | if isinstance(state, str): 31 | state = torch.load(state, map_location='cpu', weights_only=True) 32 | return cls(model.load_from_state(state).eval()) 33 | 34 | def export_onnx( 35 | self, 36 | f: str | io.BytesIO, 37 | align_size: int=32, 38 | ): 39 | inputs = torch.randn(1, 3, align_size, align_size * 10) 40 | torch.onnx.export( 41 | model=self.model, 42 | args=inputs, 43 | f=f, 44 | input_names=['input'], 45 | output_names=['output'], 46 | dynamic_axes={ 47 | 'input': {3: 'img_width'}, 48 | 'output': {3: 'seq_length'}, 49 | }, 50 | ) 51 | 52 | def to(self, device:torch.device) -> 'OCR': 53 | self.model.to(device) 54 | return self 55 | 56 | def __call__( 57 | self, 58 | image: PILImage | str | ndarray, 59 | top_k: int=10, 60 | align_size: int=224, 61 | to_gray: bool=True, 62 | whitelist: List[str] | None=None, 63 | ) -> List[Dict[str, Any]]: 64 | '''Invoke the method for image ocr. 65 | 66 | Args: 67 | image: The image to be recognized. 68 | top_k: Specifies the number of top classes to return, sorted by probability in descending order. 69 | align_size: The size to which the image will be aligned after preprocessing. 70 | to_gray: Convert the image mode to be gray, default: True. 71 | whitelist: The whitelist of the characters, default: None is disable the whitelist. 72 | ''' 73 | 74 | if isinstance(image, str): 75 | image = Image.open(image) 76 | if isinstance(image, ndarray): 77 | image = Image.fromarray(image, mode='RGB') 78 | 79 | with torch.no_grad(): 80 | result = self.model.recognize( 81 | image=image, 82 | top_k=top_k, 83 | align_size=align_size, 84 | to_gray=to_gray, 85 | whitelist=whitelist, 86 | ) 87 | return result 88 | 89 | def plot_result( 90 | self, 91 | image: PILImage, 92 | result: List[Dict[str, Any]], 93 | fig: Figure, 94 | ): 95 | '''This method visualizes the model prediction results. 96 | 97 | Args: 98 | image: The image used for ocr. 99 | result: The data returned after the model performs the ocr. 100 | fig: The matplotlib Figure object. 101 | ''' 102 | 103 | text = result['text'] 104 | ax = fig.add_subplot() 105 | ax.imshow(image) 106 | ax.set_xlabel(text) 107 | -------------------------------------------------------------------------------- /vklearn/pipelines/segment.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import io 3 | 4 | from PIL.Image import Image as PILImage 5 | from numpy import ndarray 6 | from PIL import Image 7 | from matplotlib.pyplot import Figure 8 | import torch 9 | import numpy as np 10 | 11 | from ..models.segment import Segment as Model 12 | 13 | 14 | class Segment: 15 | '''This class is used for handling semantic segmentation tasks. 16 | 17 | Args: 18 | model: Semantic segmentation model. 19 | ''' 20 | 21 | 22 | def __init__(self, model:Model): 23 | self.model = model 24 | 25 | @classmethod 26 | def load_from_state( 27 | cls, 28 | model: Model, 29 | state: Mapping[str, Any] | str, 30 | ) -> 'Segment': 31 | 32 | if isinstance(state, str): 33 | state = torch.load(state, map_location='cpu', weights_only=True) 34 | return cls(model.load_from_state(state).eval()) 35 | 36 | def export_onnx( 37 | self, 38 | f: str | io.BytesIO, 39 | align_size: int=448, 40 | ): 41 | inputs = torch.randn(1, 3, align_size, align_size) 42 | torch.onnx.export( 43 | model=self.model, 44 | args=inputs, 45 | f=f, 46 | input_names=['input'], 47 | output_names=['output'], 48 | dynamic_axes={ 49 | 'input': {0: 'batch_size'}, 50 | 'output': {0: 'batch_size'}, 51 | }, 52 | ) 53 | 54 | def to(self, device:torch.device) -> 'Segment': 55 | self.model.to(device) 56 | return self 57 | 58 | def __call__( 59 | self, 60 | image: PILImage | str | ndarray, 61 | conf_thresh: float=0.5, 62 | align_size: int=448, 63 | ) -> ndarray: 64 | '''Invoke the method for semantic segmentation. 65 | 66 | Args: 67 | image: The image to be segmented. 68 | conf_thresh: Confidence threshold. 69 | align_size: The size to which the image will be aligned after preprocessing. 70 | ''' 71 | 72 | if isinstance(image, str): 73 | image = Image.open(image) 74 | if isinstance(image, ndarray): 75 | image = Image.fromarray(image, mode='RGB') 76 | 77 | with torch.no_grad(): 78 | result = self.model.segment( 79 | image=image, 80 | conf_thresh=conf_thresh, 81 | align_size=align_size, 82 | ) 83 | return result 84 | 85 | def plot_result( 86 | self, 87 | image: PILImage, 88 | result: ndarray, 89 | fig: Figure, 90 | ): 91 | '''This method visualizes the model prediction results. 92 | 93 | Args: 94 | image: The image used for classification. 95 | result: The data returned after the model performs the segmentation. 96 | fig: The matplotlib Figure object. 97 | ''' 98 | 99 | plot_cols = len(result) 100 | for i in range(plot_cols): 101 | ax = fig.add_subplot(1, plot_cols, i + 1) 102 | mask = 1 - result[i] 103 | frame = np.array(image, dtype=np.uint8) 104 | frame[..., 1] = (frame[..., 1] * mask).astype(np.uint8) 105 | ax.imshow(frame) 106 | -------------------------------------------------------------------------------- /vklearn/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/vklearn/trainer/__init__.py -------------------------------------------------------------------------------- /vklearn/trainer/logging.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from collections import defaultdict, deque 3 | import json 4 | import os 5 | from datetime import datetime 6 | 7 | from torch import Tensor 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Logger: 12 | DEFAULT_MODES = [ 13 | 'train', 14 | 'valid', 15 | 'test', 16 | 'metric', 17 | ] 18 | 19 | def __init__( 20 | self, 21 | name: str | None=None, 22 | log_dir: str | None=None, 23 | ): 24 | 25 | self._name = name or 'TRAIN' 26 | self.reset() 27 | 28 | self._log_dir = log_dir or os.path.join( 29 | os.path.abspath(os.path.curdir), 'logs') 30 | if not os.path.isdir(self._log_dir): 31 | os.makedirs(self._log_dir, exist_ok=True) 32 | self._dump_path = os.path.join(self._log_dir, ( 33 | f'{self._name}-LOG-' 34 | f'{datetime.now():%y%m%d-%H:%M:%S}' 35 | '.txt' 36 | )) 37 | 38 | def reset(self, maxlen:int=100): 39 | self._step = defaultdict(int) 40 | # self._logs = {mode: defaultdict(float) 41 | # self._logs = {mode: defaultdict(lambda: deque(maxlen=maxlen)) 42 | # for mode in self.DEFAULT_MODES} 43 | self._logs = {} 44 | for mode in self.DEFAULT_MODES: 45 | if mode in ('train', 'valid'): 46 | self._logs[mode] = defaultdict(lambda: deque(maxlen=maxlen)) 47 | else: 48 | self._logs[mode] = defaultdict(float) 49 | 50 | def update( 51 | self, 52 | mode: str, 53 | *datas: Dict[str, float | Tensor], 54 | ): 55 | 56 | for data in datas: 57 | for key, raw in data.items(): 58 | if isinstance(raw, Tensor): 59 | value = raw.item() 60 | else: 61 | value = raw 62 | # self._logs[mode][key] += value 63 | # self._logs[mode][key].append(value) 64 | if isinstance(self._logs[mode][key], deque): 65 | self._logs[mode][key].append(value) 66 | else: 67 | self._logs[mode][key] += value 68 | self._step[mode] += 1 69 | 70 | def calc_mean_value( 71 | self, 72 | values: float | deque, 73 | step: int, 74 | ) -> float: 75 | 76 | if not isinstance(values, deque): 77 | return values / max(1, step) 78 | return sum(values) / max(1, len(values)) 79 | 80 | def compute( 81 | self, 82 | mode: str, 83 | # mean: bool=True, 84 | ) -> Dict[str, float] | None: 85 | 86 | if self._step[mode] < 1: return dict() 87 | # step = self._step[mode] if mean else 1 88 | step = self._step[mode] 89 | # calc_mean = lambda vs: round(sum(vs) / max(1, len(vs)), 5) 90 | return { 91 | # k: round(v / step, 5) 92 | # for k, v in self._logs[mode].items()} 93 | # k: calc_mean(vs) for k, vs in self._logs[mode].items()} 94 | k: round(self.calc_mean_value(vs, step), 5) 95 | for k, vs in self._logs[mode].items()} 96 | 97 | def dumps(self, *modes:str) -> str: 98 | modes = modes or self.DEFAULT_MODES 99 | datas = {mode: self.compute(mode) for mode in modes} 100 | return json.dumps(datas, ensure_ascii=False) 101 | 102 | def dumpf(self, *modes:str) -> str: 103 | dump_str = self.dumps(*modes) 104 | with open(self._dump_path, 'a') as f: 105 | f.write(dump_str + '\n') 106 | return dump_str 107 | 108 | def plot(self): 109 | self.plot_from_disk(self._dump_path) 110 | 111 | @classmethod 112 | def plot_from_disk(cls, filepath:str): 113 | frames:Dict[str, Dict[str, List]] = dict() 114 | with open(filepath) as f: 115 | lines = f.readlines() 116 | for line in lines: 117 | line = line.strip() 118 | if not line: continue 119 | datas = json.loads(line) 120 | for mode, data in datas.items(): 121 | for key, value in data.items(): 122 | if key not in frames: 123 | frames[key] = dict() 124 | if mode not in frames[key]: 125 | frames[key][mode] = list() 126 | frames[key][mode].append(value) 127 | 128 | path = os.path.splitext(filepath)[0] 129 | for key, frame in frames.items(): 130 | plt.figure() 131 | plt.title(key) 132 | plt.xlabel('epoch') 133 | for mode, values in frame.items(): 134 | plt.plot(range(1, len(values) + 1), values, label=mode) 135 | plt.legend() 136 | plt.savefig(f'{path}_{key}.png') 137 | 138 | @classmethod 139 | def create_by_output(cls, output:str) -> 'Logger': 140 | basename = os.path.basename(output).upper() 141 | log_dir = os.path.join(os.path.dirname(output), 'logs') 142 | return cls(name=basename, log_dir=log_dir) 143 | -------------------------------------------------------------------------------- /vklearn/trainer/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | from dataclasses import dataclass, field 3 | import os.path 4 | 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import LRScheduler 7 | import torch 8 | 9 | from .logging import Logger 10 | from ..models.basic import Basic as Model 11 | 12 | 13 | @dataclass 14 | class Task: 15 | '''This `class` is used to configure a set of parameters relevant to a specific task in model training. 16 | 17 | Args: 18 | model: Model object for a specific task. 19 | device: Computation device supported by PyTorch. 20 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 21 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 22 | loss_options: Optional parameters for the given loss calculation function. 23 | score_options: Optional parameters for the given score calculation function. 24 | metric_options: Optional parameters for the given metric evaluation function. 25 | key_metrics: Specifies which key evaluation metrics to track. 26 | best_metric: Current best metric score, initialized to 0. 27 | ''' 28 | 29 | model: Model 30 | device: torch.device 31 | metric_start_epoch: int=0 32 | fit_features_start: int=-1 33 | loss_options: Dict[str, Any]=field(default_factory=dict) 34 | score_options: Dict[str, Any]=field(default_factory=dict) 35 | metric_options: Dict[str, Any]=field(default_factory=dict) 36 | key_metrics: Tuple[str]=field(default_factory=tuple) 37 | best_metric: float=0 38 | 39 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 40 | 41 | assert not 'this is an empty func' 42 | 43 | def train_on_step( 44 | self, 45 | epoch: int, 46 | step: int, 47 | sample: Any, 48 | logger: Logger, 49 | grad_steps: int, 50 | ): 51 | 52 | inputs, target = self.sample_convert(sample) 53 | 54 | model = self.model 55 | model.train_features( 56 | self.fit_features_start >= 0 and 57 | epoch >= self.fit_features_start 58 | ) 59 | 60 | outputs = model(inputs) 61 | losses = model.calc_loss( 62 | outputs, *target, **self.loss_options) 63 | 64 | loss = losses['loss'] / grad_steps 65 | loss.backward() 66 | 67 | with torch.no_grad(): 68 | scores = model.calc_score( 69 | outputs, *target, **self.score_options) 70 | 71 | logger.update('train', losses, scores) 72 | 73 | def valid_on_step( 74 | self, 75 | epoch: int, 76 | step: int, 77 | sample: Any, 78 | logger: Logger, 79 | ): 80 | 81 | inputs, target = self.sample_convert(sample) 82 | 83 | model = self.model 84 | 85 | with torch.no_grad(): 86 | outputs = model(inputs) 87 | losses = model.calc_loss( 88 | outputs, *target, **self.loss_options) 89 | scores = model.calc_score( 90 | outputs, *target, **self.score_options) 91 | 92 | logger.update('valid', losses, scores) 93 | 94 | def test_on_step( 95 | self, 96 | epoch: int, 97 | step: int, 98 | sample: Any, 99 | logger: Logger, 100 | ): 101 | 102 | inputs, target = self.sample_convert(sample) 103 | 104 | model = self.model 105 | 106 | with torch.no_grad(): 107 | outputs = model(inputs) 108 | losses = model.calc_loss( 109 | outputs, *target, **self.loss_options) 110 | scores = model.calc_score( 111 | outputs, *target, **self.score_options) 112 | if epoch >= self.metric_start_epoch: 113 | model.update_metric( 114 | outputs, *target, **self.metric_options) 115 | 116 | logger.update('test', losses, scores) 117 | 118 | def end_on_epoch( 119 | self, 120 | epoch: int, 121 | logger: Logger, 122 | ): 123 | 124 | model = self.model 125 | metric = dict(zip(self.key_metrics, [0.] * len(self.key_metrics))) 126 | if epoch >= self.metric_start_epoch: 127 | metric = {k: v 128 | for k, v in model.compute_metric().items() 129 | if k in metric} 130 | logger.update('metric', metric) 131 | 132 | def save_checkpoint( 133 | self, 134 | epoch: int, 135 | output: str, 136 | optimizer: Optimizer, 137 | lr_scheduler: LRScheduler, 138 | ) -> str: 139 | 140 | out_path = os.path.splitext(output)[0] 141 | filename = f'{out_path}-{epoch}.pt' 142 | torch.save({ 143 | 'model': self.model.state_dict(), 144 | 'hyperparameters': self.model.hyperparameters(), 145 | 'optim': optimizer.state_dict(), 146 | 'lr_scheduler': lr_scheduler.state_dict()}, filename) 147 | return filename 148 | 149 | def choose_best_model( 150 | self, 151 | output: str, 152 | optimizer: Optimizer, 153 | lr_scheduler: LRScheduler, 154 | logger: Logger, 155 | ) -> bool: 156 | 157 | metric = logger.compute('metric')[self.key_metrics[0]] 158 | if self.best_metric >= metric: return False 159 | 160 | self.best_metric = metric 161 | out_path = os.path.splitext(output)[0] 162 | filename = f'{out_path}-best.pt' 163 | torch.save({ 164 | 'model': self.model.state_dict(), 165 | 'hyperparameters': self.model.hyperparameters(), 166 | 'optim': optimizer.state_dict(), 167 | 'lr_scheduler': lr_scheduler.state_dict()}, filename) 168 | return True 169 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detection 2 | from .classification import Classification 3 | from .segmentation import Segmentation 4 | from .joints import Joints 5 | from .distillation import Distillation 6 | from .ocr import OCR 7 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/classification.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.classifier import Classifier as Model 6 | 7 | 8 | @dataclass 9 | class Classification(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in classifier model training. 11 | 12 | Args: 13 | model: Specify a classification model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the classification model's loss function. 18 | score_options: Set optional parameters for the classification model's scoring function. 19 | metric_options: Set optional parameters for the classification model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('f1_score', 'precision', 'recall') 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs, target = [item.to(self.device) for item in sample] 29 | return inputs, [target] 30 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/detection.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.detector import Detector as Model 6 | 7 | 8 | @dataclass 9 | class Detection(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in object detector model training. 11 | 12 | Args: 13 | model: Specify a object detection model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the object detection model's loss function. 18 | score_options: Set optional parameters for the object detection model's scoring function. 19 | metric_options: Set optional parameters for the object detection model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('map', 'map_50', 'map_75') 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs, target_labels, target_bboxes = [ 29 | sample[i].to(self.device) for i in [0, 2, 3]] 30 | target_index = sample[1] 31 | target = target_index, target_labels, target_bboxes 32 | return inputs, target 33 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/distillation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.distiller import Distiller as Model 6 | 7 | 8 | @dataclass 9 | class Distillation(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in distiller model training. 11 | 12 | Args: 13 | model: Specify a distillation model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the distillation model's loss function. 18 | score_options: Set optional parameters for the distillation model's scoring function. 19 | metric_options: Set optional parameters for the distillation model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('mss', 'mse', 'mae') 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs = sample[0].to(self.device) 29 | return inputs, [None] 30 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/joints.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.joints import Joints as Model 6 | 7 | 8 | @dataclass 9 | class Joints(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in keypoint&joint detector model training. 11 | 12 | Args: 13 | model: Specify a keypoint&joint detection model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the keypoint&joint detection model's loss function. 18 | score_options: Set optional parameters for the keypoint&joint detection model's scoring function. 19 | metric_options: Set optional parameters for the keypoint&joint detection model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('mjoin', 'map', 'map_50', 'map_75', 'miou') 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs, target_labels, target_bboxes, target_masks = [ 29 | sample[i].to(self.device) for i in [0, 2, 3, 4]] 30 | target_index = sample[1] 31 | target = target_index, target_labels, target_bboxes, target_masks 32 | return inputs, target 33 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/ocr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.ocr import OCR as Model 6 | 7 | 8 | @dataclass 9 | class OCR(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in ocr model training. 11 | 12 | Args: 13 | model: Specify a ocr model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the ocr model's loss function. 18 | score_options: Set optional parameters for the ocr model's scoring function. 19 | metric_options: Set optional parameters for the ocr model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('c_score', 'cer') 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs, targets, target_lengths = [item.to(self.device) for item in sample] 29 | return inputs, (targets, target_lengths) 30 | -------------------------------------------------------------------------------- /vklearn/trainer/tasks/segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | from dataclasses import dataclass 3 | 4 | from ..task import Task 5 | from ...models.segment import Segment as Model 6 | 7 | 8 | @dataclass 9 | class Segmentation(Task): 10 | '''This `class` is used to configure a set of parameters relevant to a specific task in segmentation model training. 11 | 12 | Args: 13 | model: Specify a segmentation model object. 14 | device: Computation device supported by PyTorch. 15 | metric_start_epoch: Sets the epoch from which metric calculation starts, defaults to 0. 16 | fit_features_start: Sets the epoch from which feature extractor training starts, -1 means no training, defaults to -1. 17 | loss_options: Set optional parameters for the segmentation model's loss function. 18 | score_options: Set optional parameters for the segmentation model's scoring function. 19 | metric_options: Set optional parameters for the segmentation model's metric evaluation function. 20 | key_metrics: Specifies which key evaluation metrics to track. 21 | best_metric: Current best metric score, initialized to 0. 22 | ''' 23 | 24 | model: Model 25 | key_metrics: Tuple[str]=('miou', ) 26 | 27 | def sample_convert(self, sample: Any) -> Tuple[Any, Any]: 28 | inputs, target = [item.to(self.device) for item in sample] 29 | return inputs, [target] 30 | -------------------------------------------------------------------------------- /vklearn/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | from pprint import pprint 3 | from typing import Callable 4 | from dataclasses import dataclass 5 | import math 6 | 7 | from torch.optim import Optimizer, AdamW 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | from .task import Task 15 | from .logging import Logger 16 | 17 | 18 | @dataclass 19 | class Trainer: 20 | '''This `class` is used to set general parameters for model training and starts the model training process through the `fit` method. 21 | 22 | Args: 23 | task: The object of the training task. 24 | output: The path to the output file. 25 | train_loader: The data loader for the training set. 26 | valid_loader: The data loader for the validation set. 27 | test_loader: The data loader for the test set. 28 | checkpoint: The archive file for the model training parameters. 29 | drop_optim: Whether to drop the optimizer parameters from the archive file. 30 | drop_lr_scheduler: Whether to drop the learning rate parameters from the archive file. 31 | optim_method: The optimization method. 32 | lr: The learning rate. 33 | weight_decay: The regularization weight. 34 | lrf: The learning rate decay factor. 35 | T_num: The number of learning rate change cycles. 36 | grad_steps: The number of steps for gradient updates. 37 | grad_max_norm: The max norm of gradient. 38 | epochs: The total number of training epochs. 39 | show_step: Set the interval for displaying the training status in steps. 40 | save_epoch: Set the interval for saving the model in epochs. 41 | logs_deque_limit: Set the limit of the logs deque. 42 | ''' 43 | 44 | task: Task 45 | output: str 46 | train_loader: DataLoader 47 | 48 | valid_loader: DataLoader=None 49 | test_loader: DataLoader=None 50 | checkpoint: str | None=None 51 | drop_optim: bool=False 52 | drop_lr_scheduler: bool=False 53 | optim_method: Callable[..., Optimizer]=AdamW 54 | lr: float=1e-3 55 | weight_decay: float | None=None 56 | lrf: float=1. 57 | T_num: float=1. 58 | grad_steps: int=1 59 | grad_max_norm: float=0. 60 | epochs: int=1 61 | show_step: int=50 62 | save_epoch: int=1 63 | logs_deque_limit: int=77 64 | 65 | def _dump_progress( 66 | self, 67 | epoch: int, 68 | step: int, 69 | loader: DataLoader, 70 | ) -> str: 71 | 72 | return ( 73 | f'epoch: {epoch + 1}/{self.epochs}, ' 74 | f'step: {step + 1}/{len(loader)}' 75 | ) 76 | 77 | def initialize(self): 78 | print('Preparing ...') 79 | self.device:torch.device = self.task.device 80 | print('device:', self.device) 81 | 82 | self.model:nn.Module = self.task.model.to(self.device) 83 | 84 | if self.weight_decay is None: 85 | self.weight_decay = 0. 86 | if self.optim_method is AdamW: 87 | self.weight_decay = 0.01 88 | 89 | self.optimizer:Optimizer = self.optim_method( 90 | self.model.parameters(), 91 | lr=self.lr, 92 | weight_decay=self.weight_decay) 93 | print('optimizer:', self.optimizer) 94 | 95 | for kind, loader in zip( 96 | ['train', 'valid', 'test'], 97 | [self.train_loader, self.valid_loader, self.test_loader], 98 | ): 99 | if loader is None: continue 100 | print(f'{kind} dataset:', loader.dataset) 101 | pprint(dict( 102 | batch_size=loader.batch_size, 103 | num_workers=loader.num_workers, 104 | )) 105 | 106 | assert self.lrf <= 1. 107 | self.lr_scheduler = LambdaLR(self.optimizer, lr_lambda= 108 | lambda epoch: 109 | (1 + math.cos(epoch / (self.epochs / self.T_num) * math.pi)) * 110 | 0.5 * (1 - self.lrf) + self.lrf) 111 | 112 | if self.checkpoint is not None: 113 | print('checkpoint:', self.checkpoint) 114 | state_dict = torch.load(self.checkpoint, weights_only=True) 115 | self.model.load_state_dict(state_dict['model'], strict=False) 116 | if not self.drop_optim: 117 | self.optimizer.load_state_dict(state_dict['optim']) 118 | if not self.drop_lr_scheduler: 119 | self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) 120 | 121 | def fit( 122 | self, 123 | max_train_step: int=0, 124 | max_test_step: int=0, 125 | ): 126 | 127 | task = self.task 128 | optimizer = self.optimizer 129 | lr_scheduler = self.lr_scheduler 130 | logger = Logger.create_by_output(self.output) 131 | print('-' * 80) 132 | print('Training ...') 133 | for epoch in range(self.epochs): 134 | self.model.train() 135 | optimizer.zero_grad() 136 | logger.reset(maxlen=self.logs_deque_limit) 137 | 138 | train_loader = self.train_loader 139 | 140 | valid_loader = self.valid_loader 141 | if valid_loader is not None: 142 | valid_generator = iter(valid_loader) 143 | 144 | print('train mode:', self.model.training) 145 | print(f'lr={lr_scheduler.get_last_lr()}') 146 | for step, sample in enumerate(train_loader): 147 | task.train_on_step(epoch, step, sample, logger, self.grad_steps) 148 | 149 | if (step + 1) % self.show_step == 0: 150 | print(self._dump_progress(epoch, step, train_loader)) 151 | print(logger.dumps('train')) 152 | 153 | if valid_loader is not None: 154 | try: 155 | sample = next(valid_generator) 156 | except StopIteration: 157 | valid_generator = iter(valid_loader) 158 | sample = next(valid_generator) 159 | task.valid_on_step(epoch, step, sample, logger) 160 | 161 | if (step + 1) % self.show_step == 0: 162 | print('valid:', logger.dumps('valid')) 163 | 164 | if (step + 1) % self.grad_steps == 0: 165 | if self.grad_max_norm > 0.: 166 | torch.nn.utils.clip_grad_norm_( 167 | self.model.parameters(), max_norm=self.grad_max_norm) 168 | optimizer.step() 169 | optimizer.zero_grad() 170 | 171 | if (max_train_step > 0) and ((step + 1) >= max_train_step): 172 | break 173 | 174 | optimizer.zero_grad() 175 | lr_scheduler.step() 176 | 177 | if (epoch + 1) % self.save_epoch == 0: 178 | checkpoint_filename = task.save_checkpoint( 179 | epoch, self.output, optimizer, lr_scheduler) 180 | print('save checkpoint -> {}'.format(checkpoint_filename)) 181 | 182 | self.model.eval() 183 | test_loader = self.test_loader 184 | 185 | if test_loader is not None: 186 | print('train mode:', self.model.training) 187 | for step, sample in enumerate(test_loader): 188 | task.test_on_step(epoch, step, sample, logger) 189 | 190 | if (step + 1) % self.show_step == 0: 191 | print(self._dump_progress(epoch, step, test_loader)) 192 | print(logger.dumps('test')) 193 | 194 | if (max_test_step > 0) and ((step + 1) >= max_test_step): 195 | break 196 | 197 | task.end_on_epoch(epoch, logger) 198 | print(logger.dumpf()) 199 | 200 | print('A new best model emerges:', task.choose_best_model( 201 | self.output, optimizer, lr_scheduler, logger)) 202 | 203 | checkpoint_filename = task.save_checkpoint( 204 | epoch, self.output, optimizer, lr_scheduler) 205 | print('finished and save checkpoint -> {}'.format(checkpoint_filename)) 206 | logger.plot() 207 | -------------------------------------------------------------------------------- /vklearn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bxt-kk/vikit-learn/9cab54cd542e861fb43ffc87191209595aa45ebe/vklearn/utils/__init__.py -------------------------------------------------------------------------------- /vklearn/utils/focal_boost.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | 4 | from torch import Tensor 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from torchvision.ops import ( 9 | sigmoid_focal_loss, 10 | ) 11 | 12 | 13 | def focal_boost_iter( 14 | inputs: Tensor, 15 | target_index: List[Tensor], 16 | sample_mask: Tensor | None, 17 | conf_id: int, 18 | num_confs: int, 19 | alpha: float, 20 | gamma: float, 21 | ) -> Tuple[Tensor, Tensor, Tensor]: 22 | 23 | reduction = 'mean' 24 | 25 | pred_conf = inputs[..., conf_id] 26 | targ_conf = torch.zeros_like(pred_conf) 27 | targ_conf[target_index] = 1. 28 | 29 | if sample_mask is None: 30 | sample_mask = targ_conf >= -1 31 | 32 | sampled_pred = torch.masked_select(pred_conf, sample_mask) 33 | sampled_targ = torch.masked_select(targ_conf, sample_mask) 34 | sampled_loss = sigmoid_focal_loss( 35 | inputs=sampled_pred, 36 | targets=sampled_targ, 37 | alpha=alpha, 38 | gamma=gamma, 39 | reduction=reduction, 40 | ) 41 | 42 | obj_loss = 0. 43 | obj_mask = targ_conf > 0.5 44 | if obj_mask.sum() > 0: 45 | # Lab code <<< 46 | # obj_pred = torch.masked_select(pred_conf, obj_mask) 47 | # obj_targ = torch.masked_select(targ_conf, obj_mask) 48 | # obj_loss = F.binary_cross_entropy_with_logits( 49 | # obj_pred, obj_targ, reduction=reduction) 50 | # instance_weight = ( 51 | # 1 / torch.clamp_min(targ_conf.flatten(start_dim=1).sum(dim=1), 1) 52 | # )[target_index[0]] 53 | instance_weight = 1 / target_index[0].bincount().type_as(targ_conf)[target_index[0]] 54 | obj_pred = pred_conf[target_index] 55 | obj_targ = targ_conf[target_index] 56 | obj_loss = (instance_weight * F.binary_cross_entropy_with_logits( 57 | obj_pred, obj_targ, reduction='none')).sum() / inputs.shape[0] 58 | # >>> 59 | 60 | obj_pred_min = obj_pred.detach().min() 61 | sample_mask = pred_conf.detach() >= obj_pred_min 62 | 63 | F_sigma = lambda t: (math.cos(t / num_confs * math.pi) + 1) * 0.5 64 | sigma_0 = 0.25 65 | 66 | sigma_T = F_sigma(num_confs - 1) 67 | scale_r = (1 - sigma_0) / (1 - sigma_T) 68 | sigma_k = F_sigma(conf_id) * scale_r - scale_r + 1 69 | # num_foreground_per_img = (sampled_targ.sum() / len(pred_conf)).item() 70 | sn_ratio = min(1, 2 * targ_conf.sum().item() / targ_conf.numel()) 71 | an_ratio = 1 - sn_ratio 72 | 73 | conf_loss = ( 74 | # obj_loss * sigma_k / max(1, num_foreground_per_img**0.5) + 75 | obj_loss * sigma_k * an_ratio + 76 | sampled_loss) / num_confs 77 | 78 | return conf_loss, sampled_loss, sample_mask 79 | 80 | 81 | def focal_boost_loss( 82 | inputs: Tensor, 83 | target_index: List[Tensor], 84 | num_confs: int, 85 | alpha: float=0.25, 86 | gamma: float=2., 87 | ) -> Tuple[Tensor, Tensor]: 88 | 89 | conf_loss, sampled_loss, sample_mask = focal_boost_iter( 90 | inputs, target_index, None, 0, num_confs, alpha, gamma) 91 | for conf_id in range(1, num_confs): 92 | conf_loss_i, sampled_loss, sample_mask = focal_boost_iter( 93 | inputs, target_index, sample_mask, conf_id, num_confs, alpha, gamma) 94 | conf_loss += conf_loss_i 95 | return conf_loss, sampled_loss 96 | 97 | 98 | def focal_boost_predict( 99 | inputs: Tensor, 100 | num_confs: int, 101 | recall_thresh: float, 102 | ) -> Tensor: 103 | 104 | predict = torch.ones_like(inputs[..., 0]) 105 | for conf_id in range(num_confs - 1): 106 | predict[torch.sigmoid(inputs[..., conf_id]) < recall_thresh] = 0. 107 | predict = predict * torch.sigmoid(inputs[..., num_confs - 1]) 108 | return predict 109 | 110 | 111 | def focal_boost_positive( 112 | inputs: Tensor, 113 | num_confs: int, 114 | conf_thresh: float=0.5, 115 | recall_thresh: float=0.5, 116 | top_k: int=0, 117 | ) -> Tensor: 118 | 119 | predict = focal_boost_predict(inputs, num_confs, recall_thresh) 120 | if top_k <= 0: 121 | return predict >= conf_thresh 122 | samples = predict.flatten(start_dim=1) 123 | top_k = min(samples.shape[1], top_k) 124 | conf_threshes = torch.clamp_min( 125 | samples.topk(top_k).values[:, -1], conf_thresh) 126 | for _ in range(len(predict.shape) - 1): 127 | conf_threshes.unsqueeze_(dim=-1) 128 | return predict >= conf_threshes 129 | -------------------------------------------------------------------------------- /vklearn/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.2' 2 | --------------------------------------------------------------------------------