"""ArviZ rcparams. Based on matplotlib's implementation."""
import locale
import logging
import os
import pprint
import re
import sys
import warnings
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any, Dict
from typing_extensions import Literal
NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True
import numpy as np
_log = logging.getLogger(__name__)
ScaleKeyword = Literal["log", "negative_log", "deviance"]
ICKeyword = Literal["loo", "waic"]
_identity = lambda x: x
def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
"""Validate value is in accepted_values.
Parameters
----------
accepted_values : iterable
Iterable containing all accepted_values.
allow_none: boolean, optional
Whether to accept ``None`` in addition to the values in ``accepted_values``.
typeof: type, optional
Type the values should be converted to.
"""
# no blank lines allowed after function docstring by pydocstyle,
# but black requires white line before function
def validate_choice(value):
if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
return None
try:
value = typeof(value)
except (ValueError, TypeError) as err:
raise ValueError(f"Could not convert to {typeof.__name__}") from err
if isinstance(value, str):
value = value.lower()
if value in accepted_values:
# Convert value to python boolean if string matches
value = {"true": True, "false": False}.get(value, value)
return value
raise ValueError(
f'{value} is not one of {accepted_values}{" nor None" if allow_none else ""}'
)
return validate_choice
def _make_validate_choice_regex(accepted_values, accepted_values_regex, allow_none=False):
"""Validate value is in accepted_values with regex.
Parameters
----------
accepted_values : iterable
Iterable containing all accepted_values.
accepted_values_regex : iterable
Iterable containing all accepted_values with regex string.
allow_none: boolean, optional
Whether to accept ``None`` in addition to the values in ``accepted_values``.
typeof: type, optional
Type the values should be converted to.
"""
# no blank lines allowed after function docstring by pydocstyle,
# but black requires white line before function
def validate_choice_regex(value):
if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
return None
value = str(value)
if isinstance(value, str):
value = value.lower()
if value in accepted_values:
# Convert value to python boolean if string matches
value = {"true": True, "false": False}.get(value, value)
return value
elif any(re.match(pattern, value) for pattern in accepted_values_regex):
return value
raise ValueError(
f"{value} is not one of {accepted_values} "
f'or in regex {accepted_values_regex}{" nor None" if allow_none else ""}'
)
return validate_choice_regex
def _validate_positive_int(value):
"""Validate value is a natural number."""
try:
value = int(value)
except ValueError as err:
raise ValueError("Could not convert to int") from err
if value > 0:
return value
else:
raise ValueError("Only positive values are valid")
def _validate_float(value):
"""Validate value is a float."""
try:
value = float(value)
except ValueError as err:
raise ValueError("Could not convert to float") from err
return value
def _validate_probability(value):
"""Validate a probability: a float between 0 and 1."""
value = _validate_float(value)
if (value < 0) or (value > 1):
raise ValueError("Only values between 0 and 1 are valid.")
return value
def _validate_boolean(value):
"""Validate value is a float."""
if isinstance(value, str):
value = value.lower()
if value not in {True, False, "true", "false"}:
raise ValueError("Only boolean values are valid.")
return value is True or value == "true"
def _add_none_to_validator(base_validator):
"""Create a validator function that catches none and then calls base_fun."""
# no blank lines allowed after function docstring by pydocstyle,
# but black requires white line before function
def validate_with_none(value):
if value is None or isinstance(value, str) and value.lower() == "none":
return None
return base_validator(value)
return validate_with_none
def _validate_bokeh_marker(value):
"""Validate the markers."""
try:
from bokeh.core.enums import MarkerType
except ImportError:
MarkerType = [
"asterisk",
"circle",
"circle_cross",
"circle_dot",
"circle_x",
"circle_y",
"cross",
"dash",
"diamond",
"diamond_cross",
"diamond_dot",
"dot",
"hex",
"hex_dot",
"inverted_triangle",
"plus",
"square",
"square_cross",
"square_dot",
"square_pin",
"square_x",
"star",
"star_dot",
"triangle",
"triangle_dot",
"triangle_pin",
"x",
"y",
]
all_markers = list(MarkerType)
if value not in all_markers:
raise ValueError(f"{value} is not one of {all_markers}")
return value
def _validate_dict_of_lists(values):
if isinstance(values, dict):
return {key: tuple(item) for key, item in values.items()}
validated_dict = {}
for value in values:
tup = value.split(":", 1)
if len(tup) != 2:
raise ValueError(f"Could not interpret '{value}' as key: list or str")
key, vals = tup
key = key.strip(' "')
vals = [val.strip(' "') for val in vals.strip(" [],").split(",")]
if key in validated_dict:
warnings.warn(f"Repeated key {key} when validating dict of lists")
validated_dict[key] = tuple(vals)
return validated_dict
def make_iterable_validator(scalar_validator, length=None, allow_none=False, allow_auto=False):
"""Validate value is an iterable datatype."""
# based on matplotlib's _listify_validator function
def validate_iterable(value):
if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
return None
if isinstance(value, str):
if allow_auto and value.lower() == "auto":
return "auto"
value = tuple(v.strip("([ ])") for v in value.split(",") if v.strip())
if np.iterable(value) and not isinstance(value, (set, frozenset)):
val = tuple(scalar_validator(v) for v in value)
if length is not None and len(val) != length:
raise ValueError(f"Iterable must be of length: {length}")
return val
raise ValueError("Only ordered iterable values are valid")
return validate_iterable
_validate_float_or_none = _add_none_to_validator(_validate_float)
_validate_positive_int_or_none = _add_none_to_validator(_validate_positive_int)
_validate_bokeh_bounds = make_iterable_validator( # pylint: disable=invalid-name
_validate_float_or_none, length=2, allow_none=True, allow_auto=True
)
METAGROUPS = {
"posterior_groups": ["posterior", "posterior_predictive", "sample_stats", "log_likelihood"],
"prior_groups": ["prior", "prior_predictive", "sample_stats_prior"],
"posterior_groups_warmup": [
"_warmup_posterior",
"_warmup_posterior_predictive",
"_warmup_sample_stats",
],
"latent_vars": ["posterior", "prior"],
"observed_vars": ["posterior_predictive", "observed_data", "prior_predictive"],
}
defaultParams = { # pylint: disable=invalid-name
"data.http_protocol": ("https", _make_validate_choice({"https", "http"})),
"data.load": ("lazy", _make_validate_choice({"lazy", "eager"})),
"data.metagroups": (METAGROUPS, _validate_dict_of_lists),
"data.index_origin": (0, _make_validate_choice({0, 1}, typeof=int)),
"data.log_likelihood": (True, _validate_boolean),
"data.save_warmup": (False, _validate_boolean),
"plot.backend": ("matplotlib", _make_validate_choice({"matplotlib", "bokeh"})),
"plot.density_kind": ("kde", _make_validate_choice({"kde", "hist"})),
"plot.max_subplots": (40, _validate_positive_int_or_none),
"plot.point_estimate": (
"mean",
_make_validate_choice({"mean", "median", "mode"}, allow_none=True),
),
"plot.bokeh.bounds_x_range": ("auto", _validate_bokeh_bounds),
"plot.bokeh.bounds_y_range": ("auto", _validate_bokeh_bounds),
"plot.bokeh.figure.dpi": (60, _validate_positive_int),
"plot.bokeh.figure.height": (500, _validate_positive_int),
"plot.bokeh.figure.width": (500, _validate_positive_int),
"plot.bokeh.layout.order": (
"default",
_make_validate_choice_regex(
{"default", r"column", r"row", "square", "square_trimmed"}, {r"\d*column", r"\d*row"}
),
),
"plot.bokeh.layout.sizing_mode": (
"fixed",
_make_validate_choice(
{
"fixed",
"stretch_width",
"stretch_height",
"stretch_both",
"scale_width",
"scale_height",
"scale_both",
}
),
),
"plot.bokeh.layout.toolbar_location": (
"above",
_make_validate_choice({"above", "below", "left", "right"}, allow_none=True),
),
"plot.bokeh.marker": ("cross", _validate_bokeh_marker),
"plot.bokeh.output_backend": ("webgl", _make_validate_choice({"canvas", "svg", "webgl"})),
"plot.bokeh.show": (True, _validate_boolean),
"plot.bokeh.tools": (
"reset,pan,box_zoom,wheel_zoom,lasso_select,undo,save,hover",
lambda x: x,
),
"plot.matplotlib.show": (False, _validate_boolean),
"stats.ci_prob": (0.94, _validate_probability),
"stats.information_criterion": (
"loo",
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
),
"stats.ic_pointwise": (True, _validate_boolean),
"stats.ic_scale": (
"log",
_make_validate_choice(
{"log", "negative_log", "deviance"} if NO_GET_ARGS else set(get_args(ScaleKeyword))
),
),
"stats.ic_compare_method": (
"stacking",
_make_validate_choice({"stacking", "bb-pseudo-bma", "pseudo-bma"}),
),
}
# map from deprecated params to (version, new_param, fold2new, fnew2old)
deprecated_map = {"stats.hdi_prob": ("0.18.0", "stats.ci_prob", _identity, _identity)}
class RcParams(MutableMapping):
"""Class to contain ArviZ default parameters.
It is implemented as a dict with validation when setting items.
"""
validate = {key: validate_fun for key, (_, validate_fun) in defaultParams.items()}
# validate values on the way in
def __init__(self, *args, **kwargs):
self._underlying_storage: Dict[str, Any] = {}
super().__init__()
self.update(*args, **kwargs)
def __setitem__(self, key, val):
"""Add validation to __setitem__ function."""
if key in deprecated_map:
version, key_new, fold2new, _ = deprecated_map[key]
warnings.warn(
f"{key} is deprecated since {version}, use {key_new} instead",
FutureWarning,
)
key = key_new
val = fold2new(val)
try:
try:
cval = self.validate[key](val)
except ValueError as verr:
raise ValueError(f"Key {key}: {str(verr)}") from verr
self._underlying_storage[key] = cval
except KeyError as err:
raise KeyError(
f"{key} is not a valid rc parameter "
f"(see rcParams.keys() for a list of valid parameters)"
) from err
def __getitem__(self, key):
"""Use underlying dict's getitem method."""
if key in deprecated_map:
version, key_new, _, fnew2old = deprecated_map[key]
warnings.warn(
f"{key} is deprecated since {version}, use {key_new} instead",
FutureWarning,
)
if key not in self._underlying_storage:
key = key_new
else:
fnew2old = _identity
return fnew2old(self._underlying_storage[key])
def __delitem__(self, key):
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
raise TypeError("RcParams keys cannot be deleted")
def clear(self):
"""Raise TypeError if someone ever tries to delete all keys from RcParams."""
raise TypeError("RcParams keys cannot be deleted")
def pop(self, key, default=None):
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
raise TypeError(
"RcParams keys cannot be deleted. Use .get(key) of RcParams[key] to check values"
)
def popitem(self):
"""Raise TypeError if someone ever tries to delete a key from RcParams."""
raise TypeError(
"RcParams keys cannot be deleted. Use .get(key) of RcParams[key] to check values"
)
def setdefault(self, key, default=None):
"""Raise error when using setdefault, defaults are handled on initialization."""
raise TypeError(
"Defaults in RcParams are handled on object initialization during library"
"import. Use arvizrc file instead."
""
)
def __repr__(self):
"""Customize repr of RcParams objects."""
class_name = self.__class__.__name__
indent = len(class_name) + 1
repr_split = pprint.pformat(
self._underlying_storage,
indent=1,
width=80 - indent,
).split("\n")
repr_indented = ("\n" + " " * indent).join(repr_split)
return f"{class_name}({repr_indented})"
def __str__(self):
"""Customize str/print of RcParams objects."""
return "\n".join(map("{0[0]:<22}: {0[1]}".format, sorted(self._underlying_storage.items())))
def __iter__(self):
"""Yield sorted list of keys."""
yield from sorted(self._underlying_storage.keys())
def __len__(self):
"""Use underlying dict's len method."""
return len(self._underlying_storage)
def find_all(self, pattern):
"""
Find keys that match a regex pattern.
Return the subset of this RcParams dictionary whose keys match,
using :func:`re.search`, the given ``pattern``.
Notes
-----
Changes to the returned dictionary are *not* propagated to
the parent RcParams dictionary.
"""
pattern_re = re.compile(pattern)
return RcParams((key, value) for key, value in self.items() if pattern_re.search(key))
def copy(self):
"""Get a copy of the RcParams object."""
return dict(self._underlying_storage)
def get_arviz_rcfile():
"""Get arvizrc file.
The file location is determined in the following order:
- ``$PWD/arvizrc``
- ``$ARVIZ_DATA/arvizrc``
- On Linux,
- ``$XDG_CONFIG_HOME/arviz/arvizrc`` (if ``$XDG_CONFIG_HOME``
is defined)
- or ``$HOME/.config/arviz/arvizrc`` (if ``$XDG_CONFIG_HOME``
is not defined)
- On other platforms,
- ``$HOME/.arviz/arvizrc`` if ``$HOME`` is defined
Otherwise, the default defined in ``rcparams.py`` file will be used.
"""
# no blank lines allowed after function docstring by pydocstyle,
# but black requires white line before function
def gen_candidates():
yield os.path.join(os.getcwd(), "arvizrc")
arviz_data_dir = os.environ.get("ARVIZ_DATA")
if arviz_data_dir:
yield os.path.join(arviz_data_dir, "arvizrc")
xdg_base = os.environ.get("XDG_CONFIG_HOME", str(Path.home() / ".config"))
if sys.platform.startswith(("linux", "freebsd")):
configdir = str(Path(xdg_base, "arviz"))
else:
configdir = str(Path.home() / ".arviz")
yield os.path.join(configdir, "arvizrc")
for fname in gen_candidates():
if os.path.exists(fname) and not os.path.isdir(fname):
return fname
return None
def read_rcfile(fname):
"""Return :class:`arviz.RcParams` from the contents of the given file.
Unlike `rc_params_from_file`, the configuration class only contains the
parameters specified in the file (i.e. default values are not filled in).
"""
_error_details_fmt = 'line #%d\n\t"%s"\n\tin file "%s"'
config = RcParams()
with open(fname, "r", encoding="utf8") as rcfile:
try:
multiline = False
for line_no, line in enumerate(rcfile, 1):
strippedline = line.split("#", 1)[0].strip()
if not strippedline:
continue
if multiline:
if strippedline == "}":
multiline = False
val = aux_val
else:
aux_val.append(strippedline)
continue
else:
tup = strippedline.split(":", 1)
if len(tup) != 2:
error_details = _error_details_fmt % (line_no, line, fname)
_log.warning("Illegal %s", error_details)
continue
key, val = tup
key = key.strip()
val = val.strip()
if key in config:
_log.warning("Duplicate key in file %r line #%d.", fname, line_no)
if key in {"data.metagroups"}:
aux_val = []
multiline = True
continue
try:
config[key] = val
except ValueError as verr:
error_details = _error_details_fmt % (line_no, line, fname)
raise ValueError(f"Bad val {val} on {error_details}\n\t{str(verr)}") from verr
except UnicodeDecodeError:
_log.warning(
"Cannot decode configuration file %s with encoding "
"%s, check LANG and LC_* variables.",
fname,
locale.getpreferredencoding(do_setlocale=False) or "utf-8 (default)",
)
raise
return config
def rc_params(ignore_files=False):
"""Read and validate arvizrc file."""
fname = None if ignore_files else get_arviz_rcfile()
defaults = RcParams([(key, default) for key, (default, _) in defaultParams.items()])
if fname is not None:
file_defaults = read_rcfile(fname)
defaults.update(file_defaults)
return defaults
rcParams = rc_params() # pylint: disable=invalid-name
[docs]
class rc_context: # pylint: disable=invalid-name
"""
Return a context manager for managing rc settings.
Parameters
----------
rc : dict, optional
Mapping containing the rcParams to modify temporally.
fname : str, optional
Filename of the file containing the rcParams to use inside the rc_context.
Examples
--------
This allows one to do::
with az.rc_context(fname='pystan.rc'):
idata = az.load_arviz_data("radon")
az.plot_posterior(idata, var_names=["gamma"])
The plot would have settings from 'screen.rc'
A dictionary can also be passed to the context manager::
with az.rc_context(rc={'plot.max_subplots': None}, fname='pystan.rc'):
idata = az.load_arviz_data("radon")
az.plot_posterior(idata, var_names=["gamma"])
The 'rc' dictionary takes precedence over the settings loaded from
'fname'. Passing a dictionary only is also valid.
"""
# Based on mpl.rc_context
def __init__(self, rc=None, fname=None):
self._orig = rcParams.copy()
if fname:
file_rcparams = read_rcfile(fname)
rcParams.update(file_rcparams)
if rc:
rcParams.update(rc)
def __enter__(self):
"""Define enter method of context manager."""
return self
def __exit__(self, exc_type, exc_value, exc_tb):
"""Define exit method of context manager."""
rcParams.update(self._orig)