Source code for skmap.io.base

"""
Raster data input and output
"""

from decorator import contextmanager
from rasterio.io import DatasetWriter

import copy
import math
import os
import tempfile
import time
from base64 import b64decode, encodebytes
from contextlib import ExitStack
from copy import deepcopy
from datetime import datetime
from io import BytesIO
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple, Union
from uuid import uuid4

import numpy
import numpy as np
import pandas as pd
import rasterio
import requests
from dateutil.relativedelta import relativedelta
from IPython.display import HTML
from matplotlib import pyplot
from matplotlib._animation_data import DISPLAY_TEMPLATE, JS_INCLUDE, STYLE_INCLUDE
from minio import Minio
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy.typing import NDArray
from osgeo import gdal
from pandas import DataFrame, Series, to_datetime
from PIL import Image
from pystac.item import Item
from rasterio.windows import Window, from_bounds
from shapely.geometry import box, shape

import skmap_bindings as sb
from skmap import SKMapBase, SKMapGroupRunner, SKMapRunner, parallel
from skmap.misc import (
    _eval,
    date_range,
    del_memmap,
    load_memmap,
    make_tempdir,
    new_memmap,
    ref_memmap,
    ttprint,
    vrt_warp,
)

_INT_DTYPE = (
    "uint8",
    "uint8",
    "int16",
    "uint16",
    "int32",
    "uint32",
    "int64",
    "uint64",
    "int",
    "uint",
)


def _nodata_replacement(dtype: str):
    if dtype in _INT_DTYPE:
        return np.iinfo(dtype).max
    else:
        return np.nan


def _fit_in_dtype(data: NDArray, dtype: str, nodata: int) -> NDArray:
    if dtype in _INT_DTYPE:
        data = np.rint(data)

        min_val = np.iinfo(dtype).min
        max_val = np.iinfo(dtype).max

        data = np.where((data < min_val), min_val, data)
        data = np.where((data > max_val), max_val, data)

        if nodata == min_val:
            data = np.where((data == nodata), (min_val + 1), data)
        elif nodata == max_val:
            data = np.where((data == nodata), (max_val - 1), data)

    return data


