Source code for autopycoin.models.training

"""
Overloading Model tensorflow object
"""

import math
from typing import List, Union, Callable, Any
import itertools

import keras

import tensorflow.compat.v2 as tf
from keras.losses import LossFunctionWrapper

from ..utils import convert_to_list, quantiles_handler
from ..layers.base_layer import BaseLayer, QuantileLayer, UnivariateLayer
from ..constant import TENSOR_TYPE


[docs]class BaseModel(keras.Model, BaseLayer): """Base model which defines pre/post-processing methods to override. This model aims to be inherited and brings six functionality. - preprocessing : Preprocess the inputs data - post_processing : Preprocess the outputs data - init_params : initialize parameters before `build` method - metrics_wrapper : Preprocess y_true or y_pred - losses_wrapper : Preprocess y_true or y_pred This three wrappers have to be overriden - Typing check. """ def __init__(self, *args: list, **kwargs: dict) -> None: super().__init__(*args, **kwargs)
[docs] def handle_dim_in_losses_and_metrics(self, outputs: TENSOR_TYPE) -> None: """Build and wrap losses and metrics.""" if self.compiled_loss: if not self.compiled_loss.built: self.compiled_loss.build(outputs) self.compiled_loss._losses = tf.nest.map_structure( self.losses_wrapper, self.compiled_loss._losses ) if self.compiled_metrics: if not self.compiled_metrics.built: self.compiled_metrics.build(outputs, outputs) self.compiled_metrics._metrics = tf.nest.map_structure( self.metrics_wrapper, self.compiled_metrics._metrics )
[docs] def losses_wrapper( self, loss: LossFunctionWrapper ) -> Union[Callable, LossFunctionWrapper]: """Wrap the `fn` function. See `tf.keras.losses.LossFunctionWrapper` docstring for more informations about `fn`. """ raise NotImplementedError("`losses_wrapper` has to be overriden")
[docs] def metrics_wrapper(self, metrics: Any) -> Union[Callable, LossFunctionWrapper]: """Wrap the update_state function. See `tf.keras.metrics.Metric` docstring for more informations about `update_state`. """ raise NotImplementedError("`losses_wrapper` has to be overriden")
def _post_processing_wrapper(self, outputs: TENSOR_TYPE) -> TENSOR_TYPE: """Post-processing wrapper. it handles the case of: - one tensor vs multi loss functions or no loss function - one tensor vs one loss function """ self.handle_dim_in_losses_and_metrics(outputs) losses = getattr(self.compiled_loss, "_losses", None) losses_is_nested = tf.nest.is_nested(losses) outputs_is_nested = tf.nest.is_nested(outputs) outputs = tuple(outputs) if outputs_is_nested else (outputs,) # Case 1: multi outputs != multi losses or no losses if losses_is_nested and len(outputs) == len(losses): outputs = tf.nest.map_structure( lambda output, loss: self.post_processing(output, losses=loss), outputs, tuple(losses), ) # Case 2: multi outputs = multi losses else: outputs = tf.nest.map_structure( lambda output: self.post_processing(output, losses=losses), outputs ) return outputs[0] if len(outputs) == 1 else outputs
[docs]class QuantileModel(BaseModel, QuantileLayer): # pylint: disable=abstract-method """Overloads tensorflow Model class to integrate a `quantiles` attribute. During the compiling phase, the model checks the existence of the attribute `quantiles` in each loss function. If the test is positive then the model defines several attributes based on `quantiles` found in the loss functions. The model propagates the attributes associated to `quantiles` to the sublayers. Be carefull, if the check is positive the model is no more built. During the first call, all compiled losses and metrics are build and a second check is perfomed to ensure that Each output is not associated with different `quantiles` values otherwise it raises a ValueError. When subclassing this model, a pre/post-processing methods can be defined. Also a `post_processing` are already defined in order to transpose the `quantiles` dimensions. See :class:`autopycoin.layers.QuantileLayer` for more information for how to acces `quantiles` dimension in building phase. Attributes ---------- has_quantiles : bool True if `quantiles` is not None else False. It is defined during compiling `method`. Default to False. quantiles : list[List[float]] or None It defines the quantiles used in the model. `quantiles` is a list of lists depending on the number of outputs the model computes. It is defined during compiling `method`. Default to None. n_quantiles : list[int] or int The number of quantiles the model computes. It is defined during compiling `method`. Default to 0. """ NOT_INSPECT = ['_check_quantiles_requirements', 'call', 'build', 'compile'] def __init__( self, apply_quantiles_transpose: bool = True, *args: list, **kwargs: dict ) -> None: super().__init__(*args, **kwargs) self._has_quantiles = False self._quantiles = None self._n_quantiles = 0 self._additional_shapes = [[]] self.apply_quantiles_transpose = apply_quantiles_transpose
[docs] def compile( self, optimizer="rmsprop", loss=None, # TODO: multiple loss one output (qloss, mse) -> leads to mse loss over estimated due to the quantiles -> raise an error? use a wrapper to select only the 0.5 quantile? metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None, steps_per_execution=None, **kwargs, ): """Compile method from tensorflow. When compiling with losses defining a quantiles attribute it propagates this attribute to the submodels and sublayers. """ # Check if `quantiles` exists quantiles = self._check_quantiles_in_loss(loss) # Defines attributes associated with `quantiles` and propagates it to sublayers if quantiles: self._set_quantiles(quantiles) super().compile( optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights, weighted_metrics=weighted_metrics, run_eagerly=run_eagerly, steps_per_execution=steps_per_execution, **kwargs, )
def handle_dim_in_losses_and_metrics(self, outputs: TENSOR_TYPE): if self.has_quantiles: return super().handle_dim_in_losses_and_metrics(outputs) def _check_quantiles_in_loss( self, loss: Union[ str, tf.keras.losses.Loss, LossFunctionWrapper, List[Union[str, tf.keras.losses.Loss, LossFunctionWrapper]], ], ) -> Union[List[Union[List[int], int]], None]: """Check if the loss functions define a `quantiles` attribute. If True then it returns the quantiles found. """ # Case of multiple losses if isinstance(loss, (tuple, list)): quantiles = ( self._check_quantiles_in_loss(loss_fn) for loss_fn in loss ) quantiles = [q for q in quantiles if q] return list(itertools.chain.from_iterable(quantiles)) # One loss elif hasattr(loss, "quantiles"): return quantiles_handler(loss.quantiles) # TODO: Avoid to rebuild weights when quantiles of model is not None def _set_quantiles( self, value: List[List[float]], additional_shapes: Union[None, List[List[int]]] = None, n_quantiles: Union[None, List[List[int]]] = None, ) -> None: """Set attributes linked to the quantiles found in the losses functions.""" super()._set_quantiles(value, additional_shapes, n_quantiles) # Propagates to sublayers for idx, _ in enumerate(self.layers): if hasattr(self.layers[idx], "_set_quantiles"): self.layers[idx]._set_quantiles( value, self._additional_shapes, self.n_quantiles ) # pylint: disable=protected-access
[docs] def losses_wrapper( self, loss: LossFunctionWrapper ) -> Union[Callable, LossFunctionWrapper]: """Add or remove the quantile dimension to y_pred and y_true respectively.""" # TODO: We override the fn function which can be a Loss instance and turn it into function. # As below we have to recreate an instance of the loss otherwise we lose informations as the attributes etc... if ( not hasattr(loss, "quantiles") and not isinstance(loss, type(None)) and not hasattr(loss, "_done") ): loss.fn = _remove_dimension_to_ypred(loss.fn) return loss elif not isinstance(loss, type(None)) and not hasattr(loss, "_done"): loss = LossFunctionWrapper(loss) loss = _add_dimension_to_ytrue(loss, type(loss)) return loss
[docs] def metrics_wrapper( self, metric: Union[ None, str, keras.metrics.Metric, List[Union[str, keras.metrics.Metric]] ], ) -> Union[Callable, LossFunctionWrapper]: """Add or remove the quantile dimension to y_pred and y_true respectively.""" # TODO: We override the update_state function which can be a Loss instance and turn it into function. # As below we have to recreate an instance of the loss otherwise we lose informations as the attributes etc... if not hasattr(metric, "quantiles") and not isinstance(metric, type(None)): metric.update_state = _remove_dimension_to_ypred(metric.update_state) return metric elif not isinstance(metric, type(None)): metric = _add_dimension_to_ytrue(metric, type(metric)) return metric
def _check_quantiles_requirements( self, outputs: TENSOR_TYPE, losses: Union[None, tf.keras.losses.Loss, List[tf.keras.losses.Loss]] = None, ) -> bool: """Check if the requirements are valids else raise a ValueError. Raises ------ ValueError: If the losses don't define a same `quantiles` attribute respectively to one output. If the output contains a `quantiles` dimension and there isn't at least one quantile loss. If the output don't contains a `quantiles` dimension and there is at least one quantile loss. If the output `quantiles` are not broadcastable with the losses `quantiles` """ if losses and self.has_quantiles: # TODO: optimization, this calculation is made twice (One in _handle_quantiles_dim_in_losses_and_metrics) quantiles_in_losses = [ loss.quantiles[0] if hasattr(loss, "quantiles") else loss.fn.quantiles[0] if hasattr(loss.fn, "quantiles") else None for loss in convert_to_list(losses) ] quantiles_in_losses = [q for q in quantiles_in_losses if q is not None] check_uniform_quantiles = self._check_uniform_quantiles_through_losses( quantiles_in_losses ) check_quantiles_in_outputs = self._check_quantiles_in_outputs(outputs) if not check_uniform_quantiles: raise ValueError( f"`quantiles` has to be identical through losses. Got losses {quantiles_in_losses}." ) elif not any(quantiles_in_losses) and check_quantiles_in_outputs: raise ValueError( f"It is not allowed to train a quantile model without a quantile loss. Got a loss {losses} and an output shape {outputs.shape}." ) elif any(quantiles_in_losses): if self._compare_quantiles_in_outputs_and_losses( outputs, quantiles_in_losses ): return True elif self._is_single_quantile(quantiles_in_losses): return False raise ValueError( f"Quantiles in losses and outputs are not the same. Maybe you are trying to train a no quantile model " f"with a quantile loss. It is possible only if there is one quantile defined as [[0.5]]. " f"got outputs shape: {outputs.shape} and quantiles in losses: {quantiles_in_losses}" ) return False def _check_uniform_quantiles_through_losses( self, quantiles_in_losses: List[List[Union[int, float]]] ) -> bool: """Return True if all losses define an identical `quantiles` attribute""" if len(quantiles_in_losses) == 0: # Case of no quantiles in losses return True return all( q == quantiles_in_losses[idx - 1] for idx, q in enumerate(quantiles_in_losses) ) def _compare_quantiles_in_outputs_and_losses( self, outputs: TENSOR_TYPE, quantiles_in_losses: List[List[Union[int, float]]] ) -> bool: """Return True if the outputs and the quantile loss have the same `quantiles` attribute""" return len(quantiles_in_losses[0]) == outputs.shape[0] def _is_single_quantile( self, quantiles_in_losses: List[List[Union[int, float]]] ) -> bool: return len(quantiles_in_losses[0]) == 1
def _remove_dimension_to_ypred(fn): """We remove the quantile dimension from y_pred if it is not needed, then y_true and y_pred are broadcastable. """ @tf.function(experimental_relax_shapes=True) def new_fn(y_true, y_pred, *args, **kwargs): if y_pred.quantiles and y_pred.shape.rank > y_true.shape.rank: q = math.ceil(y_pred.shape[-1] / 2) y_pred = y_pred[..., q] return fn(y_true, y_pred, *args, **kwargs) return new_fn def _add_dimension_to_ytrue(fn, obj): """We add the quantile dimension from y_true if it is needed, then y_true and y_pred are broadcastable. """ @tf.function(experimental_relax_shapes=True) def new_fn(y_true, y_pred, *args, **kwargs): if y_pred.shape.rank > y_true.shape.rank: y_true = tf.expand_dims(y_true, -1) return fn(y_true, y_pred, *args, **kwargs) kwargs = fn.get_config() quantiles = fn.fn.quantiles new_fn = obj(new_fn, **kwargs) new_fn.fn.quantiles = quantiles new_fn._done = True return new_fn
[docs]class UnivariateModel(QuantileModel, UnivariateLayer): """ Wrapper around `QuantileModel` to integrate `n_variates` attributes. for the moment, if one of the inputs tensors is a multivariates tensor then all `additional_shapes` are extended by `n_variates`. In other words all layers extended by `additional_shapes` are multivariates layers. Attributes ---------- is_multivariate : bool True if the inputs rank is higher than 2. Default to False. n_variates : list[None | int] the number of variates in the inputs. Default to []. """ NOT_INSPECT = ['_check_quantiles_requirements', 'call', 'build', 'compile'] def __init__( self, apply_multivariate_transpose: bool = True, *args: list, **kwargs: dict ) -> None: super().__init__(*args, **kwargs) self.apply_multivariate_transpose = apply_multivariate_transpose def handle_dim_in_losses_and_metrics(self, outputs: TENSOR_TYPE): if self.is_multivariate: return BaseModel.handle_dim_in_losses_and_metrics(self, outputs) return super().handle_dim_in_losses_and_metrics(outputs) def init_params( self, inputs_shape: Union[tf.TensorShape, List[tf.TensorShape]], n_variates: Union[None, List[Union[None, int]]] = None, is_multivariate: Union[None, bool] = None, additional_shapes: Union[None, List[List[int]]] = None, ) -> None: """Initialize attributes related to univariate model. It is called before `build`. Three steps are done: - Filter the first shape in case of multiple inputs tensors. - Initialize attributes: `is_multivariate`, `n_variates`. - Add the n_variates dimension to `additional_shape` and propagate these attributes to the internal layers. """ super().init_params( inputs_shape, n_variates=n_variates, is_multivariate=is_multivariate, additional_shapes=additional_shapes, ) # Propagates to sublayers for idx, _ in enumerate(self.layers): if hasattr(self.layers[idx], "init_params"): self.layers[idx].init_params( inputs_shape, self.n_variates, self.is_multivariate, self._additional_shapes, ) # pylint: disable=protected-access