Skip to content

Rewrite pairwise to remove concatenation from blockwise #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 8, 2021
103 changes: 83 additions & 20 deletions sgkit/distance/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import typing

import dask.array as da
import numpy as np
from typing_extensions import Literal

from sgkit.distance import metrics
from sgkit.typing import ArrayLike

MetricTypes = Literal["euclidean", "correlation"]


def pairwise_distance(
x: ArrayLike,
metric: str = "euclidean",
) -> np.ndarray:
metric: MetricTypes = "euclidean",
split_every: typing.Optional[int] = None,
) -> da.array:
"""Calculates the pairwise distance between all pairs of row vectors in the
given two dimensional array x.

Expand Down Expand Up @@ -47,6 +53,13 @@ def pairwise_distance(
metric
The distance metric to use. The distance function can be
'euclidean' or 'correlation'.
split_every
Determines the depth of the recursive aggregation in the reduction
step. This argument is directly passed to the call to``dask.reduction``
function in the reduce step of the map reduce.

Omit to let dask heuristically decide a good default. A default can
also be set globally with the split_every key in dask.config.

Returns
-------
Expand All @@ -63,41 +76,91 @@ def pairwise_distance(
>>> from sgkit.distance.api import pairwise_distance
>>> import dask.array as da
>>> x = da.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]]).rechunk(2, 2)
>>> pairwise_distance(x, metric='euclidean')
>>> pairwise_distance(x, metric='euclidean').compute()
array([[0. , 2.44948974, 4.69041576],
[2.44948974, 0. , 5.47722558],
[4.69041576, 5.47722558, 0. ]])

>>> import numpy as np
>>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]])
>>> pairwise_distance(x, metric='euclidean')
>>> pairwise_distance(x, metric='euclidean').compute()
array([[0. , 2.44948974, 4.69041576],
[2.44948974, 0. , 5.47722558],
[4.69041576, 5.47722558, 0. ]])

>>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]])
>>> pairwise_distance(x, metric='correlation')
array([[1.11022302e-16, 2.62956526e-01, 2.82353505e-03],
[2.62956526e-01, 0.00000000e+00, 2.14285714e-01],
[2.82353505e-03, 2.14285714e-01, 0.00000000e+00]])
>>> pairwise_distance(x, metric='correlation').compute()
array([[-4.44089210e-16, 2.62956526e-01, 2.82353505e-03],
[ 2.62956526e-01, 0.00000000e+00, 2.14285714e-01],
[ 2.82353505e-03, 2.14285714e-01, 0.00000000e+00]])
"""

try:
metric_ufunc = getattr(metrics, metric)
metric_map_func = getattr(metrics, f"{metric}_map")
metric_reduce_func = getattr(metrics, f"{metric}_reduce")
n_map_param = metrics.N_MAP_PARAM[metric]
except AttributeError:
raise NotImplementedError(f"Given metric: {metric} is not implemented.")

