Source code for tradeexecutor.analysis.grid_search

"""Grid search result analysis.

- Breaddown of performance of different grid search combinations

- Heatmap and other comparison methods

"""

from typing import List

import numpy as np
import pandas as pd
from IPython.core.display_functions import display

import plotly.express as px
from plotly.graph_objs import Figure

from tradeexecutor.backtest.grid_search import GridSearchResult


VALUE_COLS = ["Annualised return", "Max drawdown", "Sharpe", "Sortino", "Average position", "Median position"]

PERCENT_COLS = ["Annualised return", "Max drawdown", "Average position", "Median position"]


[docs]def analyse_combination( r: GridSearchResult, min_positions_threshold: int, ) -> dict: """Create a grid search result table row. - Create columns we can use to compare different grid search combinations :param min_positions_threshold: If we did less positions than this amount, do not consider this a proper strategy. Filter out one position outliers. """ row = {} param_names = [] for param in r.combination.parameters: row[param.name] = param.value param_names.append(param.name) def clean(x): if x == "-": return np.NaN elif x == "": return np.NaN return x # import ipdb ; ipdb.set_trace() row.update({ # "Combination": r.combination.get_label(), "Positions": r.summary.total_positions, # "Return": r.summary.return_percent, # "Return2": r.summary.annualised_return_percent, #"Annualised profit": clean(r.metrics.loc["Expected Yearly"][0]), "Annualised return": clean(r.metrics.loc["Annualised return (raw)"][0]), "Max drawdown": clean(r.metrics.loc["Max Drawdown"][0]), "Sharpe": clean(r.metrics.loc["Sharpe"][0]), "Sortino": clean(r.metrics.loc["Sortino"][0]), "Average position": r.summary.average_trade, "Median position": r.summary.median_trade, }) # Clear all values except position count if this is not a good trade series if r.summary.total_positions < min_positions_threshold: for k in row.keys(): if k != "Positions" and k not in param_names: row[k] = np.NaN return row
[docs]def analyse_grid_search_result( results: List[GridSearchResult], min_positions_threshold: int = 5, ) -> pd.DataFrame: """Create aa table showing grid search result of each combination. - Each row have labeled parameters of its combination - Each row has some metrics extracted from the results by :py:func:`analyse_combination` See also :py:func:`analyse_combination`. :param results: Output from :py:meth:`tradeexecutor.backtest.grid_search.perform_grid_search`. :param min_positions_threshold: If we did less positions than this amount, do not consider this a proper strategy. Filter out one position outliers. :return: Table of grid search combinations """ assert len(results) > 0, "No results" rows = [analyse_combination(r, min_positions_threshold) for r in results] df = pd.DataFrame(rows) r = results[0] param_names = [p.name for p in r.combination.parameters] df = df.set_index(param_names) df = df.sort_index() return df
[docs]def visualise_table(df: pd.DataFrame): """Render a grid search combination table to notebook output. - Highlight winners and losers """ # https://stackoverflow.com/a/57152529/315168 # TODO: # Diverge color gradient around zero # https://stackoverflow.com/a/60654669/315168 formatted = df.style.background_gradient( axis = 0, subset = VALUE_COLS, ).highlight_min( color = 'pink', axis = 0, subset = VALUE_COLS, ).highlight_max( color = 'darkgreen', axis = 0, subset = VALUE_COLS, ).format( formatter="{:.2%}", subset = PERCENT_COLS, ) # formatted = df.style.highlight_max( # color = 'lightgreen', # axis = 0, # subset = VALUE_COLS, # ).highlight_min( # color = 'pink', # axis = 0, # subset = VALUE_COLS, # ).format( # formatter="{:.2%}", # subset = PERCENT_COLS, # ) display(formatted)
[docs]def visualise_heatmap_2d( result: pd.DataFrame, parameter_1: str, parameter_2: str, metric: str, color_continuous_scale='Bluered_r', continuous_scale: bool | None = None, ) -> Figure: """Draw a heatmap square comparing two different parameters. Directly shows the resulting matplotlib figure. :param parameter_1: Y axis :param parameter_2: X axis :param metric: Value to examine :param result: Grid search results as a DataFrame. Created by :py:func:`analyse_grid_search_result`. :param color_continuous_scale: The name of Plotly gradient used for the colour scale. :param continuous_scale: Are the X and Y scales continuous. X and Y scales cannot be continuous if they contain values like None or NaN. This will stretch the scale to infinity or zero. Set `True` to force continuous, `False` to force discreet steps, `None` to autodetect. :return: Plotly Figure object """ # Reset multi-index so we can work with parameter 1 and 2 as series df = result.reset_index() # Detect any non-number values on axes if continuous_scale is None: continuous_scale = not(df[parameter_1].isna().any() or df[parameter_2].isna().any()) # setting all column values to string will hint # Plotly to make all boxes same size regardless of value if not continuous_scale: df[parameter_1] = df[parameter_1].astype(str) df[parameter_2] = df[parameter_2].astype(str) df = df.pivot(index=parameter_1, columns=parameter_2, values=metric) # Format percents inside the cells and mouse hovers if metric in PERCENT_COLS: text = df.applymap(lambda x: f"{x * 100:,.2f}%") else: text = df.applymap(lambda x: f"{x:,.2f}") fig = px.imshow( df, labels=dict(x=parameter_2, y=parameter_1, color=metric), aspect="auto", title=metric, color_continuous_scale=color_continuous_scale, ) fig.update_traces(text=text, texttemplate="%{text}") fig.update_layout( title={"text": metric}, height=600, ) return fig