[Insight-users] Wrong values when computing Mean and Covariance !!

Ricardo Ferrari rjf.araraquara at gmail.com
Sun Dec 27 18:40:17 EST 2009


Dear ITK-Users,

I am working on a class to estimate the initial parameters (mean vector,
covariance matrix and weighting vector) for a Gaussian Mixture model. For
that I am using the KdTreeBasedKmeansEstimator and the refactoring statistic
classes. To be more specifically, I am using the  itk::Statistics::
CovarianceSampleFilter to compute the Mean vector and Covariance matrices.
However, the results I am getting are completely wrong. After tracking the
code, it seems that the mean vector is always zero.

Could anybody help me to solve this problem please? I am not sure if I am
missing something.

Thank you,
Ricardo


    ///
    /// Create a image of array from different image contrasts
    ///
    typedef itk::ScalarToArrayCastImageFilter< InputImageType,
ArrayImageType > CasterType;
    CasterType::Pointer caster = CasterType::New();
    for( int i=0; i < NumberOfContrasts; ++i )
        caster->SetInput( i, ReadMincImage< InputImageType >(
inputFileName[i] ) );
    caster->Update();


    ///
    /// Generate sample data
    ///
    typedef itk::Statistics::ImageToListSampleAdaptor< TArrayImageType >
ImageToListSampleAdaptorType;
    typename ImageToListSampleAdaptorType::Pointer adaptor =
ImageToListSampleAdaptorType::New();
    adaptor->SetImage( caster->GetOutput() );

    ///
    /// Create KdTree
    ///
    typedef itk::Statistics::WeightedCentroidKdTreeGenerator <
ImageToListSampleAdaptorType > TreeGeneratorType;
    typename TreeGeneratorType::Pointer treeGenerator =
TreeGeneratorType::New();
    treeGenerator->SetSample( adaptor );
    treeGenerator->SetBucketSize( m_BuckedSize );
    treeGenerator->Update();

    typedef itk::Statistics::KdTreeBasedKmeansEstimator < typename
TreeGeneratorType::KdTreeType > EstimatorType;
    typename EstimatorType::Pointer estimator = EstimatorType::New();

    ///
    /// Estimate classes mean vector
    ///
    typename EstimatorType::ParametersType initialMeans( NumberOfContrasts *
NumberOfClasses );
    estimator->SetParameters( initialMeans );
    estimator->SetKdTree( treeGenerator->GetOutput() );
    estimator->SetMaximumIteration( m_MaxNumberOfIterations );
    estimator->SetCentroidPositionChangesThreshold(
m_CentroidChangePosThresh );
    estimator->StartOptimization();

    typename EstimatorType::ParametersType estimatedMeans =
estimator->GetParameters();

    ///
    /// Compute classes labels
    ///
    typedef itk::Statistics::SampleClassifierFilter<
ImageToListSampleAdaptorType >  SampleClassifierFilterType;

    typedef typename SampleClassifierFilterType::ClassLabelVectorObjectType
ClassLabelVectorObjectType;
    typename ClassLabelVectorObjectType::Pointer classLabelsObject =
ClassLabelVectorObjectType::New();

    typedef typename SampleClassifierFilterType::ClassLabelVectorType
ClassLabelVectorType;
    ClassLabelVectorType& classLabelVector = classLabelsObject->Get();
    for( unsigned int i=0; i < NumberOfClasses; i++ )
         classLabelVector.push_back( typename
SampleClassifierFilterType::ClassLabelType( (i + 1) * 100 ) );

      //Set a decision rule type
      typedef itk::Statistics::MinimumDecisionRule2  DecisionRuleType;
      typename DecisionRuleType::Pointer decisionRule =
DecisionRuleType::New();

      const typename
SampleClassifierFilterType::MembershipFunctionVectorObjectType
*membershipFunctionsObject = estimator->GetOutput();
      const typename
SampleClassifierFilterType::MembershipFunctionVectorType membershipFunctions
= membershipFunctionsObject->Get();

      //Instantiate and pass all the required inputs to the filter
    typename SampleClassifierFilterType::Pointer classifier =
