Source code for quara.math.probability

import numpy as np

start_red = "\033[31m"
end_color = "\033[0m"


[docs]def validate_prob_dist( prob_dist: np.ndarray, eps: float = None, validate_sum: bool = True, raise_error: bool = True, message: str = "", ) -> None: """validate the probability distribution. Parameters ---------- prob_dist : np.ndarray the probability distribution. eps : float, optional the absolute tolerance parameter, by default 1e-8. checks ``absolute(the sum of probabilities - 1) <= atol`` in this function. validate_sum : bool, optional whether to validate sum=1, by default True. raise_error : bool, optional raises error when validation fails, by default True. message : str, optional prints additional message when validation fails, by default "". Raises ------ ValueError some elements of prob_dist are negative numbers. ValueError the sum of prob_dist is not 1. """ if eps == None: eps = 1e-8 # whether each probability is a positive number. for index, prob in enumerate(prob_dist): # if prob < 0: if prob < 0 and not np.isclose(prob, 0, atol=eps, rtol=0.0): if raise_error: raise ValueError( f"({message}) each probability must be a non-negative number. there is {prob} in a probability distribution({index})" ) else: print( f"{start_red}Warning!{end_color} ({message}) each probability must be a non-negative number. there is {prob} in a probability distribution({index})" ) # whether the sum of probabilities equals 1. if validate_sum is True: sum_p = np.sum(prob_dist) if not np.isclose(sum_p, 1.0, atol=eps, rtol=0.0): if raise_error: raise ValueError( f"({message}) the sum of prob_dist must be 1. the sum of prob_dist={sum_p}" ) else: print( f"{start_red}Warning!{end_color} ({message}) the sum of prob_dist must be 1. the sum of prob_dist={sum_p}" )