Source code for quara.loss_function.weighted_relative_entropy

from typing import List, Tuple, Union

import numpy as np

from quara.loss_function.probability_based_loss_function import (
    ProbabilityBasedLossFunction,
    ProbabilityBasedLossFunctionOption,
)
from quara.math.entropy import (
    relative_entropy,
    gradient_relative_entropy_2nd,
    hessian_relative_entropy_2nd,
)
from quara.math.probability import validate_prob_dist
from quara.utils import matrix_util


[docs]class WeightedRelativeEntropyOption(ProbabilityBasedLossFunctionOption): def __init__( self, mode_weight: str = None, weights: List = None, weight_name: str = None ): """Constructor mode_weight should be the following value: - "identity" then uses identity matrices for weights. - "custom" then uses user custom matrices for weights. Parameters ---------- mode_weight : str, optional mode weight, by default None weights : List, optional list of weight, by default None weight_name : str, optional weight name for reporting, by default None """ if weights is not None: mode_weight = "custom" if not mode_weight in [ "identity", "custom", ]: raise ValueError(f"unsupported mode_weight={mode_weight}") super().__init__( mode_weight=mode_weight, weights=weights, weight_name=weight_name )
[docs]class WeightedRelativeEntropy(ProbabilityBasedLossFunction): def __init__( self, num_var: int = None, func_prob_dists: List = None, func_gradient_prob_dists: List = None, func_hessian_prob_dists: List = None, prob_dists_q: List[np.ndarray] = None, weights: Union[List[float], List[np.float64]] = None, ): """Constructor Parameters ---------- num_var : int, optional number of variables, by default None func_prob_dists : List[Callable[[np.ndarray], np.ndarray]], optional functions map variables to a probability distribution. func_gradient_prob_dists : List[Callable[[int, np.ndarray], np.ndarray]], optional functions map variables and an index of variables to gradient of probability distributions. func_hessian_prob_dists : List[Callable[[int, int, np.ndarray], np.ndarray]], optional functions map variables and indices of variables to Hessian of probability distributions. prob_dists_q : List[np.ndarray], optional vectors of ``q``, by default None. weights : Union[List[float], List[np.float64]], optional weights, by default None """ super().__init__( num_var, func_prob_dists, func_gradient_prob_dists, func_hessian_prob_dists, prob_dists_q, ) # validate self._validate_weights(weights) self._weights = weights # update on_value, on_gradient and on_hessian self._update_on_value_true() self._update_on_gradient_true() self._update_on_hessian_true() def _validate_weights(self, weights: List[float]) -> None: if weights: for index, weight in enumerate(weights): # weights are real values if type(weight) != float and type(weight) != np.float64: raise ValueError( f"values of weights must be real numbers(float or np.float64). dtype of weights[{index}] is {type(weight)}" ) @property def weights(self) -> List[float]: """returns weights. Returns ------- List[float] weights. """ return self._weights
[docs] def set_weights(self, weights: List[float]) -> None: """sets weights. Parameters ---------- weights : List[float] weights. """ self._validate_weights(weights) self._weights = weights
def _sets_weight_by_mode( self, mode_weight: str, data: List[Tuple[int, np.ndarray]] ) -> None: if mode_weight == "identity": pass elif mode_weight == "custom": self.set_weight_matrices(self.option.weights) def _update_on_value_true(self) -> bool: """validates and updates ``on_value`` to True. see :func:`~quara.data_analysis.loss_function.LossFunction._update_on_value_true` """ if self.on_func_prob_dists is True and self.on_prob_dists_q is True: self._set_on_value(True) return self.on_value def _update_on_gradient_true(self) -> bool: """validates and updates ``on_gradient`` to True. see :func:`~quara.data_analysis.loss_function.LossFunction._update_on_gradient_true` """ if ( self.on_func_prob_dists is True and self.on_func_gradient_prob_dists is True and self.on_prob_dists_q is True ): self._set_on_gradient(True) return self.on_gradient def _update_on_hessian_true(self) -> bool: """validates and updates ``on_hessian`` to True. see :func:`~quara.data_analysis.loss_function.LossFunction._update_on_hessian_true` """ if ( self.on_func_prob_dists is True and self.on_func_gradient_prob_dists is True and self.on_func_hessian_prob_dists is True and self.on_prob_dists_q is True ): self._set_on_hessian(True) return self.on_hessian
[docs] def value(self, var: np.ndarray, validate: bool = False) -> np.float64: """returns the value of Weighted Relative Entropy. see :func:`~quara.data_analysis.loss_function.LossFunction.value` """ val = 0.0 for index in range(len(self.func_prob_dists)): q = self.prob_dists_q[index] p = self.func_prob_dists[index](var) if validate: validate_prob_dist( p, raise_error=False, message="WeightedRelativeEntropy.value({index})", ) if self.weights: val += self.weights[index] * relative_entropy( q, p, is_valid_required=False ) else: val += relative_entropy(q, p, is_valid_required=False) return val
[docs] def gradient(self, var: np.ndarray, validate: bool = False) -> np.ndarray: """returns the gradient of Weighted Relative Entropy. see :func:`~quara.data_analysis.loss_function.LossFunction.gradient` """ grad = np.zeros(self.num_var, dtype=np.float64) for index in range(len(self.func_prob_dists)): # calc list of gradient p tmp_grad_ps = [] for alpha in range(self.num_var): tmp_grad_ps.append(self.func_gradient_prob_dists[index](alpha, var)) grad_ps = np.stack(tmp_grad_ps, 1) q = self.prob_dists_q[index] p = self.func_prob_dists[index](var) if validate: validate_prob_dist( p, raise_error=False, message="WeightedRelativeEntropy.gradient({index})", ) if self.weights: grad += self.weights[index] * gradient_relative_entropy_2nd( q, p, grad_ps, is_valid_required=False ) else: grad += gradient_relative_entropy_2nd( q, p, grad_ps, is_valid_required=False ) return grad
[docs] def hessian(self, var: np.ndarray, validate: bool = False) -> np.ndarray: """returns the Hessian of Weighted Relative Entropy. see :func:`~quara.data_analysis.loss_function.LossFunction.hessian` """ hess = np.zeros((self.num_var, self.num_var), dtype=np.float64) for index in range(len(self.func_prob_dists)): # calc list of gradient p tmp_grad_ps = [] for alpha in range(self.num_var): tmp_grad_ps.append(self.func_gradient_prob_dists[index](alpha, var)) grad_ps = np.stack(tmp_grad_ps, 1) # calc list of Hessian p tmp_hess = [] for alpha in range(self.num_var): tmp_hess_row = [] for beta in range(self.num_var): tmp_hess_row.append( self.func_hessian_prob_dists[index](alpha, beta, var) ) tmp_hess.append(tmp_hess_row) hess_ps = np.array(tmp_hess).transpose(2, 0, 1) q = self.prob_dists_q[index] p = self.func_prob_dists[index](var) if validate: validate_prob_dist( p, raise_error=False, message="WeightedRelativeEntropy.hessian({index})", ) if self.weights: hess += self.weights[index] * hessian_relative_entropy_2nd( q, p, grad_ps, hess_ps, is_valid_required=False ) else: hess += hessian_relative_entropy_2nd( q, p, grad_ps, hess_ps, is_valid_required=False ) return hess