35 | # 通过rowcount获得插入的行数:
36 | cursor.rowcount
37 |
38 | # 关闭Cursor:
39 | cursor.close()
40 | # 提交事务:
41 | conn.commit()
42 | # 关闭Connection:
43 | conn.close()
44 |
45 |
46 | # 我们再试试查询记录:
47 | conn = sqlite3.connect('test.db')
48 | cursor = conn.cursor()
49 | # 执行查询语句:
50 | cursor.execute('select * from user where id=?', ('1',))
51 | # 获得查询结果集:
52 | values = cursor.fetchall()
53 | values
54 | cursor.close()
55 | conn.close()
56 |
57 |
58 | '''
59 | 使用Python的DB-API时,只要搞清楚Connection和Cursor对象,打开后一定记得关闭,就可以放心地使用。
60 |
61 | 使用Cursor对象执行insert,update,delete语句时,执行结果由rowcount返回影响的行数,就可以拿到执行结果。
62 |
63 | 使用Cursor对象执行select语句时,通过featchall()可以拿到结果集。结果集是一个list,每个元素都是一个tuple,对应一行记录。
64 |
65 | 如果SQL语句带有参数,那么需要把参数按照位置传递给execute()方法,有几个?占位符就必须对应几个参数,例如:
66 | '''
67 | cursor.execute('select * from user where name=? and pwd=?', ('abc', 'password'))
68 |
69 | '''
70 | 在Python中操作数据库时,要先导入数据库对应的驱动,然后,通过Connection对象和Cursor对象操作数据。
71 |
72 | 要确保打开的Connection对象和Cursor对象都正确地被关闭,否则,资源就会泄露。
73 |
74 | 如何才能确保出错的情况下也关闭掉Connection对象和Cursor对象呢?请回忆try:...except:...finally:...的用法。
75 | '''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/14.数据库/2MySQL.py:
--------------------------------------------------------------------------------
1 | '''
2 | 使用MySQL
3 | MySQL是Web世界中使用最广泛的数据库服务器。SQLite的特点是轻量级、可嵌入,但不能承受高并发访问,适合桌面和移动应用。而MySQL是为服务器端设计的数据库,能承受高并发访问,同时占用的内存也远远大于SQLite。
4 |
5 | 此外,MySQL内部有多种数据库引擎,最常用的引擎是支持数据库事务的InnoDB。
6 |
7 | 安装MySQL
8 | 可以直接从MySQL官方网站下载最新的Community Server 5.6.x版本。MySQL是跨平台的,选择对应的平台下载安装文件,安装即可。
9 |
10 | 安装时,MySQL会提示输入root用户的口令,请务必记清楚。如果怕记不住,就把口令设置为password。
11 |
12 | 在Windows上,安装时请选择UTF-8编码,以便正确地处理中文。
13 |
14 | 在Mac或Linux上,需要编辑MySQL的配置文件,把数据库默认的编码全部改为UTF-8。MySQL的配置文件默认存放在/etc/my.cnf或者/etc/mysql/my.cnf:
15 | '''
16 |
17 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/1Http协议.py:
--------------------------------------------------------------------------------
1 | '''
2 | HTTP协议简介
3 | 在Web应用中,服务器把网页传给浏览器,实际上就是把网页的HTML代码发送给浏览器,让浏览器显示出来。而浏览器和服务器之间的传输协议是HTTP,所以:
4 |
5 | HTML是一种用来定义网页的文本,会HTML,就可以编写网页;
6 |
7 | HTTP是在网络上传输HTML的协议,用于浏览器和服务器的通信。
8 |
9 | 在举例子之前,我们需要安装Google的Chrome浏览器。
10 |
11 | 为什么要使用Chrome浏览器而不是IE呢?因为IE实在是太慢了,并且,IE对于开发和调试Web应用程序完全是一点用也没有。
12 |
13 | 我们需要在浏览器很方便地调试我们的Web应用,而Chrome提供了一套完整地调试工具,非常适合Web开发。
14 |
15 | 安装好Chrome浏览器后,打开Chrome,在菜单中选择“视图”,“开发者”,“开发者工具”,就可以显示开发者工具:
16 | '''
17 |
18 | # Elements显示网页的结构,Network显示浏览器和服务器的通信。我们点Network,确保第一个小红灯亮着,Chrome就会记录所有浏览器和服务器之间的通信:
19 | # 当我们在地址栏输入www.sina.com.cn时,浏览器将显示新浪的首页。在这个过程中,浏览器都干了哪些事情呢?通过Network的记录,我们就可以知道。在Network中,定位到第一条记录,点击,右侧将显示Request Headers,点击右侧的view source,我们就可以看到浏览器发给新浪服务器的请求:
20 |
21 | #表示请求的域名是www.sina.com.cn。如果一台服务器有多个网站,服务器就需要通过Host来区分浏览器请求的是哪个网站。
22 | #
23 | # 继续往下找到Response Headers,点击view source,显示服务器返回的原始响应数据:
24 |
25 | # 当浏览器读取到新浪首页的HTML源码后,它会解析HTML,显示页面,然后,根据HTML里面的各种链接,再发送HTTP请求给新浪服务器,拿到相应的图片、视频、Flash、JavaScript脚本、CSS等各种资源,最终显示出一个完整的页面。所以我们在Network下面能看到很多额外的HTTP请求。
26 |
27 | '''
28 | HTTP请求
29 | 跟踪了新浪的首页,我们来总结一下HTTP请求的流程:
30 |
31 | 步骤1:浏览器首先向服务器发送HTTP请求,请求包括:
32 |
33 | 方法:GET还是POST,GET仅请求资源,POST会附带用户数据;
34 |
35 | 路径:/full/url/path;
36 |
37 | 域名:由Host头指定:Host: www.sina.com.cn
38 |
39 | 以及其他相关的Header;
40 |
41 | 如果是POST,那么请求还包括一个Body,包含用户数据。
42 |
43 | 步骤2:服务器向浏览器返回HTTP响应,响应包括:
44 |
45 | 响应代码:200表示成功,3xx表示重定向,4xx表示客户端发送的请求有错误,5xx表示服务器端处理时发生了错误;
46 |
47 | 响应类型:由Content-Type指定,例如:Content-Type: text/html;charset=utf-8表示响应类型是HTML文本,并且编码是UTF-8,Content-Type: image/jpeg表示响应类型是JPEG格式的图片;
48 |
49 | 以及其他相关的Header;
50 |
51 | 通常服务器的HTTP响应会携带内容,也就是有一个Body,包含响应的内容,网页的HTML源码就在Body中。
52 |
53 | 步骤3:如果浏览器还需要继续向服务器请求其他资源,比如图片,就再次发出HTTP请求,重复步骤1、2。
54 |
55 | Web采用的HTTP协议采用了非常简单的请求-响应模式,从而大大简化了开发。当我们编写一个页面时,我们只需要在HTTP响应中把HTML发送出去,不需要考虑如何附带图片、视频等,浏览器如果需要请求图片和视频,它会发送另一个HTTP请求,因此,一个HTTP请求只处理一个资源。
56 |
57 | HTTP协议同时具备极强的扩展性,虽然浏览器请求的是http://www.sina.com.cn/的首页,但是新浪在HTML中可以链入其他服务器的资源,比如
,从而将请求压力分散到各个服务器上,并且,一个站点可以链接到其他站点,无数个站点互相链接起来,就形成了World Wide Web,简称“三达不溜”(WWW)。
58 | '''
59 |
60 | '''
61 | HTTP格式
62 | 每个HTTP请求和响应都遵循相同的格式,一个HTTP包含Header和Body两部分,其中Body是可选的。
63 |
64 | HTTP协议是一种文本协议,所以,它的格式也非常简单。HTTP GET请求的格式:
65 |
66 | GET /path HTTP/1.1
67 | Header1: Value1
68 | Header2: Value2
69 | Header3: Value3
70 | 每个Header一行一个,换行符是\r\n。
71 |
72 | HTTP POST请求的格式:
73 |
74 | POST /path HTTP/1.1
75 | Header1: Value1
76 | Header2: Value2
77 | Header3: Value3
78 |
79 | body data goes here...
80 | 当遇到连续两个\r\n时,Header部分结束,后面的数据全部是Body。
81 |
82 | HTTP响应的格式:
83 |
84 | 200 OK
85 | Header1: Value1
86 | Header2: Value2
87 | Header3: Value3
88 |
89 | body data goes here...
90 | HTTP响应如果包含body,也是通过\r\n\r\n来分隔的。请再次注意,Body的数据类型由Content-Type头来确定,如果是网页,Body就是文本,如果是图片,Body就是图片的二进制数据。
91 |
92 | 当存在Content-Encoding时,Body数据是被压缩的,最常见的压缩方式是gzip,所以,看到Content-Encoding: gzip时,需要将Body数据先解压缩,才能得到真正的数据。压缩的目的在于减少Body的大小,加快网络传输。
93 |
94 | 要详细了解HTTP协议,推荐“HTTP: The Definitive Guide”一书,非常不错,有中文译本:
95 | '''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/2wsgi接口.py:
--------------------------------------------------------------------------------
1 | '''
2 | 了解了HTTP协议和HTML文档,我们其实就明白了一个Web应用的本质就是:
3 |
4 | 浏览器发送一个HTTP请求;
5 |
6 | 服务器收到请求,生成一个HTML文档;
7 |
8 | 服务器把HTML文档作为HTTP响应的Body发送给浏览器;
9 |
10 | 浏览器收到HTTP响应,从HTTP Body取出HTML文档并显示。
11 |
12 | 所以,最简单的Web应用就是先把HTML用文件保存好,用一个现成的HTTP服务器软件,接收用户请求,从文件中读取HTML,返回。Apache、Nginx、Lighttpd等这些常见的静态服务器就是干这件事情的。
13 |
14 | 如果要动态生成HTML,就需要把上述步骤自己来实现。不过,接受HTTP请求、解析HTTP请求、发送HTTP响应都是苦力活,如果我们自己来写这些底层代码,还没开始写动态HTML呢,就得花个把月去读HTTP规范。
15 |
16 | 正确的做法是底层代码由专门的服务器软件实现,我们用Python专注于生成HTML文档。因为我们不希望接触到TCP连接、HTTP原始请求和响应格式,所以,需要一个统一的接口,让我们专心用Python编写Web业务。
17 |
18 | 这个接口就是WSGI:Web Server Gateway Interface。
19 |
20 | WSGI接口定义非常简单,它只要求Web开发者实现一个函数,就可以响应HTTP请求。我们来看一个最简单的Web版本的“Hello, web!”:
21 | '''
22 | def application(environ, start_response):
23 | start_response('200 OK', [('Content-Type', 'text/html')])
24 | return [b'Hello, web!
']
25 |
26 | '''
27 | 上面的application()函数就是符合WSGI标准的一个HTTP处理函数,它接收两个参数:
28 |
29 | environ:一个包含所有HTTP请求信息的dict对象;
30 |
31 | start_response:一个发送HTTP响应的函数。
32 |
33 | 在application()函数中,调用:
34 | '''
35 | start_response('200 OK', [('Content-Type', 'text/html')])
36 |
37 | '''
38 | 就发送了HTTP响应的Header,注意Header只能发送一次,也就是只能调用一次start_response()函数。start_response()函数接收两个参数,一个是HTTP响应码,一个是一组list表示的HTTP Header,每个Header用一个包含两个str的tuple表示。
39 |
40 | 通常情况下,都应该把Content-Type头发送给浏览器。其他很多常用的HTTP Header也应该发送。
41 |
42 | 然后,函数的返回值b'Hello, web!
'将作为HTTP响应的Body发送给浏览器。
43 |
44 | 有了WSGI,我们关心的就是如何从environ这个dict对象拿到HTTP请求信息,然后构造HTML,通过start_response()发送Header,最后返回Body。
45 |
46 | 整个application()函数本身没有涉及到任何解析HTTP的部分,也就是说,底层代码不需要我们自己编写,我们只负责在更高层次上考虑如何响应请求就可以了。
47 |
48 | 不过,等等,这个application()函数怎么调用?如果我们自己调用,两个参数environ和start_response我们没法提供,返回的bytes也没法发给浏览器。
49 |
50 | 所以application()函数必须由WSGI服务器来调用。有很多符合WSGI规范的服务器,我们可以挑选一个来用。但是现在,我们只想尽快测试一下我们编写的application()函数真的可以把HTML输出到浏览器,所以,要赶紧找一个最简单的WSGI服务器,把我们的Web应用程序跑起来。
51 |
52 | 好消息是Python内置了一个WSGI服务器,这个模块叫wsgiref,它是用纯Python编写的WSGI服务器的参考实现。所谓“参考实现”是指该实现完全符合WSGI标准,但是不考虑任何运行效率,仅供开发和测试使用。
53 | '''
54 |
55 | # 运行WSGI服务
56 | '''
57 | 我们先编写hello.py,实现Web应用程序的WSGI处理函数:
58 | '''
59 | # hello.py
60 |
61 | def application(environ, start_response):
62 | start_response('200 OK', [('Content-Type', 'text/html')])
63 | return [b'Hello, web!
']
64 | # 然后,再编写一个server.py,负责启动WSGI服务器,加载application()函数:
65 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/3使用web框架.py:
--------------------------------------------------------------------------------
1 | # 使用Web框架
2 | #
3 | # 了解了WSGI框架,我们发现:其实一个Web App,就是写一个WSGI的处理函数,针对每个HTTP请求进行响应。
4 | #
5 | # 但是如何处理HTTP请求不是问题,问题是如何处理100个不同的URL。
6 | #
7 | # 每一个URL可以对应GET和POST请求,当然还有PUT、DELETE等请求,但是我们通常只考虑最常见的GET和POST请求。
8 | #
9 | # 一个最简单的想法是从environ变量里取出HTTP请求的信息,然后逐个判断:
10 | '''
11 | def application(environ, start_response):
12 | method = environ['REQUEST_METHOD']
13 | path = environ['PATH_INFO']
14 | if method=='GET' and path=='/':
15 | return handle_home(environ, start_response)
16 | if method=='POST' and path='/signin':
17 | return handle_signin(environ, start_response)
18 | '''
19 |
20 | # 只是这么写下去代码是肯定没法维护了。
21 | #
22 | # 代码这么写没法维护的原因是因为WSGI提供的接口虽然比HTTP接口高级了不少,但和Web App的处理逻辑比,还是比较低级,我们需要在WSGI接口之上能进一步抽象,让我们专注于用一个函数处理一个URL,至于URL到函数的映射,就交给Web框架来做。
23 | #
24 | # 由于用Python开发一个Web框架十分容易,所以Python有上百个开源的Web框架。这里我们先不讨论各种Web框架的优缺点,直接选择一个比较流行的Web框架——Flask来使用。
25 | #
26 | # 用Flask编写Web App比WSGI接口简单(这不是废话么,要是比WSGI还复杂,用框架干嘛?),我们先用pip安装Flask:
27 | '''
28 | 然后写一个app.py,处理3个URL,分别是:
29 |
30 | GET /:首页,返回Home;
31 |
32 | GET /signin:登录页,显示登录表单;
33 |
34 | POST /signin:处理登录表单,显示登录结果。
35 |
36 | 注意噢,同一个URL/signin分别有GET和POST两种请求,映射到两个处理函数中。
37 |
38 | Flask通过Python的装饰器在内部自动地把URL和函数给关联起来,所以,我们写出来的代码就像这样:
39 | '''
40 |
41 |
42 | #
43 | # 小结
44 | # 有了Web框架,我们在编写Web应用时,注意力就从WSGI处理函数转移到URL+对应的处理函数,这样,编写Web App就更加简单了。
45 | #
46 | # 在编写URL处理函数时,除了配置URL外,从HTTP请求拿到用户数据也是非常重要的。Web框架都提供了自己的API来实现这些功能。Flask通过request.form['name']来获取表单的内容。
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/4使用模板.py:
--------------------------------------------------------------------------------
1 | # 使用模板
2 | '''
3 | Web框架把我们从WSGI中拯救出来了。现在,我们只需要不断地编写函数,带上URL,就可以继续Web App的开发了。
4 |
5 | 但是,Web App不仅仅是处理逻辑,展示给用户的页面也非常重要。在函数中返回一个包含HTML的字符串,简单的页面还可以,但是,想想新浪首页的6000多行的HTML,你确信能在Python的字符串中正确地写出来么?反正我是做不到。
6 |
7 | 俗话说得好,不懂前端的Python工程师不是好的产品经理。有Web开发经验的同学都明白,Web App最复杂的部分就在HTML页面。HTML不仅要正确,还要通过CSS美化,再加上复杂的JavaScript脚本来实现各种交互和动画效果。总之,生成HTML页面的难度很大。
8 |
9 | 由于在Python代码里拼字符串是不现实的,所以,模板技术出现了。
10 |
11 | 使用模板,我们需要预先准备一个HTML文档,这个HTML文档不是普通的HTML,而是嵌入了一些变量和指令,然后,根据我们传入的数据,替换后,得到最终的HTML,发送给用户:
12 | '''
13 |
14 |
15 | '''
16 | 这就是传说中的MVC:Model-View-Controller,中文名“模型-视图-控制器”。
17 |
18 | Python处理URL的函数就是C:Controller,Controller负责业务逻辑,比如检查用户名是否存在,取出用户信息等等;
19 |
20 | 包含变量{{ name }}的模板就是V:View,View负责显示逻辑,通过简单地替换一些变量,View最终输出的就是用户看到的HTML。
21 |
22 | MVC中的Model在哪?Model是用来传给View的,这样View在替换变量的时候,就可以从Model中取出相应的数据。
23 |
24 | 上面的例子中,Model就是一个dict:
25 | '''
26 |
27 | '''
28 | 只是因为Python支持关键字参数,很多Web框架允许传入关键字参数,然后,在框架内部组装出一个dict作为Model。
29 |
30 | 现在,我们把上次直接输出字符串作为HTML的例子用高端大气上档次的MVC模式改写一下:
31 | '''
32 |
33 | from flask import Flask, request, render_template
34 |
35 | app = Flask(__name__)
36 |
37 | @app.route('/', methods=['GET', 'POST'])
38 | def home():
39 | return render_template('home.html')
40 |
41 | @app.route('/signin', methods=['GET'])
42 | def signin_form():
43 | return render_template('form.html')
44 |
45 | @app.route('/signin', methods=['POST'])
46 | def signin():
47 | username = request.form['username']
48 | password = request.form['password']
49 | if username=='admin' and password=='password':
50 | return render_template('signin-ok.html', username=username)
51 | return render_template('form.html', message='Bad username or password', username=username)
52 |
53 | if __name__ == '__main__':
54 | app.run()
55 |
56 | # Flask通过render_template()函数来实现模板的渲染。和Web框架类似,Python的模板也有很多种。Flask默认支持的模板是jinja2,所以我们先直接安装jinja2:
57 |
58 | '''
59 | 登录失败的模板呢?我们在form.html中加了一点条件判断,把form.html重用为登录失败的模板。
60 |
61 | 最后,一定要把模板放到正确的templates目录下,templates和app.py在同级目录下:
62 | '''
63 |
64 |
65 | '''
66 | 通过MVC,我们在Python代码中处理M:Model和C:Controller,而V:View是通过模板处理的,这样,我们就成功地把Python代码和HTML代码最大限度地分离了。
67 |
68 | 使用模板的另一大好处是,模板改起来很方便,而且,改完保存后,刷新浏览器就能看到最新的效果,这对于调试HTML、CSS和JavaScript的前端工程师来说实在是太重要了。
69 |
70 | 在Jinja2模板中,我们用{{ name }}表示一个需要替换的变量。很多时候,还需要循环、条件判断等指令语句,在Jinja2中,用{% ... %}表示指令。
71 |
72 | 比如循环输出页码:
73 |
74 | {% for i in page_list %}
75 | {{ i }}
76 | {% endfor %}
77 | '''
78 |
79 | '''
80 | 除了Jinja2,常见的模板还有:
81 |
82 | Mako:用<% ... %>和${xxx}的一个模板;
83 |
84 | Cheetah:也是用<% ... %>和${xxx}的一个模板;
85 |
86 | Django:Django是一站式框架,内置一个用{% ... %}和{{ xxx }}的模板。
87 |
88 | 小结
89 | 有了MVC,我们就分离了Python代码和HTML代码。HTML代码全部放到模板里,写起来更有效率。
90 | '''
91 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/hello.py:
--------------------------------------------------------------------------------
1 | # def application(environ, start_response):
2 | # start_response('200 OK', [('Content-Type', 'text/html')])
3 | # return [b'Hello, web!
']
4 |
5 | # hello.py
6 |
7 | def application(environ, start_response):
8 | start_response('200 OK', [('Content-Type', 'text/html')])
9 | body = 'Hello, %s!
' % (environ['PATH_INFO'][1:] or 'web')
10 | return [body.encode('utf-8')]
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/server.py:
--------------------------------------------------------------------------------
1 | # server.py
2 | # 从wsgiref模块导入:
3 | from wsgiref.simple_server import make_server
4 | # 导入我们自己编写的application函数:
5 | # from hello import application
6 | from hello import application
7 | # 创建一个服务器,IP地址为空,端口是8000,处理函数是application:
8 | httpd = make_server('', 8000, application)
9 | print('Serving HTTP on port 8000...')
10 | # 开始监听HTTP请求:
11 | httpd.serve_forever()
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/使用web框架/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask
2 | from flask import request
3 |
4 | app = Flask(__name__)
5 |
6 | @app.route('/', methods=['GET', 'POST'])
7 | def home():
8 | return 'Home
'
9 |
10 | @app.route('/signin', methods=['GET'])
11 | def signin_form():
12 | return ''''''
17 |
18 | @app.route('/signin', methods=['POST'])
19 | def signin():
20 | # 需要从request对象读取表单内容:
21 | if request.form['username']=='admin' and request.form['password']=='password':
22 | return 'Hello, admin!
'
23 | return 'Bad username or password.
'
24 |
25 | if __name__ == '__main__':
26 | app.run()
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/使用模板/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, render_template
2 |
3 | app = Flask(__name__)
4 |
5 | @app.route('/', methods=['GET', 'POST'])
6 | def home():
7 | return render_template('home.html')
8 |
9 | @app.route('/signin', methods=['GET'])
10 | def signin_form():
11 | return render_template('form.html')
12 |
13 | @app.route('/signin', methods=['POST'])
14 | def signin():
15 | username = request.form['username']
16 | password = request.form['password']
17 | if username=='admin' and password=='password':
18 | return render_template('signin-ok.html', username=username)
19 | return render_template('form.html', message='Bad username or password', username=username)
20 |
21 | if __name__ == '__main__':
22 | app.run()
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/使用模板/templates/form.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Please Sign in
6 |
7 |
8 | {% if message %}
9 | {{ message }}
10 | {% endif %}
11 |
17 |
18 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/使用模板/templates/home.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | home
6 |
7 |
8 | Home
9 |
10 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/15.Web开发/使用模板/templates/signin-ok.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Welcome, {{ username }}
6 |
7 |
8 | Welcome, {{ username }}!
9 |
10 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/16.异步IO/1协程.py:
--------------------------------------------------------------------------------
1 | '''
2 | 在学习异步IO模型前,我们先来了解协程。
3 |
4 | 协程,又称微线程,纤程。英文名Coroutine。
5 |
6 | 协程的概念很早就提出来了,但直到最近几年才在某些语言(如Lua)中得到广泛应用。
7 |
8 | 子程序,或者称为函数,在所有语言中都是层级调用,比如A调用B,B在执行过程中又调用了C,C执行完毕返回,B执行完毕返回,最后是A执行完毕。
9 |
10 | 所以子程序调用是通过栈实现的,一个线程就是执行一个子程序。
11 |
12 | 子程序调用总是一个入口,一次返回,调用顺序是明确的。而协程的调用和子程序不同。
13 |
14 | 协程看上去也是子程序,但执行过程中,在子程序内部可中断,然后转而执行别的子程序,在适当的时候再返回来接着执行。
15 |
16 | 注意,在一个子程序中中断,去执行其他子程序,不是函数调用,有点类似CPU的中断。比如子程序A、B:
17 | '''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/2.函数式编程/1高阶函数.py:
--------------------------------------------------------------------------------
1 | # 函数名也是变量
2 | def add(x, y, f):
3 | return f(x) + f(y)
4 |
5 | print(add(-5, 6, abs))
6 |
7 | # map
8 | """
9 | 我们先看map。map()函数接收两个参数,一个是函数,一个是Iterable,map将传入的函数依次作用到序列的每个元素,并把结果作为新的Iterator返回。
10 | 举例说明,比如我们有一个函数f(x)=x^2,要把这个函数作用在一个list [1, 2, 3, 4, 5, 6, 7, 8, 9]上,就可以用map()实现如下:
11 | """
12 | def f(x):
13 | return x * x
14 | r = map(f, [1, 2, 3, 4, 5, 6, 7, 8, 9])
15 | """
16 | map()作为高阶函数,事实上它把运算规则抽象了,因此,我们不但可以计算简单的f(x)=x2,还可以计算任意复杂的函数,比如,把这个list所有数字转为字符串:
17 | """
18 | #reduce
19 | """
20 | 再看reduce的用法。reduce把一个函数作用在一个序列[x1, x2, x3, ...]上,这个函数必须接收两个参数,reduce把结果继续和序列的下一个元素做累积计算,其效果就是
21 | reduce(f, [x1, x2, x3, x4]) = f(f(f(x1, x2), x3), x4)
22 | """
23 | # example1
24 | from functools import reduce
25 | def add(x, y):
26 | return x + y
27 | reduce(add, [1, 2, 3, 4, 5, 6, 7, 8, 9])
28 | # example2
29 | def fn(x, y):
30 | return x * 10 + y
31 | reduce(fn, [1, 3, 5, 7, 9])
32 | # example3
33 | def char2num(s):
34 | digits = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
35 | return digits[s]
36 | reduce(fn, map(char2num, '13579'))
37 |
38 | # filter
39 | """
40 | 和map()类似,filter()也接收一个函数和一个序列。和map()不同的是,filter()把传入的函数依次作用于每个元素,然后根据返回值是True还是False决定保留还是丢弃该元素。
41 | 例如,在一个list中,删掉偶数,只保留奇数,可以这么写
42 | """
43 | def is_odd(n):
44 | return n % 2 == 1
45 | f = filter(is_odd, [1, 2, 4, 5, 6, 9, 10, 15])
46 | list(f)
47 |
48 | """
49 | 把一个序列中的空字符串删掉,可以这么写:
50 | """
51 | def not_empty(s):
52 | return s and s.strip()
53 |
54 | list(filter(not_empty, ['A', '', 'B', None, 'C', ' ']))
55 |
56 | """
57 | 筛选素数
58 | """
59 | def odd_iter():
60 | n = 1
61 | while True:
62 | n = n + 2
63 | yield n
64 |
65 | def not_divisible(n):
66 | return lambda x : x % n > 0
67 |
68 | def primes():
69 | yield 2
70 | it = odd_iter()
71 | while True:
72 | n = next(it)
73 | yield n
74 | it = filter(not_divisible(n), it)
75 |
76 | for n in primes():
77 | if n < 1000:
78 | print(n)
79 | else:
80 | break
81 |
82 | """
83 | 筛选回数
84 | """
85 | def is_palindrome(n):
86 | return str(n) == str(n)[::-1]
87 |
88 | output = filter(is_palindrome, range(1, 1000))
89 | print('1~1000:', list(output))
90 |
91 | """
92 | 排序
93 | """
94 | # 排序算法
95 | sorted([36, 5, -12, 9, -21])
96 | # 此外,sorted()函数也是一个高阶函数,它还可以接收一个key函数来实现自定义的排序,例如按绝对值大小排序:
97 | sorted([36, 5, -12, 9, -21], key=abs)
98 | # 我们再看一个字符串排序的例子:
99 | sorted(['bob', 'about', 'Zoo', 'Credit'])
100 | """
101 | 默认情况下,对字符串排序,是按照ASCII的大小比较的,由于'Z' < 'a',结果,大写字母Z会排在小写字母a的前面。
102 | 现在,我们提出排序应该忽略大小写,按照字母序排序。要实现这个算法,不必对现有代码大加改动,只要我们能用一个key函数把字符串映射为忽略大小写排序即可。忽略大小写来比较两个字符串,实际上就是先把字符串都变成大写(或者都变成小写),再比较。
103 | 这样,我们给sorted传入key函数,即可实现忽略大小写的排序:
104 | """
105 | sorted(['bob', 'about', 'Zoo', 'Credit'], key=str.lower)
106 | # 要进行反向排序,不必改动key函数,可以传入第三个参数reverse=True:
107 | sorted(['bob', 'about', 'Zoo', 'Credit'], key=str.lower, reverse=True)
108 |
109 | # sorted()也是一个高阶函数。用sorted()排序的关键在于实现一个映射函数。
110 | L = [('Bob', 75), ('Adam', 92), ('Bart', 66), ('Lisa', 88)]
111 | def by_name(t):
112 | return t[0]
113 |
114 | def by_score(t):
115 | return t[1]
116 |
117 | L2 = sorted(L, key=by_score)
118 | print(L2)
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/2.函数式编程/2返回函数和匿名函数.py:
--------------------------------------------------------------------------------
1 | """
2 | 高阶函数除了可以接受函数作为参数外,还可以把函数作为结果值返回。
3 | 我们来实现一个可变参数的求和。通常情况下,求和的函数是这样定义的:
4 | """
5 | def calc_sum(*args):
6 | ax = 0
7 | for n in args:
8 | ax = ax + n
9 | return ax
10 |
11 | # 但是,如果不需要立刻求和,而是在后面的代码中,根据需要再计算怎么办?可以不返回求和的结果,而是返回求和的函数:
12 | def lazy_sum(*args):
13 | def sum():
14 | ax = 0
15 | for n in args:
16 | ax = ax + n
17 | return ax
18 | return sum
19 | f = lazy_sum(1, 3, 5, 7, 9)
20 |
21 | # 闭包
22 | """
23 | 注意到返回的函数在其定义内部引用了局部变量args,所以,当一个函数返回了一个函数后,其内部的局部变量还被新函数引用,所以,闭包用起来简单,实现起来可不容易。
24 | 另一个需要注意的问题是,返回的函数并没有立刻执行,而是直到调用了f()才执行。我们来看一个例子:
25 | """
26 | def count():
27 | fs = []
28 | for i in range(1,4):
29 | def f():
30 | return i * i
31 | fs.append(f)
32 | return fs
33 |
34 | f1, f2, f3 = count()
35 | # 全部都是9!原因就在于返回的函数引用了变量i,但它并非立刻执行。等到3个函数都返回时,它们所引用的变量i已经变成了3,因此最终结果为9。
36 | """
37 | 如果一定要引用循环变量怎么办?方法是再创建一个函数,用该函数的参数绑定循环变量当前的值,无论该循环变量后续如何更改,已绑定到函数参数的值不变:
38 | """
39 | def count():
40 | def f(j):
41 | def g():
42 | return j*j
43 | return g
44 | fs = []
45 | for i in range(1, 4):
46 | fs.append(f(i)) # f(i)立刻被执行,因此i的当前值被传入f()
47 | return fs
48 |
49 | f1, f2, f3 = count()
50 | f1()
51 | f2()
52 | f3()
53 | """
54 | 匿名函数
55 | """
56 | # 在Python中,对匿名函数提供了有限支持。还是以map()函数为例,计算f(x)=x2时,除了定义一个f(x)的函数外,还可以直接传入匿名函数:
57 | list(map(lambda x: x * x, [1, 2, 3, 4, 5, 6, 7, 8, 9]))
58 |
59 | f = lambda x : x * x
60 | f(5)
61 | # 使用匿名函数修改
62 | def is_odd(n):
63 | return n % 2 == 1
64 | L = list(filter(is_odd, range(1, 20)))
65 |
66 | L= list(filter(lambda x:x%2==1, range(1,20)))
67 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/2.函数式编程/3装饰器.py:
--------------------------------------------------------------------------------
1 | # 装饰器
2 |
3 | def now():
4 | print('2015-01-05')
5 |
6 | now.__name__
7 | """
8 | 现在,假设我们要增强now()函数的功能,比如,在函数调用前后自动打印日志,但又不希望修改now()函数的定义,这种在代码运行期间动态增加功能的方式,称之为“装饰器”(Decorator)。
9 | 本质上,decorator就是一个返回函数的高阶函数。所以,我们要定义一个能打印日志的decorator,可以定义如下:
10 | """
11 | def log(func):
12 | def wrapper(*args, **kw):
13 | print('call %s():' % func.__name__)
14 | return func(*args, **kw)
15 | return wrapper
16 |
17 | @log
18 | def now():
19 | print('2015-01-05')
20 |
21 | now()
22 |
23 | # 把@log放到now()函数的定义处,相当于执行了语句:
24 | now = log(now)
25 | """
26 | 由于log()是一个decorator,返回一个函数,所以,原来的now()函数仍然存在,只是现在同名的now变量指向了新的函数,于是调用now()将执行新函数,即在log()函数中返回的wrapper()函数。
27 | wrapper()函数的参数定义是(*args, **kw),因此,wrapper()函数可以接受任意参数的调用。在wrapper()函数内,首先打印日志,再紧接着调用原始函数。
28 | 如果decorator本身需要传入参数,那就需要编写一个返回decorator的高阶函数,写出来会更复杂。比如,要自定义log的文本:
29 | """
30 | def log(text):
31 | def decorator(func):
32 | def wrapper(*args, **kw):
33 | print('%s %s():' % (text, func.__name__))
34 | return func(*args, **kw)
35 | return wrapper
36 | return decorator
37 |
38 | @log('execute')
39 | def now():
40 | print('2015-3-25')
41 |
42 | # 和两层嵌套的decorator相比,3层嵌套的效果是这样的:
43 | now = log('execute')(now)
44 |
45 | """
46 | 不需要编写wrapper.__name__ = func.__name__这样的代码,Python内置的functools.wraps就是干这个事的,所以,一个完整的decorator的写法如下:
47 | """
48 | import functools
49 |
50 | def log(func):
51 | @functools.wraps(func)
52 | def wrapper(*args, **kw):
53 | print('call %s():' % func.__name__)
54 | return func(*args, **kw)
55 | return wrapper
56 |
57 | # 或者针对带参数的decorator:
58 | import functools
59 |
60 | def log(text):
61 | def decorator(func):
62 | @functools.wraps(func)
63 | def wrapper(*args, **kw):
64 | print('%s %s():' % (text, func.__name__))
65 | return func(*args, **kw)
66 | return wrapper
67 | return decorator
68 |
69 | """
70 | 练习
71 | 请设计一个decorator,它可作用于任何函数上,并打印该函数的执行时间:
72 | """
73 | import time, functools
74 |
75 |
76 | def metric(fn):
77 | @functools.wraps(fn)
78 | def wrapper(*args, **kw):
79 | print('%s executed in %s ms' % (fn.__name__, 10.24))
80 | return fn(*args, **kw)
81 | return wrapper
82 |
83 | def metricX(fn):
84 | print('%s executed in %s ms' % (fn.__name__, 10.24))
85 | return fn
86 |
87 | # 测试
88 | @metricX
89 | def fast(x, y):
90 | time.sleep(0.0012)
91 | return x + y
92 |
93 |
94 | @metric
95 | def slow(x, y, z):
96 | time.sleep(0.1234)
97 | return x * y * z
98 |
99 | f = fast(11, 22)
100 | s = slow(11, 22, 33)
101 | if f != 33:
102 | print('测试失败!')
103 | elif s != 7986:
104 | print('测试失败!')
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/3.面向对象/1继承与多肽.py:
--------------------------------------------------------------------------------
1 | class Animal(object):
2 | def run(self):
3 | print('Animal is running...')
4 |
5 | class Dog(Animal):
6 | pass
7 |
8 | class Cat(Animal):
9 | pass
10 |
11 | dog = Dog()
12 | dog.run()
13 |
14 | cat = Cat()
15 | cat.run()
16 |
17 | a = list() # a是list类型
18 | b = Animal() # b是Animal类型
19 | c = Dog() # c是Dog类型
20 |
21 | isinstance(a, list)
22 |
23 | isinstance(b, Animal)
24 |
25 | isinstance(c, Dog)
26 |
27 | """
28 | 获取对象信息
29 | """
30 | type(123)
31 |
32 |
33 | import types
34 |
35 |
36 | def fn():
37 | pass
38 |
39 |
40 | type(fn) == types.FunctionType
41 |
42 | type(abs) == types.BuiltinFunctionType
43 |
44 | type(lambda x:x) == types.LambdaType
45 |
46 | type((x for x in range(10))) == types.GeneratorType
47 |
48 | """
49 | 使用dir()
50 | 如果要获得一个对象的所有属性和方法,可以使用dir()函数,它返回一个包含字符串的list,比如,获得一个str对象的所有属性和方法:
51 | """
52 | dir('ABC')
53 |
54 | # 仅仅把属性和方法列出来是不够的,配合getattr()、setattr()以及hasattr(),我们可以直接操作一个对象的状态:
55 | class MyObject(object):
56 | def __init__(self):
57 | self.x = 9
58 | def power(self):
59 | return self.x * self.x
60 |
61 |
62 | obj = MyObject()
63 | hasattr(obj, 'x') # 有属性'x'吗?
64 | setattr(obj, 'y', 19) # 设置一个属性'y'
65 | getattr(obj, 'y') # 获取属性'y'
66 |
67 |
68 | # 如果试图获取不存在的属性,会抛出AttributeError的错误:
69 | getattr(obj, 'z') # 获取属性'z'
70 | getattr(obj, 'z', 404) # 获取属性'z',如果不存在,返回默认值404
71 |
72 | # 也可以获得对象的方法:
73 | hasattr(obj, 'power')
74 | getattr(obj, 'power') # 获取属性'power'
75 | fn = getattr(obj, 'power') # 获取属性'power'并赋值到变量fn
76 | fn() # 调用fn()与调用obj.power()是一样的
77 |
78 | # 通过内置的一系列函数,我们可以对任意一个Python对象进行剖析,拿到其内部的数据。要注意的是,只有在不知道对象信息的时候,我们才会去获取对象信息。如果可以直接写:
79 | sum = obj.x + obj.y
80 | sum = getattr(obj, 'x') + getattr(obj, 'y')
81 |
82 | """
83 | 由于Python是动态语言,根据类创建的实例可以任意绑定属性。
84 | 给实例绑定属性的方法是通过实例变量,或者通过self变量:
85 | """
86 | class Student(object):
87 | def __init__(self, name):
88 | self.name = name
89 |
90 | s = Student('Bob')
91 | s.score = 90
92 |
93 | # 但是,如果Student类本身需要绑定一个属性呢?可以直接在class中定义属性,这种属性是类属性,归Student类所有:
94 | class Student(object):
95 | name = 'Student'
96 |
97 | s = Student() # 创建实例s
98 | print(s.name) # 打印name属性,因为实例并没有name属性,所以会继续查找class的name属性
99 | s.name = 'Tom' # 给实例绑定name属性
100 | print(s.name)
101 | print(Student.name) # 但是类属性并未消失,用Student.name仍然可以访问
102 | del s.name # 如果删除实例的name属性
103 | delattr(s, 'name')
104 | print(s.name) # 再次调用s.name,由于实例的name属性没有找到,类的name属性就显示出来了
105 |
106 |
107 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/4.错误调试测试/1错误处理.py:
--------------------------------------------------------------------------------
1 |
2 | def foo():
3 | r = some_function()
4 | if r==(-1):
5 | return (-1)
6 | # do something
7 | return r
8 |
9 | def bar():
10 | r = foo()
11 | if r==(-1):
12 | print('Error')
13 | else:
14 | pass
15 | '''
16 | try
17 | 让我们用一个例子来看看try的机制:
18 | '''
19 | try:
20 | print('try...')
21 | r = 10 / 0
22 | print('result:', r)
23 | except ZeroDivisionError as e:
24 | print('except:', e)
25 | finally:
26 | print('finally...')
27 | print('END')
28 |
29 | # 由于没有错误发生,所以except语句块不会被执行,但是finally如果有,则一定会被执行(可以没有finally语句)。
30 | #
31 | # 你还可以猜测,错误应该有很多种类,如果发生了不同类型的错误,应该由不同的except语句块处理。没错,可以有多个except来捕获不同类型的错误:
32 | try:
33 | print('try...')
34 | r = 10 / int('a')
35 | print('result:', r)
36 | except ValueError as e:
37 | print('ValueError:', e)
38 | except ZeroDivisionError as e:
39 | print('ZeroDivisionError:', e)
40 | finally:
41 | print('finally...')
42 | print('END')
43 | # int()函数可能会抛出ValueError,所以我们用一个except捕获ValueError,用另一个except捕获ZeroDivisionError。
44 | #
45 | # 此外,如果没有错误发生,可以在except语句块后面加一个else,当没有错误发生时,会自动执行else语句:
46 | try:
47 | print('try...')
48 | r = 10 / int('2')
49 | print('result:', r)
50 | except ValueError as e:
51 | print('ValueError:', e)
52 | except ZeroDivisionError as e:
53 | print('ZeroDivisionError:', e)
54 | else:
55 | print('no error!')
56 | finally:
57 | print('finally...')
58 | print('END')
59 | # Python的错误其实也是class,所有的错误类型都继承自BaseException,所以在使用except时需要注意的是,它不但捕获该类型的错误,还把其子类也“一网打尽”。比如:
60 | try:
61 | foo()
62 | except ValueError as e:
63 | print('ValueError')
64 | except UnicodeError as e:
65 | print('UnicodeError')
66 |
67 | '''
68 | 第二个except永远也捕获不到UnicodeError,因为UnicodeError是ValueError的子类,如果有,也被第一个except给捕获了。
69 |
70 | Python所有的错误都是从BaseException类派生的,常见的错误类型和继承关系看这里:
71 |
72 | https://docs.python.org/3/library/exceptions.html#exception-hierarchy
73 |
74 | 使用try...except捕获错误还有一个巨大的好处,就是可以跨越多层调用,比如函数main()调用foo(),foo()调用bar(),结果bar()出错了,这时,只要main()捕获到了,就可以处理:
75 | '''
76 | def foo(s):
77 | return 10 / int(s)
78 |
79 | def bar(s):
80 | return foo(s) * 2
81 |
82 | def main():
83 | try:
84 | bar('0')
85 | except Exception as e:
86 | print('Error:', e)
87 | finally:
88 | print('finally...')
89 | # 也就是说,不需要在每个可能出错的地方去捕获错误,只要在合适的层次去捕获错误就可以了。这样一来,就大大减少了写try...except...finally的麻烦。
90 |
91 | '''
92 | 调用栈
93 | 如果错误没有被捕获,它就会一直往上抛,最后被Python解释器捕获,打印一个错误信息,然后程序退出。来看看err.py:
94 | '''
95 | # err.py:
96 | def foo(s):
97 | return 10 / int(s)
98 |
99 | def bar(s):
100 | return foo(s) * 2
101 |
102 | def main():
103 | bar('0')
104 |
105 | main()
106 |
107 | '''
108 | 记录错误
109 | 如果不捕获错误,自然可以让Python解释器来打印出错误堆栈,但程序也被结束了。既然我们能捕获错误,就可以把错误堆栈打印出来,然后分析错误原因,同时,让程序继续执行下去。
110 |
111 | Python内置的logging模块可以非常容易地记录错误信息:
112 | '''
113 | # err_logging.py
114 |
115 | import logging
116 |
117 | def foo(s):
118 | return 10 / int(s)
119 |
120 | def bar(s):
121 | return foo(s) * 2
122 |
123 | def main():
124 | try:
125 | bar('0')
126 | except Exception as e:
127 | logging.exception(e)
128 |
129 | main()
130 | print('END')
131 |
132 |
133 | '''
134 | 抛出错误
135 | 因为错误是class,捕获一个错误就是捕获到该class的一个实例。因此,错误并不是凭空产生的,而是有意创建并抛出的。Python的内置函数会抛出很多类型的错误,我们自己编写的函数也可以抛出错误。
136 |
137 | 如果要抛出错误,首先根据需要,可以定义一个错误的class,选择好继承关系,然后,用raise语句抛出一个错误的实例:
138 |
139 | # err_raise.py
140 | '''
141 | class FooError(ValueError):
142 | pass
143 |
144 | def foo(s):
145 | n = int(s)
146 | if n==0:
147 | raise FooError('invalid value: %s' % s)
148 | return 10 / n
149 |
150 | foo('0')
151 |
152 | # 只有在必要的时候才定义我们自己的错误类型。如果可以选择Python已有的内置的错误类型(比如ValueError,TypeError),尽量使用Python内置的错误类型。
153 | #
154 | # 最后,我们来看另一种错误处理的方式:
155 |
156 | # err_reraise.py
157 |
158 | def foo(s):
159 | n = int(s)
160 | if n==0:
161 | raise ValueError('invalid value: %s' % s)
162 | return 10 / n
163 |
164 | def bar():
165 | try:
166 | foo('0')
167 | except ValueError as e:
168 | print('ValueError!')
169 | raise
170 |
171 | bar()
172 | # 在bar()函数中,我们明明已经捕获了错误,但是,打印一个ValueError!后,又把错误通过raise语句抛出去了,这不有病么?
173 | #
174 | # 其实这种错误处理方式不但没病,而且相当常见。捕获错误目的只是记录一下,便于后续追踪。但是,由于当前函数不知道应该怎么处理该错误,所以,最恰当的方式是继续往上抛,让顶层调用者去处理。好比一个员工处理不了一个问题时,就把问题抛给他的老板,如果他的老板也处理不了,就一直往上抛,最终会抛给CEO去处理。
175 | #
176 | # raise语句如果不带参数,就会把当前错误原样抛出。此外,在except中raise一个Error,还可以把一种类型的错误转化成另一种类型:
177 |
178 | from functools import reduce
179 |
180 | def str2num(s):
181 | return int(s)
182 |
183 | def calc(exp):
184 | try:
185 | ss = exp.split('+')
186 | ns = map(str2num, ss)
187 | return reduce(lambda acc, x: acc + x, ns)
188 | except Exception as e:
189 | raise ValueError('value error')
190 |
191 | def main():
192 | try:
193 | r = calc('100 + 200 + 345')
194 | print('100 + 200 + 345 =', r)
195 | r = calc('99 + 88 + 7.6')
196 | print('99 + 88 + 7.6 =', r)
197 | except Exception as e:
198 | logging.exception(e)
199 | print(e)
200 |
201 | main()
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/4.错误调试测试/2调试.py:
--------------------------------------------------------------------------------
1 | '''
2 | 断言
3 | 凡是用print()来辅助查看的地方,都可以用断言(assert)来替代:
4 | '''
5 | def foo(s):
6 | n = int(s)
7 | assert n != 0, 'n is zero!'
8 | return 10 / n
9 |
10 | def main():
11 | foo('0')
12 |
13 | '''
14 | logging
15 | 把print()替换为logging是第3种方式,和assert比,logging不会抛出错误,而且可以输出到文件:
16 | '''
17 | import logging
18 | logging.basicConfig(level=logging.INFO)
19 | s = '0'
20 | n = int(s)
21 | logging.info('n = %d' % n)
22 | print(10 / n)
23 | '''
24 | logging.info()就可以输出一段文本。运行,发现除了ZeroDivisionError,没有任何信息。怎么回事?
25 |
26 | 别急,在import logging之后添加一行配置再试试:
27 | '''
28 | logging.basicConfig(level=logging.INFO)
29 |
30 | '''pdb'''
31 |
32 |
33 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/4.错误调试测试/3单元测试.py:
--------------------------------------------------------------------------------
1 | class Dict(dict):
2 |
3 | def __init__(self, **kw):
4 | super().__init__(**kw)
5 |
6 | def __getattr__(self, key):
7 | try:
8 | return self[key]
9 | except KeyError:
10 | raise AttributeError(r"'Dict' object has no attribute '%s'" % key)
11 |
12 | def __setattr__(self, key, value):
13 | self[key] = value
14 |
15 | import unittest
16 |
17 | class TestDict(unittest.TestCase):
18 |
19 | def test_init(self):
20 | d = Dict(a=1, b='test')
21 | self.assertEqual(d.a, 1)
22 | self.assertEqual(d.b, 'test')
23 | self.assertTrue(isinstance(d, dict))
24 |
25 | def test_key(self):
26 | d = Dict()
27 | d['key'] = 'value'
28 | self.assertEqual(d.key, 'value')
29 |
30 | def test_attr(self):
31 | d = Dict()
32 | d.key = 'value'
33 | self.assertTrue('key' in d)
34 | self.assertEqual(d['key'], 'value')
35 |
36 | def test_keyerror(self):
37 | d = Dict()
38 | with self.assertRaises(KeyError):
39 | value = d['empty']
40 |
41 | def test_attrerror(self):
42 | d = Dict()
43 | with self.assertRaises(AttributeError):
44 | value = d.empty
45 |
46 | # self.assertEqual(abs(-1), 1) # 断言函数返回的结果与1相等
47 | # 另一种重要的断言就是期待抛出指定类型的Error,比如通过d['empty']访问不存在的key时,断言会抛出KeyError:
48 | # with self.assertRaises(KeyError):
49 | # value = d['empty']
50 | '''
51 | setUp与tearDown
52 | 可以在单元测试中编写两个特殊的setUp()和tearDown()方法。这两个方法会分别在每调用一个测试方法的前后分别被执行。
53 |
54 | setUp()和tearDown()方法有什么用呢?设想你的测试需要启动一个数据库,这时,就可以在setUp()方法中连接数据库,在tearDown()方法中关闭数据库,这样,不必在每个测试方法中重复相同的代码:
55 |
56 | '''
57 | class TestDict(unittest.TestCase):
58 |
59 | def setUp(self):
60 | print('setUp...')
61 |
62 | def tearDown(self):
63 | print('tearDown...')
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/5.IO编程/1文件读写.py:
--------------------------------------------------------------------------------
1 | '''
2 | 文件读写
3 | 读写文件是最常见的IO操作。Python内置了读写文件的函数,用法和C是兼容的。
4 |
5 | 读写文件前,我们先必须了解一下,在磁盘上读写文件的功能都是由操作系统提供的,现代操作系统不允许普通的程序直接操作磁盘,所以,读写文件就是请求操作系统打开一个文件对象(通常称为文件描述符),然后,通过操作系统提供的接口从这个文件对象中读取数据(读文件),或者把数据写入这个文件对象(写文件)。
6 | '''
7 | '''
8 | 读文件
9 | 要以读文件的模式打开一个文件对象,使用Python内置的open()函数,传入文件名和标示符:
10 | '''
11 | f = open('/Users/yejunhai/Desktop/test.txt', 'r')
12 |
13 | f.read()
14 | f.close()
15 | # 由于文件读写时都有可能产生IOError,一旦出错,后面的f.close()就不会调用。所以,为了保证无论是否出错都能正确地关闭文件,我们可以使用try ... finally来实现:
16 | try:
17 | f = open('/path/to/file', 'r')
18 | print(f.read())
19 | finally:
20 | if f:
21 | f.close()
22 |
23 | # 但是每次都这么写实在太繁琐,所以,Python引入了with语句来自动帮我们调用close()方法:
24 | with open('/path/to/file', 'r') as f:
25 | print(f.read())
26 | '''
27 | 调用read()会一次性读取文件的全部内容,如果文件有10G,内存就爆了,所以,要保险起见,可以反复调用read(size)方法,每次最多读取size个字节的内容。另外,调用readline()可以每次读取一行内容,调用readlines()一次读取所有内容并按行返回list。因此,要根据需要决定怎么调用。
28 |
29 | 如果文件很小,read()一次性读取最方便;如果不能确定文件大小,反复调用read(size)比较保险;如果是配置文件,调用readlines()最方便:
30 | '''
31 | for line in f.readlines():
32 | print(line.strip()) # 把末尾的'\n'删掉
33 |
34 | '''
35 | file-like Object
36 | 像open()函数返回的这种有个read()方法的对象,在Python中统称为file-like Object。除了file外,还可以是内存的字节流,网络流,自定义流等等。file-like Object不要求从特定类继承,只要写个read()方法就行。
37 |
38 | StringIO就是在内存中创建的file-like Object,常用作临时缓冲。
39 | '''
40 | '''
41 | 二进制文件
42 | 前面讲的默认都是读取文本文件,并且是UTF-8编码的文本文件。要读取二进制文件,比如图片、视频等等,用'rb'模式打开文件即可:
43 | '''
44 | f = open('/Users/michael/test.jpg', 'rb')
45 | f.read()
46 | '''
47 | 字符编码
48 | 要读取非UTF-8编码的文本文件,需要给open()函数传入encoding参数,例如,读取GBK编码的文件:
49 | '''
50 | f = open('/Users/michael/gbk.txt', 'r', encoding='gbk')
51 | f.read()
52 | '''
53 | 遇到有些编码不规范的文件,你可能会遇到UnicodeDecodeError,因为在文本文件中可能夹杂了一些非法编码的字符。遇到这种情况,open()函数还接收一个errors参数,表示如果遇到编码错误后如何处理。最简单的方式是直接忽略:
54 | '''
55 | f = open('/Users/michael/gbk.txt', 'r', encoding='gbk', errors='ignore')
56 |
57 | '''
58 | 写文件
59 | 写文件和读文件是一样的,唯一区别是调用open()函数时,传入标识符'w'或者'wb'表示写文本文件或写二进制文件:
60 | '''
61 | f = open('/Users/michael/test.txt', 'w')
62 | f.write('Hello, world!')
63 | f.close()
64 |
65 | '''
66 | 你可以反复调用write()来写入文件,但是务必要调用f.close()来关闭文件。当我们写文件时,操作系统往往不会立刻把数据写入磁盘,而是放到内存缓存起来,空闲的时候再慢慢写入。只有调用close()方法时,操作系统才保证把没有写入的数据全部写入磁盘。忘记调用close()的后果是数据可能只写了一部分到磁盘,剩下的丢失了。所以,还是用with语句来得保险:
67 | '''
68 | with open('/Users/michael/test.txt', 'w') as f:
69 | f.write('Hello, world!')
70 |
71 | '''
72 | 要写入特定编码的文本文件,请给open()函数传入encoding参数,将字符串自动转换成指定编码。
73 |
74 | 细心的童鞋会发现,以'w'模式写入文件时,如果文件已存在,会直接覆盖(相当于删掉后新写入一个文件)。如果我们希望追加到文件末尾怎么办?可以传入'a'以追加(append)模式写入。
75 | '''
76 |
77 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/5.IO编程/2StringIO和BytesIO.py:
--------------------------------------------------------------------------------
1 | '''
2 | StringIO
3 | 很多时候,数据读写不一定是文件,也可以在内存中读写。
4 |
5 | StringIO顾名思义就是在内存中读写str。
6 |
7 | 要把str写入StringIO,我们需要先创建一个StringIO,然后,像文件一样写入即可:
8 | '''
9 | from io import StringIO
10 | f = StringIO()
11 | f.write('hello')
12 | f.write(' ')
13 | f.write('world!')
14 | print(f.getvalue())
15 |
16 | # 要读取StringIO,可以用一个str初始化StringIO,然后,像读文件一样读取:
17 | from io import StringIO
18 | f = StringIO('Hello!\nHi!\nGoodbye!')
19 | while True:
20 | s = f.readline()
21 | if s == '':
22 | break
23 | print(s.strip())
24 |
25 | # BytesIO
26 | '''
27 | StringIO操作的只能是str,如果要操作二进制数据,就需要使用BytesIO。
28 | BytesIO实现了在内存中读写bytes,我们创建一个BytesIO,然后写入一些bytes:
29 | '''
30 | from io import BytesIO
31 | f = BytesIO()
32 | f.write('中文'.encode('utf-8'))
33 | print(f.getvalue())
34 |
35 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/5.IO编程/3操作文件和目录.py:
--------------------------------------------------------------------------------
1 | '''
2 | 如果我们要操作文件、目录,可以在命令行下面输入操作系统提供的各种命令来完成。比如dir、cp等命令。
3 |
4 | 如果要在Python程序中执行这些目录和文件的操作怎么办?其实操作系统提供的命令只是简单地调用了操作系统提供的接口函数,Python内置的os模块也可以直接调用操作系统提供的接口函数。
5 |
6 | 打开Python交互式命令行,我们来看看如何使用os模块的基本功能:
7 | '''
8 | import os
9 | os.name
10 | # 如果是posix,说明系统是Linux、Unix或Mac OS X,如果是nt,就是Windows系统。
11 | # 要获取详细的系统信息,可以调用uname()函数:
12 | os.uname()
13 | # 注意uname()函数在Windows上不提供,也就是说,os模块的某些函数是跟操作系统相关的。
14 |
15 | '''
16 | 环境变量
17 | 在操作系统中定义的环境变量,全部保存在os.environ这个变量中,可以直接查看:
18 | '''
19 | os.environ
20 | # 要获取某个环境变量的值,可以调用os.environ.get('key'):
21 | os.environ.get('PATH')
22 | os.environ.get('x', 'default')
23 |
24 | '''
25 | 操作文件和目录
26 | 操作文件和目录的函数一部分放在os模块中,一部分放在os.path模块中,这一点要注意一下。查看、创建和删除目录可以这么调用:
27 | '''
28 | x = os.path.abspath('.')
29 | y = os.path.join(x, 'testdir')
30 | os.mkdir(y)
31 | os.rmdir(y)
32 | # 把两个路径合成一个时,不要直接拼字符串,而要通过os.path.join()函数,这样可以正确处理不同操作系统的路径分隔符。在Linux/Unix/Mac下,os.path.join()返回这样的字符串:
33 | # 同样的道理,要拆分路径时,也不要直接去拆字符串,而要通过os.path.split()函数,这样可以把一个路径拆分为两部分,后一部分总是最后级别的目录或文件名:
34 | z = os.path.split(y)
35 | # os.path.splitext()可以直接让你得到文件扩展名,很多时候非常方便:
36 | os.path.splitext('/path/to/file.txt')
37 | '''
38 | 这些合并、拆分路径的函数并不要求目录和文件要真实存在,它们只对字符串进行操作。
39 |
40 | 文件操作使用下面的函数。假定当前目录下有一个test.txt文件:
41 | '''
42 | os.rename('test.txt', 'test.py')
43 | os.remove('test.py')
44 | '''
45 | 但是复制文件的函数居然在os模块中不存在!原因是复制文件并非由操作系统提供的系统调用。理论上讲,我们通过上一节的读写文件可以完成文件复制,只不过要多写很多代码。
46 |
47 | 幸运的是shutil模块提供了copyfile()的函数,你还可以在shutil模块中找到很多实用函数,它们可以看做是os模块的补充。
48 |
49 | 最后看看如何利用Python的特性来过滤文件。比如我们要列出当前目录下的所有目录,只需要一行代码:
50 | '''
51 | [x for x in os.listdir('.') if os.path.isdir(x)]
52 | # 要列出所有的.py文件,也只需一行代码:
53 | [x for x in os.listdir('.') if os.path.isfile(x) and os.path.splitext(x)[1]=='.py']
54 |
55 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/5.IO编程/4序列化.py:
--------------------------------------------------------------------------------
1 | # 在程序运行的过程中,所有的变量都是在内存中,比如,定义一个dict:
2 | d = dict(name='Bob', age=20, score=88)
3 | '''
4 | 可以随时修改变量,比如把name改成'Bill',但是一旦程序结束,变量所占用的内存就被操作系统全部回收。如果没有把修改后的'Bill'存储到磁盘上,下次重新运行程序,变量又被初始化为'Bob'。
5 |
6 | 我们把变量从内存中变成可存储或传输的过程称之为序列化,在Python中叫pickling,在其他语言中也被称之为serialization,marshalling,flattening等等,都是一个意思。
7 |
8 | 序列化之后,就可以把序列化后的内容写入磁盘,或者通过网络传输到别的机器上。
9 |
10 | 反过来,把变量内容从序列化的对象重新读到内存里称之为反序列化,即unpickling。
11 |
12 | Python提供了pickle模块来实现序列化。
13 |
14 | 首先,我们尝试把一个对象序列化并写入文件:
15 | '''
16 | import pickle
17 | d = dict(name='Tom', age=20, score=100)
18 | pickle.dumps(d)
19 | # pickle.dumps()方法把任意对象序列化成一个bytes,然后,就可以把这个bytes写入文件。或者用另一个方法pickle.dump()直接把对象序列化后写入一个file-like Object:
20 | f = open('dump.txt', 'wb')
21 | pickle.dump(d, f)
22 | f.close()
23 | '''
24 | 看看写入的dump.txt文件,一堆乱七八糟的内容,这些都是Python保存的对象内部信息。
25 |
26 | 当我们要把对象从磁盘读到内存时,可以先把内容读到一个bytes,然后用pickle.loads()方法反序列化出对象,也可以直接用pickle.load()方法从一个file-like Object中直接反序列化出对象。我们打开另一个Python命令行来反序列化刚才保存的对象:
27 | '''
28 | f = open('dump.txt', 'rb')
29 | d = pickle.load(f)
30 | f.close()
31 | d
32 |
33 | '''
34 | JSON
35 | 如果我们要在不同的编程语言之间传递对象,就必须把对象序列化为标准格式,比如XML,但更好的方法是序列化为JSON,因为JSON表示出来就是一个字符串,可以被所有语言读取,也可以方便地存储到磁盘或者通过网络传输。JSON不仅是标准格式,并且比XML更快,而且可以直接在Web页面中读取,非常方便。
36 |
37 | JSON表示的对象就是标准的JavaScript语言的对象,JSON和Python内置的数据类型对应如下:
38 | '''
39 | # Python内置的json模块提供了非常完善的Python对象到JSON格式的转换。我们先看看如何把Python对象变成一个JSON:
40 | import json
41 | d = dict(name='Tom', age=20, score=100)
42 | json.dumps(d)
43 | '''
44 | dumps()方法返回一个str,内容就是标准的JSON。类似的,dump()方法可以直接把JSON写入一个file-like Object。
45 |
46 | 要把JSON反序列化为Python对象,用loads()或者对应的load()方法,前者把JSON的字符串反序列化,后者从file-like Object中读取字符串并反序列化:
47 | '''
48 | json_str = '{"age": 20, "score": 88, "name": "Bob"}'
49 | json.loads(json_str)
50 |
51 | # JSON进阶
52 | import json
53 |
54 | class Student(object):
55 | def __init__(self, name, age, score):
56 | self.name = name
57 | self.age = age
58 | self.score = score
59 |
60 | s = Student('Bob', 20, 88)
61 | print(json.dumps(s))
62 | # 错误的原因是Student对象不是一个可序列化为JSON的对象。
63 | '''
64 | 这些可选参数就是让我们来定制JSON序列化。前面的代码之所以无法把Student类实例序列化为JSON,是因为默认情况下,dumps()方法不知道如何将Student实例变为一个JSON的{}对象。
65 |
66 | 可选参数default就是把任意一个对象变成一个可序列为JSON的对象,我们只需要为Student专门写一个转换函数,再把函数传进去即可:
67 | '''
68 | def student2dict(std):
69 | return {
70 | 'name': std.name,
71 | 'age': std.age,
72 | 'score': std.score
73 | }
74 | # 这样,Student实例首先被student2dict()函数转换成dict,然后再被顺利序列化为JSON:
75 | print(json.dumps(s, default=student2dict))
76 |
77 | # 不过,下次如果遇到一个Teacher类的实例,照样无法序列化为JSON。我们可以偷个懒,把任意class的实例变为dict
78 | print(json.dumps(s, default=lambda obj: obj.__dict__))
79 | '''
80 | 因为通常class的实例都有一个__dict__属性,它就是一个dict,用来存储实例变量。也有少数例外,比如定义了__slots__的class。
81 |
82 | 同样的道理,如果我们要把JSON反序列化为一个Student对象实例,loads()方法首先转换出一个dict对象,然后,我们传入的object_hook函数负责把dict转换为Student实例:
83 | '''
84 | def dict2student(d):
85 | return Student(d['name'], d['age'], d['score'])
86 | json_str = '{"age": 20, "score": 88, "name": "Bob"}'
87 | print(json.loads(json_str, object_hook=dict2student))
88 |
89 |
90 | obj = dict(name='小明', age=20)
91 | s = json.dumps(obj, ensure_ascii=True)
92 | class obj(object):
93 | def __init__(self, name, age):
94 | self.name = name
95 | self.age = age
96 | def dict2obj(d):
97 | return obj(d['name'], d['age'])
98 |
99 | a = json.loads(s, object_hook=dict2obj)
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/6.进程与线程/1多进程.py:
--------------------------------------------------------------------------------
1 | '''
2 | 总结一下就是,多任务的实现有3种方式:
3 |
4 | 多进程模式;
5 | 多线程模式;
6 | 多进程+多线程模式。
7 | '''
8 | # 要让Python程序实现多进程(multiprocessing),我们先了解操作系统的相关知识。
9 | '''
10 | Unix/Linux操作系统提供了一个fork()系统调用,它非常特殊。普通的函数调用,调用一次,返回一次,但是fork()调用一次,返回两次,因为操作系统自动把当前进程(称为父进程)复制了一份(称为子进程),然后,分别在父进程和子进程内返回。
11 |
12 | 子进程永远返回0,而父进程返回子进程的ID。这样做的理由是,一个父进程可以fork出很多子进程,所以,父进程要记下每个子进程的ID,而子进程只需要调用getppid()就可以拿到父进程的ID。
13 |
14 | Python的os模块封装了常见的系统调用,其中就包括fork,可以在Python程序中轻松创建子进程:
15 | '''
16 | import os
17 | print('Process (%s) start...' % os.getpid())
18 | # Only works on Unix/Linux/Mac:
19 | pid = os.fork()
20 | if pid == 0:
21 | print('I am child process (%s) and my parent is %s.' % (os.getpid(), os.getppid()))
22 | else:
23 | print('I (%s) just created a child process (%s).' % (os.getpid(), pid))
24 |
25 | '''
26 | 由于Windows没有fork调用,上面的代码在Windows上无法运行。而Mac系统是基于BSD(Unix的一种)内核,所以,在Mac下运行是没有问题的,推荐大家用Mac学Python!
27 |
28 | 有了fork调用,一个进程在接到新任务时就可以复制出一个子进程来处理新任务,常见的Apache服务器就是由父进程监听端口,每当有新的http请求时,就fork出子进程来处理新的http请求。
29 | '''
30 |
31 | '''
32 | multiprocessing
33 | '''
34 | '''
35 | 如果你打算编写多进程的服务程序,Unix/Linux无疑是正确的选择。由于Windows没有fork调用,难道在Windows上无法用Python编写多进程的程序?
36 |
37 | 由于Python是跨平台的,自然也应该提供一个跨平台的多进程支持。multiprocessing模块就是跨平台版本的多进程模块。
38 |
39 | multiprocessing模块提供了一个Process类来代表一个进程对象,下面的例子演示了启动一个子进程并等待其结束:
40 | '''
41 | from multiprocessing import Process
42 | import os
43 | # 子进程要执行的代码
44 | def run_proc(name):
45 | print('Run child process %s (%s)...' % (name, os.getpid()))
46 |
47 | if __name__=='__main__':
48 | print('Parent process %s.' % os.getpid())
49 | p = Process(target=run_proc, args=('test',))
50 | print('Child process will start.')
51 | p.start()
52 | p.join()
53 | print('Child process end.')
54 |
55 | '''
56 | 创建子进程时,只需要传入一个执行函数和函数的参数,创建一个Process实例,用start()方法启动,这样创建进程比fork()还要简单。
57 |
58 | join()方法可以等待子进程结束后再继续往下运行,通常用于进程间的同步。
59 | '''
60 |
61 | '''
62 | Pool
63 | '''
64 | # 如果要启动大量的子进程,可以用进程池的方式批量创建子进程:
65 | from multiprocessing import Pool
66 | import os, time, random
67 |
68 | def long_time_task(name):
69 | print('Run task %s (%s)...' % (name, os.getpid()))
70 | start = time.time()
71 | time.sleep(random.random() * 3)
72 | end = time.time()
73 | print('Task %s runs %0.2f seconds.' % (name, (end - start)))
74 |
75 | if __name__=='__main__':
76 | print('Parent process %s.' % os.getpid())
77 | p = Pool(4)
78 | for i in range(5):
79 | p.apply_async(long_time_task, args=(i,))
80 | print('Waiting for all subprocesses done...')
81 | p.close()
82 | p.join()
83 | print('All subprocesses done.')
84 |
85 | '''
86 | 对Pool对象调用join()方法会等待所有子进程执行完毕,调用join()之前必须先调用close(),调用close()之后就不能继续添加新的Process了。
87 |
88 | 请注意输出的结果,task 0,1,2,3是立刻执行的,而task 4要等待前面某个task完成后才执行,这是因为Pool的默认大小在我的电脑上是4,因此,最多同时执行4个进程。这是Pool有意设计的限制,并不是操作系统的限制。如果改成:
89 | '''
90 | p = Pool(5)
91 |
92 | '''
93 | 子进程
94 | '''
95 | '''
96 | 很多时候,子进程并不是自身,而是一个外部进程。我们创建了子进程后,还需要控制子进程的输入和输出。
97 |
98 | subprocess模块可以让我们非常方便地启动一个子进程,然后控制其输入和输出。
99 |
100 | 下面的例子演示了如何在Python代码中运行命令nslookup www.python.org,这和命令行直接运行的效果是一样的:
101 | '''
102 | import subprocess
103 |
104 | print('$ nslookup www.python.org')
105 | r = subprocess.call(['nslookup', 'www.python.org'])
106 | print('Exit code:', r)
107 |
108 | # 如果子进程还需要输入,则可以通过communicate()方法输入:
109 | import subprocess
110 |
111 | print('$ nslookup')
112 | p = subprocess.Popen(['nslookup'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
113 | output, err = p.communicate(b'set q=mx\npython.org\nexit\n')
114 | print(output.decode('utf-8'))
115 | print('Exit code:', p.returncode)
116 |
117 | # 上面的代码相当于在命令行执行命令nslookup,然后手动输入:
118 | '''
119 | set q=mx
120 | python.org
121 | exit
122 | '''
123 |
124 |
125 | '''
126 | 进程间通信
127 | '''
128 | '''
129 | Process之间肯定是需要通信的,操作系统提供了很多机制来实现进程间的通信。Python的multiprocessing模块包装了底层的机制,提供了Queue、Pipes等多种方式来交换数据。
130 |
131 | 我们以Queue为例,在父进程中创建两个子进程,一个往Queue里写数据,一个从Queue里读数据:
132 | '''
133 | from multiprocessing import Process, Queue
134 | import os, time, random
135 |
136 | #写数据进程执行代码
137 | def write(q):
138 | print('Process to write:%s' % os.getpid())
139 | for value in ['A','B','C']:
140 | print('Put %s to queue.' % value)
141 | q.put(value)
142 | time.sleep(random.random())
143 |
144 | #读数据进程执行代码
145 | def read(q):
146 | print('Process to read: %s' % os.getpid())
147 | while True:
148 | value = q.get(True)
149 | print('Get %s from queue' % value)
150 |
151 | def main():
152 | q = Queue()
153 | pw = Process(target=write, args=(q,))
154 | pr = Process(target=read, args=(q,))
155 | #启动子进程pw写入
156 | pw.start()
157 | pr.start()
158 | pw.join()
159 | pr.terminate()
160 |
161 | '''
162 | 在Unix/Linux下,multiprocessing模块封装了fork()调用,使我们不需要关注fork()的细节。由于Windows没有fork调用,因此,multiprocessing需要“模拟”出fork的效果,父进程所有Python对象都必须通过pickle序列化再传到子进程去,所以,如果multiprocessing在Windows下调用失败了,要先考虑是不是pickle失败了。
163 | '''
164 |
165 |
166 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/6.进程与线程/2多线程.py:
--------------------------------------------------------------------------------
1 | '''
2 | 多任务可以由多进程完成,也可以由一个进程内的多线程完成。
3 |
4 | 我们前面提到了进程是由若干线程组成的,一个进程至少有一个线程。
5 |
6 | 由于线程是操作系统直接支持的执行单元,因此,高级语言通常都内置多线程的支持,Python也不例外,并且,Python的线程是真正的Posix Thread,而不是模拟出来的线程。
7 | '''
8 | '''
9 | Python的标准库提供了两个模块:_thread和threading,_thread是低级模块,threading是高级模块,对_thread进行了封装。绝大多数情况下,我们只需要使用threading这个高级模块。
10 |
11 | 启动一个线程就是把一个函数传入并创建Thread实例,然后调用start()开始执行:
12 | '''
13 | import time, threading
14 | # 新线程执行的代码:
15 | def loop():
16 | print('thread %s is running...' % threading.current_thread().name)
17 | n = 0
18 | while n < 5:
19 | n = n + 1
20 | print('thread %s >>> %s' % (threading.current_thread().name, n))
21 | time.sleep(1)
22 | print('thread %s ended.' % threading.current_thread().name)
23 |
24 | print('thread %s is running...' % threading.current_thread().name)
25 | t = threading.Thread(target=loop, name='LoopThread')
26 | t.start()
27 | t.join()
28 | print('thread %s ended.' % threading.current_thread().name)
29 | '''
30 | 由于任何进程默认就会启动一个线程,我们把该线程称为主线程,主线程又可以启动新的线程,Python的threading模块有个current_thread()函数,它永远返回当前线程的实例。主线程实例的名字叫MainThread,子线程的名字在创建时指定,我们用LoopThread命名子线程。名字仅仅在打印时用来显示,完全没有其他意义,如果不起名字Python就自动给线程命名为Thread-1,Thread-2……
31 | '''
32 |
33 | '''
34 | Lock
35 | '''
36 | '''
37 | 多线程和多进程最大的不同在于,多进程中,同一个变量,各自有一份拷贝存在于每个进程中,互不影响,而多线程中,所有变量都由所有线程共享,所以,任何一个变量都可以被任何一个线程修改,因此,线程之间共享数据最大的危险在于多个线程同时改一个变量,把内容给改乱了。
38 |
39 | 来看看多个线程同时操作一个变量怎么把内容给改乱了:
40 | '''
41 | import time, threading
42 |
43 | # 假定这是你的银行存款:
44 | balance = 0
45 |
46 | def change_it(n):
47 | # 先存后取,结果应该为0:
48 | global balance
49 | balance = balance + n
50 | balance = balance - n
51 |
52 | def run_thread(n):
53 | for i in range(100000):
54 | change_it(n)
55 |
56 | t1 = threading.Thread(target=run_thread, args=(5,))
57 | t2 = threading.Thread(target=run_thread, args=(8,))
58 | t1.start()
59 | t2.start()
60 | t1.join()
61 | t2.join()
62 | print(balance)
63 | '''
64 | 究其原因,是因为修改balance需要多条语句,而执行这几条语句时,线程可能中断,从而导致多个线程把同一个对象的内容改乱了。
65 |
66 | 两个线程同时一存一取,就可能导致余额不对,你肯定不希望你的银行存款莫名其妙地变成了负数,所以,我们必须确保一个线程在修改balance的时候,别的线程一定不能改。
67 |
68 | 如果我们要确保balance计算正确,就要给change_it()上一把锁,当某个线程开始执行change_it()时,我们说,该线程因为获得了锁,因此其他线程不能同时执行change_it(),只能等待,直到锁被释放后,获得该锁以后才能改。由于锁只有一个,无论多少线程,同一时刻最多只有一个线程持有该锁,所以,不会造成修改的冲突。创建一个锁就是通过threading.Lock()来实现:
69 | '''
70 | balance = 0
71 | lock = threading.Lock()
72 |
73 | def run_thread(n):
74 | for i in range(100000):
75 | # 先要获取锁:
76 | lock.acquire()
77 | try:
78 | # 放心地改吧:
79 | change_it(n)
80 | finally:
81 | # 改完了一定要释放锁:
82 | lock.release()
83 | '''
84 | 当多个线程同时执行lock.acquire()时,只有一个线程能成功地获取锁,然后继续执行代码,其他线程就继续等待直到获得锁为止。
85 |
86 | 获得锁的线程用完后一定要释放锁,否则那些苦苦等待锁的线程将永远等待下去,成为死线程。所以我们用try...finally来确保锁一定会被释放。
87 |
88 | 锁的好处就是确保了某段关键代码只能由一个线程从头到尾完整地执行,坏处当然也很多,首先是阻止了多线程并发执行,包含锁的某段代码实际上只能以单线程模式执行,效率就大大地下降了。其次,由于可以存在多个锁,不同的线程持有不同的锁,并试图获取对方持有的锁时,可能会造成死锁,导致多个线程全部挂起,既不能执行,也无法结束,只能靠操作系统强制终止。
89 | '''
90 |
91 |
92 | '''
93 | 多核CPU
94 | '''
95 | '''
96 | 如果你不幸拥有一个多核CPU,你肯定在想,多核应该可以同时执行多个线程。
97 |
98 | 如果写一个死循环的话,会出现什么情况呢?
99 |
100 | 打开Mac OS X的Activity Monitor,或者Windows的Task Manager,都可以监控某个进程的CPU使用率。
101 |
102 | 我们可以监控到一个死循环线程会100%占用一个CPU。
103 |
104 | 如果有两个死循环线程,在多核CPU中,可以监控到会占用200%的CPU,也就是占用两个CPU核心。
105 |
106 | 要想把N核CPU的核心全部跑满,就必须启动N个死循环线程。
107 |
108 | 试试用Python写个死循环:
109 | '''
110 | import threading, multiprocessing
111 |
112 | def loop():
113 | x = 0
114 | while True:
115 | x = x ^ 1
116 |
117 | for i in range(multiprocessing.cpu_count()):
118 | t = threading.Thread(target=loop)
119 | t.start()
120 |
121 | '''
122 | 启动与CPU核心数量相同的N个线程,在4核CPU上可以监控到CPU占用率仅有102%,也就是仅使用了一核。
123 |
124 | 但是用C、C++或Java来改写相同的死循环,直接可以把全部核心跑满,4核就跑到400%,8核就跑到800%,为什么Python不行呢?
125 |
126 | 因为Python的线程虽然是真正的线程,但解释器执行代码时,有一个GIL锁:Global Interpreter Lock,任何Python线程执行前,必须先获得GIL锁,然后,每执行100条字节码,解释器就自动释放GIL锁,让别的线程有机会执行。这个GIL全局锁实际上把所有线程的执行代码都给上了锁,所以,多线程在Python中只能交替执行,即使100个线程跑在100核CPU上,也只能用到1个核。
127 |
128 | GIL是Python解释器设计的历史遗留问题,通常我们用的解释器是官方实现的CPython,要真正利用多核,除非重写一个不带GIL的解释器。
129 |
130 | 所以,在Python中,可以使用多线程,但不要指望能有效利用多核。如果一定要通过多线程利用多核,那只能通过C扩展来实现,不过这样就失去了Python简单易用的特点。
131 |
132 | 不过,也不用过于担心,Python虽然不能利用多线程实现多核任务,但可以通过多进程实现多核任务。多个Python进程有各自独立的GIL锁,互不影响。
133 | '''
134 |
135 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/6.进程与线程/3ThreadLocal.py:
--------------------------------------------------------------------------------
1 | '''
2 | 在多线程环境下,每个线程都有自己的数据。一个线程使用自己的局部变量比使用全局变量好,因为局部变量只有线程自己能看见,不会影响其他线程,而全局变量的修改必须加锁。
3 |
4 | 但是局部变量也有问题,就是在函数调用的时候,传递起来很麻烦:
5 | '''
6 | # def process_student(name):
7 | # std = Student(name)
8 | # # std是局部变量,但是每个函数都要用它,因此必须传进去:
9 | # do_task_1(std)
10 | # do_task_2(std)
11 | #
12 | # def do_task_1(std):
13 | # do_subtask_1(std)
14 | # do_subtask_2(std)
15 | #
16 | # def do_task_2(std):
17 | # do_subtask_2(std)
18 | # do_subtask_2(std)
19 | '''
20 | 每个函数一层一层调用都这么传参数那还得了?用全局变量?也不行,因为每个线程处理不同的Student对象,不能共享。
21 | 如果用一个全局dict存放所有的Student对象,然后以thread自身作为key获得线程对应的Student对象如何?
22 | '''
23 | import threading
24 | global_dict={}
25 | def std_thread(name):
26 | std = Student(name)
27 | # 把std放到全局变量global_dict中:
28 | global_dict[threading.current_thread()] = std
29 | do_task_1()
30 | do_task_2()
31 |
32 | def do_task_1():
33 | # 不传入std,而是根据当前线程查找:
34 | std = global_dict[threading.current_thread()]
35 |
36 | def do_task_2():
37 | # 任何函数都可以查找出当前线程的std变量:
38 | std = global_dict[threading.current_thread()]
39 |
40 | '''
41 | 这种方式理论上是可行的,它最大的优点是消除了std对象在每层函数中的传递问题,但是,每个函数获取std的代码有点丑。
42 | 有没有更简单的方式?
43 | ThreadLocal应运而生,不用查找dict,ThreadLocal帮你自动做这件事:
44 | '''
45 | import threading
46 | # 创建全局ThreadLocal对象:
47 | local_school = threading.local()
48 |
49 | def process_student():
50 | # 获取当前线程关联的student:
51 | std = local_school.student
52 | print('Hello, %s (in %s)' % (std, threading.current_thread().name))
53 |
54 | def process_thread(name):
55 | # 绑定ThreadLocal的student:
56 | local_school.student = name
57 | process_student()
58 |
59 | t1 = threading.Thread(target= process_thread, args=('Alice',), name='Thread-A')
60 | t2 = threading.Thread(target= process_thread, args=('Bob',), name='Thread-B')
61 | t1.start()
62 | t2.start()
63 | t1.join()
64 | t2.join()
65 |
66 | '''
67 | 全局变量local_school就是一个ThreadLocal对象,每个Thread对它都可以读写student属性,但互不影响。你可以把local_school看成全局变量,但每个属性如local_school.student都是线程的局部变量,可以任意读写而互不干扰,也不用管理锁的问题,ThreadLocal内部会处理。
68 |
69 | 可以理解为全局变量local_school是一个dict,不但可以用local_school.student,还可以绑定其他变量,如local_school.teacher等等。
70 |
71 | ThreadLocal最常用的地方就是为每个线程绑定一个数据库连接,HTTP请求,用户身份信息等,这样一个线程的所有调用到的处理函数都可以非常方便地访问这些资源。
72 | '''
73 |
74 |
75 | '''进程VS线程'''
76 | '''计算密集型 vs. IO密集型'''
77 | '''
78 | 是否采用多任务的第二个考虑是任务的类型。我们可以把任务分为计算密集型和IO密集型。
79 | 这种计算密集型任务虽然也可以用多任务完成,但是任务越多,花在任务切换的时间就越多,CPU执行任务的效率就越低,所以,要最高效地利用CPU,计算密集型任务同时进行的数量应当等于CPU的核心数。
80 | 计算密集型任务由于主要消耗CPU资源,因此,代码运行效率至关重要。Python这样的脚本语言运行效率很低,完全不适合计算密集型任务。对于计算密集型任务,最好用C语言编写。
81 | '''
82 | '''
83 | 第二种任务的类型是IO密集型,涉及到网络、磁盘IO的任务都是IO密集型任务,这类任务的特点是CPU消耗很少,任务的大部分时间都在等待IO操作完成(因为IO的速度远远低于CPU和内存的速度)。对于IO密集型任务,任务越多,CPU效率越高,但也有一个限度。常见的大部分任务都是IO密集型任务,比如Web应用。
84 |
85 | IO密集型任务执行期间,99%的时间都花在IO上,花在CPU上的时间很少,因此,用运行速度极快的C语言替换用Python这样运行速度极低的脚本语言,完全无法提升运行效率。对于IO密集型任务,最合适的语言就是开发效率最高(代码量最少)的语言,脚本语言是首选,C语言最差。
86 | '''
87 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/6.进程与线程/4分布式进程.py:
--------------------------------------------------------------------------------
1 | '''
2 | 在Thread和Process中,应当优选Process,因为Process更稳定,而且,Process可以分布到多台机器上,而Thread最多只能分布到同一台机器的多个CPU上。
3 | Python的multiprocessing模块不但支持多进程,其中managers子模块还支持把多进程分布到多台机器上。一个服务进程可以作为调度者,将任务分布到其他多个进程中,依靠网络通信。由于managers模块封装很好,不必了解网络通信的细节,就可以很容易地编写分布式多进程程序。
4 | 举个例子:如果我们已经有一个通过Queue通信的多进程程序在同一台机器上运行,现在,由于处理任务的进程任务繁重,希望把发送任务的进程和处理任务的进程分布到两台机器上。怎么用分布式进程实现?
5 |
6 | 原有的Queue可以继续使用,但是,通过managers模块把Queue通过网络暴露出去,就可以让其他机器的进程访问Queue了。
7 |
8 | 我们先看服务进程,服务进程负责启动Queue,把Queue注册到网络上,然后往Queue里面写入任务:
9 | '''
10 | # task_master.py
11 | import random, time, queue
12 | from multiprocessing.managers import BaseManager
13 |
14 | # 发送任务的队列:
15 | task_queue = queue.Queue()
16 | # 接收结果的队列:
17 | result_queue = queue.Queue()
18 |
19 | # 从BaseManager继承的QueueManager:
20 | class QueueManager(BaseManager):
21 | pass
22 |
23 | # 把两个Queue都注册到网络上, callable参数关联了Queue对象:
24 | QueueManager.register('get_task_queue', callable=lambda: task_queue)
25 | QueueManager.register('get_result_queue', callable=lambda: result_queue)
26 | # 绑定端口5000, 设置验证码'abc':
27 | manager = QueueManager(address=('', 5000), authkey=b'abc')
28 | # 启动Queue:
29 | manager.start()
30 | # 获得通过网络访问的Queue对象:
31 | task = manager.get_task_queue()
32 | result = manager.get_result_queue()
33 | # 放几个任务进去:
34 | for i in range(10):
35 | n = random.randint(0, 10000)
36 | print('Put task %d...' % n)
37 | task.put(n)
38 | # 从result队列读取结果:
39 | print('Try get results...')
40 | for i in range(10):
41 | r = result.get(timeout=10)
42 | print('Result: %s' % r)
43 | # 关闭:
44 | manager.shutdown()
45 | print('master exit.')
46 | '''
47 | 请注意,当我们在一台机器上写多进程程序时,创建的Queue可以直接拿来用,但是,在分布式多进程环境下,添加任务到Queue不可以直接对原始的task_queue进行操作,那样就绕过了QueueManager的封装,必须通过manager.get_task_queue()获得的Queue接口添加。
48 |
49 | 然后,在另一台机器上启动任务进程(本机上启动也可以):
50 | '''
51 | # task_worker.py
52 | import time, sys, queue
53 | from multiprocessing.managers import BaseManager
54 |
55 | # 创建类似的QueueManager:
56 | class QueueManager(BaseManager):
57 | pass
58 |
59 | # 由于这个QueueManager只从网络上获取Queue,所以注册时只提供名字:
60 | QueueManager.register('get_task_queue')
61 | QueueManager.register('get_result_queue')
62 |
63 | # 连接到服务器,也就是运行task_master.py的机器:
64 | server_addr = '127.0.0.1'
65 | print('Connect to server %s...' % server_addr)
66 | # 端口和验证码注意保持与task_master.py设置的完全一致:
67 | m = QueueManager(address=(server_addr, 5000), authkey=b'abc')
68 | # 从网络连接:
69 | m.connect()
70 | # 获取Queue的对象:
71 | task = m.get_task_queue()
72 | result = m.get_result_queue()
73 | # 从task队列取任务,并把结果写入result队列:
74 | for i in range(10):
75 | try:
76 | n = task.get(timeout=1)
77 | print('run task %d * %d...' % (n, n))
78 | r = '%d * %d = %d' % (n, n, n*n)
79 | time.sleep(1)
80 | result.put(r)
81 | except queue.Empty:
82 | print('task queue is empty.')
83 | # 处理结束:
84 | print('worker exit.')
85 |
86 | # task_master.py进程发送完任务后,开始等待result队列的结果。现在启动task_worker.py进程:
87 |
88 | # task_worker.py进程结束,在task_master.py进程中会继续打印出结果:
89 |
90 | '''
91 | 这个简单的Master/Worker模型有什么用?其实这就是一个简单但真正的分布式计算,把代码稍加改造,启动多个worker,就可以把任务分布到几台甚至几十台机器上,比如把计算n*n的代码换成发送邮件,就实现了邮件队列的异步发送。
92 |
93 | Queue对象存储在哪?注意到task_worker.py中根本没有创建Queue的代码,所以,Queue对象存储在task_master.py进程中:
94 | '''
95 | '''
96 | 而Queue之所以能通过网络访问,就是通过QueueManager实现的。由于QueueManager管理的不止一个Queue,所以,要给每个Queue的网络调用接口起个名字,比如get_task_queue。
97 |
98 | authkey有什么用?这是为了保证两台机器正常通信,不被其他机器恶意干扰。如果task_worker.py的authkey和task_master.py的authkey不一致,肯定连接不上。
99 | '''
100 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/7.正则表达式/正则表达式.py:
--------------------------------------------------------------------------------
1 | '''
2 | 字符串是编程时涉及到的最多的一种数据结构,对字符串进行操作的需求几乎无处不在。比如判断一个字符串是否是合法的Email地址,虽然可以编程提取@前后的子串,再分别判断是否是单词和域名,但这样做不但麻烦,而且代码难以复用。
3 |
4 | 正则表达式是一种用来匹配字符串的强有力的武器。它的设计思想是用一种描述性的语言来给字符串定义一个规则,凡是符合规则的字符串,我们就认为它“匹配”了,否则,该字符串就是不合法的。
5 |
6 | 所以我们判断一个字符串是否是合法的Email的方法是:
7 |
8 | 创建一个匹配Email的正则表达式;
9 |
10 | 用该正则表达式去匹配用户的输入来判断是否合法。
11 |
12 | 因为正则表达式也是用字符串表示的,所以,我们要首先了解如何用字符来描述字符。
13 |
14 | 在正则表达式中,如果直接给出字符,就是精确匹配。用\d可以匹配一个数字,\w可以匹配一个字母或数字,所以:
15 |
16 | '00\d'可以匹配'007',但无法匹配'00A';
17 |
18 | '\d\d\d'可以匹配'010';
19 |
20 | '\w\w\d'可以匹配'py3';
21 |
22 | .可以匹配任意字符,所以:
23 |
24 | 'py.'可以匹配'pyc'、'pyo'、'py!'等等。
25 | 要匹配变长的字符,在正则表达式中,用*表示任意个字符(包括0个),用+表示至少一个字符,用?表示0个或1个字符,用{n}表示n个字符,用{n,m}表示n-m个字符:
26 |
27 | 来看一个复杂的例子:\d{3}\s+\d{3,8}。
28 |
29 | 我们来从左到右解读一下:
30 |
31 | \d{3}表示匹配3个数字,例如'010';
32 |
33 | \s可以匹配一个空格(也包括Tab等空白符),所以\s+表示至少有一个空格,例如匹配' ',' '等;
34 |
35 | \d{3,8}表示3-8个数字,例如'1234567'。
36 |
37 | 综合起来,上面的正则表达式可以匹配以任意个空格隔开的带区号的电话号码。
38 |
39 | 如果要匹配'010-12345'这样的号码呢?由于'-'是特殊字符,在正则表达式中,要用'\'转义,所以,上面的正则是\d{3}\-\d{3,8}。
40 |
41 | 但是,仍然无法匹配'010 - 12345',因为带有空格。所以我们需要更复杂的匹配方式。
42 | '''
43 |
44 |
45 | # 进阶
46 | '''
47 | 要做更精确地匹配,可以用[]表示范围,比如:
48 |
49 | [0-9a-zA-Z\_]可以匹配一个数字、字母或者下划线;
50 |
51 | [0-9a-zA-Z\_]+可以匹配至少由一个数字、字母或者下划线组成的字符串,比如'a100','0_Z','Py3000'等等;
52 |
53 | [a-zA-Z\_][0-9a-zA-Z\_]*可以匹配由字母或下划线开头,后接任意个由一个数字、字母或者下划线组成的字符串,也就是Python合法的变量;
54 |
55 | [a-zA-Z\_][0-9a-zA-Z\_]{0, 19}更精确地限制了变量的长度是1-20个字符(前面1个字符+后面最多19个字符)。
56 |
57 | A|B可以匹配A或B,所以(P|p)ython可以匹配'Python'或者'python'。
58 |
59 | ^表示行的开头,^\d表示必须以数字开头。
60 |
61 | $表示行的结束,\d$表示必须以数字结束。
62 |
63 | 你可能注意到了,py也可以匹配'python',但是加上^py$就变成了整行匹配,就只能匹配'py'了。
64 | '''
65 |
66 | # re模块
67 | '''
68 | 有了准备知识,我们就可以在Python中使用正则表达式了。Python提供re模块,包含所有正则表达式的功能。由于Python的字符串本身也用\转义,所以要特别注意:
69 | '''
70 | s = 'ABC\\-001' # Python的字符串
71 | # 对应的正则表达式字符串变成:
72 | # 'ABC\-001'
73 | '''因此我们强烈建议使用Python的r前缀,就不用考虑转义的问题了:'''
74 | s = r'ABC\-001' # Python的字符串
75 | # 对应的正则表达式字符串不变:
76 | # 'ABC\-001'
77 | # match()方法判断是否匹配,如果匹配成功,返回一个Match对象,否则返回None。常见的判断方法就是:
78 | import re
79 | test = '用户输入的字符串'
80 | if re.match(r'正则表达式', test):
81 | print('ok')
82 | else:
83 | print('failed')
84 |
85 | '''切分字符串'''
86 | # 用正则表达式切分字符串比用固定的字符更灵活,请看正常的切分代码:
87 | re.split(r'\s+', 'a b c')
88 | re.split(r'[\s\,]+', 'a,b, c d')
89 | re.split(r'[\s\,\;]+', 'a,b;; c d')
90 |
91 | # 分组
92 | '''
93 | 除了简单地判断是否匹配之外,正则表达式还有提取子串的强大功能。用()表示的就是要提取的分组(Group)。比如:
94 | ^(\d{3})-(\d{3,8})$分别定义了两个组,可以直接从匹配的字符串中提取出区号和本地号码:
95 | '''
96 | m = re.match(r'^(\d{3})-(\d{3,8})$', '010-12345')
97 | m.groups()
98 | m.group(1)
99 | '''
100 | 如果正则表达式中定义了组,就可以在Match对象上用group()方法提取出子串来。
101 | 注意到group(0)永远是原始字符串,group(1)、group(2)……表示第1、2、……个子串。
102 | 提取子串非常有用。来看一个更凶残的例子:
103 | '''
104 | t = '19:05:30'
105 | m = re.match(r'^(0[0-9]|1[0-9]|2[0-3]|[0-9])\:(0[0-9]|1[0-9]|2[0-9]|3[0-9]|4[0-9]|5[0-9]|[0-9])\:(0[0-9]|1[0-9]|2[0-9]|3[0-9]|4[0-9]|5[0-9]|[0-9])$', t)
106 | m.groups()
107 |
108 | # 贪婪匹配
109 | re.match(r'^(\d+)(0*)$', '102300').groups()
110 | # 非贪婪匹配
111 | re.match(r'^(\d+?)(0*)$', '102300').groups()
112 |
113 | # 编译
114 | '''
115 | 当我们在Python中使用正则表达式时,re模块内部会干两件事情:
116 |
117 | 编译正则表达式,如果正则表达式的字符串本身不合法,会报错;
118 |
119 | 用编译后的正则表达式去匹配字符串。
120 |
121 | 如果一个正则表达式要重复使用几千次,出于效率的考虑,我们可以预编译该正则表达式,接下来重复使用时就不需要编译这个步骤了,直接匹配:
122 | '''
123 | re_telephone = re.compile(r'^(\d{3})-(\d{3,8})$')
124 | re_telephone.match('010-12345').groups()
125 | re_telephone.match('010-8086').groups()
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/10XML.py:
--------------------------------------------------------------------------------
1 | # XML虽然比JSON复杂,在Web中应用也不如以前多了,不过仍有很多地方在用,所以,有必要了解如何操作XML。
2 | '''
3 | DOM vs SAX
4 | 操作XML有两种方法:DOM和SAX。DOM会把整个XML读入内存,解析为树,因此占用内存大,解析慢,优点是可以任意遍历树的节点。SAX是流模式,边读边解析,占用内存小,解析快,缺点是我们需要自己处理事件。
5 |
6 | 正常情况下,优先考虑SAX,因为DOM实在太占内存。
7 |
8 | 在Python中使用SAX解析XML非常简洁,通常我们关心的事件是start_element,end_element和char_data,准备好这3个函数,然后就可以解析xml了。
9 |
10 | 举个例子,当SAX解析器读到一个节点时:
11 | '''
12 | # python
13 | '''
14 | 会产生3个事件:
15 |
16 | start_element事件,在读取时;
17 |
18 | char_data事件,在读取python时;
19 |
20 | end_element事件,在读取时。
21 |
22 | 用代码实验一下:
23 | '''
24 | from xml.parsers.expat import ParserCreate
25 |
26 | class DefaultSaxHandler(object):
27 | def start_element(self, name, attrs):
28 | print('sax:start_element: %s, attrs: %s' % (name, str(attrs)))
29 |
30 | def end_element(self, name):
31 | print('sax:end_element: %s' % name)
32 |
33 | def char_data(self, text):
34 | print('sax:char_data: %s' % text)
35 |
36 | xml = r'''
37 |
38 | - Python
39 | - Ruby
40 |
41 | '''
42 |
43 | handler = DefaultSaxHandler()
44 | parser = ParserCreate()
45 | parser.StartElementHandler = handler.start_element
46 | parser.EndElementHandler = handler.end_element
47 | parser.CharacterDataHandler = handler.char_data
48 | parser.Parse(xml)
49 |
50 | '''
51 | 需要注意的是读取一大段字符串时,CharacterDataHandler可能被多次调用,所以需要自己保存起来,在EndElementHandler里面再合并。
52 |
53 | 除了解析XML外,如何生成XML呢?99%的情况下需要生成的XML结构都是非常简单的,因此,最简单也是最有效的生成XML的方法是拼接字符串:
54 | '''
55 | def t():
56 | L = []
57 | L.append(r'')
58 | L.append(r'')
59 | L.append(r'some & data'.encode())
60 | L.append(r'')
61 | return ''.join(L)
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/11HTMLParser.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | 如果我们要编写一个搜索引擎,第一步是用爬虫把目标网站的页面抓下来,第二步就是解析该HTML页面,看看里面的内容到底是新闻、图片还是视频。
4 |
5 | 假设第一步已经完成了,第二步应该如何解析HTML呢?
6 |
7 | HTML本质上是XML的子集,但是HTML的语法没有XML那么严格,所以不能用标准的DOM或SAX来解析HTML。
8 |
9 | 好在Python提供了HTMLParser来非常方便地解析HTML,只需简单几行代码
10 | '''
11 | from html.parser import HTMLParser
12 | from html.entities import name2codepoint
13 |
14 | class MyHTMLParser(HTMLParser):
15 |
16 | def handle_starttag(self, tag, attrs):
17 | print('<%s>' % tag)
18 |
19 | def handle_endtag(self, tag):
20 | print('%s>' % tag)
21 |
22 | def handle_startendtag(self, tag, attrs):
23 | print('<%s/>' % tag)
24 |
25 | def handle_data(self, data):
26 | print(data)
27 |
28 | def handle_comment(self, data):
29 | print('')
30 |
31 | def handle_entityref(self, name):
32 | print('&%s;' % name)
33 |
34 | def handle_charref(self, name):
35 | print('%s;' % name)
36 |
37 | parser = MyHTMLParser()
38 | parser.feed('''
39 |
40 |
41 |
42 | Some html HTML tutorial...
END
43 | ''')
44 |
45 |
46 | '''feed()方法可以多次调用,也就是不一定一次把整个HTML字符串都塞进去,可以一部分一部分塞进去。
47 |
48 | 特殊字符有两种,一种是英文表示的 ,一种是数字表示的Ӓ,这两种字符都可以通过Parser解析出来。
49 |
50 | '''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/1datetime.py:
--------------------------------------------------------------------------------
1 | '''datetime是Python处理日期和时间的标准库。'''
2 | # 获取当前日期和时间
3 | '''我们先看如何获取当前日期和时间:'''
4 | from datetime import datetime
5 | now = datetime.now() # 获取当前datetime
6 | print(now)
7 | print(type(now))
8 | '''
9 | 注意到datetime是模块,datetime模块还包含一个datetime类,通过from datetime import datetime导入的才是datetime这个类。
10 |
11 | 如果仅导入import datetime,则必须引用全名datetime.datetime。
12 |
13 | datetime.now()返回当前日期和时间,其类型是datetime。
14 | '''
15 |
16 | # 获取指定日期和时间
17 | '''要指定某个日期和时间,我们直接用参数构造一个datetime:'''
18 | dt = datetime(2015, 4, 19, 12, 20) # 用指定日期时间创建datetime
19 | print(dt)
20 |
21 | # datetime转换为timestamp
22 | '''在计算机中,时间实际上是用数字表示的。我们把1970年1月1日 00:00:00 UTC+00:00时区的时刻称为epoch time,记为0(1970年以前的时间timestamp为负数),当前时间就是相对于epoch time的秒数,称为timestamp。
23 |
24 | 你可以认为:'''
25 | '''timestamp = 0 = 1970-1-1 00:00:00 UTC+0:00'''
26 | '''对应的北京时间是:'''
27 | '''timestamp = 0 = 1970-1-1 08:00:00 UTC+8:00'''
28 | '''可见timestamp的值与时区毫无关系,因为timestamp一旦确定,其UTC时间就确定了,转换到任意时区的时间也是完全确定的,这就是为什么计算机存储的当前时间是以timestamp表示的,因为全球各地的计算机在任意时刻的timestamp都是完全相同的(假定时间已校准)。
29 |
30 | 把一个datetime类型转换为timestamp只需要简单调用timestamp()方法:'''
31 | dt = datetime(2015, 4, 19, 12, 20) # 用指定日期时间创建datetime
32 | dt.timestamp() # 把datetime转换为timestamp
33 | '''注意Python的timestamp是一个浮点数。如果有小数位,小数位表示毫秒数。
34 |
35 | 某些编程语言(如Java和JavaScript)的timestamp使用整数表示毫秒数,这种情况下只需要把timestamp除以1000就得到Python的浮点表示方法。'''
36 |
37 | # timestamp转换为datetime
38 | '''要把timestamp转换为datetime,使用datetime提供的fromtimestamp()方法:'''
39 | t = 1429417200.0
40 | print(datetime.fromtimestamp(t))
41 | '''注意到timestamp是一个浮点数,它没有时区的概念,而datetime是有时区的。上述转换是在timestamp和本地时间做转换。
42 |
43 | 本地时间是指当前操作系统设定的时区。例如北京时区是东8区,则本地时间:'''
44 | '''2015-04-19 12:20:00'''
45 | '''实际上就是UTC+8:00时区的时间:'''
46 | '''2015-04-19 12:20:00 UTC+8:00'''
47 | '''而此刻的格林威治标准时间与北京时间差了8小时,也就是UTC+0:00时区的时间应该是:'''
48 | '''2015-04-19 04:20:00 UTC+0:00'''
49 | # timestamp也可以直接被转换到UTC标准时区的时间:
50 | t = 1429417200.0
51 | print(datetime.fromtimestamp(t)) # 本地时间
52 | print(datetime.utcfromtimestamp(t)) # UTC时间
53 |
54 | # str转换为datetime
55 | '''很多时候,用户输入的日期和时间是字符串,要处理日期和时间,首先必须把str转换为datetime。转换方法是通过datetime.strptime()实现,需要一个日期和时间的格式化字符串:'''
56 | cday = datetime.strptime('2015-6-1 18:19:59', '%Y-%m-%d %H:%M:%S')
57 | print(cday)
58 | # datetime转换为str
59 | '''如果已经有了datetime对象,要把它格式化为字符串显示给用户,就需要转换为str,转换方法是通过strftime()实现的,同样需要一个日期和时间的格式化字符串:'''
60 | now = datetime.now()
61 | print(now.strftime('%a, %b %d %H:%M'))
62 |
63 | # datetime加减
64 | '''对日期和时间进行加减实际上就是把datetime往后或往前计算,得到新的datetime。加减可以直接用+和-运算符,不过需要导入timedelta这个类:'''
65 | from datetime import datetime, timedelta
66 | now = datetime.now()
67 | now
68 | now + timedelta(hours=10)
69 | now - timedelta(days=1)
70 | now + timedelta(days=2, hours=12)
71 |
72 | # 本地时间转换为UTC时间
73 | '''本地时间是指系统设定时区的时间,例如北京时间是UTC+8:00时区的时间,而UTC时间指UTC+0:00时区的时间。
74 |
75 | 一个datetime类型有一个时区属性tzinfo,但是默认为None,所以无法区分这个datetime到底是哪个时区,除非强行给datetime设置一个时区:'''
76 | from datetime import datetime, timedelta, timezone
77 | tz_utc_8 = timezone(timedelta(hours=8))
78 | now = datetime.now()
79 | now
80 | dt = now.replace(tzinfo=tz_utc_8) # 强制设置为UTC+8:00
81 | dt
82 |
83 | # 时区转换
84 | '''我们可以先通过utcnow()拿到当前的UTC时间,再转换为任意时区的时间:'''
85 | utc_dt = datetime.utcnow().replace(tzinfo=timezone.utc)
86 | print(utc_dt)
87 | bj_dt = utc_dt.astimezone(timezone(timedelta(hours=8)))
88 | print(bj_dt)
89 | tokyo_dt = utc_dt.astimezone(timezone(timedelta(hours=9)))
90 | print(tokyo_dt)
91 | tokyo_dt2 = bj_dt.astimezone(timezone(timedelta(hours=9)))
92 | print(tokyo_dt2)
93 | '''时区转换的关键在于,拿到一个datetime时,要获知其正确的时区,然后强制设置时区,作为基准时间。
94 |
95 | 利用带时区的datetime,通过astimezone()方法,可以转换到任意时区。
96 |
97 | 注:不是必须从UTC+0:00时区转换到其他时区,任何带时区的datetime都可以正确转换,例如上述bj_dt到tokyo_dt的转换。'''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/2colleections.py:
--------------------------------------------------------------------------------
1 | '''collections是Python内建的一个集合模块,提供了许多有用的集合类。'''
2 | # namedtuple
3 | '''但是,看到(1, 2),很难看出这个tuple是用来表示一个坐标的。
4 |
5 | 定义一个class又小题大做了,这时,namedtuple就派上了用场:'''
6 | from collections import namedtuple
7 | Point = namedtuple('Point', ['x', 'y'])
8 | p = Point(1, 2)
9 | p.x
10 | p.y
11 |
12 | '''namedtuple是一个函数,它用来创建一个自定义的tuple对象,并且规定了tuple元素的个数,并可以用属性而不是索引来引用tuple的某个元素。
13 |
14 | 这样一来,我们用namedtuple可以很方便地定义一种数据类型,它具备tuple的不变性,又可以根据属性来引用,使用十分方便。
15 |
16 | 可以验证创建的Point对象是tuple的一种子类:'''
17 | isinstance(p, Point)
18 | isinstance(p, tuple)
19 |
20 | # 类似的,如果要用坐标和半径表示一个圆,也可以用namedtuple定义:
21 | Circle = namedtuple('Circle', ['x', 'y', 'r'])
22 |
23 | # deque
24 | '''
25 | 使用list存储数据时,按索引访问元素很快,但是插入和删除元素就很慢了,因为list是线性存储,数据量大的时候,插入和删除效率很低。
26 | deque是为了高效实现插入和删除操作的双向列表,适合用于队列和栈:
27 | '''
28 | from collections import deque
29 | q = deque(['a', 'b', 'c'])
30 | q.append('x')
31 | q.appendleft('y')
32 | q
33 | '''deque除了实现list的append()和pop()外,还支持appendleft()和popleft(),这样就可以非常高效地往头部添加或删除元素。'''
34 |
35 | # defaultdict
36 | '''使用dict时,如果引用的Key不存在,就会抛出KeyError。如果希望key不存在时,返回一个默认值,就可以用defaultdict:'''
37 | from collections import defaultdict
38 | dd = defaultdict(lambda: 'N/A')
39 | dd['key1'] = 'abc'
40 | dd['key1'] # key1存在
41 | dd['key2'] # key2不存在,返回默认值
42 |
43 | # OrderedDict
44 | '''使用dict时,Key是无序的。在对dict做迭代时,我们无法确定Key的顺序。
45 |
46 | 如果要保持Key的顺序,可以用OrderedDict:'''
47 | from collections import OrderedDict
48 | d = dict([('a', 1), ('b', 2), ('c', 3)])
49 | d # dict的Key是无序的
50 | od = OrderedDict([('a', 1), ('b', 2), ('c', 3)])
51 | od
52 |
53 | '''注意,OrderedDict的Key会按照插入的顺序排列,不是Key本身排序:'''
54 | od = OrderedDict()
55 | od['z'] = 1
56 | od['y'] = 2
57 | od['x'] = 3
58 | list(od.keys()) # 按照插入的Key的顺序返回
59 |
60 | '''OrderedDict可以实现一个FIFO(先进先出)的dict,当容量超出限制时,先删除最早添加的Key:'''
61 | from collections import OrderedDict
62 |
63 | class LastUpdatedOrderedDict(OrderedDict):
64 |
65 | def __init__(self, capacity):
66 | super(LastUpdatedOrderedDict, self).__init__()
67 | self._capacity = capacity
68 |
69 | def __setitem__(self, key, value):
70 | containsKey = 1 if key in self else 0
71 | if len(self) - containsKey >= self._capacity:
72 | last = self.popitem(last=False)
73 | print('remove:', last)
74 | if containsKey:
75 | del self[key]
76 | print('set:', (key, value))
77 | else:
78 | print('add:', (key, value))
79 | OrderedDict.__setitem__(self, key, value)
80 |
81 |
82 | # ChainMap
83 | '''
84 | ChainMap可以把一组dict串起来并组成一个逻辑上的dict。ChainMap本身也是一个dict,但是查找的时候,会按照顺序在内部的dict依次查找。
85 |
86 | 什么时候使用ChainMap最合适?举个例子:应用程序往往都需要传入参数,参数可以通过命令行传入,可以通过环境变量传入,还可以有默认参数。我们可以用ChainMap实现参数的优先级查找,即先查命令行参数,如果没有传入,再查环境变量,如果没有,就使用默认参数。
87 |
88 | 下面的代码演示了如何查找user和color这两个参数:
89 | '''
90 | from collections import ChainMap
91 | import os, argparse
92 | # 构造缺省参数:
93 | defaults = {
94 | 'color': 'red',
95 | 'user': 'guest'
96 | }
97 | # 构造命令行参数:
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument('-u', '--user')
100 | parser.add_argument('-c', '--color')
101 | namespace = parser.parse_args()
102 | command_line_args = {k: v for k, v in vars(namespace).items() if v}
103 | # 组合成ChainMap:
104 | combined = ChainMap(command_line_args, os.environ, defaults)
105 | # 打印参数:
106 | print('color=%s' % combined['color'])
107 | print('user=%s' % combined['user'])
108 |
109 | # Counter
110 | '''Counter是一个简单的计数器,例如,统计字符出现的个数:'''
111 | from collections import Counter
112 | c = Counter()
113 | for ch in 'programming':
114 | c[ch] = c[ch] + 1
115 | c
116 | '''Counter实际上也是dict的一个子类,上面的结果可以看出,字符'g'、'm'、'r'各出现了两次,其他字符各出现了一次。'''
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/3base64.py:
--------------------------------------------------------------------------------
1 | # Base64是一种用64个字符来表示任意二进制数据的方法。
2 | #
3 | # 用记事本打开exe、jpg、pdf这些文件时,我们都会看到一大堆乱码,因为二进制文件包含很多无法显示和打印的字符,所以,如果要让记事本这样的文本处理软件能处理二进制数据,就需要一个二进制到字符串的转换方法。Base64是一种最常见的二进制编码方法。
4 | #
5 | # Base64的原理很简单,首先,准备一个包含64个字符的数组:
6 | '''
7 | ['A', 'B', 'C', ... 'a', 'b', 'c', ... '0', '1', ... '+', '/']
8 | '''
9 | # 然后,对二进制数据进行处理,每3个字节一组,一共是3x8=24bit,划为4组,每组正好6个bit:
10 | '''
11 | 这样我们得到4个数字作为索引,然后查表,获得相应的4个字符,就是编码后的字符串。
12 |
13 | 所以,Base64编码会把3字节的二进制数据编码为4字节的文本数据,长度增加33%,好处是编码后的文本数据可以在邮件正文、网页等直接显示。
14 |
15 | 如果要编码的二进制数据不是3的倍数,最后会剩下1个或2个字节怎么办?Base64用\x00字节在末尾补足后,再在编码的末尾加上1个或2个=号,表示补了多少字节,解码的时候,会自动去掉。
16 | '''
17 | # Python内置的base64可以直接进行base64的编解码:
18 | import base64
19 | base64.b64encode(b'binary\x00string')
20 | base64.b64decode(b'YmluYXJ5AHN0cmluZw==')
21 |
22 | '''由于标准的Base64编码后可能出现字符+和/,在URL中就不能直接作为参数,所以又有一种"url safe"的base64编码,其实就是把字符+和/分别变成-和_:'''
23 | base64.b64encode(b'i\xb7\x1d\xfb\xef\xff')
24 | base64.urlsafe_b64encode(b'i\xb7\x1d\xfb\xef\xff')
25 |
26 | '''还可以自己定义64个字符的排列顺序,这样就可以自定义Base64编码,不过,通常情况下完全没有必要。
27 |
28 | Base64是一种通过查表的编码方法,不能用于加密,即使使用自定义的编码表也不行。
29 |
30 | Base64适用于小段内容的编码,比如数字证书签名、Cookie的内容等。
31 |
32 | 由于=字符也可能出现在Base64编码中,但=用在URL、Cookie里面会造成歧义,所以,很多Base64编码后会把=去掉:'''
33 | '''
34 | # 标准Base64:
35 | 'abcd' -> 'YWJjZA=='
36 | # 自动去掉=:
37 | 'abcd' -> 'YWJjZA'
38 | '''
39 | # 去掉=后怎么解码呢?因为Base64是把3个字节变为4个字节,所以,Base64编码的长度永远是4的倍数,因此,需要加上=把Base64字符串的长度变为4的倍数,就可以正常解码了。
40 |
41 | # Base64是一种任意二进制到文本字符串的编码方法,常用于在URL、Cookie、网页中传输少量二进制数据。
42 | # https://www.jianshu.com/p/b649bfa1a320
43 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/4struct.py:
--------------------------------------------------------------------------------
1 | # 准确地讲,Python没有专门处理字节的数据类型。但由于b'str'可以表示字节,所以,字节数组=二进制str。而在C语言中,我们可以很方便地用struct、union来处理字节,以及字节和int,float的转换。
2 | # # # # #
3 | # # # 在Python中,比方说要把一个32位无符号整数变成字节,也就是4个长度的bytes,你得配合位运算符这么写:
4 | n = 10240099
5 | b1 = (n & 0xff000000) >> 24
6 | b2 = (n & 0xff0000) >> 16
7 | b3 = (n & 0xff00) >> 8
8 | b4 = n & 0xff
9 | bs = bytes([b1, b2, b3, b4])
10 | bs
11 | # 1111 1111 0000 0000 0000 0000 0000 0000 = ff000000
12 | '''
13 | 非常麻烦。如果换成浮点数就无能为力了。
14 | 好在Python提供了一个struct模块来解决bytes和其他二进制数据类型的转换。
15 | struct的pack函数把任意数据类型变成bytes:
16 | '''
17 | # 调用bytes方法将字符串转成bytes对象
18 | b4 = bytes('我爱Python编程',encoding='utf-8')
19 | print(b4)
20 | b5 = "学习Python很有趣".encode('utf-8')
21 | print(b5)
22 |
23 |
24 | '''非常麻烦。如果换成浮点数就无能为力了。
25 |
26 | 好在Python提供了一个struct模块来解决bytes和其他二进制数据类型的转换。
27 |
28 | struct的pack函数把任意数据类型变成bytes:'''
29 | import struct
30 | struct.pack('>I', 10240099)
31 | # str(10240099)
32 | '''
33 | pack的第一个参数是处理指令,'>I'的意思是:
34 |
35 | >表示字节顺序是big-endian,也就是网络序,I表示4字节无符号整数。
36 |
37 | 后面的参数个数要和处理指令一致。
38 |
39 | unpack把bytes变成相应的数据类型:
40 | '''
41 | struct.unpack('>IH', b'\xf0\xf0\xf0\xf0\x80\x80')
42 |
43 | '''根据>IH的说明,后面的bytes依次变为I:4字节无符号整数和H:2字节无符号整数。
44 |
45 | 所以,尽管Python不适合编写底层操作字节流的代码,但在对性能要求不高的地方,利用struct就方便多了。
46 |
47 | struct模块定义的数据类型可以参考Python官方文档:'''
48 |
49 | # --------------------------------------------
50 | '''Windows的位图文件(.bmp)是一种非常简单的文件格式,我们来用struct分析一下。
51 |
52 | 首先找一个bmp文件,没有的话用“画图”画一个。
53 |
54 | 读入前30个字节来分析:'''
55 | s = b'\x42\x4d\x38\x8c\x0a\x00\x00\x00\x00\x00\x36\x00\x00\x00\x28\x00\x00\x00\x80\x02\x00\x00\x68\x01\x00\x00\x01\x00\x18\x00'
56 | '''BMP格式采用小端方式存储数据,文件头的结构按顺序如下:
57 |
58 | 两个字节:'BM'表示Windows位图,'BA'表示OS/2位图; 一个4字节整数:表示位图大小; 一个4字节整数:保留位,始终为0; 一个4字节整数:实际图像的偏移量; 一个4字节整数:Header的字节数; 一个4字节整数:图像宽度; 一个4字节整数:图像高度; 一个2字节整数:始终为1; 一个2字节整数:颜色数。
59 |
60 | 所以,组合起来用unpack读取:'''
61 | struct.unpack('" % name)
67 | yield
68 | print("%s>" % name)
69 |
70 | with tag("h1"):
71 | print("hello")
72 | print("world")
73 |
74 | '''
75 | 代码的执行顺序是:
76 |
77 | with语句首先执行yield之前的语句,因此打印出;
78 | yield调用会执行with语句内部的所有语句,因此打印出hello和world;
79 | 最后执行yield之后的语句,打印出
。
80 | 因此,@contextmanager让我们通过编写generator来简化上下文管理。
81 | '''
82 |
83 | '''
84 | @closing
85 | 如果一个对象没有实现上下文,我们就不能把它用于with语句。这个时候,可以用closing()来把该对象变为上下文对象。例如,用with语句使用urlopen():
86 | '''
87 | from contextlib import closing
88 | from urllib.request import urlopen
89 |
90 | with closing(urlopen('https://www.python.org')) as page:
91 | for line in page:
92 | print(line)
93 |
94 | # closing也是一个经过@contextmanager装饰的generator,这个generator编写起来其实非常简单:
95 | @contextmanager
96 | def closing(thing):
97 | try:
98 | yield thing
99 | finally:
100 | thing.close()
101 |
102 | '''
103 | 它的作用就是把任意对象变为上下文对象,并支持with语句。
104 | @contextlib还有一些其他decorator,便于我们编写更简洁的代码。
105 | '''
106 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/8.常用内建模块/9urllib.py:
--------------------------------------------------------------------------------
1 | # urllib提供了一系列用于操作URL的功能。
2 | '''
3 | Get
4 | urllib的request模块可以非常方便地抓取URL内容,也就是发送一个GET请求到指定的页面,然后返回HTTP的响应:
5 |
6 | 例如,对豆瓣的一个URLhttps://api.douban.com/v2/book/2129650进行抓取,并返回响应:
7 | '''
8 | import ssl
9 | ssl._create_default_https_context = ssl._create_unverified_context
10 | from urllib import request
11 |
12 | with request.urlopen('https://api.douban.com/v2/book/2129650') as f:
13 | data = f.read()
14 | print('Status:', f.status, f.reason)
15 | for k, v in f.getheaders():
16 | print('%s: %s' % (k, v))
17 | print('Data:', data.decode('utf-8'))
18 |
19 | # 可以看到HTTP响应的头和JSON数据:
20 |
21 | # 如果我们要想模拟浏览器发送GET请求,就需要使用Request对象,通过往Request对象添加HTTP头,我们就可以把请求伪装成浏览器。例如,模拟iPhone 6去请求豆瓣首页:
22 | from urllib import request
23 |
24 | req = request.Request('http://www.douban.com/')
25 | req.add_header('User-Agent', 'Mozilla/6.0 (iPhone; CPU iPhone OS 8_0 like Mac OS X) AppleWebKit/536.26 (KHTML, like Gecko) Version/8.0 Mobile/10A5376e Safari/8536.25')
26 | with request.urlopen(req) as f:
27 | print('Status:', f.status, f.reason)
28 | for k, v in f.getheaders():
29 | print('%s: %s' % (k, v))
30 | print('Data:', f.read().decode('utf-8'))
31 |
32 | '''
33 | Post
34 | 如果要以POST发送一个请求,只需要把参数data以bytes形式传入。
35 |
36 | 我们模拟一个微博登录,先读取登录的邮箱和口令,然后按照weibo.cn的登录页的格式以username=xxx&password=xxx的编码传入:
37 | '''
38 | from urllib import request, parse
39 |
40 | print('Login to weibo.cn...')
41 | email = input('Email:')
42 | passwd = input('Password:')
43 | login_data = parse.urlencode([
44 | ('username', email),
45 | ('password', passwd),
46 | ('entry', 'mweibo'),
47 | ('client_id', ''),
48 | ('savestate', '1'),
49 | ('ec', ''),
50 | ('pagerefer', 'https://passport.weibo.cn/signin/welcome?entry=mweibo&r=http%3A%2F%2Fm.weibo.cn%2F')
51 | ])
52 |
53 | req = request.Request('https://passport.weibo.cn/sso/login')
54 | req.add_header('Origin', 'https://passport.weibo.cn')
55 | req.add_header('User-Agent', 'Mozilla/6.0 (iPhone; CPU iPhone OS 8_0 like Mac OS X) AppleWebKit/536.26 (KHTML, like Gecko) Version/8.0 Mobile/10A5376e Safari/8536.25')
56 | req.add_header('Referer', 'https://passport.weibo.cn/signin/login?entry=mweibo&res=wel&wm=3349&r=http%3A%2F%2Fm.weibo.cn%2F')
57 |
58 | with request.urlopen(req, data=login_data.encode('utf-8')) as f:
59 | print('Status:', f.status, f.reason)
60 | for k, v in f.getheaders():
61 | print('%s: %s' % (k, v))
62 | print('Data:', f.read().decode('utf-8'))
63 |
64 | '''
65 | Handler
66 | 如果还需要更复杂的控制,比如通过一个Proxy去访问网站,我们需要利用ProxyHandler来处理,示例代码如下:
67 | '''
68 |
69 | proxy_handler = request.ProxyHandler({'http': 'http://www.example.com:3128/'})
70 | proxy_auth_handler = request.ProxyBasicAuthHandler()
71 | proxy_auth_handler.add_password('realm', 'host', 'username', 'password')
72 | opener = request.build_opener(proxy_handler, proxy_auth_handler)
73 | with opener.open('http://www.example.com/login.html') as f:
74 | pass
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/9.常用第三方模块/1Pillow.py:
--------------------------------------------------------------------------------
1 | '''
2 | 安装Pillow
3 | 如果安装了Anaconda,Pillow就已经可用了。否则,需要在命令行下通过pip安装:
4 | 操作图像
5 | 来看看最常见的图像缩放操作,只需三四行代码:
6 | '''
7 | from PIL import Image
8 |
9 | # 打开一个jpg图像文件,注意是当前路径:
10 | im = Image.open('test.jpg')
11 | # 获得图像尺寸:
12 | w, h = im.size
13 | print('Original image size: %sx%s' % (w, h))
14 | # 缩放到50%:
15 | im.thumbnail((w//2, h//2))
16 | print('Resize image to: %sx%s' % (w//2, h//2))
17 | # 把缩放后的图像用jpeg格式保存:
18 | im.save('thumbnail.jpg', 'jpeg')
19 |
20 | '''
21 | 其他功能如切片、旋转、滤镜、输出文字、调色板等一应俱全。
22 |
23 | 比如,模糊效果也只需几行代码:
24 | '''
25 | from PIL import Image, ImageFilter
26 |
27 | # 打开一个jpg图像文件,注意是当前路径:
28 | im = Image.open('test.png')
29 | # 应用模糊滤镜:
30 | im2 = im.filter(ImageFilter.BLUR)
31 | im2.save('blur.jpg', 'jpeg')
32 |
33 |
34 | # PIL的ImageDraw提供了一系列绘图方法,让我们可以直接绘图。比如要生成字母验证码图片:
35 | from PIL import Image, ImageDraw, ImageFont, ImageFilter
36 |
37 | import random
38 |
39 | # 随机字母:
40 | def rndChar():
41 | return chr(random.randint(65, 90))
42 |
43 | # 随机颜色1:
44 | def rndColor():
45 | return (random.randint(64, 255), random.randint(64, 255), random.randint(64, 255))
46 |
47 | # 随机颜色2:
48 | def rndColor2():
49 | return (random.randint(32, 127), random.randint(32, 127), random.randint(32, 127))
50 |
51 | # 240 x 60:
52 | width = 60 * 4
53 | height = 60
54 | image = Image.new('RGB', (width, height), (255, 255, 255))
55 | # 创建Font对象:
56 | font = ImageFont.truetype('Arial.ttf', 36)
57 | # 创建Draw对象:
58 | draw = ImageDraw.Draw(image)
59 | # 填充每个像素:
60 | for x in range(width):
61 | for y in range(height):
62 | draw.point((x, y), fill=rndColor())
63 | # 输出文字:
64 | for t in range(4):
65 | draw.text((60 * t + 10, 10), rndChar(), font=font, fill=rndColor2())
66 | # 模糊:
67 | image = image.filter(ImageFilter.BLUR)
68 | image.save('code.jpg', 'jpeg')
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/9.常用第三方模块/2requests.py:
--------------------------------------------------------------------------------
1 | '''
2 | 我们已经讲解了Python内置的urllib模块,用于访问网络资源。但是,它用起来比较麻烦,而且,缺少很多实用的高级功能。
3 |
4 | 更好的方案是使用requests。它是一个Python第三方库,处理URL资源特别方便。
5 | '''
6 |
7 |
8 | '''
9 | 如果遇到Permission denied安装失败,请加上sudo重试。
10 |
11 | 使用requests
12 | 要通过GET访问一个页面,只需要几行代码:
13 | '''
14 | import requests
15 | r = requests.get('https://www.douban.com/') # 豆瓣首页
16 | r.status_code
17 | r.text
18 |
19 | '''对于带参数的URL,传入一个dict作为params参数:'''
20 | r = requests.get('https://www.douban.com/search', params={'q': 'python', 'cat': '1001'})
21 | r.url # 实际请求的URL
22 |
23 | # requests自动检测编码,可以使用encoding属性查看:
24 | r.encoding
25 |
26 | # 无论响应是文本还是二进制内容,我们都可以用content属性获得bytes对象:
27 | r.content
28 |
29 | # requests的方便之处还在于,对于特定类型的响应,例如JSON,可以直接获取:
30 | r = requests.get('https://query.yahooapis.com/v1/public/yql?q=select%20*%20from%20weather.forecast%20where%20woeid%20%3D%202151330&format=json')
31 | r.json()
32 |
33 | # 需要传入HTTP Header时,我们传入一个dict作为headers参数:
34 | r = requests.get('https://www.douban.com/', headers={'User-Agent': 'Mozilla/5.0 (iPhone; CPU iPhone OS 11_0 like Mac OS X) AppleWebKit'})
35 | r.text
36 |
37 | # 要发送POST请求,只需要把get()方法变成post(),然后传入data参数作为POST请求的数据:
38 | r = requests.post('https://accounts.douban.com/login', data={'form_email': 'abc@example.com', 'form_password': '123456'})
39 |
40 | url = ''
41 | # requests默认使用application/x-www-form-urlencoded对POST数据编码。如果要传递JSON数据,可以直接传入json参数:
42 | params = {'key': 'value'}
43 | r = requests.post(url, json=params) # 内部自动序列化为JSON
44 |
45 | # 类似的,上传文件需要更复杂的编码格式,但是requests把它简化成files参数:
46 | upload_files = {'file': open('report.xls', 'rb')}
47 | r = requests.post(url, files=upload_files)
48 |
49 | '''
50 | 在读取文件时,注意务必使用'rb'即二进制模式读取,这样获取的bytes长度才是文件的长度。
51 |
52 | 把post()方法替换为put(),delete()等,就可以以PUT或DELETE方式请求资源。
53 |
54 | 除了能轻松获取响应内容外,requests对获取HTTP响应的其他信息也非常简单。例如,获取响应头:
55 | '''
56 | r.headers
57 | r.headers['Content-Type']
58 |
59 | # requests对Cookie做了特殊处理,使得我们不必解析Cookie就可以轻松获取指定的Cookie:
60 | r.cookies['ts']
61 | # 要在请求中传入Cookie,只需准备一个dict传入cookies参数:
62 | cs = {'token': '12345', 'status': 'working'}
63 | r = requests.get(url, cookies=cs)
64 |
65 | # 最后,要指定超时,传入以秒为单位的timeout参数:
66 | r = requests.get(url, timeout=2.5) # 2.5秒后超时
67 |
68 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/9.常用第三方模块/3chardet.py:
--------------------------------------------------------------------------------
1 | '''
2 | 字符串编码一直是令人非常头疼的问题,尤其是我们在处理一些不规范的第三方网页的时候。虽然Python提供了Unicode表示的str和bytes两种数据类型,并且可以通过encode()和decode()方法转换,但是,在不知道编码的情况下,对bytes做decode()不好做。
3 |
4 | 对于未知编码的bytes,要把它转换成str,需要先“猜测”编码。猜测的方式是先收集各种编码的特征字符,根据特征字符判断,就能有很大概率“猜对”。
5 |
6 | 当然,我们肯定不能从头自己写这个检测编码的功能,这样做费时费力。chardet这个第三方库正好就派上了用场。用它来检测编码,简单易用。
7 | '''
8 | # 使用chardet
9 | import chardet
10 | # 当我们拿到一个bytes时,就可以对其检测编码。用chardet检测编码,只需要一行代码:
11 | a = chardet.detect(b'Hello, world!')
12 |
13 | '''检测出的编码是ascii,注意到还有个confidence字段,表示检测的概率是1.0(即100%)。
14 |
15 | 我们来试试检测GBK编码的中文:'''
16 | data = '离离原上草,一岁一枯荣'.encode('gbk')
17 | chardet.detect(data)
18 | # 检测的编码是GB2312,注意到GBK是GB2312的超集,两者是同一种编码,检测正确的概率是74%,language字段指出的语言是'Chinese'。
19 | # 对UTF-8编码进行检测:
20 | data = '离离原上草,一岁一枯荣'.encode('utf-8')
21 | chardet.detect(data)
22 | # 日语
23 | data = '最新の主要ニュース'.encode('euc-jp')
24 | chardet.detect(data)
25 |
26 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/9.常用第三方模块/4psutil.py:
--------------------------------------------------------------------------------
1 | '''
2 | 用Python来编写脚本简化日常的运维工作是Python的一个重要用途。在Linux下,有许多系统命令可以让我们时刻监控系统运行的状态,如ps,top,free等等。要获取这些系统信息,Python可以通过subprocess模块调用并获取结果。但这样做显得很麻烦,尤其是要写很多解析代码。
3 |
4 | 在Python中获取系统信息的另一个好办法是使用psutil这个第三方模块。顾名思义,psutil = process and system utilities,它不仅可以通过一两行代码实现系统监控,还可以跨平台使用,支持Linux/UNIX/OSX/Windows等,是系统管理员和运维小伙伴不可或缺的必备模块。
5 | '''
6 |
7 | # 获取CPU信息
8 | # 我们先来获取CPU的信息:
9 | import psutil
10 | psutil.cpu_count() # CPU逻辑数量
11 | psutil.cpu_count(logical=False) # CPU物理核心
12 |
13 | # 统计CPU的用户/系统/空闲时间:
14 | psutil.cpu_times()
15 |
16 | # 再实现类似top命令的CPU使用率,每秒刷新一次,累计10次:
17 | for x in range(10):
18 | psutil.cpu_percent(interval=1, percpu=True)
19 |
20 | '''获取内存信息
21 | 使用psutil获取物理内存和交换内存信息,分别使用:'''
22 | psutil.virtual_memory()
23 | psutil.swap_memory()
24 |
25 | '''返回的是字节为单位的整数,可以看到,总内存大小是8589934592 = 8 GB,已用7201386496 = 6.7 GB,使用了66.6%。
26 |
27 | 而交换区大小是1073741824 = 1 GB。'''
28 |
29 | # 获取磁盘信息
30 | # 可以通过psutil获取磁盘分区、磁盘使用率和磁盘IO信息:
31 | psutil.disk_partitions() # 磁盘分区信息
32 | psutil.disk_usage('/') # 磁盘使用情况
33 | psutil.disk_io_counters() # 磁盘IO
34 | # 可以看到,磁盘'/'的总容量是998982549504 = 930 GB,使用了39.1%。文件格式是HFS,opts中包含rw表示可读写,journaled表示支持日志。
35 |
36 |
37 | '''
38 | 获取网络信息
39 | psutil可以获取网络接口和网络连接信息:
40 | '''
41 | psutil.net_io_counters() # 获取网络读写字节/包的个数
42 | psutil.net_if_addrs() # 获取网络接口信息
43 | psutil.net_if_stats() # 获取网络接口状态
44 |
45 | # 要获取当前网络连接信息,使用net_connections():
46 | psutil.net_connections()
47 |
48 | # 你可能会得到一个AccessDenied错误,原因是psutil获取信息也是要走系统接口,而获取网络连接信息需要root权限,这种情况下,可以退出Python交互环境,用sudo重新启动:
49 |
50 | '''
51 | 获取进程信息
52 | 通过psutil可以获取到所有进程的详细信息:
53 | '''
54 | psutil.pids() # 所有进程ID
55 | p = psutil.Process(3776) # 获取指定进程ID=3776,其实就是当前Python交互环境
56 | p.name() # 进程名称
57 | p.exe() # 进程exe路径
58 | p.cwd() # 进程工作目录
59 | p.cmdline() # 进程启动的命令行
60 | p.ppid() # 父进程ID
61 | p.parent() # 父进程
62 | p.children() # 子进程列表
63 | p.status() # 进程状态
64 | p.username() # 进程用户名
65 | p.create_time() # 进程创建时间
66 | p.terminal() # 进程终端
67 | p.cpu_times() # 进程使用的CPU时间
68 | p.memory_info() # 进程使用的内存
69 | p.open_files() # 进程打开的文件
70 | p.connections() # 进程相关网络连接
71 | p.num_threads() # 进程的线程数量
72 | p.threads() # 所有线程信息
73 | p.environ() # 进程环境变量
74 | p.terminate() # 结束进程
75 |
76 | '''
77 | 和获取网络连接类似,获取一个root用户的进程需要root权限,启动Python交互环境或者.py文件时,需要sudo权限。
78 | psutil还提供了一个test()函数,可以模拟出ps命令的效果:
79 | '''
80 |
--------------------------------------------------------------------------------
/python学习资料/PythonGrammer/9.常用第三方模块/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotYe/TMChanQuant/d092f06409fff9dc118f6f32fca5d3610dfbac3d/python学习资料/PythonGrammer/9.常用第三方模块/test.jpg
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/backtest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # backtest.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | import pprint
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 | import time
15 |
16 |
17 | class Backtest(object):
18 | """
19 | Enscapsulates the settings and components for carrying out
20 | an event-driven backtest.
21 | """
22 |
23 | def __init__(
24 | self, csv_dir, symbol_list, initial_capital,
25 | heartbeat, start_date, data_handler,
26 | execution_handler, portfolio, strategy
27 | ):
28 | """
29 | Initialises the backtest.
30 |
31 | Parameters:
32 | csv_dir - The hard root to the CSV data directory.
33 | symbol_list - The list of symbol strings.
34 | intial_capital - The starting capital for the portfolio.
35 | heartbeat - Backtest "heartbeat" in seconds
36 | start_date - The start datetime of the strategy.
37 | data_handler - (Class) Handles the market data feed.
38 | execution_handler - (Class) Handles the orders/fills for trades.
39 | portfolio - (Class) Keeps track of portfolio current and prior positions.
40 | strategy - (Class) Generates signals based on market data.
41 | """
42 | self.csv_dir = csv_dir
43 | self.symbol_list = symbol_list
44 | self.initial_capital = initial_capital
45 | self.heartbeat = heartbeat
46 | self.start_date = start_date
47 |
48 | self.data_handler_cls = data_handler
49 | self.execution_handler_cls = execution_handler
50 | self.portfolio_cls = portfolio
51 | self.strategy_cls = strategy
52 |
53 | self.events = queue.Queue()
54 |
55 | self.signals = 0
56 | self.orders = 0
57 | self.fills = 0
58 | self.num_strats = 1
59 |
60 | self._generate_trading_instances()
61 |
62 | def _generate_trading_instances(self):
63 | """
64 | Generates the trading instance objects from
65 | their class types.
66 | """
67 | print(
68 | "Creating DataHandler, Strategy, Portfolio and ExecutionHandler"
69 | )
70 | self.data_handler = self.data_handler_cls(self.events, self.csv_dir, self.symbol_list)
71 | self.strategy = self.strategy_cls(self.data_handler, self.events)
72 | self.portfolio = self.portfolio_cls(self.data_handler, self.events, self.start_date,
73 | self.initial_capital)
74 | self.execution_handler = self.execution_handler_cls(self.events)
75 |
76 | def _run_backtest(self):
77 | """
78 | Executes the backtest.
79 | """
80 | i = 0
81 | while True:
82 | i += 1
83 | print(i)
84 | # Update the market bars
85 | if self.data_handler.continue_backtest == True:
86 | self.data_handler.update_bars()
87 | else:
88 | break
89 |
90 | # Handle the events
91 | while True:
92 | try:
93 | event = self.events.get(False)
94 | except queue.Empty:
95 | break
96 | else:
97 | if event is not None:
98 | if event.type == 'MARKET':
99 | self.strategy.calculate_signals(event)
100 | self.portfolio.update_timeindex(event)
101 |
102 | elif event.type == 'SIGNAL':
103 | self.signals += 1
104 | self.portfolio.update_signal(event)
105 |
106 | elif event.type == 'ORDER':
107 | self.orders += 1
108 | self.execution_handler.execute_order(event)
109 |
110 | elif event.type == 'FILL':
111 | self.fills += 1
112 | self.portfolio.update_fill(event)
113 |
114 | time.sleep(self.heartbeat)
115 |
116 | def _output_performance(self):
117 | """
118 | Outputs the strategy performance from the backtest.
119 | """
120 | self.portfolio.create_equity_curve_dataframe()
121 |
122 | print("Creating summary stats...")
123 | stats = self.portfolio.output_summary_stats()
124 |
125 | print("Creating equity curve...")
126 | print(self.portfolio.equity_curve.tail(10))
127 | pprint.pprint(stats)
128 |
129 | print("Signals: %s" % self.signals)
130 | print("Orders: %s" % self.orders)
131 | print("Fills: %s" % self.fills)
132 |
133 | def simulate_trading(self):
134 | """
135 | Simulates the backtest and outputs portfolio performance.
136 | """
137 | self._run_backtest()
138 | self._output_performance()
139 |
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/event.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # event.py
5 |
6 | from __future__ import print_function
7 |
8 |
9 | class Event(object):
10 | """
11 | Event is base class providing an interface for all subsequent
12 | (inherited) events, that will trigger further events in the
13 | trading infrastructure.
14 | """
15 | pass
16 |
17 |
18 | class MarketEvent(Event):
19 | """
20 | Handles the event of receiving a new market update with
21 | corresponding bars.
22 | """
23 |
24 | def __init__(self):
25 | """
26 | Initialises the MarketEvent.
27 | """
28 | self.type = 'MARKET'
29 |
30 |
31 | class SignalEvent(Event):
32 | """
33 | Handles the event of sending a Signal from a Strategy object.
34 | This is received by a Portfolio object and acted upon.
35 | """
36 |
37 | def __init__(self, strategy_id, symbol, datetime, signal_type, strength):
38 | """
39 | Initialises the SignalEvent.
40 |
41 | Parameters:
42 | strategy_id - The unique ID of the strategy sending the signal.
43 | symbol - The ticker symbol, e.g. 'GOOG'.
44 | datetime - The timestamp at which the signal was generated.
45 | signal_type - 'LONG' or 'SHORT'.
46 | strength - An adjustment factor "suggestion" used to scale
47 | quantity at the portfolio level. Useful for pairs strategies.
48 | """
49 | self.strategy_id = strategy_id
50 | self.type = 'SIGNAL'
51 | self.symbol = symbol
52 | self.datetime = datetime
53 | self.signal_type = signal_type
54 | self.strength = strength
55 |
56 |
57 | class OrderEvent(Event):
58 | """
59 | Handles the event of sending an Order to an execution system.
60 | The order contains a symbol (e.g. GOOG), a type (market or limit),
61 | quantity and a direction.
62 | """
63 |
64 | def __init__(self, symbol, order_type, quantity, direction):
65 | """
66 | Initialises the order type, setting whether it is
67 | a Market order ('MKT') or Limit order ('LMT'), has
68 | a quantity (integral) and its direction ('BUY' or
69 | 'SELL').
70 |
71 | TODO: Must handle error checking here to obtain
72 | rational orders (i.e. no negative quantities etc).
73 |
74 | Parameters:
75 | symbol - The instrument to trade.
76 | order_type - 'MKT' or 'LMT' for Market or Limit.
77 | quantity - Non-negative integer for quantity.
78 | direction - 'BUY' or 'SELL' for long or short.
79 | """
80 | self.type = 'ORDER'
81 | self.symbol = symbol
82 | self.order_type = order_type
83 | self.quantity = quantity
84 | self.direction = direction
85 |
86 | def print_order(self):
87 | """
88 | Outputs the values within the Order.
89 | """
90 | print(
91 | "Order: Symbol=%s, Type=%s, Quantity=%s, Direction=%s" %
92 | (self.symbol, self.order_type, self.quantity, self.direction)
93 | )
94 |
95 |
96 | class FillEvent(Event):
97 | """
98 | Encapsulates the notion of a Filled Order, as returned
99 | from a brokerage. Stores the quantity of an instrument
100 | actually filled and at what price. In addition, stores
101 | the commission of the trade from the brokerage.
102 |
103 | TODO: Currently does not support filling positions at
104 | different prices. This will be simulated by averaging
105 | the cost.
106 | """
107 |
108 | def __init__(self, timeindex, symbol, exchange, quantity,
109 | direction, fill_cost, commission=None):
110 | """
111 | Initialises the FillEvent object. Sets the symbol, exchange,
112 | quantity, direction, cost of fill and an optional
113 | commission.
114 |
115 | If commission is not provided, the Fill object will
116 | calculate it based on the trade size and Interactive
117 | Brokers fees.
118 |
119 | Parameters:
120 | timeindex - The bar-resolution when the order was filled.
121 | symbol - The instrument which was filled.
122 | exchange - The exchange where the order was filled.
123 | quantity - The filled quantity.
124 | direction - The direction of fill ('BUY' or 'SELL')
125 | fill_cost - The holdings value in dollars.
126 | commission - An optional commission sent from IB.
127 | """
128 | self.type = 'FILL'
129 | self.timeindex = timeindex
130 | self.symbol = symbol
131 | self.exchange = exchange
132 | self.quantity = quantity
133 | self.direction = direction
134 | self.fill_cost = fill_cost
135 |
136 | # Calculate commission
137 | if commission is None:
138 | self.commission = self.calculate_ib_commission()
139 | else:
140 | self.commission = commission
141 |
142 | def calculate_ib_commission(self):
143 | """
144 | Calculates the fees of trading based on an Interactive
145 | Brokers fee structure for API, in USD.
146 |
147 | This does not include exchange or ECN fees.
148 |
149 | Based on "US API Directed Orders":
150 | https://www.interactivebrokers.com/en/index.php?f=commission&p=stocks2
151 | """
152 | full_cost = 1.3
153 | if self.quantity <= 500:
154 | full_cost = max(1.3, 0.013 * self.quantity)
155 | else: # Greater than 500
156 | full_cost = max(1.3, 0.008 * self.quantity)
157 | return full_cost
158 |
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/execution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # execution.py
5 |
6 | from __future__ import print_function
7 |
8 | from abc import ABCMeta, abstractmethod
9 | import datetime
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 |
15 | from event import FillEvent, OrderEvent
16 |
17 |
18 | class ExecutionHandler(object):
19 | """
20 | The ExecutionHandler abstract class handles the interaction
21 | between a set of order objects generated by a Portfolio and
22 | the ultimate set of Fill objects that actually occur in the
23 | market.
24 |
25 | The handlers can be used to subclass simulated brokerages
26 | or live brokerages, with identical interfaces. This allows
27 | strategies to be backtested in a very similar manner to the
28 | live trading engine.
29 | """
30 |
31 | __metaclass__ = ABCMeta
32 |
33 | @abstractmethod
34 | def execute_order(self, event):
35 | """
36 | Takes an Order event and executes it, producing
37 | a Fill event that gets placed onto the Events queue.
38 |
39 | Parameters:
40 | event - Contains an Event object with order information.
41 | """
42 | raise NotImplementedError("Should implement execute_order()")
43 |
44 |
45 | class SimulatedExecutionHandler(ExecutionHandler):
46 | """
47 | The simulated execution handler simply converts all order
48 | objects into their equivalent fill objects automatically
49 | without latency, slippage or fill-ratio issues.
50 |
51 | This allows a straightforward "first go" test of any strategy,
52 | before implementation with a more sophisticated execution
53 | handler.
54 | """
55 |
56 | def __init__(self, events):
57 | """
58 | Initialises the handler, setting the event queues
59 | up internally.
60 |
61 | Parameters:
62 | events - The Queue of Event objects.
63 | """
64 | self.events = events
65 |
66 | def execute_order(self, event):
67 | """
68 | Simply converts Order objects into Fill objects naively,
69 | i.e. without any latency, slippage or fill ratio problems.
70 |
71 | Parameters:
72 | event - Contains an Event object with order information.
73 | """
74 | if event.type == 'ORDER':
75 | fill_event = FillEvent(
76 | datetime.datetime.utcnow(), event.symbol,
77 | 'ARCA', event.quantity, event.direction, None
78 | )
79 | self.events.put(fill_event)
80 |
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/mac.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import numpy as np
3 |
4 | from backtest import Backtest
5 | from data import HistoricCSVDataHandler
6 | from event import SignalEvent
7 | from execution import SimulatedExecutionHandler
8 | from portfolio import Portfolio
9 | from strategy import Strategy
10 |
11 |
12 | class MovingAverageCrossStrategy(Strategy):
13 | """
14 | Carries out a basic Moving Average Crossover strategy with a
15 | short/long simple weighted moving average. Default short/long
16 | windows are 100/400 periods respectively.
17 | """
18 |
19 | def __init__(self, bars, events, short_window=100, long_window=400):
20 | """
21 | Initialises the buy and hold strategy.
22 |
23 | Parameters:
24 | bars - The DataHandler object that provides bar information
25 | events - The Event Queue object.
26 | short_window - The short moving average lookback.
27 | long_window - The long moving average lookback.
28 | """
29 | self.bars = bars
30 | self.symbol_list = self.bars.symbol_list
31 | self.events = events
32 | self.short_window = short_window
33 | self.long_window = long_window
34 |
35 | # Set to True if a symbol is in the market
36 | self.bought = self._calculate_initial_bought()
37 |
38 | def _calculate_initial_bought(self):
39 | """
40 | Adds keys to the bought dictionary for all symbols
41 | and sets them to 'OUT'.
42 | """
43 | bought = {}
44 | for s in self.symbol_list:
45 | bought[s] = 'OUT'
46 | return bought
47 |
48 | def calculate_signals(self, event):
49 | """
50 | Generates a new set of signals based on the MAC
51 | SMA with the short window crossing the long window
52 | meaning a long entry and vice versa for a short entry.
53 |
54 | Parameters
55 | event - A MarketEvent object.
56 | """
57 | if event.type == 'MARKET':
58 | for symbol in self.symbol_list:
59 | bars = self.bars.get_latest_bars_values(symbol, "close", N=self.long_window)
60 |
61 | if bars is not None and bars != []:
62 | short_sma = np.mean(bars[-self.short_window:])
63 | long_sma = np.mean(bars[-self.long_window:])
64 |
65 | dt = self.bars.get_latest_bar_datetime(symbol)
66 | sig_dir = ""
67 | strength = 1.0
68 | strategy_id = 1
69 |
70 | if short_sma > long_sma and self.bought[symbol] == "OUT":
71 | sig_dir = 'LONG'
72 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength)
73 | self.events.put(signal)
74 | self.bought[symbol] = 'LONG'
75 |
76 | elif short_sma < long_sma and self.bought[symbol] == "LONG":
77 | sig_dir = 'EXIT'
78 | signal = SignalEvent(strategy_id, symbol, dt, sig_dir, strength)
79 | self.events.put(signal)
80 | self.bought[symbol] = 'OUT'
81 |
82 |
83 | if __name__ == "__main__":
84 | csv_dir = REPLACE_WITH_YOUR_CSV_DIR_HERE
85 | symbol_list = ['AAPL']
86 | initial_capital = 100000.0
87 | start_date = datetime.datetime(1990,1,1,0,0,0)
88 | heartbeat = 0.0
89 |
90 | backtest = Backtest(csv_dir,
91 | symbol_list,
92 | initial_capital,
93 | heartbeat,
94 | start_date,
95 | HistoricCSVDataHandler,
96 | SimulatedExecutionHandler,
97 | Portfolio,
98 | MovingAverageCrossStrategy)
99 |
100 | backtest.simulate_trading()
101 |
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/performance.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # performance.py
5 |
6 | from __future__ import print_function
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 |
12 | def create_sharpe_ratio(returns, periods=252):
13 | """
14 | Create the Sharpe ratio for the strategy, based on a
15 | benchmark of zero (i.e. no risk-free rate information).
16 |
17 | Parameters:
18 | returns - A pandas Series representing period percentage returns.
19 | periods - Daily (252), Hourly (252*6.5), Minutely(252*6.5*60) etc.
20 | """
21 | return np.sqrt(periods) * (np.mean(returns)) / np.std(returns)
22 |
23 |
24 | def create_drawdowns(pnl):
25 | """
26 | Calculate the largest peak-to-trough drawdown of the PnL curve
27 | as well as the duration of the drawdown. Requires that the
28 | pnl_returns is a pandas Series.
29 |
30 | Parameters:
31 | pnl - A pandas Series representing period percentage returns.
32 |
33 | Returns:
34 | drawdown, duration - Highest peak-to-trough drawdown and duration.
35 | """
36 |
37 | # Calculate the cumulative returns curve
38 | # and set up the High Water Mark
39 | hwm = [0]
40 |
41 | # Create the drawdown and duration series
42 | idx = pnl.index
43 | drawdown = pd.Series(index = idx)
44 | duration = pd.Series(index = idx)
45 |
46 | # Loop over the index range
47 | for t in range(1, len(idx)):
48 | hwm.append(max(hwm[t-1], pnl[t]))
49 | drawdown[t]= (hwm[t]-pnl[t])
50 | duration[t]= (0 if drawdown[t] == 0 else duration[t-1]+1)
51 | return drawdown, drawdown.max(), duration.max()
52 |
--------------------------------------------------------------------------------
/python学习资料/Quant/BackTestSys/strategy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # strategy.py
5 |
6 | from __future__ import print_function
7 |
8 | from abc import ABCMeta, abstractmethod
9 | import datetime
10 | try:
11 | import Queue as queue
12 | except ImportError:
13 | import queue
14 |
15 | import numpy as np
16 | import pandas as pd
17 |
18 | from event import SignalEvent
19 |
20 |
21 | class Strategy(object):
22 | """
23 | Strategy is an abstract base class providing an interface for
24 | all subsequent (inherited) strategy handling objects.
25 |
26 | The goal of a (derived) Strategy object is to generate Signal
27 | objects for particular symbols based on the inputs of Bars
28 | (OHLCV) generated by a DataHandler object.
29 |
30 | This is designed to work both with historic and live data as
31 | the Strategy object is agnostic to where the data came from,
32 | since it obtains the bar tuples from a queue object.
33 | """
34 |
35 | __metaclass__ = ABCMeta
36 |
37 | @abstractmethod
38 | def calculate_signals(self):
39 | """
40 | Provides the mechanisms to calculate the list of signals.
41 | """
42 | raise NotImplementedError("Should implement calculate_signals()")
43 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetQuantData/ApiList.py:
--------------------------------------------------------------------------------
1 | list500CompaniesInWiki = "http://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
2 |
3 | yahoo_url = "http://ichart.finance.yahoo.com/table.csv"
4 |
5 |
6 |
7 | #test
8 |
9 | import datetime
10 | import requests
11 | ticker = "AAPL"
12 | # start_date=(2000,1,1)
13 | start_date=(2019,8,1)
14 | end_date=datetime.date.today().timetuple()[0:3]
15 | ticker_tup = (
16 | ticker, start_date[1]-1, start_date[2],
17 | start_date[0], end_date[1]-1, end_date[2],
18 | end_date[0]
19 | )
20 | yahoo_url = yahoo_url
21 | yahoo_url += "?s=%s&a=%s&b=%s&c=%s&d=%s&e=%s&f=%s"
22 | yahoo_url = yahoo_url % ticker_tup
23 | try:
24 | yf_data = requests.get(yahoo_url).text.split("\n")[1:-1]
25 | prices = []
26 | for y in yf_data:
27 | p = y.strip().split(',')
28 | prices.append(
29 | (datetime.datetime.strptime(p[0], '%Y-%m-%d'),
30 | p[1], p[2], p[3], p[4], p[5], p[6])
31 | )
32 | except Exception as e:
33 | print("Could not download Yahoo data: %s" % e)
34 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetQuantData/DataParser.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # 安装rqData
4 | # pip install --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/rqdatac
5 |
6 |
7 | import rqdatac as rq
8 |
9 | rq.init()
10 |
11 | rq.get_price('000001.XSHE','2018-3-23','2018-3-23','tick')
--------------------------------------------------------------------------------
/python学习资料/Quant/GetQuantData/MySqlTest.py:
--------------------------------------------------------------------------------
1 | import pymysql
2 |
3 |
4 | connection = pymysql.connect(host='localhost',
5 | port=3306,
6 | user='root',
7 | password='68466296aB',
8 | # db='demo',
9 | charset='utf8')
10 |
11 |
12 | cursor = connection.cursor()
13 | # 创建数据库
14 | effect_row = cursor.execute(
15 | '''CREATE DATABASE MysqlTestDB'''
16 | )
17 | connection.commit()
18 |
19 | # 创建数据表
20 | effect_row = cursor.execute(
21 | '''CREATE TABLE 'users' (
22 | 'name' varchar(32) NOT NULL,
23 | 'age' int(10) unsigned NOT NULL DEFAULT '0',
24 | PRIMARY KEY (‘name')
25 | ) ENGINE=InnoDB DEFAULT CHARSET=utf8'''
26 | )
27 |
28 | # 插入数据(元组或列表)
29 | effect_row = cursor.execute('INSERT INTO `users` (`name`, `age`) VALUES (%s, %s)', ('mary', 18))
30 |
31 | # 插入数据(字典)
32 | info = {'name': 'fake', 'age': 15}
33 | effect_row = cursor.execute('INSERT INTO `users` (`name`, `age`) VALUES (%(name)s, %(age)s)', info)
34 |
35 | connection.commit()
36 |
37 | ''' 批量插入'''
38 | # 获取游标
39 | cursor = connection.cursor()
40 |
41 | effect_row = cursor.executemany(
42 | 'INSERT INTO `users` (`name`, `age`) VALUES (%s, %s) ON DUPLICATE KEY UPDATE age=VALUES(age)', [
43 | ('hello', 13),
44 | ('fake', 28),
45 | ])
46 |
47 | connection.commit()
48 |
49 |
50 | '''获取自增ID'''
51 | cursor.lastrowid
52 |
53 | '''查询数据'''
54 | # 执行查询 SQL
55 | cursor.execute('SELECT * FROM `users`')
56 | # 获取单条数据
57 | cursor.fetchone()
58 | # 获取前N条数据
59 | cursor.fetchmany(3)
60 | # 获取所有数据
61 | cursor.fetchall()
62 |
63 | '''游标控制'''
64 | # 所有的数据查询操作均基于游标,我们可以通过cursor.scroll(num, mode)控制游标的位置。
65 | cursor.scroll(1, mode='relative') # 相对当前位置移动
66 | cursor.scroll(2, mode='absolute') # 相对绝对位置移动
67 |
68 | '''指定游标类型'''
69 | # Cursor: 默认,元组类型
70 | # DictCursor: 字典类型
71 | # DictCursorMixin: 支持自定义的游标类型,需先自定义才可使用
72 | # SSCursor: 无缓冲元组类型
73 | # SSDictCursor: 无缓冲字典类型
74 |
75 | connection = pymysql.connect(host='localhost',
76 | user='root',
77 | password='root',
78 | db='demo',
79 | charset='utf8',
80 | cursorclass=pymysql.cursors.DictCursor)
81 |
82 | '''事务处理'''
83 | # 开启事务 connection.begin()
84 | # 提交修改 connection.commit()
85 | # 回滚事务 connection.rollback()
--------------------------------------------------------------------------------
/python学习资料/Quant/GetQuantData/TushareTest1.py:
--------------------------------------------------------------------------------
1 | import tushare as ts
2 | import pandas as pd
3 | df = ts.get_tick_data('600848',date='2018-12-12',src='tt')
4 | df.head(10)
5 |
6 |
7 | #########tushare pro
8 | import tushare as ts
9 | ts.set_token('bbe62c4557d639a8fc050c17c8fb7d6ec8d8611ca94dcac42136822b')
10 | pro = ts.pro_api()
11 | # 日线数据
12 | df = pro.daily(ts_code='000001.SZ', start_date='20180701', end_date='20180718')
13 |
14 | df = ts.pro_bar(ts_code='000001.SZ', start_date='20130101', end_date='20161011', asset='E', freq='1min')
15 |
16 | import os
17 |
18 | filename = 'dataTest2.csv'
19 | path = os.path.join(os.getcwd(), 'QuantData',filename)
20 | df.to_csv(path)
21 |
22 | #查询当前所有正常上市交易的股票列表
23 | pro = ts.pro_api()
24 |
25 | data = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date')
26 | filename = 'stock_list.csv'
27 | path = os.path.join(os.getcwd(), 'QuantData',filename)
28 | data.to_csv(path)
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/cadf.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import matplotlib.dates as mdates
5 | import pandas as pd
6 | import pandas.io.data as web
7 | import pprint
8 | import statsmodels.tsa.stattools as ts
9 |
10 | from pandas.stats.api import ols
11 |
12 |
13 | def plot_price_series(df, ts1, ts2):
14 | months = mdates.MonthLocator() # every month
15 | fig, ax = plt.subplots()
16 | ax.plot(df.index, df[ts1], label=ts1)
17 | ax.plot(df.index, df[ts2], label=ts2)
18 | ax.xaxis.set_major_locator(months)
19 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
20 | ax.set_xlim(datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1))
21 | ax.grid(True)
22 | fig.autofmt_xdate()
23 |
24 | plt.xlabel('Month/Year')
25 | plt.ylabel('Price ($)')
26 | plt.title('%s and %s Daily Prices' % (ts1, ts2))
27 | plt.legend()
28 | plt.show()
29 |
30 | def plot_scatter_series(df, ts1, ts2):
31 | plt.xlabel('%s Price ($)' % ts1)
32 | plt.ylabel('%s Price ($)' % ts2)
33 | plt.title('%s and %s Price Scatterplot' % (ts1, ts2))
34 | plt.scatter(df[ts1], df[ts2])
35 | plt.show()
36 |
37 | def plot_residuals(df):
38 | months = mdates.MonthLocator() # every month
39 | fig, ax = plt.subplots()
40 | ax.plot(df.index, df["res"], label="Residuals")
41 | ax.xaxis.set_major_locator(months)
42 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
43 | ax.set_xlim(datetime.datetime(2012, 1, 1), datetime.datetime(2013, 1, 1))
44 | ax.grid(True)
45 | fig.autofmt_xdate()
46 |
47 | plt.xlabel('Month/Year')
48 | plt.ylabel('Price ($)')
49 | plt.title('Residual Plot')
50 | plt.legend()
51 |
52 | plt.plot(df["res"])
53 | plt.show()
54 |
55 | if __name__ == "__main__":
56 | start = datetime.datetime(2012, 1, 1)
57 | end = datetime.datetime(2013, 1, 1)
58 |
59 | arex = web.DataReader("AREX", "yahoo", start, end)
60 | wll = web.DataReader("WLL", "yahoo", start, end)
61 |
62 | df = pd.DataFrame(index=arex.index)
63 | df["AREX"] = arex["Adj Close"]
64 | df["WLL"] = wll["Adj Close"]
65 |
66 | # Plot the two time series
67 | plot_price_series(df, "AREX", "WLL")
68 |
69 | # Display a scatter plot of the two time series
70 | plot_scatter_series(df, "AREX", "WLL")
71 |
72 | # Calculate optimal hedge ratio "beta"
73 | res = ols(y=df['WLL'], x=df["AREX"])
74 | beta_hr = res.beta.x
75 |
76 | # Calculate the residuals of the linear combination
77 | df["res"] = df["WLL"] - beta_hr*df["AREX"]
78 |
79 | # Plot the residuals
80 | plot_residuals(df)
81 |
82 | # Calculate and output the CADF test on the residuals
83 | cadf = ts.adfuller(df["res"])
84 | pprint.pprint(cadf)
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/insert_symbols.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # insert_symbols.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | from math import ceil
10 |
11 | import bs4
12 | import requests
13 | #some error MySQLdb
14 | import MySQLdb as mdb
15 | from requests.exceptions import RequestException
16 | import Quant.GetQuantData.ApiList as qapi
17 |
18 |
19 | def obtain_parse_wiki_snp500():
20 | """
21 | Download and parse the Wikipedia list of S&P500
22 | constituents using requests and BeautifulSoup.
23 |
24 | Returns a list of tuples for to add to MySQL.
25 | """
26 | # Stores the current time, for the created_at record
27 |
28 | now = datetime.datetime.utcnow()
29 |
30 | # Use requests and BeautifulSoup to download the
31 | # list of S&P500 companies and obtain the symbol table
32 | try:
33 | # url = qapi.list500CompaniesInWiki
34 | url = "http://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
35 | headers = {
36 | 'User-Agent': 'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_6_8; en-us) AppleWebKit/534.50 (KHTML, like Gecko) Version/5.1 Safari/534.50'
37 | }
38 | # proxies = {
39 | # 'http': 'http://10.10.1.10:3128',
40 | # 'https': 'http://10.10.1.10:1080'
41 | # }
42 | # proxy = '124.243.226.18:8888'
43 |
44 | # # proxy='username:password@'
45 | # proxies = {
46 | # 'http': 'http://' + proxy,
47 | # 'https': 'https://' + proxy,
48 | # }
49 |
50 |
51 | response = requests.get(
52 | url,
53 | headers=headers,
54 | # proxies=proxies
55 | )
56 | if response.status_code== 200:
57 | print(response.text)
58 | print('nothing')
59 | except RequestException:
60 | print('error')
61 | soup = bs4.BeautifulSoup(response.text)
62 | # This selects the first table, using CSS Selector syntax
63 | # and then ignores the header row ([1:])
64 | symbolslist = soup.select('table')[0].select('tr')[1:]
65 |
66 | # Obtain the symbol information for each
67 | # row in the S&P500 constituent table
68 | symbols = []
69 | for i, symbol in enumerate(symbolslist):
70 | tds = symbol.select('td')
71 | symbols.append(
72 | (
73 | tds[0].select('a')[0].text, # Ticker
74 | 'stock',
75 | tds[1].select('a')[0].text, # Name
76 | tds[3].text, # Sector
77 | 'USD', now, now
78 | )
79 | )
80 | return symbols
81 |
82 |
83 | def insert_snp500_symbols(symbols):
84 | """
85 | Insert the S&P500 symbols into the MySQL database.
86 | """
87 | # Connect to the MySQL instance
88 | db_host = 'localhost'
89 | db_user = 'sec_user'
90 | db_pass = 'password'
91 | db_name = 'securities_master'
92 | con = mdb.connect(
93 | host=db_host, user=db_user, passwd=db_pass, db=db_name
94 | )
95 |
96 | # Create the insert strings
97 | column_str = """ticker, instrument, name, sector,
98 | currency, created_date, last_updated_date
99 | """
100 | insert_str = ("%s, " * 7)[:-2]
101 | final_str = "INSERT INTO symbol (%s) VALUES (%s)" % \
102 | (column_str, insert_str)
103 |
104 | # Using the MySQL connection, carry out
105 | # an INSERT INTO for every symbol
106 | with con:
107 | cur = con.cursor()
108 | cur.executemany(final_str, symbols)
109 |
110 |
111 | if __name__ == "__main__":
112 | symbols = obtain_parse_wiki_snp500()
113 | insert_snp500_symbols(symbols)
114 | print("%s symbols were successfully added." % len(symbols))
115 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/price_retrieval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # price_retrieval.py
5 |
6 | from __future__ import print_function
7 |
8 | import datetime
9 | import warnings
10 |
11 | import MySQLdb as mdb
12 | import requests
13 | import Quant.GetQuantData.ApiList as qapi
14 |
15 |
16 | # Obtain a database connection to the MySQL instance
17 | db_host = 'localhost'
18 | db_user = 'sec_user'
19 | db_pass = 'password'
20 | db_name = 'securities_master'
21 | con = mdb.connect(db_host, db_user, db_pass, db_name)
22 |
23 |
24 | def obtain_list_of_db_tickers():
25 | """
26 | Obtains a list of the ticker symbols in the database.
27 | """
28 | with con:
29 | cur = con.cursor()
30 | cur.execute("SELECT id, ticker FROM symbol")
31 | data = cur.fetchall()
32 | return [(d[0], d[1]) for d in data]
33 |
34 |
35 | def get_daily_historic_data_yahoo(
36 | ticker, start_date=(2000,1,1),
37 | end_date=datetime.date.today().timetuple()[0:3]
38 | ):
39 | """
40 | Obtains data from Yahoo Finance returns and a list of tuples.
41 |
42 | ticker: Yahoo Finance ticker symbol, e.g. "GOOG" for Google, Inc.
43 | start_date: Start date in (YYYY, M, D) format
44 | end_date: End date in (YYYY, M, D) format
45 | """
46 | # Construct the Yahoo URL with the correct integer query parameters
47 | # for start and end dates. Note that some parameters are zero-based!
48 | ticker_tup = (
49 | ticker, start_date[1]-1, start_date[2],
50 | start_date[0], end_date[1]-1, end_date[2],
51 | end_date[0]
52 | )
53 | yahoo_url = qapi.yahoo_url
54 | yahoo_url += "?s=%s&a=%s&b=%s&c=%s&d=%s&e=%s&f=%s"
55 | yahoo_url = yahoo_url % ticker_tup
56 |
57 | # Try connecting to Yahoo Finance and obtaining the data
58 | # On failure, print an error message.
59 | try:
60 | yf_data = requests.get(yahoo_url).text.split("\n")[1:-1]
61 | prices = []
62 | for y in yf_data:
63 | p = y.strip().split(',')
64 | prices.append(
65 | (datetime.datetime.strptime(p[0], '%Y-%m-%d'),
66 | p[1], p[2], p[3], p[4], p[5], p[6])
67 | )
68 | except Exception as e:
69 | print("Could not download Yahoo data: %s" % e)
70 | return prices
71 |
72 |
73 | def insert_daily_data_into_db(
74 | data_vendor_id, symbol_id, daily_data
75 | ):
76 | """
77 | Takes a list of tuples of daily data and adds it to the
78 | MySQL database. Appends the vendor ID and symbol ID to the data.
79 |
80 | daily_data: List of tuples of the OHLC data (with
81 | adj_close and volume)
82 | """
83 | # Create the time now
84 | now = datetime.datetime.utcnow()
85 |
86 | # Amend the data to include the vendor ID and symbol ID
87 | daily_data = [
88 | (data_vendor_id, symbol_id, d[0], now, now,
89 | d[1], d[2], d[3], d[4], d[5], d[6])
90 | for d in daily_data
91 | ]
92 |
93 | # Create the insert strings
94 | column_str = """data_vendor_id, symbol_id, price_date, created_date,
95 | last_updated_date, open_price, high_price, low_price,
96 | close_price, volume, adj_close_price"""
97 | insert_str = ("%s, " * 11)[:-2]
98 | final_str = "INSERT INTO daily_price (%s) VALUES (%s)" % \
99 | (column_str, insert_str)
100 |
101 | # Using the MySQL connection, carry out an INSERT INTO for every symbol
102 | with con:
103 | cur = con.cursor()
104 | cur.executemany(final_str, daily_data)
105 |
106 |
107 | if __name__ == "__main__":
108 | # This ignores the warnings regarding Data Truncation
109 | # from the Yahoo precision to Decimal(19,4) datatypes
110 | warnings.filterwarnings('ignore')
111 |
112 | # Loop over the tickers and insert the daily historical
113 | # data into the database
114 | tickers = obtain_list_of_db_tickers()
115 | lentickers = len(tickers)
116 | for i, t in enumerate(tickers):
117 | print(
118 | "Adding data for %s: %s out of %s" %
119 | (t[1], i+1, lentickers)
120 | )
121 | yf_data = get_daily_historic_data_yahoo(t[1])
122 | insert_daily_data_into_db('1', t[0], yf_data)
123 | print("Successfully added Yahoo Finance pricing data to DB.")
124 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/quandl_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # quandl_data.py
5 |
6 | from __future__ import print_function
7 |
8 | import matplotlib.pyplot as plt
9 | import pandas as pd
10 | import pandas_datareader as pdrd
11 | import requests
12 |
13 |
14 | def construct_futures_symbols(
15 | symbol, start_year=2010, end_year=2014
16 | ):
17 | """
18 | Constructs a list of futures contract codes
19 | for a particular symbol and timeframe.
20 | """
21 | futures = []
22 | # March, June, September and
23 | # December delivery codes
24 | months = 'HMUZ'
25 | for y in range(start_year, end_year+1):
26 | for m in months:
27 | futures.append("%s%s%s" % (symbol, m, y))
28 | return futures
29 |
30 |
31 | def download_contract_from_quandl(contract, dl_dir):
32 | """
33 | Download an individual futures contract from Quandl and then
34 | store it to disk in the 'dl_dir' directory. An auth_token is
35 | required, which is obtained from the Quandl upon sign-up.
36 | """
37 | # Construct the API call from the contract and auth_token
38 | api_call = "http://www.quandl.com/api/v1/datasets/"
39 | api_call += "OFDP/FUTURE_%s.csv" % contract
40 | # If you wish to add an auth token for more downloads, simply
41 | # comment the following line and replace MY_AUTH_TOKEN with
42 | # your auth token in the line below
43 | params = "?sort_order=asc"
44 | #params = "?auth_token=MY_AUTH_TOKEN&sort_order=asc"
45 | full_url = "%s%s" % (api_call, params)
46 |
47 | # Download the data from Quandl
48 | data = requests.get(full_url).text
49 |
50 | # Store the data to disk
51 | fc = open('%s/%s.csv' % (dl_dir, contract), 'w')
52 | fc.write(data)
53 | fc.close()
54 |
55 |
56 | def download_historical_contracts(
57 | symbol, dl_dir, start_year=2010, end_year=2014
58 | ):
59 | """
60 | Downloads all futures contracts for a specified symbol
61 | between a start_year and an end_year.
62 | """
63 | contracts = construct_futures_symbols(
64 | symbol, start_year, end_year
65 | )
66 | for c in contracts:
67 | print("Downloading contract: %s" % c)
68 | download_contract_from_quandl(c, dl_dir)
69 |
70 |
71 | if __name__ == "__main__":
72 | symbol = 'ES'
73 |
74 | # Make sure you've created this
75 | # relative directory beforehand
76 | dl_dir = 'quandl/futures/ES'
77 |
78 | # Create the start and end years
79 | start_year = 2010
80 | end_year = 2014
81 |
82 | # Download the contracts into the directory
83 | download_historical_contracts(
84 | symbol, dl_dir, start_year, end_year
85 | )
86 |
87 | # Open up a single contract via read_csv
88 | # and plot the settle price
89 |
90 | # es = pd.io.parsers.read_csv(
91 | # "%s/ESH2010.csv" % dl_dir, index_col="Date"
92 | # )
93 | es = pd.read_csv(
94 | "%s/ESH2010.csv" % dl_dir, index_col="Date"
95 | )
96 | es["Settle"].plot()
97 | plt.show()
98 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/quantitative.sql:
--------------------------------------------------------------------------------
1 | -- Errors encountered generating script
2 | -- Select items in the error list to the left
3 |
4 | CREATE TABLE `exchange` (
5 | `id` int NOT NULL AUTO_INCREMENT,
6 | `abbrev` varchar(32) NOT NULL,
7 | `name` varchar(255) NOT NULL,
8 | `city` varchar(255) NULL,
9 | `country` varchar(255) NULL,
10 | `currency` varchar(64) NULL,
11 | `timezone_offset` time NULL,
12 | `created_date` datetime NOT NULL,
13 | `last_updated_date` datetime NOT NULL,
14 | PRIMARY KEY (`id`)
15 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
16 |
17 | CREATE TABLE `data_vendor` (
18 | `id` int NOT NULL AUTO_INCREMENT,
19 | `name` varchar(64) NOT NULL,
20 | `website_url` varchar(255) NULL,
21 | `support_email` varchar(255) NULL,
22 | `created_date` datetime NOT NULL,
23 | `last_updated_date` datetime NOT NULL,
24 | PRIMARY KEY (`id`)
25 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
26 |
27 | CREATE TABLE `symbol` (
28 | `id` int NOT NULL AUTO_INCREMENT,
29 | `exchange_id` int NULL,
30 | `ticker` varchar(32) NOT NULL,
31 | `instrument` varchar(64) NOT NULL,
32 | `name` varchar(255) NULL,
33 | `sector` varchar(255) NULL,
34 | `currency` varchar(32) NULL,
35 | `created_date` datetime NOT NULL,
36 | `last_updated_date` datetime NOT NULL,
37 | PRIMARY KEY (`id`),
38 | KEY `index_exchange_id` (`exchange_id`)
39 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
40 |
41 | CREATE TABLE `daily_price` (
42 | `id` int NOT NULL AUTO_INCREMENT,
43 | `data_vendor_id` int NOT NULL,
44 | `symbol_id` int NOT NULL,
45 | `price_date` datetime NOT NULL,
46 | `created_date` datetime NOT NULL,
47 | `last_updated_date` datetime NOT NULL,
48 | `open_price` decimal(19,4) NULL,
49 | `high_price` decimal(19,4) NULL,
50 | `low_price` decimal(19,4) NULL,
51 | `close_price` decimal(19,4) NULL,
52 | `adj_close_price` decimal(19,4) NULL,
53 | `volume` bigint NULL,
54 | PRIMARY KEY (`id`),
55 | KEY `index_data_vendor_id` (`data_vendor_id`),
56 | KEY `index_symbol_id` (`symbol_id`)
57 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
58 |
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/retrieving_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | # retrieving_data.py
5 |
6 | from __future__ import print_function
7 |
8 | import pandas as pd
9 | import MySQLdb as mdb
10 |
11 |
12 | if __name__ == "__main__":
13 | # Connect to the MySQL instance
14 | db_host = 'localhost'
15 | db_user = 'sec_user'
16 | db_pass = 'password'
17 | db_name = 'securities_master'
18 | con = mdb.connect(db_host, db_user, db_pass, db_name)
19 |
20 | # Select all of the historic Google adjusted close data
21 | sql = """SELECT dp.price_date, dp.adj_close_price
22 | FROM symbol AS sym
23 | INNER JOIN daily_price AS dp
24 | ON dp.symbol_id = sym.id
25 | WHERE sym.ticker = 'GOOG'
26 | ORDER BY dp.price_date ASC;"""
27 |
28 | # Create a pandas dataframe from the SQL query
29 | goog = pd.read_sql_query(sql, con=con, index_col='price_date')
30 |
31 | # Output the dataframe tail
32 | print(goog.tail())
33 |
34 |
35 |
36 |
37 | # pandas内置雅虎金融数据对接
38 | import pandas as pd
39 | import numpy as np
40 | from pandas_datareader import data, wb # 需要安装 pip install pandas_datareader
41 | import datetime
42 |
43 | # 定义获取数据的时间段
44 | start = datetime.datetime(2017, 4, 1)
45 | end = datetime.date.today()
46 |
47 | # 获取股票信息 ex: 中国石油
48 | # 如果要看上证指数请参考换成600000.ss
49 | # 如果要看深成指请换成000001.sz
50 | cnpc = data.DataReader("601857.SS", 'yahoo', start, end)
51 |
52 | cnpc.head(5)
53 |
54 | spy = data.DataReader("SPY", 'yahoo', start, end)
55 | print(spy.tail(5))
--------------------------------------------------------------------------------
/python学习资料/Quant/GetYahooData/securities_master.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE `exchange` (
2 | `id` int NOT NULL AUTO_INCREMENT,
3 | `abbrev` varchar(32) NOT NULL,
4 | `name` varchar(255) NOT NULL,
5 | `city` varchar(255) NULL,
6 | `country` varchar(255) NULL,
7 | `currency` varchar(64) NULL,
8 | `timezone_offset` time NULL,
9 | `created_date` datetime NOT NULL,
10 | `last_updated_date` datetime NOT NULL,
11 | PRIMARY KEY (`id`)
12 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
13 |
14 | CREATE TABLE `data_vendor` (
15 | `id` int NOT NULL AUTO_INCREMENT,
16 | `name` varchar(64) NOT NULL,
17 | `website_url` varchar(255) NULL,
18 | `support_email` varchar(255) NULL,
19 | `created_date` datetime NOT NULL,
20 | `last_updated_date` datetime NOT NULL,
21 | PRIMARY KEY (`id`)
22 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
23 |
24 | CREATE TABLE `symbol` (
25 | `id` int NOT NULL AUTO_INCREMENT,
26 | `exchange_id` int NULL,
27 | `ticker` varchar(32) NOT NULL,
28 | `instrument` varchar(64) NOT NULL,
29 | `name` varchar(255) NULL,
30 | `sector` varchar(255) NULL,
31 | `currency` varchar(32) NULL,
32 | `created_date` datetime NOT NULL,
33 | `last_updated_date` datetime NOT NULL,
34 | PRIMARY KEY (`id`),
35 | KEY `index_exchange_id` (`exchange_id`)
36 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
37 |
38 | CREATE TABLE `daily_price` (
39 | `id` int NOT NULL AUTO_INCREMENT,
40 | `data_vendor_id` int NOT NULL,
41 | `symbol_id` int NOT NULL,
42 | `price_date` datetime NOT NULL,
43 | `created_date` datetime NOT NULL,
44 | `last_updated_date` datetime NOT NULL,
45 | `open_price` decimal(19,4) NULL,
46 | `high_price` decimal(19,4) NULL,
47 | `low_price` decimal(19,4) NULL,
48 | `close_price` decimal(19,4) NULL,
49 | `adj_close_price` decimal(19,4) NULL,
50 | `volume` bigint NULL,
51 | PRIMARY KEY (`id`),
52 | KEY `index_data_vendor_id` (`data_vendor_id`),
53 | KEY `index_symbol_id` (`symbol_id`)
54 | ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
--------------------------------------------------------------------------------
/python学习资料/Quant/TushareBackTest/Cap_Update_daily.py:
--------------------------------------------------------------------------------
1 | import pymysql
2 |
3 | def cap_update_daily(state_dt):
4 | para_norisk = (1.0 + 0.04/365)
5 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='admin', db='stock', charset='utf8')
6 | cursor = db.cursor()
7 | sql_pool = "select * from my_stock_pool"
8 | cursor.execute(sql_pool)
9 | done_set = cursor.fetchall()
10 | db.commit()
11 | new_lock_cap = 0.00
12 | for i in range(len(done_set)):
13 | stock_code = str(done_set[i][0])
14 | stock_vol = float(done_set[i][2])
15 | sql = "select * from stock_info a where a.stock_code = '%s' and a.state_dt <= '%s' order by a.state_dt desc limit 1"%(stock_code,state_dt)
16 | cursor.execute(sql)
17 | done_temp = cursor.fetchall()
18 | db.commit()
19 | if len(done_temp) > 0:
20 | cur_close_price = float(done_temp[0][3])
21 | new_lock_cap += cur_close_price * stock_vol
22 | else:
23 | print('Cap_Update_daily Err!!')
24 | raise Exception
25 | sql_cap = "select * from my_capital order by seq asc"
26 | cursor.execute(sql_cap)
27 | done_cap = cursor.fetchall()
28 | db.commit()
29 | new_cash_cap = float(done_cap[-1][2]) * para_norisk
30 | new_total_cap = new_cash_cap + new_lock_cap
31 | sql_insert = "insert into my_capital(capital,money_lock,money_rest,bz,state_dt)values('%.2f','%.2f','%.2f','%s','%s')"%(new_total_cap,new_lock_cap,new_cash_cap,str('Daily_Update'),state_dt)
32 | cursor.execute(sql_insert)
33 | db.commit()
34 | return 1
--------------------------------------------------------------------------------
/python学习资料/Quant/TushareBackTest/DC.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf8 -*-
2 | import numpy as np
3 | import pymysql
4 |
5 |
6 | class data_collect(object):
7 |
8 | def __init__(self, in_code,start_dt,end_dt):
9 | ans = self.collectDATA(in_code,start_dt,end_dt)
10 |
11 | def collectDATA(self,in_code,start_dt,end_dt):
12 | # 建立数据库连接,获取日线基础行情(开盘价,收盘价,最高价,最低价,成交量,成交额)
13 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='admin', db='stock', charset='utf8')
14 | cursor = db.cursor()
15 | sql_done_set = "SELECT * FROM stock_all a where stock_code = '%s' and state_dt >= '%s' and state_dt <= '%s' order by state_dt asc" % (in_code, start_dt, end_dt)
16 | cursor.execute(sql_done_set)
17 | done_set = cursor.fetchall()
18 | if len(done_set) == 0:
19 | raise Exception
20 | self.date_seq = []
21 | self.open_list = []
22 | self.close_list = []
23 | self.high_list = []
24 | self.low_list = []
25 | self.vol_list = []
26 | self.amount_list = []
27 | for i in range(len(done_set)):
28 | self.date_seq.append(done_set[i][0])
29 | self.open_list.append(float(done_set[i][2]))
30 | self.close_list.append(float(done_set[i][3]))
31 | self.high_list.append(float(done_set[i][4]))
32 | self.low_list.append(float(done_set[i][5]))
33 | self.vol_list.append(float(done_set[i][6]))
34 | self.amount_list.append(float(done_set[i][7]))
35 | cursor.close()
36 | db.close()
37 | # 将日线行情整合为训练集(其中self.train是输入集,self.target是输出集,self.test_case是end_dt那天的单条测试输入)
38 | self.data_train = []
39 | self.data_target = []
40 | self.data_target_onehot = []
41 | self.cnt_pos = 0
42 | self.test_case = []
43 |
44 | for i in range(1,len(self.close_list)):
45 | train = [self.open_list[i-1],self.close_list[i-1],self.high_list[i-1],self.low_list[i-1],self.vol_list[i-1],self.amount_list[i-1]]
46 | self.data_train.append(np.array(train))
47 |
48 | if self.close_list[i]/self.close_list[i-1] > 1.0:
49 | self.data_target.append(float(1.00))
50 | self.data_target_onehot.append([1,0,0])
51 | else:
52 | self.data_target.append(float(0.00))
53 | self.data_target_onehot.append([0,1,0])
54 | self.cnt_pos =len([x for x in self.data_target if x == 1.00])
55 | self.test_case = np.array([self.open_list[-1],self.close_list[-1],self.high_list[-1],self.low_list[-1],self.vol_list[-1],self.amount_list[-1]])
56 | self.data_train = np.array(self.data_train)
57 | self.data_target = np.array(self.data_target)
58 | return 1
--------------------------------------------------------------------------------
/python学习资料/Quant/TushareBackTest/Deal.py:
--------------------------------------------------------------------------------
1 | import pymysql.cursors
2 |
3 | class Deal(object):
4 | cur_capital = 0.00
5 | cur_money_lock = 0.00
6 | cur_money_rest = 0.00
7 | stock_pool = []
8 | stock_map1 = {}
9 | stock_map2 = {}
10 | stock_map3 = {}
11 | stock_all = []
12 | ban_list = []
13 |
14 | def __init__(self,state_dt):
15 | # 建立数据库连接
16 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='admin', db='stock', charset='utf8')
17 | cursor = db.cursor()
18 | try:
19 | sql_select = 'select * from my_capital a order by seq desc limit 1'
20 | cursor.execute(sql_select)
21 | done_set = cursor.fetchall()
22 | self.cur_capital = 0.00
23 | self.cur_money_lock = 0.00
24 | self.cur_money_rest = 0.00
25 | if len(done_set) > 0:
26 | self.cur_capital = float(done_set[0][0])
27 | self.cur_money_rest = float(done_set[0][2])
28 | sql_select2 = 'select * from my_stock_pool'
29 | cursor.execute(sql_select2)
30 | done_set2 = cursor.fetchall()
31 | self.stock_pool = []
32 | self.stock_all = []
33 | self.stock_map1 = []
34 | self.stock_map2 = []
35 | self.stock_map3 = []
36 | self.ban_list = []
37 | if len(done_set2) > 0:
38 | self.stock_pool = [x[0] for x in done_set2 if x[2] > 0]
39 | self.stock_all = [x[0] for x in done_set2]
40 | self.stock_map1 = {x[0]: float(x[1]) for x in done_set2}
41 | self.stock_map2 = {x[0]: int(x[2]) for x in done_set2}
42 | self.stock_map3 = {x[0]: int(x[3]) for x in done_set2}
43 | for i in range(len(done_set2)):
44 | sql = "select * from stock_info a where a.stock_code = '%s' and a.state_dt = '%s'"%(done_set2[i][0],state_dt)
45 | cursor.execute(sql)
46 | done_temp = cursor.fetchall()
47 | db.commit()
48 | self.cur_money_lock += float(done_temp[0][3]) * float(done_set2[i][2])
49 | # sql_select3 = 'select * from ban_list'
50 | # cursor.execute(sql_select3)
51 | # done_set3 = cursor.fetchall()
52 | # if len(done_set3) > 0:
53 | # self.ban_list = [x[0] for x in done_set3]
54 |
55 |
56 | except Exception as excp:
57 | #db.rollback()
58 | print(excp)
59 |
60 | db.close()
61 |
--------------------------------------------------------------------------------
/python学习资料/Quant/TushareBackTest/Filter.py:
--------------------------------------------------------------------------------
1 | import pymysql.cursors
2 | import Deal
3 | import Operator
4 |
5 | def filter_main(stock_new,state_dt,predict_dt,poz):
6 | # 建立数据库连接
7 | db = pymysql.connect(host='127.0.0.1', user='root', passwd='admin', db='stock', charset='utf8')
8 | cursor = db.cursor()
9 |
10 | #先更新持股天数
11 | sql_update_hold_days = 'update my_stock_pool w set w.hold_days = w.hold_days + 1'
12 | cursor.execute(sql_update_hold_days)
13 | db.commit()
14 |
15 | #先卖出
16 | deal = Deal.Deal(state_dt)
17 | stock_pool_local = deal.stock_pool
18 | for stock in stock_pool_local:
19 | sql_predict = "select predict from good_pool_all a where a.state_dt = '%s' and a.stock_code = '%s'"%(predict_dt,stock)
20 | cursor.execute(sql_predict)
21 | done_set_predict = cursor.fetchall()
22 | predict = 0
23 | if len(done_set_predict) > 0:
24 | predict = int(done_set_predict[0][0])
25 | ans = Operator.sell(stock,state_dt,predict)
26 |
27 | #后买入
28 | for stock_index in range(len(stock_new)):
29 | deal_buy = Deal.Deal(state_dt)
30 | #if poz[stock_index]*deal_buy.cur_money_rest >= :
31 | # sql_ban_pool = "select distinct stock_code from ban_list"
32 | # cursor.execute(sql_ban_pool)
33 | # done_ban_pool = cursor.fetchall()
34 | # ban_list = [x[0] for x in done_ban_pool]
35 |
36 | # # 如果模型f1分值低于50则不买入
37 | # sql_f1_check = "select * from good_pool_all a where a.stock_code = '%s' and a.state_dt < '%s' order by a.state_dt desc limit 1"%(stock_new[stock_index],state_dt)
38 | # cursor.execute(sql_f1_check)
39 | # done_check = cursor.fetchall()
40 | # db.commit()
41 | # if len(done_check) > 0:
42 | # if float(done_check[0][4]) < 0.5:
43 | # print('F1 Warning !!')
44 | # continue
45 |
46 |
47 | ans = Operator.buy(stock_new[stock_index],state_dt,poz[stock_index]*deal_buy.cur_money_rest)
48 | del deal_buy
49 | db.close()
--------------------------------------------------------------------------------
/python学习资料/VNProject/atr_rsi_strategy.py:
--------------------------------------------------------------------------------
1 | from vnpy.app.cta_strategy import (
2 | CtaTemplate,
3 | StopOrder,
4 | TickData,
5 | BarData,
6 | TradeData,
7 | OrderData,
8 | BarGenerator,
9 | ArrayManager,
10 | )
11 |
12 |
13 | class AtrRsiStrategy(CtaTemplate):
14 | """"""
15 |
16 | author = "用Python的交易员"
17 |
18 | atr_length = 22
19 | atr_ma_length = 10
20 | rsi_length = 5
21 | rsi_entry = 16
22 | trailing_percent = 0.8
23 | fixed_size = 1
24 |
25 | atr_value = 0
26 | atr_ma = 0
27 | rsi_value = 0
28 | rsi_buy = 0
29 | rsi_sell = 0
30 | intra_trade_high = 0
31 | intra_trade_low = 0
32 |
33 | parameters = ["atr_length", "atr_ma_length", "rsi_length",
34 | "rsi_entry", "trailing_percent", "fixed_size"]
35 | variables = ["atr_value", "atr_ma", "rsi_value", "rsi_buy", "rsi_sell"]
36 |
37 | def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
38 | """"""
39 | super(AtrRsiStrategy, self).__init__(
40 | cta_engine, strategy_name, vt_symbol, setting
41 | )
42 | self.bg = BarGenerator(self.on_bar)
43 | self.am = ArrayManager()
44 |
45 | def on_init(self):
46 | """
47 | Callback when strategy is inited.
48 | """
49 | self.write_log("策略初始化")
50 |
51 | self.rsi_buy = 50 + self.rsi_entry
52 | self.rsi_sell = 50 - self.rsi_entry
53 |
54 | self.load_bar(10)
55 |
56 | def on_start(self):
57 | """
58 | Callback when strategy is started.
59 | """
60 | self.write_log("策略启动")
61 |
62 | def on_stop(self):
63 | """
64 | Callback when strategy is stopped.
65 | """
66 | self.write_log("策略停止")
67 |
68 | def on_tick(self, tick: TickData):
69 | """
70 | Callback of new tick data update.
71 | """
72 | self.bg.update_tick(tick)
73 |
74 | def on_bar(self, bar: BarData):
75 | """
76 | Callback of new bar data update.
77 | """
78 | self.cancel_all()
79 |
80 | am = self.am
81 | am.update_bar(bar)
82 | if not am.inited:
83 | return
84 |
85 | atr_array = am.atr(self.atr_length, array=True)
86 | self.atr_value = atr_array[-1]
87 | self.atr_ma = atr_array[-self.atr_ma_length:].mean()
88 | self.rsi_value = am.rsi(self.rsi_length)
89 |
90 | if self.pos == 0:
91 | self.intra_trade_high = bar.high_price
92 | self.intra_trade_low = bar.low_price
93 |
94 | if self.atr_value > self.atr_ma:
95 | if self.rsi_value > self.rsi_buy:
96 | self.buy(bar.close_price + 5, self.fixed_size)
97 | elif self.rsi_value < self.rsi_sell:
98 | self.short(bar.close_price - 5, self.fixed_size)
99 |
100 | elif self.pos > 0:
101 | self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
102 | self.intra_trade_low = bar.low_price
103 |
104 | long_stop = self.intra_trade_high * \
105 | (1 - self.trailing_percent / 100)
106 | self.sell(long_stop, abs(self.pos), stop=True)
107 |
108 | elif self.pos < 0:
109 | self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
110 | self.intra_trade_high = bar.high_price
111 |
112 | short_stop = self.intra_trade_low * \
113 | (1 + self.trailing_percent / 100)
114 | self.cover(short_stop, abs(self.pos), stop=True)
115 |
116 | self.put_event()
117 |
118 | def on_order(self, order: OrderData):
119 | """
120 | Callback of new order data update.
121 | """
122 | pass
123 |
124 | def on_trade(self, trade: TradeData):
125 | """
126 | Callback of new trade data update.
127 | """
128 | self.put_event()
129 |
130 | def on_stop_order(self, stop_order: StopOrder):
131 | """
132 | Callback of stop order update.
133 | """
134 | pass
135 |
--------------------------------------------------------------------------------
/python学习资料/VNProject/my_strategy_tool.py:
--------------------------------------------------------------------------------
1 | from vnpy.app.cta_strategy import (BarGenerator, ArrayManager)
2 | from vnpy.trader.object import BarData, TickData
3 | from vnpy.trader.constant import Interval
4 | from typing import Callable
5 | import talib
6 |
7 |
8 | class NewBarGenerator(BarGenerator):
9 | ''''''
10 |
11 | def __init__(
12 | self,
13 | on_bar: Callable,
14 | window: int = 0,
15 | on_window_bar: Callable = None,
16 | interval: Interval = Interval.MINUTE
17 | ):
18 | super(NewBarGenerator, self).__init__(on_bar, window, on_window_bar, interval)
19 |
20 | def update_bar(self, bar: BarData):
21 | """
22 | Update 1 minute bar into generator
23 | """
24 | # If not inited, creaate window bar object
25 | if not self.window_bar:
26 | # Generate timestamp for bar data
27 | if self.interval == Interval.MINUTE:
28 | dt = bar.datetime.replace(second=0, microsecond=0)
29 | else:
30 | dt = bar.datetime.replace(minute=0, second=0, microsecond=0)
31 |
32 | self.window_bar = BarData(
33 | symbol=bar.symbol,
34 | exchange=bar.exchange,
35 | datetime=dt,
36 | gateway_name=bar.gateway_name,
37 | open_price=bar.open_price,
38 | high_price=bar.high_price,
39 | low_price=bar.low_price
40 | )
41 | # Otherwise, update high/low price into window bar
42 | else:
43 | self.window_bar.high_price = max(
44 | self.window_bar.high_price, bar.high_price)
45 | self.window_bar.low_price = min(
46 | self.window_bar.low_price, bar.low_price)
47 |
48 | # Update close price/volume into window bar
49 | self.window_bar.close_price = bar.close_price
50 | self.window_bar.volume += int(bar.volume)
51 | self.window_bar.open_interest = bar.open_interest
52 |
53 | # Check if window bar completed
54 | finished = False
55 |
56 | if self.interval == Interval.MINUTE:
57 | # # x-minute bar
58 | # if not (bar.datetime.minute + 1) % self.window:
59 | # finished = True
60 | if self.last_bar and bar.datetime.minute != self.last_bar.datetime.minute:
61 | self.interval_count += 1
62 | if not self.interval_count % self.window:
63 | finished = True
64 | self.interval_count = 0
65 |
66 | elif self.interval == Interval.HOUR:
67 | if self.last_bar and bar.datetime.hour != self.last_bar.datetime.hour:
68 | # 1-hour bar
69 | if self.window == 1:
70 | finished = True
71 | # x-hour bar
72 | else:
73 | self.interval_count += 1
74 |
75 | if not self.interval_count % self.window:
76 | finished = True
77 | self.interval_count = 0
78 |
79 | if finished:
80 | self.on_window_bar(self.window_bar)
81 | self.window_bar = None
82 |
83 | # Cache last bar object
84 | self.last_bar = bar
85 |
86 |
87 | class NewArrayManager(ArrayManager):
88 | def __init__(self, size=100):
89 | super(NewArrayManager, self).__init__(size)
90 |
91 | def aroon(self, n, array=False):
92 | """
93 | AROON.
94 | """
95 | aroon_up, aroon_down = talib.AROON(
96 | self.high, self.low, n
97 | )
98 | if array:
99 | return aroon_up, aroon_down
100 | return aroon_up[-1], aroon_down[-1]
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/python学习资料/VNProject/no_ui/backtesting.py:
--------------------------------------------------------------------------------
1 | from vnpy.app.cta_strategy.backtesting import BacktestingEngine, OptimizationSetting
2 | from vnpy.app.cta_strategy.strategies.atr_rsi_strategy import (
3 | AtrRsiStrategy,
4 | )
5 | from datetime import datetime
6 |
7 |
8 | engine = BacktestingEngine()
9 | engine.set_parameters(
10 | vt_symbol="IF88.CFFEX",
11 | interval="1m",
12 | start=datetime(2019, 1, 1),
13 | end=datetime(2019, 4, 30),
14 | rate=0.3/10000,
15 | slippage=0.2,
16 | size=300,
17 | pricetick=0.2,
18 | capital=1_000_000,
19 | )
20 | engine.add_strategy(AtrRsiStrategy, {})
21 |
22 |
23 | engine.load_data()
24 | engine.run_backtesting()
25 | df = engine.calculate_result()
26 | engine.calculate_statistics()
27 | engine.show_chart()
28 |
29 |
30 | setting = OptimizationSetting()
31 | setting.set_target("sharpe_ratio")
32 | setting.add_parameter("atr_length", 3, 39, 1)
33 | setting.add_parameter("atr_ma_length", 10, 30, 1)
34 |
35 | engine.run_ga_optimization(setting)
--------------------------------------------------------------------------------
/python学习资料/VNProject/no_ui/run.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from time import sleep
3 | from datetime import datetime, time
4 | from logging import INFO
5 |
6 | from vnpy.event import EventEngine
7 | from vnpy.trader.setting import SETTINGS
8 | from vnpy.trader.engine import MainEngine
9 |
10 | from vnpy.gateway.ctp import CtpGateway
11 | from vnpy.app.cta_strategy import CtaStrategyApp
12 | from vnpy.app.cta_strategy.base import EVENT_CTA_LOG
13 |
14 |
15 | SETTINGS["log.active"] = True
16 | SETTINGS["log.level"] = INFO
17 | SETTINGS["log.console"] = True
18 |
19 |
20 | ctp_setting = {
21 | "用户名": "",
22 | "密码": "",
23 | "经纪商代码": "",
24 | "交易服务器": "",
25 | "行情服务器": "",
26 | "产品名称": "",
27 | "授权编码": "",
28 | "产品信息": ""
29 | }
30 |
31 |
32 | def run_child():
33 | """
34 | Running in the child process.
35 | """
36 | SETTINGS["log.file"] = True
37 |
38 | event_engine = EventEngine()
39 | main_engine = MainEngine(event_engine)
40 | main_engine.add_gateway(CtpGateway)
41 | cta_engine = main_engine.add_app(CtaStrategyApp)
42 | main_engine.write_log("主引擎创建成功")
43 |
44 | log_engine = main_engine.get_engine("log")
45 | event_engine.register(EVENT_CTA_LOG, log_engine.process_log_event)
46 | main_engine.write_log("注册日志事件监听")
47 |
48 | main_engine.connect(ctp_setting, "CTP")
49 | main_engine.write_log("连接CTP接口")
50 |
51 | sleep(10)
52 |
53 | cta_engine.init_engine()
54 | main_engine.write_log("CTA策略初始化完成")
55 |
56 | cta_engine.init_all_strategies()
57 | sleep(60) # Leave enough time to complete strategy initialization
58 | main_engine.write_log("CTA策略全部初始化")
59 |
60 | cta_engine.start_all_strategies()
61 | main_engine.write_log("CTA策略全部启动")
62 |
63 | while True:
64 | sleep(1)
65 |
66 |
67 | def run_parent():
68 | """
69 | Running in the parent process.
70 | """
71 | print("启动CTA策略守护父进程")
72 |
73 | # Chinese futures market trading period (day/night)
74 | DAY_START = time(8, 45)
75 | DAY_END = time(15, 30)
76 |
77 | NIGHT_START = time(20, 45)
78 | NIGHT_END = time(2, 45)
79 |
80 | child_process = None
81 |
82 | while True:
83 | current_time = datetime.now().time()
84 | trading = False
85 |
86 | # Check whether in trading period
87 | if (
88 | (current_time >= DAY_START and current_time <= DAY_END)
89 | or (current_time >= NIGHT_START)
90 | or (current_time <= NIGHT_END)
91 | ):
92 | trading = True
93 |
94 | # Start child process in trading period
95 | if trading and child_process is None:
96 | print("启动子进程")
97 | child_process = multiprocessing.Process(target=run_child)
98 | child_process.start()
99 | print("子进程启动成功")
100 |
101 | # 非记录时间则退出子进程
102 | if not trading and child_process is not None:
103 | print("关闭子进程")
104 | child_process.terminate()
105 | child_process.join()
106 | child_process = None
107 | print("子进程关闭成功")
108 |
109 | sleep(5)
110 |
111 |
112 | if __name__ == "__main__":
113 | run_parent()
114 |
--------------------------------------------------------------------------------
/python学习资料/socket.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/python学习资料/tushareDownloadTool.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import requests,io,random,os
4 | import pandas as pd
5 | from datetime import datetime
6 |
7 | stock="sz000868" #安凯客车股票代码
8 | path = os.path.join(os.getcwd(), 'data')
9 |
10 | for x1 in range(2017,2018):
11 | data=pd.DataFrame(columns=['date','time','datetime' 'price', 'change', 'volume','amount','type'])
12 | for x2 in range(1,13):
13 | ipAddress = str(random.randint(0,255))+'.'+str(random.randint(0,255))+'.'+str(random.randint(0,255))+'.'+str(random.randint(0,255))
14 | headers = {"X-Forwarded-For": ipAddress}
15 | for x3 in range(1,32):
16 | if x2<10 and x3<10:
17 | Date=str(x1)+'-0'+str(x2)+'-0'+str(x3)
18 | elif x2<10 and x3>9:
19 | Date=str(x1)+'-0'+str(x2)+'-'+str(x3)
20 | elif x2>9 and x3<10:
21 | Date=str(x1)+'-'+str(x2)+'-0'+str(x3)
22 | elif x2>9 and x3>9:
23 | Date=str(x1)+'-'+str(x2)+'-'+str(x3)
24 | print(x2,x3,Date)
25 | params = {"date": Date, "symbol": stock}
26 | url = 'https://market.finance.sina.com.cn/downxls.php'
27 | r = requests.get(url, params=params, headers=headers)
28 | r.encoding = 'gbk'
29 | df= pd.read_table(io.StringIO(r.text), names=['time', 'price', 'change', 'volume', 'amount', 'type'],
30 | skiprows=[0])
31 | # #print(df)
32 | #当列表值大于三的时候 才转换对日期进行格式
33 | if len(df.index)>3:
34 | tempDatetime = datetime.strptime(Date, "%Y-%m-%d") #string--->datetime
35 | Date = tempDatetime.strftime("%Y%m%d" ) #datetime-->string
36 | df['date']=Date
37 | df=df.sort_values(by=['time'],ascending=True)#按列进行排序
38 | timelist=list(df['time'])
39 | # #print(len(timelist))
40 | if len(timelist)>3:
41 | try:
42 | times=[datetime.strptime(val, "%H:%M:%S").strftime("%H%M%S") for val in timelist]
43 | print(times)
44 | except Exception as e:
45 | print(params)
46 | print(headers)
47 | print(e)
48 | break
49 | df['time']=times
50 | # #print(df)
51 | df['datetime'] = df[['date', 'time']].apply(lambda x: ''.join(str(value) for value in x), axis=1)
52 | df['type']=df['type'].replace('卖盘',-1)
53 | df['type']=df['type'].replace('中性盘',0)
54 | df['type']=df['type'].replace('买盘',1)
55 | data = pd.concat([data, df])
56 | filePath = os.path.join(path, '{}.csv'.format(x1))
57 | data.to_csv(filePath)
58 |
59 | # for x1 in range(2017, 2019):
60 | # path = os.path.join(os.getcwd(), 'data')
61 | # filePath = os.path.join(path, '{}.csv'.format(x1))
62 | # print(filePath)
63 | params={'date': '2017-01-01', 'symbol': 'sz000868'}
64 | headers = {'X-Forwarded-For': '120.16.244.83'}
65 | url = 'https://market.finance.sina.com.cn/downxls.php'
66 | r = requests.get(url, params=params, headers=headers)
--------------------------------------------------------------------------------
/python学习资料/多线程.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# '''\n",
10 | "多任务可以由多进程完成,也可以由一个进程内的多线程完成\n",
11 | "我们前面提到了进程由若干线程组成的,一个进程至少有一个线程。\n",
12 | "由于线程是操作系统直接支持的执行单位,因此,高级语言通常都内置多线程的支持,python的线程是真正的Posix Thread,而不是模拟出来的线程。\n",
13 | "Python的标准库提供了两个模块:_thread和threading,_thread是低级模块,threading是高级模块,对_thread进行封装。绝大多数情况下,\n",
14 | "我们只需要使用threading这个高级模块\n",
15 | "'''\n",
16 | "import time,threading\n",
17 | "\n",
18 | "#新线程执行的代码\n"
19 | ]
20 | }
21 | ],
22 | "metadata": {
23 | "kernelspec": {
24 | "display_name": "Python 3",
25 | "language": "python",
26 | "name": "python3"
27 | }
28 | },
29 | "nbformat": 4,
30 | "nbformat_minor": 2
31 | }
32 |
--------------------------------------------------------------------------------
/python学习资料/爬虫/import-Spyder.py:
--------------------------------------------------------------------------------
1 | import urllib.request
2 | from urllib.request import HTTPPasswordMgrWithDefaultRealm,HTTPBasicAuthHandler,build_opener
3 | from urllib.error import URLError
4 |
5 | from urllib.request import ProxyHandler, build_opener
6 |
7 | import http.cookiejar
8 |
9 | import urllib.parse
10 |
11 | from urllib.parse import urlparse
12 | from urllib.parse import urlunparse
13 | from urllib.parse import urlsplit
14 | from urllib.parse import urlunsplit
15 | from urllib.parse import urljoin
16 | from urllib.parse import urlencode
17 | from urllib.parse import parse_qs
18 | from urllib.parse import parse_qsl
19 | from urllib.parse import quote
20 | from urllib.parse import unquote
21 |
22 | from urllib.robotparser import RobotFileParser
--------------------------------------------------------------------------------
/python学习资料/爬虫/robot.txt:
--------------------------------------------------------------------------------
1 | User-agent:*
2 | Disallow:/
3 | Allow:/public/
--------------------------------------------------------------------------------
/python学习资料/爬虫/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Aug 14 07:51:48 2019
4 |
5 | @author: 26063
6 | """
7 |
8 | import requests
9 | from lxml import etree
10 | #伪装Chrome浏览器
11 |
12 | headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/55.0.2883.87 Safari/537.36',}
13 | Url='https://www.hao24.com/channel/live_list.html'
14 | request=requests.get(Url, headers=headers)
15 | request.encoding='utf-8'
16 | html=request.text
17 | html_1=etree.HTML(html)
18 | # result = html_1.xpath('//*[@id="today_quan"]/ul[2]/li[2]/p[1]/a//text()')
19 | # print(result)
20 |
21 | result=html_1.xpath('//li//p[@class="one"]//text()')
22 | result_1=html_1.xpath('//li//p[@class="two"]//text()')
23 | result_2=html_1.xpath('//li//p[@class="title"]//a//text()')
24 | result_3=html_1.xpath('//li//p[@class="price"]//text()')
25 | for x in result:
26 | print(x)
27 |
28 |
29 | result = [x for x in result if '-' in x]
30 | for i in range(len(result)):
31 | print(result[i] + '|' + result_1[i] + '|' + result_2[i] + '|' + result_3[i])
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/python学习资料/算法/程序员的算法趣题.py:
--------------------------------------------------------------------------------
1 | '''
2 | 算法学习
3 | '''
4 | # 01回文数
5 | '''
6 | '''
7 | n = 10
8 | while True:
9 | x = n % 10答案
10 | '''
11 |
12 |
13 | # 03翻牌
14 | '''
15 | 解题
16 | '''
17 | import numpy as np
18 | ar = np.zeros(100)
19 | def myTest():
20 | for i in range(2,101):
21 | for j in range(i, 101, i):
22 | ar[j-1]=1 if (ar[j-1]==0) else 0
23 |
24 | myTest()
25 |
26 | '''
27 |
28 | card = np.zeros(100)
29 | for i in range(2, 101):
30 | j = i - 1
31 | while j < card.size:
32 | card[j]=1 if card[j]==0 else 0
33 | j += i
34 |
35 |
36 |
--------------------------------------------------------------------------------
/python学习资料/量化交易之路(第八章).py:
--------------------------------------------------------------------------------
1 | '''
2 | 以下代码为第七章海龟交易法
3 | '''
4 | # from abu.abupy import AbuFactorBuyBase
5 | #
6 | # class AbuFactorBuyBreak(AbuFactorBuyBase):
7 | # def __init__(self, **kwargs):
8 | # # 突破xd
9 | # self.xd = kwargs['xd']
10 | # # 忽略连续创新高,比如买入后第二天又突破新高,忽略
11 | # self.skip_days = 0
12 | # # 在输出生成的orders_pd中显示名字
13 | # self.factor_name = '{}:{}'.format(self.__class__.__name__,self.xd)
14 | #
15 | #
16 | # def fit_day(self, today):
17 | # day_ind = int(today.key)
18 | # # 忽略不符合买入日(统计周期内前xd天及最后一天)
19 | # if day_ind < self.xd - 1 or day_ind >= self.kl_pd.shape[0] - 1:
20 | # return None
21 | #
22 | # if self.skip_days > 0:
23 | # # 执行买入订单后的忽略
24 | # self.skip_days -= 1
25 | # return None
26 | #
27 | # # 今天收盘价格达到xd天内最高价格则符合条件
28 | # if today.close == self.kl_pd.close[day_ind - self.xd + 1:day_ind+1].max():
29 | # # 把xd赋值给忽略买入日,即xd天内再次又创新高,也不买了
30 | # self.skip_days = self.xd
31 | # # 生成买入订单
32 | # return self.make_buy_order(day_ind)
33 | # return None
34 |
35 | from abu.abupy import AbuFactorBuyBreak
--------------------------------------------------------------------------------
/python学习资料/量化交易之路(第六章).py:
--------------------------------------------------------------------------------
1 | '''
2 | 第六章 量化工具---数学
3 | '''
4 |
5 | # 6.1回归与插值
6 | '''
7 | 回归,指研究一组随机变量(Y1,Y2,Y3...Yi)和另一组(X1,X2,...,Xk)变量之间的关系的统计分析方法,又称多重回归分析。
8 | 通常Y1,Y2,...,Yi是因变量;X1,X2,...,Xk是自变量。
9 | '''
10 | '''
11 | 1.偏差绝对值之和最小(MAE)
12 | 2.偏差平方和最小(MSE)对误差极值的惩罚程度
13 | 3.偏差平方和开平方最小(RMSE)对误差的评估
14 | '''
15 | import numpy as np
16 | from abu.abupy import ABuSymbolPd
17 | tsla_close = ABuSymbolPd.make_kl_df('usTSLA').close
18 | x = np.arange(0, tsla_close.shape[0])
19 | y = tsla_close.values
20 |
21 | '''
22 | 下面通过statsmodels.api.OLS()函数实现一次多项式拟合计算,即最简单的y = kx + b。使用summary()函数可以看到
23 | Method = Least Squares,即使用了最小二乘法。
24 | '''
25 | import statsmodels.api as sm
26 | from statsmodels import regression
27 | import matplotlib.pyplot as plt
28 | def regress_y(y):
29 | y = y
30 | x = np.arange(0, len(y))
31 | x = sm.add_constant(x)
32 | # 使用OLS做拟合
33 | model = regression.linear_model.OLS(y,x).fit()
34 | return model
35 |
36 | model = regress_y(y)
37 | b = model.params[0]
38 | k = model.params[1]
39 |
40 | y_fit = k * x + b
41 | plt.plot(x, y)
42 | plt.plot(x, y_fit, 'r')
43 | # summary()函数模拟拟合概述
44 | model.summary()
45 |
46 |
47 | '''
48 | 按照公式计算
49 | MAE
50 | '''
51 | MAE = sum(np.abs(y-y_fit))/ len(y)
52 | print('偏差绝对值之和(MAE)={}'.format(MAE))
53 | MSE = sum(np.square(y-y_fit))/len(y)
54 | print('偏差绝对值之和(MSE)={}'.format(MSE))
55 | RMSE = np.sqrt(MSE)
56 | print('偏差绝对值之和(RMSE)={}'.format(RMSE))
57 |
58 |
59 | '''
60 | 多项式回归
61 | 观察上面的误差值,由于一次线性回归所以误差值很大,多项式回归拟合最简单的方式就是使用np.polynomial()函数
62 | '''
63 | # 以下计算1~9次多项式回归,计算MSE的值,可以看到随着poly的增大,MSE的值逐步降低
64 | from sklearn import metrics
65 | MAE2 = metrics.mean_absolute_error(y, y_fit)
66 | MSE2 = metrics.mean_squared_error(y, y_fit)
67 | RMSE2 = np.sqrt(MSE2)
68 |
69 | import itertools
70 | _, axs = plt.subplots(nrows=3, ncols=3, figsize=(15,15))
71 | axs_list = list(itertools.chain.from_iterable(axs))
72 |
73 | poly = np.arange(1, 10, 1)
74 | for p_cnt, ax in zip(poly, axs_list):
75 | # 使用polynomial.Chebyshev.fit()函数进行多项式拟合
76 | p = np.polynomial.Chebyshev.fit(x, y, p_cnt)
77 | # 使用p直接对x序列代入即得到拟合结果序列
78 | y_fit = p(x)
79 | # 度量mse值
80 | mse = metrics.mean_squared_error(y, y_fit)
81 | # 使用拟合次数和mse误差大小设置标题
82 | ax.set_title('{} poly MSE={}'.format(p_cnt, mse))
83 | ax.plot(x,y,'',x,y_fit,'r.')
84 |
--------------------------------------------------------------------------------
/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LancelotYe/TMChanQuant/d092f06409fff9dc118f6f32fca5d3610dfbac3d/test.png
--------------------------------------------------------------------------------