Skip to content
39 changes: 20 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload
from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload

import pandas as pd

Expand All @@ -14,42 +14,41 @@
merge_attrs,
merge_collected,
)
from .types import T_DataArray, T_Dataset
from .variable import Variable
from .variable import concat as concat_vars

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions


@overload
def concat(
objs: Iterable[Dataset],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_Dataset],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
...


@overload
def concat(
objs: Iterable[DataArray],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_DataArray],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
...


Expand Down Expand Up @@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[Dataset],
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

dims: set[Hashable] = set()
Expand All @@ -429,16 +428,16 @@ def _parse_datasets(


def _dataset_concat(
datasets: list[Dataset],
dim: str | DataArray | pd.Index,
datasets: list[T_Dataset],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand Down Expand Up @@ -482,7 +481,8 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
datasets = [ds.expand_dims(dim) for ds in datasets]
# TODO: Overriding type because .expand_dims has incorrect typing:
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
Expand Down Expand Up @@ -590,7 +590,7 @@ def get_indexes(name):
# preserves original variable order
result_vars[name] = result_vars.pop(name)

result = Dataset(result_vars, attrs=result_attrs)
result = type(datasets[0])(result_vars, attrs=result_attrs)

absent_coord_names = coord_names - set(result.variables)
if absent_coord_names:
Expand Down Expand Up @@ -618,16 +618,16 @@ def get_indexes(name):


def _dataarray_concat(
arrays: Iterable[DataArray],
dim: str | DataArray | pd.Index,
arrays: Iterable[T_DataArray],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
from .dataarray import DataArray

arrays = list(arrays)
Expand All @@ -650,7 +650,8 @@ def _dataarray_concat(
if compat == "identical":
raise ValueError("array names not identical")
else:
arr = arr.rename(name)
# TODO: Overriding type because .rename has incorrect typing:
arr = cast(T_DataArray, arr.rename(name))
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(
Expand Down