Skip to content

Commit aa1fba9

Browse files
authored
MNT Improve Nearest Neighbor documentation + code consistency (#19793)
1 parent c957eb3 commit aa1fba9

File tree

1 file changed

+27
-32
lines changed

1 file changed

+27
-32
lines changed

sklearn/neighbors/_base.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -444,20 +444,19 @@ def _fit(self, X, y=None):
444444
self.n_samples_fit_ = X.data.shape[0]
445445
return self
446446

447-
if self.effective_metric_ == 'precomputed':
447+
if self.metric == 'precomputed':
448448
X = _check_precomputed(X)
449+
# Precomputed matrix X must be squared
450+
if X.shape[0] != X.shape[1]:
451+
raise ValueError("Precomputed matrix must be square."
452+
" Input is a {}x{} matrix."
453+
.format(X.shape[0], X.shape[1]))
449454
self.n_features_in_ = X.shape[1]
450455

451456
n_samples = X.shape[0]
452457
if n_samples == 0:
453458
raise ValueError("n_samples must be greater than 0")
454459

455-
# Precomputed matrix X must be squared
456-
if self.metric == 'precomputed' and X.shape[0] != X.shape[1]:
457-
raise ValueError("Precomputed matrix must be a square matrix."
458-
" Input is a {}x{} matrix."
459-
.format(X.shape[0], X.shape[1]))
460-
461460
if issparse(X):
462461
if self.algorithm not in ('auto', 'brute'):
463462
warnings.warn("cannot use tree with sparse input: "
@@ -514,14 +513,12 @@ def _fit(self, X, y=None):
514513
if self.n_neighbors <= 0:
515514
raise ValueError(
516515
"Expected n_neighbors > 0. Got %d" %
517-
self.n_neighbors
518-
)
519-
else:
520-
if not isinstance(self.n_neighbors, numbers.Integral):
521-
raise TypeError(
522-
"n_neighbors does not take %s value, "
523-
"enter integer value" %
524-
type(self.n_neighbors))
516+
self.n_neighbors)
517+
elif not isinstance(self.n_neighbors, numbers.Integral):
518+
raise TypeError(
519+
"n_neighbors does not take %s value, "
520+
"enter integer value" %
521+
type(self.n_neighbors))
525522

526523
return self
527524

@@ -654,18 +651,16 @@ class from an array representing our data set and ask who's
654651
elif n_neighbors <= 0:
655652
raise ValueError(
656653
"Expected n_neighbors > 0. Got %d" %
657-
n_neighbors
658-
)
659-
else:
660-
if not isinstance(n_neighbors, numbers.Integral):
661-
raise TypeError(
662-
"n_neighbors does not take %s value, "
663-
"enter integer value" %
664-
type(n_neighbors))
654+
n_neighbors)
655+
elif not isinstance(n_neighbors, numbers.Integral):
656+
raise TypeError(
657+
"n_neighbors does not take %s value, "
658+
"enter integer value" %
659+
type(n_neighbors))
665660

666661
if X is not None:
667662
query_is_train = False
668-
if self.effective_metric_ == 'precomputed':
663+
if self.metric == 'precomputed':
669664
X = _check_precomputed(X)
670665
else:
671666
X = self._validate_data(X, accept_sparse='csr', reset=False)
@@ -687,7 +682,7 @@ class from an array representing our data set and ask who's
687682
n_jobs = effective_n_jobs(self.n_jobs)
688683
chunked_results = None
689684
if (self._fit_method == 'brute' and
690-
self.effective_metric_ == 'precomputed' and issparse(X)):
685+
self.metric == 'precomputed' and issparse(X)):
691686
results = _kneighbors_from_graph(
692687
X, n_neighbors=n_neighbors,
693688
return_distance=return_distance)
@@ -793,8 +788,8 @@ def kneighbors_graph(self, X=None, n_neighbors=None,
793788
Returns
794789
-------
795790
A : sparse-matrix of shape (n_queries, n_samples_fit)
796-
`n_samples_fit` is the number of samples in the fitted data
797-
`A[i, j]` is assigned the weight of edge that connects `i` to `j`.
791+
`n_samples_fit` is the number of samples in the fitted data.
792+
`A[i, j]` gives the weight of the edge connecting `i` to `j`.
798793
The matrix is of CSR format.
799794
800795
Examples
@@ -980,7 +975,7 @@ class from an array representing our data set and ask who's
980975

981976
if X is not None:
982977
query_is_train = False
983-
if self.effective_metric_ == 'precomputed':
978+
if self.metric == 'precomputed':
984979
X = _check_precomputed(X)
985980
else:
986981
X = self._validate_data(X, accept_sparse='csr', reset=False)
@@ -992,7 +987,7 @@ class from an array representing our data set and ask who's
992987
radius = self.radius
993988

994989
if (self._fit_method == 'brute' and
995-
self.effective_metric_ == 'precomputed' and issparse(X)):
990+
self.metric == 'precomputed' and issparse(X)):
996991
results = _radius_neighbors_from_graph(
997992
X, radius=radius, return_distance=return_distance)
998993

@@ -1116,9 +1111,9 @@ def radius_neighbors_graph(self, X=None, radius=None, mode='connectivity',
11161111
Returns
11171112
-------
11181113
A : sparse-matrix of shape (n_queries, n_samples_fit)
1119-
`n_samples_fit` is the number of samples in the fitted data
1120-
`A[i, j]` is assigned the weight of edge that connects `i` to `j`.
1121-
The matrix if of format CSR.
1114+
`n_samples_fit` is the number of samples in the fitted data.
1115+
`A[i, j]` gives the weight of the edge connecting `i` to `j`.
1116+
The matrix is of CSR format.
11221117
11231118
Examples
11241119
--------

0 commit comments

Comments
 (0)