├── .devcontainer.json ├── .dvc ├── .gitignore └── config ├── .dvcignore ├── .gitattributes ├── .github └── workflows │ ├── deploy-model-sagemaker.yml │ ├── deploy-model-template.yml │ └── dvc-studio.yml ├── .gitignore ├── .gitlab-ci.yml ├── README.md ├── data ├── .gitignore └── pool_data.dvc ├── dvc.lock ├── dvc.yaml ├── models └── .gitignore ├── notebooks └── TrainSegModel.ipynb ├── params.yaml ├── requirements.txt ├── results └── .gitignore ├── sagemaker ├── code │ ├── inference.py │ └── requirements.txt └── deploy_model.py └── src ├── data_split.py ├── endpoint_prediction.py ├── evaluate.py └── train.py /.devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "example-cv", 3 | "image": "mcr.microsoft.com/devcontainers/python:3.10", 4 | "runArgs": ["--ipc=host"], 5 | "features": { 6 | "ghcr.io/devcontainers/features/nvidia-cuda:1": { 7 | "installCudnn": true 8 | }, 9 | "ghcr.io/iterative/features/nvtop:1": {} 10 | }, 11 | "extensions": [ 12 | "Iterative.dvc", 13 | "ms-python.python", 14 | "redhat.vscode-yaml" 15 | ], 16 | "postCreateCommand": "pip install --user -r requirements.txt" 17 | } 18 | -------------------------------------------------------------------------------- /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /tmp 3 | /cache 4 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | [core] 2 | remote = storage 3 | ['remote "storage"'] 4 | url = https://remote.dvc.org/get-started-pools 5 | -------------------------------------------------------------------------------- /.dvcignore: -------------------------------------------------------------------------------- 1 | # Add patterns of files dvc should ignore, which could improve 2 | # the performance. Learn more at 3 | # https://dvc.org/doc/user-guide/dvcignore 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dvc linguist-language=YAML 2 | dvc.lock linguist-language=YAML 3 | -------------------------------------------------------------------------------- /.github/workflows/deploy-model-sagemaker.yml: -------------------------------------------------------------------------------- 1 | name: Deploy model (Sagemaker) 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*" 7 | 8 | permissions: 9 | contents: write 10 | id-token: write 11 | 12 | jobs: 13 | parse: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: "Parse GTO tag" 18 | id: gto 19 | uses: iterative/gto-action@v2 20 | outputs: 21 | event: ${{ steps.gto.outputs.event }} 22 | name: ${{ steps.gto.outputs.name }} 23 | stage: ${{ steps.gto.outputs.stage }} 24 | version: ${{ steps.gto.outputs.version }} 25 | 26 | deploy-model: 27 | needs: parse 28 | if: "${{ needs.parse.outputs.event == 'assignment' }}" 29 | environment: cloud 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v3 33 | with: 34 | fetch-depth: 0 35 | 36 | - uses: aws-actions/configure-aws-credentials@v4 37 | with: 38 | aws-region: us-east-2 39 | role-to-assume: ${{ vars.AWS_SANDBOX_ROLE }} 40 | role-duration-seconds: 43200 41 | 42 | - name: Set up Python 43 | uses: actions/setup-python@v4 44 | with: 45 | python-version: '3.8' 46 | cache: 'pip' 47 | cache-dependency-path: requirements.txt 48 | 49 | - run: pip install -r requirements.txt 50 | 51 | - run: dvc remote add -d --local storage s3://dvc-public/remote/get-started-pools 52 | 53 | - run: | 54 | MODEL_DATA=$(dvc get --show-url . model.tar.gz) 55 | python sagemaker/deploy_model.py \ 56 | --name ${{ needs.parse.outputs.name }} \ 57 | --stage ${{ needs.parse.outputs.stage }} \ 58 | --version ${{ needs.parse.outputs.version }} \ 59 | --model_data $MODEL_DATA \ 60 | --role ${{ vars.AWS_SANDBOX_ROLE }} 61 | -------------------------------------------------------------------------------- /.github/workflows/deploy-model-template.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Model (Template) 2 | 3 | on: 4 | # the workflow is triggered whenever a tag is pushed to the repository 5 | push: 6 | tags: 7 | - "*" 8 | jobs: 9 | 10 | # This job parses the git tag with the GTO GitHub Action to identify model registry actions 11 | parse: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: "Parse GTO tag" 16 | id: gto 17 | uses: iterative/gto-action@v2 18 | outputs: 19 | event: ${{ steps.gto.outputs.event }} 20 | name: ${{ steps.gto.outputs.name }} 21 | stage: ${{ steps.gto.outputs.stage }} 22 | version: ${{ steps.gto.outputs.version }} 23 | 24 | deploy-model: 25 | needs: parse 26 | # using the outputs from the "parse" job, we run this job only for actions 27 | # in the model registry and only when the model was assigned to a stage called "prod" 28 | if: ${{ needs.parse.outputs.event == 'assignment' && needs.parse.outputs.stage == 'prod' }} 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: iterative/setup-dvc@v1 32 | # this step uses DVC to download the model from our remote repository and deploys the model 33 | # Model deployment is mocked here as it is specific to each deployment environment 34 | # The DVC Studio token is used to avoid having to store specific remote storage credentials on GitHub 35 | - name: Get Model For Deployment 36 | run: | 37 | dvc config --global studio.token ${{ secrets.DVC_STUDIO_TOKEN }} 38 | dvc artifacts get ${{ github.server_url }}/${{ github.repository }} ${{ needs.parse.outputs.name }} --rev ${{ needs.parse.outputs.version }} 39 | echo "The right model is available and you can use the rest of this command to deploy it. Good job!" 40 | -------------------------------------------------------------------------------- /.github/workflows/dvc-studio.yml: -------------------------------------------------------------------------------- 1 | name: DVC Studio Experiment 2 | 3 | on: 4 | 5 | push: 6 | tags-ignore: 7 | - '**' 8 | 9 | workflow_dispatch: 10 | inputs: 11 | exp-run-args: 12 | description: 'Args to be passed to dvc exp run call' 13 | required: false 14 | type: string 15 | default: '' 16 | parent-sha: 17 | description: 'SHA of the commit to start the experiment from' 18 | required: false 19 | type: string 20 | default: '' 21 | cloud: 22 | description: 'Cloud compute provider to host the runner' 23 | required: false 24 | default: 'aws' 25 | type: choice 26 | options: 27 | - aws 28 | - azure 29 | - gcp 30 | type: 31 | description: 'https://registry.terraform.io/providers/iterative/iterative/latest/docs/resources/task#machine-type' 32 | required: false 33 | default: 'g5.2xlarge' 34 | region: 35 | description: 'https://registry.terraform.io/providers/iterative/iterative/latest/docs/resources/task#cloud-region' 36 | required: false 37 | default: 'us-east' 38 | spot: 39 | description: 'Request a spot instance' 40 | required: false 41 | default: false 42 | type: boolean 43 | storage: 44 | description: 'Disk size in GB' 45 | required: false 46 | default: 40 47 | type: number 48 | timeout: 49 | description: 'Timeout in seconds' 50 | required: false 51 | default: 3600 52 | type: number 53 | 54 | permissions: 55 | contents: write 56 | id-token: write 57 | pull-requests: write 58 | 59 | jobs: 60 | 61 | deploy-runner: 62 | if: ${{ (github.actor == 'iterative-studio[bot]') || (github.event_name == 'workflow_dispatch') }} 63 | environment: cloud 64 | runs-on: ubuntu-latest 65 | 66 | steps: 67 | - uses: actions/checkout@v3 68 | with: 69 | ref: ${{ inputs.parent-sha || '' }} 70 | - uses: iterative/setup-cml@v2 71 | - uses: aws-actions/configure-aws-credentials@v4 72 | with: 73 | aws-region: us-east-2 74 | role-to-assume: ${{ vars.AWS_SANDBOX_ROLE }} 75 | role-duration-seconds: 43200 76 | - name: Create Runner 77 | env: 78 | REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 79 | run: | 80 | cml runner launch --single \ 81 | --labels=cml \ 82 | --cloud=${{ inputs.cloud || 'aws' }} \ 83 | --cloud-region=${{ inputs.region || 'us-east' }} \ 84 | --cloud-hdd-size=${{ inputs.storage || '40' }} \ 85 | --cloud-type=${{ inputs.type || 'g5.2xlarge' }} \ 86 | --idle-timeout=${{ inputs.timeout || '3600' }} \ 87 | ${{ (inputs.spot == 'true' && '--cloud-spot') || '' }} 88 | 89 | runner-job: 90 | needs: deploy-runner 91 | runs-on: [ self-hosted, cml ] 92 | environment: cloud 93 | container: 94 | image: iterativeai/cml:latest-gpu 95 | options: --gpus all --ipc host 96 | 97 | steps: 98 | - uses: actions/checkout@v3 99 | with: 100 | ref: ${{ inputs.parent-sha || '' }} 101 | - uses: aws-actions/configure-aws-credentials@v4 102 | with: 103 | aws-region: us-east-2 104 | role-to-assume: ${{ vars.AWS_SANDBOX_ROLE }} 105 | role-duration-seconds: 43200 106 | 107 | - run: pip install -r requirements.txt 108 | 109 | - name: Train 110 | env: 111 | REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 112 | DVC_STUDIO_TOKEN: ${{ secrets.DVC_STUDIO_TOKEN }} 113 | DVCLIVE_LOGLEVEL: DEBUG 114 | run: | 115 | cml ci --fetch-depth 0 116 | dvc exp run --pull --allow-missing ${{ github.event.inputs.exp-run-args }} 117 | dvc remote add --local push_remote s3://dvc-public/remote/get-started-pools 118 | 119 | - name: Workflow Dispatch Sharing 120 | if: github.event_name == 'workflow_dispatch' 121 | env: 122 | DVC_STUDIO_TOKEN: ${{ secrets.DVC_STUDIO_TOKEN }} 123 | run: | 124 | dvc exp push origin -r push_remote 125 | 126 | - name: Commit-based Sharing 127 | if: github.actor == 'iterative-studio[bot]' 128 | env: 129 | REPO_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 130 | run: | 131 | dvc push -r push_remote 132 | cml pr --squash --skip-ci . 133 | echo "## Metrics" > report.md 134 | dvc metrics diff main --md >> report.md 135 | echo "## Params" >> report.md 136 | dvc params diff main --md >> report.md 137 | cml comment create --pr report.md 138 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | /model.tar.gz 3 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | # Deploy Model (Template) 2 | 3 | workflow: 4 | rules: 5 | # Run the pipeline whenever a tag is pushed to the repository 6 | - if: $CI_COMMIT_TAG 7 | 8 | parse: 9 | # This job parses the model tag to identify model registry actions 10 | image: python:3.11-slim 11 | script: 12 | # Install GTO to parse model tags 13 | - pip install gto 14 | # This job parses the model tags to identify model registry actions 15 | - echo "CI_COMMIT_TAG - ${CI_COMMIT_TAG}" 16 | - echo MODEL_NAME="$(gto check-ref ${CI_COMMIT_TAG} --name)" >> parse.env 17 | - echo MODEL_VERSION="$(gto check-ref ${CI_COMMIT_TAG} --version)" >> parse.env 18 | - echo MODEL_EVENT="$(gto check-ref ${CI_COMMIT_TAG} --event)" >> parse.env 19 | - echo MODEL_STAGE="$(gto check-ref ${CI_COMMIT_TAG} --stage)" >> parse.env 20 | # Print variables saved to parse.env 21 | - cat parse.env 22 | artifacts: 23 | reports: 24 | dotenv: parse.env 25 | 26 | deploy-model: 27 | needs: 28 | - job: parse 29 | artifacts: true 30 | image: python:3.11-slim 31 | script: 32 | # Check if the model is assigned to prod (variables from parse.env are only available in the 'script' section) 33 | - if [[ $MODEL_EVENT == 'assignment' && $MODEL_STAGE == 'prod' ]]; then echo "Deploy model"; else exit 1; fi 34 | # Install DVC 35 | - pip install dvc 36 | # Build commands to download and deploy the model 37 | - dvc config --global studio.token ${DVC_STUDIO_TOKEN} 38 | - dvc artifacts get ${CI_REPOSITORY_URL} ${MODEL_NAME} --rev ${MODEL_VERSION} 39 | - echo "The right model is available and you can use the rest of this command to deploy it. Good job!" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DVC](https://img.shields.io/badge/-Open_in_Studio-grey.svg?style=flat-square&logo=dvc)](https://studio.iterative.ai/team/Iterative/projects/example-get-started-experiments-y8toqd433r) 2 | [![DVC-metrics](https://img.shields.io/badge/dynamic/json?style=flat-square&colorA=grey&colorB=F46737&label=Dice%20Metric&url=https://github.com/iterative/example-get-started-experiments/raw/main/results/evaluate/metrics.json&query=dice_multi)](https://github.com/iterative/example-get-started-experiments/raw/main/results/evaluate/metrics.json) 3 | 4 | [Train Report](./results/train/report.md) - [Evaluation Report](./results/evaluate/report.md) 5 | 6 | # DVC Get Started: Experiments 7 | 8 | This is an auto-generated repository for use in [DVC](https://dvc.org) 9 | [Get Started: Experiments](https://dvc.org/doc/start/experiment-management). 10 | 11 | This is a Computer Vision (CV) project that solves the problem of segmenting out 12 | swimming pools from satellite images. 13 | 14 | [Example results](./results/evaluate/plots/images/) 15 | 16 | We use a slightly modified version of the [BH-Pools dataset](http://patreo.dcc.ufmg.br/2020/07/29/bh-pools-watertanks-datasets/): 17 | we split the original 4k images into tiles of 1024x1024 pixels. 18 | 19 | 20 | 🐛 Please report any issues found in this project here - 21 | [example-repos-dev](https://github.com/iterative/example-repos-dev). 22 | 23 | ## Installation 24 | 25 | Python 3.8+ is required to run code from this repo. 26 | 27 | ```console 28 | $ git clone https://github.com/iterative/example-get-started-experiments 29 | $ cd example-get-started-experiments 30 | ``` 31 | 32 | Now let's install the requirements. But before we do that, we **strongly** 33 | recommend creating a virtual environment with a tool such as 34 | [virtualenv](https://virtualenv.pypa.io/en/stable/): 35 | 36 | ```console 37 | $ python -m venv .venv 38 | $ source .venv/bin/activate 39 | $ pip install -r requirements.txt 40 | ``` 41 | 42 | This DVC project comes with a preconfigured DVC 43 | [remote storage](https://dvc.org/doc/commands-reference/remote) that holds raw 44 | data (input), intermediate, and final results that are produced. This is a 45 | read-only HTTP remote. 46 | 47 | ```console 48 | $ dvc remote list 49 | storage https://remote.dvc.org/get-started-pools 50 | ``` 51 | 52 | You can run [`dvc pull`](https://man.dvc.org/pull) to download the data: 53 | 54 | ```console 55 | $ dvc pull 56 | ``` 57 | 58 | ## Running in your environment 59 | 60 | Run [`dvc exp run`](https://man.dvc.org/exp/run) to reproduce the 61 | [pipeline](https://dvc.org/doc/user-guide/pipelines/defining-pipelinese): 62 | 63 | ```console 64 | $ dvc exp run 65 | Data and pipelines are up to date. 66 | ``` 67 | 68 | If you'd like to test commands like [`dvc push`](https://man.dvc.org/push), 69 | that require write access to the remote storage, the easiest way would be to set 70 | up a "local remote" on your file system: 71 | 72 | > This kind of remote is located in the local file system, but is external to 73 | > the DVC project. 74 | 75 | ```console 76 | $ mkdir -p /tmp/dvc-storage 77 | $ dvc remote add local /tmp/dvc-storage 78 | ``` 79 | 80 | You should now be able to run: 81 | 82 | ```console 83 | $ dvc push -r local 84 | ``` 85 | 86 | ## Existing stages 87 | 88 | There is a couple of git tags in this project : 89 | 90 | ### [1-notebook-dvclive](https://github.com/iterative/example-get-started-experiments/tree/1-notebook-dvclive) 91 | 92 | Contains an end-to-end Jupyter notebook that loads data, trains a model and 93 | reports model performance. 94 | [DVCLive](https://dvc.org/doc/dvclive) is used for experiment tracking. 95 | See this [blog post](https://iterative.ai/blog/exp-tracking-dvc-python) for more 96 | details. 97 | 98 | ### [2-dvc-pipeline](https://github.com/iterative/example-get-started-experiments/tree/2-dvc-pipeline) 99 | 100 | Contains a DVC pipeline `dvc.yaml` that was created by refactoring the above 101 | notebook into individual pipeline stages. 102 | 103 | The pipeline artifacts (processed data, model file, etc) are automatically 104 | versioned. 105 | 106 | This tag also contains a GitHub Actions workflow that reruns the pipeline if any 107 | changes are introduced to the pipeline-related files. 108 | [CML](https://cml.dev/) is used in this workflow to provision a cloud-based GPU 109 | machine as well as report model performance results in Pull Requests. 110 | 111 | ## Model Deployment 112 | 113 | Check out the [GitHub Workflow](https://github.com/iterative/example-get-started-experiments/blob/main/.github/workflows/deploy-model.yml) 114 | that uses the [Iterative Studio Model Registry](https://dvc.org/doc/studio/user-guide/model-registry/what-is-a-model-registry). 115 | to deploy the model to [AWS Sagemaker](https://aws.amazon.com/es/sagemaker/) whenever a new [version is registered](https://dvc.org/doc/studio/user-guide/model-registry/register-version). 116 | 117 | ## Project structure 118 | 119 | The data files, DVC files, and results change as stages are created one by one. 120 | After cloning and using [`dvc pull`](https://man.dvc.org/pull) to download 121 | data, models, and plots tracked by DVC, the workspace should look like this: 122 | 123 | ```console 124 | $ tree -L 2 125 | . 126 | ├── LICENSE 127 | ├── README.md 128 | ├── data. # <-- Directory with raw and intermediate data 129 | │ ├── pool_data # <-- Raw image data 130 | │ ├── pool_data.dvc # <-- .dvc file - a placeholder/pointer to raw data 131 | │ ├── test_data # <-- Processed test data 132 | │ └── train_data # <-- Processed train data 133 | ├── dvc.lock 134 | ├── dvc.yaml # <-- DVC pipeline file 135 | ├── models 136 | │ └── model.pkl # <-- Trained model file 137 | ├── notebooks 138 | │ └── TrainSegModel.ipynb # <-- Initial notebook (refactored into `dvc.yaml`) 139 | ├── params.yaml # <-- Parameters file 140 | ├── requirements.txt # <-- Python dependencies needed in the project 141 | ├── results # <-- DVCLive reports and plots 142 | │ ├── evaluate 143 | │ └── train 144 | └── src # <-- Source code to run the pipeline stages 145 | ├── data_split.py 146 | ├── evaluate.py 147 | └── train.py 148 | ``` 149 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /pool_data 2 | /test_data 3 | /train_data 4 | -------------------------------------------------------------------------------- /data/pool_data.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: 14d187e749ee5614e105741c719fa185.dir 3 | size: 18999874 4 | nfiles: 183 5 | path: pool_data 6 | hash: md5 7 | -------------------------------------------------------------------------------- /dvc.lock: -------------------------------------------------------------------------------- 1 | schema: '2.0' 2 | stages: 3 | data_split: 4 | cmd: python src/data_split.py 5 | deps: 6 | - path: data/pool_data 7 | hash: md5 8 | md5: 14d187e749ee5614e105741c719fa185.dir 9 | size: 18999874 10 | nfiles: 183 11 | - path: src/data_split.py 12 | hash: md5 13 | md5: 280fa1684c5496fb9f76ff8c96bd2561 14 | size: 1035 15 | params: 16 | params.yaml: 17 | base: 18 | random_seed: 42 19 | data_split: 20 | test_regions: 21 | - REGION_1 22 | outs: 23 | - path: data/test_data 24 | hash: md5 25 | md5: 1bb16eb1219b47a8bf711ade27c476e4.dir 26 | size: 2087761 27 | nfiles: 24 28 | - path: data/train_data 29 | hash: md5 30 | md5: a28a7e4d342c27c1d7ad3c17ec1dfa7a.dir 31 | size: 16905965 32 | nfiles: 158 33 | train: 34 | cmd: python src/train.py 35 | deps: 36 | - path: data/train_data 37 | hash: md5 38 | md5: a28a7e4d342c27c1d7ad3c17ec1dfa7a.dir 39 | size: 16905965 40 | nfiles: 158 41 | - path: src/train.py 42 | hash: md5 43 | md5: 9db72f1631f53eecb232bc48992c425a 44 | size: 2507 45 | params: 46 | params.yaml: 47 | base: 48 | random_seed: 42 49 | train: 50 | valid_pct: 0.1 51 | arch: shufflenet_v2_x2_0 52 | img_size: 256 53 | batch_size: 8 54 | fine_tune_args: 55 | epochs: 8 56 | base_lr: 0.01 57 | outs: 58 | - path: models/model.pkl 59 | hash: md5 60 | md5: 63cb30df484bfa1d6ea2cdebac74876c 61 | size: 201725 62 | - path: models/model.pth 63 | hash: md5 64 | md5: 07b113fe1ab01de2d8a5453dccacdd3e 65 | size: 165147 66 | - path: results/train 67 | hash: md5 68 | md5: a26429ff680d9b01c7e92043821bf41c.dir 69 | size: 955 70 | nfiles: 5 71 | evaluate: 72 | cmd: python src/evaluate.py 73 | deps: 74 | - path: data/test_data 75 | hash: md5 76 | md5: 1bb16eb1219b47a8bf711ade27c476e4.dir 77 | size: 2087761 78 | nfiles: 24 79 | - path: models/model.pkl 80 | hash: md5 81 | md5: 63cb30df484bfa1d6ea2cdebac74876c 82 | size: 201725 83 | - path: src/evaluate.py 84 | hash: md5 85 | md5: 84d2fd3b371546730396a763a51527a0 86 | size: 3322 87 | params: 88 | params.yaml: 89 | base: 90 | random_seed: 42 91 | evaluate: 92 | n_samples_to_save: 10 93 | outs: 94 | - path: results/evaluate 95 | hash: md5 96 | md5: 34985c391291e22ac0cfdf6e83ec7268.dir 97 | size: 1257936 98 | nfiles: 11 99 | sagemaker: 100 | cmd: cp models/model.pth sagemaker/code/model.pth && cd sagemaker && tar -cpzf 101 | model.tar.gz code/ && cd .. && mv sagemaker/model.tar.gz . && rm sagemaker/code/model.pth 102 | deps: 103 | - path: models/model.pth 104 | hash: md5 105 | md5: 07b113fe1ab01de2d8a5453dccacdd3e 106 | size: 165147 107 | outs: 108 | - path: model.tar.gz 109 | hash: md5 110 | md5: c615b812da343d71fcda080597a24525 111 | size: 145778 112 | -------------------------------------------------------------------------------- /dvc.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | - results/train/params.yaml 3 | metrics: 4 | - results/train/metrics.json 5 | - results/evaluate/metrics.json 6 | plots: 7 | - results/train/plots/metrics: 8 | x: step 9 | - results/evaluate/plots/images 10 | artifacts: 11 | pool-segmentation: 12 | path: models/model.pkl 13 | type: model 14 | desc: This is a Computer Vision (CV) model that's segmenting out swimming pools 15 | from satellite images. 16 | labels: 17 | - cv 18 | - segmentation 19 | - satellite-images 20 | - shufflenet_v2_x2_0 21 | stages: 22 | data_split: 23 | cmd: python src/data_split.py 24 | deps: 25 | - data/pool_data 26 | - src/data_split.py 27 | params: 28 | - base 29 | - data_split 30 | outs: 31 | - data/test_data 32 | - data/train_data 33 | train: 34 | cmd: python src/train.py 35 | deps: 36 | - data/train_data 37 | - src/train.py 38 | params: 39 | - base 40 | - train 41 | outs: 42 | - models/model.pkl 43 | - models/model.pth 44 | - results/train 45 | evaluate: 46 | cmd: python src/evaluate.py 47 | deps: 48 | - data/test_data 49 | - models/model.pkl 50 | - src/evaluate.py 51 | params: 52 | - base 53 | - evaluate 54 | outs: 55 | - results/evaluate 56 | sagemaker: 57 | cmd: cp models/model.pth sagemaker/code/model.pth && cd sagemaker && tar -cpzf 58 | model.tar.gz code/ && cd .. && mv sagemaker/model.tar.gz . && rm sagemaker/code/model.pth 59 | deps: 60 | - models/model.pth 61 | outs: 62 | - model.tar.gz 63 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | /model.pkl 2 | /model.pth 3 | -------------------------------------------------------------------------------- /notebooks/TrainSegModel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import shutil\n", 11 | "from functools import partial\n", 12 | "from pathlib import Path\n", 13 | "import warnings\n", 14 | "\n", 15 | "import numpy as np\n", 16 | "import torch\n", 17 | "from box import ConfigBox\n", 18 | "from dvclive import Live\n", 19 | "from dvclive.fastai import DVCLiveCallback\n", 20 | "from fastai.data.all import Normalize, get_files\n", 21 | "from fastai.metrics import DiceMulti\n", 22 | "from fastai.vision.all import (Resize, SegmentationDataLoaders,\n", 23 | " imagenet_stats, models, unet_learner)\n", 24 | "from ruamel.yaml import YAML\n", 25 | "from PIL import Image\n", 26 | "\n", 27 | "os.chdir(\"..\")\n", 28 | "warnings.filterwarnings(\"ignore\")" 29 | ] 30 | }, 31 | { 32 | "attachments": {}, 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "### Load data and split it into train/test\n", 37 | "\n", 38 | "We have some [data in DVC](https://dvc.org/doc/start/data-management/data-versioning) that we can pull. \n", 39 | "\n", 40 | "This data includes:\n", 41 | "* satellite images\n", 42 | "* masks of the swimming pools in each satellite image\n", 43 | "\n", 44 | "DVC can help connect your data to your repo, but it isn't necessary to have your data in DVC to start tracking experiments with DVC and DVCLive." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "!dvc pull" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "test_regions = [\"REGION_1-\"]\n", 63 | "\n", 64 | "img_fpaths = get_files(Path(\"data\") / \"pool_data\" / \"images\", extensions=\".jpg\")\n", 65 | "\n", 66 | "train_data_dir = Path(\"data\") / \"train_data\"\n", 67 | "train_data_dir.mkdir(exist_ok=True)\n", 68 | "test_data_dir = Path(\"data\") / \"test_data\"\n", 69 | "test_data_dir.mkdir(exist_ok=True)\n", 70 | "for img_path in img_fpaths:\n", 71 | " msk_path = Path(\"data\") / \"pool_data\" / \"masks\" / f\"{img_path.stem}.png\"\n", 72 | " if any(region in str(img_path) for region in test_regions):\n", 73 | " shutil.copy(img_path, test_data_dir)\n", 74 | " shutil.copy(msk_path, test_data_dir)\n", 75 | " else:\n", 76 | " shutil.copy(img_path, train_data_dir)\n", 77 | " shutil.copy(msk_path, train_data_dir)" 78 | ] 79 | }, 80 | { 81 | "attachments": {}, 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "### Create a data loader\n", 86 | "\n", 87 | "Load and prepare the images and masks by creating a data loader." 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "def get_mask_path(x, train_data_dir):\n", 97 | " return Path(train_data_dir) / f\"{Path(x).stem}.png\"" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "bs = 8\n", 107 | "valid_pct = 0.20\n", 108 | "img_size = 256\n", 109 | "\n", 110 | "data_loader = SegmentationDataLoaders.from_label_func(\n", 111 | " path=train_data_dir,\n", 112 | " fnames=get_files(train_data_dir, extensions=\".jpg\"),\n", 113 | " label_func=partial(get_mask_path, train_data_dir=train_data_dir),\n", 114 | " codes=[\"not-pool\", \"pool\"],\n", 115 | " bs=bs,\n", 116 | " valid_pct=valid_pct,\n", 117 | " item_tfms=Resize(img_size),\n", 118 | " batch_tfms=[\n", 119 | " Normalize.from_stats(*imagenet_stats),\n", 120 | " ],\n", 121 | " )" 122 | ] 123 | }, 124 | { 125 | "attachments": {}, 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### Review a sample batch of data\n", 130 | "\n", 131 | "Below are some examples of the images overlaid with their masks." 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "data_loader.show_batch(alpha=0.7)" 141 | ] 142 | }, 143 | { 144 | "attachments": {}, 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "### Train multiple models with different learning rates using `DVCLiveCallback`\n", 149 | "\n", 150 | "Set up model training, using DVCLive to capture the results of each experiment." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6):\n", 160 | " dice_list = []\n", 161 | " for c in classes:\n", 162 | " y_true = mask_true == c\n", 163 | " y_pred = mask_pred == c\n", 164 | " intersection = 2.0 * np.sum(y_true * y_pred)\n", 165 | " dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n", 166 | " dice_list.append(dice)\n", 167 | " return np.mean(dice_list)\n", 168 | "\n", 169 | "\n", 170 | "def evaluate(learn):\n", 171 | " test_img_fpaths = sorted(get_files(Path(\"data\") / \"test_data\", extensions=\".jpg\"))\n", 172 | " test_dl = learn.dls.test_dl(test_img_fpaths)\n", 173 | " preds, _ = learn.get_preds(dl=test_dl)\n", 174 | " masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n", 175 | " test_mask_fpaths = [\n", 176 | " get_mask_path(fpath, Path(\"data\") / \"test_data\") for fpath in test_img_fpaths\n", 177 | " ]\n", 178 | " masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n", 179 | "\n", 180 | " dice_multi = 0.0\n", 181 | " for ii in range(len(masks_true)):\n", 182 | " mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n", 183 | " mask_pred = np.array(\n", 184 | " Image.fromarray(mask_pred).resize((mask_true.shape[1], mask_true.shape[0])),\n", 185 | " dtype=int\n", 186 | " )\n", 187 | " mask_true = np.array(mask_true, dtype=int)\n", 188 | " dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n", 189 | "\n", 190 | " return dice_multi" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "train_arch = 'shufflenet_v2_x2_0'\n", 200 | "\n", 201 | "for base_lr in [0.001, 0.005, 0.01]:\n", 202 | " # initialize dvclive, optionally provide output path, and show report in notebook\n", 203 | " # don't save dvc experiment until post-training metrics below\n", 204 | " with Live(\"results/train\", report=\"notebook\", save_dvc_exp=False) as live:\n", 205 | " # log a parameter\n", 206 | " live.log_param(\"train_arch\", train_arch)\n", 207 | " fine_tune_args = {\n", 208 | " 'epochs': 8,\n", 209 | " 'base_lr': base_lr\n", 210 | " }\n", 211 | " # log a dict of parameters\n", 212 | " live.log_params(fine_tune_args)\n", 213 | "\n", 214 | " learn = unet_learner(data_loader, \n", 215 | " arch=getattr(models, train_arch), \n", 216 | " metrics=DiceMulti)\n", 217 | " # train model and automatically capture metrics with DVCLiveCallback\n", 218 | " learn.fine_tune(\n", 219 | " **fine_tune_args,\n", 220 | " cbs=[DVCLiveCallback(live=live)])\n", 221 | "\n", 222 | " # save model artifact to dvc\n", 223 | " models_dir = Path(\"models\")\n", 224 | " models_dir.mkdir(exist_ok=True)\n", 225 | " learn.export(fname=(models_dir / \"model.pkl\").absolute())\n", 226 | " torch.save(learn.model, (models_dir / \"model.pth\").absolute())\n", 227 | " live.log_artifact(\n", 228 | " str(models_dir / \"model.pkl\"),\n", 229 | " type=\"model\",\n", 230 | " name=\"pool-segmentation\",\n", 231 | " desc=\"This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.\",\n", 232 | " labels=[\"cv\", \"segmentation\", \"satellite-images\", \"unet\"],\n", 233 | " )\n", 234 | "\n", 235 | " # add additional post-training summary metrics.\n", 236 | " with Live(\"results/evaluate\") as live:\n", 237 | " live.summary[\"dice_multi\"] = evaluate(learn)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "# Compare experiments\n", 247 | "!dvc exp show --only-changed" 248 | ] 249 | }, 250 | { 251 | "attachments": {}, 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "### Review sample preditions vs ground truth\n", 256 | "\n", 257 | "Below are some example of the predicted masks." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "learn.show_results(max_n=6, alpha=0.7)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [] 275 | } 276 | ], 277 | "metadata": { 278 | "kernelspec": { 279 | "display_name": "Python 3 (ipykernel)", 280 | "language": "python", 281 | "name": "python3" 282 | }, 283 | "language_info": { 284 | "codemirror_mode": { 285 | "name": "ipython", 286 | "version": 3 287 | }, 288 | "file_extension": ".py", 289 | "mimetype": "text/x-python", 290 | "name": "python", 291 | "nbconvert_exporter": "python", 292 | "pygments_lexer": "ipython3", 293 | "version": "3.11.6" 294 | }, 295 | "vscode": { 296 | "interpreter": { 297 | "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" 298 | } 299 | } 300 | }, 301 | "nbformat": 4, 302 | "nbformat_minor": 4 303 | } 304 | -------------------------------------------------------------------------------- /params.yaml: -------------------------------------------------------------------------------- 1 | base: 2 | random_seed: 42 3 | 4 | data_split: 5 | test_regions: 6 | - REGION_1 7 | 8 | train: 9 | valid_pct: 0.1 10 | arch: shufflenet_v2_x2_0 11 | img_size: 256 12 | batch_size: 8 13 | fine_tune_args: 14 | epochs: 8 15 | base_lr: 0.01 16 | 17 | evaluate: 18 | n_samples_to_save: 10 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dvc[s3]>=3.29.0 2 | dvclive>=3.0.1 3 | fastai 4 | python-box 5 | sagemaker 6 | -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | /train 2 | /evaluate 3 | -------------------------------------------------------------------------------- /sagemaker/code/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#id4 4 | """ 5 | import io 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor 12 | 13 | 14 | def model_fn(model_dir, context): 15 | kwargs = { 16 | "f": os.path.join(model_dir, "code/model.pth") 17 | } 18 | if not torch.cuda.is_available(): 19 | kwargs["map_location"] = torch.device("cpu") 20 | model = torch.load(**kwargs) 21 | return model 22 | 23 | 24 | def input_fn(request_body, request_content_type, context): 25 | if request_content_type: 26 | img_pil = Image.open(io.BytesIO(request_body)) 27 | img_transform = Compose([Resize(512), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 28 | img_tensor = img_transform(img_pil).unsqueeze_(0) 29 | return img_tensor 30 | else: 31 | raise ValueError(f"Unsupported request_content_type {request_content_type}") 32 | 33 | 34 | def predict_fn(input_object, model, context): 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | model.to(device) 37 | with torch.no_grad(): 38 | result = model(input_object) 39 | return result 40 | 41 | 42 | def output_fn(prediction_output, content_type): 43 | output = np.array( 44 | prediction_output[:, 1, :] > 0.5, dtype=np.uint8 45 | ) 46 | if torch.cuda.is_available(): 47 | output = output.cpu() 48 | buffer = io.BytesIO() 49 | np.save(buffer, output) 50 | return buffer.getvalue() 51 | -------------------------------------------------------------------------------- /sagemaker/code/requirements.txt: -------------------------------------------------------------------------------- 1 | fastai 2 | pillow 3 | torch 4 | torchvision -------------------------------------------------------------------------------- /sagemaker/deploy_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import sys 4 | 5 | import boto3 6 | import botocore 7 | 8 | from sagemaker.deserializers import JSONDeserializer 9 | from sagemaker.pytorch import PyTorchModel 10 | from sagemaker.serverless import ServerlessInferenceConfig 11 | 12 | 13 | memory_size = { 14 | "dev": 4096 , 15 | "staging": 4096, 16 | "prod": 6144 , 17 | "default": 4096, 18 | } 19 | max_concurrency = { 20 | "dev": 5, 21 | "staging": 5, 22 | "prod": 10, 23 | "default": 5, 24 | } 25 | 26 | 27 | def deploy( 28 | name: str, 29 | stage: str, 30 | version: str, 31 | model_data: str, 32 | role: str, 33 | ): 34 | sagemaker_logger = logging.getLogger("sagemaker") 35 | sagemaker_logger.setLevel(logging.DEBUG) 36 | sagemaker_logger.addHandler(logging.StreamHandler(sys.stdout)) 37 | 38 | version_name = re.sub( 39 | r"[^a-zA-Z0-9\-]", "-", f"{name}-{version}") 40 | 41 | model = PyTorchModel( 42 | name=version_name, 43 | model_data=model_data, 44 | framework_version="1.12", 45 | py_version="py38", 46 | role=role, 47 | env={ 48 | "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", 49 | "TS_MAX_RESPONSE_SIZE": "2000000000", 50 | "TS_MAX_REQUEST_SIZE": "2000000000", 51 | "MMS_MAX_RESPONSE_SIZE": "2000000000", 52 | "MMS_MAX_REQUEST_SIZE": "2000000000", 53 | }, 54 | ) 55 | 56 | stage_name = re.sub( 57 | r"[^a-zA-Z0-9\-]", "-", f"{name}-{stage}") 58 | try: 59 | boto3.client("sagemaker").delete_endpoint(EndpointName=stage_name) 60 | except botocore.exceptions.ClientError as e: 61 | sagemaker_logger.warn(e) 62 | try: 63 | boto3.client("sagemaker").delete_endpoint_config(EndpointConfigName=stage_name) 64 | except botocore.exceptions.ClientError as e: 65 | sagemaker_logger.warn(e) 66 | 67 | return model.deploy( 68 | initial_instance_count=1, 69 | deserializer=JSONDeserializer(), 70 | endpoint_name=stage_name, 71 | serverless_inference_config=ServerlessInferenceConfig( 72 | memory_size_in_mb=memory_size[stage], 73 | max_concurrency=max_concurrency[stage] 74 | ) 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | import argparse 80 | 81 | parser = argparse.ArgumentParser(description="Deploy a model to Amazon SageMaker") 82 | 83 | parser.add_argument("--name", type=str, required=True, help="Name of the model") 84 | parser.add_argument("--stage", type=str, required=True, help="Stage of the model") 85 | parser.add_argument("--version", type=str, required=True, help="Version of the model") 86 | parser.add_argument("--model_data", type=str, required=True, help="S3 location of the model data") 87 | parser.add_argument("--role", type=str, required=True, help="ARN of the IAM role to use") 88 | 89 | args = parser.parse_args() 90 | 91 | deploy(name=args.name, stage=args.stage, version=args.version, model_data=args.model_data, role=args.role) 92 | -------------------------------------------------------------------------------- /src/data_split.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from box import ConfigBox 6 | from fastai.vision.all import get_files 7 | from ruamel.yaml import YAML 8 | 9 | 10 | yaml = YAML(typ="safe") 11 | 12 | 13 | def data_split(): 14 | params = ConfigBox(yaml.load(open("params.yaml", encoding="utf-8"))) 15 | np.random.seed(params.base.random_seed) 16 | img_fpaths = get_files(Path("data") / "pool_data" / "images", extensions=".jpg") 17 | 18 | train_data_dir = Path("data") / "train_data" 19 | train_data_dir.mkdir(exist_ok=True) 20 | test_data_dir = Path("data") / "test_data" 21 | test_data_dir.mkdir(exist_ok=True) 22 | for img_path in img_fpaths: 23 | msk_path = Path("data") / "pool_data" / "masks" / f"{img_path.stem}.png" 24 | if any(region in str(img_path) for region in params.data_split.test_regions): 25 | shutil.copy(img_path, test_data_dir) 26 | shutil.copy(msk_path, test_data_dir) 27 | else: 28 | shutil.copy(img_path, train_data_dir) 29 | shutil.copy(msk_path, train_data_dir) 30 | 31 | 32 | if __name__ == "__main__": 33 | data_split() 34 | -------------------------------------------------------------------------------- /src/endpoint_prediction.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from pathlib import Path 3 | 4 | import dvc.api 5 | import numpy as np 6 | from PIL import Image 7 | from sagemaker.deserializers import NumpyDeserializer 8 | from sagemaker.pytorch import PyTorchPredictor 9 | from sagemaker.serializers import IdentitySerializer 10 | 11 | 12 | def paint_mask(mask, color_map={0: (0, 0, 0), 1: (0, 0, 255)}): 13 | vis_shape = mask.shape + (3,) 14 | vis = np.zeros(vis_shape) 15 | for i, c in color_map.items(): 16 | vis[mask == i] = color_map[i] 17 | return Image.fromarray(vis.astype(np.uint8)) 18 | 19 | 20 | def endpoint_prediction( 21 | img_path: str, 22 | endpoint_name: str, 23 | output_path: str = "predictions", 24 | ): 25 | params = dvc.api.params_show() 26 | img_size = params["train"]["img_size"] 27 | predictor = PyTorchPredictor(endpoint_name, serializer=IdentitySerializer(), deserializer=NumpyDeserializer()) 28 | name = endpoint_name 29 | 30 | output_file = Path(output_path) / name / Path(img_path).name 31 | output_file.parent.mkdir(exist_ok=True, parents=True) 32 | 33 | io = BytesIO() 34 | Image.open(img_path).resize((img_size, img_size)).save(io, format="PNG") 35 | result = predictor.predict(io.getvalue())[0] 36 | 37 | img_pil = Image.open(img_path) 38 | overlay_img_pil = Image.blend( 39 | img_pil.convert("RGBA"), 40 | paint_mask(result).convert("RGBA").resize(img_pil.size), 41 | 0.5 42 | ) 43 | overlay_img_pil.save(str(output_file.with_suffix(".png"))) 44 | 45 | 46 | if __name__ == "__main__": 47 | import argparse 48 | 49 | parser = argparse.ArgumentParser(description='Run inference on an image using a SageMaker endpoint') 50 | parser.add_argument('--img_path', type=str, help='path to the input image') 51 | parser.add_argument('--endpoint_name', type=str, help='name of the SageMaker endpoint to use') 52 | parser.add_argument('--output_path', type=str, default='predictions', help='path to save the output predictions') 53 | 54 | args = parser.parse_args() 55 | 56 | endpoint_prediction(args.img_path, args.endpoint_name, args.output_path) 57 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from box import ConfigBox 5 | from dvclive import Live 6 | from fastai.vision.all import get_files, load_learner 7 | from PIL import Image 8 | from ruamel.yaml import YAML 9 | 10 | 11 | yaml = YAML(typ="safe") 12 | 13 | 14 | def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6): 15 | dice_list = [] 16 | for c in classes: 17 | y_true = mask_true == c 18 | y_pred = mask_pred == c 19 | intersection = 2.0 * np.sum(y_true * y_pred) 20 | dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps) 21 | dice_list.append(dice) 22 | return np.mean(dice_list) 23 | 24 | 25 | def paint_mask(mask, color_map={0: (0, 0, 0), 1: (0, 0, 255)}): 26 | vis_shape = mask.shape + (3,) 27 | vis = np.zeros(vis_shape) 28 | for i, c in color_map.items(): 29 | vis[mask == i] = color_map[i] 30 | return Image.fromarray(vis.astype(np.uint8)) 31 | 32 | 33 | def stack_images(im1, im2): 34 | dst = Image.new("RGB", (im1.width + im2.width, im1.height)) 35 | dst.paste(im1, (0, 0)) 36 | dst.paste(im2, (im1.width, 0)) 37 | return dst 38 | 39 | 40 | def get_overlay_image(img_fpath, mask_true, mask_pred): 41 | img_pil = Image.open(img_fpath) 42 | overlay_img_true = Image.blend( 43 | img_pil.convert("RGBA"), paint_mask(mask_true).convert("RGBA"), 0.5 44 | ) 45 | 46 | new_color_map = { 47 | 0: (0, 0, 0), # no color - TN 48 | 1: (255, 0, 255), # purple - FN 49 | 2: (255, 255, 0), # yellow - FP 50 | 3: (0, 0, 255), # blue - TP 51 | } 52 | combined_mask = mask_true + 2 * mask_pred 53 | 54 | overlay_img_pred = Image.blend( 55 | img_pil.convert("RGBA"), 56 | paint_mask(combined_mask, color_map=new_color_map).convert("RGBA"), 57 | 0.5, 58 | ) 59 | stacked_image = stack_images(overlay_img_true, overlay_img_pred) 60 | return stacked_image 61 | 62 | 63 | def get_mask_path(x, train_data_dir): 64 | return Path(train_data_dir) / f"{Path(x).stem}.png" 65 | 66 | 67 | def evaluate(): 68 | params = ConfigBox(yaml.load(open("params.yaml", encoding="utf-8"))) 69 | model_fpath = Path("models") / "model.pkl" 70 | learn = load_learner(model_fpath, cpu=False) 71 | test_img_fpaths = sorted(get_files(Path("data") / "test_data", extensions=".jpg")) 72 | test_dl = learn.dls.test_dl(test_img_fpaths) 73 | preds, _ = learn.get_preds(dl=test_dl) 74 | masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8) 75 | test_mask_fpaths = [ 76 | get_mask_path(fpath, Path("data") / "test_data") for fpath in test_img_fpaths 77 | ] 78 | masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths] 79 | with Live("results/evaluate") as live: 80 | dice_multi = 0.0 81 | for ii in range(len(masks_true)): 82 | mask_pred, mask_true = masks_pred[ii], masks_true[ii] 83 | mask_pred = np.array( 84 | Image.fromarray(mask_pred).resize((mask_true.shape[1], mask_true.shape[0])), 85 | dtype=int 86 | ) 87 | mask_true = np.array(mask_true, dtype=int) 88 | dice_multi += dice(mask_true, mask_pred) / len(masks_true) 89 | 90 | if ii < params.evaluate.n_samples_to_save: 91 | stacked_image = get_overlay_image( 92 | test_img_fpaths[ii], mask_true, mask_pred 93 | ) 94 | stacked_image = stacked_image.resize((512, 256)) 95 | live.log_image(f"{Path(test_img_fpaths[ii]).stem}.png", stacked_image) 96 | 97 | live.summary["dice_multi"] = dice_multi 98 | 99 | 100 | if __name__ == "__main__": 101 | evaluate() 102 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import partial 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from box import ConfigBox 8 | from dvclive import Live 9 | from dvclive.fastai import DVCLiveCallback 10 | from fastai.data.all import Normalize, get_files 11 | from fastai.metrics import DiceMulti 12 | from fastai.vision.all import ( 13 | Resize, 14 | SegmentationDataLoaders, 15 | imagenet_stats, 16 | models, 17 | unet_learner, 18 | ) 19 | from ruamel.yaml import YAML 20 | 21 | yaml = YAML(typ="safe") 22 | 23 | 24 | def get_mask_path(x, train_data_dir): 25 | return Path(train_data_dir) / f"{Path(x).stem}.png" 26 | 27 | 28 | def train(): 29 | params = ConfigBox(yaml.load(open("params.yaml", encoding="utf-8"))) 30 | 31 | np.random.seed(params.base.random_seed) 32 | torch.manual_seed(params.base.random_seed) 33 | random.seed(params.base.random_seed) 34 | train_data_dir = Path("data") / "train_data" 35 | 36 | data_loader = SegmentationDataLoaders.from_label_func( 37 | path=train_data_dir, 38 | fnames=get_files(train_data_dir, extensions=".jpg"), 39 | label_func=partial(get_mask_path, train_data_dir=train_data_dir), 40 | codes=["not-pool", "pool"], 41 | bs=params.train.batch_size, 42 | valid_pct=params.train.valid_pct, 43 | item_tfms=Resize(params.train.img_size), 44 | batch_tfms=[ 45 | Normalize.from_stats(*imagenet_stats), 46 | ], 47 | ) 48 | 49 | model_names = [ 50 | name 51 | for name in dir(models) 52 | if not name.startswith("_") 53 | and name.islower() 54 | and name not in ("all", "tvm", "unet", "xresnet") 55 | ] 56 | if params.train.arch not in model_names: 57 | raise ValueError(f"Unsupported model, must be one of:\n{model_names}") 58 | 59 | with Live("results/train") as live: 60 | learn = unet_learner( 61 | data_loader, arch=getattr(models, params.train.arch), metrics=DiceMulti 62 | ) 63 | 64 | learn.fine_tune( 65 | **params.train.fine_tune_args, 66 | cbs=[DVCLiveCallback(live=live)], 67 | ) 68 | models_dir = Path("models") 69 | models_dir.mkdir(exist_ok=True) 70 | learn.export(fname=(models_dir / "model.pkl").absolute()) 71 | torch.save(learn.model, (models_dir / "model.pth").absolute()) 72 | live.log_artifact( 73 | str(models_dir / "model.pkl"), 74 | type="model", 75 | name="pool-segmentation", 76 | desc="This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.", 77 | labels=["cv", "segmentation", "satellite-images", params.train.arch], 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | train() 83 | --------------------------------------------------------------------------------