static void
CreateImage(ColorImageType::Pointer image);
int
main(int, char *[])
{
ColorImageType::Pointer image = ColorImageType::New();
CreateImage(image);
using MembershipFunctionPointer = MembershipFunctionType::Pointer;
using MembershipFunctionPointerVector = std::vector<MembershipFunctionPointer>;
ImageKmeansModelEstimatorType::Pointer kmeansEstimator = ImageKmeansModelEstimatorType::New();
kmeansEstimator->SetInputImage(image);
kmeansEstimator->SetNumberOfModels(3);
kmeansEstimator->SetThreshold(0.01);
kmeansEstimator->SetOffsetAdd(0.01);
kmeansEstimator->SetOffsetMultiply(0.01);
kmeansEstimator->SetMaxSplitAttempts(10);
kmeansEstimator->Update();
ClassifierType::Pointer classifier = ClassifierType::New();
DecisionRuleType::Pointer decisionRule = DecisionRuleType::New();
classifier->SetDecisionRule(decisionRule);
classifier->SetNumberOfClasses(3);
using ClassLabelVectorObjectType = ClassifierType::ClassLabelVectorObjectType;
using ClassLabelVectorType = ClassifierType::ClassLabelVectorType;
using MembershipFunctionVectorObjectType = ClassifierType::MembershipFunctionVectorObjectType;
using MembershipFunctionVectorType = ClassifierType::MembershipFunctionVectorType;
MembershipFunctionPointerVector kmeansMembershipFunctions = kmeansEstimator->GetMembershipFunctions();
MembershipFunctionVectorObjectType::Pointer membershipFunctionsVectorObject =
MembershipFunctionVectorObjectType::New();
classifier->SetMembershipFunctions(membershipFunctionsVectorObject);
MembershipFunctionVectorType & membershipFunctionsVector = membershipFunctionsVectorObject->Get();
for (auto & kmeansMembershipFunction : kmeansMembershipFunctions)
{
membershipFunctionsVector.push_back(kmeansMembershipFunction.GetPointer());
}
ClassLabelVectorObjectType::Pointer classLabelsObject = ClassLabelVectorObjectType::New();
classifier->SetClassLabels(classLabelsObject);
ClassLabelVectorType & classLabelsVector = classLabelsObject->Get();
classLabelsVector.push_back(50);
classLabelsVector.push_back(150);
classLabelsVector.push_back(250);
SampleAdaptorType::Pointer sample = SampleAdaptorType::New();
sample->SetImage(image);
classifier->SetInput(sample);
classifier->Update();
ScalarImageType::Pointer outputImage = ScalarImageType::New();
outputImage->SetRegions(image->GetLargestPossibleRegion());
outputImage->Allocate();
outputImage->FillBuffer(0);
const ClassifierType::MembershipSampleType * membershipSample = classifier->GetOutput();
ClassifierType::MembershipSampleType::ConstIterator membershipIterator = membershipSample->Begin();
outputImage->GetLargestPossibleRegion());
while (membershipIterator != membershipSample->End())
{
int classLabel = membershipIterator.GetClassLabel();
outputIterator.Set(classLabel);
++membershipIterator;
++outputIterator;
}
WriterType::Pointer inputWriter = WriterType::New();
inputWriter->SetFileName("input.mha");
inputWriter->SetInput(image);
inputWriter->Update();
ScalarWriterType::Pointer outputWriter = ScalarWriterType::New();
outputWriter->SetFileName("output.mha");
outputWriter->SetInput(outputImage);
outputWriter->Update();
return EXIT_SUCCESS;
}
void
CreateImage(ColorImageType::Pointer image)
{
start[0] = 0;
start[1] = 0;
size[0] = 200;
size[1] = 300;
region.SetIndex(start);
image->SetRegions(region);
image->Allocate();
redPixel[0] = 255;
redPixel[1] = 0;
redPixel[2] = 0;
greenPixel[0] = 0;
greenPixel[1] = 255;
greenPixel[2] = 0;
blackPixel[0] = 0;
blackPixel[1] = 0;
blackPixel[2] = 0;
while (!imageIterator.IsAtEnd())
{
if (imageIterator.GetIndex()[0] > 100 && imageIterator.GetIndex()[0] < 150 && imageIterator.GetIndex()[1] > 100 &&
imageIterator.GetIndex()[1] < 150)
{
imageIterator.Set(redPixel);
}
else if (imageIterator.GetIndex()[0] > 50 && imageIterator.GetIndex()[0] < 70 && imageIterator.GetIndex()[1] > 50 &&
imageIterator.GetIndex()[1] < 70)
{
imageIterator.Set(greenPixel);
}
else
{
imageIterator.Set(blackPixel);
}
++imageIterator;
}
}