[Insight-users] About optimizers for registration

DiLLa HaNdiNi daffodilsky at gmail.com
Sun Jan 15 20:45:25 EST 2006


Dear All,

I tried to apply different optimizers for doing registration of the sample
images in the BrainWeb.
I tried using both Powell and RegularStepGradientDescent Optimizer.
Both give me different results. The algorithm was modified from the sample
programs; attached below.
Powell's result is as expected, though if I changed some parameters (i.e.
step size, optimizer scales, multiresolution level ) I also got different
results. Even, if I tried to give initial transformation close to the
answer, it ended up to give wrong result.

My questions:
1. What's wrong with the code using the RegularStepGradientDescent Optimizer
below ? Why couldn't it give the correct result ?
2. What's the significance of the parameters in the codes ?
3. Can I say that registration is dependant on the parameters that we
specify ? So, it's like a guessing problem ? When would only can get result
if we guess it correctly ?

Could anybody answer me please.
Thank you in advance.
Regards.

Dilla

Codes:


#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif

#include "itkMultiResolutionImageRegistrationMethod.h"
#include "itkMultiResolutionPyramidImageFilter.h"
#include "itkMattesMutualInformationImageToImageMetric.h"
//#include " itkLinearInterpolateImageFunction.h"
#include "itkBSplineInterpolateImageFunction.h"
//#include "itkPowellOptimizer.h"
#include "itkRegularStepGradientDescentOptimizer.h"
#include " itkNormalVariateGenerator.h"

#include "itkImage.h"

#include "itkTimeProbesCollectorBase.h"

#include "itkEuler3DTransform.h"
#include "itkCenteredTransformInitializer.h"

#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"

#include "itkResampleImageFilter.h"
#include "itkSubtractImageFilter.h"
#include "itkRescaleIntensityImageFilter.h"

#include "itkCommand.h"

class CommandIterationUpdate : public itk::Command
{
public:
  typedef  CommandIterationUpdate   Self;
  typedef  itk::Command             Superclass;
  typedef itk::SmartPointer<Self>  Pointer;
  itkNewMacro( Self );
protected:
  CommandIterationUpdate() {};
public:
  typedef itk::RegularStepGradientDescentOptimizer OptimizerType;
  typedef   const OptimizerType   *    OptimizerPointer;

  void Execute(itk::Object *caller, const itk::EventObject & event)
    {
      Execute( (const itk::Object *)caller, event);
    }

  void Execute(const itk::Object * object, const itk::EventObject & event)
    {
      OptimizerPointer optimizer =
        dynamic_cast< OptimizerPointer >( object );
      if( ! itk::IterationEvent().CheckEvent( &event ) )
        {
        return;
        }
      std::cout << optimizer->GetCurrentIteration() << "   ";
      std::cout << optimizer->GetValue() << "   ";
      std::cout << optimizer->GetCurrentPosition() << std::endl;
    }
};

template <typename TRegistration>
class RegistrationInterfaceCommand : public itk::Command
{
public:
  typedef  RegistrationInterfaceCommand   Self;
  typedef  itk::Command                   Superclass;
  typedef  itk::SmartPointer<Self>        Pointer;
  itkNewMacro( Self );
protected:
  RegistrationInterfaceCommand() {};
public:
  typedef   TRegistration                              RegistrationType;
  typedef   RegistrationType *                         RegistrationPointer;
  typedef   itk::RegularStepGradientDescentOptimizer      OptimizerType;
  typedef   OptimizerType *                            OptimizerPointer;
  void Execute(itk::Object * object, const itk::EventObject & event)
  {
    if(itk::IterationEvent().CheckEvent( &event ))
    {

     RegistrationPointer registration =
                        dynamic_cast<RegistrationPointer>( object );
     OptimizerPointer optimizer = dynamic_cast< OptimizerPointer >(
                       registration->GetOptimizer() );

 if ( registration->GetCurrentLevel() == 0 )
      {
      optimizer->SetMaximumStepLength( 0.1 );
      optimizer->SetMinimumStepLength( 0.001 );
      }
    else
      {
  double currentLevel = (double)registration->GetCurrentLevel();
  double totalLevel = (double)registration->GetNumberOfLevels();
      optimizer->SetMaximumStepLength( 0.1 * currentLevel/totalLevel);
      optimizer->SetMinimumStepLength( 0.001 * currentLevel/totalLevel);
      }
 std::cout << "Interation: " << registration->GetCurrentLevel() <<
std::endl;
    }
  }
  void Execute(const itk::Object * , const itk::EventObject & )
    { return; }
};

