├── .gitignore ├── LICENSE ├── README.md ├── app-tests ├── node │ ├── __init__.py │ ├── app.py │ ├── controller.py │ └── models.py └── server │ ├── __init__.py │ ├── app.py │ └── models.py ├── dbsync ├── __init__.py ├── client │ ├── __init__.py │ ├── compression.py │ ├── conflicts.py │ ├── net.py │ ├── ping.py │ ├── pull.py │ ├── push.py │ ├── register.py │ ├── repair.py │ ├── serverquery.py │ └── tracking.py ├── core.py ├── dialects.py ├── lang.py ├── logs.py ├── messages │ ├── __init__.py │ ├── base.py │ ├── codecs.py │ ├── pull.py │ ├── push.py │ └── register.py ├── models.py ├── server │ ├── __init__.py │ ├── conflicts.py │ ├── handlers.py │ ├── tracking.py │ └── trim.py └── utils.py ├── dev-requirements.txt ├── diagram.gv ├── diagram.png ├── merge.md ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── codec_tests.py ├── conflict_detection_tests.py ├── models.py ├── pull_message_tests.py ├── push_message_tests.py └── track_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Emacs 2 | *~ 3 | \#*\# 4 | 5 | # Python 6 | *.pyc 7 | 8 | # SQLite databases 9 | *.db -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © 2015 Bint 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | “Software”), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | dbsync 2 | ====== 3 | 4 | A python library for centralized database synchronization, built over 5 | [SQLAlchemy's ORM](http://docs.sqlalchemy.org/en/latest/orm/tutorial.html). 6 | 7 | The library aims to enable applications to function offline when their 8 | internet connection is lost, by using a local database and providing a 9 | few synchronization procedures: `pull`, `push`, `register` and 10 | `repair`. 11 | 12 | This library is currently undergoing testing in a real application. 13 | 14 | ## Restrictions ## 15 | 16 | To work properly however, the library requires that several 17 | restrictions be met: 18 | 19 | - All primary keys must be integer values of no importance to the 20 | logic of the application. If using SQLite, these would be INTEGER 21 | PRIMARY KEY AUTOINCREMENT fields. 22 | 23 | - All primary keys must be unique to the row through the history of 24 | the table. This means the reuse of primary keys is likely to cause 25 | problems. In SQLite this behaviour can be achieved by changing the 26 | default PRIMARY KEY algorithm to AUTOINCREMENT (the 27 | `sqlite_autoincrement` must be set to `True` for the table in 28 | SQLAlchemy as specified 29 | [here](http://docs.sqlalchemy.org/en/rel_0_8/dialects/sqlite.html#auto-incrementing-behavior)). 30 | 31 | - All synched tables should be wrapped with a mapped class. This 32 | includes many-to-many tables. This restriction should hopefully be 33 | lifted in the future, though for now you should follow this 34 | [suggested pattern](http://docs.sqlalchemy.org/en/rel_0_8/orm/relationships.html#association-object). 35 | 36 | - Push and pull client-side synchronization procedures can't be 37 | invoked parallel to other transactions. This is consequence of bad 38 | design and should change in the future. For now, don't invoke these 39 | in a seperate thread if other transactions might run concurrently, 40 | or the transaction won't be registered correctly. 41 | 42 | ## Explanation ## 43 | 44 | Dbsync works by registering database operations (insert, update, 45 | delete) in seperate tables. These are detected through the SQLAlchemy 46 | event interface, and form a kind of operations log. 47 | 48 | The synchronization process starts with the `push` procedure. In it, 49 | the client application builds a message containing only the required 50 | database objects, deciding which to include according to the 51 | operations log, and sends it to the server to execute. If the server 52 | allows the `push`, both the client and the server databases become 53 | equivalent and the process is complete. 54 | 55 | The `push` won't be allowed by the server if it's database has 56 | advanced further since the last synchronization. If the `push` is 57 | rejected, the client should execute the `pull` procedure. The `pull` 58 | will fetch all operations executed on the server since the divergence 59 | point, and will merge those with the client's operation log. This 60 | merge operation executes internally and includes the conflict 61 | resolution phase, which ideally will resolve the potential operation 62 | collisions ([further explanation][merge-subroutine] of the merge 63 | subroutine). 64 | 65 | [merge-subroutine]: https://github.com/bintlabs/python-sync-db/blob/master/merge.md 66 | 67 | If the `pull` procedure completes successfully, the client application 68 | may attempt another `push`, as shown by the cycle in the diagram 69 | below. 70 | 71 | ![Synchronization sequence](https://raw.github.com/bintlabs/python-sync-db/master/diagram.png) 72 | 73 | ### Additional procedures ### 74 | 75 | #### Registering nodes #### 76 | 77 | The `register` procedure exists to provide a mechanism for nodes to be 78 | identified by the server. A node may request it's registration through 79 | the `register` procedure, and if accepted, it will receive a set of 80 | credentials. 81 | 82 | These credentials are used (as of this revision) to sign the `push` 83 | message sent by the node, since it's the only procedure that can 84 | potentially destroy data on the server. 85 | 86 | Other procedures should also be protected by the programmer (e.g. to 87 | prevent theft), but that is her/his responsibility. Synchronization 88 | procedures usually allow the inclusion of user-set data, which can be 89 | checked on the server for authenticity. Also, the HTTPS protocol may 90 | be used by prepending the 'https://' prefix to the URL for each 91 | procedure. 92 | 93 | #### Repairing the client's database #### 94 | 95 | The `repair` procedure exists to allow the client application's 96 | database to recover from otherwise stale states. Such a state should 97 | in theory be impossible to reach, but external database intervention, 98 | or poor conflict resolution by this library (which will be monitored 99 | in private testing), might result in achieving it. 100 | 101 | The `repair` just fetches the entire server database, serialized as 102 | JSON, and then replaces the current one with it. Since it's meant to 103 | be used to fix infrequent errors, and might take a long time to 104 | complete, it should not be used recurrently. 105 | 106 | ### Example ### 107 | 108 | First, give the library a SQLAlchemy engine to access the database. On 109 | the client application, the current tested database is SQLite. 110 | 111 | ```python 112 | from sqlalchemy import create_engine 113 | import dbsync 114 | 115 | engine = create_engine("sqlite:///storage.db") # sample database URL 116 | 117 | dbsync.set_engine(engine) 118 | ``` 119 | 120 | If you don't do this, the library will complain as soon as you attempt 121 | an operation. 122 | 123 | Next, start tracking your operations to fill the opertions log. Use 124 | the `dbsync.client.track` or the `dbsync.server.track` depending on 125 | your application. Don't import both `dbsync.client` and 126 | `dbsync.server`. 127 | 128 | ```python 129 | from sqlalchemy import Column, Integer, String 130 | from sqlalchemy.orm import relationship 131 | from sqlalchemy.ext.declarative import declarative_base 132 | 133 | from dbsync import client 134 | 135 | 136 | Base = declarative_base() 137 | Base.__table_args__ = {'sqlite_autoincrement': True,} # important 138 | 139 | 140 | @client.track 141 | class City(Base): 142 | 143 | __tablename__ = "city" 144 | 145 | id = Column(Integer, primary_key=True) # doesn't have to be called 'id' 146 | name = Column(String(100)) 147 | 148 | 149 | @client.track 150 | class Person(Base): 151 | 152 | __tablename__ = "person" 153 | 154 | id = Column(Integer, primary_key=True) 155 | name = Column(String(100)) 156 | city_id = Column(Integer, ForeignKey("city.id")) 157 | 158 | city = relationship(City, backref="persons") 159 | ``` 160 | 161 | After you've marked all the models you want tracked, you need to 162 | generate the logging infrastructure, explicitly. You can do this once, 163 | or every time the application is started, since it's idempotent. 164 | 165 | ```python 166 | import dbsync 167 | 168 | dbsync.create_all() 169 | ``` 170 | 171 | Next you should register your client application in the server. To do 172 | this, use the `register` procedure: 173 | 174 | ```python 175 | from dbsync import client 176 | 177 | client.register(REGISTER_URL) 178 | ``` 179 | 180 | Where `REGISTER_URL` is the URL pointing to the register handler on 181 | the server. More on this below. 182 | 183 | You can register the client application just once, or check whenever 184 | you wish with the `isregistered` predicate. 185 | 186 | ```python 187 | from dbsync import client 188 | 189 | if not client.isregistered(): 190 | client.register(REGISTER_URL) 191 | ``` 192 | 193 | Now you're ready to try synchronization procedures. If the server is 194 | configured correctly (as shown further below), an acceptable 195 | synchronization cycle could be: 196 | 197 | ```python 198 | from dbsync import client 199 | 200 | 201 | def synchronize(push_url, pull_url, tries): 202 | for _ in range(tries): 203 | try: 204 | return client.push(push_url) 205 | except client.PushRejected: 206 | try: 207 | client.pull(pull_url) 208 | except client.UniqueConstraintError as e: 209 | for model, pk, columns in e.entries: 210 | pass # handle exception 211 | raise Exception("push rejected %d times" % tries) 212 | ``` 213 | 214 | You may catch the different exceptions and react accordingly, since 215 | they can indicate lack of internet connection, integrity conflicts, or 216 | dbsync configuration problems. 217 | 218 | #### Server side #### 219 | 220 | First of all, instead of importing `dbsync.client`, import 221 | `dbsync.server`. So, to track a model: 222 | 223 | ```python 224 | from dbsync import server 225 | 226 | @server.track 227 | class Person(Base): 228 | # ... 229 | ``` 230 | 231 | Then, listen to five distinct URLs: 232 | 233 | - One for the `repair` procedure, listening GETs. 234 | - One for the `register` procedure, listening POSTs. 235 | - One for the `pull` procedure, listening POSTs. 236 | - One for the `push` procedure, listening POSTs. 237 | - One (optional) for the `query` procedure (for remote queries), 238 | listening GETs. 239 | 240 | These handlers should return JSON and use the dbsync handlers. For 241 | example, using [Flask](http://flask.pocoo.org/): 242 | 243 | ```python 244 | import json 245 | from flask import Flask, request 246 | from dbsync import server 247 | 248 | 249 | app = Flask(__name__) 250 | 251 | 252 | @app.route("/repair", methods=["GET"]) 253 | def repair(): 254 | return (json.dumps(server.handle_repair(request.args)), 255 | 200, 256 | {"Content-Type": "application/json"}) 257 | 258 | 259 | @app.route("/register", methods=["POST"]) 260 | def register(): 261 | return (json.dumps(server.handle_register()), 262 | 200, 263 | {"Content-Type": "application/json"}) 264 | 265 | 266 | @app.route("/pull", methods=["POST"]) 267 | def pull(): 268 | return (json.dumps(server.handle_pull_request(request.json)), 269 | 200, 270 | {"Content-Type": "application/json"}) 271 | 272 | 273 | @app.route("/push", methods=["POST"]) 274 | def push(): 275 | try: 276 | return (json.dumps(server.handle_push(request.json)), 277 | 200, 278 | {"Content-Type": "application/json"}) 279 | except server.handlers.PushRejected as e: 280 | return (json.dumps({'error': [repr(arg) for arg in e.args]}), 281 | 400, 282 | {"Content-Type": "application/json"}) 283 | 284 | 285 | @app.route("/query", methods=["GET"]) 286 | def query(): 287 | return (json.dumps(server.handle_query(request.args)), 288 | 200, 289 | {"Content-Type": "application/json"}) 290 | ``` 291 | 292 | Messages to the server usually contain additional user-set data, to 293 | allow for extra checks and custom protection. You can access these 294 | through `request.json.extra_data` when JSON is expected. 295 | -------------------------------------------------------------------------------- /app-tests/node/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bintlabs/python-sync-db/bb23d77abf560793696f906e030950aec04c3361/app-tests/node/__init__.py -------------------------------------------------------------------------------- /app-tests/node/app.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | cwd = u"".join(reversed(os.getcwd())) 4 | test_dir = "tset-ppa" 5 | try: 6 | base_dir = "".join(reversed(cwd[cwd.index(test_dir) + len(test_dir):])) 7 | except ValueError: 8 | base_dir = os.getcwd() 9 | 10 | if base_dir not in sys.path: 11 | sys.path.append(base_dir) 12 | 13 | 14 | import controller as crud 15 | from dbsync import client 16 | -------------------------------------------------------------------------------- /app-tests/node/controller.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | from models import City, House, Person, engine 6 | from dbsync import models 7 | 8 | 9 | Session = sessionmaker(bind=engine) 10 | 11 | 12 | # CRUD cities 13 | 14 | def create_city(**kwargs): 15 | session = Session() 16 | city = City() 17 | for k, v in kwargs.iteritems(): 18 | setattr(city, k, v) 19 | session.add(city) 20 | session.commit() 21 | 22 | 23 | def update_city(id=None, **kwargs): 24 | session = Session() 25 | city = session.query(City).filter(City.id == id).one() 26 | for k, v in kwargs.iteritems(): 27 | setattr(city, k, v) 28 | session.commit() 29 | 30 | 31 | def delete_city(id=None): 32 | session = Session() 33 | city = session.query(City).filter(City.id == id).one() 34 | session.delete(city) 35 | session.commit() 36 | 37 | 38 | def read_cities(): 39 | session = Session() 40 | for city in session.query(City): print city 41 | 42 | 43 | # CRUD houses 44 | 45 | def create_house(**kwargs): 46 | session = Session() 47 | house = House() 48 | for k, v in kwargs.iteritems(): 49 | setattr(house, k, v) 50 | session.add(house) 51 | session.commit() 52 | 53 | 54 | def update_house(id=None, **kwargs): 55 | session = Session() 56 | house = session.query(House).filter(House.id == id).one() 57 | for k, v in kwargs.iteritems(): 58 | setattr(house, k, v) 59 | session.commit() 60 | 61 | 62 | def delete_house(id=None): 63 | session = Session() 64 | house = session.query(House).filter(House.id == id).one() 65 | session.delete(house) 66 | session.commit() 67 | 68 | 69 | def read_houses(): 70 | session = Session() 71 | for house in session.query(House): print house 72 | 73 | 74 | # CRUD persons 75 | 76 | def create_person(**kwargs): 77 | session = Session() 78 | person = Person() 79 | for k, v in kwargs.iteritems(): 80 | setattr(person, k, v) 81 | session.add(person) 82 | session.commit() 83 | 84 | 85 | def update_person(id=None, **kwargs): 86 | session = Session() 87 | person = session.query(Person).filter(Person.id == id).one() 88 | for k, v in kwargs.iteritems(): 89 | setattr(person, k, v) 90 | session.commit() 91 | 92 | 93 | def delete_person(id=None): 94 | session = Session() 95 | person = session.query(Person).filter(Person.id == id).one() 96 | session.delete(person) 97 | session.commit() 98 | 99 | 100 | def read_persons(): 101 | session = Session() 102 | for person in session.query(Person): print person 103 | 104 | 105 | # Synch 106 | 107 | def read_content_types(): 108 | session = Session() 109 | for ct in session.query(models.ContentType): print ct 110 | 111 | 112 | def read_versions(): 113 | session = Session() 114 | for version in session.query(models.Version): print version 115 | 116 | 117 | def read_operations(): 118 | session = Session() 119 | for op in session.query(models.Operation): print op 120 | 121 | 122 | def read_nodes(): 123 | session = Session() 124 | for node in session.query(models.Node): print node 125 | -------------------------------------------------------------------------------- /app-tests/node/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, Date, ForeignKey, create_engine 2 | from sqlalchemy.orm import relationship, backref 3 | from sqlalchemy.ext.declarative import declarative_base 4 | from sqlalchemy.schema import UniqueConstraint 5 | 6 | import dbsync 7 | from dbsync import client 8 | 9 | 10 | engine = create_engine("sqlite:///node.db", echo=True) 11 | 12 | 13 | dbsync.set_engine(engine) 14 | 15 | 16 | Base = declarative_base() 17 | Base.__table_args__ = {'sqlite_autoincrement': True,} 18 | 19 | 20 | @client.track 21 | class City(Base): 22 | 23 | __tablename__ = "city" 24 | 25 | id = Column(Integer, primary_key=True) 26 | name = Column(String) 27 | 28 | def __repr__(self): 29 | return u"".format(self.id, self.name) 30 | 31 | name_pool = ["foo", "bar", "baz"] 32 | 33 | def load_extra(city): 34 | return "-".join(name_pool) + "-" + city.name 35 | 36 | def save_extra(city, data): 37 | print city.name, data 38 | 39 | client.extend(City, "extra", String, load_extra, save_extra) 40 | 41 | 42 | @client.track 43 | class House(Base): 44 | 45 | __tablename__ = "house" 46 | 47 | id = Column(Integer, primary_key=True) 48 | address = Column(String, unique=True) 49 | city_id = Column(Integer, ForeignKey("city.id")) 50 | 51 | city = relationship( 52 | City, backref=backref("houses", cascade='all, delete-orphan')) 53 | 54 | def __repr__(self): 55 | return u"".format( 56 | self.id, self.address, self.city_id) 57 | 58 | 59 | @client.track 60 | class Person(Base): 61 | 62 | __tablename__ = "person" 63 | 64 | __table_args__ = (UniqueConstraint('first_name', 'last_name'), 65 | Base.__table_args__) 66 | 67 | id = Column(Integer, primary_key=True) 68 | first_name = Column(String) 69 | last_name = Column(String) 70 | house_id = Column(Integer, ForeignKey("house.id")) 71 | birth_city_id = Column(Integer, ForeignKey("city.id")) 72 | birth_date = Column(Date) 73 | email = Column(String) 74 | 75 | house = relationship(House, backref="persons") 76 | birth_city = relationship(City) 77 | 78 | def __repr__(self): 79 | return u"".\ 80 | format(self.first_name, 81 | self.last_name, 82 | self.house_id, 83 | self.birth_city_id) 84 | 85 | 86 | Base.metadata.create_all(engine) 87 | dbsync.create_all() 88 | dbsync.generate_content_types() 89 | -------------------------------------------------------------------------------- /app-tests/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bintlabs/python-sync-db/bb23d77abf560793696f906e030950aec04c3361/app-tests/server/__init__.py -------------------------------------------------------------------------------- /app-tests/server/app.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | cwd = u"".join(reversed(os.getcwd())) 4 | test_dir = "stset-ppa" 5 | try: 6 | base_dir = "".join(reversed(cwd[cwd.index(test_dir) + len(test_dir):])) 7 | except ValueError: 8 | base_dir = os.getcwd() 9 | 10 | if base_dir not in sys.path: 11 | sys.path.append(base_dir) 12 | 13 | import json 14 | from flask import Flask, request 15 | import models 16 | 17 | from dbsync import models as synchmodels, server 18 | 19 | 20 | app = Flask(__name__) 21 | 22 | 23 | def enc(string): 24 | table = {"<": "<", 25 | ">": ">"} 26 | return u"".join(table.get(c, c) for c in string) 27 | 28 | 29 | @app.route("/") 30 | def root(): 31 | return 'Ping: any method /ping
'\ 32 | 'Repair: GET /repair
'\ 33 | 'Register: POST /register
'\ 34 | 'Pull: POST /pull
'\ 35 | 'Push: POST /push
'\ 36 | 'Query: GET /query
'\ 37 | 'Inspect: GET /inspect
'\ 38 | 'Synch query: GET /synch' 39 | 40 | 41 | @app.route("/ping") 42 | def ping(): 43 | return "" 44 | 45 | 46 | @app.route("/repair", methods=["GET"]) 47 | def repair(): 48 | return (json.dumps(server.handle_repair(request.args)), 49 | 200, 50 | {"Content-Type": "application/json"}) 51 | 52 | 53 | @app.route("/register", methods=["POST"]) 54 | def register(): 55 | return (json.dumps(server.handle_register()), 56 | 200, 57 | {"Content-Type": "application/json"}) 58 | 59 | 60 | @app.route("/pull", methods=["POST"]) 61 | def pull(): 62 | print json.dumps(request.json, indent=2) 63 | try: 64 | return (json.dumps(server.handle_pull(request.json)), 65 | 200, 66 | {"Content-Type": "application/json"}) 67 | except server.handlers.PullRejected as e: 68 | return (json.dumps({'error': [repr(arg) for arg in e.args]}), 69 | 400, 70 | {"Content-Type": "application/json"}) 71 | 72 | 73 | @app.route("/push", methods=["POST"]) 74 | def push(): 75 | print json.dumps(request.json, indent=2) 76 | try: 77 | return (json.dumps(server.handle_push(request.json)), 78 | 200, 79 | {"Content-Type": "application/json"}) 80 | except server.handlers.PullSuggested as e: 81 | return (json.dumps({'error': [repr(arg) for arg in e.args], 82 | 'suggest_pull': True}), 83 | 400, 84 | {"Content-Type": "application/json"}) 85 | except server.handlers.PushRejected as e: 86 | return (json.dumps({'error': [repr(arg) for arg in e.args]}), 87 | 400, 88 | {"Content-Type": "application/json"}) 89 | 90 | 91 | @app.route("/query", methods=["GET"]) 92 | def query(): 93 | return (json.dumps(server.handle_query(request.args)), 94 | 200, 95 | {"Content-Type": "application/json"}) 96 | 97 | 98 | @app.route("/inspect", methods=["GET"]) 99 | def inspect(): 100 | session = models.Session() 101 | return u"Cities:
{0}

"\ 102 | u"Houses:
{1}

"\ 103 | u"Persons:
{2}

".format( 104 | u"\n".join(enc(repr(x)) for x in session.query(models.City)), 105 | u"\n".join(enc(repr(x)) for x in session.query(models.House)), 106 | u"\n".join(enc(repr(x)) for x in session.query(models.Person))) 107 | 108 | 109 | @app.route("/synch", methods=["GET"]) 110 | def synch(): 111 | session = models.Session() 112 | return u"Content Types:
{0}

"\ 113 | u"Nodes:
{1}

"\ 114 | u"Versions:
{2}

"\ 115 | u"Operations:
{3}