SampleClassifierFilterType::New();
    classifier->SetInput( adaptor );
    classifier->SetNumberOfClasses( NumberOfClasses );
    classifier->SetClassLabels( classLabelsObject );
    classifier->SetDecisionRule( decisionRule );
    classifier->SetMembershipFunctions( membershipFunctionsObject );
    classifier->Update();

    const typename SampleClassifierFilterType::MembershipSampleType*
membershipSample = classifier->GetOutput();

/*
    const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label1( 100
);
      const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample1 = membershipSample->GetClassSample( label1 );
    std::cout << subSample1->GetTotalFrequency() << std::endl;

    const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label2( 200
);
      const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample2 = membershipSample->GetClassSample( label2 );
    std::cout << subSample2->GetTotalFrequency() << std::endl;

    const typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType label3( 300
);
      const typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType*
subSample3 = membershipSample->GetClassSample( label3 );
    std::cout << subSample3->GetTotalFrequency() << std::endl;
*/



    ///
    /// Compute the covariance matrices and the weight vectors
    ///
    typedef typename
SampleClassifierFilterType::MembershipSampleType::ClassSampleType
ClassSampleType;
    typedef itk::Statistics::CovarianceSampleFilter< ClassSampleType >
                CovarianceSampleFilterType;

    /// Resize weight class vector
    TGMM_InitStrategy< TArrayImageType >::m_WeightClassesVector.SetSize(
NumberOfClasses );

    /// Now, get the mean & covariance into the parameters vector
    for( unsigned int i=0; i < NumberOfClasses; ++i )
    {
        typename
SampleClassifierFilterType::MembershipSampleType::ClassLabelType classLabel(
(i + 1) * 100 );

        /// Get a pointer only to the samples of classLabel
        const ClassSampleType* sampleClass =
membershipSample->GetClassSample( classLabel );

/*  The result here is okay
std::cout << sampleClass->Size() << std::endl;

        double sum = 0;
        for( int ii=0; ii < sampleClass->Size(); ii++ )
        {
            typename ClassSampleType::MeasurementVectorType val =
sampleClass->GetMeasurementVectorByIndex( ii );
            sum = sum + val[0];
        }
std::cout << "Mean value = " << (double)( sum / sampleClass->Size() ) <<
std::endl;
*/


        /// Compute covariance matrix
        typename CovarianceSampleFilterType::Pointer covarianceFilter =
CovarianceSampleFilterType::New();
        covarianceFilter->SetInput( sampleClass );
        covarianceFilter->Update();

          const typename
CovarianceSampleFilterType::MeasurementVectorDecoratedType *MeanDecorator =
covarianceFilter->GetMeanOutput();
          typename CovarianceSampleFilterType::MeasurementVectorType mean =
MeanDecorator->Get();

/* The results here are wrong */
std::cout << "Mean Vector1: " << MeanDecorator->Get() << std::endl;
std::cout << "Mean Vector2: " << covarianceFilter->GetMean() << std::endl;

          const typename CovarianceSampleFilterType::MatrixDecoratedType
*CovDecorator = covarianceFilter->GetCovarianceMatrixOutput();
          typename CovarianceSampleFilterType::MatrixType cov =
CovDecorator->Get();

/* The results here are wrong */
std::cout << "Covariance Matrix1: " << CovDecorator->Get() << std::endl;
std::cout << "Covariance Matrix2: " <<
covarianceFilter->GetCovarianceMatrix() << std::endl;

        /// Create a parameter array for the current class
        ParametersType *params = new ParametersType( NumberOfContrasts *
NumberOfContrasts + NumberOfContrasts );

        unsigned int index = 0;
        for( unsigned j=0; j < NumberOfContrasts; j++ )
        {
            /// Copy mean vector
            params->SetElement( index++, static_cast< double >(
estimatedMeans[i * NumberOfContrasts + j] ) );

            /// Copy covariance matrix
            for( unsigned k=0; k < NumberOfContrasts; k++ )
                params->SetElement( index++, (double)cov[j][k] );
        }

        TGMM_InitStrategy< TArrayImageType >::m_Parameters.push_back( params
);
        TGMM_InitStrategy< TArrayImageType
>::m_WeightClassesVector.SetElement( i,
(double)sampleClass->GetTotalFrequency() / adaptor->GetTotalFrequency() );

    }
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://www.itk.org/pipermail/insight-users/attachments/20091227/03048f71/attachment-0001.htm>


More information about the Insight-users mailing list