int main( int argc, char * argv [] )
{
if( argc < 7 )
{
std::cerr << "Usage: " << std::endl;
std::cerr << argv[0];
std::cerr << " inputScalarImage inputLabeledImage";
std::cerr << " outputLabeledImage numberOfIterations";
std::cerr << " smoothingFactor numberOfClasses";
std::cerr << " mean1 mean2 ... meanN " << std::endl;
return EXIT_FAILURE;
}
const char * inputImageFileName = argv[1];
const char * inputLabelImageFileName = argv[2];
const char * outputImageFileName = argv[3];
const unsigned int numberOfIterations = std::stoi( argv[4] );
const double smoothingFactor = std::stod( argv[5] );
const unsigned int numberOfClasses = std::stoi( argv[6] );
constexpr unsigned int numberOfArgumentsBeforeMeans = 7;
if( static_cast<unsigned int>(argc) <
numberOfClasses + numberOfArgumentsBeforeMeans )
{
std::cerr << "Error: " << std::endl;
std::cerr << numberOfClasses << " classes have been specified ";
std::cerr << "but not enough means have been provided in the command ";
std::cerr << "line arguments " << std::endl;
return EXIT_FAILURE;
}
using PixelType = signed short;
ReaderType::Pointer reader = ReaderType::New();
reader->SetFileName( inputImageFileName );
using LabelPixelType = unsigned char;
LabelReaderType::Pointer labelReader = LabelReaderType::New();
labelReader->SetFileName( inputLabelImageFileName );
ImageType, ArrayImageType >;
ScalarToArrayFilterType::Pointer
scalarToArrayFilter = ScalarToArrayFilterType::New();
scalarToArrayFilter->SetInput( reader->GetOutput() );
MRFFilterType::Pointer mrfFilter = MRFFilterType::New();
mrfFilter->SetInput( scalarToArrayFilter->GetOutput() );
mrfFilter->SetNumberOfClasses( numberOfClasses );
mrfFilter->SetMaximumNumberOfIterations( numberOfIterations );
mrfFilter->SetErrorTolerance( 1
e-7 );
mrfFilter->SetSmoothingFactor( smoothingFactor );
ArrayImageType,
LabelImageType >;
SupervisedClassifierType::Pointer classifier =
SupervisedClassifierType::New();
DecisionRuleType::Pointer classifierDecisionRule = DecisionRuleType::New();
classifier->SetDecisionRule( classifierDecisionRule );
using MembershipFunctionType =
using MembershipFunctionPointer = MembershipFunctionType::Pointer;
double meanDistance = 0;
MembershipFunctionType::CentroidType centroid(1);
for( unsigned int i=0; i < numberOfClasses; i++ )
{
MembershipFunctionPointer membershipFunction =
MembershipFunctionType::New();
centroid[0] = std::stod( argv[i+numberOfArgumentsBeforeMeans] );
membershipFunction->SetCentroid( centroid );
classifier->AddMembershipFunction( membershipFunction );
meanDistance += static_cast< double > (centroid[0]);
}
if (numberOfClasses > 0)
{
meanDistance /= numberOfClasses;
}
else
{
std::cerr << "ERROR: numberOfClasses is 0" << std::endl;
return EXIT_FAILURE;
}
mrfFilter->SetSmoothingFactor( smoothingFactor );
mrfFilter->SetNeighborhoodRadius( 1 );
std::vector< double > weights;
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(0.0);
weights.push_back(2.0);
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(1.5);
double totalWeight = 0;
for(std::vector< double >::const_iterator wcIt = weights.begin();
wcIt != weights.end(); ++wcIt )
{
totalWeight += *wcIt;
}
for(double & weight : weights)
{
weight =
static_cast< double > ( weight * meanDistance / (2 * totalWeight));
}
mrfFilter->SetMRFNeighborhoodWeight( weights );
mrfFilter->SetClassifier( classifier );
using OutputImageType = MRFFilterType::OutputImageType;
OutputImageType, RescaledOutputImageType >;
RescalerType::Pointer intensityRescaler = RescalerType::New();
intensityRescaler->SetOutputMinimum( 0 );
intensityRescaler->SetOutputMaximum( 255 );
intensityRescaler->SetInput( mrfFilter->GetOutput() );
WriterType::Pointer writer = WriterType::New();
writer->SetInput( intensityRescaler->GetOutput() );
writer->SetFileName( outputImageFileName );
try
{
writer->Update();
}
{
std::cerr << "Problem encountered while writing ";
std::cerr << " image file : " << argv[2] << std::endl;
std::cerr << excp << std::endl;
return EXIT_FAILURE;
}
std::cout << "Number of Iterations : ";
std::cout << mrfFilter->GetNumberOfIterations() << std::endl;
std::cout << "Stop condition: " << std::endl;
std::cout << " (1) Maximum number of iterations " << std::endl;
std::cout << " (2) Error tolerance: " << std::endl;
std::cout << mrfFilter->GetStopCondition() << std::endl;
return EXIT_SUCCESS;
}