BinarySpaceTree
The BinarySpaceTree
class represents a generic multidimensional binary space
partitioning tree. It is heavily templatized to control splitting behavior and
other behaviors, and is the actual class underlying trees such as the
KDTree
. In general, the BinarySpaceTree
class is not meant to
be used directly, and instead one of the numerous variants should be used
instead:
For users who want to use BinarySpaceTree
directly or with custom behavior,
the full class is still detailed in the subsections below. BinarySpaceTree
supports the TreeType API and
can be used with mlpack’s tree-based algorithms, although using custom behavior
may require a template typedef.
- Template parameters
- Constructors
- Basic tree properties
- Bounding distances with the tree
BoundType
template parameterStatisticType
template parameterSplitType
template parameter- Tree traversals
- Example usage
🔗 See also
KDTree
MeanSplitKDTree
- Binary space partitioning on Wikipedia
- Tree-Independent Dual-Tree Algorithms (pdf)
🔗 Template parameters
The BinarySpaceTree
class takes five template parameters. The first three of
these are required by the
TreeType API
(see also
this more detailed section). The
full signature of the class is:
template<typename DistanceType,
typename StatisticType,
typename MatType,
template<typename BoundDistanceType,
typename BoundElemType,
typename...> class BoundType,
template<typename SplitBoundType,
typename SplitMatType> class SplitType>
class BinarySpaceTree;
-
DistanceType
: the distance metric to use for distance computations. By default, this isEuclideanDistance
. StatisticType
: this holds auxiliary information in each tree node. By default,EmptyStatistic
is used, which holds no information.- See the
StatisticType
section for more details.
- See the
-
MatType
: the type of matrix used to represent points. Must be a type matching the Armadillo API. By default,arma::mat
is used, but other types such asarma::fmat
or similar will work just fine. BoundType
: the class defining the bound for each node. By default,HRectBound
is used.- The
BoundType
may place additional restrictions on theDistanceType
parameter; for instance,HRectBound
requires thatDistanceType
beLMetric
. - See the
BoundType
section for more details.
- The
SplitType
: the class defining how an individualBinarySpaceTree
node should be split. By default,MidpointSplit
is used.- See the
SplitType
section for more details.
- See the
Note that the TreeType API requires trees to have only three template
parameters. In order to use a BinarySpaceTree
with its five template
parameters with an mlpack algorithm that needs a TreeType, it is easiest to
define a template typedef:
template<typename DistanceType, typename StatisticType, typename MatType>
using CustomTree = BinarySpaceTree<DistanceType, StatisticType, MatType,
CustomBoundType, CustomSplitType>
Here, CustomBoundType
and CustomSplitType
are the desired bound and split
strategy. This is the way that all BinarySpaceTree
variants (such as
KDTree
) are defined.
🔗 Constructors
BinarySpaceTree
s are efficiently constructed by permuting points in a dataset
in a quicksort-like algorithm. However, this means that the ordering of points
in the tree’s dataset (accessed with node.Dataset()
) after construction may be
different.
node = BinarySpaceTree(data, maxLeafSize=20)
node = BinarySpaceTree(data, oldFromNew, maxLeafSize=20)
node = BinarySpaceTree(data, oldFromNew, newFromOld, maxLeafSize=20)
- Construct a
BinarySpaceTree
on the givendata
, usingmaxLeafSize
as the maximum number of points held in a leaf. - Default template parameters are used, meaning that this tree will be a
KDTree
. - By default,
data
is copied. Avoid a copy by usingstd::move()
(e.g.std::move(data)
); when doing this,data
will be set to an empty matrix. - Optionally, construct mappings from old points to new points.
oldFromNew
andnewFromOld
will have lengthdata.n_cols
, and:oldFromNew[i]
indicates that pointi
in the tree’s dataset was originally pointoldFromNew[i]
indata
; that is,node.Dataset().col(i)
is the pointdata.col(oldFromNew[i])
.newFromOld[i]
indicates that pointi
indata
is now pointnewFromOld[i]
in the tree’s dataset; that is,node.Dataset().col(newFromOld[i])
is the pointdata.col(i)
.
- Construct a
node = BinarySpaceTree<DistanceType, StatisticType, MatType, BoundType, SplitType>(data, maxLeafSize=20)
node = BinarySpaceTree<DistanceType, StatisticType, MatType, BoundType, SplitType>(data, oldFromNew, maxLeafSize=20)
node = BinarySpaceTree<DistanceType, StatisticType, MatType, BoundType, SplitType>(data, oldFromNew, newFromOld, maxLeafSize=20)
- Construct a
BinarySpaceTree
on the givendata
, using custom template parameters to control the behavior of the tree, usingmaxLeafSize
as the maximum number of points held in a leaf. - By default,
data
is copied. Avoid a copy by usingstd::move()
(e.g.std::move(data)
); when doing this,data
will be set to an empty matrix. - Optionally, construct mappings from old points to new points.
oldFromNew
andnewFromOld
will have lengthdata.n_cols
, and:oldFromNew[i]
indicates that pointi
in the tree’s dataset was originally pointoldFromNew[i]
indata
; that is,node.Dataset().col(i)
is the pointdata.col(oldFromNew[i])
.newFromOld[i]
indicates that pointi
indata
is now pointnewFromOld[i]
in the tree’s dataset; that is,node.Dataset().col(newFromOld[i])
is the pointdata.col(i)
.
- Construct a
node = BinarySpaceTree()
- Construct an empty
BinarySpaceTree
with no children, no points, and default template parameters.
- Construct an empty
Notes:
-
The name
node
is used here forBinarySpaceTree
objects instead oftree
, because eachBinarySpaceTree
object is a single node in the tree. The constructor returns the node that is the root of the tree. -
Inserting individual points or removing individual points from a
BinarySpaceTree
is not supported, because this generally results in a tree with very loose bounding boxes. It is better to simply build a newBinarySpaceTree
on the modified dataset. For trees that support individual insertion and deletions, see theRectangleTree
class and all its variants (e.g.RTree
,RStarTree
, etc.). -
See also the developer documentation on tree constructors.
🔗 Constructor parameters:
name | type | description | default |
---|---|---|---|
data |
MatType |
Column-major matrix to build the tree on. Pass with std::move(data) to avoid copying the matrix. |
(N/A) |
maxLeafSize |
size_t |
Maximum number of points to store in each leaf. | 20 |
oldFromNew |
std::vector<size_t> |
Mappings from points in node.Dataset() to points in data . |
(N/A) |
newFromOld |
std::vector<size_t> |
Mappings from points in data to points in node.Dataset() . |
(N/A) |
🔗 Basic tree properties
Once a BinarySpaceTree
object is constructed, various properties of the tree
can be accessed or inspected. Many of these functions are required by the
TreeType API.
🔗 Navigating the tree
-
node.NumChildren()
returns the number of children innode
. This is either2
ifnode
has children, or0
ifnode
is a leaf. -
node.IsLeaf()
returns abool
indicating whether or notnode
is a leaf. node.Child(i)
returns aBinarySpaceTree&
that is thei
th child.i
must be0
or1
.- This function should only be called if
node.NumChildren()
is not0
(e.g. ifnode
is not a leaf). Note that this returns a validBinarySpaceTree&
that can itself be used just like the root node of the tree! node.Left()
andnode.Right()
are convenience functions specific toBinarySpaceTree
that will returnBinarySpaceTree*
(pointers) to the left and right children, respectively, orNULL
ifnode
has no children.
node.Parent()
will return aBinarySpaceTree*
that points to the parent ofnode
, orNULL
ifnode
is the root of theBinarySpaceTree
.
🔗 Accessing members of a tree
-
node.Bound()
will return aBoundType&
object that represents the hyperrectangle bounding box ofnode
. This is the smallest hyperrectangle that encloses all the descendant points ofnode
. -
node.Stat()
will return aStatisticType&
holding the statistics of the node that were computed during tree construction. -
node.Distance()
will return aDistanceType&
.
See also the developer documentation for basic tree functionality in mlpack.
🔗 Accessing data held in a tree
-
node.Dataset()
will return aconst MatType&
that is the dataset the tree was built on. Note that this is a permuted version of thedata
matrix passed to the constructor. node.NumPoints()
returns asize_t
indicating the number of points held directly innode
.- If
node
is not a leaf, this will return0
, asBinarySpaceTree
only holds points directly in its leaves. - If
node
is a leaf, then the number of points will be less than or equal to themaxLeafSize
that was specified when the tree was constructed.
- If
node.Point(i)
returns asize_t
indicating the index of thei
‘th point innode.Dataset()
.i
must be in the range[0, node.NumPoints() - 1]
(inclusive).node
must be a leaf (as non-leaves do not hold any points).- The
i
‘th point innode
can then be accessed asnode.Dataset().col(node.Point(i))
. - In a
BinarySpaceTree
, because of the permutation of points done during construction, point indices are contiguous:node.Point(i + j)
is the same asnode.Point(i) + j
for validi
andj
. - Accessing the actual
i
‘th point itself can be done with, e.g.,node.Dataset().col(node.Point(i))
.
node.NumDescendants()
returns asize_t
indicating the number of points held in all descendant leaves ofnode
.- If
node
is the root of the tree, thennode.NumDescendants()
will be equal tonode.Dataset().n_cols
.
- If
node.Descendant(i)
returns asize_t
indicating the index of thei
‘th descendant point innode.Dataset()
.i
must be in the range[0, node.NumDescendants() - 1]
(inclusive).node
does not need to be a leaf.- The
i
‘th descendant point innode
can then be accessed asnode.Dataset().col(node.Descendant(i))
. - In a
BinarySpaceTree
, because of the permutation of points done during construction, point indices are contiguous:node.Descendant(i + j)
is the same asnode.Descendant(i) + j
for validi
andj
. - Accessing the actual
i
‘th descendant itself can be done with, e.g.,node.Dataset().col(node.Descendant(i))
.
node.Begin()
returns asize_t
indicating the index of the first descendant point ofnode
.- This is equivalent to
node.Descendant(0)
.
- This is equivalent to
node.Count()
returns asize_t
indicating the number of descendant points ofnode
.- This is equivalent to
node.NumDescendants()
.
- This is equivalent to
🔗 Accessing computed bound quantities of a tree
The following quantities are cached for each node in a BinarySpaceTree
, and so
accessing them does not require any computation. In the documentation below,
ElemType
is the element type of the given MatType
; e.g., if MatType
is
arma::mat
, then ElemType
is double
.
node.FurthestPointDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the furthest point held bynode
.- If
node
is not a leaf, this returns 0 (becausenode
does not hold any points).
- If
-
node.FurthestDescendantDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the furthest descendant point held bynode
. -
node.MinimumBoundDistance()
returns anElemType
representing minimum possible distance from the center of the node to any edge of its bound. node.ParentDistance()
returns anElemType
representing the distance between the center of the bound ofnode
and the center of the bound of its parent.- If
node
is the root of the tree,0
is returned.
- If
Note: for more details on each bound quantity, see the developer documentation on bound quantities for trees.
🔗 Other functionality
node.Center(center)
computes the center of the bound ofnode
and stores it incenter
.center
should be of typearma::Col<ElemType>&
, whereElemType
is the element type of the specifiedMatType
.center
will be set to have size equivalent to the dimensionality of the dataset held bynode
.- This is equivalent to calling
node.Bound().Center(center)
.
- A
BinarySpaceTree
can be serialized withdata::Save()
anddata::Load()
.
🔗 Bounding distances with the tree
The primary use of trees in mlpack is bounding distances to points or other tree nodes. The following functions can be used for these tasks.
node.GetNearestChild(point)
node.GetFurthestChild(point)
- Return a
size_t
indicating the index of the child (0
for left,1
for right) that is closest to (or furthest from)point
, with respect to theMinDistance()
(orMaxDistance()
) function. - If there is a tie,
0
(the left child) is returned. - If
node
is a leaf,0
is returned. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.GetNearestChild(other)
node.GetFurthestChild(other)
- Return a
size_t
indicating the index of the child (0
for left,1
for right) that is closest to (or furthest from) theBinarySpaceTree
nodeother
, with respect to theMinDistance()
(orMaxDistance()
) function. - If there is a tie,
2
(an invalid index) is returned. Note that this behavior differs from the version above that takes a point. - If
node
is a leaf,0
is returned.
- Return a
node.MinDistance(point)
node.MinDistance(other)
- Return a
double
indicating the minimum possible distance betweennode
andpoint
, or theBinarySpaceTree
nodeother
. - This is equivalent to the minimum possible distance between any point
contained in the bounding hyperrectangle of
node
andpoint
, or between any point contained in the bounding hyperrectangle ofnode
and any point contained in the bounding hyperrectangle ofother
. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.MaxDistance(point)
node.MaxDistance(other)
- Return a
double
indicating the maximum possible distance betweennode
andpoint
, or theBinarySpaceTree
nodeother
. - This is equivalent to the maximum possible distance between any point
contained in the bounding hyperrectangle of
node
andpoint
, or between any point contained in the bounding hyperrectangle ofnode
and any point contained in the bounding hyperrectangle ofother
. point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
node.RangeDistance(point)
node.RangeDistance(other)
- Return a
RangeType<ElemType>
whose lower bound isnode.MinDistance(point)
ornode.MinDistance(other)
, and whose upper bound isnode.MaxDistance(point)
ornode.MaxDistance(other)
. ElemType
is the element type ofMatType
.point
should be a column vector type of the same type asMatType
. (e.g., ifMatType
isarma::mat
, thenpoint
should be anarma::vec
.)
- Return a
🔗 Tree traversals
Like every mlpack tree, the BinarySpaceTree
class provides a single-tree and
dual-tree traversal that can be paired
with a RuleType
class to implement a
single-tree or dual-tree algorithm.
BinarySpaceTree::SingleTreeTraverser
- Implements a depth-first single-tree traverser.
BinarySpaceTree::DualTreeTraverser
- Implements a dual-depth-first dual-tree traverser.
In addition to those two classes, which are required by the
TreeType
policy, an additional traverser is
available:
BinarySpaceTree::BreadthFirstDualTreeTraverser
- Implements a dual-breadth-first dual-tree traverser.
- Note: this traverser is not useful for all tasks; because the
BinarySpaceTree
only holds points in the leaves, this means that no base cases (e.g. comparisons between points) will be called until all pairs of intermediate nodes have been scored!
🔗 BoundType
Each node in a BinarySpaceTree
corresponds to some region in space that
contains all of the descendant points in the node. This region is represented
by the BoundType
class. The use of different BoundType
s can mean different
shapes for each node in the tree; for instance, the HRectBound
class uses a hyperrectangle bound. An example HRectBound
is shown below; the
bound is the smallest rectangle that encloses all of the points.
mlpack supplies several drop-in BoundType
classes, and it is also possible to
write a custom BoundType
for use with BinarySpaceTree
:
HRectBound
: hyperrectangle bound, encloses the descendant points in the smallest possible hyperrectangleBallBound
: ball bound, encloses the descendant points in the ball with the smallest possible radiusHollowBallBound
: hollow ball bound, equivalent to a ball bound with a ball subtracted from it.CellBound
: bound enclosing a contiguous subregion of a hyperrectangle- Custom
BoundType
s: implement a fully customBoundType
Note: this section is still under construction—not all bound types are documented yet.
🔗 HRectBound
The HRectBound
class represents a hyper-rectangle bound; that is, a
rectangle-shaped bound in arbitrary dimensions (e.g. a “box”). An HRectBound
can be used to perform a variety of distance-based bounding tasks.
HRectBound
is used directly by the KDTree
class.
Constructors
HRectBound
allows configurable behavior via its two template parameters:
HRectBound<DistanceType, ElemType>
Different constructor forms can be used to specify different template parameters (and thus different bound behavior).
b = HRectBound(dimensionality)
- Construct an
HRectBound
with the givendimensionality
. - The bound will be empty with an invalid center (e.g.,
b
will not contain any points at all). - The bound will use the Euclidean distance for
distance computation, and will expect data to have elements with type
double
.
- Construct an
b = HRectBound<DistanceType>(dimensionality)
- Construct an
HRectBound
with the givendimensionality
that will use the givenDistanceType
class to compute distances. DistanceType
is required to be anLMetric
, as the distance calculation must be decomposable across dimensions.- The bound will expect data to have elements with type
double
.
- Construct an
b = HRectBound<DistanceType, ElemType>(dimensionality)
- Construct an
HRectBound
with the givendimensionality
that will use the givenDistanceType
class to compute distances, and expect data to have elements with typeElemType
. DistanceType
is required to be anLMetric
, as the distance calculation must be decomposable across dimensions.ElemType
should generally bedouble
orfloat
.
- Construct an
Note: these constructors provide an empty bound; be sure to grow the bound or directly modify the bound before using it!
Accessing and modifying properties of the bound
The individual bounds associated with each dimension of an HRectBound
can be
accessed and modified.
-
b.Clear()
will reset the bound to an empty bound (e.g. containing no points). -
b.Dim()
will return asize_t
indicating the dimensionality of the bound. -
b[dim]
will return aRange
object holding the lower and upper bounds ofb
in dimensiondim
. - The lower and upper bounds of an
HRectBound
can be directly modified in a few ways:b[dim].Lo() = lo
will set the lower bound ofb
in dimensiondim
tolo
(adouble
, or anElemType
if a customElemType
is being used).b[dim].Hi() = hi
will set the upper bound ofb
in dimensiondim
tohi
.b[dim] = Range(lo, hi)
will set the bounds forb
in dimensiondim
to the (inclusive) range[lo, hi]
.- Notes:
- if a bound in a dimension is set such that
hi < lo
, then the bound will contain nothing and have zero volume. - manually modifying bounds in this way will invalidate
MinWidth()
, and ifMinWidth()
is to be used, callb.RecomputeMinWidth()
.
- if a bound in a dimension is set such that
-
b.MinWidth()
returns the minimum width of the bound in any dimension as adouble
. This value is cached and no computation is performed when callingb.MinWidth()
. If the bound is empty,0
is returned. -
b.Distance()
returns either aEuclideanDistance
distance metric object, or aDistanceType
if a customDistanceType
has been specified in the constructor. -
b.Center(center)
will compute the center of theHRectBound
(e.g. the vector with elements equal to the midpoint ofb
in each dimension) and store it in the vectorcenter
.center
should be of typearma::vec
. -
b.Volume()
computes the volume of the hyperrectangle specified byb
. The volume is returned as adouble
. -
b.Diameter()
computes the longest diagonal of the hyperrectangle specified byb
. - An
HRectBound
can be serialized withdata::Save()
anddata::Load()
.
Note: if a custom ElemType
was specified in the constructor, then:
b[dim]
will return aRangeType<ElemType>
;b.MinWidth()
,b.Volume()
, andb.Diameter()
will returnElemType
; andb.Center(center)
expectscenter
to be of typearma::Col<ElemType>
.
Growing and shrinking the bound
The HRectBound
uses the logical |=
and &=
operators to perform set
operations with data points or other bounds.
-
b |= data
expandsb
to include all of the data points indata
.data
should be a column-majorarma::mat
. The expansion operation is minimal, sob
is not expanded any more than necessary. -
b |= bound
expandsb
to fully includebound
, wherebound
is anotherHRectBound
. The expansion/union operation is minimal, sob
is not expanded any more than necessary. -
b & bound
returns a newHRectBound
whose bounding hyper-rectangle is the intersection of the bounding hyperrectangles ofb
andbound
. Ifb
andbound
do not intersect, then the returnedHRectBound
will be empty. -
b &= bound
is equivalent tob = (b & bound)
. (e.g. perform an in-place intersection withbound
.)
Notes:
-
When another bound is passed, it must have the same type as
b
; so, if a customDistanceType
andElemType
were specified, thenbound
must have typeHRectBound<DistanceType, ElemType>
. -
If a custom
ElemType
was specified, then anydata
argument should be a matrix with thatElemType
(e.g.arma::Mat<ElemType>
). -
Each function expects the other bound or dataset to have dimensionality that matches
b
.
Bounding distances to other objects
Once an HRectBound
has been successfully created and set to the desired
bounding hyperrectangle, there are a number of functions that can bound the
distance between an HRectBound
and other objects.
b.Contains(point)
b.Contains(bound)
- Return a
bool
indicating whether or notb
contains the givenpoint
(anarma::vec
) or anotherbound
(anHRectBound
). - When passing another
bound
,true
will be returned ifbound
even partially overlaps withb
.
- Return a
b.MinDistance(point)
b.MinDistance(bound)
- Return a
double
whose value is the minimum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(anHRectBound
). - The minimum distance between
b
and another point or bound is the length of the shortest possible line that can connect the other point or bound tob
. - If
point
orbound
are contained inb
, then the returned distance is 0.
- Return a
b.MaxDistance(point)
b.MaxDistance(bound)
- Return a
double
whose value is the maximum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(anHRectBound
). - The maximum distance between
b
and a givenpoint
is the furthest possible distance betweenpoint
and any possible point falling within the bounding hyperrectangle ofb
. - The maximum distance between
b
and anotherbound
is the furthest possible distance between any possible point falling within the bounding hyperrectangle ofb
, and any possible point falling within the bounding hyperrectangle ofbound
. - Note that this definition means that even if
b.Contains(point)
orb.Contains(bound)
istrue
, the maximum distance may be greater than0
.
- Return a
b.RangeDistance(point)
b.RangeDistance(bound)
- Compute the minimum and maximum distance between
b
andpoint
orbound
, returning the result as aRange
object. - This is more efficient than calling
b.MinDistance()
andb.MaxDistance()
.
- Compute the minimum and maximum distance between
b.Overlap(bound)
- Returns a
double
whose value is the volume of overlap ofb
and the givenbound
. - This is equivalent to
(b & bound).Volume()
(but more efficient!).
- Returns a
Note: if a custom DistanceType
and ElemType
were specified in the
constructor, then all distances will be computed with respect to the specified
DistanceType
and all return values will either be ElemType
or
RangeType<ElemType>
(except for Contains()
, which will
still return a bool
).
Example usage
// Create a bound that is the unit cube in 3 dimensions, by setting the values
// manually. The bounding range for all three dimensions is [0.0, 1.0].
mlpack::HRectBound b(3);
b[0] = mlpack::Range(0.0, 1.0);
b[1].Lo() = 0.0;
b[1].Hi() = 1.0;
b[2] = b[1];
// The minimum width is not correct if we modify bound dimensions manually, so
// we have to recompute it.
b.RecomputeMinWidth();
std::cout << "Bounding box created manually:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b[i].Lo() << ", " << b[i].Hi()
<< "]." << std::endl;
}
// Create a small dataset of 5 points, and then create a bound that contains all
// of those points.
arma::mat dataset(3, 5);
dataset.col(0) = arma::vec("2.0 2.0 2.0");
dataset.col(1) = arma::vec("2.5 2.5 2.5");
dataset.col(2) = arma::vec("3.0 2.0 3.0");
dataset.col(3) = arma::vec("2.0 3.0 2.0");
dataset.col(4) = arma::vec("3.0 3.0 3.0");
// The bounding box of `dataset` is [2.0, 3.0] in all three dimensions.
mlpack::HRectBound b2(3);
b2 |= dataset;
std::cout << "Bounding box created on dataset:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b2[i].Lo() << ", " << b2[i].Hi()
<< "]." << std::endl;
}
// Create a new bound that is the union of the two bounds.
mlpack::HRectBound b3 = b;
b3 |= b2;
std::cout << "Union-ed bounding box:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b3[i].Lo() << ", " << b3[i].Hi()
<< "]." << std::endl;
}
// Create a new bound that is the intersection of the two bounds (this will be
// empty!).
mlpack::HRectBound b4 = (b & b2);
std::cout << "Intersection bounding box:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b4[i].Lo() << ", " << b4[i].Hi()
<< "].";
if (b4[i].Hi() < b4[i].Lo())
std::cout << " (Empty!)";
std::cout << std::endl;
}
// Print statistics about the union bound and intersection bound.
std::cout << "Union-ed bound details:" << std::endl;
std::cout << " - Dimensionality: " << b3.Dim() << "." << std::endl;
std::cout << " - Minimum width: " << b3.MinWidth() << "." << std::endl;
std::cout << " - Diameter: " << b3.Diameter() << "." << std::endl;
std::cout << " - Volume: " << b3.Volume() << "." << std::endl;
arma::vec center;
b3.Center(center);
std::cout << " - Center: " << center.t();
std::cout << std::endl;
std::cout << "Intersection bound details:" << std::endl;
std::cout << " - Dimensionality: " << b4.Dim() << "." << std::endl;
std::cout << " - Minimum width: " << b4.MinWidth() << "." << std::endl;
std::cout << " - Diameter: " << b4.Diameter() << "." << std::endl;
std::cout << " - Volume: " << b4.Volume() << "." << std::endl;
b4.Center(center);
std::cout << " - Center: " << center.t();
std::cout << std::endl;
// Compute the minimum distance between a point inside the unit cube and the
// unit cube bound.
const double d1 = b.MinDistance(arma::vec("0.5 0.5 0.5"));
std::cout << "Minimum distance between unit cube bound and [0.5, 0.5, 0.5]: "
<< d1 << "." << std::endl;
// Use Contains(). In this case, the 'else' will be taken.
if (b.Contains(arma::vec("1.5 1.5 1.5")))
std::cout << "Unit cube bound contains [1.5, 1.5, 1.5]." << std::endl;
else
std::cout << "Unit cube does not contain [1.5, 1.5, 1.5]." << std::endl;
std::cout << std::endl;
// Compute the maximum distance between a point inside the unit cube and the
// unit cube bound.
const double d2 = b.MaxDistance(arma::vec("0.5 0.5 0.5"));
std::cout << "Maximum distance between unit cube bound and [0.5, 0.5, 0.5]: "
<< d2 << "." << std::endl;
// Compute the minimum and maximum distances between the unit cube bound and the
// bound built on data points.
const mlpack::Range r = b.RangeDistance(b2);
std::cout << "Distances between unit cube bound and dataset bound: [" << r.Lo()
<< ", " << r.Hi() << "]." << std::endl;
// Create a random bound.
mlpack::HRectBound br(3);
for (size_t i = 0; i < 3; ++i)
br[i] = mlpack::Range(mlpack::Random(), mlpack::Random() + 1);
std::cout << "Randomly created bound:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << br[i].Lo() << ", " << br[i].Hi()
<< "]." << std::endl;
}
// Compute the overlap of various bounds.
const double o1 = b.Overlap(b2); // This will be 0: the bounds don't overlap.
const double o2 = b.Overlap(b3); // This will be 1; b3 fully overlaps b, and
// the volume of b is 1 (it is the unit cube).
const double o3 = br.Overlap(b); // br and b do not fully overlap.
std::cout << "Overlap of unit cube and data bound: " << o1 << "." << std::endl;
std::cout << "Overlap of unit cube and union bound: " << o2 << "." << std::endl;
std::cout << "Overlap of unit cube and random bound: " << o3 << "."
<< std::endl;
// Create a bound using the Manhattan (L1) distance and compute the minimum and
// maximum distance to a point.
mlpack::HRectBound<mlpack::ManhattanDistance> mb(3);
mb |= dataset; // This will set the bound to [2.0, 3.0] in every dimension.
const mlpack::Range r2 = mb.RangeDistance(arma::vec("1.5 1.5 4.0"));
std::cout << "Distance between Manhattan distance HRectBound and "
<< "[1.5, 1.5, 4.0]: [" << r2.Lo() << ", " << r2.Hi() << "]." << std::endl;
// Create a bound using the Chebyshev (L-inf) distance, using random 32-bit
// floating point elements, and compute the minimum and maximum distance to a
// point.
arma::fmat floatData(3, 25, arma::fill::randu);
mlpack::HRectBound<mlpack::ChebyshevDistance, float> cb;
cb |= floatData;
// Note the use of arma::fvec to represent a point, since ElemType is float.
const mlpack::RangeType<float> r3 = cb.RangeDistance(arma::fvec("1.5 1.5 4.0"));
std::cout << "Distance between Chebyshev distance HRectBound and "
<< "[1.5, 1.5, 4.0]: [" << r3.Lo() << ", " << r3.Hi() << "]." << std::endl;
🔗 BallBound
The BallBound
class represents a ball with a center and a radius. A
BallBound
can be used to perform a variety of distance-based bounding tasks.
BallBound
is used directly by the BallTree
class.
Constructors
BallBound
allows configurable behavior via its three template parameters:
BallBound<DistanceType, ElemType, VecType>
The three template parameters are described below:
-
DistanceType
: specifies the distance metric to use for distance calculations. Defaults toEuclideanDistance
. -
ElemType
: specifies the element type of the bound. By default this isdouble
, but can also befloat
. Generally this should be a floating-point type. -
VecType
: specifies the vector type to use to store the center of the ball bound. By default this isarma::Col<ElemType>
. The element type of the givenVecType
should be the same asElemType
.
Different constructor forms can be used to specify different template parameters (and thus different bound behavior).
b = BallBound(dimensionality)
- Construct a
BallBound
with the givendimensionality
. - The bound will be empty with an invalid center (e.g.,
b
will not contain any points at all). - The bound will use the Euclidean distance for
distance computation, and will expect data to have elements with type
double
.
- Construct a
b = BallBound<DistanceType, ElemType, VecType>(dimensionality)
- Construct a
BallBound
with the givendimensionality
that will use the givenDistanceType
,ElemType
, andVecType
parameters. - Note that it is not required to specify all three template parameters.
- See above for details on the meaning of each template parameter.
- The bound will be empty with an invalid center (e.g.,
b
will not contain any points at all).
- Construct a
Note: these constructors provide an empty bound; be sure to grow the bound or directly modify the bound before using it!
b = BallBound(radius, center)
- Construct a
BallBound
with the givenradius
andcenter
. radius
should have typedouble
.center
should have vector typearma::vec
.
- Construct a
b = BallBound<DistanceType, ElemType, VecType>(radius, center)
- Construct a
BallBound
with the givenradius
andcenter
. radius
should have typeElemType
.center
should have typeVecType
.- Note that it is not required to specify all three template parameters.
- See above for details on the meaning of each template parameter.
- Construct a
Accessing and modifying properties of the bound
The properties of the BallBound
can be directly accessed and modified.
-
b.Dim()
will return asize_t
indicating the dimensionality of the bound. -
b.Center()
returns anarma::vec&
containing the center of the ball bound. Its elements can be directly modified. b.Radius()
will return adouble
that is the radius of the ball.b.Radius() = r
will set the radius of the ball tor
.
b[dim]
will return aRange
object representing the extents of the bound in dimensiondim
.- The range is defined as
[b.Center()[dim] - b.Radius(), b.Center()[dim] + b.Radius()]
. - Note: unlike
HRectBound
, it is not possible to set individual bound dimensions withb[dim]
. Useb.Center()
andb.Radius()
instead.
- The range is defined as
-
b.Diameter()
returns the diameter of the ball. This is always equal to2 * b.Radius()
. -
b.MinWidth()
returns the minimum width of the bound in any dimension as adouble
. This is always equal tob.Diameter()
. -
b.Distance()
returns either aEuclideanDistance
distance metric object, or aDistanceType
if a customDistanceType
has been specified in the constructor. -
b.Center(center)
will store the center of theBallBound
in the vectorcenter
.center
should be of typearma::vec
. - A
BallBound
can be serialized withdata::Save()
anddata::Load()
.
Note: if a custom ElemType
and/or VecType
were specified in the
constructor, then:
b.Radius()
,b.MinWidth()
,b.Volume()
, andb.Diameter()
will returnElemType
;b[dim]
will return aRangeType<ElemType>
;b.Center()
will return aVecType&
, andb.Center(center)
expectscenter
to be of typeVecType
.
Growing the bound
The BallBound
uses the logical |=
to grow the bound to include points or
other BallBound
s.
b |= data
expandsb
to include all of the data points indata
.data
should be a column-majorarma::mat
. The expansion operation is minimal, sob
is not expanded any more than necessary.- The bound is grown using Jack Ritter’s bounding sphere algorithm, which may move the center of the bound as it iteratively adds points to the bound.
- If the bound is empty, the center is initialized to the first point of
data
. - If the bound is not empty, then
data
is expected to have dimensionality that matchesb.Dim()
.
Bounding distances to other objects
Once a BallBound
has been successfully created and set to the desired bounding
ball, there are a number of functions that can bound the distance between a
BallBound
and other objects.
b.Contains(point)
- Return a
bool
indicating whether or notb
contains the givenpoint
(anarma::vec
).
- Return a
b.MinDistance(point)
b.MinDistance(bound)
- Return a
double
whose value is the minimum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aBallBound
). - The minimum distance between
b
and another point is the distance between the point andb
’s center minusb
’s radius. - The minimum distance between
b
and another bound is the distance between the centers minus the radii of the bounds. - If
point
is contained inb
, or ifbound
overlapsb
, then the returned distance is 0.
- Return a
b.MaxDistance(point)
b.MaxDistance(bound)
- Return a
double
whose value is the maximum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aBallBound
). - The maximum distance between
b
and a givenpoint
is the distance between the point andb
’s center plusb
’s radius. - The maximum distance between
b
and another bound is the distance between the centers plus the radii of the bounds. - Note that this definition means that even if
b.Contains(point)
is true, or ifb
overlapsbound
, the maximum distance may be greater than0
.
- Return a
b.RangeDistance(point)
b.RangeDistance(bound)
- Compute the minimum and maximum distance between
b
andpoint
orbound
, returning the result as aRange
object. - This is more efficient than calling
b.MinDistance()
andb.MaxDistance()
.
- Compute the minimum and maximum distance between
Note: if a custom DistanceType
, ElemType
, or VecType
were specified
in the constructor, then:
- all distances will be computed with respect to the
specified
DistanceType
; - all
point
arguments should have typeVecType
; and - all return values will either be
ElemType
orRangeType<ElemType>
(except forContains()
, which will still return abool
).
Example usage
// Create a bound that is the unit ball in 3 dimensions, by setting the center
// and radius in the constructor.
mlpack::BallBound b(1.0, arma::vec(3, arma::fill::zeros));
std::cout << "Bounding ball created manually:" << std::endl;
std::cout << " - Center: " << b.Center().t();
std::cout << " - Radius: " << b.Radius() << "." << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << " range: [" << b[i].Lo() << ", "
<< b[i].Hi() << "]." << std::endl;
}
// Create a small dataset of 5 points, and then create a bound that contains all
// of those points.
arma::mat dataset(3, 5);
dataset.col(0) = arma::vec("2.0 2.0 2.0");
dataset.col(1) = arma::vec("2.5 2.5 2.5");
dataset.col(2) = arma::vec("3.0 2.0 3.0");
dataset.col(3) = arma::vec("2.0 3.0 2.0");
dataset.col(4) = arma::vec("3.0 3.0 3.0");
// The bounding ball will be computed using Jack Ritter's algorithm.
mlpack::BallBound b2(3);
b2 |= dataset;
std::cout << "Bounding ball created on dataset:" << std::endl;
std::cout << " - Center: " << b2.Center().t();
std::cout << " - Radius: " << b2.Radius() << "." << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b2[i].Lo() << ", " << b2[i].Hi()
<< "]." << std::endl;
}
// Compute the minimum distance between a point inside the unit ball and the
// unit ball bound.
const double d1 = b.MinDistance(arma::vec("0.5 0.5 0.5"));
std::cout << "Minimum distance between unit ball bound and [0.5, 0.5, 0.5]: "
<< d1 << "." << std::endl;
// Use Contains(). In this case, the 'else' will be taken.
if (b.Contains(arma::vec("1.5 1.5 1.5")))
std::cout << "Unit ball bound contains [1.5, 1.5, 1.5]." << std::endl;
else
std::cout << "Unit ball does not contain [1.5, 1.5, 1.5]." << std::endl;
std::cout << std::endl;
// Compute the maximum distance between a point inside the unit ball and the
// unit ball bound.
const double d2 = b.MaxDistance(arma::vec("0.5 0.5 0.5"));
std::cout << "Maximum distance between unit ball bound and [0.5, 0.5, 0.5]: "
<< d2 << "." << std::endl;
// Compute the minimum and maximum distances between the unit ball bound and the
// bound built on data points.
const mlpack::Range r = b.RangeDistance(b2);
std::cout << "Distances between unit ball bound and dataset bound: [" << r.Lo()
<< ", " << r.Hi() << "]." << std::endl;
// Create a random bound with radius between 1 and 2 and random center.
mlpack::BallBound br(3);
br.Radius() = 1.0 + mlpack::Random();
br.Center() = arma::randu<arma::vec>(3);
std::cout << "Randomly created bound:" << std::endl;
std::cout << " - Center: " << br.Center().t();
std::cout << " - Radius: " << br.Radius() << "." << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << br[i].Lo() << ", " << br[i].Hi()
<< "]." << std::endl;
}
// Create a bound using the Manhattan (L1) distance and compute the minimum and
// maximum distance to a point.
mlpack::BallBound<mlpack::ManhattanDistance> mb(3);
mb |= dataset; // Expand the bound to include the points in the dataset.
const mlpack::Range r2 = mb.RangeDistance(arma::vec("1.5 1.5 4.0"));
std::cout << "Distance between Manhattan distance BallBound and "
<< "[1.5, 1.5, 4.0]: [" << r2.Lo() << ", " << r2.Hi() << "]." << std::endl;
// Create a bound using the Chebyshev (L-inf) distance, using random 32-bit
// floating point elements, and compute the minimum and maximum distance to a
// point.
arma::fmat floatData(3, 25, arma::fill::randu);
mlpack::BallBound<mlpack::ChebyshevDistance, float> cb;
cb |= floatData; // Expand the bound to include the points in the dataset.
// Note the use of arma::fvec to represent a point, since ElemType is float.
const mlpack::RangeType<float> r3 = cb.RangeDistance(arma::fvec("1.5 1.5 4.0"));
std::cout << "Distance between Chebyshev distance BallBound and "
<< "[1.5, 1.5, 4.0]: [" << r3.Lo() << ", " << r3.Hi() << "]." << std::endl;
🔗 HollowBallBound
The HollowBallBound
class represents a bounding shape that is an
arbitrary-dimensional ball bound with another smaller ball subtracted from its
inside. A HollowBallBound
consists of a center point, an outer radius, and a
secondary center point and inner radius. An example HollowBallBound
is shown
below in two dimensions; shaded area represents area held within the bound.
HollowBallBound
is used directly by the VPTree
class.
Constructors
HollowBallBound
allows configurable behavior via its two template parameters:
HollowBallBound<DistanceType, ElemType>
Different constructor forms can be used to specify different template parameters (and thus different bound behavior).
b = HollowBallBound(dimensionality)
- Construct a
HollowBallBound
with the givendimensionality
. - The bound will be empty with invalid centers and radii (e.g.,
b
will not contain any points at all). - The bound will use the Euclidean distance for
distance computation, and will expect data to have elements with type
double
.
- Construct a
b = HollowBallBound<DistanceType, ElemType>(dimensionality)
- Construct a
HollowBallBound
with the givendimensionality
that will use the givenDistanceType
class to compute distances, and expect data to have elements with typeElemType
. ElemType
should generally bedouble
orfloat
.
- Construct a
Note: these constructors provide an empty bound; be sure to grow the bound or directly modify the bound before using it!
b = HollowBallBound(innerRadius, outerRadius, center)
- Construct a
HollowBallBound
with the giveninnerRadius
for the inner ball,outerRadius
for the outer ball, andcenter
. - Both the inner and outer ball are centered at
center
. innerRadius
andouterRadius
should have typedouble
.center
should have typearma::vec
.- The bound will use the Euclidean distance for
distance computation, and will expect data to have elements with type
double
.
- Construct a
b = HollowBallBound<DistanceType, ElemType>(innerRadius, outerRadius, center)
- Construct a
HollowBallBound
with the giveninnerRadius
for the inner ball,outerRadius
for the outer ball, andcenter
. - Both the inner and outer ball are centered at
center
. innerRadius
andouterRadius
should have typeElemType
.center
should be a vector with element typeElemType
(e.g.arma::Col<ElemType>
).- The bound will use the given
DistanceType
class to compute distances, and expect data to have elements with typeElemType
.
- Construct a
Accessing and modifying properties of the bound
The individual bounds associated with each dimension of a HollowBallBound
can
be accessed and modified.
-
b.Dim()
will return asize_t
indicating the dimensionality of the bound. -
b.Center()
returns anarma::vec&
containing the center of the outer ball. Its elements can be directly modified. b.HollowCenter()
returns anarma::vec&
containing the center of the inner ball. Its elements can be directly modified.- It is possible that
b.HollowCenter()
is outside of the outer ball!
- It is possible that
b.OuterRadius()
will return adouble
that is the radius of the outer ball.b.OuterRadius() = r
will set the radius of the outer ball tor
.
b.InnerRadius()
will return adouble
that is the radius of the inner ball.b.InnerRadius() = r
will set the radius of the inner ball tor
.- It is possible that
b.InnerRadius() > b.OuterRadius()
, and this implies that the hollow center is outside the outer ball (otherwise the bound is empty).
b[dim]
will return aRange
object representing the extents of the bound in dimensiondim
.- The range is defined as
[b.Center()[dim] - b.OuterRadius(), b.Center()[dim] + b.OuterRadius()]
. - Note: this returns the maximum extents of the bound and does not consider the inner (hollow) ball.
- The range is defined as
-
b.Diameter()
returns the diameter of the ball. This is always equal to2 * b.OuterRadius()
. -
b.MinWidth()
returns the minimum width of the bound in any dimension as adouble
. This is always equal tob.Diameter()
. -
b.Distance()
returns either aEuclideanDistance
distance metric object, or aDistanceType
if a customDistanceType
has been specified in the constructor. -
b.Center(center)
will store the center of theHollowBallBound
in the vectorcenter
.center
should be of typearma::vec
. -
b.MinWidth()
returns the minimum width of the bound in any dimension as adouble
. This value is cached and no computation is performed when callingb.MinWidth()
. If the bound is empty,0
is returned. -
b.Distance()
returns either aEuclideanDistance
distance metric object, or aDistanceType
if a customDistanceType
has been specified in the constructor. -
b.Center(center)
will compute the center of theHollowBallBound
(e.g. the vector with elements equal to the midpoint ofb
in each dimension) and store it in the vectorcenter
.center
should be of typearma::vec
. -
b.Volume()
computes the volume of the hyperrectangle specified byb
. The volume is returned as adouble
. -
b.Diameter()
computes the longest diagonal of the hyperrectangle specified byb
. - A
HollowBallBound
can be serialized withdata::Save()
anddata::Load()
.
Note: if a custom ElemType
was specified in the constructor, then:
b[dim]
will return aRangeType<ElemType>
;b.OuterRadius()
,b.InnerRadius()
,b.MinWidth()
, andb.Diameter()
will returnElemType
;b.Center()
andb.HollowCenter()
will returnarma::Col<ElemType>&
; andb.Center(center)
expectscenter
to be of typearma::Col<ElemType>
.
Growing the bound
The HollowBallBound
uses the logical |=
to grow the bound to include points
or other bounds.
b |= data
expandsb
so the outer ball includes all of the data points indata
, shrinking the inner ball as necessary.data
should be a column-majorarma::mat
. The expansion operation is minimal, sob
is not expanded any more than necessary.- The bound is grown using Jack Ritter’s bounding sphere algorithm, which may move the center of the bound as it iteratively adds points to the bound. (The hollow center is not moved.)
- If the bound is empty, the centers are initialized to the first point of
data
. - If the bound is not empty, then
data
is expected to have dimensionality that matchesb.Dim()
.
b |= bound
expandsb
to include all of the volume included inbound
. The center points will not be modified.- The outer ball’s radius will be expanded to include the outer balls of both
b
andbound
. - The inner (hollow) ball’s radius will be shrunk to be the intersection of
the inner balls of
b
andbound
. (This may result inb.InnerRadius()
being 0.)
- The outer ball’s radius will be expanded to include the outer balls of both
Notes:
-
The growth operation does not grow the inner (hollow) ball. Properties related to the inner ball should be set manually with
b.HollowCenter()
andb.InnerRadius()
. -
If a custom
ElemType
was specified, then anydata
argument should be a matrix with thatElemType
(e.g.arma::Mat<ElemType>
).
Bounding distances to other objects
Once a HollowBallBound
has been successfully created and set to the desired
bounding balls, there are a number of functions that can bound the
distance between a HollowBallBound
and other objects.
b.Contains(point)
b.Contains(bound)
- Return a
bool
indicating whether or notb
contains the givenpoint
(anarma::vec
) or anotherbound
(anHRectBound
). - When passing another
bound
,true
will be returned ifbound
even partially overlaps withb
.
- Return a
b.MinDistance(point)
b.MinDistance(bound)
- Return a
double
whose value is the minimum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aHollowBallBound
). - The minimum distance between
b
and another point or bound is the length of the shortest possible line that can connect the other point or bound tob
. - If
point
orbound
are contained inb
, then the returned distance is 0.
- Return a
b.MaxDistance(point)
b.MaxDistance(bound)
- Return a
double
whose value is the maximum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aHollowBallBound
). - The maximum distance between
b
and a givenpoint
is the furthest possible distance betweenpoint
and any possible point falling within the bounding hyperrectangle ofb
. - The maximum distance between
b
and anotherbound
is the furthest possible distance between any possible point falling within the bounding hyperrectangle ofb
, and any possible point falling within the bounding hyperrectangle ofbound
. - Note that this definition means that even if
b.Contains(point)
orb.Contains(bound)
istrue
, the maximum distance may be greater than0
.
- Return a
b.RangeDistance(point)
b.RangeDistance(bound)
- Compute the minimum and maximum distance between
b
andpoint
orbound
, returning the result as aRange
object. - This is more efficient than calling
b.MinDistance()
andb.MaxDistance()
.
- Compute the minimum and maximum distance between
Note: if a custom DistanceType
and ElemType
were specified in the
constructor, then all distances will be computed with respect to the specified
DistanceType
and all return values will either be ElemType
or
RangeType<ElemType>
(except for Contains()
, which will
still return a bool
).
Example usage
// Create a hollow ball bound in 3 dimensions whose outer ball is the unit ball
// and whose inner ball is the ball with radius 0.5 centered at the origin.
// The bounding range for all three dimensions is [0.0, 1.0].
mlpack::HollowBallBound b(0.5, 1.0, arma::vec(3));
std::cout << "Hollow unit ball bound created manually:" << std::endl;
std::cout << " - Center: " << b.Center().t();
std::cout << " - Outer radius: " << b.OuterRadius() << "." << std::endl;
std::cout << " - Hollow center: " << b.HollowCenter().t();
std::cout << " - Inner radius: " << b.InnerRadius() << "." << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << " extents: [" << b[i].Lo() << ", "
<< b[i].Hi() << "]." << std::endl;
}
std::cout << std::endl;
// Create a small dataset of 5 points.
arma::mat dataset(3, 5);
dataset.col(0) = arma::vec("2.0 2.0 2.0");
dataset.col(1) = arma::vec("2.5 2.5 2.5");
dataset.col(2) = arma::vec("3.0 2.0 3.0");
dataset.col(3) = arma::vec("2.0 3.0 2.0");
dataset.col(4) = arma::vec("3.0 3.0 3.0");
// If we simply build a HollowBallBound to enclose those points, the hollow part
// of the ball is unmodified and remains empty.
mlpack::HollowBallBound b2(3);
b2 |= dataset;
std::cout << "Hollow ball bound on points with only `operator|=()`:"
<< std::endl;
std::cout << " - Center: " << b2.Center().t();
std::cout << " - Outer radius: " << b2.OuterRadius() << "." << std::endl;
std::cout << " - Hollow center: " << b2.HollowCenter().t();
std::cout << " - Inner radius: " << b2.InnerRadius() << "." << std::endl;
std::cout << std::endl;
// On the other hand, if we initialize a HollowBallBound to a non-empty bound,
// then `operator|=()` will shrink the hollow ball as necessary.
//
// We initialize this ball bound to a "slice" with radii [3.6, 3.7].
mlpack::HollowBallBound b3(3.6, 3.7, arma::vec(3));
b3 |= dataset;
std::cout << "Hollow ball bound on points with pre-initialization and "
<< "`operator|=()`:" << std::endl;
std::cout << " - Center: " << b3.Center().t();
std::cout << " - Outer radius: " << b3.OuterRadius() << "." << std::endl;
std::cout << " - Hollow center: " << b3.HollowCenter().t();
std::cout << " - Inner radius: " << b3.InnerRadius() << "." << std::endl;
std::cout << std::endl;
// Manually create a hollow ball bound whose hollow center is different than the
// outer ball's center.
mlpack::HollowBallBound b4(3);
b4.OuterRadius() = 3.0;
b4.InnerRadius() = 1.5;
b4.Center() = arma::vec(3);
b4.HollowCenter() = arma::vec("1.0 1.0 1.0");
// Compute the minimum distance between a point inside the hollow unit ball's
// outer ball.
const double d1 = b.MinDistance(arma::vec("0.9 0.9 0.9"));
std::cout << "Minimum distance between hollow unit ball bound and [0.9, 0.9, "
<< "0.9]: " << d1 << "." << std::endl;
// Compute the minimum distance between a point inside the hollow unit ball's
// inner ball (so the point is not contained in the bound---it is within the
// hollow section).
const double d2 = b.MinDistance(arma::vec("0.0 0.0 0.0"));
std::cout << "Minimum distance between hollow unit ball bound and [0.0, 0.0, "
<< "0.0]: " << d2 << "." << std::endl;
std::cout << std::endl;
// Use Contains(). In this case, the 'else' will be taken.
if (b.Contains(arma::vec("1.5 1.5 1.5")))
{
std::cout << "Hollow unit ball bound contains [1.5, 1.5, 1.5]." << std::endl;
}
else
{
std::cout << "Hollow unit ball bound does not contain [1.5, 1.5, 1.5]."
<< std::endl;
}
std::cout << std::endl;
// Compute the maximum distance between a point inside the unit ball and the
// unit hollow ball bound.
const double d3 = b4.MaxDistance(arma::vec("0.1 0.1 0.1"));
std::cout << "Maximum distance between hollow unit ball bound and [0.1, 0.1, "
<< "0.1]: " << d3 << "." << std::endl;
// Compute the minimum and maximum distances between the hollow unit ball bound
// and the bound built on data points.
const mlpack::Range r = b.RangeDistance(b3);
std::cout << "Distances between hollow unit ball bound and second hollow "
<< "dataset bound: [" << r.Lo() << ", " << r.Hi() << "]." << std::endl;
// Create a bound using the Manhattan (L1) distance and compute the minimum and
// maximum distance to a point.
mlpack::HollowBallBound<mlpack::ManhattanDistance> mb(2.0, 5.0, arma::vec(3));
const mlpack::Range r2 = mb.RangeDistance(arma::vec("1.5 1.5 4.0"));
std::cout << "Distance between Manhattan distance HollowBallBound and "
<< "[1.5, 1.5, 4.0]: [" << r2.Lo() << ", " << r2.Hi() << "]." << std::endl;
// Create a bound using the Chebyshev (L-inf) distance, using random 32-bit
// floating point elements, and compute the minimum and maximum distance to a
// point.
arma::fmat floatData(3, 25, arma::fill::randu);
mlpack::HollowBallBound<mlpack::ChebyshevDistance, float> cb;
cb |= floatData;
// Note the use of arma::fvec to represent a point, since ElemType is float.
const mlpack::RangeType<float> r3 = cb.RangeDistance(arma::fvec("1.5 1.5 4.0"));
std::cout << "Distance between Chebyshev distance HollowBallBound and "
<< "[1.5, 1.5, 4.0]: [" << r3.Lo() << ", " << r3.Hi() << "]." << std::endl;
🔗 CellBound
The CellBound
class represents a bound made up of a contiguous subregion of a
hyperrectangle. Suppose that the region represented by a hyperrectangle was
linearized and then ordered with
Z-ordering. Under this scheme, a
CellBound
can be represented as containing all points whose linearization
falls between a “start address” and an “end address”. A simple depiction of a
2-dimensional CellBound
is shown below.
In the example above, p_1
represents the point that is the “start address”,
and p_2
represents the point that is the “end address”; any points with
address in between those (e.g. the shaded region) are contained in the
CellBound
.
CellBound
is used directly by the UBTree
(universal B-tree)
class.
Addressing in a CellBound
In a CellBound
, each point is mapped to an ordered “address” that indicates
its position in the bound using
Z-ordering (also called Morton
ordering). The mathematical details of this mapping are described in
the UB-tree paper;
although mlpack uses a slightly modified implementation, the general idea is the
same.
The following two functions can be used to convert to and from linearized addresses:
PointToAddress(addr, point)
- Compute and store the address of the point
point
toaddr
. addr
should be of typearma::uvec
orarma::u32_vec
, depending on the precision ofpoint
.point
should be an Armadillo vector type (e.g.arma::vec
orarma::fvec
).
- Compute and store the address of the point
AddressToPoint(point, addr)
- Compute the point that would map to the address
addr
and store it inpoint
. addr
should be of typearma::uvec
orarma::u32_vec
.point
should be an Armadillo vector type (e.g.arma::vec
orarma::fvec
) whose precision should match that ofaddr
.
- Compute the point that would map to the address
Constructors
CellBound
allows configurable behavior via its two template parameters:
CellBound<DistanceType, ElemType>
Different constructor forms can be used to specify different template parameters (and thus different bound behavior).
b = CellBound(dimensionality)
- Construct a
CellBound
with the givendimensionality
. - The bound will be empty with an invalid center (e.g.,
b
will not contain any points at all). - The bound will use the Euclidean distance for
distance computation, and will expect data to have elements with type
double
.
- Construct a
b = CellBound<DistanceType>(dimensionality)
- Construct a
CellBound
with the givendimensionality
that will use the givenDistanceType
class to compute distances. DistanceType
is required to be anLMetric
, as the distance calculation must be decomposable across dimensions.- The bound will expect data to have elements with type
double
.
- Construct a
b = CellBound<DistanceType, ElemType>(dimensionality)
- Construct a
CellBound
with the givendimensionality
that will use the givenDistanceType
class to compute distances, and expect data to have elements with typeElemType
. DistanceType
is required to be anLMetric
, as the distance calculation must be decomposable across dimensions.ElemType
should generally bedouble
orfloat
.
- Construct a
Note: these constructors provide an empty bound; be sure to grow the bound before using it!
Accessing properties of the bound
The individual bounds associated with each dimension of a CellBound
can be
accessed, but should not be directly modified—see growing the
bound for ways to grow a CellBound
.
-
b.Clear()
will reset the bound to an empty bound (e.g. containing no points). -
b.Dim()
will return asize_t
indicating the dimensionality of the bound. b[dim]
will return aRange
object holding the lower and upper bounds of the outer hyperrectangle ofb
in dimensiondim
.- Note: this is not a tight bounding shape! It is equivalent to the
full outer hyperrectangle in introductory figure above, not
the subregion of the hyperrectangle that
b
represents.
- Note: this is not a tight bounding shape! It is equivalent to the
full outer hyperrectangle in introductory figure above, not
the subregion of the hyperrectangle that
-
b.LoAddress()
andb.HiAddress()
returnarma::uvec&
s representing the lower and upper addresses of the bound. - A tighter bounding shape for
b
can be obtained by representing theCellBound
as the union of a set of hyperrectangles.b.NumBounds()
returns the number of hyperrectangles required to representb
’s bound tightly.b.LoBound()
andb.HiBound()
returnarma::mat&
s representing the low and high corners of each of the bounding hyperrectangles.b.LoBound().col(i)
andb.HiBound().col(i)
represent the corners of thei
‘th bounding hyperrectangle.
-
b.MinWidth()
returns the minimum width of the bound in any dimension as adouble
. This value is cached and no computation is performed when callingb.MinWidth()
. If the bound is empty,0
is returned. -
b.Distance()
returns either aEuclideanDistance
distance metric object, or aDistanceType
if a customDistanceType
has been specified in the constructor. -
b.Center(center)
will compute the center of theHRectBound
(e.g. the vector with elements equal to the midpoint ofb
in each dimension) and store it in the vectorcenter
.center
should be of typearma::vec
. -
b.Diameter()
computes the longest diagonal of the hyperrectangle specified byb
. - A
CellBound
can be serialized withdata::Save()
anddata::Load()
.
Note: if a custom ElemType
was specified in the constructor, then:
b[dim]
will return aRangeType<ElemType>
;b.LoAddress()
andb.HiAddress()
will returnarma::Col<T>&
s whereT
isuint32_t
ifElemType
is 32 bits, anduint64_t
ifElemType
is 64 bits;b.LoBound()
andb.HiBound()
will returnarma::Mat<ElemType>&
;b.MinWidth()
andb.Diameter()
will returnElemType
; andb.Center(center)
expectscenter
to be of typearma::Col<ElemType>
.
Growing the bound
The CellBound
uses the logical |=
operator to grow the bound to contain
sets of points or other bounds.
b |= data
expandsb
to include all of the data points indata
.data
should be a column-majorarma::mat
. The expansion operation is minimal, sob
is not expanded any more than necessary.- The
LoAddress()
andHiAddress()
members must be manually updated after the expansion to the desired values. (This is automatically handled when aCellBound
is created by building aBinarySpaceTree
withUBTreeSplit
.)
- The
b |= bound
expandsb
to fully includebound
, wherebound
is anotherCellBound
. The expansion/union operation is minimal, sob
is not expanded any more than necessary.
Notes:
-
When another bound is passed, it must have the same type as
b
; so, if a customDistanceType
andElemType
were specified, thenbound
must have typeHRectBound<DistanceType, ElemType>
. -
If a custom
ElemType
was specified, then anydata
argument should be a matrix with thatElemType
(e.g.arma::Mat<ElemType>
). -
Each function expects the other bound or dataset to have dimensionality that matches
b
.
Bounding distances to other objects
Once a CellBound
has been successfully created and set to the desired subset
of its bounding hyperrectangle, there are a number of functions that can bound
the distance between a CellBound
and other objects.
b.Contains(point)
- Return a
bool
indicating whether or notb
contains the givenpoint
(anarma::vec
).
- Return a
b.MinDistance(point)
b.MinDistance(bound)
- Return a
double
whose value is the minimum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aCellBound
). - The minimum distance between
b
and another point or bound is the length of the shortest possible line that can connect the other point or bound tob
. - If
point
orbound
are contained inb
, then the returned distance is 0.
- Return a
b.MaxDistance(point)
b.MaxDistance(bound)
- Return a
double
whose value is the maximum possible distance betweenb
and either apoint
(anarma::vec
) or anotherbound
(aCellBound
). - The maximum distance between
b
and a givenpoint
is the furthest possible distance betweenpoint
and any possible point falling within the bounding shape ofb
. - The maximum distance between
b
and anotherbound
is the furthest possible distance between any possible point falling within the bounding shape ofb
, and any possible point falling within the bounding shape ofbound
. - Note that this definition means that even if
b.Contains(point)
orb.Contains(bound)
istrue
, the maximum distance may be greater than0
.
- Return a
b.RangeDistance(point)
b.RangeDistance(bound)
- Compute the minimum and maximum distance between
b
andpoint
orbound
, returning the result as aRange
object. - This is more efficient than calling
b.MinDistance()
andb.MaxDistance()
.
- Compute the minimum and maximum distance between
Note: if a custom DistanceType
and ElemType
were specified in the
constructor, then all distances will be computed with respect to the specified
DistanceType
and all return values will either be ElemType
or
RangeType<ElemType>
(except for Contains()
, which will
still return a bool
).
Example usage
// Create a random dataset of 50 points in 3 dimensions.
arma::mat dataset(3, 50, arma::fill::randu);
// Now create a CellBound that contains those points via the |= operator.
mlpack::CellBound b(3);
b |= dataset;
b.UpdateAddressBounds(dataset);
std::cout << "Outer bounding box of CellBound:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b[i].Lo() << ", " << b[i].Hi()
<< "]." << std::endl;
}
// Create another random dataset, but shifted to fit in a box ranging from
// [2, 2, 2] to [3, 3, 3].
arma::mat dataset2(3, 50, arma::fill::randu);
dataset2 += 2.0;
mlpack::CellBound b2(3);
b2 |= dataset2;
b2.UpdateAddressBounds(dataset2);
std::cout << "Outer bounding box of second CellBound:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b2[i].Lo() << ", " << b2[i].Hi()
<< "]." << std::endl;
}
// Compute union of two CellBounds.
mlpack::CellBound b3(3);
b3 |= b;
b3 |= b2;
std::cout << "Outer bounding box of union CellBound:" << std::endl;
for (size_t i = 0; i < 3; ++i)
{
std::cout << " - Dimension " << i << ": [" << b3[i].Lo() << ", " << b3[i].Hi()
<< "]." << std::endl;
}
// Print statistics about the union bound.
std::cout << "Union bound details:" << std::endl;
std::cout << " - Dimensionality: " << b3.Dim() << "." << std::endl;
std::cout << " - Minimum width: " << b3.MinWidth() << "." << std::endl;
std::cout << " - Diameter: " << b3.Diameter() << "." << std::endl;
arma::vec center;
b3.Center(center);
std::cout << " - Center: " << center.t();
std::cout << std::endl;
// Compute the minimum distance between a point and the first two bounds.
const double d1 = b.MinDistance(arma::vec("1.5 1.5 1.5"));
const double d2 = b2.MinDistance(arma::vec("1.5 1.5 1.5"));
std::cout << "Minimum distance between first bound and [1.5, 1.5, 1.5]: "
<< d1 << "." << std::endl;
std::cout << "Minimum distance between second bound and [1.5, 1.5, 1.5]: "
<< d2 << "." << std::endl;
// Use Contains(). In this case, the 'else' will be taken.
if (b.Contains(arma::vec("1.5 1.5 1.5")))
std::cout << "First bound contains [1.5, 1.5, 1.5]." << std::endl;
else
std::cout << "First bound does not contain [1.5, 1.5, 1.5]." << std::endl;
std::cout << std::endl;
// Compute the maximum distance between a point inside the unit cube and the
// first bound.
const double d3 = b.MaxDistance(arma::vec("0.5 0.5 0.5"));
std::cout << "Maximum distance between first bound and [0.5, 0.5, 0.5]: " << d2
<< "." << std::endl;
// Compute the minimum and maximum distances between first and second bounds.
const mlpack::Range r = b.RangeDistance(b2);
std::cout << "Distances between first bound and second bound: [" << r.Lo()
<< ", " << r.Hi() << "]." << std::endl;
// Create a bound using the Manhattan (L1) distance and compute the minimum and
// maximum distance to a point.
mlpack::CellBound<mlpack::ManhattanDistance> mb(3);
mb |= dataset;
mb.UpdateAddressBounds(dataset);
const mlpack::Range r2 = mb.RangeDistance(arma::vec("1.5 1.5 4.0"));
std::cout << "Distance between Manhattan distance CellBound and "
<< "[1.5, 1.5, 4.0]: [" << r2.Lo() << ", " << r2.Hi() << "]." << std::endl;
// Create a bound using the Chebyshev (L-inf) distance, using random 32-bit
// floating point elements, and compute the minimum and maximum distance to a
// point.
arma::fmat floatData(3, 25, arma::fill::randu);
mlpack::CellBound<mlpack::ChebyshevDistance, float> cb;
cb |= floatData; // This will set the bound to [2.0, 3.0] in every dimension.
cb.UpdateAddressBounds(floatData);
// Note the use of arma::fvec to represent a point, since ElemType is float.
const mlpack::RangeType<float> r3 = cb.RangeDistance(arma::fvec("1.5 1.5 4.0"));
std::cout << "Distance between Chebyshev distance CellBound and "
<< "[1.5, 1.5, 4.0]: [" << r3.Lo() << ", " << r3.Hi() << "]." << std::endl;
🔗 Custom BoundType
s
The BinarySpaceTree
class allows an arbitrary BoundType
template parameter
to be specified for custom behavior. By default, this is
HRectBound
(a hyper-rectangle bound), but it is also possible
to implement a custom BoundType
. Any custom BoundType
class must implement
the following functions:
// NOTE: the custom BoundType class must take at least two template parameters.
template<typename DistanceType, typename ElemType>
class BoundType
{
public:
// A default constructor must be available.
BoundType();
// Initialize the bound to an empty bound in the given dimensionality.
BoundType(const size_t dimensionality);
// A copy and move constructor must be available. (If your class is simple,
// you can generally omit this and use the default-generated versions, which
// are commented out below.)
BoundType(const BoundType& other);
BoundType(BoundType&& other);
// BoundType(const BoundType& other) = default;
// BoundType(BoundType&& other) = default;
// Return the minimum and maximum ranges of the bound in the given dimension.
mlpack::RangeType<ElemType> operator[](const size_t dim) const;
// Return the longest possible distance between two points contained in the
// bound. (Examples: for a ball bound, this is just the regular diameter.
// For a rectangle bound, this is the length of the longest diagonal.)
ElemType Diameter() const;
// Return the minimum width of the bound in any dimension.
ElemType MinWidth() const;
// Return the DistanceType object that this bound uses for distance
// calculations.
DistanceType& Distance();
// Expand the bound so that it includes all of the data points in `points`.
// `points` will be a matrix type whose element type matches `ElemType`.
template<typename MatType>
BoundType& operator|=(const MatType& points);
// Compute the minimum possible distance between the given point and the
// bound. `VecType` will be a single column vector with element type that
// matches `ElemType`.
template<typename VecType>
ElemType MinDistance(const VecType& point) const;
// Compute the minimum possible distance between this bound and the given
// other bound.
ElemType MinDistance(const BoundType& other) const;
// Compute the maximum possible distance between the given point and the
// bound. `VecType` will be a single column vector with element type that
// matches `ElemType`.
template<typename VecType>
ElemType MaxDistance(const VecType& point) const;
// Compute the maximum possible distance between this bound and the given
// other bound.
ElemType MaxDistance(const BoundType& other) const;
// Compute the minimum and maximum distances between the given point and the
// bound, returning them in a Range object. `VecType` will be a single column
// vector with element type that matches `ElemType`.
template<typename VecType>
mlpack::RangeType<ElemType> RangeDistance(const VecType& point) const;
// Compute the minimum and maximum distances between this bound and the given
// other bound, returning them in a Range object.
mlpack::RangeType<ElemType> RangeDistance(const BoundType& other) const;
// Compute the center of the bound and store it into the given `center`
// vector.
void Center(arma::Col<ElemType>& center);
// Serialize the bound to disk using the cereal library.
template<typename Archive>
void serialize(Archive& ar, const uint32_t version);
};
Behavior of some aspects of the BinarySpaceTree
depend on the traits of a
particular bound. Optionally, you may define a BoundTraits
specialization for
your bound type, of the following form:
// Replace `BoundType` below with the name of the custom class.
template<typename DistanceType, typename ElemType>
struct BoundTraits<BoundType<DistanceType, ElemType>>
{
//! If true, then the bounds for each dimension are tight. If false, then the
//! bounds for each dimension may be looser than the range of all points held
//! in the bound. This defaults to false if the struct is not defined.
static const bool HasTightBounds = false;
};
Note that if a custom SplitType
is being used, the custom BoundType
will
also have to implement any functions required by the custom SplitType
. In
addition, custom RuleType
s used with tree
traversals may have additional requirements on the BoundType
; the functions
listed above are merely the minimum required to use a BoundType
with a
BinarySpaceTree
.
🔗 StatisticType
Each node in a BinarySpaceTree
holds an instance of the StatisticType
class. This class can be used to store additional bounding information or other
cached quantities that a BinarySpaceTree
does not already compute.
mlpack provides a few existing StatisticType
classes, and a custom
StatisticType
can also be easily implemented:
EmptyStatistic
: an empty statistic class that does not hold any information- Custom
StatisticType
s: implement a fully customStatisticType
Note: this section is still under construction—not all statistic types are documented yet.
🔗 EmptyStatistic
The EmptyStatistic
class is an empty placeholder class that is used as the
default StatisticType
template parameter for mlpack trees.
The class does not hold any members and provides no functionality. See the implementation.
🔗 Custom StatisticType
s
A custom StatisticType
is trivial to implement. Only a default constructor
and a constructor taking a BinarySpaceTree
is necessary.
class CustomStatistic
{
public:
// Default constructor required by the StatisticType policy.
CustomStatistic();
// Construct a CustomStatistic for the given fully-constructed
// `BinarySpaceTree` node. Here we have templatized the tree type to make it
// easy to handle any type of `BinarySpaceTree`.
template<typename TreeType>
StatisticType(TreeType& node);
//
// Adding any additional precomputed bound quantities can be done; these
// quantities should be computed in the constructor. They can then be
// accessed from the tree with `node.Stat()`.
//
};
Example: suppose we wanted to know, for each node, the exact time at which it
was created. A StatisticType
could be created that has a
std::time_t
member,
whose value is computed in the constructor.
🔗 SplitType
The SplitType
template parameter controls the algorithm used to split each
node of a BinarySpaceTree
while building. The splitting strategy used can be
entirely arbitrary—the SplitType
only needs to specify whether a node should
be split, and if so, which points should go to the left child, and which should
go to the right child.
mlpack provides several drop-in choices for SplitType
, and it is also possible
to write a fully custom split:
MidpointSplit
: splits on the midpoint of the dimension with maximum widthMeanSplit
: splits on the mean value of the points in the dimension with maximum widthVantagePointSplit
: split by selecting a ‘vantage point’ and then split points into ‘near’ and ‘far’ setsRPTreeMeanSplit
: projects points onto a random vector, splitting on the median value of the projections, or in some cases on the distance from the mean valueRPTreeMaxSplit
: projects points onto a random vector, splitting on a random offset of the median of projected pointsUBTreeSplit
: splits aCellBound
into two balanced children- Custom
SplitType
s: implement a fully customSplitType
class
Note: this section is still under construction—not all split types are documented yet.
🔗 MidpointSplit
The MidpointSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the default strategy for splitting
KDTree
s.
The splitting strategy for the MidpointSplit
class is, given a set of points:
- Find the dimension of the points with maximum width.
- Split in that dimension.
- Points less than the midpoint (i.e.
(max + min) / 2
) will go to the left child. - Points greater than or equal to the midpoint will go to the right child.
For implementation details, see the source code.
🔗 MeanSplit
The MeanSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the splitting strategy used by the
MeanSplitKDTree
class.
The splitting strategy for the MeanSplit
class is, given a set of points:
- Find the dimension
d
of the points with maximum width. - Compute the mean value
m
of the points in dimensiond
. - Split in dimension
d
. - Points less than
m
will go to the left child. - Points greater than or equal to
m
will go to the right child.
In practice, the MeanSplit
splitting strategy often results in a tree with
fewer leaf nodes than MidpointSplit
, because each split is more likely to be
balanced. However, counterintuitively, a more balanced tree can be worse for
search tasks like nearest neighbor search, because unbalanced nodes are more
easily pruned away during search. In general, using MidpointSplit
for nearest
neighbor search is 20-80% faster, but this is not true for every dataset or
task.
For implementation details, see the source code.
🔗 VantagePointSplit
The VantagePointSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the default strategy for splitting
VPTree
s, and is detailed in
the paper.
Due to the nature of the split, VantagePointSplit
should always be used
with the HollowBallBound
.
The splitting strategy for the VantagePointSplit
class is, given a set of
points:
- Select a vantage point from a sample of 100 random candidate points (or use
the full set if there are fewer than 100 points):
- Compute the distances between each candidate point and 100 additional random samples (or the full set if there are fewer than 100 points).
- Select the vantage point as the candidate with maximum average distance to the additional random samples.
- Compute a boundary distance
mu
that is the median distance between the vantage point and its random samples. - Points with distance less than
mu
from the vantage point will go to the left child. - Points with distance greater than
mu
from the vantage point will go to the right child.
The VantagePointSplit
class has three template parameters:
VantagePointSplit<BoundType, MatType, MaxNumSamples = 100>
If a custom number of samples S
is desired, the easiest way to specify is via
a template typedef:
template<typename BoundType, typename MatType>
using MyVantagePointSplit = VantagePointSplit<BoundType, MatType, S>;
Then, MyVantagePointSplit
can be used directly with BinarySpaceTree
as a
SplitType
.
For implementation details, see the source code.
🔗 RPTreeMeanSplit
The RPTreeMeanSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the splitting strategy used by the
RPTree
class, and uses a random projection to split points. The
general idea is described in the paper by
Dasgupta and Freund,
as the RPTree-Mean
version of the ChooseRule()
function.
The splitting strategy for the RPTreeMeanSplit
class is, given a set of
points:
- Draw a random vector
z
. - Sample up to 100 points and compute
d
, the average pairwise distance between the points. - If
10 * d
is less than or equal to the squared diameter of the bounding box of the points:- Project all points onto the vector
z
, and compute the medianv
of the projected values. - Points with projected value less than
v
will go to the left child. - Points with projected value greater than or equal to
v
will go to the right child.
- Project all points onto the vector
- Otherwise:
- Compute the mean
s
of all points. - Points with distance from
s
less than the median distance froms
will go to the left child. - Points with distance from
s
greater than or equal to the median distance froms
will go to the right child.
- Compute the mean
The implementation strategy differs slightly from the RPTree-Mean
version in
the paper: instead of computing the true average pairwise distance between all
points, a sample of 100 points is used.
For implementation details, see the source code.
🔗 RPTreeMaxSplit
The RPTreeMaxSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the splitting strategy used by the
MaxRPTree
class, and uses a random projection to split
points. The general idea is described in the paper by
Dasgupta and Freund,
as the RPTree-Max
version of the ChooseRule()
function.
The splitting strategy for the RPTreeMaxSplit
class is, given a set of points,
- Draw a random vector
z
. - Sample up to 100 points (call this sample
S
). - Compute
v
, the median value of projections of points inS
ontoz
. - Points with projection onto
z
less thanv
will go to the left child. - Points with projection onto
z
greater than or equal tov
will go to the right child.
The implementation strategy differs slightly from the RPTree-Max
version in
the paper: instead of computing the median on all points, a sample of 100 points
is used.
For implementation details, see the source code.
🔗 UBTreeSplit
The UBTreeSplit
class is a splitting strategy that can be used by
BinarySpaceTree
. It is the splitting strategy used by
theUBTree
class (the universal
B-tree),
and it requires that the BoundType
being used is
CellBound
.
The splitting strategy for the UBTreeSplit
class is simple: with each point
mapped to its corresponding linearized address,
those points with address less than the median address go to the left child;
other points go to the right child.
For implementation details, see the source code.
🔗 Custom SplitType
s
Custom split strategies for a binary space tree can be implemented via the
SplitType
template parameter. By default, the
MidpointSplit
splitting strategy is used, but it is also
possible to implement and use a custom SplitType
. Any custom SplitType
class must implement the following signature:
// NOTE: the custom SplitType class must take two template parameters.
template<typename BoundType, typename ElemType>
class SplitType
{
public:
// The SplitType class must provide a SplitInfo struct that will contain the
// information necessary to perform a split. There are no required members
// here; the BinarySpaceTree class merely passes these around in the
// SplitNode() and PerformSplit() functions (see below).
struct SplitInfo { };
// Given that a node contains the points
// `data.cols(begin, begin + count - 1)`, determine whether the node should be
// split. If so, `true` should be returned and `splitInfo` should be set with
// the necessary information so that `PerformSplit()` can actually perform the
// split.
//
// If the node should not be split, `false` should be returned, and
// `splitInfo` is ignored.
template<typename MatType>
static bool SplitNode(const BoundType& bound,
MatType& data,
const size_t begin,
const size_t count,
SplitInfo& splitInfo);
// Perform the split using the `splitInfo` object, which was populated by a
// previous call to `SplitNode()`. This should reorder the points in the
// subset `data.points(begin, begin + count - 1)` such that the points for the
// left child come first, and then the points for the right child come last.
//
// This should return the index of the first point that goes to the right
// child. This is equivalent to `begin + leftPoints` where `leftPoints` is
// the number of points that went to the left child. Very specifically, on
// exit,
//
// `data.cols(begin, begin + leftPoints - 1)` should contain only points
// that will go to the left child;
// `data.cols(begin + leftPoints, begin + count - 1)` should contain only
// points that will go to the right child;
// the value `begin + leftPoints` should be returned.
//
template<typename MatType>
static size_t PerformSplit(MatType& data,
const size_t begin,
const size_t count,
const SplitInfo& splitInfo,
std::vector<size_t>& oldFromNew);
};
🔗 Example usage
The BinarySpaceTree
class is only really necessary when a custom bound type or
custom splitting strategy is intended to be used. For simpler use cases, one of
the typedefs of BinarySpaceTree
(such as KDTree
) will suffice.
For this reason, all of the examples below explicitly specify all five template
parameters of BinarySpaceTree
.
Writing a custom bound type and
writing a custom splitting strategy are discussed
in the previous sections. Each of the parameters in the examples below can be
trivially changed for different behavior.
Build a BinarySpaceTree
on the cloud
dataset and print basic statistics
about the tree.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);
// Build the binary space tree with a leaf size of 10. (This means that nodes
// are split until they contain 10 or fewer points.)
//
// The std::move() means that `dataset` will be empty after this call, and no
// data will be copied during tree building.
mlpack::BinarySpaceTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::HRectBound,
mlpack::MidpointSplit> tree(std::move(dataset));
// Print the bounding box of the root node.
std::cout << "Bounding box of root node:" << std::endl;
for (size_t i = 0; i < tree.Bound().Dim(); ++i)
{
std::cout << " - Dimension " << i << ": [" << tree.Bound()[i].Lo() << ", "
<< tree.Bound()[i].Hi() << "]." << std::endl;
}
std::cout << std::endl;
// Print the number of descendant points of the root, and of each of its
// children.
std::cout << "Descendant points of root: "
<< tree.NumDescendants() << "." << std::endl;
std::cout << "Descendant points of left child: "
<< tree.Left()->NumDescendants() << "." << std::endl;
std::cout << "Descendant points of right child: "
<< tree.Right()->NumDescendants() << "." << std::endl;
std::cout << std::endl;
// Compute the center of the BinarySpaceTree.
arma::vec center;
tree.Center(center);
std::cout << "Center of tree: " << center.t();
Build two BinarySpaceTree
s on subsets of the corel dataset and compute minimum
and maximum distances between different nodes in the tree.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::mat dataset;
mlpack::data::Load("corel-histogram.csv", dataset, true);
// Convenience typedef for the tree type.
using TreeType = mlpack::BinarySpaceTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::HRectBound,
mlpack::MidpointSplit>;
// Build trees on the first half and the second half of points.
TreeType tree1(dataset.cols(0, dataset.n_cols / 2));
TreeType tree2(dataset.cols(dataset.n_cols / 2 + 1, dataset.n_cols - 1));
// Compute the maximum distance between the trees.
std::cout << "Maximum distance between tree root nodes: "
<< tree1.MaxDistance(tree2) << "." << std::endl;
// Get the leftmost grandchild of the first tree's root---if it exists.
if (!tree1.IsLeaf() && !tree1.Child(0).IsLeaf())
{
TreeType& node1 = tree1.Child(0).Child(0);
// Get the rightmost grandchild of the second tree's root---if it exists.
if (!tree2.IsLeaf() && !tree2.Child(1).IsLeaf())
{
TreeType& node2 = tree2.Child(1).Child(1);
// Print the minimum and maximum distance between the nodes.
mlpack::Range dists = node1.RangeDistance(node2);
std::cout << "Possible distances between two grandchild nodes: ["
<< dists.Lo() << ", " << dists.Hi() << "]." << std::endl;
// Print the minimum distance between the first node and the first
// descendant point of the second node.
const size_t descendantIndex = node2.Descendant(0);
const double descendantMinDist =
node1.MinDistance(node2.Dataset().col(descendantIndex));
std::cout << "Minimum distance between grandchild node and descendant "
<< "point: " << descendantMinDist << "." << std::endl;
// Which child of node2 is closer to node1?
const size_t closerIndex = node2.GetNearestChild(node1);
if (closerIndex == 0)
std::cout << "The left child of node2 is closer to node1." << std::endl;
else if (closerIndex == 1)
std::cout << "The right child of node2 is closer to node1." << std::endl;
else // closerIndex == 2 in this case.
std::cout << "Both children of node2 are equally close to node1."
<< std::endl;
// And which child of node1 is further from node2?
const size_t furtherIndex = node1.GetFurthestChild(node2);
if (furtherIndex == 0)
std::cout << "The left child of node1 is further from node2."
<< std::endl;
else if (furtherIndex == 1)
std::cout << "The right child of node1 is further from node2."
<< std::endl;
else // furtherIndex == 2 in this case.
std::cout << "Both children of node1 are equally far from node2."
<< std::endl;
}
}
Build a BinarySpaceTree
on 32-bit floating point data and save it to disk.
// See https://datasets.mlpack.org/corel-histogram.csv.
arma::fmat dataset;
mlpack::data::Load("corel-histogram.csv", dataset);
// Build the BinarySpaceTree using 32-bit floating point data as the matrix
// type. We will still use the default EmptyStatistic and EuclideanDistance
// parameters. A leaf size of 100 is used here.
mlpack::BinarySpaceTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::HRectBound,
mlpack::MidpointSplit> tree(std::move(dataset), 100);
// Save the tree to disk with the name 'tree'.
mlpack::data::Save("tree.bin", "tree", tree);
std::cout << "Saved tree with " << tree.Dataset().n_cols << " points to "
<< "'tree.bin'." << std::endl;
Load a 32-bit floating point BinarySpaceTree
from disk, then traverse it
manually and find the number of leaf nodes with less than 10 points.
// This assumes the tree has already been saved to 'tree.bin' (as in the example
// above).
// This convenient typedef saves us a long type name!
using TreeType = mlpack::BinarySpaceTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::fmat,
mlpack::HRectBound,
mlpack::MidpointSplit>;
TreeType tree;
mlpack::data::Load("tree.bin", "tree", tree);
std::cout << "Tree loaded with " << tree.NumDescendants() << " points."
<< std::endl;
// Recurse in a depth-first manner. Count both the total number of leaves, and
// the number of leaves with less than 10 points.
size_t leafCount = 0;
size_t totalLeafCount = 0;
std::stack<TreeType*> stack;
stack.push(&tree);
while (!stack.empty())
{
TreeType* node = stack.top();
stack.pop();
if (node->NumPoints() < 10)
++leafCount;
++totalLeafCount;
if (!node->IsLeaf())
{
stack.push(node->Left());
stack.push(node->Right());
}
}
// Note that it would be possible to use TreeType::SingleTreeTraverser to
// perform the recursion above, but that is more well-suited for more complex
// tasks that require pruning and other non-trivial behavior; so using a simple
// stack is the better option here.
// Print the results.
std::cout << leafCount << " out of " << totalLeafCount << " leaves have less "
<< "than 10 points." << std::endl;
Build a BinarySpaceTree
and map between original points and new points.
// See https://datasets.mlpack.org/cloud.csv.
arma::mat dataset;
mlpack::data::Load("cloud.csv", dataset, true);
// Build the tree.
std::vector<size_t> oldFromNew, newFromOld;
mlpack::BinarySpaceTree<mlpack::EuclideanDistance,
mlpack::EmptyStatistic,
arma::mat,
mlpack::HRectBound,
mlpack::MidpointSplit> tree(
dataset, oldFromNew, newFromOld);
// oldFromNew and newFromOld will be set to the same size as the dataset.
std::cout << "Number of points in dataset: " << dataset.n_cols << "."
<< std::endl;
std::cout << "Size of oldFromNew: " << oldFromNew.size() << "." << std::endl;
std::cout << "Size of newFromOld: " << newFromOld.size() << "." << std::endl;
std::cout << std::endl;
// See where point 42 in the tree's dataset came from.
std::cout << "Point 42 in the permuted tree's dataset:" << std::endl;
std::cout << " " << tree.Dataset().col(42).t();
std::cout << "Was originally point " << oldFromNew[42] << ":" << std::endl;
std::cout << " " << dataset.col(oldFromNew[42]).t();
std::cout << std::endl;
// See where point 7 in the original dataset was mapped.
std::cout << "Point 7 in original dataset:" << std::endl;
std::cout << " " << dataset.col(7).t();
std::cout << "Mapped to point " << newFromOld[7] << ":" << std::endl;
std::cout << " " << tree.Dataset().col(newFromOld[7]).t();