#include <iostream>
#include <string>
#include <math.h>
#define NUM_CLASSES 3
#define MAX_NUM_ITER 1
int main( int argc, char *argv[] )
{
if( argc != 4 )
{
std::cerr << "Missing Parameters " << std::endl;
std::cerr << "Usage: " << argv[0];
std::cerr << " inputImage trainimage outputImage" << std::endl;
return 1;
}
std::cout<< "Gibbs Prior Test Begins: " << std::endl;
const unsigned short NUMBANDS = 1;
const unsigned short NDIMENSION = 3;
NDIMENSION> VecImageType;
ReaderType::Pointer inputimagereader = ReaderType::New();
ReaderType::Pointer trainingimagereader = ReaderType::New();
WriterType::Pointer writer = WriterType::New();
inputimagereader->SetFileName( argv[1] );
trainingimagereader->SetFileName( argv[2] );
writer->SetFileName( argv[3] );
VecImageType::Pointer vecImage = VecImageType::New();
typedef VecImageType::PixelType VecImagePixelType;
VecImageType::SizeType vecImgSize = { {181 , 217, 1} };
VecImageType::IndexType index;
index.Fill(0);
VecImageType::RegionType region;
region.SetSize( vecImgSize );
region.SetIndex( index );
vecImage->SetLargestPossibleRegion( region );
vecImage->SetBufferedRegion( region );
vecImage->Allocate();
enum { VecImageDimension = VecImageType::ImageDimension };
VecIterator vecIt( vecImage, vecImage->GetBufferedRegion() );
inputimagereader->Update();
trainingimagereader->Update();
ClassIterator inputIt( inputimagereader->GetOutput(), inputimagereader->GetOutput()->GetBufferedRegion() );
typedef VecImageType::PixelType DataVector;
DataVector dblVec;
while ( !vecIt.IsAtEnd() )
{
dblVec[0] = inputIt.Get();
vecIt.Set(dblVec);
++vecIt;
++inputIt;
}
namespace stat = itk::Statistics;
typedef VecImageType::PixelType VecImagePixelType;
typedef stat::MahalanobisDistanceMembershipFunction< VecImagePixelType >
MembershipFunctionType;
typedef MembershipFunctionType::Pointer MembershipFunctionPointer;
typedef std::vector< MembershipFunctionPointer >
MembershipFunctionPointerVector;
MembershipFunctionType, ClassImageType>
ImageGaussianModelEstimatorType;
ImageGaussianModelEstimatorType::Pointer
applyEstimateModel = ImageGaussianModelEstimatorType::New();
applyEstimateModel->SetNumberOfModels(NUM_CLASSES);
applyEstimateModel->SetInputImage(vecImage);
applyEstimateModel->SetTrainingImage(trainingimagereader->GetOutput());
applyEstimateModel->Update();
std::cout << " site 1 " << std::endl;
applyEstimateModel->Print(std::cout);
MembershipFunctionPointerVector membershipFunctions =
applyEstimateModel->GetMembershipFunctions();
std::cout << " site 2 " << std::endl;
DecisionRuleType::Pointer myDecisionRule = DecisionRuleType::New();
std::cout << " site 3 " << std::endl;
ClassImageType > ClassifierType;
typedef ClassifierType::Pointer ClassifierPointer;
ClassifierPointer myClassifier = ClassifierType::New();
myClassifier->SetNumberOfClasses(NUM_CLASSES);
myClassifier->SetDecisionRule((DecisionRuleBasePointer) myDecisionRule );
for( unsigned int i=0; i<NUM_CLASSES; i++ )
{
myClassifier->AddMembershipFunction( membershipFunctions[i] );
}
GibbsPriorFilterType;
GibbsPriorFilterType::Pointer applyGibbsImageFilter =
GibbsPriorFilterType::New();
applyGibbsImageFilter->SetNumberOfClasses(NUM_CLASSES);
applyGibbsImageFilter->SetMaximumNumberOfIterations(MAX_NUM_ITER);
applyGibbsImageFilter->SetClusterSize(10);
applyGibbsImageFilter->SetBoundaryGradient(6);
applyGibbsImageFilter->SetObjectLabel(1);
applyGibbsImageFilter->SetInput(vecImage);
applyGibbsImageFilter->SetClassifier( myClassifier );
applyGibbsImageFilter->SetTrainingImage(trainingimagereader->GetOutput());
applyGibbsImageFilter->Update();
std::cout << "applyGibbsImageFilter: " << applyGibbsImageFilter;
writer->SetInput( applyGibbsImageFilter->GetOutput() );
writer->Update();
return 0;
}