├── .gitignore ├── README.md ├── dgpy.py ├── setup.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | *.egg-info/ 3 | dist/ 4 | venv/ 5 | *.pyc 6 | .coverage 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dgpy 2 | 3 | `dgpy` is a generic dependency graph implemented in python, coded live, 10 4 | minutes at a time, in a test-driven development fashion. 5 | 6 | You can watch the playlist on 7 | [youtube](https://youtu.be/pXUL_aDhN-Y?list=PLYcUacEjhPL-nSolgfdIJ_GqBakUp790z) 8 | and read more about the motivations behind this project at 9 | [cesarsaez.me](http://www.cesarsaez.me/2015/12/dgpy.html). 10 | 11 | 12 | ## Features 13 | 14 | In terms of features, `dgpy` implements a **null/void node supporting input and 15 | output ports/plugs** as a base for user specialization, a **push and/or pull 16 | evaluation model** (you can mix it as the model get set per node) and 17 | **serialization** allowing to save and load graphs (reference counting is done 18 | during import, so the *serialized data is hack-able*). 19 | 20 | 21 | ## Usage 22 | 23 | There's no much to document as `dgpy` doesn't include any builtin node or data 24 | type in order to keep it away from any domain specific task, however there's a 25 | very simple `AddNode` implemented in the test suite that could serve as an 26 | example (and was used all over the place to drive the development). 27 | 28 | Here's a selection of snippets from the test suite showcasing how to get things 29 | started. 30 | 31 | ```python 32 | import dgpy 33 | 34 | # Let's implement a simple node 35 | class AddNode(dgpy.VoidNode): 36 | def initPorts(self): 37 | super(AddNode, self).initPorts() 38 | self.addInputPort("value1") 39 | self.addInputPort("value2") 40 | self.addOutputPort("result") 41 | 42 | def evaluate(self): 43 | super(AddNode, self).evaluate() 44 | result = 0 45 | for p in self._inputPorts.values(): 46 | if p.value is not None: 47 | result += p.value 48 | self.getOutputPort("result").value = result 49 | 50 | dgpy.registerNode("AddNode", AddNode) 51 | 52 | 53 | # Let's create a network 54 | graph = dgpy.Graph() 55 | graph.model = dgpy.PULL 56 | 57 | node1 = graph.addNode("node1", AddNode, value1=2, value2=3) 58 | node2 = graph.addNode("node2", AddNode, value1=5) 59 | node2.getInputPort("value2").connect(node1.getOutputPort("result")) 60 | 61 | print node2.getOutputPort("result").value # 5 + 5 62 | 63 | node1.getInputPort("value1").value = 10 64 | 65 | print node2.getOutputPort("result").value # 13 + 5 66 | 67 | 68 | # Let's play with the serialization 69 | data = graph.serialize() 70 | clone = dgpy.Graph.fromData(data) 71 | ``` 72 | 73 | > Check `tests.py` for more snippets of usage. 74 | 75 | 76 | ## Testing 77 | 78 | This project uses `unittest` as testing framework (python std library), I'm 79 | pretty sure every python developers out there have good reasons to prefer any 80 | of the alternatives available but I wanted to keep it simple/accesible to 81 | everyone without forcing dependencies. 82 | 83 | Coverage at the time this readme was written is 100%, but you can check it by 84 | running the test suite. 85 | 86 | ``` 87 | pip install coverage 88 | 89 | coverage run --source=dgpy -m unittest discover 90 | coverage report 91 | ``` 92 | -------------------------------------------------------------------------------- /dgpy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2016 Cesar Saez 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | 21 | import pprint 22 | import logging 23 | from collections import OrderedDict, Counter 24 | 25 | logger = logging.getLogger(__name__) 26 | NODES = dict() 27 | PUSH = 0 28 | PULL = 1 29 | 30 | 31 | def registerNode(nodeType): 32 | nodeName = nodeType.__name__ 33 | NODES[nodeName] = nodeType 34 | 35 | 36 | def getRefCounterFromData(data): 37 | count = Counter() 38 | 39 | for nodeName, nodeData in data["nodes"].items(): 40 | 41 | for portName, portData in nodeData["inputPorts"].items(): 42 | for sourceName in portData["sources"]: 43 | sourceOwner = sourceName.split(".")[0] 44 | count[sourceOwner] -= 1 45 | 46 | for portName, portData in nodeData["outputPorts"].items(): 47 | for sourceName in portData["sources"]: 48 | sourceOwner = sourceName.split(".")[0] 49 | count[sourceOwner] += 1 50 | 51 | # normalize 52 | if len(count.values()): 53 | toAdd = abs(min(count.values())) 54 | for k in count.keys(): 55 | count[k] += toAdd 56 | 57 | return count 58 | 59 | 60 | class Graph(object): 61 | def __init__(self): 62 | super(Graph, self).__init__() 63 | self._nodes = dict() 64 | self.model = None 65 | 66 | @property 67 | def nodes(self): 68 | return tuple(self._nodes.values()) 69 | 70 | def addNode(self, name, nodeType, **kwargs): 71 | node = nodeType(name) 72 | node.model = self.model 73 | self._nodes[name] = node 74 | for k, v in kwargs.items(): 75 | p = node.getInputPort(k) 76 | if p: 77 | p.value = v 78 | return node 79 | 80 | def removeNode(self, node): 81 | del self._nodes[node.name] 82 | 83 | def getNode(self, name): 84 | return self._nodes.get(name) 85 | 86 | def get(self, fullname): 87 | splitName = fullname.split(".") 88 | node = self.getNode(splitName[0]) 89 | if node and len(splitName) == 2: 90 | port = node.getInputPort(splitName[1]) 91 | if port is None: 92 | port = node.getOutputPort(splitName[1]) 93 | if port: 94 | return port 95 | return node 96 | 97 | def serialize(self): 98 | data = { 99 | "dataType": "dgpy", 100 | "version": "1.0.0", 101 | "model": self.model, 102 | "nodes": dict(), 103 | } 104 | for node in self.nodes: 105 | data["nodes"][node.name] = node.serialize() 106 | 107 | logger.debug("Serializing graph...") 108 | logger.debug(pprint.pformat(data)) 109 | return data 110 | 111 | @classmethod 112 | def fromData(cls, data): 113 | # validation 114 | if not isinstance(data, dict) or data.get("dataType") != "dgpy": 115 | return 116 | 117 | graph = cls() 118 | graph.model = data.get("model") 119 | 120 | count = getRefCounterFromData(data) 121 | orderedNodes = sorted(data["nodes"].keys(), key=lambda x: count[x]) 122 | 123 | for nodeName in orderedNodes: 124 | nodeData = data["nodes"][nodeName] 125 | nodeClass = NODES.get(nodeData["className"]) 126 | node = graph.addNode(nodeName, nodeClass) 127 | node.model = nodeData["model"] 128 | 129 | for portName, portData in nodeData["inputPorts"].items(): 130 | port = node.getInputPort(portName) 131 | 132 | for sourceName in portData["sources"]: 133 | port.connect(graph.get(sourceName)) 134 | 135 | if not port.isConnected: 136 | port.value = portData["value"] 137 | 138 | return graph 139 | 140 | 141 | class Port(object): 142 | value = property(fget=lambda x: x.getValue(), 143 | fset=lambda x, value: x.setValue(value)) 144 | 145 | def __init__(self, name): 146 | super(Port, self).__init__() 147 | self.name = name 148 | self.owner = None 149 | self._value = None 150 | self.sources = set() 151 | 152 | def serialize(self): 153 | data = { 154 | "value": self.value, 155 | "sources": [x.fullname for x in self.sources], 156 | } 157 | return data 158 | 159 | def getValue(self): 160 | return self._value 161 | 162 | def setValue(self, value): 163 | self._value = value 164 | 165 | @property 166 | def isConnected(self): 167 | return len(self.sources) > 0 168 | 169 | @property 170 | def fullname(self): 171 | return ".".join((self.owner.fullname, self.name)) 172 | 173 | 174 | class InputPort(Port): 175 | def setValue(self, value): 176 | super(InputPort, self).setValue(value) 177 | if self.owner.model == PUSH: 178 | self.owner.evaluate() 179 | if self.owner.model == PULL: 180 | self.owner.isDirty = True 181 | 182 | def connect(self, outputPort): 183 | self.sources.add(outputPort) 184 | outputPort.sources.add(self) 185 | self.value = outputPort.value 186 | 187 | def disconnect(self): 188 | port = self.sources.pop() 189 | port.sources.remove(self) 190 | 191 | 192 | class OutputPort(Port): 193 | def setValue(self, value): 194 | if self.isConnected: 195 | for port in self.sources: 196 | if self.owner.model == PUSH: 197 | port.value = value 198 | super(OutputPort, self).setValue(value) 199 | 200 | def getValue(self): 201 | if self.owner.model == PULL: 202 | if self.owner.isDirty: 203 | for port in self.owner._inputPorts.values(): 204 | if port.isConnected: 205 | for src in port.sources: 206 | port.value = src.value 207 | self.owner.evaluate() 208 | self.owner.isDirty = False 209 | return super(OutputPort, self).getValue() 210 | 211 | 212 | class VoidNode(object): 213 | """ 214 | Empty and most basic node type. 215 | Every custom node is asumed to subclass, or at least replicate, this 216 | interface. 217 | """ 218 | isDirty = property(fget=lambda x: x._isDirty, 219 | fset=lambda x, value: x.setDirty(value)) 220 | 221 | def __init__(self, name): 222 | super(VoidNode, self).__init__() 223 | self.name = name 224 | self.model = None 225 | self.evalCount = 0 226 | self._isDirty = True 227 | self._inputPorts = OrderedDict() 228 | self._outputPorts = OrderedDict() 229 | self.initPorts() 230 | 231 | @property 232 | def inputPorts(self): 233 | return tuple(self._inputPorts.values()) 234 | 235 | @property 236 | def outputPorts(self): 237 | return tuple(self._outputPorts.values()) 238 | 239 | @property 240 | def fullname(self): 241 | return self.name 242 | 243 | def serialize(self): 244 | """Return a dict containing the serialized data (json/yaml/markup friendly). 245 | """ 246 | data = { 247 | "model": self.model, 248 | "inputPorts": dict(), 249 | "outputPorts": dict(), 250 | "className": type(self).__name__, 251 | } 252 | for port in self.inputPorts: 253 | data["inputPorts"][port.name] = port.serialize() 254 | for port in self.outputPorts: 255 | data["outputPorts"][port.name] = port.serialize() 256 | return data 257 | 258 | def setDirty(self, value): 259 | """Tag the node to be reevaluated the next time the graph pull its data. 260 | """ 261 | self._isDirty = value 262 | if not value: 263 | return 264 | for port in self._outputPorts.values(): 265 | if port.isConnected: 266 | for src in port.sources: 267 | src.owner.isDirty = True 268 | 269 | def initPorts(self): 270 | """Callback where ports should be added/registered.""" 271 | pass 272 | 273 | def getInputPort(self, name): 274 | return self._inputPorts.get(name) 275 | 276 | def getOutputPort(self, name): 277 | return self._outputPorts.get(name) 278 | 279 | def addInputPort(self, name): 280 | """Add/register an input port to the node.""" 281 | port = InputPort(name) 282 | port.owner = self 283 | self._inputPorts[name] = port 284 | 285 | def addOutputPort(self, name): 286 | """Add/register an output port to the node.""" 287 | port = OutputPort(name) 288 | port.owner = self 289 | self._outputPorts[name] = port 290 | 291 | def evaluate(self): 292 | """Node computation should be implemented here. 293 | Please extend this method in order to keep track of evaluation count. 294 | """ 295 | logger.debug("Evaluating {}".format(self.name)) 296 | self.evalCount += 1 297 | 298 | registerNode(VoidNode) 299 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name="dgpy", 6 | version="0.1.0", 7 | py_modules=["dgpy"], 8 | url="http://github.com/csaez/dgpy", 9 | author="Cesar Saez", 10 | author_email="hi@cesarsaez.me", 11 | ) 12 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import unittest 3 | import logging 4 | import dgpy 5 | 6 | logger = logging.getLogger("dgpy") 7 | logger.addHandler(logging.StreamHandler()) 8 | 9 | 10 | def debug(func): 11 | @functools.wraps(func) 12 | def decorated(*args, **kwds): 13 | level = logger.level 14 | logger.setLevel(logging.DEBUG) 15 | 16 | rval = None 17 | try: 18 | rval = func(*args, **kwds) 19 | finally: 20 | logger.setLevel(level) 21 | 22 | return rval 23 | return decorated 24 | 25 | 26 | class AddNode(dgpy.VoidNode): 27 | def initPorts(self): 28 | super(AddNode, self).initPorts() 29 | self.addInputPort("value1") 30 | self.addInputPort("value2") 31 | self.addOutputPort("result") 32 | 33 | def evaluate(self): 34 | super(AddNode, self).evaluate() 35 | result = 0 36 | for p in self._inputPorts.values(): 37 | msg = "{0}: {1}".format(p.name, p.value) 38 | if p.isConnected: 39 | msg += " (connected)" 40 | logger.debug(msg) 41 | if p.value is not None: 42 | result += p.value 43 | self.getOutputPort("result").value = result 44 | logger.debug("---") 45 | 46 | dgpy.registerNode(AddNode) 47 | 48 | 49 | class UsageCase(unittest.TestCase): 50 | def testAddRemoveNodes(self): 51 | graph = dgpy.Graph() 52 | node1 = graph.addNode("node1", dgpy.VoidNode) 53 | self.assertEqual(len(graph.nodes), 1) 54 | graph.removeNode(node1) 55 | self.assertEqual(len(graph.nodes), 0) 56 | 57 | def testDisconnect(self): 58 | graph = dgpy.Graph() 59 | 60 | node1 = graph.addNode("node1", AddNode) 61 | node1.getInputPort("value1").value = 2 62 | node1.getInputPort("value2").value = 3 63 | 64 | node2 = graph.addNode("node2", AddNode, value1=5) 65 | node2.getInputPort("value2").connect(node1.getOutputPort("result")) 66 | 67 | self.assertTrue(node1.getOutputPort("result").isConnected) 68 | self.assertTrue(node2.getInputPort("value2").isConnected) 69 | 70 | node2.getInputPort("value2").disconnect() 71 | 72 | self.assertFalse(node1.getOutputPort("result").isConnected) 73 | self.assertFalse(node2.getInputPort("value2").isConnected) 74 | 75 | def testGetter(self): 76 | graph = dgpy.Graph() 77 | graph.addNode("node1", AddNode) 78 | 79 | self.assertIsNotNone(graph.get("node1.value1")) 80 | self.assertIsNotNone(graph.get("node1.value2")) 81 | self.assertIsNotNone(graph.get("node1.result")) 82 | self.assertIsNotNone(graph.get("node1")) 83 | self.assertIsNone(graph.get("foo")) 84 | 85 | def testDataValidation(self): 86 | self.assertIsNone(dgpy.Graph.fromData(1)) 87 | self.assertIsNone(dgpy.Graph.fromData("foo")) 88 | self.assertIsNone(dgpy.Graph.fromData({"foo": "bar"})) 89 | self.assertIsNone(dgpy.Graph.fromData({"dataType": "bar"})) 90 | 91 | 92 | class PushModelCase(unittest.TestCase): 93 | def __init__(self, *args): 94 | super(PushModelCase, self).__init__(*args) 95 | self.model = dgpy.PUSH 96 | 97 | def testSingleNodeEvaluation(self): 98 | graph = dgpy.Graph() 99 | graph.model = self.model 100 | 101 | node1 = graph.addNode("node1", AddNode) 102 | node1.getInputPort("value1").value = 2 103 | node1.getInputPort("value2").value = 3 104 | self.assertEqual(node1.getOutputPort("result").value, 2+3) 105 | 106 | return graph 107 | 108 | def testModelIsolation(self): 109 | graph = dgpy.Graph() 110 | graph.model = None 111 | 112 | node1 = graph.addNode("node1", AddNode) 113 | node1.getInputPort("value1").value = 2 114 | node1.getInputPort("value2").value = 3 115 | self.assertIsNone(node1.getOutputPort("result").value) 116 | 117 | return graph 118 | 119 | def testNodeConnections(self): 120 | graph = dgpy.Graph() 121 | graph.model = self.model 122 | 123 | node1 = graph.addNode("node1", AddNode) 124 | node1.getInputPort("value1").value = 2 125 | node1.getInputPort("value2").value = 3 126 | 127 | node2 = graph.addNode("node2", AddNode, value1=5) 128 | node2.getInputPort("value2").connect(node1.getOutputPort("result")) 129 | self.assertEqual(node2.getOutputPort("result").value, 5 + 5) 130 | 131 | return graph 132 | 133 | def testPersistentConnections(self): 134 | graph = dgpy.Graph() 135 | graph.model = self.model 136 | 137 | node1 = graph.addNode("node1", AddNode) 138 | node1.getInputPort("value1").value = 2 139 | node1.getInputPort("value2").value = 3 140 | 141 | node2 = graph.addNode("node2", AddNode, value1=5) 142 | node2.getInputPort("value2").connect(node1.getOutputPort("result")) 143 | self.assertEqual(node2.getOutputPort("result").value, 5 + 5) 144 | 145 | node1.getInputPort("value1").value = 10 146 | self.assertEqual(node2.getOutputPort("result").value, 5 + 13) 147 | 148 | return graph 149 | 150 | def testBranching(self): 151 | graph = self.testNodeConnections() 152 | 153 | node1 = graph.getNode("node1") 154 | 155 | node3 = graph.addNode("node3", AddNode, value1=8) 156 | node3.getInputPort("value2").connect(node1.getOutputPort("result")) 157 | self.assertEqual(node3.getOutputPort("result").value, 8 + 5) 158 | 159 | return graph 160 | 161 | def testBranchingPersistence(self): 162 | graph = self.testBranching() 163 | 164 | node1 = graph.getNode("node1") 165 | node1.getInputPort("value1").value = 1 166 | self.assertEqual(node1.getOutputPort("result").value, 1 + 3) 167 | 168 | node3 = graph.getNode("node3") 169 | self.assertEqual(node3.getOutputPort("result").value, 8 + 4) 170 | 171 | node2 = graph.getNode("node2") 172 | self.assertEqual(node2.getOutputPort("result").value, 5 + 4) 173 | 174 | 175 | class PullModelCase(PushModelCase): 176 | def __init__(self, *args): 177 | super(PullModelCase, self).__init__(*args) 178 | self.model = dgpy.PULL 179 | 180 | def testEvaluationCount(self): 181 | graph = self.testSingleNodeEvaluation() 182 | node1 = graph.getNode("node1") 183 | node1.getOutputPort("result").value 184 | self.assertEqual(node1.evalCount, 1) 185 | 186 | 187 | class SerializationCase(unittest.TestCase): 188 | def testEmptyGraph(self): 189 | graph1 = dgpy.Graph() 190 | graph1.model = dgpy.PULL 191 | data = graph1.serialize() 192 | 193 | graph2 = dgpy.Graph.fromData(data) 194 | self.assertEqual(graph1.model, graph2.model) 195 | self.assertEqual(len(graph1.nodes), len(graph2.nodes)) 196 | 197 | def testOrphanNodes(self): 198 | graph1 = dgpy.Graph() 199 | graph1.model = dgpy.PULL 200 | graph1.addNode("testingVoidNode", dgpy.VoidNode) 201 | graph1.addNode("testingAddNode", AddNode, value1=2, value2=3) 202 | data = graph1.serialize() 203 | 204 | graph2 = dgpy.Graph.fromData(data) 205 | self.assertDictEqual(data, graph2.serialize()) 206 | 207 | def testConnectedNodes(self): 208 | graph = dgpy.Graph() 209 | graph.model = dgpy.PULL 210 | 211 | node1 = graph.addNode("node1", AddNode) 212 | node1.getInputPort("value1").value = 2 213 | node1.getInputPort("value2").value = 3 214 | 215 | node2 = graph.addNode("node2", AddNode, value1=5) 216 | node2.getInputPort("value2").connect(node1.getOutputPort("result")) 217 | 218 | node3 = graph.addNode("node3", AddNode, value1=8) 219 | node3.getInputPort("value2").connect(node2.getOutputPort("result")) 220 | 221 | graph.addNode("testingVoidNode", dgpy.VoidNode) 222 | 223 | data = graph.serialize() 224 | 225 | graph2 = dgpy.Graph.fromData(data) 226 | self.assertTrue(graph2.get("node2.value2").isConnected) 227 | self.assertTrue(graph2.get("node3.value2").isConnected) 228 | 229 | if __name__ == '__main__': 230 | unittest.main(verbosity=2) 231 | --------------------------------------------------------------------------------