diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py new file mode 100644 index 0000000..c249088 --- /dev/null +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import jax.numpy as jnp + + +@dataclass(frozen=True) +class DynamicInputFrames: + """Outer-layer container for optional pandas-backed dynamic inputs.""" + + trades: Optional[Any] = None + fees: Optional[Any] = None + gas_cost: Optional[Any] = None + arb_fees: Optional[Any] = None + lp_supply: Optional[Any] = None + + +class DynamicInputArrays(NamedTuple): + """JAX pytree for dynamic simulation inputs with optional trade data.""" + + trades: Optional[jnp.ndarray] + fees: jnp.ndarray + gas_cost: jnp.ndarray + arb_fees: jnp.ndarray + lp_supply: jnp.ndarray + + +def default_dynamic_input_flags() -> dict: + """Static dispatch flags for forward-pass path selection.""" + return { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + } + + +def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputFrames]) -> dict: + """Build stable dispatch flags from the outer-layer frame container.""" + if dynamic_input_frames is None: + return default_dynamic_input_flags() + + flags = { + "use_dynamic_inputs": False, + "has_trades": dynamic_input_frames.trades is not None, + "has_dynamic_fees": dynamic_input_frames.fees is not None, + "has_dynamic_gas_cost": dynamic_input_frames.gas_cost is not None, + "has_dynamic_arb_fees": dynamic_input_frames.arb_fees is not None, + "has_lp_supply": dynamic_input_frames.lp_supply is not None, + } + flags["use_dynamic_inputs"] = any(flags.values()) + return flags + + +def resolve_dynamic_input_flags( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: Optional[dict] = None, +) -> dict: + """Return a safe dispatch flag set for the provided hot-path bundle.""" + flags = ( + default_dynamic_input_flags() + if dynamic_input_flags is None + else dict(dynamic_input_flags) + ) + if dynamic_inputs is not None: + flags["use_dynamic_inputs"] = True + return flags + + +def empty_dynamic_input_arrays() -> DynamicInputArrays: + """Create a canonical empty bundle.""" + return DynamicInputArrays( + trades=None, + fees=jnp.zeros((1,), dtype=jnp.float64), + gas_cost=jnp.zeros((1,), dtype=jnp.float64), + arb_fees=jnp.zeros((1,), dtype=jnp.float64), + lp_supply=jnp.ones((1,), dtype=jnp.float64), + ) + + +def resolve_dynamic_input_components( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: dict, + static_dict: dict, +) -> dict: + """Resolve dynamic-input leaves against static scalar defaults.""" + arrays = empty_dynamic_input_arrays() if dynamic_inputs is None else dynamic_inputs + return { + "trades": arrays.trades if dynamic_input_flags["has_trades"] else None, + "fees": ( + arrays.fees + if dynamic_input_flags["has_dynamic_fees"] + else jnp.asarray([static_dict["fees"]], dtype=jnp.float64) + ), + "gas_cost": ( + arrays.gas_cost + if dynamic_input_flags["has_dynamic_gas_cost"] + else jnp.asarray([static_dict["gas_cost"]], dtype=jnp.float64) + ), + "arb_fees": ( + arrays.arb_fees + if dynamic_input_flags["has_dynamic_arb_fees"] + else jnp.asarray([static_dict["arb_fees"]], dtype=jnp.float64) + ), + "lp_supply": ( + arrays.lp_supply + if dynamic_input_flags["has_lp_supply"] + else jnp.ones((1,), dtype=jnp.float64) + ), + } + + +def _broadcast_dynamic_input_leaf( + input_name: str, + values: jnp.ndarray, + scan_len: int, + dtype, +) -> jnp.ndarray: + """Broadcast a singleton dynamic-input leaf to the scan length.""" + values = jnp.asarray(values, dtype=dtype) + if values.ndim == 0: + values = values.reshape((1,)) + if values.shape[0] == scan_len: + return values + if values.shape[0] == 1: + return jnp.broadcast_to(values, (scan_len,) + values.shape[1:]) + raise ValueError( + f"{input_name} has leading axis {values.shape[0]}, expected 1 or {scan_len}" + ) + + +def materialize_dynamic_inputs( + dynamic_inputs: Optional[DynamicInputArrays], + dynamic_input_flags: Optional[dict], + static_dict: dict, + scan_len: int, + do_trades: bool, + dtype=jnp.float64, +) -> DynamicInputArrays: + """Resolve and broadcast dynamic inputs for a specific scan length.""" + if dynamic_input_flags is None and dynamic_inputs is not None: + flags = { + "use_dynamic_inputs": True, + "has_trades": do_trades, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": True, + "has_dynamic_arb_fees": True, + "has_lp_supply": True, + } + else: + flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags) + + resolved = resolve_dynamic_input_components(dynamic_inputs, flags, static_dict) + + trades = None + if do_trades: + if resolved["trades"] is None: + raise ValueError("Trades must be provided when do_trades=True.") + trades = _broadcast_dynamic_input_leaf( + "trades", resolved["trades"], scan_len, dtype + ) + + return DynamicInputArrays( + trades=trades, + fees=_broadcast_dynamic_input_leaf("fees", resolved["fees"], scan_len, dtype), + gas_cost=_broadcast_dynamic_input_leaf( + "gas_cost", resolved["gas_cost"], scan_len, dtype + ), + arb_fees=_broadcast_dynamic_input_leaf( + "arb_fees", resolved["arb_fees"], scan_len, dtype + ), + lp_supply=_broadcast_dynamic_input_leaf( + "lp_supply", resolved["lp_supply"], scan_len, dtype + ), + ) diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index ef8eaa5..c1509e9 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -53,11 +53,25 @@ import numpy as np from functools import partial +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + default_dynamic_input_flags, + resolve_dynamic_input_flags, +) np.seterr(all="raise") np.seterr(under="print") +def _resolve_dynamic_inputs(dynamic_inputs, static_dict): + """Return the incoming bundle plus static dispatch flags.""" + dynamic_input_flags = resolve_dynamic_input_flags( + dynamic_inputs, + static_dict.get("dynamic_input_flags"), + ) + return dynamic_inputs, dynamic_input_flags + + def _apply_price_noise(prices, sigma, seed_int): """Apply multiplicative log-normal noise to prices. @@ -839,15 +853,12 @@ def _calculate_return_value( return return_metrics[return_val]() -@partial(jit, static_argnums=(7, 8)) +@partial(jit, static_argnums=(4, 5)) def forward_pass( params, start_index, prices, - trades_array=None, - fees_array=None, - gas_cost_array=None, - arb_fees_array=None, + dynamic_inputs=None, pool=None, static_dict=None, ): @@ -870,17 +881,8 @@ def forward_pass( prices : array-like A 2D array of market prices for the assets involved in the simulation. - trades_array : array-like, optional - An array of trades to be considered in the simulation. Defaults to None. - - fees_array : array-like, optional - An array of fees to be applied during the simulation. Defaults to None. - - gas_cost_array : array-like, optional - An array of gas costs to be considered in the simulation. Defaults to None. - - arb_fees_array : array-like, optional - An array of arbitrage fees to be applied during the simulation. Defaults to None. + dynamic_inputs : DynamicInputArrays, optional + Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays. pool : object An instance of a pool object that provides methods @@ -930,8 +932,8 @@ def forward_pass( - The function handles different cases for fees and trades, adjusting the calculation method accordingly: - 1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`, - or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`. + 1. If any dynamic-input flags are enabled, it uses + `pool.calculate_reserves_with_dynamic_inputs`. 2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value, it uses `pool.calculate_reserves_with_fees`. @@ -972,6 +974,7 @@ def forward_pass( "training_data_kind": "historic", "arb_frequency": 1, "do_trades": False, + "dynamic_input_flags": default_dynamic_input_flags(), } # 'pool' has default of None only to handle how partial function @@ -1008,10 +1011,7 @@ def forward_pass( and static_dict["arb_frequency"] == 1 and static_dict.get("turnover_penalty", 0.0) == 0.0 and static_dict.get("price_noise_sigma", 0.0) == 0.0 - and all( - ele is None - for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array] - ) + and dynamic_inputs is None and 1440 % static_dict["chunk_period"] == 0 # chunk_period divides metric_period and not pool._rule_outputs_are_weights # only delta-based pools validated and static_dict["bout_length"] > 1440 * 2 # need ≥2 metric periods @@ -1031,28 +1031,20 @@ def forward_pass( # 1. Any of Fees, gas costs, and arb fees are provided as arrays, or trades are provided # 2. Any of Fees, gas costs, and arb fees are nonzero scalar values, with no trades provided # 3. Fees, gas costs, and arb fees are all zero, with no trades provided + dynamic_inputs, dynamic_input_flags = _resolve_dynamic_inputs( + dynamic_inputs, static_dict + ) + fee_revenue = None - if any( - ele is not None - for ele in [fees_array, gas_cost_array, arb_fees_array, trades_array] - ): - # Case 1, at least one of fees, gas costs, or arb fees is not None - if fees_array is None: - fees_array = jnp.array([static_dict["fees"]]) - if gas_cost_array is None: - gas_cost_array = jnp.array([static_dict["gas_cost"]]) - if arb_fees_array is None: - arb_fees_array = jnp.array([static_dict["arb_fees"]]) + if dynamic_input_flags["use_dynamic_inputs"]: + # Case 1, at least one dynamic input is enabled if hasattr(pool, "calculate_reserves_and_fee_revenue_with_dynamic_inputs"): reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( params, static_dict, prices, start_index, - fees_array=fees_array, - arb_thresh_array=gas_cost_array, - arb_fees_array=arb_fees_array, - trade_array=trades_array, + dynamic_inputs=dynamic_inputs, ) else: reserves = pool.calculate_reserves_with_dynamic_inputs( @@ -1060,10 +1052,7 @@ def forward_pass( static_dict, prices, start_index, - fees_array=fees_array, - arb_thresh_array=gas_cost_array, - arb_fees_array=arb_fees_array, - trade_array=trades_array, + dynamic_inputs=dynamic_inputs, ) elif True in ( ele > 0.0 @@ -1170,15 +1159,12 @@ def forward_pass( return base_metric -@partial(jit, static_argnums=(7, 8)) +@partial(jit, static_argnums=(4, 5)) def forward_pass_nograd( params, start_index, prices, - trades_array=None, - fees_array=None, - gas_cost_array=None, - arb_fees_array=None, + dynamic_inputs=None, pool=None, static_dict=None, ): @@ -1203,17 +1189,8 @@ def forward_pass_nograd( prices : array-like A 2D array of market prices for the assets involved in the simulation. - trades_array : array-like, optional - An array of trades to be considered in the simulation. Defaults to None. - - fees_array : array-like, optional - An array of fees to be applied during the simulation. Defaults to None. - - gas_cost_array : array-like, optional - An array of gas costs to be considered in the simulation. Defaults to None. - - arb_fees_array : array-like, optional - An array of arbitrage fees to be applied during the simulation. Defaults to None. + dynamic_inputs : DynamicInputArrays, optional + Fixed-structure bundle of dynamic trades/fees/gas/arb/LP arrays. pool : object An instance of a pool object that provides methods @@ -1263,8 +1240,8 @@ def forward_pass_nograd( - The function handles different cases for fees and trades, adjusting the calculation method accordingly: - 1. If any of `fees_array`, `gas_cost_array`, `arb_fees_array`, - or `trades_array` is provided, it uses `pool.calculate_reserves_with_dynamic_inputs`. + 1. If any dynamic-input flags are enabled, it uses + `pool.calculate_reserves_with_dynamic_inputs`. 2. If any of `fees`, `gas_cost`, or `arb_fees` in `static_dict` is a nonzero scalar value, it uses `pool.calculate_reserves_with_fees`. @@ -1289,14 +1266,23 @@ def forward_pass_nograd( params = {k: stop_gradient(v) for k, v in params.items()} start_index = stop_gradient(start_index) prices = stop_gradient(prices) + if dynamic_inputs is not None: + dynamic_inputs = DynamicInputArrays( + trades=( + None + if dynamic_inputs.trades is None + else stop_gradient(dynamic_inputs.trades) + ), + fees=stop_gradient(dynamic_inputs.fees), + gas_cost=stop_gradient(dynamic_inputs.gas_cost), + arb_fees=stop_gradient(dynamic_inputs.arb_fees), + lp_supply=stop_gradient(dynamic_inputs.lp_supply), + ) return forward_pass( params, start_index, prices, - trades_array, - fees_array, - gas_cost_array, - arb_fees_array, + dynamic_inputs, pool, static_dict, ) diff --git a/quantammsim/core_simulator/windowing_utils.py b/quantammsim/core_simulator/windowing_utils.py index 052a77e..07fcca7 100644 --- a/quantammsim/core_simulator/windowing_utils.py +++ b/quantammsim/core_simulator/windowing_utils.py @@ -201,11 +201,12 @@ def raw_fee_like_amounts_to_fee_like_array( ).astype(int) // 10**6 )[:-1] + fill_value = np.nan if fill_method == "ffill" else 0.0 full_index_df = pd.DataFrame( - index=full_index, - columns=names, - data=0, - dtype=np.float64 + index=full_index, + columns=names, + data=fill_value, + dtype=np.float64, ) # Map raw data to the full index DataFrame @@ -231,15 +232,16 @@ def raw_fee_like_amounts_to_fee_like_array( # Ensure unix values are valid valid_unix = pd.to_numeric(raw_inputs['unix'], errors='coerce') valid_mask = valid_unix.notna() + valid_inputs = raw_inputs.loc[valid_mask].copy() + valid_inputs["unix"] = valid_unix.loc[valid_mask].astype(np.int64) + valid_inputs = valid_inputs.sort_values("unix") for name in names: initial_value = None - if valid_mask.any(): + if not valid_inputs.empty: # Try to get the last value before our start date - previous_values = raw_inputs[ - valid_mask & (valid_unix < start_unix) - ] + previous_values = valid_inputs[valid_inputs["unix"] < start_unix] if not previous_values.empty: try: @@ -249,9 +251,7 @@ def raw_fee_like_amounts_to_fee_like_array( if initial_value is None or pd.isna(initial_value): # Try to get first value in our date range - in_range_values = raw_inputs[ - valid_mask & (valid_unix >= start_unix) - ] + in_range_values = valid_inputs[valid_inputs["unix"] >= start_unix] if not in_range_values.empty: try: initial_value = pd.to_numeric(in_range_values[name].iloc[0]) @@ -259,17 +259,12 @@ def raw_fee_like_amounts_to_fee_like_array( initial_value = None if initial_value is not None and pd.notna(initial_value): - # this more complex logic is because of how we have started with prior-to-start values - # filled in, and then we want to ffill the rest - # Fill initial values - full_index_df[name] = full_index_df[name].mask( - full_index_df[name] == 0, - initial_value - ) - # Use ffill() - full_index_df[name] = full_index_df[name].where( - full_index_df[name] != 0 - ).ffill() + # Seed only the leading gap; explicit in-range updates must remain intact. + first_row = full_index_df.index[0] + if pd.isna(full_index_df.at[first_row, name]): + full_index_df.at[first_row, name] = initial_value + + full_index_df[name] = full_index_df[name].ffill().fillna(0.0) except (ValueError, KeyError, TypeError) as e: print(f"Warning: Error during ffill processing: {str(e)}") # On any error, return the original zero-filled DataFrame @@ -382,4 +377,4 @@ def filter_reserves_by_given_timestamp(reserves, unix_values, timestamp): unix_values == timestamp )[0][0] - return reserves[reserves_index].copy() \ No newline at end of file + return reserves[reserves_index].copy() diff --git a/quantammsim/hooks/dynamic_fee_base_hook.py b/quantammsim/hooks/dynamic_fee_base_hook.py index 40a9714..ad64a5f 100644 --- a/quantammsim/hooks/dynamic_fee_base_hook.py +++ b/quantammsim/hooks/dynamic_fee_base_hook.py @@ -3,6 +3,10 @@ import jax.numpy as jnp from jax.lax import dynamic_slice +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + empty_dynamic_input_arrays, +) class BaseDynamicFeeHook(ABC): """Mixin class to add dynamic fee calculation capabilities to pools. @@ -113,16 +117,21 @@ def calculate_reserves_with_fees( (int((bout_length) / chunk_period), 1), ) dynamic_fees = raw_dynamic_fees.repeat(chunk_period, axis=0).squeeze() - # Use existing dynamic inputs infrastructure + empty_inputs = empty_dynamic_input_arrays() + dynamic_inputs = DynamicInputArrays( + trades=None, + fees=dynamic_fees, + gas_cost=jnp.asarray(run_fingerprint["gas_cost"], dtype=jnp.float64), + arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64), + lp_supply=empty_inputs.lp_supply, + ) + return self.calculate_reserves_with_dynamic_inputs( params, run_fingerprint, prices, start_index, - dynamic_fees, - run_fingerprint["gas_cost"], - run_fingerprint["arb_fees"], - dynamic_fees, + dynamic_inputs, additional_oracle_input, ) diff --git a/quantammsim/pools/ECLP/gyroscope.py b/quantammsim/pools/ECLP/gyroscope.py index f1fef2a..517a7ba 100644 --- a/quantammsim/pools/ECLP/gyroscope.py +++ b/quantammsim/pools/ECLP/gyroscope.py @@ -31,6 +31,7 @@ from typing import Dict, Any, Optional, Tuple import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.ECLP.gyroscope_reserves import ( @@ -321,11 +322,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: # Gyroscope ECLP pools are only defined for 2 assets @@ -356,34 +353,27 @@ def calculate_reserves_with_dynamic_inputs( sin=jnp.sin(phi), cos=jnp.cos(phi), ) - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, + ) # Handle trade array reordering if needed if run_fingerprint["do_trades"]: - # if we are doing trades, the trades array must be of the same length as the other arrays - assert trade_array.shape[0] == max_len if needs_swap: # Swap trade indices (0->1, 1->0) but keep amounts unchanged - trade_array = trade_array.at[:, :2].set(1 - trade_array[:, :2]) - - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) + materialized_inputs = materialized_inputs._replace( + trades=materialized_inputs.trades.at[:, :2].set( + 1 - materialized_inputs.trades[:, :2] + ) + ) # Calculate reserves reserves = _jax_calc_gyroscope_reserves_with_dynamic_inputs( @@ -394,10 +384,10 @@ def calculate_reserves_with_dynamic_inputs( sin=jnp.sin(phi), cos=jnp.cos(phi), lam=lam, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, - trades=trade_array, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, + trades=materialized_inputs.trades, do_trades=run_fingerprint["do_trades"], ) # Restore original order if we swapped diff --git a/quantammsim/pools/ECLP/gyroscope_reserves.py b/quantammsim/pools/ECLP/gyroscope_reserves.py index 3dde3fb..4e7d969 100644 --- a/quantammsim/pools/ECLP/gyroscope_reserves.py +++ b/quantammsim/pools/ECLP/gyroscope_reserves.py @@ -605,7 +605,7 @@ def _jax_calc_gyroscope_reserves_with_dynamic_fees_and_trades_scan_function_usin gamma = input_list[1] arb_thresh = input_list[2] arb_fees = input_list[3] - trade = input_list[4] + trade = input_list[4] if do_trades else None @@ -727,6 +727,8 @@ def _jax_calc_gyroscope_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") scan_fn = Partial( _jax_calc_gyroscope_reserves_with_dynamic_fees_and_trades_scan_function_using_precalcs, @@ -745,17 +747,15 @@ def _jax_calc_gyroscope_reserves_with_dynamic_inputs( initial_reserves, 0 ] - carry_list_end, reserves = scan( - scan_fn, - carry_list_init, - [ - prices, - gamma, - arb_thresh, - arb_fees, - trades, - ], - ) + scan_inputs = [ + prices, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + carry_list_end, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/FM_AMM/cow_pool.py b/quantammsim/pools/FM_AMM/cow_pool.py index 5fd29bb..13c171b 100644 --- a/quantammsim/pools/FM_AMM/cow_pool.py +++ b/quantammsim/pools/FM_AMM/cow_pool.py @@ -30,6 +30,7 @@ from functools import partial import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.FM_AMM.cow_reserves import ( _jax_calc_cowamm_reserves_with_fees, @@ -56,10 +57,9 @@ class CowPool(AbstractPool): start_index, additional_oracle_input=None) -> jnp.ndarray: Calculates the reserves of the pool without considering fees. - calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, - start_index, fees_array, arb_thresh_array, arb_fees_array, trade_array, - additional_oracle_input=None) -> jnp.ndarray: - Calculates the reserves of the pool with dynamic inputs for fees, + calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, + start_index, dynamic_inputs, additional_oracle_input=None) -> jnp.ndarray: + Calculates the reserves of the pool with dynamic inputs for fees, arbitrage thresholds, arbitrage fees, and trades. init_base_parameters(initial_values_dict, run_fingerprint, n_assets, @@ -194,11 +194,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: bout_length = run_fingerprint["bout_length"] @@ -218,38 +214,27 @@ def calculate_reserves_with_dynamic_inputs( initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len reserves = _jax_calc_cowamm_reserves_with_dynamic_inputs( initial_reserves, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, weights, run_fingerprint["arb_quality"], - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], noise_trader_ratio=run_fingerprint["noise_trader_ratio"], diff --git a/quantammsim/pools/FM_AMM/cow_reserves.py b/quantammsim/pools/FM_AMM/cow_reserves.py index f876d51..ab340a0 100644 --- a/quantammsim/pools/FM_AMM/cow_reserves.py +++ b/quantammsim/pools/FM_AMM/cow_reserves.py @@ -729,7 +729,7 @@ def _jax_calc_cowamm_reserves_with_dynamic_fees_and_trades_scan_function( gamma = input_list[1] arb_thresh = input_list[2] arb_fees = input_list[3] - trade = input_list[4] + trade = input_list[4] if do_trades else None if do_arb: reserves_with_perfect_arb = _jax_calc_cowamm_reserves_with_fees_scan_function( @@ -827,6 +827,8 @@ def _jax_calc_cowamm_reserves_with_dynamic_inputs( initial_prices = prices[0] gamma = 1.0 - fees + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") scan_fn = Partial( _jax_calc_cowamm_reserves_with_dynamic_fees_and_trades_scan_function, @@ -838,8 +840,9 @@ def _jax_calc_cowamm_reserves_with_dynamic_inputs( ) carry_list_init = [initial_prices, initial_reserves] - _, reserves = scan( - scan_fn, carry_list_init, [prices, gamma, arb_thresh, arb_fees, trades] - ) + scan_inputs = [prices, gamma, arb_thresh, arb_fees] + if do_trades: + scan_inputs.append(trades) + _, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/G3M/balancer/balancer.py b/quantammsim/pools/G3M/balancer/balancer.py index 4b6cfc2..fcee272 100644 --- a/quantammsim/pools/G3M/balancer/balancer.py +++ b/quantammsim/pools/G3M/balancer/balancer.py @@ -8,6 +8,7 @@ import jax.numpy as jnp from jax.lax import dynamic_slice +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.G3M.balancer.balancer_reserves import ( _jax_calc_balancer_reserve_ratios, @@ -255,11 +256,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ @@ -287,14 +284,8 @@ def calculate_reserves_with_dynamic_inputs( Price history array start_index : jnp.ndarray Starting index for the calculation window - fees_array : jnp.ndarray - Time-varying trading fees - arb_thresh_array : jnp.ndarray - Time-varying arbitrage thresholds - arb_fees_array : jnp.ndarray - Time-varying arbitrage fees - trade_array : jnp.ndarray - Custom trade sequence + dynamic_inputs : DynamicInputArrays + Fixed-structure bundle of dynamic inputs. Returns ------- @@ -318,37 +309,26 @@ def calculate_reserves_with_dynamic_inputs( initial_value_per_token = weights * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array - # can be singletons, in which case we repeat them for the length of the bout - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len reserves = _jax_calc_balancer_reserves_with_dynamic_inputs( initial_reserves, weights, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(run_fingerprint["all_sig_variations"]), - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], ) diff --git a/quantammsim/pools/G3M/balancer/balancer_reserves.py b/quantammsim/pools/G3M/balancer/balancer_reserves.py index c0b9d45..82c6629 100644 --- a/quantammsim/pools/G3M/balancer/balancer_reserves.py +++ b/quantammsim/pools/G3M/balancer/balancer_reserves.py @@ -359,7 +359,7 @@ def _jax_calc_balancer_reserves_with_dynamic_fees_and_trades_scan_function_using gamma = input_list[4] arb_thresh = input_list[5] arb_fees = input_list[6] - trade = input_list[7] + trade = input_list[7] if do_trades else None fees_are_being_charged = gamma != 1.0 @@ -494,6 +494,8 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") # pre-calculate some values that are repeatedly used in optimal arb calculations _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( @@ -528,19 +530,17 @@ def _jax_calc_balancer_reserves_with_dynamic_inputs( initial_reserves, 0, ] - _, reserves = scan( - scan_fn, - carry_list_init, - [ - prices, - active_initial_weights, - per_asset_ratios, - all_other_assets_ratios, - gamma, - arb_thresh, - arb_fees, - trades, - ], - ) + scan_inputs = [ + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + _, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py index bb99518..7091ee3 100644 --- a/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py +++ b/quantammsim/pools/G3M/quantamm/TFMM_base_pool.py @@ -11,6 +11,7 @@ from jax.lax import stop_gradient, dynamic_slice, scan, fori_loop from jax.tree_util import Partial +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.G3M.quantamm.quantamm_reserves import ( _jax_calc_quantAMM_reserve_ratios, @@ -381,11 +382,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: bout_length = run_fingerprint["bout_length"] @@ -409,49 +406,31 @@ def calculate_reserves_with_dynamic_inputs( initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array, and lp_supply_array - # can be singletons, in which case we repeat them for the length of the bout. - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) - # if lp_supply_array is not provided, we set it to a constant of 1.0 - if lp_supply_array is None: - lp_supply_array = jnp.array(1.0) - - lp_supply_array_broadcast = jnp.broadcast_to( - lp_supply_array, (max_len,) + lp_supply_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=arb_acted_upon_local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len protocol_fee_split = run_fingerprint.get("protocol_fee_split", 0.0) reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs( initial_reserves, arb_acted_upon_weights, arb_acted_upon_local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(run_fingerprint["all_sig_variations"]), - trade_array, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], run_fingerprint["noise_trader_ratio"], - lp_supply_array_broadcast, + materialized_inputs.lp_supply, protocol_fee_split=protocol_fee_split, ) return reserves @@ -1577,11 +1556,16 @@ def calculate_weights_direct( initial_weights, minimum_weight, params, + jnp.zeros_like(initial_weights), + jnp.ones_like(initial_weights), local_fingerprint["max_memory_days"], local_fingerprint["chunk_period"], local_fingerprint["weight_interpolation_period"], maximum_change, False, + False, + False, + False, ) return target_weights_cpu diff --git a/quantammsim/pools/G3M/quantamm/quantamm_reserves.py b/quantammsim/pools/G3M/quantamm/quantamm_reserves.py index b6bd126..e0091ba 100644 --- a/quantammsim/pools/G3M/quantamm/quantamm_reserves.py +++ b/quantammsim/pools/G3M/quantamm/quantamm_reserves.py @@ -540,9 +540,14 @@ def _jax_calc_quantAMM_reserves_with_dynamic_fees_and_trades_scan_function_using gamma = input_list[8] arb_thresh = input_list[9] arb_fees = input_list[10] - trade = input_list[11] - do_arb = input_list[12] - lp_supply = input_list[13] + if do_trades: + trade = input_list[11] + do_arb = input_list[12] + lp_supply = input_list[13] + else: + trade = None + do_arb = input_list[11] + lp_supply = input_list[12] fees_are_being_charged = gamma != 1.0 protocol_fee_amount_step = jnp.zeros_like(prev_reserves) @@ -826,6 +831,8 @@ def _jax_calc_quantAMM_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(weights.shape[0], arb_fees), arb_fees ) + if do_trades and trades is None: + raise ValueError("Trades must be provided when do_trades=True.") if lp_supply_array is None: lp_supply_array = jnp.array(1.0) @@ -899,26 +906,23 @@ def _jax_calc_quantAMM_reserves_with_dynamic_inputs( ] # carry_list_init = [initial_weights, initial_i] # nojit_scan = jax.disable_jit()(jax.lax.scan) - carry_list_end, reserves = scan( - scan_fn, - carry_list_init, - [ - weights, - prices, - active_initial_weights, - per_asset_ratios, - all_other_assets_ratios, - lagged_active_initial_weights, - lagged_per_asset_ratios, - lagged_all_other_assets_ratios, - gamma, - arb_thresh, - arb_fees, - trades, - do_arb, - lp_supply_array, - ], - ) + scan_inputs = [ + weights, + prices, + active_initial_weights, + per_asset_ratios, + all_other_assets_ratios, + lagged_active_initial_weights, + lagged_per_asset_ratios, + lagged_all_other_assets_ratios, + gamma, + arb_thresh, + arb_fees, + ] + if do_trades: + scan_inputs.append(trades) + scan_inputs.extend([do_arb, lp_supply_array]) + carry_list_end, reserves = scan(scan_fn, carry_list_init, scan_inputs) return reserves diff --git a/quantammsim/pools/base_pool.py b/quantammsim/pools/base_pool.py index 3aeb85c..0415ef9 100644 --- a/quantammsim/pools/base_pool.py +++ b/quantammsim/pools/base_pool.py @@ -98,11 +98,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs: Any, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: pass diff --git a/quantammsim/pools/hodl_pool.py b/quantammsim/pools/hodl_pool.py index d3171a3..c938ff9 100644 --- a/quantammsim/pools/hodl_pool.py +++ b/quantammsim/pools/hodl_pool.py @@ -37,9 +37,9 @@ class HODLPool(AbstractPool): additional_oracle_input=None): Calculates the reserves without fees, assuming no trading activity. - calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, start_index, - fees_array, arb_thresh_array, arb_fees_array, trade_array, additional_oracle_input=None): - Calculates the reserves with dynamic inputs, which in this case is + calculate_reserves_with_dynamic_inputs(params, run_fingerprint, prices, start_index, + dynamic_inputs, additional_oracle_input=None): + Calculates the reserves with dynamic inputs, which in this case is the same as reserves without fees due to no activity. init_base_parameters(initial_values_dict, run_fingerprint, n_assets, @@ -124,11 +124,7 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: # hodl means no activity, so reserves are just the initial reserves diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 152a264..da36710 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -16,6 +16,7 @@ from typing import Dict, Any, Optional, NamedTuple import numpy as np +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs from quantammsim.pools.base_pool import AbstractPool from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, @@ -275,11 +276,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ): """Calculate reserves and LP fee revenue with time-varying inputs. @@ -291,20 +288,17 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( LP fee revenue per timestep in USD. """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) - bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=False, + dtype=s.arb_prices.dtype, ) return _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( @@ -313,9 +307,9 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( s.centeredness_margin, s.daily_price_shift_base, s.seconds_per_step, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), @@ -367,28 +361,21 @@ def calculate_reserves_with_dynamic_inputs( run_fingerprint: Dict[str, Any], prices: jnp.ndarray, start_index: jnp.ndarray, - fees_array: jnp.ndarray, - arb_thresh_array: jnp.ndarray, - arb_fees_array: jnp.ndarray, - trade_array: jnp.ndarray, - lp_supply_array: jnp.ndarray = None, + dynamic_inputs, additional_oracle_input: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) - bout_length = run_fingerprint["bout_length"] max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=max_len, + do_trades=False, + dtype=s.arb_prices.dtype, ) return _jax_calc_reclamm_reserves_with_dynamic_inputs( @@ -397,9 +384,9 @@ def calculate_reserves_with_dynamic_inputs( s.centeredness_margin, s.daily_price_shift_base, s.seconds_per_step, - fees=fees_array_broadcast, - arb_thresh=arb_thresh_array_broadcast, - arb_fees=arb_fees_array_broadcast, + fees=materialized_inputs.fees, + arb_thresh=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), diff --git a/quantammsim/runners/__init__.py b/quantammsim/runners/__init__.py index 987a800..511800b 100644 --- a/quantammsim/runners/__init__.py +++ b/quantammsim/runners/__init__.py @@ -28,7 +28,7 @@ from .jax_runner_utils import ( nan_rollback, Hashabledict, - get_trades_and_fees, + prepare_dynamic_inputs, get_unique_tokens, OptunaManager, generate_evaluation_points, @@ -80,7 +80,7 @@ # Utilities "nan_rollback", "Hashabledict", - "get_trades_and_fees", + "prepare_dynamic_inputs", "get_unique_tokens", "OptunaManager", "generate_evaluation_points", diff --git a/quantammsim/runners/jax_runner_utils.py b/quantammsim/runners/jax_runner_utils.py index 77bcadb..97b49e6 100644 --- a/quantammsim/runners/jax_runner_utils.py +++ b/quantammsim/runners/jax_runner_utils.py @@ -14,6 +14,12 @@ raw_fee_like_amounts_to_fee_like_array, raw_trades_to_trade_array, ) +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + DynamicInputFrames, + dynamic_input_flags_from_frames, + empty_dynamic_input_arrays, +) from quantammsim.apis.rest_apis.simulator_dtos.simulation_run_dto import ( LiquidityPoolCoinDto, @@ -1053,8 +1059,9 @@ def get_unique_tokens(run_fingerprint): >>> get_unique_tokens(fingerprint) ['BTC', 'DAI', 'ETH'] """ + subsidary_pools = run_fingerprint.get("subsidary_pools", []) all_tokens = [run_fingerprint["tokens"]] + [ - cprd["tokens"] for cprd in run_fingerprint["subsidary_pools"] + cprd["tokens"] for cprd in subsidary_pools ] all_tokens = [item for sublist in all_tokens for item in sublist] unique_tokens = list(set(all_tokens)) @@ -1214,39 +1221,40 @@ def unpermute_list_of_params(list_of_params): return list_of_params_to_return -def get_trades_and_fees( - run_fingerprint, raw_trades, fees_df, gas_cost_df, arb_fees_df, lp_supply_df, do_test_period=False -): - """ - Process trade and fee data for a simulation run. +def _to_dynamic_input_arrays( + trades_array, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, +) -> DynamicInputArrays: + """Normalize optional numpy arrays into the hot-path container.""" + empty = empty_dynamic_input_arrays() + return DynamicInputArrays( + trades=None if trades_array is None else jnp.asarray(trades_array, dtype=jnp.float64), + fees=empty.fees if fees_array is None else jnp.asarray(fees_array, dtype=jnp.float64), + gas_cost=empty.gas_cost if gas_cost_array is None else jnp.asarray(gas_cost_array, dtype=jnp.float64), + arb_fees=empty.arb_fees if arb_fees_array is None else jnp.asarray(arb_fees_array, dtype=jnp.float64), + lp_supply=empty.lp_supply if lp_supply_array is None else jnp.asarray(lp_supply_array, dtype=jnp.float64), + ) - Takes raw trades, fees, gas costs and arbitrage fees and converts them into arrays - suitable for simulation. Handles both training and test periods if specified. - Parameters - ---------- - run_fingerprint : dict - Dictionary containing run configuration including start/end dates and tokens - raw_trades : pd.DataFrame, optional - DataFrame containing raw trade data - fees_df : pd.DataFrame, optional - DataFrame containing fee data - gas_cost_df : pd.DataFrame, optional - DataFrame containing gas cost data - arb_fees_df : pd.DataFrame, optional - DataFrame containing arbitrage fee data - lp_supply_df : pd.DataFrame, optional - DataFrame containing LP supply data - do_test_period : bool, optional - Whether to process data for a test period after training period (default False) +def prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames: Optional[DynamicInputFrames] = None, + do_test_period: bool = False, +): + """Convert optional pandas inputs into dynamic input bundles.""" + if dynamic_input_frames is None: + dynamic_input_frames = DynamicInputFrames() + + raw_trades = dynamic_input_frames.trades + fees_df = dynamic_input_frames.fees + gas_cost_df = dynamic_input_frames.gas_cost + arb_fees_df = dynamic_input_frames.arb_fees + lp_supply_df = dynamic_input_frames.lp_supply + dynamic_input_flags = dynamic_input_flags_from_frames(dynamic_input_frames) - Returns - ------- - dict - Contains processed arrays for trades, fees, gas costs and arb fees for both - training and test periods as applicable - """ - # Process raw trades if provided if raw_trades is not None: train_period_trades = raw_trades_to_trade_array( raw_trades, @@ -1264,7 +1272,7 @@ def get_trades_and_fees( else: train_period_trades = None test_period_trades = None - # Process fees, gas costs, and arb fees if provided + fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, @@ -1280,8 +1288,8 @@ def get_trades_and_fees( test_fees_array = ( raw_fee_like_amounts_to_fee_like_array( fees_df, - run_fingerprint["startDateString"], run_fingerprint["endDateString"], + run_fingerprint["endTestDateString"], names=["fees"], fill_method="ffill", ) @@ -1359,26 +1367,46 @@ def get_trades_and_fees( if lp_supply_df is not None else None ) + + # Unit LP supply is the neutral case; keep it on the static hot path. + if lp_supply_array is not None and np.allclose(lp_supply_array, 1.0): + lp_supply_array = None + if not do_test_period or test_lp_supply_array is None or np.allclose(test_lp_supply_array, 1.0): + dynamic_input_flags["has_lp_supply"] = False + dynamic_input_flags["use_dynamic_inputs"] = any( + value for key, value in dynamic_input_flags.items() if key != "use_dynamic_inputs" + ) + + if do_test_period and test_lp_supply_array is not None and np.allclose(test_lp_supply_array, 1.0): + test_lp_supply_array = None + if do_test_period: return { - "train_period_trades": train_period_trades, - "test_period_trades": test_period_trades, - "fees_array": fees_array, - "gas_cost_array": gas_cost_array, - "arb_fees_array": arb_fees_array, - "lp_supply_array": lp_supply_array, - "test_fees_array": test_fees_array, - "test_gas_cost_array": test_gas_cost_array, - "test_arb_fees_array": test_arb_fees_array, - "test_lp_supply_array": test_lp_supply_array, - } - else: - return { - "train_period_trades": train_period_trades, - "fees_array": fees_array, - "gas_cost_array": gas_cost_array, - "arb_fees_array": arb_fees_array, - "lp_supply_array": lp_supply_array, + "train_dynamic_inputs": _to_dynamic_input_arrays( + train_period_trades, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, + ), + "test_dynamic_inputs": _to_dynamic_input_arrays( + test_period_trades, + test_fees_array, + test_gas_cost_array, + test_arb_fees_array, + test_lp_supply_array, + ), + "dynamic_input_flags": dynamic_input_flags, } + return { + "train_dynamic_inputs": _to_dynamic_input_arrays( + train_period_trades, + fees_array, + gas_cost_array, + arb_fees_array, + lp_supply_array, + ), + "dynamic_input_flags": dynamic_input_flags, + } def create_daily_unix_array(start_date_str, end_date_str): @@ -1621,12 +1649,21 @@ def try_forward_pass(n_sets: int) -> bool: "n_assets": n_tokens, "training_data_kind": probe_fingerprint["optimisation_settings"]["training_data_kind"], "do_trades": False, + "dynamic_input_flags": { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, }, ) # Create vmapped forward pass partial_forward = Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=static_dict, pool=pool, diff --git a/quantammsim/runners/jax_runners.py b/quantammsim/runners/jax_runners.py index 62d5ad4..6b0160c 100644 --- a/quantammsim/runners/jax_runners.py +++ b/quantammsim/runners/jax_runners.py @@ -53,6 +53,10 @@ forward_pass_nograd, _calculate_return_value, ) +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputFrames, + materialize_dynamic_inputs, +) from quantammsim.core_simulator.windowing_utils import get_indices, filter_coarse_weights_by_data_indices import hashlib @@ -80,7 +84,7 @@ from quantammsim.runners.jax_runner_utils import ( Hashabledict, - get_trades_and_fees, + prepare_dynamic_inputs, get_unique_tokens, OptunaManager, generate_evaluation_points, @@ -159,7 +163,10 @@ def _build_scan_infrastructure( run_scan_chunk : callable ``@jit`` wrapped ``lax.scan(scan_body, carry, None, length=chunk_size)``. scan_body : callable - The raw scan body (for partial-chunk Python fallback). + The raw scan body. + run_scan_step : callable + ``@jit`` wrapped single-step execution used for remainder iterations so + partial chunks follow the same numerics as the full scan path. """ # Local aliases for closed-over constants _start_idx = start_idx @@ -309,7 +316,11 @@ def scan_body(carry, _): def _run_scan_chunk(carry): return lax.scan(scan_body, carry, None, length=chunk_size) - return _run_scan_chunk, scan_body + @jit + def _run_scan_step(carry): + return scan_body(carry, None) + + return _run_scan_chunk, scan_body, _run_scan_step def train_on_historic_data( @@ -673,6 +684,14 @@ def _train_on_historic_data_impl( "n_assets": n_assets, "training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"], "do_trades": False, + "dynamic_input_flags": { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, }, ) @@ -694,6 +713,7 @@ def _train_on_historic_data_impl( continuous_static_dict["bout_length"] = original_bout_length + data_dict["bout_length_test"] partial_forward_pass_nograd_batch_continuous = Partial( forward_pass_nograd, + dynamic_inputs=None, static_dict=Hashabledict(continuous_static_dict), pool=pool, ) @@ -834,11 +854,12 @@ def init_optimizer(params): ) if config_key in _scan_infra_cache: - _run_scan_chunk, scan_body = _scan_infra_cache[config_key] + _run_scan_chunk, scan_body, _run_scan_step = _scan_infra_cache[config_key] else: # Build scan-compatible update (prices as explicit arg, not closure) partial_step_no_prices = Partial( forward_pass, + dynamic_inputs=None, static_dict=Hashabledict(base_static_dict), pool=pool, ) @@ -852,7 +873,7 @@ def init_optimizer(params): partial_step_no_prices, params_in_axes_dict, ) - _run_scan_chunk, scan_body = _build_scan_infrastructure( + _run_scan_chunk, scan_body, _run_scan_step = _build_scan_infrastructure( chunk_size, partial_step_no_prices=partial_step_no_prices, forward_nograd_continuous=partial_forward_pass_nograd_continuous, @@ -878,7 +899,7 @@ def init_optimizer(params): swa_freq=swa_freq, n_parameter_sets=n_parameter_sets, ) - _scan_infra_cache[config_key] = (_run_scan_chunk, scan_body) + _scan_infra_cache[config_key] = (_run_scan_chunk, scan_body, _run_scan_step) # ── Initialize carry (prices & nan_bank in carry, not closures) ── carry = { @@ -933,7 +954,7 @@ def init_optimizer(params): "params": {k: [] for k in carry["params"]}, } for _ in range(actual): - carry, step_out = scan_body(carry, None) + carry, step_out = _run_scan_step(carry) all_per_steps["objective"].append(step_out["objective"]) all_per_steps["train_metrics"].append(step_out["train_metrics"]) all_per_steps["test_metrics"].append(step_out["test_metrics"]) @@ -2478,14 +2499,10 @@ def do_run_on_historic_data( root=None, price_data=None, verbose=False, - raw_trades=None, fees=None, gas_cost=None, arb_fees=None, - fees_df=None, - gas_cost_df=None, - arb_fees_df=None, - lp_supply_df=None, + dynamic_input_frames: DynamicInputFrames = None, do_test_period=False, low_data_mode=False, preslice_burnin=True, @@ -2511,23 +2528,14 @@ def do_run_on_historic_data( Pre-loaded price data. When None, loaded from parquet files. verbose : bool, optional Print progress information (default False). - raw_trades : DataFrame, optional - Real trade data to inject. Columns: unix timestamp (minute), - token_in, token_out, amount_in. fees : float, optional Swap fee override (e.g. 0.003 for 30 bps). gas_cost : float, optional Gas cost override per transaction. arb_fees : float, optional Arbitrageur fee override. - fees_df : DataFrame, optional - Time-varying swap fees (columns: unix, fee). - gas_cost_df : DataFrame, optional - Time-varying gas costs (columns: unix, gas_cost). - arb_fees_df : DataFrame, optional - Time-varying arb fees (columns: unix, arb_fee). - lp_supply_df : DataFrame, optional - Time-varying LP supply changes. + dynamic_input_frames : DynamicInputFrames, optional + Optional container of trades / fee / gas / arb / LP supply DataFrames. do_test_period : bool, optional If True, also run the OOS test period defined by ``endDateString`` to ``endTestDateString`` (default False). @@ -2572,15 +2580,21 @@ def do_run_on_historic_data( np.random.seed(0) - dynamic_inputs_dict = get_trades_and_fees( + dynamic_inputs_dict = prepare_dynamic_inputs( run_fingerprint, - raw_trades, - fees_df, - gas_cost_df, - arb_fees_df, - lp_supply_df, + dynamic_input_frames=dynamic_input_frames, do_test_period=do_test_period, ) + train_dynamic_inputs = ( + dynamic_inputs_dict["train_dynamic_inputs"] + if dynamic_inputs_dict["dynamic_input_flags"]["use_dynamic_inputs"] + else None + ) + test_dynamic_inputs = ( + dynamic_inputs_dict.get("test_dynamic_inputs") + if dynamic_inputs_dict["dynamic_input_flags"]["use_dynamic_inputs"] + else None + ) # Load price data if not provided if price_data is None: @@ -2624,7 +2638,8 @@ def do_run_on_historic_data( "fees": fees if fees is not None else run_fingerprint["fees"], "arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"], "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], - "do_trades": False if raw_trades is None else run_fingerprint["do_trades"], + "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], + "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], @@ -2680,10 +2695,7 @@ def do_run_on_historic_data( param, (data_dict["start_idx"], 0), data_dict["prices"], - dynamic_inputs_dict["train_period_trades"], - dynamic_inputs_dict["fees_array"], - dynamic_inputs_dict["gas_cost_array"], - dynamic_inputs_dict["arb_fees_array"], + train_dynamic_inputs, ) if low_data_mode: output_dict["final_prices"] = output_dict["prices"][-1] @@ -2699,10 +2711,7 @@ def do_run_on_historic_data( param, (data_dict["start_idx_test"], 0), data_dict["prices"], - dynamic_inputs_dict["test_period_trades"], - dynamic_inputs_dict["test_fees_array"], - dynamic_inputs_dict["test_gas_cost_array"], - dynamic_inputs_dict["test_arb_fees_array"], + test_dynamic_inputs, ) if low_data_mode: output_dict_test["final_prices"] = output_dict_test["prices"][-1] @@ -2741,14 +2750,10 @@ def do_run_on_historic_data_with_provided_coarse_weights( root=None, price_data=None, verbose=False, - raw_trades=None, fees=None, gas_cost=None, arb_fees=None, - fees_df=None, - gas_cost_df=None, - arb_fees_df=None, - lp_supply_df=None, + dynamic_input_frames: DynamicInputFrames = None, do_test_period=False, low_data_mode=False, ): @@ -2778,22 +2783,14 @@ def do_run_on_historic_data_with_provided_coarse_weights( Pre-loaded price data. verbose : bool, optional Print progress (default False). - raw_trades : DataFrame, optional - Real trade data to inject. fees : float, optional Swap fee override. gas_cost : float, optional Gas cost override. arb_fees : float, optional Arbitrageur fee override. - fees_df : DataFrame, optional - Time-varying swap fees. - gas_cost_df : DataFrame, optional - Time-varying gas costs. - arb_fees_df : DataFrame, optional - Time-varying arb fees. - lp_supply_df : DataFrame, optional - Time-varying LP supply changes. + dynamic_input_frames : DynamicInputFrames, optional + Optional container of trades / fee / gas / arb / LP supply DataFrames. do_test_period : bool, optional Run OOS test period (default False). low_data_mode : bool, optional @@ -2832,13 +2829,9 @@ def do_run_on_historic_data_with_provided_coarse_weights( np.random.seed(0) - dynamic_inputs_dict = get_trades_and_fees( + dynamic_inputs_dict = prepare_dynamic_inputs( run_fingerprint, - raw_trades, - fees_df, - gas_cost_df, - arb_fees_df, - lp_supply_df, + dynamic_input_frames=dynamic_input_frames, do_test_period=do_test_period, ) @@ -2881,7 +2874,8 @@ def do_run_on_historic_data_with_provided_coarse_weights( "fees": fees if fees is not None else run_fingerprint["fees"], "arb_fees": arb_fees if arb_fees is not None else run_fingerprint["arb_fees"], "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], - "do_trades": False if raw_trades is None else run_fingerprint["do_trades"], + "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], + "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], @@ -2912,11 +2906,16 @@ def do_run_on_historic_data_with_provided_coarse_weights( initial_weights, minimum_weight, params, + jnp.zeros_like(initial_weights), + jnp.ones_like(initial_weights), run_fingerprint["max_memory_days"], chunk_period, chunk_period, 1.0, False, + False, + False, + False, ) weights = _jax_fine_weights_from_actual_starts_and_diffs( @@ -2948,78 +2947,35 @@ def do_run_on_historic_data_with_provided_coarse_weights( # weights=HashableArrayWrapper(weights), # initial_reserves=HashableArrayWrapper(params["initial_reserves"]), # ) - fees_array = dynamic_inputs_dict.get("fees_array") - arb_thresh_array = dynamic_inputs_dict.get("gas_cost_array") - arb_fees_array = dynamic_inputs_dict.get("arb_fees_array") - trade_array = dynamic_inputs_dict.get("trades") - lp_supply_array = dynamic_inputs_dict.get("lp_supply_array") - - if fees_array is None: - fees_array = jnp.array([static_dict["fees"]]) - if arb_thresh_array is None: - arb_thresh_array = jnp.array([static_dict["gas_cost"]]) - if arb_fees_array is None: - arb_fees_array = jnp.array([static_dict["arb_fees"]]) - - # initial_pool_value = run_fingerprint["initial_pool_value"] - # initial_value_per_token = arb_acted_upon_weights[0] * initial_pool_value - # initial_reserves = initial_value_per_token / arb_acted_upon_local_prices[0] - + dynamic_input_flags = dynamic_inputs_dict["dynamic_input_flags"] + dynamic_inputs = dynamic_inputs_dict["train_dynamic_inputs"] initial_reserves = params["initial_reserves"] - - # any of fees_array, arb_thresh_array, arb_fees_array, trade_array, and lp_supply_array - # can be singletons, in which case we repeat them for the length of the bout. - - # Determine the maximum leading dimension max_len = bout_length - 1 if run_fingerprint["arb_frequency"] != 1: max_len = max_len // run_fingerprint["arb_frequency"] - - fees_array = fees_array[:max_len] - arb_thresh_array = arb_thresh_array[:max_len] - arb_thresh_array = arb_thresh_array * 0.0 - arb_fees_array = arb_fees_array[:max_len] - if lp_supply_array is not None: - lp_supply_array = lp_supply_array[:max_len] - if trade_array is not None: - trade_array = trade_array[:max_len] - # Broadcast input arrays to match the maximum leading dimension. - # If they are singletons, this will just repeat them for the length of the bout. - # If they are arrays of length bout_length, this will cause no change. - fees_array_broadcast = jnp.broadcast_to( - fees_array, (max_len,) + fees_array.shape[1:] - ) - arb_thresh_array_broadcast = jnp.broadcast_to( - arb_thresh_array, (max_len,) + arb_thresh_array.shape[1:] - ) - arb_fees_array_broadcast = jnp.broadcast_to( - arb_fees_array, (max_len,) + arb_fees_array.shape[1:] - ) - # if lp_supply_array is not provided, we set it to a constant of 1.0 - if lp_supply_array is None: - lp_supply_array = jnp.array(1.0) - - lp_supply_array_broadcast = jnp.broadcast_to( - lp_supply_array, (max_len,) + lp_supply_array.shape[1:] + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + dynamic_input_flags, + static_dict, + scan_len=max_len, + do_trades=run_fingerprint["do_trades"], + dtype=local_prices.dtype, ) - # if we are doing trades, the trades array must be of the same length as the other arrays - if run_fingerprint["do_trades"]: - assert trade_array.shape[0] == max_len protocol_fee_split = run_fingerprint.get("protocol_fee_split", 0.0) reserves = _jax_calc_quantAMM_reserves_with_dynamic_inputs( initial_reserves, weights, local_prices, - fees_array_broadcast, - arb_thresh_array_broadcast, - arb_fees_array_broadcast, + materialized_inputs.fees, + materialized_inputs.gas_cost, + materialized_inputs.arb_fees, jnp.array(static_dict["all_sig_variations"]), - None, + materialized_inputs.trades, run_fingerprint["do_trades"], run_fingerprint["do_arb"], run_fingerprint["noise_trader_ratio"], - lp_supply_array_broadcast, + materialized_inputs.lp_supply, protocol_fee_split=protocol_fee_split, ) diff --git a/quantammsim/runners/multi_period_sgd.py b/quantammsim/runners/multi_period_sgd.py index 7802968..c81e00a 100644 --- a/quantammsim/runners/multi_period_sgd.py +++ b/quantammsim/runners/multi_period_sgd.py @@ -430,6 +430,7 @@ def multi_period_sgd_training( # Create base forward pass base_forward_pass = Partial( forward_pass, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, @@ -506,6 +507,7 @@ def multi_period_sgd_training( partial_nograd = jit(Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, diff --git a/quantammsim/runners/training_evaluator.py b/quantammsim/runners/training_evaluator.py index e79a071..011b3f7 100644 --- a/quantammsim/runners/training_evaluator.py +++ b/quantammsim/runners/training_evaluator.py @@ -742,6 +742,7 @@ def _compute_metrics( eval_fn = jit(Partial( forward_pass_nograd, + dynamic_inputs=None, prices=data_dict["prices"], static_dict=Hashabledict(static_dict), pool=pool, diff --git a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py index f92f7da..aa31519 100644 --- a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py +++ b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py @@ -20,6 +20,7 @@ from quantammsim.runners.jax_runners import do_run_on_historic_data from quantammsim.runners.jax_runner_utils import optimized_output_conversion +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import quantammsim.simulator_analysis_tools.finance.financial_analysis_calculator as fac import quantammsim.simulator_analysis_tools.finance.financial_analysis_functions as faf import quantammsim.simulator_analysis_tools.finance.financial_analysis_utils as fau @@ -238,6 +239,12 @@ def run_pool_simulation(simulationRunDto): run_fingerprint["fees"] = static_fee fee_steps_df = None + dynamic_input_frames = DynamicInputFrames( + trades=raw_trades, + fees=fee_steps_df, + gas_cost=gas_cost_df, + ) + print("run fingerprint-------------------", run_fingerprint) print("update rule parameter dict converted-------------------", update_rule_parameter_dict_converted) outputDict = do_run_on_historic_data( @@ -247,9 +254,7 @@ def run_pool_simulation(simulationRunDto): price_data=price_data_local, verbose=True, do_test_period=False, - raw_trades=raw_trades, - gas_cost_df=gas_cost_df, - fees_df=fee_steps_df + dynamic_input_frames=dynamic_input_frames, ) print("outputDict: ", outputDict.keys()) resultTimeSteps = optimized_output_conversion(simulationRunDto, outputDict, tokens) @@ -293,9 +298,7 @@ def run_pool_simulation(simulationRunDto): price_data=price_data_local, verbose=False, do_test_period=False, - raw_trades=raw_trades, - gas_cost_df=gas_cost_df, - fees_df=fee_steps_df, + dynamic_input_frames=dynamic_input_frames, ) # Extract final weights from the result. diff --git a/scripts/demo_run_chunks_from_chain_data.py b/scripts/demo_run_chunks_from_chain_data.py index fbd9cf9..09c19d1 100644 --- a/scripts/demo_run_chunks_from_chain_data.py +++ b/scripts/demo_run_chunks_from_chain_data.py @@ -28,6 +28,7 @@ import numpy as np import pandas as pd import matplotlib as mpl +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames import matplotlib.pyplot as plt import jax.numpy as jnp @@ -474,10 +475,12 @@ def _df_meta_and_head(df, name, n=3): run_fingerprint=fingerprint, coarse_weights=cw_window, params=params, - fees_df=scraped["fees_df"], - gas_cost_df=scraped["gas_cost_df"], - lp_supply_df=scraped["lp_supply_df"], - arb_fees_df=scraped["arb_fees_df"], + dynamic_input_frames=DynamicInputFrames( + fees=scraped["fees_df"], + gas_cost=scraped["gas_cost_df"], + lp_supply=scraped["lp_supply_df"], + arb_fees=scraped["arb_fees_df"], + ), ) # ---------------- Correct, window-aligned plotting block (time-aware + plain y) ---------------- diff --git a/scripts/demo_run_from_chain_data.py b/scripts/demo_run_from_chain_data.py index 78ecbc4..b1f4139 100644 --- a/scripts/demo_run_from_chain_data.py +++ b/scripts/demo_run_from_chain_data.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import ( memory_days_to_logit_lamb, ) @@ -997,10 +998,12 @@ def generate_daily_variations(start_date_str, end_date_str): run_fingerprint=config["fingerprint"], coarse_weights=config["coarse_weights"], params=config["params"], - fees_df=config["fees_df"], - gas_cost_df=config["gas_cost_df"], - lp_supply_df=config["lp_supply_df"], - arb_fees_df=config["arb_fees_df"], + dynamic_input_frames=DynamicInputFrames( + fees=config["fees_df"], + gas_cost=config["gas_cost_df"], + lp_supply=config["lp_supply_df"], + arb_fees=config["arb_fees_df"], + ), ) print("-" * 80) print(f"Pool Type: {config['fingerprint']['rule']}") @@ -1191,4 +1194,4 @@ def generate_daily_variations(start_date_str, end_date_str): # actual_reserves_np=local_reserves, # actual_unix_values=datetime_array, # ) - # raise Exception("Stop here") \ No newline at end of file + # raise Exception("Stop here") diff --git a/scripts/reclamm/sim_vs_world_comparison.py b/scripts/reclamm/sim_vs_world_comparison.py index 0c754ea..bf7d19d 100644 --- a/scripts/reclamm/sim_vs_world_comparison.py +++ b/scripts/reclamm/sim_vs_world_comparison.py @@ -34,6 +34,7 @@ from pathlib import Path from datetime import datetime, timezone +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.runners.jax_runners import do_run_on_historic_data # ── On-chain reClAMM params ─────────────────────────────────────────────────── @@ -187,9 +188,18 @@ def sample_at_timestamps(minute_vals, start_unix_sec, timestamps_sec): return minute_vals[indices] -def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, - protocol_fee_split=0.0, gas_cost_df=None, - onchain_initial_state=None): +def run_pool( + tokens, + start, + end, + rule, + fees, + params, + gas_cost=0.0, + protocol_fee_split=0.0, + dynamic_input_frames=None, + onchain_initial_state=None, +): """Run a quantammsim pool and return minute-level results. Returns (val_eth, price_ratio, start_unix_sec) where val_eth and @@ -220,7 +230,9 @@ def run_pool(tokens, start, end, rule, fees, params, gas_cost=0.0, fp["reclamm_initial_state"] = onchain_initial_state result = do_run_on_historic_data( - run_fingerprint=fp, params=params, gas_cost_df=gas_cost_df, + run_fingerprint=fp, + params=params, + dynamic_input_frames=dynamic_input_frames, ) # Prices: sorted tokens → [AAVE, ETH] in USD @@ -291,7 +303,8 @@ def run_gas_experiment(args): gas_df = load_gas_csv(pct) val_eth_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), ) gas_results_min[pct] = val_eth_min @@ -433,7 +446,8 @@ def run_gas_scale_experiment(args): gas_df["trade_gas_cost_usd"] = gas_df_raw["trade_gas_cost_usd"] * scale val_eth_min, pr_min, start_sec = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), onchain_initial_state=onchain_state, ) results_min[(pct, scale)] = (val_eth_min, start_sec) @@ -662,7 +676,8 @@ def run_best_gas_experiment(args): gas_df_50p = load_gas_csv("50p") g50_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_50p, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_50p), onchain_initial_state=onchain_state, ) @@ -674,7 +689,8 @@ def run_best_gas_experiment(args): gas_df_75p_scaled["trade_gas_cost_usd"] *= 0.75 g75_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_75p_scaled, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_75p_scaled), onchain_initial_state=onchain_state, ) @@ -686,7 +702,8 @@ def run_best_gas_experiment(args): gas_df_90p_scaled["trade_gas_cost_usd"] *= 0.25 g90_min, _, _ = run_pool( tokens, start, end, "reclamm", ONCHAIN_FEES, pool_params, - protocol_fee_split=PROTOCOL_FEE_SPLIT, gas_cost_df=gas_df_90p_scaled, + protocol_fee_split=PROTOCOL_FEE_SPLIT, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df_90p_scaled), onchain_initial_state=onchain_state, ) diff --git a/tests/integration/test_dynamic_gas_fees.py b/tests/integration/test_dynamic_gas_fees.py index 5870414..8eea4cd 100644 --- a/tests/integration/test_dynamic_gas_fees.py +++ b/tests/integration/test_dynamic_gas_fees.py @@ -9,6 +9,7 @@ import jax.numpy as jnp from pathlib import Path +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.runners.jax_runners import do_run_on_historic_data @@ -78,8 +79,12 @@ class TestDynamicGasAndFees: @pytest.mark.requires_data def test_run_with_gas_and_fees(self, base_fingerprint, base_params, gas_df, fees_df, data_root): """Test simulation with both gas costs and dynamic fees.""" + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df, fees=fees_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -92,8 +97,12 @@ def test_run_with_gas_and_fees(self, base_fingerprint, base_params, gas_df, fees @pytest.mark.requires_data def test_run_with_gas_only(self, base_fingerprint, base_params, gas_df, data_root): """Test simulation with gas costs only.""" + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -104,8 +113,12 @@ def test_run_with_gas_only(self, base_fingerprint, base_params, gas_df, data_roo @pytest.mark.requires_data def test_run_with_fees_only(self, base_fingerprint, base_params, fees_df, data_root): """Test simulation with dynamic fees only.""" + dynamic_input_frames = DynamicInputFrames(fees=fees_df) result = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) assert result is not None @@ -120,8 +133,12 @@ def test_gas_reduces_final_value(self, base_fingerprint, base_params, gas_df, da result_no_gas = do_run_on_historic_data(base_fingerprint, base_params, root=data_root) # Run with gas + dynamic_input_frames = DynamicInputFrames(gas_cost=gas_df) result_with_gas = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, gas_cost_df=gas_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) if "final_value" in result_no_gas and "final_value" in result_with_gas: @@ -139,8 +156,12 @@ def test_fees_reduce_final_value(self, base_fingerprint, base_params, fees_df, d result_no_fees = do_run_on_historic_data(base_fingerprint, base_params, root=data_root) # Run with fees + dynamic_input_frames = DynamicInputFrames(fees=fees_df) result_with_fees = do_run_on_historic_data( - base_fingerprint, base_params, root=data_root, fees_df=fees_df + base_fingerprint, + base_params, + root=data_root, + dynamic_input_frames=dynamic_input_frames, ) if "final_value" in result_no_fees and "final_value" in result_with_fees: diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index 9406a96..4a31f1f 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt +from quantammsim.core_simulator.dynamic_inputs import DynamicInputArrays from quantammsim.pools.reCLAMM.reclamm_reserves import ( initialise_reclamm_reserves, _jax_calc_reclamm_reserves_with_fees, @@ -267,6 +268,7 @@ class TestPoolMethodWithFees: """pool.calculate_reserves_and_fee_revenue_with_fees returns correct tuple.""" def test_pool_method_with_fees(self): + from quantammsim.core_simulator.dynamic_inputs import DynamicInputArrays from quantammsim.pools.creator import create_pool from quantammsim.runners.jax_runner_utils import Hashabledict @@ -347,13 +349,20 @@ def test_pool_method_with_dynamic_inputs(self): fees_array = jnp.array([0.003]) arb_thresh_array = jnp.array([0.0]) arb_fees_array = jnp.array([0.0]) + dynamic_inputs = DynamicInputArrays( + trades=jnp.zeros((1, 3)), + fees=fees_array, + gas_cost=arb_thresh_array, + arb_fees=arb_fees_array, + lp_supply=jnp.ones((1,)), + ) reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( - params, run_fingerprint, prices, start_index, - fees_array=fees_array, - arb_thresh_array=arb_thresh_array, - arb_fees_array=arb_fees_array, - trade_array=None, + params, + run_fingerprint, + prices, + start_index, + dynamic_inputs=dynamic_inputs, ) assert reserves.shape == (n_steps, 2) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index ea77eb6..1db40cb 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -9,6 +9,7 @@ import numpy as np import numpy.testing as npt +from tests.conftest import TEST_DATA_DIR from quantammsim.pools.reCLAMM.reclamm_reserves import ( compute_invariant, compute_price_ratio, diff --git a/tests/scripts/dynamic_gas_test.py b/tests/scripts/dynamic_gas_test.py index d656b6a..920a7e3 100644 --- a/tests/scripts/dynamic_gas_test.py +++ b/tests/scripts/dynamic_gas_test.py @@ -1,7 +1,9 @@ -from quantammsim.runners.jax_runners import do_run_on_historic_data import jax.numpy as jnp import pandas as pd +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames +from quantammsim.runners.jax_runners import do_run_on_historic_data + # Print the results print("=" * 100) print("Simulation Results:") @@ -32,11 +34,21 @@ run_fingerprint["do_trades"] = False result_w_gas_and_fees = do_run_on_historic_data( - run_fingerprint, params, gas_cost_df=gas_df, fees_df=fees_df + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df, fees=fees_df), +) +result_w_gas_only = do_run_on_historic_data( + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_df), ) -result_w_gas_only = do_run_on_historic_data(run_fingerprint, params, gas_cost_df=gas_df) -result_w_fees_only = do_run_on_historic_data(run_fingerprint, params, fees_df=fees_df) +result_w_fees_only = do_run_on_historic_data( + run_fingerprint, + params, + dynamic_input_frames=DynamicInputFrames(fees=fees_df), +) print(result_w_gas_and_fees["value"][-1440+1]) print(result_w_gas_only["value"][-1440+1]) diff --git a/tests/unit/test_jax_runner_utils.py b/tests/unit/test_jax_runner_utils.py index 9cd95f6..620ecb8 100644 --- a/tests/unit/test_jax_runner_utils.py +++ b/tests/unit/test_jax_runner_utils.py @@ -13,6 +13,7 @@ import pytest import numpy as np import jax.numpy as jnp +import pandas as pd class TestHashabledict: @@ -159,6 +160,21 @@ def test_static_dict_is_hashable_with_real_fingerprint(self): h = hash(hd) assert isinstance(h, int) + def test_static_dict_accepts_dynamic_input_flags(self): + """Nested dynamic-input flags must remain hashable for JIT cache keys.""" + from quantammsim.core_simulator.dynamic_inputs import default_dynamic_input_flags + from quantammsim.runners.jax_runner_utils import create_static_dict, Hashabledict + + fp = self._make_fingerprint() + static = create_static_dict( + fp, + bout_length=10080, + overrides={"dynamic_input_flags": default_dynamic_input_flags()}, + ) + + assert static["dynamic_input_flags"]["use_dynamic_inputs"] is False + assert isinstance(hash(Hashabledict(static)), int) + def test_unknown_array_fields_dropped_with_warning(self): """Arrays not in _TRAINING_ONLY_FIELDS are dropped with a warning. @@ -237,6 +253,277 @@ def test_equality_with_non_dict_returns_false(self): assert d != [1, 2, 3] +class TestDynamicInputPreparation: + """Tests for dynamic input container construction and normalization.""" + + def test_empty_dynamic_input_arrays_have_stable_shapes(self): + """The empty hot-path bundle should use singleton fee-like placeholders only.""" + from quantammsim.core_simulator.dynamic_inputs import empty_dynamic_input_arrays + + dynamic_inputs = empty_dynamic_input_arrays() + + assert dynamic_inputs.trades is None + assert dynamic_inputs.fees.shape == (1,) + assert dynamic_inputs.gas_cost.shape == (1,) + assert dynamic_inputs.arb_fees.shape == (1,) + assert dynamic_inputs.lp_supply.shape == (1,) + + def test_dynamic_input_flags_reflect_present_frames(self): + """Frame-presence flags should drive static dynamic-input dispatch.""" + from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputFrames, + dynamic_input_flags_from_frames, + ) + + flags = dynamic_input_flags_from_frames( + DynamicInputFrames( + trades=pd.DataFrame({"unix": [1], "token_in": ["ETH"], "token_out": ["USDC"], "amount_in": [1.0]}), + fees=pd.DataFrame({"unix": [1], "fees": [0.003]}), + gas_cost=pd.DataFrame({"unix": [1], "trade_gas_cost_usd": [2.0]}), + ) + ) + + assert flags["use_dynamic_inputs"] is True + assert flags["has_trades"] is True + assert flags["has_dynamic_fees"] is True + assert flags["has_dynamic_gas_cost"] is True + assert flags["has_dynamic_arb_fees"] is False + assert flags["has_lp_supply"] is False + + def test_prepare_dynamic_inputs_preserves_fixed_hot_path_structure(self): + """Normalization should return fixed bundles plus static dispatch flags.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:02:00", + "endTestDateString": "2023-01-01 00:04:00", + } + + dynamic_input_frames = DynamicInputFrames( + trades=pd.DataFrame( + { + "unix": [1672531200000, 1672531320000], + "token_in": ["ETH", "USDC"], + "token_out": ["USDC", "ETH"], + "amount_in": [1.5, 2.0], + } + ), + fees=pd.DataFrame({"unix": [1672531200000], "fees": [0.003]}), + gas_cost=pd.DataFrame({"unix": [1672531200000], "trade_gas_cost_usd": [3.25]}), + arb_fees=pd.DataFrame({"unix": [1672531200000], "arb_fees": [0.0005]}), + lp_supply=pd.DataFrame({"unix": [1672531200000], "lp_supply": [1250.0]}), + ) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=dynamic_input_frames, + do_test_period=True, + ) + + train_inputs = prepared["train_dynamic_inputs"] + test_inputs = prepared["test_dynamic_inputs"] + flags = prepared["dynamic_input_flags"] + + assert flags["use_dynamic_inputs"] is True + assert flags["has_trades"] is True + assert flags["has_dynamic_fees"] is True + assert flags["has_dynamic_gas_cost"] is True + assert flags["has_dynamic_arb_fees"] is True + assert flags["has_lp_supply"] is True + assert train_inputs.trades.shape == (2, 3) + assert train_inputs.fees.shape == (2,) + assert train_inputs.gas_cost.shape == (2,) + assert train_inputs.arb_fees.shape == (2,) + assert train_inputs.lp_supply.shape == (2,) + assert test_inputs.trades.shape == (2, 3) + assert test_inputs.fees.shape == (2,) + assert test_inputs.gas_cost.shape == (2,) + assert test_inputs.arb_fees.shape == (2,) + assert test_inputs.lp_supply.shape == (2,) + np.testing.assert_allclose(np.asarray(train_inputs.fees), np.array([0.003, 0.003])) + + def test_prepare_dynamic_inputs_uses_correct_test_period_values(self): + """Test-period arrays should use values effective from the test window onward.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:02:00", + "endTestDateString": "2023-01-01 00:04:00", + } + end_unix = pd.Timestamp(run_fingerprint["endDateString"]).value // 10**6 + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + fees=pd.DataFrame( + {"unix": [1672531200000, end_unix], "fees": [0.003, 0.004]} + ), + gas_cost=pd.DataFrame( + { + "unix": [1672531200000, end_unix], + "trade_gas_cost_usd": [1.5, 2.5], + } + ), + arb_fees=pd.DataFrame( + {"unix": [1672531200000, end_unix], "arb_fees": [0.0001, 0.0002]} + ), + lp_supply=pd.DataFrame( + {"unix": [1672531200000, end_unix], "lp_supply": [1000.0, 2000.0]} + ), + ), + do_test_period=True, + ) + + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].fees), + np.array([0.004, 0.004]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].gas_cost), + np.array([2.5, 2.5]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].arb_fees), + np.array([0.0002, 0.0002]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].lp_supply), + np.array([2000.0, 2000.0]), + ) + + def test_resolve_dynamic_input_flags_promotes_explicit_bundle(self): + """Passing a bundle directly should force dynamic-path dispatch.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + resolve_dynamic_input_flags, + ) + + flags = resolve_dynamic_input_flags( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": False, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + ) + + assert flags["use_dynamic_inputs"] is True + + def test_resolve_dynamic_input_components_falls_back_to_static_scalars(self): + """Static scalar config should materialize as singleton arrays when no frames are present.""" + from quantammsim.core_simulator.dynamic_inputs import ( + default_dynamic_input_flags, + resolve_dynamic_input_components, + ) + + resolved = resolve_dynamic_input_components( + dynamic_inputs=None, + dynamic_input_flags=default_dynamic_input_flags(), + static_dict={"fees": 0.003, "gas_cost": 2.5, "arb_fees": 0.0001}, + ) + + assert resolved["trades"] is None + np.testing.assert_allclose(np.asarray(resolved["fees"]), np.array([0.003])) + np.testing.assert_allclose(np.asarray(resolved["gas_cost"]), np.array([2.5])) + np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0001])) + np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1.0])) + + def test_resolve_dynamic_input_components_prefers_dynamic_values(self): + """Dynamic arrays should override static scalar defaults for enabled fields.""" + from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + resolve_dynamic_input_components, + ) + + dynamic_inputs = DynamicInputArrays( + trades=jnp.array([[0.0, 1.0, 5.0]]), + fees=jnp.array([0.004]), + gas_cost=jnp.array([3.0]), + arb_fees=jnp.array([0.0003]), + lp_supply=jnp.array([1500.0]), + ) + flags = { + "use_dynamic_inputs": True, + "has_trades": True, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": True, + "has_dynamic_arb_fees": True, + "has_lp_supply": True, + } + + resolved = resolve_dynamic_input_components( + dynamic_inputs=dynamic_inputs, + dynamic_input_flags=flags, + static_dict={"fees": 0.0, "gas_cost": 0.0, "arb_fees": 0.0}, + ) + + np.testing.assert_allclose(np.asarray(resolved["trades"]), np.array([[0.0, 1.0, 5.0]])) + np.testing.assert_allclose(np.asarray(resolved["fees"]), np.array([0.004])) + np.testing.assert_allclose(np.asarray(resolved["gas_cost"]), np.array([3.0])) + np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0003])) + np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1500.0])) + + def test_materialize_dynamic_inputs_leaves_trades_optional(self): + """No-trade paths should not expand placeholder trades into the scan inputs.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + materialize_dynamic_inputs, + ) + + materialized = materialize_dynamic_inputs( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": True, + "has_trades": False, + "has_dynamic_fees": False, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + static_dict={"fees": 0.003, "gas_cost": 2.5, "arb_fees": 0.0001}, + scan_len=4, + do_trades=False, + ) + + assert materialized.trades is None + np.testing.assert_allclose(np.asarray(materialized.fees), np.full(4, 0.003)) + np.testing.assert_allclose(np.asarray(materialized.gas_cost), np.full(4, 2.5)) + np.testing.assert_allclose(np.asarray(materialized.arb_fees), np.full(4, 0.0001)) + np.testing.assert_allclose(np.asarray(materialized.lp_supply), np.ones(4)) + + def test_materialize_dynamic_inputs_requires_trades_when_enabled(self): + """Trade-enabled scans should fail fast if no trade path is available.""" + from quantammsim.core_simulator.dynamic_inputs import ( + empty_dynamic_input_arrays, + materialize_dynamic_inputs, + ) + + with pytest.raises(ValueError, match="Trades must be provided"): + materialize_dynamic_inputs( + empty_dynamic_input_arrays(), + { + "use_dynamic_inputs": True, + "has_trades": False, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": False, + "has_dynamic_arb_fees": False, + "has_lp_supply": False, + }, + static_dict={"fees": 0.003, "gas_cost": 0.0, "arb_fees": 0.0}, + scan_len=2, + do_trades=True, + ) + + class TestGetSigVariations: """Tests for get_sig_variations function.""" diff --git a/tests/unit/test_jax_runners_comprehensive.py b/tests/unit/test_jax_runners_comprehensive.py index dcdd9dd..995e6d8 100644 --- a/tests/unit/test_jax_runners_comprehensive.py +++ b/tests/unit/test_jax_runners_comprehensive.py @@ -10,6 +10,7 @@ """ import pytest import numpy as np +import pandas as pd import jax.numpy as jnp import jax from copy import deepcopy @@ -18,6 +19,7 @@ from quantammsim.runners.jax_runners import ( train_on_historic_data, do_run_on_historic_data, + do_run_on_historic_data_with_provided_coarse_weights, ) from quantammsim.runners.jax_runner_utils import ( NestedHashabledict, @@ -31,8 +33,10 @@ create_static_dict, ) from quantammsim.runners.default_run_fingerprint import run_fingerprint_defaults +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames from quantammsim.core_simulator.param_utils import recursive_default_set, check_run_fingerprint from quantammsim.pools.creator import create_pool +from quantammsim.utils.data_processing.historic_data_utils import get_data_dict from tests.conftest import TEST_DATA_DIR @@ -389,6 +393,219 @@ def test_multiple_param_sets(self, defaulted_run_fingerprint, sample_params): assert isinstance(results, list) assert len(results) == 2 + def test_dynamic_trades_change_balancer_reserves(self, defaulted_run_fingerprint): + """Dynamic trade input should change the reserve path in the runner.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "balancer" + fp["do_arb"] = False + fp["fees"] = 0.0 + fp["gas_cost"] = 0.0 + fp["arb_fees"] = 0.0 + params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + + trade_unix = pd.Timestamp(fp["startDateString"]).value // 10**6 + trades_df = pd.DataFrame( + { + "unix": [trade_unix], + "token_in": ["ETH"], + "token_out": ["USDC"], + "amount_in": [100.0], + } + ) + + result_without_trades = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_with_trades = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(trades=trades_df), + ) + + assert not np.allclose( + np.asarray(result_without_trades["reserves"]), + np.asarray(result_with_trades["reserves"]), + ) + assert result_with_trades["reserves"][0, 0] > result_without_trades["reserves"][0, 0] + assert result_with_trades["reserves"][0, 1] < result_without_trades["reserves"][0, 1] + + def test_dynamic_arb_fees_match_scalar_arb_fees(self, defaulted_run_fingerprint): + """Constant dynamic arb fees should match the scalar arb-fee path.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "balancer" + fp["fees"] = 0.003 + fp["do_arb"] = True + params = {"initial_weights_logits": jnp.array([0.0, 0.0])} + + arb_fee = 0.002 + arb_fees_df = pd.DataFrame( + { + "unix": [pd.Timestamp(fp["startDateString"]).value // 10**6], + "arb_fees": [arb_fee], + } + ) + + result_scalar = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + arb_fees=arb_fee, + ) + result_dynamic = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(arb_fees=arb_fees_df), + ) + + np.testing.assert_allclose( + np.asarray(result_dynamic["value"]), + np.asarray(result_scalar["value"]), + rtol=1e-6, + atol=1e-6, + ) + + def test_dynamic_lp_supply_changes_momentum_runner_path(self, defaulted_run_fingerprint, sample_params): + """LP supply changes should affect the main momentum runner path.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["protocol_fee_split"] = 0.25 + + start_unix = pd.Timestamp(fp["startDateString"]).value // 10**6 + midpoint_unix = start_unix + 3 * 1440 * 60 * 1000 + constant_lp_supply_df = pd.DataFrame( + { + "unix": [start_unix], + "lp_supply": [1.0], + } + ) + stepped_lp_supply_df = pd.DataFrame( + { + "unix": [start_unix, midpoint_unix], + "lp_supply": [1.0, 2.0], + } + ) + + result_without_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_with_constant_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(lp_supply=constant_lp_supply_df), + ) + result_with_stepped_lp_supply = do_run_on_historic_data( + fp, + params=sample_params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(lp_supply=stepped_lp_supply_df), + ) + + np.testing.assert_allclose( + np.asarray(result_with_constant_lp_supply["value"]), + np.asarray(result_without_lp_supply["value"]), + rtol=1e-6, + atol=1e-6, + ) + assert not np.allclose( + np.asarray(result_with_stepped_lp_supply["reserves"][-1]), + np.asarray(result_without_lp_supply["reserves"][-1]), + ) + assert float(result_with_stepped_lp_supply["final_value"]) != pytest.approx( + float(result_without_lp_supply["final_value"]) + ) + + def test_provided_coarse_weights_respect_scalar_and_dynamic_gas(self, defaulted_run_fingerprint, sample_params): + """Provided-coarse-weight path should honor both scalar gas and dynamic gas arrays.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["protocol_fee_split"] = 0.0 + + data_dict = get_data_dict( + list_of_tickers=fp["tokens"], + run_fingerprint=fp, + data_kind=fp["optimisation_settings"]["training_data_kind"], + root=TEST_DATA_DIR, + max_memory_days=fp["max_memory_days"], + start_date_string=fp["startDateString"], + end_time_string=fp["endDateString"], + start_time_test_string=fp["endDateString"], + end_time_test_string=fp["endTestDateString"], + max_mc_version=fp["optimisation_settings"]["max_mc_version"], + do_test_period=False, + ) + + coarse_unix_values = ( + pd.date_range( + start=pd.Timestamp(fp["startDateString"]), + end=pd.Timestamp(fp["endDateString"]), + freq=f"{fp['chunk_period']}min", + ) + .astype(np.int64) + // 10**6 + ) + coarse_weights = { + "weights": jnp.tile(jnp.array([[0.5, 0.5]]), (len(coarse_unix_values), 1)), + "unix_values": jnp.asarray(coarse_unix_values), + } + + params = deepcopy(sample_params) + initial_prices = jnp.asarray(data_dict["prices"][data_dict["start_idx"]], dtype=jnp.float64) + params["initial_reserves"] = (jnp.array([0.5, 0.5]) * fp["initial_pool_value"]) / initial_prices + + gas_cost = 50.0 + gas_cost_df = pd.DataFrame( + { + "unix": [pd.Timestamp(fp["startDateString"]).value // 10**6], + "trade_gas_cost_usd": [gas_cost], + } + ) + + result_no_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + ) + result_scalar_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + gas_cost=gas_cost, + ) + result_dynamic_gas = do_run_on_historic_data_with_provided_coarse_weights( + fp, + coarse_weights=coarse_weights, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames(gas_cost=gas_cost_df), + ) + + assert float(result_scalar_gas["final_value"]) != pytest.approx( + float(result_no_gas["final_value"]) + ) + np.testing.assert_allclose( + np.asarray(result_dynamic_gas["value"]), + np.asarray(result_scalar_gas["value"]), + rtol=1e-6, + atol=1e-6, + ) + # ============================================================================ # Validation and Early Stopping Tests diff --git a/tests/unit/test_lint_bugs.py b/tests/unit/test_lint_bugs.py index 95f7ee6..bbb49e6 100644 --- a/tests/unit/test_lint_bugs.py +++ b/tests/unit/test_lint_bugs.py @@ -117,16 +117,13 @@ def calculate_reserves_zero_fees(self, params, static_dict, prices, start_index) prices = jnp.ones((20, 2)) start_index = jnp.array([0, 0]) - # __wrapped__ arg order: params, start_index, prices, trades, fees, - # gas_cost, arb_fees, pool, static_dict + # __wrapped__ arg order: params, start_index, prices, dynamic_inputs, + # pool, static_dict result = forward_pass.__wrapped__( {}, # params start_index, # start_index prices, # prices - None, # trades_array - None, # fees_array - None, # gas_cost_array - None, # arb_fees_array + None, # dynamic_inputs _MockPool(), # pool static_dict, # static_dict )