".format( 116 | u"\n".join(enc(repr(x)) for x in session.query(synchmodels.ContentType)), 117 | u"\n".join(enc(repr(x)) for x in session.query(synchmodels.Node)), 118 | u"\n".join(enc(repr(x)) for x in session.query(synchmodels.Version)), 119 | u"\n".join(enc(repr(x)) for x in session.query(synchmodels.Operation))) 120 | 121 | 122 | if __name__ == "__main__": 123 | app.run(debug=True) 124 | -------------------------------------------------------------------------------- /app-tests/server/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String, Date, ForeignKey, create_engine 2 | from sqlalchemy.orm import relationship, sessionmaker 3 | from sqlalchemy.ext.declarative import declarative_base 4 | from sqlalchemy.schema import UniqueConstraint 5 | 6 | import dbsync 7 | from dbsync import server 8 | 9 | 10 | engine = create_engine("sqlite:///server.db", echo=True) 11 | Session = sessionmaker(bind=engine) 12 | 13 | 14 | dbsync.set_engine(engine) 15 | 16 | 17 | Base = declarative_base() 18 | Base.__table_args__ = {'sqlite_autoincrement': True,} 19 | 20 | 21 | @server.track 22 | class City(Base): 23 | 24 | __tablename__ = "city" 25 | 26 | id = Column(Integer, primary_key=True) 27 | name = Column(String(500)) 28 | 29 | def __repr__(self): 30 | return u"".format(self.id, self.name) 31 | 32 | name_pool = ["foo", "bar", "baz"] 33 | 34 | def load_extra(city): 35 | return "-".join(name_pool) + "-" + city.name 36 | 37 | def save_extra(city, data): 38 | print "SAVING -------------------" 39 | print city.name, data 40 | print "SAVED -------------------" 41 | 42 | def delete_extra(old_city, new_city): 43 | print "DELETING -----------------" 44 | print old_city.name, (new_city.name if new_city is not None else None) 45 | print "DELETED -----------------" 46 | 47 | server.extend(City, "extra", String, load_extra, save_extra, delete_extra) 48 | 49 | 50 | @server.track 51 | class House(Base): 52 | 53 | __tablename__ = "house" 54 | 55 | id = Column(Integer, primary_key=True) 56 | address = Column(String(500), unique=True) 57 | city_id = Column(Integer, ForeignKey("city.id")) 58 | 59 | city = relationship(City, backref="houses") 60 | 61 | def __repr__(self): 62 | return u"".format( 63 | self.id, self.address, self.city_id) 64 | 65 | 66 | @server.track 67 | class Person(Base): 68 | 69 | __tablename__ = "person" 70 | 71 | __table_args__ = (UniqueConstraint('first_name', 'last_name'), 72 | Base.__table_args__) 73 | 74 | id = Column(Integer, primary_key=True) 75 | first_name = Column(String(500)) 76 | last_name = Column(String(500)) 77 | house_id = Column(Integer, ForeignKey("house.id")) 78 | birth_city_id = Column(Integer, ForeignKey("city.id")) 79 | birth_date = Column(Date) 80 | email = Column(String(500)) 81 | 82 | house = relationship(House, backref="persons") 83 | birth_city = relationship(City) 84 | 85 | def __repr__(self): 86 | return u"".\ 87 | format(self.first_name, 88 | self.last_name, 89 | self.house_id, 90 | self.birth_city_id) 91 | 92 | 93 | Base.metadata.create_all(engine) 94 | dbsync.create_all() 95 | dbsync.generate_content_types() 96 | -------------------------------------------------------------------------------- /dbsync/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Top-level exports, for convenience. 3 | """ 4 | 5 | import dbsync.core 6 | from dbsync.core import ( 7 | is_synched, 8 | generate_content_types, 9 | set_engine, 10 | get_engine, 11 | save_extensions) 12 | from dbsync.models import Base 13 | from dbsync.logs import set_log_target 14 | 15 | 16 | __version_info__ = (0, 7, 0) 17 | __version__ = '.'.join(str(n) for n in __version_info__) 18 | 19 | 20 | def create_all(): 21 | "Issues DDL commands." 22 | Base.metadata.create_all(get_engine()) 23 | 24 | 25 | def drop_all(): 26 | "Issues DROP commands." 27 | Base.metadata.drop_all(get_engine()) 28 | 29 | 30 | def set_listening_mutex(_): 31 | "DEPRECATED" 32 | pass 33 | -------------------------------------------------------------------------------- /dbsync/client/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for the synchronization client. 3 | 4 | The client or node emits 'push' and 'pull' requests to the server. The 5 | client can also request a registry key if it hasn't been given one 6 | yet. 7 | """ 8 | 9 | import inspect 10 | 11 | from dbsync.client.compression import unsynched_objects, trim 12 | from dbsync.client.tracking import track 13 | from dbsync.core import extend 14 | from dbsync.client.register import ( 15 | register, 16 | isregistered, 17 | get_node, 18 | save_node) 19 | from dbsync.client.pull import UniqueConstraintError, pull 20 | from dbsync.client import push as pushmodule 21 | from dbsync.client.push import PushRejected, PullSuggested, push 22 | from dbsync.client.ping import isconnected, isready 23 | from dbsync.client.repair import repair 24 | from dbsync.client.serverquery import query_server 25 | from dbsync.client import net 26 | 27 | 28 | def set_pull_suggestion_criterion(predicate): 29 | """ 30 | Sets the predicate used to check whether a push response suggests 31 | the node that a pull should be performed. Default value is a 32 | constant ``False`` procedure. 33 | 34 | If set, it should be a procedure that receives three arguments: an 35 | HTTP code, the HTTP reason for said code, and the response (a 36 | dictionary). If for a given response of HTTP return code not in 37 | the 200s the procedure returns ``True``, the PullSuggested 38 | exception will be raised. PullSuggested inherits from 39 | PushRejected. 40 | """ 41 | assert inspect.isroutine(predicate), "criterion must be a function" 42 | pushmodule.suggests_pull = predicate 43 | return predicate 44 | 45 | 46 | def set_default_encoder(enc): 47 | """ 48 | Sets the default encoder used to encode simplified dictionaries to 49 | strings, the messages being sent to the server. Default is 50 | json.dumps 51 | """ 52 | assert inspect.isroutine(enc), "encoder must be a function" 53 | net.default_encoder = enc 54 | return enc 55 | 56 | 57 | def set_default_decoder(dec): 58 | """ 59 | Sets the default decoder used to decode strings, the messages 60 | received from the server, into the dictionaries interpreted by the 61 | library. Default is json.loads 62 | """ 63 | assert inspect.isroutine(dec), "decoder must be a function" 64 | net.default_decoder = dec 65 | return dec 66 | 67 | 68 | def set_default_headers(hhs): 69 | """ 70 | Sets the default headers sent in HTTP requests. Default is:: 71 | 72 | {"Content-Type": "application/json", 73 | "Accept": "application/json"} 74 | """ 75 | assert isinstance(hhs, dict), "headers must be a dictionary" 76 | net.default_headers = hhs 77 | 78 | 79 | def set_default_timeout(t): 80 | """ 81 | Sets the default timeout in seconds for all HTTP requests. Default 82 | is 10 83 | """ 84 | assert isinstance(t, (int, long, float)), "timeout must be a number" 85 | net.default_timeout = t 86 | 87 | 88 | def set_authentication_callback(c): 89 | """ 90 | Sets a procedure that returns an authentication object, used in 91 | POST and GET requests. The procedure should receive the url of the 92 | request, and return an object according to 93 | http://docs.python-requests.org/en/latest/user/authentication/ 94 | """ 95 | assert inspect.isroutine(c), "authentication callback must be a function" 96 | net.authentication_callback = c 97 | return c 98 | -------------------------------------------------------------------------------- /dbsync/client/compression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Operation compression, both in-memory and in-database. 3 | """ 4 | 5 | import warnings 6 | 7 | from dbsync.lang import * 8 | from dbsync.utils import get_pk, query_model 9 | from dbsync import core 10 | from dbsync.models import Version, Operation 11 | from dbsync.logs import get_logger 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def _assert_operation_sequence(seq, session=None): 18 | """ 19 | Asserts the correctness of a sequence of operations over a single 20 | tracked object. 21 | 22 | The sequence is given in a sorted state, from newest operation to 23 | oldest. 24 | """ 25 | message = "The sequence of operations for the given object "\ 26 | " is inconsistent. "\ 27 | "This might indicate external interference with the synchronization "\ 28 | "model or, most commonly, the reuse of old primary keys by the "\ 29 | "database engine. To function properly, the database engine must use "\ 30 | "unique primary keys through the history of the table "\ 31 | "(e.g. using AUTO INCREMENT). Operations from old to new: {3}".\ 32 | format(seq[0].row_id, 33 | seq[0].content_type_id, 34 | seq[0].tracked_model, 35 | list(reversed(map(attr('command'), seq)))) 36 | 37 | if not all(op.command == 'u' for op in seq[1:-1]): 38 | warnings.warn(message) 39 | logger.error( 40 | u"Can't have anything but updates between beginning " 41 | u"and end of operation sequence. %s", 42 | seq) 43 | 44 | if len(seq) > 1: 45 | if seq[-1].command == 'd': 46 | warnings.warn(message) 47 | logger.error( 48 | u"Can't have anything after a delete operation in sequence. %s", 49 | seq) 50 | 51 | if seq[0].command == 'i': 52 | warnings.warn(message) 53 | logger.error( 54 | u"Can't have anything before an insert in operation sequence. %s", 55 | seq) 56 | 57 | 58 | @core.session_committing 59 | def compress(session=None): 60 | """ 61 | Compresses unversioned operations in the database. 62 | 63 | For each row in the operations table, this deletes unnecesary 64 | operations that would otherwise bloat the message. 65 | 66 | This procedure is called internally before the 'push' request 67 | happens, and before the local 'merge' happens. 68 | """ 69 | unversioned = session.query(Operation).\ 70 | filter(Operation.version_id == None).order_by(Operation.order.desc()) 71 | seqs = group_by(lambda op: (op.row_id, op.content_type_id), unversioned) 72 | 73 | # Check errors on sequences 74 | for seq in seqs.itervalues(): 75 | _assert_operation_sequence(seq, session) 76 | 77 | for seq in ifilter(lambda seq: len(seq) > 1, seqs.itervalues()): 78 | if seq[-1].command == 'i': 79 | if all(op.command == 'u' for op in seq[:-1]): 80 | # updates are superfluous 81 | map(session.delete, seq[:-1]) 82 | elif seq[0].command == 'd': 83 | # it's as if the object never existed 84 | map(session.delete, seq) 85 | elif seq[-1].command == 'u': 86 | if all(op.command == 'u' for op in seq[:-1]): 87 | # leave a single update 88 | map(session.delete, seq[1:]) 89 | elif seq[0].command == 'd': 90 | # leave the delete statement 91 | map(session.delete, seq[1:]) 92 | session.flush() 93 | 94 | # repair inconsistencies 95 | for operation in session.query(Operation).\ 96 | filter(Operation.version_id == None).\ 97 | order_by(Operation.order.desc()).all(): 98 | session.flush() 99 | model = operation.tracked_model 100 | if not model: 101 | logger.error( 102 | "operation linked to content type " 103 | "not tracked: %s" % operation.content_type_id) 104 | continue 105 | if operation.command in ('i', 'u'): 106 | if query_model(session, model, only_pk=True).\ 107 | filter_by(**{get_pk(model): operation.row_id}).count() == 0: 108 | logger.warning( 109 | "deleting operation %s for model %s " 110 | "for absence of backing object" % (operation, model.__name__)) 111 | session.delete(operation) 112 | continue 113 | if operation.command == 'u': 114 | subsequent = session.query(Operation).\ 115 | filter(Operation.content_type_id == operation.content_type_id, 116 | Operation.version_id == None, 117 | Operation.row_id == operation.row_id, 118 | Operation.order > operation.order).all() 119 | if any(op.command == 'i' for op in subsequent) and \ 120 | all(op.command != 'd' for op in subsequent): 121 | logger.warning( 122 | "deleting update operation %s for model %s " 123 | "for preceding an insert operation" %\ 124 | (operation, model.__name__)) 125 | session.delete(operation) 126 | continue 127 | if session.query(Operation).\ 128 | filter(Operation.content_type_id == operation.content_type_id, 129 | Operation.command == operation.command, 130 | Operation.version_id == None, 131 | Operation.row_id == operation.row_id, 132 | Operation.order != operation.order).count() > 0: 133 | logger.warning( 134 | "deleting operation %s for model %s " 135 | "for being redundant after compression" %\ 136 | (operation, model.__name__)) 137 | session.delete(operation) 138 | continue 139 | return session.query(Operation).\ 140 | filter(Operation.version_id == None).\ 141 | order_by(Operation.order.asc()).all() 142 | 143 | 144 | def compressed_operations(operations): 145 | """ 146 | Compresses a set of operations so as to avoid redundant 147 | ones. Returns the compressed set sorted by operation order. This 148 | procedure doesn't perform database operations. 149 | """ 150 | seqs = group_by(lambda op: (op.row_id, op.content_type_id), 151 | sorted(operations, key=attr('order'))) 152 | compressed = [] 153 | for seq in seqs.itervalues(): 154 | if len(seq) == 1: 155 | compressed.append(seq[0]) 156 | elif seq[0].command == 'i': 157 | if seq[-1].command == 'd': 158 | pass 159 | else: 160 | compressed.append(seq[0]) 161 | elif seq[0].command == 'u': 162 | if seq[-1].command == 'd': 163 | compressed.append(seq[-1]) 164 | else: 165 | compressed.append(seq[0]) 166 | else: # seq[0].command == 'd': 167 | if seq[-1].command == 'd': 168 | compressed.append(seq[0]) 169 | elif seq[-1].command == 'u': 170 | compressed.append(seq[-1]) 171 | else: # seq[-1].command == 'i': 172 | op = seq[-1] 173 | compressed.append( 174 | Operation(order=op.order, 175 | content_type_id=op.content_type_id, 176 | row_id=op.row_id, 177 | version_id=op.version_id, 178 | command='u')) 179 | compressed.sort(key=attr('order')) 180 | return compressed 181 | 182 | 183 | @core.session_committing 184 | def unsynched_objects(session=None): 185 | """ 186 | Returns a list of triads (class, id, operation) that represents 187 | the unsynchronized objects in the tracked database. 188 | 189 | The first element of each triad is the class for the 190 | unsynchronized object. 191 | 192 | The second element is the primary key *value* of the object. 193 | 194 | The third element is a character in ``('i', 'u', 'd')`` that 195 | represents the operation that altered the objects state (insert, 196 | update or delete). If it's a delete, the object won't be present 197 | in the tracked database. 198 | 199 | Because of compatibility issues, this procedure will only return 200 | triads for classes marked for both push and pull handling. 201 | """ 202 | ops = compress(session=session) 203 | def getclass(op): 204 | class_ = op.tracked_model 205 | if class_ is None: return None 206 | if class_ not in core.pulled_models or class_ not in core.pushed_models: 207 | return None 208 | return class_ 209 | triads = [ 210 | (c, op.row_id, op.command) 211 | for c, op in ((getclass(op), op) for op in ops) 212 | if c is not None] 213 | return triads 214 | 215 | 216 | @core.session_committing 217 | def trim(session=None): 218 | "Trims the internal synchronization tables, to free space." 219 | last_id = core.get_latest_version_id(session=session) 220 | session.query(Operation).filter(Operation.version_id != None).delete() 221 | session.query(Version).filter(Version.version_id != last_id).delete() 222 | -------------------------------------------------------------------------------- /dbsync/client/conflicts.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: client.conflicts 3 | :synopsis: Conflict detection for the local merge operation. 4 | 5 | This module handles the conflict detection that's required for the 6 | local merge operation. The resolution phase is embedded in the 7 | dbsync.client.pull module. 8 | 9 | Related reading: 10 | 11 | Gerritsen, Jan-Henk. Detecting synchronization conflicts for 12 | horizontally decentralized relational databases. `Link to pdf`__. 13 | 14 | .. __: http://essay.utwente.nl/61767/1/Master_thesis_Jan-Henk_Gerritsen.pdf 15 | """ 16 | 17 | from sqlalchemy import or_ 18 | from sqlalchemy.orm import undefer 19 | from sqlalchemy.schema import UniqueConstraint 20 | 21 | from dbsync.lang import * 22 | from dbsync.utils import get_pk, class_mapper, query_model, column_properties 23 | from dbsync.core import synched_models, null_model 24 | from dbsync.models import Operation 25 | 26 | 27 | def get_related_tables(sa_class): 28 | """ 29 | Returns a list of related SA tables dependent on the given SA 30 | model by foreign key. 31 | """ 32 | mapper = class_mapper(sa_class) 33 | models = synched_models.models.iterkeys() 34 | return [table for table in (class_mapper(model).mapped_table 35 | for model in models) 36 | if mapper.mapped_table in [key.column.table 37 | for key in table.foreign_keys]] 38 | 39 | 40 | def get_fks(table_from, table_to): 41 | """ 42 | Returns the names of the foreign keys that are defined in 43 | *table_from* SA table and that refer to *table_to* SA table. If 44 | the foreign keys don't exist, this procedure returns an empty 45 | list. 46 | """ 47 | fks = filter(lambda k: k.column.table == table_to, table_from.foreign_keys) 48 | return [fk.parent.name for fk in fks] 49 | 50 | 51 | def related_local_ids(operation, session): 52 | """ 53 | For the given operation, return a set of row id values mapped to 54 | content type ids that correspond to objects that are dependent by 55 | foreign key on the object being operated upon. The lookups are 56 | performed in the local database. 57 | """ 58 | parent_model = operation.tracked_model 59 | if parent_model is None: 60 | return set() 61 | related_tables = get_related_tables(parent_model) 62 | 63 | mapped_fks = ifilter( 64 | lambda (m, fks): m is not None and fks, 65 | [(synched_models.tables.get(t.name, null_model).model, 66 | get_fks(t, class_mapper(parent_model).mapped_table)) 67 | for t in related_tables]) 68 | return set( 69 | (pk, ct.id) 70 | for pk, ct in \ 71 | ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None)) 72 | for model, fks in mapped_fks 73 | for obj in query_model(session, model, only_pk=True).\ 74 | filter(or_(*(getattr(model, fk) == operation.row_id 75 | for fk in fks))).all()) 76 | if ct is not None) 77 | 78 | 79 | def related_remote_ids(operation, container): 80 | """ 81 | Like *related_local_ids*, but the lookups are performed in 82 | *container*, that's an instance of 83 | *dbsync.messages.base.BaseMessage*. 84 | """ 85 | parent_model = operation.tracked_model 86 | if parent_model is None: 87 | return set() 88 | related_tables = get_related_tables(parent_model) 89 | 90 | mapped_fks = ifilter( 91 | lambda (m, fks): m is not None and fks, 92 | [(synched_models.tables.get(t.name, null_model).model, 93 | get_fks(t, class_mapper(parent_model).mapped_table)) 94 | for t in related_tables]) 95 | return set( 96 | (pk, ct.id) 97 | for pk, ct in \ 98 | ((getattr(obj, get_pk(obj)), synched_models.models.get(model, None)) 99 | for model, fks in mapped_fks 100 | for obj in container.query(model).\ 101 | filter(lambda obj: any(getattr(obj, fk) == operation.row_id 102 | for fk in fks))) 103 | if ct is not None) 104 | 105 | 106 | def find_direct_conflicts(pull_ops, unversioned_ops): 107 | """ 108 | Detect conflicts where there's both unversioned and pulled 109 | operations, update or delete ones, referering to the same tracked 110 | object. This procedure relies on the uniqueness of the primary 111 | keys through time. 112 | """ 113 | return [ 114 | (pull_op, local_op) 115 | for pull_op in pull_ops 116 | if pull_op.command == 'u' or pull_op.command == 'd' 117 | for local_op in unversioned_ops 118 | if local_op.command == 'u' or local_op.command == 'd' 119 | if pull_op.row_id == local_op.row_id 120 | if pull_op.content_type_id == local_op.content_type_id] 121 | 122 | 123 | def find_dependency_conflicts(pull_ops, unversioned_ops, session): 124 | """ 125 | Detect conflicts by relationship dependency: deletes on the pull 126 | message on objects that have dependent objects inserted or updated 127 | on the local database. 128 | """ 129 | related_ids = dict( 130 | (pull_op, related_local_ids(pull_op, session)) 131 | for pull_op in pull_ops 132 | if pull_op.command == 'd') 133 | return [ 134 | (pull_op, local_op) 135 | for pull_op in pull_ops 136 | if pull_op.command == 'd' 137 | for local_op in unversioned_ops 138 | if local_op.command == 'i' or local_op.command == 'u' 139 | if (local_op.row_id, local_op.content_type_id) in related_ids[pull_op]] 140 | 141 | 142 | def find_reversed_dependency_conflicts(pull_ops, unversioned_ops, pull_message): 143 | """ 144 | Deletes on the local database on objects that are referenced by 145 | inserted or updated objects in the pull message. 146 | """ 147 | related_ids = dict( 148 | (local_op, related_remote_ids(local_op, pull_message)) 149 | for local_op in unversioned_ops 150 | if local_op.command == 'd') 151 | return [ 152 | (pull_op, local_op) 153 | for local_op in unversioned_ops 154 | if local_op.command == 'd' 155 | for pull_op in pull_ops 156 | if pull_op.command == 'i' or pull_op.command == 'u' 157 | if (pull_op.row_id, pull_op.content_type_id) in related_ids[local_op]] 158 | 159 | 160 | def find_insert_conflicts(pull_ops, unversioned_ops): 161 | """ 162 | Inserts over the same object. These conflicts should be resolved 163 | by keeping both objects, but moving the local one out of the way 164 | (reinserting it to get a new primary key). It should be possible, 165 | however, to specify a custom handler for cases where the primary 166 | key is a meaningful property of the object. 167 | """ 168 | return [ 169 | (pull_op, local_op) 170 | for local_op in unversioned_ops 171 | if local_op.command == 'i' 172 | for pull_op in pull_ops 173 | if pull_op.command == 'i' 174 | if pull_op.row_id == local_op.row_id 175 | if pull_op.content_type_id == local_op.content_type_id] 176 | 177 | 178 | def find_unique_conflicts(pull_ops, unversioned_ops, pull_message, session): 179 | """ 180 | Unique constraints violated in a model. Returns two lists of 181 | dictionaries, the first one with the solvable conflicts, and the 182 | second one with the proper errors. Each conflict is a dictionary 183 | with the following fields:: 184 | 185 | object: the local conflicting object, bound to the session 186 | columns: tuple of column names in the unique constraint 187 | new_values: tuple of values that can be used to update the 188 | conflicting object 189 | 190 | Each error is a dictionary with the following fields:: 191 | 192 | model: the model (class) of the conflicting object 193 | pk: the value of the primary key of the conflicting object 194 | columns: tuple of column names in the unique constraint 195 | """ 196 | 197 | def verify_constraint(model, columns, values): 198 | """ 199 | Checks to see whether some local object exists with 200 | conflicting values. 201 | """ 202 | match = query_model(session, model, only_pk=True).\ 203 | options(*(undefer(column) for column in columns)).\ 204 | filter_by(**dict((column, value) 205 | for column, value in izip(columns, values))).first() 206 | pk = get_pk(model) 207 | return match, getattr(match, pk, None) 208 | 209 | def get_remote_values(model, row_id, columns): 210 | """ 211 | Gets the conflicting values out of the remote object set 212 | (*container*). 213 | """ 214 | obj = pull_message.query(model).filter(attr('__pk__') == row_id).first() 215 | if obj is not None: 216 | return tuple(getattr(obj, column) for column in columns) 217 | return (None,) 218 | 219 | # keyed to content type 220 | unversioned_pks = dict((ct_id, set(op.row_id for op in unversioned_ops 221 | if op.content_type_id == ct_id 222 | if op.command != 'd')) 223 | for ct_id in set(operation.content_type_id 224 | for operation in unversioned_ops)) 225 | # the lists to fill with conflicts and errors 226 | conflicts, errors = [], [] 227 | 228 | for op in pull_ops: 229 | model = op.tracked_model 230 | 231 | for constraint in ifilter(lambda c: isinstance(c, UniqueConstraint), 232 | class_mapper(model).mapped_table.constraints): 233 | 234 | unique_columns = tuple(col.name for col in constraint.columns) 235 | # Unique values on the server, to check conflicts with local database 236 | remote_values = get_remote_values(model, op.row_id, unique_columns) 237 | 238 | obj_conflict, pk_conflict = verify_constraint( 239 | model, unique_columns, remote_values) 240 | 241 | is_unversioned = pk_conflict in unversioned_pks.get( 242 | op.content_type_id, set()) 243 | 244 | if all(value is None for value in remote_values): continue # Null value 245 | if pk_conflict is None: continue # No problem 246 | if pk_conflict == op.row_id: 247 | if op.command == 'i': 248 | # Two nodes created objects with the same unique 249 | # value and same pk 250 | errors.append( 251 | {'model': type(obj_conflict), 252 | 'pk': pk_conflict, 253 | 'columns': unique_columns}) 254 | continue 255 | 256 | # if pk_conflict != op.row_id: 257 | remote_obj = pull_message.query(model).\ 258 | filter(attr('__pk__') == pk_conflict).first() 259 | 260 | if remote_obj is not None and not is_unversioned: 261 | old_values = tuple(getattr(obj_conflict, column) 262 | for column in unique_columns) 263 | # The new unique value of the conflictive object 264 | # in server 265 | new_values = tuple(getattr(remote_obj, column) 266 | for column in unique_columns) 267 | 268 | if old_values != new_values: 269 | # Library error 270 | # It's necesary to first update the unique value 271 | session.refresh(obj_conflict, column_properties(obj_conflict)) 272 | conflicts.append( 273 | {'object': obj_conflict, 274 | 'columns': unique_columns, 275 | 'new_values': new_values}) 276 | else: 277 | # The server allows two identical unique values 278 | # This should be impossible 279 | pass 280 | elif remote_obj is not None and is_unversioned: 281 | # Two nodes created objects with the same unique 282 | # values. Human error. 283 | errors.append( 284 | {'model': type(obj_conflict), 285 | 'pk': pk_conflict, 286 | 'columns': unique_columns}) 287 | else: 288 | # The conflicting object hasn't been modified on the 289 | # server, which must mean the local user is attempting 290 | # an update that collides with one from another user. 291 | errors.append( 292 | {'model': type(obj_conflict), 293 | 'pk': pk_conflict, 294 | 'columns': unique_columns}) 295 | return conflicts, errors 296 | -------------------------------------------------------------------------------- /dbsync/client/net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Send HTTP requests and interpret responses. 3 | 4 | The body returned by each procedure will be a python dictionary 5 | obtained from parsing a response through a decoder, or ``None`` if the 6 | decoder raises a ``ValueError``. The default encoder, decoder and 7 | headers are meant to work with the JSON specification. 8 | 9 | These procedures will raise a NetworkError in case of network failure. 10 | """ 11 | 12 | import requests 13 | import cStringIO 14 | import inspect 15 | import json 16 | 17 | 18 | class NetworkError(Exception): 19 | pass 20 | 21 | 22 | default_encoder = json.dumps 23 | 24 | default_decoder = json.loads 25 | 26 | default_headers = {"Content-Type": "application/json", 27 | "Accept": "application/json"} 28 | 29 | default_timeout = 10 30 | 31 | authentication_callback = None 32 | 33 | 34 | def _defaults(encode, decode, headers, timeout): 35 | e = encode if not encode is None else default_encoder 36 | if not inspect.isroutine(e): 37 | raise ValueError("encoder must be a function", e) 38 | d = decode if not decode is None else default_decoder 39 | if not inspect.isroutine(d): 40 | raise ValueError("decoder must be a function", d) 41 | h = headers if not headers is None else default_headers 42 | if h and not isinstance(h, dict): 43 | raise ValueError("headers must be False or a python dictionary", h) 44 | t = timeout if not timeout is None else default_timeout 45 | if not isinstance(t, (int, long, float)): 46 | raise ValueError("timeout must be a number") 47 | if t <= 0: 48 | t = None # Non-positive values are interpreted as no timeout 49 | return (e, d, h, t) 50 | 51 | 52 | def post_request(server_url, json_dict, 53 | encode=None, decode=None, headers=None, timeout=None, 54 | monitor=None): 55 | """ 56 | Sends a POST request to *server_url* with data *json_dict* and 57 | returns a trio of (code, reason, body). 58 | 59 | *encode* is a function that transforms a python dictionary into a 60 | string. 61 | 62 | *decode* is a function that transforms a string into a python 63 | dictionary. 64 | 65 | For all dictionaries d of simple types, decode(encode(d)) == d. 66 | 67 | *headers* is a python dictionary with headers to send. 68 | 69 | *timeout* is the number of seconds to wait for a response. 70 | 71 | *monitor* is a routine that gets called for each chunk of the 72 | response received, and is given a dictionary with information. The 73 | key 'status' will always be in the dictionary, and other entries 74 | will contain additional information. Check the source code to see 75 | the variations. 76 | """ 77 | if not server_url.startswith("http://") and \ 78 | not server_url.startswith("https://"): 79 | server_url = "http://" + server_url 80 | enc, dec, hhs, tout = _defaults(encode, decode, headers, timeout) 81 | stream = inspect.isroutine(monitor) 82 | auth = authentication_callback(server_url) \ 83 | if authentication_callback is not None else None 84 | try: 85 | r = requests.post(server_url, data=enc(json_dict), 86 | headers=hhs or None, stream=stream, 87 | timeout=tout, auth=auth) 88 | response = None 89 | if stream: 90 | total = r.headers.get('content-length', None) 91 | partial = 0 92 | monitor({'status': "connect", 'size': total}) 93 | chunks = cStringIO.StringIO() 94 | for chunk in r: 95 | partial += len(chunk) 96 | monitor({'status': "downloading", 97 | 'size': total, 'received': partial}) 98 | chunks.write(chunk) 99 | response = chunks.getvalue() 100 | chunks.close() 101 | else: 102 | response = r.content 103 | body = None 104 | try: 105 | body = dec(response) 106 | except ValueError: 107 | pass 108 | result = (r.status_code, r.reason, body) 109 | r.close() 110 | return result 111 | 112 | except requests.exceptions.RequestException as e: 113 | if stream: 114 | monitor({'status': "error", 'reason': "network error"}) 115 | raise NetworkError(*e.args) 116 | 117 | except Exception as e: 118 | if stream: 119 | monitor({'status': "error", 'reason': "network error"}) 120 | raise NetworkError(*e.args) 121 | 122 | 123 | def get_request(server_url, data=None, 124 | encode=None, decode=None, headers=None, timeout=None, 125 | monitor=None): 126 | """ 127 | Sends a GET request to *server_url*. If *data* is to be added, it 128 | should be a python dictionary with simple pairs suitable for url 129 | encoding. Returns a trio of (code, reason, body). 130 | 131 | Read the docstring for ``post_request`` for information on the 132 | rest. 133 | """ 134 | if not server_url.startswith("http://") and \ 135 | not server_url.startswith("https://"): 136 | server_url = "http://" + server_url 137 | enc, dec, hhs, tout = _defaults(encode, decode, headers, timeout) 138 | stream = inspect.isroutine(monitor) 139 | auth = authentication_callback(server_url) \ 140 | if authentication_callback is not None else None 141 | try: 142 | r = requests.get(server_url, params=data, 143 | headers=hhs or None, stream=stream, 144 | timeout=tout, auth=auth) 145 | response = None 146 | if stream: 147 | total = r.headers.get('content-length', None) 148 | partial = 0 149 | monitor({'status': "connect", 'size': total}) 150 | chunks = cStringIO.StringIO() 151 | for chunk in r: 152 | partial += len(chunk) 153 | monitor({'status': "downloading", 154 | 'size': total, 'received': partial}) 155 | chunks.write(chunk) 156 | response = chunks.getvalue() 157 | chunks.close() 158 | else: 159 | response = r.content 160 | body = None 161 | try: 162 | body = dec(response) 163 | except ValueError: 164 | pass 165 | result = (r.status_code, r.reason, body) 166 | r.close() 167 | return result 168 | 169 | except requests.exceptions.RequestException as e: 170 | if stream: 171 | monitor({'status': "error", 'reason': "network error"}) 172 | raise NetworkError(*e.args) 173 | 174 | except Exception as e: 175 | if stream: 176 | monitor({'status': "error", 'reason': "network error"}) 177 | raise NetworkError(*e.args) 178 | 179 | 180 | def head_request(server_url): 181 | """ 182 | Sends a HEAD request to *server_url*. 183 | 184 | Returns a pair of (code, reason). 185 | """ 186 | if not server_url.startswith("http://") and \ 187 | not server_url.startswith("https://"): 188 | server_url = "http://" + server_url 189 | try: 190 | r = requests.head(server_url, timeout=default_timeout) 191 | return (r.status_code, r.reason) 192 | 193 | except requests.exceptions.RequestException as e: 194 | raise NetworkError(*e.args) 195 | 196 | except Exception as e: 197 | raise NetworkError(*e.args) 198 | -------------------------------------------------------------------------------- /dbsync/client/ping.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ping the server. 3 | 4 | The ping procedures are used to quickly diagnose the internet 5 | connection and server status from the client application. 6 | """ 7 | 8 | from dbsync.client.net import head_request, NetworkError 9 | 10 | 11 | def isconnected(ping_url): 12 | """Whether the client application is connected to the Internet.""" 13 | try: 14 | head_request(ping_url) 15 | return True 16 | except NetworkError: 17 | return False 18 | 19 | 20 | def isready(ping_url): 21 | """Whether the server is ready to receive synchronization 22 | requests.""" 23 | try: 24 | code, reason = head_request(ping_url) 25 | return code // 100 == 2 26 | except: 27 | return False 28 | -------------------------------------------------------------------------------- /dbsync/client/pull.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pull, merge and related operations. 3 | """ 4 | 5 | import collections 6 | 7 | from sqlalchemy.orm import make_transient 8 | 9 | from dbsync.lang import * 10 | from dbsync.utils import class_mapper, get_pk, query_model 11 | from dbsync import core 12 | from dbsync.models import Operation 13 | from dbsync import dialects 14 | from dbsync.messages.pull import PullMessage, PullRequestMessage 15 | from dbsync.client.compression import compress, compressed_operations 16 | from dbsync.client.conflicts import ( 17 | get_related_tables, 18 | get_fks, 19 | find_direct_conflicts, 20 | find_dependency_conflicts, 21 | find_reversed_dependency_conflicts, 22 | find_insert_conflicts, 23 | find_unique_conflicts) 24 | from dbsync.client.net import post_request 25 | 26 | 27 | # Utilities specific to the merge 28 | 29 | def max_local(model, session): 30 | """ 31 | Returns the maximum value for the primary key of the given model 32 | in the local database. 33 | """ 34 | if model is None: 35 | raise ValueError("null model given to max_local query") 36 | return dialects.max_local(model, session) 37 | 38 | 39 | def max_remote(model, container): 40 | """ 41 | Returns the maximum value for the primary key of the given model 42 | in the container. 43 | """ 44 | return max(getattr(obj, get_pk(obj)) for obj in container.query(model)) 45 | 46 | 47 | def update_local_id(old_id, new_id, model, session): 48 | """ 49 | Updates the tuple matching *old_id* with *new_id*, and updates all 50 | dependent tuples in other tables as well. 51 | """ 52 | # Updating either the tuple or the dependent tuples first would 53 | # cause integrity violations if the transaction is flushed in 54 | # between. The order doesn't matter. 55 | if model is None: 56 | raise ValueError("null model given to update_local_id subtransaction") 57 | # must load fully, don't know yet why 58 | obj = query_model(session, model).\ 59 | filter_by(**{get_pk(model): old_id}).first() 60 | setattr(obj, get_pk(model), new_id) 61 | 62 | # Then the dependent ones 63 | related_tables = get_related_tables(model) 64 | mapped_fks = ifilter( 65 | lambda (m, fks): m is not None and fks, 66 | [(core.synched_models.tables.get(t.name, core.null_model).model, 67 | get_fks(t, class_mapper(model).mapped_table)) 68 | for t in related_tables]) 69 | for model, fks in mapped_fks: 70 | for fk in fks: 71 | for obj in query_model(session, model).filter_by(**{fk: old_id}): 72 | setattr(obj, fk, new_id) 73 | session.flush() # raise integrity errors now 74 | 75 | 76 | UniqueConstraintErrorEntry = collections.namedtuple( 77 | 'UniqueConstraintErrorEntry', 78 | 'model pk columns') 79 | 80 | class UniqueConstraintError(Exception): 81 | 82 | entries = None 83 | 84 | def __init__(self, entries): 85 | entries = map(partial(apply, UniqueConstraintErrorEntry, ()), entries) 86 | super(UniqueConstraintError, self).__init__(entries) 87 | self.entries = entries 88 | 89 | def __repr__(self): 90 | if not self.entries: return u"" 91 | return u"".format( 92 | u"; ".join( 93 | u"{0} pk {1} columns ({2})".format( 94 | entry.model.__name__, 95 | entry.pk, 96 | u", ".join(entry.columns)) 97 | for entry in self.entries)) 98 | 99 | def __str__(self): return repr(self) 100 | 101 | 102 | @core.with_transaction() 103 | def merge(pull_message, session=None): 104 | """ 105 | Merges a message from the server with the local database. 106 | 107 | *pull_message* is an instance of dbsync.messages.pull.PullMessage. 108 | """ 109 | if not isinstance(pull_message, PullMessage): 110 | raise TypeError("need an instance of dbsync.messages.pull.PullMessage " 111 | "to perform the local merge operation") 112 | valid_cts = set(ct for ct in core.synched_models.ids) 113 | 114 | unversioned_ops = compress(session=session) 115 | pull_ops = filter(attr('content_type_id').in_(valid_cts), 116 | pull_message.operations) 117 | pull_ops = compressed_operations(pull_ops) 118 | 119 | # I) first phase: resolve unique constraint conflicts if 120 | # possible. Abort early if a human error is detected 121 | unique_conflicts, unique_errors = find_unique_conflicts( 122 | pull_ops, unversioned_ops, pull_message, session) 123 | 124 | if unique_errors: 125 | raise UniqueConstraintError(unique_errors) 126 | 127 | conflicting_objects = set() 128 | for uc in unique_conflicts: 129 | obj = uc['object'] 130 | conflicting_objects.add(obj) 131 | for key, value in izip(uc['columns'], uc['new_values']): 132 | setattr(obj, key, value) 133 | # Resolve potential cyclical conflicts by deleting and reinserting 134 | for obj in conflicting_objects: 135 | make_transient(obj) # remove from session 136 | for model in set(type(obj) for obj in conflicting_objects): 137 | pk_name = get_pk(model) 138 | pks = [getattr(obj, pk_name) 139 | for obj in conflicting_objects 140 | if type(obj) is model] 141 | session.query(model).filter(getattr(model, pk_name).in_(pks)).\ 142 | delete(synchronize_session=False) # remove from the database 143 | session.add_all(conflicting_objects) # reinsert them 144 | session.flush() 145 | 146 | # II) second phase: detect conflicts between pulled operations and 147 | # unversioned ones 148 | direct_conflicts = find_direct_conflicts(pull_ops, unversioned_ops) 149 | 150 | # in which the delete operation is registered on the pull message 151 | dependency_conflicts = find_dependency_conflicts( 152 | pull_ops, unversioned_ops, session) 153 | 154 | # in which the delete operation was performed locally 155 | reversed_dependency_conflicts = find_reversed_dependency_conflicts( 156 | pull_ops, unversioned_ops, pull_message) 157 | 158 | insert_conflicts = find_insert_conflicts(pull_ops, unversioned_ops) 159 | 160 | # III) third phase: perform pull operations, when allowed and 161 | # while resolving conflicts 162 | def extract(op, conflicts): 163 | return [local for remote, local in conflicts if remote is op] 164 | 165 | def purgelocal(local): 166 | session.delete(local) 167 | exclude = lambda tup: tup[1] is not local 168 | mfilter(exclude, direct_conflicts) 169 | mfilter(exclude, dependency_conflicts) 170 | mfilter(exclude, reversed_dependency_conflicts) 171 | mfilter(exclude, insert_conflicts) 172 | unversioned_ops.remove(local) 173 | 174 | for pull_op in pull_ops: 175 | # flag to control whether the remote operation is free of obstacles 176 | can_perform = True 177 | # flag to detect the early exclusion of a remote operation 178 | reverted = False 179 | # the class of the operation 180 | class_ = pull_op.tracked_model 181 | 182 | direct = extract(pull_op, direct_conflicts) 183 | if direct: 184 | if pull_op.command == 'd': 185 | can_perform = False 186 | for local in direct: 187 | pair = (pull_op.command, local.command) 188 | if pair == ('u', 'u'): 189 | can_perform = False # favor local changes over remote ones 190 | elif pair == ('u', 'd'): 191 | pull_op.command = 'i' # negate the local delete 192 | purgelocal(local) 193 | elif pair == ('d', 'u'): 194 | local.command = 'i' # negate the remote delete 195 | session.flush() 196 | reverted = True 197 | else: # ('d', 'd') 198 | purgelocal(local) 199 | 200 | dependency = extract(pull_op, dependency_conflicts) 201 | if dependency and not reverted: 202 | can_perform = False 203 | order = min(op.order for op in unversioned_ops) 204 | # first move all operations further in order, to make way 205 | # for the new one 206 | for op in unversioned_ops: 207 | op.order = op.order + 1 208 | session.flush() 209 | # then create operation to reflect the reinsertion and 210 | # maintain a correct operation history 211 | session.add(Operation(row_id=pull_op.row_id, 212 | content_type_id=pull_op.content_type_id, 213 | command='i', 214 | order=order)) 215 | 216 | reversed_dependency = extract(pull_op, reversed_dependency_conflicts) 217 | for local in reversed_dependency: 218 | # reinsert record 219 | local.command = 'i' 220 | local.perform(pull_message, session) 221 | # delete trace of deletion 222 | purgelocal(local) 223 | 224 | insert = extract(pull_op, insert_conflicts) 225 | for local in insert: 226 | session.flush() 227 | next_id = max(max_remote(class_, pull_message), 228 | max_local(class_, session)) + 1 229 | update_local_id(local.row_id, next_id, class_, session) 230 | local.row_id = next_id 231 | if can_perform: 232 | pull_op.perform(pull_message, session) 233 | 234 | session.flush() 235 | 236 | # IV) fourth phase: insert versions from the pull_message 237 | for pull_version in pull_message.versions: 238 | session.add(pull_version) 239 | 240 | 241 | class BadResponseError(Exception): 242 | pass 243 | 244 | 245 | def pull(pull_url, extra_data=None, 246 | encode=None, decode=None, headers=None, monitor=None, timeout=None, 247 | include_extensions=True): 248 | """ 249 | Attempts a pull from the server. Returns the response body. 250 | 251 | Additional data can be passed to the request by giving 252 | *extra_data*, a dictionary of values. 253 | 254 | If not interrupted, the pull will perform a local merge. If the 255 | response from the server isn't appropriate, it will raise a 256 | dbysnc.client.pull.BadResponseError. 257 | 258 | By default, the *encode* function is ``json.dumps``, the *decode* 259 | function is ``json.loads``, and the *headers* are appropriate HTTP 260 | headers for JSON. 261 | 262 | *monitor* should be a routine that receives a dictionary with 263 | information of the state of the request and merge procedure. 264 | 265 | *include_extensions* dictates whether the extension functions will 266 | be called during the merge or not. Default is ``True``. 267 | """ 268 | assert isinstance(pull_url, basestring), "pull url must be a string" 269 | assert bool(pull_url), "pull url can't be empty" 270 | if extra_data is not None: 271 | assert isinstance(extra_data, dict), "extra data must be a dictionary" 272 | request_message = PullRequestMessage() 273 | for op in compress(): request_message.add_operation(op) 274 | data = request_message.to_json() 275 | data.update({'extra_data': extra_data or {}}) 276 | 277 | code, reason, response = post_request( 278 | pull_url, data, encode, decode, headers, timeout, monitor) 279 | if (code // 100 != 2): 280 | if monitor: 281 | monitor({'status': "error", 'reason': reason.lower()}) 282 | raise BadResponseError(code, reason, response) 283 | if response is None: 284 | if monitor: 285 | monitor({ 286 | 'status': "error", 287 | 'reason': "invalid response format"}) 288 | raise BadResponseError(code, reason, response) 289 | message = None 290 | try: 291 | message = PullMessage(response) 292 | except KeyError: 293 | if monitor: 294 | monitor({ 295 | 'status': "error", 296 | 'reason': "invalid message format"}) 297 | raise BadResponseError( 298 | "response object isn't a valid PullMessage", response) 299 | 300 | if monitor: 301 | monitor({ 302 | 'status': "merging", 303 | 'operations': len(message.operations)}) 304 | merge(message, include_extensions=include_extensions) 305 | if monitor: 306 | monitor({'status': "done"}) 307 | # return the response for the programmer to do what she wants 308 | # afterwards 309 | return response 310 | -------------------------------------------------------------------------------- /dbsync/client/push.py: -------------------------------------------------------------------------------- 1 | """ 2 | Push message and related operations. 3 | """ 4 | 5 | import datetime 6 | 7 | from dbsync.lang import * 8 | from dbsync import core 9 | from dbsync.models import Node, Version 10 | from dbsync.messages.push import PushMessage 11 | from dbsync.client.compression import compress 12 | from dbsync.client.net import post_request 13 | 14 | 15 | class PushRejected(Exception): pass 16 | 17 | class PullSuggested(PushRejected): pass 18 | 19 | 20 | # user-defined predicate to decide based on the server's response 21 | suggests_pull = None 22 | 23 | 24 | @core.with_transaction() 25 | def request_push(push_url, 26 | extra_data=None, 27 | encode=None, decode=None, headers=None, timeout=None, 28 | extensions=True, 29 | session=None): 30 | message = PushMessage() 31 | message.latest_version_id = core.get_latest_version_id(session=session) 32 | compress(session=session) 33 | message.add_unversioned_operations( 34 | session=session, include_extensions=extensions) 35 | message.set_node(session.query(Node).order_by(Node.node_id.desc()).first()) 36 | 37 | data = message.to_json() 38 | data.update({'extra_data': extra_data or {}}) 39 | 40 | code, reason, response = post_request( 41 | push_url, data, encode, decode, headers, timeout) 42 | 43 | if (code // 100 != 2) or response is None: 44 | if suggests_pull is not None and suggests_pull(code, reason, response): 45 | raise PullSuggested(code, reason, response) 46 | raise PushRejected(code, reason, response) 47 | new_version_id = response.get('new_version_id') 48 | if new_version_id is None: 49 | raise PushRejected( 50 | code, 51 | reason, 52 | {'error': "server didn't respond with new version id", 53 | 'response': response}) 54 | # Who should set the dates? Maybe send a complete Version from the 55 | # server. For now the field is ignored, so it doesn't matter. 56 | session.add( 57 | Version(version_id=new_version_id, created=datetime.datetime.now())) 58 | for op in message.operations: 59 | op.version_id = new_version_id 60 | # return the response for the programmer to do what she wants 61 | # afterwards 62 | return response 63 | 64 | 65 | def push(push_url, extra_data=None, 66 | encode=None, decode=None, headers=None, timeout=None, 67 | include_extensions=True): 68 | """ 69 | Attempts a push to the server. Returns the response body. 70 | 71 | Additional data can be passed to the request by giving 72 | *extra_data*, a dictionary of values. 73 | 74 | If not interrupted, the push will add a new version to the 75 | database, and will link all unversioned operations to that newly 76 | added version. 77 | 78 | If rejected, the push operation will raise a 79 | dbsync.client.push.PushRejected exception. 80 | 81 | By default, the *encode* function is ``json.dumps``, the *decode* 82 | function is ``json.loads``, and the *headers* are appropriate HTTP 83 | headers for JSON. 84 | 85 | *include_extensions* dictates whether the message will include 86 | model extensions or not. 87 | """ 88 | assert isinstance(push_url, basestring), "push url must be a string" 89 | assert bool(push_url), "push url can't be empty" 90 | if extra_data is not None: 91 | assert isinstance(extra_data, dict), "extra data must be a dictionary" 92 | 93 | return request_push( 94 | push_url, 95 | extra_data=extra_data, 96 | encode=encode, decode=decode, headers=headers, timeout=timeout, 97 | extensions=include_extensions, 98 | include_extensions=include_extensions) 99 | -------------------------------------------------------------------------------- /dbsync/client/register.py: -------------------------------------------------------------------------------- 1 | """ 2 | Request for node registry. 3 | 4 | This is vulnerable to many things if used by itself. It should at 5 | least be used over HTTPS and with some sort of user authentication 6 | layer on the server. 7 | """ 8 | 9 | from dbsync import core 10 | from dbsync.models import Node 11 | from dbsync.messages.register import RegisterMessage 12 | from dbsync.client.net import post_request 13 | 14 | 15 | class RegisterRejected(Exception): pass 16 | 17 | 18 | @core.with_transaction() 19 | def register(registry_url, extra_data=None, 20 | encode=None, decode=None, headers=None, timeout=None, 21 | session=None): 22 | """ 23 | Request a node registry from the server. 24 | 25 | If there is already a node registered in the local database, it 26 | won't be used for the following operations. Additional data can be 27 | passed to the request by giving *extra_data*, a dictionary of 28 | values. 29 | 30 | By default, the *encode* function is ``json.dumps``, the *decode* 31 | function is ``json.loads``, and the *headers* are appropriate HTTP 32 | headers for JSON. 33 | """ 34 | assert isinstance(registry_url, basestring), "registry url must be a string" 35 | assert bool(registry_url), "registry url can't be empty" 36 | if extra_data is not None: 37 | assert isinstance(extra_data, dict), "extra data must be a dictionary" 38 | 39 | code, reason, response = post_request( 40 | registry_url, extra_data or {}, encode, decode, headers, timeout) 41 | 42 | if (code // 100 != 2) or response is None: 43 | raise RegisterRejected(code, reason, response) 44 | 45 | message = RegisterMessage(response) 46 | session.add(message.node) 47 | return response 48 | 49 | 50 | @core.session_closing 51 | def isregistered(session=None): 52 | """ 53 | Checks whether this client application has at least one node 54 | registry. 55 | """ 56 | return session.query(Node).first() is not None 57 | 58 | 59 | @core.session_closing 60 | def get_node(session=None): 61 | "Returns the node register info for the actual client." 62 | return session.query(Node).order_by(Node.node_id.desc()).first() 63 | 64 | 65 | @core.session_committing 66 | def save_node(node_id, registered, register_user_id, secret, session=None): 67 | "Save node info into database without a server request." 68 | node = Node(node_id=node_id, 69 | registered=registered, 70 | registry_user_id=register_user_id, 71 | secret=secret) 72 | session.add(node) 73 | -------------------------------------------------------------------------------- /dbsync/client/repair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Repair client database. 3 | 4 | The repair consists in fetching the server database and inserting it 5 | locally while discarding the client database. 6 | 7 | Ideally this procedure won't ever be needed. If the synchronization 8 | models get corrupted, either by external interference or simply poor 9 | conflict resolution, further synchronization operations would fail and 10 | be unable to recover. For those cases, a repair operation will restore 11 | the client database to the server state, which is assumed to be 12 | correct. 13 | 14 | This procedure can take a long time to complete, since it clears the 15 | client database and fetches a big message from the server. 16 | """ 17 | 18 | from dbsync import core 19 | from dbsync.models import Operation, Version 20 | from dbsync.messages.base import BaseMessage 21 | from dbsync.client.net import get_request 22 | 23 | 24 | @core.with_transaction() 25 | def repair_database(message, latest_version_id, session=None): 26 | if not isinstance(message, BaseMessage): 27 | raise TypeError("need an instance of dbsync.messages.base.BaseMessage "\ 28 | "to perform the repair operation") 29 | # clear local database 30 | for model in core.synched_models.models: 31 | session.query(model).delete(synchronize_session=False) 32 | # clear the local operations and versions 33 | session.query(Operation).delete(synchronize_session=False) 34 | session.query(Version).delete(synchronize_session=False) 35 | session.expire_all() 36 | # load the fetched database 37 | obj_count = 0 38 | batch_size = 500 39 | for modelkey in core.synched_models.model_names: 40 | for obj in message.query(modelkey): 41 | session.add(obj) 42 | obj_count += 1 43 | if obj_count % batch_size == 0: 44 | session.flush() 45 | # load the new version, if any 46 | if latest_version_id is not None: 47 | session.add(Version(version_id=latest_version_id)) 48 | 49 | 50 | class BadResponseError(Exception): pass 51 | 52 | 53 | def repair(repair_url, include_extensions=True, extra_data=None, 54 | encode=None, decode=None, headers=None, timeout=None, 55 | monitor=None): 56 | """ 57 | Fetches the server database and replaces the local one with it. 58 | 59 | *include_extensions* includes or excludes extension fields from 60 | the operation. 61 | 62 | *extra_data* can be used to add user credentials. 63 | 64 | By default, the *encode* function is ``json.dumps``, the *decode* 65 | function is ``json.loads``, and the *headers* are appropriate HTTP 66 | headers for JSON. 67 | """ 68 | assert isinstance(repair_url, basestring), "repair url must be a string" 69 | assert bool(repair_url), "repair url can't be empty" 70 | if extra_data is not None: 71 | assert isinstance(extra_data, dict), "extra data must be a dictionary" 72 | assert 'exclude_extensions' not in extra_data, "reserved request key" 73 | data = {'exclude_extensions': ""} if not include_extensions else {} 74 | data.update(extra_data or {}) 75 | 76 | code, reason, response = get_request( 77 | repair_url, data, encode, decode, headers, timeout, monitor) 78 | 79 | if (code // 100 != 2): 80 | if monitor: monitor({'status': "error", 'reason': reason.lower()}) 81 | raise BadResponseError(code, reason, response) 82 | if response is None: 83 | if monitor: monitor({'status': "error", 84 | 'reason': "invalid response format"}) 85 | raise BadResponseError(code, reason, response) 86 | message = None 87 | try: 88 | message = BaseMessage(response) 89 | except KeyError: 90 | if monitor: monitor({'status': "error", 91 | 'reason': "invalid message format"}) 92 | raise BadResponseError( 93 | "response object isn't a valid BaseMessage", response) 94 | 95 | if monitor: monitor({'status': "repairing"}) 96 | repair_database( 97 | message, 98 | response.get("latest_version_id", None), 99 | include_extensions=include_extensions) 100 | if monitor: monitor({'status': "done"}) 101 | return response 102 | -------------------------------------------------------------------------------- /dbsync/client/serverquery.py: -------------------------------------------------------------------------------- 1 | """ 2 | Query the server's database. 3 | 4 | For now, only equality filters are allowed, and they will be joined 5 | together by ``and_`` in the server. 6 | """ 7 | 8 | from dbsync import core 9 | from dbsync.messages.base import BaseMessage 10 | from dbsync.client.net import get_request 11 | 12 | 13 | class BadResponseError(Exception): pass 14 | 15 | 16 | def query_server(query_url, 17 | encode=None, decode=None, headers=None, timeout=None, 18 | monitor=None): 19 | """Queries the server for a single model's dataset. 20 | 21 | This procedure returns a procedure that receives the class and 22 | filters, and performs the HTTP request.""" 23 | def query(cls, **args): 24 | data = {'model': cls.__name__} 25 | data.update(dict(('{0}_{1}'.format(cls.__name__, key), value) 26 | for key, value in args.iteritems())) 27 | 28 | code, reason, response = get_request( 29 | query_url, data, encode, decode, headers, timeout, monitor) 30 | 31 | if (code // 100 != 2): 32 | if monitor: monitor({'status': "error", 'reason': reason.lower()}) 33 | raise BadResponseError(code, reason, response) 34 | if response is None: 35 | if monitor: monitor({'status': "error", 36 | 'reason': "invalid response format"}) 37 | raise BadResponseError(code, reason, response) 38 | message = None 39 | try: 40 | message = BaseMessage(response) 41 | except KeyError: 42 | if monitor: monitor({'status': "error", 43 | 'reason': "invalid message format"}) 44 | raise BadResponseError( 45 | "response object isn't a valid BaseMessage", response) 46 | return message.query(cls).all() 47 | return query 48 | -------------------------------------------------------------------------------- /dbsync/client/tracking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Listeners to SQLAlchemy events to keep track of CUD operations. 3 | """ 4 | 5 | import logging 6 | import inspect 7 | import warnings 8 | from collections import deque 9 | 10 | from sqlalchemy import event 11 | from sqlalchemy.orm.session import Session as GlobalSession 12 | 13 | from dbsync import core 14 | from dbsync.models import Operation 15 | from dbsync.logs import get_logger 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | if core.mode == 'server': 22 | warnings.warn("don't import both client and server") 23 | core.mode = 'client' 24 | 25 | 26 | #: Operations to be flushed to the database after a commit. 27 | _operations_queue = deque() 28 | 29 | 30 | def flush_operations(committed_session): 31 | "Flush operations after a commit has been issued." 32 | if not _operations_queue or \ 33 | getattr(committed_session, core.INTERNAL_SESSION_ATTR, False): 34 | return 35 | if not core.listening: 36 | logger.warning("dbsync is disabled; aborting flush_operations") 37 | return 38 | with core.committing_context() as session: 39 | while _operations_queue: 40 | op = _operations_queue.popleft() 41 | session.add(op) 42 | session.flush() 43 | 44 | 45 | def empty_queue(*args): 46 | "Empty the operations queue." 47 | session = None if not args else args[0] 48 | if getattr(session, core.INTERNAL_SESSION_ATTR, False): 49 | return 50 | if not core.listening: 51 | logger.warning("dbsync is disabled; aborting empty_queue") 52 | return 53 | while _operations_queue: 54 | _operations_queue.pop() 55 | 56 | 57 | def make_listener(command): 58 | "Builds a listener for the given command (i, u, d)." 59 | def listener(mapper, connection, target): 60 | if getattr(core.SessionClass.object_session(target), 61 | core.INTERNAL_SESSION_ATTR, 62 | False): 63 | return 64 | if not core.listening: 65 | logger.warning("dbsync is disabled; " 66 | "aborting listener to '{0}' command".format(command)) 67 | return 68 | if command == 'u' and not core.SessionClass.object_session(target).\ 69 | is_modified(target, include_collections=False): 70 | return 71 | tname = mapper.mapped_table.name 72 | if tname not in core.synched_models.tables: 73 | logging.error("you must track a mapped class to table {0} "\ 74 | "to log operations".format(tname)) 75 | return 76 | pk = getattr(target, mapper.primary_key[0].name) 77 | op = Operation( 78 | row_id=pk, 79 | version_id=None, # operation not yet versioned 80 | content_type_id=core.synched_models.tables[tname].id, 81 | command=command) 82 | _operations_queue.append(op) 83 | return listener 84 | 85 | 86 | def _start_tracking(model, directions): 87 | if 'pull' in directions: 88 | core.pulled_models.add(model) 89 | if 'push' in directions: 90 | core.pushed_models.add(model) 91 | if model in core.synched_models.models: 92 | return model 93 | core.synched_models.install(model) 94 | if 'push' not in directions: 95 | return model # don't track operations for pull-only models 96 | event.listen(model, 'after_insert', make_listener('i')) 97 | event.listen(model, 'after_update', make_listener('u')) 98 | event.listen(model, 'after_delete', make_listener('d')) 99 | return model 100 | 101 | 102 | def track(*directions): 103 | """ 104 | Adds an ORM class to the list of synchronized classes. 105 | 106 | It can be used as a class decorator. This will also install 107 | listeners to keep track of CUD operations for the given model. 108 | 109 | *directions* are optional arguments of values in ('push', 'pull') 110 | that can restrict the way dbsync handles the class during those 111 | procedures. If not given, both values are assumed. If only one of 112 | them is given, the other procedure will ignore the tracked class. 113 | """ 114 | valid = ('push', 'pull') 115 | if not directions: 116 | return lambda model: _start_tracking(model, valid) 117 | if len(directions) == 1 and inspect.isclass(directions[0]): 118 | return _start_tracking(directions[0], valid) 119 | assert all(d in valid for d in directions), \ 120 | "track only accepts the arguments: {0}".format(', '.join(valid)) 121 | return lambda model: _start_tracking(model, directions) 122 | 123 | 124 | event.listen(GlobalSession, 'after_commit', flush_operations) 125 | event.listen(GlobalSession, 'after_soft_rollback', empty_queue) 126 | -------------------------------------------------------------------------------- /dbsync/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common functionality for model synchronization and version tracking. 3 | """ 4 | 5 | import zlib 6 | import inspect 7 | import contextlib 8 | import logging 9 | logging.getLogger('dbsync').addHandler(logging.NullHandler()) 10 | 11 | from sqlalchemy.orm import sessionmaker 12 | from sqlalchemy.engine import Engine 13 | 14 | from dbsync.lang import * 15 | from dbsync.utils import get_pk, query_model, copy, class_mapper 16 | from dbsync.models import ContentType, Operation, Version 17 | from dbsync import dialects 18 | from dbsync.logs import get_logger 19 | 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | #: Approximate maximum number of variables allowed in a query 25 | MAX_SQL_VARIABLES = 900 26 | 27 | 28 | INTERNAL_SESSION_ATTR = '_dbsync_internal' 29 | 30 | 31 | SessionClass = sessionmaker(autoflush=False, expire_on_commit=False) 32 | def Session(): 33 | s = SessionClass(bind=get_engine()) 34 | s._model_changes = dict() # for flask-sqlalchemy 35 | setattr(s, INTERNAL_SESSION_ATTR, True) # used to disable listeners 36 | return s 37 | 38 | 39 | def session_closing(fn): 40 | @wraps(fn) 41 | def wrapped(*args, **kwargs): 42 | closeit = kwargs.get('session', None) is None 43 | session = Session() if closeit else kwargs['session'] 44 | kwargs['session'] = session 45 | try: 46 | return fn(*args, **kwargs) 47 | finally: 48 | if closeit: 49 | session.close() 50 | return wrapped 51 | 52 | 53 | def session_committing(fn): 54 | @wraps(fn) 55 | def wrapped(*args, **kwargs): 56 | closeit = kwargs.get('session', None) is None 57 | session = Session() if closeit else kwargs['session'] 58 | kwargs['session'] = session 59 | try: 60 | result = fn(*args, **kwargs) 61 | if closeit: 62 | session.commit() 63 | else: 64 | session.flush() 65 | return result 66 | except: 67 | if closeit: 68 | session.rollback() 69 | raise 70 | finally: 71 | if closeit: 72 | session.close() 73 | return wrapped 74 | 75 | 76 | @contextlib.contextmanager 77 | def committing_context(): 78 | session = Session() 79 | try: 80 | yield session 81 | session.commit() 82 | except: 83 | session.rollback() 84 | raise 85 | finally: 86 | session.close() 87 | 88 | 89 | #: The internal use mode, used to prevent client-server module 90 | # collision. Possible values are 'modeless', 'client', 'server'. 91 | mode = 'modeless' 92 | 93 | 94 | #: The engine used for database connections. 95 | _engine = None 96 | 97 | 98 | def set_engine(engine): 99 | """ 100 | Sets the SA engine to be used by the library. 101 | 102 | It should point to the same database as the application's. 103 | """ 104 | assert isinstance(engine, Engine), "expected sqlalchemy.engine.Engine object" 105 | global _engine 106 | _engine = engine 107 | 108 | 109 | class ConfigurationError(Exception): pass 110 | 111 | def get_engine(): 112 | "Returns a defined (not None) engine." 113 | if _engine is None: 114 | raise ConfigurationError("database engine hasn't been set yet") 115 | return _engine 116 | 117 | 118 | class tracked_record(object): 119 | 120 | def __setattr__(self, *args): 121 | raise AttributeError("'tracked_record' object is immutable") 122 | __delattr__ = __setattr__ 123 | 124 | def __init__(self, model=None, id=None): 125 | super(tracked_record, self).__setattr__('model', model) 126 | super(tracked_record, self).__setattr__('id', id) 127 | 128 | null_model = tracked_record() 129 | 130 | def install(self, model): 131 | """ 132 | Installs the model in synched_models, indexing by class, class 133 | name, table name and content_type_id. 134 | """ 135 | ct_id = make_content_type_id(model) 136 | tname = model.__table__.name 137 | record = tracked_record(model=model, id=ct_id) 138 | self.model_names[model.__name__] = record 139 | self.models[model] = record 140 | self.tables[tname] = record 141 | self.ids[ct_id] = record 142 | 143 | #: Set of classes marked for synchronization and change tracking. 144 | synched_models = type( 145 | 'synched_models', 146 | (object,), 147 | {'tables': dict(), 148 | 'models': dict(), 149 | 'model_names': dict(), 150 | 'ids': dict(), 151 | 'install': install})() 152 | 153 | 154 | def tracked_model(operation): 155 | "Get's the tracked model (SA mapped class) for this operation." 156 | return synched_models.ids.get(operation.content_type_id, null_model).model 157 | # Injects synched models lookup into the Operation class. 158 | Operation.tracked_model = property(tracked_model) 159 | 160 | 161 | #: Set of classes in *synched_models* that are subject to pull handling. 162 | pulled_models = set() 163 | 164 | 165 | #: Set of classes in *synched_models* that are subject to push handling. 166 | pushed_models = set() 167 | 168 | 169 | #: Extensions to tracked models. 170 | model_extensions = {} 171 | 172 | 173 | def extend(model, fieldname, fieldtype, loadfn, savefn, deletefn=None): 174 | """ 175 | Extends *model* with a field of name *fieldname* and type 176 | *fieldtype*. 177 | 178 | *fieldtype* should be an instance of a SQLAlchemy type class, or 179 | the class itself. 180 | 181 | *loadfn* is a function called to populate the extension. It should 182 | accept an instance of the model and should return the value of the 183 | field. 184 | 185 | *savefn* is a function called to persist the field. It should 186 | accept the instance of the model and the field's value. It's 187 | return value is ignored. 188 | 189 | *deletefn* is a function called to revert the side effects of 190 | *savefn* for old values. It gets called after an update on the 191 | object with the old object's values, or after a delete. *deletefn* 192 | is optional, and if given it should be a function of two 193 | arguments: the first is the object in the previous state, the 194 | second is the object in the current state. 195 | 196 | Original proposal: https://gist.github.com/kklingenberg/7336576 197 | """ 198 | assert inspect.isclass(model), "model must be a mapped class" 199 | assert isinstance(fieldname, basestring) and bool(fieldname),\ 200 | "field name must be a non-empty string" 201 | assert not hasattr(model, fieldname),\ 202 | "the model {0} already has the attribute {1}".\ 203 | format(model.__name__, fieldname) 204 | assert inspect.isroutine(loadfn), "load function must be a callable" 205 | assert inspect.isroutine(savefn), "save function must be a callable" 206 | assert deletefn is None or inspect.isroutine(deletefn),\ 207 | "delete function must be a callable" 208 | extensions = model_extensions.get(model.__name__, {}) 209 | type_ = fieldtype if not inspect.isclass(fieldtype) else fieldtype() 210 | extensions[fieldname] = (type_, loadfn, savefn, deletefn) 211 | model_extensions[model.__name__] = extensions 212 | 213 | 214 | def _has_extensions(obj): 215 | return bool(model_extensions.get(type(obj).__name__, {})) 216 | 217 | def _has_delete_functions(obj): 218 | return any( 219 | delfn is not None 220 | for t, loadfn, savefn, delfn in model_extensions.get( 221 | type(obj).__name__, {}).itervalues()) 222 | 223 | 224 | def save_extensions(obj): 225 | """ 226 | Executes the save procedures for the extensions of the given 227 | object. 228 | """ 229 | extensions = model_extensions.get(type(obj).__name__, {}) 230 | for field, ext in extensions.iteritems(): 231 | _, _, savefn, _ = ext 232 | try: savefn(obj, getattr(obj, field, None)) 233 | except: 234 | logger.exception( 235 | u"Couldn't save extension %s for object %s", field, obj) 236 | 237 | 238 | def delete_extensions(old_obj, new_obj): 239 | """ 240 | Executes the delete procedures for the extensions of the given 241 | object. *old_obj* is the object in the previous state, and 242 | *new_obj* is the object in the current state (or ``None`` if the 243 | object was deleted). 244 | """ 245 | extensions = model_extensions.get(type(old_obj).__name__, {}) 246 | for field, ext in extensions.iteritems(): 247 | _, _, _, deletefn = ext 248 | if deletefn is not None: 249 | try: deletefn(old_obj, new_obj) 250 | except: 251 | logger.exception( 252 | u"Couldn't delete extension %s for object %s", field, new_obj) 253 | 254 | 255 | #: Toggled variable used to disable listening to operations momentarily. 256 | listening = True 257 | 258 | 259 | def toggle_listening(enabled=None): 260 | """ 261 | Change the listening state. 262 | 263 | If set to ``False``, no operations will be registered. This can be 264 | used to disable dbsync temporarily, in scripts or blocks that 265 | execute in a single-threaded environment. 266 | """ 267 | global listening 268 | listening = enabled if enabled is not None else not listening 269 | 270 | 271 | def with_listening(enabled): 272 | """ 273 | Decorator for procedures to be executed with the specified 274 | listening status. 275 | """ 276 | def wrapper(proc): 277 | @wraps(proc) 278 | def wrapped(*args, **kwargs): 279 | prev = bool(listening) 280 | toggle_listening(enabled) 281 | try: 282 | return proc(*args, **kwargs) 283 | finally: 284 | toggle_listening(prev) 285 | return wrapped 286 | return wrapper 287 | 288 | 289 | # Helper functions used to queue extension operations in a transaction. 290 | 291 | def _track_added(fn, added): 292 | def tracked(o, **kws): 293 | if _has_extensions(o): added.append(o) 294 | return fn(o, **kws) 295 | return tracked 296 | 297 | def _track_deleted(fn, deleted, session, always=False): 298 | def tracked(o, **kws): 299 | if _has_delete_functions(o): 300 | if always: deleted.append((copy(o), None)) 301 | else: 302 | prev = query_model(session, type(o)).filter_by( 303 | **{get_pk(o): getattr(o, get_pk(o), None)}).\ 304 | first() 305 | if prev is not None: 306 | deleted.append((copy(prev), o)) 307 | return fn(o, **kws) 308 | return tracked 309 | 310 | 311 | def with_transaction(include_extensions=True): 312 | """ 313 | Decorator for a procedure that uses a session and acts as an 314 | atomic transaction. It feeds a new session to the procedure, and 315 | commits it, rolls it back, and / or closes it when it's 316 | appropriate. If *include_extensions* is ``False``, the transaction 317 | will ignore model extensions. 318 | """ 319 | def wrapper(proc): 320 | @wraps(proc) 321 | def wrapped(*args, **kwargs): 322 | extensions = kwargs.pop('include_extensions', include_extensions) 323 | session = Session() 324 | previous_state = dialects.begin_transaction(session) 325 | added = [] 326 | deleted = [] 327 | if extensions: 328 | session.add = _track_deleted( 329 | _track_added(session.add, added), 330 | deleted, 331 | session) 332 | session.merge = _track_deleted( 333 | _track_added(session.merge, added), 334 | deleted, 335 | session) 336 | session.delete = _track_deleted( 337 | session.delete, 338 | deleted, 339 | session, 340 | always=True) 341 | result = None 342 | try: 343 | kwargs.update({'session': session}) 344 | result = proc(*args, **kwargs) 345 | session.commit() 346 | except: 347 | session.rollback() 348 | raise 349 | finally: 350 | dialects.end_transaction(previous_state, session) 351 | session.close() 352 | for old_obj, new_obj in deleted: delete_extensions(old_obj, new_obj) 353 | for obj in added: save_extensions(obj) 354 | return result 355 | return wrapped 356 | return wrapper 357 | 358 | 359 | def make_content_type_id(model): 360 | "Returns a content type id for the given model." 361 | mname = model.__name__ 362 | tname = model.__table__.name 363 | return zlib.crc32("{0}/{1}".format(mname, tname), 0) & 0xffffffff 364 | 365 | 366 | @session_committing 367 | def generate_content_types(session=None): 368 | """ 369 | Fills the content type table. 370 | 371 | Inserts content types into the internal table used to describe 372 | operations. 373 | """ 374 | for tname, record in synched_models.tables.iteritems(): 375 | content_type_id = record.id 376 | mname = record.model.__name__ 377 | if session.query(ContentType).\ 378 | filter(ContentType.table_name == tname).count() == 0: 379 | session.add(ContentType(table_name=tname, 380 | model_name=mname, 381 | content_type_id=content_type_id)) 382 | 383 | 384 | @session_closing 385 | def is_synched(obj, session=None): 386 | """ 387 | Returns whether the given tracked object is synched. 388 | 389 | Raises a TypeError if the given object is not being tracked 390 | (i.e. the content type doesn't exist). 391 | """ 392 | if type(obj) not in synched_models.models: 393 | raise TypeError("the given object of class {0} isn't being tracked".\ 394 | format(obj.__class__.__name__)) 395 | session = Session() 396 | last_op = session.query(Operation).\ 397 | filter(Operation.content_type_id == synched_models.models[type(obj)].id, 398 | Operation.row_id == getattr(obj, get_pk(obj))).\ 399 | order_by(Operation.order.desc()).first() 400 | return last_op is None or last_op.version_id is not None 401 | 402 | 403 | @session_closing 404 | def get_latest_version_id(session=None): 405 | """ 406 | Returns the latest version identifier or ``None`` if no version is 407 | found. 408 | """ 409 | # assuming version identifiers grow monotonically 410 | # might need to order by 'created' datetime field 411 | version = session.query(Version).order_by(Version.version_id.desc()).first() 412 | return maybe(version, attr('version_id'), None) 413 | -------------------------------------------------------------------------------- /dbsync/dialects.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: dbsync.dialects 3 | :synopsis: DBMS-dependent statements. 4 | """ 5 | 6 | from sqlalchemy import func 7 | 8 | from dbsync.utils import class_mapper, get_pk 9 | 10 | 11 | def begin_transaction(session): 12 | """ 13 | Returns information of the state the database was on before the 14 | transaction began. 15 | """ 16 | engine = session.bind 17 | dialect = engine.name 18 | if dialect == 'sqlite': 19 | cursor = engine.execute("PRAGMA foreign_keys;") 20 | state = cursor.fetchone()[0] 21 | cursor.close() 22 | engine.execute("PRAGMA foreign_keys = OFF;") 23 | engine.execute("BEGIN EXCLUSIVE TRANSACTION;") 24 | return state 25 | if dialect == 'mysql': 26 | # temporal by default 27 | # see http://dev.mysql.com/doc/refman/5.7/en/using-system-variables.html 28 | engine.execute("SET foreign_key_checks = 0;") 29 | return None 30 | if dialect == 'postgresql': 31 | # defer constraints 32 | engine.execute("SET CONSTRAINTS ALL DEFERRED;") 33 | return None 34 | return None 35 | 36 | 37 | def end_transaction(state, session): 38 | """ 39 | *state* is whatever was returned by :func:`begin_transaction` 40 | """ 41 | engine = session.bind 42 | dialect = engine.name 43 | if dialect == 'sqlite': 44 | if state not in (0, 1): state = 1 45 | engine.execute("PRAGMA foreign_keys = {0}".format(int(state))) 46 | 47 | 48 | def max_local(sa_class, session): 49 | """ 50 | Returns the maximum primary key used for the given table. 51 | """ 52 | engine = session.bind 53 | dialect = engine.name 54 | table_name = class_mapper(sa_class).mapped_table.name 55 | # default, strictly incorrect query 56 | found = session.query(func.max(getattr(sa_class, get_pk(sa_class)))).scalar() 57 | if dialect == 'sqlite': 58 | cursor = engine.execute("SELECT seq FROM sqlite_sequence WHERE name = ?", 59 | table_name) 60 | result = cursor.fetchone()[0] 61 | cursor.close() 62 | return max(result, found) 63 | return found 64 | -------------------------------------------------------------------------------- /dbsync/lang.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic functions for repeating patterns. 3 | """ 4 | 5 | from itertools import imap, ifilter, izip, izip_longest 6 | from functools import partial as partial_apply, wraps 7 | 8 | 9 | def identity(x): 10 | return x 11 | 12 | 13 | def maybe(value, fn=identity, default=""): 14 | "``if value is None: ...`` more compressed." 15 | if value is None: 16 | return default 17 | return fn(value) 18 | 19 | 20 | def guard(f): 21 | "Propagate nothingness in a function of one argument." 22 | @wraps(f) 23 | def g(x): 24 | return maybe(x, f, None) 25 | return g 26 | 27 | 28 | def partial(f, *arguments): 29 | """ 30 | http://bugs.python.org/issue3445 31 | https://docs.python.org/2/library/functools.html#partial-objects 32 | """ 33 | p = partial_apply(f, *arguments) 34 | p.__module__ = f.__module__ 35 | p.__name__ = "partial-{0}".format(f.__name__) 36 | return p 37 | 38 | 39 | class Function(object): 40 | "Composable function for attr and method usage." 41 | def __init__(self, fn): 42 | self.fn = fn 43 | self.__name__ = fn.__name__ # e.g. for the wraps decorator 44 | def __call__(self, obj): 45 | return self.fn(obj) 46 | def __eq__(self, other): 47 | if isinstance(other, Function): 48 | return Function(lambda obj: self.fn(obj) == other(obj)) 49 | else: 50 | return Function(lambda obj: self.fn(obj) == other) 51 | def __lt__(self, other): 52 | if isinstance(other, Function): 53 | return Function(lambda obj: self.fn(obj) < other(obj)) 54 | else: 55 | return Function(lambda obj: self.fn(obj) < other) 56 | def __le__(self, other): 57 | if isinstance(other, Function): 58 | return Function(lambda obj: self.fn(obj) <= other(obj)) 59 | else: 60 | return Function(lambda obj: self.fn(obj) <= other) 61 | def __ne__(self, other): 62 | if isinstance(other, Function): 63 | return Function(lambda obj: self.fn(obj) != other(obj)) 64 | else: 65 | return Function(lambda obj: self.fn(obj) != other) 66 | def __gt__(self, other): 67 | if isinstance(other, Function): 68 | return Function(lambda obj: self.fn(obj) > other(obj)) 69 | else: 70 | return Function(lambda obj: self.fn(obj) > other) 71 | def __ge__(self, other): 72 | if isinstance(other, Function): 73 | return Function(lambda obj: self.fn(obj) >= other(obj)) 74 | else: 75 | return Function(lambda obj: self.fn(obj) >= other) 76 | def __invert__(self): 77 | return Function(lambda obj: not self.fn(obj)) 78 | def __and__(self, other): 79 | if isinstance(other, Function): 80 | return Function(lambda obj: self.fn(obj) and other(obj)) 81 | else: 82 | return Function(lambda obj: self.fn(obj) and other) 83 | def __or__(self, other): 84 | if isinstance(other, Function): 85 | return Function(lambda obj: self.fn(obj) or other(obj)) 86 | else: 87 | return Function(lambda obj: self.fn(obj) or other) 88 | def in_(self, collection): 89 | return Function(lambda obj: self.fn(obj) in collection) 90 | 91 | 92 | def attr(name): 93 | "For use in standard higher order functions." 94 | return Function(lambda obj: getattr(obj, name)) 95 | 96 | 97 | def method(name, *args, **kwargs): 98 | "For use in standard higher order functions." 99 | return Function(lambda obj: getattr(obj, name)(*args, **kwargs)) 100 | 101 | 102 | def group_by(fn, col): 103 | """ 104 | Groups a collection according to the given *fn* into a dictionary. 105 | 106 | *fn* should return a hashable. 107 | """ 108 | groups = {} 109 | for e in col: 110 | key = fn(e) 111 | subcol = groups.get(key, None) 112 | if subcol is None: 113 | groups[key] = [e] 114 | else: 115 | subcol.append(e) 116 | return groups 117 | 118 | 119 | def grouper(iterable, n): 120 | """ 121 | Collect data into chunks or blocks of at most *n* elements. 122 | """ 123 | assert n > 0, "n must be greater than 0" 124 | count = 0 125 | accum = [] 126 | for e in iterable: 127 | accum.append(e) 128 | count += 1 129 | if count == n: 130 | yield tuple(accum) 131 | count = 0 132 | del accum[:] 133 | if accum: 134 | yield tuple(accum) 135 | 136 | 137 | def lookup(predicate, collection, default=None): 138 | """ 139 | Looks up the first value in *collection* that satisfies 140 | *predicate*. 141 | """ 142 | for e in collection: 143 | if predicate(e): 144 | return e 145 | return default 146 | 147 | 148 | def mfilter(predicate, lst): 149 | """ 150 | Removes the elements in *lst* that don't satisfy *predictate*, 151 | mutating *lst* (a list or a set). 152 | """ 153 | matching = filter(lambda e: not predicate(e), lst) 154 | for e in matching: 155 | lst.remove(e) 156 | return lst 157 | -------------------------------------------------------------------------------- /dbsync/logs.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: dbsync.logs 3 | :synopsis: Logging facilities for the library. 4 | """ 5 | 6 | import logging 7 | 8 | 9 | #: All the library loggers 10 | loggers = set() 11 | 12 | 13 | log_handler = None 14 | 15 | 16 | def get_logger(name): 17 | logger = logging.getLogger(name) 18 | logger.setLevel(logging.WARNING) 19 | loggers.add(logger) 20 | if log_handler is not None: 21 | logger.addHandler(log_handler) 22 | return logger 23 | 24 | 25 | def set_log_target(fo): 26 | """ 27 | Set a stream as target for dbsync's logging. If a string is given, 28 | it will be considered to be a path to a file. 29 | """ 30 | global log_handler 31 | if log_handler is None: 32 | log_handler = logging.FileHandler(fo) if isinstance(fo, basestring) \ 33 | else logging.StreamHandler(fo) 34 | log_handler.setLevel(logging.WARNING) 35 | log_handler.setFormatter( 36 | logging.Formatter( 37 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s")) 38 | for logger in loggers: 39 | logger.addHandler(log_handler) 40 | -------------------------------------------------------------------------------- /dbsync/messages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bintlabs/python-sync-db/bb23d77abf560793696f906e030950aec04c3361/dbsync/messages/__init__.py -------------------------------------------------------------------------------- /dbsync/messages/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base functionality for synchronization messages. 3 | """ 4 | 5 | import inspect 6 | 7 | from dbsync.lang import * 8 | from dbsync.utils import get_pk, properties_dict, construct_bare 9 | from dbsync.core import null_model, synched_models, model_extensions 10 | from dbsync import models 11 | from dbsync.messages.codecs import decode_dict, encode_dict 12 | 13 | 14 | class ObjectType(object): 15 | "Wrapper for tracked objects." 16 | 17 | def __init__(self, mname, pk, **kwargs): 18 | self.__model_name__ = mname 19 | self.__pk__ = pk 20 | self.__keys__ = [] 21 | for k, v in kwargs.iteritems(): 22 | if k != '__model_name__' and k != '__pk__' and k != '__keys__': 23 | setattr(self, k, v) 24 | self.__keys__.append(k) 25 | 26 | def __repr__(self): 27 | return u"".format( 28 | self.__model_name__, self.__pk__) 29 | 30 | def __eq__(self, other): 31 | if not isinstance(other, ObjectType): 32 | raise TypeError("not an instance of ObjectType") 33 | return self.__model_name__ == other.__model_name__ and \ 34 | self.__pk__ == other.__pk__ 35 | 36 | def __hash__(self): 37 | return self.__pk__ 38 | 39 | def to_dict(self): 40 | return dict((k, getattr(self, k)) for k in self.__keys__) 41 | 42 | def to_mapped_object(self): 43 | model = synched_models.model_names.\ 44 | get(self.__model_name__, null_model).model 45 | if model is None: 46 | raise TypeError( 47 | "model {0} isn't being tracked".format(self.__model_name__)) 48 | obj = construct_bare(model) 49 | for k in self.__keys__: 50 | setattr(obj, k, getattr(self, k)) 51 | return obj 52 | 53 | 54 | class MessageQuery(object): 55 | "Query over internal structure of a message." 56 | 57 | def __init__(self, target, payload): 58 | if target == models.Operation or \ 59 | target == models.Version or \ 60 | target == models.Node: 61 | self.target = 'models.' + target.__name__ 62 | elif inspect.isclass(target): 63 | self.target = target.__name__ 64 | elif isinstance(target, basestring): 65 | self.target = target 66 | else: 67 | raise TypeError( 68 | "query expected a class or string, got %s" % type(target)) 69 | self.payload = payload 70 | 71 | def query(self, model): 72 | """ 73 | Returns a new query with a different target, without 74 | filtering. 75 | """ 76 | return MessageQuery(model, self.payload) 77 | 78 | def filter(self, predicate): 79 | """ 80 | Returns a new query with the collection filtered according to 81 | the predicate applied to the target objects. 82 | """ 83 | to_filter = self.payload.get(self.target, None) 84 | if to_filter is None: 85 | return self 86 | return MessageQuery( 87 | self.target, 88 | dict(self.payload, **{self.target: filter(predicate, to_filter)})) 89 | 90 | def __iter__(self): 91 | "Yields objects mapped to their original type (*target*)." 92 | m = identity if self.target.startswith('models.') \ 93 | else method('to_mapped_object') 94 | lst = self.payload.get(self.target, None) 95 | if lst is not None: 96 | for e in imap(m, lst): 97 | yield e 98 | 99 | def all(self): 100 | "Returns a list of all queried objects." 101 | return list(self) 102 | 103 | def first(self): 104 | """ 105 | Returns the first of the queried objects, or ``None`` if no 106 | objects matched. 107 | """ 108 | try: return next(iter(self)) 109 | except StopIteration: return None 110 | 111 | 112 | class BaseMessage(object): 113 | "The base type for messages with a payload." 114 | 115 | #: dictionary of (model name, set of wrapped objects) 116 | payload = None 117 | 118 | def __init__(self, raw_data=None): 119 | self.payload = {} 120 | if raw_data is not None: 121 | self._from_raw(raw_data) 122 | 123 | def _from_raw(self, data): 124 | getm = lambda k: synched_models.model_names.get(k, null_model).model 125 | for k, v, m in ifilter(lambda (k, v, m): m is not None, 126 | imap(lambda (k, v): (k, v, getm(k)), 127 | data['payload'].iteritems())): 128 | self.payload[k] = set( 129 | map(lambda dict_: ObjectType(k, dict_[get_pk(m)], **dict_), 130 | imap(decode_dict(m), v))) 131 | 132 | def query(self, model): 133 | "Returns a query object for this message." 134 | return MessageQuery(model, self.payload) 135 | 136 | def to_json(self): 137 | "Returns a JSON-friendly python dictionary." 138 | encoded = {} 139 | encoded['payload'] = {} 140 | for k, objects in self.payload.iteritems(): 141 | model = synched_models.model_names.get(k, null_model).model 142 | if model is not None: 143 | encoded['payload'][k] = map(encode_dict(model), 144 | imap(method('to_dict'), objects)) 145 | return encoded 146 | 147 | def add_object(self, obj, include_extensions=True): 148 | "Adds an object to the message, if it's not already in." 149 | class_ = type(obj) 150 | classname = class_.__name__ 151 | obj_set = self.payload.get(classname, set()) 152 | if ObjectType(classname, getattr(obj, get_pk(class_))) in obj_set: 153 | return self 154 | properties = properties_dict(obj) 155 | if include_extensions: 156 | for field, ext in model_extensions.get(classname, {}).iteritems(): 157 | _, loadfn, _, _ = ext 158 | properties[field] = loadfn(obj) 159 | obj_set.add(ObjectType( 160 | classname, getattr(obj, get_pk(class_)), **properties)) 161 | self.payload[classname] = obj_set 162 | return self 163 | -------------------------------------------------------------------------------- /dbsync/messages/codecs.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: messages.codecs 3 | :synopsis: Encoding and decoding of specific datatypes. 4 | """ 5 | 6 | import datetime 7 | import base64 8 | import decimal 9 | 10 | from sqlalchemy import types 11 | from dbsync import core 12 | from dbsync.lang import * 13 | from dbsync.utils import types_dict as bare_types_dict 14 | 15 | 16 | def types_dict(class_): 17 | "Augments standard types_dict with model extensions." 18 | dict_ = bare_types_dict(class_) 19 | extensions = core.model_extensions.get(class_.__name__, {}) 20 | for field, ext in extensions.iteritems(): 21 | type_, _, _, _ = ext 22 | dict_[field] = type_ 23 | return dict_ 24 | 25 | 26 | def _encode_table(type_): 27 | "*type_* is a SQLAlchemy data type." 28 | if isinstance(type_, types.Date): 29 | return lambda value: [value.year, value.month, value.day] 30 | elif isinstance(type_, types.DateTime): 31 | return lambda value: [value.year, value.month, value.day, 32 | value.hour, value.minute, value.second, 33 | value.microsecond] 34 | elif isinstance(type_, types.Time): 35 | return lambda value: [value.hour, value.minute, value.second, 36 | value.microsecond] 37 | elif isinstance(type_, types.LargeBinary): 38 | return base64.standard_b64encode 39 | elif isinstance(type_, types.Numeric) and type_.asdecimal: 40 | return str 41 | return identity 42 | 43 | #: Encodes a python value into a JSON-friendly python value. 44 | encode = lambda t: guard(_encode_table(t)) 45 | 46 | def encode_dict(class_): 47 | """ 48 | Returns a function that transforms a dictionary, mapping the 49 | types to simpler ones, according to the given mapped class. 50 | """ 51 | types = types_dict(class_) 52 | encodings = dict((k, encode(t)) for k, t in types.iteritems()) 53 | return lambda dict_: dict((k, encodings[k](v)) 54 | for k, v in dict_.iteritems() 55 | if k in encodings) 56 | 57 | 58 | def _decode_table(type_): 59 | "*type_* is a SQLAlchemy data type." 60 | if isinstance(type_, types.Date): 61 | return partial(apply, datetime.date) 62 | elif isinstance(type_, types.DateTime): 63 | return partial(apply, datetime.datetime) 64 | elif isinstance(type_, types.Time): 65 | return partial(apply, datetime.time) 66 | elif isinstance(type_, types.LargeBinary): 67 | return base64.standard_b64decode 68 | elif isinstance(type_, types.Numeric) and type_.asdecimal: 69 | return decimal.Decimal 70 | return identity 71 | 72 | #: Decodes a value coming from a JSON string into a richer python value. 73 | decode = lambda t: guard(_decode_table(t)) 74 | 75 | def decode_dict(class_): 76 | """ 77 | Returns a function that transforms a dictionary, mapping the 78 | types to richer ones, according to the given mapped class. 79 | """ 80 | types = types_dict(class_) 81 | decodings = dict((k, decode(t)) for k, t in types.iteritems()) 82 | return lambda dict_: dict((k, decodings[k](v)) 83 | for k, v in dict_.iteritems() 84 | if k in decodings) 85 | -------------------------------------------------------------------------------- /dbsync/messages/pull.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pull message and related. 3 | """ 4 | 5 | import datetime 6 | 7 | from sqlalchemy import types 8 | from dbsync.utils import ( 9 | properties_dict, 10 | object_from_dict, 11 | get_pk, 12 | parent_references, 13 | parent_objects, 14 | query_model) 15 | from dbsync.lang import * 16 | 17 | from dbsync.core import ( 18 | MAX_SQL_VARIABLES, 19 | session_closing, 20 | synched_models, 21 | pulled_models, 22 | get_latest_version_id) 23 | from dbsync.models import Operation, Version 24 | from dbsync.messages.base import MessageQuery, BaseMessage 25 | from dbsync.messages.codecs import encode, encode_dict, decode, decode_dict 26 | 27 | 28 | class PullMessage(BaseMessage): 29 | """ 30 | A pull message. 31 | 32 | A pull message can be queried over by version, operation or model, 33 | and can be filtered multiple times. 34 | 35 | It can be instantiated from a raw data dictionary, or can be made 36 | empty and filled later with specific methods (``add_version``, 37 | ``add_operation``, ``add_object``). 38 | """ 39 | 40 | #: Datetime of creation. 41 | created = None 42 | 43 | #: List of operations to perform in the node. 44 | operations = None 45 | 46 | #: List of versions being pulled. 47 | versions = None 48 | 49 | def __init__(self, raw_data=None): 50 | """ 51 | *raw_data* must be a python dictionary, normally the 52 | product of JSON decoding. If not given, the message will be 53 | empty and should be filled with the appropriate methods 54 | (add_*). 55 | """ 56 | super(PullMessage, self).__init__(raw_data) 57 | if raw_data is not None: 58 | self._build_from_raw(raw_data) 59 | else: 60 | self.created = datetime.datetime.now() 61 | self.operations = [] 62 | self.versions = [] 63 | 64 | def _build_from_raw(self, data): 65 | self.created = decode(types.DateTime())(data['created']) 66 | self.operations = map(partial(object_from_dict, Operation), 67 | imap(decode_dict(Operation), data['operations'])) 68 | self.versions = map(partial(object_from_dict, Version), 69 | imap(decode_dict(Version), data['versions'])) 70 | 71 | def query(self, model): 72 | "Returns a query object for this message." 73 | return MessageQuery( 74 | model, 75 | dict(self.payload, **{ 76 | 'models.Operation': self.operations, 77 | 'models.Version': self.versions})) 78 | 79 | def to_json(self): 80 | """ 81 | Returns a JSON-friendly python dictionary. Structure:: 82 | 83 | created: datetime, 84 | operations: list of operations, 85 | versions: list of versions, 86 | payload: dictionary with lists of objects mapped to model names 87 | """ 88 | encoded = super(PullMessage, self).to_json() 89 | encoded['created'] = encode(types.DateTime())(self.created) 90 | encoded['operations'] = map(encode_dict(Operation), 91 | imap(properties_dict, self.operations)) 92 | encoded['versions'] = map(encode_dict(Version), 93 | imap(properties_dict, self.versions)) 94 | return encoded 95 | 96 | @session_closing 97 | def add_operation(self, op, swell=True, session=None): 98 | """ 99 | Adds an operation to the message, including the required 100 | object if it's possible to include it. 101 | 102 | If *swell* is given and set to ``False``, the operation and 103 | object will be added bare, without parent objects. Otherwise, 104 | the parent objects will be added to aid in conflict 105 | resolution. 106 | 107 | A delete operation doesn't include the associated object. If 108 | *session* is given, the procedure won't instantiate a new 109 | session. 110 | 111 | This operation might fail, (due to database inconsitency) in 112 | which case the internal state of the message won't be affected 113 | (i.e. it won't end in an inconsistent state). 114 | 115 | DEPRECATED in favor of `fill_for` 116 | """ 117 | model = op.tracked_model 118 | if model is None: 119 | raise ValueError("operation linked to model %s "\ 120 | "which isn't being tracked" % model) 121 | if model not in pulled_models: 122 | return self 123 | obj = query_model(session, model).\ 124 | filter_by(**{get_pk(model): op.row_id}).first() \ 125 | if op.command != 'd' else None 126 | self.operations.append(op) 127 | # if the object isn't there it's because the operation is old, 128 | # and should be able to be compressed out when performing the 129 | # conflict resolution phase 130 | if obj is not None: 131 | self.add_object(obj) 132 | if swell: 133 | # add parent objects to resolve possible conflicts in merge 134 | for parent in parent_objects(obj, synched_models.models.keys(), 135 | session): 136 | self.add_object(parent) 137 | return self 138 | 139 | @session_closing 140 | def add_version(self, v, swell=True, session=None): 141 | """ 142 | Adds a version to the message, and all associated 143 | operations and objects. 144 | 145 | This method will either fail and leave the message instance as 146 | if nothing had happened, or it will succeed and return the 147 | modified message. 148 | 149 | DEPRECATED in favor of `fill_for` 150 | """ 151 | for op in v.operations: 152 | if op.content_type_id not in synched_models.ids: 153 | raise ValueError("version includes operation linked "\ 154 | "to model not currently being tracked", op) 155 | self.versions.append(v) 156 | for op in v.operations: 157 | self.add_operation(op, swell=swell, session=session) 158 | return self 159 | 160 | @session_closing 161 | def fill_for(self, request, swell=False, include_extensions=True, 162 | session=None): 163 | """ 164 | Fills this pull message (response) with versions, operations 165 | and objects, for the given request (PullRequestMessage). 166 | 167 | The *swell* parameter is deprecated and considered ``True`` 168 | regardless of the value given. This means that parent objects 169 | will always be added to the message. 170 | 171 | *include_extensions* dictates whether the pull message will 172 | include model extensions or not. 173 | """ 174 | assert isinstance(request, PullRequestMessage), "invalid request" 175 | versions = session.query(Version) 176 | if request.latest_version_id is not None: 177 | versions = versions.\ 178 | filter(Version.version_id > request.latest_version_id) 179 | required_objects = {} 180 | required_parents = {} 181 | for v in versions: 182 | self.versions.append(v) 183 | for op in v.operations: 184 | model = op.tracked_model 185 | if model is None: 186 | raise ValueError("operation linked to model %s "\ 187 | "which isn't being tracked" % model) 188 | if model not in pulled_models: continue 189 | self.operations.append(op) 190 | if op.command != 'd': 191 | pks = required_objects.get(model, set()) 192 | pks.add(op.row_id) 193 | required_objects[model] = pks 194 | 195 | for model, pks in ((m, batch) 196 | for m, pks in required_objects.iteritems() 197 | for batch in grouper(pks, MAX_SQL_VARIABLES)): 198 | for obj in query_model(session, model).filter( 199 | getattr(model, get_pk(model)).in_(list(pks))).all(): 200 | self.add_object(obj, include_extensions=include_extensions) 201 | # add parent objects to resolve conflicts in merge 202 | for pmodel, ppk in parent_references(obj, 203 | synched_models.models.keys()): 204 | parent_pks = required_parents.get(pmodel, set()) 205 | parent_pks.add(ppk) 206 | required_parents[pmodel] = parent_pks 207 | 208 | for pmodel, ppks in ((m, batch) 209 | for m, pks in required_parents.iteritems() 210 | for batch in grouper(pks, MAX_SQL_VARIABLES)): 211 | for parent in query_model(session, pmodel).filter( 212 | getattr(pmodel, get_pk(pmodel)).in_(list(ppks))).all(): 213 | self.add_object(parent, include_extensions=include_extensions) 214 | return self 215 | 216 | 217 | class PullRequestMessage(BaseMessage): 218 | """ 219 | A pull request message. 220 | 221 | The message includes information for the server to decide whether 222 | it should send back related (parent) objects to those directly 223 | involved, for use in conflict resolution, or not. This is used to 224 | allow for thinner PullMessage(s) to be built, through the 225 | add_version and add_operation methods. 226 | """ 227 | 228 | #: List of operation the node has performed since the last 229 | # synchronization. If empty, the pull is a full 'fast-forward' 230 | # thin procedure. 231 | operations = None 232 | 233 | #: The identifier used to select the operations to be included in 234 | # the pull response. 235 | latest_version_id = None 236 | 237 | def __init__(self, raw_data=None): 238 | """ 239 | *raw_data* must be a python dictionary. If not given, the 240 | message should be filled with the or 241 | add_unversioned_operations method. 242 | """ 243 | super(PullRequestMessage, self).__init__(raw_data) 244 | if raw_data is not None: 245 | self._build_from_raw(raw_data) 246 | else: 247 | self.latest_version_id = get_latest_version_id() 248 | self.operations = [] 249 | 250 | def _build_from_raw(self, data): 251 | self.operations = map(partial(object_from_dict, Operation), 252 | imap(decode_dict(Operation), data['operations'])) 253 | self.latest_version_id = decode(types.Integer())( 254 | data['latest_version_id']) 255 | 256 | def query(self, model): 257 | "Returns a query object for this message." 258 | return MessageQuery( 259 | model, 260 | dict(self.payload, **{'models.Operation': self.operations})) 261 | 262 | def to_json(self): 263 | "Returns a JSON-friendly python dictionary." 264 | encoded = super(PullRequestMessage, self).to_json() 265 | encoded['operations'] = map(encode_dict(Operation), 266 | imap(properties_dict, self.operations)) 267 | encoded['latest_version_id'] = encode(types.Integer())( 268 | self.latest_version_id) 269 | return encoded 270 | 271 | def add_operation(self, op): 272 | """ 273 | Adds an operation to the message, including the required 274 | object if possible. 275 | """ 276 | assert op.version_id is None, "the operation {0} is already versioned".\ 277 | format(op) 278 | model = op.tracked_model 279 | if model is None: 280 | raise ValueError("operation linked to model %s "\ 281 | "which isn't being tracked" % model) 282 | if model not in pulled_models: return self 283 | self.operations.append(op) 284 | return self 285 | 286 | @session_closing 287 | def add_unversioned_operations(self, session=None): 288 | """ 289 | Adds all unversioned operations to this message and 290 | required objects. 291 | """ 292 | operations = session.query(Operation).\ 293 | filter(Operation.version_id == None).all() 294 | if any(op.content_type_id not in synched_models.ids 295 | for op in operations): 296 | raise ValueError("version includes operation linked "\ 297 | "to model not currently being tracked") 298 | for op in operations: 299 | self.add_operation(op) 300 | return self 301 | -------------------------------------------------------------------------------- /dbsync/messages/push.py: -------------------------------------------------------------------------------- 1 | """ 2 | Push message and related. 3 | """ 4 | 5 | import datetime 6 | import hashlib 7 | 8 | from sqlalchemy import types 9 | from dbsync.utils import ( 10 | properties_dict, 11 | object_from_dict, 12 | get_pk, 13 | parent_objects, 14 | query_model) 15 | from dbsync.lang import * 16 | 17 | from dbsync.core import ( 18 | MAX_SQL_VARIABLES, 19 | session_closing, 20 | synched_models, 21 | pushed_models) 22 | from dbsync.models import Node, Operation 23 | from dbsync.messages.base import MessageQuery, BaseMessage 24 | from dbsync.messages.codecs import encode, encode_dict, decode, decode_dict 25 | 26 | 27 | class PushMessage(BaseMessage): 28 | """ 29 | A push message. 30 | 31 | A push message contains the latest version information, the node 32 | information, and the list of unversioned operations and the 33 | required objects for those to be performed. 34 | 35 | The message can be instantiated from a raw data dictionary or can 36 | be made empty and filled later with the 37 | ``add_unversioned_operations`` method, in which case the node 38 | attribute and the latest version identifier should be assigned 39 | explicitly as well. The method ``set_node`` is required to be used 40 | for proper key generation. 41 | 42 | If the node is not assigned the message will still behave 43 | normally, since verification of its presence is not enforced on 44 | the client, and might not be enforced on the server. Likewise, if 45 | the latest version isn't assigned, it'll be just interpreted on 46 | the server to be the initial data load. 47 | 48 | To verify correctness, use ``islegit`` giving a session with 49 | access to the synch database. 50 | """ 51 | 52 | #: Datetime of creation 53 | created = None 54 | 55 | #: Node primary key 56 | node_id = None 57 | 58 | #: Secret used internally to mitigate obnoxiousness. 59 | _secret = None 60 | 61 | #: Key to this message 62 | key = None 63 | 64 | #: The latest version 65 | latest_version_id = None 66 | 67 | #: List of unversioned operations 68 | operations = None 69 | 70 | def __init__(self, raw_data=None): 71 | """ 72 | *raw_data* must be a python dictionary. If not given, the 73 | message will be empty and should be filled after 74 | instantiation. 75 | """ 76 | super(PushMessage, self).__init__(raw_data) 77 | if raw_data is not None: 78 | self._build_from_raw(raw_data) 79 | else: 80 | self.created = datetime.datetime.now() 81 | self.operations = [] 82 | 83 | def _build_from_raw(self, data): 84 | self.created = decode(types.DateTime())(data['created']) 85 | self.node_id = decode(types.Integer())(data['node_id']) 86 | self.key = decode(types.String())(data['key']) 87 | self.latest_version_id = decode(types.Integer())( 88 | data['latest_version_id']) 89 | self.operations = map(partial(object_from_dict, Operation), 90 | imap(decode_dict(Operation), data['operations'])) 91 | 92 | def query(self, model): 93 | "Returns a query object for this message." 94 | return MessageQuery( 95 | model, 96 | dict( 97 | self.payload, 98 | **{'models.Operation': self.operations})) 99 | 100 | def to_json(self): 101 | """ 102 | Returns a JSON-friendly python dictionary. Structure:: 103 | 104 | created: datetime, 105 | node_id: node primary key or null, 106 | key: a string generated from the secret and part of the message, 107 | latest_version_id: number or null, 108 | operations: list of operations, 109 | payload: dictionay with lists of objects mapped to model names 110 | """ 111 | encoded = super(PushMessage, self).to_json() 112 | encoded['created'] = encode(types.DateTime())(self.created) 113 | encoded['node_id'] = encode(types.Integer())(self.node_id) 114 | encoded['key'] = encode(types.String())(self.key) 115 | encoded['latest_version_id'] = encode(types.Integer())( 116 | self.latest_version_id) 117 | encoded['operations'] = map(encode_dict(Operation), 118 | imap(properties_dict, self.operations)) 119 | return encoded 120 | 121 | def _portion(self): 122 | "Returns part of this message as a string." 123 | portion = "".join("&{0}#{1}#{2}".\ 124 | format(op.row_id, op.content_type_id, op.command) 125 | for op in self.operations) 126 | return portion 127 | 128 | def _sign(self): 129 | if self._secret is not None: 130 | self.key = hashlib.sha512(self._secret + self._portion()).hexdigest() 131 | 132 | def set_node(self, node): 133 | "Sets the node and key for this message." 134 | if node is None: return 135 | self.node_id = node.node_id 136 | self._secret = node.secret 137 | self._sign() 138 | 139 | def islegit(self, session): 140 | "Checks whether the key for this message is proper." 141 | if self.key is None or self.node_id is None: return False 142 | node = session.query(Node).filter(Node.node_id == self.node_id).first() 143 | return node is not None and \ 144 | self.key == hashlib.sha512(node.secret + self._portion()).hexdigest() 145 | 146 | @session_closing 147 | def add_unversioned_operations(self, session=None, include_extensions=True): 148 | """ 149 | Adds all unversioned operations to this message, including the 150 | required objects for them to be performed. 151 | """ 152 | operations = session.query(Operation).\ 153 | filter(Operation.version_id == None).all() 154 | if any(op.content_type_id not in synched_models.ids 155 | for op in operations): 156 | raise ValueError("version includes operation linked "\ 157 | "to model not currently being tracked") 158 | required_objects = {} 159 | for op in operations: 160 | model = op.tracked_model 161 | if model not in pushed_models: continue 162 | self.operations.append(op) 163 | if op.command != 'd': 164 | pks = required_objects.get(model, set()) 165 | pks.add(op.row_id) 166 | required_objects[model] = pks 167 | for model, pks in ((m, batch) 168 | for m, pks in required_objects.iteritems() 169 | for batch in grouper(pks, MAX_SQL_VARIABLES)): 170 | for obj in query_model(session, model).filter( 171 | getattr(model, get_pk(model)).in_(list(pks))).all(): 172 | self.add_object(obj, include_extensions=include_extensions) 173 | if self.key is not None: 174 | # overwrite since it's probably an incorrect key 175 | self._sign() 176 | return self 177 | -------------------------------------------------------------------------------- /dbsync/messages/register.py: -------------------------------------------------------------------------------- 1 | """ 2 | Register message and related. 3 | """ 4 | 5 | from dbsync.models import Node 6 | from dbsync.utils import properties_dict, object_from_dict 7 | from dbsync.messages.codecs import encode_dict, decode_dict 8 | 9 | 10 | class RegisterMessage(object): 11 | """A register message with node information.""" 12 | 13 | #: The node to be registered in the client application 14 | node = None 15 | 16 | def __init__(self, raw_data=None): 17 | if raw_data is not None: 18 | self._build_from_raw(raw_data) 19 | 20 | def _build_from_raw(self, data): 21 | self.node = object_from_dict(Node, decode_dict(Node)(data['node'])) 22 | 23 | def to_json(self): 24 | encoded = {} 25 | encoded['node'] = None 26 | if self.node is not None: 27 | encoded['node'] = encode_dict(Node)(properties_dict(self.node)) 28 | return encoded 29 | -------------------------------------------------------------------------------- /dbsync/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Internal model used to keep track of versions and operations. 3 | """ 4 | 5 | from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, BigInteger 6 | from sqlalchemy.orm import relationship, backref, validates 7 | from sqlalchemy.ext.declarative import declarative_base 8 | from sqlalchemy.ext.declarative.api import DeclarativeMeta 9 | 10 | from dbsync.lang import * 11 | from dbsync.utils import get_pk, query_model, properties_dict 12 | from dbsync.logs import get_logger 13 | 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | #: Database tables prefix. 19 | tablename_prefix = "sync_" 20 | 21 | 22 | class PrefixTables(DeclarativeMeta): 23 | def __init__(cls, classname, bases, dict_): 24 | if '__tablename__' in dict_: 25 | tn = dict_['__tablename__'] 26 | cls.__tablename__ = dict_['__tablename__'] = tablename_prefix + tn 27 | return super(PrefixTables, cls).__init__(classname, bases, dict_) 28 | 29 | Base = declarative_base(metaclass=PrefixTables) 30 | 31 | 32 | class ContentType(Base): 33 | "A weak abstraction over a database table." 34 | 35 | __tablename__ = "content_types" 36 | 37 | content_type_id = Column(BigInteger, primary_key=True) 38 | table_name = Column(String(500)) 39 | model_name = Column(String(500)) 40 | 41 | def __repr__(self): 42 | return u"".\ 43 | format(self.content_type_id, self.table_name, self.model_name) 44 | 45 | 46 | class Node(Base): 47 | """ 48 | A node registry. 49 | 50 | A node is a client application installed somewhere else. 51 | """ 52 | 53 | __tablename__ = "nodes" 54 | 55 | node_id = Column(Integer, primary_key=True) 56 | registered = Column(DateTime) 57 | registry_user_id = Column(Integer) 58 | secret = Column(String(128)) 59 | 60 | def __init__(self, *args, **kwargs): 61 | super(Node, self).__init__(*args, **kwargs) 62 | 63 | def __repr__(self): 64 | return u"".\ 66 | format(self.node_id, 67 | self.registered, 68 | self.registry_user_id, 69 | self.secret) 70 | 71 | 72 | class Version(Base): 73 | """ 74 | A database version. 75 | 76 | These are added for each 'push' accepted and executed without 77 | problems. 78 | """ 79 | 80 | __tablename__ = "versions" 81 | 82 | version_id = Column(Integer, primary_key=True) 83 | node_id = Column(Integer, ForeignKey(Node.__tablename__ + ".node_id")) 84 | created = Column(DateTime) 85 | 86 | node = relationship(Node) 87 | 88 | def __repr__(self): 89 | return u"".\ 90 | format(self.version_id, self.created) 91 | 92 | 93 | class OperationError(Exception): pass 94 | 95 | 96 | class Operation(Base): 97 | """ 98 | A database operation (insert, delete or update). 99 | 100 | The operations are grouped in versions and ordered as they are 101 | executed. 102 | """ 103 | 104 | __tablename__ = "operations" 105 | 106 | row_id = Column(Integer) 107 | version_id = Column( 108 | Integer, 109 | ForeignKey(Version.__tablename__ + ".version_id"), 110 | nullable=True) 111 | content_type_id = Column(BigInteger) 112 | tracked_model = None # to be injected 113 | command = Column(String(1)) 114 | command_options = ('i', 'u', 'd') 115 | order = Column(Integer, primary_key=True) 116 | 117 | version = relationship(Version, backref=backref("operations", lazy="joined")) 118 | 119 | @validates('command') 120 | def validate_command(self, key, command): 121 | assert command in self.command_options 122 | return command 123 | 124 | def __repr__(self): 125 | return u"".\ 126 | format(self.row_id, self.tracked_model, self.command) 127 | 128 | def references(self, obj): 129 | "Whether this operation references the given object or not." 130 | if self.row_id != getattr(obj, get_pk(obj), None): 131 | return False 132 | model = self.tracked_model 133 | if model is None: 134 | return False # operation doesn't even refer to a tracked model 135 | return model is type(obj) 136 | 137 | def perform(operation, container, session, node_id=None): 138 | """ 139 | Performs *operation*, looking for required data in 140 | *container*, and using *session* to perform it. 141 | 142 | *container* is an instance of 143 | dbsync.messages.base.BaseMessage. 144 | 145 | *node_id* is the node responsible for the operation, if known 146 | (else ``None``). 147 | 148 | If at any moment this operation fails for predictable causes, 149 | it will raise an *OperationError*. 150 | """ 151 | model = operation.tracked_model 152 | if model is None: 153 | raise OperationError("no content type for this operation", operation) 154 | 155 | if operation.command == 'i': 156 | obj = query_model(session, model).\ 157 | filter(getattr(model, get_pk(model)) == operation.row_id).first() 158 | pull_obj = container.query(model).\ 159 | filter(attr('__pk__') == operation.row_id).first() 160 | if pull_obj is None: 161 | raise OperationError( 162 | "no object backing the operation in container", operation) 163 | if obj is None: 164 | session.add(pull_obj) 165 | else: 166 | # Don't raise an exception if the incoming object is 167 | # exactly the same as the local one. 168 | if properties_dict(obj) == properties_dict(pull_obj): 169 | logger.warning(u"insert attempted when an identical object " 170 | u"already existed in local database: " 171 | u"model {0} pk {1}".format(model.__name__, 172 | operation.row_id)) 173 | else: 174 | raise OperationError( 175 | u"insert attempted when the object already existed: " 176 | u"model {0} pk {1}".format(model.__name__, 177 | operation.row_id)) 178 | 179 | elif operation.command == 'u': 180 | obj = query_model(session, model).\ 181 | filter(getattr(model, get_pk(model)) == operation.row_id).first() 182 | if obj is None: 183 | # For now, the record will be created again, but is an 184 | # error because nothing should be deleted without 185 | # using dbsync 186 | # raise OperationError( 187 | # "the referenced object doesn't exist in database", operation) 188 | logger.warning( 189 | u"The referenced object doesn't exist in database. " 190 | u"Node %s. Operation %s", 191 | node_id, 192 | operation) 193 | 194 | pull_obj = container.query(model).\ 195 | filter(attr('__pk__') == operation.row_id).first() 196 | if pull_obj is None: 197 | raise OperationError( 198 | "no object backing the operation in container", operation) 199 | session.merge(pull_obj) 200 | 201 | elif operation.command == 'd': 202 | obj = query_model(session, model, only_pk=True).\ 203 | filter(getattr(model, get_pk(model)) == operation.row_id).first() 204 | if obj is None: 205 | # The object is already deleted in the server 206 | # The final state in node and server are the same. But 207 | # it's an error because nothing should be deleted 208 | # without using dbsync 209 | logger.warning( 210 | "The referenced object doesn't exist in database. " 211 | u"Node %s. Operation %s", 212 | node_id, 213 | operation) 214 | else: 215 | session.delete(obj) 216 | 217 | else: 218 | raise OperationError( 219 | "the operation doesn't specify a valid command ('i', 'u', 'd')", 220 | operation) 221 | -------------------------------------------------------------------------------- /dbsync/server/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for the synchronization server. 3 | 4 | The server listens to 'push' and 'pull' requests and provides a 5 | customizable registry service. 6 | """ 7 | 8 | from dbsync.server.tracking import track 9 | from dbsync.core import extend 10 | from dbsync.server.handlers import ( 11 | handle_register, 12 | handle_pull, 13 | before_push, 14 | after_push, 15 | handle_push, 16 | handle_repair, 17 | handle_query) 18 | from dbsync.server.trim import trim 19 | -------------------------------------------------------------------------------- /dbsync/server/conflicts.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: server.conflicts 3 | :synopsis: Conflict detection for the centralized push operation. 4 | """ 5 | 6 | from sqlalchemy.schema import UniqueConstraint 7 | 8 | from dbsync.lang import * 9 | from dbsync.utils import get_pk, class_mapper, query_model, column_properties 10 | 11 | 12 | def find_unique_conflicts(push_message, session): 13 | """ 14 | Returns a list of conflicts caused by unique constraints in the 15 | given push message contrasted against the database. Each conflict 16 | is a dictionary with the following fields:: 17 | 18 | object: the conflicting object in database, bound to the 19 | session 20 | columns: tuple of column names in the unique constraint 21 | new_values: tuple of values that can be used to update the 22 | conflicting object. 23 | """ 24 | conflicts = [] 25 | 26 | for pk, model in ((op.row_id, op.tracked_model) 27 | for op in push_message.operations 28 | if op.command != 'd'): 29 | if model is None: continue 30 | 31 | for constraint in ifilter(lambda c: isinstance(c, UniqueConstraint), 32 | class_mapper(model).mapped_table.constraints): 33 | 34 | unique_columns = tuple(col.name for col in constraint.columns) 35 | remote_obj = push_message.query(model).\ 36 | filter(attr('__pk__') == pk).first() 37 | remote_values = tuple(getattr(remote_obj, col, None) 38 | for col in unique_columns) 39 | 40 | if all(value is None for value in remote_values): continue 41 | local_obj = query_model(session, model).\ 42 | filter_by(**dict(izip(unique_columns, remote_values))).first() 43 | if local_obj is None: continue 44 | local_pk = getattr(local_obj, get_pk(model)) 45 | if local_pk == pk: continue 46 | 47 | push_obj = push_message.query(model).\ 48 | filter(attr('__pk__') == local_pk).first() 49 | if push_obj is None: continue # push will fail 50 | 51 | conflicts.append( 52 | {'object': local_obj, 53 | 'columns': unique_columns, 54 | 'new_values': tuple(getattr(push_obj, col) 55 | for col in unique_columns)}) 56 | 57 | return conflicts 58 | -------------------------------------------------------------------------------- /dbsync/server/handlers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Registry, pull, push and other request handlers. 3 | 4 | The pull cycle consists in receiving a version identifier and sending 5 | back a PullMessage filled with versions above the one received. 6 | 7 | The push cycle consists in receiving a complete PushMessage and either 8 | rejecting it based on latest version or signature, or accepting it and 9 | performing the operations indicated in it. The operations should also 10 | be inserted in the operations table, in the correct order but getting 11 | new keys for the 'order' column, and linked with a newly created 12 | version. If it accepts the message, the push handler should also 13 | return the new version identifier to the node (and the programmer is 14 | tasked to send the HTTP response). 15 | """ 16 | 17 | import datetime 18 | 19 | from sqlalchemy.orm import make_transient 20 | 21 | from dbsync.lang import * 22 | from dbsync.utils import ( 23 | generate_secret, 24 | properties_dict, 25 | column_properties, 26 | get_pk, 27 | query_model, 28 | EventRegister) 29 | from dbsync import core 30 | from dbsync.models import ( 31 | Version, 32 | Node, 33 | OperationError, 34 | Operation) 35 | from dbsync.messages.base import BaseMessage 36 | from dbsync.messages.register import RegisterMessage 37 | from dbsync.messages.pull import PullMessage, PullRequestMessage 38 | from dbsync.messages.push import PushMessage 39 | from dbsync.server.conflicts import find_unique_conflicts 40 | from dbsync.logs import get_logger 41 | 42 | 43 | logger = get_logger(__name__) 44 | 45 | 46 | @core.session_closing 47 | def handle_query(data, session=None): 48 | "Responds to a query request." 49 | model = core.synched_models.model_names.\ 50 | get(data.get('model', None), core.null_model).model 51 | if model is None: return None 52 | mname = model.__name__ 53 | filters = dict((k, v) for k, v in ((k[len(mname) + 1:], v) 54 | for k, v in data.iteritems() 55 | if k.startswith(mname + '_')) 56 | if k and k in column_properties(model)) 57 | message = BaseMessage() 58 | q = query_model(session, model) 59 | if filters: 60 | q = q.filter_by(**filters) 61 | for obj in q: 62 | message.add_object(obj) 63 | return message.to_json() 64 | 65 | 66 | @core.session_closing 67 | def handle_repair(data=None, session=None): 68 | "Handle repair request. Return whole server database." 69 | include_extensions = 'exclude_extensions' not in (data or {}) 70 | latest_version_id = core.get_latest_version_id(session=session) 71 | message = BaseMessage() 72 | for model in core.synched_models.models.iterkeys(): 73 | for obj in query_model(session, model): 74 | message.add_object(obj, include_extensions=include_extensions) 75 | response = message.to_json() 76 | response['latest_version_id'] = latest_version_id 77 | return response 78 | 79 | 80 | @core.with_transaction() 81 | def handle_register(user_id=None, node_id=None, session=None): 82 | """ 83 | Handle a registry request, creating a new node, wrapping it in a 84 | message and returning it to the client node. 85 | 86 | *user_id* can be a numeric key to a user record, which will be set 87 | in the node record itself. 88 | 89 | If *node_id* is given, it will be used instead of creating a new 90 | node. This allows for node reuse according to criteria specified 91 | by the programmer. 92 | """ 93 | message = RegisterMessage() 94 | if node_id is not None: 95 | node = session.query(Node).filter(Node.node_id == node_id).first() 96 | if node is not None: 97 | message.node = node 98 | return message.to_json() 99 | newnode = Node() 100 | newnode.registered = datetime.datetime.now() 101 | newnode.registry_user_id = user_id 102 | newnode.secret = generate_secret(128) 103 | session.add(newnode) 104 | session.flush() 105 | message.node = newnode 106 | return message.to_json() 107 | 108 | 109 | class PullRejected(Exception): pass 110 | 111 | 112 | def handle_pull(data, swell=False, include_extensions=True): 113 | """ 114 | Handle the pull request and return a dictionary object to be sent 115 | back to the node. 116 | 117 | *data* must be a dictionary-like object, usually one obtained from 118 | decoding a JSON dictionary in the POST body. 119 | """ 120 | try: 121 | request_message = PullRequestMessage(data) 122 | except KeyError: 123 | raise PullRejected("request object isn't a valid PullRequestMessage", data) 124 | 125 | message = PullMessage() 126 | message.fill_for( 127 | request_message, 128 | swell=swell, 129 | include_extensions=include_extensions) 130 | return message.to_json() 131 | 132 | 133 | class PushRejected(Exception): pass 134 | 135 | class PullSuggested(PushRejected): pass 136 | 137 | 138 | #: Callbacks receive the session and the message. 139 | before_push = EventRegister() 140 | after_push = EventRegister() 141 | 142 | 143 | @core.with_transaction() 144 | def handle_push(data, session=None): 145 | """ 146 | Handle the push request and return a dictionary object to be sent 147 | back to the node. 148 | 149 | If the push is rejected, this procedure will raise a 150 | dbsync.server.handlers.PushRejected exception. 151 | 152 | *data* must be a dictionary-like object, usually the product of 153 | parsing a JSON string. 154 | """ 155 | message = None 156 | try: 157 | message = PushMessage(data) 158 | except KeyError: 159 | raise PushRejected("request object isn't a valid PushMessage", data) 160 | latest_version_id = core.get_latest_version_id(session=session) 161 | if latest_version_id != message.latest_version_id: 162 | exc = "version identifier isn't the latest one; "\ 163 | "given: %s" % message.latest_version_id 164 | if latest_version_id is None: 165 | raise PushRejected(exc) 166 | if message.latest_version_id is None: 167 | raise PullSuggested(exc) 168 | if message.latest_version_id < latest_version_id: 169 | raise PullSuggested(exc) 170 | raise PushRejected(exc) 171 | if not message.operations: 172 | raise PushRejected("message doesn't contain operations") 173 | if not message.islegit(session): 174 | raise PushRejected("message isn't properly signed") 175 | 176 | for listener in before_push: 177 | listener(session, message) 178 | 179 | # I) detect unique constraint conflicts and resolve them if possible 180 | unique_conflicts = find_unique_conflicts(message, session) 181 | conflicting_objects = set() 182 | for uc in unique_conflicts: 183 | obj = uc['object'] 184 | conflicting_objects.add(obj) 185 | for key, value in izip(uc['columns'], uc['new_values']): 186 | setattr(obj, key, value) 187 | for obj in conflicting_objects: 188 | make_transient(obj) # remove from session 189 | for model in set(type(obj) for obj in conflicting_objects): 190 | pk_name = get_pk(model) 191 | pks = [getattr(obj, pk_name) 192 | for obj in conflicting_objects 193 | if type(obj) is model] 194 | session.query(model).filter(getattr(model, pk_name).in_(pks)).\ 195 | delete(synchronize_session=False) # remove from the database 196 | session.add_all(conflicting_objects) # reinsert 197 | session.flush() 198 | 199 | # II) perform the operations 200 | operations = filter(lambda o: o.tracked_model is not None, message.operations) 201 | try: 202 | for op in operations: 203 | op.perform(message, session, message.node_id) 204 | except OperationError as e: 205 | logger.exception(u"Couldn't perform operation in push from node %s.", 206 | message.node_id) 207 | raise PushRejected("at least one operation couldn't be performed", 208 | *e.args) 209 | 210 | # III) insert a new version 211 | version = Version(created=datetime.datetime.now(), node_id=message.node_id) 212 | session.add(version) 213 | 214 | # IV) insert the operations, discarding the 'order' column 215 | for op in sorted(operations, key=attr('order')): 216 | new_op = Operation() 217 | for k in ifilter(lambda k: k != 'order', properties_dict(op)): 218 | setattr(new_op, k, getattr(op, k)) 219 | session.add(new_op) 220 | new_op.version = version 221 | session.flush() 222 | 223 | for listener in after_push: 224 | listener(session, message) 225 | 226 | # return the new version id back to the node 227 | return {'new_version_id': version.version_id} 228 | -------------------------------------------------------------------------------- /dbsync/server/tracking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Listeners to SQLAlchemy events to keep track of CUD operations. 3 | 4 | On the server side, each operation will also trigger a new version, so 5 | as to allow direct use of the database while maintaining occassionally 6 | connected nodes capable of synchronizing their data. 7 | """ 8 | 9 | import logging 10 | import inspect 11 | import datetime 12 | import warnings 13 | 14 | from sqlalchemy import event 15 | 16 | from dbsync import core 17 | from dbsync.models import Operation, Version 18 | from dbsync.logs import get_logger 19 | 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | if core.mode == 'client': 25 | warnings.warn("don't import both server and client") 26 | core.mode = 'server' 27 | 28 | 29 | def make_listener(command): 30 | "Builds a listener for the given command (i, u, d)." 31 | @core.session_committing 32 | def listener(mapper, connection, target, session=None): 33 | if getattr(core.SessionClass.object_session(target), 34 | core.INTERNAL_SESSION_ATTR, 35 | False): 36 | return 37 | if not core.listening: 38 | logger.warning("dbsync is disabled; " 39 | "aborting listener to '{0}' command".format(command)) 40 | return 41 | if command == 'u' and not core.SessionClass.object_session(target).\ 42 | is_modified(target, include_collections=False): 43 | return 44 | tname = mapper.mapped_table.name 45 | if tname not in core.synched_models.tables: 46 | logging.error("you must track a mapped class to table {0} "\ 47 | "to log operations".format(tname)) 48 | return 49 | # one version for each operation 50 | version = Version(created=datetime.datetime.now()) 51 | pk = getattr(target, mapper.primary_key[0].name) 52 | op = Operation( 53 | row_id=pk, 54 | content_type_id=core.synched_models.tables[tname].id, 55 | command=command) 56 | session.add(version) 57 | session.add(op) 58 | op.version = version 59 | return listener 60 | 61 | 62 | def _start_tracking(model, directions): 63 | if 'pull' in directions: 64 | core.pulled_models.add(model) 65 | if 'push' in directions: 66 | core.pushed_models.add(model) 67 | if model in core.synched_models.models: 68 | return model 69 | core.synched_models.install(model) 70 | event.listen(model, 'after_insert', make_listener('i')) 71 | event.listen(model, 'after_update', make_listener('u')) 72 | event.listen(model, 'after_delete', make_listener('d')) 73 | return model 74 | 75 | 76 | def track(*directions): 77 | """ 78 | Adds an ORM class to the list of synchronized classes. 79 | 80 | It can be used as a class decorator. This will also install 81 | listeners to keep track of CUD operations for the given model. 82 | 83 | *directions* are optional arguments of values in ('push', 'pull') 84 | that can restrict the way dbsync handles the class during those 85 | handlers. If not given, both values are assumed. If only one of 86 | them is given, the other handler will ignore the tracked class. 87 | """ 88 | valid = ('push', 'pull') 89 | if not directions: 90 | return lambda model: _start_tracking(model, valid) 91 | if len(directions) == 1 and inspect.isclass(directions[0]): 92 | return _start_tracking(directions[0], valid) 93 | assert all(d in valid for d in directions), \ 94 | "track only accepts the arguments: {0}".format(', '.join(valid)) 95 | return lambda model: _start_tracking(model, directions) 96 | -------------------------------------------------------------------------------- /dbsync/server/trim.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trim the server synchronization tables to free space. 3 | """ 4 | 5 | from dbsync.lang import * 6 | from dbsync import core 7 | from dbsync.models import Node, Version, Operation 8 | 9 | 10 | @core.session_committing 11 | def trim(session=None): 12 | """ 13 | Clears space by deleting operations and versions that are no 14 | longer needed. 15 | 16 | This might cause the server to answer incorrectly to pull requests 17 | from nodes that were late to register. To go around that, a repair 18 | should be enforced after the node's register. 19 | 20 | Another problem with this procedure is that it won't clear space 21 | if there's at least one abandoned node registered. The task of 22 | keeping the nodes registry clean of those is left to the 23 | programmer. 24 | """ 25 | versions = [maybe(session.query(Version).\ 26 | filter(Version.node_id == node.node_id).\ 27 | order_by(Version.version_id.desc()).first(), 28 | attr('version_id'), 29 | None) 30 | for node in session.query(Node)] 31 | if not versions: 32 | last_id = core.get_latest_version_id(session=session) 33 | # all operations are versioned according to dbsync.server.track 34 | session.query(Operation).delete() 35 | session.query(Version).filter(Version.version_id != last_id).delete() 36 | return 37 | if None in versions: return # dead nodes block the trim 38 | minversion = min(versions) 39 | session.query(Operation).filter(Operation.version_id <= minversion).delete() 40 | session.query(Version).filter(Version.version_id < minversion).delete() 41 | -------------------------------------------------------------------------------- /dbsync/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: dbsync.utils 3 | :synopsis: Utility functions. 4 | """ 5 | 6 | import random 7 | import inspect 8 | from sqlalchemy.orm import ( 9 | object_mapper, 10 | class_mapper, 11 | ColumnProperty, 12 | noload, 13 | defer, 14 | instrumentation, 15 | state) 16 | 17 | 18 | def generate_secret(length=128): 19 | chars = "0123456789"\ 20 | "abcdefghijklmnopqrstuvwxyz"\ 21 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ"\ 22 | ".,_-+*@:;[](){}~!?|<>=/\&$#" 23 | return "".join(random.choice(chars) for _ in xrange(length)) 24 | 25 | 26 | def properties_dict(sa_object): 27 | """ 28 | Returns a dictionary of column-properties for the given SQLAlchemy 29 | mapped object. 30 | """ 31 | mapper = object_mapper(sa_object) 32 | return dict((prop.key, getattr(sa_object, prop.key)) 33 | for prop in mapper.iterate_properties 34 | if isinstance(prop, ColumnProperty)) 35 | 36 | 37 | def column_properties(sa_variant): 38 | "Returns a list of column-properties." 39 | mapper = class_mapper(sa_variant) if inspect.isclass(sa_variant) \ 40 | else object_mapper(sa_variant) 41 | return [prop.key for prop in mapper.iterate_properties 42 | if isinstance(prop, ColumnProperty)] 43 | 44 | 45 | def types_dict(sa_class): 46 | """ 47 | Returns a dictionary of column-properties mapped to their 48 | SQLAlchemy types for the given mapped class. 49 | """ 50 | mapper = class_mapper(sa_class) 51 | return dict((prop.key, prop.columns[0].type) 52 | for prop in mapper.iterate_properties 53 | if isinstance(prop, ColumnProperty)) 54 | 55 | 56 | def construct_bare(class_): 57 | """ 58 | Returns an object of type *class_*, without invoking the class' 59 | constructor. 60 | """ 61 | obj = class_.__new__(class_) 62 | manager = getattr(class_, instrumentation.ClassManager.MANAGER_ATTR) 63 | setattr(obj, manager.STATE_ATTR, state.InstanceState(obj, manager)) 64 | return obj 65 | 66 | 67 | def object_from_dict(class_, dict_): 68 | "Returns an object from a dictionary of attributes." 69 | obj = construct_bare(class_) 70 | for k, v in dict_.iteritems(): 71 | setattr(obj, k, v) 72 | return obj 73 | 74 | 75 | def copy(obj): 76 | "Returns a copy of the given object, not linked to a session." 77 | return object_from_dict(type(obj), properties_dict(obj)) 78 | 79 | 80 | def get_pk(sa_variant): 81 | "Returns the primary key name for the given mapped class or object." 82 | mapper = class_mapper(sa_variant) if inspect.isclass(sa_variant) \ 83 | else object_mapper(sa_variant) 84 | return mapper.primary_key[0].key 85 | 86 | 87 | def parent_references(sa_object, models): 88 | """ 89 | Returns a list of pairs (*sa_class*, *pk*) that reference all the 90 | parent objects of *sa_object*. 91 | """ 92 | mapper = object_mapper(sa_object) 93 | references = [(getattr(sa_object, k.parent.name), k.column.table) 94 | for k in mapper.mapped_table.foreign_keys] 95 | def get_model(table): 96 | for m in models: 97 | if class_mapper(m).mapped_table == table: 98 | return m 99 | return None 100 | return [(m, pk) 101 | for m, pk in ((get_model(table), v) for v, table in references) 102 | if m is not None] 103 | 104 | 105 | def parent_objects(sa_object, models, session, only_pk=False): 106 | """ 107 | Returns all the parent objects the given *sa_object* points to 108 | (through foreign keys in *sa_object*). 109 | 110 | *models* is a list of mapped classes. 111 | 112 | *session* must be a valid SA session instance. 113 | """ 114 | return filter(lambda obj: obj is not None, 115 | (query_model(session, m, only_pk=only_pk).\ 116 | filter_by(**{get_pk(m): val}).first() 117 | for m, val in parent_references(sa_object, models))) 118 | 119 | 120 | def query_model(session, sa_class, only_pk=False): 121 | """ 122 | Returns a query for *sa_class* that doesn't load any relationship 123 | attribute. 124 | """ 125 | opts = (noload('*'),) 126 | if only_pk: 127 | pk = get_pk(sa_class) 128 | opts += tuple( 129 | defer(prop.key) 130 | for prop in class_mapper(sa_class).iterate_properties 131 | if isinstance(prop, ColumnProperty) 132 | if prop.key != pk) 133 | return session.query(sa_class).options(*opts) 134 | 135 | 136 | class EventRegister(object): 137 | 138 | def __init__(self): 139 | self._listeners = [] 140 | 141 | def __iter__(self): 142 | for listener in self._listeners: 143 | yield listener 144 | 145 | def listen(self, listener): 146 | "Register a listener. May be used as a decorator." 147 | assert inspect.isroutine(listener), "invalid listener" 148 | if listener not in self._listeners: 149 | self._listeners.append(listener) 150 | return listener 151 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | nose 2 | sqlalchemy>=0.8.0 3 | Flask 4 | requests 5 | -------------------------------------------------------------------------------- /diagram.gv: -------------------------------------------------------------------------------- 1 | strict digraph diagram { 2 | n0 [label="Client:\nunsynchronized"]; 3 | n1 [label="Server:\nawaits push"]; 4 | n2 [label="Server:\nawaits pull"]; 5 | n3 [label="Client:\nup-to-date"]; 6 | n4 [label="Client:\nsynchronized"]; 7 | n5 [label="Server:\nsynchronized"]; 8 | edge [label="push"]; 9 | n0 -> n1; 10 | edge [label="push rejected"]; 11 | n1 -> n0; 12 | edge [label="pull"]; 13 | n0 -> n2; 14 | edge [label="pull payload"]; 15 | n2 -> n3; 16 | edge [label="push"]; 17 | n3 -> n1; 18 | edge [label="push accepted"]; 19 | n1 -> n5 -> n4; 20 | } 21 | -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bintlabs/python-sync-db/bb23d77abf560793696f906e030950aec04c3361/diagram.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sqlalchemy>=0.8.0 2 | requests 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | import dbsync 3 | 4 | setup(name='dbsync', 5 | version=dbsync.__version__, 6 | url='https://github.com/bintlabs/python-sync-db', 7 | author='Bint', 8 | packages=['dbsync', 'dbsync.client', 'dbsync.server', 'dbsync.messages'], 9 | description='Centralized database synchronization for SQLAlchemy', 10 | install_requires=['sqlalchemy>=0.8.0', 'requests'], 11 | license='MIT',) 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bintlabs/python-sync-db/bb23d77abf560793696f906e030950aec04c3361/tests/__init__.py -------------------------------------------------------------------------------- /tests/codec_tests.py: -------------------------------------------------------------------------------- 1 | from nose.tools import * 2 | import datetime 3 | import decimal 4 | 5 | from dbsync.messages.codecs import encode, encode_dict, decode, decode_dict 6 | from sqlalchemy import types 7 | 8 | 9 | def test_encode_date(): 10 | today = datetime.date.today() 11 | e = encode(types.Date()) 12 | d = decode(types.Date()) 13 | assert today == d(e(today)) 14 | 15 | 16 | def test_encode_datetime(): 17 | now = datetime.datetime.now() 18 | e = encode(types.DateTime()) 19 | d = decode(types.DateTime()) 20 | # microseconds are lost, but that's ok 21 | assert now.timetuple()[:6] == d(e(now)).timetuple()[:6] 22 | 23 | 24 | def test_encode_numeric(): 25 | num = decimal.Decimal('3.3') 26 | e = encode(types.Numeric()) 27 | d = decode(types.Numeric()) 28 | assert num == d(e(num)) 29 | 30 | 31 | def test_encode_float_numeric(): 32 | num = 3.3 33 | e = encode(types.Numeric(asdecimal=False)) 34 | d = decode(types.Numeric(asdecimal=False)) 35 | assert num == d(e(num)) 36 | -------------------------------------------------------------------------------- /tests/conflict_detection_tests.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from nose.tools import * 3 | 4 | from dbsync import models, core 5 | from dbsync.messages.pull import PullMessage 6 | from dbsync.client.conflicts import ( 7 | find_direct_conflicts, 8 | find_dependency_conflicts) 9 | 10 | from tests.models import A, B, Base, Session 11 | 12 | def get_content_type_ids(): 13 | return (core.synched_models.models[A].id, core.synched_models.models[B].id) 14 | 15 | ct_a_id, ct_b_id = get_content_type_ids() 16 | 17 | 18 | def addstuff(): 19 | a1 = A(name="first a") 20 | a2 = A(name="second a") 21 | b1 = B(name="first b", a=a1) 22 | b2 = B(name="second b", a=a1) 23 | b3 = B(name="third b", a=a2) 24 | session = Session() 25 | session.add_all([a1, a2, b1, b2, b3]) 26 | session.commit() 27 | 28 | def changestuff(): 29 | session = Session() 30 | a1, a2 = session.query(A) 31 | b1, b2, b3 = session.query(B) 32 | a1.name = "first a modified" 33 | b2.a = a2 34 | session.delete(b3) 35 | session.commit() 36 | 37 | def create_fake_operations(): 38 | return [models.Operation(row_id=3, content_type_id=ct_b_id, command='u'), 39 | models.Operation(row_id=1, content_type_id=ct_a_id, command='d'), 40 | models.Operation(row_id=2, content_type_id=ct_a_id, command='d')] 41 | 42 | def setup(): 43 | pass 44 | 45 | @core.with_listening(False) 46 | def teardown(): 47 | session = Session() 48 | map(session.delete, session.query(A)) 49 | map(session.delete, session.query(B)) 50 | map(session.delete, session.query(models.Operation)) 51 | session.commit() 52 | 53 | 54 | @with_setup(setup, teardown) 55 | def test_find_direct_conflicts(): 56 | addstuff() 57 | changestuff() 58 | session = Session() 59 | message_ops = create_fake_operations() 60 | conflicts = find_direct_conflicts( 61 | message_ops, session.query(models.Operation).all()) 62 | expected = [ 63 | (message_ops[0], 64 | models.Operation(row_id=3, content_type_id=ct_b_id, command='d')), # b3 65 | (message_ops[1], 66 | models.Operation(row_id=1, content_type_id=ct_a_id, command='u'))] # a1 67 | logging.info(conflicts) 68 | logging.info(expected) 69 | assert repr(conflicts) == repr(expected) 70 | 71 | 72 | @with_setup(setup, teardown) 73 | def test_find_dependency_conflicts(): 74 | addstuff() 75 | changestuff() 76 | session = Session() 77 | message_ops = create_fake_operations() 78 | conflicts = find_dependency_conflicts( 79 | message_ops, 80 | session.query(models.Operation).all(), 81 | session) 82 | expected = [ 83 | (message_ops[1], # a1 84 | models.Operation(row_id=1, content_type_id=ct_b_id, command='i')), # b1 85 | (message_ops[2], # a2 86 | models.Operation(row_id=2, content_type_id=ct_b_id, command='i')), # b2 87 | (message_ops[2], # a2 88 | models.Operation(row_id=2, content_type_id=ct_b_id, command='u'))] # b2 89 | logging.info(conflicts) 90 | logging.info(expected) 91 | assert repr(conflicts) == repr(expected) 92 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy import Column, Integer, String, ForeignKey, create_engine 4 | from sqlalchemy.orm import relationship, sessionmaker 5 | from sqlalchemy.ext.declarative import declarative_base 6 | 7 | from dbsync.utils import generate_secret 8 | import dbsync 9 | from dbsync import client, models 10 | 11 | 12 | engine = create_engine("sqlite://") 13 | Session = sessionmaker(bind=engine) 14 | 15 | 16 | Base = declarative_base() 17 | 18 | 19 | @client.track 20 | class A(Base): 21 | __tablename__ = "test_a" 22 | 23 | id = Column(Integer, primary_key=True) 24 | name = Column(String) 25 | 26 | def __repr__(self): 27 | return u"".format(self.id, self.name) 28 | 29 | 30 | @client.track 31 | class B(Base): 32 | __tablename__ = "test_b" 33 | 34 | id = Column(Integer, primary_key=True) 35 | name = Column(String) 36 | a_id = Column(Integer, ForeignKey("test_a.id")) 37 | 38 | a = relationship(A, backref="bs") 39 | 40 | def __repr__(self): 41 | return u"".format( 42 | self.id, self.name, self.a_id) 43 | 44 | 45 | Base.metadata.create_all(engine) 46 | dbsync.set_engine(engine) 47 | dbsync.create_all() 48 | dbsync.generate_content_types() 49 | _session = Session() 50 | _session.add( 51 | models.Node(registered=datetime.datetime.now(), secret=generate_secret(128))) 52 | _session.commit() 53 | _session.close() 54 | -------------------------------------------------------------------------------- /tests/pull_message_tests.py: -------------------------------------------------------------------------------- 1 | from nose.tools import * 2 | import datetime 3 | import logging 4 | import json 5 | 6 | from dbsync.lang import * 7 | from dbsync import models, core 8 | from dbsync.messages.pull import PullMessage 9 | 10 | from tests.models import A, B, Session 11 | 12 | 13 | def addstuff(): 14 | a1 = A(name="first a") 15 | a2 = A(name="second a") 16 | b1 = B(name="first b", a=a1) 17 | b2 = B(name="second b", a=a1) 18 | b3 = B(name="third b", a=a2) 19 | session = Session() 20 | session.add_all([a1, a2, b1, b2, b3]) 21 | version = models.Version() 22 | version.created = datetime.datetime.now() 23 | session.add(version) 24 | session.flush() 25 | version_id = version.version_id 26 | session.commit() 27 | session = Session() 28 | for op in session.query(models.Operation): 29 | op.version_id = version_id 30 | session.commit() 31 | 32 | def setup(): pass 33 | 34 | @core.with_listening(False) 35 | def teardown(): 36 | session = Session() 37 | map(session.delete, session.query(A)) 38 | map(session.delete, session.query(B)) 39 | map(session.delete, session.query(models.Operation)) 40 | map(session.delete, session.query(models.Version)) 41 | session.commit() 42 | 43 | 44 | @with_setup(setup, teardown) 45 | def test_create_message(): 46 | addstuff() 47 | session = Session() 48 | message = PullMessage() 49 | version = session.query(models.Version).first() 50 | message.add_version(version) 51 | assert message.to_json() == PullMessage(message.to_json()).to_json() 52 | 53 | 54 | @with_setup(setup, teardown) 55 | def test_encode_message(): 56 | addstuff() 57 | session = Session() 58 | message = PullMessage() 59 | version = session.query(models.Version).first() 60 | message.add_version(version) 61 | assert message.to_json() == json.loads(json.dumps(message.to_json())) 62 | 63 | 64 | @with_setup(setup, teardown) 65 | def test_message_query(): 66 | addstuff() 67 | session = Session() 68 | message = PullMessage() 69 | version = session.query(models.Version).first() 70 | message.add_version(version) 71 | # test equal representation, because the test models are well printed 72 | for b in session.query(B): 73 | assert repr(b) == repr(message.query(B).filter( 74 | attr('id') == b.id).all()[0]) 75 | for op in session.query(models.Operation): 76 | assert repr(op) == repr(message.query(models.Operation).filter( 77 | attr('order') == op.order).all()[0]) 78 | try: 79 | message.query(1) 80 | raise Exception("Message query did not fail") 81 | except TypeError: 82 | pass 83 | 84 | 85 | @with_setup(setup, teardown) 86 | def test_message_does_not_contaminate_database(): 87 | addstuff() 88 | session = Session() 89 | message = PullMessage() 90 | version = session.query(models.Version).first() 91 | message.add_version(version) 92 | # test that the are no unversioned operations 93 | assert not session.query(models.Operation).\ 94 | filter(models.Operation.version_id == None).all() 95 | -------------------------------------------------------------------------------- /tests/push_message_tests.py: -------------------------------------------------------------------------------- 1 | from nose.tools import * 2 | import datetime 3 | import logging 4 | import json 5 | 6 | from dbsync import models, core 7 | from dbsync.messages.push import PushMessage 8 | 9 | from tests.models import A, B, Session 10 | 11 | 12 | def addstuff(): 13 | a1 = A(name="first a") 14 | a2 = A(name="second a") 15 | b1 = B(name="first b", a=a1) 16 | b2 = B(name="second b", a=a1) 17 | b3 = B(name="third b", a=a2) 18 | session = Session() 19 | session.add_all([a1, a2, b1, b2, b3]) 20 | session.commit() 21 | 22 | def changestuff(): 23 | session = Session() 24 | a1, a2 = session.query(A) 25 | b1, b2, b3 = session.query(B) 26 | a1.name = "first a modified" 27 | b2.a = a2 28 | session.delete(b3) 29 | session.commit() 30 | 31 | def setup(): 32 | pass 33 | 34 | @core.with_listening(False) 35 | def teardown(): 36 | session = Session() 37 | map(session.delete, session.query(A)) 38 | map(session.delete, session.query(B)) 39 | map(session.delete, session.query(models.Operation)) 40 | session.commit() 41 | 42 | 43 | @with_setup(setup, teardown) 44 | def test_create_message(): 45 | addstuff() 46 | changestuff() 47 | session = Session() 48 | message = PushMessage() 49 | message.add_unversioned_operations() 50 | message.set_node(session.query(models.Node).first()) 51 | assert message.to_json() == PushMessage(message.to_json()).to_json() 52 | 53 | 54 | @with_setup(setup, teardown) 55 | def test_encode_message(): 56 | addstuff() 57 | changestuff() 58 | session = Session() 59 | message = PushMessage() 60 | message.add_unversioned_operations() 61 | message.set_node(session.query(models.Node).first()) 62 | assert message.to_json() == json.loads(json.dumps(message.to_json())) 63 | 64 | 65 | @with_setup(setup, teardown) 66 | def test_sign_message(): 67 | addstuff() 68 | changestuff() 69 | session = Session() 70 | message = PushMessage() 71 | message.set_node(session.query(models.Node).first()) 72 | message.add_unversioned_operations() 73 | assert message.islegit(session) 74 | message.key += "broken" 75 | assert not message.islegit(session) 76 | -------------------------------------------------------------------------------- /tests/track_tests.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from nose.tools import * 3 | 4 | from dbsync.lang import * 5 | from dbsync import models, core, client 6 | from dbsync.client.compression import ( 7 | compress, 8 | compressed_operations, 9 | unsynched_objects) 10 | 11 | from tests.models import A, B, Base, Session 12 | 13 | 14 | def addstuff(): 15 | a1 = A(name="first a") 16 | a2 = A(name="second a") 17 | b1 = B(name="first b", a=a1) 18 | b2 = B(name="second b", a=a1) 19 | b3 = B(name="third b", a=a2) 20 | session = Session() 21 | session.add_all([a1, a2, b1, b2, b3]) 22 | session.commit() 23 | 24 | def changestuff(): 25 | session = Session() 26 | a1, a2 = session.query(A) 27 | b1, b2, b3 = session.query(B) 28 | a1.name = "first a modified" 29 | b2.a = a2 30 | session.delete(b3) 31 | session.commit() 32 | 33 | def setup(): 34 | pass 35 | 36 | @core.with_listening(False) 37 | def teardown(): 38 | session = Session() 39 | map(session.delete, session.query(A)) 40 | map(session.delete, session.query(B)) 41 | map(session.delete, session.query(models.Operation)) 42 | session.commit() 43 | 44 | 45 | @with_setup(setup, teardown) 46 | def test_tracking(): 47 | addstuff() 48 | changestuff() 49 | session = Session() 50 | assert session.query(models.Operation).\ 51 | filter(models.Operation.command == 'i').\ 52 | count() == 5, "insert operations don't match" 53 | assert session.query(models.Operation).\ 54 | filter(models.Operation.command == 'u').\ 55 | count() == 2, "update operations don't match" 56 | assert session.query(models.Operation).\ 57 | filter(models.Operation.command == 'd').\ 58 | count() == 1, "delete operations don't match" 59 | 60 | 61 | @with_setup(setup, teardown) 62 | def test_compression(): 63 | addstuff() 64 | changestuff() 65 | compress() # remove unnecesary operations 66 | session = Session() 67 | assert session.query(models.Operation).\ 68 | filter(models.Operation.command == 'i').\ 69 | count() == 4, "insert operations don't match" 70 | assert session.query(models.Operation).\ 71 | filter(models.Operation.command == 'u').\ 72 | count() == 0, "update operations don't match" 73 | assert session.query(models.Operation).\ 74 | filter(models.Operation.command == 'd').\ 75 | count() == 0, "delete operations don't match" 76 | 77 | 78 | @with_setup(setup, teardown) 79 | def test_unsynched_objects_detection(): 80 | addstuff() 81 | changestuff() 82 | assert bool(unsynched_objects()), "unsynched objects weren't detected" 83 | 84 | 85 | @with_setup(setup, teardown) 86 | def test_compression_consistency(): 87 | addstuff() 88 | changestuff() 89 | session = Session() 90 | ops = session.query(models.Operation).all() 91 | compress() 92 | news = session.query(models.Operation).order_by(models.Operation.order).all() 93 | assert news == compressed_operations(ops) 94 | 95 | 96 | @with_setup(setup, teardown) 97 | def test_compression_correctness(): 98 | addstuff() 99 | changestuff() 100 | session = Session() 101 | ops = compressed_operations(session.query(models.Operation).all()) 102 | groups = group_by(lambda op: (op.content_type_id, op.row_id), ops) 103 | for g in groups.itervalues(): 104 | logging.info(g) 105 | assert len(g) == 1 106 | # assert correctness when compressing operations from a pull 107 | # message 108 | pull_ops = [ 109 | models.Operation(command='i', content_type_id=1, row_id=1, order=1), 110 | models.Operation(command='d', content_type_id=1, row_id=1, order=2), 111 | models.Operation(command='i', content_type_id=1, row_id=1, order=3), 112 | models.Operation(command='u', content_type_id=1, row_id=1, order=4), 113 | # result of above should be a single 'i' 114 | models.Operation(command='u', content_type_id=2, row_id=1, order=5), 115 | models.Operation(command='d', content_type_id=2, row_id=1, order=6), 116 | models.Operation(command='i', content_type_id=2, row_id=1, order=7), 117 | models.Operation(command='d', content_type_id=2, row_id=1, order=8), 118 | # result of above should be a single 'd' 119 | models.Operation(command='d', content_type_id=3, row_id=1, order=9), 120 | models.Operation(command='i', content_type_id=3, row_id=1, order=10), 121 | # result of above should be an 'u' 122 | models.Operation(command='i', content_type_id=4, row_id=1, order=11), 123 | models.Operation(command='u', content_type_id=4, row_id=1, order=12), 124 | models.Operation(command='d', content_type_id=4, row_id=1, order=13), 125 | # result of above should be no operations 126 | models.Operation(command='d', content_type_id=5, row_id=1, order=14), 127 | models.Operation(command='i', content_type_id=5, row_id=1, order=15), 128 | models.Operation(command='d', content_type_id=5, row_id=1, order=16), 129 | # result of above should be a single 'd' 130 | models.Operation(command='u', content_type_id=6, row_id=1, order=17), 131 | models.Operation(command='d', content_type_id=6, row_id=1, order=18), 132 | models.Operation(command='i', content_type_id=6, row_id=1, order=19), 133 | # result of above should be an 'u' 134 | models.Operation(command='d', content_type_id=7, row_id=1, order=20), 135 | models.Operation(command='i', content_type_id=7, row_id=1, order=21), 136 | models.Operation(command='u', content_type_id=7, row_id=1, order=22) 137 | # result of above should be an 'u' 138 | ] 139 | compressed = compressed_operations(pull_ops) 140 | logging.info("len(compressed) == {0}".format(len(compressed))) 141 | logging.info("\n".join(repr(op) for op in compressed)) 142 | assert len(compressed) == 6 143 | assert compressed[0].command == 'i' 144 | assert compressed[1].command == 'd' 145 | assert compressed[2].command == 'u' 146 | assert compressed[3].command == 'd' 147 | assert compressed[4].command == 'u' 148 | assert compressed[5].command == 'u' 149 | --------------------------------------------------------------------------------