|
| ~ball_tree_neighbors () |
|
void | init_options (const std::map< std::string, flexible_type > &_opts) override |
|
void | train (const sframe &X, const std::vector< flexible_type > &ref_labels, const std::vector< dist_component_type > &composite_distance_params, const std::map< std::string, flexible_type > &opts) override |
|
sframe | query (const v2::ml_data &mld_queries, const std::vector< flexible_type > &query_labels, const size_t k, const double radius, const bool include_self_edges) const override |
|
size_t | get_version () const override |
|
void | save_impl (turi::oarchive &oarc) const override |
|
void | load_version (turi::iarchive &iarc, size_t version) override |
|
Ball tree nearest neighbors class
Implements the ball tree method for k-nearest neighbors search.
The ball tree works by partitioning the reference data into into successively smaller balls, and recording the center (i.e. pivot) and radius of each ball. A ball tree query uses the pivots and radii to exclude many of the balls from the k-nearest neighbor search, allowing it to run in sub-linear time.
In addition to the objects contained in the nearest_neighbors_model base class, the ball tree contains the following:
- membership: Each element of this vector indicates which node the corresponding reference data point belongs to. After the tree is constructed, the elements in this vector correspond to leaf nodes of the tree only.
- pivots: The reference data point at the center of each tree node.
- node_radii: The distance from the pivot of each node to the most distant reference point belonging to the tree node.
Definition at line 46 of file ball_tree_neighbors.hpp.
sframe turi::nearest_neighbors::ball_tree_neighbors::query |
( |
const v2::ml_data & |
mld_queries, |
|
|
const std::vector< flexible_type > & |
query_labels, |
|
|
const size_t |
k, |
|
|
const double |
radius, |
|
|
const bool |
include_self_edges |
|
) |
| const |
|
override |
Find neighbors of queries in a created ball tree model.
For each query, the method keeps track of the current k-nearest neighbors in the ball tree. At each node, the closest possible point in each child node to the query is computed, and if this distance is further than the current k'th nearest neighbor, that child node (and its descendants) is skpped in the traversal.
- Parameters
-
[in] | mld_queries | query data |
[in] | query_labels | sframe query labels |
[in] | k | size_t max number of neighbors to return for each query |
[in] | radius | double max distance for returned neighbors to each query |
[out] | ret | sframe SFrame with four columns: query label, reference label, distance, and rank. |
- Note
- Assumes that data is already in the right shape.