Slicer3:Python:DemianExamples

From SlicerWiki
Jump to: navigation, search
Home < Slicer3:Python:DemianExamples

Introduction

Demian Wassermann developed a set of tutorial slides and examples for using python and numpy in Slicer3.

Median Filter in a Masked Region

XML = """<?xml version="1.0" encoding="utf-8"?>
<executable>

  <category>Demo Scripted Modules</category>
  <title>Masked median filtering</title>
  <description>
Perform median filtering over a masked section of an image
</description>
  <version>1.0</version>
  <documentation-url></documentation-url>
  <license></license>
  <contributor>Demian Wassermann</contributor>

  <parameters>
   <label>IO</label>
    <description>Input/output parameters</description>

    <image type = "scalar" >
      <name>inputVolume</name>
      <longflag>inputVolume</longflag>
      <label>Input Image</label>
      <channel>input</channel>
      <description>Input image to be filtered</description>
    </image>

    <integer>
      <name>medianFilterRadius</name>
      <longflag>medianFilterRadius</longflag>
      <label>Radius of the median filter</label>
      <default>2</default>
      <step>1</step>
      <channel>input</channel>
      <constraints>
        <minimum>2</minimum>
        <maximum>100</maximum>
      </constraints>
    </integer>


    <image type="label">
      <name>inputMaskVolume</name>
      <longflag>inputMaskVolume</longflag>
      <label>Input Mask Volume</label>
      <channel>input</channel>
      <description>Input mask to work on it</description>
    </image>

    <integer>
      <name>labelToUse</name>
      <longflag>labelToUse</longflag>
      <label>Label to use for the mask</label>
      <default>1</default>
      <step>1</step>
      <channel>input</channel>
      <constraints>
        <minimum>0</minimum>
        <maximum>255</maximum>
      </constraints>
    </integer>

    <image type = "scalar">
      <name>outputFilteredVolume</name>
      <longflag>outputFilteredVolume</longflag>
      <label>Output Image</label>
      <channel>output</channel>
      <description>Image that was median filtered</description>
    </image>

  </parameters>

</executable>
"""


from Slicer import slicer
from scipy import ndimage
import numpy

def Execute(\
     inputVolume = "",\
     medianFilterRadius = 0,\
     inputMaskVolume = "",\
     labelToUse = 1,\
     outputFilteredVolume = ""\
     ):


#       Set up the slicer environment
  scene = slicer.MRMLScene

#       Get the nodes from the MRML tree

  inputVolumeNode = scene.GetNodeByID( inputVolume )
  inputMaskVolumeNode = scene.GetNodeByID( inputMaskVolume )
  outputFilteredVolumeNode = scene.GetNodeByID( outputFilteredVolume )

#       Set up the output node
  setupTheOutputNode( inputVolumeNode, outputFilteredVolumeNode )

  #maskToImageIJK = maskIJKToImageIJKMatrix( inputVolumeNode, inputMaskVolumeNode )

#       Do what we are here to do
  input_array = inputVolumeNode.GetImageData().ToArray()
  output_array = outputFilteredVolumeNode.GetImageData().ToArray()
  mask_array = inputMaskVolumeNode.GetImageData().ToArray()

  pointsToProcess = numpy.transpose(numpy.where( mask_array == labelToUse ))
  print pointsToProcess

  # Get the bounding box

  minCorner = pointsToProcess.min(0)   # yeah yeah there's a border problem
  maxCorner = pointsToProcess.max(0)+1

  # Perform the filtering
  medianFiltered = ndimage.median_filter(\
      input_array[\
      minCorner[0]:maxCorner[0],\
      minCorner[1]:maxCorner[1],\
      minCorner[2]:maxCorner[2],\
      ], medianFilterRadius )


# Set it to the output node
  output_array[:]=input_array
  output_array[ tuple(pointsToProcess.T) ] = medianFiltered[ tuple((pointsToProcess-minCorner).T) ]
  outputFilteredVolumeNode.Modified()




