├── .gitignore
├── LICENSE
├── README.md
├── askxml
├── __init__.py
├── askxml.py
├── column
│ ├── __init__.py
│ ├── data_types.py
│ └── keys.py
├── driver
│ ├── __init__.py
│ ├── driver.py
│ └── sqlite_driver.py
└── table.py
├── setup.py
└── tests
├── README.md
├── __init__.py
├── test_sqlite_driver.py
└── test_synchronize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Sphinx documentation
50 | docs/_build/
51 |
52 | # virtualenv
53 | .venv
54 | venv/
55 | ENV/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 kamac
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AskXML
2 | Run SQL statements on XML documents
3 |
4 | ```xml
5 |
6 | tasty kiwi
7 | old kiwi
8 | apple
9 |
10 | ```
11 |
12 | ```python
13 |
14 | >>> from askxml import AskXML
15 |
16 | >>> conn = AskXML('file.xml')
17 | # get an SQL cursor object
18 | >>> c = conn.cursor()
19 | >>> results = c.execute("SELECT color FROM fruit WHERE _text LIKE '% kiwi'")
20 | >>> for row in results.fetchall():
21 | >>> print(row)
22 | [('green'), ('dark green')]
23 |
24 | # cleanup
25 | >>> c.close()
26 | >>> conn.close()
27 | ```
28 |
29 | ## BUT WHY?
30 |
31 | There are data dumps like stack exchange's, stored in XML. They're big, so fitting them whole into memory is not desired. With AskXML you can query things fast, and rather comfortably (provided you know SQL).
32 |
33 | Before you go any further though, it's very possible your task can be achieved with XPATH and ElementTree XML API, so give that a look if you haven't heard of it.
34 |
35 | ## Installation
36 |
37 | AskXML requires Python 3.5+. Best way to install is to get it with pip:
38 |
39 | `pip install askxml`
40 |
41 | ## Usage
42 |
43 | #### Adding indexes and defining columns
44 |
45 | If you want to add indexes, columns or set attribute types, you can pass a list of table definitions:
46 |
47 | ```python
48 |
49 | from askxml import *
50 | tables = [
51 | Table('fruit',
52 | Column('age', Integer()),
53 | Index('age'))
54 | ]
55 | with AskXML('file.xml', table_definitions=tables) as conn:
56 | c = conn.cursor()
57 | c.execute("UPDATE fruit SET age = 5 WHERE _text = 'tasty kiwi'")
58 | c.close()
59 | ```
60 |
61 | You don't need to define all existing columns or tables. If a definition was not found, it's created with all column types being Text by default.
62 |
63 | #### Node hierarchy
64 |
65 | If you want to find nodes that are children of another node by attribute:
66 |
67 | ```xml
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | ```
76 |
77 | ```python
78 |
79 | from askxml import *
80 | with AskXML('file.xml') as conn:
81 | c = conn.cursor()
82 | results = c.execute("""
83 | SELECT name FROM someParent_someChild as child
84 | INNER JOIN someParent as parent ON parent._id = child._parentId
85 | WHERE parent.name = 'Jerry'
86 | """)
87 | for row in results.fetchall():
88 | print(row)
89 | c.close()
90 | ```
91 |
92 | This will print `[('Morty'), ('Summer')]`.
93 |
94 | #### Inserting new data
95 |
96 | If you want to add a new tag:
97 |
98 | ```python
99 | cursor.execute("INSERT INTO fruit (color, _text) VALUES ('red', 'strawberry')")
100 | ```
101 |
102 | Or if your nodes have a hierarchy:
103 |
104 | ```python
105 | cursor.execute("INSERT INTO someParent_someChild (name, _parentId) VALUES ('a baby', 1)")
106 | ```
107 |
108 | ## Contributing
109 |
110 | Any contributions are welcome.
111 |
112 | ## License
113 |
114 | AskXML is licensed under [MIT license](https://github.com/kamac/AskXML/blob/master/LICENSE)
--------------------------------------------------------------------------------
/askxml/__init__.py:
--------------------------------------------------------------------------------
1 | from .askxml import *
2 | from .column import *
3 | from .table import *
--------------------------------------------------------------------------------
/askxml/askxml.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | from typing import List
3 | from .table import Table
4 | import tempfile
5 | import os
6 |
7 | class AskXML:
8 | def __init__(self, source, table_definitions: List[Table] = None,
9 | persist_data: bool = True, driver = 'sqlite', serialize_ident: str = ' ',
10 | join_name: str = '_parentId', id_name: str = '_id', text_name: str = '_text', *args, **kwargs):
11 | """
12 | :param source: Path to .xml file to open, or file handle
13 | :param table_definitions: A list of table definitions
14 | :param persist_data: If enabled, changes to data will be saved to source XML file
15 | :param driver: Driver used to implement sql functionality. Can be a string or an object implementing Driver
16 | :param serialize_ident: Identation to use when serializing data to XML
17 | :param join_name: Name of the column that stores parent's ID
18 | :param id_name: Name of the column that stores node's ID
19 | :param text_name: Name of the column that stores node's text
20 | """
21 | self.persist_data = persist_data
22 | self.source = source
23 | self.join_name = join_name
24 | self.id_name = id_name
25 | self.text_name = text_name
26 | self.serialize_ident = serialize_ident
27 | if not hasattr(driver, '__call__'):
28 | driver = getattr(import_module('askxml.driver.' + driver + '_driver'), driver.capitalize() + 'Driver')
29 |
30 | self._driver = driver(source, table_definitions, join_name=join_name, id_name=id_name,
31 | text_name=text_name, *args, **kwargs)
32 |
33 | def synchronize(self):
34 | """
35 | Saves changes to source XML file
36 | """
37 | if not self.persist_data:
38 | return
39 |
40 | self._sync_cursor = self._driver.create_cursor()
41 | source_is_filename = isinstance(self.source, str)
42 | try:
43 | self._sync_file = open(self.filename, 'w+') if source_is_filename else self.source
44 | if not source_is_filename:
45 | self._sync_file.seek(0)
46 | self._sync_file.truncate()
47 | self._sync_file.seek(0)
48 |
49 | root_name, root_attrib = self._driver.get_xml_root()
50 | self._sync_file.write("<{tag}{properties}>\n".format(
51 | tag=root_name,
52 | properties=self._serialize_properties(root_attrib.items())))
53 |
54 | root_tables, self.__child_tables = self._driver.get_tables()
55 | for root_tag in root_tables:
56 | root_tags_data = self._sync_cursor.execute("SELECT * FROM {from_table}".format(from_table=root_tag))
57 | self._synchronize_tags(root_tags_data.fetchall(), table_scope=root_tag, ident=self.serialize_ident)
58 | self._sync_file.write("{tag}>\n".format(tag=root_name))
59 | finally:
60 | if source_is_filename:
61 | self._sync_file.close()
62 | self._sync_cursor.close()
63 |
64 | def _serialize_properties(self, properties):
65 | """
66 | Serialize a list of properties to a XML representation
67 |
68 | :param properties: A list of tuples (property_name, property_value,)
69 | """
70 | # filter out properties whose value is None, or name is join_name or id_name
71 | filtered_properties = [p for p in properties if p[1] is not None and p[0] != self.join_name\
72 | and p[0] != self.id_name and p[0] != self.text_name]
73 | if len(filtered_properties) > 0:
74 | return ' ' + ' '.join('{}="{}"'.format(name, val.replace('"', '"')) for name, val in filtered_properties)
75 | else:
76 | return ''
77 |
78 | def _synchronize_tags(self, tags_data, table_scope='', ident=''):
79 | field_names = [desc[0] for desc in self._sync_cursor.description]
80 | for tag_data in tags_data:
81 | name_value_properties = list(zip(field_names, tag_data))
82 | tag_id = tag_data[field_names.index(self.id_name)]
83 | tag_name = table_scope.split('_')[-1]
84 | child_tables = [c for c in self.__child_tables if c[:c.rfind('_')] == table_scope]
85 | text_value = ''
86 | try:
87 | text_value = next(p[1] for p in name_value_properties if p[0] == self.text_name)
88 | if not text_value:
89 | text_value = ''
90 | except:
91 | text_value = ''
92 |
93 | self._sync_file.write('{ident}<{tag_name}{properties}{immediate_close}>{text}{close_tag}\n'.format(
94 | ident=ident,
95 | tag_name=tag_name,
96 | properties=self._serialize_properties(name_value_properties),
97 | immediate_close=' /' if not child_tables and not text_value else '',
98 | text=text_value,
99 | close_tag='' + tag_name + '>' if text_value and not child_tables else ''))
100 |
101 | # synchronize this tag's children
102 | for child_tag in child_tables:
103 | chidren_data = self._sync_cursor.execute("""SELECT a.* FROM {from_table} AS a
104 | INNER JOIN {parent_table} AS b ON a.{join_name} = b.{id_name}
105 | WHERE b.{id_name} = {parent_id}""".format(
106 | from_table=child_tag,
107 | parent_table=table_scope,
108 | join_name=self.join_name,
109 | id_name=self.id_name,
110 | parent_id=tag_id
111 | ))
112 | self._synchronize_tags(chidren_data.fetchall(), table_scope=child_tag, ident=ident + self.serialize_ident)
113 |
114 | if child_tables:
115 | # close parent tag
116 | self._sync_file.write('{ident}{tag_name}>\n'.format(ident=ident, tag_name=tag_name))
117 |
118 | def close(self):
119 | """
120 | Closes connection to XML document
121 | """
122 | self.synchronize()
123 | self._driver.close()
124 |
125 | def cursor(self):
126 | return self._driver.create_cursor()
127 |
128 | def __enter__(self):
129 | return self
130 |
131 | def __exit__(self, type, value, traceback):
132 | self.close()
--------------------------------------------------------------------------------
/askxml/column/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_types import *
2 | from .keys import *
3 |
4 | class Column:
5 | """
6 | Column stores one XML attribute
7 | """
8 |
9 | def __init__(self, column_name: str, data_type: DataType, foreign_key: ForeignKey = None):
10 | """
11 | :param column_name: Name of SQL column
12 | :param data_type: Data type stored in column
13 | :param foreign_key: A foreign key constraint
14 | """
15 | self._column_name = column_name
16 | self._data_type = data_type
17 | if foreign_key:
18 | self._foreign_key = ForeignKey(foreign_key.foreign_table_name + '.' + foreign_key.foreign_column_name,
19 | column_name=column_name)
20 | else:
21 | self._foreign_key = None
22 |
23 | @property
24 | def column_name(self):
25 | return self._column_name
26 |
27 | @property
28 | def data_type(self):
29 | return self._data_type
30 |
31 | @property
32 | def foreign_key(self):
33 | return self._foreign_key
34 |
35 | def create_default(column_name: str):
36 | """Creates a default column definition for given column name."""
37 | return Column(column_name, Text())
--------------------------------------------------------------------------------
/askxml/column/data_types.py:
--------------------------------------------------------------------------------
1 | """
2 | Supported column data types. Analogous to SQLite's.
3 | Reference: https://www.sqlite.org/datatype3.html
4 | """
5 |
6 | class DataType:
7 | """Defines the type of the stored value"""
8 | def __str__(self):
9 | raise NotImplementedError()
10 |
11 | class Integer(DataType):
12 | def __str__(self):
13 | return "INTEGER"
14 |
15 | class Real(DataType):
16 | def __str__(self):
17 | return "REAL"
18 |
19 | class Text(DataType):
20 | def __str__(self):
21 | return "TEXT"
22 |
23 | class Blob(DataType):
24 | def __str__(self):
25 | return "BLOB"
--------------------------------------------------------------------------------
/askxml/column/keys.py:
--------------------------------------------------------------------------------
1 | """
2 | Supported column keys
3 | """
4 |
5 | class Key:
6 | """Defines whether column is indexed and how"""
7 | def __init__(self, column_name: str, *args):
8 | """
9 | :param column_name: Column affected by this key
10 | :param *args: A list of additional column names affected by key. Specify this
11 | to create a composite key
12 | """
13 | self._column_name = column_name
14 | self._column_names = frozenset([column_name] + list(args))
15 |
16 | @property
17 | def column_names(self):
18 | return self._column_names
19 |
20 | @property
21 | def column_name(self):
22 | return self._column_name
23 |
24 | class UniqueIndex(Key):
25 | pass
26 |
27 | class Index(Key):
28 | pass
29 |
30 | class PrimaryKey(Key):
31 | def __init__(self, column_name: str):
32 | """
33 | :param column_name: Column affected by this key
34 | """
35 | super().__init__(column_name)
36 |
37 | class ForeignKey(Key):
38 | def __init__(self, foreign_column: str, column_name: str = None):
39 | """
40 | :param foreign_column: Full foreign column name. Example: SOME_PARENT.id
41 | :param column_name: Column affected by this key.
42 | """
43 | super().__init__(column_name)
44 | self._foreign_column_name, self._foreign_table_name = foreign_column.split('.')
45 |
46 | @property
47 | def foreign_column_name(self):
48 | return self._foreign_column_name
49 |
50 | @property
51 | def foreign_table_name(self):
52 | return self._foreign_table_name
--------------------------------------------------------------------------------
/askxml/driver/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamac/AskXML/1a819361031ba82a4b2717301bcaa3855ba1c1a4/askxml/driver/__init__.py
--------------------------------------------------------------------------------
/askxml/driver/driver.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Tuple
3 |
4 | class Driver(ABC):
5 | def __init__(self, filename: str, table_definitions):
6 | pass
7 |
8 | @abstractmethod
9 | def get_xml_root(self):
10 | pass
11 |
12 | @abstractmethod
13 | def get_tables(self) -> Tuple[List[str], List[str]]:
14 | """
15 | Returns a tuple of (root table names, child table names)
16 | """
17 | pass
18 |
19 | @abstractmethod
20 | def create_cursor(self):
21 | pass
22 |
23 | @abstractmethod
24 | def close(self):
25 | pass
--------------------------------------------------------------------------------
/askxml/driver/sqlite_driver.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, AbstractSet, List, Tuple
2 | from abc import abstractmethod
3 | from askxml import column, table
4 | from .driver import Driver
5 | try:
6 | import lxml.etree as xml
7 | except ModuleNotFoundError:
8 | import xml.etree.cElementTree as xml
9 | import tempfile
10 | import os
11 | import sqlite3
12 |
13 | class EmptyTableException(Exception):
14 | pass
15 |
16 | class SqliteDriver(Driver):
17 | """
18 | Sqlite Driver works by setting up a .sqlite copy of XML document,
19 | and then running statements on it.
20 | """
21 |
22 | def __init__(self, source, table_definitions = None, join_name: str = '_parentId', id_name: str = '_id',
23 | text_name: str = '_text', in_memory_db: bool = False):
24 | """
25 | :param source: Path to .xml file to open, or file handle
26 | :param table_definitions: A dict of table name as keys table definitions as values
27 | :param join_name: Name of the column that stores parent's ID
28 | :param id_name: Name of the column that stores node's ID
29 | :param text_name: Name of the column that stores node's text
30 | :param in_memory_db: If set to True, sqlite's database will be stored in RAM rather than as
31 | a temporary file on disk.
32 | """
33 | self.join_name = join_name
34 | self.id_name = id_name
35 | sql_file = tempfile.TemporaryFile(mode='w+')
36 | if table_definitions:
37 | # convert table definitions from a list of tables into a dict
38 | # where key is table name and value is a Table object
39 | table_definitions = dict((table.table_name, table,) for table in table_definitions)
40 |
41 | self.__converter = Converter(source, sql_file,
42 | table_definitions=table_definitions, text_name=text_name,
43 | join_name=join_name, id_name=id_name)
44 | sql_file.seek(0)
45 |
46 | if not in_memory_db:
47 | handle, self.db_path = tempfile.mkstemp(suffix='.db')
48 | os.close(handle)
49 | else:
50 | self.db_path = ':memory:'
51 |
52 | self._conn = sqlite3.connect(self.db_path)
53 | # fill database with data
54 | cursor = self._conn.cursor()
55 | for query in sql_file:
56 | cursor.execute(query)
57 | cursor.close()
58 | self._conn.commit()
59 | sql_file.close()
60 |
61 | def get_xml_root(self):
62 | return self.__converter.root_name, self.__converter.root_attrib
63 |
64 | def get_tables(self) -> Tuple[List[str], List[str]]:
65 | cursor = self.create_cursor()
66 | try:
67 | root_tables = []
68 | child_tables = []
69 | tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
70 | for table_name in tables.fetchall():
71 | columns = cursor.execute("PRAGMA table_info('{}')".format(table_name[0])).fetchall()
72 | columns = [c[1] for c in columns]
73 | if self.join_name not in columns:
74 | root_tables.append(table_name[0])
75 | else:
76 | child_tables.append(table_name[0])
77 | return root_tables, child_tables
78 | finally:
79 | cursor.close()
80 |
81 | def create_cursor(self):
82 | return self._conn.cursor()
83 |
84 | def close(self):
85 | self._conn.close()
86 | if self.db_path != ':memory:':
87 | os.remove(self.db_path)
88 |
89 | class Converter:
90 | def __init__(self, source, outfile, table_definitions: Dict[str, table.Table] = None,
91 | text_name: str = None, join_name: str = None, id_name: str = None):
92 | """
93 | Converts an XML file to a .sqlite script
94 |
95 | :param source: Path to .xml file to open, or file handle
96 | :param outfile: File handle to where resulting .sql will be stored
97 | :param table_definitions: A dict of table name as keys table definitions as values
98 | :param join_name: Name of the column that stores parent's ID. Set to None to not join.
99 | :param id_name: Name of the column that stores node's ID. Set to None to not generate an ID.
100 | """
101 | if not table_definitions:
102 | table_definitions = {}
103 |
104 | self.table_definitions = table_definitions
105 | self.join_name = join_name
106 | self.id_name = id_name
107 | self.text_name = text_name
108 | # a pair of table_name : last free id
109 | self.id_cache: Dict[str, int] = {}
110 | # a set of table names which tells us in a quick way
111 | # whether we've already altered a table with id, join id and text column
112 | self._generated_meta_columns_cache: AbstractSet[str] = set()
113 |
114 | self.xmliter = xml.iterparse(source, events=("start", "end"))
115 | _, root = next(self.xmliter)
116 | self.root_name = root.tag
117 | self.root_attrib = root.attrib
118 | # a dict that holds all found tables and their columns
119 | self.tables: Dict[str, AbstractSet[str]] = {}
120 | # update found tables with user defined table definitions
121 | for table_name, table_definition in table_definitions.items():
122 | # also generate join keys for predefined table
123 | self.__generate_table_meta_columns(table_name)
124 | self.tables[table_name] = set(column_definition.column_name for column_definition in table_definition.column_definitions)
125 |
126 | # generate insert queries in a temporary file
127 | self.inserts_file = tempfile.TemporaryFile(mode='w+')
128 | self.__parse_node(root, None)
129 | self.inserts_file.seek(0)
130 |
131 | # generate create table statements
132 | constraint_definitions = []
133 | for table_name, columns in self.tables.items():
134 | table_definition = table_definitions.get(table_name, None)
135 |
136 | # create column parameters (column_name column_type [key])
137 | column_definitions = []
138 | for column_name in columns:
139 | column_info = None
140 | try:
141 | column_info = table_definition.get_column(column_name)
142 | except KeyError:
143 | column_info = column.Column.create_default(column_name)
144 |
145 | column_definition = column_name + ' ' + str(column_info.data_type)
146 | if column_info.foreign_key:
147 | column_definition = column_definition + ' REFERENCES {}({}) DEFERRABLE INITIALLY DEFERRED'.format(
148 | column_info.foreign_key.foreign_table_name,
149 | column_info.foreign_key.foreign_column_name
150 | )
151 | column_definitions.append(column_definition)
152 |
153 | if not column_definitions:
154 | raise EmptyTableException("SQLite cannot create an empty table '{}'".format(table_name))
155 |
156 | # create constraints
157 | if table_definition:
158 | for constraint in table_definition.constraint_definitions:
159 | if isinstance(constraint, column.UniqueIndex) or isinstance(constraint, column.Index):
160 | constraint_name = '_'.join(constraint.column_names) + '_index'
161 | unique_sql = 'UNIQUE' if isinstance(constraint, column.UniqueIndex) else ''
162 | constraint_definitions.append('CREATE {} INDEX {} ON {} ({})'.format(
163 | unique_sql,
164 | constraint_name,
165 | table_name,
166 | ','.join(constraint.column_names)))
167 | elif isinstance(constraint, column.ForeignKey):
168 | for i, definition in enumerate(column_definitions):
169 | if definition.startswith(constraint.column_name + ' '):
170 | column_definitions[i] = definition + ' REFERENCES {}({}) DEFERRABLE INITIALLY DEFERRED'.format(
171 | constraint.foreign_table_name,
172 | constraint.foreign_column_name
173 | )
174 | elif isinstance(constraint, column.PrimaryKey):
175 | for i, definition in enumerate(column_definitions):
176 | if definition.startswith(constraint.column_name + ' '):
177 | column_definitions[i] = definition + ' PRIMARY KEY'
178 | break
179 |
180 | outfile.write('CREATE TABLE {} ({});\n'.format(table_name, ','.join(column_definitions)))
181 |
182 | # copy insert queries from temp file to out file
183 | for line in self.inserts_file:
184 | outfile.write(line)
185 |
186 | # generate constraints
187 | for constraint in constraint_definitions:
188 | outfile.write(constraint + '\n')
189 |
190 | self.inserts_file.close()
191 |
192 | def __generate_table_meta_columns(self, table_name):
193 | self._generated_meta_columns_cache.add(table_name)
194 | table_definition = self.table_definitions[table_name]
195 | if self.id_name:
196 | # generate an id column
197 | table_definition.column_definitions.append(column.Column(self.id_name, column.Integer()))
198 | table_definition.constraint_definitions.append(column.PrimaryKey(self.id_name))
199 | # generate a join ID column
200 | if self.join_name and table_name.rfind('_') > -1:
201 | table_definition.column_definitions.append(column.Column(self.join_name, column.Integer(),
202 | column.ForeignKey(table_name[:table_name.rfind('_')] + '.' + self.id_name)))
203 | # generate a text column
204 | if self.text_name:
205 | table_definition.column_definitions.append(column.Column(self.text_name, column.Text()))
206 |
207 | def __parse_node(self, node, prev_node_attrib, table_scope=''):
208 | attributes = node.attrib
209 | if prev_node_attrib is not None:
210 | table_name = table_scope + node.tag
211 | table_definition = self.table_definitions.get(table_name, None)
212 | # if table was not defined, define it
213 | if not table_definition:
214 | table_definition = table.Table(table_name)
215 | self.table_definitions[table_name] = table_definition
216 |
217 | # update attributes with joined ID and ID
218 | if self.id_name:
219 | if table_name not in self.id_cache:
220 | self.id_cache[table_name] = 1
221 | attributes.update({ self.id_name: str(self.id_cache[table_name]) })
222 | self.id_cache[table_name] = self.id_cache[table_name] + 1
223 |
224 | if self.id_name in prev_node_attrib and self.join_name:
225 | attributes.update({self.join_name: prev_node_attrib[self.id_name]})
226 |
227 | stripped_text = node.text.strip() if node.text else None
228 | if self.text_name and stripped_text:
229 | attributes.update({ self.text_name: stripped_text })
230 |
231 | # update table definition with meta columns if needed
232 | if self.id_name or self.join_name or self.text_name:
233 | if table_name not in self._generated_meta_columns_cache:
234 | self.__generate_table_meta_columns(table_name)
235 |
236 | column_values = []
237 | for column_name, value in attributes.items():
238 | try:
239 | data_type = table_definition.get_column(column_name).data_type
240 | except KeyError as e:
241 | data_type = column.Column.create_default(column_name).data_type
242 |
243 | if isinstance(data_type, column.Text) or isinstance(data_type, column.Blob):
244 | column_values.append("'" + value.replace("'", "''").replace('\n', '\\n') + "'")
245 | else:
246 | column_values.append(value)
247 |
248 | self.inserts_file.write('INSERT INTO {} ({}) VALUES ({});\n'.format(
249 | table_name,
250 | ','.join(attributes.keys()),
251 | ','.join(column_values)
252 | ))
253 |
254 | # update table definitions
255 | if not table_name in self.tables:
256 | self.tables[table_name] = set(node.attrib.keys())
257 | else:
258 | self.tables[table_name].update(node.attrib.keys())
259 |
260 | table_scope = table_name + '_'
261 |
262 | # parse child nodes
263 | while True:
264 | event, child_node = next(self.xmliter)
265 | if event == 'end' and child_node.tag == node.tag:
266 | break
267 | else:
268 | self.__parse_node(child_node, attributes, table_scope=table_scope)
269 |
270 | # prevent eating up too much memory
271 | node.clear()
272 |
--------------------------------------------------------------------------------
/askxml/table.py:
--------------------------------------------------------------------------------
1 | from .column import *
2 |
3 | class Table:
4 | def __init__(self, table_name: str, *args):
5 | """
6 | :param table_name: XML tag preceded with parent tags separated by underscore.
7 | eg. Parent_Child_Child
8 | :param *args: A list of column and constraint definitions
9 | """
10 | self._table_name = table_name
11 | self.column_definitions = []
12 | self.constraint_definitions = []
13 | for arg in args:
14 | if isinstance(arg, Column):
15 | self.column_definitions.append(arg)
16 | elif isinstance(arg, Key):
17 | self.constraint_definitions.append(arg)
18 |
19 | def get_column(self, column_name: str):
20 | """
21 | Returns column definition. Raises a KeyError if column wasn't defined
22 |
23 | :param column_name: Name of the column to retrieve
24 | """
25 | try:
26 | return next(column for column in self.column_definitions if column.column_name == column_name)
27 | except StopIteration:
28 | raise KeyError()
29 |
30 | @property
31 | def table_name(self):
32 | return self._table_name
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from os import path
3 |
4 | here = path.abspath(path.dirname(__file__))
5 |
6 | # Get the long description from the README file
7 | with open(path.join(here, 'README.md'), encoding='utf-8') as f:
8 | long_description = f.read()
9 |
10 | setup(
11 | name='askxml',
12 | version='1.0.0',
13 | description='Run SQL statements on XML documents',
14 | long_description=long_description,
15 | long_description_content_type='text/markdown',
16 | url='https://github.com/kamac/AskXML',
17 | author='Maciej Kozik',
18 | classifiers=[
19 | 'Development Status :: 5 - Production/Stable',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Text Processing :: Markup :: XML',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.5',
24 | 'Programming Language :: Python :: 3.6',
25 | 'Programming Language :: Python :: 3.7',
26 | ],
27 | keywords='xml sql statements query',
28 | packages=find_packages(exclude=['contrib', 'docs', 'tests']),
29 | )
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | ### Running a single test
2 |
3 | To run a single test, do
4 |
5 | ```
6 | python test_name.py
7 | ```
8 |
9 | ### Running test suite
10 |
11 | To run all tests, do
12 |
13 | ```
14 | python -m unittest
15 | ```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamac/AskXML/1a819361031ba82a4b2717301bcaa3855ba1c1a4/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_sqlite_driver.py:
--------------------------------------------------------------------------------
1 | from askxml.driver.sqlite_driver import SqliteDriver
2 | from askxml.table import Table
3 | from askxml.column import *
4 | import tempfile
5 | import unittest
6 |
7 | _xml_file_simple = """
8 |
9 |
10 | Hello
11 |
12 |
13 |
14 | Hi
15 | """
16 |
17 | _xml_file_no_children = ""
18 |
19 | class TestSqliteDriver(unittest.TestCase):
20 | def test_get_tables(self):
21 | # test simple xml file
22 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
23 | f.write(_xml_file_simple)
24 | f.seek(0)
25 | driver = SqliteDriver(source=f)
26 |
27 | root_tables, child_tables = driver.get_tables()
28 | self.assertTrue('RootTable' in root_tables)
29 | self.assertTrue('RootTableSecond' in root_tables)
30 | self.assertEqual(len(root_tables), 2)
31 | self.assertTrue('RootTable_Child' in child_tables)
32 | self.assertEqual(len(child_tables), 1)
33 | driver.close()
34 |
35 | # test xml file without any tables
36 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
37 | f.write(_xml_file_no_children)
38 | f.seek(0)
39 | driver = SqliteDriver(source=f)
40 | root_tables, child_tables = driver.get_tables()
41 | self.assertEqual(len(root_tables), 0)
42 | self.assertEqual(len(child_tables), 0)
43 | driver.close()
44 |
45 | # test xml file with no nodes, but defined tables
46 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
47 | f.write(_xml_file_no_children)
48 | f.seek(0)
49 | driver = SqliteDriver(source=f, table_definitions=[Table('sometable')])
50 | root_tables, child_tables = driver.get_tables()
51 | self.assertTrue('sometable' in root_tables)
52 | self.assertEqual(len(root_tables), 1)
53 | self.assertEqual(len(child_tables), 0)
54 | driver.close()
55 |
56 | def test_tables_have_attributes(self):
57 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
58 | f.write(_xml_file_simple)
59 | f.seek(0)
60 | driver = SqliteDriver(source=f)
61 | cursor = driver.create_cursor()
62 | result = cursor.execute("SELECT first, second FROM RootTable WHERE _id=1").fetchall()
63 | self.assertEqual(len(result), 1)
64 | self.assertEqual(result[0][0], '1')
65 | self.assertEqual(result[0][1], '2')
66 | result = cursor.execute("SELECT third FROM RootTable_Child").fetchall()
67 | self.assertEqual(len(result), 2)
68 | self.assertEqual(result[0][0], None)
69 | self.assertEqual(result[1][0], '3')
70 | cursor.close()
71 | driver.close()
72 |
73 | def test_attributes_have_correct_types(self):
74 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
75 | f.write(_xml_file_simple)
76 | f.seek(0)
77 | driver = SqliteDriver(source=f, table_definitions=[Table('RootTable', Column('first', Integer()))])
78 | cursor = driver.create_cursor()
79 | result = cursor.execute("SELECT first, second FROM RootTable WHERE _id=1").fetchall()
80 | self.assertEqual(len(result), 1)
81 | self.assertEqual(result[0][0], 1)
82 | self.assertEqual(result[0][1], '2')
83 | cursor.close()
84 | driver.close()
85 |
86 | def test_tables_have_text(self):
87 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
88 | f.write(_xml_file_simple)
89 | f.seek(0)
90 | driver = SqliteDriver(source=f)
91 | cursor = driver.create_cursor()
92 | result = cursor.execute("SELECT _text FROM RootTable_Child ORDER BY _id ASC").fetchall()
93 | self.assertEqual(len(result), 2)
94 | self.assertEqual(result[0][0], 'Hello')
95 | self.assertEqual(result[1][0], None)
96 | cursor.close()
97 | driver.close()
98 |
99 | def test_tables_are_joining(self):
100 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
101 | f.write(_xml_file_simple)
102 | f.seek(0)
103 | driver = SqliteDriver(source=f)
104 | cursor = driver.create_cursor()
105 | result = cursor.execute("""SELECT (SELECT COUNT(*) FROM RootTable_Child AS c WHERE c._parentId = r._id)
106 | FROM RootTable AS r
107 | ORDER BY r._id ASC""").fetchall()
108 | self.assertEqual(len(result), 2)
109 | self.assertEqual(result[0][0], 2)
110 | self.assertEqual(result[1][0], 0)
111 | cursor.close()
112 | driver.close()
113 |
114 | if __name__ == '__main__':
115 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_synchronize.py:
--------------------------------------------------------------------------------
1 | from askxml import *
2 | import xml.etree.ElementTree as ET
3 | import tempfile
4 | import unittest
5 |
6 | _xml_file_simple = """
7 |
8 |
9 | Hello
10 |
11 |
12 |
13 | Hi
14 | """
15 |
16 | class TestSynchronize(unittest.TestCase):
17 | def test_preserves_structure(self):
18 | with tempfile.SpooledTemporaryFile(mode='w+') as f:
19 | f.write(_xml_file_simple)
20 | f.seek(0)
21 | AskXML(f).close()
22 | f.seek(0)
23 | tree = ET.parse(f)
24 | root = tree.getroot()
25 |
26 | self.assertEqual(root.tag, 'XML')
27 | self.assertFalse(root.attrib)
28 | # inspect root node
29 | children = [child for child in root]
30 | children_tags = [c.tag for c in children]
31 | self.assertTrue('RootTable' in children_tags)
32 | self.assertTrue('RootTableSecond' in children_tags)
33 | self.assertEqual(len(children), 3)
34 |
35 | RootTables = (c for c in children if c.tag == 'RootTable')
36 | # inspect first RootTable
37 | RootTable = next(RootTables)
38 | RootTable_children = [child for child in RootTable]
39 | RootTable_children_tags = [c.tag for c in RootTable_children]
40 | self.assertTrue('Child' in RootTable_children_tags)
41 | self.assertEqual(len(RootTable_children), 2)
42 | # inspect Children
43 | # make sure Child tags don't have any child nodes
44 | self.assertEqual(sum(1 for c in RootTable_children[0]), 0)
45 | self.assertEqual(sum(1 for c in RootTable_children[1]), 0)
46 | self.assertFalse(RootTable_children[0].attrib)
47 | self.assertEqual(RootTable_children[0].text, 'Hello')
48 | self.assertEqual(RootTable_children[1].attrib, {'third': '3'})
49 | self.assertFalse(RootTable_children[1].text)
50 |
51 | # inspect second RootTable
52 | RootTable = next(RootTables)
53 | self.assertEqual(sum(1 for c in RootTable), 0)
54 | self.assertFalse(RootTable.attrib)
55 |
56 | # inspect RootTableSecond
57 | RootTableSecond = next(c for c in children if c.tag == 'RootTableSecond')
58 | self.assertEqual(sum(1 for c in RootTableSecond), 0)
59 | self.assertFalse(RootTableSecond.attrib)
60 | self.assertEqual(RootTableSecond.text, 'Hi')
61 |
62 | if __name__ == '__main__':
63 | unittest.main()
--------------------------------------------------------------------------------