Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 226 additions & 22 deletions quantammsim/pools/G3M/optimal_n_pool_arb.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,18 @@ def construct_optimal_trade_jnp(
valid_post_trade_reserves = (
jnp.sum(initial_reserves + active_overall_trade > 0) == n
)
valid_post_trade_constant = (
jnp.prod(
(
initial_reserves
+ active_overall_trade
* (fee_gamma ** (trade_to_direction_jnp(active_overall_trade)))
)
** initial_weights
post_trade_constant = jnp.prod(
(
initial_reserves
+ active_overall_trade
* (fee_gamma ** (trade_to_direction_jnp(active_overall_trade)))
)
- initial_constant
>= slack
** initial_weights
)
# Use proportional tolerance: (post - initial) / initial >= slack
# This makes the tolerance scale-invariant with pool size
relative_diff = (post_trade_constant - initial_constant) / initial_constant
valid_post_trade_constant = relative_diff >= slack
valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant)
return jnp.where(valid_trade, active_overall_trade, 0)
# return active_overall_trade, valid_post_trade_reserves * valid_post_trade_constant
Expand Down Expand Up @@ -338,18 +338,18 @@ def calc_optimal_trade_for_one_signature(
valid_post_trade_reserves = (
jnp.sum(initial_reserves + active_overall_trade > 0) == n
)
valid_post_trade_constant = (
jnp.prod(
(
initial_reserves
+ active_overall_trade
* (fee_gamma ** (trade_to_direction_jnp(active_overall_trade)))
)
** initial_weights
post_trade_constant = jnp.prod(
(
initial_reserves
+ active_overall_trade
* (fee_gamma ** (trade_to_direction_jnp(active_overall_trade)))
)
- initial_constant
>= slack
** initial_weights
)
# Use proportional tolerance: (post - initial) / initial >= slack
# This makes the tolerance scale-invariant with pool size
relative_diff = (post_trade_constant - initial_constant) / initial_constant
valid_post_trade_constant = relative_diff >= slack
valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant)
return jnp.where(valid_trade, active_overall_trade, 0)
# return {
Expand Down Expand Up @@ -413,7 +413,7 @@ def parallelised_optimal_trade_sifter(
tokens_to_drop,
fee_gamma,
n,
0,
slack,
)

profits = -(overall_trades * local_prices).sum(-1)
Expand Down Expand Up @@ -459,7 +459,211 @@ def wrapped_parallelised_optimal_trade_sifter(
tokens_to_drop,
fee_gamma,
n,
slack=0,
slack=slack,
)
return trade


# ============================================================================
# Arbitrage Fee-Adjusted Optimal Trade Functions
# ============================================================================
#
# When arb_fees > 0, the arbitrageur faces external rebalancing costs that
# affect the optimal trade size. These wrapper functions incorporate arb_fees
# into the optimal trade calculation by adjusting effective market prices.
#
# Mathematical derivation:
# The modified objective is: max_Φ: -∑(mₚ,ᵢΦᵢ) - (arb_fees/2) × ∑|mₚ,ᵢΦᵢ|
# For a fixed trade signature, |Φᵢ| = sign(Φᵢ) × Φᵢ, so:
# objective = -∑(mₚ,ᵢ × (1 + arb_fees/2 × sᵢ) × Φᵢ)
# where sᵢ = +1 for tokens in, -1 for tokens out.
#
# This is equivalent to using adjusted effective market prices:
# m̃ₚ,ᵢ = mₚ,ᵢ × (1 + arb_fees/2 × sᵢ)
#
# See https://arxiv.org/abs/2402.06731 for the base derivation.
# ============================================================================


def adjust_prices_for_arb_fees(
local_prices,
active_trade_direction,
tokens_to_drop,
arb_fees,
):
"""
Adjust market prices to account for external arbitrage costs.

When arb_fees > 0, the arbitrageur faces additional costs for rebalancing
their portfolio on external venues. This function computes effective prices
that incorporate these costs into the optimization.

Parameters
----------
local_prices : jnp.ndarray
Array of market prices for each token.
active_trade_direction : jnp.ndarray
Array where 1 = token going into pool, 0 = token coming out.
tokens_to_drop : jnp.ndarray
Boolean array indicating inactive tokens for this signature.
arb_fees : float
External arbitrage fees as a fraction (e.g., 0.001 = 0.1%).

Returns
-------
jnp.ndarray
Adjusted prices incorporating external cost effects.
"""
# sig_direction: +1 for tokens IN (arb sells to pool, buys externally at higher price)
# -1 for tokens OUT (arb buys from pool, sells externally at lower price)
sig_direction = 2 * active_trade_direction - 1 # converts 0->-1, 1->+1
price_adjustment = 1.0 + 0.5 * arb_fees * sig_direction
# Only adjust active tokens (not dropped ones)
price_adjustment = jnp.where(tokens_to_drop, 1.0, price_adjustment)
return local_prices * price_adjustment


adjust_prices_for_arb_fees_across_signatures = vmap(
adjust_prices_for_arb_fees, in_axes=[None, 0, 0, None]
)


def precalc_components_with_arb_fees_for_one_signature(
initial_weights,
local_prices,
fee_gamma,
tokens_to_drop,
active_trade_direction,
leave_one_out_idx,
arb_fees,
):
"""
Wrapper that adjusts prices for arb_fees then calls the base precalc function.

Parameters
----------
initial_weights : jnp.ndarray
Array of pool weights for each token.
local_prices : jnp.ndarray
Array of market prices for each token.
fee_gamma : float
Pool fee parameter (1 - swap_fee).
tokens_to_drop : jnp.ndarray
Boolean array indicating inactive tokens for this signature.
active_trade_direction : jnp.ndarray
Array where 1 = token going into pool, 0 = token coming out.
leave_one_out_idx : jnp.ndarray
Index array for computing products excluding each element.
arb_fees : float
External arbitrage fees as a fraction.

Returns
-------
tuple
(active_initial_weights, per_asset_ratio, all_other_assets_ratio)
"""
adjusted_prices = adjust_prices_for_arb_fees(
local_prices, active_trade_direction, tokens_to_drop, arb_fees
)
return precalc_components_of_optimal_trade_for_one_signature(
initial_weights,
adjusted_prices,
fee_gamma,
tokens_to_drop,
active_trade_direction,
leave_one_out_idx,
)


precalc_components_with_arb_fees_across_signatures = vmap(
precalc_components_with_arb_fees_for_one_signature,
in_axes=[None, None, None, 0, 0, 0, None],
)


precalc_components_with_arb_fees_across_weights_and_prices = vmap(
precalc_components_with_arb_fees_across_signatures,
in_axes=[0, 0, None, None, None, None, None],
)


precalc_components_with_arb_fees_across_weights_and_prices_and_dynamic_fees = vmap(
precalc_components_with_arb_fees_across_signatures,
in_axes=[0, 0, 0, None, None, None, 0],
)


def wrapped_parallelised_optimal_trade_sifter_with_arb_fees(
initial_weights,
local_prices,
initial_reserves,
fee_gamma,
all_sig_variations,
n,
arb_fees=0.0,
slack=0,
):
"""
Compute optimal arbitrage trade incorporating external arb fees.

This function extends wrapped_parallelised_optimal_trade_sifter to account
for external arbitrage costs (e.g., CEX fees, gas costs) by adjusting
effective market prices in the optimization.

When arb_fees > 0, the optimal trade will be smaller than the zero-arb-fee
case because the arbitrageur stops trading earlier (marginal profit goes
to zero sooner due to external costs).

Parameters
----------
initial_weights : jnp.ndarray
Array of pool weights for each token.
local_prices : jnp.ndarray
Array of market prices for each token.
initial_reserves : jnp.ndarray
Array of current pool reserves for each token.
fee_gamma : float
Pool fee parameter (1 - swap_fee).
all_sig_variations : jnp.ndarray
Array of all signature variations to test.
n : int
Number of tokens.
arb_fees : float, optional
External arbitrage fees as a fraction (default 0.0).
slack : float, optional
Slack for invariant validation (default 0).

Returns
-------
jnp.ndarray
Optimal trade vector incorporating arb_fees.
"""
_, active_trade_directions, tokens_to_drop, leave_one_out_idx = (
precalc_shared_values_for_all_signatures(all_sig_variations, n)
)

active_initial_weights, per_asset_ratio, all_other_assets_ratio = (
precalc_components_with_arb_fees_across_signatures(
initial_weights,
local_prices,
fee_gamma,
tokens_to_drop,
active_trade_directions,
leave_one_out_idx,
arb_fees,
)
)
trade = parallelised_optimal_trade_sifter(
initial_reserves,
initial_weights,
local_prices,
active_initial_weights,
active_trade_directions,
per_asset_ratio,
all_other_assets_ratio,
tokens_to_drop,
fee_gamma,
n,
slack=slack,
)
return trade
19 changes: 10 additions & 9 deletions quantammsim/pools/G3M/quantamm/weight_calculations/fine_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,16 @@ def calc_fine_weight_output(
maximum_change=maximum_change,
method=weight_interpolation_method,
)
if rule_outputs_are_weights:
return weights
else:
return jnp.vstack(
[
jnp.ones((chunk_period, n_assets), dtype=jnp.float64) * initial_weights,
weights,
]
)
# Prepend chunk_period rows of initial weights for both paths.
# This ensures weights at fine timestep t don't use prices beyond t.
# Without this prepending for weight-outputting rules, there would be
# a 1-step lookahead bias (weights computed from future prices).
return jnp.vstack(
[
jnp.ones((chunk_period, n_assets), dtype=jnp.float64) * initial_weights,
weights,
]
)


calc_fine_weight_output_from_weight_changes = jit(
Expand Down
21 changes: 21 additions & 0 deletions quantammsim/utils/data_processing/amalgamated_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,27 @@ def import_crypto_historical_data(token, root_path):


def forward_fill_ohlcv_data(df, token):
"""Forward fill OHLCV data to create a complete minute-level time series.

Creates a complete minute-level index between the first and last timestamps,
then fills missing values using appropriate strategies for each column type.

Parameters
----------
df : pd.DataFrame
DataFrame with 'unix' column (millisecond timestamps) and OHLCV columns.
token : str
Token symbol, used to identify the volume column.

Returns
-------
pd.DataFrame
DataFrame with complete minute-level index and filled values:
- close: forward filled
- open/high/low: filled with previous close
- volume columns: filled with 0
- symbol: forward filled
"""
# Set unix as index
df.set_index("unix", inplace=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from quantammsim.utils.data_processing.historic_data_utils import (
get_historic_csv_data,
get_historic_csv_data_w_versions,
get_historic_parquet_data,
)
from quantammsim.core_simulator.param_utils import default_set

Expand Down Expand Up @@ -119,7 +120,9 @@ def load_price_data_if_fingerprints_match(
if verbose:
print("loading data for all run fingerprints")
if run_fingerprint["optimisation_settings"]["training_data_kind"] == "historic":
price_data = get_historic_csv_data(unique_tokens, cols=["close"], root=root)
price_data = get_historic_parquet_data(
unique_tokens, cols=["close"], root=root
)
return price_data
elif run_fingerprint["optimisation_settings"]["training_data_kind"] == "mc":
list_of_tickers = unique_tokens
Expand Down
Loading