"""Perform a grid search ove strategy parameters to find optimal parameters."""
import concurrent
import datetime
import itertools
import logging
import os
import pickle
import shutil
import signal
import sys
import warnings
from collections import Counter
from dataclasses import dataclass
from multiprocessing import Process
from pathlib import Path
from typing import Protocol, Dict, List, Tuple, Any, Optional, Iterable, Collection, Callable
import concurrent.futures.process
import numpy as np
import pandas as pd
import futureproof
try:
from tqdm_loggable.auto import tqdm
except ImportError:
# tqdm_loggable is only available at the live execution,
# but fallback to normal TQDM auto mode
from tqdm.auto import tqdm
from tradeexecutor.analysis.advanced_metrics import calculate_advanced_metrics, AdvancedMetricsMode
from tradeexecutor.analysis.trade_analyser import TradeSummary, build_trade_analysis
from tradeexecutor.backtest.backtest_routing import BacktestRoutingIgnoredModel
from tradeexecutor.backtest.backtest_runner import run_backtest_inline
from tradeexecutor.state.state import State
from tradeexecutor.state.types import USDollarAmount
from tradeexecutor.strategy.cycle import CycleDuration
from tradeexecutor.strategy.default_routing_options import TradeRouting
from tradeexecutor.strategy.routing import RoutingModel
from tradeexecutor.strategy.strategy_module import DecideTradesProtocol
from tradeexecutor.strategy.trading_strategy_universe import TradingStrategyUniverse
from tradeexecutor.visual.equity_curve import calculate_equity_curve, calculate_returns
logger = logging.getLogger(__name__)
def _hide_warnings(func):
"""Function wrapper to suppress warnings caused by quantstats and numpy functions.
Otherwise these warnings pollute notebook output.
"""
# Hidden warnings include:
# In perform_grid_search:
# /home/.cache/pypoetry/virtualenvs/trade-executor-xSh0vQvh-py3.10/lib/python3.10/site-packages/numpy/lib/function_base.py:2854:
# RuntimeWarning: invalid value encountered in divide
# c /= stddev[:, None]
# In perform_grid_search:
# /home/alex/.cache/pypoetry/virtualenvs/trade-executor-xSh0vQvh-py3.10/lib/python3.10/site-packages/scipy/stats/_distn_infrastructure.py:2351:
# RuntimeWarning: invalid value encountered in multiply
# lower_bound = _a * scale + loc
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return func(*args, **kwargs)
return wrapper
[docs]@dataclass
class GridParameter:
"""One value in grid search matrix."""
name: str
value: Any
def __post_init__(self):
pass
def __hash__(self):
return hash((self.name, self.value))
def __eq__(self, other):
return self.name == other.name and self.value == other.value
def to_path(self) -> str:
""""""
value = self.value
if type(value) in (float, int, str):
return f"{self.name}={self.value}"
if value is None:
return f"{self.name}=none"
else:
raise NotImplementedError(f"We do not support filename conversion for value {type(value)}={value}")
[docs]@dataclass()
class GridCombination:
"""One combination line in grid search."""
#: How many of nth grid combinations this is
#:
index: int
#: In which folder we store the result files of all grid search runs
#:
#: Each individual combination will have its subfolder based on its parameter.
result_path: Path
#: Alphabetically sorted list of parameters
parameters: Tuple[GridParameter]
def __post_init__(self):
assert len(self.parameters) > 0
assert isinstance(self.result_path, Path), f"Expected Path, got {type(self.result_path)}"
assert self.result_path.exists() and self.result_path.is_dir(), f"Not a dir: {self.result_path}"
def __hash__(self):
return hash(self.parameters)
def __eq__(self, other):
return self.parameters == other.parameters
[docs] def get_relative_result_path(self) -> Path:
"""Get the path where the resulting state file is stored.
Try to avoid messing with 256 character limit on filenames, thus break down as folders.
"""
path_parts = [p.to_path() for p in self.parameters]
return Path(os.path.join(*path_parts))
[docs] def get_full_result_path(self) -> Path:
"""Get the path where the resulting state file is stored."""
return self.result_path.joinpath(self.get_relative_result_path())
[docs] def validate(self):
"""Check arguments can be serialised as fs path."""
assert isinstance(self.get_relative_result_path(), Path)
[docs] def as_dict(self) -> dict:
"""Get as kwargs mapping."""
return {p.name: p.value for p in self.parameters}
[docs] def get_label(self) -> str:
"""Human readable label for this combination"""
return f"#{self.index}, " + ", ".join([f"{p.name}={p.value}" for p in self.parameters])
[docs] def destructure(self) -> List[Any]:
"""Open parameters dict.
This will return the arguments in the same order you pass them to :py:func:`prepare_grid_combinations`.
"""
return [p.value for p in self.parameters]
[docs]@dataclass(slots=True, frozen=False)
class GridSearchResult:
"""Result for one grid combination."""
#: For which grid combination this result is
combination: GridCombination
#: The full back test state
state: State
#: Calculated trade summary
summary: TradeSummary
#: Performance metrics
metrics: pd.DataFrame
#: Was this result read from the earlier run save
cached: bool = False
#: Child process that created this result.
#:
#: Only applicable to multiprocessing
process_id: int = None
@staticmethod
def has_result(combination: GridCombination):
base_path = combination.result_path
return base_path.joinpath(combination.get_full_result_path()).joinpath("result.pickle").exists()
[docs] @staticmethod
def load(combination: GridCombination):
"""Deserialised from the cached Python pickle."""
base_path = combination.get_full_result_path()
with open(base_path.joinpath("result.pickle"), "rb") as inp:
result: GridSearchResult = pickle.load(inp)
result.cached = True
return result
[docs] def save(self):
"""Serialise as Python pickle."""
base_path = self.combination.get_full_result_path()
base_path.mkdir(parents=True, exist_ok=True)
with open(base_path.joinpath("result.pickle"), "wb") as out:
pickle.dump(self, out)
[docs]class GridSearchWorker(Protocol):
"""Define how to create different strategy bodies."""
[docs] def __call__(self, universe: TradingStrategyUniverse, combination: GridCombination) -> GridSearchResult:
"""Run a new decide_trades() strategy body based over the serach parameters.
:param args:
:param kwargs:
:return:
"""
[docs]def prepare_grid_combinations(
parameters: Dict[str, List[Any]],
result_path: Path,
clear_cached_results=False,
marker_file="README-GRID-SEARCH.md",
) -> List[GridCombination]:
"""Get iterable search matrix of all parameter combinations.
- Make sure we preverse the original order of the grid search parameters.
- Set up the folder to store the results
:param parameters:
A grid of parameters we will search.
:param result_path:
A folder where resulting state files will be stored.
:param clear_cached_results:
Clear any existing result files from the saved result cache.
You need to do this if you change the strategy logic outside
the given combination parameters, as the framework will otherwise
serve you the old cached results.
:param marker_file:
Safety to prevent novice users to nuke their hard disk with this command.
:return:
List of all combinations we need to search through
"""
assert isinstance(result_path, Path)
logger.info("Preparing %d grid combinations, caching results in %s", len(parameters), result_path)
if clear_cached_results:
if result_path.exists():
assert result_path.joinpath(marker_file).exists(), f"{result_path} does not seem to be grid search folder, it lacks {marker_file}"
shutil.rmtree(result_path)
result_path.mkdir(parents=True, exist_ok=True)
with open(result_path.joinpath(marker_file), "wt") as out:
print("This is a TradingStrategy.ai grid search result folder", file=out)
args_lists: List[list] = []
for name, values in parameters.items():
assert isinstance(values, Collection), f"Expected list, got: {values}"
args = [GridParameter(name, v) for v in values]
args_lists.append(args)
combinations = itertools.product(*args_lists)
# Maintain the orignal parameter order over itertools.product()
order = tuple(parameters.keys())
def sort_by_order(combination: List[GridParameter]) -> Tuple[GridParameter]:
temp = {p.name: p for p in combination}
return tuple([temp[o] for o in order])
combinations = [GridCombination(index=idx, parameters=sort_by_order(c), result_path=result_path) for idx, c in enumerate(combinations, start=1)]
for c in combinations:
c.validate()
return combinations
[docs]def run_grid_combination(
grid_search_worker: GridSearchWorker,
universe: TradingStrategyUniverse,
combination: GridCombination,
):
if GridSearchResult.has_result(combination):
result = GridSearchResult.load(combination)
return result
result = grid_search_worker(universe, combination)
# Cache result for the future runs
result.save()
return result
[docs]def run_grid_combination_multiprocess(
grid_search_worker: GridSearchWorker,
combination: GridCombination,
):
global _universe
universe = _universe
if GridSearchResult.has_result(combination):
result = GridSearchResult.load(combination)
return result
result = grid_search_worker(universe, combination)
result.process_id = os.getpid()
# Cache result for the future runs
result.save()
return result
[docs]def run_grid_search_backtest(
combination: GridCombination,
decide_trades: DecideTradesProtocol,
universe: TradingStrategyUniverse,
cycle_duration: Optional[CycleDuration] = None,
start_at: Optional[datetime.datetime | pd.Timestamp] = None,
end_at: Optional[datetime.datetime | pd.Timestamp] = None,
initial_deposit: USDollarAmount = 5000.0,
trade_routing: Optional[TradeRouting] = None,
data_delay_tolerance: Optional[pd.Timedelta] = None,
name: Optional[str] = None,
routing_model: Optional[RoutingModel] = None,
) -> GridSearchResult:
assert isinstance(universe, TradingStrategyUniverse)
if name is None:
name = combination.get_label()
universe_range = universe.universe.candles.get_timestamp_range()
if not start_at:
start_at = universe_range[0]
if not end_at:
end_at = universe_range[1]
if isinstance(start_at, datetime.datetime):
start_at = pd.Timestamp(start_at)
if isinstance(end_at, datetime.datetime):
end_at = pd.Timestamp(end_at)
if not cycle_duration:
cycle_duration = CycleDuration.from_timebucket(universe.universe.candles.time_bucket)
else:
assert isinstance(cycle_duration, CycleDuration)
if not routing_model:
routing_model = BacktestRoutingIgnoredModel(universe.get_reserve_asset().address)
# Run the test
state, universe, debug_dump = run_backtest_inline(
name=name,
start_at=start_at.to_pydatetime(),
end_at=end_at.to_pydatetime(),
client=None,
cycle_duration=cycle_duration,
decide_trades=decide_trades,
create_trading_universe=None,
universe=universe,
initial_deposit=initial_deposit,
reserve_currency=None,
trade_routing=TradeRouting.user_supplied_routing_model,
routing_model=routing_model,
allow_missing_fees=True,
data_delay_tolerance=data_delay_tolerance,
)
analysis = build_trade_analysis(state.portfolio)
equity = calculate_equity_curve(state)
returns = calculate_returns(equity)
metrics = calculate_advanced_metrics(
returns,
mode=AdvancedMetricsMode.full,
periods_per_year=cycle_duration.get_yearly_periods(),
)
summary = analysis.calculate_summary_statistics()
return GridSearchResult(
combination=combination,
state=state,
summary=summary,
metrics=metrics,
)
[docs]def pick_grid_search_result(results: List[GridSearchResult], **kwargs) -> Optional[GridSearchResult]:
"""Pick one combination in the results.
Example:
.. code-block:: python
# Pick a result of a single grid search combination
# and examine its trading metrics
sample = pick_grid_search_result(
results,
stop_loss_pct=0.9,
slow_ema_candle_count=7,
fast_ema_candle_count=2)
assert sample.summary.total_positions == 2
:param result:
Output from :py:func:`perform_grid_search`
:param kwargs:
Grid parameters to match
:return:
The grid search result with the matching parameters or None if not found
"""
for r in results:
# Check if this result matches all the parameters
match = all([p.value == kwargs[p.name] for p in r.combination.parameters])
if match:
return r
return None
[docs]def pick_best_grid_search_result(
results: List[GridSearchResult],
key: Callable=lambda r: r.summary.return_percent,
highest=True,
) -> Optional[GridSearchResult]:
"""Pick the best combination in the results based on one metric.
Use trading metrics or performance metrics for the selection.
Example:
.. code-block:: python
sample = pick_best_grid_search_result(
results,
key=lambda r: r.metrics.loc["Max Drawdown"][0])
assert sample is not None
:param result:
Output from :py:func:`perform_grid_search`
:param key:
Lambda function to extract the value to compare from the data.
If not given use cumulative return.
:param highest:
If true pick the highest value, otherwise lowest.
:return:
The grid search result with the matching parameters or None if not found
:return:
The grid search result with the matching parameters or None if not found
"""
current_best = -10**27 if highest else 10**27
match = None
for r in results:
# Check if this result matches all the parameters
value = key(r)
if value in (None, np.NaN):
# No result for this combination
continue
if highest:
if value > current_best:
match = r
current_best = value
else:
if value < current_best:
match = r
current_best = value
return match
#: Process global stored universe for multiprocess workers
_universe: Optional[TradingStrategyUniverse] = None
_process_pool: concurrent.futures.process.ProcessPoolExecutor | None = None
def _process_init(pickled_universe):
"""Child worker process initialiser."""
# Transfer ove the universe to the child process
global _universe
_universe = pickle.loads(pickled_universe)
def _handle_sigterm(*args):
# TODO: Despite all the effort, this does not seem to work with Visual Studio Code's Interrupt Kernel button
processes: List[Process] = list(_process_pool._processes.values())
_process_pool.shutdown()
for p in processes:
p.kill()
sys.exit(1)