Source code for quara.objects.prob_dist

from typing import Tuple
import numpy as np


from typing import Tuple


[docs]class ProbDist: def __init__(self, ps: np.ndarray, shape: Tuple[int] = None): self._ps = ps self._shape = shape @property # read only def ps(self): return self._ps @property # read only def shape(self): return self._shape def __getitem__(self, idx): # Working in progress if type(idx) == int: # One-dimensional access # ex) prob_dist[0] return self._ps[idx] elif type(idx) == tuple: # Multidimensional access # ex) prob_dist[0][1] if self._shape is None: # TODO: message raise ValueError target = self._ps.reshape(*self._shape) for i in idx: target = target[i] return target else: raise TypeError def __str__(self): desc = f"shape = {self.shape}\n" desc += f"ps = {self._ps}" return desc