Source code for tradeexecutor.backtest.grid_search

"""Perform a grid search ove strategy parameters to find optimal parameters."""
import concurrent
import datetime
import itertools
import logging
import os
import pickle
import shutil
import signal
import sys
import warnings
from collections import Counter
from dataclasses import dataclass
from multiprocessing import Process
from pathlib import Path
from typing import Protocol, Dict, List, Tuple, Any, Optional, Iterable, Collection, Callable
import concurrent.futures.process

import numpy as np
import pandas as pd
import futureproof

from tradeexecutor.strategy.engine_version import TradingStrategyEngineVersion

try:
    from tqdm_loggable.auto import tqdm
except ImportError:
    # tqdm_loggable is only available at the live execution,
    # but fallback to normal TQDM auto mode
    from tqdm.auto import tqdm

from tradeexecutor.analysis.advanced_metrics import calculate_advanced_metrics, AdvancedMetricsMode
from tradeexecutor.analysis.trade_analyser import TradeSummary, build_trade_analysis
from tradeexecutor.backtest.backtest_routing import BacktestRoutingIgnoredModel
from tradeexecutor.backtest.backtest_runner import run_backtest_inline
from tradeexecutor.state.state import State
from tradeexecutor.state.types import USDollarAmount
from tradeexecutor.strategy.cycle import CycleDuration
from tradeexecutor.strategy.default_routing_options import TradeRouting
from tradeexecutor.strategy.routing import RoutingModel
from tradeexecutor.strategy.strategy_module import DecideTradesProtocol, DecideTradesProtocol2
from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse
from tradeexecutor.visual.equity_curve import calculate_equity_curve, calculate_returns


logger = logging.getLogger(__name__)


def _hide_warnings(func):
    """Function wrapper to suppress warnings caused by quantstats and numpy functions.

    Otherwise these warnings pollute notebook output.
    """

    # Hidden warnings include:

    # In perform_grid_search:
    # /home/.cache/pypoetry/virtualenvs/trade-executor-xSh0vQvh-py3.10/lib/python3.10/site-packages/numpy/lib/function_base.py:2854:
    # RuntimeWarning: invalid value encountered in divide
    # c /= stddev[:, None]

    # In perform_grid_search:
    # /home/alex/.cache/pypoetry/virtualenvs/trade-executor-xSh0vQvh-py3.10/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:2351:
    # RuntimeWarning: invalid value encountered in multiply
    # lower_bound = _a * scale + loc

    def wrapper(*args, **kwargs):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return func(*args, **kwargs)

    return wrapper


