Date Tags skimage

In this post, we’ll see a way to segment nuclei in a confocal microscopy image stack using skimage

First import all we need

As I go on with the analysis, I usually add in the first cell all the import statements, rather than having them spread across the notebook. Once development is done, this will go into a script, and those import will already be sorted out.

I’m using the tifffile plugin (see http://www.lfd.uci.edu/~gohlke/code/tifffile.py.html) for the I/O.

In [3]:
import skimage.io as io
io.use_plugin('tifffile')

from skimage.filter import threshold_otsu, threshold_adaptive, rank
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.feature import peak_local_max
from scipy import ndimage
from skimage.morphology import disk, watershed
import pandas as pd

from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial import distance as dist
import scipy.cluster.hierarchy as hier

Load the image

Also, we set the pixel sizes for later. In a wonderfull world, this will be a OME XML file and we would read the metadata carrefully collected by the biologist who acquired the image directely from the file. Of course this is rarely the case.

In [5]:
image_stack = io.imread('../files/test_nuclei_stack.tif')
z_size, x_size, y_size = image_stack.shape

z_scale = 1.5 # µm per plane
xy_scale = 0.71 # µm per pixel

Look at the images

In [7]:
nrows = np.int(np.ceil(np.sqrt(z_size)))
ncols = np.int(z_size // nrows + 1)

fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))
for n in range(z_size):
    i = n // ncols
    j = n % ncols
    axes[i, j].imshow(image_stack[n, ...],
                      interpolation='nearest', cmap='gray')
    
## Remove empty plots 
for ax in axes.ravel():
    if not(len(ax.images)):
        fig.delaxes(ax)
fig.tight_layout()

Any analysis should start with a look at the histogram

In [8]:
fig, ax = plt.subplots(figsize=(8, 4))

ax.hist(image_stack.flatten(), log=True,
        bins=4096, range=(0, 4096))

_ = ax.set_title('Min value: %i \n'
                 'Max value: %i \n'
                 'Image shape: %s \n'
                 % (image_stack.min(),
                    image_stack.max(),
                    image_stack.shape))

Some brief comments on the histogram above. We see that the full 12 bits dynamic range is occupied, which is good. Yet, the much higher peak at the maximum value (4095) tells us some of the signal is saturated, which is not so good, we’ll see why later on. Lastly, the background doesn’t look like the expected background noise for fluorescence imagery, which should follow a Poisson distribution: $P(I) = I / \sigma_I^2 \exp{-I^2/\sigma_I^2} dI$, like the left graph bellow. What we see in the above is the result of an automated background correction, like in the roght graph bellow. So we don’t really deal with raw images…

In [11]:
intensities = np.arange(0, 4096)
background = 1000
dI = 1
proba = (intensities/background**2) * np.exp( - intensities**2 / background**2) * dI
proba /= proba.sum()
n_pixels = 512*512
proba *= n_pixels
proba = proba.astype(np.int)

fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 4))
ax0.step(intensities, proba, 'k')
ax0.set_title('Full histogram')
ax0.set_ylabel('Bin count')
ax0.set_xlabel('Intensity')

ax1.step(intensities[background:], proba[background:], 'k')
ax1.set_title('Background corrected histogram')
ax1.set_xlabel('Intensity')


fig.tight_layout()

Detection parameters

We usually need some settings, derived from what we now of the objects we are trying to segment. We’ll also normaly do some kind of smoothing, which will be specified as the pixel size of the filter kernel. So here we go:

In [12]:
smooth_size = 5 # pixels
min_radius = 4
max_radius = 20

Overview of the segmentation strategy:

In a nutt shell the idea is to (1) threshold, then (2) label the regions above threshold, then (3) segment the labeled regions. We’ll do this on each image plane. After that, we’ll cluster the resulting objects across the $z$ axis

Per plane segmentation

Of course this is always the difficult part. Here we use the threshold_otsu function from skimage. We compute an overall threshold over a maximum intensitiy projection along the $z$ axis, and use this threshold to label each stack individually. To ease the segmentation, we filter the image by first smoothing with a median filter and then a local contrast enhancement.

In [13]:
## Computing threshold on the maximum intensity projection with `threshold_otsu`
max_int_proj = image_stack.max(axis=0)
thresh_global = threshold_otsu(max_int_proj)

