int
main(int argc, char * argv[])
{
if (argc < 7)
{
std::cerr << "Usage: " << std::endl;
std::cerr << argv[0];
std::cerr << " inputScalarImage inputLabeledImage";
std::cerr << " outputLabeledImage numberOfIterations";
std::cerr << " smoothingFactor numberOfClasses";
std::cerr << " mean1 mean2 ... meanN " << std::endl;
return EXIT_FAILURE;
}
const char * inputImageFileName = argv[1];
const char * inputLabelImageFileName = argv[2];
const char * outputImageFileName = argv[3];
const unsigned int numberOfIterations = std::stoi(argv[4]);
const double smoothingFactor = std::stod(argv[5]);
const unsigned int numberOfClasses = std::stoi(argv[6]);
constexpr unsigned int numberOfArgumentsBeforeMeans = 7;
if (static_cast<unsigned int>(argc) <
numberOfClasses + numberOfArgumentsBeforeMeans)
{
std::cerr << "Error: " << std::endl;
std::cerr << numberOfClasses << " classes have been specified ";
std::cerr << "but not enough means have been provided in the command ";
std::cerr << "line arguments " << std::endl;
return EXIT_FAILURE;
}
using PixelType = short;
reader->SetFileName(inputImageFileName);
using LabelPixelType = unsigned char;
labelReader->SetFileName(inputLabelImageFileName);
using ScalarToArrayFilterType =
scalarToArrayFilter->SetInput(reader->GetOutput());
mrfFilter->SetInput(scalarToArrayFilter->GetOutput());
mrfFilter->SetNumberOfClasses(numberOfClasses);
mrfFilter->SetMaximumNumberOfIterations(numberOfIterations);
mrfFilter->SetErrorTolerance(1
e-7);
mrfFilter->SetSmoothingFactor(smoothingFactor);
using SupervisedClassifierType =
classifier->SetDecisionRule(classifierDecisionRule);
using MembershipFunctionType =
double meanDistance = 0;
MembershipFunctionType::CentroidType centroid(1);
for (unsigned int i = 0; i < numberOfClasses; ++i)
{
MembershipFunctionPointer membershipFunction =
centroid[0] = std::stod(argv[i + numberOfArgumentsBeforeMeans]);
membershipFunction->SetCentroid(centroid);
classifier->AddMembershipFunction(membershipFunction);
meanDistance += static_cast<double>(centroid[0]);
}
if (numberOfClasses > 0)
{
meanDistance /= numberOfClasses;
}
else
{
std::cerr << "ERROR: numberOfClasses is 0" << std::endl;
return EXIT_FAILURE;
}
mrfFilter->SetSmoothingFactor(smoothingFactor);
mrfFilter->SetNeighborhoodRadius(1);
std::vector<double> weights;
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(0.0);
weights.push_back(2.0);
weights.push_back(1.5);
weights.push_back(2.0);
weights.push_back(1.5);
double totalWeight = 0;
for (double weight : weights)
{
totalWeight += weight;
}
for (double & weight : weights)
{
weight = static_cast<double>(weight * meanDistance / (2 * totalWeight));
}
mrfFilter->SetMRFNeighborhoodWeight(weights);
mrfFilter->SetClassifier(classifier);
using OutputImageType = MRFFilterType::OutputImageType;
using RescalerType =
RescaledOutputImageType>;
intensityRescaler->SetOutputMinimum(0);
intensityRescaler->SetOutputMaximum(255);
intensityRescaler->SetInput(mrfFilter->GetOutput());
writer->SetInput(intensityRescaler->GetOutput());
writer->SetFileName(outputImageFileName);
try
{
writer->Update();
}
catch (const itk::ExceptionObject & excp)
{
std::cerr << "Problem encountered while writing ";
std::cerr << " image file : " << argv[2] << std::endl;
std::cerr << excp << std::endl;
return EXIT_FAILURE;
}
std::cout << "Number of Iterations : ";
std::cout << mrfFilter->GetNumberOfIterations() << std::endl;
std::cout << "Stop condition: " << std::endl;
std::cout << " (1) Maximum number of iterations " << std::endl;
std::cout << " (2) Error tolerance: " << std::endl;
std::cout << mrfFilter->GetStopCondition() << std::endl;
return EXIT_SUCCESS;
}