int main(  )
{
/*  if( argc < 4 )
    {
    std::cerr << "Missing Parameters " << std::endl;
    std::cerr << "Usage: " << argv[0];
    std::cerr << " fixedImageFile  movingImageFile ";
    std::cerr << " outputImagefile  [differenceAfterRegistration] ";
    std::cerr << " [differenceBeforeRegistration] ";
    std::cerr << " [initialStepLength] "<< std::endl;
    return EXIT_FAILURE;
    }
 */
  const    unsigned int    Dimension = 3;
  typedef  unsigned char   PixelType;

  typedef itk::Image< PixelType, Dimension >  FixedImageType;
  typedef itk::Image< PixelType, Dimension >  MovingImageType;
  typedef   float     InternalPixelType;
  typedef itk::Image< InternalPixelType, Dimension > InternalImageType;

  typedef itk::Euler3DTransform< double > TransformType;

  typedef itk::RegularStepGradientDescentOptimizer OptimizerType;

 typedef itk::MattesMutualInformationImageToImageMetric<
                                    InternalImageType,
                                    InternalImageType >    MetricType;

 typedef itk:: LinearInterpolateImageFunction<
                                    InternalImageType,
                                    double          >    InterpolatorType;


/*   typedef itk::BSplineInterpolateImageFunction<
                       InternalImageType, double >  InterpolatorType;
*/

  typedef itk::MultiResolutionImageRegistrationMethod<
                                    InternalImageType,
                                    InternalImageType >   RegistrationType;

  typedef itk::MultiResolutionPyramidImageFilter<
                                    InternalImageType,
                                    InternalImageType >
FixedImagePyramidType;
  typedef itk::MultiResolutionPyramidImageFilter<
                                    InternalImageType,
                                    InternalImageType >
MovingImagePyramidType;

  MetricType::Pointer         metric        = MetricType::New();
  OptimizerType::Pointer      optimizer     = OptimizerType::New();
  InterpolatorType::Pointer   interpolator  = InterpolatorType::New();
  RegistrationType::Pointer   registration  = RegistrationType::New();

  registration->SetMetric(        metric        );
  registration->SetOptimizer(     optimizer     );
  registration->SetInterpolator(  interpolator  );


  TransformType::Pointer  transform = TransformType::New();
  registration->SetTransform( transform );

  typedef itk::ImageFileReader< FixedImageType  > FixedImageReaderType;
  typedef itk::ImageFileReader< MovingImageType > MovingImageReaderType;

  FixedImageReaderType::Pointer  fixedImageReader  =
FixedImageReaderType::New();
  MovingImageReaderType::Pointer movingImageReader =
MovingImageReaderType::New();

  fixedImageReader->SetFileName("brainweb1e1a10f20.mha");
  movingImageReader->SetFileName( "brainweb1e1a10f20Rot10Tx15.mha" );

  fixedImageReader->Update();
  FixedImageType::SizeType size;
  size =
fixedImageReader->GetOutput()->GetLargestPossibleRegion().GetSize();

  unsigned int numberOfBins = 30;
  double percentOfSamples = 0.2;
  unsigned int numberOfSamples = (unsigned
int)(percentOfSamples*size[0]*size[1]*size[2]);
  metric->SetNumberOfSpatialSamples(numberOfSamples);
  metric->SetNumberOfHistogramBins(numberOfBins);

  typedef itk::CastImageFilter<
                        FixedImageType, InternalImageType >
FixedCastFilterType;
  typedef itk::CastImageFilter<
                        MovingImageType, InternalImageType >
MovingCastFilterType;

  FixedCastFilterType::Pointer fixedCaster   = FixedCastFilterType::New();
  MovingCastFilterType::Pointer movingCaster = MovingCastFilterType::New();

  fixedCaster->SetInput(fixedImageReader->GetOutput());
  movingCaster->SetInput(movingImageReader->GetOutput());

  FixedImagePyramidType::Pointer fixedImagePyramid =
      FixedImagePyramidType::New();
  MovingImagePyramidType::Pointer movingImagePyramid =
      MovingImagePyramidType::New();
  registration->SetFixedImagePyramid(fixedImagePyramid);
  registration->SetMovingImagePyramid(movingImagePyramid);

  registration->SetFixedImage(    fixedCaster->GetOutput()    );
  registration->SetMovingImage(   movingCaster->GetOutput()   );

////////////////////////////////////////////

  InternalImageType::IndexType regStart;
  regStart[0] = 0;
  regStart[1] = 0;
  regStart[2] = 0;

  InternalImageType::SizeType regSize;
  regSize[0] = size[0];
  regSize[1] = size[1];
  regSize[2] = size[2];

  InternalImageType::RegionType region;
  region.SetIndex(regStart);
  region.SetSize(regSize);

  registration->SetFixedImageRegion(region);

/////////////////////////////////////////////

  fixedImageReader->Update();

  movingImageReader->Update();


  typedef FixedImageType::SpacingType    SpacingType;
  typedef FixedImageType::PointType      OriginType;
  typedef FixedImageType::RegionType     RegionType;
  typedef FixedImageType::SizeType       SizeType;

  FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();

  const SpacingType fixedSpacing = fixedImage->GetSpacing();
  OriginType  fixedOrigin  = fixedImage->GetOrigin();
  const RegionType  fixedRegion  = fixedImage->GetLargestPossibleRegion();
  const SizeType    fixedSize    = fixedRegion.GetSize();


  TransformType::InputPointType centerFixed;

  centerFixed[0] = regStart[0] + fixedSpacing[0] * regSize[0] / 2.0 ;
  centerFixed[1] = regStart[1] + fixedSpacing[1] * regSize[1] / 2.0 ;
  centerFixed[2] = regStart[2] + fixedSpacing[2] * regSize[2] / 2.0 ;
  // moving image
  MovingImageType::Pointer movingImage = movingImageReader->GetOutput();

  const SpacingType movingSpacing = movingImage->GetSpacing();
  const OriginType  movingOrigin  = movingImage->GetOrigin();
  const RegionType  movingRegion  = movingImage->GetLargestPossibleRegion();
  const SizeType    movingSize    = movingRegion.GetSize();

  TransformType::InputPointType centerMoving;

  centerMoving[0] = movingOrigin[0] + movingSpacing[0] * movingSize[0] / 2.0
;
  centerMoving[1] = movingOrigin[1] + movingSpacing[1] * movingSize[1] / 2.0
;
  centerMoving[2] = movingOrigin[2] + movingSpacing[2] * movingSize[2] / 2.0
;

  transform->SetCenter( centerFixed );
  transform->SetTranslation( centerMoving - centerFixed );

 // transform->SetIdentity(  );

  registration->SetInitialTransformParameters( transform->GetParameters() );

  typedef OptimizerType::ScalesType       OptimizerScalesType;
  OptimizerScalesType optimizerScales( transform->GetNumberOfParameters() );
  const double translationScale = 1.0;

  optimizerScales[0] = 1.0;//rotation
  optimizerScales[1] = 1.0;//rotation
  optimizerScales[2] = 1.0;//rotation
  optimizerScales[3] = translationScale;
  optimizerScales[4] = translationScale;
  optimizerScales[5] = translationScale;

  optimizer->SetScales( optimizerScales );

/*  double initialStepLength = 5;


  // Power's Optimizer
  optimizer->SetMaximumIteration(200);
  optimizer->SetStepLength(2.0);
  optimizer->SetStepTolerance(0.01);
//  optimizer->SetMaximize(false);
*/
  // Create the Command observer and register it with the optimizer.
  //
 optimizer->SetNumberOfIterations( 200 );
//optimizer->SetMaximumStepLength( 10.00 );
//      optimizer->SetMinimumStepLength( 2 );
  CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
  optimizer->AddObserver( itk::IterationEvent(), observer );
  // Create the command observer for registration
  //
  typedef RegistrationInterfaceCommand<RegistrationType> CommandType;
  CommandType::Pointer command = CommandType::New();
  registration->AddObserver( itk::IterationEvent(), command );

 // if ( argc > 6 ) {
// registration->SetNumberOfLevels(std::atoi(argv[6]));
 // } else {
   registration->SetNumberOfLevels( 5 );
 // }

  itk::TimeProbesCollectorBase timer;
  try
    {
 timer.Start("registration");
    registration->StartRegistration();
 timer.Stop("registration");
    }
  catch( itk::ExceptionObject & err )
    {
    std::cerr << "ExceptionObject caught !" << std::endl;
    std::cerr << err << std::endl;
    return EXIT_FAILURE;
    }

  timer.Report();

  OptimizerType::ParametersType finalParameters =
                    registration->GetLastTransformParameters();

  const double finalAngleX           = finalParameters[0];
  const double finalAngleY           = finalParameters[1];
  const double finalAngleZ           = finalParameters[2];
  const double finalTranslationX     = finalParameters[3];
  const double finalTranslationY     = finalParameters[4];
  const double finalTranslationZ     = finalParameters[5];

  const unsigned int numberOfIterations = optimizer->GetCurrentIteration();

  const double bestValue = optimizer->GetValue();


  // Print out results
  //
  const double finalAngleXInDegrees = finalAngleX * 45.0 / atan(1.0);
  const double finalAngleYInDegrees = finalAngleY * 45.0 / atan(1.0);
  const double finalAngleZInDegrees = finalAngleZ * 45.0 / atan(1.0);

  std::cout << "Result = " << std::endl;

  std::cout << " Angle X (degrees)   = " << finalAngleXInDegrees  <<
std::endl;
  std::cout << " Angle Y (degrees)   = " << finalAngleYInDegrees  <<
std::endl;
  std::cout << " Angle Z (degrees)   = " << finalAngleZInDegrees  <<
std::endl;
  std::cout << " Translation X = " << finalTranslationX  << std::endl;
  std::cout << " Translation Y = " << finalTranslationY  << std::endl;
  std::cout << " Translation Z = " << finalTranslationZ  << std::endl;
  std::cout << " Iterations    = " << numberOfIterations << std::endl;
  std::cout << " Metric value  = " << bestValue          << std::endl;

  return EXIT_SUCCESS;
}
-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://public.kitware.com/pipermail/insight-users/attachments/20060116/148231c5/attachment.html


More information about the Insight-users mailing list