src_model_equivalence

Full_run.py

# %%
# A1. Setup & data
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

SEED = 7
X, y = make_classification(
    n_samples=8000, n_features=6, n_informative=5, n_redundant=1,
    n_clusters_per_class=2, class_sep=1.6, random_state=SEED
)
X_tr, X_tmp, y_tr, y_tmp = train_test_split(X, y, test_size=0.40, random_state=SEED)
X_val, X_te, y_val, y_te = train_test_split(X_tmp, y_tmp, test_size=0.50, random_state=SEED)

# Domain bounds (box) for verification — tighten if needed
lo = X_tr.min(axis=0)
hi = X_tr.max(axis=0)

# A2. Large forest
big = RandomForestClassifier(
    n_estimators=60, max_depth=8, min_samples_leaf=2,
    random_state=SEED, n_jobs=-1
).fit(X_tr, y_tr)

print("Big RF — train/val acc:",
      (big.score(X_tr, y_tr), big.score(X_val, y_val)))


# %%
from copy import deepcopy
from sklearn.metrics import accuracy_score

def prune_forest_by_val_greedy(rf: RandomForestClassifier, X_val, y_val, target_trees: int):
    """Drop trees that least reduce validation accuracy, until n_estimators == target_trees."""
    if target_trees >= len(rf.estimators_):
        return deepcopy(rf)

    keep = list(range(len(rf.estimators_)))
    rf_work = deepcopy(rf)

    def forest_pred_on(estim_idx):
        # average votes from the selected subset of estimators
        estims = [rf.estimators_[i] for i in estim_idx]
        proba = sum(e.predict_proba(X_val) for e in estims) / len(estims)
        return (proba[:,1] >= 0.5).astype(int)

    base_acc = accuracy_score(y_val, forest_pred_on(keep))
    while len(keep) > target_trees:
        # test dropping each tree; pick the one whose removal keeps acc highest
        best_drop, best_acc = None, -1.0
        for i in keep:
            cand = [j for j in keep if j != i]
            acc = accuracy_score(y_val, forest_pred_on(cand))
            if acc >= best_acc:
                best_acc, best_drop = acc, i
        keep.remove(best_drop)

    pruned = deepcopy(rf)
    pruned.estimators_ = [rf.estimators_[i] for i in keep]
    pruned.n_estimators = len(pruned.estimators_)
    return pruned, keep

pruned, kept_idx = prune_forest_by_val_greedy(big, X_val, y_val, target_trees=18)
print("Pruned RF — trees kept:", len(kept_idx))
print("Big vs Pruned val acc:", big.score(X_val, y_val), pruned.score(X_val, y_val))


# %%
# C1. Z3 imports
from z3 import Real, Bool, And, Or, Not, If, Solver, sat

def encode_sklearn_tree_as_z3(tree, x_vars):
    """
    Returns a Z3 expression for a sklearn DecisionTreeClassifier:
      output in {0,1} = class argmax at the reached leaf.
    """
    t = tree.tree_
    feat = t.feature
    thresh = t.threshold
    children_left = t.children_left
    children_right = t.children_right
    values = t.value  # shape [nodes, 1, n_classes]

    def node_expr(idx):
        if children_left[idx] == children_right[idx]:  # leaf
            counts = values[idx][0]
            pred_class = int(np.argmax(counts))
            return RealVal(pred_class)
        f = feat[idx]
        thr = thresh[idx]
        cond = (x_vars[f] <= thr)
        return If(cond, node_expr(children_left[idx]), node_expr(children_right[idx]))

    from z3 import RealVal
    return node_expr(0)

def encode_forest_as_z3_avg_vote(rf: RandomForestClassifier, x_vars):
    """Average of tree votes (each 0/1)."""
    from z3 import Sum, ToReal
    tree_exprs = [encode_sklearn_tree_as_z3(t, x_vars) for t in rf.estimators_]
    # average = sum / n_trees (keep as real)
    return Sum(*tree_exprs) / len(tree_exprs)

def make_box_constraints(x_vars, lo, hi):
    cons = []
    for i, xi in enumerate(x_vars):
        cons += [xi >= lo[i], xi <= hi[i]]
    return And(*cons)


# %%
# D1. Build Z3 problem
from z3 import RealVal

d = X.shape[1]
x = [Real(f"x{i}") for i in range(d)]

