Source code for datacube.api.core

# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2025 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import collections.abc
import datetime
import logging
import uuid
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
from itertools import groupby
from typing import TYPE_CHECKING, Any, TypeAlias, cast

import deprecat
import numpy
import xarray
from dask import array as da
from odc.geo import CRS, XY, Resolution, res_, resyx_, yx_
from odc.geo.geobox import GeoBox, GeoboxTiles
from odc.geo.geom import Geometry, bbox_union, box, intersects
from odc.geo.warp import Resampling
from odc.geo.xr import xr_coords
from typing_extensions import override

from datacube.cfg import GeneralisedCfg, GeneralisedEnv, GeneralisedRawCfg, ODCConfig
from datacube.model import Dataset, ExtraDimensions, ExtraDimensionSlices, Measurement
from datacube.model.utils import xr_apply
from datacube.storage import BandInfo, reproject_and_fuse
from datacube.utils import ignore_exceptions_if
from datacube.utils.dates import normalise_dt

if TYPE_CHECKING:
    from odc.geo.crs import MaybeCRS
    from pandas import DataFrame

    from datacube.model import GridSpec
    from datacube.utils.geometry import GeoBox as LegacyGeoBox

from ..drivers import new_datasource
from ..index import Index, extract_geom_from_query, index_connect
from ..migration import ODC2DeprecationWarning
from ..model import QueryField
from ..storage._load import FuserFunction, ProgressFunction
from .query import GroupBy, Query, _normalise_geobox, query_group_by

_LOG: logging.Logger = logging.getLogger(__name__)


DataFrameLike: TypeAlias = list[dict[str, str | int | float | None]]


class TerminateCurrentLoad(Exception):  # noqa: N818
    """This exception is raised by user code from `progress_cbk`
    to terminate currently running `.load`
    """

    pass


