src_z3_overconfidence

Full_run.py

# %%
 
import argparse, json, math, os
import numpy as np
from sklearn.datasets import load_digits, make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
import z3
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt

def wilson_lcb(correct: int, total: int, z: float = 1.96) -> float:
    if total <= 0: return 0.0
    p = correct / total
    denom  = 1.0 + (z*z)/total
    center = (p + (z*z)/(2*total)) / denom
    margin = z * math.sqrt((p*(1-p) + (z*z)/(4*total)) / total) / denom
    return max(0.0, center - margin)

def make_ood_like(X, rng: np.random.Generator, noise_sigma=0.8):
    Xs = X[:, rng.permutation(X.shape[1])]
    Xs = np.clip(Xs + rng.normal(0, noise_sigma, Xs.shape), 0, 1)
    return Xs

def temp_hack(probs, factor=8.0):
    p = probs ** factor
    return p / p.sum(axis=1, keepdims=True)

def make_imbalance(seed=0, n_samples=4000):
    Xb, yb = make_classification(
        n_samples=n_samples, n_features=32, n_informative=6, n_redundant=2,
        weights=[0.98, 0.02], random_state=seed
    )
    Xtrb, Xtb, ytrb, ytb = train_test_split(Xb, yb, test_size=0.5, random_state=seed, stratify=yb)
    clf_b = MLPClassifier(hidden_layer_sizes=(32,), max_iter=200, random_state=seed)
    clf_b.fit(Xtrb, ytrb)
    probs_b = clf_b.predict_proba(Xtb)
    preds_b = probs_b.argmax(axis=1)
    acc_b = accuracy_score(ytb, preds_b)
    return probs_b, ytb, acc_b

def build_bins(conf, acc, num_bins=6, wilson_z=1.96):
    edges = np.linspace(0.0, 1.0, num_bins+1)
    meta = []
    for i in range(num_bins):
        lo, hi = edges[i], edges[i+1]
        mask = (conf >= lo) & (conf < hi if i < num_bins-1 else conf <= hi)
        n = int(mask.sum())
        if n == 0: continue
        correct = int(acc[mask].sum())
        L = wilson_lcb(correct, n, wilson_z)
        meta.append({"bin": i, "lo": float(lo), "hi": float(hi),
                     "conf": float(conf[mask].mean()), "LCB": float(L), "support": n})
    return meta

def plot_bins(meta, out_path="calibration_plot.png", title="Calibration bins"):
    bins = [b["bin"] for b in meta]
    conf = [b["conf"] for b in meta]
    LCB  = [b["LCB"] for b in meta]
    plt.figure(figsize=(8,4))
    plt.plot(bins, conf, marker="o", label="mean_confidence")
    plt.plot(bins, LCB, marker="o", label="LCB_accuracy")
    plt.title(title); plt.xlabel("bin index"); plt.ylabel("value")
    plt.legend(); plt.grid(True); plt.tight_layout()
    plt.savefig(out_path, dpi=180); plt.close()

def z3_prob_vector_for_violating_bin(meta, epsilon, num_classes):
    vb = max(meta, key=lambda b: b["conf"] - b["LCB"])
    lo, hi, LCB = vb["lo"], vb["hi"], vb["LCB"]
    s = z3.Solver()
    p = [z3.Real(f"p_{i}") for i in range(num_classes)]
    pred = [z3.Bool(f"pred_{i}") for i in range(num_classes)]
    for i in range(num_classes): s.add(p[i] >= 0)
    s.add(z3.Sum(p) == 1)
    s.add(z3.PbEq([(pred[i], 1) for i in range(num_classes)], 1))
    conf_expr = z3.Sum([z3.If(pred[i], p[i], 0) for i in range(num_classes)])
    s.add(conf_expr >= lo); s.add(conf_expr <= hi); s.add(conf_expr > LCB + epsilon)
    if s.check() != z3.sat: return None
    m = s.model()
    def z3_to_float(val):
        s = str(val)
        if '/' in s:
            num, den = s.split('/'); return float(num)/float(den)
        return float(s.replace('?', ''))
    p_star = np.array([z3_to_float(m.eval(p[i])) for i in range(num_classes)], dtype=np.float32)
    target_idx = next(i for i in range(num_classes) if m.eval(pred[i]) == z3.BoolVal(True))
    return p_star, target_idx, (lo, hi), LCB + epsilon

