Source code for cleanlab.internal.neighbor.search
from __future__ import annotations
from typing import TYPE_CHECKING
from sklearn.neighbors import NearestNeighbors
if TYPE_CHECKING:
from cleanlab.typing import Metric
[docs]def construct_knn(n_neighbors: int, metric: Metric, **knn_kwargs) -> NearestNeighbors:
"""
Constructs a k-nearest neighbors search object. You can implement a similar method to run cleanlab with your own approximate-KNN library.
Parameters
----------
n_neighbors :
The number of nearest neighbors to consider.
metric :
The distance metric to use for computing distances between points.
See :py:mod:`~cleanlab.internal.neighbor.metric` for more information.
**knn_kwargs:
Additional keyword arguments to be passed to the search index constructor.
See https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html for more details on the available options.
Returns
-------
knn :
A k-nearest neighbors search object compatible with the scikit-learn NearestNeighbors class interface.
Implements:
- `fit` method: Accepts a feature array `X` to fit the model.
This enables subsequent neighbor searches on the data.
- `kneighbors` method: Finds the K-neighbors of a point, returning distances and indices of the k-nearest neighbors. Handles two scenarios:
1. When a query array `features: np.ndarray` is provided, it returns the distances and indices for each point in the query array.
2. When no query array is provided (`features = None`), it returns neighbors for each indexed point without considering the query point as its own neighbor.
Optionally, allows re-specification of the number of neighbors for each query point, defaulting to the constructor's value if not specified.
Attributes:
- `n_neighbors`: Number of neighbors to consider.
- `metric`: Distance metric used to compute distances between points.
- `metric_params`: Additional parameters for the distance metric function.
Optional:
- `kneighbors_graph` method: Not required but can be implemented for convenience.
Responsibility shifted to :py:ref:`construct_knn_graph_from_index <cleanlab.internal.neighbor.neighbor.construct_knn_graph_from_index>`.
Fitted Attributes:
- `n_features_in_`: Number of features observed during fit.
- `effective_metric_params_`: Metric parameters used in distance computation.
- `effective_metric_`: Metric used for computing distances to neighbors.
- `n_samples_fit_`: Number of samples in the fitted data.
Additional:
- `__sklearn_is_fitted__`: Method returning a boolean indicating if the object is fitted,
useful for conducting an is_fitted validation, which verifies the presence of fitted attributes (typically ending with a trailing underscore).
The above specifications ensure compatibility and provide a clear directive for developers needing to integrate alternative k-nearest neighbors implementations or modify existing functionalities.
Note
----
The `metric` argument should be a callable that takes two arguments (the two points) and returns the distance between them.
The additional keyword arguments (`**knn_kwargs`) are passed directly to the underlying k-nearest neighbors search algorithm.
"""
sklearn_knn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric, **knn_kwargs)
return sklearn_knn