Source code for pyjacket.graphtools.skeleton

""" A skeleton is a grid whose features are 1-pixel wide"""

import numpy as np
from scipy.ndimage import convolve
import numpy as np

BRANCH_PATTERNS = np.array([
    [[0, 1, 0],
     [1, 1, 1],
     [0, 0, 0]],

    [[1, 0, 1,],
     [0, 1, 0,],
     [1, 0, 0,]],

    [[1, 0, 1],
     [0, 1, 0],
     [0, 1, 0]],

    [[0, 1, 0],
     [1, 1, 0],
     [0, 0, 1]],

    [[0, 0, 1],
     [1, 1, 1],
     [0, 1, 0]],

    [[1, 0, 0],
     [1, 1, 1],
     [0, 1, 0]],

    [[0, 1, 0],
     [1, 1, 0],
     [0, 1, 0]],

    [[0, 0, 0],
     [1, 1, 1],
     [0, 1, 0]],

    [[0, 1, 0],
     [0, 1, 1],
     [0, 1, 0]],

    [[1, 0, 0],
     [0, 1, 0],
     [1, 0, 1]],

    [[0, 0, 1],
     [0, 1, 0],
     [1, 0, 1]],

    [[1, 0, 1],
     [0, 1, 0],
     [0, 0, 1]],

    [[1, 0, 0],
     [0, 1, 1],
     [1, 0, 0]],

    [[0, 1, 0],
     [0, 1, 0],
     [1, 0, 1]],

    [[0, 0, 1],
     [1, 1, 0],
     [0, 0, 1]],

    [[0, 0, 1],
     [1, 1, 0],
     [0, 1, 0]],

    [[1, 0, 0],
     [0, 1, 1],
     [0, 1, 0]],

    [[0, 1, 0],
     [0, 1, 1],
     [1, 0, 0]],

    [[1, 1, 0],
     [0, 1, 1],
     [0, 1, 0]],

    [[0, 1, 0],
     [1, 1, 1],
     [1, 0, 0]],

    [[0, 1, 0],
     [1, 1, 0],
     [0, 1, 1]],

    [[0, 1, 0],
     [0, 1, 1],
     [1, 1, 0]],

    [[0, 1, 0],
     [1, 1, 1],
     [0, 0, 1]],

    [[0, 1, 1],
     [1, 1, 0],
     [0, 1, 0]],

    [[0, 1, 0],
     [1, 1, 1],
     [0, 1, 0]],

    [[1, 0, 1],
     [0, 1, 0],
     [1, 0, 1]],
    
    ])

NEIGHBOR_KERNEL = np.array([[1, 1, 1], 
                            [1, 0, 1], 
                            [1, 1, 1]])

[docs] def critical_points(skeleton: np.ndarray): """Finds end points and branch points for a skeleton image""" assert skeleton.dtype == np.uint8, ValueError('Skeleton must be np.uint8') neighbor_count = convolve(skeleton, NEIGHBOR_KERNEL, mode="constant", cval=0) end_points = np.argwhere(skeleton & (neighbor_count == 1)) branch_candidates = np.argwhere(skeleton & (neighbor_count >= 3)) branch_points = [yx for yx in branch_candidates if is_branch_point(skeleton, *yx)] return end_points, branch_points
[docs] def is_branch_point(skeleton: np.ndarray, y, x): padded_region = np.zeros((3, 3), dtype=skeleton.dtype) # Get the valid region indices y_start, y_end = max(0, y-1), min(skeleton.shape[0], y+2) x_start, x_end = max(0, x-1), min(skeleton.shape[1], x+2) # Compute the corresponding indices in the 3x3 patch patch_y_start, patch_y_end = 1 - (y - y_start), 1 + (y_end - y) patch_x_start, patch_x_end = 1 - (x - x_start), 1 + (x_end - x) # Copy valid data from skeleton into the 3x3 region padded_region[patch_y_start:patch_y_end, patch_x_start:patch_x_end] = skeleton[y_start:y_end, x_start:x_end] return any(( np.all(pattern==padded_region) \ for pattern in BRANCH_PATTERNS ))