diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2e4b5f6..8647a5e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,8 +13,6 @@ jobs: steps: - uses: actions/checkout@v4 - with: - lfs: true - name: Set up Python 3.9 uses: actions/setup-python@v5 @@ -25,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e ".[dev]" + pip install -e ".[dev,calibration]" - name: Run tests with coverage run: | diff --git a/experiments/compare_exposures.py b/experiments/compare_exposures.py new file mode 100644 index 0000000..14fdd06 --- /dev/null +++ b/experiments/compare_exposures.py @@ -0,0 +1,141 @@ +"""Compare effective weight/exposure trajectories: reClAMM vs Balancer 50/50. + +Prints weight stats and saves a plot of weight[AAVE] over time for both pools. +""" + +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(exponent): + return 1.0 - exponent / 124649.0 + + +TOKENS = ["AAVE", "ETH"] +START = "2024-06-01 00:00:00" +END = "2025-06-01 00:00:00" + +CONFIGS = { + "reClAMM on-chain (pr=1.5)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(0.1)), + }, + }, + "reClAMM wide (pr=4)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(1.0)), + }, + }, + "reClAMM Phase 2 (pr=4, m=0.1)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.1), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(0.001)), + }, + }, + "Balancer 50/50": { + "fingerprint": { + "tokens": TOKENS, "rule": "balancer", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "initial_weights_logits": jnp.zeros(2), + }, + }, +} + +results = {} +for name, cfg in CONFIGS.items(): + print(f"Running {name}...") + r = do_run_on_historic_data( + run_fingerprint=cfg["fingerprint"], params=cfg["params"] + ) + results[name] = r + +# Compute effective weights (value fraction in token 0 = AAVE) +print("\n" + "=" * 90) +print(f" {'Config':<35s} {'w_AAVE mean':>10s} {'w_AAVE std':>10s} " + f"{'w_AAVE min':>10s} {'w_AAVE max':>10s} {'vs HODL':>10s}") +print("-" * 90) + +daily = 1440 # subsample to daily for stats and plotting +fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True) + +for name, r in results.items(): + reserves = np.array(r["reserves"]) + prices = np.array(r["prices"]) + values = reserves * prices # (T, 2) + total = values.sum(axis=1, keepdims=True) + weights = values / np.clip(total, 1e-10, None) # (T, 2) + w_aave = weights[::daily, 0] + + hodl_value = float((reserves[0] * prices[-1]).sum()) + vs_hodl = r["final_value"] / hodl_value - 1.0 + + print(f" {name:<35s} {w_aave.mean():>10.4f} {w_aave.std():>10.4f} " + f"{w_aave.min():>10.4f} {w_aave.max():>10.4f} {vs_hodl * 100:>9.2f}%") + + days = np.arange(len(w_aave)) + axes[0].plot(days, w_aave, label=name, alpha=0.8) + + # Pool value over time + pool_val = np.array(r["value"])[::daily] + axes[1].plot(days[:len(pool_val)], pool_val / 1e6, label=name, alpha=0.8) + +print("=" * 90) + +# HODL line +r0 = results[list(results.keys())[0]] +prices_daily = np.array(r0["prices"])[::daily] +reserves_0 = np.array(r0["reserves"])[0] +hodl_val = (reserves_0 * prices_daily).sum(axis=1) / 1e6 +axes[1].plot(np.arange(len(hodl_val)), hodl_val, label="HODL", ls="--", color="gray", alpha=0.7) + +# Price ratio (AAVE/ETH) on third axis +price_ratio_series = prices_daily[:, 0] / prices_daily[:, 1] +axes[2].plot(np.arange(len(price_ratio_series)), price_ratio_series, color="black", alpha=0.7) +axes[2].set_ylabel("AAVE/ETH price") +axes[2].set_xlabel("Days") + +axes[0].set_ylabel("AAVE weight (value fraction)") +axes[0].axhline(0.5, ls="--", color="gray", alpha=0.5) +axes[0].legend(fontsize=8) +axes[0].set_title("Effective AAVE exposure over time") + +axes[1].set_ylabel("Pool value ($M)") +axes[1].legend(fontsize=8) +axes[1].set_title("Pool value over time") + +plt.tight_layout() +plt.savefig("reclamm_exposure_comparison.png", dpi=150) +print("\nSaved reclamm_exposure_comparison.png") diff --git a/experiments/tune_reclamm_params.py b/experiments/tune_reclamm_params.py index 0951e2c..4969923 100644 --- a/experiments/tune_reclamm_params.py +++ b/experiments/tune_reclamm_params.py @@ -38,6 +38,17 @@ def main(): parser.add_argument("--interpolation", default="geometric", choices=["geometric", "constant_arc_length"]) parser.add_argument("--centeredness-scaling", action="store_true") + parser.add_argument("--noise-trader-ratio", type=float, default=0.0) + parser.add_argument("--start-date", default="2024-06-01 00:00:00") + parser.add_argument("--end-date", default="2025-01-01 00:00:00", + help="End of training / start of test") + parser.add_argument("--end-test-date", default="2025-06-01 00:00:00") + parser.add_argument("--bout-offset", type=int, default=None, + help="bout_offset in minutes (default: 10080 = 7 days)") + parser.add_argument("--val-fraction", type=float, default=None, + help="Validation holdout fraction (default: 0.2, use 0 to disable)") + parser.add_argument("--overfitting-penalty", type=float, default=None, + help="Overfitting penalty weight (default: 0.2)") args = parser.parse_args() learn_speed = args.interpolation == "constant_arc_length" @@ -48,29 +59,33 @@ def main(): fp = { "rule": "reclamm", "tokens": ["AAVE", "ETH"], - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-01-01 00:00:00", - "endTestDateString": "2025-06-01 00:00:00", + "startDateString": args.start_date, + "endDateString": args.end_date, + "endTestDateString": args.end_test_date, "initial_pool_value": 1_000_000.0, "do_arb": True, "fees": args.fees, "gas_cost": args.gas_cost, "arb_fees": 0.0, "protocol_fee_split": 0.5, + "noise_trader_ratio": args.noise_trader_ratio, "return_val": args.objective, "reclamm_interpolation_method": args.interpolation, "reclamm_centeredness_scaling": args.centeredness_scaling, "reclamm_learn_arc_length_speed": learn_speed, "reclamm_use_shift_exponent": True, + **({"bout_offset": args.bout_offset} if args.bout_offset is not None else {}), "optimisation_settings": { "method": "optuna", "n_parameter_sets": 1, + **({"val_fraction": args.val_fraction} if args.val_fraction is not None else {}), "optuna_settings": { "make_scalar": True, "expand_around": False, "n_trials": args.n_trials, "multi_objective": False, "parameter_config": param_config, + **({"overfitting_penalty": args.overfitting_penalty} if args.overfitting_penalty is not None else {}), }, }, } diff --git a/pyproject.toml b/pyproject.toml index 413faf9..8859b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,10 @@ docs = [ "sphinx-automodapi", "sphinx-rtd-theme", ] +calibration = [ + "numpyro>=0.15.0", + "arviz>=0.15.0", +] [tool.hatch.build.targets.wheel] packages = [ diff --git a/quantammsim/calibration/__init__.py b/quantammsim/calibration/__init__.py new file mode 100644 index 0000000..a888758 --- /dev/null +++ b/quantammsim/calibration/__init__.py @@ -0,0 +1,41 @@ +from quantammsim.calibration.grid_interpolation import ( + PoolCoeffs, + PoolCoeffsDaily, + PoolGridInterpolator, + build_scipy_interpolator, + interpolate_pool, + interpolate_pool_daily, + load_daily_grid, + load_valid_pool_grids, + pivot_grid, + precompute_pool_coeffs, + precompute_pool_coeffs_daily, +) +from quantammsim.calibration.joint_fit import ( + JointData, + fit_joint, + predict_new_pool_joint, +) +from quantammsim.calibration.learned_mapping import ( + build_targets, + cross_validate_loo, + fit_mapping, + predict_pool, +) +from quantammsim.calibration.loss import ( + K_OBS, + noise_volume, + pack_params, + pool_loss, + unpack_params, +) +from quantammsim.calibration.per_pool_fit import ( + fit_all_pools, + fit_single_pool, + make_initial_guess, +) +from quantammsim.calibration.pool_data import ( + build_pool_attributes, + build_x_obs, + match_grids_to_panel, +) diff --git a/quantammsim/calibration/grid_interpolation.py b/quantammsim/calibration/grid_interpolation.py new file mode 100644 index 0000000..bf0c55f --- /dev/null +++ b/quantammsim/calibration/grid_interpolation.py @@ -0,0 +1,463 @@ +"""PCHIP interpolation layer for precomputed arb-volume grids. + +Two grid formats: + - v1 (scalar): cadence x gas_cost -> median daily V_arb (single scalar) + - v2 (daily): cadence x gas_cost x day -> per-day V_arb (vector output) + +Interpolation in (log(cadence), gas_cost) space using PCHIP (monotone +piecewise cubic Hermite), which avoids Runge oscillation on non-uniform grids. + +Two interfaces: + 1. scipy-based: RegularGridInterpolator(method='pchip') for validation/plotting + 2. JAX-compatible: precomputed slopes + Hermite cubic eval, fully differentiable + +The JAX path uses tensor-product evaluation: + - Along cadence: Hermite cubic with scipy-precomputed PCHIP slopes + - Along gas: PCHIP slopes computed on the fly from intermediate values +""" + +import os +from typing import Dict, NamedTuple, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +from scipy.interpolate import PchipInterpolator, RegularGridInterpolator + +GRID_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "results", + "pool_grids", +) + +GRID_DIR_V2 = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "results", + "pool_grids_v2", +) + + +# ── Grid loading ───────────────────────────────────────────────────────── + + +def load_pool_grid(pool_id_prefix: str, grid_dir: str = GRID_DIR) -> pd.DataFrame: + """Load a single pool's grid CSV.""" + path = os.path.join(grid_dir, f"{pool_id_prefix}_grid.csv") + return pd.read_csv(path) + + +def load_valid_pool_grids(grid_dir: str = GRID_DIR) -> Dict[str, pd.DataFrame]: + """Load all pool grid CSVs that have valid (non-NaN) data.""" + grids = {} + for f in sorted(os.listdir(grid_dir)): + if f.endswith("_grid.csv") and f != "grid_summary.csv": + prefix = f.replace("_grid.csv", "") + df = pd.read_csv(os.path.join(grid_dir, f)) + if df["median_daily_arb_volume"].notna().any(): + grids[prefix] = df + return grids + + +def pivot_grid( + df: pd.DataFrame, value_col: str = "median_daily_arb_volume" +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Pivot grid DataFrame to (log_cadences, gas_costs, values) arrays.""" + pivot = df.pivot(index="cadence", columns="gas_cost", values=value_col) + cadences = pivot.index.values.astype(float) + gas_costs = pivot.columns.values.astype(float) + values = pivot.values.astype(float) + return np.log(cadences), gas_costs, values + + +# ── Scipy interpolation ───────────────────────────────────────────────── + + +def build_scipy_interpolator( + df: pd.DataFrame, value_col: str = "median_daily_arb_volume" +) -> RegularGridInterpolator: + """Build a scipy RegularGridInterpolator with PCHIP method.""" + log_cadences, gas_costs, values = pivot_grid(df, value_col) + return RegularGridInterpolator( + (log_cadences, gas_costs), + values, + method="pchip", + bounds_error=False, + fill_value=None, + ) + + +def query_scipy( + interp: RegularGridInterpolator, cadence: float, gas_cost: float +) -> float: + """Query scipy interpolator at (cadence, gas_cost). Cadence in minutes.""" + log_cad = np.log(np.clip(cadence, 1.0, 60.0)) + return float(interp(np.array([[log_cad, gas_cost]]))[0]) + + +# ── JAX-compatible PCHIP ──────────────────────────────────────────────── + + +class PoolCoeffs(NamedTuple): + """Precomputed coefficients for one pool's 2D PCHIP interpolation.""" + + log_cadences: jnp.ndarray # (n_cad,) + gas_costs: jnp.ndarray # (n_gas,) + values: jnp.ndarray # (n_cad, n_gas) + slopes_cad: jnp.ndarray # (n_cad, n_gas) PCHIP slopes along cadence axis + + +def precompute_pool_coeffs( + df: pd.DataFrame, value_col: str = "median_daily_arb_volume" +) -> PoolCoeffs: + """Precompute PCHIP slopes along cadence axis using scipy. + + These slopes are used by the JAX evaluation function for the first + interpolation axis (cadence). The second axis (gas) is computed on + the fly in JAX to maintain full differentiability. + """ + log_cadences, gas_costs, values = pivot_grid(df, value_col) + + n_gas = values.shape[1] + slopes = np.zeros_like(values) + for j in range(n_gas): + pchip = PchipInterpolator(log_cadences, values[:, j]) + slopes[:, j] = pchip.derivative()(log_cadences) + + return PoolCoeffs( + log_cadences=jnp.array(log_cadences), + gas_costs=jnp.array(gas_costs), + values=jnp.array(values), + slopes_cad=jnp.array(slopes), + ) + + +@jax.jit +def _pchip_slopes(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Compute PCHIP slopes via Fritsch-Carlson method. JAX-compatible. + + x: (n,) sorted knot positions + y: (n,) values at knots + Returns: (n,) slopes at knots + """ + h = x[1:] - x[:-1] + delta = (y[1:] - y[:-1]) / h + + # Interior points: weighted harmonic mean of neighboring secants + w1 = 2 * h[1:] + h[:-1] + w2 = h[1:] + 2 * h[:-1] + + sign_agree = (delta[:-1] * delta[1:]) > 0 + + # When sign_agree is False, d_mid is masked to 0. But the harmonic mean + # can produce Inf when deltas have opposite signs and w1==w2 (denominator + # cancels). JAX's where can't mask the NaN gradient of Inf (0*NaN=NaN). + # Fix: replace deltas with 1.0 when sign_agree is False, ensuring hm + # is always finite. The value is irrelevant since it gets masked. + d0 = jnp.where(sign_agree, delta[:-1], 1.0) + d1 = jnp.where(sign_agree, delta[1:], 1.0) + d0 = jnp.where(d0 == 0, 1e-30, d0) + d1 = jnp.where(d1 == 0, 1e-30, d1) + + hm = (w1 + w2) / (w1 / d0 + w2 / d1) + hm = jnp.where(jnp.isfinite(hm), hm, 0.0) + d_mid = jnp.where(sign_agree, hm, 0.0) + + # Endpoints: one-sided shape-preserving + d0 = ((2 * h[0] + h[1]) * delta[0] - h[0] * delta[1]) / (h[0] + h[1]) + d0 = jnp.where(d0 * delta[0] <= 0, 0.0, d0) + d0 = jnp.where( + (delta[0] * delta[1] < 0) & (jnp.abs(d0) > 3 * jnp.abs(delta[0])), + 3 * delta[0], + d0, + ) + + dn = ((2 * h[-1] + h[-2]) * delta[-1] - h[-1] * delta[-2]) / (h[-1] + h[-2]) + dn = jnp.where(dn * delta[-1] <= 0, 0.0, dn) + dn = jnp.where( + (delta[-1] * delta[-2] < 0) & (jnp.abs(dn) > 3 * jnp.abs(delta[-1])), + 3 * delta[-1], + dn, + ) + + return jnp.concatenate([d0[None], d_mid, dn[None]]) + + +@jax.jit +def interpolate_pool( + coeffs: PoolCoeffs, log_cadence: jnp.ndarray, gas_cost: jnp.ndarray +) -> jnp.ndarray: + """Evaluate 2D PCHIP at (log_cadence, gas_cost). JAX-differentiable. + + Tensor-product approach: + 1. Hermite cubic along cadence for all gas columns (precomputed slopes) + 2. PCHIP slopes along gas through intermediate values (computed on the fly) + 3. Hermite cubic along gas to final value + + Args: + coeffs: PoolCoeffs from precompute_pool_coeffs + log_cadence: scalar, log of cadence in minutes + gas_cost: scalar, effective profit threshold in USD + Returns: + V_arb: scalar, interpolated median daily arb volume + """ + log_cads = coeffs.log_cadences + gas = coeffs.gas_costs + vals = coeffs.values + sl_cad = coeffs.slopes_cad + + # Clamp to grid bounds + log_cadence = jnp.clip(log_cadence, log_cads[0], log_cads[-1]) + gas_cost = jnp.clip(gas_cost, gas[0], gas[-1]) + + # ── Step 1: Hermite along cadence for all gas columns ── + idx = jnp.searchsorted(log_cads, log_cadence) - 1 + idx = jnp.clip(idx, 0, log_cads.shape[0] - 2) + + h = log_cads[idx + 1] - log_cads[idx] + t = (log_cadence - log_cads[idx]) / h + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + v_at_gas = ( + h00 * vals[idx, :] + + h01 * vals[idx + 1, :] + + h * (h10 * sl_cad[idx, :] + h11 * sl_cad[idx + 1, :]) + ) + + # ── Step 2: PCHIP slopes along gas ── + gas_slopes = _pchip_slopes(gas, v_at_gas) + + # ── Step 3: Hermite along gas ── + jdx = jnp.searchsorted(gas, gas_cost) - 1 + jdx = jnp.clip(jdx, 0, gas.shape[0] - 2) + + hg = gas[jdx + 1] - gas[jdx] + s = (gas_cost - gas[jdx]) / hg + s2 = s * s + s3 = s2 * s + + g00 = 2 * s3 - 3 * s2 + 1 + g10 = s3 - 2 * s2 + s + g01 = -2 * s3 + 3 * s2 + g11 = s3 - s2 + + return ( + g00 * v_at_gas[jdx] + + g01 * v_at_gas[jdx + 1] + + hg * (g10 * gas_slopes[jdx] + g11 * gas_slopes[jdx + 1]) + ) + + +# ── Per-day (v2) grid support ────────────────────────────────────────────── + + +class PoolCoeffsDaily(NamedTuple): + """Precomputed coefficients for per-day 2D PCHIP interpolation. + + Like PoolCoeffs but values/slopes have a day dimension: + values: (n_cad, n_gas, n_days) + slopes_cad: (n_cad, n_gas, n_days) + dates: (n_days,) ordinal dates for alignment with panel + """ + + log_cadences: jnp.ndarray # (n_cad,) + gas_costs: jnp.ndarray # (n_gas,) + values: jnp.ndarray # (n_cad, n_gas, n_days) + slopes_cad: jnp.ndarray # (n_cad, n_gas, n_days) + dates: jnp.ndarray # (n_days,) ordinal dates + + +def load_daily_grid( + pool_id_prefix: str, grid_dir: str = GRID_DIR_V2 +) -> pd.DataFrame: + """Load a pool's per-day grid parquet.""" + path = os.path.join(grid_dir, f"{pool_id_prefix}_daily.parquet") + return pd.read_parquet(path) + + +def precompute_pool_coeffs_daily(df: pd.DataFrame) -> PoolCoeffsDaily: + """Build PoolCoeffsDaily from per-day grid DataFrame. + + Args: + df: DataFrame with columns [cadence, gas_cost, date, daily_arb_volume] + + Returns: + PoolCoeffsDaily with 3D values (n_cad, n_gas, n_days) + """ + df = df.copy() + df["date"] = pd.to_datetime(df["date"]) + + cadences = np.array(sorted(df["cadence"].unique()), dtype=float) + gas_costs = np.array(sorted(df["gas_cost"].unique()), dtype=float) + dates = np.array(sorted(df["date"].unique())) + log_cadences = np.log(cadences) + + n_cad = len(cadences) + n_gas = len(gas_costs) + n_days = len(dates) + + # Build 3D array: cadence x gas x day + values = np.zeros((n_cad, n_gas, n_days)) + cad_idx = {c: i for i, c in enumerate(cadences)} + gas_idx = {g: i for i, g in enumerate(gas_costs)} + date_idx = {d: i for i, d in enumerate(dates)} + + for _, row in df.iterrows(): + ci = cad_idx.get(float(row["cadence"])) + gi = gas_idx.get(float(row["gas_cost"])) + di = date_idx.get(row["date"]) + if ci is not None and gi is not None and di is not None: + values[ci, gi, di] = row["daily_arb_volume"] + + # Compute PCHIP slopes along cadence axis for each (gas, day) + slopes = np.zeros_like(values) + for j in range(n_gas): + for k in range(n_days): + col = values[:, j, k] + if np.all(np.isfinite(col)): + pchip = PchipInterpolator(log_cadences, col) + slopes[:, j, k] = pchip.derivative()(log_cadences) + + # Convert dates to ordinals for JAX + date_ordinals = np.array([ + pd.Timestamp(d).toordinal() for d in dates + ], dtype=np.int32) + + return PoolCoeffsDaily( + log_cadences=jnp.array(log_cadences), + gas_costs=jnp.array(gas_costs), + values=jnp.array(values), + slopes_cad=jnp.array(slopes), + dates=jnp.array(date_ordinals), + ) + + +@jax.jit +def interpolate_pool_daily( + coeffs: PoolCoeffsDaily, + log_cadence: jnp.ndarray, + gas_cost: jnp.ndarray, +) -> jnp.ndarray: + """Evaluate 2D PCHIP at (log_cadence, gas_cost) for all days. + + Same tensor-product approach as interpolate_pool, but values are 3D + (n_cad, n_gas, n_days) so the output is (n_days,). + + The Hermite basis coefficients are scalars that broadcast over the + day dimension of values. + + Args: + coeffs: PoolCoeffsDaily from precompute_pool_coeffs_daily + log_cadence: scalar, log of cadence in minutes + gas_cost: scalar, effective profit threshold in USD + Returns: + V_arb: (n_days,) interpolated daily arb volume + """ + log_cads = coeffs.log_cadences + gas = coeffs.gas_costs + vals = coeffs.values # (n_cad, n_gas, n_days) + sl_cad = coeffs.slopes_cad # (n_cad, n_gas, n_days) + + # Clamp to grid bounds + log_cadence = jnp.clip(log_cadence, log_cads[0], log_cads[-1]) + gas_cost = jnp.clip(gas_cost, gas[0], gas[-1]) + + # ── Step 1: Hermite along cadence for all gas columns, all days ── + idx = jnp.searchsorted(log_cads, log_cadence) - 1 + idx = jnp.clip(idx, 0, log_cads.shape[0] - 2) + + h = log_cads[idx + 1] - log_cads[idx] + t = (log_cadence - log_cads[idx]) / h + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + # vals[idx, :, :] is (n_gas, n_days) — scalars broadcast + v_at_gas = ( + h00 * vals[idx, :, :] + + h01 * vals[idx + 1, :, :] + + h * (h10 * sl_cad[idx, :, :] + h11 * sl_cad[idx + 1, :, :]) + ) # (n_gas, n_days) + + # ── Step 2: PCHIP slopes along gas, vmapped over days ── + gas_slopes = jax.vmap( + lambda y_col: _pchip_slopes(gas, y_col), + in_axes=1, out_axes=1, + )(v_at_gas) # (n_gas, n_days) + + # ── Step 3: Hermite along gas ── + jdx = jnp.searchsorted(gas, gas_cost) - 1 + jdx = jnp.clip(jdx, 0, gas.shape[0] - 2) + + hg = gas[jdx + 1] - gas[jdx] + s = (gas_cost - gas[jdx]) / hg + s2 = s * s + s3 = s2 * s + + g00 = 2 * s3 - 3 * s2 + 1 + g10 = s3 - 2 * s2 + s + g01 = -2 * s3 + 3 * s2 + g11 = s3 - s2 + + # v_at_gas[jdx] is (n_days,) — scalars broadcast + return ( + g00 * v_at_gas[jdx] + + g01 * v_at_gas[jdx + 1] + + hg * (g10 * gas_slopes[jdx] + g11 * gas_slopes[jdx + 1]) + ) # (n_days,) + + +# ── Convenience class ──────────────────────────────────────────────────── + + +class PoolGridInterpolator: + """Collection of PCHIP interpolators for all valid pool grids. + + Provides both scipy (for validation) and JAX (for optimization) access. + """ + + def __init__(self, grid_dir: str = GRID_DIR): + self.grid_dir = grid_dir + grids = load_valid_pool_grids(grid_dir) + + self._scipy_interps = {} + self._jax_coeffs = {} + self._pool_ids = sorted(grids.keys()) + + for pid, df in grids.items(): + self._scipy_interps[pid] = build_scipy_interpolator(df) + self._jax_coeffs[pid] = precompute_pool_coeffs(df) + + @property + def pool_ids(self): + return list(self._pool_ids) + + @property + def n_pools(self): + return len(self._pool_ids) + + def query_scipy(self, pool_id: str, cadence: float, gas_cost: float) -> float: + """Query scipy PCHIP at (cadence_minutes, gas_cost_usd).""" + return query_scipy(self._scipy_interps[pool_id], cadence, gas_cost) + + def query_jax(self, pool_id: str, log_cadence, gas_cost): + """Query JAX PCHIP at (log_cadence, gas_cost). Differentiable.""" + return interpolate_pool(self._jax_coeffs[pool_id], log_cadence, gas_cost) + + def get_coeffs(self, pool_id: str) -> PoolCoeffs: + """Get precomputed JAX coefficients for a single pool.""" + return self._jax_coeffs[pool_id] + + def get_scipy(self, pool_id: str) -> RegularGridInterpolator: + """Get scipy interpolator for a single pool.""" + return self._scipy_interps[pool_id] diff --git a/quantammsim/calibration/joint_fit.py b/quantammsim/calibration/joint_fit.py new file mode 100644 index 0000000..4f891c6 --- /dev/null +++ b/quantammsim/calibration/joint_fit.py @@ -0,0 +1,479 @@ +"""Joint end-to-end optimization (Option A) for the direct calibration pipeline. + +A parametric f_params maps pool_attributes → (cadence, gas, noise_coeffs), +optimized simultaneously across all pools through the grid interpolation loss. + +Two noise modes: + - "per_pool_noise": each pool has independent noise_coeffs (most flexible) + - "shared_noise": noise_coeffs = bias_noise + x_attr @ W_noise (generalizes) + +The cadence/gas mapping is always shared: + log_cadence = bias_cad + x_attr @ W_cad + log_gas = bias_gas + x_attr @ W_gas +""" + +from typing import Dict, List, NamedTuple, Optional + +import jax +import jax.numpy as jnp +import numpy as np +import scipy.optimize + +from quantammsim.calibration.grid_interpolation import ( + PoolCoeffsDaily, + interpolate_pool_daily, +) +from quantammsim.calibration.loss import K_OBS +from quantammsim.calibration.pool_data import build_pool_attributes, build_x_obs + + +class JointData(NamedTuple): + """Batched data for joint optimization.""" + pool_data: list # list of dicts with coeffs, x_obs, y_obs, day_indices + x_attr: jnp.ndarray # (n_pools, K_attr) pool attributes (no intercept) + pool_ids: list # list of pool_id prefixes + attr_names: list # attribute column names + + +def prepare_joint_data( + matched: Dict[str, dict], + drop_chain_dummies: bool = False, +) -> JointData: + """Build batched JAX arrays from matched pool data. + + Args: + matched: dict from match_grids_to_panel + drop_chain_dummies: if True, remove chain_* columns from attributes + (reduces feature count for small n) + + Returns: + JointData with per-pool JAX arrays and shared attribute matrix. + """ + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + + if drop_chain_dummies: + keep = [i for i, name in enumerate(attr_names) + if not name.startswith("chain_")] + X_attr = X_attr[:, keep] + attr_names = [attr_names[i] for i in keep] + + pool_data = [] + for pid in pool_ids: + entry = matched[pid] + panel = entry["panel"] + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + + pool_data.append({ + "coeffs": entry["coeffs"], + "x_obs": jnp.array(x_obs), + "y_obs": jnp.array(y_obs), + "day_indices": jnp.array(entry["day_indices"]), + }) + + return JointData( + pool_data=pool_data, + x_attr=jnp.array(X_attr), + pool_ids=pool_ids, + attr_names=attr_names, + ) + + +def pack_joint_params( + bias_cad: float, + bias_gas: float, + W_cad: jnp.ndarray, + W_gas: jnp.ndarray, + noise_params: jnp.ndarray, +) -> jnp.ndarray: + """Pack joint params into flat array. + + Layout: [bias_cad, bias_gas, W_cad(k_attr), W_gas(k_attr), noise_params...] + + noise_params is either: + - (n_pools, K_OBS) for per_pool_noise mode + - (1 + K_attr, K_OBS) for shared_noise mode (row 0 = noise bias) + """ + return jnp.concatenate([ + jnp.array([bias_cad, bias_gas]), + W_cad.ravel(), + W_gas.ravel(), + noise_params.ravel(), + ]) + + +def unpack_joint_params( + flat: jnp.ndarray, config: dict +) -> dict: + """Unpack flat array to structured params. + + config must have: k_attr, n_pools, mode + """ + k_attr = config["k_attr"] + mode = config["mode"] + + bias_cad = flat[0] + bias_gas = flat[1] + W_cad = flat[2:2 + k_attr] + W_gas = flat[2 + k_attr:2 + 2 * k_attr] + rest = flat[2 + 2 * k_attr:] + + if mode == "per_pool_noise": + n_pools = config["n_pools"] + noise_coeffs = rest.reshape(n_pools, K_OBS) + return { + "bias_cad": bias_cad, "bias_gas": bias_gas, + "W_cad": W_cad, "W_gas": W_gas, + "noise_coeffs": noise_coeffs, + } + else: # shared_noise + # noise_params: (1 + k_attr, K_OBS) — row 0 is bias + W_noise_full = rest.reshape(1 + k_attr, K_OBS) + return { + "bias_cad": bias_cad, "bias_gas": bias_gas, + "W_cad": W_cad, "W_gas": W_gas, + "bias_noise": W_noise_full[0], # (K_OBS,) + "W_noise": W_noise_full[1:], # (k_attr, K_OBS) + } + + +def _make_pool_loss_fn( + pool_idx: int, + pool_data_i: dict, + x_attr_i: jnp.ndarray, + config: dict, +): + """Create a JIT'd loss function for a single pool. + + Closes over pool-specific data; takes only params_flat as input. + Each pool gets its own small JIT'd computation graph. + """ + coeffs = pool_data_i["coeffs"] + x_obs = pool_data_i["x_obs"] + y_obs = pool_data_i["y_obs"] + day_indices = pool_data_i["day_indices"] + mode = config["mode"] + i = pool_idx + + @jax.jit + def pool_loss_fn(params_flat): + params = unpack_joint_params(params_flat, config) + log_cad = params["bias_cad"] + jnp.dot(x_attr_i, params["W_cad"]) + log_gas = params["bias_gas"] + jnp.dot(x_attr_i, params["W_gas"]) + + if mode == "per_pool_noise": + noise_c = params["noise_coeffs"][i] + else: + noise_c = params["bias_noise"] + jnp.dot(x_attr_i, params["W_noise"]) + + v_arb_all = interpolate_pool_daily(coeffs, log_cad, jnp.exp(log_gas)) + v_arb = v_arb_all[day_indices] + v_noise = jnp.exp(x_obs @ noise_c) + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + return jnp.mean((log_v_pred - y_obs) ** 2) + + return pool_loss_fn + + +def make_joint_loss_fn( + jdata: JointData, + mode: str = "per_pool_noise", + alpha_cad: float = 0.01, + alpha_gas: float = 0.01, +): + """Create per-pool JIT'd loss functions and a Python-level aggregator. + + Each pool gets its own small JIT'd computation graph (compiled + independently), avoiding a massive unrolled trace. The outer + function sums per-pool losses in Python and adds regularization. + + Loss averages over pools (not observations), giving equal weight + to each pool regardless of observation count. + + L2 regularization is applied to W_cad and W_gas only (not biases). + + Args: + jdata: JointData from prepare_joint_data + mode: "per_pool_noise" or "shared_noise" + alpha_cad: L2 regularization on W_cad + alpha_gas: L2 regularization on W_gas + + Returns: + loss_fn(params_flat) -> scalar loss + """ + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode} + + # Build per-pool JIT'd loss functions + pool_loss_fns = [] + pool_val_and_grad_fns = [] + for i in range(n_pools): + fn = _make_pool_loss_fn(i, jdata.pool_data[i], jdata.x_attr[i], config) + pool_loss_fns.append(fn) + pool_val_and_grad_fns.append(jax.value_and_grad(fn)) + + def loss_fn(params_flat): + total = sum(fn(params_flat) for fn in pool_loss_fns) + data_loss = total / n_pools + + params = unpack_joint_params(params_flat, config) + reg = alpha_cad * jnp.sum(params["W_cad"] ** 2) + \ + alpha_gas * jnp.sum(params["W_gas"] ** 2) + return data_loss + reg + + # Attach per-pool functions for the value_and_grad wrapper + loss_fn._pool_val_and_grad_fns = pool_val_and_grad_fns + loss_fn._n_pools = n_pools + loss_fn._config = config + loss_fn._alpha_cad = alpha_cad + loss_fn._alpha_gas = alpha_gas + + return loss_fn + + +def make_initial_joint_params( + jdata: JointData, + mode: str = "per_pool_noise", + init_from_option_c: Optional[Dict[str, dict]] = None, +) -> jnp.ndarray: + """Create initial parameter vector. + + If init_from_option_c is provided, warm-start from Option C per-pool fits: + - bias_cad, W_cad from OLS on per-pool fitted log_cadence + - bias_gas, W_gas from OLS on per-pool fitted log_gas + - noise_coeffs from per-pool fits + + Otherwise, use defaults: cadence=12min, gas=$1 for all pools. + """ + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + x_attr_np = np.array(jdata.x_attr) + + if init_from_option_c is not None: + pool_ids = jdata.pool_ids + # Filter out pools with NaN losses from warm start + valid = {p: init_from_option_c[p] for p in pool_ids + if p in init_from_option_c + and np.isfinite(init_from_option_c[p].get("loss", float("nan")))} + + if len(valid) < len(pool_ids): + default_lc = np.log(12.0) + default_lg = np.log(1.0) + for p in pool_ids: + if p not in valid: + valid[p] = { + "log_cadence": default_lc, + "log_gas": default_lg, + "noise_coeffs": np.zeros(K_OBS), + } + + log_cads = np.array([valid[p]["log_cadence"] for p in pool_ids]) + log_gases = np.array([valid[p]["log_gas"] for p in pool_ids]) + noise_all = np.array([valid[p]["noise_coeffs"] for p in pool_ids]) + + # OLS with intercept: X_aug = [1, x_attr]; solve for [bias, W] + X_aug = np.column_stack([np.ones(n_pools), x_attr_np]) + cad_params, _, _, _ = np.linalg.lstsq(X_aug, log_cads, rcond=None) + gas_params, _, _, _ = np.linalg.lstsq(X_aug, log_gases, rcond=None) + bias_cad, W_cad = cad_params[0], cad_params[1:] + bias_gas, W_gas = gas_params[0], gas_params[1:] + + if mode == "per_pool_noise": + noise_params = noise_all + else: + # OLS with intercept for noise mapping + noise_aug, _, _, _ = np.linalg.lstsq(X_aug, noise_all, rcond=None) + # noise_aug: (1+k_attr, K_OBS) — row 0 is bias + noise_params = noise_aug + else: + # Default: all pools get cadence=12min, gas=$1 + bias_cad = np.log(12.0) + bias_gas = np.log(1.0) # = 0.0 + W_cad = np.zeros(k_attr) + W_gas = np.zeros(k_attr) + + if mode == "per_pool_noise": + # Initialize noise via OLS per pool + noise_params = np.zeros((n_pools, K_OBS)) + for i, pd in enumerate(jdata.pool_data): + x_obs_np = np.array(pd["x_obs"]) + y_obs_np = np.array(pd["y_obs"]) + c, _, _, _ = np.linalg.lstsq(x_obs_np, y_obs_np, rcond=None) + noise_params[i] = c + else: + # Initialize shared noise from pooled OLS + all_x = np.vstack([np.array(pd["x_obs"]) for pd in jdata.pool_data]) + all_y = np.concatenate([np.array(pd["y_obs"]) for pd in jdata.pool_data]) + c, _, _, _ = np.linalg.lstsq(all_x, all_y, rcond=None) + # (1+k_attr, K_OBS): bias row + zero weight rows + noise_params = np.zeros((1 + k_attr, K_OBS)) + noise_params[0, :] = c + + return pack_joint_params( + float(bias_cad), + float(bias_gas), + jnp.array(W_cad), + jnp.array(W_gas), + jnp.array(noise_params), + ) + + +def _make_bounds(k_attr, n_pools, mode): + """Build scipy bounds for joint params.""" + # bias_cad, bias_gas: unbounded + bounds = [(None, None)] * 2 + # W_cad, W_gas: unbounded + bounds += [(None, None)] * (2 * k_attr) + + if mode == "per_pool_noise": + bounds += [(None, None)] * (n_pools * K_OBS) + else: + bounds += [(None, None)] * ((1 + k_attr) * K_OBS) + + return bounds + + +def fit_joint( + matched: Dict[str, dict], + mode: str = "per_pool_noise", + init_from_option_c: Optional[Dict[str, dict]] = None, + maxiter: int = 500, + alpha_cad: float = 0.01, + alpha_gas: float = 0.01, + drop_chain_dummies: bool = False, +) -> dict: + """Joint end-to-end optimization across all pools. + + Args: + matched: dict from match_grids_to_panel + mode: "per_pool_noise" or "shared_noise" + init_from_option_c: Optional Option C results for warm start. + Pools with NaN losses are silently excluded from warm start. + maxiter: max L-BFGS-B iterations + alpha_cad: L2 regularization on W_cad (not bias) + alpha_gas: L2 regularization on W_gas (not bias) + drop_chain_dummies: if True, remove chain_* columns from attributes + + Returns dict with fitted params and diagnostics. + """ + jdata = prepare_joint_data(matched, drop_chain_dummies=drop_chain_dummies) + loss_fn = make_joint_loss_fn(jdata, mode=mode, + alpha_cad=alpha_cad, alpha_gas=alpha_gas) + init = make_initial_joint_params(jdata, mode=mode, + init_from_option_c=init_from_option_c) + + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode} + bounds = _make_bounds(k_attr, n_pools, mode) + + # Per-pool value_and_grad — each pool has its own small JIT graph + pool_vg_fns = loss_fn._pool_val_and_grad_fns + + # Indices for W_cad and W_gas in the flat param vector (for reg gradient) + w_cad_start = 2 + w_cad_end = 2 + k_attr + w_gas_start = 2 + k_attr + w_gas_end = 2 + 2 * k_attr + + def scipy_wrapper(params_np): + params_j = jnp.array(params_np) + + # Sum per-pool losses and gradients + total_val = 0.0 + total_grad = jnp.zeros_like(params_j) + for vg_fn in pool_vg_fns: + v, g = vg_fn(params_j) + total_val += float(v) + total_grad = total_grad + g + + data_loss = total_val / n_pools + data_grad = total_grad / n_pools + + # Regularization on W_cad and W_gas (not biases) + reg = (alpha_cad * float(jnp.sum(params_j[w_cad_start:w_cad_end] ** 2)) + + alpha_gas * float(jnp.sum(params_j[w_gas_start:w_gas_end] ** 2))) + + reg_grad = jnp.zeros_like(params_j) + reg_grad = reg_grad.at[w_cad_start:w_cad_end].set( + 2 * alpha_cad * params_j[w_cad_start:w_cad_end]) + reg_grad = reg_grad.at[w_gas_start:w_gas_end].set( + 2 * alpha_gas * params_j[w_gas_start:w_gas_end]) + + val = data_loss + reg + grad = data_grad + reg_grad + return val, np.array(grad, dtype=np.float64) + + init_np = np.array(init, dtype=np.float64) + init_loss = float(loss_fn(jnp.array(init_np))) + + result = scipy.optimize.minimize( + scipy_wrapper, + init_np, + method="L-BFGS-B", + jac=True, + bounds=bounds, + options={"maxiter": maxiter, "ftol": 1e-10, "gtol": 1e-8}, + ) + + params = unpack_joint_params(jnp.array(result.x), config) + + out = { + "init_loss": init_loss, + "bias_cad": float(params["bias_cad"]), + "bias_gas": float(params["bias_gas"]), + "W_cad": np.array(params["W_cad"]), + "W_gas": np.array(params["W_gas"]), + "loss": float(result.fun), + "converged": result.success, + "mode": mode, + "k_attr": k_attr, + "pool_ids": jdata.pool_ids, + "attr_names": jdata.attr_names, + } + + if mode == "per_pool_noise": + out["noise_coeffs"] = np.array(params["noise_coeffs"]) + else: + out["bias_noise"] = np.array(params["bias_noise"]) + out["W_noise"] = np.array(params["W_noise"]) + + return out + + +def predict_new_pool_joint( + result: dict, + x_attr: np.ndarray, +) -> dict: + """Predict simulator settings for a new pool using joint-fitted mapping. + + In per_pool_noise mode, only cadence and gas are predicted (noise + coefficients are per-pool and can't generalize). Use shared_noise + mode for full deployment predictions including noise_coeffs. + + Args: + result: dict from fit_joint + x_attr: (K_attr,) pool attribute vector — must match the k_attr + from training (check result['attr_names'] for the feature order). + No intercept — just the real features. + + Returns dict with cadence_minutes, gas_usd, and noise_coeffs (shared_noise only). + """ + x = np.asarray(x_attr) + log_cadence = result["bias_cad"] + float(x @ result["W_cad"]) + log_gas = result["bias_gas"] + float(x @ result["W_gas"]) + + out = { + "log_cadence": log_cadence, + "log_gas": log_gas, + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": float(np.exp(log_gas)), + } + + if result["mode"] == "shared_noise": + out["noise_coeffs"] = np.array( + result["bias_noise"] + x @ result["W_noise"] + ) + + return out diff --git a/quantammsim/calibration/learned_mapping.py b/quantammsim/calibration/learned_mapping.py new file mode 100644 index 0000000..294d008 --- /dev/null +++ b/quantammsim/calibration/learned_mapping.py @@ -0,0 +1,126 @@ +"""Learned mapping: pool attributes -> (cadence, gas, noise_coeffs). + +Ridge regression from pool-level features to per-pool fitted parameters. +Trained on per-pool fit results from per_pool_fit.py; used to predict +simulator settings for new/hypothetical pools. +""" + +from typing import Dict, List + +import numpy as np + +from quantammsim.calibration.loss import K_OBS + + +def build_targets( + fit_results: Dict[str, dict], + pool_order: List[str], +) -> np.ndarray: + """Stack per-pool fitted params into (n_pools, 2+K_OBS) target matrix. + + Columns: [log_cadence, log_gas, noise_coeffs...] + Row ordering matches pool_order. + """ + n_pools = len(pool_order) + Y = np.zeros((n_pools, 2 + K_OBS)) + for i, pid in enumerate(pool_order): + r = fit_results[pid] + Y[i, 0] = r["log_cadence"] + Y[i, 1] = r["log_gas"] + Y[i, 2:] = r["noise_coeffs"] + return Y + + +def fit_mapping( + X_attr: np.ndarray, + Y_target: np.ndarray, + alpha: float = 1.0, +) -> dict: + """Fit Ridge regression: X_attr -> Y_target. + + Multi-output Ridge: one shared regularization strength across all + target columns. + + Returns dict with weights, intercept, and diagnostics. + """ + # Ridge with intercept: center Y, solve on centered data + n, k = X_attr.shape + Y_mean = Y_target.mean(axis=0) + X_mean = X_attr.mean(axis=0) + Xc = X_attr - X_mean + Yc = Y_target - Y_mean + + # W = (Xc^T Xc + alpha * I)^{-1} Xc^T Yc + A = Xc.T @ Xc + alpha * np.eye(k) + W = np.linalg.solve(A, Xc.T @ Yc) # (K_attr, K_target) + intercept = Y_mean - X_mean @ W # (K_target,) + + Y_pred = X_attr @ W + intercept + ss_res = np.sum((Y_target - Y_pred) ** 2) + ss_tot = np.sum((Y_target - Y_target.mean(axis=0)) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + return { + "weights": W, # (K_attr, K_target) + "intercept": intercept, # (K_target,) + "alpha": alpha, + "r2_train": float(r2), + } + + +def predict_pool( + mapping: dict, + x_attr: np.ndarray, +) -> dict: + """Predict simulator settings for a new pool. + + Args: + mapping: dict from fit_mapping + x_attr: (K_attr,) single pool attribute vector + + Returns dict with cadence_minutes, gas_usd, noise_coeffs, etc. + """ + y = x_attr @ mapping["weights"] + mapping["intercept"] + + log_cadence = float(y[0]) + log_gas = float(y[1]) + noise_coeffs = np.array(y[2:]) + + return { + "log_cadence": log_cadence, + "log_gas": log_gas, + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": float(np.exp(log_gas)), + "noise_coeffs": noise_coeffs, + } + + +def cross_validate_loo( + X_attr: np.ndarray, + Y_target: np.ndarray, + alpha: float = 1.0, +) -> dict: + """Leave-one-out cross-validation. + + Returns per-pool prediction errors and summary statistics. + """ + n = X_attr.shape[0] + errors = np.zeros((n, Y_target.shape[1])) + + for i in range(n): + mask = np.ones(n, dtype=bool) + mask[i] = False + X_train = X_attr[mask] + Y_train = Y_target[mask] + + model = fit_mapping(X_train, Y_train, alpha=alpha) + y_pred = X_attr[i] @ model["weights"] + model["intercept"] + errors[i] = Y_target[i] - y_pred + + mse = np.mean(errors ** 2, axis=0) + return { + "per_pool_errors": errors, + "mse_per_target": mse, + "rmse_per_target": np.sqrt(mse), + "mean_rmse": float(np.mean(np.sqrt(mse))), + } diff --git a/quantammsim/calibration/loss.py b/quantammsim/calibration/loss.py new file mode 100644 index 0000000..13b2e65 --- /dev/null +++ b/quantammsim/calibration/loss.py @@ -0,0 +1,74 @@ +"""Per-pool loss function for the direct calibration pipeline. + +JAX-differentiable loss: sum((log(V_arb_i + V_noise_i) - log(V_obs_i))^2) +where V_arb comes from per-day grid interpolation and V_noise from +log-linear regression on observation covariates. +""" + +from typing import Tuple + +import jax.numpy as jnp + +from quantammsim.calibration.grid_interpolation import ( + PoolCoeffsDaily, + interpolate_pool_daily, +) + +K_OBS = 8 # observation-level covariates + + +def noise_volume( + noise_coeffs: jnp.ndarray, x_obs: jnp.ndarray +) -> jnp.ndarray: + """V_noise = exp(x_obs @ noise_coeffs). Shape: (n_obs,).""" + return jnp.exp(x_obs @ noise_coeffs) + + +def pack_params( + log_cadence: float, log_gas: float, noise_coeffs: jnp.ndarray +) -> jnp.ndarray: + """Pack into flat array: [log_cadence, log_gas, noise_coeffs...].""" + return jnp.concatenate([ + jnp.array([log_cadence, log_gas]), + jnp.asarray(noise_coeffs), + ]) + + +def unpack_params( + flat: jnp.ndarray, +) -> Tuple[float, float, jnp.ndarray]: + """Unpack flat array to (log_cadence, log_gas, noise_coeffs).""" + return flat[0], flat[1], flat[2:] + + +def pool_loss( + params_flat: jnp.ndarray, + coeffs: PoolCoeffsDaily, + x_obs: jnp.ndarray, + y_obs: jnp.ndarray, + day_indices: jnp.ndarray, +) -> jnp.ndarray: + """Per-pool log-space L2 loss with per-day V_arb. + + Args: + params_flat: [log_cadence, log_gas, noise_coeffs...] from pack_params + coeffs: PoolCoeffsDaily with per-day grid values + x_obs: (n_obs, K_OBS) observation covariates + y_obs: (n_obs,) log(V_obs) — observed log volume + day_indices: (n_obs,) int indices mapping panel rows to grid days + + Returns: + Scalar mean squared error in log space. + """ + log_cadence, log_gas, noise_coeffs = unpack_params(params_flat) + + # Per-day V_arb from grid interpolation + v_arb_all = interpolate_pool_daily(coeffs, log_cadence, jnp.exp(log_gas)) # (n_days,) + v_arb = v_arb_all[day_indices] # (n_obs,) + + # Per-day V_noise from covariates + v_noise = noise_volume(noise_coeffs, x_obs) # (n_obs,) + + # Log-space L2 loss + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + return jnp.mean((log_v_pred - y_obs) ** 2) diff --git a/quantammsim/calibration/per_pool_fit.py b/quantammsim/calibration/per_pool_fit.py new file mode 100644 index 0000000..037a124 --- /dev/null +++ b/quantammsim/calibration/per_pool_fit.py @@ -0,0 +1,127 @@ +"""Per-pool fitting via L-BFGS-B for the direct calibration pipeline. + +Fits (log_cadence, log_gas, noise_coeffs) per pool by minimizing +the log-space L2 loss using scipy.optimize.minimize with JAX gradients. +""" + +from typing import Dict, Optional + +import jax +import jax.numpy as jnp +import numpy as np +import scipy.optimize + +from quantammsim.calibration.grid_interpolation import PoolCoeffsDaily +from quantammsim.calibration.loss import K_OBS, pack_params, pool_loss +from quantammsim.calibration.pool_data import build_x_obs + + +def make_initial_guess(x_obs: np.ndarray, y_obs: np.ndarray) -> np.ndarray: + """Initial params: cadence=12min, gas=$1, noise_coeffs from OLS. + + OLS: noise_coeffs = lstsq(x_obs, y_obs) — assumes all volume is noise. + This overestimates noise but gives a reasonable starting point. + """ + noise_coeffs, _, _, _ = np.linalg.lstsq(x_obs, y_obs, rcond=None) + init = np.zeros(2 + K_OBS) + init[0] = np.log(12.0) # log_cadence + init[1] = np.log(1.0) # log_gas (= 0.0) + init[2:] = noise_coeffs + return init + + +def fit_single_pool( + coeffs: PoolCoeffsDaily, + x_obs: np.ndarray, + y_obs: np.ndarray, + day_indices: np.ndarray, + init: Optional[np.ndarray] = None, + bounds: Optional[dict] = None, +) -> dict: + """Fit (log_cadence, log_gas, noise_coeffs) for one pool via L-BFGS-B. + + Returns dict with fitted params, loss, and convergence status. + """ + if init is None: + init = make_initial_guess(x_obs, y_obs) + + # Default bounds + if bounds is None: + bounds = {} + log_cad_bounds = bounds.get("log_cadence", (np.log(1.0), np.log(60.0))) + log_gas_bounds = bounds.get("log_gas", (np.log(0.001), np.log(50.0))) + noise_bounds = bounds.get("noise_coeffs", (-20.0, 20.0)) + + scipy_bounds = [ + log_cad_bounds, + log_gas_bounds, + ] + [(noise_bounds[0], noise_bounds[1])] * K_OBS + + # Convert to JAX arrays + x_obs_j = jnp.array(x_obs) + y_obs_j = jnp.array(y_obs) + day_idx_j = jnp.array(day_indices) + + # Value and gradient function + @jax.jit + def loss_and_grad(params_flat): + loss = pool_loss(params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j) + grad = jax.grad(pool_loss, argnums=0)( + params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j + ) + return loss, grad + + def scipy_wrapper(params_np): + params_j = jnp.array(params_np) + loss, grad = loss_and_grad(params_j) + return float(loss), np.array(grad, dtype=np.float64) + + result = scipy.optimize.minimize( + scipy_wrapper, + init, + method="L-BFGS-B", + jac=True, + bounds=scipy_bounds, + options={"maxiter": 500, "ftol": 1e-10, "gtol": 1e-8}, + ) + + log_cadence = float(result.x[0]) + log_gas = float(result.x[1]) + noise_coeffs = np.array(result.x[2:]) + + return { + "log_cadence": log_cadence, + "log_gas": log_gas, + "noise_coeffs": noise_coeffs, + "loss": float(result.fun), + "converged": result.success, + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": float(np.exp(log_gas)), + } + + +def fit_all_pools( + matched: Dict[str, dict], + n_workers: int = 1, +) -> Dict[str, dict]: + """Fit all matched pools. Returns prefix -> fit_result with metadata.""" + results = {} + + for prefix, entry in matched.items(): + panel = entry["panel"] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + + result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + + # Add metadata + result["chain"] = entry["chain"] + result["fee"] = entry["fee"] + result["tokens"] = entry["tokens"] + + results[prefix] = result + + return results diff --git a/quantammsim/calibration/pool_data.py b/quantammsim/calibration/pool_data.py new file mode 100644 index 0000000..039880b --- /dev/null +++ b/quantammsim/calibration/pool_data.py @@ -0,0 +1,302 @@ +"""Data assembly for the direct calibration pipeline. + +Matches precomputed per-day arb grids to panel observations and builds +model-ready arrays for the loss function. +""" + +import json +import os +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd + +from quantammsim.calibration.grid_interpolation import ( + PoolCoeffsDaily, + load_daily_grid, + precompute_pool_coeffs_daily, +) + +K_OBS = 8 # observation-level covariates + +# Default path for cached token market caps +_MCAP_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "local_data", "noise_calibration", "token_mcaps.json", +) + +# Asset type classification (fallback if not in mcap JSON) +_STABLECOINS = { + "USDC", "USDT", "DAI", "WXDAI", "xDAI", "GHO", "LUSD", "crvUSD", + "FRAX", "sDAI", "scUSD", "DOLA", + "waBasUSDC", "waEthUSDC", +} +_NATIVE_LST = { + "WETH", "ETH", "wstETH", "stETH", "rETH", "cbETH", + "WBTC", "BTC", "cbBTC", + "WMATIC", "MATIC", "POL", "wPOL", + "WAVAX", "AVAX", + "GNO", "S", "wS", "stS", + "JitoSOL", + "waEthLidoWETH", "waEthLidowstETH", + "waBasWETH", "waGnoGNO", "waGnowstETH", +} + + +def _load_token_mcaps(path: str = None) -> dict: + """Load cached token market caps. Returns {} if file missing.""" + path = path or _MCAP_PATH + if os.path.exists(path): + with open(path) as f: + return json.load(f) + return {} + + +def _get_asset_type(symbol: str, mcaps: dict) -> int: + """Return asset type: 0=stable, 1=native/LST, 2=volatile.""" + if symbol in mcaps and "asset_type" in mcaps[symbol]: + t = mcaps[symbol]["asset_type"] + return {"stable": 0, "native_lst": 1, "volatile": 2}.get(t, 2) + if symbol in _STABLECOINS: + return 0 + if symbol in _NATIVE_LST: + return 1 + return 2 + + +def _parse_tokens(tokens_str: str) -> List[str]: + """Parse comma-separated token string into list.""" + if isinstance(tokens_str, (list, tuple)): + return list(tokens_str) + return [t.strip() for t in tokens_str.split(",")] + + +def match_grids_to_panel( + grid_dir: str, panel: pd.DataFrame, pools_path: str = None, +) -> Dict[str, dict]: + """Match grid parquets to panel rows by pool_id prefix. + + For each _daily.parquet in grid_dir, find the panel pool whose + pool_id starts with the same 16-char prefix. Build PoolCoeffsDaily + and compute day_indices mapping panel dates to grid date indices. + + Args: + grid_dir: directory containing {prefix}_daily.parquet files + panel: panel DataFrame with pool observations + pools_path: path to pools.parquet for weight metadata. + Defaults to local_data/noise_calibration/pools.parquet. + + Returns dict: prefix -> { + 'panel': DataFrame (obs for this pool), + 'coeffs': PoolCoeffsDaily (per-day), + 'day_indices': np.ndarray (panel date -> grid day index), + 'pool_id': full pool_id from panel, + 'chain': str, 'fee': float, 'tokens': str, + 'weights': list of float (pool weights, e.g. [0.5, 0.5]), + } + """ + # Discover grid files + grid_prefixes = [] + for f in sorted(os.listdir(grid_dir)): + if f.endswith("_daily.parquet"): + prefix = f.replace("_daily.parquet", "") + grid_prefixes.append(prefix) + + if not grid_prefixes: + return {} + + # Load pool metadata for weights + if pools_path is None: + pools_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "local_data", "noise_calibration", "pools.parquet", + ) + pools_meta = {} + if os.path.exists(pools_path): + pools_df = pd.read_parquet(pools_path) + pools_df["_prefix"] = pools_df["pool_id"].str[:16] + for _, row in pools_df.iterrows(): + w = row.get("weights") + if w is not None: + try: + weights = [float(x) for x in w] + except (TypeError, ValueError): + weights = [0.5, 0.5] + else: + weights = [0.5, 0.5] + pools_meta[row["_prefix"]] = {"weights": weights} + + # Ensure date column + panel = panel.copy() + if "date" in panel.columns: + panel["date"] = pd.to_datetime(panel["date"]) + + # Build prefix -> panel rows mapping + panel["_prefix"] = panel["pool_id"].str[:16] + + matched = {} + for prefix in grid_prefixes: + pool_rows = panel[panel["_prefix"] == prefix] + if len(pool_rows) == 0: + continue + + # Load and precompute grid + grid_df = load_daily_grid(prefix, grid_dir) + coeffs = precompute_pool_coeffs_daily(grid_df) + + # Build date alignment: panel date ordinals -> grid day indices + grid_ordinals = np.array(coeffs.dates) + grid_ord_to_idx = {int(o): i for i, o in enumerate(grid_ordinals)} + + panel_dates = pd.to_datetime(pool_rows["date"]) + panel_ordinals = np.array([d.toordinal() for d in panel_dates]) + + # Filter to dates present in both panel and grid + valid_mask = np.array([int(o) in grid_ord_to_idx for o in panel_ordinals]) + pool_rows = pool_rows[valid_mask].copy() + panel_ordinals = panel_ordinals[valid_mask] + + if len(pool_rows) == 0: + continue + + day_indices = np.array([grid_ord_to_idx[int(o)] for o in panel_ordinals]) + + row0 = pool_rows.iloc[0] + weights = pools_meta.get(prefix, {}).get("weights", [0.5, 0.5]) + + matched[prefix] = { + "panel": pool_rows.reset_index(drop=True), + "coeffs": coeffs, + "day_indices": day_indices, + "pool_id": row0["pool_id"], + "chain": row0["chain"], + "fee": float(np.exp(row0["log_fee"])) if "swap_fee" not in pool_rows.columns + else float(row0.get("swap_fee", np.exp(row0["log_fee"]))), + "tokens": row0["tokens"], + "weights": weights, + } + + return matched + + +def build_x_obs(panel_rows: pd.DataFrame) -> np.ndarray: + """Build (n_obs, 8) observation covariate matrix from panel rows. + + Columns: [1, log_tvl_lag1, log_sigma, tvl*sigma, tvl*fee, + sigma*fee, dow_sin, dow_cos] + + Where: + log_sigma = log(max(volatility, 1e-6)) + tvl = log_tvl_lag1 + fee = log_fee + dow_sin = sin(2*pi*weekday/7), dow_cos = cos(2*pi*weekday/7) + weekday: Monday=0, ..., Sunday=6 + """ + n = len(panel_rows) + x = np.zeros((n, K_OBS)) + + tvl = panel_rows["log_tvl_lag1"].values.astype(float) + sigma = np.log(np.maximum(panel_rows["volatility"].values.astype(float), 1e-6)) + fee = panel_rows["log_fee"].values.astype(float) + weekdays = pd.to_datetime(panel_rows["date"]).dt.weekday.values.astype(float) + + x[:, 0] = 1.0 # intercept + x[:, 1] = tvl # log_tvl_lag1 + x[:, 2] = sigma # log_sigma + x[:, 3] = tvl * sigma # tvl × sigma + x[:, 4] = tvl * fee # tvl × fee + x[:, 5] = sigma * fee # sigma × fee + x[:, 6] = np.sin(2 * np.pi * weekdays / 7) # dow_sin + x[:, 7] = np.cos(2 * np.pi * weekdays / 7) # dow_cos + + return x + + +def build_pool_attributes( + matched: Dict[str, dict], + mcap_path: str = None, +) -> Tuple[np.ndarray, List[str], List[str]]: + """Build (n_pools, K_attr) pool attribute matrix. + + Columns: + chain_dummies..., log_fee, mean_log_tvl, + log_mcap_product, has_stable, same_asset_type, weight_imbalance + + No intercept column — bias terms are handled by the model internally. + Chain dummies: one-hot with the first chain (alphabetically) as reference. + Market caps loaded from cached JSON (run scripts/fetch_token_mcaps.py). + + Returns: (X_attr, attr_names, pool_ids) + """ + mcaps = _load_token_mcaps(mcap_path) + + pool_ids = sorted(matched.keys()) + n_pools = len(pool_ids) + + # Collect per-pool attributes + chains = [] + log_fees = [] + mean_tvls = [] + log_mcap_products = [] + has_stables = [] + same_asset_types = [] + weight_imbalances = [] + + for pid in pool_ids: + entry = matched[pid] + chains.append(entry["chain"]) + log_fee = entry["panel"]["log_fee"].values[0] + log_fees.append(float(log_fee)) + mean_tvls.append(float(entry["panel"]["log_tvl_lag1"].mean())) + + # Token-level features + tokens = _parse_tokens(entry["tokens"]) + tok_a = tokens[0] if len(tokens) > 0 else "UNKNOWN" + tok_b = tokens[1] if len(tokens) > 1 else "UNKNOWN" + + # Market cap product + mcap_a = mcaps.get(tok_a, {}).get("mcap_usd", 1e6) # $1M fallback + mcap_b = mcaps.get(tok_b, {}).get("mcap_usd", 1e6) + log_mcap_products.append(np.log(max(mcap_a, 1.0) * max(mcap_b, 1.0))) + + # Asset type: 0=stable, 1=native/LST, 2=volatile + type_a = _get_asset_type(tok_a, mcaps) + type_b = _get_asset_type(tok_b, mcaps) + has_stables.append(1.0 if (type_a == 0 or type_b == 0) else 0.0) + same_asset_types.append(1.0 if type_a == type_b else 0.0) + + # Weight imbalance: 0 for 50/50, 0.3 for 80/20 + weights = entry.get("weights", [0.5, 0.5]) + if len(weights) >= 2: + weight_imbalances.append(abs(weights[0] - weights[1])) + else: + weight_imbalances.append(0.0) + + # Chain dummies (first alphabetically is reference) + unique_chains = sorted(set(chains)) + chain_dummies = unique_chains[1:] # drop reference + + attr_names = ( + [f"chain_{c}" for c in chain_dummies] + + [ + "log_fee", "mean_log_tvl", "log_mcap_product", + "has_stable", "same_asset_type", "weight_imbalance", + ] + ) + k_attr = len(attr_names) + + X = np.zeros((n_pools, k_attr)) + for i, pid in enumerate(pool_ids): + chain = chains[i] + for j, cd in enumerate(chain_dummies): + if chain == cd: + X[i, j] = 1.0 + base = len(chain_dummies) + X[i, base] = log_fees[i] + X[i, base + 1] = mean_tvls[i] + X[i, base + 2] = log_mcap_products[i] + X[i, base + 3] = has_stables[i] + X[i, base + 4] = same_asset_types[i] + X[i, base + 5] = weight_imbalances[i] + + return X, attr_names, pool_ids diff --git a/quantammsim/core_simulator/__init__.py b/quantammsim/core_simulator/__init__.py index 10c118a..010490e 100644 --- a/quantammsim/core_simulator/__init__.py +++ b/quantammsim/core_simulator/__init__.py @@ -17,6 +17,7 @@ import jax.numpy as jnp # noqa: F401 from jax import config config.update("jax_enable_x64", True) + config.update("jax_compilation_cache_dir", "/tmp/jax_cache") except ImportError as e: raise ImportError( "JAX is required for core simulator. Please install jax and jaxlib." diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py index 598d3f7..367363e 100644 --- a/quantammsim/core_simulator/dynamic_inputs.py +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -77,14 +77,12 @@ def empty_dynamic_input_arrays() -> DynamicInputArrays: """Create a canonical empty bundle.""" return DynamicInputArrays( trades=None, - fees=jnp.zeros((1,), dtype=jnp.float64), - gas_cost=jnp.zeros((1,), dtype=jnp.float64), - arb_fees=jnp.zeros((1,), dtype=jnp.float64), - lp_supply=jnp.ones((1,), dtype=jnp.float64), + fees=jnp.zeros((1,)), + gas_cost=jnp.zeros((1,)), + arb_fees=jnp.zeros((1,)), + lp_supply=jnp.ones((1,)), # Columns: has_event, target_price_ratio, end_step, start_price_ratio_override - reclamm_price_ratio_updates=jnp.array( - [[0.0, 0.0, 0.0, jnp.nan]], dtype=jnp.float64 - ), + reclamm_price_ratio_updates=jnp.array([[0.0, 0.0, 0.0, jnp.nan]]), ) @@ -100,22 +98,22 @@ def resolve_dynamic_input_components( "fees": ( arrays.fees if dynamic_input_flags["has_dynamic_fees"] - else jnp.asarray([static_dict["fees"]], dtype=jnp.float64) + else jnp.asarray([static_dict["fees"]]) ), "gas_cost": ( arrays.gas_cost if dynamic_input_flags["has_dynamic_gas_cost"] - else jnp.asarray([static_dict["gas_cost"]], dtype=jnp.float64) + else jnp.asarray([static_dict["gas_cost"]]) ), "arb_fees": ( arrays.arb_fees if dynamic_input_flags["has_dynamic_arb_fees"] - else jnp.asarray([static_dict["arb_fees"]], dtype=jnp.float64) + else jnp.asarray([static_dict["arb_fees"]]) ), "lp_supply": ( arrays.lp_supply if dynamic_input_flags["has_lp_supply"] - else jnp.ones((1,), dtype=jnp.float64) + else jnp.ones((1,)) ), "reclamm_price_ratio_updates": ( arrays.reclamm_price_ratio_updates @@ -150,7 +148,7 @@ def materialize_dynamic_inputs( static_dict: dict, scan_len: int, do_trades: bool, - dtype=jnp.float64, + dtype=None, ) -> DynamicInputArrays: """Resolve and broadcast dynamic inputs for a specific scan length.""" if dynamic_input_flags is None and dynamic_inputs is not None: diff --git a/quantammsim/noise_calibration/__init__.py b/quantammsim/noise_calibration/__init__.py new file mode 100644 index 0000000..f1fddde --- /dev/null +++ b/quantammsim/noise_calibration/__init__.py @@ -0,0 +1,39 @@ +"""Noise calibration package for Balancer pool volume models. + +Public API re-exports from submodules. +""" + +# scipy.signal patch (must run before arviz import) +try: + from scipy.signal import gaussian as _ # noqa: F401 +except ImportError: + from scipy.signal.windows import gaussian as _gauss + import scipy.signal + scipy.signal.gaussian = _gauss + +from .constants import ( + K_COEFF, COEFF_NAMES, BALANCER_API_URL, BALANCER_API_CHAINS, CACHE_DIR, + K_CLUSTERS_DEFAULT, K_FEATURES_DEFAULT, +) +from .token_classification import classify_token_tier, _normalise_symbol +from .data_pipeline import ( + _graphql_request, enumerate_balancer_pools, fetch_pool_snapshots, + fetch_all_snapshots, fetch_token_prices, compute_pair_volatility, + assemble_panel, +) +from .data_validation import validate_panel +from .covariate_encoding import encode_covariates, encode_covariates_structural +from .model import noise_model, noise_model_dp_sigma, noise_model_ibp, noise_model_ibp_dp, stick_breaking_weights, structural_noise_model +from .formula_arb import formula_arb_volume_daily_jax +from .inference import ( + _get_theta_samples, _build_model_kwargs, run_svi, run_nuts, + run_svi_then_nuts, +) +from .postprocessing import ( + extract_noise_params, predict_new_pool, check_convergence, + run_prior_predictive, assign_dp_clusters, assign_ibp_dp_joint, + extract_structural_params, predict_new_pool_structural, +) +from .plotting import plot_diagnostics +from .output import generate_output_json, _save_sample_cache +from .cli import main diff --git a/quantammsim/noise_calibration/cli.py b/quantammsim/noise_calibration/cli.py new file mode 100644 index 0000000..3e169ac --- /dev/null +++ b/quantammsim/noise_calibration/cli.py @@ -0,0 +1,417 @@ +"""CLI entry point for noise calibration.""" + +import argparse +import json +import os +import sys +from datetime import date, timedelta + +import numpy as np +import pandas as pd + +from .constants import CACHE_DIR +from .data_pipeline import ( + enumerate_balancer_pools, fetch_all_snapshots, + fetch_token_prices, assemble_panel, +) +from .data_validation import validate_panel +from .covariate_encoding import encode_covariates +from .inference import run_svi, run_nuts, run_svi_then_nuts +from .postprocessing import ( + extract_noise_params, predict_new_pool, + check_convergence, run_prior_predictive, +) +from .plotting import plot_diagnostics +from .output import generate_output_json, _save_sample_cache + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Unified Bayesian hierarchical noise volume model " + "for Balancer pools (gold standard)" + ) + + # Actions + parser.add_argument("--fetch", action="store_true", + help="Fetch data from Balancer API") + parser.add_argument("--fit", action="store_true", + help="Run inference (SVI default)") + parser.add_argument("--nuts", action="store_true", + help="Use NUTS instead of SVI") + parser.add_argument("--svi-init-nuts", action="store_true", + help="SVI-initialized NUTS (fast warmup)") + parser.add_argument("--plot", action="store_true", + help="Generate diagnostic plots") + parser.add_argument("--prior-predictive", action="store_true", + help="Include prior predictive check") + parser.add_argument("--validate", action="store_true", + help="Run data validation pass") + parser.add_argument("--predict", action="store_true", + help="Predict for unseen pool") + + # Output + parser.add_argument("--output", default=None, + help="Output JSON path") + parser.add_argument("--output-dir", default="results", + help="Plot output directory (default: results)") + + # Predict args + parser.add_argument("--chain", default=None, + help="Chain for --predict") + parser.add_argument("--tokens", nargs="+", default=None, + help="Tokens for --predict") + parser.add_argument("--fee", type=float, default=0.003, + help="Fee for --predict") + + # NUTS hyperparameters + parser.add_argument("--num-warmup", type=int, default=1000, + help="NUTS warmup iterations (default: 1000)") + parser.add_argument("--num-samples", type=int, default=2000, + help="NUTS/SVI samples (default: 2000)") + parser.add_argument("--num-chains", type=int, default=4, + help="NUTS chains (default: 4)") + parser.add_argument("--target-accept", type=float, default=0.85, + help="NUTS target accept prob (default: 0.85)") + parser.add_argument("--max-tree-depth", type=int, default=10, + help="NUTS max tree depth (default: 10)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed (default: 42)") + + # SVI hyperparameters + parser.add_argument("--svi-steps", type=int, default=20000, + help="SVI optimization steps (default: 20000)") + parser.add_argument("--svi-lr", type=float, default=1e-3, + help="SVI learning rate (default: 1e-3)") + + # Model variant + parser.add_argument("--model", choices=["tier", "dp_sigma", "ibp", "ibp_dp", + "structural"], + default="tier", + help="Noise model variant: 'tier' (per-tier sigma_eps), " + "'dp_sigma' (DP mixture on sigma_eps), " + "'ibp' (IBP latent features), " + "'ibp_dp' (IBP features + DP noise clusters), or " + "'structural' (structural mixture: arb + MoE noise)") + parser.add_argument("--k-clusters", type=int, default=6, + help="Number of DP mixture components " + "(capacity ceiling, default: 6)") + parser.add_argument("--k-features", type=int, default=6, + help="Number of IBP latent features " + "(default: 6)") + + # Data + parser.add_argument("--train-days", type=int, default=90, + help="Use only the last N days of data for fitting " + "(default: 90). Aligns with Balancer API hourly price " + "coverage window. Set to 0 to use all data.") + parser.add_argument("--min-tvl", type=float, default=10000.0, + help="Pool enumeration TVL filter") + parser.add_argument("--cache-dir", default=None, + help="Cache directory") + parser.add_argument("--device", choices=["cpu", "gpu", "auto"], + default="auto", + help="JAX device (default: auto)") + + return parser.parse_args() + + +def main(): + args = _parse_args() + + if not any([args.fetch, args.fit, args.predict, args.validate]): + print("ERROR: At least one of --fetch, --fit, --predict, --validate " + "is required", file=sys.stderr) + sys.exit(1) + + cache_dir = args.cache_dir or CACHE_DIR + + # --- JAX setup (BEFORE any JAX ops / imports) --- + if args.fit or args.predict or args.prior_predictive: + # Set device before importing JAX + if args.device == "cpu": + os.environ.setdefault("JAX_PLATFORMS", "cpu") + elif args.device == "gpu": + os.environ.setdefault("JAX_PLATFORMS", "cuda") + # auto: don't touch JAX_PLATFORMS, let JAX pick + + # Set host device count for NUTS multi-chain BEFORE JAX init + if args.nuts or args.svi_init_nuts: + import numpyro as _np_pre + _np_pre.set_host_device_count( + min(args.num_chains, os.cpu_count() or 4) + ) + + import jax + import numpyro + numpyro.enable_x64() + + # --- File paths --- + pools_cache = os.path.join(cache_dir, "pools.parquet") + snaps_cache = os.path.join(cache_dir, "pool_snapshots.parquet") + prices_cache = os.path.join(cache_dir, "token_prices") + panel_cache = os.path.join(cache_dir, "panel.parquet") + + # --- Fetch --- + if args.fetch: + print("Phase 1: Fetching data from Balancer API") + print("=" * 60) + + print("\n1. Enumerating pools...") + pools_df = enumerate_balancer_pools(min_tvl=args.min_tvl) + os.makedirs(cache_dir, exist_ok=True) + pools_df.to_parquet(pools_cache, index=False) + print(f" Saved {len(pools_df)} pools -> {pools_cache}") + + print("\n2. Fetching daily snapshots...") + snapshots_df = fetch_all_snapshots(pools_df, cache_path=snaps_cache) + + print("\n3. Fetching token prices...") + token_addr_by_chain = {} + for _, pool in pools_df.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + print("\n4. Assembling panel (with lagged TVL)...") + panel = assemble_panel(pools_df, snapshots_df, token_prices) + panel.to_parquet(panel_cache, index=False) + print(f" Saved panel -> {panel_cache}") + + print(f"\nFetch complete. Panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # --- Validate --- + if args.validate: + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + panel = pd.read_parquet(panel_cache) + validate_panel(panel) + + # --- Fit --- + if args.fit: + print("\nUnified Noise Volume Model") + print("=" * 60) + + # Load panel + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_cache) + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + # Filter to recent window for training + if args.train_days > 0: + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=args.train_days) + n_before = len(panel) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + print(f" Filtered to last {args.train_days} days " + f"(>= {cutoff}): {len(panel)} obs " + f"(dropped {n_before - len(panel)})") + + # Ensure lagged TVL exists (in case loaded from old cache) + if "log_tvl_lag1" not in panel.columns: + print(" Adding lagged TVL to cached panel...") + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + # Filter: need at least 10 days per pool + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid_pools)].copy() + print(f" After filtering (>= 10 days): {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # Select model variant + if args.model == "structural": + from .model import structural_noise_model + from .covariate_encoding import encode_covariates_structural + model_fn = structural_noise_model + data = encode_covariates_structural(panel) + print(f" Model: structural mixture (arb + MoE noise)") + elif args.model == "ibp_dp": + from .model import noise_model_ibp_dp + model_fn = noise_model_ibp_dp + data = encode_covariates(panel, include_tiers=False) + data["K_features"] = args.k_features + data["K_clusters"] = args.k_clusters + print(f" Model: IBP+DP hybrid " + f"(K_features={args.k_features}, " + f"K_clusters={args.k_clusters})") + elif args.model == "ibp": + from .model import noise_model_ibp + model_fn = noise_model_ibp + data = encode_covariates(panel, include_tiers=False) + data["K_features"] = args.k_features + print(f" Model: IBP latent features " + f"(K_features={args.k_features})") + elif args.model == "dp_sigma": + from .model import noise_model_dp_sigma + model_fn = noise_model_dp_sigma + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = args.k_clusters + print(f" Model: DP mixture on sigma_eps " + f"(K_clusters={args.k_clusters})") + else: + model_fn = None # default = noise_model + data = encode_covariates(panel) + + # Prior predictive + prior_samples = None + if args.prior_predictive: + print("\n Running prior predictive check...") + prior_samples = run_prior_predictive(data, model_fn=model_fn) + + # Inference + mcmc_obj = None + elbo_losses = None + inference_config = {"seed": args.seed} + + if args.svi_init_nuts: + inference_config["method"] = "svi_init_nuts" + inference_config["svi_steps"] = args.svi_steps + inference_config["svi_lr"] = args.svi_lr + inference_config["num_warmup"] = args.num_warmup + inference_config["num_samples"] = args.num_samples + inference_config["num_chains"] = args.num_chains + inference_config["target_accept"] = args.target_accept + inference_config["max_tree_depth"] = args.max_tree_depth + + mcmc_obj, elbo_losses = run_svi_then_nuts( + data, + svi_steps=args.svi_steps, + svi_lr=args.svi_lr, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + model_fn=model_fn, + ) + samples = mcmc_obj + convergence = check_convergence(mcmc_obj, method="nuts") + + elif args.nuts: + inference_config["method"] = "nuts" + inference_config["num_warmup"] = args.num_warmup + inference_config["num_samples"] = args.num_samples + inference_config["num_chains"] = args.num_chains + inference_config["target_accept"] = args.target_accept + inference_config["max_tree_depth"] = args.max_tree_depth + + mcmc_obj = run_nuts( + data, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + model_fn=model_fn, + ) + samples = mcmc_obj + convergence = check_convergence(mcmc_obj, method="nuts") + + else: + inference_config["method"] = "svi" + inference_config["svi_steps"] = args.svi_steps + inference_config["svi_lr"] = args.svi_lr + inference_config["num_samples"] = args.num_samples + + samples, elbo_losses = run_svi( + data, + num_steps=args.svi_steps, + lr=args.svi_lr, + seed=args.seed, + num_samples=args.num_samples, + model_fn=model_fn, + ) + convergence = check_convergence(elbo_losses, method="svi") + + if args.model == "structural": + from .postprocessing import extract_structural_params + pool_params = extract_structural_params(samples, data) + arb_freqs = [p["arb_frequency"] for p in pool_params] + print(f"\n Per-pool arb_frequency: " + f"mean={np.mean(arb_freqs):.1f}, " + f"range=[{np.min(arb_freqs)}, {np.max(arb_freqs)}]") + else: + pool_params = extract_noise_params(samples, data) + b_c_vals = [p["noise_params"]["b_c"] for p in pool_params] + b_0_vals = [p["noise_params"]["b_0"] for p in pool_params] + print(f"\n Per-pool b_c: mean={np.mean(b_c_vals):.3f}, " + f"std={np.std(b_c_vals):.3f}, " + f"range=[{np.min(b_c_vals):.3f}, {np.max(b_c_vals):.3f}]") + print(f" Per-pool b_0: mean={np.mean(b_0_vals):.3f}, " + f"std={np.std(b_0_vals):.3f}") + + if args.output: + generate_output_json( + pool_params, samples, data, convergence, + args.output, inference_config, + ) + + if args.plot: + print("\nGenerating diagnostic plots...") + plot_diagnostics( + samples, data, output_dir=args.output_dir, + elbo_losses=elbo_losses, mcmc=mcmc_obj, + prior_samples=prior_samples, + ) + + # Cache samples for --predict + _save_sample_cache(samples, data, cache_dir) + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + print("ERROR: --predict requires --chain and --tokens", + file=sys.stderr) + sys.exit(1) + + # Load cached samples + sample_cache = os.path.join(cache_dir, "unified_samples.npz") + data_cache = os.path.join(cache_dir, "unified_data.json") + + if not os.path.exists(sample_cache): + print(f"ERROR: Sample cache not found at {sample_cache}", + file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + cached = np.load(sample_cache) + sample_dict = {k: cached[k] for k in cached.files} + + with open(data_cache) as f: + data_meta = json.load(f) + + result = predict_new_pool( + sample_dict, data_meta, args.chain, args.tokens, args.fee + ) + print(json.dumps(result, indent=2)) diff --git a/quantammsim/noise_calibration/constants.py b/quantammsim/noise_calibration/constants.py new file mode 100644 index 0000000..05a217d --- /dev/null +++ b/quantammsim/noise_calibration/constants.py @@ -0,0 +1,60 @@ +"""Constants for noise calibration.""" + +import os + +K_COEFF = 4 +COEFF_NAMES = ["intercept", "b_tvl", "b_sigma", "b_weekend"] + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAINS = [ + "MAINNET", "POLYGON", "ARBITRUM", "GNOSIS", "BASE", "SONIC", "OPTIMISM", + "AVALANCHE", +] + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "local_data", "noise_calibration", +) + +# Tier 0: blue-chip — top by volume, wrapped native, major stables +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", +} + +K_CLUSTERS_DEFAULT = 6 +K_FEATURES_DEFAULT = 6 + +# Structural model: observation-level covariates (expanded from K_COEFF=4) +K_OBS_COEFF = 8 +OBS_COEFF_NAMES = [ + "intercept", "b_tvl", "b_sigma", + "b_tvl_sigma", "b_tvl_fee", "b_sigma_fee", + "b_dow_sin", "b_dow_cos", +] + +# Gas costs per arb transaction (USD) by chain +GAS_COSTS = { + "MAINNET": None, # time-varying, loaded from CSV + "POLYGON": 0.005, + "ARBITRUM": 0.005, + "BASE": 0.005, + "GNOSIS": 0.01, + "OPTIMISM": 0.005, + "SONIC": 0.005, + "AVALANCHE": 0.005, + "MODE": 0.005, + "FRAXTAL": 0.005, +} + +# Tier 1: mid-cap DeFi blue-chips (approx CoinGecko rank < 200) +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} diff --git a/quantammsim/noise_calibration/covariate_encoding.py b/quantammsim/noise_calibration/covariate_encoding.py new file mode 100644 index 0000000..a298a6b --- /dev/null +++ b/quantammsim/noise_calibration/covariate_encoding.py @@ -0,0 +1,228 @@ +"""Covariate encoding for the hierarchical noise model.""" + +import numpy as np +import pandas as pd + +from .constants import K_COEFF, K_OBS_COEFF, GAS_COSTS + + +def encode_covariates(panel: pd.DataFrame, include_tiers: bool = True) -> dict: + """Build NumPyro-ready arrays from the panel DataFrame. + + Returns dict with arrays for the model plus metadata for output/prediction. + Key difference from hierarchical script: x_obs uses log_tvl_lag1 not log_tvl. + """ + pool_meta = panel.drop_duplicates("pool_id").reset_index(drop=True) + pool_ids = pool_meta["pool_id"].values + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + N_pools = len(pool_ids) + + pool_idx = panel["pool_id"].map(pool_id_to_idx).values + + # --- Build X_pool (pool-level covariates, data-driven) --- + chains = sorted(panel["chain"].unique()) + ref_chain = chains[0] + chain_cols = [] + chain_names = [] + for c in chains[1:]: + chain_cols.append((pool_meta["chain"] == c).astype(float).values) + chain_names.append(f"chain_{c}") + + tier_a_vals = sorted(pool_meta["tier_A"].astype(str).unique()) + ref_tier_a = tier_a_vals[0] + tier_a_cols = [] + tier_a_names = [] + if include_tiers: + for t in tier_a_vals[1:]: + tier_a_cols.append( + (pool_meta["tier_A"].astype(str) == t).astype(float).values + ) + tier_a_names.append(f"tier_A_{t}") + + tier_b_vals = sorted(pool_meta["tier_B"].astype(str).unique()) + ref_tier_b = tier_b_vals[0] + tier_b_cols = [] + tier_b_names = [] + if include_tiers: + for t in tier_b_vals[1:]: + tier_b_cols.append( + (pool_meta["tier_B"].astype(str) == t).astype(float).values + ) + tier_b_names.append(f"tier_B_{t}") + + columns = [np.ones((N_pools, 1))] + col_names = ["intercept"] + + for arr, name in zip(chain_cols, chain_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_a_cols, tier_a_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_b_cols, tier_b_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + columns.append(pool_meta["log_fee"].values.reshape(-1, 1)) + col_names.append("log_fee") + + X_pool = np.hstack(columns) + K_cov = X_pool.shape[1] + + # --- Observation-level arrays (uses LAGGED TVL) --- + x_obs = np.column_stack([ + np.ones(len(panel)), + panel["log_tvl_lag1"].values, + panel["volatility"].values, + panel["weekend"].values, + ]).astype(np.float64) + + y_obs = panel["log_volume"].values.astype(np.float64) + + # --- Per-pool tier_A index for per-tier sigma_eps --- + tier_A_per_pool = pool_meta["tier_A"].values.astype(np.int32) + + print(f" Encoded: N_obs={len(y_obs)}, N_pools={N_pools}, " + f"K_coeff={K_COEFF}, K_cov={K_cov}") + print(f" Covariates: {col_names}") + print(f" Tier distribution: " + f"T0={np.sum(tier_A_per_pool == 0)}, " + f"T1={np.sum(tier_A_per_pool == 1)}, " + f"T2={np.sum(tier_A_per_pool == 2)}") + + return { + "pool_idx": pool_idx.astype(np.int32), + "X_pool": X_pool.astype(np.float64), + "x_obs": x_obs, + "y_obs": y_obs, + "pool_ids": list(pool_ids), + "pool_meta": pool_meta, + "covariate_names": col_names, + "tier_A_per_pool": tier_A_per_pool, + "N_pools": N_pools, + "K_cov": K_cov, + "ref_chain": ref_chain, + "ref_tier_a": ref_tier_a, + "ref_tier_b": ref_tier_b, + "chains": chains, + } + + +def _tier_pair_idx(a: int, b: int) -> int: + """Encode (tier_A, tier_B) pair as a single index. + + Upper triangle of 3x3 grid: + (0,0)->0, (0,1)->1, (0,2)->2, (1,1)->3, (1,2)->4, (2,2)->5. + """ + return a * (5 - a) // 2 + b - a + + +def encode_covariates_structural( + panel: pd.DataFrame, + gas: np.ndarray = None, +) -> dict: + """Build NumPyro-ready arrays for the structural mixture model. + + Extends encode_covariates with: + - x_obs: 8 columns (intercept, tvl, log_sigma, interactions, DOW harmonics) + - Additional arrays: sigma_daily, fee, gas, chain_idx, tier_idx, lag_log_tvl + - n_chains, n_tiers computed from panel + + Parameters + ---------- + panel : pd.DataFrame + Output of assemble_panel(), must have log_sigma, dow_sin, dow_cos, + tvl_x_sigma, tvl_x_fee, sigma_x_fee columns. + gas : np.ndarray, optional + Per-observation gas costs in USD. If None, uses default (0.01 for all). + """ + # Ensure structural columns exist (compute from base columns if missing) + if "log_sigma" not in panel.columns: + panel = panel.copy() + panel["log_sigma"] = np.log(np.maximum(panel["volatility"].values, 1e-6)) + dow = panel["date"].apply( + lambda d: d.weekday() if hasattr(d, "weekday") + else pd.Timestamp(d).weekday() + ) + panel["dow_sin"] = np.sin(2.0 * np.pi * dow / 7.0) + panel["dow_cos"] = np.cos(2.0 * np.pi * dow / 7.0) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + # Reuse X_pool construction from encode_covariates (with tiers for gating) + base = encode_covariates(panel, include_tiers=True) + + # --- Observation-level x_obs: 8 columns --- + x_obs = np.column_stack([ + np.ones(len(panel)), # intercept + panel["log_tvl_lag1"].values, # lagged TVL + panel["log_sigma"].values, # log(volatility) + panel["tvl_x_sigma"].values, # tvl × sigma interaction + panel["tvl_x_fee"].values, # tvl × fee interaction + panel["sigma_x_fee"].values, # sigma × fee interaction + panel["dow_sin"].values, # DOW harmonic sin + panel["dow_cos"].values, # DOW harmonic cos + ]).astype(np.float64) + + # --- Additional arrays for the structural model --- + sigma_daily = (panel["volatility"] / np.sqrt(365.0)).values.astype(np.float64) + fee_per_obs = np.exp(panel["log_fee"].values).astype(np.float64) + lag_log_tvl = panel["log_tvl_lag1"].values.astype(np.float64) + + # Gas: per-observation + if gas is not None: + gas_arr = np.asarray(gas, dtype=np.float64) + else: + gas_arr = np.full(len(panel), 0.01, dtype=np.float64) + + # Chain index: integer per pool + pool_meta = base["pool_meta"] + chains = base["chains"] + chain_to_idx = {c: i for i, c in enumerate(chains)} + chain_idx_per_pool = np.array( + [chain_to_idx[c] for c in pool_meta["chain"]], dtype=np.int32, + ) + + # Tier pair index: per pool + tier_idx_per_pool = np.array( + [_tier_pair_idx(int(row["tier_A"]), int(row["tier_B"])) + for _, row in pool_meta.iterrows()], + dtype=np.int32, + ) + + # Count unique tier pairs and chains + n_chains = len(chains) + tier_pairs = set() + for _, row in pool_meta.iterrows(): + tier_pairs.add((int(row["tier_A"]), int(row["tier_B"]))) + n_tiers = 6 # fixed: upper triangle of 3x3 + + print(f" Structural encoding: N_obs={len(panel)}, " + f"N_pools={base['N_pools']}, n_chains={n_chains}, n_tiers={n_tiers}") + + return { + # Base arrays (same as encode_covariates) + "pool_idx": base["pool_idx"], + "X_pool": base["X_pool"], + "x_obs": x_obs, + "y_obs": base["y_obs"], + "pool_ids": base["pool_ids"], + "pool_meta": pool_meta, + "covariate_names": base["covariate_names"], + "tier_A_per_pool": base["tier_A_per_pool"], + "N_pools": base["N_pools"], + "K_cov": base["K_cov"], + "ref_chain": base["ref_chain"], + "ref_tier_a": base["ref_tier_a"], + "ref_tier_b": base["ref_tier_b"], + "chains": chains, + # Structural model extras + "sigma_daily": sigma_daily, + "fee": fee_per_obs, + "gas": gas_arr, + "chain_idx": chain_idx_per_pool, + "tier_idx": tier_idx_per_pool, + "lag_log_tvl": lag_log_tvl, + "n_chains": n_chains, + "n_tiers": n_tiers, + } diff --git a/quantammsim/noise_calibration/data_pipeline.py b/quantammsim/noise_calibration/data_pipeline.py new file mode 100644 index 0000000..b824cdc --- /dev/null +++ b/quantammsim/noise_calibration/data_pipeline.py @@ -0,0 +1,516 @@ +"""Data pipeline: fetch pools, snapshots, prices, and assemble panel.""" + +import json +import os +import time +import urllib.request +from datetime import datetime + +import numpy as np +import pandas as pd + +from .constants import BALANCER_API_URL, BALANCER_API_CHAINS +from .token_classification import classify_token_tier + + +def _graphql_request(query: dict, base_url: str = BALANCER_API_URL, + timeout: int = 30) -> dict: + """Send a GraphQL request to the Balancer V3 API.""" + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def enumerate_balancer_pools( + chains: list = None, + pool_types: list = None, + min_tvl: float = 10000.0, +) -> pd.DataFrame: + """Enumerate all WEIGHTED + RECLAMM pools across chains from Balancer API.""" + if chains is None: + chains = BALANCER_API_CHAINS + if pool_types is None: + pool_types = ["WEIGHTED", "RECLAMM"] + + all_pools = [] + for chain in chains: + print(f" Querying {chain}...", end=" ", flush=True) + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { + chainIn: [$chain] + poolTypeIn: $types + minTvl: $minTvl + } + ) { + id + chain + type + createTime + protocolVersion + poolTokens { + symbol + weight + address + } + dynamicData { + totalLiquidity + swapFee + } + } + } + """, + "variables": { + "chain": chain, + "types": pool_types, + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f"FAILED ({e})") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + weights = [t.get("weight") for t in p.get("poolTokens", [])] + token_addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "protocol_version": p.get("protocolVersion", 0), + "tokens": tokens, + "token_addresses": token_addresses, + "weights": weights, + "swap_fee": fee, + "create_time": p.get("createTime", 0), + "current_tvl": tvl, + }) + + print(f"{len(pools)} pools") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools across {len(chains)} chains") + return df + + +def fetch_pool_snapshots(pool_id: str, chain: str, + base_url: str = BALANCER_API_URL) -> pd.DataFrame: + """Fetch ALL_TIME daily snapshots for a single pool.""" + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + totalShares + } + } + """, + "variables": { + "poolId": pool_id, + "chain": chain, + "range": "ALL_TIME", + }, + } + + body = _graphql_request(query) + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + + if not snapshots: + return pd.DataFrame(columns=["timestamp", "volume_usd", + "total_liquidity_usd", "total_shares"]) + + records = [] + for snap in snapshots: + records.append({ + "timestamp": int(snap["timestamp"]), + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + "total_shares": float(snap.get("totalShares", 0)), + }) + + df = pd.DataFrame(records) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df + + +def fetch_all_snapshots(pools_df: pd.DataFrame, + cache_path: str = None) -> pd.DataFrame: + """Fetch daily snapshots for all pools, with caching.""" + cached = pd.DataFrame() + cached_pool_ids = set() + if cache_path and os.path.exists(cache_path): + cached = pd.read_parquet(cache_path) + cached_pool_ids = set(cached["pool_id"].unique()) + print(f" Cache has {len(cached_pool_ids)} pools, " + f"{len(cached)} pool-days") + + if len(pools_df) == 0: + print(" No pools to fetch.") + return cached if len(cached) > 0 else pd.DataFrame( + columns=["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd", "total_shares"] + ) + to_fetch = pools_df[~pools_df["pool_id"].isin(cached_pool_ids)] + print(f" Need to fetch {len(to_fetch)} new pools") + + new_records = [] + for i, (_, pool) in enumerate(to_fetch.iterrows()): + if (i + 1) % 10 == 0 or i == 0: + print(f" Fetching {i+1}/{len(to_fetch)}: {pool['pool_id'][:10]}... " + f"({pool['chain']})", flush=True) + try: + snap_df = fetch_pool_snapshots(pool["pool_id"], pool["chain"]) + if len(snap_df) > 0: + snap_df["pool_id"] = pool["pool_id"] + snap_df["chain"] = pool["chain"] + cols = ["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + if "total_shares" in snap_df.columns: + cols.append("total_shares") + new_records.append(snap_df[cols]) + except Exception as e: + print(f" FAILED {pool['pool_id'][:10]}: {e}") + time.sleep(0.5) + + if new_records: + new_df = pd.concat(new_records, ignore_index=True) + combined = pd.concat([cached, new_df], ignore_index=True) + else: + combined = cached + + if cache_path and len(combined) > 0: + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + combined.to_parquet(cache_path, index=False) + print(f" Saved cache: {len(combined)} pool-days -> {cache_path}") + + return combined + + +def fetch_token_prices(token_addresses_by_chain: dict, + cache_dir: str = None) -> dict: + """Fetch hourly token prices from Balancer API.""" + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + prices = {} + + for chain, tokens in token_addresses_by_chain.items(): + uncached = {} + for symbol, address in tokens.items(): + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") if cache_dir else None + + if cp and os.path.exists(cp): + prices[(chain, symbol)] = pd.read_parquet(cp) + else: + uncached[symbol] = address + + if not uncached: + continue + + addr_to_symbol = {addr: sym for sym, addr in uncached.items()} + addresses = list(uncached.values()) + + print(f" Fetching {len(addresses)} prices on {chain}...", flush=True) + + batch_size = 20 + for batch_start in range(0, len(addresses), batch_size): + batch_addrs = addresses[batch_start:batch_start + batch_size] + query = { + "query": """ + query GetPrices($chain: GqlChain!, $addresses: [String!]!, + $range: GqlTokenChartDataRange!) { + tokenGetHistoricalPrices( + addresses: $addresses, chain: $chain, range: $range + ) { + address + prices { + timestamp + price + } + } + } + """, + "variables": { + "chain": chain, + "addresses": batch_addrs, + "range": "ONE_YEAR", + }, + } + + try: + body = _graphql_request(query, timeout=60) + results = body.get("data", {}).get( + "tokenGetHistoricalPrices", []) + for result in results: + addr = result.get("address", "") + price_list = result.get("prices", []) + symbol = addr_to_symbol.get(addr) + if symbol and price_list: + pdf = pd.DataFrame(price_list) + pdf["timestamp"] = pdf["timestamp"].astype(int) + pdf["price"] = pdf["price"].astype(float) + prices[(chain, symbol)] = pdf + if cache_dir: + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") + pdf.to_parquet(cp, index=False) + except Exception as e: + print(f" FAILED batch on {chain}: {e}") + + time.sleep(0.5) + + print(f" Got prices for {len(prices)} token-chain pairs") + return prices + + +def compute_pair_volatility( + snapshots_df: pd.DataFrame, + pool_row: pd.Series, + token_prices: dict, +) -> pd.Series: + """Compute daily annualised volatility for a pool's pair ratio.""" + tokens = pool_row["tokens"] + chain = pool_row["chain"] + + if len(tokens) < 2: + return pd.Series(dtype=float) + + def _get_price_df(symbol): + key = (chain, symbol) + if key in token_prices: + return token_prices[key] + for k, v in token_prices.items(): + if k[1] == symbol: + return v + return None + + p0_df = _get_price_df(tokens[0]) + p1_df = _get_price_df(tokens[1]) + + stables = {"USDC", "USDT", "DAI", "LUSD", "GHO", "crvUSD", "sDAI", + "WXDAI", "xDAI", "USDC.e", "USDbC"} + + if tokens[0] in stables and tokens[1] in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.01, index=dates) + + if p0_df is None and tokens[0] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + if p1_df is None and tokens[1] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + + if tokens[0] in stables: + if p1_df is None or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p1_df.copy() + ratio_df["ratio"] = 1.0 / ratio_df["price"] + elif tokens[1] in stables: + if p0_df is None or len(p0_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p0_df.copy() + ratio_df["ratio"] = ratio_df["price"] + else: + if p0_df is None or p1_df is None or len(p0_df) == 0 or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + merged = pd.merge_asof( + p0_df.sort_values("timestamp"), + p1_df.sort_values("timestamp"), + on="timestamp", + suffixes=("_0", "_1"), + tolerance=7200, + ).dropna() + if len(merged) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = merged.copy() + ratio_df["ratio"] = merged["price_0"] / merged["price_1"] + + ratio_df["datetime"] = pd.to_datetime(ratio_df["timestamp"], unit="s") + ratio_df["date"] = ratio_df["datetime"].dt.date + ratio_df = ratio_df.sort_values("timestamp") + + ratio_df["log_return"] = np.log( + ratio_df["ratio"] / ratio_df["ratio"].shift(1) + ) + ratio_df = ratio_df.dropna(subset=["log_return"]) + + daily_vol = ratio_df.groupby("date")["log_return"].std() + daily_vol_ann = daily_vol * np.sqrt(24 * 365) + + return daily_vol_ann + + +def assemble_panel( + pools_df: pd.DataFrame, + snapshots_df: pd.DataFrame, + token_prices: dict, +) -> pd.DataFrame: + """Assemble the full panel DataFrame with lagged TVL. + + Adds log_tvl_lag1 = per-pool shift(1) of log_tvl to break + the TVL-volume simultaneity bias. Drops the first observation + per pool (~1 obs per pool). + """ + records = [] + pool_ids = snapshots_df["pool_id"].unique() + n_pools = len(pool_ids) + + # Track volatility fallback rate + n_obs_total = 0 + n_obs_fallback = 0 + pools_all_fallback = [] # pools where every obs hit fallback + + for i, pool_id in enumerate(pool_ids): + if (i + 1) % 20 == 0 or i == 0: + print(f" Assembling {i+1}/{n_pools}...", flush=True) + + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pool_id] + pool_meta = pools_df[pools_df["pool_id"] == pool_id] + if len(pool_meta) == 0: + continue + pool_row = pool_meta.iloc[0] + + tokens = pool_row["tokens"] + if len(tokens) < 2: + continue + + chain = pool_row["chain"] + swap_fee = pool_row["swap_fee"] + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + vol_series = compute_pair_volatility(pool_snaps, pool_row, token_prices) + + pool_obs = 0 + pool_fallback = 0 + has_shares = "total_shares" in pool_snaps.columns + for _, snap in pool_snaps.iterrows(): + date = snap["date"] + volume = snap["volume_usd"] + tvl = snap["total_liquidity_usd"] + shares = float(snap["total_shares"]) if has_shares else 0.0 + + if tvl <= 0 or volume <= 0: + continue + + used_fallback = False + if isinstance(vol_series, pd.Series) and date in vol_series.index: + vol = vol_series[date] + else: + vol = 0.5 + used_fallback = True + + if not np.isfinite(vol) or vol <= 0: + vol = 0.5 + used_fallback = True + + n_obs_total += 1 + if used_fallback: + n_obs_fallback += 1 + pool_fallback += 1 + pool_obs += 1 + + if isinstance(date, datetime): + is_weekend = date.weekday() >= 5 + else: + is_weekend = pd.Timestamp(date).weekday() >= 5 + + # DOW harmonics (deterministic from date) + if isinstance(date, datetime): + dow = date.weekday() + else: + dow = pd.Timestamp(date).weekday() + dow_sin = np.sin(2.0 * np.pi * dow / 7.0) + dow_cos = np.cos(2.0 * np.pi * dow / 7.0) + + record = { + "pool_id": pool_id, + "chain": chain, + "date": date, + "log_volume": np.log(volume), + "log_tvl": np.log(tvl), + "volatility": vol, + "log_sigma": np.log(max(vol, 1e-6)), + "weekend": 1.0 if is_weekend else 0.0, + "log_fee": np.log(max(swap_fee, 1e-6)), + "dow_sin": dow_sin, + "dow_cos": dow_cos, + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": ",".join(tokens[:2]), + "swap_fee": swap_fee, + } + if shares > 0: + record["total_shares"] = shares + records.append(record) + + if pool_obs > 0 and pool_fallback == pool_obs: + pools_all_fallback.append( + (pool_id[:16], chain, ",".join(tokens[:2])) + ) + + panel = pd.DataFrame(records) + + # Add lagged TVL to break simultaneity bias + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + n_before = len(panel) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + n_dropped = n_before - len(panel) + + # Interaction terms (use lagged TVL to break simultaneity) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + print(f"\n Panel: {len(panel)} observations, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + print(f" Dropped {n_dropped} first-day obs for lagged TVL") + + # Volatility coverage report + if n_obs_total > 0: + pct = 100 * n_obs_fallback / n_obs_total + print(f"\n Volatility coverage:") + print(f" {n_obs_fallback}/{n_obs_total} obs used fallback " + f"vol=0.5 ({pct:.1f}%)") + print(f" {len(pools_all_fallback)} pools had 100% fallback") + if pools_all_fallback: + for pid, ch, toks in pools_all_fallback[:10]: + print(f" {pid}... ({ch}) {toks}") + if len(pools_all_fallback) > 10: + print(f" ... and {len(pools_all_fallback) - 10} more") + + return panel diff --git a/quantammsim/noise_calibration/data_validation.py b/quantammsim/noise_calibration/data_validation.py new file mode 100644 index 0000000..8cb44e7 --- /dev/null +++ b/quantammsim/noise_calibration/data_validation.py @@ -0,0 +1,41 @@ +"""Data validation for noise calibration panels.""" + +import numpy as np +import pandas as pd + + +def validate_panel(panel: pd.DataFrame) -> pd.DataFrame: + """Run data validation checks. Prints warnings but does NOT drop rows.""" + print("\n Data validation:") + + # Pools with constant volume + vol_std = panel.groupby("pool_id")["log_volume"].std() + constant_vol = vol_std[vol_std < 0.01] + if len(constant_vol) > 0: + print(f" WARNING: {len(constant_vol)} pools have near-constant " + f"log(volume) (std < 0.01)") + for pid in constant_vol.index[:5]: + print(f" {pid[:16]}... std={constant_vol[pid]:.4f}") + if len(constant_vol) > 5: + print(f" ... and {len(constant_vol) - 5} more") + + # TVL jumps > 10x between consecutive days + panel_sorted = panel.sort_values(["pool_id", "date"]) + tvl_ratio = panel_sorted.groupby("pool_id")["log_tvl"].diff().abs() + big_jumps = tvl_ratio[tvl_ratio > np.log(10)] + if len(big_jumps) > 0: + affected_pools = panel_sorted.loc[big_jumps.index, "pool_id"].nunique() + print(f" WARNING: {len(big_jumps)} TVL jumps > 10x across " + f"{affected_pools} pools") + + # Days where volume > TVL + high_vol = panel[panel["log_volume"] > panel["log_tvl"]] + if len(high_vol) > 0: + affected_pools = high_vol["pool_id"].nunique() + print(f" WARNING: {len(high_vol)} days where volume > TVL across " + f"{affected_pools} pools (potential wash trading)") + + if len(constant_vol) == 0 and len(big_jumps) == 0 and len(high_vol) == 0: + print(" All checks passed.") + + return panel diff --git a/quantammsim/noise_calibration/formula_arb.py b/quantammsim/noise_calibration/formula_arb.py new file mode 100644 index 0000000..80fe335 --- /dev/null +++ b/quantammsim/noise_calibration/formula_arb.py @@ -0,0 +1,36 @@ +"""JAX-differentiable LVR formula for arb volume. + +Based on arXiv:2305.14604v2 §6 with gas costs and discrete-time correction. +Reference: scripts/plot_formula_arb_vs_real.py:formula_arb_volume_daily (line 58). +""" + +import jax.numpy as jnp + + +def formula_arb_volume_daily_jax(sigma_daily, tvl, fee, gas_usd, cadence_minutes): + """Analytical arb volume per day for a CPMM with gas costs. + + All inputs are JAX scalars or arrays (must be broadcastable). + + Parameters + ---------- + sigma_daily : float + Daily volatility of the log price ratio (NOT annualised). + tvl : float + Pool TVL in USD. + fee : float + Swap fee as fraction (e.g. 0.003 for 30bp). + gas_usd : float + All-in gas cost per arb tx in USD. + cadence_minutes : float + Effective arb cadence in minutes (= simulator's arb_frequency). + """ + block_time_s = cadence_minutes * 60.0 + delta = 2.0 * jnp.sqrt(2.0 * jnp.maximum(gas_usd, 0.0) / jnp.maximum(tvl, 1e-6)) + bLVR = sigma_daily**2 * tvl / 8.0 + sqrt_term = sigma_daily * jnp.sqrt(block_time_s / (2.0 * 86400.0)) + correction = jnp.maximum( + 1.0 - delta / (2.0 * fee) - sqrt_term / (fee + delta / 2.0), + 0.0, + ) + return bLVR * correction / fee diff --git a/quantammsim/noise_calibration/inference.py b/quantammsim/noise_calibration/inference.py new file mode 100644 index 0000000..e315a5b --- /dev/null +++ b/quantammsim/noise_calibration/inference.py @@ -0,0 +1,270 @@ +"""Inference runners: SVI, NUTS, SVI-initialized NUTS.""" + +import numpy as np + +from .constants import K_COEFF +from .model import noise_model + + +def _get_theta_samples(sample_dict: dict, X_pool: np.ndarray, + data: dict = None) -> np.ndarray: + """Get theta samples, reconstructing from non-centered params if needed. + + MCMC.get_samples() includes the deterministic "theta" site. + SVI's Predictive(guide, ...) does NOT — the guide only samples latent + variables. In that case, reconstruct theta manually: + mu = X_pool @ B^T + L_Sigma = diag(sigma_theta) @ L_Omega + theta = mu + eta @ L_Sigma^T + + For marginalized IBP (W present, z_logit absent): compute MAP feature + assignments from data, then theta = X_pool @ B.T + Z_MAP @ W. + Requires data dict for pool_idx, x_obs, y_obs. + + For legacy STE IBP (z_logit present): theta = X_pool @ B.T + Z_hard @ W. + """ + if "theta" in sample_dict: + return np.array(sample_dict["theta"]) + + # Marginalized IBP path: compute MAP assignments from data + if "W" in sample_dict and "z_logit" not in sample_dict: + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + W = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + + # MAP assignments: (N_pools, K_features) binary + if "v" in sample_dict: + # Hybrid IBP+DP: joint MAP over (features, clusters) + from .postprocessing import assign_ibp_dp_joint + Z_map, _ = assign_ibp_dp_joint(sample_dict, data) + else: + from .postprocessing import assign_ibp_features + Z_map = assign_ibp_features(sample_dict, data) + + mu = np.einsum("pd,sjd->spj", X_pool, B) + # Z_map doesn't vary across samples — broadcast + feature_effect = np.einsum("pk,skj->spj", Z_map.astype(float), W) + return mu + feature_effect + + # Legacy STE IBP path: theta = X_pool @ B.T + Z_hard @ W + if "z_logit" in sample_dict: + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + W = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + z_logit = np.array(sample_dict["z_logit"]) # (S, N_pools, K_features) + Z_hard = (z_logit > 0).astype(float) + + mu = np.einsum("pd,sjd->spj", X_pool, B) # (S, N_pools, K_coeff) + feature_effect = np.einsum("spk,skj->spj", Z_hard, W) + return mu + feature_effect + + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + sigma_theta = np.array(sample_dict["sigma_theta"]) # (S, K_coeff) + L_Omega = np.array(sample_dict["L_Omega"]) # (S, K_coeff, K_coeff) + eta = np.array(sample_dict["eta"]) # (S, N_pools, K_coeff) + + # mu[s, p, j] = sum_d X_pool[p, d] * B[s, j, d] -> (S, N_pools, K_coeff) + mu = np.einsum("pd,sjd->spj", X_pool, B) + + # L_Sigma = diag(sigma_theta) @ L_Omega -> (S, K_coeff, K_coeff) + L_Sigma = sigma_theta[:, :, None] * L_Omega + + # offset = eta @ L_Sigma^T -> (S, N_pools, K_coeff) + offset = np.einsum("spi,sji->spj", eta, L_Sigma) + + return mu + offset + + +def _build_model_kwargs(data: dict, model_fn=None) -> dict: + """Convert data dict to jnp arrays for the model. + + Uses inspect.signature on model_fn to decide which kwargs to include: + - tier_A_per_pool: only if model_fn accepts it + - K_clusters: only if model_fn accepts it and data has it + """ + import inspect + import jax.numpy as jnp + + if model_fn is None: + model_fn = noise_model + + params = set(inspect.signature(model_fn).parameters.keys()) + + kwargs = dict( + pool_idx=jnp.array(data["pool_idx"]), + X_pool=jnp.array(data["X_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K_coeff=K_COEFF, + K_cov=data["K_cov"], + ) + + if "tier_A_per_pool" in params: + kwargs["tier_A_per_pool"] = jnp.array(data["tier_A_per_pool"]) + + if "K_clusters" in params and "K_clusters" in data: + kwargs["K_clusters"] = data["K_clusters"] + + if "K_features" in params and "K_features" in data: + kwargs["K_features"] = data["K_features"] + + # Structural model parameters + if "sigma_daily" in params and "sigma_daily" in data: + kwargs["sigma_daily"] = jnp.array(data["sigma_daily"]) + if "lag_log_tvl" in params and "lag_log_tvl" in data: + kwargs["lag_log_tvl"] = jnp.array(data["lag_log_tvl"]) + if "fee" in params and "fee" in data: + kwargs["fee"] = jnp.array(data["fee"]) + if "gas" in params and "gas" in data: + kwargs["gas"] = jnp.array(data["gas"]) + if "chain_idx" in params and "chain_idx" in data: + kwargs["chain_idx"] = jnp.array(data["chain_idx"]) + if "tier_idx" in params and "tier_idx" in data: + kwargs["tier_idx"] = jnp.array(data["tier_idx"]) + if "n_chains" in params and "n_chains" in data: + kwargs["n_chains"] = data["n_chains"] + if "n_tiers" in params and "n_tiers" in data: + kwargs["n_tiers"] = data["n_tiers"] + if "K_archetypes" in params and "K_archetypes" in data: + kwargs["K_archetypes"] = data["K_archetypes"] + + return kwargs + + +def run_svi(data, num_steps=20000, lr=1e-3, seed=0, + num_samples=1000, model_fn=None) -> tuple: + """Run SVI with AutoNormal guide. + + Returns (samples_dict, elbo_losses). + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import SVI, Trace_ELBO, Predictive + from numpyro.infer.autoguide import AutoNormal + + if model_fn is None: + model_fn = noise_model + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + + print(f"\n Running SVI: {num_steps} steps, lr={lr}") + guide = AutoNormal(model_fn) + optimizer = numpyro.optim.Adam(lr) + svi = SVI(model_fn, guide, optimizer, loss=Trace_ELBO()) + + rng_key = jax.random.PRNGKey(seed) + svi_result = svi.run(rng_key, num_steps, **model_kwargs) + + elbo_losses = np.array(svi_result.losses) + print(f" SVI complete. Final ELBO: {elbo_losses[-1]:.2f}") + print(f" ELBO last 100 std: {np.std(elbo_losses[-100:]):.2f}") + + # Draw posterior samples + predictive = Predictive( + guide, params=svi_result.params, num_samples=num_samples, + ) + samples = predictive(jax.random.PRNGKey(seed + 1), **model_kwargs) + samples = {k: np.array(v) for k, v in samples.items()} + + print(f" Drew {num_samples} posterior samples.") + return samples, elbo_losses + + +def run_nuts(data, num_warmup=1000, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42, + init_values=None, model_fn=None): + """Run NUTS MCMC. + + Uses init_to_value if init_values provided (for SVI-initialized NUTS). + Returns the MCMC object. + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import MCMC, NUTS, init_to_value + + if model_fn is None: + model_fn = noise_model + + # Note: set_host_device_count must be called before JAX init. + # We handle this in main(). Here we just verify device count. + n_devices = len(jax.devices("cpu")) + if n_devices < num_chains: + print(f" WARNING: Only {n_devices} CPU devices available for " + f"{num_chains} chains. Chains will run sequentially.") + + init_strategy = None + if init_values is not None: + init_strategy = init_to_value( + values={k: jnp.array(v) for k, v in init_values.items()} + ) + + kernel = NUTS( + model_fn, + target_accept_prob=target_accept, + max_tree_depth=max_tree_depth, + init_strategy=init_strategy, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=True, + ) + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + rng_key = jax.random.PRNGKey(seed) + + print(f"\n Running NUTS: {num_chains} chains x " + f"({num_warmup} warmup + {num_samples} samples)") + print(f" target_accept={target_accept}, max_tree_depth={max_tree_depth}") + if init_values is not None: + print(" Using SVI-initialized starting values.") + + mcmc.run(rng_key, **model_kwargs) + mcmc.print_summary(exclude_deterministic=True) + return mcmc + + +def run_svi_then_nuts(data, svi_steps=5000, svi_lr=1e-3, + num_warmup=500, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42, + model_fn=None): + """Run SVI first, then use posterior means as NUTS init. + + Returns (MCMC, elbo_losses). + """ + if model_fn is None: + model_fn = noise_model + + # Phase 1: SVI + print(" Phase 1: SVI warm-start") + samples, elbo_losses = run_svi( + data, num_steps=svi_steps, lr=svi_lr, seed=seed, num_samples=100, + model_fn=model_fn, + ) + + # Extract posterior means for init + init_values = {} + skip_keys = {"y", "theta", "w"} + for k, v in samples.items(): + if k in skip_keys: + continue + init_values[k] = np.mean(v, axis=0) + + # Phase 2: NUTS from SVI init + print("\n Phase 2: NUTS from SVI-initialized values") + mcmc = run_nuts( + data, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + target_accept=target_accept, + max_tree_depth=max_tree_depth, + seed=seed, + init_values=init_values, + model_fn=model_fn, + ) + + return mcmc, elbo_losses diff --git a/quantammsim/noise_calibration/model.py b/quantammsim/noise_calibration/model.py new file mode 100644 index 0000000..4a1ba19 --- /dev/null +++ b/quantammsim/noise_calibration/model.py @@ -0,0 +1,465 @@ +"""NumPyro noise volume models.""" + +import jax +import jax.numpy as jnp + +from .formula_arb import formula_arb_volume_daily_jax + + +def _pad_with_ref(alpha): + """Prepend a zero for the reference category.""" + return jnp.concatenate([jnp.zeros(1), alpha]) + + +def stick_breaking_weights(v): + """Convert Beta stick-breaking fractions to K-simplex weights. + + v: array of shape (K-1,) with values in (0, 1). + Returns weights of shape (K,) summing to 1. + + w_1 = v_1 + w_k = v_k * prod_{j> jnp.arange(K_features)[None, :]) & 1 + ).astype(jnp.float32) # (n_configs, K_features) + + # Log-prior for each config from IBP stick-breaking + log_pi = jnp.log(pi + 1e-30) + log_1mpi = jnp.log(1.0 - pi + 1e-30) + log_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Per-config feature effect + feature_effects = configs @ W # (n_configs, K_coeff) + + # Per-observation means + mu_pop_obs = jnp.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config + log_lik = dist.StudentT(df, mu_obs, sigma_eps).log_prob( + y_obs[:, None] + ) # (N_obs, n_configs) + + # Sum log-likelihoods within each pool + pool_log_liks = jnp.zeros((N_pools, n_configs)) + pool_log_liks = pool_log_liks.at[pool_idx].add(log_lik) + + # Marginal: logsumexp over configs per pool + log_marginal = logsumexp( + log_prior[None, :] + pool_log_liks, axis=1 + ) # (N_pools,) + numpyro.factor("log_lik", log_marginal.sum()) + else: + # === Prior predictive: sample explicit assignments === + with numpyro.plate("pools", N_pools): + z_features = numpyro.sample( + "z_features", + dist.Bernoulli(probs=pi).expand([K_features]).to_event(1), + ) + + theta = mu_pop + z_features @ W # (N_pools, K_coeff) + numpyro.deterministic("theta", theta) + + theta_obs = theta[pool_idx] + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample( + "y", dist.StudentT(df, mu_obs, sigma_eps), + ) + + +def noise_model_ibp_dp(pool_idx, X_pool, x_obs, y_obs=None, + N_pools=None, K_coeff=4, K_cov=None, + K_features=6, K_clusters=6): + """Hybrid IBP+DP noise model. + + IBP latent features for mean heterogeneity (theta = X_pool @ B.T + z @ W), + DP mixture for noise heterogeneity (per-cluster sigma_eps). Joint + marginalization over (2^K_features × K_clusters) configurations when + y_obs is provided. + """ + import numpyro + import numpyro.distributions as dist + from jax.scipy.special import logsumexp + + # --- Population effects --- + B = numpyro.sample( + "B", dist.Normal(0.0, 5.0).expand([K_coeff, K_cov]).to_event(2) + ) + + # Student-t degrees of freedom + df = numpyro.sample("df", dist.Gamma(2.0, 0.1)) + + # --- IBP prior on feature prevalences --- + alpha_ibp = numpyro.sample("alpha_ibp", dist.Gamma(2.0, 1.0)) + with numpyro.plate("features", K_features): + v_ibp = numpyro.sample("v_ibp", dist.Beta(alpha_ibp, 1.0)) + pi = jnp.cumprod(v_ibp) # decreasing prevalences + + # --- Feature effect matrix --- + sigma_w = numpyro.sample("sigma_w", dist.HalfNormal(2.0)) + W = numpyro.sample( + "W", dist.Normal(0.0, sigma_w).expand([K_features, K_coeff]).to_event(2) + ) + + # --- DP mixture on sigma_eps --- + alpha_dp = numpyro.sample("alpha_dp", dist.Gamma(1.0, 1.0)) + with numpyro.plate("sticks", K_clusters - 1): + v = numpyro.sample("v", dist.Beta(1.0, alpha_dp)) + w = numpyro.deterministic("w", stick_breaking_weights(v)) + + sigma_eps = numpyro.sample( + "sigma_eps", + dist.HalfNormal(2.0).expand([K_clusters]).to_event(1), + ) + + # --- Population mean per pool --- + mu_pop = X_pool @ B.T # (N_pools, K_coeff) + + if y_obs is not None: + # === Joint marginalization over IBP configs × DP clusters === + + # Enumerate all 2^K binary feature configurations + n_configs = 2 ** K_features + configs = ( + (jnp.arange(n_configs)[:, None] >> jnp.arange(K_features)[None, :]) & 1 + ).astype(jnp.float32) # (n_configs, K_features) + + # Log-prior for each IBP config + log_pi = jnp.log(pi + 1e-30) + log_1mpi = jnp.log(1.0 - pi + 1e-30) + log_ibp_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Per-config feature effect + feature_effects = configs @ W # (n_configs, K_coeff) + + # Per-observation means for each IBP config + mu_pop_obs = jnp.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per IBP config per DP cluster + log_lik = dist.StudentT( + df, mu_obs[:, :, None], sigma_eps[None, None, :] + ).log_prob(y_obs[:, None, None]) # (N_obs, n_configs, K_clusters) + + # Sum log-likelihoods within each pool + pool_log_liks = jnp.zeros((N_pools, n_configs, K_clusters)) + pool_log_liks = pool_log_liks.at[pool_idx].add(log_lik) + + # Joint prior: IBP config prior × DP cluster weight + log_joint_prior = log_ibp_prior[:, None] + jnp.log(w + 1e-30)[None, :] # (n_configs, K_clusters) + + # Marginal log-likelihood per pool: logsumexp over (configs, clusters) + log_marginal = logsumexp( + log_joint_prior[None, :, :] + pool_log_liks, axis=(1, 2) + ) # (N_pools,) + numpyro.factor("log_lik", log_marginal.sum()) + else: + # === Prior predictive: sample explicit assignments === + with numpyro.plate("pools", N_pools): + z_features = numpyro.sample( + "z_features", + dist.Bernoulli(probs=pi).expand([K_features]).to_event(1), + ) + z_cluster = numpyro.sample("z_cluster", dist.Categorical(probs=w)) + + theta = mu_pop + z_features @ W # (N_pools, K_coeff) + numpyro.deterministic("theta", theta) + + theta_obs = theta[pool_idx] + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) + sigma_obs = sigma_eps[z_cluster[pool_idx]] + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("y", dist.StudentT(df, mu_obs, sigma_obs)) + + +def structural_noise_model(pool_idx, X_pool, x_obs, y_obs=None, + sigma_daily=None, lag_log_tvl=None, + fee=None, gas=None, + chain_idx=None, tier_idx=None, + N_pools=None, K_obs_coeff=None, + K_cov=None, tier_A_per_pool=None, + n_chains=8, n_tiers=6, + **kwargs): + """Structural model: LVR arb + hierarchical per-pool noise. + + Decomposes observed total volume into arb (LVR formula with learnable + cadence) and noise (per-pool theta with hierarchical prior, same as + noise_model). All continuous — AutoNormal guide works directly. + """ + import numpyro + import numpyro.distributions as dist + + K_obs_coeff = K_obs_coeff or x_obs.shape[1] + K_cov = K_cov or X_pool.shape[1] + + # --- Arb cadence parameters --- + # Informative prior: exp(2.5) ≈ 12 min cadence (empirical median from + # formula-vs-real analysis on major pairs). sigma=0.5 allows range ~4-35 min. + alpha_0 = numpyro.sample("alpha_0", dist.Normal(2.5, 0.5)) + alpha_chain = numpyro.sample( + "alpha_chain", + dist.Normal(0, 0.5).expand([n_chains - 1]).to_event(1), + ) + alpha_tier = numpyro.sample( + "alpha_tier", + dist.Normal(0, 0.5).expand([n_tiers - 1]).to_event(1), + ) + alpha_tvl = numpyro.sample("alpha_tvl", dist.Normal(0, 0.3)) + + # Per-observation cadence (broadcast pool-level indices to obs) + log_cadence = ( + alpha_0 + + _pad_with_ref(alpha_chain)[chain_idx[pool_idx]] + + _pad_with_ref(alpha_tier)[tier_idx[pool_idx]] + + alpha_tvl * lag_log_tvl + ) + cadence = jnp.exp(jnp.clip(log_cadence, -2.0, 6.0)) # 0.1 to 400 min + + # V_arb per obs (deterministic given cadence + observables) + V_arb = formula_arb_volume_daily_jax( + sigma_daily, jnp.exp(lag_log_tvl), fee, gas, cadence, + ) + + # --- Hierarchical per-pool noise (same structure as noise_model) --- + B = numpyro.sample( + "B", dist.Normal(0.0, 5.0).expand([K_obs_coeff, K_cov]).to_event(2) + ) + sigma_theta = numpyro.sample( + "sigma_theta", dist.HalfNormal(2.0).expand([K_obs_coeff]).to_event(1) + ) + L_Omega = numpyro.sample( + "L_Omega", dist.LKJCholesky(K_obs_coeff, concentration=2.0) + ) + + # Non-centered pool effects + L_Sigma = jnp.diag(sigma_theta) @ L_Omega + + with numpyro.plate("pools", N_pools): + eta = numpyro.sample( + "eta", dist.Normal(0.0, 1.0).expand([K_obs_coeff]).to_event(1) + ) + + mu_pop = X_pool @ B.T # (N_pools, K_obs_coeff) + theta = mu_pop + eta @ L_Sigma.T # (N_pools, K_obs_coeff) + numpyro.deterministic("theta", theta) + + # Per-obs noise volume + theta_obs = theta[pool_idx] + log_V_noise = jnp.sum(theta_obs * x_obs, axis=1) + V_noise = jnp.exp(log_V_noise) + + # --- Observation model --- + df = numpyro.sample("df", dist.Gamma(2.0, 0.1)) + + # Per-tier sigma_eps (same as noise_model) + sigma_eps = numpyro.sample( + "sigma_eps", dist.HalfNormal(3.0).expand([3]).to_event(1) + ) + sigma_obs = sigma_eps[tier_A_per_pool[pool_idx]] + + mu = jnp.log(jnp.maximum(V_arb + V_noise, 1e-6)) + + if y_obs is not None: + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("y", dist.StudentT(df, mu, sigma_obs), obs=y_obs) + else: + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("y", dist.StudentT(df, mu, sigma_obs)) diff --git a/quantammsim/noise_calibration/output.py b/quantammsim/noise_calibration/output.py new file mode 100644 index 0000000..1d7b632 --- /dev/null +++ b/quantammsim/noise_calibration/output.py @@ -0,0 +1,314 @@ +"""JSON output and sample caching.""" + +import json +import os + +import numpy as np + +from .constants import K_COEFF, COEFF_NAMES, K_OBS_COEFF, OBS_COEFF_NAMES + + +def generate_output_json(pool_params, samples, data, convergence, + output_path, inference_config): + """Write structured JSON output. + + Dispatches format based on whether samples contain DP mixture parameters + (detected via "v" in sample_dict), or structural model parameters + (detected via "W_gate" in sample_dict). + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Structural model path: has arb cadence params + hierarchical noise + is_structural = "alpha_0" in sample_dict and "B" in sample_dict + if is_structural: + _generate_structural_output( + pool_params, sample_dict, data, convergence, + output_path, inference_config, + ) + return + + # Detection priority: check hybrid first, then pure IBP, then DP + is_ibp_dp = ("W" in sample_dict and "v" in sample_dict + and "z_logit" not in sample_dict) + is_ibp = ("W" in sample_dict and "v" not in sample_dict + and "z_logit" not in sample_dict) + is_ibp_ste = "z_logit" in sample_dict # legacy STE artifacts + is_dp = "v" in sample_dict and "W" not in sample_dict + + B_median = np.median(np.array(sample_dict["B"]), axis=0).tolist() + sigma_eps_median = np.median( + np.array(sample_dict["sigma_eps"]), axis=0 + ) + df_median = float(np.median(np.array(sample_dict["df"]))) + + if is_ibp_dp: + model_name = "hierarchical_student_t_ibp_dp" + sigma_eps_structure = "dp_mixture" + + W_median = np.median(np.array(sample_dict["W"]), axis=0).tolist() + v_ibp_median = np.median(np.array(sample_dict["v_ibp"]), axis=0) + pi = np.cumprod(v_ibp_median).tolist() + + from .model import stick_breaking_weights + import jax.numpy as jnp + v_median = np.median(np.array(sample_dict["v"]), axis=0) + cluster_weights = np.array( + stick_breaking_weights(jnp.array(v_median)) + ).tolist() + + population_effects = { + "B": B_median, + "sigma_eps": sigma_eps_median.tolist() if hasattr(sigma_eps_median, 'tolist') else sigma_eps_median, + "df": df_median, + "W": W_median, + "feature_prevalences": pi, + "alpha_ibp": float( + np.median(np.array(sample_dict["alpha_ibp"])) + ), + "cluster_weights": cluster_weights, + "alpha_dp": float( + np.median(np.array(sample_dict["alpha_dp"])) + ), + } + + # Joint MAP assignments + from .postprocessing import assign_ibp_dp_joint + feat_assignments, cluster_assignments = assign_ibp_dp_joint( + sample_dict, data + ) + + pool_entries = {} + for i, p in enumerate(pool_params): + entry = { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + "feature_assignments": feat_assignments[i].tolist(), + "cluster_assignment": int(cluster_assignments[i]), + } + pool_entries[p["pool_id"]] = entry + + elif is_ibp or is_ibp_ste: + model_name = "hierarchical_student_t_ibp" + sigma_eps_structure = "scalar" + + W_median = np.median(np.array(sample_dict["W"]), axis=0).tolist() + v_ibp_median = np.median(np.array(sample_dict["v_ibp"]), axis=0) + pi = np.cumprod(v_ibp_median).tolist() + + population_effects = { + "B": B_median, + "sigma_eps": float(sigma_eps_median), + "df": df_median, + "W": W_median, + "feature_prevalences": pi, + "alpha_ibp": float( + np.median(np.array(sample_dict["alpha_ibp"])) + ), + } + + # Per-pool feature assignments + if is_ibp_ste: + # Legacy STE path: threshold z_logit + z_logit = np.array(sample_dict["z_logit"]) + z_logit_median = np.median(z_logit, axis=0) + feature_assignments = (z_logit_median > 0).astype(int).tolist() + else: + # Marginalized path: MAP assignments from data + from .postprocessing import assign_ibp_features + feature_assignments = assign_ibp_features( + sample_dict, data + ).tolist() + + pool_entries = {} + for i, p in enumerate(pool_params): + entry = { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + "feature_assignments": feature_assignments[i], + } + pool_entries[p["pool_id"]] = entry + + elif is_dp: + model_name = "hierarchical_student_t_dp_sigma" + sigma_eps_structure = "dp_mixture" + else: + model_name = "unified_hierarchical_student_t" + sigma_eps_structure = "per_tier" + + if not is_ibp and not is_ibp_dp: + sigma_theta_median = np.median( + np.array(sample_dict["sigma_theta"]), axis=0 + ).tolist() + + # Correlation matrix + L_Omega = np.array(sample_dict["L_Omega"]) + Omega = np.einsum("sij,skj->sik", L_Omega, L_Omega) + Omega_median = np.median(Omega, axis=0).tolist() + + population_effects = { + "B": B_median, + "sigma_theta": sigma_theta_median, + "sigma_eps": sigma_eps_median.tolist() if hasattr(sigma_eps_median, 'tolist') else sigma_eps_median, + "df": df_median, + "correlation_matrix": Omega_median, + } + + if is_dp: + from .model import stick_breaking_weights + import jax.numpy as jnp + v_median = np.median(np.array(sample_dict["v"]), axis=0) + w = stick_breaking_weights(jnp.array(v_median)) + population_effects["cluster_weights"] = np.array(w).tolist() + population_effects["alpha_dp"] = float( + np.median(np.array(sample_dict["alpha_dp"])) + ) + + pool_entries = { + p["pool_id"]: { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + } + for p in pool_params + } + + output = { + "model": model_name, + "model_spec": { + "K_coeff": K_COEFF, + "K_cov": data["K_cov"], + "coeff_names": COEFF_NAMES, + "covariate_names": data["covariate_names"], + "likelihood": "StudentT", + "tvl_lag": "log_tvl_lag1", + "sigma_eps_structure": sigma_eps_structure, + }, + "inference": inference_config, + "population_effects": population_effects, + "convergence": convergence, + "n_pools": len(pool_params), + "n_obs": len(data["y_obs"]), + "pools": pool_entries, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +def _generate_structural_output(pool_params, sample_dict, data, convergence, + output_path, inference_config): + """Write structural model JSON output (LVR arb + hierarchical noise).""" + alpha_0 = float(np.median(np.array(sample_dict["alpha_0"]))) + alpha_chain = np.median(np.array(sample_dict["alpha_chain"]), axis=0).tolist() + alpha_tier = np.median(np.array(sample_dict["alpha_tier"]), axis=0).tolist() + alpha_tvl = float(np.median(np.array(sample_dict["alpha_tvl"]))) + + B_median = np.median(np.array(sample_dict["B"]), axis=0).tolist() + sigma_theta_median = np.median( + np.array(sample_dict["sigma_theta"]), axis=0 + ).tolist() + + # Correlation matrix + L_Omega = np.array(sample_dict["L_Omega"]) + Omega = np.einsum("sij,skj->sik", L_Omega, L_Omega) + Omega_median = np.median(Omega, axis=0).tolist() + + df_median = float(np.median(np.array(sample_dict["df"]))) + sigma_eps_median = np.median( + np.array(sample_dict["sigma_eps"]), axis=0 + ).tolist() + + population_effects = { + "alpha_0": alpha_0, + "alpha_chain": alpha_chain, + "alpha_tier": alpha_tier, + "alpha_tvl": alpha_tvl, + "B": B_median, + "sigma_theta": sigma_theta_median, + "correlation_matrix": Omega_median, + "df": df_median, + "sigma_eps": sigma_eps_median, + } + + pool_entries = {} + for p in pool_params: + pool_entries[p["pool_id"]] = { + "chain": p["chain"], + "tokens": p["tokens"], + "arb_frequency": p["arb_frequency"], + "noise_params": p["noise_params"], + } + + output = { + "model": "structural_mixture", + "model_spec": { + "K_obs_coeff": K_OBS_COEFF, + "obs_coeff_names": OBS_COEFF_NAMES, + "K_cov": data["K_cov"], + "covariate_names": data["covariate_names"], + "likelihood": "StudentT", + }, + "inference": inference_config, + "population_effects": population_effects, + "convergence": convergence, + "n_pools": len(pool_params), + "n_obs": len(data["y_obs"]), + "pools": pool_entries, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +def _save_sample_cache(samples, data, cache_dir): + """Cache posterior samples and data arrays for --predict reuse.""" + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + os.makedirs(cache_dir, exist_ok=True) + + # Save only the samples needed for prediction and diagnostics. + # Skip "y" (S x N_obs, can be >1GB) and "theta" (S x N_pools x K, + # reconstructible from B, eta, sigma_theta, L_Omega). + skip_keys = {"y", "theta"} + sample_cache = os.path.join(cache_dir, "unified_samples.npz") + np.savez_compressed( + sample_cache, + **{k: np.array(v) for k, v in sample_dict.items() + if k not in skip_keys}, + ) + + # Data arrays for predict + data_cache = os.path.join(cache_dir, "unified_data.json") + cache_data = { + "pool_ids": data["pool_ids"], + "covariate_names": data["covariate_names"], + "K_cov": data["K_cov"], + "N_pools": data["N_pools"], + "ref_chain": data["ref_chain"], + "ref_tier_a": data["ref_tier_a"], + "ref_tier_b": data["ref_tier_b"], + "chains": data["chains"], + } + with open(data_cache, "w") as f: + json.dump(cache_data, f, indent=2) + + print(f" Cached samples -> {sample_cache}") + print(f" Cached data metadata -> {data_cache}") diff --git a/quantammsim/noise_calibration/plotting.py b/quantammsim/noise_calibration/plotting.py new file mode 100644 index 0000000..727443b --- /dev/null +++ b/quantammsim/noise_calibration/plotting.py @@ -0,0 +1,335 @@ +"""Diagnostic plots for noise calibration.""" + +import os + +import numpy as np +import pandas as pd + +from .constants import K_COEFF, COEFF_NAMES +from .inference import _get_theta_samples + + +def plot_diagnostics(samples, data, output_dir, elbo_losses=None, + mcmc=None, prior_samples=None): + """Generate up to 9 diagnostic plots.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + os.makedirs(output_dir, exist_ok=True) + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) # (S, N_pools, K_coeff) + theta_median = np.median(theta_samples, axis=0) + + pool_idx = data["pool_idx"] + x_obs = data["x_obs"] + y_obs = data["y_obs"] + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + + # --- 1. Prior predictive check --- + if prior_samples is not None: + y_prior = prior_samples.get("y", None) + if y_prior is not None: + fig, ax = plt.subplots(figsize=(10, 5)) + # Flatten a subsample of prior draws + y_prior_flat = y_prior.flatten() + # Clip for display + clip_lo, clip_hi = np.percentile(y_prior_flat, [0.5, 99.5]) + y_prior_clipped = y_prior_flat[ + (y_prior_flat >= clip_lo) & (y_prior_flat <= clip_hi) + ] + ax.hist(y_prior_clipped, bins=100, alpha=0.5, density=True, + color="steelblue", label="Prior predictive") + ax.hist(y_obs, bins=100, alpha=0.5, density=True, + color="coral", label="Observed") + ax.set_xlabel("log(volume)") + ax.set_ylabel("Density") + ax.set_title("Prior predictive check: log-volume") + ax.legend() + plt.tight_layout() + path = os.path.join(output_dir, "prior_predictive.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 2. ELBO loss curve (SVI only) --- + if elbo_losses is not None: + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + ax = axes[0] + ax.plot(elbo_losses, alpha=0.3, color="steelblue", linewidth=0.5) + # Smoothed + window = min(100, len(elbo_losses) // 10) + if window > 1: + smoothed = pd.Series(elbo_losses).rolling(window).mean().values + ax.plot(smoothed, color="red", linewidth=1.5, label=f"Rolling {window}") + ax.legend() + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence") + + ax = axes[1] + # Last 20% of training + start = len(elbo_losses) * 4 // 5 + ax.plot(range(start, len(elbo_losses)), elbo_losses[start:], + color="steelblue", linewidth=0.8) + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence (last 20%)") + + plt.tight_layout() + path = os.path.join(output_dir, "elbo_convergence.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 3. Trace plots (NUTS only) --- + if mcmc is not None: + try: + import arviz as az + idata = az.from_numpyro(mcmc) + var_names = ["sigma_theta", "sigma_eps", "df"] + available = [v for v in var_names if v in idata.posterior] + if available: + axes = az.plot_trace(idata, var_names=available, compact=True) + fig = axes.ravel()[0].figure + fig.set_size_inches(14, 3 * len(available)) + path = os.path.join(output_dir, "trace_plots.png") + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {path}") + except Exception as e: + print(f" WARNING: Trace plots failed: {e}") + + # --- 4. Posterior predictive: predicted vs observed --- + # Compute y_pred and r2 here; r2 is reused in plot 9 (model summary). + theta_obs = theta_median[pool_idx] + y_pred = np.sum(theta_obs * x_obs, axis=1) + r2 = 1 - np.var(y_obs - y_pred) / np.var(y_obs) + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + ax.scatter(y_obs, y_pred, alpha=0.1, s=4, color="steelblue") + lims = [min(y_obs.min(), y_pred.min()), max(y_obs.max(), y_pred.max())] + ax.plot(lims, lims, "r--", linewidth=1) + ax.set_xlabel("Observed log(volume)") + ax.set_ylabel("Predicted log(volume)") + ax.set_title("Posterior predictive check") + ax.text(0.05, 0.95, f"R² = {r2:.3f}", transform=ax.transAxes, + fontsize=11, verticalalignment="top") + + ax = axes[1] + residuals = y_obs - y_pred + ax.hist(residuals, bins=60, color="steelblue", edgecolor="white", alpha=0.8) + ax.axvline(0, color="red", linestyle="--") + ax.set_xlabel("Residual") + + sigma_eps_samples = np.array(sample_dict.get("sigma_eps", [0])) + if sigma_eps_samples.ndim > 1: + sigma_str = ", ".join(f"{np.median(sigma_eps_samples[:, i]):.2f}" + for i in range(sigma_eps_samples.shape[1])) + else: + sigma_str = f"{np.median(sigma_eps_samples):.2f}" + ax.set_title(f"Residuals (sigma_eps ~ [{sigma_str}])") + + plt.tight_layout() + path = os.path.join(output_dir, "posterior_predictive.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 5. Per-pool b_c by chain/tier --- + b_tvl_all = theta_median[:, 1] + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + chains_present = sorted(pool_meta["chain"].unique()) + chain_data = [] + chain_labels = [] + for c in chains_present: + mask = pool_meta["chain"].values == c + if mask.sum() > 0: + chain_data.append(b_tvl_all[mask]) + chain_labels.append(f"{c}\n(n={mask.sum()})") + if chain_data: + ax.boxplot(chain_data, tick_labels=chain_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by chain") + + ax = axes[1] + tier_a_vals = pool_meta["tier_A"].values.astype(int) + tier_labels_map = {0: "Blue-chip", 1: "Mid-cap", 2: "Long-tail"} + tier_data = [] + tier_labels = [] + for t in [0, 1, 2]: + mask = tier_a_vals == t + if mask.sum() > 0: + tier_data.append(b_tvl_all[mask]) + tier_labels.append(f"{tier_labels_map[t]}\n(n={mask.sum()})") + if tier_data: + ax.boxplot(tier_data, tick_labels=tier_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by token tier (best token)") + + plt.tight_layout() + path = os.path.join(output_dir, "per_pool_b_c.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 6. Correlation matrix posterior --- + L_Omega_samples = np.array(sample_dict["L_Omega"]) # (S, K, K) + Omega_samples = np.einsum("sij,skj->sik", L_Omega_samples, L_Omega_samples) + Omega_median = np.median(Omega_samples, axis=0) + + fig, ax = plt.subplots(figsize=(7, 6)) + im = ax.imshow(Omega_median, vmin=-1, vmax=1, cmap="RdBu_r") + ax.set_xticks(range(K_COEFF)) + ax.set_yticks(range(K_COEFF)) + ax.set_xticklabels(COEFF_NAMES, rotation=45, ha="right") + ax.set_yticklabels(COEFF_NAMES) + for i in range(K_COEFF): + for j in range(K_COEFF): + ax.text(j, i, f"{Omega_median[i, j]:.2f}", ha="center", + va="center", fontsize=10, + color="white" if abs(Omega_median[i, j]) > 0.5 else "black") + plt.colorbar(im, ax=ax, shrink=0.8) + ax.set_title("Posterior median correlation matrix (Omega)") + plt.tight_layout() + path = os.path.join(output_dir, "correlation_matrix.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 7. Shrinkage plot: OLS b_c vs hierarchical b_c --- + ols_b_c = np.zeros(len(pool_ids)) + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + if mask.sum() < 5: + ols_b_c[i] = np.nan + continue + x_i = x_obs[mask] + y_i = y_obs[mask] + try: + beta, _, _, _ = np.linalg.lstsq(x_i, y_i, rcond=None) + ols_b_c[i] = beta[1] # TVL coefficient + except np.linalg.LinAlgError: + ols_b_c[i] = np.nan + + hier_b_c = theta_median[:, 1] + valid = np.isfinite(ols_b_c) + + if valid.sum() > 2: + fig, ax = plt.subplots(figsize=(8, 8)) + ax.scatter(ols_b_c[valid], hier_b_c[valid], alpha=0.6, s=20, + color="steelblue") + + pop_b_c = np.median(hier_b_c) + ax.axhline(pop_b_c, color="red", linestyle="--", linewidth=0.8, + label=f"Population median = {pop_b_c:.3f}") + + lims = [min(np.nanmin(ols_b_c[valid]), hier_b_c[valid].min()) - 0.2, + max(np.nanmax(ols_b_c[valid]), hier_b_c[valid].max()) + 0.2] + ax.plot(lims, lims, "k:", linewidth=0.8, alpha=0.5) + ax.set_xlabel("Per-pool OLS b_c (lagged TVL)") + ax.set_ylabel("Hierarchical posterior median b_c") + ax.set_title("Shrinkage: OLS vs hierarchical TVL elasticity") + ax.legend() + plt.tight_layout() + path = os.path.join(output_dir, "shrinkage_b_c.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 8. beta_tvl vs beta_vol scatter colored by chain --- + fig, ax = plt.subplots(figsize=(10, 7)) + pool_id_to_chain = dict(zip(pool_meta["pool_id"], pool_meta["chain"])) + chain_colors = {} + cmap = plt.cm.tab10 + unique_chains = sorted(pool_meta["chain"].unique()) + for i, c in enumerate(unique_chains): + chain_colors[c] = cmap(i % 10) + + beta_tvl_arr = theta_median[:, 1] + beta_vol_arr = theta_median[:, 2] + for i, pid in enumerate(pool_ids): + c = pool_id_to_chain.get(pid, "?") + ax.scatter(beta_tvl_arr[i], beta_vol_arr[i], + color=chain_colors.get(c, "gray"), alpha=0.6, s=20, + edgecolors="white", linewidths=0.3) + + from matplotlib.lines import Line2D + handles = [Line2D([0], [0], marker="o", color="w", + markerfacecolor=chain_colors[c], markersize=8, + label=c) + for c in unique_chains if c in chain_colors] + ax.legend(handles=handles, fontsize=8, loc="best") + ax.set_xlabel("b_tvl (TVL elasticity)") + ax.set_ylabel("b_sigma (volatility sensitivity)") + ax.set_title("Pool-specific coefficients by chain") + ax.axhline(0, color="gray", linewidth=0.5, linestyle="--") + ax.axvline(0, color="gray", linewidth=0.5, linestyle="--") + plt.tight_layout() + path = os.path.join(output_dir, "beta_tvl_vs_beta_vol.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 9. Model summary panel --- + B_samples = np.array(sample_dict["B"]) + B_median = np.median(B_samples, axis=0) # (K_coeff, K_cov) + sigma_theta_med = np.median(np.array(sample_dict["sigma_theta"]), axis=0) + df_med = np.median(np.array(sample_dict["df"])) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) + + col_names = data["covariate_names"] + + fig, ax = plt.subplots(figsize=(12, 8)) + ax.axis("off") + + summary = "Group-level regression B (posterior median):\n" + header = f" {'covariate':<20s}" + for cn in COEFF_NAMES: + header += f" {cn:>10s}" + summary += header + "\n" + summary += " " + "-" * (20 + 11 * K_COEFF) + "\n" + for j, name in enumerate(col_names): + line = f" {name:<20s}" + for k in range(K_COEFF): + line += f" {B_median[k, j]:>10.3f}" + summary += line + "\n" + + summary += f"\nsigma_theta: [{', '.join(f'{v:.3f}' for v in sigma_theta_med)}]\n" + summary += f"\nCorrelation matrix (Omega):\n" + for i in range(K_COEFF): + row = " [" + " ".join(f"{Omega_median[i, j]:>6.3f}" + for j in range(K_COEFF)) + "]\n" + summary += row + + tier_names = ["blue-chip", "mid-cap", "long-tail"] + sigma_eps_str = ", ".join(f"{tier_names[i]}={sigma_eps_med[i]:.3f}" + for i in range(len(sigma_eps_med))) + summary += f"\nsigma_eps: [{sigma_eps_str}]\n" + summary += f"df (Student-t): {df_med:.1f}\n" + summary += f"R^2: {r2:.3f}\n" + + ax.text(0.02, 0.98, summary, transform=ax.transAxes, + fontsize=7, verticalalignment="top", fontfamily="monospace") + ax.set_title("Model Summary") + plt.tight_layout() + path = os.path.join(output_dir, "model_summary.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") diff --git a/quantammsim/noise_calibration/postprocessing.py b/quantammsim/noise_calibration/postprocessing.py new file mode 100644 index 0000000..190f07c --- /dev/null +++ b/quantammsim/noise_calibration/postprocessing.py @@ -0,0 +1,659 @@ +"""Post-processing: extract params, predict, convergence, prior predictive.""" + +import numpy as np + +from .constants import K_COEFF, COEFF_NAMES, K_OBS_COEFF, OBS_COEFF_NAMES +from .token_classification import classify_token_tier +from .covariate_encoding import _tier_pair_idx +from .inference import _get_theta_samples, _build_model_kwargs +from .model import noise_model + + +def extract_noise_params(samples, data, use_median=True) -> list: + """Extract per-pool noise params from posterior samples. + + Handles both MCMC.get_samples() and SVI samples dict. + Applies weekend absorption: b_0_eff = b_0_raw + b_weekend * (2/7). + """ + # Get theta samples + if hasattr(samples, "get_samples"): + # MCMC object + sample_dict = samples.get_samples() + else: + sample_dict = samples + + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) # (S, N_pools, K_coeff) + + agg_fn = np.median if use_median else np.mean + theta_agg = agg_fn(theta_samples, axis=0) # (N_pools, K_coeff) + theta_std = np.std(theta_samples, axis=0) + + pool_ids = data["pool_ids"] + pool_meta = data["pool_meta"] + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + b_0_raw, b_tvl, b_sigma, b_weekend = theta_agg[i] + std_vals = theta_std[i] + + # Weekend absorption: simulator has no weekend indicator, + # so fold the expected weekend effect into the intercept. + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "theta_median": [float(x) for x in theta_agg[i]], + "theta_std": [float(x) for x in std_vals], + "b_weekend": float(b_weekend), + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(meta["swap_fee"]), + }, + }) + + return results + + +def predict_new_pool(samples, data, chain: str, tokens: list, + fee: float, feature_assignments=None) -> dict: + """Predict noise params for an unseen pool using population effects. + + Constructs z_new, computes mu_new = B @ z_new across all posterior samples, + returns point estimate + 90% credible intervals with weekend absorption. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Build z_new using data-driven column names + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b}": + z_new[i] = 1.0 + + # mu_new = B @ z_new across all posterior samples + B_samples = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + mu_samples = np.einsum("skd,d->sk", B_samples, z_new) # (S, K_coeff) + + # IBP path: add feature effects + is_ibp = "W" in sample_dict + if is_ibp: + W_samples = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + if feature_assignments is not None: + # User-specified binary features + Z = np.array(feature_assignments) # (K_features,) + feature_effect = np.einsum("skj,k->sj", W_samples, Z) + prediction_source = "ibp_user_features" + else: + # Marginal: weight by prevalences pi = cumprod(v_ibp) + v_ibp = np.array(sample_dict["v_ibp"]) # (S, K_features) + pi = np.cumprod(v_ibp, axis=1) # (S, K_features) + feature_effect = np.einsum("skj,sk->sj", W_samples, pi) + prediction_source = "ibp_marginal" + mu_samples = mu_samples + feature_effect + else: + prediction_source = "population_level" + + mu_median = np.median(mu_samples, axis=0) + mu_q05 = np.percentile(mu_samples, 5, axis=0) + mu_q95 = np.percentile(mu_samples, 95, axis=0) + + # Weekend absorption + b_0_raw, b_tvl, b_sigma, b_weekend = mu_median + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + result = { + "chain": chain, + "tokens": tokens, + "fee": fee, + "prediction_source": prediction_source, + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(fee), + }, + "credible_intervals_90": { + name: { + "median": float(mu_median[k]), + "q05": float(mu_q05[k]), + "q95": float(mu_q95[k]), + } + for k, name in enumerate(COEFF_NAMES) + }, + } + + print(f"\n Predicted noise_params for {chain} {tokens} (fee={fee}):") + for name, ci in result["credible_intervals_90"].items(): + print(f" {name:12s}: {ci['median']:+.3f} " + f"[{ci['q05']:+.3f}, {ci['q95']:+.3f}]") + print(f"\n Effective b_0 (weekend-absorbed): {b_0_effective:.3f}") + + return result + + +def check_convergence(mcmc_or_losses, method="nuts") -> dict: + """Compute convergence diagnostics. + + For NUTS: R-hat, ESS, divergences. + For SVI: final ELBO, ELBO stability. + """ + if method == "svi": + losses = np.array(mcmc_or_losses) + return { + "method": "svi", + "final_elbo": float(losses[-1]), + "elbo_last_100_std": float(np.std(losses[-100:])), + "elbo_last_100_mean": float(np.mean(losses[-100:])), + } + + # NUTS diagnostics + import arviz as az + + mcmc = mcmc_or_losses + idata = az.from_numpyro(mcmc) + + n_chains = idata.posterior.sizes.get("chain", 1) + + rhat_max = float("nan") + if n_chains >= 2: + rhat = az.rhat(idata) + rhat_vals = [] + for var in rhat.data_vars: + if var == "theta": + continue + vals = rhat[var].values + rhat_vals.extend(vals.flatten()) + rhat_max = float(np.nanmax(rhat_vals)) if rhat_vals else float("nan") + + ess = az.ess(idata) + ess_vals = [] + for var in ess.data_vars: + if var == "theta": + continue + vals = ess[var].values + ess_vals.extend(vals.flatten()) + ess_min = float(np.nanmin(ess_vals)) if ess_vals else float("nan") + + divergences = int(idata.sample_stats["diverging"].sum().values) + + print(f"\n Convergence diagnostics:") + if n_chains >= 2: + print(f" R-hat max: {rhat_max:.4f} " + f"{'OK' if rhat_max < 1.05 else 'WARNING'}") + else: + print(f" R-hat max: N/A (need >= 2 chains)") + print(f" ESS min: {ess_min:.0f} " + f"{'OK' if ess_min > 400 else 'WARNING'}") + print(f" Divergences: {divergences} " + f"{'OK' if divergences == 0 else 'WARNING'}") + + return { + "method": "nuts", + "r_hat_max": rhat_max, + "ess_min": ess_min, + "divergences": divergences, + } + + +def assign_dp_clusters(samples, data) -> np.ndarray: + """Compute posterior MAP cluster assignments for DP mixture model. + + Uses median posterior samples for v->w, sigma_eps, df, theta to compute + per-pool-per-cluster log-likelihoods, then returns argmax assignments. + """ + from scipy.special import logsumexp + from .model import stick_breaking_weights + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Median posterior parameters + v_med = np.median(np.array(sample_dict["v"]), axis=0) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) + df_med = float(np.median(np.array(sample_dict["df"]))) + + # Compute w from v via stick-breaking (using numpy) + import jax.numpy as jnp + w = np.array(stick_breaking_weights(jnp.array(v_med))) + + # Reconstruct theta + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) + theta_med = np.median(theta_samples, axis=0) # (N_pools, K_coeff) + + # Per-observation predicted means + pool_idx = np.array(data["pool_idx"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + K_clusters = len(sigma_eps_med) + + theta_obs = theta_med[pool_idx] + mu_obs = np.sum(theta_obs * x_obs, axis=1) # (N_obs,) + + # Log-likelihood per observation per cluster + from scipy.stats import t as t_dist + log_lik_per_k = np.zeros((len(y_obs), K_clusters)) + for k in range(K_clusters): + log_lik_per_k[:, k] = t_dist.logpdf( + y_obs, df_med, loc=mu_obs, scale=sigma_eps_med[k] + ) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, K_clusters)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik_per_k[i] + + # Posterior cluster probabilities: log p(z=k|data) = log w_k + sum log p(y|k) + log_posterior = np.log(w + 1e-30)[None, :] + pool_log_liks + # MAP assignment + assignments = np.argmax(log_posterior, axis=1).astype(np.int64) + return assignments + + +def assign_ibp_features(samples, data) -> np.ndarray: + """Compute MAP feature assignments for marginalized IBP model. + + Enumerates all 2^K binary feature configurations per pool, evaluates + per-pool log-posterior (log-prior + log-likelihood), returns argmax + config as (N_pools, K_features) binary ndarray. + + Uses median posterior parameters for B, W, v_ibp, sigma_eps, df. + """ + from scipy.stats import t as t_dist + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + B_med = np.median(np.array(sample_dict["B"]), axis=0) # (K_coeff, K_cov) + W_med = np.median(np.array(sample_dict["W"]), axis=0) # (K_features, K_coeff) + v_ibp_med = np.median(np.array(sample_dict["v_ibp"]), axis=0) # (K_features,) + sigma_eps_med = float(np.median(np.array(sample_dict["sigma_eps"]))) + df_med = float(np.median(np.array(sample_dict["df"]))) + + pi = np.cumprod(v_ibp_med) # (K_features,) + K_features = len(pi) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + + # Enumerate all 2^K configs + n_configs = 2 ** K_features + configs = ( + (np.arange(n_configs)[:, None] >> np.arange(K_features)[None, :]) & 1 + ).astype(float) # (n_configs, K_features) + + # Log-prior per config + log_pi = np.log(pi + 1e-30) + log_1mpi = np.log(1.0 - pi + 1e-30) + log_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Population mean + mu_pop = X_pool @ B_med.T # (N_pools, K_coeff) + + # Feature effects per config + feature_effects = configs @ W_med # (n_configs, K_coeff) + + # Per-obs means: mu_pop_obs + feature_mu + mu_pop_obs = np.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config + log_lik = t_dist.logpdf( + y_obs[:, None], df_med, loc=mu_obs, scale=sigma_eps_med + ) # (N_obs, n_configs) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, n_configs)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik[i] + + # Posterior = log_prior + pool_log_liks; MAP config per pool + log_posterior = log_prior[None, :] + pool_log_liks # (N_pools, n_configs) + best_config_idx = np.argmax(log_posterior, axis=1) # (N_pools,) + + return configs[best_config_idx].astype(int) # (N_pools, K_features) + + +def assign_ibp_dp_joint(samples, data) -> tuple: + """Compute MAP joint (feature, cluster) assignments for hybrid IBP+DP model. + + Enumerates all (2^K_features × K_clusters) joint configurations per pool, + evaluates joint log-posterior, returns argmax assignments. + + Returns: + (feature_assignments, cluster_assignments): + feature_assignments: (N_pools, K_features) binary ndarray + cluster_assignments: (N_pools,) int ndarray + """ + from scipy.stats import t as t_dist + from .model import stick_breaking_weights + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + B_med = np.median(np.array(sample_dict["B"]), axis=0) # (K_coeff, K_cov) + W_med = np.median(np.array(sample_dict["W"]), axis=0) # (K_features, K_coeff) + v_ibp_med = np.median(np.array(sample_dict["v_ibp"]), axis=0) # (K_features,) + v_med = np.median(np.array(sample_dict["v"]), axis=0) # (K_clusters-1,) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) # (K_clusters,) + df_med = float(np.median(np.array(sample_dict["df"]))) + + pi = np.cumprod(v_ibp_med) # (K_features,) + K_features = len(pi) + K_clusters = len(sigma_eps_med) + + import jax.numpy as jnp + w = np.array(stick_breaking_weights(jnp.array(v_med))) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + + # Enumerate all 2^K configs + n_configs = 2 ** K_features + configs = ( + (np.arange(n_configs)[:, None] >> np.arange(K_features)[None, :]) & 1 + ).astype(float) # (n_configs, K_features) + + # IBP log-prior per config + log_pi = np.log(pi + 1e-30) + log_1mpi = np.log(1.0 - pi + 1e-30) + log_ibp_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Joint log-prior: IBP config × DP cluster + log_joint_prior = log_ibp_prior[:, None] + np.log(w + 1e-30)[None, :] # (n_configs, K_clusters) + + # Population mean + mu_pop = X_pool @ B_med.T # (N_pools, K_coeff) + feature_effects = configs @ W_med # (n_configs, K_coeff) + + # Per-obs means + mu_pop_obs = np.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config per cluster + log_lik = np.zeros((len(y_obs), n_configs, K_clusters)) + for k in range(K_clusters): + log_lik[:, :, k] = t_dist.logpdf( + y_obs[:, None], df_med, loc=mu_obs, scale=sigma_eps_med[k] + ) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, n_configs, K_clusters)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik[i] + + # Joint posterior: log_joint_prior + pool_log_liks + log_posterior = log_joint_prior[None, :, :] + pool_log_liks # (N_pools, n_configs, K_clusters) + + # Flatten to (N_pools, n_configs * K_clusters), argmax, unravel + flat = log_posterior.reshape(N_pools, -1) + best_flat_idx = np.argmax(flat, axis=1) + best_config_idx = best_flat_idx // K_clusters + best_cluster_idx = best_flat_idx % K_clusters + + feature_assignments = configs[best_config_idx].astype(int) # (N_pools, K_features) + cluster_assignments = best_cluster_idx.astype(np.int64) # (N_pools,) + + return feature_assignments, cluster_assignments + + +def extract_structural_params(samples, data, use_median=True) -> list: + """Extract per-pool arb frequency and noise coefficients from structural model. + + Parameters + ---------- + samples : dict + Posterior samples from SVI/NUTS with structural_noise_model. + data : dict + Output of encode_covariates_structural(). + + Returns + ------- + list of dict + Per-pool dicts with: pool_id, chain, tokens, arb_frequency, noise_params. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + agg_fn = np.median if use_median else np.mean + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # Per-pool theta from hierarchical model + # theta = X_pool @ B.T + eta @ L_Sigma.T + B = agg_fn(np.array(sample_dict["B"]), axis=0) + eta = agg_fn(np.array(sample_dict["eta"]), axis=0) + sigma_theta = agg_fn(np.array(sample_dict["sigma_theta"]), axis=0) + L_Omega = agg_fn(np.array(sample_dict["L_Omega"]), axis=0) + + X_pool = np.array(data["X_pool"]) + L_Sigma = np.diag(sigma_theta) @ L_Omega + theta = X_pool @ B.T + eta @ L_Sigma.T # (N_pools, K_obs_coeff) + + chain_idx = np.array(data["chain_idx"]) + tier_idx = np.array(data["tier_idx"]) + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + + # Per-pool cadence + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + # Per-pool log_tvl (median across observations) + pool_idx_arr = np.array(data["pool_idx"]) + lag_log_tvl = np.array(data["lag_log_tvl"]) + N_pools = data["N_pools"] + + pool_tvl_median = np.zeros(N_pools) + for p in range(N_pools): + mask = pool_idx_arr == p + if mask.any(): + pool_tvl_median[p] = np.median(lag_log_tvl[mask]) + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + log_cadence = ( + alpha_0 + + padded_chain[chain_idx[i]] + + padded_tier[tier_idx[i]] + + alpha_tvl * pool_tvl_median[i] + ) + cadence = np.exp(np.clip(log_cadence, -2.0, 6.0)) + arb_freq = int(np.clip(np.round(cadence), 1, 60)) + + noise_coeffs = { + name: float(theta[i, k]) + for k, name in enumerate(OBS_COEFF_NAMES) + } + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "arb_frequency": arb_freq, + "noise_params": noise_coeffs, + }) + + return results + + +def predict_new_pool_structural( + samples, data, chain: str, tokens: list, fee: float, tvl_est: float, +) -> dict: + """Predict cadence and noise coefficients for a hypothetical pool. + + Uses the structural model's arb cadence parameters and hierarchical B + regression for noise (population mean, no pool-specific random effect). + + Parameters + ---------- + samples : dict + Posterior samples from structural_noise_model. + data : dict + Output of encode_covariates_structural(). + chain : str + Chain name. + tokens : list of str + Token symbols. + fee : float + Swap fee (fraction). + tvl_est : float + Estimated TVL in USD. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + agg_fn = np.median + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # Hierarchical B for noise population mean + B = agg_fn(np.array(sample_dict["B"]), axis=0) + + # Construct chain and tier indices for the new pool + chains = data["chains"] + chain_to_idx = {c: i for i, c in enumerate(chains)} + c_idx = chain_to_idx.get(chain, 0) # fallback to reference + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tier_a + t_idx = _tier_pair_idx(tier_a, tier_b) + + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + log_tvl = np.log(max(tvl_est, 1.0)) + log_cadence = ( + alpha_0 + + padded_chain[c_idx] + + padded_tier[t_idx] + + alpha_tvl * log_tvl + ) + cadence = np.exp(np.clip(log_cadence, -2.0, 6.0)) + arb_freq = int(np.clip(np.round(cadence), 1, 60)) + + # Construct X_pool_new for hierarchical regression + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + tier_a_str = str(tier_a) + tier_b_str = str(tier_b) + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a_str}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b_str}": + z_new[i] = 1.0 + + # Population mean theta for new pool (no random effect) + theta_new = z_new @ B.T # (K_obs_coeff,) + + noise_coeffs = { + name: float(theta_new[k]) + for k, name in enumerate(OBS_COEFF_NAMES) + } + + return { + "chain": chain, + "tokens": tokens, + "fee": fee, + "tvl_est": tvl_est, + "arb_frequency": arb_freq, + "noise_params": noise_coeffs, + } + + +def run_prior_predictive(data, num_samples=500, model_fn=None) -> dict: + """Run prior predictive check (no observations).""" + import jax + from numpyro.infer import Predictive + + if model_fn is None: + model_fn = noise_model + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + model_kwargs["y_obs"] = None # no observations + + predictive = Predictive(model_fn, num_samples=num_samples) + rng_key = jax.random.PRNGKey(99) + prior_samples = predictive(rng_key, **model_kwargs) + prior_samples = {k: np.array(v) for k, v in prior_samples.items()} + + print(f" Prior predictive: drew {num_samples} samples") + y_prior = prior_samples.get("y", None) + if y_prior is not None: + print(f" Prior log-volume range: " + f"[{np.percentile(y_prior, 1):.1f}, " + f"{np.percentile(y_prior, 99):.1f}]") + print(f" Observed log-volume range: " + f"[{data['y_obs'].min():.1f}, {data['y_obs'].max():.1f}]") + + return prior_samples diff --git a/quantammsim/noise_calibration/token_classification.py b/quantammsim/noise_calibration/token_classification.py new file mode 100644 index 0000000..aad9caa --- /dev/null +++ b/quantammsim/noise_calibration/token_classification.py @@ -0,0 +1,23 @@ +"""Token tier classification.""" + +from .constants import _TIER_0, _TIER_1 + + +def _normalise_symbol(symbol: str) -> str: + """Normalise wrapped/bridged variants to canonical form.""" + s = symbol.strip() + mapping = { + "WETH": "WETH", "WBTC": "WBTC", "cbBTC": "cbBTC", + "WMATIC": "WMATIC", "WAVAX": "WAVAX", "WXDAI": "WXDAI", "wS": "wS", + } + return mapping.get(s, s) + + +def classify_token_tier(symbol: str) -> int: + """Classify a token symbol into tier 0/1/2.""" + s = _normalise_symbol(symbol) + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 diff --git a/quantammsim/pools/G3M/balancer/balancer.py b/quantammsim/pools/G3M/balancer/balancer.py index fcee272..89cf46d 100644 --- a/quantammsim/pools/G3M/balancer/balancer.py +++ b/quantammsim/pools/G3M/balancer/balancer.py @@ -331,6 +331,7 @@ def calculate_reserves_with_dynamic_inputs( materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], + materialized_inputs.lp_supply, ) return reserves diff --git a/quantammsim/pools/G3M/balancer/balancer_reserves.py b/quantammsim/pools/G3M/balancer/balancer_reserves.py index 82c6629..e36463b 100644 --- a/quantammsim/pools/G3M/balancer/balancer_reserves.py +++ b/quantammsim/pools/G3M/balancer/balancer_reserves.py @@ -147,7 +147,7 @@ def _jax_calc_balancer_reserves_with_fees_scan_function_using_precalcs( tokens_to_drop, gamma, n, - 0, + -1e-15, ) optimal_arb_trade = jnp.where( @@ -350,6 +350,7 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using prev_reserves = carry_list[1] counter = carry_list[2] + prev_lp_supply = carry_list[3] # input_list contains weights, prices, precalcs and fee/arb amounts prices = input_list[0] @@ -359,7 +360,16 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using gamma = input_list[4] arb_thresh = input_list[5] arb_fees = input_list[6] - trade = input_list[7] if do_trades else None + trade = input_list[7] + lp_supply = input_list[8] + + # Scale reserves for LP supply changes (proportional deposits/withdrawals) + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where( + lp_supply_change, + prev_reserves * lp_supply / prev_lp_supply, + prev_reserves, + ) fees_are_being_charged = gamma != 1.0 @@ -385,7 +395,7 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using tokens_to_drop, gamma, n, - 0, + -1e-15, ) optimal_arb_trade = jnp.where( @@ -421,6 +431,7 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using prices, reserves, counter, + lp_supply, ], reserves @@ -436,6 +447,7 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( trades=None, do_trades=False, do_arb=True, + lp_supply_array=None, ): """ Calculate AMM reserves considering fees and arbitrage opportunities using signature variations, @@ -497,6 +509,14 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( if do_trades and trades is None: raise ValueError("Trades must be provided when do_trades=True.") + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + # pre-calculate some values that are repeatedly used in optimal arb calculations _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) @@ -529,18 +549,22 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( initial_prices, initial_reserves, 0, + lp_supply_array[0], ] - scan_inputs = [ - prices, - active_initial_weights, - per_asset_ratios, - all_other_assets_ratios, - gamma, - arb_thresh, - arb_fees, - ] - if do_trades: - scan_inputs.append(trades) - _, reserves = scan(scan_fn, carry_list_init, scan_inputs) + _, reserves = scan( + scan_fn, + carry_list_init, + [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + trades, + lp_supply_array, + ], + ) return reserves diff --git a/quantammsim/pools/G3M/optimal_n_pool_arb.py b/quantammsim/pools/G3M/optimal_n_pool_arb.py index 822e7ce..d9e3964 100644 --- a/quantammsim/pools/G3M/optimal_n_pool_arb.py +++ b/quantammsim/pools/G3M/optimal_n_pool_arb.py @@ -164,21 +164,18 @@ def construct_optimal_trade_jnp( valid_post_trade_reserves = ( jnp.sum(initial_reserves + active_overall_trade > 0) == n ) - valid_post_trade_constant = ( - jnp.prod( - ( - initial_reserves - + active_overall_trade - * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) - ) - ** initial_weights + post_trade_constant = jnp.prod( + ( + initial_reserves + + active_overall_trade + * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) ) - - initial_constant - >= slack + ** initial_weights ) + relative_diff = (post_trade_constant - initial_constant) / initial_constant + valid_post_trade_constant = relative_diff >= slack valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant) return jnp.where(valid_trade, active_overall_trade, 0) - # return active_overall_trade, valid_post_trade_reserves * valid_post_trade_constant construct_optimal_trade_jnp_vmapped = vmap( @@ -336,18 +333,16 @@ def calc_optimal_trade_for_one_signature( valid_post_trade_reserves = ( jnp.sum(initial_reserves + active_overall_trade > 0) == n ) - valid_post_trade_constant = ( - jnp.prod( - ( - initial_reserves - + active_overall_trade - * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) - ) - ** initial_weights + post_trade_constant = jnp.prod( + ( + initial_reserves + + active_overall_trade + * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) ) - - initial_constant - >= slack + ** initial_weights ) + relative_diff = (post_trade_constant - initial_constant) / initial_constant + valid_post_trade_constant = relative_diff >= slack valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant) return jnp.where(valid_trade, active_overall_trade, 0) # return { @@ -411,7 +406,7 @@ def parallelised_optimal_trade_sifter( tokens_to_drop, fee_gamma, n, - 0, + slack, ) profits = -(overall_trades * local_prices).sum(-1) @@ -457,7 +452,7 @@ def wrapped_parallelised_optimal_trade_sifter( tokens_to_drop, fee_gamma, n, - slack=0, + slack=slack, ) return trade diff --git a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py index 7091ee3..b739abe 100644 --- a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py +++ b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py @@ -417,6 +417,10 @@ def calculate_reserves_with_dynamic_inputs( do_trades=run_fingerprint["do_trades"], dtype=arb_acted_upon_local_prices.dtype, ) + lp_supply_array_broadcast = materialized_inputs.lp_supply + # if we are doing trades, the trades array must be of the same length as the other arrays + if run_fingerprint["do_trades"]: + assert materialized_inputs.trades.shape[0] == max_len protocol_fee_split = run_fingerprint.get("protocol_fee_split", 0.0) reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs( initial_reserves, diff --git a/quantammsim/pools/base_pool.py b/quantammsim/pools/base_pool.py index 0415ef9..b695d8c 100644 --- a/quantammsim/pools/base_pool.py +++ b/quantammsim/pools/base_pool.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple +from functools import partial import numpy as np import jax.numpy as jnp from jax.nn import softmax -from jax.lax import stop_gradient -from jax import tree_util +from jax.lax import stop_gradient, dynamic_slice +from jax import tree_util, jit, vmap from quantammsim.core_simulator.param_utils import make_vmap_in_axes_dict @@ -283,6 +284,99 @@ def add_noise( params[key] = jnp.array(params[key]) return params + @partial(jit, static_argnums=(2, 3)) + def calculate_volatility_array(self, prices, run_fingerprint, subsample_freq=5): + """Annualised daily realised volatility broadcast to minute-level array. + + Pure-JAX implementation (vmap + dynamic_slice) — JIT-compatible and + callable from within traced contexts (e.g. forward_pass). + + Parameters + ---------- + prices : jnp.ndarray, shape (T, 2) + Minute-level prices for two tokens. + run_fingerprint : dict + Must contain ``tokens`` and ``numeraire`` for ordering. + subsample_freq : int + Subsample within each day to reduce microstructure noise. + + Returns + ------- + jnp.ndarray, shape (T,) + Annualised volatility, constant within each day. + """ + ordered_prices, needs_swap = self._handle_numeraire_ordering( + prices, run_fingerprint, + ) + asset_prices = ordered_prices[:, 0] / ordered_prices[:, 1] + n_minutes = len(asset_prices) + + # Guard: need at least one full day for vmap + dynamic_slice + if n_minutes < 1440: + return jnp.full(n_minutes, 0.1) * jnp.sqrt(365.0) + + n_days = n_minutes // 1440 + + def calculate_daily_volatility(day_idx): + start_idx = day_idx * 1440 + window_prices = dynamic_slice(asset_prices, [start_idx], [1440]) + subsampled_prices = window_prices[::subsample_freq] + log_prices = jnp.log(jnp.maximum(subsampled_prices, 1e-8)) + returns = jnp.diff(log_prices) + num_nonzero_returns = jnp.sum(returns != 0) + total_returns = len(returns) + adjusted_variance = ( + num_nonzero_returns * jnp.var(returns) / total_returns + ) + dt = subsample_freq / 1440 + vol = jnp.sqrt(adjusted_variance) / jnp.sqrt(dt) + return vol + + daily_volatilities = vmap(calculate_daily_volatility)(jnp.arange(n_days)) + volatility_array = jnp.repeat(daily_volatilities, 1440) + + remaining_minutes = n_minutes - len(volatility_array) + if remaining_minutes > 0: + last_vol = ( + daily_volatilities[-1] if len(daily_volatilities) > 0 else 0.1 + ) + volatility_array = jnp.concatenate( + [volatility_array, jnp.full(remaining_minutes, last_vol)] + ) + + return volatility_array * jnp.sqrt(365.0) + + @partial(jit, static_argnums=(2,)) + def _handle_numeraire_ordering( + self, + prices: jnp.ndarray, + run_fingerprint: Dict[str, Any], + ) -> Tuple[jnp.ndarray, bool]: + """Reorder prices so numeraire token is in second position. + + Parameters + ---------- + prices : jnp.ndarray, shape (..., 2) + Price array with two tokens. + run_fingerprint : dict + Must contain ``tokens`` (sorted) and ``numeraire``. + + Returns + ------- + (ordered_prices, needs_swap) : (jnp.ndarray, bool) + """ + tokens = sorted(run_fingerprint["tokens"]) + numeraire = run_fingerprint["numeraire"] + if numeraire is None or numeraire not in tokens: + numeraire = tokens[-1] + needs_swap = tokens.index(numeraire) == 0 + + if needs_swap: + ordered_prices = prices[..., ::-1] + else: + ordered_prices = prices + return ordered_prices, needs_swap + def _tree_flatten(self): children = () aux_data = dict() # static values diff --git a/quantammsim/pools/noise_trades.py b/quantammsim/pools/noise_trades.py index 82e7674..3f45792 100644 --- a/quantammsim/pools/noise_trades.py +++ b/quantammsim/pools/noise_trades.py @@ -108,3 +108,167 @@ def calculate_reserves_after_noise_trade( ) reserves = current_reserves * ratio_of_value_of_trade_to_reserves return reserves + + +@jit +def reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """reClAMM Tsoukalas sqrt model: effective TVL regressor. + + Predicts per-minute noise trader volume using: + V_daily = (a_0 - a_f*fee + a_sigma*sigma + + a_c*sqrt(c_eff/1e6)) * 1e6 + V_noise = max(0, V_daily/1440 - arb_volume_this_period) + + where c_eff = (Ra+Va)*pA + (Rb+Vb)*pB is the effective TVL (real + + virtual reserves valued in USD). For a concentrated liquidity pool, + effective reserves determine execution quality and routing decisions, + so they are the natural driver of noise volume. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Regression coefficients. Keys: a_0_base, a_f, a_sigma, + a_c, base_fee. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + a_0_base = noise_params.get("a_0_base", 0.5) + a_f = noise_params.get("a_f", 0.0) + a_sigma = noise_params.get("a_sigma", 2.0) + a_c = noise_params.get("a_c", 1.0) + base_fee = noise_params.get("base_fee", 0.003) + + fee = 1.0 - gamma + a_0 = a_0_base + base_fee * a_f + daily_vol = ( + a_0 - a_f * fee + + a_sigma * volatility + + a_c * jnp.sqrt(effective_value_usd / 1e6) + ) * 1e6 + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + +@jit +def reclamm_tsoukalas_log_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """reClAMM Tsoukalas log model: log(c_eff/1e6) instead of sqrt. + + Same specification as the sqrt variant but uses log regressor, + which may fit better for pools spanning a wide TVL range. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Regression coefficients (same keys as sqrt variant). + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + a_0_base = noise_params.get("a_0_base", 0.5) + a_f = noise_params.get("a_f", 0.0) + a_sigma = noise_params.get("a_sigma", 2.0) + a_c = noise_params.get("a_c", 1.0) + base_fee = noise_params.get("base_fee", 0.003) + + fee = 1.0 - gamma + a_0 = a_0_base + base_fee * a_f + daily_vol = ( + a_0 - a_f * fee + + a_sigma * volatility + + a_c * jnp.log(jnp.maximum(effective_value_usd / 1e6, 1e-30)) + ) * 1e6 + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + +@jit +def reclamm_loglinear_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """Loglinear noise volume from hierarchical cross-pool calibration. + + Predicts per-minute noise volume using: + log(V_daily) = b_0 + b_sigma * volatility + b_c * log(TVL) + V_noise = max(0, exp(log_daily_vol) / 1440 - arb_volume) + + where b_0 is a pool-specific intercept (BLUP from the hierarchical + model, absorbing chain, token tier, and fee effects), and b_sigma, + b_c are shared fixed effects estimated from cross-pool variation. + + Note: ``gamma`` is accepted for interface compatibility with the + other noise volume functions but is not used; fee effects are + absorbed into ``b_0`` via the hierarchical model's BLUP. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). Unused — kept for uniform + calling convention across noise models. + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Hierarchical model coefficients. Keys: b_0, b_sigma, b_c. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + b_0 = noise_params.get("b_0", -6.7) + b_sigma = noise_params.get("b_sigma", -0.0007) + b_c = noise_params.get("b_c", 1.04) + + log_daily_vol = ( + b_0 + + b_sigma * volatility + + b_c * jnp.log(jnp.maximum(effective_value_usd, 1.0)) + ) + daily_vol = jnp.exp(log_daily_vol) + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 762301c..5b913ff 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -18,6 +18,22 @@ from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool + + +def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len): + """Slice and decimate a dynamic input array to match arb_prices shape. + + Scalar (1,) arrays are broadcast to (max_len,). + Full-length arrays are sliced to the bout window then decimated. + """ + if arr.shape[0] <= 1: + return jnp.broadcast_to(arr, (max_len,) + arr.shape[1:]) + sliced = dynamic_slice(arr, (start_index[0],), (bout_length - 1,)) + if arb_frequency != 1: + sliced = sliced[::arb_frequency] + return sliced + + from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, calibrate_arc_length_speed, @@ -206,9 +222,36 @@ def calculate_reserves_with_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) + bout_length = run_fingerprint["bout_length"] + arb_freq = run_fingerprint["arb_frequency"] + lp_prepared = ( + _prepare_dynamic_array( + lp_supply_array, start_index, bout_length, + arb_freq, s.arb_prices.shape[0], + ) + if lp_supply_array is not None else None + ) + + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_params = run_fingerprint.get("reclamm_noise_params", None) + if noise_params is not None and type(noise_params) is not dict: + noise_params = dict(noise_params) + + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility_array = self.calculate_volatility_array( + prices, run_fingerprint, + ) + arb_vol = _prepare_dynamic_array( + volatility_array, start_index, bout_length, + arb_freq, s.arb_prices.shape[0], + ) + else: + arb_vol = None + if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_with_fees( s.initial_reserves, s.Va, s.Vb, @@ -225,6 +268,11 @@ def calculate_reserves_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=lp_prepared, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -236,6 +284,7 @@ def calculate_reserves_and_fee_revenue_with_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ): """Calculate reserves and LP fee revenue with fees. @@ -247,6 +296,32 @@ def calculate_reserves_and_fee_revenue_with_fees( """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) + bout_length = run_fingerprint["bout_length"] + arb_freq = run_fingerprint["arb_frequency"] + lp_prepared = ( + _prepare_dynamic_array( + lp_supply_array, start_index, bout_length, + arb_freq, s.arb_prices.shape[0], + ) + if lp_supply_array is not None else None + ) + + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_params = run_fingerprint.get("reclamm_noise_params", None) + if noise_params is not None and type(noise_params) is not dict: + noise_params = dict(noise_params) + + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility_array = self.calculate_volatility_array( + prices, run_fingerprint, + ) + arb_vol = _prepare_dynamic_array( + volatility_array, start_index, bout_length, + arb_freq, s.arb_prices.shape[0], + ) + else: + arb_vol = None + if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( s.initial_reserves, s.Va, s.Vb, @@ -263,6 +338,11 @@ def calculate_reserves_and_fee_revenue_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=lp_prepared, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) return ( jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape), @@ -301,6 +381,22 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( dtype=s.arb_prices.dtype, ) + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_params = run_fingerprint.get("reclamm_noise_params", None) + if noise_params is not None and type(noise_params) is not dict: + noise_params = dict(noise_params) + + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility_array = self.calculate_volatility_array( + prices, run_fingerprint, + ) + arb_vol = _prepare_dynamic_array( + volatility_array, start_index, bout_length, + run_fingerprint["arb_frequency"], max_len, + ) + else: + arb_vol = None + return _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, s.arb_prices, @@ -317,6 +413,11 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=materialized_inputs.lp_supply, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) @partial(jit, static_argnums=(2,)) @@ -379,6 +480,22 @@ def calculate_reserves_with_dynamic_inputs( dtype=s.arb_prices.dtype, ) + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_params = run_fingerprint.get("reclamm_noise_params", None) + if noise_params is not None and type(noise_params) is not dict: + noise_params = dict(noise_params) + + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility_array = self.calculate_volatility_array( + prices, run_fingerprint, + ) + arb_vol = _prepare_dynamic_array( + volatility_array, start_index, bout_length, + run_fingerprint["arb_frequency"], max_len, + ) + else: + arb_vol = None + return _jax_calc_reclamm_reserves_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, s.arb_prices, @@ -395,6 +512,11 @@ def calculate_reserves_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=materialized_inputs.lp_supply, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) def init_base_parameters( diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 3825f11..2a46ad3 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -30,6 +30,12 @@ from quantammsim.pools.G3M.G3M_trades import ( _jax_calc_G3M_trade_from_exact_in_given_out, ) +from quantammsim.pools.noise_trades import ( + calculate_reserves_after_noise_trade, + reclamm_tsoukalas_sqrt_noise_volume, + reclamm_tsoukalas_log_noise_volume, + reclamm_loglinear_noise_volume, +) # Reference balance for initialisation (matches Solidity _INITIALIZATION_MAX_BALANCE_A) _INITIALIZATION_MAX_BALANCE_A = 1e6 @@ -598,7 +604,7 @@ def apply_target_price_ratio_to_virtual_balances(Ra, Rb, Va, Vb, target_price_ra def _reclamm_scan_step_zero_fees( carry_list, - prices, + input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, @@ -611,11 +617,23 @@ def _reclamm_scan_step_zero_fees( 1. Update virtual balances (path-dependent) 2. Compute analytical constant-product arb (no fee friction) - Carry: [real_reserves (2,), Va (0-d), Vb (0-d)] + Carry: [real_reserves (2,), Va (0-d), Vb (0-d), prev_lp_supply (0-d)] + Input: [prices (2,), lp_supply (0-d)] """ prev_reserves = carry_list[0] Va = carry_list[1] Vb = carry_list[2] + prev_lp_supply = carry_list[3] + + prices = input_list[0] + lp_supply = input_list[1] + + # Scale both real and virtual reserves by LP supply ratio. + scale = lp_supply / prev_lp_supply + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) + Va = jnp.where(lp_supply_change, Va * scale, Va) + Vb = jnp.where(lp_supply_change, Vb * scale, Vb) Ra = prev_reserves[0] Rb = prev_reserves[1] @@ -690,7 +708,7 @@ def _reclamm_scan_step_zero_fees( Rb_new = jnp.where(clamp_a, Rb + edge_a[1], jnp.where(clamp_b, Rb + edge_b[1], Rb_new)) new_reserves = jnp.array([Ra_new, Rb_new]) - return [new_reserves, Va, Vb], new_reserves + return [new_reserves, Va, Vb, lp_supply], new_reserves # --------------------------------------------------------------------------- @@ -702,7 +720,7 @@ def _reclamm_scan_step_zero_fees( def _reclamm_scan_step_zero_fees_full_state( carry_list, - prices, + input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, @@ -711,7 +729,7 @@ def _reclamm_scan_step_zero_fees_full_state( ): """TEST-ONLY: scan step that outputs (reserves, Va, Vb).""" new_carry, new_reserves = _reclamm_scan_step_zero_fees( - carry_list, prices, centeredness_margin, daily_price_shift_base, seconds_per_step, + carry_list, input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, ) @@ -731,15 +749,19 @@ def _reclamm_scan_step_with_fees_and_revenue( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """Single scan step for reClAMM pool with fees, returning LP fee revenue. Primary implementation — ``_reclamm_scan_step_with_fees`` wraps this. - Carry: [real_reserves (2,), Va, Vb, step_idx, active_start_ratio, + Carry: [real_reserves (2,), Va, Vb, prev_lp_supply, step_idx, active_start_ratio, active_target_ratio, active_start_step, active_end_step, active_enabled] Input: [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update] + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update, + lp_supply, (optional) volatility] Returns ------- @@ -750,15 +772,13 @@ def _reclamm_scan_step_with_fees_and_revenue( prev_reserves = carry_list[0] Va = carry_list[1] Vb = carry_list[2] - step_idx = carry_list[3] - active_start_ratio = carry_list[4] - active_target_ratio = carry_list[5] - active_start_step = carry_list[6] - active_end_step = carry_list[7] - active_enabled = carry_list[8] - - Ra = prev_reserves[0] - Rb = prev_reserves[1] + prev_lp_supply = carry_list[3] + step_idx = carry_list[4] + active_start_ratio = carry_list[5] + active_target_ratio = carry_list[6] + active_start_step = carry_list[7] + active_end_step = carry_list[8] + active_enabled = carry_list[9] prices = input_list[0] active_initial_weights = input_list[1] @@ -768,7 +788,21 @@ def _reclamm_scan_step_with_fees_and_revenue( arb_thresh = input_list[5] arb_fees = input_list[6] price_ratio_update = input_list[7] + lp_supply = input_list[8] + + # Scale both real and virtual reserves by LP supply ratio. + # Matches ReClammPool.sol onBeforeAddLiquidity / onBeforeRemoveLiquidity: + # all balances (real + virtual) scale proportionally with BPT supply. + scale = lp_supply / prev_lp_supply + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) + Va = jnp.where(lp_supply_change, Va * scale, Va) + Vb = jnp.where(lp_supply_change, Vb * scale, Vb) + + Ra = prev_reserves[0] + Rb = prev_reserves[1] + # Price-ratio schedule: apply target price ratio changes over time. event_has = price_ratio_update[0] > 0.5 event_target_ratio = jnp.maximum( jnp.where(jnp.isfinite(price_ratio_update[1]), price_ratio_update[1], 1.0), @@ -936,6 +970,45 @@ def _skip_schedule_state(_): Ra_new = Ra + applied_trade[0] Rb_new = Rb + applied_trade[1] + # --- Noise model dispatch --- + # noise_model is a concrete Python string (passed via Partial as static + # aux_data), so if/elif branches resolve at trace time. + if noise_model == "ratio": + noisy_reserves = calculate_reserves_after_noise_trade( + applied_trade, jnp.array([Ra_new, Rb_new]), prices, + noise_trader_ratio, gamma, + ) + Ra_new = jnp.where(noise_trader_ratio > 0, noisy_reserves[0], Ra_new) + Rb_new = jnp.where(noise_trader_ratio > 0, noisy_reserves[1], Rb_new) + elif noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility = input_list[9] + arb_volume = 0.5 * jnp.sum(jnp.abs(applied_trade) * prices) + real_value = jnp.sum(jnp.array([Ra_new, Rb_new]) * prices) + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + + _np = noise_params if noise_params is not None else {} + if noise_model == "tsoukalas_sqrt": + noise_vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value, gamma, volatility, + arb_volume, _np, + ) + elif noise_model == "tsoukalas_log": + noise_vol = reclamm_tsoukalas_log_noise_volume( + effective_value, gamma, volatility, + arb_volume, _np, + ) + else: # loglinear + noise_vol = reclamm_loglinear_noise_volume( + effective_value, gamma, volatility, + arb_volume, _np, + ) + + noise_fee_income = (1.0 - gamma) * noise_vol + scale = 1.0 + noise_fee_income / jnp.maximum(real_value, 1e-8) + Ra_new = Ra_new * scale + Rb_new = Rb_new * scale + # else: "arb_only" — no noise trades + # Clamp-to-edge: if a real reserve would go negative, apply an # exact-in-given-out edge trade that drains that token to _DUST_USD # worth of reserves (preserving the AMM invariant). @@ -978,6 +1051,7 @@ def _skip_schedule_state(_): new_reserves, Va, Vb, + lp_supply, step_idx + 1.0, active_start_ratio, active_target_ratio, @@ -1000,6 +1074,9 @@ def _reclamm_scan_step_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """Single scan step for reClAMM pool with fees (reserves only). @@ -1018,6 +1095,9 @@ def _reclamm_scan_step_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params, ) return new_carry, new_reserves @@ -1064,6 +1144,7 @@ def _jax_calc_reclamm_reserves_zero_fees( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + lp_supply_array=None, ): """Calculate reClAMM reserves over time with zero fees. @@ -1085,12 +1166,22 @@ def _jax_calc_reclamm_reserves_zero_fees( If > 0, use constant-arc-length thermostat instead of geometric. centeredness_scaling : bool If True, scale speed by margin/centeredness (proportional controller). + lp_supply_array : jnp.ndarray, optional + LP token supply over time, shape (T,). Defaults to constant 1.0. Returns ------- reserves : jnp.ndarray, shape (T, 2) Real reserves over time. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + scan_fn = Partial( _reclamm_scan_step_zero_fees, centeredness_margin=centeredness_margin, @@ -1100,8 +1191,8 @@ def _jax_calc_reclamm_reserves_zero_fees( centeredness_scaling=centeredness_scaling, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] - _, reserves = scan(scan_fn, carry_init, prices) + carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] + _, reserves = scan(scan_fn, carry_init, [prices, lp_supply_array]) return reserves @@ -1116,6 +1207,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + lp_supply_array=None, ): """TEST-ONLY: Like _jax_calc_reclamm_reserves_zero_fees but returns Va/Vb. @@ -1125,6 +1217,14 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( Va_history : jnp.ndarray, shape (T,) Vb_history : jnp.ndarray, shape (T,) """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + scan_fn = Partial( _reclamm_scan_step_zero_fees_full_state, centeredness_margin=centeredness_margin, @@ -1134,12 +1234,12 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( centeredness_scaling=centeredness_scaling, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] - _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, prices) + carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] + _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, [prices, lp_supply_array]) return reserves, Va_history, Vb_history -@jit +@partial(jit, static_argnames=('noise_model',)) def _jax_calc_reclamm_reserves_with_fees( initial_reserves, initial_Va, @@ -1155,12 +1255,25 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves over time with fees. Uses the G3M optimal arb machinery with constant weights [0.5, 0.5] applied to effective reserves (real + virtual). """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) gamma = 1.0 - fees @@ -1195,12 +1308,22 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) + scan_inputs = [prices, active_initial_weights, per_asset_ratios, + all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, + price_ratio_updates, lp_supply_array] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + carry_init = [ initial_reserves, initial_Va, initial_Vb, + lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio @@ -1208,16 +1331,11 @@ def _jax_calc_reclamm_reserves_with_fees( jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled ] - _, reserves = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], - ) + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=('noise_model',)) def _jax_calc_reclamm_reserves_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1236,8 +1354,21 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves with time-varying fees/arb arrays.""" + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1285,12 +1416,22 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) + scan_inputs = [prices, active_initial_weights, per_asset_ratios, + all_other_assets_ratios, gamma, arb_thresh, arb_fees, + price_ratio_updates, lp_supply_array] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + carry_init = [ initial_reserves, initial_Va, initial_Vb, + lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio @@ -1298,12 +1439,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled ] - _, reserves = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], - ) + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves @@ -1351,6 +1487,8 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[1]) ) + lp_supply_array = jnp.ones(prices.shape[0], dtype=prices.dtype) + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) ) @@ -1380,6 +1518,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( initial_reserves, initial_Va, initial_Vb, + lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio @@ -1391,12 +1530,13 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( scan_fn, carry_init, [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], + all_other_assets_ratios, gamma, arb_thresh, arb_fees, + price_ratio_updates, lp_supply_array], ) return reserves, Va_history, Vb_history -@jit +@partial(jit, static_argnames=('noise_model',)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( initial_reserves, initial_Va, @@ -1412,6 +1552,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves and LP fee revenue over time with fees. @@ -1421,6 +1566,14 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( fee_revenue : jnp.ndarray, shape (T,) LP fee revenue per timestep in USD. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) gamma = 1.0 - fees @@ -1454,12 +1607,22 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) + scan_inputs = [prices, active_initial_weights, per_asset_ratios, + all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, + price_ratio_updates, lp_supply_array] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + carry_init = [ initial_reserves, initial_Va, initial_Vb, + lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio @@ -1467,16 +1630,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled ] - _, (reserves, fee_revenue) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], - ) + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=('noise_model',)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1495,6 +1653,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves and LP fee revenue with time-varying fees/arb arrays. @@ -1504,6 +1667,14 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( fee_revenue : jnp.ndarray, shape (T,) LP fee revenue per timestep in USD. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1550,12 +1721,22 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) + scan_inputs = [prices, active_initial_weights, per_asset_ratios, + all_other_assets_ratios, gamma, arb_thresh, arb_fees, + price_ratio_updates, lp_supply_array] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + carry_init = [ initial_reserves, initial_Va, initial_Vb, + lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio @@ -1563,10 +1744,5 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled ] - _, (reserves, fee_revenue) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], - ) + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue diff --git a/quantammsim/utils/data_processing/historic_data_utils.py b/quantammsim/utils/data_processing/historic_data_utils.py index 31fb76c..b83a97b 100644 --- a/quantammsim/utils/data_processing/historic_data_utils.py +++ b/quantammsim/utils/data_processing/historic_data_utils.py @@ -83,9 +83,6 @@ def start_and_end_calcs( if oracle_values is not None: oracle_values = oracle_values[remainder_idx:] - print("start_date: ", start_date) - print("end_date: ", end_date) - print("unix_values: ", unix_values) start_idx = np.where(unix_values == start_date)[0][0] end_idx = np.where(unix_values == end_date)[0][0] + 1 else: diff --git a/scripts/benchmark_reclamm_interpolation.py b/scripts/benchmark_reclamm_interpolation.py new file mode 100644 index 0000000..f462bbf --- /dev/null +++ b/scripts/benchmark_reclamm_interpolation.py @@ -0,0 +1,648 @@ +"""Benchmark reClAMM range shift interpolation: current vs optimal midpoint. + +Compares total arb loss during a range shift under different interpolation methods: + Geometric VB -- exponential decay of overvalued virtual (what contracts do) + Linear VB -- uniform steps in VB + Linear Z -- uniform steps in Z = sqrt(P)*VA - VB/sqrt(P) (optimal, from note) + Optimal 2-step -- exact midpoint via quadratic formula (Section 5 of note) + Brute-force optimal -- JAX gradient-optimised Z-target sequence + +Key result: per-step loss ~ (DeltaZ)^2 / (4X). Equal Z-increments minimise +total loss, analogous to TFMM optimal intermediate for G3M weight changes. + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/benchmark_reclamm_interpolation.py +""" + +import numpy as np +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +from scipy.optimize import minimize as scipy_minimize + +jax.config.update("jax_enable_x64", True) + + +# ── Core reClAMM mechanics ───────────────────────────────────────────────── + + +def compute_VA_from_VB(RA, RB, VB, Q): + """Contract rule (eq 15): VA = RA*(VB + RB) / ((Q-1)*VB - RB).""" + return RA * (VB + RB) / ((Q - 1) * VB - RB) + + +def compute_Z(VA, VB, P): + """Z = sqrt(P)*VA - VB/sqrt(P) (eq 12).""" + sqP = np.sqrt(P) + return sqP * VA - VB / sqP + + +def pool_value(RA, RB, P): + """Real pool value: P*RA + RB (eq 3).""" + return P * RA + RB + + +def micro_step(RA, RB, VA_new, VB_new, P): + """Virtual-balance update then arb to equilibrium Y/X = P. + + Returns (RA_new, RB_new, arb_loss). + """ + val_before = pool_value(RA, RB, P) + X = RA + VA_new + Y = RB + VB_new + L = X * Y + X_eq = np.sqrt(L / P) + Y_eq = P * X_eq + RA_new = X_eq - VA_new + RB_new = Y_eq - VB_new + return RA_new, RB_new, val_before - pool_value(RA_new, RB_new, P) + + +def solve_VB_for_Z(RA, RB, Z_star, Q, P): + """Solve quadratic for VB achieving Z(VB) = Z_star. + + Derived by substituting VA = RA*(VB+RB)/((Q-1)*VB-RB) into + Z = sqrt(P)*VA - VB/sqrt(P), then collecting terms in VB. + + NOTE: The research note (eq 28) has a sign error: the RB/sqrt(P) + term in b should be positive, not negative. Re-derived here from + scratch. + + Returns the physically valid root (VB > RB/(Q-1), positive). + Raises ValueError if no valid root exists. + """ + sqP = np.sqrt(P) + a = -(Q - 1) / sqP + b = sqP * RA + RB / sqP - (Q - 1) * Z_star # +RB/sqP, not minus + c = sqP * RA * RB + Z_star * RB + disc = b * b - 4 * a * c + if disc < -1e-6: + raise ValueError(f"negative discriminant: {disc:.4e}") + disc = max(disc, 0.0) + sd = np.sqrt(disc) + r1, r2 = (-b + sd) / (2 * a), (-b - sd) / (2 * a) + floor = RB / (Q - 1) + 1e-12 + ok = [r for r in (r1, r2) if r > floor] + if not ok: + raise ValueError(f"no valid root: r1={r1:.4f}, r2={r2:.4f}, floor={floor:.4f}") + return min(ok) + + +# ── Interpolation methods ────────────────────────────────────────────────── + + +def run_shift(RA, RB, VA_stale, VB_start, VB_end, Q, P, N, schedule): + """Execute N-step range shift (B overvalued, VB decreasing). + + schedule: "geometric" | "linear_VB" | "linear_Z" + + VA_stale: the current (possibly stale) VA -- used only for Z_start + in the linear_Z schedule. All micro-steps compute VA from the + contract rule with current reserves. + """ + # For linear_Z, precompute Z endpoints using contract-rule VA + if schedule == "linear_Z": + VA_start_cr = compute_VA_from_VB(RA, RB, VB_start, Q) + Z0 = compute_Z(VA_start_cr, VB_start, P) + VA_end_approx = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end = compute_Z(VA_end_approx, VB_end, P) + + total_loss = 0.0 + RA_c, RB_c = RA, RB + + for i in range(1, N + 1): + frac = i / N + if schedule == "geometric": + VB_i = VB_start * (VB_end / VB_start) ** frac + elif schedule == "linear_VB": + VB_i = VB_start + frac * (VB_end - VB_start) + elif schedule == "linear_Z": + Z_i = Z0 + frac * (Z_end - Z0) + VB_i = solve_VB_for_Z(RA_c, RB_c, Z_i, Q, P) + else: + raise ValueError(schedule) + + VA_i = compute_VA_from_VB(RA_c, RB_c, VB_i, Q) + RA_c, RB_c, loss = micro_step(RA_c, RB_c, VA_i, VB_i, P) + total_loss += loss + + return total_loss, RA_c, RB_c + + +def run_shift_optimal_2step(RA, RB, VA_stale, VB_start, VB_end, Q, P): + """Exact 2-step optimal midpoint (Section 5 of the note). + + Computes Z* = (Z_start + Z_end) / 2, solves quadratic for VB_mid. + """ + VA_start_cr = compute_VA_from_VB(RA, RB, VB_start, Q) + Z0 = compute_Z(VA_start_cr, VB_start, P) + VA_end_approx = compute_VA_from_VB(RA, RB, VB_end, Q) + Z2 = compute_Z(VA_end_approx, VB_end, P) + Z_star = (Z0 + Z2) / 2.0 + + # Step 1: jump to Z-midpoint + VB_mid = solve_VB_for_Z(RA, RB, Z_star, Q, P) + VA_mid = compute_VA_from_VB(RA, RB, VB_mid, Q) + RA1, RB1, loss1 = micro_step(RA, RB, VA_mid, VB_mid, P) + + # Step 2: jump to endpoint + VA_end = compute_VA_from_VB(RA1, RB1, VB_end, Q) + RA2, RB2, loss2 = micro_step(RA1, RB1, VA_end, VB_end, P) + + return loss1 + loss2, RA2, RB2 + + +# ── Scenario setup ───────────────────────────────────────────────────────── + + +def setup_centered_pool(P, price_ratio, R_scale=10000.0): + """Centered pool at price P with contract-rule-consistent virtuals. + + Returns (RA, RB, VA, VB, Q). + """ + Q = np.sqrt(price_ratio) + q4 = price_ratio ** 0.25 + + RA = R_scale + RB = P * R_scale + VA = RA / (q4 - 1) + VB = RB / (q4 - 1) + + return RA, RB, VA, VB, Q + + +def setup_decentered_pool(P_init, P_final, price_ratio, R_scale=10000.0): + """Centered pool at P_init, arb to P_final, then refresh virtuals. + + The refresh applies the contract rule to get consistent (VA, VB) at + the post-arb reserves, then arbs once more. This gives a decentered + but fully consistent state (equilibrium + contract rule). + + Returns (RA, RB, VA, VB, Q). + """ + Q = np.sqrt(price_ratio) + q4 = price_ratio ** 0.25 + + RA0 = R_scale + RB0 = P_init * R_scale + VA0 = RA0 / (q4 - 1) + VB0 = RB0 / (q4 - 1) + + # Arb to P_final (L preserved, virtuals stale) + X0 = RA0 + VA0 + Y0 = RB0 + VB0 + L = X0 * Y0 + X_new = np.sqrt(L / P_final) + Y_new = np.sqrt(L * P_final) + RA = X_new - VA0 + RB = Y_new - VB0 + + # Refresh: apply contract rule for current VB, then arb + VB = VB0 + VA = compute_VA_from_VB(RA, RB, VB, Q) + RA, RB, _ = micro_step(RA, RB, VA, VB, P_final) + + return RA, RB, VA, VB, Q + + +# ── JAX-differentiable versions for brute-force optimisation ────────────── + + +def _compute_VA_from_VB_jax(RA, RB, VB, Q): + return RA * (VB + RB) / ((Q - 1) * VB - RB) + + +def _compute_Z_jax(VA, VB, P): + sqP = jnp.sqrt(P) + return sqP * VA - VB / sqP + + +def _pool_value_jax(RA, RB, P): + return P * RA + RB + + +def _micro_step_jax(RA, RB, VA, VB, P): + val_before = _pool_value_jax(RA, RB, P) + X = RA + VA + Y = RB + VB + L = X * Y + X_eq = jnp.sqrt(L / P) + Y_eq = P * X_eq + RA_new = X_eq - VA + RB_new = Y_eq - VB + return RA_new, RB_new, val_before - _pool_value_jax(RA_new, RB_new, P) + + +def _solve_VB_for_Z_jax(RA, RB, Z_star, Q, P): + sqP = jnp.sqrt(P) + a = -(Q - 1) / sqP + b = sqP * RA + RB / sqP - (Q - 1) * Z_star + c = sqP * RA * RB + Z_star * RB + disc = jnp.maximum(b * b - 4 * a * c, 1e-30) + sd = jnp.sqrt(disc) + r1 = (-b + sd) / (2 * a) + r2 = (-b - sd) / (2 * a) + floor = RB / (Q - 1) + 1e-8 + return jnp.where(r2 > floor, r2, r1) + + +def _z_targets_from_raw(raw_params, Z_start, Z_end): + """Map unconstrained params -> sorted Z targets via softplus gaps.""" + gaps = jax.nn.softplus(raw_params) + gaps = gaps / jnp.sum(gaps) * (Z_end - Z_start) + return Z_start + jnp.cumsum(gaps) + + +def _make_loss_fn(N): + """Build a JIT-compiled loss function for a given N (unrolled loop).""" + + def total_loss(raw_params, RA, RB, Q, P, Z_start, Z_end): + Z_all = _z_targets_from_raw(raw_params, Z_start, Z_end) + RA_c, RB_c = RA, RB + total = 0.0 + for i in range(N): + VB_i = _solve_VB_for_Z_jax(RA_c, RB_c, Z_all[i], Q, P) + VA_i = _compute_VA_from_VB_jax(RA_c, RB_c, VB_i, Q) + RA_c, RB_c, loss = _micro_step_jax(RA_c, RB_c, VA_i, VB_i, P) + total = total + loss + return total + + return jax.jit(jax.value_and_grad(total_loss)) + + +def optimise_z_targets(RA, RB, Q, P, Z_start, Z_end, N, verbose=False): + """Find the Z-target sequence minimising total arb loss. + + Returns (optimal_loss, optimal_Z_targets_array_of_length_N). + """ + loss_and_grad_fn = _make_loss_fn(N) + RA_j = jnp.float64(RA) + RB_j = jnp.float64(RB) + Q_j = jnp.float64(Q) + P_j = jnp.float64(P) + Zs_j = jnp.float64(Z_start) + Ze_j = jnp.float64(Z_end) + + def objective(x): + val, grad = loss_and_grad_fn( + jnp.array(x, dtype=jnp.float64), RA_j, RB_j, Q_j, P_j, Zs_j, Ze_j + ) + return float(val), np.array(grad, dtype=np.float64) + + x0 = np.zeros(N) # softplus(0) = ln2, uniform gaps → linear Z init + result = scipy_minimize(objective, x0, jac=True, method="L-BFGS-B") + + optimal_Z = np.array( + _z_targets_from_raw(jnp.array(result.x), Zs_j, Ze_j) + ) + if verbose: + print(f" N={N}: loss={result.fun:.6f} " + f"nit={result.nit} success={result.success}") + return result.fun, optimal_Z + + +# ── Experiments ──────────────────────────────────────────────────────────── + + +def main(): + # --- Scenario: centered pool, moderate VB decay --- + P = 2.0 # token A costs 2 units of token B + price_ratio = 4.0 # rho, so Q = sqrt(4) = 2 + R_scale = 10000.0 + decay_fraction = 0.90 # VB_end = 0.90 * VB_start (10% decay) + + RA, RB, VA, VB, Q = setup_centered_pool(P, price_ratio, R_scale) + VB_start = VB + VB_end = VB * decay_fraction + + # Diagnostics + C = min(RA * VB, RB * VA) / max(RA * VB, RB * VA) + is_above = RA * VB > RB * VA + X = RA + VA + print("=" * 72) + print(f"Scenario: centered pool at P={P}, price_ratio={price_ratio}, Q={Q:.4f}") + print(f" RA={RA:.2f} RB={RB:.2f} VA={VA:.2f} VB={VB:.2f}") + print(f" Effective X={X:.2f} Pool value = {pool_value(RA, RB, P):.2f}") + print(f" Centeredness = {C:.4f} is_above = {is_above}") + print(f" VB shift: {VB_start:.2f} -> {VB_end:.2f} ({decay_fraction:.0%})") + VB_floor = RB / (Q - 1) + print(f" VB floor (denominator > 0): {VB_floor:.2f}") + Z_start = compute_Z(VA, VB, P) + VA_end_cr = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end = compute_Z(VA_end_cr, VB_end, P) + print(f" Z_start = {Z_start:.4f} Z_end = {Z_end:.4f}") + print(f" Approx 1-step loss ~ (DeltaZ)^2/(4X) = {(Z_end-Z_start)**2/(4*X):.2f}") + print("=" * 72) + + # ── Experiment 1: Loss vs N ──────────────────────────────────────── + + N_values = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128] + schedules = ["geometric", "linear_VB", "linear_Z"] + results = {s: [] for s in schedules} + + for N in N_values: + for sched in schedules: + try: + loss, _, _ = run_shift( + RA, RB, VA, VB_start, VB_end, Q, P, N, sched + ) + except (ValueError, AssertionError) as e: + loss = np.nan + results[sched].append(loss) + + # Optimal 2-step (single point) + try: + loss_opt2, _, _ = run_shift_optimal_2step( + RA, RB, VA, VB_start, VB_end, Q, P + ) + except (ValueError, AssertionError): + loss_opt2 = np.nan + + # Table + loss_1 = results["geometric"][0] + print(f"\n{'N':>5s} {'Geo VB':>12s} {'Lin VB':>12s} {'Lin Z':>12s}" + f" {'Geo/1step':>9s} {'LinZ/1step':>10s} {'LinZ/Geo':>9s}") + print("-" * 80) + for j, N in enumerate(N_values): + g = results["geometric"][j] + lv = results["linear_VB"][j] + lz = results["linear_Z"][j] + print(f"{N:>5d} {g:>12.6f} {lv:>12.6f} {lz:>12.6f}" + f" {g / loss_1:>9.4f} {lz / loss_1:>10.4f} {lz / g:>9.4f}") + + print(f"\n Optimal 2-step loss: {loss_opt2:.6f}") + print(f" Geometric N=2 loss: {results['geometric'][1]:.6f}" + f" (opt/geo = {loss_opt2 / results['geometric'][1]:.4f})") + print(f" Linear Z N=2 loss: {results['linear_Z'][1]:.6f}" + f" (opt/linZ = {loss_opt2 / results['linear_Z'][1]:.4f})") + + # ── Experiment 2: Z and VB trajectories at N=8 ───────────────────── + + N_viz = 8 + traj_data = {} + for sched in schedules: + VB_traj, Z_traj, loss_traj = [VB_start], [], [] + VA_s = VA # stale + Z_traj.append(compute_Z(VA_s, VB_start, P)) + + RA_c, RB_c = RA, RB + if sched == "linear_Z": + Z0 = Z_traj[0] + VA_end_a = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end_val = compute_Z(VA_end_a, VB_end, P) + + for i in range(1, N_viz + 1): + frac = i / N_viz + if sched == "geometric": + VB_i = VB_start * (VB_end / VB_start) ** frac + elif sched == "linear_VB": + VB_i = VB_start + frac * (VB_end - VB_start) + else: + Z_i = Z0 + frac * (Z_end_val - Z0) + VB_i = solve_VB_for_Z(RA_c, RB_c, Z_i, Q, P) + + try: + VA_i = compute_VA_from_VB(RA_c, RB_c, VB_i, Q) + VB_traj.append(VB_i) + Z_traj.append(compute_Z(VA_i, VB_i, P)) + RA_c, RB_c, loss = micro_step(RA_c, RB_c, VA_i, VB_i, P) + loss_traj.append(loss) + except (ValueError, AssertionError): + break + + traj_data[sched] = { + "VB": np.array(VB_traj), + "Z": np.array(Z_traj), + "loss": np.array(loss_traj), + } + + # ── Experiment 3: sweep shift size at N=2 ────────────────────────── + + decay_sweep = np.linspace(0.80, 0.99, 30) + sweep = {s: [] for s in ["geometric", "linear_Z", "optimal_2step"]} + for df in decay_sweep: + VB_e = VB * df + try: + g, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "geometric") + lz, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "linear_Z") + o2, _, _ = run_shift_optimal_2step(RA, RB, VA, VB_start, VB_e, Q, P) + except (AssertionError, ValueError): + g = lz = o2 = np.nan + sweep["geometric"].append(g) + sweep["linear_Z"].append(lz) + sweep["optimal_2step"].append(o2) + + # ── Plots ────────────────────────────────────────────────────────── + + colours = {"geometric": "C0", "linear_VB": "C1", "linear_Z": "C2"} + labels = { + "geometric": "Geometric VB (contract)", + "linear_VB": "Linear VB", + "linear_Z": "Linear Z (optimal)", + } + + fig, axes = plt.subplots(2, 2, figsize=(13, 10)) + + # (0,0) Loss vs N + ax = axes[0, 0] + for s in schedules: + ax.plot(N_values, results[s], "o-", ms=4, color=colours[s], label=labels[s]) + ax.axhline(loss_opt2, color="C3", ls=":", label=f"Optimal 2-step = {loss_opt2:.4f}") + ax.set_xlabel("Steps N") + ax.set_ylabel("Total arb loss") + ax.set_title("Arb loss vs interpolation steps") + ax.set_xscale("log") + ax.set_yscale("log") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (0,1) Ratio linear_Z / geometric + ax = axes[0, 1] + ratios = np.array(results["linear_Z"]) / np.array(results["geometric"]) + ax.plot(N_values, ratios, "o-", color="C2") + ax.axhline(1.0, color="gray", ls="--", alpha=0.5) + ax.set_xlabel("Steps N") + ax.set_ylabel("Loss(Linear Z) / Loss(Geometric VB)") + ax.set_title("Relative improvement of Z-optimal") + ax.grid(True, alpha=0.3) + + # (1,0) Z trajectories at N=8 + ax = axes[1, 0] + steps = np.arange(N_viz + 1) + for s in schedules: + ax.plot(steps, traj_data[s]["Z"], "o-", ms=4, color=colours[s], label=labels[s]) + ax.set_xlabel("Step") + ax.set_ylabel("Z = sqrt(P)*VA - VB/sqrt(P)") + ax.set_title(f"Z trajectory (N={N_viz})") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (1,1) 2-step loss vs shift size + ax = axes[1, 1] + shift_pct = (1 - decay_sweep) * 100 + ax.plot(shift_pct, sweep["geometric"], color="C0", label="Geometric VB (N=2)") + ax.plot(shift_pct, sweep["linear_Z"], color="C2", label="Linear Z (N=2)") + ax.plot(shift_pct, sweep["optimal_2step"], ":", color="C3", label="Optimal 2-step") + ax.set_xlabel("Shift size (% VB decay)") + ax.set_ylabel("Arb loss") + ax.set_title("2-step loss vs shift magnitude") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_interpolation_benchmark.png", dpi=150) + print("\nSaved reclamm_interpolation_benchmark.png") + + # ── Per-step loss bar chart for N=8 ──────────────────────────────── + + fig2, ax = plt.subplots(figsize=(10, 5)) + x = np.arange(1, N_viz + 1) + w = 0.25 + for i, s in enumerate(schedules): + ax.bar(x + i * w, traj_data[s]["loss"], w, color=colours[s], label=labels[s]) + ax.set_xlabel("Step") + ax.set_ylabel("Per-step arb loss") + ax.set_title(f"Per-step loss distribution (N={N_viz})") + ax.legend(fontsize=8) + ax.set_xticks(x + w) + plt.tight_layout() + plt.savefig("reclamm_interpolation_perstep.png", dpi=150) + print("Saved reclamm_interpolation_perstep.png") + + # ── Experiment 4: small-shift regime (paper's approximation valid) ─── + + print("\n" + "=" * 72) + print("Experiment 4: Optimal 2-step vs Geometric N=2 at small shifts") + print(" (reserves nearly constant → paper's analysis should hold)") + print("-" * 72) + print(f" {'Decay %':>8s} {'Geo N=2':>12s} {'LinZ N=2':>12s} " + f"{'Opt2':>12s} {'Opt2/Geo':>9s} {'Opt2/LinZ':>9s}") + print("-" * 72) + + small_decays = [0.999, 0.998, 0.995, 0.99, 0.98, 0.95, 0.90, 0.80] + for df in small_decays: + VB_e = VB * df + try: + g, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "geometric") + lz, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "linear_Z") + o2, _, _ = run_shift_optimal_2step( + RA, RB, VA, VB_start, VB_e, Q, P + ) + except (ValueError, AssertionError) as e: + print(f" {(1-df)*100:>7.1f}% FAILED: {e}") + continue + print(f" {(1-df)*100:>7.1f}% {g:>12.6f} {lz:>12.6f} " + f"{o2:>12.6f} {o2/g:>9.6f} {o2/lz:>9.6f}") + + print("=" * 72) + + # ── Experiment 5: brute-force JAX-optimised Z targets ──────────────── + + print("\n" + "=" * 72) + print("Experiment 5: Brute-force optimal Z targets (JAX + L-BFGS-B)") + print(" Parameterisation: softplus gaps → sorted Z targets") + print(" Initialised at linear Z (uniform gaps)") + print("-" * 72) + + opt_N_values = [2, 3, 4, 6, 8, 12, 16, 24, 32] + opt_losses = {} + opt_Z_trajs = {} + + for N in opt_N_values: + loss_bf, Z_bf = optimise_z_targets( + RA, RB, Q, P, Z_start, Z_end, N, verbose=True + ) + opt_losses[N] = loss_bf + opt_Z_trajs[N] = Z_bf + + # Comparison table + print(f"\n {'N':>5s} {'Geometric':>12s} {'Linear Z':>12s} " + f"{'BF Optimal':>12s} {'BF/LinZ':>9s} {'BF/Geo':>9s}") + print("-" * 72) + for N in opt_N_values: + idx = N_values.index(N) if N in N_values else None + g = results["geometric"][idx] if idx is not None else np.nan + lz = results["linear_Z"][idx] if idx is not None else np.nan + bf = opt_losses[N] + print(f" {N:>5d} {g:>12.6f} {lz:>12.6f} " + f"{bf:>12.6f} {bf/lz:>9.6f} {bf/g:>9.6f}") + + # ── Plot: overlay brute-force on the main loss-vs-N chart ──────────── + + fig3, axes3 = plt.subplots(1, 2, figsize=(14, 5)) + + # (left) Loss vs N with brute-force overlay + ax = axes3[0] + for s in schedules: + ax.plot(N_values, results[s], "o-", ms=4, color=colours[s], + label=labels[s]) + bf_Ns = sorted(opt_losses.keys()) + bf_vals = [opt_losses[n] for n in bf_Ns] + ax.plot(bf_Ns, bf_vals, "s--", ms=5, color="C3", label="BF Optimal (JAX)") + ax.set_xlabel("Steps N") + ax.set_ylabel("Total arb loss") + ax.set_title("Arb loss vs interpolation steps (with BF optimal)") + ax.set_xscale("log") + ax.set_yscale("log") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (right) Z trajectory comparison at N=8 + ax = axes3[1] + N_cmp = 8 + steps_cmp = np.arange(N_cmp + 1) + + # Geometric: compute Z trajectory from VB + z_geo = [Z_start] + RA_t, RB_t = RA, RB + for i in range(1, N_cmp + 1): + frac = i / N_cmp + VB_i = VB_start * (VB_end / VB_start) ** frac + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_geo.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + # Linear Z + z_linz = [Z_start] + RA_t, RB_t = RA, RB + for i in range(1, N_cmp + 1): + frac = i / N_cmp + Z_i = Z_start + frac * (Z_end - Z_start) + VB_i = solve_VB_for_Z(RA_t, RB_t, Z_i, Q, P) + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_linz.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + # BF optimal + z_bf = [Z_start] + list(opt_Z_trajs[N_cmp]) + # Trace actual Z achieved after arb at each step + z_bf_actual = [Z_start] + RA_t, RB_t = RA, RB + for i in range(N_cmp): + VB_i = solve_VB_for_Z(RA_t, RB_t, opt_Z_trajs[N_cmp][i], Q, P) + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_bf_actual.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + ax.plot(steps_cmp, z_geo, "o-", ms=4, color="C0", label="Geometric VB") + ax.plot(steps_cmp, z_linz, "o-", ms=4, color="C2", label="Linear Z") + ax.plot(steps_cmp, z_bf_actual, "s--", ms=5, color="C3", + label="BF Optimal") + ax.plot(steps_cmp, np.linspace(Z_start, Z_end, N_cmp + 1), + ":", color="gray", alpha=0.5, label="Ideal linear Z") + ax.set_xlabel("Step") + ax.set_ylabel("Z = sqrt(P)*VA - VB/sqrt(P)") + ax.set_title(f"Z trajectory comparison (N={N_cmp})") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_interpolation_bruteforce.png", dpi=150) + print("\nSaved reclamm_interpolation_bruteforce.png") + + +if __name__ == "__main__": + main() diff --git a/scripts/build_pool_grids.py b/scripts/build_pool_grids.py new file mode 100644 index 0000000..9b710c2 --- /dev/null +++ b/scripts/build_pool_grids.py @@ -0,0 +1,695 @@ +"""Build 2D arb-volume grids (cadence x gas) for all Binance-matchable pools. + +v2: Per-day daily arb volumes at each grid point, correct pool weights, + pool type dispatch (WEIGHTED vs RECLAMM). + +For each real Balancer pool where both tokens have Binance minute data: + - Uses actual LP supply trajectory (BPT totalShares from panel) + - Uses actual initial TVL from panel + - Uses correct pool weights from pools.parquet + - Dispatches reCLAMM pools with on-chain params from pools_history.db + - Sweeps cadence x gas_cost as scalar grid + - Stores per-day V_arb at each grid point (not aggregated) + +Output: results/pool_grids_v2/{pool_id_prefix}_daily.parquet + summary CSV. + +Usage: + python scripts/build_pool_grids.py + python scripts/build_pool_grids.py --workers 6 --train-days 90 +""" + +import os +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import argparse +import ast +import sqlite3 +import time +from concurrent.futures import ProcessPoolExecutor, as_completed + +import jax.numpy as jnp +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from quantammsim.runners.jax_runners import do_run_on_historic_data +from quantammsim.utils.data_processing.historic_data_utils import get_historic_parquet_data + +# ── Output ──────────────────────────────────────────────────────────────── +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "results", "pool_grids_v2", +) + +# ── Grid ────────────────────────────────────────────────────────────────── +CADENCES = [1, 2, 3, 5, 8, 12, 20, 30, 45, 60] + +# Finer gas grids — concentrated in $0.1-$3.0 where most fitted values land. +# Mainnet: 19 points (was 10). Fills the $0.25→$1.50 gap that caused +# zero-arb artifacts on low-volatility days. +GAS_COSTS_MAINNET = [ + 0.0, 0.05, 0.1, 0.15, 0.25, 0.35, 0.5, 0.65, 0.8, + 1.0, 1.25, 1.5, 2.0, 3.0, 5.0, 8.0, 12.0, 20.0, 50.0, +] +# L2: 14 points (was 10). Fills $0.05→$0.50 range. +GAS_COSTS_L2 = [ + 0.0, 0.001, 0.003, 0.005, 0.01, 0.02, 0.05, + 0.1, 0.15, 0.25, 0.5, 0.75, 1.0, 2.0, +] + +# ── Token mapping ───────────────────────────────────────────────────────── +# Maps Balancer pool token symbols to Binance trading symbols. +# Validated via Balancer hourly vs Binance minute price comparison: +# all wrappers below have daily return correlation > 0.75 with their +# underlying, or are stablecoins (basis < 0.5%). +TOKEN_MAP = { + # Wrapped natives (corr > 0.96) + "WBTC": "BTC", "WETH": "ETH", "cbBTC": "BTC", + # ETH LSTs / Aave wrappers (corr 0.87-0.97) + "wstETH": "ETH", "stETH": "ETH", "rETH": "ETH", "cbETH": "ETH", + "waEthLidoWETH": "ETH", "waEthLidowstETH": "ETH", + "waBasWETH": "ETH", # Aave wrapped WETH on Base (corr 0.975) + "waGnowstETH": "ETH", # Aave wrapped wstETH on Gnosis (corr 0.976) + # GNO wrappers (corr 0.66-0.98) + "waGnoGNO": "GNO", # Aave wrapped GNO on Gnosis (corr 0.979) + "osGNO": "GNO", # StakeWise staked GNO (corr 0.755) + # S (Sonic) wrappers (corr 0.945) + "wS": "S", + "stS": "S", # Staked Sonic + # SOL LSTs (corr 0.922) + "JitoSOL": "SOL", # Jito staked SOL + # POL/MATIC variants (corr 0.96) + "wPOL": "POL", "WMATIC": "POL", "MATIC": "POL", + # Stablecoin equivalents (all ~$1.00, basis < 0.5%) + "USDC.e": "USDC", "USDbC": "USDC", "waBasUSDC": "USDC", + "DAI": "USDC", # Stale Binance data; basis < 10bps vs USDC + "WXDAI": "USDC", "sDAI": "USDC", # Gnosis DAI variants → USDC + "USDT": "USDC", + "DOLA": "USDC", + "scUSD": "USDC", +} + +# ── reCLAMM pool → DB table mapping ────────────────────────────────────── +RECLAMM_DB_PATH = "/Users/matthew/Projects/reclamm-simulations/data/pools_history.db" +RECLAMM_POOL_TABLE_MAP = { + "0x9d1fcf346ea1b0": "AAVE_WETH", +} + + +def _get_binance_tokens(): + """Get set of tokens with Binance minute parquets.""" + data_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "quantammsim", "data", + ) + tokens = set() + for f in os.listdir(data_dir): + if f.endswith("_USD.parquet"): + tokens.add(f.replace("_USD.parquet", "")) + return tokens + + +def _map_token(tok, binance_tokens): + """Map a Balancer token symbol to Binance symbol, or None.""" + mapped = TOKEN_MAP.get(tok, tok) + return mapped if mapped in binance_tokens else None + + +def gas_costs_for_chain(chain): + return GAS_COSTS_L2 if chain != "MAINNET" else GAS_COSTS_MAINNET + + +def _parse_weights(weights_raw): + """Parse weights from pools.parquet (numpy array of strings or list).""" + if weights_raw is None: + return None + if isinstance(weights_raw, str): + weights_raw = ast.literal_eval(weights_raw) + if hasattr(weights_raw, 'tolist'): + weights_raw = weights_raw.tolist() + parsed = [] + for w in weights_raw: + if w is None or str(w).lower() == 'none': + return None + parsed.append(float(w)) + return parsed + + +def _parse_tokens(tokens_raw): + """Parse tokens from pools.parquet.""" + if isinstance(tokens_raw, str): + try: + return ast.literal_eval(tokens_raw) + except (ValueError, SyntaxError): + return [t.strip() for t in tokens_raw.split(",")] + if hasattr(tokens_raw, 'tolist'): + return tokens_raw.tolist() + return list(tokens_raw) + + +# ── Pool metadata loading ───────────────────────────────────────────────── + +def load_pools_metadata(): + """Load pools.parquet to get weights, pool_type, and true token count.""" + pools_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "pools.parquet", + ) + pools = pd.read_parquet(pools_path) + meta = {} + for _, row in pools.iterrows(): + pool_id = row["pool_id"] + prefix = pool_id[:16] + tokens = _parse_tokens(row["tokens"]) + weights = _parse_weights(row["weights"]) + meta[prefix] = { + "pool_id": pool_id, + "pool_type": row["pool_type"], + "tokens_full": tokens, + "n_tokens": len(tokens), + "weights": weights, + } + return meta + + +def load_reclamm_params(pool_id_prefix): + """Load reCLAMM on-chain params from pools_history.db.""" + table_name = RECLAMM_POOL_TABLE_MAP.get(pool_id_prefix) + if table_name is None: + return None + if not os.path.exists(RECLAMM_DB_PATH): + return None + conn = sqlite3.connect(RECLAMM_DB_PATH) + try: + df = pd.read_sql(f"SELECT * FROM [{table_name}] ORDER BY timestamp DESC LIMIT 1", conn) + if len(df) == 0: + return None + row = df.iloc[0] + return { + "price_ratio": float(row["price_ratio"]), + "centeredness_margin": float(row["margin"]), + "shift_exponent": float(row["shift_rate"]), + "swap_fee": float(row["swap_fee"]), + } + except Exception: + return None + finally: + conn.close() + + +# ── Panel loading and pool matching ─────────────────────────────────────── + +def load_panel_and_match(train_days): + """Load panel, filter to last N days, find matchable 2-token pools.""" + panel_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", + ) + panel = pd.read_parquet(panel_path) + + if "obs_date" not in panel.columns and "date" in panel.columns: + panel = panel.rename(columns={"date": "obs_date"}) + panel["obs_date"] = pd.to_datetime(panel["obs_date"]) + + if "tvl" not in panel.columns and "log_tvl" in panel.columns: + panel["tvl"] = np.exp(panel["log_tvl"]) + + cutoff = panel["obs_date"].max() - pd.Timedelta(days=train_days) + panel = panel[panel["obs_date"] >= cutoff].copy() + + binance_tokens = _get_binance_tokens() + pools_meta = load_pools_metadata() + + pools = [] + for pool_id, grp in panel.groupby("pool_id"): + prefix = pool_id[:16] + meta = pools_meta.get(prefix) + if meta is None: + continue + + # Skip multi-token pools + if meta["n_tokens"] > 2: + continue + + pool_type = meta["pool_type"] + + # Get tokens for Binance matching from panel + row = grp.iloc[0] + tokens_str = row["tokens"] + toks = [t.strip() for t in tokens_str.split(",")] + if len(toks) != 2: + continue + + mapped = [] + for t in toks: + m = _map_token(t, binance_tokens) + if m is None: + break + mapped.append(m) + else: + if mapped[0] == mapped[1]: + continue + + chain = row["chain"] + fee = row.get("swap_fee", np.exp(row["log_fee"])) + + # Get weights + weights = meta["weights"] + if pool_type == "WEIGHTED" and weights is None: + weights = [0.5, 0.5] + + # reCLAMM params + reclamm_params = None + if pool_type == "RECLAMM": + reclamm_params = load_reclamm_params(prefix) + if reclamm_params is None: + print(f" SKIP reCLAMM {'/'.join(mapped)} ({chain}): " + f"no DB params for {prefix}") + continue + + pools.append({ + "pool_id": pool_id, + "pool_id_prefix": prefix, + "tokens": mapped, + "chain": chain, + "fee": float(fee), + "panel_data": grp, + "pool_type": pool_type, + "weights": weights, + "reclamm_params": reclamm_params, + }) + + return pools + + +def build_lp_supply_df(panel_pool): + """Build lp_supply_df from daily BPT supply. Returns (df, initial_tvl).""" + panel_pool = panel_pool.sort_values("obs_date") + + has_bpt = ( + "total_shares" in panel_pool.columns + and not panel_pool["total_shares"].isna().all() + and (panel_pool["total_shares"] > 0).any() + ) + + initial_tvl = float(panel_pool["tvl"].iloc[0]) if len(panel_pool) > 0 else 0.0 + + if not has_bpt: + return None, initial_tvl + + bpt = panel_pool[["obs_date", "total_shares", "tvl"]].drop_duplicates("obs_date") + initial_bpt = bpt["total_shares"].iloc[0] + + if initial_bpt <= 0 or initial_tvl <= 0: + return None, initial_tvl + + unix_ms = bpt["obs_date"].apply( + lambda d: int(pd.Timestamp(d).timestamp() * 1000) + ).values + lp_supply = (bpt["total_shares"].values / initial_bpt).astype(float) + + return pd.DataFrame({"unix": unix_ms, "lp_supply": lp_supply}), initial_tvl + + +def get_date_range(panel_data): + """Get start/end date strings from panel data.""" + dates = panel_data["obs_date"].sort_values() + start = dates.iloc[0].strftime("%Y-%m-%d %H:%M:%S") + end = dates.iloc[-1].strftime("%Y-%m-%d %H:%M:%S") + return start, end + + +# ── Simulation ──────────────────────────────────────────────────────────── + +def run_arb_sim(tokens, fee, initial_tvl, start, end, cadence, gas_cost, + lp_supply_df=None, weights=None, pool_type="WEIGHTED", + reclamm_params=None, price_data=None): + """Run arb-only sim at one (cadence, gas_cost) point. + + Returns: pd.Series with date index → daily arb volume. + """ + fp = { + "tokens": tokens, + "startDateString": start, + "endDateString": end, + "initial_pool_value": initial_tvl, + "fees": fee, + "gas_cost": float(gas_cost), + "arb_fees": 0.0, + "do_arb": True, + "noise_trader_ratio": 0.0, + "arb_frequency": int(cadence), + "chunk_period": 1440, + "weight_interpolation_period": 1440, + } + + if pool_type == "RECLAMM" and reclamm_params is not None: + fp["rule"] = "reclamm" + fp["reclamm_use_shift_exponent"] = True + fp["fees"] = reclamm_params["swap_fee"] + params = { + "price_ratio": jnp.array(reclamm_params["price_ratio"]), + "centeredness_margin": jnp.array(reclamm_params["centeredness_margin"]), + "shift_exponent": jnp.array(reclamm_params["shift_exponent"]), + } + else: + fp["rule"] = "balancer" + if weights is not None and len(weights) == 2: + logits = np.log(np.array(weights, dtype=float)) + params = {"initial_weights_logits": jnp.array(logits)} + else: + params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + + result = do_run_on_historic_data( + fp, params, lp_supply_df=lp_supply_df, verbose=False, + price_data=price_data, + ) + + reserves = np.array(result["reserves"]) + prices = np.array(result["data_dict"]["prices"]) + unix_ms = np.array(result["data_dict"]["unix_values"]) + start_idx = int(result["data_dict"]["start_idx"]) + + T = reserves.shape[0] - 1 + prices_window = prices[start_idx:start_idx + T + 1] + delta_r = np.diff(reserves, axis=0) + step_vol = np.sum(np.abs(delta_r * prices_window[1:]), axis=1) / 2.0 + + dates = pd.to_datetime( + unix_ms[start_idx + 1:start_idx + T + 1], unit="ms", + ).normalize() + daily = pd.DataFrame( + {"date": dates, "volume": step_vol}, + ).groupby("date")["volume"].sum() + + return daily + + +def _run_cadence_sweep(pool_info, cadence, gas_costs): + """Worker: sweep all gas costs for one (pool, cadence). + + Returns list of dicts with per-day data: one dict per (gas, date) pair, + plus a summary list. + """ + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + tokens = pool_info["tokens"] + fee = pool_info["fee"] + initial_tvl = pool_info["initial_tvl"] + start = pool_info["start"] + end = pool_info["end"] + lp_supply_df = pool_info["lp_supply_df"] + weights = pool_info.get("weights") + pool_type = pool_info.get("pool_type", "WEIGHTED") + reclamm_params = pool_info.get("reclamm_params") + + price_data = pool_info.get("price_data") + + daily_rows = [] + summary_rows = [] + + for gas in gas_costs: + try: + daily = run_arb_sim( + tokens, fee, initial_tvl, start, end, cadence, gas, + lp_supply_df=lp_supply_df, + weights=weights, + pool_type=pool_type, + reclamm_params=reclamm_params, + price_data=price_data, + ) + # Store per-day data + for date_val, vol in daily.items(): + daily_rows.append({ + "cadence": cadence, + "gas_cost": gas, + "date": date_val, + "daily_arb_volume": vol, + }) + # Summary for diagnostics + summary_rows.append({ + "cadence": cadence, + "gas_cost": gas, + "total_arb_volume": daily.sum(), + "median_daily_arb_volume": daily.median(), + "mean_daily_arb_volume": daily.mean(), + "n_days": len(daily), + }) + except Exception as e: + summary_rows.append({ + "cadence": cadence, + "gas_cost": gas, + "total_arb_volume": np.nan, + "median_daily_arb_volume": np.nan, + "mean_daily_arb_volume": np.nan, + "n_days": 0, + "error": str(e), + }) + + return daily_rows, summary_rows + + +# ── Plotting ────────────────────────────────────────────────────────────── + +def plot_pool_grid(summary_df, pool_id, tokens, chain, fee, tvl, pool_type, + gas_costs, output_dir): + """Simple 2-panel diagnostic: V_arb vs cadence, gas attenuation.""" + fig, axes = plt.subplots(1, 3, figsize=(16, 4.5)) + df = summary_df + + ax = axes[0] + for gas in gas_costs: + sub = df[df["gas_cost"] == gas].sort_values("cadence") + if len(sub) > 0: + ax.plot(sub["cadence"], sub["median_daily_arb_volume"], + "o-", label=f"${gas}", markersize=3) + ax.set_xlabel("Cadence (min)") + ax.set_ylabel("Median daily V_arb ($)") + ax.set_title("V_arb vs cadence") + ax.legend(fontsize=5, ncol=2, title="gas") + + ax = axes[1] + for cadence in CADENCES: + sub = df[df["cadence"] == cadence].sort_values("gas_cost") + v0 = sub[sub["gas_cost"] == 0.0]["median_daily_arb_volume"].values + if len(v0) > 0 and v0[0] > 0: + ratio = sub["median_daily_arb_volume"].values / v0[0] + ax.plot(sub["gas_cost"].values, ratio, "o-", + label=f"{cadence}min", markersize=3) + ax.set_xlabel("Gas cost ($)") + ax.set_ylabel("V_arb / V_arb(gas=0)") + ax.set_title("Gas attenuation") + ax.set_ylim(-0.05, 1.05) + ax.legend(fontsize=5, ncol=2) + + ax = axes[2] + for gas in gas_costs: + sub = df[df["gas_cost"] == gas].sort_values("cadence") + vals = sub["median_daily_arb_volume"].values + if len(vals) > 0 and np.all(vals > 0): + ax.plot(sub["cadence"], vals, "o-", label=f"${gas}", markersize=3) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Cadence (min)") + ax.set_ylabel("Median daily V_arb ($)") + ax.set_title("Log-log") + ax.legend(fontsize=5, ncol=2, title="gas") + + tok_str = "/".join(tokens) + type_str = f" [{pool_type}]" if pool_type != "WEIGHTED" else "" + fig.suptitle( + f"{tok_str} ({chain}, fee={fee:.2%}, TVL=${tvl:,.0f}){type_str}\n" + f"{pool_id[:16]}", + fontsize=10, + ) + fig.tight_layout() + path = os.path.join(output_dir, f"{pool_id[:16]}_grid.png") + fig.savefig(path, dpi=120, bbox_inches="tight") + plt.close() + + +# ── Main ────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--workers", type=int, default=6) + parser.add_argument("--train-days", type=int, default=90) + parser.add_argument("--pools", type=str, default=None, + help="Comma-separated pool_id prefixes to run (default: all)") + args = parser.parse_args() + + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print("Loading panel and matching pools...") + pools = load_panel_and_match(args.train_days) + print(f"Found {len(pools)} matchable 2-token pools\n") + + for p in pools: + lp_df, tvl = build_lp_supply_df(p["panel_data"]) + p["lp_supply_df"] = lp_df + p["initial_tvl"] = tvl + p["start"], p["end"] = get_date_range(p["panel_data"]) + + pools = [p for p in pools if p["panel_data"]["obs_date"].nunique() >= 14] + pools.sort(key=lambda p: p["initial_tvl"], reverse=True) + + # Filter to specific pools if requested + if args.pools: + requested = set(args.pools.split(",")) + pools = [p for p in pools if p["pool_id_prefix"] in requested] + print(f"Filtered to {len(pools)} requested pools\n") + + # Print pool summary + print(f"{'#':>3} {'Tokens':<12} {'Chain':<10} {'Type':<9} {'Weights':<10} " + f"{'Fee':>6} {'TVL':>12} {'Days':>5} {'Pool ID':<18}") + print("-" * 95) + for i, p in enumerate(pools): + n_days = p["panel_data"]["obs_date"].nunique() + w_str = "/".join(f"{w:.0%}" for w in p["weights"]) if p["weights"] else "N/A" + print(f"{i+1:3d} {'/'.join(p['tokens']):<12} {p['chain']:<10} " + f"{p['pool_type']:<9} {w_str:<10} " + f"{p['fee']:5.2%} ${p['initial_tvl']:>10,.0f} " + f"{n_days:5d} {p['pool_id'][:16]}") + + all_summaries = [] + t_total = time.time() + + for pool_idx, pool in enumerate(pools): + pool_id = pool["pool_id"] + prefix = pool["pool_id_prefix"] + tokens = pool["tokens"] + chain = pool["chain"] + fee = pool["fee"] + tvl = pool["initial_tvl"] + pool_type = pool["pool_type"] + gas_costs = gas_costs_for_chain(chain) + n_runs = len(CADENCES) * len(gas_costs) + + if tvl <= 0: + print(f"\n SKIP {'/'.join(tokens)} ({chain}): TVL=0") + continue + + w_str = "/".join(f"{w:.0%}" for w in pool["weights"]) if pool["weights"] else "N/A" + print(f"\n{'='*60}") + print(f" [{pool_idx+1}/{len(pools)}] {'/'.join(tokens)} " + f"({chain}, {pool_type}, {w_str}, fee={fee:.2%}, TVL=${tvl:,.0f})") + print(f" Grid: {len(CADENCES)} cadences x {len(gas_costs)} gas = {n_runs} runs") + print(f"{'='*60}") + + t0 = time.time() + + # Preload price data once per pool (avoids re-reading parquet per run) + sorted_tokens = sorted(tokens) + price_data = get_historic_parquet_data(sorted_tokens, ["close"]) + + pool_info = { + "tokens": tokens, + "fee": fee, + "initial_tvl": tvl, + "start": pool["start"], + "end": pool["end"], + "lp_supply_df": pool["lp_supply_df"], + "weights": pool["weights"], + "pool_type": pool_type, + "reclamm_params": pool.get("reclamm_params"), + "price_data": price_data, + } + + all_daily_rows = [] + all_summary_rows = [] + + if args.workers <= 1: + for cadence in CADENCES: + daily_rows, summary_rows = _run_cadence_sweep( + pool_info, cadence, gas_costs, + ) + all_daily_rows.extend(daily_rows) + all_summary_rows.extend(summary_rows) + for r in summary_rows: + print(f" cad={r['cadence']:3d} gas=${r['gas_cost']:6.3f} -> " + f"median=${r['median_daily_arb_volume']:,.0f}/day") + else: + futures = {} + with ProcessPoolExecutor(max_workers=args.workers) as executor: + for cadence in CADENCES: + fut = executor.submit( + _run_cadence_sweep, pool_info, cadence, gas_costs, + ) + futures[fut] = cadence + + done = 0 + for fut in as_completed(futures): + cadence = futures[fut] + done += 1 + try: + daily_rows, summary_rows = fut.result() + all_daily_rows.extend(daily_rows) + all_summary_rows.extend(summary_rows) + medians = [r["median_daily_arb_volume"] for r in summary_rows + if not np.isnan(r.get("median_daily_arb_volume", np.nan))] + if medians: + print(f" [{done:2d}/{len(CADENCES)}] cad={cadence:3d} — " + f"V_arb: ${min(medians):,.0f} – ${max(medians):,.0f}/day") + else: + print(f" [{done:2d}/{len(CADENCES)}] cad={cadence:3d} — " + f"all failed") + except Exception as e: + print(f" [{done:2d}/{len(CADENCES)}] cad={cadence:3d} FAILED: {e}") + + elapsed = time.time() - t0 + print(f" {n_runs} runs in {elapsed:.1f}s ({elapsed/max(n_runs,1):.2f}s/run)") + + # Save per-day parquet + if all_daily_rows: + daily_df = pd.DataFrame(all_daily_rows) + daily_df["date"] = pd.to_datetime(daily_df["date"]) + parquet_path = os.path.join(OUTPUT_DIR, f"{prefix}_daily.parquet") + daily_df.to_parquet(parquet_path, index=False) + print(f" Saved {len(daily_df)} daily rows -> {parquet_path}") + + # Save summary CSV for diagnostics + summary_df = pd.DataFrame(all_summary_rows) + csv_path = os.path.join(OUTPUT_DIR, f"{prefix}_summary.csv") + summary_df.to_csv(csv_path, index=False) + + # Plot + if len(summary_df) > 0: + plot_pool_grid(summary_df, pool_id, tokens, chain, fee, tvl, + pool_type, gas_costs, OUTPUT_DIR) + + # Global summary + g0 = summary_df[summary_df["gas_cost"] == 0.0] if len(summary_df) > 0 else pd.DataFrame() + all_summaries.append({ + "pool_id": prefix, + "tokens": "/".join(tokens), + "chain": chain, + "pool_type": pool_type, + "weights": str(pool["weights"]), + "fee": fee, + "tvl": tvl, + "n_days": summary_df["n_days"].max() if len(summary_df) > 0 else 0, + "n_daily_rows": len(all_daily_rows), + "v_arb_cad1_gas0": g0[g0["cadence"] == 1]["median_daily_arb_volume"].values[0] + if len(g0[g0["cadence"] == 1]) > 0 else np.nan, + "v_arb_cad60_gas0": g0[g0["cadence"] == 60]["median_daily_arb_volume"].values[0] + if len(g0[g0["cadence"] == 60]) > 0 else np.nan, + "elapsed_s": elapsed, + }) + + total_elapsed = time.time() - t_total + + if all_summaries: + summary_df = pd.DataFrame(all_summaries) + summary_path = os.path.join(OUTPUT_DIR, "grid_summary.csv") + summary_df.to_csv(summary_path, index=False) + print(f"\n{'='*60}") + print(f" Summary saved: {summary_path}") + print(f" Total: {len(all_summaries)} pools, " + f"{total_elapsed:.0f}s ({total_elapsed/60:.1f} min)") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_bayesian.py b/scripts/calibrate_noise_bayesian.py new file mode 100644 index 0000000..efdfb4b --- /dev/null +++ b/scripts/calibrate_noise_bayesian.py @@ -0,0 +1,853 @@ +"""Bayesian hierarchical noise volume model across Balancer pools. + +Full Bayesian version of the noise calibration: ALL K=4 per-pool +coefficients (intercept, TVL elasticity, volatility response, weekend +effect) vary per pool with pool-level covariates modulating their priors, +and an LKJ-decomposed covariance capturing correlations between +coefficients. + +Generative model: + For pool i with pool-level covariates z_i, day t: + + mu_i = B . z_i # K-vector population mean + eta_i ~ N(0, I_K) # non-centered offsets + theta_i = mu_i + diag(sigma) . L . eta_i # per-pool coefficients + + log(V_{i,t}) ~ N(theta_i . x_{i,t}, sigma_eps^2) + + theta_i = [intercept_i, b_tvl_i, b_sigma_i, b_weekend_i] + x_{i,t} = [1, log_tvl, volatility, weekend] + z_i = [1, chain_dummies(6), tier_A_dummies(2), tier_B_dummies(2), log_fee] + +Priors: + B_{k,d} ~ N(0, 5^2) + sigma_k ~ HalfNormal(2.0) + L ~ LKJCholesky(K=4, eta=2) + sigma_eps ~ HalfNormal(3.0) + +Usage: + # Full pipeline: fit + output + diagnostics + python scripts/calibrate_noise_bayesian.py \\ + --fit --output results/bayesian_noise_params.json --plot + + # Predict for an unseen pool + python scripts/calibrate_noise_bayesian.py \\ + --predict --chain BASE --tokens ETH USDC --fee 0.003 + + # Custom NUTS settings + python scripts/calibrate_noise_bayesian.py \\ + --fit --num-warmup 2000 --num-samples 4000 --num-chains 4 +""" + +import argparse +import json +import os +import sys + +import numpy as np +import pandas as pd + +# arviz 0.17.x imports scipy.signal.gaussian which was removed in scipy 1.13+. +# Patch it back from scipy.signal.windows before any arviz import. +try: + from scipy.signal import gaussian as _ # noqa: F401 +except ImportError: + from scipy.signal.windows import gaussian as _gauss + import scipy.signal + scipy.signal.gaussian = _gauss + +# --------------------------------------------------------------------------- +# Reuse constants and helpers from the frequentist script +# --------------------------------------------------------------------------- + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "local_data", "noise_calibration" +) + +# Reference levels for dummy coding (dropped categories) +REF_CHAIN = "ARBITRUM" +REF_TIER = 0 + +# Ordered non-reference chains (alphabetical excluding REF_CHAIN) +CHAIN_ORDER = ["BASE", "GNOSIS", "MAINNET", "OPTIMISM", "POLYGON", "SONIC"] + +K = 4 # number of per-pool coefficients +D = 12 # pool-level covariate dimension: 1 + 6 chains + 2 tier_A + 2 tier_B + 1 log_fee + +COEFF_NAMES = ["intercept", "b_tvl", "b_sigma", "b_weekend"] + + +# --------------------------------------------------------------------------- +# Token tier helpers (duplicated to avoid import fragility) +# --------------------------------------------------------------------------- + +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", +} + +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} + + +def classify_token_tier(symbol: str) -> int: + s = symbol.strip() + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# Data preparation +# --------------------------------------------------------------------------- + +def load_panel(cache_dir: str = CACHE_DIR) -> pd.DataFrame: + """Load the cached panel parquet produced by calibrate_noise_hierarchical.py --fetch.""" + panel_path = os.path.join(cache_dir, "panel.parquet") + if not os.path.exists(panel_path): + print(f"ERROR: Panel cache not found at {panel_path}", file=sys.stderr) + print("Run: python scripts/calibrate_noise_hierarchical.py --fetch", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_path) + + # Filter pools with < 10 observations + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid_pools)].copy() + + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + return panel + + +def _build_z_pool(pool_meta: pd.DataFrame) -> np.ndarray: + """Build (N_pools, D) pool-level covariate matrix. + + Columns: [1, chain_BASE, ..., chain_SONIC (6), + tier_A_1, tier_A_2, tier_B_1, tier_B_2, log_fee] + """ + N = len(pool_meta) + z = np.zeros((N, D), dtype=np.float64) + + # Intercept + z[:, 0] = 1.0 + + # Chain dummies (columns 1-6) + for j, chain in enumerate(CHAIN_ORDER): + z[:, 1 + j] = (pool_meta["chain"].values == chain).astype(float) + + # tier_A dummies (columns 7-8): tiers 1 and 2, reference = 0 + tier_a = pool_meta["tier_A"].values.astype(int) + z[:, 7] = (tier_a == 1).astype(float) + z[:, 8] = (tier_a == 2).astype(float) + + # tier_B dummies (columns 9-10): tiers 1 and 2, reference = 0 + tier_b = pool_meta["tier_B"].values.astype(int) + z[:, 9] = (tier_b == 1).astype(float) + z[:, 10] = (tier_b == 2).astype(float) + + # log_fee (column 11) + z[:, 11] = np.log(np.maximum(pool_meta["swap_fee"].values.astype(float), 1e-6)) + + return z + + +def prepare_data(panel: pd.DataFrame) -> dict: + """Construct JAX-ready arrays from the panel DataFrame. + + Returns dict with: + pool_idx : (N_obs,) int32 — pool index per observation + z_pool : (N_pools, D) float64 — pool-level covariates + x_obs : (N_obs, K) float64 — within-day regressors + y_obs : (N_obs,) float64 — log_volume + pool_ids : list — ordered pool IDs + pool_meta : DataFrame — per-pool metadata (indexed same as z_pool rows) + """ + # Stable pool ordering + pool_ids = sorted(panel["pool_id"].unique()) + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + N_pools = len(pool_ids) + + # Pool-level metadata (one row per pool) + pool_meta = panel.drop_duplicates("pool_id").set_index("pool_id").loc[pool_ids].reset_index() + z_pool = _build_z_pool(pool_meta) + + # Observation-level arrays + pool_idx = panel["pool_id"].map(pool_id_to_idx).values.astype(np.int32) + x_obs = np.column_stack([ + np.ones(len(panel)), + panel["log_tvl"].values, + panel["volatility"].values, + panel["weekend"].values, + ]).astype(np.float64) + y_obs = panel["log_volume"].values.astype(np.float64) + + print(f" Prepared: N_obs={len(y_obs)}, N_pools={N_pools}, K={K}, D={D}") + print(f" z_pool range check — log_fee: [{z_pool[:, 11].min():.2f}, {z_pool[:, 11].max():.2f}]") + + return { + "pool_idx": pool_idx, + "z_pool": z_pool, + "x_obs": x_obs, + "y_obs": y_obs, + "pool_ids": pool_ids, + "pool_meta": pool_meta, + "N_pools": N_pools, + } + + +# --------------------------------------------------------------------------- +# NumPyro model +# --------------------------------------------------------------------------- + +def hierarchical_noise_model(pool_idx, z_pool, x_obs, y_obs=None, + N_pools=None, K=4, D=12): + """Bayesian hierarchical noise volume model. + + Non-centered parameterization with LKJ correlation prior. + """ + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + + N_obs = pool_idx.shape[0] + + # --- Population coefficient matrix B: (K, D) --- + B = numpyro.sample("B", dist.Normal(0.0, 5.0).expand([K, D]).to_event(2)) + + # --- Per-pool scale and correlation --- + sigma = numpyro.sample("sigma", dist.HalfNormal(2.0).expand([K]).to_event(1)) + L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(K, concentration=2.0)) + + # Cholesky factor of covariance: diag(sigma) @ L_Omega + L_Sigma = jnp.diag(sigma) @ L_Omega # (K, K) + + # --- Non-centered pool effects --- + with numpyro.plate("pools", N_pools): + eta = numpyro.sample("eta", dist.Normal(0.0, 1.0).expand([K]).to_event(1)) + + # theta_i = B @ z_i + L_Sigma @ eta_i for each pool i + # mu: (N_pools, K) = z_pool @ B^T + mu = z_pool @ B.T # (N_pools, K) + theta = mu + eta @ L_Sigma.T # (N_pools, K) + + # --- Observation model --- + sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(3.0)) + + # Predicted log-volume: theta[pool_idx] . x_obs (dot product per obs) + theta_obs = theta[pool_idx] # (N_obs, K) + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) # (N_obs,) + + with numpyro.plate("obs", N_obs): + numpyro.sample("y", dist.Normal(mu_obs, sigma_eps), obs=y_obs) + + # Deterministic: store theta for extraction + numpyro.deterministic("theta", theta) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- + +def run_inference(data, num_warmup=1000, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42): + """Run NUTS on the hierarchical model. + + Returns the MCMC object with samples. + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import MCMC, NUTS + + # Use all available CPU cores for chains + numpyro.set_host_device_count(min(num_chains, len(jax.devices("cpu")))) + + kernel = NUTS( + hierarchical_noise_model, + target_accept_prob=target_accept, + max_tree_depth=max_tree_depth, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=True, + ) + + rng_key = jax.random.PRNGKey(seed) + + print(f"\n Running NUTS: {num_chains} chains x " + f"({num_warmup} warmup + {num_samples} samples)") + print(f" target_accept={target_accept}, max_tree_depth={max_tree_depth}") + + mcmc.run( + rng_key, + pool_idx=jnp.array(data["pool_idx"]), + z_pool=jnp.array(data["z_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K=K, + D=D, + ) + + mcmc.print_summary(exclude_deterministic=True) + return mcmc + + +# --------------------------------------------------------------------------- +# Post-processing +# --------------------------------------------------------------------------- + +def extract_noise_params(mcmc, data) -> list: + """Extract per-pool noise params from MCMC posterior. + + Reconstructs theta from the non-centered parameterization, + takes posterior medians, and applies weekend absorption: + b_0_effective = b_0_raw + b_weekend * (2/7) + + Returns list of dicts compatible with reclamm_loglinear_noise_volume. + """ + samples = mcmc.get_samples() + theta_samples = samples["theta"] # (n_samples, N_pools, K) + + # Posterior median per pool + theta_median = np.median(theta_samples, axis=0) # (N_pools, K) + theta_std = np.std(theta_samples, axis=0) # (N_pools, K) + + pool_ids = data["pool_ids"] + pool_meta = data["pool_meta"] + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + b_0_raw, b_tvl, b_sigma, b_weekend = theta_median[i] + std_vals = theta_std[i] + + # Weekend absorption: simulator has no weekend indicator, + # so fold the expected weekend effect into the intercept. + # Weekend days = 2/7 of all days. + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "theta_median": [float(x) for x in theta_median[i]], + "theta_std": [float(x) for x in std_vals], + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(meta["swap_fee"]), + }, + }) + + return results + + +def predict_new_pool(mcmc, data, chain: str, tokens: list, fee: float) -> dict: + """Predict noise params for an unseen pool using population effects. + + Constructs z_new, computes mu_new = B @ z_new across posterior samples, + and returns median + 90% credible intervals. + """ + # Build z_new + z_new = np.zeros(D, dtype=np.float64) + z_new[0] = 1.0 # intercept + + # Chain dummies + if chain in CHAIN_ORDER: + j = CHAIN_ORDER.index(chain) + z_new[1 + j] = 1.0 + + # Tier dummies + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + if tier_a == 1: + z_new[7] = 1.0 + elif tier_a == 2: + z_new[8] = 1.0 + if tier_b == 1: + z_new[9] = 1.0 + elif tier_b == 2: + z_new[10] = 1.0 + + # log_fee + z_new[11] = np.log(max(fee, 1e-6)) + + # Compute mu_new = B @ z_new across all posterior samples + B_samples = np.array(mcmc.get_samples()["B"]) # (n_samples, K, D) + mu_samples = np.einsum("skd,d->sk", B_samples, z_new) # (n_samples, K) + + mu_median = np.median(mu_samples, axis=0) + mu_q05 = np.percentile(mu_samples, 5, axis=0) + mu_q95 = np.percentile(mu_samples, 95, axis=0) + + # Weekend absorption + b_0_raw, b_tvl, b_sigma, b_weekend = mu_median + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + result = { + "chain": chain, + "tokens": tokens, + "fee": fee, + "prediction_source": "population_level", + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(fee), + }, + "credible_intervals_90": { + name: { + "median": float(mu_median[k]), + "q05": float(mu_q05[k]), + "q95": float(mu_q95[k]), + } + for k, name in enumerate(COEFF_NAMES) + }, + } + + print(f"\n Predicted noise_params for {chain} {tokens} (fee={fee}):") + for name, ci in result["credible_intervals_90"].items(): + print(f" {name:12s}: {ci['median']:+.3f} " + f"[{ci['q05']:+.3f}, {ci['q95']:+.3f}]") + print(f"\n Effective b_0 (weekend-absorbed): {b_0_effective:.3f}") + + return result + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def check_convergence(mcmc) -> dict: + """Compute convergence diagnostics: R-hat, ESS, divergences.""" + import arviz as az + + idata = az.from_numpyro(mcmc) + + # R-hat and ESS for non-deterministic parameters + n_chains = idata.posterior.sizes.get("chain", 1) + + rhat_max = float("nan") + if n_chains >= 2: + rhat = az.rhat(idata) + rhat_vals = [] + for var in rhat.data_vars: + if var == "theta": + continue # deterministic + vals = rhat[var].values + rhat_vals.extend(vals.flatten()) + rhat_max = float(np.nanmax(rhat_vals)) if rhat_vals else float("nan") + + ess = az.ess(idata) + ess_vals = [] + for var in ess.data_vars: + if var == "theta": + continue + vals = ess[var].values + ess_vals.extend(vals.flatten()) + ess_min = float(np.nanmin(ess_vals)) if ess_vals else float("nan") + + # Divergences + divergences = int(idata.sample_stats["diverging"].sum().values) + + print(f"\n Convergence diagnostics:") + if n_chains >= 2: + print(f" R-hat max: {rhat_max:.4f} {'OK' if rhat_max < 1.05 else 'WARNING'}") + else: + print(f" R-hat max: N/A (need >= 2 chains)") + print(f" ESS min: {ess_min:.0f} {'OK' if ess_min > 400 else 'WARNING'}") + print(f" Divergences: {divergences} {'OK' if divergences == 0 else 'WARNING'}") + + return { + "r_hat_max": rhat_max, + "ess_min": ess_min, + "divergences": divergences, + } + + +def plot_bayesian_diagnostics(mcmc, data, output_dir="results"): + """Generate ArviZ diagnostic plots.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import arviz as az + + os.makedirs(output_dir, exist_ok=True) + idata = az.from_numpyro(mcmc) + samples = mcmc.get_samples() + + # --- 1. Trace plots for sigma, sigma_eps --- + axes = az.plot_trace(idata, var_names=["sigma", "sigma_eps"], compact=True) + fig1 = axes.ravel()[0].figure + fig1.set_size_inches(14, 8) + path1 = os.path.join(output_dir, "bayesian_trace_sigma.png") + fig1.savefig(path1, dpi=150, bbox_inches="tight") + plt.close(fig1) + print(f" Saved: {path1}") + + # --- 2. Posterior predictive: predicted vs observed --- + theta_samples = samples["theta"] # (S, N_pools, K) + sigma_eps_samples = np.array(samples["sigma_eps"]) # (S,) + theta_median = np.median(theta_samples, axis=0) # (N_pools, K) + + pool_idx = data["pool_idx"] + x_obs = data["x_obs"] + y_obs = data["y_obs"] + + theta_obs = theta_median[pool_idx] + y_pred = np.sum(theta_obs * x_obs, axis=1) + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + ax.scatter(y_obs, y_pred, alpha=0.1, s=4, color="steelblue") + lims = [min(y_obs.min(), y_pred.min()), max(y_obs.max(), y_pred.max())] + ax.plot(lims, lims, "r--", linewidth=1) + ax.set_xlabel("Observed log(volume)") + ax.set_ylabel("Predicted log(volume)") + ax.set_title("Posterior predictive check") + r2 = 1 - np.var(y_obs - y_pred) / np.var(y_obs) + ax.text(0.05, 0.95, f"R² = {r2:.3f}", transform=ax.transAxes, + fontsize=11, verticalalignment="top") + + ax = axes[1] + residuals = y_obs - y_pred + ax.hist(residuals, bins=60, color="steelblue", edgecolor="white", alpha=0.8) + ax.axvline(0, color="red", linestyle="--") + ax.set_xlabel("Residual") + ax.set_title(f"Residual distribution (σ_ε ≈ {np.median(sigma_eps_samples):.2f})") + + plt.tight_layout() + path2 = os.path.join(output_dir, "bayesian_posterior_predictive.png") + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + # --- 3. Per-pool b_c (TVL elasticity) by chain/tier --- + pool_meta = data["pool_meta"] + b_tvl_all = theta_median[:, 1] # index 1 = b_tvl + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + chains_present = sorted(pool_meta["chain"].unique()) + chain_data = [] + chain_labels = [] + for c in chains_present: + mask = pool_meta["chain"].values == c + if mask.sum() > 0: + chain_data.append(b_tvl_all[mask]) + chain_labels.append(f"{c}\n(n={mask.sum()})") + ax.boxplot(chain_data, tick_labels=chain_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by chain") + + ax = axes[1] + # By tier_A + tier_a_vals = pool_meta["tier_A"].values.astype(int) + tier_labels_map = {0: "Blue-chip", 1: "Mid-cap", 2: "Long-tail"} + tier_data = [] + tier_labels = [] + for t in [0, 1, 2]: + mask = tier_a_vals == t + if mask.sum() > 0: + tier_data.append(b_tvl_all[mask]) + tier_labels.append(f"{tier_labels_map[t]}\n(n={mask.sum()})") + ax.boxplot(tier_data, tick_labels=tier_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by token tier (best token)") + + plt.tight_layout() + path3 = os.path.join(output_dir, "bayesian_per_pool_b_c.png") + plt.savefig(path3, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path3}") + + # --- 4. Correlation matrix posterior --- + L_Omega_samples = np.array(samples["L_Omega"]) # (S, K, K) + # Correlation = L @ L^T + Omega_samples = np.einsum("sij,skj->sik", L_Omega_samples, L_Omega_samples) + Omega_median = np.median(Omega_samples, axis=0) + + fig, ax = plt.subplots(figsize=(7, 6)) + im = ax.imshow(Omega_median, vmin=-1, vmax=1, cmap="RdBu_r") + ax.set_xticks(range(K)) + ax.set_yticks(range(K)) + ax.set_xticklabels(COEFF_NAMES, rotation=45, ha="right") + ax.set_yticklabels(COEFF_NAMES) + for i in range(K): + for j in range(K): + ax.text(j, i, f"{Omega_median[i, j]:.2f}", ha="center", va="center", + fontsize=10, color="white" if abs(Omega_median[i, j]) > 0.5 else "black") + plt.colorbar(im, ax=ax, shrink=0.8) + ax.set_title("Posterior median correlation matrix (Ω)") + plt.tight_layout() + path4 = os.path.join(output_dir, "bayesian_correlation_matrix.png") + plt.savefig(path4, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path4}") + + # --- 5. Shrinkage plot: OLS b_c vs hierarchical b_c --- + # Compute per-pool OLS b_c for comparison + panel_meta = data["pool_meta"] + pool_idx_arr = data["pool_idx"] + pool_ids = data["pool_ids"] + + ols_b_c = np.zeros(len(pool_ids)) + for i, pid in enumerate(pool_ids): + mask = pool_idx_arr == i + if mask.sum() < 5: + ols_b_c[i] = np.nan + continue + x_i = data["x_obs"][mask] + y_i = data["y_obs"][mask] + # Simple OLS: y = X @ beta + try: + beta, _, _, _ = np.linalg.lstsq(x_i, y_i, rcond=None) + ols_b_c[i] = beta[1] # TVL coefficient + except np.linalg.LinAlgError: + ols_b_c[i] = np.nan + + hier_b_c = theta_median[:, 1] + valid = np.isfinite(ols_b_c) + + fig, ax = plt.subplots(figsize=(8, 8)) + ax.scatter(ols_b_c[valid], hier_b_c[valid], alpha=0.6, s=20, color="steelblue") + + # Population mean line + pop_b_c = np.median(hier_b_c) + ax.axhline(pop_b_c, color="red", linestyle="--", linewidth=0.8, + label=f"Population median = {pop_b_c:.3f}") + + # 45-degree line + lims = [min(np.nanmin(ols_b_c[valid]), hier_b_c[valid].min()) - 0.2, + max(np.nanmax(ols_b_c[valid]), hier_b_c[valid].max()) + 0.2] + ax.plot(lims, lims, "k:", linewidth=0.8, alpha=0.5) + ax.set_xlabel("Per-pool OLS b_c") + ax.set_ylabel("Hierarchical posterior median b_c") + ax.set_title("Shrinkage: OLS vs hierarchical TVL elasticity") + ax.legend() + plt.tight_layout() + path5 = os.path.join(output_dir, "bayesian_shrinkage_b_c.png") + plt.savefig(path5, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path5}") + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + +def generate_output_json(pool_params, mcmc, data, convergence, output_path, + num_warmup, num_samples, num_chains, target_accept): + """Write structured JSON output with population effects and per-pool params.""" + samples = mcmc.get_samples() + + B_median = np.median(np.array(samples["B"]), axis=0).tolist() + sigma_median = np.median(np.array(samples["sigma"]), axis=0).tolist() + sigma_eps_median = float(np.median(np.array(samples["sigma_eps"]))) + + output = { + "model": "bayesian_hierarchical_loglinear", + "inference": { + "method": "NUTS", + "num_warmup": num_warmup, + "num_samples": num_samples, + "num_chains": num_chains, + "target_accept_prob": target_accept, + }, + "population_effects": { + "B": B_median, + "sigma": sigma_median, + "sigma_eps": sigma_eps_median, + "coeff_names": COEFF_NAMES, + "covariate_names": ( + ["intercept"] + [f"chain_{c}" for c in CHAIN_ORDER] + + ["tier_A_1", "tier_A_2", "tier_B_1", "tier_B_2", "log_fee"] + ), + }, + "convergence": convergence, + "pools": { + p["pool_id"]: { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + } + for p in pool_params + }, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Bayesian hierarchical noise volume model for Balancer pools" + ) + parser.add_argument( + "--fetch", action="store_true", + help="Fetch pool data (delegates to calibrate_noise_hierarchical.py --fetch)", + ) + parser.add_argument("--fit", action="store_true", help="Run NUTS inference") + parser.add_argument("--plot", action="store_true", help="Generate diagnostic plots") + parser.add_argument("--output", default=None, help="Output JSON path") + parser.add_argument("--output-dir", default="results", help="Plot output directory") + parser.add_argument("--predict", action="store_true", help="Predict for a new pool") + parser.add_argument("--chain", default=None, help="Chain for --predict") + parser.add_argument("--tokens", nargs="+", default=None, help="Tokens for --predict") + parser.add_argument("--fee", type=float, default=0.003, help="Fee for --predict") + parser.add_argument("--cache-dir", default=None, help="Cache directory") + + # NUTS hyperparameters + parser.add_argument("--num-warmup", type=int, default=1000) + parser.add_argument("--num-samples", type=int, default=2000) + parser.add_argument("--num-chains", type=int, default=4) + parser.add_argument("--target-accept", type=float, default=0.85) + parser.add_argument("--max-tree-depth", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args() + + cache_dir = args.cache_dir or CACHE_DIR + + if not any([args.fetch, args.fit, args.predict]): + parser.error("At least one of --fetch, --fit, --predict is required") + + # --- Fetch (delegate to existing script) --- + if args.fetch: + import subprocess + cmd = [ + sys.executable, "scripts/calibrate_noise_hierarchical.py", + "--fetch", "--cache-dir", cache_dir, + ] + print("Delegating data fetch to calibrate_noise_hierarchical.py...") + subprocess.run(cmd, check=True) + + # --- Fit --- + if args.fit: + print("\nBayesian Hierarchical Noise Volume Model") + print("=" * 60) + + panel = load_panel(cache_dir) + data = prepare_data(panel) + + mcmc = run_inference( + data, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + ) + + convergence = check_convergence(mcmc) + pool_params = extract_noise_params(mcmc, data) + + # Print summary statistics + b_c_vals = [p["noise_params"]["b_c"] for p in pool_params] + b_0_vals = [p["noise_params"]["b_0"] for p in pool_params] + print(f"\n Per-pool b_c: mean={np.mean(b_c_vals):.3f}, " + f"std={np.std(b_c_vals):.3f}, " + f"range=[{np.min(b_c_vals):.3f}, {np.max(b_c_vals):.3f}]") + print(f" Per-pool b_0: mean={np.mean(b_0_vals):.3f}, " + f"std={np.std(b_0_vals):.3f}") + + if args.output: + generate_output_json( + pool_params, mcmc, data, convergence, args.output, + args.num_warmup, args.num_samples, args.num_chains, + args.target_accept, + ) + + if args.plot: + print("\nGenerating diagnostic plots...") + plot_bayesian_diagnostics(mcmc, data, output_dir=args.output_dir) + + # Save MCMC samples for --predict reuse + mcmc_cache = os.path.join(cache_dir, "bayesian_mcmc_samples.npz") + samples = mcmc.get_samples() + np.savez_compressed( + mcmc_cache, + **{k: np.array(v) for k, v in samples.items()}, + ) + # Also save data arrays for predict + data_cache = os.path.join(cache_dir, "bayesian_data.npz") + np.savez_compressed( + data_cache, + pool_idx=data["pool_idx"], + z_pool=data["z_pool"], + x_obs=data["x_obs"], + y_obs=data["y_obs"], + ) + # Save pool_ids list + with open(os.path.join(cache_dir, "bayesian_pool_ids.json"), "w") as f: + json.dump(data["pool_ids"], f) + print(f" Saved MCMC samples -> {mcmc_cache}") + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + parser.error("--predict requires --chain and --tokens") + + # Load cached MCMC samples + mcmc_cache = os.path.join(cache_dir, "bayesian_mcmc_samples.npz") + if not os.path.exists(mcmc_cache): + print(f"ERROR: MCMC cache not found at {mcmc_cache}", file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + # For predict, we only need B samples — create a minimal mock + cached = np.load(mcmc_cache) + + class _MockMCMC: + """Minimal interface to reuse predict_new_pool with cached samples.""" + def __init__(self, samples_dict): + self._samples = samples_dict + def get_samples(self): + return self._samples + + samples_dict = {k: cached[k] for k in cached.files} + mock_mcmc = _MockMCMC(samples_dict) + + result = predict_new_pool(mock_mcmc, None, args.chain, args.tokens, args.fee) + print(json.dumps(result, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_hierarchical.py b/scripts/calibrate_noise_hierarchical.py new file mode 100644 index 0000000..4fef642 --- /dev/null +++ b/scripts/calibrate_noise_hierarchical.py @@ -0,0 +1,1556 @@ +"""Bayesian hierarchical noise volume model across Balancer WEIGHTED + RECLAMM pools. + +Pools data cross-sectionally across all Balancer weighted/reCLAMM pools, +fits a Bayesian hierarchical model where pool covariates (chain, token tier, +fee) modulate all coefficients via group-level regression, with full +posterior inference via NumPyro. + +Model: + Hyperpriors: + Φ ~ Normal(0, 2) (K × 3) group-level regression + σ_θ ~ HalfNormal(2) (3,) per-coefficient scales + L_ω ~ LKJCholesky(3, η=2) correlation structure + β_weekend ~ Normal(0, 2) shared nuisance + σ_ε ~ HalfNormal(3) observation noise + + For each pool i: + x_i = [1, chain_dummies, tier_dummies, log_fee] (K,) covariates + z_i ~ N(0, I₃) non-centered + θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i (α_i, β_tvl_i, β_vol_i) + + For each observation (i, t): + log(V) ~ N(α_i + β_tvl_i·log_tvl + β_vol_i·vol + β_weekend·weekend, σ²_ε) + +Usage: + # Full pipeline: fetch data + fit model + output + python scripts/calibrate_noise_hierarchical.py \\ + --fetch --fit --output results/hierarchical_noise_params.json --plot + + # Use cached data, re-fit only + python scripts/calibrate_noise_hierarchical.py \\ + --fit --output results/hierarchical_noise_params.json + + # Predict for a new pool + python scripts/calibrate_noise_hierarchical.py \\ + --predict --chain BASE --tokens ETH BTC --fee 0.003 + + # Use NUTS instead of SVI + python scripts/calibrate_noise_hierarchical.py \\ + --fit --nuts --output results/hierarchical_noise_params.json +""" + +import argparse +import json +import os +import sys +import time +import urllib.request +from datetime import datetime, timezone + +import numpy as np +import pandas as pd + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.infer import SVI, MCMC, NUTS, Trace_ELBO, Predictive +from numpyro.infer.autoguide import AutoMultivariateNormal + +numpyro.enable_x64() + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAINS = [ + "MAINNET", "POLYGON", "ARBITRUM", "GNOSIS", "BASE", "SONIC", "OPTIMISM", + "AVALANCHE", +] + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "local_data", "noise_calibration" +) + +# --------------------------------------------------------------------------- +# Token tier classification +# --------------------------------------------------------------------------- + +# Tier 0: blue-chip — top by volume, wrapped native, major stables +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", # Sonic native +} + +# Tier 1: mid-cap DeFi blue-chips (approx CoinGecko rank < 200) +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} + + +def _normalise_symbol(symbol: str) -> str: + """Normalise wrapped/bridged variants to canonical form.""" + s = symbol.strip() + # Common wrapped → unwrapped + mapping = { + "WETH": "WETH", # keep WETH as-is (it's in tier 0) + "WBTC": "WBTC", + "cbBTC": "cbBTC", + "WMATIC": "WMATIC", + "WAVAX": "WAVAX", + "WXDAI": "WXDAI", + "wS": "wS", + } + return mapping.get(s, s) + + +def classify_token_tier(symbol: str) -> int: + """Classify a token symbol into tier 0/1/2. + + Returns + ------- + int + 0 = blue-chip, 1 = mid-cap, 2 = long-tail + """ + s = _normalise_symbol(symbol) + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# Phase 1: API data ingestion +# --------------------------------------------------------------------------- + +def _graphql_request(query: dict, base_url: str = BALANCER_API_URL, + timeout: int = 30) -> dict: + """Send a GraphQL request to the Balancer V3 API.""" + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def enumerate_balancer_pools( + chains: list = None, + pool_types: list = None, + min_tvl: float = 10000.0, +) -> pd.DataFrame: + """Enumerate all WEIGHTED + RECLAMM pools across chains from Balancer API. + + Parameters + ---------- + chains : list of str + API chain identifiers (e.g. ["MAINNET", "BASE"]). + pool_types : list of str + Pool type filters (e.g. ["WEIGHTED", "STABLE"]). + min_tvl : float + Minimum TVL in USD to include. + + Returns + ------- + pd.DataFrame + Columns: pool_id, chain, pool_type, tokens (list of symbols), + swap_fee, create_time, dynamic_data_tvl. + """ + if chains is None: + chains = BALANCER_API_CHAINS + if pool_types is None: + pool_types = ["WEIGHTED", "RECLAMM"] + + all_pools = [] + for chain in chains: + print(f" Querying {chain}...", end=" ", flush=True) + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { + chainIn: [$chain] + poolTypeIn: $types + minTvl: $minTvl + } + ) { + id + chain + type + createTime + protocolVersion + poolTokens { + symbol + weight + address + } + dynamicData { + totalLiquidity + swapFee + } + } + } + """, + "variables": { + "chain": chain, + "types": pool_types, + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f"FAILED ({e})") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + weights = [t.get("weight") for t in p.get("poolTokens", [])] + token_addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "protocol_version": p.get("protocolVersion", 0), + "tokens": tokens, + "token_addresses": token_addresses, + "weights": weights, + "swap_fee": fee, + "create_time": p.get("createTime", 0), + "current_tvl": tvl, + }) + + print(f"{len(pools)} pools") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools across {len(chains)} chains") + return df + + +def fetch_pool_snapshots(pool_id: str, chain: str, + base_url: str = BALANCER_API_URL) -> pd.DataFrame: + """Fetch ALL_TIME daily snapshots for a single pool. + + Returns + ------- + pd.DataFrame + Columns: timestamp, volume_usd, total_liquidity_usd. + """ + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + } + } + """, + "variables": { + "poolId": pool_id, + "chain": chain, + "range": "ALL_TIME", + }, + } + + body = _graphql_request(query) + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + + if not snapshots: + return pd.DataFrame(columns=["timestamp", "volume_usd", "total_liquidity_usd"]) + + records = [] + for snap in snapshots: + records.append({ + "timestamp": int(snap["timestamp"]), + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + }) + + df = pd.DataFrame(records) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + # Deduplicate by date (keep last snapshot per day) + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df + + +def fetch_all_snapshots(pools_df: pd.DataFrame, + cache_path: str = None) -> pd.DataFrame: + """Fetch daily snapshots for all pools, with caching. + + Parameters + ---------- + pools_df : pd.DataFrame + Pool enumeration from enumerate_balancer_pools. + cache_path : str, optional + Path to parquet cache. If it exists, only fetch missing pools. + + Returns + ------- + pd.DataFrame + Panel with columns: pool_id, chain, date, volume_usd, + total_liquidity_usd. + """ + # Load cache if exists + cached = pd.DataFrame() + cached_pool_ids = set() + if cache_path and os.path.exists(cache_path): + cached = pd.read_parquet(cache_path) + cached_pool_ids = set(cached["pool_id"].unique()) + print(f" Cache has {len(cached_pool_ids)} pools, " + f"{len(cached)} pool-days") + + # Determine which pools need fetching + if len(pools_df) == 0: + print(" No pools to fetch.") + return cached if len(cached) > 0 else pd.DataFrame( + columns=["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + ) + to_fetch = pools_df[~pools_df["pool_id"].isin(cached_pool_ids)] + print(f" Need to fetch {len(to_fetch)} new pools") + + new_records = [] + for i, (_, pool) in enumerate(to_fetch.iterrows()): + if (i + 1) % 10 == 0 or i == 0: + print(f" Fetching {i+1}/{len(to_fetch)}: {pool['pool_id'][:10]}... " + f"({pool['chain']})", flush=True) + try: + snap_df = fetch_pool_snapshots(pool["pool_id"], pool["chain"]) + if len(snap_df) > 0: + snap_df["pool_id"] = pool["pool_id"] + snap_df["chain"] = pool["chain"] + new_records.append(snap_df[ + ["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + ]) + except Exception as e: + print(f" FAILED {pool['pool_id'][:10]}: {e}") + time.sleep(0.5) # Rate limit + + if new_records: + new_df = pd.concat(new_records, ignore_index=True) + combined = pd.concat([cached, new_df], ignore_index=True) + else: + combined = cached + + # Save cache + if cache_path and len(combined) > 0: + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + combined.to_parquet(cache_path, index=False) + print(f" Saved cache: {len(combined)} pool-days → {cache_path}") + + return combined + + +def fetch_token_prices(token_addresses_by_chain: dict, + cache_dir: str = None) -> dict: + """Fetch hourly token prices from Balancer API. + + Parameters + ---------- + token_addresses_by_chain : dict + {chain: {symbol: address, ...}, ...} + cache_dir : str, optional + Directory for per-token price caches. + + Returns + ------- + dict + {(chain, symbol): pd.DataFrame with columns [timestamp, price], ...} + """ + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + prices = {} + + for chain, tokens in token_addresses_by_chain.items(): + # Check cache first, collect uncached addresses + uncached = {} + for symbol, address in tokens.items(): + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") if cache_dir else None + + if cp and os.path.exists(cp): + prices[(chain, symbol)] = pd.read_parquet(cp) + else: + uncached[symbol] = address + + if not uncached: + continue + + # Batch fetch: API supports multiple addresses per request + addr_to_symbol = {addr: sym for sym, addr in uncached.items()} + addresses = list(uncached.values()) + + print(f" Fetching {len(addresses)} prices on {chain}...", + flush=True) + + # Batch in groups of 20 to avoid oversized requests + batch_size = 20 + for batch_start in range(0, len(addresses), batch_size): + batch_addrs = addresses[batch_start:batch_start + batch_size] + query = { + "query": """ + query GetPrices($chain: GqlChain!, $addresses: [String!]!, + $range: GqlTokenChartDataRange!) { + tokenGetHistoricalPrices( + addresses: $addresses, chain: $chain, range: $range + ) { + address + prices { + timestamp + price + } + } + } + """, + "variables": { + "chain": chain, + "addresses": batch_addrs, + "range": "ONE_YEAR", + }, + } + + try: + body = _graphql_request(query, timeout=60) + results = body.get("data", {}).get( + "tokenGetHistoricalPrices", []) + for result in results: + addr = result.get("address", "") + price_list = result.get("prices", []) + symbol = addr_to_symbol.get(addr) + if symbol and price_list: + pdf = pd.DataFrame(price_list) + pdf["timestamp"] = pdf["timestamp"].astype(int) + pdf["price"] = pdf["price"].astype(float) + prices[(chain, symbol)] = pdf + if cache_dir: + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") + pdf.to_parquet(cp, index=False) + except Exception as e: + print(f" FAILED batch on {chain}: {e}") + + time.sleep(0.5) + + print(f" Got prices for {len(prices)} token-chain pairs") + return prices + + +def compute_pair_volatility( + snapshots_df: pd.DataFrame, + pool_row: pd.Series, + token_prices: dict, +) -> pd.Series: + """Compute daily annualised volatility for a pool's pair ratio. + + Uses hourly prices from the API to compute daily realised volatility. + Falls back to a default of 0.5 if price data is insufficient. + + Parameters + ---------- + snapshots_df : pd.DataFrame + Pool's daily snapshots (need dates). + pool_row : pd.Series + Pool metadata row (need tokens, chain). + token_prices : dict + {(chain, symbol): DataFrame, ...} + + Returns + ------- + pd.Series + Indexed by date, values are annualised daily volatility. + """ + tokens = pool_row["tokens"] + chain = pool_row["chain"] + + if len(tokens) < 2: + return pd.Series(dtype=float) + + # Get price series for token[0] and token[1] + # Try chain-specific first, then any chain + def _get_price_df(symbol): + # Exact match + key = (chain, symbol) + if key in token_prices: + return token_prices[key] + # Any chain + for k, v in token_prices.items(): + if k[1] == symbol: + return v + return None + + p0_df = _get_price_df(tokens[0]) + p1_df = _get_price_df(tokens[1]) + + # If either is a stablecoin, use $1 + stables = {"USDC", "USDT", "DAI", "LUSD", "GHO", "crvUSD", "sDAI", + "WXDAI", "xDAI", "USDC.e", "USDbC"} + + if tokens[0] in stables and tokens[1] in stables: + # Stable-stable pair: near-zero vol + dates = snapshots_df["date"].unique() + return pd.Series(0.01, index=dates) + + if p0_df is None and tokens[0] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) # fallback + if p1_df is None and tokens[1] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) # fallback + + # Build hourly price ratio + if tokens[0] in stables: + # ratio = 1 / p1 + if p1_df is None or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p1_df.copy() + ratio_df["ratio"] = 1.0 / ratio_df["price"] + elif tokens[1] in stables: + # ratio = p0 + if p0_df is None or len(p0_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p0_df.copy() + ratio_df["ratio"] = ratio_df["price"] + else: + # Both non-stable: ratio = p0/p1 + if p0_df is None or p1_df is None or len(p0_df) == 0 or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + # Merge on nearest timestamp + merged = pd.merge_asof( + p0_df.sort_values("timestamp"), + p1_df.sort_values("timestamp"), + on="timestamp", + suffixes=("_0", "_1"), + tolerance=7200, # 2 hour tolerance + ).dropna() + if len(merged) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = merged.copy() + ratio_df["ratio"] = merged["price_0"] / merged["price_1"] + + ratio_df["datetime"] = pd.to_datetime(ratio_df["timestamp"], unit="s") + ratio_df["date"] = ratio_df["datetime"].dt.date + ratio_df = ratio_df.sort_values("timestamp") + + # Log returns + ratio_df["log_return"] = np.log( + ratio_df["ratio"] / ratio_df["ratio"].shift(1) + ) + ratio_df = ratio_df.dropna(subset=["log_return"]) + + # Daily vol from hourly returns, annualised + daily_vol = ratio_df.groupby("date")["log_return"].std() + # Hourly data → ~24 returns/day. Annualise: σ_daily * sqrt(365) + # But std() already gives daily std from hourly returns, so: + # σ_annual = σ_hourly * sqrt(24 * 365) + daily_vol_ann = daily_vol * np.sqrt(24 * 365) + + return daily_vol_ann + + +def assemble_panel( + pools_df: pd.DataFrame, + snapshots_df: pd.DataFrame, + token_prices: dict, +) -> pd.DataFrame: + """Assemble the full panel DataFrame for hierarchical estimation. + + Parameters + ---------- + pools_df : pd.DataFrame + Pool enumeration from enumerate_balancer_pools. + snapshots_df : pd.DataFrame + Daily snapshots from fetch_all_snapshots. + token_prices : dict + Token prices from fetch_token_prices. + + Returns + ------- + pd.DataFrame + Panel with columns: pool_id, chain, date, log_volume, log_tvl, + volatility, weekend, log_fee, tier_A, tier_B, tokens. + """ + records = [] + pool_ids = snapshots_df["pool_id"].unique() + n_pools = len(pool_ids) + + for i, pool_id in enumerate(pool_ids): + if (i + 1) % 20 == 0 or i == 0: + print(f" Assembling {i+1}/{n_pools}...", flush=True) + + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pool_id] + pool_meta = pools_df[pools_df["pool_id"] == pool_id] + if len(pool_meta) == 0: + continue + pool_row = pool_meta.iloc[0] + + tokens = pool_row["tokens"] + if len(tokens) < 2: + continue + + chain = pool_row["chain"] + swap_fee = pool_row["swap_fee"] + + # Token tiers: sort by tier (best tier first) + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] # best (lowest) tier + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + # Compute volatility for this pool's pair + vol_series = compute_pair_volatility(pool_snaps, pool_row, token_prices) + + for _, snap in pool_snaps.iterrows(): + date = snap["date"] + volume = snap["volume_usd"] + tvl = snap["total_liquidity_usd"] + + # Skip zero/negative TVL or volume + if tvl <= 0 or volume <= 0: + continue + + # Volatility lookup + if isinstance(vol_series, pd.Series) and date in vol_series.index: + vol = vol_series[date] + else: + vol = 0.5 # fallback + + if not np.isfinite(vol) or vol <= 0: + vol = 0.5 + + # Weekend indicator + if isinstance(date, datetime): + is_weekend = date.weekday() >= 5 + else: + is_weekend = pd.Timestamp(date).weekday() >= 5 + + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": date, + "log_volume": np.log(volume), + "log_tvl": np.log(tvl), + "volatility": vol, + "weekend": 1.0 if is_weekend else 0.0, + "log_fee": np.log(max(swap_fee, 1e-6)), + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": ",".join(tokens[:2]), + "swap_fee": swap_fee, + }) + + panel = pd.DataFrame(records) + print(f"\n Panel: {len(panel)} observations, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + return panel + + +# --------------------------------------------------------------------------- +# Phase 2: Bayesian hierarchical model +# --------------------------------------------------------------------------- + +def _encode_covariates(panel: pd.DataFrame) -> dict: + """Build NumPyro-ready arrays from the panel DataFrame. + + Returns + ------- + dict with keys: + pool_idx : (N_obs,) int array mapping each observation to its pool + X_pool : (N_pools, K) covariate matrix (intercept + dummies + log_fee) + log_tvl, volatility, weekend, log_volume : (N_obs,) float arrays + pool_ids : (N_pools,) pool ID strings + covariate_names : list of str, column names for X_pool + ref_chain, ref_tier_a, ref_tier_b : reference categories + chains : sorted list of all chains + pool_meta : DataFrame of per-pool metadata + """ + pool_meta = panel.drop_duplicates("pool_id").reset_index(drop=True) + pool_ids = pool_meta["pool_id"].values + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + + pool_idx = panel["pool_id"].map(pool_id_to_idx).values + + # Build X_pool columns + chains = sorted(panel["chain"].unique()) + ref_chain = chains[0] + chain_cols = [] + chain_names = [] + for c in chains[1:]: + chain_cols.append((pool_meta["chain"] == c).astype(float).values) + chain_names.append(f"chain_{c}") + + tier_a_vals = sorted(pool_meta["tier_A"].astype(str).unique()) + ref_tier_a = tier_a_vals[0] + tier_a_cols = [] + tier_a_names = [] + for t in tier_a_vals[1:]: + tier_a_cols.append( + (pool_meta["tier_A"].astype(str) == t).astype(float).values + ) + tier_a_names.append(f"tier_A_{t}") + + tier_b_vals = sorted(pool_meta["tier_B"].astype(str).unique()) + ref_tier_b = tier_b_vals[0] + tier_b_cols = [] + tier_b_names = [] + for t in tier_b_vals[1:]: + tier_b_cols.append( + (pool_meta["tier_B"].astype(str) == t).astype(float).values + ) + tier_b_names.append(f"tier_B_{t}") + + N_pools = len(pool_ids) + columns = [np.ones((N_pools, 1))] + col_names = ["intercept"] + + for arr, name in zip(chain_cols, chain_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_a_cols, tier_a_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_b_cols, tier_b_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + columns.append(pool_meta["log_fee"].values.reshape(-1, 1)) + col_names.append("log_fee") + + X_pool = np.hstack(columns) + + return { + "pool_idx": pool_idx.astype(np.int32), + "X_pool": X_pool.astype(np.float64), + "log_tvl": panel["log_tvl"].values.astype(np.float64), + "volatility": panel["volatility"].values.astype(np.float64), + "weekend": panel["weekend"].values.astype(np.float64), + "log_volume": panel["log_volume"].values.astype(np.float64), + "pool_ids": pool_ids, + "covariate_names": col_names, + "ref_chain": ref_chain, + "ref_tier_a": ref_tier_a, + "ref_tier_b": ref_tier_b, + "chains": chains, + "pool_meta": pool_meta, + } + + +def _hierarchical_noise_model( + pool_idx, X_pool, log_tvl, volatility, weekend, log_volume=None, +): + """NumPyro model: Bayesian hierarchical loglinear noise volume. + + All pool covariates modulate all three coefficients (α, β_tvl, β_vol) + through the group-level regression matrix Φ, with correlated random + effects via LKJ-Cholesky. + """ + N_pools = X_pool.shape[0] + K = X_pool.shape[1] + + # Hyperpriors + Phi = numpyro.sample("Phi", dist.Normal(0, 2).expand([K, 3]).to_event(2)) + sigma_theta = numpyro.sample( + "sigma_theta", dist.HalfNormal(2).expand([3]).to_event(1) + ) + L_omega = numpyro.sample("L_omega", dist.LKJCholesky(3, concentration=2)) + beta_weekend = numpyro.sample("beta_weekend", dist.Normal(0, 2)) + sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(3)) + + # Non-centered pool random effects + with numpyro.plate("pools", N_pools): + z = numpyro.sample("z", dist.Normal(jnp.zeros(3), 1).to_event(1)) + + # θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i + mu = X_pool @ Phi # (N_pools, 3) + L_Sigma = sigma_theta[:, None] * L_omega # (3, 3) + theta = mu + z @ L_Sigma.T # (N_pools, 3) + + alpha = theta[:, 0] + beta_tvl = theta[:, 1] + beta_vol = theta[:, 2] + + # Observation model + loc = (alpha[pool_idx] + + beta_tvl[pool_idx] * log_tvl + + beta_vol[pool_idx] * volatility + + beta_weekend * weekend) + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("log_volume", dist.Normal(loc, sigma_eps), obs=log_volume) + + +def fit_bayesian_model( + panel: pd.DataFrame, use_nuts: bool = False, +) -> tuple: + """Fit the Bayesian hierarchical model via SVI or NUTS. + + Parameters + ---------- + panel : pd.DataFrame + Panel from assemble_panel. + use_nuts : bool + If True, use NUTS MCMC (slower, exact). Otherwise SVI with + AutoMultivariateNormal guide. + + Returns + ------- + samples : dict + Posterior samples keyed by parameter name. + encoding : dict + From _encode_covariates (needed downstream). + """ + encoding = _encode_covariates(panel) + + model_kwargs = dict( + pool_idx=jnp.array(encoding["pool_idx"]), + X_pool=jnp.array(encoding["X_pool"]), + log_tvl=jnp.array(encoding["log_tvl"]), + volatility=jnp.array(encoding["volatility"]), + weekend=jnp.array(encoding["weekend"]), + log_volume=jnp.array(encoding["log_volume"]), + ) + + N_pools = encoding["X_pool"].shape[0] + K = encoding["X_pool"].shape[1] + print(f" N obs = {len(encoding['pool_idx'])}, " + f"N pools = {N_pools}, K covariates = {K}") + print(f" Covariates: {encoding['covariate_names']}") + + rng_key = jax.random.PRNGKey(0) + + if use_nuts: + print(" Running NUTS (500 warmup + 1000 samples)...") + kernel = NUTS(_hierarchical_noise_model) + mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=1) + mcmc.run(rng_key, **model_kwargs) + samples = mcmc.get_samples() + print(" NUTS complete.") + else: + print(" Running SVI with AutoMultivariateNormal (20k steps)...") + guide = AutoMultivariateNormal(_hierarchical_noise_model) + optimizer = numpyro.optim.Adam(1e-3) + svi = SVI(_hierarchical_noise_model, guide, optimizer, + loss=Trace_ELBO()) + svi_result = svi.run(rng_key, 20_000, **model_kwargs) + print(f" SVI complete. Final ELBO loss: {svi_result.losses[-1]:.2f}") + + predictive = Predictive( + guide, params=svi_result.params, num_samples=1000, + ) + samples = predictive(jax.random.PRNGKey(1), **model_kwargs) + print(" Drew 1000 posterior samples.") + + return samples, encoding + + +def extract_posteriors(samples: dict, encoding: dict) -> dict: + """Reconstruct pool-specific coefficients from posterior samples. + + Computes θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i for each posterior draw, + then returns posterior means and variance components. + + Returns + ------- + dict with keys: + pool_effects : {pool_id: {alpha, beta_tvl, beta_vol}} + Phi_mean : (K, 3) array + sigma_theta_mean : (3,) array + correlation_matrix : (3, 3) array + beta_weekend_mean : float + sigma_eps_mean : float + theta_samples : (S, N_pools, 3) array (for diagnostics) + """ + Phi = np.array(samples["Phi"]) # (S, K, 3) + sigma_theta = np.array(samples["sigma_theta"]) # (S, 3) + L_omega = np.array(samples["L_omega"]) # (S, 3, 3) + z = np.array(samples["z"]) # (S, N_pools, 3) + beta_weekend = np.array(samples["beta_weekend"]) # (S,) + sigma_eps = np.array(samples["sigma_eps"]) # (S,) + + X_pool = encoding["X_pool"] # (N_pools, K) + pool_ids = encoding["pool_ids"] + + # mu = X_pool @ Phi for each sample: (S, N_pools, 3) + mu = np.einsum("pk,skj->spj", X_pool, Phi) + + # L_Sigma = diag(sigma_theta) @ L_omega: (S, 3, 3) + L_Sigma = sigma_theta[:, :, None] * L_omega + + # offset = z @ L_Sigma^T: (S, N_pools, 3) + offset = np.einsum("spi,sji->spj", z, L_Sigma) + + theta = mu + offset # (S, N_pools, 3) + theta_mean = theta.mean(axis=0) # (N_pools, 3) + + pool_effects = {} + for i, pid in enumerate(pool_ids): + pool_effects[pid] = { + "alpha": float(theta_mean[i, 0]), + "beta_tvl": float(theta_mean[i, 1]), + "beta_vol": float(theta_mean[i, 2]), + } + + Phi_mean = Phi.mean(axis=0) # (K, 3) + + # Correlation matrix: R = L_omega @ L_omega^T, averaged over samples + R_samples = np.einsum("sij,skj->sik", L_omega, L_omega) + R_mean = R_samples.mean(axis=0) + + return { + "pool_effects": pool_effects, + "Phi_mean": Phi_mean, + "sigma_theta_mean": sigma_theta.mean(axis=0), + "correlation_matrix": R_mean, + "beta_weekend_mean": float(beta_weekend.mean()), + "sigma_eps_mean": float(sigma_eps.mean()), + "theta_samples": theta, + } + + +def compute_noise_params(posteriors: dict, panel: pd.DataFrame) -> list: + """Convert posterior pool effects to per-pool noise_params dicts. + + Each pool now has its own β_tvl and β_vol (from the hierarchical + posterior), rather than sharing global slopes. + + Parameters + ---------- + posteriors : dict + From extract_posteriors. + panel : pd.DataFrame + Panel data. + + Returns + ------- + list of dict + Each dict has: pool_id, chain, tokens, noise_params. + """ + pool_effects = posteriors["pool_effects"] + pool_meta = panel.drop_duplicates("pool_id").set_index("pool_id") + + results = [] + for pool_id, effects in pool_effects.items(): + if pool_id not in pool_meta.index: + continue + meta = pool_meta.loc[pool_id] + swap_fee = float(meta.get("swap_fee", 0.003)) + + results.append({ + "pool_id": pool_id, + "chain": meta["chain"], + "tokens": (meta["tokens"].split(",") + if isinstance(meta["tokens"], str) + else meta["tokens"]), + "noise_params": { + "b_0": effects["alpha"], + "b_sigma": effects["beta_vol"], + "b_c": effects["beta_tvl"], + "base_fee": swap_fee, + }, + }) + + return results + + +def _build_covariate_vector( + encoding: dict, chain: str, tokens: list, fee: float, +) -> np.ndarray: + """Construct a covariate vector x for a new pool. + + Matches the column order of X_pool from _encode_covariates so that + x @ Phi_mean gives population-level predictions for all 3 coefficients. + """ + col_names = encoding["covariate_names"] + x = np.zeros(len(col_names)) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + x[i] = 1.0 + elif name == "log_fee": + x[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + x[i] = 1.0 + elif name == f"tier_A_{tier_a}": + x[i] = 1.0 + elif name == f"tier_B_{tier_b}": + x[i] = 1.0 + + return x + + +def predict_new_pool( + posteriors: dict, + encoding: dict, + chain: str, + tokens: list, + fee: float, +) -> dict: + """Predict noise params for an unseen pool. + + Uses population-level estimates only (x @ Φ, no pool random effect). + + Parameters + ---------- + posteriors : dict + From extract_posteriors. + encoding : dict + From _encode_covariates (or loaded from cache). + chain : str + Chain API identifier (e.g. "BASE"). + tokens : list + Token symbols (e.g. ["ETH", "BTC"]). + fee : float + Swap fee rate. + + Returns + ------- + dict + noise_params dict with pool-predicted coefficients. + """ + x = _build_covariate_vector(encoding, chain, tokens, fee) + Phi_mean = posteriors["Phi_mean"] + + theta_pred = x @ Phi_mean # (3,) — population-level prediction + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + return { + "b_0": float(theta_pred[0]), + "b_sigma": float(theta_pred[2]), + "b_c": float(theta_pred[1]), + "base_fee": float(fee), + "_prediction_source": "population_level", + "_alpha": float(theta_pred[0]), + "_beta_tvl": float(theta_pred[1]), + "_beta_vol": float(theta_pred[2]), + "_tier_a": tier_a, + "_tier_b": tier_b, + } + + +# --------------------------------------------------------------------------- +# Phase 3: Diagnostics and output +# --------------------------------------------------------------------------- + +def plot_hierarchical_diagnostics( + panel: pd.DataFrame, + posteriors: dict, + encoding: dict, + output_dir: str = "results", +): + """Generate diagnostic plots for the Bayesian hierarchical model. + + Figure 1 (2x2): + (0,0) Pool-specific coefficient distributions (α, β_tvl, β_vol) + (0,1) Chain effects on all 3 coefficients + (1,0) Tier effects on all 3 coefficients + (1,1) Model summary (Φ, σ_θ, correlations, σ_ε) + + Figure 2: + β_tvl vs β_vol scatter colored by chain + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + os.makedirs(output_dir, exist_ok=True) + + pool_effects = posteriors["pool_effects"] + Phi_mean = posteriors["Phi_mean"] + col_names = encoding["covariate_names"] + pool_meta = panel.drop_duplicates("pool_id") + + alphas = [e["alpha"] for e in pool_effects.values()] + beta_tvls = [e["beta_tvl"] for e in pool_effects.values()] + beta_vols = [e["beta_vol"] for e in pool_effects.values()] + + # --- Figure 1: Diagnostics 2x2 --- + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + + # (0,0) Pool-specific coefficient distributions + ax = axes[0, 0] + bins = 25 + ax.hist(alphas, bins=bins, alpha=0.6, label="α (intercept)", + color="steelblue", edgecolor="white") + ax.hist(beta_tvls, bins=bins, alpha=0.6, label="β_tvl", + color="coral", edgecolor="white") + ax.hist(beta_vols, bins=bins, alpha=0.6, label="β_vol", + color="seagreen", edgecolor="white") + ax.set_xlabel("Coefficient value") + ax.set_ylabel("Count") + ax.set_title(f"Pool-specific coefficients (n={len(alphas)})") + ax.legend() + + # (0,1) Chain effects on all 3 coefficients + ax = axes[0, 1] + chain_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("chain_")} + chain_labels = [] + chain_alpha_effects = [] + chain_tvl_effects = [] + chain_vol_effects = [] + # Reference chain (effect = 0) + ref_chain = encoding["ref_chain"] + chain_counts = pool_meta["chain"].value_counts() + chain_labels.append(f"{ref_chain}\n(n={chain_counts.get(ref_chain, 0)})") + chain_alpha_effects.append(0.0) + chain_tvl_effects.append(0.0) + chain_vol_effects.append(0.0) + for name, row_idx in sorted(chain_rows.items()): + chain_name = name.replace("chain_", "") + chain_labels.append( + f"{chain_name}\n(n={chain_counts.get(chain_name, 0)})" + ) + chain_alpha_effects.append(Phi_mean[row_idx, 0]) + chain_tvl_effects.append(Phi_mean[row_idx, 1]) + chain_vol_effects.append(Phi_mean[row_idx, 2]) + + y_pos = np.arange(len(chain_labels)) + bar_h = 0.25 + ax.barh(y_pos - bar_h, chain_alpha_effects, bar_h, label="α", + color="steelblue", alpha=0.8) + ax.barh(y_pos, chain_tvl_effects, bar_h, label="β_tvl", + color="coral", alpha=0.8) + ax.barh(y_pos + bar_h, chain_vol_effects, bar_h, label="β_vol", + color="seagreen", alpha=0.8) + ax.set_yticks(y_pos) + ax.set_yticklabels(chain_labels) + ax.axvline(0, color="red", linestyle="--", linewidth=1) + ax.set_xlabel("Effect (relative to reference)") + ax.set_title("Chain effects on all coefficients") + ax.legend(fontsize=8) + + # (1,0) Tier effects on all 3 coefficients + ax = axes[1, 0] + tier_names = ["0 (blue-chip)", "1 (mid-cap)", "2 (long-tail)"] + coeff_labels = ["α", "β_tvl", "β_vol"] + tier_a_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("tier_A_")} + tier_b_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("tier_B_")} + + # Build effects matrix: (3 tiers) x (3 coefficients) x (A/B) + x_pos = np.arange(len(tier_names)) + width = 0.13 + for coeff_idx, (coeff_name, color) in enumerate( + zip(coeff_labels, ["steelblue", "coral", "seagreen"]) + ): + tier_a_vals = [0.0, 0.0, 0.0] # reference tier gets 0 + tier_b_vals = [0.0, 0.0, 0.0] + for name, row_idx in tier_a_rows.items(): + tier_val = name.replace("tier_A_", "") + if tier_val in ("0", "1", "2"): + tier_a_vals[int(tier_val)] = Phi_mean[row_idx, coeff_idx] + for name, row_idx in tier_b_rows.items(): + tier_val = name.replace("tier_B_", "") + if tier_val in ("0", "1", "2"): + tier_b_vals[int(tier_val)] = Phi_mean[row_idx, coeff_idx] + offset = (coeff_idx - 1) * width * 2 + ax.bar(x_pos + offset - width / 2, tier_a_vals, width, + label=f"{coeff_name} (A)" if coeff_idx == 0 else "", + color=color, alpha=0.7, edgecolor="white") + ax.bar(x_pos + offset + width / 2, tier_b_vals, width, + label=f"{coeff_name} (B)" if coeff_idx == 0 else "", + color=color, alpha=0.4, edgecolor="white", hatch="//") + + ax.set_xticks(x_pos) + ax.set_xticklabels(tier_names) + ax.set_ylabel("Effect on coefficient") + ax.set_title("Token tier effects (solid=A, hatched=B)") + ax.axhline(0, color="black", linewidth=0.5) + # Manual legend for coefficient colors + from matplotlib.patches import Patch + ax.legend(handles=[Patch(color=c, label=l) for c, l in + zip(["steelblue", "coral", "seagreen"], coeff_labels)], + fontsize=8) + + # (1,1) Model summary text + ax = axes[1, 1] + ax.axis("off") + sigma_theta = posteriors["sigma_theta_mean"] + R = posteriors["correlation_matrix"] + summary = "Group-level regression Φ (posterior mean):\n" + summary += f" {'covariate':<20s} {'α':>8s} {'β_tvl':>8s} {'β_vol':>8s}\n" + summary += " " + "-" * 46 + "\n" + for j, name in enumerate(col_names): + summary += (f" {name:<20s} {Phi_mean[j,0]:>8.3f} " + f"{Phi_mean[j,1]:>8.3f} {Phi_mean[j,2]:>8.3f}\n") + summary += f"\nσ_θ: [{sigma_theta[0]:.3f}, {sigma_theta[1]:.3f}, " + summary += f"{sigma_theta[2]:.3f}]\n" + summary += f"Correlation:\n" + for i in range(3): + summary += f" [{R[i,0]:>6.3f} {R[i,1]:>6.3f} {R[i,2]:>6.3f}]\n" + summary += f"β_weekend: {posteriors['beta_weekend_mean']:.4f}\n" + summary += f"σ_ε: {posteriors['sigma_eps_mean']:.4f}\n" + ax.text(0.02, 0.98, summary, transform=ax.transAxes, + fontsize=7, verticalalignment="top", fontfamily="monospace") + ax.set_title("Model Summary") + + fig.suptitle( + "Bayesian Hierarchical Noise Model — Diagnostics", fontsize=13, + ) + plt.tight_layout() + path = os.path.join(output_dir, "hierarchical_diagnostics.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- Figure 2: β_tvl vs β_vol scatter colored by chain --- + fig2, ax2 = plt.subplots(figsize=(10, 7)) + pool_id_to_chain = dict( + zip(pool_meta["pool_id"], pool_meta["chain"]) + ) + chain_colors = {} + cmap = plt.cm.tab10 + unique_chains = sorted(pool_meta["chain"].unique()) + for i, c in enumerate(unique_chains): + chain_colors[c] = cmap(i % 10) + + for pid, effects in pool_effects.items(): + c = pool_id_to_chain.get(pid, "?") + ax2.scatter(effects["beta_tvl"], effects["beta_vol"], + color=chain_colors.get(c, "gray"), alpha=0.6, s=20, + edgecolors="white", linewidths=0.3) + + # Legend + from matplotlib.lines import Line2D + handles = [Line2D([0], [0], marker="o", color="w", + markerfacecolor=chain_colors[c], markersize=8, + label=c) + for c in unique_chains if c in chain_colors] + ax2.legend(handles=handles, fontsize=8, loc="best") + ax2.set_xlabel("β_tvl (TVL elasticity)") + ax2.set_ylabel("β_vol (volatility sensitivity)") + ax2.set_title("Pool-specific coefficients by chain") + ax2.axhline(0, color="gray", linewidth=0.5, linestyle="--") + ax2.axvline(0, color="gray", linewidth=0.5, linestyle="--") + plt.tight_layout() + path2 = os.path.join(output_dir, "beta_tvl_vs_beta_vol.png") + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + return path, path2 + + +def generate_noise_params_json( + pool_params: list, + posteriors: dict, + encoding: dict, + output_path: str, + inference_method: str = "svi", +): + """Write per-pool noise params to JSON. + + Parameters + ---------- + pool_params : list of dict + From compute_noise_params. + posteriors : dict + From extract_posteriors. + encoding : dict + From _encode_covariates. + output_path : str + Output JSON path. + inference_method : str + "svi" or "nuts". + """ + output = { + "model": "bayesian_hierarchical_loglinear", + "inference_method": inference_method, + "Phi": posteriors["Phi_mean"].tolist(), + "covariate_names": encoding["covariate_names"], + "sigma_theta": posteriors["sigma_theta_mean"].tolist(), + "correlation_matrix": posteriors["correlation_matrix"].tolist(), + "beta_weekend": posteriors["beta_weekend_mean"], + "sigma_eps": posteriors["sigma_eps_mean"], + "pools": pool_params, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params → {output_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Bayesian hierarchical noise volume model for Balancer pools" + ) + parser.add_argument( + "--fetch", action="store_true", + help="Fetch pool data from Balancer API (cached to local_data/)", + ) + parser.add_argument( + "--fit", action="store_true", + help="Fit Bayesian hierarchical model (requires fetched data)", + ) + parser.add_argument( + "--nuts", action="store_true", + help="Use NUTS MCMC instead of SVI (slower, exact posteriors)", + ) + parser.add_argument( + "--plot", action="store_true", + help="Generate diagnostic plots", + ) + parser.add_argument( + "--output", default=None, + help="Output JSON path for per-pool noise params", + ) + parser.add_argument( + "--output-dir", default="results", + help="Directory for diagnostic plots (default: results)", + ) + parser.add_argument( + "--predict", action="store_true", + help="Predict noise params for a new pool", + ) + parser.add_argument( + "--chain", default=None, + help="Chain for --predict (e.g. BASE, MAINNET)", + ) + parser.add_argument( + "--tokens", nargs="+", default=None, + help="Token symbols for --predict (e.g. ETH BTC)", + ) + parser.add_argument( + "--fee", type=float, default=0.003, + help="Swap fee for --predict", + ) + parser.add_argument( + "--min-tvl", type=float, default=10000.0, + help="Minimum TVL filter for pool enumeration", + ) + parser.add_argument( + "--cache-dir", default=None, + help="Cache directory (default: local_data/noise_calibration/)", + ) + args = parser.parse_args() + + cache_dir = args.cache_dir or CACHE_DIR + + if not any([args.fetch, args.fit, args.predict]): + parser.error("At least one of --fetch, --fit, --predict is required") + + # --- Fetch --- + pools_cache = os.path.join(cache_dir, "pools.parquet") + snaps_cache = os.path.join(cache_dir, "pool_snapshots.parquet") + prices_cache = os.path.join(cache_dir, "token_prices") + panel_cache = os.path.join(cache_dir, "panel.parquet") + + if args.fetch: + print("Phase 1: Fetching data from Balancer API") + print("=" * 60) + + # Step 1: Enumerate pools + print("\n1. Enumerating pools...") + pools_df = enumerate_balancer_pools(min_tvl=args.min_tvl) + os.makedirs(cache_dir, exist_ok=True) + pools_df.to_parquet(pools_cache, index=False) + print(f" Saved {len(pools_df)} pools → {pools_cache}") + + # Step 2: Fetch snapshots + print("\n2. Fetching daily snapshots...") + snapshots_df = fetch_all_snapshots(pools_df, cache_path=snaps_cache) + + # Step 3: Fetch token prices + print("\n3. Fetching token prices...") + token_addr_by_chain = {} + for _, pool in pools_df.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + # Step 4: Assemble panel + print("\n4. Assembling panel...") + panel = assemble_panel(pools_df, snapshots_df, token_prices) + panel.to_parquet(panel_cache, index=False) + print(f" Saved panel → {panel_cache}") + + print(f"\nFetch complete. Panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # --- Fit --- + if args.fit: + inference_method = "nuts" if args.nuts else "svi" + print(f"\nPhase 2: Fitting Bayesian hierarchical model ({inference_method})") + print("=" * 60) + + # Load panel + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_cache) + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + # Filter: need at least 10 days per pool for stable estimates + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel_filtered = panel[panel["pool_id"].isin(valid_pools)] + print(f" After filtering (≥10 days): {len(panel_filtered)} obs, " + f"{panel_filtered['pool_id'].nunique()} pools") + + samples, encoding = fit_bayesian_model( + panel_filtered, use_nuts=args.nuts, + ) + posteriors = extract_posteriors(samples, encoding) + pool_params = compute_noise_params(posteriors, panel_filtered) + + # Print key diagnostics + Phi_mean = posteriors["Phi_mean"] + col_names = encoding["covariate_names"] + intercept_idx = col_names.index("intercept") + log_fee_idx = col_names.index("log_fee") + + print(f"\n Key results:") + print(f" Population intercept (Φ[intercept]):") + print(f" α: {Phi_mean[intercept_idx, 0]:.4f}") + print(f" β_tvl: {Phi_mean[intercept_idx, 1]:.4f}") + print(f" β_vol: {Phi_mean[intercept_idx, 2]:.4f}") + print(f" Fee effect (Φ[log_fee]):") + print(f" α: {Phi_mean[log_fee_idx, 0]:.4f}") + print(f" β_tvl: {Phi_mean[log_fee_idx, 1]:.4f}") + print(f" β_vol: {Phi_mean[log_fee_idx, 2]:.4f}") + print(f" β_weekend: {posteriors['beta_weekend_mean']:.4f}") + print(f" σ_θ: {posteriors['sigma_theta_mean']}") + print(f" σ_ε: {posteriors['sigma_eps_mean']:.4f}") + + # Verify pool-specific variation + b_sigmas = [p["noise_params"]["b_sigma"] for p in pool_params] + b_cs = [p["noise_params"]["b_c"] for p in pool_params] + print(f" b_sigma range: [{min(b_sigmas):.4f}, {max(b_sigmas):.4f}]") + print(f" b_c range: [{min(b_cs):.4f}, {max(b_cs):.4f}]") + + # Cache posteriors + encoding for --predict and --plot + posteriors_cache = os.path.join(cache_dir, "posteriors.json") + cache_data = { + "Phi_mean": posteriors["Phi_mean"].tolist(), + "sigma_theta_mean": posteriors["sigma_theta_mean"].tolist(), + "correlation_matrix": posteriors["correlation_matrix"].tolist(), + "beta_weekend_mean": posteriors["beta_weekend_mean"], + "sigma_eps_mean": posteriors["sigma_eps_mean"], + "pool_effects": posteriors["pool_effects"], + "covariate_names": encoding["covariate_names"], + "ref_chain": encoding["ref_chain"], + "ref_tier_a": encoding["ref_tier_a"], + "ref_tier_b": encoding["ref_tier_b"], + "chains": encoding["chains"], + "inference_method": inference_method, + } + with open(posteriors_cache, "w") as f: + json.dump(cache_data, f, indent=2, default=str) + print(f" Cached posteriors → {posteriors_cache}") + + if args.output: + generate_noise_params_json( + pool_params, posteriors, encoding, + args.output, inference_method=inference_method, + ) + + if args.plot: + print("\nPhase 3: Generating diagnostics") + print("=" * 60) + plot_hierarchical_diagnostics( + panel_filtered, posteriors, encoding, + output_dir=args.output_dir, + ) + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + parser.error("--predict requires --chain and --tokens") + + print(f"\nPredicting noise params for new pool:") + print(f" Chain: {args.chain}") + print(f" Tokens: {args.tokens}") + print(f" Fee: {args.fee}") + + # Load cached posteriors + encoding metadata + posteriors_cache = os.path.join(cache_dir, "posteriors.json") + if not os.path.exists(posteriors_cache): + print(f"ERROR: Posteriors cache not found at {posteriors_cache}", + file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + with open(posteriors_cache) as f: + cache_data = json.load(f) + + posteriors = { + "Phi_mean": np.array(cache_data["Phi_mean"]), + "sigma_theta_mean": np.array(cache_data["sigma_theta_mean"]), + "correlation_matrix": np.array(cache_data["correlation_matrix"]), + "beta_weekend_mean": cache_data["beta_weekend_mean"], + "sigma_eps_mean": cache_data["sigma_eps_mean"], + "pool_effects": cache_data["pool_effects"], + } + encoding = { + "covariate_names": cache_data["covariate_names"], + "ref_chain": cache_data["ref_chain"], + "ref_tier_a": cache_data["ref_tier_a"], + "ref_tier_b": cache_data["ref_tier_b"], + "chains": cache_data["chains"], + } + + params = predict_new_pool( + posteriors, encoding, args.chain, args.tokens, args.fee, + ) + print(f"\n Predicted noise_params:") + print(json.dumps(params, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_unified.py b/scripts/calibrate_noise_unified.py new file mode 100644 index 0000000..0e35527 --- /dev/null +++ b/scripts/calibrate_noise_unified.py @@ -0,0 +1,6 @@ +"""Thin wrapper — all logic lives in quantammsim.noise_calibration.""" +from quantammsim.noise_calibration import * # noqa: F401, F403 +from quantammsim.noise_calibration.cli import main + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_reclamm_noise.py b/scripts/calibrate_reclamm_noise.py new file mode 100644 index 0000000..8d09584 --- /dev/null +++ b/scripts/calibrate_reclamm_noise.py @@ -0,0 +1,842 @@ +"""OLS calibration of the Tsoukalas noise volume model for reClAMM pools. + +Fits the structural volume equation: + V_daily/1e6 = a_0 + a_sigma*sigma + a_c*sqrt(c_eff/1e6) + +where c_eff = (Ra+Va)*pA + (Rb+Vb)*pB is the effective TVL (real + virtual). + +From daily pool snapshots (volume, TVL, volatility). Outputs a noise_params dict +compatible with run_fingerprint["reclamm_noise_params"]. + +Usage: + # From a pre-assembled CSV + python scripts/calibrate_reclamm_noise.py --csv daily_data.csv --base-fee 0.003 + + # End-to-end from API + DB + parquets + python scripts/calibrate_reclamm_noise.py --pool cbBTC_WETH +""" + +import argparse +import json +import os +import sqlite3 +import sys +import urllib.request +from datetime import datetime, timezone + +import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Balancer V3 API +# --------------------------------------------------------------------------- + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAIN = { + "base": "BASE", + "ethereum": "MAINNET", + "gnosis": "GNOSIS", + "avalanche": "AVALANCHE", + "arbitrum": "ARBITRUM", + "polygon": "POLYGON", + "optimism": "OPTIMISM", + "sonic": "SONIC", +} + + +def fetch_balancer_snapshots(chain, pool_address, start_ts, end_ts, + base_url=BALANCER_API_URL): + """Fetch daily pool snapshots from Balancer V3 GraphQL API. + + Parameters + ---------- + chain : str + Chain name (e.g. 'base', 'ethereum'). + pool_address : str + Pool contract address (hex, no 0x prefix). + start_ts : int + Start unix timestamp (seconds). + end_ts : int + End unix timestamp (seconds). + base_url : str + Balancer API base URL. + + Returns + ------- + pd.DataFrame + Columns: date, volume_usd, total_liquidity_usd. Indexed by date string. + """ + api_chain = BALANCER_API_CHAIN.get(chain) + if api_chain is None: + raise ValueError(f"Unknown chain for Balancer API: {chain!r}") + + pool_id = f"0x{pool_address}" if not pool_address.startswith("0x") else pool_address + + # Paginate: API may limit results. Fetch in 90-day windows. + all_snapshots = [] + window = 90 * 86400 + cursor = start_ts + + while cursor < end_ts: + window_end = min(cursor + window, end_ts) + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + } + } + """, + "variables": { + "poolId": pool_id, + "chain": api_chain, + "range": "ALL_TIME", + }, + } + + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + + with urllib.request.urlopen(req, timeout=30) as resp: + body = json.loads(resp.read().decode("utf-8")) + + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + if not snapshots: + break + + for snap in snapshots: + ts = int(snap["timestamp"]) + if start_ts <= ts <= end_ts: + all_snapshots.append({ + "timestamp": ts, + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + }) + + # The API returns ALL_TIME, so no need to paginate further + break + + if not all_snapshots: + raise ValueError( + f"No Balancer snapshots for {pool_id} on {chain} " + f"between {start_ts} and {end_ts}" + ) + + df = pd.DataFrame(all_snapshots) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + # Deduplicate by date (keep last snapshot per day) + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df.set_index("date") + + +# --------------------------------------------------------------------------- +# DB-based daily pool state +# --------------------------------------------------------------------------- + +def load_daily_pool_state(pool, db_path, data_root): + """Load daily pool state from pools_history.db, compute effective TVL. + + Parameters + ---------- + pool : PoolConfig + Pool configuration (from pool_registry). + db_path : str + Path to pools_history.db. + data_root : str + Directory containing {TICKER}_USD.parquet files. + + Returns + ------- + pd.DataFrame + Indexed by date, columns: effective_tvl_usd, real_tvl_usd. + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + cur.execute( + f"""SELECT timestamp, balance_0, balance_1, virtual_0, virtual_1 + FROM {pool.db_label} + ORDER BY timestamp""" + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No DB data for {pool.db_label}") + + df = pd.DataFrame(rows, columns=["timestamp", "bal_0", "bal_1", "virt_0", "virt_1"]) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + + # Keep last snapshot per day + daily = df.sort_values("timestamp").drop_duplicates("date", keep="last").set_index("date") + + # Load USD prices for each token + if pool.reverse: + tickers_in_db_order = [pool.tokens[1], pool.tokens[0]] + else: + tickers_in_db_order = [pool.tokens[0], pool.tokens[1]] + + price_dfs = {} + for ticker in tickers_in_db_order: + if ticker == "USDC": + price_dfs[ticker] = None # constant $1 + else: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + pdf = pd.read_parquet(path) + pdf["date"] = pd.to_datetime(pdf["unix"], unit="ms").dt.date + # Daily close: last price per day + price_dfs[ticker] = ( + pdf.sort_values("unix") + .drop_duplicates("date", keep="last") + .set_index("date")["close"] + ) + + # Compute USD prices at each daily snapshot + records = [] + for date, row in daily.iterrows(): + b0, b1, v0, v1 = row["bal_0"], row["bal_1"], row["virt_0"], row["virt_1"] + + p0 = 1.0 if tickers_in_db_order[0] == "USDC" else price_dfs[tickers_in_db_order[0]].get(date, np.nan) + p1 = 1.0 if tickers_in_db_order[1] == "USDC" else price_dfs[tickers_in_db_order[1]].get(date, np.nan) + + if np.isnan(p0) or np.isnan(p1): + continue + + real_tvl = b0 * p0 + b1 * p1 + effective_tvl = (b0 + v0) * p0 + (b1 + v1) * p1 + + records.append({ + "date": date, + "real_tvl_usd": real_tvl, + "effective_tvl_usd": effective_tvl, + }) + + result = pd.DataFrame(records).set_index("date") + return result + + +# --------------------------------------------------------------------------- +# Daily volatility from price parquets +# --------------------------------------------------------------------------- + +def compute_daily_volatility(tokens, data_root, start_ts, end_ts): + """Compute daily annualised volatility of the price ratio. + + Uses 5-minute subsampled log returns within each day, then + annualises with sqrt(365). + + Parameters + ---------- + tokens : list + Token tickers in quantammsim sorted order (e.g. ['BTC', 'ETH']). + data_root : str + Directory containing {TICKER}_USD.parquet files. + start_ts : int + Start unix timestamp (seconds). + end_ts : int + End unix timestamp (seconds). + + Returns + ------- + pd.Series + Indexed by date, values are annualised daily volatility. + """ + # Load minute-level prices for both tokens + prices = {} + for ticker in tokens: + if ticker == "USDC": + prices[ticker] = None + else: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path) + df = df[(df["unix"] >= start_ts * 1000) & (df["unix"] <= end_ts * 1000)] + df["datetime"] = pd.to_datetime(df["unix"], unit="ms") + df = df.set_index("datetime")["close"] + prices[ticker] = df + + # Compute price ratio (token[0] / token[1]) + t0, t1 = tokens[0], tokens[1] + if prices[t0] is not None and prices[t1] is not None: + # Align on common timestamps + combined = pd.DataFrame({"p0": prices[t0], "p1": prices[t1]}).dropna() + ratio = combined["p0"] / combined["p1"] + elif prices[t0] is not None: + ratio = prices[t0] # t1 is USDC ($1) + elif prices[t1] is not None: + ratio = 1.0 / prices[t1] # t0 is USDC + else: + raise ValueError("Both tokens are USDC — cannot compute ratio") + + # Subsample to 5-min intervals + ratio_5m = ratio.resample("5min").last().dropna() + log_returns = np.log(ratio_5m / ratio_5m.shift(1)).dropna() + + # Group by date, compute daily vol + log_returns_df = log_returns.to_frame("lr") + log_returns_df["date"] = log_returns_df.index.date + + daily_vol = log_returns_df.groupby("date")["lr"].std() + # Annualise: each day has ~288 5-min periods, scale by sqrt(288 * 365) + daily_vol_ann = daily_vol * np.sqrt(288 * 365) + + return daily_vol_ann + + +# --------------------------------------------------------------------------- +# Calibration DataFrame assembly +# --------------------------------------------------------------------------- + +def build_calibration_df(pool, data_root=None): + """Build daily calibration DataFrame from Balancer API + price parquets. + + All pool state (volume, effective TVL) comes from the Balancer V3 API. + The API's ``totalLiquidity`` is the effective TVL: for a reClAMM pool + on Balancer V3, the router sees real + virtual reserves, and + ``totalLiquidity`` reflects that full depth. Only the volatility + computation requires price parquets. + + Parameters + ---------- + pool : PoolConfig + Pool configuration (must have pool_address field). + data_root : str, optional + Directory containing {TICKER}_USD.parquet price files. + + Returns + ------- + pd.DataFrame + Columns: volume_usd, effective_tvl_usd, volatility. Indexed by date. + """ + from experiments.pool_registry import ( + get_data_end_date, + _date_to_unix, + ) + + if data_root is None: + data_root = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "quantammsim", "data", + ) + + start_ts = _date_to_unix(pool.plausible_start) + end_str = get_data_end_date(pool.tokens, data_root) + end_ts = _date_to_unix(end_str) + + print(f" Fetching Balancer snapshots for {pool.label} " + f"({pool.chain}, {pool.pool_address})...") + api_df = fetch_balancer_snapshots( + pool.chain, pool.pool_address, start_ts, end_ts, + ) + print(f" Got {len(api_df)} daily snapshots from API") + print(f" TVL range: ${api_df['total_liquidity_usd'].min():,.0f} — " + f"${api_df['total_liquidity_usd'].max():,.0f}") + + print(f" Computing daily volatility from price parquets...") + vol_series = compute_daily_volatility(pool.tokens, data_root, start_ts, end_ts) + print(f" Got {len(vol_series)} daily volatility values") + + # Assemble: volume + TVL from API, volatility from parquets + combined = api_df[["volume_usd", "total_liquidity_usd"]].copy() + combined = combined.rename(columns={"total_liquidity_usd": "effective_tvl_usd"}) + combined["volatility"] = vol_series + combined = combined.dropna() + + print(f" Combined: {len(combined)} days after join") + return combined + + +# --------------------------------------------------------------------------- +# OLS calibration +# --------------------------------------------------------------------------- + +def run_ols_calibration(daily_df, base_fee, model="sqrt"): + """OLS regression for Tsoukalas model params. + + Parameters + ---------- + daily_df : pd.DataFrame + Must contain columns: volume_usd, volatility, effective_tvl_usd. + base_fee : float + Static swap fee (e.g. 0.003). + model : str + 'sqrt' or 'log' — TVL regressor transformation. + + Returns + ------- + noise_params : dict + Coefficients for run_fingerprint["reclamm_noise_params"]. + diagnostics : dict + Standard errors, R², residual summary. + """ + if model == "loglinear": + # Multiplicative model: log(V) = b_0 + b_sigma·σ + b_c·log(TVL) + # Implies: V = exp(b_0) · TVL^b_c · exp(b_sigma·σ) + mask = daily_df["volume_usd"].values > 0 + n_dropped = int((~mask).sum()) + df_fit = daily_df[mask] + + y_log = np.log(df_fit["volume_usd"].values) + X = np.column_stack([ + np.ones(len(df_fit)), + df_fit["volatility"].values, + np.log(df_fit["effective_tvl_usd"].values), + ]) + + beta, _, _, _ = np.linalg.lstsq(X, y_log, rcond=None) + b_0, b_sigma, b_c = beta + + residuals = y_log - X @ beta + n, k = X.shape + bread = np.linalg.inv(X.T @ X) + hc1_scale = n / max(n - k, 1) + meat = X.T @ np.diag(residuals**2 * hc1_scale) @ X + robust_cov = bread @ meat @ bread + se = np.sqrt(np.diag(robust_cov)) + + ss_res = np.sum(residuals**2) + ss_tot = np.sum((y_log - y_log.mean())**2) + r_squared = 1.0 - ss_res / max(ss_tot, 1e-30) + + # Pseudo-R² in levels (median predictor) + y_pred_level = np.exp(X @ beta) + y_actual_level = df_fit["volume_usd"].values + res_level = y_actual_level - y_pred_level + r_sq_level = 1.0 - np.sum(res_level**2) / max( + np.sum((y_actual_level - y_actual_level.mean())**2), 1e-30) + + noise_params = { + "b_0": float(b_0), "b_sigma": float(b_sigma), + "b_c": float(b_c), "base_fee": float(base_fee), + } + diagnostics = { + "se": {"b_0": float(se[0]), "b_sigma": float(se[1]), + "b_c": float(se[2])}, + "r_squared": float(r_squared), + "r_squared_level": float(r_sq_level), + "n_obs": int(n), + "n_dropped_zero": n_dropped, + "residual_mean": float(np.mean(residuals)), + "residual_std": float(np.std(residuals)), + "smearing_factor": float(np.exp(np.var(residuals, ddof=1) / 2)), + "model": "loglinear", + } + return noise_params, diagnostics + + # --- Linear models (sqrt / log) --- + y = daily_df["volume_usd"].values / 1e6 + + if model == "sqrt": + tvl_eff = np.sqrt(daily_df["effective_tvl_usd"].values / 1e6) + elif model == "log": + tvl_eff = np.log(np.maximum(daily_df["effective_tvl_usd"].values / 1e6, 1e-30)) + else: + raise ValueError(f"Unknown model: {model!r}. Use 'sqrt', 'log', or 'loglinear'.") + + X = np.column_stack([ + np.ones(len(daily_df)), # a_0 + daily_df["volatility"].values, # a_sigma + tvl_eff, # a_c + ]) + + beta, residuals_ss, rank, sv = np.linalg.lstsq(X, y, rcond=None) + a_0, a_sigma, a_c = beta + + # Heteroskedasticity-robust standard errors (HC1) + residuals = y - X @ beta + n, k = X.shape + bread = np.linalg.inv(X.T @ X) + hc1_scale = n / max(n - k, 1) + meat = X.T @ np.diag(residuals**2 * hc1_scale) @ X + robust_cov = bread @ meat @ bread + se = np.sqrt(np.diag(robust_cov)) + + # R-squared + ss_res = np.sum(residuals**2) + ss_tot = np.sum((y - np.mean(y))**2) + r_squared = 1.0 - ss_res / max(ss_tot, 1e-30) + + noise_params = { + "a_0_base": float(a_0), + "a_f": 0.0, # not identified with static fees + "a_sigma": float(a_sigma), + "a_c": float(a_c), + "base_fee": float(base_fee), + } + + diagnostics = { + "se": dict(zip( + ["a_0", "a_sigma", "a_c"], + se.tolist(), + )), + "r_squared": float(r_squared), + "n_obs": int(n), + "residual_mean": float(np.mean(residuals)), + "residual_std": float(np.std(residuals)), + "model": model, + } + + return noise_params, diagnostics + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_calibration_diagnostics(daily_df, noise_params, diagnostics, + pool_label="", model="sqrt", + output_dir="results"): + """Generate diagnostic plots for the noise volume calibration. + + Produces a 2×2 figure: + Top-left: Time series — real vs predicted daily volume + effective TVL + Top-right: Scatter — predicted vs actual with 45° line + Bot-left: Residuals vs time + residuals vs fitted + Bot-right: Component decomposition (stacked contributions) + + Parameters + ---------- + daily_df : pd.DataFrame + Calibration DataFrame (indexed by date). + noise_params : dict + Fitted coefficients from run_ols_calibration. + diagnostics : dict + Diagnostics dict from run_ols_calibration. + pool_label : str + Pool name for titles. + model : str + 'sqrt' or 'log'. + output_dir : str + Directory for output PNGs. + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.dates import DateFormatter + import matplotlib.dates as mdates + + os.makedirs(output_dir, exist_ok=True) + + # Extract data + dates = pd.to_datetime(daily_df.index) + y_actual = daily_df["volume_usd"].values / 1e6 # in $M + eff_tvl = daily_df["effective_tvl_usd"].values + vol = daily_df["volatility"].values + + r2 = diagnostics["r_squared"] + n = diagnostics["n_obs"] + + if model == "loglinear": + b_0 = noise_params["b_0"] + b_sigma = noise_params["b_sigma"] + b_c = noise_params["b_c"] + log_tvl = np.log(np.maximum(eff_tvl, 1.0)) + y_pred_log = b_0 + b_sigma * vol + b_c * log_tvl + y_pred = np.exp(y_pred_log) / 1e6 # median prediction in $M + + # Log-space residuals + mask_pos = daily_df["volume_usd"].values > 0 + residuals = np.full(len(dates), np.nan) + residuals[mask_pos] = ( + np.log(daily_df["volume_usd"].values[mask_pos]) + - y_pred_log[mask_pos] + ) + resid_unit = "log scale" + r2_level = diagnostics.get("r_squared_level") + else: + a_0 = noise_params["a_0_base"] + a_sigma = noise_params["a_sigma"] + a_c = noise_params["a_c"] + + if model == "sqrt": + tvl_term = a_c * np.sqrt(eff_tvl / 1e6) + else: + tvl_term = a_c * np.log(np.maximum(eff_tvl / 1e6, 1e-30)) + + y_pred = a_0 + a_sigma * vol + tvl_term + residuals = y_actual - y_pred + resid_unit = "$M" + r2_level = None + + # --- Figure 1: Main diagnostics (2×2) --- + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + + # (0,0) Time series: real vs predicted + TVL on secondary axis + ax = axes[0, 0] + ax.plot(dates, y_actual, color="steelblue", alpha=0.7, linewidth=1, + label="Actual volume") + ax.plot(dates, y_pred, color="crimson", linewidth=1.5, + label="Predicted volume") + ax.set_ylabel("Daily volume ($M)", color="steelblue") + ax.tick_params(axis="y", labelcolor="steelblue") + ax.legend(loc="upper left", fontsize=8) + r2_str = (f"R²(log)={r2:.3f}, R²(level)={r2_level:.3f}" + if r2_level is not None else f"R²={r2:.3f}") + ax.set_title(f"Daily volume: actual vs predicted ({r2_str}, n={n})") + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + + ax2 = ax.twinx() + ax2.fill_between(dates, eff_tvl / 1e6, alpha=0.15, color="green", + label="Effective TVL ($M)") + ax2.set_ylabel("Effective TVL ($M)", color="green") + ax2.tick_params(axis="y", labelcolor="green") + ax2.legend(loc="upper right", fontsize=8) + + # (0,1) Scatter: predicted vs actual + ax = axes[0, 1] + ax.scatter(y_pred, y_actual, alpha=0.5, s=15, color="steelblue", + edgecolors="none") + lims = [min(y_pred.min(), y_actual.min()), max(y_pred.max(), y_actual.max())] + margin = (lims[1] - lims[0]) * 0.05 + lims = [lims[0] - margin, lims[1] + margin] + ax.plot(lims, lims, "k--", linewidth=0.8, alpha=0.5, label="45° line") + ax.set_xlabel("Predicted ($M)") + ax.set_ylabel("Actual ($M)") + ax.set_title("Predicted vs actual") + ax.legend(fontsize=8) + ax.set_aspect("equal", adjustable="box") + + # (1,0) Residuals: vs time (top) and vs fitted (bottom) + ax = axes[1, 0] + ax.scatter(dates, residuals, alpha=0.5, s=12, color="steelblue", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + # 7-day rolling mean of residuals + res_series = pd.Series(residuals, index=dates) + rolling_mean = res_series.rolling(7, min_periods=1).mean() + ax.plot(dates, rolling_mean, color="crimson", linewidth=1.5, + label="7-day rolling mean") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs time") + ax.legend(fontsize=8) + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + + # (1,1) Component decomposition + ax = axes[1, 1] + if model == "loglinear": + # Log-space additive decomposition as line plots + log_tvl_plot = np.log(np.maximum(eff_tvl, 1.0)) + comp_base = np.full(len(dates), b_0) + comp_tvl = b_c * log_tvl_plot + comp_vol = b_sigma * vol + total_log = comp_base + comp_tvl + comp_vol + vol_usd = daily_df["volume_usd"].values.copy() + vol_usd[vol_usd <= 0] = np.nan + actual_log = np.log(vol_usd) + ax.plot(dates, comp_base, color="grey", linestyle="--", linewidth=1, + label=f"b_0 = {b_0:.2f}") + ax.plot(dates, comp_base + comp_tvl, color="green", linewidth=1.5, + label=f"b_0 + b_c·log(TVL) (b_c={b_c:.4f})") + ax.plot(dates, total_log, color="crimson", linewidth=1.5, + label="Full prediction") + ax.scatter(dates, actual_log, color="steelblue", s=10, alpha=0.5, + label="Actual log(V)", zorder=5) + ax.set_ylabel("log(Volume, USD)") + ax.set_title("Component decomposition (log space)") + + fig.suptitle( + f"{pool_label} — noise calibration ({model})\n" + f"log(V) = {b_0:.2f} + {b_sigma:.4f}·σ + {b_c:.4f}·log(TVL)", + fontsize=11, + ) + else: + intercept_contrib = np.full(len(dates), a_0) + vol_contrib = a_sigma * vol + tvl_contrib = tvl_term + + ax.fill_between(dates, 0, intercept_contrib, alpha=0.3, color="grey", + label=f"a_0 = {a_0:.4f}") + ax.fill_between(dates, intercept_contrib, intercept_contrib + vol_contrib, + alpha=0.3, color="orange", + label=f"a_σ·σ (a_σ={a_sigma:.4f})") + ax.fill_between(dates, intercept_contrib + vol_contrib, + intercept_contrib + vol_contrib + tvl_contrib, + alpha=0.3, color="green", + label=f"a_c·{model}(TVL) (a_c={a_c:.4f})") + ax.plot(dates, y_actual, color="steelblue", linewidth=1, alpha=0.7, + label="Actual") + ax.set_ylabel("Volume ($M)") + ax.set_title("Component decomposition") + + fig.suptitle( + f"{pool_label} — Tsoukalas noise calibration ({model})\n" + f"V/1e6 = {a_0:.4f} + {a_sigma:.4f}·σ + {a_c:.4f}·{model}(TVL_eff/1e6)", + fontsize=11, + ) + ax.legend(fontsize=7, loc="upper left") + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + plt.tight_layout() + + fname = f"noise_calibration_{pool_label}_{model}.png" + path = os.path.join(output_dir, fname) + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- Figure 2: Residuals vs each regressor --- + fig2, axes2 = plt.subplots(1, 3, figsize=(16, 5)) + + # Residuals vs volatility + ax = axes2[0] + ax.scatter(vol, residuals, alpha=0.5, s=12, color="orange", edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Volatility (annualised)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs volatility") + + # Residuals vs effective TVL + ax = axes2[1] + ax.scatter(eff_tvl / 1e6, residuals, alpha=0.5, s=12, color="green", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Effective TVL ($M)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs effective TVL") + + # Residuals vs fitted + ax = axes2[2] + ax.scatter(y_pred, residuals, alpha=0.5, s=12, color="steelblue", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Fitted ($M)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs fitted") + + fig2.suptitle(f"{pool_label} — Residual diagnostics ({model})", fontsize=11) + plt.tight_layout() + + fname2 = f"noise_residuals_{pool_label}_{model}.png" + path2 = os.path.join(output_dir, fname2) + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + return path, path2 + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Calibrate Tsoukalas noise volume model for reClAMM" + ) + parser.add_argument( + "--csv", default=None, + help="Path to CSV with columns: volume_usd, volatility, effective_tvl_usd", + ) + parser.add_argument( + "--pool", default=None, + help="Pool label from pool_registry (e.g. cbBTC_WETH) for end-to-end calibration", + ) + parser.add_argument("--base-fee", type=float, default=None, + help="Override base fee (default: use pool's swap_fee)") + parser.add_argument("--model", choices=["sqrt", "log", "loglinear"], + default="sqrt") + parser.add_argument( + "--output", default=None, + help="Output JSON file path. Defaults to stdout.", + ) + parser.add_argument( + "--plot", action="store_true", + help="Generate diagnostic plots (saved to --output-dir)", + ) + parser.add_argument( + "--output-dir", default="results", + help="Directory for diagnostic plots (default: results)", + ) + args = parser.parse_args() + + if args.csv is None and args.pool is None: + parser.error("One of --csv or --pool is required") + + if args.pool is not None: + # End-to-end mode: fetch data, assemble, calibrate + from experiments.pool_registry import POOL_REGISTRY + + if args.pool not in POOL_REGISTRY: + print(f"Unknown pool: {args.pool}", file=sys.stderr) + print(f"Available: {list(POOL_REGISTRY.keys())}", file=sys.stderr) + sys.exit(1) + + pool = POOL_REGISTRY[args.pool] + base_fee = args.base_fee if args.base_fee is not None else pool.swap_fee + + print(f"Calibrating noise model for {pool.label} ({pool.chain})") + print(f" Swap fee: {base_fee}") + print(f" Model: {args.model}") + + df = build_calibration_df(pool) + else: + # CSV mode + df = pd.read_csv(args.csv) + required_cols = {"volume_usd", "volatility", "effective_tvl_usd"} + missing = required_cols - set(df.columns) + if missing: + print(f"Error: missing columns: {missing}", file=sys.stderr) + sys.exit(1) + base_fee = args.base_fee if args.base_fee is not None else 0.003 + + noise_params, diagnostics = run_ols_calibration(df, base_fee, args.model) + + # Print diagnostics + print(f"\n OLS Results ({args.model} model):") + print(f" R² = {diagnostics['r_squared']:.4f}") + if "r_squared_level" in diagnostics: + print(f" R²(level) = {diagnostics['r_squared_level']:.4f}") + if "n_dropped_zero" in diagnostics and diagnostics["n_dropped_zero"] > 0: + print(f" Dropped {diagnostics['n_dropped_zero']} zero-volume days") + if "smearing_factor" in diagnostics: + print(f" Smearing factor = {diagnostics['smearing_factor']:.4f} " + f"(E[V]/median[V])") + print(f" n = {diagnostics['n_obs']}") + print(f" Coefficients:") + if args.model == "loglinear": + coef_keys = ["b_0", "b_sigma", "b_c"] + else: + coef_keys = ["a_0", "a_sigma", "a_c"] + for key in coef_keys: + param_key = "a_0_base" if key == "a_0" else key + val = noise_params[param_key] + se = diagnostics["se"][key] + t_stat = val / se if se > 0 else float("inf") + print(f" {key:>8} = {val:>10.4f} (SE={se:.4f}, t={t_stat:.2f})") + print(f" Residual: mean={diagnostics['residual_mean']:.6f}, " + f"std={diagnostics['residual_std']:.4f}") + + # Plot diagnostics + if args.plot: + label = args.pool if args.pool else "custom" + plot_calibration_diagnostics( + df, noise_params, diagnostics, + pool_label=label, model=args.model, + output_dir=args.output_dir, + ) + + result = { + "noise_params": noise_params, + "diagnostics": diagnostics, + } + + output_str = json.dumps(result, indent=2) + if args.output: + with open(args.output, "w") as f: + f.write(output_str + "\n") + print(f"\nWrote calibration to {args.output}", file=sys.stderr) + else: + print(f"\n{output_str}") + + +if __name__ == "__main__": + main() diff --git a/scripts/compare_reclamm_thermostats.py b/scripts/compare_reclamm_thermostats.py new file mode 100644 index 0000000..8a2c374 --- /dev/null +++ b/scripts/compare_reclamm_thermostats.py @@ -0,0 +1,379 @@ +"""Compare geometric vs constant-arc-length thermostats on historic data. + +Runs AAVE/ETH reClAMM pool simulations with both interpolation methods. +Plots: pool value, cumulative LVR, price path, empirical weights, +value difference, LVR ratio, and per-step LVR distribution (∝ Δs²). + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/compare_reclamm_thermostats.py +""" + +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(daily_price_shift_exponent): + """Convert shift rate to daily price shift base (matches Solidity).""" + return 1.0 - daily_price_shift_exponent / 124649.0 + + +# Pool configurations to compare +CONFIGS = [ + { + "name": "AAVE/ETH on-chain (25bps, narrow range)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 1.5, + "centeredness_margin": 0.5, + "daily_price_shift_exponent": 0.1, + }, + { + "name": "AAVE/ETH wide range (25bps)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 4.0, + "centeredness_margin": 0.2, + "daily_price_shift_exponent": 1.0, + }, + { + "name": "AAVE/ETH zero fees (narrow)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0, + "price_ratio": 1.5, + "centeredness_margin": 0.5, + "daily_price_shift_exponent": 0.1, + }, +] + + +def make_fingerprint(cfg, interpolation_method, centeredness_scaling=False): + """Build run fingerprint for a given config and interpolation method.""" + return { + "tokens": cfg["tokens"], + "rule": "reclamm", + "startDateString": cfg["start"], + "endDateString": cfg["end"], + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": cfg["fees"], + "gas_cost": 0.0, + "arb_fees": 0.0, + "reclamm_interpolation_method": interpolation_method, + "reclamm_arc_length_speed": None, # auto-calibrate + "reclamm_centeredness_scaling": centeredness_scaling, + } + + +def make_params(cfg): + """Build pool params from config.""" + return { + "price_ratio": jnp.array(cfg["price_ratio"]), + "centeredness_margin": jnp.array(cfg["centeredness_margin"]), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(cfg["daily_price_shift_exponent"]) + ), + } + + +def run_comparison(cfg): + """Run all thermostat variants, return results dict.""" + params = make_params(cfg) + + results = {} + for method in ["geometric", "constant_arc_length"]: + fp = make_fingerprint(cfg, method) + results[method] = do_run_on_historic_data( + run_fingerprint=fp, params=params + ) + + # Geometric + centeredness-proportional scaling (scales decay duration) + fp_geo_scaled = make_fingerprint(cfg, "geometric", centeredness_scaling=True) + results["geometric_scaled"] = do_run_on_historic_data( + run_fingerprint=fp_geo_scaled, params=params + ) + + # Arc-length + centeredness-proportional scaling (scales speed) + fp_cal_scaled = make_fingerprint(cfg, "constant_arc_length", centeredness_scaling=True) + results["cal_scaled"] = do_run_on_historic_data( + run_fingerprint=fp_cal_scaled, params=params + ) + + return results + + +def print_comparison(cfg, results): + """Print text summary table.""" + methods = [ + ("Geometric", results["geometric"]), + ("Geo+Scaled", results["geometric_scaled"]), + ("Const Arc", results["constant_arc_length"]), + ("Arc+Scaled", results["cal_scaled"]), + ] + + hodl_value = float((methods[0][1]["reserves"][0] * methods[0][1]["prices"][-1]).sum()) + + print("=" * 105) + print(f" {cfg['name']}") + print(f" price_ratio={cfg['price_ratio']}, " + f"margin={cfg['centeredness_margin']}, " + f"shift_exp={cfg['daily_price_shift_exponent']}, " + f"fees={cfg['fees']}") + print("-" * 105) + header = " {:20s}".format("") + for name, _ in methods: + header += f" {name:>14s}" + print(header) + + row = " {:20s}".format("Final value") + for _, r in methods: + row += f" ${float(r['final_value']):>13,.0f}" + print(row) + + print(f" {'HODL value':20s} ${hodl_value:>13,.0f}") + + row = " {:20s}".format("LVR (HODL - final)") + for _, r in methods: + lvr = hodl_value - float(r["final_value"]) + row += f" ${lvr:>13,.0f}" + print(row) + + row = " {:20s}".format("Return") + for _, r in methods: + ret = (float(r["final_value"]) / float(r["value"][0]) - 1) * 100 + row += f" {ret:>13.2f}%" + print(row) + + row = " {:20s}".format("vs HODL") + for _, r in methods: + vs = (float(r["final_value"]) / hodl_value - 1) * 100 + row += f" {vs:>13.2f}%" + print(row) + print("=" * 105) + + +def plot_comparison(cfg, results, fig_idx): + """Plot 4-panel comparison for one config.""" + # Method name → (result dict, color, linestyle) + variants = { + "Geometric": (results["geometric"], "C0", "-"), + "Geo+Scaled": (results["geometric_scaled"], "C1", "-"), + "Const arc-len": (results["constant_arc_length"], "C2", "--"), + "Arc+Scaled": (results["cal_scaled"], "C3", "--"), + } + + geo = results["geometric"] + geo_prices = np.array(geo["prices"]) + geo_reserves = np.array(geo["reserves"]) + n_steps = len(np.array(geo["value"])) + t_days = np.arange(n_steps) / (60 * 24) + + hodl_traj = (geo_reserves[0] * geo_prices[:n_steps]).sum(axis=-1) + price_ratio_traj = geo_prices[:n_steps, 0] / geo_prices[:n_steps, 1] + + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig.suptitle(cfg["name"], fontsize=13, fontweight="bold") + + # (0,0) Pool value over time + ax = axes[0, 0] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + ax.plot(t_days, vals / 1e6, color=color, ls=ls, label=name, alpha=0.9) + ax.plot(t_days, np.array(hodl_traj) / 1e6, color="gray", ls=":", + alpha=0.5, label="HODL") + ax.set_xlabel("Days") + ax.set_ylabel("Pool value ($M)") + ax.set_title("Pool value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (0,1) Cumulative LVR + ax = axes[0, 1] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + lvr = np.array(hodl_traj) - vals + ax.plot(t_days, lvr / 1e3, color=color, ls=ls, label=name, alpha=0.9) + ax.set_xlabel("Days") + ax.set_ylabel("Cumulative LVR ($K)") + ax.set_title("Cumulative LVR (HODL - pool value)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (1,0) Price ratio + ax = axes[1, 0] + ax.plot(t_days, price_ratio_traj, color="C4", alpha=0.7) + ax.set_xlabel("Days") + ax.set_ylabel(f"{cfg['tokens'][0]}/{cfg['tokens'][1]} price ratio") + ax.set_title("Price path") + ax.grid(True, alpha=0.3) + + # (1,1) Empirical weights + ax = axes[1, 1] + for name, (r, color, ls) in variants.items(): + w = np.array(r["weights"]) + n_w = min(len(w), n_steps) + t_w = np.arange(n_w) / (60 * 24) + ax.plot(t_w, w[:n_w, 0], color=color, ls=ls, label=name, alpha=0.9) + ax.set_xlabel("Days") + ax.set_ylabel(f"Weight ({cfg['tokens'][0]})") + ax.set_title("Empirical weight (token 0)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + fname = f"reclamm_thermostat_comparison_{fig_idx}.png" + plt.savefig(fname, dpi=150) + print(f"Saved {fname}") + plt.close(fig) + + # Second figure: diagnostics + geo_values = np.array(geo["value"]) + geo_lvr = np.array(hodl_traj) - geo_values + + fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5)) + fig2.suptitle(f"{cfg['name']} — diagnostics", fontsize=13, fontweight="bold") + + # (left) Value difference vs geometric + ax = axes2[0] + for name, (r, color, ls) in variants.items(): + if name == "Geometric": + continue + vals = np.array(r["value"]) + ax.plot(t_days, (vals - geo_values) / 1e3, color=color, ls=ls, + label=name, alpha=0.9) + ax.axhline(0, color="gray", ls="--", alpha=0.5) + ax.set_xlabel("Days") + ax.set_ylabel("Value difference ($K)") + ax.set_title("Minus Geometric") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (middle) LVR ratio over time + ax = axes2[1] + mask = np.abs(geo_lvr) > 100 + if mask.any(): + for name, (r, color, ls) in variants.items(): + if name == "Geometric": + continue + vals = np.array(r["value"]) + method_lvr = np.array(hodl_traj) - vals + ratio = np.full_like(geo_lvr, np.nan) + ratio[mask] = method_lvr[mask] / geo_lvr[mask] + ax.plot(t_days, ratio, color=color, ls=ls, alpha=0.7, label=name) + ax.axhline(1.0, color="gray", ls="--", alpha=0.5) + ax.set_ylabel("LVR ratio (method / geometric)") + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "LVR too small to compare", + transform=ax.transAxes, ha="center", va="center") + ax.set_xlabel("Days") + ax.set_title("Relative LVR") + ax.grid(True, alpha=0.3) + + # (right) Per-step LVR histogram + ax = axes2[2] + all_pos = [] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + method_lvr = np.array(hodl_traj) - vals + step_lvr = np.diff(method_lvr) + pos = step_lvr[step_lvr > 0] + all_pos.append((name, pos, color)) + has_data = [len(p) > 10 for _, p, _ in all_pos] + if any(has_data): + max_val = max(np.percentile(p, 99) for _, p, _ in all_pos if len(p) > 10) + bins = np.linspace(0, max_val, 50) + for name, pos, color in all_pos: + if len(pos) > 10: + ax.hist(pos, bins=bins, color=color, alpha=0.3, label=name, + density=True) + ax.set_xlabel("Per-step LVR ($)") + ax.set_ylabel("Density") + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "Too few thermostat steps", + transform=ax.transAxes, ha="center", va="center") + ax.set_title("Per-step LVR distribution") + ax.grid(True, alpha=0.3) + + plt.tight_layout() + fname2 = f"reclamm_thermostat_diff_{fig_idx}.png" + plt.savefig(fname2, dpi=150) + print(f"Saved {fname2}") + plt.close(fig2) + + +if __name__ == "__main__": + all_results = [] + for i, cfg in enumerate(CONFIGS): + print(f"\n>>> Running {cfg['name']}...") + try: + results = run_comparison(cfg) + print_comparison(cfg, results) + plot_comparison(cfg, results, i) + all_results.append((cfg, results)) + except Exception as e: + print(f" FAILED: {e}") + import traceback + traceback.print_exc() + + # Summary overlay: all configs on one figure (pool value normalised) + if len(all_results) > 1: + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle("Cross-config comparison (normalised)", fontsize=13, + fontweight="bold") + + method_keys = [ + ("geometric", "geo", "-"), + ("geometric_scaled", "geo+s", "-."), + ("constant_arc_length", "arc", "--"), + ("cal_scaled", "arc+s", ":"), + ] + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + + for j, (key, suffix, ls) in enumerate(method_keys): + v = np.array(results[key]["value"]) + color_idx = i * len(method_keys) + j + + # (left) Normalised pool value + axes[0].plot(t, v / v[0], ls=ls, alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}") + + # (right) Value difference vs geometric (skip geo itself) + if key != "geometric": + pct_diff = (v - geo_v) / geo_v * 100 + axes[1].plot(t, pct_diff, ls=ls, alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}") + + axes[0].set_xlabel("Days") + axes[0].set_ylabel("Normalised pool value") + axes[0].set_title("Pool value (V/V0)") + axes[0].legend(fontsize=6, ncol=2) + axes[0].grid(True, alpha=0.3) + + axes[1].set_xlabel("Days") + axes[1].set_ylabel("(Method - Geo) / Geo (%)") + axes[1].set_title("Relative value difference vs Geometric") + axes[1].axhline(0, color="gray", ls="--", alpha=0.5) + axes[1].legend(fontsize=6, ncol=2) + axes[1].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_thermostat_summary.png", dpi=150) + print("\nSaved reclamm_thermostat_summary.png") + plt.close(fig) diff --git a/scripts/demo_run_reclamm.py b/scripts/demo_run_reclamm.py new file mode 100644 index 0000000..3ea21ec --- /dev/null +++ b/scripts/demo_run_reclamm.py @@ -0,0 +1,207 @@ +"""Demo runs for reClAMM pools vs Balancer 50/50 baseline. + +Runs reClAMM pool simulations with parameters pulled from on-chain pools +(AAVE/ETH) and hypothetical configurations, each paired with a Balancer +50/50 constant-weight pool at the same fee level for comparison. + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/demo_run_reclamm.py +""" + +import jax.numpy as jnp +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(daily_price_shift_exponent): + """Convert shift rate to daily price shift base (matches Solidity).""" + return 1.0 - daily_price_shift_exponent / 124649.0 + + +def balancer_fingerprint(tokens, start, end, fees): + """Build a Balancer 50/50 fingerprint matching the given reclamm config.""" + return { + "tokens": tokens, + "rule": "balancer", + "startDateString": start, + "endDateString": end, + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": fees, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + } + + +SCENARIOS = [ + { + "name": "AAVE/ETH on-chain (25bps)", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0025, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.1) + ), + }, + }, + }, + { + "name": "AAVE/ETH zero fees", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.1) + ), + }, + }, + }, + { + "name": "AAVE/ETH wide range (25bps)", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0025, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(1.0) + ), + }, + }, + }, + { + "name": "BTC/ETH (10bps)", + "reclamm": { + "fingerprint": { + "tokens": ["BTC", "ETH"], + "rule": "reclamm", + "startDateString": "2024-01-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.001, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(2.0), + "centeredness_margin": jnp.array(0.3), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.5) + ), + }, + }, + }, +] + + +def run_scenario(scenario): + """Run a reClAMM config and its Balancer 50/50 baseline, print comparison.""" + rc = scenario["reclamm"] + fp = rc["fingerprint"] + + # Run reClAMM + reclamm_result = do_run_on_historic_data( + run_fingerprint=fp, params=rc["params"] + ) + + # Run Balancer 50/50 with same tokens, dates, fees + bal_fp = balancer_fingerprint( + fp["tokens"], fp["startDateString"], fp["endDateString"], fp["fees"] + ) + bal_params = { + "initial_weights_logits": jnp.zeros(len(fp["tokens"])), + } + balancer_result = do_run_on_historic_data( + run_fingerprint=bal_fp, params=bal_params + ) + + # HODL value (from reClAMM initial reserves at final prices) + hodl_value = float( + (reclamm_result["reserves"][0] * reclamm_result["prices"][-1]).sum() + ) + + rc_final = float(reclamm_result["final_value"]) + bal_final = float(balancer_result["final_value"]) + rc_init = float(reclamm_result["value"][0]) + bal_init = float(balancer_result["value"][0]) + + print("=" * 80) + print(f" {scenario['name']}") + print(f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']}") + print("-" * 80) + print(f" {'':30s} {'reClAMM':>14s} {'Balancer 50/50':>14s}") + print(f" {'Initial value':30s} ${rc_init:>13,.0f} ${bal_init:>13,.0f}") + print(f" {'Final value':30s} ${rc_final:>13,.0f} ${bal_final:>13,.0f}") + print( + f" {'Return':30s} " + f"{(rc_final / rc_init - 1) * 100:>13.2f}% " + f"{(bal_final / bal_init - 1) * 100:>13.2f}%" + ) + print( + f" {'vs HODL':30s} " + f"{(rc_final / hodl_value - 1) * 100:>13.2f}% " + f"{(bal_final / hodl_value - 1) * 100:>13.2f}%" + ) + print( + f" {'reClAMM vs Balancer':30s} " + f"{(rc_final / bal_final - 1) * 100:>13.2f}%" + ) + print("=" * 80) + + +if __name__ == "__main__": + for scenario in SCENARIOS: + print(f"\n>>> {scenario['name']}...") + try: + run_scenario(scenario) + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc() diff --git a/scripts/fetch_token_mcaps.py b/scripts/fetch_token_mcaps.py new file mode 100644 index 0000000..8e71d6f --- /dev/null +++ b/scripts/fetch_token_mcaps.py @@ -0,0 +1,196 @@ +"""Fetch token market caps from CoinGecko and cache locally. + +Usage: + python scripts/fetch_token_mcaps.py + +Output: + local_data/noise_calibration/token_mcaps.json +""" + +import json +import os +import sys +import time + +import requests + +OUTPUT_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "token_mcaps.json", +) + +# Token symbol -> CoinGecko ID mapping +# Covers all tokens appearing in our 26 matched pools + common Balancer tokens +COINGECKO_IDS = { + # Blue-chip / wrapped natives + "WETH": "ethereum", + "ETH": "ethereum", + "WBTC": "wrapped-bitcoin", + "BTC": "bitcoin", + "cbBTC": "bitcoin", # Coinbase wrapped BTC — use BTC mcap + "USDC": "usd-coin", + "USDT": "tether", + "DAI": "dai", + "wstETH": "wrapped-steth", + "stETH": "staked-ether", + "rETH": "rocket-pool-eth", + "cbETH": "coinbase-wrapped-staked-eth", + "WMATIC": "polygon-ecosystem-token", + "MATIC": "polygon-ecosystem-token", + "POL": "polygon-ecosystem-token", + "WAVAX": "avalanche-2", + "AVAX": "avalanche-2", + "GNO": "gnosis", + "WXDAI": "dai", # Wrapped xDAI ≈ DAI + "xDAI": "dai", + "S": "sonic-3", + "wS": "sonic-3", + # Mid-cap DeFi + "AAVE": "aave", + "LINK": "chainlink", + "UNI": "uniswap", + "BAL": "balancer", + "MKR": "maker", + "CRV": "curve-dao-token", + "COMP": "compound-governance-token", + "SNX": "havven", + "LDO": "lido-dao", + "RPL": "rocket-pool", + "SUSHI": "sushi", + "YFI": "yearn-finance", + "1INCH": "1inch", + "ENS": "ethereum-name-service", + "ARB": "arbitrum", + "OP": "optimism", + "PENDLE": "pendle", + "ENA": "ethena", + "EIGEN": "eigenlayer", + "COW": "cow-protocol", + "SAFE": "safe", + # Smaller / specific tokens in our pools + "ACX": "across-protocol", + "ALCX": "alchemix", + "QI": "benqi", + "QNT": "quant-network", + "RDNT": "radiant-capital", + # TREE not on CoinGecko — handled as fallback below + "XAI": "xai-blockchain", + # Wrapped aTokens — use underlying + "waEthLidoWETH": "ethereum", + "waEthLidowstETH": "wrapped-steth", + "waBasWETH": "ethereum", + "waBasUSDC": "usd-coin", + "waEthUSDC": "usd-coin", + "waGnoGNO": "gnosis", + "waGnowstETH": "wrapped-steth", + # Additional tokens from expanded pool set + "wPOL": "polygon-ecosystem-token", + "stS": "sonic-3", # Staked Sonic — use S mcap + "JitoSOL": "jito-governance-token", + "scUSD": "usd-coin", # Rings scUSD stablecoin — use USDC mcap as proxy + "DOLA": "dola-usd", +} + +# Asset type classification +STABLECOINS = { + "USDC", "USDT", "DAI", "WXDAI", "xDAI", "GHO", "LUSD", "crvUSD", + "FRAX", "sDAI", "scUSD", "DOLA", + "waBasUSDC", "waEthUSDC", +} +NATIVE_LST = { + "WETH", "ETH", "wstETH", "stETH", "rETH", "cbETH", + "WBTC", "BTC", "cbBTC", + "WMATIC", "MATIC", "POL", "wPOL", + "WAVAX", "AVAX", + "GNO", "S", "wS", "stS", + "JitoSOL", + "waEthLidoWETH", "waEthLidowstETH", + "waBasWETH", "waGnoGNO", "waGnowstETH", +} +# Everything else is VOLATILE (asset_type=2) + + +def fetch_mcaps(): + """Fetch market caps from CoinGecko in batches.""" + unique_ids = sorted(set(COINGECKO_IDS.values())) + print(f"Fetching market caps for {len(unique_ids)} unique CoinGecko IDs...") + + # CoinGecko allows up to 250 IDs per request + batch_size = 100 + all_data = {} + + for i in range(0, len(unique_ids), batch_size): + batch = unique_ids[i:i + batch_size] + ids_str = ",".join(batch) + url = ( + f"https://api.coingecko.com/api/v3/simple/price" + f"?ids={ids_str}&vs_currencies=usd&include_market_cap=true" + ) + resp = requests.get(url, timeout=30) + resp.raise_for_status() + data = resp.json() + all_data.update(data) + print(f" Batch {i // batch_size + 1}: {len(data)} tokens") + if i + batch_size < len(unique_ids): + time.sleep(1) # rate limit + + # Build symbol -> mcap mapping + mcaps = {} + missing = [] + for symbol, gecko_id in COINGECKO_IDS.items(): + if gecko_id in all_data and "usd_market_cap" in all_data[gecko_id]: + mcaps[symbol] = { + "mcap_usd": all_data[gecko_id]["usd_market_cap"], + "price_usd": all_data[gecko_id]["usd"], + "coingecko_id": gecko_id, + } + else: + missing.append((symbol, gecko_id)) + + if missing: + print(f"\n Missing from CoinGecko: {missing}") + + # Fallback for tokens not on CoinGecko (very small tokens) + FALLBACK_MCAPS = { + "TREE": 1_000_000, # ~$1M estimate for small governance token + } + for symbol, mcap_est in FALLBACK_MCAPS.items(): + if symbol not in mcaps: + mcaps[symbol] = { + "mcap_usd": mcap_est, + "price_usd": 0.0, + "coingecko_id": "fallback", + } + print(f" Fallback: {symbol} -> ${mcap_est:,.0f}") + + # Add asset type classification + for symbol in mcaps: + if symbol in STABLECOINS: + mcaps[symbol]["asset_type"] = "stable" + elif symbol in NATIVE_LST: + mcaps[symbol]["asset_type"] = "native_lst" + else: + mcaps[symbol]["asset_type"] = "volatile" + + return mcaps + + +def main(): + mcaps = fetch_mcaps() + + os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True) + with open(OUTPUT_PATH, "w") as f: + json.dump(mcaps, f, indent=2) + + print(f"\nSaved {len(mcaps)} tokens to {OUTPUT_PATH}") + + # Summary + print("\nSample entries:") + for sym in ["WETH", "AAVE", "USDC", "QI", "TREE"]: + if sym in mcaps: + m = mcaps[sym] + print(f" {sym}: ${m['mcap_usd']:,.0f} ({m['asset_type']})") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_predicted_vs_real_volume.py b/scripts/plot_predicted_vs_real_volume.py new file mode 100644 index 0000000..83232ab --- /dev/null +++ b/scripts/plot_predicted_vs_real_volume.py @@ -0,0 +1,151 @@ +"""Plot predicted vs real daily volume for pool registry pools. + +Uses the fitted noise model (from calibrate_noise_unified.py) to compute +predicted daily log-volume for each pool in the registry, and overlays +the actual observed volume from the Balancer API panel data. +""" + +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "experiments")) +from pool_registry import POOL_REGISTRY, BALANCER_API_CHAIN + + +def main(): + fitted_path = "results/unified_full_90d.json" + panel_path = "local_data/noise_calibration/panel.parquet" + output_dir = "results/unified_full_90d" + os.makedirs(output_dir, exist_ok=True) + + with open(fitted_path) as f: + fitted = json.load(f) + + panel = pd.read_parquet(panel_path) + + # Deduplicate registry: multiple entries can share the same pool address + # (e.g. cbBTC_WETH and cbBTC_WETH_post_oct). Group by address. + unique_pools = {} + for label, pool in POOL_REGISTRY.items(): + addr = pool.pool_address.lower() + if addr not in unique_pools: + unique_pools[addr] = (label, pool) + + # Match to panel + matched = [] + for addr, (label, pool) in unique_pools.items(): + pid_matches = [ + pid for pid in fitted["pools"] + if addr in pid.lower() + ] + if pid_matches: + pid = pid_matches[0] + matched.append((label, pool, pid)) + else: + print(f" {label}: not in fitted model (skipping)") + + if not matched: + print("No registry pools found in the fitted model.") + return + + print(f"Plotting {len(matched)} pools: {[m[0] for m in matched]}") + + # Determine grid layout + n = len(matched) + ncols = min(n, 2) + nrows = (n + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(7 * ncols, 5 * nrows), + squeeze=False) + + for idx, (label, pool, pid) in enumerate(matched): + ax = axes[idx // ncols][idx % ncols] + pool_data = fitted["pools"][pid] + theta = np.array(pool_data["theta_median"]) + # theta = [intercept, b_tvl, b_sigma, b_weekend] + + # Get panel data for this pool + pool_panel = panel[panel["pool_id"] == pid].copy() + pool_panel = pool_panel.sort_values("date") + + if len(pool_panel) == 0: + ax.set_title(f"{label}: no panel data") + continue + + # Filter to last 90 days (matching training window) + max_date = panel["date"].max() + if hasattr(max_date, "date"): + max_date = max_date + from datetime import date, timedelta + if isinstance(max_date, date): + cutoff = max_date - timedelta(days=90) + else: + cutoff = pd.Timestamp(max_date) - pd.Timedelta(days=90) + pool_panel = pool_panel[ + pool_panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if len(pool_panel) < 5: + ax.set_title(f"{label}: <5 obs in 90d window") + continue + + # Build x_obs: [1, log_tvl_lag1, volatility, weekend] + x_obs = np.column_stack([ + np.ones(len(pool_panel)), + pool_panel["log_tvl_lag1"].values, + pool_panel["volatility"].values, + pool_panel["weekend"].values, + ]) + + predicted_log_vol = x_obs @ theta + actual_log_vol = pool_panel["log_volume"].values + + # Convert to USD volume for interpretability + predicted_vol = np.exp(predicted_log_vol) + actual_vol = np.exp(actual_log_vol) + + dates = pd.to_datetime(pool_panel["date"].values) + + # Plot + ax.plot(dates, actual_vol, "o-", color="steelblue", markersize=3, + linewidth=1, alpha=0.7, label="Actual") + ax.plot(dates, predicted_vol, "s--", color="orangered", markersize=3, + linewidth=1, alpha=0.7, label="Predicted") + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)") + ax.set_title(f"{label} ({pool_data['chain']})\n" + f"b_c={theta[1]:.2f} b_σ={theta[2]:.2f} " + f"b_wknd={theta[3]:.2f}") + ax.legend(fontsize=8) + ax.tick_params(axis="x", rotation=30) + + # Annotate R² for this pool + ss_res = np.sum((actual_log_vol - predicted_log_vol) ** 2) + ss_tot = np.sum((actual_log_vol - actual_log_vol.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + ax.text(0.02, 0.95, f"R²={r2:.3f}\nn={len(pool_panel)}", + transform=ax.transAxes, fontsize=8, va="top", + bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.8)) + + # Hide unused axes + for idx in range(n, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle("Noise model: predicted vs actual daily volume\n" + "(registry pools, 90-day training window)", fontsize=13) + fig.tight_layout() + out_path = os.path.join(output_dir, "registry_predicted_vs_real.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f"Saved: {out_path}") + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_reclamm_optuna_result.py b/scripts/plot_reclamm_optuna_result.py new file mode 100644 index 0000000..35242da --- /dev/null +++ b/scripts/plot_reclamm_optuna_result.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +"""Plot reClAMM pool performance from Optuna tuning results. + +Reads the SGD-compatible JSON output of tune_reclamm_params.py (or any Optuna +run), extracts the best trial's pool params, re-runs a forward pass over the +full train+test window, and produces a value-over-time plot with on-chain +baselines and cumulative fee revenue. + +Usage: + python scripts/plot_reclamm_optuna_result.py results/run_.json + python scripts/plot_reclamm_optuna_result.py results/run_.json --output my_plot.png + python scripts/plot_reclamm_optuna_result.py results/run_.json --top-k 3 +""" + +import argparse +import json +import sys + +import jax.numpy as jnp +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from datetime import datetime + +from quantammsim.runners.jax_runners import do_run_on_historic_data + +# ── On-chain baselines ──────────────────────────────────────────────────── +ONCHAIN_LAUNCH_PARAMS = { + "price_ratio": 1.5, "centeredness_margin": 0.5, "shift_exponent": 0.1, +} +ONCHAIN_CURRENT_PARAMS = { + "price_ratio": 4.0, "centeredness_margin": 0.1, "shift_exponent": 0.001, +} + +BG = "#162536" +TEXT_COLOR = "#E6CE97" +COLORS = [ + "#3498db", "#2ecc71", "#e74c3c", # top-k + "#f39c12", # on-chain launch + "#9b59b6", # on-chain current +] + + +def _plot_order(configs): + """Yield (name, meta, color_idx) with baselines first, optimized trials last.""" + optimized = [] + baselines = [] + for i, (name, meta) in enumerate(configs.items()): + if "On-Chain" in name: + baselines.append((name, meta, i)) + else: + optimized.append((name, meta, i)) + return baselines + optimized + + +def parse_args(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("results_json", help="Path to run_.json from Optuna") + p.add_argument("--top-k", type=int, default=1, + help="Plot top K trials by objective (default 1)") + p.add_argument("--output", default=None, + help="Output PNG path (default: auto-generated)") + p.add_argument("--no-onchain", action="store_true", + help="Skip on-chain baseline runs") + p.add_argument("--end-test-date", default=None, + help="Override endTestDateString (e.g. '2026-02-15 00:00:00')") + p.add_argument("--noise-trader-ratio", type=float, default=None, + help="Override noise_trader_ratio from results config") + return p.parse_args() + + +def load_results(path): + """Load the double-encoded JSONL from Optuna results.""" + with open(path) as f: + raw = f.read() + data = json.loads(raw) + if isinstance(data, str): + data = json.loads(data) + if not isinstance(data, list) or len(data) < 2: + print(f"ERROR: Expected [config, trial1, trial2, ...], got {type(data)}") + sys.exit(1) + config = data[0] + trials = data[1:] + return config, trials + + +def extract_pool_params(trial, config): + """Extract reClAMM pool params from a trial entry.""" + param_keys = ["price_ratio", "centeredness_margin", "shift_exponent", + "arc_length_speed", "fees"] + params = {} + for k in param_keys: + if k in trial: + params[k] = trial[k] + return params + + +def run_full_period(params, config, fees_override=None): + """Run forward pass over the full train+test window.""" + fees = fees_override if fees_override is not None else config["fees"] + fp = { + "rule": "reclamm", + "tokens": config["tokens"], + "startDateString": config["startDateString"], + "endDateString": config["endTestDateString"], # full period + "initial_pool_value": config["initial_pool_value"], + "do_arb": config["do_arb"], + "fees": fees, + "gas_cost": config.get("gas_cost", 1.0), + "arb_fees": config.get("arb_fees", 0.0), + "protocol_fee_split": config.get("protocol_fee_split", 0.0), + "noise_trader_ratio": config.get("noise_trader_ratio", 0.0), + "reclamm_use_shift_exponent": config.get("reclamm_use_shift_exponent", True), + "reclamm_interpolation_method": config.get("reclamm_interpolation_method", "geometric"), + "reclamm_centeredness_scaling": config.get("reclamm_centeredness_scaling", False), + "reclamm_learn_arc_length_speed": config.get("reclamm_learn_arc_length_speed", False), + } + jax_params = {k: jnp.array(v) for k, v in params.items()} + return do_run_on_historic_data(run_fingerprint=fp, params=jax_params) + + +def plot_results(configs, time_series, hodl_values, config, args): + """Two-panel plot: value-over-time + cumulative fee revenue.""" + train_end_str = config["endDateString"] + train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range( + start=datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S"), + periods=n_minutes, freq="1min", + ) + step = 1440 + dates_daily = dates[::step] + + has_fee_revenue = any( + "fee_revenue" in time_series[n] and time_series[n]["fee_revenue"] is not None + for n in time_series + ) + n_panels = 2 if has_fee_revenue else 1 + fig, axes = plt.subplots( + n_panels, 1, figsize=(14, 5 * n_panels), + sharex=True, gridspec_kw={"height_ratios": [3, 1] if n_panels == 2 else [1]}, + ) + if n_panels == 1: + axes = [axes] + ax_val = axes[0] + + # ── Panel 1: Value over time ────────────────────────────────────── + for name, meta, ci in _plot_order(configs): + out = time_series[name] + vals = np.array(out["value"][::step]) / 1e6 + label = f"{name}" + if "test_objective" in meta: + obj_name = config.get("return_val", "objective") + label += f" (OOS {obj_name}={meta['test_objective']:.4f})" + is_optimized = "On-Chain" not in name + ax_val.plot(dates_daily[:len(vals)], vals, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=label, + zorder=3 if is_optimized else 2) + + hodl_daily = hodl_values[::step] / 1e6 + ax_val.plot(dates_daily[:len(hodl_daily)], hodl_daily, linewidth=2, + color="white", alpha=0.7, linestyle="--", label="HODL") + + ax_val.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + ylims = ax_val.get_ylim() + ax_val.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + color="white", alpha=0.6, fontsize=11, ha="right", va="top") + ax_val.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + color="white", alpha=0.6, fontsize=11, ha="left", va="top") + + _style_axis(ax_val) + ax_val.set_ylabel("Pool Value ($M USD)", color=TEXT_COLOR, fontsize=12) + tokens_str = "/".join(config["tokens"]) + obj_name = config.get("return_val", "objective") + ntr = config.get("noise_trader_ratio", 0.0) + ax_val.set_title( + f"reClAMM Optuna-Optimized ({obj_name}, noise={ntr}) — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15, + ) + ax_val.legend(loc="upper left", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + # ── Panel 2: Cumulative fee revenue ─────────────────────────────── + if has_fee_revenue: + ax_fee = axes[1] + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + fr = out.get("fee_revenue") + if fr is None: + continue + fr = np.array(fr) + cumfee = np.cumsum(fr)[::step] / 1e3 + is_optimized = "On-Chain" not in name + ax_fee.plot(dates_daily[:len(cumfee)], cumfee, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=name, + zorder=3 if is_optimized else 2) + + ax_fee.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + _style_axis(ax_fee) + ax_fee.set_ylabel("Cumulative Fee Revenue ($K)", color=TEXT_COLOR, fontsize=12) + ax_fee.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax_fee.legend(loc="upper left", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + else: + ax_val.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + + output = args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png" + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"\nSaved plot to {output}") + plt.close() + + +def plot_test_only(configs, time_series, hodl_values, config, args): + """Test-period plot with all curves normalised to start at 1.0.""" + train_end_str = config["endDateString"] + train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") + start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range(start=start_dt, periods=n_minutes, freq="1min") + + # Find the index of the train/test boundary + train_minutes = int((train_end_dt - start_dt).total_seconds() / 60) + test_start_idx = min(train_minutes, n_minutes - 1) + + step = 1440 + test_dates = dates[test_start_idx::step] + + fig, ax = plt.subplots(1, 1, figsize=(14, 6)) + + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + vals = np.array(out["value"]) + test_vals = vals[test_start_idx::step] + if len(test_vals) == 0: + continue + normalised = test_vals / test_vals[0] + is_optimized = "On-Chain" not in name + ax.plot(test_dates[:len(normalised)], normalised, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=name, + zorder=3 if is_optimized else 2) + + hodl_test = hodl_values[test_start_idx::step] + if len(hodl_test) > 0: + hodl_norm = hodl_test / hodl_test[0] + ax.plot(test_dates[:len(hodl_norm)], hodl_norm, linewidth=2, + color="white", alpha=0.7, linestyle="--", label="HODL") + + ax.axhline(1.0, color="white", linestyle=":", alpha=0.3, linewidth=1) + _style_axis(ax) + tokens_str = "/".join(config["tokens"]) + obj_name = config.get("return_val", "objective") + ntr = config.get("noise_trader_ratio", 0.0) + ax.set_title(f"Test Period Only (normalised) — {obj_name}, noise={ntr} — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15) + ax.set_ylabel("Normalised Value", color=TEXT_COLOR, fontsize=12) + ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax.legend(loc="best", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + base = (args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png") + output = base.replace(".png", "_test_only.png") + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"Saved plot to {output}") + plt.close() + + +def plot_weights(configs, time_series, config, args): + """Effective weight (value fraction) of token 0 over time.""" + start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") + train_end_dt = datetime.strptime(config["endDateString"], "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range(start=start_dt, periods=n_minutes, freq="1min") + step = 1440 + dates_daily = dates[::step] + + token_name = config["tokens"][0] + + fig, ax = plt.subplots(1, 1, figsize=(14, 5)) + + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + weights = np.array(out["weights"]) # (T, 2) + w0 = weights[::step, 0] + is_optimized = "On-Chain" not in name + ax.plot(dates_daily[:len(w0)], w0, + linewidth=2.0 if is_optimized else 1.5, + color=COLORS[ci % len(COLORS)], label=name, + alpha=0.9 if is_optimized else 0.7, + zorder=3 if is_optimized else 2) + + ax.axhline(0.5, color="white", linestyle="--", alpha=0.3, linewidth=1) + ax.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + ylims = ax.get_ylim() + ax.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + color="white", alpha=0.6, fontsize=11, ha="right", va="top") + ax.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + color="white", alpha=0.6, fontsize=11, ha="left", va="top") + + _style_axis(ax) + tokens_str = "/".join(config["tokens"]) + ax.set_title(f"Effective {token_name} Weight — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15) + ax.set_ylabel(f"{token_name} weight (value fraction)", color=TEXT_COLOR, fontsize=12) + ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax.legend(loc="best", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + base = (args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png") + output = base.replace(".png", "_weights.png") + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"Saved plot to {output}") + plt.close() + + +def _style_axis(ax): + ax.set_facecolor(BG) + ax.tick_params(colors=TEXT_COLOR) + for spine in ax.spines.values(): + spine.set_color(TEXT_COLOR) + spine.set_alpha(0.3) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.grid(True, alpha=0.15, color=TEXT_COLOR) + + +def main(): + args = parse_args() + config, trials = load_results(args.results_json) + if args.end_test_date: + config["endTestDateString"] = args.end_test_date + if args.noise_trader_ratio is not None: + config["noise_trader_ratio"] = args.noise_trader_ratio + tokens = config["tokens"] + obj_name = config.get("return_val", "objective") + + # Sort trials by penalised objective + trials_sorted = sorted(trials, key=lambda t: t.get("objective", 0), reverse=True) + top_trials = trials_sorted[:args.top_k] + + print("=" * 80) + print(f"reClAMM Optuna Result Plotter — objective: {obj_name}") + print("=" * 80) + print(f" Results: {args.results_json}") + print(f" Tokens: {'/'.join(tokens)}") + print(f" Train: {config['startDateString']} → {config['endDateString']}") + print(f" Test: {config['endDateString']} → {config['endTestDateString']}") + print(f" Fees: {config['fees']}, Gas: {config.get('gas_cost', 1.0)}") + print(f" Trials: {len(trials)} total, plotting top {len(top_trials)}") + + configs = {} + for i, trial in enumerate(top_trials): + params = extract_pool_params(trial, config) + name = f"#{trial.get('optuna_trial_number', i)} (rank {i+1})" + configs[name] = { + "params": params, + "objective": trial.get("objective", 0), + "train_objective": trial.get("train_objective", 0), + "test_objective": trial.get("test_objective", 0), + "train_sharpe": trial.get("train_sharpe", 0), + "validation_sharpe": trial.get("validation_sharpe", 0), + } + print(f"\n {name}:") + print(f" {obj_name}: train={trial.get('train_objective', 0):.4f} " + f"test={trial.get('test_objective', 0):.4f} " + f"penalised={trial.get('objective', 0):.4f}") + print(f" sharpe: train={trial.get('train_sharpe', 0):+.4f} " + f"val={trial.get('validation_sharpe', 0):+.4f}") + for k, v in params.items(): + print(f" {k}: {v:.6g}") + + if not args.no_onchain: + configs["On-Chain (launch)"] = {"params": dict(ONCHAIN_LAUNCH_PARAMS)} + configs["On-Chain (current)"] = {"params": dict(ONCHAIN_CURRENT_PARAMS)} + + # ── Full-period runs ────────────────────────────────────────────── + print(f"\n--- Running full-period simulations ({config['startDateString']} → " + f"{config['endTestDateString']}) ---") + time_series = {} + for name, cfg in configs.items(): + print(f" {name}...", end=" ", flush=True) + out = run_full_period(cfg["params"], config) + time_series[name] = out + fv = float(out["final_value"]) + fr = out.get("fee_revenue") + fr_total = float(np.array(fr).sum()) if fr is not None else 0 + hodl = float((out["reserves"][0] * out["prices"][-1]).sum()) + print(f"final=${fv:,.0f} hodl=${hodl:,.0f} RoH={fv/hodl - 1:+.2%} " + f"fee_rev=${fr_total:,.0f}") + + first_out = next(iter(time_series.values())) + hodl_reserves = first_out["reserves"][0] + hodl_values = np.sum( + np.array(hodl_reserves) * np.array(first_out["prices"]), axis=1, + ) + + # ── Plots ───────────────────────────────────────────────────────── + plot_results(configs, time_series, hodl_values, config, args) + plot_test_only(configs, time_series, hodl_values, config, args) + plot_weights(configs, time_series, config, args) + + # ── Summary table ───────────────────────────────────────────────── + print(f"\n{'=' * 120}") + print(f"SUMMARY — {'/'.join(tokens)} — {obj_name}") + print(f"{'=' * 120}") + hdr = (f"{'Config':<28s} {'Train '+obj_name:>20s} {'Test '+obj_name:>20s} " + f"{'Train SR':>10s} {'Val SR':>10s} " + f"{'PR':>7s} {'Margin':>7s} {'ShiftExp':>10s} {'Full RoH':>10s}") + print(hdr) + print("-" * 120) + + for name, cfg in configs.items(): + cp = cfg["params"] + fv = float(time_series[name]["final_value"]) + full_roh = fv / float(hodl_values[-1]) - 1 + print( + f"{name:<28s} " + f"{cfg.get('train_objective', float('nan')):>20.4f} " + f"{cfg.get('test_objective', float('nan')):>20.4f} " + f"{cfg.get('train_sharpe', float('nan')):>+10.4f} " + f"{cfg.get('validation_sharpe', float('nan')):>+10.4f} " + f"{cp.get('price_ratio', float('nan')):>7.3f} " + f"{cp.get('centeredness_margin', float('nan')):>7.4f} " + f"{cp.get('shift_exponent', float('nan')):>10.4g} " + f"{full_roh * 100:>+9.2f}%" + ) + print("=" * 120) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_top50_predicted_vs_real.py b/scripts/plot_top50_predicted_vs_real.py new file mode 100644 index 0000000..0951a20 --- /dev/null +++ b/scripts/plot_top50_predicted_vs_real.py @@ -0,0 +1,521 @@ +"""Plot predicted vs real volume for top 50 pools by TVL on Feb 1st 2026. + +Enumerates WEIGHTED (min_tvl=1000) and RECLAMM (min_tvl=0) pools, +fetches their snapshots, filters to those with TVL >= $10k on Feb 1st 2026, +takes the top 50 by TVL, and plots predicted vs actual daily volume using +the inference artifact from calibrate_noise_unified.py. + +For pools that were in the model's training set, uses their per-pool theta. +For pools not in the training set, uses population-level prediction from B. +""" + +import ast +import json +import os +import sys +import time +from datetime import date, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# Reuse functions from the noise calibration package +from quantammsim.noise_calibration import ( + BALANCER_API_CHAINS, + _graphql_request, + assemble_panel, + classify_token_tier, + encode_covariates, + fetch_pool_snapshots, + fetch_token_prices, +) + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_top50" +) +TVL_DATE = date(2026, 2, 1) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "top50_feb1" +) +# Inference artifact from the main unified model run +FITTED_JSON = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "unified_full_90d.json" +) + + +def enumerate_all_pools(): + """Enumerate WEIGHTED (min_tvl=1000) and RECLAMM (min_tvl=0) pools.""" + all_pools = [] + + for chain in BALANCER_API_CHAINS: + for pool_type, min_tvl in [("WEIGHTED", 1000), ("RECLAMM", 0)]: + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { chainIn: [$chain], poolTypeIn: $types, minTvl: $minTvl } + ) { + id chain type protocolVersion + poolTokens { symbol weight address } + dynamicData { totalLiquidity swapFee } + } + } + """, + "variables": { + "chain": chain, + "types": [pool_type], + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f" FAILED {chain} {pool_type}: {e}") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "tokens": tokens, + "token_addresses": addresses, + "swap_fee": fee, + "current_tvl": tvl, + }) + + if pools: + print(f" {chain:>10} {pool_type:>10}: {len(pools)}") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools") + return df + + +def fetch_all_snapshots_cached(pools_df, cache_dir): + """Fetch snapshots for all pools, caching per-pool.""" + snap_dir = os.path.join(cache_dir, "snapshots") + os.makedirs(snap_dir, exist_ok=True) + + all_snaps = [] + n = len(pools_df) + + for i, (_, pool) in enumerate(pools_df.iterrows()): + pid = pool["pool_id"] + chain = pool["chain"] + cache_file = os.path.join(snap_dir, f"{pid}.parquet") + + if os.path.exists(cache_file): + df = pd.read_parquet(cache_file) + else: + if (i + 1) % 20 == 0 or i == 0: + print(f" Fetching snapshots {i+1}/{n}...", flush=True) + try: + df = fetch_pool_snapshots(pid, chain) + if len(df) > 0: + df.to_parquet(cache_file, index=False) + time.sleep(0.3) + except Exception as e: + print(f" FAILED {pid[:20]}: {e}") + continue + + if len(df) > 0: + df["pool_id"] = pid + df["chain"] = chain + all_snaps.append(df) + + if all_snaps: + return pd.concat(all_snaps, ignore_index=True) + return pd.DataFrame() + + +def get_tvl_on_date(snapshots_df, target_date, window_days=3): + """Get TVL for each pool on/near target_date.""" + results = [] + for pid in snapshots_df["pool_id"].unique(): + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pid] + + best_row = None + best_dist = float("inf") + for _, row in pool_snaps.iterrows(): + d = row["date"] + if isinstance(d, date): + dist = abs((d - target_date).days) + else: + dist = abs((pd.Timestamp(d).date() - target_date).days) + if dist < best_dist: + best_dist = dist + best_row = row + + if best_row is not None and best_dist <= window_days: + results.append({ + "pool_id": pid, + "tvl_feb1": float(best_row["total_liquidity_usd"]), + "date_used": best_row["date"], + }) + + return pd.DataFrame(results) + + +def _get_theta_for_pool(pid, fitted, panel_90d, pop_B, pop_cov_names): + """Get theta for a pool: from fitted artifact if available, else population. + + Returns (theta, source) where source is 'fitted' or 'population'. + For IBP models, population fallback adds marginal feature effect (pi @ W). + """ + if pid in fitted["pools"]: + return np.array(fitted["pools"][pid]["theta_median"]), "fitted" + + # Population-level prediction: theta = B @ z_pool + # Build z_pool from the pool's covariates + pp = panel_90d[panel_90d["pool_id"] == pid] + if len(pp) == 0: + return None, "no_data" + + chain = pp["chain"].iloc[0] + tokens = pp["tokens"].iloc[0] + if isinstance(tokens, str): + tokens = tokens.split(",") + fee = pp["swap_fee"].iloc[0] if "swap_fee" in pp.columns else 0.003 + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + + # Build covariate vector matching the model's encoding + z = np.zeros(len(pop_cov_names)) + for i, name in enumerate(pop_cov_names): + if name == "intercept": + z[i] = 1.0 + elif name == f"chain_{chain}": + z[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z[i] = 1.0 + elif name == "log_fee": + z[i] = np.log(max(fee, 1e-6)) + + # B is (K_coeff, K_cov), theta = B @ z + B = np.array(pop_B) # (K_coeff, K_cov) + theta = B @ z + + # IBP: add marginal feature effect (pi @ W) + pop = fitted["population_effects"] + if "W" in pop and "feature_prevalences" in pop: + W = np.array(pop["W"]) # (K_features, K_coeff) + pi = np.array(pop["feature_prevalences"]) # (K_features,) + theta = theta + pi @ W + + return theta, "population" + + +def plot_pages(plot_pools, pool_idx_map_fitted, fitted, panel_90d, + pools_df, tvl_lookup, pop_B, pop_cov_names, + output_dir=OUTPUT_DIR): + """Generate paginated plots, 10 pools per page.""" + n_pools = len(plot_pools) + per_page = 10 + n_pages = (n_pools + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, n_pools) + page_pools = plot_pools[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(14, 4 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + + for idx, (pid, feb_tvl) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + + theta, source = _get_theta_for_pool( + pid, fitted, panel_90d, pop_B, pop_cov_names + ) + if theta is None: + ax.set_visible(False) + continue + + pp = panel_90d[panel_90d["pool_id"] == pid].sort_values("date") + if len(pp) < 5: + ax.set_visible(False) + continue + + x_obs = np.column_stack([ + np.ones(len(pp)), + pp["log_tvl_lag1"].values, + pp["volatility"].values, + pp["weekend"].values, + ]) + + pred_log = x_obs @ theta + actual_log = pp["log_volume"].values + pred_vol = np.exp(pred_log) + actual_vol = np.exp(actual_log) + dates = pd.to_datetime(pp["date"].values) + + ax.plot(dates, actual_vol, "o-", color="steelblue", markersize=2.5, + linewidth=0.9, alpha=0.7, label="Actual") + ax.plot(dates, pred_vol, "s--", color="orangered", markersize=2.5, + linewidth=0.9, alpha=0.7, label="Predicted") + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + ss_res = np.sum((actual_log - pred_log) ** 2) + ss_tot = np.sum((actual_log - actual_log.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + meta = pools_df[pools_df["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tokens = m["tokens"] + tok_str = "/".join(str(t)[:8] for t in tokens[:2]) + chain = str(m["chain"]) + ptype = str(m["pool_type"]) + else: + tok_str = pid[:16] + chain = "?" + ptype = "?" + + type_tag = "R" if ptype == "RECLAMM" else "W" + src_tag = "*" if source == "population" else "" + ax.set_title( + "{} ({}, {}){}\n" + "TVL ${:,.0f} on Feb 1 | " + "R\u00b2={:.3f} b_c={:.2f} b_\u03c3={:.2f} " + "b_wknd={:.2f} n={}".format( + tok_str, chain, type_tag, src_tag, feb_tvl, + r2, theta[1], theta[2], theta[3], len(pp)), + fontsize=8) + ax.legend(fontsize=7) + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + "Predicted vs actual daily volume \u2014 page {}/{} " + "(sorted by TVL on {}) [* = population prediction]".format( + page + 1, n_pages, TVL_DATE), + fontsize=11) + fig.tight_layout() + out = os.path.join(output_dir, "pred_vs_real_page{}.png".format(page + 1)) + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved: {}".format(out)) + + +def main(): + import argparse + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--artifact", default=FITTED_JSON, + help="Path to inference artifact JSON") + parser.add_argument("--output-dir", default=None, + help="Output directory (default: auto from model name)") + args = parser.parse_args() + + artifact_path = args.artifact + os.makedirs(CACHE_DIR, exist_ok=True) + + # ---- Load inference artifact ---- + print(f"Loading inference artifact: {artifact_path}") + with open(artifact_path) as f: + fitted = json.load(f) + n_fitted = len(fitted["pools"]) + model_name = fitted.get("model", "unknown") + print(f" Model: {model_name}") + print(f" {n_fitted} pools with fitted theta") + + # Output dir: use CLI override or auto from model name + output_dir = args.output_dir or os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", f"top50_feb1_{model_name}", + ) + os.makedirs(output_dir, exist_ok=True) + + # Extract population-level B matrix for pools not in the model + pop_cov_names = fitted["model_spec"]["covariate_names"] + pop_B = np.array(fitted["population_effects"]["B"]) # (K_coeff, K_cov) + print(f" Population B: {pop_B.shape}, covariates: {pop_cov_names}") + + # ---- Step 1: Enumerate pools ---- + pools_cache = os.path.join(CACHE_DIR, "pools.parquet") + if os.path.exists(pools_cache): + pools_df = pd.read_parquet(pools_cache) + if isinstance(pools_df["tokens"].iloc[0], str): + pools_df["tokens"] = pools_df["tokens"].apply(ast.literal_eval) + pools_df["token_addresses"] = pools_df["token_addresses"].apply( + ast.literal_eval + ) + print(f"\nLoaded {len(pools_df)} pools from cache") + else: + print("\n1. Enumerating pools...") + pools_df = enumerate_all_pools() + pools_df.to_parquet(pools_cache, index=False) + + # ---- Step 2: Fetch snapshots ---- + print("\n2. Fetching snapshots...") + snapshots_df = fetch_all_snapshots_cached(pools_df, CACHE_DIR) + print(f" {len(snapshots_df)} pool-days") + + # ---- Step 3: TVL on Feb 1st ---- + print(f"\n3. Finding TVL on {TVL_DATE}...") + tvl_df = get_tvl_on_date(snapshots_df, TVL_DATE) + tvl_df = tvl_df[tvl_df["tvl_feb1"] >= 10_000].copy() + tvl_df = tvl_df.sort_values("tvl_feb1", ascending=False).head(50) + top50_ids = set(tvl_df["pool_id"]) + tvl_lookup = dict(zip(tvl_df["pool_id"], tvl_df["tvl_feb1"])) + print(f" {len(tvl_df)} pools with TVL >= $10k") + + # How many are in the fitted model? + n_in_model = sum(1 for pid in top50_ids if pid in fitted["pools"]) + print(f" {n_in_model} in fitted model, " + f"{len(top50_ids) - n_in_model} will use population prediction") + + # ---- Step 4: Fetch token prices & assemble panel ---- + panel_cache = os.path.join(CACHE_DIR, "panel.parquet") + if os.path.exists(panel_cache): + panel = pd.read_parquet(panel_cache) + print(f"\n4. Loaded panel from cache: {len(panel)} obs") + else: + top50_pools = pools_df[pools_df["pool_id"].isin(top50_ids)].copy() + top50_snaps = snapshots_df[snapshots_df["pool_id"].isin(top50_ids)].copy() + + print("\n4. Fetching token prices...") + prices_cache = os.path.join(CACHE_DIR, "token_prices") + token_addr_by_chain = {} + for _, pool in top50_pools.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + print("\n Assembling panel...") + panel = assemble_panel(top50_pools, top50_snaps, token_prices) + panel.to_parquet(panel_cache, index=False) + + # ---- Step 5: Filter to 90 days ---- + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=90) + panel_90d = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if "log_tvl_lag1" not in panel_90d.columns: + panel_90d = panel_90d.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel_90d["log_tvl_lag1"] = panel_90d.groupby("pool_id")["log_tvl"].shift(1) + panel_90d = panel_90d.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel_90d.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel_90d = panel_90d[panel_90d["pool_id"].isin(valid_pools)].copy() + print(f"\n5. 90-day panel: {len(panel_90d)} obs, " + f"{panel_90d['pool_id'].nunique()} pools") + + # ---- Step 6: Plot ---- + # Sort by Feb 1 TVL, only include pools with panel data + plot_pools = [] + for pid in panel_90d["pool_id"].unique(): + if pid in tvl_lookup: + plot_pools.append((pid, tvl_lookup[pid])) + plot_pools.sort(key=lambda x: -x[1]) + print(f"\n6. Plotting {len(plot_pools)} pools...") + + plot_pages(plot_pools, fitted, fitted, panel_90d, pools_df, + tvl_lookup, pop_B, pop_cov_names, output_dir=output_dir) + + # ---- Summary table ---- + summary = [] + for pid, feb_tvl in plot_pools: + theta, source = _get_theta_for_pool( + pid, fitted, panel_90d, pop_B, pop_cov_names + ) + if theta is None: + continue + pp = panel_90d[panel_90d["pool_id"] == pid] + x = np.column_stack([ + np.ones(len(pp)), + pp["log_tvl_lag1"].values, + pp["volatility"].values, + pp["weekend"].values, + ]) + pred = x @ theta + actual = pp["log_volume"].values + ss_res = np.sum((actual - pred) ** 2) + ss_tot = np.sum((actual - actual.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + meta = pools_df[pools_df["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tok_str = "/".join(str(t) for t in m["tokens"][:2]) + chain = str(m["chain"]) + ptype = str(m["pool_type"]) + else: + tok_str = pid[:16] + chain = "?" + ptype = "?" + + summary.append({ + "pool_id": pid[:20], + "tokens": tok_str, + "chain": chain, + "type": ptype, + "tvl_feb1": feb_tvl, + "n_obs": len(pp), + "R2": r2, + "b_c": theta[1], + "b_sigma": theta[2], + "b_weekend": theta[3], + "source": source, + }) + + summary_df = pd.DataFrame(summary) + summary_path = os.path.join(output_dir, "top50_summary.csv") + summary_df.to_csv(summary_path, index=False) + print(f"\n Saved: {summary_path}") + + n_pools = len(summary_df) + n_fitted_used = (summary_df["source"] == "fitted").sum() + n_pop = (summary_df["source"] == "population").sum() + n_reclamm = (summary_df["type"] == "RECLAMM").sum() + print(f"\n{'='*70}") + print(f"Summary: {n_pools} pools ({n_fitted_used} fitted, {n_pop} population)") + print(f" RECLAMM: {n_reclamm} WEIGHTED: {n_pools - n_reclamm}") + print(f" Median R\u00b2: {summary_df['R2'].median():.3f}") + print(f" Mean b_c: {summary_df['b_c'].mean():.3f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_direct_calibration_top50.py b/scripts/run_direct_calibration_top50.py new file mode 100644 index 0000000..1c1770b --- /dev/null +++ b/scripts/run_direct_calibration_top50.py @@ -0,0 +1,852 @@ +"""Run direct calibration pipeline and plot top-50 style decomposition. + +Steps: + 1. Load panel, match to per-day grids in results/pool_grids_v2/ + 2. Option C: per-pool L-BFGS-B fits + 3. Option A: joint end-to-end optimization (warm-started from C) + 4. Paginated plots: V_arb + V_noise decomposition per pool + 5. Summary plots: cadence, gas, R², arb fraction distributions +""" + +import json +import os +import sys +from datetime import date, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +GRID_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "pool_grids_v2", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "direct_calibration_top50", +) +TRAIN_DAYS = 90 +TOP_N = 50 +OPTION_C_MAXITER = 500 +JOINT_MAXITER = 500 +OPTION_C_LOSS_CUTOFF = 5.0 # Drop pools with Option C loss above this from joint fit + + +def load_and_match(): + """Load panel, filter to 90 days, match to grids.""" + panel = pd.read_parquet(PANEL_CACHE) + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=TRAIN_DAYS) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if "log_tvl_lag1" not in panel.columns: + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel.groupby("pool_id").size() + valid = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid)].copy() + + print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " + f"{cutoff} to {max_date}") + + from quantammsim.calibration.pool_data import match_grids_to_panel + matched = match_grids_to_panel(GRID_DIR, panel) + print(f"Matched: {len(matched)} pools with grids") + + return panel, matched + + +def run_option_c(matched): + """Per-pool L-BFGS-B fits.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + print(f"\n--- Option C: per-pool fits ({len(matched)} pools) ---") + results = fit_all_pools(matched) + n_converged = sum(1 for r in results.values() if r["converged"]) + losses = [r["loss"] for r in results.values()] + print(f" Converged: {n_converged}/{len(results)}") + print(f" Loss: median={np.median(losses):.4f}, " + f"mean={np.mean(losses):.4f}, " + f"range=[{np.min(losses):.4f}, {np.max(losses):.4f}]") + return results + + +def run_option_a(matched, option_c_results): + """Joint end-to-end optimization, warm-started from Option C. + + Drops pathological pools (Option C loss > OPTION_C_LOSS_CUTOFF) from the + joint fit to prevent them from dominating the shared mapping. + """ + from quantammsim.calibration.joint_fit import fit_joint + + # Filter out pathological pools + good_pools = {p: r for p, r in option_c_results.items() + if r["loss"] <= OPTION_C_LOSS_CUTOFF} + dropped = set(option_c_results) - set(good_pools) + matched_clean = {p: matched[p] for p in good_pools if p in matched} + + if dropped: + print(f"\n Dropping {len(dropped)} pathological pools (Option C loss > {OPTION_C_LOSS_CUTOFF}):") + for p in sorted(dropped): + r = option_c_results[p] + print(f" {p} {r['tokens']:<16} loss={r['loss']:.1f}") + + print(f"\n--- Option A: joint fit (per_pool_noise, {len(matched_clean)} pools, " + f"warm-start from C, no chain dummies) ---") + result_ppn = fit_joint( + matched_clean, + mode="per_pool_noise", + init_from_option_c=good_pools, + maxiter=JOINT_MAXITER, + drop_chain_dummies=True, + ) + print(f" Loss: {result_ppn['init_loss']:.4f} -> {result_ppn['loss']:.4f}") + print(f" Converged: {result_ppn['converged']}") + + print(f"\n--- Option A: joint fit (shared_noise, {len(matched_clean)} pools, " + f"warm-start from C, no chain dummies) ---") + result_sn = fit_joint( + matched_clean, + mode="shared_noise", + init_from_option_c=good_pools, + maxiter=JOINT_MAXITER, + drop_chain_dummies=True, + ) + print(f" Loss: {result_sn['init_loss']:.4f} -> {result_sn['loss']:.4f}") + print(f" Converged: {result_sn['converged']}") + + return result_ppn, result_sn + + +def run_option_rf(matched, option_c_results): + """2-stage approach: Option C per-pool fits → Ridge/RF on pool attributes. + + Drops chain dummies (too sparse for n~30), keeps 6 continuous/binary features. + Trains both Ridge regression and RF, reports both. LOO-CV for generalization. + + Noise coefficients are taken directly from Option C (per-pool). + Drops pathological pools (Option C loss > OPTION_C_LOSS_CUTOFF). + """ + from sklearn.ensemble import RandomForestRegressor + from sklearn.linear_model import RidgeCV + from sklearn.model_selection import LeaveOneOut + from quantammsim.calibration.pool_data import build_pool_attributes + + # Filter pathological pools + good_pools = {p: r for p, r in option_c_results.items() + if r["loss"] <= OPTION_C_LOSS_CUTOFF} + dropped = set(option_c_results) - set(good_pools) + matched_clean = {p: matched[p] for p in good_pools if p in matched} + + if dropped: + print(f"\n Dropping {len(dropped)} pathological pools (Option C loss > {OPTION_C_LOSS_CUTOFF}):") + for p in sorted(dropped): + r = option_c_results[p] + print(f" {p} {r['tokens']:<16} loss={r['loss']:.1f}") + + # Build attributes and targets + X_attr_full, attr_names_full, pool_ids = build_pool_attributes(matched_clean) + n_pools = len(pool_ids) + + # Drop chain dummies — too sparse for n~30. Keep only continuous/binary features. + non_chain_mask = [i for i, name in enumerate(attr_names_full) + if not name.startswith("chain_")] + X_attr = X_attr_full[:, non_chain_mask] + attr_names = [attr_names_full[i] for i in non_chain_mask] + k_attr = len(attr_names) + + print(f"\n--- Option RF: 2-stage mapping ({n_pools} pools, {k_attr} features) ---") + print(f" Features: {', '.join(attr_names)}") + + Y_cad = np.array([good_pools[p]["log_cadence"] for p in pool_ids]) + Y_gas = np.array([good_pools[p]["log_gas"] for p in pool_ids]) + Y = np.column_stack([Y_cad, Y_gas]) + + ss_tot_cad = np.sum((Y_cad - Y_cad.mean()) ** 2) + ss_tot_gas = np.sum((Y_gas - Y_gas.mean()) ** 2) + + def compute_r2(y_true, y_pred, ss_tot): + return 1 - np.sum((y_true - y_pred) ** 2) / max(ss_tot, 1e-10) + + # ---- Ridge regression (multi-output via separate fits) ---- + alphas = np.logspace(-2, 4, 50) + ridge_cad = RidgeCV(alphas=alphas, cv=None) # GCV/LOO built-in + ridge_gas = RidgeCV(alphas=alphas, cv=None) + ridge_cad.fit(X_attr, Y_cad) + ridge_gas.fit(X_attr, Y_gas) + + Y_ridge_train = np.column_stack([ridge_cad.predict(X_attr), + ridge_gas.predict(X_attr)]) + r2_ridge_cad = compute_r2(Y_cad, Y_ridge_train[:, 0], ss_tot_cad) + r2_ridge_gas = compute_r2(Y_gas, Y_ridge_train[:, 1], ss_tot_gas) + + print(f"\n Ridge (alpha_cad={ridge_cad.alpha_:.1f}, alpha_gas={ridge_gas.alpha_:.1f}):") + print(f" In-sample R²: cadence={r2_ridge_cad:.3f}, gas={r2_ridge_gas:.3f}") + + # Ridge LOO-CV + loo = LeaveOneOut() + Y_ridge_loo = np.zeros_like(Y) + for train_idx, test_idx in loo.split(X_attr): + rc = RidgeCV(alphas=alphas, cv=None).fit(X_attr[train_idx], Y_cad[train_idx]) + rg = RidgeCV(alphas=alphas, cv=None).fit(X_attr[train_idx], Y_gas[train_idx]) + Y_ridge_loo[test_idx, 0] = rc.predict(X_attr[test_idx]) + Y_ridge_loo[test_idx, 1] = rg.predict(X_attr[test_idx]) + + r2_ridge_loo_cad = compute_r2(Y_cad, Y_ridge_loo[:, 0], ss_tot_cad) + r2_ridge_loo_gas = compute_r2(Y_gas, Y_ridge_loo[:, 1], ss_tot_gas) + print(f" LOO-CV R²: cadence={r2_ridge_loo_cad:.3f}, gas={r2_ridge_loo_gas:.3f}") + print(f" LOO-CV MAE: cadence={np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_ridge_loo[:, 0]))):.1f} min, " + f"gas=${np.mean(np.abs(np.exp(Y_gas) - np.exp(Y_ridge_loo[:, 1]))):.2f}") + + # Ridge coefficients + print(f" Coefficients (cadence | gas):") + print(f" {'intercept':<20} {ridge_cad.intercept_:>7.3f} {ridge_gas.intercept_:>7.3f}") + for j, name in enumerate(attr_names): + print(f" {name:<20} {ridge_cad.coef_[j]:>7.3f} {ridge_gas.coef_[j]:>7.3f}") + + # ---- Random Forest (reduced features) ---- + rf = RandomForestRegressor( + n_estimators=200, + max_depth=None, + min_samples_leaf=3, # stronger regularization + max_features=min(4, k_attr), # cap at 4 features per split + random_state=42, + n_jobs=-1, + ) + rf.fit(X_attr, Y) + Y_rf_train = rf.predict(X_attr) + + r2_rf_cad = compute_r2(Y_cad, Y_rf_train[:, 0], ss_tot_cad) + r2_rf_gas = compute_r2(Y_gas, Y_rf_train[:, 1], ss_tot_gas) + + print(f"\n Random Forest (min_leaf=3, max_feat=4):") + print(f" In-sample R²: cadence={r2_rf_cad:.3f}, gas={r2_rf_gas:.3f}") + + # RF LOO-CV + Y_rf_loo = np.zeros_like(Y) + for train_idx, test_idx in loo.split(X_attr): + rf_loo = RandomForestRegressor( + n_estimators=200, max_depth=None, min_samples_leaf=3, + max_features=min(4, k_attr), random_state=42, n_jobs=-1, + ) + rf_loo.fit(X_attr[train_idx], Y[train_idx]) + Y_rf_loo[test_idx] = rf_loo.predict(X_attr[test_idx]) + + r2_rf_loo_cad = compute_r2(Y_cad, Y_rf_loo[:, 0], ss_tot_cad) + r2_rf_loo_gas = compute_r2(Y_gas, Y_rf_loo[:, 1], ss_tot_gas) + print(f" LOO-CV R²: cadence={r2_rf_loo_cad:.3f}, gas={r2_rf_loo_gas:.3f}") + print(f" LOO-CV MAE: cadence={np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_rf_loo[:, 0]))):.1f} min, " + f"gas=${np.mean(np.abs(np.exp(Y_gas) - np.exp(Y_rf_loo[:, 1]))):.2f}") + + print(f"\n Feature importances:") + for j, name in enumerate(attr_names): + print(f" {name:<20} {rf.feature_importances_[j]:.3f}") + + # ---- Pick best LOO model ---- + ridge_loo_total = r2_ridge_loo_cad + r2_ridge_loo_gas + rf_loo_total = r2_rf_loo_cad + r2_rf_loo_gas + best = "ridge" if ridge_loo_total >= rf_loo_total else "rf" + print(f"\n Best LOO model: {best} (ridge={ridge_loo_total:.3f} vs rf={rf_loo_total:.3f})") + + if best == "ridge": + Y_best_train = Y_ridge_train + Y_best_loo = Y_ridge_loo + r2_best_cad = r2_ridge_cad + r2_best_gas = r2_ridge_gas + r2_best_loo_cad = r2_ridge_loo_cad + r2_best_loo_gas = r2_ridge_loo_gas + else: + Y_best_train = Y_rf_train + Y_best_loo = Y_rf_loo + r2_best_cad = r2_rf_cad + r2_best_gas = r2_rf_gas + r2_best_loo_cad = r2_rf_loo_cad + r2_best_loo_gas = r2_rf_loo_gas + + # Build result dict using the best model's predictions + noise_all = np.array([good_pools[p]["noise_coeffs"] for p in pool_ids]) + + result = { + "pool_ids": pool_ids, + "attr_names": attr_names, + "X_attr": X_attr, + "best_model": best, + "predictions": {}, + "loo_predictions": {}, + "noise_coeffs": noise_all, # from Option C + "r2_train_cad": r2_best_cad, + "r2_train_gas": r2_best_gas, + "r2_loo_cad": r2_best_loo_cad, + "r2_loo_gas": r2_best_loo_gas, + } + + for i, pid in enumerate(pool_ids): + result["predictions"][pid] = { + "log_cadence": float(Y_best_train[i, 0]), + "log_gas": float(Y_best_train[i, 1]), + } + result["loo_predictions"][pid] = { + "log_cadence": float(Y_best_loo[i, 0]), + "log_gas": float(Y_best_loo[i, 1]), + } + + return result + + +def compute_per_pool_predictions(matched, option_c_results, joint_result, rf_result=None): + """Compute per-observation V_arb, V_noise, V_total for each pool. + + Pools not in the joint result (dropped as pathological) get NaN for + Option A predictions. Same for RF. + """ + from quantammsim.calibration.pool_data import build_x_obs, build_pool_attributes + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.loss import K_OBS + import jax.numpy as jnp + + pool_ids = sorted(matched.keys()) + + # Build attributes for the joint-fitted pool subset + joint_pool_ids = joint_result["pool_ids"] + joint_matched = {p: matched[p] for p in joint_pool_ids if p in matched} + X_attr_joint_full, attr_names_full, _ = build_pool_attributes(joint_matched) + # Filter to the features actually used by the joint model + joint_attr_names = joint_result["attr_names"] + joint_feat_idx = [attr_names_full.index(n) for n in joint_attr_names + if n in attr_names_full] + X_attr_joint = X_attr_joint_full[:, joint_feat_idx] + joint_pid_to_idx = {p: i for i, p in enumerate(joint_pool_ids)} + + # RF predictions lookup + rf_pool_ids = rf_result["pool_ids"] if rf_result else [] + rf_pid_to_idx = {p: i for i, p in enumerate(rf_pool_ids)} + + predictions = {} + for pid in pool_ids: + entry = matched[pid] + panel = entry["panel"] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + + def r2(v_arb, v_noise, y): + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y) ** 2) + ss_tot = np.sum((y - y.mean()) ** 2) + return 1 - ss_res / max(ss_tot, 1e-10) + + # --- Option C predictions --- + r = option_c_results[pid] + log_cad_c = r["log_cadence"] + log_gas_c = r["log_gas"] + noise_c_c = r["noise_coeffs"] + + v_arb_all_c = np.array(interpolate_pool_daily( + coeffs, jnp.float64(log_cad_c), jnp.float64(np.exp(log_gas_c)), + )) + v_arb_c = v_arb_all_c[day_indices] + v_noise_c = np.exp(x_obs @ noise_c_c) + + # --- Option A (per_pool_noise) predictions --- + if pid in joint_pid_to_idx: + ji = joint_pid_to_idx[pid] + x_attr = X_attr_joint[ji] + log_cad_a = float(joint_result["bias_cad"]) + float(x_attr @ joint_result["W_cad"]) + log_gas_a = float(joint_result["bias_gas"]) + float(x_attr @ joint_result["W_gas"]) + noise_c_a = joint_result["noise_coeffs"][ji] + + v_arb_all_a = np.array(interpolate_pool_daily( + coeffs, jnp.float64(log_cad_a), jnp.float64(np.exp(log_gas_a)), + )) + v_arb_a = v_arb_all_a[day_indices] + v_noise_a = np.exp(x_obs @ noise_c_a) + r2_a = r2(v_arb_a, v_noise_a, y_obs) + cad_a = np.exp(log_cad_a) + gas_a = np.exp(log_gas_a) + else: + v_arb_a = np.full(len(y_obs), np.nan) + v_noise_a = np.full(len(y_obs), np.nan) + r2_a = np.nan + cad_a = np.nan + gas_a = np.nan + + # --- Option RF predictions --- + if rf_result and pid in rf_pid_to_idx: + ri = rf_pid_to_idx[pid] + rf_pred = rf_result["predictions"][pid] + log_cad_rf = rf_pred["log_cadence"] + log_gas_rf = rf_pred["log_gas"] + noise_c_rf = rf_result["noise_coeffs"][ri] # from Option C + + v_arb_all_rf = np.array(interpolate_pool_daily( + coeffs, jnp.float64(log_cad_rf), jnp.float64(np.exp(log_gas_rf)), + )) + v_arb_rf = v_arb_all_rf[day_indices] + v_noise_rf = np.exp(x_obs @ noise_c_rf) + r2_rf = r2(v_arb_rf, v_noise_rf, y_obs) + cad_rf = np.exp(log_cad_rf) + gas_rf = np.exp(log_gas_rf) + + # LOO predictions (out-of-sample) + loo_pred = rf_result["loo_predictions"][pid] + log_cad_loo = loo_pred["log_cadence"] + log_gas_loo = loo_pred["log_gas"] + v_arb_all_loo = np.array(interpolate_pool_daily( + coeffs, jnp.float64(log_cad_loo), jnp.float64(np.exp(log_gas_loo)), + )) + v_arb_loo = v_arb_all_loo[day_indices] + v_noise_loo = np.exp(x_obs @ noise_c_rf) # same noise coeffs + r2_loo = r2(v_arb_loo, v_noise_loo, y_obs) + cad_loo = np.exp(log_cad_loo) + gas_loo = np.exp(log_gas_loo) + else: + v_arb_rf = np.full(len(y_obs), np.nan) + v_noise_rf = np.full(len(y_obs), np.nan) + r2_rf = np.nan + cad_rf = np.nan + gas_rf = np.nan + r2_loo = np.nan + cad_loo = np.nan + gas_loo = np.nan + + predictions[pid] = { + "dates": pd.to_datetime(panel["date"].values), + "y_obs": y_obs, + "actual_vol": np.exp(y_obs), + # Option C + "v_arb_c": v_arb_c, + "v_noise_c": v_noise_c, + "r2_c": r2(v_arb_c, v_noise_c, y_obs), + "cadence_c": np.exp(log_cad_c), + "gas_c": np.exp(log_gas_c), + "converged_c": r["converged"], + # Option A + "v_arb_a": v_arb_a, + "v_noise_a": v_noise_a, + "r2_a": r2_a, + "cadence_a": cad_a, + "gas_a": gas_a, + # Option RF (in-sample) + "v_arb_rf": v_arb_rf, + "v_noise_rf": v_noise_rf, + "r2_rf": r2_rf, + "cadence_rf": cad_rf, + "gas_rf": gas_rf, + # Option RF LOO (out-of-sample) + "r2_rf_loo": r2_loo, + "cadence_rf_loo": cad_loo, + "gas_rf_loo": gas_loo, + # Metadata + "chain": entry["chain"], + "tokens": entry["tokens"], + "fee": entry["fee"], + "median_tvl": float(np.exp(panel["log_tvl_lag1"].median())), + "n_obs": len(y_obs), + } + + return predictions + + +def plot_top50_pages(predictions, method="c"): + """Paginated plots: V_arb + V_noise decomposition.""" + # Rank by median TVL + ranked = sorted( + predictions.items(), + key=lambda x: -x[1]["median_tvl"], + )[:TOP_N] + + suffix = {"c": "option_c", "a": "option_a", "rf": "option_rf"}[method] + per_page = 10 + n_pages = (len(ranked) + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, len(ranked)) + page_pools = ranked[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4.5 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + + for idx, (pid, p) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + dates = p["dates"] + + if method == "c": + v_arb = p["v_arb_c"] + v_noise = p["v_noise_c"] + r2_val = p["r2_c"] + cad = p["cadence_c"] + gas = p["gas_c"] + elif method == "rf": + v_arb = p["v_arb_rf"] + v_noise = p["v_noise_rf"] + r2_val = p["r2_rf"] + cad = p["cadence_rf"] + gas = p["gas_rf"] + else: + v_arb = p["v_arb_a"] + v_noise = p["v_noise_a"] + r2_val = p["r2_a"] + cad = p["cadence_a"] + gas = p["gas_a"] + + # Skip pools with NaN predictions (dropped from this method) + if np.any(np.isnan(v_arb)): + ax.text(0.5, 0.5, f"Dropped from {method.upper()}", fontsize=12, + ha="center", va="center", transform=ax.transAxes, color="gray") + ax.set_title(f"{pid[:16]} — dropped", fontsize=8) + continue + + v_total = v_arb + v_noise + arb_frac = np.median(v_arb / np.maximum(v_total, 1.0)) + actual = p["actual_vol"] + + # Stacked area: V_arb bottom, V_noise on top + ax.fill_between(dates, 0, np.maximum(v_arb, 0), + alpha=0.3, color="orangered", label="V_arb (grid)") + ax.fill_between(dates, np.maximum(v_arb, 0), np.maximum(v_total, 0), + alpha=0.3, color="steelblue", label="V_noise (covariates)") + ax.plot(dates, actual, "k-", linewidth=0.8, alpha=0.7, label="Actual") + ax.plot(dates, np.maximum(v_total, 0), "--", color="purple", + linewidth=0.8, alpha=0.7, label="Predicted total") + + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + tokens = p["tokens"] + if isinstance(tokens, (list, tuple)): + tok_str = "/".join(str(t)[:8] for t in tokens[:2]) + elif isinstance(tokens, str): + tok_str = "/".join(t.strip()[:8] for t in tokens.split(",")[:2]) + else: + tok_str = pid[:16] + + ax.set_title( + f"{tok_str} ({p['chain']})\n" + f"TVL ${p['median_tvl']:,.0f} | R²={r2_val:.3f} " + f"cad={cad:.1f}min gas=${gas:.2f} " + f"arb_frac={arb_frac:.1%} n={p['n_obs']}", + fontsize=8, + ) + ax.legend(fontsize=6, loc="upper right") + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + method_label = {"c": "Option C (per-pool)", "a": "Option A (linear)", + "rf": "Option RF (random forest)"}[method] + fig.suptitle( + f"Direct calibration: V_arb + V_noise — {method_label}\n" + f"page {page + 1}/{n_pages} " + f"(top {min(TOP_N, len(ranked))} by median TVL, {TRAIN_DAYS}d window)", + fontsize=11, + ) + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, f"{suffix}_page{page + 1}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_summary(predictions, option_c_results, joint_result): + """Summary: distributions of cadence, gas, R², arb fraction for both methods.""" + pool_ids = sorted(predictions.keys()) + n = len(pool_ids) + + fig, axes = plt.subplots(2, 4, figsize=(20, 10)) + + for row, (method, label) in enumerate([("c", "Option C (per-pool)"), + ("a", "Option A (joint)")]): + cads = [predictions[p][f"cadence_{method}"] for p in pool_ids] + gases = [predictions[p][f"gas_{method}"] for p in pool_ids] + r2s = [predictions[p][f"r2_{method}"] for p in pool_ids] + arb_fracs = [] + for p in pool_ids: + v_arb = predictions[p][f"v_arb_{method}"] + v_noise = predictions[p][f"v_noise_{method}"] + total = v_arb + v_noise + arb_fracs.append(np.median(v_arb / np.maximum(total, 1.0))) + + # Cadence + ax = axes[row, 0] + ax.hist(cads, bins=20, color="orangered", alpha=0.7, edgecolor="white") + ax.axvline(np.median(cads), color="black", linestyle="--", + label=f"Median={np.median(cads):.1f}min") + ax.set_xlabel("Cadence (minutes)") + ax.set_title(f"{label}: Cadence") + ax.legend(fontsize=8) + + # Gas + ax = axes[row, 1] + ax.hist(gases, bins=20, color="goldenrod", alpha=0.7, edgecolor="white") + ax.axvline(np.median(gases), color="black", linestyle="--", + label=f"Median=${np.median(gases):.2f}") + ax.set_xlabel("Gas (USD)") + ax.set_title(f"{label}: Gas cost") + ax.legend(fontsize=8) + + # R² + ax = axes[row, 2] + r2arr = np.array(r2s) + ax.hist(r2arr[np.isfinite(r2arr)], bins=20, color="green", alpha=0.7, + edgecolor="white") + ax.axvline(np.nanmedian(r2arr), color="black", linestyle="--", + label=f"Median={np.nanmedian(r2arr):.3f}") + ax.set_xlabel("R²") + ax.set_title(f"{label}: R²") + ax.legend(fontsize=8) + + # Arb fraction + ax = axes[row, 3] + ax.hist(arb_fracs, bins=20, color="steelblue", alpha=0.7, edgecolor="white") + ax.axvline(np.median(arb_fracs), color="black", linestyle="--", + label=f"Median={np.median(arb_fracs):.2f}") + ax.set_xlabel("Arb fraction") + ax.set_title(f"{label}: Arb fraction") + ax.legend(fontsize=8) + + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, "summary_distributions.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_c_vs_a_scatter(predictions): + """Scatter: Option C vs Option A parameters.""" + pool_ids = sorted(predictions.keys()) + fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + + for ax, metric, label in [ + (axes[0], "cadence", "Cadence (min)"), + (axes[1], "gas", "Gas (USD)"), + (axes[2], "r2", "R²"), + ]: + c_vals = [predictions[p][f"{metric}_c"] for p in pool_ids] + a_vals = [predictions[p][f"{metric}_a"] for p in pool_ids] + ax.scatter(c_vals, a_vals, alpha=0.7, s=30, edgecolors="k", linewidth=0.5) + lo = min(min(c_vals), min(a_vals)) + hi = max(max(c_vals), max(a_vals)) + margin = (hi - lo) * 0.05 + ax.plot([lo - margin, hi + margin], [lo - margin, hi + margin], + "k--", alpha=0.3, linewidth=1) + ax.set_xlabel(f"Option C: {label}") + ax.set_ylabel(f"Option A: {label}") + ax.set_title(f"{label}: C vs A") + + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, "c_vs_a_scatter.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_cadence_gas_by_chain(predictions): + """Scatter: cadence vs gas, colored by chain.""" + pool_ids = sorted(predictions.keys()) + chains = [predictions[p]["chain"] for p in pool_ids] + unique_chains = sorted(set(chains)) + colors = plt.cm.tab10(np.linspace(0, 1, max(len(unique_chains), 1))) + chain_color = {c: colors[i] for i, c in enumerate(unique_chains)} + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + for ax, method, label in [(axes[0], "c", "Option C"), (axes[1], "a", "Option A")]: + for c in unique_chains: + mask = [i for i, p in enumerate(pool_ids) if predictions[p]["chain"] == c] + cads = [predictions[pool_ids[i]][f"cadence_{method}"] for i in mask] + gases = [predictions[pool_ids[i]][f"gas_{method}"] for i in mask] + ax.scatter(cads, gases, label=c, color=chain_color[c], + alpha=0.7, s=50, edgecolors="k", linewidth=0.5) + ax.set_xlabel("Cadence (minutes)") + ax.set_ylabel("Gas cost (USD)") + ax.set_title(f"{label}: Cadence vs Gas by chain") + ax.legend(fontsize=8) + ax.set_xscale("log") + ax.set_yscale("log") + + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, "cadence_gas_by_chain.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def save_results_json(predictions, option_c_results, joint_ppn, joint_sn): + """Save fitted params as JSON for later use.""" + out = { + "option_c": {}, + "option_a_ppn": { + "bias_cad": joint_ppn["bias_cad"], + "bias_gas": joint_ppn["bias_gas"], + "W_cad": joint_ppn["W_cad"].tolist(), + "W_gas": joint_ppn["W_gas"].tolist(), + "noise_coeffs": joint_ppn["noise_coeffs"].tolist(), + "loss": joint_ppn["loss"], + "init_loss": joint_ppn["init_loss"], + "converged": bool(joint_ppn["converged"]), + "attr_names": joint_ppn["attr_names"], + "pool_ids": joint_ppn["pool_ids"], + }, + "option_a_shared": { + "bias_cad": joint_sn["bias_cad"], + "bias_gas": joint_sn["bias_gas"], + "W_cad": joint_sn["W_cad"].tolist(), + "W_gas": joint_sn["W_gas"].tolist(), + "bias_noise": joint_sn["bias_noise"].tolist(), + "W_noise": joint_sn["W_noise"].tolist(), + "loss": joint_sn["loss"], + "init_loss": joint_sn["init_loss"], + "converged": bool(joint_sn["converged"]), + "attr_names": joint_sn["attr_names"], + "pool_ids": joint_sn["pool_ids"], + }, + } + for pid, r in option_c_results.items(): + out["option_c"][pid] = { + "log_cadence": r["log_cadence"], + "log_gas": r["log_gas"], + "noise_coeffs": r["noise_coeffs"].tolist(), + "loss": r["loss"], + "converged": bool(r["converged"]), + "cadence_minutes": r["cadence_minutes"], + "gas_usd": r["gas_usd"], + "chain": r["chain"], + "fee": r["fee"], + "tokens": r["tokens"], + } + + path = os.path.join(OUTPUT_DIR, "direct_calibration_results.json") + with open(path, "w") as f: + json.dump(out, f, indent=2) + print(f" Saved: {path}") + + +def print_pool_table(predictions, option_c_results): + """Print a summary table of per-pool results.""" + ranked = sorted( + predictions.items(), + key=lambda x: -x[1]["median_tvl"], + ) + + has_rf = any(not np.isnan(p["cadence_rf"]) for _, p in ranked) + + print(f"\n{'='*150}") + header = (f"{'Pool':<24} {'Chain':<10} {'TVL':>12} {'N':>4} " + f"{'Cad_C':>6} {'Gas_C':>7} {'R2_C':>6} " + f"{'Cad_A':>6} {'Gas_A':>7} {'R2_A':>6}") + if has_rf: + header += f" {'Cad_RF':>6} {'Gas_RF':>7} {'R2_RF':>6} {'R2_LOO':>6}" + header += f" {'Arb%_C':>6}" + print(header) + print(f"{'-'*150}") + for pid, p in ranked: + tokens = p["tokens"] + if isinstance(tokens, str): + tok_str = "/".join(t.strip()[:6] for t in tokens.split(",")[:2]) + else: + tok_str = pid[:16] + arb_total_c = p["v_arb_c"] + p["v_noise_c"] + arb_frac = np.median(p["v_arb_c"] / np.maximum(arb_total_c, 1.0)) + if np.isnan(p["cadence_a"]): + a_str = " --- dropped --- " + else: + a_str = f"{p['cadence_a']:>5.1f}m ${p['gas_a']:>5.2f} {p['r2_a']:>6.3f}" + line = (f"{tok_str:<24} {p['chain']:<10} ${p['median_tvl']:>10,.0f} {p['n_obs']:>4} " + f"{p['cadence_c']:>5.1f}m ${p['gas_c']:>5.2f} {p['r2_c']:>6.3f} " + f"{a_str}") + if has_rf: + if np.isnan(p["cadence_rf"]): + line += " --- dropped --- " + else: + line += (f" {p['cadence_rf']:>5.1f}m ${p['gas_rf']:>5.2f} " + f"{p['r2_rf']:>6.3f} {p['r2_rf_loo']:>6.3f}") + line += f" {arb_frac:>5.1%}" + print(line) + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Direct Calibration Pipeline: Training + Top 50 Plots") + print("=" * 70) + + panel, matched = load_and_match() + + # Step 1: Option C + option_c = run_option_c(matched) + + # Step 2: Option A (linear mapping) + joint_ppn, joint_sn = run_option_a(matched, option_c) + + # Step 3: Option RF (random forest 2-stage) + rf_result = run_option_rf(matched, option_c) + + # Step 4: Compute predictions + print("\nComputing per-pool predictions...") + predictions = compute_per_pool_predictions(matched, option_c, joint_ppn, rf_result) + + # Step 5: Print table + print_pool_table(predictions, option_c) + + # Print RF vs A comparison summary + rf_pools = [p for p in predictions if not np.isnan(predictions[p]["r2_rf"])] + if rf_pools: + r2_c = [predictions[p]["r2_c"] for p in rf_pools] + r2_a = [predictions[p]["r2_a"] for p in rf_pools + if not np.isnan(predictions[p]["r2_a"])] + r2_rf = [predictions[p]["r2_rf"] for p in rf_pools] + r2_loo = [predictions[p]["r2_rf_loo"] for p in rf_pools] + print(f"\n--- R² comparison (non-dropped pools) ---") + print(f" Option C: median={np.median(r2_c):.4f} mean={np.mean(r2_c):.4f}") + if r2_a: + print(f" Option A (linear): median={np.median(r2_a):.4f} mean={np.mean(r2_a):.4f}") + print(f" Option RF (train): median={np.median(r2_rf):.4f} mean={np.mean(r2_rf):.4f}") + print(f" Option RF (LOO): median={np.median(r2_loo):.4f} mean={np.mean(r2_loo):.4f}") + + # Step 6: Plots + print("\nGenerating plots...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plot_top50_pages(predictions, method="c") + plot_top50_pages(predictions, method="a") + plot_top50_pages(predictions, method="rf") + plot_summary(predictions, option_c, joint_ppn) + plot_c_vs_a_scatter(predictions) + plot_cadence_gas_by_chain(predictions) + + # Step 7: Save results + save_results_json(predictions, option_c, joint_ppn, joint_sn) + + print(f"\n{'='*70}") + print(f"Done. Output in: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_structural_top50.py b/scripts/run_structural_top50.py new file mode 100644 index 0000000..b46aece --- /dev/null +++ b/scripts/run_structural_top50.py @@ -0,0 +1,448 @@ +"""Fit the structural mixture model and plot predicted vs actual for top 50 pools. + +Uses the cached panel (last 90 days), fits with vanilla SVI, then generates +paginated plots showing V_arb + V_noise decomposition and predicted vs actual. +""" + +import json +import os +import sys +from datetime import date, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "structural_hierarchical", +) +OUTPUT_JSON = os.path.join(OUTPUT_DIR, "structural_fit.json") +TRAIN_DAYS = 90 +SVI_STEPS = 20_000 +SVI_LR = 1e-3 +NUM_SAMPLES = 1000 +SEED = 42 +TOP_N = 50 + + +def load_and_filter_panel(): + """Load cached panel, filter to 90 days, keep pools with >= 10 obs.""" + panel = pd.read_parquet(PANEL_CACHE) + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=TRAIN_DAYS) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if "log_tvl_lag1" not in panel.columns: + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel.groupby("pool_id").size() + valid = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid)].copy() + + print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " + f"{cutoff} to {max_date}") + return panel + + +def fit_structural(panel): + """Run SVI on the structural mixture model.""" + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.postprocessing import ( + check_convergence, extract_structural_params, + ) + from quantammsim.noise_calibration.output import generate_output_json + + import numpyro + numpyro.enable_x64() + + # Load gas costs for mainnet from CSV if available + gas_csv = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "formula_vs_real", "mainnet_gas_cost_daily.csv", + ) + gas_arr = None + if os.path.exists(gas_csv): + gas_df = pd.read_csv(gas_csv) + # CSV has columns: unix (ms timestamp), USD (gas cost) + gas_df["date"] = pd.to_datetime(gas_df["unix"], unit="ms").dt.date + gas_lookup = dict(zip(gas_df["date"], gas_df["USD"])) + + # Build per-observation gas array + gas_vals = [] + for _, row in panel.iterrows(): + d = row["date"] + if not isinstance(d, date): + d = pd.Timestamp(d).date() + chain = row["chain"] + if chain == "MAINNET" and d in gas_lookup: + gas_vals.append(gas_lookup[d]) + elif chain == "MAINNET": + gas_vals.append(1.0) # median fallback + else: + # L2 chains: ~$0.005 + from quantammsim.noise_calibration.constants import GAS_COSTS + gas_vals.append(GAS_COSTS.get(chain, 0.005)) + gas_arr = np.array(gas_vals, dtype=np.float64) + print(f"Gas costs: loaded ({len(gas_lookup)} mainnet days from CSV)") + else: + print("Gas costs: using defaults (no mainnet CSV)") + + data = encode_covariates_structural(panel, gas=gas_arr) + + print(f"\nFitting structural model: {SVI_STEPS} SVI steps, lr={SVI_LR}") + samples, elbo_losses = run_svi( + data, + num_steps=SVI_STEPS, + lr=SVI_LR, + seed=SEED, + num_samples=NUM_SAMPLES, + model_fn=structural_noise_model, + ) + convergence = check_convergence(elbo_losses, method="svi") + + pool_params = extract_structural_params(samples, data) + + # Save output JSON + os.makedirs(OUTPUT_DIR, exist_ok=True) + inference_config = { + "method": "svi", "svi_steps": SVI_STEPS, + "svi_lr": SVI_LR, "num_samples": NUM_SAMPLES, + } + generate_output_json( + pool_params, samples, data, convergence, + OUTPUT_JSON, inference_config, + ) + + return samples, data, pool_params, elbo_losses + + +def compute_predictions(samples, data, panel): + """Compute per-observation predicted V_arb and V_noise.""" + from quantammsim.noise_calibration.formula_arb import ( + formula_arb_volume_daily_jax, + ) + import jax.numpy as jnp + + sample_dict = samples + agg_fn = np.median + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # Hierarchical noise: reconstruct theta + B = agg_fn(np.array(sample_dict["B"]), axis=0) + eta = agg_fn(np.array(sample_dict["eta"]), axis=0) + sigma_theta = agg_fn(np.array(sample_dict["sigma_theta"]), axis=0) + L_Omega = agg_fn(np.array(sample_dict["L_Omega"]), axis=0) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + chain_idx = np.array(data["chain_idx"]) + tier_idx = np.array(data["tier_idx"]) + sigma_daily = np.array(data["sigma_daily"]) + lag_log_tvl = np.array(data["lag_log_tvl"]) + fee = np.array(data["fee"]) + gas = np.array(data["gas"]) + + # Per-pool cadence + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + N_pools = data["N_pools"] + pool_log_cadence = np.zeros(N_pools) + for p in range(N_pools): + pool_log_cadence[p] = ( + alpha_0 + + padded_chain[chain_idx[p]] + + padded_tier[tier_idx[p]] + + alpha_tvl * np.median(lag_log_tvl[pool_idx == p]) + ) + + # Per-obs V_arb + log_cad_obs = pool_log_cadence[pool_idx] + cadence_obs = np.exp(np.clip(log_cad_obs, -2.0, 6.0)) + tvl_obs = np.exp(lag_log_tvl) + + V_arb = np.array(formula_arb_volume_daily_jax( + jnp.array(sigma_daily), jnp.array(tvl_obs), + jnp.array(fee), jnp.array(gas), jnp.array(cadence_obs), + )) + + # Per-pool theta from hierarchical model + L_Sigma = np.diag(sigma_theta) @ L_Omega + theta = X_pool @ B.T + eta @ L_Sigma.T # (N_pools, K_obs_coeff) + + # Per-obs V_noise + log_V_noise = np.sum(theta[pool_idx] * x_obs, axis=1) + V_noise = np.exp(log_V_noise) + + # Predicted total + V_total_pred = V_arb + V_noise + log_V_pred = np.log(np.maximum(V_total_pred, 1e-6)) + + return V_arb, V_noise, V_total_pred, log_V_pred, cadence_obs + + +def plot_top50(panel, data, pool_params, V_arb, V_noise, log_V_pred): + """Plot top 50 pools by median TVL.""" + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + pool_idx = np.array(data["pool_idx"]) + y_obs = np.array(data["y_obs"]) + + # Rank pools by median TVL + pool_tvl = {} + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + pool_tvl[pid] = np.median(np.exp(np.array(data["lag_log_tvl"])[mask])) + + ranked = sorted(pool_tvl.items(), key=lambda x: -x[1])[:TOP_N] + + # Build param lookup + param_lookup = {p["pool_id"]: p for p in pool_params} + + per_page = 10 + n_pages = (len(ranked) + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, len(ranked)) + page_pools = ranked[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4.5 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + + for idx, (pid, median_tvl) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + p_idx = pool_ids.index(pid) + mask = pool_idx == p_idx + + pp = panel[panel["pool_id"] == pid].sort_values("date") + dates = pd.to_datetime(pp["date"].values) + actual_vol = np.exp(y_obs[mask]) + pred_arb = V_arb[mask] + pred_noise = V_noise[mask] + pred_total = pred_arb + pred_noise + + # R2 + actual_log = y_obs[mask] + pred_log = log_V_pred[mask] + ss_res = np.sum((actual_log - pred_log) ** 2) + ss_tot = np.sum((actual_log - actual_log.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + # Arb fraction + arb_frac = np.median(pred_arb / np.maximum(pred_total, 1.0)) + + # Plot + ax.fill_between(dates, 0, pred_arb, alpha=0.3, color="orangered", + label="V_arb (LVR)") + ax.fill_between(dates, pred_arb, pred_total, alpha=0.3, + color="steelblue", label="V_noise (hier.)") + ax.plot(dates, actual_vol, "k-", linewidth=0.8, alpha=0.7, + label="Actual") + ax.plot(dates, pred_total, "--", color="purple", linewidth=0.8, + alpha=0.7, label="Predicted total") + + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + meta = pool_meta[pool_meta["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tokens = m["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + tok_str = "/".join(str(t)[:8] for t in tokens[:2]) + chain = str(m["chain"]) + else: + tok_str = pid[:16] + chain = "?" + + params = param_lookup.get(pid, {}) + arb_freq = params.get("arb_frequency", "?") + + ax.set_title( + f"{tok_str} ({chain})\n" + f"TVL ${median_tvl:,.0f} | R\u00b2={r2:.3f} " + f"arb_freq={arb_freq}min arb_frac={arb_frac:.1%} " + f"n={mask.sum()}", + fontsize=8, + ) + ax.legend(fontsize=6, loc="upper right") + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + f"Structural mixture model: V_arb + V_noise decomposition " + f"— page {page + 1}/{n_pages} " + f"(top {TOP_N} by median TVL, 90d window)", + fontsize=11, + ) + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, f"structural_top50_page{page + 1}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_elbo(elbo_losses): + """Plot ELBO convergence.""" + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + ax = axes[0] + ax.plot(elbo_losses, alpha=0.3, color="steelblue", linewidth=0.5) + window = min(100, len(elbo_losses) // 10) + if window > 1: + smoothed = pd.Series(elbo_losses).rolling(window).mean().values + ax.plot(smoothed, color="red", linewidth=1.5, label=f"Rolling {window}") + ax.legend() + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence") + + ax = axes[1] + start = len(elbo_losses) * 4 // 5 + ax.plot(range(start, len(elbo_losses)), elbo_losses[start:], + color="steelblue", linewidth=0.8) + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence (last 20%)") + + plt.tight_layout() + out = os.path.join(OUTPUT_DIR, "elbo_convergence.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + + +def plot_summary(data, pool_params, V_arb, V_noise, log_V_pred): + """Summary plots: arb frequency distribution, arb fraction, R2.""" + pool_idx = np.array(data["pool_idx"]) + y_obs = np.array(data["y_obs"]) + pool_ids = data["pool_ids"] + + fig, axes = plt.subplots(1, 3, figsize=(16, 5)) + + # 1. Arb frequency histogram + ax = axes[0] + freqs = [p["arb_frequency"] for p in pool_params] + ax.hist(freqs, bins=range(0, 62, 2), color="orangered", alpha=0.7, + edgecolor="white") + ax.set_xlabel("Arb frequency (minutes)") + ax.set_ylabel("Count") + ax.set_title(f"Arb frequency distribution (n={len(freqs)})") + ax.axvline(np.median(freqs), color="black", linestyle="--", + label=f"Median={np.median(freqs):.0f}min") + ax.legend() + + # 2. Arb fraction per pool + ax = axes[1] + arb_fracs = [] + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + total = V_arb[mask] + V_noise[mask] + arb_fracs.append(np.median(V_arb[mask] / np.maximum(total, 1.0))) + ax.hist(arb_fracs, bins=30, color="steelblue", alpha=0.7, edgecolor="white") + ax.set_xlabel("Median arb fraction") + ax.set_ylabel("Count") + ax.set_title("Arb fraction distribution") + ax.axvline(np.median(arb_fracs), color="black", linestyle="--", + label=f"Median={np.median(arb_fracs):.2f}") + ax.legend() + + # 3. Per-pool R2 + ax = axes[2] + r2_vals = [] + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + actual = y_obs[mask] + pred = log_V_pred[mask] + ss_res = np.sum((actual - pred) ** 2) + ss_tot = np.sum((actual - actual.mean()) ** 2) + r2_vals.append(1 - ss_res / ss_tot if ss_tot > 0 else float("nan")) + r2_vals = np.array(r2_vals) + ax.hist(r2_vals[np.isfinite(r2_vals)], bins=30, color="green", alpha=0.7, + edgecolor="white") + ax.set_xlabel("R²") + ax.set_ylabel("Count") + ax.set_title("Per-pool R² distribution") + ax.axvline(np.nanmedian(r2_vals), color="black", linestyle="--", + label=f"Median={np.nanmedian(r2_vals):.3f}") + ax.legend() + + plt.tight_layout() + out = os.path.join(OUTPUT_DIR, "structural_summary.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + + +def main(): + print("=" * 70) + print("Structural Mixture Model: Fit + Top 50 Plots") + print("=" * 70) + + panel = load_and_filter_panel() + samples, data, pool_params, elbo_losses = fit_structural(panel) + + print("\nComputing predictions...") + V_arb, V_noise, V_total, log_V_pred, cadence = compute_predictions( + samples, data, panel, + ) + print(f" V_arb median: ${np.median(V_arb):,.0f}") + print(f" V_noise median: ${np.median(V_noise):,.0f}") + print(f" Arb fraction (median pool): {np.median(V_arb / np.maximum(V_total, 1)):.2%}") + + print("\nGenerating plots...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plot_elbo(elbo_losses) + plot_summary(data, pool_params, V_arb, V_noise, log_V_pred) + plot_top50(panel, data, pool_params, V_arb, V_noise, log_V_pred) + + # Summary stats + print(f"\n{'=' * 70}") + print(f"Done. Output in: {OUTPUT_DIR}") + arb_freqs = [p["arb_frequency"] for p in pool_params] + print(f" Arb frequency: median={np.median(arb_freqs):.0f}min, " + f"range=[{np.min(arb_freqs)}, {np.max(arb_freqs)}]") + print(f" JSON: {OUTPUT_JSON}") + + +if __name__ == "__main__": + main() diff --git a/tests/calibration/__init__.py b/tests/calibration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/calibration/conftest.py b/tests/calibration/conftest.py new file mode 100644 index 0000000..2c171c0 --- /dev/null +++ b/tests/calibration/conftest.py @@ -0,0 +1,134 @@ +"""Fixtures for calibration pipeline tests.""" + +import numpy as np +import pandas as pd +import pytest + + +# ── Constants ────────────────────────────────────────────────────────────── + +N_CADENCES = 3 +N_GAS = 3 +N_DAYS = 15 +N_POOLS = 2 +K_OBS = 8 + +CADENCES = np.array([1.0, 12.0, 60.0]) +GAS_COSTS = np.array([0.0, 1.0, 5.0]) + +# Pool ID prefixes (16 chars) that map to full 66-char pool IDs +POOL_PREFIXES = ["0xaaaa11112222aa", "0xbbbb33334444bb"] +POOL_IDS_FULL = [ + "0xaaaa11112222aa63ae5d458857e731c129069f29000200000000000000000588", + "0xbbbb33334444bb9c8ef030ab642b10820db8f56000200000000000000000014", +] + + +# ── Fixtures ────────────────────────────────────────────────────────────── + + +@pytest.fixture +def synthetic_daily_grid(): + """Small per-day grid DataFrame: 3 cadences x 3 gas_costs x 15 days. + + V_arb decreasing in cadence and gas, with daily sinusoidal variation. + """ + np.random.seed(42) + dates = pd.date_range("2025-12-01", periods=N_DAYS, freq="D") + + rows = [] + for ci, cad in enumerate(CADENCES): + for gi, gas in enumerate(GAS_COSTS): + base = 10000.0 / (1 + 0.3 * ci) / (1 + 0.5 * gi) + for di, date in enumerate(dates): + daily_var = 1 + 0.1 * np.sin(2 * np.pi * di / 7) + vol = base * daily_var + np.random.normal(0, base * 0.01) + rows.append({ + "cadence": cad, + "gas_cost": gas, + "date": date, + "daily_arb_volume": max(vol, 0), + }) + + return pd.DataFrame(rows) + + +@pytest.fixture +def synthetic_panel(): + """Minimal panel DataFrame: 2 pools x 15 days. + + Columns match the real panel.parquet schema. + """ + np.random.seed(42) + dates = pd.date_range("2025-12-01", periods=N_DAYS, freq="D") + + rows = [] + for pi, (prefix, full_id) in enumerate(zip(POOL_PREFIXES, POOL_IDS_FULL)): + chain = "MAINNET" if pi == 0 else "ARBITRUM" + base_tvl = 12.0 + pi # log TVL + base_vol = 9.0 + pi + fee = 0.003 if pi == 0 else 0.01 + + for di, date in enumerate(dates): + tvl = base_tvl + 0.05 * np.sin(2 * np.pi * di / 30) + vol = base_vol + 0.3 * np.random.randn() + sigma = 0.4 + 0.1 * np.random.randn() + rows.append({ + "pool_id": full_id, + "chain": chain, + "date": date, + "log_volume": vol, + "log_tvl": tvl, + "log_tvl_lag1": tvl - 0.01 if di > 0 else np.nan, + "volatility": max(sigma, 0.01), + "weekend": 1 if date.weekday() >= 5 else 0, + "log_fee": np.log(fee), + "swap_fee": fee, + "tier_A": "major" if pi == 0 else "mid", + "tier_B": "major", + "tokens": "BTC,ETH" if pi == 0 else "AAVE,ETH", + "total_shares": 1e6 * (1 + 0.01 * di), + }) + + df = pd.DataFrame(rows) + # Drop rows where log_tvl_lag1 is NaN (first day per pool) + df = df.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + return df + + +@pytest.fixture +def synthetic_pool_coeffs(synthetic_daily_grid): + """PoolCoeffsDaily built from synthetic_daily_grid.""" + from quantammsim.calibration.grid_interpolation import precompute_pool_coeffs_daily + return precompute_pool_coeffs_daily(synthetic_daily_grid) + + +@pytest.fixture +def synthetic_x_obs(synthetic_panel): + """NumPy array (n_obs, K_OBS) from synthetic panel for one pool.""" + from quantammsim.calibration.pool_data import build_x_obs + pool0 = synthetic_panel[ + synthetic_panel["pool_id"] == POOL_IDS_FULL[0] + ] + return build_x_obs(pool0) + + +@pytest.fixture +def synthetic_pool_fit_result(): + """Dict with per-pool fitted params for testing learned mapping.""" + np.random.seed(42) + results = {} + for prefix in POOL_PREFIXES: + results[prefix] = { + "log_cadence": np.log(12.0) + 0.1 * np.random.randn(), + "log_gas": np.log(1.0) + 0.1 * np.random.randn(), + "noise_coeffs": np.random.randn(K_OBS) * 0.1, + "loss": 0.5 + 0.1 * np.random.rand(), + "converged": True, + "cadence_minutes": 12.0, + "gas_usd": 1.0, + "chain": "MAINNET" if prefix == POOL_PREFIXES[0] else "ARBITRUM", + "fee": 0.003 if prefix == POOL_PREFIXES[0] else 0.01, + "tokens": "BTC/ETH" if prefix == POOL_PREFIXES[0] else "AAVE/ETH", + } + return results diff --git a/tests/calibration/test_joint_fit.py b/tests/calibration/test_joint_fit.py new file mode 100644 index 0000000..50d7730 --- /dev/null +++ b/tests/calibration/test_joint_fit.py @@ -0,0 +1,190 @@ +"""Tests for quantammsim.calibration.joint_fit — joint end-to-end optimization (Option A).""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, POOL_IDS_FULL, POOL_PREFIXES + + +@pytest.fixture +def matched_data(synthetic_daily_grid, synthetic_panel, tmp_path): + """Build matched data dict from synthetic fixtures.""" + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + +class TestPrepareJointData: + """Test prepare_joint_data: build batched arrays for joint optimization.""" + + def test_returns_expected_structure(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data) + assert hasattr(jdata, "pool_data") + assert hasattr(jdata, "x_attr") + assert hasattr(jdata, "pool_ids") + + def test_pool_count(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data) + assert len(jdata.pool_data) == len(matched_data) + + def test_pool_data_has_jax_arrays(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data) + for pd in jdata.pool_data: + assert isinstance(pd["x_obs"], jnp.ndarray) + assert isinstance(pd["y_obs"], jnp.ndarray) + assert isinstance(pd["day_indices"], jnp.ndarray) + + def test_x_attr_shape(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data) + n_pools = len(matched_data) + assert jdata.x_attr.shape[0] == n_pools + assert jdata.x_attr.shape[1] > 0 # K_attr + + +class TestJointLoss: + """Test joint_loss: end-to-end loss over all pools.""" + + def _make_loss_fn(self, matched_data): + from quantammsim.calibration.joint_fit import ( + make_joint_loss_fn, + make_initial_joint_params, + prepare_joint_data, + ) + + jdata = prepare_joint_data(matched_data) + init = make_initial_joint_params(jdata, mode="per_pool_noise") + loss_fn = make_joint_loss_fn(jdata, mode="per_pool_noise") + return loss_fn, init, jdata + + def test_loss_scalar(self, matched_data): + loss_fn, init, _ = self._make_loss_fn(matched_data) + loss = loss_fn(init) + assert loss.shape == () + + def test_loss_positive(self, matched_data): + loss_fn, init, _ = self._make_loss_fn(matched_data) + loss = loss_fn(init) + assert float(loss) >= 0 + + def test_loss_differentiable(self, matched_data): + loss_fn, init, _ = self._make_loss_fn(matched_data) + grad = jax.grad(loss_fn)(init) + assert grad.shape == init.shape + assert jnp.all(jnp.isfinite(grad)) + + def test_loss_grad_nonzero(self, matched_data): + loss_fn, init, _ = self._make_loss_fn(matched_data) + grad = jax.grad(loss_fn)(init) + assert float(jnp.sum(jnp.abs(grad))) > 0 + + def test_shared_cadence_gas_affects_all_pools(self, matched_data): + """Changing W_cad should affect losses from all pools.""" + from quantammsim.calibration.joint_fit import ( + make_joint_loss_fn, + make_initial_joint_params, + prepare_joint_data, + unpack_joint_params, + ) + + jdata = prepare_joint_data(matched_data) + init = make_initial_joint_params(jdata, mode="per_pool_noise") + loss_fn = make_joint_loss_fn(jdata, mode="per_pool_noise") + + # Gradient w.r.t. W_cad should be nonzero + grad = jax.grad(loss_fn)(init) + config = {"n_pools": len(jdata.pool_data), + "k_attr": jdata.x_attr.shape[1], + "mode": "per_pool_noise"} + params = unpack_joint_params(grad, config) + assert float(jnp.sum(jnp.abs(params["W_cad"]))) > 0 + + +class TestFitJoint: + """Test fit_joint: L-BFGS-B joint optimization.""" + + def test_returns_result(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=20) + assert isinstance(result, dict) + for key in ["bias_cad", "bias_gas", "W_cad", "W_gas", "loss", "converged"]: + assert key in result, f"Missing key: {key}" + + def test_loss_decreases(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=50) + assert result["loss"] <= result["init_loss"] + + def test_predict_new_pool(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=20) + # Predict for a new pool — k_attr must match training + k_attr = result["W_cad"].shape[0] + x_attr_new = np.zeros(k_attr) + x_attr_new[0] = 1.0 # intercept + pred = predict_new_pool_joint(result, x_attr_new) + assert "cadence_minutes" in pred + assert "gas_usd" in pred + assert pred["cadence_minutes"] > 0 + assert pred["gas_usd"] > 0 + + def test_init_from_option_c(self, matched_data): + """Warm-starting from Option C per-pool fits should work.""" + from quantammsim.calibration.joint_fit import fit_joint + from quantammsim.calibration.per_pool_fit import fit_all_pools + + option_c = fit_all_pools(matched_data) + result = fit_joint( + matched_data, mode="per_pool_noise", + init_from_option_c=option_c, maxiter=20, + ) + assert result["loss"] >= 0 + + +class TestModes: + """Test per_pool_noise vs shared_noise modes.""" + + def test_per_pool_noise_mode(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=20) + assert "noise_coeffs" in result + n_pools = len(matched_data) + assert result["noise_coeffs"].shape == (n_pools, K_OBS) + + def test_shared_noise_mode(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="shared_noise", maxiter=20) + assert "W_noise" in result + k_attr = result["W_cad"].shape[0] + assert result["W_noise"].shape == (k_attr, K_OBS) + + def test_shared_noise_predict(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint(matched_data, mode="shared_noise", maxiter=20) + k_attr = result["W_cad"].shape[0] + x_attr_new = np.zeros(k_attr) + x_attr_new[0] = 1.0 # intercept + pred = predict_new_pool_joint(result, x_attr_new) + assert "noise_coeffs" in pred + assert len(pred["noise_coeffs"]) == K_OBS diff --git a/tests/calibration/test_learned_mapping.py b/tests/calibration/test_learned_mapping.py new file mode 100644 index 0000000..8fbf16e --- /dev/null +++ b/tests/calibration/test_learned_mapping.py @@ -0,0 +1,152 @@ +"""Tests for quantammsim.calibration.learned_mapping — attribute -> params mapping.""" + +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, POOL_PREFIXES + + +class TestBuildTargets: + """Test build_targets: stack per-pool fitted params into target matrix.""" + + def test_target_matrix_shape(self, synthetic_pool_fit_result): + from quantammsim.calibration.learned_mapping import build_targets + + pool_order = sorted(synthetic_pool_fit_result.keys()) + Y = build_targets(synthetic_pool_fit_result, pool_order) + assert Y.shape == (len(pool_order), 2 + K_OBS) + + def test_target_ordering_matches_attributes(self, synthetic_pool_fit_result): + from quantammsim.calibration.learned_mapping import build_targets + + pool_order = sorted(synthetic_pool_fit_result.keys()) + Y = build_targets(synthetic_pool_fit_result, pool_order) + # First pool's log_cadence should be in first row + expected_lc = synthetic_pool_fit_result[pool_order[0]]["log_cadence"] + np.testing.assert_allclose(Y[0, 0], expected_lc) + + +class TestFitMapping: + """Test fit_mapping: Ridge regression from attributes to params.""" + + def _make_data(self, n_pools=10): + np.random.seed(42) + k_attr = 4 + X = np.random.randn(n_pools, k_attr) + X[:, 0] = 1.0 # intercept + W_true = np.random.randn(k_attr, 2 + K_OBS) * 0.5 + Y = X @ W_true + np.random.randn(n_pools, 2 + K_OBS) * 0.01 + return X, Y + + def test_returns_model(self): + from quantammsim.calibration.learned_mapping import fit_mapping + + X, Y = self._make_data() + model = fit_mapping(X, Y) + assert isinstance(model, dict) + assert "weights" in model + assert "intercept" in model + + def test_predict_shape(self): + from quantammsim.calibration.learned_mapping import fit_mapping + + X, Y = self._make_data() + model = fit_mapping(X, Y) + Y_pred = X @ model["weights"] + model["intercept"] + assert Y_pred.shape == Y.shape + + def test_predict_reasonable_range(self): + from quantammsim.calibration.learned_mapping import fit_mapping, predict_pool + + X, Y = self._make_data() + # Constrain Y targets to reasonable range + Y[:, 0] = np.log(np.random.uniform(1, 60, len(Y))) # log_cadence + Y[:, 1] = np.log(np.random.uniform(0.01, 10, len(Y))) # log_gas + model = fit_mapping(X, Y) + + result = predict_pool(model, X[0]) + assert 0.5 <= result["cadence_minutes"] <= 120.0 + assert result["gas_usd"] > 0 + + def test_overfit_on_training_data(self): + from quantammsim.calibration.learned_mapping import fit_mapping + + X, Y = self._make_data(n_pools=10) + model = fit_mapping(X, Y, alpha=0.001) + Y_pred = X @ model["weights"] + model["intercept"] + ss_res = np.sum((Y - Y_pred) ** 2) + ss_tot = np.sum((Y - Y.mean(axis=0)) ** 2) + r2 = 1 - ss_res / ss_tot + assert r2 > 0.8 + + def test_leave_one_out_runs(self): + from quantammsim.calibration.learned_mapping import cross_validate_loo + + X, Y = self._make_data(n_pools=10) + cv_result = cross_validate_loo(X, Y) + assert "per_pool_errors" in cv_result + assert len(cv_result["per_pool_errors"]) == 10 + + +class TestPredictNewPool: + """Test predict_pool: predict params for a single pool.""" + + def test_predict_single_pool(self): + from quantammsim.calibration.learned_mapping import fit_mapping, predict_pool + + np.random.seed(42) + X = np.random.randn(5, 3) + X[:, 0] = 1.0 + Y = np.random.randn(5, 2 + K_OBS) + model = fit_mapping(X, Y) + result = predict_pool(model, X[0]) + assert isinstance(result, dict) + + def test_predict_cadence_and_gas(self): + from quantammsim.calibration.learned_mapping import fit_mapping, predict_pool + + np.random.seed(42) + X = np.random.randn(5, 3) + X[:, 0] = 1.0 + Y = np.random.randn(5, 2 + K_OBS) + model = fit_mapping(X, Y) + result = predict_pool(model, X[0]) + assert "cadence_minutes" in result + assert "gas_usd" in result + assert "log_cadence" in result + assert "log_gas" in result + + def test_predict_noise_coeffs(self): + from quantammsim.calibration.learned_mapping import fit_mapping, predict_pool + + np.random.seed(42) + X = np.random.randn(5, 3) + X[:, 0] = 1.0 + Y = np.random.randn(5, 2 + K_OBS) + model = fit_mapping(X, Y) + result = predict_pool(model, X[0]) + assert "noise_coeffs" in result + assert len(result["noise_coeffs"]) == K_OBS + + def test_different_chains_different_predictions(self): + from quantammsim.calibration.learned_mapping import fit_mapping, predict_pool + + np.random.seed(42) + # X with chain dummy in col 1 + X = np.random.randn(10, 4) + X[:, 0] = 1.0 + X[:5, 1] = 1.0 # chain A + X[5:, 1] = 0.0 # chain B + Y = np.random.randn(10, 2 + K_OBS) + Y[:5, :] += 1.0 # chain A has different targets + + model = fit_mapping(X, Y, alpha=0.01) + + x_a = X[0].copy() + x_b = X[0].copy() + x_b[1] = 0.0 # flip chain + + pred_a = predict_pool(model, x_a) + pred_b = predict_pool(model, x_b) + # Predictions should differ + assert pred_a["log_cadence"] != pred_b["log_cadence"] diff --git a/tests/calibration/test_loss.py b/tests/calibration/test_loss.py new file mode 100644 index 0000000..8fd2958 --- /dev/null +++ b/tests/calibration/test_loss.py @@ -0,0 +1,239 @@ +"""Tests for quantammsim.calibration.loss — per-pool loss function.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, N_DAYS + + +class TestNoiseVolume: + """Test noise_volume: V_noise = exp(x_obs @ noise_coeffs).""" + + def test_noise_volume_shape(self, synthetic_x_obs): + from quantammsim.calibration.loss import noise_volume + + coeffs = jnp.zeros(K_OBS) + v = noise_volume(coeffs, jnp.array(synthetic_x_obs)) + assert v.shape == (synthetic_x_obs.shape[0],) + + def test_noise_volume_positive(self, synthetic_x_obs): + from quantammsim.calibration.loss import noise_volume + + coeffs = jnp.ones(K_OBS) * 0.1 + v = noise_volume(coeffs, jnp.array(synthetic_x_obs)) + assert jnp.all(v > 0) + + def test_noise_volume_intercept_only(self, synthetic_x_obs): + from quantammsim.calibration.loss import noise_volume + + c = 5.0 + coeffs = jnp.zeros(K_OBS).at[0].set(c) + v = noise_volume(coeffs, jnp.array(synthetic_x_obs)) + # x_obs[:, 0] is all 1s, so V_noise should be close to exp(c) + # (not exactly, because other columns are nonzero and contribute 0*x) + np.testing.assert_allclose(v, jnp.exp(c), rtol=1e-5) + + def test_noise_volume_tvl_effect(self, synthetic_x_obs): + from quantammsim.calibration.loss import noise_volume + + # Positive TVL coefficient: higher TVL → higher noise + coeffs = jnp.zeros(K_OBS).at[1].set(1.0) + v = noise_volume(coeffs, jnp.array(synthetic_x_obs)) + tvl_col = synthetic_x_obs[:, 1] + # Sort by TVL, check volume is monotone + order = np.argsort(tvl_col) + assert np.all(np.diff(np.array(v[order])) >= -1e-6) + + +class TestPoolLoss: + """Test pool_loss: per-pool log-space L2 loss with per-day V_arb.""" + + def _make_params(self, log_cad=None, log_gas=None, noise_coeffs=None): + from quantammsim.calibration.loss import pack_params + + if log_cad is None: + log_cad = float(jnp.log(jnp.array(12.0))) + if log_gas is None: + log_gas = float(jnp.log(jnp.array(1.0))) + if noise_coeffs is None: + noise_coeffs = jnp.zeros(K_OBS).at[0].set(8.0) + return pack_params(log_cad, log_gas, noise_coeffs) + + def _make_inputs(self, synthetic_pool_coeffs, synthetic_x_obs): + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = jnp.array(np.arange(n_obs) % n_days) + y_obs = jnp.ones(n_obs) * 9.0 + return jnp.array(synthetic_x_obs), y_obs, day_indices + + def test_loss_scalar_output(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + loss = pool_loss(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert loss.shape == () + + def test_loss_zero_when_perfect(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.loss import noise_volume, pack_params, pool_loss + + log_cad = jnp.log(jnp.array(12.0)) + log_gas = jnp.log(jnp.array(1.0)) + noise_coeffs = jnp.zeros(K_OBS).at[0].set(8.0) + + # Compute what V_pred would be + v_arb_all = interpolate_pool_daily(synthetic_pool_coeffs, log_cad, jnp.exp(log_gas)) + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = jnp.array(np.arange(n_obs) % n_days) + v_arb = v_arb_all[day_indices] + v_noise = noise_volume(noise_coeffs, jnp.array(synthetic_x_obs)) + y_obs = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + + params = pack_params(float(log_cad), float(log_gas), noise_coeffs) + loss = pool_loss(params, synthetic_pool_coeffs, jnp.array(synthetic_x_obs), + y_obs, day_indices) + assert float(loss) < 1e-6 + + def test_loss_positive(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + loss = pool_loss(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert float(loss) >= 0 + + def test_loss_increases_with_error(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + params_ok = self._make_params(noise_coeffs=jnp.zeros(K_OBS).at[0].set(8.0)) + params_bad = self._make_params(noise_coeffs=jnp.zeros(K_OBS).at[0].set(20.0)) + + loss_ok = pool_loss(params_ok, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + loss_bad = pool_loss(params_bad, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert float(loss_bad) > float(loss_ok) + + def test_loss_differentiable(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss, argnums=0) + g = grad_fn(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert g.shape == params.shape + + def test_loss_grad_nonzero(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss, argnums=0) + g = grad_fn(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert float(jnp.sum(jnp.abs(g))) > 0 + + def test_loss_grad_wrt_log_cadence(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss, argnums=0) + g = grad_fn(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert jnp.isfinite(g[0]) # log_cadence gradient + + def test_loss_grad_wrt_log_gas(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss, argnums=0) + g = grad_fn(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert jnp.isfinite(g[1]) # log_gas gradient + + def test_loss_grad_wrt_noise_coeffs(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss, argnums=0) + g = grad_fn(params, synthetic_pool_coeffs, x_obs, y_obs, day_indices) + assert jnp.all(jnp.isfinite(g[2:])) # noise_coeffs gradients + + def test_loss_uses_per_day_varb(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.loss import pool_loss, unpack_params + + params = self._make_params() + log_cad, log_gas, _ = unpack_params(params) + v_arb = interpolate_pool_daily( + synthetic_pool_coeffs, jnp.array(log_cad), jnp.exp(jnp.array(log_gas)) + ) + # Per-day V_arb should vary across days + assert float(jnp.std(v_arb)) > 0 + + def test_day_indices_alignment(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss + + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + params = self._make_params() + x_obs = jnp.array(synthetic_x_obs) + y_obs = jnp.ones(n_obs) * 9.0 + + # Different day_indices should give different loss + day_idx_a = jnp.zeros(n_obs, dtype=jnp.int32) # all same day + day_idx_b = jnp.array(np.arange(n_obs) % n_days) # varying days + + loss_a = pool_loss(params, synthetic_pool_coeffs, x_obs, y_obs, day_idx_a) + loss_b = pool_loss(params, synthetic_pool_coeffs, x_obs, y_obs, day_idx_b) + # Different day mappings → different losses + assert float(loss_a) != float(loss_b) + + +class TestPackUnpack: + """Test pack/unpack parameter roundtrip.""" + + def test_roundtrip(self): + from quantammsim.calibration.loss import pack_params, unpack_params + + log_cad = 2.5 + log_gas = -0.3 + noise_coeffs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + + flat = pack_params(log_cad, log_gas, noise_coeffs) + lc, lg, nc = unpack_params(flat) + + np.testing.assert_allclose(lc, log_cad) + np.testing.assert_allclose(lg, log_gas) + np.testing.assert_allclose(nc, noise_coeffs) + + def test_pack_shape(self): + from quantammsim.calibration.loss import pack_params + + flat = pack_params(1.0, 2.0, jnp.zeros(K_OBS)) + assert flat.shape == (2 + K_OBS,) diff --git a/tests/calibration/test_per_pool_fit.py b/tests/calibration/test_per_pool_fit.py new file mode 100644 index 0000000..32d63ab --- /dev/null +++ b/tests/calibration/test_per_pool_fit.py @@ -0,0 +1,170 @@ +"""Tests for quantammsim.calibration.per_pool_fit — L-BFGS-B per-pool fitting.""" + +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, N_DAYS, POOL_IDS_FULL, POOL_PREFIXES + + +class TestFitSinglePool: + """Test fit_single_pool: L-BFGS-B optimization for one pool.""" + + def _make_inputs(self, synthetic_pool_coeffs, synthetic_x_obs): + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = np.arange(n_obs) % n_days + y_obs = np.ones(n_obs) * 9.0 # log(V_obs) + return synthetic_x_obs, y_obs, day_indices + + def test_returns_result_dict(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + assert isinstance(result, dict) + for key in ["log_cadence", "log_gas", "noise_coeffs", "loss", "converged"]: + assert key in result, f"Missing key: {key}" + + def test_cadence_in_range(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + cadence = np.exp(result["log_cadence"]) + assert 1.0 <= cadence <= 60.0 + + def test_gas_positive(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + assert np.exp(result["log_gas"]) > 0 + + def test_noise_coeffs_length(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + assert len(result["noise_coeffs"]) == K_OBS + + def test_loss_decreases_from_init(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pack_params, pool_loss + from quantammsim.calibration.per_pool_fit import ( + fit_single_pool, + make_initial_guess, + ) + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + init = make_initial_guess(x_obs, y_obs) + init_loss = float(pool_loss( + jnp.array(init), synthetic_pool_coeffs, + jnp.array(x_obs), jnp.array(y_obs), jnp.array(day_idx), + )) + + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + assert result["loss"] <= init_loss + + def test_converged_flag(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool(synthetic_pool_coeffs, x_obs, y_obs, day_idx) + assert isinstance(result["converged"], bool) + + +class TestFitAllPools: + """Test fit_all_pools: fit all matched pools.""" + + def test_returns_dict_per_pool( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + results = fit_all_pools(matched) + assert isinstance(results, dict) + + def test_all_pools_have_results( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + results = fit_all_pools(matched) + for prefix in matched: + assert prefix in results + + def test_results_have_metadata( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + results = fit_all_pools(matched) + for prefix, res in results.items(): + assert "chain" in res + assert "fee" in res + assert "tokens" in res + + +class TestInitialGuess: + """Test make_initial_guess: reasonable starting point.""" + + def test_default_init_reasonable(self, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import make_initial_guess + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + init = make_initial_guess(synthetic_x_obs, y_obs) + assert len(init) == 2 + K_OBS + # log_cadence ~ log(12) + np.testing.assert_allclose(init[0], np.log(12.0), atol=0.1) + # log_gas ~ log(1.0) + np.testing.assert_allclose(init[1], np.log(1.0), atol=0.1) + + def test_init_noise_from_ols(self, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import make_initial_guess + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + init = make_initial_guess(synthetic_x_obs, y_obs) + noise_coeffs = init[2:] + # OLS should give finite values + assert np.all(np.isfinite(noise_coeffs)) diff --git a/tests/calibration/test_pool_data.py b/tests/calibration/test_pool_data.py new file mode 100644 index 0000000..b3a9cf7 --- /dev/null +++ b/tests/calibration/test_pool_data.py @@ -0,0 +1,303 @@ +"""Tests for quantammsim.calibration.pool_data — data assembly.""" + +import numpy as np +import pandas as pd +import pytest + +from tests.calibration.conftest import ( + K_OBS, + N_DAYS, + POOL_IDS_FULL, + POOL_PREFIXES, +) + + +class TestMatchGridsToPanel: + """Test match_grids_to_panel: match grid parquets to panel rows.""" + + def test_match_returns_dict_per_pool( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + # Write grid for pool 0 + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + assert isinstance(matched, dict) + assert POOL_PREFIXES[0] in matched + + def test_match_filters_to_grid_pools_only( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + # Only write grid for pool 0 — pool 1 should be excluded + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + assert POOL_PREFIXES[0] in matched + assert POOL_PREFIXES[1] not in matched + + def test_match_includes_panel_obs( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + entry = matched[POOL_PREFIXES[0]] + assert "panel" in entry + assert isinstance(entry["panel"], pd.DataFrame) + assert len(entry["panel"]) > 0 + + def test_match_includes_coeffs( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.grid_interpolation import PoolCoeffsDaily + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + assert "coeffs" in matched[POOL_PREFIXES[0]] + assert isinstance(matched[POOL_PREFIXES[0]]["coeffs"], PoolCoeffsDaily) + + def test_pool_id_prefix_matching( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + entry = matched[POOL_PREFIXES[0]] + assert entry["pool_id"] == POOL_IDS_FULL[0] + + def test_match_includes_day_indices( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + entry = matched[POOL_PREFIXES[0]] + assert "day_indices" in entry + assert len(entry["day_indices"]) == len(entry["panel"]) + + def test_day_indices_align_dates( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + synthetic_daily_grid.to_parquet( + grid_dir / f"{POOL_PREFIXES[0]}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + entry = matched[POOL_PREFIXES[0]] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + + # Panel dates should map to grid dates via ordinals + panel_dates = pd.to_datetime(entry["panel"]["date"]) + panel_ordinals = np.array([d.toordinal() for d in panel_dates]) + grid_ordinals = np.array(coeffs.dates) + + for i, panel_ord in enumerate(panel_ordinals): + grid_idx = day_indices[i] + assert grid_ordinals[grid_idx] == panel_ord + + +class TestBuildXObs: + """Test build_x_obs: observation covariate matrix.""" + + def test_x_obs_shape(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + assert x.shape == (len(pool0), K_OBS) + + def test_x_obs_intercept_column(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + np.testing.assert_array_equal(x[:, 0], 1.0) + + def test_x_obs_lagged_tvl(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + np.testing.assert_allclose(x[:, 1], pool0["log_tvl_lag1"].values) + + def test_x_obs_log_sigma(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + expected = np.log(np.maximum(pool0["volatility"].values, 1e-6)) + np.testing.assert_allclose(x[:, 2], expected) + + def test_x_obs_interactions(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + tvl = pool0["log_tvl_lag1"].values + sigma = np.log(np.maximum(pool0["volatility"].values, 1e-6)) + fee = pool0["log_fee"].values + np.testing.assert_allclose(x[:, 3], tvl * sigma) + np.testing.assert_allclose(x[:, 4], tvl * fee) + np.testing.assert_allclose(x[:, 5], sigma * fee) + + def test_x_obs_dow_harmonics(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + weekdays = pd.to_datetime(pool0["date"]).dt.weekday.values + expected_sin = np.sin(2 * np.pi * weekdays / 7) + expected_cos = np.cos(2 * np.pi * weekdays / 7) + np.testing.assert_allclose(x[:, 6], expected_sin, atol=1e-10) + np.testing.assert_allclose(x[:, 7], expected_cos, atol=1e-10) + + def test_x_obs_no_nans(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + assert not np.any(np.isnan(x)) + + +class TestBuildPoolAttributes: + """Test build_pool_attributes: pool-level feature matrix.""" + + def test_attributes_shape( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + match_grids_to_panel, + ) + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + assert X_attr.shape[0] == len(matched) + assert X_attr.shape[1] == len(attr_names) + + def test_attributes_has_chain( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + match_grids_to_panel, + ) + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + chain_cols = [n for n in attr_names if n.startswith("chain_")] + assert len(chain_cols) > 0 + + def test_attributes_has_log_fee( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + match_grids_to_panel, + ) + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + assert "log_fee" in attr_names + + def test_attributes_has_log_tvl( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + match_grids_to_panel, + ) + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + assert "mean_log_tvl" in attr_names + + def test_attributes_returns_pool_order( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + match_grids_to_panel, + ) + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + assert isinstance(pool_ids, list) + assert len(pool_ids) == len(matched) + assert set(pool_ids) == set(matched.keys()) diff --git a/tests/noise/__init__.py b/tests/noise/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/noise/conftest.py b/tests/noise/conftest.py new file mode 100644 index 0000000..5bfeadc --- /dev/null +++ b/tests/noise/conftest.py @@ -0,0 +1,274 @@ +"""Shared fixtures for noise calibration tests.""" + +from datetime import date, timedelta + +import numpy as np +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# synthetic_panel: 3 pools × 10 days = 30 obs (before lag drop → 27 obs) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_panel() -> pd.DataFrame: + """3 pools × 10 days with known structure. + + Pool A: MAINNET, WETH/USDC, tier_A=0, tier_B=0, fee=0.003 + Pool B: ARBITRUM, BAL/WETH, tier_A=0, tier_B=1, fee=0.01 + Pool C: BASE, RATS/WETH, tier_A=0, tier_B=2, fee=0.005 + + Dates: 2026-01-01 to 2026-01-10 + Weekend flags: Sat 2026-01-03 and Sun 2026-01-04 are weekends. + (2026-01-01 is Thursday, ..., 01-03 Sat, 01-04 Sun, 01-05 Mon, ...) + """ + np.random.seed(42) + + pools = [ + ("pool_A", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("pool_B", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("pool_C", "BASE", "RATS,WETH", 0.005, 0, 2), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + + records = [] + for pool_id, chain, tokens, fee, tier_a, tier_b in pools: + log_tvl_base = 14.0 + np.random.randn() * 0.5 + for d in dates: + log_tvl = log_tvl_base + np.random.randn() * 0.1 + log_vol = log_tvl - 2.0 + np.random.randn() * 0.3 + vol = 0.3 + np.random.rand() * 0.2 + is_weekend = 1.0 if d.weekday() >= 5 else 0.0 + + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": d, + "log_volume": log_vol, + "log_tvl": log_tvl, + "volatility": vol, + "weekend": is_weekend, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": tokens, + }) + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + # Structural model covariates + panel["log_sigma"] = np.log(np.maximum(panel["volatility"].values, 1e-6)) + dow = panel["date"].apply( + lambda d: d.weekday() if hasattr(d, "weekday") else pd.Timestamp(d).weekday() + ) + panel["dow_sin"] = np.sin(2.0 * np.pi * dow / 7.0) + panel["dow_cos"] = np.cos(2.0 * np.pi * dow / 7.0) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + return panel + + +# --------------------------------------------------------------------------- +# synthetic_encoded_data: output of encode_covariates(synthetic_panel) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_encoded_data(synthetic_panel): + from quantammsim.noise_calibration import encode_covariates + return encode_covariates(synthetic_panel) + + +# --------------------------------------------------------------------------- +# synthetic_samples: deterministic posterior-like dict +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_samples(synthetic_encoded_data): + """Deterministic posterior samples with eta=0, L_Omega=I. + + With this structure: theta = X_pool @ B^T exactly. + """ + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_coeff = 4 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + sigma_theta = np.ones((S, K_coeff)) + L_Omega = np.tile(np.eye(K_coeff), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_coeff)) + df = np.full((S,), 5.0) + sigma_eps = np.tile([0.5, 0.8, 0.6], (S, 1)) + + return { + "B": B, + "sigma_theta": sigma_theta, + "L_Omega": L_Omega, + "eta": eta, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_pools_df: matches enumerate_balancer_pools output schema +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_pools_df() -> pd.DataFrame: + return pd.DataFrame([ + { + "pool_id": "pool_A", + "chain": "MAINNET", + "pool_type": "WEIGHTED", + "tokens": ["WETH", "USDC"], + "token_addresses": ["0xweth", "0xusdc"], + "swap_fee": 0.003, + "current_tvl": 1_000_000, + }, + { + "pool_id": "pool_B", + "chain": "ARBITRUM", + "pool_type": "WEIGHTED", + "tokens": ["BAL", "WETH"], + "token_addresses": ["0xbal", "0xweth"], + "swap_fee": 0.01, + "current_tvl": 500_000, + }, + { + "pool_id": "pool_C", + "chain": "BASE", + "pool_type": "WEIGHTED", + "tokens": ["RATS", "WETH"], + "token_addresses": ["0xrats", "0xweth"], + "swap_fee": 0.005, + "current_tvl": 100_000, + }, + ]) + + +# --------------------------------------------------------------------------- +# synthetic_ibp_samples: IBP posterior-like dict (no eta, L_Omega, sigma_theta) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_ibp_samples(synthetic_encoded_data): + """Deterministic IBP posterior samples (marginalized model). + + Contains B, W, v_ibp, alpha_ibp — no z_logit, eta, L_Omega, sigma_theta. + Z is analytically marginalized; MAP assignments are computed from data. + """ + data = synthetic_encoded_data + K_cov = data["K_cov"] + K_coeff = 4 + K_features = 6 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + W = np.random.randn(S, K_features, K_coeff) * 0.3 + v_ibp = np.random.beta(2, 1, size=(S, K_features)) + alpha_ibp = np.full((S,), 2.0) + sigma_w = np.full((S,), 1.0) + df = np.full((S,), 5.0) + sigma_eps = np.full((S,), 0.5) + + return { + "B": B, + "W": W, + "v_ibp": v_ibp, + "alpha_ibp": alpha_ibp, + "sigma_w": sigma_w, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_ibp_dp_samples: hybrid IBP+DP posterior-like dict +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_ibp_dp_samples(synthetic_encoded_data): + """Deterministic hybrid IBP+DP posterior samples. + + Contains both IBP keys (B, W, v_ibp, alpha_ibp, sigma_w) and + DP keys (v, alpha_dp, sigma_eps as vector). No z_logit, eta, L_Omega, + sigma_theta. + """ + data = synthetic_encoded_data + K_cov = data["K_cov"] + K_coeff = 4 + K_features = 6 + K_clusters = 6 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + W = np.random.randn(S, K_features, K_coeff) * 0.3 + v_ibp = np.random.beta(2, 1, size=(S, K_features)) + alpha_ibp = np.full((S,), 2.0) + sigma_w = np.full((S,), 1.0) + v = np.random.beta(1, 2, size=(S, K_clusters - 1)) + alpha_dp = np.full((S,), 1.0) + df = np.full((S,), 5.0) + sigma_eps = np.abs(np.random.randn(S, K_clusters)) + 0.1 + + return { + "B": B, + "W": W, + "v_ibp": v_ibp, + "alpha_ibp": alpha_ibp, + "sigma_w": sigma_w, + "v": v, + "alpha_dp": alpha_dp, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_snapshots_df: matches fetch_all_snapshots output schema +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# synthetic_structural_data: output of encode_covariates_structural() +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_structural_data(synthetic_panel): + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + return encode_covariates_structural(synthetic_panel) + + +# --------------------------------------------------------------------------- +# synthetic_snapshots_df: matches fetch_all_snapshots output schema +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_snapshots_df() -> pd.DataFrame: + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + records = [] + np.random.seed(42) + for pool_id, chain in [("pool_A", "MAINNET"), ("pool_B", "ARBITRUM"), + ("pool_C", "BASE")]: + for d in dates: + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": d, + "volume_usd": np.exp(10.0 + np.random.randn() * 0.5), + "total_liquidity_usd": np.exp(14.0 + np.random.randn() * 0.3), + }) + return pd.DataFrame(records) diff --git a/tests/noise/test_covariate_encoding.py b/tests/noise/test_covariate_encoding.py new file mode 100644 index 0000000..dececb5 --- /dev/null +++ b/tests/noise/test_covariate_encoding.py @@ -0,0 +1,326 @@ +"""Tests for encode_covariates and encode_covariates_structural.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import encode_covariates +from quantammsim.noise_calibration.constants import K_OBS_COEFF, OBS_COEFF_NAMES + + +class TestEncodeCovariates: + def test_x_pool_shape(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + N_pools = data["N_pools"] + K_cov = data["K_cov"] + assert data["X_pool"].shape == (N_pools, K_cov) + + def test_x_obs_shape(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + N_obs = len(synthetic_panel) + assert data["x_obs"].shape == (N_obs, 4) + + def test_x_obs_column_0_is_intercept(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal(data["x_obs"][:, 0], 1.0) + + def test_x_obs_column_1_is_lagged_tvl(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 1], + synthetic_panel["log_tvl_lag1"].values, + ) + + def test_x_obs_column_2_is_volatility(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 2], + synthetic_panel["volatility"].values, + ) + + def test_x_obs_column_3_is_weekend(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 3], + synthetic_panel["weekend"].values, + ) + + def test_intercept_column_all_ones(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal(data["X_pool"][:, 0], 1.0) + + def test_chain_dummies_one_hot(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + chain_cols = [i for i, n in enumerate(col_names) if n.startswith("chain_")] + X = data["X_pool"] + + for row in range(X.shape[0]): + chain_vals = X[row, chain_cols] + assert chain_vals.sum() <= 1.0 + + def test_reference_chain_is_alphabetically_first(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + chains = sorted(synthetic_panel["chain"].unique()) + ref_chain = chains[0] # ARBITRUM + + pool_meta = data["pool_meta"] + ref_idx = pool_meta[pool_meta["chain"] == ref_chain].index[0] + col_names = data["covariate_names"] + chain_cols = [i for i, n in enumerate(col_names) if n.startswith("chain_")] + X = data["X_pool"] + assert all(X[ref_idx, c] == 0.0 for c in chain_cols) + + def test_tier_a_dummies_match(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + tier_a_cols = [ + (i, n) for i, n in enumerate(col_names) if n.startswith("tier_A_") + ] + pool_meta = data["pool_meta"] + X = data["X_pool"] + + for idx, row in pool_meta.iterrows(): + tier_a_str = str(row["tier_A"]) + for col_idx, col_name in tier_a_cols: + expected_tier = col_name.split("_")[-1] + expected = 1.0 if tier_a_str == expected_tier else 0.0 + assert X[idx, col_idx] == expected, ( + f"Pool {idx} tier_A={tier_a_str}, col {col_name}: " + f"expected {expected}, got {X[idx, col_idx]}" + ) + + def test_tier_b_dummies_match(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + tier_b_cols = [ + (i, n) for i, n in enumerate(col_names) if n.startswith("tier_B_") + ] + pool_meta = data["pool_meta"] + X = data["X_pool"] + + for idx, row in pool_meta.iterrows(): + tier_b_str = str(row["tier_B"]) + for col_idx, col_name in tier_b_cols: + expected_tier = col_name.split("_")[-1] + expected = 1.0 if tier_b_str == expected_tier else 0.0 + assert X[idx, col_idx] == expected + + def test_log_fee_column(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + fee_idx = col_names.index("log_fee") + pool_meta = data["pool_meta"] + + for idx, row in pool_meta.iterrows(): + expected = np.log(max(row["swap_fee"], 1e-6)) + np.testing.assert_allclose( + data["X_pool"][idx, fee_idx], expected, rtol=1e-10, + ) + + def test_pool_idx_maps_observations(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + pool_ids = data["pool_ids"] + pool_idx = data["pool_idx"] + + assert pool_idx.min() >= 0 + assert pool_idx.max() < len(pool_ids) + + for i, pid in enumerate(pool_ids): + mask = synthetic_panel["pool_id"] == pid + obs_indices = np.where(mask.values)[0] + assert (pool_idx[obs_indices] == i).all() + + def test_y_obs_matches_panel(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["y_obs"], + synthetic_panel["log_volume"].values, + ) + + def test_covariate_names_length(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + assert len(data["covariate_names"]) == data["K_cov"] + + def test_covariate_column_ordering(self, synthetic_panel): + """X_pool columns must follow: intercept, chain dummies, tier_A dummies, + tier_B dummies, log_fee. This ordering is load-bearing because B is + indexed by column position.""" + data = encode_covariates(synthetic_panel) + names = data["covariate_names"] + + assert names[0] == "intercept" + assert names[-1] == "log_fee" + + # Find boundaries + chain_start = None + tier_a_start = None + tier_b_start = None + fee_idx = len(names) - 1 + + for i, n in enumerate(names): + if n.startswith("chain_") and chain_start is None: + chain_start = i + if n.startswith("tier_A_") and tier_a_start is None: + tier_a_start = i + if n.startswith("tier_B_") and tier_b_start is None: + tier_b_start = i + + # Verify ordering: intercept < chains < tier_A < tier_B < log_fee + if chain_start is not None: + assert chain_start > 0 # after intercept + if tier_a_start is not None and chain_start is not None: + assert tier_a_start > chain_start + if tier_b_start is not None and tier_a_start is not None: + assert tier_b_start > tier_a_start + if tier_b_start is not None: + assert fee_idx > tier_b_start + + # All chain dummies are contiguous + chain_names = [n for n in names if n.startswith("chain_")] + if chain_names: + chain_indices = [names.index(n) for n in chain_names] + assert chain_indices == list(range(min(chain_indices), + max(chain_indices) + 1)) + + def test_output_dict_has_all_required_keys(self, synthetic_panel): + """encode_covariates must return all keys consumed by downstream + functions (predict_new_pool, generate_output_json, _save_sample_cache).""" + data = encode_covariates(synthetic_panel) + required_keys = { + "pool_idx", "X_pool", "x_obs", "y_obs", "pool_ids", "pool_meta", + "covariate_names", "tier_A_per_pool", "N_pools", "K_cov", + "ref_chain", "ref_tier_a", "ref_tier_b", "chains", + } + assert required_keys.issubset(data.keys()), ( + f"Missing keys: {required_keys - data.keys()}" + ) + + +class TestEncodeCovariatesNoTiers: + """Tests for encode_covariates(include_tiers=False).""" + + def test_no_tier_columns_in_covariate_names(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + for name in data["covariate_names"]: + assert not name.startswith("tier_A_"), ( + f"Found tier_A column {name} with include_tiers=False" + ) + assert not name.startswith("tier_B_"), ( + f"Found tier_B column {name} with include_tiers=False" + ) + + def test_still_has_intercept_and_log_fee(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + assert "intercept" in data["covariate_names"] + assert "log_fee" in data["covariate_names"] + + def test_still_has_chain_dummies(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + chain_cols = [n for n in data["covariate_names"] + if n.startswith("chain_")] + assert len(chain_cols) > 0 + + def test_k_cov_smaller_than_with_tiers(self, synthetic_panel): + data_tiers = encode_covariates(synthetic_panel, include_tiers=True) + data_no_tiers = encode_covariates(synthetic_panel, include_tiers=False) + assert data_no_tiers["K_cov"] < data_tiers["K_cov"] + + def test_default_is_include_tiers_true(self, synthetic_panel): + data_default = encode_covariates(synthetic_panel) + data_explicit = encode_covariates(synthetic_panel, include_tiers=True) + assert data_default["K_cov"] == data_explicit["K_cov"] + + def test_tier_A_per_pool_still_present(self, synthetic_panel): + """tier_A_per_pool is still returned (for other downstream uses).""" + data = encode_covariates(synthetic_panel, include_tiers=False) + assert "tier_A_per_pool" in data + + def test_x_pool_shape_matches_k_cov(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + assert data["X_pool"].shape == (data["N_pools"], data["K_cov"]) + + +class TestEncodeStructuralCovariates: + """Tests for encode_covariates_structural().""" + + @pytest.fixture() + def struct_data(self, synthetic_panel): + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + return encode_covariates_structural(synthetic_panel) + + def test_encode_structural_x_obs_shape(self, struct_data, synthetic_panel): + N_obs = len(synthetic_panel) + assert struct_data["x_obs"].shape == (N_obs, K_OBS_COEFF) + + def test_encode_structural_x_obs_columns(self, struct_data, synthetic_panel): + """Columns must match OBS_COEFF_NAMES ordering: + [1, lag_log_tvl, log_sigma, tvl_x_sigma, tvl_x_fee, sigma_x_fee, + dow_sin, dow_cos].""" + x = struct_data["x_obs"] + np.testing.assert_array_equal(x[:, 0], 1.0) # intercept + np.testing.assert_array_equal( + x[:, 1], synthetic_panel["log_tvl_lag1"].values, + ) + np.testing.assert_array_equal( + x[:, 2], synthetic_panel["log_sigma"].values, + ) + np.testing.assert_allclose( + x[:, 3], synthetic_panel["tvl_x_sigma"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 4], synthetic_panel["tvl_x_fee"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 5], synthetic_panel["sigma_x_fee"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 6], synthetic_panel["dow_sin"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 7], synthetic_panel["dow_cos"].values, rtol=1e-12, + ) + + def test_encode_structural_has_sigma_daily(self, struct_data): + """sigma_daily = volatility / sqrt(365), de-annualised.""" + assert "sigma_daily" in struct_data + assert len(struct_data["sigma_daily"]) > 0 + + def test_encode_structural_has_gas(self, struct_data): + assert "gas" in struct_data + assert len(struct_data["gas"]) > 0 + assert (struct_data["gas"] >= 0).all() + + def test_encode_structural_has_chain_idx_tier_idx(self, struct_data): + assert "chain_idx" in struct_data + assert "tier_idx" in struct_data + assert struct_data["chain_idx"].dtype in (np.int32, np.int64) + assert struct_data["tier_idx"].dtype in (np.int32, np.int64) + + def test_encode_structural_has_fee(self, struct_data): + """fee array is raw (not log), for the formula.""" + assert "fee" in struct_data + assert (struct_data["fee"] > 0).all() + assert (struct_data["fee"] < 1).all() # fees are fractions + + def test_encode_structural_tier_idx_is_pair(self, struct_data): + """tier_idx encodes the (tier_A, tier_B) PAIR, not individual tokens. + (0,0)->0, (0,1)->1, (0,2)->2, (1,1)->3, (1,2)->4, (2,2)->5.""" + tier_idx = struct_data["tier_idx"] + pool_meta = struct_data["pool_meta"] + + for i, row in pool_meta.iterrows(): + a, b = int(row["tier_A"]), int(row["tier_B"]) + expected = a * (5 - a) // 2 + b - a + assert tier_idx[i] == expected, ( + f"Pool {i}: tier ({a},{b}) expected idx {expected}, " + f"got {tier_idx[i]}" + ) + + def test_encode_structural_n_chains_n_tiers(self, struct_data): + """n_chains and n_tiers computed from data.""" + assert "n_chains" in struct_data + assert "n_tiers" in struct_data + assert struct_data["n_chains"] >= 1 + assert struct_data["n_tiers"] >= 1 diff --git a/tests/noise/test_formula_arb.py b/tests/noise/test_formula_arb.py new file mode 100644 index 0000000..f862d51 --- /dev/null +++ b/tests/noise/test_formula_arb.py @@ -0,0 +1,142 @@ +"""Tests for JAX-differentiable LVR formula.""" + +import numpy as np +import pytest + + +class TestFormulaArbJax: + @pytest.fixture(autouse=True) + def _import(self): + from quantammsim.noise_calibration.formula_arb import ( + formula_arb_volume_daily_jax, + ) + self.formula = formula_arb_volume_daily_jax + + def test_formula_arb_zero_vol_returns_zero(self): + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.0), + tvl=jnp.float64(1e6), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-10) + + def test_formula_arb_zero_tvl_returns_zero(self): + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.03), + tvl=jnp.float64(0.0), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-10) + + def test_formula_arb_quadratic_in_sigma(self): + """V_arb(2σ) / V_arb(σ) ≈ 4 for small gas (correction ≈ 1).""" + import jax.numpy as jnp + sigma = jnp.float64(0.01) + tvl = jnp.float64(1e8) # large TVL so gas is negligible + fee = jnp.float64(0.003) + gas = jnp.float64(0.001) # tiny gas + cadence = jnp.float64(0.01) # very fast arb + + v1 = float(self.formula(sigma, tvl, fee, gas, cadence)) + v2 = float(self.formula(2.0 * sigma, tvl, fee, gas, cadence)) + assert v1 > 0 + ratio = v2 / v1 + assert ratio == pytest.approx(4.0, rel=0.1) + + def test_formula_arb_linear_in_tvl(self): + """V_arb(2V) / V_arb(V) ≈ 2 for small gas.""" + import jax.numpy as jnp + sigma = jnp.float64(0.02) + tvl = jnp.float64(1e8) + fee = jnp.float64(0.003) + gas = jnp.float64(0.0001) + cadence = jnp.float64(0.01) + + v1 = float(self.formula(sigma, tvl, fee, gas, cadence)) + v2 = float(self.formula(sigma, 2.0 * tvl, fee, gas, cadence)) + assert v1 > 0 + ratio = v2 / v1 + assert ratio == pytest.approx(2.0, rel=0.1) + + def test_formula_arb_gas_kills_small_pools(self): + """High gas, small TVL → V_arb ≈ 0.""" + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.02), + tvl=jnp.float64(1000.0), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1000.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-6) + + def test_formula_arb_matches_numpy_reference(self): + """Compare to the numpy formula in plot_formula_arb_vs_real.py.""" + import jax.numpy as jnp + + # Reference implementation (from plot_formula_arb_vs_real.py:58) + def ref(sigma_daily, tvl, fee, block_time_s, gas_usd): + if tvl <= 0 or fee <= 0 or sigma_daily <= 0: + return 0.0 + gamma = fee + delta = 2.0 * np.sqrt(2.0 * gas_usd / tvl) if gas_usd > 0 else 0.0 + bLVR = sigma_daily**2 * tvl / 8.0 + sqrt_s2_2l = sigma_daily * np.sqrt(block_time_s / (2.0 * 86400.0)) + bFEE = bLVR * max( + 1.0 - delta / (2.0 * gamma) - sqrt_s2_2l / (gamma + delta / 2.0), + 0.0, + ) + return bFEE / gamma + + test_cases = [ + (0.03, 1e6, 0.003, 1.0, 1.0), + (0.05, 5e5, 0.01, 0.5, 0.005), + (0.01, 1e7, 0.005, 2.0, 0.01), + (0.1, 1e4, 0.03, 10.0, 5.0), + (0.02, 1e5, 0.001, 1.0, 0.001), + ] + + for sigma, tvl, fee, cadence_min, gas in test_cases: + block_time_s = cadence_min * 60.0 + expected = ref(sigma, tvl, fee, block_time_s, gas) + actual = float(self.formula( + jnp.float64(sigma), jnp.float64(tvl), + jnp.float64(fee), jnp.float64(gas), + jnp.float64(cadence_min), + )) + np.testing.assert_allclose( + actual, expected, rtol=1e-10, + err_msg=f"Mismatch for sigma={sigma}, tvl={tvl}, " + f"fee={fee}, cadence={cadence_min}, gas={gas}", + ) + + def test_formula_arb_is_jax_differentiable(self): + """jax.grad w.r.t. sigma should run without error.""" + import jax + import jax.numpy as jnp + + grad_fn = jax.grad(self.formula, argnums=0) + result = grad_fn( + jnp.float64(0.03), jnp.float64(1e6), + jnp.float64(0.003), jnp.float64(1.0), + jnp.float64(1.0), + ) + assert np.isfinite(float(result)) + + def test_formula_arb_cadence_reduces_volume(self): + """Higher cadence → less frequent arb → lower volume.""" + import jax.numpy as jnp + sigma = jnp.float64(0.03) + tvl = jnp.float64(1e6) + fee = jnp.float64(0.003) + gas = jnp.float64(1.0) + + v_fast = float(self.formula(sigma, tvl, fee, gas, jnp.float64(1.0))) + v_slow = float(self.formula(sigma, tvl, fee, gas, jnp.float64(10.0))) + assert v_fast > v_slow diff --git a/tests/noise/test_model_and_inference.py b/tests/noise/test_model_and_inference.py new file mode 100644 index 0000000..ee1fda5 --- /dev/null +++ b/tests/noise/test_model_and_inference.py @@ -0,0 +1,328 @@ +"""Tests for noise_model, _get_theta_samples, _build_model_kwargs, and SVI smoke.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + noise_model, + _get_theta_samples, + _build_model_kwargs, + K_COEFF, +) + + +# =========================================================================== +# TestNoiseModelDefinition +# =========================================================================== + + +class TestNoiseModelDefinition: + def test_model_traces_without_error(self, synthetic_encoded_data): + import jax + import jax.numpy as jnp + import numpyro + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + assert trace is not None + + def test_required_sites_present(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + required = {"B", "sigma_theta", "L_Omega", "df", "sigma_eps", "eta", "y"} + assert required.issubset(trace.keys()) + + def test_site_shapes(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + N_pools = data["N_pools"] + K_cov = data["K_cov"] + + assert trace["B"]["value"].shape == (K_COEFF, K_cov) + assert trace["sigma_theta"]["value"].shape == (K_COEFF,) + assert trace["L_Omega"]["value"].shape == (K_COEFF, K_COEFF) + assert trace["df"]["value"].shape == () + assert trace["sigma_eps"]["value"].shape == (3,) + assert trace["eta"]["value"].shape == (N_pools, K_COEFF) + + def test_theta_deterministic_site(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + assert "theta" in trace + assert trace["theta"]["value"].shape == (data["N_pools"], K_COEFF) + + def test_prior_predictive_produces_y(self, synthetic_encoded_data): + import jax + from numpyro.infer import Predictive + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + kwargs["y_obs"] = None + + predictive = Predictive(noise_model, num_samples=5) + rng_key = jax.random.PRNGKey(42) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 5 + + +# =========================================================================== +# TestGetThetaSamples +# =========================================================================== + + +class TestGetThetaSamples: + def test_returns_theta_directly_when_present(self, synthetic_encoded_data): + data = synthetic_encoded_data + theta_direct = np.random.randn(10, data["N_pools"], K_COEFF) + sample_dict = {"theta": theta_direct} + result = _get_theta_samples(sample_dict, data["X_pool"]) + np.testing.assert_array_equal(result, theta_direct) + + def test_reconstructs_from_non_centered( + self, synthetic_encoded_data, synthetic_samples + ): + data = synthetic_encoded_data + result = _get_theta_samples(synthetic_samples, data["X_pool"]) + assert result.shape == (10, data["N_pools"], K_COEFF) + + def test_eta_zero_identity_gives_mu( + self, synthetic_encoded_data, synthetic_samples + ): + """With eta=0 and L_Omega=I, theta = X_pool @ B^T.""" + data = synthetic_encoded_data + X_pool = data["X_pool"] + B = synthetic_samples["B"] + + result = _get_theta_samples(synthetic_samples, X_pool) + + # Expected: mu[s,p,j] = sum_d X_pool[p,d] * B[s,j,d] + expected = np.einsum("pd,sjd->spj", X_pool, B) + np.testing.assert_allclose(result, expected, atol=1e-12) + + def test_output_shape(self, synthetic_encoded_data, synthetic_samples): + data = synthetic_encoded_data + result = _get_theta_samples(synthetic_samples, data["X_pool"]) + S = synthetic_samples["B"].shape[0] + assert result.shape == (S, data["N_pools"], K_COEFF) + + def test_reconstructed_matches_direct(self, synthetic_encoded_data): + """When both theta and raw params are present, reconstruction matches.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + S = 8 + X_pool = data["X_pool"] + + np.random.seed(123) + B = np.random.randn(S, K_COEFF, K_cov) * 0.3 + sigma_theta = np.abs(np.random.randn(S, K_COEFF)) + 0.1 + # Random lower-triangular L + L_raw = np.zeros((S, K_COEFF, K_COEFF)) + for s in range(S): + A = np.random.randn(K_COEFF, K_COEFF) + L_raw[s] = np.linalg.cholesky(A @ A.T + np.eye(K_COEFF)) + eta = np.random.randn(S, N_pools, K_COEFF) + + sample_dict = { + "B": B, "sigma_theta": sigma_theta, + "L_Omega": L_raw, "eta": eta, + } + + theta_recon = _get_theta_samples(sample_dict, X_pool) + + # Compute expected directly + mu = np.einsum("pd,sjd->spj", X_pool, B) + L_Sigma = sigma_theta[:, :, None] * L_raw + offset = np.einsum("spi,sji->spj", eta, L_Sigma) + expected = mu + offset + + np.testing.assert_allclose(theta_recon, expected, atol=1e-10) + + +# =========================================================================== +# TestBuildModelKwargs +# =========================================================================== + + +class TestBuildModelKwargs: + def test_all_outputs_are_jnp(self, synthetic_encoded_data): + import jax.numpy as jnp + + kwargs = _build_model_kwargs(synthetic_encoded_data) + for key in ["pool_idx", "X_pool", "x_obs", "y_obs", "tier_A_per_pool"]: + assert isinstance(kwargs[key], jnp.ndarray), ( + f"{key} should be jnp array" + ) + + def test_shapes_preserved(self, synthetic_encoded_data): + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + assert kwargs["pool_idx"].shape == data["pool_idx"].shape + assert kwargs["X_pool"].shape == data["X_pool"].shape + assert kwargs["x_obs"].shape == data["x_obs"].shape + assert kwargs["y_obs"].shape == data["y_obs"].shape + assert kwargs["N_pools"] == data["N_pools"] + assert kwargs["K_cov"] == data["K_cov"] + + def test_dp_model_kwargs_exclude_tier_A(self, synthetic_encoded_data): + """When model_fn is noise_model_dp_sigma, tier_A_per_pool is excluded + and K_clusters is included.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + data = dict(synthetic_encoded_data) + data["K_clusters"] = 6 + kwargs = _build_model_kwargs(data, model_fn=noise_model_dp_sigma) + assert "tier_A_per_pool" not in kwargs + assert kwargs["K_clusters"] == 6 + + def test_tier_model_kwargs_include_tier_A(self, synthetic_encoded_data): + """When model_fn is noise_model (default), tier_A_per_pool is included + and K_clusters is not.""" + kwargs = _build_model_kwargs(synthetic_encoded_data) + assert "tier_A_per_pool" in kwargs + assert "K_clusters" not in kwargs + + def test_default_model_fn_is_noise_model(self, synthetic_encoded_data): + """Calling without model_fn should behave identically to model_fn=noise_model.""" + kwargs_default = _build_model_kwargs(synthetic_encoded_data) + kwargs_explicit = _build_model_kwargs( + synthetic_encoded_data, model_fn=noise_model + ) + assert set(kwargs_default.keys()) == set(kwargs_explicit.keys()) + + +# =========================================================================== +# TestSVISmoke +# =========================================================================== + + +class TestSVISmoke: + @pytest.mark.slow + def test_svi_converges_small_data(self): + """SVI on tiny synthetic data: ELBO should decrease.""" + import jax + import numpyro + + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + + # Build minimal panel: 5 pools × 20 days + np.random.seed(42) + from datetime import date, timedelta + + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("p2", "BASE", "RATS,WETH", 0.005, 0, 2), + ("p3", "MAINNET", "LINK,WETH", 0.005, 0, 1), + ("p4", "ARBITRUM", "AAVE,USDC", 0.003, 0, 1), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(21)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + + import pandas as pd + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + data = encode_covariates(panel) + samples, losses = run_svi(data, num_steps=2000, lr=1e-3, seed=0, + num_samples=50) + + # ELBO should decrease: mean of last 100 < mean of first 100 + assert np.mean(losses[-100:]) < np.mean(losses[:100]) + + @pytest.mark.slow + def test_svi_samples_have_required_keys(self): + """SVI output dict has all expected latent variable keys.""" + import jax + import numpyro + + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + + np.random.seed(42) + from datetime import date, timedelta + import pandas as pd + + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(15)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + data = encode_covariates(panel) + samples, _ = run_svi(data, num_steps=500, lr=1e-3, seed=0, + num_samples=10) + + required = {"B", "sigma_theta", "L_Omega", "eta", "df", "sigma_eps"} + assert required.issubset(samples.keys()) diff --git a/tests/noise/test_model_dp_sigma.py b/tests/noise/test_model_dp_sigma.py new file mode 100644 index 0000000..3d2be21 --- /dev/null +++ b/tests/noise/test_model_dp_sigma.py @@ -0,0 +1,305 @@ +"""Tests for stick_breaking_weights and noise_model_dp_sigma.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import K_COEFF +from quantammsim.noise_calibration.constants import K_CLUSTERS_DEFAULT + + +# =========================================================================== +# TestStickBreakingWeights +# =========================================================================== + + +class TestStickBreakingWeights: + def test_sums_to_one(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.5, 0.3, 0.4, 0.6, 0.2]) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(jnp.sum(w)), 1.0, atol=1e-6) + + def test_correct_length(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + K = 7 + v = jnp.ones(K - 1) * 0.3 + w = stick_breaking_weights(v) + assert w.shape == (K,) + + def test_non_negative(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.1, 0.9, 0.5, 0.7, 0.3]) + w = stick_breaking_weights(v) + assert jnp.all(w >= 0.0) + + def test_first_weight_equals_first_v(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.7, 0.4, 0.2]) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(w[0]), 0.7, atol=1e-6) + + def test_jit_compatible(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax + import jax.numpy as jnp + + v = jnp.array([0.5, 0.3, 0.4]) + w_eager = stick_breaking_weights(v) + w_jit = jax.jit(stick_breaking_weights)(v) + np.testing.assert_allclose( + np.array(w_eager), np.array(w_jit), atol=1e-6 + ) + + def test_all_v_one_concentrates_on_first(self): + """If v = [1, 1, ...], all mass goes to first component.""" + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.ones(5) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(w[0]), 1.0, atol=1e-6) + np.testing.assert_allclose(float(jnp.sum(w[1:])), 0.0, atol=1e-6) + + +# =========================================================================== +# TestDPModelDefinition +# =========================================================================== + + +class TestDPModelDefinition: + def _get_dp_model_kwargs(self, data, K_clusters=6): + """Build kwargs for noise_model_dp_sigma from encoded data.""" + import jax.numpy as jnp + + return dict( + pool_idx=jnp.array(data["pool_idx"]), + X_pool=jnp.array(data["X_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K_coeff=K_COEFF, + K_cov=data["K_cov"], + K_clusters=K_clusters, + ) + + def test_model_traces_without_error(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + assert trace is not None + + def test_has_dp_sites(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + dp_sites = {"alpha_dp", "v", "sigma_eps", "log_lik"} + assert dp_sites.issubset(trace.keys()), ( + f"Missing DP sites: {dp_sites - trace.keys()}" + ) + + def test_no_y_site_when_obs_provided(self, synthetic_encoded_data): + """With y_obs provided, the marginalized model uses factor, not obs.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + assert "y" not in trace, ( + "DP model should use numpyro.factor, not obs=y_obs" + ) + + def test_shared_mean_structure_sites(self, synthetic_encoded_data): + """The DP model must share the same mean structure as the tier model.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + shared_sites = {"B", "sigma_theta", "L_Omega", "df", "eta", "theta"} + assert shared_sites.issubset(trace.keys()), ( + f"Missing shared sites: {shared_sites - trace.keys()}" + ) + + def test_correct_shapes(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + K_clusters = 6 + kwargs = self._get_dp_model_kwargs( + synthetic_encoded_data, K_clusters=K_clusters + ) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + N_pools = synthetic_encoded_data["N_pools"] + K_cov = synthetic_encoded_data["K_cov"] + + assert trace["B"]["value"].shape == (K_COEFF, K_cov) + assert trace["sigma_theta"]["value"].shape == (K_COEFF,) + assert trace["L_Omega"]["value"].shape == (K_COEFF, K_COEFF) + assert trace["df"]["value"].shape == () + assert trace["sigma_eps"]["value"].shape == (K_clusters,) + assert trace["eta"]["value"].shape == (N_pools, K_COEFF) + assert trace["v"]["value"].shape == (K_clusters - 1,) + assert trace["alpha_dp"]["value"].shape == () + + def test_no_tier_A_per_pool_in_signature(self): + """The DP model should not accept tier_A_per_pool.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import inspect + + sig = inspect.signature(noise_model_dp_sigma) + assert "tier_A_per_pool" not in sig.parameters + + def test_prior_predictive_produces_y(self, synthetic_encoded_data): + """With y_obs=None, the model should sample y explicitly.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + from numpyro.infer import Predictive + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + kwargs["y_obs"] = None + + predictive = Predictive(noise_model_dp_sigma, num_samples=5) + rng_key = jax.random.PRNGKey(42) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 5 + + def test_k_clusters_configurable(self, synthetic_encoded_data): + """K_clusters=4 should produce different sigma_eps shape.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + K_clusters = 4 + kwargs = self._get_dp_model_kwargs( + synthetic_encoded_data, K_clusters=K_clusters + ) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + assert trace["sigma_eps"]["value"].shape == (K_clusters,) + assert trace["v"]["value"].shape == (K_clusters - 1,) + + +# =========================================================================== +# TestDPModelSVISmoke +# =========================================================================== + + +class TestDPModelSVISmoke: + def _build_dp_panel(self): + """Build a minimal panel for DP SVI testing.""" + from datetime import date, timedelta + import pandas as pd + + np.random.seed(42) + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("p2", "BASE", "RATS,WETH", 0.005, 0, 2), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(15)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + return panel + + def test_svi_converges_dp_model(self): + """SVI on DP model with tiny data: ELBO should decrease.""" + import numpyro + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + panel = self._build_dp_panel() + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = 4 + + samples, losses = run_svi( + data, num_steps=2000, lr=1e-3, seed=0, + num_samples=50, model_fn=noise_model_dp_sigma, + ) + assert np.mean(losses[-100:]) < np.mean(losses[:100]) + + def test_svi_samples_have_dp_keys(self): + """SVI output for DP model has v, alpha_dp, sigma_eps.""" + import numpyro + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + panel = self._build_dp_panel() + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = 4 + + samples, _ = run_svi( + data, num_steps=500, lr=1e-3, seed=0, + num_samples=10, model_fn=noise_model_dp_sigma, + ) + required = {"B", "sigma_theta", "L_Omega", "eta", "df", + "sigma_eps", "v", "alpha_dp"} + assert required.issubset(samples.keys()), ( + f"Missing keys: {required - samples.keys()}" + ) diff --git a/tests/noise/test_model_structural.py b/tests/noise/test_model_structural.py new file mode 100644 index 0000000..582cdd5 --- /dev/null +++ b/tests/noise/test_model_structural.py @@ -0,0 +1,190 @@ +"""Tests for structural_noise_model definition and SVI integration.""" + +import numpy as np +import pytest + + +class TestStructuralModelDefinition: + """Tests for the structural_noise_model numpyro model.""" + + def test_model_traces_without_error(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + rng_key = jax.random.PRNGKey(0) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + + def test_required_sites_present(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + required = { + "alpha_0", "alpha_chain", "alpha_tier", "alpha_tvl", + "B", "sigma_theta", "L_Omega", "eta", "theta", + "df", "sigma_eps", "y", + } + assert required.issubset(samples.keys()), ( + f"Missing sites: {required - samples.keys()}" + ) + + def test_no_moe_sites(self, synthetic_structural_data): + """MoE sites (W_gate, beta) should not be present.""" + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + for site in ("W_gate", "beta"): + assert site not in samples, f"MoE site '{site}' should not be present" + + def test_theta_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + N_pools = synthetic_structural_data["N_pools"] + K_obs = synthetic_structural_data["x_obs"].shape[1] + assert samples["theta"].shape == (3, N_pools, K_obs) + + def test_B_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + K_obs = synthetic_structural_data["x_obs"].shape[1] + K_cov = synthetic_structural_data["K_cov"] + assert samples["B"].shape == (3, K_obs, K_cov) + + def test_alpha_chain_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + n_chains = synthetic_structural_data["n_chains"] + assert samples["alpha_chain"].shape == (3, n_chains - 1) + + def test_prior_predictive_produces_y(self, synthetic_structural_data): + """y_obs=None path produces y samples.""" + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + kwargs["y_obs"] = None + predictive = Predictive(structural_noise_model, num_samples=10) + samples = predictive(jax.random.PRNGKey(42), **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 10 + + def test_prior_predictive_range_reasonable(self, synthetic_structural_data): + """Prior y should contain finite values in a plausible range.""" + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + kwargs["y_obs"] = None + predictive = Predictive(structural_noise_model, num_samples=200) + samples = predictive(jax.random.PRNGKey(42), **kwargs) + y = np.array(samples["y"]) + finite = y[np.isfinite(y)] + assert len(finite) > 0.5 * y.size, "Too many non-finite prior samples" + median = np.median(finite) + assert median > -50, f"Prior y median too low: {median}" + assert median < 100, f"Prior y median too high: {median}" + + +class TestSVIStructural: + """SVI convergence tests for structural model.""" + + @pytest.mark.slow + def test_svi_structural_converges(self, synthetic_structural_data): + """2000 SVI steps, ELBO should decrease.""" + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, losses = run_svi( + synthetic_structural_data, + num_steps=2000, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + assert losses[-100:].mean() < losses[:100].mean() + + @pytest.mark.slow + def test_svi_structural_samples_have_required_keys( + self, synthetic_structural_data, + ): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + required = {"B", "sigma_theta", "L_Omega", "eta", + "alpha_0", "alpha_chain", + "alpha_tier", "alpha_tvl", "df", "sigma_eps"} + assert required.issubset(samples.keys()), ( + f"Missing keys: {required - samples.keys()}" + ) + + @pytest.mark.slow + def test_svi_structural_no_moe_keys(self, synthetic_structural_data): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + for key in ("W_gate", "beta"): + assert key not in samples, f"MoE key '{key}' should not be in samples" diff --git a/tests/noise/test_output.py b/tests/noise/test_output.py new file mode 100644 index 0000000..2246c98 --- /dev/null +++ b/tests/noise/test_output.py @@ -0,0 +1,305 @@ +"""Tests for generate_output_json and _save_sample_cache.""" + +import json +import os + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + extract_noise_params, + generate_output_json, + _save_sample_cache, + K_COEFF, +) + + +# =========================================================================== +# TestGenerateOutputJSON +# =========================================================================== + + +class TestGenerateOutputJSON: + @pytest.fixture() + def _output_setup(self, tmp_path, synthetic_samples, synthetic_encoded_data): + """Produce the JSON file and return (path, data dict).""" + pool_params = extract_noise_params( + synthetic_samples, synthetic_encoded_data + ) + output_path = str(tmp_path / "test_output.json") + convergence = {"method": "svi", "final_elbo": 1234.0} + inference_config = {"method": "svi", "svi_steps": 1000} + + generate_output_json( + pool_params, synthetic_samples, synthetic_encoded_data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + data = json.load(f) + return output_path, data + + def test_writes_valid_json(self, _output_setup): + path, data = _output_setup + assert isinstance(data, dict) + + def test_top_level_keys(self, _output_setup): + _, data = _output_setup + expected = { + "model", "model_spec", "inference", "population_effects", + "convergence", "n_pools", "n_obs", "pools", + } + assert expected.issubset(data.keys()) + + def test_model_spec_fields(self, _output_setup): + _, data = _output_setup + spec = data["model_spec"] + assert "K_coeff" in spec + assert "K_cov" in spec + assert "coeff_names" in spec + assert "covariate_names" in spec + assert "likelihood" in spec + assert spec["likelihood"] == "StudentT" + assert "tvl_lag" in spec + assert spec["tvl_lag"] == "log_tvl_lag1" + + def test_population_effects_fields(self, _output_setup): + _, data = _output_setup + pe = data["population_effects"] + assert "B" in pe + assert "sigma_theta" in pe + assert "sigma_eps" in pe + assert "df" in pe + assert "correlation_matrix" in pe + + def test_pool_entries(self, _output_setup, synthetic_encoded_data): + _, data = _output_setup + pools = data["pools"] + for pid in synthetic_encoded_data["pool_ids"]: + assert pid in pools + entry = pools[pid] + assert "chain" in entry + assert "tokens" in entry + assert "theta_median" in entry + assert "noise_params" in entry + + def test_correlation_matrix_symmetric_unit_diagonal(self, _output_setup): + _, data = _output_setup + Omega = np.array(data["population_effects"]["correlation_matrix"]) + np.testing.assert_allclose(Omega, Omega.T, atol=1e-10) + np.testing.assert_allclose(np.diag(Omega), 1.0, atol=1e-10) + + def test_n_pools_and_n_obs(self, _output_setup, synthetic_encoded_data): + _, data = _output_setup + assert data["n_pools"] == synthetic_encoded_data["N_pools"] + assert data["n_obs"] == len(synthetic_encoded_data["y_obs"]) + + def test_model_name(self, _output_setup): + _, data = _output_setup + assert data["model"] == "unified_hierarchical_student_t" + + +# =========================================================================== +# TestSaveSampleCache +# =========================================================================== + + +class TestSaveSampleCache: + def test_creates_npz_and_json( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + cache_dir = str(tmp_path / "cache") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + assert os.path.exists(os.path.join(cache_dir, "unified_samples.npz")) + assert os.path.exists(os.path.join(cache_dir, "unified_data.json")) + + def test_npz_excludes_y_and_theta( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """Both 'y' and 'theta' must be excluded from the npz cache.""" + samples_with_extras = dict(synthetic_samples) + samples_with_extras["y"] = np.random.randn(10, 27) + samples_with_extras["theta"] = np.random.randn(10, 3, 4) + + cache_dir = str(tmp_path / "cache2") + _save_sample_cache(samples_with_extras, synthetic_encoded_data, cache_dir) + + npz_path = os.path.join(cache_dir, "unified_samples.npz") + loaded = np.load(npz_path) + assert "y" not in loaded.files + assert "theta" not in loaded.files + + def test_npz_contains_required_keys( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """The npz must contain B, sigma_theta, L_Omega, eta, df, sigma_eps.""" + cache_dir = str(tmp_path / "cache3") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + npz_path = os.path.join(cache_dir, "unified_samples.npz") + loaded = np.load(npz_path) + required = {"B", "sigma_theta", "L_Omega", "eta", "df", "sigma_eps"} + assert required.issubset(set(loaded.files)) + + def test_json_contains_metadata_keys( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """The JSON cache must contain all keys needed by --predict.""" + cache_dir = str(tmp_path / "cache4") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + json_path = os.path.join(cache_dir, "unified_data.json") + with open(json_path) as f: + meta = json.load(f) + required = { + "pool_ids", "covariate_names", "K_cov", "N_pools", + "ref_chain", "ref_tier_a", "ref_tier_b", "chains", + } + assert required.issubset(meta.keys()) + + +# =========================================================================== +# TestGenerateOutputJSONDP +# =========================================================================== + + +class TestGenerateOutputJSONDP: + @pytest.fixture() + def _dp_output_setup(self, tmp_path, synthetic_encoded_data): + """Produce DP model JSON output and return (path, data dict).""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_clusters = 4 + S = 10 + + np.random.seed(99) + dp_samples = { + "B": np.random.randn(S, K_COEFF, K_cov) * 0.5, + "sigma_theta": np.ones((S, K_COEFF)), + "L_Omega": np.tile(np.eye(K_COEFF), (S, 1, 1)), + "eta": np.zeros((S, N_pools, K_COEFF)), + "df": np.full((S,), 5.0), + "sigma_eps": np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)), + "v": np.tile([0.6, 0.3, 0.05], (S, 1)), + "alpha_dp": np.full((S,), 1.5), + } + + pool_params = extract_noise_params(dp_samples, data) + output_path = str(tmp_path / "dp_output.json") + convergence = {"method": "svi", "final_elbo": 1234.0} + inference_config = {"method": "svi", "svi_steps": 1000} + + generate_output_json( + pool_params, dp_samples, data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + result = json.load(f) + return output_path, result + + def test_sigma_eps_structure_is_dp_mixture(self, _dp_output_setup): + _, data = _dp_output_setup + assert data["model_spec"]["sigma_eps_structure"] == "dp_mixture" + + def test_model_name_includes_dp(self, _dp_output_setup): + _, data = _dp_output_setup + assert "dp_sigma" in data["model"] + + def test_has_cluster_weights(self, _dp_output_setup): + _, data = _dp_output_setup + assert "cluster_weights" in data["population_effects"] + + def test_sigma_eps_length_equals_k_clusters(self, _dp_output_setup): + _, data = _dp_output_setup + sigma_eps = data["population_effects"]["sigma_eps"] + assert len(sigma_eps) == 4 # K_clusters = 4 + + def test_cluster_weights_sum_to_one(self, _dp_output_setup): + _, data = _dp_output_setup + w = data["population_effects"]["cluster_weights"] + np.testing.assert_allclose(sum(w), 1.0, atol=1e-4) + + def test_still_has_standard_fields(self, _dp_output_setup): + _, data = _dp_output_setup + expected = { + "model", "model_spec", "inference", "population_effects", + "convergence", "n_pools", "n_obs", "pools", + } + assert expected.issubset(data.keys()) + + +# =========================================================================== +# TestGenerateOutputJSONStructural +# =========================================================================== + + +class TestGenerateOutputJSONStructural: + @pytest.fixture() + def _structural_output_setup(self, tmp_path, synthetic_structural_data): + """Produce structural model JSON output and return (path, data dict).""" + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + from quantammsim.noise_calibration.constants import K_OBS_COEFF + + data = synthetic_structural_data + K_cov = data["K_cov"] + N_pools = data["N_pools"] + n_chains = data["n_chains"] + n_tiers = data["n_tiers"] + S = 10 + + np.random.seed(77) + structural_samples = { + "alpha_0": np.random.randn(S) * 0.1 + 2.0, + "alpha_chain": np.random.randn(S, n_chains - 1) * 0.1, + "alpha_tier": np.random.randn(S, n_tiers - 1) * 0.1, + "alpha_tvl": np.random.randn(S) * 0.01, + "B": np.random.randn(S, K_OBS_COEFF, K_cov) * 0.5, + "sigma_theta": np.ones((S, K_OBS_COEFF)), + "L_Omega": np.tile(np.eye(K_OBS_COEFF), (S, 1, 1)), + "eta": np.zeros((S, N_pools, K_OBS_COEFF)), + "df": np.full((S,), 5.0), + "sigma_eps": np.tile([0.5, 0.8, 0.6], (S, 1)), + } + + pool_params = extract_structural_params(structural_samples, data) + output_path = str(tmp_path / "structural_output.json") + convergence = {"method": "svi", "final_elbo": 999.0} + inference_config = {"method": "svi", "svi_steps": 2000} + + generate_output_json( + pool_params, structural_samples, data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + result = json.load(f) + return output_path, result + + def test_output_model_name(self, _structural_output_setup): + _, data = _structural_output_setup + assert data["model"] == "structural_mixture" + + def test_output_has_arb_params(self, _structural_output_setup): + _, data = _structural_output_setup + pe = data["population_effects"] + assert "alpha_0" in pe + assert "alpha_chain" in pe + assert "alpha_tier" in pe + assert "alpha_tvl" in pe + + def test_output_has_hierarchical_noise_params(self, _structural_output_setup): + _, data = _structural_output_setup + pe = data["population_effects"] + assert "B" in pe + assert "sigma_theta" in pe + assert "correlation_matrix" in pe + + def test_output_pools_have_arb_frequency(self, _structural_output_setup): + _, data = _structural_output_setup + pools = data["pools"] + for pid, entry in pools.items(): + assert "arb_frequency" in entry + assert isinstance(entry["arb_frequency"], int) + assert 1 <= entry["arb_frequency"] <= 60 diff --git a/tests/noise/test_panel_assembly.py b/tests/noise/test_panel_assembly.py new file mode 100644 index 0000000..dd7cd31 --- /dev/null +++ b/tests/noise/test_panel_assembly.py @@ -0,0 +1,442 @@ +"""Tests for compute_pair_volatility, assemble_panel, validate_panel.""" + +from datetime import date, timedelta + +import numpy as np +import pandas as pd +import pytest + +from quantammsim.noise_calibration import ( + compute_pair_volatility, + assemble_panel, + validate_panel, +) + + +# =========================================================================== +# TestComputePairVolatility +# =========================================================================== + + +class TestComputePairVolatility: + @pytest.fixture() + def _snap_dates(self): + """10 unique dates for snapshot stub.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + return pd.DataFrame({"date": dates}) + + def test_stablecoin_pair_returns_001(self, _snap_dates): + pool_row = pd.Series({ + "tokens": ["USDC", "DAI"], + "chain": "MAINNET", + }) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert (vol == 0.01).all() + + def test_both_missing_non_stable(self, _snap_dates): + pool_row = pd.Series({ + "tokens": ["FOO", "BAR"], + "chain": "MAINNET", + }) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert (vol == 0.5).all() + + def test_one_missing_non_stable(self, _snap_dates): + np.random.seed(77) + pool_row = pd.Series({ + "tokens": ["WETH", "BAR"], + "chain": "MAINNET", + }) + prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": [1735689600 + i * 3600 for i in range(48)], + "price": [3000.0 + np.random.randn() * 10 for _ in range(48)], + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, prices) + assert (vol == 0.5).all() + + def test_synthetic_hourly_prices_positive_finite(self, _snap_dates): + np.random.seed(42) + n_hours = 240 # 10 days x 24 hours + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + prices_a = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 8) + prices_b = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 7) + + pool_row = pd.Series({ + "tokens": ["WETH", "LINK"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": prices_a, + }), + ("MAINNET", "LINK"): pd.DataFrame({ + "timestamp": timestamps, "price": prices_b, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + assert len(vol) > 0 + assert (vol > 0).all() + assert np.all(np.isfinite(vol)) + + def test_annualisation_uses_sqrt_24x365(self, _snap_dates): + """Verify the annualisation factor is sqrt(24*365).""" + np.random.seed(7) + n_hours = 240 + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + raw_prices = np.exp(np.cumsum(np.random.randn(n_hours) * 0.005) + 8) + + pool_row = pd.Series({ + "tokens": ["WETH", "USDC"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": raw_prices, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + + # Reconstruct manually + df = pd.DataFrame({"timestamp": timestamps, "price": raw_prices}) + df["datetime"] = pd.to_datetime(df["timestamp"], unit="s") + df["date"] = df["datetime"].dt.date + df["ratio"] = df["price"] # WETH vs stable => ratio = price + df["log_return"] = np.log(df["ratio"] / df["ratio"].shift(1)) + df = df.dropna(subset=["log_return"]) + daily_std = df.groupby("date")["log_return"].std() + expected = daily_std * np.sqrt(24 * 365) + + common = vol.index.intersection(expected.index) + assert len(common) > 0 + np.testing.assert_allclose( + vol.loc[common].values, expected.loc[common].values, rtol=1e-10, + ) + + def test_single_token_pool_returns_empty(self, _snap_dates): + pool_row = pd.Series({"tokens": ["WETH"], "chain": "MAINNET"}) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert len(vol) == 0 + + def test_no_overlapping_price_dates(self, _snap_dates): + """Two tokens with non-overlapping timestamps -> fallback 0.5.""" + pool_row = pd.Series({ + "tokens": ["WETH", "LINK"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": [1000000 + i for i in range(10)], + "price": [3000.0] * 10, + }), + ("MAINNET", "LINK"): pd.DataFrame({ + "timestamp": [9000000 + i for i in range(10)], + "price": [15.0] * 10, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + assert (vol == 0.5).all() + + def test_cross_chain_price_fallback(self, _snap_dates): + """Prices keyed to a different chain SHOULD be used as fallback. + + Token prices are chain-agnostic (WETH is WETH regardless of chain), + so an ARBITRUM pool should use MAINNET WETH prices if ARBITRUM + prices aren't available. + """ + np.random.seed(55) + n_hours = 240 + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + prices = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 8) + + pool_row = pd.Series({ + "tokens": ["WETH", "USDC"], + "chain": "ARBITRUM", + }) + token_prices = { + # Only MAINNET prices, but ARBITRUM pool should still use them + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": prices.tolist(), + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + # Should get real volatility, NOT the 0.5 fallback + assert len(vol) > 0 + assert not (vol == 0.5).all(), ( + "Cross-chain price fallback should have produced real volatility" + ) + + +# =========================================================================== +# TestAssemblePanel +# =========================================================================== + + +class TestAssemblePanel: + def test_lagged_tvl_drops_first_obs_per_pool( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + counts = panel.groupby("pool_id").size() + assert (counts == 9).all() + + def test_lagged_tvl_exact_values( + self, synthetic_pools_df, synthetic_snapshots_df + ): + """log_tvl_lag1[t] must equal log_tvl[t-1] for same pool (shift(1)).""" + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + for pid in panel["pool_id"].unique(): + pool = panel[panel["pool_id"] == pid].sort_values("date") + tvl_vals = pool["log_tvl"].values + lag_vals = pool["log_tvl_lag1"].values + # After dropping the first obs, lag[i] = tvl[i-1] in the original + # pre-drop series. Since the panel is sorted by date, each + # lag value should equal the log_tvl of the chronologically + # preceding observation. We verify consecutive pairs: for rows + # i and i+1, lag[i+1] == tvl[i]. + for i in range(len(tvl_vals) - 1): + np.testing.assert_allclose( + lag_vals[i + 1], tvl_vals[i], rtol=1e-14, + err_msg=f"Pool {pid} row {i+1}: lag should equal previous tvl", + ) + + def test_weekend_flag(self, synthetic_pools_df, synthetic_snapshots_df): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + for _, row in panel.iterrows(): + d = row["date"] + if not isinstance(d, date): + d = pd.Timestamp(d).date() + expected = 1.0 if d.weekday() >= 5 else 0.0 + assert row["weekend"] == expected, f"Wrong weekend flag for {d}" + + def test_log_volume_is_natural_log( + self, synthetic_pools_df, synthetic_snapshots_df + ): + """log_volume must equal ln(volume_usd), not log10 or log2.""" + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + snaps = synthetic_snapshots_df.copy() + # Join snapshots to panel by pool_id + date to verify exact log values + for _, row in panel.iterrows(): + pid = row["pool_id"] + d = row["date"] + snap_match = snaps[ + (snaps["pool_id"] == pid) & (snaps["date"] == d) + ] + assert len(snap_match) == 1, f"No snapshot for {pid} on {d}" + expected = np.log(snap_match.iloc[0]["volume_usd"]) + np.testing.assert_allclose( + row["log_volume"], expected, rtol=1e-14, + err_msg=f"log_volume for {pid} on {d} should be ln(volume_usd)", + ) + + def test_tier_assignment_min_first( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + # Pool A: WETH(0), USDC(0) -> tier_A=0, tier_B=0 + pool_a = panel[panel["pool_id"] == "pool_A"].iloc[0] + assert pool_a["tier_A"] == 0 + assert pool_a["tier_B"] == 0 + + # Pool B: BAL(1), WETH(0) -> tier_A=0, tier_B=1 + pool_b = panel[panel["pool_id"] == "pool_B"].iloc[0] + assert pool_b["tier_A"] == 0 + assert pool_b["tier_B"] == 1 + + # Pool C: RATS(2), WETH(0) -> tier_A=0, tier_B=2 + pool_c = panel[panel["pool_id"] == "pool_C"].iloc[0] + assert pool_c["tier_A"] == 0 + assert pool_c["tier_B"] == 2 + + def test_log_fee(self, synthetic_pools_df, synthetic_snapshots_df): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + pool_a = panel[panel["pool_id"] == "pool_A"].iloc[0] + assert np.isclose(pool_a["log_fee"], np.log(0.003)) + + def test_all_expected_columns( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + expected_cols = { + "pool_id", "chain", "date", "log_volume", "log_tvl", + "log_tvl_lag1", "volatility", "weekend", "log_fee", + "swap_fee", "tier_A", "tier_B", "tokens", + } + assert expected_cols.issubset(set(panel.columns)) + + def test_zero_volume_rows_dropped(self, synthetic_pools_df): + """Rows with volume_usd <= 0 must be excluded from the panel.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(5)] + records = [] + for d in dates: + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", "date": d, + "volume_usd": 1000.0, + "total_liquidity_usd": 100000.0, + }) + # Add a zero-volume row + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", + "date": date(2026, 1, 6), + "volume_usd": 0.0, + "total_liquidity_usd": 100000.0, + }) + snaps = pd.DataFrame(records) + panel = assemble_panel(synthetic_pools_df, snaps, {}) + pool_a = panel[panel["pool_id"] == "pool_A"] + # 5 valid rows, minus 1 for lag = 4 (the zero-volume row is skipped) + assert len(pool_a) == 4 + + def test_zero_tvl_rows_dropped(self, synthetic_pools_df): + """Rows with total_liquidity_usd <= 0 must be excluded.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(5)] + records = [] + for d in dates: + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", "date": d, + "volume_usd": 1000.0, + "total_liquidity_usd": 100000.0, + }) + # Add a zero-TVL row + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", + "date": date(2026, 1, 6), + "volume_usd": 1000.0, + "total_liquidity_usd": 0.0, + }) + snaps = pd.DataFrame(records) + panel = assemble_panel(synthetic_pools_df, snaps, {}) + pool_a = panel[panel["pool_id"] == "pool_A"] + assert len(pool_a) == 4 + + +# =========================================================================== +# TestSyntheticPanelColumns — structural model covariates in the fixture +# =========================================================================== + + +class TestSyntheticPanelColumns: + """Tests that the synthetic_panel fixture has new structural columns.""" + + def test_panel_has_log_sigma(self, synthetic_panel): + assert "log_sigma" in synthetic_panel.columns + expected = np.log(np.maximum(synthetic_panel["volatility"].values, 1e-6)) + np.testing.assert_allclose( + synthetic_panel["log_sigma"].values, expected, rtol=1e-12, + ) + + def test_panel_has_dow_harmonics(self, synthetic_panel): + assert "dow_sin" in synthetic_panel.columns + assert "dow_cos" in synthetic_panel.columns + assert (synthetic_panel["dow_sin"] >= -1.0).all() + assert (synthetic_panel["dow_sin"] <= 1.0).all() + assert (synthetic_panel["dow_cos"] >= -1.0).all() + assert (synthetic_panel["dow_cos"] <= 1.0).all() + + def test_panel_has_interactions(self, synthetic_panel): + assert "tvl_x_sigma" in synthetic_panel.columns + assert "tvl_x_fee" in synthetic_panel.columns + assert "sigma_x_fee" in synthetic_panel.columns + + def test_dow_harmonics_correct_for_known_date(self, synthetic_panel): + """2026-01-03 is Saturday (weekday=5), so dow=5.""" + sat_rows = synthetic_panel[ + synthetic_panel["date"] == date(2026, 1, 3) + ] + if len(sat_rows) == 0: + pytest.skip("No Saturday rows in fixture") + expected_sin = np.sin(2 * np.pi * 5 / 7) + expected_cos = np.cos(2 * np.pi * 5 / 7) + np.testing.assert_allclose( + sat_rows["dow_sin"].values[0], expected_sin, atol=1e-12, + ) + np.testing.assert_allclose( + sat_rows["dow_cos"].values[0], expected_cos, atol=1e-12, + ) + + def test_interactions_use_lagged_tvl(self, synthetic_panel): + """tvl_x_sigma must use log_tvl_lag1, not log_tvl.""" + expected = ( + synthetic_panel["log_tvl_lag1"].values + * synthetic_panel["log_sigma"].values + ) + np.testing.assert_allclose( + synthetic_panel["tvl_x_sigma"].values, expected, rtol=1e-12, + ) + + +# =========================================================================== +# TestAssemblePanelStructuralColumns — new columns from real pipeline +# =========================================================================== + + +class TestAssemblePanelStructuralColumns: + """Tests that assemble_panel() produces new structural columns.""" + + def test_assemble_panel_has_log_sigma( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "log_sigma" in panel.columns + expected = np.log(np.maximum(panel["volatility"].values, 1e-6)) + np.testing.assert_allclose( + panel["log_sigma"].values, expected, rtol=1e-12, + ) + + def test_assemble_panel_has_dow_harmonics( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "dow_sin" in panel.columns + assert "dow_cos" in panel.columns + + def test_assemble_panel_has_interactions( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "tvl_x_sigma" in panel.columns + assert "tvl_x_fee" in panel.columns + assert "sigma_x_fee" in panel.columns + # Interactions use lagged TVL + expected = panel["log_tvl_lag1"].values * panel["log_sigma"].values + np.testing.assert_allclose( + panel["tvl_x_sigma"].values, expected, rtol=1e-12, + ) + + +# =========================================================================== +# TestValidatePanel +# =========================================================================== + + +class TestValidatePanel: + def test_flags_constant_volume(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + mask = panel["pool_id"] == "pool_A" + panel.loc[mask, "log_volume"] = 10.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "near-constant" in captured.out or "constant" in captured.out.lower() + + def test_flags_tvl_jumps(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + idx = panel[panel["pool_id"] == "pool_A"].index[1] + panel.loc[idx, "log_tvl"] = panel.loc[idx, "log_tvl"] + 5.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "TVL jumps" in captured.out + + def test_flags_volume_exceeds_tvl(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + panel["log_volume"] = panel["log_tvl"] + 1.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "volume > TVL" in captured.out + + def test_returns_dataframe_unchanged(self, synthetic_panel): + result = validate_panel(synthetic_panel) + pd.testing.assert_frame_equal(result, synthetic_panel) diff --git a/tests/noise/test_postprocessing.py b/tests/noise/test_postprocessing.py new file mode 100644 index 0000000..c40e5dd --- /dev/null +++ b/tests/noise/test_postprocessing.py @@ -0,0 +1,533 @@ +"""Tests for extract_noise_params, predict_new_pool, check_convergence, +assign_dp_clusters, and structural model post-processing.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + extract_noise_params, + predict_new_pool, + check_convergence, + classify_token_tier, + _get_theta_samples, + K_COEFF, + COEFF_NAMES, +) +from quantammsim.noise_calibration.constants import K_OBS_COEFF, OBS_COEFF_NAMES + + +# =========================================================================== +# TestExtractNoiseParams +# =========================================================================== + + +class TestExtractNoiseParams: + def test_output_length(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + assert len(result) == synthetic_encoded_data["N_pools"] + + def test_weekend_absorption(self, synthetic_samples, synthetic_encoded_data): + """b_0_eff = b_0_raw + b_weekend * (2/7).""" + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + + theta = _get_theta_samples( + synthetic_samples, synthetic_encoded_data["X_pool"] + ) + theta_med = np.median(theta, axis=0) + + for i, p in enumerate(result): + b_0_raw = theta_med[i, 0] + b_weekend = theta_med[i, 3] + expected_b_0 = b_0_raw + b_weekend * (2.0 / 7.0) + np.testing.assert_allclose( + p["noise_params"]["b_0"], expected_b_0, atol=1e-10, + ) + + def test_noise_params_keys(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + expected_keys = {"b_0", "b_sigma", "b_c", "b_weekend", "base_fee"} + for p in result: + assert set(p["noise_params"].keys()) == expected_keys + + def test_theta_median_length(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + assert len(p["theta_median"]) == K_COEFF + + def test_b_c_equals_theta_1(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + np.testing.assert_allclose( + p["noise_params"]["b_c"], p["theta_median"][1], atol=1e-10, + ) + + def test_b_sigma_equals_theta_2( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + np.testing.assert_allclose( + p["noise_params"]["b_sigma"], p["theta_median"][2], atol=1e-10, + ) + + def test_base_fee_matches_pool( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + pool_meta = synthetic_encoded_data["pool_meta"] + for i, p in enumerate(result): + expected_fee = pool_meta.iloc[i]["swap_fee"] + np.testing.assert_allclose( + p["noise_params"]["base_fee"], expected_fee, atol=1e-10, + ) + + def test_pool_id_and_chain_preserved( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + pool_ids = synthetic_encoded_data["pool_ids"] + pool_meta = synthetic_encoded_data["pool_meta"] + for i, p in enumerate(result): + assert p["pool_id"] == pool_ids[i] + assert p["chain"] == str(pool_meta.iloc[i]["chain"]) + + def test_use_median_false_uses_mean( + self, synthetic_samples, synthetic_encoded_data + ): + """use_median=False must produce different values than use_median=True.""" + result_med = extract_noise_params( + synthetic_samples, synthetic_encoded_data, use_median=True + ) + result_mean = extract_noise_params( + synthetic_samples, synthetic_encoded_data, use_median=False + ) + assert len(result_mean) == len(result_med) + + # With random B samples (S=10, seed 99), median != mean for at least + # one pool. Check that at least one theta_median value differs. + any_differ = False + for pm, pn in zip(result_med, result_mean): + for tm, tn in zip(pm["theta_median"], pn["theta_median"]): + if not np.isclose(tm, tn, atol=1e-14): + any_differ = True + break + if any_differ: + break + assert any_differ, "median and mean paths produced identical values" + + +# =========================================================================== +# TestPredictNewPool +# =========================================================================== + + +class TestPredictNewPool: + def _build_z_new(self, data, chain, tokens, fee): + """Reconstruct z_new the same way predict_new_pool does internally.""" + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b}": + z_new[i] = 1.0 + return z_new + + def test_known_chain_sets_dummy( + self, synthetic_samples, synthetic_encoded_data + ): + """ARBITRUM dummy must be 1 and must affect the prediction vs MAINNET.""" + data = synthetic_encoded_data + + result_arb = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "USDC"], fee=0.003, + ) + result_main = predict_new_pool( + synthetic_samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], fee=0.003, + ) + # MAINNET is the reference chain (alphabetically: ARBITRUM < BASE < MAINNET). + # Wait — ARBITRUM is alphabetically first, so ARBITRUM is the reference. + # MAINNET has a chain_MAINNET dummy. Both should produce different mu. + # If chain dummies are ignored, these would be identical. + arb_b0 = result_arb["noise_params"]["b_0"] + main_b0 = result_main["noise_params"]["b_0"] + assert not np.isclose(arb_b0, main_b0, atol=1e-10), ( + "Different chains should produce different predictions" + ) + + def test_tier_assignment_affects_prediction( + self, synthetic_samples, synthetic_encoded_data + ): + """WETH/RATS (tier 0,2) vs WETH/USDC (tier 0,0) must differ.""" + data = synthetic_encoded_data + + result_rats = predict_new_pool( + synthetic_samples, data, + chain="BASE", tokens=["WETH", "RATS"], fee=0.005, + ) + result_usdc = predict_new_pool( + synthetic_samples, data, + chain="BASE", tokens=["WETH", "USDC"], fee=0.005, + ) + # Different tier_B dummies should give different mu + rats_b0 = result_rats["noise_params"]["b_0"] + usdc_b0 = result_usdc["noise_params"]["b_0"] + assert not np.isclose(rats_b0, usdc_b0, atol=1e-10), ( + "Different tier assignments should produce different predictions" + ) + + def test_weekend_absorption_arithmetic( + self, synthetic_samples, synthetic_encoded_data + ): + """b_0 in noise_params must equal mu_median[0] + mu_median[3] * (2/7).""" + data = synthetic_encoded_data + result = predict_new_pool( + synthetic_samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], fee=0.003, + ) + # Reconstruct mu_median independently + z_new = self._build_z_new(data, "MAINNET", ["WETH", "USDC"], 0.003) + B = synthetic_samples["B"] + mu_samples = np.einsum("skd,d->sk", B, z_new) + mu_median = np.median(mu_samples, axis=0) + + b_0_raw = mu_median[0] + b_weekend = mu_median[3] + expected_b_0 = b_0_raw + b_weekend * (2.0 / 7.0) + np.testing.assert_allclose( + result["noise_params"]["b_0"], expected_b_0, atol=1e-10, + ) + + def test_mu_equals_b_at_z_new( + self, synthetic_samples, synthetic_encoded_data + ): + """With known B samples, verify credible interval medians = median(B @ z_new).""" + data = synthetic_encoded_data + z_new = self._build_z_new(data, "ARBITRUM", ["WETH", "RATS"], 0.005) + + B = synthetic_samples["B"] + mu_expected = np.einsum("skd,d->sk", B, z_new) + mu_median_expected = np.median(mu_expected, axis=0) + + result = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "RATS"], fee=0.005, + ) + for k, name in enumerate(COEFF_NAMES): + np.testing.assert_allclose( + result["credible_intervals_90"][name]["median"], + mu_median_expected[k], + atol=1e-10, + ) + + def test_unseen_chain_uses_reference( + self, synthetic_samples, synthetic_encoded_data + ): + """A chain not in training data should get reference-chain prediction + (all chain dummies = 0), not raise an error.""" + data = synthetic_encoded_data + result = predict_new_pool( + synthetic_samples, data, + chain="SONIC", tokens=["WETH", "USDC"], fee=0.003, + ) + # Should be same as ARBITRUM (the reference chain, all dummies 0) + result_ref = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "USDC"], fee=0.003, + ) + np.testing.assert_allclose( + result["noise_params"]["b_0"], + result_ref["noise_params"]["b_0"], + atol=1e-10, + ) + + +# =========================================================================== +# TestCheckConvergence +# =========================================================================== + + +class TestCheckConvergence: + def test_svi_returns_expected_keys(self): + losses = np.random.randn(1000).cumsum() + 5000 + result = check_convergence(losses, method="svi") + assert "final_elbo" in result + assert "elbo_last_100_std" in result + assert "elbo_last_100_mean" in result + + def test_svi_method_key(self): + losses = np.linspace(5000, 1000, 500) + result = check_convergence(losses, method="svi") + assert result["method"] == "svi" + + def test_svi_elbo_last_100_std_correct(self): + np.random.seed(42) + losses = np.random.randn(500) * 10 + 1000 + result = check_convergence(losses, method="svi") + expected_std = float(np.std(losses[-100:])) + np.testing.assert_allclose( + result["elbo_last_100_std"], expected_std, atol=1e-10, + ) + + def test_svi_final_elbo_is_last_loss(self): + losses = np.array([100.0, 50.0, 25.0, 12.5]) + result = check_convergence(losses, method="svi") + assert result["final_elbo"] == 12.5 + + +# =========================================================================== +# TestAssignDPClusters +# =========================================================================== + + +class TestAssignDPClusters: + @pytest.fixture() + def dp_samples_and_data(self, synthetic_encoded_data): + """Synthetic DP posterior samples with known cluster structure.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_clusters = 4 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_COEFF, K_cov) * 0.5 + sigma_theta = np.ones((S, K_COEFF)) + L_Omega = np.tile(np.eye(K_COEFF), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_COEFF)) + df = np.full((S,), 5.0) + + # Well-separated sigma_eps clusters + sigma_eps = np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)) + # Stick-breaking weights: mostly on cluster 0 + v = np.tile([0.6, 0.3, 0.05], (S, 1)) + + samples = { + "B": B, "sigma_theta": sigma_theta, "L_Omega": L_Omega, + "eta": eta, "df": df, "sigma_eps": sigma_eps, "v": v, + } + data_with_k = dict(data) + data_with_k["K_clusters"] = K_clusters + return samples, data_with_k + + def test_returns_correct_length(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + assert len(assignments) == data["N_pools"] + + def test_valid_cluster_indices(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + K = data["K_clusters"] + assert all(0 <= a < K for a in assignments) + + def test_returns_integer_array(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + assert assignments.dtype in (np.int32, np.int64, int) + + +# =========================================================================== +# TestExtractNoiseParamsDP +# =========================================================================== + + +class TestExtractNoiseParamsDP: + @pytest.fixture() + def dp_samples_and_data(self, synthetic_encoded_data): + """Same as above for extract_noise_params testing.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_COEFF, K_cov) * 0.5 + sigma_theta = np.ones((S, K_COEFF)) + L_Omega = np.tile(np.eye(K_COEFF), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_COEFF)) + df = np.full((S,), 5.0) + sigma_eps = np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)) + v = np.tile([0.6, 0.3, 0.05], (S, 1)) + + samples = { + "B": B, "sigma_theta": sigma_theta, "L_Omega": L_Omega, + "eta": eta, "df": df, "sigma_eps": sigma_eps, "v": v, + } + return samples, data + + def test_output_length(self, dp_samples_and_data): + samples, data = dp_samples_and_data + result = extract_noise_params(samples, data) + assert len(result) == data["N_pools"] + + def test_noise_params_keys_present(self, dp_samples_and_data): + samples, data = dp_samples_and_data + result = extract_noise_params(samples, data) + expected_keys = {"b_0", "b_sigma", "b_c", "b_weekend", "base_fee"} + for p in result: + assert set(p["noise_params"].keys()) == expected_keys + + +# =========================================================================== +# TestExtractStructuralParams +# =========================================================================== + + +class TestExtractStructuralParams: + """Tests for extract_structural_params().""" + + @pytest.fixture() + def structural_samples_and_data(self, synthetic_structural_data): + """Run a quick SVI fit to get structural samples.""" + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + return samples, synthetic_structural_data + + def test_extract_returns_arb_params(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + samples, data = structural_samples_and_data + result = extract_structural_params(samples, data) + assert len(result) == data["N_pools"] + for p in result: + assert "arb_frequency" in p + assert isinstance(p["arb_frequency"], int) + assert 1 <= p["arb_frequency"] <= 60 + + def test_extract_returns_noise_params(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + samples, data = structural_samples_and_data + result = extract_structural_params(samples, data) + for p in result: + assert "noise_params" in p + coeffs = p["noise_params"] + assert len(coeffs) == K_OBS_COEFF + for name in OBS_COEFF_NAMES: + assert name in coeffs, f"Missing coefficient: {name}" + + +# =========================================================================== +# TestPredictStructural +# =========================================================================== + + +class TestPredictStructural: + @pytest.fixture() + def structural_samples_and_data(self, synthetic_structural_data): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + return samples, synthetic_structural_data + + def test_predict_returns_cadence(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + result = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + assert "arb_frequency" in result + assert isinstance(result["arb_frequency"], int) + assert 1 <= result["arb_frequency"] <= 60 + + def test_predict_returns_noise_coefficients( + self, structural_samples_and_data, + ): + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + result = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + assert "noise_params" in result + assert len(result["noise_params"]) == K_OBS_COEFF + + def test_predict_uses_B_regression(self, structural_samples_and_data): + """Different (chain, tier) → different predictions via B regression.""" + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + r1 = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + r2 = predict_new_pool_structural( + samples, data, + chain="ARBITRUM", tokens=["BAL", "WETH"], + fee=0.01, tvl_est=5e5, + ) + # Different pool characteristics should produce different noise params + assert r1["noise_params"] != r2["noise_params"] + + def test_predict_cadence_higher_for_longtail( + self, structural_samples_and_data, + ): + """Long-tail pools should have higher cadence (less efficient arb).""" + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + # This tests the structural relationship — it may not hold with + # random SVI samples on synthetic data, so we just check it doesn't + # crash and returns valid values + r1 = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + r2 = predict_new_pool_structural( + samples, data, + chain="BASE", tokens=["RATS", "WETH"], + fee=0.005, tvl_est=1e5, + ) + # Both should be valid + assert 1 <= r1["arb_frequency"] <= 60 + assert 1 <= r2["arb_frequency"] <= 60 diff --git a/tests/noise/test_token_classification.py b/tests/noise/test_token_classification.py new file mode 100644 index 0000000..dd540f8 --- /dev/null +++ b/tests/noise/test_token_classification.py @@ -0,0 +1,66 @@ +"""Tests for _normalise_symbol and classify_token_tier.""" + +import pytest +from quantammsim.noise_calibration import _normalise_symbol, classify_token_tier + + +# =========================================================================== +# TestNormaliseSymbol +# =========================================================================== + + +class TestNormaliseSymbol: + def test_passthrough_unknown(self): + assert _normalise_symbol("FOO") == "FOO" + + def test_known_mapping_preserved(self): + assert _normalise_symbol("WETH") == "WETH" + assert _normalise_symbol("WBTC") == "WBTC" + assert _normalise_symbol("cbBTC") == "cbBTC" + + def test_whitespace_stripped(self): + assert _normalise_symbol(" ETH ") == "ETH" + assert _normalise_symbol(" WETH ") == "WETH" + + def test_case_sensitivity_preserved(self): + # Lowercase is NOT normalised to uppercase + assert _normalise_symbol("weth") == "weth" + + +# =========================================================================== +# TestClassifyTokenTier +# =========================================================================== + + +class TestClassifyTokenTier: + def test_tier0_native_tokens(self): + for sym in ["ETH", "BTC"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_wrapped(self): + for sym in ["WETH", "WBTC", "cbBTC"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_stablecoins(self): + for sym in ["USDC", "USDT", "DAI"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_chain_natives(self): + for sym in ["MATIC", "AVAX", "GNO", "S", "wS"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier1_defi_bluechips(self): + for sym in ["AAVE", "BAL", "COW", "LINK", "ARB"]: + assert classify_token_tier(sym) == 1, f"{sym} should be tier 1" + + def test_tier2_unknown_tokens(self): + for sym in ["RATS", "PEPE"]: + assert classify_token_tier(sym) == 2, f"{sym} should be tier 2" + + def test_tier2_empty_string(self): + assert classify_token_tier("") == 2 + + def test_wrapped_variant_normalisation(self): + # "wS" is in the mapping table AND in _TIER_0 + assert classify_token_tier("wS") == 0 + assert classify_token_tier(" wS ") == 0 diff --git a/tests/pools/reCLAMM/test_reclamm_noise_volume.py b/tests/pools/reCLAMM/test_reclamm_noise_volume.py new file mode 100644 index 0000000..f413a9e --- /dev/null +++ b/tests/pools/reCLAMM/test_reclamm_noise_volume.py @@ -0,0 +1,991 @@ +"""Tests for reClAMM Tsoukalas noise volume model. + +Tests noise volume functions (sqrt and log variants), volatility computation, +scan step integration, pool class plumbing, and OLS calibration. +""" + +import pytest +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +from quantammsim.pools.noise_trades import ( + reclamm_tsoukalas_sqrt_noise_volume, + reclamm_tsoukalas_log_noise_volume, + reclamm_loglinear_noise_volume, +) + +# Typical noise_params for a mid-cap pool. +# a_c=1.5 is roughly equivalent to the old a_c_real=1.0 + a_c_virt=0.5 +# for pools with comparable real and virtual TVL: 1.5/sqrt(2) ≈ 1.06 +# per-component, but we use 1.5 to ensure noise income dominates the +# arb-suppression side-effect in integration tests. +DEFAULT_NOISE_PARAMS = { + "a_0_base": 0.5, + "a_f": 0.0, + "a_sigma": 2.0, + "a_c": 1.5, + "base_fee": 0.003, +} + + +# --------------------------------------------------------------------------- +# Tests 1-7: Unit tests for noise volume functions +# --------------------------------------------------------------------------- + + +class TestPositiveOutputReasonableInputs: + """Test 1: Volume > 0 for typical inputs (both sqrt and log variants).""" + + def test_sqrt_positive_output(self): + vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=15_000_000.0, + gamma=0.997, + volatility=0.5, + arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + assert float(vol) > 0, f"Expected positive noise volume, got {float(vol)}" + + def test_log_positive_output(self): + vol = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=15_000_000.0, + gamma=0.997, + volatility=0.5, + arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + assert float(vol) > 0, f"Expected positive noise volume, got {float(vol)}" + + +class TestZeroWhenArbExceedsPredicted: + """Test 2: noise = max(0, daily/1440 - arb), so returns 0 when arb dominates.""" + + def test_sqrt_zero_when_arb_large(self): + vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=3_000_000.0, + gamma=0.997, + volatility=0.3, + arb_volume_this_period=1e12, # Absurdly large arb + noise_params=DEFAULT_NOISE_PARAMS, + ) + assert float(vol) == 0.0, f"Expected zero noise volume, got {float(vol)}" + + def test_log_zero_when_arb_large(self): + vol = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=3_000_000.0, + gamma=0.997, + volatility=0.3, + arb_volume_this_period=1e12, + noise_params=DEFAULT_NOISE_PARAMS, + ) + assert float(vol) == 0.0, f"Expected zero noise volume, got {float(vol)}" + + +class TestMonotonicInEffectiveTVL: + """Test 3: Higher effective TVL -> more predicted volume.""" + + def test_sqrt_monotonic_effective_tvl(self): + kwargs = dict( + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + vol_low = reclamm_tsoukalas_sqrt_noise_volume(effective_value_usd=3_000_000.0, **kwargs) + vol_high = reclamm_tsoukalas_sqrt_noise_volume(effective_value_usd=20_000_000.0, **kwargs) + assert float(vol_high) > float(vol_low) + + def test_log_monotonic_effective_tvl(self): + kwargs = dict( + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + vol_low = reclamm_tsoukalas_log_noise_volume(effective_value_usd=3_000_000.0, **kwargs) + vol_high = reclamm_tsoukalas_log_noise_volume(effective_value_usd=20_000_000.0, **kwargs) + assert float(vol_high) > float(vol_low) + + +class TestMonotonicInVolatility: + """Test 4: Higher volatility -> more predicted volume.""" + + def test_sqrt_monotonic_volatility(self): + kwargs = dict( + effective_value_usd=10_000_000.0, + gamma=0.997, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + vol_low = reclamm_tsoukalas_sqrt_noise_volume(volatility=0.2, **kwargs) + vol_high = reclamm_tsoukalas_sqrt_noise_volume(volatility=0.8, **kwargs) + assert float(vol_high) > float(vol_low) + + def test_log_monotonic_volatility(self): + kwargs = dict( + effective_value_usd=10_000_000.0, + gamma=0.997, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + vol_low = reclamm_tsoukalas_log_noise_volume(volatility=0.2, **kwargs) + vol_high = reclamm_tsoukalas_log_noise_volume(volatility=0.8, **kwargs) + assert float(vol_high) > float(vol_low) + + +class TestEffectiveTVLSensitivity: + """Test 5: Changing effective TVL changes output.""" + + def test_sqrt_tvl_sensitivity(self): + base = dict( + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + v1 = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=7_000_000.0, **base) + v2 = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=13_000_000.0, **base) + assert float(v1) != float(v2), "Effective TVL change should affect output" + + def test_log_tvl_sensitivity(self): + base = dict( + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=DEFAULT_NOISE_PARAMS, + ) + v1 = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=7_000_000.0, **base) + v2 = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=13_000_000.0, **base) + assert float(v1) != float(v2), "Effective TVL change should affect output" + + +class TestCustomParamsOverrideDefaults: + """Test 6: noise_params dict values are actually used.""" + + def test_sqrt_custom_params(self): + # With a_sigma=0, volatility shouldn't matter + zero_sigma_params = {**DEFAULT_NOISE_PARAMS, "a_sigma": 0.0} + v1 = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.2, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + v2 = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.8, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + npt.assert_allclose(float(v1), float(v2), rtol=1e-10, + err_msg="With a_sigma=0, volatility should not affect output") + + def test_log_custom_params(self): + zero_sigma_params = {**DEFAULT_NOISE_PARAMS, "a_sigma": 0.0} + v1 = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.2, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + v2 = reclamm_tsoukalas_log_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.8, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + npt.assert_allclose(float(v1), float(v2), rtol=1e-10, + err_msg="With a_sigma=0, volatility should not affect output") + + +class TestCalculateVolatilityArray: + """Test calculate_volatility_array on the base pool class (JIT'd, pure JAX).""" + + def _make_run_fp(self): + from quantammsim.runners.jax_runner_utils import Hashabledict + return Hashabledict({"tokens": ("ETH", "USDC"), "numeraire": "USDC"}) + + def _make_pool(self): + from quantammsim.pools.creator import create_pool + return create_pool("reclamm") + + def test_output_shape_matches_input(self): + """Volatility array length matches input price length.""" + pool = self._make_pool() + n_minutes = 1440 * 3 # 3 days + rng = np.random.default_rng(42) + log_rets = rng.normal(0, 0.001, (n_minutes, 2)) + prices = jnp.array( + np.exp(np.cumsum(log_rets, axis=0)) * np.array([2500.0, 1.0]) + ) + vol_array = pool.calculate_volatility_array(prices, self._make_run_fp()) + assert vol_array.shape == (n_minutes,), ( + f"Expected shape ({n_minutes},), got {vol_array.shape}" + ) + + def test_constant_prices_zero_vol(self): + """Constant prices should give zero volatility.""" + pool = self._make_pool() + n_minutes = 1440 * 2 + prices = jnp.tile(jnp.array([2500.0, 1.0]), (n_minutes, 1)) + vol_array = pool.calculate_volatility_array(prices, self._make_run_fp()) + npt.assert_allclose(np.array(vol_array), 0.0, atol=1e-10) + + def test_volatile_prices_positive_vol(self): + """Volatile prices should give positive volatility.""" + pool = self._make_pool() + n_minutes = 1440 * 2 + rng = np.random.default_rng(123) + log_rets = rng.normal(0, 0.01, n_minutes) + price_ratio = np.exp(np.cumsum(log_rets)) + prices = jnp.array( + np.column_stack([price_ratio * 2500.0, np.ones(n_minutes)]) + ) + vol_array = pool.calculate_volatility_array(prices, self._make_run_fp()) + assert float(jnp.mean(vol_array)) > 0, "Volatile prices should give positive vol" + + def test_partial_last_day_handled(self): + """Non-multiple-of-1440: correct shape and partial-day fill uses last day's vol.""" + pool = self._make_pool() + n_full_days = 2 + n_partial = 500 + n_minutes = 1440 * n_full_days + n_partial + + # Use volatile prices so daily vol is nonzero + rng = np.random.default_rng(77) + log_rets = rng.normal(0, 0.005, n_minutes) + price_ratio = np.exp(np.cumsum(log_rets)) + prices = jnp.array( + np.column_stack([price_ratio * 2500.0, np.ones(n_minutes)]) + ) + vol_array = pool.calculate_volatility_array(prices, self._make_run_fp()) + + assert vol_array.shape == (n_minutes,) + + # The partial-day region (last 500 minutes) should be filled + # with the last full day's volatility value + last_full_day_vol = vol_array[n_full_days * 1440 - 1] + partial_region = vol_array[n_full_days * 1440:] + npt.assert_allclose( + np.array(partial_region), + float(last_full_day_vol), + rtol=1e-10, + err_msg="Partial-day region should be filled with last full day's vol", + ) + # And the fill value should be nonzero (volatile prices) + assert float(last_full_day_vol) > 0, "Expected nonzero vol from volatile prices" + + +# --------------------------------------------------------------------------- +# Tests 8-12: Scan step integration tests +# --------------------------------------------------------------------------- + +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + initialise_reclamm_reserves, + _jax_calc_reclamm_reserves_with_fees, + _jax_calc_reclamm_reserves_and_fee_revenue_with_fees, + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs, +) + +ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) + +# Pool config shared by integration tests +_CM = 0.2 # centeredness_margin +_DPSB = 1.0 - 1.0 / 124000.0 # daily_price_shift_base +_SPP = 60.0 # seconds_per_step (1-min arb) +_FEES = 0.003 +_PRICE_RATIO = 4.0 +_POOL_VALUE = 1_000_000.0 + + +def _init_pool(pool_value=_POOL_VALUE, price_a=2500.0, price_b=1.0, + price_ratio=_PRICE_RATIO): + initial_prices = jnp.array([price_a, price_b]) + reserves, Va, Vb = initialise_reclamm_reserves(pool_value, initial_prices, price_ratio) + return reserves, Va, Vb + + +def _make_trending_prices(start_a, end_a, price_b, n_steps): + prices_a = jnp.linspace(start_a, end_a, n_steps) + prices_b = jnp.full(n_steps, price_b) + return jnp.stack([prices_a, prices_b], axis=1) + + +class TestRatioBackwardCompatible: + """Test 8: noise_model='ratio' matches existing noise_trader_ratio path.""" + + def test_ratio_model_matches_legacy(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + + # Legacy path: just noise_trader_ratio + res_legacy = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_trader_ratio=1.5, + ) + + # New path: noise_model="ratio" (default) + res_new = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_trader_ratio=1.5, + noise_model="ratio", + ) + npt.assert_array_equal(res_legacy, res_new) + + +class TestArbOnlyEqualsZeroRatio: + """Test 9: noise_model='arb_only' same as noise_trader_ratio=0.""" + + def test_arb_only_matches_zero_ratio(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + + res_zero = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_trader_ratio=0.0, + ) + res_arb_only = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="arb_only", + ) + npt.assert_array_equal(res_zero, res_arb_only) + + +class TestTsoukalasSqrtIncreasesReserves: + """Test 10: Tsoukalas noise income grows real TVL vs arb-only.""" + + def test_sqrt_reserves_grow(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + vol_array = jnp.full(n_steps, 0.5) # Synthetic constant volatility + + res_arb_only = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="arb_only", + ) + res_tsoukalas = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="tsoukalas_sqrt", + noise_params=DEFAULT_NOISE_PARAMS, + volatility_array=vol_array, + ) + # Noise fee income should make total real value strictly greater than arb-only + val_arb = float(jnp.sum(res_arb_only[-1] * prices[-1])) + val_tsoukalas = float(jnp.sum(res_tsoukalas[-1] * prices[-1])) + assert val_tsoukalas > val_arb, ( + f"Tsoukalas reserves ({val_tsoukalas:.2f}) should be strictly > " + f"arb-only ({val_arb:.2f})" + ) + + +class TestTsoukalasDoesNotAffectVirtualBalances: + """Test 11: Within a single scan step, noise modifies real reserves + but does NOT modify Va/Vb in the carry.""" + + def test_single_step_virtual_balances_identical(self): + """Call the scan step directly for one step with arb_only and tsoukalas_sqrt. + Assert that the carry's Va and Vb are bitwise identical.""" + from quantammsim.pools.reCLAMM.reclamm_reserves import ( + _reclamm_scan_step_with_fees_and_revenue, + ) + from quantammsim.pools.G3M.optimal_n_pool_arb import ( + precalc_shared_values_for_all_signatures, + precalc_components_of_optimal_trade_across_prices, + ) + + reserves, Va, Vb = _init_pool() + # Small price shift so arb volume is small relative to predicted noise + # (a 2500→3000 jump produces arb > predicted noise/min, zeroing noise_vol) + prices_1 = jnp.array([[2510.0, 1.0]]) + + weights = jnp.array([0.5, 0.5]) + gamma = 1.0 - _FEES + + _, active_trade_dirs, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(ALL_SIG_VARIATIONS_2, 2) + ) + aiw, par, aoar = precalc_components_of_optimal_trade_across_prices( + weights, prices_1, gamma, tokens_to_drop, + active_trade_dirs, leave_one_out_idxs, + ) + + carry = [ + reserves, Va, Vb, + jnp.float64(1.0), # prev_lp_supply + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] + + def run_step(noise_model, noise_params=None): + input_list = [ + prices_1[0], aiw[0], par[0], aoar[0], + jnp.float64(gamma), jnp.float64(0.0), + jnp.float64(0.0), + jnp.array([0.0, 0.0, 0.0]), # price_ratio_update (no-op) + jnp.float64(1.0), # lp_supply + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + input_list.append(jnp.float64(0.5)) # volatility + + return _reclamm_scan_step_with_fees_and_revenue( + carry, input_list, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_dirs, + n=2, + centeredness_margin=_CM, + daily_price_shift_base=_DPSB, + seconds_per_step=_SPP, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, + ) + + carry_arb, (res_arb, _) = run_step("arb_only") + carry_tsoukalas, (res_tsoukalas, _) = run_step( + "tsoukalas_sqrt", DEFAULT_NOISE_PARAMS, + ) + + # Va and Vb in carry must be bitwise identical + npt.assert_array_equal(carry_arb[1], carry_tsoukalas[1], + err_msg="Va should be unaffected by noise model") + npt.assert_array_equal(carry_arb[2], carry_tsoukalas[2], + err_msg="Vb should be unaffected by noise model") + + # But real reserves SHOULD differ (noise adds fee income). + # The noise effect is small relative to reserve magnitude, so use + # exact bitwise comparison rather than allclose (whose default + # rtol=1e-5 would mask the difference). + assert not jnp.array_equal(res_arb, res_tsoukalas), ( + "Real reserves should differ between arb_only and tsoukalas_sqrt" + ) + + # Same invariant for loglinear path + loglinear_params = {"b_0": -1.4, "b_sigma": 0.1, "b_c": 1.04} + carry_loglinear, (res_loglinear, _) = run_step( + "loglinear", loglinear_params, + ) + npt.assert_array_equal(carry_arb[1], carry_loglinear[1], + err_msg="Va should be unaffected by loglinear noise model") + npt.assert_array_equal(carry_arb[2], carry_loglinear[2], + err_msg="Vb should be unaffected by loglinear noise model") + assert not jnp.array_equal(res_arb, res_loglinear), ( + "Real reserves should differ between arb_only and loglinear" + ) + + +class TestTsoukalasWithFeeRevenue: + """Test 12: Fee revenue includes noise contribution.""" + + def test_fee_revenue_includes_noise(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + vol_array = jnp.full(n_steps, 0.5) # Synthetic constant volatility + + _, fee_rev_arb = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="arb_only", + ) + _, fee_rev_tsoukalas = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="tsoukalas_sqrt", + noise_params=DEFAULT_NOISE_PARAMS, + volatility_array=vol_array, + ) + # Tsoukalas should generate strictly more fee revenue due to noise volume + total_arb = float(fee_rev_arb.sum()) + total_tsoukalas = float(fee_rev_tsoukalas.sum()) + assert total_tsoukalas > total_arb, ( + f"Tsoukalas fee revenue ({total_tsoukalas:.4f}) should exceed " + f"arb-only ({total_arb:.4f})" + ) + + +# --------------------------------------------------------------------------- +# Tests 13-14: Pool class integration tests +# --------------------------------------------------------------------------- + + +class TestNoiseModelFromFingerprint: + """Test 13: Pool reads noise_model from fingerprint.""" + + def test_tsoukalas_sqrt_from_fingerprint(self): + from quantammsim.pools.creator import create_pool + from quantammsim.runners.jax_runner_utils import Hashabledict + + pool = create_pool("reclamm") + + params = { + "price_ratio": _PRICE_RATIO, + "centeredness_margin": _CM, + "daily_price_shift_base": _DPSB, + } + + n_steps = 50 + np.random.seed(42) + price_a = 2500.0 * np.exp(np.cumsum(np.random.normal(0, 0.01, n_steps))) + prices = jnp.stack([jnp.array(price_a), jnp.ones(n_steps)], axis=1) + + # Fingerprint with Tsoukalas noise model + run_fingerprint_tsoukalas = Hashabledict({ + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": _POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "fees": _FEES, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "noise_model": "tsoukalas_sqrt", + "reclamm_noise_params": DEFAULT_NOISE_PARAMS, + }) + + # Fingerprint without noise + run_fingerprint_arb_only = Hashabledict({ + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": _POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "fees": _FEES, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "noise_model": "arb_only", + }) + + start_index = jnp.array([0, 0]) + + res_tsoukalas, fee_rev_tsoukalas = pool.calculate_reserves_and_fee_revenue_with_fees( + params, run_fingerprint_tsoukalas, prices, start_index, + ) + res_arb, fee_rev_arb = pool.calculate_reserves_and_fee_revenue_with_fees( + params, run_fingerprint_arb_only, prices, start_index, + ) + + assert res_tsoukalas.shape == (n_steps, 2) + assert fee_rev_tsoukalas.shape == (n_steps,) + # Tsoukalas should produce more fee revenue + assert float(fee_rev_tsoukalas.sum()) > float(fee_rev_arb.sum()) + + +class TestVolatilityComputedForTsoukalas: + """Test 14: Volatility array auto-computed when noise_model is tsoukalas_*. + + Uses >= 1440 minutes so the real vmap+dynamic_slice path is exercised + (not just the <1440 fallback). Compares against arb_only to verify the + auto-computed volatility feeds through to meaningfully different fee revenue. + """ + + def test_volatility_auto_computed_affects_fee_revenue(self): + from quantammsim.pools.creator import create_pool + from quantammsim.runners.jax_runner_utils import Hashabledict + + pool = create_pool("reclamm") + + params = { + "price_ratio": _PRICE_RATIO, + "centeredness_margin": _CM, + "daily_price_shift_base": _DPSB, + } + + # Need at least 1 day of data for real volatility computation + n_steps = 1440 + 100 # Just over 1 day + np.random.seed(42) + price_a = 2500.0 * np.exp(np.cumsum(np.random.normal(0, 0.001, n_steps))) + prices = jnp.stack([jnp.array(price_a), jnp.ones(n_steps)], axis=1) + + base_fp = { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": _POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "fees": _FEES, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + } + + fp_tsoukalas = Hashabledict({ + **base_fp, + "noise_model": "tsoukalas_sqrt", + "reclamm_noise_params": DEFAULT_NOISE_PARAMS, + }) + fp_arb_only = Hashabledict({ + **base_fp, + "noise_model": "arb_only", + }) + + start_index = jnp.array([0, 0]) + + res_tsoukalas, fee_rev_tsoukalas = pool.calculate_reserves_and_fee_revenue_with_fees( + params, fp_tsoukalas, prices, start_index, + ) + _, fee_rev_arb = pool.calculate_reserves_and_fee_revenue_with_fees( + params, fp_arb_only, prices, start_index, + ) + + assert res_tsoukalas.shape == (n_steps, 2) + assert fee_rev_tsoukalas.shape == (n_steps,) + + # The auto-computed volatility must feed through to produce + # strictly more fee revenue than arb-only + total_tsoukalas = float(fee_rev_tsoukalas.sum()) + total_arb = float(fee_rev_arb.sum()) + assert total_tsoukalas > total_arb, ( + f"Tsoukalas with auto-computed volatility ({total_tsoukalas:.4f}) " + f"should exceed arb-only ({total_arb:.4f})" + ) + + +# --------------------------------------------------------------------------- +# Tests 15-16: Calibration pipeline tests +# --------------------------------------------------------------------------- + +import pandas as pd +from scripts.calibrate_reclamm_noise import run_ols_calibration + + +class TestOLSRecoversKnownParams: + """Test 15: Synthetic data with known coefficients -> OLS recovers them.""" + + def test_ols_recovery_sqrt(self): + rng = np.random.default_rng(42) + n = 200 + + true_a_0 = 0.8 + true_a_sigma = 1.5 + true_a_c = 0.6 + + vol = rng.uniform(0.2, 1.0, n) + eff_tvl = rng.uniform(3e6, 50e6, n) + + # Construct volume from known params (in $M units) + volume_M = ( + true_a_0 + + true_a_sigma * vol + + true_a_c * np.sqrt(eff_tvl / 1e6) + ) + # Add small noise + volume_M += rng.normal(0, 0.01, n) + volume_usd = volume_M * 1e6 + + df = pd.DataFrame({ + "volume_usd": volume_usd, + "volatility": vol, + "effective_tvl_usd": eff_tvl, + }) + + noise_params, diagnostics = run_ols_calibration(df, base_fee=0.003, model="sqrt") + + npt.assert_allclose(noise_params["a_0_base"], true_a_0, atol=0.05) + npt.assert_allclose(noise_params["a_sigma"], true_a_sigma, atol=0.05) + npt.assert_allclose(noise_params["a_c"], true_a_c, atol=0.05) + assert diagnostics["r_squared"] > 0.99 + + def test_ols_recovery_log(self): + rng = np.random.default_rng(123) + n = 200 + + true_a_0 = 0.5 + true_a_sigma = 2.0 + true_a_c = 0.4 + + vol = rng.uniform(0.2, 1.0, n) + eff_tvl = rng.uniform(3e6, 50e6, n) + + volume_M = ( + true_a_0 + + true_a_sigma * vol + + true_a_c * np.log(eff_tvl / 1e6) + ) + volume_M += rng.normal(0, 0.01, n) + volume_usd = volume_M * 1e6 + + df = pd.DataFrame({ + "volume_usd": volume_usd, + "volatility": vol, + "effective_tvl_usd": eff_tvl, + }) + + noise_params, diagnostics = run_ols_calibration(df, base_fee=0.003, model="log") + + npt.assert_allclose(noise_params["a_0_base"], true_a_0, atol=0.05) + npt.assert_allclose(noise_params["a_sigma"], true_a_sigma, atol=0.05) + npt.assert_allclose(noise_params["a_c"], true_a_c, atol=0.05) + assert diagnostics["r_squared"] > 0.99 + + +class TestOutputFormatCompatible: + """Test 16: Output dict has all required keys for run_fingerprint integration.""" + + def test_output_keys(self): + rng = np.random.default_rng(99) + n = 50 + df = pd.DataFrame({ + "volume_usd": rng.uniform(1e6, 10e6, n), + "volatility": rng.uniform(0.2, 0.8, n), + "effective_tvl_usd": rng.uniform(3e6, 25e6, n), + }) + + noise_params, diagnostics = run_ols_calibration(df, base_fee=0.003) + + required_keys = {"a_0_base", "a_f", "a_sigma", "a_c", "base_fee"} + assert set(noise_params.keys()) == required_keys + + # All values are float + for k, v in noise_params.items(): + assert isinstance(v, float), f"{k} should be float, got {type(v)}" + + # a_f should be 0 for static fees + assert noise_params["a_f"] == 0.0 + + # Diagnostics should have standard errors + assert "se" in diagnostics + assert set(diagnostics["se"].keys()) == {"a_0", "a_sigma", "a_c"} + + # Can be used directly as noise_params for the noise functions + vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd=15e6, + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=noise_params, + ) + assert jnp.isfinite(vol) + + +# --------------------------------------------------------------------------- +# Tests 17-24: Loglinear (hierarchical) noise volume model +# --------------------------------------------------------------------------- + +# Typical noise_params from the hierarchical model +LOGLINEAR_NOISE_PARAMS = { + "b_0": -7.1, # grand mean + BLUP + "b_sigma": -0.003, # shared volatility effect + "b_c": 1.04, # shared TVL elasticity + "base_fee": 0.003, +} + + +class TestLoglinearPositiveOutput: + """Test 17: Volume > 0 for typical inputs.""" + + def test_loglinear_positive_output(self): + vol = reclamm_loglinear_noise_volume( + effective_value_usd=15_000_000.0, + gamma=0.997, + volatility=0.5, + arb_volume_this_period=0.0, + noise_params=LOGLINEAR_NOISE_PARAMS, + ) + assert float(vol) > 0, f"Expected positive noise volume, got {float(vol)}" + + +class TestLoglinearZeroWhenArbDominates: + """Test 18: noise = max(0, ...) so returns 0 when arb dominates.""" + + def test_loglinear_zero_when_arb_large(self): + vol = reclamm_loglinear_noise_volume( + effective_value_usd=3_000_000.0, + gamma=0.997, + volatility=0.3, + arb_volume_this_period=1e12, + noise_params=LOGLINEAR_NOISE_PARAMS, + ) + assert float(vol) == 0.0, f"Expected zero, got {float(vol)}" + + +class TestLoglinearMonotonic: + """Test 19: Higher effective TVL -> more predicted volume.""" + + def test_loglinear_monotonic_tvl(self): + kwargs = dict( + gamma=0.997, volatility=0.5, arb_volume_this_period=0.0, + noise_params=LOGLINEAR_NOISE_PARAMS, + ) + vol_low = reclamm_loglinear_noise_volume( + effective_value_usd=3_000_000.0, **kwargs) + vol_high = reclamm_loglinear_noise_volume( + effective_value_usd=20_000_000.0, **kwargs) + assert float(vol_high) > float(vol_low) + + +class TestLoglinearCustomParams: + """Test 20: noise_params dict values are actually used.""" + + def test_loglinear_custom_b_sigma(self): + # With b_sigma=0, volatility shouldn't matter + zero_sigma_params = {**LOGLINEAR_NOISE_PARAMS, "b_sigma": 0.0} + v1 = reclamm_loglinear_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.2, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + v2 = reclamm_loglinear_noise_volume( + effective_value_usd=10_000_000.0, + gamma=0.997, volatility=0.8, arb_volume_this_period=0.0, + noise_params=zero_sigma_params, + ) + npt.assert_allclose(float(v1), float(v2), rtol=1e-10, + err_msg="With b_sigma=0, volatility should not affect output") + + +class TestLoglinearScanStepIntegration: + """Test 21: loglinear noise model works through the scan step.""" + + def test_loglinear_increases_reserves(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + vol_array = jnp.full(n_steps, 0.5) + + # Use a b_0 that gives reasonable volume at this TVL + # Pool TVL ~$1M → log(1e6) ≈ 13.8 → b_0 + 1.04*13.8 = b_0 + 14.4 + # Want log(V_daily) ≈ 13 (= ~$440k/day) → b_0 ≈ -1.4 + params = {"b_0": -1.4, "b_sigma": 0.1, "b_c": 1.04, "base_fee": 0.003} + + res_arb_only = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="arb_only", + ) + res_loglinear = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="loglinear", + noise_params=params, + volatility_array=vol_array, + ) + val_arb = float(jnp.sum(res_arb_only[-1] * prices[-1])) + val_loglinear = float(jnp.sum(res_loglinear[-1] * prices[-1])) + assert val_loglinear > val_arb, ( + f"Loglinear reserves ({val_loglinear:.2f}) should exceed " + f"arb-only ({val_arb:.2f})" + ) + + +class TestLoglinearFeeRevenue: + """Test 22: Fee revenue includes loglinear noise contribution.""" + + def test_loglinear_fee_revenue(self): + reserves, Va, Vb = _init_pool() + n_steps = 50 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + vol_array = jnp.full(n_steps, 0.5) + + params = {"b_0": -1.4, "b_sigma": 0.1, "b_c": 1.04, "base_fee": 0.003} + + _, fee_rev_arb = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="arb_only", + ) + _, fee_rev_loglinear = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, _CM, _DPSB, _SPP, + fees=_FEES, all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_model="loglinear", + noise_params=params, + volatility_array=vol_array, + ) + total_arb = float(fee_rev_arb.sum()) + total_loglinear = float(fee_rev_loglinear.sum()) + assert total_loglinear > total_arb, ( + f"Loglinear fee revenue ({total_loglinear:.4f}) should exceed " + f"arb-only ({total_arb:.4f})" + ) + + +class TestLoglinearPoolClassIntegration: + """Test 23: Pool reads loglinear noise_model from fingerprint.""" + + def test_loglinear_from_fingerprint(self): + from quantammsim.pools.creator import create_pool + from quantammsim.runners.jax_runner_utils import Hashabledict + + pool = create_pool("reclamm") + + params = { + "price_ratio": _PRICE_RATIO, + "centeredness_margin": _CM, + "daily_price_shift_base": _DPSB, + } + + n_steps = 50 + np.random.seed(42) + price_a = 2500.0 * np.exp(np.cumsum(np.random.normal(0, 0.01, n_steps))) + prices = jnp.stack([jnp.array(price_a), jnp.ones(n_steps)], axis=1) + + loglinear_params = { + "b_0": -1.4, "b_sigma": 0.1, "b_c": 1.04, "base_fee": 0.003, + } + + fp_loglinear = Hashabledict({ + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": _POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "fees": _FEES, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "noise_model": "loglinear", + "reclamm_noise_params": loglinear_params, + }) + + fp_arb_only = Hashabledict({ + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": _POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "fees": _FEES, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "noise_model": "arb_only", + }) + + start_index = jnp.array([0, 0]) + + res_loglinear, fee_rev_loglinear = pool.calculate_reserves_and_fee_revenue_with_fees( + params, fp_loglinear, prices, start_index, + ) + _, fee_rev_arb = pool.calculate_reserves_and_fee_revenue_with_fees( + params, fp_arb_only, prices, start_index, + ) + + assert res_loglinear.shape == (n_steps, 2) + assert fee_rev_loglinear.shape == (n_steps,) + assert float(fee_rev_loglinear.sum()) > float(fee_rev_arb.sum()) + + +class TestLoglinearDefaultParams: + """Test 24: Function works with default params (noise_params=None).""" + + def test_loglinear_defaults(self): + vol = reclamm_loglinear_noise_volume( + effective_value_usd=15_000_000.0, + gamma=0.997, + volatility=0.5, + arb_volume_this_period=0.0, + ) + assert jnp.isfinite(vol) + assert float(vol) > 0 diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index 1db40cb..c795176 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -13,12 +13,14 @@ from quantammsim.pools.reCLAMM.reclamm_reserves import ( compute_invariant, compute_price_ratio, + compute_centeredness, initialise_reclamm_reserves, calibrate_arc_length_speed, _jax_calc_reclamm_reserves_zero_fees, + _jax_calc_reclamm_reserves_zero_fees_full_state, _jax_calc_reclamm_reserves_with_fees, + _jax_calc_reclamm_reserves_and_fee_revenue_with_fees, ) -from tests.conftest import TEST_DATA_DIR # For n=2: sig variations with exactly one +1 and one -1 ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) @@ -816,3 +818,607 @@ def test_train_on_historic_data_optuna(self): } result = train_on_historic_data(fp, verbose=False, root=TEST_DATA_DIR) assert result is not None + + +class TestNoiseTraderRatio: + """Noise trader fee income wiring for reClAMM pools.""" + + def _run_with_noise(self, noise_trader_ratio, n_steps=50, fees=0.003): + """Run reClAMM with fees and return reserves + fee revenue.""" + reserves, Va, Vb = _init_pool() + # Trending prices so arb trades happen and noise trade has non-zero effect + prices = _make_trending_prices(2500.0, 3000.0, 1.0, n_steps) + + return _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_trader_ratio=noise_trader_ratio, + ) + + def test_noise_trader_ratio_zero_is_default(self): + """noise_trader_ratio=0.0 should produce identical results to omitting it.""" + reserves, Va, Vb = _init_pool() + prices = _make_trending_prices(2500.0, 3000.0, 1.0, 50) + + # Explicit zero + res_zero, rev_zero = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + noise_trader_ratio=0.0, + ) + + # Default (omitted) + res_default, rev_default = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + + npt.assert_array_equal(res_zero, res_default) + npt.assert_array_equal(rev_zero, rev_default) + + def test_noise_trader_ratio_increases_reserves(self): + """Noise traders add fee income, so pool value should be higher.""" + res_no_noise, _ = self._run_with_noise(0.0) + res_noise, _ = self._run_with_noise(0.1) + + # Final pool value in USD (sum of reserves * price at last step) + final_prices = jnp.array([3000.0, 1.0]) + value_no_noise = (res_no_noise[-1] * final_prices).sum() + value_noise = (res_noise[-1] * final_prices).sum() + + assert value_noise > value_no_noise, ( + f"Noise traders should increase pool value: {value_noise} <= {value_no_noise}" + ) + + # Reserves should differ + assert not jnp.allclose(res_no_noise, res_noise), ( + "Reserves should differ with noise traders" + ) + + def test_noise_trader_ratio_through_pool_class(self): + """noise_trader_ratio flows through the pool class methods.""" + from quantammsim.pools.creator import create_pool + from quantammsim.runners.jax_runner_utils import Hashabledict + + pool = create_pool("reclamm") + + params = { + "price_ratio": DEFAULT_PRICE_RATIO, + "centeredness_margin": DEFAULT_CENTEREDNESS_MARGIN, + "daily_price_shift_base": DEFAULT_DAILY_PRICE_SHIFT_BASE, + } + + n_steps = 50 + np.random.seed(42) + price_a = 2500.0 * np.exp(np.cumsum(np.random.normal(0, 0.01, n_steps))) + prices = jnp.stack([jnp.array(price_a), jnp.ones(n_steps)], axis=1) + start_index = jnp.array([0, 0]) + + base_fp = { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": 1_000_000.0, + "arb_frequency": 1, + "do_arb": True, + "fees": 0.003, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + } + + fp_no_noise = Hashabledict({**base_fp, "noise_trader_ratio": 0.0}) + fp_noise = Hashabledict({**base_fp, "noise_trader_ratio": 0.1}) + + res_no = pool.calculate_reserves_with_fees( + params, fp_no_noise, prices, start_index + ) + res_yes = pool.calculate_reserves_with_fees( + params, fp_noise, prices, start_index + ) + + # Should run and produce different results + assert res_no.shape == (n_steps, 2) + assert res_yes.shape == (n_steps, 2) + assert not jnp.allclose(res_no, res_yes), ( + "Pool class should produce different reserves with noise traders" + ) + + def test_noise_trade_does_not_affect_virtual_balances(self): + """Noise trade fee income only grows real reserves, not Va/Vb.""" + from quantammsim.pools.reCLAMM.reclamm_reserves import ( + _reclamm_scan_step_with_fees_and_revenue, + precalc_shared_values_for_all_signatures, + precalc_components_of_optimal_trade_across_prices, + ) + + reserves, Va, Vb = _init_pool() + # Price must differ from init (2500) to trigger an arb trade + prices_arr = jnp.array([[3000.0, 1.0]]) + prices = prices_arr[0] + + weights = jnp.array([0.5, 0.5]) + gamma = 1.0 - 0.003 + n_assets = 2 + + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(ALL_SIG_VARIATIONS_2, n_assets) + ) + + active_initial_weights, per_asset_ratios, all_other_assets_ratios = ( + precalc_components_of_optimal_trade_across_prices( + weights, prices_arr, gamma, tokens_to_drop, + active_trade_directions, leave_one_out_idxs, + ) + ) + + carry = [ + reserves, Va, Vb, + jnp.float64(1.0), # prev_lp_supply + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] + inputs = [ + prices, + active_initial_weights[0], + per_asset_ratios[0], + all_other_assets_ratios[0], + gamma, + 0.0, # arb_thresh + 0.0, # arb_fees + jnp.array([0.0, 0.0, 0.0]), # price_ratio_update (no-op) + jnp.float64(1.0), # lp_supply + ] + + # Without noise + carry_no, _ = _reclamm_scan_step_with_fees_and_revenue( + carry, inputs, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + n=n_assets, + centeredness_margin=DEFAULT_CENTEREDNESS_MARGIN, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + noise_trader_ratio=0.0, + ) + + # With noise + carry_yes, _ = _reclamm_scan_step_with_fees_and_revenue( + carry, inputs, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + n=n_assets, + centeredness_margin=DEFAULT_CENTEREDNESS_MARGIN, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + noise_trader_ratio=0.5, + ) + + # Virtual balances should be identical + npt.assert_array_equal(carry_no[1], carry_yes[1], err_msg="Va changed") + npt.assert_array_equal(carry_no[2], carry_yes[2], err_msg="Vb changed") + + # But real reserves should differ (noise adds fee income) + assert not jnp.allclose(carry_no[0], carry_yes[0]), ( + "Real reserves should differ with noise traders" + ) + + +class TestLpSupply: + """LP supply (BPT) scaling for reClAMM pools. + + Ported from Foundry fuzz test invariants in ReClammLiquidity.t.sol: + both real AND virtual reserves scale proportionally with BPT supply, + preserving price and centeredness. + """ + + def test_lp_supply_none_matches_default(self): + """lp_supply_array=None vs omitting param → identical results.""" + reserves, Va, Vb = _init_pool() + prices = _make_trending_prices(2500.0, 3000.0, 1.0, 20) + + result_none = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + lp_supply_array=None, + ) + result_omit = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + npt.assert_array_equal(result_none, result_omit) + + def test_lp_supply_constant_one_matches_default(self): + """lp_supply_array=jnp.ones(T) vs None → identical results.""" + reserves, Va, Vb = _init_pool() + n_steps = 20 + prices = _make_trending_prices(2500.0, 3000.0, 1.0, n_steps) + + result_ones = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + lp_supply_array=jnp.ones(n_steps), + ) + result_none = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + npt.assert_array_equal(result_ones, result_none) + + def test_lp_supply_doubling_scales_reserves(self): + """BPT supply doubling halfway → reserves ~2x at that step. + + Ported from Foundry testAddLiquidity__Fuzz: reserves scale with BPT. + """ + reserves, Va, Vb = _init_pool() + n_steps = 40 + half = n_steps // 2 + prices = _make_constant_prices(2500.0, 1.0, n_steps) + + # No LP supply change (baseline) + result_base = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + + # LP supply doubles at step `half` + lp_supply = jnp.concatenate([jnp.ones(half), 2.0 * jnp.ones(n_steps - half)]) + result_lp = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + lp_supply_array=lp_supply, + ) + + # Before doubling: identical + npt.assert_allclose(result_lp[:half], result_base[:half], rtol=1e-10) + # After doubling: reserves ~2x the baseline + npt.assert_allclose(result_lp[half], result_base[half] * 2.0, rtol=1e-6) + + def test_lp_supply_halving_scales_reserves(self): + """BPT supply halving halfway → reserves ~0.5x at that step.""" + reserves, Va, Vb = _init_pool() + n_steps = 40 + half = n_steps // 2 + prices = _make_constant_prices(2500.0, 1.0, n_steps) + + result_base = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + + lp_supply = jnp.concatenate([jnp.ones(half), 0.5 * jnp.ones(n_steps - half)]) + result_lp = _jax_calc_reclamm_reserves_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + lp_supply_array=lp_supply, + ) + + # Before halving: identical + npt.assert_allclose(result_lp[:half], result_base[:half], rtol=1e-10) + # After halving: reserves ~0.5x the baseline + npt.assert_allclose(result_lp[half], result_base[half] * 0.5, rtol=1e-6) + + def test_lp_supply_preserves_price(self): + """Price ratio (Ra+Va)/(Rb+Vb) unchanged through LP supply scaling. + + Ported from Foundry: assertEq(price_before, price_after) in + testAddLiquidity__Fuzz. + """ + reserves, Va, Vb = _init_pool() + n_steps = 40 + half = n_steps // 2 + # Trending so virtual balances shift — makes price preservation non-trivial + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + + lp_supply = jnp.concatenate([jnp.ones(half), 2.0 * jnp.ones(n_steps - half)]) + + # Use full_state variant to get Va/Vb history + result_reserves, Va_hist, Vb_hist = _jax_calc_reclamm_reserves_zero_fees_full_state( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + lp_supply_array=lp_supply, + ) + + # Price = (Rb + Vb) / (Ra + Va) — should be identical at step half-1 and half + # (the supply change happens at start of step `half`, before arb) + Ra_before = result_reserves[half - 1, 0] + Rb_before = result_reserves[half - 1, 1] + Va_before = Va_hist[half - 1] + Vb_before = Vb_hist[half - 1] + price_before = (Rb_before + Vb_before) / (Ra_before + Va_before) + + # At step `half`, the carry from step half-1 gets scaled, then arb runs. + # We need to check price right after scaling, before arb. + # The closest check: price at step `half` output should reflect + # scaled reserves with arb on top. Instead, verify that a constant-price + # run preserves exact ratio. + prices_const = _make_constant_prices(2500.0, 1.0, n_steps) + res_c, Va_c, Vb_c = _jax_calc_reclamm_reserves_zero_fees_full_state( + reserves, Va, Vb, prices_const, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + lp_supply_array=lp_supply, + ) + # With constant prices and no arb, price is just initial ratio throughout + price_at_half_minus_1 = (res_c[half - 1, 1] + Vb_c[half - 1]) / ( + res_c[half - 1, 0] + Va_c[half - 1] + ) + price_at_half = (res_c[half, 1] + Vb_c[half]) / ( + res_c[half, 0] + Va_c[half] + ) + npt.assert_allclose( + float(price_at_half_minus_1), float(price_at_half), rtol=1e-10, + err_msg="Price ratio should be preserved through LP supply change", + ) + + def test_lp_supply_preserves_centeredness(self): + """Centeredness unchanged through LP supply scaling. + + Ported from Foundry: centeredness invariance in testAddLiquidity__Fuzz. + """ + reserves, Va, Vb = _init_pool() + n_steps = 40 + half = n_steps // 2 + prices = _make_constant_prices(2500.0, 1.0, n_steps) + + lp_supply = jnp.concatenate([jnp.ones(half), 2.0 * jnp.ones(n_steps - half)]) + + res, Va_hist, Vb_hist = _jax_calc_reclamm_reserves_zero_fees_full_state( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + lp_supply_array=lp_supply, + ) + + c_before, _ = compute_centeredness( + res[half - 1, 0], res[half - 1, 1], Va_hist[half - 1], Vb_hist[half - 1] + ) + c_after, _ = compute_centeredness( + res[half, 0], res[half, 1], Va_hist[half], Vb_hist[half] + ) + npt.assert_allclose( + float(c_before), float(c_after), rtol=1e-10, + err_msg="Centeredness should be preserved through LP supply change", + ) + + def test_lp_supply_through_pool_class(self): + """Pool class passes lp_supply_array to underlying computation.""" + from quantammsim.pools.creator import create_pool + from quantammsim.runners.jax_runner_utils import Hashabledict + + pool = create_pool("reclamm") + + params = { + "price_ratio": DEFAULT_PRICE_RATIO, + "centeredness_margin": DEFAULT_CENTEREDNESS_MARGIN, + "daily_price_shift_base": DEFAULT_DAILY_PRICE_SHIFT_BASE, + } + + n_steps = 20 + prices = _make_trending_prices(2500.0, 3000.0, 1.0, n_steps) + + run_fingerprint = Hashabledict({ + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": 1_000_000.0, + "arb_frequency": 1, + "do_arb": True, + "fees": 0.003, + "gas_cost": 0.0, + "arb_fees": 0.0, + "tokens": ("ETH", "USDC"), + "numeraire": "USDC", + "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + }) + + start_index = jnp.array([0, 0]) + + from quantammsim.core_simulator.dynamic_inputs import DynamicInputArrays + + lp_supply = jnp.concatenate([jnp.ones(10), 2.0 * jnp.ones(10)]) + + di_with_lp = DynamicInputArrays( + trades=None, + fees=jnp.full(n_steps, 0.003), + gas_cost=jnp.zeros(n_steps), + arb_fees=jnp.zeros(n_steps), + lp_supply=lp_supply, + reclamm_price_ratio_updates=jnp.array([[0.0, 0.0, 0.0, jnp.nan]]), + ) + di_without_lp = DynamicInputArrays( + trades=None, + fees=jnp.full(n_steps, 0.003), + gas_cost=jnp.zeros(n_steps), + arb_fees=jnp.zeros(n_steps), + lp_supply=jnp.ones(n_steps), + reclamm_price_ratio_updates=jnp.array([[0.0, 0.0, 0.0, jnp.nan]]), + ) + + res_with_lp, _ = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( + params, run_fingerprint, prices, start_index, + dynamic_inputs=di_with_lp, + ) + res_without_lp, _ = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( + params, run_fingerprint, prices, start_index, + dynamic_inputs=di_without_lp, + ) + + # First 10 steps identical, then diverge + npt.assert_allclose(res_with_lp[:10], res_without_lp[:10], rtol=1e-10) + assert not jnp.allclose(res_with_lp[10:], res_without_lp[10:]), ( + "LP supply change should produce different reserves" + ) + + def test_lp_supply_with_fee_revenue(self): + """Doubling LP supply → fee revenue increases (bigger pool → bigger arb trades).""" + reserves, Va, Vb = _init_pool() + n_steps = 40 + half = n_steps // 2 + prices = _make_trending_prices(2500.0, 3500.0, 1.0, n_steps) + + # Baseline: no supply change + _, rev_base = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + + # Supply doubles halfway + lp_supply = jnp.concatenate([jnp.ones(half), 2.0 * jnp.ones(n_steps - half)]) + _, rev_lp = _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( + reserves, Va, Vb, prices, + DEFAULT_CENTEREDNESS_MARGIN, + DEFAULT_DAILY_PRICE_SHIFT_BASE, + DEFAULT_SECONDS_PER_STEP, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=ALL_SIG_VARIATIONS_2, + lp_supply_array=lp_supply, + ) + + # After doubling, fee revenue per step should be larger + # (pool is 2x bigger → arb trades are 2x bigger → fees are 2x) + post_double_base = rev_base[half:].sum() + post_double_lp = rev_lp[half:].sum() + assert float(post_double_lp) > float(post_double_base), ( + f"Fee revenue should increase after doubling: {post_double_lp} <= {post_double_base}" + ) + + def test_lp_supply_e2e_do_run_on_historic_data(self): + """End-to-end: lp_supply flows through do_run_on_historic_data via DynamicInputFrames.""" + import pandas as pd + from quantammsim.runners.jax_runners import do_run_on_historic_data + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + + fp = { + "rule": "reclamm", + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-15 00:00:00", + "initial_pool_value": 1_000_000.0, + "do_arb": True, + "fees": 0.003, + } + params = { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array(DEFAULT_DAILY_PRICE_SHIFT_BASE), + } + + # Baseline: no LP supply change + result_base = do_run_on_historic_data( + run_fingerprint={**fp}, + params={**params}, + root=TEST_DATA_DIR, + ) + + # LP supply doubles halfway through the period + # unix column must be in milliseconds (matches windowing_utils convention) + start_unix_ms = int(pd.Timestamp("2023-01-01").timestamp() * 1000) + mid_unix_ms = int(pd.Timestamp("2023-01-08").timestamp() * 1000) + lp_supply_df = pd.DataFrame({ + "unix": [start_unix_ms, mid_unix_ms], + "lp_supply": [1.0, 2.0], + }) + + result_lp = do_run_on_historic_data( + run_fingerprint={**fp}, + params={**params}, + dynamic_input_frames=DynamicInputFrames(lp_supply=lp_supply_df), + root=TEST_DATA_DIR, + ) + + # Final values should differ — doubling LP supply changes pool dynamics + base_val = float(result_base["final_value"]) + lp_val = float(result_lp["final_value"]) + assert base_val != lp_val, ( + f"LP supply change should affect final value: base={base_val}, lp={lp_val}" + ) + # Doubled pool should have higher final value (more reserves) + assert lp_val > base_val, ( + f"Doubled LP supply should increase final value: {lp_val} <= {base_val}" + ) +