├── .github └── workflows │ └── forex_tests.yaml ├── .gitignore ├── README.md ├── api ├── __init__.py ├── oanda │ ├── __init__.py │ └── connect.py └── params.py ├── forex ├── candle.py ├── candlelist_utils.py ├── harea.py ├── pivot.py └── segment.py ├── params.py ├── requirements.txt ├── tests ├── .DS_Store ├── api │ └── oanda │ │ ├── cmd.sh │ │ └── test_connect.py ├── conftest.py ├── data │ ├── clist.AUDUSD.H8.2019.pckl │ ├── clist.AUDUSD.H8.2021.pckl │ ├── clist.pckl │ ├── clist_audusd_2010_2020.pckl │ ├── create_pickled_data.py │ ├── harealist_file.txt │ ├── harealist_file.yaml │ ├── seg_audusd.pckl │ ├── seg_audusdB.pckl │ ├── seglist_audusd.pckl │ └── testCounter.xlsx ├── forex │ ├── candle │ │ ├── conftest.py │ │ ├── test_Candle.py │ │ ├── test_CandleList.py │ │ └── test_candlelist_utils.py │ ├── conftest.py │ ├── harea │ │ ├── test_HArea.py │ │ └── test_HAreaList.py │ ├── pivot │ │ ├── test_Pivot.py │ │ └── test_PivotList.py │ └── segment │ │ ├── test_Segment.py │ │ └── test_SegmentList.py ├── test_utils.py ├── trade_bot │ ├── test_TradeBot.py │ └── test_trade_bot_utils.py └── trading_journal │ ├── conftest.py │ ├── data_for_tests.py │ ├── test_OpenTrade.py │ ├── test_Trade.py │ ├── test_TradeJournal.py │ └── test_trade_utils.py ├── trade_bot ├── trade_bot.py └── trade_bot_utils.py ├── trading_journal ├── constants.py ├── open_trade.py ├── trade.py ├── trade_journal.py └── trade_utils.py └── utils.py /.github/workflows/forex_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Python FOREX library tests 2 | on: [push] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v3 8 | - name: Set up Python 3.9 9 | uses: actions/setup-python@v4 10 | with: 11 | # Semantic version range syntax or exact version of a Python version 12 | python-version: '3.9' 13 | # Optional - x64 or x86 architecture, defaults to x64 14 | architecture: 'x64' 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | pip install -r requirements.txt 19 | - name: Test forex library with pytest 20 | run: | 21 | export PYTHONPATH=$PWD 22 | export TOKEN=${{ secrets.OANDA_TOKEN }} 23 | pip install pytest && cd tests/forex/ 24 | pytest -s -v 25 | - name: Test api library with pytest 26 | run: | 27 | export PYTHONPATH=$PWD 28 | export TOKEN=${{ secrets.OANDA_TOKEN }} 29 | pip install pytest && cd tests/api/ 30 | pytest -s -v 31 | - name: Test trade_bot library with pytest 32 | run: | 33 | export PYTHONPATH=$PWD 34 | export TOKEN=${{ secrets.OANDA_TOKEN }} 35 | pip install pytest && cd tests 36 | pytest -s -v trade_bot/ 37 | - name: Test trading_journal library with pytest 38 | run: | 39 | export PYTHONPATH=$PWD 40 | export TOKEN=${{ secrets.OANDA_TOKEN }} 41 | pip install pytest && cd tests 42 | pytest -s -v trading_journal/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .metadata 2 | tmp/ 3 | *.tmp 4 | *.bak 5 | *.swp 6 | *~.nib 7 | local.properties 8 | .settings/ 9 | .loadpath 10 | .recommenders 11 | 12 | # External tool builders 13 | .externalToolBuilders/ 14 | 15 | # Locally stored "Eclipse launch configurations" 16 | *.launch 17 | 18 | # PyDev specific (Python IDE for Eclipse) 19 | *.pydevproject 20 | 21 | # CDT-specific (C/C++ Development Tooling) 22 | .cproject 23 | 24 | # Java annotation processor (APT) 25 | .factorypath 26 | 27 | # PDT-specific (PHP Development Tools) 28 | .buildpath 29 | 30 | # sbteclipse plugin 31 | .target 32 | 33 | # Tern plugin 34 | .tern-project 35 | 36 | # TeXlipse plugin 37 | .texlipse 38 | 39 | # STS (Spring Tool Suite) 40 | .springBeans 41 | 42 | # Code Recommenders 43 | .recommenders/ 44 | 45 | # Scala IDE specific (Scala & Java development for Eclipse) 46 | .cache-main 47 | .scala_dependencies 48 | .worksheet 49 | 50 | # Byte-compiled / optimized / DLL files 51 | __pycache__/ 52 | *.py[cod] 53 | 54 | # Sphinx build files 55 | ../entry_scripts/docs/_build/ 56 | 57 | 58 | # Editor backup files # 59 | ####################### 60 | *~ 61 | *.pyc 62 | 63 | # Pycharm files 64 | .idea/ 65 | 66 | # Compiled files 67 | __pycache__ 68 | 69 | # Ignore .png files generated by my notebooks 70 | notebooks/*.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FOREX 2 | 3 | Python library to handle different kind of operations with the OANDA API. 4 | It can be used for analyzing the trades recorded in a trading journal and 5 | will calculate different features for each of the recorded trades that can 6 | be used in order to train a binary classifier that can be used in order to 7 | predict the outcome of new trades 8 | 9 | 10 | -------------------------------------------------------------------------------- /api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/api/__init__.py -------------------------------------------------------------------------------- /api/oanda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/api/oanda/__init__.py -------------------------------------------------------------------------------- /api/oanda/connect.py: -------------------------------------------------------------------------------- 1 | """ 2 | @date: 22/11/2020 3 | @author: Ernesto Lowy 4 | @email: ernestolowy@gmail.com 5 | """ 6 | import datetime 7 | import time 8 | import logging 9 | import requests 10 | import re 11 | import os 12 | import json 13 | import flatdict 14 | import argparse 15 | from datetime import timedelta 16 | 17 | from api.params import Params as apiparams 18 | from typing import Dict, List 19 | from forex.candle import CandleList 20 | 21 | o_logger = logging.getLogger(__name__) 22 | o_logger.setLevel(logging.INFO) 23 | 24 | 25 | class Connect(object): 26 | """Class representing a connection to the Oanda's REST API. 27 | 28 | Args: 29 | instrument: i.e. AUD_USD 30 | granularity: i.e. D, H12, ... 31 | """ 32 | 33 | def __init__(self, instrument: str, granularity: str) -> None: 34 | self._instrument = instrument 35 | self._granularity = granularity 36 | 37 | @property 38 | def instrument(self) -> str: 39 | return self._instrument 40 | 41 | @property 42 | def granularity(self) -> str: 43 | return self._granularity 44 | 45 | def fetch_candle(self, d: datetime) -> "Candle": 46 | """Method to get a single candle""" 47 | # substract one min to be sure we fetch the right candle 48 | start = d - timedelta(minutes=1) 49 | clO = self.query(start=start.isoformat(), end=start.isoformat()) 50 | hour_delta = timedelta(hours=1) 51 | 52 | if len(clO) == 0: 53 | # try with hour-1 to deal with time shifts 54 | new_d = None 55 | if d.hour%2 > 0: 56 | new_d = d+hour_delta 57 | else: 58 | new_d = d-hour_delta 59 | clO = self.query(start=new_d.isoformat(), end=new_d.isoformat()) 60 | 61 | if len(clO.candles) == 1: 62 | if (clO.candles[0].time != d) and abs(clO.candles[0].time - d) > hour_delta: 63 | # return None if candle is not equal to 'd' 64 | return 65 | return clO.candles[0] 66 | if len(clO.candles) > 1: 67 | raise Exception("No valid number of candles in CandleList") 68 | return 69 | 70 | def retry(cooloff: int = 5, exc_type=None): 71 | """Decorator for retrying connection and prevent TimeOut errors""" 72 | if not exc_type: 73 | exc_type = [requests.exceptions.ConnectionError] 74 | 75 | def real_decorator(function): 76 | def wrapper(*args, **kwargs): 77 | while True: 78 | try: 79 | return function(*args, **kwargs) 80 | except Exception as e: 81 | if e.__class__ in exc_type: 82 | print("failed (?)") 83 | time.sleep(cooloff) 84 | else: 85 | raise e 86 | 87 | return wrapper 88 | 89 | return real_decorator 90 | 91 | def _process_data(self, data: List[flatdict.FlatDict], strip: bool = True): 92 | """Process candle data 93 | 94 | Args: 95 | data: data returned by API 96 | strip: If True then remove 'complete' and 'volume' fields 97 | """ 98 | keys_to_remove = ["complete", "volume", "time"] 99 | cldict = list() 100 | for candle in data: 101 | atime = re.sub(r"\.\d+Z$", "", candle["time"]) 102 | if strip: 103 | candle = { 104 | key: value 105 | for key, value in candle.items() 106 | if key not in keys_to_remove 107 | } 108 | candle["time"] = atime 109 | newc = {key.replace("mid.", ""): value for key, value in candle.items()} 110 | cldict.append(newc) 111 | return cldict 112 | 113 | @retry() 114 | def query( 115 | self, start: datetime, end: datetime = None, count: int = None 116 | ) -> List[Dict]: 117 | """Function to query Oanda's REST API 118 | 119 | Args: 120 | start: isoformat 121 | end: isoformat 122 | count: If end is not defined, this controls the 123 | number of candles from the start 124 | that will be retrieved 125 | Returns: 126 | CandleList""" 127 | startObj = self.validate_datetime(start) 128 | start = startObj.isoformat() 129 | params = {} 130 | if end is not None and count is None: 131 | endObj = self.validate_datetime(end) 132 | min = datetime.timedelta(minutes=1) 133 | endObj = endObj + min 134 | end = endObj.isoformat() 135 | params["to"] = end 136 | elif count is not None: 137 | params["count"] = count 138 | elif end is None and count is None: 139 | raise Exception( 140 | "You need to set at least the 'end' or the " "'count' attribute" 141 | ) 142 | 143 | params["granularity"] = self.granularity 144 | params["from"] = start 145 | try: 146 | resp = requests.get( 147 | url=f"{apiparams.url}/{self.instrument}/candles", 148 | params=params, 149 | headers={ 150 | "content-type": f"{apiparams.content_type}", 151 | "Authorization": f"Bearer {os.environ.get('TOKEN')}", 152 | }, 153 | ) 154 | if resp.status_code != 200: 155 | raise ConnectionError(resp.status_code) 156 | else: 157 | data = json.loads(resp.content.decode("utf-8")) 158 | newdata = [flatdict.FlatDict(c, delimiter=".") for c in data["candles"]] 159 | newdata1 = self._process_data(data=newdata) 160 | return CandleList( 161 | instrument=self.instrument, 162 | granularity=self.granularity, 163 | data=newdata1, 164 | ) 165 | except ConnectionError as err: 166 | logging.exception( 167 | "Something went wrong. url used was:\n{0}".format(resp.url) 168 | ) 169 | logging.exception("Error message was: {0}".format(err)) 170 | return CandleList( 171 | instrument=self.instrument, granularity=self.granularity, data=[] 172 | ) 173 | 174 | def validate_datetime(self, datestr: str) -> datetime: 175 | """Function to parse a string datetime to return 176 | a datetime object and to validate the datetime. 177 | 178 | Args: 179 | datestr : String representing a date 180 | """ 181 | try: 182 | dateObj = datetime.datetime.strptime(datestr, "%Y-%m-%dT%H:%M:%S") 183 | except ValueError: 184 | raise ValueError("Incorrect date format, should be %Y-%m-%dT%H:%M:%S") 185 | 186 | return dateObj 187 | 188 | def print_url(self) -> str: 189 | """Print url from requests module""" 190 | 191 | print("URL: %s" % self.resp.url) 192 | 193 | def __repr__(self) -> str: 194 | return "connect" 195 | 196 | def __str__(self) -> str: 197 | out_str = "" 198 | for attr, value in self.__dict__.items(): 199 | out_str += "%s:%s " % (attr, value) 200 | return out_str 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser( 205 | description="Query Oanda REST API and generates a CandleList object and save it to a file" 206 | ) 207 | 208 | parser.add_argument( 209 | "--start", required=True, help="Start datetime. i.e.:2018-05-21T21:00:00" 210 | ) 211 | parser.add_argument( 212 | "--end", required=True, help="End datetime. i.e.: 2018-05-23T21:00:00" 213 | ) 214 | parser.add_argument("--instrument", required=True, help="AUD_USD, GBP_USD, ...") 215 | parser.add_argument("--granularity", required=True, help="i.e. D,H12,H8, ...") 216 | parser.add_argument("--outfile", required=True, help="Output filename") 217 | 218 | args = parser.parse_args() 219 | 220 | conn = Connect(instrument=args.instrument, granularity=args.granularity) 221 | 222 | clO = conn.query(start=args.start, end=args.end) 223 | 224 | clO.pickle_dump(args.outfile) 225 | -------------------------------------------------------------------------------- /api/params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Params: 6 | alignmentTimezone: int = 22 7 | dailyAlignment: str = "Europe/London" 8 | url: str = "https://api-fxtrade.oanda.com/v3/instruments/" 9 | roll: bool = ( 10 | True # If True, then extend the end date, which falls on close market, 11 | # to the next period for which 12 | ) 13 | # the market is open. Default=False 14 | content_type: str = "application/json" 15 | -------------------------------------------------------------------------------- /forex/candle.py: -------------------------------------------------------------------------------- 1 | """ 2 | @date: 22/11/2020 3 | @author: Ernesto Lowy 4 | @email: ernestolowy@gmail.com 5 | """ 6 | import logging 7 | import pickle 8 | from datetime import timedelta, datetime 9 | 10 | import pandas as pd 11 | import matplotlib 12 | from pandas.plotting import register_matplotlib_converters 13 | 14 | from utils import try_parsing_date, calculate_pips 15 | from params import clist_params 16 | 17 | register_matplotlib_converters() 18 | matplotlib.use("PS") 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | # create logger 23 | cl_logger = logging.getLogger(__name__) 24 | cl_logger.setLevel(logging.INFO) 25 | 26 | 27 | class Candle(object): 28 | """Class representing a particular Candle""" 29 | 30 | __slots__ = [ 31 | "complete", 32 | "volume", 33 | "time", 34 | "o", 35 | "h", 36 | "c", 37 | "l", 38 | "_colour", 39 | "_perc_body", 40 | "granularity", 41 | "rsi", 42 | ] 43 | 44 | def __init__(self, time: str, o: float, h: float, 45 | c: float, l: float) -> None: 46 | self.o = float(o) 47 | self.h = float(h) 48 | self.c = float(c) 49 | self.l = float(l) 50 | self.time = time 51 | 52 | if isinstance(self.time, str): 53 | self.time = datetime.strptime(self.time, "%Y-%m-%dT%H:%M:%S") 54 | self._colour = self._set_colour() 55 | self._perc_body = self._calc_perc_body() 56 | 57 | @property 58 | def colour(self) -> str: 59 | """Candle's body colour""" 60 | return self._colour 61 | 62 | @property 63 | def perc_body(self) -> float: 64 | """Candle's body percentage""" 65 | return self._perc_body 66 | 67 | def _set_colour(self) -> str: 68 | if self.o < self.c: 69 | return "green" 70 | elif self.o > self.c: 71 | return "red" 72 | else: 73 | return "undefined" 74 | 75 | def _calc_perc_body(self) -> float: 76 | height = self.h - self.l 77 | if height == 0: 78 | return 0 79 | body = abs(self.o - self.c) 80 | return round((body * 100) / height, 2) 81 | 82 | def indecision_c(self, ic_perc: int = 10) -> bool: 83 | """Function to check if the candle is an indecision candle. 84 | 85 | Args: 86 | ic_perc : Candle's body percentage below which the candle will 87 | be considered indecision candle. 88 | """ 89 | return self.perc_body <= ic_perc 90 | 91 | def height(self, pair) -> float: 92 | """Function to calculate the number of pips 93 | between self.h and self.l""" 94 | return float(calculate_pips(pair, self.h - self.l)) 95 | 96 | def middle_point(self) -> float: 97 | """Function to calculate middle candle price. 98 | Defined as (self.h-self.l)/2+self.l)""" 99 | return round(((self.h - self.l) / 2 + self.l), 5) 100 | 101 | def __hash__(self): 102 | return hash(self.time) 103 | 104 | def __eq__(self, other): 105 | if isinstance(other, Candle): 106 | return self.time == other.time 107 | return False 108 | 109 | def __repr__(self): 110 | return "Candle" 111 | 112 | def __str__(self): 113 | sb = [] 114 | for key in self.__slots__: 115 | if hasattr(self, key): 116 | sb.append("{key}='{value}'".format(key=key, value=getattr(self, key))) 117 | return ", ".join(sb) 118 | 119 | 120 | class CandleList(object): 121 | """Class containing a list of Candles. 122 | 123 | Class variables: 124 | instrument: i.e. 'AUD_USD' 125 | granularity: i.e. 'D' 126 | candles: List of Candle objects 127 | type: Type of this CandleList. Possible values are 'long'/'short'""" 128 | 129 | __slots__ = [ 130 | "instrument", 131 | "granularity", 132 | "data", 133 | "candles", 134 | "_type", 135 | "times", 136 | "pos", 137 | ] 138 | 139 | def __init__( 140 | self, instrument: str, granularity: str, 141 | data: list = None, candles=None 142 | ): 143 | """Constructor 144 | 145 | Arguments: 146 | data: list of Dictionaries, each dict containing data for a Candle 147 | 148 | self.times will be a list of datetime objects 149 | """ 150 | if candles: 151 | self.candles = candles 152 | self.times = [c.time for c in candles] 153 | elif data: 154 | self.candles = [Candle(**d) for d in data] 155 | self.times = [ 156 | try_parsing_date(d["time"]) if isinstance(d["time"], str) else d["time"] 157 | for d in data 158 | ] 159 | else: 160 | self.candles = [] 161 | self.times = [] 162 | self.instrument = instrument 163 | self.granularity = granularity 164 | self._type = self._guess_type() 165 | 166 | @property 167 | def type(self): 168 | return self._type 169 | 170 | def __iter__(self): 171 | self.pos = 0 172 | return self 173 | 174 | def __next__(self): 175 | if self.pos < len(self.candles): 176 | self.pos += 1 177 | return self.candles[self.pos - 1] 178 | else: 179 | raise StopIteration 180 | 181 | def __getitem__(self, adatetime: datetime) -> Candle: 182 | if not isinstance(adatetime, datetime): 183 | raise ValueError("A datetime object is needed!") 184 | fdt = None 185 | if adatetime not in self.times: 186 | dtp1 = adatetime + timedelta(hours=1) 187 | dtm1 = adatetime - timedelta(hours=1) 188 | if dtp1 in self.times: 189 | fdt = dtp1 190 | elif dtm1 in self.times: 191 | fdt = dtm1 192 | else: 193 | fdt = adatetime 194 | if fdt: 195 | return self.candles[self.times.index(fdt)] 196 | 197 | def __index__(self, adatetime: datetime) -> int: 198 | fdt = None 199 | if adatetime not in self.times: 200 | dtp1 = adatetime + timedelta(hours=1) 201 | dtm1 = adatetime - timedelta(hours=1) 202 | if dtp1 in self.times: 203 | fdt = dtp1 204 | elif dtm1 in self.times: 205 | fdt = dtm1 206 | else: 207 | fdt = adatetime 208 | if fdt: 209 | return self.times.index(fdt) 210 | else: 211 | raise ValueError(f"{adatetime} not in self.times") 212 | 213 | def __len__(self): 214 | return len(self.candles) 215 | 216 | def __add__(self, ClO): 217 | clist = self.candles + ClO.candles 218 | newClO = CandleList( 219 | instrument=self.instrument, 220 | granularity=self.granularity, candles=clist 221 | ) 222 | return newClO 223 | 224 | def _guess_type(self) -> str: 225 | if len(self.candles) == 0: 226 | return None 227 | price_1st = self.candles[0].c 228 | price_last = self.candles[-1].c 229 | if price_1st > price_last: 230 | return "short" # or downtrend 231 | elif price_1st < price_last: 232 | return "long" # or uptrend 233 | 234 | def calc_rsi(self): 235 | """Calculate the RSI for a certain candle list""" 236 | cl_logger.debug("Running calc_rsi") 237 | 238 | series = [c.c for c in self.candles] 239 | 240 | df = pd.DataFrame({"close": series}) 241 | chg = df["close"].diff(1) 242 | 243 | gain = chg.mask(chg < 0, 0) 244 | loss = chg.mask(chg > 0, 0) 245 | 246 | rsi_period = clist_params.rsi_period 247 | avg_gain = gain.ewm(com=rsi_period - 1, min_periods=rsi_period).mean() 248 | avg_loss = loss.ewm(com=rsi_period - 1, min_periods=rsi_period).mean() 249 | 250 | rs = abs(avg_gain / avg_loss) 251 | rsi = 100 - (100 / (1 + rs)) 252 | 253 | rsi4cl = rsi[-len(self.candles) :] 254 | # set rsi attribute in each candle of the CandleList 255 | ix = 0 256 | for _, v in zip(self.candles, rsi4cl): 257 | self.candles[ix].rsi = round(v, 2) 258 | ix += 1 259 | cl_logger.debug("Done calc_rsi") 260 | 261 | def pickle_dump(self, outfile: str) -> str: 262 | """Function to pickle this particular CandleList 263 | 264 | Arguments: 265 | outfile: Path used to pickle 266 | 267 | Returns: 268 | path to file with pickled data 269 | """ 270 | 271 | pickle_out = open(outfile, "wb") 272 | pickle.dump(self, pickle_out) 273 | pickle_out.close() 274 | 275 | return outfile 276 | 277 | @classmethod 278 | def pickle_load(self, infile: str): 279 | """Function to pickle this particular CandleList 280 | 281 | Arguments: 282 | infile: Path used to load in pickled data 283 | 284 | Returns: 285 | inclO: CandleList object 286 | """ 287 | pickle_in = open(infile, "rb") 288 | inclO = pickle.load(pickle_in) 289 | pickle_in.close() 290 | 291 | return inclO 292 | 293 | def calc_rsi_bounces(self) -> dict: 294 | """Calculate the number of times that the 295 | price has been in overbought (>70) or 296 | oversold (<30) regions 297 | 298 | Returns: 299 | dict: 300 | {number: 3 301 | lengths: [4,5,6]} 302 | Where number is the number of times price 303 | has been in overbought/oversold and lengths list 304 | is formed by the number of candles that the price 305 | has been in overbought/oversold each of the times 306 | sorted from older to newer 307 | """ 308 | adj = False 309 | num_times, length = 0, 0 310 | lengths = [] 311 | 312 | for c in self.candles: 313 | if c.rsi is None: 314 | raise Exception( 315 | "RSI values are not defined for this " 316 | "Candlelist, " 317 | "run calc_rsi first" 318 | ) 319 | if self.type is None: 320 | raise Exception("type is not defined for this Candlelist") 321 | 322 | if self.type == "short": 323 | if c.rsi > 70 and adj is False: 324 | num_times += 1 325 | length = 1 326 | adj = True 327 | elif c.rsi > 70 and adj is True: 328 | length += 1 329 | elif c.rsi < 70: 330 | if adj is True: 331 | lengths.append(length) 332 | adj = False 333 | elif self.type == "long": 334 | if c.rsi < 30 and adj is False: 335 | num_times += 1 336 | length = 1 337 | adj = True 338 | elif c.rsi < 30 and adj is True: 339 | length += 1 340 | elif c.rsi > 30: 341 | if adj is True: 342 | lengths.append(length) 343 | adj = False 344 | 345 | if adj is True and length > 0: 346 | lengths.append(length) 347 | 348 | if num_times != len(lengths): 349 | raise Exception("Number of times" "and lengths do not" "match") 350 | return {"number": num_times, "lengths": lengths} 351 | 352 | def get_length_pips(self) -> int: 353 | """Function to calculate the length of CandleList in number of pips""" 354 | 355 | start_cl = self.candles[0] 356 | end_cl = self.candles[-1] 357 | 358 | (first, second) = self.instrument.split("_") 359 | round_number = None 360 | if first == "JPY" or second == "JPY": 361 | round_number = 2 362 | else: 363 | round_number = 4 364 | 365 | start_price = round(float(start_cl.c), round_number) 366 | end_price = round(float(end_cl.c), round_number) 367 | 368 | diff = (start_price - end_price) * 10**round_number 369 | 370 | return abs(int(round(diff, 0))) 371 | 372 | def slice( 373 | self, start: datetime, end: datetime, inplace: bool = False 374 | ) -> "CandleList": 375 | """Function to slice self on a date interval. It will modify 376 | the CandleList in place 377 | 378 | Arguments: 379 | start: Slice the CandleList from this 'start' datetime. 380 | end: This CandleList will have this 'end' datetime. 381 | 382 | Raises 383 | ------ 384 | Exception 385 | If start > end 386 | """ 387 | if self.granularity == "D": 388 | delta = timedelta(hours=24) 389 | else: 390 | fgran = self.granularity.replace("H", "") 391 | delta = timedelta(hours=int(fgran)) 392 | 393 | while not self.__getitem__(start): 394 | start = start + delta 395 | while not self.__getitem__(end): 396 | end = end + delta 397 | start_ix = self.__index__(start) 398 | end_ix = self.__index__(end) 399 | if not inplace: 400 | cl = CandleList( 401 | instrument=self.instrument, 402 | granularity=self.granularity, 403 | candles=self.candles[start_ix: end_ix + 1], 404 | ) 405 | return cl 406 | else: 407 | self.candles = self.candles[start_ix: end_ix + 1] 408 | self.times = self.times[start_ix: end_ix + 1] 409 | self._type = self._guess_type() 410 | return self 411 | 412 | def get_lasttime(self, price: float, type: str) -> datetime: 413 | """Function to get the datetime for last time that price has been 414 | above/below a price level 415 | 416 | Arguments: 417 | price: value to calculate the last time in this CandleList the 418 | price was above/below 419 | trade type: either long/short 420 | """ 421 | count = 0 422 | for c in reversed(self.candles): 423 | count += 1 424 | # Last time has to be at least forexparams.min candles before 425 | if count <= clist_params.min: 426 | continue 427 | if type == "long": 428 | if c.h < price: 429 | return c.time 430 | elif type == "short": 431 | if c.l > price: 432 | return c.time 433 | 434 | return self.candles[0].time 435 | 436 | def get_highest(self) -> float: 437 | """Function to calculate the highest 438 | price in this CandleList 439 | 440 | Returns: 441 | highest price 442 | """ 443 | max = 0.0 444 | for cl in self.candles: 445 | price = cl.c 446 | if price > max: 447 | max = price 448 | 449 | return max 450 | 451 | def get_lowest(self) -> float: 452 | """Function to calculate the lowest 453 | price in this CandeList 454 | 455 | Returns: 456 | lowest price 457 | """ 458 | min = None 459 | for cl in self.candles: 460 | price = cl.c 461 | if min is None: 462 | min = price 463 | else: 464 | if price < min: 465 | min = price 466 | return min 467 | 468 | def __repr__(self): 469 | return "CandleList" 470 | 471 | def __str__(self): 472 | sb = [] 473 | for key in self.__slots__: 474 | if hasattr(self, key): 475 | sb.append("{key}='{value}'".format(key=key, 476 | value=getattr(self, key))) 477 | return ", ".join(sb) 478 | -------------------------------------------------------------------------------- /forex/candlelist_utils.py: -------------------------------------------------------------------------------- 1 | from params import tradebot_params, clist_params, pivots_params 2 | from forex.harea import HArea, HAreaList 3 | from utils import add_pips2price, substract_pips2price, calculate_pips 4 | 5 | import logging 6 | import pandas as pd 7 | 8 | # create logger 9 | cl_logger = logging.getLogger(__name__) 10 | cl_logger.setLevel(logging.INFO) 11 | 12 | 13 | def calc_SR(pvLO, outfile: str = None) -> HAreaList: 14 | """Function to calculate S/R lines. 15 | 16 | Args: 17 | pvlO: PivotList object 18 | used for calculation 19 | outfile : Output filename for .png file 20 | 21 | Returns: 22 | HAreaList object 23 | """ 24 | # now calculate the price range for calculating the S/R , add a 25 | # number of pips to max,min to be sure that we also detect the 26 | # extreme pivots 27 | ul = add_pips2price(pvLO.clist.instrument, pvLO.clist.get_highest(), 28 | tradebot_params.add_pips) 29 | ll = substract_pips2price(pvLO.clist.instrument, pvLO.clist.get_lowest(), 30 | tradebot_params.add_pips) 31 | 32 | cl_logger.debug(f"Running calc_SR for estimated range: {ll}-{ul}") 33 | 34 | prices, bounces, score_per_bounce, tot_score = ([] for i in range(4)) 35 | 36 | # the increment of price in number of pips is double the hr_extension 37 | prev_p = None 38 | 39 | p = float(ll) 40 | 41 | while p <= float(ul): 42 | cl_logger.debug("Processing S/R at {0}".format(round(p, 4))) 43 | # get a PivotList for this particular S/R 44 | newPL = pvLO.inarea_pivots(price=p) 45 | if len(newPL.pivots) == 0: 46 | mean_pivot = 0 47 | else: 48 | mean_pivot = newPL.get_avg_score() 49 | 50 | prices.append(round(p, 5)) 51 | bounces.append(len(newPL.pivots)) 52 | tot_score.append(newPL.get_score()) 53 | score_per_bounce.append(mean_pivot) 54 | # increment price to following price. Because the increment is made 55 | # in pips it does not suffer of the JPY pairs issue 56 | p = add_pips2price(pvLO.clist.instrument, p, 57 | 2*clist_params.i_pips) 58 | if prev_p is None: 59 | prev_p = p 60 | else: 61 | increment_price = round(p - prev_p, 5) 62 | prev_p = p 63 | 64 | data = {'price': prices, 65 | 'bounces': bounces, 66 | 'scores': score_per_bounce, 67 | 'tot_score': tot_score} 68 | 69 | df = pd.DataFrame(data=data) 70 | 71 | # Establishing bounces threshold as the args.th quantile 72 | # selecting only rows with at least one pivot and tot_score>0, 73 | # so threshold selection considers only these rows 74 | # and selection is not biased when range of prices is wide 75 | dfgt1 = df.loc[(df['bounces'] > 0)] 76 | dfgt2 = df.loc[(df['tot_score'] > 0)] 77 | bounce_th = dfgt1.bounces.quantile(tradebot_params.th) 78 | score_th = dfgt2.tot_score.quantile(tradebot_params.th) 79 | 80 | print(f"Selected number of pivot threshold: {round(bounce_th, 3)}") 81 | print(f"Selected tot score threshold: {round(score_th, 1)}") 82 | 83 | # selecting records over threshold 84 | dfsel = df.loc[(df['bounces'] > bounce_th) | (df['tot_score'] > score_th)] 85 | 86 | # repeat until no overlap between prices 87 | ret = calc_diff(dfsel, increment_price) 88 | 89 | dfsel = ret[0] 90 | tog_seen = ret[1] 91 | while tog_seen is True: 92 | ret = calc_diff(dfsel, increment_price) 93 | dfsel = ret[0] 94 | tog_seen = ret[1] 95 | 96 | # iterate over DF with selected SR to create a HAreaList 97 | halist = [] 98 | for _, row in dfsel.iterrows(): 99 | resist = HArea(price=row['price'], 100 | pips=pivots_params.hr_pips, 101 | instrument=pvLO.clist.instrument, 102 | granularity=pvLO.clist.granularity, 103 | no_pivots=row['bounces'], 104 | tot_score=round(row['tot_score'], 5)) 105 | halist.append(resist) 106 | 107 | halistObj = HAreaList(halist=halist) 108 | 109 | # Plot the HAreaList 110 | if outfile: 111 | halistObj.plot(clO=pvLO.clist, 112 | outfile=outfile) 113 | 114 | cl_logger.info("Run done") 115 | 116 | return halistObj 117 | 118 | 119 | def calc_atr(clO) -> float: 120 | '''Function to calculate the ATR (average timeframe rate) 121 | This is the average candle variation in pips for the desired 122 | timeframe. The variation is measured as the abs diff 123 | (in pips) between the high and low of the candle 124 | 125 | Arguments: 126 | clO: CandleList object 127 | Used for calculation 128 | ''' 129 | length, tot_diff_in_pips = 0, 0 130 | for c in clO.candles: 131 | diff = abs(c.h-c.l) 132 | tot_diff_in_pips = tot_diff_in_pips + \ 133 | float(calculate_pips(clO.instrument, diff)) 134 | length += 1 135 | return round(tot_diff_in_pips/length, 3) 136 | 137 | 138 | def calc_diff(df_loc, increment_price: float): 139 | '''Function to select the best S/R for areas that 140 | are less than 3*increment_price. 141 | 142 | Arguments: 143 | df_loc: Pandas dataframe with S/R areas 144 | increment_price : This is the increment_price 145 | between different price levels 146 | in order to identify S/Rs 147 | 148 | Returns: 149 | Pandas dataframe with selected S/R 150 | ''' 151 | prev_price = prev_row = prev_ix = None 152 | tog_seen = False 153 | for index, row in df_loc.iterrows(): 154 | if prev_price is None: 155 | prev_price = float(row['price']) 156 | prev_row = row 157 | prev_ix = index 158 | else: 159 | diff = round(float(row['price']) - prev_price, 4) 160 | if diff < clist_params.times * increment_price: 161 | tog_seen = True 162 | if row['bounces'] <= prev_row['bounces'] and \ 163 | row['tot_score'] < prev_row['tot_score']: 164 | # remove current row 165 | df_loc = df_loc.drop(index) 166 | elif row['bounces'] >= prev_row['bounces'] and \ 167 | row['tot_score'] > prev_row['tot_score'] or \ 168 | row['tot_score'] == prev_row['tot_score']: 169 | # remove previous row 170 | df_loc = df_loc.drop(prev_ix) 171 | prev_price = float(row['price']) 172 | prev_row = row 173 | prev_ix = index 174 | elif row['bounces'] <= prev_row['bounces'] and \ 175 | row['tot_score'] > prev_row['tot_score']: 176 | # remove previous row as scores in current takes precedence 177 | df_loc = df_loc.drop(prev_ix) 178 | prev_price = float(row['price']) 179 | prev_row = row 180 | prev_ix = index 181 | elif row['bounces'] >= prev_row['bounces'] and \ 182 | row['tot_score'] < prev_row['tot_score']: 183 | # remove current row as scores in current takes precedence 184 | df_loc = df_loc.drop(index) 185 | elif row['bounces'] == prev_row['bounces'] and \ 186 | row['tot_score'] == prev_row['tot_score']: 187 | # exactly same quality for row and prev_row 188 | # remove current arbitrarily 189 | df_loc = df_loc.drop(index) 190 | else: 191 | prev_price = float(row['price']) 192 | prev_row = row 193 | prev_ix = index 194 | return df_loc, tog_seen 195 | -------------------------------------------------------------------------------- /forex/harea.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from datetime import timedelta, datetime 5 | from api.oanda.connect import Connect 6 | from params import gparams 7 | from forex.candle import Candle 8 | 9 | import matplotlib.pyplot as plt 10 | import matplotlib.dates as mdates 11 | 12 | h_logger = logging.getLogger(__name__) 13 | h_logger.setLevel(logging.INFO) 14 | 15 | 16 | class HArea(object): 17 | '''Class to represent a horizontal area in the chart. 18 | 19 | Class variables: 20 | price: Price in the chart used as the middle point that will be 21 | extended on both sides a certain number of pips 22 | instrument: Instrument for this CandleList 23 | granularity: Granularity for this CandleList (i.e. D, H12, H8 etc...) 24 | pips: Number of pips above/below self.price to calculate self.upper 25 | and self.lower 26 | upper: Upper limit price of area 27 | lower: Lower limit price of area 28 | no_pivots: Number of pivots bouncing on self 29 | tot_score: Total score, which is the sum of scores of all pivots on 30 | this HArea 31 | ''' 32 | __slots__ = ["price", "instrument", "granularity", "pips", 33 | "no_pivots", "tot_score", 'upper', 'lower'] 34 | 35 | def __init__(self, price: float, instrument: str, 36 | granularity: str, pips: int, no_pivots: int = None, 37 | tot_score: int = None): 38 | 39 | try: 40 | (first, second) = instrument.split("_") 41 | except ValueError: 42 | logging.exception(f"Incorrect '_' split for instrument:{instrument}") 43 | sys.exit(1) 44 | self.instrument = instrument 45 | round_number = None 46 | divisor = None 47 | if first == "JPY" or second == "JPY": 48 | round_number = 2 49 | divisor = 100 50 | else: 51 | round_number = 4 52 | divisor = 10000 53 | price = round(price, round_number) 54 | self.price = price 55 | self.pips = pips 56 | self.granularity = granularity 57 | self.no_pivots = no_pivots 58 | self.tot_score = tot_score 59 | self.upper = round(price+(pips/divisor), 4) 60 | self.lower = round(price-(pips/divisor), 4) 61 | 62 | def get_cross_time(self, candle: Candle, granularity='M30') -> datetime: 63 | '''This function is used get the time that the candle 64 | crosses (go through) HArea 65 | 66 | Arguments: 67 | candle: Candle crossing the HArea 68 | granularity: To what granularity we should descend 69 | 70 | Returns: 71 | crossing time. 72 | n.a. if crossing time could not retrieved. This can happens 73 | when there is an artifactual jump in Oanda's data 74 | ''' 75 | if candle.l <= self.price <= candle.h: 76 | delta = None 77 | if self.granularity == "D": 78 | delta = timedelta(hours=24) 79 | else: 80 | fgran = self.granularity.replace('H', '') 81 | delta = timedelta(hours=int(fgran)) 82 | 83 | cstart = candle.time 84 | cend = cstart+delta 85 | conn = Connect(instrument=self.instrument, 86 | granularity=granularity) 87 | 88 | h_logger.debug("Fetching data from API") 89 | res = conn.query(start=cstart.isoformat(), 90 | end=cend.isoformat()) 91 | 92 | seen = False 93 | if res.candles: 94 | for c in res: 95 | if c.l <= self.price <= c.h: 96 | seen = True 97 | return c.time 98 | if seen is False: 99 | return candle.time 100 | else: 101 | return 'n.a.' 102 | 103 | def __repr__(self): 104 | return "HArea" 105 | 106 | def __str__(self): 107 | out_str = "" 108 | for attr in self.__slots__: 109 | out_str += f"{attr}:{getattr(self, attr)} " 110 | return out_str 111 | 112 | 113 | def is_ordered_ascending(lst): 114 | return all(lst[i] <= lst[i + 1] for i in range(len(lst) - 1)) 115 | 116 | 117 | class HAreaList(object): 118 | """Class that represents a list of HArea objects. 119 | 120 | Class variables: 121 | halist : List of HArea objects 122 | """ 123 | __slots__ = ["halist"] 124 | 125 | def __init__(self, halist): 126 | self.halist = halist 127 | 128 | @classmethod 129 | def from_list(cls, prices: list[float], 130 | instrument: str, 131 | granularity: str, 132 | pips: int = 30 133 | ): 134 | """Creates object from list with the SR prices. 135 | 136 | Args: 137 | prices: List with Floats. Prices should be sorted (ascending) 138 | """ 139 | if not is_ordered_ascending(prices): 140 | raise ValueError("Prices in {ifile} should be sorted (ascending)") 141 | hlist = [] 142 | for price in prices: 143 | harea = HArea( 144 | price=price, 145 | pips=pips, 146 | instrument=instrument, 147 | granularity=granularity 148 | ) 149 | hlist.append(harea) 150 | halist_object = HAreaList(halist=hlist) 151 | return halist_object 152 | 153 | @classmethod 154 | def from_csv(cls, ifile: str, 155 | instrument: str, 156 | granularity: str, 157 | pips: int = 30): 158 | """Creates object from file with the SR prices. 159 | 160 | Args: 161 | ifile: Floats in the file should be sorted (ascending) 162 | """ 163 | with open(ifile, "r") as f: 164 | prices = [float(line.strip()) for line in f.readlines()] 165 | if not is_ordered_ascending(prices): 166 | raise ValueError("Prices in {ifile} should be sorted (ascending)") 167 | halist_object = cls.from_list( 168 | prices=prices, instrument=instrument, granularity=granularity, 169 | pips=pips 170 | ) 171 | return halist_object 172 | 173 | def onArea(self, candle: Candle): 174 | '''Function that will check which (if any) of the HArea objects 175 | in this HAreaList will overlap with 'candle'. 176 | 177 | See comments in code to understand what is considered 178 | an overlap 179 | 180 | Arguments: 181 | candle: Candle that will be checked 182 | 183 | Returns: 184 | An HArea object overlapping with 'candle' and the ix 185 | in self.halist for the HArea being crossed. This ix is expressed 186 | from the HArea with the lowest price to the highest price and 187 | starting from 0. 188 | So if 'sel_ix'=2, then it will be the third HArea 189 | None if there are no HArea objects overlapping''' 190 | onArea_hr = sel_ix = None 191 | ix = 0 192 | seen = False 193 | for harea in self.halist: 194 | if harea.price <= float(candle.h) and \ 195 | harea.price >= float(candle.l): 196 | if seen: 197 | logging.warn("More than one HArea crosses this candle") 198 | onArea_hr = harea 199 | sel_ix = ix 200 | seen = True 201 | ix += 1 202 | return onArea_hr, sel_ix 203 | 204 | def print(self) -> str: 205 | '''Function to print out basic information on each of the 206 | HArea objects in the HAreaList 207 | 208 | Returns: 209 | String with stringified HArea objects 210 | ''' 211 | res = "#pair timeframe upper-price-lower no_pivots tot_score\n" 212 | for harea in self.halist: 213 | res += "{0} {1} {2}-{3}-{4} {5} {6}\n".format(harea.instrument, 214 | harea.granularity, 215 | harea.upper, 216 | harea.price, 217 | harea.lower, 218 | harea.no_pivots, 219 | harea.tot_score) 220 | return res.rstrip("\n") 221 | 222 | def plot(self, clO, outfile: str) -> None: 223 | """Plot this HAreaList 224 | 225 | Args: 226 | clO : CandeList object 227 | Used for plotting 228 | outfile : Output file 229 | """ 230 | prices, datetimes = ([] for i in range(2)) 231 | for c in clO.candles: 232 | prices.append(c.c) 233 | datetimes.append(c.time) 234 | 235 | # massage datetimes so they can be plotted in X-axis 236 | x = [mdates.date2num(i) for i in datetimes] 237 | 238 | # plotting the prices for part 239 | fig = plt.figure(figsize=gparams.size) 240 | ax = plt.axes() 241 | ax.plot(datetimes, prices, color="black") 242 | 243 | prices = [x.price for x in self.halist] 244 | 245 | # now, print an horizontal line for each S/R 246 | ax.hlines(prices, datetimes[0], datetimes[-1], color="green") 247 | 248 | fig.savefig(outfile, format='png') 249 | 250 | def __repr__(self): 251 | return "HAreaList" 252 | 253 | def __str__(self): 254 | out_str = "" 255 | for attr, value in self.__dict__.items(): 256 | out_str += "%s:%s " % (attr, value) 257 | return out_str 258 | -------------------------------------------------------------------------------- /forex/pivot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import datetime 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | from utils import periodToDelta, substract_pips2price, add_pips2price 7 | from params import gparams, pivots_params 8 | from forex.segment import SegmentList, Segment 9 | from zigzag import peak_valley_pivots, pivots_to_modes 10 | from statistics import mean 11 | 12 | # create logger 13 | p_logger = logging.getLogger(__name__) 14 | p_logger.setLevel(logging.INFO) 15 | 16 | 17 | class Pivot(object): 18 | """ 19 | Class representing a single Pivot 20 | 21 | Class variables: 22 | type : Type of pivot. It can be 1 or -1 23 | candle : Candle object ovrerlapping this pivot 24 | pre : Segment object before this pivot. When type=-1, pre.type will 25 | be 1. When type=1, pre.type will be -1 26 | aft : Segment object after this pivot. When type=-1, aft.type will 27 | be -1. When type=1, aft.type will be 1 28 | score : Result of adding the number 29 | of candles of the 'pre' and 'aft' segment (if defined) 30 | """ 31 | __slots__ = ["type", "candle", "pre", "aft", "score"] 32 | 33 | def __init__(self, type: int, candle, pre, aft, score: int = None): 34 | self.type = type 35 | self.candle = candle 36 | self.pre = pre 37 | self.aft = aft 38 | self.score = score 39 | 40 | def merge_pre(self, slist, n_candles: int, diff_th: int) -> None: 41 | """Function to merge 'pre' Segment. It will merge self.pre with 42 | previous segment if self.pre and previous segment are of the same 43 | type (1 or -1) or count of previous segment is less than 44 | pivots_params.n_candles 45 | 46 | Arguments: 47 | slist : SegmentList object 48 | SegmentList for PivotList of this Pivot. 49 | n_candles : Skip merge if Segment is greater than 'n_candles' 50 | diff_th : % of diff in pips threshold 51 | """ 52 | p_logger.debug("Running merge_pre") 53 | p_logger.debug(f"Analysis of pivot {self.candle.time}") 54 | p_logger.debug(f"self.pre start pre-merge: {self.pre.start()}") 55 | 56 | extension_needed = True 57 | while extension_needed is True: 58 | # reduce start of self.pre by one candle in order to retrieve 59 | # the previous segment by its end 60 | start_dt = self.pre.start() \ 61 | - periodToDelta(1, self.candle.granularity) 62 | 63 | s = slist.fetch_by_end(start_dt) 64 | if s is None: 65 | # This is not necessarily an error, it could be that there is 66 | # not the required Segment in slist because it is out of the 67 | # time period 68 | p_logger.debug("No Segment could be retrieved for pivot " 69 | f"falling in time {self.candle.time} " 70 | f"by using s.fetch_by_end and date: {start_dt} " 71 | "in function 'merge_pre'") 72 | extension_needed = False 73 | continue 74 | if self.pre.type == s.type: 75 | # merge if type of previous (s) is equal to self.pre 76 | p_logger.debug("Merge because of same Segment type") 77 | self.pre.prepend(s) 78 | elif self.pre.type != s.type and len(s.clist) < n_candles: 79 | # merge if types of previous segment and self.pre are 80 | # different but len(s.clist) is less than n_candles 81 | # calculate the % that s.diff is with respect to self.pre.diff 82 | perc_diff = s.diff*100/self.pre.diff 83 | # do not merge if perc_diff that s represents with respect 84 | # to s.pre is > than the defined threshold 85 | if perc_diff < diff_th: 86 | p_logger.debug("Merge because of len(s.clist) < n_candles") 87 | self.pre.prepend(s) 88 | else: 89 | p_logger.debug("Skipping merge because of %_diff") 90 | extension_needed = False 91 | else: 92 | # exit the while loop, as type of previous (s) and self.pre 93 | # are different and s.count is greater than n_candles 94 | extension_needed = False 95 | 96 | p_logger.debug(f"self.pre start after-merge: {self.pre.start()}") 97 | p_logger.debug("Done merge_pre") 98 | 99 | def merge_aft(self, slist, n_candles: int, diff_th: int) -> None: 100 | """Function to merge 'aft' Segment. It will merge self.aft with 101 | next segment if self.aft and next segment are of the same type 102 | (1 or -1) or count of next segment is less than 'n_candles' 103 | 104 | Arguments: 105 | slist : SegmentList object 106 | SegmentList for PivotList of this Pivot. 107 | n_candles : Skip merge if Segment is greater than 'n_candles' 108 | diff_th : % of diff in pips threshold 109 | """ 110 | p_logger.debug("Running merge_aft") 111 | p_logger.debug(f"Analysis of pivot {self.candle.time}") 112 | p_logger.debug(f"self.aft end before the merge: {self.aft.end()}") 113 | 114 | extension_needed = True 115 | while extension_needed is True: 116 | # increase end of self.aft by one candle 117 | start_dt = self.aft.end()+periodToDelta(1, 118 | self.candle.granularity) 119 | 120 | # fetch next segment 121 | s = slist.fetch_by_start(start_dt) 122 | if s is None: 123 | # This is not necessarily an error, it could be that there is 124 | # not the required Segment in slist 125 | # because it is out of the time period 126 | p_logger.debug("No Segment could be retrieved for pivot" 127 | f"falling in time {self.candle.time} by " 128 | f"using s.fetch_by_start and date: {start_dt} " 129 | "in function 'merge_aft'") 130 | extension_needed = False 131 | continue 132 | if self.aft.type == s.type: 133 | p_logger.debug("Merge because of same Segment type") 134 | self.aft.append(s) 135 | elif self.aft.type != s.type and len(s.clist) < n_candles: 136 | # calculate the % that s.diff is with respect to self.pre.diff 137 | perc_diff = s.diff * 100 / self.aft.diff 138 | # do not merge if perc_diff that s represents with respect 139 | # to s.aft is > than the defined threshold 140 | if perc_diff < diff_th: 141 | p_logger.debug("Merge because of len(s.clist) < n_candles") 142 | self.aft.append(s) 143 | else: 144 | p_logger.debug("Skipping merge because of %_diff") 145 | extension_needed = False 146 | else: 147 | extension_needed = False 148 | 149 | p_logger.debug("self.aft end after-merge: {0}".format(self.aft.end())) 150 | p_logger.debug("Done merge_aft") 151 | 152 | def calc_score(self, type='diff') -> float: 153 | """ 154 | Function to calculate the score for this Pivot 155 | The score will be the result of adding the 'diff' 156 | values or adding the number of candles of the 'pre' and 'aft' 157 | segments (if defined) 158 | 159 | Arguments: 160 | type : Type of score that will be 161 | calculated. Possible values: 'diff' , 'candles' 162 | """ 163 | if self.pre: 164 | score_pre = 0 165 | if type == 'diff': 166 | score_pre = self.pre.diff 167 | elif type == 'candles': 168 | score_pre = len(self.pre.clist) 169 | else: 170 | score_pre = 0.0 171 | 172 | if self.aft: 173 | score_aft = 0 174 | if type == 'diff': 175 | score_aft = self.aft.diff 176 | elif type == 'candles': 177 | score_aft = len(self.aft.clist) 178 | else: 179 | score_aft = 0.0 180 | 181 | return round(score_pre+score_aft, 2) 182 | 183 | def adjust_pivottime(self, clistO): 184 | """Function to adjust the pivot time 185 | This is necessary as sometimes the Zigzag algorithm 186 | does not find the correct pivot. 187 | 188 | Arguments: 189 | clistO : CandleList object used to identify the 190 | PivotList 191 | Returns: 192 | New adjusted datetime 193 | """ 194 | # reduce index by 1 so start candle+1 195 | # is not included 196 | clist = clistO.candles 197 | new_pc, pre_colour = None, None 198 | it = True 199 | ix = -1 200 | while it is True: 201 | cObj = clist[ix] 202 | if cObj.colour == "undefined": 203 | it = False 204 | new_pc = cObj 205 | continue 206 | if pre_colour is None: 207 | if cObj.colour == 'green' and self.type == -1: 208 | new_pc = cObj 209 | it = False 210 | elif cObj.colour == 'red' and self.type == 1: 211 | new_pc = cObj 212 | it = False 213 | pre_colour = cObj.colour 214 | ix -= 1 215 | elif self.type == -1: 216 | if cObj.colour == 'red' and cObj.colour == pre_colour: 217 | ix -= 1 218 | continue 219 | else: 220 | new_pc = cObj 221 | it = False 222 | elif self.type == 1: 223 | if cObj.colour == 'green' and cObj.colour == pre_colour: 224 | ix -= 1 225 | continue 226 | else: 227 | new_pc = cObj 228 | it = False 229 | return new_pc.time 230 | 231 | def __repr__(self): 232 | return "Pivot" 233 | 234 | def __str__(self): 235 | sb = [] 236 | for key in self.__slots__: 237 | if hasattr(self, key): 238 | sb.append("{key}='{value}'".format(key=key, 239 | value=getattr(self, key))) 240 | return ', '.join(sb) 241 | 242 | 243 | class PivotList(object): 244 | """Class that represents a list of Pivots as identified 245 | by the Zigzag indicator. 246 | 247 | Class variables: 248 | clist: CandleList object 249 | pivots: List with Pivot objects 250 | slist: SegmentList object""" 251 | 252 | __slots__ = ["clist", "pivots", "slist", 253 | "th_bounces"] 254 | 255 | def __init__(self, clist, pivots=None, slist=None, 256 | th_bounces: float = None) -> None: 257 | self.clist = clist 258 | if pivots is not None: 259 | assert slist is not None, "Error!. SegmentList needs to " 260 | "be provided" 261 | self.slist = slist 262 | self.pivots = pivots 263 | else: 264 | if th_bounces: 265 | po_l, segs = self._get_pivotlist(th_bounces) 266 | else: 267 | po_l, segs = self._get_pivotlist(pivots_params.th_bounces) 268 | self.pivots = po_l 269 | self.slist = SegmentList(slist=segs, 270 | instrument=clist.instrument) 271 | 272 | def __iter__(self): 273 | self.pos = 0 274 | return self 275 | 276 | def __next__(self): 277 | if (self.pos < len(self.pivots)): 278 | self.pos += 1 279 | return self.pivots[self.pos - 1] 280 | else: 281 | raise StopIteration 282 | 283 | def __getitem__(self, key): 284 | return self.pivots[key] 285 | 286 | def __len__(self): 287 | return len(self.pivots) 288 | 289 | def _get_pivotlist(self, th_bounces: float): 290 | """Function to obtain a pivotlist object containing pivots identified 291 | using the Zigzag indicator. 292 | 293 | Arguments: 294 | th_bounces: Value used by ZigZag to identify pivots. The lower the 295 | value the higher the sensitivity 296 | 297 | Returns: 298 | List with Pivot objects 299 | List with Segment objects 300 | """ 301 | yarr = np.array([cl.c for cl in self.clist.candles]) 302 | pivots = peak_valley_pivots(yarr, th_bounces, 303 | th_bounces*-1) 304 | modes = pivots_to_modes(pivots) 305 | 306 | segs = [] # this list will hold the Segment objects 307 | plist_o = [] # this list will hold the Pivot objects 308 | pre_s = None # Variable that will hold pre Segments 309 | ixs = list(np.where(np.logical_or(pivots == 1, pivots == -1))[0]) 310 | tuples_lst = [(ixs[i], ixs[i+1]) for i in range(len(ixs)-1)] 311 | for pair in tuples_lst: 312 | if pivots[pair[0]+1] == 0: 313 | submode = modes[pair[0]+1:pair[1]] 314 | else: 315 | submode = [modes[pair[0]+1]] 316 | # checking if all elements in submode are the same: 317 | assert len(np.unique(submode).tolist()) == 1, "more than one type in modes" 318 | s = Segment( 319 | type=submode[0], 320 | clist=self.clist.candles[pair[0]: pair[1]], 321 | instrument=self.clist.instrument, 322 | ) 323 | # create Pivot object 324 | cl = self.clist.candles[pair[0]] 325 | # add granularity to object 326 | cl.granularity = self.clist.granularity 327 | pobj = Pivot(type=submode[0], 328 | candle=cl, 329 | pre=pre_s, 330 | aft=s) 331 | pobj.score = pobj.calc_score() 332 | # Append it to list 333 | plist_o.append(pobj) 334 | # Append it to segs 335 | segs.append(s) 336 | pre_s = s 337 | 338 | # add last Pivot 339 | cl = self.clist.candles[ixs[-1]] 340 | cl.granularity = self.clist.granularity 341 | l_pivot = Pivot(type=modes[ixs[-1]], 342 | candle=cl, 343 | pre=pre_s, 344 | aft=None) 345 | l_pivot.score = l_pivot.calc_score() 346 | plist_o.append(l_pivot) 347 | return plist_o, segs 348 | 349 | def fetch_by_time(self, d: datetime) -> Pivot: 350 | '''Function to fetch a Pivot object using a datetime''' 351 | p = next((p for p in self.pivots if p.candle.time == d), None) 352 | return p 353 | 354 | def fetch_by_type(self, type: int) -> 'PivotList': 355 | '''Function to get all pivots from a certain type. 356 | 357 | Arguments: 358 | type : 1 or -1 359 | ''' 360 | 361 | pl = [p for p in self.pivots if p.type == type] 362 | 363 | return PivotList(pivots=pl, 364 | clist=self.clist, 365 | slist=self.slist) 366 | 367 | def print_pivots_dates(self) -> list: 368 | '''Function to generate a list with the datetimes of the different 369 | Pivots in PivotList''' 370 | 371 | datelist = [] 372 | for p in self.pivots: 373 | datelist.append(p.candle.time) 374 | 375 | return datelist 376 | 377 | def get_score(self) -> float: 378 | '''Function to calculate the score after adding the score 379 | for each individual pivot''' 380 | 381 | tot_score = sum(p.score for p in self.pivots) 382 | return round(tot_score, 1) 383 | 384 | def get_avg_score(self) -> float: 385 | '''Function to calculate the avg score 386 | for all pivots in this PivotList. 387 | This calculation is done by dividing the 388 | total score by the number of pivots 389 | ''' 390 | avg = mean(p.score for p in self.pivots) 391 | return round(avg, 1) 392 | 393 | def inarea_pivots(self, price: float, last_pivot: bool = True): 394 | ''' 395 | Function to identify the candles for which price is in the area defined 396 | by SR+HRpips and SR-HRpips 397 | 398 | Arguments: 399 | pivots: PivotList object 400 | SR: price of the S/R area 401 | last_pivot: If True, then the last pivot will be considered as 402 | it is part of the setup. 403 | 404 | Returns: 405 | PivotList with the pivots that are in the area centered at 'price' 406 | ''' 407 | # get bounces in the horizontal SR area 408 | lower = substract_pips2price(self.clist.instrument, 409 | price, 410 | pivots_params.hr_pips) 411 | upper = add_pips2price(self.clist.instrument, 412 | price, 413 | pivots_params.hr_pips) 414 | 415 | p_logger.debug("SR U-limit: {0}; L-limit: {1}".format(round(upper, 4), 416 | round(lower, 4))) 417 | 418 | pl = [] 419 | for p in self.pivots: 420 | # always consider the last pivot in bounces.plist as in_area as 421 | # this part of the entry setup 422 | if self.pivots[-1].candle.time == p.candle.time \ 423 | and last_pivot is True: 424 | adj_t = p.adjust_pivottime(clistO=self.clist) 425 | newclist = self.clist.slice(start=self.clist.candles[0].time, 426 | end=adj_t) 427 | newpl = PivotList(clist=newclist) 428 | newp = newpl._get_pivotlist(pivots_params.th_bounces)[0][-1] 429 | if pivots_params.runmerge_pre is True and newp.pre is not None: 430 | newp.merge_pre(slist=self.slist, 431 | n_candles=pivots_params.n_candles, 432 | diff_th=pivots_params.diff_th) 433 | if pivots_params.runmerge_aft is True and newp.aft is not None: 434 | newp.merge_aft(slist=self.slist, 435 | n_candles=pivots_params.n_candles, 436 | diff_th=pivots_params.diff_th) 437 | pl.append(newp) 438 | else: 439 | part_list = ['c'] 440 | if p.type == 1: 441 | part_list.append('h') 442 | elif p.type == -1: 443 | part_list.append('l') 444 | 445 | for part in part_list: 446 | price = getattr(p.candle, part) 447 | # only consider pivots in the area 448 | if price >= lower and price <= upper: 449 | # check if this pivot already exists in pl 450 | p_seen = False 451 | for op in pl: 452 | if op.candle.time == p.candle.time: 453 | p_seen = True 454 | if p_seen is False: 455 | p_logger.debug(f"Pivot {p.candle.time} identified in area") 456 | if pivots_params.runmerge_pre is True and \ 457 | p.pre is not None: 458 | p.merge_pre(slist=self.slist, 459 | n_candles=pivots_params.n_candles, 460 | diff_th=pivots_params.diff_th) 461 | if pivots_params.runmerge_aft is True and \ 462 | p.aft is not None: 463 | p.merge_aft(slist=self.slist, 464 | n_candles=pivots_params.n_candles, 465 | diff_th=pivots_params.diff_th) 466 | pl.append(p) 467 | 468 | return PivotList(clist=self.clist, 469 | pivots=pl, 470 | slist=self.slist) 471 | 472 | def calc_itrend(self): 473 | '''Function to calculate the datetime for the start of this CandleList, 474 | assuming that this CandleList is trending. This function will calculate 475 | the start of the trend by using the self.get_pivots function 476 | 477 | Returns: 478 | Merged Segment object containing the trend_i 479 | ''' 480 | p_logger.debug("Running calc_itrend") 481 | for p in reversed(self.pivots): 482 | adj_t = p.adjust_pivottime(clistO=self.clist) 483 | start = self.pivots[0].candle.time 484 | newclist = self.clist.slice(start=start, 485 | end=adj_t) 486 | newp = PivotList(clist=newclist).pivots[-1] 487 | newp.merge_pre(slist=self.slist, 488 | n_candles=pivots_params.n_candles, 489 | diff_th=pivots_params.diff_th) 490 | return newp.pre 491 | 492 | p_logger.debug("Done clac_itrend") 493 | 494 | def plot_pivots(self, outfile_prices: str, outfile_rsi: str): 495 | '''Function to plot all pivots that are in the area 496 | 497 | Arguments: 498 | outfile_prices : Output file for prices plot. 499 | outfile_rsi : Output file for rsi plot. 500 | ''' 501 | p_logger.debug("Running plot_pivots") 502 | 503 | prices, rsi, datetimes = ([] for i in range(3)) 504 | for c in self.clist.candles: 505 | prices.append(c.c) 506 | rsi.append(c.rsi) 507 | datetimes.append(c.time) 508 | 509 | # plotting the rsi values 510 | fig_rsi = plt.figure(figsize=gparams.size) 511 | ax_rsi = plt.axes() 512 | ax_rsi.plot(datetimes, rsi, color="black") 513 | fig_rsi.savefig(outfile_rsi, format='png') 514 | 515 | # plotting the prices for part 516 | fig = plt.figure(figsize=gparams.size) 517 | ax = plt.axes() 518 | ax.plot(datetimes, prices, color="black") 519 | 520 | for p in self.pivots: 521 | dt = p.candle.time 522 | ix = datetimes.index(dt) 523 | # prepare the plot for 'pre' segment 524 | if p.pre is not None: 525 | ix_pre_s = datetimes.index(p.pre.start()) 526 | plt.scatter(datetimes[ix_pre_s], prices[ix_pre_s], 527 | s=200, c='green', marker='v') 528 | # prepare the plot for 'aft' segment 529 | if p.aft is not None: 530 | ix_aft_e = datetimes.index(p.aft.end()) 531 | plt.scatter(datetimes[ix_aft_e], prices[ix_aft_e], 532 | s=200, c='red', marker='v') 533 | # plot 534 | plt.scatter(datetimes[ix], prices[ix], s=50) 535 | 536 | fig.savefig(outfile_prices, format='png') 537 | 538 | p_logger.debug("plot_pivots Done") 539 | 540 | def pivots_report(self, outfile: str) -> str: 541 | """Function to generate a report of the pivots in the PivotList 542 | 543 | Arguments: 544 | outfile : Path to file with report 545 | 546 | Returns: 547 | file with PivotList report with Pivots information. 548 | This file will have the following format: 549 | #pre.start|p.candle.time|p.aft.end 550 | """ 551 | f = open(outfile, 'w') 552 | f.write("#pre.start|p.candle['time']|p.aft.end\n") 553 | for p in self.pivots: 554 | if p.pre is None and p.aft is not None: 555 | f.write("{0}|{1}|{2}\n".format("n.a.", 556 | p.candle.time, p.aft.end())) 557 | elif p.pre is not None and p.aft is not None: 558 | f.write("{0}|{1}|{2}\n".format(p.pre.start(), 559 | p.candle.time, p.aft.end())) 560 | elif p.pre is not None and p.aft is None: 561 | f.write("{0}|{1}|{2}\n".format(p.pre.start(), 562 | p.candle.time, "n.a.")) 563 | f.close 564 | 565 | return outfile 566 | 567 | def __str__(self): 568 | sb = [] 569 | for key in self.__slots__: 570 | if hasattr(self, key): 571 | sb.append("{key}='{value}'".format(key=key, 572 | value=getattr(self, key))) 573 | return ', '.join(sb) 574 | 575 | def __repr__(self): 576 | return self.__str__() 577 | -------------------------------------------------------------------------------- /forex/segment.py: -------------------------------------------------------------------------------- 1 | from utils import calculate_pips 2 | import matplotlib 3 | import datetime 4 | import pickle 5 | 6 | from typing import List, Dict 7 | 8 | matplotlib.use('PS') 9 | 10 | 11 | class Segment(object): 12 | """Class containing a Segment object identified linking the pivots in the 13 | PivotList 14 | 15 | Class variables 16 | --------------- 17 | type : 1 or -1. 1 when the segment goes upwards and -1 downwards 18 | clist : List of dictionaries, in which each of the dicts is a candle 19 | instrument : Pair 20 | """ 21 | 22 | __slots__ = ['type', 'clist', 'instrument', '_diff'] 23 | 24 | def __init__(self, type: int, clist: List[Dict], instrument: str): 25 | self.type = type 26 | self.clist = clist 27 | self.instrument = instrument 28 | self._diff = self._calc_diff() 29 | 30 | @property 31 | def diff(self): 32 | return self._diff 33 | 34 | def pickle_dump(self, outfile: str) -> str: 35 | '''Function to pickle this particular Segment 36 | 37 | Arguments: 38 | outfile: Path used to pickle 39 | 40 | Returns: 41 | path to file with pickled data 42 | ''' 43 | 44 | pickle_out = open(outfile, "wb") 45 | pickle.dump(self, pickle_out) 46 | pickle_out.close() 47 | 48 | return outfile 49 | 50 | @classmethod 51 | def pickle_load(self, infile: str): 52 | '''Function to pickle this particular Segment 53 | 54 | Arguments: 55 | infile: Path used to load in pickled data 56 | 57 | Returns: 58 | inseg: Segment object 59 | ''' 60 | pickle_in = open(infile, "rb") 61 | inseg = pickle.load(pickle_in) 62 | pickle_in.close() 63 | 64 | return inseg 65 | 66 | def prepend(self, s) -> None: 67 | '''Function to prepend s to self. The merge will be done by 68 | concatenating s.clist to self.clist and increasing self.count to 69 | self.count+s.count 70 | 71 | Arguments: 72 | s : Segment object to be merged 73 | ''' 74 | 75 | self.clist = s.clist+self.clist 76 | self._diff = self._calc_diff() 77 | 78 | def append(self, s) -> None: 79 | '''Function to append s to self. The merge will be done by 80 | concatenating self.clist to self.clist and increasing self.count to 81 | self.count+s.count 82 | 83 | Arguments: 84 | s : Segment object to be merged 85 | ''' 86 | self.clist = self.clist+s.clist 87 | self._diff = self._calc_diff() 88 | 89 | def _calc_diff(self) -> float: 90 | '''Private function to calculate the absolute difference in 91 | number of pips between the first and the last candles 92 | of this segment. The candle part considered is 93 | controlled by gparams.part 94 | ''' 95 | diff = abs(self.clist[-1].c - self.clist[0].c) 96 | diff_pips = float(calculate_pips(self.instrument, diff)) 97 | if diff_pips == 0: 98 | diff_pips = 1.0 99 | return diff_pips 100 | 101 | def is_short(self, min_n_candles: int, diff_in_pips: int) -> bool: 102 | '''Function to check if segment is short (self.diff < pip_th or 103 | self.count < candle_th) 104 | 105 | Arguments: 106 | min_n_candles: Minimum number of candles for this segment to be 107 | considered short 108 | diff_in_pips: Minimum number of pips for this segment to be 109 | considered short 110 | 111 | Returns: 112 | True if is short 113 | ''' 114 | if self.count < min_n_candles and self.diff < diff_in_pips: 115 | return True 116 | else: 117 | return False 118 | 119 | def start(self) -> datetime: 120 | '''Function that returns the start of this Segment''' 121 | return self.clist[0].time 122 | 123 | def end(self) -> datetime: 124 | '''Function that returns the end of this Segment''' 125 | return self.clist[-1].time 126 | 127 | def get_lowest(self): 128 | '''Function to get the candle with the lowest price in self.clist 129 | 130 | Returns: 131 | Candle object 132 | ''' 133 | sel_c = price = None 134 | for c in self.clist: 135 | if price is None: 136 | price = c.l 137 | sel_c = c 138 | elif c.l < price: 139 | price = c.l 140 | sel_c = c 141 | return sel_c 142 | 143 | def get_highest(self): 144 | '''Function to get the candle with the highest price in self.clist 145 | 146 | Returns: 147 | Candle object 148 | ''' 149 | price = sel_c = None 150 | for c in self.clist: 151 | if price is None: 152 | price = c.h 153 | sel_c = c 154 | elif c.h > price: 155 | price = c.h 156 | sel_c = c 157 | return sel_c 158 | 159 | def __repr__(self): 160 | return "Segment" 161 | 162 | def __str__(self): 163 | sb = [] 164 | for key in self.__slots__: 165 | if hasattr(self, key): 166 | sb.append("{key}='{value}'".format(key=key, 167 | value=getattr(self, key))) 168 | return ', '.join(sb) 169 | 170 | 171 | class SegmentList(object): 172 | '''Class that represents a list of segments 173 | 174 | Class variables 175 | --------------- 176 | slist : List of Segment objects 177 | instrument : Pair 178 | diff : Diff in pips between first candle in first Segment 179 | and last candle in the last Segment 180 | ''' 181 | __slots__ = ['slist', 'instrument', '_diff'] 182 | 183 | def __init__(self, slist: list, instrument: str): 184 | self.slist = slist 185 | self.instrument = instrument 186 | self._diff = self.calc_diff() 187 | 188 | @property 189 | def diff(self): 190 | return self._diff 191 | 192 | def __iter__(self): 193 | self.pos = 0 194 | return self 195 | 196 | def __next__(self): 197 | if (self.pos < len(self.slist)): 198 | self.pos += 1 199 | return self.slist[self.pos - 1] 200 | else: 201 | raise StopIteration 202 | 203 | def __getitem__(self, key): 204 | return self.slist[key] 205 | 206 | def __len__(self): 207 | return len(self.slist) 208 | 209 | def pickle_dump(self, outfile: str) -> str: 210 | '''Function to pickle this particular SegmentList 211 | 212 | Arguments: 213 | outfile: Path used to pickle 214 | 215 | Returns: 216 | path to file with pickled data 217 | ''' 218 | 219 | pickle_out = open(outfile, "wb") 220 | pickle.dump(self, pickle_out) 221 | pickle_out.close() 222 | 223 | return outfile 224 | 225 | @classmethod 226 | def pickle_load(self, infile: str): 227 | '''Function to pickle this particular SegmentList 228 | 229 | Arguments: 230 | infile: Path used to load in pickled data 231 | 232 | Returns: 233 | inseg: Segment object 234 | ''' 235 | pickle_in = open(infile, "rb") 236 | inseglst = pickle.load(pickle_in) 237 | pickle_in.close() 238 | 239 | return inseglst 240 | 241 | def calc_diff(self) -> float: 242 | '''Function to calculate the difference in terms 243 | of number of pips between the 1st candle in 244 | the 1st segment and the last candle in the 245 | last segment 246 | 247 | Returns: 248 | float representing the diff in pips. It will be positive 249 | when it is a downtrend and negative otherwise 250 | ''' 251 | diff = self.slist[0].clist[0].c - self.slist[-1].clist[-1].c 252 | diff_pips = float(calculate_pips(self.instrument, diff)) 253 | 254 | if diff_pips == 0: 255 | diff_pips += 1.0 256 | 257 | self._diff = diff_pips 258 | 259 | def length_cl(self) -> int: 260 | '''Get length in terms of number of candles representing the sum 261 | of candles in each Segment of the SegmentList''' 262 | 263 | length = 0 264 | for s in self.slist: 265 | length = length+len(s.clist) 266 | return length 267 | 268 | def start(self) -> datetime: 269 | '''Get the start datetime for this SegmentList 270 | This start will be the time of the first candle in SegmentList''' 271 | return self.slist[0].clist[0].time 272 | 273 | def end(self) -> datetime: 274 | '''Get the end datetime for this SegmentList 275 | This start will be the time of the first candle in SegmentList''' 276 | return self.slist[-1].clist[-1].time 277 | 278 | def fetch_by_start(self, dt: datetime, max_diff: int = 3600): 279 | '''Function to get a certain Segment by 280 | the start Datetime 281 | 282 | Arguments: 283 | dt: Start of segment datetime used 284 | for fetching the Segment 285 | max_diff : Max discrepancy in number of seconds for the difference 286 | dt-s.start() 287 | Default: 3600 secs (i.e. 1hr). This is relevant when 288 | analysing with granularity = H1 or lower. 289 | 290 | Returns: 291 | Segment object. None if not found 292 | ''' 293 | for s in self.slist: 294 | if s.start() == dt or s.start() > dt or \ 295 | abs(s.start()-dt) <= datetime.timedelta(0, max_diff): 296 | return s 297 | 298 | return None 299 | 300 | def fetch_by_end(self, dt: datetime, max_diff: int = 3600): 301 | '''Function to get a certain Segment by 302 | the end Datetime 303 | 304 | Arguments: 305 | dt: End of segment datetime used 306 | for fetching the Segment 307 | max_diff : Max discrepancy in number of seconds for the difference 308 | dt-s.end() 309 | Default: 3600 secs (i.e. 1hr). This is relevant when 310 | analysing with granularity = H1 or lower. 311 | 312 | Returns: 313 | Segment object. None if not found''' 314 | 315 | for s in reversed(self.slist): 316 | if s.end() == dt or s.end() < dt or \ 317 | s.end()-dt <= datetime.timedelta(0, max_diff): 318 | return s 319 | 320 | def __repr__(self): 321 | return "SegmentList" 322 | 323 | def __str__(self): 324 | sb = [] 325 | for key in self.__slots__: 326 | if hasattr(self, key): 327 | sb.append("{key}='{value}'".format(key=key, 328 | value=getattr(self, key))) 329 | return ', '.join(sb) 330 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from utils import DATA_DIR 3 | 4 | 5 | @dataclass 6 | class gparams: 7 | """General paramaters""" 8 | # Folder to store all output files 9 | outdir: str = f"{DATA_DIR}/out/" 10 | # Increase verbosity 11 | debug: bool = True 12 | # candle's body percentage below which the candle will be considered 13 | # indecision candle 14 | ic_perc: int = 20 15 | # size of images 16 | size = (20, 10) 17 | 18 | 19 | @dataclass 20 | class tjournal_params: 21 | # Column names that will be written in the output worksheet 22 | colnames: str = 'id,timeframe,start,end,strat,type,entry,session,TP,SL,'\ 23 | 'RR,SR,tot_SR,rank_selSR,entry_time,outcome,pips,SLdiff,lasttime,' \ 24 | 'pivots, pivots_lasttime,total_score,score_lasttime,score_pivot,' \ 25 | 'score_pivot_lasttime,trend_i,entry_onrsi,pips_c_trend,max_min_rsi' 26 | 27 | 28 | @dataclass 29 | class counter_params: 30 | # number of candles from start to calculate max,min RSI 31 | rsi_period: int = 20 32 | 33 | 34 | @dataclass 35 | class tradebot_params: 36 | # quantile used as threshold for selecting S/R 37 | th: float = 0.70 38 | # invoke 'calc_SR' each 'period' number of candles 39 | period: int = 60 40 | # number of candles to go back for calculating S/R price 41 | # range. This is relevant for trade_bot's get_max_min function and it will 42 | # also be relevant to decide how much to go back in time to detect SRs 43 | period_range: int = 1500 44 | # add this number of pips to SL and entry 45 | add_pips: int = 10 46 | # Risk Ratio for trades 47 | RR: float = 1.5 48 | # adjust_SL type 49 | adj_SL: str = 'candles' 50 | # adjust_SL_pips number of pips. Only relevant if adj_SL='pips' 51 | adj_SL_pips: int = 100 52 | # do not consider trades with an ic with height>than 'max_height' pips 53 | max_height: int = 150 54 | 55 | def __post_init__(self): 56 | if self.adj_SL not in ['candles', 'pips', 'nextSR']: 57 | raise ValueError(f"{self.adj_SL} is not a valid for adj_SL") 58 | 59 | 60 | @dataclass 61 | class pivots_params: 62 | # Number of pips above/below SR to identify bounces 63 | hr_pips: int = 25 64 | # Value used by ZigZag to identify pivots. The lower the 65 | # value the higher the sensitivity 66 | th_bounces: float = 0.02 67 | # (int) Skip merge if Segment is greater than 'n_candles' 68 | n_candles: int = 18 69 | # (int) % of diff in pips threshold 70 | diff_th: int = 50 71 | # Boolean, if true then produce png files for in_area pivots and 72 | # rsi_bounces 73 | plot: bool = False 74 | # Bool. run merge_pre's function from Pivot class 75 | runmerge_pre: bool = True 76 | # Bool. run merge_aft's function from Pivot class 77 | runmerge_aft: bool = True 78 | 79 | 80 | @dataclass 81 | class clist_params: 82 | # Number of candles used for calculating the RSI 83 | rsi_period: int = 14 84 | # SR detection 85 | i_pips: int = 30 86 | # Minimum number of candles from start to be required 87 | min: int = 5 88 | # Number of times * increment_price to be used by calc_diff 89 | # The lower times, the more clustered the retained HAreas will be 90 | times: int = 3 91 | 92 | 93 | @dataclass 94 | class trade_params: 95 | # When using run method, how many pips above/below the HArea will 96 | # considered to check if it hits the entry,SL or TP 97 | hr_pips: int = 1 98 | # number of candles from start of trade to run the trade and assess the 99 | # outcome 100 | numperiods: int = 30 101 | # number of candles from start of trade to create a time interval that will 102 | # be assessed. 103 | interval: int = 1500 104 | # granularity for HArea.get_cross_time 105 | granularity: str = "H2" 106 | # num of candles from trade.start to calc ATR 107 | period_atr: int = 20 108 | # number of candles to go back when init_clist=True 109 | trade_period: int = 5000 110 | # number of pips to add/substract to SR to calculate lasttime 111 | pad: int = 30 112 | th_bounces: int = 0.02 # pivot sensitivity for 'get_trade_type' 113 | 114 | 115 | @dataclass 116 | class trade_management_params(trade_params): 117 | strat: str = "area_unaware" 118 | clisttm_tf: str = "H8" 119 | preceding_clist_strat: str = "wipe" 120 | 121 | def __post_init__(self): 122 | if self.strat not in [ 123 | "area_unaware", 124 | "area_aware", 125 | "breakeven", 126 | "trackingtrade", 127 | "trackingawaretrade" 128 | ]: 129 | raise ValueError(f"Invalid strat: {self.strat}") 130 | 131 | if self.preceding_clist_strat not in ["wipe", "queue"]: 132 | raise ValueError(f"Invalid preceding_clist_strat: {self.strat}") 133 | 134 | 135 | @dataclass 136 | class breakeven_params(trade_management_params): 137 | number_of_pips: int = 10 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | DateTime==4.3 2 | flatdict==4.0.1 3 | matplotlib==3.4.3 4 | openpyxl==3.0.9 5 | pandas==1.3.3 6 | pytest-mock==3.14.0 7 | PyYAML==6.0.2 8 | requests==2.26.0 9 | scikit-learn==1.0 10 | scipy==1.7.1 11 | ZigZag==0.1.3 12 | -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/.DS_Store -------------------------------------------------------------------------------- /tests/api/oanda/cmd.sh: -------------------------------------------------------------------------------- 1 | curl -H "Authorization: Bearer $TOKEN" "https://api-fxtrade.oanda.com/v3/instruments//AUD_USD/candles?count=1000&granularity=D&from=2000-12-28T22%3A00%3A00" -------------------------------------------------------------------------------- /tests/api/oanda/test_connect.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | 4 | from datetime import datetime 5 | from api.oanda.connect import Connect 6 | from trading_journal.trade_utils import process_start 7 | 8 | 9 | @pytest.fixture 10 | def conn_o(): 11 | log = logging.getLogger("connect_o") 12 | log.debug("Create a Connect object") 13 | 14 | conn = Connect(instrument="AUD_USD", granularity="D") 15 | return conn 16 | 17 | 18 | def test_query_s_e(conn_o): 19 | log = logging.getLogger("test_query_s_e") 20 | log.debug("Test for 'query' function with a start and end datetimes") 21 | clO = conn_o.query("2018-11-16T22:00:00", "2018-11-20T22:00:00") 22 | assert clO.instrument == "AUD_USD" 23 | assert clO.granularity == "D" 24 | assert len(clO) == 3 25 | assert isinstance(clO.times[0], datetime) 26 | 27 | 28 | def test_query_c(conn_o): 29 | log = logging.getLogger("test_query_c") 30 | log.debug("Test for 'query' function with a start and count parameters") 31 | clO = conn_o.query("2018-11-16T22:00:00", count=1) 32 | assert clO.instrument == "AUD_USD" 33 | assert clO.granularity == "D" 34 | assert len(clO) == 1 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "i,g,s,e,l", 39 | [ 40 | ("GBP_NZD", "D", "2018-11-23T22:00:00", "2019-01-02T22:00:00", 27), 41 | ("GBP_AUD", "D", "2002-11-23T22:00:00", "2007-01-02T22:00:00", 877), 42 | ("EUR_NZD", "D", "2002-11-23T22:00:00", "2007-01-02T22:00:00", 881), 43 | ("AUD_USD", "D", "2015-01-25T22:00:00", "2015-01-26T22:00:00", 2), 44 | ("AUD_USD", "D", "2018-11-16T22:00:00", "2018-11-20T22:00:00", 3), 45 | ("AUD_USD", "H12", "2018-11-12T10:00:00", "2018-11-14T10:00:00", 5), 46 | # End date falling in the daylight savings discrepancy(US/EU) period 47 | ("AUD_USD", "D", "2018-03-26T22:00:00", "2018-03-29T22:00:00", 4), 48 | # End date falling in Saturday at 21h 49 | ("AUD_USD", "D", "2018-05-21T21:00:00", "2018-05-26T21:00:00", 4), 50 | # Start and End data fall on closed market 51 | ("AUD_USD", "D", "2018-04-27T21:00:00", "2018-04-28T21:00:00", 0), 52 | # Start date before the start of historical record 53 | ("AUD_USD", "H12", "2000-11-21T22:00:00", "2002-06-15T22:00:00", 32), 54 | ], 55 | ) 56 | def test_m_queries(i, g, s, e, l): 57 | log = logging.getLogger("test_query_ser") 58 | log.debug( 59 | "Test for 'query' function with a mix of " "instruments and different datetimes" 60 | ) 61 | 62 | conn = Connect(instrument=i, granularity=g) 63 | 64 | clO = conn.query(start=s, end=e) 65 | assert len(clO) == l 66 | 67 | 68 | def test_query_M30(): 69 | log = logging.getLogger("test_query_M30") 70 | log.debug("Test for 'query' function with granularity M30") 71 | 72 | conn = Connect(instrument="AUD_USD", granularity="M30") 73 | 74 | clO = conn.query(start="2018-05-21T21:00:00", end="2018-05-23T21:00:00") 75 | assert len(clO) == 97 76 | 77 | 78 | def test_validate_datetime(): 79 | log = logging.getLogger("test_validate_datetime") 80 | log.debug("Test for function with wrong datetime format") 81 | 82 | conn = Connect(instrument="AUD_USD", granularity="D") 83 | 84 | with pytest.raises(ValueError): 85 | conn.validate_datetime("2018-05-23T21") 86 | 87 | 88 | def test_query_in_future(): 89 | """Query with a future datetime""" 90 | timeframe = "D" 91 | now = datetime.now() 92 | aligned_start = process_start(dt=now, timeframe=timeframe).isoformat().split(".")[0] 93 | 94 | conn = Connect(instrument=timeframe, granularity="AUD_USD") 95 | assert conn.query(aligned_start, count=1).candles == [] 96 | 97 | 98 | date_data = [ 99 | (datetime(2023, 10, 4, 17, 0), datetime(2023, 10, 4, 17, 0)), 100 | (datetime(2023, 10, 8, 18, 0), None), 101 | (datetime(2023, 10, 8, 21, 0), datetime(2023, 10, 8, 21, 0)), 102 | (datetime(2022, 11, 30, 21, 0), datetime(2022, 11, 30, 22, 0)), 103 | (datetime(2022, 11, 18, 21, 0), None) 104 | ] 105 | 106 | 107 | @pytest.mark.parametrize("day,expected_datetime", date_data) 108 | def test_fetch_candle(day, expected_datetime): 109 | conn = Connect(instrument="AUD_USD", granularity="H4") 110 | candle = conn.fetch_candle(d=day) 111 | if expected_datetime is not None: 112 | assert candle.time == expected_datetime 113 | else: 114 | assert candle is None 115 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from forex.candle import CandleList 4 | from utils import DATA_DIR 5 | 6 | @pytest.fixture 7 | def clO_pickled(): 8 | clO = CandleList.pickle_load(DATA_DIR+"/clist_audusd_2010_2020.pckl") 9 | clO.calc_rsi() 10 | 11 | return clO 12 | 13 | 14 | @pytest.fixture 15 | def clOH8_2019_pickled(): 16 | clO = CandleList.pickle_load(DATA_DIR+"/clist.AUDUSD.H8.2019.pckl") 17 | clO.calc_rsi() 18 | 19 | return clO -------------------------------------------------------------------------------- /tests/data/clist.AUDUSD.H8.2019.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/clist.AUDUSD.H8.2019.pckl -------------------------------------------------------------------------------- /tests/data/clist.AUDUSD.H8.2021.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/clist.AUDUSD.H8.2021.pckl -------------------------------------------------------------------------------- /tests/data/clist.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/clist.pckl -------------------------------------------------------------------------------- /tests/data/clist_audusd_2010_2020.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/clist_audusd_2010_2020.pckl -------------------------------------------------------------------------------- /tests/data/create_pickled_data.py: -------------------------------------------------------------------------------- 1 | from forex.segment import SegmentList, Segment 2 | from forex.candle import CandleList 3 | from utils import DATA_DIR 4 | from forex.pivot import PivotList 5 | from api.oanda.connect import Connect 6 | 7 | # pickle CandleList Daily 8 | conn = Connect( 9 | instrument="AUD_USD", 10 | granularity='D') 11 | clO = conn.query('2010-11-16T22:00:00', '2020-11-19T22:00:00') 12 | clO.pickle_dump(f"{DATA_DIR}/clist_audusd_2010_2020.pckl") 13 | 14 | # pickle CandleLists H8 15 | conn = Connect( 16 | instrument="AUD_USD", 17 | granularity='H8') 18 | clO = conn.query("2021-01-03T22:00:00", "2021-12-31T14:00:00") 19 | clO.pickle_dump(f"{DATA_DIR}/clist.AUDUSD.H8.2021.pckl") 20 | 21 | clO = conn.query("2019-01-03T22:00:00", "2019-12-31T14:00:00") 22 | clO.pickle_dump(f"{DATA_DIR}/clist.AUDUSD.H8.2019.pckl") 23 | 24 | 25 | # pickle Segment 26 | clObj = CandleList.pickle_load(f"{DATA_DIR}/clist_audusd_2010_2020.pckl") 27 | pl = PivotList(clist=clObj) 28 | pl.slist[5].pickle_dump('seg_audusd.pckl') 29 | pl.slist[6].pickle_dump('seg_audusdB.pckl') 30 | 31 | #pickle SegmentList 32 | pl.slist.pickle_dump('seglist_audusd.pckl') 33 | 34 | -------------------------------------------------------------------------------- /tests/data/harealist_file.txt: -------------------------------------------------------------------------------- 1 | 0.50 2 | 0.60 3 | 0.70 4 | 0.80 5 | 0.90 6 | 1.00 -------------------------------------------------------------------------------- /tests/data/harealist_file.yaml: -------------------------------------------------------------------------------- 1 | AUD_USD: 2 | - 0.57341 3 | - 0.63846 4 | - 0.69185 5 | - 0.71600 6 | - 0.75278 -------------------------------------------------------------------------------- /tests/data/seg_audusd.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/seg_audusd.pckl -------------------------------------------------------------------------------- /tests/data/seg_audusdB.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/seg_audusdB.pckl -------------------------------------------------------------------------------- /tests/data/seglist_audusd.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/seglist_audusd.pckl -------------------------------------------------------------------------------- /tests/data/testCounter.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FOREXfunguru/FOREX/2888a7c7d79c72ce303857787a6b07ed9f4f00c8/tests/data/testCounter.xlsx -------------------------------------------------------------------------------- /tests/forex/candle/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | 4 | from forex.candle import CandleList 5 | 6 | 7 | @pytest.fixture 8 | def clO(): 9 | log = logging.getLogger('cl_object') 10 | log.debug('Create a CandleList object') 11 | 12 | alist = [ 13 | { 14 | 'time': '2018-11-18T22:00:00', 15 | 'o': '0.73093', 16 | 'h': '0.73258', 17 | 'l': '0.72776', 18 | 'c': '0.72950'}, 19 | { 20 | 'time': '2018-11-19T22:00:00', 21 | 'o': '0.70123', 22 | 'h': '0.75123', 23 | 'l': '0.68123', 24 | 'c': '0.72000' 25 | } 26 | ] 27 | cl = CandleList(instrument='AUD_USD', 28 | granularity='D', 29 | data=alist) 30 | return cl 31 | 32 | 33 | @pytest.fixture 34 | def clO1(): 35 | log = logging.getLogger('cl_object') 36 | log.debug('Create a CandleList object') 37 | 38 | alist = [ 39 | { 40 | 'time': '2018-11-20T22:00:00', 41 | 'o': '0.73093', 42 | 'h': '0.73258', 43 | 'l': '0.72776', 44 | 'c': '0.72950'}, 45 | { 46 | 'time': '2018-11-21T22:00:00', 47 | 'o': '0.70123', 48 | 'h': '0.75123', 49 | 'l': '0.68123', 50 | 'c': '0.72000' 51 | } 52 | ] 53 | cl = CandleList(instrument='AUD_USD', 54 | granularity='D', 55 | data=alist) 56 | return cl 57 | -------------------------------------------------------------------------------- /tests/forex/candle/test_Candle.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @date: 22/11/2020 3 | @author: Ernesto Lowy 4 | @email: ernestolowy@gmail.com 5 | ''' 6 | import pytest 7 | 8 | from forex.candle import Candle 9 | 10 | candle_feats = [ 11 | ("2023-11-14T10:00:00", 0.87106, 0.87309, 0.87000, 0.87003), 12 | ("2018-11-14T14:00:00", 0.87003, 0.87104, 0.86890, 0.87035), 13 | ("2018-11-03T13:00:00", 0.86979, 0.87016, 0.86690, 0.86730) 14 | ] 15 | 16 | 17 | @pytest.fixture 18 | def CandleFactory(): 19 | """Candle object factory""" 20 | candle_list = [] 21 | for feats in candle_feats: 22 | candle_dict = { 23 | "time": feats[0], 24 | "o": feats[1], 25 | "h": feats[2], 26 | "l": feats[3], 27 | "c": feats[4]} 28 | candle_list.append(Candle(**candle_dict)) 29 | return candle_list 30 | 31 | 32 | @pytest.mark.parametrize("colour, perc_body", [ 33 | (["red", "green", "red"], [33.33, 14.95, 76.38]) 34 | ]) 35 | def test_check_candle_feats(CandleFactory, colour, perc_body): 36 | """Check that candle has the right attributes""" 37 | for ix in range(len(CandleFactory)): 38 | assert CandleFactory[ix].colour == colour[ix] 39 | assert CandleFactory[ix].perc_body == perc_body[ix] 40 | 41 | 42 | @pytest.mark.parametrize("indecision_candle", [ 43 | ([False, False, False]) 44 | ]) 45 | def test_indecision_c(CandleFactory, indecision_candle): 46 | """Test function to check if a certain Candle has the 47 | typical indecission pattern""" 48 | 49 | for ix in range(len(CandleFactory)): 50 | assert CandleFactory[ix].indecision_c() is indecision_candle[ix] 51 | 52 | 53 | @pytest.mark.parametrize("height", [ 54 | ([30.9, 21.4, 32.6]) 55 | ]) 56 | def test_height(CandleFactory, height): 57 | 58 | for ix in range(len(CandleFactory)): 59 | assert CandleFactory[ix].height(pair="EUR_GBP") == height[ix] 60 | 61 | 62 | @pytest.mark.parametrize("middle_pts", [ 63 | ([0.87155, 0.86997, 0.86853]) 64 | ]) 65 | def test_middle_point(CandleFactory, middle_pts): 66 | """Test for the middle_point function""" 67 | for ix in range(len(CandleFactory)): 68 | assert CandleFactory[ix].middle_point() == middle_pts[ix] 69 | -------------------------------------------------------------------------------- /tests/forex/candle/test_CandleList.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @date: 22/11/2020 3 | @author: Ernesto Lowy 4 | @email: ernestolowy@gmail.com 5 | ''' 6 | import os 7 | import logging 8 | import datetime 9 | 10 | from utils import DATA_DIR 11 | from forex.candle import CandleList 12 | 13 | 14 | def test_candlelist_inst(clO): 15 | log = logging.getLogger('Test CandleList instantiation') 16 | log.debug('CandleList instantation') 17 | assert clO.type == 'short' 18 | assert clO.candles[0].colour == 'red' 19 | assert len(clO) == 2 20 | 21 | 22 | def test_pickle_dump(clO, tmp_path): 23 | log = logging.getLogger('Test for pickle_dump') 24 | log.debug('pickle_dump') 25 | 26 | clO.pickle_dump(f"{tmp_path}/clist.pckl") 27 | 28 | assert os.path.exists(DATA_DIR+"/clist.pckl") == 1 29 | 30 | 31 | def test_pickle_load(): 32 | log = logging.getLogger('Test for pickle_load') 33 | log.debug('pickle_load') 34 | 35 | loadedCL = CandleList.pickle_load(DATA_DIR+"/clist.pckl") 36 | assert loadedCL.instrument == 'AUD_USD' 37 | assert len(loadedCL) == 2 38 | 39 | 40 | def test_calc_rsi(clO_pickled): 41 | log = logging.getLogger('Test for calc_rsi function') 42 | log.debug('calc_rsi') 43 | 44 | clO_pickled.calc_rsi() 45 | 46 | assert clO_pickled.candles[15].rsi == 61.54 47 | assert clO_pickled.candles[50].rsi == 48.59 48 | 49 | 50 | def test_rsibounces(clO_pickled): 51 | log = logging.getLogger('Test for rsi_bounces function') 52 | log.debug('rsi_bounces') 53 | 54 | clO_pickled.calc_rsi() 55 | dict1 = clO_pickled.calc_rsi_bounces() 56 | 57 | dict2 = {'number': 31, 58 | 'lengths': [1, 1, 3, 4, 5, 7, 1, 1, 59 | 4, 5, 4, 3, 1, 1, 4, 5, 60 | 1, 2, 1, 1, 1, 1, 12, 7, 61 | 14, 3, 8, 2, 3, 1, 3]} 62 | 63 | assert dict1 == dict2 64 | 65 | 66 | def test_get_length_pips(clO_pickled): 67 | ''' 68 | This test check the functions for getting the length of 69 | the CandleList in number of pips and candles 70 | ''' 71 | log = logging.getLogger('Test for different length functions') 72 | log.debug('get_length') 73 | 74 | assert clO_pickled.get_length_pips() == 2493 75 | 76 | 77 | def test_fetch_by_time_q1(clO_pickled): 78 | """Test __getitem__ with query time 1""" 79 | 80 | adatetime = datetime.datetime(2019, 5, 7, 22, 0) 81 | c = clO_pickled[adatetime] 82 | 83 | assert c.o == 0.70118 84 | assert c.h == 0.70270 85 | 86 | 87 | def test_fetch_by_time_q2(clO_pickled): 88 | """Test __getitem__ with query time 2""" 89 | 90 | adatetime = datetime.datetime(2019, 5, 7, 21, 0) 91 | c = clO_pickled[adatetime] 92 | 93 | assert c.o == 0.70118 94 | assert c.h == 0.70270 95 | 96 | 97 | def test_fetch_by_time_q3(clO_pickled): 98 | """Test __getitem__ with query time 3""" 99 | 100 | adatetime = datetime.datetime(2019, 12, 4, 21, 0) 101 | c = clO_pickled[adatetime] 102 | 103 | assert c.o == 0.68487 104 | assert c.h == 0.68546 105 | assert c.time == datetime.datetime(2019, 12, 4, 22, 0) 106 | 107 | 108 | def test_fetch_by_time_q4(clO_pickled): 109 | """Test __getitem__ with weekend query time""" 110 | 111 | adatetime = datetime.datetime(2019, 5, 4, 22, 0) 112 | c = clO_pickled[adatetime] 113 | 114 | assert c is None 115 | 116 | 117 | def test_slice_with_start_end(clO_pickled): 118 | 119 | startdatetime = datetime.datetime(2019, 5, 7, 21, 0) 120 | endatetime = datetime.datetime(2019, 7, 1, 21, 0) 121 | 122 | clO_pickled.slice(start=startdatetime, 123 | end=endatetime, 124 | inplace=True) 125 | 126 | assert len(clO_pickled) == 40 127 | 128 | 129 | def test_last_time(clO_pickled): 130 | log = logging.getLogger('Test for last_time function') 131 | log.debug('last_time') 132 | subCl1 = clO_pickled.slice(start=clO_pickled.candles[0].time, 133 | end=datetime.datetime(2017, 1, 3, 22, 0)) 134 | subCl2 = clO_pickled.slice(start=clO_pickled.candles[0].time, 135 | end=datetime.datetime(2019, 7, 19, 22, 0)) 136 | subCl3 = clO_pickled.slice(start=clO_pickled.candles[0].time, 137 | end=datetime.datetime(2018, 1, 26, 22, 0)) 138 | 139 | lt1 = subCl1.get_lasttime(price=0.71754, type='long') 140 | lt2 = subCl2.get_lasttime(price=0.70621, type='short') 141 | lt3 = subCl3.get_lasttime(price=0.80879, type='short') 142 | assert lt1.isoformat() == '2016-02-28T22:00:00' 143 | assert lt2.isoformat() == '2019-04-22T21:00:00' 144 | assert lt3.isoformat() == '2015-01-19T22:00:00' 145 | 146 | 147 | def test_get_highest(clO_pickled): 148 | log = logging.getLogger('Test get_highest') 149 | log.debug('get_highest') 150 | 151 | clO_pickled.get_highest() 152 | 153 | assert clO_pickled.get_highest() == 1.10307 154 | 155 | 156 | def test_get_lowest(clO_pickled): 157 | log = logging.getLogger('Test get_lowest') 158 | log.debug('get_lowest') 159 | 160 | clO_pickled.get_lowest() 161 | 162 | assert clO_pickled.get_lowest() == 0.57444 163 | 164 | 165 | def test_add_two_clists(clO, clO1): 166 | log = logging.getLogger('Test get_lowest') 167 | log.debug('get_lowest') 168 | 169 | newClO = clO + clO1 170 | assert len(newClO.candles) == 4 171 | -------------------------------------------------------------------------------- /tests/forex/candle/test_candlelist_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from forex.candlelist_utils import calc_SR, calc_atr 4 | from params import pivots_params 5 | from forex.pivot import PivotList 6 | 7 | 8 | def test_calc_SR(pivotlist, tmp_path): 9 | """Check 'calc_SR' function""" 10 | harealst = calc_SR(pivotlist, outfile=f"{tmp_path}/calc_sr.png") 11 | 12 | assert len(harealst.halist) == 8 13 | 14 | 15 | def test_calc_SR_H8(clOH8_pickled, tmp_path): 16 | """Check 'calc_SR' function for H8 dataframe""" 17 | # these are the recommended params for H8 18 | pivots_params.th_bounces = 0.02 19 | pivotlistH8 = PivotList(clist=clOH8_pickled. 20 | slice(start=clOH8_pickled.candles[0].time, 21 | end=datetime.datetime(2021, 10, 29, 5, 0))) 22 | harealst = calc_SR(pivotlistH8, outfile=f"{tmp_path}/calc_sr_h8.png") 23 | assert len(harealst.halist) == 3 24 | 25 | 26 | def test_calc_atr(clO): 27 | """Check 'calc_atr' function""" 28 | atr = calc_atr(clO) 29 | 30 | assert atr == 374.1 31 | -------------------------------------------------------------------------------- /tests/forex/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from forex.candle import CandleList 4 | from forex.segment import Segment, SegmentList 5 | from forex.pivot import PivotList 6 | from utils import DATA_DIR 7 | 8 | 9 | @pytest.fixture 10 | def clO_pickled(): 11 | return CandleList.pickle_load(DATA_DIR+"/clist_audusd_2010_2020.pckl") 12 | 13 | 14 | @pytest.fixture 15 | def clOH8_pickled(): 16 | """Return a H8 pickled CandleList""" 17 | return CandleList.pickle_load(DATA_DIR+"/clist.AUDUSD.H8.2021.pckl") 18 | 19 | 20 | @pytest.fixture 21 | def seg_pickled(): 22 | return Segment.pickle_load(DATA_DIR+"/seg_audusd.pckl") 23 | 24 | 25 | @pytest.fixture 26 | def seg_pickledB(): 27 | return Segment.pickle_load(DATA_DIR+"/seg_audusdB.pckl") 28 | 29 | 30 | @pytest.fixture 31 | def seglist_pickled(): 32 | return SegmentList.pickle_load(DATA_DIR+"/seglist_audusd.pckl") 33 | 34 | 35 | @pytest.fixture 36 | def pivotlist(clO_pickled): 37 | """Obtain a PivotList object""" 38 | return PivotList(clist=clO_pickled) 39 | -------------------------------------------------------------------------------- /tests/forex/harea/test_HArea.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | import datetime 4 | 5 | from forex.harea import HArea 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "clist_ix, price, dt", 10 | [ 11 | (-5, 0.7267, datetime.datetime(2020, 11, 15, 22, 0)), 12 | (-10, 0.7002, "n.a."), 13 | (-30, 0.7204, datetime.datetime(2020, 10, 12, 13, 0)), 14 | ], 15 | ) 16 | def test_get_cross_time(clO_pickled, clist_ix, price, dt): 17 | log = logging.getLogger('Test for get_cross_time') 18 | log.debug('cross_time') 19 | 20 | resist = HArea(price=price, 21 | pips=5, 22 | instrument='AUD_USD', 23 | granularity='D') 24 | 25 | cross_time = resist.get_cross_time(candle=clO_pickled.candles[clist_ix], 26 | granularity='H8') 27 | 28 | assert cross_time == dt 29 | -------------------------------------------------------------------------------- /tests/forex/harea/test_HAreaList.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | import numpy as np 4 | 5 | from forex.harea import HAreaList 6 | from forex.harea import HArea 7 | from forex.candle import Candle 8 | from utils import DATA_DIR 9 | 10 | 11 | @pytest.fixture 12 | def hlist_factory(): 13 | log = logging.getLogger('Test for hlist_factory for ' 14 | 'returning a list of HArea objects') 15 | log.debug('hlist_factory') 16 | 17 | hlist = [] 18 | for p in np.arange(0.660, 0.720, 0.020): 19 | area = HArea(price=p, 20 | pips=30, 21 | instrument="AUD_USD", 22 | granularity="D") 23 | hlist.append(area) 24 | return hlist 25 | 26 | 27 | def test_HAreaList_inst(hlist_factory): 28 | log = logging.getLogger('Instantiate a HAreaList') 29 | log.debug('HAreaList') 30 | 31 | halist = HAreaList(halist=hlist_factory) 32 | assert len(halist.halist) == 3 33 | 34 | 35 | def test_onArea(hlist_factory): 36 | log = logging.getLogger('Test for test_onArea function') 37 | log.debug('test_onArea') 38 | 39 | halist = HAreaList(halist=hlist_factory) 40 | candle = { 41 | 'time': '2018-11-18T22:00:00', 42 | 'o': '0.68605', 43 | 'h': '0.71258', 44 | 'l': '0.68600', 45 | 'c': '0.70950' 46 | } 47 | 48 | c_candle = Candle(**candle) 49 | (hrsel, ix) = halist.onArea(candle=c_candle) 50 | 51 | assert hrsel.price == 0.70 52 | assert ix == 2 53 | 54 | 55 | def test_print(hlist_factory): 56 | '''Test 'print' function''' 57 | 58 | halist = HAreaList( 59 | halist=hlist_factory) 60 | 61 | res = halist.print() 62 | print(res) 63 | 64 | 65 | def test_plot(hlist_factory, clO_pickled, tmp_path): 66 | '''Test 'plot' function''' 67 | 68 | halist = HAreaList(halist=hlist_factory) 69 | 70 | halist.plot(clO_pickled, outfile=f"{tmp_path}/AUD_USD.halist.png") 71 | 72 | 73 | def test_from_csv(): 74 | harea_list_object = HAreaList.from_csv(f"{DATA_DIR}/harealist_file.txt", 75 | instrument="AUD_USD", 76 | granularity="D") 77 | assert len(harea_list_object.halist) == 6 78 | 79 | 80 | def test_from_list(): 81 | harea_list_object = HAreaList.from_list(prices=[0.1, 0.2, 0.3, 0.4], 82 | instrument="AUD_USD", 83 | granularity="D") 84 | assert len(harea_list_object.halist) == 4 85 | -------------------------------------------------------------------------------- /tests/forex/pivot/test_Pivot.py: -------------------------------------------------------------------------------- 1 | from params import pivots_params 2 | from forex.pivot import PivotList 3 | 4 | import pytest 5 | import datetime 6 | 7 | 8 | @pytest.fixture 9 | def pivot(clO_pickled): 10 | """Obtain a Pivot object""" 11 | 12 | pl = PivotList(clist=clO_pickled) 13 | return pl[5] 14 | 15 | 16 | def test_pre_aft_lens(pivot): 17 | ''' 18 | Check if 'pre' and 'aft' Segments have the 19 | correct number of candles 20 | ''' 21 | 22 | assert len(pivot.pre.clist) == 39 23 | assert len(pivot.aft.clist) == 17 24 | 25 | 26 | def test_pre_aft_start(pivot): 27 | ''' 28 | Check if 'pre' and 'aft' Segments have the 29 | correct start Datetimes 30 | ''' 31 | 32 | assert datetime.datetime(2011, 1, 19, 22, 0) == pivot.pre.start() 33 | assert datetime.datetime(2011, 2, 27, 22, 0) == pivot.aft.start() 34 | 35 | 36 | @pytest.mark.parametrize("ix," 37 | "date_pre," 38 | "date_post", 39 | [(40, datetime.datetime(2013, 8, 8, 21, 0), 40 | datetime.datetime(2013, 8, 8, 21, 0)), 41 | (50, datetime.datetime(2014, 6, 30, 21, 0), 42 | datetime.datetime(2014, 6, 30, 21, 0))]) 43 | def test_merge_pre(pivotlist, ix, date_pre, date_post): 44 | ''' 45 | Test function 'merge_pre' 46 | ''' 47 | pivot = pivotlist.pivots[ix] 48 | # Check pivot.pre.start() before running 'merge_pre' 49 | assert date_pre == pivot.pre.start() 50 | 51 | pivot.merge_pre(slist=pivotlist.slist, 52 | n_candles=pivots_params.n_candles, 53 | diff_th=pivots_params.diff_th) 54 | 55 | # Check pivot.pre.start() after running 'merge_pre' 56 | assert date_post == pivot.pre.start() 57 | 58 | 59 | @pytest.mark.parametrize("ix," 60 | "date_pre," 61 | "date_post", 62 | [(70, datetime.datetime(2015, 10, 8, 21, 0), 63 | datetime.datetime(2015, 10, 8, 21, 0)), 64 | (80, datetime.datetime(2016, 6, 21, 21, 0), 65 | datetime.datetime(2016, 11, 6, 22, 0))]) 66 | def test_merge_aft(pivotlist, ix, date_pre, date_post): 67 | ''' 68 | Test function to merge 'aft' Segment 69 | ''' 70 | pivot = pivotlist.pivots[ix] 71 | # Check pivot.aft.end() before running 'merge_aft' 72 | assert date_pre == pivot.aft.end() 73 | 74 | pivot.merge_aft(slist=pivotlist.slist, 75 | n_candles=pivots_params.n_candles, 76 | diff_th=pivots_params.diff_th) 77 | 78 | # Check pivot.aft.end() after running 'merge_aft' 79 | assert date_post == pivot.aft.end() 80 | 81 | 82 | def test_calc_score_d(pivot): 83 | ''' 84 | Test function named 'calc_score' 85 | with 'diff' parameter (def option) 86 | ''' 87 | score = pivot.calc_score() 88 | 89 | assert score == 627.4 90 | 91 | 92 | def test_calc_score_c(pivot): 93 | ''' 94 | Test function named 'calc_score' 95 | with 'candle' parameter 96 | ''' 97 | score = pivot.calc_score(type="candles") 98 | 99 | assert score == 56 100 | 101 | 102 | @pytest.mark.parametrize("ix," 103 | "new_b", 104 | [(13, datetime.datetime(2011, 7, 30, 21, 0)), 105 | (60, datetime.datetime(2015, 4, 30, 21, 0)), 106 | (68, datetime.datetime(2015, 9, 3, 21, 0)), 107 | (100, datetime.datetime(2017, 12, 7, 22, 0)) 108 | ]) 109 | def test_adjust_pivottime(pivotlist, ix, new_b): 110 | p = pivotlist[ix] 111 | start_t = pivotlist.clist.candles[0].time 112 | end_t = p.candle.time 113 | newt = p.adjust_pivottime(clistO=pivotlist.clist.slice(start=start_t, 114 | end=end_t)) 115 | 116 | assert new_b == newt 117 | -------------------------------------------------------------------------------- /tests/forex/pivot/test_PivotList.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | from forex.pivot import PivotList 5 | 6 | 7 | def test_get_score(pivotlist): 8 | """ 9 | Test 'get_score' function 10 | """ 11 | assert pivotlist.get_score() == 90358.0 12 | 13 | 14 | def test_get_avg_score(pivotlist): 15 | """ 16 | Test 'get_avg_score' function 17 | """ 18 | assert pivotlist.get_avg_score() == 654.8 19 | 20 | 21 | def test_in_area(pivotlist): 22 | """ 23 | Test 'inarea_pivots' function 24 | """ 25 | pl_inarea = pivotlist.inarea_pivots(price=0.75) 26 | 27 | # check the len of pl.plist after getting the pivots in the S/R area 28 | assert len(pl_inarea) == 8 29 | 30 | 31 | def test_plot_pivots(pivotlist, tmp_path): 32 | """ 33 | Test plot_pivots 34 | """ 35 | outfile = f"{tmp_path}/{pivotlist.clist.instrument}.png" 36 | outfile_rsi = f"{tmp_path}/{pivotlist.clist.instrument}.final_rsi.png" 37 | 38 | pivotlist.clist.calc_rsi() 39 | pivotlist.plot_pivots(outfile_prices=outfile, 40 | outfile_rsi=outfile_rsi) 41 | 42 | assert os.path.exists(outfile) == 1 43 | assert os.path.exists(outfile_rsi) == 1 44 | 45 | 46 | def test_print_pivots_dates(pivotlist): 47 | dtl = pivotlist.print_pivots_dates() 48 | assert len(dtl) == 138 49 | 50 | 51 | def test_fetch_by_type(pivotlist): 52 | """Obtain a pivotlist of a certain type""" 53 | 54 | newpl = pivotlist.fetch_by_type(type=-1) 55 | assert len(newpl.pivots) == 68 56 | 57 | 58 | def test_fetch_by_time(pivotlist): 59 | """Obtain a Pivot object by datetime""" 60 | 61 | adt = datetime.datetime(2014, 10, 2, 21, 0) 62 | rpt = pivotlist.fetch_by_time(adt) 63 | assert rpt.candle.time == datetime.datetime(2014, 10, 2, 21, 0) 64 | 65 | 66 | def test_pivots_report(pivotlist, tmp_path): 67 | """Get a PivotList report""" 68 | 69 | outfile = f"{tmp_path}/{pivotlist.clist.instrument}.preport.txt" 70 | pivotlist.pivots_report(outfile=outfile) 71 | 72 | 73 | def test_calc_itrend(clO_pickled): 74 | """Calc init of trend""" 75 | 76 | subCl1 = clO_pickled.slice(start=clO_pickled.candles[0].time, 77 | end=datetime.datetime(2020, 6, 10, 22, 0)) 78 | subCl2 = clO_pickled.slice(start=clO_pickled.candles[0].time, 79 | end=datetime.datetime(2020, 3, 19, 22, 0)) 80 | subCl3 = clO_pickled.slice(start=clO_pickled.candles[0].time, 81 | end=datetime.datetime(2017, 12, 8, 22, 0)) 82 | 83 | pl1 = PivotList(clist=subCl1) 84 | pl2 = PivotList(clist=subCl2) 85 | pl3 = PivotList(clist=subCl3) 86 | 87 | assert pl1.calc_itrend().start() == datetime.datetime(2020, 3, 18, 21, 0) 88 | assert pl2.calc_itrend().start() == datetime.datetime(2019, 12, 30, 22, 0) 89 | assert pl3.calc_itrend().start() == datetime.datetime(2017, 9, 7, 21, 0) 90 | -------------------------------------------------------------------------------- /tests/forex/segment/test_Segment.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def test_start(seg_pickled): 5 | assert datetime.datetime(2011, 2, 27, 22, 0) == seg_pickled.start() 6 | 7 | 8 | def test_end(seg_pickled): 9 | assert datetime.datetime(2011, 3, 15, 21, 0) == seg_pickled.end() 10 | 11 | 12 | def test_get_lowest(seg_pickled): 13 | assert datetime.datetime(2011, 3, 15, 21, 0) == \ 14 | seg_pickled.get_lowest().time 15 | 16 | 17 | def test_get_highest(seg_pickled): 18 | assert datetime.datetime(2011, 2, 28, 22, 0) == \ 19 | seg_pickled.get_highest().time 20 | 21 | 22 | def test_append(seg_pickled, seg_pickledB): 23 | assert len(seg_pickled.clist) == 17 24 | assert seg_pickled.diff == 346.8 25 | seg_pickled.append(seg_pickledB) 26 | assert len(seg_pickled.clist) == 60 27 | assert seg_pickled.diff == 743.1 28 | 29 | 30 | def test_prepend(seg_pickled, seg_pickledB): 31 | assert len(seg_pickled.clist) == 17 32 | assert seg_pickled.diff == 346.8 33 | seg_pickled.prepend(seg_pickledB) 34 | assert len(seg_pickled.clist) == 60 35 | assert seg_pickled.diff == 35.7 36 | -------------------------------------------------------------------------------- /tests/forex/segment/test_SegmentList.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def test_calc_diff(seglist_pickled): 5 | 6 | seglist_pickled.calc_diff() 7 | 8 | assert seglist_pickled.diff == 2513.6 9 | 10 | 11 | def test_length_cl(seglist_pickled): 12 | 13 | assert seglist_pickled.length_cl() == 2801 14 | 15 | 16 | def test_start(seglist_pickled): 17 | 18 | assert seglist_pickled.start() == datetime.datetime(2010, 11, 16, 22, 0) 19 | 20 | 21 | def test_end(seglist_pickled): 22 | 23 | assert seglist_pickled.end() == datetime.datetime(2020, 11, 18, 22, 0) 24 | 25 | 26 | def test_fetch_by_start(seglist_pickled): 27 | 28 | adt = datetime.datetime(2019, 4, 16, 21, 0) 29 | s = seglist_pickled.fetch_by_start(adt) 30 | 31 | assert s.start() == adt 32 | 33 | 34 | def test_fetch_by_end(seglist_pickled): 35 | """Test fetch_by_end""" 36 | 37 | adt = datetime.datetime(2019, 6, 13, 21, 0) 38 | 39 | s = seglist_pickled.fetch_by_end(adt) 40 | 41 | assert s.end() == adt 42 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import datetime 3 | 4 | from utils import calculate_profit, is_even_hour 5 | 6 | prices = [((0.69200, 0.68750), "short", "AUD_USD", -45), 7 | ((0.69200, 0.68750), "long", "AUD_USD", 45), 8 | ((155.770, 158.245), "long", "EUR_JPY", -247.5), 9 | ((155.770, 158.245), "short", "EUR_JPY", 247.5)] 10 | 11 | 12 | @pytest.mark.parametrize("bi_price,type,pair,expected", prices) 13 | def test_calc_profit(bi_price, type, pair, expected): 14 | """Test function 'calc_profit'""" 15 | assert expected == calculate_profit(prices=bi_price, type=type, pair=pair) 16 | 17 | 18 | datetimes = [(datetime.datetime(2015, 2, 28, 22, 0, 0), True), 19 | (datetime.datetime(2015, 5, 28, 9, 0, 0), False)] 20 | 21 | 22 | @pytest.mark.parametrize("datetime,expected", datetimes) 23 | def test_is_even_hour(datetime, expected): 24 | res = is_even_hour(datetime) 25 | assert res == expected, "Non-correct time info" 26 | -------------------------------------------------------------------------------- /tests/trade_bot/test_TradeBot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pytest 4 | import glob 5 | 6 | from params import tradebot_params, pivots_params 7 | from trade_bot.trade_bot import TradeBot 8 | from forex.candle import CandleList 9 | from utils import DATA_DIR, load_config_yaml_file 10 | 11 | # create logger 12 | tb_logger = logging.getLogger(__name__) 13 | tb_logger.setLevel(logging.DEBUG) 14 | 15 | 16 | @pytest.fixture 17 | def clO_pickled(): 18 | clO = CandleList.pickle_load(DATA_DIR+"/clist_audusd_2010_2020.pckl") 19 | clO.calc_rsi() 20 | 21 | return clO 22 | 23 | 24 | @pytest.fixture 25 | def tb_object(): 26 | tb = TradeBot( 27 | pair='EUR_GBP', 28 | timeframe='D', 29 | start='2020-06-29 22:00:00', 30 | end='2020-07-01 22:00:00') 31 | return tb 32 | 33 | 34 | @pytest.fixture 35 | def scan_pickled(clO_pickled, tmp_path): 36 | """Prepare a pickled file with potential trades identified by scan""" 37 | tb = TradeBot( 38 | pair='AUD_USD', 39 | timeframe='D', 40 | start='2016-01-05 22:00:00', 41 | end='2016-02-11 22:00:00', 42 | clist=clO_pickled) 43 | outfile = tb.scan(prefix=f"{tmp_path}") 44 | return outfile 45 | 46 | 47 | @pytest.fixture 48 | def clean_tmp(): 49 | yield 50 | print("Cleanup files") 51 | files = glob.glob(DATA_DIR+"/out/*") 52 | 53 | for f in files: 54 | os.remove(f) 55 | 56 | 57 | def test_scan(tb_object, tmp_path): 58 | """Check the 'scan' function""" 59 | outfile = tb_object.scan(prefix=f"{tmp_path}") 60 | assert os.path.isfile(outfile) 61 | 62 | 63 | def test_scan1(tmp_path): 64 | """ 65 | Test the scan() function with a certain TradeBot interval 66 | """ 67 | pivots_params.th_bounces = 0.05 68 | tb = TradeBot( 69 | pair='AUD_USD', 70 | timeframe='D', 71 | start='2016-01-05 22:00:00', 72 | end='2016-02-11 22:00:00') 73 | outfile = tb.scan(prefix=f"{tmp_path}/") 74 | assert os.path.isfile(outfile) 75 | 76 | 77 | def test_scan_withsrlist(tmp_path): 78 | """ 79 | Test the scan() function with the SR list passed 80 | in a .yaml file 81 | """ 82 | pivots_params.th_bounces = 0.05 83 | tb = TradeBot( 84 | pair='AUD_USD', 85 | timeframe='D', 86 | start='2016-01-05 22:00:00', 87 | end='2016-02-11 22:00:00') 88 | srlist_dict = load_config_yaml_file(f"{DATA_DIR}/harealist_file.yaml") 89 | outfile = tb.scan(prefix=f"{tmp_path}/", 90 | srlist=srlist_dict["AUD_USD"]) 91 | assert os.path.isfile(outfile) 92 | 93 | 94 | def test_scan_withclist(clO_pickled, tmp_path): 95 | """ 96 | Test the scan() method using a pickled CandleList 97 | """ 98 | pivots_params.th_bounces = 0.05 99 | tb = TradeBot( 100 | pair='AUD_USD', 101 | timeframe='D', 102 | start='2016-01-05 22:00:00', 103 | end='2016-02-11 22:00:00', 104 | clist=clO_pickled) 105 | outfile = tb.scan(prefix=f"{tmp_path}") 106 | assert os.path.isfile(outfile) 107 | 108 | 109 | def test_scan_withclist_future(clO_pickled, tmp_path): 110 | """ 111 | Test tradebot using a pickled CandleList and using an end TradeBot time 112 | post clO_pickled end time. This scan() invokation will not return any 113 | preTrade 114 | """ 115 | 116 | tb = TradeBot( 117 | pair='AUD_USD', 118 | timeframe='D', 119 | start='2020-11-15 22:00:00', 120 | end='2020-11-23 22:00:00', 121 | clist=clO_pickled) 122 | outfile = tb.scan(prefix=f"{tmp_path}") 123 | with pytest.raises(TypeError): 124 | assert os.path.isfile(outfile) 125 | 126 | 127 | def test_prepare_trades(clO_pickled, scan_pickled): 128 | """ 129 | Test the prepare_trades() method with a pickled list 130 | of preTrade objects 131 | """ 132 | tb = TradeBot( 133 | pair='AUD_USD', 134 | timeframe='D', 135 | start='2016-01-05 22:00:00', 136 | end='2016-02-11 22:00:00', 137 | clist=clO_pickled) 138 | tl = tb.prepare_trades(pretrades=scan_pickled) 139 | assert len(tl) == 5 or len(tl) == 4 140 | 141 | 142 | def test_prepare_trades_nextSR(clO_pickled, 143 | scan_pickled): 144 | """ 145 | Test tradebot using a pickled CandleList and 146 | tradebot_params.adj_SL = 'nextSR' 147 | """ 148 | tradebot_params.adj_SL = 'nextSR' 149 | 150 | tb = TradeBot( 151 | pair='AUD_USD', 152 | timeframe='D', 153 | start='2016-01-05 22:00:00', 154 | end='2016-02-11 22:00:00', 155 | clist=clO_pickled) 156 | tl = tb.prepare_trades(pretrades=scan_pickled) 157 | assert len(tl) == 5 or len(tl) == 4 158 | 159 | 160 | def test_prepare_trades_pips(clO_pickled, 161 | scan_pickled): 162 | """ 163 | Test tradebot using a pickled CandleList and 164 | tradebot_params.adj_SL = 'pips' 165 | """ 166 | tradebot_params.adj_SL = "pips" 167 | 168 | tb = TradeBot( 169 | pair="AUD_USD", 170 | timeframe="D", 171 | start="2016-01-05 22:00:00", 172 | end="2016-02-11 22:00:00", 173 | clist=clO_pickled) 174 | tl = tb.prepare_trades(pretrades=scan_pickled) 175 | assert len(tl) == 5 or len(tl) == 4 176 | -------------------------------------------------------------------------------- /tests/trade_bot/test_trade_bot_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pytest 3 | import numpy as np 4 | 5 | from forex.harea import HAreaList, HArea 6 | from trade_bot.trade_bot_utils import ( 7 | adjust_SL_pips, 8 | get_trade_type, 9 | adjust_SL_candles, 10 | adjust_SL_nextSR, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def halist_factory(): 16 | hlist = [] 17 | for p in np.arange(0.610, 0.80, 0.020): 18 | area = HArea(price=p, pips=30, instrument="AUD_USD", granularity="D") 19 | hlist.append(area) 20 | 21 | halist = HAreaList(halist=hlist) 22 | return halist 23 | 24 | 25 | def test_adjust_SL_pips_short(clO_pickled): 26 | clObj = clO_pickled.candles[10] 27 | SL = adjust_SL_pips(clObj, "short", pair="AUD_USD") 28 | assert 0.9814 == SL 29 | 30 | 31 | def test_adjust_SL_pips_long(clO_pickled): 32 | clObj = clO_pickled.candles[100] 33 | SL = adjust_SL_pips(clObj, "long", pair="AUD_USD") 34 | assert 1.0026 == SL 35 | 36 | 37 | datetimes = [ 38 | ( 39 | datetime.datetime(2018, 4, 27, 22, 0, 0), 40 | datetime.datetime(2020, 4, 27, 21, 0, 0), 41 | "short", 42 | ), 43 | ( 44 | datetime.datetime(2018, 5, 18, 21, 0, 0), 45 | datetime.datetime(2020, 3, 18, 21, 0, 0), 46 | "long", 47 | ), 48 | ( 49 | datetime.datetime(2018, 6, 17, 21, 0, 0), 50 | datetime.datetime(2020, 1, 17, 21, 0, 0), 51 | "long", 52 | ), 53 | ( 54 | datetime.datetime(2018, 7, 11, 21, 0, 0), 55 | datetime.datetime(2019, 8, 11, 21, 0, 0), 56 | "long", 57 | ), 58 | ( 59 | datetime.datetime(2018, 1, 9, 21, 0, 0), 60 | datetime.datetime(2019, 1, 9, 21, 0, 0), 61 | "short", 62 | ), 63 | ] 64 | 65 | 66 | @pytest.mark.parametrize("start," "end," "type", datetimes) 67 | def test_get_trade_type(start, end, type, clO_pickled): 68 | new_cl = clO_pickled.slice(start=start, end=end) 69 | 70 | assert type == get_trade_type(end, new_cl) 71 | 72 | 73 | def test_adjust_SL_candles_short(clO_pickled): 74 | """Test adjust_SL_candles function with a short trade""" 75 | start = datetime.datetime(2018, 9, 2, 21, 0) 76 | end = datetime.datetime(2020, 9, 2, 21, 0) 77 | subClO = clO_pickled.slice(start=start, end=end) 78 | SL = adjust_SL_candles("short", subClO) 79 | 80 | assert SL == 0.74138 81 | 82 | 83 | def test_adjust_SL_candles_long(clO_pickled): 84 | """Test adjust_SL_candles function with a short trade""" 85 | start = datetime.datetime(2019, 9, 28, 21, 0) 86 | end = datetime.datetime(2020, 9, 28, 21, 0) 87 | subClO = clO_pickled.slice(start=start, end=end) 88 | SL = adjust_SL_candles("long", subClO) 89 | 90 | assert SL == 0.70061 91 | 92 | 93 | def test_adjust_SL_nextSR(halist_factory): 94 | SL, TP = adjust_SL_nextSR(halist_factory, 2, "short") 95 | assert SL == 0.67 96 | assert TP == 0.63 97 | -------------------------------------------------------------------------------- /tests/trading_journal/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from trading_journal.open_trade import UnawareTrade 4 | from trading_journal.trade_journal import TradeJournal 5 | from utils import DATA_DIR 6 | 7 | 8 | @pytest.fixture 9 | def t_object(clO_pickled): 10 | """Returns a UnawareTrade object""" 11 | 12 | td = UnawareTrade( 13 | start="2017-04-10 14:00:00", 14 | end="2017-04-26 14:00:00", 15 | entry=0.74960, 16 | TP=0.75592, 17 | SL=0.74718, 18 | SR=0.74784, 19 | pair="AUD_USD", 20 | type="long", 21 | timeframe="D", 22 | clist=clO_pickled, 23 | clist_tm=clO_pickled) 24 | return td 25 | 26 | 27 | @pytest.fixture 28 | def t_object_list(clO_pickled): 29 | """Returns a list of UnawareTrade objects""" 30 | 31 | td = UnawareTrade( 32 | start="2017-04-10 14:00:00", 33 | end="2017-04-26 14:00:00", 34 | entry=0.74960, 35 | TP=0.75592, 36 | SL=0.74718, 37 | SR=0.74784, 38 | pair="AUD_USD", 39 | type="long", 40 | timeframe="D", 41 | strat="counter_b1", 42 | id="AUD_USD 10APR2017H8", 43 | clist=clO_pickled, 44 | clist_tm=clO_pickled) 45 | return [td] 46 | 47 | 48 | @pytest.fixture 49 | def tjO(scope="session"): 50 | """Returns a trade_journal object for a Counter trade""" 51 | td = TradeJournal(url=DATA_DIR+"/testCounter.xlsx", 52 | worksheet="trading_journal") 53 | return td 54 | -------------------------------------------------------------------------------- /tests/trading_journal/data_for_tests.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | # AUD_USD - All of them entered (mix of success/failures) 4 | trades1 = [ 5 | ("2016-12-28 22:00:00", "long", 0.71857, 0.70814, 0.74267, 0.72193), 6 | ("2017-04-11 22:00:00", "long", 0.74692, 0.73898, 0.77351, 0.75277), 7 | ("2017-09-11 22:00:00", "short", 0.80451, 0.81550, 0.77864, 0.80079), 8 | ("2018-05-03 22:00:00", "long", 0.75064, 0.74267, 0.77386, 0.75525), 9 | ] 10 | 11 | # ["start", "pair", "timeframe", "type", "SR", "SL", "entry"] 12 | trades_entered = [ 13 | ("2019-04-17 22:00:00", "AUD_USD", "D", "short", 0.71600, 0.72187, 0.71433), 14 | ("2019-05-19 22:00:00", "AUD_USD", "D", "long", 0.68755, 0.67735, 0.68970), 15 | ("2019-07-17 22:00:00", "AUD_USD", "D", "long", 0.69976, 0.69114, 0.70524), 16 | ("2019-08-25 22:00:00", "AUD_USD", "D", "long", 0.67499, 0.66530, 0.67766), 17 | ] 18 | 19 | # ["start", "pair", "timeframe", "type", "SR", "SL", "entry"] 20 | trades_for_test_run = [ 21 | ("2019-05-19 22:00:00", "AUD_USD", "H12", "long", 0.68755, 0.67735, 0.68970), 22 | ("2019-05-20 10:00:00", "AUD_USD", "H12", "long", 0.68755, 0.67735, 0.68970), 23 | ("2024-07-15 17:00:00", "AUD_USD", "H4", "short", 0.67810, 0.68092, 0.67634), 24 | ] 25 | 26 | 27 | last_times = [ 28 | datetime.datetime(2016, 2, 28, 22, 0), 29 | datetime.datetime(2017, 1, 9, 22, 0), 30 | datetime.datetime(2015, 5, 13, 21, 0), 31 | datetime.datetime(2017, 6, 4, 21, 0), 32 | ] 33 | 34 | start_hours = [ 35 | ( 36 | datetime.datetime(2023, 12, 9, 9, 1), 37 | datetime.datetime(2023, 12, 9, 9, 0), 38 | "H4"), 39 | ( 40 | datetime.datetime(2023, 12, 9, 14, 37), 41 | datetime.datetime(2023, 12, 9, 13, 0), 42 | "H4", 43 | ), 44 | ( 45 | datetime.datetime(2023, 12, 9, 0, 10), 46 | datetime.datetime(2023, 12, 8, 21, 0), 47 | "H8", 48 | ), 49 | ( 50 | datetime.datetime(2023, 12, 9, 20, 0), 51 | datetime.datetime(2023, 12, 9, 9, 0), 52 | "H12", 53 | ), 54 | ( 55 | datetime.datetime(2023, 12, 9, 20, 0), 56 | datetime.datetime(2023, 12, 8, 21, 0), 57 | "D"), 58 | ( 59 | datetime.datetime(2023, 12, 9, 23, 0), 60 | datetime.datetime(2023, 12, 9, 21, 0), 61 | "H4", 62 | ), 63 | ( 64 | datetime.datetime(2023, 7, 1, 6, 0), 65 | datetime.datetime(2023, 6, 30, 21, 0), 66 | "D"), 67 | ] 68 | -------------------------------------------------------------------------------- /tests/trading_journal/test_OpenTrade.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import datetime 3 | 4 | from trading_journal.open_trade import ( 5 | OpenTrade, 6 | BreakEvenTrade, 7 | AwareTrade, 8 | UnawareTrade, 9 | TrackingTrade, 10 | ) 11 | from trading_journal.trade_utils import check_timeframes_fractions 12 | from data_for_tests import trades1, trades_entered, trades_for_test_run 13 | from params import trade_management_params 14 | 15 | 16 | class TestOpenTrade: 17 | trade_features = { 18 | "start": "2016-12-28 22:00:00", 19 | "pair": "AUD_USD", 20 | "timeframe": "D", 21 | "type": "short", 22 | "SR": 0.71857, 23 | "SL": 0.70814, 24 | "entry": 0.72193, 25 | } 26 | 27 | def init_trades(self, mock_trades): 28 | features = ["start", "pair", "timeframe", "type", "SR", "SL", "entry"] 29 | trade_features_list = [] 30 | for trade_values in mock_trades: 31 | trade_features_list.append({key: value for key, value in zip(features, 32 | trade_values)}) 33 | return trade_features_list 34 | 35 | @pytest.mark.parametrize( 36 | "TP, RR, exp_instance", 37 | [(0.86032, None, OpenTrade), (None, 1.5, OpenTrade), (None, None, Exception)], 38 | ) 39 | def test_instantiation(self, TP, RR, exp_instance): 40 | if exp_instance is OpenTrade: 41 | OpenTrade(RR=RR, TP=TP, **self.trade_features, init_clist=False) 42 | elif exp_instance == Exception: 43 | with pytest.raises(Exception): 44 | OpenTrade(RR=RR, TP=TP, **self.trade_features, init_clist=False) 45 | 46 | def test_init_clist(self): 47 | trade_features_list = self.init_trades(mock_trades=trades_entered) 48 | td = OpenTrade( 49 | **trade_features_list[0], 50 | TP=0.74267, 51 | init_clist=True, 52 | ) 53 | assert len(td.clist.candles) == 4130 54 | 55 | results = [True, False, True, False] 56 | prices = [0.72632, 0.74277, 0.79, 0.74] 57 | trade_data1 = [(*trade, price) for trade, price in zip(trades1, prices)] 58 | trade_data2 = [(*trade, result) for trade, result in zip(trade_data1, results)] 59 | 60 | @pytest.mark.parametrize("start,type,SR,SL,TP,entry,price,result", trade_data2) 61 | def test_isin_profit(self, start, type, SR, SL, TP, entry, price, result): 62 | open_trade = OpenTrade( 63 | start=start, 64 | pair="AUD_USD", 65 | timeframe="D", 66 | type=type, 67 | SR=SR, 68 | SL=SL, 69 | TP=TP, 70 | entry=entry, 71 | init_clist=False, 72 | ) 73 | assert open_trade.isin_profit(price=price) is result 74 | 75 | def test_append_trademanagement_candles(self): 76 | open_trade = OpenTrade(RR=1.5, 77 | **self.trade_features, 78 | init_clist=True) 79 | open_trade.append_trademanagement_candles(d=datetime.datetime(2023, 1, 10, 21, 0, 0), 80 | fraction=3) 81 | expected_hours = [22, 6, 14] 82 | for i, expected_hour in enumerate(expected_hours): 83 | assert ( 84 | open_trade.preceding_candles[i].time.hour == expected_hour 85 | ), f"Expected hour {expected_hour}, but got {open_trade.preceding_candles[i].time.hour}" 86 | 87 | @pytest.mark.parametrize( 88 | "trade_feats, results", 89 | [ 90 | (("D", "H8", (datetime.datetime(2023, 1, 10, 21, 0, 0))), (0, 0.7081)), 91 | (("H12", "H8", (datetime.datetime(2023, 1, 10, 21, 0, 0))), (2, 0.7081)), 92 | (("H12", "H8", (datetime.datetime(2023, 1, 10, 9, 0, 0))), (2, 0.7081)) 93 | ], 94 | ) 95 | def test_process_trademanagement(self, trade_feats, results): 96 | """Tests process_trademanagent which is also invoking 'check_timeframes_fractions' 97 | to calculate the fraction of candles 98 | """ 99 | fraction = check_timeframes_fractions(timeframe1=trade_feats[0], 100 | timeframe2=trade_feats[1]) 101 | open_trade = OpenTrade(RR=1.5, 102 | **self.trade_features, 103 | init_clist=True) 104 | open_trade.process_trademanagement(d=trade_feats[2], 105 | fraction=fraction) 106 | assert len(open_trade.preceding_candles) == results[0] 107 | assert open_trade.SL.price == results[1] 108 | 109 | 110 | class TestBreakEvenTrade(TestOpenTrade): 111 | 112 | # trades outcome for run() function 113 | trades_outcome = [ 114 | ("success", 114.0), # outcome, pips 115 | ("failure", 10), 116 | ("failure", 10), 117 | ("failure", -40.0), 118 | ] 119 | 120 | @pytest.mark.parametrize( 121 | "TP, RR, exp_instance", 122 | [ 123 | (0.86032, None, BreakEvenTrade), 124 | (None, 1.5, BreakEvenTrade), 125 | (None, None, Exception), 126 | ], 127 | ) 128 | def test_instantiation(self, TP, RR, exp_instance): 129 | trade_features_list = self.init_trades(mock_trades=trades_entered) 130 | if exp_instance is BreakEvenTrade: 131 | BreakEvenTrade( 132 | RR=RR, TP=TP, **trade_features_list[0], init_clist=False 133 | ) 134 | elif exp_instance == Exception: 135 | with pytest.raises(Exception): 136 | BreakEvenTrade( 137 | RR=RR, TP=TP, **trade_features_list[0], 138 | init_clist=False 139 | ) 140 | 141 | def test_run(self, clOH8_2019_pickled, clO_pickled): 142 | trade_features_list = self.init_trades(mock_trades=trades_entered) 143 | for ix in range(len(trade_features_list)): 144 | breakeven_trade_object = BreakEvenTrade( 145 | **trade_features_list[ix], 146 | RR=1.5, 147 | clist=clO_pickled, 148 | clist_tm=clOH8_2019_pickled, 149 | connect=False, 150 | ) 151 | breakeven_trade_object.initialise() 152 | breakeven_trade_object.run() 153 | assert breakeven_trade_object.outcome == self.trades_outcome[ix][0] 154 | assert breakeven_trade_object.pips == self.trades_outcome[ix][1] 155 | 156 | 157 | class TestAwareTrade(TestOpenTrade): 158 | 159 | # trades outcome for run() function 160 | trades_outcome = [ 161 | ("success", 114.0), # outcome , pips 162 | ("failure", 18.0), 163 | ("failure", -141.0), 164 | ("failure", 26), 165 | ] 166 | 167 | @pytest.mark.parametrize( 168 | "TP, RR, exp_instance", 169 | [(0.86032, None, AwareTrade), 170 | (None, 1.5, AwareTrade), 171 | (None, None, Exception)], 172 | ) 173 | def test_instantiation(self, TP, RR, exp_instance): 174 | trade_features_list = self.init_trades(mock_trades=trades_entered) 175 | if exp_instance is AwareTrade: 176 | AwareTrade(RR=RR, TP=TP, **trade_features_list[0], 177 | init_clist=False) 178 | elif exp_instance == Exception: 179 | with pytest.raises(Exception): 180 | AwareTrade( 181 | RR=RR, TP=TP, **trade_features_list[0], 182 | init_clist=False 183 | ) 184 | 185 | def test_run(self, clOH8_2019_pickled, clO_pickled): 186 | trade_features_list = self.init_trades(mock_trades=trades_entered) 187 | for ix in range(len(trade_features_list)): 188 | aware_trade_object = AwareTrade( 189 | **trade_features_list[ix], 190 | RR=1.5, 191 | clist=clO_pickled, 192 | clist_tm=clOH8_2019_pickled, 193 | connect=False, 194 | ) 195 | aware_trade_object.initialise() 196 | aware_trade_object.run() 197 | assert aware_trade_object.outcome == self.trades_outcome[ix][0] 198 | assert aware_trade_object.pips == self.trades_outcome[ix][1] 199 | 200 | 201 | class TestUnawareTrade(TestOpenTrade): 202 | 203 | # trades outcome for test_run() function 204 | trades_outcome = [ 205 | ("success", 114.0), # outcome , pips 206 | ("failure", 18), 207 | ("failure", -24.0), 208 | ("failure", -40.0), 209 | ] 210 | 211 | # trades outcome for test_run1() function 212 | trades_outcome1 = [ 213 | ("n.a.", -15.6), # outcome , pips 214 | ("n.a.", -15.6), 215 | ("n.a.", 39) 216 | ] 217 | 218 | @pytest.mark.parametrize( 219 | "TP, RR, exp_instance", 220 | [ 221 | (0.86032, None, UnawareTrade), 222 | (None, 1.5, UnawareTrade), 223 | (None, None, Exception), 224 | ], 225 | ) 226 | def test_instantiation(self, TP, RR, exp_instance): 227 | trade_features_list = self.init_trades(mock_trades=trades_entered) 228 | if exp_instance is UnawareTrade: 229 | UnawareTrade(RR=RR, TP=TP, **trade_features_list[0], 230 | init_clist=False) 231 | elif exp_instance == Exception: 232 | with pytest.raises(Exception): 233 | UnawareTrade( 234 | RR=RR, TP=TP, **self.trade_features_list[0], 235 | init_clist=False 236 | ) 237 | 238 | def test_run(self, clOH8_2019_pickled, clO_pickled): 239 | trade_features_list = self.init_trades(mock_trades=trades_entered) 240 | for ix in range(len(trade_features_list)): 241 | unaware_trade_object = UnawareTrade( 242 | **trade_features_list[ix], 243 | RR=1.5, 244 | clist=clO_pickled, 245 | clist_tm=clOH8_2019_pickled, 246 | connect=False, 247 | ) 248 | unaware_trade_object.initialise() 249 | unaware_trade_object.run() 250 | assert unaware_trade_object.outcome == self.trades_outcome[ix][0] 251 | assert unaware_trade_object.pips == self.trades_outcome[ix][1] 252 | 253 | def test_run1(self): 254 | """Test run() with different timeframes and start of the trades. 255 | Just to check how well the method behaves, also this test will not 256 | used the pickled clists 257 | """ 258 | trade_management_params.numperiods = 5 259 | trade_features_list = self.init_trades(mock_trades=trades_for_test_run) 260 | for ix in range(len(trade_features_list)): 261 | unaware_trade_object = UnawareTrade( 262 | **trade_features_list[ix], 263 | RR=1.5, 264 | init_clist=True 265 | ) 266 | unaware_trade_object.initialise() 267 | unaware_trade_object.run() 268 | assert unaware_trade_object.outcome == self.trades_outcome1[ix][0] 269 | assert unaware_trade_object.pips == self.trades_outcome1[ix][1] 270 | 271 | 272 | class TestTrackingTrade(TestOpenTrade): 273 | 274 | # trades outcome for run() function 275 | trades_outcome = [ 276 | ("success", 114.0, datetime.datetime(2019, 4, 23, 21, 0)), # outcome , pips, end datetime 277 | ("failure", 9.0, datetime.datetime(2019, 5, 28, 21, 0)), 278 | ("failure", -24.0, datetime.datetime(2019, 7, 22, 21, 0)), 279 | ("failure", -40.0, datetime.datetime(2019, 8, 27, 21, 0)), 280 | ] 281 | 282 | @pytest.mark.parametrize( 283 | "TP, RR, exp_instance", 284 | [ 285 | (0.86032, None, TrackingTrade), 286 | (None, 1.5, TrackingTrade), 287 | (None, None, Exception), 288 | ], 289 | ) 290 | def test_instantiation(self, TP, RR, exp_instance): 291 | trade_features_list = self.init_trades(mock_trades=trades_entered) 292 | if exp_instance is TrackingTrade: 293 | TrackingTrade(RR=RR, TP=TP, **trade_features_list[0], 294 | init_clist=False) 295 | elif exp_instance == Exception: 296 | with pytest.raises(Exception): 297 | TrackingTrade( 298 | RR=RR, TP=TP, **trade_features_list[0], 299 | init_clist=False 300 | ) 301 | 302 | def test_run(self, clOH8_2019_pickled, clO_pickled): 303 | trade_features_list = self.init_trades(mock_trades=trades_entered) 304 | for ix in range(len(trade_features_list)): 305 | tracking_trade_object = TrackingTrade( 306 | **trade_features_list[ix], 307 | RR=1.5, 308 | clist=clO_pickled, 309 | clist_tm=clOH8_2019_pickled, 310 | connect=False, 311 | ) 312 | tracking_trade_object.initialise() 313 | tracking_trade_object.run() 314 | assert tracking_trade_object.outcome == self.trades_outcome[ix][0] 315 | assert tracking_trade_object.pips == self.trades_outcome[ix][1] 316 | assert tracking_trade_object.end == self.trades_outcome[ix][2] 317 | -------------------------------------------------------------------------------- /tests/trading_journal/test_Trade.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import datetime 3 | 4 | from trading_journal.open_trade import Trade 5 | from data_for_tests import trades1, last_times 6 | 7 | 8 | trade_data1 = [(*trade, alast_time) for trade, alast_time in zip(trades1, last_times)] 9 | 10 | 11 | @pytest.mark.parametrize("start,type,SR,SL,TP,entry,lasttime", trade_data1) 12 | def test_get_lasttime(start, type, SR, SL, TP, entry, lasttime, clO_pickled): 13 | """Check function get_lasttime""" 14 | t = Trade( 15 | id="test", 16 | start=start, 17 | pair="AUD_USD", 18 | timeframe="D", 19 | type=type, 20 | SR=SR, 21 | SL=SL, 22 | TP=TP, 23 | entry=entry, 24 | clist=clO_pickled, 25 | ) 26 | 27 | assert t.get_lasttime() == lasttime 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "start," "type," "SR," "SL," "TP," "entry," "lasttime", 32 | [ 33 | ( 34 | "2017-12-10 22:00:00", 35 | "long", 36 | 0.74986, 37 | 0.74720, 38 | 0.76521, 39 | 0.75319, 40 | datetime.datetime(2017, 6, 1, 21, 0), 41 | ), 42 | ( 43 | "2017-03-21 22:00:00", 44 | "short", 45 | 0.77103, 46 | 0.77876, 47 | 0.73896, 48 | 0.76717, 49 | datetime.datetime(2016, 4, 19, 21, 0), 50 | ), 51 | ], 52 | ) 53 | def test_get_lasttime_with_pad(start, type, SR, SL, TP, entry, lasttime, 54 | clO_pickled): 55 | """Check function get_lasttime""" 56 | t = Trade( 57 | id="test", 58 | start=start, 59 | pair="AUD_USD", 60 | timeframe="D", 61 | type=type, 62 | SR=SR, 63 | SL=SL, 64 | TP=TP, 65 | entry=entry, 66 | clist=clO_pickled, 67 | ) 68 | 69 | assert t.get_lasttime(pad=30) == lasttime 70 | 71 | 72 | # add a non initialised trade 73 | trades_for_test_initialise = trades1[:] 74 | trades_for_test_initialise.append( 75 | ("2018-12-23 22:00:00", "long", 0.70112, 0.69637, 0.72756, 0.70895) 76 | ) 77 | entry_data = [ 78 | (True, "2016-12-29T08:00:00"), 79 | (True, "2017-04-12T19:00:00"), 80 | (True, "2017-09-12T01:00:00"), 81 | (True, "2018-05-04T01:00:00"), 82 | (False, "n.a."), 83 | ] 84 | trade_data2 = [ 85 | (*trade, pip) for trade, pip in zip(trades_for_test_initialise, entry_data) 86 | ] 87 | 88 | 89 | @pytest.mark.parametrize("start,type,SR,SL,TP,entry,entry_data", trade_data2) 90 | def test_initialise(start, type, SR, SL, TP, entry, entry_data, clO_pickled): 91 | t = Trade( 92 | id="test", 93 | start=start, 94 | pair="AUD_USD", 95 | timeframe="D", 96 | type=type, 97 | SR=SR, 98 | SL=SL, 99 | TP=TP, 100 | entry=entry, 101 | clist=clO_pickled, 102 | clist_tm=clO_pickled, 103 | ) 104 | t.initialise() 105 | assert t.entered == entry_data[0] 106 | if hasattr(t, "entry_time"): 107 | assert t.entry_time == entry_data[1] 108 | 109 | 110 | trades_for_entry_onrsi = trades1[:] 111 | trades_for_entry_onrsi.append( 112 | ("2017-07-25 22:00:00", "short", 0.79743, 0.80577, 0.77479, 0.79343) 113 | ) 114 | 115 | is_on_rsi = [False, False, False, False, True] 116 | trade_data3 = [(*trade, pip) for trade, pip in zip(trades_for_entry_onrsi, 117 | is_on_rsi)] 118 | 119 | 120 | @pytest.mark.parametrize("start,type,SR,SL,TP,entry,entry_onrsi", trade_data3) 121 | def test_is_entry_onrsi(start, type, SR, SL, TP, entry, entry_onrsi, 122 | clO_pickled): 123 | """Test is_entry_onrsi function""" 124 | t = Trade( 125 | id="test", 126 | start=start, 127 | pair="AUD_USD", 128 | timeframe="D", 129 | type=type, 130 | SR=SR, 131 | SL=SL, 132 | TP=TP, 133 | entry=entry, 134 | clist=clO_pickled, 135 | clist_tm=clO_pickled, 136 | ) 137 | newclist = t.clist 138 | newclist.calc_rsi() 139 | t.clist = newclist 140 | assert entry_onrsi == t.is_entry_onrsi() 141 | -------------------------------------------------------------------------------- /tests/trading_journal/test_TradeJournal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trading_journal.trade_journal import TradeJournal 4 | 5 | 6 | def test_fetch_trades(tjO): 7 | tlist = tjO.fetch_trades() 8 | 9 | assert len(tlist) == 4 10 | 11 | 12 | def test_win_rate(tjO): 13 | (number_s, number_f, tot_pips) = tjO.win_rate(strats="counter") 14 | 15 | assert number_s == 2 16 | assert number_f == 1 17 | assert tot_pips == 275.0 18 | 19 | 20 | def test_write_tradelist(t_object_list, tmp_path): 21 | td = TradeJournal(url=f"{tmp_path}/testCounter1.xlsx", 22 | worksheet="trading_journal") 23 | 24 | td.write_tradelist(t_object_list, "outsheet") 25 | 26 | assert os.path.exists(f"{tmp_path}/testCounter1.xlsx") == 1 27 | -------------------------------------------------------------------------------- /tests/trading_journal/test_trade_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from forex.candle import Candle 4 | from trading_journal.trade_utils import ( 5 | get_closest_hour, 6 | process_start, 7 | adjust_SL, 8 | check_timeframes_fractions) 9 | from data_for_tests import start_hours 10 | 11 | hour_data = [(9, "H8", 5), (21, "H8", 21), (17, "H8", 13)] 12 | 13 | 14 | @pytest.mark.parametrize("solve_hour,timeframe,closest_hour", hour_data) 15 | def test_get_closest_hour(solve_hour, timeframe, closest_hour): 16 | """Test the 'get_closest_hour' function""" 17 | assert get_closest_hour(timeframe=timeframe, solve_hour=solve_hour) == closest_hour 18 | 19 | 20 | @pytest.mark.parametrize("start,returned,timeframe", start_hours) 21 | def test_process_start(start, returned, timeframe): 22 | aligned_start = process_start(dt=start, timeframe=timeframe) 23 | assert aligned_start == returned 24 | 25 | 26 | fraction_data = [("D", "H8", 3.0), ("H12", "H8", 1.5), ("H8", "H4", 2.0)] 27 | 28 | 29 | @pytest.mark.parametrize("timeframe1,timeframe2,fraction", fraction_data) 30 | def test_check_timeframes_fractions(timeframe1, timeframe2, fraction): 31 | res_fraction = check_timeframes_fractions(timeframe1=timeframe1, 32 | timeframe2=timeframe2) 33 | assert res_fraction == fraction 34 | 35 | 36 | def test_get_SLdiff(t_object): 37 | assert 24.0 == t_object.get_SLdiff() 38 | 39 | 40 | # lisf of lists containing tuples, where each tuple is composed of a candle high and low 41 | high_low_candles = [ 42 | [(0.80, 0.70), (0.82, 0.70), (0.79, 0.70)], 43 | [(0.90, 0.70), (0.95, 0.65), (0.97, 0.72)], 44 | ] 45 | # trade types for each sublist in 'high_low_candles' 46 | trade_types = ["short", "long"] 47 | # adjusted SL prices 48 | sl_adjusted = [0.821, 0.649] 49 | 50 | 51 | @pytest.fixture 52 | def mock_candle_list(mocker): 53 | """Creates a list of lists, each sublist containing 3 mocked Candle objects""" 54 | tri_candle_list = list() 55 | for tri_candle in high_low_candles: 56 | candle_list = list() 57 | for high_low in tri_candle: 58 | mock_Candle_instrance = mocker.MagicMock(spec=Candle) 59 | mocker.patch.object(mock_Candle_instrance, "h", high_low[0]) 60 | mocker.patch.object(mock_Candle_instrance, "l", high_low[1]) 61 | candle_list.append(mock_Candle_instrance) 62 | tri_candle_list.append(candle_list) 63 | return tri_candle_list 64 | 65 | 66 | def test_adjust_sl(mock_candle_list): 67 | """Test 'adjust_sl' function""" 68 | 69 | for ix in range(len(mock_candle_list)): 70 | tri_candle = mock_candle_list[ix] 71 | new_SL = adjust_SL( 72 | pair="AUD_USD", type=trade_types[ix], list_candles=tri_candle 73 | ) 74 | assert sl_adjusted[ix] == new_SL 75 | -------------------------------------------------------------------------------- /trade_bot/trade_bot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import re 4 | 5 | from datetime import datetime, timedelta 6 | from typing import List 7 | from api.oanda.connect import Connect 8 | from forex.candle import CandleList, Candle 9 | from forex.harea import HAreaList 10 | from params import gparams, tradebot_params, pivots_params 11 | from forex.pivot import PivotList 12 | from forex.candlelist_utils import calc_SR 13 | from trading_journal.trade import Trade 14 | from trade_bot.trade_bot_utils import (adjust_SL_candles, 15 | adjust_SL_nextSR, 16 | get_trade_type, 17 | prepare_trade) 18 | from trade_bot.trade_bot_utils import adjust_SL_pips 19 | from dataclasses import dataclass 20 | from utils import try_parsing_date, periodToDelta 21 | 22 | # create logger 23 | tb_logger = logging.getLogger(__name__) 24 | tb_logger.setLevel(logging.INFO) 25 | 26 | 27 | @dataclass 28 | class preTrade: 29 | """Container for a Candle falling on a HArea, which can potentially become 30 | a trade""" 31 | sel_ix: int 32 | SRlst: HAreaList 33 | candle: Candle 34 | type: str 35 | 36 | 37 | class TradeBot(object): 38 | '''This class represents an automatic Trading bot. 39 | 40 | Class variables: 41 | start: datetime that this Bot will start operating. 42 | i.e. 20-03-2017 08:20:00s 43 | end: datetime that this Bot will end operating. 44 | i.e. 20-03-2020 08:20:00s 45 | pair: Currency pair used in the trade. i.e. AUD_USD 46 | timeframe: Timeframe used for the trade. Possible values are: 47 | D,H12,H10,H8,H4,H1 48 | clist: CandleList object used to represent this trade 49 | ''' 50 | __slots__ = ["start", "end", "pair", "timeframe", "clist", 51 | "delta_period", "delta"] 52 | 53 | def __init__(self, start: datetime, end: datetime, pair: str, 54 | timeframe: str, clist: CandleList = None): 55 | self.start = try_parsing_date(start) 56 | self.end = try_parsing_date(end) 57 | self.pair = pair 58 | self.timeframe = timeframe 59 | self.clist = clist 60 | self.delta_period = periodToDelta(tradebot_params.period_range, 61 | self.timeframe) 62 | if self.timeframe == "D": 63 | self.delta = timedelta(hours=24) 64 | else: 65 | p1 = re.compile('^H') 66 | m1 = p1.match(self.timeframe) 67 | if m1: 68 | nhours = int(self.timeframe.replace('H', '')) 69 | self.delta = timedelta(hours=nhours) 70 | if not clist: 71 | self.init_clist() 72 | else: 73 | if clist.candles[-1].time < self.end: 74 | logging.warning(f"Tradebot end:{self.end} is greater than " 75 | f"clist end: {clist.candles[-1].time}") 76 | self.init_clist() 77 | 78 | def init_clist(self) -> None: 79 | """Init clist for this TradeBot""" 80 | 81 | conn = Connect( 82 | instrument=self.pair, 83 | granularity=self.timeframe) 84 | if tradebot_params.period_range > 4899: 85 | tradebot_params.period_range = 4899 86 | initc_date = self.start-self.delta_period 87 | 88 | clO = conn.query(initc_date.isoformat(), self.end.isoformat()) 89 | self.clist = clO 90 | 91 | def scan(self, prefix: str = 'pretrades', discard_sat: bool = True, 92 | srlist: list[float] = None) -> str: 93 | """This function will scan for candles on S/R areas. 94 | These candles will be written to a .csv file 95 | 96 | Arguments: 97 | prefix: Prefix for pickled file with a list of preTrade objects 98 | discard_sat: If True, then the Trade wil not 99 | be taken if IC falls on a Saturday 100 | srlist: If passed, then do not run calc_SR and use the ones in the 101 | list passed 102 | 103 | Returns: 104 | Path to file containing a list of preTrade objects 105 | """ 106 | tb_logger.info("Running...") 107 | if srlist: 108 | srlist = HAreaList.from_list(srlist, 109 | instrument=self.pair, 110 | granularity=self.timeframe) 111 | 112 | pretrades = [] 113 | startO = self.start 114 | loop = 0 115 | while startO <= self.end: 116 | tb_logger.debug(f"Trade bot - analyzing candle: {startO.isoformat()}") 117 | # Get now a CandleList from 'initc_date' to 'self.start' 118 | # which is the total time interval for this TradeBot 119 | initc_date = startO-self.delta_period 120 | if loop == 0 or loop >= tradebot_params.period: 121 | subclO = self.clist.slice(initc_date, startO) 122 | sub_pvtlst = PivotList(clist=subclO) 123 | if pivots_params.plot is True: 124 | dt_str = startO.strftime("%d_%m_%Y_%H_%M") 125 | outfile_png = (f"{gparams.outdir}/{self.pair}." 126 | f"{self.timeframe}.{dt_str}.halist.png") 127 | # print SR report to file 128 | outfile_txt = (f"{gparams.outdir}/{self.pair}." 129 | f"{self.timeframe}.{dt_str}.halist.txt") 130 | if not srlist: 131 | SRlst = calc_SR(sub_pvtlst, outfile=outfile_png) 132 | else: 133 | SRlst = srlist 134 | res = SRlst.print() 135 | f = open(outfile_txt, 'w') 136 | f.write(res+"\n") 137 | f.close() 138 | else: 139 | if not srlist: 140 | SRlst = calc_SR(sub_pvtlst) 141 | else: 142 | SRlst = srlist 143 | res = SRlst.print() 144 | tb_logger.info("Identified HAreaList for" 145 | f"time:{startO.isoformat()}") 146 | tb_logger.info(f"{res}") 147 | loop = 0 148 | 149 | # Fetch candle for current datetime. this is the current candle 150 | # that is being checked 151 | c_candle = self.clist[startO] 152 | if c_candle is None: 153 | startO = startO+self.delta 154 | loop += 1 155 | continue 156 | 157 | # c_candle.time is not equal to startO 158 | # when startO is non-working day, for example 159 | delta1hr = timedelta(hours=1) 160 | if (c_candle.time != startO) and \ 161 | (abs(c_candle.time-startO) > delta1hr): 162 | loop += 1 163 | tb_logger.info(f"Analysed dt {startO} is not the same than " 164 | f"APIs returned dt {c_candle.time}." 165 | " Skipping...") 166 | startO = startO + self.delta 167 | continue 168 | 169 | # check if there is any HArea overlapping with c_candle 170 | HAreaSel, sel_ix = SRlst.onArea(candle=c_candle) 171 | if HAreaSel is not None: 172 | # guess the if trade is 'long' or 'short' 173 | newCl = self.clist.slice(start=initc_date, end=c_candle.time) 174 | type = get_trade_type(c_candle.time, newCl) 175 | 176 | prepare = False 177 | if c_candle.indecision_c(ic_perc=gparams.ic_perc) is True and \ 178 | len(SRlst.halist) >= 3 and \ 179 | c_candle.height(pair=self.pair) \ 180 | < tradebot_params.max_height: 181 | prepare = True 182 | elif type == 'short' and c_candle.colour == 'red' and \ 183 | len(SRlst.halist) >= 3 and \ 184 | c_candle.height(pair=self.pair) < \ 185 | tradebot_params.max_height: 186 | prepare = True 187 | elif type == 'long' and c_candle.colour == 'green' and \ 188 | len(SRlst.halist) >= 3 and \ 189 | c_candle.height(pair=self.pair) < \ 190 | tradebot_params.max_height: 191 | prepare = True 192 | 193 | # discard if IC falls on a Saturday 194 | if c_candle.time.weekday() == 5 and discard_sat is True: 195 | tb_logger.info(f"Possible trade at {c_candle.time} " 196 | f"falls on Sat. Skipping...") 197 | prepare = False 198 | 199 | if prepare is True: 200 | pretrades.append(preTrade(sel_ix=sel_ix, 201 | SRlst=SRlst, 202 | candle=c_candle, 203 | type=type)) 204 | 205 | startO = startO+self.delta 206 | loop += 1 207 | 208 | if pretrades: 209 | with open(f"{prefix}.pckl", "wb") as f: 210 | pickle.dump(pretrades, f) 211 | return f"{prefix}.pckl" 212 | 213 | tb_logger.info("Run done") 214 | 215 | def prepare_trades(self, pretrades: str) -> List[Trade]: 216 | """This function unpickles the preTrade objects 217 | identified by self.scan() and will create a list of Trade objects 218 | 219 | Arguments: 220 | pretrades: Pickled file with a list of preTrade objects 221 | """ 222 | TP = None 223 | tlist = [] 224 | with open(pretrades, "rb") as f: 225 | pret_lst = pickle.load(f) 226 | for pret in pret_lst: 227 | initc_date = pret.candle.time-self.delta_period 228 | newCl = self.clist.slice(start=initc_date, 229 | end=pret.candle.time) 230 | if tradebot_params.adj_SL == "candles": 231 | SL = adjust_SL_candles(pret.type, newCl) 232 | elif tradebot_params.adj_SL == "pips": 233 | SL = adjust_SL_pips(pret.candle, 234 | pret.type, 235 | pair=self.pair, 236 | no_pips=tradebot_params.adj_SL_pips) 237 | else: 238 | SL, TP = adjust_SL_nextSR(pret.SRlst, 239 | pret.sel_ix, 240 | pret.type) 241 | if not SL: 242 | SL = adjust_SL_pips(pret.candle, 243 | pret.type, 244 | pair=self.pair, 245 | no_pips=tradebot_params.adj_SL_pips) 246 | t = prepare_trade( 247 | tb_obj=self, 248 | start=pret.candle.time+self.delta, 249 | type=pret.type, 250 | ic=pret.candle, 251 | SL=SL, 252 | TP=TP, 253 | harea_sel=pret.SRlst.halist[pret.sel_ix], 254 | add_pips=tradebot_params.add_pips) 255 | t.tot_SR = len(pret.SRlst.halist) 256 | t.rank_selSR = pret.sel_ix 257 | tlist.append(t) 258 | return tlist 259 | -------------------------------------------------------------------------------- /trade_bot/trade_bot_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | from typing import Tuple 4 | 5 | from forex.harea import HAreaList 6 | from forex.candle import Candle, CandleList 7 | from trading_journal.trade import Trade 8 | from forex.pivot import PivotList 9 | from params import trade_params, tradebot_params 10 | from utils import add_pips2price, substract_pips2price 11 | 12 | # create logger 13 | t_logger = logging.getLogger(__name__) 14 | t_logger.setLevel(logging.INFO) 15 | 16 | 17 | def adjust_SL_pips(candle: Candle, 18 | type: str, pair: str, 19 | no_pips: int = 100) -> float: 20 | """Function to adjust the SL price 21 | to the most recent highest high/lowest low. 22 | 23 | Arguments: 24 | candle : Candle object for which SL will be adjusted 25 | type : Trade type ('long'/ 'short'). 26 | pair: Pair 27 | no_pips: Number of pips to add to the highest/lowest of 28 | the candle to calculate the S/L value. 29 | 30 | Returns: 31 | adjusted SL 32 | """ 33 | if type == "long": 34 | price = candle.l 35 | SL = substract_pips2price(pair, price, no_pips) 36 | else: 37 | price = candle.h 38 | SL = add_pips2price(pair, price, no_pips) 39 | 40 | return SL 41 | 42 | 43 | def get_trade_type(dt, clObj: CandleList) -> str: 44 | """Function to get the type of a Trade (short/long). 45 | 46 | Arguments: 47 | dt : datetime object 48 | This will be the datetime for the IC candle 49 | clObj : CandleList used for calculation 50 | 51 | Returns: 52 | type (short/long) 53 | """ 54 | if dt != clObj.candles[-1].time: 55 | dt = clObj.candles[-1].time 56 | 57 | # increate sensitivity for pivot detection 58 | PL = PivotList(clObj, th_bounces=trade_params.th_bounces) 59 | 60 | # now, get the Pivot matching the datetime for the IC+1 candle 61 | if PL.pivots[-1].candle.time != dt: 62 | raise Exception("Last pivot time does not match the passed datetime") 63 | # check the 'type' of Pivot.pre segment 64 | direction = PL.pivots[-1].pre.type 65 | 66 | if direction == -1: 67 | return "long" 68 | elif direction == 1: 69 | return "short" 70 | else: 71 | raise Exception("Could not guess the file type") 72 | 73 | 74 | def prepare_trade(tb_obj, start: datetime, type: str, SL: float, ic: Candle, 75 | harea_sel, add_pips: int = None, TP: float = None) -> Trade: 76 | """Prepare a Trade object 77 | 78 | Arguments: 79 | tb_obj : TradeBot object 80 | start : Start datetime for the trade 81 | type : Type of trade. ['short','long'] 82 | SL : Adjusted (by '__get_trade_type') SL price 83 | ic: Indecission candle 84 | harea_sel : HArea of this trade 85 | add_pips : Number of pips above/below SL and entry 86 | price to consider for recalculating 87 | the SL and entry 88 | TP : Take profit value 89 | """ 90 | if type == "short": 91 | # entry price will be the low of IC 92 | entry_p = ic.l 93 | if add_pips is not None: 94 | SL = round(add_pips2price(tb_obj.pair, 95 | SL, add_pips), 4) 96 | entry_p = round(substract_pips2price(tb_obj.pair, 97 | entry_p, add_pips), 4) 98 | elif type == "long": 99 | # entry price will be the high of IC 100 | entry_p = ic.h 101 | if add_pips is not None: 102 | entry_p = add_pips2price(tb_obj.pair, 103 | entry_p, add_pips) 104 | SL = substract_pips2price(tb_obj.pair, 105 | SL, add_pips) 106 | 107 | t = Trade( 108 | id='{0}.bot'.format(tb_obj.pair), 109 | start=start.strftime('%Y-%m-%d %H:%M:%S'), 110 | pair=tb_obj.pair, 111 | timeframe=tb_obj.timeframe, 112 | type=type, 113 | entry=entry_p, 114 | SR=harea_sel.price, 115 | SL=SL, 116 | TP=TP, 117 | RR=tradebot_params.RR) 118 | return t 119 | 120 | 121 | def adjust_SL_nextSR(SRlst: HAreaList, 122 | sel_ix: int, 123 | type: str) -> Tuple[float, float]: 124 | """Function to calculate the TP and SL prices to the next SR areas""" 125 | TP, SL = None, None 126 | try: 127 | if type == "long": 128 | SL = SRlst.halist[sel_ix-1].price 129 | TP = SRlst.halist[sel_ix+1].price 130 | if sel_ix-1 < 0: 131 | SL = None 132 | else: 133 | TP = SRlst.halist[sel_ix-1].price 134 | SL = SRlst.halist[sel_ix+1].price 135 | if sel_ix-1 < 0: 136 | TP = None 137 | except Exception: 138 | t_logger.warning(f"sel_ix error: {sel_ix}. Trying with adjust_SL_pips") 139 | 140 | return SL, TP 141 | 142 | 143 | def adjust_SL_candles(type: str, clObj: CandleList, number: int = 7) -> float: 144 | """Function to adjust the SL price 145 | to the most recent highest high/lowest low. 146 | 147 | Arguments: 148 | type : Trade type ('long'/ 'short') 149 | clObj : CandleList obj 150 | number : Number of candles to go back 151 | to adjust the SL. 152 | 153 | Returns: 154 | adjusted SL 155 | """ 156 | SL, ix = None, 0 157 | if not clObj.candles: 158 | raise Exception("No candles in CandleList. Can't calculate the SL") 159 | for c in reversed(clObj.candles): 160 | # go back 'number' candles 161 | if ix == number: 162 | break 163 | ix += 1 164 | if type == "short": 165 | if SL is None: 166 | SL = c.h 167 | elif c.h > SL: 168 | SL = c.h 169 | if type == "long": 170 | if SL is None: 171 | SL = c.l 172 | if c.l < SL: 173 | SL = c.l 174 | return SL 175 | -------------------------------------------------------------------------------- /trading_journal/constants.py: -------------------------------------------------------------------------------- 1 | # Allowed Trade class attribures 2 | ALLOWED_ATTRBS = ["entered", "start", "end", "pair", 3 | "timeframe", "outcome", "exit", "entry_time", "type", 4 | "SR", "RR", "pips", "clist", "clist_tm", "strat", 5 | "tot_SR", "rank_selSR"] 6 | 7 | # 'area_unaware': exit when candles against the trade without 8 | # considering if price is > or < than entry price 9 | VALID_TYPES = ["area_unaware", "area_aware"] 10 | -------------------------------------------------------------------------------- /trading_journal/open_trade.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from datetime import datetime, timedelta 5 | 6 | from trading_journal.trade import Trade 7 | from api.oanda.connect import Connect 8 | from forex.harea import HArea 9 | from forex.candle import Candle 10 | from utils import ( 11 | periodToDelta, 12 | calculate_profit, 13 | add_pips2price, 14 | substract_pips2price, 15 | is_even_hour, 16 | is_week_day) 17 | from trading_journal.trade_utils import ( 18 | gen_datelist, 19 | check_candle_overlap, 20 | process_start, 21 | adjust_SL, 22 | check_timeframes_fractions, 23 | ) 24 | from params import trade_management_params 25 | 26 | t_logger = logging.getLogger(__name__) 27 | t_logger.setLevel(logging.INFO) 28 | 29 | 30 | class OpenTrade(Trade): 31 | """An open Trade (i.e. entered)""" 32 | 33 | def __init__(self, candle_number: int = 3, connect: bool = True, **kwargs): 34 | """ 35 | Arguments: 36 | candle_number: number of candles against the trade to consider 37 | connect: If True then it will use the API to fetch candles 38 | completed: Is this Trade completed? 39 | preceding_candles: List with CandleList to check if it 40 | goes against the trade 41 | """ 42 | self.candle_number = candle_number 43 | self.connect = connect 44 | self.completed = False # is this OpenTrade completed 45 | self.preceding_candles = list() 46 | super().__init__(**kwargs) 47 | 48 | def append_trademanagement_candles(self, d: datetime, fraction: float) -> None: 49 | """Append the trademanagement candles to self.preceding_candles. 50 | 51 | This method will add the number of candles to self.preceding_candles 52 | depending on fraction 53 | 54 | Arguments: 55 | d: datetime that will be used to start adding candles 56 | fraction: Fraction of times one timeframe is within the other 57 | """ 58 | trade_management_timeframe = trade_management_params.clisttm_tf 59 | 60 | delta = periodToDelta(ncandles=1, 61 | timeframe=trade_management_timeframe) 62 | 63 | if fraction < 1: 64 | fraction = 1 65 | fraction = math.ceil(fraction) 66 | for ix in range(int(fraction)): 67 | new_datetime = d + delta * ix 68 | cl_tm = self.clist_tm[new_datetime] 69 | if cl_tm is None: 70 | if self.connect is True: 71 | conn = Connect( 72 | instrument=self.pair, 73 | granularity=trade_management_timeframe 74 | ) 75 | cl_tm = conn.fetch_candle(d=new_datetime) 76 | if cl_tm is not None: 77 | if cl_tm not in self.preceding_candles: 78 | self.preceding_candles.append(cl_tm) 79 | 80 | # slice to self.candle_number if more than this number 81 | # and remove the oldest candle 82 | if len(self.preceding_candles) > self.candle_number: 83 | self.preceding_candles = self.preceding_candles[(self.candle_number) * -1:] 84 | 85 | def check_if_against(self) -> bool: 86 | """Function to check if middle_point values are 87 | agaisnt the trade 88 | """ 89 | prices = [x.middle_point() for x in self.preceding_candles] 90 | if self.type == "long": 91 | return all(prices[i] > prices[i + 1] for i in range(len(prices) - 1)) 92 | else: 93 | return all(prices[i] < prices[i + 1] for i in range(len(prices) - 1)) 94 | 95 | def calculate_overlap(self, cl: Candle) -> None: 96 | """Check if 'cl' overlaps either self.SL or self.TP. 97 | 98 | It also sets the outcome 99 | """ 100 | if check_candle_overlap(cl, self.SL.price): 101 | t_logger.info("Sorry, SL was hit!") 102 | self.completed = True 103 | self.outcome = "failure" 104 | elif check_candle_overlap(cl, self.TP.price): 105 | t_logger.info("Great, TP was hit!") 106 | self.completed = True 107 | self.outcome = "success" 108 | 109 | def end_trade(self, cl: Candle, harea: HArea) -> None: 110 | """End trade""" 111 | end = None 112 | if self.connect is True: 113 | end = harea.get_cross_time(candle=cl, 114 | granularity=trade_management_params.granularity) 115 | else: 116 | end = cl.time 117 | self.end = end 118 | self.exit = harea.price 119 | 120 | def finalise_trade(self, cl: Candle) -> None: 121 | """Finalise trade by setting the outcome and calculating profit""" 122 | if self.outcome == "success": 123 | price1 = self.TP.price 124 | self.end_trade(cl=cl, harea=self.TP) 125 | if self.outcome == "failure": 126 | price1 = self.SL.price 127 | self.end_trade(cl=cl, harea=self.SL) 128 | if self.outcome == "n.a.": 129 | price1 = cl.c 130 | self.end = "n.a." 131 | self.exit = price1 132 | if self.outcome == "future": 133 | self.end = "n.a." 134 | self.pips = "n.a." 135 | self.exit = "n.a." 136 | if self.outcome != "future": 137 | self.pips = calculate_profit( 138 | prices=(price1, self.entry.price), type=self.type, pair=self.pair 139 | ) 140 | 141 | def fetch_candle(self, d: datetime) -> Candle: 142 | """Fetch a Candle object given a datetime""" 143 | cl = None 144 | cl = self.clist[d] 145 | if not is_week_day(d): 146 | return None 147 | if cl is None: 148 | if self.connect is True: 149 | conn = Connect(instrument=self.pair, 150 | granularity=self.timeframe) 151 | cl = conn.fetch_candle(d=d) 152 | if cl is None: 153 | if is_even_hour(d): 154 | d = d - timedelta(seconds=3600) 155 | else: 156 | d = d + timedelta(seconds=3600) 157 | cl = conn.fetch_candle(d=d) 158 | return cl 159 | 160 | def isin_profit(self, price: float) -> bool: 161 | """Is price in profit?. 162 | 163 | Argument: 164 | price: price to check 165 | """ 166 | if self.type == "long": 167 | if price >= self.entry.price: 168 | return True 169 | if self.type == "short": 170 | if price <= self.entry.price: 171 | return True 172 | return False 173 | 174 | def process_trademanagement( 175 | self, d: datetime, fraction: float, check_against: bool = True 176 | ): 177 | """Process trademanagement candles and ajust 'SL' if required. 178 | 179 | Args: 180 | d: datetime for the Candle that is being analysed 181 | fraction: Number of times timeframe1 is contained in timeframe2 182 | check_against: Check if the trade is going against and adjust 183 | SL if so 184 | """ 185 | # align 'd' object to 'trade_management_params.clisttm_tf' timeframe 186 | aligned_d = process_start(dt=d, 187 | timeframe=trade_management_params.clisttm_tf) 188 | self.append_trademanagement_candles(aligned_d, fraction) 189 | if len(self.preceding_candles) == self.candle_number: 190 | new_SL = adjust_SL( 191 | pair=self.pair, 192 | type=self.type, 193 | list_candles=self.preceding_candles 194 | ) 195 | if check_against: 196 | res = self.check_if_against() 197 | if res is True: 198 | if self.type == "short" and (new_SL < self.SL.price): 199 | self.SL.price = new_SL 200 | elif self.type == "long" and (new_SL > self.SL.price): 201 | self.SL.price = new_SL 202 | else: 203 | if self.type == "short" and (new_SL < self.SL.price): 204 | self.SL.price = new_SL 205 | elif self.type == "long" and (new_SL > self.SL.price): 206 | self.SL.price = new_SL 207 | if trade_management_params.preceding_clist_strat == "wipe": 208 | self.preceding_candles = list() 209 | elif trade_management_params.preceding_clist_strat == "queue": 210 | self.preceding_candles = self.preceding_candles[1:] 211 | else: 212 | raise NotImplementedError( 213 | "Invalid trade_management_params.preceding_clist_strat: " 214 | f"{trade_management_params.preceding_clist_strat}" 215 | ) 216 | 217 | def _validate_datetime(self, d: datetime) -> bool: 218 | """False if datetime is in the future. 219 | 220 | raises ValueError: if no info in the clist 221 | """ 222 | current_date = datetime.now().date() 223 | if d.date() == current_date: 224 | logging.warning("Skipping, as unable to end the trade") 225 | self.outcome = "future" 226 | return False 227 | if d > self.clist.candles[-1].time and self.connect is False: 228 | raise ValueError( 229 | "No candle is available in 'clist' and connect is False." 230 | "Unable to follow" 231 | ) 232 | return True 233 | 234 | 235 | class UnawareTrade(OpenTrade): 236 | """Represents a trade that ignores whether the price is in profit or loss. 237 | 238 | Characterizes for not being conditioned by the price being in loss or profit 239 | (hence the name 'unaware') to begin to add candles to 'start.preceding_candles'." 240 | """ 241 | 242 | def __init__(self, **kwargs): 243 | """Constructor""" 244 | super().__init__(**kwargs) 245 | 246 | def run(self) -> None: 247 | """Method to run this UnawareTrade. 248 | 249 | This function will run the trade and will set the outcome attribute 250 | """ 251 | fraction = check_timeframes_fractions( 252 | timeframe1=self.timeframe, 253 | timeframe2=trade_management_params.clisttm_tf 254 | ) 255 | count = 0 256 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 257 | if not self._validate_datetime(d) or self.completed: 258 | break 259 | count += 1 260 | cl = self.fetch_candle(d) 261 | if cl is None: 262 | count -= 1 263 | continue 264 | self.process_trademanagement(d=d, fraction=fraction) 265 | self.calculate_overlap(cl=cl) 266 | if self.completed: 267 | break 268 | if count >= trade_management_params.numperiods: 269 | self.completed = True 270 | t_logger.warning( 271 | "No outcome could be calculated in the " 272 | "trade_params.numperiods interval" 273 | ) 274 | self.outcome = "n.a." 275 | self.finalise_trade(cl=cl) 276 | 277 | 278 | class AwareTrade(OpenTrade): 279 | """Represent a trade that is aware of the price being in profit or loss. 280 | 281 | Characterizes for adding candles to 'start.preceding_candles' only if price is in profit 282 | """ 283 | 284 | def __init__(self, **kwargs): 285 | """Constructor""" 286 | super().__init__(**kwargs) 287 | 288 | def run(self) -> None: 289 | """Method to run this AwareTrade. 290 | 291 | This function will run the trade and will set the outcome attribute 292 | """ 293 | fraction = check_timeframes_fractions( 294 | timeframe1=self.timeframe, 295 | timeframe2=trade_management_params.clisttm_tf 296 | ) 297 | 298 | count = 0 299 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 300 | if not self._validate_datetime(d) or self.completed: 301 | break 302 | count += 1 303 | cl = self.fetch_candle(d) 304 | if cl is None: 305 | count -= 1 306 | continue 307 | if self.isin_profit(price=cl.c): 308 | self.process_trademanagement(d=d, fraction=fraction) 309 | else: 310 | self.preceding_candles = list() 311 | 312 | self.calculate_overlap(cl=cl) 313 | if self.completed: 314 | break 315 | if count >= trade_management_params.numperiods: 316 | self.completed = True 317 | t_logger.warning( 318 | "No outcome could be calculated in the " 319 | "trade_params.numperiods interval" 320 | ) 321 | self.outcome = "n.a." 322 | self.finalise_trade(cl=cl) 323 | 324 | 325 | class BreakEvenTrade(OpenTrade): 326 | """Represent a trade that adjusts SL to breakeven when in profit. 327 | 328 | When 'self.SL' is adjusted to breakeven, then candles will start 329 | being added to 'self.preceding_candles' 330 | """ 331 | 332 | def __init__(self, number_of_pips=20, **kwargs): 333 | """Constructor 334 | Arguments: 335 | number_of_pips = Number of pips in profit to move to breakeven. 336 | This parameter will also control the SL new price, 337 | which will be (self.entry+number_of_pips) minus 10 338 | pips 339 | """ 340 | self.number_of_pips = number_of_pips 341 | super().__init__(**kwargs) 342 | 343 | def run(self) -> None: 344 | """Method to run this BreakEvenTrade. 345 | 346 | This function will run the trade and will set the outcome attribute 347 | """ 348 | fraction = check_timeframes_fractions( 349 | timeframe1=self.timeframe, 350 | timeframe2=trade_management_params.clisttm_tf 351 | ) 352 | 353 | count = 0 354 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 355 | if not self._validate_datetime(d) or self.completed: 356 | break 357 | count += 1 358 | cl = self.fetch_candle(d) 359 | if cl is None: 360 | count -= 1 361 | continue 362 | 363 | self.calculate_overlap(cl=cl) 364 | if self.completed: 365 | break 366 | pips_distance = calculate_profit( 367 | prices=(cl.c, self.entry.price), type=self.type, pair=self.pair 368 | ) 369 | if pips_distance > self.number_of_pips: 370 | # This controls the gain achieved when moving the SL price 371 | pips_of_gain = self.number_of_pips - 10 372 | if self.type == "long": 373 | new_price = add_pips2price( 374 | pair=self.pair, price=self.entry.price, 375 | pips=pips_of_gain 376 | ) 377 | elif self.type == "short": 378 | new_price = substract_pips2price( 379 | pair=self.pair, price=self.entry.price, 380 | pips=pips_of_gain 381 | ) 382 | self.SL.price = new_price 383 | 384 | self.process_trademanagement(d=d, fraction=fraction) 385 | if count >= trade_management_params.numperiods: 386 | self.completed = True 387 | t_logger.warning( 388 | "No outcome could be calculated in the " 389 | "trade_management_params.numperiods interval" 390 | ) 391 | self.outcome = "n.a." 392 | self.finalise_trade(cl=cl) 393 | 394 | 395 | class TrackingTrade(OpenTrade): 396 | """Trade where SL will be set when 397 | OpenTrade.preceding_candles==candle_number 398 | regardless of whether the CandleList goes against the trade""" 399 | def __init__(self, **kwargs): 400 | super().__init__(**kwargs) 401 | 402 | def run(self) -> None: 403 | """Method to run this BreakEvenTrade. 404 | 405 | This function will run the trade and will set the outcome attribute 406 | """ 407 | fraction = check_timeframes_fractions( 408 | timeframe1=self.timeframe, 409 | timeframe2=trade_management_params.clisttm_tf 410 | ) 411 | count = 0 412 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 413 | if not self._validate_datetime(d): 414 | break 415 | count += 1 416 | cl = self.fetch_candle(d) 417 | if cl is None: 418 | count -= 1 419 | continue 420 | self.calculate_overlap(cl=cl) 421 | if self.completed is True: 422 | break 423 | self.process_trademanagement(d=d, fraction=fraction, 424 | check_against=False) 425 | if count >= trade_management_params.numperiods: 426 | self.completed = True 427 | t_logger.warning( 428 | "No outcome could be calculated in the " 429 | "trade_params.numperiods interval" 430 | ) 431 | self.outcome = "n.a." 432 | self.finalise_trade(cl=cl) 433 | 434 | class TrackingAwareTrade(OpenTrade): 435 | """Trade where SL will be set when 436 | OpenTrade.preceding_candles==candle_number 437 | regardless of whether the CandleList goes against the trade 438 | and only the price is on profit 439 | """ 440 | def __init__(self, **kwargs): 441 | super().__init__(**kwargs) 442 | 443 | def run(self) -> None: 444 | """Method to run this BreakEvenTrade. 445 | 446 | This function will run the trade and will set the outcome attribute 447 | """ 448 | fraction = check_timeframes_fractions( 449 | timeframe1=self.timeframe, 450 | timeframe2=trade_management_params.clisttm_tf 451 | ) 452 | count = 0 453 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 454 | if not self._validate_datetime(d): 455 | break 456 | count += 1 457 | cl = self.fetch_candle(d) 458 | if cl is None: 459 | count -= 1 460 | continue 461 | self.calculate_overlap(cl=cl) 462 | if self.completed is True: 463 | break 464 | if self.isin_profit(price=cl.c): 465 | self.process_trademanagement(d=d, fraction=fraction, 466 | check_against=False) 467 | else: 468 | self.preceding_candles = list() 469 | if count >= trade_management_params.numperiods: 470 | self.completed = True 471 | t_logger.warning( 472 | "No outcome could be calculated in the " 473 | "trade_params.numperiods interval" 474 | ) 475 | self.outcome = "n.a." 476 | self.finalise_trade(cl=cl) 477 | -------------------------------------------------------------------------------- /trading_journal/trade.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import logging 3 | 4 | from trading_journal.constants import ALLOWED_ATTRBS 5 | from api.oanda.connect import Connect 6 | from datetime import datetime 7 | from forex.harea import HArea 8 | from forex.pivot import PivotList 9 | from utils import ( 10 | add_pips2price, 11 | calculate_pips, 12 | try_parsing_date, 13 | substract_pips2price 14 | ) 15 | from trading_journal.trade_utils import ( 16 | gen_datelist, 17 | check_candle_overlap, 18 | init_clist 19 | ) 20 | from params import trade_params, trade_management_params 21 | 22 | t_logger = logging.getLogger(__name__) 23 | t_logger.setLevel(logging.INFO) 24 | 25 | 26 | class Trade: 27 | """This is an abstract class represents a Trade. 28 | 29 | Class variables: 30 | init_clist: boolean that will be true if 31 | clist and clist_tm should be 32 | initialised 33 | start: Time/date when the trade was taken 34 | pair: Currency pair used in the trade 35 | timeframe: Timeframe used for the trade. 36 | outcome: Outcome of the trade. 37 | entry: HArea representing the entry 38 | exit: exit price 39 | type: What is the type of the trade (long,short) 40 | SR: Support/Resistance area 41 | RR: Risk Ratio 42 | pips: Number of pips of profit/loss. This number will be negative if 43 | outcome was failure 44 | clist: CandleList for this trade 45 | clist_tm: CandleList for trade management 46 | """ 47 | 48 | def _preinit__(self): 49 | if not hasattr(self, "clist") and self.init_clist is True: 50 | self.clist = init_clist(timeframe=self.timeframe, 51 | pair=self.pair, 52 | start=self.start) 53 | if not hasattr(self, "clist_tm") and self.init_clist is True: 54 | self.clist_tm = init_clist(timeframe=trade_management_params.clisttm_tf, 55 | pair=self.pair, 56 | start=self.start) 57 | 58 | self.__dict__.update({"start": try_parsing_date(self.__dict__["start"])}) 59 | if hasattr(self, "end"): 60 | self.__dict__.update({"end": try_parsing_date(self.__dict__["end"])}) 61 | 62 | def __init__(self, entry: float, SL: float, TP: float = None, 63 | init_clist=False, **kwargs) -> None: 64 | self.__dict__.update((k, v) for k, v in kwargs.items() if 65 | k in ALLOWED_ATTRBS) 66 | self.init_clist = init_clist 67 | self._preinit__() 68 | self._validate_clists() 69 | self.entry = self.init_harea(entry) if not isinstance(entry, HArea) else entry 70 | self.SL = self.init_harea(SL) if not isinstance(SL, HArea) else SL 71 | if kwargs.get("RR") is None and TP is None: 72 | raise ValueError( 73 | "Neither the RR not " "the TP is defined. Please provide at least one!" 74 | ) 75 | if kwargs.get("RR") is not None and TP is None: 76 | TP = self.calc_TP() 77 | elif kwargs.get("RR") is None and TP is not None: 78 | RR = self.calc_RR(TP=TP) 79 | self.RR = RR 80 | self.TP = self.init_harea(TP) if not isinstance(TP, HArea) else TP 81 | self.SLdiff = self.get_SLdiff() 82 | 83 | def init_harea(self, price: float) -> HArea: 84 | harea_obj = HArea(price=price, 85 | instrument=self.pair, 86 | pips=trade_params.hr_pips, 87 | granularity=self.timeframe) 88 | return harea_obj 89 | 90 | def _validate_clists(self): 91 | """Method to check the validity of the clists""" 92 | if hasattr(self, "clist"): 93 | if ( 94 | self.clist.instrument != self.pair 95 | or self.clist.granularity != self.timeframe 96 | ): 97 | raise ("Incompatible clist attributes") 98 | 99 | def initialise(self, expires: int = 2, connect=True) -> None: 100 | """Progress the trade and check if taken. 101 | 102 | Arguments: 103 | expires: Number of candles after start datetime to check 104 | for entry 105 | connect: If True then it will use the API to fetch candles 106 | """ 107 | t_logger.info(f"Initialising trade: {self.pair}:{self.start}") 108 | count = 0 109 | self.entered = False 110 | for d in gen_datelist(start=self.start, timeframe=self.timeframe): 111 | if d.weekday() == 5: 112 | continue 113 | count += 1 114 | if (count > expires or d > datetime.now()) and self.entered is False: 115 | t_logger.warning("Trade entry expired or is in the future!") 116 | self.outcome = "n.a." 117 | self.pips = 0 118 | return 119 | cl = self.clist[d] 120 | if cl is None: 121 | if connect is True: 122 | conn = Connect(instrument=self.pair, 123 | granularity=self.timeframe) 124 | cl = conn.fetch_candle(d=d) 125 | if cl is None: 126 | count -= 1 127 | continue 128 | if check_candle_overlap(cl, self.entry.price): 129 | t_logger.info("Trade entered") 130 | self.entered = True 131 | if connect is True: 132 | try: 133 | entry_time = self.entry.get_cross_time( 134 | candle=cl, granularity=trade_params.granularity 135 | ) 136 | self.entry_time = entry_time.isoformat() 137 | except BaseException: 138 | self.entry_time = cl.time.isoformat() 139 | else: 140 | self.entry_time = cl.time.isoformat() 141 | break 142 | 143 | def is_entry_onrsi(self) -> bool: 144 | """Function to check if self.start is on RSI. 145 | 146 | Arguments: 147 | trade : Trade object used for the calculation 148 | 149 | Returns: 150 | True if tObj.start is on RSI (i.e. RSI>=70 or RSI<=30) 151 | """ 152 | if self.clist[self.start].rsi >= 70 or self.clist[self.start].rsi <= 30: 153 | return True 154 | else: 155 | return False 156 | 157 | def get_lasttime(self, pad: int = 0): 158 | """Function to calculate the last time price has been above/below 159 | a certain HArea. 160 | 161 | Arguments: 162 | trade : Trade object used for the calculation 163 | pad : Add/substract this number of pips to trade.SR 164 | """ 165 | new_SR = self.SR 166 | if pad > 0: 167 | if self.type == "long": 168 | new_SR = substract_pips2price(self.clist.instrument, 169 | self.SR, pad) 170 | elif self.type == "short": 171 | new_SR = add_pips2price(self.clist.instrument, self.SR, pad) 172 | newcl = self.clist.slice(start=self.clist.candles[0].time, 173 | end=self.start) 174 | return newcl.get_lasttime(new_SR, type=self.type) 175 | 176 | def calc_TP(self) -> float: 177 | diff = (self.entry.price - self.SL.price) * self.RR 178 | return round(self.entry.price + diff, 4) 179 | 180 | def calc_RR(self, TP: float) -> float: 181 | RR = abs(TP - self.entry.price) / abs(self.SL.price - self.entry.price) 182 | return round(RR, 2) 183 | 184 | def get_trend_i(self) -> datetime: 185 | """Function to calculate the start of the trend""" 186 | pvLst = PivotList(self.clist) 187 | merged_s = pvLst.calc_itrend() 188 | 189 | if self.type == "long": 190 | candle = merged_s.get_highest() 191 | elif self.type == "short": 192 | candle = merged_s.get_lowest() 193 | 194 | return candle.time 195 | 196 | def get_SLdiff(self) -> float: 197 | """Function to calculate the difference in number of pips between the 198 | entry and the SL prices. 199 | 200 | Returns: 201 | number of pips 202 | """ 203 | diff = abs(self.entry.price - self.SL.price) 204 | number_pips = float(calculate_pips(self.pair, diff)) 205 | 206 | return number_pips 207 | 208 | def __str__(self): 209 | sb = [] 210 | for key in self.__dict__: 211 | sb.append(f"{key}='{self.__dict__[key]}'") 212 | return ", ".join(sb) 213 | 214 | def __repr__(self): 215 | return "Trade" 216 | -------------------------------------------------------------------------------- /trading_journal/trade_journal.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import logging 4 | import math 5 | import re 6 | from typing import Tuple, List 7 | 8 | import openpyxl 9 | from openpyxl import Workbook 10 | 11 | from trading_journal.open_trade import UnawareTrade 12 | from params import tjournal_params 13 | 14 | 15 | tj_logger = logging.getLogger(__name__) 16 | tj_logger.setLevel(logging.INFO) 17 | 18 | 19 | class TradeJournal(object): 20 | """ 21 | Constructor 22 | 23 | Class variables: 24 | url: path to the .xlsx file with the trade journal 25 | worksheet: Name of the worksheet that will be used to create 26 | the object. i.e. 'trading_journal' 27 | """ 28 | __slots__ = ['url', 'worksheet', 'df'] 29 | 30 | def __init__(self, url: str, worksheet: str): 31 | self.url = url 32 | self.worksheet = worksheet 33 | 34 | # read-in the 'trading_journal' worksheet from a .xlsx file into a 35 | # pandas dataframe 36 | try: 37 | xls_file = pd.ExcelFile(url) 38 | df = xls_file.parse(worksheet, converters={ 39 | "id": str, 40 | "start": str, 41 | "end": str, 42 | "trend_i": str}) 43 | df = df.dropna(how='all') 44 | 45 | if df.empty is True: 46 | raise Exception(f"No trades fetched for url:{self.url} and " 47 | "worksheet:{self.worksheet}") 48 | # replace n.a. string by NaN 49 | df = df.replace("n.a.", np.NaN) 50 | # remove trailing whitespaces from col names 51 | df.columns = df.columns.str.rstrip() 52 | self.df = df 53 | except FileNotFoundError: 54 | wb = Workbook() 55 | wb.create_sheet(worksheet) 56 | wb.save(str(self.url)) 57 | 58 | def fetch_trades(self, init_clist: bool=False) -> List[UnawareTrade]: 59 | """Function to fetch a list of Trade objects. 60 | 61 | Args: 62 | init_clist: If True, then clist and clist_tm will be initialised. 63 | """ 64 | 65 | trade_list, args = [], {} 66 | for _, row in self.df.iterrows(): 67 | if isinstance(row['id'], float): 68 | continue 69 | elms = re.split(r'\.| ', row['id']) 70 | assert len(elms) >= 1, "Error parsing the trade id" 71 | pair = elms[0] 72 | args = {'pair': pair, **row} 73 | t = UnawareTrade(**args, init_clist=init_clist) 74 | trade_list.append(t) 75 | 76 | return trade_list 77 | 78 | def win_rate(self, strats: str) -> Tuple[int, int, float]: 79 | '''Calculate win rate and pips balance 80 | for this TradeJournal. If outcome attrb is not 81 | defined then it will invoke the run_trade method 82 | on each particular trade 83 | 84 | Arguments: 85 | strats : Comma-separated list of strategies to analyse: 86 | i.e. counter,counter_b1 87 | 88 | Returns: 89 | number of successes 90 | number of failures 91 | balance of pips in this TradeList 92 | ''' 93 | 94 | strat_l = strats.split(",") 95 | number_s = number_f = tot_pips = 0 96 | for _, row in self.df.iterrows(): 97 | pair = row['id'].split(" ")[0] 98 | args = {'pair': pair, **row} 99 | t = UnawareTrade(**args, init_clist=True) 100 | if t.strat not in strat_l: 101 | continue 102 | if not hasattr(t, 'outcome') or math.isnan(t.outcome): 103 | t.initialise(expires=1) 104 | t.run() 105 | if t.outcome == 'success': 106 | number_s += 1 107 | elif t.outcome == 'failure': 108 | number_f += 1 109 | tot_pips += t.pips 110 | tot_pips = round(tot_pips, 2) 111 | tot_trades = number_s+number_f 112 | perc_wins = round(number_s*100/tot_trades, 2) 113 | perc_losses = round(number_f*100/tot_trades, 2) 114 | print("Tot number of trades: {0}\n-------------".format(tot_trades)) 115 | print("Win trades: {0}; Loss trades: {1}".format(number_s, number_f)) 116 | print("% win trades: {0}; % loss trades: {1}".format(perc_wins, 117 | perc_losses)) 118 | print("Pips balance: {0}".format(tot_pips)) 119 | 120 | return number_s, number_f, tot_pips 121 | 122 | def write_tradelist(self, trade_list: List[UnawareTrade], 123 | sheet_name: str) -> None: 124 | """Write the TradeList to the Excel spreadsheet 125 | pointed by the trade_journal. 126 | 127 | Arguments: 128 | trade_list : List of Trade objects 129 | sheet_name : worksheet name""" 130 | colnames = tjournal_params.colnames.split(",") 131 | data = [] 132 | for t in trade_list: 133 | id = f"{t.pair} {t.start.strftime('%Y%m%d')}" 134 | t.id = id 135 | row = [] 136 | for key in colnames: 137 | # some keys are not defined for some of the Trade 138 | # objects 139 | if hasattr(t, key): 140 | if key in ["SL", "entry", "TP"]: 141 | area = getattr(t, key) 142 | row.append(area.price) 143 | else: 144 | row.append(getattr(t, key)) 145 | else: 146 | row.append("n.a.") 147 | data.append(row) 148 | 149 | df = pd.DataFrame(data, columns=colnames) 150 | 151 | writer = pd.ExcelWriter(self.url, engine='openpyxl', mode='a', 152 | if_sheet_exists='replace') 153 | writer.workbook = openpyxl.load_workbook(self.url) 154 | tj_logger.info("Creating new worksheet with trades with name: {0}". 155 | format(sheet_name)) 156 | df.to_excel(writer, sheet_name) 157 | writer.close() 158 | -------------------------------------------------------------------------------- /trading_journal/trade_utils.py: -------------------------------------------------------------------------------- 1 | # Collection of utilities used by the trade.py module 2 | import logging 3 | from typing import List 4 | from datetime import datetime, timedelta 5 | 6 | from utils import (periodToDelta, 7 | try_parsing_date, 8 | add_pips2price, 9 | substract_pips2price) 10 | from params import trade_params 11 | from api.oanda.connect import Connect 12 | from forex.candle import Candle, CandleList 13 | 14 | t_logger = logging.getLogger(__name__) 15 | t_logger.setLevel(logging.INFO) 16 | 17 | 18 | def calc_period(timeframe: str) -> int: 19 | """Number of hours for a certain timeframe""" 20 | return 24 if timeframe == "D" else int(timeframe.replace("H", 21 | "")) 22 | 23 | 24 | def check_timeframes_fractions(timeframe1: str, timeframe2: str) -> float: 25 | """Get the fraction of times 'timeframe1' is contained in 'timeframe2'. 26 | For example: timeframe1=H12 and timeframe2=H8 will return 1.5 27 | timeframe1=H8 and timeframe2=H4 will return 2 28 | """ 29 | hours1 = calc_period(timeframe1) 30 | hours2 = calc_period(timeframe2) 31 | 32 | return float(hours1/hours2) 33 | 34 | 35 | def get_closest_hour(timeframe: str, solve_hour: int) -> int: 36 | """Get the closest hour to 'solve_hour'""" 37 | time_ranges_dict = { 38 | "H4": [21, 1, 5, 9, 13, 17], 39 | "H8": [21, 5, 13], 40 | "H12": [21, 9], 41 | "D": [21]} 42 | filtered_hours = [hour for hour in time_ranges_dict[timeframe] if (solve_hour-hour) >= 0] 43 | if filtered_hours: 44 | closest_hour = min(filtered_hours, key=lambda x: solve_hour - x) 45 | else: 46 | closest_hour = 21 47 | return closest_hour 48 | 49 | 50 | def process_start(dt: datetime, timeframe: str) -> datetime: 51 | """Round fractional times for Trade.start. 52 | 53 | Returns: 54 | Rounded aligned datetime 55 | """ 56 | if not isinstance(dt, datetime): 57 | raise ValueError(f"{dt} should be a datetime instance") 58 | closest_hour = get_closest_hour(timeframe=timeframe, solve_hour=dt.time().hour) 59 | if closest_hour == 21 and dt.time().hour >= 0 and not dt.time().hour in [22, 23]: 60 | 61 | result_datetime = dt - timedelta(hours=calc_period(timeframe)) 62 | dt = dt.replace(day=result_datetime.day, 63 | month=result_datetime.month, 64 | hour=closest_hour, 65 | year=result_datetime.year, 66 | minute=0, 67 | second=0) 68 | else: 69 | dt = dt.replace(hour=closest_hour, 70 | minute=0, 71 | second=0) 72 | return dt 73 | 74 | 75 | def gen_datelist(start: datetime, timeframe: str) -> List[datetime]: 76 | """Generate a range of dates starting at start and ending 77 | trade_params.interval later in order to assess the outcome 78 | of trade and also the entry time. 79 | """ 80 | return [start + timedelta(hours=x*calc_period(timeframe)) 81 | for x in range(0, trade_params.interval)] 82 | 83 | 84 | def check_candle_overlap(cl: Candle, price: float) -> bool: 85 | """Method to check if Candle 'cl' overlaps 'price'""" 86 | return cl.l <= price <= cl.h 87 | 88 | 89 | def init_clist(timeframe: str, pair: str, start: datetime) -> CandleList: 90 | delta = periodToDelta(trade_params.trade_period, timeframe) 91 | if not isinstance(start, datetime): 92 | start = try_parsing_date(start) 93 | nstart = start - delta 94 | 95 | conn = Connect( 96 | instrument=pair, 97 | granularity=timeframe) 98 | return conn.query(nstart.isoformat(), start.isoformat()) 99 | 100 | 101 | def adjust_SL(pair: str, type: str, list_candles=List[Candle], 102 | pips_offset: int = 10) -> float: 103 | """Adjust SL to minimum in 'list_candles'. 104 | 105 | Arguments: 106 | pair: Instrument 107 | type: Trade type (short/long) 108 | list_candles: List of candles 109 | pips_offset: Number of pips to offset to obj.h and obj.l 110 | """ 111 | if type == "short": 112 | max_candle = max(list_candles, key=lambda obj: obj.h) 113 | new_high = add_pips2price(pair, max_candle.h, pips_offset) 114 | return new_high 115 | 116 | if type == "long": 117 | min_candle = min(list_candles, key=lambda obj: obj.l) 118 | new_low = substract_pips2price(pair, min_candle.l, pips_offset) 119 | return new_low 120 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import yaml 4 | from typing import Tuple 5 | 6 | from datetime import datetime, timedelta 7 | 8 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | # location of directory used to store all data used by Unit Tests 11 | DATA_DIR = ROOT_DIR+"/tests/data" 12 | 13 | 14 | def try_parsing_date(date_string) -> datetime: 15 | """Function to parse a string that can be formatted in 16 | different datetime formats 17 | 18 | Returns: 19 | datetime object 20 | """ 21 | if isinstance(date_string, datetime): 22 | return date_string 23 | for fmt in ("%Y-%m-%dT%H:%M:%S", 24 | "%Y-%m-%d %H:%M:%S", 25 | "%d/%m/%Y %H:%M:%S"): 26 | try: 27 | return datetime.strptime(date_string, fmt) 28 | except ValueError: 29 | pass 30 | raise ValueError(f"no valid date format found: {date_string}") 31 | 32 | 33 | def calculate_pips(pair: str, price: float) -> float: 34 | '''Function to calculate the number of pips 35 | for a given price 36 | 37 | Args: 38 | pair : Currency pair used in the trade. i.e. AUD_USD 39 | price : Provided price 40 | 41 | Returns: 42 | Number of pips 43 | ''' 44 | pips = None 45 | (first, second) = pair.split("_") 46 | if first == 'JPY' or second == 'JPY': 47 | pips = price * 100 48 | else: 49 | pips = price * 10000 50 | 51 | return '%.1f' % pips 52 | 53 | 54 | def add_pips2price(pair: str, price: float, pips: int) -> float: 55 | '''Function that gets a price value and adds 56 | a certain number of pips to the price. 57 | 58 | Arguments: 59 | pair : Currency pair used in the trade. i.e. AUD_USD 60 | price : Price 61 | pips : Number of pips to increase 62 | 63 | Returns: 64 | New price 65 | ''' 66 | (first, second) = pair.split("_") 67 | round_number = None 68 | divisor = None 69 | if first == 'JPY' or second == 'JPY': 70 | round_number = 2 71 | divisor = 100 72 | else: 73 | round_number = 4 74 | divisor = 10000 75 | price = round(price, round_number) 76 | 77 | iprice = price + (pips / divisor) 78 | 79 | return iprice 80 | 81 | 82 | def substract_pips2price(pair: str, price: float, pips: int) -> float: 83 | '''Function that gets a price value and substracts 84 | a certain number of pips to the price 85 | 86 | Arguments: 87 | pair : Currency pair used in the trade. i.e. AUD_USD 88 | price : Price to modify 89 | pips : Number of pips to decrease 90 | 91 | Returns: 92 | New price 93 | ''' 94 | (first, second) = pair.split("_") 95 | round_number = None 96 | divisor = None 97 | if first == 'JPY' or second == 'JPY': 98 | round_number = 2 99 | divisor = 100 100 | else: 101 | round_number = 4 102 | divisor = 10000 103 | price = round(price, round_number) 104 | 105 | dprice = price - (pips / divisor) 106 | 107 | return dprice 108 | 109 | 110 | def periodToDelta(ncandles: int, timeframe: str): 111 | """Function that receives an int representing a number of candles using 112 | the 'ncandles' param and returns a datetime timedelta object 113 | 114 | Arguments: 115 | ncandles: Number of candles for which the timedelta will be retrieved 116 | timeframe: Timeframe used for getting the delta object. Possible 117 | values are: 118 | 2D,D,H12,H10,H8,H4 119 | 120 | Returns: 121 | datetime timedelta object 122 | """ 123 | patt = re.compile(r"(\d)D") 124 | 125 | delta = None 126 | if patt.match(timeframe): 127 | raise Exception(f"{timeframe} is not valid. Oanda rest service does not take it") 128 | elif timeframe == 'D': 129 | delta = timedelta(hours=24 * ncandles) 130 | else: 131 | fgran = timeframe.replace('H', '') 132 | delta = timedelta(hours=int(fgran) * ncandles) 133 | 134 | return delta 135 | 136 | 137 | def get_ixfromdatetimes_list(datetimes_list, d) -> int: 138 | """Function to get the index of the element that is closest 139 | to the passed datetime 140 | 141 | Arguments: 142 | datetimes_list : list 143 | List with datetimes 144 | d : datetime 145 | 146 | Returns: 147 | index of the closest datetime to d 148 | """ 149 | 150 | sel_ix = None 151 | diff = None 152 | ix = 0 153 | for ad in datetimes_list: 154 | if diff is None: 155 | diff = abs(ad-d) 156 | sel_ix = ix 157 | else: 158 | if abs(ad-d) < diff: 159 | sel_ix = ix 160 | diff = abs(ad-d) 161 | ix += 1 162 | 163 | return sel_ix 164 | 165 | 166 | def pairwise(iterable): 167 | "s -> (s0, s1), (s2, s3), (s4, s5), ..." 168 | a = iter(iterable) 169 | return zip(a, a) 170 | 171 | 172 | def correct_timeframe(settings, timeframe): 173 | """This utility function is used for correcting 174 | all the pips-related settings depending 175 | on the selected timeframe 176 | 177 | Arguments: 178 | settings: ConfigParser object 179 | timeframe : D,H12,H8,4 180 | 181 | Returns: 182 | settings : ConfigParser object timeframe corrected 183 | """ 184 | timeframe = int(timeframe.replace('H', '')) 185 | ratio = round(timeframe/24, 2) 186 | 187 | p = re.compile('.*pips') 188 | 189 | for section_name in settings.sections(): 190 | for key, value in settings.items(section_name): 191 | if section_name == 'trade' and key == 'hr_pips': 192 | continue 193 | if p.match(key): 194 | new_pips = int(round(ratio*int(value), 0)) 195 | settings.set(section_name, key, str(new_pips)) 196 | 197 | return settings 198 | 199 | 200 | def calculate_profit(prices: Tuple[float, float], 201 | type: str, pair: str) -> float: 202 | """Function to calculate the profit (in pips) 203 | defined as the difference between 2 prices. 204 | 205 | Args: 206 | prices: tuple with the 2 prices to compare [price, entry] 207 | type: ['long'/'short'] 208 | pair: instrument 209 | """ 210 | if (prices[0] - prices[1]) < 0: 211 | sign = -1 if type == "long" else 1 212 | else: 213 | sign = 1 if type == "long" else -1 214 | pips = float(calculate_pips(pair, 215 | abs(prices[0]-prices[1]))) * sign 216 | return pips 217 | 218 | 219 | def is_even_hour(d: datetime) -> bool: 220 | """Check if hour in datetime is even""" 221 | if d.hour % 2 == 0: 222 | return True 223 | elif d.hour % 2 != 0: 224 | return False 225 | 226 | 227 | def is_week_day(d: datetime) -> bool: 228 | """Returns True if 'd' falls on a weekday""" 229 | 230 | is_weekday = d.isoweekday() < 6 231 | return is_weekday 232 | 233 | 234 | def load_config_yaml_file(yaml_file: str) -> dict: 235 | """Loads a yaml file into a dict""" 236 | with open(yaml_file, 'r') as file: 237 | config = yaml.safe_load(file) 238 | return config 239 | --------------------------------------------------------------------------------