├── .coveragerc ├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .gitmodules ├── .isort.cfg ├── .pre-commit-config.yaml ├── AUTHORS.rst ├── CHANGELOG.rst ├── CONTRIBUTORS.md ├── Dockerfile ├── LICENSE.txt ├── README.md ├── aws_backend ├── README.md ├── app.py ├── architecture.png ├── backend │ ├── __init__.py │ └── stack.py ├── cdk.context.json ├── cdk.json ├── requirements.txt └── src │ └── Dockerfile ├── core ├── README.md ├── __init__.py ├── helpers.py ├── hip.py ├── main.py ├── mocks.py ├── nlp.py ├── proto.py ├── proxy.py ├── requirements.txt ├── setup.sh ├── tests │ ├── model_brain_segmentation_sanity.py │ ├── model_medical_label_sanity.py │ ├── model_prefilter_sanity.py │ └── model_vqa_sanity.py └── words.py ├── docs ├── Makefile ├── _static │ └── .gitignore ├── authors.rst ├── changelog.rst ├── conf.py ├── index.rst └── license.rst ├── examples ├── tutorial_01_tensorboard_mnist │ ├── __init__.py │ └── mnist │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── demo.py │ │ └── model.py └── tutorial_02_saliency_map │ └── demo.py ├── misc ├── medtorch │ ├── Artboard 2.png │ ├── Artboard 2.svg │ └── Artboard 2@4x.png ├── q&aid-mini │ ├── 1x │ │ ├── Artboard 3.png │ │ ├── Artboard 4.png │ │ └── Artboard 5.png │ ├── 4x │ │ ├── Artboard 3@4x.png │ │ ├── Artboard 4@4x.png │ │ └── Artboard 5@4x.png │ └── SVG │ │ ├── Artboard 3.svg │ │ ├── Artboard 4.svg │ │ └── Artboard 5.svg ├── q&aid.ai ├── q&aid │ ├── Artboard 1.png │ ├── Artboard 1.svg │ └── Artboard 1@4x.png ├── q_aid_logo_small.png └── q_aid_logo_small1.png ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── pytorchxai │ ├── __init__.py │ ├── plugin │ ├── __init__.py │ ├── pytorchxai │ │ └── static │ │ │ ├── index.js │ │ │ └── style.css │ └── pytorchxai_plugin.py │ └── xai │ ├── __init__.py │ ├── cam_gradcam.py │ ├── cam_scorecam.py │ ├── cam_utils.py │ ├── gradient_guided_backprop.py │ ├── gradient_guided_gradcam.py │ ├── gradient_integrated_grad.py │ ├── gradient_smooth_grad.py │ ├── gradient_vanilla_backprop.py │ ├── utils.py │ └── visualizations.py └── tests ├── conftest.py └── xai ├── __init__.py ├── test_gradcam.py ├── test_guided_backprop.py ├── test_guided_gradcam.py ├── test_integrated_gradients.py ├── test_scorecam.py ├── test_smooth_grad.py ├── test_vanilla_backprop.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = pytorchxai 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Q&Aid 5 | 6 | on: [push] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install flake8 pytest 21 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 22 | python setup.py install 23 | - name: Lint with flake8 24 | run: | 25 | # stop the build if there are Python syntax errors or undefined names 26 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 27 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 28 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 29 | - name: Test with pytest 30 | run: | 31 | pytest 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | runs/ 17 | 18 | # Project files 19 | .ropeproject 20 | .project 21 | .pydevproject 22 | .settings 23 | .idea 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .tox 36 | junit.xml 37 | coverage.xml 38 | .pytest_cache/ 39 | 40 | # Build and docs folder/files 41 | build/* 42 | dist/* 43 | sdist/* 44 | docs/api/* 45 | docs/_rst/* 46 | docs/_build/* 47 | cover/* 48 | MANIFEST 49 | 50 | # Per-project virtualenvs 51 | .venv*/ 52 | MICCAI19-MedVQA.full 53 | 54 | # OSX 55 | # 56 | .DS_Store 57 | 58 | # Xcode 59 | # 60 | build/ 61 | *.pbxuser 62 | !default.pbxuser 63 | *.mode1v3 64 | !default.mode1v3 65 | *.mode2v3 66 | !default.mode2v3 67 | *.perspectivev3 68 | !default.perspectivev3 69 | xcuserdata 70 | *.xccheckout 71 | *.moved-aside 72 | DerivedData 73 | *.hmap 74 | *.ipa 75 | *.xcuserstate 76 | 77 | # Android/IntelliJ 78 | # 79 | build/ 80 | .idea 81 | .gradle 82 | local.properties 83 | *.iml 84 | 85 | # node.js 86 | # 87 | node_modules/ 88 | npm-debug.log 89 | yarn-error.log 90 | 91 | # BUCK 92 | buck-out/ 93 | \.buckd/ 94 | *.keystore 95 | !debug.keystore 96 | 97 | # fastlane 98 | # 99 | # It is recommended to not store the screenshots in the git repo. Instead, use fastlane to re-generate the 100 | # screenshots whenever they are needed. 101 | # For more information about the recommended setup visit: 102 | # https://docs.fastlane.tools/best-practices/source-control/ 103 | 104 | */fastlane/report.xml 105 | */fastlane/Preview.html 106 | */fastlane/screenshots 107 | 108 | # Bundle artifact 109 | *.jsbundle 110 | 111 | # CocoaPods 112 | /ios/Pods/ 113 | 114 | #amplify 115 | amplify/\#current-cloud-backend 116 | amplify/.config/local-* 117 | amplify/mock-data 118 | amplify/backend/amplify-meta.json 119 | amplify/backend/awscloudformation 120 | build/ 121 | dist/ 122 | node_modules/ 123 | awsconfiguration.json 124 | amplifyconfiguration.json 125 | amplify-build-config.json 126 | amplify-gradle-config.json 127 | amplifytools.xcconfig 128 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "core/models"] 2 | path = core/models 3 | url = https://github.com/medtorch/Q-Aid-Models 4 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length=88 3 | indent=' ' 4 | skip=.tox,.venv,build,dist 5 | known_standard_library=setuptools,pkg_resources 6 | known_test=pytest 7 | known_first_party=pytorchxai 8 | sections=FUTURE,STDLIB,COMPAT,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 9 | default_section=THIRDPARTY 10 | multi_line_output=3 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^docs/conf.py' 2 | 3 | repos: 4 | - repo: git://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.2.3 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=no'] 19 | - id: flake8 20 | args: ['--max-line-length=88'] # default of Black 21 | 22 | - repo: https://github.com/pre-commit/mirrors-isort 23 | rev: v4.3.4 24 | hooks: 25 | - id: isort 26 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Tudor Cebere 6 | * Bogdan Cebere 7 | * George-Cristian Muraru 8 | * Andrei Manolache 9 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1 6 | =========== 7 | 8 | - Feature A added 9 | - FIX: nasty bug #1729 fixed 10 | - add your changes here! 11 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | [Tudor Cebere](https://github.com/tudorcebere) 2 | 3 | [Bogdan Cebere](https://github.com/bcebere) 4 | 5 | [Andrei Manolache](https://github.com/andreimano) 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8-slim 2 | 3 | RUN mkdir /service 4 | WORKDIR /service 5 | 6 | COPY core . 7 | COPY core/requirements.txt . 8 | 9 | RUN apt-get update 10 | RUN apt-get install python-opencv -y 11 | 12 | RUN pip install --upgrade pip 13 | RUN pip install -r requirements.txt 14 | 15 | EXPOSE 80 16 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Tudor Cebere 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Q&Aid 3 |

