# Copyright 2020 Arthur Coqué, Pôle OFB-INRAE ECLA, UR RECOVER
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module gathers classes related to L3 products.
In this module are defined L3 products : the abstract class L3Product and its
child classes L3AlgoProduct and L3MaskProduct. Basically, these classes wrap
a well formated xarray Dataset and offer a few useful methods.
Example::
S2_ndvi = L3AlgoProduct(dataset)
S2_ndvi.plot()
S2_ndvi.save(path_to_file)
s2cloudless_mask = L3MaskProduct(dataset)
s2cloudless_mask.save(path_to_file)
"""
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt
from sisppeo.utils.exceptions import InputError
from sisppeo.utils.products import CoordinatesMixin, get_enc
warnings.filterwarnings('ignore', category=xr.SerializationWarning)
# pylint: disable=invalid-name
# Ok for a custom type.
N = Union[int, float]
P = List[Union[Path, str]]
[docs]@dataclass
class L3Product(ABC, CoordinatesMixin):
"""Abstract class inherited by both L3AlgoProduct and L3MaskProduct.
Attributes:
dataset: A dataset containing processed data.
"""
__slots__ = 'dataset',
dataset: xr.Dataset
[docs] @classmethod
@abstractmethod
def from_file(cls, filename: Union[str, Path]):
"""Loads and returns a L3Product from file.
Args:
filename: The path to the L3Product (saved as a netCDF file).
"""
@property
def title(self) -> str:
"""Returns the title of the underlying dataset."""
return self.dataset.attrs['title']
@property
def product_type(self) -> str:
"""Returns the product_type of the product used to get this dataset."""
return self.dataset.attrs['title'].rsplit(' ', 1)[1]
@property
def data_vars(self):
"""Returns a list of DataArrays corresponding to variables."""
return [data_var for data_var in self.dataset.data_vars
if data_var not in ('crs', 'product_metadata')]
[docs] def plot(self, data_var) -> None:
"""Plots a given variable.
data_var: The name of the variable/DataArray of interest (e.g., a
band, aCDOM, etc).
"""
if data_var not in self.data_vars:
msg = (f'"{data_var}" is not a variable of this product; please, '
f'choose one from the following list: {self.data_vars}.')
raise InputError(msg)
self.dataset[data_var].plot()
plt.show()
[docs] @abstractmethod
def save(self, filename: P) -> None:
"""Saves this product into a netCDF file.
Args:
filename: Path of the output file.
"""
[docs]@dataclass
class L3AlgoProduct(L3Product):
"""An L3Product embedding data obtained by using a wc/land algorithm.
Attributes:
dataset: A dataset containing processed data.
"""
__slots__ = ()
[docs] @classmethod
def from_file(cls, filename):
return L3AlgoProduct(xr.open_dataset(filename))
@property
def algo(self) -> str:
"""Returns the name of the algorithm used to get this dataset."""
return self.title.split(' ', 1)[0]
[docs] def save(self, filename: P) -> None:
"""See base class."""
enc = {data_var: get_enc(self.dataset[data_var].values, 0.001, True)
for data_var in self.data_vars}
enc.update({
'crs': {'dtype': 'byte'},
'product_metadata': {'dtype': 'byte'},
'x': get_enc(self.dataset.x.values, 0.1),
'y': get_enc(self.dataset.y.values, 0.1)
})
self.dataset.to_netcdf(filename, encoding=enc)
[docs]@dataclass
class L3MaskProduct(L3Product):
"""An L3Product embedding data obtained by using a mask algorithm.
Attributes:
dataset: A dataset containing processed data.
"""
__slots__ = ()
[docs] @classmethod
def from_file(cls, filename):
return L3MaskProduct(xr.open_dataset(filename))
@property
def mask(self):
"""Returns the name of the mask used to get this dataset."""
return self.title.split(' ', 1)[0]
[docs] def save(self, filename: Union[Path, str]) -> None:
"""See base class."""
self.dataset.to_netcdf(filename, encoding={
self.mask: {'dtype': 'bool'},
'crs': {'dtype': 'byte'},
'product_metadata': {'dtype': 'byte'},
'x': {'dtype': 'int32'},
'y': {'dtype': 'int32'}
})
[docs]def mask_product(l3_algo: L3AlgoProduct,
l3_masks: Union[L3MaskProduct, List[L3MaskProduct]],
lst_mask_type: Union[str, List[str]],
inplace=False) -> Optional[L3AlgoProduct]:
"""Masks an L3AlgoProduct.
Masks an L3AlgoProduct with one or more L3MaskProducts. It can be used for
instance to get rid of clouds or to extract only water areas.
Args:
l3_algo: The L3AlgoProduct to be masked.
l3_masks: The mask or list of masks to use.
lst_mask_type: The type of the input mask (or the list of the types of
input masks). Can either be 'IN' or 'OUT', indicating if the
corresponding mask is inclusive or exclusive.
inplace: If True, do operation inplace and return None.
Returns:
A masked L3AlgoProduct.
"""
if not inplace:
l3_algo = deepcopy(l3_algo)
if isinstance(l3_masks, L3MaskProduct):
l3_masks = [l3_masks]
if isinstance(lst_mask_type, str):
lst_mask_type = [lst_mask_type]
# Get the shared bounding box (i.e. intersection)
x_min = max([l3_mask.x.values.min() for l3_mask in l3_masks]
+ [l3_algo.x.values.min()])
x_max = min([l3_mask.x.values.max() for l3_mask in l3_masks]
+ [l3_algo.x.values.max()])
y_min = max([l3_mask.y.values.min() for l3_mask in l3_masks]
+ [l3_algo.y.values.min()])
y_max = min([l3_mask.y.values.max() for l3_mask in l3_masks]
+ [l3_algo.y.values.max()])
# Clip masks with the previous bounding box
arr_masks = [l3_mask.dataset[l3_mask.mask].sel(
x=slice(x_min, x_max), y=slice(y_max, y_min)).values
for l3_mask in l3_masks]
# Merge 'IN' masks (<=> what to include)
idx_in = [i for i, mask_type in enumerate(lst_mask_type)
if mask_type.upper() == 'IN']
mask_in = np.sum([arr_masks[i] for i in idx_in], axis=0)
# Merge 'OUT' masks (<=> what to exclude)
idx_out = [i for i, mask_type in enumerate(lst_mask_type)
if mask_type.upper() == 'OUT']
mask_out = np.sum([arr_masks[i] for i in idx_out], axis=0)
# Create the final mask
if not idx_in:
mask = np.where(mask_out == 0, True, False)
elif not idx_out:
mask = np.where(mask_in > 0, True, False)
else:
mask = np.where((mask_in > 0) & (mask_out == 0), True, False)
# Apply the previously computed mask to the product
l3_algo.dataset = l3_algo.dataset.sel(x=slice(x_min, x_max),
y=slice(y_max, y_min))
for var in l3_algo.data_vars:
l3_algo.dataset[var].values = np.where(
mask, l3_algo.dataset[var].values, np.nan
)
# Store masks' names
masks = []
dico = {'s2cloudless': 'cloudmask', 'waterdetect': 'watermask'}
for l3_mask, mask_type in zip(l3_masks, lst_mask_type):
if l3_mask.mask in ('s2cloudless', 'waterdetect'):
version = l3_mask.dataset[l3_mask.mask].attrs['version']
l3_algo.dataset.attrs[dico[l3_mask.mask]] = f'{l3_mask.mask} ({version}) [{mask_type}]'
else:
masks.append(f'{l3_mask.mask} [{mask_type}]')
if masks:
l3_algo.dataset.attrs['masks'] = masks
if not inplace:
return l3_algo
return None