├── .github └── workflows │ └── node.js.yml ├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── myscripts ├── comments.js ├── git_pull.sh ├── prepare.js ├── qiniu.conf └── qiniu_upload.sh ├── package.json ├── posts ├── README.md ├── about │ ├── about.md │ ├── about │ │ └── leave_a_msg.gif │ ├── link.md │ └── link │ │ └── link.png ├── blog-experience │ ├── blog-experience-1.md │ ├── blog-experience-1 │ │ ├── geektutu-css-chrome.webp │ │ └── loading.gif │ ├── blog-experience-2.md │ ├── blog-experience-2 │ │ └── origin_create.gif │ ├── blog-experience-3.md │ ├── blog-experience-3 │ │ ├── cnzz.jpg │ │ ├── recommend.jpg │ │ └── series_container.jpg │ ├── blog-experience-4.md │ ├── blog-experience-4 │ │ ├── cnzz_29_30.jpg │ │ ├── google.jpg │ │ ├── google_geektutu.jpg │ │ ├── google_search_console.jpg │ │ └── google_spam_report.jpg │ ├── blog-experience-5.md │ ├── blog-experience-5 │ │ ├── heart.jpg │ │ └── spam_score.png │ ├── blog-experience-6.md │ ├── blog-experience-6 │ │ ├── value.jpg │ │ └── value_pencil.jpg │ ├── blog-experience-7.md │ └── blog-experience-7 │ │ ├── also_on.jpg │ │ └── comment.jpg ├── cheat-sheet │ ├── cheat-sheet-sqlite.md │ └── cheat-sheet-sqlite │ │ ├── sqlite.jpg │ │ └── sqlite.jpg.bak ├── data-mining │ ├── pandas-cheat-sheet-zh-cn.md │ ├── pandas-cheat-sheet-zh-cn │ │ ├── 1.webp │ │ ├── 2.webp │ │ ├── Pandas_Cheat_Sheet_zh_CN.pdf │ │ ├── Pandas_Cheat_Sheet_zh_CN.pptx │ │ └── cheat_sheet_part.png │ ├── pandas-dataframe-series.md │ ├── pandas-dataframe-series │ │ └── pandas.gif │ ├── pandas-select-data.md │ └── pandas-select-data │ │ └── pandas.gif ├── pandora-box │ ├── awesome-config.md │ ├── awesome-config │ │ └── wsl.jpg │ ├── box-tools.md │ └── box-tools │ │ ├── colorhunt.jpg │ │ ├── imageoptim.jpg │ │ └── regex.jpg ├── quick-start │ ├── go │ │ ├── quick-go-context.md │ │ ├── quick-go-context │ │ │ └── context_sm.jpg │ │ ├── quick-go-gin.md │ │ ├── quick-go-gin │ │ │ ├── gin.jpg │ │ │ └── hello_gin.jpg │ │ ├── quick-go-mmap.md │ │ ├── quick-go-mmap │ │ │ ├── mmap.jpg │ │ │ └── mmap_sm.jpg │ │ ├── quick-go-protobuf.md │ │ ├── quick-go-protobuf │ │ │ ├── go-protobuf.jpg │ │ │ └── protocol-buffers.jpg │ │ ├── quick-go-rpc.md │ │ ├── quick-go-rpc │ │ │ ├── go-rpc.jpg │ │ │ └── rpc-procedure.jpg │ │ ├── quick-go-test.md │ │ ├── quick-go-test │ │ │ └── go_test.jpg │ │ ├── quick-go-wasm.md │ │ ├── quick-go-wasm │ │ │ ├── callback.png │ │ │ ├── go-wasm.jpg │ │ │ ├── hello_world.png │ │ │ └── register_functions.png │ │ ├── quick-go2.md │ │ ├── quick-go2 │ │ │ └── go2.jpg │ │ ├── quick-golang.md │ │ ├── quick-golang │ │ │ └── golang.jpg │ │ ├── quick-gomock.md │ │ └── quick-gomock │ │ │ ├── gomock.jpg │ │ │ └── gomock_logo.jpg │ ├── python │ │ ├── quick-python.md │ │ └── quick-python │ │ │ └── python.jpg │ └── rust │ │ ├── quick-rust.md │ │ └── quick-rust │ │ └── rust.jpg ├── summary │ ├── 2020 │ │ ├── 2020.jpg │ │ └── data.png │ └── 2020.md └── tensorflow │ ├── tensorflow-make-npy-hdf5-data-set.md │ ├── tensorflow-make-npy-hdf5-data-set │ └── gen_mnist_images.png │ ├── tensorflow-mnist-save-ckpt.md │ ├── tensorflow-mnist-save-ckpt │ └── save_ckpt.png │ ├── tensorflow-mnist-simplest.md │ ├── tensorflow-mnist-simplest │ ├── loss.png │ └── x_y.png │ ├── tensorflow-mnist-tensorboard-training.md │ ├── tensorflow-mnist-tensorboard-training │ ├── tensorboard_mnist_graph.png │ └── tensorbord_mnist_loss.png │ ├── tensorflow2-gym-dqn.md │ ├── tensorflow2-gym-dqn │ ├── dqn.jpg │ ├── mountaincar_v0_scores.jpg │ └── mountaincar_v0_success.gif │ ├── tensorflow2-gym-nn.md │ ├── tensorflow2-gym-nn │ ├── cartpole_v0_failed.gif │ └── cartpole_v0_success.gif │ ├── tensorflow2-gym-pg.md │ ├── tensorflow2-gym-pg │ ├── pg_optimize.jpg │ ├── pg_plot.jpg │ └── pg_success.gif │ ├── tensorflow2-gym-q-learning.md │ ├── tensorflow2-gym-q-learning │ ├── mountaincar_v0_failed.gif │ └── mountaincar_v0_success.gif │ ├── tensorflow2-mnist-cnn.md │ └── tensorflow2-mnist-cnn │ └── cnn_image_sample.gif ├── scaffolds ├── draft.md ├── page.md └── post.md ├── source ├── 404.md ├── CNAME ├── ads.txt ├── archives │ └── index.md ├── bdunion.txt ├── img │ ├── bg.jpg │ ├── icon.png │ └── related_links │ │ ├── email.png │ │ ├── geekcircle.png │ │ ├── github.png │ │ ├── go.png │ │ ├── rss.jpg │ │ ├── weibo.jpg │ │ └── zhihu.png ├── index.md ├── jd_root.txt ├── robots.txt ├── root.txt ├── series │ └── index.md ├── sogousiteverification.txt ├── tags │ └── index.md └── tool │ └── .gitkeep └── yarn.lock /.github/workflows/node.js.yml: -------------------------------------------------------------------------------- 1 | name: public blog 2 | 3 | on: 4 | workflow_dispatch: 5 | ref: refs/heads/master 6 | schedule: 7 | - cron: '0 21 * * *' 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | node-version: [14.x] 17 | 18 | environment: QQ 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Use Node.js ${{ matrix.node-version }} 23 | uses: actions/setup-node@v1 24 | with: 25 | node-version: ${{ matrix.node-version }} 26 | - name: yarn build 27 | env: 28 | QQ_NAME: ${{ secrets.QQ_NAME }} 29 | QQ_SECRETKEY: ${{ secrets.QQ_SECRETKEY }} 30 | QQ_ACCESSKEY: ${{ secrets.QQ_ACCESSKEY }} 31 | run: | 32 | echo $QQ_NAME 33 | yarn install 34 | yarn update 35 | yarn comment 36 | yarn build 37 | bash myscripts/qiniu_upload.sh -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode 3 | .deploy_git 4 | node_modules 5 | public 6 | source/_posts/* 7 | db.json 8 | baidu_push.sh 9 | urls.txt 10 | qshell* 11 | # github comments 12 | source/tool/comments.json 13 | source/tool/issues.json 14 | # integrated posts 15 | themes/geektutu 16 | posts/7days-golang 17 | posts/7days-python 18 | posts/interview-questions 19 | posts/tensorflow2-docs-zh 20 | posts/high-performance-go 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 极客兔兔的博客 2 | 3 | [![github actions](https://github.com/geektutu/geektutu-blog/workflows/public%20blog/badge.svg)](https://github.com/geektutu/geektutu-blog/actions) 4 | 5 | ## 在线阅读 6 | 7 | Coding 创建有趣的开源项目,戳:[https://geektutu.com/](https://geektutu.com/) 8 | 9 | ## 订阅我的博客 10 | 11 | 最新动态可以关注:知乎 [Go语言](https://www.zhihu.com/people/gzdaijie) 或微博 [极客兔兔](https://weibo.com/geektutu) 12 | 13 | 订阅方式:右上角 **watch** [geektutu/blog](https://github.com/geektutu/blog) ,每篇文章都能收到邮件通知,或通过 [RSS](https://geektutu.com/feed.xml) 订阅。 14 | 15 | **较为完整的系列有:** 16 | 17 | - 一篇文章入门系列 18 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) 19 | - [一篇文章入门 Go](https://geektutu.com/post/quick-golang.html) 20 | - [一篇文章入门 Rust](https://geektutu.com/post/quick-rust.html) 21 | 22 | - Go 语言 23 | - [七天用Go从零实现系列](https://geektutu.com/post/gee.html) 24 | - [Go 语言高性能编程](https://geektutu.com/post/high-performance-go.html) 25 | - [Go 语言笔试面试题](https://geektutu.com/post/qa-golang.html) 26 | 27 | - 机器学习 28 | - [tensorflow mnist 入门系列](https://geektutu.com/post/tensorflow-mnist-simplest.html) 29 | - [tensorflow openai 强化学习系列](https://geektutu.com/post/tensorflow2-gym-nn.html) 30 | - [tensorflow 2.0 文档](https://geektutu.com/post/tf2doc.html) 31 | 32 | - 经历与感悟 33 | - [建站经历](https://geektutu.com/post/blog-experience-1.html) 34 | - [年终总结](https://geektutu.com/post/2020.html) 35 | 36 | ## 关于 hexo 主题 37 | 38 | 使用主题 [hexo-theme-geektutu](https://github.com/geektutu/hexo-theme-geektutu) 39 | 40 | ```bash 41 | yarn install # 安装依赖模块 42 | yarn update # 下载主题到 themes/geektutu 43 | yarn build # 将posts的文章拷贝到source目录下的_posts,并执行hexo clean, hexo generate 44 | yarn deploy # 部署到_config.xml中配置的仓库地址 45 | ``` 46 | 47 | 如果你使用的是 yarn,将下面的 npm 换成 yarn 即可。 48 | 49 | `posts` 目录的存在仅仅是为了博主做博客分类使用, `yarn build` 时会拷贝到 `source/_posts`。 50 | 因此, 直接新建 `source/_posts` 目录,并直接在该文件夹下写文章是没有问题的。 51 | 52 | 可以在 package.json 里的 scripts 部分,定制你自己的 npm/yarn 命令。 53 | 54 | 本站使用对象存储 + CDN 方式托管在[七牛云](https://marketing.qiniu.com/cps/redirect?redirect_id=4&cps_key=1hetil5x65e8i) 55 | 56 | - [账号配置](https://github.com/qiniu/qshell) 57 | - [上传配置](https://github.com/qiniu/qshell/blob/master/docs/qupload.md) 58 | 59 | ``` 60 | qshell user ls 61 | qshell account -- 62 | qshell qupload xxx.conf 63 | ``` 64 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | # Hexo Configuration 2 | ## Docs: https://hexo.io/docs/configuration.html 3 | 4 | # Site 5 | title: 极客兔兔 6 | keyword: geektutu,jiketutu,Go语言,golang,python 7 | description: 极客兔兔的博客,致力于分享一些技术教程和有趣的技术实践,包括但不限于 Go 语言/golang, Python, tensorflow, 分布式, 机器学习相关的内容。也可以通过搜索极客小站,jiketutu 找到我。 8 | author: 极客兔兔 9 | language: zh-CN 10 | 11 | url: https://geektutu.com 12 | root: / 13 | permalink: post/:title.html 14 | 15 | # seo优化各个浏览器的验证信息 16 | beian: 沪ICP备18001798号-1 17 | seo_title: 极客兔兔 # 子页面的后缀,效果: 关于我 | 极客兔兔,如果与title一致,则可以不设置 18 | seo: 19 | google_site_verification: X31OY0d2sKImNbS2x8-OpITWRK4nn7nZCcXe3_hA9ew 20 | baidu_site_verification: p7Pz3jlx4t 21 | ms_site_verification: 7E2AEE3378AC93764DEAB411177A21A1 22 | _360_site_verification: 1d1e81a1a48a1ed2e3308d7cb8e548e7 # 不能以数字开头,所以加上了下划线 23 | 24 | # 开启博客资源的相对链接 https://hexo.io/zh-cn/docs/asset-folders 25 | post_asset_folder: true 26 | skip_render: 27 | - tool/comments.json 28 | - lab/**/* 29 | 30 | # Site settings 31 | header_icon: img/icon.png 32 | 33 | theme: geektutu 34 | 35 | # Deployment 36 | ## Docs: https://hexo.io/docs/deployment.html 37 | deploy: 38 | type: git 39 | repository: https://git.coding.net/gzdaijie/geektutu-blog.git 40 | branch: coding-pages 41 | 42 | gitalk: 43 | client_id: 'c1fdd456a4caae5f7df0' 44 | client_secret: 'b2674451e21feae50520f99337ec15d2aebe7879' 45 | accessToken: '513dc2f1e6df8c07a12fbf1547874f54679d9712' 46 | repo: 'blog' 47 | owner: 'geektutu' 48 | 49 | # alipay wechatpay 的二维码转换的网址,不使用二维码图片,提高加载速率。 50 | reward: 51 | alipay: HTTPS://QR.ALIPAY.COM/FKX060337TUXBAX9LIFJE8 52 | wechat: wxp://f2f0qgGBlfD1nZXjvBjievxB0z0fc0W2sBq5 53 | 54 | related_links: 55 | - link: https://github.com/geektutu 56 | img: img/related_links/github.png 57 | name: Github 58 | - link: https://geektutu.com/feed.xml 59 | img: img/related_links/rss.jpg 60 | name: RSS 61 | - link: https://www.zhihu.com/people/gzdaijie 62 | img: img/related_links/zhihu.png 63 | name: 知乎 64 | - link: https://weibo.com/geektutu 65 | img: img/related_links/weibo.jpg 66 | name: 微博 67 | 68 | # 文章末尾配置一些提示信息,可选,默认关闭 69 | post_tips: 70 | find_me: true # 本站永久域名「 xx.domain 」, 也可以通过搜索「 xx 」找到我。 71 | 72 | widgets: 73 | caidan: false # 开启彩蛋,点击屏幕会随机弹出 '点个赞','留个言'等 74 | busuanzi: true # 开启不蒜子统计 https://busuanzi.ibruce.info/ 75 | 76 | # 自定义推广链接:宽度250,高度不限。 77 | # custom_ads: 78 | # - link: https://s.click.taobao.com/0ZEThqv 79 | # img: https://img.alicdn.com/tfs/TB1QJC6dOrpK1RjSZFhXXXSdXXa-300-100.jpg 80 | # - link: https://url.cn/5bE04oI 81 | # img: https://s2.ax1x.com/2020/02/12/1bns2D.jpg 82 | 83 | 84 | ba_auto_push: true # Baidu 自动推送,以下均为可选配置 85 | ba_track_id: 1a0ec38c52c08db815b0046c2783b1aa # Baidu Analytics 86 | google_analysis: UA-142641425-1 # Google Analytics 87 | 88 | feed: 89 | enable: true 90 | limit: 0 91 | type: rss2 92 | content: false 93 | path: feed.xml 94 | icon: img/icon.png 95 | 96 | index_rec: 97 | - link: https://geektutu.com/post/gee.html 98 | name: 📘 七天用 Go 从零实现 99 | - link: https://geektutu.com/post/high-performance-go.html 100 | name: 📘 Go 语言高性能编程 101 | - link: https://studygolang.com 102 | name: 🔗 Go 语言中文网 -------------------------------------------------------------------------------- /myscripts/comments.js: -------------------------------------------------------------------------------- 1 | const https = require('https'); 2 | const fs = require('fs'); 3 | const path = require('path'); 4 | 5 | const COMMENTS_FILE=path.join(__dirname, '..', 'source/tool/comments.json'); 6 | const ISSUES_FILE=path.join(__dirname, '..', 'source/tool/issues.json'); 7 | 8 | const github = { 9 | repo: 'blog', 10 | owner: 'geektutu', 11 | basicAuth: 'YzFmZGQ0NTZhNGNhYWU1ZjdkZjA6YjI2NzQ0NTFlMjFmZWFlNTA1MjBmOTkzMzdlYzE1ZDJhZWJlNzg3OQ==' 12 | } 13 | 14 | const PREFIX = `/repos/${github.owner}/${github.repo}/` 15 | const PAGING = '&sort=created&direction=desc&per_page=100' 16 | 17 | class Comments { 18 | constructor() { 19 | this.comments = [] 20 | this.issueMap = [] 21 | this.obj = {} 22 | } 23 | 24 | deltaDate(old) { 25 | let hours = (Date.now() - new Date(old)) / 1000 / 3600 26 | 27 | let years = Math.floor(hours / 24 / 365) 28 | if (years) { 29 | return `${years}年前` 30 | } 31 | let months = Math.floor(hours / 24 / 30) 32 | if (months) { 33 | return `${months}月前` 34 | } 35 | 36 | let days = Math.floor(hours / 24) 37 | if (days) { 38 | return `${days}天前` 39 | } 40 | hours = Math.floor(hours) 41 | return `${hours}小时前` 42 | } 43 | 44 | async parse() { 45 | this.comments = await this.get(`issues/comments?${PAGING}`) 46 | console.log(`comments.length: ${this.comments.length}`) 47 | await this.writeComments() 48 | } 49 | 50 | async writeIssues() { 51 | let postMap = {} 52 | let page = 0 53 | let size = 0 54 | while(true) { 55 | page++ 56 | let issues = await this.get(`issues?labels=Gitalk&per_page=100&page=${page}`) 57 | for (const issue of issues) { 58 | let post = issue.labels.find(label => label.name.startsWith("/")).name 59 | postMap[post] = { 60 | "url": issue.html_url, 61 | "comments": issue.comments 62 | } 63 | size++ 64 | } 65 | if (issues.length < 100) { 66 | break 67 | } 68 | } 69 | 70 | fs.writeFileSync(ISSUES_FILE, JSON.stringify(postMap), { encoding: 'utf-8' }); 71 | console.log(`write ${ISSUES_FILE} ${size} success!`) 72 | } 73 | 74 | async fetchIssue(issueUrl) { 75 | let issueApi = issueUrl.slice(issueUrl.indexOf(PREFIX) + PREFIX.length) 76 | if (!this.issueMap[issueApi]) { 77 | let issue = await this.get(issueApi) 78 | if (!issue.labels.find(label => label.name === 'Gitalk')) { 79 | return 80 | } 81 | issue.post = issue.labels.find(label => label.name.startsWith("/")).name 82 | issue.title = issue.title.split('|')[0].trim() 83 | this.issueMap[issueApi] = issue 84 | } 85 | return this.issueMap[issueApi] 86 | } 87 | 88 | async writeComments() { 89 | let simpleComments = {} 90 | for (const comment of this.comments) { 91 | let issue = await this.fetchIssue(comment.issue_url) 92 | if (!issue) { 93 | continue 94 | } 95 | if (issue.user.login === comment.user.login) { 96 | continue 97 | } 98 | if (simpleComments[issue.post]) { 99 | continue 100 | } 101 | 102 | simpleComments[issue.post] = { 103 | title: issue.title, 104 | url: issue.post, 105 | count: issue.comments, 106 | user: comment.user.login, 107 | icon: comment.user.avatar_url, 108 | date: this.deltaDate(comment.created_at), 109 | body: comment.body 110 | .replace(/�/g, "") 111 | .replace(//g, " ") 113 | .replace(/\s+/g, " ") 114 | .trim() 115 | } 116 | } 117 | let obj = Object.keys(simpleComments).map(key => simpleComments[key]) 118 | fs.writeFileSync(COMMENTS_FILE, JSON.stringify(obj), { encoding: 'utf-8' }); 119 | console.log(`write ${COMMENTS_FILE} ${obj.length} success!`) 120 | } 121 | 122 | get(api) { 123 | let options = { 124 | hostname: 'api.github.com', 125 | path: `${PREFIX}${api}`, 126 | headers: { 127 | 'User-Agent': 'Node Https Client', 128 | 'Authorization': `Basic ${github.basicAuth}`, 129 | } 130 | }; 131 | console.log(`GET ${options.path}`) 132 | return new Promise((resolve, reject) => { 133 | const req = https.get(options, (res) => { 134 | let data = ''; 135 | res.on('data', (chunk) => data += chunk); 136 | res.on('end', () => resolve(JSON.parse(data))); 137 | }); 138 | req.on('error', (e) => reject(e)); 139 | req.end(); 140 | }); 141 | } 142 | 143 | } 144 | 145 | 146 | (async () => { 147 | client = new Comments() 148 | await client.writeIssues() 149 | await client.parse() 150 | })(); 151 | 152 | 153 | -------------------------------------------------------------------------------- /myscripts/git_pull.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eou pipefail 3 | 4 | root=$(cd $(dirname $0)/..; pwd) 5 | 6 | pull() { 7 | repo="$1" 8 | dest="$2" 9 | echo "update $dest" 10 | if [ -d "$dest" ]; then 11 | git -C "$dest" pull 12 | else 13 | git clone "https://github.com/$repo" "$dest" 14 | fi; 15 | } 16 | 17 | cd $root 18 | pull geektutu/hexo-theme-geektutu themes/geektutu 19 | pull geektutu/7days-golang posts/7days-golang 20 | pull geektutu/7days-python posts/7days-python 21 | pull geektutu/interview-questions posts/interview-questions 22 | pull geektutu/tensorflow2-docs-zh posts/tensorflow2-docs-zh 23 | pull geektutu/high-performance-go posts/high-performance-go 24 | -------------------------------------------------------------------------------- /myscripts/prepare.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | const path = require('path'); 3 | const exec = require('child_process').execSync; 4 | 5 | const ROOT_PATH=path.join(__dirname, '..'); 6 | const SOURCE_POST_DIR = path.join(ROOT_PATH, 'source/_posts/'); 7 | const EXCLUDE_DIR = [path.join(ROOT_PATH, 'posts/interview-questions/ml/')] 8 | 9 | const walkDir = (dir, callback) => { 10 | fs.readdirSync(dir).forEach(f => { 11 | if (f.startsWith('.')) { 12 | return; 13 | } 14 | let dirPath = path.join(dir, f); 15 | let isDirectory = fs.statSync(dirPath).isDirectory(); 16 | isDirectory ? 17 | walkDir(dirPath, callback) : callback(path.join(dir, f)); 18 | }); 19 | }; 20 | 21 | exec(['rm -rf', SOURCE_POST_DIR].join(' ')) 22 | exec(['mkdir -p', SOURCE_POST_DIR].join(' ')) 23 | 24 | let count = 0; 25 | walkDir(path.join(ROOT_PATH, "posts"), (filePath) => { 26 | if (!filePath.endsWith('.md') || filePath.endsWith('README.md') 27 | || EXCLUDE_DIR.find(d => filePath.startsWith(d))) { 28 | return; 29 | } 30 | count++; 31 | let assetsDirPath = filePath.slice(0, -3); 32 | exec(['cp -p', filePath, SOURCE_POST_DIR].join(' ')) 33 | if (fs.existsSync(assetsDirPath)) { 34 | exec(['cp -r -p', assetsDirPath, SOURCE_POST_DIR].join(' ')) 35 | } 36 | }) 37 | 38 | console.log(`total: ${count}`) 39 | 40 | 41 | -------------------------------------------------------------------------------- /myscripts/qiniu.conf: -------------------------------------------------------------------------------- 1 | { 2 | "src_dir" : "public", 3 | "bucket" : "geektutu-blog", 4 | "overwrite": true, 5 | "rescan_local" : true, 6 | "log_level": "warn", 7 | "log_stdout": true 8 | } -------------------------------------------------------------------------------- /myscripts/qiniu_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eou pipefail 3 | 4 | root=$(cd $(dirname $0)/..; pwd) 5 | 6 | echo "download qiniu" 7 | cd $root/myscripts 8 | rm -f qshell* 9 | wget http://devtools.qiniu.com/qshell-linux-x86-v2.4.1.zip 10 | unzip qshell-linux-x86-v2.4.1.zip 11 | mv ./qshell-linux-x86-v2.4.1 qshell 12 | export PATH=$root/myscripts:$PATH 13 | ./qshell account ${QQ_ACCESSKEY} ${QQ_SECRETKEY} ${QQ_NAME} 14 | 15 | # upload 16 | yarn qiniu -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hexo-site", 3 | "version": "0.0.0", 4 | "private": true, 5 | "hexo": { 6 | "version": "5.2.0" 7 | }, 8 | "scripts": { 9 | "update": "bash myscripts/git_pull.sh", 10 | "comment": "node myscripts/comments.js", 11 | "build": "node myscripts/prepare.js && hexo clean && hexo generate", 12 | "start": "hexo server --draft", 13 | "deploy": "hexo deploy", 14 | "qiniu": "qshell qupload -c 50 myscripts/qiniu.conf" 15 | }, 16 | "dependencies": { 17 | "hexo": "5.2.0", 18 | "hexo-deployer-git": "^3.0.0", 19 | "hexo-generator-category": "^1.0.0", 20 | "hexo-generator-index": "2.0.0", 21 | "hexo-generator-sitemap": "2.1.0", 22 | "hexo-generator-tag": "1.0.0", 23 | "hexo-renderer-ejs": "1.0.0", 24 | "hexo-renderer-marked": "4.0.0", 25 | "hexo-renderer-stylus": "2.0.1", 26 | "hexo-server": "2.0.0", 27 | "hexo-generator-feed": "3.0.0" 28 | } 29 | } -------------------------------------------------------------------------------- /posts/README.md: -------------------------------------------------------------------------------- 1 | # 极客兔兔博客合集 2 | 3 | 博客地址 [geektutu.com](https://geektutu.com) 4 | -------------------------------------------------------------------------------- /posts/about/about.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 留言板 3 | date: 2017-07-03 11:51:24 4 | description: 极客兔兔的小站,致力于分享一些技术教程和有趣的技术实践。 5 | tags: 6 | - 关于我 7 | image: post/about/leave_a_msg.gif 8 | --- 9 | 10 | ## 本站介绍 11 | 12 | 极客兔兔的小站,致力于分享一些技术教程和有趣的技术实践,主要以 Python 相关的技术为主。 13 | 14 | 对本站的任何问题都可以直接在文章下进行评论。 15 | 16 | 也可以通过[gzdaijie@gmail.com](mailto:极客兔兔?subject=【来自】极客兔兔的博客)联系我。 17 | 18 | 为了节省双方的宝贵时间,作如下几点声明: 19 | 20 | - 邮件会尽量回复,不能做到有问必答。 21 | - 请尽量在邮件中将需求或问题叙述清楚。 22 | 23 | ## 版权说明 24 | 25 | 若无特别声明,本站文章欢迎链接分享,但禁止全文转载,本站保留一切权利。 26 | 27 | ## 支持本站 28 | 29 | 坚持探索、尝试新技术,并作总结并不是一件容易的事情,如能对您有所帮助,我感到非常高兴。 30 | 31 | 如果觉得本站文章还不错,欢迎收藏&分享给小伙伴。 32 | 33 | 本站使用主题已**开源**在[Github](https://github.com/geektutu/hexo-theme-geektutu),请不要吝惜您的`Star`! 34 | -------------------------------------------------------------------------------- /posts/about/about/leave_a_msg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/about/about/leave_a_msg.gif -------------------------------------------------------------------------------- /posts/about/link.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 友情链接 3 | date: 2017-07-03 11:51:24 4 | description: 极客兔兔的友情链接 5 | tags: 6 | - 友链 7 | image: post/link/link.png 8 | --- 9 | 10 | 11 | 12 | ## 互粉 13 | 14 | | 名称 | 描述 | 15 | |---| ---| 16 | | [姬長信](https://blog.isoyu.com) | 致力为移动开发者提供资讯、问答、代码下载、工具库等服务 | 17 | | [Mayx的博客](https://mabbs.github.io) | Mayx’s Home Page | 18 | | [kok的笔记本](https://wocai.de) | 写一些杂七杂八的内容 | 19 | | [Chinsyo\|晨晓](https://chinsyo.com) | 一个工程师的迷思与随想 | 20 | | [辛未羊的博客](https://panqiincs.me) | 人生如逆旅,我亦是行人 | 21 | | [hutusi](http://hutusi.com) | 浪漫主义诗人 | 22 | | [寒夏汢](https://hanxiatu.com) | 做一个有趣的人 | 23 | | [肆月之风](https://acme.top) | 爱技术、爱分享 | 24 | | [JustSong](https://iamazing.cn) | 计算机专业求学过程中的思考 | 25 | | [Muniao](https://www.qtmuniao.com) | 摄影,技术,分布式 | 26 | | [IT人技术博客](https://itren.tech) | 站点描述:记录所学、所感、所想 | 27 | | [wujunze blog](https://wujunze.com) | 一个程序员的博客 | 28 | | [zhang0peter](https://zhang0peter.com) | 记录学习与生活 | 29 | | [格物](https://shockerli.net) | 做有进步的分享 | 30 | | [Ryan的技术笔记](https://yuanxuxu.com) | 记录全栈开发以及区块链开发技术笔记 | 31 | | [四畳半神话大系](http://blog.amoyiki.com/) | 一个后端程序员的踩坑经历 | 32 | | [影留](https://leftshadow.com) | 读书、念经 | 33 | | [BBruceyuan](https://bbruceyuan.github.io) | 站点描述:没有故事,写点技术,同时打点酱油! | 34 | | [曹阿宇的博客](https://www.caoayu.xyz) | 简单的分享和记录 | 35 | | [Schwarzeni's blog](https://blog.schwarzeni.com/) | Welcome to my secret garden, coding & life | 36 | | [HelloWorld](https://helloworld.net) | 专业开发者平台 | 37 | | [李文周的博客](https://www.liwenzhou.com) | Go 语言学习之路 | 38 | | [真白的年轮面包](https://mashiro.best) | 此生无悔恋真白,来世愿入樱花庄 | 39 | | [呱牛笔记](https://guaniutech.com/) | 瓜牛的个人技术博客 | 40 | | [lipsuper](https://www.lipsuper.com) | 走在全栈路上,专注产出高质量原创手打文章 | 41 | | [倪旭晨的技术博客](http://nxc.1207game.com/) | 一个全栈开发者的技术博客 | 42 | 43 | ## 注意 44 | 45 | **留言可以互换友链,先将本站添加为友链,再到评论区留言,一天内处理。** 46 | **希望这篇文章[《友链这件事,没那么简单》](https://geektutu.com/post/blog-experience-5.html)对寻求友链的你有帮助。** 47 | 48 | > 本站文章均为原创,要求互换友链的你,也是原创博主。 49 | > 拒绝和聚合站、采集站、转载站做朋友。 50 | > 格式: 51 | > 站点名称:极客兔兔 52 | > 站点描述:致力于分享有趣的技术实践(30字左右,太长一行显示不下) 53 | > 站点地址:https://geektutu.com 54 | > 已加友链:https://geektutu.com/post/link.html 55 | 56 |
57 |
-------------------------------------------------------------------------------- /posts/about/link/link.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/about/link/link.png -------------------------------------------------------------------------------- /posts/blog-experience/blog-experience-1.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 博客折腾记(一) - 极致性能的尝试 3 | date: 2019-06-17 20:10:00 4 | description: 极客兔兔(Geektutu)的博客折腾记的第一篇,介绍了博主从Koa+React到Hexo的切换,以及为达到极致性能的一些尝试。 5 | keywords: 6 | - 独立博客 7 | - 极致性能 8 | tags: 9 | - 关于我 10 | categories: 11 | - 建站经历 12 | nav: 杂谈 13 | image: post/blog-experience-1/loading.gif 14 | --- 15 | 16 | ## Koa+React的尝试 17 | 18 | 一开始搭建静态博客是在2017年,大学时候买的腾讯云的服务器还没过期,就想着要么就搭个个人小站吧。2017年的时候,流量基本都往移动端转了,所以也没想着自己的博客会有多大的流量,权当是个人学习的一些记录吧。不过当时的服务器是1M1G的配置,要是一个页面超过128KB,光加载就不止1s了。现在动辄100M的家宽,这个速度是不可忍受的。当时试过了好几个博客框架,总是会引入很多JS库,有些一个小小的动画,就得引入jQuery和Bootstrap两个库,五六秒才能加载完。所以就下定决心写一个极简的框架,动画都通过原生JS实现,静态博客也不太可能有太过复杂的动画效果,引用三方库的话杀鸡用牛刀了。 19 | 20 | 当时选择的技术栈是Koa+React,为了SEO采用了服务端渲染(Server Render)的方式,每个页面大小维持在50KB左右,半秒内加载完毕,博客内的图片,上传到免费的图床。大概维护了1年左右,就无心折腾了,转向了现在的Hexo静态博客,代码归档在了[Github - React Server Render Blog](https://github.com/geektutu/hexo-theme-geektutu/tree/react-server-render-blog)。 21 | 22 | ## 转向Hexo,自己写主题 23 | 24 | 转向Hexo的原因有很多,中间有几个月太忙了,没关注博客了,后来想起来的时候,博客已经挂了好长一段时间了,不知道是服务器受攻击,还是博客框架写的有问题,宕机了。另外还有一点是,为了保证加载速度,只使用了React渲染出来的静态页面,而生成的JS太大了,JS没引入页面,而是通过写原生JS放在额外的文件里加载,这就导致,每次添加需要JS交互的功能都非常痛苦。 25 | 26 | 后来偶然发现了一个网站,[Hexo Theme](https://hexo.io/themes/index.html),看到了很多人分享的静态主题,隐约觉得Hexo可扩展性很强,而且功能也完全能满足我的需求,因此就尝试了一下。一开始也遇到了之前页面过大的问题,很多主题看上去很不错,但是各种CSS库,JS库引用个不停,看着浏览器的小圆圈转个两秒真心是受不了了。另外再加上对自己原来设计的样式还是有点感情的,就自己动手撸主题了,这个主题也开源在了[Github - hexo-theme-geektutu](https://github.com/geektutu/hexo-theme-geektutu)。 27 | 28 | 迁移原有博客所有的样式和功能花了2天时间,写完后,惊叹于Hexo的可扩展性,包括可以定制生成文章的永久链接(permalink),我原有博客的链接是`/post/title.html`的格式,中间试过python的静态博客框架mkdocs不支持定制。各种小功能小特性的开发效率也很高,所以迁移到Hexo的那一个月的时间,每天都处于很兴奋的状态,添加了很多之前想加但出于拖延没有加的功能。 29 | 30 | ## 极致性能之小于10KB的CSS 31 | 32 | 1. 响应式布局,裁剪了[bootstrap.css](http://v3.bootcss.com/components/),仅使用了其中col-lg,col-md,col-sm, col-xs部分很少量的代码,图中的layout.css,仅1.8KB。 33 | 2. 兼容不同浏览器,初始化CSS, CSS Reset采用的是很小巧的[minireset.css](https://github.com/jgthms/minireset.css),不到1KB。 34 | 3. 渲染markdown的css,采用了markdown-it的一个模板,按照自己的颜色喜好更改过,3.5KB 35 | 4. 代码高亮,选取了`night`主题的`hightlight.css`,1.1KB 36 | 5. 公共的CSS代码放在了geektutu.css,不到1K。其余,每个组件的CSS代码,都以` 64 | 73 | ``` 74 | 75 | 使用 JavaScript 的 fetch 函数请求`comments.json`的数据,随机选取 4 条,动态地生成 HTML,插入到 `gitalk-related`中。这样每篇文章推荐的评论都不一样了。效果的话,相信读到这里你已经看到了~ 76 | 77 | 生成 comments.json 和动态渲染 HTML 的代码都放到 [Github](https://gist.github.com/geektutu/f379d87767787979507a0e4a20da64ba)。CSS 色彩基本是按照 Gitalk 来写的,这份代码对所有的使用 Github Issues 作为评论系统的童鞋都是可用的,也是静态博客完善评论系统的一个思路吧。 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /posts/blog-experience/blog-experience-7/also_on.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/blog-experience/blog-experience-7/also_on.jpg -------------------------------------------------------------------------------- /posts/blog-experience/blog-experience-7/comment.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/blog-experience/blog-experience-7/comment.jpg -------------------------------------------------------------------------------- /posts/cheat-sheet/cheat-sheet-sqlite.md: -------------------------------------------------------------------------------- 1 | --- 2 | cheat_sheet: true 3 | title: SQLite 常用命令 4 | seo_title: 速查表(Cheat Sheet) 5 | date: 2020-03-02 23:51:24 6 | description: SQLite 常见的使用命令。包括数据库的安装与连接,表(table)的创建(create)与删除(delete),记录的插入(insert)删除(delete)查询(select)改(update),以及事务(transaction)等操作。 7 | tags: 8 | - Cheat Sheet 9 | keywords: 10 | - SQLite 使用教程 11 | - 数据库 12 | nav: 百宝箱 13 | categories: 14 | - SQLite 速查表 15 | image: post/cheat-sheet-sqlite/sqlite.jpg 16 | --- 17 | 18 | ## SQLite 命令 19 | 20 |
21 | 22 | ### 安装/连接 23 | 24 | ```bash 25 | > apt-get install sqlite3 # ubuntu 26 | > sqlite3 -version # 查看版本 27 | 3.22.0 ... 28 | > sqlite3 gee.db # 连接数据库,不存在新建 29 | sqlite> .help # 帮助文档 30 | .archive ... xxx 31 | .auth ON|OFF xxx 32 | .backup ?DB? FILE xxx 33 | ... 34 | ``` 35 |
36 | 37 |
38 | 39 | ### 数据库操作 40 | 41 | ```bash 42 | > .help # 显示帮助文档 43 | ... 44 | > .databases # 显示数据库名称及对应文件 45 | main: /tmp/gee.db 46 | > .output FILE # 将输出定向到 FILE 47 | > .show # 显示已经设置的值 48 | > .dump # 以 SQL 格式 dump 数据库 49 | > .dump users # dump 某张表 50 | > .backup FILE # 备份数据库到文件 51 | > .quit # 退出 52 | ``` 53 | 54 |
55 | 56 |
57 | 58 | ### 表操作 59 | 60 | ```bash 61 | > .table # 查看所有的表 62 | users books 63 | > .schema users # 显示CREATE语句 64 | CREATE TABLE users(name text PRIMARY KEY, age integer); 65 | > .import FILE TABLE # 将文件的数据导入到表中。 66 | > .head ON # 查询时显示列名称 67 | > select * from users 68 | name|age 69 | Tom|18 70 | Jack|20 71 | ``` 72 |
73 | 74 |
75 | 76 | ### 输出模式 77 | 78 | ```bash 79 | > .mode csv # 设置输出模式为 csv 80 | > select * from users 81 | name,age 82 | Tom,18 83 | Jack,20 84 | > .mode insert # 设置输出模式为 insert 85 | > select * from users 86 | INSERT INTO "table"(name,age) VALUES('Tom',18); 87 | INSERT INTO "table"(name,age) VALUES('Jack',20); 88 | ``` 89 | 90 | .mode 支持 csv, column, html, insert, line, list, tabs, tcl 等 8 种模式。 91 | 92 |
93 | 94 | 95 | 96 | 97 | ## SQL 语句 98 | 99 |
100 | 101 | ### 创建表 102 | 103 | ```sql 104 | CREATE TABLE tab_name ( 105 | col1 col1_type PRIMARY KEY, 106 | col2 INTEGER AUTOINCREMENT,, 107 | col3 col3_type NOT NULL, 108 | ..... 109 | colN colN_type, 110 | ); 111 | 112 | /* 常用类型: 113 | TEXT 字符串, CHAR(100) 固长字符串 114 | INTEGER 整型, BIGINT 长整型, REAL 实数, 115 | BOOL 布尔值 116 | BLOB 二进制 117 | DATETIME 时间 118 | */ 119 | ``` 120 | 121 | `PRIMARY KEY` 标记主键,`NOT NULL`标记非空。`AUTOINCREMENT` 自增,只能用于整型。 122 |
123 | 124 |
125 | 126 | ### 删除/更新表 127 | 128 | ```sql 129 | -- 删除表 130 | DROP TABLE tab_name; 131 | -- 新增列 132 | ALTER TABLE ADD COLUMNS col_name col_type; 133 | 134 | -- 重命名表 135 | ALTER TABLE old_tab RENAME TO new_tab 136 | 137 | -- 重命名列名(3.25.0+) 138 | ALTER TABLE tab_name RENAME COLUMN old_col TO new_col 139 | ``` 140 |
141 | 142 |
143 | 144 | ### 新增记录 145 | 146 | ```sql 147 | -- 单条 148 | INSERT INTO tab_name VALUES (xx, xx) 149 | -- 指定列名 150 | INSERT INTO tab_name (col1, col3) VALUES (xx, xx) 151 | -- 多条 152 | INSERT INTO tab_name (col1, col2, col3) VALUES 153 | (xx, xx, xx), 154 | ... 155 | (xx, xx, xx); 156 | ``` 157 |
158 | 159 |
160 | 161 | ### 查询记录 162 | 163 | ```sql 164 | -- 所有列 165 | SELECT * FROM tab_name; 166 | -- 去除重复 167 | SELECT DISTINCT col1 FROM tab_name; 168 | -- 统计个数 169 | SELECT COUNT(*) FROM tab_name 170 | -- 指定列 171 | SELECT col1, col2 FROM table_name; 172 | -- 带查询条件 >、<、=、LIKE、NOT、AND、OR 等 173 | SELECT * FROM table_name WHERE col2 >= 18; 174 | SELECT * FROM table_name 175 | WHERE col2 >= 18 AND col1 LIKE %stu%; 176 | -- 限制数量 177 | SELECT * FROM table_name LIMIT 1; 178 | -- GROUP BY 179 | SELECT col1, count(*) FROM tab_name 180 | WHERE [ conditions ] 181 | GROUP BY col1 182 | -- Having 183 | SELECT col1, count(*) FROM tab_name 184 | WHERE [ conditions ] 185 | GROUP BY col1 186 | HAVING [ conditions ] 187 | -- 排序, DESC 降序,ASC 升序 188 | SELECT * FROM table_name ORDER BY col2 DESC; 189 | ``` 190 |
191 | 192 |
193 | 194 | ### 删除/更新记录 195 | 196 | ```sql 197 | -- 删除满足条件的记录 198 | DELETE FROM tab_name WHERE condition; 199 | -- 更新记录 200 | UPDATE tab_name SET col1=value1, col2=value2 201 | -- 更新满足条件的记录 202 | UPDATE tab_name 203 | SET col1=value1, col2=value2 204 | WHERE [ conditions ] 205 | ``` 206 |
207 | 208 |
209 | 210 | ### 事务(Transaction) 211 | 212 | ```sql 213 | -- 提交 214 | BEGIN; 215 | INSERT INTO ... 216 | ... 217 | COMMIT; 218 | -- 回滚 219 | BEGIN; 220 | ... 221 | ROLLBACK; 222 | ``` 223 | 224 | 事务具有原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)、持久性(Durability)四个标准属性,缩写为 `ACID`。 225 | 226 |
-------------------------------------------------------------------------------- /posts/cheat-sheet/cheat-sheet-sqlite/sqlite.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/cheat-sheet/cheat-sheet-sqlite/sqlite.jpg -------------------------------------------------------------------------------- /posts/cheat-sheet/cheat-sheet-sqlite/sqlite.jpg.bak: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/cheat-sheet/cheat-sheet-sqlite/sqlite.jpg.bak -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Pandas 数据处理(三) - Cheat Sheet 中文版 3 | date: 2019-06-16 00:10:24 4 | description: Pandas_Cheat_Sheet_zh_CN.pdf 中文版,Geektutu翻译 5 | keywords: 6 | - Python 7 | - Pandas_Cheat_Sheet_zh_CN.pdf 8 | - Pandas Cheat Sheet 9 | tags: 10 | - Pandas 11 | categories: 12 | - Pandas 数据处理 13 | nav: 简明教程 14 | image: post/pandas-cheat-sheet-zh-cn/cheat_sheet_part.png 15 | --- 16 | 17 | ## 来源 18 | 19 | Pandas的[英文文档](https://pandas.pydata.org/pandas-docs/stable/)内容过于庞杂,而大部分时候,我们仅仅需要其中一些较为常见的用法,查看官方文档耗时且费力,这两张图较为全面地概括了pandas的常用方法。 20 | 21 | 包括: 22 | 23 | - 链式调用(Method Chaining) 24 | - 数据重塑(Reshaping Data) 25 | - 筛选数据(Subsets) 26 | - 数据统计(Summarize Data) 27 | - 数据分组(Group Data) 28 | - 合并数据(Combine Data) 29 | - 绘图(Pilot) 30 | - ... 31 | 32 | 参考[Pandas_Cheat_Sheet.pdf](https://pandas.pydata.org/Pandas_Cheat_Sheet.pdf),[极客兔兔](https://github.com/geektutu)于2019年6月16日翻译了中文版,如有错误,可以直接在评论区评论。 33 | 34 | ## 效果图 35 | 36 | ![Pandas_Cheat_Sheet_zh_CN_1](pandas-cheat-sheet-zh-cn/1.webp) 37 | ![Pandas_Cheat_Sheet_zh_CN_2](pandas-cheat-sheet-zh-cn/2.webp) 38 | 39 | ## PDF下载 40 | 41 | 点击查看/下载pdf版本 [Pandas_Cheat_Sheet_zh_CN.pdf](pandas-cheat-sheet-zh-cn/Pandas_Cheat_Sheet_zh_CN.pdf) 42 | 43 | 也欢迎把本文分享给对pandas感兴趣的小伙伴。 44 | 45 | ## 附 推荐 46 | 47 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn/1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-cheat-sheet-zh-cn/1.webp -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn/2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-cheat-sheet-zh-cn/2.webp -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn/Pandas_Cheat_Sheet_zh_CN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-cheat-sheet-zh-cn/Pandas_Cheat_Sheet_zh_CN.pdf -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn/Pandas_Cheat_Sheet_zh_CN.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-cheat-sheet-zh-cn/Pandas_Cheat_Sheet_zh_CN.pptx -------------------------------------------------------------------------------- /posts/data-mining/pandas-cheat-sheet-zh-cn/cheat_sheet_part.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-cheat-sheet-zh-cn/cheat_sheet_part.png -------------------------------------------------------------------------------- /posts/data-mining/pandas-dataframe-series.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Pandas 数据处理(一) - DataFrame 与 Series 3 | date: 2018-03-20 10:51:24 4 | description: DataFrame和Series是pandas中最常见的2种数据结构。DataFrame可以理解为Excel中的一张表,Series可以理解为一张Excel表的一行或一列数据。 5 | tags: 6 | - Pandas 7 | keywords: 8 | - Python 9 | categories: 10 | - Pandas 数据处理 11 | nav: 简明教程 12 | image: post/pandas-dataframe-series/pandas.gif 13 | --- 14 | 15 | 16 | 17 | 本文示例基于 Version 0.21.0 18 | 19 | DataFrame和Series是pandas中最常见的2种数据结构。DataFrame可以理解为Excel中的一张表,Series可以理解为一张Excel表的一行或一列数据。 20 | 21 | ## 一、Series 22 | 23 | Series可以理解为一维数组,它和一维数组的区别,在于Series具有索引。 24 | 25 | ### 1. 创建Series 26 | 27 | - 默认索引 28 | 29 | ```python 30 | money_series = pd.Series([200, 300, 10, 5], name="money") # 未设置索引的情况下,自动从0开始生成 31 | """ 32 | 0 200 33 | 1 300 34 | 2 10 35 | 3 5 36 | Name: money, dtype: int64 37 | """ 38 | money_series[0] # 根据索引获取具体的值, 39 | # 200 40 | money_series = money_series.sort_values() # 根据值进行排序,排序后索引与值对应关系不变 41 | """ 42 | 3 5 43 | 2 10 44 | 0 200 45 | 1 300 46 | Name: money, dtype: int64 47 | """ 48 | money_series[0] # 根据索引获取具体的值,0对应的依旧是200,等价于 money_series.loc[0] 49 | # 200 50 | money_series.iloc[0] # 根据序号获取具体的值 51 | # 5 52 | ``` 53 | 54 | - 自定义索引 55 | 56 | ```python 57 | money_series = pd.Series([200, 300, 10, 5], index=['d', 'c', 'b', 'a'], name='money') 58 | """ 59 | d 200 60 | c 300 61 | b 10 62 | a 5 63 | Name: money, dtype: int64 64 | """ 65 | money_series.index # 查看索引 66 | # Index(['d', 'c', 'b', 'a'], dtype='object') 67 | money_series['a'] # 根据索引获取具体的值 68 | # 5 69 | money_series = money_series.sort_index() # 根据索引排序 70 | """ 71 | a 5 72 | b 10 73 | c 300 74 | d 200 75 | Name: money, dtype: int64 76 | """ 77 | money_series.iloc[-1] # 取最后一个值 78 | # 200 79 | ``` 80 | 81 | ### 2. 切片与取值 82 | 83 | - 根据索引 84 | 85 | ```python 86 | money_series = pd.Series({'d': 200, 'c': 300, 'b': 10, 'a': 5}, name='money') 87 | """ 88 | a 200 89 | b 300 90 | c 10 91 | d 5 92 | Name: money, dtype: int64 93 | """ 94 | money_series.loc['a'] # 等价于 money_series['a'] 95 | # 200 96 | money_series.loc['c':'a':-1] # 从c取到 a,倒序 97 | """ 98 | c 10 99 | b 300 100 | a 200 101 | Name: money, dtype: int64 102 | """ 103 | money_series.loc[['d', 'a']] # d, a的值,等价于 money_series[['d', 'a']] 104 | """ 105 | d 5 106 | a 200 107 | """ 108 | ``` 109 | 110 | - 根据序号 111 | 112 | ```python 113 | money_series.iloc[0] 114 | # 200 115 | money_series.iloc[1:3] # 根据序号取值,不包含结束,等价于 money_series[1:3] 116 | """ 117 | b 300 118 | c 10 119 | Name: money, dtype: int64 120 | """ 121 | money_series.iloc[[3, 0]] # 取第三个值和第一个值 122 | """ 123 | d 5 124 | a 200 125 | Name: money, dtype: int64 126 | """ 127 | ``` 128 | 129 | - 根据条件 130 | 131 | ```python 132 | money_series[money_series > 50] # 选取大于50的值 133 | """ 134 | c 300 135 | d 200 136 | Name: money, dtype: int64 137 | """ 138 | money_series[lambda x: x ** 2 > 50] # 选取值平方大于50的值 139 | """ 140 | b 10 141 | c 300 142 | d 200 143 | Name: money, dtype: int64 144 | """ 145 | ``` 146 | 147 | ## 二、DataFrame 148 | 149 | ### 1. 创建DataFrame 150 | 151 | - 从字典中创建 152 | 153 | ```python 154 | # 字典值等长 155 | # 不指定 index 156 | df = pd.DataFrame({'单价': [100, 200, 30], '数量': [3, 3, 10]}) 157 | """ 158 | 单价 数量 159 | 0 100 3 160 | 1 200 3 161 | 2 30 10 162 | """ 163 | 164 | # 指定 index 165 | df = pd.DataFrame({'单价': [100, 200, 30], '数量': [3, 3, 10]}, index=['T001', 'T002', 'T003']) 166 | """ 167 | 单价 数量 168 | T001 100 3 169 | T002 200 3 170 | T003 30 10 171 | """ 172 | ``` 173 | 174 | - 通过Series创建 175 | 176 | ```python 177 | price_series = pd.Series([100, 200, 30], index=['T001', 'T002', 'T005']) 178 | quantity_series = pd.Series([3, 3, 10, 2], index=['T001', 'T002', 'T003', 'T004']) 179 | df = pd.DataFrame({'单价': price_series, '数量': quantity_series}) 180 | # 数据中不含有对应元素,则置为NaN 181 | """ 182 | 单价 数量 183 | T001 100.0 3.0 184 | T002 200.0 3.0 185 | T003 NaN 10.0 186 | T004 NaN 2.0 187 | T005 30.0 NaN 188 | """ 189 | ``` 190 | 191 | - 从Excel文件中读取,demo.dat 192 | 193 | ```python 194 | df = pd.read_excel("path/demo.xlsx", sheetname=0) 195 | # 指定 sheetname 196 | df = pd.read_excel("path/demo.xlsx", sheetname='销售记录') 197 | ``` 198 | 199 | - 从普通文本中读取 200 | 201 | ``` 202 | 编号|日期|单价|数量 203 | T001|2018-03-02 12:34:05|100|3 204 | T002|2018-03-02 13:04:05|200|3 205 | T003|2018-03-03 18:12:31|30|10 206 | T004|2018-03-04 20:34:05|400|2 207 | T005|2018-03-02 20:34:05|500|1 208 | ``` 209 | 210 | ```python 211 | df = pd.read_csv('demo.dat', delimiter='|') # csv默认是逗号分隔的,如果不是,需要指定delimiter 212 | """ 213 | 编号 日期 单价 数量 214 | 0 T001 2018-03-02 12:34:05 100 3 215 | 1 T002 2018-03-02 13:04:05 200 3 216 | 2 T003 2018-03-03 18:12:31 30 10 217 | 3 T004 2018-03-04 20:34:05 400 2 218 | 4 T005 2018-03-02 20:34:05 500 1 219 | """ 220 | 221 | df = pd.read_csv('demo.dat', delimiter='|', index_col='编号') # index_col指定行标签为索引 222 | """ 223 | 日期 单价 数量 224 | 编号 225 | T001 2018-03-02 12:34:05 100 3 226 | T002 2018-03-02 13:04:05 200 3 227 | T003 2018-03-03 18:12:31 30 10 228 | T004 2018-03-04 20:34:05 400 2 229 | T005 2018-03-02 20:34:05 500 1 230 | """ 231 | ``` 232 | 233 | ### 2. 获取列与行 234 | 235 | ```python 236 | df['日期'] # -> 返回Series 237 | """ 238 | 0 2018-03-02 12:34:05 239 | 1 2018-03-02 13:04:05 240 | 2 2018-03-03 18:12:31 241 | 3 2018-03-04 20:34:05 242 | 4 2018-03-02 20:34:05 243 | Name: 日期, dtype: object 244 | """ 245 | df[['单价', '数量']] # -> 返回Series 246 | """ 247 | 单价 数量 248 | 0 100 3 249 | 1 200 3 250 | 2 30 10 251 | 3 400 2 252 | 4 500 1 253 | """ 254 | df.loc['T001'] # 按行标签获取,返回Series 255 | df.iloc[0] # 按行号获取,返回Series 256 | """ 257 | 日期 2018-03-02 12:34:05 258 | 单价 100 259 | 数量 3 260 | Name: T001, dtype: object 261 | """ 262 | df.head(3) # 前三行 263 | df.tail(3) # 后三行 264 | """ 265 | 日期 单价 数量 266 | 编号 267 | T003 2018-03-03 18:12:31 30 10 268 | T004 2018-03-04 20:34:05 400 2 269 | T005 2018-03-02 20:34:05 500 1 270 | """ 271 | ``` 272 | 273 | ### 3. 修改 274 | 275 | - 单价 * 2 276 | 277 | ```python 278 | df['单价'] *= 2 279 | # apply支持传入修改函数,能处理更复杂的场景 280 | # 等价于, df['单价'] = df.apply(lambda x: x['单价'] * 2, axis=1) 281 | 282 | """ 283 | 日期 单价 数量 284 | 编号 285 | T001 2018-03-02 12:34:05 200 3 286 | T002 2018-03-02 13:04:05 400 3 287 | T003 2018-03-03 18:12:31 60 10 288 | T004 2018-03-04 20:34:05 800 2 289 | T005 2018-03-02 20:34:05 1000 1 290 | """ 291 | ``` 292 | 293 | - 编号加上前缀 294 | 295 | ```python 296 | # 由于编号是索引,所以需要用 df.index去访问 297 | df.index = '2018_' + df.index 298 | """ 299 | 日期 单价 数量 300 | 2018_T001 2018-03-02 12:34:05 200 3 301 | 2018_T002 2018-03-02 13:04:05 400 3 302 | 2018_T003 2018-03-03 18:12:31 60 10 303 | 2018_T004 2018-03-04 20:34:05 800 2 304 | 2018_T005 2018-03-02 20:34:05 1000 1 305 | """ 306 | ``` 307 | 308 | - 数量小于3的记录,单价 + 10 309 | 310 | ```python 311 | def change_price(x): 312 | if x['数量'] < 3: 313 | return x['单价'] + 10 314 | return x['单价'] 315 | 316 | 317 | df['单价'] = df.apply(change_price, axis=1) 318 | """ 319 | 日期 单价 数量 320 | 2018_T001 2018-03-02 12:34:05 200 3 321 | 2018_T002 2018-03-02 13:04:05 400 3 322 | 2018_T003 2018-03-03 18:12:31 60 10 323 | 2018_T004 2018-03-04 20:34:05 810 2 324 | 2018_T005 2018-03-02 20:34:05 1010 1 325 | """ 326 | ``` 327 | 328 | - 增加物流公司 329 | 330 | ```python 331 | df['运费'] = pd.Series({'2018_T001': 10, '2018_T005': 12}) 332 | """ 333 | 日期 单价 数量 运费 334 | 2018_T001 2018-03-02 12:34:05 200 3 10.0 335 | 2018_T002 2018-03-02 13:04:05 400 3 NaN 336 | 2018_T003 2018-03-03 18:12:31 60 10 NaN 337 | 2018_T004 2018-03-04 20:34:05 810 2 NaN 338 | 2018_T005 2018-03-02 20:34:05 1010 1 12.0 339 | """ 340 | # 缺少信息的部分填充为0 341 | df.fillna(0) 342 | """ 343 | 日期 单价 数量 运费 344 | 2018_T001 2018-03-02 12:34:05 200 3 10.0 345 | 2018_T002 2018-03-02 13:04:05 400 3 0.0 346 | 2018_T003 2018-03-03 18:12:31 60 10 0.0 347 | 2018_T004 2018-03-04 20:34:05 810 2 0.0 348 | 2018_T005 2018-03-02 20:34:05 1010 1 12.0 349 | """ 350 | ``` 351 | 352 | ### 4. 删除 353 | 354 | - 删除日期列(就地删除) 355 | 356 | ```python 357 | del df['日期'] 358 | """ 359 | 单价 数量 运费 360 | 2018_T001 200 3 10.0 361 | 2018_T002 400 3 NaN 362 | 2018_T003 60 10 NaN 363 | 2018_T004 810 2 NaN 364 | 2018_T005 1010 1 12.0 365 | """ 366 | ``` 367 | 368 | - 删除运费列(返回筛选后的) 369 | 370 | ```python 371 | new_columns = list(df.columns) 372 | new_columns.remove('运费') 373 | df = df[new_columns] 374 | """ 375 | 单价 数量 376 | 2018_T001 200 3 377 | 2018_T002 400 3 378 | 2018_T003 60 10 379 | 2018_T004 810 2 380 | 2018_T005 1010 1 381 | """ 382 | ``` 383 | 384 | ## 附 推荐 385 | 386 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/data-mining/pandas-dataframe-series/pandas.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-dataframe-series/pandas.gif -------------------------------------------------------------------------------- /posts/data-mining/pandas-select-data.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Pandas 数据处理(二) - 筛选数据 3 | date: 2018-03-21 10:51:24 4 | description: Pandas筛选DataFrame和Series数据的一些实用技巧。 5 | tags: 6 | - Pandas 7 | keywords: 8 | - Python 9 | categories: 10 | - Pandas 数据处理 11 | nav: 简明教程 12 | image: post/pandas-select-data/pandas.gif 13 | --- 14 | 15 | 16 | 使用`demo.csv`举几个栗子~ 17 | 18 | ``` 19 | 编号,日期,单价,数量 20 | T001,2018-03-02 12:34:05,100,3 21 | T002,2018-03-02 13:04:05,200,3 22 | T003,2018-03-03 18:12:31,30,10 23 | T004,2018-03-04 20:34:05,400,2 24 | T005,2018-03-02 20:34:05,500,1 25 | ``` 26 | 27 | ```python 28 | import pandas as pd 29 | 30 | df = pd.read_csv('demo.csv', index_col='编号') # 指定行标签 label 31 | """ 32 | 日期 单价 33 | 编号 34 | T001 2018-03-02 12:34:05 100 35 | T002 2018-03-02 13:04:05 200 36 | T003 2018-03-03 18:12:31 30 37 | T004 2018-03-04 20:34:05 400 38 | T005 2018-03-02 20:34:05 500 39 | """ 40 | ``` 41 | 42 | > 如果不指定 index_col,则会默认生成 0 - 5的行标签,这种情况下,行标签与行号相等。 43 | > 44 | > 行标签可以理解为这一行的区别于其他行的标志(类似SQL中的主键) 45 | 46 | ```python 47 | df_without_index_col = pd.read_csv('demo.csv') 48 | """ 49 | 编号 日期 单价 数量 50 | 0 T001 2018-03-02 12:34:05 100 3 51 | 1 T002 2018-03-02 13:04:05 200 3 52 | 2 T003 2018-03-03 18:12:31 30 10 53 | 3 T004 2018-03-04 20:34:05 400 2 54 | 4 T005 2018-03-02 20:34:05 500 1 55 | """ 56 | ``` 57 | 58 | ## 一、选取列 59 | 60 | ### 1. 使用方括号 61 | 62 | ```python 63 | df[['日期', '单价']] 64 | """ 65 | 日期 单价 66 | 编号 67 | T001 2018-03-02 12:34:05 100 68 | T002 2018-03-02 13:04:05 200 69 | T003 2018-03-03 18:12:31 30 70 | T004 2018-03-04 20:34:05 400 71 | T005 2018-03-02 20:34:05 500 72 | """ 73 | ``` 74 | 75 | ```python 76 | df[['单价']] # -> DataFrame 77 | """ 78 | 单价 79 | 编号 80 | T001 100 81 | T002 200 82 | T003 30 83 | T004 400 84 | T005 500 85 | """ 86 | type(df[['单价']]) # 87 | ``` 88 | 89 | ```python 90 | df['单价'] # -> Series 91 | """ 92 | 编号 93 | T001 100 94 | T002 200 95 | T003 30 96 | T004 400 97 | T005 500 98 | Name: 单价, dtype: int64 99 | """ 100 | type(df['单价']) # 101 | ``` 102 | 103 | ## 二、选取行 104 | 105 | ### 1. 行标签使用loc 106 | 107 | 可接受2个参数,第一个参数是行标签,第二个参数是列标签 108 | 109 | - 行标签为T002的行 110 | 111 | ```python 112 | df.loc['T002'] # -> Series 113 | """ 114 | 日期 2018-03-02 12:34:05 115 | 单价 100 116 | 数量 3 117 | Name: T001, dtype: object 118 | """ 119 | ``` 120 | 121 | - 行标签从T002到 T004的行(与切片不同,这种情况下包含开头也包含结束) 122 | 123 | ```python 124 | df.loc['T002':'T004'] # -> DataFrame 125 | """ 126 | 日期 单价 数量 127 | 编号 128 | T002 2018-03-02 13:04:05 200 3 129 | T003 2018-03-03 18:12:31 30 10 130 | T004 2018-03-04 20:34:05 400 2 131 | """ 132 | ``` 133 | 134 | - 行标签从T002到最后的行,且只选取单价和数量2列 135 | 136 | ```python 137 | df.loc['T002':, ['单价', '数量']] # -> DataFrame 138 | # 等价于 df.loc['T002':][['单价', '数量']] 139 | """ 140 | 单价 数量 141 | 编号 142 | T002 200 3 143 | T003 30 10 144 | T004 400 2 145 | T005 500 1 146 | """ 147 | ``` 148 | 149 | ### 2. 行号使用iloc 150 | 151 | - 第1行到倒数第2行 152 | 153 | iloc参数是一个slice对象,和python中可迭代对象用法是一致的。[start, end, step)。 154 | 155 | 下标从0开始,包含开头,不包含结束。 156 | 157 | ```python 158 | df.iloc[1:-2] # -> DataFrame,不包含 -2 行 159 | """ 160 | 日期 单价 数量 161 | 编号 162 | T002 2018-03-02 13:04:05 200 3 163 | T003 2018-03-03 18:12:31 30 10 164 | """ 165 | ``` 166 | 167 | - 第1行到最后,隔行取,即行号为单数的行,1,3,5,7,9.... 168 | 169 | ```python 170 | df.iloc[1::2] # -> DataFrame,到最后,切片的第二个参数可省略不写 171 | 172 | """ 173 | 日期 单价 数量 174 | 编号 175 | T002 2018-03-02 13:04:05 200 3 176 | T004 2018-03-04 20:34:05 400 2 177 | """ 178 | ``` 179 | 180 | - 第0行到第3行,第1列到第3列 181 | 182 | ```python 183 | df.iloc[:3, 1:3] # -> DataFrame, 等价于 df.iloc[0:3, 1:3],0可以省略不写 184 | """ 185 | 单价 数量 186 | 编号 187 | T001 100 3 188 | T002 200 3 189 | T003 30 10 190 | """ 191 | ``` 192 | 193 | - 若没有指定标签列(index_col),则loc与iloc表现相近 194 | 195 | ```python 196 | df_without_index_col.loc[1:3] # 行标签从1到3的行,行标签是包含结束的。 197 | """ 198 | 编号 日期 单价 数量 199 | 1 T002 2018-03-02 13:04:05 200 3 200 | 2 T003 2018-03-03 18:12:31 30 10 201 | 3 T004 2018-03-04 20:34:05 400 2 202 | """ 203 | df_without_index_col.iloc[1:3] # 切片是不包含结束的。 204 | """ 205 | 编号 日期 单价 数量 206 | 1 T002 2018-03-02 13:04:05 200 3 207 | 2 T003 2018-03-03 18:12:31 30 10 208 | """ 209 | ``` 210 | 211 | ### 3. ix兼容处理loc与iloc(deprecated) 212 | 213 | loc是根据行标签定位的,iloc是根据行号定位的。 214 | 215 | ix的处理逻辑是,通常将传入的参数优先视为行标签,若找不到该标签的情况下,再当做行号取索引。 216 | 217 | ```python 218 | df.ix['T002':'T004'] # 匹配到了行标签,所以和loc表现一致 219 | """ 220 | 日期 单价 数量 221 | 编号 222 | T002 2018-03-02 13:04:05 200 3 223 | T003 2018-03-03 18:12:31 30 10 224 | T004 2018-03-04 20:34:05 400 2 225 | """ 226 | df.ix[1:3] # 没有匹配到行标签,所以和iloc一致 227 | """ 228 | 日期 单价 数量 229 | 编号 230 | T002 2018-03-02 13:04:05 200 3 231 | T003 2018-03-03 18:12:31 30 10 232 | """ 233 | df_without_index_col.ix[1:3] # 匹配到了行标签 1,2,3,表现与iloc一致 234 | """ 235 | 编号 日期 单价 数量 236 | 1 T002 2018-03-02 13:04:05 200 3 237 | 2 T003 2018-03-03 18:12:31 30 10 238 | 3 T004 2018-03-04 20:34:05 400 2 239 | """ 240 | ``` 241 | 242 | > ix方法虽然兼容处理了2种情况,但是建议不要使用,尽量使用loc和iloc明确索引方式。 243 | > 244 | > 更多的关于ix的讨论可以参考 [pandas iloc vs ix vs loc explanation, how are they different?](https://stackoverflow.com/questions/31593201/pandas-iloc-vs-ix-vs-loc-explanation-how-are-they-different) 245 | > 246 | > 当前ix方法已经处于 `deprecated` 状态。 247 | 248 | ## 三、简单条件 249 | 250 | ### 1. 简单的逻辑判断(<, >, ==, &, |, ~ 等) 251 | 252 | - 单价不小于200的记录 253 | 254 | ```python 255 | df[df['单价'] >= 200] 256 | # 等价于 df[~(df['单价'] < 200)] 257 | """ 258 | 日期 单价 数量 259 | 编号 260 | T002 2018-03-02 13:04:05 200 3 261 | T004 2018-03-04 20:34:05 400 2 262 | T005 2018-03-02 20:34:05 500 1 263 | """ 264 | ``` 265 | 266 | - 单价大于等于200且数量大于1的记录(&表示与,|表示或,~表示非) 267 | 268 | ```python 269 | df[(df['单价'] >= 200) & (df['数量'] >= 2)] 270 | """ 271 | 日期 单价 数量 272 | 编号 273 | T002 2018-03-02 13:04:05 200 3 274 | T004 2018-03-04 20:34:05 400 2 275 | """ 276 | ``` 277 | 278 | ## 四、自定义函数 279 | 280 | ### 1. loc 281 | 282 | loc函数支持传入自定义的函数进行筛选 283 | 284 | - 筛选总价大于500的记录 285 | 286 | ```python 287 | df.loc[lambda x: x['单价'] * x['数量'] > 500] # 函数入参 x 是整个 DataFrame 288 | """ 289 | 日期 单价 数量 290 | 编号 291 | T002 2018-03-02 13:04:05 200 3 292 | T004 2018-03-04 20:34:05 400 2 293 | """ 294 | # 等价于 295 | def filter_func(x): 296 | print(type(x)) # 297 | return x['单价'] * x['数量'] > 500 298 | df.loc[filter_func] 299 | ``` 300 | 301 | - 筛选3月2号的记录 302 | 303 | ```python 304 | # x是整个DataFrame,x['日期']是包含所有记录的Series 305 | df.loc[lambda x: x['日期'].str.startswith('2018-03-02')] 306 | """ 307 | 日期 单价 数量 308 | 编号 309 | T001 2018-03-02 12:34:05 100 3 310 | T002 2018-03-02 13:04:05 200 3 311 | T005 2018-03-02 20:34:05 500 1 312 | """ 313 | ``` 314 | 315 | > Series.str(),可以使用Python自带的string相关的方法,构造筛选条件。 316 | > 317 | > 查看官方文档[pandas.Series.str](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.Series.str.html) 318 | 319 | ### 2. apply 320 | 321 | - 筛选总价大于500的记录 322 | 323 | ```python 324 | df[df.apply(lambda x: x['单价'] * x['数量'] > 500, axis=1)] # 函数入参 x 是一行数据 Series 325 | """ 326 | 日期 单价 数量 327 | 编号 328 | T002 2018-03-02 13:04:05 200 3 329 | T004 2018-03-04 20:34:05 400 2 330 | """ 331 | # 等价于 332 | def filter_func(x): 333 | print(type(x)) # 334 | return x['单价'] * x['数量'] > 500 335 | df[df.apply(filter_func, axis=1)] 336 | ``` 337 | 338 | > axis默认为0,表示遍历列,axis=1,表示遍历行。 339 | > 340 | > 更多参数,查看官方文档[pandas.DataFrame.apply](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.apply.html) 341 | 342 | - 筛选3月2号的记录 343 | 344 | ```python 345 | df['日期'] = pd.to_datetime(df['日期']) 346 | # x表示一行记录(Series),x['日期']已经是一个具体的日期对象 347 | df[df.apply(lambda x: x['日期'].date() == pd.to_datetime('2018/03/02').date(), axis=1)] 348 | """ 349 | 日期 单价 数量 350 | 编号 351 | T001 2018-03-02 12:34:05 100 3 352 | T002 2018-03-02 13:04:05 200 3 353 | T005 2018-03-02 20:34:05 500 1 354 | """ 355 | ``` 356 | 357 | - 假如数据中单价包含非数字,需要删除 358 | 359 | ``` 360 | 编号,日期,单价,数量 361 | T001,2018-03-02 12:34:05,100,3 362 | T002,2018-03-02 13:04:05,200,3 363 | T003,2018-03-03 18:12:31,30,10 364 | T004,2018-03-04 20:34:05,400,2 365 | T005,2018-03-02 20:34:05,500,1 366 | T006,2018-03-02 20:34:05,$#500,1 367 | ``` 368 | 369 | ```python 370 | def filter_func(x): 371 | try: 372 | float(x['单价']) # x是为某一行的数据,x['单价']可以取到该行对应的值 373 | return True 374 | except: 375 | return False 376 | 377 | 378 | df = pd.read_csv('demo.csv', index_col='编号') 379 | print(df[df.apply(filter_func, axis=1)]) 380 | """ 381 | 日期 单价 数量 382 | 编号 383 | T001 2018-03-02 12:34:05 100 3 384 | T002 2018-03-02 13:04:05 200 3 385 | T003 2018-03-03 18:12:31 30 10 386 | T004 2018-03-04 20:34:05 400 2 387 | T005 2018-03-02 20:34:05 500.0 1 388 | """ 389 | ``` 390 | 391 | ## 附 推荐 392 | 393 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/data-mining/pandas-select-data/pandas.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/data-mining/pandas-select-data/pandas.gif -------------------------------------------------------------------------------- /posts/pandora-box/awesome-config.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: WSL, Git, Mircosoft Terminal 等常用工具配置 3 | date: 2019-12-03 00:30:00 4 | description: 记录开发过程中一些工具的常用配置,加速在新环境上的准备效率。工具包括不限于微软 Linux 子系统 Ubuntu (WSL, WSL2), Git, 微软最新发布的命令行神器( Mircosoft Terminal )等。 5 | tags: 6 | - 百宝箱 7 | nav: 百宝箱 8 | categories: 9 | - 工具 10 | keywords: 11 | - WSL 12 | - Ubuntu 13 | - Mircosoft Terminal 14 | image: post/awesome-config/wsl.jpg 15 | --- 16 | 17 | ## Git 18 | 19 | ### 用户与鉴权 20 | 21 | ```bash 22 | # 生成SSH公钥对 23 | ssh-keygen -t rsa -b 4096 -C "your_email@example.com" 24 | 25 | # 配置用户名和邮箱 26 | git config --global user.name 27 | git config --global user.email 28 | 29 | # 针对 https 协议的仓库,记住密码,避免每次都要求输入密码 30 | git config --global credential.helper store 31 | 32 | # 自动拉取 submodule 33 | git config --global submodule.recurse true 34 | ``` 35 | 36 | ### alias 提高效率 37 | 38 | ```bash 39 | git config --global alias.co checkout 40 | git config --global alias.ci commit 41 | git config --global alias.st status 42 | git config --global alias.br branch 43 | git config --global alias.logg "log --graph --decorate --all" 44 | ``` 45 | 46 | 配置了 alias,就可以简化相应的 Git 命令,例如 `git status` 可以简化为 `git st` 47 | 48 | Git 的 `git log` 并不能显示其他的分支,以及分支之间的树形关系,所以额外添加很多的参数,因此,适合用 `git logg` 这么一个别名来代替。 49 | 50 | 51 |
52 | 对比一下`git log` 和 `git logg` 的差异。 53 |
54 | 55 | 56 | `git log` 不能够显示分支之间的树形关系,`git logg`可以。 57 | 58 | ```bash 59 | commit 68b7f2f13b73cfdaeadc022eb02181714449186c (HEAD -> master, origin/master, origin/HEAD) 60 | Author: geektutu 61 | Date: Mon Nov 25 00:08:01 2019 +0800 62 | 63 | fix title 64 | 65 | commit b65de90b15ef78c12b2ac9346520c873504b361c 66 | Author: geektutu 67 | Date: Mon Nov 25 00:01:13 2019 +0800 68 | 69 | add quick rust 70 | ``` 71 | 72 | ```bash 73 | $ git logg 74 | * commit 68b7f2f (HEAD -> master, origin/master, origin/HEAD) 75 | | Author: geektutu 76 | | Date: Mon Nov 25 00:08:01 2019 +0800 77 | | 78 | | fix title 79 | | 80 | | * commit 01dfa04 (origin/dependabot/npm_and_yarn/lodash-4.17.15) 81 | |/ Author: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> 82 | | Date: Sun Nov 24 16:01:47 2019 +000 83 | ``` 84 | 85 | 结合 `--oneline` 参数很方便地浏览提交记录。 86 | 87 | ```bash 88 | $ git logg --oneline 89 | * 68b7f2f (HEAD -> master, origin/master, origin/HEAD) fix title 90 | | * 01dfa04 (origin/dependabot/npm_and_yarn/lodash-4.17.15) Bump lodash from 4.17.11 to 4.17.15 91 | |/ 92 | | * 387e6da (origin/dependabot/npm_and_yarn/lodash.merge-4.6.2) Bump lodash.merge from 4.6.1 to 4.6.2 93 | |/ 94 | | * 32a2929 (origin/dependabot/npm_and_yarn/mixin-deep-1.3.2) Bump mixin-deep from 1.3.1 to 1.3.2 95 | |/ 96 | * b65de90 add quick rust 97 | * 228f94a update comments.js 98 | ... 99 | ``` 100 | 101 |
102 |
103 | 104 | ## WSL (Ubuntu) 105 | 106 | Windows 10 下内置了微软Linux子系统(Windows Subsystem for Linux, WSL),对于使用 Windows 作为主力开发的童鞋们,生产力可以得到极大的解放。 107 | 108 | ### 安装 109 | 110 | - 第一步,在 PowerShell (管理员权限) 中以命令行方式开启 Linux 特性,并**重启**。 111 | 112 | ```bash 113 | Enable-WindowsOptionalFeature -Online -FeatureName Microsoft-Windows-Subsystem-Linux 114 | ``` 115 | 116 | - 第二步,在微软应用商店(Microsoft Store)搜索 Ubuntu 并安装。 117 | 118 | > 命令行方式或离线安装请参考:[WSL Install - Microsoft](https://docs.microsoft.com/en-us/windows/wsl/install-win10) 119 | 120 | ### 配置 121 | 122 | - 权限问题 123 | 124 | 第一个问题,磁盘挂载到Linux下时,所有的权限都被变为了 777,如果需要保持 Linux 下的习惯(文件夹 755,文件644),则需要额外配置 `/etc/wsl.conf`。 125 | 126 | `vim /etc/wsl.conf`,添加如下配置,这样在 Windows 中新建文件夹和文件,将以 755/644 的权限创建: 127 | 128 | ```conf 129 | [automount] 130 | options = "metadata,umask=22,fmask=11" 131 | ``` 132 | 133 | 第二个问题,WSL 中创建文件夹和文件,仍旧是 777 的权限,因为 WSL 的 umask 默认值为 `0000`,需要修改为`0022`: 134 | 135 | `vim ~/.profile`,添加 136 | 137 | ```bash 138 | umask 0022 139 | ``` 140 | 141 | 修改完上述两个文件,并不会即时生效,需要执行以下命令关闭 WSL 服务,再重新打开。 142 | 143 | 在 *cmd* 中执行 `wsl -t Ubuntu` 或在 *PowerShell* 执行 `Restart-Service LxssManager` 144 | 145 | `wsl -t `,中的 DistributionName 可以通过 `wsl -l` 查询到 146 | 147 | ```bash 148 | C:\Users\admin>wsl -l 149 | 适用于 Linux 的 Windows 子系统: 150 | Ubuntu (默认) 151 | ``` 152 | 153 | 附:如果在 Visual Studio Code (VS Code) 中 使用了 Remote - WSL 打开 WSL 中的文件夹,会发现 `umask 0022` 仍没有生效,在 VS Code 中创建的文件夹权限仍旧是 777,则需要将`umask 0022`添加到`~/.vscode-server/server-env-setup`(不存在则新建)中。VS Code 连接 WSL 时会执行该脚本,可以用以下命令快速添加,重新启动 VS Code 即可生效。 154 | 155 | ```bash 156 | echo "umask 0022" | tee -a ~/.vscode-server/server-env-setup 157 | ``` 158 | 159 | > 参考:[WSL Config - Microsoft](https://docs.microsoft.com/en-us/windows/wsl/wsl-config) 160 | > 参考: [Updates to wsl.conf no longer immediate - Github](https://github.com/microsoft/WSL/issues/3994) 161 | 162 | ## Mircosoft Terminal 163 | 164 | 微软新开发的命令行程序,可以算是良心之作了,同样可以在 Microsoft Store 中搜索安装。支持多页签切换,支持选择不同的 Shell,结合 WSL 使用,显示效果也非常棒。 165 | 166 | `Ctrl + ,` 可以打开配置文件(json 格式),也可以点击下拉框中的 `Settings`。 167 | 168 | - 第一步,将 defaultProfile 的值修改为 WSL 的 guid,这样默认打开就是 WSL 的 Shell 了。 169 | 170 | - 第二步,为了让 Terminal 的快捷键和 Linux 更接近,还可以设置快捷键。 171 | 172 | ```json 173 | "keybindings": [ 174 | { 175 | "command": "closeTab", 176 | "keys": ["ctrl+w"] 177 | }, 178 | { 179 | "command": "newTab", 180 | "keys": ["ctrl+t"] 181 | }, 182 | { 183 | "command": "paste", 184 | "keys": ["shift+insert"] 185 | } 186 | ], 187 | "copyOnSelect": true // 选中即复制 188 | ``` 189 | 190 | 默认的快捷键: 191 | 192 | ```bash 193 | ctrl + tab # 切换标签页 194 | ctrl + shift + c # 复制 195 | ctrl + shift + v # 粘贴 196 | ctrl + shift + 1/2/3 # 打开 powershell/cmd/WSL 197 | ``` 198 | 199 | 新增快捷键: 200 | 201 | ```bash 202 | ctrl + w # 关闭当前标签页 203 | ctrl + t # 新增标签页 204 | shift + insert # 粘贴 205 | copyOnSelect # 选中即复制 206 | ``` 207 | 208 | > 参考 [Terminal SettingsSchema - Github](https://github.com/microsoft/terminal/blob/master/doc/cascadia/SettingsSchema.md) 209 | 210 | **待更新** -------------------------------------------------------------------------------- /posts/pandora-box/awesome-config/wsl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/pandora-box/awesome-config/wsl.jpg -------------------------------------------------------------------------------- /posts/pandora-box/box-tools.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 百宝箱 - 值得收藏的工具网站 3 | date: 2019-09-02 01:00:10 4 | description: 极客兔兔的百宝箱,都是些珍藏多年的宝贝。包括编程常用的网站,例如在线正则表达式;前端设计常用的网站,例如色彩搭配等。 5 | tags: 6 | - 百宝箱 7 | nav: 百宝箱 8 | categories: 9 | - 工具 10 | keywords: 11 | - 在线正则表达式网站 12 | - 色彩搭配 13 | image: post/box-tools/colorhunt.jpg 14 | --- 15 | 16 | 准备逐步把自己平时常用的网站和工具分享出来,希望能提高大家的效率。 17 | 18 | ## 编程 19 | 20 |
21 | No.1 Regex101 - 在线正则表达式测试 22 |
23 | 24 | [Regex101](https://regex101.com/) 25 | 26 | 正则表达式无非是一些规则的集合罢了,很多童鞋有正则恐怖症,不知道自己写的是否是期望的,这个网站可以帮助你克服正则恐怖症,你正则表达式中的每一个字母符号,都会告诉你具体的含义。 27 | 28 | ![regex](box-tools/regex.jpg) 29 |
30 |
31 | 32 | ## 前端/设计 33 | 34 |
35 | No.1 ColorHunt - 颜色搭配 36 |
37 | 38 | [ColorHunt](https://colorhunt.co) 39 | 40 | 好看的配色都在这里,适合建站颜色搭配困难症。 41 | 42 | ![colorhunt](box-tools/colorhunt.jpg) 43 |
44 |
45 | 46 |
47 | No.2 ImageOptim - 图片压缩神器 48 |
49 | 50 | [ImageOptim](https://imageoptim.com/mac) 51 | 52 | 一直在使用的图片压缩神器,压缩率基本在 50% 以上,本站的所有图片上传前都经过 ImageOptim 压缩。而且非常小!只有一个页面,拖进去即可! 53 | 54 | ![ImageOptim](box-tools/imageoptim.jpg) 55 |
56 |
-------------------------------------------------------------------------------- /posts/pandora-box/box-tools/colorhunt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/pandora-box/box-tools/colorhunt.jpg -------------------------------------------------------------------------------- /posts/pandora-box/box-tools/imageoptim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/pandora-box/box-tools/imageoptim.jpg -------------------------------------------------------------------------------- /posts/pandora-box/box-tools/regex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/pandora-box/box-tools/regex.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-context.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go Context 并发编程简明教程 3 | seo_title: 快速入门 4 | date: 2020-04-20 23:30:00 5 | description: WaitGroup 和信道(channel)是常见的 2 种并发控制的方式。但是对于复杂的并发场景,Context 上下文是更优雅的控制方式。Context 提供了 WithCancel(取消)、WithValue(传值)、WithTimeout(超时机制)、WithDeadline(截止时间)等4种并发控制的方式。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - 并发编程 13 | - 上下文 14 | - 信道 15 | - 超时退出 16 | - golang 17 | image: post/quick-go-context/context_sm.jpg 18 | --- 19 | 20 | ## 1 为什么需要 Context 21 | 22 | WaitGroup 和信道(channel)是常见的 2 种并发控制的方式。 23 | 24 | 如果并发启动了多个子协程,需要等待所有的子协程完成任务,WaitGroup 非常适合于这类场景,例如下面的例子: 25 | 26 | ```go 27 | var wg sync.WaitGroup 28 | 29 | func doTask(n int) { 30 | time.Sleep(time.Duration(n)) 31 | fmt.Printf("Task %d Done\n", n) 32 | wg.Done() 33 | } 34 | 35 | func main() { 36 | for i := 0; i < 3; i++ { 37 | wg.Add(1) 38 | go doTask(i + 1) 39 | } 40 | wg.Wait() 41 | fmt.Println("All Task Done") 42 | } 43 | ``` 44 | 45 | `wg.Wait()` 会等待所有的子协程任务全部完成,所有子协程结束后,才会执行 `wg.Wait()` 后面的代码。 46 | 47 | ```bash 48 | Task 3 Done 49 | Task 1 Done 50 | Task 2 Done 51 | All Task Done 52 | ``` 53 | 54 | WaitGroup 只是傻傻地等待子协程结束,但是并不能主动通知子协程退出。假如开启了一个定时轮询的子协程,有没有什么办法,通知该子协程退出呢?这种场景下,可以使用 `select+chan` 的机制。 55 | 56 | ```go 57 | var stop chan bool 58 | 59 | func reqTask(name string) { 60 | for { 61 | select { 62 | case <-stop: 63 | fmt.Println("stop", name) 64 | return 65 | default: 66 | fmt.Println(name, "send request") 67 | time.Sleep(1 * time.Second) 68 | } 69 | } 70 | } 71 | 72 | func main() { 73 | stop = make(chan bool) 74 | go reqTask("worker1") 75 | time.Sleep(3 * time.Second) 76 | stop <- true 77 | time.Sleep(3 * time.Second) 78 | } 79 | ``` 80 | 81 | 子协程使用 for 循环定时轮询,如果 `stop` 信道有值,则退出,否则继续轮询。 82 | 83 | ```bash 84 | worker1 send request 85 | worker1 send request 86 | worker1 send request 87 | stop worker1 88 | ``` 89 | 90 | 更复杂的场景如何做并发控制呢?比如子协程中开启了新的子协程,或者需要同时控制多个子协程。这种场景下,`select+chan`的方式就显得力不从心了。 91 | 92 | Go 语言提供了 Context 标准库可以解决这类场景的问题,Context 的作用和它的名字很像,上下文,即子协程的下上文。Context 有两个主要的功能: 93 | 94 | - 通知子协程退出(正常退出,超时退出等); 95 | - 传递必要的参数。 96 | 97 | ## 2 context.WithCancel 98 | 99 | `context.WithCancel()` 创建可取消的 Context 对象,即可以主动通知子协程退出。 100 | 101 | ### 2.1 控制单个协程 102 | 103 | 使用 Context 改写上述的例子,效果与 `select+chan` 相同。 104 | 105 | ```go 106 | func reqTask(ctx context.Context, name string) { 107 | for { 108 | select { 109 | case <-ctx.Done(): 110 | fmt.Println("stop", name) 111 | return 112 | default: 113 | fmt.Println(name, "send request") 114 | time.Sleep(1 * time.Second) 115 | } 116 | } 117 | } 118 | 119 | func main() { 120 | ctx, cancel := context.WithCancel(context.Background()) 121 | go reqTask(ctx, "worker1") 122 | time.Sleep(3 * time.Second) 123 | cancel() 124 | time.Sleep(3 * time.Second) 125 | } 126 | ``` 127 | 128 | - `context.Backgroud()` 创建根 Context,通常在 main 函数、初始化和测试代码中创建,作为顶层 Context。 129 | - `context.WithCancel(parent)` 创建可取消的子 Context,同时返回函数 `cancel`。 130 | - 在子协程中,使用 select 调用 `<-ctx.Done()` 判断是否需要退出。 131 | - 主协程中,调用 `cancel()` 函数通知子协程退出。 132 | 133 | ### 2.2 控制多个协程 134 | 135 | ```go 136 | func main() { 137 | ctx, cancel := context.WithCancel(context.Background()) 138 | 139 | go reqTask(ctx, "worker1") 140 | go reqTask(ctx, "worker2") 141 | 142 | time.Sleep(3 * time.Second) 143 | cancel() 144 | time.Sleep(3 * time.Second) 145 | } 146 | ``` 147 | 148 | 为每个子协程传递相同的上下文 `ctx` 即可,调用 `cancel()` 函数后该 Context 控制的所有子协程都会退出。 149 | 150 | ```bash 151 | worker1 send request 152 | worker2 send request 153 | worker1 send request 154 | worker2 send request 155 | worker1 send request 156 | worker2 send request 157 | stop worker1 158 | stop worker2 159 | ``` 160 | 161 | ## 3 context.WithValue 162 | 163 | 如果需要往子协程中传递参数,可以使用 `context.WithValue()`。 164 | 165 | ```go 166 | type Options struct{ Interval time.Duration } 167 | 168 | func reqTask(ctx context.Context, name string) { 169 | for { 170 | select { 171 | case <-ctx.Done(): 172 | fmt.Println("stop", name) 173 | return 174 | default: 175 | fmt.Println(name, "send request") 176 | op := ctx.Value("options").(*Options) 177 | time.Sleep(op.Interval * time.Second) 178 | } 179 | } 180 | } 181 | 182 | func main() { 183 | ctx, cancel := context.WithCancel(context.Background()) 184 | vCtx := context.WithValue(ctx, "options", &Options{1}) 185 | 186 | go reqTask(vCtx, "worker1") 187 | go reqTask(vCtx, "worker2") 188 | 189 | time.Sleep(3 * time.Second) 190 | cancel() 191 | time.Sleep(3 * time.Second) 192 | } 193 | ``` 194 | 195 | - `context.WithValue()` 创建了一个基于 `ctx` 的子 Context,并携带了值 `options`。 196 | - 在子协程中,使用 `ctx.Value("options")` 获取到传递的值,读取/修改该值。 197 | 198 | ## 4 context.WithTimeout 199 | 200 | 如果需要控制子协程的执行时间,可以使用 `context.WithTimeout` 创建具有超时通知机制的 Context 对象。 201 | 202 | ``` 203 | func main() { 204 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 205 | go reqTask(ctx, "worker1") 206 | go reqTask(ctx, "worker2") 207 | 208 | time.Sleep(3 * time.Second) 209 | fmt.Println("before cancel") 210 | cancel() 211 | time.Sleep(3 * time.Second) 212 | } 213 | ``` 214 | 215 | `WithTimeout()`的使用与 `WithCancel()` 类似,多了一个参数,用于设置超时时间。执行结果如下: 216 | 217 | ```bash 218 | worker2 send request 219 | worker1 send request 220 | worker1 send request 221 | worker2 send request 222 | stop worker2 223 | stop worker1 224 | before cancel 225 | ``` 226 | 227 | 因为超时时间设置为 2s,但是 main 函数中,3s 后才会调用 `cancel()`,因此,在调用 `cancel()` 函数前,子协程因为超时已经退出了。 228 | 229 | ## 5 context.WithDeadline 230 | 231 | 超时退出可以控制子协程的最长执行时间,那 `context.WithDeadline()` 则可以控制子协程的最迟退出时间。 232 | 233 | ```go 234 | func reqTask(ctx context.Context, name string) { 235 | for { 236 | select { 237 | case <-ctx.Done(): 238 | fmt.Println("stop", name, ctx.Err()) 239 | return 240 | default: 241 | fmt.Println(name, "send request") 242 | time.Sleep(1 * time.Second) 243 | } 244 | } 245 | } 246 | 247 | func main() { 248 | ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second)) 249 | go reqTask(ctx, "worker1") 250 | go reqTask(ctx, "worker2") 251 | 252 | time.Sleep(3 * time.Second) 253 | fmt.Println("before cancel") 254 | cancel() 255 | time.Sleep(3 * time.Second) 256 | } 257 | ``` 258 | 259 | - `WithDeadline` 用于设置截止时间。在这个例子中,将截止时间设置为1s后,`cancel()` 函数在 3s 后调用,因此子协程将在调用 `cancel()` 函数前结束。 260 | - 在子协程中,可以通过 `ctx.Err()` 获取到子协程退出的错误原因。 261 | 262 | 运行结果如下: 263 | 264 | ```bash 265 | worker2 send request 266 | worker1 send request 267 | stop worker2 context deadline exceeded 268 | stop worker1 context deadline exceeded 269 | before cancel 270 | ``` 271 | 272 | 可以看到,子协程 `worker1` 和 `worker2` 均是因为截止时间到了而退出。 -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-context/context_sm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-context/context_sm.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-gin/gin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-gin/gin.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-gin/hello_gin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-gin/hello_gin.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-mmap.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go Mmap 文件内存映射简明教程 3 | seo_title: 快速入门 4 | date: 2020-04-20 22:30:00 5 | description: 简单理解,mmap 是一种将文件/设备映射到内存的方法,实现文件的磁盘地址和进程虚拟地址空间中的一段虚拟地址的一一映射关系。也就是说,可以在某个进程中通过操作这一段映射的内存,实现对文件的读写等操作。修改了这一段内存的内容,文件对应位置的内容也会同步修改,而读取这一段内存的内容,相当于读取文件对应位置的内容。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - mmap 13 | - memory mapping 14 | - windows 15 | - golang 16 | image: post/quick-go-mmap/mmap_sm.jpg 17 | --- 18 | 19 | ![golang mmap](quick-go-mmap/mmap.jpg) 20 | 21 | ## 1 mmap 简介 22 | 23 | > In computing, mmap is a POSIX-compliant Unix system call that maps files or devices into memory. It is a method of memory-mapped file I/O. 24 | > -- [mmap - wikipedia.org](https://en.wikipedia.org/wiki/Mmap) 25 | 26 | 简单理解,mmap 是一种将文件/设备映射到内存的方法,实现文件的磁盘地址和进程虚拟地址空间中的一段虚拟地址的一一映射关系。也就是说,可以在某个进程中通过操作这一段映射的内存,实现对文件的读写等操作。修改了这一段内存的内容,文件对应位置的内容也会同步修改,而读取这一段内存的内容,相当于读取文件对应位置的内容。 27 | 28 | mmap 另一个非常重要的特性是:减少内存的拷贝次数。在 linux 系统中,文件的读写操作通常通过 read 和 write 这两个系统调用来实现,这个过程会产生频繁的内存拷贝。比如 read 函数就涉及了 2 次内存拷贝: 29 | 30 | - 1) 操作系统读取磁盘文件到页缓存; 31 | - 2) 从页缓存将数据拷贝到 read 传递的 buf 中(例如进程中创建的byte数组)。 32 | 33 | mmap 只需要一次拷贝。即操作系统读取磁盘文件到页缓存,进程内部直接通过指针方式修改映射的内存。因此 mmap 特别适合读写频繁的场景,既减少了内存拷贝次数,提高效率,又简化了操作。KV数据库 [bbolt](https://github.com/etcd-io/bbolt) 就使用了这个方法持久化数据。 34 | 35 | ## 2 标准库 mmap 36 | 37 | Go 语言标准库 [golang.org/x/exp/mmap](https://godoc.org/golang.org/x/exp/mmap) 仅实现了 read 操作,后续能否支持 write 操作未知。使用场景非常有限。看一个简单的例子: 38 | 39 | 从第4个byte开始,读取 tmp.txt 2个byte的内容。 40 | 41 | ```go 42 | package main 43 | 44 | import ( 45 | "fmt" 46 | "golang.org/x/exp/mmap" 47 | ) 48 | 49 | func main() { 50 | at, _ := mmap.Open("./tmp.txt") 51 | buff := make([]byte, 2) 52 | _, _ = at.ReadAt(buff, 4) 53 | _ = at.Close() 54 | fmt.Println(string(buff)) 55 | } 56 | ``` 57 | 58 | ```bash 59 | $ echo "abcdefg" > tmp.txt 60 | $ go run . 61 | ef 62 | ``` 63 | 64 | 如果使用 `os.File` 操作,代码几乎是一样的,`os.File` 还支持写操作 `WriteAt`: 65 | 66 | ```go 67 | package main 68 | 69 | import ( 70 | "fmt" 71 | "os" 72 | ) 73 | 74 | func main() { 75 | f, _ := os.OpenFile("tmp.txt", os.O_CREATE|os.O_RDWR, 0644) 76 | _, _ = f.WriteAt([]byte("abcdefg"), 0) 77 | 78 | buff := make([]byte, 2) 79 | _, _ = f.ReadAt(buff, 4) 80 | _ = f.Close() 81 | fmt.Println(string(buff)) 82 | } 83 | ``` 84 | 85 | ## 3 mmap(linux) 86 | 87 | 如果要支持 write 操作,那么就需要直接调用 mmap 的系统调用来实现了。Linux 和 Windows 都支持 mmap,但接口有所不同。对于 linux 系统,mmap 方法定义如下: 88 | 89 | ```go 90 | func Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) 91 | ``` 92 | 93 | 每个参数的含义分别是: 94 | 95 | ```go 96 | - fd:待映射的文件描述符。 97 | - offset:映射到内存区域的起始位置,0 表示由内核指定内存地址。 98 | - length:要映射的内存区域的大小。 99 | - prot:内存保护标志位,可以通过或运算符`|`组合 100 | - PROT_EXEC // 页内容可以被执行 101 | - PROT_READ // 页内容可以被读取 102 | - PROT_WRITE // 页可以被写入 103 | - PROT_NONE // 页不可访问 104 | - flags:映射对象的类型,常用的是以下两类 105 | - MAP_SHARED // 共享映射,写入数据会复制回文件, 与映射该文件的其他进程共享。 106 | - MAP_PRIVATE // 建立一个写入时拷贝的私有映射,写入数据不影响原文件。 107 | ``` 108 | 109 | 首先定义2个常量和数据类型Demo: 110 | 111 | ```go 112 | const defaultMaxFileSize = 1 << 30 // 假设文件最大为 1G 113 | const defaultMemMapSize = 128 * (1 << 20) // 假设映射的内存大小为 128M 114 | 115 | type Demo struct { 116 | file *os.File 117 | data *[defaultMaxFileSize]byte 118 | dataRef []byte 119 | } 120 | 121 | func _assert(condition bool, msg string, v ...interface{}) { 122 | if !condition { 123 | panic(fmt.Sprintf(msg, v...)) 124 | } 125 | } 126 | ``` 127 | 128 | - 内存有换页机制,映射的物理内存可以远小于文件。 129 | - Demo结构体由3个字段构成,file 即文件描述符,data 是映射内存的起始地址,dataRef 用于后续取消映射。 130 | 131 | 定义 mmap, grow, ummap 三个方法: 132 | 133 | ```go 134 | func (demo *Demo) mmap() { 135 | b, err := syscall.Mmap(int(demo.file.Fd()), 0, defaultMemMapSize, syscall.PROT_WRITE|syscall.PROT_READ, syscall.MAP_SHARED) 136 | _assert(err == nil, "failed to mmap", err) 137 | demo.dataRef = b 138 | demo.data = (*[defaultMaxFileSize]byte)(unsafe.Pointer(&b[0])) 139 | } 140 | 141 | func (demo *Demo) grow(size int64) { 142 | if info, _ := demo.file.Stat(); info.Size() >= size { 143 | return 144 | } 145 | _assert(demo.file.Truncate(size) == nil, "failed to truncate") 146 | } 147 | 148 | func (demo *Demo) munmap() { 149 | _assert(syscall.Munmap(demo.dataRef) == nil, "failed to munmap") 150 | demo.data = nil 151 | demo.dataRef = nil 152 | } 153 | ``` 154 | 155 | - mmap 传入的内存保护标志位为 `syscall.PROT_WRITE|syscall.PROT_READ`,即可读可写,映射类型为 `syscall.MAP_SHARED`,即对内存的修改会同步到文件。 156 | - `syscall.Mmap` 返回的是一个切片对象,需要从该切片中获取到内存的起始地址,并转换为可操作的 byte 数组,byte数组的长度为 `defaultMaxFileSize`。 157 | - grow 用于修改文件的大小,Linux 不允许操作超过文件大小之外的内存地址。例如文件大小为 4K,可访问的地址是`data[0~4095]`,如果访问 `data[10000]` 会报错。 158 | - munmap 用于取消映射。 159 | 160 | 在文件中写入 `hello, geektutu!` 161 | 162 | ```go 163 | func main() { 164 | _ = os.Remove("tmp.txt") 165 | f, _ := os.OpenFile("tmp.txt", os.O_CREATE|os.O_RDWR, 0644) 166 | demo := &Demo{file: f} 167 | demo.grow(1) 168 | demo.mmap() 169 | defer demo.munmap() 170 | 171 | msg := "hello geektutu!" 172 | 173 | demo.grow(int64(len(msg) * 2)) 174 | for i, v := range msg { 175 | demo.data[2*i] = byte(v) 176 | demo.data[2*i+1] = byte(' ') 177 | } 178 | } 179 | ``` 180 | 181 | - 在调用 `mmap` 之前,调用了 `grow(1)`,因为在 `mmap` 中使用 `&b[0]` 获取到映射内存的起始地址,所以文件大小至少为 1 byte。 182 | - 接下来,便是通过直接操作 `demo.data`,修改文件内容了。 183 | 184 | 运行: 185 | 186 | ```bash 187 | $ go run . 188 | $ cat tmp.txt 189 | h e l l o g e e k t u t u ! 190 | ``` 191 | 192 | ## 4 mmap(Windows) 193 | 194 | 相对于 Linux,Windows 上 mmap 的使用要复杂一些。 195 | 196 | ```go 197 | func (demo *Demo) mmap() { 198 | h, err := syscall.CreateFileMapping(syscall.Handle(demo.file.Fd()), nil, syscall.PAGE_READWRITE, 0, defaultMemMapSize, nil) 199 | _assert(h != 0, "failed to map", err) 200 | 201 | addr, err := syscall.MapViewOfFile(h, syscall.FILE_MAP_WRITE, 0, 0, uintptr(defaultMemMapSize)) 202 | _assert(addr != 0, "MapViewOfFile failed", err) 203 | 204 | err = syscall.CloseHandle(syscall.Handle(h)); 205 | _assert(err == nil, "CloseHandle failed") 206 | 207 | // Convert to a byte array. 208 | demo.data = (*[defaultMaxFileSize]byte)(unsafe.Pointer(addr)) 209 | } 210 | 211 | func (demo *Demo) munmap() { 212 | addr := (uintptr)(unsafe.Pointer(&demo.data[0])) 213 | _assert(syscall.UnmapViewOfFile(addr) == nil, "failed to munmap") 214 | } 215 | ``` 216 | - 需要 `CreateFileMapping` 和 `MapViewOfFile` 两步才能完成内存映射。`MapViewOfFile` 返回映射成功的内存地址,因此可以直接将该地址转换成 byte 数组。 217 | - Windows 对文件的大小没有要求,直接操作内存`data`,文件大小会自动发生改变。 218 | 219 | 使用时无需关注文件的大小。 220 | 221 | ```go 222 | func main() { 223 | _ = os.Remove("tmp.txt") 224 | f, _ := os.OpenFile("tmp.txt", os.O_CREATE|os.O_RDWR, 0644) 225 | demo := &Demo{file: f} 226 | demo.mmap() 227 | defer demo.munmap() 228 | 229 | msg := "hello geektutu!" 230 | for i, v := range msg { 231 | demo.data[2*i] = byte(v) 232 | demo.data[2*i+1] = byte(' ') 233 | } 234 | } 235 | ``` 236 | 237 | ```go 238 | $ go run . 239 | $ cat .\tmp.txt 240 | h e l l o g e e k t u t u ! 241 | ``` 242 | 243 | ## 附 参考 244 | 245 | - [edsrzf/mmap-go - github.com](https://github.com/edsrzf/mmap-go) 246 | - [golang 官方文档 syscall - golang.org](https://golang.org/pkg/syscall/) 247 | -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-mmap/mmap.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-mmap/mmap.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-mmap/mmap_sm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-mmap/mmap_sm.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-protobuf.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go Protobuf 简明教程 3 | seo_title: 快速入门 4 | date: 2020-01-11 00:27:00 5 | description: protobuf 即 Protocol Buffers,是一种轻便高效的结构化数据存储格式,与语言、平台无关,可扩展可序列化。protobuf 性能和效率大幅度优于 JSON、XML 等其他的结构化数据格式。protobuf 是以二进制方式存储的,占用空间小,但也带来了可读性差的缺点。protobuf 在通信协议和数据存储等领域应用广泛。本文介绍的是 Protocol Buffers 3 (protobuf3, proto3) 的安装和基本语法,以及如何在 Go 语言使用。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - Protocol Buffers 13 | - Protobuf 14 | - Golang 15 | - RPC 16 | image: post/quick-go-protobuf/protocol-buffers.jpg 17 | --- 18 | 19 | ![Golang Protocol Buffers](quick-go-protobuf/go-protobuf.jpg) 20 | 21 | ## 1 Protocol Buffers 简介 22 | 23 | protobuf 即 Protocol Buffers,是一种轻便高效的结构化数据存储格式,与语言、平台无关,可扩展可序列化。protobuf 性能和效率大幅度优于 JSON、XML 等其他的结构化数据格式。protobuf 是以二进制方式存储的,占用空间小,但也带来了可读性差的缺点。protobuf 在通信协议和数据存储等领域应用广泛。例如著名的分布式缓存工具 [Memcached](https://memcached.org/) 的 Go 语言版本[groupcache](https://github.com/golang/groupcache) 就使用了 protobuf 作为其 RPC 数据格式。 24 | 25 | Protobuf 在 `.proto` 定义需要处理的结构化数据,可以通过 `protoc` 工具,将 `.proto` 文件转换为 C、C++、Golang、Java、Python 等多种语言的代码,兼容性好,易于使用。 26 | 27 | ## 2 安装 28 | 29 | ### 2.1 protoc 30 | 31 | 从 [Protobuf Releases](https://github.com/protocolbuffers/protobuf/releases) 下载最先版本的发布包安装。如果是 Ubuntu,可以按照如下步骤操作(以3.11.2为例)。 32 | 33 | ```bash 34 | # 下载安装包 35 | $ wget https://github.com/protocolbuffers/protobuf/releases/download/v3.11.2/protoc-3.11.2-linux-x86_64.zip 36 | # 解压到 /usr/local 目录下 37 | $ sudo 7z x protoc-3.11.2-linux-x86_64.zip -o/usr/local 38 | ``` 39 | 40 | 如果不想安装在 /usr/local 目录下,可以解压到其他的其他,并把解压路径下的 bin 目录 加入到环境变量即可。 41 | 42 | 如果能正常显示版本,则表示安装成功。 43 | 44 | ```bash 45 | $ protoc --version 46 | libprotoc 3.11.2 47 | ``` 48 | 49 | ### 2.2 protoc-gen-go 50 | 51 | 我们需要在 Golang 中使用 protobuf,还需要安装 protoc-gen-go,这个工具用来将 `.proto` 文件转换为 Golang 代码。 52 | 53 | ```bash 54 | go get -u github.com/golang/protobuf/protoc-gen-go 55 | ``` 56 | 57 | protoc-gen-go 将自动安装到 `$GOPATH/bin` 目录下,也需要将这个目录加入到环境变量中。 58 | 59 | ## 3 定义消息类型 60 | 61 | 接下来,我们创建一个非常简单的示例,`student.proto` 62 | 63 | ```go 64 | syntax = "proto3"; 65 | package main; 66 | 67 | // this is a comment 68 | message Student { 69 | string name = 1; 70 | bool male = 2; 71 | repeated int32 scores = 3; 72 | } 73 | ``` 74 | 75 | 在当前目录下执行: 76 | 77 | ```bash 78 | $ protoc --go_out=. *.proto 79 | $ ls 80 | student.pb.go student.proto 81 | ``` 82 | 83 | 即是,将该目录下的所有的 .proto 文件转换为 Go 代码,我们可以看到该目录下多出了一个 Go 文件 *student.pb.go*。这个文件内部定义了一个结构体 Student,以及相关的方法: 84 | 85 | ```go 86 | type Student struct { 87 | Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` 88 | Male bool `protobuf:"varint,2,opt,name=male,proto3" json:"male,omitempty"` 89 | Scores []int32 `protobuf:"varint,3,rep,packed,name=scores,proto3" json:"scores,omitempty"` 90 | ... 91 | } 92 | ``` 93 | 94 | 逐行解读`student.proto` 95 | 96 | - protobuf 有2个版本,默认版本是 proto2,如果需要 proto3,则需要在非空非注释第一行使用 `syntax = "proto3"` 标明版本。 97 | - `package`,即包名声明符是可选的,用来防止不同的消息类型有命名冲突。 98 | - 消息类型 使用 `message` 关键字定义,Student 是类型名,name, male, scores 是该类型的 3 个字段,类型分别为 string, bool 和 []int32。字段可以是标量类型,也可以是合成类型。 99 | - 每个字段的修饰符默认是 singular,一般省略不写,`repeated` 表示字段可重复,即用来表示 Go 语言中的数组类型。 100 | - 每个字符 `=`后面的数字称为标识符,每个字段都需要提供一个唯一的标识符。标识符用来在消息的二进制格式中识别各个字段,一旦使用就不能够再改变,标识符的取值范围为 [1, 2^29 - 1] 。 101 | - .proto 文件可以写注释,单行注释 `//`,多行注释 `/* ... */` 102 | - 一个 .proto 文件中可以写多个消息类型,即对应多个结构体(struct)。 103 | 104 | 接下来,就可以在项目代码中直接使用了,以下是一个非常简单的例子,即证明被序列化的和反序列化后的实例,包含相同的数据。 105 | 106 | ```go 107 | package main 108 | 109 | import ( 110 | "log" 111 | 112 | "github.com/golang/protobuf/proto" 113 | ) 114 | 115 | func main() { 116 | test := &Student{ 117 | Name: "geektutu", 118 | Male: true, 119 | Scores: []int32{98, 85, 88}, 120 | } 121 | data, err := proto.Marshal(test) 122 | if err != nil { 123 | log.Fatal("marshaling error: ", err) 124 | } 125 | newTest := &Student{} 126 | err = proto.Unmarshal(data, newTest) 127 | if err != nil { 128 | log.Fatal("unmarshaling error: ", err) 129 | } 130 | // Now test and newTest contain the same data. 131 | if test.GetName() != newTest.GetName() { 132 | log.Fatalf("data mismatch %q != %q", test.GetName(), newTest.GetName()) 133 | } 134 | } 135 | ``` 136 | 137 | - 保留字段(Reserved Field) 138 | 139 | 更新消息类型时,可能会将某些字段/标识符删除。这些被删掉的字段/标识符可能被重新使用,如果加载老版本的数据时,可能会造成数据冲突,在升级时,可以将这些字段/标识符保留(reserved),这样就不会被重新使用了,protoc 会检查。 140 | 141 | ```go 142 | message Foo { 143 | reserved 2, 15, 9 to 11; 144 | reserved "foo", "bar"; 145 | } 146 | ``` 147 | 148 | 149 | ## 4 字段类型 150 | 151 | ### 4.1 标量类型(Scalar) 152 | 153 | | proto类型 | go类型 | 备注 | proto类型 | go类型 | 备注 | 154 | |---|---|---|---|---|---| 155 | | double | float64 | | float | float32 || 156 | | int32 | int32 | | int64 | int64 | | 157 | | uint32 | uint32 | |uint64 | uint64 | | 158 | | sint32 | int32 | 适合负数 | sint64 | int64 | 适合负数 | 159 | | fixed32 | uint32 | 固长编码,适合大于2^28的值 | fixed64 | uint64 | 固长编码,适合大于2^56的值 | 160 | | sfixed32 | int32 | 固长编码 | sfixed64 | int64 | 固长编码 | 161 | | bool | bool | |string|string| UTF8 编码,长度不超过 2^32 | 162 | | bytes | []byte | 任意字节序列,长度不超过 2^32 | 163 | 164 | 标量类型如果没有被赋值,则不会被序列化,解析时,会赋予默认值。 165 | 166 | - strings:空字符串 167 | - bytes:空序列 168 | - bools:false 169 | - 数值类型:0 170 | 171 | ### 4.2 枚举(Enumerations) 172 | 173 | 枚举类型适用于提供一组预定义的值,选择其中一个。例如我们将性别定义为枚举类型。 174 | 175 | ```go 176 | message Student { 177 | string name = 1; 178 | enum Gender { 179 | FEMALE = 0; 180 | MALE = 1; 181 | } 182 | Gender gender = 2; 183 | repeated int32 scores = 3; 184 | } 185 | ``` 186 | 187 | - 枚举类型的第一个选项的标识符必须是0,这也是枚举类型的默认值。 188 | - 别名(Alias),允许为不同的枚举值赋予相同的标识符,称之为别名,需要打开`allow_alias`选项。 189 | 190 | ```go 191 | message EnumAllowAlias { 192 | enum Status { 193 | option allow_alias = true; 194 | UNKOWN = 0; 195 | STARTED = 1; 196 | RUNNING = 1; 197 | } 198 | } 199 | ``` 200 | 201 | ### 4.3 使用其他消息类型 202 | 203 | `Result`是另一个消息类型,在 SearchReponse 作为一个消息字段类型使用。 204 | 205 | ```go 206 | message SearchResponse { 207 | repeated Result results = 1; 208 | } 209 | 210 | message Result { 211 | string url = 1; 212 | string title = 2; 213 | repeated string snippets = 3; 214 | } 215 | ``` 216 | 217 | 嵌套写也是支持的: 218 | 219 | ```go 220 | message SearchResponse { 221 | message Result { 222 | string url = 1; 223 | string title = 2; 224 | repeated string snippets = 3; 225 | } 226 | repeated Result results = 1; 227 | } 228 | ``` 229 | 230 | 如果定义在其他文件中,可以导入其他消息类型来使用: 231 | 232 | ```go 233 | import "myproject/other_protos.proto"; 234 | ``` 235 | 236 | ### 4.4 任意类型(Any) 237 | 238 | Any 可以表示不在 .proto 中定义任意的内置类型。 239 | 240 | ```go 241 | import "google/protobuf/any.proto"; 242 | 243 | message ErrorStatus { 244 | string message = 1; 245 | repeated google.protobuf.Any details = 2; 246 | } 247 | ``` 248 | 249 | ### 4.5 oneof 250 | 251 | ```go 252 | message SampleMessage { 253 | oneof test_oneof { 254 | string name = 4; 255 | SubMessage sub_message = 9; 256 | } 257 | } 258 | ``` 259 | 260 | ### 4.6 map 261 | 262 | ```go 263 | message MapRequest { 264 | map points = 1; 265 | } 266 | ``` 267 | 268 | ## 5 定义服务(Services) 269 | 270 | 如果消息类型是用来远程通信的(Remote Procedure Call, RPC),可以在 .proto 文件中定义 RPC 服务接口。例如我们定义了一个名为 SearchService 的 RPC 服务,提供了 `Search` 接口,入参是 `SearchRequest` 类型,返回类型是 `SearchResponse` 271 | 272 | ```go 273 | service SearchService { 274 | rpc Search (SearchRequest) returns (SearchResponse); 275 | } 276 | ``` 277 | 278 | 官方仓库也提供了一个[插件列表](https://github.com/protocolbuffers/protobuf/blob/master/docs/third_party.md),帮助开发基于 Protocol Buffer 的 RPC 服务。 279 | 280 | ## 6 protoc 其他参数 281 | 282 | 命令行使用方法 283 | 284 | ```bash 285 | protoc --proto_path=IMPORT_PATH --_out=DST_DIR path/to/file.proto 286 | ``` 287 | 288 | - `--proto_path=IMPORT_PATH`:可以在 .proto 文件中 import 其他的 .proto 文件,proto_path 即用来指定其他 .proto 文件的查找目录。如果没有引入其他的 .proto 文件,该参数可以省略。 289 | - `--_out=DST_DIR`:指定生成代码的目标文件夹,例如 --go_out=. 即生成 GO 代码在当前文件夹,另外支持 cpp/java/python/ruby/objc/csharp/php 等语言 290 | 291 | ## 7 推荐风格 292 | 293 | - 文件(Files) 294 | - 文件名使用小写下划线的命名风格,例如 lower_snake_case.proto 295 | - 每行不超过 80 字符 296 | - 使用 2 个空格缩进 297 | 298 | - 包(Packages) 299 | - 包名应该和目录结构对应,例如文件在`my/package/`目录下,包名应为 `my.package` 300 | 301 | - 消息和字段(Messages & Fields) 302 | - 消息名使用首字母大写驼峰风格(CamelCase),例如`message StudentRequest { ... }` 303 | - 字段名使用小写下划线的风格,例如 `string status_code = 1` 304 | - 枚举类型,枚举名使用首字母大写驼峰风格,例如 `enum FooBar`,枚举值使用全大写下划线隔开的风格(CAPITALS_WITH_UNDERSCORES ),例如 FOO_DEFAULT=1 305 | 306 | - 服务(Services) 307 | - RPC 服务名和方法名,均使用首字母大写驼峰风格,例如`service FooService{ rpc GetSomething() }` 308 | 309 | ## 附:参考 310 | 311 | 1. [protobuf 代码仓库 - github.com](https://github.com/protocolbuffers/protobuf) 312 | 2. [golang protobuf 代码仓库 - github.com](https://github.com/golang/protobuf) 313 | 3. [Remote procedure call 远程过程调用 - wikipedia.org](https://en.wikipedia.org/wiki/Remote_procedure_call) 314 | 4. [Groupcache Go语言版 memcached - github.com](https://github.com/golang/groupcache) 315 | 5. [Language Guide (proto3) 官方指南 - google.com](https://developers.google.com/protocol-buffers/docs/proto3) 316 | 6. [Proto Style Guide 代码风格指南 - google.com](https://developers.google.com/protocol-buffers/docs/style) 317 | 7. [Protocol Buffer 插件列表 - github.com](https://github.com/protocolbuffers/protobuf/blob/master/docs/third_party.md) -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-protobuf/go-protobuf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-protobuf/go-protobuf.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-protobuf/protocol-buffers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-protobuf/protocol-buffers.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-rpc.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go RPC & TLS 鉴权简明教程 3 | seo_title: 快速入门 4 | date: 2020-01-13 23:47:00 5 | description: 本文介绍了 Go 语言(Golang)中远程过程调用(Remote Procedure Call, RPC)的使用方式,使用 Golang 标准库 net/rpc,同时介绍了异步调用,以及如何使用 TLS/SSL/HTTPS 实现服务器端和客户端的单向鉴权、双向鉴权。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - TLS 13 | - Golang 14 | - RPC 15 | - 证书 16 | - 鉴权 17 | image: post/quick-go-rpc/go-rpc.jpg 18 | --- 19 | 20 | 本文介绍了 Go 语言远程过程调用(Remote Procedure Call, RPC)的使用方式,示例基于 Golang 标准库 net/rpc,同时介绍了如何基于 TLS/SSL 实现服务器端和客户端的单向鉴权、双向鉴权。 21 | 22 | ## 1 RPC 简介 23 | 24 | > 远程过程调用(英语:Remote Procedure Call,缩写为 RPC)是一个计算机通信协议。该协议允许运行于一台计算机的程序调用另一个地址空间(通常为一个开放网络的一台计算机)的子程序,而程序员就像调用本地程序一样,无需额外地为这个交互作用编程(无需关注细节)。RPC是一种服务器-客户端(Client/Server)模式,经典实现是一个通过发送请求-接受回应进行信息交互的系统。 25 | > -- [远程过程调用 - Wikipedia.org](https://zh.wikipedia.org/wiki/%E9%81%A0%E7%A8%8B%E9%81%8E%E7%A8%8B%E8%AA%BF%E7%94%A8) 26 | 27 | 划重点:**程序员就像调用本地程序一样,无需关注细节** 28 | 29 | RPC 协议假定某种传输协议(TCP, UDP)存在,为通信程序之间携带信息数据。使用 RPC 协议,无需关注底层网络技术协议,调用远程方法就像在调用本地方法一样。 30 | 31 | RPC 流程: 32 | 33 | ![RPC PROCEDURE](quick-go-rpc/rpc-procedure.jpg) 34 | 35 | RPC 模型是一个典型的客户端-服务器模型(Client-Server, CS),相比于调用本地的接口,RPC 还需要知道的是服务器端的地址信息。本地调用,好比两个人面对面说话,而 RPC 好比打电话,需要知道对方的电话号码,但是并不需要关心语音是怎么编码,如何传输,又如何解码的。 36 | 37 | 接下来我们将展示如何将一个简单的本地调用的程序一步步地改造一个 RPC 服务。 38 | 39 | 示例使用 Go 语言,RPC 使用 Golang 提供的`net/rpc` 标准库 40 | 41 | ## 2 一个简单的计算二次方的程序 42 | 43 | 不考虑 RPC 调用,仅考虑本地调用的场景,程序实现如下: 44 | 45 | ```go 46 | // main.go 47 | package main 48 | 49 | import "log" 50 | 51 | type Result struct { 52 | Num, Ans int 53 | } 54 | 55 | type Cal int 56 | 57 | func (cal *Cal) Square(num int) *Result { 58 | return &Result{ 59 | Num: num, 60 | Ans: num * num, 61 | } 62 | } 63 | 64 | func main() { 65 | cal := new(Cal) 66 | result := cal.Square(12) 67 | log.Printf("%d^2 = %d", result.Num, result.Ans) 68 | } 69 | ``` 70 | 71 | 在这个20行的程序中,我们做了以下几件事: 72 | 73 | - `Cal` 结构体,提供了 Square 方法,用于计算传入参数 num 的 二次方。 74 | - `Result` 结构体,包含 Num 和 Ans 两个字段,Ans 是计算后的值,Num 是待计算的值。 75 | - `main` 函数,测试我们实现的 Square 方法。 76 | 77 | 运行 main.go,将会输出 78 | 79 | ```go 80 | $ go run main.go 81 | 2020/01/13 20:27:08 12^2 = 144 82 | ``` 83 | 84 | ## 3 RPC 需要满足什么条件 85 | 86 | 虽然说,远程过程调用并不需要我们关心如何编解码,如何通信,但是最基本的,如果一个方法需要支持远程过程调用,需要满足一定的约束和规范。不同 RPC 框架的约束和规范是不同的,如果使用 Golang 的标准库 `net/rpc`,方法需要长这个样子: 87 | 88 | ```go 89 | func (t *T) MethodName(argType T1, replyType *T2) error 90 | ``` 91 | 92 | 即需要满足以下 5 个条件: 93 | 94 | - 1) 方法类型(T)是导出的(首字母大写) 95 | - 2) 方法名(MethodName)是导出的 96 | - 3) 方法有2个参数(argType T1, replyType *T2),均为导出/内置类型 97 | - 4) 方法的第2个参数一个指针(replyType *T2) 98 | - 5) 方法的返回值类型是 error 99 | 100 | `net/rpc` 对参数个数的限制比较严格,仅能有2个,第一个参数是调用者提供的请求参数,第二个参数是返回给调用者的响应参数,也就是说,服务端需要将计算结果写在第二个参数中。如果调用过程中发生错误,会返回 error 给调用者。 101 | 102 | 接下来,我们改造下 Square 函数,以满足上述 5 个条件。 103 | 104 | ```go 105 | func (cal *Cal) Square(num int, result *Result) error { 106 | result.Num = num 107 | result.Ans = num * num 108 | return nil 109 | } 110 | 111 | func main() { 112 | cal := new(Cal) 113 | var result Result 114 | cal.Square(11, &result) 115 | log.Printf("%d^2 = %d", result.Num, result.Ans) 116 | } 117 | ``` 118 | 119 | - Cal 和 Square 均为导出类型,满足条件 1) 和 2) 120 | - 2 个参数,`num int` 为内置类型,`result *Result` 为导出类型,满足条件 3) 121 | - 第2个参数 `result *Result` 是一个指针,满足条件 4) 122 | - 返回值类型是 error,满足条件 5) 123 | 124 | 至此,方法 Cal.Square 满足了 RPC 调用的5个条件。 125 | 126 | ## 4 RPC 服务与调用 127 | 128 | ### 4.1 基于HTTP,启动 RPC 服务 129 | 130 | RPC 是一个典型的客户端-服务器(Client-Server, CS) 架构模型,很显然,需要将 Cal.Square 方法放在服务端。服务端需要提供一个套接字服务,处理客户端发送的请求。通常可以基于 HTTP 协议,监听一个端口,等待 HTTP 请求。 131 | 132 | 接下来我们新建一个文件夹 server,将 Cal.Square 方法移动到 server/main.go 中,并在 main 函数中启动 RPC 服务。 133 | 134 | ```go 135 | // server/main.go 136 | package main 137 | 138 | import ( 139 | "log" 140 | "net" 141 | "net/http" 142 | "net/rpc" 143 | ) 144 | 145 | type Result struct { 146 | Num, Ans int 147 | } 148 | 149 | type Cal int 150 | 151 | func (cal *Cal) Square(num int, result *Result) error { 152 | result.Num = num 153 | result.Ans = num * num 154 | return nil 155 | } 156 | 157 | func main() { 158 | rpc.Register(new(Cal)) 159 | rpc.HandleHTTP() 160 | 161 | log.Printf("Serving RPC server on port %d", 1234) 162 | if err := http.ListenAndServe(":1234", nil); err != nil { 163 | log.Fatal("Error serving: ", err) 164 | } 165 | } 166 | ``` 167 | 168 | - 使用 `rpc.Register`,发布 Cal 中满足 RPC 注册条件的方法(Cal.Square) 169 | - 使用 `rpc.HandleHTTP` 注册用于处理 RPC 消息的 HTTP Handler 170 | - 使用 `http.ListenAndServe` 监听 1234 端口,等待 RPC 请求。 171 | 172 | 我们在 server 目录下,执行 173 | 174 | ```bash 175 | $ go run main.go 176 | 2020/01/13 20:59:22 Serving RPC server on port 1234 177 | ``` 178 | 179 | 此时,RPC 服务已经启动,等待客户端的调用。 180 | 181 | ### 4.2 实现客户端 182 | 183 | 我们在 client 目录中新建文件 client/main.go,创建 HTTP 客户端,调用 Cal.Square 方法。 184 | 185 | ```go 186 | // client/main.go 187 | package main 188 | 189 | import ( 190 | "log" 191 | "net/rpc" 192 | ) 193 | 194 | type Result struct { 195 | Num, Ans int 196 | } 197 | 198 | func main() { 199 | client, _ := rpc.DialHTTP("tcp", "localhost:1234") 200 | var result Result 201 | if err := client.Call("Cal.Square", 12, &result); err != nil { 202 | log.Fatal("Failed to call Cal.Square. ", err) 203 | } 204 | log.Printf("%d^2 = %d", result.Num, result.Ans) 205 | } 206 | ``` 207 | 208 | 在客户端的实现中,因为要用到 Result 类型,简单起见,我们拷贝了 `Result` 的定义。 209 | 210 | - 使用 `rpc.DialHTTP` 创建了 HTTP 客户端 client,并且创建了与 localhost:1234 的链接,1234 恰好是 RPC 服务监听的端口。 211 | - 使用 `rpc.Call` 调用远程方法,第1个参数是方法名 Cal.Square,后两个参数与 Cal.Square 的定义的参数相对应。 212 | 213 | 我们在 client 目录下,执行 214 | 215 | ```bash 216 | 2020/01/13 21:17:45 12^2 = 144 217 | ``` 218 | 219 | 如果能够返回计算的结果,说明调用成功。 220 | 221 | ### 4.3 异步调用 222 | 223 | `client.Call` 是同步调用的方式,会阻塞当前的程序,直到结果返回。如果有异步调用的需求,可以考虑使用`client.Go`,如下 224 | 225 | ```go 226 | func main() { 227 | client, _ := rpc.DialHTTP("tcp", "localhost:1234") 228 | var result Result 229 | asyncCall := client.Go("Cal.Square", 12, &result, nil) 230 | log.Printf("%d^2 = %d", result.Num, result.Ans) 231 | 232 | <-asyncCall.Done 233 | log.Printf("%d^2 = %d", result.Num, result.Ans) 234 | 235 | } 236 | ``` 237 | 238 | 执行结果如下: 239 | 240 | ``` 241 | 2020/01/13 21:34:26 0^2 = 0 242 | 2020/01/13 21:34:26 12^2 = 144 243 | ``` 244 | 245 | 因为 `client.Go` 是异步调用,因此第一次打印 result,result 没有被赋值。而通过调用 `<-asyncCall.Done`,阻塞当前程序直到 RPC 调用结束,因此第二次打印 result 时,能够看到正确的赋值。 246 | 247 | ## 5 证书鉴权(TLS/SSL) 248 | 249 | ### 5.1 客户端对服务器端鉴权 250 | 251 | HTTP 协议默认是不加密的,我们可以使用证书来保证通信过程的安全。 252 | 253 | 生成私钥和自签名的证书,并将 server.key 权限设置为只读,保证私钥的安全。 254 | 255 | ```bash 256 | # 生成私钥 257 | openssl genrsa -out server.key 2048 258 | # 生成证书 259 | openssl req -new -x509 -key server.key -out server.crt -days 3650 260 | # 只读权限 261 | chmod 400 server.key 262 | ``` 263 | 264 | 执行完,当前文件夹下多出了 server.crt 和 server.key 2 个文件。 265 | 266 | 服务器端可以使用生成的 server.crt 和 server.key 文件启动 TLS 的端口监听。 267 | 268 | ```go 269 | // server/main.go 270 | import ( 271 | "crypto/tls" 272 | "log" 273 | "net/rpc" 274 | ) 275 | 276 | func main() { 277 | rpc.Register(new(Cal)) 278 | cert, _ := tls.LoadX509KeyPair("server.crt", "server.key") 279 | config := &tls.Config{ 280 | Certificates: []tls.Certificate{cert}, 281 | } 282 | listener, _ := tls.Listen("tcp", ":1234", config) 283 | log.Printf("Serving RPC server on port %d", 1234) 284 | 285 | for { 286 | conn, _ := listener.Accept() 287 | defer conn.Close() 288 | go rpc.ServeConn(conn) 289 | } 290 | } 291 | ``` 292 | 293 | 客户端也需要做相应的修改,使用 `tls.Dial` 代替 `rpc.DialHTTP` 连接服务端,如果客户端不需要对服务端鉴权,那么可以设置 `InsecureSkipVerify:true`,即可跳过对服务端的鉴权,例如: 294 | 295 | ```go 296 | // client/main.go 297 | import ( 298 | "crypto/tls" 299 | "log" 300 | "net/rpc" 301 | ) 302 | 303 | func main() { 304 | config := &tls.Config{ 305 | InsecureSkipVerify: true, 306 | } 307 | conn, _ := tls.Dial("tcp", "localhost:1234", config) 308 | defer conn.Close() 309 | client := rpc.NewClient(conn) 310 | 311 | var result Result 312 | if err := client.Call("Cal.Square", 12, &result); err != nil { 313 | log.Fatal("Failed to call Cal.Square. ", err) 314 | } 315 | 316 | log.Printf("%d^2 = %d", result.Num, result.Ans) 317 | } 318 | ``` 319 | 320 | 如果需要对服务器端鉴权,那么需要将服务端的证书添加到信任证书池中,如下: 321 | 322 | ```go 323 | // client/main.go 324 | 325 | func main() { 326 | certPool := x509.NewCertPool() 327 | certBytes, err := ioutil.ReadFile("../server/server.crt") 328 | if err != nil { 329 | log.Fatal("Failed to read server.crt") 330 | } 331 | certPool.AppendCertsFromPEM(certBytes) 332 | 333 | config := &tls.Config{ 334 | RootCAs: certPool, 335 | } 336 | 337 | conn, _ := tls.Dial("tcp", "localhost:1234", config) 338 | defer conn.Close() 339 | client := rpc.NewClient(conn) 340 | 341 | var result Result 342 | if err := client.Call("Cal.Square", 12, &result); err != nil { 343 | log.Fatal("Failed to call Cal.Square. ", err) 344 | } 345 | 346 | log.Printf("%d^2 = %d", result.Num, result.Ans) 347 | } 348 | ``` 349 | 350 | ### 5.2 服务器端对客户端的鉴权 351 | 352 | 服务器端对客户端的鉴权是类似的,核心在于 `tls.Config` 的配置: 353 | 354 | - 把对方的证书添加到自己的信任证书池 `RootCAs`(客户端配置),`ClientCAs`(服务器端配置) 中。 355 | - 创建链接时,配置自己的证书 `Certificates`。 356 | 357 | 358 | 客户端的 config 作如下修改: 359 | 360 | ```go 361 | // client/main.go 362 | 363 | cert, _ := tls.LoadX509KeyPair("client.crt", "client.key") 364 | certPool := x509.NewCertPool() 365 | certBytes, _ := ioutil.ReadFile("../server/server.crt") 366 | certPool.AppendCertsFromPEM(certBytes) 367 | config := &tls.Config{ 368 | Certificates: []tls.Certificate{cert}, 369 | RootCAs: certPool, 370 | } 371 | ``` 372 | 373 | 服务器端的 config 作如下修改: 374 | 375 | ```go 376 | // server/main.go 377 | 378 | cert, _ := tls.LoadX509KeyPair("server.crt", "server.key") 379 | certPool := x509.NewCertPool() 380 | certBytes, _ := ioutil.ReadFile("../client/client.crt") 381 | certPool.AppendCertsFromPEM(certBytes) 382 | config := &tls.Config{ 383 | Certificates: []tls.Certificate{cert}, 384 | ClientAuth: tls.RequireAndVerifyClientCert, 385 | ClientCAs: certPool, 386 | } 387 | ``` 388 | 389 | ## 附:参考 390 | 391 | 1. [Golang net/rpc 官方文档 - golang.org](https://golang.org/pkg/net/rpc/) 392 | 2. [Golang TLS 配置 - github.com](https://github.com/denji/golang-tls) 393 | -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-rpc/go-rpc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-rpc/go-rpc.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-rpc/rpc-procedure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-rpc/rpc-procedure.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-test/go_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-test/go_test.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-wasm.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go WebAssembly (Wasm) 简明教程 3 | seo_title: 快速入门 4 | date: 2020-01-23 23:25:00 5 | description: 本文介绍了如何使用Go 语言(golang)、WebAssembly 和 gopherjs 进行前端开发。包括注册函数,并与浏览器 JavaScript 对象交互,操作 DOM 元素,异步编程与回调函数等。最后介绍了一些进阶的 Demo(游戏,渲染等方面)和相关的项目和文档。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - Go WebAssembly 13 | - wasm 14 | - gopherjs 15 | image: post/quick-go-wasm/go-wasm.jpg 16 | github: https://github.com/geektutu/7days-golang 17 | --- 18 | 19 | 20 | ![Golang WebAssembly](quick-go-wasm/go-wasm.jpg) 21 | 22 | ## 1 WebAssembly 简介 23 | 24 | > WebAssembly是一种新的编码方式,可以在现代的网络浏览器中运行 - 它是一种低级的类汇编语言,具有紧凑的二进制格式,可以接近原生的性能运行,并为诸如C / C ++等语言提供一个编译目标,以便它们可以在Web上运行。它也被设计为可以与JavaScript共存,允许两者一起工作。 —— [MDN web docs - mozilla.org](https://developer.mozilla.org/zh-CN/docs/WebAssembly) 25 | 26 | 从 MDN 的介绍中,我们可以得出几个结论: 27 | 28 | - 1)WebAssembly 是一种二进制编码格式,而不是一门新的语言。 29 | - 2) WebAssembly 不是为了取代 JavaScript,而是一种补充(至少现阶段是这样),结合 WebAssembly 的性能优势,很大可能集中在对性能要求高(例如游戏,AI),或是对交互体验要求高(例如移动端)的场景。 30 | - 3)C/C++ 等语言可以编译 WebAssembly 的目标文件,也就是说,其他语言可以通过编译器支持,而写出能够在浏览器前端运行的代码。 31 | 32 | Go 语言在 1.11 版本(2018年8月) 加入了对 WebAssembly (Wasm) 的原生支持,使用 Go 语言开发 WebAssembly 相关的应用变得更加地简单。Go 语言的内建支持是 Go 语言进军前端的一个重要的里程碑。在这之前,如果想使用 Go 语言开发前端,需要使用 [GopherJS](https://github.com/gopherjs/gopherjs),GopherJS 是一个编译器,可以将 Go 语言转换成可以在浏览器中运行的 JavaScript 代码。新版本的 Go 则直接将 Go 代码编译为 wasm 二进制文件,而不再需要转为 JavaScript 代码。更巧的是,实现 GopherJS 和在 Go 语言中内建支持 WebAssembly 的是同一拨人。 33 | 34 | Go 语言实现的函数可以直接导出供 JavaScript 代码调用,同时,Go 语言内置了 [syscall/js](https://github.com/golang/go/tree/master/src/syscall/js) 包,可以在 Go 语言中直接调用 JavaScript 函数,包括对 DOM 树的操作。 35 | 36 | ## 2 Hello World 37 | 38 | 如果对 Go 语言不熟悉,推荐 [Go 语言简明教程](https://geektutu.com/post/quick-golang.html),一篇文章快速入门。 39 | 40 | 接下来,我们使用 Go 语言实现一个最简单的程序,在网页上弹出 `Hello World`。 41 | 42 | 第一步,新建文件 main.go,使用 js.Global().get('alert') 获取全局的 alert 对象,通过 Invoke 方法调用。等价于在 js 中调用 `window.alert("Hello World")`。 43 | 44 | ```go 45 | // main.go 46 | package main 47 | 48 | import "syscall/js" 49 | 50 | func main() { 51 | alert := js.Global().Get("alert") 52 | alert.Invoke("Hello World!") 53 | } 54 | ``` 55 | 56 | 第二步,将 main.go 编译为 static/main.wasm 57 | 58 | > 如果启用了 `GO MODULES`,则需要使用 go mod init 初始化模块,或设置 GO111MODULE=auto。 59 | 60 | ```bash 61 | $ GOOS=js GOARCH=wasm go build -o static/main.wasm 62 | ``` 63 | 64 | 第三步,拷贝 wasm_exec.js (JavaScript 支持文件,加载 wasm 文件时需要) 到 static 文件夹 65 | 66 | ```bash 67 | $ cp "$(go env GOROOT)/misc/wasm/wasm_exec.js" static 68 | ``` 69 | 70 | 第四步,创建 index.html,引用 `static/main.wasm` 和 `static/wasm_exec.js`。 71 | 72 | ```html 73 | 74 | 75 | 80 | 81 | 82 | ``` 83 | 84 | 第五步,使用 goexec 启动 Web 服务 85 | 86 | > 如果没有安装 goexec,可用 `go get -u github.com/shurcooL/goexec` 安装,需要将 $GOBIN 或 $GOPATH/bin 加入环境变量 87 | 88 | 当前的目录结构如下: 89 | 90 | ```bash 91 | demo/ 92 | |--static/ 93 | |--wasm_exec.js 94 | |--main.wasm 95 | |--main.go 96 | |--index.html 97 | ``` 98 | 99 | ```bash 100 | $ goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' 101 | ``` 102 | 103 | 浏览器访问 localhost:9999,则会有一个弹出窗口,上面写着 *Hello World!*。 104 | 105 | ![go wasm hello world demo](quick-go-wasm/hello_world.png) 106 | 107 | 为了避免每次编译都需要输入繁琐的命令,可将这个过程写在 `Makefile` 中 108 | 109 | ```makefile 110 | all: static/main.wasm static/wasm_exec.js 111 | goexec 'http.ListenAndServe(`:9999`, http.FileServer(http.Dir(`.`)))' 112 | 113 | static/wasm_exec.js: 114 | cp "$(shell go env GOROOT)/misc/wasm/wasm_exec.js" static 115 | 116 | static/main.wasm : main.go 117 | GO111MODULE=auto GOOS=js GOARCH=wasm go build -o static/main.wasm . 118 | ``` 119 | 120 | 这样一个敲一下 make 就够了,代码已经上传到 [7days-golang - github.com](https://github.com/geektutu/7days-golang/tree/master/demo-wasm)。 121 | 122 | ## 3 注册函数(Register Functions) 123 | 124 | 在 Go 语言中调用 JavaScript 函数是一方面,另一方面,如果仅仅是使用 WebAssembly 替代性能要求高的模块,那么就需要注册函数,以便其他 JavaScript 代码调用。 125 | 126 | 假设我们需要注册一个计算斐波那契数列的函数,可以这么实现。 127 | 128 | ```go 129 | // main.go 130 | package main 131 | 132 | import "syscall/js" 133 | 134 | func fib(i int) int { 135 | if i == 0 || i == 1 { 136 | return 1 137 | } 138 | return fib(i-1) + fib(i-2) 139 | } 140 | 141 | func fibFunc(this js.Value, args []js.Value) interface{} { 142 | return js.ValueOf(fib(args[0].Int())) 143 | } 144 | 145 | func main() { 146 | done := make(chan int, 0) 147 | js.Global().Set("fibFunc", js.FuncOf(fibFunc)) 148 | <-done 149 | } 150 | ``` 151 | 152 | - fib 是一个普通的 Go 函数,通过递归计算第 i 个斐波那契数,接收一个 int 入参,返回值也是 int。 153 | - 定义了 fibFunc 函数,为 fib 函数套了一个壳,从 args[0] 获取入参,计算结果用 js.ValueOf 包装,并返回。 154 | - 使用 js.Global().Set() 方法,将注册函数 fibFunc 到全局,以便在浏览器中能够调用。 155 | 156 | `js.Value` 可以将 Js 的值转换为 Go 的值,比如 args[0].Int(),则是转换为 Go 语言中的整型。`js.ValueOf`,则用来将 Go 的值,转换为 Js 的值。另外,注册函数的时候,使用 js.FuncOf 将函数转换为 `Func` 类型,只有 Func 类型的函数,才能在 JavaScript 中调用。可以认为这是 Go 与 JavaScript 之间的接口/约定。 157 | 158 | `js.Func()` 接受一个函数类型作为其参数,该函数的定义必须是: 159 | 160 | ```go 161 | func(this Value, args []Value) interface{} 162 | // this 即 JavaScript 中的 this 163 | // args 是在 JavaScript 中调用该函数的参数列表。 164 | // 返回值需用 js.ValueOf 映射成 JavaScript 的值 165 | ``` 166 | 167 | 在 main 函数中,创建了信道(chan) done,阻塞主协程(goroutine)。fibFunc 如果在 JavaScript 中被调用,会开启一个新的子协程执行。 168 | 169 | > A wrapped function triggered during a call from Go to JavaScript gets executed on the same goroutine. A wrapped function triggered by JavaScript's event loop gets executed on an extra goroutine. —— [FuncOf - golang.org](https://golang.org/pkg/syscall/js/#FuncOf) 170 | 171 | 接下来,修改之前的 index.html,在其中添加一个输入框(num),一个按钮(btn) 和一个文本框(ans,用来显示计算结果),并给按钮添加了一个点击事件,调用 fibFunc,并将计算结果显示在文本框(ans)中。 172 | 173 | ```html 174 | 175 | ... 176 | 177 | 178 | 179 |

1

180 | 181 | 182 | ``` 183 | 184 | 使用之前的命令重新编译 main.go,并在 9999 端口启动 Web 服务,如果我们已经将命令写在 Makefile 中了,只需要运行 `make` 即可。 185 | 186 | 接下来访问 localhost:9999,可以看到如下效果。输入一个数字,点击`Click`,计算结果显示在输入框下方。 187 | 188 | ![register functions demo](quick-go-wasm/register_functions.png) 189 | 190 | ## 4 操作 DOM 191 | 192 | 在上一个例子中,仅仅是注册了全局函数 fibFunc,事件注册,调用,对 DOM 元素的操作都是在 HTML 193 | 中通过原生的 JavaScript 函数实现的。这些事情,能不能全部在 Go 语言中完成呢?答案可以。 194 | 195 | 首先修改 index.html,删除事件注册部分和 对 DOM 元素的操作部分。 196 | 197 | ```html 198 | 199 | ... 200 | 201 | 202 | 203 |

1

204 | 205 | 206 | ``` 207 | 208 | 修改 main.go: 209 | 210 | ```go 211 | package main 212 | 213 | import ( 214 | "strconv" 215 | "syscall/js" 216 | ) 217 | 218 | func fib(i int) int { 219 | if i == 0 || i == 1 { 220 | return 1 221 | } 222 | return fib(i-1) + fib(i-2) 223 | } 224 | 225 | var ( 226 | document = js.Global().Get("document") 227 | numEle = document.Call("getElementById", "num") 228 | ansEle = document.Call("getElementById", "ans") 229 | btnEle = js.Global().Get("btn") 230 | ) 231 | 232 | func fibFunc(this js.Value, args []js.Value) interface{} { 233 | v := numEle.Get("value") 234 | if num, err := strconv.Atoi(v.String()); err == nil { 235 | ansEle.Set("innerHTML", js.ValueOf(fib(num))) 236 | } 237 | return nil 238 | } 239 | 240 | func main() { 241 | done := make(chan int, 0) 242 | btnEle.Call("addEventListener", "click", js.FuncOf(fibFunc)) 243 | <-done 244 | } 245 | ``` 246 | 247 | - 通过 `js.Global().Get("btn")` 或 `document.Call("getElementById", "num")` 两种方式获取到 DOM 元素。 248 | - btnEle 调用 `addEventListener` 为 btn 绑定点击事件 fibFunc。 249 | - 在 fibFunc 中使用 `numEle.Get("value")` 获取到 numEle 的值(字符串),转为整型并调用 fib 计算出结果。 250 | - ansEle 调用 `Set("innerHTML", ...)` 渲染计算结果。 251 | 252 | 重新编译 main.go,访问 localhost:9999,效果与之前是一致的。 253 | 254 | ## 5 回调函数(Callback Functions) 255 | 256 | 在 JavaScript 中,异步+回调是非常常见的,比如请求一个 Restful API,注册一个回调函数,待数据获取到,再执行回调函数的逻辑,这个期间程序可以继续做其他的事情。Go 语言可以通过协程实现异步。 257 | 258 | 假设 fib 的计算非常耗时,那么可以启动注册一个回调函数,待 fib 计算完成后,再把计算结果显示出来。 259 | 260 | 我们先修改 main.go,使得 fibFunc 支持传入回调函数。 261 | 262 | ```go 263 | package main 264 | 265 | import ( 266 | "syscall/js" 267 | "time" 268 | ) 269 | 270 | func fib(i int) int { 271 | if i == 0 || i == 1 { 272 | return 1 273 | } 274 | return fib(i-1) + fib(i-2) 275 | } 276 | 277 | func fibFunc(this js.Value, args []js.Value) interface{} { 278 | callback := args[len(args)-1] 279 | go func() { 280 | time.Sleep(3 * time.Second) 281 | v := fib(args[0].Int()) 282 | callback.Invoke(v) 283 | }() 284 | 285 | js.Global().Get("ans").Set("innerHTML", "Waiting 3s...") 286 | return nil 287 | } 288 | 289 | func main() { 290 | done := make(chan int, 0) 291 | js.Global().Set("fibFunc", js.FuncOf(fibFunc)) 292 | <-done 293 | } 294 | ``` 295 | 296 | - 假设调用 fibFunc 时,回调函数作为最后一个参数,那么通过 args[len(args)-1] 便可以获取到该函数。这与其他类型参数的传递并无区别。 297 | - 使用 `go func()` 启动子协程,调用 fib 计算结果,计算结束后,调用回调函数 `callback`,并将计算结果传递给回调函数,使用 time.Sleep() 模拟 3s 的耗时操作。 298 | - 计算结果出来前,先在界面上显示 `Waiting 3s...` 299 | 300 | 接下来我们修改 index.html,为按钮添加点击事件,调用 fibFunc 301 | 302 | ```html 303 | 304 | ... 305 | 306 | 307 | 308 |

309 | 310 | 311 | ``` 312 | 313 | - 为 btn 注册了点击事件,第一个参数是待计算的数字,从 num 输入框获取。 314 | - 第二个参数是一个回调函数,将参数 v 显示在 ans 文本框中。 315 | 316 | 接下来,重新编译 main.go,访问 localhost:9999,随便输入一个数字,点击 Click。页面会先显示 `Waiting 3s...`,3s过后显示计算结果。 317 | 318 | ![go wasm callback demo](quick-go-wasm/callback.png) 319 | 320 | 321 | ## 6 进一步的尝试 322 | 323 | ### 6.1 工具框架 324 | 325 | - WebAssembly 的二进制分析工具 [WebAssembly Code Explorer](https://wasdk.github.io/wasmcodeexplorer/) 326 | - 使用NodeJs 或浏览器测试 Go Wasm 代码 [Github Wiki](https://github.com/golang/go/wiki/WebAssembly#executing-webassembly-with-nodejs) 327 | - 借鉴 Vue 实现的 Golang WebAssembly 前端框架 [Vugu](https://www.vugu.org/doc/start),完全使用 Go,不用写任何的 JavaScript 代码。 328 | 329 | ### 6.2 Demo/项目 330 | 331 | - 使用 Go Assembly 前端渲染的一些[例子](https://stdiopt.github.io/gowasm-experiments/) 332 | - [jsgo](https://github.com/dave/jsgo) 这个项目汇聚一些小而精的项目,包括 [2048](https://jsgo.io/hajimehoshi/ebiten/examples/2048),[俄罗斯方块](https://jsgo.io/hajimehoshi/ebiten/examples/blocks)等游戏,还有证明 Go 可以完整开发前端项目的 [TodoMVC](https://jsgo.io/dave/todomvc) 333 | 334 | ### 6.3 相关文档 335 | 336 | - [syscall/js 官方文档 - golang.org](https://golang.org/pkg/syscall/js) 337 | - [Go WebAssembly 官方文档 - github.com](https://github.com/golang/go/wiki/WebAssembly) -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-wasm/callback.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-wasm/callback.png -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-wasm/go-wasm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-wasm/go-wasm.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-wasm/hello_world.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-wasm/hello_world.png -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go-wasm/register_functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go-wasm/register_functions.png -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go2.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go2 新特性简明教程 3 | seo_title: 快速入门 4 | date: 2019-08-15 23:59:10 5 | description: GO 2 (golang 2) 的变化和新特性,与Go1相比Go2的变化。GO2草案,GO2设计草案。包管理机制(package)、错误处理(Error handling),错误值(Error values)和泛型(Generics)。GO语言的历史。difference between GO 2 and GO 1. 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - Go语言 13 | - GO2 14 | - GO 2 15 | - 泛型 16 | image: post/quick-go2/go2.jpg 17 | --- 18 | 19 | ![quick-go2](quick-go2/go2.jpg) 20 | 21 | 图片引用自`udemy.com` 22 | 23 | ## Go 的演进 24 | 25 | Go语言/golang 诞生于2007年,经过12年的发展,Go逐渐成为了云计算领域新一代的开发语言。Go语言在牺牲很少性能的情况下,语法简洁,功能强大。我是Python的重度用户,在学习Go时,却有一种在学习Python的感觉。并非语法相似,而是Go语言作为一门编译型语言,竟然能够像Python一样,少量的代码就能够完成尽可能多的事情。Go语言仿佛是C和Python的结合体。 26 | 27 | Go是如何火起来的呢?我觉得有几个主要的原因,除了语言本身性能好,语法简单,易上手外。Go语言原生支持`Goroutine`和`Channel`,极大地降低了并发和异步编程的复杂度。对于服务端编程,并发和异步尤其重要,相比之下,C++,Java等语言的并发和异步控制逻辑过于复杂。另外,杀手级应用`Docker`的出现起到了很大的推动作用。 28 | 29 | Go语言也有很多令人诟病的地方,例如包管理机制,Go直到v1.6才默认开启了vendor机制,vendor机制非常简陋,简单说就是在项目目录下增加一个vendor文件夹,里面放第三方依赖。vendor机制是没有版本概念的,而且不能解决vendor目录嵌套的问题以及同名包函数冲突问题。后来社区涌现了大量的包管理工具,仅官方推荐的包管理工具就有15种之多,应用比较广泛的,如dep、govendor。直到v1.11,官方增加了Go modules机制,才算较为完整地解决了包管理的问题。 30 | 31 | Go2 可以说是Go语言一个非常重要的里程碑,Go1 目前虽然已经到了1.12版本,事实上每一个版本很少涉及语法层面的变化,而且每个版本都是向前兼容的。较大的改动如下: 32 | 33 | - Go1.2 切片操作 34 | 35 | ```go 36 | var a = make([]int, 10) 37 | var b = a[i:j:k] 38 | ``` 39 | 40 | - Go1.4 for语言加强 41 | 42 | ```go 43 | // <= 1.3 44 | for i, v := range x { 45 | // ... 46 | } 47 | 48 | for i := range x { 49 | // ... 50 | } 51 | 52 | // 1.4 新增 53 | var times [5][0]int 54 | 55 | for i := 0; i < len(times); i++ { 56 | // ... 57 | } 58 | 59 | for _ = range times { 60 | // ... 61 | } 62 | ``` 63 | 64 | - Go1.9 类型别名 65 | 66 | ```go 67 | type T1 = T2 68 | ``` 69 | 70 | ## Go 2 设计草案 71 | 72 | 为了进一步完善Go语言,提供更好的体验。Go语言社区目前发布了三类重要的设计草案,分别是`错误处理(Error handling)`、`错误值(Error values)`、`泛型(Generics)`,这几个草案代表了社区重点关注的完善方向,但并不代表最终的实现。 73 | 74 | ### 错误处理(Error Handling) 75 | 76 | Go1 的错误处理机制非常简单,通过返回值的方式,强迫调用者对错误进行处理,这种设计导致会在代码中写大量的`if`判断。例如: 77 | 78 | ```go 79 | func CopyFile(src, dst string) { 80 | r := os.Open(src) 81 | defer r.Close() 82 | 83 | w := os.Create(dst) 84 | io.Copy(w, r) 85 | w.Close() 86 | } 87 | ``` 88 | 89 | IO操作容易引发错误,文件打开失败,创建失败,拷贝失败等都会产生错误。如果要对这个函数进行完整的错误处理,代码将变成这样: 90 | 91 | ```go 92 | func CopyFile(src, dst string) error { 93 | r, err := os.Open(src) 94 | if err != nil { 95 | return err 96 | } 97 | defer r.Close() 98 | 99 | w, err := os.Create(dst) 100 | if err != nil { 101 | return err 102 | } 103 | defer w.Close() 104 | 105 | if _, err := io.Copy(w, r); err != nil { 106 | return err 107 | } 108 | if err := w.Close(); err != nil { 109 | return err 110 | } 111 | } 112 | ``` 113 | 114 | 看似逻辑清晰,但不够优雅,充斥了大量重复的逻辑。这是Go错误处理机制的缺陷。同时,因为错误处理机制的繁琐,很多开发者在开发应用时,很少去检查并处理错误,程序的健壮性得不到保证。 115 | 116 | 为了解决这个问题,Go2 发布了一个设计草案供社区讨论,Go2将会完善错误处理机制,错误处理的语法将会简洁很多。 117 | 118 | 这个提案引入了`handle err`和`check`关键字,上面的函数可以简化成: 119 | 120 | ```go 121 | func CopyFile(src, dst string) error { 122 | handle err { 123 | return fmt.Errorf("copy %s %s: %v", src, dst, err) 124 | } 125 | r := check os.Open(src) 126 | defer r.Close() 127 | 128 | w := check os.Create(dst) 129 | check io.Copy(w, r) 130 | check w.Close() 131 | } 132 | ``` 133 | 134 | 为什么不使用被Java、Python等语言采用的`try`关键字呢?比如写成: 135 | 136 | ```go 137 | data := try parseHexdump(string(hex)) 138 | ``` 139 | 140 | 上面的写法看似和谐,但`try`关键字直接应用在 error values 时,可读性就没那么好了: 141 | 142 | ```go 143 | data, err := parseHexdump(string(hex)) 144 | if err == ErrBadHex { 145 | ... special handling ... 146 | } 147 | try err 148 | ``` 149 | 150 | 很明显,在这种场景下,`check err`显然比`try err`更有意义。 151 | 152 | ### 错误值(Error values) 153 | 154 | 同样由于错误处理机制设计得较为简陋,Go语言对`Error values`支持有限。任何值,只要实现了`error`接口,都是错误类型。由于缺少细粒度的设计,在各种库当中,判断是否产生错误以及产生了哪类错误的方式多种多样,例如`io.EOF`,`os.IsNotExist`,`err.Error()`等,。另外,Go语言目前没有机制追溯到完整的错误链条。例如, 155 | 156 | ```go 157 | func funcB() error { 158 | if v, err := funcA(); if err != nil { 159 | return fmt.Errorf("connect to db: %v", err) 160 | } 161 | } 162 | func funcC() error { 163 | v, err := funcB() 164 | if err != nil { 165 | return fmt.Errorf("write users database: %v", err) 166 | } 167 | } 168 | ``` 169 | 170 | `funcC`返回的错误信息是: 171 | 172 | ```bash 173 | write users database: connect to db: open /etc/xx.conf: permission denied 174 | ``` 175 | 176 | 每一层,用额外的字符串对错误进行封装,是目前最常用的方法,除了通过字符串解析,很难还原出完整的错误链条。 177 | 178 | 为了解决Error values缺少标准的问题,有2个提案,分别针对`Error inspection`和`Error formatting`。 179 | 180 | - 针对 Error inspection ,为error定义了一个可选的接口`Unwrap`,用来返回错误链上的下一个错误。 181 | 182 | ```go 183 | package errors 184 | 185 | type Wrapper interface { 186 | Unwrap() error 187 | } 188 | ``` 189 | 190 | 例如, 191 | 192 | ```go 193 | // WriteError 实现 Unwrap 接口 194 | func (e *WriteError) Unwrap() error { return e.Err } 195 | ``` 196 | 197 | - 针对 Error format,定义了一个可选的接口`Format`,用来返回错误信息。 198 | 199 | ```go 200 | package errors 201 | 202 | type Formatter interface { 203 | Format(p Printer) (next error) 204 | } 205 | ``` 206 | 207 | 例如, 208 | 209 | ```go 210 | func (e *WriteError) Format(p errors.Printer) (next error) { 211 | p.Printf("write %s database", e.Database) 212 | if p.Detail() { 213 | p.Printf("more detail here") 214 | } 215 | return e.Err 216 | } 217 | ``` 218 | 219 | ### 泛型(Generics) 220 | 221 | Go语言当前可使用`inferface{}`,允许函数参数和返回值是任何类型的值。但这过于灵活,很多时候需要在获取参数后使用类型断言,进而决定下一步的处理。对比C++/Java的标准容器,Go语言在泛型方面有很大不足,因此针对泛型的提案即希望弥补这方面的不足。提案希望能够支持以下功能: 222 | 223 | ```go 224 | type List(type T) []T 225 | // 返回map的键 226 | func Keys(type K, V)(m map[K]V) []K 227 | // 去重过滤 228 | func Uniq(<-chan T) <-chan T 229 | // 合并 230 | func Merge(chans ...<-chan T) <-chan T 231 | // 使用自定义排序函数排序 232 | func SortSlice(data []T, less func(x, y T) bool) 233 | ``` 234 | 235 | 例如,我们需要返回一个map对象中所有的键,而希望这个键的类型可以是任意类型。 236 | 237 | ```go 238 | var ints List(int) 239 | keysA := Keys(int, string)(map[int]string{1:"one", 2: "two"}) 240 | keysB := Keys(string, string)(map[string]string{"name":"geektutu", "age": "twenty"}) 241 | // [1, 2] 242 | ``` 243 | 244 | > 参考:[Go2 wiki - Github](https://github.com/golang/go/wiki/Go2) 245 | 246 | ## Go 2 新特性 247 | 248 | Go2还未正式发布,发布后更新 249 | 250 | -------------------------------------------------------------------------------- /posts/quick-start/go/quick-go2/go2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-go2/go2.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-golang/golang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-golang/golang.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-gomock.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Go Mock (gomock)简明教程 3 | seo_title: 快速入门 4 | date: 2020-02-14 23:30:10 5 | description: gomock 是Go语言/golang 官方提供的mock库,用于在单元测试(unit tests) 中模拟一些依赖复杂,难以直接调用的场景,比如网络请求、数据库依赖和文件I/O等。本文介绍了 gomock 的使用方法,包括模拟参数(Any, Nil, Not, Eq)和返回值(Do, Return, DoAndReturn),以及调用次数(Times) 和顺序(InOrder),如何写可测试可mock的代码。 6 | tags: 7 | - Go 8 | categories: 9 | - Go 简明教程 10 | nav: 简明教程 11 | keywords: 12 | - Go语言 13 | - 单元测试 14 | - mock框架 15 | - stub 16 | image: post/quick-gomock/gomock_logo.jpg 17 | --- 18 | 19 | ![gomock](quick-gomock/gomock.jpg) 20 | 21 | ## 1 gomock 简介 22 | 23 | 上一篇文章 [Go Test 单元测试简明教程](https://geektutu.com/post/quick-go-test.html) 介绍了 Go 语言中单元测试的常用方法,包括子测试(subtests)、表格驱动测试(table-driven tests)、帮助函数(helpers)、网络测试和基准测试(Benchmark)等。这篇文章介绍一种新的测试方法,mock/stub 测试,当待测试的函数/对象的依赖关系很复杂,并且有些依赖不能直接创建,例如数据库连接、文件I/O等。这种场景就非常适合使用 mock/stub 测试。简单来说,就是用 mock 对象模拟依赖项的行为。 24 | 25 | > GoMock is a mocking framework for the Go programming language. It integrates well with Go's built-in testing package, but can be used in other contexts too. 26 | 27 | [gomock](https://github.com/golang/mock) 是官方提供的 mock 框架,同时还提供了 mockgen 工具用来辅助生成测试代码。 28 | 29 | 使用如下命令即可安装: 30 | 31 | ```bash 32 | go get -u github.com/golang/mock/gomock 33 | go get -u github.com/golang/mock/mockgen 34 | ``` 35 | 36 | ## 2 一个简单的 Demo 37 | 38 | ```go 39 | // db.go 40 | type DB interface { 41 | Get(key string) (int, error) 42 | } 43 | 44 | func GetFromDB(db DB, key string) int { 45 | if value, err := db.Get(key); err == nil { 46 | return value 47 | } 48 | 49 | return -1 50 | } 51 | ``` 52 | 53 | 假设 `DB` 是代码中负责与数据库交互的部分(在这里用 map 模拟),测试用例中不能创建真实的数据库连接。这个时候,如果我们需要测试 `GetFromDB` 这个函数内部的逻辑,就需要 mock 接口 `DB`。 54 | 55 | 第一步:使用 `mockgen` 生成 `db_mock.go`。一般传递三个参数。包含需要被mock的接口得到源文件`source`,生成的目标文件`destination`,包名`package`。 56 | 57 | ```bash 58 | $ mockgen -source=db.go -destination=db_mock.go -package=main 59 | ``` 60 | 61 | 第二步:新建 `db_test.go`,写测试用例。 62 | 63 | ```go 64 | func TestGetFromDB(t *testing.T) { 65 | ctrl := gomock.NewController(t) 66 | defer ctrl.Finish() // 断言 DB.Get() 方法是否被调用 67 | 68 | m := NewMockDB(ctrl) 69 | m.EXPECT().Get(gomock.Eq("Tom")).Return(100, errors.New("not exist")) 70 | 71 | if v := GetFromDB(m, "Tom"); v != -1 { 72 | t.Fatal("expected -1, but got", v) 73 | } 74 | } 75 | ``` 76 | 77 | - 这个测试用例有2个目的,一是使用 `ctrl.Finish()` 断言 `DB.Get()` 被是否被调用,如果没有被调用,后续的 mock 就失去了意义; 78 | - 二是测试方法 `GetFromDB()` 的逻辑是否正确(如果 `DB.Get()` 返回 error,那么 `GetFromDB()` 返回 -1)。 79 | - `NewMockDB()` 的定义在 `db_mock.go` 中,由 mockgen 自动生成。 80 | 81 | 最终的代码结构如下: 82 | 83 | ```bash 84 | project/ 85 | |--db.go 86 | |--db_mock.go // generated by mockgen 87 | |--db_test.go 88 | ``` 89 | 90 | 执行测试: 91 | 92 | ```bash 93 | $ go test . -cover -v 94 | === RUN TestGetFromDB 95 | --- PASS: TestGetFromDB (0.00s) 96 | PASS 97 | coverage: 81.2% of statements 98 | ok example 0.008s coverage: 81.2% of statements 99 | ``` 100 | 101 | ## 3 打桩(stubs) 102 | 103 | 在上面的例子中,当 `Get()` 的参数为 Tom,则返回 error,这称之为`打桩(stub)`,有明确的参数和返回值是最简单打桩方式。除此之外,检测调用次数、调用顺序,动态设置返回值等方式也经常使用。 104 | 105 | 3.1 参数(Eq, Any, Not, Nil) 106 | 107 | ```go 108 | m.EXPECT().Get(gomock.Eq("Tom")).Return(0, errors.New("not exist")) 109 | m.EXPECT().Get(gomock.Any()).Return(630, nil) 110 | m.EXPECT().Get(gomock.Not("Sam")).Return(0, nil) 111 | m.EXPECT().Get(gomock.Nil()).Return(0, errors.New("nil")) 112 | ``` 113 | 114 | - `Eq(value)` 表示与 value 等价的值。 115 | - `Any()` 可以用来表示任意的入参。 116 | - `Not(value)` 用来表示非 value 以外的值。 117 | - `Nil()` 表示 None 值 118 | 119 | 120 | 3.2 返回值(Return, DoAndReturn) 121 | 122 | ```go 123 | m.EXPECT().Get(gomock.Not("Sam")).Return(0, nil) 124 | m.EXPECT().Get(gomock.Any()).Do(func(key string) { 125 | t.Log(key) 126 | }) 127 | m.EXPECT().Get(gomock.Any()).DoAndReturn(func(key string) (int, error) { 128 | if key == "Sam" { 129 | return 630, nil 130 | } 131 | return 0, errors.New("not exist") 132 | }) 133 | ``` 134 | 135 | - `Return` 返回确定的值 136 | - `Do` Mock 方法被调用时,要执行的操作吗,忽略返回值。 137 | - `DoAndReturn` 可以动态地控制返回值。 138 | 139 | 3.3 调用次数(Times) 140 | 141 | ```go 142 | func TestGetFromDB(t *testing.T) { 143 | ctrl := gomock.NewController(t) 144 | defer ctrl.Finish() 145 | 146 | m := NewMockDB(ctrl) 147 | m.EXPECT().Get(gomock.Not("Sam")).Return(0, nil).Times(2) 148 | GetFromDB(m, "ABC") 149 | GetFromDB(m, "DEF") 150 | } 151 | ``` 152 | 153 | - `Times()` 断言 Mock 方法被调用的次数。 154 | - `MaxTimes()` 最大次数。 155 | - `MinTimes()` 最小次数。 156 | - `AnyTimes()` 任意次数(包括 0 次)。 157 | 158 | 3.4 调用顺序(InOrder) 159 | 160 | ```go 161 | func TestGetFromDB(t *testing.T) { 162 | ctrl := gomock.NewController(t) 163 | defer ctrl.Finish() // 断言 DB.Get() 方法是否被调用 164 | 165 | m := NewMockDB(ctrl) 166 | o1 := m.EXPECT().Get(gomock.Eq("Tom")).Return(0, errors.New("not exist")) 167 | o2 := m.EXPECT().Get(gomock.Eq("Sam")).Return(630, nil) 168 | gomock.InOrder(o1, o2) 169 | GetFromDB(m, "Tom") 170 | GetFromDB(m, "Sam") 171 | } 172 | ``` 173 | 174 | ## 4 如何编写可 mock 的代码 175 | 176 | 写可测试的代码与写好测试用例是同等重要的,如何写可 mock 的代码呢? 177 | 178 | - mock 作用的是接口,因此将依赖抽象为接口,而不是直接依赖具体的类。 179 | - 不直接依赖的实例,而是使用依赖注入降低耦合性。 180 | 181 | > 在软件工程中,依赖注入的意思为,给予调用方它所需要的事物。 “依赖”是指可被方法调用的事物。依赖注入形式下,调用方不再直接指使用“依赖”,取而代之是“注入” 。“注入”是指将“依赖”传递给调用方的过程。在“注入”之后,调用方才会调用该“依赖”。传递依赖给调用方,而不是让让调用方直接获得依赖,这个是该设计的根本需求。 182 | > -- [依赖注入 - Wikipedia](https://zh.wikipedia.org/zh-cn/%E4%BE%9D%E8%B5%96%E6%B3%A8%E5%85%A5) 183 | 184 | 如果 `GetFromDB()` 方法长这个样子 185 | 186 | ```go 187 | func GetFromDB(key string) int { 188 | db := NewDB() 189 | if value, err := db.Get(key); err == nil { 190 | return value 191 | } 192 | 193 | return -1 194 | } 195 | ``` 196 | 197 | 对 `DB` 接口的 mock 并不能作用于 `GetFromDB()` 内部,这样写是没办法进行测试的。那如果将接口 `db DB` 通过参数传递到 `GetFromDB()`,那么就可以轻而易举地传入 Mock 对象了。 198 | -------------------------------------------------------------------------------- /posts/quick-start/go/quick-gomock/gomock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-gomock/gomock.jpg -------------------------------------------------------------------------------- /posts/quick-start/go/quick-gomock/gomock_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/go/quick-gomock/gomock_logo.jpg -------------------------------------------------------------------------------- /posts/quick-start/python/quick-python/python.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/python/quick-python/python.jpg -------------------------------------------------------------------------------- /posts/quick-start/rust/quick-rust/rust.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/quick-start/rust/quick-rust/rust.jpg -------------------------------------------------------------------------------- /posts/summary/2020.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 2020 年终总结 3 | date: 2020-12-30 01:00:00 4 | description: 从 2020 年开始,在博客里记录每一年的成长与改变。年初的时候,给自己定了两个目标:工作上升一级,博客浏览量能翻一倍。第一个目标没想到刚过完年,就实现了。第二个目标到年底的时候也基本达成了,月 UV 3w,月 PV 8w。在这里,将过程中的一些经历和感悟记录下来。 5 | keywords: 6 | - 年终总结 7 | - 2020 8 | tags: 9 | - 关于我 10 | categories: 11 | - 年终总结 12 | nav: 杂谈 13 | image: post/2020/2020.jpg 14 | --- 15 | 16 | ![2020](2020/2020.jpg) 17 | 18 | 从 2020 年开始,在博客里记录每一年的成长与改变。 19 | 20 | 年初的时候,给自己定了两个目标:工作上升一级,博客浏览量能翻一倍。第一个目标没想到刚过完年,就实现了。第二个目标到年底的时候也基本达成了,月 UV 3w,月 PV 8w。 21 | 22 | 在这里,将过程中的一些经历和感悟记录下来。 23 | 24 | ## 工作 25 | 26 | 今年所做的工作需要横跨很多部门,而我是特别不愿意干这种事情的。喜欢一个人静下心来写代码,不被打扰。对其他部门的同事,特别是总来问你进展,顺便把成果包装的同事特别反感。但是在这个过程中,逐渐地转变了想法。一个人的能力是有限的,跨团队跨部门的合作更容易产生大的价值,关键点在于摘掉有色眼镜,找到大家共同的利益诉求,把事情做好。不要担心被别人包装,别人来包装你,说明你有价值。把蛋糕做大才是最重要的,一个大的蛋糕,再怎么分,也比原来手中的小蛋糕大。 27 | 28 | 今年涉及的工作是原来从来没接触过的,经过一年的摸爬滚打,竟然做得还不错。我在公司内部总结为那么一句话: 29 | 30 | > 一年只做一件事,一年做好一件事。 31 | 32 | 我一直是一个很浮躁的人,坚持做一件事不超过三个月。Github 上的开源项目也是这样,经常发现过了几个月还没啥人关注,就放弃了。上面这句话,是今年的最大的感悟。不管是在工作中,还是在工作外。很多事,很可能需要持续到某一个时间点,才会产生质变,不要过早放弃,也不要随意切换方向。选择是有时间成本的。 33 | 34 | ## 博客 35 | 36 | ![2020 blog](2020/data.png) 37 | 38 | 今年博客的阅读量和评论量主要是来自于这两个系列: 39 | 40 | - [七天用 Go 从零实现系列](https://geektutu.com/post/gee.html) 41 | - [Go 语言高性能编程](https://geektutu.com/post/high-performance-go.html) 42 | 43 | 七天系列 Github 有 5.9k 的星星,这是一个完全原创的项目,从很多童鞋的评论来看,能够对大部分童鞋理解一些优秀的开源项目和框架源码有帮助。 44 | 45 | 我分析过我的博客的流量构成,搜索、直接访问和引荐各占了三分之一。直接访问有那么高是比较意外的,我觉得对持续提升流量有帮助的几个点: 46 | 47 | - 坚持写一个领域的文章,比如写 Go,就坚持写个几十篇,比较成体系,对搜索引擎比较友好,而且也容易在这个领域形成一定的影响力,会有很多直接访问的流量,因为别人记住你了。 48 | - 不要在其他平台随意地转载自己的文章。大部分博客平台的权重很高,如果同时发在自己的独立博客和 CSDN、博客园这样的平台,那么很容易被搜索引擎判定抄袭,降低权重,博客的搜索流量会变得很小。如果想完全同步,建议间隔一个月。如果你的域名权重很高,就无所谓了。 49 | - 可以在其他平台发一些汇总链接的文章,我的每个系列都会在知乎发汇总贴,比如 [7天用Go从零实现RPC框架GeeRPC](https://zhuanlan.zhihu.com/p/265813329),有 200 个赞,一个原因是有童鞋在知乎关注了我,发篇文章能够让关注我的同学知道我出新的系列了。另一个原因,也是可以在一定程度上增加外链的数量,可以增加博客文章的权重。 50 | - 检查下你的订阅文件,有没有包含全文。我原来没注意,生成的 atom2.xml 包含了全文,每次更新,第一时间被很多采集站全文搬运,这篇文章就白写了。利用标准的 feed 文件内容,直接发布,还不用写爬虫,不搬运你搬运谁呢,这种搬运更不可能注明转载和原文地址了。现在生成的 [feed.xml](https://geektutu.com/feed.xml) 去掉了正文。 51 | 52 | ## 读的书 53 | 54 | 今年纸质书读得比较少,微信读书读了挺多,在这里分享两本比较有意思的鸡汤书吧。 55 | 56 | 第一本书是《微习惯》,因为这本书,我给自己定了每天写 50 字和做 1 个俯卧撑的目标,坚持了很长时间。我的第二个系列能写那么快,我觉得这本书功不可没。养成一个习惯很难,即使有再强大的动机和毅力都很难坚持下来。那如果这个习惯只需要极其微小的动机,不消耗任何的毅力呢?比如每天一个俯卧撑,中午吃完饭大家午休的时候,你就趴下来,再站起来,一秒钟一个俯卧撑的目标达成,每天花费一秒钟,就养成了一个健身的好习惯。 57 | 58 | 我从第一天的 5 个,到现在每天中午做 40 个。但是目标仍旧是每天 1 个,哪天生病了,实在是没法动,随便做 1 个,这个习惯就还是没有断的。 59 | 60 | 每天写 50 字也是一样的,大部分时间,在手机或者电脑上随便敲 50 字就睡觉了,然后有三分之一的时间,敲满了 2 千字,比如今天这篇总结。假如没这个习惯,我估计明日复明日,这个年终总结就没了。 61 | 62 | 第二本书是《刻意练习》,这本书比较啰嗦。对我启发比较大的地方总结一下: 63 | 64 | 首先,只花时间练习是不够的,还得“刻意”,最佳方式是找这个领域你能接触到的最厉害的人给你辅导,给你反馈。一件事情简单的重复是无意义的,是浪费时间,一万小时也不会变成专家。自己自学是没问题的,但是如果有人指导你,那必然是能更快的。比如入门 Python,淘宝花 200 块,让别人给你讲 1 小时课,他的水平可能不高,但是对于从没有接触过 Python 的你,指导你安装好环境,写出 hello world 是足够了的。如果自己去琢磨,很可能折腾了一天,啥事都没做成。工作以后,时间是很宝贵的。 65 | 66 | 第二,一件事重复做一百次是有意义的,前提是每次做完思考下哪些地方可以提升。比如学英语,对大部分人来说,看一百部电影的效果比不上《老友记》的一集看一百遍。看一百部电影,光顾着看字幕去了,不太可能找到提升的点。而找一个难度适合的半小时的材料,听一百遍,每次多听懂一个单词,一句话,提升就是巨大的。 67 | 68 | 只有“刻意”去做,才不会看似忙碌,实则浪费时间。 69 | 70 | ## flag 71 | 72 | 唯一的 flag 就是,通过上述两本鸡汤书养成的习惯,明年不要丢。再小的习惯,坚持一年,也会产生巨大的能量。 -------------------------------------------------------------------------------- /posts/summary/2020/2020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/summary/2020/2020.jpg -------------------------------------------------------------------------------- /posts/summary/2020/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/summary/2020/data.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-make-npy-hdf5-data-set.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow入门(四) - mnist手写数字识别(制作h5py训练集) 3 | date: 2018-04-02 11:51:24 4 | description: TensorFlow 入门系列文章,第四篇,mnist手写数字识别(制作h5py训练集)。 5 | tags: 6 | - 机器学习 7 | - TensorFlow 8 | - mnist 9 | - Python 10 | nav: 简明教程 11 | categories: 12 | - TensorFlow 教程 13 | image: post/tensorflow-mnist-simplest/x_y.png 14 | github: https://github.com/geektutu/tensorflow-tutorial-samples 15 | --- 16 | 17 | 这篇文章是 **TensorFlow Tutorial** 入门教程的第四篇文章。 18 | 19 | 在之前的几篇文章中,我们都是通过 `tensorflow.examples.tutorials.mnist`来使用mnist训练集集,制作训练集主要有2个目的,一是加快训练时读取的速度,而是支持随机批读取。假如,每次训练时,都是直接读取图片,再将图片转为矩阵进行训练,那这样效率无疑是非常低下的。 20 | 21 | 这篇文章将使用numpy 和 h5py(HDF5文件格式)2种方式来制作训练集,并对这两种方式进行对比。 22 | 23 | ## 准备图片 24 | ![mnist-images](tensorflow-make-npy-hdf5-data-set/gen_mnist_images.png) 25 | 26 | 直接读取tensorflow中mnist数据集,将数据集还原为图片。 27 | 28 | 在这里,使用 pillow库将矩阵转为图片。 29 | 30 | ```python 31 | import numpy as np 32 | from PIL import Image 33 | from tensorflow.examples.tutorials.mnist import input_data 34 | 35 | 36 | def gen_image(arr, index, label): 37 | # 直接保存 arr,是黑底图片,1.0 - arr 是白底图片 38 | matrix = (np.reshape(1.0 - arr, (28, 28)) * 255).astype(np.uint8) 39 | img = Image.fromarray(matrix, 'L') 40 | # 存储图片时,label_index的格式,方便在制作数据集时,从文件名即可知道label 41 | img.save("./images/{}_{}.png".format(label, index)) 42 | 43 | 44 | data = input_data.read_data_sets('../mnist/data_set') 45 | x, y = data.train.next_batch(200) 46 | for i, (arr, label) in enumerate(zip(x, y)): 47 | print(i, label) 48 | gen_image(arr, i, label) 49 | ``` 50 | 51 | 这样,就得到了200张 `28*28`的图片供下一步制作训练集。 52 | 53 | 54 | 55 | ## 制作npy格式的数据集 56 | 57 | numpy能够将矩阵保存为文件,也能从文件中读取矩阵,因此可以考虑使用numpy制作数据集。 58 | 59 | ```python 60 | import os 61 | import numpy as np 62 | from PIL import Image 63 | from sklearn.model_selection import train_test_split 64 | ``` 65 | 66 | ### 1. 图片转为矩阵并保存 67 | 68 | ```python 69 | x, y = [], [] 70 | 71 | for i, image_path in enumerate(os.listdir('./images')): 72 | # label转为独热编码后再保存 73 | label = int(image_path.split('_')[0]) 74 | label_one_hot = [0 if i != label else 1 for i in range(10)] 75 | y.append(label_one_hot) 76 | 77 | # 图片像素值映射到 0 - 1之间 78 | image = Image.open('./images/{}'.format(image_path)).convert('L') 79 | image_arr = 1 - np.reshape(image, 784) / 255.0 80 | x.append(image_arr) 81 | 82 | np.save('data_set/X.npy', np.array(x)) 83 | np.save('data_set/Y.npy', np.array(y)) 84 | ``` 85 | 86 | ### 2. 读取文件随机批处理 87 | 88 | ```python 89 | class DataSet: 90 | def __init__(self): 91 | x, y = np.load('data_set/X.npy'), np.load('data_set/Y.npy') 92 | self.train_x, self.test_x, self.train_y, self.test_y = \ 93 | train_test_split(x, y, test_size=0.2, random_state=0) 94 | 95 | self.train_size = len(self.train_x) 96 | 97 | def get_train_batch(self, batch_size=64): 98 | # 随机获取batch_size个训练数据 99 | choice = np.random.randint(self.train_size, size=batch_size) 100 | batch_x = self.train_x[choice, :] 101 | batch_y = self.train_y[choice, :] 102 | 103 | return batch_x, batch_y 104 | 105 | def get_test_set(self): 106 | return self.test_x, self.test_y 107 | ``` 108 | 109 | - 一般情况下,我们会用随机批梯度下降的方式去进行训练,因此需要实现随机获取 batch_size个数据的功能。 110 | - 为了测试模型的泛化能力,测试集一般不与测试集交叉,常用 `sklearn`库中的`train_test_split`去分离训练数据与测试数据。 111 | 112 | ### 3. 如何使用 113 | 114 | ```python 115 | data_source = DataSet() 116 | for i in range(1000): 117 | train_x, train_y = data_source.get_train_batch(batch_size=32) 118 | // ... 119 | ``` 120 | 121 | 122 | 123 | ## 制作HDF5格式的数据集 124 | 125 | HDF 是用于存储和分发科学数据的一种自我描述、多对象文件格式。HDF 是由美国国家超级计算应用中心(NCSA)创建的,以满足不同群体的科学家在不同工程项目领域之需要。一个HDF5文件就是一个由两种基本数据对象(groups and datasets)存放多种科学数据的容器: 126 | 127 | - HDF5 group: 包含0个或多个HDF5对象以及支持元数据(metadata)的一个群组结构。 128 | - HDF5 dataset: 数据元素的一个多维数组以及支持元数据(metadata) 129 | 130 | 直观理解,一个HDF5文件可以存储多个数据(value),并用索引(key)找到,支持层级嵌套,类似于Python中的字典。 131 | 132 | Python中[h5py](http://docs.h5py.org/en/latest/index.html)来制作和使用HDF5格式的文件。 133 | 134 | ```python 135 | import os 136 | import h5py 137 | import numpy as np 138 | from PIL import Image 139 | from sklearn.model_selection import train_test_split 140 | ``` 141 | 142 | ### 1. 图片转为矩阵并保存 143 | 144 | ```python 145 | x, y = [], [] 146 | 147 | for i, image_path in enumerate(os.listdir('./images')): 148 | # label转为独热编码后再保存 149 | label = int(image_path.split('_')[0]) 150 | label_one_hot = [0 if i != label else 1 for i in range(10)] 151 | y.append(label_one_hot) 152 | 153 | # 图片像素值映射到 0 - 1之间 154 | image = Image.open('./images/{}'.format(image_path)).convert('L') 155 | image_arr = 1 - np.reshape(image, 784) / 255.0 156 | x.append(image_arr) 157 | 158 | with h5py.File('./data_set/data.h5', 'w') as f: 159 | f.create_dataset('x_data', data=np.array(x)) 160 | f.create_dataset('y_data', data=np.array(y)) 161 | ``` 162 | 163 | ### 2. 读取文件随机批处理 164 | 165 | ```python 166 | class DataSet: 167 | def __init__(self): 168 | with h5py.File('./data_set/data.h5', 'r') as f: 169 | x, y = f['x_data'].value, f['y_data'].value 170 | 171 | self.train_x, self.test_x, self.train_y, self.test_y = \ 172 | train_test_split(x, y, test_size=0.2, random_state=0) 173 | 174 | self.train_size = len(self.train_x) 175 | 176 | def get_train_batch(self, batch_size=64): 177 | # 随机获取batch_size个训练数据 178 | choice = np.random.randint(self.train_size, size=batch_size) 179 | batch_x = self.train_x[choice, :] 180 | batch_y = self.train_y[choice, :] 181 | 182 | return batch_x, batch_y 183 | 184 | def get_test_set(self): 185 | return self.test_x, self.test_y 186 | ``` 187 | 188 | > f['x_data'] 是一个datasets,拥有 name, shape, value 属性 189 | > 190 | > 可以看到,我们只用了1个HDF5文件就将x 和 y存下来了。假如在保存文件前对训练集和测试集进行拆分,同样能够将 train_x, train_y, test_x, test_y 一起保存在一个 HDF5文件中,使用非常方便。 191 | 192 | ## npy格式与hdf5格式的对比 193 | 194 | | # | 读取(1000次/ms) | 存储空间(M) | 195 | | ---- | ------------ | ------- | 196 | | npy | 1204 | 1.3 | 197 | | hdf5 | 1665 | 1.3 | 198 | 199 | 使用 200 张 28 * 28的图片对比,可以发现在没有使用任何压缩辅助的情况下,两种格式的数据占据的磁盘空间是一样的,HDF5的读取速度比npy慢了1/3,训练集如果能一次读取内存,启动训练前的读取时间可以忽略不计,但是HDF5格式的文件因为能够存储metadata和支持层级嵌套,键索引,使用起来更方便。 200 | 201 | **觉得还不错,不要吝惜你的[star](https://github.com/geektutu/tensorflow-tutorial-samples),支持是持续不断更新的动力。** 202 | -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-make-npy-hdf5-data-set/gen_mnist_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-make-npy-hdf5-data-set/gen_mnist_images.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-save-ckpt.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow入门(二) - mnist手写数字识别(模型保存加载) 3 | date: 2017-12-17 11:51:24 4 | description: TensorFlow 入门系列文章,第二篇,mnist手写数字识别(模型保存加载)。 5 | tags: 6 | - 机器学习 7 | - TensorFlow 8 | - mnist 9 | - Python 10 | nav: 简明教程 11 | categories: 12 | - TensorFlow 教程 13 | image: post/tensorflow-mnist-save-ckpt/save_ckpt.png 14 | github: https://github.com/geektutu/tensorflow-tutorial-samples 15 | --- 16 | 17 | 这篇文章是 **TensorFlow Tutorial** 入门教程的第二篇文章。 18 | 19 | 上一篇文章[TensorFlow入门(一) - mnist手写数字识别(网络搭建)](http://geektutu.com/post/tensorflow-mnist-simplest.html)介绍了`神经网络输入`、`输出`、`独热编码`、`损失函数`等最基本的知识,并且演示了如何用最简单的模型实现mnist手写数字识别91%的正确率。但是遗留的问题是,模型保存在内存中,每次都得重新开始训练。 20 | 21 | 这篇文章解决的就是这个问题。将依次介绍tensorflow中如何`保存`已经训练好的模型,如何在某个训练步数的基础上`继续训练`,最后将演示如何`加载模型`,并借助pillow(Python2中称为PIL)库实现真实手写数字图片的识别。 22 | 23 | ## 模型的保存 24 | - 首先看一下项目的目录结构 25 | 26 | ``` 27 | |--mnist/ 28 | |--data_set/ 训练以及测试数据集 29 | |--test_images/ 多张测试图片 30 | |--0.png 31 | |--1.png 32 | |--4.png 33 | |--v2/ 34 | |--ckpt/ 模型保存在这里!!! 35 | |--model.py 网络模型 36 | |--train.py 训练代码 37 | |--predict.py 预测代码 38 | ``` 39 | ### 第一步更改模型,记录global_step 40 | 41 | > 每一次训练,会进行一次梯度下降,传入的global_step的值会自增1,因此,可以通过计算global_step这个张量的值,知道当前训练了多少步。 42 | 43 | ```python 44 | # model.py 45 | class Network: 46 | def __init__(self): 47 | # 记录已经训练的次数 48 | self.global_step = tf.Variable(0, trainable=False) 49 | 50 | # ... 中间省略网络结构 51 | 52 | # minimize 可传入参数 global_step, 每次训练 global_step的值会增加1 53 | # 因此,可以通过计算self.global_step这个张量的值,知道当前训练了多少步 54 | self.train = tf.train.GradientDescentOptimizer(0.001).minimize( 55 | self.loss, global_step=self.global_step) 56 | ``` 57 | 58 | ### 第二步,每隔N步保存 59 | 60 | ```python 61 | CKPT_DIR = 'ckpt' # 定义模型存储的位置 62 | net = Network() 63 | sess = tf.Session() 64 | sess.run(tf.global_variables_initializer()) 65 | 66 | # tf.train.Saver是用来保存训练结果的。 67 | # max_to_keep 用来设置最多保存多少个模型,默认是5 68 | # 如果保存的模型超过这个值,最旧的模型将被删除 69 | saver = tf.train.Saver(max_to_keep=10) 70 | 71 | train_step = 10000 # 总的训练次数10000 72 | step = 0 # 记录训练次数, 初始化为0 73 | save_interval = 1000 # 每隔1000步保存模型 74 | 75 | while step < train_step: 76 | # ...省略训练代码 77 | 78 | step = sess.run(net.global_step) 79 | # 模型保存在ckpt文件夹下 80 | # 模型文件名最后会增加global_step的值,比如1000的模型文件名为 model-1000 81 | if step % save_interval == 0: 82 | saver.save(sess, CKPT_DIR + '/model', global_step=step) 83 | ``` 84 | 85 | - 最终保存的模型如下所示 86 | 87 | > 假设训练到了2000步,保存了2次模型。ckpt文件夹下会生成7个文件,第一个文件是 checkpoint文件,保存了所有的模型的路径。其中第一行代表当前的状态,即在加载模型时,使用哪一个模型是由第一行决定的。 88 | 89 | > 每个模型包含3个文件,分别是 90 | > 1. model-xxx.data-00000-of-00001 91 | > 2. model-xxx.index 92 | > 3. model-xxx.meta 93 | 94 | checkpoint文件 95 | ``` 96 | model_checkpoint_path: "model-2000" 97 | all_model_checkpoint_paths: "model-1000" 98 | all_model_checkpoint_paths: "model-2000" 99 | ``` 100 | 101 | 目录结构 102 | 103 | ``` 104 | |--v2/ 105 | |--ckpt/ 模型保存在这里!!! 106 | |--checkpoint 107 | |--model-1000.data-00000-of-00001 108 | |--model-1000.index 109 | |--model-1000.meta 110 | |--model-2000.data-00000-of-00001 111 | |--model-2000.index 112 | |--model-2000.meta 113 | |--model.py 网络模型 114 | |--train.py 训练代码 115 | |--predict.py 预测代码 116 | ``` 117 | 118 | ## 加载模型与继续训练(train.py) 119 | > 假设我们当前模型已经训练到了2000步,但是由于某种原因停止了。那么是否可以在2000步的基础上继续训练呢? 120 | 121 | - 只需一步,训练前保存的模型restore到session中即可。这里需要注意的是,创建 `tf.train.Saver`对象一定要在创建`tf.Session`之后。 122 | 123 | 124 | ```python 125 | CKPT_DIR = 'ckpt' 126 | net = Network() 127 | sess = tf.Session() 128 | sess.run(tf.global_variables_initializer()) 129 | saver = tf.train.Saver(max_to_keep=10) 130 | 131 | train_step = 10000 132 | step = 0 133 | save_interval = 1000 134 | 135 | # 开始训练前,检查ckpt文件夹,看是否有checkpoint文件存在。 136 | # 如果存在,则读取checkpoint文件指向的模型,restore到sess中。 137 | ckpt = tf.train.get_checkpoint_state(CKPT_DIR) 138 | if ckpt and ckpt.model_checkpoint_path: 139 | saver.restore(sess, ckpt.model_checkpoint_path) 140 | # 读取网络中的global_step的值,即当前已经训练的次数 141 | step = sess.run(net.global_step) 142 | print('Continue from') 143 | print(' -> Minibatch update : ', step) 144 | 145 | while step < train_step: 146 | # ...省略训练代码 147 | ``` 148 | 149 | - 再次运行代码,将打印出 150 | 151 | ``` 152 | Continue from 153 | -> Minibatch update : 2000 154 | 第 3000步,... 155 | ``` 156 | 157 | - 如果将checkpoint文件的第一行改为如下,训练将从1000开始,再次训练到2000时,会将原来的2000的模型覆盖。所以restore哪一个模型,只与checkpoint的第一行有关,即只与`model_checkpoint_path`有关。 158 | ``` 159 | model_checkpoint_path: "model-1000" 160 | ``` 161 | 162 | ``` 163 | Continue from 164 | -> Minibatch update : 1000 165 | 第 2000步,... 166 | ``` 167 | 168 | ## 使用模型预测数字(predict.py) 169 | ### 第一步,restore模型 170 | ```python 171 | import numpy as np 172 | from PIL import Image 173 | 174 | 175 | class Predict: 176 | def __init__(self): 177 | self.net = Network() 178 | self.sess = tf.Session() 179 | self.sess.run(tf.global_variables_initializer()) 180 | self.restore() # 加载模型到sess中 181 | 182 | def restore(self): 183 | saver = tf.train.Saver() 184 | ckpt = tf.train.get_checkpoint_state(CKPT_DIR) 185 | if ckpt and ckpt.model_checkpoint_path: 186 | saver.restore(self.sess, ckpt.model_checkpoint_path) 187 | else: 188 | raise FileNotFoundError("未保存任何模型") 189 | 190 | def predict(self, image_path): 191 | # ...省略 192 | ``` 193 | 194 | ### 第二步读入图片并预测 195 | ```python 196 | class Predict: 197 | # ... 198 | 199 | def predict(self, image_path): 200 | # 读图片并转为黑白的 201 | img = Image.open(image_path).convert('L') 202 | flatten_img = np.reshape(img, 784) 203 | x = np.array([1 - flatten_img]) 204 | y = self.sess.run(self.net.y, feed_dict={self.net.x: x}) 205 | 206 | # 因为x只传入了一张图片,取y[0]即可 207 | # np.argmax()取得独热编码最大值的下标,即代表的数字 208 | print(image_path) 209 | print(' -> Predict digit', np.argmax(y[0])) 210 | ``` 211 | 212 | - test_images目录下的`0.png`,`1.png`,`4.png`三张图片的预测结果。 213 | ```python 214 | app = Predict() 215 | app.predict('../test_images/0.png') 216 | app.predict('../test_images/1.png') 217 | app.predict('../test_images/4.png') 218 | ``` 219 | 220 | ### 最后的结果 221 | 222 | - 第一次 **python train.py** 223 | ``` 224 | 第 1000步,当前loss:26.94 225 | 第 2000步,当前loss:28.36 226 | ``` 227 | - 2000步时停止,第二次 **python train.py** 228 | ``` 229 | Continue from 230 | -> Minibatch update : 2000 231 | 第 3000步,当前loss:23.49 232 | 第 4000步,当前loss:20.40 233 | 第 5000步,当前loss:11.65 234 | ``` 235 | 236 | - **python predict.py** 237 | ``` 238 | ../test_images/0.png 239 | -> Predict digit 0 240 | ../test_images/1.png 241 | -> Predict digit 1 242 | ../test_images/4.png 243 | -> Predict digit 4 244 | ``` 245 | 246 | > 源代码&数据集已上传到 [Github](https://github.com/geektutu/tensorflow-tutorial-samples) 247 | 248 | **觉得还不错,不要吝惜你的[star](https://github.com/geektutu/tensorflow-tutorial-samples),支持是持续不断更新的动力。** 249 | 250 | ## 附 推荐 251 | 252 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-save-ckpt/save_ckpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-mnist-save-ckpt/save_ckpt.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-simplest.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow入门(一) - mnist手写数字识别(网络搭建) 3 | date: 2017-12-09 11:51:24 4 | description: TensorFlow 入门系列文章,第一篇,mnist手写数字识别(网络搭建)。 5 | tags: 6 | - 机器学习 7 | - TensorFlow 8 | - mnist 9 | - Python 10 | nav: 简明教程 11 | categories: 12 | - TensorFlow 教程 13 | image: post/tensorflow-mnist-simplest/x_y.png 14 | github: https://github.com/geektutu/tensorflow-tutorial-samples 15 | --- 16 | 17 | 这篇文章是 **TensorFlow Tutorial** 入门教程的第一篇文章。主要介绍了如何从0开始用tensorflow搭建最简单的网络进行训练。 18 | 19 | ## mnist数据集 20 | 21 | ### 简介 22 | 23 | MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片。在机器学习中的地位相当于Python入门的打印`Hello World`。官网是[THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/) 24 | 该数据集包含以下四个部分 25 | 26 | - train-images-idx3-ubyte.gz: 训练集-图片,6w 27 | - train-labels-idx1-ubyte.gz: 训练集-标签,6w 28 | - t10k-images-idx3-ubyte.gz: 测试集-图片,1w 29 | - t10k-labels-idx1-ubyte.gz: 测试集-标签,1w 30 | 31 | ### 图片和标签 32 | 33 | mnist数据集里的每张图片大小为28 * 28像素,可以用28 * 28的大小的数组来表示一张图片。 34 | 标签用大小为10的数组来表示,这种编码我们称之为One hot(独热编码)。 35 | 36 | ### One-hot编码(独热编码) 37 | 38 | 独热编码使用N位代表N种状态,任意时候只有其中一位有效。 39 | 40 | 采用独热编码的例子 41 | 42 | ``` 43 | 性别: 44 | [0, 1]代表女,[1, 0]代表男 45 | 46 | 数字0-9: 47 | [0,0,0,0,0,0,0,0,0,1]代表9,[0,1,0,0,0,0,0,0,0,0]代表1 48 | ``` 49 | 50 | 独热编码的优点在于 51 | - 能够处理非连续型数值特征 52 | - 在一定程度上也扩充了特征。比如性别本身是一个特征,经过编码以后,就变成了男或女两个特征。 53 | 54 | > 在神经网络中,独热编码其实具有很强的容错性,比如神经网络的输出结果是 [0,0.1,0.2,0.7,0,0,0,0,0, 0]转成独热编码后,表示数字3。即值最大的地方变为1,其余均为0。[0,0.1,0.4,0.5,0,0,0,0,0, 0]也能表示数字3。 55 | 56 | > numpy中有一个函数,numpy.argmax()可以取得最大值的下标。 57 | 58 | ## 神经网络的重要概念 59 | ### 输入(x)输出(y)、标签(label) 60 | - 输入是指传入给网络处理的向量,相当于数学函数中的变量。 61 | - 输出是指网络处理后返回的结果,相当于数据函数中的函数值。 62 | - 标签是指我们期望网络返回的结果。 63 | 64 | 对于识别mnist图片而言,输入是大小为784(28 * 28)的向量,输出是大小为10的概率向量(概率最大的位置,即预测的数字)。 65 | 66 | ### 损失函数(loss function) 67 | 损失函数评估网络模型的好坏,值越大,表示模型越差,值越小,表示模型越好。因为传入大量的训练集训练的目标,就是将损失函数的值降到最小。 68 | 69 | 常见的损失函数定义: 70 | - 差的平方和 sum((y - label)^2) 71 | 72 | ``` 73 | [0, 0, 1] 与 [0.1, 0.3, 0.6]的差的平方和为 0.01 + 0.09 + 0.16 = 0.26 74 | [0, 0, 1] 与 [0.2, 0.2, 0.6]的差的平方和为 0.04 + 0.04 + 0.16 = 0.24 75 | [0, 0, 1] 与 [0.1, 0, 0.9]的差的平方和为 0.01 + 0.01 = 0.02 76 | ``` 77 | - 交叉熵 -sum(label * log(y)) 78 | 79 | ``` 80 | 81 | [0, 0, 1] 与 [0.1, 0.3, 0.6]的交叉熵为 -log(0.6) = 0.51 82 | [0, 0, 1] 与 [0.2, 0.2, 0.6]的交叉熵为 -log(0.6) = 0.51 83 | [0, 0, 1] 与 [0.1, 0, 0.9]的交叉熵为 -log(0.9) = 0.10 84 | ``` 85 | 当label为0时,交叉熵为0,label为1时,交叉熵为-log(y),交叉熵只关注独热编码中有效位的损失。这样屏蔽了无效位值的变化(无效位的值的变化并不会影响最终结果),并且通过取对数放大了有效位的损失。当有效位的值趋近于0时,交叉熵趋近于正无穷大。 86 | 87 | ![x_y](tensorflow-mnist-simplest/x_y.png) 88 | 89 | ### 回归模型 90 | 91 | 我们可以将网络理解为一个函数,回归模型,其实是希望对这个函数进行拟合。 92 | 比如定义模型为 Y = X * w + b,对应的损失即 93 | 94 | ``` 95 | loss = (Y - labal)^2 96 | = -(X * w - b - label)^2 97 | 这里损失函数用方差计算,这个函数是关于w和b的二次函数,所以神经网络训练的目的是找到w和b,使得loss最小。 98 | ``` 99 | 100 | 可以通过不断地传入X和label的值,来修正w和b,使得最终得到的Y与label的loss最小。这个训练的过程,可以采用**梯度下降**的方法。通过梯度下降,找到最快的方向,调整w和b值,使得w * X + b的值越来越接近label。 101 | 梯度下降的具体过程,就不在这篇文章中展开了。 102 | 103 | ![loss](tensorflow-mnist-simplest/loss.png) 104 | 105 | ### 学习速率 106 | 简单说,梯度即一个函数的斜率,找到函数的斜率,其实就知道了w和b的值往哪个方向调整,能够让函数值(loss)降低得最快。那么方向知道了,往这个方向调整多少呢?这个数,神经网络中称之为学习速率。学习速率调得太低,训练速度会很慢,学习速率调得过高,每次迭代波动会很大。 107 | 108 | ### softmax激活函数 109 | 110 | 本文不展开讲解softmax激活函数。事实上,再计算交叉熵前的Y值是经过softmax后的,经过softmax后的Y,并不影响Y向量的每个位置的值之间的大小关系。大致有2个作用,一是放大效果,二是梯度下降时需要一个可导的函数。 111 | 112 | ```python 113 | def softmax(x): 114 | import numpy as np 115 | return np.exp(x) / np.sum(np.exp(x), axis=0) 116 | 117 | softmax([4, 5, 10]) 118 | # [ 0.002, 0.007, 0.991] 119 | ``` 120 | 121 | ## Tensorflow识别手写数字 122 | 123 | ### 构造网络 `model.py` 124 | 125 | ```python 126 | import tensorflow as tf 127 | 128 | 129 | class Network: 130 | def __init__(self): 131 | # 学习速率,一般在 0.00001 - 0.5 之间 132 | self.learning_rate = 0.001 133 | 134 | # 输入张量 28 * 28 = 784个像素的图片一维向量 135 | self.x = tf.placeholder(tf.float32, [None, 784]) 136 | 137 | # 标签值,即图像对应的结果,如果对应数字是8,则对应label是 [0,0,0,0,0,0,0,0,1,0] 138 | # 这种方式称为 one-hot编码 139 | # 标签是一个长度为10的一维向量,值最大的下标即图片上写的数字 140 | self.label = tf.placeholder(tf.float32, [None, 10]) 141 | 142 | # 权重,初始化全 0 143 | self.w = tf.Variable(tf.zeros([784, 10])) 144 | # 偏置 bias, 初始化全 0 145 | self.b = tf.Variable(tf.zeros([10])) 146 | # 输出 y = softmax(X * w + b) 147 | self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b) 148 | 149 | # 损失,即交叉熵,最常用的计算标签(label)与输出(y)之间差别的方法 150 | self.loss = -tf.reduce_sum(self.label * tf.log(self.y + 1e-10)) 151 | 152 | # 反向传播,采用梯度下降的方法。调整w与b,使得损失(loss)最小 153 | # loss越小,那么计算出来的y值与 标签(label)值越接近,准确率越高 154 | self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) 155 | 156 | # 以下代码验证正确率时使用 157 | # argmax 返回最大值的下标,最大值的下标即答案 158 | # 例如 [0,0,0,0.9,0,0.1,0,0,0,0] 代表数字3 159 | predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1)) 160 | 161 | # predict -> [true, true, true, false, false, true] 162 | # reduce_mean即求predict的平均数 即 正确个数 / 总数,即正确率 163 | self.accuracy = tf.reduce_mean(tf.cast(predict, "float")) 164 | ``` 165 | 166 | ### 训练 `train.py` 167 | 168 | ```python 169 | import tensorflow as tf 170 | from tensorflow.examples.tutorials.mnist import input_data 171 | from model import Network 172 | 173 | 174 | class Train: 175 | def __init__(self): 176 | self.net = Network() 177 | 178 | # 初始化 session 179 | # Network() 只是构造了一张计算图,计算需要放到会话(session)中 180 | self.sess = tf.Session() 181 | # 初始化变量 182 | self.sess.run(tf.global_variables_initializer()) 183 | 184 | # 读取训练和测试数据,这是tensorflow库自带的,不存在训练集会自动下载 185 | # 项目目录下已经下载好,删掉后,重新运行代码会自动下载 186 | # data_set/train-images-idx3-ubyte.gz 187 | # data_set/train-labels-idx1-ubyte.gz 188 | # data_set/t10k-images-idx3-ubyte.gz 189 | # data_set/t10k-labels-idx1-ubyte.gz 190 | self.data = input_data.read_data_sets('../data_set', one_hot=True) 191 | 192 | def train(self): 193 | # batch_size 是指每次迭代训练,传入训练的图片张数。 194 | # 数据集小,可以使用全数据集,数据大的情况下, 195 | # 为了提高训练速度,用随机抽取的n张图片来训练,效果与全数据集相近 196 | # https://www.zhihu.com/question/32673260 197 | batch_size = 64 198 | 199 | # 总的训练次数 200 | train_step = 2000 201 | 202 | # 开始训练 203 | for i in range(train_step): 204 | # 从数据集中获取 输入和标签(也就是答案) 205 | x, label = self.data.train.next_batch(batch_size) 206 | # 每次计算train,更新整个网络 207 | # loss只是为了看到损失的大小,方便打印 208 | _, loss = self.sess.run([self.net.train, self.net.loss], 209 | feed_dict={self.net.x: x, self.net.label: label}) 210 | 211 | # 打印 loss,训练过程中将会看到,loss有变小的趋势 212 | # 代表随着训练的进行,网络识别图像的能力提高 213 | # 但是由于网络规模较小,后期没有明显下降,而是有明显波动 214 | if (i + 1) % 10 == 0: 215 | print('第%5d步,当前loss:%.2f' % (i + 1, loss)) 216 | ``` 217 | 218 | ### 验证准确率 `train.py` 219 | 220 | ```python 221 | class Train: 222 | def __init__(self): 223 | ... 224 | 225 | def train(self): 226 | ... 227 | 228 | def calculate_accuracy(self): 229 | test_x = self.data.test.images 230 | test_label = self.data.test.labels 231 | # 注意:与训练不同的是,并没有计算 self.net.train 232 | # 只计算了accuracy这个张量,所以不会更新网络 233 | # 最终准确率约为0.91 234 | accuracy = self.sess.run(self.net.accuracy, 235 | feed_dict={self.net.x: test_x, self.net.label: test_label}) 236 | print("准确率: %.2f,共测试了%d张图片 " % (accuracy, len(test_label))) 237 | ``` 238 | 239 | ### 主函数 `train.py` 240 | 241 | ```python 242 | if __name__ == "__main__": 243 | app = Train() 244 | app.train() 245 | app.calculate_accuracy() 246 | 247 | # 运行后,会打印出如下结果 248 | # 第 10步,当前loss:120.93 249 | # 第 20步,当前loss:90.38 250 | # 第 30步,当前loss:80.88 251 | # 第 40步,当前loss:71.23 252 | # 第 50步,当前loss:66.07 253 | # 第 60步,当前loss:55.83 254 | # 第 70步,当前loss:47.27 255 | # 第 80步,当前loss:45.42 256 | # 第 90步,当前loss:37.14 257 | # ... 258 | # 第 2000步,当前loss:21.75 259 | # 准确率: 0.91,共测试了10000张图片 260 | ``` 261 | 262 | > 项目已更新在[Github](https://github.com/geektutu/tensorflow-tutorial-samples),数据集由于国内网络等因素,有时候不能正确下载,所以数据集也一并同步了。 263 | 264 | **觉得还不错,不要吝惜你的[star](https://github.com/geektutu/tensorflow-tutorial-samples),支持是持续不断更新的动力。** 265 | 266 | ## 附 推荐 267 | 268 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-simplest/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-mnist-simplest/loss.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-simplest/x_y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-mnist-simplest/x_y.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-tensorboard-training.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow入门(三) - mnist手写数字识别(可视化训练) 3 | date: 2018-03-29 11:51:24 4 | description: TensorFlow 入门系列文章,第三篇,mnist手写数字识别(可视化训练)。 5 | tags: 6 | - 机器学习 7 | - TensorFlow 8 | - mnist 9 | - Python 10 | nav: 简明教程 11 | categories: 12 | - TensorFlow 教程 13 | image: post/tensorflow-mnist-tensorboard-training/tensorboard_mnist_graph.png 14 | github: https://github.com/geektutu/tensorflow-tutorial-samples 15 | --- 16 | 17 | 这篇文章是 **TensorFlow Tutorial** 入门教程的第三篇文章。 18 | 19 | 在第一篇文章中,我们通过每10步打印出loss值的方式,来观察loss值的变化。但是这样做并不直观,有没有什么方式能够让整个训练的过程更加一目了然呢?tensorflow提供了tensorboard,专门来实现训练过程的可视化。 20 | 21 | ## tensorboard 22 | 23 | 为了更方便 tensorflow 程序的理解、调试与优化,Google发布了一套叫做 tensorboard 的可视化工具。我们可以用 tensorboard 来展现tensorflow 图像,绘制图像生成的定量指标图以及附加数据。 24 | 25 | tensorboard 是通过读取 tensorflow 的`事件文件`来运行的,因为,我们需要将可视化的数据写入事件文件,这个过程称为`Summary`即汇总数据,这样才能在tensorflow中看到期望的可视化效果,先展示下最终的效果。 26 | 27 | - **网络结构图** 28 | 29 | ![tensorboard_mnist_graph](tensorflow-mnist-tensorboard-training/tensorboard_mnist_graph.png) 30 | 31 | - **accuray与loss,bias与weight** 32 | 33 | ![tensorbord_mnist_loss](tensorflow-mnist-tensorboard-training/tensorbord_mnist_loss.png) 34 | 35 | ## summary的生命周期 36 | 37 | **1. 选择需要汇总数据的张量** 38 | 39 | ```python 40 | # 例如在识别 mnist 的网络中 41 | w = tf.Variable(tf.zeros([784, 10]), name="fc/weight") 42 | loss = -tf.reduce_sum(label * tf.log(y + 1e-10)) 43 | # 创建loss的summary node,scalar表示最后的数据会展示为标量曲线。 44 | tf.summary.scalar('loss', loss) 45 | # 创建w的summary node, 最后展示为直方图 46 | tf.summary.histogram('weight', w) 47 | ``` 48 | 49 | **2. 汇总数据的存放位置** 50 | 51 | ```python 52 | # 存放在当前目录下的 log 文件夹中,获得文件句柄 53 | merged_writer = tf.summary.FileWriter("./log", sess.graph) 54 | ``` 55 | 56 | **3. 生成汇总数据** 57 | 58 | > 在tensorflow中,所有的操作只有当你执行,或者另一个操作依赖于它的输出时才会运行。我们刚才创建的这些summary node没有任何操作依赖于它们的结果,因此是不会主动生成的汇总数据的。为了生成汇总信息,可以使用tf.merge_all_summaries来合并所有的summary node。 59 | 60 | ```python 61 | # 合并所有的summary node 62 | merged_summary_op = tf.summary.merge_all() 63 | # 训练时一起run 64 | _, loss, merged_summary = self.sess.run([train, loss, merged_summary_op], feed_dict={x: x, label: label}) 65 | ``` 66 | 67 | **4. 保存汇总数据到文件中** 68 | 69 | > 可以每一步数据都保存,但是一般没有这个必要,可以选择每100步保存一次。 70 | 71 | ```python 72 | if step % 100 == 0: 73 | merged_writer.add_summary(merged_summary, step) 74 | ``` 75 | 76 | ## 可视化mnist网络 77 | 78 | - **在模型(model.py)中选择需要summary的张量** 79 | 80 | ```python 81 | import tensorflow as tf 82 | 83 | 84 | class Network: 85 | def __init__(self): 86 | self.learning_rate = 0.001 87 | self.global_step = tf.Variable(0, trainable=False, name="global_step") 88 | 89 | self.x = tf.placeholder(tf.float32, [None, 784], name="x") 90 | self.label = tf.placeholder(tf.float32, [None, 10], name="label") 91 | 92 | self.w = tf.Variable(tf.zeros([784, 10]), name="fc/weight") 93 | self.b = tf.Variable(tf.zeros([10]), name="fc/bias") 94 | self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b, name="y") 95 | 96 | self.loss = -tf.reduce_sum(self.label * tf.log(self.y + 1e-10)) 97 | self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize( 98 | self.loss, global_step=self.global_step) 99 | 100 | predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1)) 101 | self.accuracy = tf.reduce_mean(tf.cast(predict, "float")) 102 | 103 | # 创建 summary node 104 | # w, b 画直方图 105 | # loss, accuracy画标量图 106 | tf.summary.histogram('weight', self.w) 107 | tf.summary.histogram('bias', self.b) 108 | tf.summary.scalar('loss', self.loss) 109 | tf.summary.scalar('accuracy', self.accuracy) 110 | ``` 111 | 112 | - **训练时保存summary的数据** 113 | 114 | ```python 115 | import tensorflow as tf 116 | from tensorflow.examples.tutorials.mnist import input_data 117 | from model import Network 118 | 119 | CKPT_DIR = 'ckpt' 120 | 121 | class Train: 122 | def __init__(self): 123 | self.net = Network() 124 | self.sess = tf.Session() 125 | self.sess.run(tf.global_variables_initializer()) 126 | self.data = input_data.read_data_sets('../data_set', one_hot=True) 127 | 128 | def train(self): 129 | batch_size = 64 130 | train_step = 20000 131 | step = 0 132 | save_interval = 1000 133 | saver = tf.train.Saver(max_to_keep=5) 134 | 135 | # merge所有的summary node 136 | merged_summary_op = tf.summary.merge_all() 137 | # 可视化存储目录为当前文件夹下的 log 138 | merged_writer = tf.summary.FileWriter("./log", self.sess.graph) 139 | 140 | ckpt = tf.train.get_checkpoint_state(CKPT_DIR) 141 | if ckpt and ckpt.model_checkpoint_path: 142 | saver.restore(self.sess, ckpt.model_checkpoint_path) 143 | # 读取网络中的global_step的值,即当前已经训练的次数 144 | step = self.sess.run(self.net.global_step) 145 | print('Continue from') 146 | print(' -> Minibatch update : ', step) 147 | 148 | while step < train_step: 149 | x, label = self.data.train.next_batch(batch_size) 150 | _, loss, merged_summary = self.sess.run( 151 | [self.net.train, self.net.loss, merged_summary_op], 152 | feed_dict={self.net.x: x, self.net.label: label} 153 | ) 154 | step = self.sess.run(self.net.global_step) 155 | 156 | if step % 100 == 0: 157 | merged_writer.add_summary(merged_summary, step) 158 | 159 | if step % save_interval == 0: 160 | saver.save(self.sess, CKPT_DIR + '/model', global_step=step) 161 | print('%s/model-%d saved' % (CKPT_DIR, step)) 162 | 163 | app = Train() 164 | app.train() 165 | ``` 166 | 167 | ## 启动tensorboard 168 | 169 | 启动前,需要先训练网络,训练过程中,数据会每隔100步写入log文件夹下的文件中,这个时候,可以启动tensorboard(随tensorflow安装,不用单独安装) 170 | 171 | ```shell 172 | tensorboard --logdir=./log 173 | 或 174 | python -m tensorboard.main --logdir=./log 175 | ``` 176 | 177 | `./log`是summary数据存储的路径,即在tf.summary.FileWriter中传入的路径。tensorboard 开始运行后,在浏览器中输入 `localhost:6006` 即可看到本文最开始的效果。 178 | 179 | **觉得还不错,不要吝惜你的[star](https://github.com/geektutu/tensorflow-tutorial-samples),支持是持续不断更新的动力。** 180 | 181 | ## 附 推荐 182 | 183 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-tensorboard-training/tensorboard_mnist_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-mnist-tensorboard-training/tensorboard_mnist_graph.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow-mnist-tensorboard-training/tensorbord_mnist_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow-mnist-tensorboard-training/tensorbord_mnist_loss.png -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-dqn.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow 2.0 (八) - 强化学习 DQN 玩转 gym Mountain Car 3 | date: 2019-06-28 00:30:24 4 | description: TensorFlow2教程,TensorFlow2.0教程,TensorFlow 2.0 入门系列文章,第八篇,强化学习 DQN (Deep Q-Learning) 玩转 OpenAI gym game MountainCar-v0。 5 | tags: 6 | - 强化学习 7 | - TensorFlow 8 | - OpenAI gym 9 | - Python 10 | - DQN 11 | keywords: 12 | - TensorFlow2教程 13 | - TensorFlow2.0教程 14 | - Double DQN 15 | - DDQN 16 | nav: 简明教程 17 | categories: 18 | - TensorFlow 教程 19 | image: post/tensorflow2-gym-dqn/mountaincar_v0_scores.jpg 20 | github: https://github.com/geektutu/tensorflow-tutorial-samples 21 | --- 22 | 23 | ![Geektutu Q-Learning MountainCar Failed](tensorflow2-gym-dqn/mountaincar_v0_success.gif) 24 | 25 | 29 | 30 | 这篇文章是 **TensorFlow 2.0 Tutorial** 入门教程的第八篇文章。 31 | 32 | 实现DQN(Deep Q-Learning Network)算法,**代码90行** 33 | 34 | ## MountainCar 简介 35 | 36 | 上一篇文章[TensorFlow 2.0 (七) - 强化学习 Q-Learning 玩转 OpenAI gym](https://geektutu.com/post/tensorflow2-gym-q-learning.html)介绍了如何用**Q表(Q-Table)**,来更新策略,使小车顺利达到山顶,整个代码只有50行。我们先回顾一下上一篇文章的要点。 37 | 38 | 1. **MountainCar-v0 的游戏目标** 39 | 40 | 向左/向右推动小车,小车若到达山顶,则游戏胜利,若200回合后,没有到达山顶,则游戏失败。每走一步得-1分,最低分-200,越早到达山顶,则分数越高。 41 | 42 | 2. **MountainCar-v0 的几个重要的变量** 43 | 44 | - State: [position, velocity],position 范围 [-0.6, 0.6],velocity 范围 [-0.1, 0.1] 45 | - Action: 0(向左推) 或 1(不动) 或 2(向右推) 46 | - Reward: -1 47 | - Done: 小车到达山顶或已花费200回合 48 | 49 | 50 | 3. **Q-Table 的更新方程** 51 | 52 | > Q[s][a] = (1 - lr) * Q[s][a] + lr * (reward + factor * max(Q[next_s])) 53 | 54 | ## 神经网络替换 Q-Table 55 | 56 | 这一篇文章,我们将借助`TensorFlow 2.0`中的`keras`库,搭建深度神经网络(Deep Netural Network, DNN),替代`Q-Table`,即**深度Q网络(Deep Q-Learning Network, DQN)**,实现Q值的计算。 57 | 58 | 我们将神经网络比作一个函数,神经网络代替`Q-Table`其实就是在做 **函数拟合**,也可以称为**值函数近似(Value Function Approximation)**。 59 | 60 | 维基百科上有一个**万能近似定理(Universal approximation theorem)**,[Universal approximation theorem](https://en.wikipedia.org/wiki/Universal_approximation_theorem)定理表明:前馈神经网络,只需具备单层隐含层和有限个神经单元,就能以任意精度拟合任意复杂度的函数。 61 | 62 | 这是我们上篇文章定义的`Q-Table` 63 | 64 | ```python 65 | Q = defaultdict(lambda: [0, 0, 0]) 66 | ``` 67 | 68 | - 输入(key): 一维向量,(position, velocity) 69 | - 输出(value):一维向量,(action0_value, action1_value, action2_value) 70 | 71 | 72 | 接下来那我们按照定义的输入输出,简单搭一个神经网络吧。 73 | 74 | ```python 75 | # dqn.py 76 | # https://geektutu.com 77 | from collections import deque 78 | import random 79 | import gym 80 | import numpy as np 81 | from tensorflow.keras import models, layers, optimizers 82 | 83 | class DQN(object): 84 | def __init__(self): 85 | self.step = 0 86 | self.update_freq = 200 # 模型更新频率 87 | self.replay_size = 2000 # 训练集大小 88 | self.replay_queue = deque(maxlen=self.replay_size) 89 | self.model = self.create_model() 90 | self.target_model = self.create_model() 91 | 92 | def create_model(self): 93 | """创建一个隐藏层为100的神经网络""" 94 | STATE_DIM, ACTION_DIM = 2, 3 95 | model = models.Sequential([ 96 | layers.Dense(100, input_dim=STATE_DIM, activation='relu'), 97 | layers.Dense(ACTION_DIM, activation="linear") 98 | ]) 99 | model.compile(loss='mean_squared_error', 100 | optimizer=optimizers.Adam(0.001)) 101 | return model 102 | 103 | def act(self, s, epsilon=0.1): 104 | """预测动作""" 105 | # 刚开始时,加一点随机成分,产生更多的状态 106 | if np.random.uniform() < epsilon - self.step * 0.0002: 107 | return np.random.choice([0, 1, 2]) 108 | return np.argmax(self.model.predict(np.array([s]))[0]) 109 | 110 | def save_model(self, file_path='MountainCar-v0-dqn.h5'): 111 | print('model saved') 112 | self.model.save(file_path) 113 | ``` 114 | 115 | 网络结构很简单,只有一层隐藏层的全连接网络(Full Connected Network, FC)。但是我们用这个网络结构生成了2个model,一个是**预测**使用的`model`,另一个是训练时使用的`target_model`。看完下面的代码,就容易理解了。 116 | 117 | ```python 118 | # dqn.py 119 | # https://geektutu.com 120 | class DQN(object): 121 | # 省略 __init__, create_model, act, save_model,见上。 122 | def remember(self, s, a, next_s, reward): 123 | """历史记录,position >= 0.4时给额外的reward,快速收敛""" 124 | if next_s[0] >= 0.4: 125 | reward += 1 126 | self.replay_queue.append((s, a, next_s, reward)) 127 | 128 | def train(self, batch_size=64, lr=1, factor=0.95): 129 | if len(self.replay_queue) < self.replay_size: 130 | return 131 | self.step += 1 132 | # 每 update_freq 步,将 model 的权重赋值给 target_model 133 | if self.step % self.update_freq == 0: 134 | self.target_model.set_weights(self.model.get_weights()) 135 | 136 | replay_batch = random.sample(self.replay_queue, batch_size) 137 | s_batch = np.array([replay[0] for replay in replay_batch]) 138 | next_s_batch = np.array([replay[2] for replay in replay_batch]) 139 | 140 | Q = self.model.predict(s_batch) 141 | Q_next = self.target_model.predict(next_s_batch) 142 | 143 | # 使用公式更新训练集中的Q值 144 | for i, replay in enumerate(replay_batch): 145 | _, a, _, reward = replay 146 | Q[i][a] = (1 - lr) * Q[i][a] + lr * (reward + factor * np.amax(Q_next[i])) 147 | 148 | # 传入网络进行训练 149 | self.model.fit(s_batch, Q, verbose=0) 150 | ``` 151 | 152 | 整个结构如下图所示: 153 | 154 | ![Geektutu DQN](tensorflow2-gym-dqn/dqn.jpg) 155 | 156 | 我们需要用到上文提到的更新方程,来构造训练数据。其中`Q_next`是对`next_s`的预测值,在这里其实也可以使用`model`,但是`model`变化得太过频繁,而且我们在训练时,是以**batch**为单位进行训练的,也就是说很多训练数据对应的是之前状态的model,而不是频繁更新值的`model`,因此,我们使用更新频率低的`target_model`来计算`next_s`的Q值。 157 | 158 | 同时使用2个Q-Network的算法被称为**双Q网络(Double DQN, DDQN)**。因为传统的DQN普遍会过高估计Action的Q值,误差会随着Action的增加而增加,可能导致某个次优的值超过了最优Action的Q值,永远无法找到最优解。`DDQN`能够有效地解决这个问题。DQN 在比较简单的游戏,比如**CartPole-v0**能够取得较好的效果,但在**MountainCar-v0**这个游戏中,如果只使用 DQN 很难找到最优解。 159 | 160 | `target_model`每训练update_freq(200)次,更新权重与`model`一致。 161 | 162 | 那为什么在`Q-Table`中,可以用单步的数据来进行更新,但换作了神经网络,就需要以**batch**为单位来进行训练呢?这个问题在知乎有过讨论,链接在这里:[深度学习中的batch的大小对学习效果有何影响?](https://www.zhihu.com/question/32673260),简单说,如果单步训练,即**batch**为1,每次朝着单步的梯度方向修正,横冲直撞各自为政,难以收敛。如果**batch**过大,容易过拟合。而且`DQN`是增强学习算法,前面的训练数据质量较差,随着训练的进行,产生的动作价值越来越高,增强学习更为看重后面的训练数据,所以**batch**也不宜过大。 163 | 164 | 而这一点,也是`replay_queue`的最大容量设置为**2000**的原因。队列有先进先出的特性,当后面的数据加进来后,如果数据条数超过2000,前面的数据就会从队列中移除。后面的训练数据对于强化学习更重要。 165 | 166 | ## 可改动的 Reward 167 | 168 | 代码中还有这么一个细节: 169 | 170 | ```python 171 | if next_s[0] >= 0.4: 172 | reward += 1 173 | ``` 174 | 175 | `MountainCar-v0`这个游戏中,`State`由2个值构成,(position, velocity)。山顶的位置是**0.5**,因此当**position**大于**0.4**时,给`Reward`额外加**1**。这么做,是希望加快神经网络的收敛,更快地达到预期结果。每一步的`Reward`其实都是可以调整的,怎么做会让训练效果更好,可以动动脑,尝试尝试。 176 | 177 | ## 提前终止的 DQN 训练 178 | 179 | 好,神经网络已经准备就绪,接下来就开始训练吧。 180 | 181 | ```python 182 | # dqn.py 183 | # https://geektutu.com 184 | env = gym.make('MountainCar-v0') 185 | episodes = 1000 # 训练1000次 186 | score_list = [] # 记录所有分数 187 | agent = DQN() 188 | for i in range(episodes): 189 | s = env.reset() 190 | score = 0 191 | while True: 192 | a = agent.act(s) 193 | next_s, reward, done, _ = env.step(a) 194 | agent.remember(s, a, next_s, reward) 195 | agent.train() 196 | score += reward 197 | s = next_s 198 | if done: 199 | score_list.append(score) 200 | print('episode:', i, 'score:', score, 'max:', max(score_list)) 201 | break 202 | # 最后10次的平均分大于 -160 时,停止并保存模型 203 | if np.mean(score_list[-10:]) > -160: 204 | agent.save_model() 205 | break 206 | env.close() 207 | ``` 208 | 209 | 如果看过[TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game](https://geektutu.com/post/tensorflow2-gym-nn.html)和[TensorFlow 2.0 (七) - 强化学习 Q-Learning 玩转 OpenAI gym](https://geektutu.com/post/tensorflow2-gym-q-learning.html)这两篇文章的话,这部分代码就非常简单了。 210 | 211 | 我们在训练过程中,记录了每一次游戏的分数。并且,如果最近10次的平均分高于**-160**时,结束训练,并保存模型。 212 | 213 | 运行一下,看看效果吧。 214 | 215 | ```bash 216 | $ python dqn.py 217 | episode: 0 score: -200.0 max: -200.0 218 | episode: 1 score: -200.0 max: -200.0 219 | episode: 2 score: -200.0 max: -200.0 220 | ... 221 | episode: 124 score: -200.0 max: -200.0 222 | episode: 125 score: -138.0 max: -138.0 223 | ... 224 | episode: 166 score: -144.0 max: -97.0 225 | episode: 167 score: -166.0 max: -97.0 226 | episode: 168 score: -136.0 max: -97.0 227 | model saved 228 | ``` 229 | 230 | 可以看到,在第**125次**时,首次成功爬到了山顶,在**168次**的时候,平均分达到预期,停止了训练。 231 | 232 | ## 训练效果绘图 233 | 234 | 接下来,我们添加3行代码,将整个训练过程中的`score_list`的变化情况画出来,直观感受强化学习的学习过程。 235 | 236 | ```python 237 | import matplotlib.pyplot as plt 238 | 239 | plt.plot(score_list, color='green') 240 | plt.show() 241 | ``` 242 | 243 | ![geektutu MountainCar v0 scores](tensorflow2-gym-dqn/mountaincar_v0_scores.jpg) 244 | 245 | ## 模型预测/测试 246 | 247 | 和之前一样,准备了一个非常简单的可视化的测试代码,直观地感受下最终的游戏效果。 248 | 249 | ```python 250 | # test_dqn.py 251 | # https://geektutu.com 252 | import time 253 | import gym 254 | import numpy as np 255 | from tensorflow.keras import models 256 | env = gym.make('MountainCar-v0') 257 | model = models.load_model('MountainCar-v0-dqn.h5') 258 | s = env.reset() 259 | score = 0 260 | while True: 261 | env.render() 262 | time.sleep(0.01) 263 | a = np.argmax(model.predict(np.array([s]))[0]) 264 | s, reward, done, _ = env.step(a) 265 | score += reward 266 | if done: 267 | print('score:', score) 268 | break 269 | env.close() 270 | ``` 271 | 272 | 运行一下,还不错~ 273 | 274 | ```bash 275 | $ python test_dqn.py 276 | score: -161.0 277 | ``` 278 | 279 | ![Geektutu Q-Learning MountainCar Success](tensorflow2-gym-dqn/mountaincar_v0_success.gif) 280 | 281 | 代码已经上传到[Github - tensorflow-tutorial-samples](https://github.com/geektutu/tensorflow-tutorial-samples/tree/master/gym/MountainCar-v0-dqn),**dqn.py**只有90行,不妨试一试吧~ 282 | 283 | ## 附 推荐 284 | 285 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) 286 | -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-dqn/dqn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-dqn/dqn.jpg -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-dqn/mountaincar_v0_scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-dqn/mountaincar_v0_scores.jpg -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-dqn/mountaincar_v0_success.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-dqn/mountaincar_v0_success.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-nn.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game 3 | date: 2019-06-21 00:10:20 4 | description: TensorFlow2教程,TensorFlow2.0教程,TensorFlow 入门系列文章,第六篇,监督学习玩转 OpenAI gym game。 5 | tags: 6 | - 机器学习 7 | - TensorFlow 8 | - OpenAI gym 9 | - Python 10 | nav: 简明教程 11 | categories: 12 | - TensorFlow 教程 13 | keywords: 14 | - TensorFlow2教程 15 | - TensorFlow2.0教程 16 | image: post/tensorflow2-gym-nn/cartpole_v0_failed.gif 17 | github: https://github.com/geektutu/tensorflow-tutorial-samples 18 | --- 19 | 20 | ![gym cartpole-v0 failed](tensorflow2-gym-nn/cartpole_v0_failed.gif) 21 | 22 | 26 | 27 | 这篇文章是 **TensorFlow 2.0 Tutorial** 入门教程的第六篇文章,介绍如何使用 **TensorFlow 2.0** 搭建神经网络(Neural Network, NN),使用纯监督学习(Supervised Learning)的方法,玩转 OpenAI gym game。示例代码基于 Python 3 和 TensorFlow 2.0 。 28 | 29 | [OpenAI gym](https://gym.openai.com/)是一个开源的游戏模拟环境,主要用来开发和比较强化学习(Reinforcement Learning, RL)的算法。这篇文章是 Tensorflow 2.0 系列使用 gym 的第一篇文章,网上介绍强化学习玩 gym 的文章比较多,而纯监督学习的文章极少。我们先使用纯监督学习的算法,一起感受 gym 的魅力吧。 30 | 31 | ## 如何安装 32 | 33 | ```bash 34 | pip install tensorflow==2.0.0-beta0 35 | pip install gym 36 | # 如果你有多个Python环境,需要指定 37 | # python3 -m pip install tensorflow==2.0.0-beta0 38 | # python3 -m pip install gym 39 | ``` 40 | 41 | ## OpenAI gym 初尝试 42 | 43 | 我们先对 OpenAI 的 gym 库的几个核心概念作个简单介绍。 44 | 45 | 想象一下你在玩贪吃蛇,你需要分析当前游戏的`状态(State)`,例如你所处的位置,周围的障碍物等,才能够决定下一步的`动作(Action)`,上下左右。那你每走一步,就会得到一个`奖励(Reward)`。这个奖励可能是正向奖励(Positive Reward),也可能是负向奖励(Negative Reward),比如撞到了障碍物。重复N次这样的过程,直到游戏`结束(Done)`。 46 | 47 | 从整个例子中,可以总结出几个重要的概念,接下来的示例将会使用 OpenAI gym 库提供的 **CartPole Game** 环境,一起来熟悉CartPole 游戏中的这几个概念的含义吧。先直接给一个可以运行看效果的示例,这个示例中,Action 是随机选择的。 48 | 49 | ```python 50 | # try_gym.py 51 | # https://geektutu.com 52 | import gym # 0.12.5 53 | import random 54 | import time 55 | 56 | env = gym.make("CartPole-v0") # 加载游戏环境 57 | 58 | state = env.reset() 59 | score = 0 60 | while True: 61 | time.sleep(0.1) 62 | env.render() # 显示画面 63 | action = random.randint(0, 1) # 随机选择一个动作 0 或 1 64 | state, reward, done, _ = env.step(action) # 执行这个动作 65 | score += reward # 每回合的得分 66 | if done: # 游戏结束 67 | print('score: ', score) # 打印分数 68 | break 69 | env.close() 70 | ``` 71 | 72 | ```bash 73 | $ python3 try_gym.py 74 | score: 14.0 75 | ``` 76 | 77 | | 概念 | 解释 | 示例 | 78 | | ----- | -------------------------------------------------- | ---------------------------- | 79 | | State | list:状态,[车位置, 车速度, 杆角度, 杆速度] | 0.02,0.95,-0.07,-1.53| 80 | | Action | int:动作(0向左/1向右) | 1 | 81 | | Reward | float:奖励(每走一步得1分) | 1.0 | 82 | | Done | bool:是否结束(True/False),上限200回合 | False | 83 | 84 | 游戏上限是200回合,但是如果是随机选择 Action,就只得了14分,游戏就结束了。 85 | 86 | ## 搭建神经网络 87 | 88 | 我们的目的就是将随机选择 Action 的部分,变为由神经网络模型来选择。神经网络的输入是`State`,输出是`Action`。在这里,Action 用独热编码来表示,即 **[1, 0]** 表示向左,**[0, 1]** 表示向右。这样我们可以方便地使用`np.argmax()`获取预测的 Action 的值。 89 | 90 | ```python 91 | np.argmax([0.3, 0.7]) # 1,假如神经网络的输出是 [0.3, 0.7],那Action值为1,表示向右。 92 | np.argmax([0.8, 0.2]) # 0,表示向右。 93 | ``` 94 | 95 | 接下来我们搭建一个 `4 x 64 x 20 x 2` 的网络,输入层为4,输出层为2。 96 | 97 | ```python 98 | # train.py 99 | # https://geektutu.com 100 | import random 101 | import gym 102 | import numpy as np 103 | from tensorflow.keras import models, layers 104 | 105 | env = gym.make("CartPole-v0") # 加载游戏环境 106 | 107 | STATE_DIM, ACTION_DIM = 4, 2 # State 维度 4, Action 维度 2 108 | model = models.Sequential([ 109 | layers.Dense(64, input_dim=STATE_DIM, activation='relu'), 110 | layers.Dense(20, activation='relu'), 111 | layers.Dense(ACTION_DIM, activation='linear') 112 | ]) 113 | model.summary() # 打印神经网络信息 114 | ``` 115 | 116 | ## 训练数据从哪里来? 117 | 118 | 神经网络的模型搭好了,那训练数据呢? 119 | 120 | 随机产生的数据,得分很低,如果不过滤,数据集质量是很低的。 121 | 122 | 最终的办法:**试,一百次不行,就试一万次**。 123 | 124 | 简而言之,我们在过程中计算`Score`,如果最终得分达到设定的标准,这个分数所对应的所有`State`和`Action`就可以作为我们的训练数据了。 125 | 126 | ```python 127 | # train.py 128 | def generate_data_one_episode(): 129 | '''生成单次游戏的训练数据''' 130 | x, y, score = [], [], 0 131 | state = env.reset() 132 | while True: 133 | action = random.randrange(0, 2) 134 | x.append(state) 135 | y.append([1, 0] if action == 0 else [0, 1]) # 记录数据 136 | state, reward, done, _ = env.step(action) # 执行动作 137 | score += reward 138 | if done: 139 | break 140 | return x, y, score 141 | 142 | def generate_training_data(expected_score=100): 143 | '''# 生成N次游戏的训练数据,并进行筛选,选择 > 100 的数据作为训练集''' 144 | data_X, data_Y, scores = [], [], [] 145 | for i in range(10000): 146 | x, y, score = generate_data_one_episode() 147 | if score > expected_score: 148 | data_X += x 149 | data_Y += y 150 | scores.append(score) 151 | print('dataset size: {}, max score: {}'.format(len(data_X), max(scores))) 152 | return np.array(data_X), np.array(data_Y) 153 | ``` 154 | 155 | 这样,我们就可以使用`generate_training_data`函数生成训练集了。 156 | 157 | ## 训练并保存模型 158 | 159 | 神经网络和数据集都准备好了,训练就非常简单了。 160 | 161 | ```python 162 | # train.py 163 | data_X, data_Y = generate_training_data() 164 | model.compile(loss='mse', optimizer='adam', epochs=5) 165 | model.fit(data_X, data_Y) 166 | model.save('CartPole-v0-nn.h5') # 保存模型 167 | ``` 168 | 169 | 从运行的结果看,我们最终得到的训练集大小为213,最大分数是108分。 170 | 171 | ```bash 172 | $ python train.py 173 | Model: "sequential" 174 | _________________________________________________________________ 175 | Layer (type) Output Shape Param # 176 | ================================================================= 177 | dense (Dense) (None, 64) 320 178 | _________________________________________________________________ 179 | dense_1 (Dense) (None, 20) 1300 180 | _________________________________________________________________ 181 | dense_2 (Dense) (None, 2) 42 182 | ================================================================= 183 | Total params: 1,662 184 | Trainable params: 1,662 185 | Non-trainable params: 0 186 | _________________________________________________________________ 187 | dataset size: 213, max score: 108.0 188 | Train on 213 samples 189 | Epoch 1/5 190 | 213/213 [==============================] - 0s 713us/sample - loss: 0.4701 191 | Epoch 2/5 192 | 213/213 [==============================] - 0s 35us/sample - loss: 0.3920 193 | Epoch 3/5 194 | 213/213 [==============================] - 0s 38us/sample - loss: 0.3370 195 | Epoch 4/5 196 | 213/213 [==============================] - 0s 39us/sample - loss: 0.2985 197 | Epoch 5/5 198 | 213/213 [==============================] - 0s 38us/sample - loss: 0.2745 199 | ``` 200 | 201 | ## 模型测试/预测 202 | 203 | ```python 204 | # predict.py 205 | # https://geektutu.com 206 | import time 207 | import numpy as np 208 | import gym 209 | from tensorflow.keras import models 210 | 211 | 212 | saved_model = models.load_model('CartPole-v0-nn.h5') # 加载模型 213 | env = gym.make("CartPole-v0") # 加载游戏环境 214 | 215 | for i in range(5): 216 | state = env.reset() 217 | score = 0 218 | while True: 219 | time.sleep(0.01) 220 | env.render() # 显示画面 221 | action = np.argmax(saved_model.predict(np.array([state]))[0]) # 预测动作 222 | state, reward, done, _ = env.step(action) # 执行这个动作 223 | score += reward # 每回合的得分 224 | if done: # 游戏结束 225 | print('using nn, score: ', score) # 打印分数 226 | break 227 | env.close() 228 | ``` 229 | 230 | ```bash 231 | $ python predict.py 232 | using nn, score: 200.0 233 | using nn, score: 200.0 234 | using nn, score: 200.0 235 | using nn, score: 200.0 236 | using nn, score: 200.0 237 | ``` 238 | 239 | 模型的结果很不错,每一次都达到了200的满分。 240 | 241 | 看看效果吧~ 242 | 243 | ![gym cartpole-v0 success](tensorflow2-gym-nn/cartpole_v0_success.gif) 244 | 245 | 在[Github - tensorflow-tutorial-samples](https://github.com/geektutu/tensorflow-tutorial-samples/tree/master/gym/CartPole-v0-nn)上提供了`.py`和`.ipynb`2种格式的代码。 246 | 247 | ## 附 推荐 248 | 249 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-nn/cartpole_v0_failed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-nn/cartpole_v0_failed.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-nn/cartpole_v0_success.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-nn/cartpole_v0_success.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-pg.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow 2.0 (九) - 强化学习 70行代码实战 Policy Gradient 3 | date: 2019-07-06 17:30:00 4 | description: TensorFlow2教程,TensorFlow2.0教程,TensorFlow 2.0 入门系列文章,第九篇,使用强化学习算法策略梯度(Policy Gradient),实战 OpenAI gym CartPole,代码仅70行。 5 | tags: 6 | - 强化学习 7 | - TensorFlow 8 | - OpenAI gym 9 | - Python 10 | keywords: 11 | - Policy Gradient 12 | - 策略梯度 13 | - TensorFlow2教程 14 | - TensorFlow2.0教程 15 | nav: 简明教程 16 | categories: 17 | - TensorFlow 教程 18 | image: post/tensorflow2-gym-pg/pg_plot.jpg 19 | github: https://github.com/geektutu/tensorflow-tutorial-samples 20 | --- 21 | 22 | ![Geektutu Policy Gradient Success](tensorflow2-gym-pg/pg_success.gif) 23 | 24 | 28 | 29 | 这篇文章是 **TensorFlow 2.0 Tutorial** 入门教程的第九篇文章。 30 | 31 | 实战策略梯度算法(Policy Gradient),**代码70行** 32 | 33 | ## CartPole 简介 34 | 35 | 在之前的文章中,我们使用过纯监督学习的算法,强化学习算法中的Q学习(Q-Learning)和深度Q网络(Deep Q-learning Network, DQN),这一篇文章,我们选择策略梯度算法(Policy Gradient),来玩一玩 CartPole。 36 | 37 | 先回顾一下**CartPole-v0**的几个重要概念。 38 | 39 | | 概念 | 解释 | 示例 | 40 | | ----- | -------------------------------| -------------| 41 | | State | 状态,[车位置, 车速度, 杆角度, 杆速度] | 0.02, 0.95, -0.07, -1.53| 42 | | Action | 动作(0向左/1向右) | 1 | 43 | | Reward | 奖励(每走一步得1分) | 1.0 | 44 | 45 | 我们在 [TensorFlow 2.0 (八) - 强化学习 DQN 玩转 gym Mountain Car](https://geektutu.com/post/tensorflow2-gym-dqn.html)这篇文章中,介绍了基于价值(value-based)的强化学习算法 DQN,在 DQN 中,神经网络的输入是状态,输出是每一个动作的价值。每一次从所有可行的动作中选择Q值最大的执行。我们使用了一个公式来不断地计算期望的Q值,训练神经网络。 46 | 47 | 那有没有可能,直接输出动作呢?这就是我们今天要介绍的基于策略(policy-based)的策略梯度算法(Policy Gradient)。 48 | 49 | 本文不涉及数学推导,仅介绍如何高效实现。如对该算法该兴趣,推荐 `Medium` 上有3.8k点赞的一篇文章[An introduction to Policy Gradients](https://medium.com/free-code-camp/an-introduction-to-policy-gradients-with-cartpole-and-doom-495b5ef2207f)。 50 | 51 | ## 搭建神经网络 52 | 53 | Policy Gradient 网络的输入也是状态(State),那输出呢?每个动作的概率。例如 `[0.7, 0.3]` ,这意味着有70%的几率会选择动作0,30%的几率选择动作1。相对于 Policy Gradient,DQN 的动作更确定,因为 DQN 每次总是选择Q值最大的动作,而Policy Gradient 按照概率选择,会产生更多的不确定性。 54 | 55 | 废话不多说,神经网络先搭起来吧~ 56 | 57 | ```python 58 | # policy_gradient.py 59 | # https://geektutu.com 60 | import matplotlib.pyplot as plt 61 | import gym 62 | import numpy as np 63 | from tensorflow.keras import models, layers, optimizers 64 | 65 | env = gym.make('CartPole-v0') 66 | 67 | STATE_DIM, ACTION_DIM = 4, 2 68 | model = models.Sequential([ 69 | layers.Dense(100, input_dim=STATE_DIM, activation='relu'), 70 | layers.Dropout(0.1), 71 | layers.Dense(ACTION_DIM, activation="softmax") 72 | ]) 73 | model.compile(loss='mean_squared_error', 74 | optimizer=optimizers.Adam(0.001)) 75 | ``` 76 | 77 | 我们的神经网络很简单,输入层为4,输出层为2,隐藏层为100。不过这次代码多了一个`Dropout`,Dropout(0.1) 的含义是,随机忘记10%的权重。学习初期,一开始的数据质量不高,随着学习的进行,质量才逐步高了起来,一开始容易陷入**局部最优**和**过拟合**,使用 Dropout 可以有效避免。 78 | 79 | 如何选择动作呢?前文已经介绍,按照概率。 80 | 81 | ```python 82 | # policy_gradient.py 83 | # https://geektutu.com 84 | def choose_action(s): 85 | """预测动作""" 86 | prob = model.predict(np.array([s]))[0] 87 | return np.random.choice(len(prob), p=prob) 88 | ``` 89 | 90 | ## 优化策略 91 | 92 | 接下来是最大的问题,如何优化策略呢? 93 | 94 | ![Policy Gradient Optimize](tensorflow2-gym-pg/pg_optimize.jpg) 95 | 96 | ### 1) 衰减的累加期望 97 | 98 | 我们先想象一下,假如你在玩坦克大战,你的每一步都会对后面的局势产生巨大的影响。比如,敌方攻打你的老巢,你是选择先消灭敌方呢,还是选择坐视不理?很可能一步就决定了结局。因此,需要从整个回合的角度看待这个问题。先引入一个概念 `带衰减reward的累加期望` 。 99 | 100 | > discount_reward[i] = reward[i] + gamma * discount_reward[i+1] 101 | 102 | 某一步的累加期望等于下一步的累加期望乘衰减系数`gamma`,加上`reward`。 103 | 104 | 手工算一算。 105 | 106 | ``` 107 | 最后一步:1 108 | 倒数第二步:1 + 0.95 * 1 = 1.95 109 | 倒数第三步:1 + 0.95 * 1.95 = 2.8525 110 | 倒数第四步:1 + 0.95 * 2.8525 = 3.709875 111 | ``` 112 | 113 | 假设某个回合只得了10分,那么这个回合的每一步的累加期望都不会高。假设得到了满分200分,那么回合中的大部分步骤的累加期望很会很高,越是前面的步骤,累加期望越高。 114 | 115 | 代码实现就很简单了,唯一的不同是最后加了中心化和标准化的处理。这样处理的目的是希望得到相同尺度的数据,避免因为数值相差过大而导致网络无法收敛。 116 | 117 | ```python 118 | # policy_gradient.py 119 | # https://geektutu.com 120 | def discount_rewards(rewards, gamma=0.95): 121 | """计算衰减reward的累加期望,并中心化和标准化处理""" 122 | prior = 0 123 | out = np.zeros_like(rewards) 124 | for i in reversed(range(len(rewards))): 125 | prior = prior * gamma + rewards[i] 126 | out[i] = prior 127 | return out / np.std(out - np.mean(out)) 128 | ``` 129 | 130 | ### 2) 给loss加权重 131 | 132 | 一个动作的`累加期望`很高,自然希望该动作出现的概率变大,这就是学习的目的。一般,我们通过构造**标签(y_true/label)**,来训练神经网络。就如在[TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game](https://geektutu.com/post/tensorflow2-gym-nn.html)这篇文章中做的一样。当然,我们还可以通过改变**损失函数(loss function)**达到目的。对于累加期望大的动作,可以放大`loss`的值,而对于累加期望小的动作,那么就减小loss的值。这样呢?神经网络就能快速朝着累加期望大的方向优化了。最简单的方法,给`loss`加一个权重。 133 | 134 | 所以我们的最终的损失函数就变成了: 135 | 136 | > loss = discount_reward * loss 137 | 138 | 这里的`discount_reward`可以理解为策略梯度算法(Policy Gradient)中的梯度(Gradient)。如果对梯度不熟悉,可以看第一篇文章[TensorFlow入门(一) - mnist手写数字识别(网络搭建)](https://geektutu.com/post/tensorflow-mnist-simplest.html)。 139 | 140 | 在**TensorFlow 1.x**的版本中,搭建一个自定义loss的网络很复杂,而使用**TensorFlow 2.0**,借助`Keras`,我们可以写出非常简洁的代码。 141 | 142 | ```python 143 | # policy_gradient.py 144 | # https://geektutu.com 145 | def train(records): 146 | s_batch = np.array([record[0] for record in records]) 147 | # action 独热编码处理,方便求动作概率,即 prob_batch 148 | a_batch = np.array([[1 if record[1] == i else 0 for i in range(ACTION_DIM)] 149 | for record in records]) 150 | # 假设predict的概率是 [0.3, 0.7],选择的动作是 [0, 1] 151 | # 则动作[0, 1]的概率等于 [0, 0.7] = [0.3, 0.7] * [0, 1] 152 | prob_batch = model.predict(s_batch) * a_batch 153 | r_batch = discount_rewards([record[2] for record in records]) 154 | 155 | model.fit(s_batch, prob_batch, sample_weight=r_batch, verbose=0) 156 | ``` 157 | 158 | 设置参数`sample_weight`,即可给loss设权重。 159 | 160 | ## 训练过程与结果 161 | 162 | 接下来,把 OpenAI gym 的代码融入进来吧。 163 | 164 | ```python 165 | # policy_gradient.py 166 | # https://geektutu.com 167 | episodes = 2000 # 至多2000次 168 | score_list = [] # 记录所有分数 169 | for i in range(episodes): 170 | s = env.reset() 171 | score = 0 172 | replay_records = [] 173 | while True: 174 | a = choose_action(s) 175 | next_s, r, done, _ = env.step(a) 176 | replay_records.append((s, a, r)) 177 | 178 | score += r 179 | s = next_s 180 | if done: 181 | train(replay_records) 182 | score_list.append(score) 183 | print('episode:', i, 'score:', score, 'max:', max(score_list)) 184 | break 185 | # 最后10次的平均分大于 195 时,停止并保存模型 186 | if np.mean(score_list[-10:]) > 195: 187 | model.save('CartPole-v0-pg.h5') 188 | break 189 | env.close() 190 | ``` 191 | 192 | 运行一下试一试吧。 193 | 194 | ```bash 195 | $ python policy_gradient.py 196 | episode: 0 score: 13.0 max: 13.0 197 | episode: 1 score: 35.0 max: 35.0 198 | episode: 2 score: 18.0 max: 35.0 199 | ... 200 | episode: 793 score: 200.0 max: 200.0 201 | episode: 794 score: 200.0 max: 200.0 202 | episode: 795 score: 164.0 max: 200.0 203 | episode: 796 score: 200.0 max: 200.0 204 | model saved 205 | ``` 206 | 207 | 画一张图,感受下学习的过程,这一次稍微多了3行多项式拟合的代码,能够更好地展现整个分数变化的趋势。 208 | 209 | ```python 210 | # policy_gradient.py 211 | # https://geektutu.com 212 | plt.plot(score_list) 213 | x = np.array(range(len(score_list))) 214 | smooth_func = np.poly1d(np.polyfit(x, score_list, 3)) 215 | plt.plot(x, smooth_func(x), label='Mean', linestyle='--') 216 | plt.show() 217 | ``` 218 | 219 | ![Geektutu Policy Gradient Plot](tensorflow2-gym-pg/pg_plot.jpg)。 220 | 221 | ## 测试 222 | 223 | 按照惯例,测试下效果。 224 | 225 | ```python 226 | # test_policy_gradient.py 227 | # https://geektutu.com 228 | import time 229 | import numpy as np 230 | import gym 231 | from tensorflow.keras import models 232 | 233 | saved_model = models.load_model('CartPole-v0-pg.h5') 234 | env = gym.make("CartPole-v0") 235 | 236 | for i in range(5): 237 | s = env.reset() 238 | score = 0 239 | while True: 240 | time.sleep(0.01) 241 | env.render() 242 | prob = saved_model.predict(np.array([s]))[0] 243 | a = np.random.choice(len(prob), p=prob) 244 | s, r, done, _ = env.step(a) 245 | score += r 246 | if done: 247 | print('using policy gradient, score: ', score) # 打印分数 248 | break 249 | env.close() 250 | ``` 251 | 252 | ```bash 253 | python test_policy_gradient.py 254 | using policy gradient, score: 200.0 255 | using policy gradient, score: 200.0 256 | using policy gradient, score: 200.0 257 | using policy gradient, score: 200.0 258 | using policy gradient, score: 200.0 259 | ``` 260 | 261 | ![Geektutu Policy Gradient Success](tensorflow2-gym-pg/pg_success.gif) 262 | 263 | ## 如何优化 264 | 265 | 教程中,每个回合不管多少条训练数据,直接训练,而没有固定大小的`batch`,不利于训练。有时间可以尝试,设置一个大小为2000的队列,存储历史的训练数据,每次固定取32/64条训练集,对比下两者的效果。 266 | 267 | 代码已上传至 [Github - CartPole-v0-policy-gradient](https://github.com/geektutu/tensorflow-tutorial-samples/tree/master/gym/CartPole-v0-policy-gradient) 268 | 269 | ## 附 推荐 270 | 271 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-pg/pg_optimize.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-pg/pg_optimize.jpg -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-pg/pg_plot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-pg/pg_plot.jpg -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-pg/pg_success.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-pg/pg_success.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-q-learning.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: TensorFlow 2.0 (七) - 强化学习 Q-Learning 玩转 OpenAI gym 3 | date: 2019-06-25 00:40:24 4 | description: TensorFlow2教程,TensorFlow2.0教程,TensorFlow 2.0 入门系列文章,第七篇,Q-Learning 玩转 OpenAI gym game MountainCar-v0。 5 | tags: 6 | - 强化学习 7 | - TensorFlow 8 | - Q-Learning 9 | - OpenAI gym 10 | - Python 11 | keywords: 12 | - TensorFlow2教程 13 | - TensorFlow2.0教程 14 | nav: 简明教程 15 | categories: 16 | - TensorFlow 教程 17 | image: post/tensorflow2-gym-q-learning/mountaincar_v0_success.gif 18 | github: https://github.com/geektutu/tensorflow-tutorial-samples 19 | --- 20 | 21 | ![Geektutu Q-Learning MountainCar Failed](tensorflow2-gym-q-learning/mountaincar_v0_failed.gif) 22 | 23 | 27 | 28 | 这篇文章是 **TensorFlow 2.0 Tutorial** 入门教程的第七篇文章,介绍如何使用强化学习(Reinforcement Learning, RL)的一个经典算法(Q-Learning),玩转 OpenAI gym game。 29 | 30 | **代码仅50行~** 31 | 32 | ## MountainCar-v0 游戏简介 33 | 34 | 今天我们选取的游戏是[MountainCar-v0](https://github.com/openai/gym/wiki/MountainCar-v0),这个游戏很简单,将车往不同的方向推,最终让车爬到山顶。和上一篇文章 [TensorFlow 2.0 (六) - 监督学习玩转 OpenAI gym game](https://geektutu.com/post/tensorflow2-gym-nn.html)一样,我们先介绍几个比较关键的概念,以及这几个概念在这个游戏中的具体含义。 35 | 36 | 37 | |概念|解释|示例| 38 | |---|---|---| 39 | |State | list: 状态,[位置,速度] | [0.5,-0.01] | 40 | |Action | int: 动作(0向左推,1不动,2向右推)| 2 | 41 | |Reward | float: 每回合-1分 | -1 | 42 | |Done | bool: 是否爬到山顶(True/False),上限200回合 | -1 | 43 | 44 | 如果`200回合`还没到达山顶,说明游戏失败,-200是最低分。每个回合得-1,分数越高,说明尝试回合数越少,意味着越早地到达山顶。比如得分-100分,表示仅经过了100回合就到达了山顶。 45 | 46 | ## 初始化 Q-Table(Q表) 47 | 48 | 如果有如下这样一张表,告诉我在某个状态(State)下, 执行每一个动作(Action)产生的价值(Value),那就可以通过查询表格,选择产生价值最大的动作了。 49 | 50 | | State | Action 0 | Action 1 | Action 2| 51 | |---|---|---|---| 52 | | [0.2, -0.01] | 10 | -20 | -30 | 53 | | [-0.3, 0.01] | 100 | 0 | 0 | 54 | | [-0.1, -0.01] | 0 | -10 | 20 | 55 | 56 | 价值(Value)怎么计算呢?游戏的最终目标是爬到山顶,爬到山顶前的每一个动作都为最终的目标贡献了价值,因此每一个动作的价值计算,和最终的结果,也就是与未来(Future)有关。这就是强化学习的经典算法 `Q-Learning` 设计的核心。`Q-Learning`中的`Q`,代表的是 **Action-Value**,也可以理解为 **Quality**。而上面这张表,就称之为 `Q表(Q-Table)`。 57 | 58 | 到这里,你应该可以理解了,`Q-Learning`的目的是创建`Q-Table`。有了`Q-Table`,自然能知道选择哪一个Action了。 59 | 60 | 我们先初始化一张`Q表(Q-Table)` 61 | 62 | ```python 63 | # q_learning.py 64 | # https://geektutu.com 65 | import pickle # 保存模型用 66 | from collections import defaultdict 67 | import gym # 0.12.5 68 | import numpy as np 69 | 70 | # 默认将Action 0,1,2的价值初始化为0 71 | Q = defaultdict(lambda: [0, 0, 0]) 72 | ``` 73 | 74 | ## 连续状态映射 75 | 76 | 但是这个`Q-Table`有一个问题,我们用字典来表示`Q-Table`,State中的值是浮点数,是连续的,意味着有无数种状态,这样更新Q-Table的值是不可能实现。因此,我们需要对State进行线性转换,**归一化处理**。即,将State中的值映射到[0, 40]的空间中。这样,就将无数种状态映射到`40x40`种状态了。 77 | 78 | ```python 79 | # q_learning.py 80 | # https://geektutu.com 81 | env = gym.make('MountainCar-v0') 82 | 83 | def transform_state(state): 84 | """将 position, velocity 通过线性转换映射到 [0, 40] 范围内""" 85 | pos, v = state 86 | pos_low, v_low = env.observation_space.low 87 | pos_high, v_high = env.observation_space.high 88 | 89 | a = 40 * (pos - pos_low) / (pos_high - pos_low) 90 | b = 40 * (v - v_low) / (v_high - v_low) 91 | 92 | return int(a), int(b) 93 | 94 | # print(transform_state([-1.0, 0.01])) 95 | # eg: (4, 22) 96 | ``` 97 | 98 | ## 更新 Q-Table 99 | 100 | 那怎么更新Q-Table呢?下面这个简化版的公式就是关键了。 101 | 102 | > Q[s][a] = (1 - lr) * Q[s][a] + lr * (reward + factor * max(Q[next_s])) 103 | 104 | 看见公式先别紧张,我们逐步来看。 105 | 106 | | 表达式 | 含义 | 简介 | 107 | |---|-----|---| 108 | |s, a,next_s | - | 当前状态,当前动作,下一个状态 | 109 | |reward | 奖励 | 执行a动作的奖励 | 110 | |Q[s][a] | 价值 | 状态s下,动作a产生的价值 | 111 | |max(Q[next_s]) | 最大价值 | 下一个状态下,所有动作价值的最大值 | 112 | | lr | 学习速率(learning_rate) | lr越大,保留之前训练效果越少。lr为0,Q[s, a]值不变;lr为1时,完全抛弃了原来的值。| 113 | | factor | 折扣因子(discount_factor) | factor 越大,表示越重视历史的经验; factor 为0时,只关心当前利益(reward) | 114 | 115 | 为什么是`max(Q[next_s])`而不是`min(Q[next_s])`呢?在Q-Table中,状态 **next_s** 有3个动作可选,即[0, 1, 2],对应价值 **Q[next_s][0],Q[next_s][1],Q[next_s][2]**。`Q[s][a]`的值应由产生的最大价值的动作决定。 116 | 117 | 我们想象一个极端场景:五子棋,最后一步,下在X位置赢,100分;其他位置输,0分。那怎么衡量倒数第二步的价值呢?当然是由最后一步的最大价值决定,不能因为最后一步走错了,就否定前面动作的价值。 118 | 119 | ## 开始训练 120 | 121 | 接下来我们就把这个公式嵌入到`OpenAI gym`中吧。 122 | 123 | ```python 124 | # q_learning.py 125 | # https://geektutu.com 126 | lr, factor = 0.7, 0.95 127 | episodes = 10000 # 训练10000次 128 | score_list = [] # 记录所有分数 129 | for i in range(episodes): 130 | s = transform_state(env.reset()) 131 | score = 0 132 | while True: 133 | a = np.argmax(Q[s]) 134 | # 训练刚开始,多一点随机性,以便有更多的状态 135 | if np.random.random() > i * 3 / episodes: 136 | a = np.random.choice([0, 1, 2]) 137 | # 执行动作 138 | next_s, reward, done, _ = env.step(a) 139 | next_s = transform_state(next_s) 140 | # 根据上面的公式更新Q-Table 141 | Q[s][a] = (1 - lr) * Q[s][a] + lr * (reward + factor * max(Q[next_s])) 142 | score += reward 143 | s = next_s 144 | if done: 145 | score_list.append(score) 146 | print('episode:', i, 'score:', score, 'max:', max(score_list)) 147 | break 148 | env.close() 149 | 150 | # 保存模型 151 | with open('MountainCar-v0-q-learning.pickle', 'wb') as f: 152 | pickle.dump(dict(Q), f) 153 | print('model saved') 154 | ``` 155 | 156 | 接下来我们来看一看训练效果。因为Q表的状态比较多,因而训练到3000次的时候,仍旧没能成功到达山顶。最终训练结束的时候,分数保持在-150左右,最大分数达到-119。代码中的参数都是随便选取的,如果有时间优化下,肯定能有更好的结果。 157 | 158 | ```bash 159 | $ python q_learning.py 160 | episode: 3080 score: -200.0 max: -200 161 | episode: 3081 score: -200.0 max: -200 162 | ... 163 | episode: 9996 score: -169.0 max: -119.0 164 | episode: 9997 score: -141.0 max: -119.0 165 | episode: 9998 score: -160.0 max: -119.0 166 | episode: 9999 score: -161.0 max: -119.0 167 | model saved 168 | ``` 169 | 170 | ## 测试模型 171 | 172 | 最终,我们写一下测试代码,加载模型,顺便感受下真实的游戏画面吧~ 173 | 174 | ```python 175 | # test_q_learning.py 176 | # https://geektutu.com 177 | import time 178 | import pickle 179 | import gym 180 | import numpy as np 181 | 182 | # 加载模型 183 | with open('MountainCar-v0-q-learning.pickle', 'rb') as f: 184 | Q = pickle.load(f) 185 | print('model loaded') 186 | 187 | env = gym.make('MountainCar-v0') 188 | s = env.reset() 189 | score = 0 190 | while True: 191 | env.render() 192 | time.sleep(0.01) 193 | # transform_state函数 与 训练时的一致 194 | s = transform_state(s) 195 | a = np.argmax(Q[s]) if s in Q else 0 196 | s, reward, done, _ = env.step(a) 197 | score += reward 198 | if done: 199 | print('score:', score) 200 | break 201 | env.close() 202 | ``` 203 | 204 | 运行一下,你就知道。 205 | 206 | ```bash 207 | $ python test_q_learning.py 208 | model loaded 209 | score: -151.0 210 | ``` 211 | 212 | ![Geektutu Q-Learning MountainCar Success](tensorflow2-gym-q-learning/mountaincar_v0_success.gif) 213 | 214 | 215 | 代码已经上传到[Github - tensorflow-tutorial-samples](https://github.com/geektutu/tensorflow-tutorial-samples/tree/master/gym/MountainCar-v0-q-learning),**q_learning.py**只有50行,不妨试一试吧~ 216 | 217 | 我们这里的预测模型保存在了`Q-Table`中,输入是State,输出是3个Action的价值,Q-Table是一个字典,有着准确的映射关系,那如果我们用深度神经网络(Deep Neural Network, DNN)模拟这个字典呢?那这就被称为 `DQN(Deep Q-Learning Network)`。好,那我们下一篇文章,就借助`TensorFlow 2.0`用神经网络替换掉`Q-Table`吧。 218 | 219 | ## 附 推荐 220 | 221 | - [一篇文章入门 Python](https://geektutu.com/post/quick-python.html) -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-q-learning/mountaincar_v0_failed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-q-learning/mountaincar_v0_failed.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-gym-q-learning/mountaincar_v0_success.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-gym-q-learning/mountaincar_v0_success.gif -------------------------------------------------------------------------------- /posts/tensorflow/tensorflow2-mnist-cnn/cnn_image_sample.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/posts/tensorflow/tensorflow2-mnist-cnn/cnn_image_sample.gif -------------------------------------------------------------------------------- /scaffolds/draft.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{ title }} 3 | subtitle: 4 | catalog: true 5 | header-img: 6 | tags: 7 | --- 8 | -------------------------------------------------------------------------------- /scaffolds/page.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{ title }} 3 | date: {{ date }} 4 | description: 5 | header-img: "img/home-bg.jpg" 6 | --- 7 | -------------------------------------------------------------------------------- /scaffolds/post.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: {{ title }} 3 | subtitle: 4 | date: {{ date }} 5 | catalog: true 6 | header-img: 7 | tags: 8 | --- 9 | -------------------------------------------------------------------------------- /source/404.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: 404 3 | title: "404 NOT FOUND" 4 | description: "404 NOT FOUND" 5 | --- 6 | -------------------------------------------------------------------------------- /source/CNAME: -------------------------------------------------------------------------------- 1 | geektutu.com -------------------------------------------------------------------------------- /source/ads.txt: -------------------------------------------------------------------------------- 1 | google.com, pub-9394187906225546, DIRECT, f08c47fec0942fa0 -------------------------------------------------------------------------------- /source/archives/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: "archive" 3 | title: "归档" 4 | description: "极客兔兔的博客的归档列表" 5 | --- 6 | -------------------------------------------------------------------------------- /source/bdunion.txt: -------------------------------------------------------------------------------- 1 | 51aa69923fbbdd7b0f3fbd567f76e248 -------------------------------------------------------------------------------- /source/img/bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/bg.jpg -------------------------------------------------------------------------------- /source/img/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/icon.png -------------------------------------------------------------------------------- /source/img/related_links/email.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/email.png -------------------------------------------------------------------------------- /source/img/related_links/geekcircle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/geekcircle.png -------------------------------------------------------------------------------- /source/img/related_links/github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/github.png -------------------------------------------------------------------------------- /source/img/related_links/go.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/go.png -------------------------------------------------------------------------------- /source/img/related_links/rss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/rss.jpg -------------------------------------------------------------------------------- /source/img/related_links/weibo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/weibo.jpg -------------------------------------------------------------------------------- /source/img/related_links/zhihu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/img/related_links/zhihu.png -------------------------------------------------------------------------------- /source/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: index 3 | --- -------------------------------------------------------------------------------- /source/jd_root.txt: -------------------------------------------------------------------------------- 1 | e95d2f4a675fe6f2db07016ed3b771a10cc5c1247fa7892b -------------------------------------------------------------------------------- /source/robots.txt: -------------------------------------------------------------------------------- 1 | Sitemap: https://geektutu.com/sitemap.xml -------------------------------------------------------------------------------- /source/root.txt: -------------------------------------------------------------------------------- 1 | 1137985e41c3aa18f8322a84b31a3034 -------------------------------------------------------------------------------- /source/series/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: "category" 3 | title: "专题" 4 | description: "极客兔兔的博客的专题列表" 5 | --- 6 | -------------------------------------------------------------------------------- /source/sogousiteverification.txt: -------------------------------------------------------------------------------- 1 | mZQ8WVSroQ -------------------------------------------------------------------------------- /source/tags/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: "tags" 3 | title: "标签" 4 | description: "极客兔兔的博客的标签列表" 5 | header-img: "img/header_img/tag-bg.png" 6 | --- 7 | -------------------------------------------------------------------------------- /source/tool/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geektutu/blog/eb92e4fc81547c7cfb6a35b45fb53b42d6d8b0b0/source/tool/.gitkeep --------------------------------------------------------------------------------