def train_torch_digits(seed=0):
    X, y = load_digits(return_X_y=True)
    X = (X / 16.0).astype(np.float32)
    Xtr, Xt, ytr, yt = train_test_split(X, y, test_size=0.2, stratify=y, random_state=seed)
    Xtr_t = torch.tensor(Xtr); ytr_t = torch.tensor(ytr, dtype=torch.long)
    Xt_t  = torch.tensor(Xt);  yt_t  = torch.tensor(yt, dtype=torch.long)
    class MLP(nn.Module):
        def __init__(self): super().__init__(); self.fc1=nn.Linear(64,64); self.fc2=nn.Linear(64,10)
        def forward(self,x): return self.fc2(torch.relu(self.fc1(x)))
    model = MLP(); opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for _ in range(15):
        model.train(); idx=torch.randperm(len(Xtr_t))
        for i in range(0,len(Xtr_t),128):
            b=idx[i:i+128]; logits=model(Xtr_t[b])
            loss=F.cross_entropy(logits,ytr_t[b]); opt.zero_grad(); loss.backward(); opt.step()
    model.eval()
    with torch.no_grad():
        logits = model(Xt_t); pred = logits.argmax(1); acc = (pred==yt_t).float().mean().item()
    return model, acc

def cegis_input(model, p_star, target_idx, lo, hi, lcb_plus_eps, steps=1200, lr=0.12, lam=1e-3):
    torch.manual_seed(0)
    p_star_t = torch.tensor(p_star)
    x = torch.rand(64, requires_grad=True)
    opt_x = torch.optim.Adam([x], lr=lr)
    def kl(p,q): return torch.sum(p*(torch.log(p+1e-9)-torch.log(q+1e-9)))
    hit=None; q=None; floor=max(lo,lcb_plus_eps)
    for t in range(steps):
        logits=model(x); q=torch.softmax(logits,dim=0)
        loss=kl(p_star_t,q)+lam*torch.sum(x*x)
        opt_x.zero_grad(); loss.backward(); opt_x.step()
        with torch.no_grad(): x.clamp_(0.0,1.0)
        if (q[target_idx]>=floor) and (q[target_idx]<=hi): hit=t; break
    with torch.no_grad(): logits=model(x); q=torch.softmax(logits,dim=0)
    return x.detach().numpy(), q.detach().numpy(), hit



# %%
 
import os, json, numpy as np, z3
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

runs = [
    {"tag":"temp_hack_strong", "mode":"temp_hack", "temp_factor":10.0, "bins":6, "min_support":30, "epsilon":0.02},
    {"tag":"ood_shift",        "mode":"ood",       "noise_sigma":0.9,  "bins":6, "min_support":20, "epsilon":0.02},
    {"tag":"imbalance",        "mode":"imbalance",                         "bins":6, "min_support":10, "epsilon":0.02},
    {"tag":"clean_strict",     "mode":"clean",                             "bins":6, "min_support":50, "epsilon":0.02},
]

# base data once
X_all, y_all = load_digits(return_X_y=True)
X_all = X_all / 16.0