smoothed_stack = np.zeros_like(image_stack)
labeled_stack = smoothed_stack.copy()
## Labeling for each z plane:
for z, frame in enumerate(image_stack):
    smoothed = rank.median(frame, disk(smooth_size))
    smoothed = rank.enhance_contrast(smoothed, disk(smooth_size))
    smoothed_stack[z] = smoothed
    im_max = smoothed.max()
    thresh = thresh_global
    # thresh = threshold_otsu(smoothed)
    if im_max < thresh_global:
        labeled_stack[z] = np.zeros(smoothed.shape, dtype=np.int32)
    else:
        binary = smoothed > thresh
        #binary = threshold_adaptive(smoothed, block_size=smooth_size)
        
        distance = ndimage.distance_transform_edt(binary)
        local_maxi = peak_local_max(distance, min_distance=2*min_radius,
                                    indices=False, labels=smoothed)
        markers = ndimage.label(local_maxi)[0]
        labeled_stack[z] = watershed(-distance, markers, mask=binary)
/home/guillaume/python3/lib/python3.3/site-packages/scikit_image-0.10dev-py3.3-linux-x86_64.egg/skimage/filter/rank/generic.py:63: UserWarning: Bitdepth of 11 may result in bad rank filter performance due to large number of bins.
  "performance due to large number of bins." % bitdepth)

Result of the segmentation

In [14]:
fig, axes = plt.subplots(nrows, ncols*2, figsize=(3*ncols, 1.5*nrows))

for z in range(z_size):
    i = z // ncols
    j = z % ncols * 2
    axes[i, j].imshow(smoothed_stack[z, ...], interpolation='nearest', cmap='gray')
    axes[i, j+1].imshow(labeled_stack[z, ...], interpolation='nearest', cmap='Dark2')
    
    axes[i, j].set_xticks([])
    axes[i, j].set_yticks([])
    axes[i, j+1].set_xticks([])
    axes[i, j+1].set_yticks([])

## Remove empty plots 
for ax in axes.ravel():
    if not(len(ax.images)):
        fig.delaxes(ax)

        
fig.tight_layout()

a closer look:

In [15]:
fig, axes = plt.subplots(1, 2, figsize=(8, 16))

z = z_size // 2
axes[0].imshow(smoothed_stack[z, ...], interpolation='nearest', cmap='gray')
axes[1].imshow(labeled_stack[z, ...], interpolation='nearest', cmap='Dark2')
Out[15]:
<matplotlib.image.AxesImage at 0x7fd55d9b2110>

Computing the properties of the labeled regions:

We use skimage.measure handy function regionprops. For convinience, we store the computed properties in a pandas DataFrame object (this is particullarly usefulll if you have more images or more timepoints, and want to later on manipulate the collected data.

In [20]:
properties = []
columns = ('x', 'y', 'z', 'I', 'w')
indices = []
for z, frame in enumerate(labeled_stack):
    f_prop = regionprops(frame.astype(np.int),
                         intensity_image=image_stack[z])
    for d in f_prop:
        radius = (d.area / np.pi)**0.5
        if (min_radius  < radius < max_radius):
            properties.append([d.weighted_centroid[0],
                               d.weighted_centroid[1],
                               z, d.mean_intensity * d.area,
                               radius])
            indices.append(d.label)
if not len(indices):
    all_props = pd.DataFrame([], index=[])
indices = pd.Index(indices, name='label')
properties = pd.DataFrame(properties, index=indices, columns=columns)
properties['I'] /= properties['I'].max()

Here is what we collected:

In [21]:
properties.head()
Out[21]:
x y z I w
label
6 71.249178 40.507531 5 0.394966 5.808687
1 70.186552 39.377994 6 0.854559 8.156394
2 82.987933 50.992160 6 0.171050 4.583498
1 70.022599 38.560650 7 0.990137 8.758578
3 83.256457 50.622169 7 0.316493 5.262410

5 rows × 5 columns

Let’s look at what we collected, plotted over the segmented image

In [22]:
fig, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))

for z in range(z_size):
    plane_props = properties[properties['z'] == z]
    if not(plane_props.shape[0]) :
        continue
    i = z // ncols
    j = z % ncols
    axes[i, j].imshow(labeled_stack[z, ...],
                      interpolation='nearest', cmap='Dark2')
    axes[i, j].set_xticks([])
    axes[i, j].set_yticks([])
    x_lim = axes[i, j].get_xlim()
    y_lim = axes[i, j].get_ylim()    
    
    
    axes[i, j].scatter(plane_props['y'], plane_props['x'],  
                       s=plane_props['I']*200, alpha=0.4)
    axes[i, j].scatter(plane_props['y'], plane_props['x'],
                       s=40, marker='+', alpha=0.4)
    axes[i, j].set_xlim(x_lim)
    axes[i, j].set_ylim(y_lim)    
 

## Remove empty plots 
for ax in axes.ravel():
    if not(len(ax.images)):
        fig.delaxes(ax)

        
fig.tight_layout()

Bellow is another way to look at the segmented nuclei positions, this time over intensity projections of the original image.

In [23]:
fig = plt.figure(figsize=(12, 12))
colors = plt.cm.jet(properties.index.astype(np.int32))

