├── .gitignore ├── Cheetah.py ├── StockBlock.py ├── cheetah.iml ├── common ├── Constants.py ├── __init__.py └── downloader.py ├── cron.py ├── db.py ├── pyalgo.py ├── runStrategy.py ├── stocks.py ├── strategy ├── SMACrossOver.py └── __init__.py ├── test.py ├── testcases ├── __init__.py └── utils_test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | db 3 | *.im 4 | *.pyc 5 | *.log -------------------------------------------------------------------------------- /Cheetah.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | import stocks 5 | import tushare as ts 6 | 7 | stock_list = ts.get_hist_data('000521', start='2015-04-01', end='2016-04-18') 8 | stocks.MACD(stock_list, slow_period=55) 9 | 10 | for row in stock_list.iterrows(): 11 | print "row:",row -------------------------------------------------------------------------------- /StockBlock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | import tushare as ts 4 | import db 5 | from common.Constants import * 6 | from common import downloader 7 | from bs4 import BeautifulSoup 8 | from pandas import DataFrame 9 | 10 | try: 11 | import cPickle as pickle 12 | except ImportError: 13 | import pickle 14 | 15 | import sys 16 | 17 | reload(sys) 18 | sys.setdefaultencoding('utf-8') 19 | 20 | concept_url = "http://quote.eastmoney.com/center/BKList.html#notion_0_0?sortRule=0" 21 | industry_url = "http://quote.eastmoney.com/center/BKList.html#trade_0_0?sortRule=0" 22 | area_url = "http://quote.eastmoney.com/center/BKList.html#area_0_0?sortRule=0" 23 | 24 | class StockBlock(object): 25 | industry = "industry" # 行业 26 | concept = "concept" # 概念 27 | area = "area" # 地域 28 | sme = "sme" # 中小板 29 | gem = "gem" # 创业板 30 | st = "st" # 风险警示板 31 | hs300s = "hs300s" # 沪深300成份及权重 32 | sz50s = "sz50s" # 上证50成份股 33 | zz500s = "zz500s" # 中证500成份股 34 | 35 | def __init__(self): 36 | pass 37 | 38 | def parse(self, url): 39 | soup = BeautifulSoup(downloader.download(url), 'lxml').find_all(id='bklist') 40 | table = soup[0] 41 | tr = table.find_all('tr') 42 | bkmc = [] 43 | zdf = [] 44 | zsz = [] 45 | hsl = [] 46 | szjs = [] 47 | xdjs = [] 48 | for i in range(len(tr)): 49 | line = tr[i] 50 | td = line.find_all('td') 51 | if td: 52 | # print "id=%s,板块名称=%s,涨跌幅=%s,总市值(亿)=%s,换手率=%s,上涨家数=%s,下跌家数=%s" % \ 53 | # (td[0].get_text(), td[1].get_text(), td[3].get_text(), td[4].get_text(), td[5].get_text(), 54 | # td[6].get_text(), td[7].get_text()) 55 | bkmc.append(td[1].get_text()) 56 | zdf.append(float(td[3].get_text().replace('%', ''))) 57 | zsz.append(float(td[4].get_text())) 58 | hsl.append(float(td[5].get_text().replace('%', ''))) 59 | szjs.append(int(td[6].get_text())) 60 | xdjs.append(int(td[7].get_text())) 61 | 62 | data = {"板块名称": bkmc, "涨跌幅": zdf, "总市值(亿)": zsz, "换手率": hsl, "上涨家数": szjs, "下跌家数": xdjs} 63 | return DataFrame(data, columns=['板块名称', '涨跌幅', '总市值(亿)', '换手率', '上涨家数', '下跌家数']) 64 | 65 | def list_detail(self, stock_block_type): 66 | """ 67 | 实时获取股票板块详细数据 68 | :param stock_block_type: 69 | :return: 70 | """ 71 | if stock_block_type == self.industry: 72 | return self.parse(industry_url) 73 | elif stock_block_type == self.concept: 74 | return self.parse(concept_url) 75 | elif stock_block_type == self.area: 76 | return self.parse(area_url) 77 | else: 78 | return None 79 | 80 | def list(self, stock_block_type): 81 | stock_block = None 82 | if stock_block_type == self.industry: 83 | stock_block = db.get(STOCK_BLOCK_INDUSTRY) 84 | if stock_block is None: 85 | stock_block = ts.get_industry_classified() 86 | db.save(STOCK_BLOCK_INDUSTRY, stock_block) 87 | elif stock_block_type == self.concept: 88 | stock_block = db.get(STOCK_BLOCK_CONCEPT) 89 | if stock_block is None: 90 | stock_block = ts.get_concept_classified() 91 | db.save(STOCK_BLOCK_CONCEPT, stock_block) 92 | elif stock_block_type == self.area: 93 | stock_block = db.get(STOCK_BLOCK_AREA) 94 | if stock_block is None: 95 | stock_block = ts.get_area_classified() 96 | db.save(STOCK_BLOCK_AREA, stock_block) 97 | elif stock_block_type == self.sme: 98 | stock_block = db.get(STOCK_BLOCK_SME) 99 | if stock_block is None: 100 | stock_block = ts.get_sme_classified() 101 | db.save(STOCK_BLOCK_SME, stock_block) 102 | elif stock_block_type == self.gem: 103 | stock_block = db.get(STOCK_BLOCK_GEM) 104 | if stock_block is None: 105 | stock_block = ts.get_gem_classified() 106 | db.save(STOCK_BLOCK_GEM, stock_block) 107 | elif stock_block_type == self.st: 108 | stock_block = db.get(STOCK_BLOCK_ST) 109 | if stock_block is None: 110 | stock_block = ts.get_st_classified() 111 | db.save(STOCK_BLOCK_ST, stock_block) 112 | elif stock_block_type == self.hs300s: 113 | stock_block = db.get(STOCK_BLOCK_HS300S) 114 | if stock_block is None: 115 | stock_block = ts.get_hs300s() 116 | db.save(STOCK_BLOCK_HS300S, stock_block) 117 | elif stock_block_type == self.sz50s: 118 | stock_block = db.get(STOCK_BLOCK_SZ50S) 119 | if stock_block is None: 120 | stock_block = ts.get_sz50s() 121 | db.save(STOCK_BLOCK_SZ50S, stock_block) 122 | elif stock_block_type == self.zz500s: 123 | stock_block = db.get(STOCK_BLOCK_ZZ500S) 124 | if stock_block is None: 125 | stock_block = ts.get_zz500s() 126 | db.save(STOCK_BLOCK_ZZ500S, stock_block) 127 | else: 128 | return None 129 | return stock_block 130 | 131 | @staticmethod 132 | def preload(): 133 | stock_block = ts.get_industry_classified() 134 | db.save(STOCK_BLOCK_INDUSTRY, stock_block) 135 | stock_block = ts.get_concept_classified() 136 | db.save(STOCK_BLOCK_CONCEPT, stock_block) 137 | stock_block = ts.get_area_classified() 138 | db.save(STOCK_BLOCK_AREA, stock_block) 139 | stock_block = ts.get_sme_classified() 140 | db.save(STOCK_BLOCK_SME, stock_block) 141 | stock_block = ts.get_gem_classified() 142 | db.save(STOCK_BLOCK_GEM, stock_block) 143 | stock_block = ts.get_st_classified() 144 | db.save(STOCK_BLOCK_ST, stock_block) 145 | stock_block = ts.get_hs300s() 146 | db.save(STOCK_BLOCK_HS300S, stock_block) 147 | stock_block = ts.get_sz50s() 148 | db.save(STOCK_BLOCK_SZ50S, stock_block) 149 | stock_block = ts.get_zz500s() 150 | db.save(STOCK_BLOCK_ZZ500S, stock_block) 151 | 152 | 153 | if __name__ == '__main__': 154 | # StockBlock.preload() 155 | sb = StockBlock() 156 | sblist = sb.list_detail(sb.concept) 157 | # for line in sblist: 158 | # print line 159 | print sblist['板块名称'] 160 | 161 | -------------------------------------------------------------------------------- /cheetah.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /common/Constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | 4 | DATA_EXPIRED_TIME = 86400 5 | 6 | # 数据库key名称 7 | STOCK_BLOCK_INDUSTRY = "stockblock:industry" 8 | STOCK_BLOCK_CONCEPT = "stockblock:concept" 9 | STOCK_BLOCK_AREA = "stockblock:area" 10 | STOCK_BLOCK_SME = "stockblock:sem" 11 | STOCK_BLOCK_GEM = "stockblock:gem" 12 | STOCK_BLOCK_ST = "stockblock:st" 13 | STOCK_BLOCK_HS300S = "stockblock:hs300s" 14 | STOCK_BLOCK_SZ50S = "stockblock:sz50s" 15 | STOCK_BLOCK_ZZ500S = "stockblock:zz500s" 16 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | -------------------------------------------------------------------------------- /common/downloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | 4 | from selenium import webdriver 5 | 6 | 7 | def download(url): 8 | driver = webdriver.PhantomJS() 9 | driver.get(url) 10 | html = driver.page_source 11 | driver.close() 12 | return html 13 | -------------------------------------------------------------------------------- /cron.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | import time 5 | import datetime 6 | import stocks 7 | import tushare as ts 8 | import utils 9 | 10 | # target_list = [] 11 | target_list = ['000521', '002067', '000430'] 12 | 13 | 14 | while True: 15 | if utils.is_working_hour(): 16 | columns = ['name', 'close', 'change'] 17 | df = ts.get_index() 18 | print df[columns][df.code == '000001'] 19 | print df[columns][df.code == '399005'] 20 | 21 | columns = ['time', 'name', 'price', 'b1_v', 'b1_p','a1_v','a1_p'] 22 | for stock in target_list: 23 | print stocks.realtime(stock)[columns] 24 | print '-------' 25 | time.sleep(30) 26 | 27 | # def print_index(index): 28 | # print index 29 | # 30 | # 31 | # starttime = datetime.datetime.now() 32 | # 33 | # stock_list = stocks.list_stock() 34 | # for stock in stock_list.index: 35 | # stocks.filter(stock) 36 | # 37 | # endtime = datetime.datetime.now() 38 | # print "cost:", (endtime - starttime).seconds 39 | -------------------------------------------------------------------------------- /db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | import tushare as ts 5 | import redis 6 | import datetime 7 | import os 8 | from common.Constants import * 9 | 10 | db_name = "stock" 11 | db_path = "/home/eryk/workspaces/cheetah/db/" 12 | db_suffix = datetime.datetime.now().strftime("%Y-%m-%d") 13 | 14 | try: 15 | import cPickle as pickle 16 | except ImportError: 17 | import pickle 18 | 19 | 20 | def get_conn(): 21 | # return redis.Redis(host='112.124.60.26', port=6399) 22 | return redis.Redis(host='127.0.0.1', port=6379) 23 | 24 | 25 | def save_db(stock_list): 26 | r = get_conn() 27 | r.set('stock_list', pickle.dumps(stock_list)) 28 | 29 | 30 | def save(key, value, expire=DATA_EXPIRED_TIME): 31 | r = get_conn() 32 | r.set(key, value, expire) 33 | 34 | 35 | def exist(key): 36 | r = get_conn() 37 | return r.exists(key) 38 | 39 | 40 | def get(key): 41 | return get_conn().get(key) 42 | 43 | 44 | def load_db(): 45 | r = get_conn() 46 | return pickle.loads(r.get('stock_list')) 47 | 48 | 49 | def get_db_file_name(): 50 | return db_path + db_name + "_" + db_suffix 51 | 52 | 53 | def save_file(stock_list): 54 | f = open(get_db_file_name(), 'wb') 55 | pickle.dump(stock_list, f) 56 | f.close() 57 | 58 | 59 | def load_file(): 60 | db_full_path = get_db_file_name() 61 | if os.path.exists(db_full_path): 62 | if os.path.getsize(db_full_path) == 0: 63 | stock_list = ts.get_stock_basics() 64 | save_file(stock_list) 65 | else: 66 | stock_list = ts.get_stock_basics() 67 | save_file(stock_list) 68 | f = open(db_full_path, 'rb') 69 | objects = pickle.load(f) 70 | f.close 71 | return objects 72 | 73 | 74 | if __name__ == "__main__": 75 | new_stock_list = load_file() 76 | print new_stock_list 77 | -------------------------------------------------------------------------------- /pyalgo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | from pyalgotrade import strategy 5 | from pyalgotrade.barfeed import yahoofeed 6 | from pyalgotrade.technical import cross, ma, macd, rsi 7 | 8 | 9 | class MyStrategy(strategy.BacktestingStrategy): 10 | def __init__(self, feed, instrument): 11 | strategy.BacktestingStrategy.__init__(self, feed) 12 | # We want a 15 period SMA over the closing prices. 13 | self.__sma = ma.SMA(feed[instrument].getCloseDataSeries(), 15) 14 | self.__rsi = rsi.RSI(feed[instrument].getCloseDataSeries(), 14) 15 | self.__macd = macd.MACD(feed[instrument].getCloseDataSeries(), 12, 55, 9).getSignal() 16 | self.__instrument = instrument 17 | 18 | def onBars(self, bars): 19 | bar = bars[self.__instrument] 20 | # self.info("%s\t%s\t%s\t%s" % (bar.getClose(), self.__sma[-1], self.__rsi[-1], self.__macd[-1])) 21 | 22 | # Load the yahoo feed from the CSV file 23 | feed = yahoofeed.Feed() 24 | feed.addBarsFromCSV("orcl", "orcl-2015.csv") 25 | print cross.cross_above(feed['orcl'].getCloseDataSeries(),feed['orcl'].getHighDataSeries()) 26 | 27 | # Evaluate the strategy with the feed's bars. 28 | myStrategy = MyStrategy(feed, "orcl") 29 | myStrategy.run() -------------------------------------------------------------------------------- /runStrategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | from pyalgotrade.barfeed import yahoofeed 5 | from pyalgotrade.stratanalyzer import drawdown 6 | from pyalgotrade.stratanalyzer import returns 7 | from pyalgotrade.stratanalyzer import sharpe 8 | from pyalgotrade.stratanalyzer import trades 9 | 10 | # Load the yahoo feed from the CSV file 11 | from strategy import SMACrossOver 12 | 13 | feed = yahoofeed.Feed() 14 | feed.addBarsFromCSV("orcl", "orcl-2015.csv") 15 | # Evaluate the strategy with the feed's bars. 16 | myStrategy = SMACrossOver(feed, "orcl", 20) 17 | # Attach different analyzers to a strategy before executing it. 18 | retAnalyzer = returns.Returns() 19 | myStrategy.attachAnalyzer(retAnalyzer) 20 | sharpeRatioAnalyzer = sharpe.SharpeRatio() 21 | myStrategy.attachAnalyzer(sharpeRatioAnalyzer) 22 | drawDownAnalyzer = drawdown.DrawDown() 23 | myStrategy.attachAnalyzer(drawDownAnalyzer) 24 | tradesAnalyzer = trades.Trades() 25 | myStrategy.attachAnalyzer(tradesAnalyzer) 26 | # Run the strategy. 27 | myStrategy.run() 28 | print "Final portfolio value: $%.2f" % myStrategy.getResult() 29 | print "Cumulative returns: %.2f %%" % (retAnalyzer.getCumulativeReturns()[-1] * 100) 30 | print "Sharpe ratio: %.2f" % (sharpeRatioAnalyzer.getSharpeRatio(0.05)) 31 | print "Max. drawdown: %.2f %%" % (drawDownAnalyzer.getMaxDrawDown() * 100) 32 | print "Longest drawdown duration: %s" % (drawDownAnalyzer.getLongestDrawDownDuration()) 33 | print 34 | print "Total trades: %d" % (tradesAnalyzer.getCount()) 35 | if tradesAnalyzer.getCount() > 0: 36 | profits = tradesAnalyzer.getAll() 37 | print "Avg. profit: $%2.f" % (profits.mean()) 38 | print "Profits std. dev.: $%2.f" % (profits.std()) 39 | print "Max. profit: $%2.f" % (profits.max()) 40 | print "Min. profit: $%2.f" % (profits.min()) 41 | returns = tradesAnalyzer.getAllReturns() 42 | print "Avg. return: %2.f %%" % (returns.mean() * 100) 43 | print "Returns std. dev.: %2.f %%" % (returns.std() * 100) 44 | print "Max. return: %2.f %%" % (returns.max() * 100) 45 | print "Min. return: %2.f %%" % (returns.min() * 100) 46 | print 47 | print "Profitable trades: %d" % (tradesAnalyzer.getProfitableCount()) 48 | if tradesAnalyzer.getProfitableCount() > 0: 49 | profits = tradesAnalyzer.getProfits() 50 | print "Avg. profit: $%2.f" % (profits.mean()) 51 | print "Profits std. dev.: $%2.f" % (profits.std()) -------------------------------------------------------------------------------- /stocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | import tushare as ts 5 | import pandas as pd 6 | import numpy as np 7 | import talib 8 | import matplotlib.pyplot as plt 9 | import datetime 10 | import numpy, array 11 | import scipy 12 | import utils 13 | import db 14 | 15 | stock_list = db.load_file() 16 | 17 | 18 | def MACD(stock_data, fast_period=12, slow_period=26, signal_period=9): 19 | closes = stock_data.sort_index().close.values 20 | macd = talib.MACD(closes, fast_period, slow_period, signal_period) 21 | for i in range(len(macd[0])): 22 | macd[2][i] = (utils.f(macd[0][i]) - utils.f(macd[1][i])) * 2 23 | stock_data['dif'] = macd[0][::-1] 24 | stock_data['dea'] = macd[1][::-1] 25 | stock_data['macd'] = macd[2][::-1] 26 | return stock_data 27 | 28 | 29 | def MA(stock_data, timeperiod=30): 30 | closes = stock_data.sort_index().close.values 31 | ma = talib.EMA(closes, timeperiod) 32 | stock_data['ma' + str(timeperiod)] = ma[::-1] 33 | return stock_data 34 | 35 | 36 | def golden_cross(metric_first, metric_second): 37 | arr = [] 38 | for i in range(len(metric_first)): 39 | if metric_first[i] is not None and metric_second[i] is not None: 40 | if metric_first[i] < metric_second[i]: 41 | arr[i] = 0 42 | else: 43 | arr[i] = 1 44 | 45 | 46 | def list_stock(): 47 | stocks = ts.get_stock_basics() 48 | stocks = stocks[(stocks.pe < 100) & (stocks.pe > 0) & (stocks.totalAssets < 300000)] 49 | return stocks 50 | 51 | 52 | def filter(index, day=60): 53 | start_date = utils.get_start_date(day).strftime("%Y-%m-%d") 54 | stop_date = datetime.date.today().strftime("%Y-%m-%d") 55 | stock_data = ts.get_hist_data(index, start=start_date, end=stop_date) 56 | 57 | is_high = False 58 | # if stock_data.head(20).max().p_change > 9.5: 59 | # is_high = True 60 | # else: 61 | # return 62 | 63 | for stock in stock_data.head(3).itertuples(): 64 | if stock.open <= min(stock.ma5, stock.ma10, stock.ma20) \ 65 | and stock.close >= max(stock.ma5, stock.ma10, stock.ma20): 66 | print index 67 | print stock_data.head(1) 68 | 69 | # mean = stock_data.mean() 70 | 71 | # if 5 > mean.p_change > 1 and mean.turnover < 10 and mean.close * 1.2 >= stock_data.ix[0].close: 72 | # print "%s,p_change=%f,turnover=%f,close=%f,1.2close=%0.2f %0.2f,isHigh=%s" % \ 73 | # (index, mean.p_change, mean.turnover, mean.close, mean.close * 1.2, stock_data.ix[0].close, is_high) 74 | # print stock_data.describe() 75 | 76 | 77 | def realtime(symbol): 78 | return ts.get_realtime_quotes(symbol) 79 | 80 | 81 | def tick_today(symbol): 82 | return ts.get_today_ticks(symbol) 83 | 84 | 85 | def tick_history(symbol, date): 86 | df = ts.get_tick_data(symbol, date) 87 | return df 88 | 89 | 90 | if __name__ == "__main__": 91 | # stocks = list_stock((stocks.pe < 100) & (stocks.pe > 0) & (stocks.totalAssets < 300000)) 92 | # print utils.is_working_day(datetime.datetime.now()) 93 | # count = 0 94 | # for stock in stocks.index: 95 | # val = ts.get_realtime_quotes(stock) 96 | # if float(val.price) < 15 and float(val.price) > 0: 97 | # print stock, float(val.price) 98 | # count += 1; 99 | # print len(stocks),count 100 | 101 | # macd = MACD('300145') 102 | # print macd 103 | # get_basic('300415') 104 | # condition() 105 | 106 | # stock = ts.get_hist_data('000521', start='2010-04-01', end='2016-04-19') 107 | # stock = MA(stock, 30) 108 | # print stock.head(5) 109 | pass -------------------------------------------------------------------------------- /strategy/SMACrossOver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | from pyalgotrade import strategy 5 | from pyalgotrade.technical import ma 6 | from pyalgotrade.technical import cross 7 | 8 | 9 | class SMACrossOver(strategy.BacktestingStrategy): 10 | def __init__(self, feed, instrument, smaPeriod): 11 | strategy.BacktestingStrategy.__init__(self, feed) 12 | self.__instrument = instrument 13 | self.__position = None 14 | # We'll use adjusted close values instead of regular close values. 15 | self.setUseAdjustedValues(True) 16 | self.__prices = feed[instrument].getPriceDataSeries() 17 | self.__sma = ma.SMA(self.__prices, smaPeriod) 18 | 19 | def getSMA(self): 20 | return self.__sma 21 | 22 | def onEnterCanceled(self, position): 23 | self.__position = None 24 | 25 | def onExitOk(self, position): 26 | self.__position = None 27 | 28 | def onExitCanceled(self, position): 29 | # If the exit was canceled, re-submit it. 30 | self.__position.exitMarket() 31 | 32 | def onBars(self, bars): 33 | # If a position was not opened, check if we should enter a long position. 34 | if self.__position is None: 35 | if cross.cross_above(self.__prices, self.__sma) > 0: 36 | shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice()) 37 | # Enter a buy market order. The order is good till canceled. 38 | self.__position = self.enterLong(self.__instrument, shares, True) 39 | # Check if we have to exit the position. 40 | elif not self.__position.exitActive() and cross.cross_below(self.__prices, self.__sma) > 0: 41 | self.__position.exitMarket() -------------------------------------------------------------------------------- /strategy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | 4 | import matplotlib.pyplot as plt 5 | import tushare as ts 6 | import stocks 7 | 8 | # stock = ts.get_hist_data('000521', start='2015-04-01', end='2016-04-12') 9 | # stock.sort_index(ascending=True).close.plot() 10 | # stock.sort_index(ascending=True).ma5.plot() 11 | # plt.show() 12 | 13 | # dict={"name":"python","english":33,"math":35} 14 | # 15 | # print "##for in " 16 | # for i in dict: 17 | # print "dict[%s]=" % i,dict[i] 18 | # 19 | # print "##items" 20 | # for (k,v) in dict.items(): 21 | # print "dict[%s]=" % k,v 22 | # 23 | # print "##iteritems" 24 | # for k,v in dict.iteritems(): 25 | # print "dict[%s]=" % k,v 26 | 27 | df = stocks.tick_today('000521') 28 | desc = df.describe() 29 | mid = desc.volume['std'] * 3 30 | big = desc.volume['std'] * 6 31 | print "---" 32 | print df[(df.volume > mid) & (df.volume < big)] 33 | print "---" 34 | print df[df.volume > big] -------------------------------------------------------------------------------- /testcases/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #coding:utf-8 3 | -------------------------------------------------------------------------------- /testcases/utils_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | import unittest 4 | from utils import * 5 | 6 | 7 | class TestUtils(unittest.TestCase): 8 | def setUp(self): 9 | print 'setUp...' 10 | 11 | def tearDown(self): 12 | print 'tearDown...' 13 | 14 | def test_is_working_hour(self): 15 | print "now is working hour? ", is_working_hour(datetime.datetime.now()) 16 | # "2016-04-16 11:30" 17 | self.assertTrue(is_working_hour(datetime.datetime(2016, 4, 8, 11, 30))) 18 | # "2016-04-08 11:29" 19 | self.assertTrue(is_working_hour(datetime.datetime(2016, 4, 8, 11, 29))) 20 | # "2016-04-08 09:14" 21 | self.assertFalse(is_working_hour(datetime.datetime(2016, 4, 8, 9, 14))) 22 | # "2016-04-08 09:15" 23 | self.assertTrue(is_working_hour(datetime.datetime(2016, 4, 8, 9, 15))) 24 | # "2016-04-08 13:00" 25 | self.assertTrue(is_working_hour(datetime.datetime(2016, 4, 8, 13, 0))) 26 | # "2016-04-08 12:59" 27 | self.assertFalse(is_working_hour(datetime.datetime(2016, 4, 8, 12, 59))) 28 | # "2016-04-08 15:00" 29 | self.assertTrue(is_working_hour(datetime.datetime(2016, 4, 8, 15, 0))) 30 | # "2016-04-08 15:01" 31 | self.assertFalse(is_working_hour(datetime.datetime(2016, 4, 8, 15, 1))) 32 | 33 | def test_is_working_day(self): 34 | self.assertFalse(is_working_day(datetime.datetime(2016, 4, 16))) 35 | self.assertTrue(is_working_day(datetime.datetime(2016, 4, 15))) 36 | 37 | def test_get_page(self): 38 | doc = parse_page("http://quote.eastmoney.com/center/BKList.html#notion_0_0?sortRule=0") 39 | print doc 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf-8 3 | import datetime 4 | import urllib2 5 | 6 | from bs4 import BeautifulSoup 7 | 8 | 9 | def f(num): 10 | """ 11 | 结果四舍五入保留两位小数 12 | 13 | :param num: 待格式化数字 14 | :return: 15 | """ 16 | return round(float(num), 2) 17 | 18 | 19 | def get_start_date(n): 20 | """ 21 | 获取n天前的日期 22 | 23 | now = datetime.now() 24 | print now.strftime('%Y-%m-%d') 25 | 26 | strNow = '2012-01-03' 27 | nowDate = time.strptime(strNow, "%Y-%m-%d") 28 | 29 | %a 星期几的简写 Weekday name, abbr. 30 | %A 星期几的全称 Weekday name, full 31 | %b 月分的简写 Month name, abbr. 32 | %B 月份的全称 Month name, full 33 | %c 标准的日期的时间串 Complete date and time representation 34 | %d 十进制表示的每月的第几天 Day of the month 35 | %H 24小时制的小时 Hour (24-hour clock) 36 | %I 12小时制的小时 Hour (12-hour clock) 37 | %j 十进制表示的每年的第几天 Day of the year 38 | %m 十进制表示的月份 Month number 39 | %M 十时制表示的分钟数 Minute number 40 | %S 十进制的秒数 Second number 41 | %U 第年的第几周,把星期日做为第一天(值从0到53)Week number (Sunday first weekday) 42 | %w 十进制表示的星期几(值从0到6,星期天为0)weekday number 43 | %W 每年的第几周,把星期一做为第一天(值从0到53) Week number (Monday first weekday) 44 | %x 标准的日期串 Complete date representation (e.g. 13/01/08) 45 | %X 标准的时间串 Complete time representation (e.g. 17:02:10) 46 | %y 不带世纪的十进制年份(值从0到99)Year number within century 47 | %Y 带世纪部分的十制年份 Year number 48 | %z,%Z 时区名称,如果不能得到时区名称则返回空字符。Name of time zone 49 | %% 百分号 50 | 51 | :rtype : object 52 | :param n: 53 | :return: 54 | """ 55 | day = datetime.date.today() 56 | return day + datetime.timedelta(-n) 57 | 58 | 59 | def is_working_day(dt=datetime.datetime.now()): 60 | """ 61 | 检查某天是否是工作日,周一为0 62 | :param dt: 63 | :return: 64 | """ 65 | if 0 <= dt.weekday() <= 4: 66 | return True 67 | else: 68 | return False 69 | 70 | 71 | def is_working_hour(dt=datetime.datetime.now()): 72 | """ 73 | 检查今天是否是交易时间段 74 | :param dt: 75 | :return: 76 | """ 77 | if dt.time().hour <= 9 and dt.time().minute < 15: 78 | return False 79 | elif (dt.time().hour >= 11 and dt.time().minute > 30) and (dt.time().hour < 13): 80 | return False 81 | elif dt.time().hour >= 15 and dt.time().minute > 0: 82 | return False 83 | else: 84 | return True 85 | 86 | 87 | def get_page(url): 88 | request = urllib2.Request(url) 89 | response = urllib2.urlopen(request) 90 | return response.read() 91 | 92 | 93 | def parse_page(url): 94 | return BeautifulSoup(get_page(url),"lxml") 95 | 96 | 97 | if __name__ == "__main__": 98 | pass 99 | --------------------------------------------------------------------------------