arviz.dict_to_dataset#

arviz.dict_to_dataset(data, *, attrs=None, library=None, coords=None, dims=None, default_dims=None, index_origin=None, skip_event_dims=None)#

Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.

See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but this inclues at least dictionaries and tuple types.

Parameters:
datadict of {strarray_like or dict} or pytree

Data to convert. Keys are variable names.

attrsdict, optional

Json serializable metadata to attach to the dataset, in addition to defaults.

librarymodule, optional

Library used for performing inference. Will be attached to the attrs metadata.

coordsdict of {strndarray}, optional

Coordinates for the dataset

dimsdict of {strlist of str}, optional

Dimensions of each variable. The keys are variable names, values are lists of coordinates.

default_dimslist of str, optional

Passed to numpy_to_data_array()

index_originint, optional

Passed to numpy_to_data_array()

skip_event_dimsbool, optional

If True, cut extra dims whenever present to match the shape of the data. Necessary for PPLs which have the same name in both observed data and log likelihood groups, to account for their different shapes when observations are multivariate.

Returns:
xarray.Dataset

In case of nested pytrees, the variable name will be a tuple of individual names.

Notes

This function is available through two aliases: dict_to_dataset or pytree_to_dataset.

Examples

Convert a dictionary with two 2D variables to a Dataset.

In [1]: import arviz as az
   ...: import numpy as np
   ...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
   ...: 
Out[1]: 
<xarray.Dataset> Size: 7kB
Dimensions:  (chain: 4, draw: 100)
Coordinates:
  * chain    (chain) int64 32B 0 1 2 3
  * draw     (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
Data variables:
    x        (chain, draw) float64 3kB -0.1619 1.607 1.245 ... 0.9672 0.4799
    y        (chain, draw) float64 3kB 0.2238 0.03489 0.3024 ... 0.9876 0.8047
Attributes:
    created_at:     2024-05-10T09:15:06.503956+00:00
    arviz_version:  0.19.0.dev0

Note that unlike the xarray.Dataset constructor, ArviZ has added extra information to the generated Dataset such as default dimension names for sampled dimensions and some attributes.

The function is also general enough to work on pytrees such as nested dictionaries:

In [2]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
Out[2]: 
<xarray.Dataset> Size: 32B
Dimensions:            (chain: 1, draw: 1)
Coordinates:
  * chain              (chain) int64 8B 0
  * draw               (draw) int64 8B 0
Data variables:
    ('top', 'second')  (chain, draw) float64 8B 1.0
    top2               (chain, draw) float64 8B 1.0
Attributes:
    created_at:     2024-05-10T09:15:06.527029+00:00
    arviz_version:  0.19.0.dev0

which has two variables (as many as leafs) named ('top', 'second') and top2.

Dimensions and co-ordinates can be defined as usual:

In [3]: datadict = {
   ...:     "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
   ...:     "d": np.random.randn(100),
   ...: }
   ...: az.dict_to_dataset(
   ...:     datadict,
   ...:     coords={"c": np.arange(10)},
   ...:     dims={("top", "b"): ["c"]}
   ...: )
   ...: 
Out[3]: 
<xarray.Dataset> Size: 10kB
Dimensions:       (chain: 1, draw: 100, c: 10)
Coordinates:
  * chain         (chain) int64 8B 0
  * draw          (draw) int64 800B 0 1 2 3 4 5 6 7 ... 92 93 94 95 96 97 98 99
  * c             (c) int64 80B 0 1 2 3 4 5 6 7 8 9
Data variables:
    d             (chain, draw) float64 800B -0.09468 0.175 ... 0.8065 2.772
    ('top', 'a')  (chain, draw) float64 800B 0.2191 1.704 ... -0.6839 2.638
    ('top', 'b')  (chain, draw, c) float64 8kB -0.2195 -0.04862 ... -0.3008
Attributes:
    created_at:     2024-05-10T09:15:06.540062+00:00
    arviz_version:  0.19.0.dev0