├── py-bsds500 ├── bsds │ ├── __init__.py │ ├── correspond_pixels.pyx │ ├── thin.py │ ├── bsds_dataset.py │ ├── evaluate_boundaries.py │ └── evaluate_boundaries_parallel.py ├── src │ ├── kofn.hh │ ├── match.hh │ ├── csa.cc │ ├── Exception.cc │ ├── Point.hh │ ├── Exception.hh │ ├── kofn.cc │ ├── Random.cc │ ├── Timer.cc │ ├── Timer.hh │ ├── Random.hh │ ├── String.cc │ ├── String.hh │ ├── csa_types.h │ ├── csa_defs.h │ ├── Sort.hh │ ├── Array.hh │ ├── match.cc │ └── Matrix.hh ├── README.md ├── setup.py └── evaluate_parallel.py ├── assets ├── edge.png ├── fish.jpg └── fish_gt.png ├── scripts └── docker │ ├── build.sh │ └── run.sh ├── requirements.txt ├── Dockerfile ├── example.py ├── .gitignore ├── pipeline.py ├── README.md └── automatic_mask_and_probability_generator.py /py-bsds500/bsds/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/edge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgw55/segment-anything-edge-detection/HEAD/assets/edge.png -------------------------------------------------------------------------------- /assets/fish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgw55/segment-anything-edge-detection/HEAD/assets/fish.jpg -------------------------------------------------------------------------------- /assets/fish_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymgw55/segment-anything-edge-detection/HEAD/assets/fish_gt.png -------------------------------------------------------------------------------- /scripts/docker/build.sh: -------------------------------------------------------------------------------- 1 | docker build --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -g) --build-arg USERNAME=$(whoami) -t ${USER}/samed . -------------------------------------------------------------------------------- /py-bsds500/src/kofn.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __kofn_hh__ 3 | #define __kofn_hh__ 4 | 5 | void kOfN (int k, int n, int* values); 6 | 7 | #endif // __kofn_hh__ 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.36 2 | fire==0.4.0 3 | isort==5.10.1 4 | jupyterlab>=3.6.8 5 | numpy==1.23.1 6 | onnx>=1.16.2 7 | onnxruntime==1.15.1 8 | opencv-contrib-python>=4.8.1.78 9 | pandas==1.4.2 10 | protobuf==4.25.8 11 | pycocotools>=2.0.6 12 | scikit-image>=0.21.0 13 | scikit-learn>=1.3.2 14 | scipy>=1.10.0 15 | -------------------------------------------------------------------------------- /py-bsds500/src/match.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __match_hh__ 3 | #define __match_hh__ 4 | 5 | class Matrix; 6 | 7 | // returns the cost of the assignment 8 | double matchEdgeMaps ( 9 | const Matrix& bmap1, const Matrix& bmap2, 10 | double maxDist, double outlierCost, 11 | Matrix& match1, Matrix& match2); 12 | 13 | #endif // __match_hh__ 14 | -------------------------------------------------------------------------------- /scripts/docker/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if an argument is provided, if not display usage and exit 4 | if [ -z "$1" ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | project_name="samed" 10 | device_id=$1 11 | container_name="${project_name}_${device_id}" 12 | docker_image="${USER}/${project_name}" 13 | 14 | docker run --rm -it --name ${container_name} \ 15 | -u $(id -u):$(id -g) \ 16 | --gpus device=${device_id} \ 17 | -v $PWD:/working \ 18 | ${docker_image} bash -------------------------------------------------------------------------------- /py-bsds500/README.md: -------------------------------------------------------------------------------- 1 | # Python port of BSDS 500 boundary prediction evaluation suite 2 | 3 | Uses quite a lot of code from the original BSDS evaluation suite at 4 | [berkeley.edu/Research/Projects/CS/vision/bsds/](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) 5 | 6 | 7 | Takes the original C++ source code that provides the `matchPixels` function for Matlab 8 | and wraps it with Cython to make it available from Python. 9 | 10 | Provides a Python implementation of the morphological thinning operation. 11 | 12 | Compile the extension module with: 13 | 14 | `python setup.py build_ext --inplace` 15 | 16 | Then run: 17 | 18 | `python verify.py ` 19 | 20 | You should get output that (almost) matches the text files in the 21 | `bench/data/test_2` directory within the BSDS package. 22 | -------------------------------------------------------------------------------- /py-bsds500/src/csa.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "csa.hh" 3 | 4 | char* CSA::err_messages[] = 5 | { 6 | "Can't read from the input file.", 7 | "Not a correct assignment problem line.", 8 | "Error reading a node descriptor from the input.", 9 | "Error reading an arc descriptor from the input.", 10 | "Unknown line type in the input", 11 | "Inconsistent number of arcs in the input.", 12 | "Parsing noncontiguous node ID numbers not implemented.", 13 | "Can't obtain enough memory to solve this problem.", 14 | }; 15 | 16 | char* CSA::nomem_msg = "Insufficient memory.\n"; 17 | 18 | CSA::CSA (int n, int m, const int* graph) 19 | { 20 | assert(n>0); 21 | assert(m>0); 22 | assert(graph!=NULL); 23 | assert((n%2)==0); 24 | _init(n,m); 25 | main(graph); 26 | } 27 | 28 | CSA::~CSA () 29 | { 30 | _delete(); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime 2 | 3 | RUN apt update -y && apt install -y \ 4 | git 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | libsndfile1 \ 8 | vim \ 9 | libgl1-mesa-dev \ 10 | libglib2.0-0 11 | 12 | # Copy files from host to the image. 13 | COPY requirements.txt /tmp/requirements.txt 14 | 15 | # Install python package, remove copied file and cache. 16 | RUN pip install --upgrade pip && \ 17 | pip install -r /tmp/requirements.txt 18 | 19 | RUN pip install git+https://github.com/facebookresearch/segment-anything.git 20 | 21 | # Language settings 22 | ENV LANG=C.UTF-8 23 | ENV LANGUAGE=en_US 24 | 25 | # Create the user. 26 | ARG USERNAME 27 | ARG USER_UID 28 | ARG USER_GID 29 | RUN groupadd --gid $USER_GID $USERNAME \ 30 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 31 | && apt-get update \ 32 | && apt-get install -y --no-install-recommends sudo \ 33 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ 34 | && chmod 0440 /etc/sudoers.d/$USERNAME \ 35 | && rm -rf /var/lib/apt/lists/* \ 36 | && mkdir -p /home/$USERNAME 37 | 38 | # Directory settings for login 39 | WORKDIR /working 40 | RUN chmod 777 /working 41 | -------------------------------------------------------------------------------- /py-bsds500/src/Exception.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include "Exception.hh" 22 | 23 | Exception::Exception (const char* msg) 24 | : _msg (strdup (msg)) 25 | { 26 | } 27 | 28 | Exception::Exception (const Exception& that) 29 | : _msg (strdup (that._msg)) 30 | { 31 | } 32 | 33 | Exception::~Exception () 34 | { 35 | free (_msg); 36 | } 37 | 38 | const char* 39 | Exception::msg () const 40 | { 41 | return _msg; 42 | } 43 | 44 | -------------------------------------------------------------------------------- /py-bsds500/src/Point.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Point_hh__ 3 | #define __Point_hh__ 4 | 5 | // Simple point template classes. 6 | // Probably only make sense for intrinsic types. 7 | 8 | // 2D Points 9 | 10 | template 11 | class Point2D 12 | { 13 | public: 14 | Point2D () { x = 0; y = 0; } 15 | Point2D (T x, T y) { this->x = x; this->y = y; } 16 | T x,y; 17 | }; 18 | 19 | template 20 | inline int operator== (const Point2D& a, const Point2D& b) 21 | { return (a.x == b.x) && (a.y == b.y); } 22 | 23 | template 24 | inline int operator!= (const Point2D& a, const Point2D& b) 25 | { return (a.x != b.x) || (a.y != b.y); } 26 | 27 | typedef Point2D Pixel; 28 | 29 | // 3D Points 30 | 31 | template 32 | class Point3D 33 | { 34 | public: 35 | Point3D () { x = 0; y = 0; z = 0; } 36 | Point3D (T x, T y) { this->x = x; this->y = y; this->z = z;} 37 | T x,y,z; 38 | }; 39 | 40 | template 41 | inline int operator== (const Point3D& a, const Point3D& b) 42 | { return (a.x == b.x) && (a.y == b.y) && (a.z == b.z); } 43 | 44 | template 45 | inline int operator!= (const Point3D& a, const Point3D& b) 46 | { return (a.x != b.x) || (a.y != b.y) || (a.z != b.z); } 47 | 48 | typedef Point3D Voxel; 49 | 50 | #endif // __Point_hh__ 51 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from segment_anything import sam_model_registry 7 | 8 | from automatic_mask_and_probability_generator import \ 9 | SamAutomaticMaskAndProbabilityGenerator 10 | 11 | 12 | def normalize_image(image): 13 | # Normalize the image to the range [0, 1] 14 | min_val = image.min() 15 | max_val = image.max() 16 | image = (image - min_val) / (max_val - min_val) 17 | 18 | return image 19 | 20 | 21 | def main(): 22 | device = "cuda" 23 | sam = sam_model_registry["default"]( 24 | checkpoint="model/sam_vit_h_4b8939.pth") 25 | sam.to(device=device) 26 | generator = SamAutomaticMaskAndProbabilityGenerator(sam, pred_iou_thresh = 0.88, stability_score_thresh = 0.95) 27 | 28 | img_path = 'assets/fish.jpg' 29 | image = cv2.imread(img_path) 30 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 31 | masks = generator.generate(image) 32 | 33 | p_max = None 34 | for mask in masks: 35 | p = mask["prob"] 36 | if p_max is None: 37 | p_max = p 38 | else: 39 | p_max = np.maximum(p_max, p) 40 | 41 | edges = normalize_image(p_max) 42 | edge_detection = cv2.ximgproc.createStructuredEdgeDetection( 43 | 'model/model.yml.gz') 44 | orimap = edge_detection.computeOrientation(edges) 45 | edges = edge_detection.edgesNms(edges, orimap) 46 | 47 | # make output directory 48 | Path('output/example').mkdir(parents=True, exist_ok=True) 49 | plt.imsave('output/example/edge.png', edges, cmap='binary') 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /py-bsds500/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup, find_packages 4 | from setuptools.extension import Extension 5 | from Cython.Build import cythonize 6 | 7 | version = '0.1.dev1' 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | try: 11 | README = open(os.path.join(here, 'README.md')).read() 12 | except IOError: 13 | README = '' 14 | 15 | install_requires = [ 16 | 'numpy', 17 | 'scikit-image' 18 | # 'Theano', # we require a development version, see requirements.txt 19 | ] 20 | 21 | extensions = [ 22 | Extension( 23 | 'bsds.correspond_pixels', 24 | [os.path.join('bsds', 'correspond_pixels.pyx')], 25 | ), 26 | ] 27 | 28 | setup( 29 | name = "py-bsds500", 30 | version=version, 31 | description="BSDS-500 access library and evaluation suite for Python", 32 | long_description="\n\n".join([README]), 33 | classifiers=[ 34 | "Development Status :: 1 - Alpha", 35 | "Intended Audience :: Developers", 36 | "Intended Audience :: Science/Research", 37 | "License :: OSI Approved :: MIT License", 38 | "Programming Language :: Python :: 2.7", 39 | # "Programming Language :: Python :: 3", 40 | # "Programming Language :: Python :: 3.4", 41 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 42 | ], 43 | keywords="", 44 | author="Geoffrey French", 45 | # author_email="brittix1023 at gmail dot com", 46 | url="https://github.com/Britefury/py-bsds500", 47 | license="MIT", 48 | # packages=find_packages(), 49 | include_package_data=False, 50 | zip_safe=False, 51 | install_requires=install_requires, 52 | 53 | packages = find_packages(), 54 | ext_modules = cythonize(extensions) 55 | ) 56 | -------------------------------------------------------------------------------- /py-bsds500/src/Exception.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Exception_hh__ 3 | #define __Exception_hh__ 4 | 5 | // A simple exception class that contains an error message. 6 | 7 | // Copyright (C) 2002 David R. Martin 8 | // 9 | // This program is free software; you can redistribute it and/or 10 | // modify it under the terms of the GNU General Public License as 11 | // published by the Free Software Foundation; either version 2 of the 12 | // License, or (at your option) any later version. 13 | // 14 | // This program is distributed in the hope that it will be useful, but 15 | // WITHOUT ANY WARRANTY; without even the implied warranty of 16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 | // General Public License for more details. 18 | // 19 | // You should have received a copy of the GNU General Public License 20 | // along with this program; if not, write to the Free Software 21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 23 | 24 | #include 25 | 26 | class Exception 27 | { 28 | public: 29 | 30 | // Always construct exception with a message, so we can print 31 | // a useful error/log message. 32 | Exception (const char* msg); 33 | 34 | // We need to implement the copy constructor so that rethrowing 35 | // works. 36 | Exception (const Exception& that); 37 | 38 | virtual ~Exception (); 39 | 40 | // Retrieve the message that this exception carries. 41 | virtual const char* msg () const; 42 | 43 | protected: 44 | 45 | char* _msg; 46 | 47 | }; 48 | 49 | // write to output stream 50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) { 51 | out << e.msg(); 52 | return out; 53 | } 54 | 55 | #endif // __Exception_hh__ 56 | -------------------------------------------------------------------------------- /py-bsds500/src/kofn.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "Random.hh" 3 | #include "kofn.hh" 4 | 5 | // O(n) implementation. 6 | static void 7 | _kOfN_largeK (int k, int n, int* values) 8 | { 9 | assert (k > 0); 10 | assert (k <= n); 11 | int j = 0; 12 | for (int i = 0; i < n; i++) { 13 | double prob = (double) (k - j) / (n - i); 14 | assert (prob <= 1); 15 | double x = Random::rand.fp (); 16 | if (x < prob) { 17 | values[j++] = i; 18 | } 19 | } 20 | assert (j == k); 21 | } 22 | 23 | // O(k*lg(k)) implementation; constant factor is about 2x the constant 24 | // factor for the O(n) implementation. 25 | static void 26 | _kOfN_smallK (int k, int n, int* values) 27 | { 28 | assert (k > 0); 29 | assert (k <= n); 30 | if (k == 1) { 31 | values[0] = Random::rand.i32 (0, n - 1); 32 | return; 33 | } 34 | int leftN = n / 2; 35 | int rightN = n - leftN; 36 | int leftK = 0; 37 | int rightK = 0; 38 | for (int i = 0; i < k; i++) { 39 | int x = Random::rand.i32 (0, n - i - 1); 40 | if (x < leftN - leftK) { 41 | leftK++; 42 | } else { 43 | rightK++; 44 | } 45 | } 46 | if (leftK > 0) { _kOfN_smallK (leftK, leftN, values); } 47 | if (rightK > 0) { _kOfN_smallK (rightK, rightN, values + leftK); } 48 | for (int i = leftK; i < k; i++) { 49 | values[i] += leftN; 50 | } 51 | } 52 | 53 | // Return k randomly selected integers from the interval [0,n), in 54 | // increasing sorted order. 55 | void 56 | kOfN (int k, int n, int* values) 57 | { 58 | assert (k >= 0); 59 | assert (n >= 0); 60 | if (k == 0) { return; } 61 | static double log2 = log (2); 62 | double klogk = k * log (k) / log2; 63 | if (klogk < n / 2) { 64 | _kOfN_smallK (k, n, values); 65 | } else { 66 | _kOfN_largeK (k, n, values); 67 | } 68 | } 69 | 70 | -------------------------------------------------------------------------------- /py-bsds500/src/Random.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "Random.hh" 8 | #include "String.hh" 9 | #include "Exception.hh" 10 | 11 | // Copyright (C) 2002 David R. Martin 12 | // 13 | // This program is free software; you can redistribute it and/or 14 | // modify it under the terms of the GNU General Public License as 15 | // published by the Free Software Foundation; either version 2 of the 16 | // License, or (at your option) any later version. 17 | // 18 | // This program is distributed in the hope that it will be useful, but 19 | // WITHOUT ANY WARRANTY; without even the implied warranty of 20 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 21 | // General Public License for more details. 22 | // 23 | // You should have received a copy of the GNU General Public License 24 | // along with this program; if not, write to the Free Software 25 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 26 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 27 | 28 | Random Random::rand; 29 | 30 | Random::Random () 31 | { 32 | reseed (0); 33 | } 34 | 35 | Random::Random (u_int64_t seed) 36 | { 37 | reseed (seed); 38 | } 39 | 40 | Random::Random (Random& that) 41 | { 42 | u_int64_t a = that.ui32 (); 43 | u_int64_t b = that.ui32 (); 44 | u_int64_t seed = (a << 32) | b; 45 | _init (seed); 46 | } 47 | 48 | void 49 | Random::reset () 50 | { 51 | _init (_seed); 52 | } 53 | 54 | void 55 | Random::reseed (u_int64_t seed) 56 | { 57 | if (seed == 0) { 58 | struct timeval t; 59 | gettimeofday (&t, NULL); 60 | u_int64_t a = (t.tv_usec >> 3) & 0xffff; 61 | u_int64_t b = t.tv_sec & 0xffff; 62 | u_int64_t c = (t.tv_sec >> 16) & 0xffff; 63 | seed = a | (b << 16) | (c << 32); 64 | } 65 | _init (seed); 66 | } 67 | 68 | void 69 | Random::_init (u_int64_t seed) 70 | { 71 | _seed = seed & 0xffffffffffffull; 72 | _xsubi[0] = (seed >> 0) & 0xffff; 73 | _xsubi[1] = (seed >> 16) & 0xffff; 74 | _xsubi[2] = (seed >> 32) & 0xffff; 75 | } 76 | 77 | -------------------------------------------------------------------------------- /py-bsds500/src/Timer.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include 22 | #include "Timer.hh" 23 | 24 | typedef unsigned long long uint64; 25 | 26 | void 27 | Timer::_compute () 28 | { 29 | // Compute elapsed time. 30 | long sec = _elapsed_stop.tv_sec - _elapsed_start.tv_sec; 31 | long usec = _elapsed_stop.tv_usec - _elapsed_start.tv_usec; 32 | if (usec < 0) { 33 | sec -= 1; 34 | usec += 1000000; 35 | } 36 | _elapsed += (double) sec + usec / 1e6; 37 | 38 | // Computer CPU user and system times. 39 | _user += (double) (_cpu_stop.tms_utime - _cpu_start.tms_utime) 40 | / sysconf(_SC_CLK_TCK); 41 | _system += (double) (_cpu_stop.tms_stime - _cpu_start.tms_stime) 42 | / sysconf(_SC_CLK_TCK); 43 | } 44 | 45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss 46 | // Return a pointer to a static buffer. 47 | const char* 48 | Timer::formatTime (double sec, int precision) 49 | { 50 | static char buf[128]; 51 | 52 | // Limit range of precision for safety and sanity. 53 | precision = (precision < 0) ? 0 : precision; 54 | precision = (precision > 9) ? 9 : precision; 55 | uint64 base = 1; 56 | for (int digit = 0; digit < precision; digit++) { base *= 10;} 57 | 58 | bool neg = (sec < 0); 59 | uint64 ticks = (uint64) rint (fabs (sec) * base); 60 | uint64 rsec = ticks / base; // Rounded seconds. 61 | uint64 frac = ticks % base; 62 | 63 | uint64 h = rsec / 3600; 64 | uint64 m = (rsec / 60) % 60; 65 | uint64 s = rsec % 60; 66 | 67 | sprintf (buf, "%s%llu:%02llu:%02llu", 68 | neg ? "-" : "", h, m, s); 69 | 70 | if (precision > 0) { 71 | static char fmt[10]; 72 | sprintf (fmt, ".%%0%dlld", precision); 73 | sprintf (buf + strlen (buf), fmt, frac); 74 | } 75 | 76 | return buf; 77 | } 78 | 79 | -------------------------------------------------------------------------------- /py-bsds500/bsds/correspond_pixels.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # distutils: sources = src/csa.cc src/Exception.cc src/kofn.cc src/match.cc src/Matrix.cc src/Random.cc src/String.cc src/Timer.cc 3 | # distutils: extra_compile_args = -DNOBLAS 4 | # 5 | import math 6 | import numpy as np 7 | 8 | 9 | cdef extern from "../src/Matrix.hh": 10 | cppclass Matrix: 11 | Matrix () except + 12 | Matrix (int rows, int cols) except + 13 | Matrix (int rows, int cols, double* data) except + 14 | double* data () 15 | 16 | cdef extern from "../src/match.hh": 17 | double matchEdgeMaps(const Matrix& bmap1, const Matrix& bmap2, 18 | double maxDist, double outlierCost, 19 | Matrix& match1, Matrix& match2) 20 | 21 | 22 | 23 | cdef _correspond_pixels(double[::1,:] img0, double[::1,:] img1, double max_dist, double outlier_cost, 24 | double[::1,:] out0, double[::1,:] out1): 25 | cdef int rows = img0.shape[0] 26 | cdef int cols = img0.shape[1] 27 | cdef double idiag = math.sqrt(rows * rows + cols * cols) 28 | cdef double oc = outlier_cost * max_dist * idiag 29 | 30 | # Copy data to Matrix types; construct matrices, get views of their contents and copy 31 | # over 32 | # Constructing a Matrix from a double* acquired from views of im0 and img1 don't 33 | # work well at all... 34 | cdef Matrix i0 = Matrix(rows, cols) 35 | cdef Matrix i1 = Matrix(rows, cols) 36 | cdef double[::1,:] i0_view = i0.data() 37 | cdef double[::1,:] i1_view = i1.data() 38 | i0_view[:,:] = img0[:,:] 39 | i1_view[:,:] = img1[:,:] 40 | 41 | # Output matrices 42 | cdef Matrix m0, m1 43 | 44 | # Perform the match 45 | cdef double cost = matchEdgeMaps(i0, i1, max_dist * idiag, oc, m0, m1) 46 | 47 | # Get views of the output matrices and copy to our output arrays 48 | cdef double[::1,:] o0_view = m0.data() 49 | cdef double[::1,:] o1_view = m1.data() 50 | out0[:,:] = o0_view[:,:] 51 | out1[:,:] = o1_view[:,:] 52 | 53 | return cost, oc 54 | 55 | 56 | def correspond_pixels(img0, img1, max_dist=0.0075, outlier_cost=100.0): 57 | if img0.shape != img1.shape: 58 | raise ValueError('img0.shape ({}) and img1.shape({}) do not match'.format(img0.shape, img1.shape)) 59 | if max_dist <= 0.0: 60 | raise ValueError('max_dist must be >= 0 (it is {})'.format(max_dist)) 61 | if outlier_cost <= 1: 62 | raise ValueError('outlier_cost must be > 1 (it is {})'.format(max_dist)) 63 | 64 | 65 | i0 = img0.astype('float64').copy(order='F') 66 | i1 = img1.astype('float64').copy(order='F') 67 | o0 = np.zeros_like(i0, order='F') 68 | o1 = np.zeros_like(i1, order='F') 69 | max_dist = float(max_dist) 70 | outlier_cost = float(outlier_cost) 71 | cost, oc = _correspond_pixels(i0, i1, max_dist, outlier_cost, o0, o1) 72 | return o0, o1, cost, oc -------------------------------------------------------------------------------- /py-bsds500/src/Timer.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Timer_hh__ 3 | #define __Timer_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | class Timer 29 | { 30 | public: 31 | 32 | inline Timer (); 33 | inline ~Timer (); 34 | 35 | inline void start (); 36 | inline void stop (); 37 | inline void reset (); 38 | 39 | // All times are in seconds. 40 | inline double cpu (); 41 | inline double user (); 42 | inline double system (); 43 | inline double elapsed (); 44 | 45 | // Convert time in seconds into a nice human-friendly format: h:mm:ss.ss 46 | // Precision is the number of digits after the decimal. 47 | // Return a pointer to a static buffer. 48 | static const char* formatTime (double sec, int precision = 2); 49 | 50 | private: 51 | 52 | void _compute (); 53 | 54 | enum State { stopped, running }; 55 | 56 | State _state; 57 | 58 | struct timeval _elapsed_start; 59 | struct timeval _elapsed_stop; 60 | double _elapsed; 61 | 62 | struct tms _cpu_start; 63 | struct tms _cpu_stop; 64 | double _user; 65 | double _system; 66 | }; 67 | 68 | Timer::Timer () 69 | { 70 | reset (); 71 | } 72 | 73 | Timer::~Timer () 74 | { 75 | } 76 | 77 | void 78 | Timer::reset () 79 | { 80 | _state = stopped; 81 | _elapsed = _user = _system = 0; 82 | } 83 | 84 | void 85 | Timer::start () 86 | { 87 | assert (_state == stopped); 88 | _state = running; 89 | gettimeofday (&_elapsed_start, NULL); 90 | times (&_cpu_start); 91 | } 92 | 93 | void 94 | Timer::stop () 95 | { 96 | assert (_state == running); 97 | gettimeofday (&_elapsed_stop, NULL); 98 | times (&_cpu_stop); 99 | _compute (); 100 | _state = stopped; 101 | } 102 | 103 | double 104 | Timer::cpu () 105 | { 106 | assert (_state == stopped); 107 | return _user + _system; 108 | } 109 | 110 | double 111 | Timer::user () 112 | { 113 | assert (_state == stopped); 114 | return _user; 115 | } 116 | 117 | double 118 | Timer::system () 119 | { 120 | assert (_state == stopped); 121 | return _system; 122 | } 123 | 124 | double 125 | Timer::elapsed () 126 | { 127 | assert (_state == stopped); 128 | return _elapsed; 129 | } 130 | 131 | #endif // __Timer_hh__ 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /py-bsds500/src/Random.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Random_hh__ 3 | #define __Random_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | // All random numbers are generated from a single seed. This is true 29 | // even when private random streams (seperate from the global 30 | // Random::rand stream) are spawned from existing streams, since the new 31 | // streams are seeded automatically from the parent's random stream. 32 | // Any random stream can be reset so that a sequence of random values 33 | // can be replayed. 34 | 35 | // If seed==0, then the seed is generated from the system clock. 36 | 37 | typedef uint16_t u_int16_t; 38 | typedef uint32_t u_int32_t; 39 | typedef uint64_t u_int64_t; 40 | 41 | class Random 42 | { 43 | public: 44 | 45 | static Random rand; 46 | 47 | // These are defined in as the limits of int, but 48 | // here we need the limits of int32_t. 49 | static const int32_t int32_max = 2147483647; 50 | static const int32_t int32_min = -int32_max-1; 51 | static const u_int32_t u_int32_max = 4294967295u; 52 | 53 | // Seed from the system clock. 54 | Random (); 55 | 56 | // Specify seed. 57 | // If zero, seed from the system clock. 58 | Random (u_int64_t seed); 59 | 60 | // Spawn off a new random stream seeded from the parent's stream. 61 | Random (Random& that); 62 | 63 | // Restore initial seed so we can replay a random sequence. 64 | void reset (); 65 | 66 | // Set the seed. 67 | // If zero, seed from the system clock. 68 | void reseed (u_int64_t seed); 69 | 70 | // double in [0..1) or [a..b) 71 | inline double fp (); 72 | inline double fp (double a, double b); 73 | 74 | // 32-bit signed integer in [-2^31,2^31) or [a..b] 75 | inline int32_t i32 (); 76 | inline int32_t i32 (int32_t a, int32_t b); 77 | 78 | // 32-bit unsigned integer in [0,2^32) or [a..b] 79 | inline u_int32_t ui32 (); 80 | inline u_int32_t ui32 (u_int32_t a, u_int32_t b); 81 | 82 | protected: 83 | 84 | void _init (u_int64_t seed); 85 | 86 | // The original seed for this random stream. 87 | u_int64_t _seed; 88 | 89 | // The current state for this random stream. 90 | u_int16_t _xsubi[3]; 91 | 92 | }; 93 | 94 | inline u_int32_t 95 | Random::ui32 () 96 | { 97 | return ui32(0,u_int32_max); 98 | } 99 | 100 | inline u_int32_t 101 | Random::ui32 (u_int32_t a, u_int32_t b) 102 | { 103 | assert (a <= b); 104 | double x = fp (); 105 | return (u_int32_t) floor (x * ((double)b - (double)a + 1) + a); 106 | } 107 | 108 | inline int32_t 109 | Random::i32 () 110 | { 111 | return i32(int32_min,int32_max); 112 | } 113 | 114 | inline int32_t 115 | Random::i32 (int32_t a, int32_t b) 116 | { 117 | assert (a <= b); 118 | double x = fp (); 119 | return (int32_t) floor (x * ((double)b - (double)a + 1) + a); 120 | } 121 | 122 | inline double 123 | Random::fp () 124 | { 125 | return erand48 (_xsubi); 126 | } 127 | 128 | inline double 129 | Random::fp (double a, double b) 130 | { 131 | assert (a < b); 132 | return erand48 (_xsubi) * (b - a) + a; 133 | } 134 | 135 | #endif // __Random_hh__ 136 | 137 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #Do NOT edit this file manually. Use ./.gitignore_gen.sh > .gitignore 2 | ### Generated by gibo (https://github.com/simonwhitaker/gibo) 3 | ### https://raw.github.com/github/gitignore/4488915eec0b3a45b5c63ead28f286819c0917de/Global/Vim.gitignore 4 | 5 | # Swap 6 | [._]*.s[a-v][a-z] 7 | !*.svg # comment out if you don't need vector files 8 | [._]*.sw[a-p] 9 | [._]s[a-rt-v][a-z] 10 | [._]ss[a-gi-z] 11 | [._]sw[a-p] 12 | 13 | # Session 14 | Session.vim 15 | Sessionx.vim 16 | 17 | # Temporary 18 | .netrwhist 19 | *~ 20 | # Auto-generated tag files 21 | tags 22 | # Persistent undo 23 | [._]*.un~ 24 | 25 | 26 | ### https://raw.github.com/github/gitignore/4488915eec0b3a45b5c63ead28f286819c0917de/Python.gitignore 27 | 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | share/python-wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .nox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | *.py,cover 77 | .hypothesis/ 78 | .pytest_cache/ 79 | cover/ 80 | 81 | # Translations 82 | *.mo 83 | *.pot 84 | 85 | # Django stuff: 86 | *.log 87 | local_settings.py 88 | db.sqlite3 89 | db.sqlite3-journal 90 | 91 | # Flask stuff: 92 | instance/ 93 | .webassets-cache 94 | 95 | # Scrapy stuff: 96 | .scrapy 97 | 98 | # Sphinx documentation 99 | docs/_build/ 100 | 101 | # PyBuilder 102 | .pybuilder/ 103 | target/ 104 | 105 | # Jupyter Notebook 106 | .ipynb_checkpoints 107 | 108 | # IPython 109 | profile_default/ 110 | ipython_config.py 111 | 112 | # pyenv 113 | # For a library or package, you might want to ignore these files since the code is 114 | # intended to run in multiple environments; otherwise, check them in: 115 | # .python-version 116 | 117 | # pipenv 118 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 119 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 120 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 121 | # install all needed dependencies. 122 | #Pipfile.lock 123 | 124 | # poetry 125 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 126 | # This is especially recommended for binary packages to ensure reproducibility, and is more 127 | # commonly ignored for libraries. 128 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 129 | #poetry.lock 130 | 131 | # pdm 132 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 133 | #pdm.lock 134 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 135 | # in version control. 136 | # https://pdm.fming.dev/#use-with-ide 137 | .pdm.toml 138 | 139 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 140 | __pypackages__/ 141 | 142 | # Celery stuff 143 | celerybeat-schedule 144 | celerybeat.pid 145 | 146 | # SageMath parsed files 147 | *.sage.py 148 | 149 | # Environments 150 | .env 151 | .venv 152 | env/ 153 | venv/ 154 | ENV/ 155 | env.bak/ 156 | venv.bak/ 157 | 158 | # Spyder project settings 159 | .spyderproject 160 | .spyproject 161 | 162 | # Rope project settings 163 | .ropeproject 164 | 165 | # mkdocs documentation 166 | /site 167 | 168 | # mypy 169 | .mypy_cache/ 170 | .dmypy.json 171 | dmypy.json 172 | 173 | # Pyre type checker 174 | .pyre/ 175 | 176 | # pytype static type analyzer 177 | .pytype/ 178 | 179 | # Cython debug symbols 180 | cython_debug/ 181 | 182 | # PyCharm 183 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 184 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 185 | # and can be added to the global gitignore or merged into this file. For a more nuclear 186 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 187 | #.idea/ 188 | 189 | 190 | ############ Additional files ############ 191 | .vscode 192 | .DS_Store 193 | output/ 194 | model/ 195 | data/ -------------------------------------------------------------------------------- /py-bsds500/src/String.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "String.hh" 24 | 25 | String::String () 26 | { 27 | _length = 0; 28 | _size = defaultMinSize + 1; 29 | _text = new char [_size]; 30 | _text[_length] = '\0'; 31 | } 32 | 33 | String::String (const String& that) 34 | { 35 | _length = that._length; 36 | _size = that._size; 37 | _text = new char [_size]; 38 | memcpy (_text, that._text, _length + 1); 39 | } 40 | 41 | String::String (const char* fmt, ...) 42 | { 43 | assert (fmt != NULL); 44 | 45 | _length = 0; 46 | _size = strlen (fmt) + 1; 47 | _text = new char [_size]; 48 | _text[_length] = '\0'; 49 | 50 | va_list ap; 51 | va_start (ap, fmt); 52 | _append (fmt, ap); 53 | va_end (ap); 54 | } 55 | 56 | String::~String () 57 | { 58 | assert (_text != NULL); 59 | delete [] _text; 60 | } 61 | 62 | String& 63 | String::operator= (const String& that) 64 | { 65 | if (&that == this) { return *this; } 66 | clear(); 67 | append ("%s", that.text()); 68 | return *this; 69 | } 70 | 71 | String& 72 | String::operator= (const char* s) 73 | { 74 | clear(); 75 | if (s != NULL) { 76 | append ("%s", s); 77 | } 78 | return *this; 79 | } 80 | 81 | void 82 | String::clear () 83 | { 84 | _length = 0; 85 | _text[0] = '\0'; 86 | } 87 | 88 | void 89 | String::append (char c) 90 | { 91 | _append (1, (const char*)&c); 92 | } 93 | 94 | void 95 | String::append (unsigned length, const char* s) 96 | { 97 | _append (length, s); 98 | } 99 | 100 | void 101 | String::append (const char* fmt, ...) 102 | { 103 | assert (fmt != NULL); 104 | va_list ap; 105 | va_start (ap, fmt); 106 | _append (fmt, ap); 107 | va_end (ap); 108 | } 109 | 110 | const char& 111 | String::operator[] (unsigned i) const 112 | { 113 | assert (i < _length); 114 | return _text[i]; 115 | } 116 | 117 | bool 118 | String::nextLine (FILE* fp) 119 | { 120 | assert (fp != NULL); 121 | 122 | const int bufLen = 128; 123 | char buf[bufLen]; 124 | 125 | clear (); 126 | 127 | while (fgets (buf, bufLen, fp) != NULL) { 128 | _append (strlen (buf), buf); 129 | if (_text[_length - 1] == '\n') { 130 | _length--; 131 | _text[_length] = '\0'; 132 | return true; 133 | } 134 | } 135 | 136 | if (_length > 0) { 137 | assert (_text[_length - 1] != '\n'); 138 | return true; 139 | } else { 140 | return false; 141 | } 142 | } 143 | 144 | void 145 | String::_append (unsigned length, const char* s) 146 | { 147 | _grow (length + _length + 1); 148 | if (length > 0) { 149 | memcpy (_text + _length, s, length); 150 | _length += length; 151 | _text[_length] = '\0'; 152 | } 153 | } 154 | 155 | // On solaris and linux, vsnprintf returns the number of characters needed 156 | // to format the entire string. 157 | // On irix, vsnprintf returns the number of characters written. This is 158 | // at most length(buf)-1. 159 | // On some sytems, vsnprintf returns -1 if there wasn't enough space. 160 | void 161 | String::_append (const char* fmt, va_list ap) 162 | { 163 | int bufLen = 128; 164 | char* buf; 165 | 166 | while (1) { 167 | buf = new char [bufLen]; 168 | int cnt = vsnprintf (buf, bufLen, fmt, ap); 169 | if (cnt < 0 || cnt >= bufLen - 1) { 170 | delete [] buf; 171 | bufLen *= 2; 172 | continue; 173 | } else { 174 | break; 175 | } 176 | } 177 | 178 | _append (strlen (buf), buf); 179 | delete [] buf; 180 | } 181 | 182 | void 183 | String::_grow (unsigned minSize) 184 | { 185 | if (minSize > _size) { 186 | char* old = _text; 187 | _size += minSize; 188 | _text = new char [_size]; 189 | memcpy (_text, old, _length + 1); 190 | delete [] old; 191 | } 192 | } 193 | 194 | -------------------------------------------------------------------------------- /py-bsds500/evaluate_parallel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import numpy as np 7 | import tqdm 8 | from skimage.io import imread 9 | from skimage.util import img_as_float 10 | 11 | from bsds import evaluate_boundaries_parallel as evaluate_boundaries 12 | from bsds.bsds_dataset import Dataset 13 | from pathlib import Path 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Test output') 18 | parser.add_argument('data_path', type=str, 19 | help='the root path of the dataset') 20 | parser.add_argument('pred_path', type=str, 21 | help='the root path of the predictions') 22 | parser.add_argument('val_test', type=str, 23 | help='val or test') 24 | parser.add_argument('--thresholds', type=str, default='99', 25 | help='the number of thresholds') 26 | parser.add_argument('--suffix_ext', type=str, default='.png', 27 | help='suffix and extension') 28 | parser.add_argument('--num_workers', type=int, default=4, 29 | help='number of workers') 30 | parser.add_argument('--max_dist', type=float, default=0.0075) 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(): 36 | args = parse_args() 37 | data_path = args.data_path 38 | pred_path = args.pred_path 39 | val_test = args.val_test 40 | suffix_ext = args.suffix_ext 41 | thresholds = args.thresholds 42 | thresholds = thresholds.strip() 43 | num_workers = args.num_workers 44 | max_dist = args.max_dist 45 | 46 | if 'BSDS500' in data_path: 47 | ext = '.jpg' 48 | elif 'NYUDv2' in data_path: 49 | ext = '.png' 50 | 51 | try: 52 | n_thresholds = int(thresholds) 53 | thresholds = n_thresholds 54 | except ValueError: 55 | try: 56 | if thresholds.startswith('[') and thresholds.endswith(']'): 57 | thresholds = thresholds[1:-1] 58 | thresholds = np.array( 59 | [float(t.strip()) for t in thresholds.split(',')]) 60 | else: 61 | print('Bad threshold format; ' 62 | 'should be a python list of floats (`[a, b, c]`)') 63 | sys.exit() 64 | except ValueError: 65 | print('Bad threshold format; ' 66 | 'should be a python list of ints (`[a, b, c]`)') 67 | sys.exit() 68 | 69 | ds = Dataset(data_path, ext) 70 | 71 | if val_test == 'val': 72 | SAMPLE_NAMES = ds.val_sample_names 73 | elif val_test == 'test': 74 | SAMPLE_NAMES = ds.test_sample_names 75 | else: 76 | print('need to specify either val or test, not {}'.format(val_test)) 77 | sys.exit() 78 | 79 | def load_gt_boundaries(sample_name): 80 | return ds.boundaries(sample_name) 81 | 82 | def load_pred(sample_name): 83 | sample_path = os.path.join(pred_path, f'{sample_name}{suffix_ext}') 84 | pred = img_as_float(imread(sample_path)) 85 | bnds = ds.boundaries(sample_name) 86 | tgt_shape = bnds[0].shape 87 | pred = pred[:tgt_shape[0], :tgt_shape[1]] 88 | pred = np.pad(pred, [(0, tgt_shape[0]-pred.shape[0]), 89 | (0, tgt_shape[1]-pred.shape[1])], mode='constant') 90 | return pred 91 | 92 | output_dir = Path(pred_path) / 'results' 93 | print(f'output_dir: {output_dir}') 94 | 95 | os.makedirs(output_dir, exist_ok=True) 96 | results_path = os.path.join(f'{output_dir}', 97 | f'results_thr{thresholds}.pkl') 98 | if os.path.exists(results_path): 99 | with open(results_path, 'rb') as f: 100 | results = pickle.load(f) 101 | SAMPLE_NAMES, sample_results, threshold_results, overall_result = \ 102 | results 103 | else: 104 | sample_results, threshold_results, overall_result = \ 105 | evaluate_boundaries.pr_evaluation( 106 | thresholds, SAMPLE_NAMES, load_gt_boundaries, 107 | load_pred, progress=tqdm.tqdm, num_workers=num_workers, 108 | max_dist=max_dist) 109 | results = (SAMPLE_NAMES, sample_results, 110 | threshold_results, overall_result) 111 | with open(results_path, 'wb') as f: 112 | pickle.dump(results, f) 113 | 114 | ods = overall_result.f1 115 | ois = overall_result.best_f1 116 | 117 | rs = [] 118 | ps = [] 119 | for res in threshold_results: 120 | rs.append(res.recall) 121 | ps.append(res.precision) 122 | ap = np.trapz(ps[::-1], rs[::-1]) 123 | 124 | with open(os.path.join(f'{output_dir}', f'results_thr{thresholds}.txt'), 125 | 'w') as f: 126 | text = f'ODS: {ods:.3f}, OIS: {ois:.3f} AP: {ap:.3f}' 127 | print(text, file=f) 128 | print(text) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /py-bsds500/src/String.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __String_hh__ 3 | #define __String_hh__ 4 | 5 | // Class that makes it easy to construct strings in a safe manner. 6 | // The main bonus is the printf-style interface for creating and 7 | // appending strings. 8 | 9 | // This class implements strings so that they behave like intrinsic 10 | // types, i.e. assignment creates a copy, passing by value in a 11 | // function call creates a copy. 12 | 13 | // NOTE: Calling a constructor or append() method with a plain char* 14 | // is dangerous, since the string is interpreted by sprintf. To be 15 | // safe, always do append("%s",s) instead of append(s). 16 | 17 | // Copyright (C) 2002 David R. Martin 18 | // 19 | // This program is free software; you can redistribute it and/or 20 | // modify it under the terms of the GNU General Public License as 21 | // published by the Free Software Foundation; either version 2 of the 22 | // License, or (at your option) any later version. 23 | // 24 | // This program is distributed in the hope that it will be useful, but 25 | // WITHOUT ANY WARRANTY; without even the implied warranty of 26 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 27 | // General Public License for more details. 28 | // 29 | // You should have received a copy of the GNU General Public License 30 | // along with this program; if not, write to the Free Software 31 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 32 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | class String 40 | { 41 | public: 42 | 43 | // Constructors. 44 | String (); 45 | String (const String& that); 46 | String (const char* fmt, ...); 47 | 48 | // Destructor. 49 | ~String (); 50 | 51 | // Assignment operators. 52 | String& operator= (const String& that); 53 | String& operator= (const char* s); 54 | 55 | // Accessors. 56 | unsigned length () const { return _length; } 57 | const char* text () const { return _text; } 58 | const char& operator[] (unsigned i) const; 59 | 60 | // Modifiers. 61 | void clear (); 62 | void append (char c); 63 | void append (unsigned length, const char* s); 64 | void append (const char* fmt, ...); 65 | 66 | // Load next line from file; newline is discarded. 67 | // Return true if new data; false on EOF. 68 | bool nextLine (FILE* fp); 69 | 70 | // Implicit convertion to const char* is useful so that other 71 | // modules that take strings as arguments don't have to know about 72 | // the String class, and the caller doesn't have to explicitly 73 | // call the text() method. 74 | operator const char* () const { return text(); } 75 | 76 | private: 77 | 78 | static const unsigned defaultMinSize = 16; 79 | 80 | void _append (unsigned length, const char* s); 81 | void _append (const char* fmt, va_list ap); 82 | 83 | void _grow (unsigned minSize); 84 | 85 | unsigned _length; 86 | unsigned _size; 87 | char* _text; 88 | 89 | }; 90 | 91 | // == operator 92 | inline int operator== (const String& x, const String& y) 93 | { return strcmp (x, y) == 0; } 94 | inline int operator== (const String& x, const char* y) 95 | { return strcmp (x, y) == 0; } 96 | inline int operator== (const char* x, const String& y) 97 | { return strcmp (x, y) == 0; } 98 | 99 | // != operator 100 | inline int operator!= (const String& x, const String& y) 101 | { return strcmp (x, y) != 0; } 102 | inline int operator!= (const String& x, const char* y) 103 | { return strcmp (x, y) != 0; } 104 | inline int operator!= (const char* x, const String& y) 105 | { return strcmp (x, y) != 0; } 106 | 107 | // < operator 108 | inline int operator< (const String& x, const String& y) 109 | { return strcmp (x, y) < 0; } 110 | inline int operator< (const String& x, const char* y) 111 | { return strcmp (x, y) < 0; } 112 | inline int operator< (const char* x, const String& y) 113 | { return strcmp (x, y) < 0; } 114 | 115 | // > operator 116 | inline int operator> (const String& x, const String& y) 117 | { return strcmp (x, y) > 0; } 118 | inline int operator> (const String& x, const char* y) 119 | { return strcmp (x, y) > 0; } 120 | inline int operator> (const char* x, const String& y) 121 | { return strcmp (x, y) > 0; } 122 | 123 | // <= operator 124 | inline int operator<= (const String& x, const String& y) 125 | { return strcmp (x, y) <= 0; } 126 | inline int operator<= (const String& x, const char* y) 127 | { return strcmp (x, y) <= 0; } 128 | inline int operator<= (const char* x, const String& y) 129 | { return strcmp (x, y) <= 0; } 130 | 131 | // >= operator 132 | inline int operator>= (const String& x, const String& y) 133 | { return strcmp (x, y) >= 0; } 134 | inline int operator>= (const String& x, const char* y) 135 | { return strcmp (x, y) >= 0; } 136 | inline int operator>= (const char* x, const String& y) 137 | { return strcmp (x, y) >= 0; } 138 | 139 | // write to output stream 140 | inline std::ostream& operator<< (std::ostream& out, const String& s) { 141 | out << (const char*)s; 142 | return out; 143 | } 144 | 145 | #endif // __String_hh__ 146 | -------------------------------------------------------------------------------- /py-bsds500/src/csa_types.h: -------------------------------------------------------------------------------- 1 | #define PREC_COSTS 2 | 3 | #if defined(QUICK_MIN) && !defined(NUM_BEST) 4 | #define NUM_BEST 3 5 | #endif 6 | 7 | #if defined(USE_SP_AUG_FORWARD) || defined(USE_SP_AUG_BACKWARD) 8 | #ifndef USE_SP_AUG 9 | #define USE_SP_AUG 10 | #endif 11 | #endif 12 | 13 | #if defined(USE_P_UPDATE) || defined(BACK_PRICE_OUT) || \ 14 | defined(USE_SP_AUG_BACKWARD) 15 | #define STORE_REV_ARCS 16 | #endif 17 | 18 | typedef struct lhs_node { 19 | #if defined(QUICK_MIN) 20 | struct { 21 | /* 22 | flag used to indicate to 23 | double_push() that so few arcs 24 | are incident that best[] is 25 | useless. 26 | */ 27 | #ifdef QUICK_MIN 28 | unsigned few_arcs : 1; 29 | #endif 30 | } node_info; 31 | #ifdef QUICK_MIN 32 | /* 33 | list of arcs to consider first in 34 | calculating the minimum-reduced-cost 35 | incident arc; if we find it here, we 36 | need look no further. 37 | */ 38 | struct lr_arc *best[NUM_BEST]; 39 | /* 40 | bound on the reduced cost of an arc we 41 | can be certain still belongs among 42 | those in best[]. 43 | */ 44 | double next_best; 45 | #endif 46 | #endif 47 | #ifdef EXPLICIT_LHS_PRICES 48 | /* 49 | price of this node. 50 | */ 51 | double p; 52 | #endif 53 | /* 54 | first arc in the arc array associated 55 | with this node. 56 | */ 57 | struct lr_arc *priced_out; 58 | /* 59 | first priced-in arc in the arc array 60 | associated with this node. 61 | */ 62 | struct lr_arc *first; 63 | /* 64 | matching arc (if any) associated with 65 | this node; NULL if this node is 66 | unmatched. 67 | */ 68 | struct lr_arc *matched; 69 | #if defined(USE_P_UPDATE) 70 | /* 71 | price change required on this node (in 72 | units of epsilon) to ensure that its 73 | excess can reach a deficit in the 74 | admissible graph. computed and used in 75 | p_update(). 76 | */ 77 | long delta_reqd; 78 | #endif 79 | #ifdef USE_SP_AUG_BACKWARD 80 | struct lr_arc *aug_path; 81 | #endif 82 | } *lhs_ptr; 83 | 84 | typedef struct rhs_node { 85 | struct { 86 | #ifdef USE_P_REFINE 87 | /* 88 | depth-first search flags. 89 | dfs is to determine whether 90 | admissible graph contains a 91 | cycle in p_refine(). 92 | */ 93 | unsigned srchng : 1; 94 | unsigned srched : 1; 95 | #endif 96 | /* 97 | flag to indicate this node's 98 | matching arc (if any) is 99 | priced in. 100 | */ 101 | unsigned priced_in : 1; 102 | } node_info; 103 | /* 104 | lhs node this rhs node is matched to. 105 | */ 106 | lhs_ptr matched; 107 | /* 108 | price of this node. 109 | */ 110 | double p; 111 | #ifdef USE_SP_AUG_FORWARD 112 | struct lr_arc *aug_path; 113 | #endif 114 | #if defined(USE_P_REFINE) || defined(USE_P_UPDATE) || defined(USE_SP_AUG) 115 | /* 116 | number of epsilons of price change 117 | required at this node to accomplish 118 | p_refine()'s or p_update()'s goal. 119 | */ 120 | long key; 121 | /* 122 | fields to maintain buckets of nodes as 123 | lists in p_refine() and p_update(). 124 | */ 125 | struct rhs_node *prev, *next; 126 | #endif 127 | #ifdef STORE_REV_ARCS 128 | /* 129 | first back arc in the arc array 130 | associated with this node. 131 | */ 132 | struct rl_arc *priced_out; 133 | /* 134 | first priced-in back arc in the arc 135 | array associated with this node. 136 | */ 137 | struct rl_arc *back_arcs; 138 | #endif 139 | } *rhs_ptr; 140 | 141 | #ifdef STORE_REV_ARCS 142 | typedef struct rl_arc { 143 | /* 144 | lhs node associated with this back 145 | arc. some would have liked the name 146 | head better. 147 | */ 148 | lhs_ptr tail; 149 | #if defined(USE_P_UPDATE) || defined(USE_SP_AUG_BACKWARD) 150 | /* 151 | cost of this back arc. this cost gets 152 | modified to incorporate other arc 153 | costs in p_update() and sp_aug(), 154 | while forward arc costs remain 155 | constant throughout. 156 | */ 157 | double c; 158 | #endif 159 | #if defined(USE_PRICE_OUT) || defined(USE_SP_AUG_BACKWARD) 160 | /* 161 | this arc's reverse in the forward arc 162 | list. 163 | */ 164 | struct lr_arc *rev; 165 | #endif 166 | } *rl_aptr; 167 | #endif 168 | 169 | typedef struct lr_arc { 170 | /* 171 | rhs node associated with this arc. 172 | */ 173 | rhs_ptr head; 174 | /* 175 | arc cost. 176 | */ 177 | double c; 178 | #ifdef USE_SP_AUG_FORWARD 179 | lhs_ptr tail; 180 | #endif 181 | #ifdef STORE_REV_ARCS 182 | /* 183 | this arc's reverse in the back arc 184 | list. 185 | */ 186 | struct rl_arc *rev; 187 | #endif 188 | } *lr_aptr; 189 | 190 | typedef struct stack_st { 191 | /* 192 | Sometimes stacks have lhs nodes, and 193 | other times they have rhs nodes. So 194 | there's a little type clash; 195 | everything gets cast to (char *) so we 196 | can use the same structure for both. 197 | */ 198 | char **bottom; 199 | char **top; 200 | } *stack; 201 | 202 | typedef struct queue_st { 203 | /* 204 | Sometimes queues have lhs nodes, and 205 | other times they have rhs nodes. So 206 | there's a little type clash; 207 | everything gets cast to (char *) so we 208 | can use the same structure for both. 209 | */ 210 | char **head; 211 | char **tail; 212 | char **storage; 213 | char **end; 214 | unsigned max_size; 215 | } *queue; 216 | -------------------------------------------------------------------------------- /py-bsds500/src/csa_defs.h: -------------------------------------------------------------------------------- 1 | #define TRUE 1 2 | #define FALSE 0 3 | #define MAXLINE 100 4 | #define DEFAULT_SCALE_FACTOR 10 5 | #define DEFAULT_PO_COST_THRESH (2.0 * sqrt((double) n) * \ 6 | sqrt(sqrt((double) n))) 7 | #define DEFAULT_PO_WORK_THRESH 50 8 | #define DEFAULT_UPD_FAC 2 9 | #if defined(USE_SP_AUG_FORWARD) || defined(USE_SP_AUG_BACKWARD) 10 | #ifndef USE_SP_AUG 11 | #define USE_SP_AUG 12 | #endif 13 | #endif 14 | 15 | #ifdef USE_SP_AUG 16 | #define EXCESS_THRESH 127 17 | #else 18 | #define EXCESS_THRESH 0 19 | #endif 20 | 21 | #if defined(USE_P_UPDATE) || defined(STRONG_PO) 22 | #define WORK_TYPE unsigned 23 | #define REFINE_WORK relabelings 24 | #endif 25 | 26 | #if defined(DEBUG) && defined(ROUND_COSTS) 27 | #define MAGIC_MARKER 0xAAAAAAAA 28 | #endif 29 | 30 | #ifdef QUEUE_ORDER 31 | #define ACTIVE_TYPE queue 32 | #define create_active(size) active = q_create(size) 33 | #define make_active(v) enq(active, (char *) v) 34 | #define get_active_node(v) v = (lhs_ptr) deq(active) 35 | #else 36 | #define ACTIVE_TYPE stack 37 | #define create_active(size) active = st_create(size) 38 | #define make_active(v) st_push(active, (char *) v) 39 | #define get_active_node(v) v = (lhs_ptr) st_pop(active) 40 | #endif 41 | 42 | #define st_push(s, el) \ 43 | {\ 44 | *(s->top) = (char *) el;\ 45 | s->top++;\ 46 | } 47 | 48 | #define st_empty(s) (s->top == s->bottom) 49 | 50 | #define enq(q, el) \ 51 | {\ 52 | *(q->tail) = el;\ 53 | if (q->tail == q->end) q->tail = q->storage;\ 54 | else q->tail++;\ 55 | } 56 | 57 | #define q_empty(q) (q->head == q->tail ? 1 : 0) 58 | 59 | #define insert_list(node, head) \ 60 | {\ 61 | node->next = (*(head));\ 62 | (*(head))->prev = node;\ 63 | (*(head)) = node;\ 64 | node->prev = tail_rhs_node;\ 65 | } 66 | 67 | #define delete_list(node, head) \ 68 | {\ 69 | if (node->prev == tail_rhs_node)\ 70 | (*(head)) = node->next;\ 71 | node->prev->next = node->next;\ 72 | node->next->prev = node->prev;\ 73 | } 74 | 75 | /* 76 | The author hereby apologizes for the following incomprehensible 77 | muddle. Price-outs involve moving arcs around in the data structure, 78 | and it turns out to be faster to copy them field-by-field than to use 79 | memcpy() because they're so small. But the set of fields an arc has 80 | depends on lots of things, hence this mess. 81 | */ 82 | 83 | #if defined(USE_PRICE_OUT) || defined(ROUND_COSTS) 84 | #ifdef STORE_REV_ARCS 85 | #ifdef ROUND_COSTS 86 | #define copy_lr_arc(a, b) \ 87 | {\ 88 | b->head = a->head;\ 89 | b->c_init = a->c_init;\ 90 | b->c = a->c;\ 91 | b->rev = a->rev;\ 92 | } 93 | #else /* ROUND_COSTS */ 94 | #define copy_lr_arc(a, b) \ 95 | {\ 96 | b->head = a->head;\ 97 | b->c = a->c;\ 98 | b->rev = a->rev;\ 99 | } 100 | #endif /* ROUND_COSTS */ 101 | 102 | #ifdef USE_P_UPDATE 103 | #define copy_rl_arc(a, b) \ 104 | { b->tail = a->tail; b->c = a->c; b->rev = a->rev; } 105 | #else /* USE_P_UPDATE */ 106 | #define copy_rl_arc(a, b) \ 107 | { b->tail = a->tail; b->rev = a->rev; } 108 | #endif /* USE_P_UPDATE */ 109 | 110 | #define exch_rl_arcs(a, b) \ 111 | {\ 112 | copy_rl_arc(b, tail_rl_arc);\ 113 | copy_rl_arc(a, b);\ 114 | copy_rl_arc(tail_rl_arc, a);\ 115 | } 116 | #else /* STORE_REV_ARCS */ 117 | #ifdef PREC_COSTS 118 | #define copy_lr_arc(a, b) \ 119 | {\ 120 | b->head = a->head;\ 121 | b->c = a->c;\ 122 | } 123 | #else /* PREC_COSTS */ 124 | #define copy_lr_arc(a, b) \ 125 | {\ 126 | b->head = a->head;\ 127 | b->c_init = a->c_init;\ 128 | b->c = a->c;\ 129 | } 130 | #endif /* PREC_COSTS */ 131 | #endif /* STORE_REV_ARCS */ 132 | 133 | #define exch_lr_arcs(a, b) \ 134 | {\ 135 | copy_lr_arc(b, tail_lr_arc);\ 136 | copy_lr_arc(a, b);\ 137 | copy_lr_arc(tail_lr_arc, a);\ 138 | } 139 | 140 | extern lr_aptr tail_lr_arc; 141 | #ifdef STORE_REV_ARCS 142 | extern rl_aptr tail_rl_arc; 143 | #endif 144 | 145 | #ifdef STORE_REV_ARCS 146 | #define price_in_rev(a) \ 147 | { \ 148 | register rl_aptr b_a = --a->head->back_arcs; \ 149 | register rl_aptr a_r = a->rev; \ 150 | if (b_a != a_r) \ 151 | { \ 152 | register lr_aptr b_r = b_a->rev; \ 153 | exch_rl_arcs(b_a, a_r); \ 154 | b_r->rev = a_r; \ 155 | a->rev = b_a; \ 156 | } \ 157 | } 158 | 159 | #define price_out_rev(a) \ 160 | { \ 161 | register rl_aptr b_a = a->head->back_arcs; \ 162 | register rl_aptr a_r = a->rev; \ 163 | if (b_a != a_r) \ 164 | { \ 165 | register lr_aptr b_r = b_a->rev; \ 166 | exch_rl_arcs(b_a, a_r); \ 167 | b_r->rev = a_r; \ 168 | a->rev = b_a; \ 169 | } \ 170 | a->head->back_arcs++; \ 171 | } 172 | 173 | #define handle_rev_pointers(a, b) { a->rev->rev = b; b->rev->rev = a; } 174 | #else /* STORE_REV_ARCS */ 175 | #define price_in_rev(a) /* do nothing */ 176 | #define price_out_rev(a) /* do nothing */ 177 | #define handle_rev_pointers(a, b) /* do nothing */ 178 | #endif /* STORE_REV_ARCS */ 179 | 180 | #define price_in_unm_arc(v, a) \ 181 | { \ 182 | register lr_aptr f_a = --v->first; \ 183 | price_in_rev(a); \ 184 | if (f_a != a) \ 185 | { \ 186 | if (v->matched == f_a) v->matched = a; \ 187 | handle_rev_pointers(a, f_a); \ 188 | exch_lr_arcs(a, f_a); \ 189 | } \ 190 | } 191 | 192 | #define price_in_mch_arc(v, a) \ 193 | { \ 194 | register lr_aptr f_a = --v->first; \ 195 | price_in_rev(a); \ 196 | a->head->node_info.priced_in = TRUE; \ 197 | if (f_a != a) \ 198 | { \ 199 | v->matched = f_a; \ 200 | handle_rev_pointers(a, f_a); \ 201 | exch_lr_arcs(a, f_a); \ 202 | } \ 203 | } 204 | 205 | #define price_out_unm_arc(v, a) \ 206 | { \ 207 | register lr_aptr f_a = v->first++; \ 208 | price_out_rev(a); \ 209 | if (f_a != a) \ 210 | { \ 211 | if (v->matched == f_a) v->matched = a; \ 212 | handle_rev_pointers(a, f_a); \ 213 | exch_lr_arcs(a, f_a); \ 214 | } \ 215 | } 216 | 217 | #define price_out_mch_arc(v, a) \ 218 | { \ 219 | register lr_aptr f_a = v->first++; \ 220 | price_out_rev(a); \ 221 | a->head->node_info.priced_in = FALSE; \ 222 | if (f_a != a) \ 223 | { \ 224 | v->matched = f_a; \ 225 | handle_rev_pointers(a, f_a); \ 226 | exch_lr_arcs(a, f_a); \ 227 | } \ 228 | } 229 | #endif /* USE_PRICE_OUT || ROUND_COSTS */ 230 | -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argparse 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from segment_anything import sam_model_registry 8 | from tqdm import tqdm 9 | 10 | from automatic_mask_and_probability_generator import \ 11 | SamAutomaticMaskAndProbabilityGenerator 12 | 13 | 14 | def normalize_image(image): 15 | # Normalize the image to the range [0, 1] 16 | min_val = image.min() 17 | max_val = image.max() 18 | image = (image - min_val) / (max_val - min_val) 19 | 20 | return image 21 | 22 | 23 | def get_args(): 24 | 25 | parser = argparse.ArgumentParser(description='Test output') 26 | 27 | # dataset 28 | parser.add_argument('--dataset', type=str, help='BSDS500 or NYUDv2') 29 | parser.add_argument('--data_split', type=str, default='test', 30 | help='train, val, or test') 31 | 32 | # arguments for SAM 33 | parser.add_argument('--points_per_side', type=int, default=16, 34 | help='Number of points per side.') 35 | parser.add_argument('--points_per_batch', type=int, default=64, 36 | help='Number of points per batch') 37 | parser.add_argument('--pred_iou_thresh', type=float, default=0.88, 38 | help='Prediction IOU threshold') 39 | parser.add_argument('--stability_score_thresh', type=float, default=0.95, 40 | help='Stability score threshold') 41 | parser.add_argument('--stability_score_offset', type=float, default=1.0, 42 | help='Stability score offset') 43 | parser.add_argument('--box_nms_thresh', type=float, default=0.7, 44 | help='NMS threshold for box suppression') 45 | parser.add_argument('--crop_n_layers', type=int, default=0, 46 | help='Number of layers to crop') 47 | parser.add_argument('--crop_nms_thresh', type=float, default=0.7, 48 | help='NMS threshold for cropping') 49 | parser.add_argument('--crop_overlap_ratio', type=float, default=512/1500, 50 | help='Overlap ratio for cropping') 51 | parser.add_argument('--crop_n_points_downscale_factor', 52 | type=int, default=1, 53 | help='Downscale factor for number of points in crop') 54 | parser.add_argument('--min_mask_region_area', type=int, default=0, 55 | help='Minimum mask region area') 56 | parser.add_argument('--output_mode', type=str, default="binary_mask", 57 | help='Output mode of the mask generator') 58 | parser.add_argument('--nms_threshold', type=float, default=0.7, 59 | help='NMS threshold') 60 | parser.add_argument('--bzp', type=int, default=0, 61 | help='boundary zero padding') 62 | parser.add_argument('--pred_iou_thresh_filtering', action='store_true', 63 | help='filter by pred_iou_thresh') 64 | parser.add_argument('--stability_score_thresh_filtering', 65 | action='store_true', 66 | help='filter by stability_score_thresh') 67 | 68 | # gaussian kernel size for post processing before edge nms 69 | parser.add_argument('--kernel_size', type=int, default=3, 70 | help='kernel size') 71 | 72 | args = parser.parse_args() 73 | return args 74 | 75 | 76 | def make_output_dir(args): 77 | dataset = args.dataset 78 | data_split = args.data_split 79 | 80 | outut_root_dir = Path('output') / dataset 81 | outut_root_dir.mkdir(parents=True, exist_ok=True) 82 | 83 | last_exp_num = 0 84 | for exp_dir in outut_root_dir.glob('exp*'): 85 | exp_num = int(exp_dir.stem[3:]) 86 | last_exp_num = max(last_exp_num, exp_num) 87 | 88 | output_dir = \ 89 | outut_root_dir / f'exp{str(last_exp_num + 1).zfill(3)}' / data_split 90 | output_dir.mkdir(parents=True, exist_ok=True) 91 | 92 | # save args as a text file 93 | with open(output_dir / 'args.txt', 'w') as f: 94 | for k, v in vars(args).items(): 95 | f.write(f'{k}: {v}\n') 96 | 97 | return output_dir 98 | 99 | 100 | def main(): 101 | 102 | args = get_args() 103 | dataset = args.dataset 104 | assert dataset in ["BSDS500", "NYUDv2"] 105 | data_split = args.data_split 106 | # assert data_split in ["train", "val", "test"] 107 | 108 | device = "cuda" 109 | sam = sam_model_registry["default"]( 110 | checkpoint="model/sam_vit_h_4b8939.pth") 111 | sam.to(device=device) 112 | generator = SamAutomaticMaskAndProbabilityGenerator( 113 | model=sam, 114 | points_per_side=args.points_per_side, 115 | points_per_batch=args.points_per_batch, 116 | pred_iou_thresh=args.pred_iou_thresh, 117 | stability_score_thresh=args.stability_score_thresh, 118 | stability_score_offset=args.stability_score_offset, 119 | box_nms_thresh=args.box_nms_thresh, 120 | crop_n_layers=args.crop_n_layers, 121 | crop_nms_thresh=args.crop_nms_thresh, 122 | crop_overlap_ratio=args.crop_overlap_ratio, 123 | crop_n_points_downscale_factor=args.crop_n_points_downscale_factor, 124 | min_mask_region_area=args.min_mask_region_area, 125 | output_mode=args.output_mode, 126 | nms_threshold=args.nms_threshold, 127 | bzp=args.bzp, 128 | pred_iou_thresh_filtering=args.pred_iou_thresh_filtering, 129 | stability_score_thresh_filtering=args.stability_score_thresh_filtering, 130 | ) 131 | 132 | kernel_size = args.kernel_size 133 | 134 | # make output directory 135 | output_dir = make_output_dir(args) 136 | 137 | img_dir = Path('data') / dataset / 'images' / data_split 138 | 139 | if dataset == 'BSDS500': 140 | suf = 'jpg' 141 | else: 142 | suf = 'png' 143 | 144 | for img_path in tqdm(list(img_dir.glob(f'*.{suf}'))): 145 | 146 | name = img_path.stem 147 | image = cv2.imread(str(img_path)) 148 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 149 | masks = generator.generate(image) 150 | 151 | p_max = None 152 | for mask in masks: 153 | p = mask["prob"] 154 | if p_max is None: 155 | p_max = p 156 | else: 157 | p_max = np.maximum(p_max, p) 158 | 159 | edges = normalize_image(p_max) 160 | 161 | if kernel_size > 0: 162 | assert kernel_size % 2 == 1 163 | edges = cv2.GaussianBlur(edges, (kernel_size, kernel_size), 0) 164 | 165 | edge_detection = cv2.ximgproc.createStructuredEdgeDetection( 166 | 'model/model.yml.gz') 167 | orimap = edge_detection.computeOrientation(edges) 168 | edges = edge_detection.edgesNms(edges, orimap) 169 | edges = (edges * 255).astype(np.uint8) 170 | 171 | cv2.imwrite(str(output_dir / f'{name}.png'), edges) 172 | 173 | 174 | if __name__ == "__main__": 175 | main() 176 | -------------------------------------------------------------------------------- /py-bsds500/bsds/thin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Thinning morphological operation applied using lookup tables. 5 | # We convert the 3x3 neighbourhood surrounding a pixel to an index 6 | # used to lookup the output in a lookup table. 7 | 8 | # Bit masks for each neighbour 9 | # 1 2 4 10 | # 8 16 32 11 | # 64 128 256 12 | NEIGH_MASK_EAST = 32 13 | NEIGH_MASK_NORTH_EAST = 4 14 | NEIGH_MASK_NORTH = 2 15 | NEIGH_MASK_NORTH_WEST = 1 16 | NEIGH_MASK_WEST = 8 17 | NEIGH_MASK_SOUTH_WEST = 64 18 | NEIGH_MASK_SOUTH = 128 19 | NEIGH_MASK_SOUTH_EAST = 256 20 | NEIGH_MASK_CENTRE = 16 21 | 22 | # Masks in a list 23 | # MASKS[0] = centre 24 | # MASKS[1..8] = start from east, counter-clockwise 25 | MASKS = [NEIGH_MASK_CENTRE, 26 | NEIGH_MASK_EAST, NEIGH_MASK_NORTH_EAST, NEIGH_MASK_NORTH, NEIGH_MASK_NORTH_WEST, 27 | NEIGH_MASK_WEST, NEIGH_MASK_SOUTH_WEST, NEIGH_MASK_SOUTH, NEIGH_MASK_SOUTH_EAST, 28 | ] 29 | 30 | # Constant listing all indices 31 | _LUT_INDS = np.arange(512) 32 | 33 | 34 | def binary_image_to_lut_indices(x): 35 | """ 36 | Convert a binary image to an index image that can be used with a lookup table 37 | to perform morphological operations. Non-zero elements in the image are interpreted 38 | as 1, zero elements as 0 39 | 40 | :param x: a 2D NumPy array. 41 | :return: a 2D NumPy array, same shape as x 42 | """ 43 | if x.ndim != 2: 44 | raise ValueError('x should have 2 dimensions, not {}'.format(x.ndim)) 45 | 46 | # If the dtype of x is not bool, convert 47 | if x.dtype != np.bool: 48 | x = x != 0 49 | 50 | # Add 51 | x = np.pad(x, [(1, 1), (1, 1)], mode='constant') 52 | 53 | # Convert to LUT indices 54 | lut_indices = x[:-2, :-2] * NEIGH_MASK_NORTH_WEST + \ 55 | x[:-2, 1:-1] * NEIGH_MASK_NORTH + \ 56 | x[:-2, 2:] * NEIGH_MASK_NORTH_EAST + \ 57 | x[1:-1, :-2] * NEIGH_MASK_WEST + \ 58 | x[1:-1, 1:-1] * NEIGH_MASK_CENTRE + \ 59 | x[1:-1, 2:] * NEIGH_MASK_EAST + \ 60 | x[2:, :-2] * NEIGH_MASK_SOUTH_WEST + \ 61 | x[2:, 1:-1] * NEIGH_MASK_SOUTH + \ 62 | x[2:, 2:] * NEIGH_MASK_SOUTH_EAST 63 | 64 | return lut_indices.astype(np.int32) 65 | 66 | 67 | def apply_lut(x, lut): 68 | """ 69 | Perform a morphological operation on the binary image x using the supplied lookup table 70 | :param x: 71 | :param lut: 72 | :return: 73 | """ 74 | if lut.ndim != 1: 75 | raise ValueError('lut should have 1 dimension, not {}'.format(lut.ndim)) 76 | 77 | if lut.shape[0] != 512: 78 | raise ValueError('lut should have 512 entries, not {}'.format(lut.shape[0])) 79 | 80 | lut_indices = binary_image_to_lut_indices(x) 81 | 82 | return lut[lut_indices] 83 | 84 | 85 | def identity_lut(): 86 | """ 87 | Create identity lookup tablef 88 | :return: 89 | """ 90 | lut = np.zeros((512,), dtype=bool) 91 | inds = np.arange(512) 92 | 93 | lut[(inds & NEIGH_MASK_CENTRE)!=0] = True 94 | 95 | return lut 96 | 97 | 98 | def _lut_mutate_mask(lut): 99 | """ 100 | Get a mask that shows which neighbourhood shapes result in changes to the image 101 | :param lut: lookup table 102 | :return: mask indicating which lookup indices result in changes 103 | """ 104 | return lut != identity_lut() 105 | 106 | 107 | 108 | def lut_masks_zero(neigh): 109 | """ 110 | Create a LUT index mask for which the specified neighbour is 0 111 | :param neigh: neighbour index; counter-clockwise from 1 staring at the eastern neighbour 112 | :return: a LUT index mask 113 | """ 114 | if neigh > 8: 115 | neigh -= 8 116 | return (_LUT_INDS & MASKS[neigh]) == 0 117 | 118 | def lut_masks_one(neigh): 119 | """ 120 | Create a LUT index mask for which the specified neighbour is 1 121 | :param neigh: neighbour index; counter-clockwise from 1 staring at the eastern neighbour 122 | :return: a LUT index mask 123 | """ 124 | if neigh > 8: 125 | neigh -= 8 126 | return (_LUT_INDS & MASKS[neigh]) != 0 127 | 128 | def _thin_cond_g1(): 129 | """ 130 | Thinning morphological operation; condition G1 131 | :return: a LUT index mask 132 | """ 133 | b = np.zeros(512, dtype=int) 134 | for i in range(1, 5): 135 | b += lut_masks_zero(2*i-1) & (lut_masks_one(2*i) | lut_masks_one(2*i+1)) 136 | return b == 1 137 | 138 | def _thin_cond_g2(): 139 | """ 140 | Thinning morphological operation; condition G2 141 | :return: a LUT index mask 142 | """ 143 | n1 = np.zeros(512, dtype=int) 144 | n2 = np.zeros(512, dtype=int) 145 | for k in range(1, 5): 146 | n1 += (lut_masks_one(2*k-1) | lut_masks_one(2*k)) 147 | n2 += (lut_masks_one(2*k) | lut_masks_one(2*k+1)) 148 | m = np.minimum(n1, n2) 149 | return (m >= 2) & (m <= 3) 150 | 151 | def _thin_cond_g3(): 152 | """ 153 | Thinning morphological operation; condition G3 154 | :return: a LUT index mask 155 | """ 156 | return ((lut_masks_one(2) | lut_masks_one(3) | lut_masks_zero(8)) & lut_masks_one(1)) == 0 157 | 158 | def _thin_cond_g3_prime(): 159 | """ 160 | Thinning morphological operation; condition G3' 161 | :return: a LUT index mask 162 | """ 163 | return ((lut_masks_one(6) | lut_masks_one(7) | lut_masks_zero(4)) & lut_masks_one(5)) == 0 164 | 165 | def _thin_iter_1_lut(): 166 | """ 167 | Thinning morphological operation; lookup table for iteration 1 168 | :return: lookup table 169 | """ 170 | lut = identity_lut() 171 | cond = _thin_cond_g1() & _thin_cond_g2() & _thin_cond_g3() 172 | lut[cond] = False 173 | return lut 174 | 175 | def _thin_iter_2_lut(): 176 | """ 177 | Thinning morphological operation; lookup table for iteration 2 178 | :return: lookup table 179 | """ 180 | lut = identity_lut() 181 | cond = _thin_cond_g1() & _thin_cond_g2() & _thin_cond_g3_prime() 182 | lut[cond] = False 183 | return lut 184 | 185 | def binary_thin(x, max_iter=None): 186 | """ 187 | Binary thinning morphological operation 188 | 189 | :param x: a binary image, or an image that is to be converted to a binary image 190 | :param max_iter: maximum number of iterations; default is `None` that results in an infinite 191 | number of iterations (note that `binary_thin` will automatically terminate when no more changes occur) 192 | :return: 193 | """ 194 | thin1 = _thin_iter_1_lut() 195 | thin2 = _thin_iter_2_lut() 196 | thin1_mut = _lut_mutate_mask(thin1) 197 | thin2_mut = _lut_mutate_mask(thin2) 198 | 199 | iter_count = 0 200 | while max_iter is None or iter_count < max_iter: 201 | # Iter 1 202 | lut_indices = binary_image_to_lut_indices(x) 203 | x_mut = thin1_mut[lut_indices] 204 | if x_mut.sum() == 0: 205 | break 206 | 207 | x = thin1[lut_indices] 208 | 209 | # Iter 2 210 | lut_indices = binary_image_to_lut_indices(x) 211 | x_mut = thin2_mut[lut_indices] 212 | if x_mut.sum() == 0: 213 | break 214 | 215 | x = thin2[lut_indices] 216 | 217 | iter_count += 1 218 | 219 | return x -------------------------------------------------------------------------------- /py-bsds500/src/Sort.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Sort_hh__ 3 | #define __Sort_hh__ 4 | 5 | // 6 | // A fast in-place sorting routine that can be customized to a 7 | // specific type with all swap and compare operations inlined. 8 | // 9 | // For arrays of types for which assignment, > and < exist (such as 10 | // int,float,etc. or appropriately-defined user types), the usage is 11 | // simple: 12 | // 13 | // double* a = new double [100]; 14 | // sort(a,100); 15 | // 16 | // This will sort the array into increasing order. To sort in another 17 | // order, or to sort more complex types, you must provide compare and 18 | // swap routines: 19 | // 20 | // sortSwap(cl,i,j) 21 | // - Swap elements i and j. 22 | // 23 | // sortCmp(cl,i,j) 24 | // - Compare elements i and j, returning -1,0,1 for <,=,>. 25 | // 26 | // The argument 'cl' is a closure. Note that the sorting routine does 27 | // not evaluate cl in any context other than these two routines. 28 | // 29 | // The postcondition of sort() is (sortCmp(cl,i,j) <= 0) for 30 | // all 0 <= i < j < n, i.e. increasing order. 31 | // 32 | // Here is an example of how to sort an array of points by x 33 | // coordinate in decreasing order: 34 | // 35 | // struct Point { int x, y; }; 36 | // static inline void sortSwap (Point* a, int i, int j) { 37 | // swap(a[i],a[j]); 38 | // } 39 | // static inline int sortCmp (Point* a, int i, int j) { 40 | // return a[j].x - a[i].x; 41 | // } 42 | // Point* points = new Point [100]; 43 | // sort(points,100); 44 | // 45 | 46 | // Copyright (C) 2002 David R. Martin 47 | // 48 | // This program is free software; you can redistribute it and/or 49 | // modify it under the terms of the GNU General Public License as 50 | // published by the Free Software Foundation; either version 2 of the 51 | // License, or (at your option) any later version. 52 | // 53 | // This program is distributed in the hope that it will be useful, but 54 | // WITHOUT ANY WARRANTY; without even the implied warranty of 55 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 56 | // General Public License for more details. 57 | // 58 | // You should have received a copy of the GNU General Public License 59 | // along with this program; if not, write to the Free Software 60 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 61 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 62 | 63 | #include 64 | 65 | // Public routines for sorting arrays of simple values that have 66 | // assignment, < and > defined. 67 | template 68 | static inline void sortSwap (T* a, int i, int j) { 69 | T tmp = a[i]; a[i] = a[j]; a[j] = tmp; 70 | } 71 | template 72 | static inline int sortCmp (T* a, int i, int j) { 73 | if (a[i] < a[j]) { return -1; } 74 | if (a[i] > a[j]) { return 1; } 75 | return 0; 76 | } 77 | 78 | // Private routine. 79 | // Sort elements [start,start+n) using insertion sort. 80 | template 81 | void 82 | __insertionSort (Closure cl, int start, int n) 83 | { 84 | for (int i = start; i < start+n-1; i++) { 85 | for (int j = i+1; j > start; j--) { 86 | if (sortCmp(cl,j-1,j) <= 0) { break; } 87 | sortSwap(cl,j-1,j); 88 | } 89 | } 90 | } 91 | 92 | // Private routine. 93 | // Sort elements [start,start+n) using selection sort. 94 | template 95 | void 96 | __selectionSort1 (Closure cl, int start, int n) 97 | { 98 | for (int i = start; i < start + n - 1; i++) { 99 | // Skip over duplicate elements. 100 | if (i > start && sortCmp(cl,i,i-1) == 0) { continue; } 101 | // Find the smallest element in [i,end] and move it to the front. 102 | int minLoc = i; 103 | for (int j = i + 1; j < start + n; j++) { 104 | if (sortCmp(cl,j,minLoc) < 0) { 105 | minLoc = j; 106 | } 107 | } 108 | if (minLoc > i) { 109 | sortSwap (cl, i, minLoc); 110 | } 111 | } 112 | } 113 | 114 | // Private routine. 115 | // Sort elements [start,start+n) using double-ended selection sort. 116 | template 117 | void 118 | __selectionSort2 (Closure cl, int start, int n) 119 | { 120 | int i = start; 121 | int j = start + n - 1; 122 | while (i < j) { 123 | // Skip over duplicate elements. 124 | if (i > start && sortCmp(cl,i,i-1) == 0) { i++; continue; } 125 | if (j < start+n-1 && sortCmp(cl,j,j+1) == 0) { j--; continue; } 126 | // Find the min and max elements in [i,j]. 127 | int minLoc=i, maxLoc=i; 128 | for (int k = i + 1; k <= j; k++) { 129 | if (sortCmp(cl,k,minLoc) < 0) { minLoc = k; } 130 | if (sortCmp(cl,k,maxLoc) > 0) { maxLoc = k; } 131 | } 132 | // Move the min element to the front and the max element to 133 | // the back. 134 | if (minLoc == maxLoc) { break; } 135 | if (minLoc > maxLoc) { 136 | sortSwap(cl,minLoc,maxLoc); 137 | int tmp=minLoc; minLoc=maxLoc; maxLoc=tmp; 138 | } 139 | if (minLoc > i) { sortSwap(cl,i,minLoc); } 140 | if (maxLoc < j) { sortSwap(cl,j,maxLoc); } 141 | i++; j--; 142 | } 143 | } 144 | 145 | // Private routine. 146 | // Return the median of the 3 arguments as defined by cmp##NAME. 147 | // Used internally in qsort to pick a pivot. 148 | template 149 | int 150 | __3median (Closure cl, int x, int y, int z) 151 | { 152 | return sortCmp(cl,x,y) > 0 153 | ? (sortCmp(cl,y,z) > 0 154 | ? y : (sortCmp(cl,x,z) > 0 ? z : x)) 155 | : (sortCmp(cl,y,z) < 0 156 | ? y : (sortCmp(cl,x,z) < 0 ? z : x)); 157 | } 158 | 159 | // Private routine. 160 | // Sort elements [start,start+n) using quick sort. 161 | template 162 | void 163 | __quickSort (Closure cl, int start, int n) 164 | { 165 | // Use selection-sort for small arrays. 166 | if (n < 16) { 167 | __insertionSort (cl, start, n); 168 | //__selectionSort1 (cl, start, n); 169 | //__selectionSort2 (cl, start, n); 170 | return; 171 | } 172 | 173 | // Pick the median of elements n/4, n/2, 3n/4 as the pivot, and 174 | // move it to the front. 175 | int x = start + (n >> 2); 176 | int y = start + (n >> 1); 177 | int z = x + (n >> 1); 178 | int pivotLoc = __3median (cl, x, y, z); 179 | sortSwap (cl, start, pivotLoc); 180 | 181 | // Segregate array elements into three groups. Those equal to the 182 | // pivot (=), those less than the pivot (<), and those greater 183 | // than the pivot (>). After this loop, the array will look like 184 | // this: 185 | // S P RL 186 | // =====<<<<<>>>>> 187 | // 188 | // Where S=start P=pivot R=right L=left. 189 | // 190 | int pivot = start; 191 | int left = start + 1; 192 | int right = start + n - 1; 193 | while (1) { 194 | restart: 195 | while (left <= right) { 196 | int c = sortCmp (cl, left, pivot); 197 | if (c > 0) { break; } 198 | if (c < 0) { left++; continue; } 199 | if (left != pivot+1) { sortSwap (cl, left, pivot+1); } 200 | pivot++; left++; 201 | } 202 | while (left <= right) { 203 | int c = sortCmp (cl, right, pivot); 204 | if (c < 0) { break; } 205 | if (c > 0) { right--; continue; } 206 | assert (left < right); 207 | sortSwap (cl, left, right); 208 | if (left != pivot+1) { sortSwap (cl, left, pivot+1); } 209 | pivot++; left++; right--; 210 | goto restart; 211 | } 212 | if (left > right) { break; } 213 | sortSwap (cl, left, right); 214 | } 215 | assert (pivot >= start); 216 | assert (right >= pivot); 217 | assert (left == right + 1); 218 | assert (left <= start + n); 219 | 220 | int numEq = pivot - start + 1; 221 | int numLt = right - pivot; 222 | int numLe = left - start; 223 | int numGt = n - numLe; 224 | assert (numEq + numLt + numGt == n); 225 | 226 | // Copy pivot values into middle. 227 | int count = (numEq < numLt) ? numEq : numLt; 228 | int dist = numLe - count; 229 | for (int i = 0; i < count; i++) { 230 | sortSwap (cl, start + i, start + i + dist); 231 | } 232 | 233 | // Recursively sort the < and > chunks. 234 | if (numLt > 0) { __quickSort (cl, start, numLt); } 235 | if (numGt > 0) { __quickSort (cl, left, numGt); } 236 | } 237 | 238 | // Public sort routine. 239 | template 240 | inline void 241 | sort (Closure cl, int n) 242 | { 243 | __quickSort (cl, 0, n); 244 | // Check the postcondition. 245 | for (int i = 1; i < n; i++) { 246 | assert (sortCmp (cl, i-1, i) <= 0); 247 | } 248 | } 249 | 250 | #endif // __Sort_hh__ 251 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # segment-anything-edge-detection 2 | 3 |