# xy projection:
ax_xy = fig.add_subplot(111)
ax_xy.imshow(image_stack.max(axis=0), cmap='gray')
ax_xy.scatter(properties['y'],
              properties['x'], c=colors, alpha=1)

divider = make_axes_locatable(ax_xy)
ax_zx = divider.append_axes("top", 2, pad=0.2, sharex=ax_xy)
ax_zx.imshow(image_stack.max(axis=1), aspect=z_scale/xy_scale, cmap='gray')
ax_zx.scatter(properties['y'],
              properties['z'], c=colors, alpha=1)
ax_yz = divider.append_axes("right", 2, pad=0.2, sharey=ax_xy)
ax_yz.imshow(image_stack.max(axis=2).T, aspect=xy_scale/z_scale, cmap='gray')
ax_yz.scatter(properties['z'],
              properties['x'], c=colors, alpha=1)
plt.draw()

Clustering

Now that we have detected the nuclei in each plane, we want to regroup them across $z$ so that we have only one 3D position for each of the 4 (or is it 5?) nuclei.

We do so by applying a hierarchical clustering whith scipy.cluster module of the positions in the $(x, y)$ plane. As a parameter for the clustering, we quite naturally use the maximum radius we defined earlier.

In [24]:
positions = properties[['x', 'y']].copy()

dist_mat = dist.squareform(dist.pdist(positions.values))
link_mat = hier.linkage(dist_mat)
cluster_idx = hier.fcluster(link_mat, max_radius,
                            criterion='distance')
properties['new_label'] = cluster_idx
properties.set_index('new_label', drop=True, append=False, inplace=True)
properties.index.name = 'label'
properties = properties.sort_index()

Now the detected elements are regrouped by label:

In [25]:
properties.head()
Out[25]:
x y z I w
label
1 57.060095 52.985963 18 0.255932 5.556623
1 57.433341 52.431252 17 0.357211 5.698035
1 55.415167 54.268103 16 0.301234 5.232079
1 55.324279 54.091615 15 0.297468 5.382026
2 44.683366 49.372282 13 0.290805 4.982787

5 rows × 5 columns

Next we want to get the center of each cluster. To do so we run a weighted average of the positions for each cluster, using the measured intensity as weight.

In [26]:
def df_average(df, weights_column):
    '''Computes the average on each columns of a dataframe, weighted
    by the values of the column `weight_columns`.
    
    Parameters:
    -----------
    df: a pandas DataFrame instance
    weights_column: a string, the column name of the weights column 
    
    Returns:
    --------
    
    values: pandas DataFrame instance with the same column names as `df`
        with the weighted average value of the column
    '''
    
    values = df.copy().iloc[0]
    norm = df[weights_column].sum()
    for col in df.columns:
        try:
            v = (df[col] * df[weights_column]).sum() / norm
        except TypeError:
            v = df[col].iloc[0]
        values[col] = v
    return values
In [27]:
cell_positions = properties.groupby(level='label').apply(df_average, 'I')

So here is what we where looking for:

In [28]:
cell_positions
Out[28]:
x y z I w
label
1 56.335143 53.412561 16.471683 0.307242 5.474775
2 44.899683 49.921097 11.522186 0.419746 5.945169
3 82.425622 50.970181 11.551490 0.364361 5.626174
4 70.343025 37.816205 8.129519 0.794513 7.884338

4 rows × 5 columns

The final result

Let’s plot all that

In [29]:
fig = plt.figure(figsize=(12, 12))
colors = plt.cm.jet(properties.index.astype(np.int32))

# xy projection:
ax_xy = fig.add_subplot(111)
ax_xy.imshow(image_stack.max(axis=0), cmap='gray')
ax_xy.scatter(properties['y'],
              properties['x'],
              c=colors, alpha=0.2)


ax_xy.scatter(cell_positions['y'],
              cell_positions['x'],
              c='r', s=50, alpha=1.)


divider = make_axes_locatable(ax_xy)
ax_yz = divider.append_axes("top", 2, pad=0.2, sharex=ax_xy)
ax_yz.imshow(image_stack.max(axis=1), aspect=z_scale/xy_scale, cmap='gray')
ax_yz.scatter(properties['y'],
              properties['z'],
              c=colors, alpha=0.2)

ax_yz.scatter(cell_positions['y'],
              cell_positions['z'],
              c='r', s=50, alpha=1.)


ax_zx = divider.append_axes("right", 2, pad=0.2, sharey=ax_xy)
ax_zx.imshow(image_stack.max(axis=2).T, aspect=xy_scale/z_scale, cmap='gray')
ax_zx.scatter(properties['z'],
              properties['x'],
              c=colors, alpha=0.2)

ax_zx.scatter(cell_positions['z'],
              cell_positions['x'],
              c='r', s=50, alpha=1.)

plt.draw()