big_out   = encode_forest_as_z3_avg_vote(big, x)
prun_out  = encode_forest_as_z3_avg_vote(pruned, x)

def to_label(expr):  # >= 0.5 ⇔ class 1
    return If(expr >= RealVal(0.5), RealVal(1), RealVal(0))

s = Solver()
s.add(make_box_constraints(x, lo, hi))
s.add(to_label(big_out) != to_label(prun_out))   # ask for a disagreement

# D2. Check
res = s.check()
if res == sat:
    m = s.model()
    cex = np.array([float(m[xi].as_decimal(10).replace("?", "")) for xi in x], dtype=float)
    print("❌ Counterexample found (big vs pruned differ):", cex)
else:
    print("✅ Equivalent on the bounded box [lo, hi].")


# %%
x_cex = cex.reshape(1, -1)
print("big:",    big.predict(x_cex), big.predict_proba(x_cex)[0])
print("pruned:", pruned.predict(x_cex), pruned.predict_proba(x_cex)[0])


# %%
rng = np.random.default_rng(0)
Z = rng.uniform(lo, hi, size=(20000, len(lo)))
emp = (big.predict(Z) != pruned.predict(Z)).mean()
print("sampled disagree rate:", emp)


# %% [markdown]
# ### Extras

# %%
import numpy as np

def tree_leaf_and_path(dt, x):
    """
    Returns (leaf_class, path_node_indices) for one sklearn DecisionTreeClassifier and 1×d x.
    """
    t = dt.tree_
    node = 0
    path = [0]
    while t.children_left[node] != t.children_right[node]:
        f = t.feature[node]
        thr = t.threshold[node]
        if x[0, f] <= thr:
            node = t.children_left[node]
        else:
            node = t.children_right[node]
        path.append(node)
    # leaf class = argmax counts at leaf
    cls = int(np.argmax(t.value[node][0]))
    return cls, path

def pretty_path(dt, path, feature_names=None, x=None):
    """Human-readable path constraints."""
    t = dt.tree_
    out = []
    for i in range(len(path)-1):
        n = path[i]
        f = t.feature[n]; thr = t.threshold[n]
        left = t.children_left[n]; right = t.children_right[n]
        name = feature_names[f] if feature_names is not None else f"x[{f}]"
        went_left = (path[i+1] == left)
        cond = f"{name} <= {thr:.6g}" if went_left else f"{name} > {thr:.6g}"
        if x is not None:
            cond += f"  (x={x[0,f]:.6g})"
        out.append(cond)
    return " ∧ ".join(out)

