Multivariate Gaussian Inference with Widgets

This notebook shows how to do multivariate Gaussian inference with widgets. We allow one variable to change and visualize the change of distributions for the other. We will be using the Cowell graph [Cow98].

Simulate data

[1]:
%matplotlib inline
import numpy as np
from pybbn.gaussian.inference import GaussianInference
import matplotlib.pyplot as plt

np.random.seed(37)
plt.style.use('ggplot')
plt.rcParams['axes.grid'] = False

def get_cowell_data():
    n = 10000
    Y = np.random.normal(0, 1, n)
    X = np.random.normal(Y, 1, n)
    Z = np.random.normal(X, 1, n)

    D = np.vstack([Y, X, Z]).T
    return D, ['Y', 'X', 'Z']

def get_mvn():
    X, H = get_cowell_data()

    M = X.mean(axis=0)
    E = np.cov(X.T)

    g = GaussianInference(H, M, E)
    return g

g = get_mvn()
[2]:
import pandas as pd

pd.DataFrame(g.marginals)
[2]:
name mean var
0 Y -0.001723 0.990700
1 X 0.007448 2.016406
2 Z 0.002459 3.033838

Visualize

[3]:
from ipywidgets import interact

samples1 = g.sample_marginals(size=10000)

@interact(x=(-5, 5, 1))
def f(x=None):
    if x is not None:
        gg = g.do_inference('X', x)
    else:
        gg = g

    samples2 = gg.sample_marginals(size=5000)

    fig, axes = plt.subplots(1, 3, figsize=(15, 3), sharey=False)
    axes = np.ravel(axes)

    kind = 'hist'
    alpha = 0.15
    for (name, s2), ax in zip(samples2.items(), axes):
        if name == 'X':
            ax2 = ax.twinx()
            _ = samples1[name].plot(kind=kind, ax=ax2, color='blue', alpha=alpha)
            _ = ax.axvline(x=x, color='red')
            _ = ax2.set_ylabel('')
        else:
            ax2 = ax.twinx()
            _ = samples1[name].plot(kind=kind, ax=ax, color='blue', alpha=alpha)
            _ = s2.plot(kind=kind, ax=ax)
            _ = s2.plot(kind='kde', ax=ax2, color='green')
            _ = ax2.set_ylabel('')

        _ = ax.set_title(f'{name}')
        _ = ax.set_ylabel('')

    plt.tight_layout()