[Insight-users] Wrong values when computing Mean and Covariance !!
Ricardo Ferrari
rjf.araraquara at gmail.com
Sun Dec 27 18:40:17 EST 2009
Dear ITK-Users,
I am working on a class to estimate the initial parameters (mean vector,
covariance matrix and weighting vector) for a Gaussian Mixture model. For
that I am using the KdTreeBasedKmeansEstimator and the refactoring statistic
classes. To be more specifically, I am using the itk::Statistics::
CovarianceSampleFilter to compute the Mean vector and Covariance matrices.
However, the results I am getting are completely wrong. After tracking the
code, it seems that the mean vector is always zero.
Could anybody help me to solve this problem please? I am not sure if I am
missing something.
Thank you,
Ricardo
///
/// Create a image of array from different image contrasts
///
typedef itk::ScalarToArrayCastImageFilter< InputImageType,
ArrayImageType > CasterType;
CasterType::Pointer caster = CasterType::New();
for( int i=0; i < NumberOfContrasts; ++i )
caster->SetInput( i, ReadMincImage< InputImageType >(
inputFileName[i] ) );
caster->Update();
///
/// Generate sample data
///
typedef itk::Statistics::ImageToListSampleAdaptor< TArrayImageType >
ImageToListSampleAdaptorType;
typename ImageToListSampleAdaptorType::Pointer adaptor =
ImageToListSampleAdaptorType::New();
adaptor->SetImage( caster->GetOutput() );
///
/// Create KdTree
///
typedef itk::Statistics::WeightedCentroidKdTreeGenerator <
ImageToListSampleAdaptorType > TreeGeneratorType;
typename TreeGeneratorType::Pointer treeGenerator =
TreeGeneratorType::New();
treeGenerator->SetSample( adaptor );
treeGenerator->SetBucketSize( m_BuckedSize );
treeGenerator->Update();
typedef itk::Statistics::KdTreeBasedKmeansEstimator < typename
TreeGeneratorType::KdTreeType > EstimatorType;
typename EstimatorType::Pointer estimator = EstimatorType::New();
///
/// Estimate classes mean vector
///
typename EstimatorType::ParametersType initialMeans( NumberOfContrasts *
NumberOfClasses );
estimator->SetParameters( initialMeans );
estimator->SetKdTree( treeGenerator->GetOutput() );
estimator->SetMaximumIteration( m_MaxNumberOfIterations );
estimator->SetCentroidPositionChangesThreshold(
m_CentroidChangePosThresh );
estimator->StartOptimization();
typename EstimatorType::ParametersType estimatedMeans =
estimator->GetParameters();
///
/// Compute classes labels
///
typedef itk::Statistics::SampleClassifierFilter<
ImageToListSampleAdaptorType > SampleClassifierFilterType;
typedef typename SampleClassifierFilterType::ClassLabelVectorObjectType
ClassLabelVectorObjectType;
typename ClassLabelVectorObjectType::Pointer classLabelsObject =
ClassLabelVectorObjectType::New();
typedef typename SampleClassifierFilterType::ClassLabelVectorType
ClassLabelVectorType;
ClassLabelVectorType& classLabelVector = classLabelsObject->Get();
for( unsigned int i=0; i < NumberOfClasses; i++ )
classLabelVector.push_back( typename
SampleClassifierFilterType::ClassLabelType( (i + 1) * 100 ) );
//Set a decision rule type
typedef itk::Statistics::MinimumDecisionRule2 DecisionRuleType;
typename DecisionRuleType::Pointer decisionRule =
DecisionRuleType::New();
const typename
SampleClassifierFilterType::MembershipFunctionVectorObjectType
*membershipFunctionsObject = estimator->GetOutput();
const typename
SampleClassifierFilterType::MembershipFunctionVectorType membershipFunctions
= membershipFunctionsObject->Get();
//Instantiate and pass all the required inputs to the filter
typename SampleClassifierFilterType::Pointer classifier =
SampleClassifierFilterType::New();
classifier->SetInput( adaptor );
classifier->SetNumberOfClasses( NumberOfClasses );
classifier->SetClassLabels( classLabelsObject );
classifier->SetDecisionRule( decisionRule );
classifier->SetMembershipFunctions( membershipFunctionsObject );
classifier->Update();
const typename SampleClassifierFilterType::MembershipSampleType*
membershipSample = classifier->GetOutput();
/*
const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label1( 100
);
const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample1 = membershipSample->GetClassSample( label1 );
std::cout << subSample1->GetTotalFrequency() << std::endl;
const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label2( 200
);
const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample2 = membershipSample->GetClassSample( label2 );
std::cout << subSample2->GetTotalFrequency() << std::endl;
const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label3( 300
);
const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample3 = membershipSample->GetClassSample( label3 );
std::cout << subSample3->GetTotalFrequency() << std::endl;
*/
///
/// Compute the covariance matrices and the weight vectors
///
typedef typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType
ClassSampleType;
typedef itk::Statistics::CovarianceSampleFilter< ClassSampleType >
CovarianceSampleFilterType;
/// Resize weight class vector
TGMM_InitStrategy< TArrayImageType >::m_WeightClassesVector.SetSize(
NumberOfClasses );
/// Now, get the mean & covariance into the parameters vector
for( unsigned int i=0; i < NumberOfClasses; ++i )
{
typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType classLabel(
(i + 1) * 100 );
/// Get a pointer only to the samples of classLabel
const ClassSampleType* sampleClass =
membershipSample->GetClassSample( classLabel );
/* The result here is okay
std::cout << sampleClass->Size() << std::endl;
double sum = 0;
for( int ii=0; ii < sampleClass->Size(); ii++ )
{
typename ClassSampleType::MeasurementVectorType val =
sampleClass->GetMeasurementVectorByIndex( ii );
sum = sum + val[0];
}
std::cout << "Mean value = " << (double)( sum / sampleClass->Size() ) <<
std::endl;
*/
/// Compute covariance matrix
typename CovarianceSampleFilterType::Pointer covarianceFilter =
CovarianceSampleFilterType::New();
covarianceFilter->SetInput( sampleClass );
covarianceFilter->Update();
const typename
CovarianceSampleFilterType::MeasurementVectorDecoratedType *MeanDecorator =
covarianceFilter->GetMeanOutput();
typename CovarianceSampleFilterType::MeasurementVectorType mean =
MeanDecorator->Get();
/* The results here are wrong */
std::cout << "Mean Vector1: " << MeanDecorator->Get() << std::endl;
std::cout << "Mean Vector2: " << covarianceFilter->GetMean() << std::endl;
const typename CovarianceSampleFilterType::MatrixDecoratedType
*CovDecorator = covarianceFilter->GetCovarianceMatrixOutput();
typename CovarianceSampleFilterType::MatrixType cov =
CovDecorator->Get();
/* The results here are wrong */
std::cout << "Covariance Matrix1: " << CovDecorator->Get() << std::endl;
std::cout << "Covariance Matrix2: " <<
covarianceFilter->GetCovarianceMatrix() << std::endl;
/// Create a parameter array for the current class
ParametersType *params = new ParametersType( NumberOfContrasts *
NumberOfContrasts + NumberOfContrasts );
unsigned int index = 0;
for( unsigned j=0; j < NumberOfContrasts; j++ )
{
/// Copy mean vector
params->SetElement( index++, static_cast< double >(
estimatedMeans[i * NumberOfContrasts + j] ) );
/// Copy covariance matrix
for( unsigned k=0; k < NumberOfContrasts; k++ )
params->SetElement( index++, (double)cov[j][k] );
}
TGMM_InitStrategy< TArrayImageType >::m_Parameters.push_back( params
);
TGMM_InitStrategy< TArrayImageType
>::m_WeightClassesVector.SetElement( i,
(double)sampleClass->GetTotalFrequency() / adaptor->GetTotalFrequency() );
}
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://www.itk.org/pipermail/insight-users/attachments/20091227/03048f71/attachment-0001.htm>
More information about the Insight-users
mailing list