Exact Inference with Widgets

Here, we show a very simple example of how to observe the marginal posterior probabilities of each node given the state of one. We will use the Huang graph [HD99].

Simulate data

[1]:
%matplotlib inline
from pybbn.graph.dag import BbnUtil
from pybbn.graph.jointree import EvidenceBuilder, EvidenceType
from pybbn.pptc.inferencecontroller import InferenceController
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from collections import namedtuple

np.random.seed(37)
plt.style.use('ggplot')
Marginal = namedtuple('Marginal', 'name, s')

def potential_to_series(p):
    vals = []
    index = []

    for pe in p.entries:
        try:
            v = pe.entries.values()[0]
        except:
            v = list(pe.entries.values())[0]
        p = pe.value

        vals.append(p)
        index.append(v)

    return pd.Series(vals, index=index)

def get_marginals(join_tree):
    data = []
    for node in join_tree.get_bbn_nodes():
        name = node.variable.name
        s = potential_to_series(join_tree.get_bbn_potential(node))
        t = Marginal(name, s)
        data.append(t)
    return data

# get the pre-defined huang graph
bbn = BbnUtil.get_huang_graph()

# convert the BBN to a join tree
join_tree = InferenceController.apply(bbn)

Visualize

[2]:
import math
from ipywidgets import interact

@interact(a=[('unobserved', -1), ('off', 0), ('on', 1)])
def f(a=-1):
    n_cols = 4
    n_rows = math.ceil(len(bbn.get_nodes()) / n_cols)

    if a == -1:
        join_tree.unobserve_all()
        marginals = get_marginals(join_tree)
    else:
        v = 'on' if a == 1 else 'off'
        ev = EvidenceBuilder() \
            .with_node(join_tree.get_bbn_node_by_name('a')) \
            .with_evidence(v, 1.0) \
            .build()
        join_tree.unobserve_all()
        join_tree.set_observation(ev)
        marginals = get_marginals(join_tree)

    marginals = sorted(marginals, key=lambda tup: tup[0])

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5), sharey=True)

    for m, ax in zip(marginals, np.ravel(axes)):
        m.s.plot(kind='bar', legend=False, ax=ax)
        ax.set_title(m.name)
        ax.set_ylim([0.0, 1.0])
        ax.set_xlabel('')

    plt.tight_layout()