for r in runs:
    tag = r["tag"]
    outdir = os.path.join(".", f"out_{tag}")
    os.makedirs(outdir, exist_ok=True)
    print(f"\n--- RUN: {tag} (mode={r['mode']}) ---")

    # train sklearn model for probs
    Xtr, Xt, ytr, yt = train_test_split(X_all, y_all, test_size=0.2, random_state=0, stratify=y_all)
    clf = MLPClassifier(hidden_layer_sizes=(64,), max_iter=120, random_state=0).fit(Xtr, ytr)
    probs_clean = clf.predict_proba(Xt)
    acc_clean = accuracy_score(yt, probs_clean.argmax(1))
    print(f"[sklearn] clean accuracy={acc_clean:.3f}")

    # --- choose mode
    if r["mode"] == "clean":
        probs, y_eval = probs_clean, yt

    elif r["mode"] == "temp_hack":
        f = r.get("temp_factor", 8.0)
        p = probs_clean ** f
        probs = p / p.sum(axis=1, keepdims=True)
        y_eval = yt

    elif r["mode"] == "ood":
        # FIX: build OOD from Xt (size 360), not Xtr (1437)
        rng = np.random.default_rng(0)
        X_ood = Xt[:, rng.permutation(Xt.shape[1])]
        X_ood = np.clip(X_ood + rng.normal(0, r.get("noise_sigma", 0.8), X_ood.shape), 0, 1)
        probs = clf.predict_proba(X_ood)
        y_eval = yt  # labels from the test split
        # sanity
        assert probs.shape[0] == len(y_eval), f"OOD size mismatch: probs={probs.shape[0]} vs labels={len(y_eval)}"

    elif r["mode"] == "imbalance":
        probs, y_eval, _ = make_imbalance(seed=0, n_samples=4000)

    else:
        raise ValueError("unknown mode")

    # --- eval vectors
    preds = probs.argmax(axis=1)
    conf  = probs.max(axis=1)
    acc   = (preds == y_eval).astype(int)
    print(f"[eval] mean conf={conf.mean():.3f}, accuracy={acc.mean():.3f}")

    # --- bins + artifacts
    meta = build_bins(conf, acc, num_bins=r["bins"], wilson_z=1.96)
    for b in meta: b["num_classes"] = probs.shape[1]

    with open(os.path.join(outdir, "calibration.json"), "w") as f:
        json.dump({"bins":meta, "mode":r["mode"], "epsilon":r["epsilon"], "min_support":r["min_support"]}, f, indent=2)

    plot_bins(meta, out_path=os.path.join(outdir, "calibration_plot.png"), title=f"Calibration bins ({tag})")

    # --- SMT decision (exists violating bin with support >= min_support)
    s = z3.Solver()
    conf_v=[b["conf"] for b in meta]; LCB_v=[b["LCB"] for b in meta]; sup_v=[b["support"] for b in meta]
    B=[z3.Bool(f"Bin_{i}") for i in range(len(meta))]
    for i in range(len(meta)):
        s.add(z3.Implies(B[i], conf_v[i] > LCB_v[i] + r["epsilon"]))
        s.add(z3.Implies(B[i], sup_v[i] >= r["min_support"]))
    s.add(z3.Or(B) if B else z3.BoolVal(False))
    print("SMT overconfident bin exists?", "YES" if s.check()==z3.sat else "NO")

    # --- Z3: violating probability target + CEGIS (only if K matches torch model)
    out = z3_prob_vector_for_violating_bin(meta, r["epsilon"], probs.shape[1])
    if out is None:
        print("  -> no violating probability vector (no non-empty gap).")
        continue

    p_star, target_idx, (lo,hi), lcb_eps = out
    K = probs.shape[1]
    print(f"  -> Z3 p* conf={p_star[target_idx]:.4f}, bin=({lo:.2f},{hi:.2f}), LCB+eps={lcb_eps:.4f}, K={K}")

    # CEGIS demo is implemented for the 10-class digits model (64→64→10).
    if K != 10:
        print(f"  -> [CEGIS] skipped (Torch demo model is 10-class; this run has K={K}). "
              f"Artifacts (calibration.json/plot) written.")
        print(f"Artifacts -> {os.path.abspath(outdir)}: calibration.json, calibration_plot.png")
        continue

    # ---- 10-class only: train torch digits model + CEGIS input synthesis
    model, tacc = train_torch_digits(seed=0)
    x_img, q, hit = cegis_input(model, p_star, target_idx, lo, hi, lcb_eps,
                                steps=1200, lr=0.12, lam=1e-3)
    print(f"  -> CEGIS hit? {'YES' if hit is not None else 'NO'} at step {hit}; q[target]={q[target_idx]:.4f}")

    import matplotlib.pyplot as plt
    plt.figure(figsize=(3,3)); plt.imshow(x_img.reshape(8,8), cmap="gray"); plt.axis("off")
    plt.title(f"{tag} q[target]={q[target_idx]:.3f}"); plt.tight_layout()
    plt.savefig(os.path.join(outdir, "cegis_input.png"), dpi=180); plt.close()

    print(f"Artifacts -> {os.path.abspath(outdir)}: calibration.json, calibration_plot.png, cegis_input.png")