def _read_raster(
    raster_idx,
    raster_files,
    band,
    window,
    dtype,
    data_mask,
    expected_shape,
    try_without_window,
    scale,
    gdal_opts,
    overview,
    verbose,
):
    # array_mm = load_memmap(**ref_array)

    for key in gdal_opts.keys():
        gdal.SetConfigOption(key, gdal_opts[key])

    raster_file = raster_files[raster_idx]
    ds, band_data = None, None
    nodata = None

    try:
        ds = rasterio.open(raster_file)

        ttprint("Start reading")
        array_mm = None
        if overview is not None:
            overviews = ds.overviews(band)
            if overview in overviews:
                array_mm = ds.read(
                    band,
                    out_dtype=dtype,
                    out_shape=(
                        1,
                        math.ceil(ds.height // overview),
                        math.ceil(ds.width // overview),
                    ),
                    window=window,
                )
            else:
                array_mm = ds.read(band, out_dtype=dtype, window=window)
        else:
            array_mm = ds.read(band, out_dtype=dtype, window=window)
        ttprint("End reading")

        # if band_data.size == 0 and try_without_window:
        #  band_data = ds.read(band, out=array_mm[:,:,raster_idx])

        # band_data = band_data.astype(dtype)
        nodata = ds.nodatavals[0]
        data_exists = True
        # print(f"Data was read: {raster_file}")
    except Exception as ex:
        ttprint(f"Exception: {ex}")
        # traceback.i(print)_exc()

        if window is not None:
            if verbose:
                ttprint(f"ERROR: Failed to read {raster_file} window {window}")
            array_mm = np.empty((int(window.height), int(window.width)))
            array_mm = _nodata_replacement(dtype)

        if expected_shape is not None:
            if verbose:
                ttprint(f"Full nan image for {raster_file}")
            array_mm = np.empty(expected_shape)
            array_mm = _nodata_replacement(dtype)

    if data_exists:
        if data_mask is not None:
            if len(data_mask.shape) == 3:
                data_mask = data_mask[:, :, 0]

            if data_mask.shape == band_data.shape:
                array_mm[np.logical_not(data_mask)] = np.nan
            else:
                ttprint(
                    f"WARNING: incompatible data_mask shape {data_mask.shape} != {band_data.shape}"
                )

        if nodata is not None:
            ttprint("Start _nodata_replacement")
            array_mm[array_mm == nodata] = _nodata_replacement(dtype)
            ttprint("End _nodata_replacement")

    if scale != 1.0:
        ttprint("Start scaling")
        array_mm = array_mm * scale
        ttprint("Start scaling")

    return array_mm, raster_idx, data_exists


def _read_auth_raster(raster_files, url_pos, bands, username, password, dtype, nodata):
    url = raster_files[url_pos]

    data = None
    ds_params = None

    try:
        data = requests.get(url, auth=(username, password), stream=True)

        with rasterio.io.MemoryFile(data.content) as memfile:
            if verbose:
                ttprint(f"Reading {url} to {memfile.name}")

            with memfile.open() as ds:
                if bands is None:
                    bands = range(1, ds.count + 1)

                if nodata is None:
                    nodata = ds.nodatavals[0]

                data = ds.read(bands)
                if isinstance(data, np.ndarray):
                    data = data.astype(dtype)
                    data[data == nodata] = _nodata_replacement(dtype)

                nbands, x_size, y_size = data.shape

                ds_params = {
                    "driver": ds.driver,
                    "width": x_size,
                    "height": y_size,
                    "count": nbands,
                    "dtype": ds.dtypes[0],
                    "crs": ds.crs,
                    "transform": ds.transform,
                }

    except:
        ttprint(f"Invalid raster file {url}")
        # traceback.print_exc()
        pass

    return url_pos, data, ds_params


@contextmanager
def _new_raster(base_raster, raster_file, data, window=None, dtype=None, nodata=None):
    if not isinstance(raster_file, Path):
        raster_file = Path(raster_file)

    raster_file.parent.mkdir(parents=True, exist_ok=True)

    if len(data.shape) < 3:
        data = np.stack([data], axis=2)

    x_size, y_size, nbands = data.shape

    with rasterio.open(base_raster, "r") as base_raster:
        if dtype is None:
            dtype = base_raster.dtypes[0]

        if nodata is None:
            nodata = base_raster.nodata

        transform = base_raster.transform

        if window is not None:
            transform = rasterio.windows.transform(window, transform)

        with rasterio.open(
            raster_file,
            "w",
            driver="GTiff",
            height=x_size,
            width=y_size,
            count=nbands,
            dtype=dtype,
            crs=base_raster.crs,
            compress="LZW",
            transform=transform,
            nodata=nodata,
        ) as dataset:
            yield dataset


def _save_raster(
    fn_base_raster: str,
    raster_file: str,
    ref_array,
    i: int,
    spatial_win: Window | None = None,
    dtype: str | None = None,
    nodata=None,
    fit_in_dtype=False,
    on_each_outfile: Callable | None = None,
):
    # if len(data.shape) < 3:
    #  data = np.stack([data], axis=2)

    # _, _, nbands = data.shape

    array = load_memmap(**ref_array)

    with _new_raster(
        fn_base_raster, raster_file, array[:, :, i], spatial_win, dtype, nodata
    ) as new_raster:  # type: DatasetWriter
        band_dtype = new_raster.dtypes[0]

        if fit_in_dtype:
            array[:, :, i] = _fit_in_dtype(
                array[:, :, i], band_dtype, new_raster.nodata
            )

        array[:, :, i][np.isnan(array[:, :, i])] = new_raster.nodata
        new_raster.write(array[:, :, i].astype(band_dtype), indexes=1)

    if on_each_outfile is not None:
        on_each_outfile(raster_file)

    return raster_file


[docs] def save_rasters_cpp( base_raster: Union[List, str], out_data: np.array, out_files: Union[List, str], out_dir: str = ".", out_idx: List = None, out_s3: Union[List, str] = None, window: Window = None, nodata: int = None, dtype: type = np.int16, n_jobs: int = 8, gdal_opts: dict = {}, gdal_co: str = { "COMPRESS": "deflate", "ZLEVEL": "9", "TILED": "TRUE", "BLOCKXSIZE": "1024", "BLOCKYSIZE": "1024", }, verbose=False, ): if isinstance(out_files, str): out_files = [out_files] if len(out_files) < n_jobs: n_jobs = len(out_files) n_layers = len(out_files) if window is None: ds = rasterio.open(base_raster) window = rasterio.windows.Window(0, 0, ds.width, ds.height) if out_idx is None: out_idx = list(range(0, n_layers)) if out_s3 is not None: out_dir = str(make_tempdir()) if nodata is None: ds = rasterio.open(base_raster) nodata = int(ds.nodatavals[0]) gdal_co = "-co " + " -co ".join( [f"{k}={v}" for k, v in zip(gdal_co.keys(), gdal_co.values())] ) gdal_cmd = f"gdal_translate -a_nodata {nodata} {gdal_co}" if isinstance(base_raster, str): base_raster = [base_raster for i in out_files] write_fn = sb.writeInt16Data if dtype == np.uint8: write_fn = sb.writeByteData elif dtype == np.uint16: write_fn = sb.writeUInt16Data elif dtype == np.float32: write_fn = sb.writeData if verbose: ttprint(f"Saving {n_layers} layers using window={window} to ") write_fn( out_data, n_jobs, gdal_opts, base_raster, out_dir, out_files, out_idx, window.col_off, window.row_off, window.width, window.height, nodata, gdal_cmd, out_s3, ) if verbose: ttprint("End") if out_s3 is not None: return out_s3 else: return [str(Path(out_dir).joinpath(f"{o}.tif")) for o in out_files]
[docs] def read_rasters_cpp( raster_files: Union[List[Union[str, Path]], str, Path] = [], band: Union[List[int], int] = 1, window: Optional[Window] = None, n_jobs: int = 8, out_data: Optional[NDArray[np.float32]] = None, out_idx: Optional[List] = None, dtype: type = np.float32, gdal_opts: dict = {}, verbose=False, ): """ Read rasters in parallel using the C++ backend, aggregating them into a single array. :param raster_files: A list with the raster paths. :param band: The band to be read from each raster file :param window: The window (if any) to read from the raster :param n_jobs: The number threads to read in parallel :param out_data: a pre-allocated array to write into :param out_idx: permutation array :param dtype: Datatype (currently only `np.float32` is supported) :param gdal_opts: additional options to be passed to GDAL :param verbose: Whether to print extra output :returns: A 2D array of ``n_bands`` by ``n_pixels`` :rtype: NDArray[np.float32] Examples ======== >>> import rasterio >>> import tempfile >>> import numpy as np >>> from skmap.io.base import read_rasters_cpp >>> from pathlib import Path >>> # Create a dummy raster >>> with tempfile.TemporaryDirectory() as tempdir: ... band_1 = np.random.rand(100,100).astype(np.float32) ... band_2 = np.random.rand(100,100).astype(np.float32) ... transform = rasterio.transform.from_origin(0,0,1,1) ... raster_name = Path(tempdir)/"example.tif" ... with rasterio.open( ... raster_name, ... "w", ... height=100, ... width=100, ... count=2, ... dtype=np.float32, ... crs="EPSG:4326", ... transform=transform ... ) as raster: ... raster.write(band_1, 1) ... raster.write(band_2, 2) ... # read the raster bands in parallel ... bands = read_rasters_cpp(raster_files=[raster_name, raster_name], band=[1,2]) ... # the shape is bands x n_pix ... assert bands.shape == (2,100*100) ... # bands are selected as above ... np.testing.assert_equal(band_1.reshape(100*100), bands[0,:]) ... np.testing.assert_equal(band_2.reshape(100*100), bands[1,:]) """ if isinstance(raster_files, str) or isinstance(raster_files, Path): raster_files = [raster_files] if isinstance(band, int): band = [band] if len(raster_files) == 0: ttprint("No raster files provided, nothing will be written") raise ValueError(f"Should provide at least one raster file, got {raster_files}") if isinstance(raster_files[0], Path): raster_files = [str(r) for r in raster_files] if len(raster_files) < n_jobs: n_jobs = len(raster_files) n_layers = len(raster_files) if window is None: ds = rasterio.open(raster_files[0]) window = rasterio.windows.Window(0, 0, ds.width, ds.height) if out_data is None: out_data = np.empty((n_layers, window.width * window.height), dtype=dtype) if out_idx is None: out_idx = list(range(0, n_layers)) nodata_out = np.nan if verbose: ttprint( f"Reading {n_layers} layers using window={window} and array={out_data.shape}" ) sb.readData( out_data, n_jobs, raster_files, out_idx, window.col_off, window.row_off, window.width, window.height, band, gdal_opts, None, nodata_out, ) if verbose: ttprint("End") return out_data
[docs] def read_rasters( raster_files: Union[List, str] = [], band: int = 1, window: Window | None = None, bounds: [] = None, dtype: str = "float32", n_jobs: int = 8, data_mask: NDArray[np.float32] = None, scale: float = 1.0, expected_shape=None, try_without_window: bool = False, gdal_opts: dict = {}, overview=None, max_rasters=None, verbose=False, ) -> NDArray[np.float32]: """ Read raster files aggregating them into a single array. Only the first band of each raster is read. The ``nodata`` value is replaced by ``np.nan`` in case of ``dtype=float*``, and for ``dtype=*int*`` it's replaced by the the lowest possible value inside the range (for ``int16`` this value is ``-32768``). :param raster_files: A list with the raster paths. Provide it and the ``raster_dirs`` is ignored. :param window: Read the data according to the spatial window. By default is ``None``, reading all the raster data. :param dtype: Convert the read data to specific ``dtype``. By default it reads in ``float16`` to save memory, however pay attention in the precision limitations for this ``dtype`` [1]. :param n_jobs: Number of parallel jobs used to read the raster files. :param data_mask: A array with the same space dimensions of the read data, where all the values equal ``0`` are converted to ``np.nan``. :param expected_shape: The expected size (space dimension) of the read data. In case of error in reading any of the raster files, this is used to create a empty 2D array. By default is ``None``, throwing a exception if the raster doesn't exists. :param try_without_window: First, try to read using ``window``, if fails try to read without it. :param overview: Overview level to be read. In COG files are usually `[2, 4, 8, 16, 32, 64, 128, 256]`. :param verbose: Use ``True`` to print the reading progress. :returns: A 3D array, where the last dimension refers to the read files, and a list containing the read paths. :rtype: Tuple[Numpy.array, List[Path]] Examples ======== >>> import rasterio >>> from skmap.io.base import read_rasters >>> >>> # skmap COG layers - NDVI seasons for 2000 >>> # these actually 404 >>> raster_files = [ ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_200003_skmap_epsg3035_v1.0.tif', # winter ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_200006_skmap_epsg3035_v1.0.tif', # spring ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_200009_skmap_epsg3035_v1.0.tif', # summer ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_200012_skmap_epsg3035_v1.0.tif' # fall ... ] >>> >>> # Transform for the EPSG:3035 >>> eu_transform = rasterio.open(raster_files[0]).transform # doctest: +SKIP >>> # Bounding box window over Wageningen, NL >>> window = rasterio.windows.from_bounds(left=4020659, bottom=3213544, right=4023659, top=3216544, transform=eu_transform) # doctest: +SKIP >>> >>> data, _ = read_rasters_cpp(raster_files=raster_files, window=window, verbose=True) # doctest: +SKIP >>> print(f'Data shape: {data.shape}') # doctest: +SKIP References ========== [1] `Float16 Precision <https://github.com/numpy/numpy/issues/8063>`_ """ if data_mask is not None and dtype not in ("float16", "float32"): raise Exception("The data_mask requires dtype as float") if isinstance(raster_files, str): raster_files = [raster_files] if len(raster_files) < n_jobs: n_jobs = len(raster_files) if verbose: ttprint(f"Reading {len(raster_files)} raster file(s) using {n_jobs} workers") ds = rasterio.open(raster_files[-1]) if bounds is not None and len(bounds) == 4: bounds = shape( rasterio.warp.transform_geom( src_crs="EPSG:4326", dst_crs=ds.crs, geom=box(*bounds), ) ).bounds window = from_bounds(*bounds, ds.transform).round_lengths() if verbose: ttprint(f"Transform {bounds} into {window}") if overview is not None: overviews = ds.overviews(band) if overview in overviews: height, width = ( math.ceil(ds.height // overview), math.ceil(ds.width // overview), ) else: raise Exception( f"Overview {overviews} is invalid for {raster_files[-1]}.\n" f"Use one of overviews: {ds.overviews(band)}" ) elif window is not None: ( height, width, ) = window.height, window.width else: ( height, width, ) = ds.height, ds.width ttprint("Start new_memmap") if max_rasters is not None: array_mm = new_memmap(dtype, shape=(height, width, max_rasters)) else: # TOCHECK: this was multiplied by 10 before, but why? array_mm = new_memmap(dtype, shape=(height, width, len(raster_files))) ttprint(f"End new_memmap of shape {array_mm.shape}") # ref_array = ref_memmap(array_mm) # print(ref_array) args = [ ( raster_idx, raster_files, band, window, dtype, data_mask, expected_shape, try_without_window, scale, gdal_opts, overview, verbose, ) for raster_idx in range(0, len(raster_files)) ] for array, raster_idx, data_exists in parallel.job( _read_raster, args, n_jobs=n_jobs, joblib_args={"backend": "loky", "return_as": "generator"}, ): # joblib_args={ # 'backend': 'threading', # 'pre_dispatch': math.ceil(n_jobs / 3), # 'batch_size': math.floor(len(args) / n_jobs), # 'return_as': 'generator' # }): ttprint(array.shape) array_mm[:, :, raster_idx] = array if not data_exists: raster_file = raster_files[raster_idx] raise Exception(f"The raster {raster_file} not exists") return array_mm
# if not keep_memmap: # return del_memmap(array_mm, True) # else: # return array_mm
[docs] def read_auth_rasters( raster_files: List, username: str, password: str, bands=None, dtype: str = "float16", n_jobs: int = 4, return_base_raster: bool = False, nodata=None, verbose: bool = False, ): """ Read raster files trough a authenticate HTTP service, aggregating them into a single array. For raster files without authentication it's better to use read_rasters. The ``nodata`` value is replaced by ``np.nan`` in case of ``dtype=float*``, and for ``dtype=*int*`` it's replaced by the the lowest possible value inside the range (for ``int16`` this value is ``-32768``). :param raster_files: A list with the raster urls. :param username: Username to provide to the basic access authentication. :param password: Password to provide to the basic access authentication. :param bands: Which bands needs to be read. By default is ``None`` reading all the bands. :param dtype: Convert the read data to specific ``dtype``. By default it reads in ``float16`` to save memory, however pay attention in the precision limitations for this ``dtype`` [1]. :param n_jobs: Number of parallel jobs used to read the raster files. :param return_base_raster: Return an empty raster with the same properties of the read rasters ``(height, width, n_bands, crs, dtype, transform)``. :param nodata: Use this value if the nodata property is not defined in the read rasters. :param verbose: Use ``True`` to print the reading progress. :returns: A 4D array, where the first dimension refers to the bands and the last dimension to read files. If ``return_base_raster=True`` the second value will be a base raster path. :rtype: Numpy.array or Tuple[Numpy.array, Path] Examples ======== >>> from skmap.io.base import read_auth_rasters >>> >>> # Do the registration in >>> # https://glad.umd.edu/ard/user-registration >>> username = '<YOUR_USERNAME>' >>> password = '<YOUR_PASSWORD>' >>> raster_files = [ ... 'https://glad.umd.edu/dataset/landsat_v1.1/47N/092W_47N/850.tif', ... 'https://glad.umd.edu/dataset/landsat_v1.1/47N/092W_47N/851.tif', ... 'https://glad.umd.edu/dataset/landsat_v1.1/47N/092W_47N/852.tif', ... 'https://glad.umd.edu/dataset/landsat_v1.1/47N/092W_47N/853.tif' ... ] >>> >>> data, base_raster = read_auth_rasters( ... raster_files, ... username, ... password, ... return_base_raster=True, ... verbose=True ... ) # doctest: +SKIP >>> print(f'Data: shape={data.shape}, dtype={data.dtype} and base_raster={base_raster}') # doctest: +SKIP References ========== [1] `Float16 Precision <https://github.com/numpy/numpy/issues/8063>`_ """ if verbose: ttprint( f"Reading {len(raster_files)} remote raster files using {n_jobs} workers" ) args = [ (raster_files, url_pos, bands, username, password, dtype, nodata) for url_pos in range(0, len(raster_files)) ] raster_data = {} fn_base_raster = None for url_pos, data, ds_params in parallel.job( _read_auth_raster, args, n_jobs=n_jobs ): if data is not None: raster_data[url_pos] = data if return_base_raster and fn_base_raster is None: with tempfile.NamedTemporaryFile( suffix=".tif", delete=False ) as base_raster: with rasterio.open( base_raster.name, "w", driver=ds_params["driver"], width=ds_params["width"], height=ds_params["height"], count=ds_params["count"], crs=ds_params["crs"], dtype=ds_params["dtype"], transform=ds_params["transform"], ) as ds: fn_base_raster = ds.name raster_data_arr = [] for i in range(0, len(raster_files)): if i in raster_data: raster_data_arr.append(raster_data[i]) raster_data = np.stack(raster_data_arr, axis=-1) del raster_data_arr if return_base_raster: if verbose: ttprint(f"The base raster is {fn_base_raster}") return raster_data, fn_base_raster else: return raster_data
[docs] def save_rasters( base_raster: str, raster_files: List, array, window: Window = None, bounds: [] = None, dtype: str = None, nodata=None, array_idx: List = [], fit_in_dtype: bool = False, n_jobs: int = 8, on_each_outfile: Callable = None, verbose: bool = False, ): """ Save a 3D array in multiple raster files using as reference one base raster. The last dimension is used to split the array in different rasters. GeoTIFF is the only output format supported. It always replaces the ``np.nan`` value by the specified ``nodata``. :param base_raster: The base raster path used to retrieve the parameters ``(height, width, n_bands, crs, dtype, transform)`` for the new rasters. :param raster_files: A list containing the paths for the new raster. It creates the folder hierarchy if not exists. :param array: 3D data array. :param window: Save the data considering a spatial window, even if the ``base_rasters`` refers to a bigger area. For example, it's possible to have a base raster covering the whole Europe and save the data using a window that cover just part of Wageningen. By default is ``None`` saving the raster data in position ``0, 0`` of the raster grid. :param dtype: Convert the data to a specific ``dtype`` before save it. By default is ``None`` using the same ``dtype`` from the base raster. :param nodata: Use the specified value as ``nodata`` for the new rasters. By default is ``None`` using the same ``nodata`` from the base raster. :param fit_in_dtype: If ``True`` the values outside of ``dtype`` range are truncated to the minimum and maximum representation. It's also change the minimum and maximum data values, if they exist, to avoid overlap with ``nodata`` (see the ``_fit_in_dtype`` function). For example, if ``dtype='uint8'`` and ``nodata=0``, all data values equal to ``0`` are re-scaled to ``1`` in the new rasters. :param n_jobs: Number of parallel jobs used to save the raster files. :param verbose: Use ``True`` to print the saving progress. :returns: A list containing the path for new rasters. :rtype: List[Path] Examples ======== >>> import rasterio >>> from skmap.io.base import read_rasters, save_rasters >>> >>> # skmap COG layers - NDVI seasons for 2019 >>> raster_files = [ ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201903_skmap_epsg3035_v1.0.tif', # winter ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201906_skmap_epsg3035_v1.0.tif', # spring ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201909_skmap_epsg3035_v1.0.tif', # summer ... 'http://s3.eu-central-1.wasabisys.com/skmap/lcv/lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201912_skmap_epsg3035_v1.0.tif' # fall ... ] >>> >>> # Transform for the EPSG:3035 >>> eu_transform = rasterio.open(raster_files[0]).transform # doctest: +SKIP >>> # Bounding box window over Wageningen, NL >>> window = rasterio.windows.from_bounds(left=4020659, bottom=3213544, right=4023659, top=3216544, transform=eu_transform) #doctest: +SKIP >>> >>> data, _ = read_rasters(raster_files=raster_files, window=window, verbose=True) #doctest: +SKIP >>> >>> # Save in the current execution folder >>> raster_files = [ ... './lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201903_wageningen_epsg3035_v1.0.tif', ... './lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201906_wageningen_epsg3035_v1.0.tif', ... './lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201909_wageningen_epsg3035_v1.0.tif', ... './lcv_ndvi_landsat.glad.ard_p50_30m_0..0cm_201912_wageningen_epsg3035_v1.0.tif' ... ] # doctest: +SKIP >>> >>> save_rasters(raster_files[0], raster_files, data, window=window, verbose=True) #doctest: +SKIP """ if type(raster_files) == str: raster_files = [raster_files] # if len(data.shape) < 3: # data = np.stack([data], axis=2) if len(array_idx) == 0: array_idx = list(range(0, array.shape[-1])) if len(array_idx) != len(raster_files): raise Exception( f"The array shape {array.shape} is incompatible with the raster_files size {len(raster_files)}." ) ds = rasterio.open(base_raster) if bounds is not None and len(bounds) == 4: bounds = shape( rasterio.warp.transform_geom( src_crs="EPSG:4326", dst_crs=ds.crs, geom=box(*bounds), ) ).bounds window = from_bounds(*bounds, ds.transform).round_lengths() if verbose: ttprint(f"Transform {bounds} into {window}") if verbose: ttprint(f"Saving {len(raster_files)} raster files using {n_jobs} workers") ref_array = ref_memmap(array) args = [ (base_raster, raster_file, ref_array, i, window, dtype, nodata, fit_in_dtype) for raster_file, i in zip(raster_files, array_idx) ] # batch_size = math.floor(len(args) / n_jobs) # if batch_size <= 0: # batch_size = 'auto' out_files = [] for out_raster in parallel.job( _save_raster, args, n_jobs=n_jobs, joblib_args={"backend": "loky", "return_as": "generator"}, ): if on_each_outfile is not None: on_each_outfile(out_raster) out_files.append(out_raster) continue return out_files
[docs] class RasterData(SKMapBase): PLACEHOLDER_DT = "{dt}" INTERVAL_DT_SEP = "_" GROUP_COL = "group" NAME_COL = "name" PATH_COL = "input_path" BAND_COL = "input_band" TEMPORAL_COL = "temporal" DT_COL = "date" START_DT_COL = "start_date" END_DT_COL = "end_date" TRANSFORM_SEP = "." def __init__( self, raster_files: Union[List, str, dict], raster_mask: str = None, raster_mask_val=np.nan, max_rasters: int = None, verbose=False, ) -> None: if isinstance(raster_files, str): raster_files = {"default": [raster_files]} elif isinstance(raster_files, list): raster_files = {"default": raster_files} self.raster_files = raster_files self.verbose = verbose self.raster_mask = raster_mask self.raster_mask_val = raster_mask_val rows = [] for group in raster_files.keys(): if isinstance(raster_files[group], str): rows.append([group, raster_files[group], 1, None, None]) else: for r in raster_files[group]: if isinstance(r, tuple): if len(r) == 2: rows.append([group, r[0], r[1], None, None]) elif len(r) == 4: rows.append([group, r[0], r[1], r[2], r[3]]) else: raise Exception( f"Wrong tuple size {len(r)}. Please provide 2 or 4 size tuple." ) else: rows.append([group, r, 1, None, None]) self.info = DataFrame( rows, columns=[ RasterData.GROUP_COL, RasterData.PATH_COL, RasterData.BAND_COL, RasterData.START_DT_COL, RasterData.END_DT_COL, ], ) self.info[RasterData.TEMPORAL_COL] = self.info.apply( lambda r: RasterData.PLACEHOLDER_DT in str(r[RasterData.PATH_COL]), axis=1 ) self.info[RasterData.NAME_COL] = self.info.apply( lambda r: Path(str(r[RasterData.PATH_COL]).split("?")[0]).stem if not r[RasterData.TEMPORAL_COL] else None, axis=1, ) self.date_args = {} self._active_group = None has_date = ~self.info[RasterData.START_DT_COL].isnull().any() if has_date: self.info[RasterData.TEMPORAL_COL] = True for g in self.info[RasterData.GROUP_COL].unique(): self.date_args[g] = { "date_style": "interval", "date_format": "%Y%m%d", "ignore_29feb": True, } self.info.reset_index(drop=True, inplace=True) self.max_rasters = max_rasters def _new_info_row( self, raster_file: str, name: str, group: str = None, dates: list = [], date_format=None, date_style=None, ignore_29feb=True, ): row = {} if group is None or "default" in group: group = "default" row[RasterData.PATH_COL] = raster_file row[RasterData.NAME_COL] = name row[RasterData.GROUP_COL] = group row[RasterData.BAND_COL] = 1 if self._active_group is not None: if date_style is None: date_style = self.date_args[self._active_group]["date_style"] if date_format is None: date_format = self.date_args[self._active_group]["date_format"] self.date_args[group] = self.date_args[self._active_group] else: self.date_args[group] = { "date_style": date_style, "date_format": date_format, "ignore_29feb": ignore_29feb, } if len(dates) > 0 and date_style is not None: row[RasterData.TEMPORAL_COL] = True dt1, dt2 = (dates[0], dates[1]) if isinstance(dt1, str): dt1 = datetime.strptime(dt1, date_format) if isinstance(dt2, str): dt2 = datetime.strptime(dt2, date_format) row[RasterData.START_DT_COL] = dt1 row[RasterData.END_DT_COL] = dt2 else: row[RasterData.PATH_COL] = raster_file row[RasterData.NAME_COL] = name return row def from_stac_items( stac_items: List[Item], bands: List[str] = None, to_crs=rasterio.crs.CRS.from_epsg(4326), spatial_res=None, resamp_method="near", n_jobs: int = 10, verbose=False, ): all_bands = list(stac_items[0].assets.keys()) if bands is None: if verbose: ttprint(f"Reading band {all_bands[0]} from {all_bands}") bands = [all_bands[0]] stac_info = [] stac_href = {} for i in stac_items: for band in i.assets.keys(): if band in bands: href = i.assets[band].href if href not in stac_href: stac_href[href] = False stac_info.append( { "href": i.assets[band].href, "band": band, "date": i.datetime.replace(tzinfo=None), } ) stac_info = pd.DataFrame(stac_info) raster_file, vrt_files = vrt_warp( stac_info["href"], dst_crs=to_crs.to_wkt(), tr=spatial_res, r_method=resamp_method, return_input_files=True, ) vrt_info = pd.DataFrame({"href": raster_file, "vrt": vrt_files}) stac_info = stac_info.merge(vrt_info, on="href", how="inner") groups = {} for g, row in stac_info.groupby("band"): if g not in groups: groups[g] = [] groups[g] += [(v, 1, d, d) for v, d in zip(row["vrt"], row["date"])] return RasterData(groups, verbose=verbose) def _set_date( self, text, dt1, dt2, date_format=None, date_style=None, ignore_29feb=None, **kwargs, ): if "gr" in kwargs and "default" in kwargs.get("gr"): gr = "" if date_format is None: date_format = self.date_args[self._active_group]["date_format"] if date_style is None: date_style = self.date_args[self._active_group]["date_style"] if ignore_29feb and "%j" in date_format: dt1 = dt1 + relativedelta(leapdays=-1) dt2 = dt2 + relativedelta(leapdays=-1) if date_style == "start_date": dt = f"{dt1.strftime(date_format)}" elif date_style == "end_date": dt = f"{dt2.strftime(date_format)}" else: dt = f"{dt1.strftime(date_format)}" dt += f"{RasterData.INTERVAL_DT_SEP}" dt += f"{dt2.strftime(date_format)}" return _eval(str(text), {**kwargs, **locals()}) def timespan( self, start_date, end_date, date_unit, date_step, date_style: str = "interval", date_format: str = "%Y%m%d", ignore_29feb=True, group: [list, str] = [], ): if isinstance(group, str): group = [group] to_drop = [] to_add = [] for _group, ginfo in self.info.groupby(RasterData.GROUP_COL): if len(group) > 0 and _group not in group: continue self.date_args[_group] = { "date_style": date_style, "date_format": date_format, "ignore_29feb": ignore_29feb, } dates = date_range( start_date, end_date, date_unit=date_unit, date_step=date_step, date_format=date_format, ignore_29feb=ignore_29feb, ) def fun(r): if r[RasterData.TEMPORAL_COL]: names, start, end = [], [], [] for dt1, dt2 in dates: names.append( self._set_date( r[RasterData.PATH_COL], dt1, dt2, date_format=date_format, date_style=date_style, ignore_29feb=ignore_29feb, ) ) start.append(dt1) end.append(dt2) return Series([names, start, end]) else: return Series([[r[RasterData.PATH_COL]], [None], [None]]) temporal_cols = [ RasterData.PATH_COL, RasterData.START_DT_COL, RasterData.END_DT_COL, ] ginfo[temporal_cols] = ginfo.apply(fun, axis=1) ginfo = ginfo.explode(temporal_cols) ginfo[RasterData.NAME_COL] = ginfo.apply( lambda r: Path(r[RasterData.PATH_COL]).stem, axis=1 ) to_drop.append(ginfo.index) to_add.append(ginfo) for idx in to_drop: self.info = self.info.drop(index=idx) self.info = pd.concat([self.info] + to_add).reset_index(drop=True) return self def _base_raster(self) -> Optional[bool]: for filepath in list(self.info[RasterData.PATH_COL]): if "http" in filepath: res = requests.head(filepath) if res.status_code == 200: return True else: if Path(filepath).exists(): return True def read( self, window: Window = None, bounds: list = None, dtype: str = "float32", expected_shape=None, overview: int = None, n_jobs: int = 4, scale: float = 1, gdal_opts: dict = {}, ): self.window = window self.bounds = bounds data_mask = None if self.raster_mask is not None: self._verbose( f"Masking {self.raster_mask_val} values considering {Path(self.raster_mask).name}" ) data_mask = read_rasters( [self.raster_mask], window=window, overview=overview, gdal_opts=gdal_opts, ) if self.raster_mask_val is np.nan: data_mask = np.logical_not(np.isnan(data_mask)) else: data_mask = data_mask != self.raster_mask_val self.base_raster = self._base_raster() raster_files = [] # FIXME: add supporting for band_list for band, rows in self.info.groupby(RasterData.BAND_COL): raster_files += [Path(r) for r in rows[RasterData.PATH_COL]] self._verbose( f"RasterData with {len(raster_files)} rasters" + f" and {len(self.info[RasterData.GROUP_COL].unique())} group(s)" ) self.array = read_rasters( raster_files, band=band, window=self.window, bounds=bounds, data_mask=data_mask, dtype=dtype, expected_shape=expected_shape, n_jobs=n_jobs, overview=overview, scale=scale, gdal_opts=gdal_opts, verbose=self.verbose, max_rasters=self.max_rasters, ) self._verbose(f"Read array shape: {self.array.shape}") return self def run( self, process: SKMapRunner, group: [list, str] = [], outname: str = None, drop_input: bool = False, ): if isinstance(process, SKMapGroupRunner): self._group_run(process, group, outname) else: process_name = process.__class__.__name__ start = time.time() self._verbose(f"Running {process_name}" + f" on {self.array.shape}") kwargs = {"rdata": self} if outname is not None: kwargs["outname"] = outname _, new_info = process.run(**kwargs) if new_info.shape[0] > 0: idx_offset = self._idx_offset() new_info.index = list(range(idx_offset, idx_offset + new_info.shape[0])) self.info = pd.concat([self.info, new_info]) self._verbose( "Execution" + f" time for {process_name}: {(time.time() - start):.2f} segs" ) if drop_input: self.drop(group) return self def _group_run( self, process: SKMapGroupRunner, group: [list, str] = [], outname: str = None, ): if isinstance(group, str): group = [group] to_add_info = [] group_list = [] ginfo_list = [] for _group, ginfo in self.info.groupby(RasterData.GROUP_COL): if ginfo[RasterData.TEMPORAL_COL].iloc[0] != process.temporal: self._verbose( f"Skipping {process.__class__.__name__} for {_group} group" ) continue if len(group) > 0 and _group not in group: continue expr_group = f'{RasterData.GROUP_COL} == "{_group}"' ginfo = self.info.query(expr_group) group_list.append(_group) ginfo_list.append(ginfo) process_name = process.__class__.__name__ start = time.time() self._verbose(f"Running {process_name}" + f" {len(group_list)} groups") _, new_info = process.run(self, group_list, ginfo_list, outname) to_add_info.append(new_info) self._verbose( "Execution" + f" time for {process_name}: {(time.time() - start):.2f} segs" ) self._active_group = None if len(to_add_info) > 0: new_info = pd.concat(to_add_info) idx_offset = self._idx_offset() new_info.index = list(range(idx_offset, idx_offset + new_info.shape[0])) self.info = pd.concat([self.info, new_info]) return self def drop(self, group): if isinstance(group, str): group = [group] self._verbose(f"Dropping data and info for groups: {group}") idx = self.info[self.info[RasterData.GROUP_COL].isin(group)].index self.info = self.info.drop(idx) return self def rename(self, groups: dict): self.info[RasterData.GROUP_COL] = self.info[RasterData.GROUP_COL].replace( groups ) self.info[RasterData.NAME_COL] = self.info[RasterData.NAME_COL].replace(groups) for old_group in groups.keys(): new_group = groups[old_group] self.date_args[new_group] = self.date_args[old_group] del self.date_args[old_group] return self def filter_date( self, start_date, end_date=None, date_format="%Y-%m-%d", date_overlap=False, return_array=False, return_copy=True, return_idx=False, ): start_dt_col, end_dt_col = (RasterData.START_DT_COL, RasterData.END_DT_COL) info_main = self.info if RasterData.DT_COL in info_main.columns: start_dt_col, end_dt_col = (RasterData.DT_COL, None) if date_overlap: dt_mask = np.logical_or( info_main[start_dt_col] >= to_datetime(start_date, format=date_format), info_main[end_dt_col] >= to_datetime(start_date, format=date_format), ) else: dt_mask = info_main[start_dt_col] >= to_datetime( start_date, format=date_format ) if end_date is not None and end_dt_col is not None: if date_overlap: dt_mask_end = np.logical_or( info_main[end_dt_col] <= to_datetime(end_date, format=date_format), info_main[start_dt_col] <= to_datetime(end_date, format=date_format), ) else: dt_mask_end = info_main[end_dt_col] <= to_datetime( end_date, format=date_format ) dt_mask = np.logical_and(dt_mask, dt_mask_end) return self._filter( info_main[dt_mask], return_array=return_array, return_copy=return_copy, return_idx=return_idx, ) def filter_contains( self, text, return_array=False, return_copy=True, return_idx=False ): return self.filter( f'{self.NAME_COL}.str.contains("{text}")', return_array=return_array, return_copy=return_copy, return_idx=return_idx, ) def filter(self, expr, return_array=False, return_copy=True, return_idx=False): return self._filter( self.info.query(expr), return_array=return_array, return_copy=return_copy, return_idx=return_idx, ) def _filter( self, info, return_info=False, return_array=False, return_copy=True, return_idx=False, ): # Active filters if self._active_group is not None: info = info.query(f'{RasterData.GROUP_COL} == "{self._active_group}"') if return_idx: return list(info.index) elif return_array: return self.array[:, :, info.index] elif return_info: return info elif return_copy: rdata = copy.copy(self) rdata.array = self.array[:, :, info.index] rdata.info = info return rdata else: self.array = self.array[:, :, info.index] self.info = info return self def _array(self): return self._filter(self.info, return_array=True) def _info(self): return self._filter(self.info, return_info=True) def _base_raster(self): for _, row in self.info.iterrows(): path = row[RasterData.PATH_COL] if "http" in str(path): res = requests.head(path) if res.status_code == 200: return path elif os.path.isfile(path): return path raise Exception("No base raster is available.") def to_dir( self, out_dir: Union[Path, str], group_expr: str = None, dtype: str = None, nodata=None, fit_in_dtype: bool = False, n_jobs: int = 4, return_outfiles=False, on_each_outfile: Callable = None, ): if isinstance(out_dir, str): out_dir = Path(out_dir) info = self.info if group_expr is not None: info = self.info.query(group_expr) if info.size == 0: ttprint("No rasters to save. Double check group_expr arg.") return self base_raster = self._base_raster() outfiles = [ out_dir.joinpath(f"{name}.tif") for name in list(info[RasterData.NAME_COL]) ] self._verbose(f"Saving rasters in {out_dir}") save_rasters( base_raster, outfiles, self.array, array_idx=info.index, window=self.window, bounds=self.bounds, dtype=dtype, nodata=nodata, fit_in_dtype=fit_in_dtype, n_jobs=n_jobs, on_each_outfile=on_each_outfile, verbose=self.verbose, ) if return_outfiles: return outfiles else: return self def to_s3( self, host: Union[str, list], access_key: str, secret_key: str, path: str, secure: bool = True, tmp_dir: str = None, group_expr: str = None, dtype: str = None, nodata=None, fit_in_dtype: bool = False, n_jobs: int = None, verbose_cp=False, ): bucket = path.split("/")[0] prefix = "/".join(path.split("/")[1:]) if tmp_dir is None: tmp_dir = Path(tempfile.TemporaryDirectory().name) tmp_dir = tmp_dir.joinpath(prefix) def _to_s3(outfile) -> None: _host = host if isinstance(host, list): ih = int.from_bytes(str(outfile.stem).encode(), "little") % len(host) _host = host[ih] client = Minio(_host, access_key, secret_key, secure=secure) name = f"{outfile.name}" if verbose_cp: ttprint(f"Copying {outfile} to http://{host}/{bucket}/{prefix}/{name}") client.fput_object(bucket, f"{prefix}/{name}", outfile) os.remove(outfile) outfiles = self.to_dir( tmp_dir, group_expr=group_expr, dtype=dtype, nodata=nodata, fit_in_dtype=fit_in_dtype, n_jobs=n_jobs, return_outfiles=True, on_each_outfile=_to_s3, ) name = outfiles[len(outfiles) - 1].name last_url = f"http://{host}/{bucket}/{prefix}/{name}" self._verbose(f"{len(outfiles)} rasters copied to s3") self._verbose(f"Last raster in s3: {last_url}") return self def __del__(self) -> None: print("Deleting") del_memmap(self.array) def __exit__(self): self.__del__() def _get_titles(self, img_title, bands): f_arr = self.filter(f"group=={bands}") if img_title == "date": titles = list( f_arr.info["start_date"].astype(str) + " - " + f_arr.info["end_date"].astype(str) ) elif img_title == "index": titles = [str(i) for i in range(f_arr.info.shape[0])] elif img_title == "name": titles = f_arr.info["name"].to_list() # titles = [] # n = 20 # for name in list(f_arr.info['name']): # titles.append('\n'.join(name[i:i+n] for i in range(0, len(name), n))) else: titles = [""] * f_arr.info.shape[0] return titles
[docs] def point_query( self, x: list, y: list, cols: int = 3, titles: list = None, label_xaxis: str = "index", return_data: bool = False, ): """ Makes point queries on dataset and provide plots and data :param x: longitude value(s) of the given point(s) :param y: latitude value(s) of the given point(s) :param cols: column count of the desired layout. Default is 3. :param titles: list of the titles that will be placed on top of the each graph :param label_xaxis: labels of the x axes. it could be `index`, `name`,`date` or None. :param return_data: If the user wants to access the data sampled from rasters, this needs to be set to True. Default is False Examples ======== >>> import geopandas as gpd >>> from skmap.data import toy >>> rasterdata = toy.ndvi_rdata(gappy=False) #doctest: +SKIP >>> points = gpd.read_file('./skmap/data/toy/samples/samples.gpkg') >>> rasterdata.point_query(x=points.geometry.x.to_list(), y=points.geometry.y.to_list() , label_xaxis='index', cols=3, titles=points.label) #doctest: +SKIP """ df = pd.DataFrame() df["x"], df["y"], df["title"] = x, y, titles bbox = rasterio.open(self._base_raster()).bounds # filtering points based on the bounds of the base raster df = df[ (bbox.left <= df["x"]) & (df["x"] <= bbox.right) & (bbox.bottom <= df["y"]) & (df["y"] <= bbox.top) ] with rasterio.open(self._base_raster()) as src: row_id, col_id = rasterio.transform.rowcol(src.transform, df.x, df.y) df["data"] = np.array(self.array[row_id, col_id]).tolist() # if data is required no need to create figures if return_data: return df.data.to_numpy() labels_x = self._get_titles(label_xaxis) fig, axs = pyplot.subplots( ncols=cols, nrows=math.ceil(len(x) / cols), figsize=(6 * cols, 2 * math.ceil(len(x) / cols)), sharex=True, sharey=True, ) mgc = df.shape[0] # maximum graph count for i, ax in enumerate(axs.flatten()): if i < mgc: ax.plot(labels_x, df.data[i], "-o", markersize=4, color="blue", lw=1) ax.set_title(titles[i], fontsize=10) ax.tick_params(axis="x", rotation=90) else: ax.axis("off") pyplot.tight_layout() pyplot.close() return fig
def _vminmax(self, vmm, arr): """ To check and calculate the boundaries of the data. If the bounds are supplied it will return it, If not function will return the 1 and 99% of the data as bounds. :param vmm: supplied min/max bounds of data :param arr: the data will be used to generate a image """ if vmm[0]: return vmm return np.nanquantile(arr.flatten(), [0.02, 0.98]) def _op_io(self, figure): """ converts figure to image and ascii representation of it to use with HTML embeded animation. :param figure: matplotlib figure object """ buffer = BytesIO() figure.savefig(buffer, format="png", bbox_inches="tight") img64 = encodebytes(buffer.getvalue()).decode("ascii") return img64 def _percent_clip(self, arr): """ To calculate and scale the band upper and lower limits to generate a composite image from 3 bands. returns the scaled data :param arr: the data usually single band data in np.array format. """ return (arr - np.nanpercentile(arr, 1)) / ( np.nanpercentile(arr, 99) - np.nanpercentile(arr, 1) ) def _mutate_baseshot(self, img, arr, titletext, textfontsize): """ takes imshow generated mock image copies it and replaces the nested image data. :param img: the mock image, generated with pyplot.imshow :param arr: the scaled data for the frame :param title_params: it is a dict. It will be used for the title generation on the frame. """ c_img = deepcopy(img) c_img.set_data(arr) if titletext: c_img._axes.set_title( label=titletext, fontdict=dict(fontsize=textfontsize), pad=1 ) return c_img.get_figure() def _gen_baseshot( self, arr, scaling: int = 1, img_style: dict = None, cbar_props: dict = None, composite=False, ): # base figure with predefined style # no axis labels tick_params = dict(left=False, labelleft=False, labelbottom=False, bottom=False) # scaling the figsize based on the passed array shape # the base figsize is 3.15 inc = 8cm almost half short side of a A4 page fig, ax = pyplot.subplots() rc, cc = arr.shape[:2] fig.set_size_inches(scaling * 3.15, scaling * 3.15 * rc / cc) # generation of basedata based on the array shape basedata = np.zeros(rc * cc).reshape(rc, cc) if composite: basedata = np.zeros(rc * cc * 3).reshape(rc, cc, 3) # crafting the base image ax.tick_params(**tick_params) ax.margins(x=0) if img_style: baseimg = ax.imshow( basedata, **img_style ) # img_style = dict(vmin=, vmax=, cmap=) else: baseimg = ax.imshow(basedata) # if there will be a colorbar there will be a colorbar if cbar_props: # cbar_props is dict(label='text') div = make_axes_locatable(ax) pyplot.colorbar( baseimg, orientation="horizontal", label=cbar_props["label"], cax=div.append_axes("bottom", size="3%", pad=0.05), ) pyplot.close() return baseimg def _band_manage(self, groups): """ to structure the band(s) data based on the provided band information. either single or multiple band data. :params groups: list of band names. """ if len(groups) == 1: # single band raster arr = self.filter(f"group=={groups}").array elif len(groups) == 3: # composite arr = [] band1 = self.filter(f"group=={groups}[0]", return_array=True) band2 = self.filter(f"group=={groups}[1]", return_array=True) band3 = self.filter(f"group=={groups}[2]", return_array=True) alpha = np.ones(band3.shape) mask = np.any(np.isnan(np.stack([band1, band2, band3], axis=-1)), axis=-1) alpha[mask] = 0 for i in range(band1.shape[2]): arr.append( np.stack( [ np.clip(self._percent_clip(band1[:, :, i]), 0, 1), np.clip(self._percent_clip(band2[:, :, i]), 0, 1), np.clip(self._percent_clip(band3[:, :, i]), 0, 1), alpha[:, :, i], ], axis=2, ) ) else: raise Exception("""The band count should either be one or three. Current plotting capabilites are limited to single or composite image generation.""") return arr
[docs] def plot( self, groups: list = None, cmap: str = "Spectral_r", cbar_title: str = None, img_title_text: str or list = "index", img_title_fontsize: int = 10, vminmax: tuple = (None, None), to_img: str = None, dpi: int = 100, layout_col: int = 4, ): """ Generates a grid plot to view and save with a colorscale with a desired layout. :param cmap : This sets the colorscale with given matplotlib.colormap. Default is Spectral_r :param cbar_title : This sets the colorbar title if the cbar exists in the plot. Default is None. :param img_title_text : This sets the image titles that will be display on top of the each image. Default is `index`. :param img_ltitle_fontsize : This sets the fontsize of the image label which will be on top of the image. Default is 10. :param v_minmax : This sets the loower and upper limits of the data that will be plot and the colorbar. Default is None and will be calculated on he fly. :param groups : This used for to generate composite plot. Pass one or tree group names (groups) which will be used to generate. Default is None. :param to_img : This sets the directory adn the format of the file where the generated image will be saved. Default is None. :param dpi : dot per inch value to save the figure. If the `to_img` param provided :param layout_col : This controls the column count that will be used in the grid plot. Default is 3. """ if not groups: groups = [self.info.group.to_list()[0]] arr = self._band_manage(groups=groups) if isinstance(img_title_text, str): img_title_text = self._get_titles(img_title_text, groups) if len(groups) == 3: img_cnt = len(arr) composite = True baseimg = self._gen_baseshot(arr[0][:, :, 0]) elif len(groups) == 1: img_cnt = arr.shape[2] composite = False vminmax = self._vminmax(vminmax, arr) baseimg = self._gen_baseshot(arr[:, :, 0]) if img_cnt < layout_col: layout_col = img_cnt layout_row = math.ceil(img_cnt / layout_col) set_h = baseimg.get_size()[0] / baseimg.get_figure()._dpi set_w = baseimg.get_size()[1] / baseimg.get_figure()._dpi if set_w > set_h: set_w = set_w * 3.15 / set_h set_h = 3.15 else: set_h = set_h * 3.15 / set_w set_w = 3.15 grd_fig, grd_axs = pyplot.subplots( nrows=layout_row, ncols=layout_col, gridspec_kw=dict(wspace=0.1, hspace=0.1), figsize=( set_w * layout_col + (layout_col - 1) * 0.1, set_h * layout_row + (layout_row - 1) * 0.1, # + 1 ), ) def _preprocess(arr_, ind, composite): if composite: return np.flipud(arr_[ind]) else: return np.flipud(arr_[:, :, ind]) matrix_params = dict(vmin=vminmax[0], vmax=vminmax[1], cmap=cmap) def gen_pane( ind, arr, ax, composite, matrix_params, img_title_text, img_title_fontsize ) -> None: try: ax.pcolorfast( _preprocess(arr, ind, composite=composite), **matrix_params ) ax.set_title(img_title_text[ind], fontsize=img_title_fontsize, pad=1) ax.tick_params( left=False, bottom=False, labelleft=False, labelbottom=False ) except IndexError: ax.axis("off") try: for i, ax in enumerate(grd_axs.flatten()): gen_pane( i, arr, ax, composite, matrix_params, img_title_text, img_title_fontsize, ) except AttributeError: gen_pane( 0, arr, grd_axs, composite, matrix_params, img_title_text, img_title_fontsize, ) pyplot.close() if not composite: grd_fig.subplots_adjust(left=0, right=1, bottom=0, top=1) w, h = grd_fig.get_size_inches() cbar_ax = grd_fig.add_axes([0.1, 1 + (3.15 / h) * 0.1, 0.8, 0.16 / h]) cbar_ax = grd_fig.colorbar( pyplot.imshow( arr[:, :, 0], vmin=vminmax[0], vmax=vminmax[1], cmap=cmap ), orientation="horizontal", cax=cbar_ax, ticklocation="top", ).set_label(label=cbar_title) pyplot.tight_layout() pyplot.close() if to_img: grd_fig.savefig( to_img, format=f"{to_img.split('.')[-1]}", dpi=dpi, bbox_inches="tight" ) return grd_fig
def _idx_offset(self): return self.info.index.max() + 1
[docs] def animate( self, cmap: str = "Spectral_r", groups: list = None, scaling: float = 2, cbar_title: str = None, img_title_text: str or list = "index", img_title_fontsize: int = 10, vminmax: tuple = (None, None), interval: int = 250, to_gif: str = None, n_jobs: int = 4, ): """ Generates an animation with the given band(s) and saves it. :param cmap: colormap name that will derived from `matplotlib.colormaps()` :param groups: this is used for to select the band(s) or to generate a composite images, that will be used as animation frame. Default is None but it will select the first band on RasterData. :param scaling: scaling can be used to increase/decrease the frame size. Default is 2. :param cbar_title: """ if not groups: groups = [self.info.group.to_list()[0]] arr = self._band_manage(groups=groups) if isinstance(img_title_text, str): img_title_text = self._get_titles(img_title_text, groups) if len(groups) == 3: img_cnt = len(arr) baseimg = self._gen_baseshot( arr=arr[0][:, :, 0], scaling=scaling, composite=True ) args = [ (baseimg, arr[i], img_title_text[i], img_title_fontsize) for i in range(img_cnt) ] elif len(groups) == 1: img_cnt = arr.shape[2] vminmax = self._vminmax(vminmax, arr) baseimg = self._gen_baseshot( arr=arr[:, :, 0], scaling=scaling, img_style=dict(vmin=vminmax[0], vmax=vminmax[1], cmap=cmap), cbar_props=dict(label=cbar_title), composite=False, ) args = [ (baseimg, arr[:, :, i], img_title_text[i], img_title_fontsize) for i in range(img_cnt) ] int_fig = [f for f in parallel.job(self._mutate_baseshot, args, n_jobs=n_jobs)] int_img = [ j for j in parallel.job( self._op_io, [(fig,) for fig in int_fig], n_jobs=n_jobs ) ] if to_gif is not None: with ExitStack() as stack: imgs = ( stack.enter_context(Image.open(BytesIO(b64decode(img)))) for img in int_img ) img = next(imgs) img.save( to_gif, format="GIF", append_images=imgs, save_all=True, duration=interval, loop=0, ) template = ' frames[{0}] = "data:image/{1};base64,{2}"\n' embedded_frames = "\n" + "".join( template.format(i, "png", imgdata.replace("\n", "\\\n")) for i, imgdata in enumerate(int_img) ) mode_dict = dict(once_checked="", loop_checked="checked", reflect_checked="") with TemporaryDirectory() as tmpdir: path = Path(tmpdir, "temp.html") with open(path, "w") as of: of.write(JS_INCLUDE + STYLE_INCLUDE) of.write( DISPLAY_TEMPLATE.format( id=uuid4().hex, Nframes=img_cnt, fill_frames=embedded_frames, interval=interval, **mode_dict, ) ) html_rep = path.read_text() return HTML(html_rep)