Imitando el algoritmo watershed de ImageJ#

En ImageJ hay un algoritmo llamado “Watershed” que permite dividir objetos densos segmentados dentro de una imagen binaria. Este cuaderno demuestra cómo lograr una operación similar en Python.

El “Watershed” en ImageJ se aplicó a una imagen binaria utilizando esta macro:

open("../BioImageAnalysisNotebooks/data/blobs_otsu.tif");
run("Watershed");
from skimage.io import imread, imshow
import matplotlib.pyplot as plt
import napari_segment_blobs_and_things_with_membranes as nsbatwm
import numpy as np

from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.filters import gaussian, sobel
from skimage.measure import label
from skimage.segmentation import watershed
from skimage.morphology import binary_opening

El punto de partida para la demostración es una imagen binaria.

binary_image = imread("../../data/blobs_otsu.tif")
imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e41b2700>
../_images/1a67ed79a4d4aad6da50c333def145da3c6d63c765ee7abfe2b9596de8270df6.png

Después de aplicar la macro mostrada anteriormente, la imagen resultante en ImageJ se ve así:

binary_watershed_imagej = imread("../../data/blobs_otsu_watershed.tif")
imshow(binary_watershed_imagej)
<matplotlib.image.AxesImage at 0x1f8e4209100>
../_images/86e8366055208bd6e1f9322796080904a9f59e2f2e95f6df24b086204337e884.png

El plugin de Napari napari-segment-blobs-and-things-with-membranes ofrece una función para imitar la funcionalidad de ImageJ.

binary_watershed_nsbatwm = nsbatwm.split_touching_objects(binary_image)
imshow(binary_watershed_nsbatwm)
<matplotlib.image.AxesImage at 0x1f8e42a63d0>
../_images/7b1305d67025f669d96b7d40ee328b1386a1680344b38718b319f651ab255b28.png

Comparando resultados#

Al comparar los resultados, es obvio que no son 100% idénticos.

fig, axs = plt.subplots(1, 2, figsize=(10,10))

axs[0].imshow(binary_watershed_imagej)
axs[0].set_title("ImageJ")
axs[1].imshow(binary_watershed_nsbatwm)
axs[1].set_title("nsbatwm")
Text(0.5, 1.0, 'nsbatwm')
../_images/8f620d03e62c5d9540f7b8417ba99ffa6fd4ce0c330ae762e92e786d90f631e6.png

Ajuste fino de resultados#

Es posible modificar el resultado ajustando el parámetro sigma.

fig, axs = plt.subplots(1, 4, figsize=(10,10))

for i, sigma in enumerate(np.arange(2, 6, 1)):
    result = nsbatwm.split_touching_objects(binary_image, sigma=sigma)
    axs[i].imshow(result)
    axs[i].set_title("sigma="+str(sigma))
../_images/2b270e8a59cbce9ef7d0f05de36435c9cd49fb2b6085dc659caed85cc9d956a2.png

¿Cómo funciona?#

Internamente, el algoritmo watershed de ImageJ utiliza una imagen de distancia y detección de puntos. El siguiente código intenta replicar el resultado.

Nuevamente, comenzamos con la imagen binaria.

imshow(binary_image)
<matplotlib.image.AxesImage at 0x1f8e55b87f0>
../_images/1a67ed79a4d4aad6da50c333def145da3c6d63c765ee7abfe2b9596de8270df6.png

El primer paso es producir una imagen de distancia.

distance = ndi.distance_transform_edt(binary_image)
imshow(distance)
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Float image out of standard range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e55f17c0>
../_images/a915b049790f4c508c064e452983cc5e39f2a9e725651d71344f68f30352d00f.png

Para evitar objetos divididos muy pequeños, difuminamos la imagen de distancia utilizando el parámetro sigma.

sigma = 3.5

blurred_distance = gaussian(distance, sigma=sigma)
imshow(blurred_distance)
<matplotlib.image.AxesImage at 0x1f8e56ec5e0>
../_images/cf45e136aa34b9b8b861abc84d50895f757c31ed010993bedc83c9fbe7a543e1.png

Dentro de esta imagen difuminada, buscamos máximos locales y los recibimos como una lista de coordenadas.

fp = np.ones((3,) * binary_image.ndim)
coords = peak_local_max(blurred_distance, footprint=fp, labels=binary_image)

# mostramos solo los primeros 5
coords[:5]
array([[  8, 254],
       [ 97,   1],
       [ 10, 108],
       [230, 180],
       [182, 179]], dtype=int64)

A continuación, escribimos estos máximos en una nueva imagen y los etiquetamos.

mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
markers = label(mask)
imshow(markers, cmap='jet')
C:\Users\haase\mambaforge\envs\bio39\lib\site-packages\skimage\io\_plugins\matplotlib_plugin.py:150: UserWarning: Low image data range; displaying image with stretched contrast.
  lo, hi, cmap = _get_display_range(image)
<matplotlib.image.AxesImage at 0x1f8e59c1c40>
../_images/eb86a0069cb107c60d8e3a783746b77c651dba9b4359d01bd5c84455c02d4df6.png

Luego, aplicamos el algoritmo Watershed de scikit-image (ejemplo). Toma una imagen de distancia y una imagen de etiquetas como entrada. La entrada opcional es la binary_image para limitar la propagación de las etiquetas demasiado lejos.

labels = watershed(-blurred_distance, markers, mask=binary_image)
imshow(labels, cmap='jet')
<matplotlib.image.AxesImage at 0x1f8e5a97f40>
../_images/9fd2db3a41881adc4680b7461e9d482e8ff369a2ab220a46293b725e5d81cd07.png

Para crear una imagen binaria nuevamente como lo hace ImageJ, ahora identificamos los bordes entre las etiquetas.

# identificar bordes que cortan etiquetas
edges_labels = sobel(labels)
edges_binary = sobel(binary_image)

edges = np.logical_xor(edges_labels != 0, edges_binary != 0)
imshow(edges)
<matplotlib.image.AxesImage at 0x1f8e5aaea00>
../_images/fef6ed859fb8cd3a99dc0c12fe342189109bcefdabbe3ce0a2d0abe0920e9444.png

Luego, restamos esos bordes de la binary_image original.

almost = np.logical_not(edges) * binary_image
imshow(almost)
<matplotlib.image.AxesImage at 0x1f8e55c8610>
../_images/5b617b10b44acd24fb80a9af3c5e1d6870481916b0d727eab774ad6adb6e0020.png

Como este resultado aún no es perfecto, aplicamos una apertura binaria.

result = binary_opening(almost)
imshow(result)
<matplotlib.image.AxesImage at 0x1f8e3b81e50>
../_images/7b1305d67025f669d96b7d40ee328b1386a1680344b38718b319f651ab255b28.png