#undef USE_CONTROLLED_IMAGE
#ifdef USE_CONTROLLED_IMAGE
static void
ControlledImage(ImageType::Pointer image);
#else
static void
RandomImage(ImageType::Pointer image);
#endif
int
main(int , char * [])
{
ImageType::Pointer image = ImageType::New();
#ifdef USE_CONTROLLED_IMAGE
ControlledImage(image);
#else
RandomImage(image);
#endif
ImageToListSampleFilterType::Pointer imageToListSampleFilter = ImageToListSampleFilterType::New();
imageToListSampleFilter->SetInput(image);
imageToListSampleFilter->Update();
unsigned int numberOfClasses = 3;
ParametersType params(numberOfClasses + numberOfClasses * numberOfClasses);
std::vector<ParametersType> initialParameters(numberOfClasses);
for (unsigned int i = 0; i < 3; i++)
{
params[i] = 5.0;
}
unsigned int counter = 0;
for (unsigned int i = 0; i < 3; i++)
{
for (unsigned int j = 0; j < 3; j++)
{
if (i == j)
{
params[3 + counter] = 5;
}
else
{
params[3 + counter] = 0;
}
counter++;
}
}
initialParameters[0] = params;
params[0] = 210.0;
params[1] = 5.0;
params[2] = 5.0;
counter = 0;
for (unsigned int i = 0; i < 3; i++)
{
for (unsigned int j = 0; j < 3; j++)
{
if (i == j)
{
params[3 + counter] = 5;
}
else
{
params[3 + counter] = 0;
}
counter++;
}
}
initialParameters[1] = params;
params[0] = 5.0;
params[1] = 210.0;
params[2] = 5.0;
counter = 0;
for (unsigned int i = 0; i < 3; i++)
{
for (unsigned int j = 0; j < 3; j++)
{
if (i == j)
{
params[3 + counter] = 5;
}
else
{
params[3 + counter] = 0;
}
counter++;
}
}
initialParameters[2] = params;
std::cout << "Initial parameters: " << std::endl;
for (unsigned int i = 0; i < numberOfClasses; i++)
{
std::cout << initialParameters[i] << std::endl;
}
std::cout << "Number of samples: " << imageToListSampleFilter->GetOutput()->GetTotalFrequency() << std::endl;
std::vector<ComponentType::Pointer> components;
for (unsigned int i = 0; i < numberOfClasses; i++)
{
components.push_back(ComponentType::New());
(components[i])->SetSample(imageToListSampleFilter->GetOutput());
(components[i])->SetParameters(initialParameters[i]);
}
using EstimatorType =
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetSample(imageToListSampleFilter->GetOutput());
estimator->SetMaximumIteration(200);
initialProportions[0] = 0.33;
initialProportions[1] = 0.33;
initialProportions[2] = 0.33;
std::cout << "Initial proportions: " << initialProportions << std::endl;
estimator->SetInitialProportions(initialProportions);
for (unsigned int i = 0; i < numberOfClasses; i++)
{
estimator->AddComponent(components[i]);
}
estimator->Update();
for (unsigned int i = 0; i < numberOfClasses; i++)
{
std::cout << "Cluster[" << i << "]" << std::endl;
std::cout << " Parameters:" << std::endl;
std::cout << " " << (components[i])->GetFullParameters() << std::endl;
std::cout << " Proportion: ";
std::cout << " " << estimator->GetProportions()[i] << std::endl;
}
DecisionRuleType::Pointer decisionRule = DecisionRuleType::New();
using ClassLabelVectorObjectType = FilterType::ClassLabelVectorObjectType;
using ClassLabelVectorType = FilterType::ClassLabelVectorType;
ClassLabelVectorObjectType::Pointer classLabelsObject = ClassLabelVectorObjectType::New();
ClassLabelVectorType & classLabelVector = classLabelsObject->Get();
using ClassLabelType = FilterType::ClassLabelType;
ClassLabelType class0 = 0;
classLabelVector.push_back(class0);
ClassLabelType class1 = 1;
classLabelVector.push_back(class1);
ClassLabelType class2 = 2;
classLabelVector.push_back(class2);
FilterType::Pointer sampleClassifierFilter = FilterType::New();
sampleClassifierFilter->SetInput(imageToListSampleFilter->GetOutput());
sampleClassifierFilter->SetNumberOfClasses(numberOfClasses);
sampleClassifierFilter->SetClassLabels(classLabelsObject);
sampleClassifierFilter->SetDecisionRule(decisionRule);
sampleClassifierFilter->SetMembershipFunctions(estimator->GetOutput());
sampleClassifierFilter->Update();
const FilterType::MembershipSampleType * membershipSample = sampleClassifierFilter->GetOutput();
FilterType::MembershipSampleType::ConstIterator iter = membershipSample->Begin();
while (iter != membershipSample->End())
{
std::cout << (int)iter.GetMeasurementVector()[0] << " " << (int)iter.GetMeasurementVector()[1] << " "
<< (int)iter.GetMeasurementVector()[2] << " : " << iter.GetClassLabel() << std::endl;
++iter;
}
return EXIT_SUCCESS;
}
#ifdef USE_CONTROLLED_IMAGE
void
ControlledImage(ImageType::Pointer image)
{
start[0] = 0;
start[1] = 0;
size[0] = 10;
size[1] = 10;
region.SetIndex(start);
image->SetRegions(region);
image->Allocate();
green[0] = 0;
green[1] = 255;
green[2] = 0;
red[0] = 255;
red[1] = 0;
red[2] = 0;
black[0] = 0;
black[1] = 0;
black[2] = 0;
imageIterator.GoToBegin();
while (!imageIterator.IsAtEnd())
{
if (imageIterator.GetIndex()[0] > 2 && imageIterator.GetIndex()[0] < 5 && imageIterator.GetIndex()[1] > 2 &&
imageIterator.GetIndex()[1] < 5)
{
imageIterator.Set(green);
}
else if (imageIterator.GetIndex()[0] > 6 && imageIterator.GetIndex()[0] < 9 && imageIterator.GetIndex()[1] > 6 &&
imageIterator.GetIndex()[1] < 9)
{
imageIterator.Set(red);
}
else
{
imageIterator.Set(black);
}
++imageIterator;
}
}
#else
void
RandomImage(ImageType::Pointer image)
{
start[0] = 0;
start[1] = 0;
size[0] = 10;
size[1] = 10;
region.SetIndex(start);
image->SetRegions(region);
image->Allocate();
imageIterator.GoToBegin();
while (!imageIterator.IsAtEnd())
{
pixel[0] = rand() * 255;
pixel[1] = rand() * 255;
pixel[2] = rand() * 255;
imageIterator.Set(pixel);
++imageIterator;
}
}
#endif