├── .gitignore ├── .pre-commit-config.yaml ├── AShareData ├── __init__.py ├── algo.py ├── analysis │ ├── __init__.py │ ├── close_fund_info.py │ ├── fund_nav_analysis.py │ ├── holding.py │ ├── public_fund_holding.py │ ├── return_analysis.py │ └── trading.py ├── ashare_data_reader.py ├── barra_descriptors.py ├── barra_style_factors.py ├── config.py ├── constants.py ├── data │ ├── __init__.py │ ├── db_schema.json │ ├── industry.json │ ├── jqdata_param.json │ ├── tdx_param.json │ ├── tushare_param.json │ ├── wind_param.json │ └── 自编指数配置.xlsx ├── data_source │ ├── __init__.py │ ├── data_source.py │ ├── jq_data.py │ ├── tdx_data.py │ ├── tushare_data.py │ ├── web_data.py │ └── wind_data.py ├── database_interface.py ├── date_utils.py ├── empirical.py ├── factor.py ├── factor_compositor │ ├── __init__.py │ ├── factor_compositor.py │ └── factor_portfolio.py ├── model │ ├── __init__.py │ ├── capm.py │ ├── fama_french_3_factor_model.py │ ├── fama_french_carhart_4_factor_model.py │ └── model.py ├── plot.py ├── portfolio_analysis.py ├── tickers.py ├── tools │ ├── __init__.py │ └── tools.py └── utils.py ├── LICENSE ├── README.md ├── config_example.json ├── docs ├── Makefile ├── make.bat └── source │ ├── DBInterface.rst │ ├── DataReader.rst │ ├── DataSource.rst │ ├── DateUtils.rst │ ├── Factor.rst │ ├── Model.rst │ ├── Tickers.rst │ ├── conf.py │ └── index.rst ├── requirements.txt ├── scripts ├── big_names.py ├── daily_report.py ├── factor_return.py ├── init.py ├── update_morning_auction.py ├── update_routine.py └── wind_stock_rt.py ├── setup.py └── tests ├── analysis_test.py ├── ashare_datareader_test.py ├── calendar_test.py ├── db_interface_test.py ├── factor_compositor_test.py ├── factor_test.py ├── industry_comparison_test.py ├── jq_data_test.py ├── model_test.py ├── plot_test.py ├── portfolio_test.py ├── test_algo.py ├── test_model.py ├── ticker_test.py ├── tools_test.py ├── tushare2mysql_test.py ├── web_data_test.py └── wind_data_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | config.json 2 | param_raw.json 3 | *.xls 4 | .idea/ 5 | __pycache__/ 6 | *.egg-info/ 7 | AShareData/complimentary_code/ 8 | docs/build 9 | docs/latex 10 | docs/source/_.*/ 11 | docs/source/modules 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 # Use the ref you want to point at 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: requirements-txt-fixer 7 | - id: check-case-conflict 8 | - id: check-docstring-first 9 | - id: double-quote-string-fixer 10 | - id: check-json 11 | # - id: pretty-format-json 12 | # args: ['--autofix', '--indent', '4', '--no-ensure-ascii', '--no-sort-keys', '--'] 13 | -------------------------------------------------------------------------------- /AShareData/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import logging 3 | 4 | from .analysis import IndustryComparison, TradingAnalysis 5 | from .ashare_data_reader import AShareDataReader 6 | from .config import generate_db_interface_from_config, get_db_interface, get_global_config, set_global_config 7 | from .data_source import JQData, TDXData, TushareData, WebDataCrawler 8 | from .database_interface import MySQLInterface 9 | from .date_utils import SHSZTradingCalendar 10 | from .factor_compositor import ConstLimitStockFactorCompositor, IndexCompositor, IndexUpdater, MarketSummaryCompositor, \ 11 | NegativeBookEquityListingCompositor 12 | from .model import FamaFrench3FactorModel, FamaFrenchCarhart4FactorModel 13 | from .tools import IndexHighlighter, major_index_valuation, MajorIndustryConstitutes, StockIndexFutureBasis 14 | 15 | ch = logging.StreamHandler() 16 | ch.setFormatter(logging.Formatter('%(asctime)s | %(levelname)s | %(name)s | %(message)s')) 17 | 18 | logger = logging.getLogger(__name__) 19 | logger.addHandler(ch) 20 | logger.setLevel(logging.INFO) 21 | 22 | if importlib.util.find_spec('WindPy'): 23 | from .data_source import WindData 24 | -------------------------------------------------------------------------------- /AShareData/algo.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Optional, Sequence 3 | 4 | 5 | def chunk_list(l: Sequence, n: int): 6 | for i in range(0, len(l), n): 7 | yield l[i:i + n] 8 | 9 | 10 | def human_sort(l): 11 | """ Sort the given list in the way that humans expect. 12 | """ 13 | l = l.copy() 14 | convert = lambda text: int(text) if text.isdigit() else text 15 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] 16 | l.sort(key=alphanum_key) 17 | return l 18 | 19 | 20 | def get_less_or_equal_of_a_in_b(a: Sequence, b: Sequence) -> Dict: 21 | """ return {va: vb} s.t. vb = max(tb) where tb <= va and tb in b, for all va in a 22 | 23 | :param a: sored sequence of comparable T 24 | :param b: non-empty sorted sequence of comparable T 25 | :return: {va: vb} s.t. vb = max(b) given vb <= va for all va in a 26 | """ 27 | if len(b) <= 0: 28 | raise ValueError(f'b({b}) cannot be empty') 29 | ret = {} 30 | i, j = 0, 1 31 | la, lb = len(a), len(b) 32 | while i < la and a[i] < b[0]: 33 | i += 1 34 | while i < la: 35 | while j < lb and a[i] >= b[j]: 36 | j += 1 37 | ret[a[i]] = b[j - 1] 38 | i += 1 39 | return ret 40 | 41 | 42 | def extract_close_operate_period(fund_name: str) -> Optional[int]: 43 | if fund_name: 44 | if '封闭运作' in fund_name: 45 | fund_name = fund_name.replace('三', '3').replace('二', '2').replace('一', '1').replace('两', '2') 46 | if '年' in fund_name: 47 | return int(fund_name[fund_name.index('年') - 1]) * 12 48 | elif '月' in fund_name: 49 | loc = fund_name.index('月') - 1 50 | ret_str = fund_name[loc - 2:loc] if fund_name[loc - 2].isnumeric() else fund_name[loc] 51 | return int(ret_str) 52 | -------------------------------------------------------------------------------- /AShareData/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .holding import IndustryComparison 2 | from .trading import TradingAnalysis 3 | -------------------------------------------------------------------------------- /AShareData/analysis/close_fund_info.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | from dateutil.relativedelta import relativedelta 5 | 6 | from ..config import get_db_interface 7 | from ..database_interface import DBInterface 8 | from ..tickers import ExchangeFundTickers, OTCFundTickers 9 | 10 | 11 | def close_fund_opening_info(date: dt.datetime = None, db_interface: DBInterface = None): 12 | if date is None: 13 | date = dt.datetime.combine(dt.date.today(), dt.time()) 14 | if db_interface is None: 15 | db_interface = get_db_interface() 16 | exchange_fund_tickers = ExchangeFundTickers(db_interface) 17 | tickers = exchange_fund_tickers.ticker() 18 | 19 | info = db_interface.read_table('基金列表', ['全名', '定开', '定开时长(月)', '封闭运作转LOF时长(月)', '投资类型'], ids=tickers) 20 | funds = info.loc[(info['定开'] == 1) | (info['封闭运作转LOF时长(月)'] > 0), :].copy() 21 | of_ticker = [it.replace('.SH', '.OF').replace('.SZ', '.OF') for it in funds.index.tolist()] 22 | list_date = OTCFundTickers(db_interface).get_list_date(of_ticker).sort_index() 23 | list_date.index = funds.index 24 | list_date.name = '成立日期' 25 | mask = funds['封闭运作转LOF时长(月)'] > 0 26 | funds.loc[mask, '定开时长(月)'] = funds['封闭运作转LOF时长(月)'][mask] 27 | funds = pd.concat([funds, list_date], axis=1) 28 | 29 | tmp = pd.Series([relativedelta(months=it) for it in funds['定开时长(月)']], index=funds.index) 30 | funds.rename({'成立日期': '上一次开放日'}, axis=1, inplace=True) 31 | funds['下一次开放日'] = tmp + funds.loc[:, '上一次开放日'] 32 | ind_base = funds['定开'].astype(bool) 33 | ind = (funds['下一次开放日'] < date) & ind_base 34 | while any(ind): 35 | funds.loc[ind, '上一次开放日'] = funds.loc[ind, '下一次开放日'] 36 | funds.loc[ind, '下一次开放日'] = (tmp + funds.loc[:, '上一次开放日']).loc[ind] 37 | ind = (funds['下一次开放日'] < date) & ind_base 38 | 39 | funds['距离下次开放时间'] = [max((it - date).days, 0) for it in funds['下一次开放日']] 40 | 41 | return funds.loc[:, ['全名', '投资类型', '上一次开放日', '下一次开放日', '距离下次开放时间']].sort_values('距离下次开放时间') 42 | -------------------------------------------------------------------------------- /AShareData/analysis/fund_nav_analysis.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | 5 | from ..ashare_data_reader import AShareDataReader 6 | from ..config import get_db_interface 7 | from ..database_interface import DBInterface 8 | from ..model.model import FinancialModel 9 | from ..factor import ContinuousFactor 10 | 11 | 12 | class FundNAVAnalysis(object): 13 | def __init__(self, ticker: str, db_interface: DBInterface = None): 14 | self.db_interface = db_interface if db_interface else get_db_interface() 15 | self.data_reader = AShareDataReader(self.db_interface) 16 | self.ticker = ticker 17 | db_name = '场外基金净值' if ticker.endswith('OF') else '场内基金日行情' 18 | self.nav_data = ContinuousFactor(db_name, '单位净值').bind_params(ids=self.ticker) 19 | 20 | def compute_correlation(self, index_code: str, period: int = 60) -> float: 21 | index_return_factor = self.data_reader.get_index_return_factor(index_code) 22 | start_date = self.data_reader.calendar.offset(dt.date.today(), -period) 23 | nav_chg = self.nav_data.get_data(start_date=start_date).pct_change() 24 | index_ret = index_return_factor.get_data(start_date=start_date) 25 | corr = nav_chg.corr(index_ret) 26 | return corr 27 | 28 | def compute_exposure(self, model: FinancialModel, period: int = 60): 29 | pass 30 | 31 | def get_latest_published_portfolio_holding(self) -> pd.DataFrame: 32 | data = self.db_interface.read_table('公募基金持仓', ids=self.ticker) 33 | latest = data.loc[data.index.get_level_values('DateTime') == data.index.get_level_values('DateTime').max(), :] 34 | return latest.sort_values('占股票市值比', ascending=False) 35 | -------------------------------------------------------------------------------- /AShareData/analysis/holding.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | 5 | from .. import utils 6 | from ..ashare_data_reader import AShareDataReader 7 | from ..config import get_db_interface 8 | from ..database_interface import DBInterface 9 | 10 | 11 | class IndustryComparison(object): 12 | def __init__(self, index: str, industry_provider: str, industry_level: int, db_interface: DBInterface = None): 13 | self.db_interface = db_interface if db_interface else get_db_interface() 14 | self.data_reader = AShareDataReader(self.db_interface) 15 | self.industry_info = self.data_reader.industry(industry_provider, industry_level) 16 | self.index = index 17 | 18 | def holding_comparison(self, holding: pd.Series): 19 | holding_ratio = self.portfolio_weight(holding) 20 | return self.industry_ratio_comparison(holding_ratio) 21 | 22 | def industry_ratio_comparison(self, holding_ratio: pd.Series): 23 | date = holding_ratio.index.get_level_values('DateTime').unique()[0] 24 | 25 | industry_info = self.industry_info.get_data(dates=date) 26 | index_comp = self.data_reader.index_constitute.get_data(index_ticker=self.index, date=date) 27 | 28 | holding_industry = self._industry_ratio(holding_ratio, industry_info) * 100 29 | index_industry = self._industry_ratio(index_comp, industry_info) 30 | 31 | diff_df = pd.concat([holding_industry, index_industry], axis=1, sort=True).fillna(0) 32 | 33 | return diff_df.iloc[:, 0] - diff_df.iloc[:, 1] 34 | 35 | def portfolio_weight(self, holding: pd.Series): 36 | date = holding.index.get_level_values('DateTime').unique()[0] 37 | 38 | price_info = self.data_reader.stock_close.get_data(dates=date) 39 | price_info.name = 'close' 40 | tmp = pd.concat([holding, price_info], axis=1).dropna() 41 | cap = tmp['quantity'] * tmp['close'] 42 | ratio = cap / cap.sum() 43 | ratio.name = 'weight' 44 | return ratio 45 | 46 | @staticmethod 47 | def _industry_ratio(ratio: pd.Series, industry_info: pd.Series): 48 | tmp = pd.concat([ratio, industry_info], axis=1).dropna() 49 | return tmp.groupby(industry_info.name).sum().iloc[:, 0] 50 | 51 | @staticmethod 52 | def import_holding(holding_loc, date: dt.datetime): 53 | holding = pd.read_excel(holding_loc).rename({'证券代码': 'ID', '数量': 'quantity'}, axis=1) 54 | holding['ID'] = holding.ID.apply(utils.format_stock_ticker) 55 | holding['DateTime'] = date 56 | holding.set_index(['DateTime', 'ID'], inplace=True) 57 | return holding 58 | 59 | 60 | class FundHolding(object): 61 | def __init__(self, db_interface: DBInterface = None): 62 | self.db_interface = db_interface if db_interface else get_db_interface() 63 | self.data_reader = AShareDataReader(self.db_interface) 64 | 65 | def get_holding(self, date: dt.datetime, fund: str = None) -> pd.DataFrame: 66 | sql = None 67 | if fund and fund != 'ALL': 68 | sql = f'accountName = "{fund}"' 69 | data = self.db_interface.read_table('持仓记录', dates=date, text_statement=sql) 70 | if fund: 71 | data = data.groupby(['DateTime', 'windCode'])['quantity'].sum() 72 | data.index.names = ['DateTime', 'ID'] 73 | return data 74 | 75 | def portfolio_stock_weight(self, date: dt.datetime, fund: str = None): 76 | holding = self.get_holding(date, fund) 77 | 78 | price_info = self.data_reader.stock_close.get_data(dates=date) 79 | price_info.name = 'close' 80 | tmp = pd.concat([holding, price_info], axis=1).dropna() 81 | cap = tmp['quantity'] * tmp['close'] 82 | ratio = cap / cap.sum() 83 | ratio.name = 'weight' 84 | return ratio.loc[ratio > 0] 85 | -------------------------------------------------------------------------------- /AShareData/analysis/public_fund_holding.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from typing import Dict 3 | 4 | import pandas as pd 5 | from functools import cached_property 6 | 7 | from ..ashare_data_reader import AShareDataReader 8 | from ..config import get_db_interface 9 | from ..database_interface import DBInterface 10 | 11 | 12 | class PublicFundHoldingRecords(object): 13 | def __init__(self, ticker: str, date: dt.datetime, db_interface: DBInterface = None): 14 | if db_interface is None: 15 | db_interface = get_db_interface() 16 | self.db_interface = db_interface 17 | self.data_reader = AShareDataReader(db_interface) 18 | self.ticker = ticker 19 | self.date = date 20 | 21 | @cached_property 22 | def cache(self): 23 | return self.db_interface.read_table('公募基金持仓', ['持有股票数量', '占股票市值比'], 24 | report_period=self.date, constitute_ticker=self.ticker) 25 | 26 | def stock_holding_by_funds(self): 27 | close = self.data_reader.stock_close.get_data(ids=self.ticker, dates=self.date).values[0] 28 | 29 | data = self.cache.loc[:, ['持有股票数量']].copy().droplevel(['DateTime', 'ConstituteTicker', '报告期']) 30 | data['市值'] = data['持有股票数量'] * close 31 | sec_name = self.data_reader.sec_name.get_data(ids=data.index.tolist(), dates=self.date).droplevel('DateTime') 32 | ret = pd.concat([sec_name, data], axis=1) 33 | ret = ret.sort_values('持有股票数量', ascending=False) 34 | 35 | def fund_holding_pct(self) -> Dict: 36 | fund_holding_shares = self.cache['持有股票数量'].sum() 37 | 38 | total_share = self.data_reader.total_share.get_data(ids=self.ticker, dates=self.date)[0] 39 | float_share = self.data_reader.float_a_shares.get_data(ids=self.ticker, dates=self.date)[0] 40 | free_float_share = self.data_reader.free_floating_share.get_data(ids=self.ticker, dates=self.date)[0] 41 | return {'基金持有': fund_holding_shares, '占总股本': fund_holding_shares / total_share, 42 | '占流通股本': fund_holding_shares / float_share, '占只有流通股本': fund_holding_shares / free_float_share} 43 | -------------------------------------------------------------------------------- /AShareData/analysis/return_analysis.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import empyrical 4 | import pandas as pd 5 | 6 | from .. import date_utils 7 | from ..factor import ContinuousFactor 8 | 9 | 10 | @date_utils.dtlize_input_dates 11 | def aggregate_returns(target: ContinuousFactor, convert_to: str, benchmark_factor: ContinuousFactor = None, 12 | start_date: date_utils.DateType = None, end_date: date_utils.DateType = None 13 | ) -> Union[pd.Series, pd.DataFrame]: 14 | """ 按 年/月/周 统计收益 15 | 16 | :param target: 标的收益率 17 | :param convert_to: 周期, 可为 ``yearly`` (年), ``monthly`` (月), ``weekly`` (周) 18 | :param benchmark_factor: 基准收益率 19 | :param start_date: 开始时间 20 | :param end_date: 结束时间 21 | :return: 各个周期的收益率. 若指定基准则还会计算各周期差值列( ``diff`` ) 22 | """ 23 | 24 | def _agg_ret(factor): 25 | target_returns = factor.get_data(start_date=start_date, end_date=end_date).unstack().iloc[:, 0] 26 | agg_target_ret = empyrical.aggregate_returns(target_returns, convert_to) 27 | return agg_target_ret 28 | 29 | ret = _agg_ret(target) 30 | if benchmark_factor: 31 | agg_benchmark_return = _agg_ret(benchmark_factor) 32 | ret = pd.concat([ret, agg_benchmark_return], axis=1) 33 | ret['diff'] = ret.iloc[:, 0] - ret.iloc[:, 1] 34 | return ret 35 | 36 | 37 | def locate_max_drawdown(returns: pd.Series) -> Tuple[pd.Timestamp, pd.Timestamp, float]: 38 | """ 寻找最大回撤周期 39 | 40 | :param returns: 收益序列, 已时间为 ``index`` 41 | :return: (最大回撤开始时间, 最大回撤结束时间, 最大回撤比例) 42 | """ 43 | if len(returns) < 1: 44 | raise ValueError('returns is empty.') 45 | 46 | cumulative = empyrical.cum_returns(returns, starting_value=100) 47 | max_return = cumulative.cummax() 48 | drawdown = cumulative.sub(max_return).div(max_return) 49 | val = drawdown.min() 50 | end = drawdown.index[drawdown.argmin()] 51 | start = drawdown.loc[(drawdown == 0) & (drawdown.index <= end)].index[-1] 52 | return start, end, val 53 | -------------------------------------------------------------------------------- /AShareData/analysis/trading.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import AShareData as asd 4 | 5 | 6 | class TradingAnalysis(object): 7 | def __init__(self, db_interface: asd.database_interface.DBInterface = None): 8 | self.db_interface = db_interface if db_interface else asd.get_db_interface() 9 | self.data_reader = asd.AShareDataReader(self.db_interface) 10 | 11 | def trading_volume_summary(self, trading_records: pd.DataFrame) -> pd.DataFrame: 12 | vol_summary = trading_records.groupby(['ID', 'tradeDirection'], as_index=False).tradeVolume.sum() 13 | single_direction_vol = vol_summary.groupby('ID').max() 14 | bi_direction_vol = vol_summary.groupby('ID').sum() 15 | 16 | date = trading_records.DateTime[0].date() 17 | market_vol_info = self.data_reader.stock_trading_volume.get_data(dates=date) 18 | market_vol_info.index = market_vol_info.index.droplevel('DateTime') 19 | 20 | single_ratio = (single_direction_vol.tradeVolume / market_vol_info).dropna() 21 | bi_direction_ratio = (bi_direction_vol.tradeVolume / market_vol_info / 2).dropna() 22 | ret = pd.concat([single_ratio, bi_direction_ratio], axis=1, sort=False) 23 | ret.columns = ['单向成交量占比', '双向成交量占比'] 24 | ret = ret.sort_values('单向成交量占比', ascending=False) 25 | 26 | return ret 27 | -------------------------------------------------------------------------------- /AShareData/ashare_data_reader.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property, lru_cache 2 | 3 | import numpy as np 4 | 5 | from . import date_utils 6 | from .config import generate_db_interface_from_config, get_db_interface 7 | from .database_interface import DBInterface 8 | from .factor import BetaFactor, BinaryFactor, CompactFactor, ContinuousFactor, FactorBase, IndexConstitute, \ 9 | IndustryFactor, InterestRateFactor, LatestAccountingFactor, OnTheRecordFactor, TTMAccountingFactor, UnaryFactor 10 | from .tickers import StockTickers 11 | 12 | 13 | class AShareDataReader(object): 14 | def __init__(self, db_interface: DBInterface = None) -> None: 15 | """ 16 | AShare Data Reader 17 | 18 | :param db_interface: DBInterface 19 | """ 20 | 21 | self.db_interface = db_interface if db_interface else get_db_interface() 22 | self.calendar = date_utils.SHSZTradingCalendar(self.db_interface) 23 | 24 | @cached_property 25 | def stocks(self) -> StockTickers: 26 | """股票列表""" 27 | return StockTickers(self.db_interface) 28 | 29 | @cached_property 30 | def sec_name(self) -> CompactFactor: 31 | """证券名称""" 32 | return CompactFactor('证券名称', self.db_interface) 33 | 34 | @cached_property 35 | def adj_factor(self) -> CompactFactor: 36 | """复权因子""" 37 | return CompactFactor('复权因子', self.db_interface) 38 | 39 | @cached_property 40 | def float_a_shares(self) -> CompactFactor: 41 | """A股流通股本""" 42 | return CompactFactor('A股流通股本', self.db_interface) 43 | 44 | @cached_property 45 | def const_limit(self) -> OnTheRecordFactor: 46 | """一字涨跌停""" 47 | return OnTheRecordFactor('一字涨跌停', self.db_interface) 48 | 49 | @cached_property 50 | def stock_open(self) -> ContinuousFactor: 51 | """股票开盘价""" 52 | return ContinuousFactor('股票日行情', '开盘价', self.db_interface) 53 | 54 | @cached_property 55 | def stock_close(self) -> ContinuousFactor: 56 | """股票收盘价""" 57 | return ContinuousFactor('股票日行情', '收盘价', self.db_interface) 58 | 59 | @cached_property 60 | def stock_trading_volume(self) -> ContinuousFactor: 61 | """股票成交量""" 62 | return ContinuousFactor('股票日行情', '成交量', self.db_interface) 63 | 64 | @cached_property 65 | def stock_trading_amount(self) -> ContinuousFactor: 66 | """股票成交额""" 67 | return ContinuousFactor('股票日行情', '成交额', self.db_interface) 68 | 69 | @cached_property 70 | def stock_turnover_rate(self) -> ContinuousFactor: 71 | """股票换手率""" 72 | return (self.stock_trading_amount / (self.stock_close * self.free_floating_share)).set_factor_name('换手率') 73 | 74 | @cached_property 75 | def total_share(self) -> CompactFactor: 76 | """股票总股本""" 77 | return CompactFactor('总股本', self.db_interface) 78 | 79 | @cached_property 80 | def free_floating_share(self) -> CompactFactor: 81 | """股票自由流通股本""" 82 | return CompactFactor('自由流通股本', self.db_interface) 83 | 84 | @cached_property 85 | def stock_market_cap(self) -> BinaryFactor: 86 | """股票总市值""" 87 | return (self.total_share * self.stock_close).set_factor_name('股票市值') 88 | 89 | @cached_property 90 | def stock_free_floating_market_cap(self) -> BinaryFactor: 91 | """股票自由流通市值""" 92 | return (self.free_floating_share * self.stock_close).set_factor_name('自由流通市值') 93 | 94 | @cached_property 95 | def free_floating_cap_weight(self) -> UnaryFactor: 96 | """自由流通市值权重""" 97 | return self.stock_free_floating_market_cap.weight().set_factor_name('自由流通市值权重') 98 | 99 | @cached_property 100 | def log_cap(self) -> UnaryFactor: 101 | """股票市值对数""" 102 | return self.stock_market_cap.log().set_factor_name('市值对数') 103 | 104 | @cached_property 105 | def hfq_close(self) -> BinaryFactor: 106 | """股票后复权收盘价""" 107 | return (self.adj_factor * self.stock_close).set_factor_name('后复权收盘价') 108 | 109 | @cached_property 110 | def stock_return(self) -> UnaryFactor: 111 | """股票收益率""" 112 | return self.hfq_close.pct_change().set_factor_name('股票收益率') 113 | 114 | @cached_property 115 | def forward_return(self) -> UnaryFactor: 116 | """股票前瞻收益率""" 117 | return self.hfq_close.pct_change_shift(-1).set_factor_name('股票前瞻收益率') 118 | 119 | @cached_property 120 | def log_return(self) -> UnaryFactor: 121 | """股票对数收益率""" 122 | return self.hfq_close.log().diff().set_factor_name('股票对数收益') 123 | 124 | @cached_property 125 | def forward_log_return(self) -> UnaryFactor: 126 | """股票前瞻对数收益率""" 127 | return self.hfq_close.log().diff_shift(-1).set_factor_name('股票前瞻对数收益') 128 | 129 | @cached_property 130 | def index_close(self) -> ContinuousFactor: 131 | """指数收盘价""" 132 | return ContinuousFactor('指数日行情', '收盘点位', self.db_interface) 133 | 134 | @cached_property 135 | def index_return(self) -> UnaryFactor: 136 | """指数收益率""" 137 | return self.index_close.pct_change().set_factor_name('指数收益率') 138 | 139 | @cached_property 140 | def user_constructed_index_return(self) -> ContinuousFactor: 141 | """自合成指数收益率""" 142 | return ContinuousFactor('自合成指数', '收益率', self.db_interface) 143 | 144 | @cached_property 145 | def market_return(self) -> ContinuousFactor: 146 | """全市场收益率""" 147 | return ContinuousFactor('自合成指数', '收益率', self.db_interface).bind_params(ids='全市场.IND') 148 | 149 | @cached_property 150 | def model_factor_return(self) -> ContinuousFactor: 151 | """模型因子收益率""" 152 | return ContinuousFactor('模型因子收益率', '收益率', self.db_interface) 153 | 154 | @cached_property 155 | def index_log_return(self) -> UnaryFactor: 156 | """指数对数收益率""" 157 | return self.index_close.log().diff().set_factor_name('指数对数收益率') 158 | 159 | @cached_property 160 | def index_constitute(self) -> IndexConstitute: 161 | """指数成分股权重""" 162 | return IndexConstitute(self.db_interface) 163 | 164 | @lru_cache(5) 165 | def industry(self, provider: str, level: int) -> IndustryFactor: 166 | """stock industry""" 167 | return IndustryFactor(provider, level, self.db_interface) 168 | 169 | @cached_property 170 | def beta(self) -> BetaFactor: 171 | """stock beat""" 172 | return BetaFactor(db_interface=self.db_interface) 173 | 174 | @cached_property 175 | def book_val(self) -> LatestAccountingFactor: 176 | """Book value""" 177 | return LatestAccountingFactor('股东权益合计(不含少数股东权益)', self.db_interface).set_factor_name('股东权益') 178 | 179 | @cached_property 180 | def earning_ttm(self) -> TTMAccountingFactor: 181 | """Earning Trailing Twelve Month""" 182 | return TTMAccountingFactor('净利润(不含少数股东损益)', self.db_interface).set_factor_name('净利润TTM') 183 | 184 | @cached_property 185 | def bm(self) -> BinaryFactor: 186 | """Book to Market""" 187 | return (self.book_val / self.stock_market_cap).set_factor_name('BM') 188 | 189 | @cached_property 190 | def bm_after_close(self) -> BinaryFactor: 191 | """After market Book to Market value""" 192 | return (self.book_val.shift(-1) / self.stock_market_cap).set_factor_name('BM') 193 | 194 | @cached_property 195 | def pb(self) -> BinaryFactor: 196 | """Price to Book""" 197 | return (self.stock_market_cap / self.book_val).set_factor_name('PB') 198 | 199 | @cached_property 200 | def cb_close(self) -> ContinuousFactor: 201 | """可转债收盘价""" 202 | return ContinuousFactor('可转债日行情', '收盘价', self.db_interface) 203 | 204 | @cached_property 205 | def cb_total_val(self) -> ContinuousFactor: 206 | """可转债未转股余额""" 207 | return ContinuousFactor('可转债日行情', '未转股余额', self.db_interface) 208 | 209 | @cached_property 210 | def cb_convert_price(self) -> CompactFactor: 211 | """可转债转股价""" 212 | return CompactFactor('可转债转股价').set_factor_name('转股价') 213 | 214 | # TODO 215 | @cached_property 216 | def pb_after_close(self) -> BinaryFactor: 217 | """After market Price to Book""" 218 | return (self.stock_market_cap / self.book_val.shift(-1)).set_factor_name('BM') 219 | 220 | @cached_property 221 | def pe_ttm(self) -> BinaryFactor: 222 | """Price to Earning Trailing Twelve Month""" 223 | return (self.stock_market_cap / self.earning_ttm).set_factor_name('PE_TTM') 224 | 225 | @cached_property 226 | def future_close(self) -> ContinuousFactor: 227 | """期货收盘价""" 228 | return ContinuousFactor('期货日行情', '收盘价', self.db_interface) 229 | 230 | @cached_property 231 | def fund_nav(self) -> ContinuousFactor: 232 | """场外基金单位净值""" 233 | return ContinuousFactor('场外基金净值', '单位净值', self.db_interface) 234 | 235 | @cached_property 236 | def hfq_fund_nav(self) -> BinaryFactor: 237 | """场外基金后复权净值""" 238 | return (self.fund_nav * self.adj_factor).set_factor_name('基金后复权净值') 239 | 240 | @cached_property 241 | def overnight_shibor(self) -> InterestRateFactor: 242 | """隔夜shibor""" 243 | return InterestRateFactor('shibor利率数据', '隔夜', self.db_interface).set_factor_name('隔夜shibor') 244 | 245 | @cached_property 246 | def three_month_shibor(self) -> InterestRateFactor: 247 | """三月期shibor""" 248 | return InterestRateFactor('shibor利率数据', '3个月', self.db_interface).set_factor_name('3个月shibor') 249 | 250 | @cached_property 251 | def six_month_shibor(self) -> InterestRateFactor: 252 | """6月期shibor""" 253 | return InterestRateFactor('shibor利率数据', '6个月', self.db_interface).set_factor_name('6个月shibor') 254 | 255 | @cached_property 256 | def one_year_shibor(self) -> InterestRateFactor: 257 | """一年期shibor""" 258 | return InterestRateFactor('shibor利率数据', '1年', self.db_interface).set_factor_name('1年shibor') 259 | 260 | @cached_property 261 | def model_factor_return(self): 262 | return ContinuousFactor('模型因子收益率', '收益率', self.db_interface) 263 | 264 | def get_index_return_factor(self, ticker: str) -> FactorBase: 265 | factor = ContinuousFactor('自合成指数', '收益率') if ticker.endswith('.IND') else self.index_return 266 | return factor.bind_params(ids=ticker) 267 | 268 | @staticmethod 269 | def exponential_weight(n: int, half_life: int): 270 | series = range(-(n - 1), 1) 271 | return np.exp(np.log(2) * series / half_life) 272 | 273 | @classmethod 274 | def from_config(cls, json_loc: str): 275 | """根据 ``config_loc`` 的适配信息生成 ``AShareDataReader`` 实例""" 276 | db_interface = generate_db_interface_from_config(json_loc) 277 | return cls(db_interface) 278 | -------------------------------------------------------------------------------- /AShareData/barra_descriptors.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | class BarraDescriptor(object): 5 | def __init__(self, factor_zoo): 6 | self.factor_zoo = factor_zoo 7 | 8 | def nature_log_of_market_cap(self): 9 | return self.factor_zoo.stock_market_cap.log() 10 | 11 | def beta(self, window: int, half_life: int): 12 | y = self.factor_zoo.excess_return() 13 | pass 14 | 15 | def relative_strength(self, window: int = 504, lag: int = 21, half_life: int = 126) -> pd.DataFrame: 16 | tmp = self.factor_zoo.log_return().sub(self.factor_zoo.log_shibor_return(), axis=0) 17 | exp_weight = self.factor_zoo.exponential_weight(window, half_life) 18 | tmp2 = tmp * exp_weight 19 | return tmp2.rolling(window, min_periods=window).sum().shift(lag) 20 | 21 | def daily_standard_deviation(self): 22 | pass 23 | 24 | def cumulative_range(self): 25 | pass 26 | 27 | def historical_sigma(self): 28 | pass 29 | 30 | def cube_of_size(self): 31 | pass 32 | 33 | def book_to_price_ratio(self): 34 | pass 35 | 36 | def share_turnover_one_month(self): 37 | pass 38 | 39 | def average_share_turnover_trailing_3_month(self): 40 | pass 41 | 42 | def average_share_turnover_trailing_12_months(self): 43 | pass 44 | 45 | def predicted_earning_to_price_ratio(self): 46 | pass 47 | 48 | def cash_earning_to_price_ratio(self): 49 | pass 50 | 51 | def trailing_earnings_to_price_ratio(self): 52 | pass 53 | 54 | def long_term_predicted_earning_growth(self): 55 | pass 56 | 57 | def short_term_predicted_earning_growth(self): 58 | pass 59 | 60 | def earnings_growth_trailing_5_years(self): 61 | pass 62 | 63 | def sales_growth_trailing_5_years(self): 64 | pass 65 | 66 | def market_leverage(self): 67 | pass 68 | 69 | def debt_to_assets(self): 70 | pass 71 | 72 | def book_leverage(self): 73 | pass 74 | -------------------------------------------------------------------------------- /AShareData/barra_style_factors.py: -------------------------------------------------------------------------------- 1 | from .barra_descriptors import BarraDescriptor 2 | 3 | 4 | class BarraStyleFactors(object): 5 | def __init__(self, descriptors: BarraDescriptor): 6 | self.descriptors = descriptors 7 | 8 | def size(self): 9 | return self.descriptors.nature_log_of_market_cap() 10 | 11 | def beta(self): 12 | window = 252 13 | half_life = 63 14 | return self.descriptors.beta(window, half_life) 15 | 16 | def momentum(self): 17 | return self.descriptors.relative_strength() 18 | 19 | def residual_volatility(self): 20 | return 0.74 * self.descriptors.daily_standard_deviation() + \ 21 | 0.16 * self.descriptors.cumulative_range() + \ 22 | 0.1 * self.descriptors.historical_sigma() 23 | 24 | def non_linear_size(self): 25 | return self.descriptors.cube_of_size() 26 | 27 | def book_to_price(self): 28 | return self.descriptors.book_to_price_ratio() 29 | 30 | def liquidity(self): 31 | return 0.35 * self.descriptors.share_turnover_one_month() + \ 32 | 0.35 * self.descriptors.average_share_turnover_trailing_3_month() + \ 33 | 0.3 * self.descriptors.average_share_turnover_trailing_12_months() 34 | 35 | def earning_yield(self): 36 | return 0.68 * self.descriptors.predicted_earning_to_price_ratio() + \ 37 | 0.21 * self.descriptors.cash_earning_to_price_ratio() + \ 38 | 0.11 * self.descriptors.trailing_earnings_to_price_ratio() 39 | 40 | def growth(self): 41 | return 0.18 * self.descriptors.long_term_predicted_earning_growth() + \ 42 | 0.11 * self.descriptors.short_term_predicted_earning_growth() + \ 43 | 0.24 * self.descriptors.earnings_growth_trailing_5_years() + \ 44 | 0.47 * self.descriptors.sales_growth_trailing_5_years() 45 | 46 | def leverage(self): 47 | return 0.38 * self.descriptors.market_leverage() + \ 48 | 0.35 * self.descriptors.debt_to_assets() + \ 49 | 0.27 * self.descriptors.book_leverage() 50 | -------------------------------------------------------------------------------- /AShareData/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, Optional, Union 3 | 4 | import sqlalchemy as sa 5 | from sqlalchemy.engine.url import URL 6 | 7 | from .database_interface import DBInterface, MySQLInterface 8 | 9 | __config__: Dict = None 10 | __db_interface__: DBInterface = None 11 | 12 | 13 | def prepare_engine(config: Dict) -> sa.engine.Engine: 14 | """Create sqlalchemy engine from config dict""" 15 | url = URL(drivername=config['driver'], host=config['host'], port=config['port'], database=config['database'], 16 | username=config['username'], password=config['password'], 17 | query={'charset': 'utf8mb4'}) 18 | return sa.create_engine(url) 19 | 20 | 21 | def generate_db_interface_from_config(config_loc: Union[str, Dict], init: bool = False) -> Optional[DBInterface]: 22 | if isinstance(config_loc, str): 23 | with open(config_loc, 'r', encoding='utf-8') as f: 24 | global_config = json.load(f) 25 | else: 26 | global_config = config_loc 27 | if 'mysql' in global_config['db_interface']['driver']: 28 | engine = prepare_engine(global_config['db_interface']) 29 | return MySQLInterface(engine, init=init) 30 | 31 | 32 | def set_global_config(config_loc: str): 33 | global __config__ 34 | with open(config_loc, 'r', encoding='utf-8') as f: 35 | __config__ = json.load(f) 36 | 37 | 38 | def get_global_config(): 39 | global __config__ 40 | if __config__ is None: 41 | raise ValueError('Global configuration not set. Please use "set_global_config" to initialize.') 42 | return __config__ 43 | 44 | 45 | def get_db_interface(): 46 | global __db_interface__ 47 | if __db_interface__ is None: 48 | __db_interface__ = generate_db_interface_from_config(get_global_config()) 49 | return __db_interface__ 50 | 51 | 52 | def set_db_interface(db_interface: DBInterface): 53 | __db_interface__ = db_interface 54 | -------------------------------------------------------------------------------- /AShareData/constants.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | TRADING_DAYS_IN_YEAR = 244 4 | 5 | # exchanges 6 | STOCK_EXCHANGES = ['SSE', 'SZSE'] 7 | FUTURE_EXCHANGES = ['CFFEX', 'DCE', 'CZCE', 'SHFE', 'INE'] 8 | ALL_EXCHANGES = STOCK_EXCHANGES + FUTURE_EXCHANGES 9 | 10 | # indexes 11 | STOCK_INDEXES = {'上证指数': '000001.SH', '深证成指': '399001.SZ', '中小板指': '399005.SZ', '创业板指': '399006.SZ', 12 | '上证50': '000016.SH', '沪深300': '000300.SH', '中证500': '000905.SH'} 13 | BOARD_INDEXES = ['000016.SH', '000300.SH', '000905.SH'] 14 | STOCK_INDEX_ETFS = {'中小板': '159902.SZ', '创业板': '159915.SZ', '50ETF': '510050.SH', '300ETF': '510300.SH', 15 | '500ETF': '510500.SH'} 16 | 17 | # financial statements 18 | FINANCIAL_STATEMENTS_TYPE = ['资产负债表', '利润表', '现金流量表', '财务指标'] 19 | 20 | # industry constants 21 | INDUSTRY_DATA_PROVIDER = ['中信', '申万', '中证', 'Wind'] 22 | INDUSTRY_DATA_PROVIDER_CODE_DICT = {'中信': 'citic', '申万': 'sw', '中证': 'csi', 'Wind': 'gics'} 23 | INDUSTRY_LEVEL = {'中信': 3, '申万': 3, '中证': 4, 'Wind': 4} 24 | INDUSTRY_START_DATE = {'中信': dt.datetime(2003, 1, 2), '申万': dt.datetime(2005, 5, 27), '中证': dt.datetime(2016, 12, 12), 25 | 'Wind': dt.datetime(2005, 1, 5)} 26 | -------------------------------------------------------------------------------- /AShareData/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jicewarwick/AShareData/13c78602fe00a5326f421c8a8003f3889492e6dd/AShareData/data/__init__.py -------------------------------------------------------------------------------- /AShareData/data/jqdata_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "可转债信息": { 3 | "code": "ID", 4 | "short_name": "名称", 5 | "company_code": "发行人ID", 6 | "list_date": "上市日期", 7 | "delist_Date": "退市日期" 8 | }, 9 | "股票集合竞价数据": { 10 | "time": "DateTime", 11 | "code": "ID", 12 | "current": "成交价", 13 | "volume": "成交量", 14 | "money": "成交额", 15 | "a1_p": "卖1价", 16 | "a2_p": "卖2价", 17 | "a3_p": "卖3价", 18 | "a4_p": "卖4价", 19 | "a5_p": "卖5价", 20 | "a1_v": "卖1量", 21 | "a2_v": "卖2量", 22 | "a3_v": "卖3量", 23 | "a4_v": "卖4量", 24 | "a5_v": "卖5量", 25 | "b1_p": "买1价", 26 | "b2_p": "买2价", 27 | "b3_p": "买3价", 28 | "b4_p": "买4价", 29 | "b5_p": "买5价", 30 | "b1_v": "买1量", 31 | "b2_v": "买2量", 32 | "b3_v": "买3量", 33 | "b4_v": "买4量", 34 | "b5_v": "买5量" 35 | }, 36 | "行情数据": { 37 | "time": "DateTime", 38 | "code": "ID", 39 | "open": "开盘价", 40 | "high": "最高价", 41 | "low": "最低价", 42 | "close": "收盘价", 43 | "volume": "成交量", 44 | "money": "成交额", 45 | "open_interest": "持仓量", 46 | "settle": "结算价", 47 | "delta": "Delta", 48 | "theta": "Theta", 49 | "gamma": "Gamma", 50 | "rho": "Rho", 51 | "vega": "Vega" 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /AShareData/data/tdx_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "行情数据": { 3 | "datetime": "DateTime", 4 | "code": "ID", 5 | "open": "开盘价", 6 | "high": "最高价", 7 | "low": "最低价", 8 | "close": "收盘价", 9 | "vol": "成交量", 10 | "amount": "成交额" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /AShareData/data/wind_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "股票分钟行情": { 3 | "OPEN": "开盘价", 4 | "HIGH": "最高价", 5 | "LOW": "最低价", 6 | "CLOSE": "收盘价", 7 | "VOLUME": "成交量", 8 | "AMOUNT": "成交额" 9 | }, 10 | "股票日行情": { 11 | "OPEN": "开盘价", 12 | "HIGH": "最高价", 13 | "LOW": "最低价", 14 | "CLOSE": "收盘价", 15 | "VOLUME": "成交量", 16 | "AMT": "成交额" 17 | }, 18 | "股票停牌": { 19 | "sec_name": "证券名称", 20 | "suspend_type": "停牌类型", 21 | "suspend_reason": "停牌原因" 22 | }, 23 | "指数日行情": { 24 | "OPEN": "开盘点位", 25 | "HIGH": "最高点位", 26 | "LOW": "最低点位", 27 | "CLOSE": "收盘点位", 28 | "VOLUME": "成交量", 29 | "AMT": "成交额" 30 | }, 31 | "可转债日行情": { 32 | "OPEN": "开盘价", 33 | "HIGH": "最高价", 34 | "LOW": "最低价", 35 | "CLOSE": "收盘价", 36 | "VOLUME": "成交量", 37 | "AMT": "成交额", 38 | "CLAUSE_CONVERSION2_BONDLOT": "未转股余额" 39 | }, 40 | "可转债分钟行情": { 41 | "OPEN": "开盘价", 42 | "HIGH": "最高价", 43 | "LOW": "最低价", 44 | "CLOSE": "收盘价", 45 | "VOLUME": "成交量", 46 | "AMOUNT": "成交额" 47 | }, 48 | "期货日行情": { 49 | "OPEN": "开盘价", 50 | "HIGH": "最高价", 51 | "LOW": "最低价", 52 | "CLOSE": "收盘价", 53 | "SETTLE": "结算价", 54 | "VOLUME": "成交量", 55 | "AMT": "成交额", 56 | "OI": "持仓量" 57 | }, 58 | "期权日行情": { 59 | "date": "DateTime", 60 | "option_code": "ID", 61 | "AMT": "成交额", 62 | "OPEN": "开盘价", 63 | "HIGH": "最高价", 64 | "LOW": "最低价", 65 | "CLOSE": "收盘价", 66 | "VOLUME": "成交量", 67 | "OI": "持仓量", 68 | "DELTA": "Delta", 69 | "GAMMA": "Gamma", 70 | "VEGA": "Vega", 71 | "THETA": "Theta", 72 | "RHO": "Rho" 73 | }, 74 | "场内基金日行情": { 75 | "OPEN": "开盘价", 76 | "HIGH": "最高价", 77 | "LOW": "最低价", 78 | "CLOSE": "收盘价", 79 | "VOLUME": "成交量", 80 | "AMT": "成交额", 81 | "NAV": "单位净值", 82 | "UNIT_TOTAL": "基金份额", 83 | "NAV_ADJ": "复权因子" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /AShareData/data/自编指数配置.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jicewarwick/AShareData/13c78602fe00a5326f421c8a8003f3889492e6dd/AShareData/data/自编指数配置.xlsx -------------------------------------------------------------------------------- /AShareData/data_source/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import logging 3 | 4 | from .data_source import DataSource 5 | from .jq_data import JQData 6 | from .tdx_data import TDXData 7 | from .tushare_data import TushareData 8 | from .web_data import WebDataCrawler 9 | 10 | if importlib.util.find_spec('WindPy'): 11 | from .wind_data import WindData 12 | 13 | logging.getLogger(__name__).info('WindPy found') 14 | else: 15 | logging.getLogger(__name__).debug('WindPy not found!!') 16 | -------------------------------------------------------------------------------- /AShareData/data_source/data_source.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | 5 | from .. import date_utils 6 | from ..config import get_db_interface 7 | from ..database_interface import DBInterface 8 | 9 | 10 | class DataSource(object): 11 | """Data Source Base Class""" 12 | 13 | def __init__(self, db_interface: DBInterface = None) -> None: 14 | self.db_interface = db_interface if db_interface else get_db_interface() 15 | self.calendar = date_utils.SHSZTradingCalendar(self.db_interface) 16 | 17 | def __enter__(self): 18 | self.login() 19 | return self 20 | 21 | def __exit__(self, exc_type, exc_val, exc_tb): 22 | self.logout() 23 | pass 24 | 25 | def login(self): 26 | pass 27 | 28 | def logout(self): 29 | pass 30 | 31 | 32 | class MinutesDataFunctionMixin(object): 33 | @staticmethod 34 | def _auction_data_to_price_data(auction_data: pd.DataFrame) -> pd.DataFrame: 35 | auction_data['开盘价'] = auction_data['成交价'] 36 | auction_data['最高价'] = auction_data['成交价'] 37 | auction_data['最低价'] = auction_data['成交价'] 38 | auction_data['收盘价'] = auction_data['成交价'] 39 | return auction_data.drop('成交价', axis=1) 40 | 41 | @classmethod 42 | def left_shift_minute_data(cls, minute_data: pd.DataFrame, auction_db_data: pd.DataFrame) -> pd.DataFrame: 43 | auction_data = cls._auction_data_to_price_data(auction_db_data) 44 | 45 | date = minute_data.index.get_level_values('DateTime')[0].date() 46 | t0930 = dt.datetime.combine(date, dt.time(9, 30)) 47 | t0931 = dt.datetime.combine(date, dt.time(9, 31)) 48 | t1500 = dt.datetime.combine(date, dt.time(15, 0)) 49 | 50 | # morning auction 51 | diff_columns = ['成交量', '成交额'] 52 | first_min_data = minute_data.loc[minute_data.index.get_level_values('DateTime') == t0931, :] 53 | 54 | tmp = first_min_data.loc[:, diff_columns].droplevel('DateTime').fillna(0) - \ 55 | auction_data.loc[:, diff_columns].droplevel('DateTime').fillna(0) 56 | tmp['DateTime'] = t0930 57 | tmp.set_index('DateTime', append=True, inplace=True) 58 | tmp.index = tmp.index.swaplevel() 59 | 60 | new_index = pd.MultiIndex.from_product([[t0930], first_min_data.index.get_level_values('ID')], 61 | names=['DateTime', 'ID']) 62 | first_min_data = first_min_data.drop(diff_columns, axis=1) 63 | first_min_data.index = new_index 64 | 65 | first_minute_db_data = pd.concat([first_min_data, tmp], sort=True, axis=1) 66 | 67 | # mid data 68 | mid_data = minute_data.reset_index() 69 | mid_data = mid_data.loc[(mid_data.DateTime < t1500) & (mid_data.DateTime > t0931), :] 70 | mid_data.DateTime = mid_data.DateTime - dt.timedelta(minutes=1) 71 | mid_data = mid_data.set_index(['DateTime', 'ID'], drop=True) 72 | 73 | # afternoon auction 74 | end_data = minute_data.loc[minute_data.index.get_level_values('DateTime') == t1500, :] 75 | 76 | # combine all 77 | storage = [auction_data, first_minute_db_data, mid_data, end_data] 78 | ret = pd.concat(storage) 79 | ret = ret.loc[ret['成交量'] >= 1, :] 80 | return ret 81 | -------------------------------------------------------------------------------- /AShareData/data_source/jq_data.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import json 3 | from functools import cached_property 4 | from typing import Mapping, Optional, Union 5 | 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | from .data_source import DataSource, MinutesDataFunctionMixin 10 | from .. import config, date_utils, utils 11 | from ..database_interface import DBInterface 12 | from ..tickers import ETFOptionTickers, FutureTickers, IndexOptionTickers, StockTickers 13 | 14 | with utils.NullPrinter(): 15 | import jqdatasdk as jq 16 | 17 | 18 | class JQData(DataSource, MinutesDataFunctionMixin): 19 | def __init__(self, db_interface: DBInterface = None, mobile: str = None, password: str = None): 20 | if db_interface is None: 21 | db_interface = config.get_db_interface() 22 | global_config = config.get_global_config() 23 | mobile = global_config['join_quant']['mobile'] 24 | password = global_config['join_quant']['password'] 25 | 26 | super().__init__(db_interface) 27 | self.mobile = mobile 28 | self.password = password 29 | self.is_logged_in = False 30 | self._factor_param = utils.load_param('jqdata_param.json') 31 | 32 | def login(self): 33 | if not self.is_logged_in: 34 | with utils.NullPrinter(): 35 | jq.auth(self.mobile, self.password) 36 | if jq.is_auth(): 37 | self.is_logged_in = True 38 | else: 39 | raise ValueError('JQDataLoginError: Wrong mobile number or password') 40 | 41 | def logout(self): 42 | if self.is_logged_in: 43 | with utils.NullPrinter(): 44 | jq.logout() 45 | 46 | @cached_property 47 | def stock_tickers(self): 48 | return StockTickers(self.db_interface) 49 | 50 | @cached_property 51 | def future_tickers(self): 52 | return FutureTickers(self.db_interface) 53 | 54 | @cached_property 55 | def stock_index_option_tickers(self): 56 | return IndexOptionTickers(self.db_interface) 57 | 58 | @cached_property 59 | def stock_etf_option_tickers(self): 60 | return ETFOptionTickers(self.db_interface) 61 | 62 | def update_convertible_bond_list(self): 63 | q = jq.query(jq.bond.BOND_BASIC_INFO).filter(jq.bond.BOND_BASIC_INFO.bond_type == '可转债') 64 | df = jq.bond.run_query(q) 65 | exchange = df.exchange.map({'深交所主板': '.SZ', '上交所': '.SH'}) 66 | df.code = df.code + exchange 67 | renaming_dict = self._factor_param['可转债信息'] 68 | df.company_code = df.company_code.apply(self.jqcode2windcode) 69 | df.list_date = df.list_date.apply(date_utils.date_type2datetime) 70 | df.delist_Date = df.delist_Date.apply(date_utils.date_type2datetime) 71 | df.company_code = df.company_code.apply(self.jqcode2windcode) 72 | ret = df.loc[:, renaming_dict.keys()].rename(renaming_dict, axis=1).set_index('ID') 73 | self.db_interface.update_df(ret, '可转债信息') 74 | print(df) 75 | 76 | def _get_stock_minute_data_first_minute(self, date: dt.datetime): 77 | renaming_dict = self._factor_param['行情数据'] 78 | diff_columns = ['成交量', '成交额'] 79 | tickers = self.stock_tickers.ticker(date) 80 | tickers = [self.windcode2jqcode(it) for it in tickers] 81 | 82 | auction_time = date + dt.timedelta(hours=9, minutes=25) 83 | auction_data = self.db_interface.read_table('股票集合竞价数据', columns=['成交价', '成交量', '成交额'], dates=auction_time) 84 | auction_db_data = self._auction_data_to_price_data(auction_data) 85 | 86 | real_first_minute = date + dt.timedelta(hours=9, minutes=30) 87 | first_minute = date + dt.timedelta(hours=9, minutes=31) 88 | first_minute_raw = jq.get_price(tickers, start_date=first_minute, end_date=first_minute, frequency='1m', 89 | fq=None, fill_paused=True) 90 | first_minute_raw.time = real_first_minute 91 | first_minute_data = self._standardize_df(first_minute_raw, renaming_dict) 92 | tmp = first_minute_data.loc[:, diff_columns].droplevel('DateTime').fillna(0) - \ 93 | auction_db_data.loc[:, diff_columns].droplevel('DateTime').fillna(0) 94 | tmp['DateTime'] = real_first_minute 95 | tmp.set_index('DateTime', append=True, inplace=True) 96 | tmp.index = tmp.index.swaplevel() 97 | first_minute_db_data = pd.concat([first_minute_data.drop(diff_columns, axis=1), tmp], sort=True, axis=1) 98 | db_data = pd.concat([auction_db_data, first_minute_db_data], sort=True) 99 | self.db_interface.insert_df(db_data, '股票分钟行情') 100 | 101 | def _get_stock_minute_data_after_first_minute(self, date: dt.datetime): 102 | renaming_dict = self._factor_param['行情数据'] 103 | tickers = self.stock_tickers.ticker(date) 104 | tickers = [self.windcode2jqcode(it) for it in tickers] 105 | 106 | t0932 = date + dt.timedelta(hours=9, minutes=32) 107 | t1458 = date + dt.timedelta(hours=14, minutes=58) 108 | t1459 = date + dt.timedelta(hours=14, minutes=59) 109 | t1500 = date + dt.timedelta(hours=15) 110 | 111 | data = jq.get_price(tickers, start_date=t0932, end_date=t1458, frequency='1m', fq=None, fill_paused=True) 112 | data.time = data.time.apply(lambda x: x - dt.timedelta(minutes=1)) 113 | db_data = self._standardize_df(data, renaming_dict) 114 | self.db_interface.insert_df(db_data, '股票分钟行情') 115 | 116 | # SZ 117 | sz_tickers = [it for it in tickers if it.endswith('XSHE')] 118 | data = jq.get_price(sz_tickers, start_date=t1500, end_date=t1500, frequency='1m', fq=None, fill_paused=True) 119 | db_data = self._standardize_df(data, renaming_dict) 120 | self.db_interface.insert_df(db_data, '股票分钟行情') 121 | 122 | # SH 123 | sh_tickers = [it for it in tickers if it.endswith('XSHG')] 124 | data = jq.get_price(sh_tickers, start_date=t1459, end_date=t1500, frequency='1m', fq=None, fill_paused=True) 125 | data = data.loc[data.volume > 0, :] 126 | if t1459 in data.time.tolist(): 127 | data.time = data.time.apply(lambda x: x - dt.timedelta(minutes=1)) 128 | db_data = self._standardize_df(data, renaming_dict) 129 | self.db_interface.insert_df(db_data, '股票分钟行情') 130 | 131 | def get_stock_minute(self, date: dt.datetime): 132 | self._get_stock_minute_data_first_minute(date) 133 | self._get_stock_minute_data_after_first_minute(date) 134 | 135 | def update_stock_minute(self): 136 | table_name = '股票分钟行情' 137 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 138 | start_date = self.calendar.offset(db_timestamp.date(), 1) 139 | if dt.datetime.now().hour < 16: 140 | end_date = self.calendar.yesterday() 141 | else: 142 | end_date = dt.datetime.today() 143 | dates = self.calendar.select_dates(start_date, end_date) 144 | with tqdm(dates) as pbar: 145 | for date in dates: 146 | pbar.set_description(f'更新{date}的{table_name}') 147 | self.get_stock_minute(date) 148 | pbar.update() 149 | 150 | def update_stock_morning_auction_data(self): 151 | """更新早盘集合竞价数据""" 152 | table_name = '股票集合竞价数据' 153 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 154 | start_date = self.calendar.offset(db_timestamp.date(), 1) 155 | end_date = dt.datetime.today() 156 | dates = self.calendar.select_dates(start_date, end_date) 157 | with tqdm(dates) as pbar: 158 | for date in dates: 159 | pbar.set_description(f'更新{date}的{table_name}') 160 | self.stock_open_auction_data(date) 161 | pbar.update() 162 | 163 | @date_utils.dtlize_input_dates 164 | def stock_open_auction_data(self, date: date_utils.DateType): 165 | """获取 ``date`` 的早盘集合竞价数据""" 166 | table_name = '股票集合竞价数据' 167 | renaming_dict = self._factor_param[table_name] 168 | date_str = date_utils.date_type2str(date, '-') 169 | tickers = self.stock_tickers.ticker(date) 170 | tickers = [self.windcode2jqcode(it) for it in tickers] 171 | data = jq.get_call_auction(tickers, start_date=date_str, end_date=date_str) 172 | auction_time = dt.datetime.combine(date.date(), dt.time(hour=9, minute=25)) 173 | data.time = auction_time 174 | data = data.loc[data.volume > 0, :] 175 | db_data = self._standardize_df(data, renaming_dict) 176 | self.db_interface.insert_df(db_data, table_name) 177 | 178 | @date_utils.dtlize_input_dates 179 | def get_stock_daily(self, date: date_utils.DateType): 180 | renaming_dict = self._factor_param['行情数据'] 181 | tickers = self.stock_tickers.ticker(date) 182 | tickers = [self.windcode2jqcode(it) for it in tickers] 183 | 184 | data = jq.get_price(tickers, start_date=date, end_date=date, frequency='daily', fq=None, fill_paused=True) 185 | db_data = self._standardize_df(data, renaming_dict) 186 | self.db_interface.insert_df(db_data, '股票日行情') 187 | 188 | def update_stock_daily(self): 189 | table_name = '股票日行情' 190 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 191 | dates = self.calendar.select_dates(db_timestamp, dt.date.today()) 192 | dates = dates[1:] 193 | for date in dates: 194 | self.get_stock_daily(date) 195 | 196 | @date_utils.dtlize_input_dates 197 | def get_future_daily(self, date: date_utils.DateType): 198 | renaming_dict = self._factor_param['行情数据'] 199 | tickers = self.future_tickers.ticker(date) 200 | tickers = [self.windcode2jqcode(it) for it in tickers] 201 | 202 | data = jq.get_price(tickers, start_date=date, end_date=date, frequency='daily', fq=None, fill_paused=True, 203 | fields=['open', 'high', 'low', 'close', 'volume', 'money', 'open_interest']) 204 | settle_data = jq.get_extras('futures_sett_price', tickers, start_date=date, end_date=date) 205 | settle = settle_data.stack().reset_index() 206 | settle.columns = ['time', 'code', 'settle'] 207 | combined_data = pd.merge(data, settle) 208 | db_data = self._standardize_df(combined_data, renaming_dict).sort_index() 209 | self.db_interface.insert_df(db_data, '期货日行情') 210 | 211 | def update_future_daily(self): 212 | table_name = '期货日行情' 213 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 214 | dates = self.calendar.select_dates(db_timestamp, dt.date.today()) 215 | dates = dates[1:] 216 | for date in dates: 217 | self.get_future_daily(date) 218 | 219 | # TODO 220 | def _get_future_settle_info(self, date): 221 | # table_name = '期货结算参数' 222 | tickers = self.future_tickers.ticker(date) 223 | tickers = [self.windcode2jqcode(it) for it in tickers] 224 | 225 | data = jq.get_extras('futures_sett_price', tickers, start_date=date, end_date=date) 226 | data.columns = [self.jqcode2windcode(it) for it in data.columns] 227 | df = data.stack() 228 | df.name = '结算价' 229 | df.index.names = ['DateTime', 'ID'] 230 | return df 231 | 232 | @date_utils.dtlize_input_dates 233 | def get_stock_option_daily(self, date: date_utils.DateType): 234 | renaming_dict = self._factor_param['行情数据'] 235 | tickers = self.stock_index_option_tickers.ticker(date) + self.stock_etf_option_tickers.ticker(date) 236 | tickers = [self.windcode2jqcode(it) for it in tickers] 237 | 238 | data = jq.get_price(tickers, start_date=date, end_date=date, frequency='daily', fq=None, fill_paused=True, 239 | fields=['open', 'high', 'low', 'close', 'volume', 'money', 'open_interest']) 240 | q = jq.query(jq.opt.OPT_RISK_INDICATOR).filter(jq.opt.OPT_RISK_INDICATOR.date == date) \ 241 | .filter(jq.opt.OPT_RISK_INDICATOR.exchange_code.in_(['XSHG', 'XSHE', 'CCFX'])) 242 | risk_data = jq.opt.run_query(q) 243 | risk = risk_data.drop(['id', 'exchange_code', 'date'], axis=1) 244 | combined_data = pd.merge(data, risk) 245 | db_data = self._standardize_df(combined_data, renaming_dict).sort_index() 246 | self.db_interface.insert_df(db_data, '期权日行情') 247 | 248 | def update_stock_option_daily(self): 249 | table_name = '期权日行情' 250 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 251 | dates = self.calendar.select_dates(db_timestamp, dt.date.today()) 252 | dates = dates[1:] 253 | for date in dates: 254 | self.get_stock_option_daily(date) 255 | 256 | @staticmethod 257 | def _standardize_df(df: pd.DataFrame, parameter_info: Mapping[str, str]) -> Union[pd.Series, pd.DataFrame]: 258 | dates_columns = [it for it in df.columns if it.endswith('date') | it.endswith('time')] 259 | for it in dates_columns: 260 | df[it] = df[it].apply(date_utils.date_type2datetime) 261 | 262 | df.rename(parameter_info, axis=1, inplace=True) 263 | if 'ID' in df.columns: 264 | df.ID = df.ID.apply(JQData.jqcode2windcode) 265 | index = sorted(list({'DateTime', 'ID', '报告期', 'IndexCode'} & set(df.columns))) 266 | df = df.set_index(index, drop=True) 267 | if df.shape[1] == 1: 268 | df = df.iloc[:, 0] 269 | return df 270 | 271 | @staticmethod 272 | def jqcode2windcode(ticker: str) -> Optional[str]: 273 | if ticker: 274 | ticker = ticker.replace('.XSHG', '.SH').replace('.XSHE', '.SZ') 275 | ticker = ticker.replace('.XDCE', '.DCE').replace('.XSGE', '.SHF').replace('.XZCE', '.CZC') 276 | ticker = ticker.replace('.XINE', '.INE') 277 | ticker = ticker.replace('.CCFX', '.CFE') 278 | if ticker.endswith('.CZC'): 279 | ticker = utils.format_czc_ticker(ticker) 280 | return ticker 281 | 282 | @staticmethod 283 | def windcode2jqcode(ticker: str) -> Optional[str]: 284 | if ticker: 285 | ticker = ticker.replace('.DCE', '.XDCE').replace('.SHF', '.XSGE').replace('.CZC', '.XZCE') 286 | ticker = ticker.replace('.CFE', '.CCFX') 287 | ticker = ticker.replace('.INE', '.XINE') 288 | ticker = ticker.replace('.SH', '.XSHG').replace('.SZ', '.XSHE') 289 | if ticker.endswith('.XZCE') and len(ticker) <= 11: 290 | ticker = utils.full_czc_ticker(ticker) 291 | return ticker 292 | 293 | @classmethod 294 | def from_config(cls, config_loc: str): 295 | with open(config_loc, 'r', encoding='utf-8') as f: 296 | global_config = json.load(f) 297 | db_interface = config.generate_db_interface_from_config(config_loc) 298 | return cls(db_interface, global_config['join_quant']['mobile'], global_config['join_quant']['password']) 299 | -------------------------------------------------------------------------------- /AShareData/data_source/tdx_data.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from collections import OrderedDict 3 | from typing import Sequence 4 | 5 | import pandas as pd 6 | from pytdx.hq import TdxHq_API 7 | from pytdx.params import TDXParams 8 | from tqdm import tqdm 9 | 10 | from .data_source import DataSource, MinutesDataFunctionMixin 11 | from .. import utils 12 | from ..config import get_global_config 13 | from ..database_interface import DBInterface 14 | from ..tickers import ConvertibleBondTickers, StockTickers 15 | 16 | 17 | class TDXData(DataSource, MinutesDataFunctionMixin): 18 | def __init__(self, db_interface: DBInterface = None, host: str = None, port: int = None): 19 | super().__init__(db_interface) 20 | if host is None: 21 | conf = get_global_config() 22 | host = conf['tdx_server']['host'] 23 | port = conf['tdx_server']['port'] 24 | self.api = TdxHq_API() 25 | self.host = host 26 | self.port = port 27 | self._factor_param = utils.load_param('tdx_param.json') 28 | self.stock_ticker = StockTickers(db_interface) 29 | 30 | def login(self): 31 | self.api.connect(self.host, self.port) 32 | 33 | def logout(self): 34 | self.api.disconnect() 35 | 36 | def update_stock_minute(self): 37 | """更新股票分钟行情""" 38 | table_name = '股票分钟行情' 39 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(2015, 1, 1)) 40 | start_date = self.calendar.offset(db_timestamp.date(), 1) 41 | end_date = dt.datetime.today() 42 | dates = self.calendar.select_dates(start_date, end_date) 43 | for date in dates: 44 | self.get_stock_minute(date) 45 | 46 | def get_stock_minute(self, date: dt.datetime) -> None: 47 | """获取 ``date`` 的股票分钟行情""" 48 | tickers = self.stock_ticker.ticker(date) 49 | minute_data = self.get_minute_data(date, tickers) 50 | auction_time = date + dt.timedelta(hours=9, minutes=25) 51 | auction_db_data = self.db_interface.read_table('股票集合竞价数据', columns=['成交价', '成交量', '成交额'], dates=auction_time) 52 | df = self.left_shift_minute_data(minute_data=minute_data, auction_db_data=auction_db_data) 53 | 54 | self.db_interface.insert_df(df, '股票分钟行情') 55 | 56 | def update_convertible_bond_minute(self): 57 | """更新可转债分钟行情""" 58 | table_name = '可转债分钟行情' 59 | cb_tickers = ConvertibleBondTickers(self.db_interface) 60 | 61 | db_timestamp = self.db_interface.get_latest_timestamp(table_name, dt.datetime(1998, 9, 2)) 62 | start_date = self.calendar.offset(db_timestamp.date(), 1) 63 | end_date = dt.datetime.today() 64 | dates = self.calendar.select_dates(start_date, end_date) 65 | 66 | for date in dates: 67 | tickers = cb_tickers.ticker(date) 68 | minute_data = self.get_minute_data(date, tickers) 69 | self.db_interface.insert_df(minute_data, table_name) 70 | 71 | def get_minute_data(self, date: dt.datetime, tickers: Sequence[str]) -> pd.DataFrame: 72 | num_days = self.calendar.days_count(date, dt.date.today()) 73 | start_index = num_days * 60 * 4 74 | 75 | storage = [] 76 | with tqdm(tickers) as pbar: 77 | for ticker in tickers: 78 | pbar.set_description(f'下载 {ticker} 在 {date} 的分钟数据') 79 | code, market = self._split_ticker(ticker) 80 | data = self.api.get_security_bars(category=TDXParams.KLINE_TYPE_1MIN, market=market, code=code, 81 | start=start_index, count=240) 82 | if data: 83 | data = self._formatting_data(data, ticker) 84 | storage.append(data) 85 | pbar.update() 86 | 87 | df = pd.concat(storage) if storage else pd.DataFrame() 88 | return df 89 | 90 | def _formatting_data(self, info: OrderedDict, ticker: str) -> pd.DataFrame: 91 | df = pd.DataFrame(info) 92 | df['datetime'] = df['datetime'].apply(self.str2datetime) 93 | df = df.drop(['year', 'month', 'day', 'hour', 'minute'], axis=1).rename(self._factor_param['行情数据'], axis=1) 94 | df['ID'] = ticker 95 | 96 | df = df.set_index(['DateTime', 'ID'], drop=True) 97 | df = df.where(abs(df) > 0.0001, 0) 98 | return df 99 | 100 | @staticmethod 101 | def _split_ticker(ticker: str) -> [str, int]: 102 | code, market_str = ticker.split('.') 103 | market = TDXParams.MARKET_SZ if market_str == 'SZ' else TDXParams.MARKET_SH 104 | return code, market 105 | 106 | @staticmethod 107 | def str2datetime(date: str) -> dt.datetime: 108 | return dt.datetime.strptime(date, '%Y-%m-%d %H:%M') 109 | -------------------------------------------------------------------------------- /AShareData/data_source/web_data.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import logging 3 | from io import StringIO 4 | from typing import Sequence, Union 5 | 6 | import pandas as pd 7 | import requests 8 | from tqdm import tqdm 9 | 10 | from .data_source import DataSource 11 | from .. import date_utils, utils 12 | from ..config import get_db_interface 13 | from ..database_interface import DBInterface 14 | from ..tickers import StockTickers 15 | 16 | 17 | class WebDataCrawler(DataSource): 18 | """Get data through HTTP connections""" 19 | _SW_INDUSTRY_URL = 'http://www.swsindex.com/downloadfiles.aspx' 20 | _HEADER = { 21 | 'Connection': 'keep-alive', 22 | 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.121 Safari/537.36', 23 | } 24 | _ZZ_INDUSTRY_URL = 'http://www.csindex.com.cn/zh-CN/downloads/industry-price-earnings-ratio-detail' 25 | 26 | def __init__(self, db_schema_loc: str = None, init: bool = False, db_interface: DBInterface = None) -> None: 27 | if db_interface is None: 28 | db_interface = get_db_interface() 29 | super().__init__(db_interface) 30 | if init: 31 | logging.getLogger(__name__).debug('检查数据库完整性.') 32 | self._db_parameters = utils.load_param('db_schema.json', db_schema_loc) 33 | for table_name, type_info in self._db_parameters.items(): 34 | self.db_interface.create_table(table_name, type_info) 35 | 36 | self._stock_list = StockTickers(db_interface).ticker() 37 | 38 | def get_sw_industry(self) -> None: 39 | """获取申万一级行业""" 40 | header = self._HEADER 41 | header['referer'] = 'http://www.swsindex.com/idx0530.aspx' 42 | params = {'swindexcode': 'SwClass', 'type': 530, 'columnid': 8892} 43 | response = requests.post(self._SW_INDUSTRY_URL, headers=header, params=params) 44 | raw_data = pd.read_html(response.content.decode('gbk'))[0] 45 | 46 | def convert_dt(x: str) -> dt.datetime: 47 | date, time = x.split(' ') 48 | date_parts = [int(it) for it in date.split('/')] 49 | time_parts = [int(it) for it in time.split(':')] 50 | ret = dt.datetime(*date_parts, *time_parts) 51 | return ret 52 | 53 | raw_data['DateTime'] = raw_data['起始日期'].map(convert_dt) 54 | raw_data['ID'] = raw_data['股票代码'].map(stock_code2ts_code) 55 | 56 | raw_data.set_index(['DateTime', 'ID'], inplace=True) 57 | self.db_interface.update_df(raw_data[['行业名称']], '申万一级行业') 58 | 59 | @date_utils.dtlize_input_dates 60 | def get_zz_industry(self, date: date_utils.DateType) -> None: 61 | """获取中证4级行业""" 62 | referer_template = 'http://www.csindex.com.cn/zh-CN/downloads/industry-price-earnings-ratio?type=zz1&date=' 63 | date_str = date_utils.date_type2str(date, '-') 64 | header = self._HEADER 65 | header['referer'] = referer_template + date_str 66 | storage = [] 67 | 68 | with tqdm(self._stock_list) as pbar: 69 | for it in self._stock_list: 70 | pbar.set_description(f'正在获取{it}的中证行业数据') 71 | params = {'date': date_str, 'class': 2, 'search': 1, 'csrc_code': it.split('.')[0]} 72 | response = requests.get(self._ZZ_INDUSTRY_URL, headers=header, params=params) 73 | res_table = pd.read_html(response.text)[0] 74 | storage.append(res_table) 75 | pbar.update(1) 76 | data = pd.concat(storage) 77 | data['股票代码'] = data['股票代码'].map(stock_code2ts_code) 78 | data['trade_date'] = date 79 | useful_data = data[['trade_date', '股票代码', '所属中证行业四级名称']] 80 | useful_data.columns = ['DateTime', 'ID', '行业名称'] 81 | useful_data.set_index(['DateTime', 'ID'], inplace=True) 82 | self.db_interface.update_df(useful_data, '中证行业') 83 | 84 | 85 | def stock_code2ts_code(stock_code: Union[int, str]) -> str: 86 | stock_code = int(stock_code) 87 | return f'{stock_code:06}.SH' if stock_code >= 600000 else f'{stock_code:06}.SZ' 88 | 89 | 90 | def ts_code2stock_code(ts_code: str) -> str: 91 | return ts_code.split()[0] 92 | 93 | 94 | def get_current_cffex_contracts(products: Union[str, Sequence[str]]): 95 | """Get IC, IF, IH, IO contracts from CFFEX""" 96 | today = dt.datetime.today() 97 | url = f'http://www.cffex.com.cn/sj/jycs/{today.strftime("%Y%m")}/{today.strftime("%d")}/{today.strftime("%Y%m%d")}_1.csv' 98 | rsp = requests.get(url) 99 | rsp.encoding = 'gbk' 100 | data = pd.read_csv(StringIO(rsp.text), skiprows=1) 101 | tickers = data['合约代码'].tolist() 102 | if isinstance(products, str): 103 | products = [products] 104 | ret = [it for it in tickers if it[:2] in products] 105 | return ret 106 | -------------------------------------------------------------------------------- /AShareData/date_utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import datetime as dt 3 | import inspect 4 | from functools import wraps 5 | from typing import Callable, List, Optional, Sequence, Tuple, Union 6 | 7 | from singleton_decorator import singleton 8 | 9 | from .config import get_db_interface 10 | from .database_interface import DBInterface 11 | 12 | DateType = Union[str, dt.datetime, dt.date] 13 | 14 | 15 | def date_type2str(date: DateType, delimiter: str = '') -> Optional[str]: 16 | if date is not None: 17 | formatter = delimiter.join(['%Y', '%m', '%d']) 18 | return date.strftime(formatter) if not isinstance(date, str) else date 19 | 20 | 21 | def strlize_input_dates(func): 22 | @wraps(func) 23 | def inner(*args, **kwargs): 24 | signature = inspect.signature(func) 25 | for arg, (arg_name, _) in zip(args, signature.parameters.items()): 26 | if arg in ['start_date', 'end_date', 'dates', 'date', 'report_period']: 27 | kwargs[arg_name] = date_type2str(arg) 28 | else: 29 | kwargs[arg_name] = arg 30 | 31 | for it in kwargs.keys(): 32 | if it in ['start_date', 'end_date', 'dates', 'date', 'report_period']: 33 | kwargs[it] = date_type2str(kwargs[it]) 34 | return func(**kwargs) 35 | 36 | return inner 37 | 38 | 39 | def date_type2datetime(date: Union[DateType, Sequence]) -> Optional[Union[dt.datetime, Sequence[dt.datetime]]]: 40 | if isinstance(date, str): 41 | return _date_type2datetime(date) 42 | elif isinstance(date, Sequence): 43 | return [_date_type2datetime(it) for it in date] 44 | else: 45 | return _date_type2datetime(date) 46 | 47 | 48 | def _date_type2datetime(date: DateType) -> Optional[dt.datetime]: 49 | if isinstance(date, dt.datetime): 50 | return date 51 | if isinstance(date, dt.date): 52 | return dt.datetime.combine(date, dt.time()) 53 | if isinstance(date, str) & (date not in ['', 'nan']): 54 | date = date.replace('/', '') 55 | date = date.replace('-', '') 56 | return dt.datetime.strptime(date, '%Y%m%d') 57 | 58 | 59 | def dtlize_input_dates(func): 60 | @wraps(func) 61 | def inner(*args, **kwargs): 62 | signature = inspect.signature(func) 63 | for arg, (arg_name, _) in zip(args, signature.parameters.items()): 64 | if arg in ['start_date', 'end_date', 'dates', 'date', 'report_period']: 65 | kwargs[arg_name] = date_type2datetime(arg) 66 | else: 67 | kwargs[arg_name] = arg 68 | 69 | for it in kwargs.keys(): 70 | if it in ['start_date', 'end_date', 'dates', 'date', 'report_period']: 71 | kwargs[it] = date_type2datetime(kwargs[it]) 72 | return func(**kwargs) 73 | 74 | return inner 75 | 76 | 77 | class TradingCalendarBase(object): 78 | def __init__(self): 79 | self.calendar = None 80 | 81 | @dtlize_input_dates 82 | def is_trading_date(self, date: DateType): 83 | """return if ``date`` is a trading date""" 84 | return date in self.calendar 85 | 86 | @dtlize_input_dates 87 | def select_dates(self, start_date: DateType = None, end_date: DateType = None, 88 | inclusive=(True, True), period: str = None) -> List[dt.datetime]: 89 | """ Get list of all trading days between ``start_date`` and ``end_date`` 90 | 91 | :param start_date: start date 92 | :param end_date: end date 93 | :param inclusive: when select daily trading dates, if the return list include the start date and end date in the parameter 94 | :param period: valid for {'d', 'wb', 'we', 'mb', 'me', 'yb', 'ye'} where 'd', 'w', 'm', 'y' stands for day, week, month and year. 'b' and 'e' stands for beginning and the end 95 | :return: 96 | """ 97 | if start_date is None: 98 | start_date = self.calendar[0] 99 | if end_date is None: 100 | end_date = dt.datetime.now() 101 | 102 | if period is None or period.lower() == 'd': 103 | dates = self._select_dates(start_date, end_date, lambda pre, curr, next_: True) 104 | if dates and not inclusive[0]: 105 | dates = dates[1:] 106 | if dates and not inclusive[1]: 107 | dates = dates[:-1] 108 | return dates 109 | elif period.lower() == 'wb': 110 | return self.first_day_of_week(start_date, end_date) 111 | elif period.lower() == 'we': 112 | return self.last_day_of_week(start_date, end_date) 113 | elif period.lower() == 'mb': 114 | return self.first_day_of_month(start_date, end_date) 115 | elif period.lower() == 'me': 116 | return self.last_day_of_month(start_date, end_date) 117 | elif period.lower() == 'yb': 118 | return self.first_day_of_year(start_date, end_date) 119 | elif period.lower() == 'ye': 120 | return self.last_day_of_year(start_date, end_date) 121 | 122 | @dtlize_input_dates 123 | def offset(self, date: DateType, days: int) -> dt.datetime: 124 | """offset ``date`` by number of days 125 | 126 | Push days forward if ``days`` is positive and backward if ``days`` is negative. 127 | If `days = 0` and date is a trading day, date is returned 128 | If `days = 0` and date is not a trading day, the next trading day is returned 129 | """ 130 | loc = bisect.bisect_left(self.calendar, date) 131 | if self.calendar[loc] != date and days > 0: 132 | days = days - 1 133 | return self.calendar[loc + days] 134 | 135 | @dtlize_input_dates 136 | def middle(self, start_date: DateType, end_date: DateType) -> dt.datetime: 137 | """Get middle of the trading period[``start_date``, ``end_date``]""" 138 | return self.calendar[int((self.calendar.index(start_date) + self.calendar.index(end_date)) / 2.0)] 139 | 140 | @dtlize_input_dates 141 | def days_count(self, start_date: DateType, end_date: DateType) -> int: 142 | """Count number of trading days during [``start_date``, ``end_date``] 143 | Note: ``end_date`` need to be a trading day 144 | """ 145 | i = bisect.bisect_left(self.calendar, start_date) 146 | if self.calendar[i] != start_date: 147 | i = i - 1 148 | j = self.calendar.index(end_date) 149 | 150 | return j - i 151 | 152 | def today(self) -> dt.datetime: 153 | t = dt.datetime.combine(dt.date.today(), dt.time()) 154 | if not self.is_trading_date(t): 155 | t = self.offset(t, -1) 156 | return t 157 | 158 | def yesterday(self) -> dt.datetime: 159 | return self.offset(dt.date.today(), -1) 160 | 161 | @dtlize_input_dates 162 | def first_day_of_week(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 163 | """Get first trading day of weeks between [``start_date``, ``end_date``]""" 164 | return self._select_dates(start_date, end_date, 165 | lambda pre, curr, next_: pre.isocalendar()[1] != curr.isocalendar()[1]) 166 | 167 | @dtlize_input_dates 168 | def last_day_of_week(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 169 | """Get last trading day of weeks between [``start_date``, ``end_date``]""" 170 | return self._select_dates(start_date, end_date, 171 | lambda pre, curr, next_: curr.isocalendar()[1] != next_.isocalendar()[1]) 172 | 173 | @dtlize_input_dates 174 | def first_day_of_month(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 175 | """Get first trading day of months between [``start_date``, ``end_date``]""" 176 | return self._select_dates(start_date, end_date, lambda pre, curr, next_: pre.month != curr.month) 177 | 178 | @dtlize_input_dates 179 | def last_day_of_month(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 180 | """Get last trading day of months between [``start_date``, ``end_date``]""" 181 | return self._select_dates(start_date, end_date, 182 | lambda pre, curr, next_: curr.month != next_.month) 183 | 184 | def month_begin(self, year: int, month: int): 185 | """Get the first trading date of month (year, month)""" 186 | anchor = dt.datetime(year, month, 1) 187 | i = bisect.bisect_left(self.calendar, anchor) 188 | return self.calendar[i] 189 | 190 | def pre_month_end(self, year: int, month: int): 191 | """Get the last month's last trading date""" 192 | anchor = dt.datetime(year, month, 1) 193 | i = bisect.bisect_left(self.calendar, anchor) 194 | return self.calendar[i - 1] 195 | 196 | def month_end(self, year: int, month: int): 197 | """Get the last trading date of month (year, month)""" 198 | if month == 12: 199 | year += 1 200 | month = 1 201 | else: 202 | month += 1 203 | anchor = dt.datetime(year, month, 1) 204 | i = bisect.bisect_left(self.calendar, anchor) 205 | return self.calendar[i - 1] 206 | 207 | @dtlize_input_dates 208 | def first_day_of_year(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 209 | """Get first trading day of the year between [``start_date``, ``end_date``]""" 210 | return self._select_dates(start_date, end_date, lambda pre, curr, next_: pre.year != curr.year) 211 | 212 | @dtlize_input_dates 213 | def last_day_of_year(self, start_date: DateType = None, end_date: DateType = None) -> List[dt.datetime]: 214 | """Get last trading day of the year between [``start_date``, ``end_date``]""" 215 | return self._select_dates(start_date, end_date, lambda pre, curr, next_: curr.year != next_.year) 216 | 217 | def _select_dates(self, start_date: dt.datetime = None, end_date: dt.datetime = None, 218 | func: Callable[[dt.datetime, dt.datetime, dt.datetime], bool] = None) -> List[dt.datetime]: 219 | i = bisect.bisect_left(self.calendar, start_date) 220 | j = bisect.bisect_right(self.calendar, end_date) 221 | if self.calendar[j] == end_date: 222 | j = j + 1 223 | 224 | if func: 225 | storage = [] 226 | for k in range(i, j): 227 | if func(self.calendar[k - 1], self.calendar[k], self.calendar[k + 1]): 228 | storage.append(self.calendar[k]) 229 | return storage 230 | else: 231 | return self.calendar[i:j] 232 | 233 | def split_to_chunks(self, start_date: DateType, end_date: DateType, chunk_size: int) \ 234 | -> List[Tuple[dt.datetime, dt.datetime]]: 235 | all_dates = self.select_dates(start_date, end_date) 236 | res = [] 237 | for i in range(0, len(all_dates), chunk_size): 238 | tmp = all_dates[i:i + chunk_size] 239 | res.append((tmp[0], tmp[-1])) 240 | return res 241 | 242 | 243 | @singleton 244 | class SHSZTradingCalendar(TradingCalendarBase): 245 | """A Share Trading Calendar""" 246 | 247 | def __init__(self, db_interface: DBInterface = None): 248 | super().__init__() 249 | self.db_interface = db_interface if db_interface else get_db_interface() 250 | calendar_df = self.db_interface.read_table('交易日历') 251 | self.calendar = sorted(calendar_df['交易日期'].dt.to_pydatetime().tolist()) 252 | 253 | 254 | @singleton 255 | class HKTradingCalendar(TradingCalendarBase): 256 | """A Share Trading Calendar""" 257 | 258 | def __init__(self, db_interface: DBInterface = None): 259 | super().__init__() 260 | self.db_interface = db_interface if db_interface else get_db_interface() 261 | calendar_df = self.db_interface.read_table('港股交易日历') 262 | self.calendar = sorted(calendar_df['交易日期'].dt.to_pydatetime().tolist()) 263 | 264 | 265 | class ReportingDate(object): 266 | @staticmethod 267 | @dtlize_input_dates 268 | def yoy_date(date: DateType) -> dt.datetime: 269 | """ 270 | 返回去年同期的报告期 271 | 272 | :param date: 报告期 273 | :return: 去年同期的报告期 274 | """ 275 | return dt.datetime(date.year - 1, date.month, date.day) 276 | 277 | @staticmethod 278 | @dtlize_input_dates 279 | def yearly_offset(date: DateType, delta: int = 1) -> dt.datetime: 280 | """ 281 | 返回``delta``年后的年报报告期 282 | 283 | :param date: 报告期 284 | :param delta: 时长(年) 285 | :return: 前``delta``个年报的报告期 286 | """ 287 | return dt.datetime(date.year + delta, 12, 31) 288 | 289 | @staticmethod 290 | @dtlize_input_dates 291 | def quarterly_offset(date: DateType, delta: int = 1) -> dt.datetime: 292 | """ 293 | 返回 ``delta`` 个季度后的报告期 294 | 295 | :param date: 报告期 296 | :param delta: 时长(季度) 297 | :return: ``delta`` 个季度后的报告期 298 | """ 299 | rep = date.year * 12 + date.month + delta * 3 - 1 300 | month = rep % 12 + 1 301 | day = 31 if month == 3 or month == 12 else 30 302 | return dt.datetime(rep // 12, month, day) 303 | 304 | @classmethod 305 | def offset(cls, report_date, offset_str: str): 306 | """报告期偏移 307 | 308 | :param report_date: 基准报告期 309 | :param offset_str: 偏移量:如``q3``, ``y1`` 310 | :return: 偏移后的报告期 311 | """ 312 | delta = -int(offset_str[1:]) 313 | if offset_str[0] == 'q': 314 | return cls.quarterly_offset(report_date, delta) 315 | elif offset_str[0] == 'y': 316 | return cls.yearly_offset(report_date, delta) 317 | else: 318 | raise ValueError(f'Illegal offset_str: {offset_str}') 319 | 320 | @staticmethod 321 | def get_latest_report_date(date: Union[dt.date, dt.datetime] = None) -> List[dt.datetime]: 322 | """ 323 | 获取最新报告期 324 | 325 | 上市公司季报披露时间: 326 | 一季报:4月1日——4月30日。 327 | 二季报(中报):7月1日——8月30日。 328 | 三季报:10月1日——10月31日。 329 | 四季报(年报):1月1日——4月30日。 330 | 331 | :return: 最新财报的报告期 332 | """ 333 | if date is None: 334 | date = dt.date.today() 335 | year = date.year 336 | if date.month < 4: 337 | return [dt.datetime(year - 1, 12, 31)] 338 | elif date.month < 5: 339 | return [dt.datetime(year, 3, 30), dt.datetime(year - 1, 12, 31)] 340 | elif date.month < 9: 341 | return [dt.datetime(year, 6, 30)] 342 | else: 343 | return [dt.datetime(year, 9, 30)] 344 | 345 | @staticmethod 346 | @dtlize_input_dates 347 | def get_report_date(year: int, n: int = 1) -> dt.datetime: 348 | """ 349 | 返回 ``year`` 年的第 ``n`` 个报告期 350 | """ 351 | month = n * 4 352 | day = 31 if month == 3 or month == 12 else 30 353 | return dt.datetime(year, month, day) 354 | -------------------------------------------------------------------------------- /AShareData/empirical.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import empyrical 4 | import pandas as pd 5 | 6 | from AShareData.date_utils import SHSZTradingCalendar 7 | 8 | DAYS_IN_YEAR = 365 9 | 10 | 11 | def annual_return(prices: pd.Series): 12 | prices = prices.dropna() 13 | if prices.shape[0] <= 1: 14 | return 0 15 | dates = prices.index.get_level_values('DateTime') 16 | days = (dates[-1].date() - dates[0].date()).days 17 | pct_change = prices[-1] / prices[0] 18 | years = days / DAYS_IN_YEAR 19 | return pow(pct_change, 1 / years) - 1 20 | 21 | 22 | def annual_volatility(prices: pd.Series): 23 | dates = prices.index.get_level_values('DateTime') 24 | cal = SHSZTradingCalendar() 25 | date_index = cal.select_dates(start_date=dates[0], end_date=dates[-1]) 26 | prices = prices.droplevel('ID').reindex(date_index).interpolate() 27 | return prices.pct_change().std() * math.sqrt(DAYS_IN_YEAR) 28 | 29 | 30 | def sharpe_ratio(prices: pd.Series): 31 | return annual_return(prices) / annual_volatility(prices) 32 | 33 | 34 | def bond_fund_annual_return(prices: pd.Series, threshold: float = 0.005): 35 | prices = prices.dropna() 36 | if prices.shape[0] <= 1: 37 | return 0 38 | dates = prices.index.get_level_values('DateTime') 39 | days = (dates[-1].date() - dates[0].date()).days 40 | pct = prices.pct_change() 41 | pct_change = (1 + pct.loc[pct < threshold]).prod() 42 | years = days / DAYS_IN_YEAR 43 | return pow(pct_change, 1 / years) - 1 44 | 45 | 46 | def bond_fund_annual_volatility(prices: pd.Series, threshold: float = 0.005): 47 | dates = prices.index.get_level_values('DateTime') 48 | cal = SHSZTradingCalendar() 49 | date_index = cal.select_dates(start_date=dates[0], end_date=dates[-1]) 50 | prices = prices.droplevel('ID').reindex(date_index).interpolate() 51 | pct = prices.pct_change() 52 | std = pct.loc[pct < threshold].std() 53 | return std * math.sqrt(DAYS_IN_YEAR) 54 | 55 | 56 | def bond_fund_sharpe_ratio(prices: pd.Series): 57 | if prices.shape[0] < 20: 58 | return 0 59 | return bond_fund_annual_return(prices) / bond_fund_annual_volatility(prices) 60 | 61 | 62 | def max_drawdown(prices: pd.Series) -> float: 63 | return empyrical.max_drawdown(prices.pct_change()) 64 | -------------------------------------------------------------------------------- /AShareData/factor_compositor/__init__.py: -------------------------------------------------------------------------------- 1 | from .factor_compositor import ConstLimitStockFactorCompositor, FactorCompositor, FundAdjFactorCompositor, \ 2 | IndexCompositor, IndexUpdater, MarketSummaryCompositor, NegativeBookEquityListingCompositor 3 | from .factor_portfolio import FactorPortfolio, FactorPortfolioPolicy 4 | -------------------------------------------------------------------------------- /AShareData/factor_compositor/factor_compositor.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from .. import utils 8 | from ..ashare_data_reader import AShareDataReader 9 | from ..data_source.data_source import DataSource 10 | from ..database_interface import DBInterface 11 | from ..factor import CompactFactor 12 | from ..tickers import FundTickers, StockTickerSelector 13 | 14 | 15 | class FactorCompositor(DataSource): 16 | def __init__(self, db_interface: DBInterface = None): 17 | """ 18 | Factor Compositor 19 | 20 | This class composite factors from raw market/financial info 21 | 22 | :param db_interface: DBInterface 23 | """ 24 | super().__init__(db_interface) 25 | self.data_reader = AShareDataReader(db_interface) 26 | 27 | def update(self): 28 | """更新数据""" 29 | raise NotImplementedError() 30 | 31 | 32 | class IndexCompositor(FactorCompositor): 33 | def __init__(self, index_composition_policy: utils.StockIndexCompositionPolicy, db_interface: DBInterface = None): 34 | """自建指数收益计算器""" 35 | super().__init__(db_interface) 36 | self.table_name = '自合成指数' 37 | self.policy = index_composition_policy 38 | self.weight = None 39 | if index_composition_policy.unit_base: 40 | self.weight = (CompactFactor(index_composition_policy.unit_base, self.db_interface) 41 | * self.data_reader.stock_close).weight() 42 | self.stock_ticker_selector = StockTickerSelector(self.policy.stock_selection_policy, self.db_interface) 43 | 44 | def update(self): 45 | """ 更新市场收益率 """ 46 | price_table = '股票日行情' 47 | 48 | start_date = self.db_interface.get_latest_timestamp(self.table_name, self.policy.start_date, 49 | column_condition=('ID', self.policy.ticker)) 50 | end_date = self.db_interface.get_latest_timestamp(price_table) 51 | dates = self.calendar.select_dates(start_date, end_date, inclusive=(False, True)) 52 | 53 | with tqdm(dates) as pbar: 54 | for date in dates: 55 | pbar.set_description(f'{date}') 56 | ids = self.stock_ticker_selector.ticker(date) 57 | 58 | if ids: 59 | t_dates = [(self.calendar.offset(date, -1)), date] 60 | if self.weight: 61 | rets = (self.data_reader.forward_return * self.weight).sum().get_data(dates=t_dates, ids=ids) 62 | else: 63 | rets = self.data_reader.stock_return.mean(along='DateTime').get_data(dates=t_dates, ids=ids) 64 | index = pd.MultiIndex.from_tuples([(date, self.policy.ticker)], names=['DateTime', 'ID']) 65 | ret = pd.Series(rets.values[0], index=index, name='收益率') 66 | 67 | self.db_interface.update_df(ret, self.table_name) 68 | pbar.update() 69 | 70 | 71 | class IndexUpdater(object): 72 | def __init__(self, config_loc=None, db_interface: DBInterface = None): 73 | """ 指数更新器 74 | 75 | :param config_loc: 配置文件路径. 默认指数位于 ``./data/自编指数配置.xlsx``. 自定义配置可参考此文件 76 | :param db_interface: DBInterface 77 | """ 78 | super().__init__() 79 | self.db_interface = db_interface 80 | records = utils.load_excel('自编指数配置.xlsx', config_loc) 81 | self.policies = {} 82 | for record in records: 83 | self.policies[record['name']] = utils.StockIndexCompositionPolicy.from_dict(record) 84 | 85 | def update(self): 86 | with tqdm(self.policies) as pbar: 87 | for policy in self.policies.values(): 88 | pbar.set_description(f'更新{policy.name}') 89 | IndexCompositor(policy, db_interface=self.db_interface).update() 90 | pbar.update() 91 | 92 | 93 | class ConstLimitStockFactorCompositor(FactorCompositor): 94 | def __init__(self, db_interface: DBInterface = None): 95 | """ 96 | 标识一字涨跌停板 97 | 98 | 判断方法: 取最高价和最低价一致 且 当日未停牌 99 | - 若价格高于昨前复权价, 则视为涨停一字板 100 | - 若价格低于昨前复权价, 则视为跌停一字板 101 | 102 | :param db_interface: DBInterface 103 | """ 104 | super().__init__(db_interface) 105 | self.table_name = '一字涨跌停' 106 | stock_selection_policy = utils.StockSelectionPolicy(select_pause=True) 107 | self.paused_stock_selector = StockTickerSelector(stock_selection_policy, db_interface) 108 | 109 | def update(self): 110 | price_table_name = '股票日行情' 111 | 112 | start_date = self.db_interface.get_latest_timestamp(self.table_name, dt.date(1999, 5, 4)) 113 | end_date = self.db_interface.get_latest_timestamp(price_table_name, dt.date(1990, 12, 10)) 114 | 115 | pre_data = self.db_interface.read_table(price_table_name, ['最高价', '最低价'], dates=start_date) 116 | dates = self.calendar.select_dates(start_date, end_date) 117 | pre_date = dates[0] 118 | dates = dates[1:] 119 | 120 | with tqdm(dates) as pbar: 121 | pbar.set_description('更新股票一字板') 122 | for date in dates: 123 | data = self.db_interface.read_table(price_table_name, ['最高价', '最低价'], dates=date) 124 | no_price_move_tickers = data.loc[data['最高价'] == data['最低价']].index.get_level_values('ID').tolist() 125 | if no_price_move_tickers: 126 | target_stocks = list(set(no_price_move_tickers) - set(self.paused_stock_selector.ticker(date))) 127 | if target_stocks: 128 | adj_factor = self.data_reader.adj_factor.get_data(start_date=pre_date, end_date=date, 129 | ids=target_stocks) 130 | price = data.loc[(slice(None), target_stocks), '最高价'] * adj_factor.loc[(date, target_stocks)] 131 | pre_price = pre_data.loc[(slice(None), target_stocks), '最高价'] * adj_factor.loc[ 132 | (pre_date, target_stocks)] 133 | diff_price = pd.concat([pre_price, price]).unstack().diff().iloc[1, :].dropna() 134 | diff_price = diff_price.loc[diff_price != 0] 135 | if diff_price.shape[0] > 1: 136 | ret = (diff_price > 0) * 2 - 1 137 | ret = ret.to_frame().reset_index() 138 | ret['DateTime'] = date 139 | ret.set_index(['DateTime', 'ID'], inplace=True) 140 | ret.columns = ['涨跌停'] 141 | self.db_interface.insert_df(ret, self.table_name) 142 | pre_data = data 143 | pre_date = date 144 | pbar.update() 145 | 146 | 147 | class FundAdjFactorCompositor(FactorCompositor): 148 | def __init__(self, db_interface: DBInterface = None): 149 | """ 150 | 计算基金的复权因子 151 | 152 | :param db_interface: DBInterface 153 | """ 154 | super().__init__(db_interface) 155 | self.fund_tickers = FundTickers(self.db_interface) 156 | self.target_table_name = '复权因子' 157 | self.div_table_name = '公募基金分红' 158 | self.cache_entry = '基金复权因子' 159 | 160 | def compute_adj_factor(self, ticker): 161 | list_date = self.fund_tickers.get_list_date(ticker) 162 | index = pd.MultiIndex.from_tuples([(list_date, ticker)], names=('DateTime', 'ID')) 163 | list_date_adj_factor = pd.Series(1, index=index, name=self.target_table_name) 164 | self.db_interface.update_df(list_date_adj_factor, self.target_table_name) 165 | 166 | div_info = self.db_interface.read_table(self.div_table_name, ids=ticker) 167 | if div_info.empty: 168 | return 169 | div_dates = div_info.index.get_level_values('DateTime').tolist() 170 | pre_date = [self.calendar.offset(it, -1) for it in div_dates] 171 | 172 | if ticker.endswith('.OF'): 173 | price_table_name, col_name = '场外基金净值', '单位净值' 174 | else: 175 | price_table_name, col_name = '场内基金日行情', '收盘价' 176 | pre_price_data = self.db_interface.read_table(price_table_name, col_name, dates=pre_date, ids=ticker) 177 | if pre_price_data.shape[0] != div_info.shape[0]: 178 | price_data = self.db_interface.read_table(price_table_name, col_name, ids=ticker) 179 | tmp = pd.concat([price_data.shift(1), div_info], axis=1) 180 | pre_price_data = tmp.dropna().iloc[:, 0] 181 | else: 182 | pre_price_data.index = div_info.index 183 | adj_factor = (pre_price_data / (pre_price_data - div_info)).cumprod() 184 | adj_factor.name = self.target_table_name 185 | self.db_interface.update_df(adj_factor, self.target_table_name) 186 | 187 | def update(self): 188 | update_time = self.db_interface.get_cache_date(self.cache_entry) 189 | if update_time is None: 190 | tickers = self.fund_tickers.all_ticker() 191 | else: 192 | next_day = self.calendar.offset(update_time, 1) 193 | fund_div_table = self.db_interface.read_table(self.div_table_name, start_date=next_day) 194 | tickers = fund_div_table.index.get_level_values('ID').unique().tolist() 195 | with tqdm(tickers) as pbar: 196 | for ticker in tickers: 197 | pbar.set_description(f'更新 {ticker} 的复权因子') 198 | self.compute_adj_factor(ticker) 199 | pbar.update() 200 | timestamp = self.db_interface.get_latest_timestamp(self.div_table_name) 201 | self.db_interface.update_cache_date(self.cache_entry, timestamp) 202 | 203 | 204 | class NegativeBookEquityListingCompositor(FactorCompositor): 205 | def __init__(self, db_interface: DBInterface = None): 206 | """标识负净资产股票 207 | 208 | :param db_interface: DBInterface 209 | """ 210 | super().__init__(db_interface) 211 | self.table_name = '负净资产股票' 212 | 213 | def update(self): 214 | data = self.db_interface.read_table('合并资产负债表', '股东权益合计(不含少数股东权益)') 215 | storage = [] 216 | for _, group in data.groupby('ID'): 217 | if any(group < 0): 218 | tmp = group.groupby('DateTime').tail(1) < 0 219 | t = tmp.iloc[tmp.argmax():].droplevel('报告期') 220 | t2 = t[np.concatenate(([True], t.values[:-1] != t.values[1:]))] 221 | if any(t2): 222 | storage.append(t2) 223 | 224 | ret = pd.concat(storage) 225 | ret.name = '负净资产股票' 226 | self.db_interface.purge_table(self.table_name) 227 | self.db_interface.insert_df(ret, self.table_name) 228 | 229 | 230 | class MarketSummaryCompositor(FactorCompositor): 231 | def __init__(self, policy_name: str = '全市场', db_interface: DBInterface = None): 232 | """市场汇总 233 | 234 | :param db_interface: DBInterface 235 | """ 236 | super().__init__(db_interface) 237 | self.table_name = '市场汇总' 238 | self.init_date = dt.datetime(2001, 1, 1) 239 | records = utils.load_excel('自编指数配置.xlsx') 240 | policies = {} 241 | for record in records: 242 | policies[record['name']] = utils.StockIndexCompositionPolicy.from_dict(record) 243 | stock_selection_policy = policies[policy_name] 244 | self.ticker = stock_selection_policy.ticker 245 | self.stock_ticker_selector = StockTickerSelector(stock_selection_policy.stock_selection_policy, 246 | self.db_interface) 247 | self.total_share = CompactFactor('A股总股本', self.db_interface) 248 | self.float_share = CompactFactor('A股流通股本', self.db_interface) 249 | self.free_float_share = CompactFactor('自由流通股本', self.db_interface) 250 | 251 | def update(self): 252 | price_table_name = '股票日行情' 253 | 254 | start_date = self.db_interface.get_latest_timestamp(self.table_name, self.init_date) 255 | end_date = self.db_interface.get_latest_timestamp(price_table_name, dt.date(1990, 12, 10)) 256 | dates = self.calendar.select_dates(start_date, end_date, (False, True)) 257 | with tqdm(dates) as pbar: 258 | for date in dates: 259 | pbar.set_description(f'更新市场汇总: {date}') 260 | tickers = self.stock_ticker_selector.ticker(date) 261 | daily_info = self.db_interface.read_table(price_table_name, ['收盘价', '成交额'], ids=tickers, dates=date) 262 | total_share = self.total_share.get_data(dates=date, ids=tickers) 263 | float_share = self.float_share.get_data(dates=date, ids=tickers) 264 | free_float_share = self.free_float_share.get_data(dates=date, ids=tickers) 265 | 266 | vol = daily_info['成交额'].sum() 267 | float_val = daily_info['收盘价'].dot(float_share) 268 | free_float_val = daily_info['收盘价'].dot(free_float_share) 269 | total_df = daily_info['收盘价'] * total_share 270 | total_val = total_df.sum() 271 | 272 | earning_ttm = self.data_reader.earning_ttm.get_data(ids=tickers, dates=date) 273 | earning_data = pd.concat([total_df, earning_ttm], axis=1).dropna() 274 | pe_val = earning_data.iloc[:, 0].sum() / earning_data.iloc[:, 1].sum() 275 | 276 | book_val = self.data_reader.book_val.get_data(ids=tickers, dates=date) 277 | book_data = pd.concat([total_df, book_val], axis=1).dropna() 278 | pb_val = book_data.iloc[:, 0].sum() / book_data.iloc[:, 1].sum() 279 | 280 | sum_dict = {'DateTime': date, 'ID': self.ticker, '成交额': vol, '总市值': total_val, '流通市值': float_val, 281 | '自由流通市值': free_float_val, '市盈率TTM': pe_val, '市净率': pb_val} 282 | 283 | df = pd.DataFrame(sum_dict, index=['DateTime']).set_index(['DateTime', 'ID']) 284 | self.db_interface.insert_df(df, self.table_name) 285 | 286 | pbar.update() 287 | -------------------------------------------------------------------------------- /AShareData/factor_compositor/factor_portfolio.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from dataclasses import dataclass 3 | from typing import Sequence 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from .factor_compositor import FactorCompositor 9 | from .. import utils 10 | from ..database_interface import DBInterface 11 | from ..factor import Factor, FactorBase, IndustryFactor 12 | from ..tickers import StockTickerSelector 13 | 14 | 15 | @dataclass 16 | class FactorPortfolioPolicy: 17 | """ 18 | 因子收益率 19 | 20 | :param name: 名称 21 | :param bins: 分层数 22 | :param weight: 权重,默认为 ``None`` (等权) 23 | :param stock_selection_policy: 股票选取范围 24 | :param factor: 因子 25 | :param factor_need_shift: 因子是否需要延迟一个周期以避免未来函数 26 | :param industry: 行业分类因子,默认为 ``None`` (不进行行业中性) 27 | :param start_date: 开始日期 28 | """ 29 | name: str = None 30 | bins: Sequence[int] = None 31 | weight: Factor = None 32 | stock_selection_policy: utils.StockSelectionPolicy = None 33 | factor: FactorBase = None 34 | factor_need_shift: bool = False 35 | industry: IndustryFactor = None 36 | start_date: dt.datetime = None 37 | 38 | 39 | class FactorPortfolio(FactorCompositor): 40 | def __init__(self, factor_portfolio_policy: FactorPortfolioPolicy, db_interface: DBInterface = None): 41 | super().__init__(db_interface) 42 | self.policy = factor_portfolio_policy 43 | self.stock_ticker_selector = StockTickerSelector(factor_portfolio_policy.stock_selection_policy, db_interface) 44 | self.factor_name = self.policy.factor.name 45 | self.ret_name = self.data_reader.stock_return.name 46 | self.industry_category = self.policy.industry.name 47 | self.cap_name = self.policy.weight.name 48 | 49 | def update(self): 50 | table_name = '因子分组收益率' 51 | identifying_ticker = f'{self.factor_name}_NN_G1inG5' 52 | start_date = self.db_interface.get_latest_timestamp(table_name, self.policy.start_date, 53 | column_condition=('ID', identifying_ticker)) 54 | end_date = self.db_interface.get_latest_timestamp('股票日行情') 55 | dates = self.calendar.select_dates(start_date, end_date, inclusive=(False, True)) 56 | 57 | with tqdm(dates) as pbar: 58 | for date in dates: 59 | pbar.set_description(f'更新{date}的{self.factor_name}的因子收益率') 60 | pre_date = self.calendar.offset(date, -1) 61 | ids = self.stock_ticker_selector.ticker(date) 62 | 63 | pct_return = self.data_reader.stock_return.get_data(start_date=pre_date, end_date=date, ids=ids) 64 | factor_date = pre_date if self.policy.factor_need_shift else date 65 | factor_data = self.policy.factor.get_data(dates=factor_date) 66 | industry = self.policy.industry.get_data(dates=date) 67 | cap = self.policy.weight.get_data(dates=pre_date) 68 | 69 | storage = [pct_return.droplevel('DateTime'), factor_data.droplevel('DateTime'), 70 | industry.droplevel('DateTime'), cap.droplevel('DateTime')] 71 | data = pd.concat(storage, axis=1).dropna() 72 | 73 | def split_group(x: pd.Series) -> pd.Series: 74 | labels = [f'G{i + 1}inG{num_bin}' for i in range(num_bin)] 75 | return pd.qcut(x.loc[:, self.factor_name], q=num_bin, labels=labels, duplicates='drop') 76 | 77 | def fill_index(res: pd.Series) -> pd.Series: 78 | index_name = [f'{self.factor_name}_{i}{w}_{it}' for it in res.index] 79 | index = pd.MultiIndex.from_product([[date], index_name]) 80 | res.index = index 81 | return res 82 | 83 | storage = [] 84 | i = 'I' 85 | # use industry info: 86 | for num_bin in self.policy.bins: 87 | industry_data = data.copy() 88 | group = industry_data.groupby(self.industry_category).apply(split_group) 89 | group.name = 'group' 90 | industry_data = pd.concat([industry_data, group.droplevel(self.industry_category)], axis=1) 91 | 92 | # unweighted 93 | w = 'N' 94 | res = industry_data.groupby('group')[self.ret_name].mean() 95 | storage.append(fill_index(res)) 96 | 97 | # cap weighted 98 | w = 'W' 99 | res = industry_data.groupby('group').apply( 100 | lambda x: x.loc[:, self.ret_name].dot(x.loc[:, self.cap_name] / x.loc[:, self.cap_name].sum())) 101 | storage.append(fill_index(res)) 102 | 103 | i = 'N' 104 | # without industry 105 | for num_bin in self.policy.bins: 106 | non_industry_info = data.copy() 107 | non_industry_info['group'] = split_group(non_industry_info) 108 | 109 | # unweighted 110 | w = 'N' 111 | res = non_industry_info.groupby('group')[self.ret_name].mean() 112 | storage.append(fill_index(res)) 113 | 114 | # cap weighted 115 | w = 'W' 116 | res = non_industry_info.groupby('group').apply( 117 | lambda x: x.loc[:, self.ret_name].dot(x.loc[:, self.cap_name] / x.loc[:, self.cap_name].sum())) 118 | storage.append(fill_index(res)) 119 | 120 | full_data = pd.concat(storage) 121 | full_data.index.names = ('DateTime', 'ID') 122 | full_data.name = '收益率' 123 | self.db_interface.insert_df(full_data, table_name) 124 | pbar.update() 125 | -------------------------------------------------------------------------------- /AShareData/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .capm import CapitalAssetPricingModel as CAPM 2 | from .fama_french_3_factor_model import FamaFrench3FactorModel, SMBandHMLCompositor 3 | from .fama_french_carhart_4_factor_model import FamaFrenchCarhart4FactorModel, UMDCompositor 4 | -------------------------------------------------------------------------------- /AShareData/model/capm.py: -------------------------------------------------------------------------------- 1 | from .model import FinancialModel 2 | 3 | 4 | class CapitalAssetPricingModel(FinancialModel): 5 | def __init__(self): 6 | super().__init__('Capital Asset Pricing Model', []) 7 | -------------------------------------------------------------------------------- /AShareData/model/fama_french_3_factor_model.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | 5 | from .model import FinancialModel, ModelFactorCompositor 6 | from ..database_interface import DBInterface 7 | from ..tickers import StockTickerSelector 8 | from ..utils import StockSelectionPolicy 9 | 10 | 11 | class FamaFrench3FactorModel(FinancialModel): 12 | def __init__(self): 13 | """Fama French 3 factor model(1992)""" 14 | super().__init__('Fama French 3 factor model', ['FF3_SMB', 'FF3_HML']) 15 | 16 | self.stock_selection_policy = StockSelectionPolicy(ignore_negative_book_value_stock=True, 17 | ignore_st=True, ignore_pause=True, 18 | ignore_new_stock_period=244) 19 | self.hml_threshold = [0, 0.3, 0.7, 1] 20 | self.smb_threshold = [0, 0.5, 1] 21 | 22 | 23 | class SMBandHMLCompositor(ModelFactorCompositor): 24 | def __init__(self, model: FamaFrench3FactorModel = None, db_interface: DBInterface = None): 25 | """Compute SMB and HML in Fama French 3 factor model""" 26 | model = model if model else FamaFrench3FactorModel() 27 | super().__init__(model, db_interface) 28 | 29 | self.start_date = dt.datetime(2007, 1, 4) 30 | self.ticker_selector = StockTickerSelector(model.stock_selection_policy, self.db_interface) 31 | 32 | self.cap = self.data_reader.stock_free_floating_market_cap 33 | self.bm = self.data_reader.bm 34 | self.returns = self.data_reader.stock_return 35 | 36 | def compute_factor_return(self, balance_date: dt.datetime, pre_date: dt.datetime, date: dt.datetime, 37 | rebalance_marker: str, period_marker: str) -> pd.Series: 38 | def cap_weighted_return(x): 39 | return x[returns.name].dot(x[cap.name]) / x[cap.name].sum() 40 | 41 | # data 42 | tickers = self.ticker_selector.ticker(date) 43 | bm = self.bm.get_data(ids=tickers, dates=balance_date).droplevel('DateTime') 44 | cap = self.cap.get_data(ids=tickers, dates=balance_date).droplevel('DateTime') 45 | returns = self.returns.get_data(ids=tickers, dates=[pre_date, date]).droplevel('DateTime') 46 | df = pd.concat([returns, bm, cap], axis=1).dropna() 47 | 48 | # grouping 49 | df['G_SMB'] = pd.qcut(df[self.cap.name], self.model.smb_threshold, labels=['small', 'big']) 50 | df['G_HML'] = pd.qcut(df[self.bm.name], self.model.hml_threshold, labels=['low', 'mid', 'high']) 51 | rets = df.groupby(['G_SMB', 'G_HML']).apply(cap_weighted_return) 52 | tmp = rets.groupby('G_SMB').mean() 53 | smb = tmp.loc['small'] - tmp.loc['big'] 54 | tmp = rets.groupby('G_HML').mean() 55 | hml = tmp.loc['high'] - tmp.loc['low'] 56 | 57 | # formatting result 58 | factor_names = [f'{factor_name}_{rebalance_marker}{period_marker}' for factor_name in self.factor_names] 59 | index_date = pre_date if period_marker == 'M' else date 60 | index = pd.MultiIndex.from_product([[index_date], factor_names], names=('DateTime', 'ID')) 61 | factor_df = pd.Series([smb, hml], index=index, name='收益率') 62 | return factor_df 63 | -------------------------------------------------------------------------------- /AShareData/model/fama_french_carhart_4_factor_model.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import pandas as pd 4 | 5 | from .model import FinancialModel, ModelFactorCompositor 6 | from ..database_interface import DBInterface 7 | from ..tickers import StockTickerSelector 8 | from ..utils import StockSelectionPolicy 9 | 10 | 11 | class FamaFrenchCarhart4FactorModel(FinancialModel): 12 | def __init__(self): 13 | """Fama French Carhart 4 factor model(1997)""" 14 | super().__init__('Fama French Carhart 4 factor model', ['FF3_SMB', 'FF3_HML', 'FFC4_UMD']) 15 | 16 | self.stock_selection_policy = StockSelectionPolicy(ignore_negative_book_value_stock=True, 17 | ignore_st=True, ignore_pause=True, 18 | ignore_new_stock_period=244) 19 | self.hml_threshold = [0, 0.3, 0.7, 1] 20 | self.smb_threshold = [0, 0.5, 1] 21 | self.umd_threshold = [0, 0.3, 0.7, 1] 22 | self.offset_1 = 22 23 | self.offset_2 = 22*12 24 | 25 | 26 | class UMDCompositor(ModelFactorCompositor): 27 | def __init__(self, model: FamaFrenchCarhart4FactorModel = None, db_interface: DBInterface = None): 28 | """Compute UMD/MOM in Fama French Carhart 4 factor model""" 29 | model = model if model else FamaFrenchCarhart4FactorModel() 30 | super().__init__(model, db_interface) 31 | self.factor_names = ['Carhart_UMD'] 32 | 33 | self.start_date = dt.datetime(2007, 1, 4) 34 | self.ticker_selector = StockTickerSelector(model.stock_selection_policy, self.db_interface) 35 | 36 | self.cap = self.data_reader.stock_free_floating_market_cap 37 | self.returns = self.data_reader.stock_return 38 | 39 | def compute_factor_return(self, balance_date: dt.datetime, pre_date: dt.datetime, date: dt.datetime, 40 | rebalance_marker: str, period_marker: str) -> pd.Series: 41 | # data 42 | tm1 = self.calendar.offset(balance_date, -self.model.offset_1) 43 | tm12 = self.calendar.offset(balance_date, -self.model.offset_2) 44 | tickers = self.ticker_selector.ticker(date) 45 | tm1_ticker = self.ticker_selector.ticker(tm1) 46 | tm12_ticker = self.ticker_selector.ticker(tm12) 47 | tickers = sorted(list(set(tickers) & set(tm1_ticker) & set(tm12_ticker))) 48 | p1 = self.data_reader.hfq_close.get_data(ids=tickers, dates=tm1) 49 | p12 = self.data_reader.hfq_close.get_data(ids=tickers, dates=tm12) 50 | pct_diff = p12.droplevel('DateTime') / p1.droplevel('DateTime') 51 | cap = self.cap.get_data(ids=tickers, dates=balance_date).droplevel('DateTime') 52 | returns = self.returns.get_data(ids=tickers, dates=[pre_date, date]).droplevel('DateTime') 53 | df = pd.concat([returns, cap, pct_diff], axis=1).dropna() 54 | 55 | # grouping 56 | df['G_SMB'] = pd.qcut(df[cap.name], self.model.smb_threshold, labels=['small', 'big']) 57 | df['G_UMD'] = pd.qcut(df[pct_diff.name], self.model.umd_threshold, labels=['up', 'mid', 'down']) 58 | rets = df.groupby(['G_SMB', 'G_UMD'])[returns.name].mean() 59 | tmp = rets.groupby('G_UMD').mean() 60 | umd = tmp.loc['up'] - tmp.loc['down'] 61 | 62 | # formatting result 63 | factor_names = [f'{factor_name}_{rebalance_marker}{period_marker}' for factor_name in self.factor_names] 64 | index_date = pre_date if period_marker == 'M' else date 65 | index = pd.MultiIndex.from_product([[index_date], factor_names], names=('DateTime', 'ID')) 66 | factor_df = pd.Series(umd, index=index, name='收益率') 67 | return factor_df 68 | -------------------------------------------------------------------------------- /AShareData/model/model.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from typing import List 3 | 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from ..database_interface import DBInterface 8 | from ..factor_compositor.factor_compositor import FactorCompositor 9 | 10 | 11 | class FinancialModel(object): 12 | def __init__(self, model_name: str, factor_names: List[str]): 13 | """Base class for Financial Models 14 | 15 | Should include model parameters in the object, accompanied by a subclass of ModelFactorCompositor to compute factor returns 16 | 17 | PS: excess market return is implied and do not be specified in the ``factor_names`` 18 | 19 | :param model_name: Financial model Name 20 | :param factor_names: Factor names specified in the model 21 | """ 22 | self.model_name = model_name 23 | self.factor_names = factor_names.copy() 24 | 25 | def get_db_factor_names(self, rebalance_schedule: str = 'D', computing_schedule: str = 'D'): 26 | """ Naming schemes used when calculating using different rebalancing schedule and computing schedule. Combination of ('D', 'D'), ('M', 'D'), ('M', 'M') are valid 27 | 28 | :param rebalance_schedule: 'D' or 'M', portfolio is rebalanced Daily('D') or at the end of each Month('M') 29 | :param computing_schedule: 'D' or 'M', portfolio return is computed Daily('D') or Monthly('M') 30 | :return: 31 | """ 32 | return [f'{it}_{rebalance_schedule}{computing_schedule}' for it in self.factor_names] 33 | 34 | 35 | class ModelFactorCompositor(FactorCompositor): 36 | def __init__(self, model, db_interface: DBInterface): 37 | """ Model Factor Return Compositor 38 | 39 | Compute factor returns specified by ``model`` 40 | 41 | :param model: Financial model 42 | :param db_interface: DBInterface 43 | """ 44 | super().__init__(db_interface) 45 | self.model = model 46 | self.factor_names = model.factor_names 47 | self.db_table_name = '模型因子收益率' 48 | self.start_date = None 49 | 50 | def update(self): 51 | self.update_daily_rebalanced_portfolio() 52 | self.update_monthly_rebalanced_portfolio_return() 53 | 54 | def update_monthly_rebalanced_portfolio_return(self): 55 | eg_factor_name = f'{self.factor_names[-1]}_MD' 56 | start_date = self.db_interface.get_latest_timestamp(self.db_table_name, self.start_date, 57 | column_condition=('ID', eg_factor_name)) 58 | end_date = self.db_interface.get_latest_timestamp('股票日行情') 59 | dates = self.data_reader.calendar.select_dates(start_date, end_date, inclusive=(False, True)) 60 | 61 | with tqdm(dates) as pbar: 62 | for date in dates: 63 | pbar.set_description(f'更新 {self.model.model_name} 因子收益率: {date}') 64 | rebalance_date = self.calendar.pre_month_end(date.year, date.month) 65 | 66 | pre_date = self.data_reader.calendar.offset(date, -1) 67 | factor_df = self.compute_factor_return(rebalance_date, pre_date, date, 'M', 'D') 68 | self.db_interface.insert_df(factor_df, self.db_table_name) 69 | 70 | next_date = self.data_reader.calendar.offset(date, -1) 71 | if next_date.month != date.month: 72 | factor_df = self.compute_factor_return(rebalance_date, rebalance_date, date, 'M', 'M') 73 | month_beg_date = self.calendar.month_begin(date.year, date.month) 74 | factor_df.index = pd.MultiIndex.from_product( 75 | [[month_beg_date], factor_df.index.get_level_values('ID')], names=('DateTime', 'ID')) 76 | self.db_interface.insert_df(factor_df, self.db_table_name) 77 | pbar.update() 78 | 79 | def update_daily_rebalanced_portfolio(self): 80 | eg_factor_name = f'{self.factor_names[-1]}_DD' 81 | start_date = self.db_interface.get_latest_timestamp(self.db_table_name, self.start_date, 82 | column_condition=('ID', eg_factor_name)) 83 | end_date = self.db_interface.get_latest_timestamp('股票日行情') 84 | dates = self.data_reader.calendar.select_dates(start_date, end_date, inclusive=(False, True)) 85 | 86 | with tqdm(dates) as pbar: 87 | for date in dates: 88 | pbar.set_description(f'更新 {self.model.model_name} 因子日收益率: {date}') 89 | pre_date = self.data_reader.calendar.offset(date, -1) 90 | factor_df = self.compute_factor_return(pre_date, pre_date, date, 'D', 'D') 91 | self.db_interface.insert_df(factor_df, self.db_table_name) 92 | pbar.update() 93 | 94 | def compute_factor_return(self, balance_date: dt.datetime, pre_date: dt.datetime, date: dt.datetime, 95 | rebalance_marker: str, period_marker: str) -> pd.Series: 96 | raise NotImplementedError() 97 | -------------------------------------------------------------------------------- /AShareData/plot.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | from matplotlib.axes import Axes 6 | 7 | from AShareData import date_utils, utils 8 | from AShareData.config import get_db_interface 9 | from AShareData.database_interface import DBInterface 10 | from AShareData.factor import ContinuousFactor 11 | 12 | plt.rcParams['font.sans-serif'] = ['SimHei'] 13 | plt.rcParams['axes.unicode_minus'] = False 14 | 15 | 16 | def plot_factor_return(factor_name: str, weight: bool = True, industry_neutral: bool = True, bins: int = 5, 17 | start_date: date_utils.DateType = None, end_date: date_utils.DateType = None, 18 | db_interface: DBInterface = None) -> plt.Figure: 19 | if db_interface is None: 20 | db_interface = get_db_interface() 21 | 22 | ids = utils.generate_factor_bin_names(factor_name, weight=weight, industry_neutral=industry_neutral, bins=bins) 23 | data = db_interface.read_table('因子分组收益率', ids=ids, start_date=start_date, end_date=end_date) 24 | df = (data.unstack() + 1).cumprod() 25 | bin_names_info = [utils.decompose_bin_names(it) for it in df.columns] 26 | diff_series = df[ids[0]] - df[ids[-1]] 27 | 28 | df.columns = [it['group'] for it in bin_names_info] 29 | diff_series.name = f'{utils.decompose_bin_names(ids[0])["group"]}-{utils.decompose_bin_names(ids[-1])["group"]}' 30 | 31 | fig, axes = plt.subplots(2, 1, figsize=(15, 8), sharex='col') 32 | df.plot(ax=axes[0]) 33 | industry_neutral_str = '行业中性' if industry_neutral else '非行业中性' 34 | weight_str = '市值加权' if weight else '等权' 35 | axes[0].set_title(f'{factor_name} 分组收益率({industry_neutral_str}, {weight_str})') 36 | plot_dt = df.index.get_level_values('DateTime') 37 | axes[0].set_xlim(left=plot_dt[0], right=plot_dt[-1]) 38 | axes[0].grid(True) 39 | 40 | diff_series.plot(ax=axes[1]) 41 | axes[1].grid(True) 42 | axes[1].legend() 43 | 44 | return fig 45 | 46 | 47 | def plot_indexes(indexes: Union[ContinuousFactor, Sequence[ContinuousFactor]], start_date=None, end_date=None, 48 | ax: Axes = None) -> Axes: 49 | if isinstance(indexes, ContinuousFactor): 50 | indexes = [indexes] 51 | storage = [] 52 | for it in indexes: 53 | storage.append(it.get_data(start_date=start_date, end_date=end_date)) 54 | data = pd.concat(storage).unstack() 55 | val = (data + 1).cumprod() 56 | 57 | if ax is None: 58 | _, ax = plt.subplots(1, 1) 59 | val.plot(ax=ax) 60 | ax.set_xlim(left=val.index[0], right=val.index[-1]) 61 | ax.grid(True) 62 | return ax 63 | -------------------------------------------------------------------------------- /AShareData/portfolio_analysis.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import logging 3 | from functools import lru_cache 4 | from typing import Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from scipy.stats.mstats import winsorize 9 | from statsmodels.api import OLS 10 | from statsmodels.tools.tools import add_constant 11 | 12 | from . import AShareDataReader 13 | from .algo import human_sort 14 | from .config import get_db_interface 15 | from .database_interface import DBInterface 16 | from .factor import BinaryFactor, FactorBase, IndustryFactor, InterestRateFactor, OnTheRecordFactor, UnaryFactor 17 | from .model.model import FinancialModel 18 | from .tickers import StockTickerSelector 19 | 20 | 21 | class CrossSectionalPortfolioAnalysis(object): 22 | def __init__(self, forward_return: UnaryFactor, ticker_selector: StockTickerSelector, dates: Sequence[dt.datetime], 23 | factors: Union[FactorBase, Sequence[FactorBase]] = None, 24 | market_cap: BinaryFactor = None, industry: IndustryFactor = None): 25 | self.forward_return = forward_return 26 | self.ticker_selector = ticker_selector 27 | self.industry = industry 28 | self.dates = dates 29 | self.market_cap = market_cap 30 | 31 | self.cache_data = None 32 | self.market_cap_name = None 33 | 34 | self.factors = {} 35 | self.factor_names = [] 36 | self.append_factor(factors) 37 | self.sorting_factor = [] 38 | 39 | def append_factor(self, factor: Union[FactorBase, Sequence[FactorBase]]): 40 | if factor: 41 | if isinstance(factor, Sequence): 42 | for f in factor: 43 | self.append_factor(f) 44 | else: 45 | if factor.name not in self.factor_names: 46 | self.factors[factor.name] = factor 47 | self.factor_names.append(factor.name) 48 | if self.cache_data: 49 | data = factor.get_data(dates=self.dates, ticker_selector=self.ticker_selector) 50 | self.cache_data = pd.concat([self.cache_data, data], axis=1).dropna() 51 | 52 | def cache(self): 53 | logging.getLogger(__name__).info('Cache cross-sectional data') 54 | storage = [self.forward_return.get_data(dates=self.dates, ticker_selector=self.ticker_selector)] 55 | if self.market_cap: 56 | logging.getLogger(__name__).debug('Cache market cap data') 57 | storage.append(self.market_cap.get_data(dates=self.dates, ticker_selector=self.ticker_selector)) 58 | self.market_cap_name = self.market_cap.name 59 | else: 60 | self.market_cap_name = 'cap_weight' 61 | if self.industry: 62 | logging.getLogger(__name__).debug('Cache industry data') 63 | storage.append(self.industry.get_data(dates=self.dates, ticker_selector=self.ticker_selector)) 64 | for it in self.factors.values(): 65 | storage.append(it.get_data(dates=self.dates, ticker_selector=self.ticker_selector)) 66 | 67 | self.cache_data = pd.concat(storage, axis=1).dropna() 68 | 69 | def _factor_sorting(self, factor_name: str = None, quantile: int = None, separate_neg_vals: bool = False, 70 | gb_vars: Union[str, Sequence[str]] = 'DateTime'): 71 | var_name = f'G_{factor_name}' 72 | if var_name in self.cache_data.columns: 73 | return 74 | 75 | if factor_name is None: 76 | if len(self.factor_names) == 1: 77 | factor_name = self.factor_names[0] 78 | else: 79 | raise ValueError('Ambiguous factor name, please specify in `factor_name=`') 80 | 81 | quantile_labels = [f'G{i}' for i in range(1, quantile + 1)] 82 | if separate_neg_vals: 83 | negative_ind = self.cache_data[factor_name] < 0 84 | else: 85 | negative_ind = pd.Series(False, index=self.cache_data.index) 86 | 87 | tmp = self.cache_data.loc[~negative_ind, :].groupby(gb_vars)[factor_name].apply( 88 | lambda x: pd.qcut(x, quantile, labels=quantile_labels)) 89 | neg_vals = pd.Series('G0', index=self.cache_data.loc[negative_ind, :].index) 90 | tmp = pd.concat([tmp, neg_vals]).sort_index() 91 | 92 | self.cache_data[var_name] = tmp.values 93 | 94 | def single_factor_sorting(self, factor_name: str = None, quantile: int = None, separate_neg_vals: bool = False): 95 | self.sorting_factor = factor_name 96 | self._factor_sorting(factor_name, quantile, separate_neg_vals) 97 | 98 | def two_factor_sorting(self, factor_names: Tuple[str, str], independent: bool, 99 | quantile: Union[int, Sequence[float], Tuple[int, int]] = None, 100 | separate_neg_vals: Union[bool, Tuple[bool, bool]] = False): 101 | if factor_names[0] not in self.factor_names: 102 | raise ValueError(f'Unknown factor name: {factor_names[0]}.') 103 | if factor_names[1] not in self.factor_names: 104 | raise ValueError(f'Unknown factor name: {factor_names[1]}.') 105 | 106 | self.sorting_factor = list(factor_names) 107 | if not isinstance(quantile, Tuple): 108 | quantile = (quantile, quantile) 109 | if not isinstance(separate_neg_vals, Tuple): 110 | separate_neg_vals = (separate_neg_vals, separate_neg_vals) 111 | 112 | self._factor_sorting(factor_names[0], quantile[0], separate_neg_vals[0]) 113 | if independent: 114 | self._factor_sorting(factor_names[1], quantile[1], separate_neg_vals[1]) 115 | else: 116 | self._factor_sorting(factor_names[1], quantile[1], separate_neg_vals[1], 117 | gb_vars=['DateTime', f'G_{factor_names[0]}']) 118 | 119 | # TODO 120 | def fm_regression(self): 121 | data = self.cache_data.loc[:, ['forward_return'] + self.cache['factor_names']].copy() 122 | # need to winsorize 123 | for factor in self.cache['factor_names']: 124 | data[factor] = data[factor].apply(lambda x: winsorize(x, (0.25, 0.25))) 125 | # cross-sectional regression 126 | 127 | # time-series regression 128 | 129 | # test 130 | pass 131 | 132 | def returns_results(self, cap_weighted: bool = False) -> pd.DataFrame: 133 | if cap_weighted and not self.market_cap: 134 | raise ValueError('market cap is not specified.') 135 | 136 | def weighted_ret(x): 137 | return x[self.forward_return.name].dot(x[self.market_cap_name] / x[self.market_cap_name].sum()) 138 | 139 | func = weighted_ret if cap_weighted else np.mean 140 | g_vars = [f'G_{it}' for it in self.factor_names] 141 | tmp = self.cache_data.groupby(g_vars + ['DateTime']).apply(func) 142 | storage = [tmp.groupby(g_vars).mean().reset_index()] 143 | for var in g_vars: 144 | t2 = self.cache_data.groupby([var, 'DateTime']).apply(func) 145 | t3 = t2.groupby(var).mean().reset_index() 146 | other_var = list(set(g_vars) - {var})[0] 147 | t3[other_var] = 'ALL' 148 | storage.append(t3) 149 | 150 | tmp = pd.concat(storage) 151 | res = tmp.pivot(index=g_vars[0], columns=g_vars[1], values=tmp.columns[-1]) 152 | index = human_sort(res.index.tolist()) 153 | col = human_sort(res.columns.tolist()) 154 | res = res.loc[index, col] 155 | return res 156 | 157 | def summary_statistics(self, factor_name: str = None) -> pd.DataFrame: 158 | if factor_name is None: 159 | if len(self.factor_names) == 1: 160 | factor_name = self.factor_names[0] 161 | else: 162 | raise ValueError('Ambiguous factor name, please specify in `factor_name=`') 163 | 164 | res = self.cache_data.groupby([f'G_{factor_name}', 'DateTime']).mean() 165 | return res.groupby(f'G_{factor_name}').mean() 166 | 167 | def factor_corr(self, factor_names: Tuple[str, str]) -> pd.Series: 168 | return self.cache_data.groupby('DateTime').apply( 169 | lambda x: np.corrcoef(x[factor_names[0]], x[factor_names[1]])[0, 1]) 170 | 171 | 172 | class ASharePortfolioExposure(object): 173 | def __init__(self, model: FinancialModel, rf_rate: InterestRateFactor = None, 174 | rebalance_schedule: str = 'D', computing_schedule: str = 'D', 175 | db_interface: DBInterface = None): 176 | self.model = model 177 | self.factor_names = self.model.get_db_factor_names(rebalance_schedule, computing_schedule) 178 | 179 | self.db_interface = db_interface if db_interface else get_db_interface() 180 | self.data_reader = AShareDataReader(self.db_interface) 181 | self.rf_rate = rf_rate if rf_rate else self.data_reader.three_month_shibor 182 | market_return = self.data_reader.market_return 183 | self.excess_market_return = (market_return - self.rf_rate).set_factor_name(f'{market_return}-Rf') 184 | self.stock_pause_info = OnTheRecordFactor('股票停牌', self.db_interface) 185 | 186 | @lru_cache(10) 187 | def common_data(self, date: dt.datetime, start_date: dt.datetime): 188 | rf_data = self.rf_rate.get_data(start_date=start_date, end_date=date) 189 | rm = self.excess_market_return.get_data(start_date=start_date, end_date=date) 190 | col_names = [rm.index.get_level_values('ID')[0]] 191 | if self.factor_names: 192 | factor_data = self.data_reader.model_factor_return.get_data(ids=self.factor_names, 193 | start_date=start_date, end_date=date) 194 | col_names.extend(self.factor_names) 195 | else: 196 | factor_data = None 197 | data = pd.concat([rm, factor_data]).unstack().reindex(col_names, axis=1).rename({col_names[0]: 'Rm-Rf'}, axis=1) 198 | return rf_data, data 199 | 200 | def get_stock_exposure(self, ticker: str, date: dt.datetime = None, lookback_period: int = 60, 201 | minimum_look_back_length: int = 40): 202 | date = date if date else dt.datetime.combine(dt.date.today(), dt.time()) 203 | start_date = self.data_reader.calendar.offset(date, -lookback_period - 1) 204 | returns = self.data_reader.stock_return.get_data(ids=[ticker], start_date=start_date, end_date=date) 205 | y = returns.droplevel('ID') 206 | rf_data, x = self.common_data(date, start_date) 207 | data = pd.concat([y - rf_data, x], axis=1).dropna() 208 | pause_counts = self.stock_pause_info.get_counts(ids=[ticker], start_date=start_date, end_date=date).values[0] 209 | if data.shape[0] - pause_counts < minimum_look_back_length: 210 | return 211 | regress_res = OLS(data.iloc[:, 0], add_constant(data.iloc[:, 1:])).fit() 212 | res = regress_res.params.iloc[1:] 213 | res.name = ticker 214 | return res.to_frame().T 215 | 216 | def get_portfolio_exposure(self, portfolio_weight: pd.DataFrame, date: dt.datetime = None, 217 | lookback_period: int = 60, minimum_look_back_length: int = 40): 218 | date = date if date else portfolio_weight.index.get_level_values('DateTime')[0] 219 | storage = [self.get_stock_exposure(it, date, lookback_period, minimum_look_back_length) 220 | for it in portfolio_weight.index.get_level_values('ID')] 221 | stock_exposure = pd.concat(storage) 222 | return stock_exposure.mul(portfolio_weight.droplevel('DateTime'), axis=0).sum() 223 | -------------------------------------------------------------------------------- /AShareData/tickers.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | from functools import cached_property 3 | from itertools import product 4 | from typing import Dict, List, Sequence, Union 5 | 6 | import pandas as pd 7 | from dateutil.relativedelta import relativedelta 8 | from singleton_decorator import singleton 9 | 10 | from . import date_utils 11 | from .config import get_db_interface 12 | from .database_interface import DBInterface 13 | from .factor import CompactFactor, CompactRecordFactor, IndustryFactor, OnTheRecordFactor 14 | from .utils import StockSelectionPolicy, TickerSelector 15 | 16 | 17 | @singleton 18 | class FundInfo(object): 19 | def __init__(self, db_interface: DBInterface = None): 20 | super().__init__() 21 | if db_interface is None: 22 | db_interface = get_db_interface() 23 | self.data = db_interface.read_table('基金列表') 24 | 25 | 26 | class TickersBase(object): 27 | """证券代码基类""" 28 | 29 | def __init__(self, db_interface: DBInterface = None) -> None: 30 | self.db_interface = db_interface if db_interface else get_db_interface() 31 | self.cache = None 32 | 33 | def all_ticker(self) -> List[str]: 34 | """ return ALL ticker for the asset class""" 35 | return sorted(self.cache.ID.unique().tolist()) 36 | 37 | @date_utils.dtlize_input_dates 38 | def ticker(self, date: date_utils.DateType = None) -> List[str]: 39 | """ return tickers that are alive on `date`, `date` default to today""" 40 | if date is None: 41 | date = dt.datetime.today() 42 | stock_ticker_df = self.cache.loc[self.cache.DateTime <= date] 43 | tmp = stock_ticker_df.groupby('ID').tail(1) 44 | return sorted(tmp.loc[tmp['上市状态'] == 1, 'ID'].tolist()) 45 | 46 | def list_date(self) -> Dict[str, dt.datetime]: 47 | """ return the list date of all tickers""" 48 | first_list_info = self.cache.groupby('ID').head(1) 49 | return dict(zip(first_list_info.ID, first_list_info.DateTime)) 50 | 51 | def get_list_date(self, tickers: Union[str, Sequence[str]]) -> Union[pd.Series, dt.datetime]: 52 | """ return the list date of a ticker""" 53 | if isinstance(tickers, str): 54 | tickers = [tickers] 55 | info = self.cache.loc[self.cache.ID.isin(tickers) & self.cache['上市状态'] == 1, :].set_index('ID') 56 | ret = info.DateTime.iloc[0] if info.shape[0] == 1 else info.DateTime 57 | return ret 58 | 59 | def new_ticker(self, start_date: dt.datetime, end_date: dt.datetime = None) -> List[str]: 60 | if end_date is None: 61 | end_date = dt.datetime.today() 62 | if start_date is None: 63 | start_date = dt.datetime(1990, 12, 10) 64 | u_data = self.cache.loc[(start_date <= self.cache.DateTime) & (self.cache.DateTime <= end_date), :] 65 | tmp = u_data.groupby('ID').tail(1) 66 | return sorted(tmp.loc[tmp['上市状态'] == 1, 'ID'].tolist()) 67 | 68 | 69 | class DiscreteTickers(TickersBase): 70 | """细类证券代码基类""" 71 | 72 | def __init__(self, asset_type: str, db_interface: DBInterface = None) -> None: 73 | super().__init__(db_interface) 74 | self.cache = self.db_interface.read_table('证券代码', text_statement=f'证券类型="{asset_type}"').reset_index() 75 | 76 | 77 | class StockTickers(DiscreteTickers): 78 | """股票代码""" 79 | 80 | def __init__(self, db_interface: DBInterface = None) -> None: 81 | super().__init__('A股股票', db_interface) 82 | 83 | 84 | class ConvertibleBondTickers(DiscreteTickers): 85 | """可转债代码""" 86 | 87 | def __init__(self, db_interface: DBInterface = None) -> None: 88 | super().__init__('可转债', db_interface) 89 | 90 | 91 | class FutureTickers(DiscreteTickers): 92 | """期货合约代码""" 93 | 94 | def __init__(self, db_interface: DBInterface = None) -> None: 95 | super().__init__('期货', db_interface) 96 | 97 | 98 | class StockIndexFutureIndex(FutureTickers): 99 | """股指期货合约代码""" 100 | 101 | def __init__(self, db_interface: DBInterface = None) -> None: 102 | super().__init__(db_interface) 103 | mask = self.cache.ID.str.startswith('IH') | self.cache.ID.str.startswith('IF') | self.cache.ID.str.startswith( 104 | 'IC') 105 | self.cache = self.cache.loc[mask, :] 106 | 107 | 108 | class ETFOptionTickers(DiscreteTickers): 109 | """期权合约代码""" 110 | 111 | def __init__(self, db_interface: DBInterface = None) -> None: 112 | super().__init__('ETF期权', db_interface) 113 | 114 | 115 | class IndexOptionTickers(DiscreteTickers): 116 | """指数期权合约代码""" 117 | 118 | def __init__(self, db_interface: DBInterface = None) -> None: 119 | super().__init__('指数期权', db_interface) 120 | 121 | 122 | class FutureOptionTickers(DiscreteTickers): 123 | """商品期权合约代码""" 124 | 125 | def __init__(self, db_interface: DBInterface = None) -> None: 126 | super().__init__('商品期权', db_interface) 127 | 128 | 129 | class ExchangeStockETFTickers(DiscreteTickers): 130 | """场内股票ETF基金代码""" 131 | 132 | def __init__(self, db_interface: DBInterface = None) -> None: 133 | super().__init__('场内基金', db_interface) 134 | fund_info = FundInfo(db_interface) 135 | all_tickers = fund_info.data.loc[(fund_info.data['ETF'] == True) & (fund_info.data['投资类型'] == '被动指数型基金'), :] 136 | self.cache = self.cache.loc[self.cache.ID.isin(all_tickers.index.tolist()), :] 137 | 138 | 139 | class BondETFTickers(DiscreteTickers): 140 | """债券ETF基金代码""" 141 | 142 | def __init__(self, db_interface: DBInterface = None) -> None: 143 | super().__init__('场内基金', db_interface) 144 | fund_info = FundInfo(db_interface) 145 | all_tickers = fund_info.data.loc[(fund_info.data['ETF'] == True) & (fund_info.data['投资类型'] == '被动指数型债券基金'), :] 146 | self.cache = self.cache.loc[self.cache.ID.isin(all_tickers.index.tolist()), :] 147 | 148 | 149 | class ConglomerateTickers(TickersBase): 150 | """聚合类证券代码基类""" 151 | 152 | def __init__(self, sql_statement: str, db_interface: DBInterface = None) -> None: 153 | super().__init__(db_interface) 154 | self.cache = self.db_interface.read_table('证券代码', text_statement=sql_statement).reset_index() 155 | 156 | 157 | class OptionTickers(ConglomerateTickers): 158 | """期权""" 159 | 160 | def __init__(self, db_interface: DBInterface = None) -> None: 161 | super().__init__('证券类型 like "%期权"', db_interface) 162 | 163 | 164 | class FundTickers(ConglomerateTickers): 165 | """基金""" 166 | 167 | def __init__(self, db_interface: DBInterface = None) -> None: 168 | super().__init__('证券类型 like "%基金"', db_interface) 169 | 170 | 171 | class ETFTickers(DiscreteTickers): 172 | """ETF""" 173 | 174 | def __init__(self, db_interface: DBInterface = None) -> None: 175 | super().__init__('场内基金', db_interface) 176 | fund_info = FundInfo(db_interface) 177 | all_tickers = fund_info.data.loc[fund_info.data['ETF'] == True, :] 178 | self.cache = self.cache.loc[self.cache.ID.isin(all_tickers.index.tolist()), :] 179 | 180 | 181 | class ExchangeFundTickers(DiscreteTickers): 182 | """场内基金""" 183 | 184 | def __init__(self, db_interface: DBInterface = None) -> None: 185 | super().__init__('场内基金', db_interface) 186 | 187 | 188 | class OTCFundTickers(DiscreteTickers): 189 | """场外基金""" 190 | 191 | def __init__(self, db_interface: DBInterface = None) -> None: 192 | super().__init__('场外基金', db_interface) 193 | 194 | 195 | class InvestmentStyleFundTicker(DiscreteTickers): 196 | def __init__(self, investment_type: Sequence[str], otc: bool = False, db_interface: DBInterface = None) -> None: 197 | """ 某些投资风格的基金 198 | 199 | :param investment_type: [普通股票型基金, 灵活配置型基金, 偏股混合型基金, 平衡混合型基金, 被动指数型基金, 增强指数型基金, 股票多空, 200 | 短期纯债型基金, 中长期纯债型基金, 混合债券型一级基金, 混合债券型二级基金, 偏债混合型基金, 被动指数型债券基金, 增强指数型债券基金, 201 | 商品型基金, 202 | 货币市场型基金, 203 | 国际(QDII)股票型基金, 国际增强指数型基金, (QDII)混合型基金, 国际(QDII)债券型基金, 国际(QDII)另类投资基金, 204 | REITs] 205 | :param otc: 选择 OTC 基金代码 或 .SH / .SZ 的基金代码 206 | :param db_interface: DBInterface 207 | """ 208 | type_name = '场外基金' if otc else '场内基金' 209 | super().__init__(type_name, db_interface) 210 | self.fund_info = FundInfo(db_interface) 211 | all_tickers = self.fund_info.data.loc[self.fund_info.data['投资类型'].isin(investment_type), :] 212 | self.cache = self.cache.loc[self.cache.ID.isin(all_tickers.index.tolist()), :] 213 | 214 | def get_next_open_day(self, ids: Union[Sequence[str], str], date: dt.datetime = None): 215 | if date is None: 216 | date = dt.datetime.combine(dt.date.today(), dt.time()) 217 | if isinstance(ids, str): 218 | ids = [ids] 219 | list_date = self.get_list_date(ids) 220 | period = self.fund_info.data.loc[ids, '定开时长(月)'] 221 | df = pd.concat([list_date, period], axis=1) 222 | storage = [] 223 | for ticker, row in df.iterrows(): 224 | if pd.isna(row['定开时长(月)']): 225 | storage.append(pd.NaT) 226 | continue 227 | open_day = row.DateTime 228 | while open_day < date: 229 | open_day = open_day + relativedelta(months=row['定开时长(月)']) 230 | storage.append(open_day) 231 | return pd.Series(storage, index=df.index) 232 | 233 | 234 | class StockFundTickers(InvestmentStyleFundTicker): 235 | """ 236 | 股票型基金 237 | 238 | 以股票为主要(>=50%)投资标的的基金 239 | """ 240 | 241 | def __init__(self, otc: bool = False, db_interface: DBInterface = None) -> None: 242 | stock_investment_type = ['偏股混合型基金', '被动指数型基金', '灵活配置型基金', '增强指数型基金', '普通股票型基金', '股票多空', '平衡混合型基金'] 243 | super().__init__(stock_investment_type, otc, db_interface) 244 | 245 | 246 | class FundWithStocksTickers(InvestmentStyleFundTicker): 247 | """可以投资股票的基金 """ 248 | 249 | def __init__(self, otc: bool = False, db_interface: DBInterface = None) -> None: 250 | stock_investment_type = ['偏股混合型基金', '被动指数型基金', '灵活配置型基金', '增强指数型基金', '普通股票型基金', '股票多空', '平衡混合型基金', '混合债券型二级基金', 251 | '混合债券型一级基金', '偏债混合型基金'] 252 | super().__init__(stock_investment_type, otc, db_interface) 253 | 254 | 255 | class EnhancedIndexFund(InvestmentStyleFundTicker): 256 | """股票指数增强基金""" 257 | 258 | def __init__(self, otc: bool = False, db_interface: DBInterface = None) -> None: 259 | stock_investment_type = ['增强指数型基金'] 260 | super().__init__(stock_investment_type, otc, db_interface) 261 | 262 | 263 | class IndexFund(InvestmentStyleFundTicker): 264 | """指数基金""" 265 | 266 | def __init__(self, otc: bool = False, db_interface: DBInterface = None) -> None: 267 | stock_investment_type = ['被动指数型基金'] 268 | super().__init__(stock_investment_type, otc, db_interface) 269 | 270 | 271 | class ActiveManagedStockFundTickers(InvestmentStyleFundTicker): 272 | """以股票为主要(>=50%)投资标的的主动管理型基金""" 273 | 274 | def __init__(self, otc: bool = False, db_interface: DBInterface = None) -> None: 275 | stock_investment_type = ['偏股混合型基金', '灵活配置型基金', '增强指数型基金', '普通股票型基金', '股票多空', '平衡混合型基金'] 276 | super().__init__(stock_investment_type, otc, db_interface) 277 | 278 | 279 | class StockTickerSelector(TickerSelector): 280 | """股票代码选择器""" 281 | 282 | def __init__(self, policy: StockSelectionPolicy, db_interface: DBInterface = None) -> None: 283 | """ 284 | :param db_interface: BDInterface 285 | :param policy: 选股条件 286 | """ 287 | super().__init__() 288 | self.db_interface = db_interface if db_interface else get_db_interface() 289 | self.calendar = date_utils.SHSZTradingCalendar(self.db_interface) 290 | self.stock_ticker = StockTickers(self.db_interface) 291 | self.policy = policy 292 | 293 | @cached_property 294 | def paused_stock_selector(self): 295 | return OnTheRecordFactor('股票停牌', self.db_interface) 296 | 297 | @cached_property 298 | def const_limit_selector(self): 299 | return OnTheRecordFactor('一字涨跌停', self.db_interface) 300 | 301 | @cached_property 302 | def risk_warned_stock_selector(self): 303 | tmp = CompactFactor('证券名称', self.db_interface) 304 | ids = tmp.data.index.get_level_values('ID') 305 | tmp.data = tmp.data.loc[ids.str.endswith('.SH') | ids.str.endswith('.SZ')] 306 | tmp.data = tmp.data.map(lambda x: 'PT' in x or 'ST' in x or '退' in x) 307 | return CompactRecordFactor(tmp, '风险警示股') 308 | 309 | @cached_property 310 | def negative_book_value_stock_selector(self): 311 | return CompactFactor('负净资产股票', self.db_interface) 312 | 313 | @cached_property 314 | def industry_info(self): 315 | if self.policy.industry: 316 | return IndustryFactor(self.policy.industry_provider, self.policy.industry_level, self.db_interface) 317 | 318 | @date_utils.dtlize_input_dates 319 | def ticker(self, date: date_utils.DateType, ids: Sequence[str] = None) -> List[str]: 320 | """ select stocks that matched selection policy on `date`(amongst `ids`) 321 | 322 | :param date: query date 323 | :param ids: tickers to select from 324 | :return: list of ticker that satisfy the stock selection policy 325 | """ 326 | if ids is None: 327 | ids = set(self.stock_ticker.ticker(date)) 328 | 329 | if self.policy.ignore_new_stock_period or self.policy.select_new_stock_period: 330 | start_date, end_date = None, None 331 | if self.policy.ignore_new_stock_period: 332 | end_date = self.calendar.offset(date, -self.policy.ignore_new_stock_period) 333 | if self.policy.select_new_stock_period: 334 | start_date = self.calendar.offset(date, -self.policy.select_new_stock_period - 1) 335 | ids = set(self.stock_ticker.new_ticker(start_date=start_date, end_date=end_date)) & ids 336 | 337 | if self.industry_info and self.policy.industry: 338 | ids = ids & set(self.industry_info.list_constitutes(date=date, industry=self.policy.industry)) 339 | if self.policy.ignore_const_limit: 340 | ids = ids - set(self.const_limit_selector.get_data(date)) 341 | 342 | if self.policy.ignore_pause: 343 | ids = ids - set(self.paused_stock_selector.get_data(date)) 344 | elif self.policy.select_pause: 345 | ids = ids & set(self.paused_stock_selector.get_data(date)) 346 | if self.policy.max_pause_days: 347 | pause_days, period_length = self.policy.max_pause_days 348 | start_date = self.calendar.offset(date, -period_length) 349 | end_date = self.calendar.offset(date, -1) 350 | pause_counts = self.paused_stock_selector.get_counts(start_date=start_date, end_date=end_date) 351 | pause_counts = pause_counts.loc[pause_counts > pause_days] 352 | ids = ids - set(pause_counts.index.get_level_values('ID')) 353 | 354 | if self.policy.select_st: 355 | ids = ids & set(self.risk_warned_stock_selector.get_data(date)) 356 | if self.policy.st_defer_period: 357 | start_date = self.calendar.offset(date, -self.policy.st_defer_period - 1) 358 | ids = ids & set(self.risk_warned_stock_selector.get_data(start_date)) 359 | if self.policy.ignore_st: 360 | ids = ids - set(self.risk_warned_stock_selector.get_data(date)) 361 | 362 | if self.policy.ignore_negative_book_value_stock: 363 | data = self.negative_book_value_stock_selector.get_data(dates=date) 364 | ids = ids - set(data.loc[data == True].index.get_level_values('ID').tolist()) 365 | 366 | ids = sorted(list(ids)) 367 | return ids 368 | 369 | def generate_index(self, start_date: date_utils.DateType = None, end_date: date_utils.DateType = None, 370 | dates: Union[date_utils.DateType, Sequence[date_utils.DateType]] = None) -> pd.MultiIndex: 371 | storage = [] 372 | if dates is None: 373 | dates = self.calendar.select_dates(start_date, end_date) 374 | for date in dates: 375 | ids = self.ticker(date) 376 | storage.extend(list(product([date], ids))) 377 | return pd.MultiIndex.from_tuples(storage, names=['DateTime', 'ID']) 378 | -------------------------------------------------------------------------------- /AShareData/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import IndexHighlighter, major_index_valuation, MajorIndustryConstitutes, StockIndexFutureBasis 2 | -------------------------------------------------------------------------------- /AShareData/tools/tools.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import matplotlib.pyplot as plt 4 | import matplotlib.ticker as mtick 5 | import pandas as pd 6 | 7 | from AShareData import AShareDataReader, constants, SHSZTradingCalendar, utils 8 | from AShareData.config import get_db_interface 9 | from AShareData.database_interface import DBInterface 10 | from AShareData.factor import CompactFactor, ContinuousFactor 11 | from AShareData.tickers import StockIndexFutureIndex 12 | 13 | plt.rcParams['font.sans-serif'] = ['SimHei'] 14 | plt.rcParams['axes.unicode_minus'] = False 15 | 16 | 17 | class MajorIndustryConstitutes(object): 18 | def __init__(self, provider: str, level: int, cap: CompactFactor = None, db_interface: DBInterface = None): 19 | self.db_interface = db_interface if db_interface else get_db_interface() 20 | self.calendar = SHSZTradingCalendar(self.db_interface) 21 | self.date = self.calendar.today() 22 | self.data_reader = AShareDataReader(self.db_interface) 23 | self.industry = self.data_reader.industry(provider=provider, level=level) 24 | self.cap = cap if cap else self.data_reader.stock_free_floating_market_cap 25 | 26 | def get_major_constitute(self, name: str, n: int = None): 27 | if name not in self.industry.all_industries: 28 | raise ValueError(f'unknown industry: {name}') 29 | constitute = self.industry.list_constitutes(date=self.date, industry=name) 30 | val = self.cap.get_data(ids=constitute, dates=self.date) / 1e8 31 | if n: 32 | val = val.sort_values(ascending=False) 33 | val = val.head(n) 34 | constitute = val.index.get_level_values('ID').tolist() 35 | sec_name = self.data_reader.sec_name.get_data(ids=constitute, dates=self.date) 36 | pe = self.data_reader.pe_ttm.get_data(ids=constitute, dates=self.date) 37 | pb = self.data_reader.pb.get_data(ids=constitute, dates=self.date) 38 | ret = pd.concat([sec_name, val, pe, pb], axis=1).sort_values(val.name, ascending=False) 39 | return ret 40 | 41 | 42 | class IndexHighlighter(object): 43 | must_keep_indexes = ['全市场.IND', '全市场等权.IND', '次新股等权.IND', 'ST.IND'] 44 | 45 | def __init__(self, date: dt.datetime = None, db_interface: DBInterface = None): 46 | self.db_interface = db_interface if db_interface else get_db_interface() 47 | self.calendar = SHSZTradingCalendar(self.db_interface) 48 | if date is None: 49 | date = dt.datetime.combine(dt.date.today(), dt.time()) 50 | self.date = date 51 | records = utils.load_excel('自编指数配置.xlsx') 52 | self.tickers = [it['ticker'] for it in records] 53 | self.tbd_indexes = list(set(self.tickers) - set(self.must_keep_indexes)) 54 | start_date = self.calendar.offset(date, -22) 55 | index_factor = ContinuousFactor('自合成指数', '收益率', db_interface=self.db_interface) 56 | self.cache = index_factor.get_data(start_date=start_date, end_date=date).unstack() 57 | self.industry_cache = [] 58 | 59 | def featured_data(self, look_back_period: int, n: int) -> pd.DataFrame: 60 | data = self.cache.iloc[-look_back_period:, :] 61 | data = (data + 1).cumprod() 62 | tmp = data.loc[data.index[-1], self.tbd_indexes].sort_values() 63 | ordered_index = tmp.index.tolist() 64 | cols = ordered_index[:n] + ordered_index[-n:] 65 | self.industry_cache.extend(cols) 66 | return data.loc[:, cols + self.must_keep_indexes] - 1 67 | 68 | @staticmethod 69 | def disp_data(data): 70 | print(data.loc[data.index[-1], :].T.sort_values(ascending=False) * 100) 71 | 72 | def plot_index(self, period: int, n: int, ax: plt.Axes = None): 73 | plot_data = self.featured_data(period, n) * 100 74 | if ax is None: 75 | _, ax = plt.subplots(1, 1) 76 | plot_data.plot(ax=ax) 77 | ax.set_xlim(left=plot_data.index[0], right=plot_data.index[-1]) 78 | ax.grid(True) 79 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 80 | return ax 81 | 82 | def summary(self): 83 | for i, it in enumerate([(3, 3), (5, 3), (20, 3)]): 84 | print(f'回溯{it[0]}天:') 85 | self.disp_data(self.featured_data(it[0], it[1])) 86 | print('') 87 | self.plot_index(20, 3) 88 | mentioned_industry = [it[2:-4] for it in set(self.industry_cache) if it.startswith('申万')] 89 | constitute = MajorIndustryConstitutes(provider='申万', level=2) 90 | for it in mentioned_industry: 91 | print(f'申万2级行业 - {it}') 92 | print(constitute.get_major_constitute(it, 10)) 93 | print('') 94 | 95 | 96 | def major_index_valuation(db_interface: DBInterface = None): 97 | if db_interface is None: 98 | db_interface = get_db_interface() 99 | data = db_interface.read_table('指数日行情', ['市盈率TTM', '市净率']).dropna(how='all') 100 | tmp = data.groupby('ID').rank() 101 | latest = data.groupby('ID').tail(1) 102 | percentile = tmp.groupby('ID').tail(1) / tmp.groupby('ID').max() 103 | percentile.columns = [f'{it}分位' for it in percentile.columns] 104 | ret = pd.concat([latest, percentile], axis=1) 105 | ret = ret.loc[:, sorted(ret.columns)].reset_index() 106 | index_name_dict = dict(zip(constants.STOCK_INDEXES.values(), constants.STOCK_INDEXES.keys())) 107 | ret['ID'] = ret['ID'].map(index_name_dict) 108 | return ret.set_index(['DateTime', 'ID']) 109 | 110 | 111 | class StockIndexFutureBasis(object): 112 | FUTURE_INDEX_MAP = {'IH': '000016.SH', 'IF': '000300.SH', 'IC': '000905.SH'} 113 | 114 | def __init__(self, date: dt.datetime = None, lookback_period: int = 5, db_interface: DBInterface = None): 115 | super().__init__() 116 | self.date = date if date else dt.datetime.combine(dt.date.today(), dt.time()) 117 | self.look_back_period = lookback_period 118 | self.db_interface = db_interface if db_interface else get_db_interface() 119 | self.data_reader = AShareDataReader(self.db_interface) 120 | self.cal = SHSZTradingCalendar(self.db_interface) 121 | self.stock_index_tickers = StockIndexFutureIndex(self.db_interface) 122 | 123 | def compute(self) -> pd.DataFrame: 124 | start_date = self.cal.offset(self.date, -self.look_back_period) 125 | tickers = self.stock_index_tickers.ticker() 126 | tickers_info = self.db_interface.read_table('期货合约', '最后交易日', ids=tickers).to_frame() 127 | tickers_info['index_ticker'] = [self.FUTURE_INDEX_MAP[it[:2]] for it in tickers_info.index] 128 | index_close = self.data_reader.index_close.get_data(start_date=start_date, end_date=self.date, 129 | ids=list(self.FUTURE_INDEX_MAP.values())).reset_index() 130 | future_close = self.data_reader.future_close.get_data(start_date=start_date, end_date=self.date, 131 | ids=tickers).reset_index() 132 | tmp = pd.merge(future_close, tickers_info, left_on='ID', right_index=True) 133 | df = pd.merge(tmp, index_close, left_on=['DateTime', 'index_ticker'], right_on=['DateTime', 'ID']).rename( 134 | {'ID_x': 'ID'}, axis=1) 135 | df['合约时长'] = (pd.to_datetime(df['最后交易日']) - df['DateTime']).dt.days 136 | df['年化贴水率'] = ((df['收盘价'] / df['收盘点位']) - 1) / df['合约时长'] * 365 * 100 137 | 138 | res = df.loc[:, ['DateTime', 'ID', '年化贴水率']].set_index(['ID', 'DateTime']).unstack().loc[:, '年化贴水率'] 139 | return res 140 | -------------------------------------------------------------------------------- /AShareData/utils.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import json 3 | import sys 4 | import tempfile 5 | from dataclasses import dataclass 6 | from importlib.resources import open_text, read_binary 7 | from typing import Any, Dict, List, Optional, Tuple, Union 8 | 9 | import pandas as pd 10 | 11 | from . import constants 12 | 13 | 14 | class NullPrinter(object): 15 | def __init__(self): 16 | self._stdout = None 17 | self._std_error = None 18 | self._temp_file = None 19 | 20 | def __enter__(self): 21 | self._stdout = sys.stdout 22 | self._std_error = sys.stderr 23 | self._temp_file = tempfile.TemporaryFile(mode='w') 24 | sys.stdout = self._temp_file 25 | sys.stderr = self._temp_file 26 | 27 | def __exit__(self, exc_type, exc_val, exc_tb): 28 | self._temp_file.close() 29 | sys.stdout = self._stdout 30 | sys.stderr = self._std_error 31 | 32 | 33 | def load_param(default_loc: str, param_json_loc: str = None) -> Dict[str, Any]: 34 | if param_json_loc is None: 35 | f = open_text('AShareData.data', default_loc) 36 | else: 37 | f = open(param_json_loc, 'r', encoding='utf-8') 38 | with f: 39 | param = json.load(f) 40 | return param 41 | 42 | 43 | def load_excel(default_loc: str, param_json_loc: str = None) -> List[Dict[str, Any]]: 44 | if param_json_loc is None: 45 | f = read_binary('AShareData.data', default_loc) 46 | df = pd.read_excel(f) 47 | else: 48 | df = pd.read_excel(param_json_loc) 49 | for col in df.columns: 50 | df[col] = df[col].where(df[col].notnull(), other=None) 51 | return df.to_dict('records') 52 | 53 | 54 | def format_stock_ticker(ticker: Union[str, int]) -> str: 55 | if isinstance(ticker, str): 56 | ticker = int(ticker) 57 | if ticker < 600000: 58 | return f'{ticker:06d}.SZ' 59 | else: 60 | return f'{ticker:06d}.SH' 61 | 62 | 63 | def format_czc_ticker(ticker: str) -> str: 64 | c = ticker[1] if ticker[1].isnumeric() else ticker[2] 65 | ticker = ticker.replace(c, '', 1) 66 | return ticker 67 | 68 | 69 | def full_czc_ticker(ticker: str) -> str: 70 | c = 1 if ticker[1].isnumeric() else 2 71 | ticker = ticker[:c] + '2' + ticker[c:] 72 | return ticker 73 | 74 | 75 | def split_hs_ticker(ticker: str) -> Optional[Tuple[int, str]]: 76 | try: 77 | ticker_num, market = ticker.split('.') 78 | except (ValueError, AttributeError): 79 | return None 80 | if market not in ['SH', 'SZ']: 81 | return None 82 | try: 83 | ticker_num = int(ticker_num) 84 | except ValueError: 85 | return None 86 | return ticker_num, market 87 | 88 | 89 | def is_main_board_stock(ticker: str) -> bool: 90 | """判断是否为沪深主板股票 91 | 92 | :param ticker: 股票代码, 如 `000001.SZ` 93 | """ 94 | return get_stock_board_name(ticker) == '主板' 95 | 96 | 97 | def get_stock_board_name(ticker: str) -> str: 98 | """获取股票所在版块(主板, 中小板, 创业板, 科创板), 其他返回 `非股票` 99 | 100 | :param ticker: 股票代码, 如 `000001.SZ` 101 | """ 102 | ret = split_hs_ticker(ticker) 103 | if ret is None: 104 | return '非股票' 105 | ticker_num, market = ret 106 | if (0 < ticker_num < 2000 and market == 'SZ') or (600000 <= ticker_num < 606000 and market == 'SH'): 107 | return '主板' 108 | elif 2000 < ticker_num < 4000 and market == 'SZ': 109 | return '中小板' 110 | elif 300000 < ticker_num < 301000 and market == 'SZ': 111 | return '创业板' 112 | elif 688000 < ticker_num < 690000 and market == 'SH': 113 | return '科创板' 114 | else: 115 | return '非股票' 116 | 117 | 118 | class SecuritySelectionPolicy: 119 | pass 120 | 121 | 122 | @dataclass 123 | class StockSelectionPolicy(SecuritySelectionPolicy): 124 | """ 股票筛选条件 125 | 126 | :param industry_provider: 股票行业分类标准 127 | :param industry_level: 股票行业分类标准 128 | :param industry: 股票所在行业 129 | 130 | :param ignore_new_stock_period: 新股纳入市场收益计算的时间(交易日天数) 131 | :param select_new_stock_period: 仅选取新上市的股票, 可与 ``ignore_new_stock_period`` 搭配使用 132 | 133 | :param ignore_st: 排除 风险警告股 134 | :param select_st: 仅选取 风险警告股, 包括 PT, ST, SST, \*ST, (即将)退市股 等 135 | :param st_defer_period: 新ST纳入计算的时间(交易日天数), 配合 ``select_st`` 使用 136 | 137 | :param select_pause: 选取停牌股 138 | :param ignore_pause: 排除停牌股 139 | :param max_pause_days: (i, n): 在前n个交易日中最大停牌天数不大于i 140 | 141 | :param ignore_const_limit: 排除一字板股票 142 | :param ignore_negative_book_value_stock: 排除净资产为负的股票 143 | """ 144 | industry_provider: str = None 145 | industry_level: int = None 146 | industry: str = None 147 | 148 | ignore_new_stock_period: int = None 149 | select_new_stock_period: int = None 150 | 151 | ignore_st: bool = False 152 | select_st: bool = False 153 | st_defer_period: int = 10 154 | 155 | select_pause: bool = False 156 | ignore_pause: bool = False 157 | max_pause_days: Tuple[int, int] = None 158 | 159 | ignore_const_limit: bool = False 160 | ignore_negative_book_value_stock: bool = False 161 | 162 | def __post_init__(self): 163 | if self.ignore_new_stock_period: 164 | self.ignore_new_stock_period = int(self.ignore_new_stock_period) 165 | if self.select_new_stock_period: 166 | self.select_new_stock_period = int(self.select_new_stock_period) 167 | if self.industry_provider: 168 | if self.industry_provider not in constants.INDUSTRY_DATA_PROVIDER: 169 | raise ValueError('非法行业分类机构!') 170 | if not (0 < self.industry_level <= constants.INDUSTRY_LEVEL[self.industry_provider]): 171 | raise ValueError('非法行业分类级别!') 172 | self.industry_level = int(self.industry_level) 173 | if self.ignore_st & self.select_st: 174 | raise ValueError('不能同时选择ST股票和忽略ST股票') 175 | 176 | 177 | @dataclass 178 | class StockIndexCompositionPolicy: 179 | """ 自建指数信息 180 | 181 | :param ticker: 新建指数入库代码. 建议以`.IND`结尾, 代表自合成指数 182 | :param name: 指数名称 183 | :param unit_base: 股本指标 184 | :param stock_selection_policy: 股票筛选条件 185 | :param start_date: 指数开始日期 186 | """ 187 | ticker: str = None 188 | name: str = None 189 | unit_base: str = None 190 | stock_selection_policy: StockSelectionPolicy = None 191 | start_date: dt.datetime = None 192 | 193 | def __post_init__(self): 194 | if self.unit_base and self.unit_base not in ['自由流通股本', '总股本', 'A股流通股本', 'A股总股本']: 195 | raise ValueError('非法股本字段!') 196 | 197 | @classmethod 198 | def from_dict(cls, info: Dict): 199 | info = info.copy() 200 | ticker = info.pop('ticker') 201 | name = info.pop('name') 202 | unit_base = info.pop('unit_base') 203 | start_date = info.pop('start_date') 204 | stock_selection_policy = StockSelectionPolicy(**info) 205 | return cls(ticker=ticker, name=name, unit_base=unit_base, stock_selection_policy=stock_selection_policy, 206 | start_date=start_date) 207 | 208 | 209 | class TickerSelector(object): 210 | def __init__(self): 211 | super().__init__() 212 | 213 | def generate_index(self, *args, **kwargs) -> pd.MultiIndex: 214 | raise NotImplementedError() 215 | 216 | def ticker(self, *args, **kwargs) -> List[str]: 217 | raise NotImplementedError() 218 | 219 | 220 | def generate_factor_bin_names(factor_name: str, weight: bool = True, industry_neutral: bool = True, bins: int = 10): 221 | i = 'I' if industry_neutral else 'N' 222 | w = 'W' if weight else 'N' 223 | return [f'{factor_name}_{i}{w}_G{it}inG{bins}' for it in range(1, bins + 1)] 224 | 225 | 226 | def decompose_bin_names(factor_bin_name): 227 | tmp = factor_bin_name.split('_') 228 | composition_info = tmp[1] 229 | group_info = tmp[-1].split('in') 230 | 231 | return { 232 | 'factor_name': tmp[0], 233 | 'industry_neutral': composition_info[0] == 'I', 234 | 'cap_weight': composition_info[1] == 'W', 235 | 'group': group_info[0], 236 | 'total_group': group_info[-1] 237 | } 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 jicewarwick 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A股数据获取及本地SQL储存与读取 2 | Manual: 3 | - 在 `config.json` 里填写相关信息. 模板文件为 `config_example.json` 4 | - 已完成数据: 5 | - 交易日历 6 | - 股票 7 | - 股票列表 8 | - 上市公司基本信息 9 | - IPO新股列表 10 | - 日行情 11 | - 中信, 中证, 申万, Wind行业 12 | - 股票曾用名 / ST处理情况 13 | - 财报 14 | - 指数日行情, 列类似于股票日行情 15 | - 期货 16 | - 合约列表 17 | - 日行情 18 | - 期权 19 | - 合约列表 20 | - 行情 21 | - 基金 22 | - ETF基金列表 23 | - ETF日行情 24 | - 股票指数 25 | - 日行情 26 | - 自合成指标: 27 | - 股票涨跌停一字板 28 | - 股票自定义指数合成 29 | 30 | Dependencies: 31 | - numpy 32 | - pandas 33 | - tushare 34 | - sqlalchemy 35 | - tqdm: 进度显示 36 | - requests 37 | - sortedcontainers 38 | 39 | Optional: 40 | - pymysql: 数据库驱动 41 | - pymysqldb 42 | - WindPy 43 | - alphalens 44 | -------------------------------------------------------------------------------- /config_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "tushare": { 3 | "token": "********" 4 | }, 5 | "db_interface": { 6 | "driver": "********", 7 | "host": "********", 8 | "port": 0, 9 | "database": "********", 10 | "username": "********", 11 | "password": "********" 12 | }, 13 | "join_quant": { 14 | "mobile": "***********", 15 | "password": "************" 16 | }, 17 | "tdx_server": { 18 | "host": "**************", 19 | "port": 0 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | sphinx-apidoc -o source/modules -e -f ../AShareData 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M clean %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | sphinx-autogen %SOURCEDIR%/index.rst 30 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 31 | goto end 32 | 33 | :help 34 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 35 | 36 | :end 37 | popd 38 | -------------------------------------------------------------------------------- /docs/source/DBInterface.rst: -------------------------------------------------------------------------------- 1 | Database Interfaces 2 | =================== 3 | 4 | Classes that interact with database. ALL reads and writes should go through this API. 5 | 6 | .. autoclass:: AShareData.DBInterface 7 | 8 | .. autoclass:: AShareData.MySQLInterface 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/source/DataReader.rst: -------------------------------------------------------------------------------- 1 | DataReader 2 | =================================== 3 | 4 | .. autoclass:: AShareData.AShareDataReader 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/DataSource.rst: -------------------------------------------------------------------------------- 1 | DataSources 2 | =========== 3 | DataSources writes data into database. 4 | 5 | Outside data source inherent from ``DateSource``. 6 | They can be used as a context manager. log ins and log outs should be handled in ``.__enter__()`` and ``.__exit__()`` 7 | 8 | Base Class 9 | ----------- 10 | .. autoclass:: AShareData.data_source.DataSource 11 | :members: 12 | 13 | Market Data Implementation 14 | ---------------------------- 15 | Tushare 16 | ^^^^^^^^^^ 17 | .. autoclass:: AShareData.TushareData 18 | :members: 19 | 20 | Wind 21 | ^^^^^^^^^^ 22 | .. autoclass:: AShareData.WindData 23 | :members: 24 | 25 | Join Quant 26 | ^^^^^^^^^^ 27 | .. autoclass:: AShareData.JQData 28 | :members: 29 | 30 | 通达讯 31 | ^^^^^^^^^^ 32 | .. autoclass:: AShareData.TDXData 33 | :members: 34 | 35 | Web HTTP request 36 | ^^^^^^^^^^^^^^^^^^^^ 37 | .. autoclass:: AShareData.WebDataCrawler 38 | :members: 39 | 40 | Internally Computed Data 41 | -------------------------- 42 | Base Class 43 | ^^^^^^^^^^^ 44 | .. autoclass:: AShareData.factor_compositor.FactorCompositor 45 | :members: 46 | 47 | Implementation 48 | ^^^^^^^^^^^^^^^^^^^^^^ 49 | .. autoclass:: AShareData.factor_compositor.ConstLimitStockFactorCompositor 50 | :members: 51 | 52 | .. autoclass:: AShareData.factor_compositor.FundAdjFactorCompositor 53 | :members: 54 | 55 | 56 | Index Compositor 57 | """"""""""""""""""""" 58 | .. autoclass:: AShareData.utils.StockSelectionPolicy 59 | :members: 60 | 61 | .. autoclass:: AShareData.utils.StockIndexCompositionPolicy 62 | :members: 63 | 64 | .. autoclass:: AShareData.factor_compositor.IndexCompositor 65 | :members: 66 | 67 | 68 | Factor Portfolio Return 69 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 70 | .. autoclass:: AShareData.factor_compositor.FactorPortfolioPolicy 71 | :members: 72 | 73 | .. autoclass:: AShareData.factor_compositor.FactorPortfolio 74 | :members: 75 | -------------------------------------------------------------------------------- /docs/source/DateUtils.rst: -------------------------------------------------------------------------------- 1 | Dates Related Classes and Functions 2 | =================================== 3 | 4 | .. autoclass:: AShareData.TradingCalendar 5 | :members: 6 | :inherited-members: 7 | -------------------------------------------------------------------------------- /docs/source/Factor.rst: -------------------------------------------------------------------------------- 1 | Factors 2 | =========== 3 | 4 | Factor class: 5 | 6 | * implemented ``.get_data()`` function that retrieve and/or computes required data with (DateTime, ID) multiindex 7 | * common numerical operations(+, -, \*, /), boolean operation(>, >=, ==, <=, <, !=) and transformation(log, pct_change, etc) are supported 8 | 9 | 基类 10 | ----- 11 | .. autoclass:: AShareData.Factor.Factor 12 | :members: 13 | :inherited-members: 14 | 15 | 非财报数据 16 | ----------- 17 | .. autoclass:: AShareData.Factor.IndexConstitute 18 | :members: 19 | 20 | .. autoclass:: AShareData.Factor.IndustryFactor 21 | :members: 22 | 23 | 行情数据 24 | ^^^^^^^^ 25 | .. autoclass:: AShareData.Factor.CompactFactor 26 | :members: 27 | 28 | .. autoclass:: AShareData.Factor.OnTheRecordFactor 29 | :members: 30 | 31 | .. autoclass:: AShareData.Factor.ContinuousFactor 32 | :members: 33 | 34 | .. autoclass:: AShareData.Factor.BetaFactor 35 | :members: 36 | 37 | 38 | 财报数据 39 | ------------ 40 | .. autoclass:: AShareData.Factor.QuarterlyFactor 41 | :members: 42 | 43 | .. autoclass:: AShareData.Factor.LatestAccountingFactor 44 | :members: 45 | 46 | .. autoclass:: AShareData.Factor.LatestQuarterAccountingFactor 47 | :members: 48 | 49 | .. autoclass:: AShareData.Factor.YearlyReportAccountingFactor 50 | :members: 51 | 52 | .. autoclass:: AShareData.Factor.QOQAccountingFactor 53 | :members: 54 | 55 | .. autoclass:: AShareData.Factor.YOYPeriodAccountingFactor 56 | :members: 57 | 58 | .. autoclass:: AShareData.Factor.YOYQuarterAccountingFactor 59 | :members: 60 | 61 | .. autoclass:: AShareData.Factor.TTMAccountingFactor 62 | :members: 63 | 64 | -------------------------------------------------------------------------------- /docs/source/Model.rst: -------------------------------------------------------------------------------- 1 | Equity Market Model 2 | ====================== 3 | 4 | Base Class 5 | ---------------------- 6 | .. autoclass:: AShareData.model.model.FinancialModel 7 | :members: 8 | :inherited-members: 9 | 10 | Models and their FactorCompositors 11 | ------------------------------------- 12 | 13 | Capital Asset Pricing Model 14 | ^^^^^^^^^ 15 | .. autoclass:: AShareData.model.CAPM 16 | :members: 17 | :inherited-members: 18 | 19 | 20 | Fama French 3 Factor Model 21 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 22 | .. autoclass:: AShareData.model.FamaFrench3FactorModel 23 | :members: 24 | :inherited-members: 25 | 26 | .. autoclass:: AShareData.model.SMBandHMLCompositor 27 | :members: 28 | :inherited-members: 29 | 30 | Fama French Carhart 4 Factor Model 31 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 32 | .. autoclass:: AShareData.model.FamaFrenchCarhart4FactorModel 33 | :members: 34 | :inherited-members: 35 | 36 | .. autoclass:: AShareData.model.UMDCompositor 37 | :members: 38 | :inherited-members: 39 | -------------------------------------------------------------------------------- /docs/source/Tickers.rst: -------------------------------------------------------------------------------- 1 | Tickers 2 | ======= 3 | 4 | ``Ticker`` class helps to select tickers that you need. They implements 5 | 6 | * ``.all_ticker()`` to get all tickers that belongs to that type, alive or dead 7 | * ``.ticker(date)`` to get all ticker available on that date 8 | * ``.list_date()`` returns a dict mapping ticker to the time it is listed 9 | * ``.get_list_date(ticker)`` return ``ticker``'s list date 10 | 11 | 基类 12 | ----- 13 | .. autoclass:: AShareData.Tickers.TickersBase 14 | :members: 15 | 16 | 股票, 股基 17 | ------------- 18 | .. autoclass:: AShareData.Tickers.StockTickers 19 | 20 | .. autoclass:: AShareData.Tickers.StockFundTickers 21 | .. autoclass:: AShareData.Tickers.ExchangeStockETFTickers 22 | .. autoclass:: AShareData.Tickers.EnhancedIndexFund 23 | 24 | .. autoclass:: AShareData.Tickers.StockOTCFundTickers 25 | .. autoclass:: AShareData.Tickers.ActiveManagedOTCStockFundTickers 26 | 27 | 债券, 债基 28 | ------------ 29 | .. autoclass:: AShareData.Tickers.ConvertibleBondTickers 30 | .. autoclass:: AShareData.Tickers.BondETFTickers 31 | 32 | 基金 33 | ------------ 34 | .. autoclass:: AShareData.Tickers.FundTickers 35 | .. autoclass:: AShareData.Tickers.ETFTickers 36 | .. autoclass:: AShareData.Tickers.ExchangeFundTickers 37 | .. autoclass:: AShareData.Tickers.IndexFund 38 | 39 | 衍生品 40 | -------- 41 | .. autoclass:: AShareData.Tickers.FutureTickers 42 | .. autoclass:: AShareData.Tickers.OptionTickers 43 | .. autoclass:: AShareData.Tickers.IndexOptionTickers 44 | .. autoclass:: AShareData.Tickers.ETFOptionTickers 45 | 46 | 股票筛选器 47 | ------------------ 48 | .. autoclass:: AShareData.utils.StockSelectionPolicy 49 | :members: 50 | 51 | .. autoclass:: AShareData.Tickers.StockTickerSelector 52 | :members: 53 | 54 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath('../../AShareData')) 5 | 6 | # -- Project information ----------------------------------------------------- 7 | project = 'AShareData' 8 | copyright = '2021, Ce Ji' 9 | author = 'Ce Ji' 10 | release = '0.1.0' 11 | # -- General configuration --------------------------------------------------- 12 | extensions = ['IPython.sphinxext.ipython_directive', 13 | 'IPython.sphinxext.ipython_console_highlighting', 14 | 'sphinx.ext.mathjax', 15 | 'sphinx.ext.autodoc', 16 | 'sphinx.ext.inheritance_diagram', 17 | 'sphinx.ext.autosummary', 18 | # 'sphinx.ext.napoleon' 19 | ] 20 | templates_path = ['_templates'] 21 | exclude_patterns = [] 22 | # -- Options for HTML output ------------------------------------------------- 23 | html_theme = 'sphinx_rtd_theme' 24 | html_static_path = ['_static'] 25 | # If false, no module index is generated. 26 | html_use_modindex = True 27 | # If false, no index is generated. 28 | html_use_index = True 29 | 30 | # autodoc 31 | autodoc_type_aliases = {'AShare.DateUtils.DateType': 'DateType'} 32 | autodoc_default_options = { 33 | 'member-order': 'bysource', 34 | } 35 | autodoc_mock_imports = ['WindPy'] 36 | autoclass_content = 'both' 37 | 38 | # autosummary 39 | autosummary_generate = True 40 | autosummary_imported_members = True 41 | 42 | # -- Options for LaTeX output -------------------------------------------------- 43 | latex_paper_size = 'a4' 44 | # The font size ('10pt', '11pt' or '12pt'). 45 | # latex_font_size = '10pt' 46 | 47 | # Grouping the document tree into LaTeX files. List of tuples 48 | # (source start file, target name, title, author, documentclass [howto/manual]). 49 | latex_documents = [('index',), ] 50 | # For "manual" documents, if this is true, then toplevel headings are parts, 51 | # not chapters. 52 | # latex_use_parts = False 53 | 54 | # Additional stuff for the LaTeX preamble. 55 | # latex_preamble = '' 56 | 57 | # Documents to append as an appendix to all manuals. 58 | # latex_appendices = [] 59 | 60 | # If false, no module index is generated. 61 | # latex_use_modindex = True 62 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to AShareData's documentation! 2 | ====================================== 3 | 4 | Introduction 5 | ============ 6 | 7 | You'll need: 8 | ------------ 9 | - python(>=3.7) 10 | - SQL database somewhere 11 | - Tushare account for accounting data(points > 800) 12 | - (Optionally) Wind data service 13 | 14 | Setting up 15 | ---------- 16 | Filling out ``config.json`` file. Example can be found in repo root 17 | 18 | Components 19 | ---------- 20 | .. toctree:: 21 | :maxdepth: 3 22 | 23 | DBInterface 24 | DataSource 25 | Tickers 26 | Factor 27 | DateUtils 28 | DataReader 29 | Model 30 | 31 | Indices and tables 32 | ================== 33 | 34 | * :ref:`genindex` 35 | * :ref:`modindex` 36 | * :ref:`search` 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alphalens 2 | dateutil 3 | empyrical 4 | jqdatasdk 5 | matplotlib 6 | numpy 7 | pandas 8 | pytdx 9 | ratelimiter 10 | requests 11 | retrying 12 | scipy 13 | setuptools 14 | singleton-decorator 15 | sqlalchemy 16 | statsmodels 17 | tqdm 18 | tushare 19 | -------------------------------------------------------------------------------- /scripts/big_names.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import AShareData as asd 4 | 5 | if __name__ == '__main__': 6 | config_loc = './config.json' 7 | asd.set_global_config(config_loc) 8 | 9 | data_reader = asd.AShareDataReader() 10 | calendar = asd.SHSZTradingCalendar() 11 | date = calendar.yesterday() 12 | 13 | industry = data_reader.industry('申万', 2).get_data(dates=date) 14 | cap = data_reader.stock_market_cap.get_data(dates=date) / 1e8 15 | sec_name = data_reader.sec_name.get_data(dates=date) 16 | 17 | df = pd.concat([sec_name, industry, cap], axis=1).dropna() 18 | df.columns = df.columns[:2].tolist() + ['cap'] 19 | df = df.sort_values('cap', ascending=False) 20 | industry_big_name = df.groupby('申万2级行业').head(3) 21 | big_cap = df.head(300) 22 | all_ticker = pd.concat([industry_big_name, big_cap]).drop_duplicates().sort_index().droplevel('DateTime') 23 | 24 | company_info = asd.get_db_interface().read_table('上市公司基本信息', columns=['所在城市', '主要业务及产品', '经营范围'], 25 | ids=all_ticker.index.get_level_values('ID').tolist()) 26 | ret = pd.concat([all_ticker, company_info], axis=1) 27 | ret.to_excel('big_names.xlsx', freeze_panes=(0, 3)) 28 | -------------------------------------------------------------------------------- /scripts/daily_report.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import sys 3 | 4 | import pandas as pd 5 | 6 | import AShareData as asd 7 | 8 | if __name__ == '__main__': 9 | pd.set_option('precision', 2) 10 | 11 | asd.set_global_config(sys.argv[1]) 12 | db_interface = asd.get_db_interface() 13 | pre_date = dt.datetime.combine(dt.date.today(), dt.time()) - dt.timedelta(days=7) 14 | 15 | data = db_interface.read_table('市场汇总', start_date=pre_date) 16 | data['换手率'] = data['成交额'] / data['自由流通市值'] * 100 17 | data.iloc[:, :4] = data.iloc[:, :4] / 1e12 18 | print('市场成交和估值:') 19 | print(data) 20 | 21 | print('') 22 | print('自编指数收益:') 23 | asd.IndexHighlighter().summary() 24 | 25 | print('') 26 | print('主要指数估值:') 27 | print(asd.major_index_valuation()) 28 | 29 | print('') 30 | print('股指贴水情况:') 31 | print(asd.StockIndexFutureBasis().compute()) 32 | 33 | data_reader = asd.AShareDataReader() 34 | model_factor_ret = data_reader.model_factor_return.bind_params(ids=['FF3_SMB_DD', 'FF3_HML_DD', 'Carhart_UMD_DD']) 35 | date = dt.datetime.combine(dt.date.today(), dt.time()) 36 | print('') 37 | print('因子收益率:') 38 | print(pd.concat([data_reader.market_return.get_data(dates=date), model_factor_ret.get_data(dates=date)])) 39 | -------------------------------------------------------------------------------- /scripts/factor_return.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import sys 3 | 4 | import AShareData as asd 5 | from AShareData.factor_compositor import FactorPortfolio, FactorPortfolioPolicy 6 | from AShareData.utils import StockSelectionPolicy 7 | 8 | if __name__ == '__main__': 9 | asd.set_global_config(sys.argv[1]) 10 | 11 | data_reader = asd.AShareDataReader() 12 | stock_selection_policy = StockSelectionPolicy() 13 | stock_selection_policy.ignore_new_stock_period = 244 14 | stock_selection_policy.ignore_st = True 15 | stock_selection_policy.ignore_pause = True 16 | 17 | policy = FactorPortfolioPolicy() 18 | policy.bins = [5, 10] 19 | policy.stock_selection_policy = stock_selection_policy 20 | policy.start_date = dt.datetime(2010, 1, 1) 21 | policy.industry = data_reader.industry('申万', 1) 22 | policy.weight = data_reader.stock_free_floating_market_cap 23 | 24 | policy.name = data_reader.beta.name 25 | policy.factor = data_reader.beta 26 | 27 | sub_port = FactorPortfolio(factor_portfolio_policy=policy) 28 | sub_port.update() 29 | -------------------------------------------------------------------------------- /scripts/init.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import AShareData as asd 4 | from update_routine import daily_routine 5 | 6 | if __name__ == '__main__': 7 | config_loc = sys.argv[1] 8 | db_interface = asd.generate_db_interface_from_config(config_loc, init=True) 9 | asd.set_global_config(config_loc) 10 | 11 | with asd.TushareData() as tushare_data: 12 | tushare_data.init_db() 13 | tushare_data.init_accounting_data() 14 | 15 | daily_routine(config_loc) 16 | 17 | asd.model.SMBandHMLCompositor(asd.FamaFrench3FactorModel()).update() 18 | asd.model.UMDCompositor(asd.FamaFrenchCarhart4FactorModel()).update() 19 | -------------------------------------------------------------------------------- /scripts/update_morning_auction.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import sys 3 | 4 | from DingTalkMessageBot import DingTalkMessageBot 5 | 6 | import AShareData as asd 7 | 8 | if __name__ == '__main__': 9 | config_loc = sys.argv[1] 10 | asd.set_global_config(config_loc) 11 | 12 | tushare_crawler = asd.TushareData() 13 | tushare_crawler.get_ipo_info() 14 | 15 | messenger = DingTalkMessageBot.from_config(config_loc, '自闭') 16 | try: 17 | with asd.JQData() as jq_data: 18 | date = dt.date.today() 19 | jq_data.stock_open_auction_data(date) 20 | messenger.send_message(f'{date} 集合竞价数据已下载.') 21 | except: 22 | messenger.send_message(f'{date} 集合竞价数据下载失败.') 23 | -------------------------------------------------------------------------------- /scripts/update_routine.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import AShareData as asd 4 | 5 | 6 | def daily_routine(config_loc: str): 7 | asd.set_global_config(config_loc) 8 | 9 | with asd.TushareData() as tushare_crawler: 10 | tushare_crawler.update_base_info() 11 | tushare_crawler.get_shibor() 12 | 13 | tushare_crawler.get_ipo_info() 14 | tushare_crawler.get_company_info() 15 | tushare_crawler.update_hs_holding() 16 | tushare_crawler.get_hs_constitute() 17 | 18 | tushare_crawler.update_stock_names() 19 | tushare_crawler.update_dividend() 20 | 21 | tushare_crawler.update_index_daily() 22 | 23 | tushare_crawler.update_hk_stock_daily() 24 | 25 | tushare_crawler.update_fund_daily() 26 | tushare_crawler.update_fund_dividend() 27 | tushare_crawler.update_financial_data() 28 | 29 | with asd.WindData() as wind_data: 30 | wind_data.update_stock_daily_data() 31 | wind_data.update_stock_adj_factor() 32 | wind_data.update_stock_units() 33 | wind_data.update_industry() 34 | wind_data.update_pause_stock_info() 35 | 36 | wind_data.update_convertible_bond_daily_data() 37 | wind_data.update_cb_convertible_price() 38 | wind_data.update_future_daily_data() 39 | wind_data.update_fund_info() 40 | wind_data.update_stock_option_daily_data() 41 | 42 | with asd.JQData() as jq_data: 43 | jq_data.update_stock_morning_auction_data() 44 | 45 | with asd.TDXData() as tdx_data: 46 | # tdx_data.update_stock_minute() 47 | tdx_data.update_convertible_bond_minute() 48 | 49 | # compute data 50 | asd.ConstLimitStockFactorCompositor().update() 51 | asd.NegativeBookEquityListingCompositor().update() 52 | asd.IndexUpdater().update() 53 | asd.MarketSummaryCompositor().update() 54 | 55 | # model data 56 | asd.model.SMBandHMLCompositor().update() 57 | asd.model.UMDCompositor().update() 58 | 59 | 60 | if __name__ == '__main__': 61 | daily_routine(sys.argv[1]) 62 | -------------------------------------------------------------------------------- /scripts/wind_stock_rt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from AShareData import set_global_config, WindData 4 | 5 | if __name__ == '__main__': 6 | config_loc = sys.argv[1] 7 | set_global_config(config_loc) 8 | 9 | with WindData() as wind_data: 10 | wind_data.get_stock_rt_price() 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 7 | long_description = f.read() 8 | 9 | setup( 10 | name='AShareData', 11 | version='0.1.0', 12 | description='Gather data for A share and store in MySQL database', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/jicewarwick/AShareData', 16 | author='Ce Ji', 17 | author_email='Mr.Ce.Ji@outlook.com', 18 | classifiers=[ 19 | 'Development Status :: 1 - Planning', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Office/Business :: Financial :: Investment', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.7', 24 | ], 25 | keywords='tushare mysql', 26 | packages=find_packages(exclude=['docs', 'tests']), 27 | python_requires='>=3.5, <4', 28 | install_requires=['numpy', 29 | 'pandas', 30 | 'tushare', 31 | 'sqlalchemy', 32 | 'tqdm', 33 | 'requests', 34 | ], 35 | package_data={ 36 | 'json': ['data/*'], 37 | }, 38 | entry_points={ 39 | 'console_scripts': [ 40 | 'sample=update_routine:main', 41 | ], 42 | }, 43 | project_urls={ 44 | 'Bug Reports': 'https://github.com/jicewarwick/AShareData/issues', 45 | 'Source': 'https://github.com/jicewarwick/AShareData', 46 | }, 47 | ) 48 | -------------------------------------------------------------------------------- /tests/analysis_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData import * 4 | from AShareData.analysis.fund_nav_analysis import * 5 | from AShareData.analysis.holding import * 6 | from AShareData.analysis.public_fund_holding import * 7 | # from AShareData.analysis.trading import * 8 | from AShareData.analysis.return_analysis import * 9 | from AShareData.factor import ContinuousFactor 10 | 11 | 12 | class MyTestCase(unittest.TestCase): 13 | def setUp(self) -> None: 14 | set_global_config('config.json') 15 | self.target = ContinuousFactor('自合成指数', '收益率') 16 | self.target.bind_params(ids='ST.IND') 17 | self.benchmark = ContinuousFactor('自合成指数', '收益率') 18 | self.benchmark.bind_params(ids='全市场.IND') 19 | self.start = dt.datetime(2012, 1, 1) 20 | self.end = dt.datetime(2020, 1, 1) 21 | 22 | def test_max_drawdown(self): 23 | returns = self.target.get_data(start_date=self.start, end_date=self.end).unstack().iloc[:, 0] 24 | print(locate_max_drawdown(returns)) 25 | returns = self.benchmark.get_data().unstack().iloc[:, 0] 26 | print(locate_max_drawdown(returns)) 27 | 28 | def test_aggregate_return(self): 29 | print(aggregate_returns(target=self.target, convert_to='monthly', benchmark_factor=self.benchmark)) 30 | 31 | @staticmethod 32 | def test_holding(): 33 | h = FundHolding() 34 | date = dt.datetime(2021, 3, 8) 35 | print(h.get_holding(date)) 36 | print(h.get_holding(date, fund='指增1号 - 东财 - 普通户')) 37 | print(h.get_holding(date, fund='ALL')) 38 | 39 | def test_fund_nav_analysis(self): 40 | fund_nav_analysis = FundNAVAnalysis('110011.OF') 41 | fund_nav_analysis.compute_correlation('399006.SZ') 42 | model = FamaFrench3FactorModel() 43 | fund_nav_analysis.compute_exposure(model) 44 | fund_nav_analysis.get_latest_published_portfolio_holding() 45 | 46 | def test_public_fund_holding(self): 47 | ticker = '000001.SZ' 48 | date = dt.datetime(2020, 12, 31) 49 | rec = PublicFundHoldingRecords(ticker, date) 50 | self = rec 51 | 52 | 53 | if __name__ == '__main__': 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /tests/ashare_datareader_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | from AShareData.ashare_data_reader import AShareDataReader 5 | from AShareData.config import set_global_config 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | set_global_config('config.json') 11 | self.db = AShareDataReader() 12 | self.start_date = dt.datetime(2018, 5, 10) 13 | self.end_date = dt.datetime(2018, 7, 10) 14 | self.ids = ['000001.SZ', '600000.SH', '000002.SZ'] 15 | self.dates = [self.start_date, self.end_date] 16 | 17 | def test_calendar(self): 18 | print(self.db.calendar.calendar) 19 | 20 | def test_adj_factor(self): 21 | print(self.db.adj_factor.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 22 | print(self.db.adj_factor.get_data(start_date=self.start_date, ids=self.ids)) 23 | print(self.db.adj_factor.get_data(end_date=self.end_date, ids=self.ids)) 24 | print(self.db.adj_factor.get_data(dates=self.dates, ids=self.ids)) 25 | 26 | def test_stocks(self): 27 | print(self.db.stocks) 28 | 29 | def test_get_sec_name(self): 30 | start_date = dt.date(2018, 5, 10) 31 | print(self.db.sec_name.get_data(dates=start_date)) 32 | 33 | def test_industry(self): 34 | start_date = dt.date(2018, 5, 10) 35 | print(self.db.industry('中信', 3).get_data(dates=start_date)) 36 | print(self.db.industry('中证', 3).get_data(dates=start_date)) 37 | 38 | def test_index_constitute(self): 39 | print(self.db.index_constitute.get_data('000300.SH', '20201130')) 40 | 41 | def test_ttm(self): 42 | print(self.db.earning_ttm.get_data(dates=self.dates, ids=self.ids)) 43 | print(self.db.stock_market_cap.get_data(dates=self.dates, ids=self.ids)) 44 | print(self.db.pe_ttm.get_data(dates=self.dates, ids=self.ids)) 45 | print(self.db.pb_after_close.get_data(dates=self.dates, ids=self.ids)) 46 | 47 | def test_cap_weight(self): 48 | print(self.db.free_floating_cap_weight.get_data(dates=[self.start_date, self.end_date], ids=self.ids)) 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /tests/calendar_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData.config import set_global_config 4 | from AShareData.date_utils import * 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | def setUp(self) -> None: 9 | set_global_config('config.json') 10 | self.calendar = SHSZTradingCalendar() 11 | 12 | def test_is_trading_day(self): 13 | self.assertFalse(self.calendar.is_trading_date(dt.date(2019, 10, 1))) 14 | self.assertTrue(self.calendar.is_trading_date(dt.date(2019, 10, 16))) 15 | 16 | def test_days_count(self): 17 | start = dt.datetime(2019, 1, 4) 18 | end = dt.datetime(2019, 1, 7) 19 | self.assertEqual(self.calendar.days_count(start, end), 1) 20 | self.assertEqual(self.calendar.days_count(end, start), -1) 21 | self.assertEqual(self.calendar.days_count(start, start), 0) 22 | 23 | self.assertEqual(self.calendar.days_count(dt.datetime(2015, 9, 30), dt.datetime(2015, 10, 8)), 1) 24 | self.assertEqual(self.calendar.days_count(dt.datetime(2015, 10, 1), dt.datetime(2015, 10, 8)), 1) 25 | 26 | def test_first_day_of_month(self): 27 | start = dt.datetime(2019, 3, 2) 28 | end = dt.datetime(2019, 4, 2) 29 | self.assertEqual(self.calendar.first_day_of_month(start, end)[0], dt.datetime(2019, 4, 1)) 30 | 31 | def test_last_day_of_month(self): 32 | start = dt.datetime(2019, 3, 2) 33 | end = dt.datetime(2019, 4, 2) 34 | self.assertEqual(self.calendar.last_day_of_month(start, end)[0], dt.datetime(2019, 3, 29)) 35 | 36 | def test_last_day_of_year(self): 37 | start = dt.datetime(2018, 3, 2) 38 | end = dt.datetime(2019, 4, 2) 39 | self.assertEqual(self.calendar.last_day_of_year(start, end)[0], dt.datetime(2018, 12, 28)) 40 | 41 | def test_select_dates(self): 42 | start = dt.datetime(2019, 9, 2) 43 | end = dt.datetime(2019, 9, 3) 44 | self.assertEqual(self.calendar.select_dates(start, end), [start, end]) 45 | 46 | start = dt.datetime(2020, 11, 2) 47 | end = dt.datetime(2020, 11, 7) 48 | dates = self.calendar.select_dates(start, end) 49 | self.assertEqual(dates[0], start) 50 | self.assertEqual(dates[-1], dt.datetime(2020, 11, 6)) 51 | 52 | start = dt.datetime(2020, 11, 1) 53 | end = dt.datetime(2020, 11, 6) 54 | dates = self.calendar.select_dates(start, end) 55 | self.assertEqual(dates[0], dt.datetime(2020, 11, 2)) 56 | self.assertEqual(dates[-1], dt.datetime(2020, 11, 6)) 57 | 58 | start = dt.datetime(2020, 11, 1) 59 | end = dt.datetime(2020, 11, 7) 60 | dates = self.calendar.select_dates(start, end) 61 | self.assertEqual(dates[0], dt.datetime(2020, 11, 2)) 62 | self.assertEqual(dates[-1], dt.datetime(2020, 11, 6)) 63 | 64 | def test_offset(self): 65 | start_date = dt.datetime(2020, 11, 2) 66 | self.assertEqual(self.calendar.offset(start_date, 1), dt.datetime(2020, 11, 3)) 67 | self.assertEqual(self.calendar.offset(start_date, 0), dt.datetime(2020, 11, 2)) 68 | self.assertEqual(self.calendar.offset(start_date, -1), dt.datetime(2020, 10, 30)) 69 | 70 | start_date = dt.datetime(2020, 11, 1) 71 | self.assertEqual(self.calendar.offset(start_date, 1), dt.datetime(2020, 11, 2)) 72 | self.assertEqual(self.calendar.offset(start_date, 0), dt.datetime(2020, 11, 2)) 73 | self.assertEqual(self.calendar.offset(start_date, -1), dt.datetime(2020, 10, 30)) 74 | 75 | def test_begin_and_end(self): 76 | self.assertEqual(self.calendar.month_begin(2021, 3), dt.datetime(2021, 3, 1)) 77 | self.assertEqual(self.calendar.month_begin(2021, 1), dt.datetime(2021, 1, 4)) 78 | self.assertEqual(self.calendar.month_begin(2020, 10), dt.datetime(2020, 10, 9)) 79 | 80 | self.assertEqual(self.calendar.month_end(2021, 3), dt.datetime(2021, 3, 31)) 81 | self.assertEqual(self.calendar.month_end(2021, 1), dt.datetime(2021, 1, 29)) 82 | self.assertEqual(self.calendar.month_end(2020, 1), dt.datetime(2020, 1, 23)) 83 | 84 | self.assertEqual(self.calendar.pre_month_end(2021, 4), dt.datetime(2021, 3, 31)) 85 | self.assertEqual(self.calendar.pre_month_end(2021, 2), dt.datetime(2021, 1, 29)) 86 | self.assertEqual(self.calendar.pre_month_end(2020, 2), dt.datetime(2020, 1, 23)) 87 | 88 | @staticmethod 89 | def test_format_dt(): 90 | @dtlize_input_dates 91 | def func(date, dates=None): 92 | print(date) 93 | print(dates) 94 | 95 | func(dt.date(2000, 1, 1), dates=dt.date(2010, 1, 1)) 96 | 97 | def test_report_date_offset(self): 98 | self.assertEqual(ReportingDate.quarterly_offset(dt.datetime(2020, 3, 31), -1), dt.datetime(2019, 12, 31)) 99 | self.assertEqual(ReportingDate.quarterly_offset(dt.datetime(2020, 3, 31), -2), dt.datetime(2019, 9, 30)) 100 | self.assertEqual(ReportingDate.quarterly_offset(dt.datetime(2020, 3, 31), -3), dt.datetime(2019, 6, 30)) 101 | self.assertEqual(ReportingDate.quarterly_offset(dt.datetime(2020, 3, 31), -4), dt.datetime(2019, 3, 31)) 102 | self.assertEqual(ReportingDate.quarterly_offset(dt.datetime(2020, 3, 31), -5), dt.datetime(2018, 12, 31)) 103 | 104 | self.assertEqual(ReportingDate.offset(dt.datetime(2020, 3, 31), 'q1'), dt.datetime(2019, 12, 31)) 105 | self.assertEqual(ReportingDate.offset(dt.datetime(2020, 3, 31), 'y1'), dt.datetime(2019, 12, 31)) 106 | 107 | 108 | if __name__ == '__main__': 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /tests/db_interface_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import datetime as dt 3 | 4 | from AShareData.config import get_db_interface, set_global_config 5 | from AShareData.date_utils import date_type2datetime 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | set_global_config('config.json') 11 | self.db_interface = get_db_interface() 12 | 13 | def test_read_data(self): 14 | table_name = '合并资产负债表' 15 | factor_name = '期末总股本' 16 | start_date = date_type2datetime('20190101') 17 | end_date = date_type2datetime('20190101') 18 | report_period = date_type2datetime('20181231') 19 | print(self.db_interface.read_table(table_name, factor_name).head()) 20 | print(self.db_interface.read_table(table_name, factor_name, start_date=start_date, end_date=end_date).head()) 21 | print(self.db_interface.read_table(table_name, factor_name, start_date=start_date).head()) 22 | print(self.db_interface.read_table(table_name, factor_name, report_period=report_period).head()) 23 | 24 | def test_calendar(self): 25 | self.db_interface.read_table('交易日历') 26 | 27 | def test_db_timestamp(self): 28 | table_name = '合并资产负债表' 29 | print(self.db_interface.get_latest_timestamp(table_name)) 30 | table_name = '模型因子日收益率' 31 | print(self.db_interface.get_latest_timestamp(table_name)) 32 | print(self.db_interface.get_latest_timestamp(table_name, default_ts=dt.datetime(2021, 3, 4))) 33 | 34 | 35 | if __name__ == '__main__': 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /tests/factor_compositor_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | import AShareData.date_utils 5 | from AShareData.config import set_global_config 6 | from AShareData.factor_compositor import * 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def setUp(self) -> None: 11 | set_global_config('config.json') 12 | self.factor_compositor = FactorCompositor() 13 | 14 | def test_market_return(self): 15 | ticker: str = '000001.IND' 16 | ignore_new_stock_period: dt.timedelta = dt.timedelta(days=252) 17 | unit_base: str = '自由流通股本' 18 | start_date: AShareData.date_utils.DateType = dt.datetime(1999, 5, 4) 19 | 20 | 21 | if __name__ == '__main__': 22 | unittest.main() 23 | -------------------------------------------------------------------------------- /tests/factor_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData import set_global_config 4 | from AShareData.factor import * 5 | from AShareData.tickers import * 6 | from AShareData.utils import StockSelectionPolicy 7 | 8 | 9 | class MyTestCase(unittest.TestCase): 10 | def setUp(self) -> None: 11 | set_global_config('config.json') 12 | self.db_interface = get_db_interface() 13 | self.calendar = SHSZTradingCalendar() 14 | self.start_date = dt.datetime(2002, 3, 1) 15 | self.end_date = dt.datetime(2002, 3, 30) 16 | self.ids = ['000001.SZ', '000002.SZ'] 17 | self.close = ContinuousFactor('股票日行情', '收盘价', self.db_interface) 18 | self.adj = CompactFactor('复权因子', self.db_interface) 19 | 20 | def test_compact_record_factor(self): 21 | compact_factor = CompactFactor('证券名称', self.db_interface) 22 | compact_factor.data = compact_factor.data.map(lambda x: 'PT' in x or 'ST' in x or '退' in x) 23 | compact_record_factor = CompactRecordFactor(compact_factor, 'ST') 24 | print(compact_record_factor.get_data(date=dt.datetime(2015, 5, 15))) 25 | 26 | def test_compact_factor(self): 27 | compact_factor = CompactFactor('证券名称', self.db_interface) 28 | print(compact_factor.get_data(dates=dt.datetime(2015, 5, 15))) 29 | policy = StockSelectionPolicy(select_st=True) 30 | print(compact_factor.get_data(dates=dt.datetime(2015, 5, 15), ticker_selector=StockTickerSelector(policy))) 31 | 32 | def test_industry(self): 33 | print('') 34 | industry_factor = IndustryFactor('中信', 3, self.db_interface) 35 | print(industry_factor.list_constitutes(dt.datetime(2019, 1, 7), '白酒')) 36 | print('') 37 | print(industry_factor.all_industries) 38 | 39 | def test_pause_stocks(self): 40 | pause_stock = OnTheRecordFactor('股票停牌', self.db_interface) 41 | start_date = dt.datetime(2021, 1, 1) 42 | end_date = dt.datetime(2021, 2, 4) 43 | print(pause_stock.get_data(date=end_date)) 44 | print(pause_stock.get_counts(start_date=start_date, end_date=end_date, ids=self.ids + ['000662.SZ'])) 45 | 46 | def test_latest_accounting_factor(self): 47 | f = LatestAccountingFactor('期末总股本', self.db_interface) 48 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 49 | print(a) 50 | 51 | def test_latest_quarter_report_factor(self): 52 | f = LatestQuarterAccountingFactor('期末总股本', self.db_interface) 53 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 54 | print(a) 55 | 56 | def test_yearly_report_factor(self): 57 | f = YearlyReportAccountingFactor('期末总股本', self.db_interface) 58 | ids = list(set(self.ids) - {'600087.SH', '600788.SH', '600722.SH'}) 59 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=ids) 60 | print(a) 61 | 62 | def test_qoq_report_factor(self): 63 | f = QOQAccountingFactor('期末总股本', self.db_interface) 64 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 65 | print(a) 66 | 67 | def test_yoy_period_report_factor(self): 68 | f = YOYPeriodAccountingFactor('期末总股本', self.db_interface) 69 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 70 | print(a) 71 | 72 | def test_yoy_quarter_factor(self): 73 | f = YOYQuarterAccountingFactor('期末总股本', self.db_interface) 74 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 75 | print(a) 76 | 77 | def test_ttm_factor(self): 78 | f = TTMAccountingFactor('期末总股本', self.db_interface) 79 | a = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 80 | print(a) 81 | 82 | def test_index_constitute(self): 83 | index_constitute = IndexConstitute(self.db_interface) 84 | print(index_constitute.get_data('000300.SH', '20200803')) 85 | 86 | def test_sum_factor(self): 87 | sum_hfq = self.close + self.adj 88 | sum_hfq_close_data = sum_hfq.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 89 | print(sum_hfq_close_data) 90 | uni_sum = self.close + 1 91 | print(uni_sum.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 92 | 93 | def test_mul_factor(self): 94 | hfq = self.close * self.adj 95 | hfq_close_data = hfq.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 96 | print(hfq_close_data) 97 | 98 | def test_factor_pct_change(self): 99 | hfq = self.close * self.adj 100 | hfq_chg = hfq.pct_change() 101 | pct_chg_data = hfq_chg.get_data(start_date=self.start_date, end_date=self.end_date) 102 | print(pct_chg_data) 103 | 104 | def test_factor_max(self): 105 | f = self.adj.max() 106 | f_max = f.get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids) 107 | print(f_max) 108 | 109 | def test_beta_factor(self): 110 | ids: Union[str, Sequence[str]] = ['000001.SZ', '600000.SH'] 111 | dates: Sequence[dt.datetime] = [dt.datetime(2020, 1, 15), dt.datetime(2020, 5, 13)] 112 | look_back_period: int = 60 113 | min_trading_days: int = 40 114 | 115 | policy = StockSelectionPolicy(ignore_new_stock_period=365, ignore_st=True) 116 | ticker_selector = StockTickerSelector(policy) 117 | 118 | beta_factor = BetaFactor(db_interface=self.db_interface) 119 | print(beta_factor.get_data(dates, ids, look_back_period=look_back_period, min_trading_days=min_trading_days)) 120 | print(beta_factor.get_data(dates, ticker_selector=ticker_selector, look_back_period=look_back_period, 121 | min_trading_days=min_trading_days)) 122 | 123 | def test_interest_rate(self): 124 | print('') 125 | interest_rate = InterestRateFactor('shibor利率数据', '6个月', self.db_interface).set_factor_name('6个月shibor') 126 | start_date = dt.datetime(2021, 1, 1) 127 | end_date = dt.datetime(2021, 3, 1) 128 | data = interest_rate.get_data(start_date=start_date, end_date=end_date) 129 | print(data) 130 | 131 | def test_mean_and_average(self): 132 | print(self.close.mean('DateTime').get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 133 | print(self.close.mean('ID').get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 134 | print(self.close.sum('DateTime').get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 135 | print(self.close.sum('ID').get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 136 | 137 | def test_diff(self): 138 | print(self.close.diff().get_data(start_date=self.start_date, end_date=self.end_date, ids=self.ids)) 139 | 140 | def test_latest_update_factor(self): 141 | latest_update_factor = LatestUpdateFactor('场外基金规模', '资产净值', self.db_interface) 142 | print(latest_update_factor.get_data(ids=['008864.OF', '000001.OF'])) 143 | print(latest_update_factor.get_data(ids='008864.OF')) 144 | 145 | 146 | if __name__ == '__main__': 147 | unittest.main() 148 | -------------------------------------------------------------------------------- /tests/industry_comparison_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | from AShareData import get_db_interface, IndustryComparison, set_global_config 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | def setUp(self) -> None: 9 | set_global_config('config.json') 10 | db_interface = get_db_interface() 11 | self.industry_obj = IndustryComparison(index='000905.SH', industry_provider='中信', industry_level=2) 12 | 13 | def test_something(self): 14 | holding = self.industry_obj.import_holding('holding.xlsx', date=dt.datetime(2020, 12, 18)) 15 | print(self.industry_obj.holding_comparison(holding)) 16 | 17 | 18 | if __name__ == '__main__': 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /tests/jq_data_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import AShareData as asd 4 | 5 | 6 | class MyTestCase(unittest.TestCase): 7 | def setUp(self) -> None: 8 | asd.set_global_config('config.json') 9 | self.jq_data = asd.JQData() 10 | 11 | def test_jq_login(self): 12 | pass 13 | 14 | 15 | if __name__ == '__main__': 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | from AShareData import set_global_config 5 | from AShareData.model import * 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | set_global_config('config.json') 11 | 12 | def test_something(self): 13 | self.assertEqual(True, False) 14 | 15 | @staticmethod 16 | def test_FF3factor_return(): 17 | model = FamaFrench3FactorModel() 18 | smb = SMBandHMLCompositor(model) 19 | date = dt.datetime(2021, 3, 9) 20 | pre_date = dt.datetime(2021, 3, 8) 21 | pre_month_date = dt.datetime(2021, 2, 26) 22 | smb.compute_factor_return(balance_date=pre_date, pre_date=pre_date, date=date, 23 | rebalance_marker='D', period_marker='D') 24 | smb.compute_factor_return(balance_date=pre_month_date, pre_date=pre_date, date=date, 25 | rebalance_marker='M', period_marker='D') 26 | smb.compute_factor_return(balance_date=pre_month_date, pre_date=pre_month_date, date=date, 27 | rebalance_marker='M', period_marker='M') 28 | 29 | @staticmethod 30 | def test_FFC4_factor_return(): 31 | model = FamaFrenchCarhart4FactorModel() 32 | umd = UMDCompositor(model) 33 | date = dt.datetime(2021, 3, 9) 34 | pre_date = dt.datetime(2021, 3, 8) 35 | pre_month_date = dt.datetime(2021, 2, 26) 36 | umd.compute_factor_return(balance_date=pre_date, pre_date=pre_date, date=date, 37 | rebalance_marker='D', period_marker='D') 38 | umd.compute_factor_return(balance_date=pre_month_date, pre_date=pre_date, date=date, 39 | rebalance_marker='M', period_marker='D') 40 | umd.compute_factor_return(balance_date=pre_month_date, pre_date=pre_month_date, date=date, 41 | rebalance_marker='M', period_marker='M') 42 | 43 | 44 | if __name__ == '__main__': 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /tests/plot_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData.config import set_global_config 4 | from AShareData.factor import ContinuousFactor 5 | from AShareData.plot import plot_index 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | set_global_config('config.json') 11 | 12 | def test_plot_factor_portfolio_return(self): 13 | factor_name = 'Beta' 14 | weight = True 15 | industry_neutral = True 16 | bins = 5 17 | start_date = None 18 | end_date = None 19 | db_interface = None 20 | 21 | def test_plot_index(self): 22 | index_factor = ContinuousFactor('自合成指数', '收益率') 23 | index_factor.bind_params(ids='ST.IND') 24 | benchmark_factor = ContinuousFactor('自合成指数', '收益率') 25 | benchmark_factor.bind_params(ids='全市场.IND') 26 | start_date = None 27 | end_date = None 28 | ids = 'ST.IND' 29 | plot_index(index_factor) 30 | plot_index(index_factor, benchmark_factor=benchmark_factor) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/portfolio_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData import set_global_config, SHSZTradingCalendar 4 | from AShareData.analysis.holding import * 5 | from AShareData.model.fama_french_3_factor_model import FamaFrench3FactorModel 6 | from AShareData.portfolio_analysis import * 7 | from AShareData.tickers import StockTickerSelector 8 | from AShareData.utils import StockSelectionPolicy 9 | 10 | 11 | class MyTestCase(unittest.TestCase): 12 | def setUp(self): 13 | set_global_config('config.json') 14 | self.portfolio_analysis = ASharePortfolioAnalysis() 15 | self.data_reader = self.portfolio_analysis.data_reader 16 | 17 | def test_summary_statistics(self): 18 | start_date = dt.date(2008, 1, 1) 19 | end_date = dt.date(2020, 1, 1) 20 | price = self.data_reader.get_factor('股票日行情', '收盘价', start_date=start_date, end_date=end_date) 21 | self.portfolio_analysis.summary_statistics(price) 22 | 23 | 24 | class CrossSectionTesting(unittest.TestCase): 25 | def setUp(self): 26 | set_global_config('config.json') 27 | self.data_reader = AShareDataReader() 28 | forward_return = self.data_reader.forward_return 29 | factors = self.data_reader.log_cap 30 | ticker_selector = StockTickerSelector(StockSelectionPolicy()) 31 | market_cap = self.data_reader.stock_free_floating_market_cap 32 | start_date = dt.datetime(2020, 8, 1) 33 | end_date = dt.datetime(2021, 2, 1) 34 | dates = SHSZTradingCalendar().first_day_of_month(start_date, end_date) 35 | 36 | self.t = CrossSectionalPortfolioAnalysis(forward_return, factors=factors, dates=dates, market_cap=market_cap, 37 | ticker_selector=ticker_selector) 38 | self.t.cache() 39 | 40 | def test_single_sort(self): 41 | self.t.single_factor_sorting('BM') 42 | self.t.returns_results(cap_weighted=True) 43 | self.t.returns_results(cap_weighted=False) 44 | self.t.summary_statistics('BM') 45 | 46 | def test_independent_double_sort(self): 47 | self.t.two_factor_sorting(factor_names=('BM', '市值对数'), quantile=10, separate_neg_vals=True, independent=True) 48 | self.t.returns_results(cap_weighted=True) 49 | self.t.returns_results(cap_weighted=False) 50 | self.t.summary_statistics('BM') 51 | 52 | def test_dependent_double_sort(self): 53 | self.t.two_factor_sorting(factor_names=('BM', '市值对数'), quantile=10, separate_neg_vals=True, independent=False) 54 | self.t.returns_results(cap_weighted=True) 55 | self.t.returns_results(cap_weighted=False) 56 | self.t.summary_statistics('BM') 57 | 58 | 59 | class PortfolioExposureTest(unittest.TestCase): 60 | def setUp(self): 61 | set_global_config('config.json') 62 | 63 | @staticmethod 64 | def test_case(): 65 | date = dt.datetime(2021, 3, 8) 66 | model = FamaFrench3FactorModel() 67 | exposure = ASharePortfolioExposure(model=model) 68 | ticker = '000002.SZ' 69 | fh = FundHolding() 70 | portfolio_weight = fh.portfolio_stock_weight(date, 'ALL') 71 | exposure.get_stock_exposure(ticker) 72 | exposure.get_portfolio_exposure(portfolio_weight) 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/test_algo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from AShareData.algo import * 3 | from AShareData.utils import * 4 | 5 | 6 | class MyTestCase(unittest.TestCase): 7 | @staticmethod 8 | def test_get_less_or_equal_of_a_in_b(): 9 | a = [5, 9, 16, 25, 60] 10 | a2 = [-2] 11 | a3 = [-2, 54] 12 | a4 = [] 13 | b = list(range(20)) + [24, 55, 56] 14 | print(get_less_or_equal_of_a_in_b(a, b)) 15 | print(get_less_or_equal_of_a_in_b(a2, b)) 16 | print(get_less_or_equal_of_a_in_b(a3, b)) 17 | print(get_less_or_equal_of_a_in_b(a4, b)) 18 | 19 | def test_is_stock_ticker(self): 20 | self.assertEqual(get_stock_board_name('000001.SZ'), '主板') 21 | self.assertEqual(get_stock_board_name('001979.SZ'), '主板') 22 | self.assertEqual(get_stock_board_name('600000.SH'), '主板') 23 | self.assertEqual(get_stock_board_name('605500.SH'), '主板') 24 | 25 | self.assertEqual(get_stock_board_name('002594.SZ'), '中小板') 26 | self.assertEqual(get_stock_board_name('300498.SZ'), '创业板') 27 | self.assertEqual(get_stock_board_name('688688.SH'), '科创板') 28 | 29 | self.assertEqual(get_stock_board_name('0196.HK'), '非股票') 30 | self.assertEqual(get_stock_board_name('IF1208.CFE'), '非股票') 31 | self.assertEqual(get_stock_board_name('300498'), '非股票') 32 | self.assertEqual(get_stock_board_name(300498), '非股票') 33 | 34 | 35 | if __name__ == '__main__': 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | from AShareData import set_global_config 5 | from AShareData.model.fama_french_3_factor_model import FamaFrench3FactorModel 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | @staticmethod 10 | def test_ff_model(): 11 | set_global_config('config.json') 12 | date = dt.datetime(2020, 3, 3) 13 | model = FamaFrench3FactorModel() 14 | # self = model 15 | print(model.compute_daily_factor_return(date)) 16 | 17 | 18 | if __name__ == '__main__': 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /tests/ticker_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData.config import set_global_config 4 | from AShareData.tickers import * 5 | from AShareData.utils import StockSelectionPolicy 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | set_global_config('config.json') 11 | self.db_interface = get_db_interface() 12 | 13 | @staticmethod 14 | def ticker_test(ticker_obj): 15 | ticker_obj.all_ticker() 16 | tickers = (ticker_obj.ticker(dt.date(2020, 9, 30))) 17 | print(tickers) 18 | print(len(tickers)) 19 | 20 | def test_stock_ticker(self): 21 | stock_ticker = StockTickers(self.db_interface) 22 | self.ticker_test(stock_ticker) 23 | stock_ticker.get_list_date('000001.SZ') 24 | 25 | start_date = dt.datetime(2018, 1, 1) 26 | end_date = dt.datetime(2018, 12, 1) 27 | print(stock_ticker.new_ticker(start_date=start_date, end_date=end_date)) 28 | 29 | def test_future_ticker(self): 30 | future_ticker = FutureTickers(self.db_interface) 31 | self.ticker_test(future_ticker) 32 | 33 | def test_etf_option_ticker(self): 34 | etf_option_ticker = ETFOptionTickers(self.db_interface) 35 | self.ticker_test(etf_option_ticker) 36 | 37 | def test_etf_ticker(self): 38 | etf_ticker = ETFTickers(self.db_interface) 39 | self.ticker_test(etf_ticker) 40 | 41 | def test_stock_etf_ticker(self): 42 | stock_etf = ExchangeStockETFTickers(self.db_interface) 43 | self.ticker_test(stock_etf) 44 | 45 | def test_bond_etf_ticker(self): 46 | stock_etf = BondETFTickers(self.db_interface) 47 | self.ticker_test(stock_etf) 48 | 49 | def test_active_stock_ticker(self): 50 | ticker = ActiveManagedStockFundTickers(True, self.db_interface) 51 | self.ticker_test(ticker) 52 | 53 | def test_exchange_fund_ticker(self): 54 | ticker = ExchangeFundTickers(self.db_interface) 55 | self.ticker_test(ticker) 56 | 57 | def test_option_ticker(self): 58 | ticker = OptionTickers(self.db_interface) 59 | self.ticker_test(ticker) 60 | 61 | def test_ticker_selection(self): 62 | policy = StockSelectionPolicy() 63 | policy.ignore_new_stock_period = 360 64 | policy.select_st = False 65 | policy.max_pause_days = (2, 5) 66 | selector = StockTickerSelector(policy=policy, db_interface=self.db_interface) 67 | dates = [dt.datetime(2020, 1, 7), dt.datetime(2020, 12, 28)] 68 | ret = selector.generate_index(dates=dates) 69 | print(ret) 70 | 71 | def test_new_ticker_selection(self): 72 | policy = StockSelectionPolicy() 73 | policy.ignore_new_stock_period = 60 74 | policy.select_new_stock_period = 90 75 | policy.select_st = False 76 | selector = StockTickerSelector(policy=policy, db_interface=self.db_interface) 77 | dates = [dt.datetime(2020, 1, 7), dt.datetime(2020, 12, 28)] 78 | ret = selector.generate_index(dates=dates) 79 | print(ret) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /tests/tools_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData import set_global_config 4 | from AShareData.tools.tools import IndexHighlighter, MajorIndustryConstitutes 5 | 6 | 7 | class MyTestCase(unittest.TestCase): 8 | @staticmethod 9 | def test_major_industry_constitute(): 10 | set_global_config('config.json') 11 | provider = '申万' 12 | level = 2 13 | name = '景点' 14 | obj = MajorIndustryConstitutes(provider=provider, level=level) 15 | print(obj.get_major_constitute(name)) 16 | 17 | @staticmethod 18 | def test_index_highlighter(): 19 | set_global_config('config.json') 20 | obj = IndexHighlighter() 21 | obj.summary() 22 | 23 | 24 | if __name__ == '__main__': 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/tushare2mysql_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from AShareData import set_global_config, TushareData 4 | 5 | 6 | class Tushare2MySQLTest(unittest.TestCase): 7 | def setUp(self) -> None: 8 | set_global_config('config.json') 9 | self.downloader = TushareData() 10 | 11 | def test_calendar(self): 12 | print(self.downloader.calendar.calendar) 13 | 14 | def test_financial(self): 15 | self.downloader.get_financial('300146.SZ') 16 | 17 | def test_index(self): 18 | self.downloader.get_index_daily() 19 | 20 | def test_ipo_info(self): 21 | self.downloader.get_ipo_info() 22 | 23 | def test_all_past_names(self): 24 | self.downloader.init_stock_names() 25 | 26 | def test_past_names(self): 27 | self.downloader.update_stock_names() 28 | 29 | def test_company_info(self): 30 | self.downloader.get_company_info() 31 | 32 | def test_daily_hq(self): 33 | self.downloader.get_daily_hq(start_date='2010917') 34 | 35 | def test_all_dividend(self): 36 | self.downloader.get_all_dividend() 37 | 38 | def test_routine(self): 39 | # self.downloader.update_routine() 40 | pass 41 | 42 | def test_hs_const(self): 43 | self.downloader.get_hs_constitute() 44 | 45 | def test_shibor(self): 46 | self.downloader.get_shibor(end_date='20111010') 47 | 48 | def test_index_weight(self): 49 | self.downloader.get_index_weight(start_date='20050101') 50 | 51 | 52 | if __name__ == '__main__': 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /tests/web_data_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | import AShareData as asd 5 | 6 | 7 | class WebDataSourceTest(unittest.TestCase): 8 | def setUp(self) -> None: 9 | asd.set_global_config('config.json') 10 | self.web_crawler = asd.WebDataCrawler() 11 | self.calendar = asd.SHSZTradingCalendar() 12 | 13 | def test_sw_industry(self): 14 | self.web_crawler.get_sw_industry() 15 | 16 | def test_zx_industry(self): 17 | self.web_crawler.get_zz_industry(self.calendar.offset(dt.date.today(), -1)) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /tests/wind_data_test.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import unittest 3 | 4 | from AShareData import constants, set_global_config 5 | from AShareData.data_source.wind_data import WindData, WindWrapper 6 | 7 | 8 | class MyTestCase(unittest.TestCase): 9 | def setUp(self) -> None: 10 | config_loc = 'config.json' 11 | set_global_config(config_loc) 12 | self.wind_data = WindData.from_config(config_loc) 13 | 14 | def test_get_industry_func(self): 15 | wind_code = '000019.SZ' 16 | start_date = '20161212' 17 | end_date = '20190905' 18 | provider = '中证' 19 | start_data = '软饮料' 20 | end_data = '食品经销商' 21 | print(self.wind_data._find_industry(wind_code, provider, start_date, start_data, end_date, end_data)) 22 | 23 | def test_update_zz_industry(self): 24 | self.wind_data.update_industry('中证') 25 | 26 | def test_update_sw_industry(self): 27 | self.wind_data.update_industry('申万') 28 | 29 | def test_update_wind_industry(self): 30 | self.wind_data.update_industry('Wind') 31 | 32 | def test_minutes_data(self): 33 | self.assertRaises(AssertionError, self.wind_data.get_stock_minutes_data, '20191001') 34 | # print(self.wind_data.get_minutes_data('20161017')) 35 | 36 | def test_update_minutes_data(self): 37 | self.wind_data.update_stock_minutes_data() 38 | 39 | def test_stock_daily_data(self): 40 | self.wind_data.get_stock_daily_data(trade_date=dt.date(2019, 12, 27)) 41 | 42 | 43 | class WindWrapperTestCase(unittest.TestCase): 44 | def setUp(self) -> None: 45 | self.w = WindWrapper() 46 | self.w.connect() 47 | 48 | def test_wsd(self): 49 | stock = '000001.SZ' 50 | stocks = ['000001.SZ', '000002.SZ'] 51 | start_date = dt.datetime(2019, 10, 23) 52 | end_date = dt.datetime(2019, 10, 24) 53 | indicator = 'high' 54 | indicators = 'high,low' 55 | provider = '中信' 56 | 57 | print(self.w.wsd(stock, indicator, start_date, start_date, '')) 58 | print(self.w.wsd(stock, indicators, start_date, start_date, '')) 59 | print(self.w.wsd(stocks, indicator, start_date, start_date, '')) 60 | print(self.w.wsd(stock, indicator, start_date, end_date, '')) 61 | print(self.w.wsd(stock, indicators, start_date, end_date, '')) 62 | print(self.w.wsd(stocks, indicator, start_date, end_date, '')) 63 | 64 | print(self.w.wsd(stocks, f'industry_{constants.INDUSTRY_DATA_PROVIDER_CODE_DICT[provider]}', 65 | start_date, end_date, industryType=constants.INDUSTRY_LEVEL[provider])) 66 | 67 | def test_wss(self): 68 | # data = self.w.wss(['000001.SZ', '000002.SZ', '000005.SZ'], ['SHARE_RTD_STATE', 'SHARE_RTD_STATEJUR'], 69 | # trade_date='20190715', unit='1') 70 | # print('\n') 71 | # print(data) 72 | 73 | data = self.w.wss(['000001.SZ', '000002.SZ', '000005.SZ'], 'open,low,high,close,volume,amt', 74 | date='20190715', 75 | priceAdj='U', cycle='D') 76 | print('\n') 77 | print(data) 78 | 79 | # data = self.w.wss("000001.SH,000002.SZ", "grossmargin,operateincome", "unit=1;rptDate=20191231") 80 | # print('\n') 81 | # print(data) 82 | 83 | def test_wset(self): 84 | data = self.w.wset('futurecc', startdate='2019-07-29', enddate='2020-07-29', wind_code='A.DCE') 85 | print('\n') 86 | print(data) 87 | 88 | start_date = dt.date(2020, 6, 30).strftime('%Y-%m-%d') 89 | end_date = dt.date(2020, 7, 30).strftime('%Y-%m-%d') 90 | exchange = 'sse' 91 | wind_code = '510050.SH' 92 | status = 'all' 93 | field = 'wind_code,trade_code,sec_name' 94 | data = self.w.wset('optioncontractbasicinfo', options=f'field={field}', startdate=start_date, enddate=end_date, 95 | status=status, windcode=wind_code, exchange=exchange) 96 | print('\n') 97 | print(data) 98 | 99 | def test_wsq(self): 100 | data = self.w.wsq('002080.SZ,000002.SZ', 'rt_latest,rt_vol') 101 | print('\n') 102 | print(data) 103 | data = self.w.wsq('000002.SZ', 'rt_latest,rt_vol') 104 | print('\n') 105 | print(data) 106 | data = self.w.wsq('002080.SZ,000002.SZ', 'rt_latest') 107 | print('\n') 108 | print(data) 109 | data = self.w.wsq('000002.SZ', 'rt_latest') 110 | print('\n') 111 | print(data) 112 | 113 | def test_index_constitute(self): 114 | hs300_constitute = self.w.get_index_constitute(index='000300.SH') 115 | print(hs300_constitute) 116 | 117 | 118 | if __name__ == '__main__': 119 | unittest.main() 120 | --------------------------------------------------------------------------------