[docs]@dataclass class GridParameter: """One value in grid search matrix.""" name: str value: Any def __post_init__(self): pass def __hash__(self): return hash((self.name, self.value)) def __eq__(self, other): return self.name == other.name and self.value == other.value def to_path(self) -> str: """""" value = self.value if type(value) in (float, int, str): return f"{self.name}={self.value}" if value is None: return f"{self.name}=none" else: raise NotImplementedError(f"We do not support filename conversion for value {type(value)}={value}")
[docs]@dataclass() class GridCombination: """One combination line in grid search.""" #: How many of nth grid combinations this is #: index: int #: In which folder we store the result files of all grid search runs #: #: Each individual combination will have its subfolder based on its parameter. result_path: Path #: Alphabetically sorted list of parameters parameters: Tuple[GridParameter] def __post_init__(self): assert len(self.parameters) > 0 assert isinstance(self.result_path, Path), f"Expected Path, got {type(self.result_path)}" assert self.result_path.exists() and self.result_path.is_dir(), f"Not a dir: {self.result_path}" def __hash__(self): return hash(self.parameters) def __eq__(self, other): return self.parameters == other.parameters
[docs] def get_relative_result_path(self) -> Path: """Get the path where the resulting state file is stored. Try to avoid messing with 256 character limit on filenames, thus break down as folders. """ path_parts = [p.to_path() for p in self.parameters] return Path(os.path.join(*path_parts))
[docs] def get_full_result_path(self) -> Path: """Get the path where the resulting state file is stored.""" return self.result_path.joinpath(self.get_relative_result_path())
[docs] def validate(self): """Check arguments can be serialised as fs path.""" assert isinstance(self.get_relative_result_path(), Path)
[docs] def as_dict(self) -> dict: """Get as kwargs mapping.""" return {p.name: p.value for p in self.parameters}
[docs] def get_label(self) -> str: """Human readable label for this combination""" return f"#{self.index}, " + ", ".join([f"{p.name}={p.value}" for p in self.parameters])
[docs] def destructure(self) -> List[Any]: """Open parameters dict. This will return the arguments in the same order you pass them to :py:func:`prepare_grid_combinations`. """ return [p.value for p in self.parameters]
[docs]@dataclass(slots=True, frozen=False) class GridSearchResult: """Result for one grid combination.""" #: For which grid combination this result is combination: GridCombination #: The full back test state state: State #: Calculated trade summary summary: TradeSummary #: Performance metrics metrics: pd.DataFrame #: Was this result read from the earlier run save cached: bool = False #: Child process that created this result. #: #: Only applicable to multiprocessing process_id: int = None @staticmethod def has_result(combination: GridCombination): base_path = combination.result_path return base_path.joinpath(combination.get_full_result_path()).joinpath("result.pickle").exists()
[docs] @staticmethod def load(combination: GridCombination): """Deserialised from the cached Python pickle.""" base_path = combination.get_full_result_path() with open(base_path.joinpath("result.pickle"), "rb") as inp: result: GridSearchResult = pickle.load(inp) result.cached = True return result
[docs] def save(self): """Serialise as Python pickle.""" base_path = self.combination.get_full_result_path() base_path.mkdir(parents=True, exist_ok=True) with open(base_path.joinpath("result.pickle"), "wb") as out: pickle.dump(self, out)
[docs]class GridSearchWorker(Protocol): """Define how to create different strategy bodies."""
[docs] def __call__(self, universe: TradingStrategyUniverse, combination: GridCombination) -> GridSearchResult: """Run a new decide_trades() strategy body based over the serach parameters. :param args: :param kwargs: :return: """
[docs]def prepare_grid_combinations( parameters: Dict[str, List[Any]], result_path: Path, clear_cached_results=False, marker_file="README-GRID-SEARCH.md", ) -> List[GridCombination]: """Get iterable search matrix of all parameter combinations. - Make sure we preverse the original order of the grid search parameters. - Set up the folder to store the results :param parameters: A grid of parameters we will search. :param result_path: A folder where resulting state files will be stored. :param clear_cached_results: Clear any existing result files from the saved result cache. You need to do this if you change the strategy logic outside the given combination parameters, as the framework will otherwise serve you the old cached results. :param marker_file: Safety to prevent novice users to nuke their hard disk with this command. :return: List of all combinations we need to search through """ assert isinstance(result_path, Path) logger.info("Preparing %d grid combinations, caching results in %s", len(parameters), result_path) if clear_cached_results: if result_path.exists(): assert result_path.joinpath(marker_file).exists(), f"{result_path} does not seem to be grid search folder, it lacks {marker_file}" shutil.rmtree(result_path) result_path.mkdir(parents=True, exist_ok=True) with open(result_path.joinpath(marker_file), "wt") as out: print("This is a TradingStrategy.ai grid search result folder", file=out) args_lists: List[list] = [] for name, values in parameters.items(): assert isinstance(values, Collection), f"Expected list, got: {values}" args = [GridParameter(name, v) for v in values] args_lists.append(args) combinations = itertools.product(*args_lists) # Maintain the orignal parameter order over itertools.product() order = tuple(parameters.keys()) def sort_by_order(combination: List[GridParameter]) -> Tuple[GridParameter]: temp = {p.name: p for p in combination} return tuple([temp[o] for o in order]) combinations = [GridCombination(index=idx, parameters=sort_by_order(c), result_path=result_path) for idx, c in enumerate(combinations, start=1)] for c in combinations: c.validate() return combinations
[docs]def run_grid_combination( grid_search_worker: GridSearchWorker, universe: TradingStrategyUniverse, combination: GridCombination, ): if GridSearchResult.has_result(combination): result = GridSearchResult.load(combination) return result result = grid_search_worker(universe, combination) # Cache result for the future runs result.save() return result
[docs]def run_grid_combination_multiprocess( grid_search_worker: GridSearchWorker, combination: GridCombination, ): global _universe universe = _universe if GridSearchResult.has_result(combination): result = GridSearchResult.load(combination) return result result = grid_search_worker(universe, combination) result.process_id = os.getpid() # Cache result for the future runs result.save() return result
[docs]def run_grid_search_backtest( combination: GridCombination, decide_trades: DecideTradesProtocol | DecideTradesProtocol2, universe: TradingStrategyUniverse, cycle_duration: Optional[CycleDuration] = None, start_at: Optional[datetime.datetime | pd.Timestamp] = None, end_at: Optional[datetime.datetime | pd.Timestamp] = None, initial_deposit: USDollarAmount = 5000.0, trade_routing: Optional[TradeRouting] = None, data_delay_tolerance: Optional[pd.Timedelta] = None, name: Optional[str] = None, routing_model: Optional[TradingStrategyEngineVersion] = None, trading_strategy_engine_version: Optional[str] = None, ) -> GridSearchResult: assert isinstance(universe, TradingStrategyUniverse) if name is None: name = combination.get_label() universe_range = universe.data_universe.candles.get_timestamp_range() if not start_at: start_at = universe_range[0] if not end_at: end_at = universe_range[1] if isinstance(start_at, datetime.datetime): start_at = pd.Timestamp(start_at) if isinstance(end_at, datetime.datetime): end_at = pd.Timestamp(end_at) if not cycle_duration: cycle_duration = CycleDuration.from_timebucket(universe.data_universe.candles.time_bucket) else: assert isinstance(cycle_duration, CycleDuration) if not routing_model: routing_model = BacktestRoutingIgnoredModel(universe.get_reserve_asset().address) # Run the test state, universe, debug_dump = run_backtest_inline( name=name, start_at=start_at.to_pydatetime(), end_at=end_at.to_pydatetime(), client=None, cycle_duration=cycle_duration, decide_trades=decide_trades, create_trading_universe=None, universe=universe, initial_deposit=initial_deposit, reserve_currency=None, trade_routing=TradeRouting.user_supplied_routing_model, routing_model=routing_model, allow_missing_fees=True, data_delay_tolerance=data_delay_tolerance, engine_version=trading_strategy_engine_version, ) analysis = build_trade_analysis(state.portfolio) equity = calculate_equity_curve(state) returns = calculate_returns(equity) metrics = calculate_advanced_metrics( returns, mode=AdvancedMetricsMode.full, periods_per_year=cycle_duration.get_yearly_periods(), ) summary = analysis.calculate_summary_statistics() return GridSearchResult( combination=combination, state=state, summary=summary, metrics=metrics, )
[docs]def pick_grid_search_result(results: List[GridSearchResult], **kwargs) -> Optional[GridSearchResult]: """Pick one combination in the results. Example: .. code-block:: python # Pick a result of a single grid search combination # and examine its trading metrics sample = pick_grid_search_result( results, stop_loss_pct=0.9, slow_ema_candle_count=7, fast_ema_candle_count=2) assert sample.summary.total_positions == 2 :param result: Output from :py:func:`perform_grid_search` :param kwargs: Grid parameters to match :return: The grid search result with the matching parameters or None if not found """ for r in results: # Check if this result matches all the parameters match = all([p.value == kwargs[p.name] for p in r.combination.parameters]) if match: return r return None
[docs]def pick_best_grid_search_result( results: List[GridSearchResult], key: Callable=lambda r: r.summary.return_percent, highest=True, ) -> Optional[GridSearchResult]: """Pick the best combination in the results based on one metric. Use trading metrics or performance metrics for the selection. Example: .. code-block:: python sample = pick_best_grid_search_result( results, key=lambda r: r.metrics.loc["Max Drawdown"][0]) assert sample is not None :param result: Output from :py:func:`perform_grid_search` :param key: Lambda function to extract the value to compare from the data. If not given use cumulative return. :param highest: If true pick the highest value, otherwise lowest. :return: The grid search result with the matching parameters or None if not found :return: The grid search result with the matching parameters or None if not found """ current_best = -10**27 if highest else 10**27 match = None for r in results: # Check if this result matches all the parameters value = key(r) if value in (None, np.NaN): # No result for this combination continue if highest: if value > current_best: match = r current_best = value else: if value < current_best: match = r current_best = value return match
#: Process global stored universe for multiprocess workers _universe: Optional[TradingStrategyUniverse] = None _process_pool: concurrent.futures.process.ProcessPoolExecutor | None = None def _process_init(pickled_universe): """Child worker process initialiser.""" # Transfer ove the universe to the child process global _universe _universe = pickle.loads(pickled_universe) def _handle_sigterm(*args): # TODO: Despite all the effort, this does not seem to work with Visual Studio Code's Interrupt Kernel button processes: List[Process] = list(_process_pool._processes.values()) _process_pool.shutdown() for p in processes: p.kill() sys.exit(1)