def trace_forest_disagreement(x_cex, big, pruned, feature_names=None, top_k=10):
    """
    Prints:
      - forest votes
      - per-tree votes
      - trees that differ between big and pruned on x_cex
      - path constraints for a few differing trees
      - greedy minimal subset of kept trees whose removal flips big's vote on x_cex
    """
    xb = x_cex.reshape(1, -1)

    # Forest-level
    big_prob   = np.mean([t.predict(xb)[0] for t in big.estimators_])
    prun_prob  = np.mean([t.predict(xb)[0] for t in pruned.estimators_])
    print(f"[forest] big: prob1={big_prob:.6f}, pred={int(big_prob>=0.5)}")
    print(f"[forest] pruned: prob1={prun_prob:.6f}, pred={int(prun_prob>=0.5)}")

    # Per-tree votes
    big_votes  = np.array([t.predict(xb)[0] for t in big.estimators_], dtype=int)
    prun_votes = np.array([t.predict(xb)[0] for t in pruned.estimators_], dtype=int)

    # Trees present in pruned forest
    kept = set(id(t) for t in pruned.estimators_)
    kept_idx = [i for i,t in enumerate(big.estimators_) if id(t) in kept]

    # Differences among kept trees (structure identical; prediction might differ if you later distill/simplify)
    diff_idx = [i for i in kept_idx if big.estimators_[i].predict(xb)[0] != pruned.estimators_[kept_idx.index(i)].predict(xb)[0]]
    print(f"[kept trees] disagreeing predictions: {len(diff_idx)} / {len(kept_idx)}")

    # Differences due to pruned-away trees (vote lost vs kept)
    removed_idx = [i for i in range(len(big.estimators_)) if i not in kept_idx]
    removed_vote = big_votes[removed_idx].mean() if removed_idx else np.nan
    print(f"[removed trees] count={len(removed_idx)}, mean vote={removed_vote}")

    # Show a few paths that vote 1 vs 0 in BIG (useful to see why vote changed)
    # Pick top_k trees with largest absolute contribution change after pruning
    if len(removed_idx):
        # effect = how much average vote would drop if this single tree were removed
        base = big_votes.mean()
        effects = []
        for i in removed_idx:
            new_mean = (big_votes.sum() - big_votes[i]) / (len(big_votes) - 1)
            effects.append((abs(base - new_mean), i))
        effects.sort(reverse=True)
        show = [j for _, j in effects[:min(top_k, len(effects))]]
        print(f"[removed trees] top contributors (by single-tree effect on mean vote): {show}")
        for i in show:
            leaf, path = tree_leaf_and_path(big.estimators_[i], xb)
            print(f"  - tree#{i} vote={leaf} | path: {pretty_path(big.estimators_[i], path, feature_names, xb)}")

    # Minimal subset of kept trees whose removal flips BIG prediction at x_cex (greedy)
    # Insight: identifies “responsible coalition” for the decision at this x
    target_label = int(big_prob >= 0.5)
    votes = big_votes.copy().astype(float)
    idxs = list(range(len(votes)))
    current_mean = votes.mean()
    coalition = []
    while int(current_mean >= 0.5) == target_label and idxs:
        best_drop, best_mean = None, None
        for i in idxs:
            new_mean = (votes.sum() - votes[i]) / (len(idxs) - 1) if len(idxs) > 1 else (1.0 - votes[i])  # degenerate
            if (best_mean is None) or (abs(new_mean - 0.5) < abs(best_mean - 0.5)):
                best_drop, best_mean = i, new_mean
        coalition.append(best_drop)
        idxs.remove(best_drop)
        current_mean = (votes.sum() - votes[coalition].sum()) / (len(votes) - len(coalition))
    print(f"[minimal-ish coalition] #trees to flip BIG at x_cex (greedy): {len(coalition)}")
    return {
        "big_prob": big_prob, "pruned_prob": prun_prob,
        "removed_idx": removed_idx, "diff_idx": diff_idx, "coalition": coalition
    }


# %%
_ = trace_forest_disagreement(cex, big, pruned, feature_names=None, top_k=8)


# %% [markdown]
# ### margin calls
# 

# %%
from z3 import Solver, Real, RealVal, And, Or, Sum, ToReal, If, sat

def z3_counterexample_margin(big, pruned, lo, hi, eps=0.10):
    d = len(lo)
    x = [Real(f"x{i}") for i in range(d)]

    big_out  = encode_forest_as_z3_avg_vote(big, x)
    prun_out = encode_forest_as_z3_avg_vote(pruned, x)

    s = Solver()
    s.add(make_box_constraints(x, lo, hi))
    # margin condition:
    s.add( Or(big_out - prun_out >= RealVal(eps),
              prun_out - big_out >= RealVal(eps)) )

    if s.check() == sat:
        m = s.model()
        cex = np.array([float(m[xi].as_decimal(20).replace("?", "")) for xi in x], dtype=float)
        return cex
    return None

# Example:
cex_margin = z3_counterexample_margin(big, pruned, lo, hi, eps=0.25)
print("margin-CE:", cex_margin)
if cex_margin is not None:
    print("big/pruned probs at margin-CE:",
          np.mean([t.predict(cex_margin.reshape(1,-1))[0] for t in big.estimators_]),
          np.mean([t.predict(cex_margin.reshape(1,-1))[0] for t in pruned.estimators_]))


# %% [markdown]
# ### viz

# %%
# z3_rf_viz.py
import numpy as np
import matplotlib.pyplot as plt

# ---------- smoothing (scipy optional) ----------
def _smooth(Z, sigma):
    if sigma is None or sigma <= 0:
        return Z
    try:
        from scipy.ndimage import gaussian_filter  # noqa
        return gaussian_filter(Z, sigma=float(sigma))
    except Exception:
        # tiny separable box blur fallback
        k = max(1, int(round(sigma * 2)))
        if k % 2 == 0: k += 1
        pad = k // 2
        Zp = np.pad(Z, pad, mode="edge")
        W = np.ones((k, k), dtype=float) / (k * k)
        out = np.zeros_like(Z)
        for i in range(out.shape[0]):
            for j in range(out.shape[1]):
                out[i, j] = (Zp[i:i+k, j:j+k] * W).sum()
        return out

