# Register Two Point Sets¶

Similar to image registration, an n-dimensional “moving” point set may be resampled to align with a “fixed” point set. An ITK point set metric may be employed with an ITK optimizer in order to register the two sets.

In this example we create two itk.PointSet representations with an arbitrary offset and select parameters to align them with a TranslationTransform. We use the JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4 class to quantify the difference between point sets and the GradientDescentOptimizerv4 class to iteratively reduce this difference by modifying transform parameters. Our example also includes sample code visualizing the parameter surface with matplotlib and itkwidgets as well as a sample hyperparameter search to optimize gradient descent performance.

[1]:

import os
import sys
import itertools
from math import pi, sin, cos, sqrt

import matplotlib.pyplot as plt
import numpy as np

import itk
from itkwidgets import view


## Construct Two Point Sets¶

[2]:

# Generate two circles with a small offset
def make_circles(dimension:int=2,offset:list=None):

PointSetType = itk.PointSet[itk.F, dimension]

if not offset or len(offset) != dimension:
offset = [2.0] * dimension

fixed_points = PointSetType.New()
moving_points = PointSetType.New()
fixed_points.Initialize()
moving_points.Initialize()

count = 0
step = 0.1
for count in range(0, int(2 * pi / step) + 1):

theta = count * step

fixed_point = list()
for dim in range(1,dimension):
fixed_points.SetPoint(count, fixed_point)

moving_point = [fixed_point[dim] + offset[dim]
for dim in range(0,dimension)]
moving_points.SetPoint(count, moving_point)

return fixed_points, moving_points

[3]:

POINT_SET_OFFSET = [15.0, 15.0]
fixed_set, moving_set = make_circles(offset=POINT_SET_OFFSET)

[4]:

# Visualize point sets with matplotlib

fig = plt.figure()
ax = plt.axes()

n_points = fixed_set.GetNumberOfPoints()
ax.scatter([fixed_set.GetPoint(i)[0] for i in range(0,n_points)],
[fixed_set.GetPoint(i)[1] for i in range(0,n_points)])
ax.scatter([moving_set.GetPoint(i)[0] for i in range(0,n_points)],
[moving_set.GetPoint(i)[1] for i in range(0,n_points)])

[4]:

<matplotlib.collections.PathCollection at 0x1d48559e790>


We will quantify the point set offset with JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4 and minimize the metric value over 10 gradient descent iterations.

[5]:


ExhaustiveOptimizerType = itk.ExhaustiveOptimizerv4[itk.D]

[6]:

dim = 2

# Define translation parameters to update iteratively
TransformType = itk.TranslationTransform[itk.D, dim]
transform = TransformType.New()
transform.SetIdentity()

[7]:

PointSetType = type(fixed_set)

# Define a metric to reflect the difference between point sets
PointSetMetricType = itk.JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4[PointSetType]
metric = PointSetMetricType.New(
FixedPointSet=fixed_set,
MovingPointSet=moving_set,
MovingTransform=transform,
PointSetSigma=5.0,
KernelSigma=10.0,
UseAnisotropicCovariances=False,
CovarianceKNeighborhood=5,
EvaluationKNeighborhood=10,
Alpha=1.1)
metric.Initialize()

[8]:

# Define an estimator to help determine step sizes along each transform parameter2
ShiftScalesType = itk.RegistrationParameterScalesFromPhysicalShift[PointSetMetricType]
shift_scale_estimator = ShiftScalesType.New(
Metric=metric,
VirtualDomainPointSet=metric.GetVirtualTransformedPointSet(),
TransformForward=True
)

[9]:

max_iterations = 10

# Define the gradient descent optimzer
optimizer = OptimizerType.New(
Metric=metric,
NumberOfIterations=max_iterations,
ScalesEstimator=shift_scale_estimator,
MaximumStepSizeInPhysicalUnits=8.0,
MinimumConvergenceValue=-1,
DoEstimateLearningRateAtEachIteration=False,
DoEstimateLearningRateOnce=True,
ReturnBestParametersAndValue=True)

