24 from math
import pi, sin, cos
30 def make_circles(l_dimension: int = 2):
34 offset = [2.0] * l_dimension
36 fixed_points = PointSetType.New()
37 moving_points = PointSetType.New()
38 fixed_points.Initialize()
39 moving_points.Initialize()
42 for count
in range(0, int(2 * pi / step) + 1):
47 fixed_point.append(RADIUS * cos(theta))
48 for dim
in range(1, l_dimension):
49 fixed_point.append(RADIUS * sin(theta))
50 fixed_points.SetPoint(count, fixed_point)
52 moving_point = [fixed_point[dim] + offset[dim]
for dim
in range(0, l_dimension)]
53 moving_points.SetPoint(count, moving_point)
55 return fixed_points, moving_points
58 def test_registration(l_dimension: int = 2):
78 fixed_set, moving_set = make_circles(l_dimension)
80 transform = AffineTransformType.New()
81 transform.SetIdentity()
83 metric = PointSetMetricType.New(
84 FixedPointSet=fixed_set,
85 MovingPointSet=moving_set,
88 UseAnisotropicCovariances=
False,
89 CovarianceKNeighborhood=5,
90 EvaluationKNeighborhood=10,
91 MovingTransform=transform,
96 shift_scale_estimator = ShiftScalesType.New(
97 Metric=metric, VirtualDomainPointSet=metric.GetVirtualTransformedPointSet()
100 optimizer = OptimizerType.New(
102 NumberOfIterations=num_iterations,
103 ScalesEstimator=shift_scale_estimator,
104 MaximumStepSizeInPhysicalUnits=3.0,
105 MinimumConvergenceValue=0.0,
106 ConvergenceWindowSize=10,
109 def print_iteration():
111 f
"It: {optimizer.GetCurrentIteration()}"
112 f
" metric value: {optimizer.GetCurrentMetricValue():.6f} "
115 optimizer.AddObserver(itk.IterationEvent(), print_iteration)
118 optimizer.StartOptimization()
120 print(f
"Number of iterations: {num_iterations}")
121 print(f
"Moving-source final value: {optimizer.GetCurrentMetricValue()}")
122 print(f
"Moving-source final position: {list(optimizer.GetCurrentPosition())}")
123 print(f
"Optimizer scales: {list(optimizer.GetScales())}")
124 print(f
"Optimizer learning rate: {optimizer.GetLearningRate()}")
127 print(
"Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff")
129 moving_inverse = metric.GetMovingTransform().GetInverseTransform()
130 fixed_inverse = metric.GetFixedTransform().GetInverseTransform()
132 def print_point(vals: list) -> str:
133 return f
'[{",".join(f"{x:.4f}" for x in vals)}]'
135 for n
in range(0, metric.GetNumberOfComponents()):
136 transformed_moving_point = moving_inverse.TransformPoint(moving_set.GetPoint(n))
137 transformed_fixed_point = fixed_inverse.TransformPoint(fixed_set.GetPoint(n))
140 transformed_moving_point[dim] - transformed_fixed_point[dim]
141 for dim
in range(0, l_dimension)
145 f
"{print_point(fixed_set.GetPoint(n))}"
146 f
"\t{print_point(moving_set.GetPoint(n))}"
147 f
"\t{print_point(transformed_moving_point)}"
148 f
"\t{print_point(transformed_fixed_point)}"
149 f
"\t{print_point(difference)}"
152 if any(
abs(difference[dim]) > tolerance
for dim
in range(0, l_dimension)):
156 raise Exception(
"Transform outside of allowable tolerance")
158 print(
"Transform is within allowable tolerance.")
161 if __name__ ==
"__main__":
162 if len(sys.argv) == 2:
163 dimension = int(sys.argv[1])
164 test_registration(dimension)