# ---------- core helpers ----------
def rf_vote_prob(estimators, X):
    """Mean per-tree P(class=1)."""
    return np.mean([t.predict_proba(X)[:, 1] for t in estimators], axis=0)

def _grid(lo, hi, center, i, j, n):
    xs = np.linspace(lo[i], hi[i], n)
    ys = np.linspace(lo[j], hi[j], n)
    XX, YY = np.meshgrid(xs, ys)
    base = np.tile(center, (n * n, 1))
    for r in range(n):
        base[r * n:(r + 1) * n, i] = XX[r]
        base[r * n:(r + 1) * n, j] = YY[r]
    return xs, ys, base

def best_dims_for_disagreement(big, pruned, lo, hi, samples=4000, rng=0):
    """Pick (i,j) where |Δvote| shows the most structure (variance over coarse grid)."""
    r = np.random.default_rng(rng)
    X = r.uniform(lo, hi, size=(samples, len(lo)))
    dv = np.abs(rf_vote_prob(big.estimators_, X) - rf_vote_prob(pruned.estimators_, X))
    d = len(lo); bins = 16
    scores = {}
    for i in range(d):
        for j in range(i + 1, d):
            bi = np.clip(((X[:, i] - lo[i]) / (hi[i] - lo[i]) * bins).astype(int), 0, bins - 1)
            bj = np.clip(((X[:, j] - lo[j]) / (hi[j] - lo[j]) * bins).astype(int), 0, bins - 1)
            grid_sum = np.zeros((bins, bins)); grid_cnt = np.zeros((bins, bins))
            for k in range(len(X)):
                grid_sum[bi[k], bj[k]] += dv[k]; grid_cnt[bi[k], bj[k]] += 1
            with np.errstate(invalid='ignore', divide='ignore'):
                grid = grid_sum / grid_cnt
            scores[(i, j)] = np.nanvar(grid)
    return max(scores, key=scores.get)

def _imshow_copper(ZZ, xs, ys, title, cbar_label, xlabel=None, ylabel=None, vmin=None, vmax=None):
    if vmin is None or vmax is None:
        finite = ZZ[np.isfinite(ZZ)]
        if finite.size:
            vmin = np.nanpercentile(finite, 5)
            vmax = np.nanpercentile(finite, 95)

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(
        ZZ,
        extent=[xs.min(), xs.max(), ys.min(), ys.max()],
        origin="lower",
        aspect="equal",
        cmap="copper",
        vmin=vmin,
        vmax=vmax,
    )
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(cbar_label)

    # ✅ axis labels restored
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    ax.set_title(title)
    fig.tight_layout()
    plt.show()


# ---------- visuals ----------
def plot_vote_diff_slice(
    big, pruned, lo, hi, *, center=None, dims=None, n=220, pts=None,
    eps=None, smooth_sigma=1.0, title="|Δ vote probability| on 2D slice"
):
    """
    Heatmap of |P_big(1) - P_pruned(1)| on a 2D slice.
    Args:
      center: fixed values for other dims (default: mid-box)
      dims: (i,j) to plot (default: auto via best_dims_for_disagreement)
      n: grid resolution (use ~150–250 for readability)
      eps: if set, mask values < eps
      smooth_sigma: gaussian sigma (or box-blur fallback) for visual smoothing
    """
    d = len(lo)
    if center is None: center = (lo + hi) / 2.0
    if dims is None: dims = best_dims_for_disagreement(big, pruned, lo, hi)
    i, j = dims
    xs, ys, base = _grid(lo, hi, center, i, j, n)

    vb = rf_vote_prob(big.estimators_, base)
    vp = rf_vote_prob(pruned.estimators_, base)
    ZZ = np.abs(vb - vp).reshape(n, n)

    ZZ = _smooth(ZZ, smooth_sigma)
    if eps is not None:
        ZZ = np.where(ZZ >= eps, ZZ, np.nan)

    _imshow_copper(
        ZZ, xs, ys, title,
        cbar_label="|Δ vote prob|",
        xlabel=f"x[{i}]",
        ylabel=f"x[{j}]"
    )


    # overlay points
    if pts:
        plt.figure();  # small overlay figure for markers only (no replot)
        plt.close()  # keep API simple; users can scatter on their own if they want

