diff --git a/experiments/diagnose_spikes.py b/experiments/diagnose_spikes.py new file mode 100644 index 0000000..b146c6b --- /dev/null +++ b/experiments/diagnose_spikes.py @@ -0,0 +1,216 @@ +"""Diagnose spikes in sim-vs-world deviation for LP-supply-normalized runs. + +Compares old (no LP supply) vs new (with LP supply + per-LP normalization) +to pinpoint what causes spikes in the deviation time series. +""" + +import os +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, extract_on_chain_state, extract_initial_state, + get_data_end_date, load_world_history, load_bpt_supply_df, +) +from experiments.run_pool_battery import ( + run_sim, sample_at_timestamps, _start_str_from_pool, + _onchain_params_to_sim, PROTOCOL_FEE_SPLIT, +) +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +POOL_LABEL = "WAVAX_USDC" # Change to "cbBTC_WETH" etc. + + +def main(): + pool = POOL_REGISTRY[POOL_LABEL] + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + start_sec = datetime.strptime( + start_str, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + print(f"Pool: {pool.label}, TVL: ${pool.initial_pool_value_usd:,.0f}") + print(f"Period: {start_str} to {end_str}") + print(f"BPT range: {lp_supply_df['lp_supply'].min():.4f} to {lp_supply_df['lp_supply'].max():.4f}") + + # ---- Run sims ---- + # 1. Old way: no lp_supply at all + result_old = run_sim( + pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, start=start_str, end=end_str, + lp_supply_df=None, + ) + + # 2. New way: lp_supply in scan + per-LP normalization + result_new = run_sim( + pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, start=start_str, end=end_str, + lp_supply_df=lp_supply_df, + ) + + # 3. Raw lp run (scan has lp_supply, but we DON'T divide by it) + params = _onchain_params_to_sim(pool) + fp = { + "tokens": pool.tokens, "rule": "reclamm", + "startDateString": start_str, "endDateString": end_str, + "initial_pool_value": pool.initial_pool_value_usd, + "fees": pool.swap_fee, "gas_cost": 0.0, "arb_fees": 0.0, + "do_arb": True, "arb_frequency": 1, "chunk_period": 1440, + "weight_interpolation_period": 1440, + "reclamm_use_shift_exponent": True, + "reclamm_interpolation_method": "geometric", + "reclamm_centeredness_scaling": False, + "protocol_fee_split": PROTOCOL_FEE_SPLIT, + "reclamm_initial_state": initial_state, + } + result_raw = do_run_on_historic_data( + run_fingerprint=fp, params=params, lp_supply_df=lp_supply_df, + ) + v_lp_raw = np.array(result_raw["value"]) + + v_old = np.array(result_old["value_usd"]) # no LP, no normalization + v_new = np.array(result_new["value_usd"]) # LP scan + divided by lp_supply + + # ---- World ---- + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + prices_min = result_old["prices"] + + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts) + for i in range(prices_min.shape[1]) + ], axis=1) + + # BPT-normalized world value (same in both old and new) + world_bpt_val = ( + world["bal_0"] * prices_at_world[:, 0] + + world["bal_1"] * prices_at_world[:, 1] + ) + world_growth = world_bpt_val / world_bpt_val[0] + + # Raw world value (absolute, un-normalized) + world_raw_val = ( + world["raw_bal_0"] * prices_at_world[:, 0] + + world["raw_bal_1"] * prices_at_world[:, 1] + ) + + # Sample sim at world timestamps + old_at_world = sample_at_timestamps(v_old, start_sec, world_ts) + new_at_world = sample_at_timestamps(v_new, start_sec, world_ts) + raw_at_world = sample_at_timestamps(v_lp_raw, start_sec, world_ts) + + old_growth = old_at_world / old_at_world[0] + new_growth = new_at_world / new_at_world[0] + raw_growth = raw_at_world / raw_at_world[0] + + # Deviations + dev_old = (old_growth / world_growth - 1) * 100 + dev_new = (new_growth / world_growth - 1) * 100 + + # Raw vs raw-world comparison (both absolute) + world_raw_growth = world_raw_val / world_raw_val[0] + dev_raw = (raw_growth / world_raw_growth - 1) * 100 + + days = (world_ts - world_ts[0]) / 86400 + + # LP supply at world timestamps + lp_unix = np.array(lp_supply_df["unix"]) + lp_vals = np.array(lp_supply_df["lp_supply"]) + lp_at_world = np.interp(world_ts, lp_unix / 1000, lp_vals) + + # ---- Print diagnostics ---- + print(f"\n--- Final deviations ---") + print(f"Old (no LP): {dev_old[-1]:+.4f}%") + print(f"New (LP + per-LP): {dev_new[-1]:+.4f}%") + print(f"Raw (LP, absolute): {dev_raw[-1]:+.4f}%") + + # Spike analysis + for label, dev in [("Old", dev_old), ("New", dev_new), ("Raw", dev_raw)]: + diffs = np.abs(np.diff(dev)) + n_spikes_01 = np.sum(diffs > 0.1) + n_spikes_05 = np.sum(diffs > 0.5) + n_spikes_10 = np.sum(diffs > 1.0) + print(f"\n{label} — step-to-step jumps in deviation:") + print(f" >0.1%: {n_spikes_01}, >0.5%: {n_spikes_05}, >1.0%: {n_spikes_10}") + if n_spikes_10 > 0: + spike_idx = np.where(diffs > 1.0)[0] + for si in spike_idx[:5]: + print(f" day {days[si]:.1f}: dev {dev[si]:+.2f}% -> {dev[si+1]:+.2f}% " + f"(Δ={dev[si+1]-dev[si]:+.2f}%, lp={lp_at_world[si]:.4f}->{lp_at_world[si+1]:.4f})") + + # World growth spikes + world_g_diffs = np.diff(world_growth) + n_world_spikes = np.sum(np.abs(world_g_diffs) > 0.01) + print(f"\nWorld BPT-normalized growth jumps > 1%: {n_world_spikes}") + if n_world_spikes > 0: + wsi = np.where(np.abs(world_g_diffs) > 0.01)[0] + for si in wsi[:5]: + print(f" day {days[si]:.1f}: growth {world_growth[si]:.4f} -> {world_growth[si+1]:.4f} " + f"(Δ={world_g_diffs[si]:+.4f}, lp={lp_at_world[si]:.4f}->{lp_at_world[si+1]:.4f})") + + # ---- Plot ---- + fig, axes = plt.subplots(4, 1, figsize=(14, 16), sharex=True) + + ax = axes[0] + ax.plot(days, dev_old, "b-", linewidth=1.5, label=f"Old (no LP) → {dev_old[-1]:+.2f}%") + ax.plot(days, dev_new, "r-", linewidth=1.5, label=f"New (LP + per-LP norm) → {dev_new[-1]:+.2f}%") + ax.axhline(0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("% deviation from world") + ax.set_title(f"{pool.label} — gas=0, arb=1min — old vs new deviation") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[1] + ax.plot(days, dev_raw, "g-", linewidth=1.5, label=f"Raw absolute (LP scan, raw world) → {dev_raw[-1]:+.2f}%") + ax.axhline(0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("% deviation") + ax.set_title("Alternative: raw absolute sim vs raw absolute world") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[2] + ax.plot(days, world_growth, "k-", linewidth=2, label="World (BPT-normalized)") + ax.plot(days, old_growth, "b-", linewidth=1, alpha=0.8, label="Old sim") + ax.plot(days, new_growth, "r-", linewidth=1, alpha=0.8, label="New sim (LP + per-LP)") + ax.set_ylabel("Growth factor") + ax.set_title("Growth factors") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[3] + ax.plot(days, lp_at_world, "g-", linewidth=2, label="LP supply (BPT/BPT₀)") + # Mark large LP changes + lp_diffs = np.abs(np.diff(lp_at_world)) + big_lp = np.where(lp_diffs > 0.05)[0] + if len(big_lp): + ax.scatter(days[big_lp], lp_at_world[big_lp], c="red", s=40, zorder=5, + label=f"Large LP events ({len(big_lp)})") + ax.set_ylabel("BPT / BPT₀") + ax.set_xlabel("Days from start") + ax.set_title("On-chain BPT supply") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + fig.suptitle( + f"{pool.label} ({pool.chain}) — spike diagnosis", + fontsize=13, fontweight="bold", + ) + plt.tight_layout() + + os.makedirs("results", exist_ok=True) + out = f"results/diagnose_spikes_{POOL_LABEL}.png" + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved: {out}") + + +if __name__ == "__main__": + main() diff --git a/experiments/diagnostic_lp_supply.py b/experiments/diagnostic_lp_supply.py new file mode 100644 index 0000000..3f3c317 --- /dev/null +++ b/experiments/diagnostic_lp_supply.py @@ -0,0 +1,134 @@ +"""Diagnostic: sim vs world absolute pool value for a single (gas=0, arb_freq=1) run. + +Plots raw USD pool value over time for both sim and world, no per-LP normalization. +""" + +import os +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, extract_on_chain_state, extract_initial_state, + get_data_end_date, load_world_history, load_bpt_supply_df, +) +from experiments.run_pool_battery import run_sim, sample_at_timestamps, _start_str_from_pool + + +def main(): + pool = POOL_REGISTRY["cbBTC_WETH"] + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + print(f"Pool: {pool.label}, TVL: ${pool.initial_pool_value_usd:,.0f}") + print(f"BPT: {lp_supply_df['lp_supply'].iloc[0]:.4f} -> {lp_supply_df['lp_supply'].iloc[-1]:.4f}") + print(f"Period: {start_str} to {end_str}") + + # Run sim WITH lp_supply (gas=0, arb_freq=1) + result_lp = run_sim(pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, + start=start_str, end=end_str, + lp_supply_df=lp_supply_df) + + # Run sim WITHOUT lp_supply (gas=0, arb_freq=1) + result_no_lp = run_sim(pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, + start=start_str, end=end_str, + lp_supply_df=None) + + # World + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + raw_bal_0 = world["raw_bal_0"] + raw_bal_1 = world["raw_bal_1"] + + start_sec = result_lp["start_unix_sec"] + prices_min = result_lp["prices"] + + # World value at world timestamps (raw balances × USD prices) + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts) + for i in range(prices_min.shape[1]) + ], axis=1) + world_value = raw_bal_0 * prices_at_world[:, 0] + raw_bal_1 * prices_at_world[:, 1] + + # Sim values (minute-resolution) + sim_value_lp = np.array(result_lp["value_usd"]) + sim_value_no_lp = np.array(result_no_lp["value_usd"]) + n_minutes = len(sim_value_lp) + sim_times_sec = start_sec + np.arange(n_minutes) * 60 + sim_days = (sim_times_sec - start_sec) / 86400 + world_days = (world_ts - start_sec) / 86400 + + # BPT supply at world timestamps (for annotation) + lp_at_world = np.interp( + world_ts, + np.array(lp_supply_df["unix"]) / 1000, + np.array(lp_supply_df["lp_supply"]), + ) + + # --- Plot --- + fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True) + + # Panel 1: absolute pool value + ax = axes[0] + ax.plot(world_days, world_value, "k-", linewidth=2, label="World (raw balances × prices)") + ax.plot(sim_days, sim_value_lp, "b-", linewidth=1, alpha=0.8, label="Sim (with lp_supply)") + ax.plot(sim_days, sim_value_no_lp, "r--", linewidth=1, alpha=0.8, label="Sim (no lp_supply)") + ax.set_ylabel("Pool value (USD)") + ax.set_title(f"{pool.label} — gas=0, arb_freq=1min — absolute pool value") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + # Panel 2: growth factors + ax = axes[1] + world_growth = world_value / world_value[0] + sim_growth_lp = sim_value_lp / sim_value_lp[0] + sim_growth_no_lp = sim_value_no_lp / sim_value_no_lp[0] + ax.plot(world_days, world_growth, "k-", linewidth=2, label="World growth") + ax.plot(sim_days, sim_growth_lp, "b-", linewidth=1, alpha=0.8, label="Sim growth (with lp_supply)") + ax.plot(sim_days, sim_growth_no_lp, "r--", linewidth=1, alpha=0.8, label="Sim growth (no lp_supply)") + ax.axhline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("Growth factor") + ax.set_title("Growth factors (value / initial value)") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + # Panel 3: BPT supply + ax = axes[2] + ax.plot(world_days, lp_at_world, "g-", linewidth=2, label="BPT supply (normalized)") + ax.set_ylabel("BPT / BPT₀") + ax.set_xlabel("Days from start") + ax.set_title("On-chain BPT supply") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + fig.suptitle( + f"{pool.label} ({pool.chain}) — TVL=${pool.initial_pool_value_usd:,.0f} — " + f"PR={pool.on_chain_params['price_ratio']:.4f}", + fontsize=12, fontweight="bold", + ) + plt.tight_layout() + + os.makedirs("results", exist_ok=True) + out = "results/diagnostic_lp_supply_cbBTC_WETH.png" + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved: {out}") + + # Print key numbers + print(f"\nWorld: {world_value[0]:.0f} -> {world_value[-1]:.0f} (growth={world_growth[-1]:.4f})") + print(f"Sim (lp): {sim_value_lp[0]:.0f} -> {sim_value_lp[-1]:.0f} (growth={sim_growth_lp[-1]:.4f})") + print(f"Sim (no lp): {sim_value_no_lp[0]:.0f} -> {sim_value_no_lp[-1]:.0f} (growth={sim_growth_no_lp[-1]:.4f})") + print(f"\nDeviation (lp): {(sim_growth_lp[-1]/world_growth[-1] - 1)*100:+.2f}%") + print(f"Deviation (no lp): {(sim_growth_no_lp[-1]/world_growth[-1] - 1)*100:+.2f}%") + + +if __name__ == "__main__": + main() diff --git a/experiments/pool_registry.py b/experiments/pool_registry.py new file mode 100644 index 0000000..cfdce2f --- /dev/null +++ b/experiments/pool_registry.py @@ -0,0 +1,512 @@ +"""Registry of on-chain reClAMM pools for sim-vs-world gas calibration. + +Extracts pool state from reclamm-simulations DB and computes TVL in USD +at each pool's plausible_start date. Maps chain → realistic gas costs. +Also provides initial on-chain state (Ra, Rb, Va, Vb) and world balance +history for comparison. + +Pools excluded: + - EUR_USDC_b, sUSDai_USDT0, WXPL_USDT0: stable/stable pairs + - wstETH_GNO: boosted (wstETH yield-bearing) +""" + +import math +import os +import sqlite3 +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Database path (reclamm-simulations repo) +# --------------------------------------------------------------------------- +DEFAULT_DB_PATH = os.path.expanduser( + "~/Projects/reclamm-simulations/data/pools_history.db" +) + +# --------------------------------------------------------------------------- +# Chain → gas cost batteries (USD) +# --------------------------------------------------------------------------- +# Non-mainnet chains use flat gas costs. +# Ethereum uses time-varying gas from on-chain percentile CSVs. +CHAIN_GAS_COSTS = { + "base": [0.0, 0.01, 0.1, 0.5], + "gnosis": [0.0, 0.01, 0.1, 0.5], + "avalanche": [0.0, 0.01, 0.1, 0.5], +} + +# Ethereum mainnet: time-varying gas percentiles + flat zero baseline. +# CSVs live in gas_csvs/ with columns [unix, USD]. +GAS_CSV_DIR = os.path.join(os.path.dirname(__file__), "..", "gas_csvs") +ETHEREUM_GAS_PERCENTILES = ["50p", "75p", "90p", "95p"] + + +@dataclass +class PoolConfig: + """Static metadata for a simulatable on-chain reClAMM pool.""" + + label: str + tokens: list # quantammsim ticker names, e.g. ['BTC', 'ETH'] + chain: str + swap_fee: float + db_label: str # table name in pools_history.db + plausible_start: str # YYYY-MM-DD + reverse: bool # True if DB token order is reversed vs quantammsim + pool_address: str = "" # on-chain contract address (hex, no 0x prefix) + # Filled by extract_on_chain_state(): + on_chain_params: Optional[dict] = None # price_ratio, margin, shift_rate + initial_pool_value_usd: Optional[float] = None + + +# --------------------------------------------------------------------------- +# Pool definitions (non-stable, non-boosted pools with quantammsim tickers) +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Chain → Balancer V3 API chain identifier +# --------------------------------------------------------------------------- +BALANCER_API_CHAIN = { + "base": "BASE", + "ethereum": "MAINNET", + "gnosis": "GNOSIS", + "avalanche": "AVALANCHE", + "arbitrum": "ARBITRUM", + "polygon": "POLYGON", + "optimism": "OPTIMISM", + "sonic": "SONIC", +} + + +POOL_REGISTRY = { + "cbBTC_WETH": PoolConfig( + label="cbBTC_WETH", + tokens=["BTC", "ETH"], + chain="base", + swap_fee=0.0005, + db_label="cbBTC_WETH", + plausible_start="2025-08-01", + reverse=True, + pool_address="19aeb8168d921bb069c6771bbaff7c09116720d0", + ), + "cbBTC_WETH_post_oct": PoolConfig( + label="cbBTC_WETH_post_oct", + tokens=["BTC", "ETH"], + chain="base", + swap_fee=0.0005, + db_label="cbBTC_WETH", + plausible_start="2025-12-01", + reverse=True, + pool_address="19aeb8168d921bb069c6771bbaff7c09116720d0", + ), + "AAVE_WETH": PoolConfig( + label="AAVE_WETH", + tokens=["AAVE", "ETH"], + chain="ethereum", + swap_fee=0.0025, + db_label="AAVE_WETH", + plausible_start="2025-08-15", + reverse=False, + pool_address="9d1fcf346ea1b073de4d5834e25572cc6ad71f4d", + ), + "AAVE_WETH_post_gov": PoolConfig( + label="AAVE_WETH_post_gov", + tokens=["AAVE", "ETH"], + chain="ethereum", + swap_fee=0.0025, + db_label="AAVE_WETH", + plausible_start="2025-12-21", + reverse=False, + pool_address="9d1fcf346ea1b073de4d5834e25572cc6ad71f4d", + ), + "COW_WETH_b": PoolConfig( + label="COW_WETH_b", + tokens=["COW", "ETH"], + chain="base", + swap_fee=0.003, + db_label="COW_WETH_b", + plausible_start="2025-07-18", + reverse=True, + pool_address="ff028c1ec4559d3aa2b0859aa582925b5cc28069", + ), + "COW_WETH_e": PoolConfig( + label="COW_WETH_e", + tokens=["COW", "ETH"], + chain="ethereum", + swap_fee=0.003, + db_label="COW_WETH_e", + plausible_start="2025-09-21", + reverse=True, + pool_address="d321300ef77067d4a868f117d37706eb81368e98", + ), + "WAVAX_USDC": PoolConfig( + label="WAVAX_USDC", + tokens=["AVAX", "USDC"], + chain="avalanche", + swap_fee=0.001, + db_label="WAVAX_USDC", + plausible_start="2025-08-17", + reverse=False, + pool_address="8750ccffcddbff81b63790dbcb1ffd8c7dc4c16d", + ), + "GNO_USDC": PoolConfig( + label="GNO_USDC", + tokens=["GNO", "USDC"], + chain="gnosis", + swap_fee=0.003, + db_label="GNO_USDC", + plausible_start="2025-09-18", + reverse=True, + pool_address="70b3b56773ace43fe86ee1d80cbe03176cbe4c09", + ), +} + + +def _date_to_unix(date_str: str) -> int: + """Convert YYYY-MM-DD or YYYY-MM-DD HH:MM:SS to unix timestamp (seconds).""" + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + dt = datetime.strptime(date_str, fmt).replace(tzinfo=timezone.utc) + return int(dt.timestamp()) + except ValueError: + continue + raise ValueError(f"Cannot parse date: {date_str}") + + +def _get_usd_price_at(ticker: str, unix_ms: int, data_root: str) -> float: + """Get the USD price of a ticker at a given unix timestamp (ms).""" + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path) + idx = (df["unix"] - unix_ms).abs().idxmin() + return float(df.iloc[idx]["close"]) + + +def extract_on_chain_state( + pool: PoolConfig, + db_path: str = DEFAULT_DB_PATH, + data_root: str = None, +) -> PoolConfig: + """Query the DB for on-chain state at plausible_start and compute USD TVL. + + Mutates and returns the pool config with on_chain_params and + initial_pool_value_usd filled in. + """ + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + + conn = sqlite3.connect(db_path) + cur = conn.cursor() + ts = _date_to_unix(pool.plausible_start) + + cur.execute( + f"""SELECT * FROM {pool.db_label} + WHERE timestamp <= ? + ORDER BY timestamp DESC LIMIT 1""", + (ts + 3600,), + ) + row = cur.fetchone() + conn.close() + + if row is None: + raise ValueError( + f"No DB data for {pool.db_label} at {pool.plausible_start}" + ) + + # DB columns: timestamp, block_number, bpt_supply, balance_0, balance_1, + # spot_price, virtual_0, virtual_1, time_last_interaction, + # price_ratio, margin, shift_rate, swap_fee + balance_0, balance_1 = row[3], row[4] + price_ratio = row[9] + margin = row[10] + shift_rate = row[11] + + pool.on_chain_params = { + "price_ratio": price_ratio, + "margin": margin, + "shift_rate": shift_rate, + "swap_fee": row[12], + } + + # Compute TVL in USD from per-token USD prices. + # DB stores balances in contract token order (bring_pool_data.py never + # applies reverse). The reverse flag tells us the mapping: + # reverse=False → balance_0=tokens[0], balance_1=tokens[1] + # reverse=True → balance_0=tokens[1], balance_1=tokens[0] + unix_ms = ts * 1000 + if pool.reverse: + tickers_in_db_order = [pool.tokens[1], pool.tokens[0]] + else: + tickers_in_db_order = [pool.tokens[0], pool.tokens[1]] + + usd_prices = [] + for ticker in tickers_in_db_order: + if ticker == "USDC": + usd_prices.append(1.0) + else: + usd_prices.append( + _get_usd_price_at(ticker, unix_ms, data_root) + ) + + pool.initial_pool_value_usd = ( + balance_0 * usd_prices[0] + balance_1 * usd_prices[1] + ) + return pool + + +def extract_initial_state( + pool: PoolConfig, + db_path: str = DEFAULT_DB_PATH, +) -> dict: + """Extract on-chain Ra, Rb, Va, Vb at plausible_start in quantammsim order. + + quantammsim sorts tokens alphabetically, so token[0] is the + alphabetically-first ticker. The reverse flag maps DB contract + order to this sorted order. + + Returns dict with keys Ra, Rb, Va, Vb (floats). + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + ts = _date_to_unix(pool.plausible_start) + + cur.execute( + f"""SELECT balance_0, balance_1, virtual_0, virtual_1 + FROM {pool.db_label} + WHERE timestamp <= ? + ORDER BY timestamp DESC LIMIT 1""", + (ts + 3600,), + ) + row = cur.fetchone() + conn.close() + + if row is None: + raise ValueError( + f"No DB data for {pool.db_label} at {pool.plausible_start}" + ) + + b0, b1, v0, v1 = row + if pool.reverse: + # DB contract order is opposite to quantammsim sorted order + return {"Ra": b1, "Rb": b0, "Va": v1, "Vb": v0} + else: + return {"Ra": b0, "Rb": b1, "Va": v0, "Vb": v1} + + +def load_world_history( + pool: PoolConfig, + end_date: str = None, + db_path: str = DEFAULT_DB_PATH, +) -> dict: + """Load on-chain balance history from the DB. + + Returns dict with: + timestamps: array of unix timestamps (seconds) + bal_0: BPT-normalized balance of quantammsim token[0] + bal_1: BPT-normalized balance of quantammsim token[1] + raw_bal_0: raw (un-normalized) balance of quantammsim token[0] + raw_bal_1: raw (un-normalized) balance of quantammsim token[1] + governance_events: list of (timestamp, field, old_val, new_val) + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + ts_start = _date_to_unix(pool.plausible_start) - 1000 + if end_date: + ts_end = _date_to_unix(end_date) + else: + ts_end = 2_000_000_000 # far future + + cur.execute( + f"""SELECT timestamp, bpt_supply, balance_0, balance_1, + price_ratio, margin, shift_rate, swap_fee + FROM {pool.db_label} + WHERE timestamp BETWEEN ? AND ? + ORDER BY timestamp""", + (ts_start, ts_end), + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No world history for {pool.db_label}") + + initial_bpt = rows[0][1] + timestamps = [] + bal_db_0_norm = [] + bal_db_1_norm = [] + bal_db_0_raw = [] + bal_db_1_raw = [] + governance_events = [] + + for i, row in enumerate(rows): + ts, bpt, b0, b1, pr, margin, shift_rate, swap_fee = row + timestamps.append(ts) + norm = initial_bpt / bpt + bal_db_0_norm.append(b0 * norm) + bal_db_1_norm.append(b1 * norm) + bal_db_0_raw.append(b0) + bal_db_1_raw.append(b1) + + # Detect governance changes. + # price_ratio drifts continuously via the shift mechanism, so + # only flag large discrete jumps (>1% relative change) as governance. + # margin, shift_rate, and swap_fee are set by governance and don't drift. + if i > 0: + prev = rows[i - 1] + if not math.isclose(prev[4], pr, rel_tol=0.01): + governance_events.append((ts, "price_ratio", prev[4], pr)) + if not math.isclose(prev[5], margin, rel_tol=1e-6): + governance_events.append((ts, "margin", prev[5], margin)) + if not math.isclose(prev[6], shift_rate, rel_tol=1e-6): + governance_events.append((ts, "shift_rate", prev[6], shift_rate)) + + bal_db_0_norm = np.array(bal_db_0_norm) + bal_db_1_norm = np.array(bal_db_1_norm) + bal_db_0_raw = np.array(bal_db_0_raw) + bal_db_1_raw = np.array(bal_db_1_raw) + + # Apply reverse: swap to quantammsim sorted token order + if pool.reverse: + bal_sorted_0, bal_sorted_1 = bal_db_1_norm, bal_db_0_norm + raw_sorted_0, raw_sorted_1 = bal_db_1_raw, bal_db_0_raw + else: + bal_sorted_0, bal_sorted_1 = bal_db_0_norm, bal_db_1_norm + raw_sorted_0, raw_sorted_1 = bal_db_0_raw, bal_db_1_raw + + return { + "timestamps": np.array(timestamps), + "bal_0": bal_sorted_0, + "bal_1": bal_sorted_1, + "raw_bal_0": raw_sorted_0, + "raw_bal_1": raw_sorted_1, + "governance_events": governance_events, + } + + +def load_bpt_supply_df( + pool: PoolConfig, + end_date: str = None, + db_path: str = DEFAULT_DB_PATH, +) -> pd.DataFrame: + """Load BPT supply as a DataFrame suitable for do_run_on_historic_data. + + Returns DataFrame with columns: + unix: timestamps in milliseconds + lp_supply: BPT normalized to 1.0 at plausible_start + + The normalization matches the simulator convention: lp_supply=1.0 at the + start of the sim, scaling proportionally as the on-chain pool grows/shrinks. + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + ts_start = _date_to_unix(pool.plausible_start) - 1000 + if end_date: + ts_end = _date_to_unix(end_date) + else: + ts_end = 2_000_000_000 + + cur.execute( + f"""SELECT timestamp, bpt_supply + FROM {pool.db_label} + WHERE timestamp BETWEEN ? AND ? + ORDER BY timestamp""", + (ts_start, ts_end), + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No BPT data for {pool.db_label}") + + initial_bpt = rows[0][1] + return pd.DataFrame({ + # Round to nearest minute boundary so timestamps land on the minute grid + # used by raw_fee_like_amounts_to_fee_like_array. + "unix": [round(r[0] / 60) * 60 * 1000 for r in rows], + "lp_supply": [r[1] / initial_bpt for r in rows], + }) + + +def get_data_end_date(tokens: list, data_root: str = None) -> str: + """Find the latest common date across all token parquets. + + Returns a date string like '2026-02-18 00:00:00'. + """ + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + + min_end = float("inf") + for ticker in tokens: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path, columns=["unix"]) + last = float(df["unix"].iloc[-1]) + if last < min_end: + min_end = last + + # Convert ms to datetime + dt = datetime.utcfromtimestamp(min_end / 1000) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def load_gas_csv(percentile: str) -> pd.DataFrame: + """Load a gas percentile CSV as a DataFrame for do_run_on_historic_data. + + Returns DataFrame with columns [unix, trade_gas_cost_usd], timestamps + floored to minute boundaries. + """ + path = os.path.join(GAS_CSV_DIR, f"Gas_{percentile}.csv") + df = pd.read_csv(path) + df = df.rename(columns={"USD": "trade_gas_cost_usd"}) + df["unix"] = (df["unix"] // 60000) * 60000 # floor to minute boundary + return df + + +def get_gas_costs(pool: PoolConfig, custom: list = None) -> list: + """Return the gas cost battery for a pool's chain. + + For Ethereum, returns a list mixing flat 0.0 with gas percentile labels + (e.g. ["0.0", "50p", "75p", "90p", "95p"]). + For other chains, returns flat USD values. + """ + if custom is not None: + return custom + if pool.chain == "ethereum": + flat = [0.0, 0.1, 0.5, 1.0, 3.0, 5.0, 10.0] + return flat + ETHEREUM_GAS_PERCENTILES + return CHAIN_GAS_COSTS.get(pool.chain, [0.0, 0.1, 1.0]) + + +def print_pool_summary(pool: PoolConfig): + """Print a summary of the pool's on-chain state.""" + print(f"\n{'='*60}") + print(f"Pool: {pool.label}") + print(f" Chain: {pool.chain}") + print(f" Tokens: {pool.tokens[0]}/{pool.tokens[1]}") + print(f" Swap fee: {pool.swap_fee}") + print(f" Start: {pool.plausible_start}") + if pool.on_chain_params: + p = pool.on_chain_params + print(f" On-chain: PR={p['price_ratio']:.4f} " + f"margin={p['margin']} shift_rate={p['shift_rate']} " + f"fee={p['swap_fee']}") + if pool.initial_pool_value_usd: + print(f" TVL: ${pool.initial_pool_value_usd:,.0f} USD") + print(f" Gas battery: {get_gas_costs(pool)}") + print(f"{'='*60}") + + +if __name__ == "__main__": + # Print summary of all pools + for label, pool in POOL_REGISTRY.items(): + try: + extract_on_chain_state(pool) + print_pool_summary(pool) + except Exception as e: + print(f"\n{label}: FAILED — {e}") diff --git a/experiments/run_pool_battery.py b/experiments/run_pool_battery.py new file mode 100644 index 0000000..def2c47 --- /dev/null +++ b/experiments/run_pool_battery.py @@ -0,0 +1,914 @@ +"""Sim-vs-world gas + arb-frequency calibration for on-chain reClAMM pools. + +For each pool in the registry, runs quantammsim forward passes with exact +on-chain parameters across a 2D grid of (gas_cost, arb_frequency), then +compares the simulated pool value trajectory against the actual on-chain +trajectory. + +arb_frequency is the period between arb trades in minutes (1 = every minute, +the most aggressive; higher = sparser arb). + +Generalizes scripts/sim_vs_world_comparison.py to work for any pool. + +Usage: + cd /Users/matthew/Projects/quantammsim-reclamm + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + + # Single pool (default gas + arb_freq grid) + python experiments/run_pool_battery.py cbBTC_WETH + + # All pools with data available + python experiments/run_pool_battery.py --all + + # Custom grids + python experiments/run_pool_battery.py cbBTC_WETH --gas-costs 0.0 0.5 1.0 --arb-freqs 1 5 15 60 + + # Dry run (show config without running) + python experiments/run_pool_battery.py cbBTC_WETH --dry-run + + # List available pools + python experiments/run_pool_battery.py --list +""" + +import argparse +import json +import os +import time + +import jax.numpy as jnp +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, + PoolConfig, + extract_initial_state, + extract_on_chain_state, + get_data_end_date, + get_gas_costs, + load_bpt_supply_df, + load_gas_csv, + load_world_history, + print_pool_summary, +) +from quantammsim.runners.jax_runners import do_run_on_historic_data + +PROTOCOL_FEE_SPLIT = 0.5 +DEFAULT_ARB_FREQS = [1, 2, 3, 5, 10, 15, 20] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): + """Sample a minute-level array at specific Unix timestamps. + + For each target timestamp, finds the nearest minute index in the + sim output and returns the corresponding value. + """ + indices = np.round((timestamps_sec - start_unix_sec) / 60).astype(int) + indices = np.clip(indices, 0, len(minute_vals) - 1) + return minute_vals[indices] + + +def compute_log_rmse(sim_growth, world_growth): + """RMSE of log(sim/world) across all trajectory points. + + Symmetric in over/under-estimation, natural for multiplicative processes. + A score of 0.02 means typical 2% deviation at any point in time. + """ + log_ratio = np.log(sim_growth / world_growth) + return np.sqrt(np.mean(log_ratio ** 2)) + + +def _start_str_from_pool(pool): + """Derive sim start time from pool's plausible_start, rounded to minute.""" + ts = int( + datetime.strptime(pool.plausible_start, "%Y-%m-%d") + .replace(tzinfo=timezone.utc) + .timestamp() + ) + ts_minute = (ts // 60) * 60 + return datetime.utcfromtimestamp(ts_minute).strftime("%Y-%m-%d %H:%M:%S") + + +def _onchain_params_to_sim(pool): + """Map DB param names to quantammsim param dict (jnp arrays).""" + p = pool.on_chain_params + return { + "price_ratio": jnp.array(p["price_ratio"]), + "centeredness_margin": jnp.array(p["margin"]), + "shift_exponent": jnp.array(p["shift_rate"]), + } + + +# --------------------------------------------------------------------------- +# Core sim runner +# --------------------------------------------------------------------------- + +def run_sim(pool, gas_cost, arb_frequency, initial_state, start, end, + protocol_fee_split=PROTOCOL_FEE_SPLIT, lp_supply_df=None, + noise_config=None): + """Run a single forward pass with exact on-chain params. + + gas_cost can be: + - float: flat gas cost in USD (e.g. 0.0, 0.5) + - str: gas percentile label (e.g. "50p", "90p") — loads time-varying + gas from CSV + + noise_config can be: + - None: no noise model (arb-only, default) + - dict with keys 'noise_model' and 'reclamm_noise_params': inject + Tsoukalas noise model into the sim + + Returns dict with minute-level per-LP value (USD), prices (USD per token), + and start_unix_sec. When lp_supply_df is provided, value_usd is divided + by the interpolated LP supply so it is comparable to BPT-normalized world + balances. + """ + params = _onchain_params_to_sim(pool) + + # Resolve gas: percentile string → DataFrame, float → scalar + gas_cost_df = None + if isinstance(gas_cost, str): + gas_cost_df = load_gas_csv(gas_cost) + flat_gas = 0.0 # placeholder; gas_cost_df overrides + else: + flat_gas = gas_cost + + fp = { + "tokens": pool.tokens, + "rule": "reclamm", + "startDateString": start, + "endDateString": end, + "initial_pool_value": pool.initial_pool_value_usd, + "fees": pool.swap_fee, + "gas_cost": flat_gas, + "arb_fees": 0.0, + "do_arb": True, + "arb_frequency": arb_frequency, + "chunk_period": 1440, + "weight_interpolation_period": 1440, + "reclamm_use_shift_exponent": True, + "reclamm_interpolation_method": "geometric", + "reclamm_centeredness_scaling": False, + "protocol_fee_split": protocol_fee_split, + "reclamm_initial_state": initial_state, + } + + if noise_config is not None: + fp["noise_model"] = noise_config["noise_model"] + fp["reclamm_noise_params"] = noise_config["reclamm_noise_params"] + + result = do_run_on_historic_data( + run_fingerprint=fp, params=params, lp_supply_df=lp_supply_df, + gas_cost_df=gas_cost_df, + ) + + start_unix_sec = datetime.strptime( + start, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + value_usd = np.array(result["value"]) + + # Normalize to per-LP value so comparison with BPT-normalized world is valid. + # + # Subtlety: the scan applies lp_supply every arb_frequency minutes. + # Between scan steps, reserves are constant (no arb), so the pool value + # reflects the lp_supply from the LAST scan step. If a BPT event occurs + # between scan steps (e.g. at minute 3 when arb_frequency=5), the value + # at minutes 3-4 still reflects the old lp_supply. We must divide by + # the scan-step-aligned lp_supply, not the current minute's lp_supply, + # otherwise we get transient spikes. + if lp_supply_df is not None: + n_minutes = len(value_usd) + # Map each minute to its most recent scan-step time + scan_step_minutes = ( + np.arange(n_minutes) // arb_frequency * arb_frequency + ) + scan_step_times_ms = ( + start_unix_sec * 1000 + scan_step_minutes * 60_000 + ) + lp_unix = np.array(lp_supply_df["unix"]) + lp_vals = np.array(lp_supply_df["lp_supply"]) + indices = np.searchsorted(lp_unix, scan_step_times_ms, side="right") - 1 + indices = np.clip(indices, 0, len(lp_vals) - 1) + value_usd = value_usd / lp_vals[indices] + + return { + "value_usd": value_usd, + "prices": np.array(result["prices"]), # (T, n_tokens) in USD + "start_unix_sec": start_unix_sec, + } + + +# --------------------------------------------------------------------------- +# Pool calibration (2D grid: gas_cost × arb_frequency) +# --------------------------------------------------------------------------- + +def run_pool_calibration(pool, gas_costs, arb_freqs, verbose=True, + noise_config=None): + """Run 2D gas × arb_frequency calibration for a single pool. + + Parameters + ---------- + noise_config : dict, optional + If provided, passed through to run_sim to inject noise model. + Keys: 'noise_model', 'reclamm_noise_params'. + + Returns dict with: + world_growth: array of world growth factors + sim_growths: {(gas_cost, arb_freq): growth array} + timestamps: world timestamps (seconds) + governance_idx: index of first governance event (or n_points) + n_points: number of comparison points + days: array of days from start + gas_costs: list of gas costs + arb_freqs: list of arb frequencies + """ + # Extract on-chain state + initial reserves + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + + if verbose: + print_pool_summary(pool) + print(f" Initial state: Ra={initial_state['Ra']:.4f}, " + f"Rb={initial_state['Rb']:.4f}, " + f"Va={initial_state['Va']:.4f}, Vb={initial_state['Vb']:.4f}") + + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + + # Load BPT supply history for LP supply scaling + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + if verbose: + bpt_start = lp_supply_df["lp_supply"].iloc[0] + bpt_end = lp_supply_df["lp_supply"].iloc[-1] + print(f" BPT supply: {bpt_start:.4f} → {bpt_end:.4f} " + f"({(bpt_end/bpt_start - 1)*100:+.1f}%)") + print(f" Sim period: {start_str} to {end_str}") + n_runs = len(gas_costs) * len(arb_freqs) + print(f" Grid: {len(gas_costs)} gas × {len(arb_freqs)} arb_freq = {n_runs} runs") + + # Load world history (BPT-normalized balances + governance events) + # BPT-normalized is correct for the growth ratio metric since LP supply + # cancels out of sim_growth/world_growth. Raw balances are available + # in world["raw_bal_0/1"] for absolute trajectory comparison. + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + world_bal_0 = world["bal_0"] + world_bal_1 = world["bal_1"] + gov_events = world["governance_events"] + + if verbose: + print(f" World points: {len(world_ts)}") + if gov_events: + for ts, field, old, new in gov_events: + dt = datetime.utcfromtimestamp(ts).strftime("%Y-%m-%d") + print(f" Governance: {field} {old:.6f} -> {new:.6f} on {dt}") + else: + print(" No governance events") + + # Governance cutoff index + if gov_events: + gov_idx = np.searchsorted(world_ts, gov_events[0][0]) + else: + gov_idx = len(world_ts) + + # Run sims across the 2D grid + sim_results = {} + prices_min = None + start_sec = None + + for gc in gas_costs: + for af in arb_freqs: + if verbose: + print(f"\n Running gas=${gc}, arb_freq={af}min...") + t0 = time.time() + result = run_sim(pool, gc, af, initial_state, start_str, end_str, + lp_supply_df=lp_supply_df, + noise_config=noise_config) + elapsed = time.time() - t0 + if verbose: + print(f" Done in {elapsed:.1f}s") + sim_results[(gc, af)] = result + if prices_min is None: + prices_min = result["prices"] + start_sec = result["start_unix_sec"] + + # Truncate at governance + n = min(gov_idx, len(world_ts)) + world_ts_trunc = world_ts[:n] + + # Sample USD prices at world timestamps for world valuation + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts_trunc) + for i in range(prices_min.shape[1]) + ], axis=1) + + # World value in USD = sum(bal_i * price_usd_i) + world_value = ( + world_bal_0[:n] * prices_at_world[:, 0] + + world_bal_1[:n] * prices_at_world[:, 1] + ) + world_growth = world_value / world_value[0] + + # Sim growths at world timestamps + sim_growths = {} + for key, result in sim_results.items(): + sim_val = sample_at_timestamps( + result["value_usd"], start_sec, world_ts_trunc, + ) + sim_growths[key] = sim_val / sim_val[0] + + days = (world_ts_trunc - world_ts_trunc[0]) / 86400 + + return { + "world_growth": world_growth, + "sim_growths": sim_growths, + "timestamps": world_ts_trunc, + "governance_idx": gov_idx, + "n_points": n, + "days": days, + "gas_costs": list(gas_costs), + "arb_freqs": list(arb_freqs), + } + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_pool_calibration(pool, calibration, output_dir="results", suffix=""): + """Plot 2D gas × arb_freq calibration as heatmap + time series. + + Left: heatmap of final % deviation (gas_cost × arb_freq). + Right: time series for each arb_freq at best gas cost. + """ + os.makedirs(output_dir, exist_ok=True) + + world_growth = calibration["world_growth"] + sim_growths = calibration["sim_growths"] + days = calibration["days"] + gas_costs = calibration["gas_costs"] + arb_freqs = calibration["arb_freqs"] + + # Build RMSE matrix (log ratio, %) + rmse_matrix = np.zeros((len(arb_freqs), len(gas_costs))) + for i, af in enumerate(arb_freqs): + for j, gc in enumerate(gas_costs): + rmse_matrix[i, j] = compute_log_rmse( + sim_growths[(gc, af)], world_growth + ) * 100 + + fig, (ax_heat, ax_ts) = plt.subplots( + 1, 2, figsize=(18, 7), + gridspec_kw={"width_ratios": [1, 1.5]}, + ) + + # Left: heatmap of trajectory RMSE + im = ax_heat.imshow( + rmse_matrix, aspect="auto", cmap="RdYlGn_r", + vmin=0, vmax=rmse_matrix.max(), origin="lower", + ) + ax_heat.set_xticks(range(len(gas_costs))) + gas_labels = [ + f"gas {gc}" if isinstance(gc, str) else f"${gc}" + for gc in gas_costs + ] + ax_heat.set_xticklabels(gas_labels, fontsize=9) + ax_heat.set_yticks(range(len(arb_freqs))) + ax_heat.set_yticklabels([f"{af}min" for af in arb_freqs], fontsize=9) + ax_heat.set_xlabel("gas cost (USD)") + ax_heat.set_ylabel("arb frequency (minutes)") + ax_heat.set_title("Trajectory RMSE (log ratio, %)") + + # Annotate cells + for i in range(len(arb_freqs)): + for j in range(len(gas_costs)): + val = rmse_matrix[i, j] + color = "white" if val > rmse_matrix.max() * 0.6 else "black" + ax_heat.text(j, i, f"{val:.2f}%", ha="center", va="center", + fontsize=8, color=color) + + # Mark cell with least-negative mean bias (closest to 0 from below). + # If no cell is below world on average, fall back to lowest RMSE. + bias_matrix = np.zeros_like(rmse_matrix) + for i, af in enumerate(arb_freqs): + for j, gc in enumerate(gas_costs): + bias_matrix[i, j] = float(np.mean( + np.log(sim_growths[(gc, af)] / world_growth) + )) + negative_mask = bias_matrix < 0 + if negative_mask.any(): + # Among negative cells, find the one closest to 0 (max value) + masked = np.where(negative_mask, bias_matrix, -np.inf) + best_idx = np.unravel_index(np.argmax(masked), masked.shape) + else: + best_idx = np.unravel_index(np.argmin(rmse_matrix), rmse_matrix.shape) + ax_heat.add_patch(plt.Rectangle( + (best_idx[1] - 0.5, best_idx[0] - 0.5), 1, 1, + fill=False, edgecolor="lime", linewidth=3, + )) + + fig.colorbar(im, ax=ax_heat, label="RMSE (%)", shrink=0.8) + + # Right: 4 closest from below + 1 first above world + # Rationale: sim should underestimate (can't capture organic swaps, MEV + # rebates, etc.), so being below world is expected. The one-above config + # brackets where the sim crosses from conservative to optimistic. + ax_ts.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # Classify configs by mean log ratio (trajectory-average bias). + # Using the mean rather than endpoint avoids a curve that's above + # world for 80% of the trajectory being classified as "below" + # just because it dips at the end. + below = [] # (mean_bias, rmse, gc, af) where sim < world on average + above = [] # (mean_bias, rmse, gc, af) where sim >= world on average + for (gc, af), sg in sim_growths.items(): + mean_bias = float(np.mean(np.log(sg / world_growth))) + rmse = compute_log_rmse(sg, world_growth) + if mean_bias < 0: + below.append((mean_bias, rmse, gc, af)) + else: + above.append((mean_bias, rmse, gc, af)) + + # Sort: below by mean_bias descending (closest to 0 first) + below.sort(key=lambda x: x[0], reverse=True) + # Sort: above by mean_bias ascending (closest to 0 first) + above.sort(key=lambda x: x[0]) + + # Select: up to 4 from below, 1 from above, fill if needed + selected = [] + n_below = min(4, len(below)) + n_above = min(1, len(above)) + selected.extend(below[:n_below]) + selected.extend(above[:n_above]) + remaining = 5 - len(selected) + if remaining > 0 and len(below) > n_below: + selected.extend(below[n_below:n_below + remaining]) + remaining = 5 - len(selected) + if remaining > 0 and len(above) > n_above: + selected.extend(above[n_above:n_above + remaining]) + + colors_below = plt.cm.Blues(np.linspace(0.4, 0.8, n_below)) + colors_above = np.array([[0.8, 0.2, 0.2, 1.0]]) # red for above + plot_colors = list(colors_below) + list(colors_above[:n_above]) + # Fill remaining with grey + while len(plot_colors) < len(selected): + plot_colors.append([0.5, 0.5, 0.5, 1.0]) + + for rank, (mean_bias, rmse, gc, af) in enumerate(selected): + dev = (sim_growths[(gc, af)] / world_growth - 1) * 100 + gc_label = f"gas {gc}" if isinstance(gc, str) else f"gas=${gc}" + marker = "\u25b2" if mean_bias >= 0 else "\u25bc" # ▲ above, ▼ below + ax_ts.plot(days, dev, color=plot_colors[rank], linewidth=2, + label=f"{marker} {gc_label}, arb={af}min " + f"bias={mean_bias*100:+.2f}% RMSE={rmse*100:.2f}%") + + ax_ts.set_xlabel("days") + ax_ts.set_ylabel("% deviation from world") + trunc = " (pre-governance)" if calibration["governance_idx"] < calibration["n_points"] + 1 else "" + ax_ts.set_title(f"Best bracket: {n_below} below + {n_above} above world{trunc}") + ax_ts.legend(fontsize=7, loc="best") + ax_ts.grid(True, alpha=0.2) + + p = pool.on_chain_params + fig.suptitle( + f"{pool.label} ({pool.chain}) — {pool.tokens[0]}/{pool.tokens[1]}\n" + f"PR={p['price_ratio']:.4f} margin={p['margin']} " + f"shift={p['shift_rate']} fee={pool.swap_fee} " + f"TVL=${pool.initial_pool_value_usd:,.0f} " + f"protocol_fee={PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + + out = os.path.join(output_dir, f"gas_calibration_{pool.label}{suffix}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + return out + + +def plot_cross_pool_summary(all_results, output_dir="results"): + """Plot cross-pool comparison: best (gas, arb_freq) and residual deviation.""" + os.makedirs(output_dir, exist_ok=True) + + pool_labels = [] + best_configs = [] + best_devs = [] + + best_rmses = [] + best_biases = [] + for pool, cal in all_results: + wg = cal["world_growth"] + # Best = least-negative mean bias (closest from below) + below_keys = [ + k for k in cal["sim_growths"] + if np.mean(np.log(cal["sim_growths"][k] / wg)) < 0 + ] + if below_keys: + best_key = max( + below_keys, + key=lambda k: np.mean(np.log(cal["sim_growths"][k] / wg)), + ) + else: + best_key = min( + cal["sim_growths"].keys(), + key=lambda k: compute_log_rmse(cal["sim_growths"][k], wg), + ) + best_rmse = compute_log_rmse(cal["sim_growths"][best_key], wg) * 100 + best_bias = float(np.mean(np.log(cal["sim_growths"][best_key] / wg))) * 100 + pool_labels.append(f"{pool.label}\n({pool.chain})") + best_configs.append(best_key) + best_rmses.append(best_rmse) + best_biases.append(best_bias) + + fig, (ax_cfg, ax_dev) = plt.subplots(1, 2, figsize=(16, 6)) + + x = np.arange(len(pool_labels)) + + # Left: RMSE at best config + config_strs = [ + f"gas {gc}\narb={af}min" if isinstance(gc, str) + else f"gas=${gc}\narb={af}min" + for gc, af in best_configs + ] + ax_cfg.barh(x, best_rmses, color="steelblue") + ax_cfg.set_yticks(x) + ax_cfg.set_yticklabels(pool_labels, fontsize=9) + ax_cfg.set_xlabel("Trajectory RMSE (%)") + ax_cfg.set_title("RMSE at best config") + for i, (cs, rmse) in enumerate(zip(config_strs, best_rmses)): + ax_cfg.text(rmse + 0.05, i, f"{cs} (RMSE={rmse:.2f}%)", va="center", fontsize=8) + ax_cfg.grid(True, alpha=0.2, axis="x") + + # Right: mean bias at best config (negative = conservative) + colors = ["green" if d < 0 else "orange" if d < 1 else "red" + for d in best_biases] + ax_dev.bar(x, best_biases, color=colors) + ax_dev.axhline(y=0, color="brown", linewidth=1) + ax_dev.set_xticks(x) + ax_dev.set_xticklabels(pool_labels, fontsize=8) + ax_dev.set_ylabel("Mean bias (%)") + ax_dev.set_title("Mean trajectory bias at best config") + ax_dev.grid(True, alpha=0.2, axis="y") + + fig.suptitle("Cross-pool gas + arb frequency calibration", fontsize=12) + plt.tight_layout() + + out = os.path.join(output_dir, "gas_calibration_cross_pool.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved cross-pool summary: {out}") + return out + + +# --------------------------------------------------------------------------- +# Summary printing +# --------------------------------------------------------------------------- + +def print_calibration_summary(pool, calibration): + """Print a summary table of the 2D grid results (trajectory RMSE).""" + world_growth = calibration["world_growth"] + sim_growths = calibration["sim_growths"] + n_days = calibration["days"][-1] + gas_costs = calibration["gas_costs"] + arb_freqs = calibration["arb_freqs"] + + print(f"\n {pool.label} ({pool.chain}) — {n_days:.0f} days") + print(f" World growth: {world_growth[-1]:.4f}") + + # Print as table: rows=arb_freq, cols=gas_cost (values = trajectory RMSE %) + col_label = "arb\\gas" + gas_labels = [ + f"gas {gc}" if isinstance(gc, str) else f"${gc}" + for gc in gas_costs + ] + header = f" {col_label:<10}" + "".join(f"{gl:<10}" for gl in gas_labels) + print(header) + print(f" {'-'*len(header)}") + for af in arb_freqs: + row = f" {af:>4}min " + for gc in gas_costs: + rmse = compute_log_rmse( + sim_growths[(gc, af)], world_growth + ) * 100 + row += f"{rmse:>8.2f}% " + print(row) + + +# --------------------------------------------------------------------------- +# Data availability check +# --------------------------------------------------------------------------- + +def check_data_available(pool, data_root=None): + """Check that all required parquet files exist for a pool.""" + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + for ticker in pool.tokens: + if ticker == "USDC": + continue + path = os.path.join(data_root, f"{ticker}_USD.parquet") + if not os.path.exists(path): + return False + return True + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Sim-vs-world gas + arb-frequency calibration for on-chain reClAMM pools" + ) + parser.add_argument("pool", nargs="?", + help="Pool label (e.g. cbBTC_WETH)") + parser.add_argument("--all", action="store_true", + help="Run all pools with available data") + parser.add_argument("--list", action="store_true", + help="List available pools and exit") + parser.add_argument("--dry-run", action="store_true", + help="Show pool config without running") + parser.add_argument("--gas-costs", nargs="+", type=float, default=None, + help="Override gas cost battery (flat USD values)") + parser.add_argument("--arb-freqs", nargs="+", type=int, + default=DEFAULT_ARB_FREQS, + help="Arb frequency values in minutes (default: 1 2 3 5 10 15 20)") + parser.add_argument("--protocol-fee", type=float, default=PROTOCOL_FEE_SPLIT, + help="Protocol fee split (default 0.5)") + parser.add_argument("--output-dir", default="results", + help="Directory for output plots and JSON") + parser.add_argument("--calibrate-noise", action="store_true", + help="Calibrate Tsoukalas noise model from Balancer API + DB " + "and inject into sim fingerprints") + parser.add_argument("--noise-model", choices=["sqrt", "log", "loglinear"], + default="sqrt", + help="Noise model variant (default: sqrt)") + parser.add_argument("--noise-params-json", default=None, + help="Path to hierarchical noise params JSON " + "(from calibrate_noise_hierarchical.py). " + "Looks up pool by address or uses --predict.") + args = parser.parse_args() + + # --list mode + if args.list: + print("\nAvailable pools:\n") + for label, pool in POOL_REGISTRY.items(): + has_data = check_data_available(pool) + status = "READY" if has_data else "MISSING DATA" + try: + extract_on_chain_state(pool) + print_pool_summary(pool) + except Exception as e: + print(f" {label}: {e}") + print(f" Data: {status}") + return + + # Determine which pools to run + if args.all: + pool_labels = [ + label for label, pool in POOL_REGISTRY.items() + if check_data_available(pool) + ] + if not pool_labels: + print("No pools have all required data files.") + return + elif args.pool: + if args.pool not in POOL_REGISTRY: + print(f"Unknown pool: {args.pool}") + print(f"Available: {list(POOL_REGISTRY.keys())}") + return + if not check_data_available(POOL_REGISTRY[args.pool]): + missing = [ + f"{t}_USD.parquet" for t in POOL_REGISTRY[args.pool].tokens + if t != "USDC" and not os.path.exists( + os.path.join( + os.path.dirname(__file__), "..", "quantammsim", + "data", f"{t}_USD.parquet" + ) + ) + ] + print(f"Missing data for {args.pool}: {missing}") + return + pool_labels = [args.pool] + else: + parser.print_help() + return + + # Collect runs + runs = [] + for label in pool_labels: + pool = POOL_REGISTRY[label] + gas_costs = get_gas_costs(pool, args.gas_costs) + runs.append((pool, gas_costs)) + + arb_freqs = args.arb_freqs + n_total = sum(len(gcs) * len(arb_freqs) for _, gcs in runs) + + print(f"\n{'='*60}") + print(f"GAS + ARB CALIBRATION: {len(runs)} pool(s), {n_total} total runs") + for pool, gcs in runs: + print(f" {pool.label:15s} ({pool.chain:10s}) " + f"gas={gcs} arb_freq={arb_freqs}") + print(f" Protocol fee split: {args.protocol_fee}") + print(f"{'='*60}") + + if args.dry_run: + print("\n--- DRY RUN ---\n") + for pool, gas_costs in runs: + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + print_pool_summary(pool) + print(f" Initial state: {initial_state}") + start = _start_str_from_pool(pool) + end = get_data_end_date(pool.tokens) + print(f" Sim period: {start} to {end}") + world = load_world_history(pool, end_date=end) + n_gov = len(world["governance_events"]) + print(f" World points: {len(world['timestamps'])}, " + f"governance events: {n_gov}") + if world["governance_events"]: + for ts, field, old, new in world["governance_events"]: + dt = datetime.utcfromtimestamp(ts).strftime("%Y-%m-%d") + print(f" {field} {old:.6f} -> {new:.6f} on {dt}") + print(f" Gas battery: {gas_costs}") + print(f" Arb freqs: {arb_freqs}") + n_runs = len(gas_costs) * len(arb_freqs) + print(f" Total runs: {n_runs}") + return + + # Execute calibration for each pool + all_results = [] + for pool, gas_costs in runs: + print(f"\n{'#'*60}") + print(f"POOL: {pool.label}") + print(f"{'#'*60}") + + # Noise calibration (if requested) + noise_config = None + if args.noise_params_json and args.noise_model == "loglinear": + # Load hierarchical noise params from JSON + with open(args.noise_params_json) as f: + hier_data = json.load(f) + # Look up pool by address + addr = pool.pool_address.lower() + pool_params = None + for p in hier_data["pools"]: + pid = p["pool_id"].lower().replace("0x", "") + if pid.startswith(addr) or addr.startswith(pid): + pool_params = p["noise_params"] + break + if pool_params is None: + # Fall back to population-level prediction + from scripts.calibrate_noise_hierarchical import predict_new_pool + chain_map = {"ethereum": "MAINNET", "base": "BASE", + "gnosis": "GNOSIS", "arbitrum": "ARBITRUM", + "polygon": "POLYGON", "optimism": "OPTIMISM", + "sonic": "SONIC", "avalanche": "AVALANCHE"} + api_chain = chain_map.get(pool.chain, pool.chain.upper()) + # Reconstruct posteriors + encoding from the JSON + posteriors_from_json = { + "Phi_mean": np.array(hier_data["Phi"]), + } + encoding_from_json = { + "covariate_names": hier_data["covariate_names"], + } + pool_params = predict_new_pool( + posteriors_from_json, encoding_from_json, + api_chain, pool.tokens, pool.swap_fee, + ) + print(f"\n Using population-level loglinear params (pool not in JSON)") + else: + print(f"\n Using hierarchical loglinear params for {pool.label}") + print(f" b_0 = {pool_params['b_0']:.4f}, " + f"b_sigma = {pool_params['b_sigma']:.6f}, " + f"b_c = {pool_params['b_c']:.4f}") + # Strip metadata keys (prefixed with _) — JAX can't trace strings + sim_params = {k: v for k, v in pool_params.items() + if not k.startswith("_")} + noise_config = { + "noise_model": "loglinear", + "reclamm_noise_params": sim_params, + } + elif args.calibrate_noise: + from scripts.calibrate_reclamm_noise import ( + build_calibration_df, + run_ols_calibration, + ) + noise_model_name = ( + "tsoukalas_sqrt" if args.noise_model == "sqrt" + else "tsoukalas_log" + ) + print(f"\n Calibrating noise model ({args.noise_model})...") + cal_df = build_calibration_df(pool) + noise_params, diag = run_ols_calibration( + cal_df, pool.swap_fee, args.noise_model, + ) + print(f" R² = {diag['r_squared']:.4f}, n = {diag['n_obs']}") + for key in ["a_0", "a_sigma", "a_c"]: + param_key = "a_0_base" if key == "a_0" else key + val = noise_params[param_key] + se = diag["se"][key] + t_stat = val / se if se > 0 else float("inf") + print(f" {key:>8} = {val:>10.4f} (SE={se:.4f}, t={t_stat:.2f})") + noise_config = { + "noise_model": noise_model_name, + "reclamm_noise_params": noise_params, + } + + t0 = time.time() + calibration = run_pool_calibration( + pool, gas_costs, arb_freqs, noise_config=noise_config, + ) + elapsed = time.time() - t0 + + print_calibration_summary(pool, calibration) + plot_pool_calibration(pool, calibration, output_dir=args.output_dir) + all_results.append((pool, calibration)) + + print(f"\n Total time for {pool.label}: {elapsed:.0f}s") + + # Cross-pool summary + if len(all_results) > 1: + plot_cross_pool_summary(all_results, output_dir=args.output_dir) + + # Final summary table + print(f"\n{'='*60}") + print("CALIBRATION COMPLETE") + print(f"{'='*60}") + print(f"\n{'Pool':<16} {'Chain':<10} {'Best Gas':>9} {'Best Arb':>9} " + f"{'Bias':>8} {'RMSE':>8} {'Days':>6}") + print("-" * 70) + for pool, cal in all_results: + wg = cal["world_growth"] + # Best = least-negative mean bias (closest from below) + below_keys = [ + k for k in cal["sim_growths"] + if np.mean(np.log(cal["sim_growths"][k] / wg)) < 0 + ] + if below_keys: + best_key = max( + below_keys, + key=lambda k: np.mean(np.log(cal["sim_growths"][k] / wg)), + ) + else: + best_key = min( + cal["sim_growths"].keys(), + key=lambda k: compute_log_rmse(cal["sim_growths"][k], wg), + ) + best_bias = float(np.mean(np.log(cal["sim_growths"][best_key] / wg))) * 100 + best_rmse = compute_log_rmse(cal["sim_growths"][best_key], wg) * 100 + n_days = cal["days"][-1] + gc_label = f"gas {best_key[0]}" if isinstance(best_key[0], str) else f"${best_key[0]}" + print(f"{pool.label:<16} {pool.chain:<10} {gc_label:<9} " + f"{best_key[1]:>4}min {best_bias:>+7.2f}% {best_rmse:>7.2f}% {n_days:>5.0f}d") + + # Save JSON summary + os.makedirs(args.output_dir, exist_ok=True) + summary = [] + for pool, cal in all_results: + wg_arr = cal["world_growth"] + wg_final = float(wg_arr[-1]) + pool_summary = { + "label": pool.label, + "chain": pool.chain, + "tokens": pool.tokens, + "swap_fee": pool.swap_fee, + "tvl_usd": pool.initial_pool_value_usd, + "on_chain_params": pool.on_chain_params, + "n_days": float(cal["days"][-1]), + "n_governance_events": 1 if cal["governance_idx"] < cal["n_points"] else 0, + "world_growth": wg_final, + "grid_results": {}, + } + for (gc, af) in sorted(cal["sim_growths"].keys(), key=lambda k: (str(k[0]), k[1])): + sg_arr = cal["sim_growths"][(gc, af)] + rmse = compute_log_rmse(sg_arr, wg_arr) * 100 + pool_summary["grid_results"][f"gas={gc}_arb={af}"] = { + "gas_cost": gc, + "arb_frequency": af, + "sim_growth": float(sg_arr[-1]), + "pct_deviation": float((sg_arr[-1] / wg_final - 1) * 100), + "trajectory_rmse_pct": float(rmse), + } + summary.append(pool_summary) + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + json_path = os.path.join(args.output_dir, f"gas_calibration_{ts}.json") + with open(json_path, "w") as f: + json.dump(summary, f, indent=2, default=str) + print(f"\nSummary saved to {json_path}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 8859b55..19cb8f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "plotly", "dask", "Historic-Crypto", + "gdown", + "binance_historical_data", "bidask", "optax", "jsonpickle", @@ -113,4 +115,4 @@ exclude_lines = [ "def __repr__", "raise NotImplementedError", "if __name__ == .__main__.:", -] \ No newline at end of file +] diff --git a/quantammsim/pools/base_pool.py b/quantammsim/pools/base_pool.py index b695d8c..706c939 100644 --- a/quantammsim/pools/base_pool.py +++ b/quantammsim/pools/base_pool.py @@ -99,7 +99,11 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - dynamic_inputs: Any, + fees_array: jnp.ndarray, + arb_thresh_array: jnp.ndarray, + arb_fees_array: jnp.ndarray, + trade_array: jnp.ndarray, + lp_supply_array: jnp.ndarray = None, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: pass diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 5b913ff..2413777 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -18,22 +18,6 @@ from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool - - -def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len): - """Slice and decimate a dynamic input array to match arb_prices shape. - - Scalar (1,) arrays are broadcast to (max_len,). - Full-length arrays are sliced to the bout window then decimated. - """ - if arr.shape[0] <= 1: - return jnp.broadcast_to(arr, (max_len,) + arr.shape[1:]) - sliced = dynamic_slice(arr, (start_index[0],), (bout_length - 1,)) - if arb_frequency != 1: - sliced = sliced[::arb_frequency] - return sliced - - from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, calibrate_arc_length_speed, @@ -50,6 +34,22 @@ def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len SHIFT_EXPONENT_DIVISOR = 124649.0 +def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len): + """Slice and decimate a dynamic input array to match arb_prices shape.""" + arr = jnp.asarray(arr) + if arr.ndim == 0: + return jnp.full((max_len,), arr, dtype=arr.dtype) + if arr.shape[0] <= 1: + return jnp.broadcast_to(arr, (max_len,) + arr.shape[1:]) + + start = (start_index[0],) + (0,) * (arr.ndim - 1) + slice_sizes = (bout_length - 1,) + arr.shape[1:] + sliced = dynamic_slice(arr, start, slice_sizes) + if arb_frequency != 1: + sliced = sliced[::arb_frequency] + return sliced + + class _PoolState(NamedTuple): """Intermediate state produced by _init_pool_state. @@ -214,43 +214,70 @@ def _resolve_fees(params, run_fingerprint): return jnp.squeeze(params["fees"]) return run_fingerprint["fees"] - @partial(jit, static_argnums=(2,)) - def calculate_reserves_with_fees( + @staticmethod + def _resolve_ste_temperature(run_fingerprint): + """Resolve STE gate temperature for differentiable reCLAMM transitions.""" + return run_fingerprint.get("ste_temperature") + + def _resolve_noise_inputs( self, - params: Dict[str, Any], run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - additional_oracle_input: Optional[jnp.ndarray] = None, + arb_len: int, lp_supply_array: Optional[jnp.ndarray] = None, - ) -> jnp.ndarray: - s = self._init_pool_state(params, run_fingerprint, prices, start_index) - + ): + """Prepare optional lp-supply and noise-model inputs for reserve scans.""" bout_length = run_fingerprint["bout_length"] arb_freq = run_fingerprint["arb_frequency"] - lp_prepared = ( - _prepare_dynamic_array( - lp_supply_array, start_index, bout_length, - arb_freq, s.arb_prices.shape[0], + + lp_prepared = None + if lp_supply_array is not None: + lp_prepared = _prepare_dynamic_array( + lp_supply_array, + start_index=start_index, + bout_length=bout_length, + arb_frequency=arb_freq, + max_len=arb_len, ) - if lp_supply_array is not None else None - ) noise_model = run_fingerprint.get("noise_model", "ratio") noise_params = run_fingerprint.get("reclamm_noise_params", None) if noise_params is not None and type(noise_params) is not dict: noise_params = dict(noise_params) + arb_vol = None if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = self.calculate_volatility_array( - prices, run_fingerprint, - ) + volatility_array = self.calculate_volatility_array(prices, run_fingerprint) arb_vol = _prepare_dynamic_array( - volatility_array, start_index, bout_length, - arb_freq, s.arb_prices.shape[0], + volatility_array, + start_index=start_index, + bout_length=bout_length, + arb_frequency=arb_freq, + max_len=arb_len, ) - else: - arb_vol = None + + return lp_prepared, noise_model, noise_params, arb_vol + + @partial(jit, static_argnums=(2,)) + def calculate_reserves_with_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, + ) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_with_fees( @@ -268,6 +295,7 @@ def calculate_reserves_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), lp_supply_array=lp_prepared, noise_model=noise_model, @@ -295,33 +323,15 @@ def calculate_reserves_and_fee_revenue_with_fees( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) - - bout_length = run_fingerprint["bout_length"] - arb_freq = run_fingerprint["arb_frequency"] - lp_prepared = ( - _prepare_dynamic_array( - lp_supply_array, start_index, bout_length, - arb_freq, s.arb_prices.shape[0], - ) - if lp_supply_array is not None else None + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, ) - noise_model = run_fingerprint.get("noise_model", "ratio") - noise_params = run_fingerprint.get("reclamm_noise_params", None) - if noise_params is not None and type(noise_params) is not dict: - noise_params = dict(noise_params) - - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = self.calculate_volatility_array( - prices, run_fingerprint, - ) - arb_vol = _prepare_dynamic_array( - volatility_array, start_index, bout_length, - arb_freq, s.arb_prices.shape[0], - ) - else: - arb_vol = None - if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( s.initial_reserves, s.Va, s.Vb, @@ -338,6 +348,7 @@ def calculate_reserves_and_fee_revenue_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), lp_supply_array=lp_prepared, noise_model=noise_model, @@ -368,10 +379,8 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) - bout_length = run_fingerprint["bout_length"] - max_len = bout_length - 1 - if run_fingerprint["arb_frequency"] != 1: - max_len = max_len // run_fingerprint["arb_frequency"] + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + max_len = s.arb_prices.shape[0] materialized_inputs = materialize_dynamic_inputs( dynamic_inputs, run_fingerprint.get("dynamic_input_flags"), @@ -380,22 +389,13 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) - - noise_model = run_fingerprint.get("noise_model", "ratio") - noise_params = run_fingerprint.get("reclamm_noise_params", None) - if noise_params is not None and type(noise_params) is not dict: - noise_params = dict(noise_params) - - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = self.calculate_volatility_array( - prices, run_fingerprint, - ) - arb_vol = _prepare_dynamic_array( - volatility_array, start_index, bout_length, - run_fingerprint["arb_frequency"], max_len, - ) - else: - arb_vol = None + _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=None, + ) return _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, @@ -413,6 +413,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), lp_supply_array=materialized_inputs.lp_supply, noise_model=noise_model, @@ -428,9 +429,20 @@ def _calculate_reserves_zero_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Protected zero-fee implementation for hooks and weight calculation.""" s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared = None + if lp_supply_array is not None: + lp_prepared = _prepare_dynamic_array( + lp_supply_array, + start_index=start_index, + bout_length=run_fingerprint["bout_length"], + arb_frequency=run_fingerprint["arb_frequency"], + max_len=s.arb_prices.shape[0], + ) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_zero_fees( @@ -441,6 +453,8 @@ def _calculate_reserves_zero_fees( s.seconds_per_step, arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, + ste_temperature=ste_temperature, + lp_supply_array=lp_prepared, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -451,9 +465,15 @@ def calculate_reserves_zero_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: return self._calculate_reserves_zero_fees( - params, run_fingerprint, prices, start_index, additional_oracle_input + params, + run_fingerprint, + prices, + start_index, + additional_oracle_input, + lp_supply_array, ) @partial(jit, static_argnums=(2,)) @@ -467,10 +487,8 @@ def calculate_reserves_with_dynamic_inputs( additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) - bout_length = run_fingerprint["bout_length"] - max_len = bout_length - 1 - if run_fingerprint["arb_frequency"] != 1: - max_len = max_len // run_fingerprint["arb_frequency"] + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + max_len = s.arb_prices.shape[0] materialized_inputs = materialize_dynamic_inputs( dynamic_inputs, run_fingerprint.get("dynamic_input_flags"), @@ -479,22 +497,13 @@ def calculate_reserves_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) - - noise_model = run_fingerprint.get("noise_model", "ratio") - noise_params = run_fingerprint.get("reclamm_noise_params", None) - if noise_params is not None and type(noise_params) is not dict: - noise_params = dict(noise_params) - - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = self.calculate_volatility_array( - prices, run_fingerprint, - ) - arb_vol = _prepare_dynamic_array( - volatility_array, start_index, bout_length, - run_fingerprint["arb_frequency"], max_len, - ) - else: - arb_vol = None + _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=None, + ) return _jax_calc_reclamm_reserves_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, @@ -512,6 +521,7 @@ def calculate_reserves_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), lp_supply_array=materialized_inputs.lp_supply, noise_model=noise_model, diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 2a46ad3..37082ff 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -17,7 +17,8 @@ import jax.numpy as jnp from jax import jit -from jax.lax import scan, cond +from jax.lax import scan, cond, stop_gradient +from jax.nn import sigmoid from jax.tree_util import Partial from functools import partial @@ -54,6 +55,35 @@ # Pure math functions # --------------------------------------------------------------------------- +def _ste_gate(hard_bool, soft_value): + """Hard forward / soft backward gate.""" + hard_value = hard_bool.astype(soft_value.dtype) + return soft_value + stop_gradient(hard_value - soft_value) + + +def _ste_greater_than(x, threshold, temperature=10.0): + hard = x > threshold + soft = sigmoid(temperature * (x - threshold)) + return _ste_gate(hard, soft) + + +def _ste_less_than(x, threshold, temperature=10.0): + hard = x < threshold + soft = sigmoid(temperature * (threshold - x)) + return _ste_gate(hard, soft) + + +def _ste_greater_equal(x, threshold, temperature=10.0): + hard = x >= threshold + soft = sigmoid(temperature * (x - threshold)) + return _ste_gate(hard, soft) + + +def _ste_select(mask, when_true, when_false): + """Select between two values using a 0/1 gate that can carry STE gradients.""" + return mask * when_true + (1.0 - mask) * when_false + + def compute_invariant(Ra, Rb, Va, Vb): """Compute constant-product invariant L = (Ra + Va) * (Rb + Vb).""" return (Ra + Va) * (Rb + Vb) @@ -610,6 +640,7 @@ def _reclamm_scan_step_zero_fees( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """Single scan step for zero-fee reClAMM pool. @@ -641,7 +672,6 @@ def _reclamm_scan_step_zero_fees( # Step 1: Update virtual balances if out of range centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) sqrt_Q = jnp.sqrt(compute_price_ratio(Ra, Rb, Va, Vb)) - out_of_range = centeredness < centeredness_margin market_price = prices[0] / prices[1] # Centeredness-proportional scaling: margin/centeredness multiplier @@ -672,8 +702,11 @@ def _reclamm_scan_step_zero_fees( Va_updated = jnp.where(use_cal, Va_cal, Va_geo) Vb_updated = jnp.where(use_cal, Vb_cal, Vb_geo) - Va = jnp.where(out_of_range, Va_updated, Va) - Vb = jnp.where(out_of_range, Vb_updated, Vb) + out_of_range_gate = _ste_less_than( + centeredness, centeredness_margin, ste_temperature + ) + Va = _ste_select(out_of_range_gate, Va_updated, Va) + Vb = _ste_select(out_of_range_gate, Vb_updated, Vb) # Step 2: Analytical zero-fee arb on effective reserves L = compute_invariant(Ra, Rb, Va, Vb) @@ -726,12 +759,14 @@ def _reclamm_scan_step_zero_fees_full_state( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """TEST-ONLY: scan step that outputs (reserves, Va, Vb).""" new_carry, new_reserves = _reclamm_scan_step_zero_fees( carry_list, input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) return new_carry, (new_reserves, new_carry[1], new_carry[2]) @@ -749,6 +784,7 @@ def _reclamm_scan_step_with_fees_and_revenue( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, noise_model="ratio", noise_params=None, @@ -757,11 +793,12 @@ def _reclamm_scan_step_with_fees_and_revenue( Primary implementation — ``_reclamm_scan_step_with_fees`` wraps this. - Carry: [real_reserves (2,), Va, Vb, prev_lp_supply, step_idx, active_start_ratio, - active_target_ratio, active_start_step, active_end_step, active_enabled] + Carry: [real_reserves (2,), Va, Vb, step_idx, active_start_ratio, + active_target_ratio, active_start_step, active_end_step, active_enabled, + prev_lp_supply] Input: [prices, active_initial_weights, per_asset_ratios, all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update, - lp_supply, (optional) volatility] + lp_supply] Returns ------- @@ -772,13 +809,13 @@ def _reclamm_scan_step_with_fees_and_revenue( prev_reserves = carry_list[0] Va = carry_list[1] Vb = carry_list[2] - prev_lp_supply = carry_list[3] - step_idx = carry_list[4] - active_start_ratio = carry_list[5] - active_target_ratio = carry_list[6] - active_start_step = carry_list[7] - active_end_step = carry_list[8] - active_enabled = carry_list[9] + step_idx = carry_list[3] + active_start_ratio = carry_list[4] + active_target_ratio = carry_list[5] + active_start_step = carry_list[6] + active_end_step = carry_list[7] + active_enabled = carry_list[8] + prev_lp_supply = carry_list[9] prices = input_list[0] active_initial_weights = input_list[1] @@ -790,9 +827,8 @@ def _reclamm_scan_step_with_fees_and_revenue( price_ratio_update = input_list[7] lp_supply = input_list[8] - # Scale both real and virtual reserves by LP supply ratio. - # Matches ReClammPool.sol onBeforeAddLiquidity / onBeforeRemoveLiquidity: - # all balances (real + virtual) scale proportionally with BPT supply. + # Scale both real and virtual reserves by LP supply ratio so liquidity + # add/remove events preserve proportional pool state. scale = lp_supply / prev_lp_supply lp_supply_change = lp_supply != prev_lp_supply prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) @@ -802,7 +838,6 @@ def _reclamm_scan_step_with_fees_and_revenue( Ra = prev_reserves[0] Rb = prev_reserves[1] - # Price-ratio schedule: apply target price ratio changes over time. event_has = price_ratio_update[0] > 0.5 event_target_ratio = jnp.maximum( jnp.where(jnp.isfinite(price_ratio_update[1]), price_ratio_update[1], 1.0), @@ -898,7 +933,6 @@ def _skip_schedule_state(_): # Step 1: Update virtual balances if out of range centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) sqrt_Q = jnp.sqrt(compute_price_ratio(Ra, Rb, Va, Vb)) - out_of_range = centeredness < centeredness_margin market_price = prices[0] / prices[1] # Centeredness-proportional scaling: margin/centeredness multiplier @@ -928,14 +962,15 @@ def _skip_schedule_state(_): Va_updated = jnp.where(use_cal, Va_cal, Va_geo) Vb_updated = jnp.where(use_cal, Vb_cal, Vb_geo) - Va = jnp.where(out_of_range, Va_updated, Va) - Vb = jnp.where(out_of_range, Vb_updated, Vb) + out_of_range_gate = _ste_less_than( + centeredness, centeredness_margin, ste_temperature + ) + Va = _ste_select(out_of_range_gate, Va_updated, Va) + Vb = _ste_select(out_of_range_gate, Vb_updated, Vb) # Step 2: Compute arb trade using G3M machinery on effective reserves effective_reserves = jnp.array([Ra + Va, Rb + Vb]) - fees_are_being_charged = gamma != 1.0 - # Zero-fee analytical arb L = compute_invariant(Ra, Rb, Va, Vb) market_price = prices[0] / prices[1] @@ -958,21 +993,24 @@ def _skip_schedule_state(_): 0, ) + fees_are_being_charged = gamma != 1.0 optimal_arb_trade = jnp.where(fees_are_being_charged, fee_trade, zero_fee_trade) # Check profitability for arb profit_to_arb = -(optimal_arb_trade * prices).sum() - arb_thresh arb_external_cost = 0.5 * arb_fees * (jnp.abs(optimal_arb_trade) * prices).sum() - do_trade = profit_to_arb >= arb_external_cost # Apply trade to REAL reserves only - applied_trade = jnp.where(do_trade, optimal_arb_trade, 0.0) + trade_gate = _ste_greater_equal( + profit_to_arb, arb_external_cost, ste_temperature + ) + applied_trade = _ste_select( + trade_gate, optimal_arb_trade, jnp.zeros_like(optimal_arb_trade) + ) Ra_new = Ra + applied_trade[0] Rb_new = Rb + applied_trade[1] - # --- Noise model dispatch --- - # noise_model is a concrete Python string (passed via Partial as static - # aux_data), so if/elif branches resolve at trace time. + # Optional noise-trader model. if noise_model == "ratio": noisy_reserves = calculate_reserves_after_noise_trade( applied_trade, jnp.array([Ra_new, Rb_new]), prices, @@ -989,25 +1027,21 @@ def _skip_schedule_state(_): _np = noise_params if noise_params is not None else {} if noise_model == "tsoukalas_sqrt": noise_vol = reclamm_tsoukalas_sqrt_noise_volume( - effective_value, gamma, volatility, - arb_volume, _np, + effective_value, gamma, volatility, arb_volume, _np ) elif noise_model == "tsoukalas_log": noise_vol = reclamm_tsoukalas_log_noise_volume( - effective_value, gamma, volatility, - arb_volume, _np, + effective_value, gamma, volatility, arb_volume, _np ) - else: # loglinear + else: noise_vol = reclamm_loglinear_noise_volume( - effective_value, gamma, volatility, - arb_volume, _np, + effective_value, gamma, volatility, arb_volume, _np ) noise_fee_income = (1.0 - gamma) * noise_vol - scale = 1.0 + noise_fee_income / jnp.maximum(real_value, 1e-8) - Ra_new = Ra_new * scale - Rb_new = Rb_new * scale - # else: "arb_only" — no noise trades + noise_scale = 1.0 + noise_fee_income / jnp.maximum(real_value, 1e-8) + Ra_new = Ra_new * noise_scale + Rb_new = Rb_new * noise_scale # Clamp-to-edge: if a real reserve would go negative, apply an # exact-in-given-out edge trade that drains that token to _DUST_USD @@ -1051,13 +1085,13 @@ def _skip_schedule_state(_): new_reserves, Va, Vb, - lp_supply, step_idx + 1.0, active_start_ratio, active_target_ratio, active_start_step, active_end_step, active_enabled, + lp_supply, ], (new_reserves, lp_fee_revenue_usd) @@ -1074,6 +1108,7 @@ def _reclamm_scan_step_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, noise_model="ratio", noise_params=None, @@ -1095,6 +1130,7 @@ def _reclamm_scan_step_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, noise_trader_ratio=noise_trader_ratio, noise_model=noise_model, noise_params=noise_params, @@ -1115,6 +1151,10 @@ def _reclamm_scan_step_with_fees_full_state( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """TEST-ONLY: fee scan step that also outputs virtual balances.""" new_carry, (new_reserves, _fee_rev) = _reclamm_scan_step_with_fees_and_revenue( @@ -1129,6 +1169,10 @@ def _reclamm_scan_step_with_fees_full_state( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params, ) return new_carry, (new_reserves, new_carry[1], new_carry[2]) @@ -1144,6 +1188,7 @@ def _jax_calc_reclamm_reserves_zero_fees( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, lp_supply_array=None, ): """Calculate reClAMM reserves over time with zero fees. @@ -1189,6 +1234,7 @@ def _jax_calc_reclamm_reserves_zero_fees( seconds_per_step=seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] @@ -1207,6 +1253,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, lp_supply_array=None, ): """TEST-ONLY: Like _jax_calc_reclamm_reserves_zero_fees but returns Va/Vb. @@ -1232,14 +1279,17 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( seconds_per_step=seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] - _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, [prices, lp_supply_array]) + _, (reserves, Va_history, Vb_history) = scan( + scan_fn, carry_init, [prices, lp_supply_array] + ) return reserves, Va_history, Vb_history -@partial(jit, static_argnames=('noise_model',)) +@partial(jit, static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_fees( initial_reserves, initial_Va, @@ -1255,6 +1305,7 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, lp_supply_array=None, noise_model="ratio", @@ -1308,34 +1359,43 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, noise_trader_ratio=noise_trader_ratio, noise_model=noise_model, noise_params=noise_params if noise_params is not None else {}, ) - scan_inputs = [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, - price_ratio_updates, lp_supply_array] - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - scan_inputs.append(volatility_array) - carry_init = [ initial_reserves, initial_Va, initial_Vb, - lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply + ] + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma_array, + arb_thresh_array, + arb_fees_array, + price_ratio_updates, + lp_supply_array, ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves -@partial(jit, static_argnums=(11,), static_argnames=('noise_model',)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1354,6 +1414,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, lp_supply_array=None, noise_model="ratio", @@ -1416,34 +1477,43 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, noise_trader_ratio=noise_trader_ratio, noise_model=noise_model, noise_params=noise_params if noise_params is not None else {}, ) - scan_inputs = [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, - price_ratio_updates, lp_supply_array] - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - scan_inputs.append(volatility_array) - carry_init = [ initial_reserves, initial_Va, initial_Vb, - lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( initial_reserves, initial_Va, @@ -1462,8 +1532,22 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """TEST-ONLY: dynamic-input reserve path returning virtual-balance history.""" + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1487,8 +1571,6 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[1]) ) - lp_supply_array = jnp.ones(prices.shape[0], dtype=prices.dtype) - _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) ) @@ -1512,31 +1594,43 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ initial_reserves, initial_Va, initial_Vb, - lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, (reserves, Va_history, Vb_history) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, - price_ratio_updates, lp_supply_array], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, scan_inputs) return reserves, Va_history, Vb_history -@partial(jit, static_argnames=('noise_model',)) +@partial(jit, static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( initial_reserves, initial_Va, @@ -1552,6 +1646,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, lp_supply_array=None, noise_model="ratio", @@ -1607,34 +1702,43 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, noise_trader_ratio=noise_trader_ratio, noise_model=noise_model, noise_params=noise_params if noise_params is not None else {}, ) - scan_inputs = [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, - price_ratio_updates, lp_supply_array] - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - scan_inputs.append(volatility_array) - carry_init = [ initial_reserves, initial_Va, initial_Vb, - lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply + ] + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma_array, + arb_thresh_array, + arb_fees_array, + price_ratio_updates, + lp_supply_array, ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue -@partial(jit, static_argnums=(11,), static_argnames=('noise_model',)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1653,6 +1757,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, noise_trader_ratio=0.0, lp_supply_array=None, noise_model="ratio", @@ -1721,28 +1826,37 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, noise_trader_ratio=noise_trader_ratio, noise_model=noise_model, noise_params=noise_params if noise_params is not None else {}, ) - scan_inputs = [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, - price_ratio_updates, lp_supply_array] - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - scan_inputs.append(volatility_array) - carry_init = [ initial_reserves, initial_Va, initial_Vb, - lp_supply_array[0], jnp.float64(0.0), # step_idx jnp.float64(0.0), # active_start_ratio jnp.float64(0.0), # active_target_ratio jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue diff --git a/quantammsim/runners/default_run_fingerprint.py b/quantammsim/runners/default_run_fingerprint.py index 45b9094..9b2a068 100644 --- a/quantammsim/runners/default_run_fingerprint.py +++ b/quantammsim/runners/default_run_fingerprint.py @@ -98,6 +98,7 @@ "reclamm_interpolation_method": "geometric", # "geometric" or "constant_arc_length" "reclamm_arc_length_speed": None, # auto-calibrate from geometric onset if None "reclamm_centeredness_scaling": False, # scale speed by margin/centeredness + "ste_temperature": 10.0, # STE gate sharpness; higher is closer to hard threshold "reclamm_learn_arc_length_speed": False, # include arc_length_speed in trainable params "reclamm_use_shift_exponent": False, # parametrise shift rate as shift_exponent (log-friendly) "reclamm_learn_fees": False, # include fees in trainable params (Optuna search over fee level) diff --git a/scripts/demo_run_chunks_from_chain_data.py b/scripts/demo_run_chunks_from_chain_data.py index 09c19d1..fbd9cf9 100644 --- a/scripts/demo_run_chunks_from_chain_data.py +++ b/scripts/demo_run_chunks_from_chain_data.py @@ -28,7 +28,6 @@ import numpy as np import pandas as pd import matplotlib as mpl -from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import matplotlib.pyplot as plt import jax.numpy as jnp @@ -475,12 +474,10 @@ def _df_meta_and_head(df, name, n=3): run_fingerprint=fingerprint, coarse_weights=cw_window, params=params, - dynamic_input_frames=DynamicInputFrames( - fees=scraped["fees_df"], - gas_cost=scraped["gas_cost_df"], - lp_supply=scraped["lp_supply_df"], - arb_fees=scraped["arb_fees_df"], - ), + fees_df=scraped["fees_df"], + gas_cost_df=scraped["gas_cost_df"], + lp_supply_df=scraped["lp_supply_df"], + arb_fees_df=scraped["arb_fees_df"], ) # ---------------- Correct, window-aligned plotting block (time-aware + plain y) ---------------- diff --git a/scripts/demo_run_from_chain_data.py b/scripts/demo_run_from_chain_data.py index b1f4139..78ecbc4 100644 --- a/scripts/demo_run_from_chain_data.py +++ b/scripts/demo_run_from_chain_data.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import ( memory_days_to_logit_lamb, ) @@ -998,12 +997,10 @@ def generate_daily_variations(start_date_str, end_date_str): run_fingerprint=config["fingerprint"], coarse_weights=config["coarse_weights"], params=config["params"], - dynamic_input_frames=DynamicInputFrames( - fees=config["fees_df"], - gas_cost=config["gas_cost_df"], - lp_supply=config["lp_supply_df"], - arb_fees=config["arb_fees_df"], - ), + fees_df=config["fees_df"], + gas_cost_df=config["gas_cost_df"], + lp_supply_df=config["lp_supply_df"], + arb_fees_df=config["arb_fees_df"], ) print("-" * 80) print(f"Pool Type: {config['fingerprint']['rule']}") @@ -1194,4 +1191,4 @@ def generate_daily_variations(start_date_str, end_date_str): # actual_reserves_np=local_reserves, # actual_unix_values=datetime_array, # ) - # raise Exception("Stop here") + # raise Exception("Stop here") \ No newline at end of file diff --git a/scripts/sim_vs_world_comparison.py b/scripts/sim_vs_world_comparison.py new file mode 100644 index 0000000..ca3046f --- /dev/null +++ b/scripts/sim_vs_world_comparison.py @@ -0,0 +1,972 @@ +#!/usr/bin/env python3 +"""Compare quantammsim reClAMM / Balancer vs reclamm-simulations repo + on-chain. + +Runs: + 1. Zero-fee Balancer pool (quantammsim) — the normalization baseline + 2. reClAMM pool with on-chain params (quantammsim) + 3. Loads reclamm-simulations results + world values from CSV + 4. Gas-experiment runs: time-varying gas from on-chain percentiles, + 50% protocol fee take, on-chain fees + +All comparisons align quantammsim's minute-level output to the world state +CSV's actual Unix timestamps, eliminating timing drift from block-time +variability. + +4-panel plot matching the reclamm-simulations format: + Top-left: Price (WETH/AAVE) — both repos overlaid + Top-right: (legend) + Bottom-left: Absolute value in WETH + Bottom-right: Value relative to feeless weighted (Balancer = 1.0) + +Usage: + python scripts/sim_vs_world_comparison.py + python scripts/sim_vs_world_comparison.py --csv /path/to/csv + python scripts/sim_vs_world_comparison.py --gas-experiment +""" + +import argparse +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import jax.numpy as jnp +from pathlib import Path +from datetime import datetime, timezone + +from quantammsim.runners.jax_runners import do_run_on_historic_data + +# ── On-chain reClAMM params ─────────────────────────────────────────────────── +ONCHAIN_FEES = 0.0025 + +ONCHAIN_LAUNCH_PARAMS = { # deployment through 2025-12-18 + "price_ratio": 1.5014, + "centeredness_margin": 0.5, + "shift_exponent": 0.1, +} +ONCHAIN_CURRENT_PARAMS = { # post 2025-12-18 governance + "price_ratio": 4.0, + "centeredness_margin": 0.1, + "shift_exponent": 0.001, +} +GOVERNANCE_DATE = "2025-12-18" + +# CSV starts at ~17.2 WETH ≈ $50k at $2900/ETH. +INITIAL_POOL_VALUE = 50_000.0 + +# Gas cost = arb profit threshold in USD. +# reclamm-simulations uses profit_threshold = 3e-4 WETH (in token1 units). +# quantammsim's arb_thresh is in USD: 3 * 3e-4 WETH × ~$3000/ETH ≈ $2.70. +ARB_GAS_COST = 2.7 + +DEFAULT_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_values_AAVE_WETH.csv" +) +ZEROFEE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_zerofee_centered_AAVE_WETH.csv" +) +ZEROFEE_MINUTE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_zerofee_centered_minute_AAVE_WETH.csv" +) +WORLD_STATE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_world_AAVE_WETH.csv" +) +DEFAULT_START = "2025-08-16 00:00:00" +DEFAULT_END = "2026-01-04 00:00:00" +DEFAULT_TOKENS = ["AAVE", "ETH"] +HALF_DAY = 720 # minutes + +# Gas experiment +GAS_CSV_DIR = Path(__file__).resolve().parent.parent / "gas_csvs" +GAS_PERCENTILES = ["50p", "75p", "90p", "95p"] +GAS_SCALE_FACTORS = [0.25, 0.5, 0.75, 1.0] +FLAT_GAS_USD = [0.0, 0.25, 0.50, 1.0, 2.0, 3.0, 5.0] +PROTOCOL_FEE_SPLIT = 0.5 + + +def parse_args(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--csv", default=DEFAULT_CSV) + p.add_argument("--start", default=DEFAULT_START) + p.add_argument("--end", default=DEFAULT_END) + p.add_argument("--tokens", nargs="+", default=DEFAULT_TOKENS) + p.add_argument("--output", default="sim_vs_world_comparison.png") + p.add_argument( + "--gas-experiment", action="store_true", + help="Run gas-experiment sweep (time-varying gas, 50%% protocol fee)", + ) + p.add_argument( + "--launch-params", action="store_true", + help="Use launch params instead of current params in gas experiment", + ) + p.add_argument( + "--gas-scale-sweep", action="store_true", + help="Sweep gas cost scale factors, rebase to world, truncate at governance", + ) + p.add_argument( + "--best-gas", action="store_true", + help="Run the 3 best gas configs vs world (clean plot)", + ) + return p.parse_args() + + +def load_onchain_initial_state(): + """Load the on-chain pool state at t=0 from the world state CSV. + + Returns (state_dict, start_time_str) where state_dict has + Ra, Rb, Va, Vb (token units) and start_time_str is rounded + to the nearest minute for alignment with minute-level price data. + """ + df = pd.read_csv(WORLD_STATE_CSV) + r = df.iloc[0] + state = { + "Ra": float(r.balance_0), + "Rb": float(r.balance_1), + "Va": float(r.virtual_0), + "Vb": float(r.virtual_1), + } + # Round to nearest minute for price data alignment + ts_sec = int(r.timestamp) + ts_minute = (ts_sec // 60) * 60 + start_str = datetime.utcfromtimestamp(ts_minute).strftime("%Y-%m-%d %H:%M:%S") + return state, start_str + + +def load_world_timestamps(): + """Load Unix timestamps (seconds) from the world state CSV.""" + df = pd.read_csv(WORLD_STATE_CSV) + return df["timestamp"].values + + +def load_world_normalized_balances(): + """Load BPT-normalized on-chain balances and timestamps. + + Normalizes balances to initial BPT supply so that value tracks a + fixed LP position (accounts for joins/exits changing BPT supply). + + Returns (norm_bal_0, norm_bal_1, timestamps_sec). + """ + df = pd.read_csv(WORLD_STATE_CSV) + bpt_0 = df["bpt_supply"].iloc[0] + norm = bpt_0 / df["bpt_supply"].values + return ( + df["balance_0"].values * norm, + df["balance_1"].values * norm, + df["timestamp"].values, + ) + + +def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): + """Sample a minute-level array at specific Unix timestamps. + + For each target timestamp, finds the nearest minute index in the + sim output and returns the corresponding value. + + Parameters + ---------- + minute_vals : array, shape (N,) + Minute-level sim output. + start_unix_sec : float + Unix timestamp (seconds) of minute_vals[0]. + timestamps_sec : array + Unix timestamps (seconds) to sample at. + + Returns + ------- + array : values at the nearest minute to each target timestamp. + """ + indices = np.round((timestamps_sec - start_unix_sec) / 60).astype(int) + indices = np.clip(indices, 0, len(minute_vals) - 1) + return minute_vals[indices] + + +def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, + protocol_fee_split=0.0, gas_cost_df=None, + onchain_initial_state=None): + """Run a quantammsim pool and return minute-level results. + + Returns (val_eth, price_ratio, start_unix_sec) where val_eth and + price_ratio are minute-level arrays and start_unix_sec is the Unix + timestamp (seconds) of the first element. + """ + fp = { + "tokens": tokens, + "rule": rule, + "startDateString": start, + "endDateString": end, + "initial_pool_value": INITIAL_POOL_VALUE, + "fees": fees, + "gas_cost": gas_cost, + "arb_fees": 0.0, + "do_arb": True, + "arb_frequency": 1, + "chunk_period": 1440, + "weight_interpolation_period": 1440, + } + if rule == "reclamm": + fp["reclamm_use_shift_exponent"] = True + fp["reclamm_interpolation_method"] = "geometric" + fp["reclamm_centeredness_scaling"] = False + if protocol_fee_split != 0.0: + fp["protocol_fee_split"] = protocol_fee_split + if onchain_initial_state is not None: + fp["reclamm_initial_state"] = onchain_initial_state + + result = do_run_on_historic_data( + run_fingerprint=fp, params=params, gas_cost_df=gas_cost_df, + ) + + # Prices: sorted tokens → [AAVE, ETH] in USD + prices = np.array(result["prices"]) + eth_usd = prices[:, 1] + price_ratio = prices[:, 0] / prices[:, 1] # WETH/AAVE + + # Pool value in ETH + val_eth = np.array(result["value"]) / eth_usd + + # Compute start timestamp from startDateString + start_unix_sec = datetime.strptime( + start, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + return val_eth, price_ratio, start_unix_sec + + +def load_gas_csv(percentile): + """Load a gas CSV and return a DataFrame with columns [unix, trade_gas_cost_usd]. + + Gas CSV timestamps are offset by ~59s from exact minutes. Round down + to the nearest minute so they align with the simulator's minute-level index. + """ + path = GAS_CSV_DIR / f"Gas_{percentile}.csv" + df = pd.read_csv(path) + df = df.rename(columns={"USD": "trade_gas_cost_usd"}) + df["unix"] = (df["unix"] // 60000) * 60000 # floor to minute boundary + return df + + +def run_gas_experiment(args): + """Run gas-experiment sweep and produce comparison plot.""" + tokens = args.tokens + start, end = args.start, args.end + + # ── Select params ───────────────────────────────────────────────── + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # ── Baselines ────────────────────────────────────────────────────── + print("Running Balancer (zero-fee 50/50)...") + bal_params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + bal_eth_min, qsim_price_min, start_sec = run_pool( + tokens, start, end, "balancer", 0.0, bal_params, + ) + + print(f"Running reClAMM ({param_label} params, flat gas, no protocol fee)...") + reclamm_flat_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=ARB_GAS_COST, + ) + + # ── Load world values from CSV ───────────────────────────────────── + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + + # ── Gas percentile runs ──────────────────────────────────────────── + gas_results_min = {} + for pct in GAS_PERCENTILES: + print(f"Running reClAMM ({param_label} params, gas={pct}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df = load_gas_csv(pct) + val_eth_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + ) + gas_results_min[pct] = val_eth_min + + # ── Sample at world timestamps ──────────────────────────────────── + world_ts = load_world_timestamps() + n = min(len(df), len(world_ts)) + world_ts = world_ts[:n] + + bal_eth = sample_at_timestamps(bal_eth_min, start_sec, world_ts) + reclamm_flat_eth = sample_at_timestamps(reclamm_flat_min, start_sec, world_ts) + qsim_price = sample_at_timestamps(qsim_price_min, start_sec, world_ts) + gas_results = { + pct: sample_at_timestamps(v, start_sec, world_ts) + for pct, v in gas_results_min.items() + } + + csv_world = df["world"].values[:n] + csv_feeless = df["feeless weighted"].values[:n] + print(f" Aligned: {n} world-timestamp points") + t = np.arange(n) + + # Governance half-day index + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts, gov_unix) + + # ── Plot: relative to feeless weighted ───────────────────────────── + fig, (ax_price, ax_rel) = plt.subplots(2, 1, figsize=(14, 9), + gridspec_kw={"height_ratios": [1, 2]}) + + # Top: price + ax_price.plot(t, qsim_price, color="gray", alpha=0.6, linewidth=1) + ax_price.set_ylabel("AAVE/ETH") + ax_price.set_title("Price") + ax_price.set_ylim(bottom=0) + if gov_idx < n: + ax_price.axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + # Bottom: relative values + ax_rel.axhline(y=1.0, color="blue", linewidth=2, label="feeless weighted") + + # Flat-gas baseline (no protocol fee) + flat_rel = reclamm_flat_eth / bal_eth + ax_rel.plot(t, flat_rel, linewidth=2, color="gray", linestyle="--", + label=f"flat gas ${ARB_GAS_COST}, no protocol fee") + + # Gas percentile runs + colors = {"50p": "#2ca02c", "75p": "#ff7f0e", "90p": "#d62728", "95p": "#9467bd"} + for pct in GAS_PERCENTILES: + vals = gas_results[pct] + rel = vals / bal_eth + ax_rel.plot(t, rel, linewidth=1.5, color=colors[pct], + label=f"gas {pct}, {int(PROTOCOL_FEE_SPLIT*100)}% protocol fee") + + # World values + world_rel = csv_world / csv_feeless + ax_rel.plot(t, world_rel, linewidth=1.5, marker=".", markersize=2, + color="brown", label="world (on-chain)") + + ax_rel.set_xlabel("half days") + ax_rel.set_ylabel("value / feeless weighted") + ax_rel.set_title("LP value relative to feeless weighted (Balancer 50/50)") + ax_rel.legend(fontsize=8, loc="lower left") + ax_rel.grid(True, alpha=0.2) + if gov_idx < n: + ax_rel.axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + ax_rel.text(gov_idx + 1, ax_rel.get_ylim()[1] * 0.98, + "governance", fontsize=7, color="gray", va="top") + + tokens_str = "/".join(tokens) + fig.suptitle( + f"reClAMM gas experiment ({param_label} params) — {tokens_str}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + out = args.output.replace(".png", f"_gas_experiment_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # ── Summary table ────────────────────────────────────────────────── + print(f"\n{'Scenario':<45} {'Final rel':>10} {'vs world':>10}") + print("-" * 65) + world_final_rel = world_rel[-1] if len(world_rel) > 0 else float("nan") + print(f"{'Flat gas, no protocol fee':<45} {flat_rel[-1]:>10.4f} " + f"{flat_rel[-1] - world_final_rel:>+10.4f}") + for pct in GAS_PERCENTILES: + rel = gas_results[pct] / bal_eth + print(f"{'Gas ' + pct + f', {int(PROTOCOL_FEE_SPLIT*100)}% protocol fee':<45} " + f"{rel[-1]:>10.4f} {rel[-1] - world_final_rel:>+10.4f}") + print(f"{'World (on-chain)':<45} {world_final_rel:>10.4f}") + + +def run_gas_scale_experiment(args): + """Sweep gas cost scale factors, rebase to world, truncate at governance.""" + tokens = args.tokens + end = args.end + + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # Load on-chain initial state and derive start time + onchain_state, onchain_start = load_onchain_initial_state() + start = onchain_start + print(f"On-chain initial state: Ra={onchain_state['Ra']:.2f}, " + f"Rb={onchain_state['Rb']:.2f}, Va={onchain_state['Va']:.2f}, " + f"Vb={onchain_state['Vb']:.2f}") + print(f"Sim start time (from on-chain): {start}") + + # Load world + reclamm-simulations values + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + + # Load world timestamps and find governance cutoff + world_ts = load_world_timestamps() + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts, gov_unix) + + # Run all (percentile, scale) combinations + results_min = {} + price_ratio_min = None + for pct in GAS_PERCENTILES: + gas_df_raw = load_gas_csv(pct) + for scale in GAS_SCALE_FACTORS: + label = f"{pct} × {scale}" + print(f"Running reClAMM ({param_label}, gas={label}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df = gas_df_raw.copy() + gas_df["trade_gas_cost_usd"] = gas_df_raw["trade_gas_cost_usd"] * scale + val_eth_min, pr_min, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + onchain_initial_state=onchain_state, + ) + results_min[(pct, scale)] = (val_eth_min, start_sec) + if price_ratio_min is None: + price_ratio_min = pr_min + + # Flat gas cost runs + flat_results_min = {} + for gas_usd in FLAT_GAS_USD: + print(f"Running reClAMM ({param_label}, flat gas=${gas_usd}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + val_eth_min, _, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=gas_usd, protocol_fee_split=PROTOCOL_FEE_SPLIT, + onchain_initial_state=onchain_state, + ) + flat_results_min[gas_usd] = (val_eth_min, start_sec) + + # ── World values: on-chain balances × quantammsim prices ────────── + world_bal_0, world_bal_1, world_ts = load_world_normalized_balances() + + # reclamm-sim comparison uses its own CSV (self-consistent pricing) + csv_world = df["world"].values + csv_sim = df["simulation"].values + + n = min(gov_idx, len(world_bal_0), len(csv_world), len(csv_sim), len(world_ts)) + print(f" Truncated at governance: {n} world-timestamp points") + t = np.arange(n) + world_ts_trunc = world_ts[:n] + + # Repriced world for quantammsim comparison + price_at_world = sample_at_timestamps( + price_ratio_min, start_sec, world_ts_trunc, + ) + world_val = world_bal_0[:n] * price_at_world + world_bal_1[:n] + world_growth = world_val / world_val[0] + + # CSV-based world for reclamm-sim comparison (self-consistent pricing) + csv_world = csv_world[:n] + csv_sim = csv_sim[:n] + world_growth_csv = csv_world / csv_world[0] + recsim_growth = csv_sim / csv_sim[0] + + # Sample all sim runs at world timestamps + start_sec = flat_results_min[FLAT_GAS_USD[0]][1] + + results = {} + for key, (val_min, _) in results_min.items(): + results[key] = sample_at_timestamps(val_min, start_sec, world_ts_trunc) + + flat_results = {} + for gas_usd, (val_min, _) in flat_results_min.items(): + flat_results[gas_usd] = sample_at_timestamps(val_min, start_sec, world_ts_trunc) + + # Compute growth ratios + flat_growths = {} + for gas_usd in FLAT_GAS_USD: + vals = flat_results[gas_usd] + flat_growths[gas_usd] = vals / vals[0] + + # ── Plot (% deviation from world: positive = sim below world) ─── + fig, (ax_ts, ax_pct, ax_flat) = plt.subplots( + 1, 3, figsize=(20, 7), gridspec_kw={"width_ratios": [3, 1, 1]}, + ) + + # Left: time series of % deviation from world + ax_ts.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # reclamm-simulations (uses CSV-based world for self-consistent pricing) + recsim_dev = (1 - recsim_growth / world_growth_csv) * 100 + ax_ts.plot(t, recsim_dev, color="red", linewidth=2, + linestyle="--", label="reclamm-sim") + + # Gas scale sweep (percentile-based) + colors = {"50p": "#2ca02c", "75p": "#ff7f0e", "90p": "#d62728", "95p": "#9467bd"} + for pct in GAS_PERCENTILES: + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + dev = (1 - sim_growth / world_growth) * 100 + alpha = 0.3 + 0.7 * scale + lw = 0.8 + 1.2 * scale + if scale == 1.0: + label = f"{pct} × {scale}" + elif pct == "50p": + label = f"50p × {scale}" + else: + label = None + ax_ts.plot(t, dev, color=colors[pct], alpha=alpha, + linewidth=lw, label=label) + + # Flat gas runs + flat_cmap = plt.cm.copper + for i, gas_usd in enumerate(FLAT_GAS_USD): + c = flat_cmap(i / max(len(FLAT_GAS_USD) - 1, 1)) + dev = (1 - flat_growths[gas_usd] / world_growth) * 100 + ax_ts.plot(t, dev, color=c, linewidth=1.5, linestyle="-.", + label=f"flat ${gas_usd}") + + ax_ts.set_xlabel("half days") + ax_ts.set_ylabel("% deviation from world") + ax_ts.set_title("LP value vs world (pre-governance)") + ax_ts.legend(fontsize=6, loc="best", ncol=2) + ax_ts.grid(True, alpha=0.2) + + # Reference lines for both summary panels (as % deviation) + recsim_final_dev = (1 - recsim_growth[-1] / world_growth_csv[-1]) * 100 + + # Middle: final % deviation vs percentile scale factor + ax_pct.axhline(y=0.0, color="brown", linewidth=2, label="world") + ax_pct.axhline(y=recsim_final_dev, color="red", linewidth=1.5, + linestyle="--", label=f"reclamm-sim ({recsim_final_dev:+.2f}%)") + + for pct in GAS_PERCENTILES: + finals = [] + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + finals.append((1 - sim_growth[-1] / world_growth[-1]) * 100) + ax_pct.plot(GAS_SCALE_FACTORS, finals, marker="o", + color=colors[pct], linewidth=2, label=pct) + + ax_pct.set_xlabel("gas scale factor\n(1.0 = 450k gas)") + ax_pct.set_ylabel("% deviation from world") + ax_pct.set_title("Percentile gas") + ax_pct.legend(fontsize=6) + ax_pct.grid(True, alpha=0.2) + + # Right: final % deviation vs flat gas cost + ax_flat.axhline(y=0.0, color="brown", linewidth=2, label="world") + ax_flat.axhline(y=recsim_final_dev, color="red", linewidth=1.5, + linestyle="--", label=f"reclamm-sim ({recsim_final_dev:+.2f}%)") + + flat_finals = [] + for gas_usd in FLAT_GAS_USD: + flat_finals.append( + (1 - flat_growths[gas_usd][-1] / world_growth[-1]) * 100 + ) + ax_flat.plot(FLAT_GAS_USD, flat_finals, marker="s", color="black", + linewidth=2, label="flat gas") + + ax_flat.set_xlabel("flat gas cost (USD)") + ax_flat.set_ylabel("% deviation from world") + ax_flat.set_title("Flat gas") + ax_flat.legend(fontsize=6) + ax_flat.grid(True, alpha=0.2) + + tokens_str = "/".join(tokens) + fig.suptitle( + f"reClAMM gas sweep ({param_label} params) — {tokens_str}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + out = args.output.replace(".png", f"_gas_scale_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # ── Summary table (% deviation from world) ───────────────────────── + print(f"\n{'Scenario':<35} {'% dev from world':>16}") + print("-" * 52) + print(f"{'reclamm-sim':<35} {recsim_final_dev:>+16.2f}%") + print() + for gas_usd in FLAT_GAS_USD: + dev = (1 - flat_growths[gas_usd][-1] / world_growth[-1]) * 100 + print(f"{'Flat $' + f'{gas_usd}':<35} {dev:>+16.2f}%") + print() + for pct in GAS_PERCENTILES: + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + dev = (1 - sim_growth[-1] / world_growth[-1]) * 100 + print(f"{'Gas ' + pct + f' × {scale}':<35} {dev:>+16.2f}%") + + +def run_best_gas_experiment(args): + """Run the 3 best gas configs vs world on a clean single-panel plot.""" + tokens = args.tokens + end = args.end + + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # Load on-chain initial state and derive start time + onchain_state, onchain_start = load_onchain_initial_state() + start = onchain_start + print(f"On-chain initial state: Ra={onchain_state['Ra']:.2f}, " + f"Rb={onchain_state['Rb']:.2f}, Va={onchain_state['Va']:.2f}, " + f"Vb={onchain_state['Vb']:.2f}") + print(f"Sim start time (from on-chain): {start}") + + # Find governance cutoff from world timestamps + world_ts_all = load_world_timestamps() + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts_all, gov_unix) + + # ── The best configs ───────────────────────────────────────────── + configs = [ + ("Flat $1.00", "black", "-"), + ("50p × 1.0", "#2ca02c", "-"), + ("75p × 0.75", "#ff7f0e", "-"), + ("90p × 0.25", "#d62728", "-"), + ] + + # 1) Flat $1.00 + print(f"Running reClAMM ({param_label}, flat gas=$1.00, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + flat1_min, price_ratio_min, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=1.0, protocol_fee_split=PROTOCOL_FEE_SPLIT, + onchain_initial_state=onchain_state, + ) + + # 2) 50p × 1.0 + print(f"Running reClAMM ({param_label}, gas=50p × 1.0, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_50p = load_gas_csv("50p") + g50_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_50p, + onchain_initial_state=onchain_state, + ) + + # 3) 75p × 0.75 + print(f"Running reClAMM ({param_label}, gas=75p × 0.75, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_75p = load_gas_csv("75p") + gas_df_75p_scaled = gas_df_75p.copy() + gas_df_75p_scaled["trade_gas_cost_usd"] *= 0.75 + g75_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_75p_scaled, + onchain_initial_state=onchain_state, + ) + + # 4) 90p × 0.25 + print(f"Running reClAMM ({param_label}, gas=90p × 0.25, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_90p = load_gas_csv("90p") + gas_df_90p_scaled = gas_df_90p.copy() + gas_df_90p_scaled["trade_gas_cost_usd"] *= 0.25 + g90_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_90p_scaled, + onchain_initial_state=onchain_state, + ) + + # ── World values: on-chain balances × quantammsim prices ────────── + # Both sim and world valued at the same price at each point, + # so price fluctuations cancel in the growth ratio comparison. + world_bal_0, world_bal_1, world_ts = load_world_normalized_balances() + n = min(gov_idx, len(world_bal_0), len(world_ts)) + print(f" Truncated at governance: {n} world-timestamp points") + t = np.arange(n) + world_ts_trunc = world_ts[:n] + + # Sample quantammsim price ratio at world timestamps + price_at_world = sample_at_timestamps( + price_ratio_min, start_sec, world_ts_trunc, + ) + # World value in ETH = norm_AAVE * (AAVE/ETH) + norm_ETH + world_val = world_bal_0[:n] * price_at_world + world_bal_1[:n] + world_growth = world_val / world_val[0] + + run_vals = [ + sample_at_timestamps(flat1_min, start_sec, world_ts_trunc), + sample_at_timestamps(g50_min, start_sec, world_ts_trunc), + sample_at_timestamps(g75_min, start_sec, world_ts_trunc), + sample_at_timestamps(g90_min, start_sec, world_ts_trunc), + ] + + growths = [v / v[0] for v in run_vals] + + # ── Plot ────────────────────────────────────────────────────────── + fig, ax = plt.subplots(figsize=(14, 6)) + + ax.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # Best 3 + for (label, color, ls), g in zip(configs, growths): + dev = (1 - g / world_growth) * 100 + final_dev = dev[-1] + ax.plot(t, dev, color=color, linewidth=2, linestyle=ls, + label=f"{label} (final {final_dev:+.2f}%)") + + ax.set_xlabel("half days") + ax.set_ylabel("% deviation from world") + ax.set_title( + f"Best gas configs vs world ({param_label} params) — " + f"{'/'.join(tokens)}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + ) + ax.legend(fontsize=9, loc="best") + ax.grid(True, alpha=0.3) + + plt.tight_layout() + out = args.output.replace(".png", f"_best_gas_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # Summary + labels = [c[0] for c in configs] + print(f"\n{'Scenario':<25} {'% dev from world':>16}") + print("-" * 42) + for label, g in zip(labels, growths): + dev = (1 - g[-1] / world_growth[-1]) * 100 + print(f"{label:<25} {dev:>+16.2f}%") + + +def main(): + args = parse_args() + + if args.best_gas: + run_best_gas_experiment(args) + return + + if args.gas_scale_sweep: + run_gas_scale_experiment(args) + return + + if args.gas_experiment: + run_gas_experiment(args) + return + + # ── Load CSVs ───────────────────────────────────────────────────── + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + n_csv = len(df) + print(f" {n_csv} half-day points") + + print("Loading zero-fee minute-level CSV...") + df_zf_min = pd.read_csv(ZEROFEE_MINUTE_CSV) + print(f" {len(df_zf_min)} minute points") + + # Load world timestamps for alignment + world_ts = load_world_timestamps() + + # ── Run quantammsim pools (minute-level) ────────────────────────── + print("Running Balancer (zero-fee 50/50)...") + bal_params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + bal_eth_min, qsim_price_min, start_sec = run_pool( + args.tokens, args.start, args.end, "balancer", 0.0, bal_params, + ) + + print("Running reClAMM (launch, zero-fee, zero-gas)...") + launch_params = {k: jnp.array(v) for k, v in ONCHAIN_LAUNCH_PARAMS.items()} + reclamm_zerofee_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", 0.0, launch_params, + gas_cost=0.0, + ) + + print(f"Running reClAMM (launch params, gas=${ARB_GAS_COST})...") + reclamm_launch_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", ONCHAIN_FEES, launch_params, + gas_cost=ARB_GAS_COST, + ) + + print(f"Running reClAMM (current params, gas=${ARB_GAS_COST})...") + current_params = {k: jnp.array(v) for k, v in ONCHAIN_CURRENT_PARAMS.items()} + reclamm_current_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", ONCHAIN_FEES, current_params, + gas_cost=ARB_GAS_COST, + ) + + # ── Sample at world timestamps ──────────────────────────────────── + n = min(n_csv, len(world_ts)) + world_ts_trunc = world_ts[:n] + + bal_eth = sample_at_timestamps(bal_eth_min, start_sec, world_ts_trunc) + reclamm_zerofee_eth = sample_at_timestamps(reclamm_zerofee_min, start_sec, world_ts_trunc) + reclamm_launch_eth = sample_at_timestamps(reclamm_launch_min, start_sec, world_ts_trunc) + reclamm_current_eth = sample_at_timestamps(reclamm_current_min, start_sec, world_ts_trunc) + qsim_price = sample_at_timestamps(qsim_price_min, start_sec, world_ts_trunc) + + print(f" Aligned: {n} world-timestamp points " + f"(qsim minutes={len(bal_eth_min)}, csv={n_csv})") + t = np.arange(n) + + csv_price = df["price"].values[:n] + csv_feeless = df["feeless weighted"].values[:n] + csv_sim = df["simulation"].values[:n] + csv_hold = df["hold"].values[:n] + csv_world = df["world"].values[:n] + + # Governance change index + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts_trunc, gov_unix) + + # Normalize quantammsim to same starting value as CSV + v0 = csv_feeless[0] + bal_norm = bal_eth * (v0 / bal_eth[0]) + zerofee_norm = reclamm_zerofee_eth * (v0 / reclamm_zerofee_eth[0]) + launch_norm = reclamm_launch_eth * (v0 / reclamm_launch_eth[0]) + current_norm = reclamm_current_eth * (v0 / reclamm_current_eth[0]) + + # Relative values (÷ respective feeless weighted baseline) + zerofee_rel = reclamm_zerofee_eth / bal_eth + launch_rel = reclamm_launch_eth / bal_eth + current_rel = reclamm_current_eth / bal_eth + csv_sim_rel = csv_sim / csv_feeless + csv_hold_rel = csv_hold / csv_feeless + csv_world_rel = csv_world / csv_feeless + + # ── Plot ────────────────────────────────────────────────────────── + fig, axs = plt.subplots(2, 2, figsize=(13, 8)) + + # Top-left: price + axs[0][0].plot(t, csv_price, label="reclamm-sim", alpha=0.8) + axs[0][0].plot(t, qsim_price, label="quantammsim", alpha=0.8, linestyle="--") + axs[0][0].set_ylabel("WETH/AAVE") + axs[0][0].set_title("Price") + axs[0][0].set_ylim(bottom=0) + axs[0][0].legend(fontsize=8) + if gov_idx < n: + axs[0][0].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + # Top-right: remove (legend is on other panels) + axs[0][1].remove() + + # Bottom-left: absolute values in WETH + axs[1][0].plot(t, bal_norm, label="qsim feeless weighted", linewidth=2, color="blue") + axs[1][0].plot(t, launch_norm, label="qsim reClAMM (launch)", linewidth=2, color="orange") + axs[1][0].plot(t, current_norm, label="qsim reClAMM (current)", linewidth=2, + color="purple", linestyle="-.") + axs[1][0].plot(t, csv_sim, label="reclamm-sim simulation", linewidth=1.5, + linestyle="--", color="red") + axs[1][0].plot(t, csv_hold, label="hold", linewidth=1.5, color="green") + axs[1][0].plot(t, csv_world, label="world values", linewidth=1.5, + marker=".", markersize=2, color="brown") + axs[1][0].set_title("Value histories") + axs[1][0].set_xlabel("half days") + axs[1][0].set_ylabel("Value in WETH") + axs[1][0].set_ylim(bottom=0) + axs[1][0].legend(fontsize=7, loc="upper right") + if gov_idx < n: + axs[1][0].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + axs[1][0].text(gov_idx + 1, axs[1][0].get_ylim()[1] * 0.95, + "governance", fontsize=7, color="gray", va="top") + + # Bottom-right: relative to feeless weighted + axs[1][1].axhline(y=1.0, color="blue", linewidth=2, label="feeless weighted") + axs[1][1].plot(t, launch_rel, label="qsim reClAMM (launch)", linewidth=2, color="orange") + axs[1][1].plot(t, current_rel, label="qsim reClAMM (current)", linewidth=2, + color="purple", linestyle="-.") + axs[1][1].plot(t, csv_sim_rel, label="reclamm-sim simulation", linewidth=1.5, + linestyle="--", color="red") + axs[1][1].plot(t, csv_hold_rel, label="hold", linewidth=1.5, color="green") + axs[1][1].plot(t, csv_world_rel, label="world values", linewidth=1.5, + marker=".", markersize=2, color="brown") + axs[1][1].set_title("Value relative to feeless weighted") + axs[1][1].set_xlabel("half days") + axs[1][1].set_ylabel("relative value") + axs[1][1].legend(fontsize=7, loc="lower left") + if gov_idx < n: + axs[1][1].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + tokens_str = "/".join(args.tokens) + fig.suptitle( + f"quantammsim vs reclamm-simulations — {tokens_str}\n" + f"Launch: {list(ONCHAIN_LAUNCH_PARAMS.values())}, " + f"Current: {list(ONCHAIN_CURRENT_PARAMS.values())}, " + f"fees: {ONCHAIN_FEES}", + fontsize=10, + ) + plt.tight_layout() + plt.savefig(args.output, dpi=150, bbox_inches="tight") + print(f"\nSaved: {args.output}") + + # ── Zero-fee comparison plot (minute-level) ─────────────────────── + # Revalue reclamm-sim balances at quantammsim's price so both sides + # use the same price and the comparison is purely about balances. + # Skip row 0 of the CSV (initial state before first arb) to align + # with quantammsim's reserves[0] which is post-first-step. + ext_bal_0 = df_zf_min["balance_0"].values[1:] + ext_bal_1 = df_zf_min["balance_1"].values[1:] + n_zf = min(len(reclamm_zerofee_min), len(ext_bal_0), len(qsim_price_min)) + ext_val_repriced = ( + ext_bal_0[:n_zf] * qsim_price_min[:n_zf] + ext_bal_1[:n_zf] + ) + qsim_growth = reclamm_zerofee_min[:n_zf] / reclamm_zerofee_min[0] + ext_growth = ext_val_repriced[:n_zf] / ext_val_repriced[0] + pct_dev = (qsim_growth / ext_growth - 1) * 100 + days = np.arange(n_zf) / 1440 + + zerofee_title = ( + f"Zero-fee zero-gas reClAMM: quantammsim / reclamm-sim (minute-level) — {tokens_str}\n" + f"params: {list(ONCHAIN_LAUNCH_PARAMS.values())}" + ) + daily_smooth = pd.Series(pct_dev).rolling(1440, center=True, min_periods=720).mean() + + # Plot 1: with daily smoothing overlay + fig2, ax2 = plt.subplots(figsize=(12, 5)) + ax2.plot(days, pct_dev, linewidth=0.5, color="teal", alpha=0.6) + ax2.plot(days, daily_smooth, linewidth=2, color="darkblue", label="daily smoothed") + ax2.axhline(y=0.0, color="gray", linestyle="--", alpha=0.6) + ax2.set_xlabel("days") + ax2.set_ylabel("deviation (%)") + ax2.set_title(zerofee_title, fontsize=11) + ax2.legend(fontsize=9) + ax2.grid(True, alpha=0.3) + plt.tight_layout() + zerofee_path = args.output.replace(".png", "_zerofee_ratio.png") + plt.savefig(zerofee_path, dpi=150, bbox_inches="tight") + print(f"Saved: {zerofee_path}") + plt.close() + + # Plot 2: raw minute-level only (no smoothing) + fig3, ax3 = plt.subplots(figsize=(12, 5)) + ax3.plot(days, pct_dev, linewidth=0.5, color="teal", alpha=0.8) + ax3.axhline(y=0.0, color="gray", linestyle="--", alpha=0.6) + ax3.set_xlabel("days") + ax3.set_ylabel("deviation (%)") + ax3.set_title(zerofee_title, fontsize=11) + ax3.grid(True, alpha=0.3) + plt.tight_layout() + zerofee_raw_path = args.output.replace(".png", "_zerofee_ratio_raw.png") + plt.savefig(zerofee_raw_path, dpi=150, bbox_inches="tight") + print(f"Saved: {zerofee_raw_path}") + plt.close() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 20eb403..1aa0b94 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "pyarrow", "plotly", "bidask", - "Historic_Crypto", + "Historic-Crypto", "gdown", "binance_historical_data", "dask", @@ -30,10 +30,12 @@ ], extras_require={ "dev": [ - "pytest>=6.0", + "pytest>=7.0", + "pytest-cov>=4.0", + "pytest-xdist>=3.0", + "pytest-timeout>=2.0", "black", "flake8", - "pytest-cov", "hypothesis", ], "docs": [ @@ -41,6 +43,10 @@ "sphinx-automodapi", "sphinx-rtd-theme", ], + "calibration": [ + "numpyro>=0.15.0", + "arviz>=0.15.0", + ], }, python_requires=">=3.9", ) diff --git a/tests/pools/reCLAMM/test_reclamm_differentiability.py b/tests/pools/reCLAMM/test_reclamm_differentiability.py new file mode 100644 index 0000000..29fdcbe --- /dev/null +++ b/tests/pools/reCLAMM/test_reclamm_differentiability.py @@ -0,0 +1,167 @@ +"""Differentiability tests for reCLAMM STE-gated training path behavior.""" + +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +from quantammsim.pools.creator import create_pool +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + initialise_reclamm_reserves, + _jax_calc_reclamm_reserves_zero_fees, + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs, +) +from quantammsim.runners.jax_runner_utils import Hashabledict + + +ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) +DEFAULT_POOL_VALUE = 1_000_000.0 +DEFAULT_INITIAL_PRICES = jnp.array([2500.0, 1.0], dtype=jnp.float64) +DEFAULT_PRICE_RATIO = 4.0 +DEFAULT_SHIFT_BASE = 1.0 - 1.0 / 124000.0 +DEFAULT_SECONDS_PER_STEP = 60.0 + + +def _init_pool_state(): + return initialise_reclamm_reserves( + DEFAULT_POOL_VALUE, + DEFAULT_INITIAL_PRICES, + DEFAULT_PRICE_RATIO, + ) + + +def _trending_prices(n_steps): + return jnp.stack( + [jnp.linspace(DEFAULT_INITIAL_PRICES[0], 4200.0, n_steps), jnp.ones((n_steps,))], + axis=1, + ) + + +def test_ste_forward_outputs_are_temperature_invariant(): + """STE hard-forward path should be invariant to STE temperature.""" + reserves, Va, Vb = _init_pool_state() + n_steps = 12 + prices = _trending_prices(n_steps) + fees = jnp.full((n_steps,), 0.003, dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.full((n_steps,), 0.0005, dtype=jnp.float64) + + schedule = np.zeros((n_steps, 4), dtype=np.float64) + schedule[:, 3] = np.nan + schedule[2] = np.array([1.0, 6.0, 7.0, DEFAULT_PRICE_RATIO], dtype=np.float64) + schedule = jnp.asarray(schedule) + + low_temp_reserves, low_temp_fee_revenue = ( + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ste_temperature=3.0, + ) + ) + high_temp_reserves, high_temp_fee_revenue = ( + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ste_temperature=50.0, + ) + ) + npt.assert_allclose(high_temp_reserves, low_temp_reserves, rtol=1e-10, atol=1e-10) + npt.assert_allclose( + high_temp_fee_revenue, low_temp_fee_revenue, rtol=1e-10, atol=1e-10 + ) + + +def test_margin_gradient_is_finite_and_nonzero_in_zero_fee_kernel(): + """Centeredness-margin gradient should flow through always-on STE gates.""" + reserves, Va, Vb = _init_pool_state() + n_steps = 6 + prices = jnp.tile(DEFAULT_INITIAL_PRICES, (n_steps, 1)) + margin = jnp.float64(1.0) + + def _loss(centeredness_margin): + reserves_out = _jax_calc_reclamm_reserves_zero_fees( + reserves, + Va, + Vb, + prices, + centeredness_margin=centeredness_margin, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + ste_temperature=25.0, + ) + return jnp.sum(reserves_out[-1]) + + grad_val = jax.grad(_loss)(margin) + + assert jnp.isfinite(grad_val) + assert jnp.abs(grad_val) > 1e-9 + + +def test_pool_zero_fee_path_uses_configured_ste_temperature(): + """Pool-level path should pass STE temperature through to kernel gradients.""" + pool = create_pool("reclamm") + n_steps = 6 + prices = jnp.tile(DEFAULT_INITIAL_PRICES, (n_steps, 1)) + start_index = jnp.array([0, 0], dtype=jnp.int32) + + run_fp_low_temp = Hashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": DEFAULT_POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "ste_temperature": 2.0, + } + ) + run_fp_high_temp = Hashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": DEFAULT_POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "ste_temperature": 50.0, + } + ) + + def _loss(centeredness_margin, run_fingerprint): + params = { + "price_ratio": jnp.float64(DEFAULT_PRICE_RATIO), + "centeredness_margin": centeredness_margin, + "daily_price_shift_base": jnp.float64(DEFAULT_SHIFT_BASE), + } + reserves_out = pool.calculate_reserves_zero_fees( + params, run_fingerprint, prices, start_index + ) + return jnp.sum(reserves_out[-1]) + + margin = jnp.float64(1.0) + low_temp_grad = jax.grad(lambda m: _loss(m, run_fp_low_temp))(margin) + high_temp_grad = jax.grad(lambda m: _loss(m, run_fp_high_temp))(margin) + + assert jnp.isfinite(low_temp_grad) + assert jnp.isfinite(high_temp_grad) + assert jnp.abs(low_temp_grad) > 1e-9 + assert jnp.abs(high_temp_grad) > 1e-9 + assert jnp.abs(high_temp_grad) > jnp.abs(low_temp_grad) * 1.5 diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index bfbd21b..106f8a3 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -297,6 +297,7 @@ def test_pool_method_with_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -342,6 +343,7 @@ def test_pool_method_with_dynamic_inputs(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -408,6 +410,7 @@ def test_forward_pass_returns_fee_revenue(self): "rule": "reclamm", "training_data_kind": "historic", "do_trades": False, + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) diff --git a/tests/pools/reCLAMM/test_reclamm_noise_volume.py b/tests/pools/reCLAMM/test_reclamm_noise_volume.py index f413a9e..500ae07 100644 --- a/tests/pools/reCLAMM/test_reclamm_noise_volume.py +++ b/tests/pools/reCLAMM/test_reclamm_noise_volume.py @@ -547,6 +547,7 @@ def test_tsoukalas_sqrt_from_fingerprint(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "tsoukalas_sqrt", "reclamm_noise_params": DEFAULT_NOISE_PARAMS, + "ste_temperature": 10.0, }) # Fingerprint without noise @@ -563,6 +564,7 @@ def test_tsoukalas_sqrt_from_fingerprint(self): "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "arb_only", + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -618,6 +620,7 @@ def test_volatility_auto_computed_affects_fee_revenue(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, } fp_tsoukalas = Hashabledict({ @@ -946,6 +949,7 @@ def test_loglinear_from_fingerprint(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "loglinear", "reclamm_noise_params": loglinear_params, + "ste_temperature": 10.0, }) fp_arb_only = Hashabledict({ @@ -961,6 +965,7 @@ def test_loglinear_from_fingerprint(self): "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "arb_only", + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index c795176..58d9dae 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -275,6 +275,7 @@ def test_calculate_reserves_with_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -317,6 +318,7 @@ def test_calculate_reserves_zero_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -356,6 +358,7 @@ def test_calculate_weights(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -488,6 +491,7 @@ def test_fingerprint_dispatch(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "reclamm_interpolation_method": "constant_arc_length", "reclamm_arc_length_speed": None, # auto-calibrate + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -693,6 +697,7 @@ def test_learnable_arc_length_speed_forward_pass(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "reclamm_interpolation_method": "constant_arc_length", "reclamm_learn_arc_length_speed": True, + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -924,6 +929,7 @@ def test_noise_trader_ratio_through_pool_class(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, } fp_no_noise = Hashabledict({**base_fp, "noise_trader_ratio": 0.0}) @@ -1286,6 +1292,7 @@ def test_lp_supply_through_pool_class(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0])