x = da.asarray(x)
x_distance = da.blockwise(
# Lambda wraps reshape for broadcast
lambda _x, _y: metric_ufunc(_x[:, None, :], _y),
"jk",
if x.ndim != 2:
raise ValueError(f"2-dimensional array expected, got '{x.ndim}'")

# setting this variable outside of _pairwise to avoid it's recreation
# in every iteration, which eventually leads to increase in dask
# graph serialisation/deserialisation time significantly
metric_param = np.empty(n_map_param, dtype=x.dtype)

def _pairwise(f: ArrayLike, g: ArrayLike) -> ArrayLike:
result: ArrayLike = metric_map_func(f[:, None, :], g, metric_param)
# Adding a new axis to help combine chunks along this axis in the
# reduction step (see the _aggregate and _combine functions below).
return result[..., np.newaxis]

# concatenate in blockwise leads to high memory footprints, so instead
# we perform blockwise without contraction followed by reduction.
# More about this issue: https://github.com/dask/dask/issues/6874
out = da.blockwise(
_pairwise,
"ijk",
x,
"ji",
"ik",
x,
"ki",
dtype="float64",
concatenate=True,
"jk",
dtype=x.dtype,
concatenate=False,
)

def _aggregate(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
"""Last function to be executed when resolving the dask graph,
producing the final output. It is always invoked, even when the reduced
Array counts a single chunk along the reduced axes."""
x_chunk = x_chunk.reshape(x_chunk.shape[:-2] + (-1, n_map_param))
result: ArrayLike = metric_reduce_func(x_chunk)
return result

def _chunk(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
return x_chunk

def _combine(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
"""Function used for intermediate recursive aggregation (see
split_every argument to ``da.reduction below``). If the
reduction can be performed in less than 3 steps, it will
not be invoked at all."""
# reduce chunks by summing along the -2 axis
x_chunk_reshaped = x_chunk.reshape(x_chunk.shape[:-2] + (-1, n_map_param))
return x_chunk_reshaped.sum(axis=-2)[..., np.newaxis]

r = da.reduction(
out,
chunk=_chunk,
combine=_combine,
aggregate=_aggregate,
axis=-1,
dtype=x.dtype,
meta=np.ndarray((0, 0), dtype=x.dtype),
split_every=split_every,
name="pairwise",
)
x_distance = da.triu(x_distance, 1) + da.triu(x_distance).T
return x_distance.compute() # type: ignore[no-any-return]

t = da.triu(r)
return t + t.T
184 changes: 116 additions & 68 deletions sgkit/distance/metrics.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,112 @@
"""This module implements various distance metrics."""
"""
This module implements various distance metrics. To implement a new distance
metric, two methods needs to be written, one of them suffixed by 'map' and other
suffixed by 'reduce'. An entry for the same should be added in the N_MAP_PARAM
dictionary below.
"""

import numpy as np
from numba import guvectorize

from sgkit.typing import ArrayLike

# The number of parameters for the map step of the respective distance metric
N_MAP_PARAM = {
"correlation": 6,
"euclidean": 1,
}


@guvectorize( # type: ignore
[
"void(float32[:], float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], int8[:], float64[:])",
],
"(n),(n),(p)->(p)",
nopython=True,
cache=True,
)
def euclidean_map(
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
"""Euclidean distance "map" function for partial vector pairs.

Parameters
----------
x
An array chunk, a partial vector
y
Another array chunk, a partial vector
_
A dummy variable to map the size of output
out
The output array, which has the squared sum of partial vector pairs.

Returns
-------
An ndarray, which contains the output of the calculation of the application
of euclidean distance on the given pair of chunks, without the aggregation.
"""
square_sum = 0.0
m = x.shape[0]
# Ignore missing values
for i in range(m):
if x[i] >= 0 and y[i] >= 0:
square_sum += (x[i] - y[i]) ** 2
out[:] = square_sum


def euclidean_reduce(v: ArrayLike) -> ArrayLike: # pragma: no cover
"""Corresponding "reduce" function for euclidean distance.

Parameters
----------
v
The euclidean array on which map step of euclidean distance has been
applied.

Returns
-------
An ndarray, which contains square root of the sum of the squared sums obtained from
the map step of `euclidean_map`.
"""
out: ArrayLike = np.sqrt(np.einsum("ijkl -> ij", v))
return out


@guvectorize( # type: ignore
[
"void(float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], float64[:])",
"void(float32[:], float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], int8[:], float64[:])",
],
"(n),(n)->()",
"(n),(n),(p)->(p)",
nopython=True,
cache=True,
)
def correlation(x: ArrayLike, y: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
"""Calculates the correlation between two vectors.
def correlation_map(
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None: # pragma: no cover
"""Pearson correlation "map" function for partial vector pairs.

Parameters
----------
x
[array-like, shape: (M,)]
A vector
An array chunk, a partial vector
y
[array-like, shape: (M,)]
Another vector
Another array chunk, a partial vector
_
A dummy variable to map the size of output
out
The output array, which has the output of pearson correlation.

Returns
-------
A scalar representing the pearson correlation coefficient between two vectors x and y.

Examples
--------
>>> from sgkit.distance.metrics import correlation
>>> import dask.array as da
>>> import numpy as np
>>> x = da.array([4, 3, 2, 3], dtype='i1')
>>> y = da.array([5, 6, 7, 0], dtype='i1')
>>> correlation(x, y).compute()
1.2626128

>>> correlation(x, x).compute()
-1.1920929e-07
An ndarray, which contains the output of the calculation of the application
of pearson correlation on the given pair of chunks, without the aggregation.
"""

m = x.shape[0]
valid_indices = np.zeros(m, dtype=np.float64)

Expand All @@ -66,61 +126,49 @@ def correlation(x: ArrayLike, y: ArrayLike, out: ArrayLike) -> None: # pragma:
_y[valid_idx] = y[i]
valid_idx += 1

cov = ((_x - _x.mean()) * (_y - _y.mean())).sum()
denom = (_x.std() * _y.std()) / _x.shape[0]

value = np.nan
if denom > 0:
value = 1.0 - (cov / (_x.std() * _y.std()) / _x.shape[0])
out[0] = value
out[:] = np.array(
[
np.sum(_x),
np.sum(_y),
np.sum(_x * _x),
np.sum(_y * _y),
np.sum(_x * _y),
len(_x),
]
)


@guvectorize( # type: ignore
[
"void(float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], float64[:])",
"void(float32[:, :], float32[:])",
"void(float64[:, :], float64[:])",
],
"(n),(n)->()",
"(p, m)->()",
nopython=True,
cache=True,
)
def euclidean(x: ArrayLike, y: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
"""Calculates the euclidean distance between two vectors.

def correlation_reduce(v: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
"""Corresponding "reduce" function for pearson correlation
Parameters
----------
x
[array-like, shape: (M,)]
A vector
y
[array-like, shape: (M,)]
Another vector
v
The correlation array on which pearson corrections has been
applied on chunks
out
The output scalar, which has the output of euclidean between two vectors.
An ndarray, which is a symmetric matrix of pearson correlation

Returns
-------
A scalar representing the euclidean distance between two vectors x and y.

Examples
--------
>>> from sgkit.distance.metrics import euclidean
>>> import dask.array as da
>>> import numpy as np
>>> x = da.array([4, 3, 2, 3], dtype='i1')
>>> y = da.array([5, 6, 7, 0], dtype='i1')
>>> euclidean(x, y).compute()
6.6332495807108

>>> euclidean(x, x).compute()
0.0

An ndarray, which contains the result of the calculation of the application
of euclidean distance on all the chunks.
"""
square_sum = 0.0
m = x.shape[0]
# Ignore missing values
for i in range(m):
if x[i] >= 0 and y[i] >= 0:
square_sum += (x[i] - y[i]) ** 2
out[0] = np.sqrt(square_sum)
v = v.sum(axis=0)
n = v[5]
num = n * v[4] - v[0] * v[1]
denom1 = np.sqrt(n * v[2] - v[0] ** 2)
denom2 = np.sqrt(n * v[3] - v[1] ** 2)
denom = denom1 * denom2
value = np.nan
if denom > 0:
value = 1 - (num / denom)
out[0] = value
Loading