4 | 5 |

6 | 7 | This repository provides code for performing edge detection using the Automatic Mask Generation (AMG) of the Segment Anything Model (SAM) [1]. Since the code used in the paper is not currently available to the public, this implementation is based on the descriptions provided in the paper. 8 | 9 | The image on the left is taken from the BSDS. The middle is the ground truth edge. The image on the right is the result of applying edge detection. 10 | 11 | --- 12 | 13 | ## Docker 14 | 15 | This repository is intended to be run in a Docker environment. If you are not familiar with Docker, please install the packages listed in [requirements.txt](requirements.txt). 16 | 17 | ### Docker build 18 | 19 | Create a Docker image as follows: 20 | 21 | ```bash 22 | $ bash scripts/docker/build.sh 23 | ``` 24 | 25 | ### Docker run 26 | Run the Docker container by passing the GPU ID as an argument: 27 | ```bash 28 | $ bash scripts/docker/run.sh 0 29 | ``` 30 | 31 | --- 32 | 33 | ## Data 34 | 35 | ### BSDS500 36 | download BSDS500 [2] dataset from [official site](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html). 37 | 38 | If you cannot download it, the following mirror repositories may be helpful. 39 | - https://github.com/BIDS/BSDS500 40 | 41 | Then prepare the following directory structure: 42 | 43 | ```bash 44 | data/BSDS500/ 45 | ├── groundTruth 46 | │ └── test 47 | │ ├── 100007.mat 48 | │ ├── 100039.mat 49 | │ ... 50 | │ 51 | └── images 52 | ├── test 53 | │ ├── 100007.jpg 54 | │ ├── 100039.jpg 55 | │ ... 56 | │ 57 | ├── train 58 | └── val 59 | ``` 60 | 61 | ### NYUDv2 62 | 63 | download NYUDv2 [3] test dataset from [EDTER](https://github.com/MengyangPu/EDTER). 64 | Then prepare the following directory structure: 65 | 66 | ```bash 67 | data/NYUDv2/ 68 | ├── groundTruth 69 | │ └── test 70 | │ ├── img_5001.mat 71 | │ ├── img_5002.mat 72 | │ ... 73 | │ 74 | └── images 75 | ├── test 76 | │ ├── img_5001.png 77 | │ ├── img_5002.png 78 | │ ... 79 | │ 80 | ├── train 81 | └── val 82 | ``` 83 | 84 | --- 85 | 86 | ## Model 87 | 88 | Create a directory to download the model as follows: 89 | 90 | ```bash 91 | mkdir model 92 | ``` 93 | 94 | ### SAM 95 | 96 | Download the SAM model as follows: 97 | 98 | ```bash 99 | wget -P model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 100 | ``` 101 | 102 | ### Edge-NMS 103 | 104 | In the original paper [1], Canny edge NMS [4] was used for edge NMS. 105 | However, in our environment, it did not produce the edges reported in the paper. 106 | Therefore, we temporarily used OpenCV's Structured Forests [5] model for edge NMS. 107 | 108 | Download the Structured Forests model as follows: 109 | 110 | ```bash 111 | wget -P model https://cdn.rawgit.com/opencv/opencv_extra/3.3.0/testdata/cv/ximgproc/model.yml.gz 112 | ``` 113 | 114 | ## Prediction 115 | 116 | To generate the image above, do the following: 117 | ``` 118 | python example.py 119 | ``` 120 | The output result is generated in `output/example`. 121 | 122 | 123 | 124 | Predict edges as follows: 125 | 126 | ```bash 127 | python pipeline.py --dataset BSDS500 --data_split test 128 | ``` 129 | 130 | Other arguments for initializing `SamAutomaticMaskAndProbabilityGenerator` can be passed as follows. 131 | 132 | ```bash 133 | -h, --help show this help message and exit 134 | --dataset DATASET BSDS500 or NYUDv2 135 | --data_split DATA_SPLIT 136 | train, val, or test 137 | --points_per_side POINTS_PER_SIDE 138 | Number of points per side. 139 | --points_per_batch POINTS_PER_BATCH 140 | Number of points per batch 141 | --pred_iou_thresh PRED_IOU_THRESH 142 | Prediction IOU threshold 143 | --stability_score_thresh STABILITY_SCORE_THRESH 144 | Stability score threshold 145 | --stability_score_offset STABILITY_SCORE_OFFSET 146 | Stability score offset 147 | --box_nms_thresh BOX_NMS_THRESH 148 | NMS threshold for box suppression 149 | --crop_n_layers CROP_N_LAYERS 150 | Number of layers to crop 151 | --crop_nms_thresh CROP_NMS_THRESH 152 | NMS threshold for cropping 153 | --crop_overlap_ratio CROP_OVERLAP_RATIO 154 | Overlap ratio for cropping 155 | --crop_n_points_downscale_factor CROP_N_POINTS_DOWNSCALE_FACTOR 156 | Downscale factor for number of points in crop 157 | --min_mask_region_area MIN_MASK_REGION_AREA 158 | Minimum mask region area 159 | --output_mode OUTPUT_MODE 160 | Output mode of the mask generator 161 | --nms_threshold NMS_THRESHOLD 162 | NMS threshold 163 | --bzp BZP boundary zero padding 164 | --pred_iou_thresh_filtering 165 | filter by pred_iou_thresh 166 | --stability_score_thresh_filtering 167 | filter by stability_score_thresh 168 | --kernel_size KERNEL_SIZE 169 | kernel size 170 | ``` 171 | 172 | See [6] for more details about boundary zero padding. 173 | 174 | The output result is generated in `output_${dataset}/exp${exp_num}/${data_split}`. 175 | 176 | # Evaluation 177 | We use [py-bsds500](https://github.com/Britefury/py-bsds500/tree/master) for edge detection. Some bugs have been fixed and ported to the `py-bsds500` directory. 178 | Compile the extension module with: 179 | 180 | ```bash 181 | cd py-bsds500 182 | python setup.py build_ext --inplace 183 | ``` 184 | 185 | Then evaluate ODS, OIS, and AP as follows: 186 | 187 | ```bash 188 | cd py-bsds500/ 189 | python evaluate_parallel.py ../data/BSDS500 ../output/BSDS500/exp${exp}/ test --max_dist 0.0075 190 | python evaluate_parallel.py ../data/NYUDv2 ../output/NYUDv2/exp${exp}/ test --max_dist 0.011 191 | ``` 192 | 193 | Note that following previous works, the localization tolerance is set to 0.0075 for BSDS500 and 0.011 for NYUDv2. 194 | 195 | # Todo 196 | - Since there is a large gap with the original paper in terms of performance, we would like to be able to reproduce the results. 197 | - A high-performance Cany Edge NMS needs to be implemented to reproduce the settings of the original paper. 198 | 199 | ## Reference 200 | 201 | ### Code 202 | 203 | The code in this repository mainly uses code from the following repositories. Thank you. 204 | - [segment-anything](https://github.com/facebookresearch/segment-anything) 205 | - [py-bsds500](https://github.com/Britefury/py-bsds500/tree/master) 206 | - [opencv_contrib](https://github.com/opencv/opencv_contrib) 207 | 208 | ### Paper 209 | 210 | [1] Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alexander C. Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. Segment Anything. ICCV 2023. 211 | 212 | [2] Pablo Arbelaez, Michael Maire, Charless C. Fowlkes, and Jitendra Malik. Contour detection and hierarchical image segmentation. IEEE Trans. Pattern Anal. Mach. Intell 2011. 213 | 214 | [3] Nathan Silberman, Derek Hoiem, Pushmeet Kohli, and Rob Fergus. Indoor segmentation and support inference from RGBD images. ECCV 2012. 215 | 216 | [4] John F. Canny. A computational approach to edge detection. IEEE Trans. Pattern Anal. Mach. Intell 1986. 217 | 218 | [5] Piotr Dollar and C. Lawrence Zitnick. Fast edge detection using structured forests. IEEE Trans. Pattern Anal. Mach. Intell 2015. 219 | 220 | [6] Hiroaki Yamagiwa, Yusuke Takase, Hiroyuki Kambe, and Ryosuke Nakamoto. Zero-Shot Edge Detection With SCESAME: Spectral Clustering-Based Ensemble for Segment Anything Model Estimation. WACV Workshop 2024. 221 | 222 | --- 223 | ## Related Work 224 | 225 | The following is a list of studies on SAM and edge detection. Please let me know if you would like to add new research. 226 | 227 | - Wenya Yang, Xiao-Diao Chen, Wen Wu, Hongshuai Qin, Kangming Yan, Xiaoyang Mao, and Haichuan Song. [Boosting Deep Unsupervised Edge Detection via Segment Anything Model](https://ieeexplore.ieee.org/abstract/document/10490131). IEEE Transactions on Industrial Informatics 2024. 228 | - Xingchen Li, Yifan Duan, Beibei Wang, Haojie Ren, Guoliang You, Yu Sheng, 229 | Jianmin Ji, and Yanyong Zhang. [EdgeCalib: Multi-Frame Weighted Edge Features for AutomaticTargetless LiDAR-Camera Calibration](https://arxiv.org/abs/2310.16629). arXiv 2023. 230 | - Hiroaki Yamagiwa, Yusuke Takase, Hiroyuki Kambe, and Ryosuke Nakamoto. [Zero-Shot Edge Detection With SCESAME: Spectral Clustering-Based Ensemble for Segment Anything Model Estimation](https://openaccess.thecvf.com/content/WACV2024W/Pretrain/html/Yamagiwa_Zero-Shot_Edge_Detection_With_SCESAME_Spectral_Clustering-Based_Ensemble_for_Segment_WACVW_2024_paper.html). WACV Workshop 2024. 231 | 232 | # Contribution 233 | 234 | I may be slow to respond, but everyone is welcome to contribute. -------------------------------------------------------------------------------- /py-bsds500/src/Array.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Array_hh__ 3 | #define __Array_hh__ 4 | 5 | // Arrays that reduce bugs by: 6 | // - Being allocatable on the stack, so destructors get called 7 | // automatically. 8 | // - Doing bounds checking. 9 | // - Providing easy initialization. 10 | // - Encapsulating the address calculation. 11 | 12 | // The arrays are allocated as single blocks so that all elements are 13 | // contiguous in memory. Latter indices change more quickly than 14 | // former indices. Clients can rely on this ordering. 15 | 16 | // Copyright (C) 2003 David R. Martin 17 | // 18 | // This program is free software; you can redistribute it and/or 19 | // modify it under the terms of the GNU General Public License as 20 | // published by the Free Software Foundation; either version 2 of the 21 | // License, or (at your option) any later version. 22 | // 23 | // This program is distributed in the hope that it will be useful, but 24 | // WITHOUT ANY WARRANTY; without even the implied warranty of 25 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 26 | // General Public License for more details. 27 | // 28 | // You should have received a copy of the GNU General Public License 29 | // along with this program; if not, write to the Free Software 30 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 31 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 32 | 33 | #include 34 | #include 35 | 36 | template 37 | class Array1D 38 | { 39 | public: 40 | 41 | Array1D () { 42 | _alloc(0); 43 | } 44 | Array1D (unsigned n) { 45 | _alloc(n); 46 | } 47 | ~Array1D () { 48 | _delete(); 49 | } 50 | void resize (unsigned n) { 51 | if (!issize(n)) { 52 | _delete(); 53 | _alloc(n); 54 | } 55 | } 56 | void init (const Elem& elem) { 57 | for (unsigned i = 0; i < _n; i++) { 58 | _array[i] = elem; 59 | } 60 | } 61 | bool issize (unsigned n) const { 62 | return (_n == n); 63 | } 64 | int size () const { 65 | return _n; 66 | } 67 | Elem* data () { 68 | return _array; 69 | } 70 | Elem& operator() (unsigned i) { 71 | assert (i < _n); 72 | return _array[i]; 73 | } 74 | const Elem& operator() (unsigned i) const { 75 | assert (i < _n); 76 | return _array[i]; 77 | } 78 | 79 | private: 80 | 81 | void _alloc (unsigned n) { 82 | _n = n; 83 | _array = new Elem [_n]; 84 | } 85 | void _delete () { 86 | assert (_array != NULL); 87 | delete [] _array; 88 | _array = NULL; 89 | } 90 | 91 | unsigned _n; 92 | Elem* _array; 93 | 94 | }; // class Array1D 95 | 96 | template 97 | class Array2D 98 | { 99 | public: 100 | 101 | Array2D () { 102 | _alloc(0,0); 103 | } 104 | Array2D (unsigned d0, unsigned d1) { 105 | _alloc(d0,d1); 106 | } 107 | ~Array2D () { 108 | _delete(); 109 | } 110 | void resize (unsigned d0, unsigned d1) { 111 | if (!issize(d0,d1)) { 112 | _delete(); 113 | _alloc(d0,d1); 114 | } 115 | } 116 | void init (const Elem& elem) { 117 | for (unsigned i = 0; i < _n; i++) { 118 | _array[i] = elem; 119 | } 120 | } 121 | bool issize (unsigned d0, unsigned d1) const { 122 | return (_dim[0] == d0 && _dim[1] == d1); 123 | } 124 | int size (unsigned d) const { 125 | assert (d < 2); 126 | return _dim[d]; 127 | } 128 | Elem* data () { 129 | return _array; 130 | } 131 | Elem& operator() (unsigned i, unsigned j) { 132 | assert (i < _dim[0]); 133 | assert (j < _dim[1]); 134 | unsigned index = i * _dim[1] + j; 135 | assert (index < _n); 136 | return _array[index]; 137 | } 138 | const Elem& operator() (unsigned i, unsigned j) const { 139 | assert (i < _dim[0]); 140 | assert (j < _dim[1]); 141 | unsigned index = i * _dim[1] + j; 142 | assert (index < _n); 143 | return _array[index]; 144 | } 145 | 146 | private: 147 | 148 | void _alloc (unsigned d0, unsigned d1) { 149 | _n = d0 * d1; 150 | _dim[0] = d0; 151 | _dim[1] = d1; 152 | _array = new Elem [_n]; 153 | } 154 | void _delete () { 155 | assert (_array != NULL); 156 | delete [] _array; 157 | _array = NULL; 158 | } 159 | 160 | unsigned _n; 161 | Elem* _array; 162 | unsigned _dim[2]; 163 | 164 | }; // class Array2D 165 | 166 | template 167 | class Array3D 168 | { 169 | public: 170 | 171 | Array3D () { 172 | _alloc(0,0,0); 173 | } 174 | Array3D (unsigned d0, unsigned d1, unsigned d2) { 175 | _alloc(d0,d1,d2); 176 | } 177 | ~Array3D () { 178 | _delete(); 179 | } 180 | void resize (unsigned d0, unsigned d1, unsigned d2) { 181 | if (!issize(d0,d1,d2)) { 182 | _delete(); 183 | _alloc(d0,d1,d2); 184 | } 185 | } 186 | void init (const Elem& elem) { 187 | for (unsigned i = 0; i < _n; i++) { 188 | _array[i] = elem; 189 | } 190 | } 191 | bool issize (unsigned d0, unsigned d1, unsigned d2) const { 192 | return (_dim[0] == d0 && _dim[1] == d1 && _dim[2] == d2); 193 | } 194 | int size (unsigned d) const { 195 | assert (d < 3); 196 | return _dim[d]; 197 | } 198 | Elem* data () { 199 | return _array; 200 | } 201 | Elem& operator() (unsigned i, unsigned j, unsigned k) { 202 | assert (i < _dim[0]); 203 | assert (j < _dim[1]); 204 | assert (k < _dim[2]); 205 | unsigned index = (i * _dim[1] + j) * _dim[2] + k; 206 | assert (index < _n); 207 | return _array[index]; 208 | } 209 | const Elem& operator() (unsigned i, unsigned j, unsigned k) const { 210 | assert (i < _dim[0]); 211 | assert (j < _dim[1]); 212 | assert (k < _dim[2]); 213 | unsigned index = (i * _dim[1] + j) * _dim[2] + k; 214 | assert (index < _n); 215 | return _array[index]; 216 | } 217 | 218 | private: 219 | 220 | void _alloc (unsigned d0, unsigned d1, unsigned d2) { 221 | _n = d0 * d1 * d2; 222 | _array = new Elem [_n]; 223 | _dim[0] = d0; 224 | _dim[1] = d1; 225 | _dim[2] = d2; 226 | } 227 | void _delete () { 228 | assert (_array != NULL); 229 | delete [] _array; 230 | _array = NULL; 231 | } 232 | 233 | unsigned _n; 234 | Elem* _array; 235 | unsigned _dim[3]; 236 | 237 | }; // class Array3D 238 | 239 | template 240 | class Array4D 241 | { 242 | public: 243 | 244 | Array4D () { 245 | _alloc(0,0,0,0); 246 | } 247 | Array4D (unsigned d0, unsigned d1, unsigned d2, unsigned d3) { 248 | _alloc(d0,d1,d2,d3); 249 | } 250 | ~Array4D () { 251 | _delete(); 252 | } 253 | void resize (unsigned d0, unsigned d1, unsigned d2, unsigned d3) { 254 | if (!issize(d0,d1,d2,d3)) { 255 | _delete(); 256 | _alloc(d0,d1,d2,d3); 257 | } 258 | } 259 | void init (const Elem& elem) { 260 | for (unsigned i = 0; i < _n; i++) { 261 | _array[i] = elem; 262 | } 263 | } 264 | bool issize (unsigned d0, unsigned d1, unsigned d2, unsigned d3) const { 265 | return (_dim[0] == d0 && _dim[1] == d1 && _dim[2] == d2 && _dim[3] == d3); 266 | } 267 | int size (unsigned d) const { 268 | assert (d < 4); 269 | return _dim[d]; 270 | } 271 | Elem* data () { 272 | return _array; 273 | } 274 | Elem& operator() (unsigned i, unsigned j, unsigned k, unsigned m) { 275 | assert (i < _dim[0]); 276 | assert (j < _dim[1]); 277 | assert (k < _dim[2]); 278 | assert (m < _dim[3]); 279 | unsigned index = ((i * _dim[1] + j) * _dim[2] + k) * _dim[3] + m; 280 | assert (index < _n); 281 | return _array[index]; 282 | } 283 | const Elem& operator() (unsigned i, unsigned j, unsigned k, unsigned m) const { 284 | assert (i < _dim[0]); 285 | assert (j < _dim[1]); 286 | assert (k < _dim[2]); 287 | assert (m < _dim[3]); 288 | unsigned index = ((i * _dim[1] + j) * _dim[2] + k) * _dim[3] + m; 289 | assert (index < _n); 290 | return _array[index]; 291 | } 292 | 293 | private: 294 | 295 | void _alloc (unsigned d0, unsigned d1, unsigned d2, unsigned d3) { 296 | _n = d0 * d1 * d2 * d3; 297 | _array = new Elem [_n]; 298 | _dim[0] = d0; 299 | _dim[1] = d1; 300 | _dim[2] = d2; 301 | _dim[3] = d3; 302 | } 303 | void _delete () { 304 | assert (_array != NULL); 305 | delete [] _array; 306 | _array = NULL; 307 | } 308 | 309 | unsigned _n; 310 | Elem* _array; 311 | unsigned _dim[4]; 312 | 313 | }; // class Array4D 314 | 315 | #endif // __Array_hh__ 316 | -------------------------------------------------------------------------------- /py-bsds500/bsds/bsds_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from skimage.util import img_as_float 5 | from skimage.color import rgb2gray 6 | from skimage.io import imread 7 | from scipy.io import loadmat 8 | 9 | 10 | class Dataset (object): 11 | """ 12 | BSDS dataset wrapper 13 | 14 | Given the path to the root of the BSDS dataset, this class provides 15 | methods for loading images, ground truths and evaluating predictions 16 | 17 | Attribtes: 18 | 19 | data_path - the root path of the dataset 20 | images_path - the path of the images directory within the data dir 21 | gt_path - the path of the groundTruth directory within the data dir 22 | train_sample_names - a list of names of training images 23 | val_sample_names - a list of names of validation images 24 | test_sample_names - a list of names of test images 25 | """ 26 | def __init__(self, data_path, ext): 27 | 28 | self.ext = ext 29 | self.data_path = data_path 30 | self.images_path = os.path.join(self.data_path, 'images') 31 | self.gt_path = os.path.join(self.data_path, 'groundTruth') 32 | 33 | self.train_sample_names = self._sample_names(self.images_path, 'train') 34 | self.val_sample_names = self._sample_names(self.images_path, 'val') 35 | self.test_sample_names = self._sample_names(self.images_path, 'test') 36 | 37 | def _sample_names(self, dir, subset): 38 | names = [] 39 | path = os.path.join(dir, subset) 40 | 41 | if not os.path.exists(path): 42 | return names 43 | 44 | files = os.listdir(path) 45 | for fn in files: 46 | dir, filename = os.path.split(fn) 47 | name, ext = os.path.splitext(filename) 48 | if ext.lower() == self.ext.lower(): 49 | names.append(os.path.join(subset, name)) 50 | return names 51 | 52 | def read_image(self, name): 53 | """ 54 | Load the image identified by the sample name (you can get the names 55 | from the `train_sample_names`, `val_sample_names` and 56 | `test_sample_names` attributes) 57 | :param name: the sample name 58 | :return: a (H,W,3) array containing the image, scaled to range [0,1] 59 | """ 60 | path = os.path.join(self.images_path, name + self.ext) 61 | return img_as_float(imread(path)) 62 | 63 | def get_image_shape(self, name): 64 | """ 65 | Get the shape of the image identified by the sample name (you can 66 | get the names from the `train_sample_names`, `val_sample_names` and 67 | `test_sample_names` attributes) 68 | :param name: the sample name 69 | :return: a tuple of the form `(height, width, channels)` 70 | """ 71 | path = os.path.join(self.images_path, name + self.ext) 72 | img = Image.open(path) 73 | return img.height, img.width, 3 74 | 75 | def ground_truth_mat(self, name): 76 | """ 77 | Load the ground truth Matlab file identified by the sample name 78 | (you can get the names from the `train_sample_names`, 79 | `val_sample_names` and `test_sample_names` attributes) 80 | :param name: the sample name 81 | :return: the `groundTruth` entry from the Matlab file 82 | """ 83 | path = os.path.join(self.gt_path, name + '.mat') 84 | return self.load_ground_truth_mat(path) 85 | 86 | def segmentations(self, name): 87 | """ 88 | Load the ground truth segmentations identified by the sample name 89 | (you can get the names from the `train_sample_names`, 90 | `val_sample_names` and `test_sample_names` attributes) 91 | :param name: the sample name 92 | :return: a list of (H,W) arrays, each of which contains a 93 | segmentation ground truth 94 | """ 95 | path = os.path.join(self.gt_path, name + '.mat') 96 | return self.load_segmentations(path) 97 | 98 | def boundaries(self, name): 99 | """ 100 | Load the ground truth boundaries identified by the sample name 101 | (you can get the names from the `train_sample_names`, 102 | `val_sample_names` and `test_sample_names` attributes) 103 | :param name: the sample name 104 | :return: a list of (H,W) arrays, each of which contains a 105 | boundary ground truth 106 | """ 107 | path = os.path.join(self.gt_path, name + '.mat') 108 | return self.load_boundaries(path) 109 | 110 | @staticmethod 111 | def load_ground_truth_mat(path): 112 | """ 113 | Load the ground truth Matlab file at the specified path 114 | and return the `groundTruth` entry. 115 | :param path: path 116 | :return: the 'groundTruth' entry from the Matlab file 117 | """ 118 | gt = loadmat(path) 119 | return gt['groundTruth'] 120 | 121 | @staticmethod 122 | def load_segmentations(path): 123 | """ 124 | Load the ground truth segmentations from the Matlab file 125 | at the specified path. 126 | :param path: path 127 | :return: a list of (H,W) arrays, each of which contains a 128 | segmentation ground truth 129 | """ 130 | gt = Dataset.load_ground_truth_mat(path) 131 | num_gts = gt.shape[1] 132 | return [gt[0,i]['Segmentation'][0,0].astype(np.int32) for i in range(num_gts)] 133 | 134 | @staticmethod 135 | def load_boundaries(path): 136 | """ 137 | Load the ground truth boundaries from the Matlab file 138 | at the specified path. 139 | :param path: path 140 | :return: a list of (H,W) arrays, each of which contains a 141 | boundary ground truth 142 | """ 143 | gt = Dataset.load_ground_truth_mat(path) 144 | num_gts = gt.shape[1] 145 | return [gt[0,i]['Boundaries'][0,0] for i in range(num_gts)] 146 | 147 | 148 | class BSDSHEDAugDataset (object): 149 | """ 150 | BSDS HED augmented dataset wrapper 151 | 152 | Given the path to the root of the BSDS dataset, this class provides 153 | methods for loading images, ground truths and evaluating predictions 154 | 155 | The augmented dataset can be downloaded from: 156 | 157 | http://vcl.ucsd.edu/hed/HED-BSDS.tar 158 | 159 | See their repo for more information: 160 | 161 | http://github.com/s9xie/hed 162 | 163 | Attribtes: 164 | 165 | bsds_dataset - standard BSDS dataset 166 | root_path - the root path of the dataset 167 | """ 168 | 169 | AUG_SCALES = [ 170 | '', '_scale_0.5', '_scale_1.5' 171 | ] 172 | 173 | AUG_ROTS = [ 174 | '0.0', '22.5', '45.0', '67.5', '90.0', '112.5', '135.0', '157.5', '180.0', '202.5', '225.0', '247.5', 175 | '270.0', '292.5', '315.0', '337.5' 176 | ] 177 | 178 | AUG_FLIPS = [ 179 | '1_0', '1_1' 180 | ] 181 | 182 | ALL_AUGS = [] 183 | for s, r, f in zip(AUG_SCALES, AUG_ROTS, AUG_FLIPS): 184 | ALL_AUGS.append((s, r, f)) 185 | 186 | def __init__(self, bsds_dataset, root_path): 187 | """ 188 | Constructor 189 | 190 | :param bsds_dataset: the standard BSDS dataset 191 | :param root_path: the path to the root of the augmented dataset 192 | """ 193 | self.bsds_dataset = bsds_dataset 194 | self.root_path = root_path 195 | 196 | 197 | self.sample_name_to_fold = {} 198 | for name in bsds_dataset.train_sample_names: 199 | self.sample_name_to_fold[name] = 'train' 200 | for name in bsds_dataset.val_sample_names: 201 | self.sample_name_to_fold[name] = 'train' 202 | for name in bsds_dataset.test_sample_names: 203 | self.sample_name_to_fold[name] = 'test' 204 | 205 | def _data_path(self, data_type, scale, rot, flip, name, ext): 206 | fold = self.sample_name_to_fold[name] 207 | if data_type not in {'data', 'gt'}: 208 | raise ValueError("data_type should be 'data' or 'gt', not {}".format(data_type)) 209 | if scale not in self.AUG_SCALES: 210 | raise ValueError("scale should be one of {}, not {}".format(self.AUG_SCALES, scale)) 211 | if rot not in self.AUG_ROTS: 212 | raise ValueError("rot should be one of {}, not {}".format(self.AUG_ROTS, rot)) 213 | if flip not in self.AUG_FLIPS: 214 | raise ValueError("flip should be one of {}, not {}".format(self.AUG_FLIPS, flip)) 215 | return os.path.join(self.root_path, fold, 'aug_{}{}'.format(data_type, scale), '{}_{}'.format(rot, flip), 216 | '{}{}'.format(os.path.split(name)[1], ext)) 217 | 218 | @classmethod 219 | def augment_names(cls, names): 220 | """ 221 | Add augmentation parameters to the supplied list of names. Converts a 222 | sequence of names into a sequence of tuples that provide the name along 223 | with augmentation parameters. Each name is combined will all possible 224 | combinations of augmentation parameters. By default, there are 96 225 | possible augmentations, so the resulting list will be 96x the length 226 | of `names`. 227 | 228 | The tuples returned can be used as parameters for the `read_image`, 229 | `image_shape` and `mean_boundaries` methods. 230 | 231 | :param names: a sequence of names 232 | :return: list of `(name, scale_aug, rotate_aug, flip_aug)` tuples 233 | """ 234 | return [(n, s, r, f) for n in names for (s, r, f) in cls.ALL_AUGS] 235 | 236 | def read_image(self, name, scale, rot, flip): 237 | """ 238 | Load the image identified by the sample name and augmentation 239 | parameters. 240 | The sample name `name` should come from the `train_sample_names`, 241 | `val_sample_names` and `test_sample_names` attributes of a 242 | `Dataset` instance. 243 | The `scale`, `rot` and `flip` augmentation parameters should 244 | come from `AUG_SCALES`, `AUG_ROTS` and `AUG_FLIPS` attributes 245 | of the `BSDSHEDAugDataset` class 246 | :param name: the sample name 247 | :param scale: augmentation scale 248 | :param rot: augmentation rotation 249 | :param flip: augmentation flip 250 | :return: a tuple of the form `(height, width, channels)` 251 | """ 252 | path = self._data_path('data', scale, rot, flip, name, '.jpg') 253 | return img_as_float(imread(path)).astype(np.float32) 254 | 255 | def get_image_shape(self, name, scale, rot, flip): 256 | """ 257 | Get the shape of the image identified by the sample name 258 | and augmentation parameters. 259 | The sample name `name` should come from the `train_sample_names`, 260 | `val_sample_names` and `test_sample_names` attributes of a 261 | `Dataset` instance. 262 | The `scale`, `rot` and `flip` augmentation parameters should 263 | come from `AUG_SCALES`, `AUG_ROTS` and `AUG_FLIPS` attributes 264 | of the `BSDSHEDAugDataset` class 265 | :param name: the sample name 266 | :param scale: augmentation scale 267 | :param rot: augmentation rotation 268 | :param flip: augmentation flip 269 | :return: a (H,W,3) array containing the image, scaled to range [0,1] 270 | """ 271 | path = self._data_path('data', scale, rot, flip, name, '.jpg') 272 | img = Image.open(path) 273 | return img.height, img.width, 3 274 | 275 | def mean_boundaries(self, name, scale, rot, flip): 276 | """ 277 | Load the ground truth boundaries identified by the sample name 278 | and augmentation parameters. 279 | 280 | See the `read_image` method for more information on the sample 281 | name and augmentation parameters 282 | 283 | :param name: the sample name 284 | :param scale: augmentation scale 285 | :param rot: augmentation rotation 286 | :param flip: augmentation flip 287 | :return: a list of (H,W) arrays, each of which contains a 288 | boundary ground truth 289 | """ 290 | path = self._data_path('gt', scale, rot, flip, name, '.png') 291 | return self.load_mean_boundaries(path) 292 | 293 | @staticmethod 294 | def load_mean_boundaries(path): 295 | """ 296 | Load the ground truth boundaries from the Matlab file 297 | at the specified path. 298 | :param path: path 299 | :return: a list of (H,W) arrays, each of which contains a 300 | boundary ground truth 301 | """ 302 | return rgb2gray(img_as_float(imread(path))).astype(np.float32) 303 | -------------------------------------------------------------------------------- /py-bsds500/bsds/evaluate_boundaries.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | from . import thin, correspond_pixels 4 | 5 | 6 | def evaluate_boundaries_bin(predicted_boundaries_bin, gt_boundaries, 7 | max_dist=0.0075, apply_thinning=True): 8 | """ 9 | Evaluate the accuracy of a predicted boundary. 10 | 11 | :param predicted_boundaries_bin: the predicted boundaries as a (H,W) 12 | binary array 13 | :param gt_boundaries: a list of ground truth boundaries, as returned 14 | by the `load_boundaries` or `boundaries` methods 15 | :param max_dist: (default=0.0075) maximum distance parameter 16 | used for determining pixel matches. This value is multiplied by the 17 | length of the diagonal of the image to get the threshold used 18 | for matching pixels. 19 | :param apply_thinning: (default=True) if True, apply morphologial 20 | thinning to the predicted boundaries before evaluation 21 | :return: tuple `(count_r, sum_r, count_p, sum_p)` where each of 22 | the four entries are float values that can be used to compute 23 | recall and precision with: 24 | ``` 25 | recall = count_r / (sum_r + (sum_r == 0)) 26 | precision = count_p / (sum_p + (sum_p == 0)) 27 | ``` 28 | """ 29 | acc_prec = np.zeros(predicted_boundaries_bin.shape, dtype=bool) 30 | predicted_boundaries_bin = predicted_boundaries_bin != 0 31 | 32 | if apply_thinning: 33 | predicted_boundaries_bin = thin.binary_thin(predicted_boundaries_bin) 34 | 35 | sum_r = 0 36 | count_r = 0 37 | for gt in gt_boundaries: 38 | match1, match2, cost, oc = correspond_pixels.correspond_pixels( 39 | predicted_boundaries_bin, gt, max_dist=max_dist 40 | ) 41 | match1 = match1 > 0 42 | match2 = match2 > 0 43 | # Precision accumulator 44 | acc_prec = acc_prec | match1 45 | # Recall 46 | sum_r += gt.sum() 47 | count_r += match2.sum() 48 | 49 | # Precision 50 | sum_p = predicted_boundaries_bin.sum() 51 | count_p = acc_prec.sum() 52 | 53 | return count_r, sum_r, count_p, sum_p 54 | 55 | 56 | def evaluate_boundaries(predicted_boundaries, gt_boundaries, 57 | thresholds=99, max_dist=0.0075, apply_thinning=True, 58 | progress=None): 59 | """ 60 | Evaluate the accuracy of a predicted boundary and a range of thresholds 61 | 62 | :param predicted_boundaries: the predicted boundaries as a (H,W) 63 | floating point array where each pixel represents the strength of the 64 | predicted boundary 65 | :param gt_boundaries: a list of ground truth boundaries, as returned 66 | by the `load_boundaries` or `boundaries` methods 67 | :param thresholds: either an integer specifying the number of thresholds 68 | to use or a 1D array specifying the thresholds 69 | :param max_dist: (default=0.0075) maximum distance parameter 70 | used for determining pixel matches. This value is multiplied by the 71 | length of the diagonal of the image to get the threshold used 72 | for matching pixels. 73 | :param apply_thinning: (default=True) if True, apply morphologial 74 | thinning to the predicted boundaries before evaluation 75 | :param progress: a function that can be used to monitor progress; 76 | use `tqdm.tqdm` or `tdqm.tqdm_notebook` from the `tqdm` package 77 | to generate a progress bar. 78 | :return: tuple `(count_r, sum_r, count_p, sum_p, thresholds)` where each 79 | of the first four entries are arrays that can be used to compute 80 | recall and precision at each threshold with: 81 | ``` 82 | recall = count_r / (sum_r + (sum_r == 0)) 83 | precision = count_p / (sum_p + (sum_p == 0)) 84 | ``` 85 | The thresholds are also returned. 86 | """ 87 | if progress is None: 88 | progress = lambda x, *args, **kwargs: x 89 | 90 | # Handle thresholds 91 | if isinstance(thresholds, int): 92 | thresholds = np.linspace(1.0 / (thresholds + 1), 93 | 1.0 - 1.0 / (thresholds + 1), thresholds) 94 | elif isinstance(thresholds, np.ndarray): 95 | if thresholds.ndim != 1: 96 | raise ValueError('thresholds array should have 1 dimension, ' 97 | 'not {}'.format(thresholds.ndim)) 98 | pass 99 | else: 100 | raise ValueError('thresholds should be an int or a NumPy array, not ' 101 | 'a {}'.format(type(thresholds))) 102 | 103 | sum_p = np.zeros(thresholds.shape) 104 | count_p = np.zeros(thresholds.shape) 105 | sum_r = np.zeros(thresholds.shape) 106 | count_r = np.zeros(thresholds.shape) 107 | 108 | for i_t, thresh in enumerate(progress(list(thresholds))): 109 | predicted_boundaries_bin = predicted_boundaries >= thresh 110 | 111 | acc_prec = np.zeros(predicted_boundaries_bin.shape, dtype=bool) 112 | 113 | if apply_thinning: 114 | predicted_boundaries_bin = thin.binary_thin( 115 | predicted_boundaries_bin) 116 | 117 | for gt in gt_boundaries: 118 | 119 | match1, match2, cost, oc = correspond_pixels.correspond_pixels( 120 | predicted_boundaries_bin, gt, max_dist=max_dist 121 | ) 122 | match1 = match1 > 0 123 | match2 = match2 > 0 124 | # Precision accumulator 125 | acc_prec = acc_prec | match1 126 | # Recall 127 | sum_r[i_t] += gt.sum() 128 | count_r[i_t] += match2.sum() 129 | 130 | # Precision 131 | sum_p[i_t] = predicted_boundaries_bin.sum() 132 | count_p[i_t] = acc_prec.sum() 133 | 134 | return count_r, sum_r, count_p, sum_p, thresholds 135 | 136 | 137 | 138 | def compute_rec_prec_f1(count_r, sum_r, count_p, sum_p): 139 | """ 140 | Computer recall, precision and F1-score given `count_r`, `sum_r`, 141 | `count_p` and `sum_p`; see `evaluate_boundaries`. 142 | :param count_r: 143 | :param sum_r: 144 | :param count_p: 145 | :param sum_p: 146 | :return: tuple `(recall, precision, f1)` 147 | """ 148 | rec = count_r / (sum_r + (sum_r == 0)) 149 | prec = count_p / (sum_p + (sum_p == 0)) 150 | f1_denom = (prec + rec + ((prec+rec) == 0)) 151 | f1 = 2.0 * prec * rec / f1_denom 152 | return rec, prec, f1 153 | 154 | 155 | SampleResult = namedtuple('SampleResult', ['sample_name', 'threshold', 156 | 'recall', 'precision', 'f1']) 157 | ThresholdResult = namedtuple('ThresholdResult', ['threshold', 'recall', 158 | 'precision', 'f1']) 159 | OverallResult = namedtuple('OverallResult', ['threshold', 'recall', 160 | 'precision', 'f1', 161 | 'best_recall', 'best_precision', 162 | 'best_f1', 'area_pr']) 163 | 164 | def pr_evaluation(thresholds, sample_names, load_gt_boundaries, load_pred, 165 | progress=None): 166 | """ 167 | Perform an evaluation of predictions against ground truths for an image 168 | set over a given set of thresholds. 169 | 170 | :param thresholds: either an integer specifying the number of thresholds 171 | to use or a 1D array specifying the thresholds 172 | :param sample_names: the names of the samples that are to be evaluated 173 | :param load_gt_boundaries: a callable that loads the ground truth for a 174 | named sample; of the form `load_gt_boundaries(sample_name) -> gt` 175 | where `gt` is a 2D NumPy array 176 | :param load_pred: a callable that loads the prediction for a 177 | named sample; of the form `load_gt_boundaries(sample_name) -> gt` 178 | where `gt` is a 2D NumPy array 179 | :param progress: default=None a callable -- such as `tqdm` -- that 180 | accepts an iterator over the sample names in order to track progress 181 | :return: `(sample_results, threshold_results, overall_result)` 182 | where `sample_results` is a list of `SampleResult` named tuples with one 183 | for each sample, `threshold_results` is a list of `ThresholdResult` 184 | named tuples, with one for each threshold and `overall_result` 185 | is an `OverallResult` named tuple giving the over all results. The 186 | attributes in these structures will now be described: 187 | 188 | `SampleResult`: 189 | - `sample_name`: the name identifying the sample to which this result 190 | applies 191 | - `threshold`: the threshold at which the best F1-score was obtained for 192 | the given sample 193 | - `recall`: the recall score obtained at the best threshold 194 | - `precision`: the precision score obtained at the best threshold 195 | - `f1`: the F1-score obtained at the best threshold\ 196 | 197 | `ThresholdResult`: 198 | - `threshold`: the threshold value to which this result applies 199 | - `recall`: the average recall score for all samples 200 | - `precision`: the average precision score for all samples 201 | - `f1`: the average F1-score for all samples 202 | 203 | `OverallResult`: 204 | - `threshold`: the threshold at which the best average F1-score over 205 | all samples is obtained 206 | - `recall`: the average recall score for all samples at `threshold` 207 | - `precision`: the average precision score for all samples at `threshold` 208 | - `f1`: the average F1-score for all samples at `threshold` 209 | - `best_recall`: the average recall score for all samples at the best 210 | threshold *for each individual sample* 211 | - `best_precision`: the average precision score for all samples at the 212 | best threshold *for each individual sample* 213 | - `best_f1`: the average F1-score for all samples at the best threshold 214 | *for each individual sample* 215 | - `area_pr`: the area under the precision-recall curve at `threshold` 216 | ` 217 | """ 218 | if progress is None: 219 | progress = lambda x, *args: x 220 | 221 | if isinstance(thresholds, int): 222 | n_thresh = thresholds 223 | else: 224 | n_thresh = thresholds.shape[0] 225 | 226 | count_r_overall = np.zeros((n_thresh,)) 227 | sum_r_overall = np.zeros((n_thresh,)) 228 | count_p_overall = np.zeros((n_thresh,)) 229 | sum_p_overall = np.zeros((n_thresh,)) 230 | 231 | count_r_best = 0 232 | sum_r_best = 0 233 | count_p_best = 0 234 | sum_p_best = 0 235 | 236 | sample_results = [] 237 | for sample_index, sample_name in enumerate(progress(sample_names)): 238 | # Get the paths for the ground truth and predicted boundaries 239 | 240 | # Load them 241 | pred = load_pred(sample_name) 242 | gt_b = load_gt_boundaries(sample_name) 243 | 244 | # Evaluate predictions 245 | count_r, sum_r, count_p, sum_p, used_thresholds = \ 246 | evaluate_boundaries(pred, gt_b, thresholds=thresholds, 247 | apply_thinning=True) 248 | 249 | count_r_overall += count_r 250 | sum_r_overall += sum_r 251 | count_p_overall += count_p 252 | sum_p_overall += sum_p 253 | 254 | # Compute precision, recall and F1 255 | rec, prec, f1 = compute_rec_prec_f1(count_r, sum_r, count_p, sum_p) 256 | 257 | # Find best F1 score 258 | best_ndx = np.argmax(f1) 259 | 260 | count_r_best += count_r[best_ndx] 261 | sum_r_best += sum_r[best_ndx] 262 | count_p_best += count_p[best_ndx] 263 | sum_p_best += sum_p[best_ndx] 264 | 265 | sample_results.append(SampleResult(sample_name, 266 | used_thresholds[best_ndx], 267 | rec[best_ndx], prec[best_ndx], 268 | f1[best_ndx])) 269 | 270 | # Computer overall precision, recall and F1 271 | rec_overall, prec_overall, f1_overall = compute_rec_prec_f1( 272 | count_r_overall, sum_r_overall, count_p_overall, sum_p_overall) 273 | 274 | # Find best F1 score 275 | best_i_ovr = np.argmax(f1_overall) 276 | 277 | threshold_results = [] 278 | for thresh_i in range(n_thresh): 279 | threshold_results.append(ThresholdResult(used_thresholds[thresh_i], 280 | rec_overall[thresh_i], 281 | prec_overall[thresh_i], 282 | f1_overall[thresh_i])) 283 | 284 | 285 | rec_unique, rec_unique_ndx = np.unique(rec_overall, return_index=True) 286 | prec_unique = prec_overall[rec_unique_ndx] 287 | if rec_unique.shape[0] > 1: 288 | prec_interp = np.interp(np.arange(0, 1, 0.01), rec_unique, 289 | prec_unique, left=0.0, right=0.0) 290 | area_pr = prec_interp.sum() * 0.01 291 | else: 292 | area_pr = 0.0 293 | 294 | rec_best, prec_best, f1_best = compute_rec_prec_f1( 295 | float(count_r_best), float(sum_r_best), float(count_p_best), 296 | float(sum_p_best) 297 | ) 298 | 299 | overall_result = OverallResult(used_thresholds[best_i_ovr], 300 | rec_overall[best_i_ovr], 301 | prec_overall[best_i_ovr], 302 | f1_overall[best_i_ovr], rec_best, 303 | prec_best, f1_best, area_pr) 304 | 305 | return sample_results, threshold_results, overall_result 306 | -------------------------------------------------------------------------------- /py-bsds500/src/match.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "csa.hh" 8 | #include "kofn.hh" 9 | #include "Point.hh" 10 | #include "Matrix.hh" 11 | #include "Array.hh" 12 | #include "match.hh" 13 | #include "Timer.hh" 14 | 15 | struct Edge { 16 | int i,j; // node ids, 0-based 17 | double w; // distance between pixels 18 | }; 19 | 20 | // CSA code needs integer weights. Use this multiplier to convert 21 | // floating-point weights to integers. 22 | static const int multiplier = 100; 23 | 24 | // The degree of outlier connections. 25 | static const int degree = 6; 26 | 27 | double 28 | matchEdgeMaps ( 29 | const Matrix& bmap1, const Matrix& bmap2, 30 | double maxDist, double outlierCost, 31 | Matrix& m1, Matrix& m2) 32 | { 33 | // Check global constants. 34 | assert (degree > 0); 35 | assert (multiplier > 0); 36 | 37 | // Check arguments. 38 | assert (bmap1.nrows() == bmap2.nrows()); 39 | assert (bmap1.ncols() == bmap2.ncols()); 40 | assert (maxDist >= 0); 41 | assert (outlierCost > maxDist); 42 | 43 | const int height = bmap1.nrows(); 44 | const int width = bmap1.ncols(); 45 | 46 | // Initialize to zeros. 47 | m1 = Matrix(height,width); 48 | m2 = Matrix(height,width); 49 | 50 | // Initialize match[12] arrays to (-1,-1). 51 | Array2D match1 (width,height); 52 | Array2D match2 (width,height); 53 | for (int x = 0; x < width; x++) { 54 | for (int y = 0; y < height; y++) { 55 | match1(x,y) = Pixel(-1,-1); 56 | match2(x,y) = Pixel(-1,-1); 57 | } 58 | } 59 | 60 | // Radius of search window. 61 | const int r = (int) ceil (maxDist); 62 | 63 | // Figure out which nodes are matchable, i.e. within maxDist 64 | // of another node. 65 | Array2D matchable1 (width,height); 66 | Array2D matchable2 (width,height); 67 | matchable1.init(false); 68 | matchable2.init(false); 69 | for (int y1 = 0; y1 < height; y1++) { 70 | for (int x1 = 0; x1 < width; x1++) { 71 | if (!bmap1(y1,x1)) { continue; } 72 | for (int v = -r; v <= r; v++) { 73 | for (int u = -r; u <= r; u++) { 74 | const double d2 = u*u + v*v; 75 | if (d2 > maxDist*maxDist) { continue; } 76 | const int x2 = x1 + u; 77 | const int y2 = y1 + v; 78 | if (x2 < 0 || x2 >= width) { continue; } 79 | if (y2 < 0 || y2 >= height) { continue; } 80 | if (!bmap2(y2,x2)) { continue; } 81 | matchable1(x1,y1) = true; 82 | matchable2(x2,y2) = true; 83 | } 84 | } 85 | } 86 | } 87 | 88 | // Count the number of nodes on each side of the match. 89 | // Construct nodeID->pixel and pixel->nodeID maps. 90 | // Node IDs range from [0,n1) and [0,n2). 91 | int n1=0, n2=0; 92 | std::vector nodeToPix1; 93 | std::vector nodeToPix2; 94 | Array2D pixToNode1 (width,height); 95 | Array2D pixToNode2 (width,height); 96 | for (int x = 0; x < width; x++) { 97 | for (int y = 0; y < height; y++) { 98 | pixToNode1(x,y) = -1; 99 | pixToNode2(x,y) = -1; 100 | Pixel pix (x,y); 101 | if (matchable1(x,y)) { 102 | pixToNode1(x,y) = n1; 103 | nodeToPix1.push_back(pix); 104 | n1++; 105 | } 106 | if (matchable2(x,y)) { 107 | pixToNode2(x,y) = n2; 108 | nodeToPix2.push_back(pix); 109 | n2++; 110 | } 111 | } 112 | } 113 | 114 | // Construct the list of edges between pixels within maxDist. 115 | std::vector edges; 116 | for (int x1 = 0; x1 < width; x1++) { 117 | for (int y1 = 0; y1 < height; y1++) { 118 | if (!matchable1(x1,y1)) { continue; } 119 | for (int u = -r; u <= r; u++) { 120 | for (int v = -r; v <= r; v++) { 121 | const double d2 = u*u + v*v; 122 | if (d2 > maxDist*maxDist) { continue; } 123 | const int x2 = x1 + u; 124 | const int y2 = y1 + v; 125 | if (x2 < 0 || x2 >= width) { continue; } 126 | if (y2 < 0 || y2 >= height) { continue; } 127 | if (!matchable2(x2,y2)) { continue; } 128 | Edge e; 129 | e.i = pixToNode1(x1,y1); 130 | e.j = pixToNode2(x2,y2); 131 | e.w = sqrt(d2); 132 | assert (e.i >= 0 && e.i < n1); 133 | assert (e.j >= 0 && e.j < n2); 134 | assert (e.w < outlierCost); 135 | edges.push_back(e); 136 | } 137 | } 138 | } 139 | } 140 | 141 | // The cardinality of the match is n. 142 | const int n = n1 + n2; 143 | const int nmin = std::min(n1,n2); 144 | const int nmax = std::max(n1,n2); 145 | 146 | // Compute the degree of various outlier connections. 147 | const int d1 = std::max(0,std::min(degree,n1-1)); // from map1 148 | const int d2 = std::max(0,std::min(degree,n2-1)); // from map2 149 | const int d3 = std::min(degree,std::min(n1,n2)); // between outliers 150 | const int dmax = std::max(d1,std::max(d2,d3)); 151 | 152 | assert (n1 == 0 || (d1 >= 0 && d1 < n1)); 153 | assert (n2 == 0 || (d2 >= 0 && d2 < n2)); 154 | assert (d3 >= 0 && d3 <= nmin); 155 | 156 | // Count the number of edges. 157 | int m = 0; 158 | m += edges.size(); // real connections 159 | m += d1 * n1; // outlier connections 160 | m += d2 * n2; // outlier connections 161 | m += d3 * nmax; // outlier-outlier connections 162 | m += n; // high-cost perfect match overlay 163 | 164 | // If the graph is empty, then there's nothing to do. 165 | if (m == 0) { 166 | return 0; 167 | } 168 | 169 | // Weight of outlier connections. 170 | const int ow = (int) ceil (outlierCost * multiplier); 171 | 172 | // Scratch array for outlier edges. 173 | Array1D outliers (dmax); 174 | 175 | // Construct the input graph for the assignment problem. 176 | Array2D igraph (m,3); 177 | int count = 0; 178 | // real edges 179 | for (int a = 0; a < (int)edges.size(); a++) { 180 | int i = edges[a].i; 181 | int j = edges[a].j; 182 | assert (i >= 0 && i < n1); 183 | assert (j >= 0 && j < n2); 184 | igraph(count,0) = i; 185 | igraph(count,1) = j; 186 | igraph(count,2) = (int) rint (edges[a].w * multiplier); 187 | count++; 188 | } 189 | // outliers edges for map1, exclude diagonal 190 | for (int i = 0; i < n1; i++) { 191 | kOfN(d1,n1-1,outliers.data()); 192 | for (int a = 0; a < d1; a++) { 193 | int j = outliers(a); 194 | if (j >= i) { j++; } 195 | assert (i != j); 196 | assert (j >= 0 && j < n1); 197 | igraph(count,0) = i; 198 | igraph(count,1) = n2 + j; 199 | igraph(count,2) = ow; 200 | count++; 201 | } 202 | } 203 | // outliers edges for map2, exclude diagonal 204 | for (int j = 0; j < n2; j++) { 205 | kOfN(d2,n2-1,outliers.data()); 206 | for (int a = 0; a < d2; a++) { 207 | int i = outliers(a); 208 | if (i >= j) { i++; } 209 | assert (i != j); 210 | assert (i >= 0 && i < n2); 211 | igraph(count,0) = n1 + i; 212 | igraph(count,1) = j; 213 | igraph(count,2) = ow; 214 | count++; 215 | } 216 | } 217 | // outlier-to-outlier edges 218 | for (int i = 0; i < nmax; i++) { 219 | kOfN(d3,nmin,outliers.data()); 220 | for (int a = 0; a < d3; a++) { 221 | const int j = outliers(a); 222 | assert (j >= 0 && j < nmin); 223 | if (n1 < n2) { 224 | assert (i >= 0 && i < n2); 225 | assert (j >= 0 && j < n1); 226 | igraph(count,0) = n1 + i; 227 | igraph(count,1) = n2 + j; 228 | } else { 229 | assert (i >= 0 && i < n1); 230 | assert (j >= 0 && j < n2); 231 | igraph(count,0) = n1 + j; 232 | igraph(count,1) = n2 + i; 233 | } 234 | igraph(count,2) = ow; 235 | count++; 236 | } 237 | } 238 | // perfect match overlay (diagonal) 239 | for (int i = 0; i < n1; i++) { 240 | igraph(count,0) = i; 241 | igraph(count,1) = n2 + i; 242 | igraph(count,2) = ow * multiplier; 243 | count++; 244 | } 245 | for (int i = 0; i < n2; i++) { 246 | igraph(count,0) = n1 + i; 247 | igraph(count,1) = i; 248 | igraph(count,2) = ow * multiplier; 249 | count++; 250 | } 251 | assert (count == m); 252 | 253 | // Check all the edges, and set the values up for CSA. 254 | for (int i = 0; i < m; i++) { 255 | assert(igraph(i,0) >= 0 && igraph(i,0) < n); 256 | assert(igraph(i,1) >= 0 && igraph(i,1) < n); 257 | igraph(i,0) += 1; 258 | igraph(i,1) += 1+n; 259 | } 260 | 261 | // Solve the assignment problem. 262 | CSA csa(2*n,m,igraph.data()); 263 | assert(csa.edges()==n); 264 | 265 | Array2D ograph (n,3); 266 | for (int i = 0; i < n; i++) { 267 | int a,b,c; 268 | csa.edge(i,a,b,c); 269 | ograph(i,0)=a-1; ograph(i,1)=b-1-n; ograph(i,2)=c; 270 | } 271 | 272 | // Check the solution. 273 | // Count the number of high-cost edges from the perfect match 274 | // overlay that were used in the match. 275 | int overlayCount = 0; 276 | for (int a = 0; a < n; a++) { 277 | const int i = ograph(a,0); 278 | const int j = ograph(a,1); 279 | const int c = ograph(a,2); 280 | assert (i >= 0 && i < n); 281 | assert (j >= 0 && j < n); 282 | assert (c >= 0); 283 | // edge from high-cost perfect match overlay 284 | if (c == ow * multiplier) { overlayCount++; } 285 | // skip outlier edges 286 | if (i >= n1) { continue; } 287 | if (j >= n2) { continue; } 288 | // for edges between real nodes, check the edge weight 289 | const Pixel pix1 = nodeToPix1[i]; 290 | const Pixel pix2 = nodeToPix2[j]; 291 | const int dx = pix1.x - pix2.x; 292 | const int dy = pix1.y - pix2.y; 293 | const int w = (int) rint (sqrt(dx*dx+dy*dy)*multiplier); 294 | assert (w == c); 295 | } 296 | 297 | // Print a warning if any of the edges from the perfect match overlay 298 | // were used. This should happen rarely. If it happens frequently, 299 | // then the outlier connectivity should be increased. 300 | if (overlayCount > 5) { 301 | fprintf (stderr, "%s:%d: WARNING: The match includes %d " 302 | "outlier(s) from the perfect match overlay.\n", 303 | __FILE__, __LINE__, overlayCount); 304 | } 305 | 306 | // Compute match arrays. 307 | for (int a = 0; a < n; a++) { 308 | // node ids 309 | const int i = ograph(a,0); 310 | const int j = ograph(a,1); 311 | // skip outlier edges 312 | if (i >= n1) { continue; } 313 | if (j >= n2) { continue; } 314 | // map node ids to pixels 315 | const Pixel pix1 = nodeToPix1[i]; 316 | const Pixel pix2 = nodeToPix2[j]; 317 | // record edges 318 | match1(pix1.x,pix1.y) = pix2; 319 | match2(pix2.x,pix2.y) = pix1; 320 | } 321 | for (int x = 0; x < width; x++) { 322 | for (int y = 0; y < height; y++) { 323 | if (bmap1(y,x)) { 324 | if (match1(x,y) != Pixel(-1,-1)) { 325 | m1(y,x) = match1(x,y).x*height + match1(x,y).y + 1; 326 | } 327 | } 328 | if (bmap2(y,x)) { 329 | if (match2(x,y) != Pixel(-1,-1)) { 330 | m2(y,x) = match2(x,y).x*height + match2(x,y).y + 1; 331 | } 332 | } 333 | } 334 | } 335 | 336 | // Compute the match cost. 337 | double cost = 0; 338 | for (int x = 0; x < width; x++) { 339 | for (int y = 0; y < height; y++) { 340 | if (bmap1(y,x)) { 341 | if (match1(x,y) == Pixel(-1,-1)) { 342 | cost += outlierCost; 343 | } else { 344 | const int dx = x - match1(x,y).x; 345 | const int dy = y - match1(x,y).y; 346 | cost += 0.5 * sqrt (dx*dx + dy*dy); 347 | } 348 | } 349 | if (bmap2(y,x)) { 350 | if (match2(x,y) == Pixel(-1,-1)) { 351 | cost += outlierCost; 352 | } else { 353 | const int dx = x - match2(x,y).x; 354 | const int dy = y - match2(x,y).y; 355 | cost += 0.5 * sqrt (dx*dx + dy*dy); 356 | } 357 | } 358 | } 359 | } 360 | 361 | // Return the match cost. 362 | return cost; 363 | } 364 | -------------------------------------------------------------------------------- /py-bsds500/bsds/evaluate_boundaries_parallel.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | from typing import Callable, Optional 4 | from . import thin, correspond_pixels 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | import contextlib 7 | import joblib 8 | from tqdm.auto import tqdm 9 | from tqdm.contrib.concurrent import thread_map # Required for progress monitoring 10 | 11 | 12 | @contextlib.contextmanager 13 | def tqdm_joblib(total: Optional[int] = None, **kwargs): 14 | 15 | pbar = tqdm(total=total, miniters=1, smoothing=0, **kwargs) 16 | 17 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 18 | def __call__(self, *args, **kwargs): 19 | pbar.update(n=self.batch_size) 20 | return super().__call__(*args, **kwargs) 21 | 22 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 23 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 24 | 25 | try: 26 | yield pbar 27 | finally: 28 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 29 | pbar.close() 30 | 31 | 32 | def evaluate_sample(sample_name, load_pred: Callable, load_gt_boundaries: Callable, thresholds, max_dist): 33 | pred = load_pred(sample_name) 34 | gt_b = load_gt_boundaries(sample_name) 35 | 36 | count_r, sum_r, count_p, sum_p, used_thresholds = evaluate_boundaries( 37 | pred, gt_b, thresholds=thresholds, apply_thinning=True, max_dist=max_dist) 38 | 39 | rec, prec, f1 = compute_rec_prec_f1(count_r, sum_r, count_p, sum_p) 40 | 41 | best_ndx = np.argmax(f1) 42 | 43 | sample_result = SampleResult( 44 | sample_name, 45 | used_thresholds[best_ndx], rec[best_ndx], prec[best_ndx], f1[best_ndx]) 46 | 47 | return sample_result, count_r, sum_r, \ 48 | count_p, sum_p, best_ndx, used_thresholds 49 | 50 | 51 | def evaluate_boundaries_bin(predicted_boundaries_bin, gt_boundaries, 52 | max_dist=0.0075, apply_thinning=True): 53 | """ 54 | Evaluate the accuracy of a predicted boundary. 55 | 56 | :param predicted_boundaries_bin: the predicted boundaries as a (H,W) 57 | binary array 58 | :param gt_boundaries: a list of ground truth boundaries, as returned 59 | by the `load_boundaries` or `boundaries` methods 60 | :param max_dist: (default=0.0075) maximum distance parameter 61 | used for determining pixel matches. This value is multiplied by the 62 | length of the diagonal of the image to get the threshold used 63 | for matching pixels. 64 | :param apply_thinning: (default=True) if True, apply morphologial 65 | thinning to the predicted boundaries before evaluation 66 | :return: tuple `(count_r, sum_r, count_p, sum_p)` where each of 67 | the four entries are float values that can be used to compute 68 | recall and precision with: 69 | ``` 70 | recall = count_r / (sum_r + (sum_r == 0)) 71 | precision = count_p / (sum_p + (sum_p == 0)) 72 | ``` 73 | """ 74 | acc_prec = np.zeros(predicted_boundaries_bin.shape, dtype=bool) 75 | predicted_boundaries_bin = predicted_boundaries_bin != 0 76 | 77 | if apply_thinning: 78 | predicted_boundaries_bin = thin.binary_thin(predicted_boundaries_bin) 79 | 80 | sum_r = 0 81 | count_r = 0 82 | for gt in gt_boundaries: 83 | match1, match2, cost, oc = correspond_pixels.correspond_pixels( 84 | predicted_boundaries_bin, gt, max_dist=max_dist 85 | ) 86 | match1 = match1 > 0 87 | match2 = match2 > 0 88 | # Precision accumulator 89 | acc_prec = acc_prec | match1 90 | # Recall 91 | sum_r += gt.sum() 92 | count_r += match2.sum() 93 | 94 | # Precision 95 | sum_p = predicted_boundaries_bin.sum() 96 | count_p = acc_prec.sum() 97 | 98 | return count_r, sum_r, count_p, sum_p 99 | 100 | 101 | def evaluate_boundaries(predicted_boundaries, gt_boundaries, 102 | thresholds=99, max_dist=0.0075, apply_thinning=True, 103 | progress=None): 104 | """ 105 | Evaluate the accuracy of a predicted boundary and a range of thresholds 106 | 107 | :param predicted_boundaries: the predicted boundaries as a (H,W) 108 | floating point array where each pixel represents the strength of the 109 | predicted boundary 110 | :param gt_boundaries: a list of ground truth boundaries, as returned 111 | by the `load_boundaries` or `boundaries` methods 112 | :param thresholds: either an integer specifying the number of thresholds 113 | to use or a 1D array specifying the thresholds 114 | :param max_dist: (default=0.0075) maximum distance parameter 115 | used for determining pixel matches. This value is multiplied by the 116 | length of the diagonal of the image to get the threshold used 117 | for matching pixels. 118 | :param apply_thinning: (default=True) if True, apply morphologial 119 | thinning to the predicted boundaries before evaluation 120 | :param progress: a function that can be used to monitor progress; 121 | use `tqdm.tqdm` or `tdqm.tqdm_notebook` from the `tqdm` package 122 | to generate a progress bar. 123 | :return: tuple `(count_r, sum_r, count_p, sum_p, thresholds)` where each 124 | of the first four entries are arrays that can be used to compute 125 | recall and precision at each threshold with: 126 | ``` 127 | recall = count_r / (sum_r + (sum_r == 0)) 128 | precision = count_p / (sum_p + (sum_p == 0)) 129 | ``` 130 | The thresholds are also returned. 131 | """ 132 | if progress is None: 133 | progress = lambda x, *args, **kwargs: x 134 | 135 | # Handle thresholds 136 | if isinstance(thresholds, int): 137 | thresholds = np.linspace(1.0 / (thresholds + 1), 138 | 1.0 - 1.0 / (thresholds + 1), thresholds) 139 | elif isinstance(thresholds, np.ndarray): 140 | if thresholds.ndim != 1: 141 | raise ValueError('thresholds array should have 1 dimension, ' 142 | 'not {}'.format(thresholds.ndim)) 143 | pass 144 | else: 145 | raise ValueError('thresholds should be an int or a NumPy array, not ' 146 | 'a {}'.format(type(thresholds))) 147 | 148 | sum_p = np.zeros(thresholds.shape) 149 | count_p = np.zeros(thresholds.shape) 150 | sum_r = np.zeros(thresholds.shape) 151 | count_r = np.zeros(thresholds.shape) 152 | 153 | for i_t, thresh in enumerate(progress(list(thresholds))): 154 | predicted_boundaries_bin = predicted_boundaries >= thresh 155 | 156 | acc_prec = np.zeros(predicted_boundaries_bin.shape, dtype=bool) 157 | 158 | if apply_thinning: 159 | predicted_boundaries_bin = thin.binary_thin( 160 | predicted_boundaries_bin) 161 | 162 | for gt in gt_boundaries: 163 | 164 | match1, match2, cost, oc = correspond_pixels.correspond_pixels( 165 | predicted_boundaries_bin, gt, max_dist=max_dist 166 | ) 167 | match1 = match1 > 0 168 | match2 = match2 > 0 169 | # Precision accumulator 170 | acc_prec = acc_prec | match1 171 | # Recall 172 | sum_r[i_t] += gt.sum() 173 | count_r[i_t] += match2.sum() 174 | 175 | # Precision 176 | sum_p[i_t] = predicted_boundaries_bin.sum() 177 | count_p[i_t] = acc_prec.sum() 178 | 179 | return count_r, sum_r, count_p, sum_p, thresholds 180 | 181 | 182 | def compute_rec_prec_f1(count_r, sum_r, count_p, sum_p): 183 | """ 184 | Computer recall, precision and F1-score given `count_r`, `sum_r`, 185 | `count_p` and `sum_p`; see `evaluate_boundaries`. 186 | :param count_r: 187 | :param sum_r: 188 | :param count_p: 189 | :param sum_p: 190 | :return: tuple `(recall, precision, f1)` 191 | """ 192 | rec = count_r / (sum_r + (sum_r == 0)) 193 | prec = count_p / (sum_p + (sum_p == 0)) 194 | f1_denom = (prec + rec + ((prec+rec) == 0)) 195 | f1 = 2.0 * prec * rec / f1_denom 196 | return rec, prec, f1 197 | 198 | 199 | SampleResult = namedtuple('SampleResult', ['sample_name', 'threshold', 200 | 'recall', 'precision', 'f1']) 201 | ThresholdResult = namedtuple('ThresholdResult', ['threshold', 'recall', 202 | 'precision', 'f1']) 203 | OverallResult = namedtuple('OverallResult', ['threshold', 'recall', 204 | 'precision', 'f1', 205 | 'best_recall', 'best_precision', 206 | 'best_f1', 'area_pr']) 207 | 208 | 209 | def pr_evaluation(thresholds, sample_names: list, load_gt_boundaries: Callable, load_pred: Callable, 210 | progress=None, num_workers=1, max_dist=0.0075): 211 | """ 212 | Perform an evaluation of predictions against ground truths for an image 213 | set over a given set of thresholds. 214 | 215 | :param thresholds: either an integer specifying the number of thresholds 216 | to use or a 1D array specifying the thresholds 217 | :param sample_names: the names of the samples that are to be evaluated 218 | :param load_gt_boundaries: a callable that loads the ground truth for a 219 | named sample; of the form `load_gt_boundaries(sample_name) -> gt` 220 | where `gt` is a 2D NumPy array 221 | :param load_pred: a callable that loads the prediction for a 222 | named sample; of the form `load_gt_boundaries(sample_name) -> gt` 223 | where `gt` is a 2D NumPy array 224 | :param progress: default=None a callable -- such as `tqdm` -- that 225 | accepts an iterator over the sample names in order to track progress 226 | :return: `(sample_results, threshold_results, overall_result)` 227 | where `sample_results` is a list of `SampleResult` named tuples with one 228 | for each sample, `threshold_results` is a list of `ThresholdResult` 229 | named tuples, with one for each threshold and `overall_result` 230 | is an `OverallResult` named tuple giving the over all results. The 231 | attributes in these structures will now be described: 232 | 233 | `SampleResult`: 234 | - `sample_name`: the name identifying the sample to which this result 235 | applies 236 | - `threshold`: the threshold at which the best F1-score was obtained for 237 | the given sample 238 | - `recall`: the recall score obtained at the best threshold 239 | - `precision`: the precision score obtained at the best threshold 240 | - `f1`: the F1-score obtained at the best threshold\ 241 | 242 | `ThresholdResult`: 243 | - `threshold`: the threshold value to which this result applies 244 | - `recall`: the average recall score for all samples 245 | - `precision`: the average precision score for all samples 246 | - `f1`: the average F1-score for all samples 247 | 248 | `OverallResult`: 249 | - `threshold`: the threshold at which the best average F1-score over 250 | all samples is obtained 251 | - `recall`: the average recall score for all samples at `threshold` 252 | - `precision`: the average precision score for all samples at `threshold` 253 | - `f1`: the average F1-score for all samples at `threshold` 254 | - `best_recall`: the average recall score for all samples at the best 255 | threshold *for each individual sample* 256 | - `best_precision`: the average precision score for all samples at the 257 | best threshold *for each individual sample* 258 | - `best_f1`: the average F1-score for all samples at the best threshold 259 | *for each individual sample* 260 | - `area_pr`: the area under the precision-recall curve at `threshold` 261 | ` 262 | """ 263 | if progress is None: 264 | progress = lambda x, *args: x 265 | 266 | if isinstance(thresholds, int): 267 | n_thresh = thresholds 268 | else: 269 | n_thresh = thresholds.shape[0] 270 | 271 | sample_results = [] 272 | count_r_overall = np.zeros((n_thresh,)) 273 | sum_r_overall = np.zeros((n_thresh,)) 274 | count_p_overall = np.zeros((n_thresh,)) 275 | sum_p_overall = np.zeros((n_thresh,)) 276 | count_r_best = 0 277 | sum_r_best = 0 278 | count_p_best = 0 279 | sum_p_best = 0 280 | 281 | with tqdm_joblib(len(sample_names)): 282 | results = joblib.Parallel(n_jobs=num_workers)( 283 | joblib.delayed(evaluate_sample)(sample_name, load_pred, load_gt_boundaries, thresholds, max_dist) 284 | for sample_name in sample_names 285 | ) 286 | 287 | for sample_name, result in zip(progress(sample_names), results): 288 | # Get the paths for the ground truth and predicted boundaries 289 | 290 | # Evaluated predictions 291 | sample_result, count_r, sum_r, \ 292 | count_p, sum_p, best_ndx, used_thresholds = result 293 | 294 | count_r_overall += count_r 295 | sum_r_overall += sum_r 296 | count_p_overall += count_p 297 | sum_p_overall += sum_p 298 | 299 | count_r_best += count_r[best_ndx] 300 | sum_r_best += sum_r[best_ndx] 301 | count_p_best += count_p[best_ndx] 302 | sum_p_best += sum_p[best_ndx] 303 | 304 | sample_results.append(sample_result) 305 | 306 | # Computer overall precision, recall and F1 307 | rec_overall, prec_overall, f1_overall = compute_rec_prec_f1( 308 | count_r_overall, sum_r_overall, count_p_overall, sum_p_overall) 309 | 310 | # Find best F1 score 311 | best_i_ovr = np.argmax(f1_overall) 312 | 313 | threshold_results = [] 314 | for thresh_i in range(n_thresh): 315 | threshold_results.append(ThresholdResult(used_thresholds[thresh_i], 316 | rec_overall[thresh_i], 317 | prec_overall[thresh_i], 318 | f1_overall[thresh_i])) 319 | 320 | rec_unique, rec_unique_ndx = np.unique(rec_overall, return_index=True) 321 | prec_unique = prec_overall[rec_unique_ndx] 322 | if rec_unique.shape[0] > 1: 323 | prec_interp = np.interp(np.arange(0, 1, 0.01), rec_unique, 324 | prec_unique, left=0.0, right=0.0) 325 | area_pr = prec_interp.sum() * 0.01 326 | else: 327 | area_pr = 0.0 328 | 329 | rec_best, prec_best, f1_best = compute_rec_prec_f1( 330 | float(count_r_best), float(sum_r_best), float(count_p_best), 331 | float(sum_p_best) 332 | ) 333 | 334 | overall_result = OverallResult(used_thresholds[best_i_ovr], 335 | rec_overall[best_i_ovr], 336 | prec_overall[best_i_ovr], 337 | f1_overall[best_i_ovr], rec_best, 338 | prec_best, f1_best, area_pr) 339 | 340 | return sample_results, threshold_results, overall_result 341 | -------------------------------------------------------------------------------- /py-bsds500/src/Matrix.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Matrix_hh__ 3 | #define __Matrix_hh__ 4 | 5 | // MATLAB style matrix class. 6 | // Data storage is column-major. 7 | // Uses LAPACK/BLAS for critical operations. 8 | 9 | // If you use this to do matrix multiplication, then make sure you 10 | // have a good BLAS. I recommend using ATLAS. You can download 11 | // pre-built ATLAS libraries from http://www.netlib.org/atlas/. You 12 | // will need libf77blas.a and libatlas.a. 13 | 14 | // Currently implements almost all of matlab/{elmat,elfun}. 15 | // Should ultimately also implement much of matlab/{matfun,datafun}. 16 | 17 | // Set to 1 or 0. Controlls if the nextpow2 methods get defined, 18 | // which depend on having the ieeefp.h header file. 19 | 20 | #ifdef __APPLE__ 21 | #define NEXTPOW2 0 22 | #else 23 | #define NEXTPOW2 1 24 | #endif 25 | 26 | class Matrix 27 | { 28 | public: 29 | 30 | // how to initialize a new matrix 31 | enum FillType { 32 | undef, zeros, ones, eye, rand, randn 33 | }; 34 | 35 | // create special matrices 36 | friend Matrix zeros (int rows, int cols); 37 | friend Matrix ones (int rows, int cols); 38 | friend Matrix eye (int rows, int cols); 39 | friend Matrix rand (int rows, int cols); 40 | friend Matrix randn (int rows, int cols); 41 | friend Matrix zeros (int sz); 42 | friend Matrix ones (int sz); 43 | friend Matrix eye (int sz); 44 | friend Matrix rand (int sz); 45 | friend Matrix randn (int sz); 46 | 47 | // construct empty matrix 48 | Matrix (); 49 | 50 | // construct square matrix with specified fill 51 | Matrix (int sz, FillType type = zeros); 52 | 53 | // construct matrix with specified fill 54 | Matrix (int rows, int cols, FillType type = zeros); 55 | 56 | // copy constructor 57 | Matrix (const Matrix& that); 58 | 59 | // create a wrapped matrix 60 | // (i.e. we're not responsible for freeing the data) 61 | Matrix (int rows, int cols, double* data); 62 | 63 | // destructor 64 | ~Matrix (); 65 | 66 | // reshape is only valid if the number of elements doesn't change 67 | void reshape (int rows, int cols); 68 | friend Matrix reshape (const Matrix& a, int rows, int cols); 69 | 70 | // resize a matrix 71 | void resize (int rows, int cols, FillType type = zeros); 72 | 73 | // matrix properties 74 | bool isvec () const; 75 | bool isrowvec () const; 76 | bool iscolvec () const; 77 | bool isempty () const; 78 | bool isscalar () const; 79 | bool issize (int rows, int cols) const; 80 | int nrows () const; 81 | int ncols () const; 82 | int numel () const; 83 | int length () const; 84 | Matrix size () const; 85 | bool samesize (const Matrix& a) const; 86 | friend bool isvec (const Matrix& a); 87 | friend bool isrowvec (const Matrix& a); 88 | friend bool iscolvec (const Matrix& a); 89 | friend bool isempty (const Matrix& a); 90 | friend bool isscalar (const Matrix& a); 91 | friend bool issize (const Matrix& a, int rows, int cols); 92 | friend int nrows (const Matrix& a); 93 | friend int ncols (const Matrix& a); 94 | friend int numel (const Matrix& a); 95 | friend int length (const Matrix& a); 96 | friend Matrix size (const Matrix& a); 97 | friend bool samesize (const Matrix& a, const Matrix& b); 98 | 99 | // TODO: catenation 100 | void horzcat (const Matrix& a); 101 | void vertcat (const Matrix& a); 102 | void blkdiag (const Matrix& a); 103 | friend Matrix horzcat (const Matrix& a, const Matrix& b); 104 | friend Matrix vertcat (const Matrix& a, const Matrix& b); 105 | friend Matrix blkdiag (const Matrix& a, const Matrix& b); 106 | 107 | // index/subscript conversion 108 | void ind2sub (const Matrix& ind, Matrix& i, Matrix& j); 109 | void sub2ind (const Matrix& i, const Matrix& j, Matrix& ind); 110 | friend void ind2sub ( 111 | const Matrix& siz, const Matrix& ind, Matrix& i, Matrix& j); 112 | friend void sub2ind ( 113 | const Matrix& siz, const Matrix& i, const Matrix& j, Matrix& ind); 114 | 115 | // access to raw data array 116 | bool iswrapped () const; 117 | friend bool iswrapped (const Matrix &a); 118 | double* data (); 119 | 120 | // 1D element access 121 | double& operator() (int index); 122 | const double& operator() (int index) const; 123 | 124 | // 2D element access 125 | double& operator() (int row, int col); 126 | const double& operator() (int row, int col) const; 127 | 128 | // sub-matrix access 129 | Matrix operator() (int r1, int r2, int c1, int c2) const; 130 | void insert (const Matrix& m, int r1, int r2, int c1, int c2); 131 | friend Matrix insert (const Matrix& a, const Matrix& m, 132 | int r1, int r2, int c1, int c2); 133 | 134 | // assignment 135 | void operator= (const double& val); 136 | void operator= (const Matrix& that); 137 | 138 | // eqality 139 | bool isequal (const Matrix& a) const; 140 | friend bool isequal (const Matrix& a, const Matrix& b); 141 | 142 | // filling 143 | void linspace (double a, double b); 144 | void logspace (double a, double b); 145 | friend Matrix linspace (double a, double b, int n); 146 | friend Matrix logspace (double a, double b, int n); 147 | 148 | // element shuffling 149 | void transpose (); // transpose 150 | void fliplr (); // swap columns 151 | void flipud (); // swap rows 152 | void rot90 (int k = 1); // rotate CC k*90 degrees 153 | friend Matrix transpose (const Matrix& a); 154 | friend Matrix fliplr (const Matrix& a); 155 | friend Matrix flipud (const Matrix& a); 156 | friend Matrix rot90 (const Matrix& a, int k); 157 | 158 | // replication 159 | Matrix repmat (int m, int n) const; 160 | friend Matrix repmat (const Matrix& a, int m, int n); 161 | 162 | // find non-zeros 163 | Matrix find () const; 164 | friend Matrix find (const Matrix& a); 165 | 166 | // gather/scatter 167 | Matrix operator() (const Matrix& indices) const; 168 | Matrix gather (const Matrix& indices) const; 169 | void scatter (const Matrix& indices, const Matrix& values); 170 | void scatter (const Matrix& indices, double value); 171 | friend Matrix scatter ( 172 | const Matrix& a, const Matrix& indices, const Matrix& values); 173 | friend Matrix scatter ( 174 | const Matrix& a, const Matrix& indices, double value); 175 | 176 | // masking 177 | void tril (int k = 0); // save lower triangle 178 | void triu (int k = 0); // save upper triangle 179 | friend Matrix tril (const Matrix& a, int k); 180 | friend Matrix triu (const Matrix& a, int k); 181 | 182 | // diagonal 183 | Matrix getdiag (int k = 0) const; // get vector 184 | void setdiag (double val, int k = 0); // set from scalar 185 | void setdiag (const Matrix& d, int k = 0); // set from vector 186 | // make a diagonal matrix if a is a vector 187 | // extract diagonal as column vector if a is a matrix 188 | friend Matrix diag (const Matrix& a, int k); 189 | 190 | // reductions 191 | bool any () const; // logical-or reduction 192 | bool all () const; // logical-and reduction 193 | double sum () const; // sum reduction 194 | Matrix rsum () const; // row sums 195 | Matrix csum () const; // column sums 196 | double prod () const; // product reduction 197 | Matrix rprod () const; // row products 198 | Matrix cprod () const; // column products 199 | friend bool any (const Matrix& a); 200 | friend bool all (const Matrix& a); 201 | friend double sum (const Matrix& a); 202 | friend Matrix rsum (const Matrix& a); 203 | friend Matrix csum (const Matrix& a); 204 | friend double prod (const Matrix& a); 205 | friend Matrix rprod (const Matrix& a); 206 | friend Matrix cprod (const Matrix& a); 207 | 208 | // min and max reductions 209 | double min () const; // total min 210 | double min (int& index) const; // total min with index 211 | Matrix rmin () const; // row mins 212 | Matrix rmin (Matrix& indices) const;// row mins with indices 213 | Matrix cmin () const; // column mins 214 | Matrix cmin (Matrix& indices) const;// column mins with indices 215 | double max () const; // total max 216 | double max (int& index) const; // total max with index 217 | Matrix rmax () const; // row maxs 218 | Matrix rmax (Matrix& indices) const;// row maxs with indices 219 | Matrix cmax () const; // column maxs 220 | Matrix cmax (Matrix& indices) const;// column maxs with indices 221 | friend double min (const Matrix& a); 222 | friend double min (const Matrix& a, int& index); 223 | friend Matrix rmin (const Matrix& a); 224 | friend Matrix rmin (const Matrix& a, Matrix& indices); 225 | friend Matrix cmin (const Matrix& a); 226 | friend Matrix cmin (const Matrix& a, Matrix& indices); 227 | friend double max (const Matrix& a); 228 | friend double max (const Matrix& a, int& index); 229 | friend Matrix rmax (const Matrix& a); 230 | friend Matrix rmax (const Matrix& a, Matrix& indices); 231 | friend Matrix cmax (const Matrix& a); 232 | friend Matrix cmax (const Matrix& a, Matrix& indices); 233 | 234 | // binary min and max 235 | friend Matrix min (const Matrix& a, const Matrix& b); 236 | friend Matrix min (const Matrix& a, double b); 237 | friend Matrix min (double a, const Matrix& b); 238 | friend Matrix max (const Matrix& a, const Matrix& b); 239 | friend Matrix max (const Matrix& a, double b); 240 | friend Matrix max (double a, const Matrix& b); 241 | 242 | // rounding 243 | void ceil (); // round toward +inf 244 | void floor (); // round toward -inf 245 | void round (); // round to nearest 246 | void fix (); // round towards zero 247 | friend Matrix ceil (const Matrix& a); 248 | friend Matrix floor (const Matrix& a); 249 | friend Matrix round (const Matrix& a); 250 | friend Matrix fix (const Matrix& a); 251 | 252 | // nan, inf 253 | void iznan (); 254 | void izinf (); 255 | void izfinite (); 256 | friend Matrix iznan (const Matrix& a); 257 | friend Matrix izinf (const Matrix& a); 258 | friend Matrix izfinite (const Matrix& a); 259 | 260 | // sign 261 | void abs (); // absolute value 262 | void sign (); // -1,0,1 for negative,zero,positive 263 | friend Matrix abs (const Matrix& a); 264 | friend Matrix sign (const Matrix& a); 265 | 266 | // exponentials (element-wise) 267 | void exp (); // e^x 268 | void log (); // natural log 269 | void log10 (); // base 10 log 270 | void log2 (); // base 2 log 271 | void pow2 (); // 2^x 272 | void sqrt (); // square root 273 | #if NEXTPOW2 274 | void nextpow2 (); // smallest i s.t. 2^i>x 275 | #endif // NEXTPOW2 276 | friend Matrix exp (const Matrix& a); 277 | friend Matrix log (const Matrix& a); 278 | friend Matrix log10 (const Matrix& a); 279 | friend Matrix log2 (const Matrix& a); 280 | friend Matrix pow2 (const Matrix& a); 281 | friend Matrix sqrt (const Matrix& a); 282 | #if NEXTPOW2 283 | friend Matrix nextpow2 (const Matrix& a); 284 | #endif // NEXTPOW2 285 | 286 | // trigonometric functions 287 | void sin (); // Sine. 288 | void sinh (); // Hyperbolic sine. 289 | void asin (); // Inverse sine. 290 | void asinh (); // Inverse hyperbolic sine. 291 | void cos (); // Cosine. 292 | void cosh (); // Hyperbolic cosine. 293 | void acos (); // Inverse cosine. 294 | void acosh (); // Inverse hyperbolic cosine. 295 | void tan (); // Tangent. 296 | void tanh (); // Hyperbolic tangent. 297 | void atan (); // Inverse tangent. 298 | void atanh (); // Inverse hyperbolic tangent. 299 | void sec (); // Secant. 300 | void sech (); // Hyperbolic secant. 301 | void asec (); // Inverse secant. 302 | void asech (); // Inverse hyperbolic secant. 303 | void csc (); // Cosecant. 304 | void csch (); // Hyperbolic cosecant. 305 | void acsc (); // Inverse cosecant. 306 | void acsch (); // Inverse hyperbolic cosecant. 307 | void cot (); // Cotangent. 308 | void coth (); // Hyperbolic cotangent. 309 | void acot (); // Inverse cotangent. 310 | void acoth (); // Inverse hyperbolic cotangent. 311 | friend Matrix sin (const Matrix& a); 312 | friend Matrix sinh (const Matrix& a); 313 | friend Matrix asin (const Matrix& a); 314 | friend Matrix asinh (const Matrix& a); 315 | friend Matrix cos (const Matrix& a); 316 | friend Matrix cosh (const Matrix& a); 317 | friend Matrix acos (const Matrix& a); 318 | friend Matrix acosh (const Matrix& a); 319 | friend Matrix tan (const Matrix& a); 320 | friend Matrix tanh (const Matrix& a); 321 | friend Matrix atan (const Matrix& a); 322 | friend Matrix atanh (const Matrix& a); 323 | friend Matrix sec (const Matrix& a); 324 | friend Matrix sech (const Matrix& a); 325 | friend Matrix asec (const Matrix& a); 326 | friend Matrix asech (const Matrix& a); 327 | friend Matrix csc (const Matrix& a); 328 | friend Matrix csch (const Matrix& a); 329 | friend Matrix acsc (const Matrix& a); 330 | friend Matrix acsch (const Matrix& a); 331 | friend Matrix cot (const Matrix& a); 332 | friend Matrix coth (const Matrix& a); 333 | friend Matrix acot (const Matrix& a); 334 | friend Matrix acoth (const Matrix& a); 335 | 336 | // computed assignment (all element-wise) 337 | #define DEFOP(OP) \ 338 | Matrix& operator OP (const double& val); \ 339 | Matrix& operator OP (const Matrix& that); 340 | DEFOP(+=); 341 | DEFOP(-=); 342 | DEFOP(*=); 343 | DEFOP(/=); 344 | DEFOP(^=); // exponentiation, not xor 345 | #undef DEFOP 346 | 347 | // binary operators (all element-wise) 348 | #define DEFOP(OP) \ 349 | friend Matrix operator OP (const Matrix& a, const Matrix& b); \ 350 | friend Matrix operator OP (const Matrix& a, double b); \ 351 | friend Matrix operator OP (double a, const Matrix& b); 352 | DEFOP(+); 353 | DEFOP(-); 354 | DEFOP(*); 355 | DEFOP(/); 356 | DEFOP(^); // exponentiation, not xor 357 | DEFOP(<); 358 | DEFOP(>); 359 | DEFOP(==); 360 | DEFOP(!=); 361 | DEFOP(<=); 362 | DEFOP(>=); 363 | DEFOP(&&); 364 | DEFOP(||); 365 | #undef DEFOP 366 | 367 | // unary operators 368 | #define DEFOP(OP) \ 369 | friend Matrix operator OP (const Matrix& a); 370 | DEFOP(!); 371 | #undef DEFOP 372 | 373 | // misc binary functions 374 | #define DEFOP(OP) \ 375 | friend Matrix OP (const Matrix& a, const Matrix& b); \ 376 | friend Matrix OP (const Matrix& a, double b); \ 377 | friend Matrix OP (double a, const Matrix& b); 378 | DEFOP(rem); 379 | DEFOP(mod); 380 | DEFOP(atan2); 381 | #undef DEFOP 382 | 383 | // matrix multiplication 384 | friend Matrix mtimes (const Matrix& a, const Matrix& b); 385 | 386 | protected: 387 | 388 | void _alloc (int rows, int cols, FillType type); 389 | void _delete (); 390 | void _zero (); 391 | 392 | int _rows, _cols, _n; 393 | double* _data; 394 | bool _wrapped; 395 | 396 | }; // class Matrix 397 | 398 | Matrix rot90 (const Matrix& a, int k = 1); 399 | Matrix tril (const Matrix& a, int k = 0); 400 | Matrix triu (const Matrix& a, int k = 0); 401 | Matrix diag (const Matrix& a, int k = 0); 402 | 403 | 404 | #endif // __Matrix_hh__ 405 | -------------------------------------------------------------------------------- /automatic_mask_and_probability_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Changes were made to this file by Hiroaki Yamagiwa. 8 | 9 | from typing import Any, Dict, List, Optional, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from segment_anything import SamAutomaticMaskGenerator 15 | from segment_anything.modeling import Sam 16 | from segment_anything.utils.amg import (MaskData, area_from_rle, 17 | batched_mask_to_box, box_xyxy_to_xywh, 18 | batch_iterator, 19 | uncrop_boxes_xyxy, uncrop_points, 20 | calculate_stability_score, 21 | coco_encode_rle, generate_crop_boxes, 22 | is_box_near_crop_edge, 23 | mask_to_rle_pytorch, rle_to_mask, 24 | uncrop_masks) 25 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 26 | 27 | 28 | def batched_mask_to_prob(masks: torch.Tensor) -> torch.Tensor: 29 | """ 30 | For implementation, see the following issue comment: 31 | 32 | "To get the probability map for a mask, 33 | we simply do element-wise sigmoid over the logits." 34 | URL: https://github.com/facebookresearch/segment-anything/issues/226 35 | 36 | Args: 37 | masks: Tensor of shape [B, H, W] representing batch of binary masks. 38 | 39 | Returns: 40 | Tensor of shape [B, H, W] representing batch of probability maps. 41 | """ 42 | probs = torch.sigmoid(masks).to(masks.device) 43 | return probs 44 | 45 | 46 | def batched_sobel_filter(probs: torch.Tensor, masks: torch.Tensor, bzp: int 47 | ) -> torch.Tensor: 48 | """ 49 | For implementation, see section D.2 of the paper: 50 | 51 | "we apply a Sobel filter to the remaining masks' unthresholded probability 52 | maps and set values to zero if they do not intersect with the outer 53 | boundary pixels of a mask." 54 | URL: https://arxiv.org/abs/2304.02643 55 | 56 | Args: 57 | probs: Tensor of shape [B, H, W] representing batch of probability maps. 58 | masks: Tensor of shape [B, H, W] representing batch of binary masks. 59 | 60 | Returns: 61 | Tensor of shape [B, H, W] with filtered probability maps. 62 | """ 63 | # probs: [B, H, W] 64 | # Add channel dimension to make it [B, 1, H, W] 65 | probs = probs.unsqueeze(1) 66 | 67 | # sobel_filter: [1, 1, 3, 3] 68 | sobel_filter_x = torch.tensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]], 69 | dtype=torch.float32 70 | ).to(probs.device).unsqueeze(0) 71 | sobel_filter_y = torch.tensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]], 72 | dtype=torch.float32 73 | ).to(probs.device).unsqueeze(0) 74 | 75 | # Apply the Sobel filters 76 | G_x = F.conv2d(probs, sobel_filter_x, padding=1) 77 | G_y = F.conv2d(probs, sobel_filter_y, padding=1) 78 | 79 | # Combine the gradients 80 | probs = torch.sqrt(G_x ** 2 + G_y ** 2) 81 | 82 | # Iterate through each image in the batch 83 | for i in range(probs.shape[0]): 84 | # Convert binary mask to float 85 | mask = masks[i].float() 86 | 87 | G_x = F.conv2d(mask[None, None], sobel_filter_x, padding=1) 88 | G_y = F.conv2d(mask[None, None], sobel_filter_y, padding=1) 89 | edge = torch.sqrt(G_x ** 2 + G_y ** 2) 90 | outer_boundary = (edge > 0).float() 91 | 92 | # Set to zero values that don't touch the mask's outer boundary. 93 | probs[i, 0] = probs[i, 0] * outer_boundary 94 | 95 | # Boundary zero padding (BZP). 96 | # See "Zero-Shot Edge Detection With SCESAME: Spectral 97 | # Clustering-Based Ensemble for Segment Anything Model Estimation". 98 | if bzp > 0: 99 | probs[i, 0, 0:bzp, :] = 0 100 | probs[i, 0, -bzp:, :] = 0 101 | probs[i, 0, :, 0:bzp] = 0 102 | probs[i, 0, :, -bzp:] = 0 103 | 104 | # Remove the channel dimension 105 | probs = probs.squeeze(1) 106 | 107 | return probs 108 | 109 | 110 | class SamAutomaticMaskAndProbabilityGenerator(SamAutomaticMaskGenerator): 111 | def __init__( 112 | self, 113 | model: Sam, 114 | points_per_side: Optional[int] = 16, 115 | points_per_batch: int = 64, 116 | pred_iou_thresh: float = 0.88, 117 | stability_score_thresh: float = 0.95, 118 | stability_score_offset: float = 1.0, 119 | box_nms_thresh: float = 0.7, 120 | crop_n_layers: int = 0, 121 | crop_nms_thresh: float = 0.7, 122 | crop_overlap_ratio: float = 512 / 1500, 123 | crop_n_points_downscale_factor: int = 1, 124 | point_grids: Optional[List[np.ndarray]] = None, 125 | min_mask_region_area: int = 0, 126 | output_mode: str = "binary_mask", 127 | nms_threshold: float = 0.7, 128 | bzp: int = 0, 129 | pred_iou_thresh_filtering=False, 130 | stability_score_thresh_filtering=False, 131 | ) -> None: 132 | """ 133 | Using a SAM model, generates masks for the entire image. 134 | Generates a grid of point prompts over the image, then filters 135 | low quality and duplicate masks. The default settings are chosen 136 | for SAM with a ViT-H backbone. 137 | 138 | Arguments: 139 | model (Sam): The SAM model to use for mask prediction. 140 | points_per_side (int or None): The number of points to be sampled 141 | along one side of the image. The total number of points is 142 | points_per_side**2. If None, 'point_grids' must provide explicit 143 | point sampling. 144 | points_per_batch (int): Sets the number of points run simultaneously 145 | by the model. Higher numbers may be faster but use more GPU memory. 146 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 147 | model's predicted mask quality. 148 | stability_score_thresh (float): A filtering threshold in [0,1], using 149 | the stability of the mask under changes to the cutoff used to binarize 150 | the model's mask predictions. 151 | stability_score_offset (float): The amount to shift the cutoff when 152 | calculated the stability score. 153 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 154 | suppression to filter duplicate masks. 155 | crop_n_layers (int): If >0, mask prediction will be run again on 156 | crops of the image. Sets the number of layers to run, where each 157 | layer has 2**i_layer number of image crops. 158 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 159 | suppression to filter duplicate masks between different crops. 160 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 161 | In the first crop layer, crops will overlap by this fraction of 162 | the image length. Later layers with more crops scale down this overlap. 163 | crop_n_points_downscale_factor (int): The number of points-per-side 164 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 165 | point_grids (list(np.ndarray) or None): A list over explicit grids 166 | of points used for sampling, normalized to [0,1]. The nth grid in the 167 | list is used in the nth crop layer. Exclusive with points_per_side. 168 | min_mask_region_area (int): If >0, postprocessing will be applied 169 | to remove disconnected regions and holes in masks with area smaller 170 | than min_mask_region_area. Requires opencv. 171 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 172 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 173 | For large resolutions, 'binary_mask' may consume large amounts of 174 | memory. 175 | nms_threshold (float): The IoU threshold used for non-maximal suppression 176 | """ 177 | super().__init__( 178 | model, 179 | points_per_side, 180 | points_per_batch, 181 | pred_iou_thresh, 182 | stability_score_thresh, 183 | stability_score_offset, 184 | box_nms_thresh, 185 | crop_n_layers, 186 | crop_nms_thresh, 187 | crop_overlap_ratio, 188 | crop_n_points_downscale_factor, 189 | point_grids, 190 | min_mask_region_area, 191 | output_mode, 192 | ) 193 | self.nms_threshold = nms_threshold 194 | self.bzp = bzp 195 | self.pred_iou_thresh_filtering = pred_iou_thresh_filtering 196 | self.stability_score_thresh_filtering = \ 197 | stability_score_thresh_filtering 198 | 199 | @torch.no_grad() 200 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 201 | """ 202 | Generates masks for the given image. 203 | 204 | Arguments: 205 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 206 | 207 | Returns: 208 | list(dict(str, any)): A list over records for masks. Each record is 209 | a dict containing the following keys: 210 | segmentation (dict(str, any) or np.ndarray): The mask. If 211 | output_mode='binary_mask', is an array of shape HW. Otherwise, 212 | is a dictionary containing the RLE. 213 | bbox (list(float)): The box around the mask, in XYWH format. 214 | area (int): The area in pixels of the mask. 215 | predicted_iou (float): The model's own prediction of the mask's 216 | quality. This is filtered by the pred_iou_thresh parameter. 217 | point_coords (list(list(float))): The point coordinates input 218 | to the model to generate this mask. 219 | stability_score (float): A measure of the mask's quality. This 220 | is filtered on using the stability_score_thresh parameter. 221 | crop_box (list(float)): The crop of the image used to generate 222 | the mask, given in XYWH format. 223 | """ 224 | 225 | # Generate masks 226 | mask_data = self._generate_masks(image) 227 | 228 | # Filter small disconnected regions and holes in masks 229 | if self.min_mask_region_area > 0: 230 | mask_data = self.postprocess_small_regions( 231 | mask_data, 232 | self.min_mask_region_area, 233 | max(self.box_nms_thresh, self.crop_nms_thresh), 234 | ) 235 | 236 | # Encode masks 237 | if self.output_mode == "coco_rle": 238 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 239 | elif self.output_mode == "binary_mask": 240 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 241 | else: 242 | mask_data["segmentations"] = mask_data["rles"] 243 | 244 | # Write mask records 245 | curr_anns = [] 246 | for idx in range(len(mask_data["segmentations"])): 247 | ann = { 248 | "segmentation": mask_data["segmentations"][idx], 249 | "area": area_from_rle(mask_data["rles"][idx]), 250 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 251 | "predicted_iou": mask_data["iou_preds"][idx].item(), 252 | "point_coords": [mask_data["points"][idx].tolist()], 253 | "stability_score": mask_data["stability_score"][idx].item(), 254 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 255 | "prob": mask_data["probs"][idx], 256 | } 257 | curr_anns.append(ann) 258 | 259 | return curr_anns 260 | 261 | def _process_crop( 262 | self, 263 | image: np.ndarray, 264 | crop_box: List[int], 265 | crop_layer_idx: int, 266 | orig_size: Tuple[int, ...], 267 | ) -> MaskData: 268 | # Crop the image and calculate embeddings 269 | x0, y0, x1, y1 = crop_box 270 | cropped_im = image[y0:y1, x0:x1, :] 271 | cropped_im_size = cropped_im.shape[:2] 272 | self.predictor.set_image(cropped_im) 273 | 274 | # Get points for this crop 275 | points_scale = np.array(cropped_im_size)[None, ::-1] 276 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 277 | 278 | # Generate masks for this crop in batches 279 | data = MaskData() 280 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 281 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 282 | data.cat(batch_data) 283 | del batch_data 284 | self.predictor.reset_image() 285 | 286 | # Remove duplicates within this crop. 287 | keep_by_nms = batched_nms( 288 | data["boxes"].float(), 289 | data["iou_preds"], 290 | torch.zeros_like(data["boxes"][:, 0]), # categories 291 | iou_threshold=self.box_nms_thresh, 292 | ) 293 | data.filter(keep_by_nms) 294 | 295 | # Return to the original image frame 296 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 297 | data["points"] = uncrop_points(data["points"], crop_box) 298 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 299 | 300 | padded_probs = torch.zeros((data["probs"].shape[0], *orig_size), 301 | dtype=torch.float32, 302 | device=data["probs"].device) 303 | padded_probs[:, y0:y1, x0:x1] = data["probs"] 304 | data["probs"] = padded_probs 305 | 306 | return data 307 | 308 | def _generate_masks(self, image: np.ndarray) -> MaskData: 309 | orig_size = image.shape[:2] 310 | crop_boxes, layer_idxs = generate_crop_boxes( 311 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 312 | ) 313 | 314 | # Iterate over image crops 315 | data = MaskData() 316 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 317 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 318 | data.cat(crop_data) 319 | 320 | # Remove duplicate masks between crops 321 | if len(crop_boxes) > 1: 322 | # Prefer masks from smaller crops 323 | scores = 1 / box_area(data["crop_boxes"]) 324 | scores = scores.to(data["boxes"].device) 325 | keep_by_nms = batched_nms( 326 | data["boxes"].float(), 327 | scores, 328 | torch.zeros_like(data["boxes"][:, 0]), # categories 329 | iou_threshold=self.crop_nms_thresh, 330 | ) 331 | data.filter(keep_by_nms) 332 | 333 | data.to_numpy() 334 | return data 335 | 336 | def _process_batch( 337 | self, 338 | points: np.ndarray, 339 | im_size: Tuple[int, ...], 340 | crop_box: List[int], 341 | orig_size: Tuple[int, ...], 342 | ) -> MaskData: 343 | orig_h, orig_w = orig_size 344 | 345 | # Run model on this batch 346 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 347 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 348 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 349 | masks, iou_preds, _ = self.predictor.predict_torch( 350 | in_points[:, None, :], 351 | in_labels[:, None], 352 | multimask_output=True, 353 | return_logits=True, 354 | ) 355 | 356 | # Serialize predictions and store in MaskData 357 | data = MaskData( 358 | masks=masks.flatten(0, 1), 359 | iou_preds=iou_preds.flatten(0, 1), 360 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 361 | ) 362 | del masks 363 | 364 | if self.pred_iou_thresh_filtering and self.pred_iou_thresh > 0.0: 365 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 366 | data.filter(keep_mask) 367 | 368 | # Calculate stability score 369 | data["stability_score"] = calculate_stability_score( 370 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 371 | ) 372 | 373 | if self.stability_score_thresh_filtering and \ 374 | self.stability_score_thresh > 0.0: 375 | keep_mask = data["stability_score"] >= self.stability_score_thresh 376 | data.filter(keep_mask) 377 | 378 | # Threshold masks and calculate boxes 379 | data["probs"] = batched_mask_to_prob(data["masks"]) 380 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 381 | data["boxes"] = batched_mask_to_box(data["masks"]) 382 | 383 | # Filter boxes that touch crop boundaries 384 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 385 | if not torch.all(keep_mask): 386 | data.filter(keep_mask) 387 | 388 | # filter by nms 389 | if self.nms_threshold > 0.0: 390 | keep_mask = batched_nms( 391 | data["boxes"].float(), 392 | data["iou_preds"], 393 | torch.zeros_like(data["boxes"][:, 0]), # categories 394 | iou_threshold=self.nms_threshold, 395 | ) 396 | data.filter(keep_mask) 397 | 398 | # apply sobel filter for probability map 399 | data["probs"] = batched_sobel_filter(data["probs"], data["masks"], 400 | bzp=self.bzp) 401 | 402 | # set prob to 0 for pixels outside of crop box 403 | # data["probs"] = batched_crop_probs(data["probs"], data["boxes"]) 404 | 405 | # Compress to RLE 406 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 407 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 408 | del data["masks"] 409 | 410 | return data 411 | --------------------------------------------------------------------------------