From d9471a4bea73ab4f8d9e3847fb897b71c63058d1 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Thu, 5 Mar 2026 16:53:14 +0000 Subject: [PATCH 01/14] enable straight through estimation given boolean flag properties of ratio changes etc for gradient based optimisation methods --- quantammsim/pools/reCLAMM/reclamm.py | 15 ++ quantammsim/pools/reCLAMM/reclamm_reserves.py | 85 +++++++-- .../runners/default_run_fingerprint.py | 1 + .../reCLAMM/test_reclamm_differentiability.py | 167 ++++++++++++++++++ 4 files changed, 256 insertions(+), 12 deletions(-) create mode 100644 tests/pools/reCLAMM/test_reclamm_differentiability.py diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 762301c..0b1e2d4 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -198,6 +198,11 @@ def _resolve_fees(params, run_fingerprint): return jnp.squeeze(params["fees"]) return run_fingerprint["fees"] + @staticmethod + def _resolve_ste_temperature(run_fingerprint): + """Resolve STE gate temperature for differentiable reCLAMM transitions.""" + return run_fingerprint.get("ste_temperature", 10.0) + @partial(jit, static_argnums=(2,)) def calculate_reserves_with_fees( self, @@ -208,6 +213,7 @@ def calculate_reserves_with_fees( additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_with_fees( @@ -225,6 +231,7 @@ def calculate_reserves_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -246,6 +253,7 @@ def calculate_reserves_and_fee_revenue_with_fees( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( @@ -263,6 +271,7 @@ def calculate_reserves_and_fee_revenue_with_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, ) return ( jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape), @@ -288,6 +297,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: @@ -317,6 +327,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, ) @partial(jit, static_argnums=(2,)) @@ -330,6 +341,7 @@ def _calculate_reserves_zero_fees( ) -> jnp.ndarray: """Protected zero-fee implementation for hooks and weight calculation.""" s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_zero_fees( @@ -340,6 +352,7 @@ def _calculate_reserves_zero_fees( s.seconds_per_step, arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, + ste_temperature=ste_temperature, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -366,6 +379,7 @@ def calculate_reserves_with_dynamic_inputs( additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: @@ -395,6 +409,7 @@ def calculate_reserves_with_dynamic_inputs( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, ) def init_base_parameters( diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 5f1ac85..7e0bab8 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -17,7 +17,8 @@ import jax.numpy as jnp from jax import jit -from jax.lax import scan, cond +from jax.lax import scan, cond, stop_gradient +from jax.nn import sigmoid from jax.tree_util import Partial from functools import partial @@ -48,6 +49,35 @@ # Pure math functions # --------------------------------------------------------------------------- +def _ste_gate(hard_bool, soft_value): + """Hard forward / soft backward gate.""" + hard_value = hard_bool.astype(soft_value.dtype) + return soft_value + stop_gradient(hard_value - soft_value) + + +def _ste_greater_than(x, threshold, temperature=10.0): + hard = x > threshold + soft = sigmoid(temperature * (x - threshold)) + return _ste_gate(hard, soft) + + +def _ste_less_than(x, threshold, temperature=10.0): + hard = x < threshold + soft = sigmoid(temperature * (threshold - x)) + return _ste_gate(hard, soft) + + +def _ste_greater_equal(x, threshold, temperature=10.0): + hard = x >= threshold + soft = sigmoid(temperature * (x - threshold)) + return _ste_gate(hard, soft) + + +def _ste_select(mask, when_true, when_false): + """Select between two values using a 0/1 gate that can carry STE gradients.""" + return mask * when_true + (1.0 - mask) * when_false + + def compute_invariant(Ra, Rb, Va, Vb): """Compute constant-product invariant L = (Ra + Va) * (Rb + Vb).""" return (Ra + Va) * (Rb + Vb) @@ -598,6 +628,7 @@ def _reclamm_scan_step_zero_fees( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """Single scan step for zero-fee reClAMM pool. @@ -617,7 +648,6 @@ def _reclamm_scan_step_zero_fees( # Step 1: Update virtual balances if out of range centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) sqrt_Q = jnp.sqrt(compute_price_ratio(Ra, Rb, Va, Vb)) - out_of_range = centeredness < centeredness_margin market_price = prices[0] / prices[1] # Centeredness-proportional scaling: margin/centeredness multiplier @@ -648,8 +678,11 @@ def _reclamm_scan_step_zero_fees( Va_updated = jnp.where(use_cal, Va_cal, Va_geo) Vb_updated = jnp.where(use_cal, Vb_cal, Vb_geo) - Va = jnp.where(out_of_range, Va_updated, Va) - Vb = jnp.where(out_of_range, Vb_updated, Vb) + out_of_range_gate = _ste_less_than( + centeredness, centeredness_margin, ste_temperature + ) + Va = _ste_select(out_of_range_gate, Va_updated, Va) + Vb = _ste_select(out_of_range_gate, Vb_updated, Vb) # Step 2: Analytical zero-fee arb on effective reserves L = compute_invariant(Ra, Rb, Va, Vb) @@ -702,12 +735,14 @@ def _reclamm_scan_step_zero_fees_full_state( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """TEST-ONLY: scan step that outputs (reserves, Va, Vb).""" new_carry, new_reserves = _reclamm_scan_step_zero_fees( carry_list, prices, centeredness_margin, daily_price_shift_base, seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) return new_carry, (new_reserves, new_carry[1], new_carry[2]) @@ -725,6 +760,7 @@ def _reclamm_scan_step_with_fees_and_revenue( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Single scan step for reClAMM pool with fees, returning LP fee revenue. @@ -858,7 +894,6 @@ def _skip_schedule_state(_): # Step 1: Update virtual balances if out of range centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) sqrt_Q = jnp.sqrt(compute_price_ratio(Ra, Rb, Va, Vb)) - out_of_range = centeredness < centeredness_margin market_price = prices[0] / prices[1] # Centeredness-proportional scaling: margin/centeredness multiplier @@ -888,14 +923,15 @@ def _skip_schedule_state(_): Va_updated = jnp.where(use_cal, Va_cal, Va_geo) Vb_updated = jnp.where(use_cal, Vb_cal, Vb_geo) - Va = jnp.where(out_of_range, Va_updated, Va) - Vb = jnp.where(out_of_range, Vb_updated, Vb) + out_of_range_gate = _ste_less_than( + centeredness, centeredness_margin, ste_temperature + ) + Va = _ste_select(out_of_range_gate, Va_updated, Va) + Vb = _ste_select(out_of_range_gate, Vb_updated, Vb) # Step 2: Compute arb trade using G3M machinery on effective reserves effective_reserves = jnp.array([Ra + Va, Rb + Vb]) - fees_are_being_charged = gamma != 1.0 - # Zero-fee analytical arb L = compute_invariant(Ra, Rb, Va, Vb) market_price = prices[0] / prices[1] @@ -918,15 +954,22 @@ def _skip_schedule_state(_): 0, ) - optimal_arb_trade = jnp.where(fees_are_being_charged, fee_trade, zero_fee_trade) + fees_gate = _ste_greater_than( + jnp.abs(gamma - 1.0), jnp.asarray(1e-12, dtype=gamma.dtype), ste_temperature + ) + optimal_arb_trade = _ste_select(fees_gate, fee_trade, zero_fee_trade) # Check profitability for arb profit_to_arb = -(optimal_arb_trade * prices).sum() - arb_thresh arb_external_cost = 0.5 * arb_fees * (jnp.abs(optimal_arb_trade) * prices).sum() - do_trade = profit_to_arb >= arb_external_cost # Apply trade to REAL reserves only - applied_trade = jnp.where(do_trade, optimal_arb_trade, 0.0) + trade_gate = _ste_greater_equal( + profit_to_arb, arb_external_cost, ste_temperature + ) + applied_trade = _ste_select( + trade_gate, optimal_arb_trade, jnp.zeros_like(optimal_arb_trade) + ) Ra_new = Ra + applied_trade[0] Rb_new = Rb + applied_trade[1] @@ -994,6 +1037,7 @@ def _reclamm_scan_step_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Single scan step for reClAMM pool with fees (reserves only). @@ -1012,6 +1056,7 @@ def _reclamm_scan_step_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) return new_carry, new_reserves @@ -1029,6 +1074,7 @@ def _reclamm_scan_step_with_fees_full_state( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """TEST-ONLY: fee scan step that also outputs virtual balances.""" new_carry, (new_reserves, _fee_rev) = _reclamm_scan_step_with_fees_and_revenue( @@ -1043,6 +1089,7 @@ def _reclamm_scan_step_with_fees_full_state( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) return new_carry, (new_reserves, new_carry[1], new_carry[2]) @@ -1058,6 +1105,7 @@ def _jax_calc_reclamm_reserves_zero_fees( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """Calculate reClAMM reserves over time with zero fees. @@ -1092,6 +1140,7 @@ def _jax_calc_reclamm_reserves_zero_fees( seconds_per_step=seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) carry_init = [initial_reserves, initial_Va, initial_Vb] @@ -1110,6 +1159,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( seconds_per_step, arc_length_speed=0.0, centeredness_scaling=False, + ste_temperature=10.0, ): """TEST-ONLY: Like _jax_calc_reclamm_reserves_zero_fees but returns Va/Vb. @@ -1126,6 +1176,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( seconds_per_step=seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, + ste_temperature=ste_temperature, ) carry_init = [initial_reserves, initial_Va, initial_Vb] @@ -1149,6 +1200,7 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Calculate reClAMM reserves over time with fees. @@ -1189,6 +1241,7 @@ def _jax_calc_reclamm_reserves_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) carry_init = [ @@ -1230,6 +1283,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Calculate reClAMM reserves with time-varying fees/arb arrays.""" n_assets = 2 @@ -1279,6 +1333,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) carry_init = [ @@ -1320,6 +1375,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """TEST-ONLY: dynamic-input reserve path returning virtual-balance history.""" n_assets = 2 @@ -1368,6 +1424,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) carry_init = [ @@ -1406,6 +1463,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Calculate reClAMM reserves and LP fee revenue over time with fees. @@ -1448,6 +1506,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) carry_init = [ @@ -1489,6 +1548,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=0.0, centeredness_scaling=False, protocol_fee_split=0.0, + ste_temperature=10.0, ): """Calculate reClAMM reserves and LP fee revenue with time-varying fees/arb arrays. @@ -1544,6 +1604,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, ) carry_init = [ diff --git a/quantammsim/runners/default_run_fingerprint.py b/quantammsim/runners/default_run_fingerprint.py index 81bc806..b02de68 100644 --- a/quantammsim/runners/default_run_fingerprint.py +++ b/quantammsim/runners/default_run_fingerprint.py @@ -98,6 +98,7 @@ "reclamm_interpolation_method": "geometric", # "geometric" or "constant_arc_length" "reclamm_arc_length_speed": None, # auto-calibrate from geometric onset if None "reclamm_centeredness_scaling": False, # scale speed by margin/centeredness + "ste_temperature": 10.0, # STE gate sharpness; higher is closer to hard threshold "reclamm_learn_arc_length_speed": False, # include arc_length_speed in trainable params "reclamm_use_shift_exponent": False, # parametrise shift rate as shift_exponent (log-friendly) "reclamm_learn_fees": False, # include fees in trainable params (Optuna search over fee level) diff --git a/tests/pools/reCLAMM/test_reclamm_differentiability.py b/tests/pools/reCLAMM/test_reclamm_differentiability.py new file mode 100644 index 0000000..29fdcbe --- /dev/null +++ b/tests/pools/reCLAMM/test_reclamm_differentiability.py @@ -0,0 +1,167 @@ +"""Differentiability tests for reCLAMM STE-gated training path behavior.""" + +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +from quantammsim.pools.creator import create_pool +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + initialise_reclamm_reserves, + _jax_calc_reclamm_reserves_zero_fees, + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs, +) +from quantammsim.runners.jax_runner_utils import Hashabledict + + +ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) +DEFAULT_POOL_VALUE = 1_000_000.0 +DEFAULT_INITIAL_PRICES = jnp.array([2500.0, 1.0], dtype=jnp.float64) +DEFAULT_PRICE_RATIO = 4.0 +DEFAULT_SHIFT_BASE = 1.0 - 1.0 / 124000.0 +DEFAULT_SECONDS_PER_STEP = 60.0 + + +def _init_pool_state(): + return initialise_reclamm_reserves( + DEFAULT_POOL_VALUE, + DEFAULT_INITIAL_PRICES, + DEFAULT_PRICE_RATIO, + ) + + +def _trending_prices(n_steps): + return jnp.stack( + [jnp.linspace(DEFAULT_INITIAL_PRICES[0], 4200.0, n_steps), jnp.ones((n_steps,))], + axis=1, + ) + + +def test_ste_forward_outputs_are_temperature_invariant(): + """STE hard-forward path should be invariant to STE temperature.""" + reserves, Va, Vb = _init_pool_state() + n_steps = 12 + prices = _trending_prices(n_steps) + fees = jnp.full((n_steps,), 0.003, dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.full((n_steps,), 0.0005, dtype=jnp.float64) + + schedule = np.zeros((n_steps, 4), dtype=np.float64) + schedule[:, 3] = np.nan + schedule[2] = np.array([1.0, 6.0, 7.0, DEFAULT_PRICE_RATIO], dtype=np.float64) + schedule = jnp.asarray(schedule) + + low_temp_reserves, low_temp_fee_revenue = ( + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ste_temperature=3.0, + ) + ) + high_temp_reserves, high_temp_fee_revenue = ( + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ste_temperature=50.0, + ) + ) + npt.assert_allclose(high_temp_reserves, low_temp_reserves, rtol=1e-10, atol=1e-10) + npt.assert_allclose( + high_temp_fee_revenue, low_temp_fee_revenue, rtol=1e-10, atol=1e-10 + ) + + +def test_margin_gradient_is_finite_and_nonzero_in_zero_fee_kernel(): + """Centeredness-margin gradient should flow through always-on STE gates.""" + reserves, Va, Vb = _init_pool_state() + n_steps = 6 + prices = jnp.tile(DEFAULT_INITIAL_PRICES, (n_steps, 1)) + margin = jnp.float64(1.0) + + def _loss(centeredness_margin): + reserves_out = _jax_calc_reclamm_reserves_zero_fees( + reserves, + Va, + Vb, + prices, + centeredness_margin=centeredness_margin, + daily_price_shift_base=DEFAULT_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + ste_temperature=25.0, + ) + return jnp.sum(reserves_out[-1]) + + grad_val = jax.grad(_loss)(margin) + + assert jnp.isfinite(grad_val) + assert jnp.abs(grad_val) > 1e-9 + + +def test_pool_zero_fee_path_uses_configured_ste_temperature(): + """Pool-level path should pass STE temperature through to kernel gradients.""" + pool = create_pool("reclamm") + n_steps = 6 + prices = jnp.tile(DEFAULT_INITIAL_PRICES, (n_steps, 1)) + start_index = jnp.array([0, 0], dtype=jnp.int32) + + run_fp_low_temp = Hashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": DEFAULT_POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "ste_temperature": 2.0, + } + ) + run_fp_high_temp = Hashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": DEFAULT_POOL_VALUE, + "arb_frequency": 1, + "do_arb": True, + "ste_temperature": 50.0, + } + ) + + def _loss(centeredness_margin, run_fingerprint): + params = { + "price_ratio": jnp.float64(DEFAULT_PRICE_RATIO), + "centeredness_margin": centeredness_margin, + "daily_price_shift_base": jnp.float64(DEFAULT_SHIFT_BASE), + } + reserves_out = pool.calculate_reserves_zero_fees( + params, run_fingerprint, prices, start_index + ) + return jnp.sum(reserves_out[-1]) + + margin = jnp.float64(1.0) + low_temp_grad = jax.grad(lambda m: _loss(m, run_fp_low_temp))(margin) + high_temp_grad = jax.grad(lambda m: _loss(m, run_fp_high_temp))(margin) + + assert jnp.isfinite(low_temp_grad) + assert jnp.isfinite(high_temp_grad) + assert jnp.abs(low_temp_grad) > 1e-9 + assert jnp.abs(high_temp_grad) > 1e-9 + assert jnp.abs(high_temp_grad) > jnp.abs(low_temp_grad) * 1.5 From b6aab22d616654018abd660bb41ef91c3f858717 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 11:39:08 +0000 Subject: [PATCH 02/14] remove default get --- quantammsim/pools/reCLAMM/reclamm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 0b1e2d4..fbfecab 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -201,7 +201,7 @@ def _resolve_fees(params, run_fingerprint): @staticmethod def _resolve_ste_temperature(run_fingerprint): """Resolve STE gate temperature for differentiable reCLAMM transitions.""" - return run_fingerprint.get("ste_temperature", 10.0) + return run_fingerprint.get("ste_temperature") @partial(jit, static_argnums=(2,)) def calculate_reserves_with_fees( From 4e86c4df1c6d6c9256f2d85720b5a1c6e945571c Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:43:53 +0000 Subject: [PATCH 03/14] noise calibration port from private repo --- quantammsim/noise_calibration/__init__.py | 39 + quantammsim/noise_calibration/cli.py | 417 +++++++++++ quantammsim/noise_calibration/constants.py | 60 ++ .../noise_calibration/covariate_encoding.py | 228 ++++++ .../noise_calibration/data_pipeline.py | 516 ++++++++++++++ .../noise_calibration/data_validation.py | 41 ++ quantammsim/noise_calibration/formula_arb.py | 36 + quantammsim/noise_calibration/inference.py | 270 +++++++ quantammsim/noise_calibration/model.py | 444 ++++++++++++ quantammsim/noise_calibration/output.py | 306 ++++++++ quantammsim/noise_calibration/plotting.py | 335 +++++++++ .../noise_calibration/postprocessing.py | 667 ++++++++++++++++++ .../noise_calibration/token_classification.py | 23 + 13 files changed, 3382 insertions(+) create mode 100644 quantammsim/noise_calibration/__init__.py create mode 100644 quantammsim/noise_calibration/cli.py create mode 100644 quantammsim/noise_calibration/constants.py create mode 100644 quantammsim/noise_calibration/covariate_encoding.py create mode 100644 quantammsim/noise_calibration/data_pipeline.py create mode 100644 quantammsim/noise_calibration/data_validation.py create mode 100644 quantammsim/noise_calibration/formula_arb.py create mode 100644 quantammsim/noise_calibration/inference.py create mode 100644 quantammsim/noise_calibration/model.py create mode 100644 quantammsim/noise_calibration/output.py create mode 100644 quantammsim/noise_calibration/plotting.py create mode 100644 quantammsim/noise_calibration/postprocessing.py create mode 100644 quantammsim/noise_calibration/token_classification.py diff --git a/quantammsim/noise_calibration/__init__.py b/quantammsim/noise_calibration/__init__.py new file mode 100644 index 0000000..f1fddde --- /dev/null +++ b/quantammsim/noise_calibration/__init__.py @@ -0,0 +1,39 @@ +"""Noise calibration package for Balancer pool volume models. + +Public API re-exports from submodules. +""" + +# scipy.signal patch (must run before arviz import) +try: + from scipy.signal import gaussian as _ # noqa: F401 +except ImportError: + from scipy.signal.windows import gaussian as _gauss + import scipy.signal + scipy.signal.gaussian = _gauss + +from .constants import ( + K_COEFF, COEFF_NAMES, BALANCER_API_URL, BALANCER_API_CHAINS, CACHE_DIR, + K_CLUSTERS_DEFAULT, K_FEATURES_DEFAULT, +) +from .token_classification import classify_token_tier, _normalise_symbol +from .data_pipeline import ( + _graphql_request, enumerate_balancer_pools, fetch_pool_snapshots, + fetch_all_snapshots, fetch_token_prices, compute_pair_volatility, + assemble_panel, +) +from .data_validation import validate_panel +from .covariate_encoding import encode_covariates, encode_covariates_structural +from .model import noise_model, noise_model_dp_sigma, noise_model_ibp, noise_model_ibp_dp, stick_breaking_weights, structural_noise_model +from .formula_arb import formula_arb_volume_daily_jax +from .inference import ( + _get_theta_samples, _build_model_kwargs, run_svi, run_nuts, + run_svi_then_nuts, +) +from .postprocessing import ( + extract_noise_params, predict_new_pool, check_convergence, + run_prior_predictive, assign_dp_clusters, assign_ibp_dp_joint, + extract_structural_params, predict_new_pool_structural, +) +from .plotting import plot_diagnostics +from .output import generate_output_json, _save_sample_cache +from .cli import main diff --git a/quantammsim/noise_calibration/cli.py b/quantammsim/noise_calibration/cli.py new file mode 100644 index 0000000..3e169ac --- /dev/null +++ b/quantammsim/noise_calibration/cli.py @@ -0,0 +1,417 @@ +"""CLI entry point for noise calibration.""" + +import argparse +import json +import os +import sys +from datetime import date, timedelta + +import numpy as np +import pandas as pd + +from .constants import CACHE_DIR +from .data_pipeline import ( + enumerate_balancer_pools, fetch_all_snapshots, + fetch_token_prices, assemble_panel, +) +from .data_validation import validate_panel +from .covariate_encoding import encode_covariates +from .inference import run_svi, run_nuts, run_svi_then_nuts +from .postprocessing import ( + extract_noise_params, predict_new_pool, + check_convergence, run_prior_predictive, +) +from .plotting import plot_diagnostics +from .output import generate_output_json, _save_sample_cache + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Unified Bayesian hierarchical noise volume model " + "for Balancer pools (gold standard)" + ) + + # Actions + parser.add_argument("--fetch", action="store_true", + help="Fetch data from Balancer API") + parser.add_argument("--fit", action="store_true", + help="Run inference (SVI default)") + parser.add_argument("--nuts", action="store_true", + help="Use NUTS instead of SVI") + parser.add_argument("--svi-init-nuts", action="store_true", + help="SVI-initialized NUTS (fast warmup)") + parser.add_argument("--plot", action="store_true", + help="Generate diagnostic plots") + parser.add_argument("--prior-predictive", action="store_true", + help="Include prior predictive check") + parser.add_argument("--validate", action="store_true", + help="Run data validation pass") + parser.add_argument("--predict", action="store_true", + help="Predict for unseen pool") + + # Output + parser.add_argument("--output", default=None, + help="Output JSON path") + parser.add_argument("--output-dir", default="results", + help="Plot output directory (default: results)") + + # Predict args + parser.add_argument("--chain", default=None, + help="Chain for --predict") + parser.add_argument("--tokens", nargs="+", default=None, + help="Tokens for --predict") + parser.add_argument("--fee", type=float, default=0.003, + help="Fee for --predict") + + # NUTS hyperparameters + parser.add_argument("--num-warmup", type=int, default=1000, + help="NUTS warmup iterations (default: 1000)") + parser.add_argument("--num-samples", type=int, default=2000, + help="NUTS/SVI samples (default: 2000)") + parser.add_argument("--num-chains", type=int, default=4, + help="NUTS chains (default: 4)") + parser.add_argument("--target-accept", type=float, default=0.85, + help="NUTS target accept prob (default: 0.85)") + parser.add_argument("--max-tree-depth", type=int, default=10, + help="NUTS max tree depth (default: 10)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed (default: 42)") + + # SVI hyperparameters + parser.add_argument("--svi-steps", type=int, default=20000, + help="SVI optimization steps (default: 20000)") + parser.add_argument("--svi-lr", type=float, default=1e-3, + help="SVI learning rate (default: 1e-3)") + + # Model variant + parser.add_argument("--model", choices=["tier", "dp_sigma", "ibp", "ibp_dp", + "structural"], + default="tier", + help="Noise model variant: 'tier' (per-tier sigma_eps), " + "'dp_sigma' (DP mixture on sigma_eps), " + "'ibp' (IBP latent features), " + "'ibp_dp' (IBP features + DP noise clusters), or " + "'structural' (structural mixture: arb + MoE noise)") + parser.add_argument("--k-clusters", type=int, default=6, + help="Number of DP mixture components " + "(capacity ceiling, default: 6)") + parser.add_argument("--k-features", type=int, default=6, + help="Number of IBP latent features " + "(default: 6)") + + # Data + parser.add_argument("--train-days", type=int, default=90, + help="Use only the last N days of data for fitting " + "(default: 90). Aligns with Balancer API hourly price " + "coverage window. Set to 0 to use all data.") + parser.add_argument("--min-tvl", type=float, default=10000.0, + help="Pool enumeration TVL filter") + parser.add_argument("--cache-dir", default=None, + help="Cache directory") + parser.add_argument("--device", choices=["cpu", "gpu", "auto"], + default="auto", + help="JAX device (default: auto)") + + return parser.parse_args() + + +def main(): + args = _parse_args() + + if not any([args.fetch, args.fit, args.predict, args.validate]): + print("ERROR: At least one of --fetch, --fit, --predict, --validate " + "is required", file=sys.stderr) + sys.exit(1) + + cache_dir = args.cache_dir or CACHE_DIR + + # --- JAX setup (BEFORE any JAX ops / imports) --- + if args.fit or args.predict or args.prior_predictive: + # Set device before importing JAX + if args.device == "cpu": + os.environ.setdefault("JAX_PLATFORMS", "cpu") + elif args.device == "gpu": + os.environ.setdefault("JAX_PLATFORMS", "cuda") + # auto: don't touch JAX_PLATFORMS, let JAX pick + + # Set host device count for NUTS multi-chain BEFORE JAX init + if args.nuts or args.svi_init_nuts: + import numpyro as _np_pre + _np_pre.set_host_device_count( + min(args.num_chains, os.cpu_count() or 4) + ) + + import jax + import numpyro + numpyro.enable_x64() + + # --- File paths --- + pools_cache = os.path.join(cache_dir, "pools.parquet") + snaps_cache = os.path.join(cache_dir, "pool_snapshots.parquet") + prices_cache = os.path.join(cache_dir, "token_prices") + panel_cache = os.path.join(cache_dir, "panel.parquet") + + # --- Fetch --- + if args.fetch: + print("Phase 1: Fetching data from Balancer API") + print("=" * 60) + + print("\n1. Enumerating pools...") + pools_df = enumerate_balancer_pools(min_tvl=args.min_tvl) + os.makedirs(cache_dir, exist_ok=True) + pools_df.to_parquet(pools_cache, index=False) + print(f" Saved {len(pools_df)} pools -> {pools_cache}") + + print("\n2. Fetching daily snapshots...") + snapshots_df = fetch_all_snapshots(pools_df, cache_path=snaps_cache) + + print("\n3. Fetching token prices...") + token_addr_by_chain = {} + for _, pool in pools_df.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + print("\n4. Assembling panel (with lagged TVL)...") + panel = assemble_panel(pools_df, snapshots_df, token_prices) + panel.to_parquet(panel_cache, index=False) + print(f" Saved panel -> {panel_cache}") + + print(f"\nFetch complete. Panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # --- Validate --- + if args.validate: + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + panel = pd.read_parquet(panel_cache) + validate_panel(panel) + + # --- Fit --- + if args.fit: + print("\nUnified Noise Volume Model") + print("=" * 60) + + # Load panel + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_cache) + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + # Filter to recent window for training + if args.train_days > 0: + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=args.train_days) + n_before = len(panel) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + print(f" Filtered to last {args.train_days} days " + f"(>= {cutoff}): {len(panel)} obs " + f"(dropped {n_before - len(panel)})") + + # Ensure lagged TVL exists (in case loaded from old cache) + if "log_tvl_lag1" not in panel.columns: + print(" Adding lagged TVL to cached panel...") + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + # Filter: need at least 10 days per pool + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid_pools)].copy() + print(f" After filtering (>= 10 days): {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # Select model variant + if args.model == "structural": + from .model import structural_noise_model + from .covariate_encoding import encode_covariates_structural + model_fn = structural_noise_model + data = encode_covariates_structural(panel) + print(f" Model: structural mixture (arb + MoE noise)") + elif args.model == "ibp_dp": + from .model import noise_model_ibp_dp + model_fn = noise_model_ibp_dp + data = encode_covariates(panel, include_tiers=False) + data["K_features"] = args.k_features + data["K_clusters"] = args.k_clusters + print(f" Model: IBP+DP hybrid " + f"(K_features={args.k_features}, " + f"K_clusters={args.k_clusters})") + elif args.model == "ibp": + from .model import noise_model_ibp + model_fn = noise_model_ibp + data = encode_covariates(panel, include_tiers=False) + data["K_features"] = args.k_features + print(f" Model: IBP latent features " + f"(K_features={args.k_features})") + elif args.model == "dp_sigma": + from .model import noise_model_dp_sigma + model_fn = noise_model_dp_sigma + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = args.k_clusters + print(f" Model: DP mixture on sigma_eps " + f"(K_clusters={args.k_clusters})") + else: + model_fn = None # default = noise_model + data = encode_covariates(panel) + + # Prior predictive + prior_samples = None + if args.prior_predictive: + print("\n Running prior predictive check...") + prior_samples = run_prior_predictive(data, model_fn=model_fn) + + # Inference + mcmc_obj = None + elbo_losses = None + inference_config = {"seed": args.seed} + + if args.svi_init_nuts: + inference_config["method"] = "svi_init_nuts" + inference_config["svi_steps"] = args.svi_steps + inference_config["svi_lr"] = args.svi_lr + inference_config["num_warmup"] = args.num_warmup + inference_config["num_samples"] = args.num_samples + inference_config["num_chains"] = args.num_chains + inference_config["target_accept"] = args.target_accept + inference_config["max_tree_depth"] = args.max_tree_depth + + mcmc_obj, elbo_losses = run_svi_then_nuts( + data, + svi_steps=args.svi_steps, + svi_lr=args.svi_lr, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + model_fn=model_fn, + ) + samples = mcmc_obj + convergence = check_convergence(mcmc_obj, method="nuts") + + elif args.nuts: + inference_config["method"] = "nuts" + inference_config["num_warmup"] = args.num_warmup + inference_config["num_samples"] = args.num_samples + inference_config["num_chains"] = args.num_chains + inference_config["target_accept"] = args.target_accept + inference_config["max_tree_depth"] = args.max_tree_depth + + mcmc_obj = run_nuts( + data, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + model_fn=model_fn, + ) + samples = mcmc_obj + convergence = check_convergence(mcmc_obj, method="nuts") + + else: + inference_config["method"] = "svi" + inference_config["svi_steps"] = args.svi_steps + inference_config["svi_lr"] = args.svi_lr + inference_config["num_samples"] = args.num_samples + + samples, elbo_losses = run_svi( + data, + num_steps=args.svi_steps, + lr=args.svi_lr, + seed=args.seed, + num_samples=args.num_samples, + model_fn=model_fn, + ) + convergence = check_convergence(elbo_losses, method="svi") + + if args.model == "structural": + from .postprocessing import extract_structural_params + pool_params = extract_structural_params(samples, data) + arb_freqs = [p["arb_frequency"] for p in pool_params] + print(f"\n Per-pool arb_frequency: " + f"mean={np.mean(arb_freqs):.1f}, " + f"range=[{np.min(arb_freqs)}, {np.max(arb_freqs)}]") + else: + pool_params = extract_noise_params(samples, data) + b_c_vals = [p["noise_params"]["b_c"] for p in pool_params] + b_0_vals = [p["noise_params"]["b_0"] for p in pool_params] + print(f"\n Per-pool b_c: mean={np.mean(b_c_vals):.3f}, " + f"std={np.std(b_c_vals):.3f}, " + f"range=[{np.min(b_c_vals):.3f}, {np.max(b_c_vals):.3f}]") + print(f" Per-pool b_0: mean={np.mean(b_0_vals):.3f}, " + f"std={np.std(b_0_vals):.3f}") + + if args.output: + generate_output_json( + pool_params, samples, data, convergence, + args.output, inference_config, + ) + + if args.plot: + print("\nGenerating diagnostic plots...") + plot_diagnostics( + samples, data, output_dir=args.output_dir, + elbo_losses=elbo_losses, mcmc=mcmc_obj, + prior_samples=prior_samples, + ) + + # Cache samples for --predict + _save_sample_cache(samples, data, cache_dir) + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + print("ERROR: --predict requires --chain and --tokens", + file=sys.stderr) + sys.exit(1) + + # Load cached samples + sample_cache = os.path.join(cache_dir, "unified_samples.npz") + data_cache = os.path.join(cache_dir, "unified_data.json") + + if not os.path.exists(sample_cache): + print(f"ERROR: Sample cache not found at {sample_cache}", + file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + cached = np.load(sample_cache) + sample_dict = {k: cached[k] for k in cached.files} + + with open(data_cache) as f: + data_meta = json.load(f) + + result = predict_new_pool( + sample_dict, data_meta, args.chain, args.tokens, args.fee + ) + print(json.dumps(result, indent=2)) diff --git a/quantammsim/noise_calibration/constants.py b/quantammsim/noise_calibration/constants.py new file mode 100644 index 0000000..05a217d --- /dev/null +++ b/quantammsim/noise_calibration/constants.py @@ -0,0 +1,60 @@ +"""Constants for noise calibration.""" + +import os + +K_COEFF = 4 +COEFF_NAMES = ["intercept", "b_tvl", "b_sigma", "b_weekend"] + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAINS = [ + "MAINNET", "POLYGON", "ARBITRUM", "GNOSIS", "BASE", "SONIC", "OPTIMISM", + "AVALANCHE", +] + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "local_data", "noise_calibration", +) + +# Tier 0: blue-chip — top by volume, wrapped native, major stables +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", +} + +K_CLUSTERS_DEFAULT = 6 +K_FEATURES_DEFAULT = 6 + +# Structural model: observation-level covariates (expanded from K_COEFF=4) +K_OBS_COEFF = 8 +OBS_COEFF_NAMES = [ + "intercept", "b_tvl", "b_sigma", + "b_tvl_sigma", "b_tvl_fee", "b_sigma_fee", + "b_dow_sin", "b_dow_cos", +] + +# Gas costs per arb transaction (USD) by chain +GAS_COSTS = { + "MAINNET": None, # time-varying, loaded from CSV + "POLYGON": 0.005, + "ARBITRUM": 0.005, + "BASE": 0.005, + "GNOSIS": 0.01, + "OPTIMISM": 0.005, + "SONIC": 0.005, + "AVALANCHE": 0.005, + "MODE": 0.005, + "FRAXTAL": 0.005, +} + +# Tier 1: mid-cap DeFi blue-chips (approx CoinGecko rank < 200) +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} diff --git a/quantammsim/noise_calibration/covariate_encoding.py b/quantammsim/noise_calibration/covariate_encoding.py new file mode 100644 index 0000000..a298a6b --- /dev/null +++ b/quantammsim/noise_calibration/covariate_encoding.py @@ -0,0 +1,228 @@ +"""Covariate encoding for the hierarchical noise model.""" + +import numpy as np +import pandas as pd + +from .constants import K_COEFF, K_OBS_COEFF, GAS_COSTS + + +def encode_covariates(panel: pd.DataFrame, include_tiers: bool = True) -> dict: + """Build NumPyro-ready arrays from the panel DataFrame. + + Returns dict with arrays for the model plus metadata for output/prediction. + Key difference from hierarchical script: x_obs uses log_tvl_lag1 not log_tvl. + """ + pool_meta = panel.drop_duplicates("pool_id").reset_index(drop=True) + pool_ids = pool_meta["pool_id"].values + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + N_pools = len(pool_ids) + + pool_idx = panel["pool_id"].map(pool_id_to_idx).values + + # --- Build X_pool (pool-level covariates, data-driven) --- + chains = sorted(panel["chain"].unique()) + ref_chain = chains[0] + chain_cols = [] + chain_names = [] + for c in chains[1:]: + chain_cols.append((pool_meta["chain"] == c).astype(float).values) + chain_names.append(f"chain_{c}") + + tier_a_vals = sorted(pool_meta["tier_A"].astype(str).unique()) + ref_tier_a = tier_a_vals[0] + tier_a_cols = [] + tier_a_names = [] + if include_tiers: + for t in tier_a_vals[1:]: + tier_a_cols.append( + (pool_meta["tier_A"].astype(str) == t).astype(float).values + ) + tier_a_names.append(f"tier_A_{t}") + + tier_b_vals = sorted(pool_meta["tier_B"].astype(str).unique()) + ref_tier_b = tier_b_vals[0] + tier_b_cols = [] + tier_b_names = [] + if include_tiers: + for t in tier_b_vals[1:]: + tier_b_cols.append( + (pool_meta["tier_B"].astype(str) == t).astype(float).values + ) + tier_b_names.append(f"tier_B_{t}") + + columns = [np.ones((N_pools, 1))] + col_names = ["intercept"] + + for arr, name in zip(chain_cols, chain_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_a_cols, tier_a_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_b_cols, tier_b_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + columns.append(pool_meta["log_fee"].values.reshape(-1, 1)) + col_names.append("log_fee") + + X_pool = np.hstack(columns) + K_cov = X_pool.shape[1] + + # --- Observation-level arrays (uses LAGGED TVL) --- + x_obs = np.column_stack([ + np.ones(len(panel)), + panel["log_tvl_lag1"].values, + panel["volatility"].values, + panel["weekend"].values, + ]).astype(np.float64) + + y_obs = panel["log_volume"].values.astype(np.float64) + + # --- Per-pool tier_A index for per-tier sigma_eps --- + tier_A_per_pool = pool_meta["tier_A"].values.astype(np.int32) + + print(f" Encoded: N_obs={len(y_obs)}, N_pools={N_pools}, " + f"K_coeff={K_COEFF}, K_cov={K_cov}") + print(f" Covariates: {col_names}") + print(f" Tier distribution: " + f"T0={np.sum(tier_A_per_pool == 0)}, " + f"T1={np.sum(tier_A_per_pool == 1)}, " + f"T2={np.sum(tier_A_per_pool == 2)}") + + return { + "pool_idx": pool_idx.astype(np.int32), + "X_pool": X_pool.astype(np.float64), + "x_obs": x_obs, + "y_obs": y_obs, + "pool_ids": list(pool_ids), + "pool_meta": pool_meta, + "covariate_names": col_names, + "tier_A_per_pool": tier_A_per_pool, + "N_pools": N_pools, + "K_cov": K_cov, + "ref_chain": ref_chain, + "ref_tier_a": ref_tier_a, + "ref_tier_b": ref_tier_b, + "chains": chains, + } + + +def _tier_pair_idx(a: int, b: int) -> int: + """Encode (tier_A, tier_B) pair as a single index. + + Upper triangle of 3x3 grid: + (0,0)->0, (0,1)->1, (0,2)->2, (1,1)->3, (1,2)->4, (2,2)->5. + """ + return a * (5 - a) // 2 + b - a + + +def encode_covariates_structural( + panel: pd.DataFrame, + gas: np.ndarray = None, +) -> dict: + """Build NumPyro-ready arrays for the structural mixture model. + + Extends encode_covariates with: + - x_obs: 8 columns (intercept, tvl, log_sigma, interactions, DOW harmonics) + - Additional arrays: sigma_daily, fee, gas, chain_idx, tier_idx, lag_log_tvl + - n_chains, n_tiers computed from panel + + Parameters + ---------- + panel : pd.DataFrame + Output of assemble_panel(), must have log_sigma, dow_sin, dow_cos, + tvl_x_sigma, tvl_x_fee, sigma_x_fee columns. + gas : np.ndarray, optional + Per-observation gas costs in USD. If None, uses default (0.01 for all). + """ + # Ensure structural columns exist (compute from base columns if missing) + if "log_sigma" not in panel.columns: + panel = panel.copy() + panel["log_sigma"] = np.log(np.maximum(panel["volatility"].values, 1e-6)) + dow = panel["date"].apply( + lambda d: d.weekday() if hasattr(d, "weekday") + else pd.Timestamp(d).weekday() + ) + panel["dow_sin"] = np.sin(2.0 * np.pi * dow / 7.0) + panel["dow_cos"] = np.cos(2.0 * np.pi * dow / 7.0) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + # Reuse X_pool construction from encode_covariates (with tiers for gating) + base = encode_covariates(panel, include_tiers=True) + + # --- Observation-level x_obs: 8 columns --- + x_obs = np.column_stack([ + np.ones(len(panel)), # intercept + panel["log_tvl_lag1"].values, # lagged TVL + panel["log_sigma"].values, # log(volatility) + panel["tvl_x_sigma"].values, # tvl × sigma interaction + panel["tvl_x_fee"].values, # tvl × fee interaction + panel["sigma_x_fee"].values, # sigma × fee interaction + panel["dow_sin"].values, # DOW harmonic sin + panel["dow_cos"].values, # DOW harmonic cos + ]).astype(np.float64) + + # --- Additional arrays for the structural model --- + sigma_daily = (panel["volatility"] / np.sqrt(365.0)).values.astype(np.float64) + fee_per_obs = np.exp(panel["log_fee"].values).astype(np.float64) + lag_log_tvl = panel["log_tvl_lag1"].values.astype(np.float64) + + # Gas: per-observation + if gas is not None: + gas_arr = np.asarray(gas, dtype=np.float64) + else: + gas_arr = np.full(len(panel), 0.01, dtype=np.float64) + + # Chain index: integer per pool + pool_meta = base["pool_meta"] + chains = base["chains"] + chain_to_idx = {c: i for i, c in enumerate(chains)} + chain_idx_per_pool = np.array( + [chain_to_idx[c] for c in pool_meta["chain"]], dtype=np.int32, + ) + + # Tier pair index: per pool + tier_idx_per_pool = np.array( + [_tier_pair_idx(int(row["tier_A"]), int(row["tier_B"])) + for _, row in pool_meta.iterrows()], + dtype=np.int32, + ) + + # Count unique tier pairs and chains + n_chains = len(chains) + tier_pairs = set() + for _, row in pool_meta.iterrows(): + tier_pairs.add((int(row["tier_A"]), int(row["tier_B"]))) + n_tiers = 6 # fixed: upper triangle of 3x3 + + print(f" Structural encoding: N_obs={len(panel)}, " + f"N_pools={base['N_pools']}, n_chains={n_chains}, n_tiers={n_tiers}") + + return { + # Base arrays (same as encode_covariates) + "pool_idx": base["pool_idx"], + "X_pool": base["X_pool"], + "x_obs": x_obs, + "y_obs": base["y_obs"], + "pool_ids": base["pool_ids"], + "pool_meta": pool_meta, + "covariate_names": base["covariate_names"], + "tier_A_per_pool": base["tier_A_per_pool"], + "N_pools": base["N_pools"], + "K_cov": base["K_cov"], + "ref_chain": base["ref_chain"], + "ref_tier_a": base["ref_tier_a"], + "ref_tier_b": base["ref_tier_b"], + "chains": chains, + # Structural model extras + "sigma_daily": sigma_daily, + "fee": fee_per_obs, + "gas": gas_arr, + "chain_idx": chain_idx_per_pool, + "tier_idx": tier_idx_per_pool, + "lag_log_tvl": lag_log_tvl, + "n_chains": n_chains, + "n_tiers": n_tiers, + } diff --git a/quantammsim/noise_calibration/data_pipeline.py b/quantammsim/noise_calibration/data_pipeline.py new file mode 100644 index 0000000..b824cdc --- /dev/null +++ b/quantammsim/noise_calibration/data_pipeline.py @@ -0,0 +1,516 @@ +"""Data pipeline: fetch pools, snapshots, prices, and assemble panel.""" + +import json +import os +import time +import urllib.request +from datetime import datetime + +import numpy as np +import pandas as pd + +from .constants import BALANCER_API_URL, BALANCER_API_CHAINS +from .token_classification import classify_token_tier + + +def _graphql_request(query: dict, base_url: str = BALANCER_API_URL, + timeout: int = 30) -> dict: + """Send a GraphQL request to the Balancer V3 API.""" + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def enumerate_balancer_pools( + chains: list = None, + pool_types: list = None, + min_tvl: float = 10000.0, +) -> pd.DataFrame: + """Enumerate all WEIGHTED + RECLAMM pools across chains from Balancer API.""" + if chains is None: + chains = BALANCER_API_CHAINS + if pool_types is None: + pool_types = ["WEIGHTED", "RECLAMM"] + + all_pools = [] + for chain in chains: + print(f" Querying {chain}...", end=" ", flush=True) + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { + chainIn: [$chain] + poolTypeIn: $types + minTvl: $minTvl + } + ) { + id + chain + type + createTime + protocolVersion + poolTokens { + symbol + weight + address + } + dynamicData { + totalLiquidity + swapFee + } + } + } + """, + "variables": { + "chain": chain, + "types": pool_types, + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f"FAILED ({e})") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + weights = [t.get("weight") for t in p.get("poolTokens", [])] + token_addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "protocol_version": p.get("protocolVersion", 0), + "tokens": tokens, + "token_addresses": token_addresses, + "weights": weights, + "swap_fee": fee, + "create_time": p.get("createTime", 0), + "current_tvl": tvl, + }) + + print(f"{len(pools)} pools") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools across {len(chains)} chains") + return df + + +def fetch_pool_snapshots(pool_id: str, chain: str, + base_url: str = BALANCER_API_URL) -> pd.DataFrame: + """Fetch ALL_TIME daily snapshots for a single pool.""" + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + totalShares + } + } + """, + "variables": { + "poolId": pool_id, + "chain": chain, + "range": "ALL_TIME", + }, + } + + body = _graphql_request(query) + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + + if not snapshots: + return pd.DataFrame(columns=["timestamp", "volume_usd", + "total_liquidity_usd", "total_shares"]) + + records = [] + for snap in snapshots: + records.append({ + "timestamp": int(snap["timestamp"]), + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + "total_shares": float(snap.get("totalShares", 0)), + }) + + df = pd.DataFrame(records) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df + + +def fetch_all_snapshots(pools_df: pd.DataFrame, + cache_path: str = None) -> pd.DataFrame: + """Fetch daily snapshots for all pools, with caching.""" + cached = pd.DataFrame() + cached_pool_ids = set() + if cache_path and os.path.exists(cache_path): + cached = pd.read_parquet(cache_path) + cached_pool_ids = set(cached["pool_id"].unique()) + print(f" Cache has {len(cached_pool_ids)} pools, " + f"{len(cached)} pool-days") + + if len(pools_df) == 0: + print(" No pools to fetch.") + return cached if len(cached) > 0 else pd.DataFrame( + columns=["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd", "total_shares"] + ) + to_fetch = pools_df[~pools_df["pool_id"].isin(cached_pool_ids)] + print(f" Need to fetch {len(to_fetch)} new pools") + + new_records = [] + for i, (_, pool) in enumerate(to_fetch.iterrows()): + if (i + 1) % 10 == 0 or i == 0: + print(f" Fetching {i+1}/{len(to_fetch)}: {pool['pool_id'][:10]}... " + f"({pool['chain']})", flush=True) + try: + snap_df = fetch_pool_snapshots(pool["pool_id"], pool["chain"]) + if len(snap_df) > 0: + snap_df["pool_id"] = pool["pool_id"] + snap_df["chain"] = pool["chain"] + cols = ["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + if "total_shares" in snap_df.columns: + cols.append("total_shares") + new_records.append(snap_df[cols]) + except Exception as e: + print(f" FAILED {pool['pool_id'][:10]}: {e}") + time.sleep(0.5) + + if new_records: + new_df = pd.concat(new_records, ignore_index=True) + combined = pd.concat([cached, new_df], ignore_index=True) + else: + combined = cached + + if cache_path and len(combined) > 0: + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + combined.to_parquet(cache_path, index=False) + print(f" Saved cache: {len(combined)} pool-days -> {cache_path}") + + return combined + + +def fetch_token_prices(token_addresses_by_chain: dict, + cache_dir: str = None) -> dict: + """Fetch hourly token prices from Balancer API.""" + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + prices = {} + + for chain, tokens in token_addresses_by_chain.items(): + uncached = {} + for symbol, address in tokens.items(): + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") if cache_dir else None + + if cp and os.path.exists(cp): + prices[(chain, symbol)] = pd.read_parquet(cp) + else: + uncached[symbol] = address + + if not uncached: + continue + + addr_to_symbol = {addr: sym for sym, addr in uncached.items()} + addresses = list(uncached.values()) + + print(f" Fetching {len(addresses)} prices on {chain}...", flush=True) + + batch_size = 20 + for batch_start in range(0, len(addresses), batch_size): + batch_addrs = addresses[batch_start:batch_start + batch_size] + query = { + "query": """ + query GetPrices($chain: GqlChain!, $addresses: [String!]!, + $range: GqlTokenChartDataRange!) { + tokenGetHistoricalPrices( + addresses: $addresses, chain: $chain, range: $range + ) { + address + prices { + timestamp + price + } + } + } + """, + "variables": { + "chain": chain, + "addresses": batch_addrs, + "range": "ONE_YEAR", + }, + } + + try: + body = _graphql_request(query, timeout=60) + results = body.get("data", {}).get( + "tokenGetHistoricalPrices", []) + for result in results: + addr = result.get("address", "") + price_list = result.get("prices", []) + symbol = addr_to_symbol.get(addr) + if symbol and price_list: + pdf = pd.DataFrame(price_list) + pdf["timestamp"] = pdf["timestamp"].astype(int) + pdf["price"] = pdf["price"].astype(float) + prices[(chain, symbol)] = pdf + if cache_dir: + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") + pdf.to_parquet(cp, index=False) + except Exception as e: + print(f" FAILED batch on {chain}: {e}") + + time.sleep(0.5) + + print(f" Got prices for {len(prices)} token-chain pairs") + return prices + + +def compute_pair_volatility( + snapshots_df: pd.DataFrame, + pool_row: pd.Series, + token_prices: dict, +) -> pd.Series: + """Compute daily annualised volatility for a pool's pair ratio.""" + tokens = pool_row["tokens"] + chain = pool_row["chain"] + + if len(tokens) < 2: + return pd.Series(dtype=float) + + def _get_price_df(symbol): + key = (chain, symbol) + if key in token_prices: + return token_prices[key] + for k, v in token_prices.items(): + if k[1] == symbol: + return v + return None + + p0_df = _get_price_df(tokens[0]) + p1_df = _get_price_df(tokens[1]) + + stables = {"USDC", "USDT", "DAI", "LUSD", "GHO", "crvUSD", "sDAI", + "WXDAI", "xDAI", "USDC.e", "USDbC"} + + if tokens[0] in stables and tokens[1] in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.01, index=dates) + + if p0_df is None and tokens[0] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + if p1_df is None and tokens[1] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + + if tokens[0] in stables: + if p1_df is None or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p1_df.copy() + ratio_df["ratio"] = 1.0 / ratio_df["price"] + elif tokens[1] in stables: + if p0_df is None or len(p0_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p0_df.copy() + ratio_df["ratio"] = ratio_df["price"] + else: + if p0_df is None or p1_df is None or len(p0_df) == 0 or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + merged = pd.merge_asof( + p0_df.sort_values("timestamp"), + p1_df.sort_values("timestamp"), + on="timestamp", + suffixes=("_0", "_1"), + tolerance=7200, + ).dropna() + if len(merged) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = merged.copy() + ratio_df["ratio"] = merged["price_0"] / merged["price_1"] + + ratio_df["datetime"] = pd.to_datetime(ratio_df["timestamp"], unit="s") + ratio_df["date"] = ratio_df["datetime"].dt.date + ratio_df = ratio_df.sort_values("timestamp") + + ratio_df["log_return"] = np.log( + ratio_df["ratio"] / ratio_df["ratio"].shift(1) + ) + ratio_df = ratio_df.dropna(subset=["log_return"]) + + daily_vol = ratio_df.groupby("date")["log_return"].std() + daily_vol_ann = daily_vol * np.sqrt(24 * 365) + + return daily_vol_ann + + +def assemble_panel( + pools_df: pd.DataFrame, + snapshots_df: pd.DataFrame, + token_prices: dict, +) -> pd.DataFrame: + """Assemble the full panel DataFrame with lagged TVL. + + Adds log_tvl_lag1 = per-pool shift(1) of log_tvl to break + the TVL-volume simultaneity bias. Drops the first observation + per pool (~1 obs per pool). + """ + records = [] + pool_ids = snapshots_df["pool_id"].unique() + n_pools = len(pool_ids) + + # Track volatility fallback rate + n_obs_total = 0 + n_obs_fallback = 0 + pools_all_fallback = [] # pools where every obs hit fallback + + for i, pool_id in enumerate(pool_ids): + if (i + 1) % 20 == 0 or i == 0: + print(f" Assembling {i+1}/{n_pools}...", flush=True) + + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pool_id] + pool_meta = pools_df[pools_df["pool_id"] == pool_id] + if len(pool_meta) == 0: + continue + pool_row = pool_meta.iloc[0] + + tokens = pool_row["tokens"] + if len(tokens) < 2: + continue + + chain = pool_row["chain"] + swap_fee = pool_row["swap_fee"] + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + vol_series = compute_pair_volatility(pool_snaps, pool_row, token_prices) + + pool_obs = 0 + pool_fallback = 0 + has_shares = "total_shares" in pool_snaps.columns + for _, snap in pool_snaps.iterrows(): + date = snap["date"] + volume = snap["volume_usd"] + tvl = snap["total_liquidity_usd"] + shares = float(snap["total_shares"]) if has_shares else 0.0 + + if tvl <= 0 or volume <= 0: + continue + + used_fallback = False + if isinstance(vol_series, pd.Series) and date in vol_series.index: + vol = vol_series[date] + else: + vol = 0.5 + used_fallback = True + + if not np.isfinite(vol) or vol <= 0: + vol = 0.5 + used_fallback = True + + n_obs_total += 1 + if used_fallback: + n_obs_fallback += 1 + pool_fallback += 1 + pool_obs += 1 + + if isinstance(date, datetime): + is_weekend = date.weekday() >= 5 + else: + is_weekend = pd.Timestamp(date).weekday() >= 5 + + # DOW harmonics (deterministic from date) + if isinstance(date, datetime): + dow = date.weekday() + else: + dow = pd.Timestamp(date).weekday() + dow_sin = np.sin(2.0 * np.pi * dow / 7.0) + dow_cos = np.cos(2.0 * np.pi * dow / 7.0) + + record = { + "pool_id": pool_id, + "chain": chain, + "date": date, + "log_volume": np.log(volume), + "log_tvl": np.log(tvl), + "volatility": vol, + "log_sigma": np.log(max(vol, 1e-6)), + "weekend": 1.0 if is_weekend else 0.0, + "log_fee": np.log(max(swap_fee, 1e-6)), + "dow_sin": dow_sin, + "dow_cos": dow_cos, + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": ",".join(tokens[:2]), + "swap_fee": swap_fee, + } + if shares > 0: + record["total_shares"] = shares + records.append(record) + + if pool_obs > 0 and pool_fallback == pool_obs: + pools_all_fallback.append( + (pool_id[:16], chain, ",".join(tokens[:2])) + ) + + panel = pd.DataFrame(records) + + # Add lagged TVL to break simultaneity bias + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + n_before = len(panel) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + n_dropped = n_before - len(panel) + + # Interaction terms (use lagged TVL to break simultaneity) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + print(f"\n Panel: {len(panel)} observations, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + print(f" Dropped {n_dropped} first-day obs for lagged TVL") + + # Volatility coverage report + if n_obs_total > 0: + pct = 100 * n_obs_fallback / n_obs_total + print(f"\n Volatility coverage:") + print(f" {n_obs_fallback}/{n_obs_total} obs used fallback " + f"vol=0.5 ({pct:.1f}%)") + print(f" {len(pools_all_fallback)} pools had 100% fallback") + if pools_all_fallback: + for pid, ch, toks in pools_all_fallback[:10]: + print(f" {pid}... ({ch}) {toks}") + if len(pools_all_fallback) > 10: + print(f" ... and {len(pools_all_fallback) - 10} more") + + return panel diff --git a/quantammsim/noise_calibration/data_validation.py b/quantammsim/noise_calibration/data_validation.py new file mode 100644 index 0000000..8cb44e7 --- /dev/null +++ b/quantammsim/noise_calibration/data_validation.py @@ -0,0 +1,41 @@ +"""Data validation for noise calibration panels.""" + +import numpy as np +import pandas as pd + + +def validate_panel(panel: pd.DataFrame) -> pd.DataFrame: + """Run data validation checks. Prints warnings but does NOT drop rows.""" + print("\n Data validation:") + + # Pools with constant volume + vol_std = panel.groupby("pool_id")["log_volume"].std() + constant_vol = vol_std[vol_std < 0.01] + if len(constant_vol) > 0: + print(f" WARNING: {len(constant_vol)} pools have near-constant " + f"log(volume) (std < 0.01)") + for pid in constant_vol.index[:5]: + print(f" {pid[:16]}... std={constant_vol[pid]:.4f}") + if len(constant_vol) > 5: + print(f" ... and {len(constant_vol) - 5} more") + + # TVL jumps > 10x between consecutive days + panel_sorted = panel.sort_values(["pool_id", "date"]) + tvl_ratio = panel_sorted.groupby("pool_id")["log_tvl"].diff().abs() + big_jumps = tvl_ratio[tvl_ratio > np.log(10)] + if len(big_jumps) > 0: + affected_pools = panel_sorted.loc[big_jumps.index, "pool_id"].nunique() + print(f" WARNING: {len(big_jumps)} TVL jumps > 10x across " + f"{affected_pools} pools") + + # Days where volume > TVL + high_vol = panel[panel["log_volume"] > panel["log_tvl"]] + if len(high_vol) > 0: + affected_pools = high_vol["pool_id"].nunique() + print(f" WARNING: {len(high_vol)} days where volume > TVL across " + f"{affected_pools} pools (potential wash trading)") + + if len(constant_vol) == 0 and len(big_jumps) == 0 and len(high_vol) == 0: + print(" All checks passed.") + + return panel diff --git a/quantammsim/noise_calibration/formula_arb.py b/quantammsim/noise_calibration/formula_arb.py new file mode 100644 index 0000000..80fe335 --- /dev/null +++ b/quantammsim/noise_calibration/formula_arb.py @@ -0,0 +1,36 @@ +"""JAX-differentiable LVR formula for arb volume. + +Based on arXiv:2305.14604v2 §6 with gas costs and discrete-time correction. +Reference: scripts/plot_formula_arb_vs_real.py:formula_arb_volume_daily (line 58). +""" + +import jax.numpy as jnp + + +def formula_arb_volume_daily_jax(sigma_daily, tvl, fee, gas_usd, cadence_minutes): + """Analytical arb volume per day for a CPMM with gas costs. + + All inputs are JAX scalars or arrays (must be broadcastable). + + Parameters + ---------- + sigma_daily : float + Daily volatility of the log price ratio (NOT annualised). + tvl : float + Pool TVL in USD. + fee : float + Swap fee as fraction (e.g. 0.003 for 30bp). + gas_usd : float + All-in gas cost per arb tx in USD. + cadence_minutes : float + Effective arb cadence in minutes (= simulator's arb_frequency). + """ + block_time_s = cadence_minutes * 60.0 + delta = 2.0 * jnp.sqrt(2.0 * jnp.maximum(gas_usd, 0.0) / jnp.maximum(tvl, 1e-6)) + bLVR = sigma_daily**2 * tvl / 8.0 + sqrt_term = sigma_daily * jnp.sqrt(block_time_s / (2.0 * 86400.0)) + correction = jnp.maximum( + 1.0 - delta / (2.0 * fee) - sqrt_term / (fee + delta / 2.0), + 0.0, + ) + return bLVR * correction / fee diff --git a/quantammsim/noise_calibration/inference.py b/quantammsim/noise_calibration/inference.py new file mode 100644 index 0000000..e315a5b --- /dev/null +++ b/quantammsim/noise_calibration/inference.py @@ -0,0 +1,270 @@ +"""Inference runners: SVI, NUTS, SVI-initialized NUTS.""" + +import numpy as np + +from .constants import K_COEFF +from .model import noise_model + + +def _get_theta_samples(sample_dict: dict, X_pool: np.ndarray, + data: dict = None) -> np.ndarray: + """Get theta samples, reconstructing from non-centered params if needed. + + MCMC.get_samples() includes the deterministic "theta" site. + SVI's Predictive(guide, ...) does NOT — the guide only samples latent + variables. In that case, reconstruct theta manually: + mu = X_pool @ B^T + L_Sigma = diag(sigma_theta) @ L_Omega + theta = mu + eta @ L_Sigma^T + + For marginalized IBP (W present, z_logit absent): compute MAP feature + assignments from data, then theta = X_pool @ B.T + Z_MAP @ W. + Requires data dict for pool_idx, x_obs, y_obs. + + For legacy STE IBP (z_logit present): theta = X_pool @ B.T + Z_hard @ W. + """ + if "theta" in sample_dict: + return np.array(sample_dict["theta"]) + + # Marginalized IBP path: compute MAP assignments from data + if "W" in sample_dict and "z_logit" not in sample_dict: + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + W = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + + # MAP assignments: (N_pools, K_features) binary + if "v" in sample_dict: + # Hybrid IBP+DP: joint MAP over (features, clusters) + from .postprocessing import assign_ibp_dp_joint + Z_map, _ = assign_ibp_dp_joint(sample_dict, data) + else: + from .postprocessing import assign_ibp_features + Z_map = assign_ibp_features(sample_dict, data) + + mu = np.einsum("pd,sjd->spj", X_pool, B) + # Z_map doesn't vary across samples — broadcast + feature_effect = np.einsum("pk,skj->spj", Z_map.astype(float), W) + return mu + feature_effect + + # Legacy STE IBP path: theta = X_pool @ B.T + Z_hard @ W + if "z_logit" in sample_dict: + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + W = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + z_logit = np.array(sample_dict["z_logit"]) # (S, N_pools, K_features) + Z_hard = (z_logit > 0).astype(float) + + mu = np.einsum("pd,sjd->spj", X_pool, B) # (S, N_pools, K_coeff) + feature_effect = np.einsum("spk,skj->spj", Z_hard, W) + return mu + feature_effect + + B = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + sigma_theta = np.array(sample_dict["sigma_theta"]) # (S, K_coeff) + L_Omega = np.array(sample_dict["L_Omega"]) # (S, K_coeff, K_coeff) + eta = np.array(sample_dict["eta"]) # (S, N_pools, K_coeff) + + # mu[s, p, j] = sum_d X_pool[p, d] * B[s, j, d] -> (S, N_pools, K_coeff) + mu = np.einsum("pd,sjd->spj", X_pool, B) + + # L_Sigma = diag(sigma_theta) @ L_Omega -> (S, K_coeff, K_coeff) + L_Sigma = sigma_theta[:, :, None] * L_Omega + + # offset = eta @ L_Sigma^T -> (S, N_pools, K_coeff) + offset = np.einsum("spi,sji->spj", eta, L_Sigma) + + return mu + offset + + +def _build_model_kwargs(data: dict, model_fn=None) -> dict: + """Convert data dict to jnp arrays for the model. + + Uses inspect.signature on model_fn to decide which kwargs to include: + - tier_A_per_pool: only if model_fn accepts it + - K_clusters: only if model_fn accepts it and data has it + """ + import inspect + import jax.numpy as jnp + + if model_fn is None: + model_fn = noise_model + + params = set(inspect.signature(model_fn).parameters.keys()) + + kwargs = dict( + pool_idx=jnp.array(data["pool_idx"]), + X_pool=jnp.array(data["X_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K_coeff=K_COEFF, + K_cov=data["K_cov"], + ) + + if "tier_A_per_pool" in params: + kwargs["tier_A_per_pool"] = jnp.array(data["tier_A_per_pool"]) + + if "K_clusters" in params and "K_clusters" in data: + kwargs["K_clusters"] = data["K_clusters"] + + if "K_features" in params and "K_features" in data: + kwargs["K_features"] = data["K_features"] + + # Structural model parameters + if "sigma_daily" in params and "sigma_daily" in data: + kwargs["sigma_daily"] = jnp.array(data["sigma_daily"]) + if "lag_log_tvl" in params and "lag_log_tvl" in data: + kwargs["lag_log_tvl"] = jnp.array(data["lag_log_tvl"]) + if "fee" in params and "fee" in data: + kwargs["fee"] = jnp.array(data["fee"]) + if "gas" in params and "gas" in data: + kwargs["gas"] = jnp.array(data["gas"]) + if "chain_idx" in params and "chain_idx" in data: + kwargs["chain_idx"] = jnp.array(data["chain_idx"]) + if "tier_idx" in params and "tier_idx" in data: + kwargs["tier_idx"] = jnp.array(data["tier_idx"]) + if "n_chains" in params and "n_chains" in data: + kwargs["n_chains"] = data["n_chains"] + if "n_tiers" in params and "n_tiers" in data: + kwargs["n_tiers"] = data["n_tiers"] + if "K_archetypes" in params and "K_archetypes" in data: + kwargs["K_archetypes"] = data["K_archetypes"] + + return kwargs + + +def run_svi(data, num_steps=20000, lr=1e-3, seed=0, + num_samples=1000, model_fn=None) -> tuple: + """Run SVI with AutoNormal guide. + + Returns (samples_dict, elbo_losses). + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import SVI, Trace_ELBO, Predictive + from numpyro.infer.autoguide import AutoNormal + + if model_fn is None: + model_fn = noise_model + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + + print(f"\n Running SVI: {num_steps} steps, lr={lr}") + guide = AutoNormal(model_fn) + optimizer = numpyro.optim.Adam(lr) + svi = SVI(model_fn, guide, optimizer, loss=Trace_ELBO()) + + rng_key = jax.random.PRNGKey(seed) + svi_result = svi.run(rng_key, num_steps, **model_kwargs) + + elbo_losses = np.array(svi_result.losses) + print(f" SVI complete. Final ELBO: {elbo_losses[-1]:.2f}") + print(f" ELBO last 100 std: {np.std(elbo_losses[-100:]):.2f}") + + # Draw posterior samples + predictive = Predictive( + guide, params=svi_result.params, num_samples=num_samples, + ) + samples = predictive(jax.random.PRNGKey(seed + 1), **model_kwargs) + samples = {k: np.array(v) for k, v in samples.items()} + + print(f" Drew {num_samples} posterior samples.") + return samples, elbo_losses + + +def run_nuts(data, num_warmup=1000, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42, + init_values=None, model_fn=None): + """Run NUTS MCMC. + + Uses init_to_value if init_values provided (for SVI-initialized NUTS). + Returns the MCMC object. + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import MCMC, NUTS, init_to_value + + if model_fn is None: + model_fn = noise_model + + # Note: set_host_device_count must be called before JAX init. + # We handle this in main(). Here we just verify device count. + n_devices = len(jax.devices("cpu")) + if n_devices < num_chains: + print(f" WARNING: Only {n_devices} CPU devices available for " + f"{num_chains} chains. Chains will run sequentially.") + + init_strategy = None + if init_values is not None: + init_strategy = init_to_value( + values={k: jnp.array(v) for k, v in init_values.items()} + ) + + kernel = NUTS( + model_fn, + target_accept_prob=target_accept, + max_tree_depth=max_tree_depth, + init_strategy=init_strategy, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=True, + ) + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + rng_key = jax.random.PRNGKey(seed) + + print(f"\n Running NUTS: {num_chains} chains x " + f"({num_warmup} warmup + {num_samples} samples)") + print(f" target_accept={target_accept}, max_tree_depth={max_tree_depth}") + if init_values is not None: + print(" Using SVI-initialized starting values.") + + mcmc.run(rng_key, **model_kwargs) + mcmc.print_summary(exclude_deterministic=True) + return mcmc + + +def run_svi_then_nuts(data, svi_steps=5000, svi_lr=1e-3, + num_warmup=500, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42, + model_fn=None): + """Run SVI first, then use posterior means as NUTS init. + + Returns (MCMC, elbo_losses). + """ + if model_fn is None: + model_fn = noise_model + + # Phase 1: SVI + print(" Phase 1: SVI warm-start") + samples, elbo_losses = run_svi( + data, num_steps=svi_steps, lr=svi_lr, seed=seed, num_samples=100, + model_fn=model_fn, + ) + + # Extract posterior means for init + init_values = {} + skip_keys = {"y", "theta", "w"} + for k, v in samples.items(): + if k in skip_keys: + continue + init_values[k] = np.mean(v, axis=0) + + # Phase 2: NUTS from SVI init + print("\n Phase 2: NUTS from SVI-initialized values") + mcmc = run_nuts( + data, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + target_accept=target_accept, + max_tree_depth=max_tree_depth, + seed=seed, + init_values=init_values, + model_fn=model_fn, + ) + + return mcmc, elbo_losses diff --git a/quantammsim/noise_calibration/model.py b/quantammsim/noise_calibration/model.py new file mode 100644 index 0000000..d0d5ebf --- /dev/null +++ b/quantammsim/noise_calibration/model.py @@ -0,0 +1,444 @@ +"""NumPyro noise volume models.""" + +import jax +import jax.numpy as jnp + +from .formula_arb import formula_arb_volume_daily_jax + + +def _pad_with_ref(alpha): + """Prepend a zero for the reference category.""" + return jnp.concatenate([jnp.zeros(1), alpha]) + + +def stick_breaking_weights(v): + """Convert Beta stick-breaking fractions to K-simplex weights. + + v: array of shape (K-1,) with values in (0, 1). + Returns weights of shape (K,) summing to 1. + + w_1 = v_1 + w_k = v_k * prod_{j> jnp.arange(K_features)[None, :]) & 1 + ).astype(jnp.float32) # (n_configs, K_features) + + # Log-prior for each config from IBP stick-breaking + log_pi = jnp.log(pi + 1e-30) + log_1mpi = jnp.log(1.0 - pi + 1e-30) + log_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Per-config feature effect + feature_effects = configs @ W # (n_configs, K_coeff) + + # Per-observation means + mu_pop_obs = jnp.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config + log_lik = dist.StudentT(df, mu_obs, sigma_eps).log_prob( + y_obs[:, None] + ) # (N_obs, n_configs) + + # Sum log-likelihoods within each pool + pool_log_liks = jnp.zeros((N_pools, n_configs)) + pool_log_liks = pool_log_liks.at[pool_idx].add(log_lik) + + # Marginal: logsumexp over configs per pool + log_marginal = logsumexp( + log_prior[None, :] + pool_log_liks, axis=1 + ) # (N_pools,) + numpyro.factor("log_lik", log_marginal.sum()) + else: + # === Prior predictive: sample explicit assignments === + with numpyro.plate("pools", N_pools): + z_features = numpyro.sample( + "z_features", + dist.Bernoulli(probs=pi).expand([K_features]).to_event(1), + ) + + theta = mu_pop + z_features @ W # (N_pools, K_coeff) + numpyro.deterministic("theta", theta) + + theta_obs = theta[pool_idx] + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample( + "y", dist.StudentT(df, mu_obs, sigma_eps), + ) + + +def noise_model_ibp_dp(pool_idx, X_pool, x_obs, y_obs=None, + N_pools=None, K_coeff=4, K_cov=None, + K_features=6, K_clusters=6): + """Hybrid IBP+DP noise model. + + IBP latent features for mean heterogeneity (theta = X_pool @ B.T + z @ W), + DP mixture for noise heterogeneity (per-cluster sigma_eps). Joint + marginalization over (2^K_features × K_clusters) configurations when + y_obs is provided. + """ + import numpyro + import numpyro.distributions as dist + from jax.scipy.special import logsumexp + + # --- Population effects --- + B = numpyro.sample( + "B", dist.Normal(0.0, 5.0).expand([K_coeff, K_cov]).to_event(2) + ) + + # Student-t degrees of freedom + df = numpyro.sample("df", dist.Gamma(2.0, 0.1)) + + # --- IBP prior on feature prevalences --- + alpha_ibp = numpyro.sample("alpha_ibp", dist.Gamma(2.0, 1.0)) + with numpyro.plate("features", K_features): + v_ibp = numpyro.sample("v_ibp", dist.Beta(alpha_ibp, 1.0)) + pi = jnp.cumprod(v_ibp) # decreasing prevalences + + # --- Feature effect matrix --- + sigma_w = numpyro.sample("sigma_w", dist.HalfNormal(2.0)) + W = numpyro.sample( + "W", dist.Normal(0.0, sigma_w).expand([K_features, K_coeff]).to_event(2) + ) + + # --- DP mixture on sigma_eps --- + alpha_dp = numpyro.sample("alpha_dp", dist.Gamma(1.0, 1.0)) + with numpyro.plate("sticks", K_clusters - 1): + v = numpyro.sample("v", dist.Beta(1.0, alpha_dp)) + w = numpyro.deterministic("w", stick_breaking_weights(v)) + + sigma_eps = numpyro.sample( + "sigma_eps", + dist.HalfNormal(2.0).expand([K_clusters]).to_event(1), + ) + + # --- Population mean per pool --- + mu_pop = X_pool @ B.T # (N_pools, K_coeff) + + if y_obs is not None: + # === Joint marginalization over IBP configs × DP clusters === + + # Enumerate all 2^K binary feature configurations + n_configs = 2 ** K_features + configs = ( + (jnp.arange(n_configs)[:, None] >> jnp.arange(K_features)[None, :]) & 1 + ).astype(jnp.float32) # (n_configs, K_features) + + # Log-prior for each IBP config + log_pi = jnp.log(pi + 1e-30) + log_1mpi = jnp.log(1.0 - pi + 1e-30) + log_ibp_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Per-config feature effect + feature_effects = configs @ W # (n_configs, K_coeff) + + # Per-observation means for each IBP config + mu_pop_obs = jnp.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per IBP config per DP cluster + log_lik = dist.StudentT( + df, mu_obs[:, :, None], sigma_eps[None, None, :] + ).log_prob(y_obs[:, None, None]) # (N_obs, n_configs, K_clusters) + + # Sum log-likelihoods within each pool + pool_log_liks = jnp.zeros((N_pools, n_configs, K_clusters)) + pool_log_liks = pool_log_liks.at[pool_idx].add(log_lik) + + # Joint prior: IBP config prior × DP cluster weight + log_joint_prior = log_ibp_prior[:, None] + jnp.log(w + 1e-30)[None, :] # (n_configs, K_clusters) + + # Marginal log-likelihood per pool: logsumexp over (configs, clusters) + log_marginal = logsumexp( + log_joint_prior[None, :, :] + pool_log_liks, axis=(1, 2) + ) # (N_pools,) + numpyro.factor("log_lik", log_marginal.sum()) + else: + # === Prior predictive: sample explicit assignments === + with numpyro.plate("pools", N_pools): + z_features = numpyro.sample( + "z_features", + dist.Bernoulli(probs=pi).expand([K_features]).to_event(1), + ) + z_cluster = numpyro.sample("z_cluster", dist.Categorical(probs=w)) + + theta = mu_pop + z_features @ W # (N_pools, K_coeff) + numpyro.deterministic("theta", theta) + + theta_obs = theta[pool_idx] + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) + sigma_obs = sigma_eps[z_cluster[pool_idx]] + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("y", dist.StudentT(df, mu_obs, sigma_obs)) + + +def structural_noise_model(pool_idx, X_pool, x_obs, y_obs=None, + sigma_daily=None, lag_log_tvl=None, + fee=None, gas=None, + chain_idx=None, tier_idx=None, + N_pools=None, K_archetypes=3, + n_chains=8, n_tiers=6, + **kwargs): + """Structural mixture model: LVR arb + mixture-of-experts noise. + + Decomposes observed total volume into arb (structurally restricted to + LVR formula) and noise (flexible MoE). All continuous — no discrete + latent variables, so AutoNormal guide works directly. + """ + import numpyro + import numpyro.distributions as dist + + # --- Arb cadence parameters --- + alpha_0 = numpyro.sample("alpha_0", dist.Normal(1.0, 2.0)) + alpha_chain = numpyro.sample( + "alpha_chain", + dist.Normal(0, 1).expand([n_chains - 1]).to_event(1), + ) + alpha_tier = numpyro.sample( + "alpha_tier", + dist.Normal(0, 1).expand([n_tiers - 1]).to_event(1), + ) + alpha_tvl = numpyro.sample("alpha_tvl", dist.Normal(0, 0.5)) + + # Per-observation cadence (broadcast pool-level indices to obs) + log_cadence = ( + alpha_0 + + _pad_with_ref(alpha_chain)[chain_idx[pool_idx]] + + _pad_with_ref(alpha_tier)[tier_idx[pool_idx]] + + alpha_tvl * lag_log_tvl + ) + cadence = jnp.exp(jnp.clip(log_cadence, -2.0, 6.0)) # 0.1 to 400 min + + # V_arb per obs (deterministic given cadence + observables) + V_arb = formula_arb_volume_daily_jax( + sigma_daily, jnp.exp(lag_log_tvl), fee, gas, cadence, + ) + + # --- Noise MoE parameters --- + n_obs_coeff = x_obs.shape[1] + K_pool_cov = X_pool.shape[1] + + W_gate = numpyro.sample( + "W_gate", + dist.Normal(0, 1).expand([K_pool_cov, K_archetypes]).to_event(2), + ) + beta = numpyro.sample( + "beta", + dist.Normal(0, 2).expand([K_archetypes, n_obs_coeff]).to_event(2), + ) + + # Per-pool soft assignment and coefficient blend + logits = X_pool @ W_gate # (N_pools, K) + w = jax.nn.softmax(logits, axis=-1) # (N_pools, K) + beta_pool = jnp.einsum("pk,kc->pc", w, beta) # (N_pools, n_obs_coeff) + log_V_noise = jnp.sum(beta_pool[pool_idx] * x_obs, axis=1) + V_noise = jnp.exp(log_V_noise) + + # --- Observation model --- + df = numpyro.sample("df", dist.Gamma(2.0, 0.1)) + sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(1.0)) + + mu = jnp.log(jnp.maximum(V_arb + V_noise, 1e-6)) + + if y_obs is not None: + numpyro.sample("y", dist.StudentT(df, mu, sigma_eps), obs=y_obs) + else: + numpyro.sample("y", dist.StudentT(df, mu, sigma_eps)) diff --git a/quantammsim/noise_calibration/output.py b/quantammsim/noise_calibration/output.py new file mode 100644 index 0000000..931f059 --- /dev/null +++ b/quantammsim/noise_calibration/output.py @@ -0,0 +1,306 @@ +"""JSON output and sample caching.""" + +import json +import os + +import numpy as np + +from .constants import K_COEFF, COEFF_NAMES, K_OBS_COEFF, OBS_COEFF_NAMES + + +def generate_output_json(pool_params, samples, data, convergence, + output_path, inference_config): + """Write structured JSON output. + + Dispatches format based on whether samples contain DP mixture parameters + (detected via "v" in sample_dict), or structural model parameters + (detected via "W_gate" in sample_dict). + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Structural mixture model path + is_structural = "W_gate" in sample_dict + if is_structural: + _generate_structural_output( + pool_params, sample_dict, data, convergence, + output_path, inference_config, + ) + return + + # Detection priority: check hybrid first, then pure IBP, then DP + is_ibp_dp = ("W" in sample_dict and "v" in sample_dict + and "z_logit" not in sample_dict) + is_ibp = ("W" in sample_dict and "v" not in sample_dict + and "z_logit" not in sample_dict) + is_ibp_ste = "z_logit" in sample_dict # legacy STE artifacts + is_dp = "v" in sample_dict and "W" not in sample_dict + + B_median = np.median(np.array(sample_dict["B"]), axis=0).tolist() + sigma_eps_median = np.median( + np.array(sample_dict["sigma_eps"]), axis=0 + ) + df_median = float(np.median(np.array(sample_dict["df"]))) + + if is_ibp_dp: + model_name = "hierarchical_student_t_ibp_dp" + sigma_eps_structure = "dp_mixture" + + W_median = np.median(np.array(sample_dict["W"]), axis=0).tolist() + v_ibp_median = np.median(np.array(sample_dict["v_ibp"]), axis=0) + pi = np.cumprod(v_ibp_median).tolist() + + from .model import stick_breaking_weights + import jax.numpy as jnp + v_median = np.median(np.array(sample_dict["v"]), axis=0) + cluster_weights = np.array( + stick_breaking_weights(jnp.array(v_median)) + ).tolist() + + population_effects = { + "B": B_median, + "sigma_eps": sigma_eps_median.tolist() if hasattr(sigma_eps_median, 'tolist') else sigma_eps_median, + "df": df_median, + "W": W_median, + "feature_prevalences": pi, + "alpha_ibp": float( + np.median(np.array(sample_dict["alpha_ibp"])) + ), + "cluster_weights": cluster_weights, + "alpha_dp": float( + np.median(np.array(sample_dict["alpha_dp"])) + ), + } + + # Joint MAP assignments + from .postprocessing import assign_ibp_dp_joint + feat_assignments, cluster_assignments = assign_ibp_dp_joint( + sample_dict, data + ) + + pool_entries = {} + for i, p in enumerate(pool_params): + entry = { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + "feature_assignments": feat_assignments[i].tolist(), + "cluster_assignment": int(cluster_assignments[i]), + } + pool_entries[p["pool_id"]] = entry + + elif is_ibp or is_ibp_ste: + model_name = "hierarchical_student_t_ibp" + sigma_eps_structure = "scalar" + + W_median = np.median(np.array(sample_dict["W"]), axis=0).tolist() + v_ibp_median = np.median(np.array(sample_dict["v_ibp"]), axis=0) + pi = np.cumprod(v_ibp_median).tolist() + + population_effects = { + "B": B_median, + "sigma_eps": float(sigma_eps_median), + "df": df_median, + "W": W_median, + "feature_prevalences": pi, + "alpha_ibp": float( + np.median(np.array(sample_dict["alpha_ibp"])) + ), + } + + # Per-pool feature assignments + if is_ibp_ste: + # Legacy STE path: threshold z_logit + z_logit = np.array(sample_dict["z_logit"]) + z_logit_median = np.median(z_logit, axis=0) + feature_assignments = (z_logit_median > 0).astype(int).tolist() + else: + # Marginalized path: MAP assignments from data + from .postprocessing import assign_ibp_features + feature_assignments = assign_ibp_features( + sample_dict, data + ).tolist() + + pool_entries = {} + for i, p in enumerate(pool_params): + entry = { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + "feature_assignments": feature_assignments[i], + } + pool_entries[p["pool_id"]] = entry + + elif is_dp: + model_name = "hierarchical_student_t_dp_sigma" + sigma_eps_structure = "dp_mixture" + else: + model_name = "unified_hierarchical_student_t" + sigma_eps_structure = "per_tier" + + if not is_ibp and not is_ibp_dp: + sigma_theta_median = np.median( + np.array(sample_dict["sigma_theta"]), axis=0 + ).tolist() + + # Correlation matrix + L_Omega = np.array(sample_dict["L_Omega"]) + Omega = np.einsum("sij,skj->sik", L_Omega, L_Omega) + Omega_median = np.median(Omega, axis=0).tolist() + + population_effects = { + "B": B_median, + "sigma_theta": sigma_theta_median, + "sigma_eps": sigma_eps_median.tolist() if hasattr(sigma_eps_median, 'tolist') else sigma_eps_median, + "df": df_median, + "correlation_matrix": Omega_median, + } + + if is_dp: + from .model import stick_breaking_weights + import jax.numpy as jnp + v_median = np.median(np.array(sample_dict["v"]), axis=0) + w = stick_breaking_weights(jnp.array(v_median)) + population_effects["cluster_weights"] = np.array(w).tolist() + population_effects["alpha_dp"] = float( + np.median(np.array(sample_dict["alpha_dp"])) + ) + + pool_entries = { + p["pool_id"]: { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + } + for p in pool_params + } + + output = { + "model": model_name, + "model_spec": { + "K_coeff": K_COEFF, + "K_cov": data["K_cov"], + "coeff_names": COEFF_NAMES, + "covariate_names": data["covariate_names"], + "likelihood": "StudentT", + "tvl_lag": "log_tvl_lag1", + "sigma_eps_structure": sigma_eps_structure, + }, + "inference": inference_config, + "population_effects": population_effects, + "convergence": convergence, + "n_pools": len(pool_params), + "n_obs": len(data["y_obs"]), + "pools": pool_entries, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +def _generate_structural_output(pool_params, sample_dict, data, convergence, + output_path, inference_config): + """Write structural mixture model JSON output.""" + alpha_0 = float(np.median(np.array(sample_dict["alpha_0"]))) + alpha_chain = np.median(np.array(sample_dict["alpha_chain"]), axis=0).tolist() + alpha_tier = np.median(np.array(sample_dict["alpha_tier"]), axis=0).tolist() + alpha_tvl = float(np.median(np.array(sample_dict["alpha_tvl"]))) + + W_gate = np.median(np.array(sample_dict["W_gate"]), axis=0).tolist() + beta = np.median(np.array(sample_dict["beta"]), axis=0).tolist() + K_archetypes = np.array(sample_dict["beta"]).shape[1] + + df_median = float(np.median(np.array(sample_dict["df"]))) + sigma_eps_median = float(np.median(np.array(sample_dict["sigma_eps"]))) + + population_effects = { + "alpha_0": alpha_0, + "alpha_chain": alpha_chain, + "alpha_tier": alpha_tier, + "alpha_tvl": alpha_tvl, + "W_gate": W_gate, + "beta": beta, + "K_archetypes": K_archetypes, + "df": df_median, + "sigma_eps": sigma_eps_median, + } + + pool_entries = {} + for p in pool_params: + pool_entries[p["pool_id"]] = { + "chain": p["chain"], + "tokens": p["tokens"], + "arb_frequency": p["arb_frequency"], + "noise_params": p["noise_params"], + } + + output = { + "model": "structural_mixture", + "model_spec": { + "K_obs_coeff": K_OBS_COEFF, + "obs_coeff_names": OBS_COEFF_NAMES, + "K_cov": data["K_cov"], + "covariate_names": data["covariate_names"], + "likelihood": "StudentT", + }, + "inference": inference_config, + "population_effects": population_effects, + "convergence": convergence, + "n_pools": len(pool_params), + "n_obs": len(data["y_obs"]), + "pools": pool_entries, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +def _save_sample_cache(samples, data, cache_dir): + """Cache posterior samples and data arrays for --predict reuse.""" + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + os.makedirs(cache_dir, exist_ok=True) + + # Save only the samples needed for prediction and diagnostics. + # Skip "y" (S x N_obs, can be >1GB) and "theta" (S x N_pools x K, + # reconstructible from B, eta, sigma_theta, L_Omega). + skip_keys = {"y", "theta"} + sample_cache = os.path.join(cache_dir, "unified_samples.npz") + np.savez_compressed( + sample_cache, + **{k: np.array(v) for k, v in sample_dict.items() + if k not in skip_keys}, + ) + + # Data arrays for predict + data_cache = os.path.join(cache_dir, "unified_data.json") + cache_data = { + "pool_ids": data["pool_ids"], + "covariate_names": data["covariate_names"], + "K_cov": data["K_cov"], + "N_pools": data["N_pools"], + "ref_chain": data["ref_chain"], + "ref_tier_a": data["ref_tier_a"], + "ref_tier_b": data["ref_tier_b"], + "chains": data["chains"], + } + with open(data_cache, "w") as f: + json.dump(cache_data, f, indent=2) + + print(f" Cached samples -> {sample_cache}") + print(f" Cached data metadata -> {data_cache}") diff --git a/quantammsim/noise_calibration/plotting.py b/quantammsim/noise_calibration/plotting.py new file mode 100644 index 0000000..727443b --- /dev/null +++ b/quantammsim/noise_calibration/plotting.py @@ -0,0 +1,335 @@ +"""Diagnostic plots for noise calibration.""" + +import os + +import numpy as np +import pandas as pd + +from .constants import K_COEFF, COEFF_NAMES +from .inference import _get_theta_samples + + +def plot_diagnostics(samples, data, output_dir, elbo_losses=None, + mcmc=None, prior_samples=None): + """Generate up to 9 diagnostic plots.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + os.makedirs(output_dir, exist_ok=True) + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) # (S, N_pools, K_coeff) + theta_median = np.median(theta_samples, axis=0) + + pool_idx = data["pool_idx"] + x_obs = data["x_obs"] + y_obs = data["y_obs"] + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + + # --- 1. Prior predictive check --- + if prior_samples is not None: + y_prior = prior_samples.get("y", None) + if y_prior is not None: + fig, ax = plt.subplots(figsize=(10, 5)) + # Flatten a subsample of prior draws + y_prior_flat = y_prior.flatten() + # Clip for display + clip_lo, clip_hi = np.percentile(y_prior_flat, [0.5, 99.5]) + y_prior_clipped = y_prior_flat[ + (y_prior_flat >= clip_lo) & (y_prior_flat <= clip_hi) + ] + ax.hist(y_prior_clipped, bins=100, alpha=0.5, density=True, + color="steelblue", label="Prior predictive") + ax.hist(y_obs, bins=100, alpha=0.5, density=True, + color="coral", label="Observed") + ax.set_xlabel("log(volume)") + ax.set_ylabel("Density") + ax.set_title("Prior predictive check: log-volume") + ax.legend() + plt.tight_layout() + path = os.path.join(output_dir, "prior_predictive.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 2. ELBO loss curve (SVI only) --- + if elbo_losses is not None: + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + ax = axes[0] + ax.plot(elbo_losses, alpha=0.3, color="steelblue", linewidth=0.5) + # Smoothed + window = min(100, len(elbo_losses) // 10) + if window > 1: + smoothed = pd.Series(elbo_losses).rolling(window).mean().values + ax.plot(smoothed, color="red", linewidth=1.5, label=f"Rolling {window}") + ax.legend() + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence") + + ax = axes[1] + # Last 20% of training + start = len(elbo_losses) * 4 // 5 + ax.plot(range(start, len(elbo_losses)), elbo_losses[start:], + color="steelblue", linewidth=0.8) + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence (last 20%)") + + plt.tight_layout() + path = os.path.join(output_dir, "elbo_convergence.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 3. Trace plots (NUTS only) --- + if mcmc is not None: + try: + import arviz as az + idata = az.from_numpyro(mcmc) + var_names = ["sigma_theta", "sigma_eps", "df"] + available = [v for v in var_names if v in idata.posterior] + if available: + axes = az.plot_trace(idata, var_names=available, compact=True) + fig = axes.ravel()[0].figure + fig.set_size_inches(14, 3 * len(available)) + path = os.path.join(output_dir, "trace_plots.png") + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {path}") + except Exception as e: + print(f" WARNING: Trace plots failed: {e}") + + # --- 4. Posterior predictive: predicted vs observed --- + # Compute y_pred and r2 here; r2 is reused in plot 9 (model summary). + theta_obs = theta_median[pool_idx] + y_pred = np.sum(theta_obs * x_obs, axis=1) + r2 = 1 - np.var(y_obs - y_pred) / np.var(y_obs) + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + ax.scatter(y_obs, y_pred, alpha=0.1, s=4, color="steelblue") + lims = [min(y_obs.min(), y_pred.min()), max(y_obs.max(), y_pred.max())] + ax.plot(lims, lims, "r--", linewidth=1) + ax.set_xlabel("Observed log(volume)") + ax.set_ylabel("Predicted log(volume)") + ax.set_title("Posterior predictive check") + ax.text(0.05, 0.95, f"R² = {r2:.3f}", transform=ax.transAxes, + fontsize=11, verticalalignment="top") + + ax = axes[1] + residuals = y_obs - y_pred + ax.hist(residuals, bins=60, color="steelblue", edgecolor="white", alpha=0.8) + ax.axvline(0, color="red", linestyle="--") + ax.set_xlabel("Residual") + + sigma_eps_samples = np.array(sample_dict.get("sigma_eps", [0])) + if sigma_eps_samples.ndim > 1: + sigma_str = ", ".join(f"{np.median(sigma_eps_samples[:, i]):.2f}" + for i in range(sigma_eps_samples.shape[1])) + else: + sigma_str = f"{np.median(sigma_eps_samples):.2f}" + ax.set_title(f"Residuals (sigma_eps ~ [{sigma_str}])") + + plt.tight_layout() + path = os.path.join(output_dir, "posterior_predictive.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 5. Per-pool b_c by chain/tier --- + b_tvl_all = theta_median[:, 1] + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + chains_present = sorted(pool_meta["chain"].unique()) + chain_data = [] + chain_labels = [] + for c in chains_present: + mask = pool_meta["chain"].values == c + if mask.sum() > 0: + chain_data.append(b_tvl_all[mask]) + chain_labels.append(f"{c}\n(n={mask.sum()})") + if chain_data: + ax.boxplot(chain_data, tick_labels=chain_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by chain") + + ax = axes[1] + tier_a_vals = pool_meta["tier_A"].values.astype(int) + tier_labels_map = {0: "Blue-chip", 1: "Mid-cap", 2: "Long-tail"} + tier_data = [] + tier_labels = [] + for t in [0, 1, 2]: + mask = tier_a_vals == t + if mask.sum() > 0: + tier_data.append(b_tvl_all[mask]) + tier_labels.append(f"{tier_labels_map[t]}\n(n={mask.sum()})") + if tier_data: + ax.boxplot(tier_data, tick_labels=tier_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by token tier (best token)") + + plt.tight_layout() + path = os.path.join(output_dir, "per_pool_b_c.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 6. Correlation matrix posterior --- + L_Omega_samples = np.array(sample_dict["L_Omega"]) # (S, K, K) + Omega_samples = np.einsum("sij,skj->sik", L_Omega_samples, L_Omega_samples) + Omega_median = np.median(Omega_samples, axis=0) + + fig, ax = plt.subplots(figsize=(7, 6)) + im = ax.imshow(Omega_median, vmin=-1, vmax=1, cmap="RdBu_r") + ax.set_xticks(range(K_COEFF)) + ax.set_yticks(range(K_COEFF)) + ax.set_xticklabels(COEFF_NAMES, rotation=45, ha="right") + ax.set_yticklabels(COEFF_NAMES) + for i in range(K_COEFF): + for j in range(K_COEFF): + ax.text(j, i, f"{Omega_median[i, j]:.2f}", ha="center", + va="center", fontsize=10, + color="white" if abs(Omega_median[i, j]) > 0.5 else "black") + plt.colorbar(im, ax=ax, shrink=0.8) + ax.set_title("Posterior median correlation matrix (Omega)") + plt.tight_layout() + path = os.path.join(output_dir, "correlation_matrix.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 7. Shrinkage plot: OLS b_c vs hierarchical b_c --- + ols_b_c = np.zeros(len(pool_ids)) + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + if mask.sum() < 5: + ols_b_c[i] = np.nan + continue + x_i = x_obs[mask] + y_i = y_obs[mask] + try: + beta, _, _, _ = np.linalg.lstsq(x_i, y_i, rcond=None) + ols_b_c[i] = beta[1] # TVL coefficient + except np.linalg.LinAlgError: + ols_b_c[i] = np.nan + + hier_b_c = theta_median[:, 1] + valid = np.isfinite(ols_b_c) + + if valid.sum() > 2: + fig, ax = plt.subplots(figsize=(8, 8)) + ax.scatter(ols_b_c[valid], hier_b_c[valid], alpha=0.6, s=20, + color="steelblue") + + pop_b_c = np.median(hier_b_c) + ax.axhline(pop_b_c, color="red", linestyle="--", linewidth=0.8, + label=f"Population median = {pop_b_c:.3f}") + + lims = [min(np.nanmin(ols_b_c[valid]), hier_b_c[valid].min()) - 0.2, + max(np.nanmax(ols_b_c[valid]), hier_b_c[valid].max()) + 0.2] + ax.plot(lims, lims, "k:", linewidth=0.8, alpha=0.5) + ax.set_xlabel("Per-pool OLS b_c (lagged TVL)") + ax.set_ylabel("Hierarchical posterior median b_c") + ax.set_title("Shrinkage: OLS vs hierarchical TVL elasticity") + ax.legend() + plt.tight_layout() + path = os.path.join(output_dir, "shrinkage_b_c.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 8. beta_tvl vs beta_vol scatter colored by chain --- + fig, ax = plt.subplots(figsize=(10, 7)) + pool_id_to_chain = dict(zip(pool_meta["pool_id"], pool_meta["chain"])) + chain_colors = {} + cmap = plt.cm.tab10 + unique_chains = sorted(pool_meta["chain"].unique()) + for i, c in enumerate(unique_chains): + chain_colors[c] = cmap(i % 10) + + beta_tvl_arr = theta_median[:, 1] + beta_vol_arr = theta_median[:, 2] + for i, pid in enumerate(pool_ids): + c = pool_id_to_chain.get(pid, "?") + ax.scatter(beta_tvl_arr[i], beta_vol_arr[i], + color=chain_colors.get(c, "gray"), alpha=0.6, s=20, + edgecolors="white", linewidths=0.3) + + from matplotlib.lines import Line2D + handles = [Line2D([0], [0], marker="o", color="w", + markerfacecolor=chain_colors[c], markersize=8, + label=c) + for c in unique_chains if c in chain_colors] + ax.legend(handles=handles, fontsize=8, loc="best") + ax.set_xlabel("b_tvl (TVL elasticity)") + ax.set_ylabel("b_sigma (volatility sensitivity)") + ax.set_title("Pool-specific coefficients by chain") + ax.axhline(0, color="gray", linewidth=0.5, linestyle="--") + ax.axvline(0, color="gray", linewidth=0.5, linestyle="--") + plt.tight_layout() + path = os.path.join(output_dir, "beta_tvl_vs_beta_vol.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- 9. Model summary panel --- + B_samples = np.array(sample_dict["B"]) + B_median = np.median(B_samples, axis=0) # (K_coeff, K_cov) + sigma_theta_med = np.median(np.array(sample_dict["sigma_theta"]), axis=0) + df_med = np.median(np.array(sample_dict["df"])) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) + + col_names = data["covariate_names"] + + fig, ax = plt.subplots(figsize=(12, 8)) + ax.axis("off") + + summary = "Group-level regression B (posterior median):\n" + header = f" {'covariate':<20s}" + for cn in COEFF_NAMES: + header += f" {cn:>10s}" + summary += header + "\n" + summary += " " + "-" * (20 + 11 * K_COEFF) + "\n" + for j, name in enumerate(col_names): + line = f" {name:<20s}" + for k in range(K_COEFF): + line += f" {B_median[k, j]:>10.3f}" + summary += line + "\n" + + summary += f"\nsigma_theta: [{', '.join(f'{v:.3f}' for v in sigma_theta_med)}]\n" + summary += f"\nCorrelation matrix (Omega):\n" + for i in range(K_COEFF): + row = " [" + " ".join(f"{Omega_median[i, j]:>6.3f}" + for j in range(K_COEFF)) + "]\n" + summary += row + + tier_names = ["blue-chip", "mid-cap", "long-tail"] + sigma_eps_str = ", ".join(f"{tier_names[i]}={sigma_eps_med[i]:.3f}" + for i in range(len(sigma_eps_med))) + summary += f"\nsigma_eps: [{sigma_eps_str}]\n" + summary += f"df (Student-t): {df_med:.1f}\n" + summary += f"R^2: {r2:.3f}\n" + + ax.text(0.02, 0.98, summary, transform=ax.transAxes, + fontsize=7, verticalalignment="top", fontfamily="monospace") + ax.set_title("Model Summary") + plt.tight_layout() + path = os.path.join(output_dir, "model_summary.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") diff --git a/quantammsim/noise_calibration/postprocessing.py b/quantammsim/noise_calibration/postprocessing.py new file mode 100644 index 0000000..4980b44 --- /dev/null +++ b/quantammsim/noise_calibration/postprocessing.py @@ -0,0 +1,667 @@ +"""Post-processing: extract params, predict, convergence, prior predictive.""" + +import numpy as np + +from .constants import K_COEFF, COEFF_NAMES, K_OBS_COEFF, OBS_COEFF_NAMES +from .token_classification import classify_token_tier +from .covariate_encoding import _tier_pair_idx +from .inference import _get_theta_samples, _build_model_kwargs +from .model import noise_model + + +def extract_noise_params(samples, data, use_median=True) -> list: + """Extract per-pool noise params from posterior samples. + + Handles both MCMC.get_samples() and SVI samples dict. + Applies weekend absorption: b_0_eff = b_0_raw + b_weekend * (2/7). + """ + # Get theta samples + if hasattr(samples, "get_samples"): + # MCMC object + sample_dict = samples.get_samples() + else: + sample_dict = samples + + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) # (S, N_pools, K_coeff) + + agg_fn = np.median if use_median else np.mean + theta_agg = agg_fn(theta_samples, axis=0) # (N_pools, K_coeff) + theta_std = np.std(theta_samples, axis=0) + + pool_ids = data["pool_ids"] + pool_meta = data["pool_meta"] + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + b_0_raw, b_tvl, b_sigma, b_weekend = theta_agg[i] + std_vals = theta_std[i] + + # Weekend absorption: simulator has no weekend indicator, + # so fold the expected weekend effect into the intercept. + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "theta_median": [float(x) for x in theta_agg[i]], + "theta_std": [float(x) for x in std_vals], + "b_weekend": float(b_weekend), + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(meta["swap_fee"]), + }, + }) + + return results + + +def predict_new_pool(samples, data, chain: str, tokens: list, + fee: float, feature_assignments=None) -> dict: + """Predict noise params for an unseen pool using population effects. + + Constructs z_new, computes mu_new = B @ z_new across all posterior samples, + returns point estimate + 90% credible intervals with weekend absorption. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Build z_new using data-driven column names + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b}": + z_new[i] = 1.0 + + # mu_new = B @ z_new across all posterior samples + B_samples = np.array(sample_dict["B"]) # (S, K_coeff, K_cov) + mu_samples = np.einsum("skd,d->sk", B_samples, z_new) # (S, K_coeff) + + # IBP path: add feature effects + is_ibp = "W" in sample_dict + if is_ibp: + W_samples = np.array(sample_dict["W"]) # (S, K_features, K_coeff) + if feature_assignments is not None: + # User-specified binary features + Z = np.array(feature_assignments) # (K_features,) + feature_effect = np.einsum("skj,k->sj", W_samples, Z) + prediction_source = "ibp_user_features" + else: + # Marginal: weight by prevalences pi = cumprod(v_ibp) + v_ibp = np.array(sample_dict["v_ibp"]) # (S, K_features) + pi = np.cumprod(v_ibp, axis=1) # (S, K_features) + feature_effect = np.einsum("skj,sk->sj", W_samples, pi) + prediction_source = "ibp_marginal" + mu_samples = mu_samples + feature_effect + else: + prediction_source = "population_level" + + mu_median = np.median(mu_samples, axis=0) + mu_q05 = np.percentile(mu_samples, 5, axis=0) + mu_q95 = np.percentile(mu_samples, 95, axis=0) + + # Weekend absorption + b_0_raw, b_tvl, b_sigma, b_weekend = mu_median + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + result = { + "chain": chain, + "tokens": tokens, + "fee": fee, + "prediction_source": prediction_source, + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(fee), + }, + "credible_intervals_90": { + name: { + "median": float(mu_median[k]), + "q05": float(mu_q05[k]), + "q95": float(mu_q95[k]), + } + for k, name in enumerate(COEFF_NAMES) + }, + } + + print(f"\n Predicted noise_params for {chain} {tokens} (fee={fee}):") + for name, ci in result["credible_intervals_90"].items(): + print(f" {name:12s}: {ci['median']:+.3f} " + f"[{ci['q05']:+.3f}, {ci['q95']:+.3f}]") + print(f"\n Effective b_0 (weekend-absorbed): {b_0_effective:.3f}") + + return result + + +def check_convergence(mcmc_or_losses, method="nuts") -> dict: + """Compute convergence diagnostics. + + For NUTS: R-hat, ESS, divergences. + For SVI: final ELBO, ELBO stability. + """ + if method == "svi": + losses = np.array(mcmc_or_losses) + return { + "method": "svi", + "final_elbo": float(losses[-1]), + "elbo_last_100_std": float(np.std(losses[-100:])), + "elbo_last_100_mean": float(np.mean(losses[-100:])), + } + + # NUTS diagnostics + import arviz as az + + mcmc = mcmc_or_losses + idata = az.from_numpyro(mcmc) + + n_chains = idata.posterior.sizes.get("chain", 1) + + rhat_max = float("nan") + if n_chains >= 2: + rhat = az.rhat(idata) + rhat_vals = [] + for var in rhat.data_vars: + if var == "theta": + continue + vals = rhat[var].values + rhat_vals.extend(vals.flatten()) + rhat_max = float(np.nanmax(rhat_vals)) if rhat_vals else float("nan") + + ess = az.ess(idata) + ess_vals = [] + for var in ess.data_vars: + if var == "theta": + continue + vals = ess[var].values + ess_vals.extend(vals.flatten()) + ess_min = float(np.nanmin(ess_vals)) if ess_vals else float("nan") + + divergences = int(idata.sample_stats["diverging"].sum().values) + + print(f"\n Convergence diagnostics:") + if n_chains >= 2: + print(f" R-hat max: {rhat_max:.4f} " + f"{'OK' if rhat_max < 1.05 else 'WARNING'}") + else: + print(f" R-hat max: N/A (need >= 2 chains)") + print(f" ESS min: {ess_min:.0f} " + f"{'OK' if ess_min > 400 else 'WARNING'}") + print(f" Divergences: {divergences} " + f"{'OK' if divergences == 0 else 'WARNING'}") + + return { + "method": "nuts", + "r_hat_max": rhat_max, + "ess_min": ess_min, + "divergences": divergences, + } + + +def assign_dp_clusters(samples, data) -> np.ndarray: + """Compute posterior MAP cluster assignments for DP mixture model. + + Uses median posterior samples for v->w, sigma_eps, df, theta to compute + per-pool-per-cluster log-likelihoods, then returns argmax assignments. + """ + from scipy.special import logsumexp + from .model import stick_breaking_weights + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + # Median posterior parameters + v_med = np.median(np.array(sample_dict["v"]), axis=0) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) + df_med = float(np.median(np.array(sample_dict["df"]))) + + # Compute w from v via stick-breaking (using numpy) + import jax.numpy as jnp + w = np.array(stick_breaking_weights(jnp.array(v_med))) + + # Reconstruct theta + theta_samples = _get_theta_samples( + sample_dict, np.array(data["X_pool"]), data=data + ) + theta_med = np.median(theta_samples, axis=0) # (N_pools, K_coeff) + + # Per-observation predicted means + pool_idx = np.array(data["pool_idx"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + K_clusters = len(sigma_eps_med) + + theta_obs = theta_med[pool_idx] + mu_obs = np.sum(theta_obs * x_obs, axis=1) # (N_obs,) + + # Log-likelihood per observation per cluster + from scipy.stats import t as t_dist + log_lik_per_k = np.zeros((len(y_obs), K_clusters)) + for k in range(K_clusters): + log_lik_per_k[:, k] = t_dist.logpdf( + y_obs, df_med, loc=mu_obs, scale=sigma_eps_med[k] + ) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, K_clusters)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik_per_k[i] + + # Posterior cluster probabilities: log p(z=k|data) = log w_k + sum log p(y|k) + log_posterior = np.log(w + 1e-30)[None, :] + pool_log_liks + # MAP assignment + assignments = np.argmax(log_posterior, axis=1).astype(np.int64) + return assignments + + +def assign_ibp_features(samples, data) -> np.ndarray: + """Compute MAP feature assignments for marginalized IBP model. + + Enumerates all 2^K binary feature configurations per pool, evaluates + per-pool log-posterior (log-prior + log-likelihood), returns argmax + config as (N_pools, K_features) binary ndarray. + + Uses median posterior parameters for B, W, v_ibp, sigma_eps, df. + """ + from scipy.stats import t as t_dist + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + B_med = np.median(np.array(sample_dict["B"]), axis=0) # (K_coeff, K_cov) + W_med = np.median(np.array(sample_dict["W"]), axis=0) # (K_features, K_coeff) + v_ibp_med = np.median(np.array(sample_dict["v_ibp"]), axis=0) # (K_features,) + sigma_eps_med = float(np.median(np.array(sample_dict["sigma_eps"]))) + df_med = float(np.median(np.array(sample_dict["df"]))) + + pi = np.cumprod(v_ibp_med) # (K_features,) + K_features = len(pi) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + + # Enumerate all 2^K configs + n_configs = 2 ** K_features + configs = ( + (np.arange(n_configs)[:, None] >> np.arange(K_features)[None, :]) & 1 + ).astype(float) # (n_configs, K_features) + + # Log-prior per config + log_pi = np.log(pi + 1e-30) + log_1mpi = np.log(1.0 - pi + 1e-30) + log_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Population mean + mu_pop = X_pool @ B_med.T # (N_pools, K_coeff) + + # Feature effects per config + feature_effects = configs @ W_med # (n_configs, K_coeff) + + # Per-obs means: mu_pop_obs + feature_mu + mu_pop_obs = np.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config + log_lik = t_dist.logpdf( + y_obs[:, None], df_med, loc=mu_obs, scale=sigma_eps_med + ) # (N_obs, n_configs) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, n_configs)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik[i] + + # Posterior = log_prior + pool_log_liks; MAP config per pool + log_posterior = log_prior[None, :] + pool_log_liks # (N_pools, n_configs) + best_config_idx = np.argmax(log_posterior, axis=1) # (N_pools,) + + return configs[best_config_idx].astype(int) # (N_pools, K_features) + + +def assign_ibp_dp_joint(samples, data) -> tuple: + """Compute MAP joint (feature, cluster) assignments for hybrid IBP+DP model. + + Enumerates all (2^K_features × K_clusters) joint configurations per pool, + evaluates joint log-posterior, returns argmax assignments. + + Returns: + (feature_assignments, cluster_assignments): + feature_assignments: (N_pools, K_features) binary ndarray + cluster_assignments: (N_pools,) int ndarray + """ + from scipy.stats import t as t_dist + from .model import stick_breaking_weights + + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + B_med = np.median(np.array(sample_dict["B"]), axis=0) # (K_coeff, K_cov) + W_med = np.median(np.array(sample_dict["W"]), axis=0) # (K_features, K_coeff) + v_ibp_med = np.median(np.array(sample_dict["v_ibp"]), axis=0) # (K_features,) + v_med = np.median(np.array(sample_dict["v"]), axis=0) # (K_clusters-1,) + sigma_eps_med = np.median(np.array(sample_dict["sigma_eps"]), axis=0) # (K_clusters,) + df_med = float(np.median(np.array(sample_dict["df"]))) + + pi = np.cumprod(v_ibp_med) # (K_features,) + K_features = len(pi) + K_clusters = len(sigma_eps_med) + + import jax.numpy as jnp + w = np.array(stick_breaking_weights(jnp.array(v_med))) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + y_obs = np.array(data["y_obs"]) + N_pools = data["N_pools"] + + # Enumerate all 2^K configs + n_configs = 2 ** K_features + configs = ( + (np.arange(n_configs)[:, None] >> np.arange(K_features)[None, :]) & 1 + ).astype(float) # (n_configs, K_features) + + # IBP log-prior per config + log_pi = np.log(pi + 1e-30) + log_1mpi = np.log(1.0 - pi + 1e-30) + log_ibp_prior = configs @ log_pi + (1.0 - configs) @ log_1mpi # (n_configs,) + + # Joint log-prior: IBP config × DP cluster + log_joint_prior = log_ibp_prior[:, None] + np.log(w + 1e-30)[None, :] # (n_configs, K_clusters) + + # Population mean + mu_pop = X_pool @ B_med.T # (N_pools, K_coeff) + feature_effects = configs @ W_med # (n_configs, K_coeff) + + # Per-obs means + mu_pop_obs = np.sum(mu_pop[pool_idx] * x_obs, axis=1) # (N_obs,) + feature_mu = x_obs @ feature_effects.T # (N_obs, n_configs) + mu_obs = mu_pop_obs[:, None] + feature_mu # (N_obs, n_configs) + + # Log-likelihood per obs per config per cluster + log_lik = np.zeros((len(y_obs), n_configs, K_clusters)) + for k in range(K_clusters): + log_lik[:, :, k] = t_dist.logpdf( + y_obs[:, None], df_med, loc=mu_obs, scale=sigma_eps_med[k] + ) + + # Sum within pools + pool_log_liks = np.zeros((N_pools, n_configs, K_clusters)) + for i in range(len(y_obs)): + pool_log_liks[pool_idx[i]] += log_lik[i] + + # Joint posterior: log_joint_prior + pool_log_liks + log_posterior = log_joint_prior[None, :, :] + pool_log_liks # (N_pools, n_configs, K_clusters) + + # Flatten to (N_pools, n_configs * K_clusters), argmax, unravel + flat = log_posterior.reshape(N_pools, -1) + best_flat_idx = np.argmax(flat, axis=1) + best_config_idx = best_flat_idx // K_clusters + best_cluster_idx = best_flat_idx % K_clusters + + feature_assignments = configs[best_config_idx].astype(int) # (N_pools, K_features) + cluster_assignments = best_cluster_idx.astype(np.int64) # (N_pools,) + + return feature_assignments, cluster_assignments + + +def extract_structural_params(samples, data, use_median=True) -> list: + """Extract per-pool arb frequency and noise coefficients from structural model. + + Parameters + ---------- + samples : dict + Posterior samples from SVI/NUTS with structural_noise_model. + data : dict + Output of encode_covariates_structural(). + + Returns + ------- + list of dict + Per-pool dicts with: pool_id, chain, tokens, arb_frequency, noise_params. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + import jax + import jax.numpy as jnp + from .model import _pad_with_ref + + agg_fn = np.median if use_median else np.mean + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # MoE parameters + W_gate = agg_fn(np.array(sample_dict["W_gate"]), axis=0) + beta = agg_fn(np.array(sample_dict["beta"]), axis=0) + + chain_idx = np.array(data["chain_idx"]) + tier_idx = np.array(data["tier_idx"]) + X_pool = np.array(data["X_pool"]) + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + + # Per-pool cadence + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + # Per-pool log_tvl (median across observations) + pool_idx_arr = np.array(data["pool_idx"]) + lag_log_tvl = np.array(data["lag_log_tvl"]) + N_pools = data["N_pools"] + + pool_tvl_median = np.zeros(N_pools) + for p in range(N_pools): + mask = pool_idx_arr == p + if mask.any(): + pool_tvl_median[p] = np.median(lag_log_tvl[mask]) + + # Per-pool noise coefficients via MoE gating + logits = X_pool @ W_gate # (N_pools, K_archetypes) + w = np.exp(logits - logits.max(axis=1, keepdims=True)) + w = w / w.sum(axis=1, keepdims=True) # softmax + beta_pool = w @ beta # (N_pools, K_obs_coeff) + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + log_cadence = ( + alpha_0 + + padded_chain[chain_idx[i]] + + padded_tier[tier_idx[i]] + + alpha_tvl * pool_tvl_median[i] + ) + cadence = np.exp(np.clip(log_cadence, -2.0, 6.0)) + arb_freq = int(np.clip(np.round(cadence), 1, 60)) + + noise_coeffs = { + name: float(beta_pool[i, k]) + for k, name in enumerate(OBS_COEFF_NAMES) + } + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "arb_frequency": arb_freq, + "noise_params": noise_coeffs, + }) + + return results + + +def predict_new_pool_structural( + samples, data, chain: str, tokens: list, fee: float, tvl_est: float, +) -> dict: + """Predict cadence and noise coefficients for a hypothetical pool. + + Uses the structural model's arb cadence parameters and MoE gating. + + Parameters + ---------- + samples : dict + Posterior samples from structural_noise_model. + data : dict + Output of encode_covariates_structural(). + chain : str + Chain name. + tokens : list of str + Token symbols. + fee : float + Swap fee (fraction). + tvl_est : float + Estimated TVL in USD. + """ + if hasattr(samples, "get_samples"): + sample_dict = samples.get_samples() + else: + sample_dict = samples + + agg_fn = np.median + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # MoE parameters + W_gate = agg_fn(np.array(sample_dict["W_gate"]), axis=0) + beta = agg_fn(np.array(sample_dict["beta"]), axis=0) + + # Construct chain and tier indices for the new pool + chains = data["chains"] + chain_to_idx = {c: i for i, c in enumerate(chains)} + c_idx = chain_to_idx.get(chain, 0) # fallback to reference + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tier_a + t_idx = _tier_pair_idx(tier_a, tier_b) + + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + log_tvl = np.log(max(tvl_est, 1.0)) + log_cadence = ( + alpha_0 + + padded_chain[c_idx] + + padded_tier[t_idx] + + alpha_tvl * log_tvl + ) + cadence = np.exp(np.clip(log_cadence, -2.0, 6.0)) + arb_freq = int(np.clip(np.round(cadence), 1, 60)) + + # Construct X_pool_new for gating + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + tier_a_str = str(tier_a) + tier_b_str = str(tier_b) + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a_str}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b_str}": + z_new[i] = 1.0 + + # MoE gating for new pool + logits = z_new @ W_gate # (K_archetypes,) + w = np.exp(logits - logits.max()) + w = w / w.sum() + beta_new = w @ beta # (K_obs_coeff,) + + noise_coeffs = { + name: float(beta_new[k]) + for k, name in enumerate(OBS_COEFF_NAMES) + } + + return { + "chain": chain, + "tokens": tokens, + "fee": fee, + "tvl_est": tvl_est, + "arb_frequency": arb_freq, + "noise_params": noise_coeffs, + "archetype_weights": w.tolist(), + } + + +def run_prior_predictive(data, num_samples=500, model_fn=None) -> dict: + """Run prior predictive check (no observations).""" + import jax + from numpyro.infer import Predictive + + if model_fn is None: + model_fn = noise_model + + model_kwargs = _build_model_kwargs(data, model_fn=model_fn) + model_kwargs["y_obs"] = None # no observations + + predictive = Predictive(model_fn, num_samples=num_samples) + rng_key = jax.random.PRNGKey(99) + prior_samples = predictive(rng_key, **model_kwargs) + prior_samples = {k: np.array(v) for k, v in prior_samples.items()} + + print(f" Prior predictive: drew {num_samples} samples") + y_prior = prior_samples.get("y", None) + if y_prior is not None: + print(f" Prior log-volume range: " + f"[{np.percentile(y_prior, 1):.1f}, " + f"{np.percentile(y_prior, 99):.1f}]") + print(f" Observed log-volume range: " + f"[{data['y_obs'].min():.1f}, {data['y_obs'].max():.1f}]") + + return prior_samples diff --git a/quantammsim/noise_calibration/token_classification.py b/quantammsim/noise_calibration/token_classification.py new file mode 100644 index 0000000..aad9caa --- /dev/null +++ b/quantammsim/noise_calibration/token_classification.py @@ -0,0 +1,23 @@ +"""Token tier classification.""" + +from .constants import _TIER_0, _TIER_1 + + +def _normalise_symbol(symbol: str) -> str: + """Normalise wrapped/bridged variants to canonical form.""" + s = symbol.strip() + mapping = { + "WETH": "WETH", "WBTC": "WBTC", "cbBTC": "cbBTC", + "WMATIC": "WMATIC", "WAVAX": "WAVAX", "WXDAI": "WXDAI", "wS": "wS", + } + return mapping.get(s, s) + + +def classify_token_tier(symbol: str) -> int: + """Classify a token symbol into tier 0/1/2.""" + s = _normalise_symbol(symbol) + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 From 7e70f3312ec8dc1812a5ebe8d45b45614cf0c4dd Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:49:56 +0000 Subject: [PATCH 04/14] add new experiments and scripts related to noise modelling and tuning --- experiments/compare_exposures.py | 141 ++ experiments/diagnose_spikes.py | 216 +++ experiments/diagnostic_lp_supply.py | 134 ++ experiments/pool_registry.py | 512 +++++++ experiments/run_pool_battery.py | 914 ++++++++++++ experiments/tune_reclamm_params.py | 21 +- scripts/benchmark_reclamm_interpolation.py | 648 ++++++++ scripts/calibrate_noise_bayesian.py | 853 +++++++++++ scripts/calibrate_noise_hierarchical.py | 1556 ++++++++++++++++++++ scripts/calibrate_noise_unified.py | 6 + scripts/calibrate_reclamm_noise.py | 842 +++++++++++ scripts/compare_reclamm_thermostats.py | 379 +++++ scripts/demo_run_chunks_from_chain_data.py | 11 +- scripts/demo_run_from_chain_data.py | 13 +- scripts/demo_run_reclamm.py | 207 +++ scripts/plot_predicted_vs_real_volume.py | 151 ++ scripts/plot_reclamm_optuna_result.py | 451 ++++++ scripts/plot_top50_predicted_vs_real.py | 521 +++++++ scripts/run_structural_top50.py | 456 ++++++ scripts/sim_vs_world_comparison.py | 972 ++++++++++++ 20 files changed, 8986 insertions(+), 18 deletions(-) create mode 100644 experiments/compare_exposures.py create mode 100644 experiments/diagnose_spikes.py create mode 100644 experiments/diagnostic_lp_supply.py create mode 100644 experiments/pool_registry.py create mode 100644 experiments/run_pool_battery.py create mode 100644 scripts/benchmark_reclamm_interpolation.py create mode 100644 scripts/calibrate_noise_bayesian.py create mode 100644 scripts/calibrate_noise_hierarchical.py create mode 100644 scripts/calibrate_noise_unified.py create mode 100644 scripts/calibrate_reclamm_noise.py create mode 100644 scripts/compare_reclamm_thermostats.py create mode 100644 scripts/demo_run_reclamm.py create mode 100644 scripts/plot_predicted_vs_real_volume.py create mode 100644 scripts/plot_reclamm_optuna_result.py create mode 100644 scripts/plot_top50_predicted_vs_real.py create mode 100644 scripts/run_structural_top50.py create mode 100644 scripts/sim_vs_world_comparison.py diff --git a/experiments/compare_exposures.py b/experiments/compare_exposures.py new file mode 100644 index 0000000..14fdd06 --- /dev/null +++ b/experiments/compare_exposures.py @@ -0,0 +1,141 @@ +"""Compare effective weight/exposure trajectories: reClAMM vs Balancer 50/50. + +Prints weight stats and saves a plot of weight[AAVE] over time for both pools. +""" + +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(exponent): + return 1.0 - exponent / 124649.0 + + +TOKENS = ["AAVE", "ETH"] +START = "2024-06-01 00:00:00" +END = "2025-06-01 00:00:00" + +CONFIGS = { + "reClAMM on-chain (pr=1.5)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(0.1)), + }, + }, + "reClAMM wide (pr=4)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(1.0)), + }, + }, + "reClAMM Phase 2 (pr=4, m=0.1)": { + "fingerprint": { + "tokens": TOKENS, "rule": "reclamm", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.1), + "daily_price_shift_base": jnp.array(to_daily_price_shift_base(0.001)), + }, + }, + "Balancer 50/50": { + "fingerprint": { + "tokens": TOKENS, "rule": "balancer", + "startDateString": START, "endDateString": END, + "initial_pool_value": 1_000_000.0, "do_arb": True, + "fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0, + "chunk_period": 60, "weight_interpolation_period": 60, + }, + "params": { + "initial_weights_logits": jnp.zeros(2), + }, + }, +} + +results = {} +for name, cfg in CONFIGS.items(): + print(f"Running {name}...") + r = do_run_on_historic_data( + run_fingerprint=cfg["fingerprint"], params=cfg["params"] + ) + results[name] = r + +# Compute effective weights (value fraction in token 0 = AAVE) +print("\n" + "=" * 90) +print(f" {'Config':<35s} {'w_AAVE mean':>10s} {'w_AAVE std':>10s} " + f"{'w_AAVE min':>10s} {'w_AAVE max':>10s} {'vs HODL':>10s}") +print("-" * 90) + +daily = 1440 # subsample to daily for stats and plotting +fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True) + +for name, r in results.items(): + reserves = np.array(r["reserves"]) + prices = np.array(r["prices"]) + values = reserves * prices # (T, 2) + total = values.sum(axis=1, keepdims=True) + weights = values / np.clip(total, 1e-10, None) # (T, 2) + w_aave = weights[::daily, 0] + + hodl_value = float((reserves[0] * prices[-1]).sum()) + vs_hodl = r["final_value"] / hodl_value - 1.0 + + print(f" {name:<35s} {w_aave.mean():>10.4f} {w_aave.std():>10.4f} " + f"{w_aave.min():>10.4f} {w_aave.max():>10.4f} {vs_hodl * 100:>9.2f}%") + + days = np.arange(len(w_aave)) + axes[0].plot(days, w_aave, label=name, alpha=0.8) + + # Pool value over time + pool_val = np.array(r["value"])[::daily] + axes[1].plot(days[:len(pool_val)], pool_val / 1e6, label=name, alpha=0.8) + +print("=" * 90) + +# HODL line +r0 = results[list(results.keys())[0]] +prices_daily = np.array(r0["prices"])[::daily] +reserves_0 = np.array(r0["reserves"])[0] +hodl_val = (reserves_0 * prices_daily).sum(axis=1) / 1e6 +axes[1].plot(np.arange(len(hodl_val)), hodl_val, label="HODL", ls="--", color="gray", alpha=0.7) + +# Price ratio (AAVE/ETH) on third axis +price_ratio_series = prices_daily[:, 0] / prices_daily[:, 1] +axes[2].plot(np.arange(len(price_ratio_series)), price_ratio_series, color="black", alpha=0.7) +axes[2].set_ylabel("AAVE/ETH price") +axes[2].set_xlabel("Days") + +axes[0].set_ylabel("AAVE weight (value fraction)") +axes[0].axhline(0.5, ls="--", color="gray", alpha=0.5) +axes[0].legend(fontsize=8) +axes[0].set_title("Effective AAVE exposure over time") + +axes[1].set_ylabel("Pool value ($M)") +axes[1].legend(fontsize=8) +axes[1].set_title("Pool value over time") + +plt.tight_layout() +plt.savefig("reclamm_exposure_comparison.png", dpi=150) +print("\nSaved reclamm_exposure_comparison.png") diff --git a/experiments/diagnose_spikes.py b/experiments/diagnose_spikes.py new file mode 100644 index 0000000..b146c6b --- /dev/null +++ b/experiments/diagnose_spikes.py @@ -0,0 +1,216 @@ +"""Diagnose spikes in sim-vs-world deviation for LP-supply-normalized runs. + +Compares old (no LP supply) vs new (with LP supply + per-LP normalization) +to pinpoint what causes spikes in the deviation time series. +""" + +import os +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, extract_on_chain_state, extract_initial_state, + get_data_end_date, load_world_history, load_bpt_supply_df, +) +from experiments.run_pool_battery import ( + run_sim, sample_at_timestamps, _start_str_from_pool, + _onchain_params_to_sim, PROTOCOL_FEE_SPLIT, +) +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +POOL_LABEL = "WAVAX_USDC" # Change to "cbBTC_WETH" etc. + + +def main(): + pool = POOL_REGISTRY[POOL_LABEL] + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + start_sec = datetime.strptime( + start_str, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + print(f"Pool: {pool.label}, TVL: ${pool.initial_pool_value_usd:,.0f}") + print(f"Period: {start_str} to {end_str}") + print(f"BPT range: {lp_supply_df['lp_supply'].min():.4f} to {lp_supply_df['lp_supply'].max():.4f}") + + # ---- Run sims ---- + # 1. Old way: no lp_supply at all + result_old = run_sim( + pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, start=start_str, end=end_str, + lp_supply_df=None, + ) + + # 2. New way: lp_supply in scan + per-LP normalization + result_new = run_sim( + pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, start=start_str, end=end_str, + lp_supply_df=lp_supply_df, + ) + + # 3. Raw lp run (scan has lp_supply, but we DON'T divide by it) + params = _onchain_params_to_sim(pool) + fp = { + "tokens": pool.tokens, "rule": "reclamm", + "startDateString": start_str, "endDateString": end_str, + "initial_pool_value": pool.initial_pool_value_usd, + "fees": pool.swap_fee, "gas_cost": 0.0, "arb_fees": 0.0, + "do_arb": True, "arb_frequency": 1, "chunk_period": 1440, + "weight_interpolation_period": 1440, + "reclamm_use_shift_exponent": True, + "reclamm_interpolation_method": "geometric", + "reclamm_centeredness_scaling": False, + "protocol_fee_split": PROTOCOL_FEE_SPLIT, + "reclamm_initial_state": initial_state, + } + result_raw = do_run_on_historic_data( + run_fingerprint=fp, params=params, lp_supply_df=lp_supply_df, + ) + v_lp_raw = np.array(result_raw["value"]) + + v_old = np.array(result_old["value_usd"]) # no LP, no normalization + v_new = np.array(result_new["value_usd"]) # LP scan + divided by lp_supply + + # ---- World ---- + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + prices_min = result_old["prices"] + + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts) + for i in range(prices_min.shape[1]) + ], axis=1) + + # BPT-normalized world value (same in both old and new) + world_bpt_val = ( + world["bal_0"] * prices_at_world[:, 0] + + world["bal_1"] * prices_at_world[:, 1] + ) + world_growth = world_bpt_val / world_bpt_val[0] + + # Raw world value (absolute, un-normalized) + world_raw_val = ( + world["raw_bal_0"] * prices_at_world[:, 0] + + world["raw_bal_1"] * prices_at_world[:, 1] + ) + + # Sample sim at world timestamps + old_at_world = sample_at_timestamps(v_old, start_sec, world_ts) + new_at_world = sample_at_timestamps(v_new, start_sec, world_ts) + raw_at_world = sample_at_timestamps(v_lp_raw, start_sec, world_ts) + + old_growth = old_at_world / old_at_world[0] + new_growth = new_at_world / new_at_world[0] + raw_growth = raw_at_world / raw_at_world[0] + + # Deviations + dev_old = (old_growth / world_growth - 1) * 100 + dev_new = (new_growth / world_growth - 1) * 100 + + # Raw vs raw-world comparison (both absolute) + world_raw_growth = world_raw_val / world_raw_val[0] + dev_raw = (raw_growth / world_raw_growth - 1) * 100 + + days = (world_ts - world_ts[0]) / 86400 + + # LP supply at world timestamps + lp_unix = np.array(lp_supply_df["unix"]) + lp_vals = np.array(lp_supply_df["lp_supply"]) + lp_at_world = np.interp(world_ts, lp_unix / 1000, lp_vals) + + # ---- Print diagnostics ---- + print(f"\n--- Final deviations ---") + print(f"Old (no LP): {dev_old[-1]:+.4f}%") + print(f"New (LP + per-LP): {dev_new[-1]:+.4f}%") + print(f"Raw (LP, absolute): {dev_raw[-1]:+.4f}%") + + # Spike analysis + for label, dev in [("Old", dev_old), ("New", dev_new), ("Raw", dev_raw)]: + diffs = np.abs(np.diff(dev)) + n_spikes_01 = np.sum(diffs > 0.1) + n_spikes_05 = np.sum(diffs > 0.5) + n_spikes_10 = np.sum(diffs > 1.0) + print(f"\n{label} — step-to-step jumps in deviation:") + print(f" >0.1%: {n_spikes_01}, >0.5%: {n_spikes_05}, >1.0%: {n_spikes_10}") + if n_spikes_10 > 0: + spike_idx = np.where(diffs > 1.0)[0] + for si in spike_idx[:5]: + print(f" day {days[si]:.1f}: dev {dev[si]:+.2f}% -> {dev[si+1]:+.2f}% " + f"(Δ={dev[si+1]-dev[si]:+.2f}%, lp={lp_at_world[si]:.4f}->{lp_at_world[si+1]:.4f})") + + # World growth spikes + world_g_diffs = np.diff(world_growth) + n_world_spikes = np.sum(np.abs(world_g_diffs) > 0.01) + print(f"\nWorld BPT-normalized growth jumps > 1%: {n_world_spikes}") + if n_world_spikes > 0: + wsi = np.where(np.abs(world_g_diffs) > 0.01)[0] + for si in wsi[:5]: + print(f" day {days[si]:.1f}: growth {world_growth[si]:.4f} -> {world_growth[si+1]:.4f} " + f"(Δ={world_g_diffs[si]:+.4f}, lp={lp_at_world[si]:.4f}->{lp_at_world[si+1]:.4f})") + + # ---- Plot ---- + fig, axes = plt.subplots(4, 1, figsize=(14, 16), sharex=True) + + ax = axes[0] + ax.plot(days, dev_old, "b-", linewidth=1.5, label=f"Old (no LP) → {dev_old[-1]:+.2f}%") + ax.plot(days, dev_new, "r-", linewidth=1.5, label=f"New (LP + per-LP norm) → {dev_new[-1]:+.2f}%") + ax.axhline(0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("% deviation from world") + ax.set_title(f"{pool.label} — gas=0, arb=1min — old vs new deviation") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[1] + ax.plot(days, dev_raw, "g-", linewidth=1.5, label=f"Raw absolute (LP scan, raw world) → {dev_raw[-1]:+.2f}%") + ax.axhline(0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("% deviation") + ax.set_title("Alternative: raw absolute sim vs raw absolute world") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[2] + ax.plot(days, world_growth, "k-", linewidth=2, label="World (BPT-normalized)") + ax.plot(days, old_growth, "b-", linewidth=1, alpha=0.8, label="Old sim") + ax.plot(days, new_growth, "r-", linewidth=1, alpha=0.8, label="New sim (LP + per-LP)") + ax.set_ylabel("Growth factor") + ax.set_title("Growth factors") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + ax = axes[3] + ax.plot(days, lp_at_world, "g-", linewidth=2, label="LP supply (BPT/BPT₀)") + # Mark large LP changes + lp_diffs = np.abs(np.diff(lp_at_world)) + big_lp = np.where(lp_diffs > 0.05)[0] + if len(big_lp): + ax.scatter(days[big_lp], lp_at_world[big_lp], c="red", s=40, zorder=5, + label=f"Large LP events ({len(big_lp)})") + ax.set_ylabel("BPT / BPT₀") + ax.set_xlabel("Days from start") + ax.set_title("On-chain BPT supply") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + fig.suptitle( + f"{pool.label} ({pool.chain}) — spike diagnosis", + fontsize=13, fontweight="bold", + ) + plt.tight_layout() + + os.makedirs("results", exist_ok=True) + out = f"results/diagnose_spikes_{POOL_LABEL}.png" + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved: {out}") + + +if __name__ == "__main__": + main() diff --git a/experiments/diagnostic_lp_supply.py b/experiments/diagnostic_lp_supply.py new file mode 100644 index 0000000..3f3c317 --- /dev/null +++ b/experiments/diagnostic_lp_supply.py @@ -0,0 +1,134 @@ +"""Diagnostic: sim vs world absolute pool value for a single (gas=0, arb_freq=1) run. + +Plots raw USD pool value over time for both sim and world, no per-LP normalization. +""" + +import os +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, extract_on_chain_state, extract_initial_state, + get_data_end_date, load_world_history, load_bpt_supply_df, +) +from experiments.run_pool_battery import run_sim, sample_at_timestamps, _start_str_from_pool + + +def main(): + pool = POOL_REGISTRY["cbBTC_WETH"] + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + print(f"Pool: {pool.label}, TVL: ${pool.initial_pool_value_usd:,.0f}") + print(f"BPT: {lp_supply_df['lp_supply'].iloc[0]:.4f} -> {lp_supply_df['lp_supply'].iloc[-1]:.4f}") + print(f"Period: {start_str} to {end_str}") + + # Run sim WITH lp_supply (gas=0, arb_freq=1) + result_lp = run_sim(pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, + start=start_str, end=end_str, + lp_supply_df=lp_supply_df) + + # Run sim WITHOUT lp_supply (gas=0, arb_freq=1) + result_no_lp = run_sim(pool, gas_cost=0.0, arb_frequency=1, + initial_state=initial_state, + start=start_str, end=end_str, + lp_supply_df=None) + + # World + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + raw_bal_0 = world["raw_bal_0"] + raw_bal_1 = world["raw_bal_1"] + + start_sec = result_lp["start_unix_sec"] + prices_min = result_lp["prices"] + + # World value at world timestamps (raw balances × USD prices) + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts) + for i in range(prices_min.shape[1]) + ], axis=1) + world_value = raw_bal_0 * prices_at_world[:, 0] + raw_bal_1 * prices_at_world[:, 1] + + # Sim values (minute-resolution) + sim_value_lp = np.array(result_lp["value_usd"]) + sim_value_no_lp = np.array(result_no_lp["value_usd"]) + n_minutes = len(sim_value_lp) + sim_times_sec = start_sec + np.arange(n_minutes) * 60 + sim_days = (sim_times_sec - start_sec) / 86400 + world_days = (world_ts - start_sec) / 86400 + + # BPT supply at world timestamps (for annotation) + lp_at_world = np.interp( + world_ts, + np.array(lp_supply_df["unix"]) / 1000, + np.array(lp_supply_df["lp_supply"]), + ) + + # --- Plot --- + fig, axes = plt.subplots(3, 1, figsize=(14, 12), sharex=True) + + # Panel 1: absolute pool value + ax = axes[0] + ax.plot(world_days, world_value, "k-", linewidth=2, label="World (raw balances × prices)") + ax.plot(sim_days, sim_value_lp, "b-", linewidth=1, alpha=0.8, label="Sim (with lp_supply)") + ax.plot(sim_days, sim_value_no_lp, "r--", linewidth=1, alpha=0.8, label="Sim (no lp_supply)") + ax.set_ylabel("Pool value (USD)") + ax.set_title(f"{pool.label} — gas=0, arb_freq=1min — absolute pool value") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + # Panel 2: growth factors + ax = axes[1] + world_growth = world_value / world_value[0] + sim_growth_lp = sim_value_lp / sim_value_lp[0] + sim_growth_no_lp = sim_value_no_lp / sim_value_no_lp[0] + ax.plot(world_days, world_growth, "k-", linewidth=2, label="World growth") + ax.plot(sim_days, sim_growth_lp, "b-", linewidth=1, alpha=0.8, label="Sim growth (with lp_supply)") + ax.plot(sim_days, sim_growth_no_lp, "r--", linewidth=1, alpha=0.8, label="Sim growth (no lp_supply)") + ax.axhline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.set_ylabel("Growth factor") + ax.set_title("Growth factors (value / initial value)") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + # Panel 3: BPT supply + ax = axes[2] + ax.plot(world_days, lp_at_world, "g-", linewidth=2, label="BPT supply (normalized)") + ax.set_ylabel("BPT / BPT₀") + ax.set_xlabel("Days from start") + ax.set_title("On-chain BPT supply") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.2) + + fig.suptitle( + f"{pool.label} ({pool.chain}) — TVL=${pool.initial_pool_value_usd:,.0f} — " + f"PR={pool.on_chain_params['price_ratio']:.4f}", + fontsize=12, fontweight="bold", + ) + plt.tight_layout() + + os.makedirs("results", exist_ok=True) + out = "results/diagnostic_lp_supply_cbBTC_WETH.png" + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved: {out}") + + # Print key numbers + print(f"\nWorld: {world_value[0]:.0f} -> {world_value[-1]:.0f} (growth={world_growth[-1]:.4f})") + print(f"Sim (lp): {sim_value_lp[0]:.0f} -> {sim_value_lp[-1]:.0f} (growth={sim_growth_lp[-1]:.4f})") + print(f"Sim (no lp): {sim_value_no_lp[0]:.0f} -> {sim_value_no_lp[-1]:.0f} (growth={sim_growth_no_lp[-1]:.4f})") + print(f"\nDeviation (lp): {(sim_growth_lp[-1]/world_growth[-1] - 1)*100:+.2f}%") + print(f"Deviation (no lp): {(sim_growth_no_lp[-1]/world_growth[-1] - 1)*100:+.2f}%") + + +if __name__ == "__main__": + main() diff --git a/experiments/pool_registry.py b/experiments/pool_registry.py new file mode 100644 index 0000000..cfdce2f --- /dev/null +++ b/experiments/pool_registry.py @@ -0,0 +1,512 @@ +"""Registry of on-chain reClAMM pools for sim-vs-world gas calibration. + +Extracts pool state from reclamm-simulations DB and computes TVL in USD +at each pool's plausible_start date. Maps chain → realistic gas costs. +Also provides initial on-chain state (Ra, Rb, Va, Vb) and world balance +history for comparison. + +Pools excluded: + - EUR_USDC_b, sUSDai_USDT0, WXPL_USDT0: stable/stable pairs + - wstETH_GNO: boosted (wstETH yield-bearing) +""" + +import math +import os +import sqlite3 +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Database path (reclamm-simulations repo) +# --------------------------------------------------------------------------- +DEFAULT_DB_PATH = os.path.expanduser( + "~/Projects/reclamm-simulations/data/pools_history.db" +) + +# --------------------------------------------------------------------------- +# Chain → gas cost batteries (USD) +# --------------------------------------------------------------------------- +# Non-mainnet chains use flat gas costs. +# Ethereum uses time-varying gas from on-chain percentile CSVs. +CHAIN_GAS_COSTS = { + "base": [0.0, 0.01, 0.1, 0.5], + "gnosis": [0.0, 0.01, 0.1, 0.5], + "avalanche": [0.0, 0.01, 0.1, 0.5], +} + +# Ethereum mainnet: time-varying gas percentiles + flat zero baseline. +# CSVs live in gas_csvs/ with columns [unix, USD]. +GAS_CSV_DIR = os.path.join(os.path.dirname(__file__), "..", "gas_csvs") +ETHEREUM_GAS_PERCENTILES = ["50p", "75p", "90p", "95p"] + + +@dataclass +class PoolConfig: + """Static metadata for a simulatable on-chain reClAMM pool.""" + + label: str + tokens: list # quantammsim ticker names, e.g. ['BTC', 'ETH'] + chain: str + swap_fee: float + db_label: str # table name in pools_history.db + plausible_start: str # YYYY-MM-DD + reverse: bool # True if DB token order is reversed vs quantammsim + pool_address: str = "" # on-chain contract address (hex, no 0x prefix) + # Filled by extract_on_chain_state(): + on_chain_params: Optional[dict] = None # price_ratio, margin, shift_rate + initial_pool_value_usd: Optional[float] = None + + +# --------------------------------------------------------------------------- +# Pool definitions (non-stable, non-boosted pools with quantammsim tickers) +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Chain → Balancer V3 API chain identifier +# --------------------------------------------------------------------------- +BALANCER_API_CHAIN = { + "base": "BASE", + "ethereum": "MAINNET", + "gnosis": "GNOSIS", + "avalanche": "AVALANCHE", + "arbitrum": "ARBITRUM", + "polygon": "POLYGON", + "optimism": "OPTIMISM", + "sonic": "SONIC", +} + + +POOL_REGISTRY = { + "cbBTC_WETH": PoolConfig( + label="cbBTC_WETH", + tokens=["BTC", "ETH"], + chain="base", + swap_fee=0.0005, + db_label="cbBTC_WETH", + plausible_start="2025-08-01", + reverse=True, + pool_address="19aeb8168d921bb069c6771bbaff7c09116720d0", + ), + "cbBTC_WETH_post_oct": PoolConfig( + label="cbBTC_WETH_post_oct", + tokens=["BTC", "ETH"], + chain="base", + swap_fee=0.0005, + db_label="cbBTC_WETH", + plausible_start="2025-12-01", + reverse=True, + pool_address="19aeb8168d921bb069c6771bbaff7c09116720d0", + ), + "AAVE_WETH": PoolConfig( + label="AAVE_WETH", + tokens=["AAVE", "ETH"], + chain="ethereum", + swap_fee=0.0025, + db_label="AAVE_WETH", + plausible_start="2025-08-15", + reverse=False, + pool_address="9d1fcf346ea1b073de4d5834e25572cc6ad71f4d", + ), + "AAVE_WETH_post_gov": PoolConfig( + label="AAVE_WETH_post_gov", + tokens=["AAVE", "ETH"], + chain="ethereum", + swap_fee=0.0025, + db_label="AAVE_WETH", + plausible_start="2025-12-21", + reverse=False, + pool_address="9d1fcf346ea1b073de4d5834e25572cc6ad71f4d", + ), + "COW_WETH_b": PoolConfig( + label="COW_WETH_b", + tokens=["COW", "ETH"], + chain="base", + swap_fee=0.003, + db_label="COW_WETH_b", + plausible_start="2025-07-18", + reverse=True, + pool_address="ff028c1ec4559d3aa2b0859aa582925b5cc28069", + ), + "COW_WETH_e": PoolConfig( + label="COW_WETH_e", + tokens=["COW", "ETH"], + chain="ethereum", + swap_fee=0.003, + db_label="COW_WETH_e", + plausible_start="2025-09-21", + reverse=True, + pool_address="d321300ef77067d4a868f117d37706eb81368e98", + ), + "WAVAX_USDC": PoolConfig( + label="WAVAX_USDC", + tokens=["AVAX", "USDC"], + chain="avalanche", + swap_fee=0.001, + db_label="WAVAX_USDC", + plausible_start="2025-08-17", + reverse=False, + pool_address="8750ccffcddbff81b63790dbcb1ffd8c7dc4c16d", + ), + "GNO_USDC": PoolConfig( + label="GNO_USDC", + tokens=["GNO", "USDC"], + chain="gnosis", + swap_fee=0.003, + db_label="GNO_USDC", + plausible_start="2025-09-18", + reverse=True, + pool_address="70b3b56773ace43fe86ee1d80cbe03176cbe4c09", + ), +} + + +def _date_to_unix(date_str: str) -> int: + """Convert YYYY-MM-DD or YYYY-MM-DD HH:MM:SS to unix timestamp (seconds).""" + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + dt = datetime.strptime(date_str, fmt).replace(tzinfo=timezone.utc) + return int(dt.timestamp()) + except ValueError: + continue + raise ValueError(f"Cannot parse date: {date_str}") + + +def _get_usd_price_at(ticker: str, unix_ms: int, data_root: str) -> float: + """Get the USD price of a ticker at a given unix timestamp (ms).""" + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path) + idx = (df["unix"] - unix_ms).abs().idxmin() + return float(df.iloc[idx]["close"]) + + +def extract_on_chain_state( + pool: PoolConfig, + db_path: str = DEFAULT_DB_PATH, + data_root: str = None, +) -> PoolConfig: + """Query the DB for on-chain state at plausible_start and compute USD TVL. + + Mutates and returns the pool config with on_chain_params and + initial_pool_value_usd filled in. + """ + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + + conn = sqlite3.connect(db_path) + cur = conn.cursor() + ts = _date_to_unix(pool.plausible_start) + + cur.execute( + f"""SELECT * FROM {pool.db_label} + WHERE timestamp <= ? + ORDER BY timestamp DESC LIMIT 1""", + (ts + 3600,), + ) + row = cur.fetchone() + conn.close() + + if row is None: + raise ValueError( + f"No DB data for {pool.db_label} at {pool.plausible_start}" + ) + + # DB columns: timestamp, block_number, bpt_supply, balance_0, balance_1, + # spot_price, virtual_0, virtual_1, time_last_interaction, + # price_ratio, margin, shift_rate, swap_fee + balance_0, balance_1 = row[3], row[4] + price_ratio = row[9] + margin = row[10] + shift_rate = row[11] + + pool.on_chain_params = { + "price_ratio": price_ratio, + "margin": margin, + "shift_rate": shift_rate, + "swap_fee": row[12], + } + + # Compute TVL in USD from per-token USD prices. + # DB stores balances in contract token order (bring_pool_data.py never + # applies reverse). The reverse flag tells us the mapping: + # reverse=False → balance_0=tokens[0], balance_1=tokens[1] + # reverse=True → balance_0=tokens[1], balance_1=tokens[0] + unix_ms = ts * 1000 + if pool.reverse: + tickers_in_db_order = [pool.tokens[1], pool.tokens[0]] + else: + tickers_in_db_order = [pool.tokens[0], pool.tokens[1]] + + usd_prices = [] + for ticker in tickers_in_db_order: + if ticker == "USDC": + usd_prices.append(1.0) + else: + usd_prices.append( + _get_usd_price_at(ticker, unix_ms, data_root) + ) + + pool.initial_pool_value_usd = ( + balance_0 * usd_prices[0] + balance_1 * usd_prices[1] + ) + return pool + + +def extract_initial_state( + pool: PoolConfig, + db_path: str = DEFAULT_DB_PATH, +) -> dict: + """Extract on-chain Ra, Rb, Va, Vb at plausible_start in quantammsim order. + + quantammsim sorts tokens alphabetically, so token[0] is the + alphabetically-first ticker. The reverse flag maps DB contract + order to this sorted order. + + Returns dict with keys Ra, Rb, Va, Vb (floats). + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + ts = _date_to_unix(pool.plausible_start) + + cur.execute( + f"""SELECT balance_0, balance_1, virtual_0, virtual_1 + FROM {pool.db_label} + WHERE timestamp <= ? + ORDER BY timestamp DESC LIMIT 1""", + (ts + 3600,), + ) + row = cur.fetchone() + conn.close() + + if row is None: + raise ValueError( + f"No DB data for {pool.db_label} at {pool.plausible_start}" + ) + + b0, b1, v0, v1 = row + if pool.reverse: + # DB contract order is opposite to quantammsim sorted order + return {"Ra": b1, "Rb": b0, "Va": v1, "Vb": v0} + else: + return {"Ra": b0, "Rb": b1, "Va": v0, "Vb": v1} + + +def load_world_history( + pool: PoolConfig, + end_date: str = None, + db_path: str = DEFAULT_DB_PATH, +) -> dict: + """Load on-chain balance history from the DB. + + Returns dict with: + timestamps: array of unix timestamps (seconds) + bal_0: BPT-normalized balance of quantammsim token[0] + bal_1: BPT-normalized balance of quantammsim token[1] + raw_bal_0: raw (un-normalized) balance of quantammsim token[0] + raw_bal_1: raw (un-normalized) balance of quantammsim token[1] + governance_events: list of (timestamp, field, old_val, new_val) + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + ts_start = _date_to_unix(pool.plausible_start) - 1000 + if end_date: + ts_end = _date_to_unix(end_date) + else: + ts_end = 2_000_000_000 # far future + + cur.execute( + f"""SELECT timestamp, bpt_supply, balance_0, balance_1, + price_ratio, margin, shift_rate, swap_fee + FROM {pool.db_label} + WHERE timestamp BETWEEN ? AND ? + ORDER BY timestamp""", + (ts_start, ts_end), + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No world history for {pool.db_label}") + + initial_bpt = rows[0][1] + timestamps = [] + bal_db_0_norm = [] + bal_db_1_norm = [] + bal_db_0_raw = [] + bal_db_1_raw = [] + governance_events = [] + + for i, row in enumerate(rows): + ts, bpt, b0, b1, pr, margin, shift_rate, swap_fee = row + timestamps.append(ts) + norm = initial_bpt / bpt + bal_db_0_norm.append(b0 * norm) + bal_db_1_norm.append(b1 * norm) + bal_db_0_raw.append(b0) + bal_db_1_raw.append(b1) + + # Detect governance changes. + # price_ratio drifts continuously via the shift mechanism, so + # only flag large discrete jumps (>1% relative change) as governance. + # margin, shift_rate, and swap_fee are set by governance and don't drift. + if i > 0: + prev = rows[i - 1] + if not math.isclose(prev[4], pr, rel_tol=0.01): + governance_events.append((ts, "price_ratio", prev[4], pr)) + if not math.isclose(prev[5], margin, rel_tol=1e-6): + governance_events.append((ts, "margin", prev[5], margin)) + if not math.isclose(prev[6], shift_rate, rel_tol=1e-6): + governance_events.append((ts, "shift_rate", prev[6], shift_rate)) + + bal_db_0_norm = np.array(bal_db_0_norm) + bal_db_1_norm = np.array(bal_db_1_norm) + bal_db_0_raw = np.array(bal_db_0_raw) + bal_db_1_raw = np.array(bal_db_1_raw) + + # Apply reverse: swap to quantammsim sorted token order + if pool.reverse: + bal_sorted_0, bal_sorted_1 = bal_db_1_norm, bal_db_0_norm + raw_sorted_0, raw_sorted_1 = bal_db_1_raw, bal_db_0_raw + else: + bal_sorted_0, bal_sorted_1 = bal_db_0_norm, bal_db_1_norm + raw_sorted_0, raw_sorted_1 = bal_db_0_raw, bal_db_1_raw + + return { + "timestamps": np.array(timestamps), + "bal_0": bal_sorted_0, + "bal_1": bal_sorted_1, + "raw_bal_0": raw_sorted_0, + "raw_bal_1": raw_sorted_1, + "governance_events": governance_events, + } + + +def load_bpt_supply_df( + pool: PoolConfig, + end_date: str = None, + db_path: str = DEFAULT_DB_PATH, +) -> pd.DataFrame: + """Load BPT supply as a DataFrame suitable for do_run_on_historic_data. + + Returns DataFrame with columns: + unix: timestamps in milliseconds + lp_supply: BPT normalized to 1.0 at plausible_start + + The normalization matches the simulator convention: lp_supply=1.0 at the + start of the sim, scaling proportionally as the on-chain pool grows/shrinks. + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + + ts_start = _date_to_unix(pool.plausible_start) - 1000 + if end_date: + ts_end = _date_to_unix(end_date) + else: + ts_end = 2_000_000_000 + + cur.execute( + f"""SELECT timestamp, bpt_supply + FROM {pool.db_label} + WHERE timestamp BETWEEN ? AND ? + ORDER BY timestamp""", + (ts_start, ts_end), + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No BPT data for {pool.db_label}") + + initial_bpt = rows[0][1] + return pd.DataFrame({ + # Round to nearest minute boundary so timestamps land on the minute grid + # used by raw_fee_like_amounts_to_fee_like_array. + "unix": [round(r[0] / 60) * 60 * 1000 for r in rows], + "lp_supply": [r[1] / initial_bpt for r in rows], + }) + + +def get_data_end_date(tokens: list, data_root: str = None) -> str: + """Find the latest common date across all token parquets. + + Returns a date string like '2026-02-18 00:00:00'. + """ + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + + min_end = float("inf") + for ticker in tokens: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path, columns=["unix"]) + last = float(df["unix"].iloc[-1]) + if last < min_end: + min_end = last + + # Convert ms to datetime + dt = datetime.utcfromtimestamp(min_end / 1000) + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def load_gas_csv(percentile: str) -> pd.DataFrame: + """Load a gas percentile CSV as a DataFrame for do_run_on_historic_data. + + Returns DataFrame with columns [unix, trade_gas_cost_usd], timestamps + floored to minute boundaries. + """ + path = os.path.join(GAS_CSV_DIR, f"Gas_{percentile}.csv") + df = pd.read_csv(path) + df = df.rename(columns={"USD": "trade_gas_cost_usd"}) + df["unix"] = (df["unix"] // 60000) * 60000 # floor to minute boundary + return df + + +def get_gas_costs(pool: PoolConfig, custom: list = None) -> list: + """Return the gas cost battery for a pool's chain. + + For Ethereum, returns a list mixing flat 0.0 with gas percentile labels + (e.g. ["0.0", "50p", "75p", "90p", "95p"]). + For other chains, returns flat USD values. + """ + if custom is not None: + return custom + if pool.chain == "ethereum": + flat = [0.0, 0.1, 0.5, 1.0, 3.0, 5.0, 10.0] + return flat + ETHEREUM_GAS_PERCENTILES + return CHAIN_GAS_COSTS.get(pool.chain, [0.0, 0.1, 1.0]) + + +def print_pool_summary(pool: PoolConfig): + """Print a summary of the pool's on-chain state.""" + print(f"\n{'='*60}") + print(f"Pool: {pool.label}") + print(f" Chain: {pool.chain}") + print(f" Tokens: {pool.tokens[0]}/{pool.tokens[1]}") + print(f" Swap fee: {pool.swap_fee}") + print(f" Start: {pool.plausible_start}") + if pool.on_chain_params: + p = pool.on_chain_params + print(f" On-chain: PR={p['price_ratio']:.4f} " + f"margin={p['margin']} shift_rate={p['shift_rate']} " + f"fee={p['swap_fee']}") + if pool.initial_pool_value_usd: + print(f" TVL: ${pool.initial_pool_value_usd:,.0f} USD") + print(f" Gas battery: {get_gas_costs(pool)}") + print(f"{'='*60}") + + +if __name__ == "__main__": + # Print summary of all pools + for label, pool in POOL_REGISTRY.items(): + try: + extract_on_chain_state(pool) + print_pool_summary(pool) + except Exception as e: + print(f"\n{label}: FAILED — {e}") diff --git a/experiments/run_pool_battery.py b/experiments/run_pool_battery.py new file mode 100644 index 0000000..def2c47 --- /dev/null +++ b/experiments/run_pool_battery.py @@ -0,0 +1,914 @@ +"""Sim-vs-world gas + arb-frequency calibration for on-chain reClAMM pools. + +For each pool in the registry, runs quantammsim forward passes with exact +on-chain parameters across a 2D grid of (gas_cost, arb_frequency), then +compares the simulated pool value trajectory against the actual on-chain +trajectory. + +arb_frequency is the period between arb trades in minutes (1 = every minute, +the most aggressive; higher = sparser arb). + +Generalizes scripts/sim_vs_world_comparison.py to work for any pool. + +Usage: + cd /Users/matthew/Projects/quantammsim-reclamm + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + + # Single pool (default gas + arb_freq grid) + python experiments/run_pool_battery.py cbBTC_WETH + + # All pools with data available + python experiments/run_pool_battery.py --all + + # Custom grids + python experiments/run_pool_battery.py cbBTC_WETH --gas-costs 0.0 0.5 1.0 --arb-freqs 1 5 15 60 + + # Dry run (show config without running) + python experiments/run_pool_battery.py cbBTC_WETH --dry-run + + # List available pools + python experiments/run_pool_battery.py --list +""" + +import argparse +import json +import os +import time + +import jax.numpy as jnp +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from datetime import datetime, timezone + +from experiments.pool_registry import ( + POOL_REGISTRY, + PoolConfig, + extract_initial_state, + extract_on_chain_state, + get_data_end_date, + get_gas_costs, + load_bpt_supply_df, + load_gas_csv, + load_world_history, + print_pool_summary, +) +from quantammsim.runners.jax_runners import do_run_on_historic_data + +PROTOCOL_FEE_SPLIT = 0.5 +DEFAULT_ARB_FREQS = [1, 2, 3, 5, 10, 15, 20] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): + """Sample a minute-level array at specific Unix timestamps. + + For each target timestamp, finds the nearest minute index in the + sim output and returns the corresponding value. + """ + indices = np.round((timestamps_sec - start_unix_sec) / 60).astype(int) + indices = np.clip(indices, 0, len(minute_vals) - 1) + return minute_vals[indices] + + +def compute_log_rmse(sim_growth, world_growth): + """RMSE of log(sim/world) across all trajectory points. + + Symmetric in over/under-estimation, natural for multiplicative processes. + A score of 0.02 means typical 2% deviation at any point in time. + """ + log_ratio = np.log(sim_growth / world_growth) + return np.sqrt(np.mean(log_ratio ** 2)) + + +def _start_str_from_pool(pool): + """Derive sim start time from pool's plausible_start, rounded to minute.""" + ts = int( + datetime.strptime(pool.plausible_start, "%Y-%m-%d") + .replace(tzinfo=timezone.utc) + .timestamp() + ) + ts_minute = (ts // 60) * 60 + return datetime.utcfromtimestamp(ts_minute).strftime("%Y-%m-%d %H:%M:%S") + + +def _onchain_params_to_sim(pool): + """Map DB param names to quantammsim param dict (jnp arrays).""" + p = pool.on_chain_params + return { + "price_ratio": jnp.array(p["price_ratio"]), + "centeredness_margin": jnp.array(p["margin"]), + "shift_exponent": jnp.array(p["shift_rate"]), + } + + +# --------------------------------------------------------------------------- +# Core sim runner +# --------------------------------------------------------------------------- + +def run_sim(pool, gas_cost, arb_frequency, initial_state, start, end, + protocol_fee_split=PROTOCOL_FEE_SPLIT, lp_supply_df=None, + noise_config=None): + """Run a single forward pass with exact on-chain params. + + gas_cost can be: + - float: flat gas cost in USD (e.g. 0.0, 0.5) + - str: gas percentile label (e.g. "50p", "90p") — loads time-varying + gas from CSV + + noise_config can be: + - None: no noise model (arb-only, default) + - dict with keys 'noise_model' and 'reclamm_noise_params': inject + Tsoukalas noise model into the sim + + Returns dict with minute-level per-LP value (USD), prices (USD per token), + and start_unix_sec. When lp_supply_df is provided, value_usd is divided + by the interpolated LP supply so it is comparable to BPT-normalized world + balances. + """ + params = _onchain_params_to_sim(pool) + + # Resolve gas: percentile string → DataFrame, float → scalar + gas_cost_df = None + if isinstance(gas_cost, str): + gas_cost_df = load_gas_csv(gas_cost) + flat_gas = 0.0 # placeholder; gas_cost_df overrides + else: + flat_gas = gas_cost + + fp = { + "tokens": pool.tokens, + "rule": "reclamm", + "startDateString": start, + "endDateString": end, + "initial_pool_value": pool.initial_pool_value_usd, + "fees": pool.swap_fee, + "gas_cost": flat_gas, + "arb_fees": 0.0, + "do_arb": True, + "arb_frequency": arb_frequency, + "chunk_period": 1440, + "weight_interpolation_period": 1440, + "reclamm_use_shift_exponent": True, + "reclamm_interpolation_method": "geometric", + "reclamm_centeredness_scaling": False, + "protocol_fee_split": protocol_fee_split, + "reclamm_initial_state": initial_state, + } + + if noise_config is not None: + fp["noise_model"] = noise_config["noise_model"] + fp["reclamm_noise_params"] = noise_config["reclamm_noise_params"] + + result = do_run_on_historic_data( + run_fingerprint=fp, params=params, lp_supply_df=lp_supply_df, + gas_cost_df=gas_cost_df, + ) + + start_unix_sec = datetime.strptime( + start, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + value_usd = np.array(result["value"]) + + # Normalize to per-LP value so comparison with BPT-normalized world is valid. + # + # Subtlety: the scan applies lp_supply every arb_frequency minutes. + # Between scan steps, reserves are constant (no arb), so the pool value + # reflects the lp_supply from the LAST scan step. If a BPT event occurs + # between scan steps (e.g. at minute 3 when arb_frequency=5), the value + # at minutes 3-4 still reflects the old lp_supply. We must divide by + # the scan-step-aligned lp_supply, not the current minute's lp_supply, + # otherwise we get transient spikes. + if lp_supply_df is not None: + n_minutes = len(value_usd) + # Map each minute to its most recent scan-step time + scan_step_minutes = ( + np.arange(n_minutes) // arb_frequency * arb_frequency + ) + scan_step_times_ms = ( + start_unix_sec * 1000 + scan_step_minutes * 60_000 + ) + lp_unix = np.array(lp_supply_df["unix"]) + lp_vals = np.array(lp_supply_df["lp_supply"]) + indices = np.searchsorted(lp_unix, scan_step_times_ms, side="right") - 1 + indices = np.clip(indices, 0, len(lp_vals) - 1) + value_usd = value_usd / lp_vals[indices] + + return { + "value_usd": value_usd, + "prices": np.array(result["prices"]), # (T, n_tokens) in USD + "start_unix_sec": start_unix_sec, + } + + +# --------------------------------------------------------------------------- +# Pool calibration (2D grid: gas_cost × arb_frequency) +# --------------------------------------------------------------------------- + +def run_pool_calibration(pool, gas_costs, arb_freqs, verbose=True, + noise_config=None): + """Run 2D gas × arb_frequency calibration for a single pool. + + Parameters + ---------- + noise_config : dict, optional + If provided, passed through to run_sim to inject noise model. + Keys: 'noise_model', 'reclamm_noise_params'. + + Returns dict with: + world_growth: array of world growth factors + sim_growths: {(gas_cost, arb_freq): growth array} + timestamps: world timestamps (seconds) + governance_idx: index of first governance event (or n_points) + n_points: number of comparison points + days: array of days from start + gas_costs: list of gas costs + arb_freqs: list of arb frequencies + """ + # Extract on-chain state + initial reserves + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + + if verbose: + print_pool_summary(pool) + print(f" Initial state: Ra={initial_state['Ra']:.4f}, " + f"Rb={initial_state['Rb']:.4f}, " + f"Va={initial_state['Va']:.4f}, Vb={initial_state['Vb']:.4f}") + + start_str = _start_str_from_pool(pool) + end_str = get_data_end_date(pool.tokens) + + # Load BPT supply history for LP supply scaling + lp_supply_df = load_bpt_supply_df(pool, end_date=end_str) + + if verbose: + bpt_start = lp_supply_df["lp_supply"].iloc[0] + bpt_end = lp_supply_df["lp_supply"].iloc[-1] + print(f" BPT supply: {bpt_start:.4f} → {bpt_end:.4f} " + f"({(bpt_end/bpt_start - 1)*100:+.1f}%)") + print(f" Sim period: {start_str} to {end_str}") + n_runs = len(gas_costs) * len(arb_freqs) + print(f" Grid: {len(gas_costs)} gas × {len(arb_freqs)} arb_freq = {n_runs} runs") + + # Load world history (BPT-normalized balances + governance events) + # BPT-normalized is correct for the growth ratio metric since LP supply + # cancels out of sim_growth/world_growth. Raw balances are available + # in world["raw_bal_0/1"] for absolute trajectory comparison. + world = load_world_history(pool, end_date=end_str) + world_ts = world["timestamps"] + world_bal_0 = world["bal_0"] + world_bal_1 = world["bal_1"] + gov_events = world["governance_events"] + + if verbose: + print(f" World points: {len(world_ts)}") + if gov_events: + for ts, field, old, new in gov_events: + dt = datetime.utcfromtimestamp(ts).strftime("%Y-%m-%d") + print(f" Governance: {field} {old:.6f} -> {new:.6f} on {dt}") + else: + print(" No governance events") + + # Governance cutoff index + if gov_events: + gov_idx = np.searchsorted(world_ts, gov_events[0][0]) + else: + gov_idx = len(world_ts) + + # Run sims across the 2D grid + sim_results = {} + prices_min = None + start_sec = None + + for gc in gas_costs: + for af in arb_freqs: + if verbose: + print(f"\n Running gas=${gc}, arb_freq={af}min...") + t0 = time.time() + result = run_sim(pool, gc, af, initial_state, start_str, end_str, + lp_supply_df=lp_supply_df, + noise_config=noise_config) + elapsed = time.time() - t0 + if verbose: + print(f" Done in {elapsed:.1f}s") + sim_results[(gc, af)] = result + if prices_min is None: + prices_min = result["prices"] + start_sec = result["start_unix_sec"] + + # Truncate at governance + n = min(gov_idx, len(world_ts)) + world_ts_trunc = world_ts[:n] + + # Sample USD prices at world timestamps for world valuation + prices_at_world = np.stack([ + sample_at_timestamps(prices_min[:, i], start_sec, world_ts_trunc) + for i in range(prices_min.shape[1]) + ], axis=1) + + # World value in USD = sum(bal_i * price_usd_i) + world_value = ( + world_bal_0[:n] * prices_at_world[:, 0] + + world_bal_1[:n] * prices_at_world[:, 1] + ) + world_growth = world_value / world_value[0] + + # Sim growths at world timestamps + sim_growths = {} + for key, result in sim_results.items(): + sim_val = sample_at_timestamps( + result["value_usd"], start_sec, world_ts_trunc, + ) + sim_growths[key] = sim_val / sim_val[0] + + days = (world_ts_trunc - world_ts_trunc[0]) / 86400 + + return { + "world_growth": world_growth, + "sim_growths": sim_growths, + "timestamps": world_ts_trunc, + "governance_idx": gov_idx, + "n_points": n, + "days": days, + "gas_costs": list(gas_costs), + "arb_freqs": list(arb_freqs), + } + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_pool_calibration(pool, calibration, output_dir="results", suffix=""): + """Plot 2D gas × arb_freq calibration as heatmap + time series. + + Left: heatmap of final % deviation (gas_cost × arb_freq). + Right: time series for each arb_freq at best gas cost. + """ + os.makedirs(output_dir, exist_ok=True) + + world_growth = calibration["world_growth"] + sim_growths = calibration["sim_growths"] + days = calibration["days"] + gas_costs = calibration["gas_costs"] + arb_freqs = calibration["arb_freqs"] + + # Build RMSE matrix (log ratio, %) + rmse_matrix = np.zeros((len(arb_freqs), len(gas_costs))) + for i, af in enumerate(arb_freqs): + for j, gc in enumerate(gas_costs): + rmse_matrix[i, j] = compute_log_rmse( + sim_growths[(gc, af)], world_growth + ) * 100 + + fig, (ax_heat, ax_ts) = plt.subplots( + 1, 2, figsize=(18, 7), + gridspec_kw={"width_ratios": [1, 1.5]}, + ) + + # Left: heatmap of trajectory RMSE + im = ax_heat.imshow( + rmse_matrix, aspect="auto", cmap="RdYlGn_r", + vmin=0, vmax=rmse_matrix.max(), origin="lower", + ) + ax_heat.set_xticks(range(len(gas_costs))) + gas_labels = [ + f"gas {gc}" if isinstance(gc, str) else f"${gc}" + for gc in gas_costs + ] + ax_heat.set_xticklabels(gas_labels, fontsize=9) + ax_heat.set_yticks(range(len(arb_freqs))) + ax_heat.set_yticklabels([f"{af}min" for af in arb_freqs], fontsize=9) + ax_heat.set_xlabel("gas cost (USD)") + ax_heat.set_ylabel("arb frequency (minutes)") + ax_heat.set_title("Trajectory RMSE (log ratio, %)") + + # Annotate cells + for i in range(len(arb_freqs)): + for j in range(len(gas_costs)): + val = rmse_matrix[i, j] + color = "white" if val > rmse_matrix.max() * 0.6 else "black" + ax_heat.text(j, i, f"{val:.2f}%", ha="center", va="center", + fontsize=8, color=color) + + # Mark cell with least-negative mean bias (closest to 0 from below). + # If no cell is below world on average, fall back to lowest RMSE. + bias_matrix = np.zeros_like(rmse_matrix) + for i, af in enumerate(arb_freqs): + for j, gc in enumerate(gas_costs): + bias_matrix[i, j] = float(np.mean( + np.log(sim_growths[(gc, af)] / world_growth) + )) + negative_mask = bias_matrix < 0 + if negative_mask.any(): + # Among negative cells, find the one closest to 0 (max value) + masked = np.where(negative_mask, bias_matrix, -np.inf) + best_idx = np.unravel_index(np.argmax(masked), masked.shape) + else: + best_idx = np.unravel_index(np.argmin(rmse_matrix), rmse_matrix.shape) + ax_heat.add_patch(plt.Rectangle( + (best_idx[1] - 0.5, best_idx[0] - 0.5), 1, 1, + fill=False, edgecolor="lime", linewidth=3, + )) + + fig.colorbar(im, ax=ax_heat, label="RMSE (%)", shrink=0.8) + + # Right: 4 closest from below + 1 first above world + # Rationale: sim should underestimate (can't capture organic swaps, MEV + # rebates, etc.), so being below world is expected. The one-above config + # brackets where the sim crosses from conservative to optimistic. + ax_ts.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # Classify configs by mean log ratio (trajectory-average bias). + # Using the mean rather than endpoint avoids a curve that's above + # world for 80% of the trajectory being classified as "below" + # just because it dips at the end. + below = [] # (mean_bias, rmse, gc, af) where sim < world on average + above = [] # (mean_bias, rmse, gc, af) where sim >= world on average + for (gc, af), sg in sim_growths.items(): + mean_bias = float(np.mean(np.log(sg / world_growth))) + rmse = compute_log_rmse(sg, world_growth) + if mean_bias < 0: + below.append((mean_bias, rmse, gc, af)) + else: + above.append((mean_bias, rmse, gc, af)) + + # Sort: below by mean_bias descending (closest to 0 first) + below.sort(key=lambda x: x[0], reverse=True) + # Sort: above by mean_bias ascending (closest to 0 first) + above.sort(key=lambda x: x[0]) + + # Select: up to 4 from below, 1 from above, fill if needed + selected = [] + n_below = min(4, len(below)) + n_above = min(1, len(above)) + selected.extend(below[:n_below]) + selected.extend(above[:n_above]) + remaining = 5 - len(selected) + if remaining > 0 and len(below) > n_below: + selected.extend(below[n_below:n_below + remaining]) + remaining = 5 - len(selected) + if remaining > 0 and len(above) > n_above: + selected.extend(above[n_above:n_above + remaining]) + + colors_below = plt.cm.Blues(np.linspace(0.4, 0.8, n_below)) + colors_above = np.array([[0.8, 0.2, 0.2, 1.0]]) # red for above + plot_colors = list(colors_below) + list(colors_above[:n_above]) + # Fill remaining with grey + while len(plot_colors) < len(selected): + plot_colors.append([0.5, 0.5, 0.5, 1.0]) + + for rank, (mean_bias, rmse, gc, af) in enumerate(selected): + dev = (sim_growths[(gc, af)] / world_growth - 1) * 100 + gc_label = f"gas {gc}" if isinstance(gc, str) else f"gas=${gc}" + marker = "\u25b2" if mean_bias >= 0 else "\u25bc" # ▲ above, ▼ below + ax_ts.plot(days, dev, color=plot_colors[rank], linewidth=2, + label=f"{marker} {gc_label}, arb={af}min " + f"bias={mean_bias*100:+.2f}% RMSE={rmse*100:.2f}%") + + ax_ts.set_xlabel("days") + ax_ts.set_ylabel("% deviation from world") + trunc = " (pre-governance)" if calibration["governance_idx"] < calibration["n_points"] + 1 else "" + ax_ts.set_title(f"Best bracket: {n_below} below + {n_above} above world{trunc}") + ax_ts.legend(fontsize=7, loc="best") + ax_ts.grid(True, alpha=0.2) + + p = pool.on_chain_params + fig.suptitle( + f"{pool.label} ({pool.chain}) — {pool.tokens[0]}/{pool.tokens[1]}\n" + f"PR={p['price_ratio']:.4f} margin={p['margin']} " + f"shift={p['shift_rate']} fee={pool.swap_fee} " + f"TVL=${pool.initial_pool_value_usd:,.0f} " + f"protocol_fee={PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + + out = os.path.join(output_dir, f"gas_calibration_{pool.label}{suffix}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + return out + + +def plot_cross_pool_summary(all_results, output_dir="results"): + """Plot cross-pool comparison: best (gas, arb_freq) and residual deviation.""" + os.makedirs(output_dir, exist_ok=True) + + pool_labels = [] + best_configs = [] + best_devs = [] + + best_rmses = [] + best_biases = [] + for pool, cal in all_results: + wg = cal["world_growth"] + # Best = least-negative mean bias (closest from below) + below_keys = [ + k for k in cal["sim_growths"] + if np.mean(np.log(cal["sim_growths"][k] / wg)) < 0 + ] + if below_keys: + best_key = max( + below_keys, + key=lambda k: np.mean(np.log(cal["sim_growths"][k] / wg)), + ) + else: + best_key = min( + cal["sim_growths"].keys(), + key=lambda k: compute_log_rmse(cal["sim_growths"][k], wg), + ) + best_rmse = compute_log_rmse(cal["sim_growths"][best_key], wg) * 100 + best_bias = float(np.mean(np.log(cal["sim_growths"][best_key] / wg))) * 100 + pool_labels.append(f"{pool.label}\n({pool.chain})") + best_configs.append(best_key) + best_rmses.append(best_rmse) + best_biases.append(best_bias) + + fig, (ax_cfg, ax_dev) = plt.subplots(1, 2, figsize=(16, 6)) + + x = np.arange(len(pool_labels)) + + # Left: RMSE at best config + config_strs = [ + f"gas {gc}\narb={af}min" if isinstance(gc, str) + else f"gas=${gc}\narb={af}min" + for gc, af in best_configs + ] + ax_cfg.barh(x, best_rmses, color="steelblue") + ax_cfg.set_yticks(x) + ax_cfg.set_yticklabels(pool_labels, fontsize=9) + ax_cfg.set_xlabel("Trajectory RMSE (%)") + ax_cfg.set_title("RMSE at best config") + for i, (cs, rmse) in enumerate(zip(config_strs, best_rmses)): + ax_cfg.text(rmse + 0.05, i, f"{cs} (RMSE={rmse:.2f}%)", va="center", fontsize=8) + ax_cfg.grid(True, alpha=0.2, axis="x") + + # Right: mean bias at best config (negative = conservative) + colors = ["green" if d < 0 else "orange" if d < 1 else "red" + for d in best_biases] + ax_dev.bar(x, best_biases, color=colors) + ax_dev.axhline(y=0, color="brown", linewidth=1) + ax_dev.set_xticks(x) + ax_dev.set_xticklabels(pool_labels, fontsize=8) + ax_dev.set_ylabel("Mean bias (%)") + ax_dev.set_title("Mean trajectory bias at best config") + ax_dev.grid(True, alpha=0.2, axis="y") + + fig.suptitle("Cross-pool gas + arb frequency calibration", fontsize=12) + plt.tight_layout() + + out = os.path.join(output_dir, "gas_calibration_cross_pool.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f"\nSaved cross-pool summary: {out}") + return out + + +# --------------------------------------------------------------------------- +# Summary printing +# --------------------------------------------------------------------------- + +def print_calibration_summary(pool, calibration): + """Print a summary table of the 2D grid results (trajectory RMSE).""" + world_growth = calibration["world_growth"] + sim_growths = calibration["sim_growths"] + n_days = calibration["days"][-1] + gas_costs = calibration["gas_costs"] + arb_freqs = calibration["arb_freqs"] + + print(f"\n {pool.label} ({pool.chain}) — {n_days:.0f} days") + print(f" World growth: {world_growth[-1]:.4f}") + + # Print as table: rows=arb_freq, cols=gas_cost (values = trajectory RMSE %) + col_label = "arb\\gas" + gas_labels = [ + f"gas {gc}" if isinstance(gc, str) else f"${gc}" + for gc in gas_costs + ] + header = f" {col_label:<10}" + "".join(f"{gl:<10}" for gl in gas_labels) + print(header) + print(f" {'-'*len(header)}") + for af in arb_freqs: + row = f" {af:>4}min " + for gc in gas_costs: + rmse = compute_log_rmse( + sim_growths[(gc, af)], world_growth + ) * 100 + row += f"{rmse:>8.2f}% " + print(row) + + +# --------------------------------------------------------------------------- +# Data availability check +# --------------------------------------------------------------------------- + +def check_data_available(pool, data_root=None): + """Check that all required parquet files exist for a pool.""" + if data_root is None: + data_root = os.path.join( + os.path.dirname(__file__), "..", "quantammsim", "data" + ) + for ticker in pool.tokens: + if ticker == "USDC": + continue + path = os.path.join(data_root, f"{ticker}_USD.parquet") + if not os.path.exists(path): + return False + return True + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Sim-vs-world gas + arb-frequency calibration for on-chain reClAMM pools" + ) + parser.add_argument("pool", nargs="?", + help="Pool label (e.g. cbBTC_WETH)") + parser.add_argument("--all", action="store_true", + help="Run all pools with available data") + parser.add_argument("--list", action="store_true", + help="List available pools and exit") + parser.add_argument("--dry-run", action="store_true", + help="Show pool config without running") + parser.add_argument("--gas-costs", nargs="+", type=float, default=None, + help="Override gas cost battery (flat USD values)") + parser.add_argument("--arb-freqs", nargs="+", type=int, + default=DEFAULT_ARB_FREQS, + help="Arb frequency values in minutes (default: 1 2 3 5 10 15 20)") + parser.add_argument("--protocol-fee", type=float, default=PROTOCOL_FEE_SPLIT, + help="Protocol fee split (default 0.5)") + parser.add_argument("--output-dir", default="results", + help="Directory for output plots and JSON") + parser.add_argument("--calibrate-noise", action="store_true", + help="Calibrate Tsoukalas noise model from Balancer API + DB " + "and inject into sim fingerprints") + parser.add_argument("--noise-model", choices=["sqrt", "log", "loglinear"], + default="sqrt", + help="Noise model variant (default: sqrt)") + parser.add_argument("--noise-params-json", default=None, + help="Path to hierarchical noise params JSON " + "(from calibrate_noise_hierarchical.py). " + "Looks up pool by address or uses --predict.") + args = parser.parse_args() + + # --list mode + if args.list: + print("\nAvailable pools:\n") + for label, pool in POOL_REGISTRY.items(): + has_data = check_data_available(pool) + status = "READY" if has_data else "MISSING DATA" + try: + extract_on_chain_state(pool) + print_pool_summary(pool) + except Exception as e: + print(f" {label}: {e}") + print(f" Data: {status}") + return + + # Determine which pools to run + if args.all: + pool_labels = [ + label for label, pool in POOL_REGISTRY.items() + if check_data_available(pool) + ] + if not pool_labels: + print("No pools have all required data files.") + return + elif args.pool: + if args.pool not in POOL_REGISTRY: + print(f"Unknown pool: {args.pool}") + print(f"Available: {list(POOL_REGISTRY.keys())}") + return + if not check_data_available(POOL_REGISTRY[args.pool]): + missing = [ + f"{t}_USD.parquet" for t in POOL_REGISTRY[args.pool].tokens + if t != "USDC" and not os.path.exists( + os.path.join( + os.path.dirname(__file__), "..", "quantammsim", + "data", f"{t}_USD.parquet" + ) + ) + ] + print(f"Missing data for {args.pool}: {missing}") + return + pool_labels = [args.pool] + else: + parser.print_help() + return + + # Collect runs + runs = [] + for label in pool_labels: + pool = POOL_REGISTRY[label] + gas_costs = get_gas_costs(pool, args.gas_costs) + runs.append((pool, gas_costs)) + + arb_freqs = args.arb_freqs + n_total = sum(len(gcs) * len(arb_freqs) for _, gcs in runs) + + print(f"\n{'='*60}") + print(f"GAS + ARB CALIBRATION: {len(runs)} pool(s), {n_total} total runs") + for pool, gcs in runs: + print(f" {pool.label:15s} ({pool.chain:10s}) " + f"gas={gcs} arb_freq={arb_freqs}") + print(f" Protocol fee split: {args.protocol_fee}") + print(f"{'='*60}") + + if args.dry_run: + print("\n--- DRY RUN ---\n") + for pool, gas_costs in runs: + extract_on_chain_state(pool) + initial_state = extract_initial_state(pool) + print_pool_summary(pool) + print(f" Initial state: {initial_state}") + start = _start_str_from_pool(pool) + end = get_data_end_date(pool.tokens) + print(f" Sim period: {start} to {end}") + world = load_world_history(pool, end_date=end) + n_gov = len(world["governance_events"]) + print(f" World points: {len(world['timestamps'])}, " + f"governance events: {n_gov}") + if world["governance_events"]: + for ts, field, old, new in world["governance_events"]: + dt = datetime.utcfromtimestamp(ts).strftime("%Y-%m-%d") + print(f" {field} {old:.6f} -> {new:.6f} on {dt}") + print(f" Gas battery: {gas_costs}") + print(f" Arb freqs: {arb_freqs}") + n_runs = len(gas_costs) * len(arb_freqs) + print(f" Total runs: {n_runs}") + return + + # Execute calibration for each pool + all_results = [] + for pool, gas_costs in runs: + print(f"\n{'#'*60}") + print(f"POOL: {pool.label}") + print(f"{'#'*60}") + + # Noise calibration (if requested) + noise_config = None + if args.noise_params_json and args.noise_model == "loglinear": + # Load hierarchical noise params from JSON + with open(args.noise_params_json) as f: + hier_data = json.load(f) + # Look up pool by address + addr = pool.pool_address.lower() + pool_params = None + for p in hier_data["pools"]: + pid = p["pool_id"].lower().replace("0x", "") + if pid.startswith(addr) or addr.startswith(pid): + pool_params = p["noise_params"] + break + if pool_params is None: + # Fall back to population-level prediction + from scripts.calibrate_noise_hierarchical import predict_new_pool + chain_map = {"ethereum": "MAINNET", "base": "BASE", + "gnosis": "GNOSIS", "arbitrum": "ARBITRUM", + "polygon": "POLYGON", "optimism": "OPTIMISM", + "sonic": "SONIC", "avalanche": "AVALANCHE"} + api_chain = chain_map.get(pool.chain, pool.chain.upper()) + # Reconstruct posteriors + encoding from the JSON + posteriors_from_json = { + "Phi_mean": np.array(hier_data["Phi"]), + } + encoding_from_json = { + "covariate_names": hier_data["covariate_names"], + } + pool_params = predict_new_pool( + posteriors_from_json, encoding_from_json, + api_chain, pool.tokens, pool.swap_fee, + ) + print(f"\n Using population-level loglinear params (pool not in JSON)") + else: + print(f"\n Using hierarchical loglinear params for {pool.label}") + print(f" b_0 = {pool_params['b_0']:.4f}, " + f"b_sigma = {pool_params['b_sigma']:.6f}, " + f"b_c = {pool_params['b_c']:.4f}") + # Strip metadata keys (prefixed with _) — JAX can't trace strings + sim_params = {k: v for k, v in pool_params.items() + if not k.startswith("_")} + noise_config = { + "noise_model": "loglinear", + "reclamm_noise_params": sim_params, + } + elif args.calibrate_noise: + from scripts.calibrate_reclamm_noise import ( + build_calibration_df, + run_ols_calibration, + ) + noise_model_name = ( + "tsoukalas_sqrt" if args.noise_model == "sqrt" + else "tsoukalas_log" + ) + print(f"\n Calibrating noise model ({args.noise_model})...") + cal_df = build_calibration_df(pool) + noise_params, diag = run_ols_calibration( + cal_df, pool.swap_fee, args.noise_model, + ) + print(f" R² = {diag['r_squared']:.4f}, n = {diag['n_obs']}") + for key in ["a_0", "a_sigma", "a_c"]: + param_key = "a_0_base" if key == "a_0" else key + val = noise_params[param_key] + se = diag["se"][key] + t_stat = val / se if se > 0 else float("inf") + print(f" {key:>8} = {val:>10.4f} (SE={se:.4f}, t={t_stat:.2f})") + noise_config = { + "noise_model": noise_model_name, + "reclamm_noise_params": noise_params, + } + + t0 = time.time() + calibration = run_pool_calibration( + pool, gas_costs, arb_freqs, noise_config=noise_config, + ) + elapsed = time.time() - t0 + + print_calibration_summary(pool, calibration) + plot_pool_calibration(pool, calibration, output_dir=args.output_dir) + all_results.append((pool, calibration)) + + print(f"\n Total time for {pool.label}: {elapsed:.0f}s") + + # Cross-pool summary + if len(all_results) > 1: + plot_cross_pool_summary(all_results, output_dir=args.output_dir) + + # Final summary table + print(f"\n{'='*60}") + print("CALIBRATION COMPLETE") + print(f"{'='*60}") + print(f"\n{'Pool':<16} {'Chain':<10} {'Best Gas':>9} {'Best Arb':>9} " + f"{'Bias':>8} {'RMSE':>8} {'Days':>6}") + print("-" * 70) + for pool, cal in all_results: + wg = cal["world_growth"] + # Best = least-negative mean bias (closest from below) + below_keys = [ + k for k in cal["sim_growths"] + if np.mean(np.log(cal["sim_growths"][k] / wg)) < 0 + ] + if below_keys: + best_key = max( + below_keys, + key=lambda k: np.mean(np.log(cal["sim_growths"][k] / wg)), + ) + else: + best_key = min( + cal["sim_growths"].keys(), + key=lambda k: compute_log_rmse(cal["sim_growths"][k], wg), + ) + best_bias = float(np.mean(np.log(cal["sim_growths"][best_key] / wg))) * 100 + best_rmse = compute_log_rmse(cal["sim_growths"][best_key], wg) * 100 + n_days = cal["days"][-1] + gc_label = f"gas {best_key[0]}" if isinstance(best_key[0], str) else f"${best_key[0]}" + print(f"{pool.label:<16} {pool.chain:<10} {gc_label:<9} " + f"{best_key[1]:>4}min {best_bias:>+7.2f}% {best_rmse:>7.2f}% {n_days:>5.0f}d") + + # Save JSON summary + os.makedirs(args.output_dir, exist_ok=True) + summary = [] + for pool, cal in all_results: + wg_arr = cal["world_growth"] + wg_final = float(wg_arr[-1]) + pool_summary = { + "label": pool.label, + "chain": pool.chain, + "tokens": pool.tokens, + "swap_fee": pool.swap_fee, + "tvl_usd": pool.initial_pool_value_usd, + "on_chain_params": pool.on_chain_params, + "n_days": float(cal["days"][-1]), + "n_governance_events": 1 if cal["governance_idx"] < cal["n_points"] else 0, + "world_growth": wg_final, + "grid_results": {}, + } + for (gc, af) in sorted(cal["sim_growths"].keys(), key=lambda k: (str(k[0]), k[1])): + sg_arr = cal["sim_growths"][(gc, af)] + rmse = compute_log_rmse(sg_arr, wg_arr) * 100 + pool_summary["grid_results"][f"gas={gc}_arb={af}"] = { + "gas_cost": gc, + "arb_frequency": af, + "sim_growth": float(sg_arr[-1]), + "pct_deviation": float((sg_arr[-1] / wg_final - 1) * 100), + "trajectory_rmse_pct": float(rmse), + } + summary.append(pool_summary) + + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + json_path = os.path.join(args.output_dir, f"gas_calibration_{ts}.json") + with open(json_path, "w") as f: + json.dump(summary, f, indent=2, default=str) + print(f"\nSummary saved to {json_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/tune_reclamm_params.py b/experiments/tune_reclamm_params.py index 0951e2c..4969923 100644 --- a/experiments/tune_reclamm_params.py +++ b/experiments/tune_reclamm_params.py @@ -38,6 +38,17 @@ def main(): parser.add_argument("--interpolation", default="geometric", choices=["geometric", "constant_arc_length"]) parser.add_argument("--centeredness-scaling", action="store_true") + parser.add_argument("--noise-trader-ratio", type=float, default=0.0) + parser.add_argument("--start-date", default="2024-06-01 00:00:00") + parser.add_argument("--end-date", default="2025-01-01 00:00:00", + help="End of training / start of test") + parser.add_argument("--end-test-date", default="2025-06-01 00:00:00") + parser.add_argument("--bout-offset", type=int, default=None, + help="bout_offset in minutes (default: 10080 = 7 days)") + parser.add_argument("--val-fraction", type=float, default=None, + help="Validation holdout fraction (default: 0.2, use 0 to disable)") + parser.add_argument("--overfitting-penalty", type=float, default=None, + help="Overfitting penalty weight (default: 0.2)") args = parser.parse_args() learn_speed = args.interpolation == "constant_arc_length" @@ -48,29 +59,33 @@ def main(): fp = { "rule": "reclamm", "tokens": ["AAVE", "ETH"], - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-01-01 00:00:00", - "endTestDateString": "2025-06-01 00:00:00", + "startDateString": args.start_date, + "endDateString": args.end_date, + "endTestDateString": args.end_test_date, "initial_pool_value": 1_000_000.0, "do_arb": True, "fees": args.fees, "gas_cost": args.gas_cost, "arb_fees": 0.0, "protocol_fee_split": 0.5, + "noise_trader_ratio": args.noise_trader_ratio, "return_val": args.objective, "reclamm_interpolation_method": args.interpolation, "reclamm_centeredness_scaling": args.centeredness_scaling, "reclamm_learn_arc_length_speed": learn_speed, "reclamm_use_shift_exponent": True, + **({"bout_offset": args.bout_offset} if args.bout_offset is not None else {}), "optimisation_settings": { "method": "optuna", "n_parameter_sets": 1, + **({"val_fraction": args.val_fraction} if args.val_fraction is not None else {}), "optuna_settings": { "make_scalar": True, "expand_around": False, "n_trials": args.n_trials, "multi_objective": False, "parameter_config": param_config, + **({"overfitting_penalty": args.overfitting_penalty} if args.overfitting_penalty is not None else {}), }, }, } diff --git a/scripts/benchmark_reclamm_interpolation.py b/scripts/benchmark_reclamm_interpolation.py new file mode 100644 index 0000000..f462bbf --- /dev/null +++ b/scripts/benchmark_reclamm_interpolation.py @@ -0,0 +1,648 @@ +"""Benchmark reClAMM range shift interpolation: current vs optimal midpoint. + +Compares total arb loss during a range shift under different interpolation methods: + Geometric VB -- exponential decay of overvalued virtual (what contracts do) + Linear VB -- uniform steps in VB + Linear Z -- uniform steps in Z = sqrt(P)*VA - VB/sqrt(P) (optimal, from note) + Optimal 2-step -- exact midpoint via quadratic formula (Section 5 of note) + Brute-force optimal -- JAX gradient-optimised Z-target sequence + +Key result: per-step loss ~ (DeltaZ)^2 / (4X). Equal Z-increments minimise +total loss, analogous to TFMM optimal intermediate for G3M weight changes. + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/benchmark_reclamm_interpolation.py +""" + +import numpy as np +import matplotlib.pyplot as plt + +import jax +import jax.numpy as jnp +from scipy.optimize import minimize as scipy_minimize + +jax.config.update("jax_enable_x64", True) + + +# ── Core reClAMM mechanics ───────────────────────────────────────────────── + + +def compute_VA_from_VB(RA, RB, VB, Q): + """Contract rule (eq 15): VA = RA*(VB + RB) / ((Q-1)*VB - RB).""" + return RA * (VB + RB) / ((Q - 1) * VB - RB) + + +def compute_Z(VA, VB, P): + """Z = sqrt(P)*VA - VB/sqrt(P) (eq 12).""" + sqP = np.sqrt(P) + return sqP * VA - VB / sqP + + +def pool_value(RA, RB, P): + """Real pool value: P*RA + RB (eq 3).""" + return P * RA + RB + + +def micro_step(RA, RB, VA_new, VB_new, P): + """Virtual-balance update then arb to equilibrium Y/X = P. + + Returns (RA_new, RB_new, arb_loss). + """ + val_before = pool_value(RA, RB, P) + X = RA + VA_new + Y = RB + VB_new + L = X * Y + X_eq = np.sqrt(L / P) + Y_eq = P * X_eq + RA_new = X_eq - VA_new + RB_new = Y_eq - VB_new + return RA_new, RB_new, val_before - pool_value(RA_new, RB_new, P) + + +def solve_VB_for_Z(RA, RB, Z_star, Q, P): + """Solve quadratic for VB achieving Z(VB) = Z_star. + + Derived by substituting VA = RA*(VB+RB)/((Q-1)*VB-RB) into + Z = sqrt(P)*VA - VB/sqrt(P), then collecting terms in VB. + + NOTE: The research note (eq 28) has a sign error: the RB/sqrt(P) + term in b should be positive, not negative. Re-derived here from + scratch. + + Returns the physically valid root (VB > RB/(Q-1), positive). + Raises ValueError if no valid root exists. + """ + sqP = np.sqrt(P) + a = -(Q - 1) / sqP + b = sqP * RA + RB / sqP - (Q - 1) * Z_star # +RB/sqP, not minus + c = sqP * RA * RB + Z_star * RB + disc = b * b - 4 * a * c + if disc < -1e-6: + raise ValueError(f"negative discriminant: {disc:.4e}") + disc = max(disc, 0.0) + sd = np.sqrt(disc) + r1, r2 = (-b + sd) / (2 * a), (-b - sd) / (2 * a) + floor = RB / (Q - 1) + 1e-12 + ok = [r for r in (r1, r2) if r > floor] + if not ok: + raise ValueError(f"no valid root: r1={r1:.4f}, r2={r2:.4f}, floor={floor:.4f}") + return min(ok) + + +# ── Interpolation methods ────────────────────────────────────────────────── + + +def run_shift(RA, RB, VA_stale, VB_start, VB_end, Q, P, N, schedule): + """Execute N-step range shift (B overvalued, VB decreasing). + + schedule: "geometric" | "linear_VB" | "linear_Z" + + VA_stale: the current (possibly stale) VA -- used only for Z_start + in the linear_Z schedule. All micro-steps compute VA from the + contract rule with current reserves. + """ + # For linear_Z, precompute Z endpoints using contract-rule VA + if schedule == "linear_Z": + VA_start_cr = compute_VA_from_VB(RA, RB, VB_start, Q) + Z0 = compute_Z(VA_start_cr, VB_start, P) + VA_end_approx = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end = compute_Z(VA_end_approx, VB_end, P) + + total_loss = 0.0 + RA_c, RB_c = RA, RB + + for i in range(1, N + 1): + frac = i / N + if schedule == "geometric": + VB_i = VB_start * (VB_end / VB_start) ** frac + elif schedule == "linear_VB": + VB_i = VB_start + frac * (VB_end - VB_start) + elif schedule == "linear_Z": + Z_i = Z0 + frac * (Z_end - Z0) + VB_i = solve_VB_for_Z(RA_c, RB_c, Z_i, Q, P) + else: + raise ValueError(schedule) + + VA_i = compute_VA_from_VB(RA_c, RB_c, VB_i, Q) + RA_c, RB_c, loss = micro_step(RA_c, RB_c, VA_i, VB_i, P) + total_loss += loss + + return total_loss, RA_c, RB_c + + +def run_shift_optimal_2step(RA, RB, VA_stale, VB_start, VB_end, Q, P): + """Exact 2-step optimal midpoint (Section 5 of the note). + + Computes Z* = (Z_start + Z_end) / 2, solves quadratic for VB_mid. + """ + VA_start_cr = compute_VA_from_VB(RA, RB, VB_start, Q) + Z0 = compute_Z(VA_start_cr, VB_start, P) + VA_end_approx = compute_VA_from_VB(RA, RB, VB_end, Q) + Z2 = compute_Z(VA_end_approx, VB_end, P) + Z_star = (Z0 + Z2) / 2.0 + + # Step 1: jump to Z-midpoint + VB_mid = solve_VB_for_Z(RA, RB, Z_star, Q, P) + VA_mid = compute_VA_from_VB(RA, RB, VB_mid, Q) + RA1, RB1, loss1 = micro_step(RA, RB, VA_mid, VB_mid, P) + + # Step 2: jump to endpoint + VA_end = compute_VA_from_VB(RA1, RB1, VB_end, Q) + RA2, RB2, loss2 = micro_step(RA1, RB1, VA_end, VB_end, P) + + return loss1 + loss2, RA2, RB2 + + +# ── Scenario setup ───────────────────────────────────────────────────────── + + +def setup_centered_pool(P, price_ratio, R_scale=10000.0): + """Centered pool at price P with contract-rule-consistent virtuals. + + Returns (RA, RB, VA, VB, Q). + """ + Q = np.sqrt(price_ratio) + q4 = price_ratio ** 0.25 + + RA = R_scale + RB = P * R_scale + VA = RA / (q4 - 1) + VB = RB / (q4 - 1) + + return RA, RB, VA, VB, Q + + +def setup_decentered_pool(P_init, P_final, price_ratio, R_scale=10000.0): + """Centered pool at P_init, arb to P_final, then refresh virtuals. + + The refresh applies the contract rule to get consistent (VA, VB) at + the post-arb reserves, then arbs once more. This gives a decentered + but fully consistent state (equilibrium + contract rule). + + Returns (RA, RB, VA, VB, Q). + """ + Q = np.sqrt(price_ratio) + q4 = price_ratio ** 0.25 + + RA0 = R_scale + RB0 = P_init * R_scale + VA0 = RA0 / (q4 - 1) + VB0 = RB0 / (q4 - 1) + + # Arb to P_final (L preserved, virtuals stale) + X0 = RA0 + VA0 + Y0 = RB0 + VB0 + L = X0 * Y0 + X_new = np.sqrt(L / P_final) + Y_new = np.sqrt(L * P_final) + RA = X_new - VA0 + RB = Y_new - VB0 + + # Refresh: apply contract rule for current VB, then arb + VB = VB0 + VA = compute_VA_from_VB(RA, RB, VB, Q) + RA, RB, _ = micro_step(RA, RB, VA, VB, P_final) + + return RA, RB, VA, VB, Q + + +# ── JAX-differentiable versions for brute-force optimisation ────────────── + + +def _compute_VA_from_VB_jax(RA, RB, VB, Q): + return RA * (VB + RB) / ((Q - 1) * VB - RB) + + +def _compute_Z_jax(VA, VB, P): + sqP = jnp.sqrt(P) + return sqP * VA - VB / sqP + + +def _pool_value_jax(RA, RB, P): + return P * RA + RB + + +def _micro_step_jax(RA, RB, VA, VB, P): + val_before = _pool_value_jax(RA, RB, P) + X = RA + VA + Y = RB + VB + L = X * Y + X_eq = jnp.sqrt(L / P) + Y_eq = P * X_eq + RA_new = X_eq - VA + RB_new = Y_eq - VB + return RA_new, RB_new, val_before - _pool_value_jax(RA_new, RB_new, P) + + +def _solve_VB_for_Z_jax(RA, RB, Z_star, Q, P): + sqP = jnp.sqrt(P) + a = -(Q - 1) / sqP + b = sqP * RA + RB / sqP - (Q - 1) * Z_star + c = sqP * RA * RB + Z_star * RB + disc = jnp.maximum(b * b - 4 * a * c, 1e-30) + sd = jnp.sqrt(disc) + r1 = (-b + sd) / (2 * a) + r2 = (-b - sd) / (2 * a) + floor = RB / (Q - 1) + 1e-8 + return jnp.where(r2 > floor, r2, r1) + + +def _z_targets_from_raw(raw_params, Z_start, Z_end): + """Map unconstrained params -> sorted Z targets via softplus gaps.""" + gaps = jax.nn.softplus(raw_params) + gaps = gaps / jnp.sum(gaps) * (Z_end - Z_start) + return Z_start + jnp.cumsum(gaps) + + +def _make_loss_fn(N): + """Build a JIT-compiled loss function for a given N (unrolled loop).""" + + def total_loss(raw_params, RA, RB, Q, P, Z_start, Z_end): + Z_all = _z_targets_from_raw(raw_params, Z_start, Z_end) + RA_c, RB_c = RA, RB + total = 0.0 + for i in range(N): + VB_i = _solve_VB_for_Z_jax(RA_c, RB_c, Z_all[i], Q, P) + VA_i = _compute_VA_from_VB_jax(RA_c, RB_c, VB_i, Q) + RA_c, RB_c, loss = _micro_step_jax(RA_c, RB_c, VA_i, VB_i, P) + total = total + loss + return total + + return jax.jit(jax.value_and_grad(total_loss)) + + +def optimise_z_targets(RA, RB, Q, P, Z_start, Z_end, N, verbose=False): + """Find the Z-target sequence minimising total arb loss. + + Returns (optimal_loss, optimal_Z_targets_array_of_length_N). + """ + loss_and_grad_fn = _make_loss_fn(N) + RA_j = jnp.float64(RA) + RB_j = jnp.float64(RB) + Q_j = jnp.float64(Q) + P_j = jnp.float64(P) + Zs_j = jnp.float64(Z_start) + Ze_j = jnp.float64(Z_end) + + def objective(x): + val, grad = loss_and_grad_fn( + jnp.array(x, dtype=jnp.float64), RA_j, RB_j, Q_j, P_j, Zs_j, Ze_j + ) + return float(val), np.array(grad, dtype=np.float64) + + x0 = np.zeros(N) # softplus(0) = ln2, uniform gaps → linear Z init + result = scipy_minimize(objective, x0, jac=True, method="L-BFGS-B") + + optimal_Z = np.array( + _z_targets_from_raw(jnp.array(result.x), Zs_j, Ze_j) + ) + if verbose: + print(f" N={N}: loss={result.fun:.6f} " + f"nit={result.nit} success={result.success}") + return result.fun, optimal_Z + + +# ── Experiments ──────────────────────────────────────────────────────────── + + +def main(): + # --- Scenario: centered pool, moderate VB decay --- + P = 2.0 # token A costs 2 units of token B + price_ratio = 4.0 # rho, so Q = sqrt(4) = 2 + R_scale = 10000.0 + decay_fraction = 0.90 # VB_end = 0.90 * VB_start (10% decay) + + RA, RB, VA, VB, Q = setup_centered_pool(P, price_ratio, R_scale) + VB_start = VB + VB_end = VB * decay_fraction + + # Diagnostics + C = min(RA * VB, RB * VA) / max(RA * VB, RB * VA) + is_above = RA * VB > RB * VA + X = RA + VA + print("=" * 72) + print(f"Scenario: centered pool at P={P}, price_ratio={price_ratio}, Q={Q:.4f}") + print(f" RA={RA:.2f} RB={RB:.2f} VA={VA:.2f} VB={VB:.2f}") + print(f" Effective X={X:.2f} Pool value = {pool_value(RA, RB, P):.2f}") + print(f" Centeredness = {C:.4f} is_above = {is_above}") + print(f" VB shift: {VB_start:.2f} -> {VB_end:.2f} ({decay_fraction:.0%})") + VB_floor = RB / (Q - 1) + print(f" VB floor (denominator > 0): {VB_floor:.2f}") + Z_start = compute_Z(VA, VB, P) + VA_end_cr = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end = compute_Z(VA_end_cr, VB_end, P) + print(f" Z_start = {Z_start:.4f} Z_end = {Z_end:.4f}") + print(f" Approx 1-step loss ~ (DeltaZ)^2/(4X) = {(Z_end-Z_start)**2/(4*X):.2f}") + print("=" * 72) + + # ── Experiment 1: Loss vs N ──────────────────────────────────────── + + N_values = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 96, 128] + schedules = ["geometric", "linear_VB", "linear_Z"] + results = {s: [] for s in schedules} + + for N in N_values: + for sched in schedules: + try: + loss, _, _ = run_shift( + RA, RB, VA, VB_start, VB_end, Q, P, N, sched + ) + except (ValueError, AssertionError) as e: + loss = np.nan + results[sched].append(loss) + + # Optimal 2-step (single point) + try: + loss_opt2, _, _ = run_shift_optimal_2step( + RA, RB, VA, VB_start, VB_end, Q, P + ) + except (ValueError, AssertionError): + loss_opt2 = np.nan + + # Table + loss_1 = results["geometric"][0] + print(f"\n{'N':>5s} {'Geo VB':>12s} {'Lin VB':>12s} {'Lin Z':>12s}" + f" {'Geo/1step':>9s} {'LinZ/1step':>10s} {'LinZ/Geo':>9s}") + print("-" * 80) + for j, N in enumerate(N_values): + g = results["geometric"][j] + lv = results["linear_VB"][j] + lz = results["linear_Z"][j] + print(f"{N:>5d} {g:>12.6f} {lv:>12.6f} {lz:>12.6f}" + f" {g / loss_1:>9.4f} {lz / loss_1:>10.4f} {lz / g:>9.4f}") + + print(f"\n Optimal 2-step loss: {loss_opt2:.6f}") + print(f" Geometric N=2 loss: {results['geometric'][1]:.6f}" + f" (opt/geo = {loss_opt2 / results['geometric'][1]:.4f})") + print(f" Linear Z N=2 loss: {results['linear_Z'][1]:.6f}" + f" (opt/linZ = {loss_opt2 / results['linear_Z'][1]:.4f})") + + # ── Experiment 2: Z and VB trajectories at N=8 ───────────────────── + + N_viz = 8 + traj_data = {} + for sched in schedules: + VB_traj, Z_traj, loss_traj = [VB_start], [], [] + VA_s = VA # stale + Z_traj.append(compute_Z(VA_s, VB_start, P)) + + RA_c, RB_c = RA, RB + if sched == "linear_Z": + Z0 = Z_traj[0] + VA_end_a = compute_VA_from_VB(RA, RB, VB_end, Q) + Z_end_val = compute_Z(VA_end_a, VB_end, P) + + for i in range(1, N_viz + 1): + frac = i / N_viz + if sched == "geometric": + VB_i = VB_start * (VB_end / VB_start) ** frac + elif sched == "linear_VB": + VB_i = VB_start + frac * (VB_end - VB_start) + else: + Z_i = Z0 + frac * (Z_end_val - Z0) + VB_i = solve_VB_for_Z(RA_c, RB_c, Z_i, Q, P) + + try: + VA_i = compute_VA_from_VB(RA_c, RB_c, VB_i, Q) + VB_traj.append(VB_i) + Z_traj.append(compute_Z(VA_i, VB_i, P)) + RA_c, RB_c, loss = micro_step(RA_c, RB_c, VA_i, VB_i, P) + loss_traj.append(loss) + except (ValueError, AssertionError): + break + + traj_data[sched] = { + "VB": np.array(VB_traj), + "Z": np.array(Z_traj), + "loss": np.array(loss_traj), + } + + # ── Experiment 3: sweep shift size at N=2 ────────────────────────── + + decay_sweep = np.linspace(0.80, 0.99, 30) + sweep = {s: [] for s in ["geometric", "linear_Z", "optimal_2step"]} + for df in decay_sweep: + VB_e = VB * df + try: + g, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "geometric") + lz, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "linear_Z") + o2, _, _ = run_shift_optimal_2step(RA, RB, VA, VB_start, VB_e, Q, P) + except (AssertionError, ValueError): + g = lz = o2 = np.nan + sweep["geometric"].append(g) + sweep["linear_Z"].append(lz) + sweep["optimal_2step"].append(o2) + + # ── Plots ────────────────────────────────────────────────────────── + + colours = {"geometric": "C0", "linear_VB": "C1", "linear_Z": "C2"} + labels = { + "geometric": "Geometric VB (contract)", + "linear_VB": "Linear VB", + "linear_Z": "Linear Z (optimal)", + } + + fig, axes = plt.subplots(2, 2, figsize=(13, 10)) + + # (0,0) Loss vs N + ax = axes[0, 0] + for s in schedules: + ax.plot(N_values, results[s], "o-", ms=4, color=colours[s], label=labels[s]) + ax.axhline(loss_opt2, color="C3", ls=":", label=f"Optimal 2-step = {loss_opt2:.4f}") + ax.set_xlabel("Steps N") + ax.set_ylabel("Total arb loss") + ax.set_title("Arb loss vs interpolation steps") + ax.set_xscale("log") + ax.set_yscale("log") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (0,1) Ratio linear_Z / geometric + ax = axes[0, 1] + ratios = np.array(results["linear_Z"]) / np.array(results["geometric"]) + ax.plot(N_values, ratios, "o-", color="C2") + ax.axhline(1.0, color="gray", ls="--", alpha=0.5) + ax.set_xlabel("Steps N") + ax.set_ylabel("Loss(Linear Z) / Loss(Geometric VB)") + ax.set_title("Relative improvement of Z-optimal") + ax.grid(True, alpha=0.3) + + # (1,0) Z trajectories at N=8 + ax = axes[1, 0] + steps = np.arange(N_viz + 1) + for s in schedules: + ax.plot(steps, traj_data[s]["Z"], "o-", ms=4, color=colours[s], label=labels[s]) + ax.set_xlabel("Step") + ax.set_ylabel("Z = sqrt(P)*VA - VB/sqrt(P)") + ax.set_title(f"Z trajectory (N={N_viz})") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (1,1) 2-step loss vs shift size + ax = axes[1, 1] + shift_pct = (1 - decay_sweep) * 100 + ax.plot(shift_pct, sweep["geometric"], color="C0", label="Geometric VB (N=2)") + ax.plot(shift_pct, sweep["linear_Z"], color="C2", label="Linear Z (N=2)") + ax.plot(shift_pct, sweep["optimal_2step"], ":", color="C3", label="Optimal 2-step") + ax.set_xlabel("Shift size (% VB decay)") + ax.set_ylabel("Arb loss") + ax.set_title("2-step loss vs shift magnitude") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_interpolation_benchmark.png", dpi=150) + print("\nSaved reclamm_interpolation_benchmark.png") + + # ── Per-step loss bar chart for N=8 ──────────────────────────────── + + fig2, ax = plt.subplots(figsize=(10, 5)) + x = np.arange(1, N_viz + 1) + w = 0.25 + for i, s in enumerate(schedules): + ax.bar(x + i * w, traj_data[s]["loss"], w, color=colours[s], label=labels[s]) + ax.set_xlabel("Step") + ax.set_ylabel("Per-step arb loss") + ax.set_title(f"Per-step loss distribution (N={N_viz})") + ax.legend(fontsize=8) + ax.set_xticks(x + w) + plt.tight_layout() + plt.savefig("reclamm_interpolation_perstep.png", dpi=150) + print("Saved reclamm_interpolation_perstep.png") + + # ── Experiment 4: small-shift regime (paper's approximation valid) ─── + + print("\n" + "=" * 72) + print("Experiment 4: Optimal 2-step vs Geometric N=2 at small shifts") + print(" (reserves nearly constant → paper's analysis should hold)") + print("-" * 72) + print(f" {'Decay %':>8s} {'Geo N=2':>12s} {'LinZ N=2':>12s} " + f"{'Opt2':>12s} {'Opt2/Geo':>9s} {'Opt2/LinZ':>9s}") + print("-" * 72) + + small_decays = [0.999, 0.998, 0.995, 0.99, 0.98, 0.95, 0.90, 0.80] + for df in small_decays: + VB_e = VB * df + try: + g, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "geometric") + lz, _, _ = run_shift(RA, RB, VA, VB_start, VB_e, Q, P, 2, "linear_Z") + o2, _, _ = run_shift_optimal_2step( + RA, RB, VA, VB_start, VB_e, Q, P + ) + except (ValueError, AssertionError) as e: + print(f" {(1-df)*100:>7.1f}% FAILED: {e}") + continue + print(f" {(1-df)*100:>7.1f}% {g:>12.6f} {lz:>12.6f} " + f"{o2:>12.6f} {o2/g:>9.6f} {o2/lz:>9.6f}") + + print("=" * 72) + + # ── Experiment 5: brute-force JAX-optimised Z targets ──────────────── + + print("\n" + "=" * 72) + print("Experiment 5: Brute-force optimal Z targets (JAX + L-BFGS-B)") + print(" Parameterisation: softplus gaps → sorted Z targets") + print(" Initialised at linear Z (uniform gaps)") + print("-" * 72) + + opt_N_values = [2, 3, 4, 6, 8, 12, 16, 24, 32] + opt_losses = {} + opt_Z_trajs = {} + + for N in opt_N_values: + loss_bf, Z_bf = optimise_z_targets( + RA, RB, Q, P, Z_start, Z_end, N, verbose=True + ) + opt_losses[N] = loss_bf + opt_Z_trajs[N] = Z_bf + + # Comparison table + print(f"\n {'N':>5s} {'Geometric':>12s} {'Linear Z':>12s} " + f"{'BF Optimal':>12s} {'BF/LinZ':>9s} {'BF/Geo':>9s}") + print("-" * 72) + for N in opt_N_values: + idx = N_values.index(N) if N in N_values else None + g = results["geometric"][idx] if idx is not None else np.nan + lz = results["linear_Z"][idx] if idx is not None else np.nan + bf = opt_losses[N] + print(f" {N:>5d} {g:>12.6f} {lz:>12.6f} " + f"{bf:>12.6f} {bf/lz:>9.6f} {bf/g:>9.6f}") + + # ── Plot: overlay brute-force on the main loss-vs-N chart ──────────── + + fig3, axes3 = plt.subplots(1, 2, figsize=(14, 5)) + + # (left) Loss vs N with brute-force overlay + ax = axes3[0] + for s in schedules: + ax.plot(N_values, results[s], "o-", ms=4, color=colours[s], + label=labels[s]) + bf_Ns = sorted(opt_losses.keys()) + bf_vals = [opt_losses[n] for n in bf_Ns] + ax.plot(bf_Ns, bf_vals, "s--", ms=5, color="C3", label="BF Optimal (JAX)") + ax.set_xlabel("Steps N") + ax.set_ylabel("Total arb loss") + ax.set_title("Arb loss vs interpolation steps (with BF optimal)") + ax.set_xscale("log") + ax.set_yscale("log") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + # (right) Z trajectory comparison at N=8 + ax = axes3[1] + N_cmp = 8 + steps_cmp = np.arange(N_cmp + 1) + + # Geometric: compute Z trajectory from VB + z_geo = [Z_start] + RA_t, RB_t = RA, RB + for i in range(1, N_cmp + 1): + frac = i / N_cmp + VB_i = VB_start * (VB_end / VB_start) ** frac + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_geo.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + # Linear Z + z_linz = [Z_start] + RA_t, RB_t = RA, RB + for i in range(1, N_cmp + 1): + frac = i / N_cmp + Z_i = Z_start + frac * (Z_end - Z_start) + VB_i = solve_VB_for_Z(RA_t, RB_t, Z_i, Q, P) + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_linz.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + # BF optimal + z_bf = [Z_start] + list(opt_Z_trajs[N_cmp]) + # Trace actual Z achieved after arb at each step + z_bf_actual = [Z_start] + RA_t, RB_t = RA, RB + for i in range(N_cmp): + VB_i = solve_VB_for_Z(RA_t, RB_t, opt_Z_trajs[N_cmp][i], Q, P) + VA_i = compute_VA_from_VB(RA_t, RB_t, VB_i, Q) + z_bf_actual.append(compute_Z(VA_i, VB_i, P)) + RA_t, RB_t, _ = micro_step(RA_t, RB_t, VA_i, VB_i, P) + + ax.plot(steps_cmp, z_geo, "o-", ms=4, color="C0", label="Geometric VB") + ax.plot(steps_cmp, z_linz, "o-", ms=4, color="C2", label="Linear Z") + ax.plot(steps_cmp, z_bf_actual, "s--", ms=5, color="C3", + label="BF Optimal") + ax.plot(steps_cmp, np.linspace(Z_start, Z_end, N_cmp + 1), + ":", color="gray", alpha=0.5, label="Ideal linear Z") + ax.set_xlabel("Step") + ax.set_ylabel("Z = sqrt(P)*VA - VB/sqrt(P)") + ax.set_title(f"Z trajectory comparison (N={N_cmp})") + ax.legend(fontsize=7) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_interpolation_bruteforce.png", dpi=150) + print("\nSaved reclamm_interpolation_bruteforce.png") + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_bayesian.py b/scripts/calibrate_noise_bayesian.py new file mode 100644 index 0000000..efdfb4b --- /dev/null +++ b/scripts/calibrate_noise_bayesian.py @@ -0,0 +1,853 @@ +"""Bayesian hierarchical noise volume model across Balancer pools. + +Full Bayesian version of the noise calibration: ALL K=4 per-pool +coefficients (intercept, TVL elasticity, volatility response, weekend +effect) vary per pool with pool-level covariates modulating their priors, +and an LKJ-decomposed covariance capturing correlations between +coefficients. + +Generative model: + For pool i with pool-level covariates z_i, day t: + + mu_i = B . z_i # K-vector population mean + eta_i ~ N(0, I_K) # non-centered offsets + theta_i = mu_i + diag(sigma) . L . eta_i # per-pool coefficients + + log(V_{i,t}) ~ N(theta_i . x_{i,t}, sigma_eps^2) + + theta_i = [intercept_i, b_tvl_i, b_sigma_i, b_weekend_i] + x_{i,t} = [1, log_tvl, volatility, weekend] + z_i = [1, chain_dummies(6), tier_A_dummies(2), tier_B_dummies(2), log_fee] + +Priors: + B_{k,d} ~ N(0, 5^2) + sigma_k ~ HalfNormal(2.0) + L ~ LKJCholesky(K=4, eta=2) + sigma_eps ~ HalfNormal(3.0) + +Usage: + # Full pipeline: fit + output + diagnostics + python scripts/calibrate_noise_bayesian.py \\ + --fit --output results/bayesian_noise_params.json --plot + + # Predict for an unseen pool + python scripts/calibrate_noise_bayesian.py \\ + --predict --chain BASE --tokens ETH USDC --fee 0.003 + + # Custom NUTS settings + python scripts/calibrate_noise_bayesian.py \\ + --fit --num-warmup 2000 --num-samples 4000 --num-chains 4 +""" + +import argparse +import json +import os +import sys + +import numpy as np +import pandas as pd + +# arviz 0.17.x imports scipy.signal.gaussian which was removed in scipy 1.13+. +# Patch it back from scipy.signal.windows before any arviz import. +try: + from scipy.signal import gaussian as _ # noqa: F401 +except ImportError: + from scipy.signal.windows import gaussian as _gauss + import scipy.signal + scipy.signal.gaussian = _gauss + +# --------------------------------------------------------------------------- +# Reuse constants and helpers from the frequentist script +# --------------------------------------------------------------------------- + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "local_data", "noise_calibration" +) + +# Reference levels for dummy coding (dropped categories) +REF_CHAIN = "ARBITRUM" +REF_TIER = 0 + +# Ordered non-reference chains (alphabetical excluding REF_CHAIN) +CHAIN_ORDER = ["BASE", "GNOSIS", "MAINNET", "OPTIMISM", "POLYGON", "SONIC"] + +K = 4 # number of per-pool coefficients +D = 12 # pool-level covariate dimension: 1 + 6 chains + 2 tier_A + 2 tier_B + 1 log_fee + +COEFF_NAMES = ["intercept", "b_tvl", "b_sigma", "b_weekend"] + + +# --------------------------------------------------------------------------- +# Token tier helpers (duplicated to avoid import fragility) +# --------------------------------------------------------------------------- + +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", +} + +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} + + +def classify_token_tier(symbol: str) -> int: + s = symbol.strip() + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# Data preparation +# --------------------------------------------------------------------------- + +def load_panel(cache_dir: str = CACHE_DIR) -> pd.DataFrame: + """Load the cached panel parquet produced by calibrate_noise_hierarchical.py --fetch.""" + panel_path = os.path.join(cache_dir, "panel.parquet") + if not os.path.exists(panel_path): + print(f"ERROR: Panel cache not found at {panel_path}", file=sys.stderr) + print("Run: python scripts/calibrate_noise_hierarchical.py --fetch", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_path) + + # Filter pools with < 10 observations + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid_pools)].copy() + + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + return panel + + +def _build_z_pool(pool_meta: pd.DataFrame) -> np.ndarray: + """Build (N_pools, D) pool-level covariate matrix. + + Columns: [1, chain_BASE, ..., chain_SONIC (6), + tier_A_1, tier_A_2, tier_B_1, tier_B_2, log_fee] + """ + N = len(pool_meta) + z = np.zeros((N, D), dtype=np.float64) + + # Intercept + z[:, 0] = 1.0 + + # Chain dummies (columns 1-6) + for j, chain in enumerate(CHAIN_ORDER): + z[:, 1 + j] = (pool_meta["chain"].values == chain).astype(float) + + # tier_A dummies (columns 7-8): tiers 1 and 2, reference = 0 + tier_a = pool_meta["tier_A"].values.astype(int) + z[:, 7] = (tier_a == 1).astype(float) + z[:, 8] = (tier_a == 2).astype(float) + + # tier_B dummies (columns 9-10): tiers 1 and 2, reference = 0 + tier_b = pool_meta["tier_B"].values.astype(int) + z[:, 9] = (tier_b == 1).astype(float) + z[:, 10] = (tier_b == 2).astype(float) + + # log_fee (column 11) + z[:, 11] = np.log(np.maximum(pool_meta["swap_fee"].values.astype(float), 1e-6)) + + return z + + +def prepare_data(panel: pd.DataFrame) -> dict: + """Construct JAX-ready arrays from the panel DataFrame. + + Returns dict with: + pool_idx : (N_obs,) int32 — pool index per observation + z_pool : (N_pools, D) float64 — pool-level covariates + x_obs : (N_obs, K) float64 — within-day regressors + y_obs : (N_obs,) float64 — log_volume + pool_ids : list — ordered pool IDs + pool_meta : DataFrame — per-pool metadata (indexed same as z_pool rows) + """ + # Stable pool ordering + pool_ids = sorted(panel["pool_id"].unique()) + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + N_pools = len(pool_ids) + + # Pool-level metadata (one row per pool) + pool_meta = panel.drop_duplicates("pool_id").set_index("pool_id").loc[pool_ids].reset_index() + z_pool = _build_z_pool(pool_meta) + + # Observation-level arrays + pool_idx = panel["pool_id"].map(pool_id_to_idx).values.astype(np.int32) + x_obs = np.column_stack([ + np.ones(len(panel)), + panel["log_tvl"].values, + panel["volatility"].values, + panel["weekend"].values, + ]).astype(np.float64) + y_obs = panel["log_volume"].values.astype(np.float64) + + print(f" Prepared: N_obs={len(y_obs)}, N_pools={N_pools}, K={K}, D={D}") + print(f" z_pool range check — log_fee: [{z_pool[:, 11].min():.2f}, {z_pool[:, 11].max():.2f}]") + + return { + "pool_idx": pool_idx, + "z_pool": z_pool, + "x_obs": x_obs, + "y_obs": y_obs, + "pool_ids": pool_ids, + "pool_meta": pool_meta, + "N_pools": N_pools, + } + + +# --------------------------------------------------------------------------- +# NumPyro model +# --------------------------------------------------------------------------- + +def hierarchical_noise_model(pool_idx, z_pool, x_obs, y_obs=None, + N_pools=None, K=4, D=12): + """Bayesian hierarchical noise volume model. + + Non-centered parameterization with LKJ correlation prior. + """ + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + + N_obs = pool_idx.shape[0] + + # --- Population coefficient matrix B: (K, D) --- + B = numpyro.sample("B", dist.Normal(0.0, 5.0).expand([K, D]).to_event(2)) + + # --- Per-pool scale and correlation --- + sigma = numpyro.sample("sigma", dist.HalfNormal(2.0).expand([K]).to_event(1)) + L_Omega = numpyro.sample("L_Omega", dist.LKJCholesky(K, concentration=2.0)) + + # Cholesky factor of covariance: diag(sigma) @ L_Omega + L_Sigma = jnp.diag(sigma) @ L_Omega # (K, K) + + # --- Non-centered pool effects --- + with numpyro.plate("pools", N_pools): + eta = numpyro.sample("eta", dist.Normal(0.0, 1.0).expand([K]).to_event(1)) + + # theta_i = B @ z_i + L_Sigma @ eta_i for each pool i + # mu: (N_pools, K) = z_pool @ B^T + mu = z_pool @ B.T # (N_pools, K) + theta = mu + eta @ L_Sigma.T # (N_pools, K) + + # --- Observation model --- + sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(3.0)) + + # Predicted log-volume: theta[pool_idx] . x_obs (dot product per obs) + theta_obs = theta[pool_idx] # (N_obs, K) + mu_obs = jnp.sum(theta_obs * x_obs, axis=1) # (N_obs,) + + with numpyro.plate("obs", N_obs): + numpyro.sample("y", dist.Normal(mu_obs, sigma_eps), obs=y_obs) + + # Deterministic: store theta for extraction + numpyro.deterministic("theta", theta) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- + +def run_inference(data, num_warmup=1000, num_samples=2000, num_chains=4, + target_accept=0.85, max_tree_depth=10, seed=42): + """Run NUTS on the hierarchical model. + + Returns the MCMC object with samples. + """ + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import MCMC, NUTS + + # Use all available CPU cores for chains + numpyro.set_host_device_count(min(num_chains, len(jax.devices("cpu")))) + + kernel = NUTS( + hierarchical_noise_model, + target_accept_prob=target_accept, + max_tree_depth=max_tree_depth, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=num_chains, + progress_bar=True, + ) + + rng_key = jax.random.PRNGKey(seed) + + print(f"\n Running NUTS: {num_chains} chains x " + f"({num_warmup} warmup + {num_samples} samples)") + print(f" target_accept={target_accept}, max_tree_depth={max_tree_depth}") + + mcmc.run( + rng_key, + pool_idx=jnp.array(data["pool_idx"]), + z_pool=jnp.array(data["z_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K=K, + D=D, + ) + + mcmc.print_summary(exclude_deterministic=True) + return mcmc + + +# --------------------------------------------------------------------------- +# Post-processing +# --------------------------------------------------------------------------- + +def extract_noise_params(mcmc, data) -> list: + """Extract per-pool noise params from MCMC posterior. + + Reconstructs theta from the non-centered parameterization, + takes posterior medians, and applies weekend absorption: + b_0_effective = b_0_raw + b_weekend * (2/7) + + Returns list of dicts compatible with reclamm_loglinear_noise_volume. + """ + samples = mcmc.get_samples() + theta_samples = samples["theta"] # (n_samples, N_pools, K) + + # Posterior median per pool + theta_median = np.median(theta_samples, axis=0) # (N_pools, K) + theta_std = np.std(theta_samples, axis=0) # (N_pools, K) + + pool_ids = data["pool_ids"] + pool_meta = data["pool_meta"] + + results = [] + for i, pool_id in enumerate(pool_ids): + meta = pool_meta.iloc[i] + b_0_raw, b_tvl, b_sigma, b_weekend = theta_median[i] + std_vals = theta_std[i] + + # Weekend absorption: simulator has no weekend indicator, + # so fold the expected weekend effect into the intercept. + # Weekend days = 2/7 of all days. + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + tokens = meta["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + + results.append({ + "pool_id": pool_id, + "chain": str(meta["chain"]), + "tokens": tokens, + "theta_median": [float(x) for x in theta_median[i]], + "theta_std": [float(x) for x in std_vals], + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(meta["swap_fee"]), + }, + }) + + return results + + +def predict_new_pool(mcmc, data, chain: str, tokens: list, fee: float) -> dict: + """Predict noise params for an unseen pool using population effects. + + Constructs z_new, computes mu_new = B @ z_new across posterior samples, + and returns median + 90% credible intervals. + """ + # Build z_new + z_new = np.zeros(D, dtype=np.float64) + z_new[0] = 1.0 # intercept + + # Chain dummies + if chain in CHAIN_ORDER: + j = CHAIN_ORDER.index(chain) + z_new[1 + j] = 1.0 + + # Tier dummies + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + if tier_a == 1: + z_new[7] = 1.0 + elif tier_a == 2: + z_new[8] = 1.0 + if tier_b == 1: + z_new[9] = 1.0 + elif tier_b == 2: + z_new[10] = 1.0 + + # log_fee + z_new[11] = np.log(max(fee, 1e-6)) + + # Compute mu_new = B @ z_new across all posterior samples + B_samples = np.array(mcmc.get_samples()["B"]) # (n_samples, K, D) + mu_samples = np.einsum("skd,d->sk", B_samples, z_new) # (n_samples, K) + + mu_median = np.median(mu_samples, axis=0) + mu_q05 = np.percentile(mu_samples, 5, axis=0) + mu_q95 = np.percentile(mu_samples, 95, axis=0) + + # Weekend absorption + b_0_raw, b_tvl, b_sigma, b_weekend = mu_median + b_0_effective = b_0_raw + b_weekend * (2.0 / 7.0) + + result = { + "chain": chain, + "tokens": tokens, + "fee": fee, + "prediction_source": "population_level", + "noise_params": { + "b_0": float(b_0_effective), + "b_sigma": float(b_sigma), + "b_c": float(b_tvl), + "b_weekend": float(b_weekend), + "base_fee": float(fee), + }, + "credible_intervals_90": { + name: { + "median": float(mu_median[k]), + "q05": float(mu_q05[k]), + "q95": float(mu_q95[k]), + } + for k, name in enumerate(COEFF_NAMES) + }, + } + + print(f"\n Predicted noise_params for {chain} {tokens} (fee={fee}):") + for name, ci in result["credible_intervals_90"].items(): + print(f" {name:12s}: {ci['median']:+.3f} " + f"[{ci['q05']:+.3f}, {ci['q95']:+.3f}]") + print(f"\n Effective b_0 (weekend-absorbed): {b_0_effective:.3f}") + + return result + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def check_convergence(mcmc) -> dict: + """Compute convergence diagnostics: R-hat, ESS, divergences.""" + import arviz as az + + idata = az.from_numpyro(mcmc) + + # R-hat and ESS for non-deterministic parameters + n_chains = idata.posterior.sizes.get("chain", 1) + + rhat_max = float("nan") + if n_chains >= 2: + rhat = az.rhat(idata) + rhat_vals = [] + for var in rhat.data_vars: + if var == "theta": + continue # deterministic + vals = rhat[var].values + rhat_vals.extend(vals.flatten()) + rhat_max = float(np.nanmax(rhat_vals)) if rhat_vals else float("nan") + + ess = az.ess(idata) + ess_vals = [] + for var in ess.data_vars: + if var == "theta": + continue + vals = ess[var].values + ess_vals.extend(vals.flatten()) + ess_min = float(np.nanmin(ess_vals)) if ess_vals else float("nan") + + # Divergences + divergences = int(idata.sample_stats["diverging"].sum().values) + + print(f"\n Convergence diagnostics:") + if n_chains >= 2: + print(f" R-hat max: {rhat_max:.4f} {'OK' if rhat_max < 1.05 else 'WARNING'}") + else: + print(f" R-hat max: N/A (need >= 2 chains)") + print(f" ESS min: {ess_min:.0f} {'OK' if ess_min > 400 else 'WARNING'}") + print(f" Divergences: {divergences} {'OK' if divergences == 0 else 'WARNING'}") + + return { + "r_hat_max": rhat_max, + "ess_min": ess_min, + "divergences": divergences, + } + + +def plot_bayesian_diagnostics(mcmc, data, output_dir="results"): + """Generate ArviZ diagnostic plots.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import arviz as az + + os.makedirs(output_dir, exist_ok=True) + idata = az.from_numpyro(mcmc) + samples = mcmc.get_samples() + + # --- 1. Trace plots for sigma, sigma_eps --- + axes = az.plot_trace(idata, var_names=["sigma", "sigma_eps"], compact=True) + fig1 = axes.ravel()[0].figure + fig1.set_size_inches(14, 8) + path1 = os.path.join(output_dir, "bayesian_trace_sigma.png") + fig1.savefig(path1, dpi=150, bbox_inches="tight") + plt.close(fig1) + print(f" Saved: {path1}") + + # --- 2. Posterior predictive: predicted vs observed --- + theta_samples = samples["theta"] # (S, N_pools, K) + sigma_eps_samples = np.array(samples["sigma_eps"]) # (S,) + theta_median = np.median(theta_samples, axis=0) # (N_pools, K) + + pool_idx = data["pool_idx"] + x_obs = data["x_obs"] + y_obs = data["y_obs"] + + theta_obs = theta_median[pool_idx] + y_pred = np.sum(theta_obs * x_obs, axis=1) + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + ax.scatter(y_obs, y_pred, alpha=0.1, s=4, color="steelblue") + lims = [min(y_obs.min(), y_pred.min()), max(y_obs.max(), y_pred.max())] + ax.plot(lims, lims, "r--", linewidth=1) + ax.set_xlabel("Observed log(volume)") + ax.set_ylabel("Predicted log(volume)") + ax.set_title("Posterior predictive check") + r2 = 1 - np.var(y_obs - y_pred) / np.var(y_obs) + ax.text(0.05, 0.95, f"R² = {r2:.3f}", transform=ax.transAxes, + fontsize=11, verticalalignment="top") + + ax = axes[1] + residuals = y_obs - y_pred + ax.hist(residuals, bins=60, color="steelblue", edgecolor="white", alpha=0.8) + ax.axvline(0, color="red", linestyle="--") + ax.set_xlabel("Residual") + ax.set_title(f"Residual distribution (σ_ε ≈ {np.median(sigma_eps_samples):.2f})") + + plt.tight_layout() + path2 = os.path.join(output_dir, "bayesian_posterior_predictive.png") + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + # --- 3. Per-pool b_c (TVL elasticity) by chain/tier --- + pool_meta = data["pool_meta"] + b_tvl_all = theta_median[:, 1] # index 1 = b_tvl + + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + ax = axes[0] + chains_present = sorted(pool_meta["chain"].unique()) + chain_data = [] + chain_labels = [] + for c in chains_present: + mask = pool_meta["chain"].values == c + if mask.sum() > 0: + chain_data.append(b_tvl_all[mask]) + chain_labels.append(f"{c}\n(n={mask.sum()})") + ax.boxplot(chain_data, tick_labels=chain_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by chain") + + ax = axes[1] + # By tier_A + tier_a_vals = pool_meta["tier_A"].values.astype(int) + tier_labels_map = {0: "Blue-chip", 1: "Mid-cap", 2: "Long-tail"} + tier_data = [] + tier_labels = [] + for t in [0, 1, 2]: + mask = tier_a_vals == t + if mask.sum() > 0: + tier_data.append(b_tvl_all[mask]) + tier_labels.append(f"{tier_labels_map[t]}\n(n={mask.sum()})") + ax.boxplot(tier_data, tick_labels=tier_labels, vert=True) + ax.axhline(1.0, color="red", linestyle="--", linewidth=0.8, alpha=0.6) + ax.set_ylabel("Per-pool b_c (TVL elasticity)") + ax.set_title("TVL elasticity by token tier (best token)") + + plt.tight_layout() + path3 = os.path.join(output_dir, "bayesian_per_pool_b_c.png") + plt.savefig(path3, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path3}") + + # --- 4. Correlation matrix posterior --- + L_Omega_samples = np.array(samples["L_Omega"]) # (S, K, K) + # Correlation = L @ L^T + Omega_samples = np.einsum("sij,skj->sik", L_Omega_samples, L_Omega_samples) + Omega_median = np.median(Omega_samples, axis=0) + + fig, ax = plt.subplots(figsize=(7, 6)) + im = ax.imshow(Omega_median, vmin=-1, vmax=1, cmap="RdBu_r") + ax.set_xticks(range(K)) + ax.set_yticks(range(K)) + ax.set_xticklabels(COEFF_NAMES, rotation=45, ha="right") + ax.set_yticklabels(COEFF_NAMES) + for i in range(K): + for j in range(K): + ax.text(j, i, f"{Omega_median[i, j]:.2f}", ha="center", va="center", + fontsize=10, color="white" if abs(Omega_median[i, j]) > 0.5 else "black") + plt.colorbar(im, ax=ax, shrink=0.8) + ax.set_title("Posterior median correlation matrix (Ω)") + plt.tight_layout() + path4 = os.path.join(output_dir, "bayesian_correlation_matrix.png") + plt.savefig(path4, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path4}") + + # --- 5. Shrinkage plot: OLS b_c vs hierarchical b_c --- + # Compute per-pool OLS b_c for comparison + panel_meta = data["pool_meta"] + pool_idx_arr = data["pool_idx"] + pool_ids = data["pool_ids"] + + ols_b_c = np.zeros(len(pool_ids)) + for i, pid in enumerate(pool_ids): + mask = pool_idx_arr == i + if mask.sum() < 5: + ols_b_c[i] = np.nan + continue + x_i = data["x_obs"][mask] + y_i = data["y_obs"][mask] + # Simple OLS: y = X @ beta + try: + beta, _, _, _ = np.linalg.lstsq(x_i, y_i, rcond=None) + ols_b_c[i] = beta[1] # TVL coefficient + except np.linalg.LinAlgError: + ols_b_c[i] = np.nan + + hier_b_c = theta_median[:, 1] + valid = np.isfinite(ols_b_c) + + fig, ax = plt.subplots(figsize=(8, 8)) + ax.scatter(ols_b_c[valid], hier_b_c[valid], alpha=0.6, s=20, color="steelblue") + + # Population mean line + pop_b_c = np.median(hier_b_c) + ax.axhline(pop_b_c, color="red", linestyle="--", linewidth=0.8, + label=f"Population median = {pop_b_c:.3f}") + + # 45-degree line + lims = [min(np.nanmin(ols_b_c[valid]), hier_b_c[valid].min()) - 0.2, + max(np.nanmax(ols_b_c[valid]), hier_b_c[valid].max()) + 0.2] + ax.plot(lims, lims, "k:", linewidth=0.8, alpha=0.5) + ax.set_xlabel("Per-pool OLS b_c") + ax.set_ylabel("Hierarchical posterior median b_c") + ax.set_title("Shrinkage: OLS vs hierarchical TVL elasticity") + ax.legend() + plt.tight_layout() + path5 = os.path.join(output_dir, "bayesian_shrinkage_b_c.png") + plt.savefig(path5, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path5}") + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + +def generate_output_json(pool_params, mcmc, data, convergence, output_path, + num_warmup, num_samples, num_chains, target_accept): + """Write structured JSON output with population effects and per-pool params.""" + samples = mcmc.get_samples() + + B_median = np.median(np.array(samples["B"]), axis=0).tolist() + sigma_median = np.median(np.array(samples["sigma"]), axis=0).tolist() + sigma_eps_median = float(np.median(np.array(samples["sigma_eps"]))) + + output = { + "model": "bayesian_hierarchical_loglinear", + "inference": { + "method": "NUTS", + "num_warmup": num_warmup, + "num_samples": num_samples, + "num_chains": num_chains, + "target_accept_prob": target_accept, + }, + "population_effects": { + "B": B_median, + "sigma": sigma_median, + "sigma_eps": sigma_eps_median, + "coeff_names": COEFF_NAMES, + "covariate_names": ( + ["intercept"] + [f"chain_{c}" for c in CHAIN_ORDER] + + ["tier_A_1", "tier_A_2", "tier_B_1", "tier_B_2", "log_fee"] + ), + }, + "convergence": convergence, + "pools": { + p["pool_id"]: { + "chain": p["chain"], + "tokens": p["tokens"], + "theta_median": p["theta_median"], + "theta_std": p["theta_std"], + "noise_params": p["noise_params"], + } + for p in pool_params + }, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params -> {output_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Bayesian hierarchical noise volume model for Balancer pools" + ) + parser.add_argument( + "--fetch", action="store_true", + help="Fetch pool data (delegates to calibrate_noise_hierarchical.py --fetch)", + ) + parser.add_argument("--fit", action="store_true", help="Run NUTS inference") + parser.add_argument("--plot", action="store_true", help="Generate diagnostic plots") + parser.add_argument("--output", default=None, help="Output JSON path") + parser.add_argument("--output-dir", default="results", help="Plot output directory") + parser.add_argument("--predict", action="store_true", help="Predict for a new pool") + parser.add_argument("--chain", default=None, help="Chain for --predict") + parser.add_argument("--tokens", nargs="+", default=None, help="Tokens for --predict") + parser.add_argument("--fee", type=float, default=0.003, help="Fee for --predict") + parser.add_argument("--cache-dir", default=None, help="Cache directory") + + # NUTS hyperparameters + parser.add_argument("--num-warmup", type=int, default=1000) + parser.add_argument("--num-samples", type=int, default=2000) + parser.add_argument("--num-chains", type=int, default=4) + parser.add_argument("--target-accept", type=float, default=0.85) + parser.add_argument("--max-tree-depth", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args() + + cache_dir = args.cache_dir or CACHE_DIR + + if not any([args.fetch, args.fit, args.predict]): + parser.error("At least one of --fetch, --fit, --predict is required") + + # --- Fetch (delegate to existing script) --- + if args.fetch: + import subprocess + cmd = [ + sys.executable, "scripts/calibrate_noise_hierarchical.py", + "--fetch", "--cache-dir", cache_dir, + ] + print("Delegating data fetch to calibrate_noise_hierarchical.py...") + subprocess.run(cmd, check=True) + + # --- Fit --- + if args.fit: + print("\nBayesian Hierarchical Noise Volume Model") + print("=" * 60) + + panel = load_panel(cache_dir) + data = prepare_data(panel) + + mcmc = run_inference( + data, + num_warmup=args.num_warmup, + num_samples=args.num_samples, + num_chains=args.num_chains, + target_accept=args.target_accept, + max_tree_depth=args.max_tree_depth, + seed=args.seed, + ) + + convergence = check_convergence(mcmc) + pool_params = extract_noise_params(mcmc, data) + + # Print summary statistics + b_c_vals = [p["noise_params"]["b_c"] for p in pool_params] + b_0_vals = [p["noise_params"]["b_0"] for p in pool_params] + print(f"\n Per-pool b_c: mean={np.mean(b_c_vals):.3f}, " + f"std={np.std(b_c_vals):.3f}, " + f"range=[{np.min(b_c_vals):.3f}, {np.max(b_c_vals):.3f}]") + print(f" Per-pool b_0: mean={np.mean(b_0_vals):.3f}, " + f"std={np.std(b_0_vals):.3f}") + + if args.output: + generate_output_json( + pool_params, mcmc, data, convergence, args.output, + args.num_warmup, args.num_samples, args.num_chains, + args.target_accept, + ) + + if args.plot: + print("\nGenerating diagnostic plots...") + plot_bayesian_diagnostics(mcmc, data, output_dir=args.output_dir) + + # Save MCMC samples for --predict reuse + mcmc_cache = os.path.join(cache_dir, "bayesian_mcmc_samples.npz") + samples = mcmc.get_samples() + np.savez_compressed( + mcmc_cache, + **{k: np.array(v) for k, v in samples.items()}, + ) + # Also save data arrays for predict + data_cache = os.path.join(cache_dir, "bayesian_data.npz") + np.savez_compressed( + data_cache, + pool_idx=data["pool_idx"], + z_pool=data["z_pool"], + x_obs=data["x_obs"], + y_obs=data["y_obs"], + ) + # Save pool_ids list + with open(os.path.join(cache_dir, "bayesian_pool_ids.json"), "w") as f: + json.dump(data["pool_ids"], f) + print(f" Saved MCMC samples -> {mcmc_cache}") + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + parser.error("--predict requires --chain and --tokens") + + # Load cached MCMC samples + mcmc_cache = os.path.join(cache_dir, "bayesian_mcmc_samples.npz") + if not os.path.exists(mcmc_cache): + print(f"ERROR: MCMC cache not found at {mcmc_cache}", file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + # For predict, we only need B samples — create a minimal mock + cached = np.load(mcmc_cache) + + class _MockMCMC: + """Minimal interface to reuse predict_new_pool with cached samples.""" + def __init__(self, samples_dict): + self._samples = samples_dict + def get_samples(self): + return self._samples + + samples_dict = {k: cached[k] for k in cached.files} + mock_mcmc = _MockMCMC(samples_dict) + + result = predict_new_pool(mock_mcmc, None, args.chain, args.tokens, args.fee) + print(json.dumps(result, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_hierarchical.py b/scripts/calibrate_noise_hierarchical.py new file mode 100644 index 0000000..4fef642 --- /dev/null +++ b/scripts/calibrate_noise_hierarchical.py @@ -0,0 +1,1556 @@ +"""Bayesian hierarchical noise volume model across Balancer WEIGHTED + RECLAMM pools. + +Pools data cross-sectionally across all Balancer weighted/reCLAMM pools, +fits a Bayesian hierarchical model where pool covariates (chain, token tier, +fee) modulate all coefficients via group-level regression, with full +posterior inference via NumPyro. + +Model: + Hyperpriors: + Φ ~ Normal(0, 2) (K × 3) group-level regression + σ_θ ~ HalfNormal(2) (3,) per-coefficient scales + L_ω ~ LKJCholesky(3, η=2) correlation structure + β_weekend ~ Normal(0, 2) shared nuisance + σ_ε ~ HalfNormal(3) observation noise + + For each pool i: + x_i = [1, chain_dummies, tier_dummies, log_fee] (K,) covariates + z_i ~ N(0, I₃) non-centered + θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i (α_i, β_tvl_i, β_vol_i) + + For each observation (i, t): + log(V) ~ N(α_i + β_tvl_i·log_tvl + β_vol_i·vol + β_weekend·weekend, σ²_ε) + +Usage: + # Full pipeline: fetch data + fit model + output + python scripts/calibrate_noise_hierarchical.py \\ + --fetch --fit --output results/hierarchical_noise_params.json --plot + + # Use cached data, re-fit only + python scripts/calibrate_noise_hierarchical.py \\ + --fit --output results/hierarchical_noise_params.json + + # Predict for a new pool + python scripts/calibrate_noise_hierarchical.py \\ + --predict --chain BASE --tokens ETH BTC --fee 0.003 + + # Use NUTS instead of SVI + python scripts/calibrate_noise_hierarchical.py \\ + --fit --nuts --output results/hierarchical_noise_params.json +""" + +import argparse +import json +import os +import sys +import time +import urllib.request +from datetime import datetime, timezone + +import numpy as np +import pandas as pd + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from numpyro.infer import SVI, MCMC, NUTS, Trace_ELBO, Predictive +from numpyro.infer.autoguide import AutoMultivariateNormal + +numpyro.enable_x64() + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAINS = [ + "MAINNET", "POLYGON", "ARBITRUM", "GNOSIS", "BASE", "SONIC", "OPTIMISM", + "AVALANCHE", +] + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "local_data", "noise_calibration" +) + +# --------------------------------------------------------------------------- +# Token tier classification +# --------------------------------------------------------------------------- + +# Tier 0: blue-chip — top by volume, wrapped native, major stables +_TIER_0 = { + "ETH", "WETH", "BTC", "WBTC", "cbBTC", "USDC", "USDT", "DAI", + "wstETH", "stETH", "rETH", "cbETH", "WMATIC", "MATIC", "POL", + "WAVAX", "AVAX", "GNO", "WXDAI", "xDAI", + "S", "wS", # Sonic native +} + +# Tier 1: mid-cap DeFi blue-chips (approx CoinGecko rank < 200) +_TIER_1 = { + "AAVE", "LINK", "UNI", "BAL", "MKR", "CRV", "COMP", "SNX", + "LDO", "RPL", "SUSHI", "YFI", "1INCH", "ENS", "DYDX", + "FXS", "FRAX", "LUSD", "sDAI", "GHO", "crvUSD", + "ARB", "OP", "PENDLE", "ENA", "EIGEN", + "SAFE", "COW", +} + + +def _normalise_symbol(symbol: str) -> str: + """Normalise wrapped/bridged variants to canonical form.""" + s = symbol.strip() + # Common wrapped → unwrapped + mapping = { + "WETH": "WETH", # keep WETH as-is (it's in tier 0) + "WBTC": "WBTC", + "cbBTC": "cbBTC", + "WMATIC": "WMATIC", + "WAVAX": "WAVAX", + "WXDAI": "WXDAI", + "wS": "wS", + } + return mapping.get(s, s) + + +def classify_token_tier(symbol: str) -> int: + """Classify a token symbol into tier 0/1/2. + + Returns + ------- + int + 0 = blue-chip, 1 = mid-cap, 2 = long-tail + """ + s = _normalise_symbol(symbol) + if s in _TIER_0: + return 0 + if s in _TIER_1: + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# Phase 1: API data ingestion +# --------------------------------------------------------------------------- + +def _graphql_request(query: dict, base_url: str = BALANCER_API_URL, + timeout: int = 30) -> dict: + """Send a GraphQL request to the Balancer V3 API.""" + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + with urllib.request.urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def enumerate_balancer_pools( + chains: list = None, + pool_types: list = None, + min_tvl: float = 10000.0, +) -> pd.DataFrame: + """Enumerate all WEIGHTED + RECLAMM pools across chains from Balancer API. + + Parameters + ---------- + chains : list of str + API chain identifiers (e.g. ["MAINNET", "BASE"]). + pool_types : list of str + Pool type filters (e.g. ["WEIGHTED", "STABLE"]). + min_tvl : float + Minimum TVL in USD to include. + + Returns + ------- + pd.DataFrame + Columns: pool_id, chain, pool_type, tokens (list of symbols), + swap_fee, create_time, dynamic_data_tvl. + """ + if chains is None: + chains = BALANCER_API_CHAINS + if pool_types is None: + pool_types = ["WEIGHTED", "RECLAMM"] + + all_pools = [] + for chain in chains: + print(f" Querying {chain}...", end=" ", flush=True) + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { + chainIn: [$chain] + poolTypeIn: $types + minTvl: $minTvl + } + ) { + id + chain + type + createTime + protocolVersion + poolTokens { + symbol + weight + address + } + dynamicData { + totalLiquidity + swapFee + } + } + } + """, + "variables": { + "chain": chain, + "types": pool_types, + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f"FAILED ({e})") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + weights = [t.get("weight") for t in p.get("poolTokens", [])] + token_addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "protocol_version": p.get("protocolVersion", 0), + "tokens": tokens, + "token_addresses": token_addresses, + "weights": weights, + "swap_fee": fee, + "create_time": p.get("createTime", 0), + "current_tvl": tvl, + }) + + print(f"{len(pools)} pools") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools across {len(chains)} chains") + return df + + +def fetch_pool_snapshots(pool_id: str, chain: str, + base_url: str = BALANCER_API_URL) -> pd.DataFrame: + """Fetch ALL_TIME daily snapshots for a single pool. + + Returns + ------- + pd.DataFrame + Columns: timestamp, volume_usd, total_liquidity_usd. + """ + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + } + } + """, + "variables": { + "poolId": pool_id, + "chain": chain, + "range": "ALL_TIME", + }, + } + + body = _graphql_request(query) + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + + if not snapshots: + return pd.DataFrame(columns=["timestamp", "volume_usd", "total_liquidity_usd"]) + + records = [] + for snap in snapshots: + records.append({ + "timestamp": int(snap["timestamp"]), + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + }) + + df = pd.DataFrame(records) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + # Deduplicate by date (keep last snapshot per day) + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df + + +def fetch_all_snapshots(pools_df: pd.DataFrame, + cache_path: str = None) -> pd.DataFrame: + """Fetch daily snapshots for all pools, with caching. + + Parameters + ---------- + pools_df : pd.DataFrame + Pool enumeration from enumerate_balancer_pools. + cache_path : str, optional + Path to parquet cache. If it exists, only fetch missing pools. + + Returns + ------- + pd.DataFrame + Panel with columns: pool_id, chain, date, volume_usd, + total_liquidity_usd. + """ + # Load cache if exists + cached = pd.DataFrame() + cached_pool_ids = set() + if cache_path and os.path.exists(cache_path): + cached = pd.read_parquet(cache_path) + cached_pool_ids = set(cached["pool_id"].unique()) + print(f" Cache has {len(cached_pool_ids)} pools, " + f"{len(cached)} pool-days") + + # Determine which pools need fetching + if len(pools_df) == 0: + print(" No pools to fetch.") + return cached if len(cached) > 0 else pd.DataFrame( + columns=["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + ) + to_fetch = pools_df[~pools_df["pool_id"].isin(cached_pool_ids)] + print(f" Need to fetch {len(to_fetch)} new pools") + + new_records = [] + for i, (_, pool) in enumerate(to_fetch.iterrows()): + if (i + 1) % 10 == 0 or i == 0: + print(f" Fetching {i+1}/{len(to_fetch)}: {pool['pool_id'][:10]}... " + f"({pool['chain']})", flush=True) + try: + snap_df = fetch_pool_snapshots(pool["pool_id"], pool["chain"]) + if len(snap_df) > 0: + snap_df["pool_id"] = pool["pool_id"] + snap_df["chain"] = pool["chain"] + new_records.append(snap_df[ + ["pool_id", "chain", "date", "volume_usd", + "total_liquidity_usd"] + ]) + except Exception as e: + print(f" FAILED {pool['pool_id'][:10]}: {e}") + time.sleep(0.5) # Rate limit + + if new_records: + new_df = pd.concat(new_records, ignore_index=True) + combined = pd.concat([cached, new_df], ignore_index=True) + else: + combined = cached + + # Save cache + if cache_path and len(combined) > 0: + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + combined.to_parquet(cache_path, index=False) + print(f" Saved cache: {len(combined)} pool-days → {cache_path}") + + return combined + + +def fetch_token_prices(token_addresses_by_chain: dict, + cache_dir: str = None) -> dict: + """Fetch hourly token prices from Balancer API. + + Parameters + ---------- + token_addresses_by_chain : dict + {chain: {symbol: address, ...}, ...} + cache_dir : str, optional + Directory for per-token price caches. + + Returns + ------- + dict + {(chain, symbol): pd.DataFrame with columns [timestamp, price], ...} + """ + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + prices = {} + + for chain, tokens in token_addresses_by_chain.items(): + # Check cache first, collect uncached addresses + uncached = {} + for symbol, address in tokens.items(): + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") if cache_dir else None + + if cp and os.path.exists(cp): + prices[(chain, symbol)] = pd.read_parquet(cp) + else: + uncached[symbol] = address + + if not uncached: + continue + + # Batch fetch: API supports multiple addresses per request + addr_to_symbol = {addr: sym for sym, addr in uncached.items()} + addresses = list(uncached.values()) + + print(f" Fetching {len(addresses)} prices on {chain}...", + flush=True) + + # Batch in groups of 20 to avoid oversized requests + batch_size = 20 + for batch_start in range(0, len(addresses), batch_size): + batch_addrs = addresses[batch_start:batch_start + batch_size] + query = { + "query": """ + query GetPrices($chain: GqlChain!, $addresses: [String!]!, + $range: GqlTokenChartDataRange!) { + tokenGetHistoricalPrices( + addresses: $addresses, chain: $chain, range: $range + ) { + address + prices { + timestamp + price + } + } + } + """, + "variables": { + "chain": chain, + "addresses": batch_addrs, + "range": "ONE_YEAR", + }, + } + + try: + body = _graphql_request(query, timeout=60) + results = body.get("data", {}).get( + "tokenGetHistoricalPrices", []) + for result in results: + addr = result.get("address", "") + price_list = result.get("prices", []) + symbol = addr_to_symbol.get(addr) + if symbol and price_list: + pdf = pd.DataFrame(price_list) + pdf["timestamp"] = pdf["timestamp"].astype(int) + pdf["price"] = pdf["price"].astype(float) + prices[(chain, symbol)] = pdf + if cache_dir: + cache_key = f"{chain}_{symbol}".replace("/", "_") + cp = os.path.join(cache_dir, f"{cache_key}.parquet") + pdf.to_parquet(cp, index=False) + except Exception as e: + print(f" FAILED batch on {chain}: {e}") + + time.sleep(0.5) + + print(f" Got prices for {len(prices)} token-chain pairs") + return prices + + +def compute_pair_volatility( + snapshots_df: pd.DataFrame, + pool_row: pd.Series, + token_prices: dict, +) -> pd.Series: + """Compute daily annualised volatility for a pool's pair ratio. + + Uses hourly prices from the API to compute daily realised volatility. + Falls back to a default of 0.5 if price data is insufficient. + + Parameters + ---------- + snapshots_df : pd.DataFrame + Pool's daily snapshots (need dates). + pool_row : pd.Series + Pool metadata row (need tokens, chain). + token_prices : dict + {(chain, symbol): DataFrame, ...} + + Returns + ------- + pd.Series + Indexed by date, values are annualised daily volatility. + """ + tokens = pool_row["tokens"] + chain = pool_row["chain"] + + if len(tokens) < 2: + return pd.Series(dtype=float) + + # Get price series for token[0] and token[1] + # Try chain-specific first, then any chain + def _get_price_df(symbol): + # Exact match + key = (chain, symbol) + if key in token_prices: + return token_prices[key] + # Any chain + for k, v in token_prices.items(): + if k[1] == symbol: + return v + return None + + p0_df = _get_price_df(tokens[0]) + p1_df = _get_price_df(tokens[1]) + + # If either is a stablecoin, use $1 + stables = {"USDC", "USDT", "DAI", "LUSD", "GHO", "crvUSD", "sDAI", + "WXDAI", "xDAI", "USDC.e", "USDbC"} + + if tokens[0] in stables and tokens[1] in stables: + # Stable-stable pair: near-zero vol + dates = snapshots_df["date"].unique() + return pd.Series(0.01, index=dates) + + if p0_df is None and tokens[0] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) # fallback + if p1_df is None and tokens[1] not in stables: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) # fallback + + # Build hourly price ratio + if tokens[0] in stables: + # ratio = 1 / p1 + if p1_df is None or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p1_df.copy() + ratio_df["ratio"] = 1.0 / ratio_df["price"] + elif tokens[1] in stables: + # ratio = p0 + if p0_df is None or len(p0_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = p0_df.copy() + ratio_df["ratio"] = ratio_df["price"] + else: + # Both non-stable: ratio = p0/p1 + if p0_df is None or p1_df is None or len(p0_df) == 0 or len(p1_df) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + # Merge on nearest timestamp + merged = pd.merge_asof( + p0_df.sort_values("timestamp"), + p1_df.sort_values("timestamp"), + on="timestamp", + suffixes=("_0", "_1"), + tolerance=7200, # 2 hour tolerance + ).dropna() + if len(merged) == 0: + dates = snapshots_df["date"].unique() + return pd.Series(0.5, index=dates) + ratio_df = merged.copy() + ratio_df["ratio"] = merged["price_0"] / merged["price_1"] + + ratio_df["datetime"] = pd.to_datetime(ratio_df["timestamp"], unit="s") + ratio_df["date"] = ratio_df["datetime"].dt.date + ratio_df = ratio_df.sort_values("timestamp") + + # Log returns + ratio_df["log_return"] = np.log( + ratio_df["ratio"] / ratio_df["ratio"].shift(1) + ) + ratio_df = ratio_df.dropna(subset=["log_return"]) + + # Daily vol from hourly returns, annualised + daily_vol = ratio_df.groupby("date")["log_return"].std() + # Hourly data → ~24 returns/day. Annualise: σ_daily * sqrt(365) + # But std() already gives daily std from hourly returns, so: + # σ_annual = σ_hourly * sqrt(24 * 365) + daily_vol_ann = daily_vol * np.sqrt(24 * 365) + + return daily_vol_ann + + +def assemble_panel( + pools_df: pd.DataFrame, + snapshots_df: pd.DataFrame, + token_prices: dict, +) -> pd.DataFrame: + """Assemble the full panel DataFrame for hierarchical estimation. + + Parameters + ---------- + pools_df : pd.DataFrame + Pool enumeration from enumerate_balancer_pools. + snapshots_df : pd.DataFrame + Daily snapshots from fetch_all_snapshots. + token_prices : dict + Token prices from fetch_token_prices. + + Returns + ------- + pd.DataFrame + Panel with columns: pool_id, chain, date, log_volume, log_tvl, + volatility, weekend, log_fee, tier_A, tier_B, tokens. + """ + records = [] + pool_ids = snapshots_df["pool_id"].unique() + n_pools = len(pool_ids) + + for i, pool_id in enumerate(pool_ids): + if (i + 1) % 20 == 0 or i == 0: + print(f" Assembling {i+1}/{n_pools}...", flush=True) + + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pool_id] + pool_meta = pools_df[pools_df["pool_id"] == pool_id] + if len(pool_meta) == 0: + continue + pool_row = pool_meta.iloc[0] + + tokens = pool_row["tokens"] + if len(tokens) < 2: + continue + + chain = pool_row["chain"] + swap_fee = pool_row["swap_fee"] + + # Token tiers: sort by tier (best tier first) + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] # best (lowest) tier + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + # Compute volatility for this pool's pair + vol_series = compute_pair_volatility(pool_snaps, pool_row, token_prices) + + for _, snap in pool_snaps.iterrows(): + date = snap["date"] + volume = snap["volume_usd"] + tvl = snap["total_liquidity_usd"] + + # Skip zero/negative TVL or volume + if tvl <= 0 or volume <= 0: + continue + + # Volatility lookup + if isinstance(vol_series, pd.Series) and date in vol_series.index: + vol = vol_series[date] + else: + vol = 0.5 # fallback + + if not np.isfinite(vol) or vol <= 0: + vol = 0.5 + + # Weekend indicator + if isinstance(date, datetime): + is_weekend = date.weekday() >= 5 + else: + is_weekend = pd.Timestamp(date).weekday() >= 5 + + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": date, + "log_volume": np.log(volume), + "log_tvl": np.log(tvl), + "volatility": vol, + "weekend": 1.0 if is_weekend else 0.0, + "log_fee": np.log(max(swap_fee, 1e-6)), + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": ",".join(tokens[:2]), + "swap_fee": swap_fee, + }) + + panel = pd.DataFrame(records) + print(f"\n Panel: {len(panel)} observations, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + return panel + + +# --------------------------------------------------------------------------- +# Phase 2: Bayesian hierarchical model +# --------------------------------------------------------------------------- + +def _encode_covariates(panel: pd.DataFrame) -> dict: + """Build NumPyro-ready arrays from the panel DataFrame. + + Returns + ------- + dict with keys: + pool_idx : (N_obs,) int array mapping each observation to its pool + X_pool : (N_pools, K) covariate matrix (intercept + dummies + log_fee) + log_tvl, volatility, weekend, log_volume : (N_obs,) float arrays + pool_ids : (N_pools,) pool ID strings + covariate_names : list of str, column names for X_pool + ref_chain, ref_tier_a, ref_tier_b : reference categories + chains : sorted list of all chains + pool_meta : DataFrame of per-pool metadata + """ + pool_meta = panel.drop_duplicates("pool_id").reset_index(drop=True) + pool_ids = pool_meta["pool_id"].values + pool_id_to_idx = {pid: i for i, pid in enumerate(pool_ids)} + + pool_idx = panel["pool_id"].map(pool_id_to_idx).values + + # Build X_pool columns + chains = sorted(panel["chain"].unique()) + ref_chain = chains[0] + chain_cols = [] + chain_names = [] + for c in chains[1:]: + chain_cols.append((pool_meta["chain"] == c).astype(float).values) + chain_names.append(f"chain_{c}") + + tier_a_vals = sorted(pool_meta["tier_A"].astype(str).unique()) + ref_tier_a = tier_a_vals[0] + tier_a_cols = [] + tier_a_names = [] + for t in tier_a_vals[1:]: + tier_a_cols.append( + (pool_meta["tier_A"].astype(str) == t).astype(float).values + ) + tier_a_names.append(f"tier_A_{t}") + + tier_b_vals = sorted(pool_meta["tier_B"].astype(str).unique()) + ref_tier_b = tier_b_vals[0] + tier_b_cols = [] + tier_b_names = [] + for t in tier_b_vals[1:]: + tier_b_cols.append( + (pool_meta["tier_B"].astype(str) == t).astype(float).values + ) + tier_b_names.append(f"tier_B_{t}") + + N_pools = len(pool_ids) + columns = [np.ones((N_pools, 1))] + col_names = ["intercept"] + + for arr, name in zip(chain_cols, chain_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_a_cols, tier_a_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + for arr, name in zip(tier_b_cols, tier_b_names): + columns.append(arr.reshape(-1, 1)) + col_names.append(name) + columns.append(pool_meta["log_fee"].values.reshape(-1, 1)) + col_names.append("log_fee") + + X_pool = np.hstack(columns) + + return { + "pool_idx": pool_idx.astype(np.int32), + "X_pool": X_pool.astype(np.float64), + "log_tvl": panel["log_tvl"].values.astype(np.float64), + "volatility": panel["volatility"].values.astype(np.float64), + "weekend": panel["weekend"].values.astype(np.float64), + "log_volume": panel["log_volume"].values.astype(np.float64), + "pool_ids": pool_ids, + "covariate_names": col_names, + "ref_chain": ref_chain, + "ref_tier_a": ref_tier_a, + "ref_tier_b": ref_tier_b, + "chains": chains, + "pool_meta": pool_meta, + } + + +def _hierarchical_noise_model( + pool_idx, X_pool, log_tvl, volatility, weekend, log_volume=None, +): + """NumPyro model: Bayesian hierarchical loglinear noise volume. + + All pool covariates modulate all three coefficients (α, β_tvl, β_vol) + through the group-level regression matrix Φ, with correlated random + effects via LKJ-Cholesky. + """ + N_pools = X_pool.shape[0] + K = X_pool.shape[1] + + # Hyperpriors + Phi = numpyro.sample("Phi", dist.Normal(0, 2).expand([K, 3]).to_event(2)) + sigma_theta = numpyro.sample( + "sigma_theta", dist.HalfNormal(2).expand([3]).to_event(1) + ) + L_omega = numpyro.sample("L_omega", dist.LKJCholesky(3, concentration=2)) + beta_weekend = numpyro.sample("beta_weekend", dist.Normal(0, 2)) + sigma_eps = numpyro.sample("sigma_eps", dist.HalfNormal(3)) + + # Non-centered pool random effects + with numpyro.plate("pools", N_pools): + z = numpyro.sample("z", dist.Normal(jnp.zeros(3), 1).to_event(1)) + + # θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i + mu = X_pool @ Phi # (N_pools, 3) + L_Sigma = sigma_theta[:, None] * L_omega # (3, 3) + theta = mu + z @ L_Sigma.T # (N_pools, 3) + + alpha = theta[:, 0] + beta_tvl = theta[:, 1] + beta_vol = theta[:, 2] + + # Observation model + loc = (alpha[pool_idx] + + beta_tvl[pool_idx] * log_tvl + + beta_vol[pool_idx] * volatility + + beta_weekend * weekend) + + with numpyro.plate("obs", pool_idx.shape[0]): + numpyro.sample("log_volume", dist.Normal(loc, sigma_eps), obs=log_volume) + + +def fit_bayesian_model( + panel: pd.DataFrame, use_nuts: bool = False, +) -> tuple: + """Fit the Bayesian hierarchical model via SVI or NUTS. + + Parameters + ---------- + panel : pd.DataFrame + Panel from assemble_panel. + use_nuts : bool + If True, use NUTS MCMC (slower, exact). Otherwise SVI with + AutoMultivariateNormal guide. + + Returns + ------- + samples : dict + Posterior samples keyed by parameter name. + encoding : dict + From _encode_covariates (needed downstream). + """ + encoding = _encode_covariates(panel) + + model_kwargs = dict( + pool_idx=jnp.array(encoding["pool_idx"]), + X_pool=jnp.array(encoding["X_pool"]), + log_tvl=jnp.array(encoding["log_tvl"]), + volatility=jnp.array(encoding["volatility"]), + weekend=jnp.array(encoding["weekend"]), + log_volume=jnp.array(encoding["log_volume"]), + ) + + N_pools = encoding["X_pool"].shape[0] + K = encoding["X_pool"].shape[1] + print(f" N obs = {len(encoding['pool_idx'])}, " + f"N pools = {N_pools}, K covariates = {K}") + print(f" Covariates: {encoding['covariate_names']}") + + rng_key = jax.random.PRNGKey(0) + + if use_nuts: + print(" Running NUTS (500 warmup + 1000 samples)...") + kernel = NUTS(_hierarchical_noise_model) + mcmc = MCMC(kernel, num_warmup=500, num_samples=1000, num_chains=1) + mcmc.run(rng_key, **model_kwargs) + samples = mcmc.get_samples() + print(" NUTS complete.") + else: + print(" Running SVI with AutoMultivariateNormal (20k steps)...") + guide = AutoMultivariateNormal(_hierarchical_noise_model) + optimizer = numpyro.optim.Adam(1e-3) + svi = SVI(_hierarchical_noise_model, guide, optimizer, + loss=Trace_ELBO()) + svi_result = svi.run(rng_key, 20_000, **model_kwargs) + print(f" SVI complete. Final ELBO loss: {svi_result.losses[-1]:.2f}") + + predictive = Predictive( + guide, params=svi_result.params, num_samples=1000, + ) + samples = predictive(jax.random.PRNGKey(1), **model_kwargs) + print(" Drew 1000 posterior samples.") + + return samples, encoding + + +def extract_posteriors(samples: dict, encoding: dict) -> dict: + """Reconstruct pool-specific coefficients from posterior samples. + + Computes θ_i = Φᵀx_i + diag(σ_θ)·L_ω·z_i for each posterior draw, + then returns posterior means and variance components. + + Returns + ------- + dict with keys: + pool_effects : {pool_id: {alpha, beta_tvl, beta_vol}} + Phi_mean : (K, 3) array + sigma_theta_mean : (3,) array + correlation_matrix : (3, 3) array + beta_weekend_mean : float + sigma_eps_mean : float + theta_samples : (S, N_pools, 3) array (for diagnostics) + """ + Phi = np.array(samples["Phi"]) # (S, K, 3) + sigma_theta = np.array(samples["sigma_theta"]) # (S, 3) + L_omega = np.array(samples["L_omega"]) # (S, 3, 3) + z = np.array(samples["z"]) # (S, N_pools, 3) + beta_weekend = np.array(samples["beta_weekend"]) # (S,) + sigma_eps = np.array(samples["sigma_eps"]) # (S,) + + X_pool = encoding["X_pool"] # (N_pools, K) + pool_ids = encoding["pool_ids"] + + # mu = X_pool @ Phi for each sample: (S, N_pools, 3) + mu = np.einsum("pk,skj->spj", X_pool, Phi) + + # L_Sigma = diag(sigma_theta) @ L_omega: (S, 3, 3) + L_Sigma = sigma_theta[:, :, None] * L_omega + + # offset = z @ L_Sigma^T: (S, N_pools, 3) + offset = np.einsum("spi,sji->spj", z, L_Sigma) + + theta = mu + offset # (S, N_pools, 3) + theta_mean = theta.mean(axis=0) # (N_pools, 3) + + pool_effects = {} + for i, pid in enumerate(pool_ids): + pool_effects[pid] = { + "alpha": float(theta_mean[i, 0]), + "beta_tvl": float(theta_mean[i, 1]), + "beta_vol": float(theta_mean[i, 2]), + } + + Phi_mean = Phi.mean(axis=0) # (K, 3) + + # Correlation matrix: R = L_omega @ L_omega^T, averaged over samples + R_samples = np.einsum("sij,skj->sik", L_omega, L_omega) + R_mean = R_samples.mean(axis=0) + + return { + "pool_effects": pool_effects, + "Phi_mean": Phi_mean, + "sigma_theta_mean": sigma_theta.mean(axis=0), + "correlation_matrix": R_mean, + "beta_weekend_mean": float(beta_weekend.mean()), + "sigma_eps_mean": float(sigma_eps.mean()), + "theta_samples": theta, + } + + +def compute_noise_params(posteriors: dict, panel: pd.DataFrame) -> list: + """Convert posterior pool effects to per-pool noise_params dicts. + + Each pool now has its own β_tvl and β_vol (from the hierarchical + posterior), rather than sharing global slopes. + + Parameters + ---------- + posteriors : dict + From extract_posteriors. + panel : pd.DataFrame + Panel data. + + Returns + ------- + list of dict + Each dict has: pool_id, chain, tokens, noise_params. + """ + pool_effects = posteriors["pool_effects"] + pool_meta = panel.drop_duplicates("pool_id").set_index("pool_id") + + results = [] + for pool_id, effects in pool_effects.items(): + if pool_id not in pool_meta.index: + continue + meta = pool_meta.loc[pool_id] + swap_fee = float(meta.get("swap_fee", 0.003)) + + results.append({ + "pool_id": pool_id, + "chain": meta["chain"], + "tokens": (meta["tokens"].split(",") + if isinstance(meta["tokens"], str) + else meta["tokens"]), + "noise_params": { + "b_0": effects["alpha"], + "b_sigma": effects["beta_vol"], + "b_c": effects["beta_tvl"], + "base_fee": swap_fee, + }, + }) + + return results + + +def _build_covariate_vector( + encoding: dict, chain: str, tokens: list, fee: float, +) -> np.ndarray: + """Construct a covariate vector x for a new pool. + + Matches the column order of X_pool from _encode_covariates so that + x @ Phi_mean gives population-level predictions for all 3 coefficients. + """ + col_names = encoding["covariate_names"] + x = np.zeros(len(col_names)) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + x[i] = 1.0 + elif name == "log_fee": + x[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + x[i] = 1.0 + elif name == f"tier_A_{tier_a}": + x[i] = 1.0 + elif name == f"tier_B_{tier_b}": + x[i] = 1.0 + + return x + + +def predict_new_pool( + posteriors: dict, + encoding: dict, + chain: str, + tokens: list, + fee: float, +) -> dict: + """Predict noise params for an unseen pool. + + Uses population-level estimates only (x @ Φ, no pool random effect). + + Parameters + ---------- + posteriors : dict + From extract_posteriors. + encoding : dict + From _encode_covariates (or loaded from cache). + chain : str + Chain API identifier (e.g. "BASE"). + tokens : list + Token symbols (e.g. ["ETH", "BTC"]). + fee : float + Swap fee rate. + + Returns + ------- + dict + noise_params dict with pool-predicted coefficients. + """ + x = _build_covariate_vector(encoding, chain, tokens, fee) + Phi_mean = posteriors["Phi_mean"] + + theta_pred = x @ Phi_mean # (3,) — population-level prediction + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + tier_b = tiers[1] if len(tiers) > 1 else tiers[0] + + return { + "b_0": float(theta_pred[0]), + "b_sigma": float(theta_pred[2]), + "b_c": float(theta_pred[1]), + "base_fee": float(fee), + "_prediction_source": "population_level", + "_alpha": float(theta_pred[0]), + "_beta_tvl": float(theta_pred[1]), + "_beta_vol": float(theta_pred[2]), + "_tier_a": tier_a, + "_tier_b": tier_b, + } + + +# --------------------------------------------------------------------------- +# Phase 3: Diagnostics and output +# --------------------------------------------------------------------------- + +def plot_hierarchical_diagnostics( + panel: pd.DataFrame, + posteriors: dict, + encoding: dict, + output_dir: str = "results", +): + """Generate diagnostic plots for the Bayesian hierarchical model. + + Figure 1 (2x2): + (0,0) Pool-specific coefficient distributions (α, β_tvl, β_vol) + (0,1) Chain effects on all 3 coefficients + (1,0) Tier effects on all 3 coefficients + (1,1) Model summary (Φ, σ_θ, correlations, σ_ε) + + Figure 2: + β_tvl vs β_vol scatter colored by chain + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + os.makedirs(output_dir, exist_ok=True) + + pool_effects = posteriors["pool_effects"] + Phi_mean = posteriors["Phi_mean"] + col_names = encoding["covariate_names"] + pool_meta = panel.drop_duplicates("pool_id") + + alphas = [e["alpha"] for e in pool_effects.values()] + beta_tvls = [e["beta_tvl"] for e in pool_effects.values()] + beta_vols = [e["beta_vol"] for e in pool_effects.values()] + + # --- Figure 1: Diagnostics 2x2 --- + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + + # (0,0) Pool-specific coefficient distributions + ax = axes[0, 0] + bins = 25 + ax.hist(alphas, bins=bins, alpha=0.6, label="α (intercept)", + color="steelblue", edgecolor="white") + ax.hist(beta_tvls, bins=bins, alpha=0.6, label="β_tvl", + color="coral", edgecolor="white") + ax.hist(beta_vols, bins=bins, alpha=0.6, label="β_vol", + color="seagreen", edgecolor="white") + ax.set_xlabel("Coefficient value") + ax.set_ylabel("Count") + ax.set_title(f"Pool-specific coefficients (n={len(alphas)})") + ax.legend() + + # (0,1) Chain effects on all 3 coefficients + ax = axes[0, 1] + chain_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("chain_")} + chain_labels = [] + chain_alpha_effects = [] + chain_tvl_effects = [] + chain_vol_effects = [] + # Reference chain (effect = 0) + ref_chain = encoding["ref_chain"] + chain_counts = pool_meta["chain"].value_counts() + chain_labels.append(f"{ref_chain}\n(n={chain_counts.get(ref_chain, 0)})") + chain_alpha_effects.append(0.0) + chain_tvl_effects.append(0.0) + chain_vol_effects.append(0.0) + for name, row_idx in sorted(chain_rows.items()): + chain_name = name.replace("chain_", "") + chain_labels.append( + f"{chain_name}\n(n={chain_counts.get(chain_name, 0)})" + ) + chain_alpha_effects.append(Phi_mean[row_idx, 0]) + chain_tvl_effects.append(Phi_mean[row_idx, 1]) + chain_vol_effects.append(Phi_mean[row_idx, 2]) + + y_pos = np.arange(len(chain_labels)) + bar_h = 0.25 + ax.barh(y_pos - bar_h, chain_alpha_effects, bar_h, label="α", + color="steelblue", alpha=0.8) + ax.barh(y_pos, chain_tvl_effects, bar_h, label="β_tvl", + color="coral", alpha=0.8) + ax.barh(y_pos + bar_h, chain_vol_effects, bar_h, label="β_vol", + color="seagreen", alpha=0.8) + ax.set_yticks(y_pos) + ax.set_yticklabels(chain_labels) + ax.axvline(0, color="red", linestyle="--", linewidth=1) + ax.set_xlabel("Effect (relative to reference)") + ax.set_title("Chain effects on all coefficients") + ax.legend(fontsize=8) + + # (1,0) Tier effects on all 3 coefficients + ax = axes[1, 0] + tier_names = ["0 (blue-chip)", "1 (mid-cap)", "2 (long-tail)"] + coeff_labels = ["α", "β_tvl", "β_vol"] + tier_a_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("tier_A_")} + tier_b_rows = {name: i for i, name in enumerate(col_names) + if name.startswith("tier_B_")} + + # Build effects matrix: (3 tiers) x (3 coefficients) x (A/B) + x_pos = np.arange(len(tier_names)) + width = 0.13 + for coeff_idx, (coeff_name, color) in enumerate( + zip(coeff_labels, ["steelblue", "coral", "seagreen"]) + ): + tier_a_vals = [0.0, 0.0, 0.0] # reference tier gets 0 + tier_b_vals = [0.0, 0.0, 0.0] + for name, row_idx in tier_a_rows.items(): + tier_val = name.replace("tier_A_", "") + if tier_val in ("0", "1", "2"): + tier_a_vals[int(tier_val)] = Phi_mean[row_idx, coeff_idx] + for name, row_idx in tier_b_rows.items(): + tier_val = name.replace("tier_B_", "") + if tier_val in ("0", "1", "2"): + tier_b_vals[int(tier_val)] = Phi_mean[row_idx, coeff_idx] + offset = (coeff_idx - 1) * width * 2 + ax.bar(x_pos + offset - width / 2, tier_a_vals, width, + label=f"{coeff_name} (A)" if coeff_idx == 0 else "", + color=color, alpha=0.7, edgecolor="white") + ax.bar(x_pos + offset + width / 2, tier_b_vals, width, + label=f"{coeff_name} (B)" if coeff_idx == 0 else "", + color=color, alpha=0.4, edgecolor="white", hatch="//") + + ax.set_xticks(x_pos) + ax.set_xticklabels(tier_names) + ax.set_ylabel("Effect on coefficient") + ax.set_title("Token tier effects (solid=A, hatched=B)") + ax.axhline(0, color="black", linewidth=0.5) + # Manual legend for coefficient colors + from matplotlib.patches import Patch + ax.legend(handles=[Patch(color=c, label=l) for c, l in + zip(["steelblue", "coral", "seagreen"], coeff_labels)], + fontsize=8) + + # (1,1) Model summary text + ax = axes[1, 1] + ax.axis("off") + sigma_theta = posteriors["sigma_theta_mean"] + R = posteriors["correlation_matrix"] + summary = "Group-level regression Φ (posterior mean):\n" + summary += f" {'covariate':<20s} {'α':>8s} {'β_tvl':>8s} {'β_vol':>8s}\n" + summary += " " + "-" * 46 + "\n" + for j, name in enumerate(col_names): + summary += (f" {name:<20s} {Phi_mean[j,0]:>8.3f} " + f"{Phi_mean[j,1]:>8.3f} {Phi_mean[j,2]:>8.3f}\n") + summary += f"\nσ_θ: [{sigma_theta[0]:.3f}, {sigma_theta[1]:.3f}, " + summary += f"{sigma_theta[2]:.3f}]\n" + summary += f"Correlation:\n" + for i in range(3): + summary += f" [{R[i,0]:>6.3f} {R[i,1]:>6.3f} {R[i,2]:>6.3f}]\n" + summary += f"β_weekend: {posteriors['beta_weekend_mean']:.4f}\n" + summary += f"σ_ε: {posteriors['sigma_eps_mean']:.4f}\n" + ax.text(0.02, 0.98, summary, transform=ax.transAxes, + fontsize=7, verticalalignment="top", fontfamily="monospace") + ax.set_title("Model Summary") + + fig.suptitle( + "Bayesian Hierarchical Noise Model — Diagnostics", fontsize=13, + ) + plt.tight_layout() + path = os.path.join(output_dir, "hierarchical_diagnostics.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- Figure 2: β_tvl vs β_vol scatter colored by chain --- + fig2, ax2 = plt.subplots(figsize=(10, 7)) + pool_id_to_chain = dict( + zip(pool_meta["pool_id"], pool_meta["chain"]) + ) + chain_colors = {} + cmap = plt.cm.tab10 + unique_chains = sorted(pool_meta["chain"].unique()) + for i, c in enumerate(unique_chains): + chain_colors[c] = cmap(i % 10) + + for pid, effects in pool_effects.items(): + c = pool_id_to_chain.get(pid, "?") + ax2.scatter(effects["beta_tvl"], effects["beta_vol"], + color=chain_colors.get(c, "gray"), alpha=0.6, s=20, + edgecolors="white", linewidths=0.3) + + # Legend + from matplotlib.lines import Line2D + handles = [Line2D([0], [0], marker="o", color="w", + markerfacecolor=chain_colors[c], markersize=8, + label=c) + for c in unique_chains if c in chain_colors] + ax2.legend(handles=handles, fontsize=8, loc="best") + ax2.set_xlabel("β_tvl (TVL elasticity)") + ax2.set_ylabel("β_vol (volatility sensitivity)") + ax2.set_title("Pool-specific coefficients by chain") + ax2.axhline(0, color="gray", linewidth=0.5, linestyle="--") + ax2.axvline(0, color="gray", linewidth=0.5, linestyle="--") + plt.tight_layout() + path2 = os.path.join(output_dir, "beta_tvl_vs_beta_vol.png") + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + return path, path2 + + +def generate_noise_params_json( + pool_params: list, + posteriors: dict, + encoding: dict, + output_path: str, + inference_method: str = "svi", +): + """Write per-pool noise params to JSON. + + Parameters + ---------- + pool_params : list of dict + From compute_noise_params. + posteriors : dict + From extract_posteriors. + encoding : dict + From _encode_covariates. + output_path : str + Output JSON path. + inference_method : str + "svi" or "nuts". + """ + output = { + "model": "bayesian_hierarchical_loglinear", + "inference_method": inference_method, + "Phi": posteriors["Phi_mean"].tolist(), + "covariate_names": encoding["covariate_names"], + "sigma_theta": posteriors["sigma_theta_mean"].tolist(), + "correlation_matrix": posteriors["correlation_matrix"].tolist(), + "beta_weekend": posteriors["beta_weekend_mean"], + "sigma_eps": posteriors["sigma_eps_mean"], + "pools": pool_params, + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2, default=str) + print(f" Wrote {len(pool_params)} pool params → {output_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Bayesian hierarchical noise volume model for Balancer pools" + ) + parser.add_argument( + "--fetch", action="store_true", + help="Fetch pool data from Balancer API (cached to local_data/)", + ) + parser.add_argument( + "--fit", action="store_true", + help="Fit Bayesian hierarchical model (requires fetched data)", + ) + parser.add_argument( + "--nuts", action="store_true", + help="Use NUTS MCMC instead of SVI (slower, exact posteriors)", + ) + parser.add_argument( + "--plot", action="store_true", + help="Generate diagnostic plots", + ) + parser.add_argument( + "--output", default=None, + help="Output JSON path for per-pool noise params", + ) + parser.add_argument( + "--output-dir", default="results", + help="Directory for diagnostic plots (default: results)", + ) + parser.add_argument( + "--predict", action="store_true", + help="Predict noise params for a new pool", + ) + parser.add_argument( + "--chain", default=None, + help="Chain for --predict (e.g. BASE, MAINNET)", + ) + parser.add_argument( + "--tokens", nargs="+", default=None, + help="Token symbols for --predict (e.g. ETH BTC)", + ) + parser.add_argument( + "--fee", type=float, default=0.003, + help="Swap fee for --predict", + ) + parser.add_argument( + "--min-tvl", type=float, default=10000.0, + help="Minimum TVL filter for pool enumeration", + ) + parser.add_argument( + "--cache-dir", default=None, + help="Cache directory (default: local_data/noise_calibration/)", + ) + args = parser.parse_args() + + cache_dir = args.cache_dir or CACHE_DIR + + if not any([args.fetch, args.fit, args.predict]): + parser.error("At least one of --fetch, --fit, --predict is required") + + # --- Fetch --- + pools_cache = os.path.join(cache_dir, "pools.parquet") + snaps_cache = os.path.join(cache_dir, "pool_snapshots.parquet") + prices_cache = os.path.join(cache_dir, "token_prices") + panel_cache = os.path.join(cache_dir, "panel.parquet") + + if args.fetch: + print("Phase 1: Fetching data from Balancer API") + print("=" * 60) + + # Step 1: Enumerate pools + print("\n1. Enumerating pools...") + pools_df = enumerate_balancer_pools(min_tvl=args.min_tvl) + os.makedirs(cache_dir, exist_ok=True) + pools_df.to_parquet(pools_cache, index=False) + print(f" Saved {len(pools_df)} pools → {pools_cache}") + + # Step 2: Fetch snapshots + print("\n2. Fetching daily snapshots...") + snapshots_df = fetch_all_snapshots(pools_df, cache_path=snaps_cache) + + # Step 3: Fetch token prices + print("\n3. Fetching token prices...") + token_addr_by_chain = {} + for _, pool in pools_df.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + # Step 4: Assemble panel + print("\n4. Assembling panel...") + panel = assemble_panel(pools_df, snapshots_df, token_prices) + panel.to_parquet(panel_cache, index=False) + print(f" Saved panel → {panel_cache}") + + print(f"\nFetch complete. Panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools") + + # --- Fit --- + if args.fit: + inference_method = "nuts" if args.nuts else "svi" + print(f"\nPhase 2: Fitting Bayesian hierarchical model ({inference_method})") + print("=" * 60) + + # Load panel + if not os.path.exists(panel_cache): + print(f"ERROR: Panel cache not found at {panel_cache}", + file=sys.stderr) + print("Run with --fetch first.", file=sys.stderr) + sys.exit(1) + + panel = pd.read_parquet(panel_cache) + print(f" Loaded panel: {len(panel)} obs, " + f"{panel['pool_id'].nunique()} pools, " + f"{panel['chain'].nunique()} chains") + + # Filter: need at least 10 days per pool for stable estimates + pool_counts = panel.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel_filtered = panel[panel["pool_id"].isin(valid_pools)] + print(f" After filtering (≥10 days): {len(panel_filtered)} obs, " + f"{panel_filtered['pool_id'].nunique()} pools") + + samples, encoding = fit_bayesian_model( + panel_filtered, use_nuts=args.nuts, + ) + posteriors = extract_posteriors(samples, encoding) + pool_params = compute_noise_params(posteriors, panel_filtered) + + # Print key diagnostics + Phi_mean = posteriors["Phi_mean"] + col_names = encoding["covariate_names"] + intercept_idx = col_names.index("intercept") + log_fee_idx = col_names.index("log_fee") + + print(f"\n Key results:") + print(f" Population intercept (Φ[intercept]):") + print(f" α: {Phi_mean[intercept_idx, 0]:.4f}") + print(f" β_tvl: {Phi_mean[intercept_idx, 1]:.4f}") + print(f" β_vol: {Phi_mean[intercept_idx, 2]:.4f}") + print(f" Fee effect (Φ[log_fee]):") + print(f" α: {Phi_mean[log_fee_idx, 0]:.4f}") + print(f" β_tvl: {Phi_mean[log_fee_idx, 1]:.4f}") + print(f" β_vol: {Phi_mean[log_fee_idx, 2]:.4f}") + print(f" β_weekend: {posteriors['beta_weekend_mean']:.4f}") + print(f" σ_θ: {posteriors['sigma_theta_mean']}") + print(f" σ_ε: {posteriors['sigma_eps_mean']:.4f}") + + # Verify pool-specific variation + b_sigmas = [p["noise_params"]["b_sigma"] for p in pool_params] + b_cs = [p["noise_params"]["b_c"] for p in pool_params] + print(f" b_sigma range: [{min(b_sigmas):.4f}, {max(b_sigmas):.4f}]") + print(f" b_c range: [{min(b_cs):.4f}, {max(b_cs):.4f}]") + + # Cache posteriors + encoding for --predict and --plot + posteriors_cache = os.path.join(cache_dir, "posteriors.json") + cache_data = { + "Phi_mean": posteriors["Phi_mean"].tolist(), + "sigma_theta_mean": posteriors["sigma_theta_mean"].tolist(), + "correlation_matrix": posteriors["correlation_matrix"].tolist(), + "beta_weekend_mean": posteriors["beta_weekend_mean"], + "sigma_eps_mean": posteriors["sigma_eps_mean"], + "pool_effects": posteriors["pool_effects"], + "covariate_names": encoding["covariate_names"], + "ref_chain": encoding["ref_chain"], + "ref_tier_a": encoding["ref_tier_a"], + "ref_tier_b": encoding["ref_tier_b"], + "chains": encoding["chains"], + "inference_method": inference_method, + } + with open(posteriors_cache, "w") as f: + json.dump(cache_data, f, indent=2, default=str) + print(f" Cached posteriors → {posteriors_cache}") + + if args.output: + generate_noise_params_json( + pool_params, posteriors, encoding, + args.output, inference_method=inference_method, + ) + + if args.plot: + print("\nPhase 3: Generating diagnostics") + print("=" * 60) + plot_hierarchical_diagnostics( + panel_filtered, posteriors, encoding, + output_dir=args.output_dir, + ) + + # --- Predict --- + if args.predict: + if args.chain is None or args.tokens is None: + parser.error("--predict requires --chain and --tokens") + + print(f"\nPredicting noise params for new pool:") + print(f" Chain: {args.chain}") + print(f" Tokens: {args.tokens}") + print(f" Fee: {args.fee}") + + # Load cached posteriors + encoding metadata + posteriors_cache = os.path.join(cache_dir, "posteriors.json") + if not os.path.exists(posteriors_cache): + print(f"ERROR: Posteriors cache not found at {posteriors_cache}", + file=sys.stderr) + print("Run with --fit first.", file=sys.stderr) + sys.exit(1) + + with open(posteriors_cache) as f: + cache_data = json.load(f) + + posteriors = { + "Phi_mean": np.array(cache_data["Phi_mean"]), + "sigma_theta_mean": np.array(cache_data["sigma_theta_mean"]), + "correlation_matrix": np.array(cache_data["correlation_matrix"]), + "beta_weekend_mean": cache_data["beta_weekend_mean"], + "sigma_eps_mean": cache_data["sigma_eps_mean"], + "pool_effects": cache_data["pool_effects"], + } + encoding = { + "covariate_names": cache_data["covariate_names"], + "ref_chain": cache_data["ref_chain"], + "ref_tier_a": cache_data["ref_tier_a"], + "ref_tier_b": cache_data["ref_tier_b"], + "chains": cache_data["chains"], + } + + params = predict_new_pool( + posteriors, encoding, args.chain, args.tokens, args.fee, + ) + print(f"\n Predicted noise_params:") + print(json.dumps(params, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_noise_unified.py b/scripts/calibrate_noise_unified.py new file mode 100644 index 0000000..0e35527 --- /dev/null +++ b/scripts/calibrate_noise_unified.py @@ -0,0 +1,6 @@ +"""Thin wrapper — all logic lives in quantammsim.noise_calibration.""" +from quantammsim.noise_calibration import * # noqa: F401, F403 +from quantammsim.noise_calibration.cli import main + +if __name__ == "__main__": + main() diff --git a/scripts/calibrate_reclamm_noise.py b/scripts/calibrate_reclamm_noise.py new file mode 100644 index 0000000..8d09584 --- /dev/null +++ b/scripts/calibrate_reclamm_noise.py @@ -0,0 +1,842 @@ +"""OLS calibration of the Tsoukalas noise volume model for reClAMM pools. + +Fits the structural volume equation: + V_daily/1e6 = a_0 + a_sigma*sigma + a_c*sqrt(c_eff/1e6) + +where c_eff = (Ra+Va)*pA + (Rb+Vb)*pB is the effective TVL (real + virtual). + +From daily pool snapshots (volume, TVL, volatility). Outputs a noise_params dict +compatible with run_fingerprint["reclamm_noise_params"]. + +Usage: + # From a pre-assembled CSV + python scripts/calibrate_reclamm_noise.py --csv daily_data.csv --base-fee 0.003 + + # End-to-end from API + DB + parquets + python scripts/calibrate_reclamm_noise.py --pool cbBTC_WETH +""" + +import argparse +import json +import os +import sqlite3 +import sys +import urllib.request +from datetime import datetime, timezone + +import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Balancer V3 API +# --------------------------------------------------------------------------- + +BALANCER_API_URL = "https://api-v3.balancer.fi/" + +BALANCER_API_CHAIN = { + "base": "BASE", + "ethereum": "MAINNET", + "gnosis": "GNOSIS", + "avalanche": "AVALANCHE", + "arbitrum": "ARBITRUM", + "polygon": "POLYGON", + "optimism": "OPTIMISM", + "sonic": "SONIC", +} + + +def fetch_balancer_snapshots(chain, pool_address, start_ts, end_ts, + base_url=BALANCER_API_URL): + """Fetch daily pool snapshots from Balancer V3 GraphQL API. + + Parameters + ---------- + chain : str + Chain name (e.g. 'base', 'ethereum'). + pool_address : str + Pool contract address (hex, no 0x prefix). + start_ts : int + Start unix timestamp (seconds). + end_ts : int + End unix timestamp (seconds). + base_url : str + Balancer API base URL. + + Returns + ------- + pd.DataFrame + Columns: date, volume_usd, total_liquidity_usd. Indexed by date string. + """ + api_chain = BALANCER_API_CHAIN.get(chain) + if api_chain is None: + raise ValueError(f"Unknown chain for Balancer API: {chain!r}") + + pool_id = f"0x{pool_address}" if not pool_address.startswith("0x") else pool_address + + # Paginate: API may limit results. Fetch in 90-day windows. + all_snapshots = [] + window = 90 * 86400 + cursor = start_ts + + while cursor < end_ts: + window_end = min(cursor + window, end_ts) + query = { + "query": """ + query GetSnapshots($poolId: String!, $chain: GqlChain!, + $range: GqlPoolSnapshotDataRange!) { + poolGetSnapshots(id: $poolId, chain: $chain, range: $range) { + timestamp + volume24h + totalLiquidity + } + } + """, + "variables": { + "poolId": pool_id, + "chain": api_chain, + "range": "ALL_TIME", + }, + } + + data = json.dumps(query).encode("utf-8") + req = urllib.request.Request( + base_url, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "quantammsim/1.0", + }, + ) + + with urllib.request.urlopen(req, timeout=30) as resp: + body = json.loads(resp.read().decode("utf-8")) + + snapshots = body.get("data", {}).get("poolGetSnapshots", []) + if not snapshots: + break + + for snap in snapshots: + ts = int(snap["timestamp"]) + if start_ts <= ts <= end_ts: + all_snapshots.append({ + "timestamp": ts, + "volume_usd": float(snap["volume24h"]), + "total_liquidity_usd": float(snap["totalLiquidity"]), + }) + + # The API returns ALL_TIME, so no need to paginate further + break + + if not all_snapshots: + raise ValueError( + f"No Balancer snapshots for {pool_id} on {chain} " + f"between {start_ts} and {end_ts}" + ) + + df = pd.DataFrame(all_snapshots) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + # Deduplicate by date (keep last snapshot per day) + df = df.sort_values("timestamp").drop_duplicates("date", keep="last") + return df.set_index("date") + + +# --------------------------------------------------------------------------- +# DB-based daily pool state +# --------------------------------------------------------------------------- + +def load_daily_pool_state(pool, db_path, data_root): + """Load daily pool state from pools_history.db, compute effective TVL. + + Parameters + ---------- + pool : PoolConfig + Pool configuration (from pool_registry). + db_path : str + Path to pools_history.db. + data_root : str + Directory containing {TICKER}_USD.parquet files. + + Returns + ------- + pd.DataFrame + Indexed by date, columns: effective_tvl_usd, real_tvl_usd. + """ + conn = sqlite3.connect(db_path) + cur = conn.cursor() + cur.execute( + f"""SELECT timestamp, balance_0, balance_1, virtual_0, virtual_1 + FROM {pool.db_label} + ORDER BY timestamp""" + ) + rows = cur.fetchall() + conn.close() + + if not rows: + raise ValueError(f"No DB data for {pool.db_label}") + + df = pd.DataFrame(rows, columns=["timestamp", "bal_0", "bal_1", "virt_0", "virt_1"]) + df["date"] = pd.to_datetime(df["timestamp"], unit="s").dt.date + + # Keep last snapshot per day + daily = df.sort_values("timestamp").drop_duplicates("date", keep="last").set_index("date") + + # Load USD prices for each token + if pool.reverse: + tickers_in_db_order = [pool.tokens[1], pool.tokens[0]] + else: + tickers_in_db_order = [pool.tokens[0], pool.tokens[1]] + + price_dfs = {} + for ticker in tickers_in_db_order: + if ticker == "USDC": + price_dfs[ticker] = None # constant $1 + else: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + pdf = pd.read_parquet(path) + pdf["date"] = pd.to_datetime(pdf["unix"], unit="ms").dt.date + # Daily close: last price per day + price_dfs[ticker] = ( + pdf.sort_values("unix") + .drop_duplicates("date", keep="last") + .set_index("date")["close"] + ) + + # Compute USD prices at each daily snapshot + records = [] + for date, row in daily.iterrows(): + b0, b1, v0, v1 = row["bal_0"], row["bal_1"], row["virt_0"], row["virt_1"] + + p0 = 1.0 if tickers_in_db_order[0] == "USDC" else price_dfs[tickers_in_db_order[0]].get(date, np.nan) + p1 = 1.0 if tickers_in_db_order[1] == "USDC" else price_dfs[tickers_in_db_order[1]].get(date, np.nan) + + if np.isnan(p0) or np.isnan(p1): + continue + + real_tvl = b0 * p0 + b1 * p1 + effective_tvl = (b0 + v0) * p0 + (b1 + v1) * p1 + + records.append({ + "date": date, + "real_tvl_usd": real_tvl, + "effective_tvl_usd": effective_tvl, + }) + + result = pd.DataFrame(records).set_index("date") + return result + + +# --------------------------------------------------------------------------- +# Daily volatility from price parquets +# --------------------------------------------------------------------------- + +def compute_daily_volatility(tokens, data_root, start_ts, end_ts): + """Compute daily annualised volatility of the price ratio. + + Uses 5-minute subsampled log returns within each day, then + annualises with sqrt(365). + + Parameters + ---------- + tokens : list + Token tickers in quantammsim sorted order (e.g. ['BTC', 'ETH']). + data_root : str + Directory containing {TICKER}_USD.parquet files. + start_ts : int + Start unix timestamp (seconds). + end_ts : int + End unix timestamp (seconds). + + Returns + ------- + pd.Series + Indexed by date, values are annualised daily volatility. + """ + # Load minute-level prices for both tokens + prices = {} + for ticker in tokens: + if ticker == "USDC": + prices[ticker] = None + else: + path = os.path.join(data_root, f"{ticker}_USD.parquet") + df = pd.read_parquet(path) + df = df[(df["unix"] >= start_ts * 1000) & (df["unix"] <= end_ts * 1000)] + df["datetime"] = pd.to_datetime(df["unix"], unit="ms") + df = df.set_index("datetime")["close"] + prices[ticker] = df + + # Compute price ratio (token[0] / token[1]) + t0, t1 = tokens[0], tokens[1] + if prices[t0] is not None and prices[t1] is not None: + # Align on common timestamps + combined = pd.DataFrame({"p0": prices[t0], "p1": prices[t1]}).dropna() + ratio = combined["p0"] / combined["p1"] + elif prices[t0] is not None: + ratio = prices[t0] # t1 is USDC ($1) + elif prices[t1] is not None: + ratio = 1.0 / prices[t1] # t0 is USDC + else: + raise ValueError("Both tokens are USDC — cannot compute ratio") + + # Subsample to 5-min intervals + ratio_5m = ratio.resample("5min").last().dropna() + log_returns = np.log(ratio_5m / ratio_5m.shift(1)).dropna() + + # Group by date, compute daily vol + log_returns_df = log_returns.to_frame("lr") + log_returns_df["date"] = log_returns_df.index.date + + daily_vol = log_returns_df.groupby("date")["lr"].std() + # Annualise: each day has ~288 5-min periods, scale by sqrt(288 * 365) + daily_vol_ann = daily_vol * np.sqrt(288 * 365) + + return daily_vol_ann + + +# --------------------------------------------------------------------------- +# Calibration DataFrame assembly +# --------------------------------------------------------------------------- + +def build_calibration_df(pool, data_root=None): + """Build daily calibration DataFrame from Balancer API + price parquets. + + All pool state (volume, effective TVL) comes from the Balancer V3 API. + The API's ``totalLiquidity`` is the effective TVL: for a reClAMM pool + on Balancer V3, the router sees real + virtual reserves, and + ``totalLiquidity`` reflects that full depth. Only the volatility + computation requires price parquets. + + Parameters + ---------- + pool : PoolConfig + Pool configuration (must have pool_address field). + data_root : str, optional + Directory containing {TICKER}_USD.parquet price files. + + Returns + ------- + pd.DataFrame + Columns: volume_usd, effective_tvl_usd, volatility. Indexed by date. + """ + from experiments.pool_registry import ( + get_data_end_date, + _date_to_unix, + ) + + if data_root is None: + data_root = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "quantammsim", "data", + ) + + start_ts = _date_to_unix(pool.plausible_start) + end_str = get_data_end_date(pool.tokens, data_root) + end_ts = _date_to_unix(end_str) + + print(f" Fetching Balancer snapshots for {pool.label} " + f"({pool.chain}, {pool.pool_address})...") + api_df = fetch_balancer_snapshots( + pool.chain, pool.pool_address, start_ts, end_ts, + ) + print(f" Got {len(api_df)} daily snapshots from API") + print(f" TVL range: ${api_df['total_liquidity_usd'].min():,.0f} — " + f"${api_df['total_liquidity_usd'].max():,.0f}") + + print(f" Computing daily volatility from price parquets...") + vol_series = compute_daily_volatility(pool.tokens, data_root, start_ts, end_ts) + print(f" Got {len(vol_series)} daily volatility values") + + # Assemble: volume + TVL from API, volatility from parquets + combined = api_df[["volume_usd", "total_liquidity_usd"]].copy() + combined = combined.rename(columns={"total_liquidity_usd": "effective_tvl_usd"}) + combined["volatility"] = vol_series + combined = combined.dropna() + + print(f" Combined: {len(combined)} days after join") + return combined + + +# --------------------------------------------------------------------------- +# OLS calibration +# --------------------------------------------------------------------------- + +def run_ols_calibration(daily_df, base_fee, model="sqrt"): + """OLS regression for Tsoukalas model params. + + Parameters + ---------- + daily_df : pd.DataFrame + Must contain columns: volume_usd, volatility, effective_tvl_usd. + base_fee : float + Static swap fee (e.g. 0.003). + model : str + 'sqrt' or 'log' — TVL regressor transformation. + + Returns + ------- + noise_params : dict + Coefficients for run_fingerprint["reclamm_noise_params"]. + diagnostics : dict + Standard errors, R², residual summary. + """ + if model == "loglinear": + # Multiplicative model: log(V) = b_0 + b_sigma·σ + b_c·log(TVL) + # Implies: V = exp(b_0) · TVL^b_c · exp(b_sigma·σ) + mask = daily_df["volume_usd"].values > 0 + n_dropped = int((~mask).sum()) + df_fit = daily_df[mask] + + y_log = np.log(df_fit["volume_usd"].values) + X = np.column_stack([ + np.ones(len(df_fit)), + df_fit["volatility"].values, + np.log(df_fit["effective_tvl_usd"].values), + ]) + + beta, _, _, _ = np.linalg.lstsq(X, y_log, rcond=None) + b_0, b_sigma, b_c = beta + + residuals = y_log - X @ beta + n, k = X.shape + bread = np.linalg.inv(X.T @ X) + hc1_scale = n / max(n - k, 1) + meat = X.T @ np.diag(residuals**2 * hc1_scale) @ X + robust_cov = bread @ meat @ bread + se = np.sqrt(np.diag(robust_cov)) + + ss_res = np.sum(residuals**2) + ss_tot = np.sum((y_log - y_log.mean())**2) + r_squared = 1.0 - ss_res / max(ss_tot, 1e-30) + + # Pseudo-R² in levels (median predictor) + y_pred_level = np.exp(X @ beta) + y_actual_level = df_fit["volume_usd"].values + res_level = y_actual_level - y_pred_level + r_sq_level = 1.0 - np.sum(res_level**2) / max( + np.sum((y_actual_level - y_actual_level.mean())**2), 1e-30) + + noise_params = { + "b_0": float(b_0), "b_sigma": float(b_sigma), + "b_c": float(b_c), "base_fee": float(base_fee), + } + diagnostics = { + "se": {"b_0": float(se[0]), "b_sigma": float(se[1]), + "b_c": float(se[2])}, + "r_squared": float(r_squared), + "r_squared_level": float(r_sq_level), + "n_obs": int(n), + "n_dropped_zero": n_dropped, + "residual_mean": float(np.mean(residuals)), + "residual_std": float(np.std(residuals)), + "smearing_factor": float(np.exp(np.var(residuals, ddof=1) / 2)), + "model": "loglinear", + } + return noise_params, diagnostics + + # --- Linear models (sqrt / log) --- + y = daily_df["volume_usd"].values / 1e6 + + if model == "sqrt": + tvl_eff = np.sqrt(daily_df["effective_tvl_usd"].values / 1e6) + elif model == "log": + tvl_eff = np.log(np.maximum(daily_df["effective_tvl_usd"].values / 1e6, 1e-30)) + else: + raise ValueError(f"Unknown model: {model!r}. Use 'sqrt', 'log', or 'loglinear'.") + + X = np.column_stack([ + np.ones(len(daily_df)), # a_0 + daily_df["volatility"].values, # a_sigma + tvl_eff, # a_c + ]) + + beta, residuals_ss, rank, sv = np.linalg.lstsq(X, y, rcond=None) + a_0, a_sigma, a_c = beta + + # Heteroskedasticity-robust standard errors (HC1) + residuals = y - X @ beta + n, k = X.shape + bread = np.linalg.inv(X.T @ X) + hc1_scale = n / max(n - k, 1) + meat = X.T @ np.diag(residuals**2 * hc1_scale) @ X + robust_cov = bread @ meat @ bread + se = np.sqrt(np.diag(robust_cov)) + + # R-squared + ss_res = np.sum(residuals**2) + ss_tot = np.sum((y - np.mean(y))**2) + r_squared = 1.0 - ss_res / max(ss_tot, 1e-30) + + noise_params = { + "a_0_base": float(a_0), + "a_f": 0.0, # not identified with static fees + "a_sigma": float(a_sigma), + "a_c": float(a_c), + "base_fee": float(base_fee), + } + + diagnostics = { + "se": dict(zip( + ["a_0", "a_sigma", "a_c"], + se.tolist(), + )), + "r_squared": float(r_squared), + "n_obs": int(n), + "residual_mean": float(np.mean(residuals)), + "residual_std": float(np.std(residuals)), + "model": model, + } + + return noise_params, diagnostics + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_calibration_diagnostics(daily_df, noise_params, diagnostics, + pool_label="", model="sqrt", + output_dir="results"): + """Generate diagnostic plots for the noise volume calibration. + + Produces a 2×2 figure: + Top-left: Time series — real vs predicted daily volume + effective TVL + Top-right: Scatter — predicted vs actual with 45° line + Bot-left: Residuals vs time + residuals vs fitted + Bot-right: Component decomposition (stacked contributions) + + Parameters + ---------- + daily_df : pd.DataFrame + Calibration DataFrame (indexed by date). + noise_params : dict + Fitted coefficients from run_ols_calibration. + diagnostics : dict + Diagnostics dict from run_ols_calibration. + pool_label : str + Pool name for titles. + model : str + 'sqrt' or 'log'. + output_dir : str + Directory for output PNGs. + """ + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.dates import DateFormatter + import matplotlib.dates as mdates + + os.makedirs(output_dir, exist_ok=True) + + # Extract data + dates = pd.to_datetime(daily_df.index) + y_actual = daily_df["volume_usd"].values / 1e6 # in $M + eff_tvl = daily_df["effective_tvl_usd"].values + vol = daily_df["volatility"].values + + r2 = diagnostics["r_squared"] + n = diagnostics["n_obs"] + + if model == "loglinear": + b_0 = noise_params["b_0"] + b_sigma = noise_params["b_sigma"] + b_c = noise_params["b_c"] + log_tvl = np.log(np.maximum(eff_tvl, 1.0)) + y_pred_log = b_0 + b_sigma * vol + b_c * log_tvl + y_pred = np.exp(y_pred_log) / 1e6 # median prediction in $M + + # Log-space residuals + mask_pos = daily_df["volume_usd"].values > 0 + residuals = np.full(len(dates), np.nan) + residuals[mask_pos] = ( + np.log(daily_df["volume_usd"].values[mask_pos]) + - y_pred_log[mask_pos] + ) + resid_unit = "log scale" + r2_level = diagnostics.get("r_squared_level") + else: + a_0 = noise_params["a_0_base"] + a_sigma = noise_params["a_sigma"] + a_c = noise_params["a_c"] + + if model == "sqrt": + tvl_term = a_c * np.sqrt(eff_tvl / 1e6) + else: + tvl_term = a_c * np.log(np.maximum(eff_tvl / 1e6, 1e-30)) + + y_pred = a_0 + a_sigma * vol + tvl_term + residuals = y_actual - y_pred + resid_unit = "$M" + r2_level = None + + # --- Figure 1: Main diagnostics (2×2) --- + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + + # (0,0) Time series: real vs predicted + TVL on secondary axis + ax = axes[0, 0] + ax.plot(dates, y_actual, color="steelblue", alpha=0.7, linewidth=1, + label="Actual volume") + ax.plot(dates, y_pred, color="crimson", linewidth=1.5, + label="Predicted volume") + ax.set_ylabel("Daily volume ($M)", color="steelblue") + ax.tick_params(axis="y", labelcolor="steelblue") + ax.legend(loc="upper left", fontsize=8) + r2_str = (f"R²(log)={r2:.3f}, R²(level)={r2_level:.3f}" + if r2_level is not None else f"R²={r2:.3f}") + ax.set_title(f"Daily volume: actual vs predicted ({r2_str}, n={n})") + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + + ax2 = ax.twinx() + ax2.fill_between(dates, eff_tvl / 1e6, alpha=0.15, color="green", + label="Effective TVL ($M)") + ax2.set_ylabel("Effective TVL ($M)", color="green") + ax2.tick_params(axis="y", labelcolor="green") + ax2.legend(loc="upper right", fontsize=8) + + # (0,1) Scatter: predicted vs actual + ax = axes[0, 1] + ax.scatter(y_pred, y_actual, alpha=0.5, s=15, color="steelblue", + edgecolors="none") + lims = [min(y_pred.min(), y_actual.min()), max(y_pred.max(), y_actual.max())] + margin = (lims[1] - lims[0]) * 0.05 + lims = [lims[0] - margin, lims[1] + margin] + ax.plot(lims, lims, "k--", linewidth=0.8, alpha=0.5, label="45° line") + ax.set_xlabel("Predicted ($M)") + ax.set_ylabel("Actual ($M)") + ax.set_title("Predicted vs actual") + ax.legend(fontsize=8) + ax.set_aspect("equal", adjustable="box") + + # (1,0) Residuals: vs time (top) and vs fitted (bottom) + ax = axes[1, 0] + ax.scatter(dates, residuals, alpha=0.5, s=12, color="steelblue", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + # 7-day rolling mean of residuals + res_series = pd.Series(residuals, index=dates) + rolling_mean = res_series.rolling(7, min_periods=1).mean() + ax.plot(dates, rolling_mean, color="crimson", linewidth=1.5, + label="7-day rolling mean") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs time") + ax.legend(fontsize=8) + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + + # (1,1) Component decomposition + ax = axes[1, 1] + if model == "loglinear": + # Log-space additive decomposition as line plots + log_tvl_plot = np.log(np.maximum(eff_tvl, 1.0)) + comp_base = np.full(len(dates), b_0) + comp_tvl = b_c * log_tvl_plot + comp_vol = b_sigma * vol + total_log = comp_base + comp_tvl + comp_vol + vol_usd = daily_df["volume_usd"].values.copy() + vol_usd[vol_usd <= 0] = np.nan + actual_log = np.log(vol_usd) + ax.plot(dates, comp_base, color="grey", linestyle="--", linewidth=1, + label=f"b_0 = {b_0:.2f}") + ax.plot(dates, comp_base + comp_tvl, color="green", linewidth=1.5, + label=f"b_0 + b_c·log(TVL) (b_c={b_c:.4f})") + ax.plot(dates, total_log, color="crimson", linewidth=1.5, + label="Full prediction") + ax.scatter(dates, actual_log, color="steelblue", s=10, alpha=0.5, + label="Actual log(V)", zorder=5) + ax.set_ylabel("log(Volume, USD)") + ax.set_title("Component decomposition (log space)") + + fig.suptitle( + f"{pool_label} — noise calibration ({model})\n" + f"log(V) = {b_0:.2f} + {b_sigma:.4f}·σ + {b_c:.4f}·log(TVL)", + fontsize=11, + ) + else: + intercept_contrib = np.full(len(dates), a_0) + vol_contrib = a_sigma * vol + tvl_contrib = tvl_term + + ax.fill_between(dates, 0, intercept_contrib, alpha=0.3, color="grey", + label=f"a_0 = {a_0:.4f}") + ax.fill_between(dates, intercept_contrib, intercept_contrib + vol_contrib, + alpha=0.3, color="orange", + label=f"a_σ·σ (a_σ={a_sigma:.4f})") + ax.fill_between(dates, intercept_contrib + vol_contrib, + intercept_contrib + vol_contrib + tvl_contrib, + alpha=0.3, color="green", + label=f"a_c·{model}(TVL) (a_c={a_c:.4f})") + ax.plot(dates, y_actual, color="steelblue", linewidth=1, alpha=0.7, + label="Actual") + ax.set_ylabel("Volume ($M)") + ax.set_title("Component decomposition") + + fig.suptitle( + f"{pool_label} — Tsoukalas noise calibration ({model})\n" + f"V/1e6 = {a_0:.4f} + {a_sigma:.4f}·σ + {a_c:.4f}·{model}(TVL_eff/1e6)", + fontsize=11, + ) + ax.legend(fontsize=7, loc="upper left") + ax.xaxis.set_major_formatter(DateFormatter("%b %y")) + plt.tight_layout() + + fname = f"noise_calibration_{pool_label}_{model}.png" + path = os.path.join(output_dir, fname) + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path}") + + # --- Figure 2: Residuals vs each regressor --- + fig2, axes2 = plt.subplots(1, 3, figsize=(16, 5)) + + # Residuals vs volatility + ax = axes2[0] + ax.scatter(vol, residuals, alpha=0.5, s=12, color="orange", edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Volatility (annualised)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs volatility") + + # Residuals vs effective TVL + ax = axes2[1] + ax.scatter(eff_tvl / 1e6, residuals, alpha=0.5, s=12, color="green", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Effective TVL ($M)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs effective TVL") + + # Residuals vs fitted + ax = axes2[2] + ax.scatter(y_pred, residuals, alpha=0.5, s=12, color="steelblue", + edgecolors="none") + ax.axhline(0, color="black", linewidth=0.8) + ax.set_xlabel("Fitted ($M)") + ax.set_ylabel(f"Residual ({resid_unit})") + ax.set_title("Residuals vs fitted") + + fig2.suptitle(f"{pool_label} — Residual diagnostics ({model})", fontsize=11) + plt.tight_layout() + + fname2 = f"noise_residuals_{pool_label}_{model}.png" + path2 = os.path.join(output_dir, fname2) + plt.savefig(path2, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {path2}") + + return path, path2 + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Calibrate Tsoukalas noise volume model for reClAMM" + ) + parser.add_argument( + "--csv", default=None, + help="Path to CSV with columns: volume_usd, volatility, effective_tvl_usd", + ) + parser.add_argument( + "--pool", default=None, + help="Pool label from pool_registry (e.g. cbBTC_WETH) for end-to-end calibration", + ) + parser.add_argument("--base-fee", type=float, default=None, + help="Override base fee (default: use pool's swap_fee)") + parser.add_argument("--model", choices=["sqrt", "log", "loglinear"], + default="sqrt") + parser.add_argument( + "--output", default=None, + help="Output JSON file path. Defaults to stdout.", + ) + parser.add_argument( + "--plot", action="store_true", + help="Generate diagnostic plots (saved to --output-dir)", + ) + parser.add_argument( + "--output-dir", default="results", + help="Directory for diagnostic plots (default: results)", + ) + args = parser.parse_args() + + if args.csv is None and args.pool is None: + parser.error("One of --csv or --pool is required") + + if args.pool is not None: + # End-to-end mode: fetch data, assemble, calibrate + from experiments.pool_registry import POOL_REGISTRY + + if args.pool not in POOL_REGISTRY: + print(f"Unknown pool: {args.pool}", file=sys.stderr) + print(f"Available: {list(POOL_REGISTRY.keys())}", file=sys.stderr) + sys.exit(1) + + pool = POOL_REGISTRY[args.pool] + base_fee = args.base_fee if args.base_fee is not None else pool.swap_fee + + print(f"Calibrating noise model for {pool.label} ({pool.chain})") + print(f" Swap fee: {base_fee}") + print(f" Model: {args.model}") + + df = build_calibration_df(pool) + else: + # CSV mode + df = pd.read_csv(args.csv) + required_cols = {"volume_usd", "volatility", "effective_tvl_usd"} + missing = required_cols - set(df.columns) + if missing: + print(f"Error: missing columns: {missing}", file=sys.stderr) + sys.exit(1) + base_fee = args.base_fee if args.base_fee is not None else 0.003 + + noise_params, diagnostics = run_ols_calibration(df, base_fee, args.model) + + # Print diagnostics + print(f"\n OLS Results ({args.model} model):") + print(f" R² = {diagnostics['r_squared']:.4f}") + if "r_squared_level" in diagnostics: + print(f" R²(level) = {diagnostics['r_squared_level']:.4f}") + if "n_dropped_zero" in diagnostics and diagnostics["n_dropped_zero"] > 0: + print(f" Dropped {diagnostics['n_dropped_zero']} zero-volume days") + if "smearing_factor" in diagnostics: + print(f" Smearing factor = {diagnostics['smearing_factor']:.4f} " + f"(E[V]/median[V])") + print(f" n = {diagnostics['n_obs']}") + print(f" Coefficients:") + if args.model == "loglinear": + coef_keys = ["b_0", "b_sigma", "b_c"] + else: + coef_keys = ["a_0", "a_sigma", "a_c"] + for key in coef_keys: + param_key = "a_0_base" if key == "a_0" else key + val = noise_params[param_key] + se = diagnostics["se"][key] + t_stat = val / se if se > 0 else float("inf") + print(f" {key:>8} = {val:>10.4f} (SE={se:.4f}, t={t_stat:.2f})") + print(f" Residual: mean={diagnostics['residual_mean']:.6f}, " + f"std={diagnostics['residual_std']:.4f}") + + # Plot diagnostics + if args.plot: + label = args.pool if args.pool else "custom" + plot_calibration_diagnostics( + df, noise_params, diagnostics, + pool_label=label, model=args.model, + output_dir=args.output_dir, + ) + + result = { + "noise_params": noise_params, + "diagnostics": diagnostics, + } + + output_str = json.dumps(result, indent=2) + if args.output: + with open(args.output, "w") as f: + f.write(output_str + "\n") + print(f"\nWrote calibration to {args.output}", file=sys.stderr) + else: + print(f"\n{output_str}") + + +if __name__ == "__main__": + main() diff --git a/scripts/compare_reclamm_thermostats.py b/scripts/compare_reclamm_thermostats.py new file mode 100644 index 0000000..8a2c374 --- /dev/null +++ b/scripts/compare_reclamm_thermostats.py @@ -0,0 +1,379 @@ +"""Compare geometric vs constant-arc-length thermostats on historic data. + +Runs AAVE/ETH reClAMM pool simulations with both interpolation methods. +Plots: pool value, cumulative LVR, price path, empirical weights, +value difference, LVR ratio, and per-step LVR distribution (∝ Δs²). + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/compare_reclamm_thermostats.py +""" + +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(daily_price_shift_exponent): + """Convert shift rate to daily price shift base (matches Solidity).""" + return 1.0 - daily_price_shift_exponent / 124649.0 + + +# Pool configurations to compare +CONFIGS = [ + { + "name": "AAVE/ETH on-chain (25bps, narrow range)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 1.5, + "centeredness_margin": 0.5, + "daily_price_shift_exponent": 0.1, + }, + { + "name": "AAVE/ETH wide range (25bps)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 4.0, + "centeredness_margin": 0.2, + "daily_price_shift_exponent": 1.0, + }, + { + "name": "AAVE/ETH zero fees (narrow)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0, + "price_ratio": 1.5, + "centeredness_margin": 0.5, + "daily_price_shift_exponent": 0.1, + }, +] + + +def make_fingerprint(cfg, interpolation_method, centeredness_scaling=False): + """Build run fingerprint for a given config and interpolation method.""" + return { + "tokens": cfg["tokens"], + "rule": "reclamm", + "startDateString": cfg["start"], + "endDateString": cfg["end"], + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": cfg["fees"], + "gas_cost": 0.0, + "arb_fees": 0.0, + "reclamm_interpolation_method": interpolation_method, + "reclamm_arc_length_speed": None, # auto-calibrate + "reclamm_centeredness_scaling": centeredness_scaling, + } + + +def make_params(cfg): + """Build pool params from config.""" + return { + "price_ratio": jnp.array(cfg["price_ratio"]), + "centeredness_margin": jnp.array(cfg["centeredness_margin"]), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(cfg["daily_price_shift_exponent"]) + ), + } + + +def run_comparison(cfg): + """Run all thermostat variants, return results dict.""" + params = make_params(cfg) + + results = {} + for method in ["geometric", "constant_arc_length"]: + fp = make_fingerprint(cfg, method) + results[method] = do_run_on_historic_data( + run_fingerprint=fp, params=params + ) + + # Geometric + centeredness-proportional scaling (scales decay duration) + fp_geo_scaled = make_fingerprint(cfg, "geometric", centeredness_scaling=True) + results["geometric_scaled"] = do_run_on_historic_data( + run_fingerprint=fp_geo_scaled, params=params + ) + + # Arc-length + centeredness-proportional scaling (scales speed) + fp_cal_scaled = make_fingerprint(cfg, "constant_arc_length", centeredness_scaling=True) + results["cal_scaled"] = do_run_on_historic_data( + run_fingerprint=fp_cal_scaled, params=params + ) + + return results + + +def print_comparison(cfg, results): + """Print text summary table.""" + methods = [ + ("Geometric", results["geometric"]), + ("Geo+Scaled", results["geometric_scaled"]), + ("Const Arc", results["constant_arc_length"]), + ("Arc+Scaled", results["cal_scaled"]), + ] + + hodl_value = float((methods[0][1]["reserves"][0] * methods[0][1]["prices"][-1]).sum()) + + print("=" * 105) + print(f" {cfg['name']}") + print(f" price_ratio={cfg['price_ratio']}, " + f"margin={cfg['centeredness_margin']}, " + f"shift_exp={cfg['daily_price_shift_exponent']}, " + f"fees={cfg['fees']}") + print("-" * 105) + header = " {:20s}".format("") + for name, _ in methods: + header += f" {name:>14s}" + print(header) + + row = " {:20s}".format("Final value") + for _, r in methods: + row += f" ${float(r['final_value']):>13,.0f}" + print(row) + + print(f" {'HODL value':20s} ${hodl_value:>13,.0f}") + + row = " {:20s}".format("LVR (HODL - final)") + for _, r in methods: + lvr = hodl_value - float(r["final_value"]) + row += f" ${lvr:>13,.0f}" + print(row) + + row = " {:20s}".format("Return") + for _, r in methods: + ret = (float(r["final_value"]) / float(r["value"][0]) - 1) * 100 + row += f" {ret:>13.2f}%" + print(row) + + row = " {:20s}".format("vs HODL") + for _, r in methods: + vs = (float(r["final_value"]) / hodl_value - 1) * 100 + row += f" {vs:>13.2f}%" + print(row) + print("=" * 105) + + +def plot_comparison(cfg, results, fig_idx): + """Plot 4-panel comparison for one config.""" + # Method name → (result dict, color, linestyle) + variants = { + "Geometric": (results["geometric"], "C0", "-"), + "Geo+Scaled": (results["geometric_scaled"], "C1", "-"), + "Const arc-len": (results["constant_arc_length"], "C2", "--"), + "Arc+Scaled": (results["cal_scaled"], "C3", "--"), + } + + geo = results["geometric"] + geo_prices = np.array(geo["prices"]) + geo_reserves = np.array(geo["reserves"]) + n_steps = len(np.array(geo["value"])) + t_days = np.arange(n_steps) / (60 * 24) + + hodl_traj = (geo_reserves[0] * geo_prices[:n_steps]).sum(axis=-1) + price_ratio_traj = geo_prices[:n_steps, 0] / geo_prices[:n_steps, 1] + + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig.suptitle(cfg["name"], fontsize=13, fontweight="bold") + + # (0,0) Pool value over time + ax = axes[0, 0] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + ax.plot(t_days, vals / 1e6, color=color, ls=ls, label=name, alpha=0.9) + ax.plot(t_days, np.array(hodl_traj) / 1e6, color="gray", ls=":", + alpha=0.5, label="HODL") + ax.set_xlabel("Days") + ax.set_ylabel("Pool value ($M)") + ax.set_title("Pool value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (0,1) Cumulative LVR + ax = axes[0, 1] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + lvr = np.array(hodl_traj) - vals + ax.plot(t_days, lvr / 1e3, color=color, ls=ls, label=name, alpha=0.9) + ax.set_xlabel("Days") + ax.set_ylabel("Cumulative LVR ($K)") + ax.set_title("Cumulative LVR (HODL - pool value)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (1,0) Price ratio + ax = axes[1, 0] + ax.plot(t_days, price_ratio_traj, color="C4", alpha=0.7) + ax.set_xlabel("Days") + ax.set_ylabel(f"{cfg['tokens'][0]}/{cfg['tokens'][1]} price ratio") + ax.set_title("Price path") + ax.grid(True, alpha=0.3) + + # (1,1) Empirical weights + ax = axes[1, 1] + for name, (r, color, ls) in variants.items(): + w = np.array(r["weights"]) + n_w = min(len(w), n_steps) + t_w = np.arange(n_w) / (60 * 24) + ax.plot(t_w, w[:n_w, 0], color=color, ls=ls, label=name, alpha=0.9) + ax.set_xlabel("Days") + ax.set_ylabel(f"Weight ({cfg['tokens'][0]})") + ax.set_title("Empirical weight (token 0)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + fname = f"reclamm_thermostat_comparison_{fig_idx}.png" + plt.savefig(fname, dpi=150) + print(f"Saved {fname}") + plt.close(fig) + + # Second figure: diagnostics + geo_values = np.array(geo["value"]) + geo_lvr = np.array(hodl_traj) - geo_values + + fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5)) + fig2.suptitle(f"{cfg['name']} — diagnostics", fontsize=13, fontweight="bold") + + # (left) Value difference vs geometric + ax = axes2[0] + for name, (r, color, ls) in variants.items(): + if name == "Geometric": + continue + vals = np.array(r["value"]) + ax.plot(t_days, (vals - geo_values) / 1e3, color=color, ls=ls, + label=name, alpha=0.9) + ax.axhline(0, color="gray", ls="--", alpha=0.5) + ax.set_xlabel("Days") + ax.set_ylabel("Value difference ($K)") + ax.set_title("Minus Geometric") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # (middle) LVR ratio over time + ax = axes2[1] + mask = np.abs(geo_lvr) > 100 + if mask.any(): + for name, (r, color, ls) in variants.items(): + if name == "Geometric": + continue + vals = np.array(r["value"]) + method_lvr = np.array(hodl_traj) - vals + ratio = np.full_like(geo_lvr, np.nan) + ratio[mask] = method_lvr[mask] / geo_lvr[mask] + ax.plot(t_days, ratio, color=color, ls=ls, alpha=0.7, label=name) + ax.axhline(1.0, color="gray", ls="--", alpha=0.5) + ax.set_ylabel("LVR ratio (method / geometric)") + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "LVR too small to compare", + transform=ax.transAxes, ha="center", va="center") + ax.set_xlabel("Days") + ax.set_title("Relative LVR") + ax.grid(True, alpha=0.3) + + # (right) Per-step LVR histogram + ax = axes2[2] + all_pos = [] + for name, (r, color, ls) in variants.items(): + vals = np.array(r["value"]) + method_lvr = np.array(hodl_traj) - vals + step_lvr = np.diff(method_lvr) + pos = step_lvr[step_lvr > 0] + all_pos.append((name, pos, color)) + has_data = [len(p) > 10 for _, p, _ in all_pos] + if any(has_data): + max_val = max(np.percentile(p, 99) for _, p, _ in all_pos if len(p) > 10) + bins = np.linspace(0, max_val, 50) + for name, pos, color in all_pos: + if len(pos) > 10: + ax.hist(pos, bins=bins, color=color, alpha=0.3, label=name, + density=True) + ax.set_xlabel("Per-step LVR ($)") + ax.set_ylabel("Density") + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "Too few thermostat steps", + transform=ax.transAxes, ha="center", va="center") + ax.set_title("Per-step LVR distribution") + ax.grid(True, alpha=0.3) + + plt.tight_layout() + fname2 = f"reclamm_thermostat_diff_{fig_idx}.png" + plt.savefig(fname2, dpi=150) + print(f"Saved {fname2}") + plt.close(fig2) + + +if __name__ == "__main__": + all_results = [] + for i, cfg in enumerate(CONFIGS): + print(f"\n>>> Running {cfg['name']}...") + try: + results = run_comparison(cfg) + print_comparison(cfg, results) + plot_comparison(cfg, results, i) + all_results.append((cfg, results)) + except Exception as e: + print(f" FAILED: {e}") + import traceback + traceback.print_exc() + + # Summary overlay: all configs on one figure (pool value normalised) + if len(all_results) > 1: + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle("Cross-config comparison (normalised)", fontsize=13, + fontweight="bold") + + method_keys = [ + ("geometric", "geo", "-"), + ("geometric_scaled", "geo+s", "-."), + ("constant_arc_length", "arc", "--"), + ("cal_scaled", "arc+s", ":"), + ] + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + + for j, (key, suffix, ls) in enumerate(method_keys): + v = np.array(results[key]["value"]) + color_idx = i * len(method_keys) + j + + # (left) Normalised pool value + axes[0].plot(t, v / v[0], ls=ls, alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}") + + # (right) Value difference vs geometric (skip geo itself) + if key != "geometric": + pct_diff = (v - geo_v) / geo_v * 100 + axes[1].plot(t, pct_diff, ls=ls, alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}") + + axes[0].set_xlabel("Days") + axes[0].set_ylabel("Normalised pool value") + axes[0].set_title("Pool value (V/V0)") + axes[0].legend(fontsize=6, ncol=2) + axes[0].grid(True, alpha=0.3) + + axes[1].set_xlabel("Days") + axes[1].set_ylabel("(Method - Geo) / Geo (%)") + axes[1].set_title("Relative value difference vs Geometric") + axes[1].axhline(0, color="gray", ls="--", alpha=0.5) + axes[1].legend(fontsize=6, ncol=2) + axes[1].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("reclamm_thermostat_summary.png", dpi=150) + print("\nSaved reclamm_thermostat_summary.png") + plt.close(fig) diff --git a/scripts/demo_run_chunks_from_chain_data.py b/scripts/demo_run_chunks_from_chain_data.py index 09c19d1..fbd9cf9 100644 --- a/scripts/demo_run_chunks_from_chain_data.py +++ b/scripts/demo_run_chunks_from_chain_data.py @@ -28,7 +28,6 @@ import numpy as np import pandas as pd import matplotlib as mpl -from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import matplotlib.pyplot as plt import jax.numpy as jnp @@ -475,12 +474,10 @@ def _df_meta_and_head(df, name, n=3): run_fingerprint=fingerprint, coarse_weights=cw_window, params=params, - dynamic_input_frames=DynamicInputFrames( - fees=scraped["fees_df"], - gas_cost=scraped["gas_cost_df"], - lp_supply=scraped["lp_supply_df"], - arb_fees=scraped["arb_fees_df"], - ), + fees_df=scraped["fees_df"], + gas_cost_df=scraped["gas_cost_df"], + lp_supply_df=scraped["lp_supply_df"], + arb_fees_df=scraped["arb_fees_df"], ) # ---------------- Correct, window-aligned plotting block (time-aware + plain y) ---------------- diff --git a/scripts/demo_run_from_chain_data.py b/scripts/demo_run_from_chain_data.py index b1f4139..78ecbc4 100644 --- a/scripts/demo_run_from_chain_data.py +++ b/scripts/demo_run_from_chain_data.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import ( memory_days_to_logit_lamb, ) @@ -998,12 +997,10 @@ def generate_daily_variations(start_date_str, end_date_str): run_fingerprint=config["fingerprint"], coarse_weights=config["coarse_weights"], params=config["params"], - dynamic_input_frames=DynamicInputFrames( - fees=config["fees_df"], - gas_cost=config["gas_cost_df"], - lp_supply=config["lp_supply_df"], - arb_fees=config["arb_fees_df"], - ), + fees_df=config["fees_df"], + gas_cost_df=config["gas_cost_df"], + lp_supply_df=config["lp_supply_df"], + arb_fees_df=config["arb_fees_df"], ) print("-" * 80) print(f"Pool Type: {config['fingerprint']['rule']}") @@ -1194,4 +1191,4 @@ def generate_daily_variations(start_date_str, end_date_str): # actual_reserves_np=local_reserves, # actual_unix_values=datetime_array, # ) - # raise Exception("Stop here") + # raise Exception("Stop here") \ No newline at end of file diff --git a/scripts/demo_run_reclamm.py b/scripts/demo_run_reclamm.py new file mode 100644 index 0000000..3ea21ec --- /dev/null +++ b/scripts/demo_run_reclamm.py @@ -0,0 +1,207 @@ +"""Demo runs for reClAMM pools vs Balancer 50/50 baseline. + +Runs reClAMM pool simulations with parameters pulled from on-chain pools +(AAVE/ETH) and hypothetical configurations, each paired with a Balancer +50/50 constant-weight pool at the same fee level for comparison. + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm + python scripts/demo_run_reclamm.py +""" + +import jax.numpy as jnp +from quantammsim.runners.jax_runners import do_run_on_historic_data + + +def to_daily_price_shift_base(daily_price_shift_exponent): + """Convert shift rate to daily price shift base (matches Solidity).""" + return 1.0 - daily_price_shift_exponent / 124649.0 + + +def balancer_fingerprint(tokens, start, end, fees): + """Build a Balancer 50/50 fingerprint matching the given reclamm config.""" + return { + "tokens": tokens, + "rule": "balancer", + "startDateString": start, + "endDateString": end, + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": fees, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + } + + +SCENARIOS = [ + { + "name": "AAVE/ETH on-chain (25bps)", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0025, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.1) + ), + }, + }, + }, + { + "name": "AAVE/ETH zero fees", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(1.5), + "centeredness_margin": jnp.array(0.5), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.1) + ), + }, + }, + }, + { + "name": "AAVE/ETH wide range (25bps)", + "reclamm": { + "fingerprint": { + "tokens": ["AAVE", "ETH"], + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.0025, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(1.0) + ), + }, + }, + }, + { + "name": "BTC/ETH (10bps)", + "reclamm": { + "fingerprint": { + "tokens": ["BTC", "ETH"], + "rule": "reclamm", + "startDateString": "2024-01-01 00:00:00", + "endDateString": "2025-06-01 00:00:00", + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": 0.001, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + }, + "params": { + "price_ratio": jnp.array(2.0), + "centeredness_margin": jnp.array(0.3), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(0.5) + ), + }, + }, + }, +] + + +def run_scenario(scenario): + """Run a reClAMM config and its Balancer 50/50 baseline, print comparison.""" + rc = scenario["reclamm"] + fp = rc["fingerprint"] + + # Run reClAMM + reclamm_result = do_run_on_historic_data( + run_fingerprint=fp, params=rc["params"] + ) + + # Run Balancer 50/50 with same tokens, dates, fees + bal_fp = balancer_fingerprint( + fp["tokens"], fp["startDateString"], fp["endDateString"], fp["fees"] + ) + bal_params = { + "initial_weights_logits": jnp.zeros(len(fp["tokens"])), + } + balancer_result = do_run_on_historic_data( + run_fingerprint=bal_fp, params=bal_params + ) + + # HODL value (from reClAMM initial reserves at final prices) + hodl_value = float( + (reclamm_result["reserves"][0] * reclamm_result["prices"][-1]).sum() + ) + + rc_final = float(reclamm_result["final_value"]) + bal_final = float(balancer_result["final_value"]) + rc_init = float(reclamm_result["value"][0]) + bal_init = float(balancer_result["value"][0]) + + print("=" * 80) + print(f" {scenario['name']}") + print(f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']}") + print("-" * 80) + print(f" {'':30s} {'reClAMM':>14s} {'Balancer 50/50':>14s}") + print(f" {'Initial value':30s} ${rc_init:>13,.0f} ${bal_init:>13,.0f}") + print(f" {'Final value':30s} ${rc_final:>13,.0f} ${bal_final:>13,.0f}") + print( + f" {'Return':30s} " + f"{(rc_final / rc_init - 1) * 100:>13.2f}% " + f"{(bal_final / bal_init - 1) * 100:>13.2f}%" + ) + print( + f" {'vs HODL':30s} " + f"{(rc_final / hodl_value - 1) * 100:>13.2f}% " + f"{(bal_final / hodl_value - 1) * 100:>13.2f}%" + ) + print( + f" {'reClAMM vs Balancer':30s} " + f"{(rc_final / bal_final - 1) * 100:>13.2f}%" + ) + print("=" * 80) + + +if __name__ == "__main__": + for scenario in SCENARIOS: + print(f"\n>>> {scenario['name']}...") + try: + run_scenario(scenario) + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc() diff --git a/scripts/plot_predicted_vs_real_volume.py b/scripts/plot_predicted_vs_real_volume.py new file mode 100644 index 0000000..83232ab --- /dev/null +++ b/scripts/plot_predicted_vs_real_volume.py @@ -0,0 +1,151 @@ +"""Plot predicted vs real daily volume for pool registry pools. + +Uses the fitted noise model (from calibrate_noise_unified.py) to compute +predicted daily log-volume for each pool in the registry, and overlays +the actual observed volume from the Balancer API panel data. +""" + +import json +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "experiments")) +from pool_registry import POOL_REGISTRY, BALANCER_API_CHAIN + + +def main(): + fitted_path = "results/unified_full_90d.json" + panel_path = "local_data/noise_calibration/panel.parquet" + output_dir = "results/unified_full_90d" + os.makedirs(output_dir, exist_ok=True) + + with open(fitted_path) as f: + fitted = json.load(f) + + panel = pd.read_parquet(panel_path) + + # Deduplicate registry: multiple entries can share the same pool address + # (e.g. cbBTC_WETH and cbBTC_WETH_post_oct). Group by address. + unique_pools = {} + for label, pool in POOL_REGISTRY.items(): + addr = pool.pool_address.lower() + if addr not in unique_pools: + unique_pools[addr] = (label, pool) + + # Match to panel + matched = [] + for addr, (label, pool) in unique_pools.items(): + pid_matches = [ + pid for pid in fitted["pools"] + if addr in pid.lower() + ] + if pid_matches: + pid = pid_matches[0] + matched.append((label, pool, pid)) + else: + print(f" {label}: not in fitted model (skipping)") + + if not matched: + print("No registry pools found in the fitted model.") + return + + print(f"Plotting {len(matched)} pools: {[m[0] for m in matched]}") + + # Determine grid layout + n = len(matched) + ncols = min(n, 2) + nrows = (n + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(7 * ncols, 5 * nrows), + squeeze=False) + + for idx, (label, pool, pid) in enumerate(matched): + ax = axes[idx // ncols][idx % ncols] + pool_data = fitted["pools"][pid] + theta = np.array(pool_data["theta_median"]) + # theta = [intercept, b_tvl, b_sigma, b_weekend] + + # Get panel data for this pool + pool_panel = panel[panel["pool_id"] == pid].copy() + pool_panel = pool_panel.sort_values("date") + + if len(pool_panel) == 0: + ax.set_title(f"{label}: no panel data") + continue + + # Filter to last 90 days (matching training window) + max_date = panel["date"].max() + if hasattr(max_date, "date"): + max_date = max_date + from datetime import date, timedelta + if isinstance(max_date, date): + cutoff = max_date - timedelta(days=90) + else: + cutoff = pd.Timestamp(max_date) - pd.Timedelta(days=90) + pool_panel = pool_panel[ + pool_panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if len(pool_panel) < 5: + ax.set_title(f"{label}: <5 obs in 90d window") + continue + + # Build x_obs: [1, log_tvl_lag1, volatility, weekend] + x_obs = np.column_stack([ + np.ones(len(pool_panel)), + pool_panel["log_tvl_lag1"].values, + pool_panel["volatility"].values, + pool_panel["weekend"].values, + ]) + + predicted_log_vol = x_obs @ theta + actual_log_vol = pool_panel["log_volume"].values + + # Convert to USD volume for interpretability + predicted_vol = np.exp(predicted_log_vol) + actual_vol = np.exp(actual_log_vol) + + dates = pd.to_datetime(pool_panel["date"].values) + + # Plot + ax.plot(dates, actual_vol, "o-", color="steelblue", markersize=3, + linewidth=1, alpha=0.7, label="Actual") + ax.plot(dates, predicted_vol, "s--", color="orangered", markersize=3, + linewidth=1, alpha=0.7, label="Predicted") + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)") + ax.set_title(f"{label} ({pool_data['chain']})\n" + f"b_c={theta[1]:.2f} b_σ={theta[2]:.2f} " + f"b_wknd={theta[3]:.2f}") + ax.legend(fontsize=8) + ax.tick_params(axis="x", rotation=30) + + # Annotate R² for this pool + ss_res = np.sum((actual_log_vol - predicted_log_vol) ** 2) + ss_tot = np.sum((actual_log_vol - actual_log_vol.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + ax.text(0.02, 0.95, f"R²={r2:.3f}\nn={len(pool_panel)}", + transform=ax.transAxes, fontsize=8, va="top", + bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.8)) + + # Hide unused axes + for idx in range(n, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle("Noise model: predicted vs actual daily volume\n" + "(registry pools, 90-day training window)", fontsize=13) + fig.tight_layout() + out_path = os.path.join(output_dir, "registry_predicted_vs_real.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + print(f"Saved: {out_path}") + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_reclamm_optuna_result.py b/scripts/plot_reclamm_optuna_result.py new file mode 100644 index 0000000..35242da --- /dev/null +++ b/scripts/plot_reclamm_optuna_result.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +"""Plot reClAMM pool performance from Optuna tuning results. + +Reads the SGD-compatible JSON output of tune_reclamm_params.py (or any Optuna +run), extracts the best trial's pool params, re-runs a forward pass over the +full train+test window, and produces a value-over-time plot with on-chain +baselines and cumulative fee revenue. + +Usage: + python scripts/plot_reclamm_optuna_result.py results/run_.json + python scripts/plot_reclamm_optuna_result.py results/run_.json --output my_plot.png + python scripts/plot_reclamm_optuna_result.py results/run_.json --top-k 3 +""" + +import argparse +import json +import sys + +import jax.numpy as jnp +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from datetime import datetime + +from quantammsim.runners.jax_runners import do_run_on_historic_data + +# ── On-chain baselines ──────────────────────────────────────────────────── +ONCHAIN_LAUNCH_PARAMS = { + "price_ratio": 1.5, "centeredness_margin": 0.5, "shift_exponent": 0.1, +} +ONCHAIN_CURRENT_PARAMS = { + "price_ratio": 4.0, "centeredness_margin": 0.1, "shift_exponent": 0.001, +} + +BG = "#162536" +TEXT_COLOR = "#E6CE97" +COLORS = [ + "#3498db", "#2ecc71", "#e74c3c", # top-k + "#f39c12", # on-chain launch + "#9b59b6", # on-chain current +] + + +def _plot_order(configs): + """Yield (name, meta, color_idx) with baselines first, optimized trials last.""" + optimized = [] + baselines = [] + for i, (name, meta) in enumerate(configs.items()): + if "On-Chain" in name: + baselines.append((name, meta, i)) + else: + optimized.append((name, meta, i)) + return baselines + optimized + + +def parse_args(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("results_json", help="Path to run_.json from Optuna") + p.add_argument("--top-k", type=int, default=1, + help="Plot top K trials by objective (default 1)") + p.add_argument("--output", default=None, + help="Output PNG path (default: auto-generated)") + p.add_argument("--no-onchain", action="store_true", + help="Skip on-chain baseline runs") + p.add_argument("--end-test-date", default=None, + help="Override endTestDateString (e.g. '2026-02-15 00:00:00')") + p.add_argument("--noise-trader-ratio", type=float, default=None, + help="Override noise_trader_ratio from results config") + return p.parse_args() + + +def load_results(path): + """Load the double-encoded JSONL from Optuna results.""" + with open(path) as f: + raw = f.read() + data = json.loads(raw) + if isinstance(data, str): + data = json.loads(data) + if not isinstance(data, list) or len(data) < 2: + print(f"ERROR: Expected [config, trial1, trial2, ...], got {type(data)}") + sys.exit(1) + config = data[0] + trials = data[1:] + return config, trials + + +def extract_pool_params(trial, config): + """Extract reClAMM pool params from a trial entry.""" + param_keys = ["price_ratio", "centeredness_margin", "shift_exponent", + "arc_length_speed", "fees"] + params = {} + for k in param_keys: + if k in trial: + params[k] = trial[k] + return params + + +def run_full_period(params, config, fees_override=None): + """Run forward pass over the full train+test window.""" + fees = fees_override if fees_override is not None else config["fees"] + fp = { + "rule": "reclamm", + "tokens": config["tokens"], + "startDateString": config["startDateString"], + "endDateString": config["endTestDateString"], # full period + "initial_pool_value": config["initial_pool_value"], + "do_arb": config["do_arb"], + "fees": fees, + "gas_cost": config.get("gas_cost", 1.0), + "arb_fees": config.get("arb_fees", 0.0), + "protocol_fee_split": config.get("protocol_fee_split", 0.0), + "noise_trader_ratio": config.get("noise_trader_ratio", 0.0), + "reclamm_use_shift_exponent": config.get("reclamm_use_shift_exponent", True), + "reclamm_interpolation_method": config.get("reclamm_interpolation_method", "geometric"), + "reclamm_centeredness_scaling": config.get("reclamm_centeredness_scaling", False), + "reclamm_learn_arc_length_speed": config.get("reclamm_learn_arc_length_speed", False), + } + jax_params = {k: jnp.array(v) for k, v in params.items()} + return do_run_on_historic_data(run_fingerprint=fp, params=jax_params) + + +def plot_results(configs, time_series, hodl_values, config, args): + """Two-panel plot: value-over-time + cumulative fee revenue.""" + train_end_str = config["endDateString"] + train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range( + start=datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S"), + periods=n_minutes, freq="1min", + ) + step = 1440 + dates_daily = dates[::step] + + has_fee_revenue = any( + "fee_revenue" in time_series[n] and time_series[n]["fee_revenue"] is not None + for n in time_series + ) + n_panels = 2 if has_fee_revenue else 1 + fig, axes = plt.subplots( + n_panels, 1, figsize=(14, 5 * n_panels), + sharex=True, gridspec_kw={"height_ratios": [3, 1] if n_panels == 2 else [1]}, + ) + if n_panels == 1: + axes = [axes] + ax_val = axes[0] + + # ── Panel 1: Value over time ────────────────────────────────────── + for name, meta, ci in _plot_order(configs): + out = time_series[name] + vals = np.array(out["value"][::step]) / 1e6 + label = f"{name}" + if "test_objective" in meta: + obj_name = config.get("return_val", "objective") + label += f" (OOS {obj_name}={meta['test_objective']:.4f})" + is_optimized = "On-Chain" not in name + ax_val.plot(dates_daily[:len(vals)], vals, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=label, + zorder=3 if is_optimized else 2) + + hodl_daily = hodl_values[::step] / 1e6 + ax_val.plot(dates_daily[:len(hodl_daily)], hodl_daily, linewidth=2, + color="white", alpha=0.7, linestyle="--", label="HODL") + + ax_val.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + ylims = ax_val.get_ylim() + ax_val.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + color="white", alpha=0.6, fontsize=11, ha="right", va="top") + ax_val.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + color="white", alpha=0.6, fontsize=11, ha="left", va="top") + + _style_axis(ax_val) + ax_val.set_ylabel("Pool Value ($M USD)", color=TEXT_COLOR, fontsize=12) + tokens_str = "/".join(config["tokens"]) + obj_name = config.get("return_val", "objective") + ntr = config.get("noise_trader_ratio", 0.0) + ax_val.set_title( + f"reClAMM Optuna-Optimized ({obj_name}, noise={ntr}) — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15, + ) + ax_val.legend(loc="upper left", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + # ── Panel 2: Cumulative fee revenue ─────────────────────────────── + if has_fee_revenue: + ax_fee = axes[1] + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + fr = out.get("fee_revenue") + if fr is None: + continue + fr = np.array(fr) + cumfee = np.cumsum(fr)[::step] / 1e3 + is_optimized = "On-Chain" not in name + ax_fee.plot(dates_daily[:len(cumfee)], cumfee, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=name, + zorder=3 if is_optimized else 2) + + ax_fee.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + _style_axis(ax_fee) + ax_fee.set_ylabel("Cumulative Fee Revenue ($K)", color=TEXT_COLOR, fontsize=12) + ax_fee.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax_fee.legend(loc="upper left", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + else: + ax_val.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + + output = args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png" + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"\nSaved plot to {output}") + plt.close() + + +def plot_test_only(configs, time_series, hodl_values, config, args): + """Test-period plot with all curves normalised to start at 1.0.""" + train_end_str = config["endDateString"] + train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") + start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range(start=start_dt, periods=n_minutes, freq="1min") + + # Find the index of the train/test boundary + train_minutes = int((train_end_dt - start_dt).total_seconds() / 60) + test_start_idx = min(train_minutes, n_minutes - 1) + + step = 1440 + test_dates = dates[test_start_idx::step] + + fig, ax = plt.subplots(1, 1, figsize=(14, 6)) + + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + vals = np.array(out["value"]) + test_vals = vals[test_start_idx::step] + if len(test_vals) == 0: + continue + normalised = test_vals / test_vals[0] + is_optimized = "On-Chain" not in name + ax.plot(test_dates[:len(normalised)], normalised, + linewidth=2.5 if is_optimized else 1.8, + color=COLORS[ci % len(COLORS)], label=name, + zorder=3 if is_optimized else 2) + + hodl_test = hodl_values[test_start_idx::step] + if len(hodl_test) > 0: + hodl_norm = hodl_test / hodl_test[0] + ax.plot(test_dates[:len(hodl_norm)], hodl_norm, linewidth=2, + color="white", alpha=0.7, linestyle="--", label="HODL") + + ax.axhline(1.0, color="white", linestyle=":", alpha=0.3, linewidth=1) + _style_axis(ax) + tokens_str = "/".join(config["tokens"]) + obj_name = config.get("return_val", "objective") + ntr = config.get("noise_trader_ratio", 0.0) + ax.set_title(f"Test Period Only (normalised) — {obj_name}, noise={ntr} — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15) + ax.set_ylabel("Normalised Value", color=TEXT_COLOR, fontsize=12) + ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax.legend(loc="best", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + base = (args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png") + output = base.replace(".png", "_test_only.png") + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"Saved plot to {output}") + plt.close() + + +def plot_weights(configs, time_series, config, args): + """Effective weight (value fraction) of token 0 over time.""" + start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") + train_end_dt = datetime.strptime(config["endDateString"], "%Y-%m-%d %H:%M:%S") + + first_out = next(iter(time_series.values())) + n_minutes = len(first_out["value"]) + dates = pd.date_range(start=start_dt, periods=n_minutes, freq="1min") + step = 1440 + dates_daily = dates[::step] + + token_name = config["tokens"][0] + + fig, ax = plt.subplots(1, 1, figsize=(14, 5)) + + for name, _meta, ci in _plot_order(configs): + out = time_series[name] + weights = np.array(out["weights"]) # (T, 2) + w0 = weights[::step, 0] + is_optimized = "On-Chain" not in name + ax.plot(dates_daily[:len(w0)], w0, + linewidth=2.0 if is_optimized else 1.5, + color=COLORS[ci % len(COLORS)], label=name, + alpha=0.9 if is_optimized else 0.7, + zorder=3 if is_optimized else 2) + + ax.axhline(0.5, color="white", linestyle="--", alpha=0.3, linewidth=1) + ax.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) + ylims = ax.get_ylim() + ax.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + color="white", alpha=0.6, fontsize=11, ha="right", va="top") + ax.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + color="white", alpha=0.6, fontsize=11, ha="left", va="top") + + _style_axis(ax) + tokens_str = "/".join(config["tokens"]) + ax.set_title(f"Effective {token_name} Weight — {tokens_str}", + color=TEXT_COLOR, fontsize=13, pad=15) + ax.set_ylabel(f"{token_name} weight (value fraction)", color=TEXT_COLOR, fontsize=12) + ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) + ax.legend(loc="best", fontsize=9, facecolor=BG, + edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) + + fig.patch.set_facecolor(BG) + plt.tight_layout() + base = (args.output or f"reclamm_optuna_{tokens_str.replace('/', '_')}.png") + output = base.replace(".png", "_weights.png") + plt.savefig(output, dpi=200, bbox_inches="tight", facecolor=BG) + print(f"Saved plot to {output}") + plt.close() + + +def _style_axis(ax): + ax.set_facecolor(BG) + ax.tick_params(colors=TEXT_COLOR) + for spine in ax.spines.values(): + spine.set_color(TEXT_COLOR) + spine.set_alpha(0.3) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.grid(True, alpha=0.15, color=TEXT_COLOR) + + +def main(): + args = parse_args() + config, trials = load_results(args.results_json) + if args.end_test_date: + config["endTestDateString"] = args.end_test_date + if args.noise_trader_ratio is not None: + config["noise_trader_ratio"] = args.noise_trader_ratio + tokens = config["tokens"] + obj_name = config.get("return_val", "objective") + + # Sort trials by penalised objective + trials_sorted = sorted(trials, key=lambda t: t.get("objective", 0), reverse=True) + top_trials = trials_sorted[:args.top_k] + + print("=" * 80) + print(f"reClAMM Optuna Result Plotter — objective: {obj_name}") + print("=" * 80) + print(f" Results: {args.results_json}") + print(f" Tokens: {'/'.join(tokens)}") + print(f" Train: {config['startDateString']} → {config['endDateString']}") + print(f" Test: {config['endDateString']} → {config['endTestDateString']}") + print(f" Fees: {config['fees']}, Gas: {config.get('gas_cost', 1.0)}") + print(f" Trials: {len(trials)} total, plotting top {len(top_trials)}") + + configs = {} + for i, trial in enumerate(top_trials): + params = extract_pool_params(trial, config) + name = f"#{trial.get('optuna_trial_number', i)} (rank {i+1})" + configs[name] = { + "params": params, + "objective": trial.get("objective", 0), + "train_objective": trial.get("train_objective", 0), + "test_objective": trial.get("test_objective", 0), + "train_sharpe": trial.get("train_sharpe", 0), + "validation_sharpe": trial.get("validation_sharpe", 0), + } + print(f"\n {name}:") + print(f" {obj_name}: train={trial.get('train_objective', 0):.4f} " + f"test={trial.get('test_objective', 0):.4f} " + f"penalised={trial.get('objective', 0):.4f}") + print(f" sharpe: train={trial.get('train_sharpe', 0):+.4f} " + f"val={trial.get('validation_sharpe', 0):+.4f}") + for k, v in params.items(): + print(f" {k}: {v:.6g}") + + if not args.no_onchain: + configs["On-Chain (launch)"] = {"params": dict(ONCHAIN_LAUNCH_PARAMS)} + configs["On-Chain (current)"] = {"params": dict(ONCHAIN_CURRENT_PARAMS)} + + # ── Full-period runs ────────────────────────────────────────────── + print(f"\n--- Running full-period simulations ({config['startDateString']} → " + f"{config['endTestDateString']}) ---") + time_series = {} + for name, cfg in configs.items(): + print(f" {name}...", end=" ", flush=True) + out = run_full_period(cfg["params"], config) + time_series[name] = out + fv = float(out["final_value"]) + fr = out.get("fee_revenue") + fr_total = float(np.array(fr).sum()) if fr is not None else 0 + hodl = float((out["reserves"][0] * out["prices"][-1]).sum()) + print(f"final=${fv:,.0f} hodl=${hodl:,.0f} RoH={fv/hodl - 1:+.2%} " + f"fee_rev=${fr_total:,.0f}") + + first_out = next(iter(time_series.values())) + hodl_reserves = first_out["reserves"][0] + hodl_values = np.sum( + np.array(hodl_reserves) * np.array(first_out["prices"]), axis=1, + ) + + # ── Plots ───────────────────────────────────────────────────────── + plot_results(configs, time_series, hodl_values, config, args) + plot_test_only(configs, time_series, hodl_values, config, args) + plot_weights(configs, time_series, config, args) + + # ── Summary table ───────────────────────────────────────────────── + print(f"\n{'=' * 120}") + print(f"SUMMARY — {'/'.join(tokens)} — {obj_name}") + print(f"{'=' * 120}") + hdr = (f"{'Config':<28s} {'Train '+obj_name:>20s} {'Test '+obj_name:>20s} " + f"{'Train SR':>10s} {'Val SR':>10s} " + f"{'PR':>7s} {'Margin':>7s} {'ShiftExp':>10s} {'Full RoH':>10s}") + print(hdr) + print("-" * 120) + + for name, cfg in configs.items(): + cp = cfg["params"] + fv = float(time_series[name]["final_value"]) + full_roh = fv / float(hodl_values[-1]) - 1 + print( + f"{name:<28s} " + f"{cfg.get('train_objective', float('nan')):>20.4f} " + f"{cfg.get('test_objective', float('nan')):>20.4f} " + f"{cfg.get('train_sharpe', float('nan')):>+10.4f} " + f"{cfg.get('validation_sharpe', float('nan')):>+10.4f} " + f"{cp.get('price_ratio', float('nan')):>7.3f} " + f"{cp.get('centeredness_margin', float('nan')):>7.4f} " + f"{cp.get('shift_exponent', float('nan')):>10.4g} " + f"{full_roh * 100:>+9.2f}%" + ) + print("=" * 120) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_top50_predicted_vs_real.py b/scripts/plot_top50_predicted_vs_real.py new file mode 100644 index 0000000..0951a20 --- /dev/null +++ b/scripts/plot_top50_predicted_vs_real.py @@ -0,0 +1,521 @@ +"""Plot predicted vs real volume for top 50 pools by TVL on Feb 1st 2026. + +Enumerates WEIGHTED (min_tvl=1000) and RECLAMM (min_tvl=0) pools, +fetches their snapshots, filters to those with TVL >= $10k on Feb 1st 2026, +takes the top 50 by TVL, and plots predicted vs actual daily volume using +the inference artifact from calibrate_noise_unified.py. + +For pools that were in the model's training set, uses their per-pool theta. +For pools not in the training set, uses population-level prediction from B. +""" + +import ast +import json +import os +import sys +import time +from datetime import date, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# Reuse functions from the noise calibration package +from quantammsim.noise_calibration import ( + BALANCER_API_CHAINS, + _graphql_request, + assemble_panel, + classify_token_tier, + encode_covariates, + fetch_pool_snapshots, + fetch_token_prices, +) + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_top50" +) +TVL_DATE = date(2026, 2, 1) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "top50_feb1" +) +# Inference artifact from the main unified model run +FITTED_JSON = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "unified_full_90d.json" +) + + +def enumerate_all_pools(): + """Enumerate WEIGHTED (min_tvl=1000) and RECLAMM (min_tvl=0) pools.""" + all_pools = [] + + for chain in BALANCER_API_CHAINS: + for pool_type, min_tvl in [("WEIGHTED", 1000), ("RECLAMM", 0)]: + query = { + "query": """ + query GetPools($chain: GqlChain!, $types: [GqlPoolType!], + $minTvl: Float) { + poolGetPools( + where: { chainIn: [$chain], poolTypeIn: $types, minTvl: $minTvl } + ) { + id chain type protocolVersion + poolTokens { symbol weight address } + dynamicData { totalLiquidity swapFee } + } + } + """, + "variables": { + "chain": chain, + "types": [pool_type], + "minTvl": min_tvl, + }, + } + + try: + body = _graphql_request(query) + pools = body.get("data", {}).get("poolGetPools", []) + except Exception as e: + print(f" FAILED {chain} {pool_type}: {e}") + continue + + for p in pools: + tokens = [t["symbol"] for t in p.get("poolTokens", [])] + addresses = [t.get("address", "") for t in p.get("poolTokens", [])] + tvl = float(p.get("dynamicData", {}).get("totalLiquidity", 0)) + fee = float(p.get("dynamicData", {}).get("swapFee", 0)) + all_pools.append({ + "pool_id": p["id"], + "chain": p["chain"], + "pool_type": p["type"], + "tokens": tokens, + "token_addresses": addresses, + "swap_fee": fee, + "current_tvl": tvl, + }) + + if pools: + print(f" {chain:>10} {pool_type:>10}: {len(pools)}") + time.sleep(0.3) + + df = pd.DataFrame(all_pools) + print(f"\n Total: {len(df)} pools") + return df + + +def fetch_all_snapshots_cached(pools_df, cache_dir): + """Fetch snapshots for all pools, caching per-pool.""" + snap_dir = os.path.join(cache_dir, "snapshots") + os.makedirs(snap_dir, exist_ok=True) + + all_snaps = [] + n = len(pools_df) + + for i, (_, pool) in enumerate(pools_df.iterrows()): + pid = pool["pool_id"] + chain = pool["chain"] + cache_file = os.path.join(snap_dir, f"{pid}.parquet") + + if os.path.exists(cache_file): + df = pd.read_parquet(cache_file) + else: + if (i + 1) % 20 == 0 or i == 0: + print(f" Fetching snapshots {i+1}/{n}...", flush=True) + try: + df = fetch_pool_snapshots(pid, chain) + if len(df) > 0: + df.to_parquet(cache_file, index=False) + time.sleep(0.3) + except Exception as e: + print(f" FAILED {pid[:20]}: {e}") + continue + + if len(df) > 0: + df["pool_id"] = pid + df["chain"] = chain + all_snaps.append(df) + + if all_snaps: + return pd.concat(all_snaps, ignore_index=True) + return pd.DataFrame() + + +def get_tvl_on_date(snapshots_df, target_date, window_days=3): + """Get TVL for each pool on/near target_date.""" + results = [] + for pid in snapshots_df["pool_id"].unique(): + pool_snaps = snapshots_df[snapshots_df["pool_id"] == pid] + + best_row = None + best_dist = float("inf") + for _, row in pool_snaps.iterrows(): + d = row["date"] + if isinstance(d, date): + dist = abs((d - target_date).days) + else: + dist = abs((pd.Timestamp(d).date() - target_date).days) + if dist < best_dist: + best_dist = dist + best_row = row + + if best_row is not None and best_dist <= window_days: + results.append({ + "pool_id": pid, + "tvl_feb1": float(best_row["total_liquidity_usd"]), + "date_used": best_row["date"], + }) + + return pd.DataFrame(results) + + +def _get_theta_for_pool(pid, fitted, panel_90d, pop_B, pop_cov_names): + """Get theta for a pool: from fitted artifact if available, else population. + + Returns (theta, source) where source is 'fitted' or 'population'. + For IBP models, population fallback adds marginal feature effect (pi @ W). + """ + if pid in fitted["pools"]: + return np.array(fitted["pools"][pid]["theta_median"]), "fitted" + + # Population-level prediction: theta = B @ z_pool + # Build z_pool from the pool's covariates + pp = panel_90d[panel_90d["pool_id"] == pid] + if len(pp) == 0: + return None, "no_data" + + chain = pp["chain"].iloc[0] + tokens = pp["tokens"].iloc[0] + if isinstance(tokens, str): + tokens = tokens.split(",") + fee = pp["swap_fee"].iloc[0] if "swap_fee" in pp.columns else 0.003 + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = tiers[0] + + # Build covariate vector matching the model's encoding + z = np.zeros(len(pop_cov_names)) + for i, name in enumerate(pop_cov_names): + if name == "intercept": + z[i] = 1.0 + elif name == f"chain_{chain}": + z[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z[i] = 1.0 + elif name == "log_fee": + z[i] = np.log(max(fee, 1e-6)) + + # B is (K_coeff, K_cov), theta = B @ z + B = np.array(pop_B) # (K_coeff, K_cov) + theta = B @ z + + # IBP: add marginal feature effect (pi @ W) + pop = fitted["population_effects"] + if "W" in pop and "feature_prevalences" in pop: + W = np.array(pop["W"]) # (K_features, K_coeff) + pi = np.array(pop["feature_prevalences"]) # (K_features,) + theta = theta + pi @ W + + return theta, "population" + + +def plot_pages(plot_pools, pool_idx_map_fitted, fitted, panel_90d, + pools_df, tvl_lookup, pop_B, pop_cov_names, + output_dir=OUTPUT_DIR): + """Generate paginated plots, 10 pools per page.""" + n_pools = len(plot_pools) + per_page = 10 + n_pages = (n_pools + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, n_pools) + page_pools = plot_pools[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(14, 4 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + + for idx, (pid, feb_tvl) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + + theta, source = _get_theta_for_pool( + pid, fitted, panel_90d, pop_B, pop_cov_names + ) + if theta is None: + ax.set_visible(False) + continue + + pp = panel_90d[panel_90d["pool_id"] == pid].sort_values("date") + if len(pp) < 5: + ax.set_visible(False) + continue + + x_obs = np.column_stack([ + np.ones(len(pp)), + pp["log_tvl_lag1"].values, + pp["volatility"].values, + pp["weekend"].values, + ]) + + pred_log = x_obs @ theta + actual_log = pp["log_volume"].values + pred_vol = np.exp(pred_log) + actual_vol = np.exp(actual_log) + dates = pd.to_datetime(pp["date"].values) + + ax.plot(dates, actual_vol, "o-", color="steelblue", markersize=2.5, + linewidth=0.9, alpha=0.7, label="Actual") + ax.plot(dates, pred_vol, "s--", color="orangered", markersize=2.5, + linewidth=0.9, alpha=0.7, label="Predicted") + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + ss_res = np.sum((actual_log - pred_log) ** 2) + ss_tot = np.sum((actual_log - actual_log.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + meta = pools_df[pools_df["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tokens = m["tokens"] + tok_str = "/".join(str(t)[:8] for t in tokens[:2]) + chain = str(m["chain"]) + ptype = str(m["pool_type"]) + else: + tok_str = pid[:16] + chain = "?" + ptype = "?" + + type_tag = "R" if ptype == "RECLAMM" else "W" + src_tag = "*" if source == "population" else "" + ax.set_title( + "{} ({}, {}){}\n" + "TVL ${:,.0f} on Feb 1 | " + "R\u00b2={:.3f} b_c={:.2f} b_\u03c3={:.2f} " + "b_wknd={:.2f} n={}".format( + tok_str, chain, type_tag, src_tag, feb_tvl, + r2, theta[1], theta[2], theta[3], len(pp)), + fontsize=8) + ax.legend(fontsize=7) + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + "Predicted vs actual daily volume \u2014 page {}/{} " + "(sorted by TVL on {}) [* = population prediction]".format( + page + 1, n_pages, TVL_DATE), + fontsize=11) + fig.tight_layout() + out = os.path.join(output_dir, "pred_vs_real_page{}.png".format(page + 1)) + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved: {}".format(out)) + + +def main(): + import argparse + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--artifact", default=FITTED_JSON, + help="Path to inference artifact JSON") + parser.add_argument("--output-dir", default=None, + help="Output directory (default: auto from model name)") + args = parser.parse_args() + + artifact_path = args.artifact + os.makedirs(CACHE_DIR, exist_ok=True) + + # ---- Load inference artifact ---- + print(f"Loading inference artifact: {artifact_path}") + with open(artifact_path) as f: + fitted = json.load(f) + n_fitted = len(fitted["pools"]) + model_name = fitted.get("model", "unknown") + print(f" Model: {model_name}") + print(f" {n_fitted} pools with fitted theta") + + # Output dir: use CLI override or auto from model name + output_dir = args.output_dir or os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", f"top50_feb1_{model_name}", + ) + os.makedirs(output_dir, exist_ok=True) + + # Extract population-level B matrix for pools not in the model + pop_cov_names = fitted["model_spec"]["covariate_names"] + pop_B = np.array(fitted["population_effects"]["B"]) # (K_coeff, K_cov) + print(f" Population B: {pop_B.shape}, covariates: {pop_cov_names}") + + # ---- Step 1: Enumerate pools ---- + pools_cache = os.path.join(CACHE_DIR, "pools.parquet") + if os.path.exists(pools_cache): + pools_df = pd.read_parquet(pools_cache) + if isinstance(pools_df["tokens"].iloc[0], str): + pools_df["tokens"] = pools_df["tokens"].apply(ast.literal_eval) + pools_df["token_addresses"] = pools_df["token_addresses"].apply( + ast.literal_eval + ) + print(f"\nLoaded {len(pools_df)} pools from cache") + else: + print("\n1. Enumerating pools...") + pools_df = enumerate_all_pools() + pools_df.to_parquet(pools_cache, index=False) + + # ---- Step 2: Fetch snapshots ---- + print("\n2. Fetching snapshots...") + snapshots_df = fetch_all_snapshots_cached(pools_df, CACHE_DIR) + print(f" {len(snapshots_df)} pool-days") + + # ---- Step 3: TVL on Feb 1st ---- + print(f"\n3. Finding TVL on {TVL_DATE}...") + tvl_df = get_tvl_on_date(snapshots_df, TVL_DATE) + tvl_df = tvl_df[tvl_df["tvl_feb1"] >= 10_000].copy() + tvl_df = tvl_df.sort_values("tvl_feb1", ascending=False).head(50) + top50_ids = set(tvl_df["pool_id"]) + tvl_lookup = dict(zip(tvl_df["pool_id"], tvl_df["tvl_feb1"])) + print(f" {len(tvl_df)} pools with TVL >= $10k") + + # How many are in the fitted model? + n_in_model = sum(1 for pid in top50_ids if pid in fitted["pools"]) + print(f" {n_in_model} in fitted model, " + f"{len(top50_ids) - n_in_model} will use population prediction") + + # ---- Step 4: Fetch token prices & assemble panel ---- + panel_cache = os.path.join(CACHE_DIR, "panel.parquet") + if os.path.exists(panel_cache): + panel = pd.read_parquet(panel_cache) + print(f"\n4. Loaded panel from cache: {len(panel)} obs") + else: + top50_pools = pools_df[pools_df["pool_id"].isin(top50_ids)].copy() + top50_snaps = snapshots_df[snapshots_df["pool_id"].isin(top50_ids)].copy() + + print("\n4. Fetching token prices...") + prices_cache = os.path.join(CACHE_DIR, "token_prices") + token_addr_by_chain = {} + for _, pool in top50_pools.iterrows(): + chain = pool["chain"] + tokens = pool["tokens"] + addresses = pool["token_addresses"] + if chain not in token_addr_by_chain: + token_addr_by_chain[chain] = {} + for sym, addr in zip(tokens, addresses): + if sym and addr: + token_addr_by_chain[chain][sym] = addr + + token_prices = fetch_token_prices( + token_addr_by_chain, cache_dir=prices_cache + ) + + print("\n Assembling panel...") + panel = assemble_panel(top50_pools, top50_snaps, token_prices) + panel.to_parquet(panel_cache, index=False) + + # ---- Step 5: Filter to 90 days ---- + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=90) + panel_90d = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if "log_tvl_lag1" not in panel_90d.columns: + panel_90d = panel_90d.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel_90d["log_tvl_lag1"] = panel_90d.groupby("pool_id")["log_tvl"].shift(1) + panel_90d = panel_90d.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel_90d.groupby("pool_id").size() + valid_pools = pool_counts[pool_counts >= 10].index + panel_90d = panel_90d[panel_90d["pool_id"].isin(valid_pools)].copy() + print(f"\n5. 90-day panel: {len(panel_90d)} obs, " + f"{panel_90d['pool_id'].nunique()} pools") + + # ---- Step 6: Plot ---- + # Sort by Feb 1 TVL, only include pools with panel data + plot_pools = [] + for pid in panel_90d["pool_id"].unique(): + if pid in tvl_lookup: + plot_pools.append((pid, tvl_lookup[pid])) + plot_pools.sort(key=lambda x: -x[1]) + print(f"\n6. Plotting {len(plot_pools)} pools...") + + plot_pages(plot_pools, fitted, fitted, panel_90d, pools_df, + tvl_lookup, pop_B, pop_cov_names, output_dir=output_dir) + + # ---- Summary table ---- + summary = [] + for pid, feb_tvl in plot_pools: + theta, source = _get_theta_for_pool( + pid, fitted, panel_90d, pop_B, pop_cov_names + ) + if theta is None: + continue + pp = panel_90d[panel_90d["pool_id"] == pid] + x = np.column_stack([ + np.ones(len(pp)), + pp["log_tvl_lag1"].values, + pp["volatility"].values, + pp["weekend"].values, + ]) + pred = x @ theta + actual = pp["log_volume"].values + ss_res = np.sum((actual - pred) ** 2) + ss_tot = np.sum((actual - actual.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + meta = pools_df[pools_df["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tok_str = "/".join(str(t) for t in m["tokens"][:2]) + chain = str(m["chain"]) + ptype = str(m["pool_type"]) + else: + tok_str = pid[:16] + chain = "?" + ptype = "?" + + summary.append({ + "pool_id": pid[:20], + "tokens": tok_str, + "chain": chain, + "type": ptype, + "tvl_feb1": feb_tvl, + "n_obs": len(pp), + "R2": r2, + "b_c": theta[1], + "b_sigma": theta[2], + "b_weekend": theta[3], + "source": source, + }) + + summary_df = pd.DataFrame(summary) + summary_path = os.path.join(output_dir, "top50_summary.csv") + summary_df.to_csv(summary_path, index=False) + print(f"\n Saved: {summary_path}") + + n_pools = len(summary_df) + n_fitted_used = (summary_df["source"] == "fitted").sum() + n_pop = (summary_df["source"] == "population").sum() + n_reclamm = (summary_df["type"] == "RECLAMM").sum() + print(f"\n{'='*70}") + print(f"Summary: {n_pools} pools ({n_fitted_used} fitted, {n_pop} population)") + print(f" RECLAMM: {n_reclamm} WEIGHTED: {n_pools - n_reclamm}") + print(f" Median R\u00b2: {summary_df['R2'].median():.3f}") + print(f" Mean b_c: {summary_df['b_c'].mean():.3f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_structural_top50.py b/scripts/run_structural_top50.py new file mode 100644 index 0000000..0459bd5 --- /dev/null +++ b/scripts/run_structural_top50.py @@ -0,0 +1,456 @@ +"""Fit the structural mixture model and plot predicted vs actual for top 50 pools. + +Uses the cached panel (last 90 days), fits with vanilla SVI, then generates +paginated plots showing V_arb + V_noise decomposition and predicted vs actual. +""" + +import json +import os +import sys +from datetime import date, timedelta + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "structural_top50_nogas", +) +OUTPUT_JSON = os.path.join(OUTPUT_DIR, "structural_fit.json") +TRAIN_DAYS = 90 +SVI_STEPS = 20_000 +SVI_LR = 1e-3 +NUM_SAMPLES = 1000 +SEED = 42 +TOP_N = 50 + + +def load_and_filter_panel(): + """Load cached panel, filter to 90 days, keep pools with >= 10 obs.""" + panel = pd.read_parquet(PANEL_CACHE) + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=TRAIN_DAYS) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + + if "log_tvl_lag1" not in panel.columns: + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel.groupby("pool_id").size() + valid = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid)].copy() + + print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " + f"{cutoff} to {max_date}") + return panel + + +def fit_structural(panel): + """Run SVI on the structural mixture model.""" + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.postprocessing import ( + check_convergence, extract_structural_params, + ) + from quantammsim.noise_calibration.output import generate_output_json + + import numpyro + numpyro.enable_x64() + + # Load gas costs for mainnet from CSV if available + gas_csv = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "formula_vs_real", "mainnet_gas_cost_daily.csv", + ) + gas_arr = None + if os.path.exists(gas_csv): + gas_df = pd.read_csv(gas_csv) + # CSV has columns: unix (ms timestamp), USD (gas cost) + gas_df["date"] = pd.to_datetime(gas_df["unix"], unit="ms").dt.date + gas_lookup = dict(zip(gas_df["date"], gas_df["USD"])) + + # Build per-observation gas array + gas_vals = [] + for _, row in panel.iterrows(): + d = row["date"] + if not isinstance(d, date): + d = pd.Timestamp(d).date() + chain = row["chain"] + if chain == "MAINNET" and d in gas_lookup: + gas_vals.append(gas_lookup[d]) + elif chain == "MAINNET": + gas_vals.append(1.0) # median fallback + else: + # L2 chains: ~$0.005 + from quantammsim.noise_calibration.constants import GAS_COSTS + gas_vals.append(GAS_COSTS.get(chain, 0.005)) + gas_arr = np.array(gas_vals, dtype=np.float64) + print(f"Gas costs: loaded ({len(gas_lookup)} mainnet days from CSV)") + else: + print("Gas costs: using defaults (no mainnet CSV)") + + # No-gas variant: rely on cadence alone to modulate V_arb + # Gas threshold kills V_arb=0 for TVL<$300k (most pools), making cadence + # unlearnable. Without gas, V_arb>0 for all pools and cadence has gradient. + gas_arr = np.zeros(len(panel), dtype=np.float64) + print("Gas costs: DISABLED (cadence-only mode)") + + data = encode_covariates_structural(panel, gas=gas_arr) + + print(f"\nFitting structural model: {SVI_STEPS} SVI steps, lr={SVI_LR}") + samples, elbo_losses = run_svi( + data, + num_steps=SVI_STEPS, + lr=SVI_LR, + seed=SEED, + num_samples=NUM_SAMPLES, + model_fn=structural_noise_model, + ) + convergence = check_convergence(elbo_losses, method="svi") + + pool_params = extract_structural_params(samples, data) + + # Save output JSON + os.makedirs(OUTPUT_DIR, exist_ok=True) + inference_config = { + "method": "svi", "svi_steps": SVI_STEPS, + "svi_lr": SVI_LR, "num_samples": NUM_SAMPLES, + } + generate_output_json( + pool_params, samples, data, convergence, + OUTPUT_JSON, inference_config, + ) + + return samples, data, pool_params, elbo_losses + + +def compute_predictions(samples, data, panel): + """Compute per-observation predicted V_arb and V_noise.""" + from quantammsim.noise_calibration.formula_arb import ( + formula_arb_volume_daily_jax, + ) + from quantammsim.noise_calibration.model import _pad_with_ref + from quantammsim.noise_calibration.constants import OBS_COEFF_NAMES + import jax.numpy as jnp + + sample_dict = samples + agg_fn = np.median + + # Cadence parameters + alpha_0 = agg_fn(np.array(sample_dict["alpha_0"])) + alpha_chain = agg_fn(np.array(sample_dict["alpha_chain"]), axis=0) + alpha_tier = agg_fn(np.array(sample_dict["alpha_tier"]), axis=0) + alpha_tvl = agg_fn(np.array(sample_dict["alpha_tvl"])) + + # MoE parameters + W_gate = agg_fn(np.array(sample_dict["W_gate"]), axis=0) + beta = agg_fn(np.array(sample_dict["beta"]), axis=0) + + pool_idx = np.array(data["pool_idx"]) + X_pool = np.array(data["X_pool"]) + x_obs = np.array(data["x_obs"]) + chain_idx = np.array(data["chain_idx"]) + tier_idx = np.array(data["tier_idx"]) + sigma_daily = np.array(data["sigma_daily"]) + lag_log_tvl = np.array(data["lag_log_tvl"]) + fee = np.array(data["fee"]) + gas = np.array(data["gas"]) + + # Per-pool cadence + padded_chain = np.concatenate([[0.0], alpha_chain]) + padded_tier = np.concatenate([[0.0], alpha_tier]) + + N_pools = data["N_pools"] + pool_log_cadence = np.zeros(N_pools) + for p in range(N_pools): + pool_log_cadence[p] = ( + alpha_0 + + padded_chain[chain_idx[p]] + + padded_tier[tier_idx[p]] + + alpha_tvl * np.median(lag_log_tvl[pool_idx == p]) + ) + + # Per-obs V_arb + log_cad_obs = pool_log_cadence[pool_idx] + cadence_obs = np.exp(np.clip(log_cad_obs, -2.0, 6.0)) + tvl_obs = np.exp(lag_log_tvl) + + V_arb = np.array(formula_arb_volume_daily_jax( + jnp.array(sigma_daily), jnp.array(tvl_obs), + jnp.array(fee), jnp.array(gas), jnp.array(cadence_obs), + )) + + # Per-pool noise coefficients via MoE + logits = X_pool @ W_gate + w = np.exp(logits - logits.max(axis=1, keepdims=True)) + w = w / w.sum(axis=1, keepdims=True) + beta_pool = w @ beta # (N_pools, K_obs_coeff) + + # Per-obs V_noise + log_V_noise = np.sum(beta_pool[pool_idx] * x_obs, axis=1) + V_noise = np.exp(log_V_noise) + + # Predicted total + V_total_pred = V_arb + V_noise + log_V_pred = np.log(np.maximum(V_total_pred, 1e-6)) + + return V_arb, V_noise, V_total_pred, log_V_pred, cadence_obs + + +def plot_top50(panel, data, pool_params, V_arb, V_noise, log_V_pred): + """Plot top 50 pools by median TVL.""" + pool_meta = data["pool_meta"] + pool_ids = data["pool_ids"] + pool_idx = np.array(data["pool_idx"]) + y_obs = np.array(data["y_obs"]) + + # Rank pools by median TVL + pool_tvl = {} + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + pool_tvl[pid] = np.median(np.exp(np.array(data["lag_log_tvl"])[mask])) + + ranked = sorted(pool_tvl.items(), key=lambda x: -x[1])[:TOP_N] + + # Build param lookup + param_lookup = {p["pool_id"]: p for p in pool_params} + + per_page = 10 + n_pages = (len(ranked) + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, len(ranked)) + page_pools = ranked[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4.5 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + + for idx, (pid, median_tvl) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + p_idx = pool_ids.index(pid) + mask = pool_idx == p_idx + + pp = panel[panel["pool_id"] == pid].sort_values("date") + dates = pd.to_datetime(pp["date"].values) + actual_vol = np.exp(y_obs[mask]) + pred_arb = V_arb[mask] + pred_noise = V_noise[mask] + pred_total = pred_arb + pred_noise + + # R2 + actual_log = y_obs[mask] + pred_log = log_V_pred[mask] + ss_res = np.sum((actual_log - pred_log) ** 2) + ss_tot = np.sum((actual_log - actual_log.mean()) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") + + # Arb fraction + arb_frac = np.median(pred_arb / np.maximum(pred_total, 1.0)) + + # Plot + ax.fill_between(dates, 0, pred_arb, alpha=0.3, color="orangered", + label="V_arb (LVR)") + ax.fill_between(dates, pred_arb, pred_total, alpha=0.3, + color="steelblue", label="V_noise (MoE)") + ax.plot(dates, actual_vol, "k-", linewidth=0.8, alpha=0.7, + label="Actual") + ax.plot(dates, pred_total, "--", color="purple", linewidth=0.8, + alpha=0.7, label="Predicted total") + + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + meta = pool_meta[pool_meta["pool_id"] == pid] + if len(meta) > 0: + m = meta.iloc[0] + tokens = m["tokens"] + if isinstance(tokens, str): + tokens = tokens.split(",") + tok_str = "/".join(str(t)[:8] for t in tokens[:2]) + chain = str(m["chain"]) + else: + tok_str = pid[:16] + chain = "?" + + params = param_lookup.get(pid, {}) + arb_freq = params.get("arb_frequency", "?") + + ax.set_title( + f"{tok_str} ({chain})\n" + f"TVL ${median_tvl:,.0f} | R\u00b2={r2:.3f} " + f"arb_freq={arb_freq}min arb_frac={arb_frac:.1%} " + f"n={mask.sum()}", + fontsize=8, + ) + ax.legend(fontsize=6, loc="upper right") + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + f"Structural mixture model: V_arb + V_noise decomposition " + f"— page {page + 1}/{n_pages} " + f"(top {TOP_N} by median TVL, 90d window)", + fontsize=11, + ) + fig.tight_layout() + out = os.path.join(OUTPUT_DIR, f"structural_top50_page{page + 1}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_elbo(elbo_losses): + """Plot ELBO convergence.""" + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + ax = axes[0] + ax.plot(elbo_losses, alpha=0.3, color="steelblue", linewidth=0.5) + window = min(100, len(elbo_losses) // 10) + if window > 1: + smoothed = pd.Series(elbo_losses).rolling(window).mean().values + ax.plot(smoothed, color="red", linewidth=1.5, label=f"Rolling {window}") + ax.legend() + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence") + + ax = axes[1] + start = len(elbo_losses) * 4 // 5 + ax.plot(range(start, len(elbo_losses)), elbo_losses[start:], + color="steelblue", linewidth=0.8) + ax.set_xlabel("Step") + ax.set_ylabel("ELBO loss") + ax.set_title("ELBO convergence (last 20%)") + + plt.tight_layout() + out = os.path.join(OUTPUT_DIR, "elbo_convergence.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + + +def plot_summary(data, pool_params, V_arb, V_noise, log_V_pred): + """Summary plots: arb frequency distribution, arb fraction, R2.""" + pool_idx = np.array(data["pool_idx"]) + y_obs = np.array(data["y_obs"]) + pool_ids = data["pool_ids"] + + fig, axes = plt.subplots(1, 3, figsize=(16, 5)) + + # 1. Arb frequency histogram + ax = axes[0] + freqs = [p["arb_frequency"] for p in pool_params] + ax.hist(freqs, bins=range(0, 62, 2), color="orangered", alpha=0.7, + edgecolor="white") + ax.set_xlabel("Arb frequency (minutes)") + ax.set_ylabel("Count") + ax.set_title(f"Arb frequency distribution (n={len(freqs)})") + ax.axvline(np.median(freqs), color="black", linestyle="--", + label=f"Median={np.median(freqs):.0f}min") + ax.legend() + + # 2. Arb fraction per pool + ax = axes[1] + arb_fracs = [] + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + total = V_arb[mask] + V_noise[mask] + arb_fracs.append(np.median(V_arb[mask] / np.maximum(total, 1.0))) + ax.hist(arb_fracs, bins=30, color="steelblue", alpha=0.7, edgecolor="white") + ax.set_xlabel("Median arb fraction") + ax.set_ylabel("Count") + ax.set_title("Arb fraction distribution") + ax.axvline(np.median(arb_fracs), color="black", linestyle="--", + label=f"Median={np.median(arb_fracs):.2f}") + ax.legend() + + # 3. Per-pool R2 + ax = axes[2] + r2_vals = [] + for i, pid in enumerate(pool_ids): + mask = pool_idx == i + actual = y_obs[mask] + pred = log_V_pred[mask] + ss_res = np.sum((actual - pred) ** 2) + ss_tot = np.sum((actual - actual.mean()) ** 2) + r2_vals.append(1 - ss_res / ss_tot if ss_tot > 0 else float("nan")) + r2_vals = np.array(r2_vals) + ax.hist(r2_vals[np.isfinite(r2_vals)], bins=30, color="green", alpha=0.7, + edgecolor="white") + ax.set_xlabel("R²") + ax.set_ylabel("Count") + ax.set_title("Per-pool R² distribution") + ax.axvline(np.nanmedian(r2_vals), color="black", linestyle="--", + label=f"Median={np.nanmedian(r2_vals):.3f}") + ax.legend() + + plt.tight_layout() + out = os.path.join(OUTPUT_DIR, "structural_summary.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out}") + + +def main(): + print("=" * 70) + print("Structural Mixture Model: Fit + Top 50 Plots") + print("=" * 70) + + panel = load_and_filter_panel() + samples, data, pool_params, elbo_losses = fit_structural(panel) + + print("\nComputing predictions...") + V_arb, V_noise, V_total, log_V_pred, cadence = compute_predictions( + samples, data, panel, + ) + print(f" V_arb median: ${np.median(V_arb):,.0f}") + print(f" V_noise median: ${np.median(V_noise):,.0f}") + print(f" Arb fraction (median pool): {np.median(V_arb / np.maximum(V_total, 1)):.2%}") + + print("\nGenerating plots...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + plot_elbo(elbo_losses) + plot_summary(data, pool_params, V_arb, V_noise, log_V_pred) + plot_top50(panel, data, pool_params, V_arb, V_noise, log_V_pred) + + # Summary stats + print(f"\n{'=' * 70}") + print(f"Done. Output in: {OUTPUT_DIR}") + arb_freqs = [p["arb_frequency"] for p in pool_params] + print(f" Arb frequency: median={np.median(arb_freqs):.0f}min, " + f"range=[{np.min(arb_freqs)}, {np.max(arb_freqs)}]") + print(f" JSON: {OUTPUT_JSON}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sim_vs_world_comparison.py b/scripts/sim_vs_world_comparison.py new file mode 100644 index 0000000..ca3046f --- /dev/null +++ b/scripts/sim_vs_world_comparison.py @@ -0,0 +1,972 @@ +#!/usr/bin/env python3 +"""Compare quantammsim reClAMM / Balancer vs reclamm-simulations repo + on-chain. + +Runs: + 1. Zero-fee Balancer pool (quantammsim) — the normalization baseline + 2. reClAMM pool with on-chain params (quantammsim) + 3. Loads reclamm-simulations results + world values from CSV + 4. Gas-experiment runs: time-varying gas from on-chain percentiles, + 50% protocol fee take, on-chain fees + +All comparisons align quantammsim's minute-level output to the world state +CSV's actual Unix timestamps, eliminating timing drift from block-time +variability. + +4-panel plot matching the reclamm-simulations format: + Top-left: Price (WETH/AAVE) — both repos overlaid + Top-right: (legend) + Bottom-left: Absolute value in WETH + Bottom-right: Value relative to feeless weighted (Balancer = 1.0) + +Usage: + python scripts/sim_vs_world_comparison.py + python scripts/sim_vs_world_comparison.py --csv /path/to/csv + python scripts/sim_vs_world_comparison.py --gas-experiment +""" + +import argparse +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import jax.numpy as jnp +from pathlib import Path +from datetime import datetime, timezone + +from quantammsim.runners.jax_runners import do_run_on_historic_data + +# ── On-chain reClAMM params ─────────────────────────────────────────────────── +ONCHAIN_FEES = 0.0025 + +ONCHAIN_LAUNCH_PARAMS = { # deployment through 2025-12-18 + "price_ratio": 1.5014, + "centeredness_margin": 0.5, + "shift_exponent": 0.1, +} +ONCHAIN_CURRENT_PARAMS = { # post 2025-12-18 governance + "price_ratio": 4.0, + "centeredness_margin": 0.1, + "shift_exponent": 0.001, +} +GOVERNANCE_DATE = "2025-12-18" + +# CSV starts at ~17.2 WETH ≈ $50k at $2900/ETH. +INITIAL_POOL_VALUE = 50_000.0 + +# Gas cost = arb profit threshold in USD. +# reclamm-simulations uses profit_threshold = 3e-4 WETH (in token1 units). +# quantammsim's arb_thresh is in USD: 3 * 3e-4 WETH × ~$3000/ETH ≈ $2.70. +ARB_GAS_COST = 2.7 + +DEFAULT_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_values_AAVE_WETH.csv" +) +ZEROFEE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_zerofee_centered_AAVE_WETH.csv" +) +ZEROFEE_MINUTE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_zerofee_centered_minute_AAVE_WETH.csv" +) +WORLD_STATE_CSV = ( + "/Users/matthew/Projects/reclamm-simulations" + "/data/sim_vs_world_world_AAVE_WETH.csv" +) +DEFAULT_START = "2025-08-16 00:00:00" +DEFAULT_END = "2026-01-04 00:00:00" +DEFAULT_TOKENS = ["AAVE", "ETH"] +HALF_DAY = 720 # minutes + +# Gas experiment +GAS_CSV_DIR = Path(__file__).resolve().parent.parent / "gas_csvs" +GAS_PERCENTILES = ["50p", "75p", "90p", "95p"] +GAS_SCALE_FACTORS = [0.25, 0.5, 0.75, 1.0] +FLAT_GAS_USD = [0.0, 0.25, 0.50, 1.0, 2.0, 3.0, 5.0] +PROTOCOL_FEE_SPLIT = 0.5 + + +def parse_args(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--csv", default=DEFAULT_CSV) + p.add_argument("--start", default=DEFAULT_START) + p.add_argument("--end", default=DEFAULT_END) + p.add_argument("--tokens", nargs="+", default=DEFAULT_TOKENS) + p.add_argument("--output", default="sim_vs_world_comparison.png") + p.add_argument( + "--gas-experiment", action="store_true", + help="Run gas-experiment sweep (time-varying gas, 50%% protocol fee)", + ) + p.add_argument( + "--launch-params", action="store_true", + help="Use launch params instead of current params in gas experiment", + ) + p.add_argument( + "--gas-scale-sweep", action="store_true", + help="Sweep gas cost scale factors, rebase to world, truncate at governance", + ) + p.add_argument( + "--best-gas", action="store_true", + help="Run the 3 best gas configs vs world (clean plot)", + ) + return p.parse_args() + + +def load_onchain_initial_state(): + """Load the on-chain pool state at t=0 from the world state CSV. + + Returns (state_dict, start_time_str) where state_dict has + Ra, Rb, Va, Vb (token units) and start_time_str is rounded + to the nearest minute for alignment with minute-level price data. + """ + df = pd.read_csv(WORLD_STATE_CSV) + r = df.iloc[0] + state = { + "Ra": float(r.balance_0), + "Rb": float(r.balance_1), + "Va": float(r.virtual_0), + "Vb": float(r.virtual_1), + } + # Round to nearest minute for price data alignment + ts_sec = int(r.timestamp) + ts_minute = (ts_sec // 60) * 60 + start_str = datetime.utcfromtimestamp(ts_minute).strftime("%Y-%m-%d %H:%M:%S") + return state, start_str + + +def load_world_timestamps(): + """Load Unix timestamps (seconds) from the world state CSV.""" + df = pd.read_csv(WORLD_STATE_CSV) + return df["timestamp"].values + + +def load_world_normalized_balances(): + """Load BPT-normalized on-chain balances and timestamps. + + Normalizes balances to initial BPT supply so that value tracks a + fixed LP position (accounts for joins/exits changing BPT supply). + + Returns (norm_bal_0, norm_bal_1, timestamps_sec). + """ + df = pd.read_csv(WORLD_STATE_CSV) + bpt_0 = df["bpt_supply"].iloc[0] + norm = bpt_0 / df["bpt_supply"].values + return ( + df["balance_0"].values * norm, + df["balance_1"].values * norm, + df["timestamp"].values, + ) + + +def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): + """Sample a minute-level array at specific Unix timestamps. + + For each target timestamp, finds the nearest minute index in the + sim output and returns the corresponding value. + + Parameters + ---------- + minute_vals : array, shape (N,) + Minute-level sim output. + start_unix_sec : float + Unix timestamp (seconds) of minute_vals[0]. + timestamps_sec : array + Unix timestamps (seconds) to sample at. + + Returns + ------- + array : values at the nearest minute to each target timestamp. + """ + indices = np.round((timestamps_sec - start_unix_sec) / 60).astype(int) + indices = np.clip(indices, 0, len(minute_vals) - 1) + return minute_vals[indices] + + +def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, + protocol_fee_split=0.0, gas_cost_df=None, + onchain_initial_state=None): + """Run a quantammsim pool and return minute-level results. + + Returns (val_eth, price_ratio, start_unix_sec) where val_eth and + price_ratio are minute-level arrays and start_unix_sec is the Unix + timestamp (seconds) of the first element. + """ + fp = { + "tokens": tokens, + "rule": rule, + "startDateString": start, + "endDateString": end, + "initial_pool_value": INITIAL_POOL_VALUE, + "fees": fees, + "gas_cost": gas_cost, + "arb_fees": 0.0, + "do_arb": True, + "arb_frequency": 1, + "chunk_period": 1440, + "weight_interpolation_period": 1440, + } + if rule == "reclamm": + fp["reclamm_use_shift_exponent"] = True + fp["reclamm_interpolation_method"] = "geometric" + fp["reclamm_centeredness_scaling"] = False + if protocol_fee_split != 0.0: + fp["protocol_fee_split"] = protocol_fee_split + if onchain_initial_state is not None: + fp["reclamm_initial_state"] = onchain_initial_state + + result = do_run_on_historic_data( + run_fingerprint=fp, params=params, gas_cost_df=gas_cost_df, + ) + + # Prices: sorted tokens → [AAVE, ETH] in USD + prices = np.array(result["prices"]) + eth_usd = prices[:, 1] + price_ratio = prices[:, 0] / prices[:, 1] # WETH/AAVE + + # Pool value in ETH + val_eth = np.array(result["value"]) / eth_usd + + # Compute start timestamp from startDateString + start_unix_sec = datetime.strptime( + start, "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc).timestamp() + + return val_eth, price_ratio, start_unix_sec + + +def load_gas_csv(percentile): + """Load a gas CSV and return a DataFrame with columns [unix, trade_gas_cost_usd]. + + Gas CSV timestamps are offset by ~59s from exact minutes. Round down + to the nearest minute so they align with the simulator's minute-level index. + """ + path = GAS_CSV_DIR / f"Gas_{percentile}.csv" + df = pd.read_csv(path) + df = df.rename(columns={"USD": "trade_gas_cost_usd"}) + df["unix"] = (df["unix"] // 60000) * 60000 # floor to minute boundary + return df + + +def run_gas_experiment(args): + """Run gas-experiment sweep and produce comparison plot.""" + tokens = args.tokens + start, end = args.start, args.end + + # ── Select params ───────────────────────────────────────────────── + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # ── Baselines ────────────────────────────────────────────────────── + print("Running Balancer (zero-fee 50/50)...") + bal_params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + bal_eth_min, qsim_price_min, start_sec = run_pool( + tokens, start, end, "balancer", 0.0, bal_params, + ) + + print(f"Running reClAMM ({param_label} params, flat gas, no protocol fee)...") + reclamm_flat_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=ARB_GAS_COST, + ) + + # ── Load world values from CSV ───────────────────────────────────── + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + + # ── Gas percentile runs ──────────────────────────────────────────── + gas_results_min = {} + for pct in GAS_PERCENTILES: + print(f"Running reClAMM ({param_label} params, gas={pct}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df = load_gas_csv(pct) + val_eth_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + ) + gas_results_min[pct] = val_eth_min + + # ── Sample at world timestamps ──────────────────────────────────── + world_ts = load_world_timestamps() + n = min(len(df), len(world_ts)) + world_ts = world_ts[:n] + + bal_eth = sample_at_timestamps(bal_eth_min, start_sec, world_ts) + reclamm_flat_eth = sample_at_timestamps(reclamm_flat_min, start_sec, world_ts) + qsim_price = sample_at_timestamps(qsim_price_min, start_sec, world_ts) + gas_results = { + pct: sample_at_timestamps(v, start_sec, world_ts) + for pct, v in gas_results_min.items() + } + + csv_world = df["world"].values[:n] + csv_feeless = df["feeless weighted"].values[:n] + print(f" Aligned: {n} world-timestamp points") + t = np.arange(n) + + # Governance half-day index + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts, gov_unix) + + # ── Plot: relative to feeless weighted ───────────────────────────── + fig, (ax_price, ax_rel) = plt.subplots(2, 1, figsize=(14, 9), + gridspec_kw={"height_ratios": [1, 2]}) + + # Top: price + ax_price.plot(t, qsim_price, color="gray", alpha=0.6, linewidth=1) + ax_price.set_ylabel("AAVE/ETH") + ax_price.set_title("Price") + ax_price.set_ylim(bottom=0) + if gov_idx < n: + ax_price.axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + # Bottom: relative values + ax_rel.axhline(y=1.0, color="blue", linewidth=2, label="feeless weighted") + + # Flat-gas baseline (no protocol fee) + flat_rel = reclamm_flat_eth / bal_eth + ax_rel.plot(t, flat_rel, linewidth=2, color="gray", linestyle="--", + label=f"flat gas ${ARB_GAS_COST}, no protocol fee") + + # Gas percentile runs + colors = {"50p": "#2ca02c", "75p": "#ff7f0e", "90p": "#d62728", "95p": "#9467bd"} + for pct in GAS_PERCENTILES: + vals = gas_results[pct] + rel = vals / bal_eth + ax_rel.plot(t, rel, linewidth=1.5, color=colors[pct], + label=f"gas {pct}, {int(PROTOCOL_FEE_SPLIT*100)}% protocol fee") + + # World values + world_rel = csv_world / csv_feeless + ax_rel.plot(t, world_rel, linewidth=1.5, marker=".", markersize=2, + color="brown", label="world (on-chain)") + + ax_rel.set_xlabel("half days") + ax_rel.set_ylabel("value / feeless weighted") + ax_rel.set_title("LP value relative to feeless weighted (Balancer 50/50)") + ax_rel.legend(fontsize=8, loc="lower left") + ax_rel.grid(True, alpha=0.2) + if gov_idx < n: + ax_rel.axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + ax_rel.text(gov_idx + 1, ax_rel.get_ylim()[1] * 0.98, + "governance", fontsize=7, color="gray", va="top") + + tokens_str = "/".join(tokens) + fig.suptitle( + f"reClAMM gas experiment ({param_label} params) — {tokens_str}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + out = args.output.replace(".png", f"_gas_experiment_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # ── Summary table ────────────────────────────────────────────────── + print(f"\n{'Scenario':<45} {'Final rel':>10} {'vs world':>10}") + print("-" * 65) + world_final_rel = world_rel[-1] if len(world_rel) > 0 else float("nan") + print(f"{'Flat gas, no protocol fee':<45} {flat_rel[-1]:>10.4f} " + f"{flat_rel[-1] - world_final_rel:>+10.4f}") + for pct in GAS_PERCENTILES: + rel = gas_results[pct] / bal_eth + print(f"{'Gas ' + pct + f', {int(PROTOCOL_FEE_SPLIT*100)}% protocol fee':<45} " + f"{rel[-1]:>10.4f} {rel[-1] - world_final_rel:>+10.4f}") + print(f"{'World (on-chain)':<45} {world_final_rel:>10.4f}") + + +def run_gas_scale_experiment(args): + """Sweep gas cost scale factors, rebase to world, truncate at governance.""" + tokens = args.tokens + end = args.end + + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # Load on-chain initial state and derive start time + onchain_state, onchain_start = load_onchain_initial_state() + start = onchain_start + print(f"On-chain initial state: Ra={onchain_state['Ra']:.2f}, " + f"Rb={onchain_state['Rb']:.2f}, Va={onchain_state['Va']:.2f}, " + f"Vb={onchain_state['Vb']:.2f}") + print(f"Sim start time (from on-chain): {start}") + + # Load world + reclamm-simulations values + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + + # Load world timestamps and find governance cutoff + world_ts = load_world_timestamps() + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts, gov_unix) + + # Run all (percentile, scale) combinations + results_min = {} + price_ratio_min = None + for pct in GAS_PERCENTILES: + gas_df_raw = load_gas_csv(pct) + for scale in GAS_SCALE_FACTORS: + label = f"{pct} × {scale}" + print(f"Running reClAMM ({param_label}, gas={label}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df = gas_df_raw.copy() + gas_df["trade_gas_cost_usd"] = gas_df_raw["trade_gas_cost_usd"] * scale + val_eth_min, pr_min, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + onchain_initial_state=onchain_state, + ) + results_min[(pct, scale)] = (val_eth_min, start_sec) + if price_ratio_min is None: + price_ratio_min = pr_min + + # Flat gas cost runs + flat_results_min = {} + for gas_usd in FLAT_GAS_USD: + print(f"Running reClAMM ({param_label}, flat gas=${gas_usd}, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + val_eth_min, _, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=gas_usd, protocol_fee_split=PROTOCOL_FEE_SPLIT, + onchain_initial_state=onchain_state, + ) + flat_results_min[gas_usd] = (val_eth_min, start_sec) + + # ── World values: on-chain balances × quantammsim prices ────────── + world_bal_0, world_bal_1, world_ts = load_world_normalized_balances() + + # reclamm-sim comparison uses its own CSV (self-consistent pricing) + csv_world = df["world"].values + csv_sim = df["simulation"].values + + n = min(gov_idx, len(world_bal_0), len(csv_world), len(csv_sim), len(world_ts)) + print(f" Truncated at governance: {n} world-timestamp points") + t = np.arange(n) + world_ts_trunc = world_ts[:n] + + # Repriced world for quantammsim comparison + price_at_world = sample_at_timestamps( + price_ratio_min, start_sec, world_ts_trunc, + ) + world_val = world_bal_0[:n] * price_at_world + world_bal_1[:n] + world_growth = world_val / world_val[0] + + # CSV-based world for reclamm-sim comparison (self-consistent pricing) + csv_world = csv_world[:n] + csv_sim = csv_sim[:n] + world_growth_csv = csv_world / csv_world[0] + recsim_growth = csv_sim / csv_sim[0] + + # Sample all sim runs at world timestamps + start_sec = flat_results_min[FLAT_GAS_USD[0]][1] + + results = {} + for key, (val_min, _) in results_min.items(): + results[key] = sample_at_timestamps(val_min, start_sec, world_ts_trunc) + + flat_results = {} + for gas_usd, (val_min, _) in flat_results_min.items(): + flat_results[gas_usd] = sample_at_timestamps(val_min, start_sec, world_ts_trunc) + + # Compute growth ratios + flat_growths = {} + for gas_usd in FLAT_GAS_USD: + vals = flat_results[gas_usd] + flat_growths[gas_usd] = vals / vals[0] + + # ── Plot (% deviation from world: positive = sim below world) ─── + fig, (ax_ts, ax_pct, ax_flat) = plt.subplots( + 1, 3, figsize=(20, 7), gridspec_kw={"width_ratios": [3, 1, 1]}, + ) + + # Left: time series of % deviation from world + ax_ts.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # reclamm-simulations (uses CSV-based world for self-consistent pricing) + recsim_dev = (1 - recsim_growth / world_growth_csv) * 100 + ax_ts.plot(t, recsim_dev, color="red", linewidth=2, + linestyle="--", label="reclamm-sim") + + # Gas scale sweep (percentile-based) + colors = {"50p": "#2ca02c", "75p": "#ff7f0e", "90p": "#d62728", "95p": "#9467bd"} + for pct in GAS_PERCENTILES: + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + dev = (1 - sim_growth / world_growth) * 100 + alpha = 0.3 + 0.7 * scale + lw = 0.8 + 1.2 * scale + if scale == 1.0: + label = f"{pct} × {scale}" + elif pct == "50p": + label = f"50p × {scale}" + else: + label = None + ax_ts.plot(t, dev, color=colors[pct], alpha=alpha, + linewidth=lw, label=label) + + # Flat gas runs + flat_cmap = plt.cm.copper + for i, gas_usd in enumerate(FLAT_GAS_USD): + c = flat_cmap(i / max(len(FLAT_GAS_USD) - 1, 1)) + dev = (1 - flat_growths[gas_usd] / world_growth) * 100 + ax_ts.plot(t, dev, color=c, linewidth=1.5, linestyle="-.", + label=f"flat ${gas_usd}") + + ax_ts.set_xlabel("half days") + ax_ts.set_ylabel("% deviation from world") + ax_ts.set_title("LP value vs world (pre-governance)") + ax_ts.legend(fontsize=6, loc="best", ncol=2) + ax_ts.grid(True, alpha=0.2) + + # Reference lines for both summary panels (as % deviation) + recsim_final_dev = (1 - recsim_growth[-1] / world_growth_csv[-1]) * 100 + + # Middle: final % deviation vs percentile scale factor + ax_pct.axhline(y=0.0, color="brown", linewidth=2, label="world") + ax_pct.axhline(y=recsim_final_dev, color="red", linewidth=1.5, + linestyle="--", label=f"reclamm-sim ({recsim_final_dev:+.2f}%)") + + for pct in GAS_PERCENTILES: + finals = [] + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + finals.append((1 - sim_growth[-1] / world_growth[-1]) * 100) + ax_pct.plot(GAS_SCALE_FACTORS, finals, marker="o", + color=colors[pct], linewidth=2, label=pct) + + ax_pct.set_xlabel("gas scale factor\n(1.0 = 450k gas)") + ax_pct.set_ylabel("% deviation from world") + ax_pct.set_title("Percentile gas") + ax_pct.legend(fontsize=6) + ax_pct.grid(True, alpha=0.2) + + # Right: final % deviation vs flat gas cost + ax_flat.axhline(y=0.0, color="brown", linewidth=2, label="world") + ax_flat.axhline(y=recsim_final_dev, color="red", linewidth=1.5, + linestyle="--", label=f"reclamm-sim ({recsim_final_dev:+.2f}%)") + + flat_finals = [] + for gas_usd in FLAT_GAS_USD: + flat_finals.append( + (1 - flat_growths[gas_usd][-1] / world_growth[-1]) * 100 + ) + ax_flat.plot(FLAT_GAS_USD, flat_finals, marker="s", color="black", + linewidth=2, label="flat gas") + + ax_flat.set_xlabel("flat gas cost (USD)") + ax_flat.set_ylabel("% deviation from world") + ax_flat.set_title("Flat gas") + ax_flat.legend(fontsize=6) + ax_flat.grid(True, alpha=0.2) + + tokens_str = "/".join(tokens) + fig.suptitle( + f"reClAMM gas sweep ({param_label} params) — {tokens_str}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + fontsize=10, + ) + plt.tight_layout() + out = args.output.replace(".png", f"_gas_scale_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # ── Summary table (% deviation from world) ───────────────────────── + print(f"\n{'Scenario':<35} {'% dev from world':>16}") + print("-" * 52) + print(f"{'reclamm-sim':<35} {recsim_final_dev:>+16.2f}%") + print() + for gas_usd in FLAT_GAS_USD: + dev = (1 - flat_growths[gas_usd][-1] / world_growth[-1]) * 100 + print(f"{'Flat $' + f'{gas_usd}':<35} {dev:>+16.2f}%") + print() + for pct in GAS_PERCENTILES: + for scale in GAS_SCALE_FACTORS: + vals = results[(pct, scale)] + sim_growth = vals / vals[0] + dev = (1 - sim_growth[-1] / world_growth[-1]) * 100 + print(f"{'Gas ' + pct + f' × {scale}':<35} {dev:>+16.2f}%") + + +def run_best_gas_experiment(args): + """Run the 3 best gas configs vs world on a clean single-panel plot.""" + tokens = args.tokens + end = args.end + + if args.launch_params: + param_source = ONCHAIN_LAUNCH_PARAMS + param_label = "launch" + else: + param_source = ONCHAIN_CURRENT_PARAMS + param_label = "current" + pool_params = {k: jnp.array(v) for k, v in param_source.items()} + + # Load on-chain initial state and derive start time + onchain_state, onchain_start = load_onchain_initial_state() + start = onchain_start + print(f"On-chain initial state: Ra={onchain_state['Ra']:.2f}, " + f"Rb={onchain_state['Rb']:.2f}, Va={onchain_state['Va']:.2f}, " + f"Vb={onchain_state['Vb']:.2f}") + print(f"Sim start time (from on-chain): {start}") + + # Find governance cutoff from world timestamps + world_ts_all = load_world_timestamps() + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts_all, gov_unix) + + # ── The best configs ───────────────────────────────────────────── + configs = [ + ("Flat $1.00", "black", "-"), + ("50p × 1.0", "#2ca02c", "-"), + ("75p × 0.75", "#ff7f0e", "-"), + ("90p × 0.25", "#d62728", "-"), + ] + + # 1) Flat $1.00 + print(f"Running reClAMM ({param_label}, flat gas=$1.00, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + flat1_min, price_ratio_min, start_sec = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + gas_cost=1.0, protocol_fee_split=PROTOCOL_FEE_SPLIT, + onchain_initial_state=onchain_state, + ) + + # 2) 50p × 1.0 + print(f"Running reClAMM ({param_label}, gas=50p × 1.0, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_50p = load_gas_csv("50p") + g50_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_50p, + onchain_initial_state=onchain_state, + ) + + # 3) 75p × 0.75 + print(f"Running reClAMM ({param_label}, gas=75p × 0.75, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_75p = load_gas_csv("75p") + gas_df_75p_scaled = gas_df_75p.copy() + gas_df_75p_scaled["trade_gas_cost_usd"] *= 0.75 + g75_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_75p_scaled, + onchain_initial_state=onchain_state, + ) + + # 4) 90p × 0.25 + print(f"Running reClAMM ({param_label}, gas=90p × 0.25, " + f"protocol_fee={PROTOCOL_FEE_SPLIT})...") + gas_df_90p = load_gas_csv("90p") + gas_df_90p_scaled = gas_df_90p.copy() + gas_df_90p_scaled["trade_gas_cost_usd"] *= 0.25 + g90_min, _, _ = run_pool( + tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, + protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_90p_scaled, + onchain_initial_state=onchain_state, + ) + + # ── World values: on-chain balances × quantammsim prices ────────── + # Both sim and world valued at the same price at each point, + # so price fluctuations cancel in the growth ratio comparison. + world_bal_0, world_bal_1, world_ts = load_world_normalized_balances() + n = min(gov_idx, len(world_bal_0), len(world_ts)) + print(f" Truncated at governance: {n} world-timestamp points") + t = np.arange(n) + world_ts_trunc = world_ts[:n] + + # Sample quantammsim price ratio at world timestamps + price_at_world = sample_at_timestamps( + price_ratio_min, start_sec, world_ts_trunc, + ) + # World value in ETH = norm_AAVE * (AAVE/ETH) + norm_ETH + world_val = world_bal_0[:n] * price_at_world + world_bal_1[:n] + world_growth = world_val / world_val[0] + + run_vals = [ + sample_at_timestamps(flat1_min, start_sec, world_ts_trunc), + sample_at_timestamps(g50_min, start_sec, world_ts_trunc), + sample_at_timestamps(g75_min, start_sec, world_ts_trunc), + sample_at_timestamps(g90_min, start_sec, world_ts_trunc), + ] + + growths = [v / v[0] for v in run_vals] + + # ── Plot ────────────────────────────────────────────────────────── + fig, ax = plt.subplots(figsize=(14, 6)) + + ax.axhline(y=0.0, color="brown", linewidth=2, label="world (on-chain)") + + # Best 3 + for (label, color, ls), g in zip(configs, growths): + dev = (1 - g / world_growth) * 100 + final_dev = dev[-1] + ax.plot(t, dev, color=color, linewidth=2, linestyle=ls, + label=f"{label} (final {final_dev:+.2f}%)") + + ax.set_xlabel("half days") + ax.set_ylabel("% deviation from world") + ax.set_title( + f"Best gas configs vs world ({param_label} params) — " + f"{'/'.join(tokens)}\n" + f"params: {list(param_source.values())}, " + f"fees: {ONCHAIN_FEES}, protocol fee: {PROTOCOL_FEE_SPLIT}", + ) + ax.legend(fontsize=9, loc="best") + ax.grid(True, alpha=0.3) + + plt.tight_layout() + out = args.output.replace(".png", f"_best_gas_{param_label}.png") + plt.savefig(out, dpi=150, bbox_inches="tight") + print(f"\nSaved: {out}") + plt.close() + + # Summary + labels = [c[0] for c in configs] + print(f"\n{'Scenario':<25} {'% dev from world':>16}") + print("-" * 42) + for label, g in zip(labels, growths): + dev = (1 - g[-1] / world_growth[-1]) * 100 + print(f"{label:<25} {dev:>+16.2f}%") + + +def main(): + args = parse_args() + + if args.best_gas: + run_best_gas_experiment(args) + return + + if args.gas_scale_sweep: + run_gas_scale_experiment(args) + return + + if args.gas_experiment: + run_gas_experiment(args) + return + + # ── Load CSVs ───────────────────────────────────────────────────── + print("Loading reclamm-simulations CSV...") + df = pd.read_csv(args.csv) + n_csv = len(df) + print(f" {n_csv} half-day points") + + print("Loading zero-fee minute-level CSV...") + df_zf_min = pd.read_csv(ZEROFEE_MINUTE_CSV) + print(f" {len(df_zf_min)} minute points") + + # Load world timestamps for alignment + world_ts = load_world_timestamps() + + # ── Run quantammsim pools (minute-level) ────────────────────────── + print("Running Balancer (zero-fee 50/50)...") + bal_params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + bal_eth_min, qsim_price_min, start_sec = run_pool( + args.tokens, args.start, args.end, "balancer", 0.0, bal_params, + ) + + print("Running reClAMM (launch, zero-fee, zero-gas)...") + launch_params = {k: jnp.array(v) for k, v in ONCHAIN_LAUNCH_PARAMS.items()} + reclamm_zerofee_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", 0.0, launch_params, + gas_cost=0.0, + ) + + print(f"Running reClAMM (launch params, gas=${ARB_GAS_COST})...") + reclamm_launch_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", ONCHAIN_FEES, launch_params, + gas_cost=ARB_GAS_COST, + ) + + print(f"Running reClAMM (current params, gas=${ARB_GAS_COST})...") + current_params = {k: jnp.array(v) for k, v in ONCHAIN_CURRENT_PARAMS.items()} + reclamm_current_min, _, _ = run_pool( + args.tokens, args.start, args.end, "reclamm", ONCHAIN_FEES, current_params, + gas_cost=ARB_GAS_COST, + ) + + # ── Sample at world timestamps ──────────────────────────────────── + n = min(n_csv, len(world_ts)) + world_ts_trunc = world_ts[:n] + + bal_eth = sample_at_timestamps(bal_eth_min, start_sec, world_ts_trunc) + reclamm_zerofee_eth = sample_at_timestamps(reclamm_zerofee_min, start_sec, world_ts_trunc) + reclamm_launch_eth = sample_at_timestamps(reclamm_launch_min, start_sec, world_ts_trunc) + reclamm_current_eth = sample_at_timestamps(reclamm_current_min, start_sec, world_ts_trunc) + qsim_price = sample_at_timestamps(qsim_price_min, start_sec, world_ts_trunc) + + print(f" Aligned: {n} world-timestamp points " + f"(qsim minutes={len(bal_eth_min)}, csv={n_csv})") + t = np.arange(n) + + csv_price = df["price"].values[:n] + csv_feeless = df["feeless weighted"].values[:n] + csv_sim = df["simulation"].values[:n] + csv_hold = df["hold"].values[:n] + csv_world = df["world"].values[:n] + + # Governance change index + gov_unix = datetime.strptime( + GOVERNANCE_DATE, "%Y-%m-%d" + ).replace(tzinfo=timezone.utc).timestamp() + gov_idx = np.searchsorted(world_ts_trunc, gov_unix) + + # Normalize quantammsim to same starting value as CSV + v0 = csv_feeless[0] + bal_norm = bal_eth * (v0 / bal_eth[0]) + zerofee_norm = reclamm_zerofee_eth * (v0 / reclamm_zerofee_eth[0]) + launch_norm = reclamm_launch_eth * (v0 / reclamm_launch_eth[0]) + current_norm = reclamm_current_eth * (v0 / reclamm_current_eth[0]) + + # Relative values (÷ respective feeless weighted baseline) + zerofee_rel = reclamm_zerofee_eth / bal_eth + launch_rel = reclamm_launch_eth / bal_eth + current_rel = reclamm_current_eth / bal_eth + csv_sim_rel = csv_sim / csv_feeless + csv_hold_rel = csv_hold / csv_feeless + csv_world_rel = csv_world / csv_feeless + + # ── Plot ────────────────────────────────────────────────────────── + fig, axs = plt.subplots(2, 2, figsize=(13, 8)) + + # Top-left: price + axs[0][0].plot(t, csv_price, label="reclamm-sim", alpha=0.8) + axs[0][0].plot(t, qsim_price, label="quantammsim", alpha=0.8, linestyle="--") + axs[0][0].set_ylabel("WETH/AAVE") + axs[0][0].set_title("Price") + axs[0][0].set_ylim(bottom=0) + axs[0][0].legend(fontsize=8) + if gov_idx < n: + axs[0][0].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + # Top-right: remove (legend is on other panels) + axs[0][1].remove() + + # Bottom-left: absolute values in WETH + axs[1][0].plot(t, bal_norm, label="qsim feeless weighted", linewidth=2, color="blue") + axs[1][0].plot(t, launch_norm, label="qsim reClAMM (launch)", linewidth=2, color="orange") + axs[1][0].plot(t, current_norm, label="qsim reClAMM (current)", linewidth=2, + color="purple", linestyle="-.") + axs[1][0].plot(t, csv_sim, label="reclamm-sim simulation", linewidth=1.5, + linestyle="--", color="red") + axs[1][0].plot(t, csv_hold, label="hold", linewidth=1.5, color="green") + axs[1][0].plot(t, csv_world, label="world values", linewidth=1.5, + marker=".", markersize=2, color="brown") + axs[1][0].set_title("Value histories") + axs[1][0].set_xlabel("half days") + axs[1][0].set_ylabel("Value in WETH") + axs[1][0].set_ylim(bottom=0) + axs[1][0].legend(fontsize=7, loc="upper right") + if gov_idx < n: + axs[1][0].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + axs[1][0].text(gov_idx + 1, axs[1][0].get_ylim()[1] * 0.95, + "governance", fontsize=7, color="gray", va="top") + + # Bottom-right: relative to feeless weighted + axs[1][1].axhline(y=1.0, color="blue", linewidth=2, label="feeless weighted") + axs[1][1].plot(t, launch_rel, label="qsim reClAMM (launch)", linewidth=2, color="orange") + axs[1][1].plot(t, current_rel, label="qsim reClAMM (current)", linewidth=2, + color="purple", linestyle="-.") + axs[1][1].plot(t, csv_sim_rel, label="reclamm-sim simulation", linewidth=1.5, + linestyle="--", color="red") + axs[1][1].plot(t, csv_hold_rel, label="hold", linewidth=1.5, color="green") + axs[1][1].plot(t, csv_world_rel, label="world values", linewidth=1.5, + marker=".", markersize=2, color="brown") + axs[1][1].set_title("Value relative to feeless weighted") + axs[1][1].set_xlabel("half days") + axs[1][1].set_ylabel("relative value") + axs[1][1].legend(fontsize=7, loc="lower left") + if gov_idx < n: + axs[1][1].axvline(x=gov_idx, color="gray", linestyle=":", alpha=0.6) + + tokens_str = "/".join(args.tokens) + fig.suptitle( + f"quantammsim vs reclamm-simulations — {tokens_str}\n" + f"Launch: {list(ONCHAIN_LAUNCH_PARAMS.values())}, " + f"Current: {list(ONCHAIN_CURRENT_PARAMS.values())}, " + f"fees: {ONCHAIN_FEES}", + fontsize=10, + ) + plt.tight_layout() + plt.savefig(args.output, dpi=150, bbox_inches="tight") + print(f"\nSaved: {args.output}") + + # ── Zero-fee comparison plot (minute-level) ─────────────────────── + # Revalue reclamm-sim balances at quantammsim's price so both sides + # use the same price and the comparison is purely about balances. + # Skip row 0 of the CSV (initial state before first arb) to align + # with quantammsim's reserves[0] which is post-first-step. + ext_bal_0 = df_zf_min["balance_0"].values[1:] + ext_bal_1 = df_zf_min["balance_1"].values[1:] + n_zf = min(len(reclamm_zerofee_min), len(ext_bal_0), len(qsim_price_min)) + ext_val_repriced = ( + ext_bal_0[:n_zf] * qsim_price_min[:n_zf] + ext_bal_1[:n_zf] + ) + qsim_growth = reclamm_zerofee_min[:n_zf] / reclamm_zerofee_min[0] + ext_growth = ext_val_repriced[:n_zf] / ext_val_repriced[0] + pct_dev = (qsim_growth / ext_growth - 1) * 100 + days = np.arange(n_zf) / 1440 + + zerofee_title = ( + f"Zero-fee zero-gas reClAMM: quantammsim / reclamm-sim (minute-level) — {tokens_str}\n" + f"params: {list(ONCHAIN_LAUNCH_PARAMS.values())}" + ) + daily_smooth = pd.Series(pct_dev).rolling(1440, center=True, min_periods=720).mean() + + # Plot 1: with daily smoothing overlay + fig2, ax2 = plt.subplots(figsize=(12, 5)) + ax2.plot(days, pct_dev, linewidth=0.5, color="teal", alpha=0.6) + ax2.plot(days, daily_smooth, linewidth=2, color="darkblue", label="daily smoothed") + ax2.axhline(y=0.0, color="gray", linestyle="--", alpha=0.6) + ax2.set_xlabel("days") + ax2.set_ylabel("deviation (%)") + ax2.set_title(zerofee_title, fontsize=11) + ax2.legend(fontsize=9) + ax2.grid(True, alpha=0.3) + plt.tight_layout() + zerofee_path = args.output.replace(".png", "_zerofee_ratio.png") + plt.savefig(zerofee_path, dpi=150, bbox_inches="tight") + print(f"Saved: {zerofee_path}") + plt.close() + + # Plot 2: raw minute-level only (no smoothing) + fig3, ax3 = plt.subplots(figsize=(12, 5)) + ax3.plot(days, pct_dev, linewidth=0.5, color="teal", alpha=0.8) + ax3.axhline(y=0.0, color="gray", linestyle="--", alpha=0.6) + ax3.set_xlabel("days") + ax3.set_ylabel("deviation (%)") + ax3.set_title(zerofee_title, fontsize=11) + ax3.grid(True, alpha=0.3) + plt.tight_layout() + zerofee_raw_path = args.output.replace(".png", "_zerofee_ratio_raw.png") + plt.savefig(zerofee_raw_path, dpi=150, bbox_inches="tight") + print(f"Saved: {zerofee_raw_path}") + plt.close() + + +if __name__ == "__main__": + main() From 427044a17e51f2bcc6495ffce5d25b17b27272d5 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:51:35 +0000 Subject: [PATCH 05/14] add calibration to pyproject --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 413faf9..8859b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,10 @@ docs = [ "sphinx-automodapi", "sphinx-rtd-theme", ] +calibration = [ + "numpyro>=0.15.0", + "arviz>=0.15.0", +] [tool.hatch.build.targets.wheel] packages = [ From fa39d72fa28ed3d1f3f285164489907f504b8e3d Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:51:47 +0000 Subject: [PATCH 06/14] core noise model in base pool updates --- quantammsim/pools/base_pool.py | 106 ++++++++++++++++++- quantammsim/pools/noise_trades.py | 164 ++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 4 deletions(-) diff --git a/quantammsim/pools/base_pool.py b/quantammsim/pools/base_pool.py index d8c6d75..01845db 100644 --- a/quantammsim/pools/base_pool.py +++ b/quantammsim/pools/base_pool.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple +from functools import partial import numpy as np import jax.numpy as jnp from jax.nn import softmax -from jax.lax import stop_gradient -from jax import tree_util +from jax.lax import stop_gradient, dynamic_slice +from jax import tree_util, jit, vmap from quantammsim.core_simulator.param_utils import make_vmap_in_axes_dict @@ -93,7 +94,11 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - dynamic_inputs: Any, + fees_array: jnp.ndarray, + arb_thresh_array: jnp.ndarray, + arb_fees_array: jnp.ndarray, + trade_array: jnp.ndarray, + lp_supply_array: jnp.ndarray = None, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: pass @@ -278,6 +283,99 @@ def add_noise( params[key] = jnp.array(params[key]) return params + @partial(jit, static_argnums=(2, 3)) + def calculate_volatility_array(self, prices, run_fingerprint, subsample_freq=5): + """Annualised daily realised volatility broadcast to minute-level array. + + Pure-JAX implementation (vmap + dynamic_slice) — JIT-compatible and + callable from within traced contexts (e.g. forward_pass). + + Parameters + ---------- + prices : jnp.ndarray, shape (T, 2) + Minute-level prices for two tokens. + run_fingerprint : dict + Must contain ``tokens`` and ``numeraire`` for ordering. + subsample_freq : int + Subsample within each day to reduce microstructure noise. + + Returns + ------- + jnp.ndarray, shape (T,) + Annualised volatility, constant within each day. + """ + ordered_prices, needs_swap = self._handle_numeraire_ordering( + prices, run_fingerprint, + ) + asset_prices = ordered_prices[:, 0] / ordered_prices[:, 1] + n_minutes = len(asset_prices) + + # Guard: need at least one full day for vmap + dynamic_slice + if n_minutes < 1440: + return jnp.full(n_minutes, 0.1) * jnp.sqrt(365.0) + + n_days = n_minutes // 1440 + + def calculate_daily_volatility(day_idx): + start_idx = day_idx * 1440 + window_prices = dynamic_slice(asset_prices, [start_idx], [1440]) + subsampled_prices = window_prices[::subsample_freq] + log_prices = jnp.log(jnp.maximum(subsampled_prices, 1e-8)) + returns = jnp.diff(log_prices) + num_nonzero_returns = jnp.sum(returns != 0) + total_returns = len(returns) + adjusted_variance = ( + num_nonzero_returns * jnp.var(returns) / total_returns + ) + dt = subsample_freq / 1440 + vol = jnp.sqrt(adjusted_variance) / jnp.sqrt(dt) + return vol + + daily_volatilities = vmap(calculate_daily_volatility)(jnp.arange(n_days)) + volatility_array = jnp.repeat(daily_volatilities, 1440) + + remaining_minutes = n_minutes - len(volatility_array) + if remaining_minutes > 0: + last_vol = ( + daily_volatilities[-1] if len(daily_volatilities) > 0 else 0.1 + ) + volatility_array = jnp.concatenate( + [volatility_array, jnp.full(remaining_minutes, last_vol)] + ) + + return volatility_array * jnp.sqrt(365.0) + + @partial(jit, static_argnums=(2,)) + def _handle_numeraire_ordering( + self, + prices: jnp.ndarray, + run_fingerprint: Dict[str, Any], + ) -> Tuple[jnp.ndarray, bool]: + """Reorder prices so numeraire token is in second position. + + Parameters + ---------- + prices : jnp.ndarray, shape (..., 2) + Price array with two tokens. + run_fingerprint : dict + Must contain ``tokens`` (sorted) and ``numeraire``. + + Returns + ------- + (ordered_prices, needs_swap) : (jnp.ndarray, bool) + """ + tokens = sorted(run_fingerprint["tokens"]) + numeraire = run_fingerprint["numeraire"] + if numeraire is None or numeraire not in tokens: + numeraire = tokens[-1] + needs_swap = tokens.index(numeraire) == 0 + + if needs_swap: + ordered_prices = prices[..., ::-1] + else: + ordered_prices = prices + return ordered_prices, needs_swap + def _tree_flatten(self): children = () aux_data = dict() # static values diff --git a/quantammsim/pools/noise_trades.py b/quantammsim/pools/noise_trades.py index f78b75c..21a574c 100644 --- a/quantammsim/pools/noise_trades.py +++ b/quantammsim/pools/noise_trades.py @@ -110,3 +110,167 @@ def calculate_reserves_after_noise_trade( ) reserves = current_reserves * ratio_of_value_of_trade_to_reserves return reserves + + +@jit +def reclamm_tsoukalas_sqrt_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """reClAMM Tsoukalas sqrt model: effective TVL regressor. + + Predicts per-minute noise trader volume using: + V_daily = (a_0 - a_f*fee + a_sigma*sigma + + a_c*sqrt(c_eff/1e6)) * 1e6 + V_noise = max(0, V_daily/1440 - arb_volume_this_period) + + where c_eff = (Ra+Va)*pA + (Rb+Vb)*pB is the effective TVL (real + + virtual reserves valued in USD). For a concentrated liquidity pool, + effective reserves determine execution quality and routing decisions, + so they are the natural driver of noise volume. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Regression coefficients. Keys: a_0_base, a_f, a_sigma, + a_c, base_fee. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + a_0_base = noise_params.get("a_0_base", 0.5) + a_f = noise_params.get("a_f", 0.0) + a_sigma = noise_params.get("a_sigma", 2.0) + a_c = noise_params.get("a_c", 1.0) + base_fee = noise_params.get("base_fee", 0.003) + + fee = 1.0 - gamma + a_0 = a_0_base + base_fee * a_f + daily_vol = ( + a_0 - a_f * fee + + a_sigma * volatility + + a_c * jnp.sqrt(effective_value_usd / 1e6) + ) * 1e6 + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + +@jit +def reclamm_tsoukalas_log_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """reClAMM Tsoukalas log model: log(c_eff/1e6) instead of sqrt. + + Same specification as the sqrt variant but uses log regressor, + which may fit better for pools spanning a wide TVL range. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Regression coefficients (same keys as sqrt variant). + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + a_0_base = noise_params.get("a_0_base", 0.5) + a_f = noise_params.get("a_f", 0.0) + a_sigma = noise_params.get("a_sigma", 2.0) + a_c = noise_params.get("a_c", 1.0) + base_fee = noise_params.get("base_fee", 0.003) + + fee = 1.0 - gamma + a_0 = a_0_base + base_fee * a_f + daily_vol = ( + a_0 - a_f * fee + + a_sigma * volatility + + a_c * jnp.log(jnp.maximum(effective_value_usd / 1e6, 1e-30)) + ) * 1e6 + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + +@jit +def reclamm_loglinear_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + noise_params=None, +): + """Loglinear noise volume from hierarchical cross-pool calibration. + + Predicts per-minute noise volume using: + log(V_daily) = b_0 + b_sigma * volatility + b_c * log(TVL) + V_noise = max(0, exp(log_daily_vol) / 1440 - arb_volume) + + where b_0 is a pool-specific intercept (BLUP from the hierarchical + model, absorbing chain, token tier, and fee effects), and b_sigma, + b_c are shared fixed effects estimated from cross-pool variation. + + Note: ``gamma`` is accepted for interface compatibility with the + other noise volume functions but is not used; fee effects are + absorbed into ``b_0`` via the hierarchical model's BLUP. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). Unused — kept for uniform + calling convention across noise models. + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + noise_params : dict, optional + Hierarchical model coefficients. Keys: b_0, b_sigma, b_c. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + b_0 = noise_params.get("b_0", -6.7) + b_sigma = noise_params.get("b_sigma", -0.0007) + b_c = noise_params.get("b_c", 1.04) + + log_daily_vol = ( + b_0 + + b_sigma * volatility + + b_c * jnp.log(jnp.maximum(effective_value_usd, 1.0)) + ) + daily_vol = jnp.exp(log_daily_vol) + return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) + + From fde87af188ddf061506898ca2361b3ab5e664cee Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:56:19 +0000 Subject: [PATCH 07/14] add noise tests --- tests/noise/__init__.py | 0 tests/noise/conftest.py | 274 ++++++++++++ tests/noise/test_covariate_encoding.py | 326 ++++++++++++++ tests/noise/test_formula_arb.py | 142 ++++++ tests/noise/test_model_and_inference.py | 328 ++++++++++++++ tests/noise/test_model_dp_sigma.py | 305 +++++++++++++ tests/noise/test_model_structural.py | 215 +++++++++ tests/noise/test_output.py | 303 +++++++++++++ tests/noise/test_panel_assembly.py | 442 +++++++++++++++++++ tests/noise/test_postprocessing.py | 533 +++++++++++++++++++++++ tests/noise/test_token_classification.py | 66 +++ 11 files changed, 2934 insertions(+) create mode 100644 tests/noise/__init__.py create mode 100644 tests/noise/conftest.py create mode 100644 tests/noise/test_covariate_encoding.py create mode 100644 tests/noise/test_formula_arb.py create mode 100644 tests/noise/test_model_and_inference.py create mode 100644 tests/noise/test_model_dp_sigma.py create mode 100644 tests/noise/test_model_structural.py create mode 100644 tests/noise/test_output.py create mode 100644 tests/noise/test_panel_assembly.py create mode 100644 tests/noise/test_postprocessing.py create mode 100644 tests/noise/test_token_classification.py diff --git a/tests/noise/__init__.py b/tests/noise/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/noise/conftest.py b/tests/noise/conftest.py new file mode 100644 index 0000000..5bfeadc --- /dev/null +++ b/tests/noise/conftest.py @@ -0,0 +1,274 @@ +"""Shared fixtures for noise calibration tests.""" + +from datetime import date, timedelta + +import numpy as np +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# synthetic_panel: 3 pools × 10 days = 30 obs (before lag drop → 27 obs) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_panel() -> pd.DataFrame: + """3 pools × 10 days with known structure. + + Pool A: MAINNET, WETH/USDC, tier_A=0, tier_B=0, fee=0.003 + Pool B: ARBITRUM, BAL/WETH, tier_A=0, tier_B=1, fee=0.01 + Pool C: BASE, RATS/WETH, tier_A=0, tier_B=2, fee=0.005 + + Dates: 2026-01-01 to 2026-01-10 + Weekend flags: Sat 2026-01-03 and Sun 2026-01-04 are weekends. + (2026-01-01 is Thursday, ..., 01-03 Sat, 01-04 Sun, 01-05 Mon, ...) + """ + np.random.seed(42) + + pools = [ + ("pool_A", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("pool_B", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("pool_C", "BASE", "RATS,WETH", 0.005, 0, 2), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + + records = [] + for pool_id, chain, tokens, fee, tier_a, tier_b in pools: + log_tvl_base = 14.0 + np.random.randn() * 0.5 + for d in dates: + log_tvl = log_tvl_base + np.random.randn() * 0.1 + log_vol = log_tvl - 2.0 + np.random.randn() * 0.3 + vol = 0.3 + np.random.rand() * 0.2 + is_weekend = 1.0 if d.weekday() >= 5 else 0.0 + + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": d, + "log_volume": log_vol, + "log_tvl": log_tvl, + "volatility": vol, + "weekend": is_weekend, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, + "tier_A": tier_a, + "tier_B": tier_b, + "tokens": tokens, + }) + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + # Structural model covariates + panel["log_sigma"] = np.log(np.maximum(panel["volatility"].values, 1e-6)) + dow = panel["date"].apply( + lambda d: d.weekday() if hasattr(d, "weekday") else pd.Timestamp(d).weekday() + ) + panel["dow_sin"] = np.sin(2.0 * np.pi * dow / 7.0) + panel["dow_cos"] = np.cos(2.0 * np.pi * dow / 7.0) + panel["tvl_x_sigma"] = panel["log_tvl_lag1"] * panel["log_sigma"] + panel["tvl_x_fee"] = panel["log_tvl_lag1"] * panel["log_fee"] + panel["sigma_x_fee"] = panel["log_sigma"] * panel["log_fee"] + + return panel + + +# --------------------------------------------------------------------------- +# synthetic_encoded_data: output of encode_covariates(synthetic_panel) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_encoded_data(synthetic_panel): + from quantammsim.noise_calibration import encode_covariates + return encode_covariates(synthetic_panel) + + +# --------------------------------------------------------------------------- +# synthetic_samples: deterministic posterior-like dict +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_samples(synthetic_encoded_data): + """Deterministic posterior samples with eta=0, L_Omega=I. + + With this structure: theta = X_pool @ B^T exactly. + """ + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_coeff = 4 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + sigma_theta = np.ones((S, K_coeff)) + L_Omega = np.tile(np.eye(K_coeff), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_coeff)) + df = np.full((S,), 5.0) + sigma_eps = np.tile([0.5, 0.8, 0.6], (S, 1)) + + return { + "B": B, + "sigma_theta": sigma_theta, + "L_Omega": L_Omega, + "eta": eta, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_pools_df: matches enumerate_balancer_pools output schema +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_pools_df() -> pd.DataFrame: + return pd.DataFrame([ + { + "pool_id": "pool_A", + "chain": "MAINNET", + "pool_type": "WEIGHTED", + "tokens": ["WETH", "USDC"], + "token_addresses": ["0xweth", "0xusdc"], + "swap_fee": 0.003, + "current_tvl": 1_000_000, + }, + { + "pool_id": "pool_B", + "chain": "ARBITRUM", + "pool_type": "WEIGHTED", + "tokens": ["BAL", "WETH"], + "token_addresses": ["0xbal", "0xweth"], + "swap_fee": 0.01, + "current_tvl": 500_000, + }, + { + "pool_id": "pool_C", + "chain": "BASE", + "pool_type": "WEIGHTED", + "tokens": ["RATS", "WETH"], + "token_addresses": ["0xrats", "0xweth"], + "swap_fee": 0.005, + "current_tvl": 100_000, + }, + ]) + + +# --------------------------------------------------------------------------- +# synthetic_ibp_samples: IBP posterior-like dict (no eta, L_Omega, sigma_theta) +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_ibp_samples(synthetic_encoded_data): + """Deterministic IBP posterior samples (marginalized model). + + Contains B, W, v_ibp, alpha_ibp — no z_logit, eta, L_Omega, sigma_theta. + Z is analytically marginalized; MAP assignments are computed from data. + """ + data = synthetic_encoded_data + K_cov = data["K_cov"] + K_coeff = 4 + K_features = 6 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + W = np.random.randn(S, K_features, K_coeff) * 0.3 + v_ibp = np.random.beta(2, 1, size=(S, K_features)) + alpha_ibp = np.full((S,), 2.0) + sigma_w = np.full((S,), 1.0) + df = np.full((S,), 5.0) + sigma_eps = np.full((S,), 0.5) + + return { + "B": B, + "W": W, + "v_ibp": v_ibp, + "alpha_ibp": alpha_ibp, + "sigma_w": sigma_w, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_ibp_dp_samples: hybrid IBP+DP posterior-like dict +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_ibp_dp_samples(synthetic_encoded_data): + """Deterministic hybrid IBP+DP posterior samples. + + Contains both IBP keys (B, W, v_ibp, alpha_ibp, sigma_w) and + DP keys (v, alpha_dp, sigma_eps as vector). No z_logit, eta, L_Omega, + sigma_theta. + """ + data = synthetic_encoded_data + K_cov = data["K_cov"] + K_coeff = 4 + K_features = 6 + K_clusters = 6 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_coeff, K_cov) * 0.5 + W = np.random.randn(S, K_features, K_coeff) * 0.3 + v_ibp = np.random.beta(2, 1, size=(S, K_features)) + alpha_ibp = np.full((S,), 2.0) + sigma_w = np.full((S,), 1.0) + v = np.random.beta(1, 2, size=(S, K_clusters - 1)) + alpha_dp = np.full((S,), 1.0) + df = np.full((S,), 5.0) + sigma_eps = np.abs(np.random.randn(S, K_clusters)) + 0.1 + + return { + "B": B, + "W": W, + "v_ibp": v_ibp, + "alpha_ibp": alpha_ibp, + "sigma_w": sigma_w, + "v": v, + "alpha_dp": alpha_dp, + "df": df, + "sigma_eps": sigma_eps, + } + + +# --------------------------------------------------------------------------- +# synthetic_snapshots_df: matches fetch_all_snapshots output schema +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# synthetic_structural_data: output of encode_covariates_structural() +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_structural_data(synthetic_panel): + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + return encode_covariates_structural(synthetic_panel) + + +# --------------------------------------------------------------------------- +# synthetic_snapshots_df: matches fetch_all_snapshots output schema +# --------------------------------------------------------------------------- + +@pytest.fixture() +def synthetic_snapshots_df() -> pd.DataFrame: + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + records = [] + np.random.seed(42) + for pool_id, chain in [("pool_A", "MAINNET"), ("pool_B", "ARBITRUM"), + ("pool_C", "BASE")]: + for d in dates: + records.append({ + "pool_id": pool_id, + "chain": chain, + "date": d, + "volume_usd": np.exp(10.0 + np.random.randn() * 0.5), + "total_liquidity_usd": np.exp(14.0 + np.random.randn() * 0.3), + }) + return pd.DataFrame(records) diff --git a/tests/noise/test_covariate_encoding.py b/tests/noise/test_covariate_encoding.py new file mode 100644 index 0000000..dececb5 --- /dev/null +++ b/tests/noise/test_covariate_encoding.py @@ -0,0 +1,326 @@ +"""Tests for encode_covariates and encode_covariates_structural.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import encode_covariates +from quantammsim.noise_calibration.constants import K_OBS_COEFF, OBS_COEFF_NAMES + + +class TestEncodeCovariates: + def test_x_pool_shape(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + N_pools = data["N_pools"] + K_cov = data["K_cov"] + assert data["X_pool"].shape == (N_pools, K_cov) + + def test_x_obs_shape(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + N_obs = len(synthetic_panel) + assert data["x_obs"].shape == (N_obs, 4) + + def test_x_obs_column_0_is_intercept(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal(data["x_obs"][:, 0], 1.0) + + def test_x_obs_column_1_is_lagged_tvl(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 1], + synthetic_panel["log_tvl_lag1"].values, + ) + + def test_x_obs_column_2_is_volatility(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 2], + synthetic_panel["volatility"].values, + ) + + def test_x_obs_column_3_is_weekend(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["x_obs"][:, 3], + synthetic_panel["weekend"].values, + ) + + def test_intercept_column_all_ones(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal(data["X_pool"][:, 0], 1.0) + + def test_chain_dummies_one_hot(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + chain_cols = [i for i, n in enumerate(col_names) if n.startswith("chain_")] + X = data["X_pool"] + + for row in range(X.shape[0]): + chain_vals = X[row, chain_cols] + assert chain_vals.sum() <= 1.0 + + def test_reference_chain_is_alphabetically_first(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + chains = sorted(synthetic_panel["chain"].unique()) + ref_chain = chains[0] # ARBITRUM + + pool_meta = data["pool_meta"] + ref_idx = pool_meta[pool_meta["chain"] == ref_chain].index[0] + col_names = data["covariate_names"] + chain_cols = [i for i, n in enumerate(col_names) if n.startswith("chain_")] + X = data["X_pool"] + assert all(X[ref_idx, c] == 0.0 for c in chain_cols) + + def test_tier_a_dummies_match(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + tier_a_cols = [ + (i, n) for i, n in enumerate(col_names) if n.startswith("tier_A_") + ] + pool_meta = data["pool_meta"] + X = data["X_pool"] + + for idx, row in pool_meta.iterrows(): + tier_a_str = str(row["tier_A"]) + for col_idx, col_name in tier_a_cols: + expected_tier = col_name.split("_")[-1] + expected = 1.0 if tier_a_str == expected_tier else 0.0 + assert X[idx, col_idx] == expected, ( + f"Pool {idx} tier_A={tier_a_str}, col {col_name}: " + f"expected {expected}, got {X[idx, col_idx]}" + ) + + def test_tier_b_dummies_match(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + tier_b_cols = [ + (i, n) for i, n in enumerate(col_names) if n.startswith("tier_B_") + ] + pool_meta = data["pool_meta"] + X = data["X_pool"] + + for idx, row in pool_meta.iterrows(): + tier_b_str = str(row["tier_B"]) + for col_idx, col_name in tier_b_cols: + expected_tier = col_name.split("_")[-1] + expected = 1.0 if tier_b_str == expected_tier else 0.0 + assert X[idx, col_idx] == expected + + def test_log_fee_column(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + col_names = data["covariate_names"] + fee_idx = col_names.index("log_fee") + pool_meta = data["pool_meta"] + + for idx, row in pool_meta.iterrows(): + expected = np.log(max(row["swap_fee"], 1e-6)) + np.testing.assert_allclose( + data["X_pool"][idx, fee_idx], expected, rtol=1e-10, + ) + + def test_pool_idx_maps_observations(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + pool_ids = data["pool_ids"] + pool_idx = data["pool_idx"] + + assert pool_idx.min() >= 0 + assert pool_idx.max() < len(pool_ids) + + for i, pid in enumerate(pool_ids): + mask = synthetic_panel["pool_id"] == pid + obs_indices = np.where(mask.values)[0] + assert (pool_idx[obs_indices] == i).all() + + def test_y_obs_matches_panel(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + np.testing.assert_array_equal( + data["y_obs"], + synthetic_panel["log_volume"].values, + ) + + def test_covariate_names_length(self, synthetic_panel): + data = encode_covariates(synthetic_panel) + assert len(data["covariate_names"]) == data["K_cov"] + + def test_covariate_column_ordering(self, synthetic_panel): + """X_pool columns must follow: intercept, chain dummies, tier_A dummies, + tier_B dummies, log_fee. This ordering is load-bearing because B is + indexed by column position.""" + data = encode_covariates(synthetic_panel) + names = data["covariate_names"] + + assert names[0] == "intercept" + assert names[-1] == "log_fee" + + # Find boundaries + chain_start = None + tier_a_start = None + tier_b_start = None + fee_idx = len(names) - 1 + + for i, n in enumerate(names): + if n.startswith("chain_") and chain_start is None: + chain_start = i + if n.startswith("tier_A_") and tier_a_start is None: + tier_a_start = i + if n.startswith("tier_B_") and tier_b_start is None: + tier_b_start = i + + # Verify ordering: intercept < chains < tier_A < tier_B < log_fee + if chain_start is not None: + assert chain_start > 0 # after intercept + if tier_a_start is not None and chain_start is not None: + assert tier_a_start > chain_start + if tier_b_start is not None and tier_a_start is not None: + assert tier_b_start > tier_a_start + if tier_b_start is not None: + assert fee_idx > tier_b_start + + # All chain dummies are contiguous + chain_names = [n for n in names if n.startswith("chain_")] + if chain_names: + chain_indices = [names.index(n) for n in chain_names] + assert chain_indices == list(range(min(chain_indices), + max(chain_indices) + 1)) + + def test_output_dict_has_all_required_keys(self, synthetic_panel): + """encode_covariates must return all keys consumed by downstream + functions (predict_new_pool, generate_output_json, _save_sample_cache).""" + data = encode_covariates(synthetic_panel) + required_keys = { + "pool_idx", "X_pool", "x_obs", "y_obs", "pool_ids", "pool_meta", + "covariate_names", "tier_A_per_pool", "N_pools", "K_cov", + "ref_chain", "ref_tier_a", "ref_tier_b", "chains", + } + assert required_keys.issubset(data.keys()), ( + f"Missing keys: {required_keys - data.keys()}" + ) + + +class TestEncodeCovariatesNoTiers: + """Tests for encode_covariates(include_tiers=False).""" + + def test_no_tier_columns_in_covariate_names(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + for name in data["covariate_names"]: + assert not name.startswith("tier_A_"), ( + f"Found tier_A column {name} with include_tiers=False" + ) + assert not name.startswith("tier_B_"), ( + f"Found tier_B column {name} with include_tiers=False" + ) + + def test_still_has_intercept_and_log_fee(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + assert "intercept" in data["covariate_names"] + assert "log_fee" in data["covariate_names"] + + def test_still_has_chain_dummies(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + chain_cols = [n for n in data["covariate_names"] + if n.startswith("chain_")] + assert len(chain_cols) > 0 + + def test_k_cov_smaller_than_with_tiers(self, synthetic_panel): + data_tiers = encode_covariates(synthetic_panel, include_tiers=True) + data_no_tiers = encode_covariates(synthetic_panel, include_tiers=False) + assert data_no_tiers["K_cov"] < data_tiers["K_cov"] + + def test_default_is_include_tiers_true(self, synthetic_panel): + data_default = encode_covariates(synthetic_panel) + data_explicit = encode_covariates(synthetic_panel, include_tiers=True) + assert data_default["K_cov"] == data_explicit["K_cov"] + + def test_tier_A_per_pool_still_present(self, synthetic_panel): + """tier_A_per_pool is still returned (for other downstream uses).""" + data = encode_covariates(synthetic_panel, include_tiers=False) + assert "tier_A_per_pool" in data + + def test_x_pool_shape_matches_k_cov(self, synthetic_panel): + data = encode_covariates(synthetic_panel, include_tiers=False) + assert data["X_pool"].shape == (data["N_pools"], data["K_cov"]) + + +class TestEncodeStructuralCovariates: + """Tests for encode_covariates_structural().""" + + @pytest.fixture() + def struct_data(self, synthetic_panel): + from quantammsim.noise_calibration.covariate_encoding import ( + encode_covariates_structural, + ) + return encode_covariates_structural(synthetic_panel) + + def test_encode_structural_x_obs_shape(self, struct_data, synthetic_panel): + N_obs = len(synthetic_panel) + assert struct_data["x_obs"].shape == (N_obs, K_OBS_COEFF) + + def test_encode_structural_x_obs_columns(self, struct_data, synthetic_panel): + """Columns must match OBS_COEFF_NAMES ordering: + [1, lag_log_tvl, log_sigma, tvl_x_sigma, tvl_x_fee, sigma_x_fee, + dow_sin, dow_cos].""" + x = struct_data["x_obs"] + np.testing.assert_array_equal(x[:, 0], 1.0) # intercept + np.testing.assert_array_equal( + x[:, 1], synthetic_panel["log_tvl_lag1"].values, + ) + np.testing.assert_array_equal( + x[:, 2], synthetic_panel["log_sigma"].values, + ) + np.testing.assert_allclose( + x[:, 3], synthetic_panel["tvl_x_sigma"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 4], synthetic_panel["tvl_x_fee"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 5], synthetic_panel["sigma_x_fee"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 6], synthetic_panel["dow_sin"].values, rtol=1e-12, + ) + np.testing.assert_allclose( + x[:, 7], synthetic_panel["dow_cos"].values, rtol=1e-12, + ) + + def test_encode_structural_has_sigma_daily(self, struct_data): + """sigma_daily = volatility / sqrt(365), de-annualised.""" + assert "sigma_daily" in struct_data + assert len(struct_data["sigma_daily"]) > 0 + + def test_encode_structural_has_gas(self, struct_data): + assert "gas" in struct_data + assert len(struct_data["gas"]) > 0 + assert (struct_data["gas"] >= 0).all() + + def test_encode_structural_has_chain_idx_tier_idx(self, struct_data): + assert "chain_idx" in struct_data + assert "tier_idx" in struct_data + assert struct_data["chain_idx"].dtype in (np.int32, np.int64) + assert struct_data["tier_idx"].dtype in (np.int32, np.int64) + + def test_encode_structural_has_fee(self, struct_data): + """fee array is raw (not log), for the formula.""" + assert "fee" in struct_data + assert (struct_data["fee"] > 0).all() + assert (struct_data["fee"] < 1).all() # fees are fractions + + def test_encode_structural_tier_idx_is_pair(self, struct_data): + """tier_idx encodes the (tier_A, tier_B) PAIR, not individual tokens. + (0,0)->0, (0,1)->1, (0,2)->2, (1,1)->3, (1,2)->4, (2,2)->5.""" + tier_idx = struct_data["tier_idx"] + pool_meta = struct_data["pool_meta"] + + for i, row in pool_meta.iterrows(): + a, b = int(row["tier_A"]), int(row["tier_B"]) + expected = a * (5 - a) // 2 + b - a + assert tier_idx[i] == expected, ( + f"Pool {i}: tier ({a},{b}) expected idx {expected}, " + f"got {tier_idx[i]}" + ) + + def test_encode_structural_n_chains_n_tiers(self, struct_data): + """n_chains and n_tiers computed from data.""" + assert "n_chains" in struct_data + assert "n_tiers" in struct_data + assert struct_data["n_chains"] >= 1 + assert struct_data["n_tiers"] >= 1 diff --git a/tests/noise/test_formula_arb.py b/tests/noise/test_formula_arb.py new file mode 100644 index 0000000..f862d51 --- /dev/null +++ b/tests/noise/test_formula_arb.py @@ -0,0 +1,142 @@ +"""Tests for JAX-differentiable LVR formula.""" + +import numpy as np +import pytest + + +class TestFormulaArbJax: + @pytest.fixture(autouse=True) + def _import(self): + from quantammsim.noise_calibration.formula_arb import ( + formula_arb_volume_daily_jax, + ) + self.formula = formula_arb_volume_daily_jax + + def test_formula_arb_zero_vol_returns_zero(self): + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.0), + tvl=jnp.float64(1e6), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-10) + + def test_formula_arb_zero_tvl_returns_zero(self): + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.03), + tvl=jnp.float64(0.0), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-10) + + def test_formula_arb_quadratic_in_sigma(self): + """V_arb(2σ) / V_arb(σ) ≈ 4 for small gas (correction ≈ 1).""" + import jax.numpy as jnp + sigma = jnp.float64(0.01) + tvl = jnp.float64(1e8) # large TVL so gas is negligible + fee = jnp.float64(0.003) + gas = jnp.float64(0.001) # tiny gas + cadence = jnp.float64(0.01) # very fast arb + + v1 = float(self.formula(sigma, tvl, fee, gas, cadence)) + v2 = float(self.formula(2.0 * sigma, tvl, fee, gas, cadence)) + assert v1 > 0 + ratio = v2 / v1 + assert ratio == pytest.approx(4.0, rel=0.1) + + def test_formula_arb_linear_in_tvl(self): + """V_arb(2V) / V_arb(V) ≈ 2 for small gas.""" + import jax.numpy as jnp + sigma = jnp.float64(0.02) + tvl = jnp.float64(1e8) + fee = jnp.float64(0.003) + gas = jnp.float64(0.0001) + cadence = jnp.float64(0.01) + + v1 = float(self.formula(sigma, tvl, fee, gas, cadence)) + v2 = float(self.formula(sigma, 2.0 * tvl, fee, gas, cadence)) + assert v1 > 0 + ratio = v2 / v1 + assert ratio == pytest.approx(2.0, rel=0.1) + + def test_formula_arb_gas_kills_small_pools(self): + """High gas, small TVL → V_arb ≈ 0.""" + import jax.numpy as jnp + result = self.formula( + sigma_daily=jnp.float64(0.02), + tvl=jnp.float64(1000.0), + fee=jnp.float64(0.003), + gas_usd=jnp.float64(1000.0), + cadence_minutes=jnp.float64(1.0), + ) + assert float(result) == pytest.approx(0.0, abs=1e-6) + + def test_formula_arb_matches_numpy_reference(self): + """Compare to the numpy formula in plot_formula_arb_vs_real.py.""" + import jax.numpy as jnp + + # Reference implementation (from plot_formula_arb_vs_real.py:58) + def ref(sigma_daily, tvl, fee, block_time_s, gas_usd): + if tvl <= 0 or fee <= 0 or sigma_daily <= 0: + return 0.0 + gamma = fee + delta = 2.0 * np.sqrt(2.0 * gas_usd / tvl) if gas_usd > 0 else 0.0 + bLVR = sigma_daily**2 * tvl / 8.0 + sqrt_s2_2l = sigma_daily * np.sqrt(block_time_s / (2.0 * 86400.0)) + bFEE = bLVR * max( + 1.0 - delta / (2.0 * gamma) - sqrt_s2_2l / (gamma + delta / 2.0), + 0.0, + ) + return bFEE / gamma + + test_cases = [ + (0.03, 1e6, 0.003, 1.0, 1.0), + (0.05, 5e5, 0.01, 0.5, 0.005), + (0.01, 1e7, 0.005, 2.0, 0.01), + (0.1, 1e4, 0.03, 10.0, 5.0), + (0.02, 1e5, 0.001, 1.0, 0.001), + ] + + for sigma, tvl, fee, cadence_min, gas in test_cases: + block_time_s = cadence_min * 60.0 + expected = ref(sigma, tvl, fee, block_time_s, gas) + actual = float(self.formula( + jnp.float64(sigma), jnp.float64(tvl), + jnp.float64(fee), jnp.float64(gas), + jnp.float64(cadence_min), + )) + np.testing.assert_allclose( + actual, expected, rtol=1e-10, + err_msg=f"Mismatch for sigma={sigma}, tvl={tvl}, " + f"fee={fee}, cadence={cadence_min}, gas={gas}", + ) + + def test_formula_arb_is_jax_differentiable(self): + """jax.grad w.r.t. sigma should run without error.""" + import jax + import jax.numpy as jnp + + grad_fn = jax.grad(self.formula, argnums=0) + result = grad_fn( + jnp.float64(0.03), jnp.float64(1e6), + jnp.float64(0.003), jnp.float64(1.0), + jnp.float64(1.0), + ) + assert np.isfinite(float(result)) + + def test_formula_arb_cadence_reduces_volume(self): + """Higher cadence → less frequent arb → lower volume.""" + import jax.numpy as jnp + sigma = jnp.float64(0.03) + tvl = jnp.float64(1e6) + fee = jnp.float64(0.003) + gas = jnp.float64(1.0) + + v_fast = float(self.formula(sigma, tvl, fee, gas, jnp.float64(1.0))) + v_slow = float(self.formula(sigma, tvl, fee, gas, jnp.float64(10.0))) + assert v_fast > v_slow diff --git a/tests/noise/test_model_and_inference.py b/tests/noise/test_model_and_inference.py new file mode 100644 index 0000000..ee1fda5 --- /dev/null +++ b/tests/noise/test_model_and_inference.py @@ -0,0 +1,328 @@ +"""Tests for noise_model, _get_theta_samples, _build_model_kwargs, and SVI smoke.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + noise_model, + _get_theta_samples, + _build_model_kwargs, + K_COEFF, +) + + +# =========================================================================== +# TestNoiseModelDefinition +# =========================================================================== + + +class TestNoiseModelDefinition: + def test_model_traces_without_error(self, synthetic_encoded_data): + import jax + import jax.numpy as jnp + import numpyro + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + assert trace is not None + + def test_required_sites_present(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + required = {"B", "sigma_theta", "L_Omega", "df", "sigma_eps", "eta", "y"} + assert required.issubset(trace.keys()) + + def test_site_shapes(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + N_pools = data["N_pools"] + K_cov = data["K_cov"] + + assert trace["B"]["value"].shape == (K_COEFF, K_cov) + assert trace["sigma_theta"]["value"].shape == (K_COEFF,) + assert trace["L_Omega"]["value"].shape == (K_COEFF, K_COEFF) + assert trace["df"]["value"].shape == () + assert trace["sigma_eps"]["value"].shape == (3,) + assert trace["eta"]["value"].shape == (N_pools, K_COEFF) + + def test_theta_deterministic_site(self, synthetic_encoded_data): + import jax + import numpyro.handlers as handlers + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace(handlers.seed(noise_model, rng_key)).get_trace( + **kwargs + ) + assert "theta" in trace + assert trace["theta"]["value"].shape == (data["N_pools"], K_COEFF) + + def test_prior_predictive_produces_y(self, synthetic_encoded_data): + import jax + from numpyro.infer import Predictive + + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + kwargs["y_obs"] = None + + predictive = Predictive(noise_model, num_samples=5) + rng_key = jax.random.PRNGKey(42) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 5 + + +# =========================================================================== +# TestGetThetaSamples +# =========================================================================== + + +class TestGetThetaSamples: + def test_returns_theta_directly_when_present(self, synthetic_encoded_data): + data = synthetic_encoded_data + theta_direct = np.random.randn(10, data["N_pools"], K_COEFF) + sample_dict = {"theta": theta_direct} + result = _get_theta_samples(sample_dict, data["X_pool"]) + np.testing.assert_array_equal(result, theta_direct) + + def test_reconstructs_from_non_centered( + self, synthetic_encoded_data, synthetic_samples + ): + data = synthetic_encoded_data + result = _get_theta_samples(synthetic_samples, data["X_pool"]) + assert result.shape == (10, data["N_pools"], K_COEFF) + + def test_eta_zero_identity_gives_mu( + self, synthetic_encoded_data, synthetic_samples + ): + """With eta=0 and L_Omega=I, theta = X_pool @ B^T.""" + data = synthetic_encoded_data + X_pool = data["X_pool"] + B = synthetic_samples["B"] + + result = _get_theta_samples(synthetic_samples, X_pool) + + # Expected: mu[s,p,j] = sum_d X_pool[p,d] * B[s,j,d] + expected = np.einsum("pd,sjd->spj", X_pool, B) + np.testing.assert_allclose(result, expected, atol=1e-12) + + def test_output_shape(self, synthetic_encoded_data, synthetic_samples): + data = synthetic_encoded_data + result = _get_theta_samples(synthetic_samples, data["X_pool"]) + S = synthetic_samples["B"].shape[0] + assert result.shape == (S, data["N_pools"], K_COEFF) + + def test_reconstructed_matches_direct(self, synthetic_encoded_data): + """When both theta and raw params are present, reconstruction matches.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + S = 8 + X_pool = data["X_pool"] + + np.random.seed(123) + B = np.random.randn(S, K_COEFF, K_cov) * 0.3 + sigma_theta = np.abs(np.random.randn(S, K_COEFF)) + 0.1 + # Random lower-triangular L + L_raw = np.zeros((S, K_COEFF, K_COEFF)) + for s in range(S): + A = np.random.randn(K_COEFF, K_COEFF) + L_raw[s] = np.linalg.cholesky(A @ A.T + np.eye(K_COEFF)) + eta = np.random.randn(S, N_pools, K_COEFF) + + sample_dict = { + "B": B, "sigma_theta": sigma_theta, + "L_Omega": L_raw, "eta": eta, + } + + theta_recon = _get_theta_samples(sample_dict, X_pool) + + # Compute expected directly + mu = np.einsum("pd,sjd->spj", X_pool, B) + L_Sigma = sigma_theta[:, :, None] * L_raw + offset = np.einsum("spi,sji->spj", eta, L_Sigma) + expected = mu + offset + + np.testing.assert_allclose(theta_recon, expected, atol=1e-10) + + +# =========================================================================== +# TestBuildModelKwargs +# =========================================================================== + + +class TestBuildModelKwargs: + def test_all_outputs_are_jnp(self, synthetic_encoded_data): + import jax.numpy as jnp + + kwargs = _build_model_kwargs(synthetic_encoded_data) + for key in ["pool_idx", "X_pool", "x_obs", "y_obs", "tier_A_per_pool"]: + assert isinstance(kwargs[key], jnp.ndarray), ( + f"{key} should be jnp array" + ) + + def test_shapes_preserved(self, synthetic_encoded_data): + data = synthetic_encoded_data + kwargs = _build_model_kwargs(data) + assert kwargs["pool_idx"].shape == data["pool_idx"].shape + assert kwargs["X_pool"].shape == data["X_pool"].shape + assert kwargs["x_obs"].shape == data["x_obs"].shape + assert kwargs["y_obs"].shape == data["y_obs"].shape + assert kwargs["N_pools"] == data["N_pools"] + assert kwargs["K_cov"] == data["K_cov"] + + def test_dp_model_kwargs_exclude_tier_A(self, synthetic_encoded_data): + """When model_fn is noise_model_dp_sigma, tier_A_per_pool is excluded + and K_clusters is included.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + data = dict(synthetic_encoded_data) + data["K_clusters"] = 6 + kwargs = _build_model_kwargs(data, model_fn=noise_model_dp_sigma) + assert "tier_A_per_pool" not in kwargs + assert kwargs["K_clusters"] == 6 + + def test_tier_model_kwargs_include_tier_A(self, synthetic_encoded_data): + """When model_fn is noise_model (default), tier_A_per_pool is included + and K_clusters is not.""" + kwargs = _build_model_kwargs(synthetic_encoded_data) + assert "tier_A_per_pool" in kwargs + assert "K_clusters" not in kwargs + + def test_default_model_fn_is_noise_model(self, synthetic_encoded_data): + """Calling without model_fn should behave identically to model_fn=noise_model.""" + kwargs_default = _build_model_kwargs(synthetic_encoded_data) + kwargs_explicit = _build_model_kwargs( + synthetic_encoded_data, model_fn=noise_model + ) + assert set(kwargs_default.keys()) == set(kwargs_explicit.keys()) + + +# =========================================================================== +# TestSVISmoke +# =========================================================================== + + +class TestSVISmoke: + @pytest.mark.slow + def test_svi_converges_small_data(self): + """SVI on tiny synthetic data: ELBO should decrease.""" + import jax + import numpyro + + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + + # Build minimal panel: 5 pools × 20 days + np.random.seed(42) + from datetime import date, timedelta + + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("p2", "BASE", "RATS,WETH", 0.005, 0, 2), + ("p3", "MAINNET", "LINK,WETH", 0.005, 0, 1), + ("p4", "ARBITRUM", "AAVE,USDC", 0.003, 0, 1), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(21)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + + import pandas as pd + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + data = encode_covariates(panel) + samples, losses = run_svi(data, num_steps=2000, lr=1e-3, seed=0, + num_samples=50) + + # ELBO should decrease: mean of last 100 < mean of first 100 + assert np.mean(losses[-100:]) < np.mean(losses[:100]) + + @pytest.mark.slow + def test_svi_samples_have_required_keys(self): + """SVI output dict has all expected latent variable keys.""" + import jax + import numpyro + + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + + np.random.seed(42) + from datetime import date, timedelta + import pandas as pd + + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(15)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + data = encode_covariates(panel) + samples, _ = run_svi(data, num_steps=500, lr=1e-3, seed=0, + num_samples=10) + + required = {"B", "sigma_theta", "L_Omega", "eta", "df", "sigma_eps"} + assert required.issubset(samples.keys()) diff --git a/tests/noise/test_model_dp_sigma.py b/tests/noise/test_model_dp_sigma.py new file mode 100644 index 0000000..3d2be21 --- /dev/null +++ b/tests/noise/test_model_dp_sigma.py @@ -0,0 +1,305 @@ +"""Tests for stick_breaking_weights and noise_model_dp_sigma.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import K_COEFF +from quantammsim.noise_calibration.constants import K_CLUSTERS_DEFAULT + + +# =========================================================================== +# TestStickBreakingWeights +# =========================================================================== + + +class TestStickBreakingWeights: + def test_sums_to_one(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.5, 0.3, 0.4, 0.6, 0.2]) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(jnp.sum(w)), 1.0, atol=1e-6) + + def test_correct_length(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + K = 7 + v = jnp.ones(K - 1) * 0.3 + w = stick_breaking_weights(v) + assert w.shape == (K,) + + def test_non_negative(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.1, 0.9, 0.5, 0.7, 0.3]) + w = stick_breaking_weights(v) + assert jnp.all(w >= 0.0) + + def test_first_weight_equals_first_v(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.array([0.7, 0.4, 0.2]) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(w[0]), 0.7, atol=1e-6) + + def test_jit_compatible(self): + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax + import jax.numpy as jnp + + v = jnp.array([0.5, 0.3, 0.4]) + w_eager = stick_breaking_weights(v) + w_jit = jax.jit(stick_breaking_weights)(v) + np.testing.assert_allclose( + np.array(w_eager), np.array(w_jit), atol=1e-6 + ) + + def test_all_v_one_concentrates_on_first(self): + """If v = [1, 1, ...], all mass goes to first component.""" + from quantammsim.noise_calibration.model import stick_breaking_weights + import jax.numpy as jnp + + v = jnp.ones(5) + w = stick_breaking_weights(v) + np.testing.assert_allclose(float(w[0]), 1.0, atol=1e-6) + np.testing.assert_allclose(float(jnp.sum(w[1:])), 0.0, atol=1e-6) + + +# =========================================================================== +# TestDPModelDefinition +# =========================================================================== + + +class TestDPModelDefinition: + def _get_dp_model_kwargs(self, data, K_clusters=6): + """Build kwargs for noise_model_dp_sigma from encoded data.""" + import jax.numpy as jnp + + return dict( + pool_idx=jnp.array(data["pool_idx"]), + X_pool=jnp.array(data["X_pool"]), + x_obs=jnp.array(data["x_obs"]), + y_obs=jnp.array(data["y_obs"]), + N_pools=data["N_pools"], + K_coeff=K_COEFF, + K_cov=data["K_cov"], + K_clusters=K_clusters, + ) + + def test_model_traces_without_error(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + assert trace is not None + + def test_has_dp_sites(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + dp_sites = {"alpha_dp", "v", "sigma_eps", "log_lik"} + assert dp_sites.issubset(trace.keys()), ( + f"Missing DP sites: {dp_sites - trace.keys()}" + ) + + def test_no_y_site_when_obs_provided(self, synthetic_encoded_data): + """With y_obs provided, the marginalized model uses factor, not obs.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + assert "y" not in trace, ( + "DP model should use numpyro.factor, not obs=y_obs" + ) + + def test_shared_mean_structure_sites(self, synthetic_encoded_data): + """The DP model must share the same mean structure as the tier model.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + shared_sites = {"B", "sigma_theta", "L_Omega", "df", "eta", "theta"} + assert shared_sites.issubset(trace.keys()), ( + f"Missing shared sites: {shared_sites - trace.keys()}" + ) + + def test_correct_shapes(self, synthetic_encoded_data): + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + K_clusters = 6 + kwargs = self._get_dp_model_kwargs( + synthetic_encoded_data, K_clusters=K_clusters + ) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + N_pools = synthetic_encoded_data["N_pools"] + K_cov = synthetic_encoded_data["K_cov"] + + assert trace["B"]["value"].shape == (K_COEFF, K_cov) + assert trace["sigma_theta"]["value"].shape == (K_COEFF,) + assert trace["L_Omega"]["value"].shape == (K_COEFF, K_COEFF) + assert trace["df"]["value"].shape == () + assert trace["sigma_eps"]["value"].shape == (K_clusters,) + assert trace["eta"]["value"].shape == (N_pools, K_COEFF) + assert trace["v"]["value"].shape == (K_clusters - 1,) + assert trace["alpha_dp"]["value"].shape == () + + def test_no_tier_A_per_pool_in_signature(self): + """The DP model should not accept tier_A_per_pool.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import inspect + + sig = inspect.signature(noise_model_dp_sigma) + assert "tier_A_per_pool" not in sig.parameters + + def test_prior_predictive_produces_y(self, synthetic_encoded_data): + """With y_obs=None, the model should sample y explicitly.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + from numpyro.infer import Predictive + + kwargs = self._get_dp_model_kwargs(synthetic_encoded_data) + kwargs["y_obs"] = None + + predictive = Predictive(noise_model_dp_sigma, num_samples=5) + rng_key = jax.random.PRNGKey(42) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 5 + + def test_k_clusters_configurable(self, synthetic_encoded_data): + """K_clusters=4 should produce different sigma_eps shape.""" + from quantammsim.noise_calibration.model import noise_model_dp_sigma + import jax + import numpyro.handlers as handlers + + K_clusters = 4 + kwargs = self._get_dp_model_kwargs( + synthetic_encoded_data, K_clusters=K_clusters + ) + rng_key = jax.random.PRNGKey(0) + + trace = handlers.trace( + handlers.seed(noise_model_dp_sigma, rng_key) + ).get_trace(**kwargs) + + assert trace["sigma_eps"]["value"].shape == (K_clusters,) + assert trace["v"]["value"].shape == (K_clusters - 1,) + + +# =========================================================================== +# TestDPModelSVISmoke +# =========================================================================== + + +class TestDPModelSVISmoke: + def _build_dp_panel(self): + """Build a minimal panel for DP SVI testing.""" + from datetime import date, timedelta + import pandas as pd + + np.random.seed(42) + pools_spec = [ + ("p0", "MAINNET", "WETH,USDC", 0.003, 0, 0), + ("p1", "ARBITRUM", "BAL,WETH", 0.01, 0, 1), + ("p2", "BASE", "RATS,WETH", 0.005, 0, 2), + ] + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(15)] + records = [] + for pid, chain, tokens, fee, ta, tb in pools_spec: + base_tvl = 14 + np.random.randn() * 0.5 + for d in dates: + tvl = base_tvl + np.random.randn() * 0.1 + vol = tvl - 2 + np.random.randn() * 0.3 + records.append({ + "pool_id": pid, "chain": chain, "date": d, + "log_volume": vol, "log_tvl": tvl, + "volatility": 0.3 + np.random.rand() * 0.2, + "weekend": 1.0 if d.weekday() >= 5 else 0.0, + "log_fee": np.log(max(fee, 1e-6)), + "swap_fee": fee, "tier_A": ta, "tier_B": tb, + "tokens": tokens, + }) + panel = pd.DataFrame(records) + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + return panel + + def test_svi_converges_dp_model(self): + """SVI on DP model with tiny data: ELBO should decrease.""" + import numpyro + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + panel = self._build_dp_panel() + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = 4 + + samples, losses = run_svi( + data, num_steps=2000, lr=1e-3, seed=0, + num_samples=50, model_fn=noise_model_dp_sigma, + ) + assert np.mean(losses[-100:]) < np.mean(losses[:100]) + + def test_svi_samples_have_dp_keys(self): + """SVI output for DP model has v, alpha_dp, sigma_eps.""" + import numpyro + numpyro.enable_x64() + + from quantammsim.noise_calibration import encode_covariates, run_svi + from quantammsim.noise_calibration.model import noise_model_dp_sigma + + panel = self._build_dp_panel() + data = encode_covariates(panel, include_tiers=False) + data["K_clusters"] = 4 + + samples, _ = run_svi( + data, num_steps=500, lr=1e-3, seed=0, + num_samples=10, model_fn=noise_model_dp_sigma, + ) + required = {"B", "sigma_theta", "L_Omega", "eta", "df", + "sigma_eps", "v", "alpha_dp"} + assert required.issubset(samples.keys()), ( + f"Missing keys: {required - samples.keys()}" + ) diff --git a/tests/noise/test_model_structural.py b/tests/noise/test_model_structural.py new file mode 100644 index 0000000..9243ba7 --- /dev/null +++ b/tests/noise/test_model_structural.py @@ -0,0 +1,215 @@ +"""Tests for structural_noise_model definition and SVI integration.""" + +import numpy as np +import pytest + + +class TestStructuralModelDefinition: + """Tests for the structural_noise_model numpyro model.""" + + def test_model_traces_without_error(self, synthetic_structural_data): + import jax + import jax.numpy as jnp + import numpyro + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + rng_key = jax.random.PRNGKey(0) + samples = predictive(rng_key, **kwargs) + assert "y" in samples + + def test_required_sites_present(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + required = { + "alpha_0", "alpha_chain", "alpha_tier", "alpha_tvl", + "W_gate", "beta", "df", "sigma_eps", "y", + } + assert required.issubset(samples.keys()), ( + f"Missing sites: {required - samples.keys()}" + ) + + def test_no_eta_sigma_theta_L_Omega(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=2) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + old_sites = {"eta", "sigma_theta", "L_Omega", "theta", "B"} + for site in old_sites: + assert site not in samples, f"Old site '{site}' should not be present" + + def test_W_gate_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + K_pool_cov = synthetic_structural_data["K_cov"] + K_archetypes = 3 # default + assert samples["W_gate"].shape == (3, K_pool_cov, K_archetypes) + + def test_beta_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + from quantammsim.noise_calibration.constants import K_OBS_COEFF + K_archetypes = 3 + assert samples["beta"].shape == (3, K_archetypes, K_OBS_COEFF) + + def test_alpha_chain_shape(self, synthetic_structural_data): + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + predictive = Predictive(structural_noise_model, num_samples=3) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + + n_chains = synthetic_structural_data["n_chains"] + assert samples["alpha_chain"].shape == (3, n_chains - 1) + + def test_prior_predictive_produces_y(self, synthetic_structural_data): + """y_obs=None path produces y samples.""" + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + kwargs["y_obs"] = None + predictive = Predictive(structural_noise_model, num_samples=10) + samples = predictive(jax.random.PRNGKey(42), **kwargs) + assert "y" in samples + assert samples["y"].shape[0] == 10 + + def test_prior_predictive_range_reasonable(self, synthetic_structural_data): + """Prior y median should be in a plausible range for log(volume). + + The tails can be wide (V_noise = exp(beta @ x_obs) with large + covariates), but the central mass should be reasonable. + """ + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + kwargs["y_obs"] = None + predictive = Predictive(structural_noise_model, num_samples=200) + samples = predictive(jax.random.PRNGKey(42), **kwargs) + y = np.array(samples["y"]) + median = np.median(y) + p25 = np.percentile(y, 25) + p75 = np.percentile(y, 75) + # Median should be in a plausible log-volume range + assert median > -20, f"Prior y median too low: {median}" + assert median < 50, f"Prior y median too high: {median}" + # IQR should not be enormous + iqr = p75 - p25 + assert iqr < 500, f"Prior y IQR too wide: {iqr}" + + def test_K_archetypes_configurable(self, synthetic_structural_data): + """K=2 and K=4 both work.""" + import jax + from numpyro.infer import Predictive + from quantammsim.noise_calibration.model import structural_noise_model + from quantammsim.noise_calibration.inference import _build_model_kwargs + + for K in [2, 4]: + kwargs = _build_model_kwargs(synthetic_structural_data, + model_fn=structural_noise_model) + kwargs["K_archetypes"] = K + predictive = Predictive(structural_noise_model, num_samples=2) + samples = predictive(jax.random.PRNGKey(0), **kwargs) + assert samples["W_gate"].shape[-1] == K + assert samples["beta"].shape[1] == K + + +class TestSVIStructural: + """SVI convergence tests for structural model.""" + + @pytest.mark.slow + def test_svi_structural_converges(self, synthetic_structural_data): + """2000 SVI steps, ELBO should decrease.""" + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, losses = run_svi( + synthetic_structural_data, + num_steps=2000, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + assert losses[-100:].mean() < losses[:100].mean() + + @pytest.mark.slow + def test_svi_structural_samples_have_required_keys( + self, synthetic_structural_data, + ): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + required = {"W_gate", "beta", "alpha_0", "alpha_chain", + "alpha_tier", "alpha_tvl", "df", "sigma_eps"} + assert required.issubset(samples.keys()), ( + f"Missing keys: {required - samples.keys()}" + ) + + @pytest.mark.slow + def test_svi_structural_no_eta_keys(self, synthetic_structural_data): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + old_keys = {"eta", "sigma_theta", "L_Omega", "B"} + for key in old_keys: + assert key not in samples, f"Old key '{key}' should not be in samples" diff --git a/tests/noise/test_output.py b/tests/noise/test_output.py new file mode 100644 index 0000000..7661c9d --- /dev/null +++ b/tests/noise/test_output.py @@ -0,0 +1,303 @@ +"""Tests for generate_output_json and _save_sample_cache.""" + +import json +import os + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + extract_noise_params, + generate_output_json, + _save_sample_cache, + K_COEFF, +) + + +# =========================================================================== +# TestGenerateOutputJSON +# =========================================================================== + + +class TestGenerateOutputJSON: + @pytest.fixture() + def _output_setup(self, tmp_path, synthetic_samples, synthetic_encoded_data): + """Produce the JSON file and return (path, data dict).""" + pool_params = extract_noise_params( + synthetic_samples, synthetic_encoded_data + ) + output_path = str(tmp_path / "test_output.json") + convergence = {"method": "svi", "final_elbo": 1234.0} + inference_config = {"method": "svi", "svi_steps": 1000} + + generate_output_json( + pool_params, synthetic_samples, synthetic_encoded_data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + data = json.load(f) + return output_path, data + + def test_writes_valid_json(self, _output_setup): + path, data = _output_setup + assert isinstance(data, dict) + + def test_top_level_keys(self, _output_setup): + _, data = _output_setup + expected = { + "model", "model_spec", "inference", "population_effects", + "convergence", "n_pools", "n_obs", "pools", + } + assert expected.issubset(data.keys()) + + def test_model_spec_fields(self, _output_setup): + _, data = _output_setup + spec = data["model_spec"] + assert "K_coeff" in spec + assert "K_cov" in spec + assert "coeff_names" in spec + assert "covariate_names" in spec + assert "likelihood" in spec + assert spec["likelihood"] == "StudentT" + assert "tvl_lag" in spec + assert spec["tvl_lag"] == "log_tvl_lag1" + + def test_population_effects_fields(self, _output_setup): + _, data = _output_setup + pe = data["population_effects"] + assert "B" in pe + assert "sigma_theta" in pe + assert "sigma_eps" in pe + assert "df" in pe + assert "correlation_matrix" in pe + + def test_pool_entries(self, _output_setup, synthetic_encoded_data): + _, data = _output_setup + pools = data["pools"] + for pid in synthetic_encoded_data["pool_ids"]: + assert pid in pools + entry = pools[pid] + assert "chain" in entry + assert "tokens" in entry + assert "theta_median" in entry + assert "noise_params" in entry + + def test_correlation_matrix_symmetric_unit_diagonal(self, _output_setup): + _, data = _output_setup + Omega = np.array(data["population_effects"]["correlation_matrix"]) + np.testing.assert_allclose(Omega, Omega.T, atol=1e-10) + np.testing.assert_allclose(np.diag(Omega), 1.0, atol=1e-10) + + def test_n_pools_and_n_obs(self, _output_setup, synthetic_encoded_data): + _, data = _output_setup + assert data["n_pools"] == synthetic_encoded_data["N_pools"] + assert data["n_obs"] == len(synthetic_encoded_data["y_obs"]) + + def test_model_name(self, _output_setup): + _, data = _output_setup + assert data["model"] == "unified_hierarchical_student_t" + + +# =========================================================================== +# TestSaveSampleCache +# =========================================================================== + + +class TestSaveSampleCache: + def test_creates_npz_and_json( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + cache_dir = str(tmp_path / "cache") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + assert os.path.exists(os.path.join(cache_dir, "unified_samples.npz")) + assert os.path.exists(os.path.join(cache_dir, "unified_data.json")) + + def test_npz_excludes_y_and_theta( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """Both 'y' and 'theta' must be excluded from the npz cache.""" + samples_with_extras = dict(synthetic_samples) + samples_with_extras["y"] = np.random.randn(10, 27) + samples_with_extras["theta"] = np.random.randn(10, 3, 4) + + cache_dir = str(tmp_path / "cache2") + _save_sample_cache(samples_with_extras, synthetic_encoded_data, cache_dir) + + npz_path = os.path.join(cache_dir, "unified_samples.npz") + loaded = np.load(npz_path) + assert "y" not in loaded.files + assert "theta" not in loaded.files + + def test_npz_contains_required_keys( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """The npz must contain B, sigma_theta, L_Omega, eta, df, sigma_eps.""" + cache_dir = str(tmp_path / "cache3") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + npz_path = os.path.join(cache_dir, "unified_samples.npz") + loaded = np.load(npz_path) + required = {"B", "sigma_theta", "L_Omega", "eta", "df", "sigma_eps"} + assert required.issubset(set(loaded.files)) + + def test_json_contains_metadata_keys( + self, tmp_path, synthetic_samples, synthetic_encoded_data + ): + """The JSON cache must contain all keys needed by --predict.""" + cache_dir = str(tmp_path / "cache4") + _save_sample_cache(synthetic_samples, synthetic_encoded_data, cache_dir) + + json_path = os.path.join(cache_dir, "unified_data.json") + with open(json_path) as f: + meta = json.load(f) + required = { + "pool_ids", "covariate_names", "K_cov", "N_pools", + "ref_chain", "ref_tier_a", "ref_tier_b", "chains", + } + assert required.issubset(meta.keys()) + + +# =========================================================================== +# TestGenerateOutputJSONDP +# =========================================================================== + + +class TestGenerateOutputJSONDP: + @pytest.fixture() + def _dp_output_setup(self, tmp_path, synthetic_encoded_data): + """Produce DP model JSON output and return (path, data dict).""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_clusters = 4 + S = 10 + + np.random.seed(99) + dp_samples = { + "B": np.random.randn(S, K_COEFF, K_cov) * 0.5, + "sigma_theta": np.ones((S, K_COEFF)), + "L_Omega": np.tile(np.eye(K_COEFF), (S, 1, 1)), + "eta": np.zeros((S, N_pools, K_COEFF)), + "df": np.full((S,), 5.0), + "sigma_eps": np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)), + "v": np.tile([0.6, 0.3, 0.05], (S, 1)), + "alpha_dp": np.full((S,), 1.5), + } + + pool_params = extract_noise_params(dp_samples, data) + output_path = str(tmp_path / "dp_output.json") + convergence = {"method": "svi", "final_elbo": 1234.0} + inference_config = {"method": "svi", "svi_steps": 1000} + + generate_output_json( + pool_params, dp_samples, data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + result = json.load(f) + return output_path, result + + def test_sigma_eps_structure_is_dp_mixture(self, _dp_output_setup): + _, data = _dp_output_setup + assert data["model_spec"]["sigma_eps_structure"] == "dp_mixture" + + def test_model_name_includes_dp(self, _dp_output_setup): + _, data = _dp_output_setup + assert "dp_sigma" in data["model"] + + def test_has_cluster_weights(self, _dp_output_setup): + _, data = _dp_output_setup + assert "cluster_weights" in data["population_effects"] + + def test_sigma_eps_length_equals_k_clusters(self, _dp_output_setup): + _, data = _dp_output_setup + sigma_eps = data["population_effects"]["sigma_eps"] + assert len(sigma_eps) == 4 # K_clusters = 4 + + def test_cluster_weights_sum_to_one(self, _dp_output_setup): + _, data = _dp_output_setup + w = data["population_effects"]["cluster_weights"] + np.testing.assert_allclose(sum(w), 1.0, atol=1e-4) + + def test_still_has_standard_fields(self, _dp_output_setup): + _, data = _dp_output_setup + expected = { + "model", "model_spec", "inference", "population_effects", + "convergence", "n_pools", "n_obs", "pools", + } + assert expected.issubset(data.keys()) + + +# =========================================================================== +# TestGenerateOutputJSONStructural +# =========================================================================== + + +class TestGenerateOutputJSONStructural: + @pytest.fixture() + def _structural_output_setup(self, tmp_path, synthetic_structural_data): + """Produce structural model JSON output and return (path, data dict).""" + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + from quantammsim.noise_calibration.constants import K_OBS_COEFF + + data = synthetic_structural_data + K_cov = data["K_cov"] + K_archetypes = 3 + n_chains = data["n_chains"] + n_tiers = data["n_tiers"] + S = 10 + + np.random.seed(77) + structural_samples = { + "alpha_0": np.random.randn(S) * 0.1 + 2.0, + "alpha_chain": np.random.randn(S, n_chains - 1) * 0.1, + "alpha_tier": np.random.randn(S, n_tiers - 1) * 0.1, + "alpha_tvl": np.random.randn(S) * 0.01, + "W_gate": np.random.randn(S, K_cov, K_archetypes) * 0.3, + "beta": np.random.randn(S, K_archetypes, K_OBS_COEFF) * 0.5, + "df": np.full((S,), 5.0), + "sigma_eps": np.full((S,), 0.5), + } + + pool_params = extract_structural_params(structural_samples, data) + output_path = str(tmp_path / "structural_output.json") + convergence = {"method": "svi", "final_elbo": 999.0} + inference_config = {"method": "svi", "svi_steps": 2000} + + generate_output_json( + pool_params, structural_samples, data, + convergence, output_path, inference_config, + ) + with open(output_path) as f: + result = json.load(f) + return output_path, result + + def test_output_model_name(self, _structural_output_setup): + _, data = _structural_output_setup + assert data["model"] == "structural_mixture" + + def test_output_has_arb_params(self, _structural_output_setup): + _, data = _structural_output_setup + pe = data["population_effects"] + assert "alpha_0" in pe + assert "alpha_chain" in pe + assert "alpha_tier" in pe + assert "alpha_tvl" in pe + + def test_output_has_archetype_info(self, _structural_output_setup): + _, data = _structural_output_setup + pe = data["population_effects"] + assert "W_gate" in pe + assert "beta" in pe + assert "K_archetypes" in pe + + def test_output_pools_have_arb_frequency(self, _structural_output_setup): + _, data = _structural_output_setup + pools = data["pools"] + for pid, entry in pools.items(): + assert "arb_frequency" in entry + assert isinstance(entry["arb_frequency"], int) + assert 1 <= entry["arb_frequency"] <= 60 diff --git a/tests/noise/test_panel_assembly.py b/tests/noise/test_panel_assembly.py new file mode 100644 index 0000000..dd7cd31 --- /dev/null +++ b/tests/noise/test_panel_assembly.py @@ -0,0 +1,442 @@ +"""Tests for compute_pair_volatility, assemble_panel, validate_panel.""" + +from datetime import date, timedelta + +import numpy as np +import pandas as pd +import pytest + +from quantammsim.noise_calibration import ( + compute_pair_volatility, + assemble_panel, + validate_panel, +) + + +# =========================================================================== +# TestComputePairVolatility +# =========================================================================== + + +class TestComputePairVolatility: + @pytest.fixture() + def _snap_dates(self): + """10 unique dates for snapshot stub.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(10)] + return pd.DataFrame({"date": dates}) + + def test_stablecoin_pair_returns_001(self, _snap_dates): + pool_row = pd.Series({ + "tokens": ["USDC", "DAI"], + "chain": "MAINNET", + }) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert (vol == 0.01).all() + + def test_both_missing_non_stable(self, _snap_dates): + pool_row = pd.Series({ + "tokens": ["FOO", "BAR"], + "chain": "MAINNET", + }) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert (vol == 0.5).all() + + def test_one_missing_non_stable(self, _snap_dates): + np.random.seed(77) + pool_row = pd.Series({ + "tokens": ["WETH", "BAR"], + "chain": "MAINNET", + }) + prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": [1735689600 + i * 3600 for i in range(48)], + "price": [3000.0 + np.random.randn() * 10 for _ in range(48)], + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, prices) + assert (vol == 0.5).all() + + def test_synthetic_hourly_prices_positive_finite(self, _snap_dates): + np.random.seed(42) + n_hours = 240 # 10 days x 24 hours + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + prices_a = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 8) + prices_b = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 7) + + pool_row = pd.Series({ + "tokens": ["WETH", "LINK"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": prices_a, + }), + ("MAINNET", "LINK"): pd.DataFrame({ + "timestamp": timestamps, "price": prices_b, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + assert len(vol) > 0 + assert (vol > 0).all() + assert np.all(np.isfinite(vol)) + + def test_annualisation_uses_sqrt_24x365(self, _snap_dates): + """Verify the annualisation factor is sqrt(24*365).""" + np.random.seed(7) + n_hours = 240 + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + raw_prices = np.exp(np.cumsum(np.random.randn(n_hours) * 0.005) + 8) + + pool_row = pd.Series({ + "tokens": ["WETH", "USDC"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": raw_prices, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + + # Reconstruct manually + df = pd.DataFrame({"timestamp": timestamps, "price": raw_prices}) + df["datetime"] = pd.to_datetime(df["timestamp"], unit="s") + df["date"] = df["datetime"].dt.date + df["ratio"] = df["price"] # WETH vs stable => ratio = price + df["log_return"] = np.log(df["ratio"] / df["ratio"].shift(1)) + df = df.dropna(subset=["log_return"]) + daily_std = df.groupby("date")["log_return"].std() + expected = daily_std * np.sqrt(24 * 365) + + common = vol.index.intersection(expected.index) + assert len(common) > 0 + np.testing.assert_allclose( + vol.loc[common].values, expected.loc[common].values, rtol=1e-10, + ) + + def test_single_token_pool_returns_empty(self, _snap_dates): + pool_row = pd.Series({"tokens": ["WETH"], "chain": "MAINNET"}) + vol = compute_pair_volatility(_snap_dates, pool_row, {}) + assert len(vol) == 0 + + def test_no_overlapping_price_dates(self, _snap_dates): + """Two tokens with non-overlapping timestamps -> fallback 0.5.""" + pool_row = pd.Series({ + "tokens": ["WETH", "LINK"], + "chain": "MAINNET", + }) + token_prices = { + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": [1000000 + i for i in range(10)], + "price": [3000.0] * 10, + }), + ("MAINNET", "LINK"): pd.DataFrame({ + "timestamp": [9000000 + i for i in range(10)], + "price": [15.0] * 10, + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + assert (vol == 0.5).all() + + def test_cross_chain_price_fallback(self, _snap_dates): + """Prices keyed to a different chain SHOULD be used as fallback. + + Token prices are chain-agnostic (WETH is WETH regardless of chain), + so an ARBITRUM pool should use MAINNET WETH prices if ARBITRUM + prices aren't available. + """ + np.random.seed(55) + n_hours = 240 + ts_base = 1735689600 + timestamps = [ts_base + i * 3600 for i in range(n_hours)] + prices = np.exp(np.cumsum(np.random.randn(n_hours) * 0.01) + 8) + + pool_row = pd.Series({ + "tokens": ["WETH", "USDC"], + "chain": "ARBITRUM", + }) + token_prices = { + # Only MAINNET prices, but ARBITRUM pool should still use them + ("MAINNET", "WETH"): pd.DataFrame({ + "timestamp": timestamps, "price": prices.tolist(), + }), + } + vol = compute_pair_volatility(_snap_dates, pool_row, token_prices) + # Should get real volatility, NOT the 0.5 fallback + assert len(vol) > 0 + assert not (vol == 0.5).all(), ( + "Cross-chain price fallback should have produced real volatility" + ) + + +# =========================================================================== +# TestAssemblePanel +# =========================================================================== + + +class TestAssemblePanel: + def test_lagged_tvl_drops_first_obs_per_pool( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + counts = panel.groupby("pool_id").size() + assert (counts == 9).all() + + def test_lagged_tvl_exact_values( + self, synthetic_pools_df, synthetic_snapshots_df + ): + """log_tvl_lag1[t] must equal log_tvl[t-1] for same pool (shift(1)).""" + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + for pid in panel["pool_id"].unique(): + pool = panel[panel["pool_id"] == pid].sort_values("date") + tvl_vals = pool["log_tvl"].values + lag_vals = pool["log_tvl_lag1"].values + # After dropping the first obs, lag[i] = tvl[i-1] in the original + # pre-drop series. Since the panel is sorted by date, each + # lag value should equal the log_tvl of the chronologically + # preceding observation. We verify consecutive pairs: for rows + # i and i+1, lag[i+1] == tvl[i]. + for i in range(len(tvl_vals) - 1): + np.testing.assert_allclose( + lag_vals[i + 1], tvl_vals[i], rtol=1e-14, + err_msg=f"Pool {pid} row {i+1}: lag should equal previous tvl", + ) + + def test_weekend_flag(self, synthetic_pools_df, synthetic_snapshots_df): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + for _, row in panel.iterrows(): + d = row["date"] + if not isinstance(d, date): + d = pd.Timestamp(d).date() + expected = 1.0 if d.weekday() >= 5 else 0.0 + assert row["weekend"] == expected, f"Wrong weekend flag for {d}" + + def test_log_volume_is_natural_log( + self, synthetic_pools_df, synthetic_snapshots_df + ): + """log_volume must equal ln(volume_usd), not log10 or log2.""" + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + snaps = synthetic_snapshots_df.copy() + # Join snapshots to panel by pool_id + date to verify exact log values + for _, row in panel.iterrows(): + pid = row["pool_id"] + d = row["date"] + snap_match = snaps[ + (snaps["pool_id"] == pid) & (snaps["date"] == d) + ] + assert len(snap_match) == 1, f"No snapshot for {pid} on {d}" + expected = np.log(snap_match.iloc[0]["volume_usd"]) + np.testing.assert_allclose( + row["log_volume"], expected, rtol=1e-14, + err_msg=f"log_volume for {pid} on {d} should be ln(volume_usd)", + ) + + def test_tier_assignment_min_first( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + # Pool A: WETH(0), USDC(0) -> tier_A=0, tier_B=0 + pool_a = panel[panel["pool_id"] == "pool_A"].iloc[0] + assert pool_a["tier_A"] == 0 + assert pool_a["tier_B"] == 0 + + # Pool B: BAL(1), WETH(0) -> tier_A=0, tier_B=1 + pool_b = panel[panel["pool_id"] == "pool_B"].iloc[0] + assert pool_b["tier_A"] == 0 + assert pool_b["tier_B"] == 1 + + # Pool C: RATS(2), WETH(0) -> tier_A=0, tier_B=2 + pool_c = panel[panel["pool_id"] == "pool_C"].iloc[0] + assert pool_c["tier_A"] == 0 + assert pool_c["tier_B"] == 2 + + def test_log_fee(self, synthetic_pools_df, synthetic_snapshots_df): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + pool_a = panel[panel["pool_id"] == "pool_A"].iloc[0] + assert np.isclose(pool_a["log_fee"], np.log(0.003)) + + def test_all_expected_columns( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + expected_cols = { + "pool_id", "chain", "date", "log_volume", "log_tvl", + "log_tvl_lag1", "volatility", "weekend", "log_fee", + "swap_fee", "tier_A", "tier_B", "tokens", + } + assert expected_cols.issubset(set(panel.columns)) + + def test_zero_volume_rows_dropped(self, synthetic_pools_df): + """Rows with volume_usd <= 0 must be excluded from the panel.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(5)] + records = [] + for d in dates: + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", "date": d, + "volume_usd": 1000.0, + "total_liquidity_usd": 100000.0, + }) + # Add a zero-volume row + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", + "date": date(2026, 1, 6), + "volume_usd": 0.0, + "total_liquidity_usd": 100000.0, + }) + snaps = pd.DataFrame(records) + panel = assemble_panel(synthetic_pools_df, snaps, {}) + pool_a = panel[panel["pool_id"] == "pool_A"] + # 5 valid rows, minus 1 for lag = 4 (the zero-volume row is skipped) + assert len(pool_a) == 4 + + def test_zero_tvl_rows_dropped(self, synthetic_pools_df): + """Rows with total_liquidity_usd <= 0 must be excluded.""" + dates = [date(2026, 1, 1) + timedelta(days=i) for i in range(5)] + records = [] + for d in dates: + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", "date": d, + "volume_usd": 1000.0, + "total_liquidity_usd": 100000.0, + }) + # Add a zero-TVL row + records.append({ + "pool_id": "pool_A", "chain": "MAINNET", + "date": date(2026, 1, 6), + "volume_usd": 1000.0, + "total_liquidity_usd": 0.0, + }) + snaps = pd.DataFrame(records) + panel = assemble_panel(synthetic_pools_df, snaps, {}) + pool_a = panel[panel["pool_id"] == "pool_A"] + assert len(pool_a) == 4 + + +# =========================================================================== +# TestSyntheticPanelColumns — structural model covariates in the fixture +# =========================================================================== + + +class TestSyntheticPanelColumns: + """Tests that the synthetic_panel fixture has new structural columns.""" + + def test_panel_has_log_sigma(self, synthetic_panel): + assert "log_sigma" in synthetic_panel.columns + expected = np.log(np.maximum(synthetic_panel["volatility"].values, 1e-6)) + np.testing.assert_allclose( + synthetic_panel["log_sigma"].values, expected, rtol=1e-12, + ) + + def test_panel_has_dow_harmonics(self, synthetic_panel): + assert "dow_sin" in synthetic_panel.columns + assert "dow_cos" in synthetic_panel.columns + assert (synthetic_panel["dow_sin"] >= -1.0).all() + assert (synthetic_panel["dow_sin"] <= 1.0).all() + assert (synthetic_panel["dow_cos"] >= -1.0).all() + assert (synthetic_panel["dow_cos"] <= 1.0).all() + + def test_panel_has_interactions(self, synthetic_panel): + assert "tvl_x_sigma" in synthetic_panel.columns + assert "tvl_x_fee" in synthetic_panel.columns + assert "sigma_x_fee" in synthetic_panel.columns + + def test_dow_harmonics_correct_for_known_date(self, synthetic_panel): + """2026-01-03 is Saturday (weekday=5), so dow=5.""" + sat_rows = synthetic_panel[ + synthetic_panel["date"] == date(2026, 1, 3) + ] + if len(sat_rows) == 0: + pytest.skip("No Saturday rows in fixture") + expected_sin = np.sin(2 * np.pi * 5 / 7) + expected_cos = np.cos(2 * np.pi * 5 / 7) + np.testing.assert_allclose( + sat_rows["dow_sin"].values[0], expected_sin, atol=1e-12, + ) + np.testing.assert_allclose( + sat_rows["dow_cos"].values[0], expected_cos, atol=1e-12, + ) + + def test_interactions_use_lagged_tvl(self, synthetic_panel): + """tvl_x_sigma must use log_tvl_lag1, not log_tvl.""" + expected = ( + synthetic_panel["log_tvl_lag1"].values + * synthetic_panel["log_sigma"].values + ) + np.testing.assert_allclose( + synthetic_panel["tvl_x_sigma"].values, expected, rtol=1e-12, + ) + + +# =========================================================================== +# TestAssemblePanelStructuralColumns — new columns from real pipeline +# =========================================================================== + + +class TestAssemblePanelStructuralColumns: + """Tests that assemble_panel() produces new structural columns.""" + + def test_assemble_panel_has_log_sigma( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "log_sigma" in panel.columns + expected = np.log(np.maximum(panel["volatility"].values, 1e-6)) + np.testing.assert_allclose( + panel["log_sigma"].values, expected, rtol=1e-12, + ) + + def test_assemble_panel_has_dow_harmonics( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "dow_sin" in panel.columns + assert "dow_cos" in panel.columns + + def test_assemble_panel_has_interactions( + self, synthetic_pools_df, synthetic_snapshots_df + ): + panel = assemble_panel(synthetic_pools_df, synthetic_snapshots_df, {}) + assert "tvl_x_sigma" in panel.columns + assert "tvl_x_fee" in panel.columns + assert "sigma_x_fee" in panel.columns + # Interactions use lagged TVL + expected = panel["log_tvl_lag1"].values * panel["log_sigma"].values + np.testing.assert_allclose( + panel["tvl_x_sigma"].values, expected, rtol=1e-12, + ) + + +# =========================================================================== +# TestValidatePanel +# =========================================================================== + + +class TestValidatePanel: + def test_flags_constant_volume(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + mask = panel["pool_id"] == "pool_A" + panel.loc[mask, "log_volume"] = 10.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "near-constant" in captured.out or "constant" in captured.out.lower() + + def test_flags_tvl_jumps(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + idx = panel[panel["pool_id"] == "pool_A"].index[1] + panel.loc[idx, "log_tvl"] = panel.loc[idx, "log_tvl"] + 5.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "TVL jumps" in captured.out + + def test_flags_volume_exceeds_tvl(self, synthetic_panel, capsys): + panel = synthetic_panel.copy() + panel["log_volume"] = panel["log_tvl"] + 1.0 + validate_panel(panel) + captured = capsys.readouterr() + assert "volume > TVL" in captured.out + + def test_returns_dataframe_unchanged(self, synthetic_panel): + result = validate_panel(synthetic_panel) + pd.testing.assert_frame_equal(result, synthetic_panel) diff --git a/tests/noise/test_postprocessing.py b/tests/noise/test_postprocessing.py new file mode 100644 index 0000000..3d7d2e4 --- /dev/null +++ b/tests/noise/test_postprocessing.py @@ -0,0 +1,533 @@ +"""Tests for extract_noise_params, predict_new_pool, check_convergence, +assign_dp_clusters, and structural model post-processing.""" + +import numpy as np +import pytest + +from quantammsim.noise_calibration import ( + extract_noise_params, + predict_new_pool, + check_convergence, + classify_token_tier, + _get_theta_samples, + K_COEFF, + COEFF_NAMES, +) +from quantammsim.noise_calibration.constants import K_OBS_COEFF, OBS_COEFF_NAMES + + +# =========================================================================== +# TestExtractNoiseParams +# =========================================================================== + + +class TestExtractNoiseParams: + def test_output_length(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + assert len(result) == synthetic_encoded_data["N_pools"] + + def test_weekend_absorption(self, synthetic_samples, synthetic_encoded_data): + """b_0_eff = b_0_raw + b_weekend * (2/7).""" + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + + theta = _get_theta_samples( + synthetic_samples, synthetic_encoded_data["X_pool"] + ) + theta_med = np.median(theta, axis=0) + + for i, p in enumerate(result): + b_0_raw = theta_med[i, 0] + b_weekend = theta_med[i, 3] + expected_b_0 = b_0_raw + b_weekend * (2.0 / 7.0) + np.testing.assert_allclose( + p["noise_params"]["b_0"], expected_b_0, atol=1e-10, + ) + + def test_noise_params_keys(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + expected_keys = {"b_0", "b_sigma", "b_c", "b_weekend", "base_fee"} + for p in result: + assert set(p["noise_params"].keys()) == expected_keys + + def test_theta_median_length(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + assert len(p["theta_median"]) == K_COEFF + + def test_b_c_equals_theta_1(self, synthetic_samples, synthetic_encoded_data): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + np.testing.assert_allclose( + p["noise_params"]["b_c"], p["theta_median"][1], atol=1e-10, + ) + + def test_b_sigma_equals_theta_2( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + for p in result: + np.testing.assert_allclose( + p["noise_params"]["b_sigma"], p["theta_median"][2], atol=1e-10, + ) + + def test_base_fee_matches_pool( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + pool_meta = synthetic_encoded_data["pool_meta"] + for i, p in enumerate(result): + expected_fee = pool_meta.iloc[i]["swap_fee"] + np.testing.assert_allclose( + p["noise_params"]["base_fee"], expected_fee, atol=1e-10, + ) + + def test_pool_id_and_chain_preserved( + self, synthetic_samples, synthetic_encoded_data + ): + result = extract_noise_params(synthetic_samples, synthetic_encoded_data) + pool_ids = synthetic_encoded_data["pool_ids"] + pool_meta = synthetic_encoded_data["pool_meta"] + for i, p in enumerate(result): + assert p["pool_id"] == pool_ids[i] + assert p["chain"] == str(pool_meta.iloc[i]["chain"]) + + def test_use_median_false_uses_mean( + self, synthetic_samples, synthetic_encoded_data + ): + """use_median=False must produce different values than use_median=True.""" + result_med = extract_noise_params( + synthetic_samples, synthetic_encoded_data, use_median=True + ) + result_mean = extract_noise_params( + synthetic_samples, synthetic_encoded_data, use_median=False + ) + assert len(result_mean) == len(result_med) + + # With random B samples (S=10, seed 99), median != mean for at least + # one pool. Check that at least one theta_median value differs. + any_differ = False + for pm, pn in zip(result_med, result_mean): + for tm, tn in zip(pm["theta_median"], pn["theta_median"]): + if not np.isclose(tm, tn, atol=1e-14): + any_differ = True + break + if any_differ: + break + assert any_differ, "median and mean paths produced identical values" + + +# =========================================================================== +# TestPredictNewPool +# =========================================================================== + + +class TestPredictNewPool: + def _build_z_new(self, data, chain, tokens, fee): + """Reconstruct z_new the same way predict_new_pool does internally.""" + col_names = data["covariate_names"] + z_new = np.zeros(len(col_names), dtype=np.float64) + + tiers = sorted([classify_token_tier(t) for t in tokens]) + tier_a = str(tiers[0]) + tier_b = str(tiers[1]) if len(tiers) > 1 else tier_a + + for i, name in enumerate(col_names): + if name == "intercept": + z_new[i] = 1.0 + elif name == "log_fee": + z_new[i] = np.log(max(fee, 1e-6)) + elif name == f"chain_{chain}": + z_new[i] = 1.0 + elif name == f"tier_A_{tier_a}": + z_new[i] = 1.0 + elif name == f"tier_B_{tier_b}": + z_new[i] = 1.0 + return z_new + + def test_known_chain_sets_dummy( + self, synthetic_samples, synthetic_encoded_data + ): + """ARBITRUM dummy must be 1 and must affect the prediction vs MAINNET.""" + data = synthetic_encoded_data + + result_arb = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "USDC"], fee=0.003, + ) + result_main = predict_new_pool( + synthetic_samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], fee=0.003, + ) + # MAINNET is the reference chain (alphabetically: ARBITRUM < BASE < MAINNET). + # Wait — ARBITRUM is alphabetically first, so ARBITRUM is the reference. + # MAINNET has a chain_MAINNET dummy. Both should produce different mu. + # If chain dummies are ignored, these would be identical. + arb_b0 = result_arb["noise_params"]["b_0"] + main_b0 = result_main["noise_params"]["b_0"] + assert not np.isclose(arb_b0, main_b0, atol=1e-10), ( + "Different chains should produce different predictions" + ) + + def test_tier_assignment_affects_prediction( + self, synthetic_samples, synthetic_encoded_data + ): + """WETH/RATS (tier 0,2) vs WETH/USDC (tier 0,0) must differ.""" + data = synthetic_encoded_data + + result_rats = predict_new_pool( + synthetic_samples, data, + chain="BASE", tokens=["WETH", "RATS"], fee=0.005, + ) + result_usdc = predict_new_pool( + synthetic_samples, data, + chain="BASE", tokens=["WETH", "USDC"], fee=0.005, + ) + # Different tier_B dummies should give different mu + rats_b0 = result_rats["noise_params"]["b_0"] + usdc_b0 = result_usdc["noise_params"]["b_0"] + assert not np.isclose(rats_b0, usdc_b0, atol=1e-10), ( + "Different tier assignments should produce different predictions" + ) + + def test_weekend_absorption_arithmetic( + self, synthetic_samples, synthetic_encoded_data + ): + """b_0 in noise_params must equal mu_median[0] + mu_median[3] * (2/7).""" + data = synthetic_encoded_data + result = predict_new_pool( + synthetic_samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], fee=0.003, + ) + # Reconstruct mu_median independently + z_new = self._build_z_new(data, "MAINNET", ["WETH", "USDC"], 0.003) + B = synthetic_samples["B"] + mu_samples = np.einsum("skd,d->sk", B, z_new) + mu_median = np.median(mu_samples, axis=0) + + b_0_raw = mu_median[0] + b_weekend = mu_median[3] + expected_b_0 = b_0_raw + b_weekend * (2.0 / 7.0) + np.testing.assert_allclose( + result["noise_params"]["b_0"], expected_b_0, atol=1e-10, + ) + + def test_mu_equals_b_at_z_new( + self, synthetic_samples, synthetic_encoded_data + ): + """With known B samples, verify credible interval medians = median(B @ z_new).""" + data = synthetic_encoded_data + z_new = self._build_z_new(data, "ARBITRUM", ["WETH", "RATS"], 0.005) + + B = synthetic_samples["B"] + mu_expected = np.einsum("skd,d->sk", B, z_new) + mu_median_expected = np.median(mu_expected, axis=0) + + result = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "RATS"], fee=0.005, + ) + for k, name in enumerate(COEFF_NAMES): + np.testing.assert_allclose( + result["credible_intervals_90"][name]["median"], + mu_median_expected[k], + atol=1e-10, + ) + + def test_unseen_chain_uses_reference( + self, synthetic_samples, synthetic_encoded_data + ): + """A chain not in training data should get reference-chain prediction + (all chain dummies = 0), not raise an error.""" + data = synthetic_encoded_data + result = predict_new_pool( + synthetic_samples, data, + chain="SONIC", tokens=["WETH", "USDC"], fee=0.003, + ) + # Should be same as ARBITRUM (the reference chain, all dummies 0) + result_ref = predict_new_pool( + synthetic_samples, data, + chain="ARBITRUM", tokens=["WETH", "USDC"], fee=0.003, + ) + np.testing.assert_allclose( + result["noise_params"]["b_0"], + result_ref["noise_params"]["b_0"], + atol=1e-10, + ) + + +# =========================================================================== +# TestCheckConvergence +# =========================================================================== + + +class TestCheckConvergence: + def test_svi_returns_expected_keys(self): + losses = np.random.randn(1000).cumsum() + 5000 + result = check_convergence(losses, method="svi") + assert "final_elbo" in result + assert "elbo_last_100_std" in result + assert "elbo_last_100_mean" in result + + def test_svi_method_key(self): + losses = np.linspace(5000, 1000, 500) + result = check_convergence(losses, method="svi") + assert result["method"] == "svi" + + def test_svi_elbo_last_100_std_correct(self): + np.random.seed(42) + losses = np.random.randn(500) * 10 + 1000 + result = check_convergence(losses, method="svi") + expected_std = float(np.std(losses[-100:])) + np.testing.assert_allclose( + result["elbo_last_100_std"], expected_std, atol=1e-10, + ) + + def test_svi_final_elbo_is_last_loss(self): + losses = np.array([100.0, 50.0, 25.0, 12.5]) + result = check_convergence(losses, method="svi") + assert result["final_elbo"] == 12.5 + + +# =========================================================================== +# TestAssignDPClusters +# =========================================================================== + + +class TestAssignDPClusters: + @pytest.fixture() + def dp_samples_and_data(self, synthetic_encoded_data): + """Synthetic DP posterior samples with known cluster structure.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + K_clusters = 4 + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_COEFF, K_cov) * 0.5 + sigma_theta = np.ones((S, K_COEFF)) + L_Omega = np.tile(np.eye(K_COEFF), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_COEFF)) + df = np.full((S,), 5.0) + + # Well-separated sigma_eps clusters + sigma_eps = np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)) + # Stick-breaking weights: mostly on cluster 0 + v = np.tile([0.6, 0.3, 0.05], (S, 1)) + + samples = { + "B": B, "sigma_theta": sigma_theta, "L_Omega": L_Omega, + "eta": eta, "df": df, "sigma_eps": sigma_eps, "v": v, + } + data_with_k = dict(data) + data_with_k["K_clusters"] = K_clusters + return samples, data_with_k + + def test_returns_correct_length(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + assert len(assignments) == data["N_pools"] + + def test_valid_cluster_indices(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + K = data["K_clusters"] + assert all(0 <= a < K for a in assignments) + + def test_returns_integer_array(self, dp_samples_and_data): + from quantammsim.noise_calibration.postprocessing import assign_dp_clusters + + samples, data = dp_samples_and_data + assignments = assign_dp_clusters(samples, data) + assert assignments.dtype in (np.int32, np.int64, int) + + +# =========================================================================== +# TestExtractNoiseParamsDP +# =========================================================================== + + +class TestExtractNoiseParamsDP: + @pytest.fixture() + def dp_samples_and_data(self, synthetic_encoded_data): + """Same as above for extract_noise_params testing.""" + data = synthetic_encoded_data + N_pools = data["N_pools"] + K_cov = data["K_cov"] + S = 10 + + np.random.seed(99) + B = np.random.randn(S, K_COEFF, K_cov) * 0.5 + sigma_theta = np.ones((S, K_COEFF)) + L_Omega = np.tile(np.eye(K_COEFF), (S, 1, 1)) + eta = np.zeros((S, N_pools, K_COEFF)) + df = np.full((S,), 5.0) + sigma_eps = np.tile([0.3, 0.8, 1.5, 2.5], (S, 1)) + v = np.tile([0.6, 0.3, 0.05], (S, 1)) + + samples = { + "B": B, "sigma_theta": sigma_theta, "L_Omega": L_Omega, + "eta": eta, "df": df, "sigma_eps": sigma_eps, "v": v, + } + return samples, data + + def test_output_length(self, dp_samples_and_data): + samples, data = dp_samples_and_data + result = extract_noise_params(samples, data) + assert len(result) == data["N_pools"] + + def test_noise_params_keys_present(self, dp_samples_and_data): + samples, data = dp_samples_and_data + result = extract_noise_params(samples, data) + expected_keys = {"b_0", "b_sigma", "b_c", "b_weekend", "base_fee"} + for p in result: + assert set(p["noise_params"].keys()) == expected_keys + + +# =========================================================================== +# TestExtractStructuralParams +# =========================================================================== + + +class TestExtractStructuralParams: + """Tests for extract_structural_params().""" + + @pytest.fixture() + def structural_samples_and_data(self, synthetic_structural_data): + """Run a quick SVI fit to get structural samples.""" + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + return samples, synthetic_structural_data + + def test_extract_returns_arb_params(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + samples, data = structural_samples_and_data + result = extract_structural_params(samples, data) + assert len(result) == data["N_pools"] + for p in result: + assert "arb_frequency" in p + assert isinstance(p["arb_frequency"], int) + assert 1 <= p["arb_frequency"] <= 60 + + def test_extract_returns_noise_params(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + extract_structural_params, + ) + samples, data = structural_samples_and_data + result = extract_structural_params(samples, data) + for p in result: + assert "noise_params" in p + coeffs = p["noise_params"] + assert len(coeffs) == K_OBS_COEFF + for name in OBS_COEFF_NAMES: + assert name in coeffs, f"Missing coefficient: {name}" + + +# =========================================================================== +# TestPredictStructural +# =========================================================================== + + +class TestPredictStructural: + @pytest.fixture() + def structural_samples_and_data(self, synthetic_structural_data): + from quantammsim.noise_calibration.inference import run_svi + from quantammsim.noise_calibration.model import structural_noise_model + + samples, _ = run_svi( + synthetic_structural_data, + num_steps=500, + lr=5e-3, + seed=0, + num_samples=10, + model_fn=structural_noise_model, + ) + return samples, synthetic_structural_data + + def test_predict_returns_cadence(self, structural_samples_and_data): + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + result = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + assert "arb_frequency" in result + assert isinstance(result["arb_frequency"], int) + assert 1 <= result["arb_frequency"] <= 60 + + def test_predict_returns_noise_coefficients( + self, structural_samples_and_data, + ): + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + result = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + assert "noise_params" in result + assert len(result["noise_params"]) == K_OBS_COEFF + + def test_predict_uses_W_gate(self, structural_samples_and_data): + """Different (chain, tier) → different predictions.""" + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + r1 = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + r2 = predict_new_pool_structural( + samples, data, + chain="ARBITRUM", tokens=["BAL", "WETH"], + fee=0.01, tvl_est=5e5, + ) + # Different pool characteristics should produce different noise params + assert r1["noise_params"] != r2["noise_params"] + + def test_predict_cadence_higher_for_longtail( + self, structural_samples_and_data, + ): + """Long-tail pools should have higher cadence (less efficient arb).""" + from quantammsim.noise_calibration.postprocessing import ( + predict_new_pool_structural, + ) + samples, data = structural_samples_and_data + # This tests the structural relationship — it may not hold with + # random SVI samples on synthetic data, so we just check it doesn't + # crash and returns valid values + r1 = predict_new_pool_structural( + samples, data, + chain="MAINNET", tokens=["WETH", "USDC"], + fee=0.003, tvl_est=1e6, + ) + r2 = predict_new_pool_structural( + samples, data, + chain="BASE", tokens=["RATS", "WETH"], + fee=0.005, tvl_est=1e5, + ) + # Both should be valid + assert 1 <= r1["arb_frequency"] <= 60 + assert 1 <= r2["arb_frequency"] <= 60 diff --git a/tests/noise/test_token_classification.py b/tests/noise/test_token_classification.py new file mode 100644 index 0000000..dd540f8 --- /dev/null +++ b/tests/noise/test_token_classification.py @@ -0,0 +1,66 @@ +"""Tests for _normalise_symbol and classify_token_tier.""" + +import pytest +from quantammsim.noise_calibration import _normalise_symbol, classify_token_tier + + +# =========================================================================== +# TestNormaliseSymbol +# =========================================================================== + + +class TestNormaliseSymbol: + def test_passthrough_unknown(self): + assert _normalise_symbol("FOO") == "FOO" + + def test_known_mapping_preserved(self): + assert _normalise_symbol("WETH") == "WETH" + assert _normalise_symbol("WBTC") == "WBTC" + assert _normalise_symbol("cbBTC") == "cbBTC" + + def test_whitespace_stripped(self): + assert _normalise_symbol(" ETH ") == "ETH" + assert _normalise_symbol(" WETH ") == "WETH" + + def test_case_sensitivity_preserved(self): + # Lowercase is NOT normalised to uppercase + assert _normalise_symbol("weth") == "weth" + + +# =========================================================================== +# TestClassifyTokenTier +# =========================================================================== + + +class TestClassifyTokenTier: + def test_tier0_native_tokens(self): + for sym in ["ETH", "BTC"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_wrapped(self): + for sym in ["WETH", "WBTC", "cbBTC"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_stablecoins(self): + for sym in ["USDC", "USDT", "DAI"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier0_chain_natives(self): + for sym in ["MATIC", "AVAX", "GNO", "S", "wS"]: + assert classify_token_tier(sym) == 0, f"{sym} should be tier 0" + + def test_tier1_defi_bluechips(self): + for sym in ["AAVE", "BAL", "COW", "LINK", "ARB"]: + assert classify_token_tier(sym) == 1, f"{sym} should be tier 1" + + def test_tier2_unknown_tokens(self): + for sym in ["RATS", "PEPE"]: + assert classify_token_tier(sym) == 2, f"{sym} should be tier 2" + + def test_tier2_empty_string(self): + assert classify_token_tier("") == 2 + + def test_wrapped_variant_normalisation(self): + # "wS" is in the mapping table AND in _TIER_0 + assert classify_token_tier("wS") == 0 + assert classify_token_tier(" wS ") == 0 From 9812099cf50fca3c97c2bb5130857ba4edf753c8 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Fri, 6 Mar 2026 16:56:45 +0000 Subject: [PATCH 08/14] port over and combine with dynamic obj noise modelling approach --- quantammsim/pools/reCLAMM/reclamm.py | 190 +++++++++- quantammsim/pools/reCLAMM/reclamm_reserves.py | 329 +++++++++++++++--- 2 files changed, 460 insertions(+), 59 deletions(-) diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index fbfecab..aa57d3e 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -10,7 +10,7 @@ config.update("jax_enable_x64", True) import jax.numpy as jnp -from jax import jit, tree_util +from jax import jit, tree_util, vmap from jax.lax import dynamic_slice from functools import partial from typing import Dict, Any, Optional, NamedTuple @@ -34,6 +34,73 @@ SHIFT_EXPONENT_DIVISOR = 124649.0 +def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len): + """Slice and decimate a dynamic input array to match arb_prices shape.""" + arr = jnp.asarray(arr) + if arr.ndim == 0: + return jnp.full((max_len,), arr, dtype=arr.dtype) + if arr.shape[0] <= 1: + return jnp.broadcast_to(arr, (max_len,) + arr.shape[1:]) + + start = (start_index[0],) + (0,) * (arr.ndim - 1) + slice_sizes = (bout_length - 1,) + arr.shape[1:] + sliced = dynamic_slice(arr, start, slice_sizes) + if arb_frequency != 1: + sliced = sliced[::arb_frequency] + return sliced + + +def _align_prices_to_numeraire(prices, run_fingerprint): + """Ensure the numeraire token is in column 1 for ratio-volatility calc.""" + tokens = run_fingerprint.get("tokens") + numeraire = run_fingerprint.get("numeraire") + if tokens is None or numeraire is None or len(tokens) != 2: + return prices + + token_labels = [str(token).lower() for token in tokens] + numeraire_label = str(numeraire).lower() + if token_labels[0] == numeraire_label: + return prices[:, ::-1] + return prices + + +def _calculate_annualized_ratio_volatility( + prices, run_fingerprint, subsample_freq=5, +): + """Annualized daily realized volatility broadcast to minute-level array.""" + ordered_prices = _align_prices_to_numeraire(prices, run_fingerprint) + asset_prices = ordered_prices[:, 0] / ordered_prices[:, 1] + n_minutes = asset_prices.shape[0] + + if n_minutes < 1440: + return jnp.full((n_minutes,), 0.1 * jnp.sqrt(365.0), dtype=prices.dtype) + + n_days = n_minutes // 1440 + + def calculate_daily_volatility(day_idx): + start_idx = day_idx * 1440 + window_prices = dynamic_slice(asset_prices, (start_idx,), (1440,)) + subsampled_prices = window_prices[::subsample_freq] + log_prices = jnp.log(jnp.maximum(subsampled_prices, 1e-8)) + returns = jnp.diff(log_prices) + num_nonzero_returns = jnp.sum(returns != 0) + total_returns = jnp.maximum(returns.shape[0], 1) + adjusted_variance = num_nonzero_returns * jnp.var(returns) / total_returns + dt = subsample_freq / 1440 + return jnp.sqrt(adjusted_variance) / jnp.sqrt(dt) + + daily_volatilities = vmap(calculate_daily_volatility)(jnp.arange(n_days)) + volatility_array = jnp.repeat(daily_volatilities, 1440) + + remaining_minutes = n_minutes - volatility_array.shape[0] + if remaining_minutes > 0: + volatility_array = jnp.concatenate( + [volatility_array, jnp.full((remaining_minutes,), daily_volatilities[-1])] + ) + + return volatility_array * jnp.sqrt(365.0) + + class _PoolState(NamedTuple): """Intermediate state produced by _init_pool_state. @@ -203,6 +270,48 @@ def _resolve_ste_temperature(run_fingerprint): """Resolve STE gate temperature for differentiable reCLAMM transitions.""" return run_fingerprint.get("ste_temperature") + def _resolve_noise_inputs( + self, + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + arb_len: int, + lp_supply_array: Optional[jnp.ndarray] = None, + ): + """Prepare optional lp-supply and noise-model inputs for reserve scans.""" + bout_length = run_fingerprint["bout_length"] + arb_freq = run_fingerprint["arb_frequency"] + + lp_prepared = None + if lp_supply_array is not None: + lp_prepared = _prepare_dynamic_array( + lp_supply_array, + start_index=start_index, + bout_length=bout_length, + arb_frequency=arb_freq, + max_len=arb_len, + ) + + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_params = run_fingerprint.get("reclamm_noise_params", None) + if noise_params is not None and type(noise_params) is not dict: + noise_params = dict(noise_params) + + arb_vol = None + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility_array = _calculate_annualized_ratio_volatility( + prices, run_fingerprint + ) + arb_vol = _prepare_dynamic_array( + volatility_array, + start_index=start_index, + bout_length=bout_length, + arb_frequency=arb_freq, + max_len=arb_len, + ) + + return lp_prepared, noise_model, noise_params, arb_vol + @partial(jit, static_argnums=(2,)) def calculate_reserves_with_fees( self, @@ -211,9 +320,17 @@ def calculate_reserves_with_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, + ) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_with_fees( @@ -232,6 +349,11 @@ def calculate_reserves_with_fees( centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), ste_temperature=ste_temperature, + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=lp_prepared, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -243,6 +365,7 @@ def calculate_reserves_and_fee_revenue_with_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ): """Calculate reserves and LP fee revenue with fees. @@ -254,6 +377,13 @@ def calculate_reserves_and_fee_revenue_with_fees( """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, + ) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( @@ -272,6 +402,11 @@ def calculate_reserves_and_fee_revenue_with_fees( centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), ste_temperature=ste_temperature, + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=lp_prepared, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) return ( jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape), @@ -298,10 +433,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) - bout_length = run_fingerprint["bout_length"] - max_len = bout_length - 1 - if run_fingerprint["arb_frequency"] != 1: - max_len = max_len // run_fingerprint["arb_frequency"] + max_len = s.arb_prices.shape[0] materialized_inputs = materialize_dynamic_inputs( dynamic_inputs, run_fingerprint.get("dynamic_input_flags"), @@ -310,6 +442,13 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) + _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=None, + ) return _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, @@ -321,6 +460,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( arb_thresh=materialized_inputs.gas_cost, arb_fees=materialized_inputs.arb_fees, price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, + lp_supply_array=materialized_inputs.lp_supply, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), @@ -328,6 +468,10 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), ste_temperature=ste_temperature, + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) @partial(jit, static_argnums=(2,)) @@ -338,10 +482,20 @@ def _calculate_reserves_zero_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Protected zero-fee implementation for hooks and weight calculation.""" s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) + lp_prepared = None + if lp_supply_array is not None: + lp_prepared = _prepare_dynamic_array( + lp_supply_array, + start_index=start_index, + bout_length=run_fingerprint["bout_length"], + arb_frequency=run_fingerprint["arb_frequency"], + max_len=s.arb_prices.shape[0], + ) if run_fingerprint["do_arb"]: return _jax_calc_reclamm_reserves_zero_fees( @@ -353,6 +507,7 @@ def _calculate_reserves_zero_fees( arc_length_speed=s.arc_length_speed, centeredness_scaling=s.centeredness_scaling, ste_temperature=ste_temperature, + lp_supply_array=lp_prepared, ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -363,9 +518,15 @@ def calculate_reserves_zero_fees( prices: jnp.ndarray, start_index: jnp.ndarray, additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: return self._calculate_reserves_zero_fees( - params, run_fingerprint, prices, start_index, additional_oracle_input + params, + run_fingerprint, + prices, + start_index, + additional_oracle_input, + lp_supply_array, ) @partial(jit, static_argnums=(2,)) @@ -380,10 +541,7 @@ def calculate_reserves_with_dynamic_inputs( ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) - bout_length = run_fingerprint["bout_length"] - max_len = bout_length - 1 - if run_fingerprint["arb_frequency"] != 1: - max_len = max_len // run_fingerprint["arb_frequency"] + max_len = s.arb_prices.shape[0] materialized_inputs = materialize_dynamic_inputs( dynamic_inputs, run_fingerprint.get("dynamic_input_flags"), @@ -392,6 +550,13 @@ def calculate_reserves_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) + _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=None, + ) return _jax_calc_reclamm_reserves_with_dynamic_inputs( s.initial_reserves, s.Va, s.Vb, @@ -403,6 +568,7 @@ def calculate_reserves_with_dynamic_inputs( arb_thresh=materialized_inputs.gas_cost, arb_fees=materialized_inputs.arb_fees, price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, + lp_supply_array=materialized_inputs.lp_supply, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), @@ -410,6 +576,10 @@ def calculate_reserves_with_dynamic_inputs( centeredness_scaling=s.centeredness_scaling, protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), ste_temperature=ste_temperature, + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + noise_model=noise_model, + noise_params=noise_params, + volatility_array=arb_vol, ) def init_base_parameters( diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 7e0bab8..3a0e89a 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -31,6 +31,12 @@ from quantammsim.pools.G3M.G3M_trades import ( _jax_calc_G3M_trade_from_exact_in_given_out, ) +from quantammsim.pools.noise_trades import ( + calculate_reserves_after_noise_trade, + reclamm_tsoukalas_sqrt_noise_volume, + reclamm_tsoukalas_log_noise_volume, + reclamm_loglinear_noise_volume, +) # Reference balance for initialisation (matches Solidity _INITIALIZATION_MAX_BALANCE_A) _INITIALIZATION_MAX_BALANCE_A = 1e6 @@ -622,7 +628,7 @@ def apply_target_price_ratio_to_virtual_balances(Ra, Rb, Va, Vb, target_price_ra def _reclamm_scan_step_zero_fees( carry_list, - prices, + input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, @@ -636,11 +642,23 @@ def _reclamm_scan_step_zero_fees( 1. Update virtual balances (path-dependent) 2. Compute analytical constant-product arb (no fee friction) - Carry: [real_reserves (2,), Va (0-d), Vb (0-d)] + Carry: [real_reserves (2,), Va (0-d), Vb (0-d), prev_lp_supply (0-d)] + Input: [prices (2,), lp_supply (0-d)] """ prev_reserves = carry_list[0] Va = carry_list[1] Vb = carry_list[2] + prev_lp_supply = carry_list[3] + + prices = input_list[0] + lp_supply = input_list[1] + + # Scale both real and virtual reserves by LP supply ratio. + scale = lp_supply / prev_lp_supply + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) + Va = jnp.where(lp_supply_change, Va * scale, Va) + Vb = jnp.where(lp_supply_change, Vb * scale, Vb) Ra = prev_reserves[0] Rb = prev_reserves[1] @@ -717,7 +735,7 @@ def _reclamm_scan_step_zero_fees( Rb_new = jnp.where(clamp_a, Rb + edge_a[1], jnp.where(clamp_b, Rb + edge_b[1], Rb_new)) new_reserves = jnp.array([Ra_new, Rb_new]) - return [new_reserves, Va, Vb], new_reserves + return [new_reserves, Va, Vb, lp_supply], new_reserves # --------------------------------------------------------------------------- @@ -729,7 +747,7 @@ def _reclamm_scan_step_zero_fees( def _reclamm_scan_step_zero_fees_full_state( carry_list, - prices, + input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, @@ -739,7 +757,7 @@ def _reclamm_scan_step_zero_fees_full_state( ): """TEST-ONLY: scan step that outputs (reserves, Va, Vb).""" new_carry, new_reserves = _reclamm_scan_step_zero_fees( - carry_list, prices, centeredness_margin, daily_price_shift_base, seconds_per_step, + carry_list, input_list, centeredness_margin, daily_price_shift_base, seconds_per_step, arc_length_speed=arc_length_speed, centeredness_scaling=centeredness_scaling, ste_temperature=ste_temperature, @@ -761,15 +779,20 @@ def _reclamm_scan_step_with_fees_and_revenue( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """Single scan step for reClAMM pool with fees, returning LP fee revenue. Primary implementation — ``_reclamm_scan_step_with_fees`` wraps this. Carry: [real_reserves (2,), Va, Vb, step_idx, active_start_ratio, - active_target_ratio, active_start_step, active_end_step, active_enabled] + active_target_ratio, active_start_step, active_end_step, active_enabled, + prev_lp_supply] Input: [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update] + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update, + lp_supply] Returns ------- @@ -786,9 +809,7 @@ def _reclamm_scan_step_with_fees_and_revenue( active_start_step = carry_list[6] active_end_step = carry_list[7] active_enabled = carry_list[8] - - Ra = prev_reserves[0] - Rb = prev_reserves[1] + prev_lp_supply = carry_list[9] prices = input_list[0] active_initial_weights = input_list[1] @@ -798,6 +819,18 @@ def _reclamm_scan_step_with_fees_and_revenue( arb_thresh = input_list[5] arb_fees = input_list[6] price_ratio_update = input_list[7] + lp_supply = input_list[8] + + # Scale both real and virtual reserves by LP supply ratio so liquidity + # add/remove events preserve proportional pool state. + scale = lp_supply / prev_lp_supply + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) + Va = jnp.where(lp_supply_change, Va * scale, Va) + Vb = jnp.where(lp_supply_change, Vb * scale, Vb) + + Ra = prev_reserves[0] + Rb = prev_reserves[1] event_has = price_ratio_update[0] > 0.5 event_target_ratio = jnp.maximum( @@ -973,6 +1006,39 @@ def _skip_schedule_state(_): Ra_new = Ra + applied_trade[0] Rb_new = Rb + applied_trade[1] + # Optional noise-trader model. + if noise_model == "ratio": + noisy_reserves = calculate_reserves_after_noise_trade( + applied_trade, jnp.array([Ra_new, Rb_new]), prices, + noise_trader_ratio, gamma, + ) + Ra_new = jnp.where(noise_trader_ratio > 0, noisy_reserves[0], Ra_new) + Rb_new = jnp.where(noise_trader_ratio > 0, noisy_reserves[1], Rb_new) + elif noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + volatility = input_list[9] + arb_volume = 0.5 * jnp.sum(jnp.abs(applied_trade) * prices) + real_value = jnp.sum(jnp.array([Ra_new, Rb_new]) * prices) + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + + _np = noise_params if noise_params is not None else {} + if noise_model == "tsoukalas_sqrt": + noise_vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value, gamma, volatility, arb_volume, _np + ) + elif noise_model == "tsoukalas_log": + noise_vol = reclamm_tsoukalas_log_noise_volume( + effective_value, gamma, volatility, arb_volume, _np + ) + else: + noise_vol = reclamm_loglinear_noise_volume( + effective_value, gamma, volatility, arb_volume, _np + ) + + noise_fee_income = (1.0 - gamma) * noise_vol + noise_scale = 1.0 + noise_fee_income / jnp.maximum(real_value, 1e-8) + Ra_new = Ra_new * noise_scale + Rb_new = Rb_new * noise_scale + # Clamp-to-edge: if a real reserve would go negative, apply an # exact-in-given-out edge trade that drains that token to _DUST_USD # worth of reserves (preserving the AMM invariant). @@ -1021,6 +1087,7 @@ def _skip_schedule_state(_): active_start_step, active_end_step, active_enabled, + lp_supply, ], (new_reserves, lp_fee_revenue_usd) @@ -1038,6 +1105,9 @@ def _reclamm_scan_step_with_fees( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """Single scan step for reClAMM pool with fees (reserves only). @@ -1057,6 +1127,9 @@ def _reclamm_scan_step_with_fees( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params, ) return new_carry, new_reserves @@ -1075,6 +1148,9 @@ def _reclamm_scan_step_with_fees_full_state( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, ): """TEST-ONLY: fee scan step that also outputs virtual balances.""" new_carry, (new_reserves, _fee_rev) = _reclamm_scan_step_with_fees_and_revenue( @@ -1090,6 +1166,9 @@ def _reclamm_scan_step_with_fees_full_state( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params, ) return new_carry, (new_reserves, new_carry[1], new_carry[2]) @@ -1106,6 +1185,7 @@ def _jax_calc_reclamm_reserves_zero_fees( arc_length_speed=0.0, centeredness_scaling=False, ste_temperature=10.0, + lp_supply_array=None, ): """Calculate reClAMM reserves over time with zero fees. @@ -1127,12 +1207,22 @@ def _jax_calc_reclamm_reserves_zero_fees( If > 0, use constant-arc-length thermostat instead of geometric. centeredness_scaling : bool If True, scale speed by margin/centeredness (proportional controller). + lp_supply_array : jnp.ndarray, optional + LP token supply over time, shape (T,). Defaults to constant 1.0. Returns ------- reserves : jnp.ndarray, shape (T, 2) Real reserves over time. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + scan_fn = Partial( _reclamm_scan_step_zero_fees, centeredness_margin=centeredness_margin, @@ -1143,8 +1233,8 @@ def _jax_calc_reclamm_reserves_zero_fees( ste_temperature=ste_temperature, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] - _, reserves = scan(scan_fn, carry_init, prices) + carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] + _, reserves = scan(scan_fn, carry_init, [prices, lp_supply_array]) return reserves @@ -1160,6 +1250,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( arc_length_speed=0.0, centeredness_scaling=False, ste_temperature=10.0, + lp_supply_array=None, ): """TEST-ONLY: Like _jax_calc_reclamm_reserves_zero_fees but returns Va/Vb. @@ -1169,6 +1260,14 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( Va_history : jnp.ndarray, shape (T,) Vb_history : jnp.ndarray, shape (T,) """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + scan_fn = Partial( _reclamm_scan_step_zero_fees_full_state, centeredness_margin=centeredness_margin, @@ -1179,12 +1278,14 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( ste_temperature=ste_temperature, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] - _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, prices) + carry_init = [initial_reserves, initial_Va, initial_Vb, lp_supply_array[0]] + _, (reserves, Va_history, Vb_history) = scan( + scan_fn, carry_init, [prices, lp_supply_array] + ) return reserves, Va_history, Vb_history -@jit +@partial(jit, static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_fees( initial_reserves, initial_Va, @@ -1201,12 +1302,25 @@ def _jax_calc_reclamm_reserves_with_fees( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves over time with fees. Uses the G3M optimal arb machinery with constant weights [0.5, 0.5] applied to effective reserves (real + virtual). """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) gamma = 1.0 - fees @@ -1242,6 +1356,9 @@ def _jax_calc_reclamm_reserves_with_fees( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ @@ -1254,17 +1371,27 @@ def _jax_calc_reclamm_reserves_with_fees( jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, reserves = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma_array, + arb_thresh_array, + arb_fees_array, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1284,8 +1411,21 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves with time-varying fees/arb arrays.""" + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1334,6 +1474,9 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ @@ -1346,17 +1489,27 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, reserves = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( initial_reserves, initial_Va, @@ -1376,8 +1529,21 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """TEST-ONLY: dynamic-input reserve path returning virtual-balance history.""" + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1425,6 +1591,9 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ @@ -1437,17 +1606,27 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, (reserves, Va_history, Vb_history) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, (reserves, Va_history, Vb_history) = scan(scan_fn, carry_init, scan_inputs) return reserves, Va_history, Vb_history -@jit +@partial(jit, static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( initial_reserves, initial_Va, @@ -1464,6 +1643,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves and LP fee revenue over time with fees. @@ -1473,6 +1657,14 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( fee_revenue : jnp.ndarray, shape (T,) LP fee revenue per timestep in USD. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) gamma = 1.0 - fees @@ -1507,6 +1699,9 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ @@ -1519,17 +1714,27 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, (reserves, fee_revenue) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma_array, + arb_thresh_array, + arb_fees_array, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue -@partial(jit, static_argnums=(11,)) +@partial(jit, static_argnums=(11,), static_argnames=("noise_model",)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1549,6 +1754,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( centeredness_scaling=False, protocol_fee_split=0.0, ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, ): """Calculate reClAMM reserves and LP fee revenue with time-varying fees/arb arrays. @@ -1558,6 +1768,14 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( fee_revenue : jnp.ndarray, shape (T,) LP fee revenue per timestep in USD. """ + if lp_supply_array is None: + lp_supply_array = jnp.array(1.0) + lp_supply_array = jnp.where( + lp_supply_array.size == 1, + jnp.full(prices.shape[0], lp_supply_array), + lp_supply_array, + ) + n_assets = 2 weights = jnp.array([0.5, 0.5]) @@ -1605,6 +1823,9 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( centeredness_scaling=centeredness_scaling, protocol_fee_split=protocol_fee_split, ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, ) carry_init = [ @@ -1617,11 +1838,21 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( jnp.float64(0.0), # active_start_step jnp.float64(0.0), # active_end_step jnp.array(False), # active_enabled + lp_supply_array[0], # prev_lp_supply ] - _, (reserves, fee_revenue) = scan( - scan_fn, - carry_init, - [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + ] + if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + scan_inputs.append(volatility_array) + + _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue From d39ddcb035a47ede5517b4fbc23fb9fab9df0e07 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Sat, 7 Mar 2026 15:59:58 +0000 Subject: [PATCH 09/14] add missing imports --- pyproject.toml | 4 +++- setup.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8859b55..19cb8f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "plotly", "dask", "Historic-Crypto", + "gdown", + "binance_historical_data", "bidask", "optax", "jsonpickle", @@ -113,4 +115,4 @@ exclude_lines = [ "def __repr__", "raise NotImplementedError", "if __name__ == .__main__.:", -] \ No newline at end of file +] diff --git a/setup.py b/setup.py index 20eb403..1aa0b94 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "pyarrow", "plotly", "bidask", - "Historic_Crypto", + "Historic-Crypto", "gdown", "binance_historical_data", "dask", @@ -30,10 +30,12 @@ ], extras_require={ "dev": [ - "pytest>=6.0", + "pytest>=7.0", + "pytest-cov>=4.0", + "pytest-xdist>=3.0", + "pytest-timeout>=2.0", "black", "flake8", - "pytest-cov", "hypothesis", ], "docs": [ @@ -41,6 +43,10 @@ "sphinx-automodapi", "sphinx-rtd-theme", ], + "calibration": [ + "numpyro>=0.15.0", + "arviz>=0.15.0", + ], }, python_requires=">=3.9", ) From 9e2485761961c0ec9a8f1a9e3cfa0411556075f3 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Sat, 7 Mar 2026 16:16:21 +0000 Subject: [PATCH 10/14] add reclamm private repo port fixes --- quantammsim/pools/reCLAMM/reclamm.py | 57 +--------------------------- 1 file changed, 2 insertions(+), 55 deletions(-) diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index aa57d3e..afd2496 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -10,7 +10,7 @@ config.update("jax_enable_x64", True) import jax.numpy as jnp -from jax import jit, tree_util, vmap +from jax import jit, tree_util from jax.lax import dynamic_slice from functools import partial from typing import Dict, Any, Optional, NamedTuple @@ -50,57 +50,6 @@ def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len return sliced -def _align_prices_to_numeraire(prices, run_fingerprint): - """Ensure the numeraire token is in column 1 for ratio-volatility calc.""" - tokens = run_fingerprint.get("tokens") - numeraire = run_fingerprint.get("numeraire") - if tokens is None or numeraire is None or len(tokens) != 2: - return prices - - token_labels = [str(token).lower() for token in tokens] - numeraire_label = str(numeraire).lower() - if token_labels[0] == numeraire_label: - return prices[:, ::-1] - return prices - - -def _calculate_annualized_ratio_volatility( - prices, run_fingerprint, subsample_freq=5, -): - """Annualized daily realized volatility broadcast to minute-level array.""" - ordered_prices = _align_prices_to_numeraire(prices, run_fingerprint) - asset_prices = ordered_prices[:, 0] / ordered_prices[:, 1] - n_minutes = asset_prices.shape[0] - - if n_minutes < 1440: - return jnp.full((n_minutes,), 0.1 * jnp.sqrt(365.0), dtype=prices.dtype) - - n_days = n_minutes // 1440 - - def calculate_daily_volatility(day_idx): - start_idx = day_idx * 1440 - window_prices = dynamic_slice(asset_prices, (start_idx,), (1440,)) - subsampled_prices = window_prices[::subsample_freq] - log_prices = jnp.log(jnp.maximum(subsampled_prices, 1e-8)) - returns = jnp.diff(log_prices) - num_nonzero_returns = jnp.sum(returns != 0) - total_returns = jnp.maximum(returns.shape[0], 1) - adjusted_variance = num_nonzero_returns * jnp.var(returns) / total_returns - dt = subsample_freq / 1440 - return jnp.sqrt(adjusted_variance) / jnp.sqrt(dt) - - daily_volatilities = vmap(calculate_daily_volatility)(jnp.arange(n_days)) - volatility_array = jnp.repeat(daily_volatilities, 1440) - - remaining_minutes = n_minutes - volatility_array.shape[0] - if remaining_minutes > 0: - volatility_array = jnp.concatenate( - [volatility_array, jnp.full((remaining_minutes,), daily_volatilities[-1])] - ) - - return volatility_array * jnp.sqrt(365.0) - - class _PoolState(NamedTuple): """Intermediate state produced by _init_pool_state. @@ -299,9 +248,7 @@ def _resolve_noise_inputs( arb_vol = None if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = _calculate_annualized_ratio_volatility( - prices, run_fingerprint - ) + volatility_array = self.calculate_volatility_array(prices, run_fingerprint) arb_vol = _prepare_dynamic_array( volatility_array, start_index=start_index, From 6d45c270c54e567dbee9f0e774208572669f6b06 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Mon, 9 Mar 2026 12:28:42 +0000 Subject: [PATCH 11/14] ci: install calibration extra and add reclamm branch trigger --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6e0239d..f311c9c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,9 +2,9 @@ name: Tests on: push: - branches: [main, master, dev] + branches: [main, master, dev, reclamm] pull_request: - branches: [main, master, dev] + branches: [main, master, dev, reclamm] jobs: test: @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e ".[dev]" + pip install -e ".[dev,calibration]" - name: Run tests with coverage run: | From 1b9d1ebe807783e5bab150a46dcf93124d601514 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:29:14 +0000 Subject: [PATCH 12/14] fix: add ste_temperature to test fingerprints Tests that construct run_fingerprint dicts directly (bypassing recursive_default_set) need the ste_temperature key now that the STE-enabled scan steps read it from the fingerprint. --- tests/pools/reCLAMM/test_reclamm_fee_revenue.py | 3 +++ tests/pools/reCLAMM/test_reclamm_reserves.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index bfbd21b..106f8a3 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -297,6 +297,7 @@ def test_pool_method_with_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -342,6 +343,7 @@ def test_pool_method_with_dynamic_inputs(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -408,6 +410,7 @@ def test_forward_pass_returns_fee_revenue(self): "rule": "reclamm", "training_data_kind": "historic", "do_trades": False, + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index cf40217..396f50f 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -272,6 +272,7 @@ def test_calculate_reserves_with_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -314,6 +315,7 @@ def test_calculate_reserves_zero_fees(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -353,6 +355,7 @@ def test_calculate_weights(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -485,6 +488,7 @@ def test_fingerprint_dispatch(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "reclamm_interpolation_method": "constant_arc_length", "reclamm_arc_length_speed": None, # auto-calibrate + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -690,6 +694,7 @@ def test_learnable_arc_length_speed_forward_pass(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "reclamm_interpolation_method": "constant_arc_length", "reclamm_learn_arc_length_speed": True, + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) From 8bd2c598205310415f7828477b39166d4473ba83 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:55:23 +0000 Subject: [PATCH 13/14] fix: replace STE gate on fees check with hard check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fees_gate STE was unnecessary — gamma (1 - fees) is either a static config value or a learnable param constrained to be nonzero, so the fee/zero-fee branch selection never benefits from soft gradients. --- quantammsim/pools/reCLAMM/reclamm_reserves.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 3a0e89a..1ea939a 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -987,10 +987,8 @@ def _skip_schedule_state(_): 0, ) - fees_gate = _ste_greater_than( - jnp.abs(gamma - 1.0), jnp.asarray(1e-12, dtype=gamma.dtype), ste_temperature - ) - optimal_arb_trade = _ste_select(fees_gate, fee_trade, zero_fee_trade) + fees_are_being_charged = gamma != 1.0 + optimal_arb_trade = jnp.where(fees_are_being_charged, fee_trade, zero_fee_trade) # Check profitability for arb profit_to_arb = -(optimal_arb_trade * prices).sum() - arb_thresh From 43024c36e236cac5315b8dea06258a418c347963 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:31:55 +0000 Subject: [PATCH 14/14] fix: add missing ste_temperature to test fingerprints after STE merge --- tests/pools/reCLAMM/test_reclamm_noise_volume.py | 5 +++++ tests/pools/reCLAMM/test_reclamm_reserves.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/tests/pools/reCLAMM/test_reclamm_noise_volume.py b/tests/pools/reCLAMM/test_reclamm_noise_volume.py index f413a9e..500ae07 100644 --- a/tests/pools/reCLAMM/test_reclamm_noise_volume.py +++ b/tests/pools/reCLAMM/test_reclamm_noise_volume.py @@ -547,6 +547,7 @@ def test_tsoukalas_sqrt_from_fingerprint(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "tsoukalas_sqrt", "reclamm_noise_params": DEFAULT_NOISE_PARAMS, + "ste_temperature": 10.0, }) # Fingerprint without noise @@ -563,6 +564,7 @@ def test_tsoukalas_sqrt_from_fingerprint(self): "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "arb_only", + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) @@ -618,6 +620,7 @@ def test_volatility_auto_computed_affects_fee_revenue(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, } fp_tsoukalas = Hashabledict({ @@ -946,6 +949,7 @@ def test_loglinear_from_fingerprint(self): "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "loglinear", "reclamm_noise_params": loglinear_params, + "ste_temperature": 10.0, }) fp_arb_only = Hashabledict({ @@ -961,6 +965,7 @@ def test_loglinear_from_fingerprint(self): "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), "noise_model": "arb_only", + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0]) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index 43263cb..58d9dae 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -929,6 +929,7 @@ def test_noise_trader_ratio_through_pool_class(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, } fp_no_noise = Hashabledict({**base_fp, "noise_trader_ratio": 0.0}) @@ -1291,6 +1292,7 @@ def test_lp_supply_through_pool_class(self): "tokens": ("ETH", "USDC"), "numeraire": "USDC", "all_sig_variations": tuple(map(tuple, [[1, -1], [-1, 1]])), + "ste_temperature": 10.0, }) start_index = jnp.array([0, 0])