[10]:

iteration_data = dict()

# Track gradient descent iterations with observers
def print_iteration():
print(f'It: {optimizer.GetCurrentIteration()}'
f' metric value: {optimizer.GetCurrentMetricValue():.6f} '
#f' transform position: {list(optimizer.GetCurrentPosition())}'
f' learning rate: {optimizer.GetLearningRate()}')

def log_iteration():
iteration_data[optimizer.GetCurrentIteration() + 1] = list(optimizer.GetCurrentPosition())

# Set first value to default transform position
iteration_data[0] = list(optimizer.GetCurrentPosition())

[11]:

# Run optimization and print out results
optimizer.StartOptimization()

print(f'Number of iterations: {optimizer.GetCurrentIteration() - 1}')
print(f'Moving-source final value: {optimizer.GetCurrentMetricValue()}')
print(f'Moving-source final position: {list(optimizer.GetCurrentPosition())}')
print(f'Optimizer scales: {list(optimizer.GetScales())}')
print(f'Optimizer learning rate: {optimizer.GetLearningRate()}')
print(f'Stop reason: {optimizer.GetStopConditionDescription()}')

It: 0 metric value: 0.000000  learning rate: 1.0
It: 0 metric value: -0.043464  learning rate: 5594.753388446298
It: 1 metric value: -0.054787  learning rate: 5594.753388446298
It: 2 metric value: -0.062597  learning rate: 5594.753388446298
It: 3 metric value: -0.064588  learning rate: 5594.753388446298
It: 4 metric value: -0.064807  learning rate: 5594.753388446298
It: 5 metric value: -0.064815  learning rate: 5594.753388446298
It: 6 metric value: -0.064815  learning rate: 5594.753388446298
It: 7 metric value: -0.064815  learning rate: 5594.753388446298
It: 8 metric value: -0.064815  learning rate: 5594.753388446298
It: 9 metric value: -0.064815  learning rate: 5594.753388446298
It: 10 metric value: -0.064815  learning rate: 5594.753388446298
Number of iterations: 9
Moving-source final value: -0.06481531061643396
Moving-source final position: [15.000412861069881, 14.99997463945473]
Optimizer scales: [1.0000000000010232, 1.0000000000010232]
Optimizer learning rate: 5594.753388446298
Stop reason: GradientDescentOptimizerv4Template: Maximum number of iterations (10) exceeded.


## Resample Moving Point Set¶

[12]:

moving_inverse = metric.GetMovingTransform().GetInverseTransform()
fixed_inverse = metric.GetFixedTransform().GetInverseTransform()

[13]:

transformed_fixed_set = PointSetType.New()
transformed_moving_set = PointSetType.New()

for n in range(0,metric.GetNumberOfComponents()):
transformed_moving_point = moving_inverse.TransformPoint(moving_set.GetPoint(n))
transformed_moving_set.SetPoint(n,transformed_moving_point)

transformed_fixed_point = fixed_inverse.TransformPoint(fixed_set.GetPoint(n))
transformed_fixed_set.SetPoint(n,transformed_fixed_point)

[14]:

# Compare fixed point set with resampled moving point set to see alignment

fig = plt.figure()
ax = plt.axes()

n_points = fixed_set.GetNumberOfPoints()
ax.scatter([fixed_set.GetPoint(i)[0] for i in range(0,n_points)],
[fixed_set.GetPoint(i)[1] for i in range(0,n_points)])
ax.scatter([transformed_moving_set.GetPoint(i)[0] for i in range(0,n_points)],
[transformed_moving_set.GetPoint(i)[1] for i in range(0,n_points)])

[14]:

<matplotlib.collections.PathCollection at 0x1d488213850>