[docs] class Datacube: """ Interface to search, read and write a datacube. """ def __init__( self, index: Index | None = None, config: GeneralisedCfg | None = None, env: GeneralisedEnv | None = None, raw_config: GeneralisedRawCfg | None = None, app: str | None = None, validate_connection: bool = True, ) -> None: """ Create an interface for the query and storage access. :param index: The database index to use. If provided, config, app, env and raw_config should all be None. :param config: One of: - None (Use provided ODCEnvironment or Index, or perform default config loading.) - An ODCConfig object - A file system path pointing to the location of the config file. - A list of file system paths to search for config files. The first readable file found will be used. If an index or an explicit ODCEnvironment is supplied, config and raw_config should be None. :param env: The datacube environment to use. Either an explicit ODCEnvironment object, or a str which is a section name in the loaded config file. Defaults to 'default'. Falls back to 'datacube' with a deprecation warning if config file does not contain a 'default' section. Allows you to have multiple datacube instances in one configuration, specified on load, e.g. 'dev', 'test' or 'landsat', 'modis' etc. If env is an ODCEnvironment object, config and index should both None. :param raw_config: Explicit configuration to use. Either as a string (serialised in ini or yaml format) or a dictionary (deserialised). If provided, config should be None. If an index or an explicit ODCEnvironment is supplied, config and raw_config should be None. :param app: A short, alphanumeric name to identify this application. The application name is used to track down problems with database queries, so it is strongly advised that be used. Should be None if an index is supplied. :param validate_connection: Check that the database connection is available and valid. Defaults to True. Ignored if index is passed. """ # Validate arguments if index is not None: # If an explicit index is provided, all other index-creation arguments should be None. should_be_none: list[str] = [] if config is not None: should_be_none.append("config") if raw_config is not None: should_be_none.append("raw_config") if app is not None: should_be_none.append("app") if env is not None: should_be_none.append("env") if should_be_none: raise ValueError( f"When an explicit index is provided, these arguments should be None: {','.join(should_be_none)}" ) # Explicit index passed in? Use it. self.index = index return # Obtain an ODCEnvironment object: cfg_env = ODCConfig.get_environment( env=env, config=config, raw_config=raw_config ) self.index = index_connect( cfg_env, application_name=app, validate_connection=validate_connection )
[docs] def list_products( self, with_pandas: bool = True, dataset_count: bool = False ) -> DataFrame | DataFrameLike: """ List all products in the datacube. This will produce a ``pandas.DataFrame`` or list of dicts containing useful information about each product, including: 'name' 'description' 'license' 'default_crs' or 'grid_spec.crs' 'default_resolution' or 'grid_spec.crs' 'dataset_count' (optional) :param with_pandas: Return the list as a Pandas DataFrame. Defaults to True. If False, return a list of dicts. :param dataset_count: Return a "dataset_count" column containing the number of datasets for each product. This can take several minutes on large datacubes. Defaults to False. :return: A table or list of every product in the datacube. """ def _get_non_default(product, col): load_hints = product.load_hints() if load_hints: if col == "crs": return load_hints.get("output_crs", None) return load_hints.get(col, None) return getattr(product.grid_spec, col, None) # Read properties from each datacube product cols = [ "name", "description", "license", "default_crs", "default_resolution", ] rows = [ [ ( getattr(pr, col, None) # if 'default_crs' and 'default_resolution' are not None # return 'default_crs' and 'default_resolution' if getattr(pr, col, None) or "default" not in col # else get crs and resolution from load_hints or grid_spec # as per output_geobox() handling logic else _get_non_default(pr, col.replace("default_", "")) ) for col in cols ] for pr in self.index.products.get_all() ] # Optionally compute dataset count for each product and add to row/cols # Product lists are sorted by product name to ensure 1:1 match if dataset_count: # Load counts counts = [(p.name, c) for p, c in self.index.datasets.count_by_product()] # Sort both rows and counts by product name from operator import itemgetter rows = sorted(rows, key=itemgetter(0)) counts = sorted(counts, key=itemgetter(0)) # Add sorted count to each existing row rows = [row + [count[1]] for row, count in zip(rows, counts)] cols += ["dataset_count"] # If pandas not requested, return list of dicts if not with_pandas: return [dict(zip(cols, row)) for row in rows] # Return pandas dataframe with each product as a row import pandas return pandas.DataFrame(rows, columns=cols).set_index("name", drop=False)
[docs] @deprecat.deprecat( deprecated_args={ "show_archived": { "reason": "The show_archived argument has never done anything and will be removed in future.", "version": "1.9.0", "category": ODC2DeprecationWarning, } } ) def list_measurements( self, show_archived: bool = False, with_pandas: bool = True ) -> DataFrame | DataFrameLike: """ List measurements for each product :param show_archived: include archived products in the result. :param with_pandas: return the list as a Pandas DataFrame, otherwise as a list of dict. (defaults to True) """ measurements = self._list_measurements() if not with_pandas: return measurements import pandas return pandas.DataFrame.from_records(measurements).set_index( ["product", "measurement"] )
def _list_measurements(self) -> list[dict[str, Any]]: measurements = [] dts = self.index.products.get_all() for dt in dts: if dt.measurements: for name, measurement in dt.measurements.items(): row = { "product": dt.name, "measurement": name, } if "attrs" in measurement: row.update(measurement["attrs"]) row.update({k: v for k, v in measurement.items() if k != "attrs"}) measurements.append(row) return measurements #: pylint: disable=too-many-arguments, too-many-locals
[docs] def load( self, product: str | None = None, measurements: str | list[str] | None = None, output_crs: MaybeCRS = None, resolution: ( int | float | tuple[int | float, int | float] | Resolution | None ) = None, resampling: Resampling | dict[str, Resampling] | None = None, align: XY[float] | Iterable[float] | None = None, skip_broken_datasets: bool | None = None, dask_chunks: dict[str, str | int] | None = None, like: GeoBox | xarray.Dataset | xarray.DataArray | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None, datasets: Sequence[Dataset] | None = None, dataset_predicate: Callable[[Dataset], bool] | None = None, progress_cbk: ProgressFunction | None = None, patch_url: Callable[[str], str] | None = None, limit: int | None = None, driver: Any | None = None, **query: QueryField, ) -> xarray.Dataset: r""" Load data as an ``xarray.Dataset`` object. Each measurement will be a data variable in the :class:`xarray.Dataset`. See the `xarray documentation <https://xarray.pydata.org/en/stable/data-structures.html>`_ for usage of the :class:`xarray.Dataset` and :class:`xarray.DataArray` objects. **Product and Measurements** A product can be specified using the product name:: product='ls5_ndvi_albers' See :meth:`list_products` for the list of products with their names and properties. A product name MUST be supplied unless search is bypassed all together by supplying an explicit list of datasets. The ``measurements`` argument is a list of measurement names, as listed in :meth:`list_measurements`. If not provided, all measurements for the product will be returned:: measurements=['red', 'nir', 'swir2'] **Dimensions** Spatial dimensions can be specified using the ``longitude``/``latitude`` and ``x``/``y`` fields. The CRS of this query is assumed to be WGS84/EPSG:4326 unless the ``crs`` field is supplied, even if the stored data is in another projection or the ``output_crs`` is specified. The dimensions ``longitude``/``latitude`` and ``x``/``y`` can be used interchangeably:: latitude=(-34.5, -35.2), longitude=(148.3, 148.7) or:: x=(1516200, 1541300), y=(-3867375, -3867350), crs='EPSG:3577' You can also specify a polygon with an arbitrary CRS (in e.g. the native CRS):: geopolygon=polygon(coords, crs="EPSG:3577") Or an iterable of polygons (search is done against the union of all polygons):: geopolygon=[poly1, poly2, poly3, ....] You can also pass a WKT string, or a GeoJSON string or any other object that can be passed to the odc.geo.Geometry constructor, or an iterable of any of the above. Performance and accuracy of geopolygon queries may vary depending on the index driver in use and the CRS. The ``time`` dimension can be specified using a single or tuple of datetime objects or strings with ``YYYY-MM-DD hh:mm:ss`` format. Data will be loaded inclusive of the start and finish times. A ``None`` value in the range indicates an open range, with the provided date serving as either the upper or lower bound. E.g.:: time=('2000-01-01', '2001-12-31') time=('2000-01', '2001-12') time=('2000', '2001') time=('2000') time=('2000', None) # all data from 2000 onward time=(None, '2000') # all data up to and including 2000 For 3D datasets, where the product definition contains an ``extra_dimension`` specification, these dimensions can be queried using that dimension's name. E.g.:: z=(10, 30) or:: z=5 or:: wvl=(560.3, 820.5) For EO-specific datasets that are based around scenes, the time dimension can be reduced to the day level, using solar day to keep scenes together:: group_by='solar_day' For data that has different values for the scene overlap that requires more complex rules for combining data, a function can be provided to the merging into a single time slice. See :func:`datacube.helpers.ga_pq_fuser` for an example implementation. see :func:`datacube.api.query.query_group_by` for `group_by` built-in functions. **Output** To reproject or resample data, supply the ``output_crs``, ``resolution``, ``resampling`` and ``align`` fields. By default, the resampling method is 'nearest'. However, any stored overview layers may be used when down-sampling, which may override (or hybridise) the choice of resampling method. To reproject data to 30 m resolution for EPSG:3577:: dc.load(product='ls5_nbar_albers', x=(148.15, 148.2), y=(-35.15, -35.2), time=('1990', '1991'), output_crs='EPSG:3577`, resolution=30, resampling='cubic' ) odc-geo style xy objects are preferred for passing in resolution and align pairs to avoid x/y ordering ambiguity. :param product: The name of the product to be loaded. Either ``product`` or ``datasets`` must be supplied :param measurements: Measurements name or list of names to be included, as listed in :meth:`list_measurements`. These will be loaded as individual ``xr.DataArray`` variables in the output ``xarray.Dataset`` object. If a list is specified, the measurements will be returned in the order requested. By default, all available measurements are included. :param output_crs: The CRS of the returned data, for example ``EPSG:3577``. If no CRS is supplied, the CRS of the stored data is used if available. Any form that can be converted to a CRS by odc-geo is accepted. This differs from the ``crs`` parameter described above, which is used to define the CRS of the coordinates in the query itself. :param resolution: The spatial resolution of the returned data. If using square pixels with an inverted Y axis, it should be provided as an int or float. If not, it should be provided as an odc-geo XY object to avoid coordinate-order ambiguity. If passed as a tuple, y,x order is assumed for backwards compatibility. Units are in the coordinate space of ``output_crs``. This includes the direction (as indicated by a positive or negative number). :param resampling: The resampling method to use if re-projection is required. This could be a string or a dictionary mapping band name to resampling mode. When using a dict use ``'\*'`` to indicate "apply to all other bands", for example ``{'\*': 'cubic', 'fmask': 'nearest'}`` would use ``cubic`` for all bands except ``fmask`` for which ``nearest`` will be used. Valid values are:: 'nearest', 'average', 'bilinear', 'cubic', 'cubic_spline', 'lanczos', 'mode', 'gauss', 'max', 'min', 'med', 'q1', 'q3' Default is to use ``nearest`` for all bands. .. seealso:: :meth:`load_data` :param align: Load data such that point 'align' lies on the pixel boundary. A pair of floats between 0 and 1. An odc-geo XY object is preferred to avoid coordinate-order ambiguity. If passed as a tuple, x,y order is assumed for backwards compatibility. Default is ``(0, 0)`` :param skip_broken_datasets: Optional. If this is True, then don't break when failing to load a broken dataset. If None, the value will come from the environment variable of the same name. Default is False. :param dask_chunks: If the data should be lazily loaded using :class:`dask.array.Array`, specify the chunking size in each output dimension. See the documentation on using `xarray with dask <https://xarray.pydata.org/en/stable/dask.html>`_ for more information. :param like: Use the output of a previous :meth:`load()` to load data into the same spatial grid and resolution (i.e. :class:`odc.geo.geobox.GeoBox` or an xarray `Dataset` or `DataArray`). E.g.:: pq = dc.load(product='ls5_pq_albers', like=nbar_dataset) :param fuse_func: Function used to fuse/combine/reduce data with the ``group_by`` parameter. By default, data is simply copied over the top of each other in a relatively undefined manner. This function can perform a specific combining step. This can be a dictionary if different fusers are needed per band (similar format to the resampling dict described above). :param datasets: Optional. If this is a non-empty list of :class:`datacube.model.Dataset` objects, these will be loaded instead of performing a database lookup. :param dataset_predicate: Optional. A function that can be passed to restrict loaded datasets. A predicate function should take a :class:`datacube.model.Dataset` object (e.g. as returned from :meth:`find_datasets`) and return a boolean. For example, loaded data could be filtered to January observations only by passing the following predicate function that returns True for datasets acquired in January:: def filter_jan(dataset): return dataset.time.begin.month == 1 . :param progress_cbk: ``Int, Int -> None``, if supplied will be called for every file read with ``files_processed_so_far, total_files``. This is only applicable to non-lazy loads, ignored when using dask. :param patch_url: if supplied, will be used to patch/sign the url(s), as required to access some commercial archives (e.g. Microsoft Planetary Computer). :param limit: Optional. If provided, limit the maximum number of datasets returned. Useful for testing and debugging. Can also be provided via the ``dc_load_limit`` config option. :param driver: Optional. If provided, use the specified driver to load the data. :param query: Search parameters for products and dimension ranges as described above. For example: ``'x', 'y', 'time', 'crs'``. :return: Requested data in a :class:`xarray.Dataset` """ if product is None and datasets is None: raise ValueError("Must specify a product or supply datasets") if datasets is None: assert product is not None # For type checker if limit is None: # check if a value was provided via the envvar limit = self.index.environment["dc_load_limit"] datasets = self.find_datasets( ensure_location=True, dataset_predicate=dataset_predicate, like=like, limit=limit, product=product, **query, ) elif isinstance(datasets, collections.abc.Iterator): datasets = list(datasets) if len(datasets) == 0: return xarray.Dataset() ds, *_ = datasets datacube_product = ds.product # Retrieve extra_dimension from product definition extra_dims: ExtraDimensions | None = None if datacube_product: extra_dims = datacube_product.extra_dimensions # Extract extra_dims slice information extra_dims_slice = cast( ExtraDimensionSlices, { k: query.pop(k, None) for k in list(query.keys()) if k in extra_dims.dims and query.get(k, None) is not None }, ) extra_dims = extra_dims[extra_dims_slice] # Check if empty if extra_dims.has_empty_dim(): return xarray.Dataset() if type(resolution) is tuple: _LOG.warning( "Resolution should be provided as a single int or float, or the axis order specified " "using odc.geo.resxy_ or odc.geo.resyx_" ) if resolution[0] == -resolution[1]: resolution = res_(resolution[1]) else: _LOG.warning( "Assuming resolution has been provided in (y, x) ordering. Please specify the order " "with odc.geo.resxy_ or odc.geo.resyx_" ) resolution = resyx_(*resolution) load_hints = datacube_product.load_hints() grid_spec = None if load_hints is not None else datacube_product.grid_spec geobox = output_geobox( like=like, output_crs=output_crs, resolution=resolution, align=align, grid_spec=grid_spec, load_hints=load_hints, datasets=datasets, geopolygon=cast(Geometry | None, query.pop("geopolygon", None)), **query, ) group_by = query_group_by(**query) # type: ignore[arg-type] grouped = self.group_datasets(datasets, group_by) measurement_dicts = datacube_product.lookup_measurements(measurements) if skip_broken_datasets is None: # default to value from env var, which defaults to False skip_broken_datasets = self.index.environment["skip_broken_datasets"] result = self.load_data( grouped, geobox, measurement_dicts, resampling=resampling, fuse_func=fuse_func, dask_chunks=dask_chunks, skip_broken_datasets=skip_broken_datasets, progress_cbk=progress_cbk, extra_dims=extra_dims, patch_url=patch_url, driver=driver, ) return result
[docs] def find_datasets( self, ensure_location: bool = False, dataset_predicate: Callable[[Dataset], bool] | None = None, like: GeoBox | xarray.Dataset | xarray.DataArray | None = None, limit: int | None = None, **search_terms: QueryField, ) -> list[Dataset]: """ Search the index and return all datasets for a product matching the search terms. :param ensure_location: only return datasets that have locations :param dataset_predicate: an optional predicate to filter datasets :param like: Use the output of a previous :meth:`load()` to load data into the same spatial grid and resolution (i.e. :class:`odc.geo.geobox.GeoBox` or an xarray `Dataset` or `DataArray`). E.g.:: pq = dc.load(product='ls5_pq_albers', like=nbar_dataset) :param limit: if provided, limit the maximum number of datasets returned :param search_terms: see :class:`datacube.api.query.Query` :return: list of datasets .. seealso:: :meth:`group_datasets` :meth:`load_data` :meth:`find_datasets_lazy` """ return list( self.find_datasets_lazy( limit=limit, ensure_location=ensure_location, dataset_predicate=dataset_predicate, like=like, **search_terms, ) )
[docs] def find_datasets_lazy( self, limit: int | None = None, ensure_location: bool = False, dataset_predicate: Callable[[Dataset], bool] | None = None, like: GeoBox | xarray.Dataset | xarray.DataArray | None = None, **kwargs: QueryField, ) -> Iterable[Dataset]: """ Find datasets matching query. :param limit: if provided, limit the maximum number of datasets returned :param ensure_location: only return datasets that have locations :param dataset_predicate: an optional predicate to filter datasets :param like: Use the output of a previous :meth:`load()` to load data into the same spatial grid and resolution (i.e. :class:`odc.geo.geobox.GeoBox` or an xarray `Dataset` or `DataArray`). E.g.:: pq = dc.load(product='ls5_pq_albers', like=nbar_dataset) :param kwargs: see :class:`datacube.api.query.Query` :return: iterator of datasets .. seealso:: :meth:`group_datasets` :meth:`load_data` :meth:`find_datasets` """ if like is not None: like = _normalise_geobox(like) query = Query(self.index, like=like, **kwargs) # type: ignore[arg-type] if not query.product: raise ValueError("must specify a product") datasets = self.index.datasets.search(limit=limit, **query.search_terms) if query.geopolygon is not None and not self.index.supports_spatial_indexes: datasets = select_datasets_inside_polygon(datasets, query.geopolygon) if ensure_location: datasets = (dataset for dataset in datasets if dataset.uri) # If a predicate function is provided, use this to filter datasets before load if dataset_predicate is not None: datasets = (dataset for dataset in datasets if dataset_predicate(dataset)) return datasets
[docs] @staticmethod def group_datasets( datasets: Iterable[Dataset], group_by: GroupBy ) -> xarray.DataArray: """ Group datasets along defined non-spatial dimensions (ie. time). :param datasets: a list of datasets, typically from :meth:`find_datasets` :param group_by: Contains: - a function that returns a label for a dataset - name of the new dimension - unit for the new dimension - function to sort by before grouping .. seealso:: :meth:`find_datasets`, :meth:`load_data`, :meth:`query_group_by` """ if isinstance(group_by, str): group_by = query_group_by(group_by=group_by) def ds_sorter(ds: Dataset) -> Any: return group_by.sort_key(ds), getattr(ds, "id", 0) def norm_axis_value(x: Any) -> Any: if isinstance(x, datetime.datetime): # For datetime we convert to UTC, then strip timezone info # to avoid numpy/pandas warning about timezones return numpy.datetime64(normalise_dt(x), "ns") return x def mk_group(group: Iterable[Dataset]) -> tuple[Any, Iterable[Dataset]]: dss = tuple(sorted(group, key=ds_sorter)) return norm_axis_value(group_by.group_key(dss)), dss datasets = sorted(datasets, key=group_by.group_by_func) groups = [ mk_group(group) for _, group in groupby(datasets, group_by.group_by_func) ] groups.sort(key=lambda x: x[0]) coords = numpy.asarray([coord for coord, _ in groups]) data = numpy.empty(len(coords), dtype=object) for i, (_, dss) in enumerate(groups): data[i] = dss sources = xarray.DataArray(data, dims=[group_by.dimension], coords=[coords]) if coords.dtype.kind == "M": # skip units for time dimensions as it breaks .to_netcdf(..) functionality #972 sources[group_by.dimension].attrs["units"] = group_by.units return sources
[docs] @staticmethod def create_storage( coords: Mapping[str, xarray.DataArray], geobox: GeoBox | xarray.Dataset | xarray.DataArray, measurements: list[Measurement], data_func: ( Callable[[Measurement, tuple[int, ...]], numpy.ndarray] | None ) = None, extra_dims: ExtraDimensions | None = None, ) -> xarray.Dataset: """ Create a :class:`xarray.Dataset` and (optionally) fill it with data. This function makes the in memory storage structure to hold datacube data. :param coords: OrderedDict holding `DataArray` objects defining the dimensions not specified by `geobox` :param geobox: A GeoBox defining the output spatial projection and resolution :param measurements: list of :class:`datacube.model.Measurement` :param data_func: Callable `Measurement -> np.ndarray` function to fill the storage with data. It is called once for each measurement, with the measurement as an argument. It should return an appropriately shaped numpy array. If not provided memory is allocated and filled with `nodata` value defined on a given Measurement. :param extra_dims: A ExtraDimensions describing any additional dimensions on top of (t, y, x) .. seealso:: :meth:`find_datasets` :meth:`group_datasets` """ from collections import OrderedDict from copy import deepcopy spatial_ref = "spatial_ref" def empty_func(m: Measurement, shape: tuple[int, ...]) -> numpy.ndarray: return numpy.full(shape, m.nodata, dtype=m.dtype) geobox = _normalise_geobox(geobox) crs_attrs = {} if geobox.crs is not None: crs_attrs["crs"] = str(geobox.crs) crs_attrs["grid_mapping"] = spatial_ref # Assumptions # - 3D dims must fit between (t) and (y, x) or (lat, lon) # 2D defaults # retrieve dims from coords if DataArray dims_default: tuple[Hashable, ...] = () if coords != {}: coords_value = next(iter(coords.values())) if isinstance(coords_value, xarray.DataArray): dims_default = coords_value.dims + geobox.dimensions if not dims_default: dims_default = tuple(coords) + geobox.dimensions shape_default = ( tuple(c.size for k, c in coords.items() if k in dims_default) + geobox.shape ) coords_default: OrderedDict[str, xarray.DataArray] = OrderedDict(**coords) coords_default.update( [(str(k), v) for k, v in xr_coords(geobox, spatial_ref).items()] ) arrays = [] ds_coords = deepcopy(coords_default) for m in measurements: if "extra_dim" not in m: # 2D default case arrays.append((m, shape_default, coords_default, dims_default)) elif extra_dims: # 3D case name = m.extra_dim new_dims = dims_default[:1] + (name,) + dims_default[1:] new_coords = deepcopy(coords_default) new_coords[name] = extra_dims._coords[name].copy() new_coords[name].attrs.update(crs_attrs) ds_coords.update(new_coords) new_shape = ( shape_default[:1] + (len(new_coords[name].values),) + shape_default[1:] ) arrays.append((m, new_shape, new_coords, new_dims)) data_func = data_func or (lambda m, shape: empty_func(m, shape)) def mk_data_var( m: Measurement, shape: tuple[int, ...], coords: OrderedDict[str, xarray.DataArray], dims: tuple[Hashable, ...], data_func: Callable[[Measurement, tuple[int, ...]], numpy.ndarray], ) -> xarray.DataArray: data = data_func(m, shape) attrs = dict(**m.dataarray_attrs(), **crs_attrs) return xarray.DataArray( data, name=m.name, coords=coords, dims=dims, attrs=attrs ) return xarray.Dataset( { m.name: mk_data_var(m, shape, coords, dims, data_func) for m, shape, coords, dims in arrays }, coords=ds_coords, attrs=crs_attrs, )
@staticmethod def _dask_load( sources: xarray.DataArray, geobox: GeoBox, measurements: list[Measurement], dask_chunks: dict[str, str | int], skip_broken_datasets: bool = False, extra_dims: ExtraDimensions | None = None, patch_url: Callable[[str], str] | None = None, ) -> xarray.Dataset: chunk_sizes = _calculate_chunk_sizes(sources, geobox, dask_chunks, extra_dims) needed_irr_chunks = chunk_sizes[0] if extra_dims: extra_dim_chunks = chunk_sizes[1] grid_chunks = chunk_sizes[-1] gbt = GeoboxTiles(geobox, grid_chunks) dsk = {} def chunk_datasets(dss, gbt): out = {} for ds in dss: dsk[_tokenize_dataset(ds)] = ds for idx in gbt.tiles(ds.extent): out.setdefault(idx, []).append(ds) return out chunked_srcs = xr_apply( sources, lambda _, dss: chunk_datasets(dss, gbt), dtype=object ) def data_func(measurement, shape): if "extra_dim" in measurement: chunks = needed_irr_chunks + extra_dim_chunks + grid_chunks else: chunks = needed_irr_chunks + grid_chunks return _make_dask_array( chunked_srcs, dsk, gbt, measurement, chunks=chunks, skip_broken_datasets=skip_broken_datasets, extra_dims=extra_dims, patch_url=patch_url, ) return Datacube.create_storage( cast(Mapping[str, xarray.DataArray], sources.coords), geobox, measurements, data_func, extra_dims, ) @staticmethod def _xr_load( sources: xarray.DataArray, geobox: GeoBox, measurements: list[Measurement], skip_broken_datasets: bool = False, progress_cbk: ProgressFunction | None = None, extra_dims: ExtraDimensions | None = None, patch_url: Callable[[str], str] | None = None, ) -> xarray.Dataset: def mk_cbk(cbk: ProgressFunction | None) -> ProgressFunction | None: if cbk is None: return None n = 0 t_size = sum(len(x) for x in sources.values.ravel()) n_total = 0 for m in measurements: if "extra_dim" in m: assert extra_dims is not None # for type-checker index_subset = extra_dims.measurements_slice(m.extra_dim) n_total += t_size * len( m.extra_dim.get("measurement_map")[index_subset] ) else: n_total += t_size def _cbk(*ignored): nonlocal n n += 1 return cbk(n, n_total) return _cbk data = Datacube.create_storage( cast(Mapping[str, xarray.DataArray], sources.coords), geobox, measurements, extra_dims=extra_dims, ) _cbk = mk_cbk(progress_cbk) # Create a list of read IO operations read_ios = [] for index, datasets in numpy.ndenumerate(sources.values): for m in measurements: if "extra_dim" in m: # When we want to support 3D native reads, we can start by replacing the for loop with # read_ios.append(((index + extra_dim_index), (datasets, m, index_subset))) assert extra_dims is not None # for type-checker index_subset = extra_dims.measurements_index(m.extra_dim) for result_index, extra_dim_index in enumerate( range(*index_subset) ): read_ios.append( ((index + (result_index,)), (datasets, m, extra_dim_index)) ) else: # Get extra_dim index if available extra_dim_index = m.get("extra_dim_index", None) read_ios.append((index, (datasets, m, extra_dim_index))) # Perform the read IO operations for index, (datasets, m, extra_dim_index) in read_ios: data_slice = data[m.name].values[index] try: _fuse_measurement( data_slice, datasets, geobox, m, skip_broken_datasets=skip_broken_datasets, progress_cbk=_cbk, extra_dim_index=extra_dim_index, patch_url=patch_url, ) except (TerminateCurrentLoad, KeyboardInterrupt): data.attrs["dc_partial_load"] = True return data return data
[docs] @staticmethod def load_data( sources: xarray.DataArray, geobox: GeoBox | xarray.Dataset | xarray.DataArray, measurements: Mapping[str, Measurement] | list[Measurement], resampling: Resampling | dict[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None, dask_chunks: dict[str, str | int] | None = None, skip_broken_datasets: bool = False, progress_cbk: ProgressFunction | None = None, extra_dims: ExtraDimensions | None = None, patch_url: Callable[[str], str] | None = None, driver: Any | None = None, **extra, ) -> xarray.Dataset: """ Load data from :meth:`group_datasets` into an :class:`xarray.Dataset`. :param sources: DataArray holding a list of :class:`datacube.model.Dataset`, grouped along the time dimension :param geobox: A GeoBox defining the output spatial projection and resolution :param measurements: list of `Measurement` objects :param resampling: The resampling method to use if re-projection is required. This could be a string or a dictionary mapping band name to resampling mode. When using a dict use ``'*'`` to indicate "apply to all other bands", for example ``{'*': 'cubic', 'fmask': 'nearest'}`` would use `cubic` for all bands except ``fmask`` for which `nearest` will be used. Valid values are: ``'nearest', 'cubic', 'bilinear', 'cubic_spline', 'lanczos', 'average', 'mode', 'gauss', 'max', 'min', 'med', 'q1', 'q3'`` Default is to use ``nearest`` for all bands. :param fuse_func: function to merge successive arrays as an output. Can be a dictionary just like resampling. :param dask_chunks: If provided, the data will be loaded on demand using :class:`dask.array.Array`. Should be a dictionary specifying the chunking size for each output dimension. Unspecified dimensions will be auto-guessed, currently this means use chunk size of 1 for non-spatial dimensions and use whole dimension (no chunking unless specified) for spatial dimensions. See the documentation on using `xarray with dask <https://xarray.pydata.org/en/stable/dask.html>`_ for more information. :param skip_broken_datasets: do not include broken datasets in the result. :param progress_cbk: Int, Int -> None if supplied will be called for every file read with `files_processed_so_far, total_files`. This is only applicable to non-lazy loads, ignored when using dask. :param extra_dims: A ExtraDimensions describing any additional dimensions on top of (t, y, x) :param patch_url: if supplied, will be used to patch/sign the url(s), as required to access some commercial archives. :param driver: Optional. If provided, use the specified driver to load the data. .. seealso:: :meth:`find_datasets` :meth:`group_datasets` """ measurements = per_band_load_data_settings( measurements, resampling=resampling, fuse_func=fuse_func ) geobox = _normalise_geobox(geobox) if driver is not None: from ..storage._loader import driver_based_load return driver_based_load( driver, sources, geobox, measurements, dask_chunks, skip_broken_datasets=skip_broken_datasets, extra_dims=extra_dims, patch_url=patch_url, ) if dask_chunks is not None: return Datacube._dask_load( sources, geobox, measurements, dask_chunks, skip_broken_datasets=skip_broken_datasets, extra_dims=extra_dims, patch_url=patch_url, ) else: return Datacube._xr_load( sources, geobox, measurements, skip_broken_datasets=skip_broken_datasets, progress_cbk=progress_cbk, extra_dims=extra_dims, patch_url=patch_url, )
@override def __str__(self) -> str: return f"Datacube<index={self.index!r}>" @override def __repr__(self) -> str: return self.__str__()
[docs] def close(self) -> None: """ Close any open connections """ self.index.close()
def __enter__(self) -> Datacube: return self def __exit__(self, type_, value, traceback) -> None: self.close()
def per_band_load_data_settings( measurements: list[Measurement] | Mapping[str, Measurement], resampling: Resampling | Mapping[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None, ) -> list[Measurement]: def with_resampling(m, resampling, default=None): m = m.copy() m["resampling_method"] = resampling.get(m.name, default) return m def with_fuser(m, fuser, default=None): m = m.copy() m["fuser"] = fuser.get(m.name, default) return m if resampling is not None and not isinstance(resampling, dict): resampling = {"*": resampling} if fuse_func is None or callable(fuse_func): fuse_func = {"*": fuse_func} if isinstance(measurements, dict): measurements = list(measurements.values()) if resampling is not None: measurements = [ with_resampling(m, resampling, default=resampling.get("*")) for m in measurements ] if fuse_func is not None: measurements = [ with_fuser(m, fuse_func, default=fuse_func.get("*")) for m in measurements ] return measurements def output_geobox( like: GeoBox | LegacyGeoBox | xarray.Dataset | xarray.DataArray | None = None, output_crs: Any = None, resolution: ( int | float | tuple[int | float, int | float] | Resolution | None ) = None, align: XY[float] | Iterable[float] | None = None, grid_spec: GridSpec | None = None, load_hints: Mapping[str, Any] | None = None, datasets: Iterable[Dataset] | None = None, geopolygon: Geometry | None = None, **query: QueryField, ) -> GeoBox: """Configure output geobox from user provided output specs.""" if like is not None: assert output_crs is None, "'like' and 'output_crs' are not supported together" assert resolution is None, "'like' and 'resolution' are not supported together" assert align is None, "'like' and 'align' are not supported together" return _normalise_geobox(like) if load_hints: if output_crs is None: output_crs = load_hints.get("output_crs", None) if resolution is None: resolution = cast( int | float | tuple[int | float, int | float] | None, load_hints.get("resolution", None), ) if align is None: align = load_hints.get("align", None) if output_crs is not None: if resolution is None: raise ValueError("Must specify 'resolution' when specifying 'output_crs'") crs = CRS(output_crs) elif grid_spec is not None: # specification from grid_spec crs = grid_spec.crs if resolution is None: resolution = grid_spec.resolution align = align or grid_spec.alignment else: raise ValueError( "Product has no default CRS.\nMust specify 'output_crs' and 'resolution'" ) # Try figuring out bounds # 1. Explicitly defined with geopolygon # 2. Extracted from x=,y= # 3. Computed from dataset footprints # 4. fail with ValueError if geopolygon is None: geopolygon = extract_geom_from_query(**query) if geopolygon is None: if datasets is None: raise ValueError("Bounds are not specified") geopolygon = get_bounds(datasets, crs) if type(resolution) is tuple: _LOG.warning( "Resolution should be provided as a single int or float, or the axis order specified " "using odc.geo.resxy_ or odc.geo.resyx_" ) if resolution[0] == -resolution[1]: resolution = resolution[1] else: _LOG.warning( "Assuming resolution has been provided in (y, x) ordering. Please specify the order " "with odc.geo.resxy_ or odc.geo.resyx_" ) resolution = resyx_(*resolution) resolution = res_(cast(Resolution | int | float, resolution)) if align is not None: align = yx_(align) return GeoBox.from_geopolygon(geopolygon, resolution, crs, align) def select_datasets_inside_polygon( datasets: Iterable[Dataset], polygon: Geometry ) -> Iterable[Dataset]: # Check against the bounding box of the original scene, can throw away some portions # (Only needed for index drivers without spatial index support) query_crs = polygon.crs for dataset in datasets: if dataset.extent and intersects(polygon, dataset.extent.to_crs(query_crs)): yield dataset def fuse_lazy( datasets: Iterable[Dataset], geobox: GeoBox, measurement: Measurement, skip_broken_datasets: bool = False, prepend_dims: int = 0, extra_dim_index: int | None = None, patch_url: Callable[[str], str] | None = None, ) -> numpy.ndarray: prepend_shape = (1,) * prepend_dims data = numpy.full(geobox.shape, measurement.nodata, dtype=measurement.dtype) _fuse_measurement( data, datasets, geobox, measurement, skip_broken_datasets=skip_broken_datasets, extra_dim_index=extra_dim_index, patch_url=patch_url, ) return data.reshape(prepend_shape + geobox.shape) def _fuse_measurement( dest: numpy.ndarray, datasets: Iterable[Dataset], geobox: GeoBox, measurement: Measurement, skip_broken_datasets: bool = False, progress_cbk: ProgressFunction | None = None, extra_dim_index: int | None = None, patch_url: Callable[[str], str] | None = None, ) -> None: srcs = [] for ds in datasets: src = None with ignore_exceptions_if(skip_broken_datasets): src = new_datasource( BandInfo( ds, measurement.name, extra_dim_index=extra_dim_index, patch_url=patch_url, ) ) if src is None: if not skip_broken_datasets: raise ValueError(f"Failed to load dataset: {ds.id}") else: srcs.append(src) reproject_and_fuse( srcs, dest, geobox, dest.dtype.type(measurement.nodata), resampling=measurement.get("resampling_method", "nearest"), fuse_func=measurement.get("fuser", None), skip_broken_datasets=skip_broken_datasets, progress_cbk=progress_cbk, extra_dim_index=extra_dim_index, ) def get_bounds(datasets: Iterable[Dataset], crs: CRS) -> Geometry: bbox = bbox_union(ds.extent.to_crs(crs).boundingbox for ds in datasets if ds.extent) return box(*bbox, crs=crs) # type: ignore[misc] def _calculate_chunk_sizes( sources: xarray.DataArray, geobox: GeoBox, dask_chunks: dict[str, str | int], extra_dims: ExtraDimensions | None = None, ) -> tuple[tuple, ...]: extra_dim_names: tuple[str, ...] = () extra_dim_shapes: tuple[int, ...] = () if extra_dims is not None: extra_dim_names, extra_dim_shapes = extra_dims.chunk_size() valid_keys = sources.dims + extra_dim_names + geobox.dimensions bad_keys = cast(set[str], set(dask_chunks)) - cast(set[str], set(valid_keys)) if bad_keys: raise KeyError( f"Unknown dask_chunk dimension {bad_keys}. Valid dimensions are: {valid_keys}" ) chunk_maxsz = dict( zip( sources.dims + extra_dim_names + geobox.dimensions, sources.shape + extra_dim_shapes + geobox.shape, ) ) # defaults: 1 for non-spatial, whole dimension for Y/X chunk_defaults = dict( [(dim, 1) for dim in sources.dims] + [(dim, 1) for dim in extra_dim_names] + [(dim, -1) for dim in geobox.dimensions] ) def _resolve(k, v: str | int | None) -> int: if v is None or v == "auto": v = _resolve(k, chunk_defaults[k]) if isinstance(v, int): if v < 0: return chunk_maxsz[k] return v raise ValueError("Chunk should be one of int|'auto'") irr_chunks = tuple(_resolve(dim, dask_chunks.get(str(dim))) for dim in sources.dims) extra_dim_chunks = tuple( _resolve(dim, dask_chunks.get(str(dim))) for dim in extra_dim_names ) grid_chunks = tuple( _resolve(dim, dask_chunks.get(str(dim))) for dim in geobox.dimensions ) if extra_dim_chunks: return irr_chunks, extra_dim_chunks, grid_chunks else: return irr_chunks, grid_chunks def _tokenize_dataset(dataset: Dataset) -> str: return f"dataset-{dataset.id.hex}" # pylint: disable=too-many-locals def _make_dask_array( chunked_srcs: xarray.DataArray, dsk, gbt, measurement: Measurement, chunks, skip_broken_datasets: bool = False, extra_dims: ExtraDimensions | None = None, patch_url: Callable[[str], str] | None = None, ): dsk = dsk.copy() # this contains mapping from dataset id to dataset object token = uuid.uuid4().hex dsk_name = f"dc_load_{measurement.name}-{token}" needed_irr_chunks, grid_chunks = chunks[:-2], chunks[-2:] actual_irr_chunks = (1,) * len(needed_irr_chunks) # we can have up to 4 empty chunk shapes: whole, right edge, bottom edge and # bottom right corner # W R # B BR empties: dict[tuple[int, int], str] = {} def _mk_empty(shape: tuple[int, int]) -> str: name = empties.get(shape, None) if name is not None: return name name = "empty_{}x{}-{token}".format(*shape, token=token) dsk[name] = ( numpy.full, actual_irr_chunks + shape, measurement.nodata, measurement.dtype, ) empties[shape] = name return name for irr_index, tiled_dss in numpy.ndenumerate(chunked_srcs.values): key_prefix = (dsk_name, *irr_index) # all spatial chunks for idx in numpy.ndindex(gbt.shape.shape): dss = tiled_dss.get(idx, None) if dss is None: val3d = _mk_empty(gbt.chunk_shape(idx).yx) # 3D case if "extra_dim" in measurement: assert extra_dims is not None # For type checker index_subset = extra_dims.measurements_index(measurement.extra_dim) for result_index, _ in numpy.ndenumerate(range(*index_subset)): dsk[key_prefix + result_index + idx] = val3d else: dsk[key_prefix + idx] = val3d else: val = ( fuse_lazy, [_tokenize_dataset(ds) for ds in dss], gbt[idx], measurement, skip_broken_datasets, len(needed_irr_chunks), ) # 3D case if "extra_dim" in measurement: # Do extra_dim subsetting here assert extra_dims is not None # For type checker index_subset = extra_dims.measurements_index(measurement.extra_dim) for result_index, extra_dim_index in enumerate( range(*index_subset) ): dsk[key_prefix + (result_index,) + idx] = val + ( extra_dim_index, patch_url, ) else: # Get extra_dim index if available extra_dim_index = measurement.get("extra_dim_index", None) dsk[key_prefix + idx] = val + (extra_dim_index, patch_url) y_shapes = [grid_chunks[0]] * gbt.shape[0] x_shapes = [grid_chunks[1]] * gbt.shape[1] y_shapes[-1], x_shapes[-1] = gbt.chunk_shape(tuple(n - 1 for n in gbt.shape)) extra_dim_shape: tuple = () if "extra_dim" in measurement: assert extra_dims is not None # For type checker dim_name = measurement.extra_dim extra_dim_shape += (len(extra_dims.measurements_values(dim_name)),) data = da.Array( dsk, dsk_name, chunks=actual_irr_chunks + (tuple(y_shapes), tuple(x_shapes)), dtype=measurement.dtype, shape=(chunked_srcs.shape + extra_dim_shape + gbt.base.shape), ) if needed_irr_chunks != actual_irr_chunks: data = data.rechunk(chunks=chunks) return data