├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyetl ├── __init__.py ├── connections.py ├── dataset.py ├── es.py ├── mapping.py ├── reader.py ├── task.py ├── utils.py └── writer.py ├── setup.py └── tests ├── data ├── dst.txt ├── src.txt └── src.xlsx └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution / packaging 2 | __pycache__/ 3 | .idea 4 | env/ 5 | bin/ 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | eggs/ 10 | lib/ 11 | lib64/ 12 | parts/ 13 | sdist/ 14 | var/ 15 | logs/ 16 | venv/ 17 | tmp/ 18 | logs/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | install.* 23 | *.exe 24 | *.py[cod] 25 | sftp-config.json 26 | .Python 27 | demo.py 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | .pytest_cache 35 | htmlcov/ 36 | .tox/ 37 | .coverage 38 | .cache 39 | test.db 40 | nosetests.xml 41 | coverage.xml 42 | 43 | *.tar 44 | *.gz 45 | *.zip 46 | 47 | # Django stuff: 48 | *.log 49 | *.pot 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | 54 | # OS 55 | .DS_Store 56 | .DS_Store? 57 | ._* 58 | .Spotlight-V100 59 | .Trashes 60 | Icon? 61 | ehthumbs.db 62 | Thumbs.db 63 | MANIFEST -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include pyetl/templates * 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pyetl 2 | 3 | Pyetl is a **Python 3.6+** ETL framework 4 | 5 | ## Installation: 6 | ```shell script 7 | pip3 install pyetl 8 | ``` 9 | 10 | ## Example 11 | 12 | ```python 13 | import sqlite3 14 | import pymysql 15 | from pyetl import Task, DatabaseReader, DatabaseWriter, ElasticsearchWriter, FileWriter 16 | src = sqlite3.connect("file.db") 17 | reader = DatabaseReader(src, table_name="source_table") 18 | # 数据库之间数据同步,表到表传输 19 | dst = pymysql.connect(host="localhost", user="your_user", password="your_password", db="test") 20 | writer = DatabaseWriter(dst, table_name="target_table") 21 | Task(reader, writer).start() 22 | # 数据库表导出到文件 23 | writer = FileWriter(file_path="./", file_name="file.csv") 24 | Task(reader, writer).start() 25 | # 数据库表同步es 26 | writer = ElasticsearchWriter(index_name="target_index") 27 | Task(reader, writer).start() 28 | ``` 29 | 30 | #### 原始表目标表字段名称不同 31 | 32 | ```python 33 | import sqlite3 34 | from pyetl import Task, DatabaseReader, DatabaseWriter 35 | con = sqlite3.connect("file.db") 36 | # 原始表source_table包含uuid,full_name字段 37 | reader = DatabaseReader(con, table_name="source_table") 38 | # 目标表target_table包含id,name字段 39 | writer = DatabaseWriter(con, table_name="target_table") 40 | # columns配置目标表和原始表的字段映射 41 | columns = {"id": "uuid", "name": "full_name"} 42 | Task(reader, writer, columns=columns).start() 43 | ``` 44 | 45 | #### 添加字段的udf映射,对字段进行规则校验、数据标准化、数据清洗等 46 | ```python 47 | # functions配置字段的udf映射,如下id转字符串,name去除前后空格 48 | functions={"id": str, "name": lambda x: x.strip()} 49 | Task(reader, writer, columns=columns, functions=functions).start() 50 | ``` 51 | 52 | #### 继承Task,灵活扩展 53 | 54 | ```python 55 | import json 56 | from pyetl import Task, DatabaseReader, DatabaseWriter 57 | class NewTask(Task): 58 | reader = DatabaseReader("sqlite:///db.sqlite3", table_name="source") 59 | writer = DatabaseWriter("sqlite:///db.sqlite3", table_name="target") 60 | 61 | def get_columns(self): 62 | """通过函数的方式生成字段映射配置,使用更灵活""" 63 | # 以下示例将数据库中的字段映射配置取出后转字典类型返回 64 | sql = "select columns from task where name='new_task'" 65 | columns = self.writer.db.read_one(sql)["columns"] 66 | return json.loads(columns) 67 | 68 | def get_functions(self): 69 | """通过函数的方式生成字段的udf映射""" 70 | # 以下示例将每个字段类型都转换为字符串 71 | return {col: str for col in self.columns} 72 | 73 | def apply_function(self, record): 74 | """数据流中对一整条数据的udf""" 75 | record["flag"] = int(record["id"]) % 2 76 | return record 77 | 78 | def before(self): 79 | """任务开始前要执行的操作, 如初始化任务表,创建目标表等""" 80 | sql = "create table destination_table(id int, name varchar(100))" 81 | self.writer.db.execute(sql) 82 | 83 | def after(self): 84 | """任务完成后要执行的操作,如更新任务状态等""" 85 | sql = "update task set status='done' where name='new_task'" 86 | self.writer.db.execute(sql) 87 | 88 | NewTask().start() 89 | ``` 90 | 91 | ## Reader和Writer 92 | 93 | | Reader | 介绍 | 94 | | ------------------- | -------------------------- | 95 | | DatabaseReader | 支持所有关系型数据库的读取 | 96 | | FileReader | 结构化文本数据读取,如csv文件 | 97 | | ExcelReader | Excel表文件读取 | 98 | | ElasticsearchReader | 读取es索引数据 | 99 | 100 | | Writer | 介绍 | 101 | | ------------------- | -------------------------- | 102 | | DatabaseWriter | 支持所有关系型数据库的写入 | 103 | | ElasticsearchWriter | 批量写入数据到es索引 | 104 | | HiveWriter | 批量插入hive表 | 105 | | HiveWriter2 | Load data方式导入hive表(推荐) | 106 | | FileWriter | 写入数据到文本文件 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /pyetl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/4/30 11:27 上午 4 | @desc: python etl frame based on pandas for small dataset 5 | """ 6 | from .task import Task 7 | from .reader import DatabaseReader, FileReader, ExcelReader, ElasticsearchReader 8 | from .writer import DatabaseWriter, ElasticsearchWriter, HiveWriter, HiveWriter2, FileWriter 9 | 10 | __version__ = '2.2.2' 11 | __author__ = "liyatao" 12 | -------------------------------------------------------------------------------- /pyetl/connections.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/6/4 11:04 下午 4 | @desc: 5 | """ 6 | from pydbclib import connect, Database 7 | from sqlalchemy import engine 8 | 9 | from pyetl.es import Client 10 | 11 | 12 | class DatabaseConnection(object): 13 | 14 | def __init__(self, db): 15 | if isinstance(db, Database): 16 | self.db = db 17 | elif isinstance(db, engine.base.Engine) or hasattr(db, "cursor"): 18 | self.db = connect(driver=db) 19 | elif isinstance(db, dict): 20 | self.db = connect(**db) 21 | elif isinstance(db, str): 22 | self.db = connect(db) 23 | else: 24 | raise ValueError("db 参数类型错误") 25 | 26 | 27 | class ElasticsearchConnection(object): 28 | _client = None 29 | 30 | def __init__(self, es_params=None): 31 | if es_params is None: 32 | es_params = {} 33 | self.es_params = es_params 34 | 35 | @property 36 | def client(self): 37 | if self._client is None: 38 | self._client = Client(**self.es_params) 39 | return self._client 40 | -------------------------------------------------------------------------------- /pyetl/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/5/25 11:47 下午 4 | @desc: 5 | """ 6 | import itertools 7 | 8 | import pandas 9 | 10 | from pyetl.utils import limit_iterator 11 | 12 | 13 | class Dataset(object): 14 | 15 | def __init__(self, rows): 16 | self._rows = rows 17 | self.total = 0 18 | 19 | def __iter__(self): 20 | return self 21 | 22 | def next(self): 23 | self.total += 1 24 | return next(self._rows) 25 | 26 | __next__ = next 27 | 28 | def map(self, function): 29 | self._rows = (function(r) for r in self._rows) 30 | return self 31 | 32 | def filter(self, function): 33 | self._rows = (r for r in self._rows if function(r)) 34 | return self 35 | 36 | def rename(self, columns): 37 | """ 38 | 字段重命名 39 | """ 40 | def function(record): 41 | if isinstance(record, dict): 42 | return {columns.get(k, k): v for k, v in record.items()} 43 | else: 44 | raise ValueError("only rename dict record") 45 | return self.map(function) 46 | 47 | def rename_and_extract(self, columns): 48 | """ 49 | 字段投影,字段不存在的默认等于None 50 | """ 51 | def function(record): 52 | if isinstance(record, dict): 53 | return {v: record.get(k) for k, v in columns.items()} 54 | else: 55 | raise ValueError("only rename dict record") 56 | return self.map(function) 57 | 58 | def limit(self, num): 59 | self._rows = limit_iterator(self._rows, num) 60 | return self 61 | 62 | def get_one(self): 63 | r = self.get(1) 64 | return r[0] if len(r) > 0 else None 65 | 66 | def get(self, num): 67 | return [i for i in itertools.islice(self._rows, num)] 68 | 69 | def get_all(self): 70 | return [r for r in self._rows] 71 | 72 | def to_batch(self, size=10000): 73 | while 1: 74 | batch = self.get(size) 75 | if batch: 76 | yield batch 77 | else: 78 | return None 79 | 80 | def show(self, num=10): 81 | for data in self.limit(num): 82 | print(data) 83 | 84 | def write(self, writer): 85 | writer.write(self) 86 | 87 | def to_df(self, batch_size=None): 88 | if batch_size is None: 89 | return pandas.DataFrame.from_records(self) 90 | else: 91 | return self._df_generator(batch_size) 92 | 93 | def _df_generator(self, batch_size): 94 | while 1: 95 | records = self.get(batch_size) 96 | if records: 97 | yield pandas.DataFrame.from_records(records) 98 | else: 99 | return None 100 | 101 | def to_csv(self, file_path, batch_size=100000, **kwargs): 102 | """ 103 | 用于大数据量分批写入文件 104 | :param file_path: 文件路径 105 | :param sep: 分割符号,hive默认\001 106 | :param header: 是否写入表头 107 | :param columns: 按给定字段排序 108 | :param batch_size: 每批次写入文件行数 109 | """ 110 | kwargs.update(index=False) 111 | for df in self.to_df(batch_size=batch_size): 112 | df.to_csv(file_path, **kwargs) 113 | kwargs.update(mode="a", header=False) 114 | -------------------------------------------------------------------------------- /pyetl/es.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/5/11 3:14 下午 4 | @desc: 5 | """ 6 | import json 7 | 8 | from elasticsearch import Elasticsearch, helpers 9 | 10 | from pyetl.utils import batch_dataset, Singleton 11 | 12 | 13 | class SingletonES(Elasticsearch, metaclass=Singleton): 14 | pass 15 | 16 | 17 | def bulk_insert(es_params, docs, index_name, doc_type, batch_size=10000): 18 | def mapping(doc): 19 | return {"_index": index_name, "_type": doc_type, "_source": doc} 20 | 21 | docs = (mapping(doc) for doc in docs) 22 | res = helpers.parallel_bulk(SingletonES(**es_params), actions=docs, thread_count=2, chunk_size=batch_size) 23 | success_count, error_count = 0, 0 24 | for success, info in res: 25 | if success: 26 | success_count += 1 27 | else: 28 | error_count += 1 29 | return success_count, error_count 30 | 31 | 32 | class Index(object): 33 | 34 | def __init__(self, name, con, doc_type=None): 35 | self.name = name 36 | self.doc_type = doc_type 37 | self.es = con 38 | 39 | def scan(self): 40 | return helpers.scan(self.es, index=self.name, doc_type=self.doc_type) 41 | 42 | def search(self, body=None): 43 | return self.es.search(index=self.name, doc_type=self.doc_type, body=body) 44 | 45 | def get_columns(self): 46 | r = self.es.indices.get_mapping(self.name, doc_type=self.doc_type) 47 | columns = [] 48 | for index in r: 49 | mappings = r[index]["mappings"] 50 | if "properties" in mappings: 51 | columns.extend(mappings["properties"]) 52 | else: 53 | columns.extend(mappings[self.doc_type]["properties"]) 54 | return set(columns) 55 | 56 | def insert_one(self, doc): 57 | return self.es.index(index=self.name, doc_type=self.doc_type, body=doc) 58 | 59 | def bulk(self, docs, batch_size=10000): 60 | def mapping(doc): 61 | return {"_index": self.name, "_type": self.doc_type, "_source": doc} 62 | docs = (mapping(doc) for doc in docs) 63 | for batch in batch_dataset(docs, batch_size=batch_size): 64 | helpers.bulk(self.es, batch) 65 | 66 | def parallel_bulk(self, docs, batch_size=10000, thread_count=4): 67 | def mapping(doc): 68 | return {"_index": self.name, "_type": self.doc_type, "_source": doc} 69 | docs = (mapping(doc) for doc in docs) 70 | res = helpers.parallel_bulk(self.es, actions=docs, thread_count=thread_count, chunk_size=batch_size) 71 | success_count, error_count = 0, 0 72 | for success, info in res: 73 | if success: 74 | success_count += 1 75 | else: 76 | error_count += 1 77 | return success_count, error_count 78 | 79 | def delete_one(self, _id): 80 | self.es.delete(index=self.name, doc_type=self.doc_type, id=_id) 81 | 82 | def bulk_delete(self, body): 83 | """ 84 | 批量删除 85 | body = {'query': {'match': {"_id": "BxCklGwBt0482SoSeXuE"}}} 86 | demo_index.delete_many(body=body) 87 | """ 88 | self.es.delete_by_query(index=self.name, doc_type=self.doc_type, body=body) 89 | 90 | def create(self, settings): 91 | return self.es.indices.create(index=self.name, body=settings) 92 | 93 | def drop(self): 94 | return self.es.indices.delete(index=self.name, doc_type=self.doc_type, ignore=[400, 404]) 95 | 96 | 97 | class AliasManager(object): 98 | 99 | def __init__(self, name, es): 100 | self.name = name 101 | self.es = es 102 | 103 | def exists(self): 104 | return self.es.indices.exists_alias(self.name) 105 | 106 | def list(self): 107 | """ 108 | {'job-boss': {'aliases': {'job': {}}}, 'accounts': {'aliases': {'job': {}}}} 109 | """ 110 | if self.exists(): 111 | return self.es.indices.get_alias(name=self.name) 112 | else: 113 | return {} 114 | 115 | def add(self, index): 116 | return self.es.indices.put_alias(name=self.name, index=index) 117 | 118 | def remove(self, index): 119 | actions = json.dumps({ 120 | "actions": [ 121 | {"remove": {"index": index, "alias": self.name}}, 122 | ] 123 | }) 124 | return self.es.indices.update_aliases(body=actions) 125 | 126 | def drop(self): 127 | return self.es.indices.delete_alias(name=self.name, index="_all") 128 | 129 | 130 | class Client(Elasticsearch): 131 | 132 | def get_index(self, name, doc_type=None): 133 | return Index(name, self, doc_type=doc_type) 134 | 135 | def get_alias_manager(self, name): 136 | return AliasManager(name, self) 137 | 138 | 139 | def main(): 140 | es = Client() 141 | print(es.get_index("user*").get_columns()) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /pyetl/mapping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/4/30 11:29 上午 4 | @desc: 5 | """ 6 | 7 | 8 | class ColumnsMapping(object): 9 | 10 | def __init__(self, columns): 11 | self.raw_columns = columns 12 | self.alias, self.columns = self.get_src_columns_alias() 13 | 14 | def get_src_columns_alias(self): 15 | alias = {} 16 | for k, v in self.raw_columns.items(): 17 | if isinstance(v, (list, tuple)): 18 | for i, name in enumerate(v): 19 | alias.setdefault(name, "%s_%s" % (k, i)) 20 | else: 21 | alias.setdefault(v, k) 22 | 23 | columns = {} 24 | for k, v in self.raw_columns.items(): 25 | if isinstance(v, (list, tuple)): 26 | columns[k] = tuple(alias[n] for n in v) 27 | else: 28 | columns[k] = alias[v] 29 | return alias, columns 30 | 31 | 32 | class Mapping(object): 33 | 34 | def __init__(self, columns, functions, apply_function): 35 | self.columns = columns 36 | self.functions = functions 37 | self.apply_function = apply_function 38 | self.total = 0 39 | 40 | def __call__(self, record): 41 | result = {} 42 | for k, v in self.columns.items(): 43 | if isinstance(v, (list, tuple)): 44 | result[k] = self.functions.get(k, lambda x: ",".join(map(str, x)))(tuple(record.get(n) for n in v)) 45 | else: 46 | value = record.get(v) 47 | result[k] = self.functions[k](value) if k in self.functions else value 48 | self.total += 1 49 | return self.apply_function(result) 50 | -------------------------------------------------------------------------------- /pyetl/reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/4/30 11:28 上午 4 | @desc: 5 | """ 6 | from abc import ABC, abstractmethod 7 | 8 | import pandas 9 | 10 | from pyetl.connections import DatabaseConnection, ElasticsearchConnection 11 | from pyetl.dataset import Dataset 12 | 13 | 14 | class Reader(ABC): 15 | default_batch_size = 10000 16 | _columns = None 17 | _limit_num = None 18 | 19 | def read(self, columns): 20 | """返回结果列名必须rename""" 21 | dataset = self.get_dataset(columns) 22 | if isinstance(self._limit_num, int): 23 | dataset = dataset.limit(self._limit_num) 24 | return dataset 25 | 26 | @abstractmethod 27 | def get_dataset(self, columns): 28 | pass 29 | 30 | @property 31 | @abstractmethod 32 | def columns(self): 33 | return self._columns 34 | 35 | 36 | class DatabaseReader(DatabaseConnection, Reader): 37 | 38 | def __init__(self, db, table_name, condition=None, batch_size=None, limit=None): 39 | super().__init__(db) 40 | self.table_name = table_name 41 | self.table = self.db.get_table(self.table_name) 42 | self.condition = condition if condition else "1=1" 43 | self.batch_size = batch_size or self.default_batch_size 44 | self._limit_num = limit 45 | 46 | def _get_dataset(self, text): 47 | return Dataset((r for r in self.db.read(text, batch_size=self.batch_size))) 48 | 49 | def _query_text(self, columns): 50 | fields = [f"{col} as {alias}" for col, alias in columns.items()] 51 | return " ".join(["select", ",".join(fields), "from", self.table_name]) 52 | 53 | def get_dataset(self, columns): 54 | text = self._query_text(columns) 55 | if isinstance(self.condition, str): 56 | text = f"{text} where {self.condition}" 57 | dataset = self._get_dataset(text) 58 | elif callable(self.condition): 59 | dataset = self._get_dataset(text).filter(self.condition) 60 | else: 61 | raise ValueError("condition 参数类型错误") 62 | return dataset 63 | 64 | @property 65 | def columns(self): 66 | if self._columns is None: 67 | self._columns = self.db.get_table(self.table_name).get_columns() 68 | return self._columns 69 | 70 | 71 | class FileReader(Reader): 72 | 73 | def __init__(self, file_path, pd_params=None, limit=None): 74 | self.file_path = file_path 75 | self._limit_num = limit 76 | if pd_params is None: 77 | pd_params = {} 78 | pd_params.setdefault("chunksize", self.default_batch_size) 79 | self.file = pandas.read_csv(self.file_path, **pd_params) 80 | 81 | def _get_records(self, columns): 82 | for df in self.file: 83 | df = df.where(df.notnull(), None).reindex(columns=columns).rename(columns=columns) 84 | for record in df.to_dict("records"): 85 | yield record 86 | 87 | def get_dataset(self, columns): 88 | return Dataset(self._get_records(columns)) 89 | 90 | @property 91 | def columns(self): 92 | if self._columns is None: 93 | self._columns = [col for col in self.file.read(0).columns] 94 | return self._columns 95 | 96 | 97 | class ExcelReader(Reader): 98 | 99 | def __init__(self, file, sheet_name=0, pd_params=None, limit=None, detect_table_border=True): 100 | if pd_params is None: 101 | pd_params = {} 102 | pd_params.setdefault("dtype", 'object') 103 | self.sheet_name = sheet_name 104 | self._limit_num = limit 105 | if isinstance(file, str): 106 | file = pandas.ExcelFile(file) 107 | self.df = file.parse(self.sheet_name, **pd_params) 108 | elif isinstance(file, pandas.ExcelFile): 109 | self.df = file.parse(self.sheet_name, **pd_params) 110 | elif isinstance(file, pandas.DataFrame): 111 | self.df = file 112 | else: 113 | raise ValueError(f"file 参数类型错误") 114 | if detect_table_border: 115 | self.detect_table_border() 116 | 117 | def get_dataset(self, columns): 118 | df = self.df.where(self.df.notnull(), None).reindex(columns=columns).rename(columns=columns) 119 | return Dataset(df.to_dict("records")) 120 | 121 | @property 122 | def columns(self): 123 | if self._columns is None: 124 | self._columns = [col for col in self.df.columns] 125 | return self._columns 126 | 127 | def detect_table_border(self): 128 | y, x = self.df.shape 129 | axis_x = self.df.count() 130 | for i in range(axis_x.size): 131 | name = axis_x.index[i] 132 | count = axis_x.iloc[i] 133 | if isinstance(name, str) and name.startswith("Unnamed:") and count == 0: 134 | x = i 135 | break 136 | axis_y = self.df.count(axis=1) 137 | for i in range(axis_y.size): 138 | count = axis_y.iloc[i] 139 | if count == 0: 140 | y = i 141 | break 142 | self.df = self.df.iloc[:y, :x] 143 | 144 | 145 | class ElasticsearchReader(ElasticsearchConnection, Reader): 146 | 147 | def __init__(self, index_name, doc_type=None, es_params=None, batch_size=None, limit=None): 148 | super().__init__(es_params) 149 | self.index_name = index_name 150 | self.doc_type = doc_type 151 | self.batch_size = batch_size or self.default_batch_size 152 | self._limit_num = limit 153 | self.index = self.client.get_index(self.index_name, self.doc_type) 154 | 155 | def get_dataset(self, columns): 156 | return Dataset(doc["_source"] for doc in self.index.scan()).rename_and_extract(columns) 157 | 158 | @property 159 | def columns(self): 160 | if self._columns is None: 161 | self._columns = self.index.get_columns() 162 | return self._columns 163 | -------------------------------------------------------------------------------- /pyetl/task.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/5/26 11:40 上午 4 | @desc: 5 | """ 6 | from pyetl.mapping import ColumnsMapping, Mapping 7 | from pyetl.reader import Reader 8 | from pyetl.utils import print_run_time, validate_param 9 | from pyetl.writer import Writer 10 | 11 | 12 | class Task(object): 13 | _dataset = None 14 | reader = None 15 | writer = None 16 | columns = None 17 | functions = None 18 | 19 | def __init__(self, reader=None, writer=None, columns=None, functions=None): 20 | if reader is not None: 21 | self.reader = reader 22 | if writer is not None: 23 | self.writer = writer 24 | if not getattr(self, 'reader', None): 25 | raise ValueError("%s must have a reader" % type(self).__name__) 26 | if not isinstance(self.reader, Reader): 27 | raise ValueError("reader类型错误") 28 | if self.writer and not isinstance(self.writer, Writer): 29 | raise ValueError("writer类型错误") 30 | if columns is not None: 31 | self.columns = columns 32 | if functions is not None: 33 | self.functions = validate_param("functions", functions, dict) 34 | self.columns = self.get_columns() 35 | self.functions = self.get_functions() 36 | self.columns_mapping = ColumnsMapping(self.columns) 37 | self.mapping = Mapping(self.columns_mapping.columns, self.functions, self.apply_function) 38 | 39 | def get_columns(self): 40 | if self.columns is None: 41 | return {col: col for col in self.reader.columns} 42 | if isinstance(self.columns, dict): 43 | return {i: j for i, j in self.columns.items()} 44 | elif isinstance(self.columns, set): 45 | return {c: c for c in self.columns} 46 | else: 47 | raise ValueError("columns 参数错误") 48 | 49 | def get_functions(self): 50 | if self.functions: 51 | return self.functions 52 | else: 53 | return {} 54 | 55 | def apply_function(self, record): 56 | return record 57 | 58 | def filter_function(self, record): 59 | return True 60 | 61 | def before(self): 62 | pass 63 | 64 | def after(self): 65 | pass 66 | 67 | def show(self, num=10): 68 | self.dataset.show(num) 69 | 70 | @property 71 | def total(self): 72 | return self.mapping.total 73 | 74 | @property 75 | def dataset(self): 76 | if self._dataset is None: 77 | self._dataset = self.reader.read(self.columns_mapping.alias).map(self.mapping).filter(self.filter_function) 78 | return self._dataset 79 | 80 | @print_run_time 81 | def start(self): 82 | if not getattr(self, "writer", None): 83 | raise ValueError("%s must have a writer" % type(self).__name__) 84 | self.before() 85 | self.dataset.write(self.writer) 86 | self.after() 87 | 88 | -------------------------------------------------------------------------------- /pyetl/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/4/30 11:29 上午 4 | @desc: 5 | """ 6 | import time 7 | import functools 8 | 9 | 10 | class Singleton(type): 11 | 12 | def __init__(cls, *args, **kwargs): 13 | cls.__instance = None 14 | super().__init__(*args, **kwargs) 15 | 16 | def __call__(cls, *args, **kwargs): 17 | if cls.__instance is None: 18 | cls.__instance = super().__call__(*args, **kwargs) 19 | return cls.__instance 20 | 21 | 22 | def limit_iterator(rows, limit): 23 | for i, r in enumerate(rows): 24 | if i < limit: 25 | yield r 26 | else: 27 | return None 28 | 29 | 30 | def validate_param(name, value, type_or_types): 31 | if isinstance(value, type_or_types): 32 | return value 33 | else: 34 | raise ValueError(f"{name} 参数错误") 35 | 36 | 37 | def lower_columns(x): 38 | if isinstance(x, (list, tuple)): 39 | return tuple([i.lower() for i in x]) 40 | else: 41 | return x.lower() 42 | 43 | 44 | def batch_dataset(dataset, batch_size): 45 | cache = [] 46 | for data in dataset: 47 | cache.append(data) 48 | if len(cache) >= batch_size: 49 | yield cache 50 | cache = [] 51 | if cache: 52 | yield cache 53 | 54 | 55 | def print_run_time(func): 56 | @functools.wraps(func) 57 | def wrapper(*args, **kwargs): 58 | start = time.time() 59 | r = func(*args, **kwargs) 60 | cost = time.time() - start 61 | cost = round(cost, 3) 62 | print(f"{func.__name__}函数执行了{cost}s") 63 | return r 64 | return wrapper 65 | 66 | 67 | @print_run_time 68 | def main(): 69 | pass 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /pyetl/writer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/4/30 11:28 上午 4 | @desc: 5 | """ 6 | import hashlib 7 | import os 8 | import random 9 | import shutil 10 | import sys 11 | import time 12 | from abc import ABC, abstractmethod 13 | from multiprocessing.pool import Pool 14 | 15 | from pyetl.connections import DatabaseConnection, ElasticsearchConnection 16 | from pyetl.es import bulk_insert 17 | from pyetl.utils import batch_dataset 18 | 19 | 20 | class Writer(ABC): 21 | default_batch_size = 100000 22 | 23 | @abstractmethod 24 | def write(self, dataset): 25 | pass 26 | 27 | 28 | class DatabaseWriter(DatabaseConnection, Writer): 29 | 30 | def __init__(self, db, table_name, batch_size=None): 31 | super().__init__(db) 32 | self.table_name = table_name 33 | self.table = self.db.get_table(self.table_name) 34 | self.batch_size = batch_size or self.default_batch_size 35 | 36 | def write(self, dataset): 37 | self.db.get_table(self.table_name).bulk(dataset, batch_size=self.batch_size) 38 | 39 | 40 | class ElasticsearchWriter(ElasticsearchConnection, Writer): 41 | 42 | def __init__(self, index_name, doc_type=None, es_params=None, parallel_num=None, batch_size=10000): 43 | super().__init__(es_params) 44 | self._index = None 45 | self.index_name = index_name 46 | self.doc_type = doc_type 47 | self.batch_size = batch_size or self.default_batch_size 48 | self.parallel_num = parallel_num 49 | self.index = self.client.get_index(self.index_name, self.doc_type) 50 | 51 | def write(self, dataset): 52 | if self.parallel_num is None or "win" in sys.platform: 53 | self.index.parallel_bulk(docs=dataset, batch_size=self.batch_size) 54 | else: 55 | pool = Pool(self.parallel_num) 56 | for batch in batch_dataset(dataset, self.batch_size): 57 | pool.apply_async(bulk_insert, args=(self.es_params, batch, self.index.name, self.index.doc_type)) 58 | pool.close() 59 | pool.join() 60 | 61 | 62 | class HiveWriter(DatabaseConnection, Writer): 63 | """ 64 | insert dataset to hive table by 'insert into' sql 65 | """ 66 | 67 | def __init__(self, db, table_name, batch_size=None): 68 | super().__init__(db) 69 | self.table_name = table_name 70 | self.batch_size = batch_size or self.default_batch_size 71 | self._columns = None 72 | 73 | @property 74 | def columns(self): 75 | if self._columns is None: 76 | r = self.db.execute(f"select * from {self.table_name} limit 0") 77 | r.fetchall() 78 | self._columns = r.get_columns() 79 | return self._columns 80 | 81 | def complete_all_fields(self, record): 82 | return {k: record.get(k, "") for k in self.columns} 83 | 84 | def write(self, dataset): 85 | self.db.get_table(self.table_name).bulk(dataset.map(self.complete_all_fields), batch_size=self.batch_size) 86 | 87 | 88 | class HiveWriter2(HiveWriter): 89 | """ 90 | insert dataset to hive table by 'load data' sql 91 | """ 92 | cache_file = ".pyetl_hive_cache" 93 | 94 | def __init__(self, db, table_name, batch_size=1000000, hadoop_path=None, delimited="\001"): 95 | super().__init__(db, table_name, batch_size) 96 | self.file_name = self._get_local_file_name() 97 | self.local_path = os.path.join(self.cache_file, self.file_name) 98 | self.delimited = delimited 99 | self.hadoop = hadoop_path if hadoop_path else "hadoop" 100 | 101 | def _get_local_file_name(self): 102 | # 注意 table_name 可能是多表关联多情况,如 t1 left t2 using(uuid) 103 | # code = random.randint(1000, 9999) 104 | # return f"pyetl_dst_table_{'_'.join(self.table_name.split())}_{code}" 105 | uuid = hashlib.md5(self.table_name.encode()) 106 | return f"{uuid.hexdigest()}-{int(time.time())}" 107 | 108 | def clear(self): 109 | shutil.rmtree(self.local_path) 110 | 111 | def write(self, dataset): 112 | file_writer = FileWriter( 113 | self.local_path, header=False, sep=self.delimited, columns=self.columns, batch_size=self.batch_size) 114 | file_writer.write(dataset.map(self.complete_all_fields)) 115 | try: 116 | self.load_data() 117 | finally: 118 | self.clear() 119 | 120 | def to_csv(self, dataset): 121 | dataset.map(self.complete_all_fields).to_csv( 122 | self.local_path, header=False, sep="\001", columns=self.columns, batch_size=self.batch_size) 123 | 124 | def load_data(self): 125 | if os.system(f"{self.hadoop} fs -put {self.local_path} /tmp/{self.file_name}") == 0: 126 | try: 127 | self.db.execute(f"load data inpath '/tmp/{self.file_name}' into table {self.table_name}") 128 | finally: 129 | os.system(f"{self.hadoop} fs -rm -r /tmp/{self.file_name}") 130 | else: 131 | print("上传HDFS失败:", self.file_name) 132 | 133 | 134 | class FileWriter(Writer): 135 | 136 | def __init__(self, file_path, file_name=None, batch_size=None, header=True, sep=",", columns=None): 137 | self.file_path = file_path 138 | self.file_name = file_name 139 | if not os.path.exists(file_path): 140 | os.makedirs(self.file_path) 141 | self.batch_size = batch_size or self.default_batch_size 142 | self.kw = dict(header=header, sep=sep, columns=columns) 143 | 144 | def write(self, dataset): 145 | if self.file_name: 146 | dataset.to_csv(os.path.join(self.file_path, self.file_name), batch_size=self.batch_size, **self.kw) 147 | else: 148 | self.to_csv_files(dataset, self.file_path, batch_size=self.batch_size, **self.kw) 149 | 150 | @classmethod 151 | def to_csv_files(cls, dataset, path, batch_size=100000, **kwargs): 152 | for i, df in enumerate(dataset.to_df(batch_size=batch_size)): 153 | file = os.path.join(path, f"{i:0>8}") 154 | df.to_csv(file, index=False, **kwargs) 155 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | from setuptools import setup, find_packages 4 | 5 | _version_re = re.compile(r'__version__\s+=\s+(.*)') 6 | 7 | with open('pyetl/__init__.py', 'rb') as f: 8 | rs = _version_re.search(f.read().decode('utf-8')).group(1) 9 | version = str(ast.literal_eval(rs)) 10 | 11 | setup( 12 | name='pyetl', 13 | version=version, 14 | install_requires=['pydbclib>=2.2.2', 'pandas>=0.22', 'elasticsearch', 'sqlalchemy'], 15 | description='Python ETL Frame', 16 | classifiers=[ 17 | 'Development Status :: 5 - Production/Stable', 18 | 'Intended Audience :: Developers', 19 | 'License :: OSI Approved :: Apache Software License', 20 | 'Programming Language :: Python', 21 | 'Programming Language :: Python :: 3.6', 22 | 'Programming Language :: Python :: 3.7', 23 | 'Programming Language :: Python :: 3.8', 24 | ], 25 | author='liyatao', 26 | url='https://github.com/taogeYT/pyetl', 27 | author_email='li_yatao@outlook.com', 28 | license='Apache 2.0', 29 | packages=find_packages(), 30 | include_package_data=False, 31 | zip_safe=True, 32 | python_requires='>=3.6', 33 | # entry_points={ 34 | # 'console_scripts': ['pyetl = pyetl.cli:main'] 35 | # } 36 | ) 37 | -------------------------------------------------------------------------------- /tests/data/dst.txt: -------------------------------------------------------------------------------- 1 | id,name 2 | 1,python ETL framework 3 | -------------------------------------------------------------------------------- /tests/data/src.txt: -------------------------------------------------------------------------------- 1 | uuid,full_name 2 | 1,python ETL framework -------------------------------------------------------------------------------- /tests/data/src.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taogeYT/pyetl/26ae6612f2c37460deec478933b6bcd9456484f3/tests/data/src.xlsx -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2020/5/6 4:30 下午 4 | @desc: 5 | """ 6 | import os 7 | import unittest 8 | 9 | from pydbclib import connect 10 | 11 | from pyetl.dataset import Dataset 12 | from pyetl.task import Task 13 | from pyetl.reader import DatabaseReader, FileReader, ExcelReader 14 | from pyetl.writer import DatabaseWriter, FileWriter, HiveWriter 15 | 16 | 17 | class BaseTest(unittest.TestCase): 18 | db = None 19 | src_record = {"uuid": 1, "full_name": "python ETL framework"} 20 | dst_record = {"id": 1, "name": "python ETL framework"} 21 | columns = {"id": "uuid", "name": "full_name"} 22 | 23 | @classmethod 24 | def get_file_path(cls, name): 25 | return os.path.join(os.path.dirname(__file__), 'data', name) 26 | 27 | @staticmethod 28 | def get_db(): 29 | return connect("sqlite:///:memory:") 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | cls.db = cls.get_db() 34 | create_src = """CREATE TABLE "src" ("uuid" INTEGER NOT NULL,"full_name" TEXT,PRIMARY KEY ("uuid"))""" 35 | cls.db.execute(create_src) 36 | create_dst = """CREATE TABLE "dst" ("id" INTEGER NOT NULL,"name" TEXT,PRIMARY KEY ("id"))""" 37 | cls.db.execute(create_dst) 38 | cls.db.get_table("src").insert(cls.src_record) 39 | 40 | @classmethod 41 | def tearDownClass(cls) -> None: 42 | cls.db.execute("drop table src") 43 | cls.db.execute("drop table dst") 44 | cls.db.commit() 45 | 46 | @classmethod 47 | def get_dataset(cls, record): 48 | return Dataset(iter([record])) 49 | 50 | 51 | class TestReader(BaseTest): 52 | 53 | def validate(self, reader): 54 | self.assertEqual(reader.columns, ["uuid", "full_name"]) 55 | r = reader.read(columns={"uuid": "id", "full_name": "name"}) 56 | self.assertTrue(isinstance(r, Dataset)) 57 | self.assertEqual(r.get_all(), self.get_dataset(self.dst_record).get_all()) 58 | 59 | def test_db_reader(self): 60 | reader = DatabaseReader(self.db, "src") 61 | self.validate(reader) 62 | 63 | def test_db_reader_by_engine(self): 64 | engine = self.db.driver.engine 65 | reader = DatabaseReader(engine, "src") 66 | self.validate(reader) 67 | 68 | def test_file_reader(self): 69 | reader = FileReader(self.get_file_path("src.txt")) 70 | self.validate(reader) 71 | 72 | def test_excel_reader(self): 73 | reader = ExcelReader(self.get_file_path("src.xlsx")) 74 | self.validate(reader) 75 | 76 | 77 | class TestWriter(BaseTest): 78 | 79 | @staticmethod 80 | def get_db(): 81 | return connect(":memory:", driver="sqlite3") 82 | 83 | def test_db_writer(self): 84 | writer = DatabaseWriter(self.db, "dst") 85 | writer.table.delete("1=1") 86 | writer.write(self.get_dataset(self.dst_record)) 87 | task = Task(DatabaseReader(self.db, "dst")) 88 | self.assertEqual(task.dataset.get_all(), self.get_dataset(self.dst_record).get_all()) 89 | 90 | def test_db_reader_by_con(self): 91 | con = self.db.driver.con 92 | writer = DatabaseWriter(con, "dst") 93 | writer.table.delete("1=1") 94 | writer.write(self.get_dataset(self.dst_record)) 95 | task = Task(DatabaseReader(self.db, "dst")) 96 | self.assertEqual(task.dataset.get_all(), self.get_dataset(self.dst_record).get_all()) 97 | 98 | def test_file_writer(self): 99 | file = self.get_file_path("dst.txt") 100 | path, name = os.path.split(file) 101 | writer = FileWriter(path, name) 102 | writer.write(self.get_dataset(self.dst_record)) 103 | task = Task(FileReader(file)) 104 | self.assertEqual(task.dataset.get_all(), self.get_dataset(self.dst_record).get_all()) 105 | 106 | 107 | class TestHiveWriter(BaseTest): 108 | 109 | @staticmethod 110 | def get_db(): 111 | return connect("hive://localhost:10000/default") 112 | 113 | @classmethod 114 | def setUpClass(cls): 115 | cls.db = cls.get_db() 116 | create_src = """CREATE TABLE src (uuid int,full_name string)""" 117 | cls.db.execute(create_src) 118 | create_dst = """CREATE TABLE dst (id int,name string)""" 119 | cls.db.execute(create_dst) 120 | cls.db.get_table("src").insert(cls.src_record) 121 | 122 | def test_hive_writer(self): 123 | writer = HiveWriter(self.db, "dst") 124 | writer.write(self.get_dataset(self.dst_record)) 125 | task = Task(DatabaseReader(self.db, "dst")) 126 | self.assertEqual(task.dataset.get_all(), self.get_dataset(self.dst_record).get_all()) 127 | 128 | 129 | class TestTask(BaseTest): 130 | 131 | def test_no_columns(self): 132 | reader = DatabaseReader(self.db, "src") 133 | task = Task(reader) 134 | self.assertEqual(task.dataset.get_all(), [self.src_record]) 135 | 136 | def test_set_columns(self): 137 | reader = DatabaseReader(self.db, "src") 138 | task = Task(reader, columns={"uuid"}) 139 | self.assertEqual(task.dataset.get_all(), [{"uuid": 1}]) 140 | 141 | def test_dict_columns(self): 142 | reader = DatabaseReader(self.db, "src") 143 | task = Task(reader, columns=self.columns) 144 | self.assertEqual(task.dataset.get_all(), self.get_dataset(self.dst_record).get_all()) 145 | 146 | 147 | if __name__ == '__main__': 148 | unittest.main() 149 | --------------------------------------------------------------------------------