├── .gitignore ├── .pylintrc ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── app │ ├── .gitignore │ ├── README.md │ ├── api.py │ ├── requirements.txt │ └── run.sh ├── imagenet │ ├── README.md │ ├── classy_train.py │ ├── configs │ │ ├── resnet18-nbdt.json │ │ ├── resnet18.json │ │ └── resnet50.json │ └── losses │ │ ├── __init__.py │ │ └── nbdt_losses.py └── load_pretrained_nbdts.ipynb ├── main.py ├── nbdt ├── __init__.py ├── analysis.py ├── bin │ ├── nbdt │ ├── nbdt-hierarchy │ ├── nbdt-wnids │ └── original ├── data │ ├── __init__.py │ ├── ade20k.py │ ├── cifar.py │ ├── custom.py │ ├── imagenet.py │ ├── lip.py │ ├── pascal_context.py │ └── transforms.py ├── graph.py ├── hierarchies │ ├── ADE20K │ │ └── graph-induced-HRNet-w48.json │ ├── CIFAR10 │ │ ├── graph-induced-ResNet10.json │ │ ├── graph-induced-ResNet18.json │ │ ├── graph-induced-wrn28_10_cifar10.json │ │ ├── graph-induced.json │ │ └── graph-wordnet.json │ ├── CIFAR100 │ │ ├── graph-induced-ResNet10.json │ │ ├── graph-induced-ResNet18.json │ │ ├── graph-induced-wrn28_10_cifar100.json │ │ ├── graph-induced.json │ │ ├── graph-wordnet-single.json │ │ └── graph-wordnet.json │ ├── Cityscapes │ │ ├── graph-induced-HRNet-w18-v1.json │ │ └── graph-induced-HRNet-w48.json │ ├── Imagenet1000 │ │ ├── graph-induced-efficientnet_b7b.json │ │ └── graph-induced.json │ ├── LookIntoPerson │ │ └── graph-induced-HRNet-w48-cls20.json │ ├── PascalContext │ │ └── graph-induced-HRNet-w48-cls59.json │ └── TinyImagenet200 │ │ ├── graph-induced-ResNet18.json │ │ ├── graph-induced-wrn28_10.json │ │ ├── graph-induced.json │ │ ├── graph-wordnet-single.json │ │ └── graph-wordnet.json ├── hierarchy.py ├── loss.py ├── metrics.py ├── model.py ├── models │ ├── __init__.py │ ├── resnet.py │ ├── utils.py │ └── wideresnet.py ├── templates │ └── tree-template.html ├── thirdparty │ ├── nx.py │ └── wn.py ├── tree.py ├── utils.py └── wnids │ ├── ADE20K.txt │ ├── CIFAR10.txt │ ├── CIFAR100.txt │ ├── Imagenet1000.txt │ ├── LookIntoPerson.txt │ ├── PascalContext.txt │ └── TinyImagenet200.txt ├── pytest.ini ├── requirements.txt ├── scripts ├── gen_train_eval_nopretrained.ps1 ├── gen_train_eval_nopretrained.sh ├── gen_train_eval_pretrained.ps1 ├── gen_train_eval_pretrained.sh ├── gen_train_eval_resnet.ps1 ├── gen_train_eval_resnet.sh ├── gen_train_eval_wideresnet.ps1 ├── gen_train_eval_wideresnet.sh ├── generate_hierarchies_wordnet.ps1 └── generate_hierarchies_wordnet.sh ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_inference.py └── test_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.xml 3 | /data 4 | checkpoint 5 | out 6 | ./*.html 7 | 8 | # Created by https://www.gitignore.io/api/macos,python 9 | # Edit at https://www.gitignore.io/?templates=macos,python 10 | 11 | ### macOS ### 12 | # General 13 | .DS_Store 14 | .AppleDouble 15 | .LSOverride 16 | 17 | # Icon must end with two \r 18 | Icon 19 | 20 | # Thumbnails 21 | ._* 22 | 23 | # Files that might appear in the root of a volume 24 | .DocumentRevisions-V100 25 | .fseventsd 26 | .Spotlight-V100 27 | .TemporaryItems 28 | .Trashes 29 | .VolumeIcon.icns 30 | .com.apple.timemachine.donotpresent 31 | 32 | # Directories potentially created on remote AFP share 33 | .AppleDB 34 | .AppleDesktop 35 | Network Trash Folder 36 | Temporary Items 37 | .apdisk 38 | 39 | ### Python ### 40 | # Byte-compiled / optimized / DLL files 41 | __pycache__/ 42 | *.py[cod] 43 | *$py.class 44 | *.ipynb_checkpoints/ 45 | 46 | # C extensions 47 | *.so 48 | 49 | # Distribution / packaging 50 | .Python 51 | build/ 52 | develop-eggs/ 53 | dist/ 54 | downloads/ 55 | eggs/ 56 | .eggs/ 57 | lib/ 58 | lib64/ 59 | parts/ 60 | sdist/ 61 | var/ 62 | wheels/ 63 | pip-wheel-metadata/ 64 | share/python-wheels/ 65 | *.egg-info/ 66 | .installed.cfg 67 | *.egg 68 | MANIFEST 69 | 70 | # PyInstaller 71 | # Usually these files are written by a python script from a template 72 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 73 | *.manifest 74 | *.spec 75 | 76 | # Installer logs 77 | pip-log.txt 78 | pip-delete-this-directory.txt 79 | 80 | # Unit test / coverage reports 81 | htmlcov/ 82 | .tox/ 83 | .nox/ 84 | .coverage 85 | .coverage.* 86 | .cache 87 | nosetests.xml 88 | coverage.xml 89 | *.cover 90 | .hypothesis/ 91 | .pytest_cache/ 92 | 93 | # Translations 94 | *.mo 95 | *.pot 96 | 97 | # Scrapy stuff: 98 | .scrapy 99 | 100 | # Sphinx documentation 101 | docs/_build/ 102 | 103 | # PyBuilder 104 | target/ 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # celery beat schedule file 117 | celerybeat-schedule 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # Mr Developer 130 | .mr.developer.cfg 131 | .project 132 | .pydevproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # End of https://www.gitignore.io/api/macos,python 146 | 147 | # Elastic Beanstalk Files 148 | .elasticbeanstalk/* 149 | !.elasticbeanstalk/*.cfg.yml 150 | !.elasticbeanstalk/*.global.yml 151 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | 3 | # List of members which are set dynamically and missed by Pylint inference 4 | # system, and so shouldn't trigger E1101 when accessed. https://stackoverflow.com/a/53572939 5 | generated-members=numpy.*, torch.* 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Alvin Wan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include nbdt/templates/*.html 2 | include nbdt/wnids/*.txt 3 | include nbdt/hierarchies/*/*.json 4 | include requirements.txt 5 | include README.md 6 | include LICENSE 7 | -------------------------------------------------------------------------------- /examples/app/.gitignore: -------------------------------------------------------------------------------- 1 | *.lock 2 | 3 | # Elastic Beanstalk Files 4 | .elasticbeanstalk/* 5 | !.elasticbeanstalk/*.cfg.yml 6 | !.elasticbeanstalk/*.global.yml 7 | -------------------------------------------------------------------------------- /examples/app/README.md: -------------------------------------------------------------------------------- 1 | super simple flask app for serving NBDT predictions 2 | 3 | deployed via digitalocean with uwsgi and nginx 4 | 5 | https://www.digitalocean.com/community/tutorials/how-to-serve-flask-applications-with-uswgi-and-nginx-on-ubuntu-18-04 6 | 7 | 8 | but also `sudo ufw enable` and `sudo ufw allow 'OpenSSH'` else you'll get locked out. 9 | -------------------------------------------------------------------------------- /examples/app/api.py: -------------------------------------------------------------------------------- 1 | """Single-file example for serving an NBDT model. 2 | 3 | This functions as a simple single-endpoint API, using flask. 4 | """ 5 | 6 | 7 | from flask import Flask, flash, request, redirect, url_for, jsonify 8 | from flask_cors import CORS 9 | from nbdt.model import HardNBDT 10 | from nbdt.models import wrn28_10_cifar10 11 | from torchvision import transforms 12 | from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path 13 | from nbdt.thirdparty.wn import maybe_install_wordnet 14 | from werkzeug.utils import secure_filename 15 | from PIL import Image 16 | import os 17 | 18 | maybe_install_wordnet() 19 | app = Flask(__name__) 20 | app.config['SECRET_KEY'] = os.urandom(24) 21 | 22 | CORS(app) 23 | 24 | 25 | ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} 26 | 27 | 28 | def inference(im): 29 | # load pretrained NBDT 30 | model = wrn28_10_cifar10() 31 | model = HardNBDT( 32 | pretrained=True, 33 | dataset='CIFAR10', 34 | arch='wrn28_10_cifar10', 35 | model=model) 36 | 37 | # load + transform image 38 | transform = transforms.Compose([ 39 | transforms.Resize(32), 40 | transforms.CenterCrop(32), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | x = transform(im)[None] 45 | 46 | # run inference 47 | outputs, decisions = model.forward_with_decisions(x) # use `model(x)` to obtain just logits 48 | _, predicted = outputs.max(1) 49 | return { 50 | 'success': True, 51 | 'prediction': DATASET_TO_CLASSES['CIFAR10'][predicted[0]], 52 | 'decisions': [{ 53 | 'name': info['name'], 54 | 'prob': info['prob'] 55 | } for info in decisions[0]] 56 | } 57 | 58 | 59 | def allowed_file(filename): 60 | return '.' in filename and \ 61 | filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS 62 | 63 | 64 | image_urls = { 65 | 'cat': 'https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&h=32', 66 | 'bear': 'https://images.pexels.com/photos/158109/kodiak-brown-bear-adult-portrait-wildlife-158109.jpeg?auto=compress&cs=tinysrgb&h=32', 67 | 'dog': 'https://images.pexels.com/photos/1490908/pexels-photo-1490908.jpeg?auto=compress&cs=tinysrgb&h=32', 68 | } 69 | 70 | 71 | @app.route('/', methods=['GET', 'POST']) 72 | def upload_file(): 73 | """ 74 | To use this endpoint. You may use ANY of the following: 75 | 76 | 1. POST a URL with name "url", or 77 | 2. call this page with a query param "url", or 78 | 3. POST a file with name "file" to this URL 79 | 80 | Note that the ordering above is the order of priority. If a URL is posted, 81 | the uploaded file and the query param will be ignored. 82 | """ 83 | if request.method == 'POST' or request.args.get('url', None): 84 | url = request.form.get('url', request.args.get('url', None)) 85 | if url: 86 | try: 87 | im = load_image_from_path(url) 88 | except OSError as e: 89 | return jsonify({ 90 | "message": "Make sure you're passing in a path to an " 91 | "image and not to a webpage. Here was the " 92 | f" exact error: {e}", 93 | "success": False}) 94 | return jsonify(inference(im)) 95 | # check if the post request has the file part 96 | if 'file' not in request.files: 97 | return jsonify({ 98 | 'success': False, 99 | 'message': 'No file part' 100 | }) 101 | file = request.files['file'] 102 | # if user does not select file, browser also 103 | # submit an empty part without filename 104 | if file.filename == '': 105 | return jsonify({ 106 | 'sucess': False, 107 | 'message': 'No selected file' 108 | }) 109 | if file and allowed_file(file.filename): 110 | im = Image.open(file.stream) 111 | return jsonify(inference(im)) 112 | return jsonify({ 113 | 'success': False, 114 | 'message': f'nope. Allowed file? ({file.filename}) Got a file? ({bool(file)})' 115 | }) 116 | return jsonify({ 117 | 'success': False, 118 | 'message': 'You might be looking for the main page. Please see nbdt.alvinwan.com/demo. Here are some sample URLs you can use:', 119 | 'image_urls': image_urls 120 | }) 121 | 122 | 123 | if __name__ == '__main__': 124 | app.run() 125 | -------------------------------------------------------------------------------- /examples/app/requirements.txt: -------------------------------------------------------------------------------- 1 | nbdt==0.0.3 2 | flask==1.1.1 3 | uwsgi==2.0.18 4 | flask-cors==3.0.9 5 | -------------------------------------------------------------------------------- /examples/app/run.sh: -------------------------------------------------------------------------------- 1 | uwsgi --socket 0.0.0.0:5000 --protocol=http -w api:app 2 | -------------------------------------------------------------------------------- /examples/imagenet/README.md: -------------------------------------------------------------------------------- 1 | # Neural-Backed Decision Trees on ImageNet 2 | 3 | Just a loss hook in the Classy Vision workflow. The `classy_train.py` is 100% boilerplate. To launch a run with 8 GPUs on one node, use: 4 | 5 | ```bash 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=${NUM_CUDA_DEVICES:-8} \ 8 | --use_env \ 9 | classy_train.py \ 10 | --config=${CONFIG:-configs/resnet18-nbdt.json} \ 11 | --distributed_backend ddp 12 | ``` -------------------------------------------------------------------------------- /examples/imagenet/classy_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This is the main script used for training Classy Vision jobs. 9 | 10 | This can be used for training on your local machine, using CPU or GPU, and 11 | for distributed training. This script also supports Tensorboard, Visdom and 12 | checkpointing. 13 | 14 | Example: 15 | For training locally, simply specify a configuration file and whether 16 | to use CPU or GPU: 17 | 18 | $ ./classy_train.py --device gpu --config configs/my_config.json 19 | 20 | For distributed training, this can be invoked via 21 | :func:`torch.distributed.launch`. For instance 22 | 23 | $ python -m torch.distributed.launch \ 24 | --nnodes=1 \ 25 | --nproc_per_node=1 \ 26 | --master_addr=localhost \ 27 | --master_port=29500 \ 28 | --use_env \ 29 | classy_train.py \ 30 | --device=gpu \ 31 | --config=configs/resnet50_synthetic_image_classy_config.json \ 32 | --num_workers=1 \ 33 | --log_freq=100 34 | 35 | For other use cases, try 36 | 37 | $ ./classy_train.py --help 38 | """ 39 | 40 | import logging 41 | import os 42 | from datetime import datetime 43 | from pathlib import Path 44 | 45 | import torch 46 | from classy_vision.generic.distributed_util import get_rank, get_world_size 47 | from classy_vision.generic.opts import check_generic_args, parse_train_arguments 48 | from classy_vision.generic.registry_utils import import_all_packages_from_directory 49 | from classy_vision.generic.util import load_checkpoint, load_json 50 | from classy_vision.hooks import ( 51 | CheckpointHook, 52 | LossLrMeterLoggingHook, 53 | ModelComplexityHook, 54 | ProfilerHook, 55 | ProgressBarHook, 56 | TensorboardPlotHook, 57 | VisdomHook, 58 | ) 59 | from classy_vision.tasks import FineTuningTask, build_task 60 | from classy_vision.trainer import DistributedTrainer, LocalTrainer 61 | from torchvision import set_image_backend, set_video_backend 62 | 63 | 64 | hydra_available = False 65 | 66 | 67 | def main(args, config): 68 | # Global flags 69 | torch.manual_seed(0) 70 | set_image_backend(args.image_backend) 71 | set_video_backend(args.video_backend) 72 | 73 | task = build_task(config) 74 | 75 | # Load checkpoint, if available. 76 | if args.checkpoint_load_path: 77 | task.set_checkpoint(args.checkpoint_load_path) 78 | 79 | # Load a checkpoint contraining a pre-trained model. This is how we 80 | # implement fine-tuning of existing models. 81 | if args.pretrained_checkpoint_path: 82 | assert isinstance( 83 | task, FineTuningTask 84 | ), "Can only use a pretrained checkpoint for fine tuning tasks" 85 | task.set_pretrained_checkpoint(args.pretrained_checkpoint_path) 86 | 87 | # Configure hooks to do tensorboard logging, checkpoints and so on. 88 | # `configure_hooks` adds default hooks, while extra hooks can be specified 89 | # in config file and stored in `task.hooks`. Here, we merge them when we 90 | # set the final hooks of the task. 91 | task.set_hooks(configure_hooks(args, config) + task.hooks) 92 | 93 | # LocalTrainer is used for a single replica. DistributedTrainer will setup 94 | # training to use PyTorch's DistributedDataParallel. 95 | trainer_class = {"none": LocalTrainer, "ddp": DistributedTrainer}[ 96 | args.distributed_backend 97 | ] 98 | 99 | trainer = trainer_class() 100 | 101 | logging.info( 102 | f"Starting training on rank {get_rank()} worker. " 103 | f"World size is {get_world_size()}" 104 | ) 105 | # That's it! When this call returns, training is done. 106 | trainer.train(task) 107 | 108 | output_folder = Path(args.checkpoint_folder).resolve() 109 | logging.info("Training successful!") 110 | logging.info(f'Results of this training run are available at: "{output_folder}"') 111 | 112 | 113 | def configure_hooks(args, config): 114 | hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] 115 | 116 | # Make a folder to store checkpoints and tensorboard logging outputs 117 | suffix = datetime.now().isoformat() 118 | base_folder = f"{Path(__file__).parent}/output_{suffix}" 119 | if args.checkpoint_folder == "": 120 | args.checkpoint_folder = base_folder + "/checkpoints" 121 | os.makedirs(args.checkpoint_folder, exist_ok=True) 122 | 123 | logging.info(f"Logging outputs to {base_folder}") 124 | logging.info(f"Logging checkpoints to {args.checkpoint_folder}") 125 | 126 | if not args.skip_tensorboard: 127 | try: 128 | from torch.utils.tensorboard import SummaryWriter 129 | 130 | tb_writer = SummaryWriter(log_dir=Path(base_folder) / "tensorboard") 131 | hooks.append(TensorboardPlotHook(tb_writer)) 132 | except ImportError: 133 | logging.warning("tensorboard not installed, skipping tensorboard hooks") 134 | 135 | args_dict = vars(args) 136 | args_dict["config"] = config 137 | hooks.append( 138 | CheckpointHook( 139 | args.checkpoint_folder, args_dict, checkpoint_period=args.checkpoint_period 140 | ) 141 | ) 142 | 143 | if args.profiler: 144 | hooks.append(ProfilerHook()) 145 | if args.show_progress: 146 | hooks.append(ProgressBarHook()) 147 | if args.visdom_server != "": 148 | hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) 149 | 150 | return hooks 151 | 152 | 153 | if hydra_available: 154 | 155 | @hydra.main(config_path="hydra_configs", config_name="args") 156 | def hydra_main(cfg): 157 | args = cfg 158 | check_generic_args(cfg) 159 | config = omegaconf.OmegaConf.to_container(cfg.config) 160 | main(args, config) 161 | 162 | 163 | # run all the things: 164 | if __name__ == "__main__": 165 | logger = logging.getLogger() 166 | logger.setLevel(logging.INFO) 167 | 168 | logging.info("Classy Vision's default training script.") 169 | 170 | # This imports all modules in the same directory as classy_train.py 171 | # Because of the way Classy Vision's registration decorators work, 172 | # importing a module has a side effect of registering it with Classy 173 | # Vision. This means you can give classy_train.py a config referencing your 174 | # custom module (e.g. my_dataset) and it'll actually know how to 175 | # instantiate it. 176 | file_root = Path(__file__).parent 177 | import_all_packages_from_directory(file_root) 178 | 179 | if hydra_available: 180 | hydra_main() 181 | else: 182 | args = parse_train_arguments() 183 | config = load_json(args.config_file) 184 | main(args, config) 185 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet18-nbdt.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "classification_task", 3 | "num_epochs": 90, 4 | "loss": { 5 | "name": "NBDTTreeSupLoss" 6 | }, 7 | "dataset": { 8 | "train": { 9 | "name": "image_path", 10 | "batchsize_per_replica": 32, 11 | "num_samples": null, 12 | "use_shuffle": true, 13 | "transforms": [{ 14 | "name": "apply_transform_to_key", 15 | "transforms": [ 16 | {"name": "RandomResizedCrop", "size": 224}, 17 | {"name": "RandomHorizontalFlip"}, 18 | {"name": "ToTensor"}, 19 | { 20 | "name": "Normalize", 21 | "mean": [0.485, 0.456, 0.406], 22 | "std": [0.229, 0.224, 0.225] 23 | } 24 | ], 25 | "key": "input" 26 | }], 27 | "image_folder": "/data/imagenetwhole/ilsvrc2012/train" 28 | }, 29 | "test": { 30 | "name": "image_path", 31 | "batchsize_per_replica": 32, 32 | "num_samples": null, 33 | "use_shuffle": false, 34 | "transforms": [{ 35 | "name": "apply_transform_to_key", 36 | "transforms": [ 37 | {"name": "Resize", "size": 256}, 38 | {"name": "CenterCrop", "size": 224}, 39 | {"name": "ToTensor"}, 40 | { 41 | "name": "Normalize", 42 | "mean": [0.485, 0.456, 0.406], 43 | "std": [0.229, 0.224, 0.225] 44 | } 45 | ], 46 | "key": "input" 47 | }], 48 | "image_folder": "/data/imagenetwhole/ilsvrc2012/val" 49 | } 50 | }, 51 | "meters": { 52 | "accuracy": { 53 | "topk": [1, 5] 54 | } 55 | }, 56 | "model": { 57 | "name": "resnet", 58 | "num_blocks": [2, 2, 2, 2], 59 | "small_input": false, 60 | "zero_init_bn_residuals": true, 61 | "heads": [ 62 | { 63 | "name": "fully_connected", 64 | "unique_id": "default_head", 65 | "num_classes": 1000, 66 | "fork_block": "block3-1", 67 | "in_plane": 2048 68 | } 69 | ] 70 | }, 71 | "optimizer": { 72 | "name": "sgd", 73 | "param_schedulers": { 74 | "lr": { 75 | "name": "composite", 76 | "schedulers": [ 77 | {"name": "linear", "start_value": 0.1, "end_value": 0.4}, 78 | {"name": "multistep", "values": [0.4, 0.04, 0.004, 0.0004], "milestones": [30, 60, 80]} 79 | ], 80 | "update_interval": "epoch", 81 | "interval_scaling": ["rescaled", "fixed"], 82 | "lengths": [0.0555, 0.9445] 83 | } 84 | }, 85 | "weight_decay": 1e-4, 86 | "momentum": 0.9 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "classification_task", 3 | "num_epochs": 90, 4 | "loss": { 5 | "name": "CrossEntropyLoss" 6 | }, 7 | "dataset": { 8 | "train": { 9 | "name": "image_path", 10 | "batchsize_per_replica": 32, 11 | "num_samples": null, 12 | "use_shuffle": true, 13 | "transforms": [{ 14 | "name": "apply_transform_to_key", 15 | "transforms": [ 16 | {"name": "RandomResizedCrop", "size": 224}, 17 | {"name": "RandomHorizontalFlip"}, 18 | {"name": "ToTensor"}, 19 | { 20 | "name": "Normalize", 21 | "mean": [0.485, 0.456, 0.406], 22 | "std": [0.229, 0.224, 0.225] 23 | } 24 | ], 25 | "key": "input" 26 | }], 27 | "image_folder": "/data/imagenetwhole/ilsvrc2012/train" 28 | }, 29 | "test": { 30 | "name": "image_path", 31 | "batchsize_per_replica": 32, 32 | "num_samples": null, 33 | "use_shuffle": false, 34 | "transforms": [{ 35 | "name": "apply_transform_to_key", 36 | "transforms": [ 37 | {"name": "Resize", "size": 256}, 38 | {"name": "CenterCrop", "size": 224}, 39 | {"name": "ToTensor"}, 40 | { 41 | "name": "Normalize", 42 | "mean": [0.485, 0.456, 0.406], 43 | "std": [0.229, 0.224, 0.225] 44 | } 45 | ], 46 | "key": "input" 47 | }], 48 | "image_folder": "/data/imagenetwhole/ilsvrc2012/val" 49 | } 50 | }, 51 | "meters": { 52 | "accuracy": { 53 | "topk": [1, 5] 54 | } 55 | }, 56 | "model": { 57 | "name": "resnet", 58 | "num_blocks": [2, 2, 2, 2], 59 | "small_input": false, 60 | "zero_init_bn_residuals": true, 61 | "heads": [ 62 | { 63 | "name": "fully_connected", 64 | "unique_id": "default_head", 65 | "num_classes": 1000, 66 | "fork_block": "block3-1", 67 | "in_plane": 2048 68 | } 69 | ] 70 | }, 71 | "optimizer": { 72 | "name": "sgd", 73 | "param_schedulers": { 74 | "lr": { 75 | "name": "composite", 76 | "schedulers": [ 77 | {"name": "linear", "start_value": 0.1, "end_value": 0.4}, 78 | {"name": "multistep", "values": [0.4, 0.04, 0.004, 0.0004], "milestones": [30, 60, 80]} 79 | ], 80 | "update_interval": "epoch", 81 | "interval_scaling": ["rescaled", "fixed"], 82 | "lengths": [0.0555, 0.9445] 83 | } 84 | }, 85 | "weight_decay": 1e-4, 86 | "momentum": 0.9 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/imagenet/configs/resnet50.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "classification_task", 3 | "num_epochs": 90, 4 | "loss": { 5 | "name": "CrossEntropyLoss" 6 | }, 7 | "dataset": { 8 | "train": { 9 | "name": "image_path", 10 | "batchsize_per_replica": 32, 11 | "num_samples": null, 12 | "use_shuffle": true, 13 | "transforms": [{ 14 | "name": "apply_transform_to_key", 15 | "transforms": [ 16 | {"name": "RandomResizedCrop", "size": 224}, 17 | {"name": "RandomHorizontalFlip"}, 18 | {"name": "ToTensor"}, 19 | { 20 | "name": "Normalize", 21 | "mean": [0.485, 0.456, 0.406], 22 | "std": [0.229, 0.224, 0.225] 23 | } 24 | ], 25 | "key": "input" 26 | }], 27 | "image_folder": "/data/imagenetwhole/ilsvrc2012/train" 28 | }, 29 | "test": { 30 | "name": "image_path", 31 | "batchsize_per_replica": 32, 32 | "num_samples": null, 33 | "use_shuffle": false, 34 | "transforms": [{ 35 | "name": "apply_transform_to_key", 36 | "transforms": [ 37 | {"name": "Resize", "size": 256}, 38 | {"name": "CenterCrop", "size": 224}, 39 | {"name": "ToTensor"}, 40 | { 41 | "name": "Normalize", 42 | "mean": [0.485, 0.456, 0.406], 43 | "std": [0.229, 0.224, 0.225] 44 | } 45 | ], 46 | "key": "input" 47 | }], 48 | "image_folder": "/data/imagenetwhole/ilsvrc2012/val" 49 | } 50 | }, 51 | "meters": { 52 | "accuracy": { 53 | "topk": [1, 5] 54 | } 55 | }, 56 | "model": { 57 | "name": "resnet", 58 | "num_blocks": [3, 4, 6, 3], 59 | "small_input": false, 60 | "zero_init_bn_residuals": true, 61 | "heads": [ 62 | { 63 | "name": "fully_connected", 64 | "unique_id": "default_head", 65 | "num_classes": 1000, 66 | "fork_block": "block3-2", 67 | "in_plane": 2048 68 | } 69 | ] 70 | }, 71 | "optimizer": { 72 | "name": "sgd", 73 | "param_schedulers": { 74 | "lr": { 75 | "name": "composite", 76 | "schedulers": [ 77 | {"name": "linear", "start_value": 0.1, "end_value": 0.4}, 78 | {"name": "multistep", "values": [0.4, 0.04, 0.004, 0.0004], "milestones": [30, 60, 80]} 79 | ], 80 | "update_interval": "epoch", 81 | "interval_scaling": ["rescaled", "fixed"], 82 | "lengths": [0.0555, 0.9445] 83 | } 84 | }, 85 | "weight_decay": 1e-4, 86 | "momentum": 0.9 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/imagenet/losses/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | from classy_vision.generic.registry_utils import import_all_modules 10 | 11 | 12 | FILE_ROOT = Path(__file__).parent 13 | 14 | # Automatically import any Python files in the losses/ directory 15 | import_all_modules(FILE_ROOT, "losses") 16 | -------------------------------------------------------------------------------- /examples/imagenet/losses/nbdt_losses.py: -------------------------------------------------------------------------------- 1 | from classy_vision.losses import ClassyLoss, register_loss 2 | import torch.nn as nn 3 | from nbdt.loss import SoftTreeSupLoss 4 | 5 | 6 | @register_loss("NBDTTreeSupLoss") 7 | class NBDTTreeSupLoss(SoftTreeSupLoss, ClassyLoss): 8 | def __init__(self, tree_supervision_weight=5): 9 | super().__init__( 10 | criterion=nn.CrossEntropyLoss().cuda(), 11 | dataset='Imagenet1000', 12 | tree_supervision_weight=tree_supervision_weight, 13 | hierarchy='induced-efficientnet_b7b' 14 | ) 15 | 16 | @classmethod 17 | def from_config(cls, config): 18 | # We don't need anything from the config 19 | return cls( 20 | tree_supervision_weight=config.get("tree_supervision_weight", 5) 21 | ) 22 | -------------------------------------------------------------------------------- /examples/load_pretrained_nbdts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pip install nbdt" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from nbdt.model import SoftNBDT\n", 19 | "from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # use wrn28_10 for TinyImagenet200\n", 20 | "from torchvision import transforms\n", 21 | "from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path, maybe_install_wordnet\n", 22 | "from IPython.display import display" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "model = wrn28_10_cifar10()\n", 32 | "model = SoftNBDT(\n", 33 | " pretrained=True,\n", 34 | " dataset='CIFAR10',\n", 35 | " arch='wrn28_10_cifar10',\n", 36 | " model=model)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "image_urls = {\n", 46 | " 'cat': 'https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300',\n", 47 | " 'bear': 'https://images.pexels.com/photos/158109/kodiak-brown-bear-adult-portrait-wildlife-158109.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300',\n", 48 | " 'dog': 'https://images.pexels.com/photos/1490908/pexels-photo-1490908.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300'\n", 49 | "}" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# show image\n", 59 | "im = load_image_from_path(image_urls['cat'])\n", 60 | "display(im)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# load + transform image\n", 70 | "transforms = transforms.Compose([\n", 71 | " transforms.Resize(32),\n", 72 | " transforms.CenterCrop(32),\n", 73 | " transforms.ToTensor(),\n", 74 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 75 | "])\n", 76 | "x = transforms(im)[None]" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# run inference\n", 86 | "outputs = model(x) # to get intermediate decisions, use `model.forward_with_decisions(x)` and add `hierarchy='wordnet' to SoftNBDT\n", 87 | "_, predicted = outputs.max(1)\n", 88 | "cls = DATASET_TO_CLASSES['CIFAR10'][predicted[0]]\n", 89 | "print(cls)" 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "pytorch-1.2", 96 | "language": "python", 97 | "name": "pytorch-1.2" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.7.4" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 4 114 | } 115 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural-Backed Decision Trees training on CIFAR10, CIFAR100, TinyImagenet200 3 | 4 | The original version of this `main.py` was taken from kuangliu/pytorch-cifar. 5 | The script has since been heavily modified to support a number of different 6 | configurations and options: alvinwan/neural-backed-decision-trees 7 | """ 8 | import os 9 | import argparse 10 | import numpy as np 11 | import torch 12 | from torch import nn, optim 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | 18 | from nbdt import data, analysis, loss, models, metrics, tree as T 19 | from nbdt.utils import progress_bar, generate_checkpoint_fname, generate_kwargs, Colors 20 | from nbdt.thirdparty.wn import maybe_install_wordnet 21 | from nbdt.models.utils import load_state_dict, make_kwarg_optional 22 | from nbdt.tree import Tree 23 | 24 | 25 | def main(): 26 | maybe_install_wordnet() 27 | datasets = data.cifar.names + data.imagenet.names + data.custom.names 28 | parser = argparse.ArgumentParser(description="PyTorch CIFAR Training") 29 | parser.add_argument( 30 | "--batch-size", default=512, type=int, help="Batch size used for training" 31 | ) 32 | parser.add_argument( 33 | "--epochs", 34 | "-e", 35 | default=200, 36 | type=int, 37 | help="By default, lr schedule is scaled accordingly", 38 | ) 39 | parser.add_argument("--dataset", default="CIFAR10", choices=datasets) 40 | parser.add_argument( 41 | "--arch", default="ResNet18", choices=list(models.get_model_choices()) 42 | ) 43 | parser.add_argument("--lr", default=0.1, type=float, help="learning rate") 44 | parser.add_argument( 45 | "--resume", "-r", action="store_true", help="resume from checkpoint" 46 | ) 47 | 48 | # extra general options for main script 49 | parser.add_argument( 50 | "--path-resume", default="", help="Overrides checkpoint path generation" 51 | ) 52 | parser.add_argument( 53 | "--name", default="", help="Name of experiment. Used for checkpoint filename" 54 | ) 55 | parser.add_argument( 56 | "--pretrained", 57 | action="store_true", 58 | help="Download pretrained model. Not all models support this.", 59 | ) 60 | parser.add_argument("--eval", help="eval only", action="store_true") 61 | parser.add_argument( 62 | "--dataset-test", 63 | choices=datasets, 64 | help="If not set, automatically set to train dataset", 65 | ) 66 | parser.add_argument( 67 | "--disable-test-eval", 68 | help="Allows you to run model inference on a test dataset " 69 | " different from train dataset. Use an anlayzer to define " 70 | "a metric.", 71 | action="store_true", 72 | ) 73 | 74 | # options specific to this project and its dataloaders 75 | parser.add_argument( 76 | "--loss", choices=loss.names, default=["CrossEntropyLoss"], nargs="+" 77 | ) 78 | parser.add_argument("--metric", choices=metrics.names, default="top1") 79 | parser.add_argument( 80 | "--analysis", choices=analysis.names, help="Run analysis after each epoch" 81 | ) 82 | 83 | # other dataset, loss or analysis specific options 84 | data.custom.add_arguments(parser) 85 | T.add_arguments(parser) 86 | loss.add_arguments(parser) 87 | analysis.add_arguments(parser) 88 | 89 | args = parser.parse_args() 90 | loss.set_default_values(args) 91 | 92 | device = "cuda" if torch.cuda.is_available() else "cpu" 93 | best_acc = 0 # best test accuracy 94 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 95 | 96 | # Data 97 | print("==> Preparing data..") 98 | dataset_train = getattr(data, args.dataset) 99 | dataset_test = getattr(data, args.dataset_test or args.dataset) 100 | 101 | transform_train = dataset_train.transform_train() 102 | transform_test = dataset_test.transform_val() 103 | 104 | dataset_train_kwargs = generate_kwargs( 105 | args, 106 | dataset_train, 107 | name=f"Dataset {dataset_train.__class__.__name__}", 108 | globals=locals(), 109 | ) 110 | dataset_test_kwargs = generate_kwargs( 111 | args, 112 | dataset_test, 113 | name=f"Dataset {dataset_test.__class__.__name__}", 114 | globals=locals(), 115 | ) 116 | trainset = dataset_train( 117 | **dataset_train_kwargs, 118 | root="./data", 119 | train=True, 120 | download=True, 121 | transform=transform_train, 122 | ) 123 | testset = dataset_test( 124 | **dataset_test_kwargs, 125 | root="./data", 126 | train=False, 127 | download=True, 128 | transform=transform_test, 129 | ) 130 | 131 | assert trainset.classes == testset.classes or args.disable_test_eval, ( 132 | trainset.classes, 133 | testset.classes, 134 | ) 135 | 136 | trainloader = torch.utils.data.DataLoader( 137 | trainset, batch_size=args.batch_size, shuffle=True, num_workers=2 138 | ) 139 | testloader = torch.utils.data.DataLoader( 140 | testset, batch_size=100, shuffle=False, num_workers=2 141 | ) 142 | 143 | Colors.cyan(f"Training with dataset {args.dataset} and {len(trainset.classes)} classes") 144 | Colors.cyan( 145 | f"Testing with dataset {args.dataset_test or args.dataset} and {len(testset.classes)} classes" 146 | ) 147 | 148 | # Model 149 | print("==> Building model..") 150 | model = getattr(models, args.arch) 151 | 152 | if args.pretrained: 153 | print("==> Loading pretrained model..") 154 | model = make_kwarg_optional(model, dataset=args.dataset) 155 | net = model(pretrained=True, num_classes=len(trainset.classes)) 156 | else: 157 | net = model(num_classes=len(trainset.classes)) 158 | 159 | net = net.to(device) 160 | if device == "cuda": 161 | net = torch.nn.DataParallel(net) 162 | cudnn.benchmark = True 163 | 164 | checkpoint_fname = generate_checkpoint_fname(**vars(args)) 165 | checkpoint_path = "./checkpoint/{}.pth".format(checkpoint_fname) 166 | print(f"==> Checkpoints will be saved to: {checkpoint_path}") 167 | 168 | resume_path = args.path_resume or checkpoint_path 169 | if args.resume: 170 | # Load checkpoint. 171 | print("==> Resuming from checkpoint..") 172 | assert os.path.isdir("checkpoint"), "Error: no checkpoint directory found!" 173 | if not os.path.exists(resume_path): 174 | print("==> No checkpoint found. Skipping...") 175 | else: 176 | checkpoint = torch.load(resume_path, map_location=torch.device(device)) 177 | 178 | if "net" in checkpoint: 179 | load_state_dict(net, checkpoint["net"]) 180 | best_acc = checkpoint["acc"] 181 | start_epoch = checkpoint["epoch"] 182 | Colors.cyan( 183 | f"==> Checkpoint found for epoch {start_epoch} with accuracy " 184 | f"{best_acc} at {resume_path}" 185 | ) 186 | else: 187 | load_state_dict(net, checkpoint) 188 | Colors.cyan(f"==> Checkpoint found at {resume_path}") 189 | 190 | # hierarchy 191 | tree = Tree.create_from_args(args, classes=trainset.classes) 192 | 193 | # loss 194 | criterion = None 195 | for _loss in args.loss: 196 | if criterion is None and not hasattr(nn, _loss): 197 | criterion = nn.CrossEntropyLoss() 198 | class_criterion = getattr(loss, _loss) 199 | loss_kwargs = generate_kwargs( 200 | args, 201 | class_criterion, 202 | name=f"Loss {args.loss}", 203 | globals=locals(), 204 | ) 205 | criterion = class_criterion(**loss_kwargs) 206 | 207 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 208 | scheduler = optim.lr_scheduler.MultiStepLR( 209 | optimizer, milestones=[int(3 / 7.0 * args.epochs), int(5 / 7.0 * args.epochs)] 210 | ) 211 | 212 | class_analysis = getattr(analysis, args.analysis or "Noop") 213 | analyzer_kwargs = generate_kwargs( 214 | args, 215 | class_analysis, 216 | name=f"Analyzer {args.analysis}", 217 | globals=locals(), 218 | ) 219 | analyzer = class_analysis(**analyzer_kwargs) 220 | 221 | metric = getattr(metrics, args.metric)() 222 | 223 | # Training 224 | @analyzer.train_function 225 | def train(epoch): 226 | if hasattr(criterion, "set_epoch"): 227 | criterion.set_epoch(epoch, args.epochs) 228 | 229 | print("\nEpoch: %d / LR: %.04f" % (epoch, scheduler.get_last_lr()[0])) 230 | net.train() 231 | train_loss = 0 232 | metric.clear() 233 | for batch_idx, (inputs, targets) in enumerate(trainloader): 234 | inputs, targets = inputs.to(device), targets.to(device) 235 | optimizer.zero_grad() 236 | outputs = net(inputs) 237 | loss = criterion(outputs, targets) 238 | loss.backward() 239 | optimizer.step() 240 | 241 | train_loss += loss.item() 242 | metric.forward(outputs, targets) 243 | transform = trainset.transform_val_inverse().to(device) 244 | stat = analyzer.update_batch(outputs, targets, transform(inputs)) 245 | 246 | progress_bar( 247 | batch_idx, 248 | len(trainloader), 249 | "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" 250 | % ( 251 | train_loss / (batch_idx + 1), 252 | 100.0 * metric.report(), 253 | metric.correct, 254 | metric.total, 255 | f"| {analyzer.name}: {stat}" if stat else "", 256 | ), 257 | ) 258 | scheduler.step() 259 | 260 | 261 | @analyzer.test_function 262 | def test(epoch, checkpoint=True): 263 | nonlocal best_acc 264 | net.eval() 265 | test_loss = 0 266 | metric.clear() 267 | with torch.no_grad(): 268 | for batch_idx, (inputs, targets) in enumerate(testloader): 269 | inputs, targets = inputs.to(device), targets.to(device) 270 | outputs = net(inputs) 271 | 272 | if not args.disable_test_eval: 273 | loss = criterion(outputs, targets) 274 | test_loss += loss.item() 275 | metric.forward(outputs, targets) 276 | transform = testset.transform_val_inverse().to(device) 277 | stat = analyzer.update_batch(outputs, targets, transform(inputs)) 278 | 279 | progress_bar( 280 | batch_idx, 281 | len(testloader), 282 | "Loss: %.3f | Acc: %.3f%% (%d/%d) %s" 283 | % ( 284 | test_loss / (batch_idx + 1), 285 | 100.0 * metric.report(), 286 | metric.correct, 287 | metric.total, 288 | f"| {analyzer.name}: {stat}" if stat else "", 289 | ), 290 | ) 291 | 292 | # Save checkpoint. 293 | acc = 100.0 * metric.report() 294 | print( 295 | "Accuracy: {}, {}/{} | Best Accurracy: {}".format( 296 | acc, metric.correct, metric.total, best_acc 297 | ) 298 | ) 299 | if acc > best_acc and checkpoint: 300 | Colors.green(f"Saving to {checkpoint_fname} ({acc})..") 301 | state = { 302 | "net": net.state_dict(), 303 | "acc": acc, 304 | "epoch": epoch, 305 | } 306 | os.makedirs("checkpoint", exist_ok=True) 307 | torch.save(state, f"./checkpoint/{checkpoint_fname}.pth") 308 | best_acc = acc 309 | 310 | if args.disable_test_eval and (not args.analysis or args.analysis == "Noop"): 311 | Colors.red( 312 | " * Warning: `disable_test_eval` is used but no custom metric " 313 | "`--analysis` is supplied. I suggest supplying an analysis to perform " 314 | " custom loss and accuracy calculation." 315 | ) 316 | 317 | if args.eval: 318 | if not args.resume and not args.pretrained: 319 | Colors.red( 320 | " * Warning: Model is not loaded from checkpoint. " 321 | "Use --resume or --pretrained (if supported)" 322 | ) 323 | with analyzer.epoch_context(0): 324 | test(0, checkpoint=False) 325 | else: 326 | for epoch in range(start_epoch, args.epochs): 327 | with analyzer.epoch_context(epoch): 328 | train(epoch) 329 | test(epoch) 330 | 331 | print(f"Best accuracy: {best_acc} // Checkpoint name: {checkpoint_fname}") 332 | 333 | if __name__ == '__main__': 334 | main() -------------------------------------------------------------------------------- /nbdt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinwan/neural-backed-decision-trees/a7a2ee6f735bbc1b3d8c7c4f9ecdd02c6a75fc1e/nbdt/__init__.py -------------------------------------------------------------------------------- /nbdt/bin/nbdt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Run evaluation on a single image, using an NBDT""" 3 | 4 | from nbdt.model import SoftNBDT, HardNBDT 5 | from pytorchcv.models.wrn_cifar import wrn28_10_cifar10 6 | from torchvision import transforms 7 | from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path 8 | from nbdt.thirdparty.wn import maybe_install_wordnet 9 | import sys 10 | 11 | maybe_install_wordnet() 12 | 13 | assert len(sys.argv) > 1, "Need to pass image URL or image path as argument" 14 | 15 | # load pretrained NBDT 16 | model = wrn28_10_cifar10() 17 | model = SoftNBDT( 18 | pretrained=True, dataset="CIFAR10", arch="wrn28_10_cifar10", model=model 19 | ) 20 | 21 | # load + transform image 22 | im = load_image_from_path(sys.argv[1]) 23 | transform = transforms.Compose( 24 | [ 25 | transforms.Resize(32), 26 | transforms.CenterCrop(32), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 29 | ] 30 | ) 31 | x = transform(im)[None] 32 | 33 | # run inference 34 | outputs, decisions = model.forward_with_decisions( 35 | x 36 | ) # use `model(x)` to obtain just logits 37 | _, predicted = outputs.max(1) 38 | cls = DATASET_TO_CLASSES["CIFAR10"][predicted[0]] 39 | print( 40 | "Prediction:", 41 | cls, 42 | "// Decisions:", 43 | ", ".join( 44 | [ 45 | "{} (Confidence: {:.2f}%)".format(info["name"], (1 - info["entropy"]) * 100) 46 | for info in decisions[0] 47 | ][1:] 48 | ), 49 | ) # [1:] to skip the root 50 | -------------------------------------------------------------------------------- /nbdt/bin/nbdt-hierarchy: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from nbdt.hierarchy import generate_hierarchy, test_hierarchy, generate_hierarchy_vis 4 | from nbdt.graph import get_parser 5 | from nbdt.thirdparty.wn import maybe_install_wordnet 6 | 7 | 8 | def main(): 9 | 10 | parser = get_parser() 11 | args = parser.parse_args() 12 | 13 | if not args.path: 14 | maybe_install_wordnet() 15 | generate_hierarchy(**vars(args)) 16 | test_hierarchy(args) 17 | generate_hierarchy_vis(args) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /nbdt/bin/nbdt-wnids: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Generates wnids using class names for torchvision dataset""" 3 | 4 | import argparse 5 | import torchvision 6 | from nbdt import data 7 | from nltk.corpus import wordnet as wn 8 | from pathlib import Path 9 | from nbdt.utils import Colors, generate_kwargs 10 | from nbdt.thirdparty.wn import ( 11 | maybe_install_wordnet, 12 | synset_to_wnid, 13 | write_wnids, 14 | FakeSynset, 15 | ) 16 | import os 17 | 18 | maybe_install_wordnet() 19 | 20 | datasets = ( 21 | ("CIFAR10", "CIFAR100", "Cityscapes") 22 | + data.imagenet.names 23 | + data.custom.names 24 | + data.pascal_context.names 25 | + data.lip.names 26 | + data.ade20k.names 27 | ) 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--dataset", choices=datasets, default="CIFAR10") 32 | parser.add_argument("--root", default="./nbdt/wnids") 33 | parser.add_argument( 34 | "--classes", 35 | type=str, 36 | nargs="*", 37 | help="INSTEAD of writing WNIDs for a dataset, convert JUST" 38 | " this class name to a WNID.", 39 | ) 40 | data.custom.add_arguments(parser) 41 | args = parser.parse_args() 42 | 43 | if args.classes: 44 | classes = args.classes 45 | else: 46 | dataset = getattr(data, args.dataset) 47 | dataset_kwargs = generate_kwargs( 48 | args, 49 | dataset, 50 | name=f"Dataset {args.dataset}", 51 | keys=data.custom.keys, 52 | globals=globals(), 53 | ) 54 | if args.dataset not in ["Cityscapes", "PascalContext", "LookIntoPerson", "ADE20K"]: 55 | dataset_kwargs["download"] = True 56 | 57 | dataset = dataset(**dataset_kwargs, root="./data") 58 | 59 | classes = dataset.classes 60 | if args.dataset == "Cityscapes": 61 | classes = [cls.name for cls in dataset.classes if not cls.ignore_in_eval] 62 | if args.dataset == "PascalContext": 63 | classes = [cls for cls in dataset.classes if cls != "background"] 64 | 65 | path = Path(os.path.join(args.root, f"{args.dataset}.txt")) 66 | os.makedirs(path.parent, exist_ok=True) 67 | failures = [] 68 | 69 | hardcoded_mapping = { 70 | "aquarium_fish": wn.synsets("fingerling", pos=wn.NOUN)[0], 71 | "arcade_machine": wn.synsets("slot_machine", pos=wn.NOUN)[0], 72 | "background": wn.synsets("background", pos=wn.NOUN)[1], 73 | "barrel": wn.synsets("barrel", pos=wn.NOUN)[1], 74 | "beaver": wn.synsets("beaver", pos=wn.NOUN)[-1], 75 | "booth": wn.synsets("booth", pos=wn.NOUN)[1], 76 | "blind": wn.synsets("blind", pos=wn.NOUN)[2], 77 | "bulletin_board": wn.synsets("bulletin_board", pos=wn.NOUN)[1], 78 | "canopy": wn.synsets("canopy", pos=wn.NOUN)[2], 79 | "case": wn.synsets("case", pos=wn.NOUN)[-1], 80 | "castle": wn.synsets("castle", pos=wn.NOUN)[1], 81 | "column": wn.synsets("column", pos=wn.NOUN)[5], 82 | "cushion": wn.synsets("cushion", pos=wn.NOUN)[2], 83 | "diningtable": wn.synsets("dining_table", pos=wn.NOUN)[0], 84 | "earth": wn.synsets("earth", pos=wn.NOUN)[1], 85 | "escalator": wn.synsets("escalator", pos=wn.NOUN)[1], 86 | "flatfish": wn.synsets("flatfish", pos=wn.NOUN)[1], 87 | "food": wn.synsets("food", pos=wn.NOUN)[1], 88 | "glove": wn.synsets("glove", pos=wn.NOUN)[1], 89 | "grandstand": wn.synsets("grandstand", pos=wn.NOUN)[1], 90 | "lamp": wn.synsets("lamp", pos=wn.NOUN)[1], 91 | "land": wn.synsets("land", pos=wn.NOUN)[1], 92 | "leopard": wn.synsets("leopard", pos=wn.NOUN)[1], 93 | "left-arm": wn.synsets("arm", pos=wn.NOUN)[0], 94 | "left-leg": wn.synsets("leg", pos=wn.NOUN)[0], 95 | "left-shoe": wn.synsets("shoe", pos=wn.NOUN)[0], 96 | "lobster": wn.synsets("lobster", pos=wn.NOUN)[1], 97 | "maple_tree": wn.synsets("maple", pos=wn.NOUN)[1], 98 | "microwave": wn.synsets("microwave", pos=wn.NOUN)[1], 99 | "monitor": wn.synsets("monitor", pos=wn.NOUN)[3], 100 | "otter": wn.synsets("otter", pos=wn.NOUN)[1], 101 | "ottoman": wn.synsets("ottoman", pos=wn.NOUN)[2], 102 | "path": wn.synsets("path", pos=wn.NOUN)[2], 103 | "plant": wn.synsets("plant", pos=wn.NOUN)[1], 104 | "plate": wn.synsets("plate", pos=wn.NOUN)[3], 105 | "pottedplant": wn.synsets("plant", pos=wn.NOUN)[1], 106 | "raccoon": wn.synsets("raccoon", pos=wn.NOUN)[1], 107 | "radiator": wn.synsets("radiator", pos=wn.NOUN)[1], 108 | "ray": wn.synsets("ray", pos=wn.NOUN)[-1], 109 | "rider": wn.synsets("rider", pos=wn.NOUN)[2], 110 | "runway": wn.synsets("runway", pos=wn.NOUN)[3], 111 | "seal": wn.synsets("seal", pos=wn.NOUN)[-1], 112 | "shrew": wn.synsets("shrew", pos=wn.NOUN)[1], 113 | "sign": wn.synsets("sign", pos=wn.NOUN)[1], 114 | "skunk": wn.synsets("skunk", pos=wn.NOUN)[1], 115 | "stage": wn.synsets("stage", pos=wn.NOUN)[2], 116 | "step": wn.synsets("step", pos=wn.NOUN)[3], 117 | "table": wn.synsets("table", pos=wn.NOUN)[1], 118 | "tiger": wn.synsets("tiger", pos=wn.NOUN)[1], 119 | "toilet": wn.synsets("toilet", pos=wn.NOUN)[1], 120 | "traffic_sign": wn.synsets("street_sign", pos=wn.NOUN)[0], 121 | "turtle": wn.synsets("turtle", pos=wn.NOUN)[1], 122 | "tvmonitor": wn.synsets("tv_monitor", pos=wn.NOUN)[0], 123 | "upper-clothes": wn.synsets("top", pos=wn.NOUN)[9], 124 | "van": wn.synsets("van", pos=wn.NOUN)[-1], 125 | "washer": wn.synsets("washer", pos=wn.NOUN)[2], 126 | "water": wn.synsets("water", pos=wn.NOUN)[1], 127 | "whale": wn.synsets("whale", pos=wn.NOUN)[1], 128 | } 129 | 130 | wnids = [] 131 | for i, cls in enumerate(classes): 132 | if cls in hardcoded_mapping: 133 | synset = hardcoded_mapping[cls] 134 | else: 135 | synsets = wn.synsets(cls, pos=wn.NOUN) 136 | if not synsets: 137 | Colors.red(f"==> Failed to find synset for {cls}. Using fake synset...") 138 | failures.append(cls) 139 | synsets = [FakeSynset.create_from_offset(i)] 140 | synset = synsets[0] 141 | wnid = synset_to_wnid(synset) 142 | print(f"{wnid}: ({cls}) {synset.definition()}") 143 | wnids.append(wnid) 144 | 145 | if not args.classes: 146 | write_wnids(wnids, path) 147 | Colors.green(f"==> Wrote to {path}") 148 | 149 | if failures: 150 | Colors.red(f"==> Warning: failed to find wordnet IDs for {failures}") 151 | -------------------------------------------------------------------------------- /nbdt/bin/original: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Run evaluation on a single image, using a baseline neural network""" 3 | 4 | from nbdt.model import SoftNBDT, HardNBDT 5 | from pytorchcv.models.wrn_cifar import wrn28_10_cifar10 6 | from torchvision import transforms 7 | from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path 8 | from nbdt.thirdparty.wn import maybe_install_wordnet 9 | import torch.nn.functional as F 10 | from torch.distributions import Categorical 11 | import sys 12 | 13 | maybe_install_wordnet() 14 | 15 | assert len(sys.argv) > 1, "Need to pass image URL or image path as argument" 16 | 17 | # load pretrained model 18 | model = wrn28_10_cifar10(pretrained=True) 19 | model.eval() 20 | 21 | # load + transform image 22 | im = load_image_from_path(sys.argv[1]) 23 | transform = transforms.Compose( 24 | [ 25 | transforms.Resize(32), 26 | transforms.CenterCrop(32), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 29 | ] 30 | ) 31 | x = transform(im)[None] 32 | 33 | model.eval() 34 | 35 | probs = F.softmax(model(x), dim=1)[0] 36 | confidence = (1 - Categorical(probs=probs).entropy()) * 100.0 37 | 38 | print( 39 | "Probabilities per class: " 40 | + ", ".join( 41 | [ 42 | f"{cls} ({p.item() * 100:.2f}%)" 43 | for p, cls in sorted( 44 | zip(probs, DATASET_TO_CLASSES["CIFAR10"]), 45 | key=lambda t: t[0], 46 | reverse=True, 47 | ) 48 | ] 49 | ) 50 | + f"// Confidence: {confidence:.2f}%" 51 | ) 52 | -------------------------------------------------------------------------------- /nbdt/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import * 2 | from .imagenet import * 3 | from .pascal_context import * 4 | from .lip import * 5 | from .ade20k import * 6 | from torchvision.datasets import * 7 | from .cifar import * 8 | -------------------------------------------------------------------------------- /nbdt/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | 5 | import cv2 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | from torch.utils import data 12 | 13 | __all__ = names = ("ADE20K",) 14 | 15 | 16 | class BaseDataset(data.Dataset): 17 | def __init__( 18 | self, 19 | ignore_label=-1, 20 | base_size=2048, 21 | crop_size=(512, 1024), 22 | downsample_rate=1, 23 | scale_factor=16, 24 | mean=[0.485, 0.456, 0.406], 25 | std=[0.229, 0.224, 0.225], 26 | ): 27 | 28 | self.base_size = base_size 29 | self.crop_size = crop_size 30 | self.ignore_label = ignore_label 31 | 32 | self.mean = mean 33 | self.std = std 34 | self.scale_factor = scale_factor 35 | self.downsample_rate = 1.0 / downsample_rate 36 | 37 | self.files = [] 38 | 39 | def __len__(self): 40 | return len(self.files) 41 | 42 | def input_transform(self, image): 43 | image = image.astype(np.float32)[:, :, ::-1] 44 | image = image / 255.0 45 | image -= self.mean 46 | image /= self.std 47 | return image 48 | 49 | def label_transform(self, label): 50 | return np.array(label).astype("int32") 51 | 52 | def pad_image(self, image, h, w, size, padvalue): 53 | pad_image = image.copy() 54 | pad_h = max(size[0] - h, 0) 55 | pad_w = max(size[1] - w, 0) 56 | if pad_h > 0 or pad_w > 0: 57 | pad_image = cv2.copyMakeBorder( 58 | image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue 59 | ) 60 | 61 | return pad_image 62 | 63 | def rand_crop(self, image, label): 64 | h, w = image.shape[:-1] 65 | image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0)) 66 | label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,)) 67 | 68 | new_h, new_w = label.shape 69 | x = random.randint(0, new_w - self.crop_size[1]) 70 | y = random.randint(0, new_h - self.crop_size[0]) 71 | image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]] 72 | label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]] 73 | 74 | return image, label 75 | 76 | def center_crop(self, image, label): 77 | h, w = image.shape[:2] 78 | x = int(round((w - self.crop_size[1]) / 2.0)) 79 | y = int(round((h - self.crop_size[0]) / 2.0)) 80 | image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]] 81 | label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]] 82 | 83 | return image, label 84 | 85 | def image_resize(self, image, long_size, label=None): 86 | h, w = image.shape[:2] 87 | if h > w: 88 | new_h = long_size 89 | new_w = np.int(w * long_size / h + 0.5) 90 | else: 91 | new_w = long_size 92 | new_h = np.int(h * long_size / w + 0.5) 93 | 94 | image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 95 | if label is not None: 96 | label = cv2.resize(label, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 97 | else: 98 | return image 99 | 100 | return image, label 101 | 102 | def multi_scale_aug(self, image, label=None, rand_scale=1, rand_crop=True): 103 | long_size = np.int(self.base_size * rand_scale + 0.5) 104 | if label is not None: 105 | image, label = self.image_resize(image, long_size, label) 106 | if rand_crop: 107 | image, label = self.rand_crop(image, label) 108 | return image, label 109 | else: 110 | image = self.image_resize(image, long_size) 111 | return image 112 | 113 | def gen_sample( 114 | self, image, label, multi_scale=True, is_flip=True, center_crop_test=False 115 | ): 116 | if multi_scale: 117 | rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0 118 | image, label = self.multi_scale_aug(image, label, rand_scale=rand_scale) 119 | 120 | if center_crop_test: 121 | image, label = self.image_resize(image, self.base_size, label) 122 | image, label = self.center_crop(image, label) 123 | 124 | image = self.input_transform(image) 125 | label = self.label_transform(label) 126 | 127 | image = image.transpose((2, 0, 1)) 128 | 129 | if is_flip: 130 | flip = np.random.choice(2) * 2 - 1 131 | image = image[:, :, ::flip] 132 | label = label[:, ::flip] 133 | 134 | if self.downsample_rate != 1: 135 | label = cv2.resize( 136 | label, 137 | None, 138 | fx=self.downsample_rate, 139 | fy=self.downsample_rate, 140 | interpolation=cv2.INTER_NEAREST, 141 | ) 142 | 143 | return image, label 144 | 145 | 146 | class ADE20K(BaseDataset): 147 | def __init__( 148 | self, 149 | root="./data/", 150 | list_path="ADE20K/training.odgt", 151 | num_samples=None, 152 | num_classes=150, 153 | multi_scale=True, 154 | flip=True, 155 | ignore_label=-1, 156 | base_size=512, 157 | crop_size=(512, 512), 158 | center_crop_test=False, 159 | downsample_rate=1, 160 | scale_factor=16, 161 | mean=[0.485, 0.456, 0.406], 162 | std=[0.229, 0.224, 0.225], 163 | ): 164 | 165 | super(ADE20K, self).__init__( 166 | ignore_label, base_size, crop_size, downsample_rate, scale_factor, mean, std 167 | ) 168 | 169 | self.root = root 170 | self.list_path = list_path 171 | self.num_classes = num_classes 172 | self.class_weights = None 173 | 174 | self.multi_scale = multi_scale 175 | self.flip = flip 176 | self.center_crop_test = center_crop_test 177 | 178 | self.img_list = [ 179 | json.loads(x.rstrip()) for x in open(os.path.join(root, list_path), "r") 180 | ] 181 | 182 | self.files = self.read_files() 183 | if num_samples: 184 | self.files = self.files[:num_samples] 185 | 186 | self.classes = [ 187 | "wall", 188 | "building", 189 | "sky", 190 | "floor", 191 | "tree", 192 | "ceiling", 193 | "road", 194 | "bed", 195 | "windowpane", 196 | "grass", 197 | "cabinet", 198 | "sidewalk", 199 | "person", 200 | "earth", 201 | "door", 202 | "table", 203 | "mountain", 204 | "plant", 205 | "curtain", 206 | "chair", 207 | "car", 208 | "water", 209 | "painting", 210 | "sofa", 211 | "shelf", 212 | "house", 213 | "sea", 214 | "mirror", 215 | "rug", 216 | "field", 217 | "armchair", 218 | "seat", 219 | "fence", 220 | "desk", 221 | "rock", 222 | "wardrobe", 223 | "lamp", 224 | "bathtub", 225 | "railing", 226 | "cushion", 227 | "pedestal", 228 | "box", 229 | "column", 230 | "signboard", 231 | "chest_of_drawers", 232 | "counter", 233 | "sand", 234 | "sink", 235 | "skyscraper", 236 | "fireplace", 237 | "refrigerator", 238 | "grandstand", 239 | "path", 240 | "stairs", 241 | "runway", 242 | "case", 243 | "pool_table", 244 | "pillow", 245 | "screen_door", 246 | "stairway", 247 | "river", 248 | "bridge", 249 | "bookcase", 250 | "blind", 251 | "coffee_table", 252 | "toilet", 253 | "flower", 254 | "book", 255 | "hill", 256 | "bench", 257 | "countertop", 258 | "stove", 259 | "palm_tree", 260 | "kitchen_island", 261 | "computer", 262 | "swivel_chair", 263 | "boat", 264 | "bar", 265 | "arcade_machine", 266 | "hovel", 267 | "bus", 268 | "towel", 269 | "light_source", 270 | "truck", 271 | "tower", 272 | "chandelier", 273 | "awning", 274 | "streetlight", 275 | "booth", 276 | "television_receiver", 277 | "airplane", 278 | "dirt_track", 279 | "apparel", 280 | "pole", 281 | "land", 282 | "handrail", 283 | "escalator", 284 | "ottoman", 285 | "bottle", 286 | "buffet", 287 | "poster", 288 | "stage", 289 | "van", 290 | "ship", 291 | "fountain", 292 | "conveyer_belt", 293 | "canopy", 294 | "washer", 295 | "toy", 296 | "swimming_pool", 297 | "stool", 298 | "barrel", 299 | "basket", 300 | "waterfall", 301 | "tent", 302 | "bag", 303 | "minibike", 304 | "cradle", 305 | "oven", 306 | "ball", 307 | "food", 308 | "step", 309 | "storage_tank", 310 | "brand", 311 | "microwave", 312 | "flowerpot", 313 | "animal", 314 | "bicycle", 315 | "lake", 316 | "dishwasher", 317 | "screen", 318 | "blanket", 319 | "sculpture", 320 | "exhaust_hood", 321 | "sconce", 322 | "vase", 323 | "traffic_light", 324 | "tray", 325 | "trash_can", 326 | "fan", 327 | "pier", 328 | "crt_screen", 329 | "plate", 330 | "monitor", 331 | "bulletin_board", 332 | "shower", 333 | "radiator", 334 | "drinking_glass", 335 | "clock", 336 | "flag", 337 | ] 338 | 339 | def read_files(self): 340 | files = [] 341 | for item in self.img_list: 342 | image_path = item["fpath_img"].replace("ADEChallengeData2016", "ADE20K") 343 | label_path = item["fpath_segm"].replace("ADEChallengeData2016", "ADE20K") 344 | name = os.path.splitext(os.path.basename(image_path))[0] 345 | files.append( 346 | {"img": image_path, "label": label_path, "name": name,} 347 | ) 348 | return files 349 | 350 | def resize_image_label(self, image, label, size): 351 | scale = size / min(image.shape[0], image.shape[1]) 352 | image = cv2.resize( 353 | image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR 354 | ) 355 | label = cv2.resize( 356 | label, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST 357 | ) 358 | return image, label 359 | 360 | def convert_label(self, label): 361 | # Convert labels to -1 to 149 362 | return np.array(label).astype("int32") - 1 363 | 364 | def __getitem__(self, index): 365 | item = self.files[index] 366 | name = item["name"] 367 | image = cv2.imread(os.path.join(self.root, item["img"]), cv2.IMREAD_COLOR) 368 | size = image.shape 369 | label = cv2.imread(os.path.join(self.root, item["label"]), cv2.IMREAD_GRAYSCALE) 370 | label = self.convert_label(label) 371 | 372 | if "validation" in self.list_path: 373 | image = self.input_transform(image) 374 | image = image.transpose((2, 0, 1)) 375 | label = self.label_transform(label) 376 | else: 377 | image, label = self.resize_image_label(image, label, self.base_size) 378 | image, label = self.gen_sample( 379 | image, label, self.multi_scale, self.flip, self.center_crop_test 380 | ) 381 | 382 | return image.copy(), label.copy(), np.array(size), name 383 | -------------------------------------------------------------------------------- /nbdt/data/cifar.py: -------------------------------------------------------------------------------- 1 | """Wrappers around CIFAR datasets""" 2 | 3 | from torchvision import datasets, transforms 4 | from . import transforms as transforms_custom 5 | 6 | __all__ = names = ("CIFAR10", "CIFAR100") 7 | 8 | 9 | class CIFAR: 10 | @staticmethod 11 | def transform_train(): 12 | return transforms.Compose( 13 | [ 14 | transforms.RandomCrop(32, padding=4), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize( 18 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 19 | ), 20 | ] 21 | ) 22 | 23 | @staticmethod 24 | def transform_val(): 25 | return transforms.Compose( 26 | [ 27 | transforms.ToTensor(), 28 | transforms.Normalize( 29 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 30 | ), 31 | ] 32 | ) 33 | 34 | @staticmethod 35 | def transform_val_inverse(): 36 | return transforms_custom.InverseNormalize( 37 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 38 | ) 39 | 40 | 41 | class CIFAR10(datasets.CIFAR10, CIFAR): 42 | pass 43 | 44 | 45 | class CIFAR100(datasets.CIFAR100, CIFAR): 46 | pass 47 | -------------------------------------------------------------------------------- /nbdt/data/custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from collections import defaultdict 5 | from nbdt.utils import DATASET_TO_NUM_CLASSES, DATASETS 6 | from collections import defaultdict 7 | from nbdt.thirdparty.wn import get_wnids, FakeSynset, wnid_to_synset, wnid_to_name 8 | from nbdt.thirdparty.nx import get_leaves, get_leaf_to_path, read_graph 9 | from nbdt.tree import Tree 10 | from nbdt.utils import ( 11 | dataset_to_default_path_graph, 12 | dataset_to_default_path_wnids, 13 | hierarchy_to_path_graph, 14 | ) 15 | from . import imagenet 16 | from . import cifar 17 | import torch.nn as nn 18 | import random 19 | 20 | 21 | __all__ = names = ( 22 | "CIFAR10IncludeLabels", 23 | "CIFAR100IncludeLabels", 24 | "TinyImagenet200IncludeLabels", 25 | "Imagenet1000IncludeLabels", 26 | "CIFAR10ExcludeLabels", 27 | "CIFAR100ExcludeLabels", 28 | "TinyImagenet200ExcludeLabels", 29 | "Imagenet1000ExcludeLabels", 30 | "CIFAR10ResampleLabels", 31 | "CIFAR100ResampleLabels", 32 | "TinyImagenet200ResampleLabels", 33 | "Imagenet1000ResampleLabels", 34 | ) 35 | 36 | 37 | def add_arguments(parser): 38 | parser.add_argument("--probability-labels", nargs="*", type=float) 39 | parser.add_argument("--include-labels", nargs="*", type=int) 40 | parser.add_argument("--exclude-labels", nargs="*", type=int) 41 | parser.add_argument("--include-classes", nargs="*", type=int) 42 | 43 | 44 | class ResampleLabelsDataset(Dataset): 45 | """ 46 | Dataset that includes only the labels provided, with a limited number of 47 | samples. Note that labels are integers in [0, k) for a k-class dataset. 48 | 49 | :drop_classes bool: Modifies the dataset so that it is only a m-way 50 | classification where m of k classes are kept. Otherwise, 51 | the problem is still k-way. 52 | """ 53 | 54 | accepts_probability_labels = True 55 | 56 | def __init__(self, dataset, probability_labels=1, drop_classes=False, seed=0): 57 | self.dataset = dataset 58 | self.classes = dataset.classes 59 | self.labels = list(range(len(self.classes))) 60 | self.probability_labels = self.get_probability_labels( 61 | dataset, probability_labels 62 | ) 63 | 64 | self.drop_classes = drop_classes 65 | if self.drop_classes: 66 | self.classes, self.labels = self.get_classes_after_drop( 67 | dataset, probability_labels 68 | ) 69 | 70 | assert self.labels, "No labels are included in `include_labels`" 71 | 72 | self.new_to_old = self.build_index_mapping(seed=seed) 73 | 74 | def get_probability_labels(self, dataset, ps): 75 | if not isinstance(ps, (tuple, list)): 76 | return [ps] * len(dataset.classes) 77 | if len(ps) == 1: 78 | return ps * len(dataset.classes) 79 | assert len(ps) == len(dataset.classes), ( 80 | f"Length of probabilities vector {len(ps)} must equal that of the " 81 | f"dataset classes {len(dataset.classes)}." 82 | ) 83 | return ps 84 | 85 | def apply_drop(self, dataset, ps): 86 | classes = [cls for p, cls in zip(ps, dataset.classes) if p > 0] 87 | labels = [i for p, i in zip(ps, range(len(dataset.classes))) if p > 0] 88 | return classes, labels 89 | 90 | def build_index_mapping(self, seed=0): 91 | """Iterates over all samples in dataset. 92 | 93 | Remaps all to-be-included samples to [0, n) where n is the number of 94 | samples with a class in the whitelist. 95 | 96 | Additionally, the outputted list is truncated to match the number of 97 | desired samples. 98 | """ 99 | random.seed(seed) 100 | 101 | new_to_old = [] 102 | for old, (_, label) in enumerate(self.dataset): 103 | if random.random() < self.probability_labels[label]: 104 | new_to_old.append(old) 105 | return new_to_old 106 | 107 | def __getitem__(self, index_new): 108 | index_old = self.new_to_old[index_new] 109 | sample, label_old = self.dataset[index_old] 110 | 111 | label_new = label_old 112 | if self.drop_classes: 113 | label_new = self.include_labels.index(label_old) 114 | 115 | return sample, label_new 116 | 117 | def __len__(self): 118 | return len(self.new_to_old) 119 | 120 | 121 | class IncludeLabelsDataset(ResampleLabelsDataset): 122 | 123 | accepts_include_labels = True 124 | accepts_probability_labels = False 125 | 126 | def __init__(self, dataset, include_labels=(0,)): 127 | super().__init__( 128 | dataset, 129 | probability_labels=[ 130 | int(cls in include_labels) for cls in range(len(dataset.classes)) 131 | ], 132 | ) 133 | 134 | 135 | def get_resample_labels_dataset(dataset): 136 | class Cls(ResampleLabelsDataset): 137 | def __init__(self, *args, root="./data", probability_labels=1, **kwargs): 138 | super().__init__( 139 | dataset=dataset(*args, root=root, **kwargs), 140 | probability_labels=probability_labels, 141 | ) 142 | 143 | Cls.__name__ = f"{dataset.__class__.__name__}ResampleLabels" 144 | return Cls 145 | 146 | 147 | CIFAR10ResampleLabels = get_resample_labels_dataset(cifar.CIFAR10) 148 | CIFAR100ResampleLabels = get_resample_labels_dataset(cifar.CIFAR100) 149 | TinyImagenet200ResampleLabels = get_resample_labels_dataset(imagenet.TinyImagenet200) 150 | Imagenet1000ResampleLabels = get_resample_labels_dataset(imagenet.Imagenet1000) 151 | 152 | 153 | class IncludeClassesDataset(IncludeLabelsDataset): 154 | """ 155 | Dataset that includes only the labels provided, with a limited number of 156 | samples. Note that classes are strings, like 'cat' or 'dog'. 157 | """ 158 | 159 | accepts_include_labels = False 160 | accepts_include_classes = True 161 | 162 | def __init__(self, dataset, include_classes=()): 163 | super().__init__( 164 | dataset, 165 | include_labels=[dataset.classes.index(cls) for cls in include_classes], 166 | ) 167 | 168 | 169 | def get_include_labels_dataset(dataset): 170 | class Cls(IncludeLabelsDataset): 171 | def __init__(self, *args, root="./data", include_labels=(0,), **kwargs): 172 | super().__init__( 173 | dataset=dataset(*args, root=root, **kwargs), 174 | include_labels=include_labels, 175 | ) 176 | 177 | Cls.__name__ = f"{dataset.__class__.__name__}IncludeLabels" 178 | return Cls 179 | 180 | 181 | CIFAR10IncludeLabels = get_include_labels_dataset(cifar.CIFAR10) 182 | CIFAR100IncludeLabels = get_include_labels_dataset(cifar.CIFAR100) 183 | TinyImagenet200IncludeLabels = get_include_labels_dataset(imagenet.TinyImagenet200) 184 | Imagenet1000IncludeLabels = get_include_labels_dataset(imagenet.Imagenet1000) 185 | 186 | 187 | class ExcludeLabelsDataset(IncludeLabelsDataset): 188 | 189 | accepts_include_labels = False 190 | accepts_exclude_labels = True 191 | 192 | def __init__(self, dataset, exclude_labels=(0,)): 193 | k = len(dataset.classes) 194 | include_labels = set(range(k)) - set(exclude_labels) 195 | super().__init__(dataset=dataset, include_labels=include_labels) 196 | 197 | 198 | def get_exclude_labels_dataset(dataset): 199 | class Cls(ExcludeLabelsDataset): 200 | def __init__(self, *args, root="./data", exclude_labels=(0,), **kwargs): 201 | super().__init__( 202 | dataset=dataset(*args, root=root, **kwargs), 203 | exclude_labels=exclude_labels, 204 | ) 205 | 206 | Cls.__name__ = f"{dataset.__class__.__name__}ExcludeLabels" 207 | return Cls 208 | 209 | 210 | CIFAR10ExcludeLabels = get_exclude_labels_dataset(cifar.CIFAR10) 211 | CIFAR100ExcludeLabels = get_exclude_labels_dataset(cifar.CIFAR100) 212 | TinyImagenet200ExcludeLabels = get_exclude_labels_dataset(imagenet.TinyImagenet200) 213 | Imagenet1000ExcludeLabels = get_exclude_labels_dataset(imagenet.Imagenet1000) 214 | -------------------------------------------------------------------------------- /nbdt/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | from . import transforms as transforms_custom 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | import zipfile 8 | import urllib.request 9 | import shutil 10 | import time 11 | 12 | 13 | __all__ = names = ( 14 | "TinyImagenet200", 15 | "Imagenet1000", 16 | ) 17 | 18 | 19 | class TinyImagenet200(Dataset): 20 | """Tiny imagenet dataloader""" 21 | 22 | url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" 23 | 24 | dataset = None 25 | 26 | def __init__(self, root="./data", *args, train=True, download=False, **kwargs): 27 | super().__init__() 28 | 29 | if download: 30 | self.download(root=root) 31 | dataset = _TinyImagenet200Train if train else _TinyImagenet200Val 32 | self.root = root 33 | self.dataset = dataset(root, *args, **kwargs) 34 | self.classes = self.dataset.classes 35 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 36 | 37 | @staticmethod 38 | def transform_train(input_size=64): 39 | return transforms.Compose( 40 | [ 41 | transforms.RandomCrop(input_size, padding=8), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize( 45 | [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262] 46 | ), 47 | ] 48 | ) 49 | 50 | @staticmethod 51 | def transform_val(input_size=-1): 52 | return transforms.Compose( 53 | [ 54 | transforms.ToTensor(), 55 | transforms.Normalize( 56 | [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262] 57 | ), 58 | ] 59 | ) 60 | 61 | @staticmethod 62 | def transform_val_inverse(): 63 | return transforms_custom.InverseNormalize( 64 | [0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262] 65 | ) 66 | 67 | def download(self, root="./"): 68 | """Download and unzip Imagenet200 files in the `root` directory.""" 69 | dir = os.path.join(root, "tiny-imagenet-200") 70 | dir_train = os.path.join(dir, "train") 71 | if os.path.exists(dir) and os.path.exists(dir_train): 72 | print("==> Already downloaded.") 73 | return 74 | 75 | path = Path(os.path.join(root, "tiny-imagenet-200.zip")) 76 | if not os.path.exists(path): 77 | os.makedirs(path.parent, exist_ok=True) 78 | 79 | print("==> Downloading TinyImagenet200...") 80 | with urllib.request.urlopen(self.url) as response, open( 81 | str(path), "wb" 82 | ) as out_file: 83 | shutil.copyfileobj(response, out_file) 84 | 85 | print("==> Extracting TinyImagenet200...") 86 | with zipfile.ZipFile(str(path)) as zf: 87 | zf.extractall(root) 88 | 89 | def __getitem__(self, i): 90 | return self.dataset[i] 91 | 92 | def __len__(self): 93 | return len(self.dataset) 94 | 95 | 96 | class _TinyImagenet200Train(datasets.ImageFolder): 97 | def __init__(self, root="./data", *args, **kwargs): 98 | super().__init__(os.path.join(root, "tiny-imagenet-200/train"), *args, **kwargs) 99 | 100 | 101 | class _TinyImagenet200Val(datasets.ImageFolder): 102 | def __init__(self, root="./data", *args, **kwargs): 103 | super().__init__(os.path.join(root, "tiny-imagenet-200/val"), *args, **kwargs) 104 | 105 | self.path_to_class = {} 106 | with open(os.path.join(self.root, "val_annotations.txt")) as f: 107 | for line in f.readlines(): 108 | parts = line.split() 109 | path = os.path.join(self.root, "images", parts[0]) 110 | self.path_to_class[path] = parts[1] 111 | 112 | self.classes = list(sorted(set(self.path_to_class.values()))) 113 | self.class_to_idx = {label: self.classes.index(label) for label in self.classes} 114 | 115 | def __getitem__(self, i): 116 | sample, _ = super().__getitem__(i) 117 | path, _ = self.samples[i] 118 | label = self.path_to_class[path] 119 | target = self.class_to_idx[label] 120 | return sample, target 121 | 122 | def __len__(self): 123 | return super().__len__() 124 | 125 | 126 | class Imagenet1000(Dataset): 127 | """ImageNet dataloader""" 128 | 129 | dataset = None 130 | 131 | def __init__(self, root="./data", *args, train=True, download=False, **kwargs): 132 | super().__init__() 133 | 134 | if download: 135 | self.download(root=root) 136 | dataset = _Imagenet1000Train if train else _Imagenet1000Val 137 | self.root = root 138 | self.dataset = dataset(root, *args, **kwargs) 139 | self.classes = self.dataset.classes 140 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 141 | 142 | def download(self, root="./"): 143 | dir = os.path.join(root, "imagenet-1000") 144 | dir_train = os.path.join(dir, "train") 145 | if os.path.exists(dir) and os.path.exists(dir_train): 146 | print("==> Already downloaded.") 147 | return 148 | 149 | msg = "Please symlink existing ImageNet dataset rather than downloading." 150 | raise RuntimeError(msg) 151 | 152 | @staticmethod 153 | def transform_train(input_size=224): 154 | return transforms.Compose( 155 | [ 156 | transforms.RandomResizedCrop(input_size), # TODO: may need updating 157 | transforms.RandomHorizontalFlip(), 158 | transforms.ToTensor(), 159 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 160 | ] 161 | ) 162 | 163 | @staticmethod 164 | def transform_val(input_size=224): 165 | return transforms.Compose( 166 | [ 167 | transforms.Resize(input_size + 32), 168 | transforms.CenterCrop(input_size), 169 | transforms.ToTensor(), 170 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 171 | ] 172 | ) 173 | 174 | @staticmethod 175 | def transform_val_inverse(): 176 | return transforms_custom.InverseNormalize( 177 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 178 | ) 179 | 180 | def __getitem__(self, i): 181 | return self.dataset[i] 182 | 183 | def __len__(self): 184 | return len(self.dataset) 185 | 186 | 187 | class _Imagenet1000Train(datasets.ImageFolder): 188 | def __init__(self, root="./data", *args, **kwargs): 189 | super().__init__(os.path.join(root, "imagenet-1000/train"), *args, **kwargs) 190 | 191 | 192 | class _Imagenet1000Val(datasets.ImageFolder): 193 | def __init__(self, root="./data", *args, **kwargs): 194 | super().__init__(os.path.join(root, "imagenet-1000/val"), *args, **kwargs) 195 | 196 | -------------------------------------------------------------------------------- /nbdt/data/lip.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | import random 12 | 13 | import torch 14 | from torch.nn import functional as F 15 | from torch.utils import data 16 | 17 | __all__ = names = ("LookIntoPerson",) 18 | 19 | 20 | class BaseDataset(data.Dataset): 21 | def __init__( 22 | self, 23 | ignore_label=-1, 24 | base_size=2048, 25 | crop_size=(512, 1024), 26 | downsample_rate=1, 27 | scale_factor=16, 28 | mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225], 30 | ): 31 | 32 | self.base_size = base_size 33 | self.crop_size = crop_size 34 | self.ignore_label = ignore_label 35 | 36 | self.mean = mean 37 | self.std = std 38 | self.scale_factor = scale_factor 39 | self.downsample_rate = 1.0 / downsample_rate 40 | 41 | self.files = [] 42 | 43 | def __len__(self): 44 | return len(self.files) 45 | 46 | def input_transform(self, image): 47 | image = image.astype(np.float32)[:, :, ::-1] 48 | image = image / 255.0 49 | image -= self.mean 50 | image /= self.std 51 | return image 52 | 53 | def label_transform(self, label): 54 | return np.array(label).astype("int32") 55 | 56 | def pad_image(self, image, h, w, size, padvalue): 57 | pad_image = image.copy() 58 | pad_h = max(size[0] - h, 0) 59 | pad_w = max(size[1] - w, 0) 60 | if pad_h > 0 or pad_w > 0: 61 | pad_image = cv2.copyMakeBorder( 62 | image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=padvalue 63 | ) 64 | 65 | return pad_image 66 | 67 | def rand_crop(self, image, label): 68 | h, w = image.shape[:-1] 69 | image = self.pad_image(image, h, w, self.crop_size, (0.0, 0.0, 0.0)) 70 | label = self.pad_image(label, h, w, self.crop_size, (self.ignore_label,)) 71 | 72 | new_h, new_w = label.shape 73 | x = random.randint(0, new_w - self.crop_size[1]) 74 | y = random.randint(0, new_h - self.crop_size[0]) 75 | image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]] 76 | label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]] 77 | 78 | return image, label 79 | 80 | def center_crop(self, image, label): 81 | h, w = image.shape[:2] 82 | x = int(round((w - self.crop_size[1]) / 2.0)) 83 | y = int(round((h - self.crop_size[0]) / 2.0)) 84 | image = image[y : y + self.crop_size[0], x : x + self.crop_size[1]] 85 | label = label[y : y + self.crop_size[0], x : x + self.crop_size[1]] 86 | 87 | return image, label 88 | 89 | def image_resize(self, image, long_size, label=None): 90 | h, w = image.shape[:2] 91 | if h > w: 92 | new_h = long_size 93 | new_w = np.int(w * long_size / h + 0.5) 94 | else: 95 | new_w = long_size 96 | new_h = np.int(h * long_size / w + 0.5) 97 | 98 | image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 99 | if label is not None: 100 | label = cv2.resize(label, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 101 | else: 102 | return image 103 | 104 | return image, label 105 | 106 | def multi_scale_aug(self, image, label=None, rand_scale=1, rand_crop=True): 107 | long_size = np.int(self.base_size * rand_scale + 0.5) 108 | if label is not None: 109 | image, label = self.image_resize(image, long_size, label) 110 | if rand_crop: 111 | image, label = self.rand_crop(image, label) 112 | return image, label 113 | else: 114 | image = self.image_resize(image, long_size) 115 | return image 116 | 117 | def gen_sample( 118 | self, image, label, multi_scale=True, is_flip=True, center_crop_test=False 119 | ): 120 | if multi_scale: 121 | rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0 122 | image, label = self.multi_scale_aug(image, label, rand_scale=rand_scale) 123 | 124 | if center_crop_test: 125 | image, label = self.image_resize(image, self.base_size, label) 126 | image, label = self.center_crop(image, label) 127 | 128 | image = self.input_transform(image) 129 | label = self.label_transform(label) 130 | 131 | image = image.transpose((2, 0, 1)) 132 | 133 | if is_flip: 134 | flip = np.random.choice(2) * 2 - 1 135 | image = image[:, :, ::flip] 136 | label = label[:, ::flip] 137 | 138 | if self.downsample_rate != 1: 139 | label = cv2.resize( 140 | label, 141 | None, 142 | fx=self.downsample_rate, 143 | fy=self.downsample_rate, 144 | interpolation=cv2.INTER_NEAREST, 145 | ) 146 | 147 | return image, label 148 | 149 | 150 | class LookIntoPerson(BaseDataset): 151 | def __init__( 152 | self, 153 | root="./data/", 154 | list_path="LookIntoPerson/trainList.txt", 155 | num_samples=None, 156 | num_classes=20, 157 | multi_scale=True, 158 | flip=True, 159 | ignore_label=-1, 160 | base_size=473, 161 | crop_size=(473, 473), 162 | downsample_rate=1, 163 | scale_factor=11, 164 | center_crop_test=False, 165 | mean=[0.485, 0.456, 0.406], 166 | std=[0.229, 0.224, 0.225], 167 | ): 168 | 169 | super(LookIntoPerson, self).__init__( 170 | ignore_label, base_size, crop_size, downsample_rate, scale_factor, mean, std 171 | ) 172 | 173 | self.root = root 174 | self.num_classes = num_classes 175 | self.list_path = list_path 176 | self.class_weights = None 177 | self.classes = [ 178 | "background", 179 | "hat", 180 | "hair", 181 | "glove", 182 | "sunglasses", 183 | "upper-clothes", 184 | "dress", 185 | "coat", 186 | "socks", 187 | "pants", 188 | "jumpsuits", 189 | "scarf", 190 | "skirt", 191 | "face", 192 | "left-arm", 193 | "right-arm", 194 | "left-leg", 195 | "right-leg", 196 | "left-shoe", 197 | "right-shoe", 198 | ] 199 | 200 | self.multi_scale = multi_scale 201 | self.flip = flip 202 | self.img_list = [ 203 | line.strip().split() for line in open(os.path.join(root, list_path)) 204 | ] 205 | 206 | self.files = self.read_files() 207 | if num_samples: 208 | self.files = self.files[:num_samples] 209 | 210 | def read_files(self): 211 | files = [] 212 | for item in self.img_list: 213 | image_path, label_path = item[:2] 214 | name = os.path.splitext(os.path.basename(label_path))[0] 215 | sample = { 216 | "img": image_path, 217 | "label": label_path, 218 | "name": name, 219 | } 220 | files.append(sample) 221 | return files 222 | 223 | def resize_image(self, image, label, size): 224 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 225 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) 226 | return image, label 227 | 228 | def __getitem__(self, index): 229 | item = self.files[index] 230 | name = item["name"] 231 | 232 | image = cv2.imread( 233 | os.path.join(self.root, "LookIntoPerson/TrainVal_images/", item["img"]), 234 | cv2.IMREAD_COLOR, 235 | ) 236 | label = cv2.imread( 237 | os.path.join( 238 | self.root, "LookIntoPerson/TrainVal_parsing_annotations/", item["label"] 239 | ), 240 | cv2.IMREAD_GRAYSCALE, 241 | ) 242 | size = label.shape 243 | 244 | if "testval" in self.list_path: 245 | image = cv2.resize(image, self.crop_size, interpolation=cv2.INTER_LINEAR) 246 | image = self.input_transform(image) 247 | image = image.transpose((2, 0, 1)) 248 | 249 | return image.copy(), label.copy(), np.array(size), name 250 | 251 | if self.flip: 252 | flip = np.random.choice(2) * 2 - 1 253 | image = image[:, ::flip, :] 254 | label = label[:, ::flip] 255 | 256 | if flip == -1: 257 | right_idx = [15, 17, 19] 258 | left_idx = [14, 16, 18] 259 | for i in range(0, 3): 260 | right_pos = np.where(label == right_idx[i]) 261 | left_pos = np.where(label == left_idx[i]) 262 | label[right_pos[0], right_pos[1]] = left_idx[i] 263 | label[left_pos[0], left_pos[1]] = right_idx[i] 264 | 265 | image, label = self.resize_image(image, label, self.crop_size) 266 | image, label = self.gen_sample(image, label, self.multi_scale, False) 267 | 268 | return image.copy(), label.copy(), np.array(size), name 269 | -------------------------------------------------------------------------------- /nbdt/data/pascal_context.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: Hang Zhang 3 | # Email: zhang.hang@rutgers.edu 4 | # Copyright (c) 2017 5 | ########################################################################### 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | import os 9 | import math 10 | import random 11 | import numpy as np 12 | from tqdm import trange 13 | 14 | import torch 15 | import torch.utils.data as data 16 | 17 | __all__ = names = ("PascalContext",) 18 | 19 | 20 | class BaseDataset(data.Dataset): 21 | def __init__( 22 | self, 23 | root, 24 | split, 25 | mode=None, 26 | transform=None, 27 | target_transform=None, 28 | base_size=520, 29 | crop_size=480, 30 | ): 31 | self.root = root 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.split = split 35 | self.mode = mode if mode is not None else split 36 | self.base_size = base_size 37 | self.crop_size = crop_size 38 | if self.mode == "train": 39 | print( 40 | "BaseDataset: base_size {}, crop_size {}".format(base_size, crop_size) 41 | ) 42 | 43 | def __getitem__(self, index): 44 | raise NotImplemented 45 | 46 | @property 47 | def num_class(self): 48 | return self.NUM_CLASS 49 | 50 | @property 51 | def pred_offset(self): 52 | raise NotImplemented 53 | 54 | def make_pred(self, x): 55 | return x + self.pred_offset 56 | 57 | def _val_sync_transform(self, img, mask): 58 | outsize = self.crop_size 59 | short_size = outsize 60 | w, h = img.size 61 | if w > h: 62 | oh = short_size 63 | ow = int(1.0 * w * oh / h) 64 | else: 65 | ow = short_size 66 | oh = int(1.0 * h * ow / w) 67 | img = img.resize((ow, oh), Image.BILINEAR) 68 | mask = mask.resize((ow, oh), Image.NEAREST) 69 | # center crop 70 | w, h = img.size 71 | x1 = int(round((w - outsize) / 2.0)) 72 | y1 = int(round((h - outsize) / 2.0)) 73 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) 74 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) 75 | # final transform 76 | return img, self._mask_transform(mask) 77 | 78 | def _sync_transform(self, img, mask): 79 | # random mirror 80 | if random.random() < 0.5: 81 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 82 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 83 | crop_size = self.crop_size 84 | # random scale (short edge) 85 | w, h = img.size 86 | long_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 87 | if h > w: 88 | oh = long_size 89 | ow = int(1.0 * w * long_size / h + 0.5) 90 | short_size = ow 91 | else: 92 | ow = long_size 93 | oh = int(1.0 * h * long_size / w + 0.5) 94 | short_size = oh 95 | img = img.resize((ow, oh), Image.BILINEAR) 96 | mask = mask.resize((ow, oh), Image.NEAREST) 97 | # pad crop 98 | if short_size < crop_size: 99 | padh = crop_size - oh if oh < crop_size else 0 100 | padw = crop_size - ow if ow < crop_size else 0 101 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 102 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 103 | # random crop crop_size 104 | w, h = img.size 105 | x1 = random.randint(0, w - crop_size) 106 | y1 = random.randint(0, h - crop_size) 107 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 108 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 109 | # final transform 110 | return img, self._mask_transform(mask) 111 | 112 | def _mask_transform(self, mask): 113 | return torch.from_numpy(np.array(mask)).long() 114 | 115 | 116 | class PascalContext(BaseDataset): 117 | NUM_CLASS = 59 118 | 119 | def __init__( 120 | self, 121 | root="./data", 122 | split="train", 123 | mode=None, 124 | transform=None, 125 | target_transform=None, 126 | **kwargs 127 | ): 128 | super(PascalContext, self).__init__( 129 | root, split, mode, transform, target_transform, **kwargs 130 | ) 131 | from detail import Detail 132 | 133 | # from detail import mask 134 | root = os.path.join(root, "PascalContext") 135 | annFile = os.path.join(root, "trainval_merged.json") 136 | imgDir = os.path.join(root, "JPEGImages") 137 | # training mode 138 | self.detail = Detail(annFile, imgDir, split) 139 | self.transform = transform 140 | self.target_transform = target_transform 141 | self.ids = self.detail.getImgs() 142 | # generate masks 143 | self._mapping = np.sort( 144 | np.array( 145 | [ 146 | 0, 147 | 2, 148 | 259, 149 | 260, 150 | 415, 151 | 324, 152 | 9, 153 | 258, 154 | 144, 155 | 18, 156 | 19, 157 | 22, 158 | 23, 159 | 397, 160 | 25, 161 | 284, 162 | 158, 163 | 159, 164 | 416, 165 | 33, 166 | 162, 167 | 420, 168 | 454, 169 | 295, 170 | 296, 171 | 427, 172 | 44, 173 | 45, 174 | 46, 175 | 308, 176 | 59, 177 | 440, 178 | 445, 179 | 31, 180 | 232, 181 | 65, 182 | 354, 183 | 424, 184 | 68, 185 | 326, 186 | 72, 187 | 458, 188 | 34, 189 | 207, 190 | 80, 191 | 355, 192 | 85, 193 | 347, 194 | 220, 195 | 349, 196 | 360, 197 | 98, 198 | 187, 199 | 104, 200 | 105, 201 | 366, 202 | 189, 203 | 368, 204 | 113, 205 | 115, 206 | ] 207 | ) 208 | ) 209 | self.classes = [ 210 | "background", 211 | "aeroplane", 212 | "mountain", 213 | "mouse", 214 | "track", 215 | "road", 216 | "bag", 217 | "motorbike", 218 | "fence", 219 | "bed", 220 | "bedclothes", 221 | "bench", 222 | "bicycle", 223 | "diningtable", 224 | "bird", 225 | "person", 226 | "floor", 227 | "boat", 228 | "train", 229 | "book", 230 | "bottle", 231 | "tree", 232 | "window", 233 | "plate", 234 | "platform", 235 | "tvmonitor", 236 | "building", 237 | "bus", 238 | "cabinet", 239 | "shelves", 240 | "light", 241 | "pottedplant", 242 | "wall", 243 | "car", 244 | "ground", 245 | "cat", 246 | "sidewalk", 247 | "truck", 248 | "ceiling", 249 | "rock", 250 | "chair", 251 | "wood", 252 | "food", 253 | "horse", 254 | "cloth", 255 | "sign", 256 | "computer", 257 | "sheep", 258 | "keyboard", 259 | "flower", 260 | "sky", 261 | "cow", 262 | "grass", 263 | "cup", 264 | "curtain", 265 | "snow", 266 | "water", 267 | "sofa", 268 | "dog", 269 | "door", 270 | ] 271 | self._key = np.array(range(len(self._mapping))).astype("uint8") 272 | mask_file = os.path.join(root, self.split + ".pth") 273 | print("mask_file:", mask_file) 274 | if os.path.exists(mask_file): 275 | self.masks = torch.load(mask_file) 276 | else: 277 | self.masks = self._preprocess(mask_file) 278 | 279 | def _class_to_index(self, mask): 280 | # assert the values 281 | values = np.unique(mask) 282 | for i in range(len(values)): 283 | assert values[i] in self._mapping 284 | index = np.digitize(mask.ravel(), self._mapping, right=True) 285 | return self._key[index].reshape(mask.shape) 286 | 287 | def _preprocess(self, mask_file): 288 | masks = {} 289 | tbar = trange(len(self.ids)) 290 | print( 291 | "Preprocessing mask, this will take a while." 292 | + "But don't worry, it only run once for each split." 293 | ) 294 | for i in tbar: 295 | img_id = self.ids[i] 296 | mask = Image.fromarray(self._class_to_index(self.detail.getMask(img_id))) 297 | masks[img_id["image_id"]] = mask 298 | tbar.set_description("Preprocessing masks {}".format(img_id["image_id"])) 299 | torch.save(masks, mask_file) 300 | return masks 301 | 302 | def __getitem__(self, index): 303 | img_id = self.ids[index] 304 | path = img_id["file_name"] 305 | iid = img_id["image_id"] 306 | img = Image.open(os.path.join(self.detail.img_folder, path)).convert("RGB") 307 | if self.mode == "test": 308 | if self.transform is not None: 309 | img = self.transform(img) 310 | return img, os.path.basename(path) 311 | # convert mask to 60 categories 312 | mask = self.masks[iid] 313 | # synchrosized transform 314 | if self.mode == "train": 315 | img, mask = self._sync_transform(img, mask) 316 | elif self.mode == "val": 317 | img, mask = self._val_sync_transform(img, mask) 318 | else: 319 | assert self.mode == "testval" 320 | mask = self._mask_transform(mask) 321 | # general resize, normalize and toTensor 322 | if self.transform is not None: 323 | img = self.transform(img) 324 | if self.target_transform is not None: 325 | mask = self.target_transform(mask) 326 | return img, mask 327 | 328 | def _mask_transform(self, mask): 329 | target = np.array(mask).astype("int32") - 1 330 | return torch.from_numpy(target).long() 331 | 332 | def __len__(self): 333 | return len(self.ids) 334 | 335 | @property 336 | def pred_offset(self): 337 | return 1 338 | -------------------------------------------------------------------------------- /nbdt/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class InverseNormalize: 5 | def __init__(self, mean, std): 6 | self.mean = torch.Tensor(mean)[None, :, None, None] 7 | self.std = torch.Tensor(std)[None, :, None, None] 8 | 9 | def __call__(self, sample): 10 | return (sample * self.std) + self.mean 11 | 12 | def to(self, device): 13 | self.mean = self.mean.to(device) 14 | self.std = self.std.to(device) 15 | return self 16 | -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR10/graph-induced-ResNet10.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "airplane", "id": "n02691156"}, {"label": "car", "id": "n02958343"}, {"label": "bird", "id": "n01503061"}, {"label": "cat", "id": "n02121620"}, {"label": "deer", "id": "n02430045"}, {"label": "dog", "id": "n02084071"}, {"label": "frog", "id": "n01639765"}, {"label": "horse", "id": "n02374451"}, {"label": "ship", "id": "n04194289"}, {"label": "truck", "id": "n04490091"}, {"label": "motor_vehicle", "id": "n03791235"}, {"label": "vehicle", "id": "n04524313"}, {"label": "placental", "id": "n01886756"}, {"label": "vertebrate", "id": "n01471682"}, {"label": "conveyance", "id": "n03100490"}, {"label": "mammal", "id": "n01861778"}, {"label": "whole", "id": "n00003553"}, {"label": "chordate", "id": "n01466257"}, {"label": "object", "id": "n00002684"}], "links": [{"source": "n03791235", "target": "n02958343"}, {"source": "n03791235", "target": "n04490091"}, {"source": "n04524313", "target": "n04194289"}, {"source": "n04524313", "target": "n03791235"}, {"source": "n01886756", "target": "n02084071"}, {"source": "n01886756", "target": "n02374451"}, {"source": "n01471682", "target": "n02121620"}, {"source": "n01471682", "target": "n01639765"}, {"source": "n03100490", "target": "n02691156"}, {"source": "n03100490", "target": "n04524313"}, {"source": "n01861778", "target": "n02430045"}, {"source": "n01861778", "target": "n01886756"}, {"source": "n00003553", "target": "n01471682"}, {"source": "n00003553", "target": "n03100490"}, {"source": "n01466257", "target": "n01503061"}, {"source": "n01466257", "target": "n01861778"}, {"source": "n00002684", "target": "n00003553"}, {"source": "n00002684", "target": "n01466257"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR10/graph-induced-ResNet18.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "airplane", "id": "n02691156"}, {"label": "car", "id": "n02958343"}, {"label": "bird", "id": "n01503061"}, {"label": "cat", "id": "n02121620"}, {"label": "deer", "id": "n02430045"}, {"label": "dog", "id": "n02084071"}, {"label": "frog", "id": "n01639765"}, {"label": "horse", "id": "n02374451"}, {"label": "ship", "id": "n04194289"}, {"label": "truck", "id": "n04490091"}, {"label": "motor_vehicle", "id": "n03791235"}, {"label": "carnivore", "id": "n02075296"}, {"label": "craft", "id": "n03125870"}, {"label": "vertebrate", "id": "n01471682"}, {"label": "ungulate", "id": "n02370806"}, {"label": "vehicle", "id": "n04524313"}, {"label": "placental", "id": "n01886756"}, {"label": "chordate", "id": "n01466257"}, {"label": "whole", "id": "n00003553"}], "links": [{"source": "n03791235", "target": "n02958343"}, {"source": "n03791235", "target": "n04490091"}, {"source": "n02075296", "target": "n02121620"}, {"source": "n02075296", "target": "n02084071"}, {"source": "n03125870", "target": "n02691156"}, {"source": "n03125870", "target": "n04194289"}, {"source": "n01471682", "target": "n01503061"}, {"source": "n01471682", "target": "n01639765"}, {"source": "n02370806", "target": "n02430045"}, {"source": "n02370806", "target": "n02374451"}, {"source": "n04524313", "target": "n03791235"}, {"source": "n04524313", "target": "n03125870"}, {"source": "n01886756", "target": "n02075296"}, {"source": "n01886756", "target": "n02370806"}, {"source": "n01466257", "target": "n01471682"}, {"source": "n01466257", "target": "n01886756"}, {"source": "n00003553", "target": "n04524313"}, {"source": "n00003553", "target": "n01466257"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR10/graph-induced-wrn28_10_cifar10.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "airplane", "id": "n02691156"}, {"label": "car", "id": "n02958343"}, {"label": "bird", "id": "n01503061"}, {"label": "cat", "id": "n02121620"}, {"label": "deer", "id": "n02430045"}, {"label": "dog", "id": "n02084071"}, {"label": "frog", "id": "n01639765"}, {"label": "horse", "id": "n02374451"}, {"label": "ship", "id": "n04194289"}, {"label": "truck", "id": "n04490091"}, {"label": "carnivore", "id": "n02075296"}, {"label": "motor_vehicle", "id": "n03791235"}, {"label": "craft", "id": "n03125870"}, {"label": "ungulate", "id": "n02370806"}, {"label": "vertebrate", "id": "n01471682"}, {"label": "chordate", "id": "n01466257"}, {"label": "animal", "id": "n00015388"}, {"label": "vehicle", "id": "n04524313"}, {"label": "whole", "id": "n00003553"}], "links": [{"source": "n02075296", "target": "n02121620"}, {"source": "n02075296", "target": "n02084071"}, {"source": "n03791235", "target": "n02958343"}, {"source": "n03791235", "target": "n04490091"}, {"source": "n03125870", "target": "n02691156"}, {"source": "n03125870", "target": "n04194289"}, {"source": "n02370806", "target": "n02430045"}, {"source": "n02370806", "target": "n02374451"}, {"source": "n01471682", "target": "n01503061"}, {"source": "n01471682", "target": "n01639765"}, {"source": "n01466257", "target": "n02075296"}, {"source": "n01466257", "target": "n01471682"}, {"source": "n00015388", "target": "n02370806"}, {"source": "n00015388", "target": "n01466257"}, {"source": "n04524313", "target": "n03791235"}, {"source": "n04524313", "target": "n03125870"}, {"source": "n00003553", "target": "n00015388"}, {"source": "n00003553", "target": "n04524313"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR10/graph-induced.json: -------------------------------------------------------------------------------- 1 | graph-induced-wrn28_10_cifar10.json -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR10/graph-wordnet.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "airplane", "contraction": {"n03510583": {"label": "heavier-than-air_craft"}, "n02686568": {"label": "aircraft"}}, "id": "n02691156"}, {"label": "craft", "contraction": {"n04524313": {"label": "vehicle"}, "n03100490": {"label": "conveyance"}}, "id": "n03125870"}, {"label": "instrumentality", "contraction": {"n00021939": {"label": "artifact"}}, "id": "n03575240"}, {"label": "whole", "contraction": {"n00002684": {"label": "object"}, "n00001930": {"label": "physical_entity"}, "n00001740": {"label": "entity"}}, "id": "n00003553"}, {"label": "car", "id": "n02958343"}, {"label": "motor_vehicle", "contraction": {"n04170037": {"label": "self-propelled_vehicle"}, "n04576211": {"label": "wheeled_vehicle"}, "n03094503": {"label": "container"}}, "id": "n03791235"}, {"label": "bird", "id": "n01503061"}, {"label": "vertebrate", "contraction": {"n01466257": {"label": "chordate"}, "n00015388": {"label": "animal"}, "n00004475": {"label": "organism"}, "n00004258": {"label": "living_thing"}}, "id": "n01471682"}, {"label": "cat", "contraction": {"n02120997": {"label": "feline"}}, "id": "n02121620"}, {"label": "carnivore", "id": "n02075296"}, {"label": "placental", "contraction": {"n01861778": {"label": "mammal"}}, "id": "n01886756"}, {"label": "deer", "contraction": {"n02399000": {"label": "ruminant"}, "n02394477": {"label": "even-toed_ungulate"}}, "id": "n02430045"}, {"label": "ungulate", "id": "n02370806"}, {"label": "dog", "contraction": {"n02083346": {"label": "canine"}}, "id": "n02084071"}, {"label": "frog", "contraction": {"n01627424": {"label": "amphibian"}}, "id": "n01639765"}, {"label": "horse", "contraction": {"n02374149": {"label": "equine"}, "n02373336": {"label": "odd-toed_ungulate"}}, "id": "n02374451"}, {"label": "ship", "contraction": {"n04530566": {"label": "vessel"}}, "id": "n04194289"}, {"label": "truck", "id": "n04490091"}], "links": [{"source": "n03125870", "target": "n02691156"}, {"source": "n03125870", "target": "n04194289"}, {"source": "n03575240", "target": "n03125870"}, {"source": "n03575240", "target": "n03791235"}, {"source": "n00003553", "target": "n03575240"}, {"source": "n00003553", "target": "n01471682"}, {"source": "n03791235", "target": "n02958343"}, {"source": "n03791235", "target": "n04490091"}, {"source": "n01471682", "target": "n01503061"}, {"source": "n01471682", "target": "n01886756"}, {"source": "n01471682", "target": "n01639765"}, {"source": "n02075296", "target": "n02121620"}, {"source": "n02075296", "target": "n02084071"}, {"source": "n01886756", "target": "n02075296"}, {"source": "n01886756", "target": "n02370806"}, {"source": "n02370806", "target": "n02430045"}, {"source": "n02370806", "target": "n02374451"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR100/graph-induced-ResNet18.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "apple", "id": "n07739125"}, {"label": "fingerling", "id": "n02512752"}, {"label": "baby", "id": "n09827683"}, {"label": "bear", "id": "n02131653"}, {"label": "beaver", "id": "n02363005"}, {"label": "bed", "id": "n02818832"}, {"label": "bee", "id": "n02206856"}, {"label": "beetle", "id": "n02164464"}, {"label": "bicycle", "id": "n02834778"}, {"label": "bottle", "id": "n02876657"}, {"label": "bowl", "id": "n02881193"}, {"label": "male_child", "id": "n10285313"}, {"label": "bridge", "id": "n02898711"}, {"label": "bus", "id": "n02924116"}, {"label": "butterfly", "id": "n02274259"}, {"label": "camel", "id": "n02437136"}, {"label": "can", "id": "n02946921"}, {"label": "castle", "id": "n02980441"}, {"label": "caterpillar", "id": "n02309337"}, {"label": "cattle", "id": "n02402425"}, {"label": "chair", "id": "n03001627"}, {"label": "chimpanzee", "id": "n02481823"}, {"label": "clock", "id": "n03046257"}, {"label": "cloud", "id": "n11439690"}, {"label": "cockroach", "id": "n02233338"}, {"label": "sofa", "id": "n04256520"}, {"label": "crab", "id": "n01976957"}, {"label": "crocodile", "id": "n01697178"}, {"label": "cup", "id": "n03147509"}, {"label": "dinosaur", "id": "n01699831"}, {"label": "dolphinfish", "id": "n02581957"}, {"label": "elephant", "id": "n02503517"}, {"label": "flatfish", "id": "n02657368"}, {"label": "forest", "id": "n08438533"}, {"label": "fox", "id": "n02118333"}, {"label": "girl", "id": "n10129825"}, {"label": "hamster", "id": "n02342885"}, {"label": "house", "id": "n03544360"}, {"label": "kangaroo", "id": "n01877134"}, {"label": "keyboard", "id": "n03614007"}, {"label": "lamp", "id": "n03636248"}, {"label": "lawn_mower", "id": "n03649909"}, {"label": "leopard", "id": "n02128385"}, {"label": "lion", "id": "n02129165"}, {"label": "lizard", "id": "n01674464"}, {"label": "lobster", "id": "n01982650"}, {"label": "man", "id": "n10287213"}, {"label": "maple", "id": "n12752205"}, {"label": "motorcycle", "id": "n03790512"}, {"label": "mountain", "id": "n09359803"}, {"label": "mouse", "id": "n02330245"}, {"label": "mushroom", "id": "n13001041"}, {"label": "oak", "id": "n12268246"}, {"label": "orange", "id": "n07747607"}, {"label": "orchid", "id": "n12041446"}, {"label": "otter", "id": "n02444819"}, {"label": "palm", "id": "n12582231"}, {"label": "pear", "id": "n07767847"}, {"label": "pickup", "id": "n03930630"}, {"label": "pine", "id": "n11608250"}, {"label": "plain", "id": "n09393605"}, {"label": "plate", "id": "n03959485"}, {"label": "poppy", "id": "n11900569"}, {"label": "porcupine", "id": "n02346627"}, {"label": "opossum", "id": "n01874928"}, {"label": "rabbit", "id": "n02324045"}, {"label": "raccoon", "id": "n02508021"}, {"label": "ray", "id": "n01495701"}, {"label": "road", "id": "n04096066"}, {"label": "rocket", "id": "n04099429"}, {"label": "rose", "id": "n12620196"}, {"label": "sea", "id": "n09426788"}, {"label": "seal", "id": "n02076196"}, {"label": "shark", "id": "n01482330"}, {"label": "shrew", "id": "n01891633"}, {"label": "shutout", "id": "n07476495"}, {"label": "skyscraper", "id": "n04233124"}, {"label": "snail", "id": "n01944390"}, {"label": "snake", "id": "n01726692"}, {"label": "spider", "id": "n01772222"}, {"label": "squirrel", "id": "n02355227"}, {"label": "streetcar", "id": "n04335435"}, {"label": "sunflower", "id": "n11978233"}, {"label": "sweet_pepper", "id": "n12901264"}, {"label": "table", "id": "n04379243"}, {"label": "tank", "id": "n04389033"}, {"label": "telephone", "id": "n04401088"}, {"label": "television", "id": "n06277280"}, {"label": "tiger", "id": "n02129604"}, {"label": "tractor", "id": "n04465501"}, {"label": "train", "id": "n04468005"}, {"label": "trout", "id": "n07794452"}, {"label": "tulip", "id": "n12454159"}, {"label": "turtle", "id": "n01662784"}, {"label": "wardrobe", "id": "n04550184"}, {"label": "whale", "id": "n02062744"}, {"label": "willow", "id": "n12724942"}, {"label": "wolf", "id": "n02114100"}, {"label": "woman", "id": "n10787470"}, {"label": "worm", "id": "n01922303"}, {"id": "f00000100"}, {"id": "f00000101"}, {"id": "f00000102"}, {"id": "f00000103"}, {"id": "f00000104"}, {"id": "f00000105"}, {"id": "f00000106"}, {"id": "f00000107"}, {"id": "f00000108"}, {"id": "f00000109"}, {"id": "f00000110"}, {"id": "f00000111"}, {"id": "f00000112"}, {"id": "f00000113"}, {"id": "f00000114"}, {"id": "f00000115"}, {"id": "f00000116"}, {"id": "f00000117"}, {"id": "f00000118"}, {"id": "f00000119"}, {"id": "f00000120"}, {"id": "f00000121"}, {"id": "f00000122"}, {"id": "f00000123"}, {"id": "f00000124"}, {"id": "f00000125"}, {"id": "f00000126"}, {"id": "f00000127"}, {"id": "f00000128"}, {"id": "f00000129"}, {"id": "f00000130"}, {"id": "f00000131"}, {"id": "f00000132"}, {"id": "f00000133"}, {"id": "f00000134"}, {"id": "f00000135"}, {"id": "f00000136"}, {"id": "f00000137"}, {"id": "f00000138"}, {"id": "f00000139"}, {"id": "f00000140"}, {"id": "f00000141"}, {"id": "f00000142"}, {"id": "f00000143"}, {"id": "f00000144"}, {"id": "f00000145"}, {"id": "f00000146"}, {"id": "f00000147"}, {"id": "f00000148"}, {"id": "f00000149"}, {"id": "f00000150"}, {"id": "f00000151"}, {"id": "f00000152"}, {"id": "f00000153"}, {"id": "f00000154"}, {"id": "f00000155"}, {"id": "f00000156"}, {"id": "f00000157"}, {"id": "f00000158"}, {"id": "f00000159"}, {"id": "f00000160"}, {"id": "f00000161"}, {"id": "f00000162"}, {"id": "f00000163"}, {"id": "f00000164"}, {"id": "f00000165"}, {"id": "f00000166"}, {"id": "f00000167"}, {"id": "f00000168"}, {"id": "f00000169"}, {"id": "f00000170"}, {"id": "f00000171"}, {"id": "f00000172"}, {"id": "f00000173"}, {"id": "f00000174"}, {"id": "f00000175"}, {"id": "f00000176"}, {"id": "f00000177"}, {"id": "f00000178"}, {"id": "f00000179"}, {"id": "f00000180"}, {"id": "f00000181"}, {"id": "f00000182"}, {"id": "f00000183"}, {"id": "f00000184"}, {"id": "f00000185"}, {"id": "f00000186"}, {"id": "f00000187"}, {"id": "f00000188"}, {"id": "f00000189"}, {"id": "f00000190"}, {"id": "f00000191"}, {"id": "f00000192"}, {"id": "f00000193"}, {"id": "f00000194"}, {"id": "f00000195"}, {"id": "f00000196"}, {"id": "f00000197"}, {"id": "f00000198"}], "links": [{"source": "f00000100", "target": "n07739125"}, {"source": "f00000100", "target": "n07747607"}, {"source": "f00000101", "target": "n09393605"}, {"source": "f00000101", "target": "n09426788"}, {"source": "f00000102", "target": "n02834778"}, {"source": "f00000102", "target": "n03790512"}, {"source": "f00000103", "target": "n12901264"}, {"source": "f00000103", "target": "f00000100"}, {"source": "f00000104", "target": "n11439690"}, {"source": "f00000104", "target": "f00000101"}, {"source": "f00000105", "target": "n06277280"}, {"source": "f00000105", "target": "n04550184"}, {"source": "f00000106", "target": "n02980441"}, {"source": "f00000106", "target": "n04233124"}, {"source": "f00000107", "target": "n11900569"}, {"source": "f00000107", "target": "n11978233"}, {"source": "f00000108", "target": "n02924116"}, {"source": "f00000108", "target": "n03930630"}, {"source": "f00000109", "target": "n12752205"}, {"source": "f00000109", "target": "n12268246"}, {"source": "f00000110", "target": "n02818832"}, {"source": "f00000110", "target": "n03001627"}, {"source": "f00000111", "target": "n03649909"}, {"source": "f00000111", "target": "n04465501"}, {"source": "f00000112", "target": "n02129165"}, {"source": "f00000112", "target": "n02129604"}, {"source": "f00000113", "target": "n09359803"}, {"source": "f00000113", "target": "f00000106"}, {"source": "f00000114", "target": "n03614007"}, {"source": "f00000114", "target": "n04401088"}, {"source": "f00000115", "target": "n04335435"}, {"source": "f00000115", "target": "n04468005"}, {"source": "f00000116", "target": "n03544360"}, {"source": "f00000116", "target": "n12582231"}, {"source": "f00000117", "target": "n08438533"}, {"source": "f00000117", "target": "n04096066"}, {"source": "f00000118", "target": "n07767847"}, {"source": "f00000118", "target": "f00000103"}, {"source": "f00000119", "target": "n12041446"}, {"source": "f00000119", "target": "n12620196"}, {"source": "f00000120", "target": "n02581957"}, {"source": "f00000120", "target": "n02062744"}, {"source": "f00000121", "target": "n04256520"}, {"source": "f00000121", "target": "f00000110"}, {"source": "f00000122", "target": "n12454159"}, {"source": "f00000122", "target": "f00000119"}, {"source": "f00000123", "target": "n02508021"}, {"source": "f00000123", "target": "n07476495"}, {"source": "f00000124", "target": "n02164464"}, {"source": "f00000124", "target": "n02233338"}, {"source": "f00000125", "target": "n02128385"}, {"source": "f00000125", "target": "f00000112"}, {"source": "f00000126", "target": "n03046257"}, {"source": "f00000126", "target": "n03959485"}, {"source": "f00000127", "target": "n01697178"}, {"source": "f00000127", "target": "n04389033"}, {"source": "f00000128", "target": "n02481823"}, {"source": "f00000128", "target": "n02503517"}, {"source": "f00000129", "target": "n12724942"}, {"source": "f00000129", "target": "f00000109"}, {"source": "f00000130", "target": "n02512752"}, {"source": "f00000130", "target": "n07794452"}, {"source": "f00000131", "target": "n01726692"}, {"source": "f00000131", "target": "n01922303"}, {"source": "f00000132", "target": "n02876657"}, {"source": "f00000132", "target": "n02946921"}, {"source": "f00000133", "target": "n02881193"}, {"source": "f00000133", "target": "n03147509"}, {"source": "f00000134", "target": "n02342885"}, {"source": "f00000134", "target": "n02330245"}, {"source": "f00000135", "target": "n04099429"}, {"source": "f00000135", "target": "f00000113"}, {"source": "f00000136", "target": "n02118333"}, {"source": "f00000136", "target": "n02114100"}, {"source": "f00000137", "target": "n01976957"}, {"source": "f00000137", "target": "n01772222"}, {"source": "f00000138", "target": "n02346627"}, {"source": "f00000138", "target": "f00000123"}, {"source": "f00000139", "target": "n11608250"}, {"source": "f00000139", "target": "f00000129"}, {"source": "f00000140", "target": "n10129825"}, {"source": "f00000140", "target": "n10787470"}, {"source": "f00000141", "target": "n01482330"}, {"source": "f00000141", "target": "f00000120"}, {"source": "f00000142", "target": "f00000107"}, {"source": "f00000142", "target": "f00000122"}, {"source": "f00000143", "target": "n04379243"}, {"source": "f00000143", "target": "f00000121"}, {"source": "f00000144", "target": "n13001041"}, {"source": "f00000144", "target": "n01944390"}, {"source": "f00000145", "target": "n02206856"}, {"source": "f00000145", "target": "n02309337"}, {"source": "f00000146", "target": "n02437136"}, {"source": "f00000146", "target": "n02402425"}, {"source": "f00000147", "target": "f00000125"}, {"source": "f00000147", "target": "f00000136"}, {"source": "f00000148", "target": "f00000102"}, {"source": "f00000148", "target": "f00000111"}, {"source": "f00000149", "target": "n10285313"}, {"source": "f00000149", "target": "n10287213"}, {"source": "f00000150", "target": "n02274259"}, {"source": "f00000150", "target": "f00000145"}, {"source": "f00000151", "target": "f00000108"}, {"source": "f00000151", "target": "f00000115"}, {"source": "f00000152", "target": "f00000105"}, {"source": "f00000152", "target": "f00000143"}, {"source": "f00000153", "target": "n02898711"}, {"source": "f00000153", "target": "f00000116"}, {"source": "f00000154", "target": "n01874928"}, {"source": "f00000154", "target": "f00000138"}, {"source": "f00000155", "target": "n01877134"}, {"source": "f00000155", "target": "f00000146"}, {"source": "f00000156", "target": "n02131653"}, {"source": "f00000156", "target": "f00000128"}, {"source": "f00000157", "target": "n02657368"}, {"source": "f00000157", "target": "n01495701"}, {"source": "f00000158", "target": "f00000104"}, {"source": "f00000158", "target": "f00000117"}, {"source": "f00000159", "target": "n09827683"}, {"source": "f00000159", "target": "f00000140"}, {"source": "f00000160", "target": "f00000126"}, {"source": "f00000160", "target": "f00000133"}, {"source": "f00000161", "target": "n01699831"}, {"source": "f00000161", "target": "n01662784"}, {"source": "f00000162", "target": "n01982650"}, {"source": "f00000162", "target": "f00000137"}, {"source": "f00000163", "target": "n02363005"}, {"source": "f00000163", "target": "n02444819"}, {"source": "f00000164", "target": "n02324045"}, {"source": "f00000164", "target": "n02355227"}, {"source": "f00000165", "target": "f00000135"}, {"source": "f00000165", "target": "f00000153"}, {"source": "f00000166", "target": "f00000149"}, {"source": "f00000166", "target": "f00000159"}, {"source": "f00000167", "target": "f00000114"}, {"source": "f00000167", "target": "f00000132"}, {"source": "f00000168", "target": "n03636248"}, {"source": "f00000168", "target": "f00000160"}, {"source": "f00000169", "target": "n01674464"}, {"source": "f00000169", "target": "f00000131"}, {"source": "f00000170", "target": "f00000130"}, {"source": "f00000170", "target": "f00000161"}, {"source": "f00000171", "target": "f00000124"}, {"source": "f00000171", "target": "f00000150"}, {"source": "f00000172", "target": "f00000134"}, {"source": "f00000172", "target": "f00000164"}, {"source": "f00000173", "target": "n01891633"}, {"source": "f00000173", "target": "f00000144"}, {"source": "f00000174", "target": "n02076196"}, {"source": "f00000174", "target": "f00000163"}, {"source": "f00000175", "target": "f00000127"}, {"source": "f00000175", "target": "f00000170"}, {"source": "f00000176", "target": "f00000158"}, {"source": "f00000176", "target": "f00000165"}, {"source": "f00000177", "target": "f00000155"}, {"source": "f00000177", "target": "f00000156"}, {"source": "f00000178", "target": "f00000162"}, {"source": "f00000178", "target": "f00000171"}, {"source": "f00000179", "target": "f00000154"}, {"source": "f00000179", "target": "f00000172"}, {"source": "f00000180", "target": "f00000148"}, {"source": "f00000180", "target": "f00000151"}, {"source": "f00000181", "target": "f00000152"}, {"source": "f00000181", "target": "f00000167"}, {"source": "f00000182", "target": "f00000141"}, {"source": "f00000182", "target": "f00000157"}, {"source": "f00000183", "target": "f00000169"}, {"source": "f00000183", "target": "f00000173"}, {"source": "f00000184", "target": "f00000118"}, {"source": "f00000184", "target": "f00000142"}, {"source": "f00000185", "target": "f00000147"}, {"source": "f00000185", "target": "f00000177"}, {"source": "f00000186", "target": "f00000175"}, {"source": "f00000186", "target": "f00000183"}, {"source": "f00000187", "target": "f00000168"}, {"source": "f00000187", "target": "f00000181"}, {"source": "f00000188", "target": "f00000174"}, {"source": "f00000188", "target": "f00000179"}, {"source": "f00000189", "target": "f00000139"}, {"source": "f00000189", "target": "f00000176"}, {"source": "f00000190", "target": "f00000182"}, {"source": "f00000190", "target": "f00000186"}, {"source": "f00000191", "target": "f00000185"}, {"source": "f00000191", "target": "f00000188"}, {"source": "f00000192", "target": "f00000178"}, {"source": "f00000192", "target": "f00000184"}, {"source": "f00000193", "target": "f00000180"}, {"source": "f00000193", "target": "f00000187"}, {"source": "f00000194", "target": "f00000166"}, {"source": "f00000194", "target": "f00000193"}, {"source": "f00000195", "target": "f00000190"}, {"source": "f00000195", "target": "f00000192"}, {"source": "f00000196", "target": "f00000189"}, {"source": "f00000196", "target": "f00000195"}, {"source": "f00000197", "target": "f00000194"}, {"source": "f00000197", "target": "f00000196"}, {"source": "f00000198", "target": "f00000191"}, {"source": "f00000198", "target": "f00000197"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/CIFAR100/graph-induced.json: -------------------------------------------------------------------------------- 1 | graph-induced-wrn28_10_cifar100.json -------------------------------------------------------------------------------- /nbdt/hierarchies/Cityscapes/graph-induced-HRNet-w18-v1.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "road", "id": "n04096066"}, {"label": "sidewalk", "id": "n04215402"}, {"label": "building", "id": "n02913152"}, {"label": "wall", "id": "n04546855"}, {"label": "fence", "id": "n03327234"}, {"label": "pole", "id": "n03976657"}, {"label": "traffic_light", "id": "n06874185"}, {"label": "street_sign", "id": "n06794110"}, {"label": "vegetation", "id": "n08436759"}, {"label": "terrain", "id": "n08674563"}, {"label": "sky", "id": "n09436708"}, {"label": "person", "id": "n00007846"}, {"label": "rider", "id": "n10530150"}, {"label": "car", "id": "n02958343"}, {"label": "truck", "id": "n04490091"}, {"label": "bus", "id": "n02924116"}, {"label": "train", "id": "n04468005"}, {"label": "motorcycle", "id": "n03790512"}, {"label": "bicycle", "id": "n02834778"}, {"label": "wheeled_vehicle", "id": "n04576211"}, {"label": "cyclist", "id": "n00003553"}, {"label": "public_transport", "id": "n04019101"}, {"label": "signage", "id": "n00033020"}, {"label": "long_vehicle", "id": "n03100490"}, {"label": "structure", "id": "n04341686"}, {"label": "pole-like", "id": "n00001740"}, {"label": "people", "id": "n00002684"}, {"label": "physical_entity", "id": "n00001930"}, {"label": "(generated)", "id": "f00000028"}, {"label": "(generated)", "id": "f00000029"}, {"label": "pavement", "id": "f00000030"}, {"label": "(generated)", "id": "f00000031", "alt": "people, long_vehicle, pole-like, sky"}, {"label": "(generated)", "id": "f00000032", "alt": "vehicle, people, pole-like, sky"}, {"label": "(generated)", "id": "f00000033"}, {"label": "landscape", "id": "f00000034"}, {"label": "(generated)", "id": "f00000035"}, {"label": "(generated)", "id": "f00000036"}], "links": [{"source": "n04576211", "target": "n03790512"}, {"source": "n04576211", "target": "n02834778"}, {"source": "n00003553", "target": "n10530150"}, {"source": "n00003553", "target": "n04576211"}, {"source": "n04019101", "target": "n02924116"}, {"source": "n04019101", "target": "n04468005"}, {"source": "n00033020", "target": "n06874185"}, {"source": "n00033020", "target": "n06794110"}, {"source": "n03100490", "target": "n04490091"}, {"source": "n03100490", "target": "n04019101"}, {"source": "n04341686", "target": "n04546855"}, {"source": "n04341686", "target": "n03327234"}, {"source": "n00001740", "target": "n03976657"}, {"source": "n00001740", "target": "n00033020"}, {"source": "n00002684", "target": "n00007846"}, {"source": "n00002684", "target": "n00003553"}, {"source": "n00001930", "target": "n08674563"}, {"source": "n00001930", "target": "n04341686"}, {"source": "f00000028", "target": "n09436708"}, {"source": "f00000028", "target": "n00001740"}, {"source": "f00000029", "target": "n03100490"}, {"source": "f00000029", "target": "f00000028"}, {"source": "f00000030", "target": "n04215402"}, {"source": "f00000030", "target": "n00001930"}, {"source": "f00000031", "target": "n00002684"}, {"source": "f00000031", "target": "f00000029"}, {"source": "f00000032", "target": "n02958343"}, {"source": "f00000032", "target": "f00000031"}, {"source": "f00000033", "target": "f00000030"}, {"source": "f00000033", "target": "f00000032"}, {"source": "f00000034", "target": "n02913152"}, {"source": "f00000034", "target": "n08436759"}, {"source": "f00000035", "target": "f00000033"}, {"source": "f00000035", "target": "f00000034"}, {"source": "f00000036", "target": "n04096066"}, {"source": "f00000036", "target": "f00000035"}]} 2 | -------------------------------------------------------------------------------- /nbdt/hierarchies/Cityscapes/graph-induced-HRNet-w48.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "road", "id": "n04096066"}, {"label": "sidewalk", "id": "n04215402"}, {"label": "building", "id": "n02913152"}, {"label": "wall", "id": "n04546855"}, {"label": "fence", "id": "n03327234"}, {"label": "pole", "id": "n03976657"}, {"label": "traffic_light", "id": "n06874185"}, {"label": "street_sign", "id": "n06794110"}, {"label": "vegetation", "id": "n08436759"}, {"label": "terrain", "id": "n08674563"}, {"label": "sky", "id": "n09436708"}, {"label": "person", "id": "n00007846"}, {"label": "rider", "id": "n10530150"}, {"label": "car", "id": "n02958343"}, {"label": "truck", "id": "n04490091"}, {"label": "bus", "id": "n02924116"}, {"label": "train", "id": "n04468005"}, {"label": "motorcycle", "id": "n03790512"}, {"label": "bicycle", "id": "n02834778"}, {"label": "whole", "id": "n00003553"}, {"label": "object", "id": "n00002684"}, {"label": "communication", "id": "n00033020"}, {"label": "public_transport", "id": "n04019101"}, {"label": "conveyance", "id": "n03100490"}, {"label": "structure", "id": "n04341686"}, {"label": "entity", "id": "n00001740"}, {"label": "physical_entity", "id": "n00001930"}, {"label": "(generated)", "id": "f00000027"}, {"label": "(generated)", "id": "f00000028"}, {"label": "(generated)", "id": "f00000029"}, {"label": "(generated)", "id": "f00000030"}, {"label": "(generated)", "id": "f00000031"}, {"label": "(generated)", "id": "f00000032"}, {"label": "(generated)", "id": "f00000033"}, {"label": "(generated)", "id": "f00000034"}, {"label": "(generated)", "id": "f00000035"}, {"label": "(generated)", "id": "f00000036"}], "links": [{"source": "n00003553", "target": "n10530150"}, {"source": "n00003553", "target": "n03790512"}, {"source": "n00002684", "target": "n02834778"}, {"source": "n00002684", "target": "n00003553"}, {"source": "n00033020", "target": "n06874185"}, {"source": "n00033020", "target": "n06794110"}, {"source": "n04019101", "target": "n02924116"}, {"source": "n04019101", "target": "n04468005"}, {"source": "n03100490", "target": "n04490091"}, {"source": "n03100490", "target": "n04019101"}, {"source": "n04341686", "target": "n04546855"}, {"source": "n04341686", "target": "n03327234"}, {"source": "n00001740", "target": "n03976657"}, {"source": "n00001740", "target": "n00033020"}, {"source": "n00001930", "target": "n00007846"}, {"source": "n00001930", "target": "n00002684"}, {"source": "f00000027", "target": "n08674563"}, {"source": "f00000027", "target": "n04341686"}, {"source": "f00000028", "target": "n03100490"}, {"source": "f00000028", "target": "n00001740"}, {"source": "f00000029", "target": "n09436708"}, {"source": "f00000029", "target": "f00000028"}, {"source": "f00000030", "target": "n00001930"}, {"source": "f00000030", "target": "f00000029"}, {"source": "f00000031", "target": "f00000027"}, {"source": "f00000031", "target": "f00000030"}, {"source": "f00000032", "target": "n02958343"}, {"source": "f00000032", "target": "f00000031"}, {"source": "f00000033", "target": "n04215402"}, {"source": "f00000033", "target": "f00000032"}, {"source": "f00000034", "target": "n02913152"}, {"source": "f00000034", "target": "n08436759"}, {"source": "f00000035", "target": "f00000033"}, {"source": "f00000035", "target": "f00000034"}, {"source": "f00000036", "target": "n04096066"}, {"source": "f00000036", "target": "f00000035"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/Imagenet1000/graph-induced.json: -------------------------------------------------------------------------------- 1 | graph-induced-efficientnet_b7b.json -------------------------------------------------------------------------------- /nbdt/hierarchies/LookIntoPerson/graph-induced-HRNet-w48-cls20.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "background", "id": "n05933834"}, {"label": "hat", "id": "n03497657"}, {"label": "hair", "id": "n05254795"}, {"label": "glove", "id": "n03441112"}, {"label": "sunglasses", "id": "n04356056"}, {"label": "top", "id": "n04453666"}, {"label": "dress", "id": "n03236735"}, {"label": "coat", "id": "n03057021"}, {"label": "sock", "id": "n04254777"}, {"label": "bloomers", "id": "n02854739"}, {"label": "jump_suit", "id": "n03605598"}, {"label": "scarf", "id": "n04143897"}, {"label": "skirt", "id": "n04231272"}, {"label": "face", "id": "n05600637"}, {"label": "arm", "id": "n05563770"}, {"label": "(generated)", "id": "f00000015"}, {"label": "leg", "id": "n05560787"}, {"label": "(generated)", "id": "f00000017"}, {"label": "shoe", "id": "n04199027"}, {"label": "(generated)", "id": "f00000019"}, {"label": "(generated)", "id": "f00000020"}, {"label": "artifact", "id": "n00021939"}, {"label": "(generated)", "id": "f00000022"}, {"label": "(generated)", "id": "f00000023"}, {"label": "whole", "id": "n00003553"}, {"label": "covering", "id": "n03122748"}, {"label": "object", "id": "n00002684"}, {"label": "physical_entity", "id": "n00001930"}, {"label": "(generated)", "id": "f00000028"}, {"label": "(generated)", "id": "f00000029"}, {"label": "entity", "id": "n00001740"}, {"label": "(generated)", "id": "f00000031"}, {"label": "(generated)", "id": "f00000032"}, {"label": "(generated)", "id": "f00000033"}, {"label": "(generated)", "id": "f00000034"}, {"label": "(generated)", "id": "f00000035"}, {"label": "clothing", "id": "n03051540"}, {"label": "(generated)", "id": "f00000037"}, {"label": "(generated)", "id": "f00000038"}], "links": [{"source": "f00000020", "target": "n04199027"}, {"source": "f00000020", "target": "f00000019"}, {"source": "n00021939", "target": "n03441112"}, {"source": "n00021939", "target": "n04356056"}, {"source": "f00000022", "target": "n05560787"}, {"source": "f00000022", "target": "f00000017"}, {"source": "f00000023", "target": "n04254777"}, {"source": "f00000023", "target": "f00000020"}, {"source": "n00003553", "target": "n04143897"}, {"source": "n00003553", "target": "n00021939"}, {"source": "n03122748", "target": "n03236735"}, {"source": "n03122748", "target": "n04231272"}, {"source": "n00002684", "target": "n03497657"}, {"source": "n00002684", "target": "n00003553"}, {"source": "n00001930", "target": "n03605598"}, {"source": "n00001930", "target": "n00002684"}, {"source": "f00000028", "target": "n05563770"}, {"source": "f00000028", "target": "f00000015"}, {"source": "f00000029", "target": "f00000022"}, {"source": "f00000029", "target": "f00000023"}, {"source": "n00001740", "target": "n03122748"}, {"source": "n00001740", "target": "n00001930"}, {"source": "f00000031", "target": "n05254795"}, {"source": "f00000031", "target": "n05600637"}, {"source": "f00000032", "target": "f00000029"}, {"source": "f00000032", "target": "n00001740"}, {"source": "f00000033", "target": "n02854739"}, {"source": "f00000033", "target": "f00000032"}, {"source": "f00000034", "target": "f00000028"}, {"source": "f00000034", "target": "f00000031"}, {"source": "f00000035", "target": "f00000033"}, {"source": "f00000035", "target": "f00000034"}, {"source": "n03051540", "target": "n04453666"}, {"source": "n03051540", "target": "n03057021"}, {"source": "f00000037", "target": "f00000035"}, {"source": "f00000037", "target": "n03051540"}, {"source": "f00000038", "target": "n05933834"}, {"source": "f00000038", "target": "f00000037"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/PascalContext/graph-induced-HRNet-w48-cls59.json: -------------------------------------------------------------------------------- 1 | {"directed": true, "multigraph": false, "graph": {}, "nodes": [{"label": "airplane", "id": "n02691156"}, {"label": "mountain", "id": "n09359803"}, {"label": "mouse", "id": "n02330245"}, {"label": "path", "id": "n09387222"}, {"label": "road", "id": "n04096066"}, {"label": "bag", "id": "n02773037"}, {"label": "minibike", "id": "n03769722"}, {"label": "fence", "id": "n03327234"}, {"label": "bed", "id": "n02818832"}, {"label": "bedclothes", "id": "n02820210"}, {"label": "bench", "id": "n02828884"}, {"label": "bicycle", "id": "n02834778"}, {"label": "dining_table", "id": "n03201208"}, {"label": "bird", "id": "n01503061"}, {"label": "person", "id": "n00007846"}, {"label": "floor", "id": "n03365592"}, {"label": "boat", "id": "n02858304"}, {"label": "train", "id": "n04468005"}, {"label": "book", "id": "n06410904"}, {"label": "bottle", "id": "n02876657"}, {"label": "tree", "id": "n13104059"}, {"label": "window", "id": "n04587648"}, {"label": "plate", "id": "n03959485"}, {"label": "platform", "id": "n03961939"}, {"label": "television_monitor", "id": "n04405762"}, {"label": "building", "id": "n02913152"}, {"label": "bus", "id": "n02924116"}, {"label": "cabinet", "id": "n02933112"}, {"label": "shelf", "id": "n04190052"}, {"label": "light", "id": "n11473954"}, {"label": "plant", "id": "n00017222"}, {"label": "wall", "id": "n04546855"}, {"label": "car", "id": "n02958343"}, {"label": "land", "id": "n09334396"}, {"label": "cat", "id": "n02121620"}, {"label": "sidewalk", "id": "n04215402"}, {"label": "truck", "id": "n04490091"}, {"label": "ceiling", "id": "n02990373"}, {"label": "rock", "id": "n09416076"}, {"label": "chair", "id": "n03001627"}, {"label": "wood", "id": "n15098161"}, {"label": "food", "id": "n00021265"}, {"label": "horse", "id": "n02374451"}, {"label": "fabric", "id": "n03309808"}, {"label": "sign", "id": "n06793231"}, {"label": "computer", "id": "n03082979"}, {"label": "sheep", "id": "n02411705"}, {"label": "keyboard", "id": "n03614007"}, {"label": "flower", "id": "n11669921"}, {"label": "sky", "id": "n09436708"}, {"label": "cow", "id": "n02403454"}, {"label": "grass", "id": "n12102133"}, {"label": "cup", "id": "n03147509"}, {"label": "curtain", "id": "n03151077"}, {"label": "snow", "id": "n11508382"}, {"label": "water", "id": "n14845743"}, {"label": "sofa", "id": "n04256520"}, {"label": "dog", "id": "n02084071"}, {"label": "door", "id": "n03221720"}, {"label": "artifact", "id": "n00021939"}, {"label": "whole", "id": "n00003553"}, {"label": "object", "id": "n00002684"}, {"label": "physical_entity", "id": "n00001930"}, {"label": "instrumentality", "id": "n03575240"}, {"label": "entity", "id": "n00001740"}, {"label": "(generated)", "id": "f00000065"}, {"label": "(generated)", "id": "f00000066"}, {"label": "(generated)", "id": "f00000067"}, {"label": "(generated)", "id": "f00000068"}, {"label": "(generated)", "id": "f00000069"}, {"label": "(generated)", "id": "f00000070"}, {"label": "(generated)", "id": "f00000071"}, {"label": "(generated)", "id": "f00000072"}, {"label": "(generated)", "id": "f00000073"}, {"label": "(generated)", "id": "f00000074"}, {"label": "(generated)", "id": "f00000075"}, {"label": "(generated)", "id": "f00000076"}, {"label": "(generated)", "id": "f00000077"}, {"label": "abstraction", "id": "n00002137"}, {"label": "(generated)", "id": "f00000079"}, {"label": "(generated)", "id": "f00000080"}, {"label": "(generated)", "id": "f00000081"}, {"label": "(generated)", "id": "f00000082"}, {"label": "(generated)", "id": "f00000083"}, {"label": "(generated)", "id": "f00000084"}, {"label": "(generated)", "id": "f00000085"}, {"label": "(generated)", "id": "f00000086"}, {"label": "(generated)", "id": "f00000087"}, {"label": "(generated)", "id": "f00000088"}, {"label": "(generated)", "id": "f00000089"}, {"label": "(generated)", "id": "f00000090"}, {"label": "(generated)", "id": "f00000091"}, {"label": "(generated)", "id": "f00000092"}, {"label": "(generated)", "id": "f00000093"}, {"label": "(generated)", "id": "f00000094"}, {"label": "(generated)", "id": "f00000095"}, {"label": "(generated)", "id": "f00000096"}, {"label": "(generated)", "id": "f00000097"}, {"label": "(generated)", "id": "f00000098"}, {"label": "(generated)", "id": "f00000099"}, {"label": "(generated)", "id": "f00000100"}, {"label": "(generated)", "id": "f00000101"}, {"label": "(generated)", "id": "f00000102"}, {"label": "(generated)", "id": "f00000103"}, {"label": "(generated)", "id": "f00000104"}, {"label": "(generated)", "id": "f00000105"}, {"label": "(generated)", "id": "f00000106"}, {"label": "(generated)", "id": "f00000107"}, {"label": "(generated)", "id": "f00000108"}, {"label": "(generated)", "id": "f00000109"}, {"label": "(generated)", "id": "f00000110"}, {"label": "(generated)", "id": "f00000111"}, {"label": "(generated)", "id": "f00000112"}, {"label": "(generated)", "id": "f00000113"}, {"label": "(generated)", "id": "f00000114"}, {"label": "(generated)", "id": "f00000115"}, {"label": "(generated)", "id": "f00000116"}], "links": [{"source": "n00021939", "target": "n04546855"}, {"source": "n00021939", "target": "n04215402"}, {"source": "n00003553", "target": "n02933112"}, {"source": "n00003553", "target": "n02990373"}, {"source": "n00002684", "target": "n06410904"}, {"source": "n00002684", "target": "n00021939"}, {"source": "n00001930", "target": "n13104059"}, {"source": "n00001930", "target": "n00003553"}, {"source": "n03575240", "target": "n03082979"}, {"source": "n03575240", "target": "n03151077"}, {"source": "n00001740", "target": "n02818832"}, {"source": "n00001740", "target": "n03309808"}, {"source": "f00000065", "target": "n03365592"}, {"source": "f00000065", "target": "n02958343"}, {"source": "f00000066", "target": "n09416076"}, {"source": "f00000066", "target": "n02403454"}, {"source": "f00000067", "target": "n04096066"}, {"source": "f00000067", "target": "n03575240"}, {"source": "f00000068", "target": "n09359803"}, {"source": "f00000068", "target": "n00001930"}, {"source": "f00000069", "target": "n02330245"}, {"source": "f00000069", "target": "n09387222"}, {"source": "f00000070", "target": "n02924116"}, {"source": "f00000070", "target": "n03001627"}, {"source": "f00000071", "target": "n02876657"}, {"source": "f00000071", "target": "n00017222"}, {"source": "f00000072", "target": "n02820210"}, {"source": "f00000072", "target": "f00000068"}, {"source": "f00000073", "target": "n03201208"}, {"source": "f00000073", "target": "n03961939"}, {"source": "f00000074", "target": "n00002684"}, {"source": "f00000074", "target": "n00001740"}, {"source": "f00000075", "target": "n02121620"}, {"source": "f00000075", "target": "n00021265"}, {"source": "f00000076", "target": "n04587648"}, {"source": "f00000076", "target": "n02084071"}, {"source": "f00000077", "target": "n12102133"}, {"source": "f00000077", "target": "f00000066"}, {"source": "n00002137", "target": "n15098161"}, {"source": "n00002137", "target": "n06793231"}, {"source": "f00000079", "target": "n03769722"}, {"source": "f00000079", "target": "n02374451"}, {"source": "f00000080", "target": "n03614007"}, {"source": "f00000080", "target": "f00000075"}, {"source": "f00000081", "target": "n02691156"}, {"source": "f00000081", "target": "n03327234"}, {"source": "f00000082", "target": "n11508382"}, {"source": "f00000082", "target": "f00000074"}, {"source": "f00000083", "target": "n02773037"}, {"source": "f00000083", "target": "n09334396"}, {"source": "f00000084", "target": "f00000065"}, {"source": "f00000084", "target": "f00000076"}, {"source": "f00000085", "target": "n09436708"}, {"source": "f00000085", "target": "f00000072"}, {"source": "f00000086", "target": "n02834778"}, {"source": "f00000086", "target": "f00000067"}, {"source": "f00000087", "target": "n04468005"}, {"source": "f00000087", "target": "f00000069"}, {"source": "f00000088", "target": "f00000071"}, {"source": "f00000088", "target": "f00000079"}, {"source": "f00000089", "target": "n03221720"}, {"source": "f00000089", "target": "f00000073"}, {"source": "f00000090", "target": "n02858304"}, {"source": "f00000090", "target": "n11669921"}, {"source": "f00000091", "target": "n04256520"}, {"source": "f00000091", "target": "f00000080"}, {"source": "f00000092", "target": "n01503061"}, {"source": "f00000092", "target": "f00000086"}, {"source": "f00000093", "target": "f00000070"}, {"source": "f00000093", "target": "f00000085"}, {"source": "f00000094", "target": "f00000081"}, {"source": "f00000094", "target": "f00000092"}, {"source": "f00000095", "target": "n00007846"}, {"source": "f00000095", "target": "n03959485"}, {"source": "f00000096", "target": "f00000084"}, {"source": "f00000096", "target": "f00000089"}, {"source": "f00000097", "target": "n04405762"}, {"source": "f00000097", "target": "f00000077"}, {"source": "f00000098", "target": "n02913152"}, {"source": "f00000098", "target": "f00000090"}, {"source": "f00000099", "target": "f00000083"}, {"source": "f00000099", "target": "f00000094"}, {"source": "f00000100", "target": "f00000088"}, {"source": "f00000100", "target": "f00000095"}, {"source": "f00000101", "target": "f00000087"}, {"source": "f00000101", "target": "f00000098"}, {"source": "f00000102", "target": "f00000097"}, {"source": "f00000102", "target": "f00000099"}, {"source": "f00000103", "target": "f00000082"}, {"source": "f00000103", "target": "f00000093"}, {"source": "f00000104", "target": "n02411705"}, {"source": "f00000104", "target": "f00000091"}, {"source": "f00000105", "target": "n14845743"}, {"source": "f00000105", "target": "f00000096"}, {"source": "f00000106", "target": "n00002137"}, {"source": "f00000106", "target": "f00000102"}, {"source": "f00000107", "target": "n04190052"}, {"source": "f00000107", "target": "n03147509"}, {"source": "f00000108", "target": "n02828884"}, {"source": "f00000108", "target": "n11473954"}, {"source": "f00000109", "target": "n04490091"}, {"source": "f00000109", "target": "f00000101"}, {"source": "f00000110", "target": "f00000104"}, {"source": "f00000110", "target": "f00000107"}, {"source": "f00000111", "target": "f00000103"}, {"source": "f00000111", "target": "f00000105"}, {"source": "f00000112", "target": "f00000106"}, {"source": "f00000112", "target": "f00000108"}, {"source": "f00000113", "target": "f00000100"}, {"source": "f00000113", "target": "f00000110"}, {"source": "f00000114", "target": "f00000109"}, {"source": "f00000114", "target": "f00000111"}, {"source": "f00000115", "target": "f00000112"}, {"source": "f00000115", "target": "f00000113"}, {"source": "f00000116", "target": "f00000114"}, {"source": "f00000116", "target": "f00000115"}]} -------------------------------------------------------------------------------- /nbdt/hierarchies/TinyImagenet200/graph-induced.json: -------------------------------------------------------------------------------- 1 | graph-induced-wrn28_10.json -------------------------------------------------------------------------------- /nbdt/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import defaultdict 5 | from nbdt.tree import Node, Tree 6 | from nbdt.model import HardEmbeddedDecisionRules, SoftEmbeddedDecisionRules 7 | from math import log 8 | from nbdt.utils import ( 9 | Colors, 10 | dataset_to_default_path_graph, 11 | dataset_to_default_path_wnids, 12 | hierarchy_to_path_graph, 13 | coerce_tensor, 14 | uncoerce_tensor, 15 | ) 16 | from pathlib import Path 17 | import os 18 | 19 | __all__ = names = ( 20 | "HardTreeSupLoss", 21 | "SoftTreeSupLoss", 22 | "SoftTreeLoss", 23 | "CrossEntropyLoss", 24 | ) 25 | 26 | 27 | def add_arguments(parser): 28 | parser.add_argument( 29 | "--xent-weight", "--xw", type=float, help="Weight for cross entropy term" 30 | ) 31 | parser.add_argument( 32 | "--xent-weight-end", 33 | "--xwe", 34 | type=float, 35 | help="Weight for cross entropy term at end of training." 36 | "If not set, set to cew", 37 | ) 38 | parser.add_argument( 39 | "--xent-weight-power", "--xwp", type=float, help="Raise progress to this power." 40 | ) 41 | parser.add_argument( 42 | "--tree-supervision-weight", 43 | "--tsw", 44 | type=float, 45 | default=1, 46 | help="Weight assigned to tree supervision losses", 47 | ) 48 | parser.add_argument( 49 | "--tree-supervision-weight-end", 50 | "--tswe", 51 | type=float, 52 | help="Weight assigned to tree supervision losses at " 53 | "end of training. If not set, this is equal to tsw", 54 | ) 55 | parser.add_argument( 56 | "--tree-supervision-weight-power", 57 | "--tswp", 58 | type=float, 59 | help="Raise progress to this power. > 1 to trend " 60 | "towards tswe more slowly. < 1 to trend more quickly", 61 | ) 62 | parser.add_argument( 63 | "--tree-start-epochs", 64 | "--tse", 65 | type=int, 66 | help="epoch count to start tree supervision loss from (generate tree at that pt)", 67 | ) 68 | parser.add_argument( 69 | "--tree-update-end-epochs", 70 | "--tuene", 71 | type=int, 72 | help="epoch count to stop generating new trees at", 73 | ) 74 | parser.add_argument( 75 | "--tree-update-every-epochs", 76 | "--tueve", 77 | type=int, 78 | help="Recompute tree from weights every (this many) epochs", 79 | ) 80 | 81 | 82 | def set_default_values(args): 83 | assert not ( 84 | args.hierarchy and args.path_graph 85 | ), "Only one, between --hierarchy and --path-graph can be provided." 86 | if args.hierarchy and not args.path_graph: 87 | args.path_graph = hierarchy_to_path_graph(args.dataset, args.hierarchy) 88 | if not args.path_graph: 89 | args.path_graph = dataset_to_default_path_graph(args.dataset) 90 | if not args.path_wnids: 91 | args.path_wnids = dataset_to_default_path_wnids(args.dataset) 92 | 93 | 94 | CrossEntropyLoss = nn.CrossEntropyLoss 95 | 96 | 97 | class TreeSupLoss(nn.Module): 98 | 99 | accepts_tree = lambda tree, **kwargs: tree 100 | accepts_criterion = lambda criterion, **kwargs: criterion 101 | accepts_dataset = lambda trainset, **kwargs: trainset.__class__.__name__ 102 | accepts_path_graph = True 103 | accepts_path_wnids = True 104 | accepts_tree_supervision_weight = True 105 | accepts_classes = lambda trainset, **kwargs: trainset.classes 106 | accepts_hierarchy = True 107 | accepts_tree_supervision_weight_end = True 108 | accepts_tree_supervision_weight_power = True 109 | accepts_xent_weight = True 110 | accepts_xent_weight_end = True 111 | accepts_xent_weight_power = True 112 | 113 | def __init__( 114 | self, 115 | dataset, 116 | criterion, 117 | path_graph=None, 118 | path_wnids=None, 119 | classes=None, 120 | hierarchy=None, 121 | Rules=HardEmbeddedDecisionRules, 122 | tree=None, 123 | tree_supervision_weight=1.0, 124 | tree_supervision_weight_end=None, 125 | tree_supervision_weight_power=1, # 1 for linear 126 | xent_weight=1, 127 | xent_weight_end=None, 128 | xent_weight_power=1, 129 | ): 130 | super().__init__() 131 | 132 | if not tree: 133 | tree = Tree(dataset, path_graph, path_wnids, classes, hierarchy=hierarchy) 134 | self.num_classes = len(tree.classes) 135 | self.tree = tree 136 | self.rules = Rules(tree=tree) 137 | self.tree_supervision_weight = tree_supervision_weight 138 | self.tree_supervision_weight_end = ( 139 | tree_supervision_weight_end 140 | if tree_supervision_weight_end is not None 141 | else tree_supervision_weight 142 | ) 143 | self.tree_supervision_weight_power = tree_supervision_weight_power 144 | self.xent_weight = xent_weight 145 | self.xent_weight_end = ( 146 | xent_weight_end if xent_weight_end is not None else xent_weight 147 | ) 148 | self.xent_weight_power = xent_weight_power 149 | self.criterion = criterion 150 | self.progress = 1 151 | self.epochs = 0 152 | 153 | @staticmethod 154 | def assert_output_not_nbdt(outputs): 155 | """ 156 | >>> x = torch.randn(1, 3, 224, 224) 157 | >>> TreeSupLoss.assert_output_not_nbdt(x) # all good! 158 | >>> x._nbdt_output_flag = True 159 | >>> TreeSupLoss.assert_output_not_nbdt(x) #doctest: +ELLIPSIS 160 | Traceback (most recent call last): 161 | ... 162 | AssertionError: ... 163 | >>> from nbdt.model import NBDT 164 | >>> import torchvision.models as models 165 | >>> model = models.resnet18() 166 | >>> y = model(x) 167 | >>> TreeSupLoss.assert_output_not_nbdt(y) # all good! 168 | >>> model = NBDT('CIFAR10', model, arch='ResNet18') 169 | >>> y = model(x) 170 | >>> TreeSupLoss.assert_output_not_nbdt(y) #doctest: +ELLIPSIS 171 | Traceback (most recent call last): 172 | ... 173 | AssertionError: ... 174 | """ 175 | assert getattr(outputs, "_nbdt_output_flag", False) is False, ( 176 | "Uh oh! Looks like you passed an NBDT model's output to an NBDT " 177 | "loss. NBDT losses are designed to take in the *original* model's " 178 | "outputs, as input. NBDT models are designed to only be used " 179 | "during validation and inference, not during training. Confused? " 180 | " Check out github.com/alvinwan/nbdt#convert-neural-networks-to-decision-trees" 181 | " for examples and instructions." 182 | ) 183 | 184 | def forward_tree(self, outputs, targets): 185 | raise NotImplementedError() 186 | 187 | def get_weight(self, start, end, power=1): 188 | progress = self.progress ** power 189 | return (1 - progress) * start + progress * end 190 | 191 | def forward(self, outputs, targets): 192 | loss_xent = self.criterion(outputs, targets) 193 | loss_tree = self.forward_tree(outputs, targets) 194 | 195 | tree_weight = self.get_weight( 196 | self.tree_supervision_weight, 197 | self.tree_supervision_weight_end, 198 | self.tree_supervision_weight_power, 199 | ) 200 | xent_weight = self.get_weight( 201 | self.xent_weight, self.xent_weight_end, self.xent_weight_power 202 | ) 203 | return loss_xent * xent_weight + loss_tree * tree_weight 204 | 205 | def set_epoch(self, cur, total): 206 | self.epochs = cur 207 | self.progress = cur / total 208 | if hasattr(super(), "set_epoch"): 209 | super().set_epoch(cur, total) 210 | 211 | 212 | class HardTreeSupLoss(TreeSupLoss): 213 | def forward_tree(self, outputs, targets): 214 | """ 215 | The supplementary losses are all uniformly down-weighted so that on 216 | average, each sample incurs half of its loss from standard cross entropy 217 | and half of its loss from all nodes. 218 | 219 | The code below is structured weirdly to minimize number of tensors 220 | constructed and moved from CPU to GPU or vice versa. In short, 221 | all outputs and targets for nodes with 2 children are gathered and 222 | moved onto GPU at once. Same with those with 3, with 4 etc. On CIFAR10, 223 | the max is 2. On CIFAR100, the max is 8. 224 | """ 225 | self.assert_output_not_nbdt(outputs) 226 | 227 | loss = 0 228 | num_losses = outputs.size(0) * len(self.tree.inodes) / 2.0 229 | 230 | outputs_subs = defaultdict(lambda: []) 231 | targets_subs = defaultdict(lambda: []) 232 | targets_ints = [int(target) for target in targets.cpu().long()] 233 | for node in self.tree.inodes: 234 | ( 235 | _, 236 | outputs_sub, 237 | targets_sub, 238 | ) = HardEmbeddedDecisionRules.get_node_logits_filtered( 239 | node, outputs, targets_ints 240 | ) 241 | 242 | key = node.num_classes 243 | assert outputs_sub.size(0) == len(targets_sub) 244 | outputs_subs[key].append(outputs_sub) 245 | targets_subs[key].extend(targets_sub) 246 | 247 | for key in outputs_subs: 248 | outputs_sub = torch.cat(outputs_subs[key], dim=0) 249 | targets_sub = torch.Tensor(targets_subs[key]).long().to(outputs_sub.device) 250 | 251 | if not outputs_sub.size(0): 252 | continue 253 | fraction = ( 254 | outputs_sub.size(0) / float(num_losses) * self.tree_supervision_weight 255 | ) 256 | loss += self.criterion(outputs_sub, targets_sub) * fraction 257 | return loss 258 | 259 | 260 | class SoftTreeSupLoss(TreeSupLoss): 261 | def __init__(self, *args, Rules=None, **kwargs): 262 | super().__init__(*args, Rules=SoftEmbeddedDecisionRules, **kwargs) 263 | 264 | def forward_tree(self, outputs, targets): 265 | self.assert_output_not_nbdt(outputs) 266 | return self.criterion(self.rules(outputs), targets) 267 | 268 | 269 | class SoftTreeLoss(SoftTreeSupLoss): 270 | 271 | accepts_tree_start_epochs = True 272 | accepts_tree_update_every_epochs = True 273 | accepts_tree_update_end_epochs = True 274 | accepts_arch = True 275 | accepts_net = lambda net, **kwargs: net 276 | accepts_checkpoint_path = lambda checkpoint_path, **kwargs: checkpoint_path 277 | 278 | def __init__( 279 | self, 280 | *args, 281 | arch=None, 282 | checkpoint_path="./", 283 | net=None, 284 | tree_start_epochs=67, 285 | tree_update_every_epochs=10, 286 | tree_update_end_epochs=120, 287 | **kwargs, 288 | ): 289 | super().__init__(*args, **kwargs) 290 | self.start_epochs = tree_start_epochs 291 | self.update_every_epochs = tree_update_every_epochs 292 | self.update_end_epochs = tree_update_end_epochs 293 | self.net = net 294 | self.arch = arch 295 | self.checkpoint_path = checkpoint_path 296 | 297 | def forward_tree(self, outputs, targets): 298 | if self.epochs < self.start_epochs: 299 | return self.criterion(outputs, targets) # regular xent 300 | self.assert_output_not_nbdt(outputs) 301 | return self.criterion(self.rules(outputs), targets) 302 | 303 | def set_epoch(self, *args, **kwargs): 304 | super().set_epoch(*args, **kwargs) 305 | offset = self.epochs - self.start_epochs 306 | if ( 307 | offset >= 0 308 | and offset % self.update_every_epochs == 0 309 | and self.epochs < self.update_end_epochs 310 | ): 311 | checkpoint_dir = self.checkpoint_path.replace(".pth", "") 312 | path_graph = os.path.join(checkpoint_dir, f"graph-epoch{self.epochs}.json") 313 | self.tree.update_from_model( 314 | self.net, self.arch, self.tree.dataset, path_graph=path_graph 315 | ) 316 | 317 | 318 | class SoftSegTreeSupLoss(SoftTreeSupLoss): 319 | def forward(self, outputs, targets): 320 | self.assert_output_not_nbdt(outputs) 321 | 322 | loss = self.criterion(outputs, targets) 323 | coerced_outputs = coerce_tensor(outputs) 324 | bayesian_outputs = self.rules(coerced_outputs) 325 | bayesian_outputs = uncoerce_tensor(bayesian_outputs, outputs.shape) 326 | loss += self.criterion(bayesian_outputs, targets) * self.tree_supervision_weight 327 | return loss 328 | -------------------------------------------------------------------------------- /nbdt/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = names = ("top1", "top2", "top5", "top10") 5 | 6 | 7 | class TopK: 8 | def __init__(self, k=1): 9 | self.k = k 10 | self.clear() 11 | 12 | def clear(self): 13 | self.correct = 0 14 | self.total = 0 15 | 16 | def forward(self, outputs, targets): 17 | _, preds = torch.topk(outputs, self.k) 18 | results = [(pred == target).any() for pred, target in zip(preds, targets)] 19 | self.correct += sum(results).item() 20 | self.total += targets.size(0) 21 | 22 | def report(self): 23 | return self.correct / (self.total or 1) 24 | 25 | def __repr__(self): 26 | return f"Top{self.k}: {self.report()}" 27 | 28 | def __str__(self): 29 | return repr(self) 30 | 31 | 32 | top1 = lambda: TopK(1) 33 | top2 = lambda: TopK(2) 34 | top5 = lambda: TopK(5) 35 | top10 = lambda: TopK(10) 36 | -------------------------------------------------------------------------------- /nbdt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .wideresnet import * 3 | from pytorchcv.models.efficientnet import * 4 | from torchvision.models import * 5 | 6 | 7 | def get_model_choices(): 8 | from types import ModuleType 9 | 10 | for key, value in globals().items(): 11 | if not key.startswith('__') and not isinstance(value, ModuleType) and callable(value): 12 | yield key 13 | -------------------------------------------------------------------------------- /nbdt/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from nbdt.models.utils import get_pretrained_model 13 | 14 | 15 | __all__ = ("ResNet10", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152") 16 | 17 | 18 | model_urls = { 19 | ( 20 | "ResNet10", 21 | "CIFAR10", 22 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-ResNet10.pth", 23 | ( 24 | "ResNet10", 25 | "CIFAR100", 26 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet10.pth", 27 | ( 28 | "ResNet18", 29 | "CIFAR10", 30 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-ResNet18.pth", 31 | ( 32 | "ResNet18", 33 | "CIFAR100", 34 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18.pth", 35 | ( 36 | "ResNet18", 37 | "TinyImagenet200", 38 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-ResNet18.pth", 39 | } 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(BasicBlock, self).__init__() 47 | self.conv1 = nn.Conv2d( 48 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 49 | ) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d( 52 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 53 | ) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion * planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d( 60 | in_planes, 61 | self.expansion * planes, 62 | kernel_size=1, 63 | stride=stride, 64 | bias=False, 65 | ), 66 | nn.BatchNorm2d(self.expansion * planes), 67 | ) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.bn2(self.conv2(out)) 72 | out += self.shortcut(x) 73 | out = F.relu(out) 74 | return out 75 | 76 | 77 | class Bottleneck(nn.Module): 78 | expansion = 4 79 | 80 | def __init__(self, in_planes, planes, stride=1): 81 | super(Bottleneck, self).__init__() 82 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 83 | self.bn1 = nn.BatchNorm2d(planes) 84 | self.conv2 = nn.Conv2d( 85 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 86 | ) 87 | self.bn2 = nn.BatchNorm2d(planes) 88 | self.conv3 = nn.Conv2d( 89 | planes, self.expansion * planes, kernel_size=1, bias=False 90 | ) 91 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 92 | 93 | self.shortcut = nn.Sequential() 94 | if stride != 1 or in_planes != self.expansion * planes: 95 | self.shortcut = nn.Sequential( 96 | nn.Conv2d( 97 | in_planes, 98 | self.expansion * planes, 99 | kernel_size=1, 100 | stride=stride, 101 | bias=False, 102 | ), 103 | nn.BatchNorm2d(self.expansion * planes), 104 | ) 105 | 106 | def forward(self, x): 107 | out = F.relu(self.bn1(self.conv1(x))) 108 | out = F.relu(self.bn2(self.conv2(out))) 109 | out = self.bn3(self.conv3(out)) 110 | out += self.shortcut(x) 111 | out = F.relu(out) 112 | return out 113 | 114 | 115 | class ResNet(nn.Module): 116 | def __init__(self, block, num_blocks, num_classes=10): 117 | super(ResNet, self).__init__() 118 | self.in_planes = 64 119 | 120 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 121 | self.bn1 = nn.BatchNorm2d(64) 122 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 123 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 124 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 125 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 126 | self.linear = nn.Linear(512 * block.expansion, num_classes) 127 | 128 | def _make_layer(self, block, planes, num_blocks, stride): 129 | strides = [stride] + [1] * (num_blocks - 1) 130 | layers = [] 131 | for stride in strides: 132 | layers.append(block(self.in_planes, planes, stride)) 133 | self.in_planes = planes * block.expansion 134 | return nn.Sequential(*layers) 135 | 136 | def features(self, x): 137 | out = F.relu(self.bn1(self.conv1(x))) 138 | out = self.layer1(out) 139 | out = self.layer2(out) 140 | out = self.layer3(out) 141 | out = self.layer4(out) 142 | out = F.avg_pool2d(out, out.size()[2:]) # global average pooling 143 | out = out.view(out.size(0), -1) 144 | return out 145 | 146 | def forward(self, x): 147 | out = self.features(x) 148 | out = self.linear(out) 149 | return out 150 | 151 | 152 | def _ResNet(arch, *args, pretrained=False, progress=True, dataset="CIFAR10", **kwargs): 153 | model = ResNet(*args, **kwargs) 154 | model = get_pretrained_model( 155 | arch, dataset, model, model_urls, pretrained=pretrained, progress=progress 156 | ) 157 | return model 158 | 159 | 160 | def ResNet10(pretrained=False, progress=True, **kwargs): 161 | return _ResNet( 162 | "ResNet10", 163 | BasicBlock, 164 | [1, 1, 1, 1], 165 | pretrained=pretrained, 166 | progress=progress, 167 | **kwargs 168 | ) 169 | 170 | 171 | def ResNet18(pretrained=False, progress=True, **kwargs): 172 | return _ResNet( 173 | "ResNet18", 174 | BasicBlock, 175 | [2, 2, 2, 2], 176 | pretrained=pretrained, 177 | progress=progress, 178 | **kwargs 179 | ) 180 | 181 | 182 | def ResNet34(pretrained=False, progress=True, **kwargs): 183 | return _ResNet( 184 | "ResNet34", 185 | BasicBlock, 186 | [3, 4, 6, 3], 187 | pretrained=pretrained, 188 | progress=progress, 189 | **kwargs 190 | ) 191 | 192 | 193 | def ResNet50(pretrained=False, progress=True, **kwargs): 194 | return _ResNet( 195 | "ResNet50", 196 | Bottleneck, 197 | [3, 4, 6, 3], 198 | pretrained=pretrained, 199 | progress=progress, 200 | **kwargs 201 | ) 202 | 203 | 204 | def ResNet101(pretrained=False, progress=True, **kwargs): 205 | return _ResNet( 206 | "ResNet101", 207 | Bottleneck, 208 | [3, 4, 23, 3], 209 | pretrained=pretrained, 210 | progress=progress, 211 | **kwargs 212 | ) 213 | 214 | 215 | def ResNet152(pretrained=False, progress=True, **kwargs): 216 | return _ResNet( 217 | "ResNet152", 218 | Bottleneck, 219 | [3, 8, 36, 3], 220 | pretrained=pretrained, 221 | progress=progress, 222 | **kwargs 223 | ) 224 | 225 | 226 | def test(): 227 | net = ResNet18() 228 | y = net(torch.randn(1, 3, 32, 32)) 229 | print(y.size()) 230 | 231 | 232 | # test() 233 | -------------------------------------------------------------------------------- /nbdt/models/utils.py: -------------------------------------------------------------------------------- 1 | from torch.hub import load_state_dict_from_url 2 | from nbdt.utils import Colors 3 | from pathlib import Path 4 | import torch 5 | 6 | # TODO(alvin): fix checkpoint structure so that this isn't neededd 7 | def load_state_dict(net, state_dict): 8 | try: 9 | net.load_state_dict(state_dict) 10 | except RuntimeError as e: 11 | if "Missing key(s) in state_dict:" in str(e): 12 | net.load_state_dict( 13 | { 14 | key.replace("module.", "", 1): value 15 | for key, value in state_dict.items() 16 | } 17 | ) 18 | 19 | 20 | def make_kwarg_optional(init, **optional_kwargs): 21 | """Returns wrapper function that attempts 'optional' kwargs. 22 | 23 | If initialization fails, retries initialization without 'optional' kwargs. 24 | """ 25 | 26 | def f(**kwargs): 27 | try: 28 | net = init(**optional_kwargs, **kwargs) 29 | except TypeError as e: # likely because `dataset` not allowed arg 30 | print(e) 31 | 32 | try: 33 | net = init(**kwargs) 34 | except Exception as e: 35 | Colors.red(f"Fatal error: {e}") 36 | exit() 37 | return net 38 | 39 | return f 40 | 41 | 42 | def get_pretrained_model( 43 | arch, 44 | dataset, 45 | model, 46 | model_urls, 47 | pretrained=False, 48 | progress=True, 49 | root=".cache/torch/checkpoints", 50 | ): 51 | if pretrained: 52 | state_dict = load_state_dict_from_key( 53 | [(arch, dataset)], 54 | model_urls, 55 | pretrained, 56 | progress, 57 | root, 58 | device=get_model_device(model), 59 | ) 60 | state_dict = coerce_state_dict(state_dict, model.state_dict()) 61 | model.load_state_dict(state_dict) 62 | return model 63 | 64 | 65 | def coerce_state_dict(state_dict, reference_state_dict): 66 | if "net" in state_dict: 67 | state_dict = state_dict["net"] 68 | has_reference_module = list(reference_state_dict)[0].startswith("module.") 69 | has_module = list(state_dict)[0].startswith("module.") 70 | if not has_reference_module and has_module: 71 | state_dict = { 72 | key.replace("module.", "", 1): value for key, value in state_dict.items() 73 | } 74 | elif has_reference_module and not has_module: 75 | state_dict = {"module." + key: value for key, value in state_dict.items()} 76 | return state_dict 77 | 78 | 79 | def get_model_device(model): 80 | return next(model.parameters()).device 81 | 82 | 83 | def load_state_dict_from_key( 84 | keys, 85 | model_urls, 86 | pretrained=False, 87 | progress=True, 88 | root=".cache/torch/checkpoints", 89 | device="cpu", 90 | ): 91 | valid_keys = [key for key in keys if key in model_urls] 92 | if not valid_keys: 93 | raise UserWarning(f"None of the keys {keys} correspond to a pretrained model.") 94 | key = valid_keys[-1] 95 | url = model_urls[key] 96 | Colors.green(f"Loading pretrained model {key} from {url}") 97 | return load_state_dict_from_url( 98 | url, 99 | Path.home() / root, 100 | progress=progress, 101 | check_hash=False, 102 | map_location=torch.device(device), 103 | ) 104 | -------------------------------------------------------------------------------- /nbdt/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | from pytorchcv.models.wrn_cifar import ( 2 | wrn28_10_cifar10, 3 | wrn28_10_cifar100, 4 | get_wrn_cifar, 5 | ) 6 | from nbdt.models.utils import get_pretrained_model 7 | import torch.nn as nn 8 | 9 | 10 | __all__ = ("wrn28_10", "wrn28_10_cifar10", "wrn28_10_cifar100") 11 | 12 | 13 | model_urls = { 14 | ( 15 | "wrn28_10", 16 | "TinyImagenet200", 17 | ): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-wrn28_10.pth" 18 | } 19 | 20 | 21 | def _wrn(arch, model, pretrained=False, progress=True, dataset="CIFAR10"): 22 | model = get_pretrained_model( 23 | arch, dataset, model, model_urls, pretrained=pretrained, progress=progress 24 | ) 25 | return model 26 | 27 | 28 | def wrn28_10(pretrained=False, progress=True, dataset="CIFAR10", **kwargs): 29 | """Replace `final_pool` (8x8 average pooling) with a global average pooling. 30 | 31 | If this gets crappy accuracy for TinyImagenet200, it's probably because the 32 | final pooled feature map is 16x16 instead of 8x8. So needs another stride 2 33 | stage, technically. 34 | """ 35 | model = get_wrn_cifar(blocks=28, width_factor=10, model_name="wrn28_10", **kwargs) 36 | model.features.final_pool = nn.AdaptiveAvgPool2d((1, 1)) 37 | model = _wrn( 38 | "wrn28_10", model, pretrained=pretrained, progress=progress, dataset=dataset 39 | ) 40 | return model 41 | -------------------------------------------------------------------------------- /nbdt/thirdparty/nx.py: -------------------------------------------------------------------------------- 1 | """Utilities acting directly on networkx objects""" 2 | 3 | from nbdt.utils import makeparentdirs 4 | import networkx as nx 5 | import json 6 | import random 7 | from nbdt.utils import DATASETS, METHODS, fwd 8 | from networkx.readwrite.json_graph import node_link_data, node_link_graph 9 | from sklearn.cluster import AgglomerativeClustering 10 | from pathlib import Path 11 | import nbdt.models as models 12 | import torch 13 | import argparse 14 | import os 15 | 16 | 17 | def is_leaf(G, node): 18 | return len(G.succ[node]) == 0 19 | 20 | 21 | def get_leaves(G, root=None): 22 | nodes = G.nodes if root is None else nx.descendants(G, root) | {root} 23 | for node in nodes: 24 | if is_leaf(G, node): 25 | yield node 26 | 27 | 28 | def get_roots(G): 29 | for node in G.nodes: 30 | if len(G.pred[node]) == 0: 31 | yield node 32 | 33 | 34 | def get_root(G): 35 | roots = list(get_roots(G)) 36 | assert len(roots) == 1, f"Multiple ({len(roots)}) found" 37 | return roots[0] 38 | 39 | 40 | def get_depth(G): 41 | def _get_depth(node): 42 | if not G.succ[node]: 43 | return 1 44 | return max([_get_depth(child) for child in G.succ[node]]) + 1 45 | 46 | return max([_get_depth(root) for root in get_roots(G)]) 47 | 48 | 49 | def get_leaf_to_path(G): 50 | leaf_to_path = {} 51 | for root in get_roots(G): 52 | frontier = [(root, 0, [])] 53 | while frontier: 54 | node, child_index, path = frontier.pop(0) 55 | path = path + [(child_index, node)] 56 | if is_leaf(G, node): 57 | leaf_to_path[node] = path 58 | continue 59 | frontier.extend([(child, i, path) for i, child in enumerate(G.succ[node])]) 60 | return leaf_to_path 61 | 62 | 63 | def write_graph(G, path): 64 | makeparentdirs(path) 65 | with open(str(path), "w") as f: 66 | json.dump(node_link_data(G), f) 67 | 68 | 69 | def read_graph(path): 70 | if not os.path.exists(path): 71 | parent = Path(fwd()).parent 72 | print(f"No such file or directory: {path}. Looking in {str(parent)}") 73 | path = parent / path 74 | with open(path) as f: 75 | return node_link_graph(json.load(f)) 76 | -------------------------------------------------------------------------------- /nbdt/thirdparty/wn.py: -------------------------------------------------------------------------------- 1 | """Utilities for NLTK WordNet synsets and WordNet IDs""" 2 | import networkx as nx 3 | import json 4 | import random 5 | from nbdt.utils import DATASETS, METHODS, fwd, get_directory 6 | from networkx.readwrite.json_graph import node_link_data, node_link_graph 7 | from sklearn.cluster import AgglomerativeClustering 8 | from pathlib import Path 9 | import nbdt.models as models 10 | import torch 11 | import argparse 12 | import os 13 | import nltk 14 | 15 | 16 | def maybe_install_wordnet(): 17 | try: 18 | nltk.data.find("corpora/wordnet") 19 | except Exception as e: 20 | print(e) 21 | nltk.download("wordnet") 22 | 23 | 24 | def get_wnids(path_wnids): 25 | if not os.path.exists(path_wnids): 26 | parent = Path(fwd()).parent 27 | print(f"No such file or directory: {path_wnids}. Looking in {str(parent)}") 28 | path_wnids = parent / path_wnids 29 | with open(path_wnids) as f: 30 | wnids = [wnid.strip() for wnid in f.readlines()] 31 | return wnids 32 | 33 | 34 | def get_wnids_from_dataset(dataset, root="./nbdt/wnids"): 35 | directory = get_directory(dataset, root) 36 | return get_wnids(f"{directory}.txt") 37 | 38 | 39 | ########## 40 | # SYNSET # 41 | ########## 42 | 43 | 44 | def synset_to_wnid(synset): 45 | return f"{synset.pos()}{synset.offset():08d}" 46 | 47 | 48 | def wnid_to_synset(wnid): 49 | from nltk.corpus import wordnet as wn # entire script should not depend on wn 50 | 51 | offset = int(wnid[1:]) 52 | pos = wnid[0] 53 | 54 | try: 55 | return wn.synset_from_pos_and_offset(wnid[0], offset) 56 | except: 57 | return FakeSynset(wnid) 58 | 59 | 60 | def wnid_to_name(wnid): 61 | return synset_to_name(wnid_to_synset(wnid)) 62 | 63 | 64 | def synset_to_name(synset): 65 | return synset.name().split(".")[0] 66 | 67 | 68 | def write_wnids(wnids, path): 69 | makeparentdirs(path) 70 | with open(str(path), "w") as f: 71 | f.write("\n".join(wnids)) 72 | 73 | 74 | class FakeSynset: 75 | def __init__(self, wnid): 76 | self.wnid = wnid 77 | 78 | assert isinstance(wnid, str) 79 | 80 | @staticmethod 81 | def create_from_offset(offset): 82 | return FakeSynset("f{:08d}".format(offset)) 83 | 84 | def offset(self): 85 | return int(self.wnid[1:]) 86 | 87 | def pos(self): 88 | return "f" 89 | 90 | def name(self): 91 | return "(generated)" 92 | 93 | def definition(self): 94 | return "(generated)" 95 | -------------------------------------------------------------------------------- /nbdt/tree.py: -------------------------------------------------------------------------------- 1 | """Tree and node utilities for navigating the NBDT hierarchy""" 2 | import torchvision.datasets as datasets 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from collections import defaultdict 7 | from nbdt.utils import DATASET_TO_NUM_CLASSES, DATASETS 8 | from collections import defaultdict 9 | from nbdt.thirdparty.wn import get_wnids, FakeSynset, wnid_to_synset, wnid_to_name 10 | from nbdt.thirdparty.nx import read_graph, get_leaves, get_leaf_to_path 11 | from nbdt.utils import ( 12 | dataset_to_default_path_graph, 13 | dataset_to_default_path_wnids, 14 | hierarchy_to_path_graph, 15 | ) 16 | import torch.nn as nn 17 | import random 18 | 19 | 20 | def dataset_to_dummy_classes(dataset): 21 | assert dataset in DATASETS 22 | num_classes = DATASET_TO_NUM_CLASSES[dataset] 23 | return [FakeSynset.create_from_offset(i).wnid for i in range(num_classes)] 24 | 25 | 26 | def add_arguments(parser): 27 | parser.add_argument( 28 | "--hierarchy", 29 | help="Hierarchy to use. If supplied, will be used to " 30 | "generate --path-graph. --path-graph takes precedence.", 31 | ) 32 | parser.add_argument( 33 | "--path-graph", help="Path to graph-*.json file." 34 | ) # WARNING: hard-coded suffix -build in generate_checkpoint_fname 35 | parser.add_argument("--path-wnids", help="Path to wnids.txt file.") 36 | 37 | 38 | class Node: 39 | def __init__(self, tree, wnid, other_class=False): 40 | self.tree = tree 41 | 42 | self.wnid = wnid 43 | self.name = wnid_to_name(wnid) 44 | self.synset = wnid_to_synset(wnid) 45 | 46 | self.original_classes = tree.classes 47 | self.num_original_classes = len(self.tree.wnids_leaves) 48 | 49 | self.has_other = other_class and not (self.is_root() or self.is_leaf()) 50 | self.num_children = len(self.succ) 51 | 52 | self.num_classes = self.num_children + int(self.has_other) 53 | 54 | ( 55 | self.class_index_to_child_index, 56 | self.child_index_to_class_index, 57 | ) = self.build_class_mappings() 58 | self.classes = self.build_classes() 59 | 60 | assert len(self.classes) == self.num_classes, ( 61 | f"Number of classes {self.num_classes} does not equal number of " 62 | f"class names found ({len(self.classes)}): {self.classes}" 63 | ) 64 | 65 | self.leaves = list(self.get_leaves()) 66 | self.num_leaves = len(self.leaves) 67 | 68 | def wnid_to_class_index(self, wnid): 69 | return self.tree.wnids_leaves.index(wnid) 70 | 71 | def wnid_to_child_index(self, wnid): 72 | return [child.wnid for child in self.children].index(wnid) 73 | 74 | @property 75 | def parent(self): 76 | if not self.parents: 77 | return None 78 | return self.parents[0] 79 | 80 | @property 81 | def pred(self): 82 | return self.tree.G.pred[self.wnid] 83 | 84 | @property 85 | def parents(self): 86 | return [self.tree.wnid_to_node[wnid] for wnid in self.pred] 87 | 88 | @property 89 | def succ(self): 90 | return self.tree.G.succ[self.wnid] 91 | 92 | @property 93 | def children(self): 94 | return [self.tree.wnid_to_node[wnid] for wnid in self.succ] 95 | 96 | def get_leaves(self): 97 | return get_leaves(self.tree.G, self.wnid) 98 | 99 | def is_leaf(self): 100 | return len(self.succ) == 0 101 | 102 | def is_root(self): 103 | return len(self.pred) == 0 104 | 105 | def build_class_mappings(self): 106 | if self.is_leaf(): 107 | return {}, {} 108 | 109 | old_to_new = defaultdict(lambda: []) 110 | new_to_old = defaultdict(lambda: []) 111 | for new_index, child in enumerate(self.succ): 112 | for leaf in get_leaves(self.tree.G, child): 113 | old_index = self.wnid_to_class_index(leaf) 114 | old_to_new[old_index].append(new_index) 115 | new_to_old[new_index].append(old_index) 116 | 117 | if not self.has_other: 118 | return old_to_new, new_to_old 119 | 120 | new_index = self.num_children 121 | for old in range(self.num_original_classes): 122 | if old not in old_to_new: 123 | old_to_new[old].append(new_index) 124 | new_to_old[new_index].append(old) 125 | return old_to_new, new_to_old 126 | 127 | def build_classes(self): 128 | return [ 129 | ",".join([self.original_classes[old] for old in old_indices]) 130 | for new_index, old_indices in sorted( 131 | self.child_index_to_class_index.items(), key=lambda t: t[0] 132 | ) 133 | ] 134 | 135 | @property 136 | def class_counts(self): 137 | """Number of old classes in each new class""" 138 | return [len(old_indices) for old_indices in self.child_index_to_class_index] 139 | 140 | @staticmethod 141 | def dim(nodes): 142 | return sum([node.num_classes for node in nodes]) 143 | 144 | 145 | class Tree: 146 | def __init__( 147 | self, dataset, path_graph=None, path_wnids=None, classes=None, hierarchy=None 148 | ): 149 | if dataset and hierarchy and not path_graph: 150 | path_graph = hierarchy_to_path_graph(dataset, hierarchy) 151 | if dataset and not path_graph: 152 | path_graph = dataset_to_default_path_graph(dataset) 153 | if dataset and not path_wnids: 154 | path_wnids = dataset_to_default_path_wnids(dataset) 155 | if dataset and not classes: 156 | classes = dataset_to_dummy_classes(dataset) 157 | 158 | self.load_hierarchy(dataset, path_graph, path_wnids, classes) 159 | 160 | def load_hierarchy(self, dataset, path_graph, path_wnids, classes): 161 | self.dataset = dataset 162 | self.path_graph = path_graph 163 | self.path_wnids = path_wnids 164 | self.classes = classes 165 | self.G = read_graph(path_graph) 166 | self.wnids_leaves = get_wnids(path_wnids) 167 | self.wnid_to_class = { 168 | wnid: cls for wnid, cls in zip(self.wnids_leaves, self.classes) 169 | } 170 | self.wnid_to_class_index = {wnid: i for i, wnid in enumerate(self.wnids_leaves)} 171 | self.wnid_to_node = self.get_wnid_to_node() 172 | self.nodes = [self.wnid_to_node[wnid] for wnid in sorted(self.wnid_to_node)] 173 | self.inodes = [node for node in self.nodes if not node.is_leaf()] 174 | self.leaves = [self.wnid_to_node[wnid] for wnid in self.wnids_leaves] 175 | 176 | def update_from_model( 177 | self, model, arch, dataset, classes=None, path_wnids=None, path_graph=None 178 | ): 179 | from nbdt.hierarchy import generate_hierarchy # avoid circular import hah 180 | assert model is not None, "`model` cannot be NoneType" 181 | path_graph = generate_hierarchy( 182 | dataset=dataset, method="induced", arch=arch, model=model, path=path_graph, 183 | ) 184 | tree = Tree(dataset, path_graph=path_graph, path_wnids=path_wnids, classes=classes, hierarchy="induced") 185 | self.load_hierarchy( 186 | dataset=tree.dataset, 187 | path_graph=tree.path_graph, 188 | path_wnids=tree.path_wnids, 189 | classes=tree.classes 190 | ) 191 | 192 | @classmethod 193 | def create_from_args(cls, args, classes=None): 194 | return cls( 195 | args.dataset, 196 | args.path_graph, 197 | args.path_wnids, 198 | classes=classes, 199 | hierarchy=args.hierarchy, 200 | ) 201 | 202 | @property 203 | def root(self): 204 | for node in self.inodes: 205 | if node.is_root(): 206 | return node 207 | raise UserWarning("Should not be reachable. Tree should always have root") 208 | 209 | def get_wnid_to_node(self): 210 | wnid_to_node = {} 211 | for wnid in self.G: 212 | wnid_to_node[wnid] = Node(self, wnid) 213 | return wnid_to_node 214 | 215 | def get_leaf_to_steps(self): 216 | node = self.inodes[0] 217 | leaf_to_path = get_leaf_to_path(self.G) 218 | leaf_to_steps = {} 219 | for leaf in self.wnids_leaves: 220 | next_indices = [index for index, _ in leaf_to_path[leaf][1:]] + [-1] 221 | leaf_to_steps[leaf] = [ 222 | { 223 | "node": self.wnid_to_node[wnid], 224 | "name": self.wnid_to_node[wnid].name, 225 | "next_index": next_index, # curr node's next child index to traverse 226 | } 227 | for next_index, (_, wnid) in zip(next_indices, leaf_to_path[leaf]) 228 | ] 229 | return leaf_to_steps 230 | 231 | def visualize(self, path_html, dataset=None, **kwargs): 232 | """ 233 | :param path_html: Where to write the final generated visualization 234 | """ 235 | from nbdt.hierarchy import generate_hierarchy_vis_from # avoid circular import hah 236 | generate_hierarchy_vis_from( 237 | self.G, 238 | dataset=dataset, 239 | path_html=path_html, 240 | **kwargs 241 | ) 242 | -------------------------------------------------------------------------------- /nbdt/utils.py: -------------------------------------------------------------------------------- 1 | """Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | """ 6 | import os 7 | import sys 8 | import time 9 | import math 10 | import numpy as np 11 | 12 | from urllib.request import urlopen, Request 13 | from PIL import Image 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | from pathlib import Path 17 | import io 18 | 19 | # tree-generation consntants 20 | METHODS = ("wordnet", "random", "induced") 21 | DATASETS = ( 22 | "CIFAR10", 23 | "CIFAR100", 24 | "TinyImagenet200", 25 | "Imagenet1000", 26 | "Cityscapes", 27 | "PascalContext", 28 | "LookIntoPerson", 29 | "ADE20K", 30 | ) 31 | DATASET_TO_NUM_CLASSES = { 32 | "CIFAR10": 10, 33 | "CIFAR100": 100, 34 | "TinyImagenet200": 200, 35 | "Imagenet1000": 1000, 36 | "Cityscapes": 19, 37 | "PascalContext": 59, 38 | "LookIntoPerson": 20, 39 | "ADE20K": 150, 40 | } 41 | DATASET_TO_CLASSES = { 42 | "CIFAR10": [ 43 | "airplane", 44 | "automobile", 45 | "bird", 46 | "cat", 47 | "deer", 48 | "dog", 49 | "frog", 50 | "horse", 51 | "ship", 52 | "truck", 53 | ] 54 | } 55 | 56 | 57 | def fwd(): 58 | """Get file's working directory""" 59 | return Path(__file__).parent.absolute() 60 | 61 | 62 | def dataset_to_default_path_graph(dataset): 63 | return hierarchy_to_path_graph(dataset, "induced") 64 | 65 | 66 | def hierarchy_to_path_graph(dataset, hierarchy): 67 | return os.path.join(fwd(), f"hierarchies/{dataset}/graph-{hierarchy}.json") 68 | 69 | 70 | def dataset_to_default_path_wnids(dataset): 71 | return os.path.join(fwd(), f"wnids/{dataset}.txt") 72 | 73 | 74 | def get_directory(dataset, root="./nbdt/hierarchies"): 75 | return os.path.join(root, dataset) 76 | 77 | 78 | def generate_kwargs(args, object, name="Dataset", globals={}, kwargs=None): 79 | kwargs = kwargs or {} 80 | 81 | for key in dir(object): 82 | accepts_key = getattr(object, key, False) 83 | if not key.startswith("accepts_") or not accepts_key: 84 | continue 85 | key = key.replace("accepts_", "", 1) 86 | assert key in args or callable(accepts_key) 87 | 88 | value = getattr(args, key, None) 89 | if callable(accepts_key): 90 | kwargs[key] = accepts_key(**globals) 91 | Colors.cyan(f"{key}:\t(callable)") 92 | elif accepts_key and value is not None: 93 | kwargs[key] = value 94 | Colors.cyan(f"{key}:\t{value}") 95 | elif value is not None: 96 | Colors.red(f"Warning: {name} does not support custom " f"{key}: {value}") 97 | return kwargs 98 | 99 | 100 | def load_image_from_path(path): 101 | """Path can be local or a URL""" 102 | headers = { 103 | "User-Agent": "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3" 104 | } 105 | if "http" in path: 106 | request = Request(path, headers=headers) 107 | file = io.BytesIO(urlopen(request).read()) 108 | else: 109 | file = path 110 | return Image.open(file) 111 | 112 | 113 | def makeparentdirs(path): 114 | dir = Path(path).parent 115 | os.makedirs(dir, exist_ok=True) 116 | 117 | 118 | class Colors: 119 | RED = "\x1b[31m" 120 | GREEN = "\x1b[32m" 121 | ENDC = "\033[0m" 122 | BOLD = "\033[1m" 123 | CYAN = "\x1b[36m" 124 | 125 | @classmethod 126 | def red(cls, *args): 127 | print(cls.RED + args[0], *args[1:], cls.ENDC) 128 | 129 | @classmethod 130 | def green(cls, *args): 131 | print(cls.GREEN + args[0], *args[1:], cls.ENDC) 132 | 133 | @classmethod 134 | def cyan(cls, *args): 135 | print(cls.CYAN + args[0], *args[1:], cls.ENDC) 136 | 137 | @classmethod 138 | def bold(cls, *args): 139 | print(cls.BOLD + args[0], *args[1:], cls.ENDC) 140 | 141 | 142 | def get_mean_and_std(dataset): 143 | """Compute the mean and std value of dataset.""" 144 | dataloader = torch.utils.data.DataLoader( 145 | dataset, batch_size=1, shuffle=True, num_workers=2 146 | ) 147 | mean = torch.zeros(3) 148 | std = torch.zeros(3) 149 | print("==> Computing mean and std..") 150 | for inputs, targets in dataloader: 151 | for i in range(3): 152 | mean[i] += inputs[:, i, :, :].mean() 153 | std[i] += inputs[:, i, :, :].std() 154 | mean.div_(len(dataset)) 155 | std.div_(len(dataset)) 156 | return mean, std 157 | 158 | 159 | def init_params(net): 160 | """Init layer parameters.""" 161 | for m in net.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | init.kaiming_normal(m.weight, mode="fan_out") 164 | if m.bias: 165 | init.constant(m.bias, 0) 166 | elif isinstance(m, nn.BatchNorm2d): 167 | init.constant(m.weight, 1) 168 | init.constant(m.bias, 0) 169 | elif isinstance(m, nn.Linear): 170 | init.normal(m.weight, std=1e-3) 171 | if m.bias: 172 | init.constant(m.bias, 0) 173 | 174 | 175 | try: 176 | _, term_width = os.popen("stty size", "r").read().split() 177 | term_width = int(term_width) 178 | except Exception as e: 179 | print(e) 180 | term_width = 50 181 | 182 | TOTAL_BAR_LENGTH = 65.0 183 | last_time = time.time() 184 | begin_time = last_time 185 | 186 | 187 | def progress_bar(current, total, msg=None): 188 | global last_time, begin_time 189 | if current == 0: 190 | begin_time = time.time() # Reset for new bar. 191 | 192 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 193 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 194 | 195 | sys.stdout.write(" [") 196 | for i in range(cur_len): 197 | sys.stdout.write("=") 198 | sys.stdout.write(">") 199 | for i in range(rest_len): 200 | sys.stdout.write(".") 201 | sys.stdout.write("]") 202 | 203 | cur_time = time.time() 204 | step_time = cur_time - last_time 205 | last_time = cur_time 206 | tot_time = cur_time - begin_time 207 | 208 | L = [] 209 | L.append(" Step: %s" % format_time(step_time)) 210 | L.append(" | Tot: %s" % format_time(tot_time)) 211 | if msg: 212 | L.append(" | " + msg) 213 | 214 | msg = "".join(L) 215 | sys.stdout.write(msg) 216 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 217 | sys.stdout.write(" ") 218 | 219 | # Go back to the center of the bar. 220 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 221 | sys.stdout.write("\b") 222 | sys.stdout.write(" %d/%d " % (current + 1, total)) 223 | 224 | if current < total - 1: 225 | sys.stdout.write("\r") 226 | else: 227 | sys.stdout.write("\n") 228 | sys.stdout.flush() 229 | 230 | 231 | def format_time(seconds): 232 | days = int(seconds / 3600 / 24) 233 | seconds = seconds - days * 3600 * 24 234 | hours = int(seconds / 3600) 235 | seconds = seconds - hours * 3600 236 | minutes = int(seconds / 60) 237 | seconds = seconds - minutes * 60 238 | secondsf = int(seconds) 239 | seconds = seconds - secondsf 240 | millis = int(seconds * 1000) 241 | 242 | f = "" 243 | i = 1 244 | if days > 0: 245 | f += str(days) + "D" 246 | i += 1 247 | if hours > 0 and i <= 2: 248 | f += str(hours) + "h" 249 | i += 1 250 | if minutes > 0 and i <= 2: 251 | f += str(minutes) + "m" 252 | i += 1 253 | if secondsf > 0 and i <= 2: 254 | f += str(secondsf) + "s" 255 | i += 1 256 | if millis > 0 and i <= 2: 257 | f += str(millis) + "ms" 258 | i += 1 259 | if f == "": 260 | f = "0ms" 261 | return f 262 | 263 | 264 | def set_np_printoptions(): 265 | np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) 266 | 267 | 268 | def generate_checkpoint_fname( 269 | dataset, 270 | arch, 271 | path_graph, 272 | wnid=None, 273 | name="", 274 | trainset=None, 275 | include_labels=(), 276 | exclude_labels=(), 277 | include_classes=(), 278 | num_samples=0, 279 | tree_supervision_weight=1, 280 | fine_tune=False, 281 | loss="CrossEntropyLoss", 282 | lr=0.1, 283 | tree_supervision_weight_end=None, 284 | tree_supervision_weight_power=1, 285 | xent_weight=1, 286 | xent_weight_end=None, 287 | xent_weight_power=1, 288 | tree_start_epochs=None, 289 | tree_update_every_epochs=None, 290 | tree_update_end_epochs=None, 291 | **kwargs, 292 | ): 293 | fname = "ckpt" 294 | fname += "-" + dataset 295 | fname += "-" + arch 296 | if lr != 0.1: 297 | fname += f"-lr{lr}" 298 | if name: 299 | fname += "-" + name 300 | if path_graph and "TreeSupLoss" in loss: 301 | path = Path(path_graph) 302 | fname += "-" + path.stem.replace("graph-", "", 1) 303 | if include_labels: 304 | labels = ",".join(map(str, include_labels)) 305 | fname += f"-incl{labels}" 306 | if exclude_labels: 307 | labels = ",".join(map(str, exclude_labels)) 308 | fname += f"-excl{labels}" 309 | if include_classes: 310 | labels = ",".join(map(str, include_classes)) 311 | fname += f"-incc{labels}" 312 | if num_samples != 0 and num_samples is not None: 313 | fname += f"-samples{num_samples}" 314 | if len(loss) > 1 or loss[0] != "CrossEntropyLoss": 315 | fname += f'-{",".join(loss)}' 316 | if tree_supervision_weight not in (None, 1): 317 | fname += f"-tsw{tree_supervision_weight}" 318 | if tree_supervision_weight_end not in (tree_supervision_weight, None): 319 | fname += f"-tswe{tree_supervision_weight_end}" 320 | if tree_supervision_weight_power not in (None, 1): 321 | fname += f"-tswp{tree_supervision_weight_power}" 322 | if xent_weight not in (None, 1): 323 | fname += f"-xw{xent_weight}" 324 | if xent_weight_end not in (xent_weight, None): 325 | fname += f"-xwe{xent_weight_end}" 326 | if xent_weight_power not in (None, 1): 327 | fname += f"-xwp{xent_weight_power}" 328 | if "SoftTreeLoss" in loss: 329 | if tree_start_epochs is not None: 330 | fname += f"-tse{tree_start_epochs}" 331 | if tree_update_every_epochs is not None: 332 | fname += f"-tueve{tree_update_every_epochs}" 333 | if tree_update_end_epochs is not None: 334 | fname += f"-tuene{tree_update_end_epochs}" 335 | return fname 336 | 337 | 338 | def coerce_tensor(x, is_label=False): 339 | if is_label: 340 | return x.reshape(-1, 1) 341 | else: 342 | return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]) 343 | 344 | 345 | def uncoerce_tensor(x, original_shape): 346 | n, c, h, w = original_shape 347 | return x.reshape(n, h, w, c).permute(0, 3, 1, 2) 348 | 349 | -------------------------------------------------------------------------------- /nbdt/wnids/ADE20K.txt: -------------------------------------------------------------------------------- 1 | n04546855 2 | n02913152 3 | n09436708 4 | n03365592 5 | n13104059 6 | n02990373 7 | n04096066 8 | n02818832 9 | n04589745 10 | n12102133 11 | n02933112 12 | n04215402 13 | n00007846 14 | n14842992 15 | n03221720 16 | n04379243 17 | n09359803 18 | n00017222 19 | n03151077 20 | n03001627 21 | n02958343 22 | n09225146 23 | n03876519 24 | n04256520 25 | n04190052 26 | n03544360 27 | n09426788 28 | n03773035 29 | n04118021 30 | n08569998 31 | n02738535 32 | n08647616 33 | n03327234 34 | n03179701 35 | n09416076 36 | n04550184 37 | n03636649 38 | n02808440 39 | n04047401 40 | n03151500 41 | n02797692 42 | n02883344 43 | n03074380 44 | n04217882 45 | n03015254 46 | n03116530 47 | n15019030 48 | n04223580 49 | n04233124 50 | n03346455 51 | n04070727 52 | n03452953 53 | n08616311 54 | n04298171 55 | n04120842 56 | n02975212 57 | n03982430 58 | n03938244 59 | n04153025 60 | n04298308 61 | n09411430 62 | n02898711 63 | n02870880 64 | n02851099 65 | n03063968 66 | n04446521 67 | n11669921 68 | n06410904 69 | n09303008 70 | n02828884 71 | n03118245 72 | n04330340 73 | n12582231 74 | n03620600 75 | n03082979 76 | n04373704 77 | n02858304 78 | n02796995 79 | n04243941 80 | n03547054 81 | n02924116 82 | n04459362 83 | n03665366 84 | n04490091 85 | n04460130 86 | n03005285 87 | n02763901 88 | n04335886 89 | n02873839 90 | n04405907 91 | n02691156 92 | n03205760 93 | n02728440 94 | n03976657 95 | n09335240 96 | n02788148 97 | n03295773 98 | n03858418 99 | n02876657 100 | n02912065 101 | n06793426 102 | n04296562 103 | n04520170 104 | n04194289 105 | n03388043 106 | n03100897 107 | n02951843 108 | n04554684 109 | n03964744 110 | n04371225 111 | n04326896 112 | n02795169 113 | n02801938 114 | n09475292 115 | n04411264 116 | n02773037 117 | n03769722 118 | n03125729 119 | n03862676 120 | n02778669 121 | n07555863 122 | n04314914 123 | n04388743 124 | n06845599 125 | n03761084 126 | n03991062 127 | n00015388 128 | n02834778 129 | n09328904 130 | n03207941 131 | n04152829 132 | n02849154 133 | n04157320 134 | n03531546 135 | n04148936 136 | n04522168 137 | n06874185 138 | n04476259 139 | n02747177 140 | n03320046 141 | n03933529 142 | n04152593 143 | n03959485 144 | n03782190 145 | n02916538 146 | n04208936 147 | n04041069 148 | n03438257 149 | n03046257 150 | n03354903 -------------------------------------------------------------------------------- /nbdt/wnids/CIFAR10.txt: -------------------------------------------------------------------------------- 1 | n02691156 2 | n02958343 3 | n01503061 4 | n02121620 5 | n02430045 6 | n02084071 7 | n01639765 8 | n02374451 9 | n04194289 10 | n04490091 -------------------------------------------------------------------------------- /nbdt/wnids/CIFAR100.txt: -------------------------------------------------------------------------------- 1 | n07739125 2 | n02512752 3 | n09827683 4 | n02131653 5 | n02363005 6 | n02818832 7 | n02206856 8 | n02164464 9 | n02834778 10 | n02876657 11 | n02881193 12 | n10285313 13 | n02898711 14 | n02924116 15 | n02274259 16 | n02437136 17 | n02946921 18 | n02980441 19 | n02309337 20 | n02402425 21 | n03001627 22 | n02481823 23 | n03046257 24 | n11439690 25 | n02233338 26 | n04256520 27 | n01976957 28 | n01697178 29 | n03147509 30 | n01699831 31 | n02581957 32 | n02503517 33 | n02657368 34 | n08438533 35 | n02118333 36 | n10129825 37 | n02342885 38 | n03544360 39 | n01877134 40 | n03614007 41 | n03636248 42 | n03649909 43 | n02128385 44 | n02129165 45 | n01674464 46 | n01982650 47 | n10287213 48 | n12752205 49 | n03790512 50 | n09359803 51 | n02330245 52 | n13001041 53 | n12268246 54 | n07747607 55 | n12041446 56 | n02444819 57 | n12582231 58 | n07767847 59 | n03930630 60 | n11608250 61 | n09393605 62 | n03959485 63 | n11900569 64 | n02346627 65 | n01874928 66 | n02324045 67 | n02508021 68 | n01495701 69 | n04096066 70 | n04099429 71 | n12620196 72 | n09426788 73 | n02076196 74 | n01482330 75 | n01891633 76 | n07476495 77 | n04233124 78 | n01944390 79 | n01726692 80 | n01772222 81 | n02355227 82 | n04335435 83 | n11978233 84 | n12901264 85 | n04379243 86 | n04389033 87 | n04401088 88 | n06277280 89 | n02129604 90 | n04465501 91 | n04468005 92 | n07794452 93 | n12454159 94 | n01662784 95 | n04550184 96 | n02062744 97 | n12724942 98 | n02114100 99 | n10787470 100 | n01922303 -------------------------------------------------------------------------------- /nbdt/wnids/LookIntoPerson.txt: -------------------------------------------------------------------------------- 1 | n05933834 2 | n03497657 3 | n05254795 4 | n03441112 5 | n04356056 6 | n04453666 7 | n03236735 8 | n03057021 9 | n04254777 10 | n02854739 11 | n03605598 12 | n04143897 13 | n04231272 14 | n05600637 15 | n05563770 16 | f00000015 17 | n05560787 18 | f00000017 19 | n04199027 20 | f00000019 -------------------------------------------------------------------------------- /nbdt/wnids/PascalContext.txt: -------------------------------------------------------------------------------- 1 | n02691156 2 | n09359803 3 | n02330245 4 | n09387222 5 | n04096066 6 | n02773037 7 | n03769722 8 | n03327234 9 | n02818832 10 | n02820210 11 | n02828884 12 | n02834778 13 | n03201208 14 | n01503061 15 | n00007846 16 | n03365592 17 | n02858304 18 | n04468005 19 | n06410904 20 | n02876657 21 | n13104059 22 | n04587648 23 | n03959485 24 | n03961939 25 | n04405762 26 | n02913152 27 | n02924116 28 | n02933112 29 | n04190052 30 | n11473954 31 | n00017222 32 | n04546855 33 | n02958343 34 | n09334396 35 | n02121620 36 | n04215402 37 | n04490091 38 | n02990373 39 | n09416076 40 | n03001627 41 | n15098161 42 | n00021265 43 | n02374451 44 | n03309808 45 | n06793231 46 | n03082979 47 | n02411705 48 | n03614007 49 | n11669921 50 | n09436708 51 | n02403454 52 | n12102133 53 | n03147509 54 | n03151077 55 | n11508382 56 | n14845743 57 | n04256520 58 | n02084071 59 | n03221720 -------------------------------------------------------------------------------- /nbdt/wnids/TinyImagenet200.txt: -------------------------------------------------------------------------------- 1 | n02124075 2 | n04067472 3 | n04540053 4 | n04099969 5 | n07749582 6 | n01641577 7 | n02802426 8 | n09246464 9 | n07920052 10 | n03970156 11 | n03891332 12 | n02106662 13 | n03201208 14 | n02279972 15 | n02132136 16 | n04146614 17 | n07873807 18 | n02364673 19 | n04507155 20 | n03854065 21 | n03838899 22 | n03733131 23 | n01443537 24 | n07875152 25 | n03544143 26 | n09428293 27 | n03085013 28 | n02437312 29 | n07614500 30 | n03804744 31 | n04265275 32 | n02963159 33 | n02486410 34 | n01944390 35 | n09256479 36 | n02058221 37 | n04275548 38 | n02321529 39 | n02769748 40 | n02099712 41 | n07695742 42 | n02056570 43 | n02281406 44 | n01774750 45 | n02509815 46 | n03983396 47 | n07753592 48 | n04254777 49 | n02233338 50 | n04008634 51 | n02823428 52 | n02236044 53 | n03393912 54 | n07583066 55 | n04074963 56 | n01629819 57 | n09332890 58 | n02481823 59 | n03902125 60 | n03404251 61 | n09193705 62 | n03637318 63 | n04456115 64 | n02666196 65 | n03796401 66 | n02795169 67 | n02123045 68 | n01855672 69 | n01882714 70 | n02917067 71 | n02988304 72 | n04398044 73 | n02843684 74 | n02423022 75 | n02669723 76 | n04465501 77 | n02165456 78 | n03770439 79 | n02099601 80 | n04486054 81 | n02950826 82 | n03814639 83 | n04259630 84 | n03424325 85 | n02948072 86 | n03179701 87 | n03400231 88 | n02206856 89 | n03160309 90 | n01984695 91 | n03977966 92 | n03584254 93 | n04023962 94 | n02814860 95 | n01910747 96 | n04596742 97 | n03992509 98 | n04133789 99 | n03937543 100 | n02927161 101 | n01945685 102 | n02395406 103 | n02125311 104 | n03126707 105 | n04532106 106 | n02268443 107 | n02977058 108 | n07734744 109 | n03599486 110 | n04562935 111 | n03014705 112 | n04251144 113 | n04356056 114 | n02190166 115 | n03670208 116 | n02002724 117 | n02074367 118 | n04285008 119 | n04560804 120 | n04366367 121 | n02403003 122 | n07615774 123 | n04501370 124 | n03026506 125 | n02906734 126 | n01770393 127 | n04597913 128 | n03930313 129 | n04118538 130 | n04179913 131 | n04311004 132 | n02123394 133 | n04070727 134 | n02793495 135 | n02730930 136 | n02094433 137 | n04371430 138 | n04328186 139 | n03649909 140 | n04417672 141 | n03388043 142 | n01774384 143 | n02837789 144 | n07579787 145 | n04399382 146 | n02791270 147 | n03089624 148 | n02814533 149 | n04149813 150 | n07747607 151 | n03355925 152 | n01983481 153 | n04487081 154 | n03250847 155 | n03255030 156 | n02892201 157 | n02883205 158 | n03100240 159 | n02415577 160 | n02480495 161 | n01698640 162 | n01784675 163 | n04376876 164 | n03444034 165 | n01917289 166 | n01950731 167 | n03042490 168 | n07711569 169 | n04532670 170 | n03763968 171 | n07768694 172 | n02999410 173 | n03617480 174 | n06596364 175 | n01768244 176 | n02410509 177 | n03976657 178 | n01742172 179 | n03980874 180 | n02808440 181 | n02226429 182 | n02231487 183 | n02085620 184 | n01644900 185 | n02129165 186 | n02699494 187 | n03837869 188 | n02815834 189 | n07720875 190 | n02788148 191 | n02909870 192 | n03706229 193 | n07871810 194 | n03447447 195 | n02113799 196 | n12267677 197 | n03662601 198 | n02841315 199 | n07715103 200 | n02504458 201 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --doctest-modules 3 | norecursedirs = .svn _build tmp* *.egg-info __pycache__ examples/app 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorchcv 2 | requests 3 | pyparsing 4 | zipp 5 | torch==1.6.0 6 | torchvision==0.7.0 7 | nltk 8 | scikit-learn 9 | networkx 10 | pytest 11 | opencv-python==4.4.0.42 12 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_nopretrained.ps1: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | # This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. 3 | 4 | $model="ResNet18" 5 | $dataset="CIFAR10" 6 | $weight=1 7 | 8 | # 0. train the baseline neural network 9 | python main.py --dataset=$dataset --arch=$model 10 | 11 | # 1. generate hieararchy -- for models without a pretrained checkpoint, use `checkpoint` 12 | nbdt-hierarchy --dataset=$dataset --checkpoint="./checkpoint/ckpt-$dataset" + "-$model.pth" 13 | 14 | # 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use `path-resume` OR just train from scratch, without `path-resume` 15 | # python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # fine-tuning 16 | python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight # training from scratch 17 | 18 | # 3. evaluate with soft then hard inference 19 | $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") 20 | 21 | foreach ($analysis in $analysisRules) { 22 | python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight 23 | } 24 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_nopretrained.sh: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | # This script is for networks that do NOT come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. 3 | 4 | model="ResNet18" 5 | dataset=CIFAR10 6 | weight=1 7 | 8 | # 0. train the baseline neural network 9 | python main.py --dataset=${dataset} --arch=${model} 10 | 11 | # 1. generate hieararchy -- for models without a pretrained checkpoint, use `checkpoint` 12 | nbdt-hierarchy --dataset=${dataset} --checkpoint=./checkpoint/ckpt-${dataset}-${model}.pth 13 | 14 | # 2. train with soft tree supervision loss -- for models without a pretrained checkpoint, use `path-resume` OR just train from scratch, without `path-resume` 15 | # python main.py --lr=0.01 --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --path-resume=./checkpoint/ckpt-${dataset}-${model}.pth --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # fine-tuning 16 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} # training from scratch 17 | 18 | # 3. evaluate with soft then hard inference 19 | for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do 20 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} 21 | done 22 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_pretrained.ps1: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | # This script is for networks that DO come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. 3 | 4 | $model="wrn28_10_cifar10" 5 | $dataset="CIFAR10" 6 | $weight=1 7 | 8 | # 1. generate hieararchy 9 | nbdt-hierarchy --dataset=$dataset --arch=$model 10 | 11 | # 2. train with soft tree supervision loss 12 | python main.py --lr=0.01 --dataset=$dataset --model=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight 13 | 14 | # 3. evaluate with soft then hard inference 15 | $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") 16 | 17 | foreach ($analysis in $analysisRules) { 18 | python main.py --dataset=$dataset --model=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --eval --resume --analysis=$analysis --tree-supervision-weight=$weight 19 | } 20 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_pretrained.sh: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | # This script is for networks that DO come with a pretrained checkpoint provided either by a model zoo or by the NBDT utility itself. 3 | 4 | model=wrn28_10_cifar10 5 | dataset=CIFAR10 6 | weight=1 7 | 8 | # 1. generate hieararchy 9 | nbdt-hierarchy --dataset=${dataset} --arch=${model} 10 | 11 | # 2. train with soft tree supervision loss 12 | python main.py --lr=0.01 --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} 13 | 14 | # 3. evaluate with soft then hard inference 15 | for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do 16 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} 17 | done 18 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_resnet.ps1: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | 3 | $MODELS = @("CIFAR10 1", "CIFAR100 1", "TinyImagenet200 10") 4 | 5 | foreach ($model in $MODELS) { 6 | 7 | $params = $model.split(" ") 8 | 9 | $dataset=$params[0] 10 | $weight=$params[1] 11 | 12 | 13 | 14 | # 1. generate hieararchy 15 | nbdt-hierarchy --dataset=$dataset --arch=ResNet18 16 | 17 | # 2. train with soft tree supervision loss 18 | python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight 19 | 20 | # 3. evaluate with soft then hard inference 21 | 22 | $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") 23 | 24 | foreach ($analysis in $analysisRules) { 25 | python main.py --dataset=$dataset --arch=$model --hierarchy=induced-$model --loss=SoftTreeSupLoss --tree-supervision-weight=$weight --eval --resume --analysis=$analysis 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_resnet.sh: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | 3 | for i in "CIFAR10 1" "CIFAR100 1" "TinyImagenet200 10"; do 4 | read dataset weight <<< "${i}"; 5 | 6 | # 1. generate hieararchy 7 | nbdt-hierarchy --dataset=${dataset} --arch=ResNet18 8 | 9 | # 2. train with soft tree supervision loss 10 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} 11 | 12 | # 3. evaluate with soft then hard inference 13 | for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do 14 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} --eval --resume --analysis=${analysis} 15 | done 16 | done; 17 | -------------------------------------------------------------------------------- /scripts/gen_train_eval_wideresnet.ps1: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | 3 | $MODEL_NAME="wrn28_10" 4 | $CIFAR100="CIFAR100" + " " + $MODEL_NAME + "_cifar100 1" 5 | $CIFAR10="CIFAR10 $MODEL_NAME" + "_cifar10 1" 6 | $MODELS=@($CIFAR10, $CIFAR100, "TinyImagenet200 $MODEL_NAME 10") 7 | 8 | foreach ($model in $MODELS) { 9 | 10 | $params = $model.split(" ") 11 | 12 | $dataset=$params[0] 13 | $model=$params[1] 14 | $weight=$params[2] 15 | 16 | # 1. generate hieararchy 17 | nbdt-hierarchy --dataset=$dataset --arch=$model 18 | 19 | # 2. train with soft tree supervision loss 20 | python main.py --lr=0.01 --dataset=$dataset --arch=$model --hierarchy=induced-$model --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=$weight 21 | 22 | # 3. evaluate with soft then hard inference 23 | $analysisRules = @("SoftEmbeddedDecisionRules", "HardEmbeddedDecisionRules") 24 | 25 | foreach ($analysis in $analysisRules) { 26 | python main.py --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/gen_train_eval_wideresnet.sh: -------------------------------------------------------------------------------- 1 | # Want to train with wordnet hierarchy? Just set `--hierarchy=wordnet` below. 2 | 3 | MODEL_NAME="wrn28_10" 4 | for i in "CIFAR10 ${MODEL_NAME}_cifar10 1" "CIFAR100 ${MODEL_NAME}_cifar100 1" "TinyImagenet200 ${MODEL_NAME} 10"; do 5 | read dataset model weight <<< "${i}"; 6 | 7 | # 1. generate hieararchy 8 | nbdt-hierarchy --dataset=${dataset} --arch=${model} 9 | 10 | # 2. train with soft tree supervision loss 11 | python main.py --lr=0.01 --dataset=${dataset} --arch=${model} --hierarchy=induced-${model} --pretrained --loss=SoftTreeSupLoss --tree-supervision-weight=${weight} 12 | 13 | # 3. evaluate with soft then hard inference 14 | for analysis in SoftEmbeddedDecisionRules HardEmbeddedDecisionRules; do 15 | python main.py --dataset=${dataset} --model=${model} --hierarchy=induced-${model} --loss=SoftTreeSupLoss --eval --resume --analysis=${analysis} --tree-supervision-weight=${weight} 16 | done 17 | done; 18 | -------------------------------------------------------------------------------- /scripts/generate_hierarchies_wordnet.ps1: -------------------------------------------------------------------------------- 1 | python -c "import nltk;nltk.download('wordnet')" 2 | 3 | # Generate WNIDs 4 | $DATASETS = @("CIFAR10", "CIFAR100") 5 | foreach ($dataset in $DATASETS) { 6 | nbdt-wnids --dataset=$dataset 7 | } 8 | 9 | # Generate and test hierarchies 10 | $MORE_DATASETS = @("CIFAR10", "CIFAR100", "TinyImagenet200") 11 | foreach ($dataset in $MORE_DATASETS) { 12 | nbdt-hierarchy --dataset=$dataset --method=wordnet; 13 | } -------------------------------------------------------------------------------- /scripts/generate_hierarchies_wordnet.sh: -------------------------------------------------------------------------------- 1 | python -c "import nltk;nltk.download('wordnet')" 2 | 3 | # Generate WNIDs 4 | for dataset in CIFAR10 CIFAR100; 5 | do 6 | nbdt-wnids --dataset=${dataset} 7 | done; 8 | 9 | # Generate and test hierarchies 10 | for dataset in CIFAR10 CIFAR100 TinyImagenet200; 11 | do 12 | nbdt-hierarchy --dataset=${dataset} --method=wordnet; 13 | done; 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | VERSION = "0.0.4" 4 | 5 | with open("requirements.txt", "r") as f: 6 | install_requires = f.readlines() 7 | 8 | 9 | with open("README.md", "r") as fh: 10 | long_description = fh.read() 11 | 12 | 13 | setuptools.setup( 14 | name="nbdt", 15 | version=VERSION, 16 | author="Alvin Wan", # TODO: proper way to list all paper authors? 17 | author_email="hi@alvinwan.com", 18 | description="Making decision trees competitive with state-of-the-art " 19 | "neural networks on CIFAR10, CIFAR100, TinyImagenet200, " 20 | "Imagenet. Transform any image classification neural network " 21 | "into an interpretable neural-backed decision tree.", 22 | long_description=long_description, 23 | long_description_content_type="text/markdown", 24 | url="https://github.com/alvinwan/neural-backed-decision-trees", 25 | packages=setuptools.find_packages(), 26 | install_requires=install_requires, 27 | download_url="https://github.com/alvinwan/neural-backed-decision-trees/archive/%s.zip" 28 | % VERSION, 29 | scripts=["nbdt/bin/nbdt-hierarchy", "nbdt/bin/nbdt-wnids", "nbdt/bin/nbdt"], 30 | classifiers=[ 31 | "Intended Audience :: Developers", 32 | "Programming Language :: Python :: 3", 33 | "License :: OSI Approved :: MIT License", 34 | "Operating System :: OS Independent", 35 | ], 36 | python_requires=">=3.5", 37 | include_package_data=True, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alvinwan/neural-backed-decision-trees/a7a2ee6f735bbc1b3d8c7c4f9ecdd02c6a75fc1e/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from nbdt.models import ResNet18 5 | 6 | 7 | collect_ignore = ["setup.py", "main.py"] 8 | 9 | 10 | @pytest.fixture 11 | def label_cifar10(): 12 | return torch.randint(10, (1,)) 13 | 14 | 15 | @pytest.fixture 16 | def input_cifar10(): 17 | return torch.randn(1, 3, 32, 32) 18 | 19 | 20 | @pytest.fixture 21 | def input_cifar100(): 22 | return torch.randn(1, 3, 32, 32) 23 | 24 | 25 | @pytest.fixture 26 | def input_tinyimagenet200(): 27 | return torch.randn(1, 3, 64, 64) 28 | 29 | 30 | @pytest.fixture 31 | def criterion(): 32 | return nn.CrossEntropyLoss() 33 | 34 | 35 | @pytest.fixture 36 | def resnet18_cifar10(): 37 | return ResNet18(num_classes=10) 38 | 39 | 40 | @pytest.fixture 41 | def resnet18_cifar100(): 42 | return ResNet18(num_classes=100) 43 | 44 | 45 | @pytest.fixture 46 | def resnet18_tinyimagenet200(): 47 | return ResNet18(num_classes=200) 48 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | """Tests that models work inference-time""" 2 | 3 | from nbdt.model import SoftNBDT, HardNBDT 4 | 5 | 6 | def test_nbdt_soft_cifar10(input_cifar10, resnet18_cifar10): 7 | model_soft = SoftNBDT( 8 | dataset="CIFAR10", model=resnet18_cifar10, hierarchy="induced" 9 | ) 10 | model_soft(input_cifar10) 11 | 12 | 13 | def test_nbdt_soft_cifar100(input_cifar100, resnet18_cifar100): 14 | model_soft = SoftNBDT( 15 | dataset="CIFAR100", model=resnet18_cifar100, hierarchy="induced" 16 | ) 17 | model_soft(input_cifar100) 18 | 19 | 20 | def test_nbdt_soft_tinyimagenet200(input_tinyimagenet200, resnet18_tinyimagenet200): 21 | model_soft = SoftNBDT( 22 | dataset="TinyImagenet200", model=resnet18_tinyimagenet200, hierarchy="induced" 23 | ) 24 | model_soft(input_tinyimagenet200) 25 | 26 | 27 | def test_nbdt_hard_cifar10(input_cifar10, resnet18_cifar10): 28 | model_hard = HardNBDT( 29 | dataset="CIFAR10", model=resnet18_cifar10, hierarchy="induced" 30 | ) 31 | model_hard(input_cifar10) 32 | 33 | 34 | def test_nbdt_hard_cifar100(input_cifar100, resnet18_cifar100): 35 | model_hard = HardNBDT( 36 | dataset="CIFAR100", model=resnet18_cifar100, hierarchy="induced" 37 | ) 38 | model_hard(input_cifar100) 39 | 40 | 41 | def test_nbdt_hard_tinyimagenet200(input_tinyimagenet200, resnet18_tinyimagenet200): 42 | model_hard = HardNBDT( 43 | dataset="TinyImagenet200", model=resnet18_tinyimagenet200, hierarchy="induced" 44 | ) 45 | model_hard(input_tinyimagenet200) 46 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | """Tests that train utilities work as advertised""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from nbdt.loss import SoftTreeSupLoss, HardTreeSupLoss 6 | from nbdt.model import HardNBDT 7 | 8 | 9 | def test_criterion_cifar10(criterion, label_cifar10): 10 | criterion = SoftTreeSupLoss( 11 | dataset="CIFAR10", criterion=criterion, hierarchy="induced" 12 | ) 13 | criterion(torch.randn((1, 10)), label_cifar10) 14 | 15 | 16 | def test_criterion_cifar100(criterion): 17 | criterion = SoftTreeSupLoss( 18 | dataset="CIFAR100", criterion=criterion, hierarchy="induced" 19 | ) 20 | criterion(torch.randn((1, 100)), torch.randint(100, (1,))) 21 | 22 | 23 | def test_criterion_tinyimagenet200(criterion): 24 | criterion = SoftTreeSupLoss( 25 | dataset="TinyImagenet200", criterion=criterion, hierarchy="induced" 26 | ) 27 | criterion(torch.randn((1, 200)), torch.randint(200, (1,))) 28 | 29 | 30 | def test_nbdt_gradient_hard(resnet18_cifar10, input_cifar10, label_cifar10, criterion): 31 | output_cifar10 = resnet18_cifar10(input_cifar10) 32 | assert output_cifar10.requires_grad 33 | 34 | criterion = HardTreeSupLoss( 35 | dataset="CIFAR10", criterion=criterion, hierarchy="induced" 36 | ) 37 | loss = criterion(output_cifar10, label_cifar10) 38 | loss.backward() 39 | 40 | 41 | def test_nbdt_gradient_soft(resnet18_cifar10, input_cifar10, label_cifar10, criterion): 42 | output_cifar10 = resnet18_cifar10(input_cifar10) 43 | assert output_cifar10.requires_grad 44 | 45 | criterion = SoftTreeSupLoss( 46 | dataset="CIFAR10", criterion=criterion, hierarchy="induced" 47 | ) 48 | loss = criterion(output_cifar10, label_cifar10) 49 | loss.backward() 50 | --------------------------------------------------------------------------------