int
main()
{
SampleType::Pointer sample = SampleType::New();
sample->SetMeasurementVectorSize(2);
MeasurementVectorType mv;
for (unsigned int i = 0; i < 1000; ++i)
{
mv[0] = (float)i;
mv[1] = (float)((1000 - i) / 2);
sample->PushBack(mv);
}
TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
treeGenerator->SetSample(sample);
treeGenerator->SetBucketSize(16);
treeGenerator->Update();
using CentroidTreeGeneratorType =
CentroidTreeGeneratorType::Pointer centroidTreeGenerator =
CentroidTreeGeneratorType::New();
centroidTreeGenerator->SetSample(sample);
centroidTreeGenerator->SetBucketSize(16);
centroidTreeGenerator->Update();
using TreeType = TreeGeneratorType::KdTreeType;
using NodeType = TreeType::KdTreeNodeType;
TreeType::Pointer tree = treeGenerator->GetOutput();
TreeType::Pointer centroidTree = centroidTreeGenerator->GetOutput();
NodeType * root = tree->GetRoot();
if (root->IsTerminal())
{
std::cout << "Root node is a terminal node." << std::endl;
}
else
{
std::cout << "Root node is not a terminal node." << std::endl;
}
unsigned int partitionDimension;
float partitionValue;
root->GetParameters(partitionDimension, partitionValue);
std::cout << "Dimension chosen to split the space = " << partitionDimension
<< std::endl;
std::cout << "Split point on the partition dimension = " << partitionValue
<< std::endl;
std::cout << "Address of the left chile of the root node = " << root->Left()
<< std::endl;
std::cout << "Address of the right chile of the root node = "
<< root->Right() << std::endl;
root = centroidTree->GetRoot();
std::cout << "Number of the measurement vectors under the root node"
<< " in the tree hierarchy = " << root->Size() << std::endl;
NodeType::CentroidType centroid;
root->GetWeightedCentroid(centroid);
std::cout << "Sum of the measurement vectors under the root node = "
<< centroid << std::endl;
std::cout << "Number of the measurement vectors under the left child"
<< " of the root node = " << root->Left()->Size() << std::endl;
MeasurementVectorType queryPoint;
queryPoint[0] = 10.0;
queryPoint[1] = 7.0;
using DistanceMetricType =
DistanceMetricType::Pointer distanceMetric = DistanceMetricType::New();
DistanceMetricType::OriginType origin(2);
for (unsigned int i = 0; i < sample->GetMeasurementVectorSize(); ++i)
{
origin[i] = queryPoint[i];
}
distanceMetric->SetOrigin(origin);
unsigned int numberOfNeighbors = 3;
TreeType::InstanceIdentifierVectorType neighbors;
tree->Search(queryPoint, numberOfNeighbors, neighbors);
std::cout
<< "\n*** kd-tree knn search result using an Euclidean distance metric:"
<< std::endl
<< "query point = [" << queryPoint << "]" << std::endl
<< "k = " << numberOfNeighbors << std::endl;
std::cout << "measurement vector : distance from querry point "
<< std::endl;
std::vector<double> distances1(numberOfNeighbors);
for (unsigned int i = 0; i < numberOfNeighbors; ++i)
{
distances1[i] =
distanceMetric->Evaluate(tree->GetMeasurementVector(neighbors[i]));
std::cout << "[" << tree->GetMeasurementVector(neighbors[i])
<< "] : " << distances1[i] << std::endl;
}
std::vector<double> distances2;
tree->Search(queryPoint, numberOfNeighbors, neighbors, distances2);
std::cout << "\n*** kd-tree knn search result directly from tree:"
<< std::endl
<< "query point = [" << queryPoint << "]" << std::endl
<< "k = " << numberOfNeighbors << std::endl;
std::cout << "measurement vector : distance from querry point "
<< std::endl;
for (unsigned int i = 0; i < numberOfNeighbors; ++i)
{
std::cout << "[" << tree->GetMeasurementVector(neighbors[i])
<< "] : " << distances2[i] << std::endl;
{
std::cerr << "Mismatched distance values by tree." << std::endl;
return EXIT_FAILURE;
}
}
std::vector<double> distances3;
centroidTree->Search(queryPoint, numberOfNeighbors, neighbors, distances3);
centroidTree->Search(queryPoint, numberOfNeighbors, neighbors);
std::cout << "\n*** Weighted centroid kd-tree knn search result:"
<< std::endl
<< "query point = [" << queryPoint << "]" << std::endl
<< "k = " << numberOfNeighbors << std::endl;
std::cout
<< "measurement vector : distance_by_distMetric : distance_by_tree"
<< std::endl;
std::vector<double> distances4(numberOfNeighbors);
for (unsigned int i = 0; i < numberOfNeighbors; ++i)
{
distances4[i] = distanceMetric->Evaluate(
centroidTree->GetMeasurementVector(neighbors[i]));
std::cout << "[" << centroidTree->GetMeasurementVector(neighbors[i])
<< "] : " << distances4[i] << " : "
<< distances3[i] << std::endl;
{
std::cerr << "Mismatched distance values by centroid tree."
<< std::endl;
return EXIT_FAILURE;
}
}
double radius = 437.0;
tree->Search(queryPoint, radius, neighbors);
std::cout << "\nSearching points within a hyper-spherical kernel:"
<< std::endl;
std::cout << "*** kd-tree radius search result:" << std::endl
<< "query point = [" << queryPoint << "]" << std::endl
<< "search radius = " << radius << std::endl;
std::cout << "measurement vector : distance" << std::endl;
for (auto neighbor : neighbors)
{
std::cout << "[" << tree->GetMeasurementVector(neighbor) << "] : "
<< distanceMetric->Evaluate(
tree->GetMeasurementVector(neighbor))
<< std::endl;
}
centroidTree->Search(queryPoint, radius, neighbors);
std::cout << "\n*** Weighted centroid kd-tree radius search result:"
<< std::endl
<< "query point = [" << queryPoint << "]" << std::endl
<< "search radius = " << radius << std::endl;
std::cout << "measurement vector : distance" << std::endl;
for (auto neighbor : neighbors)
{
std::cout << "[" << centroidTree->GetMeasurementVector(neighbor) << "] : "
<< distanceMetric->Evaluate(
centroidTree->GetMeasurementVector(neighbor))
<< std::endl;
}
return EXIT_SUCCESS;
}