def setupTheOutputNode( inputVolumeNode, outputFilteredVolumeNode ):

  inputVolumeNode_imageData = inputVolumeNode.GetImageData()
  outputFilteredVolume_ImageData = outputFilteredVolumeNode.GetImageData()

  if not outputFilteredVolume_ImageData:
    outputFilteredVolume_ImageData = slicer.vtkImageData()
    outputFilteredVolumeNode.SetAndObserveImageData( outputFilteredVolume_ImageData )

  dimensions = inputVolumeNode_imageData.GetDimensions()
  outputFilteredVolume_ImageData.SetDimensions( dimensions[0], dimensions[1], dimensions[2] )
  outputFilteredVolume_ImageData.SetScalarType( inputVolumeNode_imageData.GetScalarType() )
  outputFilteredVolume_ImageData.SetOrigin( 0, 0, 0 )
  outputFilteredVolume_ImageData.SetSpacing( 1, 1, 1 )
  outputFilteredVolume_ImageData.AllocateScalars()

  matrix = slicer.vtkMatrix4x4()
  inputVolumeNode.GetIJKToRASMatrix( matrix )
  outputFilteredVolumeNode.SetIJKToRASMatrix( matrix )


  outputFilteredVolumeNode.Modified()

K-Medoids Fiber Clustering

XML = """<?xml version="1.0" encoding="utf-8"?>
<executable>

  <category>Demo Scripted Modules</category>
  <title>K-Medoids fiber clustering</title>
  <description>
Fiber Clustering simple K-Medoids 
</description>
  <version>1.0</version>
  <documentation-url></documentation-url>
  <license></license>
  <contributor>Demian Wassermann</contributor>

  <parameters>
   <label>IO</label>
    <description>Input/output parameters</description>

    <geometry type = "fiberbundle" >
      <name>inputFiberBundle</name>
      <longflag>inputFiberBundle</longflag>
      <label>Input Fiber Bundle</label>
      <channel>input</channel>
      <description>Input bundle</description>
    </geometry>

    <geometry >
      <name>outputFiberBundle</name>
      <longflag>outputFiberBundle</longflag>
      <label>Output Fiber Bundle</label>
      <channel>output</channel>
      <description>Clustered bundle</description>
    </geometry>

    <integer>
      <name>numberOfClusters</name>
      <longflag>numberOfClusters</longflag>
      <label>Number of clusters for K-Medoids</label>
      <default>5</default>
      <step>1</step>
      <channel>input</channel>
      <constraints>
        <minimum>2</minimum>
        <maximum>100</maximum>
      </constraints>
    </integer>

  </parameters>
  <parameters advanced="true">
    <label>Advanced</label>
    <integer>
      <name>subsampling</name>
      <longflag>subsampling</longflag>
      <label>Number of fiber points to keep</label>
      <default>15</default>
      <step>1</step>
      <channel>input</channel>
      <constraints>
        <minimum>2</minimum>
        <maximum>1000</maximum>
      </constraints>
    </integer>
    <integer>
      <name>minimumFiberLength</name>
      <longflag>minimumFiberLength</longflag>
      <label>minimum fiber length to consider valid</label>
      <default>15</default>
      <step>1</step>
      <channel>input</channel>
      <constraints>
        <minimum>2</minimum>
        <maximum>1000</maximum>
      </constraints>
    </integer>
  </parameters>

</executable>
"""

from Slicer import slicer
import numpy

# Warning, this example needs the package Pycluster 
# http://bonsai.ims.u-tokyo.ac.jp/~mdehoon/software/cluster/software.htm#pycluster
import Pycluster