We can use the ITK ExhaustiveOptimizerv4 class to view how the optimizer moved along the surface defined by the transform parameters and metric.

[15]:

# Set up the new optimizer

# Create a new transform and metric for analysis
transform = TransformType.New()
transform.SetIdentity()

metric = PointSetMetricType.New(
FixedPointSet=fixed_set,
MovingPointSet=moving_set,
MovingTransform=transform,
PointSetSigma=5,
KernelSigma=10.0,
UseAnisotropicCovariances=False,
CovarianceKNeighborhood=5,
EvaluationKNeighborhood=10,
Alpha=1.1)
metric.Initialize()

# Create a new observer to map out the parameter surface
optimizer.RemoveAllObservers()
optimizer = ExhaustiveOptimizerType.New(
Metric=metric)

# Use observers to collect points on the surface
param_space = dict()
def log_exhaustive_iteration():
param_space[tuple(optimizer.GetCurrentPosition())] = optimizer.GetCurrentValue()

# Collect a moderate number of steps along each transform parameter
step_count = 25
optimizer.SetNumberOfSteps([step_count,step_count])

# Step a reasonable distance along each transform parameter
scales = optimizer.GetScales()
scales.SetSize(2)

scale_size = 1.0
scales.SetElement(0, scale_size)
scales.SetElement(1, scale_size)

optimizer.SetScales(scales)

[16]:

optimizer.StartOptimization()
print(f'MinimumMetricValue: {optimizer.GetMinimumMetricValue():.4f}\t'
f'MaximumMetricValue: {optimizer.GetMaximumMetricValue():.4f}\n'
f'MinimumMetricValuePosition: {list(optimizer.GetMinimumMetricValuePosition())}\t'
f'MaximumMetricValuePosition: {list(optimizer.GetMaximumMetricValuePosition())}\n'
f'StopConditionDescription: {optimizer.GetStopConditionDescription()}\t')

MinimumMetricValue: -0.0648     MaximumMetricValue: -0.0153
MinimumMetricValuePosition: [15.0, 15.0]        MaximumMetricValuePosition: [-25.0, -25.0]
StopConditionDescription: ExhaustiveOptimizerv4: Completed sampling of parametric space of size 2

[17]:

# Reformat gradient descent data to overlay on the plot
descent_x_vals = [iteration_data[i][0] for i in range(0,len(iteration_data))]
descent_y_vals = [iteration_data[i][1] for i in range(0,len(iteration_data))]

[18]:

# Plot the surface, extrema, and gradient descent data in a matplotlib scatter plot

fig = plt.figure()
ax = plt.axes()

ax.scatter([x for (x,y) in param_space.keys()],
[y for (x,y) in param_space.keys()],
c=list(param_space.values()),
cmap='Greens',zorder=1);
ax.plot(optimizer.GetMinimumMetricValuePosition()[0],
optimizer.GetMinimumMetricValuePosition()[1],
'kv')
ax.plot(optimizer.GetMaximumMetricValuePosition()[0],
optimizer.GetMaximumMetricValuePosition()[1],
'w^')

for i in range(0,len(iteration_data)):
ax.plot(descent_x_vals[i:i+2],descent_y_vals[i:i+2],'rx-')
ax.plot(descent_x_vals[0],descent_y_vals[0],'ro')
ax.plot(descent_x_vals[len(iteration_data) - 1],
descent_y_vals[len(iteration_data) - 1],'bo')

[18]:

[<matplotlib.lines.Line2D at 0x1d48834c940>]


We can also view and export the surface as an image using itkwidgets.

[19]:

x_vals = list(set(x for (x,y) in param_space.keys()))
y_vals = list(set(y for (x,y) in param_space.keys()))

x_vals.sort()
y_vals.sort(reverse=True)
array = np.array([[param_space[(x,y)]
for x in x_vals]
for y in y_vals])

image_view = itk.GetImageViewFromArray(array)

[20]:

view(image_view)