diff --git a/quantammsim/pools/G3M/optimal_n_pool_arb.py b/quantammsim/pools/G3M/optimal_n_pool_arb.py index d44a045..0c03355 100644 --- a/quantammsim/pools/G3M/optimal_n_pool_arb.py +++ b/quantammsim/pools/G3M/optimal_n_pool_arb.py @@ -166,18 +166,18 @@ def construct_optimal_trade_jnp( valid_post_trade_reserves = ( jnp.sum(initial_reserves + active_overall_trade > 0) == n ) - valid_post_trade_constant = ( - jnp.prod( - ( - initial_reserves - + active_overall_trade - * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) - ) - ** initial_weights + post_trade_constant = jnp.prod( + ( + initial_reserves + + active_overall_trade + * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) ) - - initial_constant - >= slack + ** initial_weights ) + # Use proportional tolerance: (post - initial) / initial >= slack + # This makes the tolerance scale-invariant with pool size + relative_diff = (post_trade_constant - initial_constant) / initial_constant + valid_post_trade_constant = relative_diff >= slack valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant) return jnp.where(valid_trade, active_overall_trade, 0) # return active_overall_trade, valid_post_trade_reserves * valid_post_trade_constant @@ -338,18 +338,18 @@ def calc_optimal_trade_for_one_signature( valid_post_trade_reserves = ( jnp.sum(initial_reserves + active_overall_trade > 0) == n ) - valid_post_trade_constant = ( - jnp.prod( - ( - initial_reserves - + active_overall_trade - * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) - ) - ** initial_weights + post_trade_constant = jnp.prod( + ( + initial_reserves + + active_overall_trade + * (fee_gamma ** (trade_to_direction_jnp(active_overall_trade))) ) - - initial_constant - >= slack + ** initial_weights ) + # Use proportional tolerance: (post - initial) / initial >= slack + # This makes the tolerance scale-invariant with pool size + relative_diff = (post_trade_constant - initial_constant) / initial_constant + valid_post_trade_constant = relative_diff >= slack valid_trade = jnp.logical_and(valid_post_trade_reserves, valid_post_trade_constant) return jnp.where(valid_trade, active_overall_trade, 0) # return { @@ -413,7 +413,7 @@ def parallelised_optimal_trade_sifter( tokens_to_drop, fee_gamma, n, - 0, + slack, ) profits = -(overall_trades * local_prices).sum(-1) @@ -459,7 +459,211 @@ def wrapped_parallelised_optimal_trade_sifter( tokens_to_drop, fee_gamma, n, - slack=0, + slack=slack, ) return trade + +# ============================================================================ +# Arbitrage Fee-Adjusted Optimal Trade Functions +# ============================================================================ +# +# When arb_fees > 0, the arbitrageur faces external rebalancing costs that +# affect the optimal trade size. These wrapper functions incorporate arb_fees +# into the optimal trade calculation by adjusting effective market prices. +# +# Mathematical derivation: +# The modified objective is: max_Φ: -∑(mₚ,ᵢΦᵢ) - (arb_fees/2) × ∑|mₚ,ᵢΦᵢ| +# For a fixed trade signature, |Φᵢ| = sign(Φᵢ) × Φᵢ, so: +# objective = -∑(mₚ,ᵢ × (1 + arb_fees/2 × sᵢ) × Φᵢ) +# where sᵢ = +1 for tokens in, -1 for tokens out. +# +# This is equivalent to using adjusted effective market prices: +# m̃ₚ,ᵢ = mₚ,ᵢ × (1 + arb_fees/2 × sᵢ) +# +# See https://arxiv.org/abs/2402.06731 for the base derivation. +# ============================================================================ + + +def adjust_prices_for_arb_fees( + local_prices, + active_trade_direction, + tokens_to_drop, + arb_fees, +): + """ + Adjust market prices to account for external arbitrage costs. + + When arb_fees > 0, the arbitrageur faces additional costs for rebalancing + their portfolio on external venues. This function computes effective prices + that incorporate these costs into the optimization. + + Parameters + ---------- + local_prices : jnp.ndarray + Array of market prices for each token. + active_trade_direction : jnp.ndarray + Array where 1 = token going into pool, 0 = token coming out. + tokens_to_drop : jnp.ndarray + Boolean array indicating inactive tokens for this signature. + arb_fees : float + External arbitrage fees as a fraction (e.g., 0.001 = 0.1%). + + Returns + ------- + jnp.ndarray + Adjusted prices incorporating external cost effects. + """ + # sig_direction: +1 for tokens IN (arb sells to pool, buys externally at higher price) + # -1 for tokens OUT (arb buys from pool, sells externally at lower price) + sig_direction = 2 * active_trade_direction - 1 # converts 0->-1, 1->+1 + price_adjustment = 1.0 + 0.5 * arb_fees * sig_direction + # Only adjust active tokens (not dropped ones) + price_adjustment = jnp.where(tokens_to_drop, 1.0, price_adjustment) + return local_prices * price_adjustment + + +adjust_prices_for_arb_fees_across_signatures = vmap( + adjust_prices_for_arb_fees, in_axes=[None, 0, 0, None] +) + + +def precalc_components_with_arb_fees_for_one_signature( + initial_weights, + local_prices, + fee_gamma, + tokens_to_drop, + active_trade_direction, + leave_one_out_idx, + arb_fees, +): + """ + Wrapper that adjusts prices for arb_fees then calls the base precalc function. + + Parameters + ---------- + initial_weights : jnp.ndarray + Array of pool weights for each token. + local_prices : jnp.ndarray + Array of market prices for each token. + fee_gamma : float + Pool fee parameter (1 - swap_fee). + tokens_to_drop : jnp.ndarray + Boolean array indicating inactive tokens for this signature. + active_trade_direction : jnp.ndarray + Array where 1 = token going into pool, 0 = token coming out. + leave_one_out_idx : jnp.ndarray + Index array for computing products excluding each element. + arb_fees : float + External arbitrage fees as a fraction. + + Returns + ------- + tuple + (active_initial_weights, per_asset_ratio, all_other_assets_ratio) + """ + adjusted_prices = adjust_prices_for_arb_fees( + local_prices, active_trade_direction, tokens_to_drop, arb_fees + ) + return precalc_components_of_optimal_trade_for_one_signature( + initial_weights, + adjusted_prices, + fee_gamma, + tokens_to_drop, + active_trade_direction, + leave_one_out_idx, + ) + + +precalc_components_with_arb_fees_across_signatures = vmap( + precalc_components_with_arb_fees_for_one_signature, + in_axes=[None, None, None, 0, 0, 0, None], +) + + +precalc_components_with_arb_fees_across_weights_and_prices = vmap( + precalc_components_with_arb_fees_across_signatures, + in_axes=[0, 0, None, None, None, None, None], +) + + +precalc_components_with_arb_fees_across_weights_and_prices_and_dynamic_fees = vmap( + precalc_components_with_arb_fees_across_signatures, + in_axes=[0, 0, 0, None, None, None, 0], +) + + +def wrapped_parallelised_optimal_trade_sifter_with_arb_fees( + initial_weights, + local_prices, + initial_reserves, + fee_gamma, + all_sig_variations, + n, + arb_fees=0.0, + slack=0, +): + """ + Compute optimal arbitrage trade incorporating external arb fees. + + This function extends wrapped_parallelised_optimal_trade_sifter to account + for external arbitrage costs (e.g., CEX fees, gas costs) by adjusting + effective market prices in the optimization. + + When arb_fees > 0, the optimal trade will be smaller than the zero-arb-fee + case because the arbitrageur stops trading earlier (marginal profit goes + to zero sooner due to external costs). + + Parameters + ---------- + initial_weights : jnp.ndarray + Array of pool weights for each token. + local_prices : jnp.ndarray + Array of market prices for each token. + initial_reserves : jnp.ndarray + Array of current pool reserves for each token. + fee_gamma : float + Pool fee parameter (1 - swap_fee). + all_sig_variations : jnp.ndarray + Array of all signature variations to test. + n : int + Number of tokens. + arb_fees : float, optional + External arbitrage fees as a fraction (default 0.0). + slack : float, optional + Slack for invariant validation (default 0). + + Returns + ------- + jnp.ndarray + Optimal trade vector incorporating arb_fees. + """ + _, active_trade_directions, tokens_to_drop, leave_one_out_idx = ( + precalc_shared_values_for_all_signatures(all_sig_variations, n) + ) + + active_initial_weights, per_asset_ratio, all_other_assets_ratio = ( + precalc_components_with_arb_fees_across_signatures( + initial_weights, + local_prices, + fee_gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idx, + arb_fees, + ) + ) + trade = parallelised_optimal_trade_sifter( + initial_reserves, + initial_weights, + local_prices, + active_initial_weights, + active_trade_directions, + per_asset_ratio, + all_other_assets_ratio, + tokens_to_drop, + fee_gamma, + n, + slack=slack, + ) + return trade diff --git a/quantammsim/pools/G3M/quantamm/weight_calculations/fine_weights.py b/quantammsim/pools/G3M/quantamm/weight_calculations/fine_weights.py index f56f1c3..10b54fb 100644 --- a/quantammsim/pools/G3M/quantamm/weight_calculations/fine_weights.py +++ b/quantammsim/pools/G3M/quantamm/weight_calculations/fine_weights.py @@ -485,15 +485,16 @@ def calc_fine_weight_output( maximum_change=maximum_change, method=weight_interpolation_method, ) - if rule_outputs_are_weights: - return weights - else: - return jnp.vstack( - [ - jnp.ones((chunk_period, n_assets), dtype=jnp.float64) * initial_weights, - weights, - ] - ) + # Prepend chunk_period rows of initial weights for both paths. + # This ensures weights at fine timestep t don't use prices beyond t. + # Without this prepending for weight-outputting rules, there would be + # a 1-step lookahead bias (weights computed from future prices). + return jnp.vstack( + [ + jnp.ones((chunk_period, n_assets), dtype=jnp.float64) * initial_weights, + weights, + ] + ) calc_fine_weight_output_from_weight_changes = jit( diff --git a/quantammsim/utils/data_processing/amalgamated_data_utils.py b/quantammsim/utils/data_processing/amalgamated_data_utils.py index 61cd65f..bed6fdb 100644 --- a/quantammsim/utils/data_processing/amalgamated_data_utils.py +++ b/quantammsim/utils/data_processing/amalgamated_data_utils.py @@ -34,6 +34,27 @@ def import_crypto_historical_data(token, root_path): def forward_fill_ohlcv_data(df, token): + """Forward fill OHLCV data to create a complete minute-level time series. + + Creates a complete minute-level index between the first and last timestamps, + then fills missing values using appropriate strategies for each column type. + + Parameters + ---------- + df : pd.DataFrame + DataFrame with 'unix' column (millisecond timestamps) and OHLCV columns. + token : str + Token symbol, used to identify the volume column. + + Returns + ------- + pd.DataFrame + DataFrame with complete minute-level index and filled values: + - close: forward filled + - open/high/low: filled with previous close + - volume columns: filled with 0 + - symbol: forward filled + """ # Set unix as index df.set_index("unix", inplace=True) diff --git a/quantammsim/utils/data_processing/price_data_fingerprint_utils.py b/quantammsim/utils/data_processing/price_data_fingerprint_utils.py index 1a9773c..d53272e 100644 --- a/quantammsim/utils/data_processing/price_data_fingerprint_utils.py +++ b/quantammsim/utils/data_processing/price_data_fingerprint_utils.py @@ -8,6 +8,7 @@ from quantammsim.utils.data_processing.historic_data_utils import ( get_historic_csv_data, get_historic_csv_data_w_versions, + get_historic_parquet_data, ) from quantammsim.core_simulator.param_utils import default_set @@ -119,7 +120,9 @@ def load_price_data_if_fingerprints_match( if verbose: print("loading data for all run fingerprints") if run_fingerprint["optimisation_settings"]["training_data_kind"] == "historic": - price_data = get_historic_csv_data(unique_tokens, cols=["close"], root=root) + price_data = get_historic_parquet_data( + unique_tokens, cols=["close"], root=root + ) return price_data elif run_fingerprint["optimisation_settings"]["training_data_kind"] == "mc": list_of_tickers = unique_tokens diff --git a/quantammsim/utils/plot_utils.py b/quantammsim/utils/plot_utils.py index c1ad9ae..6300b3d 100644 --- a/quantammsim/utils/plot_utils.py +++ b/quantammsim/utils/plot_utils.py @@ -1,25 +1,29 @@ import pandas as pd import numpy as np +import matplotlib as mpl import matplotlib.pyplot as plt import seaborn as sns -from datetime import datetime +import seaborn.objects as so +from matplotlib.ticker import MultipleLocator +from datetime import datetime, timezone +from pathlib import Path import warnings -warnings.filterwarnings("ignore") sns.set(rc={"text.usetex": True}) def calc_returns_from_values(values_in): - r"""Calculate returns + r"""Calculate period-over-period returns from a value series. + + Parameters ---------- - arr_in : np.ndarray, float64 - A one-dimenisional numpy array of price data + values_in : np.ndarray + A one-dimensional numpy array of value data. Returns ------- np.ndarray - The returns - + The returns array, with the first element set to 0.0. """ n = values_in.shape[0] returns = np.empty((n,), dtype=np.float64) @@ -33,16 +37,17 @@ def calc_returns_from_values(values_in): def calc_overall_returns_from_values(values_in): - r"""Calculate returns + r"""Calculate the total return from start to end of a value series. + + Parameters ---------- - arr_in : np.ndarray, float64 - A one-dimenisional numpy array of price data + values_in : np.ndarray + A one-dimensional numpy array of value data. Returns ------- float - The returns - + The overall return (final_value / initial_value - 1). """ returns = values_in[-1] / values_in[0] - 1.0 return returns @@ -309,7 +314,7 @@ def plot_vals( cmap_ = "RdYlGn" def get_dfwise(df): - df_wide = df.pivot(*cols) + df_wide = df.pivot(index=cols[0], columns=cols[1], values=cols[2]) # clean up window length values if window_size_estimate_yaxis: window_size_estimate = [str(item) for item in df_wide.index] @@ -832,3 +837,278 @@ def plot_lineplot( ax.set_yticklabels(y_value) plt.savefig(save_location, dpi=700, bbox_inches="tight") plt.close() + + +def name_to_latex_name(name): + """Convert run name to clean LaTeX formatted name. + + Parameters + ---------- + name : str + Name of the run (e.g. 'Current_index_BTC-ETH_min_0.1_index_memory_day_30.0') + + Returns + ------- + str + LaTeX formatted name + """ + if name.startswith("Current_index"): + return "$\\mathrm{Current\\ Index\\ Product}$" + elif name.startswith("HODL"): + return "$\\mathrm{HODL}$" + elif name.startswith("QuantAMM_index"): + return "$\\mathrm{QuantAMM\\ Index}$" + elif name.startswith("Balancer"): + return "$\\mathrm{Balancer}$" + elif name.startswith("Traditional DEX"): + return "$\\mathrm{Traditional\\ DEX}$" + elif name.startswith("Optimized_QuantAMM"): + rule = name.split("rule_")[-1] + rule = rule.replace("_", " ").title() + if rule == "Mean Reversion Channel": + rule = "Mean-Reversion\\ Channel" + elif rule == "Anti Momentum": + rule = "Anti-Momentum" + elif rule == "Power Channel": + rule = "Power-Channel" + return f"$\\mathrm{{QuantAMM\\ {rule}}}$" + else: + return name + + +def do_weight_change_as_rebalances_plots( + output_dict, + run_fingerprint, + n_bars=200, + plot_prefix="weight_change", + color="black", + verbose=True, +): + """Plot weight changes as rebalance bars overlaid on price series. + + Parameters + ---------- + output_dict : dict + Simulation output with 'reserves' and 'prices' arrays. + run_fingerprint : dict + Run configuration with 'tokens', 'chunk_period', date strings. + n_bars : int + Number of bars to display. + plot_prefix : str + Prefix for saved plot filenames. + color : str + Color for axis spines, ticks, and labels. + verbose : bool + Whether to print diagnostic info. + """ + output_dict = output_dict.copy() + plot_path = Path("./plots/") + plot_path.mkdir(parents=True, exist_ok=True) + + total_value = np.sum( + output_dict["reserves"] * output_dict["prices"], axis=1, keepdims=True + ) + weights = output_dict["reserves"] * output_dict["prices"] / total_value + output_dict["weights"] = weights + raw_weight_changes = np.diff(output_dict["weights"], axis=0) + raw_weight_changes = np.vstack( + [np.zeros((1, raw_weight_changes.shape[1])), raw_weight_changes] + ) + indexes = np.arange(len(output_dict["prices"])) + bar_fill_ratio = 0.8 + plot_prefix = "./plots/" + plot_prefix + first = True + lims = [] + decimation = int(len(indexes) / n_bars) + remainder = len(output_dict["prices"]) % decimation + if remainder > 0: + trim_length = len(output_dict["prices"]) - remainder + output_dict["prices"] = output_dict["prices"][:trim_length] + output_dict["weights"] = output_dict["weights"][:trim_length] + raw_weight_changes = raw_weight_changes[:trim_length] + indexes = indexes[:trim_length] + + tokens = sorted(run_fingerprint["tokens"]) + + for i in range(output_dict["prices"].shape[1]): + token = tokens[i] + prices_range = np.max(output_dict["prices"][:, i]) - np.min( + output_dict["prices"][:, i] + ) + + raw_weight_changes = np.vstack( + [ + np.zeros((1, raw_weight_changes.shape[1])), + np.diff(output_dict["weights"][::decimation], axis=0), + ] + ) + max_raw_weight_changes = np.max(np.abs(raw_weight_changes)) + scaled_pool_value = prices_range / max_raw_weight_changes + trades = raw_weight_changes[:, i] * scaled_pool_value + if verbose: + print(scaled_pool_value) + + bar_width = bar_fill_ratio * len(output_dict["prices"]) / n_bars + + ax = sns.lineplot( + x=np.arange(len(output_dict["prices"])), + y=output_dict["prices"][:, i], + legend=False, + color="#DAAB43", + linewidth=0.5, + ) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["left"].set_color(color) + ax.spines["bottom"].set_color(color) + ax.tick_params(axis="both", colors=color, direction="out", length=6, width=1) + ax.yaxis.set_ticks_position("left") + ax.xaxis.set_ticks_position("bottom") + ax.set_ylabel( + f"$\\mathrm{{{token}}}\\ \\mathrm{{price}}\\ (\\mathrm{{USD}})$", + color=color, + ) + + start_date = datetime.strptime( + run_fingerprint["startDateString"], "%Y-%m-%d %H:%M:%S" + ) + end_date = datetime.strptime( + run_fingerprint["endDateString"], "%Y-%m-%d %H:%M:%S" + ) + + total_seconds = len(output_dict["prices"]) + date_range = pd.date_range(start=start_date, end=end_date, periods=6) + + x_positions = np.linspace(0, total_seconds - 1, len(date_range)) + + date_labels = [f"$$\\mathrm{{{d.strftime('%Y-%m-%d')}}}$$" for d in date_range] + + plt.xticks(x_positions, date_labels, rotation=45) + ax.grid(False) + max_trades = trades + max_prices = output_dict["prices"][:, i][::decimation] + max_x = indexes[::decimation] + if first: + lims.append( + [ + int(np.max([np.min(max_trades + max_prices) * 0.9, 0])), + np.ceil(np.max(max_trades + max_prices) * 1.1), + ] + ) + + pos_mask = max_trades > 0 + neg_mask = max_trades < 0 + + plt.bar( + x=max_x[pos_mask], + height=max_trades[pos_mask], + width=bar_width, + bottom=max_prices[pos_mask], + color="g", + linewidth=0.0, + ) + plt.bar( + x=max_x[neg_mask], + height=max_trades[neg_mask], + width=bar_width, + bottom=max_prices[neg_mask], + color="r", + linewidth=0.0, + ) + + plt.ylim(*lims[i]) + + plt.savefig( + plot_prefix + + "_weight_change_signal_" + + "_token_" + + str(i) + + "_nbars_" + + str(n_bars) + + "_.png", + dpi=700, + bbox_inches="tight", + ) + + plt.close() + first = False + + +def plot_weights( + output_dict, run_fingerprint, plot_prefix="weights", plot_dir=None, verbose=True +): + """Plot token weights over time as a stacked area chart. + + Parameters + ---------- + output_dict : dict + Simulation output with 'reserves' and 'prices' arrays. + run_fingerprint : dict + Run configuration with 'tokens' and date strings. + plot_prefix : str + Prefix for saved plot filenames. + plot_dir : str or None + Directory to save plots. Defaults to './plots/'. + verbose : bool + Whether to print diagnostic info. + """ + if plot_dir is None: + plot_dir = "./plots/" + plot_path = Path(plot_dir) + plot_path.mkdir(parents=True, exist_ok=True) + + total_value = np.sum( + output_dict["reserves"] * output_dict["prices"], axis=1, keepdims=True + ) + weights = np.array(output_dict["reserves"] * output_dict["prices"] / total_value) + + weights = weights[::1440] + df_list = [] + tokens = sorted(run_fingerprint["tokens"]) + for i, token in enumerate(tokens): + df_list.extend( + [ + {"Time": t, "Weight": w, "Token": token} + for t, w in enumerate(weights[:, i]) + ] + ) + + df = pd.DataFrame(df_list) + start_date = datetime.strptime( + run_fingerprint["startDateString"], "%Y-%m-%d %H:%M:%S" + ) + end_date = datetime.strptime(run_fingerprint["endDateString"], "%Y-%m-%d %H:%M:%S") + + date_range = pd.date_range( + start=start_date, end=end_date, periods=len(df["Time"].unique()) + ) + df["Time"] = np.tile(date_range, weights.shape[1]) + + f = mpl.figure.Figure() + + pl = ( + so.Plot(df, "Time", "Weight", color="Token") + .add(so.Area(alpha=0.7), so.Stack()) + .limit(y=(0, 1)) + .scale(color=sns.color_palette()) + .label(y="$\\mathrm{Weight}$", x="$\\mathrm{Date}$") + ) + + res = pl.on(f).plot() + ax = f.axes[0] + unique_dates = df["Time"].unique() + date_indices = np.linspace(0, len(unique_dates) - 1, 4, dtype=int) + selected_dates = unique_dates[date_indices] + + date_labels = [ + f"$$\\mathrm{{{pd.Timestamp(date).strftime('%Y-%m-%d')}}}$$" + for date in selected_dates + ] + ax.set_xticks(date_indices, date_labels, rotation=45) + plt.tight_layout() + pl.save( + plot_path / (plot_prefix + "_weights_over_time.png"), + dpi=700, + bbox_inches="tight", + ) + plt.close()