def Execute (inputFiberBundle="", outputFiberBundle="", numberOfClusters=2, subsampling=15, minimumFiberLength=15 ):
  
  scene = slicer.MRMLScene

  inputFiberBundleNode = scene.GetNodeByID(inputFiberBundle)
  outputFiberBundleNode = scene.GetNodeByID(outputFiberBundle)


  #Prepare the output fiber bundle and the Arrays for the atlas labeling and clustering

  clusters = setupTheOutputNode( inputFiberBundleNode, outputFiberBundleNode )
  clusters_array = clusters.ToArray().squeeze()

  #Get the fibers form the Polydata and susbsample them
  fibers, lines = fibers_from_vtkPolyData( inputFiberBundleNode.GetPolyData(), minimumFiberLength )

  subsampledFibers = []

  for fiber in fibers:
    subsampledFibers.append( fiber[::max( len(fiber)/subsampling, len(fiber) ) ] )


  #Generate the distance matrix
  distanceMatrix = numpy.zeros( (len(fibers),len(fibers)), dtype=float )
  for i in xrange( len(fibers) ):
    for j in xrange( 0, i):
      distanceMatrix[ i, j ] = dist_hausdorff_min( subsampledFibers[i], subsampledFibers[j] )
      distanceMatrix[ j, i ] = distanceMatrix[ i, j ] 




  #Perform the clustering
  fiberClusters = renumberLabels(Pycluster.kmedoids( distanceMatrix, numberOfClusters, npass=100 )[0])
  print fiberClusters 
  clusters_array[:]=0


  for i in xrange(len(lines)):
    clusters_array[ lines[i] ] = fiberClusters[i]

  clusters.Modified()


dist2 = lambda i,j : numpy.sqrt(((i-j)**2).sum(j.ndim-1))
dist_hausdorff_asym_mean = lambda i,j: numpy.apply_along_axis( lambda k: dist2(k,j).min(),  1,i).mean()
dist_hausdorff_min = lambda i,j : numpy.min(dist_hausdorff_asym_mean(i,j),dist_hausdorff_asym_mean(j,i))



def fibers_from_vtkPolyData(vtkPolyData, minimumFiberLength):
    #Fibers and Lines are the same thing

    lines = vtkPolyData.GetLines().GetData().ToArray().squeeze()
    points = vtkPolyData.GetPoints().GetData().ToArray()

    fibersList = []
    linesList = []
    actualLineIndex = 0
    numberOfFibers = vtkPolyData.GetLines().GetNumberOfCells()
    for l in xrange( numberOfFibers ):
      if lines[actualLineIndex]>minimumFiberLength:
        fibersList.append( points[ lines[actualLineIndex+1: actualLineIndex+lines[actualLineIndex]+1] ] )
        linesList.append( lines[actualLineIndex+1: actualLineIndex+lines[actualLineIndex]+1]  )
      actualLineIndex += lines[actualLineIndex]+1

    return fibersList, linesList

def setupTheOutputNode( inputFiberBundleNode, outputFiberBundleNode ):
  if ( outputFiberBundleNode.GetPolyData()==[] ):
    outputFiberBundleNode.SetAndObservePolyData(slicer.vtkPolyData())

  outputPolyData = outputFiberBundleNode.GetPolyData()
  outputPolyData.SetPoints( inputFiberBundleNode.GetPolyData().GetPoints() )
  outputPolyData.SetLines( inputFiberBundleNode.GetPolyData().GetLines() )
  outputPolyData.Update()


  clusters = outputFiberBundleNode.GetPolyData().GetPointData().GetScalars('Cluster')
  if (clusters==[] or clusters.GetNumberOfTuples()!=outputPolyData.GetPoints().GetNumberOfPoints() ):
    clusters = slicer.vtkUnsignedIntArray()
    clusters.SetNumberOfComponents(1)
    clusters.SetNumberOfTuples( outputPolyData.GetPoints().GetNumberOfPoints() )
    clusters.SetName('Cluster')
    outputPolyData.GetPointData().AddArray( clusters )

  return clusters

def renumberLabels(labelArray):
  newLabeling=[]
  for a in labelArray:
    if not(a in newLabeling):
      newLabeling.append(a)

  newLabelArray=labelArray.copy()
  for i in range(len(labelArray)):
    newLabelArray[i]=newLabeling.index(labelArray[i])+1

  return newLabelArray