In this post, we’ll see a way to segment nuclei in a confocal microscopy image stack using skimage
First import all we need¶
As I go on with the analysis, I usually add in the first cell all the import
statements, rather than having them spread across the notebook. Once development is done, this will go into a script, and those import
will already be sorted out.
I’m using the tifffile
plugin (see http://www.lfd.uci.edu/~gohlke/code/tifffile.py.html) for the I/O.
import skimage.io as io
io.use_plugin('tifffile')
from skimage.filter import threshold_otsu, threshold_adaptive, rank
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.feature import peak_local_max
from scipy import ndimage
from skimage.morphology import disk, watershed
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial import distance as dist
import scipy.cluster.hierarchy as hier
Load the image¶
Also, we set the pixel sizes for later. In a wonderfull world, this will be a OME XML file and we would read the metadata carrefully collected by the biologist who acquired the image directely from the file. Of course this is rarely the case.
image_stack = io.imread('../files/test_nuclei_stack.tif')
z_size, x_size, y_size = image_stack.shape
z_scale = 1.5 # µm per plane
xy_scale = 0.71 # µm per pixel
Look at the images¶
nrows = np.int(np.ceil(np.sqrt(z_size)))
ncols = np.int(z_size // nrows + 1)
fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
for n in range(z_size):
i = n // ncols
j = n % ncols
axes[i, j].imshow(image_stack[n, ...],
interpolation='nearest', cmap='gray')
## Remove empty plots
for ax in axes.ravel():
if not(len(ax.images)):
fig.delaxes(ax)
fig.tight_layout()
Any analysis should start with a look at the histogram¶
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(image_stack.flatten(), log=True,
bins=4096, range=(0, 4096))
_ = ax.set_title('Min value: %i \n'
'Max value: %i \n'
'Image shape: %s \n'
% (image_stack.min(),
image_stack.max(),
image_stack.shape))
Some brief comments on the histogram above. We see that the full 12 bits dynamic range is occupied, which is good. Yet, the much higher peak at the maximum value (4095) tells us some of the signal is saturated, which is not so good, we’ll see why later on. Lastly, the background doesn’t look like the expected background noise for fluorescence imagery, which should follow a Poisson distribution: $P(I) = I / \sigma_I^2 \exp{-I^2/\sigma_I^2} dI$, like the left graph bellow. What we see in the above is the result of an automated background correction, like in the roght graph bellow. So we don’t really deal with raw images…
intensities = np.arange(0, 4096)
background = 1000
dI = 1
proba = (intensities/background**2) * np.exp( - intensities**2 / background**2) * dI
proba /= proba.sum()
n_pixels = 512*512
proba *= n_pixels
proba = proba.astype(np.int)
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 4))
ax0.step(intensities, proba, 'k')
ax0.set_title('Full histogram')
ax0.set_ylabel('Bin count')
ax0.set_xlabel('Intensity')
ax1.step(intensities[background:], proba[background:], 'k')
ax1.set_title('Background corrected histogram')
ax1.set_xlabel('Intensity')
fig.tight_layout()
Detection parameters¶
We usually need some settings, derived from what we now of the objects we are trying to segment. We’ll also normaly do some kind of smoothing, which will be specified as the pixel size of the filter kernel. So here we go:
smooth_size = 5 # pixels
min_radius = 4
max_radius = 20
Overview of the segmentation strategy:¶
In a nutt shell the idea is to (1) threshold, then (2) label the regions above threshold, then (3) segment the labeled regions. We’ll do this on each image plane. After that, we’ll cluster the resulting objects across the $z$ axis
Per plane segmentation¶
Of course this is always the difficult part. Here we use the threshold_otsu
function from skimage.
We compute an overall threshold over a maximum intensitiy projection along the $z$ axis, and use this threshold to label each stack individually. To ease the segmentation, we filter the image by first smoothing with a median filter and then a local contrast enhancement.
## Computing threshold on the maximum intensity projection with `threshold_otsu`
max_int_proj = image_stack.max(axis=0)
thresh_global = threshold_otsu(max_int_proj)
smoothed_stack = np.zeros_like(image_stack)
labeled_stack = smoothed_stack.copy()
## Labeling for each z plane:
for z, frame in enumerate(image_stack):
smoothed = rank.median(frame, disk(smooth_size))
smoothed = rank.enhance_contrast(smoothed, disk(smooth_size))
smoothed_stack[z] = smoothed
im_max = smoothed.max()
thresh = thresh_global
# thresh = threshold_otsu(smoothed)
if im_max < thresh_global:
labeled_stack[z] = np.zeros(smoothed.shape, dtype=np.int32)
else:
binary = smoothed > thresh
#binary = threshold_adaptive(smoothed, block_size=smooth_size)
distance = ndimage.distance_transform_edt(binary)
local_maxi = peak_local_max(distance, min_distance=2*min_radius,
indices=False, labels=smoothed)
markers = ndimage.label(local_maxi)[0]
labeled_stack[z] = watershed(-distance, markers, mask=binary)
Result of the segmentation¶
fig, axes = plt.subplots(nrows, ncols*2, figsize=(3*ncols, 1.5*nrows))
for z in range(z_size):
i = z // ncols
j = z % ncols * 2
axes[i, j].imshow(smoothed_stack[z, ...], interpolation='nearest', cmap='gray')
axes[i, j+1].imshow(labeled_stack[z, ...], interpolation='nearest', cmap='Dark2')
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
axes[i, j+1].set_xticks([])
axes[i, j+1].set_yticks([])
## Remove empty plots
for ax in axes.ravel():
if not(len(ax.images)):
fig.delaxes(ax)
fig.tight_layout()
a closer look:
fig, axes = plt.subplots(1, 2, figsize=(8, 16))
z = z_size // 2
axes[0].imshow(smoothed_stack[z, ...], interpolation='nearest', cmap='gray')
axes[1].imshow(labeled_stack[z, ...], interpolation='nearest', cmap='Dark2')
Computing the properties of the labeled regions:¶
We use skimage.measure
handy function regionprops
. For convinience, we store the computed properties in a pandas DataFrame
object (this is particullarly usefulll if you have more images or more timepoints, and want to later on manipulate the collected data.
properties = []
columns = ('x', 'y', 'z', 'I', 'w')
indices = []
for z, frame in enumerate(labeled_stack):
f_prop = regionprops(frame.astype(np.int),
intensity_image=image_stack[z])
for d in f_prop:
radius = (d.area / np.pi)**0.5
if (min_radius < radius < max_radius):
properties.append([d.weighted_centroid[0],
d.weighted_centroid[1],
z, d.mean_intensity * d.area,
radius])
indices.append(d.label)
if not len(indices):
all_props = pd.DataFrame([], index=[])
indices = pd.Index(indices, name='label')
properties = pd.DataFrame(properties, index=indices, columns=columns)
properties['I'] /= properties['I'].max()
Here is what we collected:
properties.head()
Let’s look at what we collected, plotted over the segmented image
fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
for z in range(z_size):
plane_props = properties[properties['z'] == z]
if not(plane_props.shape[0]) :
continue
i = z // ncols
j = z % ncols
axes[i, j].imshow(labeled_stack[z, ...],
interpolation='nearest', cmap='Dark2')
axes[i, j].set_xticks([])
axes[i, j].set_yticks([])
x_lim = axes[i, j].get_xlim()
y_lim = axes[i, j].get_ylim()
axes[i, j].scatter(plane_props['y'], plane_props['x'],
s=plane_props['I']*200, alpha=0.4)
axes[i, j].scatter(plane_props['y'], plane_props['x'],
s=40, marker='+', alpha=0.4)
axes[i, j].set_xlim(x_lim)
axes[i, j].set_ylim(y_lim)
## Remove empty plots
for ax in axes.ravel():
if not(len(ax.images)):
fig.delaxes(ax)
fig.tight_layout()
Bellow is another way to look at the segmented nuclei positions, this time over intensity projections of the original image.
fig = plt.figure(figsize=(12, 12))
colors = plt.cm.jet(properties.index.astype(np.int32))
# xy projection:
ax_xy = fig.add_subplot(111)
ax_xy.imshow(image_stack.max(axis=0), cmap='gray')
ax_xy.scatter(properties['y'],
properties['x'], c=colors, alpha=1)
divider = make_axes_locatable(ax_xy)
ax_zx = divider.append_axes("top", 2, pad=0.2, sharex=ax_xy)
ax_zx.imshow(image_stack.max(axis=1), aspect=z_scale/xy_scale, cmap='gray')
ax_zx.scatter(properties['y'],
properties['z'], c=colors, alpha=1)
ax_yz = divider.append_axes("right", 2, pad=0.2, sharey=ax_xy)
ax_yz.imshow(image_stack.max(axis=2).T, aspect=xy_scale/z_scale, cmap='gray')
ax_yz.scatter(properties['z'],
properties['x'], c=colors, alpha=1)
plt.draw()
Clustering¶
Now that we have detected the nuclei in each plane, we want to regroup them across $z$ so that we have only one 3D position for each of the 4 (or is it 5?) nuclei.
We do so by applying a hierarchical clustering whith scipy.cluster
module of the positions in the $(x, y)$ plane. As a parameter for the clustering, we quite naturally use the maximum radius we defined earlier.
positions = properties[['x', 'y']].copy()
dist_mat = dist.squareform(dist.pdist(positions.values))
link_mat = hier.linkage(dist_mat)
cluster_idx = hier.fcluster(link_mat, max_radius,
criterion='distance')
properties['new_label'] = cluster_idx
properties.set_index('new_label', drop=True, append=False, inplace=True)
properties.index.name = 'label'
properties = properties.sort_index()
Now the detected elements are regrouped by label:
properties.head()
Next we want to get the center of each cluster. To do so we run a weighted average of the positions for each cluster, using the measured intensity as weight.
def df_average(df, weights_column):
'''Computes the average on each columns of a dataframe, weighted
by the values of the column `weight_columns`.
Parameters:
-----------
df: a pandas DataFrame instance
weights_column: a string, the column name of the weights column
Returns:
--------
values: pandas DataFrame instance with the same column names as `df`
with the weighted average value of the column
'''
values = df.copy().iloc[0]
norm = df[weights_column].sum()
for col in df.columns:
try:
v = (df[col] * df[weights_column]).sum() / norm
except TypeError:
v = df[col].iloc[0]
values[col] = v
return values
cell_positions = properties.groupby(level='label').apply(df_average, 'I')
So here is what we where looking for:
cell_positions
The final result¶
Let’s plot all that
fig = plt.figure(figsize=(12, 12))
colors = plt.cm.jet(properties.index.astype(np.int32))
# xy projection:
ax_xy = fig.add_subplot(111)
ax_xy.imshow(image_stack.max(axis=0), cmap='gray')
ax_xy.scatter(properties['y'],
properties['x'],
c=colors, alpha=0.2)
ax_xy.scatter(cell_positions['y'],
cell_positions['x'],
c='r', s=50, alpha=1.)
divider = make_axes_locatable(ax_xy)
ax_yz = divider.append_axes("top", 2, pad=0.2, sharex=ax_xy)
ax_yz.imshow(image_stack.max(axis=1), aspect=z_scale/xy_scale, cmap='gray')
ax_yz.scatter(properties['y'],
properties['z'],
c=colors, alpha=0.2)
ax_yz.scatter(cell_positions['y'],
cell_positions['z'],
c='r', s=50, alpha=1.)
ax_zx = divider.append_axes("right", 2, pad=0.2, sharey=ax_xy)
ax_zx.imshow(image_stack.max(axis=2).T, aspect=xy_scale/z_scale, cmap='gray')
ax_zx.scatter(properties['z'],
properties['x'],
c=colors, alpha=0.2)
ax_zx.scatter(cell_positions['z'],
cell_positions['x'],
c='r', s=50, alpha=1.)
plt.draw()
A brief final comment: The detection looks fine althou we only have 4 nuclei, and it well looks like (especially in the right pannel) two nuclei where merged in the process. I don’t really see how to avoid that with this method, maybe by searching for two local intensity maxima along the $z$ axis in each cluster (I’ve tried it), but this is not very noise resitant. Maybe a random walker segmentation in 3D could avoir this, but I must say I don’t really see how to implement that strategy. One of the difficulty with a random walker here is that as the data is saturated you don’t have well defined local maxima, but instead big regions at maximum value.
Comments