def plot_label_disagreement_region(
    big, pruned, lo, hi, *, center=None, dims=None, n=220, pts=None,
    smooth_sigma=0.0, title="Label disagreement region"
):
    d = len(lo)
    if center is None: center = (lo + hi) / 2.0
    if dims is None: dims = best_dims_for_disagreement(big, pruned, lo, hi)
    i, j = dims
    xs, ys, base = _grid(lo, hi, center, i, j, n)

    yb = (rf_vote_prob(big.estimators_, base) >= 0.5).astype(int)
    yp = (rf_vote_prob(pruned.estimators_, base) >= 0.5).astype(int)
    D = (yb != yp).reshape(n, n).astype(float)
    D = _smooth(D, smooth_sigma)

    _imshow_copper(
        D, xs, ys, title,
        cbar_label="disagree=1 / agree=0",
        xlabel=f"x[{i}]",
        ylabel=f"x[{j}]"
    )


def plot_removed_tree_effects_at_x(
    big, pruned, x_cex, top_k=15, title="Top removed trees at CE"
):
    xb = x_cex.reshape(1, -1)
    big_votes = np.array([t.predict_proba(xb)[:, 1][0] for t in big.estimators_], dtype=float)
    kept_ids = {id(t) for t in pruned.estimators_}
    removed = [i for i, t in enumerate(big.estimators_) if id(t) not in kept_ids]
    if not removed:
        print("No removed trees."); return

    base = big_votes.mean()
    eff = []
    for i in removed:
        new_mean = (big_votes.sum() - big_votes[i]) / (len(big_votes) - 1)
        eff.append((abs(base - new_mean), i))
    eff.sort(reverse=True)
    eff = eff[:min(top_k, len(eff))]
    deltas = [e[0] for e in eff]; idxs = [e[1] for e in eff]

    plt.figure(figsize=(8, 4))
    plt.bar(range(len(eff)), deltas, color="#b87333")
    plt.xticks(range(len(eff)), idxs, rotation=45)
    plt.ylabel("Δ avg vote if removed"); plt.title(title)
    plt.tight_layout(); plt.show()

def plot_conflict_where_big_confident(
    big, pruned, lo, hi, *, center=None, dims=None, n=220, m=0.2,
    title="Pruned disagrees where BIG margin is high"
):
    d = len(lo)
    if center is None: center = (lo + hi) / 2.0
    if dims is None: dims = best_dims_for_disagreement(big, pruned, lo, hi)
    i, j = dims
    xs, ys, base = _grid(lo, hi, center, i, j, n)

    vb = rf_vote_prob(big.estimators_, base)
    yb = (vb >= 0.5).astype(int)
    yp = (rf_vote_prob(pruned.estimators_, base) >= 0.5).astype(int)
    confident = (vb >= 0.5 + m) | (vb <= 0.5 - m)
    conflict = (yb != yp) & confident
    Z = conflict.reshape(n, n).astype(float)

    _imshow_copper(
        Z, xs, ys, title,
        cbar_label=f"disagree & margin≥{m}",
        xlabel=f"x[{i}]",
        ylabel=f"x[{j}]"
    )



# %%
pts = [] 
if 'cex' in globals(): pts.append(cex)
if 'cex_margin' in globals() and cex_margin is not None: pts.append(cex_margin)

plot_vote_diff_slice(big, pruned, lo, hi, center=(lo+hi)/2, dims=None, n=300, pts=pts, eps=0.06)

# 2) Show where labels differ on that slice
plot_label_disagreement_region(big, pruned, lo, hi, center=(lo+hi)/2, dims=None, n=300, pts=pts)

# 3) Explain which removed trees mattered at a specific CE
if 'cex' in globals():
    plot_removed_tree_effects_at_x(big, pruned, cex, top_k=12)

# 4) Safety view: disagreements only where BIG is confident by ≥ 0.2 margin
plot_conflict_where_big_confident(big, pruned, lo, hi, center=(lo+hi)/2, dims=None, n=300, m=0.05)

# 5) (Optional) Force the slice through a particular CE and feature pair
# dims=(0,3) or any other two indices; center=cex to slice through the counterexample
if 'cex' in globals():
    plot_vote_diff_slice(big, pruned, lo, hi, center=cex, dims=(0,3), n=350, pts=[cex], eps=0.05)