4 | 5 | ![License: MIT](https://img.shields.io/badge/License-MIT-green.svg) 6 | ![Q&Aid](https://github.com/medtorch/Q-Aid/workflows/Q&Aid/badge.svg) 7 | 8 | ## Features 9 | 10 | - :fire: Collection of healthcare AI models under [core](core), created using PyTorch. 11 | - :key: Served using [FastAPI](https://fastapi.tiangolo.com/). 12 | - :cyclone: Full deployment scripts for AWS. 13 | - :zap: Compatible React-Native app under [app](app) folder. 14 | 15 | ## Installation 16 | 17 | ## Usage 18 | 19 | ### Models 20 | 21 | Read more about the models [here](https://github.com/medtorch/Q-Aid-Models). 22 | ### App 23 | 24 | Read more about the app [here](https://github.com/medtorch/Q-Aid-App). 25 | 26 | ### Server 27 | 28 | Read more about the server setup [here](https://github.com/medtorch/Q-Aid-Core/blob/master/core/README.md). 29 | 30 | ### AWS deployment 31 | Seet the [AWS README](aws_backend/README.md). 32 | 33 | 34 | ## Contributors 35 | 36 | See [CONTRIBUTORS.md](CONTRIBUTORS.md). 37 | 38 | ## License 39 | [MIT License](https://choosealicense.com/licenses/mit/) 40 | -------------------------------------------------------------------------------- /aws_backend/README.md: -------------------------------------------------------------------------------- 1 |

2 | Q&Aid 3 |

4 | 5 | ![License: MIT](https://img.shields.io/badge/License-MIT-green.svg) 6 | ![Q&Aid](https://github.com/medtorch/Q-Aid/workflows/Q&Aid/badge.svg) 7 | 8 | # AWS backend deployment 9 | 10 | ## Introduction 11 | 12 | The code is used deployment the Q&Aid backend in AWS. 13 | 14 | The code is inspired from https://aws-blog.de/2020/03/building-a-fargate-based-container-app-with-cognito-authentication.html 15 | 16 | This CDK app sets up infrastructure that can be used for the integration with the Application Load Balancer. Furthermore it includes not only the log in but also the log out workflow. 17 | 18 | ## Architecture 19 | 20 | ![Architecture](architecture.png) 21 | 22 | This stack builds up a bunch of things: 23 | 24 | - A DNS-Record for the application. 25 | - A SSL/TLS certificate. 26 | - An Application Load Balancer with that DNS recorcd and certificate. 27 | - An ECR Container Registry to push our Docker image to. 28 | - An ECS Fargate Service to run our Q&Aid backend. 29 | 30 | ## Prerequisites 31 | 32 | - CDK is installed. 33 | - Docker is installed. 34 | - You have a public hosted zone in your account(You can use Route53 for that). 35 | 36 | ## Steps to deploy 37 | 38 | 1. Review the variables in `backend/stack.py` and edit these variables as described in the [blog article](https://aws-blog.de/2020/03/building-a-fargate-based-container-app-with-cognito-authentication.html): 39 | 40 | ```python 41 | APP_DNS_NAME = "q-and-aid.com" 42 | HOSTED_ZONE_ID = "Z09644041TWEBPC10I0YZ" 43 | HOSTED_ZONE_NAME = "q-and-aid.com" 44 | ``` 45 | 2. Make sure you have a valid AWS profile. You can generate one using 46 | ```amplify configure``` 47 | 3. `cdk --profile medqaid-profile bootstrap aws://unknown-account/eu-central-1` 48 | 4. Run `cdk synth` to check if the CDK works as expected, you can inspect the template if you're curious. 49 | 5. Run `cdk deploy` to deploy the resources. 50 | -------------------------------------------------------------------------------- /aws_backend/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from aws_cdk import core 4 | 5 | from backend.stack import FargateStack 6 | 7 | 8 | app = core.App() 9 | FargateStack(app, "med-qaid-core-backend-v3", env={"region": "eu-central-1"}) 10 | 11 | app.synth() 12 | -------------------------------------------------------------------------------- /aws_backend/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/aws_backend/architecture.png -------------------------------------------------------------------------------- /aws_backend/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/aws_backend/backend/__init__.py -------------------------------------------------------------------------------- /aws_backend/backend/stack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.parse 3 | 4 | from aws_cdk import core 5 | 6 | import aws_cdk.aws_certificatemanager as certificatemanager 7 | import aws_cdk.aws_cognito as cognito 8 | import aws_cdk.aws_ec2 as ec2 9 | import aws_cdk.aws_ecs as ecs 10 | import aws_cdk.aws_ecs_patterns as ecs_patterns 11 | import aws_cdk.aws_ecr_assets as ecr_assets 12 | import aws_cdk.aws_elasticloadbalancingv2 as elb 13 | import aws_cdk.aws_route53 as route53 14 | 15 | 16 | APP_DNS_NAME = "q-and-aid.com" 17 | HOSTED_ZONE_ID = "Z09644041TWEBPC10I0YZ" 18 | HOSTED_ZONE_NAME = "q-and-aid.com" 19 | 20 | 21 | class FargateStack(core.Stack): 22 | def __init__(self, scope: core.Construct, id: str, **kwargs) -> None: 23 | super().__init__(scope, id, **kwargs) 24 | 25 | # Get the hosted Zone and create a certificate for our domain 26 | 27 | hosted_zone = route53.HostedZone.from_hosted_zone_attributes( 28 | self, 29 | "HostedZone", 30 | hosted_zone_id=HOSTED_ZONE_ID, 31 | zone_name=HOSTED_ZONE_NAME, 32 | ) 33 | 34 | cert = certificatemanager.DnsValidatedCertificate( 35 | self, "Certificate", hosted_zone=hosted_zone, domain_name=APP_DNS_NAME 36 | ) 37 | 38 | # Set up a new VPC 39 | 40 | vpc = ec2.Vpc(self, "med-qaid-vpc", max_azs=2) 41 | 42 | # Set up an ECS Cluster for fargate 43 | 44 | cluster = ecs.Cluster(self, "med-qaid-cluster", vpc=vpc) 45 | 46 | # Define the Docker Image for our container (the CDK will do the build and push for us!) 47 | docker_image = ecr_assets.DockerImageAsset( 48 | self, 49 | "med-qaid-app", 50 | directory=os.path.join(os.path.dirname(__file__), "..", "src"), 51 | ) 52 | 53 | # Define the fargate service + ALB 54 | 55 | fargate_service = ecs_patterns.ApplicationLoadBalancedFargateService( 56 | self, 57 | "FargateService", 58 | cluster=cluster, 59 | certificate=cert, 60 | domain_name=f"{APP_DNS_NAME}", 61 | domain_zone=hosted_zone, 62 | cpu=2048, 63 | memory_limit_mib=16384, 64 | task_image_options={ 65 | "image": ecs.ContainerImage.from_docker_image_asset(docker_image), 66 | "environment": {"PORT": "80",}, 67 | }, 68 | ) 69 | 70 | # Allow 10 seconds for in flight requests before termination, the default of 5 minutes is much too high. 71 | fargate_service.target_group.set_attribute( 72 | key="deregistration_delay.timeout_seconds", value="10" 73 | ) 74 | -------------------------------------------------------------------------------- /aws_backend/cdk.context.json: -------------------------------------------------------------------------------- 1 | { 2 | "@aws-cdk/core:enableStackNameDuplicates": "true", 3 | "aws-cdk:enableDiffNoFail": "true" 4 | } 5 | -------------------------------------------------------------------------------- /aws_backend/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py" 3 | } 4 | -------------------------------------------------------------------------------- /aws_backend/requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /aws_backend/src/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM bcebere/qaid:latest 2 | 3 | WORKDIR /service 4 | EXPOSE 80 5 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] 6 | -------------------------------------------------------------------------------- /core/README.md: -------------------------------------------------------------------------------- 1 |

2 | Q&Aid 3 |

4 | 5 | 6 | # Core logic 7 | 8 | ## Introduction 9 | 10 | Scripts for testing and deploying the core logic behind Q&Aid. 11 | 12 | 13 | ## Prerequisites 14 | 15 | ``` 16 | ./setup.sh 17 | ``` 18 | 19 | Download the VQA model from https://drive.google.com/file/d/1dqJjthrbdnIs41ZdC_ZGVQnoZbuGMNCR/view?usp=sharing 20 | and save it to the path models/model_vqa/MICCAI19-MedVQA/saved_models/BAN_MEVF/model_epoch19.pth 21 | 22 | ## Run the server 23 | 24 | ``` 25 | uvicorn main:app --host 0.0.0.0 --port 8000 26 | ``` 27 | 28 | GET http://127.0.0.1:8000/capabilities should return the list of available models. 29 | 30 | ## Tests 31 | 32 | Run the scripts in the `tests` folder for checking each model. 33 | The tests require a running server instance. 34 | 35 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/core/__init__.py -------------------------------------------------------------------------------- /core/helpers.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | 4 | def hash_input(val): 5 | hash_object = hashlib.sha1(val.encode("utf-8")) 6 | return hash_object.digest() 7 | -------------------------------------------------------------------------------- /core/hip.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | from models.model_vqa.inference import VQA 4 | from models.model_brain_segmentation.inference import Segmentation 5 | 6 | 7 | class HealthIntelProviderLocal: 8 | def __init__(self, name, capabilities): 9 | self.name = name 10 | self.capabilities = capabilities 11 | 12 | self.cache = {} 13 | self.models = {} 14 | 15 | for feat in capabilities: 16 | self.cache[feat] = {} 17 | if feat == "vqa": 18 | self.models[feat] = VQA() 19 | elif feat == "segmentation": 20 | print("loading segmentation capability") 21 | self.models[feat] = Segmentation() 22 | else: 23 | raise "not implemented" 24 | 25 | def vqa(self, question: str, image_b64: str, topic: str): 26 | if not self.supports("vqa", topic): 27 | raise NotImplementedError() 28 | 29 | results = {} 30 | try: 31 | result = self.models["vqa"].ask(question, image_b64) 32 | results["vqa"] = result 33 | except BaseException as e: 34 | print("vqa failed ", e) 35 | 36 | return results 37 | 38 | def segment(self, image_b64: str, topic: str): 39 | if not self.supports("segmentation", topic): 40 | raise NotImplementedError() 41 | 42 | results = {} 43 | try: 44 | result = base64.b64encode( 45 | self.models["segmentation"].ask(image_b64) 46 | ).decode() 47 | results["segmentation"] = result 48 | except BaseException as e: 49 | print("segmentation failed ", e) 50 | 51 | return results 52 | 53 | def supports(self, model: str, topic: str): 54 | if model not in self.capabilities: 55 | return False 56 | if topic not in self.capabilities[model]: 57 | return False 58 | 59 | return True 60 | -------------------------------------------------------------------------------- /core/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fastapi import FastAPI 4 | 5 | from proto import QuestionProto, ImageProto, NLPProto 6 | from proxy import Proxy, Filter 7 | from mocks import generate_mocks 8 | from nlp import NLP 9 | 10 | app = FastAPI() 11 | proxy = Proxy() 12 | nlp = NLP() 13 | 14 | for hip_mock in generate_mocks(): 15 | proxy.register(hip_mock) 16 | 17 | 18 | @app.get("/sources") 19 | def get_sources(): 20 | return proxy.sources() 21 | 22 | 23 | @app.get("/capabilities") 24 | def get_sources(): 25 | return proxy.capabilities() 26 | 27 | 28 | @app.post("/vqa") 29 | def vqa_task(q: QuestionProto): 30 | try: 31 | prefilter = proxy.prefilter(q.image_b64) 32 | if not prefilter["valid"]: 33 | return {"error": "invalid input"} 34 | 35 | result = proxy.ask(q.question, q.image_b64, prefilter["topic"]) 36 | result = proxy.aggregate(result) 37 | return {"answer": result} 38 | except BaseException as e: 39 | return {"error": str(e)} 40 | 41 | 42 | @app.post("/segmentation") 43 | def segmentation_task(q: ImageProto): 44 | try: 45 | prefilter = proxy.prefilter(q.image_b64) 46 | if not prefilter["valid"]: 47 | return {"error": "invalid input"} 48 | 49 | result = proxy.segment(q.image_b64, prefilter["topic"]) 50 | return {"answer": result} 51 | except BaseException as e: 52 | return {"error": str(e)} 53 | 54 | 55 | @app.post("/prefilter") 56 | def prefilter_task(q: ImageProto): 57 | try: 58 | result = proxy.prefilter(q.image_b64) 59 | 60 | if not result["valid"]: 61 | return {"answer": result} 62 | 63 | result["anomalies"] = proxy.anomalies(q.image_b64, result["topic"]) 64 | 65 | return {"answer": result} 66 | except BaseException as e: 67 | return {"error": str(e)} 68 | 69 | 70 | @app.post("/nlp") 71 | def nlp_task(q: NLPProto): 72 | try: 73 | return {"answer": nlp.ask(q.data)} 74 | except BaseException as e: 75 | return {"error": str(e)} 76 | -------------------------------------------------------------------------------- /core/mocks.py: -------------------------------------------------------------------------------- 1 | from hip import HealthIntelProviderLocal 2 | 3 | full_topics = [ 4 | "xr_elbow", 5 | "xr_forearm", 6 | "xr_hand", 7 | "xr_hummerus", 8 | "xr_shoulder", 9 | "xr_wrist", 10 | "xr_chest", 11 | "scan_brain", 12 | "scan_breast", 13 | "scan_eyes", 14 | "scan_heart", 15 | ] 16 | 17 | 18 | def generate_mocks(): 19 | return [ 20 | HealthIntelProviderLocal("Gotham General Hospital", {"vqa": ["scan_brain"]}), 21 | HealthIntelProviderLocal( 22 | "Metropolis General Hospital", 23 | {"vqa": ["scan_brain", "xr_chest"], "segmentation": ["scan_brain"]}, 24 | ), 25 | HealthIntelProviderLocal("Smallville Medical Center", {"vqa": ["xr_chest"]}), 26 | HealthIntelProviderLocal("Mercy General Hospital", {"vqa": ["xr_chest"]}), 27 | HealthIntelProviderLocal( 28 | "St. Mary's Hospital", {"segmentation": ["scan_brain"]} 29 | ), 30 | ] 31 | -------------------------------------------------------------------------------- /core/nlp.py: -------------------------------------------------------------------------------- 1 | import re 2 | import nltk 3 | import string 4 | from words import greeting_inputs, definition_inputs, vqa_filter 5 | 6 | nltk.download("punkt") 7 | nltk.download("wordnet") 8 | 9 | 10 | class NLP: 11 | def __init__(self): 12 | self.wnlemmatizer = nltk.stem.WordNetLemmatizer() 13 | self.punctuation_removal = dict( 14 | (ord(punctuation), None) for punctuation in string.punctuation 15 | ) 16 | 17 | def perform_lemmatization(self, tokens): 18 | return [self.wnlemmatizer.lemmatize(token) for token in tokens] 19 | 20 | def get_processed_text(self, document): 21 | return self.perform_lemmatization( 22 | nltk.word_tokenize(document.lower().translate(self.punctuation_removal)) 23 | ) 24 | 25 | def is_greeting(self, words, query): 26 | if query in greeting_inputs: 27 | return True 28 | for word in words: 29 | if word in greeting_inputs: 30 | return True 31 | return False 32 | 33 | def is_definition(self, words, query): 34 | for sep in definition_inputs: 35 | if sep in query: 36 | return query.split(sep)[1] 37 | return None 38 | 39 | def is_vqa_safe(self, words, query): 40 | for word in words: 41 | if word in vqa_filter: 42 | return True 43 | return False 44 | 45 | def ask(self, query): 46 | words = self.get_processed_text(query) 47 | query = " ".join(words) 48 | 49 | if self.is_greeting(words, query): 50 | return {"type": "greeting"} 51 | 52 | define = self.is_definition(words, query) 53 | if define: 54 | return { 55 | "type": "wiki", 56 | "data": "".join(self.get_processed_text(define)), 57 | } 58 | if self.is_vqa_safe(words, query): 59 | return { 60 | "type": "vqa", 61 | } 62 | return { 63 | "type": "invalid", 64 | } 65 | -------------------------------------------------------------------------------- /core/proto.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class QuestionProto(BaseModel): 5 | image_b64: str 6 | question: str 7 | 8 | 9 | class ImageProto(BaseModel): 10 | image_b64: str 11 | 12 | 13 | class NLPProto(BaseModel): 14 | data: str 15 | -------------------------------------------------------------------------------- /core/proxy.py: -------------------------------------------------------------------------------- 1 | from models.model_prefilter.inference import Prefilter 2 | from models.model_medical_label.inference import ImageRouter 3 | 4 | from helpers import hash_input 5 | 6 | 7 | class Filter: 8 | def __init__(self): 9 | self.prefilter = Prefilter() 10 | self.router = ImageRouter() 11 | self.cache = {} 12 | 13 | def ask(self, image_b64: str): 14 | h = hash_input(image_b64) 15 | if h in self.cache: 16 | return self.cache[h] 17 | 18 | valid = 0 == self.prefilter.ask(image_b64) 19 | 20 | result = { 21 | "valid": valid, 22 | } 23 | 24 | if valid: 25 | result["topic"] = self.router.ask(image_b64) 26 | 27 | self.cache[h] = result 28 | 29 | return result 30 | 31 | 32 | class Proxy: 33 | def __init__(self): 34 | self.providers = [] 35 | self.filter = Filter() 36 | 37 | def register(self, handler): 38 | print("registering ", handler.name) 39 | self.providers.append(handler) 40 | 41 | def sources(self): 42 | result = [] 43 | for provider in self.providers: 44 | result.append(provider.name) 45 | return result 46 | 47 | def capabilities(self): 48 | result = {} 49 | for provider in self.providers: 50 | for model in provider.capabilities: 51 | if model not in result: 52 | result[model] = [] 53 | for topic in provider.capabilities[model]: 54 | if topic not in result[model]: 55 | result[model].append(topic) 56 | return result 57 | 58 | def aggregate(self, results): 59 | data = results["hip"] 60 | results["total"] = 0 61 | results["aggregated"] = {} 62 | 63 | for provider in data: 64 | results["total"] += 1 65 | for model in data[provider]: 66 | if model not in results["aggregated"]: 67 | results["aggregated"][model] = {} 68 | val = data[provider][model].lower() 69 | if val not in results["aggregated"][model]: 70 | results["aggregated"][model][val] = 0 71 | results["aggregated"][model][val] += 1 72 | return results 73 | 74 | def prefilter(self, image_b64: str): 75 | return self.filter.ask(image_b64) 76 | 77 | def ask(self, question: str, image_b64: str, topic: str): 78 | results = {"hip": {}} 79 | for provider in self.providers: 80 | if not provider.supports("vqa", topic): 81 | continue 82 | 83 | results["hip"][provider.name] = provider.vqa(question, image_b64, topic) 84 | 85 | return results 86 | 87 | def segment(self, image_b64: str, topic: str): 88 | results = {"hip": {}} 89 | for provider in self.providers: 90 | if not provider.supports("segmentation", topic): 91 | continue 92 | results["hip"][provider.name] = provider.segment(image_b64, topic) 93 | 94 | return results 95 | 96 | def anomalies(self, image_b64: str, topic: str): 97 | filter_q = "is there something abnormal in the image?" 98 | questions = { 99 | "what": "what is abnormal in the image?", 100 | "why": "why is this abnormal?", 101 | "where": "where is something abnormal?", 102 | } 103 | 104 | has_anomaly = self.ask(filter_q, image_b64, topic) 105 | has_anomaly = self.aggregate(has_anomaly) 106 | 107 | results = {"has": 0, "total": has_anomaly["total"]} 108 | 109 | if has_anomaly["total"] == 0: 110 | return results 111 | 112 | if "yes" not in has_anomaly["aggregated"]["vqa"]: 113 | return results 114 | 115 | results["has"] = has_anomaly["aggregated"]["vqa"]["yes"] 116 | for qtype in questions: 117 | res = self.ask(questions[qtype], image_b64, topic) 118 | res = self.aggregate(res) 119 | 120 | results[qtype] = res["aggregated"]["vqa"] 121 | 122 | return results 123 | -------------------------------------------------------------------------------- /core/requirements.txt: -------------------------------------------------------------------------------- 1 | bunch 2 | dataset 3 | fastapi 4 | h5py 5 | numpy 6 | opencv-python 7 | pandas 8 | Pillow 9 | torch 10 | torchvision 11 | uvicorn 12 | medpy 13 | scikit-image 14 | nltk 15 | -------------------------------------------------------------------------------- /core/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | pip install -r requirements.txt 5 | 6 | git submodule init && git submodule update 7 | 8 | cd models/model_vqa && git submodule init && git submodule update 9 | cd - 10 | 11 | 12 | -------------------------------------------------------------------------------- /core/tests/model_brain_segmentation_sanity.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import requests 4 | import json 5 | from PIL import Image 6 | from io import BytesIO 7 | from skimage.io import imsave 8 | import cv2 9 | import numpy as np 10 | 11 | samples = "./samples/" 12 | 13 | requests_session = requests.Session() 14 | server = "http://127.0.0.1:8000/segmentation" 15 | 16 | 17 | for subdir, dirs, files in os.walk(samples): 18 | for f in files: 19 | path = subdir + f 20 | print(path) 21 | 22 | with open(path, "rb") as image_file: 23 | encoded_string = base64.b64encode(image_file.read()).decode() 24 | 25 | payload = { 26 | "image_b64": encoded_string, 27 | } 28 | 29 | r = requests_session.post(server, json=payload, timeout=10) 30 | 31 | data = json.loads(r.text) 32 | output = data["answer"]["hip"] 33 | 34 | for source in output: 35 | string = output[source]["segmentation"] 36 | decoded = base64.b64decode(string) 37 | decoded = BytesIO(decoded) 38 | img = Image.open(decoded) 39 | -------------------------------------------------------------------------------- /core/tests/model_medical_label_sanity.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from io import BytesIO 3 | import base64 4 | import json 5 | 6 | medical_imgs = [ 7 | "https://media.sciencephoto.com/image/c0371577/800wm/C0371577-Stroke,_MRI_brain_scan.jpg", 8 | "https://prod-images-static.radiopaedia.org/images/34839897/e0bfac31c00d077d18aca7ab33b495_gallery.jpeg", 9 | "https://prod-images-static.radiopaedia.org/images/157210/332aa0c67cb2e035e372c7cb3ceca2_jumbo.jpg", 10 | "https://image.freepik.com/photos-gratuite/technologie-rayon-chirurgie-x-ray-xray_1172-444.jpg", 11 | "https://media.wired.com/photos/5ba015a0ab6e142d95f93dac/125:94/w_1196,h_900,c_limit/R.Kim-eyescan-w.jpg", 12 | "https://encrypted-tbn0.gstatic.com/images?q=tbn%3AANd9GcQ4IevMy0w_3XO3Wc-PNRB5lBgwqvoSSttiAw&usqp=CAU", 13 | "https://www.startradiology.com/uploads/images/english-class-x-elbow-fig-5-normal-anatomy-elbow-lateral-blanco.jpg", 14 | ] 15 | 16 | 17 | requests_session = requests.Session() 18 | server = "http://127.0.0.1:8000/prefilter" 19 | 20 | 21 | for medical_img in medical_imgs: 22 | response = requests.get(medical_img) 23 | img = BytesIO(response.content).getvalue() 24 | encoded_string = base64.b64encode(img).decode() 25 | 26 | payload = { 27 | "image_b64": encoded_string, 28 | } 29 | 30 | r = requests_session.post(server, json=payload, timeout=10) 31 | 32 | data = json.loads(r.text) 33 | print(data) 34 | output = data["answer"] 35 | -------------------------------------------------------------------------------- /core/tests/model_prefilter_sanity.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from io import BytesIO 3 | import base64 4 | import json 5 | 6 | medical_imgs = [ 7 | "https://media.sciencephoto.com/image/c0371577/800wm/C0371577-Stroke,_MRI_brain_scan.jpg", 8 | "https://prod-images-static.radiopaedia.org/images/34839897/e0bfac31c00d077d18aca7ab33b495_gallery.jpeg", 9 | "https://prod-images-static.radiopaedia.org/images/157210/332aa0c67cb2e035e372c7cb3ceca2_jumbo.jpg", 10 | "https://www.mqmi.com.au/wp-content/uploads/2019/10/CT-CORONARY-ANGIOGRAM-Severe-1.jpg", 11 | "http://www.medicalradiation.com/wp-content/uploads/fluoroscopy.jpg", 12 | "https://image.freepik.com/photos-gratuite/technologie-rayon-chirurgie-x-ray-xray_1172-444.jpg", 13 | "https://prod-images-static.radiopaedia.org/images/51665621/badcab5dfbb1423245a3343156b347_big_gallery.jpeg", 14 | ] 15 | 16 | nonmedical_imgs = [ 17 | "https://i.pinimg.com/originals/e0/3d/5b/e03d5b812b2734826f76960eca5b5541.jpg", 18 | "https://i.pinimg.com/originals/82/61/79/826179defbbdbc3ec7fdc37e15ea6bab.jpg", 19 | "https://www.lifesavvy.com/thumbcache/0/0/31c7385df31261da25272193d5226120/p/uploads/2019/05/daf3eeae-3.jpg", 20 | "https://acumass.com/wp-content/uploads/2016/02/selfie-pay.jpeg", 21 | ] 22 | 23 | 24 | requests_session = requests.Session() 25 | server = "http://127.0.0.1:8000/prefilter" 26 | 27 | 28 | print("Medical") 29 | for medical_img in medical_imgs: 30 | response = requests.get(medical_img) 31 | img = BytesIO(response.content).getvalue() 32 | encoded_string = base64.b64encode(img).decode() 33 | 34 | payload = { 35 | "image_b64": encoded_string, 36 | } 37 | 38 | r = requests_session.post(server, json=payload, timeout=10) 39 | 40 | data = json.loads(r.text) 41 | print(data) 42 | output = data["answer"] 43 | print(output) 44 | 45 | print("Nonmedical") 46 | for non_medical_img in nonmedical_imgs: 47 | response = requests.get(non_medical_img) 48 | img = BytesIO(response.content).getvalue() 49 | encoded_string = base64.b64encode(img).decode() 50 | 51 | payload = { 52 | "image_b64": encoded_string, 53 | } 54 | 55 | r = requests_session.post(server, json=payload, timeout=10) 56 | 57 | data = json.loads(r.text) 58 | output = data["answer"] 59 | print(output) 60 | -------------------------------------------------------------------------------- /core/tests/model_vqa_sanity.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from pathlib import Path 4 | 5 | import requests 6 | 7 | model_root = Path("../models/model_vqa/MICCAI19-MedVQA") 8 | 9 | data = json.load(open(model_root / "data_RAD/testset.json")) 10 | img_folder = model_root / "data_RAD/images/" 11 | 12 | questions = {} 13 | for entry in data: 14 | if entry["image_organ"] not in questions: 15 | questions[entry["image_organ"]] = [] 16 | 17 | question = entry["question"] 18 | filename = img_folder / entry["image_name"] 19 | 20 | with open(filename, "rb") as image_file: 21 | encoded_string = base64.b64encode(image_file.read()).decode() 22 | 23 | obj = { 24 | "image_b64": encoded_string, 25 | "name": entry["image_name"], 26 | "question": question, 27 | "expected_answer": entry["answer"], 28 | } 29 | questions[entry["image_organ"]].append(obj) 30 | 31 | fail = 0 32 | ok = 0 33 | total = 0 34 | 35 | requests_session = requests.Session() 36 | server = "http://127.0.0.1:8000" 37 | 38 | for tag in questions: 39 | if tag != "HEAD" and tag != "CHEST": 40 | continue 41 | 42 | for q in questions[tag]: 43 | payload = { 44 | "image_b64": q["image_b64"], 45 | } 46 | r = requests_session.post(server + "/prefilter", json=payload, timeout=10) 47 | 48 | data = json.loads(r.text) 49 | print(data) 50 | result = data["answer"] 51 | 52 | if not result["valid"]: 53 | continue 54 | 55 | payload = { 56 | "question": q["question"], 57 | "image_b64": q["image_b64"], 58 | } 59 | r = requests_session.post(server + "/vqa", json=payload, timeout=10) 60 | 61 | data = json.loads(r.text) 62 | print(data) 63 | result = data["answer"] 64 | 65 | matching = 0 66 | for hospital in result["hip"]: 67 | expected = str(q["expected_answer"]).lower() 68 | actual = str(result["hip"][hospital]["vqa"]).lower() 69 | if expected == actual: 70 | matching += 1 71 | 72 | total += 1 73 | if matching == 0: 74 | fail += 1 75 | print("FAIL: ", q["name"], q["expected_answer"], result) 76 | else: 77 | ok += 1 78 | print("OK: ", q["name"], " q:", q["question"], " a:", expected) 79 | 80 | print("Total ", total) 81 | print("ok ", ok) 82 | print("failed ", fail) 83 | -------------------------------------------------------------------------------- /core/words.py: -------------------------------------------------------------------------------- 1 | greeting_inputs = [ 2 | "hi", 3 | "hei", 4 | "hello", 5 | "hey", 6 | "helloo", 7 | "hellooo", 8 | "g morining", 9 | "gmorning", 10 | "good morning", 11 | "morning", 12 | "good day", 13 | "good afternoon", 14 | "good evening", 15 | "greetings", 16 | "greeting", 17 | "good to see you", 18 | "its good seeing you", 19 | "how are you", 20 | "howre you", 21 | "how are you doing", 22 | "how ya doin", 23 | "how ya doin", 24 | "how is everything", 25 | "how is everything going", 26 | "hows everything going", 27 | "how is you", 28 | "hows you", 29 | "how are things", 30 | "howre things", 31 | "how is it going", 32 | "hows it going", 33 | "hows it goin", 34 | "hows it goin", 35 | "how is life been treating you", 36 | "hows life been treating you", 37 | "how have you been", 38 | "howve you been", 39 | "what is up", 40 | "whats up", 41 | "what is cracking", 42 | "whats cracking", 43 | "what is good", 44 | "whats good", 45 | "what is happening", 46 | "whats happening", 47 | "what is new", 48 | "whats new", 49 | "what is neww", 50 | "gday", 51 | "howdy", 52 | ] 53 | 54 | definition_inputs = [ 55 | "what is ", 56 | "characterize ", 57 | "describe ", 58 | "detail", 59 | "what is the meaning of", 60 | "meaning of", 61 | "what is the definition of", 62 | "definition of", 63 | "how do you define", 64 | "how would you define", 65 | "explain", 66 | "how do you explain", 67 | "how do you define", 68 | "what is", 69 | "characterize", 70 | "how do you characterize", 71 | "how would you characterize", 72 | "describe", 73 | "how do you describe", 74 | "how would you describe", 75 | "detail", 76 | "how do you detail", 77 | "how would you detail", 78 | "explain", 79 | "how do you explain", 80 | "how would you explain", 81 | "meaning of ", 82 | "definition of ", 83 | "how do you define", 84 | "explain ", 85 | "define ", 86 | ] 87 | 88 | 89 | vqa_filter = [ 90 | "regions", 91 | "brain", 92 | "infarcted", 93 | "abnormality", 94 | "pathology", 95 | "image", 96 | "imaging", 97 | "organ", 98 | "system", 99 | "pictured", 100 | "swelling", 101 | "grey", 102 | "matter", 103 | "plane", 104 | "oriented", 105 | "skull", 106 | "fracture", 107 | "mass", 108 | "located", 109 | "near", 110 | "compressing", 111 | "section", 112 | "compression", 113 | "patient", 114 | "midbrain", 115 | "structures", 116 | "shifted", 117 | "midline", 118 | "shift", 119 | "cerebral", 120 | "parenchyma", 121 | "cerebellum", 122 | "transverse", 123 | "crossed", 124 | "evidence", 125 | "midlight", 126 | "vertebral", 127 | "arteries", 128 | "patent", 129 | "vertebro", 130 | "basilar", 131 | "arterial", 132 | "network", 133 | "viewed", 134 | "artery", 135 | "blurring", 136 | "white", 137 | "junctions", 138 | "right", 139 | "temporal", 140 | "lobe", 141 | "definitive", 142 | "border", 143 | "between", 144 | "calcifications", 145 | "left", 146 | "middle", 147 | "appear", 148 | "present", 149 | "weighted", 150 | "lesion", 151 | "causing", 152 | "significant", 153 | "brainstem", 154 | "herniation", 155 | "secondary", 156 | "associated", 157 | "identified", 158 | "visible", 159 | "shifting", 160 | "edema", 161 | "cytotoxic", 162 | "smooth", 163 | "appearing", 164 | "taken", 165 | "motion", 166 | "artifact", 167 | "lobes", 168 | "herniated", 169 | "would", 170 | "describe", 171 | "characteristics", 172 | "abnormal", 173 | "findings", 174 | "normal", 175 | "scan", 176 | "tissue", 177 | "ischemic", 178 | "atrophy", 179 | "focal", 180 | "diffuse", 181 | "hemisphere", 182 | "lesions", 183 | "abnormalities", 184 | "enhancing", 185 | "ring", 186 | "suggestive", 187 | "most", 188 | "likely", 189 | "enhancement", 190 | "enhanced", 191 | "which", 192 | "ischemia", 193 | "affect", 194 | "neighboring", 195 | "structure", 196 | "effect", 197 | "suggest", 198 | "causes", 199 | "hyperintensity", 200 | "relative", 201 | "eyes", 202 | "primary", 203 | "dense", 204 | "surrounding", 205 | "denser", 206 | "contrast", 207 | "applied", 208 | "hematoma", 209 | "blood", 210 | "lateral", 211 | "ventricle", 212 | "hemorrhage", 213 | "radiological", 214 | "description", 215 | "color", 216 | "possible", 217 | "choroid", 218 | "fissure", 219 | "around", 220 | "basal", 221 | "ganglia", 222 | "enlarged", 223 | "alteration", 224 | "posterior", 225 | "medulla", 226 | "swollen", 227 | "infiltrating", 228 | "involvement", 229 | "ventricles", 230 | "compressed", 231 | "feeding", 232 | "hypodensity", 233 | "besides", 234 | "these", 235 | "hyperintensities", 236 | "cysts", 237 | "cyst", 238 | "atrophied", 239 | "shrunk", 240 | "fractures", 241 | "bone", 242 | "bright", 243 | "center", 244 | "that", 245 | "form", 246 | "subdural", 247 | "dark", 248 | "areas", 249 | "show", 250 | "suspect", 251 | "abcess", 252 | "cancer", 253 | "finding", 254 | "abscess", 255 | "long", 256 | "modality", 257 | "take", 258 | "complete", 259 | "hours", 260 | "minutes", 261 | "appreciate", 262 | "diagnosis", 263 | "impression", 264 | "made", 265 | "problem", 266 | "originate", 267 | "process", 268 | "visualize", 269 | "physical", 270 | "injury", 271 | "medical", 272 | "cause", 273 | "fractured", 274 | "anoxic", 275 | "infarction", 276 | "axial", 277 | "seen", 278 | "demonstrate", 279 | "displayed", 280 | "sulci", 281 | "blunted", 282 | "blunting", 283 | "characterize", 284 | "characterization", 285 | "caudate", 286 | "nucleus", 287 | "involved", 288 | "disease", 289 | "origin", 290 | "vascular", 291 | "neoplastic", 292 | "shown", 293 | "intraventricular", 294 | "acute", 295 | "head", 296 | "region", 297 | "shows", 298 | "display", 299 | "hydrocephalus", 300 | "sequence", 301 | "demonstrates", 302 | "hyper", 303 | "intense", 304 | "signal", 305 | "bleed", 306 | "kind", 307 | "rectus", 308 | "muscles", 309 | "infarcts", 310 | "adjective", 311 | "wedge", 312 | "shaped", 313 | "calcification", 314 | "calcified", 315 | "location", 316 | "dependent", 317 | "layering", 318 | "occipital", 319 | "horns", 320 | "hyperdensities", 321 | "represent", 322 | "indicative", 323 | "area", 324 | "gray", 325 | "tell", 326 | "best", 327 | "weighting", 328 | "about", 329 | "epidural", 330 | "bleeds", 331 | "hemorrhagic", 332 | "acquire", 333 | "sturctures", 334 | "cranial", 335 | "nerves", 336 | "affected", 337 | "possibly", 338 | "major", 339 | "enhance", 340 | "enchanced", 341 | "noncontrast", 342 | "large", 343 | "captured", 344 | "condition", 345 | "cortex", 346 | "cortical", 347 | "differentiated", 348 | "differentiation", 349 | "ventricular", 350 | "enlargement", 351 | "infarct", 352 | "vessel", 353 | "signs", 354 | "visualized", 355 | "cerebrum", 356 | "subarachnoid", 357 | "bleeding", 358 | "sign", 359 | "cortexes", 360 | "unaltered", 361 | "bulging", 362 | "called", 363 | "bottom", 364 | "indicates", 365 | "fluid", 366 | "accumulation", 367 | "evinced", 368 | "horn", 369 | "enhancements", 370 | "gyral", 371 | "hyperintense", 372 | "hypointense", 373 | "intensity", 374 | "junction", 375 | "altered", 376 | "locations", 377 | "embolus", 378 | "from", 379 | "vessels", 380 | "organs", 381 | "impacted", 382 | "depicted", 383 | "signaling", 384 | "method", 385 | "hyperlucencies", 386 | "lighting", 387 | "periphery", 388 | "outer", 389 | "singular", 390 | "multilobulated", 391 | "lobulation", 392 | "single", 393 | "bilateral", 394 | "hemispheres", 395 | "contains", 396 | "hyperdense", 397 | "cerebellar", 398 | "attenuated", 399 | "probably", 400 | "term", 401 | "corpus", 402 | "callosum", 403 | "open", 404 | "orbits", 405 | "size", 406 | "increased", 407 | "result", 408 | "material", 409 | "restricted", 410 | "diffusion", 411 | "picture", 412 | "smaller", 413 | "many", 414 | "hyperdensity", 415 | "vasculature", 416 | "uniform", 417 | "density", 418 | "same", 419 | "throughout", 420 | "clot", 421 | "differential", 422 | "shape", 423 | "sided", 424 | "thrombosis", 425 | "them", 426 | "agent", 427 | "radiolucent", 428 | "radioopaque", 429 | "larger", 430 | "slice", 431 | "superior", 432 | "characterized", 433 | "flair", 434 | "protocol", 435 | "saggital", 436 | "territory", 437 | "edematous", 438 | "category", 439 | "hemmorhage", 440 | "high", 441 | "half", 442 | "largest", 443 | "showing", 444 | "able", 445 | "types", 446 | "incidentally", 447 | "central", 448 | "intesities", 449 | "nature", 450 | "first", 451 | "test", 452 | "suspected", 453 | "spared", 454 | "pathological", 455 | "caused", 456 | "call", 457 | "distribution", 458 | "typical", 459 | "parts", 460 | "sagittal", 461 | "symmetrical", 462 | "symmetric", 463 | "brighter", 464 | "indicated", 465 | "your", 466 | "diagnoses", 467 | "alternate", 468 | "lens", 469 | "synonymous", 470 | "going", 471 | "favor", 472 | "viral", 473 | "parasitic", 474 | "obvious", 475 | "seem", 476 | "space", 477 | "portion", 478 | "represents", 479 | "spaces", 480 | "genetic", 481 | "etiology", 482 | "were", 483 | "involve", 484 | "neuro", 485 | "deficits", 486 | "will", 487 | "predicted", 488 | "suggested", 489 | "optic", 490 | "chiasm", 491 | "highlighted", 492 | "unified", 493 | "mean", 494 | "makes", 495 | "know", 496 | "tells", 497 | "infectious", 498 | "frontal", 499 | "could", 500 | "multiple", 501 | "assessed", 502 | "difference", 503 | "infiltrate", 504 | "sinuses", 505 | "level", 506 | "hyperlucency", 507 | "indicate", 508 | "hyperintensitites", 509 | "structural", 510 | "deviation", 511 | "lungs", 512 | "pneumothorax", 513 | "chest", 514 | "trachea", 515 | "aortic", 516 | "aneurysm", 517 | "costovertebral", 518 | "angles", 519 | "anterior", 520 | "under", 521 | "diaphragm", 522 | "ribs", 523 | "fracturing", 524 | "lung", 525 | "penetration", 526 | "inspiratory", 527 | "effort", 528 | "position", 529 | "hilar", 530 | "soft", 531 | "densities", 532 | "symmetry", 533 | "hilums", 534 | "hilum", 535 | "equivalent", 536 | "tissues", 537 | "descending", 538 | "silhouette", 539 | "contour", 540 | "tortuosity", 541 | "aorta", 542 | "format", 543 | "opacities", 544 | "noted", 545 | "nodules", 546 | "intraparenchymal", 547 | "heart", 548 | "apical", 549 | "field", 550 | "narrowed", 551 | "pulmonary", 552 | "cardiac", 553 | "difficult", 554 | "delineate", 555 | "costophrenic", 556 | "angle", 557 | "sharp", 558 | "film", 559 | "hypointensity", 560 | "hemidiaphragm", 561 | "lower", 562 | "fields", 563 | "cavitary", 564 | "primarily", 565 | "evaluated", 566 | "imaged", 567 | "study", 568 | "hila", 569 | "happening", 570 | "apices", 571 | "asymmetrical", 572 | "breasts", 573 | "aortopulmonary", 574 | "window", 575 | "characteristic", 576 | "catheter", 577 | "extend", 578 | "into", 579 | "markings", 580 | "width", 581 | "exceed", 582 | "thorax", 583 | "wide", 584 | "shadow", 585 | "obscured", 586 | "clearly", 587 | "both", 588 | "sides", 589 | "nodular", 590 | "inferior", 591 | "arch", 592 | "bones", 593 | "lighter", 594 | "wider", 595 | "compared", 596 | "free", 597 | "pathologic", 598 | "supraclavicular", 599 | "fossae", 600 | "subdiaphram", 601 | "important", 602 | "prominent", 603 | "prominently", 604 | "extensive", 605 | "infiltration", 606 | "infiltrates", 607 | "collection", 608 | "subdiaphragmatic", 609 | "fifth", 610 | "broken", 611 | "subcutaneous", 612 | "neck", 613 | "opacity", 614 | "apex", 615 | "pneumonthorax", 616 | "inappropriate", 617 | "properly", 618 | "exposed", 619 | "dilated", 620 | "nodule", 621 | "small", 622 | "path", 623 | "principally", 624 | "solitary", 625 | "stomach", 626 | "undulations", 627 | "along", 628 | "column", 629 | "deviated", 630 | "appreciated", 631 | "solid", 632 | "cystic", 633 | "anything", 634 | "male", 635 | "female", 636 | "highlight", 637 | "pleural", 638 | "effusion", 639 | "sufficient", 640 | "enough", 641 | "diagnose", 642 | "typically", 643 | "evaluate", 644 | "kidneys", 645 | "bladde", 646 | "ureters", 647 | "pericardial", 648 | "sustain", 649 | "damage", 650 | "anywhere", 651 | "humerus", 652 | "abnormally", 653 | "inflated", 654 | "hylar", 655 | "lymphadenopathy", 656 | "widened", 657 | "mediastinum", 658 | "failure", 659 | "cardiomegaly", 660 | "coronal", 661 | "constitute", 662 | "loculated", 663 | "loculation", 664 | "standing", 665 | "widening", 666 | "mediastium", 667 | "demonstrated", 668 | "depict", 669 | "interstitial", 670 | "xray", 671 | "underexposed", 672 | "radiograph", 673 | "foreign", 674 | "body", 675 | "intubated", 676 | "contain", 677 | "cardiovascular", 678 | "dissection", 679 | "clavicle", 680 | "endotracheal", 681 | "tube", 682 | "placed", 683 | "pneumoperitoneum", 684 | "thoracic", 685 | "elevation", 686 | "hyperinflated", 687 | "defect", 688 | "systems", 689 | "expiration", 690 | "bump", 691 | "upper", 692 | "quadrant", 693 | "patients", 694 | "symmtery", 695 | "look", 696 | "being", 697 | "higher", 698 | "masses", 699 | "their", 700 | "azygoesophageal", 701 | "recess", 702 | "direction", 703 | "elevated", 704 | "lateralized", 705 | "only", 706 | "confined", 707 | "abdomen", 708 | "abdominal", 709 | "cavity", 710 | "consolidations", 711 | "chostrochondral", 712 | "common", 713 | "aspirations", 714 | "cilia", 715 | "alveoli", 716 | "blocked", 717 | "mismatch", 718 | "else", 719 | "need", 720 | "order", 721 | "localize", 722 | "tracheal", 723 | "rays", 724 | "identify", 725 | "liver", 726 | "hemidiaphragms", 727 | "flattened", 728 | "either", 729 | "knob", 730 | "pneuomothorax", 731 | "clavicles", 732 | "clavicular", 733 | "flat", 734 | "borders", 735 | "physiology", 736 | "healthy", 737 | "airway", 738 | "walls", 739 | "thickened", 740 | "bases", 741 | "thickening", 742 | "wall", 743 | "wrong", 744 | "described", 745 | "distributions", 746 | "dots", 747 | "hyperlucent", 748 | "laterality", 749 | "patchy", 750 | "margins", 751 | "base", 752 | "adenopathy", 753 | "pleura", 754 | "thick", 755 | "hemithorax", 756 | "lucent", 757 | "decreased", 758 | "depressed", 759 | "nipple", 760 | "markers", 761 | "superimposed", 762 | "opacification", 763 | "opacificaions", 764 | "terms", 765 | "consistent", 766 | "gender", 767 | "greater", 768 | "diameter", 769 | "consolidation", 770 | "orientation", 771 | "acquired", 772 | "leads", 773 | "circumscribed", 774 | "adequate", 775 | "inspiration", 776 | "mediastinal", 777 | "pneumothroax", 778 | "line", 779 | "curly", 780 | "gastric", 781 | "bubble", 782 | "outline", 783 | "circumferential", 784 | "colon", 785 | "rotated", 786 | "positioned", 787 | "inappropriately", 788 | "hemidiaphragmatic", 789 | "backwards", 790 | "mirror", 791 | "three", 792 | "circular", 793 | "vein", 794 | "venous", 795 | "ground", 796 | "glass", 797 | "hemodiaphragm", 798 | "presence", 799 | "pneumonia", 800 | "effusions", 801 | "plain", 802 | "special", 803 | "concerning", 804 | "filling", 805 | "hyperinflation", 806 | "collapsed", 807 | "collapse", 808 | "easily", 809 | "procedure", 810 | "might", 811 | "reveal", 812 | "displaced", 813 | "determine", 814 | "tram", 815 | "track", 816 | "indictate", 817 | "deviating", 818 | "subclavian", 819 | "please", 820 | "lymph", 821 | "nodes", 822 | "tension", 823 | "expect", 824 | "plaques", 825 | "surfaces", 826 | "hemithoraces", 827 | "determines", 828 | "absence", 829 | "unilateral", 830 | "exterior", 831 | "interior", 832 | "inside", 833 | "outside", 834 | "observed", 835 | "clear", 836 | "vertical", 837 | "within", 838 | "superficial", 839 | "skin", 840 | "aeration", 841 | ] 842 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = ../build/sphinx/ 9 | AUTODOCDIR = api 10 | AUTODOCBUILD = sphinx-apidoc 11 | PROJECT = PyTorchXAI 12 | MODULEDIR = ../src/pytorchxai 13 | 14 | # User-friendly check for sphinx-build 15 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 16 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 17 | endif 18 | 19 | # Internal variables. 20 | PAPEROPT_a4 = -D latex_paper_size=a4 21 | PAPEROPT_letter = -D latex_paper_size=letter 22 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 23 | # the i18n builder cannot share the environment and doctrees with the others 24 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 25 | 26 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext doc-requirements 27 | 28 | help: 29 | @echo "Please use \`make ' where is one of" 30 | @echo " html to make standalone HTML files" 31 | @echo " dirhtml to make HTML files named index.html in directories" 32 | @echo " singlehtml to make a single large HTML file" 33 | @echo " pickle to make pickle files" 34 | @echo " json to make JSON files" 35 | @echo " htmlhelp to make HTML files and a HTML help project" 36 | @echo " qthelp to make HTML files and a qthelp project" 37 | @echo " devhelp to make HTML files and a Devhelp project" 38 | @echo " epub to make an epub" 39 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 40 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 41 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 42 | @echo " text to make text files" 43 | @echo " man to make manual pages" 44 | @echo " texinfo to make Texinfo files" 45 | @echo " info to make Texinfo files and run them through makeinfo" 46 | @echo " gettext to make PO message catalogs" 47 | @echo " changes to make an overview of all changed/added/deprecated items" 48 | @echo " xml to make Docutils-native XML files" 49 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 50 | @echo " linkcheck to check all external links for integrity" 51 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 52 | 53 | clean: 54 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 55 | 56 | $(AUTODOCDIR): $(MODULEDIR) 57 | mkdir -p $@ 58 | $(AUTODOCBUILD) -f -o $@ $^ 59 | 60 | doc-requirements: $(AUTODOCDIR) 61 | 62 | html: doc-requirements 63 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 64 | @echo 65 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 66 | 67 | dirhtml: doc-requirements 68 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 69 | @echo 70 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 71 | 72 | singlehtml: doc-requirements 73 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 74 | @echo 75 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 76 | 77 | pickle: doc-requirements 78 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 79 | @echo 80 | @echo "Build finished; now you can process the pickle files." 81 | 82 | json: doc-requirements 83 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 84 | @echo 85 | @echo "Build finished; now you can process the JSON files." 86 | 87 | htmlhelp: doc-requirements 88 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 89 | @echo 90 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 91 | ".hhp project file in $(BUILDDIR)/htmlhelp." 92 | 93 | qthelp: doc-requirements 94 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 95 | @echo 96 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 97 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 98 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/$(PROJECT).qhcp" 99 | @echo "To view the help file:" 100 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/$(PROJECT).qhc" 101 | 102 | devhelp: doc-requirements 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $HOME/.local/share/devhelp/$(PROJECT)" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $HOME/.local/share/devhelp/$(PROJEC)" 109 | @echo "# devhelp" 110 | 111 | epub: doc-requirements 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | patch-latex: 117 | find _build/latex -iname "*.tex" | xargs -- \ 118 | sed -i'' 's~includegraphics{~includegraphics\[keepaspectratio,max size={\\textwidth}{\\textheight}\]{~g' 119 | 120 | latex: doc-requirements 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | $(MAKE) patch-latex 123 | @echo 124 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 125 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 126 | "(use \`make latexpdf' here to do that automatically)." 127 | 128 | latexpdf: doc-requirements 129 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 130 | $(MAKE) patch-latex 131 | @echo "Running LaTeX files through pdflatex..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | latexpdfja: doc-requirements 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through platex and dvipdfmx..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | text: doc-requirements 142 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 143 | @echo 144 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 145 | 146 | man: doc-requirements 147 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 148 | @echo 149 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 150 | 151 | texinfo: doc-requirements 152 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 153 | @echo 154 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 155 | @echo "Run \`make' in that directory to run these through makeinfo" \ 156 | "(use \`make info' here to do that automatically)." 157 | 158 | info: doc-requirements 159 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 160 | @echo "Running Texinfo files through makeinfo..." 161 | make -C $(BUILDDIR)/texinfo info 162 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 163 | 164 | gettext: doc-requirements 165 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 166 | @echo 167 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 168 | 169 | changes: doc-requirements 170 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 171 | @echo 172 | @echo "The overview file is in $(BUILDDIR)/changes." 173 | 174 | linkcheck: doc-requirements 175 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 176 | @echo 177 | @echo "Link check complete; look for any errors in the above output " \ 178 | "or in $(BUILDDIR)/linkcheck/output.txt." 179 | 180 | doctest: doc-requirements 181 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 182 | @echo "Testing of doctests in the sources finished, look at the " \ 183 | "results in $(BUILDDIR)/doctest/output.txt." 184 | 185 | xml: doc-requirements 186 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 187 | @echo 188 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 189 | 190 | pseudoxml: doc-requirements 191 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 192 | @echo 193 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 194 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGELOG.rst 3 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # This file is execfile()d with the current directory set to its containing dir. 4 | # 5 | # Note that not all possible configuration values are present in this 6 | # autogenerated file. 7 | # 8 | # All configuration values have a default; values that are commented out 9 | # serve to show the default. 10 | 11 | import os 12 | import sys 13 | import inspect 14 | import shutil 15 | 16 | __location__ = os.path.join( 17 | os.getcwd(), os.path.dirname(inspect.getfile(inspect.currentframe())) 18 | ) 19 | 20 | # If extensions (or modules to document with autodoc) are in another directory, 21 | # add these directories to sys.path here. If the directory is relative to the 22 | # documentation root, use os.path.abspath to make it absolute, like shown here. 23 | sys.path.insert(0, os.path.join(__location__, "../src")) 24 | 25 | # -- Run sphinx-apidoc ------------------------------------------------------ 26 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 27 | # `sphinx-build -b html . _build/html`. See Issue: 28 | # https://github.com/rtfd/readthedocs.org/issues/1139 29 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 30 | # setup.py install" in the RTD Advanced Settings. 31 | # Additionally it helps us to avoid running apidoc manually 32 | 33 | try: # for Sphinx >= 1.7 34 | from sphinx.ext import apidoc 35 | except ImportError: 36 | from sphinx import apidoc 37 | 38 | output_dir = os.path.join(__location__, "api") 39 | module_dir = os.path.join(__location__, "../src/pytorchxai") 40 | try: 41 | shutil.rmtree(output_dir) 42 | except FileNotFoundError: 43 | pass 44 | 45 | try: 46 | import sphinx 47 | from pkg_resources import parse_version 48 | 49 | cmd_line_template = "sphinx-apidoc -f -o {outputdir} {moduledir}" 50 | cmd_line = cmd_line_template.format(outputdir=output_dir, moduledir=module_dir) 51 | 52 | args = cmd_line.split(" ") 53 | if parse_version(sphinx.__version__) >= parse_version("1.7"): 54 | args = args[1:] 55 | 56 | apidoc.main(args) 57 | except Exception as e: 58 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 59 | 60 | # -- General configuration ----------------------------------------------------- 61 | 62 | # If your documentation needs a minimal Sphinx version, state it here. 63 | # needs_sphinx = '1.0' 64 | 65 | # Add any Sphinx extension module names here, as strings. They can be extensions 66 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 67 | extensions = [ 68 | "sphinx.ext.autodoc", 69 | "sphinx.ext.intersphinx", 70 | "sphinx.ext.todo", 71 | "sphinx.ext.autosummary", 72 | "sphinx.ext.viewcode", 73 | "sphinx.ext.coverage", 74 | "sphinx.ext.doctest", 75 | "sphinx.ext.ifconfig", 76 | "sphinx.ext.mathjax", 77 | "sphinx.ext.napoleon", 78 | ] 79 | 80 | # Add any paths that contain templates here, relative to this directory. 81 | templates_path = ["_templates"] 82 | 83 | # The suffix of source filenames. 84 | source_suffix = ".rst" 85 | 86 | # The encoding of source files. 87 | # source_encoding = 'utf-8-sig' 88 | 89 | # The master toctree document. 90 | master_doc = "index" 91 | 92 | # General information about the project. 93 | project = u"PyTorchXAI" 94 | copyright = u"2020, Tudor Cebere" 95 | 96 | # The version info for the project you're documenting, acts as replacement for 97 | # |version| and |release|, also used in various other places throughout the 98 | # built documents. 99 | # 100 | # The short X.Y version. 101 | version = "" # Is set by calling `setup.py docs` 102 | # The full version, including alpha/beta/rc tags. 103 | release = "" # Is set by calling `setup.py docs` 104 | 105 | # The language for content autogenerated by Sphinx. Refer to documentation 106 | # for a list of supported languages. 107 | # language = None 108 | 109 | # There are two options for replacing |today|: either, you set today to some 110 | # non-false value, then it is used: 111 | # today = '' 112 | # Else, today_fmt is used as the format for a strftime call. 113 | # today_fmt = '%B %d, %Y' 114 | 115 | # List of patterns, relative to source directory, that match files and 116 | # directories to ignore when looking for source files. 117 | exclude_patterns = ["_build"] 118 | 119 | # The reST default role (used for this markup: `text`) to use for all documents. 120 | # default_role = None 121 | 122 | # If true, '()' will be appended to :func: etc. cross-reference text. 123 | # add_function_parentheses = True 124 | 125 | # If true, the current module name will be prepended to all description 126 | # unit titles (such as .. function::). 127 | # add_module_names = True 128 | 129 | # If true, sectionauthor and moduleauthor directives will be shown in the 130 | # output. They are ignored by default. 131 | # show_authors = False 132 | 133 | # The name of the Pygments (syntax highlighting) style to use. 134 | pygments_style = "sphinx" 135 | 136 | # A list of ignored prefixes for module index sorting. 137 | # modindex_common_prefix = [] 138 | 139 | # If true, keep warnings as "system message" paragraphs in the built documents. 140 | # keep_warnings = False 141 | 142 | 143 | # -- Options for HTML output --------------------------------------------------- 144 | 145 | # The theme to use for HTML and HTML Help pages. See the documentation for 146 | # a list of builtin themes. 147 | html_theme = "alabaster" 148 | 149 | # Theme options are theme-specific and customize the look and feel of a theme 150 | # further. For a list of options available for each theme, see the 151 | # documentation. 152 | html_theme_options = {"sidebar_width": "300px", "page_width": "1200px"} 153 | 154 | # Add any paths that contain custom themes here, relative to this directory. 155 | # html_theme_path = [] 156 | 157 | # The name for this set of Sphinx documents. If None, it defaults to 158 | # " v documentation". 159 | try: 160 | from pytorchxai.plugin import __version__ as version 161 | except ImportError: 162 | pass 163 | else: 164 | release = version 165 | 166 | # A shorter title for the navigation bar. Default is the same as html_title. 167 | # html_short_title = None 168 | 169 | # The name of an image file (relative to this directory) to place at the top 170 | # of the sidebar. 171 | # html_logo = "" 172 | 173 | # The name of an image file (within the static path) to use as favicon of the 174 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 175 | # pixels large. 176 | # html_favicon = None 177 | 178 | # Add any paths that contain custom static files (such as style sheets) here, 179 | # relative to this directory. They are copied after the builtin static files, 180 | # so a file named "default.css" will overwrite the builtin "default.css". 181 | html_static_path = ["_static"] 182 | 183 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 184 | # using the given strftime format. 185 | # html_last_updated_fmt = '%b %d, %Y' 186 | 187 | # If true, SmartyPants will be used to convert quotes and dashes to 188 | # typographically correct entities. 189 | # html_use_smartypants = True 190 | 191 | # Custom sidebar templates, maps document names to template names. 192 | # html_sidebars = {} 193 | 194 | # Additional templates that should be rendered to pages, maps page names to 195 | # template names. 196 | # html_additional_pages = {} 197 | 198 | # If false, no module index is generated. 199 | # html_domain_indices = True 200 | 201 | # If false, no index is generated. 202 | # html_use_index = True 203 | 204 | # If true, the index is split into individual pages for each letter. 205 | # html_split_index = False 206 | 207 | # If true, links to the reST sources are added to the pages. 208 | # html_show_sourcelink = True 209 | 210 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 211 | # html_show_sphinx = True 212 | 213 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 214 | # html_show_copyright = True 215 | 216 | # If true, an OpenSearch description file will be output, and all pages will 217 | # contain a tag referring to it. The value of this option must be the 218 | # base URL from which the finished HTML is served. 219 | # html_use_opensearch = '' 220 | 221 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 222 | # html_file_suffix = None 223 | 224 | # Output file base name for HTML help builder. 225 | htmlhelp_basename = "pytorchxai-doc" 226 | 227 | 228 | # -- Options for LaTeX output -------------------------------------------------- 229 | 230 | latex_elements = { 231 | # The paper size ('letterpaper' or 'a4paper'). 232 | # 'papersize': 'letterpaper', 233 | # The font size ('10pt', '11pt' or '12pt'). 234 | # 'pointsize': '10pt', 235 | # Additional stuff for the LaTeX preamble. 236 | # 'preamble': '', 237 | } 238 | 239 | # Grouping the document tree into LaTeX files. List of tuples 240 | # (source start file, target name, title, author, documentclass [howto/manual]). 241 | latex_documents = [ 242 | ("index", "user_guide.tex", u"PyTorchXAI Documentation", u"Tudor Cebere", "manual"), 243 | ] 244 | 245 | # The name of an image file (relative to this directory) to place at the top of 246 | # the title page. 247 | # latex_logo = "" 248 | 249 | # For "manual" documents, if this is true, then toplevel headings are parts, 250 | # not chapters. 251 | # latex_use_parts = False 252 | 253 | # If true, show page references after internal links. 254 | # latex_show_pagerefs = False 255 | 256 | # If true, show URL addresses after external links. 257 | # latex_show_urls = False 258 | 259 | # Documents to append as an appendix to all manuals. 260 | # latex_appendices = [] 261 | 262 | # If false, no module index is generated. 263 | # latex_domain_indices = True 264 | 265 | # -- External mapping ------------------------------------------------------------ 266 | python_version = ".".join(map(str, sys.version_info[0:2])) 267 | intersphinx_mapping = { 268 | "sphinx": ("http://www.sphinx-doc.org/en/stable", None), 269 | "python": ("https://docs.python.org/" + python_version, None), 270 | "matplotlib": ("https://matplotlib.org", None), 271 | "numpy": ("https://docs.scipy.org/doc/numpy", None), 272 | "sklearn": ("http://scikit-learn.org/stable", None), 273 | "pandas": ("http://pandas.pydata.org/pandas-docs/stable", None), 274 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 275 | } 276 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | PyTorchXAI 3 | ========== 4 | 5 | This is the documentation of **PyTorchXAI**. 6 | 7 | .. note:: 8 | 9 | This is the main page of your project's `Sphinx`_ documentation. 10 | It is formatted in `reStructuredText`_. Add additional pages 11 | by creating rst-files in ``docs`` and adding them to the `toctree`_ below. 12 | Use then `references`_ in order to link them from this page, e.g. 13 | :ref:`authors` and :ref:`changes`. 14 | 15 | It is also possible to refer to the documentation of other Python packages 16 | with the `Python domain syntax`_. By default you can reference the 17 | documentation of `Sphinx`_, `Python`_, `NumPy`_, `SciPy`_, `matplotlib`_, 18 | `Pandas`_, `Scikit-Learn`_. You can add more by extending the 19 | ``intersphinx_mapping`` in your Sphinx's ``conf.py``. 20 | 21 | The pretty useful extension `autodoc`_ is activated by default and lets 22 | you include documentation from docstrings. Docstrings can be written in 23 | `Google style`_ (recommended!), `NumPy style`_ and `classical style`_. 24 | 25 | 26 | Contents 27 | ======== 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | 32 | License 33 | Authors 34 | Changelog 35 | Module Reference 36 | 37 | 38 | Indices and tables 39 | ================== 40 | 41 | * :ref:`genindex` 42 | * :ref:`modindex` 43 | * :ref:`search` 44 | 45 | .. _toctree: http://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html 46 | .. _reStructuredText: http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 47 | .. _references: http://www.sphinx-doc.org/en/stable/markup/inline.html 48 | .. _Python domain syntax: http://sphinx-doc.org/domains.html#the-python-domain 49 | .. _Sphinx: http://www.sphinx-doc.org/ 50 | .. _Python: http://docs.python.org/ 51 | .. _Numpy: http://docs.scipy.org/doc/numpy 52 | .. _SciPy: http://docs.scipy.org/doc/scipy/reference/ 53 | .. _matplotlib: https://matplotlib.org/contents.html# 54 | .. _Pandas: http://pandas.pydata.org/pandas-docs/stable 55 | .. _Scikit-Learn: http://scikit-learn.org/stable 56 | .. _autodoc: http://www.sphinx-doc.org/en/stable/ext/autodoc.html 57 | .. _Google style: https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings 58 | .. _NumPy style: https://numpydoc.readthedocs.io/en/latest/format.html 59 | .. _classical style: http://www.sphinx-doc.org/en/stable/domains.html#info-field-lists 60 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. include:: ../LICENSE.txt 8 | -------------------------------------------------------------------------------- /examples/tutorial_01_tensorboard_mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/examples/tutorial_01_tensorboard_mnist/__init__.py -------------------------------------------------------------------------------- /examples/tutorial_01_tensorboard_mnist/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/examples/tutorial_01_tensorboard_mnist/mnist/__init__.py -------------------------------------------------------------------------------- /examples/tutorial_01_tensorboard_mnist/mnist/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | train_batch_size = 64 5 | test_batch_size = 1000 6 | 7 | use_cuda = torch.cuda.is_available() 8 | 9 | kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} 10 | train_loader = torch.utils.data.DataLoader( 11 | datasets.MNIST( 12 | "data", 13 | train=True, 14 | download=True, 15 | transform=transforms.Compose( 16 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 17 | ), 18 | ), 19 | batch_size=train_batch_size, 20 | shuffle=True, 21 | **kwargs 22 | ) 23 | 24 | test_loader = torch.utils.data.DataLoader( 25 | datasets.MNIST( 26 | "data", 27 | train=False, 28 | transform=transforms.Compose( 29 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 30 | ), 31 | ), 32 | batch_size=test_batch_size, 33 | shuffle=True, 34 | **kwargs 35 | ) 36 | -------------------------------------------------------------------------------- /examples/tutorial_01_tensorboard_mnist/mnist/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | from examples.tutorial_01_tensorboard_mnist.mnist import model 5 | from examples.tutorial_01_tensorboard_mnist.mnist.dataloader import ( 6 | test_loader, 7 | train_loader 8 | ) 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | writer = SummaryWriter() 12 | 13 | epochs = 5 14 | lr = 0.01 15 | momentum = 0.5 16 | seed = 1 17 | save_model = True 18 | 19 | 20 | use_cuda = torch.cuda.is_available() 21 | torch.manual_seed(seed) 22 | 23 | device = torch.device("cuda" if use_cuda else "cpu") 24 | 25 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 26 | 27 | 28 | def train(model, device, train_loader, optimizer, epoch): 29 | model.train() 30 | avg_loss = 0 31 | correct = 0 32 | 33 | for batch_idx, (data, target) in enumerate(train_loader): 34 | data, target = data.to(device), target.to(device) 35 | optimizer.zero_grad() 36 | 37 | output = model(data) 38 | 39 | loss = F.nll_loss(output, target) 40 | loss.backward() 41 | 42 | optimizer.step() 43 | 44 | avg_loss += F.nll_loss(output, target, reduction="sum").item() 45 | 46 | pred = output.argmax(dim=1, keepdim=True) 47 | 48 | correct += pred.eq(target.view_as(pred)).sum().item() 49 | 50 | print( 51 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 52 | epoch, 53 | batch_idx * len(data), 54 | len(train_loader.dataset), 55 | 100.0 * batch_idx / len(train_loader), 56 | loss.item(), 57 | ) 58 | ) 59 | 60 | accuracy = 100.0 * correct / len(train_loader.dataset) 61 | avg_loss /= len(train_loader.dataset) 62 | 63 | writer.add_scalar("Loss/train", avg_loss, epoch) 64 | writer.add_scalar("Accuracy/train", accuracy, epoch) 65 | 66 | return avg_loss 67 | 68 | 69 | def model_test(model, device, test_loader): 70 | model.eval() 71 | test_loss = 0 72 | correct = 0 73 | with torch.no_grad(): 74 | # noinspection PyPackageRequirements 75 | for data, target in test_loader: 76 | data, target = data.to(device), target.to(device) 77 | output = model(data) 78 | test_loss += F.nll_loss(output, target, reduction="sum").item() 79 | pred = output.argmax(dim=1, keepdim=True) 80 | correct += pred.eq(target.view_as(pred)).sum().item() 81 | 82 | test_loss /= len(test_loader.dataset) 83 | 84 | print( 85 | "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( 86 | test_loss, 87 | correct, 88 | len(test_loader.dataset), 89 | 100.0 * correct / len(test_loader.dataset), 90 | ) 91 | ) 92 | accuracy = 100.0 * correct / len(test_loader.dataset) 93 | 94 | writer.add_scalar("Loss/test", test_loss, epoch) 95 | writer.add_scalar("Accuracy/test", accuracy, epoch) 96 | 97 | return test_loss, accuracy 98 | 99 | 100 | train_losses = [] 101 | test_losses = [] 102 | accuracy_list = [] 103 | for epoch in range(1, epochs + 1): 104 | trn_loss = train(model, device, train_loader, optimizer, epoch) 105 | test_loss, accuracy = model_test(model, device, test_loader) 106 | train_losses.append(trn_loss) 107 | test_losses.append(test_loss) 108 | accuracy_list.append(accuracy) 109 | 110 | 111 | writer.close() 112 | -------------------------------------------------------------------------------- /examples/tutorial_01_tensorboard_mnist/mnist/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | use_cuda = torch.cuda.is_available() 6 | device = torch.device("cuda" if use_cuda else "cpu") 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 13 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 14 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 15 | self.fc2 = nn.Linear(500, 10) 16 | 17 | def forward(self, x): 18 | x = F.relu(self.conv1(x)) 19 | x = F.max_pool2d(x, 2, 2) 20 | x = F.relu(self.conv2(x)) 21 | x = F.max_pool2d(x, 2, 2) 22 | x = x.view(-1, 4 * 4 * 50) 23 | x = F.relu(self.fc1(x)) 24 | x = self.fc2(x) 25 | return F.log_softmax(x, dim=1) 26 | 27 | 28 | model = Net().to(device) 29 | -------------------------------------------------------------------------------- /examples/tutorial_02_saliency_map/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torchvision.transforms as T 4 | from PIL import Image 5 | from torch.utils.tensorboard import SummaryWriter 6 | from torchvision import models 7 | 8 | from pytorchxai.xai.utils import preprocess_image 9 | from pytorchxai.xai.visualizations import GradientVisualization 10 | 11 | pretrained_model = models.alexnet(pretrained=True) 12 | writer = SummaryWriter() 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--image") 16 | parser.add_argument("--target_class") 17 | args = parser.parse_args() 18 | 19 | original_image = Image.open(args.image).convert("RGB") 20 | writer.add_image("input", T.ToTensor()(original_image), 0) 21 | 22 | images = [] 23 | 24 | original_image = original_image.resize((224, 224), Image.ANTIALIAS) 25 | prep_img = preprocess_image(original_image) 26 | 27 | print(prep_img.shape) 28 | vis = GradientVisualization(pretrained_model) 29 | 30 | output = vis.generate(original_image, prep_img, int(args.target_class)) 31 | 32 | for g in output: 33 | print("adding ", g) 34 | writer.add_image(g, output[g], 0) 35 | 36 | writer.close() 37 | -------------------------------------------------------------------------------- /misc/medtorch/Artboard 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/medtorch/Artboard 2.png -------------------------------------------------------------------------------- /misc/medtorch/Artboard 2.svg: -------------------------------------------------------------------------------- 1 | Artboard 2MEDTORCH -------------------------------------------------------------------------------- /misc/medtorch/Artboard 2@4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/medtorch/Artboard 2@4x.png -------------------------------------------------------------------------------- /misc/q&aid-mini/1x/Artboard 3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/1x/Artboard 3.png -------------------------------------------------------------------------------- /misc/q&aid-mini/1x/Artboard 4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/1x/Artboard 4.png -------------------------------------------------------------------------------- /misc/q&aid-mini/1x/Artboard 5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/1x/Artboard 5.png -------------------------------------------------------------------------------- /misc/q&aid-mini/4x/Artboard 3@4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/4x/Artboard 3@4x.png -------------------------------------------------------------------------------- /misc/q&aid-mini/4x/Artboard 4@4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/4x/Artboard 4@4x.png -------------------------------------------------------------------------------- /misc/q&aid-mini/4x/Artboard 5@4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid-mini/4x/Artboard 5@4x.png -------------------------------------------------------------------------------- /misc/q&aid-mini/SVG/Artboard 3.svg: -------------------------------------------------------------------------------- 1 | Artboard 3& -------------------------------------------------------------------------------- /misc/q&aid-mini/SVG/Artboard 4.svg: -------------------------------------------------------------------------------- 1 | Artboard 4& -------------------------------------------------------------------------------- /misc/q&aid-mini/SVG/Artboard 5.svg: -------------------------------------------------------------------------------- 1 | Artboard 5 -------------------------------------------------------------------------------- /misc/q&aid.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid.ai -------------------------------------------------------------------------------- /misc/q&aid/Artboard 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid/Artboard 1.png -------------------------------------------------------------------------------- /misc/q&aid/Artboard 1.svg: -------------------------------------------------------------------------------- 1 | Artboard 1& -------------------------------------------------------------------------------- /misc/q&aid/Artboard 1@4x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q&aid/Artboard 1@4x.png -------------------------------------------------------------------------------- /misc/q_aid_logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q_aid_logo_small.png -------------------------------------------------------------------------------- /misc/q_aid_logo_small1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/misc/q_aid_logo_small1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | torchvision 3 | torch 4 | pytest-cov 5 | matplotlib 6 | tensorboard 7 | # ============================================================================= 8 | # DEPRECATION WARNING: 9 | # 10 | # The file `requirements.txt` does not influence the package dependencies and 11 | # will not be automatically created in the next version of PyScaffold (v4.x). 12 | # 13 | # Please have look at the docs for better alternatives 14 | # (`Dependency Management` section). 15 | # ============================================================================= 16 | # 17 | # Add your pinned requirements so that they can be easily installed with: 18 | # pip install -r requirements.txt 19 | # Remember to also add them in setup.cfg but unpinned. 20 | # Example: 21 | # numpy==1.13.3 22 | # scipy==1.0 23 | # 24 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files 4 | 5 | [metadata] 6 | name = PyTorchXAI 7 | description = Add a short description here! 8 | author = Tudor Cebere 9 | author-email = tudorcebere@gmail.com 10 | license = mit 11 | long-description = file: README.rst 12 | long-description-content-type = text/x-rst; charset=UTF-8 13 | url = https://github.com/pyscaffold/pyscaffold/ 14 | project-urls = 15 | Documentation = https://pyscaffold.org/ 16 | # Change if running only on Windows, Mac or Linux (comma-separated) 17 | platforms = any 18 | # Add here all kinds of additional classifiers as defined under 19 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 20 | classifiers = 21 | Development Status :: 4 - Beta 22 | Programming Language :: Python 23 | 24 | [options] 25 | zip_safe = False 26 | packages = find: 27 | include_package_data = True 28 | package_dir = 29 | = src 30 | # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! 31 | setup_requires = pyscaffold>=3.2a0,<3.3a0 32 | # Add here dependencies of your project (semicolon/line-separated), e.g. 33 | # install_requires = numpy; scipy 34 | # The usage of test_requires is discouraged, see `Dependency Management` docs 35 | # tests_require = pytest; pytest-cov 36 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4 37 | # python_requires = >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* 38 | 39 | [options.packages.find] 40 | where = src 41 | exclude = 42 | tests 43 | 44 | [options.extras_require] 45 | # Add here additional requirements for extra features, to install with: 46 | # `pip install PyTorchXAI[PDF]` like: 47 | # PDF = ReportLab; RXP 48 | # Add here test requirements (semicolon/line-separated) 49 | testing = 50 | pytest 51 | pytest-cov 52 | 53 | [options.entry_points] 54 | # Add here console scripts like: 55 | # console_scripts = 56 | # script_name = pytorchxai.module:function 57 | # For example: 58 | # console_scripts = 59 | # fibonacci = pytorchxai.skeleton:run 60 | # And any other entry points, for example: 61 | # pyscaffold.cli = 62 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 63 | tensorboard_plugins= 64 | pytorchxai = pytorchxai.plugin.pytorchxai_plugin:PyTorchXAIPlugin 65 | 66 | [test] 67 | # py.test options when running `python setup.py test` 68 | # addopts = --verbose 69 | extras = True 70 | 71 | [tool:pytest] 72 | # Options for py.test: 73 | # Specify command line options as you would do when invoking py.test directly. 74 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 75 | # in order to write a coverage file that can be read by Jenkins. 76 | addopts = 77 | --cov pytorchxai --cov-report term-missing 78 | --verbose 79 | norecursedirs = 80 | dist 81 | build 82 | .tox 83 | testpaths = tests 84 | 85 | [aliases] 86 | dists = bdist_wheel 87 | 88 | [bdist_wheel] 89 | # Use this option if your package is pure-python 90 | universal = 1 91 | 92 | [build_sphinx] 93 | source_dir = docs 94 | build_dir = build/sphinx 95 | 96 | [devpi:upload] 97 | # Options for the devpi: PyPI server and packaging tool 98 | # VCS export must be deactivated since we are using setuptools-scm 99 | no-vcs = 1 100 | formats = bdist_wheel 101 | 102 | [flake8] 103 | # Some sane defaults for the code style checker flake8 104 | ignore = E501, E402, # line too long 105 | exclude = 106 | .tox 107 | build 108 | dist 109 | .eggs 110 | docs/conf.py 111 | 112 | [pyscaffold] 113 | # PyScaffold's parameters when the project was created. 114 | # This will be used when updating. Do not change! 115 | version = 3.2.3 116 | package = pytorchxai 117 | extensions = 118 | pre_commit 119 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Setup file for pytorchxai. 4 | Use setup.cfg to configure your project. 5 | 6 | This file was generated with PyScaffold 3.2.3. 7 | PyScaffold helps you to put up the scaffold of your new Python project. 8 | Learn more under: https://pyscaffold.org/ 9 | """ 10 | import sys 11 | from pkg_resources import VersionConflict, require 12 | from setuptools import setup 13 | 14 | try: 15 | require("setuptools>=38.3") 16 | except VersionConflict: 17 | print("Error: version of setuptools is too old (<38.3)!") 18 | sys.exit(1) 19 | 20 | 21 | if __name__ == "__main__": 22 | setup(use_pyscaffold=True) 23 | -------------------------------------------------------------------------------- /src/pytorchxai/__init__.py: -------------------------------------------------------------------------------- 1 | from . import plugin 2 | from . import xai 3 | 4 | __all__ = ["plugin", "xai"] 5 | -------------------------------------------------------------------------------- /src/pytorchxai/plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pytorchxai_plugin 2 | 3 | __all__ = ["pytorchxai_plugin"] 4 | -------------------------------------------------------------------------------- /src/pytorchxai/plugin/pytorchxai/static/index.js: -------------------------------------------------------------------------------- 1 | export async function render() { 2 | const input_img = document.createElement("input"); 3 | const input_preview = document.createElement("img"); 4 | 5 | input_img.name = "input_img"; 6 | input_img.type = "file"; 7 | input_img.addEventListener("change", function() {showImage(input_img);}); 8 | 9 | input_preview.id = "input_preview"; 10 | 11 | document.body.appendChild(input_img); 12 | document.body.appendChild(input_preview); 13 | } 14 | 15 | function showImage(input) 16 | { 17 | var reader; 18 | const input_preview = document.getElementById("input_preview"); 19 | 20 | if (input.files && input.files[0]) { 21 | reader = new FileReader(); 22 | 23 | reader.onload = function(e) { 24 | input_preview.setAttribute('src', e.target.result); 25 | } 26 | 27 | reader.readAsDataURL(input.files[0]); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/pytorchxai/plugin/pytorchxai/static/style.css: -------------------------------------------------------------------------------- 1 | .container { 2 | height: 200px; 3 | position: relative; 4 | border: 3px solid green; 5 | } 6 | -------------------------------------------------------------------------------- /src/pytorchxai/plugin/pytorchxai_plugin.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import mimetypes 4 | from pathlib import Path 5 | 6 | from tensorboard.backend import http_util 7 | from tensorboard.plugins import base_plugin 8 | from werkzeug import wrappers 9 | 10 | 11 | class PyTorchXAIPlugin(base_plugin.TBPlugin): 12 | plugin_name = "pytorchxai" 13 | 14 | def __init__(self, context): 15 | """Instantiates ExamplePlugin. 16 | 17 | Args: 18 | context: A base_plugin.TBContext instance. 19 | """ 20 | plugin_directory_path_part = f"/data/plugin/{self.plugin_name}/" 21 | self._multiplexer = context.multiplexer 22 | self._offset_path = len(plugin_directory_path_part) 23 | self._prefix_path = Path(__file__).parent / "pytorchxai" 24 | 25 | def is_active(self): 26 | """Returns whether there is relevant data for the plugin to process. 27 | 28 | When there are no runs with greeting data, TensorBoard will hide the 29 | plugin from the main navigation bar. 30 | """ 31 | return bool(self._multiplexer.PluginRunToTagToContent(self.plugin_name)) 32 | 33 | def get_plugin_apps(self): 34 | return { 35 | "/static/*": self._serve_static_file, 36 | } 37 | 38 | def frontend_metadata(self): 39 | return base_plugin.FrontendMetadata( 40 | es_module_path="/static/index.js", tab_name="PyTorchXAI" 41 | ) 42 | 43 | @wrappers.Request.application 44 | def _serve_static_file(self, request): 45 | static_path_part = request.path[self._offset_path:] 46 | resource_path = Path(static_path_part) 47 | 48 | if not resource_path.parent != "static": 49 | return http_util.Respond( 50 | request, "Resource not found", "text/plain", code=404 51 | ) 52 | 53 | resource_absolute_path = str(self._prefix_path / resource_path) 54 | with open(resource_absolute_path, "rb") as read_file: 55 | mimetype = mimetypes.guess_type(resource_absolute_path)[0] 56 | return http_util.Respond(request, read_file.read(), content_type=mimetype) 57 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils, visualizations 2 | 3 | __all__ = ["utils", "visualizations"] 4 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/cam_gradcam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gradient-weighted Class Activation Mapping(Grad-CAM) is an algorithm that can be used to visualize the class activation maps of a Convolutional Neural Network. 3 | 4 | Algorithm details: 5 | - The algorithm finds the final convolutional layer in the network. 6 | - It examines the gradient information flowing into that layer. 7 | - The output of Grad-CAM is a heatmap visualization for a given class label. 8 | 9 | [1] Selvaraju, Ramprasaath R., et al. "Grad-cam: Visual explanations from deep networks via gradient-based localization." Proceedings of the IEEE International Conference on Computer Vision. 2017. 10 | """ 11 | import numpy as np 12 | import torch 13 | import torchvision.transforms as T 14 | from PIL import Image 15 | 16 | from pytorchxai.xai.cam_utils import CamExtractor 17 | from pytorchxai.xai.utils import apply_colormap_on_image 18 | 19 | 20 | class GradCam: 21 | def __init__(self, model): 22 | self.model = model 23 | self.model.eval() 24 | self.extractor = CamExtractor(self.model) 25 | 26 | def generate_cam(self, input_image, target_class=None): 27 | """ 28 | Does a full forward pass on the model and generates the activations maps. 29 | Args: 30 | input_image: input image 31 | target_class: optional target class 32 | Returns: 33 | The activations maps. 34 | """ 35 | conv_output, model_output = self.extractor.forward_pass(input_image) 36 | if target_class is None: 37 | target_class = np.argmax(model_output.data.numpy()) 38 | 39 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 40 | one_hot_output[0][target_class] = 1 41 | 42 | self.model.features.zero_grad() 43 | self.model.classifier.zero_grad() 44 | 45 | model_output.backward(gradient=one_hot_output, retain_graph=True) 46 | guided_gradients = self.extractor.gradients.data.numpy()[0] 47 | target = conv_output.data.numpy()[0] 48 | 49 | weights = np.mean(guided_gradients, axis=(1, 2)) 50 | 51 | cam = np.ones(target.shape[1:], dtype=np.float32) 52 | 53 | for i, w in enumerate(weights): 54 | cam += w * target[i, :, :] 55 | cam = np.maximum(cam, 0) 56 | cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) 57 | cam = np.uint8(cam * 255) 58 | cam = ( 59 | np.uint8( 60 | Image.fromarray(cam).resize( 61 | (input_image.shape[2], input_image.shape[3]), Image.ANTIALIAS, 62 | ) 63 | ) / 255 64 | ) 65 | return cam 66 | 67 | def generate(self, orig_image, input_image, target_class=None): 68 | """ 69 | Generates and returns the activations maps. 70 | 71 | Args: 72 | orig_image: Original resized image. 73 | input_image: Preprocessed input image. 74 | target_class: Expected category. 75 | Returns: 76 | Colored and grayscale Grad-Cam heatmaps. 77 | Heatmap over the original image 78 | """ 79 | cam = self.generate_cam(input_image, target_class) 80 | heatmap, heatmap_on_image = apply_colormap_on_image(orig_image, cam, "hsv") 81 | return { 82 | "gradcam_heatmap": T.ToTensor()(heatmap), 83 | "gradcam_heatmap_on_image": T.ToTensor()(heatmap_on_image), 84 | "gradcam_grayscale": T.ToTensor()(cam), 85 | } 86 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/cam_scorecam.py: -------------------------------------------------------------------------------- 1 | """Score-CAM is a gradient-free visualization method, extended from Grad-CAM and Grad-CAM++. It achieves better visual performance and fairness for interpreting the decision making process. 2 | 3 | Algorithm details: 4 | - For an input image, it extracts and saves the K activation maps from the last convolutional layer. 5 | - The activation maps are normalized using the maximum and minimum for each map. 6 | - The activations maps are multiplied by the original image to create K images. 7 | - Each of the generated images becomes an input to the CNN, and the probability of the target class (probability through softmax rather than score) is calculated. 8 | - The K probability values are regarded as the importance levels of the K Activation Maps, and the importance levels are multiplied by the activation maps, and then added together to obtain the ScoreCAM. 9 | - 10 | 11 | [1] Wang et al. "Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks" 12 | 13 | """ 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision.transforms as T 18 | from PIL import Image 19 | 20 | from pytorchxai.xai.cam_utils import CamExtractor 21 | from pytorchxai.xai.utils import apply_colormap_on_image 22 | 23 | 24 | class ScoreCam: 25 | def __init__(self, model): 26 | self.model = model 27 | self.model.eval() 28 | self.extractor = CamExtractor(self.model) 29 | 30 | def generate_cam(self, input_image, target_class=None): 31 | """ 32 | Does a full forward pass on the model and generates the activations maps. 33 | Args: 34 | input_image: input image 35 | target_class: optional target class 36 | Returns: 37 | The activations maps. 38 | """ 39 | conv_output, model_output = self.extractor.forward_pass(input_image) 40 | if target_class is None: 41 | target_class = np.argmax(model_output.data.numpy()) 42 | 43 | # Get convolution outputs 44 | target = conv_output[0] 45 | # Create empty numpy array for cam 46 | cam = np.ones(target.shape[1:], dtype=np.float32) 47 | # Multiply each weight with its conv output and then, sum 48 | for i in range(len(target)): 49 | # Unsqueeze to 4D 50 | saliency_map = torch.unsqueeze(torch.unsqueeze(target[i, :, :], 0), 0) 51 | # Upsampling to input size 52 | saliency_map = F.interpolate( 53 | saliency_map, size=(224, 224), mode="bilinear", align_corners=False, 54 | ) 55 | if saliency_map.max() == saliency_map.min(): 56 | continue 57 | # Scale between 0-1 58 | norm_saliency_map = (saliency_map - saliency_map.min()) / ( 59 | saliency_map.max() - saliency_map.min() 60 | ) 61 | # Get the target score 62 | w = F.softmax( 63 | self.extractor.forward_pass(input_image * norm_saliency_map)[1], dim=1, 64 | )[0][target_class] 65 | cam += w.data.numpy() * target[i, :, :].data.numpy() 66 | cam = np.maximum(cam, 0) 67 | cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize between 0-1 68 | cam = np.uint8(cam * 255) # Scale between 0-255 to visualize 69 | cam = ( 70 | np.uint8( 71 | Image.fromarray(cam).resize( 72 | (input_image.shape[2], input_image.shape[3]), Image.ANTIALIAS, 73 | ) 74 | ) / 255 75 | ) 76 | return cam 77 | 78 | def generate(self, orig_image, input_image, target_class=None): 79 | """ 80 | Generates and returns the activations maps. 81 | 82 | Args: 83 | orig_image: Original resized image. 84 | input_image: Preprocessed input image. 85 | target_class: Expected category. 86 | Returns: 87 | Colored and grayscale ScoreCam heatmaps. 88 | Heatmap over the original image 89 | """ 90 | 91 | cam = self.generate_cam(input_image, target_class) 92 | heatmap, heatmap_on_image = apply_colormap_on_image(orig_image, cam, "hsv") 93 | return { 94 | "scorecam_heatmap": T.ToTensor()(heatmap), 95 | "scorecam_heatmap_on_image": T.ToTensor()(heatmap_on_image), 96 | "scorecam_grayscale": T.ToTensor()(cam), 97 | } 98 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/cam_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class Activation Mapping helpers 3 | """ 4 | import torch 5 | 6 | 7 | class CamExtractor: 8 | def __init__(self, model): 9 | self.model = model 10 | self.last_conv = None 11 | for module_pos, module in self.model.features._modules.items(): 12 | if isinstance(module, torch.nn.Conv2d): 13 | self.last_conv = module_pos 14 | if not self.last_conv: 15 | raise ("invalid input model") 16 | 17 | self.gradients = None 18 | 19 | def _save_gradient(self, grad): 20 | self.gradients = grad 21 | 22 | def _forward_pass_on_convolutions(self, x): 23 | """ 24 | Does a forward pass on convolutions, hooks the function at given layer 25 | Args: 26 | x: input image 27 | Returns: 28 | The output of the last convolutional layer. 29 | The output of the model. 30 | """ 31 | conv_output = None 32 | for module_pos, module in self.model.features._modules.items(): 33 | x = module(x) 34 | if module_pos == self.last_conv: 35 | x.register_hook(self._save_gradient) 36 | conv_output = x 37 | return conv_output, x 38 | 39 | def forward_pass(self, x): 40 | """ 41 | Does a full forward pass on the model 42 | Args: 43 | x: input image 44 | Returns: 45 | The output of the last convolutional layer. 46 | The output of the model. 47 | """ 48 | conv_output, x = self._forward_pass_on_convolutions(x) 49 | x = x.view(x.size(0), -1) 50 | x = self.model.classifier(x) 51 | return conv_output, x 52 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/gradient_guided_backprop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Guided Backpropagation(guided saliency) generates heat maps that are intended to provide insight into what aspects of an input image a convolutional neural network is using to make a prediction, focusing on what image features the neuron detects, not in what kind of stuff it doesn’t detect. 3 | 4 | Algorithm details: 5 | - We backpropagate positive error signals – i.e. we set the negative gradients to zero. This is the application of the ReLU to the error signal itself during the backward pass. 6 | - Like vanilla backpropagation, we also restrict ourselves to only positive inputs. 7 | 8 | [1] Springenberg et al. "Striving for Simplicity: The All Convolutional Net", 2014. 9 | """ 10 | import torch 11 | from torch.nn import ReLU 12 | 13 | from pytorchxai.xai.utils import ( 14 | convert_to_grayscale, 15 | get_positive_negative_saliency, 16 | normalize_gradient 17 | ) 18 | 19 | 20 | class GuidedBackprop: 21 | def __init__(self, model): 22 | self.model = model 23 | self.gradients = None 24 | self.forward_relu_outputs = [] 25 | 26 | self.model.eval() 27 | self._update_relus() 28 | self._hook_layers() 29 | 30 | def _hook_layers(self): 31 | """ 32 | Method for registering a hook to the first layer 33 | """ 34 | 35 | def hook_function(module, grad_in, grad_out): 36 | self.gradients = grad_in[0] 37 | 38 | first_layer = list(self.model.features._modules.items())[0][1] 39 | first_layer.register_backward_hook(hook_function) 40 | 41 | def _update_relus(self): 42 | """ 43 | Updates relu activation functions so that: 44 | - they store the output in the forward pass. 45 | - they set the negative gradients to zero. 46 | """ 47 | 48 | def relu_backward_hook_function(module, grad_in, grad_out): 49 | """ 50 | If there is a negative gradient, change it to zero 51 | """ 52 | corresponding_forward_output = self.forward_relu_outputs[-1] 53 | corresponding_forward_output[corresponding_forward_output > 0] = 1 54 | modified_grad_out = corresponding_forward_output * torch.clamp( 55 | grad_in[0], min=0.0 56 | ) 57 | del self.forward_relu_outputs[-1] 58 | return (modified_grad_out,) 59 | 60 | def relu_forward_hook_function(module, ten_in, ten_out): 61 | """ 62 | Store results of forward pass 63 | """ 64 | self.forward_relu_outputs.append(ten_out) 65 | 66 | for pos, module in self.model.features._modules.items(): 67 | if isinstance(module, ReLU): 68 | module.register_backward_hook(relu_backward_hook_function) 69 | module.register_forward_hook(relu_forward_hook_function) 70 | 71 | def generate_gradients(self, input_image, target_class): 72 | """ 73 | Generates the gradients using guided backpropagation from the given model and image. 74 | 75 | Args: 76 | input_image: Preprocessed input image. 77 | target_class: Expected category. 78 | Returns: 79 | The gradients computed using the guided backpropagation. 80 | """ 81 | model_output = self.model(input_image) 82 | self.model.zero_grad() 83 | 84 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 85 | one_hot_output[0][target_class] = 1 86 | 87 | model_output.backward(gradient=one_hot_output) 88 | 89 | gradients_as_arr = self.gradients.data.numpy()[0] 90 | return gradients_as_arr 91 | 92 | def generate(self, orig_image, input_image, target_class): 93 | """ 94 | Generates and returns multiple saliency maps, based on guided backpropagation. 95 | 96 | Args: 97 | orig_image: Original resized image. 98 | input_image: Preprocessed input image. 99 | target_class: Expected category. 100 | Returns: 101 | Colored and grayscale gradients for the guided backpropagation. 102 | Positive and negative saliency maps. 103 | Grayscale gradients multiplied with the image itself. 104 | """ 105 | guided_grads = self.generate_gradients(input_image, target_class) 106 | 107 | color_guided_grads = normalize_gradient(guided_grads) 108 | grayscale_guided_grads = normalize_gradient(convert_to_grayscale(guided_grads)) 109 | 110 | pos_sal, neg_sal = get_positive_negative_saliency(guided_grads) 111 | pos_sal_grads = normalize_gradient(pos_sal) 112 | neg_sal_grads = normalize_gradient(neg_sal) 113 | 114 | grad_times_image = guided_grads[0] * input_image.detach().numpy()[0] 115 | grad_times_image = convert_to_grayscale(grad_times_image) 116 | grad_times_image = normalize_gradient(grad_times_image) 117 | 118 | return { 119 | "guided_grads_colored": color_guided_grads, 120 | "guided_grads_grayscale": grayscale_guided_grads, 121 | "guided_grads_grayscale_grad_times_image": grad_times_image, 122 | "saliency_maps_positive": pos_sal_grads, 123 | "saliency_maps_negative": neg_sal_grads, 124 | } 125 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/gradient_guided_gradcam.py: -------------------------------------------------------------------------------- 1 | """ 2 | GradCAM helps vizualing which parts of an input image trigger the predicted class, by backpropagating the gradients to the last convolutional layer, producing a coarse heatmap. 3 | 4 | Guided GradCAM is obtained by fusing GradCAM with Guided Backpropagation via element-wise multiplication, and results in a heatmap highliting much finer details. 5 | 6 | This technique is only useful for inspecting an already trained network, not for training it, as the backpropagation on ReLU will be changed for computing the Guided Backpropagation. 7 | 8 | [1] Selvaraju, Ramprasaath R., et al. "Grad-cam: Visual explanations from deep networks via gradient-based localization." Proceedings of the IEEE International Conference on Computer Vision. 2017. 9 | """ 10 | import numpy as np 11 | 12 | from pytorchxai.xai.cam_gradcam import GradCam 13 | from pytorchxai.xai.gradient_guided_backprop import GuidedBackprop 14 | from pytorchxai.xai.utils import convert_to_grayscale, normalize_gradient 15 | 16 | 17 | class GuidedGradCam: 18 | def __init__(self, model): 19 | self.model = model 20 | 21 | self.gradcam = GradCam(model) 22 | self.gbp = GuidedBackprop(model) 23 | 24 | def generate(self, orig_image, input_image, target_class): 25 | """ 26 | Guided gradcam is just pointwise multiplication of the cam mask and 27 | the guided backprop mask. 28 | 29 | Args: 30 | orig_image: Original resized image. 31 | input_image: Preprocessed input image. 32 | target_class: Expected category. 33 | Returns: 34 | Colored and grayscale gradients for the guided Grad-CAM. 35 | """ 36 | cam = self.gradcam.generate_cam(input_image, target_class) 37 | guided_grads = self.gbp.generate_gradients(input_image, target_class) 38 | 39 | cam_gb = np.multiply(cam, guided_grads) 40 | 41 | guided_gradcam = normalize_gradient(cam_gb) 42 | grayscale_cam_gb = convert_to_grayscale(cam_gb) 43 | guided_gradcam_grayscale = normalize_gradient(grayscale_cam_gb) 44 | 45 | return { 46 | "guided_gradcam": guided_gradcam, 47 | "guided_gradcam_grayscale": guided_gradcam_grayscale, 48 | } 49 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/gradient_integrated_grad.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrated Gradients aims to explain the relationship between a model's predictions in terms of its features. It has many use cases including understanding feature importances, identifying data skew, and debugging model performance. 3 | 4 | Algorithm details: 5 | - It constructs a sequence of images interpolating from a baseline (black) to the actual image. 6 | - It averages the gradients across these images. 7 | 8 | Other use cases: 9 | - Text Classification 10 | - Language translation 11 | - Search Ranking 12 | 13 | [1] Sundararajan, Taly, Yan et al. "Axiomatic Attribution for Deep Networks.", Proceedings of International Conference on Machine Learning (ICML), 2017. 14 | [2] https://github.com/ankurtaly/Integrated-Gradients 15 | """ 16 | import numpy as np 17 | import torch 18 | 19 | from pytorchxai.xai.utils import convert_to_grayscale, normalize_gradient 20 | 21 | 22 | class IntegratedGradients: 23 | """ 24 | Produces gradients generated with integrated gradients from the image 25 | """ 26 | 27 | def __init__(self, model): 28 | self.model = model 29 | self.gradients = None 30 | 31 | self.model.eval() 32 | self._hook_layers() 33 | 34 | def _hook_layers(self): 35 | def hook_function(module, grad_in, grad_out): 36 | self.gradients = grad_in[0] 37 | 38 | first_layer = list(self.model.features._modules.items())[0][1] 39 | first_layer.register_backward_hook(hook_function) 40 | 41 | def generate_images_on_linear_path(self, input_image, steps): 42 | """ 43 | Generates "steps" intermediary images. 44 | 45 | Args: 46 | input_image: Preprocessed input image. 47 | steps: Numbers of intermediary images. 48 | Returns: 49 | An array of "steps" images. 50 | """ 51 | step_list = np.arange(steps + 1) / steps 52 | 53 | return [input_image * step for step in step_list] 54 | 55 | def generate_gradients(self, input_image, target_class): 56 | """ 57 | Generates the gradients for the given model and image. 58 | 59 | Args: 60 | input_image: Preprocessed input image. 61 | target_class: Expected category. 62 | Returns: 63 | The gradients. 64 | """ 65 | model_output = self.model(input_image) 66 | self.model.zero_grad() 67 | 68 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 69 | one_hot_output[0][target_class] = 1 70 | 71 | model_output.backward(gradient=one_hot_output) 72 | 73 | return self.gradients.data.numpy()[0] 74 | 75 | def generate_integrated_gradients(self, input_image, target_class, steps): 76 | """ 77 | Generates "steps" intermediary images and generates gradients for all of them. Returns the average of the gradients. 78 | 79 | Args: 80 | input_image: Preprocessed input image. 81 | target_class: Expected category. 82 | steps: Numbers of intermediay images. 83 | Returns: 84 | The gradients' average. 85 | """ 86 | xbar_list = self.generate_images_on_linear_path(input_image, steps) 87 | integrated_grads = np.zeros(input_image.size()) 88 | 89 | for xbar_image in xbar_list: 90 | single_integrated_grad = self.generate_gradients(xbar_image, target_class) 91 | integrated_grads = integrated_grads + single_integrated_grad / steps 92 | 93 | return integrated_grads[0] 94 | 95 | def generate(self, orig_image, input_image, target_class): 96 | """ 97 | Generates heatmaps using the integrated gradient method. 98 | 99 | Args: 100 | orig_image: The original image. 101 | input_image: Preprocessed input image. 102 | target_class: Expected category. 103 | Returns: 104 | The heatmaps. 105 | """ 106 | integrated_grads = self.generate_integrated_gradients( 107 | input_image, target_class, 5 108 | ) 109 | grayscale_integrated_grads = normalize_gradient( 110 | convert_to_grayscale(integrated_grads) 111 | ) 112 | 113 | grad_times_image = integrated_grads[0] * input_image.detach().numpy()[0] 114 | grad_times_image = convert_to_grayscale(grad_times_image) 115 | grad_times_image = normalize_gradient(grad_times_image) 116 | 117 | return { 118 | "integrated_gradients": grayscale_integrated_grads, 119 | "integrated_gradients_times_image": grad_times_image, 120 | } 121 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/gradient_smooth_grad.py: -------------------------------------------------------------------------------- 1 | """ 2 | SmoothGrad technique is adding some Gaussian noise to the original image and calculating gradients multiple times and averaging the results. 3 | It can augment other sensitivity techniques, such as: vanilla gradients, integrated gradients, guided backpropagation or GradCam. 4 | 5 | 6 | [1] D. Smilkov, N. Thorat, N. Kim, F. Viégas, M. Wattenberg. "SmoothGrad: removing noise by adding noise", 2017. 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | from pytorchxai.xai.gradient_vanilla_backprop import VanillaBackprop 14 | from pytorchxai.xai.utils import convert_to_grayscale, normalize_gradient 15 | 16 | 17 | class SmoothGrad: 18 | def __init__(self, model): 19 | self.model = model 20 | 21 | self.backprop = VanillaBackprop(model) 22 | 23 | def generate_smooth_grad( 24 | self, prep_img, target_class, param_n, param_sigma_multiplier 25 | ): 26 | """ 27 | Generates smooth gradients of given backprop type: vanilla or guided. 28 | Args: 29 | prep_img (torch Variable): preprocessed image. 30 | target_class (int): target class of imagenet 31 | param_n (int): Amount of images used to smooth gradient. 32 | param_sigma_multiplier (int): Sigma multiplier when calculating std of noise. 33 | Returns: 34 | The gradients. 35 | """ 36 | # Generate an empty image/matrix 37 | smooth_grad = np.zeros(prep_img.size()[1:]) 38 | 39 | mean = 0 40 | sigma = ( 41 | param_sigma_multiplier / (torch.max(prep_img) - torch.min(prep_img)).item() 42 | ) 43 | for x in range(param_n): 44 | # Generate noise 45 | noise = Variable( 46 | prep_img.data.new(prep_img.size()).normal_(mean, sigma ** 2) 47 | ) 48 | # Add noise to the image 49 | noisy_img = prep_img + noise 50 | # Calculate gradients 51 | vanilla_grads = self.backprop.generate_gradients(noisy_img, target_class) 52 | # Add gradients to smooth_grad 53 | smooth_grad = smooth_grad + vanilla_grads 54 | # Average it out 55 | smooth_grad = smooth_grad / param_n 56 | return smooth_grad 57 | 58 | def generate(self, orig_image, input_image, target_class): 59 | """ 60 | Generates and returns multiple sensitivy heatmaps, based on SmoothGrad technique. 61 | Args: 62 | orig_image: Original resized image. 63 | input_image: Preprocessed input image. 64 | target_class: Expected category. 65 | Returns: 66 | Colored and grayscale gradients for the SmoothGrad backpropagation. 67 | """ 68 | param_n = 5 69 | param_sigma_multiplier = 4 70 | smooth_grad = self.generate_smooth_grad( 71 | input_image, target_class, param_n, param_sigma_multiplier 72 | ) 73 | 74 | color_smooth_grad = normalize_gradient(smooth_grad) 75 | 76 | grayscale_smooth_grad = convert_to_grayscale(smooth_grad) 77 | grayscale_smooth_grad = normalize_gradient(grayscale_smooth_grad) 78 | 79 | return { 80 | "smooth_grad_colored": color_smooth_grad, 81 | "smooth_grad_grayscale": grayscale_smooth_grad, 82 | } 83 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/gradient_vanilla_backprop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vanilla Backpropagation generates heat maps that are intended to provide insight into what aspects of an input image a convolutional neural network is using to make a prediction. 3 | 4 | Algorithm details: 5 | - The algorithm uses only the positive inputs. 6 | 7 | [1] Springenberg et al. "Striving for Simplicity: The All Convolutional Net", 2014. 8 | """ 9 | 10 | import torch 11 | 12 | from pytorchxai.xai.utils import convert_to_grayscale, normalize_gradient 13 | 14 | 15 | class VanillaBackprop: 16 | def __init__(self, model): 17 | self.model = model 18 | self.gradients = None 19 | 20 | self.model.eval() 21 | self._hook_layers() 22 | 23 | def _hook_layers(self): 24 | def hook_function(module, grad_in, grad_out): 25 | self.gradients = grad_in[0] 26 | 27 | first_layer = list(self.model.features._modules.items())[0][1] 28 | first_layer.register_backward_hook(hook_function) 29 | 30 | def generate_gradients(self, input_image, target_class): 31 | """ 32 | Generates the gradients using vanilla backpropagation from the given model and image. 33 | 34 | Args: 35 | input_image: Preprocessed input image. 36 | target_class: Expected category. 37 | Returns: 38 | The gradients computed using the vanilla backpropagation. 39 | """ 40 | 41 | model_output = self.model(input_image) 42 | self.model.zero_grad() 43 | 44 | one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_() 45 | one_hot_output[0][target_class] = 1 46 | 47 | model_output.backward(gradient=one_hot_output) 48 | gradients_as_arr = self.gradients.data.numpy()[0] 49 | return gradients_as_arr 50 | 51 | def generate(self, orig_image, input_image, target_class): 52 | """ 53 | Generates and returns multiple saliency maps, based on vanilla backpropagation. 54 | 55 | Args: 56 | orig_image: Original resized image. 57 | input_image: Preprocessed input image. 58 | target_class: Expected category. 59 | Returns: 60 | Colored and grayscale gradients for the vanilla backpropagation. 61 | Grayscale gradients multiplied with the image itself. 62 | """ 63 | vanilla_grads = self.generate_gradients(input_image, target_class) 64 | 65 | color_vanilla_bp = normalize_gradient(vanilla_grads) 66 | 67 | grayscale_vanilla_bp = convert_to_grayscale(vanilla_grads) 68 | grayscale_vanilla_bp = normalize_gradient(grayscale_vanilla_bp) 69 | 70 | grad_times_image = vanilla_grads[0] * input_image.detach().numpy()[0] 71 | grad_times_image = convert_to_grayscale(grad_times_image) 72 | grad_times_image = normalize_gradient(grad_times_image) 73 | 74 | return { 75 | "vanilla_colored_backpropagation": color_vanilla_bp, 76 | "vanilla_grayscale_backpropagation": grayscale_vanilla_bp, 77 | "vanilla_grayscale_grad_times_image": grad_times_image, 78 | } 79 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib.cm as mpl_color_map 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.autograd import Variable 8 | 9 | 10 | def convert_to_grayscale(im_as_arr): 11 | """ 12 | Converts 3d image to grayscale 13 | 14 | Args: 15 | im_as_arr (numpy arr): RGB image with shape (D,W,H) 16 | 17 | returns: 18 | grayscale_im (numpy_arr): Grayscale image with shape (1,W,D) 19 | """ 20 | grayscale_im = np.sum(np.abs(im_as_arr), axis=0) 21 | im_max = np.percentile(grayscale_im, 99) 22 | im_min = np.min(grayscale_im) 23 | grayscale_im = np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1) 24 | grayscale_im = np.expand_dims(grayscale_im, axis=0) 25 | return grayscale_im 26 | 27 | 28 | def normalize_gradient(gradient): 29 | """ 30 | Args: 31 | gradient (np arr): Numpy array of the gradient with shape (3, 224, 224) 32 | """ 33 | gradient = gradient - gradient.min() 34 | gradient /= gradient.max() 35 | 36 | return gradient 37 | 38 | 39 | def apply_colormap_on_image(orig_image, activation, colormap_name): 40 | """ 41 | Apply heatmap on image 42 | Args: 43 | activation_map (numpy arr): Activation map (grayscale) 0-255 44 | colormap_name (str): Name of the colormap 45 | """ 46 | # Get colormap 47 | color_map = mpl_color_map.get_cmap(colormap_name) 48 | no_trans_heatmap = color_map(activation) 49 | # Change alpha channel in colormap to make sure original image is displayed 50 | heatmap = copy.copy(no_trans_heatmap) 51 | heatmap[:, :, 3] = 0.4 52 | heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)) 53 | no_trans_heatmap = Image.fromarray((no_trans_heatmap * 255).astype(np.uint8)) 54 | 55 | heatmap_on_image = Image.new("RGBA", orig_image.size) 56 | heatmap_on_image = Image.alpha_composite( 57 | heatmap_on_image, orig_image.convert("RGBA") 58 | ) 59 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap) 60 | 61 | return no_trans_heatmap, heatmap_on_image 62 | 63 | 64 | def preprocess_image(pil_im, resize_im=True): 65 | """ 66 | Processes image for CNNs 67 | 68 | Args: 69 | PIL_img (PIL_img): PIL Image or numpy array to process 70 | resize_im (bool): Resize to 224 or not 71 | returns: 72 | im_as_var (torch variable): Variable that contains processed float tensor 73 | """ 74 | # mean and std list for channels (Imagenet) 75 | mean = [0.485, 0.456, 0.406] 76 | std = [0.229, 0.224, 0.225] 77 | 78 | # ensure or transform incoming image to PIL image 79 | if type(pil_im) != Image.Image: 80 | try: 81 | pil_im = Image.fromarray(pil_im) 82 | except Exception as e: 83 | print( 84 | "Please check input. err: ", e, 85 | ) 86 | 87 | # Resize image 88 | if resize_im: 89 | pil_im = pil_im.resize((224, 224), Image.ANTIALIAS) 90 | 91 | im_as_arr = np.float32(pil_im) 92 | im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H 93 | # Normalize the channels 94 | for channel, _ in enumerate(im_as_arr): 95 | im_as_arr[channel] /= 255 96 | im_as_arr[channel] -= mean[channel] 97 | im_as_arr[channel] /= std[channel] 98 | # Convert to float tensor 99 | im_as_ten = torch.from_numpy(im_as_arr).float() 100 | # Add one more channel to the beginning. Tensor shape = 1,3,224,224 101 | im_as_ten.unsqueeze_(0) 102 | # Convert to Pytorch variable 103 | im_as_var = Variable(im_as_ten, requires_grad=True) 104 | return im_as_var 105 | 106 | 107 | def recreate_image(im_as_var): 108 | """ 109 | Recreates images from a torch variable, sort of reverse preprocessing 110 | Args: 111 | im_as_var (torch variable): Image to recreate 112 | returns: 113 | recreated_im (numpy arr): Recreated image in array 114 | """ 115 | reverse_mean = [-0.485, -0.456, -0.406] 116 | reverse_std = [1 / 0.229, 1 / 0.224, 1 / 0.225] 117 | recreated_im = copy.copy(im_as_var.data.numpy()[0]) 118 | for c in range(3): 119 | recreated_im[c] /= reverse_std[c] 120 | recreated_im[c] -= reverse_mean[c] 121 | recreated_im[recreated_im > 1] = 1 122 | recreated_im[recreated_im < 0] = 0 123 | recreated_im = np.round(recreated_im * 255) 124 | 125 | recreated_im = np.uint8(recreated_im).transpose(1, 2, 0) 126 | return recreated_im 127 | 128 | 129 | def get_positive_negative_saliency(gradient): 130 | """ 131 | Generates positive and negative saliency maps based on the gradient 132 | Args: 133 | gradient (numpy arr): Gradient of the operation to visualize 134 | 135 | returns: 136 | pos_saliency ( ) 137 | """ 138 | pos_saliency = np.maximum(0, gradient) / gradient.max() 139 | neg_saliency = np.maximum(0, -gradient) / -gradient.min() 140 | return pos_saliency, neg_saliency 141 | -------------------------------------------------------------------------------- /src/pytorchxai/xai/visualizations.py: -------------------------------------------------------------------------------- 1 | from pytorchxai.xai.cam_gradcam import GradCam 2 | from pytorchxai.xai.cam_scorecam import ScoreCam 3 | from pytorchxai.xai.gradient_guided_backprop import GuidedBackprop 4 | from pytorchxai.xai.gradient_guided_gradcam import GuidedGradCam 5 | from pytorchxai.xai.gradient_integrated_grad import IntegratedGradients 6 | from pytorchxai.xai.gradient_smooth_grad import SmoothGrad 7 | from pytorchxai.xai.gradient_vanilla_backprop import VanillaBackprop 8 | 9 | 10 | class GradientVisualization: 11 | def __init__(self, model): 12 | self.model = model 13 | 14 | self.visualizations = [ 15 | GuidedBackprop(model), 16 | VanillaBackprop(model), 17 | ScoreCam(model), 18 | GradCam(model), 19 | GuidedGradCam(model), 20 | IntegratedGradients(model), 21 | SmoothGrad(model), 22 | ] 23 | 24 | def generate(self, orig_image, input_image, target): 25 | results = {} 26 | for v in self.visualizations: 27 | results.update(v.generate(orig_image, input_image, target)) 28 | return results 29 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Dummy conftest.py for pytorchxai. 4 | 5 | If you don't know what this is for, just leave it empty. 6 | Read more about conftest.py under: 7 | https://pytest.org/latest/plugins.html 8 | """ 9 | 10 | # import pytest 11 | -------------------------------------------------------------------------------- /tests/xai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/medtorch/Q-Aid-Core/9ea70c1a1ab66323aa21484a8066512c9cd4fc43/tests/xai/__init__.py -------------------------------------------------------------------------------- /tests/xai/test_gradcam.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.cam_gradcam import GradCam 6 | from pytorchxai.xai.cam_utils import CamExtractor 7 | 8 | 9 | @pytest.mark.parametrize("model", MODELS) 10 | def test_sanity(model): 11 | generator = GradCam(model) 12 | assert generator is not None 13 | 14 | 15 | @pytest.mark.parametrize("model", MODELS) 16 | def test_generate_cam(model): 17 | generator = GradCam(model) 18 | 19 | _, test_image, test_target = create_image() 20 | 21 | guided_grads = generator.generate_cam(test_image, test_target) 22 | 23 | assert guided_grads.shape == (224, 224) 24 | 25 | 26 | @pytest.mark.parametrize("model", MODELS) 27 | def test_generate(model): 28 | generator = GradCam(model) 29 | 30 | test_image, test_input, test_target = create_image() 31 | 32 | output = generator.generate(test_image, test_input, test_target) 33 | 34 | expected = [ 35 | "gradcam_heatmap", 36 | "gradcam_heatmap_on_image", 37 | "gradcam_grayscale", 38 | ] 39 | 40 | for check in expected: 41 | assert check in output 42 | print(check) 43 | if "grayscale" in check: 44 | assert output[check].shape == (1, 224, 224) 45 | else: 46 | assert output[check].shape == ( 47 | 4, 48 | 224, 49 | 224, 50 | ) # gradcam adds alpha to the image 51 | 52 | 53 | @pytest.mark.parametrize("model", MODELS) 54 | def test_extractor(model): 55 | extractor = CamExtractor(model) 56 | _, test_input, _ = create_image() 57 | 58 | conv_output, model_output = extractor.forward_pass(test_input) 59 | 60 | assert model_output.shape[1] == list(model.children())[-1][-1].out_features 61 | assert conv_output is not None 62 | -------------------------------------------------------------------------------- /tests/xai/test_guided_backprop.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.gradient_guided_backprop import GuidedBackprop 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS) 9 | def test_sanity(model): 10 | generator = GuidedBackprop(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS) 15 | def test_generate_gradients(model): 16 | generator = GuidedBackprop(model) 17 | 18 | _, test_image, test_target = create_image() 19 | 20 | guided_grads = generator.generate_gradients(test_image, test_target) 21 | 22 | assert guided_grads.shape == (3, 224, 224) 23 | 24 | 25 | @pytest.mark.parametrize("model", MODELS) 26 | def test_generate(model): 27 | generator = GuidedBackprop(model) 28 | 29 | test_image, test_input, test_target = create_image() 30 | 31 | output = generator.generate(test_image, test_input, test_target) 32 | 33 | expected = [ 34 | "guided_grads_colored", 35 | "guided_grads_grayscale", 36 | "guided_grads_grayscale_grad_times_image", 37 | "saliency_maps_positive", 38 | "saliency_maps_negative", 39 | ] 40 | 41 | for check in expected: 42 | assert check in output 43 | if "grayscale" in check: 44 | assert output[check].shape == (1, 224, 224) 45 | else: 46 | assert output[check].shape == (3, 224, 224) 47 | -------------------------------------------------------------------------------- /tests/xai/test_guided_gradcam.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.gradient_guided_gradcam import GuidedGradCam 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS) 9 | def test_sanity(model): 10 | generator = GuidedGradCam(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS) 15 | def test_generate(model): 16 | generator = GuidedGradCam(model) 17 | 18 | test_image, test_input, test_target = create_image() 19 | 20 | output = generator.generate(test_image, test_input, test_target) 21 | 22 | expected = [ 23 | "guided_gradcam", 24 | "guided_gradcam_grayscale", 25 | ] 26 | 27 | for check in expected: 28 | assert check in output 29 | if "grayscale" in check: 30 | assert output[check].shape == (1, 224, 224) 31 | else: 32 | assert output[check].shape == (3, 224, 224) 33 | -------------------------------------------------------------------------------- /tests/xai/test_integrated_gradients.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.gradient_integrated_grad import IntegratedGradients 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS) 9 | def test_sanity(model): 10 | generator = IntegratedGradients(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS) 15 | def test_generate(model): 16 | generator = IntegratedGradients(model) 17 | 18 | test_image, test_input, test_target = create_image() 19 | 20 | output = generator.generate(test_image, test_input, test_target) 21 | 22 | expected = [ 23 | "integrated_gradients", 24 | "integrated_gradients_times_image", 25 | ] 26 | 27 | for check in expected: 28 | assert check in output 29 | assert output[check].shape == (1, 224, 224) 30 | 31 | 32 | @pytest.mark.parametrize("model", MODELS) 33 | def test_generate_images_on_linear_path(model): 34 | generator = IntegratedGradients(model) 35 | 36 | _, test_input, _ = create_image() 37 | 38 | imgs = generator.generate_images_on_linear_path(test_input, 10) 39 | 40 | assert len(imgs) == 11 41 | for img in imgs: 42 | assert img.shape == (1, 3, 224, 224) 43 | 44 | 45 | @pytest.mark.parametrize("model", MODELS) 46 | def test_generate_integrated_gradients(model): 47 | generator = IntegratedGradients(model) 48 | 49 | _, test_input, test_target = create_image() 50 | 51 | output = generator.generate_integrated_gradients(test_input, test_target, 2) 52 | 53 | assert output.shape == (3, 224, 224) 54 | -------------------------------------------------------------------------------- /tests/xai/test_scorecam.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.cam_scorecam import ScoreCam 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS[:1]) 9 | def test_sanity(model): 10 | generator = ScoreCam(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS[:1]) 15 | def test_generate_cam(model): 16 | generator = ScoreCam(model) 17 | 18 | _, test_image, test_target = create_image() 19 | 20 | guided_grads = generator.generate_cam(test_image, test_target) 21 | 22 | assert guided_grads.shape == (224, 224) 23 | 24 | 25 | @pytest.mark.parametrize("model", MODELS[:1]) 26 | def test_generate(model): 27 | generator = ScoreCam(model) 28 | 29 | test_image, test_input, test_target = create_image() 30 | 31 | output = generator.generate(test_image, test_input, test_target) 32 | 33 | expected = [ 34 | "scorecam_heatmap", 35 | "scorecam_heatmap_on_image", 36 | "scorecam_grayscale", 37 | ] 38 | 39 | for check in expected: 40 | assert check in output 41 | print(check) 42 | if "grayscale" in check: 43 | assert output[check].shape == (1, 224, 224) 44 | else: 45 | assert output[check].shape == ( 46 | 4, 47 | 224, 48 | 224, 49 | ) # scorecam adds alpha to the image 50 | -------------------------------------------------------------------------------- /tests/xai/test_smooth_grad.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.gradient_smooth_grad import SmoothGrad 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS) 9 | def test_sanity(model): 10 | generator = SmoothGrad(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS) 15 | def test_generate_smooth_grad(model): 16 | generator = SmoothGrad(model) 17 | 18 | _, test_image, test_target = create_image() 19 | 20 | guided_grads = generator.generate_smooth_grad(test_image, test_target, 5, 4) 21 | 22 | assert guided_grads.shape == (3, 224, 224) 23 | 24 | 25 | @pytest.mark.parametrize("model", MODELS) 26 | def test_generate(model): 27 | generator = SmoothGrad(model) 28 | 29 | test_image, test_input, test_target = create_image() 30 | 31 | output = generator.generate(test_image, test_input, test_target) 32 | 33 | expected = [ 34 | "smooth_grad_colored", 35 | "smooth_grad_grayscale", 36 | ] 37 | 38 | for check in expected: 39 | assert check in output 40 | if "grayscale" in check: 41 | assert output[check].shape == (1, 224, 224) 42 | else: 43 | assert output[check].shape == (3, 224, 224) 44 | -------------------------------------------------------------------------------- /tests/xai/test_vanilla_backprop.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .utils import MODELS, create_image 4 | 5 | from pytorchxai.xai.gradient_vanilla_backprop import VanillaBackprop 6 | 7 | 8 | @pytest.mark.parametrize("model", MODELS) 9 | def test_sanity(model): 10 | generator = VanillaBackprop(model) 11 | assert generator is not None 12 | 13 | 14 | @pytest.mark.parametrize("model", MODELS) 15 | def test_generate_gradients(model): 16 | generator = VanillaBackprop(model) 17 | 18 | _, test_image, test_target = create_image() 19 | 20 | guided_grads = generator.generate_gradients(test_image, test_target) 21 | 22 | assert guided_grads.shape == (3, 224, 224) 23 | 24 | 25 | @pytest.mark.parametrize("model", MODELS) 26 | def test_generate(model): 27 | generator = VanillaBackprop(model) 28 | 29 | test_image, test_input, test_target = create_image() 30 | 31 | output = generator.generate(test_image, test_input, test_target) 32 | 33 | expected = [ 34 | "vanilla_colored_backpropagation", 35 | "vanilla_grayscale_backpropagation", 36 | "vanilla_grayscale_grad_times_image", 37 | ] 38 | 39 | for check in expected: 40 | assert check in output 41 | if "grayscale" in check: 42 | assert output[check].shape == (1, 224, 224) 43 | else: 44 | assert output[check].shape == (3, 224, 224) 45 | -------------------------------------------------------------------------------- /tests/xai/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torchvision import models 4 | 5 | from pytorchxai.xai.utils import preprocess_image 6 | 7 | 8 | def create_image(width=244, height=244): 9 | width = int(width) 10 | height = int(height) 11 | 12 | rgb_array = np.random.rand(height, width, 3) * 255 13 | image = Image.fromarray(rgb_array.astype("uint8")).convert("RGB") 14 | image = image.resize((224, 224), Image.ANTIALIAS) 15 | prep = preprocess_image(image) 16 | 17 | target = 42 18 | 19 | return image, prep, target 20 | 21 | 22 | MODELS = [models.alexnet(pretrained=True), models.vgg19(pretrained=True)] 23 | --------------------------------------------------------------------------------