├── .gitignore ├── LICENSE.md ├── README.md ├── package-lock.json ├── package.json ├── src ├── ast.ts ├── cmd │ ├── app.ts │ ├── common.ts │ ├── entrypoint.ts │ ├── logger.ts │ ├── main.ts │ ├── repl.ts │ └── worker.ts ├── generate │ ├── client.ts │ ├── helpers.ts │ ├── index.ts │ └── model.ts ├── index.ts ├── internal │ ├── client.ts │ ├── config.ts │ ├── consumer.ts │ ├── dataTypes.ts │ ├── db │ │ ├── PostgresDb.ts │ │ ├── SqliteDb.ts │ │ ├── TestDb.ts │ │ └── index.ts │ ├── id.ts │ ├── index.ts │ ├── meta.ts │ ├── object.ts │ ├── project.ts │ ├── queue │ │ ├── PostgresQueue.ts │ │ ├── SqliteQueue.ts │ │ ├── common.ts │ │ ├── index.ts │ │ └── test.ts │ ├── transition.ts │ └── types.ts └── parser.ts ├── test ├── jest.config.js ├── package-lock.json ├── package.json ├── restate.config.json ├── restate │ ├── Email.rst │ ├── TypesTest.rst │ └── User.rst └── src │ ├── index.test.ts │ └── restate.ts └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | node_modules 4 | 5 | src/generated 6 | 7 | test/node_modules 8 | test/restate.sqlite -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fabian Lindfors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Restate 2 | 3 | Restate is an experimental Typescript framework for building backends using state machines. With Restate, you define all database models as state machines which can only be modified through state transitions. The logic for the transitions are defined in code, and it's also possible to run code asynchronously in response to state transitions, for example to trigger new transitions, send emails or make API requests. This enables more complex business logic to be expressed by connecting together simpler state machines. 4 | 5 | The point of Restate is to help build systems which are: 6 | 7 | - **Debuggable:** All state transitions are tracked, making it easy to trace how a database object ended up in its current state and what triggered its transitions (an admin interface is in the works). 8 | - **Understandable:** All business logic is encoded in state transitions and consumers, making it easy to understand the full behavior of your system. Writing decoupled code also becomes easier with consumers. 9 | - **Reliable:** Consumers are automatically retried on failure and change data capture is used to ensure no transitions are missed. 10 | 11 | Does that sound interesting? Then keep reading for a walkthrough of a sample project! 12 | 13 | ## Getting started 14 | 15 | ### Installation 16 | 17 | To get started with Restate, we are going to create a standard Node project and install Restate: 18 | 19 | ```text 20 | $ mkdir my-first-restate-project && cd my-first-restate-project 21 | $ npm init 22 | $ npm install --save restate-ts 23 | ``` 24 | 25 | For this example, we are going to be using Express to build our API, so we need to install that as well: 26 | 27 | ```text 28 | $ npm install --save express 29 | ``` 30 | 31 | Restate has a built in development tool with auto-reloading, start it and keep it running in the background as you code: 32 | 33 | ```text 34 | $ npx restate 35 | ``` 36 | 37 | You'll see a warning message saying that no project definition was found, but don't worry about that, we'll create one soon! 38 | 39 | ### Defining models and transitions 40 | 41 | Database models in Reshape are defined in a custom file type, `.rst`, and stored in the `restate/` folder of your project. Every database model is a state machine and hence we need to define the possible states and transitions between those states. 42 | 43 | For this project, we are going to model a very simple application that tracks orders. Orders start out as created and are then paid by the customer. Once the order has been paid, we want to book a delivery with our carrier. To model this in Restate, let's create a new file called `restate/Order.rst` with the following contents: 44 | 45 | ``` 46 | model Order { 47 | // All models have an autogenerated `id` field with a prefix to make them easily identifiable 48 | // In this case, they will look something like: "order_01gqjyp438r30j3g28jt78cx23" 49 | prefix "order" 50 | 51 | // The common fields defined here will be available across all states 52 | field amount: Int 53 | 54 | state Created {} 55 | 56 | state Paid { 57 | field paymentReference: String 58 | } 59 | 60 | // States can inherit other state's fields, so in this case `DeliveryBooked` will have `amount` and `paymentReference` fields as well 61 | state DeliveryBooked: Paid { 62 | field trackingNumber: String 63 | } 64 | 65 | // `Create` doesn't have any starting states and is hence an initializing transition. 66 | // It will be used to create new orders. 67 | transition Create: Created { 68 | field amount: Int 69 | } 70 | 71 | // `Pay` is triggered when payment is received for the order 72 | transition Pay: Created -> Paid { 73 | field paymentReference: String 74 | } 75 | 76 | // `BookDelivery` is triggered when an order has been sent and we are ready to book delivery 77 | transition BookDelivery: Paid -> DeliveryBooked {} 78 | } 79 | ``` 80 | 81 | ### Generating the Restate client 82 | 83 | Once we have defined our models, the dev session you have running will automatically generate types for your models as well as a client to interact with them. All of this can be imported directly from the `restate-ts` module. 84 | 85 | The starting point of any Restate project is the project definition, which lives in `src/restate.ts`. The definition we export from that file defines how our models' transitions are handled. Let's start with some placeholder values, create a `src/restate.ts` file with the following code: 86 | 87 | ```typescript 88 | import { RestateProject, RestateClient, Order } from "restate-ts"; 89 | 90 | const project: RestateProject = { 91 | async main(restate: RestateClient) { 92 | // main is the entrypoint for your project. Here you could for example start a web server and use 93 | // `restate` to create and transition objects. 94 | }, 95 | 96 | transitions: { 97 | // We need to provide implementations for all transitions 98 | order: { 99 | async create(restate: RestateClient, transition: Order.Create) { 100 | throw new Error("Create transition not implemented"); 101 | }, 102 | 103 | async pay( 104 | restate: RestateClient, 105 | order: Order.Created, 106 | transition: Order.Pay 107 | ) { 108 | throw new Error("Pay transition not implemented"); 109 | }, 110 | 111 | async bookDelivery( 112 | restate: RestateClient, 113 | order: Order.Paid, 114 | transition: Order.BookDelivery 115 | ) { 116 | throw new Error("BookDelivery transition not implemented"); 117 | }, 118 | }, 119 | }, 120 | }; 121 | 122 | // The definition should be the default export 123 | export default project; 124 | ``` 125 | 126 | ### Creating orders 127 | 128 | Before we can create orders, we need to actually implement the `Create` transition in `src/restate.ts`: 129 | 130 | ```typescript 131 | const project: RestateProject = { 132 | // ... 133 | transitions: { 134 | order: { 135 | async create(restate: RestateClient, transition: Order.Create) { 136 | // We should return the shape of the object after the transition has been applied 137 | // As this is an initializing transition, it will result in a new object being created 138 | return { 139 | state: Order.State.Created, 140 | // amount is passed through the transition and saved to the object 141 | amount: transition.data.amount, 142 | }; 143 | }, 144 | }, 145 | }, 146 | }; 147 | ``` 148 | 149 | To interact with our backend, we are going to create a simple HTTP API using `express`. Restate is fully API agnostic though so you can interact with it however you want; REST, GraphQL, SOAP, anything goes! Let's start a simple web server from the `main` function in `src/restate.ts` with a single endpoint to create a new order: 150 | 151 | ```typescript 152 | import express from "express"; 153 | import { RestateProject, RestateClient, Order } from "restate-ts"; 154 | 155 | const project: RestateProject = { 156 | async main(restate: RestateClient) { 157 | const app = express(); 158 | 159 | app.post("/orders", async (req, res) => { 160 | // Get amount from query parameter 161 | const amount = parseInt(req.query.amount); 162 | 163 | // Trigger `Create` transition to create a new order object 164 | const [order] = await restate.order.transition.create({ 165 | data: { 166 | amount, 167 | }, 168 | }); 169 | 170 | // Respond with our new order object in JSON format 171 | res.json(order); 172 | }); 173 | 174 | app.listen(3000, () => { 175 | console.log("API server started!"); 176 | }); 177 | }, 178 | }; 179 | ``` 180 | 181 | The dev session should automatically reload and you should see "API server started!" in the output. Let's test it! 182 | 183 | ```shell 184 | $ curl -X POST "localhost:3000/orders?amount=100" 185 | { 186 | "id": "order_01gqjyp438r30j3g28jt78cx23", 187 | "state": "created", 188 | "amount": 100 189 | } 190 | ``` 191 | 192 | It works and we get a nice order back! Here you see both the `amount` field, which we specified in `Order.rst`, but also `id` and `state`. These are fields which are automatically added for all models. 193 | 194 | ### Querying orders 195 | 196 | Being able to create data wouldn't do much good if we can't get it back, which Restate handles using queries. We'll add the following code to our main function to introduce a new endpoint for getting all orders: 197 | 198 | ```typescript 199 | app.get("/orders", async (req, res) => { 200 | // Get all orders from the database 201 | const orders = await restate.order.findAll(); 202 | 203 | res.json(orders); 204 | }); 205 | ``` 206 | 207 | And if we try that, we unsurprisingly get back: 208 | 209 | ```shell 210 | $ curl localhost:3000/orders 211 | [ 212 | { 213 | "id": "order_01gqjyp438r30j3g28jt78cx23", 214 | "state": "created", 215 | "amount": 100 216 | } 217 | ] 218 | ``` 219 | 220 | ### Transitioning orders when paid 221 | 222 | Now we are getting to the nice parts. We've created our order and the next step is to update it to the paid state once we receive a payment. The first step is to add a very simple implementation for the `Pay` transition: 223 | 224 | ```typescript 225 | const project: RestateProject = { 226 | // ... 227 | transitions: { 228 | order: { 229 | async pay( 230 | restate: RestateClient, 231 | order: Order.Created, 232 | transition: Order.Pay 233 | ) { 234 | return { 235 | // The spread operator is a convenient way of avoiding having to specify all fields again 236 | ...order, 237 | state: Order.State.Paid, 238 | paymentReference: transition.data.paymentReference, 239 | }; 240 | }, 241 | }, 242 | }, 243 | }; 244 | ``` 245 | 246 | For this example, let's say our payment provider will send us a webhook when an order is paid for. To handle that, we'll need another endpoint which should trigger the `Pay` transition for an order: 247 | 248 | ```typescript 249 | app.post("/webhook/order_paid/:orderId", async (req, res) => { 250 | // Get payment reference from query parameters 251 | const reference = req.query.reference; 252 | 253 | // Trigger the `Pay` transition for the order, which returns the updated object 254 | const [order] = await restate.order.transition.pay({ 255 | object: req.params.orderId, 256 | data: { 257 | // The `Pay` transition requires us to pass the payment reference 258 | paymentReference: req.query.reference, 259 | }, 260 | }); 261 | 262 | // Respond with the updated object 263 | res.json(order); 264 | }); 265 | ``` 266 | 267 | If we were to simulate a webhook request from our payment provider, we get back an order in the expected state and with the passed reference saved to a new field (remember to replace the order ID with the one you got in the last request): 268 | 269 | ```shell 270 | $ curl -X POST "localhost:3000/webhook/order_paid/order_01gqjyp438r30j3g28jt78cx23?reference=abc123" 271 | { 272 | "id": "order_01gqjyp438r30j3g28jt78cx23", 273 | "state": "paid", 274 | "amount": 100, 275 | "paymentReference": "abc123" 276 | } 277 | ``` 278 | 279 | ### Asynchronously booking deliveries 280 | 281 | For the final part of this example, we want to book a delivery when an order is paid, with an imagined API call to our shipping carrier. Let's start by implementing the final `bookDelivery` transition for this: 282 | 283 | ```typescript 284 | const project: RestateProject = { 285 | // ... 286 | transitions: { 287 | order: { 288 | async bookDelivery( 289 | restate: RestateClient, 290 | order: Order.Paid, 291 | transition: Order.BookDelivery 292 | ) { 293 | // This is where we'd call the shipping carriers API and get a tracking number back, but for the sake 294 | // of the example, we'll use a static value 295 | const trackingNumber = "123456789"; 296 | 297 | return { 298 | ...order, 299 | state: Order.State.DeliveryBooked, 300 | trackingNumber, 301 | }; 302 | }, 303 | }, 304 | }, 305 | }; 306 | ``` 307 | 308 | What we could do is simply trigger this transition right in our payment webhook, but our shipping carrier's API is really slow and unreliable, so we don't want to bog down the webhook handler with that. Preferably we want to perform the delivery booking asynchronously! This is where one of Restate's central features come in: consumers. 309 | 310 | Consumers let's us write code that runs asynchronously in response to transitions. This lets us improve reliability, performance and code quality through decoupling. Like most everything in Restate, consumers are defined in `src/restate.ts`. In our case, we want to trigger the `BookDelivery` transition when the `Pay` transition has completed, so let's add a consumer for that: 311 | 312 | ```typescript 313 | const project: RestateProject = { 314 | // ... 315 | consumers: [ 316 | Order.createConsumer({ 317 | // Every consumer should have a unique name 318 | name: "BookDeliveryAfterOrderPaid", 319 | 320 | // We can tell our consumer to only trigger on specific transitions 321 | transition: Order.Transition.Pay, 322 | 323 | async handler( 324 | restate: RestateClient, 325 | order: Order.Any, 326 | transition: Order.Pay 327 | ) { 328 | // You might notice that `order` has type `Order.Any` rather than `Order.Paid`. 329 | // It's possible that the object changed since the consumer was queued but we'll always 330 | // get the latest version in here. Because consumers are asynchronous, this is something 331 | // we must take into consideration. 332 | if (order.state != Order.State.Paid) { 333 | return; 334 | } 335 | 336 | // Trigger `BookDelivery` transition, which will take a little while but that is completely fine! 337 | await restate.order.transition.bookDelivery({ 338 | object: order, 339 | }); 340 | }, 341 | }), 342 | ], 343 | // ... 344 | }; 345 | ``` 346 | 347 | If you now mark a payment as paid using the webhook endpoint, you should soon after see that the order has been updated again: 348 | 349 | ```shell 350 | $ curl localhost:3000/orders 351 | [ 352 | { 353 | "id": "order_01gqjyp438r30j3g28jt78cx23", 354 | "state": "deliveryBooked", 355 | "amount": 100, 356 | "paymentReference": "abc123", 357 | "trackingNumber": "123456789" 358 | } 359 | ] 360 | ``` 361 | 362 | That's it for the introduction! Keep reading to learn more about the different features of Restate. 363 | 364 | ## Model definitions 365 | 366 | ### IDs and prefixes 367 | 368 | Every Restate model has an implicit field, `id`, which stores an autogenerated identifier. All IDs are prefixed with a string unique to the model, which makes it easier to identify what an ID is for. Here's an example of defining an `Order` model with prefix `order`. Objects of this model will automatically get IDs that look like: `order_01gqjyp438r30j3g28jt78cx23`. 369 | 370 | ``` 371 | model Order { 372 | prefix "order" 373 | } 374 | ``` 375 | 376 | ### Fields 377 | 378 | All Restate models have two implicit fields: `id` and `state`, which store an autogenerated ID and the current state respectively. When defining a model, it's also possible to add custom fields. Fields can be defined top-level, in which case they will be part of all states, or only on specific states. 379 | 380 | ``` 381 | model User { 382 | field name: String 383 | 384 | state Verified { 385 | field age: Int 386 | } 387 | } 388 | ``` 389 | 390 | Every field has a data type and is by default non-nullable. If you want to make a field nullable, wrap the type in an `Optional`: 391 | 392 | ``` 393 | model User { 394 | field name: Optional[String] 395 | } 396 | ``` 397 | 398 | Restate supports the following data types: 399 | 400 | | **Data type** | **Description** | **Typescript equivalent** | 401 | | ---------------- | -------------------------------- | ------------------------- | 402 | | `String` | Variable-length string | `string` | 403 | | `Int` | Integer which may be negative | `number` | 404 | | `Decimal` | Decimal number | `number` | 405 | | `Bool` | Boolean, either true or false | `boolean` | 406 | | `Optional[Type]` | Nullable version of another type | `Type \| null` | 407 | 408 | ## Client 409 | 410 | The Restate client is used to create, transition and query objects. In the following examples, we'll be working with a model definition that looks like this: 411 | 412 | ``` 413 | model Order { 414 | prefix "order" 415 | 416 | field amount: Int 417 | 418 | state Created {} 419 | state Paid {} 420 | 421 | transition Create: Created { 422 | field amount: Int 423 | } 424 | 425 | transition Pay: Created -> Paid {} 426 | } 427 | ``` 428 | 429 | ### Transitions 430 | 431 | The client can be used to trigger initializing transitions, which create new objects. The transition call will return the new object after the transition has been applied. 432 | 433 | ```typescript 434 | const [order] = await restate.order.transition.create({ 435 | data: { 436 | amount: 100, 437 | }, 438 | }); 439 | ``` 440 | 441 | For regular transitions, one must also specify which object to apply the transition to by passing an object ID or a full object. 442 | 443 | ```typescript 444 | const [paidOrder] = await restate.order.transition.pay({ 445 | object: "order_01gqjyp438r30j3g28jt78cx23", 446 | }); 447 | ``` 448 | 449 | If passing a full object, the types will ensure it's in the correct state for the transition to apply: 450 | 451 | ```typescript 452 | const order: Order.Paid = { 453 | // ... 454 | }; 455 | 456 | const [paidOrder] = await restate.order.transition.pay({ 457 | // This will trigger a type error because the `Pay` transition can 458 | // only be applied to orders in state `Created` 459 | object: order, 460 | }); 461 | ``` 462 | 463 | Transition calls will also return the full transition object if needed: 464 | 465 | ```typescript 466 | const [paidOrder, transition] = await restate.order.transition.pay({ 467 | object: "order_01gqjyp438r30j3g28jt78cx23", 468 | }); 469 | 470 | console.log(transition.id); 471 | // tsn_01gqjyp438r30j3g28jt78cx23 472 | // Transition IDs have prefix "tsn" 473 | ``` 474 | 475 | It's of course also possible to get a transition by ID or all transitions for an object: 476 | 477 | ```typescript 478 | const [paidOrder, transition] = await restate.order.transition.pay({ 479 | object: "order_01gqjyp438r30j3g28jt78cx23", 480 | }); 481 | 482 | // Find a single transition by ID 483 | const transitionById = await restate.order.getTransition(transition.id); 484 | 485 | // Find all transitions for an object (starting with the latest one) 486 | const allTransitions = await restate.order.getObjectTransitions(paidOrder); 487 | ``` 488 | 489 | For debugging purposes, it's possible to add a free text note to a transition. This field is designed to be human readable and should not be relied upon by your code: 490 | 491 | ```typescript 492 | const [paidOrder] = await restate.order.transition.pay({ 493 | order: "order_01gqjyp438r30j3g28jt78cx23", 494 | note: "Payment manually verified", 495 | }); 496 | ``` 497 | 498 | ### Queries 499 | 500 | There are different kinds of queries depending on how many results you expect back. To find a single object by ID, you can do: 501 | 502 | ```typescript 503 | const order: Order.Any | null = await restate.order.findOne({ 504 | where: { 505 | id: "order_01gqjyp438r30j3g28jt78cx23", 506 | }, 507 | }); 508 | ``` 509 | 510 | Similarly, it's possible to filter by all fields on a model and to find many objects: 511 | 512 | ```typescript 513 | const orders: Order.Any[] = await restate.order.findAll({ 514 | where: { 515 | amount: 100, 516 | }, 517 | }); 518 | ``` 519 | 520 | When querying by state, the resulting object will have the expected type: 521 | 522 | ```typescript 523 | const orders: Order.Created[] = await restate.order.findAll({ 524 | where: { 525 | state: Order.State.Created, 526 | }, 527 | }); 528 | ``` 529 | 530 | If you want an error to be thrown if no object could be found, use `findOneOrThrow`: 531 | 532 | ```typescript 533 | const order: Order.Any = await restate.order.findOneOrThrow({ 534 | where: { 535 | id: "order_01gqjyp438r30j3g28jt78cx23", 536 | }, 537 | }); 538 | ``` 539 | 540 | You can also limit the number of objects you want to fetch: 541 | 542 | ```typescript 543 | const orders: Order.Created[] = await restate.order.findAll({ 544 | where: { 545 | state: Order.State.Created, 546 | }, 547 | limit: 10, 548 | }); 549 | ``` 550 | 551 | ## Testing 552 | 553 | Restate has built-in support for testing with a real database. In your test cases, import the project definition from `src/restate.ts` and pass it to `setupTestClient` to create a new Restate client for testing. This client will automatically configure an in-memory SQLite database and will run any consumers synchronously when transitions are triggered. 554 | 555 | Here's an example in [Jest](https://jestjs.io), but any test framework will work: 556 | 557 | ```typescript 558 | import { test, expect, beforeEach } from "@jest/globals"; 559 | import { Order, RestateClient, setupTestClient } from "restate-ts"; 560 | 561 | // Import project definition from "restate.ts" 562 | import project from "./restate"; 563 | 564 | let restate: RestateClient; 565 | 566 | beforeEach(async () => { 567 | // Create a new test client for each test run 568 | restate = await setupTestClient(project); 569 | }); 570 | 571 | test("delivery is booked when order is paid", async () => { 572 | // Create order 573 | const order = await restate.order.transition.create({ 574 | data: { 575 | amount: 100, 576 | }, 577 | }); 578 | 579 | // Trigger `Pay` transition on order 580 | await restate.order.transition.pay({ 581 | object: order, 582 | data: { 583 | paymentReference: "abc123", 584 | }, 585 | }); 586 | 587 | // The `BookDeliveryAfterOrderPaid` consumer should have been triggered when the order was paid 588 | // and transitioned it into `DeliveryBooked`. With the test client, consumers are run synchronously. 589 | const updatedOrder = await restate.order.findOneOrThrow({ 590 | where: { 591 | id: order.id, 592 | }, 593 | }); 594 | expect(user.state).toBe(Order.State.DeliveryBooked); 595 | expect(user.trackingNumber).toBe("123456789"); 596 | }); 597 | ``` 598 | 599 | ## Config 600 | 601 | If you want to configure Restate, create a `restate.config.json` file in the root of your project. In your config file, you can specify settings based on environment and the environment will be based on the `NODE_ENV` environemnt variable. When running `restate dev`, the default environment will be `development`. For all other commands, it will default to `production`. 602 | 603 | In your config file, you can configure what database to use. Restate supports both Postgres and SQLite, where we recommend using Postgres in production and SQLite during development and testing. Below is an annotated example of a config file, showing what settings exist and the defaults: 604 | 605 | ```jsonc 606 | { 607 | "database": { 608 | "type": "postgres", 609 | "connection_string": "postgres://postgres:@localhost:5432/postgres" 610 | }, 611 | 612 | // The settings in here will only be used in the development environment 613 | "development": { 614 | "database": { 615 | "type": "sqlite", 616 | "connection_string": "restate.sqlite" 617 | } 618 | } 619 | } 620 | ``` 621 | 622 | ## Commands 623 | 624 | ### `restate dev` 625 | 626 | Starts an auto-reloading dev server for your project. It will automatically generate a client and run both your main function and a worker to handle consumers. 627 | 628 | ### `restate main` 629 | 630 | Starts the main function as defined in your project definition. 631 | 632 | ### `restate worker` 633 | 634 | Starts a worker which handles running consumers in response to transitions. 635 | 636 | ### `restate generate` 637 | 638 | Regenerates the Restate client and types based on your `*.rst` files. 639 | 640 | ### `restate migrate` 641 | 642 | Automatically sets up tables for all your models. Runs automatically as part of `restate dev`. 643 | 644 | ## License 645 | 646 | Restate is [MIT licensed](LICENSE.md) 647 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "restate-ts", 3 | "version": "0.1.0", 4 | "description": "Build reliable and understandable backends with state machines", 5 | "main": "src/index.ts", 6 | "scripts": {}, 7 | "author": "Fabian Lindfors", 8 | "license": "MIT", 9 | "dependencies": { 10 | "@types/pg": "^8.6.6", 11 | "chalk": "^4.0.0", 12 | "chokidar": "^3.5.3", 13 | "deepmerge": "^4.3.0", 14 | "js-convert-case": "^4.2.0", 15 | "json5": "^2.2.3", 16 | "knex": "^2.3.0", 17 | "pg": "^8.9.0", 18 | "pg-logical-replication": "^2.0.3", 19 | "pluralize": "^8.0.0", 20 | "prepend-transform": "^0.0.1019", 21 | "sqlite3": "^5.1.4", 22 | "ts-morph": "^14.0.0", 23 | "ts-node": "^10.7.0", 24 | "typescript-parsec": "^0.3.2", 25 | "ulid": "^2.1.8", 26 | "winston": "^3.8.2" 27 | }, 28 | "bin": { 29 | "restate": "src/cmd/entrypoint.ts" 30 | } 31 | } -------------------------------------------------------------------------------- /src/ast.ts: -------------------------------------------------------------------------------- 1 | import { toCamelCase, toPascalCase, toSnakeCase } from "js-convert-case"; 2 | import { DataType, dataTypeFromParsed } from "./internal/dataTypes"; 3 | import * as Parser from "./parser"; 4 | 5 | export class Model { 6 | private name: string; 7 | private states: Map; 8 | private transitions: Map; 9 | private prefix: string; 10 | 11 | constructor({ name, states, transitions, baseFields, prefix }: Parser.Model) { 12 | if (!prefix) { 13 | throw new Error(`Missing prefix for model ${name}`); 14 | } 15 | 16 | const newStates: Map = kvToMap( 17 | states.map((state) => [state.name, new State(state)]) 18 | ); 19 | 20 | // Fill in the extended states 21 | states.forEach((state) => { 22 | const newState = newStates.get(state.name)!; 23 | newState.setExtendedStates( 24 | state.extends.map((extended) => newStates.get(extended)!) 25 | ); 26 | }); 27 | 28 | // Add fields from extended states 29 | newStates.forEach((state, _1, _2) => { 30 | baseFields.forEach((baseField) => { 31 | state.addField(new Field(baseField)); 32 | }); 33 | 34 | Model.addFieldsFromExtendedStates(state); 35 | }); 36 | 37 | const newTransitions: Map = kvToMap( 38 | transitions.map((transition) => [ 39 | transition.name, 40 | new Transition(transition, newStates), 41 | ]) 42 | ); 43 | 44 | this.name = name; 45 | this.states = newStates; 46 | this.transitions = newTransitions; 47 | this.prefix = prefix.prefix; 48 | } 49 | 50 | private static addFieldsFromExtendedStates(state: State) { 51 | state.getExtendedStates().forEach((extended) => { 52 | // Add extended fields recursively 53 | Model.addFieldsFromExtendedStates(extended); 54 | 55 | extended.getFields().forEach((field) => state.addField(field)); 56 | }); 57 | } 58 | 59 | getStates(): State[] { 60 | return Array.from(this.states.values()); 61 | } 62 | 63 | getTransitions(): Transition[] { 64 | return Array.from(this.transitions.values()); 65 | } 66 | 67 | getPrefix(): string { 68 | return this.prefix; 69 | } 70 | 71 | pascalCaseName(): string { 72 | return this.name; 73 | } 74 | 75 | camelCaseName(): string { 76 | return toCamelCase(this.name); 77 | } 78 | } 79 | 80 | export class State { 81 | private name: string; 82 | private fields: Map; 83 | private extendedStates: State[]; 84 | 85 | constructor({ name, fields }: Parser.State) { 86 | const newFields: Map = kvToMap( 87 | fields.map((field) => [field.name, new Field(field)]) 88 | ); 89 | 90 | this.name = name; 91 | this.fields = newFields; 92 | this.extendedStates = []; 93 | } 94 | 95 | addField(field: Field) { 96 | this.fields.set(field.camelCaseName(), field); 97 | } 98 | 99 | getFields(): Field[] { 100 | return Array.from(this.fields.values()); 101 | } 102 | 103 | setExtendedStates(extended: State[]) { 104 | this.extendedStates = extended; 105 | } 106 | 107 | getExtendedStates(): State[] { 108 | return this.extendedStates; 109 | } 110 | 111 | pascalCaseName(): string { 112 | return this.name; 113 | } 114 | 115 | camelCaseName(): string { 116 | return toCamelCase(this.name); 117 | } 118 | } 119 | 120 | export class Field { 121 | private name: string; 122 | private type: DataType; 123 | 124 | constructor({ name, type }: Parser.Field) { 125 | this.name = name; 126 | this.type = dataTypeFromParsed(type); 127 | } 128 | 129 | pascalCaseName(): string { 130 | return toPascalCase(this.name); 131 | } 132 | 133 | camelCaseName(): string { 134 | return this.name; 135 | } 136 | 137 | getType(): DataType { 138 | return this.type; 139 | } 140 | } 141 | 142 | export class Transition { 143 | private name: string; 144 | private from?: State[] | "*"; 145 | private to: State[] | "*"; 146 | private fields: Map; 147 | 148 | constructor( 149 | { name, from, to, fields }: Parser.Transition, 150 | states: Map 151 | ) { 152 | const fromState = (() => { 153 | if (!from) { 154 | return undefined; 155 | } 156 | return from[0] == "*" ? "*" : from.map((name) => states.get(name)!); 157 | })(); 158 | 159 | const toState = to[0] == "*" ? "*" : to.map((name) => states.get(name)!); 160 | const newFields: Map = kvToMap( 161 | fields.map((field) => [field.name, new Field(field)]) 162 | ); 163 | 164 | this.name = name; 165 | this.from = fromState; 166 | this.to = toState; 167 | this.fields = newFields; 168 | } 169 | 170 | pascalCaseName(): string { 171 | return toPascalCase(this.name); 172 | } 173 | 174 | camelCaseName(): string { 175 | return toCamelCase(this.name); 176 | } 177 | 178 | snakeCaseName(): string { 179 | return toSnakeCase(this.name); 180 | } 181 | 182 | getFromStates(): undefined | State[] | "*" { 183 | return this.from; 184 | } 185 | 186 | getToStates(): State[] | "*" { 187 | return this.to; 188 | } 189 | 190 | getFields(): Field[] { 191 | return Array.from(this.fields.values()); 192 | } 193 | } 194 | 195 | function kvToMap(entries: [K, V][]): Map { 196 | let map: Map = new Map(); 197 | 198 | for (const [key, value] of entries) { 199 | map.set(key, value); 200 | } 201 | 202 | return map; 203 | } 204 | -------------------------------------------------------------------------------- /src/cmd/app.ts: -------------------------------------------------------------------------------- 1 | import { readFileSync, existsSync, readdirSync } from "fs"; 2 | import * as path from "path"; 3 | import { generate } from "../generate"; 4 | import { parse } from "../parser"; 5 | import { Model } from "../ast"; 6 | import chokidar from "chokidar"; 7 | import { spawn, ChildProcess } from "child_process"; 8 | import pt from "prepend-transform"; 9 | import Main from "./main"; 10 | import Worker from "./worker"; 11 | import chalk from "chalk"; 12 | import Repl from "./repl"; 13 | import { dbFromConfig, loadGeneratedModule, loadProject } from "./common"; 14 | import { loadConfig, getEnv } from "../internal/config"; 15 | import merge from "deepmerge"; 16 | 17 | const welcomeMessage = chalk.bold(` _____________________ 18 | < Welcome to Restate! > 19 | --------------------- 20 | \\ ^__^ 21 | \\ (oo)\\_______ 22 | (__)\\ )\\/\\ 23 | ||----w | 24 | || || 25 | `); 26 | const version = "0.1.0"; 27 | const checkmark = chalk.green("✓"); 28 | 29 | export default class App { 30 | private mainProcess: ChildProcess | undefined; 31 | private workerProcess: ChildProcess | undefined; 32 | 33 | async startDev() { 34 | const env = getEnv() || "development"; 35 | const config = loadConfig(env); 36 | 37 | console.log(welcomeMessage); 38 | console.log("Version:", chalk.bold(version)); 39 | console.log("Environment:", chalk.bold(env)); 40 | console.log("Database:", chalk.bold(config.database.type)); 41 | console.log(""); 42 | 43 | await this.runDev(); 44 | 45 | chokidar 46 | .watch(["src/", "restate/", "restate.config.json"]) 47 | .on("change", (_event, _path) => { 48 | console.log(""); 49 | console.log("🔂 File changed, restarting..."); 50 | this.runDev(); 51 | }); 52 | } 53 | 54 | async startMain() { 55 | const project = await loadProject(); 56 | const main = new Main(project); 57 | await main.run(); 58 | } 59 | 60 | async startWorker() { 61 | const project = await loadProject(); 62 | const worker = new Worker(project); 63 | await worker.run(); 64 | } 65 | 66 | async startRepl() { 67 | const project = await loadProject(); 68 | const repl = new Repl(project); 69 | await repl.run(); 70 | } 71 | 72 | async generate() { 73 | // Find and parse models from all files in the restate/ directory 74 | process.stdout.write("🛠️ Generating client... "); 75 | const files = existsSync("restate/") ? readdirSync("restate/") : []; 76 | const models = files 77 | .map((file) => readFileSync(path.join("restate", file)).toString()) 78 | .flatMap(parse) 79 | .map((parsed) => new Model(parsed)); 80 | 81 | // Output the generated Typescript files to the packages directory 82 | // This way, they can be imported directly from the project 83 | const outputPath = `${__dirname}/../generated`; 84 | await generate(models, outputPath); 85 | console.log(checkmark); 86 | } 87 | 88 | async migrate() { 89 | const project = await loadProject(); 90 | if (project === undefined) { 91 | console.log(chalk.yellow("⚠️ No restate.ts project definition found!")); 92 | return; 93 | } 94 | 95 | process.stdout.write("⏩️ Migrating database... "); 96 | 97 | const config = loadConfig(); 98 | 99 | // Set up client using the imported config 100 | const generatedModule = await loadGeneratedModule(); 101 | const db = dbFromConfig(generatedModule.__ProjectMeta, config.database); 102 | const Client = generatedModule.RestateClient; 103 | const client = new Client(project, db); 104 | await client.setup(); 105 | 106 | // Set up database if it hasn't already been set up 107 | await client.migrate(); 108 | await client.close(); 109 | 110 | console.log(checkmark); 111 | } 112 | 113 | private async runDev() { 114 | await this.generate(); 115 | 116 | // Import the project definition dynamically from the project directory. 117 | // This project definition contains the transition implementations and consumers. 118 | process.stdout.write("🚛 Loading project... "); 119 | let project; 120 | try { 121 | project = await loadProject(); 122 | } catch (e) { 123 | console.log(chalk.red("\n❗️ Failed to load project:")); 124 | console.log(e); 125 | return; 126 | } 127 | 128 | if (project === undefined) { 129 | console.log( 130 | chalk.yellow("\n⚠️ No restate.ts project definition found!") 131 | ); 132 | return; 133 | } 134 | 135 | console.log(checkmark); 136 | 137 | const env = getEnv() || "development"; 138 | const config = loadConfig(env); 139 | 140 | // Set up client using the imported config 141 | const generatedModule = await loadGeneratedModule(); 142 | const db = dbFromConfig(generatedModule.__ProjectMeta, config.database); 143 | const Client = generatedModule.RestateClient; 144 | const client = new Client(project, db); 145 | await client.setup(); 146 | 147 | // Set up database if it hasn't already been set up 148 | process.stdout.write("⏩️ Migrating database... "); 149 | await client.migrate(); 150 | console.log(checkmark); 151 | 152 | this.spawnWorker(); 153 | this.spawnMain(); 154 | } 155 | 156 | private spawnWorker() { 157 | // Spawn a child process to run the projects consumers 158 | process.stdout.write("👷‍♀️ Starting worker... "); 159 | 160 | // If a main process is already running, kill it 161 | if (this.workerProcess) { 162 | this.workerProcess.kill(); 163 | } 164 | 165 | this.workerProcess = spawn(`${__dirname}/entrypoint.ts`, ["worker"], { 166 | env: merge({ NODE_ENV: "development" }, process.env), 167 | }); 168 | 169 | const prefix = chalk.yellow("[worker] "); 170 | this.workerProcess.stdout?.pipe(pt(prefix)).pipe(process.stdout); 171 | this.workerProcess.stderr?.pipe(pt(prefix)).pipe(process.stderr); 172 | 173 | console.log(checkmark); 174 | } 175 | 176 | private spawnMain() { 177 | // Spawn a child process to run the projects main function 178 | process.stdout.write("🚀 Starting application... "); 179 | 180 | // If a main process is already running, kill it 181 | if (this.mainProcess) { 182 | this.mainProcess.kill(); 183 | } 184 | 185 | this.mainProcess = spawn(`${__dirname}/entrypoint.ts`, ["main"], { 186 | env: merge({ NODE_ENV: "development" }, process.env), 187 | }); 188 | 189 | const prefix = chalk.cyan("[main] "); 190 | this.mainProcess.stdout?.pipe(pt(prefix)).pipe(process.stdout); 191 | this.mainProcess.stderr?.pipe(pt(prefix)).pipe(process.stderr); 192 | 193 | console.log(checkmark); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/cmd/common.ts: -------------------------------------------------------------------------------- 1 | import winston, { config } from "winston"; 2 | import { Db, Project, SqliteDb, SqliteQueue } from "../internal"; 3 | import { DatabaseConfig } from "../internal/config"; 4 | import { PostgresDb } from "../internal/db"; 5 | import { Queue } from "../internal/queue"; 6 | import { PostgresQueue } from "../internal/queue/PostgresQueue"; 7 | import Logger from "./logger"; 8 | import { existsSync } from "fs"; 9 | 10 | export async function loadGeneratedModule(): Promise<{ 11 | __ProjectMeta: any; 12 | RestateClient: any; 13 | }> { 14 | const generatedPath = "../generated"; 15 | return await import(generatedPath); 16 | } 17 | 18 | export async function loadProject(): Promise { 19 | const configPath = `${process.cwd()}/src/restate.ts`; 20 | if (!existsSync(configPath)) { 21 | return undefined; 22 | } 23 | 24 | return (await import(configPath)).default; 25 | } 26 | 27 | export function dbFromConfig(projectMeta: any, config: DatabaseConfig): Db { 28 | switch (config.type) { 29 | case "sqlite": 30 | return SqliteDb.fromConfig(projectMeta, config); 31 | case "postgres": 32 | return PostgresDb.fromConfig(projectMeta, config); 33 | } 34 | } 35 | 36 | export function queueFromDb( 37 | logger: Logger, 38 | projectMeta: any, 39 | db: Db, 40 | client: any, 41 | project: Project 42 | ): Queue { 43 | if (db instanceof SqliteDb) { 44 | return new SqliteQueue(projectMeta, db, client, project); 45 | } 46 | 47 | if (db instanceof PostgresDb) { 48 | return new PostgresQueue(logger, projectMeta, db, client, project); 49 | } 50 | 51 | throw new Error("couldn't determine what queue to create for database"); 52 | } 53 | -------------------------------------------------------------------------------- /src/cmd/entrypoint.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ts-node 2 | 3 | import App from "./app"; 4 | 5 | async function main() { 6 | let command = "dev"; 7 | if (process.argv.length > 2) { 8 | command = process.argv[2]; 9 | } 10 | 11 | const app = new App(); 12 | 13 | switch (command) { 14 | case "main": 15 | app.startMain(); 16 | break; 17 | case "worker": 18 | app.startWorker(); 19 | break; 20 | case "dev": 21 | await app.startDev(); 22 | break; 23 | case "generate": 24 | await app.generate(); 25 | break; 26 | case "migrate": 27 | await app.generate(); 28 | await app.migrate(); 29 | break; 30 | } 31 | } 32 | 33 | main(); 34 | -------------------------------------------------------------------------------- /src/cmd/logger.ts: -------------------------------------------------------------------------------- 1 | import winston from "winston"; 2 | 3 | type Logger = winston.Logger; 4 | export default Logger; 5 | 6 | export function createLogger(): Logger { 7 | return winston.createLogger({ 8 | level: "debug", 9 | format: winston.format.combine( 10 | winston.format.colorize(), 11 | winston.format.printf((info) => { 12 | const level = info.level; 13 | const message = info.message; 14 | 15 | let data = info as any; 16 | delete data.level; 17 | delete data.message; 18 | const attributes = Object.entries(data) 19 | .map(([key, value]) => `${key}=${JSON.stringify(value)}`) 20 | .join(" "); 21 | 22 | return `${level}:\t${message} ${attributes}`; 23 | }) 24 | ), 25 | transports: [new winston.transports.Console()], 26 | }); 27 | } 28 | -------------------------------------------------------------------------------- /src/cmd/main.ts: -------------------------------------------------------------------------------- 1 | import { loadConfig } from "../internal/config"; 2 | import { dbFromConfig, loadGeneratedModule } from "./common"; 3 | 4 | export default class Main { 5 | constructor(private project: any) {} 6 | 7 | async run() { 8 | const generatedModule = await loadGeneratedModule(); 9 | const config = loadConfig(); 10 | 11 | const db = dbFromConfig(generatedModule.__ProjectMeta, config.database); 12 | const client = new generatedModule.RestateClient(this.project, db); 13 | await client.setup(); 14 | 15 | const main: Function = this.project.main; 16 | await (main.call(undefined, client) as Promise); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/cmd/repl.ts: -------------------------------------------------------------------------------- 1 | import { TestDb } from "../internal"; 2 | import * as tsNode from "ts-node"; 3 | import { rep } from "typescript-parsec"; 4 | 5 | export default class Repl { 6 | constructor(private config: any) {} 7 | 8 | async run() { 9 | const module = await this.generatedModule(); 10 | const exports = Object.keys(module); 11 | 12 | const repl = tsNode.createRepl(); 13 | 14 | const service = tsNode.create({ ...repl.evalAwarePartialHost }); 15 | repl.setService(service); 16 | 17 | const command = `import { ${exports.join( 18 | ", " 19 | )} } from "${__dirname}/../generated"`; 20 | repl.start(); 21 | repl.evalCode("console.log('test')"); 22 | repl.evalCode(`import * as files from "fs"`); 23 | } 24 | 25 | private async generatedModule(): Promise { 26 | const generatedPath = "../generated"; 27 | return await import(generatedPath); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/cmd/worker.ts: -------------------------------------------------------------------------------- 1 | import { loadConfig } from "../internal/config"; 2 | import { 3 | dbFromConfig, 4 | loadGeneratedModule, 5 | loadProject, 6 | queueFromDb, 7 | } from "./common"; 8 | import { createLogger } from "./logger"; 9 | 10 | export default class Worker { 11 | constructor(private project: any) {} 12 | 13 | async run() { 14 | const config = loadConfig(); 15 | const project = await loadProject(); 16 | const generatedModule = await loadGeneratedModule(); 17 | 18 | const db = dbFromConfig(generatedModule.__ProjectMeta, config.database); 19 | 20 | const Client = generatedModule.RestateClient; 21 | const client = new Client(this.project, db); 22 | await client.setup(); 23 | 24 | const logger = createLogger(); 25 | 26 | logger.info("Starting worker"); 27 | 28 | const queue = queueFromDb( 29 | logger, 30 | generatedModule.__ProjectMeta, 31 | db, 32 | client, 33 | project 34 | ); 35 | await queue.run(); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/generate/client.ts: -------------------------------------------------------------------------------- 1 | import { toCamelCase } from "js-convert-case"; 2 | import { 3 | ts, 4 | ClassDeclarationStructure, 5 | ConstructorDeclarationStructure, 6 | PropertyDeclarationStructure, 7 | StatementStructures, 8 | StructureKind, 9 | VariableStatementStructure, 10 | printNode, 11 | MethodDeclarationStructure, 12 | ParameterDeclarationStructure, 13 | ImportDeclarationStructure, 14 | ExportDeclarationStructure, 15 | InterfaceDeclarationStructure, 16 | Scope, 17 | FunctionDeclarationStructure, 18 | VariableDeclarationKind, 19 | } from "ts-morph"; 20 | import { Model, Transition } from "../ast"; 21 | 22 | export function generateDbFile(models: Model[]): StatementStructures[] { 23 | const modelClientDeclarations = models.map(modelClientClass); 24 | const modelTransitionsClientDeclarations = models.map(modelTransitionsClass); 25 | 26 | return [ 27 | ...imports(models), 28 | ...reexports(models), 29 | transitonImplType(models), 30 | projectMetaConstant(models), 31 | projectType(), 32 | testClientFunction(), 33 | clientClass(models), 34 | ...modelClientDeclarations, 35 | ...modelTransitionsClientDeclarations, 36 | ]; 37 | } 38 | 39 | function imports(models: Model[]): ImportDeclarationStructure[] { 40 | const modelImports: ImportDeclarationStructure[] = models.map((model) => ({ 41 | kind: StructureKind.ImportDeclaration, 42 | namespaceImport: model.pascalCaseName(), 43 | moduleSpecifier: `./${model.pascalCaseName()}`, 44 | })); 45 | 46 | const internalImport: ImportDeclarationStructure = { 47 | kind: StructureKind.ImportDeclaration, 48 | namespaceImport: "__Internal", 49 | moduleSpecifier: "../internal", 50 | }; 51 | 52 | return [...modelImports, internalImport]; 53 | } 54 | 55 | function reexports(models: Model[]): ExportDeclarationStructure[] { 56 | return models.map((model) => ({ 57 | kind: StructureKind.ExportDeclaration, 58 | namespaceExport: model.pascalCaseName(), 59 | moduleSpecifier: `./${model.pascalCaseName()}`, 60 | })); 61 | } 62 | 63 | function transitonImplType(models: Model[]): InterfaceDeclarationStructure { 64 | const properties = models.map((model) => { 65 | return { 66 | name: toCamelCase(model.pascalCaseName()), 67 | type: `${model.pascalCaseName()}.TransitionImpl`, 68 | }; 69 | }); 70 | 71 | return { 72 | name: "TransitionImpls", 73 | kind: StructureKind.Interface, 74 | properties, 75 | isExported: true, 76 | }; 77 | } 78 | 79 | function projectMetaConstant(models: Model[]): VariableStatementStructure { 80 | const modelMetasConstant = 81 | "[" + 82 | models.map((model) => `${model.pascalCaseName()}.__Meta`).join(", ") + 83 | "]"; 84 | 85 | return { 86 | kind: StructureKind.VariableStatement, 87 | declarationKind: VariableDeclarationKind.Const, 88 | declarations: [ 89 | { 90 | name: "__ProjectMeta", 91 | type: "__Internal.ProjectMeta", 92 | initializer: `new __Internal.ProjectMeta(${modelMetasConstant});`, 93 | }, 94 | ], 95 | isExported: true, 96 | }; 97 | } 98 | 99 | function projectType(): InterfaceDeclarationStructure { 100 | return { 101 | name: "RestateProject", 102 | kind: StructureKind.Interface, 103 | properties: [ 104 | { 105 | name: "main", 106 | type: "(restate: RestateClient) => Promise", 107 | }, 108 | { 109 | name: "transitions", 110 | type: "TransitionImpls", 111 | }, 112 | { 113 | name: "consumers", 114 | type: "__Internal.Consumer[]", 115 | hasQuestionToken: true, 116 | }, 117 | ], 118 | isExported: true, 119 | }; 120 | } 121 | 122 | function testClientFunction(): FunctionDeclarationStructure { 123 | return { 124 | kind: StructureKind.Function, 125 | name: "setupTestClient", 126 | isExported: true, 127 | isAsync: true, 128 | parameters: [ 129 | { 130 | name: "project", 131 | type: "RestateProject", 132 | }, 133 | ], 134 | returnType: "Promise", 135 | statements: [ 136 | "const db = new __Internal.TestDb(__ProjectMeta)", 137 | "await db.setup()", 138 | "await db.migrate()", 139 | "const client = new RestateClient(project, db)", 140 | "const consumerCallback = __Internal.createTestConsumerRunner(db, project, client, __ProjectMeta)", 141 | "db.setTransitionCallback(consumerCallback)", 142 | "return client", 143 | ], 144 | }; 145 | } 146 | 147 | function clientClass(models: Model[]): ClassDeclarationStructure { 148 | const modelClientProperties: PropertyDeclarationStructure[] = models.map( 149 | (model) => { 150 | return { 151 | kind: StructureKind.Property, 152 | name: toCamelCase(model.pascalCaseName()), 153 | type: modelClientClassName(model), 154 | }; 155 | } 156 | ); 157 | 158 | const internalDbProperty: PropertyDeclarationStructure = { 159 | kind: StructureKind.Property, 160 | name: "__db", 161 | type: "__Internal.Db", 162 | }; 163 | 164 | const internalConfigProperty: PropertyDeclarationStructure = { 165 | kind: StructureKind.Property, 166 | name: "__project", 167 | type: "RestateProject", 168 | }; 169 | 170 | const internalTriggeredByPropery: PropertyDeclarationStructure = { 171 | kind: StructureKind.Property, 172 | name: "__triggeredBy", 173 | type: "string | null", 174 | initializer: "null", 175 | }; 176 | 177 | const modelPropertyAssignments: string[] = models.map((model) => 178 | // `this.{MODEL_NAME} = new {MODEL_NAME}Db(this.__db)` 179 | printNode( 180 | ts.factory.createExpressionStatement( 181 | ts.factory.createBinaryExpression( 182 | ts.factory.createPropertyAccessExpression( 183 | ts.factory.createThis(), 184 | ts.factory.createIdentifier(toCamelCase(model.pascalCaseName())) 185 | ), 186 | ts.factory.createToken(ts.SyntaxKind.EqualsToken), 187 | ts.factory.createNewExpression( 188 | ts.factory.createIdentifier(modelClientClassName(model)), 189 | undefined, 190 | [ts.factory.createThis()] 191 | ) 192 | ) 193 | ) 194 | ) 195 | ); 196 | 197 | const internalDbAssignment = "this.__db = db"; 198 | const internalConfigAssignment = "this.__project = project"; 199 | 200 | const constructor: ConstructorDeclarationStructure = { 201 | kind: StructureKind.Constructor, 202 | parameters: [ 203 | { 204 | kind: StructureKind.Parameter, 205 | name: "project", 206 | type: "RestateProject", 207 | }, 208 | { 209 | kind: StructureKind.Parameter, 210 | name: "db", 211 | type: "__Internal.Db", 212 | }, 213 | ], 214 | statements: [ 215 | internalDbAssignment, 216 | internalConfigAssignment, 217 | ...modelPropertyAssignments, 218 | ], 219 | }; 220 | 221 | const setupMethod: MethodDeclarationStructure = { 222 | kind: StructureKind.Method, 223 | name: "setup", 224 | statements: ["await this.__db.setup();"], 225 | isAsync: true, 226 | returnType: `Promise`, 227 | }; 228 | 229 | const migrateMethod: MethodDeclarationStructure = { 230 | kind: StructureKind.Method, 231 | name: "migrate", 232 | statements: ["await this.__db.migrate();"], 233 | isAsync: true, 234 | returnType: `Promise`, 235 | }; 236 | 237 | const closeMethod: MethodDeclarationStructure = { 238 | kind: StructureKind.Method, 239 | name: "close", 240 | statements: ["await this.__db.close();"], 241 | isAsync: true, 242 | returnType: `Promise`, 243 | }; 244 | 245 | const withTriggeredByMethod: MethodDeclarationStructure = { 246 | kind: StructureKind.Method, 247 | name: "withTriggeredBy", 248 | statements: [ 249 | "const newClient = new RestateClient(this.__project, this.__db)", 250 | "newClient.__triggeredBy = triggeredBy", 251 | "return newClient", 252 | ], 253 | parameters: [ 254 | { 255 | kind: StructureKind.Parameter, 256 | name: "triggeredBy", 257 | type: "string", 258 | }, 259 | ], 260 | returnType: `RestateClient`, 261 | }; 262 | 263 | const getTasksForTransitionMethod: MethodDeclarationStructure = { 264 | kind: StructureKind.Method, 265 | name: "getTasksForTransition", 266 | isAsync: true, 267 | parameters: [ 268 | { 269 | kind: StructureKind.Parameter, 270 | name: "transition", 271 | type: "string | __Internal.Transition", 272 | }, 273 | ], 274 | statements: [ 275 | "const id = typeof transition == 'string' ? transition : transition.id;", 276 | "return await this.__db.getTasksForTransition(id);", 277 | ], 278 | returnType: `Promise<__Internal.Task[]>`, 279 | }; 280 | 281 | return { 282 | kind: StructureKind.Class, 283 | name: "RestateClient", 284 | properties: [ 285 | ...modelClientProperties, 286 | internalDbProperty, 287 | internalConfigProperty, 288 | internalTriggeredByPropery, 289 | ], 290 | ctors: [constructor], 291 | methods: [ 292 | setupMethod, 293 | migrateMethod, 294 | closeMethod, 295 | withTriggeredByMethod, 296 | getTasksForTransitionMethod, 297 | ], 298 | isExported: true, 299 | }; 300 | } 301 | 302 | function modelClientClass(model: Model): ClassDeclarationStructure { 303 | const internalClientProperty: PropertyDeclarationStructure = { 304 | kind: StructureKind.Property, 305 | scope: Scope.Private, 306 | name: "parent", 307 | type: "RestateClient", 308 | }; 309 | 310 | const transitionClientProperty: PropertyDeclarationStructure = { 311 | kind: StructureKind.Property, 312 | scope: Scope.Public, 313 | name: "transition", 314 | type: modelTransitionsClassName(model), 315 | }; 316 | 317 | const constructor: ConstructorDeclarationStructure = { 318 | kind: StructureKind.Constructor, 319 | parameters: [ 320 | { 321 | kind: StructureKind.Parameter, 322 | name: "parent", 323 | type: "RestateClient", 324 | }, 325 | ], 326 | statements: [ 327 | `super(parent.__db, ${model.pascalCaseName()}.__Meta)`, 328 | "this.parent = parent", 329 | `this.transition = new ${modelTransitionsClassName(model)}(parent)`, 330 | ], 331 | }; 332 | 333 | const methods = []; 334 | 335 | return { 336 | kind: StructureKind.Class, 337 | name: modelClientClassName(model), 338 | properties: [internalClientProperty, transitionClientProperty], 339 | methods: [...modelQueryMethods(model), ...modelGetTransitionMethods(model)], 340 | ctors: [constructor], 341 | extends: "__Internal.BaseClient", 342 | }; 343 | } 344 | 345 | function modelTransitionsClass(model: Model): ClassDeclarationStructure { 346 | const internalClientProperty: PropertyDeclarationStructure = { 347 | kind: StructureKind.Property, 348 | scope: Scope.Private, 349 | name: "parent", 350 | type: "RestateClient", 351 | }; 352 | 353 | const transitionImplsProperty: PropertyDeclarationStructure = { 354 | kind: StructureKind.Property, 355 | scope: Scope.Private, 356 | name: "transitionImpls", 357 | type: `${model.pascalCaseName()}.TransitionImpl`, 358 | }; 359 | 360 | const transitionMethods: MethodDeclarationStructure[] = model 361 | .getTransitions() 362 | .map((transition) => modelDbClassTransitionMethod(model, transition)); 363 | 364 | const constructor: ConstructorDeclarationStructure = { 365 | kind: StructureKind.Constructor, 366 | parameters: [ 367 | { 368 | kind: StructureKind.Parameter, 369 | name: "parent", 370 | type: "RestateClient", 371 | }, 372 | ], 373 | statements: [ 374 | `super(parent.__db, ${model.pascalCaseName()}.__Meta)`, 375 | "this.parent = parent", 376 | `this.transitionImpls = parent.__project.transitions.${model.camelCaseName()}`, 377 | ], 378 | }; 379 | 380 | return { 381 | kind: StructureKind.Class, 382 | name: modelTransitionsClassName(model), 383 | extends: `__Internal.BaseTransitionsClient`, 384 | properties: [internalClientProperty, transitionImplsProperty], 385 | methods: [...transitionMethods], 386 | ctors: [constructor], 387 | }; 388 | } 389 | 390 | function modelQueryMethods(model: Model): MethodDeclarationStructure[] { 391 | const typeParameters: string[] = [ 392 | `F extends ${model.pascalCaseName()}.QueryFilter`, 393 | `Out extends ${model.pascalCaseName()}.ResultFromQueryFilter`, 394 | ]; 395 | 396 | const filterParameter: ParameterDeclarationStructure = { 397 | kind: StructureKind.Parameter, 398 | name: "params", 399 | type: "__Internal.QueryParams", 400 | hasQuestionToken: true, 401 | }; 402 | 403 | return [ 404 | { 405 | kind: StructureKind.Method, 406 | name: "findOne", 407 | typeParameters, 408 | parameters: [filterParameter], 409 | statements: [ 410 | `const result = await this.internalFindOne(params || {});`, 411 | `return result as Out | null`, 412 | ], 413 | isAsync: true, 414 | returnType: `Promise`, 415 | }, 416 | { 417 | kind: StructureKind.Method, 418 | name: "findOneOrThrow", 419 | typeParameters, 420 | parameters: [filterParameter], 421 | statements: [ 422 | `const result = await this.internalFindOneOrThrow(params || {});`, 423 | `return result as Out`, 424 | ], 425 | isAsync: true, 426 | returnType: `Promise`, 427 | }, 428 | { 429 | kind: StructureKind.Method, 430 | name: "findAll", 431 | typeParameters, 432 | parameters: [filterParameter], 433 | statements: [ 434 | `const result = await this.internalFindAll(params || {});`, 435 | `return result as Out[]`, 436 | ], 437 | isAsync: true, 438 | returnType: `Promise`, 439 | }, 440 | ]; 441 | } 442 | 443 | function modelGetTransitionMethods(model: Model): MethodDeclarationStructure[] { 444 | const transitionType = `__Internal.Transition<${model.pascalCaseName()}.AnyTransition, ${model.pascalCaseName()}.Transition>`; 445 | 446 | return [ 447 | { 448 | kind: StructureKind.Method, 449 | name: "getTransition", 450 | parameters: [ 451 | { 452 | kind: StructureKind.Parameter, 453 | name: "id", 454 | type: "string", 455 | }, 456 | ], 457 | statements: [ 458 | `const result = await this.getTransitionById(id);`, 459 | `return result as ${transitionType} | null`, 460 | ], 461 | isAsync: true, 462 | returnType: `Promise<${transitionType} | null>`, 463 | }, 464 | { 465 | kind: StructureKind.Method, 466 | name: "getObjectTransitions", 467 | parameters: [ 468 | { 469 | kind: StructureKind.Parameter, 470 | name: "object", 471 | type: `string | ${model.pascalCaseName()}.Any`, 472 | }, 473 | ], 474 | statements: [ 475 | "const id = typeof object == 'string' ? object : object.id;", 476 | `const result = await this.getTransitionsForObject(id);`, 477 | `return result as ${transitionType}[]`, 478 | ], 479 | isAsync: true, 480 | returnType: `Promise<${transitionType}[]>`, 481 | }, 482 | ]; 483 | } 484 | 485 | function modelDbClassTransitionMethod( 486 | model: Model, 487 | transition: Transition 488 | ): MethodDeclarationStructure { 489 | let parameterUnionTypes = ["__Internal.TransitionParameters"]; 490 | let requiresParameters = false; 491 | 492 | // From state is not defined for initializing transition 493 | const fromStates = transition.getFromStates(); 494 | if (fromStates) { 495 | const fromStatesType = 496 | fromStates == "*" 497 | ? `${model.pascalCaseName()}.Any` 498 | : fromStates 499 | .map( 500 | (state) => `${model.pascalCaseName()}.${state.pascalCaseName()}` 501 | ) 502 | .join(" | "); 503 | 504 | parameterUnionTypes.push( 505 | `__Internal.TransitionWithObject<${fromStatesType}>` 506 | ); 507 | requiresParameters = true; 508 | } 509 | 510 | if (transition.getFields().length != 0) { 511 | parameterUnionTypes.push( 512 | `__Internal.TransitionWithData<${model.pascalCaseName()}.${transition.pascalCaseName()}Data>` 513 | ); 514 | requiresParameters = true; 515 | } 516 | 517 | let parameters: ParameterDeclarationStructure[] = [ 518 | { 519 | kind: StructureKind.Parameter, 520 | name: "params", 521 | type: parameterUnionTypes.join(" & "), 522 | initializer: !requiresParameters ? "{}" : undefined, 523 | }, 524 | ]; 525 | 526 | const statements: string[] = []; 527 | if (fromStates) { 528 | statements.push( 529 | `const fn = async (object: any, transition: any) => await this.transitionImpls.${transition.camelCaseName()}(this.parent.withTriggeredBy(transition.id), object, transition);`, 530 | "const id = typeof params.object == 'string' ? params.object : params.object.id;", 531 | `const { updatedObject, updatedTransition } = await this.applyTransition(${model.pascalCaseName()}.__Meta.getTransitionMeta("${transition.pascalCaseName()}"), params, id, fn, this.parent.__triggeredBy);` 532 | ); 533 | } else { 534 | statements.push( 535 | `const fn = async (object: any, transition: any) => await this.transitionImpls.${transition.camelCaseName()}(this.parent.withTriggeredBy(transition.id), transition);`, 536 | `const { updatedObject, updatedTransition } = await this.applyTransition(${model.pascalCaseName()}.__Meta.getTransitionMeta("${transition.pascalCaseName()}"), params, undefined, fn, this.parent.__triggeredBy);` 537 | ); 538 | } 539 | 540 | const toStates = transition.getToStates(); 541 | const toStateType = 542 | toStates == "*" 543 | ? `${model.pascalCaseName()}.Any` 544 | : toStates 545 | .map((state) => `${model.pascalCaseName()}.${state.pascalCaseName()}`) 546 | .join(" | "); 547 | 548 | const transitionType = `${model.pascalCaseName()}.${transition.pascalCaseName()}`; 549 | 550 | statements.push( 551 | `return [updatedObject as ${toStateType}, updatedTransition as ${transitionType}];` 552 | ); 553 | 554 | return { 555 | kind: StructureKind.Method, 556 | name: transition.camelCaseName(), 557 | parameters: parameters, 558 | statements, 559 | isAsync: true, 560 | returnType: `Promise<[${toStateType}, ${transitionType}]>`, 561 | }; 562 | } 563 | 564 | function modelClientClassName(model: Model): string { 565 | return `${model.pascalCaseName()}Client`; 566 | } 567 | 568 | function modelTransitionsClassName(model: Model): string { 569 | return `${model.pascalCaseName()}TransitionsClient`; 570 | } 571 | -------------------------------------------------------------------------------- /src/generate/helpers.ts: -------------------------------------------------------------------------------- 1 | import { State } from "../ast"; 2 | 3 | export function statesToType(states: State[]): string { 4 | return states.map((state) => state.pascalCaseName()).join(" | "); 5 | } 6 | -------------------------------------------------------------------------------- /src/generate/index.ts: -------------------------------------------------------------------------------- 1 | import { Project } from "ts-morph"; 2 | import { Model } from "../ast"; 3 | import { generateDbFile } from "./client"; 4 | import { generateModelFile } from "./model"; 5 | 6 | export async function generate(models: Model[], outputDir: string) { 7 | const project = new Project({}); 8 | 9 | models.forEach(model => { 10 | project.createSourceFile(`${outputDir}//${model.pascalCaseName()}.ts`, { 11 | statements: generateModelFile(model), 12 | }, { 13 | overwrite: true, 14 | }); 15 | }) 16 | 17 | project.createSourceFile(`${outputDir}/index.ts`, { 18 | statements: generateDbFile(models), 19 | }, { 20 | overwrite: true 21 | }); 22 | 23 | await project.save(); 24 | } 25 | -------------------------------------------------------------------------------- /src/generate/model.ts: -------------------------------------------------------------------------------- 1 | import { toSnakeCase } from "js-convert-case"; 2 | import { 3 | EnumDeclarationStructure, 4 | EnumMemberStructure, 5 | FunctionDeclarationStructure, 6 | ImportDeclarationStructure, 7 | InterfaceDeclarationStructure, 8 | MethodSignatureStructure, 9 | ParameterDeclarationStructure, 10 | PropertySignatureStructure, 11 | StatementStructures, 12 | StructureKind, 13 | TypeAliasDeclarationStructure, 14 | VariableDeclarationKind, 15 | VariableStatementStructure, 16 | WriterFunction, 17 | Writers, 18 | } from "ts-morph"; 19 | import { Model, Transition, State, Field } from "../ast"; 20 | import { statesToType } from "./helpers"; 21 | 22 | export function generateModelFile(model: Model): StatementStructures[] { 23 | return [ 24 | ...imports(), 25 | anyType(model), 26 | anyTransitionType(model), 27 | stateEnum(model), 28 | transitionEnum(model), 29 | ...stateTypes(model), 30 | ...transitionTypes(model), 31 | transitionInterface(model), 32 | ...queryTypes(model), 33 | modelMeta(model), 34 | createConsumerFunction(model), 35 | ]; 36 | } 37 | 38 | function imports(): ImportDeclarationStructure[] { 39 | return [ 40 | { 41 | kind: StructureKind.ImportDeclaration, 42 | namespaceImport: "__Internal", 43 | moduleSpecifier: "../internal", 44 | }, 45 | { 46 | kind: StructureKind.ImportDeclaration, 47 | namedImports: ["RestateClient"], 48 | moduleSpecifier: "./", 49 | }, 50 | ]; 51 | } 52 | 53 | function anyType(model: Model): TypeAliasDeclarationStructure { 54 | const states = model.getStates(); 55 | 56 | let type: string | WriterFunction; 57 | if (states.length == 1) { 58 | type = states[0].pascalCaseName(); 59 | } else { 60 | const [state1, state2, ...restStates] = model.getStates(); 61 | type = Writers.unionType( 62 | state1.pascalCaseName(), 63 | state2.pascalCaseName(), 64 | ...restStates.map((state) => state.pascalCaseName()) 65 | ); 66 | } 67 | 68 | return { 69 | name: "Any", 70 | kind: StructureKind.TypeAlias, 71 | type, 72 | isExported: true, 73 | }; 74 | } 75 | 76 | function anyTransitionType(model: Model): TypeAliasDeclarationStructure { 77 | const transitions = model.getTransitions(); 78 | 79 | let type: string | WriterFunction; 80 | if (transitions.length == 1) { 81 | type = transitions[0].pascalCaseName(); 82 | } else { 83 | const [transition1, transition2, ...restTransitions] = 84 | model.getTransitions(); 85 | type = Writers.unionType( 86 | transition1.pascalCaseName(), 87 | transition2.pascalCaseName(), 88 | ...restTransitions.map((transition) => transition.pascalCaseName()) 89 | ); 90 | } 91 | 92 | return { 93 | name: "AnyTransition", 94 | kind: StructureKind.TypeAlias, 95 | type, 96 | isExported: true, 97 | }; 98 | } 99 | 100 | function stateEnum(model: Model): EnumDeclarationStructure { 101 | const members: EnumMemberStructure[] = model.getStates().map((state) => ({ 102 | kind: StructureKind.EnumMember, 103 | name: state.pascalCaseName(), 104 | initializer: `"${toSnakeCase(state.pascalCaseName())}"`, 105 | })); 106 | 107 | return { 108 | kind: StructureKind.Enum, 109 | name: "State", 110 | isExported: true, 111 | members, 112 | }; 113 | } 114 | 115 | function transitionEnum(model: Model): EnumDeclarationStructure { 116 | const members: EnumMemberStructure[] = model 117 | .getTransitions() 118 | .map((transition) => ({ 119 | kind: StructureKind.EnumMember, 120 | name: transition.pascalCaseName(), 121 | initializer: `"${transition.snakeCaseName()}"`, 122 | })); 123 | 124 | return { 125 | kind: StructureKind.Enum, 126 | name: "Transition", 127 | isExported: true, 128 | members, 129 | }; 130 | } 131 | 132 | function stateTypes(model: Model): InterfaceDeclarationStructure[] { 133 | return model.getStates().flatMap((state) => { 134 | const modelTypeProperties = Array.from(state.getFields().values()).map( 135 | (field) => { 136 | const typeScriptType = field.getType().getTypescriptType(); 137 | return { 138 | name: field.camelCaseName(), 139 | type: field.getType().canBeNull() 140 | ? `${typeScriptType} | null` 141 | : typeScriptType, 142 | }; 143 | } 144 | ); 145 | 146 | const modelType: InterfaceDeclarationStructure = { 147 | name: state.pascalCaseName(), 148 | kind: StructureKind.Interface, 149 | properties: [ 150 | { 151 | name: "id", 152 | type: "string", 153 | }, 154 | { 155 | name: "state", 156 | type: `State.${state.pascalCaseName()}`, 157 | }, 158 | ...modelTypeProperties, 159 | ], 160 | isExported: true, 161 | }; 162 | 163 | const modelDataTypeProperties = Array.from(state.getFields().values()).map( 164 | (field) => { 165 | const typeScriptType = field.getType().getTypescriptType(); 166 | return { 167 | name: field.camelCaseName(), 168 | type: field.getType().canBeNull() 169 | ? `${typeScriptType} | null` 170 | : typeScriptType, 171 | hasQuestionToken: field.getType().canBeNull(), 172 | }; 173 | } 174 | ); 175 | 176 | const modelDataType: InterfaceDeclarationStructure = { 177 | name: `${state.pascalCaseName()}Data`, 178 | kind: StructureKind.Interface, 179 | properties: [ 180 | { 181 | name: "state", 182 | type: `State.${state.pascalCaseName()}`, 183 | }, 184 | ...modelDataTypeProperties, 185 | ], 186 | isExported: true, 187 | }; 188 | 189 | return [modelType, modelDataType]; 190 | }); 191 | } 192 | 193 | function transitionTypes( 194 | model: Model 195 | ): (InterfaceDeclarationStructure | TypeAliasDeclarationStructure)[] { 196 | return model.getTransitions().flatMap((transition) => { 197 | const transitionType: InterfaceDeclarationStructure = { 198 | kind: StructureKind.Interface, 199 | name: transition.pascalCaseName(), 200 | isExported: true, 201 | properties: [ 202 | { 203 | name: "id", 204 | type: "string", 205 | }, 206 | { 207 | name: "objectId", 208 | type: "string", 209 | }, 210 | { 211 | name: "model", 212 | type: `"${model.pascalCaseName()}"`, 213 | }, 214 | { 215 | name: "type", 216 | type: `Transition.${transition.pascalCaseName()}`, 217 | }, 218 | { 219 | name: "from", 220 | type: `State | null`, 221 | }, 222 | { 223 | name: "to", 224 | type: `State`, 225 | }, 226 | { 227 | name: "data", 228 | type: `${transition.pascalCaseName()}Data`, 229 | }, 230 | { 231 | name: "note", 232 | type: "string | null", 233 | }, 234 | { 235 | name: "triggeredBy", 236 | type: "string | null", 237 | }, 238 | { 239 | name: "appliedAt", 240 | type: "Date", 241 | }, 242 | ], 243 | }; 244 | 245 | const dataProperties = transition.getFields().map((field) => { 246 | return { 247 | name: field.camelCaseName(), 248 | type: field.getType().getTypescriptType(), 249 | hasQuestionToken: field.getType().canBeNull(), 250 | }; 251 | }); 252 | 253 | const dataType: InterfaceDeclarationStructure = { 254 | name: `${transition.pascalCaseName()}Data`, 255 | kind: StructureKind.Interface, 256 | properties: [...dataProperties], 257 | isExported: true, 258 | }; 259 | 260 | return [transitionType, dataType]; 261 | }); 262 | } 263 | 264 | function transitionInterface(model: Model): InterfaceDeclarationStructure { 265 | const methods = model 266 | .getTransitions() 267 | .map((transition) => transitionMethodSignature(model, transition)); 268 | 269 | return { 270 | name: `TransitionImpl`, 271 | kind: StructureKind.Interface, 272 | methods, 273 | isExported: true, 274 | }; 275 | } 276 | 277 | function transitionMethodSignature( 278 | model: Model, 279 | transition: Transition 280 | ): MethodSignatureStructure { 281 | let parameters: ParameterDeclarationStructure[] = [ 282 | { 283 | kind: StructureKind.Parameter, 284 | name: "client", 285 | type: "RestateClient", 286 | }, 287 | ]; 288 | 289 | const fromStates = transition.getFromStates(); 290 | if (fromStates) { 291 | const parameterType = 292 | fromStates == "*" ? model.pascalCaseName() : statesToType(fromStates); 293 | parameters.push({ 294 | kind: StructureKind.Parameter, 295 | name: "original", 296 | type: parameterType, 297 | }); 298 | } 299 | 300 | parameters.push({ 301 | kind: StructureKind.Parameter, 302 | name: "transition", 303 | type: `${transition.pascalCaseName()}`, 304 | }); 305 | 306 | const toStates = transition.getToStates(); 307 | const returnType = 308 | toStates == "*" ? `${model.pascalCaseName()}.Any` : statesToType(toStates); 309 | 310 | return { 311 | kind: StructureKind.MethodSignature, 312 | name: transition.camelCaseName(), 313 | parameters: parameters, 314 | returnType: `Promise<${returnType}Data>`, 315 | }; 316 | } 317 | 318 | function queryTypes(model: Model): StatementStructures[] { 319 | const flattenedFields: Map = new Map(); 320 | for (const state of model.getStates()) { 321 | for (const field of state.getFields()) { 322 | flattenedFields.set(field.pascalCaseName(), field); 323 | } 324 | } 325 | 326 | const properties: PropertySignatureStructure[] = Array.from( 327 | flattenedFields.values() 328 | ).map((field) => { 329 | return { 330 | kind: StructureKind.PropertySignature, 331 | name: field.camelCaseName(), 332 | type: field.getType().getTypescriptType(), 333 | hasQuestionToken: true, 334 | }; 335 | }); 336 | 337 | const queryFilter: InterfaceDeclarationStructure = { 338 | name: "QueryFilter", 339 | kind: StructureKind.Interface, 340 | properties: [ 341 | { 342 | name: "id", 343 | type: "string", 344 | hasQuestionToken: true, 345 | }, 346 | { 347 | name: "state", 348 | type: "State | State[]", 349 | hasQuestionToken: true, 350 | }, 351 | ...properties, 352 | ], 353 | isExported: true, 354 | }; 355 | 356 | let enumToStateMappings: string[] = []; 357 | for (const state of model.getStates()) { 358 | enumToStateMappings.push( 359 | `E extends State.${state.pascalCaseName()} ? ${state.pascalCaseName()} :` 360 | ); 361 | } 362 | const enumToStateType: TypeAliasDeclarationStructure = { 363 | kind: StructureKind.TypeAlias, 364 | name: "EnumToState", 365 | typeParameters: ["E"], 366 | type: ` 367 | E extends any ? ( 368 | ${enumToStateMappings.join("\n")} 369 | never 370 | ) : never 371 | `, 372 | }; 373 | 374 | const resultFromFilter: TypeAliasDeclarationStructure = { 375 | kind: StructureKind.TypeAlias, 376 | isExported: true, 377 | name: "ResultFromQueryFilter", 378 | typeParameters: ["T extends QueryFilter", "S"], 379 | type: ` 380 | T["state"] extends S ? EnumToState : 381 | T["state"] extends S[] ? EnumToState<__Internal.ArrayElementType> : 382 | Any 383 | `, 384 | }; 385 | 386 | let enumToTransitionMappings: string[] = []; 387 | for (const transition of model.getTransitions()) { 388 | enumToTransitionMappings.push( 389 | `E extends Transition.${transition.pascalCaseName()} ? ${transition.pascalCaseName()} :` 390 | ); 391 | } 392 | 393 | const enumToTransitionType: TypeAliasDeclarationStructure = { 394 | kind: StructureKind.TypeAlias, 395 | name: "EnumToTransition", 396 | typeParameters: ["E"], 397 | type: ` 398 | E extends any ? ( 399 | ${enumToTransitionMappings.join("\n")} 400 | never 401 | ) : never 402 | `, 403 | }; 404 | 405 | const transitionFromFilter: TypeAliasDeclarationStructure = { 406 | kind: StructureKind.TypeAlias, 407 | isExported: true, 408 | name: "TransitionFromFilter", 409 | typeParameters: ["T", "S"], 410 | type: ` 411 | T extends S ? EnumToTransition : 412 | T extends S[] ? EnumToTransition<__Internal.ArrayElementType> : 413 | AnyTransition 414 | `, 415 | }; 416 | 417 | return [ 418 | queryFilter, 419 | enumToStateType, 420 | resultFromFilter, 421 | enumToTransitionType, 422 | transitionFromFilter, 423 | ]; 424 | } 425 | 426 | function modelMeta(model: Model): VariableStatementStructure { 427 | const statesInitializer = 428 | "[" + 429 | model 430 | .getStates() 431 | .map((state) => { 432 | const fieldsInitializer = 433 | "[" + state.getFields().map(fieldMetaInitializer).join(", ") + "]"; 434 | return `new __Internal.StateMeta("${state.pascalCaseName()}", ${fieldsInitializer})`; 435 | }) 436 | .join(", ") + 437 | "]"; 438 | 439 | const transitionsInitializer = 440 | "[" + 441 | model 442 | .getTransitions() 443 | .map((transition) => { 444 | const fieldsConstant = 445 | "[" + 446 | transition.getFields().map(fieldMetaInitializer).join(", ") + 447 | "]"; 448 | 449 | let fromStatesDefinition: string = "undefined"; 450 | const fromStates = transition.getFromStates(); 451 | if (fromStates) { 452 | let fromStatesExpanded: State[] = []; 453 | if (fromStates == "*") { 454 | fromStatesExpanded = model.getStates(); 455 | } else { 456 | fromStatesExpanded = fromStates; 457 | } 458 | const fromStatesJoined = fromStatesExpanded 459 | .map((state) => `"${state.pascalCaseName()}"`) 460 | .join(); 461 | fromStatesDefinition = `[${fromStatesJoined}]`; 462 | } 463 | 464 | let toStatesExpanded: State[] = []; 465 | const toStates = transition.getToStates(); 466 | if (toStates == "*") { 467 | toStatesExpanded = model.getStates(); 468 | } else { 469 | toStatesExpanded = toStates; 470 | } 471 | const toStatesJoined = toStatesExpanded 472 | .map((state) => `"${state.pascalCaseName()}"`) 473 | .join(); 474 | const toStatesDefinition = `[${toStatesJoined}]`; 475 | 476 | return `new __Internal.TransitionMeta("${transition.pascalCaseName()}", ${fieldsConstant}, ${fromStatesDefinition}, ${toStatesDefinition})`; 477 | }) 478 | .join(", ") + 479 | "]"; 480 | 481 | const metaObjectInitializer = `new __Internal.ModelMeta("${model.pascalCaseName()}", "${model.getPrefix()}", ${statesInitializer}, ${transitionsInitializer})`; 482 | 483 | return { 484 | kind: StructureKind.VariableStatement, 485 | declarationKind: VariableDeclarationKind.Const, 486 | declarations: [ 487 | { 488 | name: `__Meta`, 489 | type: "__Internal.ModelMeta", 490 | initializer: metaObjectInitializer, 491 | }, 492 | ], 493 | isExported: true, 494 | }; 495 | } 496 | 497 | function fieldMetaInitializer(field: Field): string { 498 | return `new __Internal.FieldMeta("${field.camelCaseName()}", ${field 499 | .getType() 500 | .initializer()})`; 501 | } 502 | 503 | function createConsumerFunction(model: Model): FunctionDeclarationStructure { 504 | return { 505 | kind: StructureKind.Function, 506 | name: "createConsumer", 507 | typeParameters: ["T extends Transition | Transition[]"], 508 | isExported: true, 509 | parameters: [ 510 | { 511 | name: "consumer", 512 | type: "{ name: string; transition: T; handler: (client: RestateClient, model: Any, transition: TransitionFromFilter) => Promise }", 513 | }, 514 | ], 515 | statements: [ 516 | "const { name, transition, handler } = consumer", 517 | "const arrayifiedTransitions: string | string[] = Array.isArray(transition) ? transition : [transition]", 518 | "return new __Internal.Consumer(name, __Meta, arrayifiedTransitions, handler)", 519 | ], 520 | }; 521 | } 522 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | export * from "./generated"; 2 | -------------------------------------------------------------------------------- /src/internal/client.ts: -------------------------------------------------------------------------------- 1 | import { toCamelCase, toPascalCase } from "js-convert-case"; 2 | import { Db } from "./db"; 3 | import { generateId, generateTransactionId } from "./id"; 4 | import { ModelMeta, TransitionMeta } from "./meta"; 5 | import BaseObject from "./object"; 6 | import Transition from "./transition"; 7 | import { QueryParams } from "./types"; 8 | 9 | export abstract class BaseClient { 10 | constructor(protected db: Db, protected modelMeta: ModelMeta) {} 11 | 12 | protected async internalFindAll(params: QueryParams): Promise { 13 | return await this.query(params); 14 | } 15 | 16 | protected async internalFindOne( 17 | params: QueryParams 18 | ): Promise { 19 | const results = await this.query(params); 20 | if (results.length == 0) { 21 | return null; 22 | } 23 | 24 | return results[0]; 25 | } 26 | 27 | protected async internalFindOneOrThrow( 28 | params: QueryParams 29 | ): Promise { 30 | const results = await this.query(params); 31 | if (results.length == 0) { 32 | throw new Error("no object found"); 33 | } 34 | 35 | return results[0]; 36 | } 37 | 38 | protected async query(params: QueryParams): Promise { 39 | return await this.db.query(this.modelMeta, params?.where, params?.limit); 40 | } 41 | 42 | protected async getTransitionById( 43 | id: string 44 | ): Promise | null> { 45 | return await this.db.getTransitionById(id); 46 | } 47 | 48 | protected async getTransitionsForObject( 49 | objectId: string 50 | ): Promise[]> { 51 | return await this.db.getTransitionsForObject(objectId); 52 | } 53 | } 54 | 55 | export abstract class BaseTransitionsClient { 56 | constructor(protected db: Db, protected modelMeta: ModelMeta) {} 57 | 58 | protected async applyTransition( 59 | transitionMeta: TransitionMeta, 60 | transitionParams: any, 61 | existingObjectId: string | undefined, 62 | transitionFn: (object: any, transition: any) => Promise, 63 | triggeredBy: string | null 64 | ): Promise<{ 65 | updatedTransition: Transition; 66 | updatedObject: Object; 67 | }> { 68 | let object = undefined; 69 | if (existingObjectId) { 70 | object = await this.db.getById(this.modelMeta, existingObjectId); 71 | } 72 | 73 | let objectId = existingObjectId; 74 | if (objectId == undefined) { 75 | objectId = generateId(this.modelMeta.prefix); 76 | } 77 | 78 | // Set up the transition object 79 | const transition: Transition = { 80 | id: generateTransactionId(), 81 | model: this.modelMeta.pascalCaseName(), 82 | type: transitionMeta.pascalCaseNmae(), 83 | from: object?.state || null, 84 | to: "", 85 | objectId: objectId, 86 | data: transitionParams.data, 87 | note: transitionParams.note, 88 | triggeredBy, 89 | appliedAt: new Date(), 90 | }; 91 | 92 | // Apply the transition implementation (from the project config) 93 | const transitionedObject = (await transitionFn( 94 | object, 95 | transition 96 | )) as BaseObject; 97 | transitionedObject.id = objectId; 98 | transition.to = transitionedObject.state; 99 | 100 | // Validate field values on object 101 | const stateMeta = this.modelMeta.getStateMetaBySerializedName( 102 | transitionedObject.state 103 | ); 104 | for (const field of stateMeta.allFieldMetas()) { 105 | const dataType = field.type; 106 | const value = (transitionedObject as any)[field.camelCaseName()]; 107 | dataType.validate(value); 108 | } 109 | 110 | await this.db.applyTransition( 111 | this.modelMeta, 112 | transitionMeta, 113 | transition, 114 | transitionedObject 115 | ); 116 | 117 | return { 118 | updatedTransition: transition, 119 | updatedObject: transitionedObject, 120 | }; 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/internal/config.ts: -------------------------------------------------------------------------------- 1 | import { readFileSync, existsSync } from "fs"; 2 | import merge from "deepmerge"; 3 | import JSON5 from "json5"; 4 | 5 | export interface Config { 6 | database: DatabaseConfig; 7 | } 8 | 9 | export function loadConfig(environmentOverride?: string): Config { 10 | const env = getEnv() || environmentOverride || "production"; 11 | 12 | // If a config file exists, we load that and merge it with the default config 13 | let fileConfig = {}; 14 | if (existsSync("restate.config.json")) { 15 | const configFileContents = readFileSync("restate.config.json").toString(); 16 | const configFromFile = JSON5.parse(configFileContents) || {}; 17 | const environmentConfigFromFile = configFromFile[env] || {}; 18 | 19 | fileConfig = merge(configFromFile, environmentConfigFromFile); 20 | } 21 | 22 | const environmentDefaultConfig = environmentSpecificDefaultConfigs[env] || {}; 23 | const defaultConfig = merge(baseDefaultConfig, environmentDefaultConfig); 24 | 25 | return merge(defaultConfig, fileConfig); 26 | } 27 | 28 | export function getEnv(): string | undefined { 29 | return process.env.NODE_ENV; 30 | } 31 | 32 | export type DatabaseConfig = SqliteDatabaseConfig | PostgresDatabaseConfig; 33 | 34 | export type SqliteDatabaseConfig = { 35 | type: "sqlite"; 36 | file: string; 37 | }; 38 | 39 | export type PostgresDatabaseConfig = { 40 | type: "postgres"; 41 | connection_string: string; 42 | }; 43 | 44 | const baseDefaultConfig: Config = { 45 | database: { 46 | type: "postgres", 47 | connection_string: "postgres://postgres:@localhost:5432/postgres", 48 | }, 49 | }; 50 | 51 | const environmentSpecificDefaultConfigs: { 52 | [environment: string]: Config; 53 | } = { 54 | development: { 55 | database: { 56 | type: "sqlite", 57 | file: "restate.sqlite", 58 | }, 59 | }, 60 | }; 61 | -------------------------------------------------------------------------------- /src/internal/consumer.ts: -------------------------------------------------------------------------------- 1 | import { ModelMeta } from "./meta"; 2 | 3 | // Should never be instantiated directly. Use the generated helper functions on each model instead. 4 | export default class Consumer { 5 | constructor( 6 | public name: string, 7 | public model: ModelMeta, 8 | public transitions: string[], 9 | public handler: ( 10 | client: any, 11 | updatedObject: any, 12 | transition: any 13 | ) => Promise 14 | ) {} 15 | } 16 | 17 | export enum TaskState { 18 | Created = "created", 19 | Completed = "completed", 20 | } 21 | 22 | export interface Task { 23 | id: string; 24 | transitionId: string; 25 | consumer: string; 26 | state: TaskState; 27 | } 28 | -------------------------------------------------------------------------------- /src/internal/dataTypes.ts: -------------------------------------------------------------------------------- 1 | import * as Parser from "../parser"; 2 | 3 | export interface DataType { 4 | validate(value: any): void; 5 | getTypescriptType(): string; 6 | canBeNull(): boolean; 7 | initializer(): string; 8 | } 9 | 10 | export function dataTypeFromParsed(type: Parser.Type): DataType { 11 | switch (type.name) { 12 | case "String": 13 | return new String(); 14 | case "Int": 15 | return new Int(); 16 | case "Decimal": 17 | return new Decimal(); 18 | case "Optional": 19 | const nested = dataTypeFromParsed(type.nested as Parser.Type); 20 | return new Optional(nested); 21 | case "Bool": 22 | return new Bool(); 23 | default: 24 | throw new Error(`unsupported type "${type.name}"`); 25 | } 26 | } 27 | 28 | export class String implements DataType { 29 | validate(value: any) { 30 | if (typeof value !== "string") { 31 | throw new Error("not a string"); 32 | } 33 | } 34 | 35 | getTypescriptType(): string { 36 | return "string"; 37 | } 38 | 39 | canBeNull(): boolean { 40 | return false; 41 | } 42 | 43 | initializer(): string { 44 | return `new __Internal.String()`; 45 | } 46 | } 47 | 48 | export class Int implements DataType { 49 | validate(value: any) { 50 | if (typeof value !== "number" || !Number.isInteger(value)) { 51 | throw new Error("not an integer"); 52 | } 53 | } 54 | 55 | getTypescriptType(): string { 56 | return "number"; 57 | } 58 | 59 | canBeNull(): boolean { 60 | return false; 61 | } 62 | 63 | initializer(): string { 64 | return `new __Internal.Int()`; 65 | } 66 | } 67 | 68 | export class Decimal implements DataType { 69 | validate(value: any) { 70 | if (typeof value !== "number") { 71 | throw new Error("not a number"); 72 | } 73 | } 74 | 75 | getTypescriptType(): string { 76 | return "number"; 77 | } 78 | 79 | canBeNull(): boolean { 80 | return false; 81 | } 82 | 83 | initializer(): string { 84 | return `new __Internal.Decimal()`; 85 | } 86 | } 87 | 88 | export class Optional implements DataType { 89 | constructor(private nestedType: DataType) {} 90 | 91 | validate(value: any) { 92 | if (value !== undefined && value !== null) { 93 | this.nestedType.validate(value); 94 | } 95 | } 96 | 97 | getTypescriptType(): string { 98 | return this.nestedType.getTypescriptType(); 99 | } 100 | 101 | canBeNull(): boolean { 102 | return true; 103 | } 104 | 105 | initializer(): string { 106 | return `new __Internal.Optional(${this.nestedType.initializer()})`; 107 | } 108 | 109 | getNestedType(): DataType { 110 | return this.nestedType; 111 | } 112 | } 113 | 114 | export class Bool implements DataType { 115 | validate(value: any) { 116 | if (typeof value !== "boolean") { 117 | throw new Error("not a boolean"); 118 | } 119 | } 120 | 121 | getTypescriptType(): string { 122 | return "boolean"; 123 | } 124 | 125 | canBeNull(): boolean { 126 | return false; 127 | } 128 | 129 | initializer(): string { 130 | return `new __Internal.Bool()`; 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/internal/db/PostgresDb.ts: -------------------------------------------------------------------------------- 1 | import knex, { Knex } from "knex"; 2 | import { Db } from "."; 3 | import { PostgresDatabaseConfig } from "../config"; 4 | import { FieldMeta, ModelMeta, ProjectMeta, TransitionMeta } from "../meta"; 5 | import BaseObject from "../object"; 6 | import Transition from "../transition"; 7 | import { Task, TaskState } from "../consumer"; 8 | import { String, Int, DataType, Decimal, Optional, Bool } from "../dataTypes"; 9 | import { toSnakeCase } from "js-convert-case"; 10 | 11 | const TRANSITIONS_TABLE = "transitions"; 12 | const TASKS_TABLE = "tasks"; 13 | 14 | export default class PostgresDb implements Db { 15 | constructor(private projectMeta: ProjectMeta, private db: Knex) {} 16 | 17 | static fromConfig( 18 | projectMeta: ProjectMeta, 19 | config: PostgresDatabaseConfig 20 | ): PostgresDb { 21 | const db = knex({ 22 | client: "postgres", 23 | connection: config.connection_string, 24 | }); 25 | return new PostgresDb(projectMeta, db); 26 | } 27 | 28 | async transaction(fn: (db: PostgresDb) => Promise) { 29 | await this.db.transaction(async (txn) => { 30 | const newDb = new PostgresDb(this.projectMeta, txn); 31 | await fn(newDb); 32 | }); 33 | } 34 | 35 | async close() { 36 | await this.db.destroy(); 37 | } 38 | 39 | async setup() { 40 | // Create transitions table 41 | if (!(await this.db.schema.hasTable(TRANSITIONS_TABLE))) { 42 | await this.db.schema.createTable(TRANSITIONS_TABLE, (table) => { 43 | table.text("id"); 44 | table.text("model").notNullable(); 45 | table.text("type").notNullable(); 46 | table.text("from"); 47 | table.text("to").notNullable(); 48 | table.text("object_id").notNullable(); 49 | table.jsonb("data"); 50 | table.text("note"); 51 | table.text("triggered_by"); 52 | table.datetime("applied_at").notNullable(); 53 | }); 54 | } 55 | 56 | // Create tasks table 57 | if (!(await this.db.schema.hasTable(TASKS_TABLE))) { 58 | await this.db.schema.createTable(TASKS_TABLE, (table) => { 59 | table.text("id"); 60 | table.text("transition_id").notNullable(); 61 | table.text("consumer").notNullable(); 62 | table.text("state").notNullable(); 63 | }); 64 | } 65 | } 66 | 67 | async migrate() { 68 | // Create model tables 69 | for (const modelMeta of this.projectMeta.allModelMetas()) { 70 | if (!(await this.db.schema.hasTable(tableName(modelMeta)))) { 71 | await this.db.schema.createTable(tableName(modelMeta), (table) => { 72 | const addedFields = new Set(); 73 | table.text("id").primary(); 74 | table.text("state").notNullable(); 75 | 76 | for (const state of modelMeta.allStateMetas()) { 77 | for (const field of state.allFieldMetas()) { 78 | // Keep track of which fields have already been added 79 | // to avoid duplicates when some states share fields. 80 | if (addedFields.has(field.camelCaseName())) { 81 | continue; 82 | } 83 | addedFields.add(field.camelCaseName()); 84 | 85 | const type = this.typeToPostgresType(field.type); 86 | const builder = table.specificType(columnName(field), type); 87 | 88 | if ( 89 | modelMeta.doesFieldAppearInAllStates(field) && 90 | !field.type.canBeNull() 91 | ) { 92 | builder.notNullable(); 93 | } 94 | } 95 | } 96 | }); 97 | } 98 | } 99 | } 100 | 101 | private typeToPostgresType(type: DataType): string { 102 | // Postgres data types reference: https://www.postgresql.org/docs/current/datatype.html 103 | if (type instanceof String) { 104 | return "text"; 105 | } else if (type instanceof Int) { 106 | return "integer"; 107 | } else if (type instanceof Decimal) { 108 | return "decimal"; 109 | } else if (type instanceof Bool) { 110 | return "boolean"; 111 | } else if (type instanceof Optional) { 112 | return this.typeToPostgresType(type.getNestedType()); 113 | } else { 114 | throw new Error(`unrecognized data type ${type}`); 115 | } 116 | } 117 | 118 | async applyTransition( 119 | modelMeta: ModelMeta, 120 | transitionMeta: TransitionMeta, 121 | transition: Transition, 122 | object: BaseObject 123 | ): Promise { 124 | const targetStateMeta = modelMeta.getStateMetaBySerializedName( 125 | object.state 126 | ); 127 | 128 | // If we don't have from states, then this is an initializing transition 129 | // and we should be inserting a new row. 130 | const newRow = transitionMeta.fromStates == undefined; 131 | 132 | let data: { [key: string]: any } = { 133 | id: object.id, 134 | state: object.state, 135 | }; 136 | 137 | // Populate object data based on model fields 138 | for (const field of targetStateMeta.allFieldMetas()) { 139 | data[columnName(field)] = (object as any)[field.camelCaseName()]; 140 | } 141 | 142 | // Insert or update object 143 | const table = this.db(tableName(modelMeta)); 144 | if (newRow) { 145 | await table.insert(data); 146 | object.id = data.id; 147 | } else { 148 | await table.where("id", object.id).update(data); 149 | } 150 | 151 | // Record transition event 152 | await this.insertTransition(transition); 153 | } 154 | 155 | private async insertTransition(transition: Transition) { 156 | await this.db(TRANSITIONS_TABLE).insert({ 157 | id: transition.id, 158 | object_id: transition.objectId, 159 | model: transition.model, 160 | type: transition.type, 161 | from: transition.from, 162 | to: transition.to, 163 | data: transition.data, 164 | note: transition.note, 165 | triggered_by: transition.triggeredBy, 166 | applied_at: transition.appliedAt, 167 | }); 168 | } 169 | 170 | async getById(model: ModelMeta, id: string): Promise { 171 | const rows = await this.db 172 | .table(tableName(model)) 173 | .where("id", id) 174 | .select("*"); 175 | if (rows.length == 0) { 176 | throw new Error(`No object found with ID ${id}`); 177 | } 178 | 179 | const row = rows[0]; 180 | const state = row.state; 181 | const stateMeta = model.getStateMetaBySerializedName(state); 182 | 183 | // Convert row to object 184 | const data: { [key: string]: any } = { 185 | id: row.id, 186 | state, 187 | }; 188 | for (const field of stateMeta.allFieldMetas()) { 189 | const value = row[columnName(field)]; 190 | data[field.camelCaseName()] = value; 191 | } 192 | 193 | return data; 194 | } 195 | 196 | async query( 197 | model: ModelMeta, 198 | where?: { [key: string]: any }, 199 | limit?: number 200 | ): Promise { 201 | let query = this.db.table(tableName(model)); 202 | 203 | if (where) { 204 | for (const [key, value] of Object.entries(where)) { 205 | const columnName = toSnakeCase(key); 206 | if (Array.isArray(value)) { 207 | query = query.whereIn(columnName, value); 208 | } else { 209 | query = query.where(columnName, value); 210 | } 211 | } 212 | } 213 | 214 | if (limit) { 215 | query = query.limit(limit); 216 | } 217 | 218 | const rows = await query.select("*"); 219 | const results = rows.map((row) => { 220 | const state = row.state; 221 | const stateMeta = model.getStateMetaBySerializedName(state); 222 | 223 | const data: { [key: string]: any } = { 224 | id: row.id, 225 | state, 226 | }; 227 | for (const field of stateMeta.allFieldMetas()) { 228 | const value = row[columnName(field)]; 229 | data[field.camelCaseName()] = value; 230 | } 231 | 232 | return data; 233 | }); 234 | 235 | return results; 236 | } 237 | 238 | async insertTask(task: Task): Promise { 239 | await this.db(TASKS_TABLE).insert({ 240 | id: task.id, 241 | transition_id: task.transitionId, 242 | consumer: task.consumer, 243 | state: task.state, 244 | }); 245 | } 246 | 247 | async updateTask(task: Task): Promise { 248 | await this.db(TASKS_TABLE).where("id", task.id).update({ 249 | state: task.state, 250 | }); 251 | } 252 | 253 | async getTasksForTransition(transitionId: string): Promise { 254 | const rows = await this.db 255 | .table(TASKS_TABLE) 256 | .where("transition_id", transitionId) 257 | .select("*"); 258 | 259 | return rows.map((row) => ({ 260 | id: row.id, 261 | transitionId: row.transition_id, 262 | state: row.state, 263 | consumer: row.consumer, 264 | })); 265 | } 266 | 267 | async getTransitionById( 268 | id: string 269 | ): Promise | undefined> { 270 | const rows = await this.db 271 | .table(TRANSITIONS_TABLE) 272 | .where("id", id) 273 | .select("*"); 274 | if (rows.length == 0) { 275 | return null; 276 | } 277 | 278 | const row = rows[0]; 279 | return { 280 | id: row.id, 281 | objectId: row.object_id, 282 | model: row.model, 283 | type: row.type, 284 | from: row.from, 285 | to: row.to, 286 | data: row.data, 287 | note: row.note, 288 | triggeredBy: row.triggered_by, 289 | appliedAt: row.applied_at, 290 | }; 291 | } 292 | 293 | async getTransitionsForObject( 294 | objectId: string 295 | ): Promise[]> { 296 | const rows = await this.db 297 | .table(TRANSITIONS_TABLE) 298 | .where("object_id", objectId) 299 | .orderBy("id", "desc") 300 | .select("*"); 301 | 302 | return rows.map((row) => ({ 303 | id: row.id, 304 | objectId: row.object_id, 305 | model: row.model, 306 | type: row.type, 307 | from: row.from, 308 | to: row.to, 309 | data: row.data, 310 | note: row.note, 311 | triggeredBy: row.triggered_by, 312 | appliedAt: row.applied_at, 313 | })); 314 | } 315 | 316 | async getUnprocessedTasks(limit: number): Promise { 317 | const rows = await this.db 318 | .table(TASKS_TABLE) 319 | .where("state", "=", TaskState.Created) 320 | .limit(limit) 321 | .forUpdate() 322 | .skipLocked() 323 | .select("*"); 324 | 325 | return rows.map((row) => ({ 326 | id: row.id, 327 | transitionId: row.transition_id, 328 | state: row.state, 329 | consumer: row.consumer, 330 | })); 331 | } 332 | 333 | async setTaskProcessed(taskId: string): Promise { 334 | await this.db.table(TASKS_TABLE).where("id", "=", taskId).update({ 335 | state: TaskState.Completed, 336 | }); 337 | } 338 | } 339 | 340 | function tableName(model: ModelMeta): string { 341 | return model.pluralSnakeCaseName(); 342 | } 343 | 344 | function columnName(field: FieldMeta): string { 345 | return field.snakeCaseName(); 346 | } 347 | -------------------------------------------------------------------------------- /src/internal/db/SqliteDb.ts: -------------------------------------------------------------------------------- 1 | import { FieldMeta, ModelMeta, ProjectMeta, TransitionMeta } from "../meta"; 2 | import Transition from "../transition"; 3 | import { generateTransactionId } from "../id"; 4 | import knex, { Knex } from "knex"; 5 | import { toSnakeCase } from "js-convert-case"; 6 | import { Db } from "."; 7 | import { Database } from "sqlite3"; 8 | import { Task, TaskState } from "../consumer"; 9 | import BaseObject from "../object"; 10 | import { SqliteDatabaseConfig } from "../config"; 11 | import { Bool, DataType, Decimal, Int, Optional, String } from "../dataTypes"; 12 | 13 | const TRANSITIONS_TABLE = "transitions"; 14 | const TASKS_TABLE = "tasks"; 15 | 16 | type TransitionWithSeqId = Transition & { seqId: number }; 17 | 18 | export default class SqliteDb implements Db { 19 | constructor(private projectMeta: ProjectMeta, private db: Knex) {} 20 | 21 | static fromConfig( 22 | projectMeta: ProjectMeta, 23 | config: SqliteDatabaseConfig 24 | ): SqliteDb { 25 | const db = knex({ 26 | client: "sqlite3", 27 | connection: { 28 | filename: config.file, 29 | }, 30 | useNullAsDefault: true, 31 | }); 32 | 33 | return new SqliteDb(projectMeta, db); 34 | } 35 | 36 | async transaction(fn: (db: SqliteDb) => Promise) { 37 | await this.db.transaction(async (txn) => { 38 | const newDb = new SqliteDb(this.projectMeta, txn); 39 | await fn(newDb); 40 | }); 41 | } 42 | 43 | async close() { 44 | await this.db.destroy(); 45 | } 46 | 47 | async getRawSqliteConnection(): Promise { 48 | return await this.db.client.acquireRawConnection(); 49 | } 50 | 51 | async setup() { 52 | // Create transitions table 53 | if (!(await this.db.schema.hasTable(TRANSITIONS_TABLE))) { 54 | await this.db.schema.createTable(TRANSITIONS_TABLE, (table) => { 55 | table.increments("seq_id"); 56 | table.text("id"); 57 | table.text("model").notNullable(); 58 | table.text("type").notNullable(); 59 | table.text("from"); 60 | table.text("to").notNullable(); 61 | table.text("object_id").notNullable(); 62 | table.jsonb("data"); 63 | table.text("note"); 64 | table.text("triggered_by"); 65 | table.integer("applied_at").notNullable(); 66 | }); 67 | } 68 | 69 | // Create tasks table 70 | if (!(await this.db.schema.hasTable(TASKS_TABLE))) { 71 | await this.db.schema.createTable(TASKS_TABLE, (table) => { 72 | table.text("id"); 73 | table.text("transition_id").notNullable(); 74 | table.text("consumer").notNullable(); 75 | table.text("state").notNullable(); 76 | }); 77 | } 78 | } 79 | 80 | async migrate() { 81 | // Create model tables 82 | for (const modelMeta of this.projectMeta.allModelMetas()) { 83 | if (!(await this.db.schema.hasTable(tableName(modelMeta)))) { 84 | await this.db.schema.createTable(tableName(modelMeta), (table) => { 85 | const addedFields = new Set(); 86 | table.text("id").primary(); 87 | table.text("state").notNullable(); 88 | 89 | for (const state of modelMeta.allStateMetas()) { 90 | for (const field of state.allFieldMetas()) { 91 | // Keep track of which fields have already been added 92 | // to avoid duplicates when some states share fields. 93 | if (addedFields.has(field.camelCaseName())) { 94 | continue; 95 | } 96 | addedFields.add(field.camelCaseName()); 97 | 98 | const type = this.typeToSqliteType(field.type); 99 | const builder = table.specificType(columnName(field), type); 100 | 101 | if ( 102 | modelMeta.doesFieldAppearInAllStates(field) && 103 | !field.type.canBeNull() 104 | ) { 105 | builder.notNullable(); 106 | } 107 | } 108 | } 109 | }); 110 | } 111 | } 112 | } 113 | 114 | private typeToSqliteType(type: DataType): string { 115 | // SQLite data types reference: https://www.sqlite.org/datatype3.html 116 | if (type instanceof String) { 117 | return "text"; 118 | } else if (type instanceof Int) { 119 | return "integer"; 120 | } else if (type instanceof Decimal) { 121 | return "real"; 122 | } else if (type instanceof Bool) { 123 | return "integer"; 124 | } else if (type instanceof Optional) { 125 | return this.typeToSqliteType(type.getNestedType()); 126 | } else { 127 | throw new Error(`unrecognized data type ${type}`); 128 | } 129 | } 130 | 131 | async applyTransition( 132 | modelMeta: ModelMeta, 133 | transitionMeta: TransitionMeta, 134 | transition: Transition, 135 | object: BaseObject 136 | ): Promise { 137 | const targetStateMeta = modelMeta.getStateMetaBySerializedName( 138 | object.state 139 | ); 140 | 141 | // If we don't have from states, then this is an initializing transition 142 | // and we should be inserting a new row. 143 | const newRow = transitionMeta.fromStates == undefined; 144 | 145 | let data: { [key: string]: any } = { 146 | id: object.id, 147 | state: object.state, 148 | }; 149 | 150 | // Populate object data based on model fields 151 | for (const field of targetStateMeta.allFieldMetas()) { 152 | data[columnName(field)] = (object as any)[field.camelCaseName()]; 153 | } 154 | 155 | // Insert or update object 156 | const table = this.db(tableName(modelMeta)); 157 | if (newRow) { 158 | await table.insert(data); 159 | object.id = data.id; 160 | } else { 161 | await table.where("id", object.id).update(data); 162 | } 163 | 164 | // Record transition event 165 | await this.insertTransition(transition); 166 | } 167 | 168 | private async insertTransition(transition: Transition) { 169 | await this.db(TRANSITIONS_TABLE).insert({ 170 | id: transition.id, 171 | object_id: transition.objectId, 172 | model: transition.model, 173 | type: transition.type, 174 | from: transition.from, 175 | to: transition.to, 176 | data: transition.data, 177 | note: transition.note, 178 | triggered_by: transition.triggeredBy, 179 | applied_at: transition.appliedAt.getTime(), 180 | }); 181 | } 182 | 183 | async getById(model: ModelMeta, id: string): Promise { 184 | const rows = await this.db 185 | .table(tableName(model)) 186 | .where("id", id) 187 | .select("*"); 188 | if (rows.length == 0) { 189 | throw new Error(`No object found with ID ${id}`); 190 | } 191 | 192 | const row = rows[0]; 193 | const state = row.state; 194 | const stateMeta = model.getStateMetaBySerializedName(state); 195 | 196 | // Convert row to object 197 | const data: { [key: string]: any } = { 198 | id: row.id, 199 | state, 200 | }; 201 | for (const field of stateMeta.allFieldMetas()) { 202 | const value = row[columnName(field)]; 203 | data[field.camelCaseName()] = value; 204 | } 205 | 206 | return data; 207 | } 208 | 209 | async query( 210 | model: ModelMeta, 211 | where?: { [key: string]: any }, 212 | limit?: number 213 | ): Promise { 214 | let query = this.db.table(tableName(model)); 215 | 216 | if (where) { 217 | for (const [key, value] of Object.entries(where)) { 218 | const columnName = toSnakeCase(key); 219 | if (Array.isArray(value)) { 220 | query = query.whereIn(columnName, value); 221 | } else { 222 | query = query.where(columnName, value); 223 | } 224 | } 225 | } 226 | 227 | if (limit) { 228 | query = query.limit(limit); 229 | } 230 | 231 | const rows = await query.select("*"); 232 | const results = rows.map((row) => { 233 | const state = row.state; 234 | const stateMeta = model.getStateMetaBySerializedName(state); 235 | 236 | const data: { [key: string]: any } = { 237 | id: row.id, 238 | state, 239 | }; 240 | for (const field of stateMeta.allFieldMetas()) { 241 | const value = row[columnName(field)]; 242 | data[field.camelCaseName()] = value; 243 | } 244 | 245 | return data; 246 | }); 247 | 248 | return results; 249 | } 250 | 251 | async getLatestTransitionSeqId(): Promise { 252 | const row = await this.db 253 | .table(TRANSITIONS_TABLE) 254 | .orderBy("id", "desc") 255 | .limit(1) 256 | .first("seq_id"); 257 | if (!row) { 258 | return undefined; 259 | } 260 | 261 | return row.seq_id; 262 | } 263 | 264 | async getTransitions( 265 | afterSeqId: number | undefined 266 | ): Promise { 267 | let query = this.db.table(TRANSITIONS_TABLE); 268 | if (afterSeqId) { 269 | query = query.where("seq_id", ">", afterSeqId); 270 | } 271 | 272 | const rows = await query.orderBy("id", "desc").select("*"); 273 | 274 | return rows.map((row) => ({ 275 | id: row.id, 276 | seqId: row.seq_id, 277 | objectId: row.object_id, 278 | model: row.model, 279 | type: row.type, 280 | from: row.from, 281 | to: row.to, 282 | data: row.data, 283 | appliedAt: row.applied_at, 284 | })); 285 | } 286 | 287 | async getTransitionById( 288 | id: string 289 | ): Promise | undefined> { 290 | const rows = await this.db 291 | .table(TRANSITIONS_TABLE) 292 | .where("id", id) 293 | .select("*"); 294 | if (rows.length == 0) { 295 | return null; 296 | } 297 | 298 | const row = rows[0]; 299 | return this.transitionFromRow(row); 300 | } 301 | 302 | async getTransitionsForObject( 303 | objectId: string 304 | ): Promise[]> { 305 | const rows = await this.db 306 | .table(TRANSITIONS_TABLE) 307 | .where("object_id", objectId) 308 | .orderBy("id", "desc") 309 | .select("*"); 310 | 311 | return rows.map(this.transitionFromRow); 312 | } 313 | 314 | private transitionFromRow(row: any): Transition { 315 | return { 316 | id: row.id, 317 | objectId: row.object_id, 318 | model: row.model, 319 | type: row.type, 320 | from: row.from, 321 | to: row.to, 322 | data: row.data ? JSON.parse(row.data) : null, 323 | note: row.note, 324 | triggeredBy: row.triggered_by, 325 | appliedAt: new Date(row.applied_at), 326 | }; 327 | } 328 | 329 | async insertTask(task: Task): Promise { 330 | await this.db.table(TASKS_TABLE).insert({ 331 | id: task.id, 332 | transition_id: task.transitionId, 333 | consumer: task.consumer, 334 | state: task.state, 335 | }); 336 | } 337 | 338 | async updateTask(task: Task): Promise { 339 | await this.db.table(TASKS_TABLE).where("id", task.id).update({ 340 | state: task.state, 341 | }); 342 | } 343 | 344 | async getTasksForTransition(transitionId: string): Promise { 345 | const rows = await this.db 346 | .table(TASKS_TABLE) 347 | .where("transition_id", transitionId) 348 | .select("*"); 349 | 350 | return rows.map((row) => ({ 351 | id: row.id, 352 | transitionId: row.transition_id, 353 | state: row.state, 354 | consumer: row.consumer, 355 | })); 356 | } 357 | 358 | async getUnprocessedTasks(): Promise { 359 | const rows = await this.db 360 | .table(TASKS_TABLE) 361 | .where("state", "=", TaskState.Created) 362 | .select("*"); 363 | 364 | return rows.map((row) => ({ 365 | id: row.id, 366 | transitionId: row.transition_id, 367 | state: row.state, 368 | consumer: row.consumer, 369 | })); 370 | } 371 | 372 | async setTaskProcessed(taskId: string): Promise { 373 | await this.db.table(TASKS_TABLE).where("id", "=", taskId).update({ 374 | state: TaskState.Completed, 375 | }); 376 | } 377 | } 378 | 379 | function tableName(model: ModelMeta): string { 380 | return model.pluralSnakeCaseName(); 381 | } 382 | 383 | function columnName(field: FieldMeta): string { 384 | return field.snakeCaseName(); 385 | } 386 | -------------------------------------------------------------------------------- /src/internal/db/TestDb.ts: -------------------------------------------------------------------------------- 1 | import knex from "knex"; 2 | import { ModelMeta, ProjectMeta, TransitionMeta } from "../meta"; 3 | import BaseObject from "../object"; 4 | import Transition from "../transition"; 5 | import SqliteDb from "./SqliteDb"; 6 | 7 | export type Callback = ( 8 | modelMeta: ModelMeta, 9 | transitionMeta: TransitionMeta, 10 | transition: Transition, 11 | updatedModel: any 12 | ) => Promise; 13 | 14 | export default class TestDb extends SqliteDb { 15 | private callback?: Callback; 16 | 17 | constructor(projectMeta: ProjectMeta) { 18 | const knexDb = knex({ 19 | client: "sqlite", 20 | connection: { 21 | filename: ":memory:", 22 | }, 23 | useNullAsDefault: true, 24 | }); 25 | 26 | super(projectMeta, knexDb); 27 | } 28 | 29 | setTransitionCallback(callback: Callback) { 30 | this.callback = callback; 31 | } 32 | 33 | async applyTransition( 34 | modelMeta: ModelMeta, 35 | transitionMeta: TransitionMeta, 36 | transition: Transition, 37 | object: BaseObject 38 | ): Promise { 39 | await super.applyTransition(modelMeta, transitionMeta, transition, object); 40 | 41 | if (this.callback) { 42 | await this.callback(modelMeta, transitionMeta, transition, object); 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/internal/db/index.ts: -------------------------------------------------------------------------------- 1 | import { ModelMeta, TransitionMeta } from "../meta"; 2 | import Transition from "../transition"; 3 | import TestDb from "./TestDb"; 4 | import SqliteDb from "./SqliteDb"; 5 | import PostgresDb from "./PostgresDb"; 6 | import { Task, TaskState } from "../consumer"; 7 | 8 | export { TestDb, SqliteDb, PostgresDb }; 9 | 10 | export interface Db { 11 | setup(): Promise; 12 | close(): Promise; 13 | 14 | migrate(): Promise; 15 | 16 | transaction(fn: (db: Db) => Promise): Promise; 17 | 18 | // Transitions 19 | applyTransition( 20 | modelMeta: ModelMeta, 21 | transitionMeta: TransitionMeta, 22 | transition: Transition, 23 | object: Object 24 | ): Promise; 25 | getTransitionById(id: string): Promise | null>; 26 | getTransitionsForObject(objectId: string): Promise[]>; 27 | 28 | // Object querying 29 | getById(model: ModelMeta, id: string): Promise; 30 | query( 31 | model: ModelMeta, 32 | where?: { [key: string]: any }, 33 | limit?: number 34 | ): Promise; 35 | 36 | // Tasks 37 | insertTask(task: Task): Promise; 38 | updateTask(task: Task): Promise; 39 | getTasksForTransition(transition_id: string): Promise; 40 | } 41 | -------------------------------------------------------------------------------- /src/internal/id.ts: -------------------------------------------------------------------------------- 1 | import { ulid } from "ulid"; 2 | 3 | export function generateTransactionId(): string { 4 | return generateId("tsn") 5 | } 6 | 7 | export function generateConsumerTaskId(): string { 8 | return generateId("task") 9 | } 10 | 11 | export function generateId(prefix: string): string { 12 | return `${prefix}_${ulid().toLowerCase()}`; 13 | } 14 | -------------------------------------------------------------------------------- /src/internal/index.ts: -------------------------------------------------------------------------------- 1 | import { 2 | ProjectMeta, 3 | ModelMeta, 4 | StateMeta, 5 | FieldMeta, 6 | TransitionMeta, 7 | } from "./meta"; 8 | import { ArrayElementType, QueryParams } from "./types"; 9 | import { Db, TestDb, SqliteDb } from "./db"; 10 | import Consumer, { Task } from "./consumer"; 11 | import Project from "./project"; 12 | import { createTestConsumerRunner, SqliteQueue } from "./queue"; 13 | import { BaseClient, BaseTransitionsClient } from "./client"; 14 | import { Config, loadConfig } from "./config"; 15 | import { DataType, String, Int, Decimal, Optional, Bool } from "./dataTypes"; 16 | import Transition, { 17 | TransitionParameters, 18 | TransitionWithData, 19 | TransitionWithObject, 20 | } from "./transition"; 21 | 22 | export { 23 | Db, 24 | TestDb, 25 | SqliteDb, 26 | ProjectMeta, 27 | ModelMeta, 28 | StateMeta, 29 | FieldMeta, 30 | TransitionMeta, 31 | ArrayElementType, 32 | QueryParams, 33 | Consumer, 34 | Task, 35 | Project, 36 | createTestConsumerRunner, 37 | SqliteQueue, 38 | BaseClient, 39 | BaseTransitionsClient, 40 | Config, 41 | loadConfig, 42 | Transition, 43 | TransitionParameters, 44 | TransitionWithData, 45 | TransitionWithObject, 46 | DataType, 47 | String, 48 | Int, 49 | Decimal, 50 | Optional, 51 | Bool, 52 | }; 53 | -------------------------------------------------------------------------------- /src/internal/meta.ts: -------------------------------------------------------------------------------- 1 | import { toPascalCase, toSnakeCase } from "js-convert-case"; 2 | import { DataType } from "./dataTypes"; 3 | import pluralize from "pluralize"; 4 | 5 | export class ProjectMeta { 6 | constructor(private modelMetas: ModelMeta[]) {} 7 | 8 | allModelMetas(): ModelMeta[] { 9 | return this.modelMetas; 10 | } 11 | 12 | getModelMeta(name: string): ModelMeta { 13 | return this.modelMetas.find((meta) => meta.pascalCaseName() === name); 14 | } 15 | } 16 | 17 | export class ModelMeta { 18 | constructor( 19 | // PascalCase name 20 | private name: string, 21 | public prefix: string, 22 | private states: StateMeta[], 23 | private transitions: TransitionMeta[] 24 | ) {} 25 | 26 | pascalCaseName(): string { 27 | return this.name; 28 | } 29 | 30 | pluralPascalCaseName(): string { 31 | return pluralize(this.name); 32 | } 33 | 34 | snakeCaseName(): string { 35 | return toSnakeCase(this.name); 36 | } 37 | 38 | pluralSnakeCaseName(): string { 39 | return pluralize(toSnakeCase(this.name)); 40 | } 41 | 42 | allStateMetas(): StateMeta[] { 43 | return this.states; 44 | } 45 | 46 | getStateMeta(name: string): StateMeta { 47 | return this.states.find((meta) => meta.pascalCaseNmae() === name); 48 | } 49 | 50 | getStateMetaBySerializedName(serializedName: string): StateMeta { 51 | const pascalCaseStateName = toPascalCase(serializedName); 52 | return this.states.find( 53 | (meta) => meta.pascalCaseNmae() === pascalCaseStateName 54 | ); 55 | } 56 | 57 | allTransitionMetas(): TransitionMeta[] { 58 | return this.transitions; 59 | } 60 | 61 | getTransitionMeta(name: string): TransitionMeta { 62 | return this.transitions.find((meta) => meta.pascalCaseNmae() === name); 63 | } 64 | 65 | doesFieldAppearInAllStates(fieldMeta: FieldMeta): boolean { 66 | return ( 67 | this.allStateMetas().find((state) => { 68 | const hasFieldWithName = 69 | state 70 | .allFieldMetas() 71 | .find( 72 | (state) => state.camelCaseName() === fieldMeta.camelCaseName() 73 | ) !== undefined; 74 | return !hasFieldWithName; 75 | }) === undefined 76 | ); 77 | } 78 | } 79 | 80 | export class StateMeta { 81 | constructor(private name: string, private fields: FieldMeta[]) {} 82 | 83 | pascalCaseNmae(): string { 84 | return this.name; 85 | } 86 | 87 | allFieldMetas(): FieldMeta[] { 88 | return this.fields; 89 | } 90 | } 91 | 92 | export class TransitionMeta { 93 | constructor( 94 | private name: string, 95 | private fields: FieldMeta[], 96 | public fromStates: string[] | undefined, 97 | public toStates: string[] 98 | ) {} 99 | 100 | pascalCaseNmae(): string { 101 | return this.name; 102 | } 103 | 104 | snakeCaseName(): string { 105 | return toSnakeCase(this.name); 106 | } 107 | } 108 | 109 | export class FieldMeta { 110 | constructor(private name: string, public type: DataType) {} 111 | 112 | camelCaseName(): string { 113 | return this.name; 114 | } 115 | 116 | snakeCaseName(): string { 117 | return toSnakeCase(this.name); 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/internal/object.ts: -------------------------------------------------------------------------------- 1 | export default interface BaseObject { 2 | id: string; 3 | state: string; 4 | } -------------------------------------------------------------------------------- /src/internal/project.ts: -------------------------------------------------------------------------------- 1 | import Consumer from "./consumer"; 2 | 3 | export default interface Project { 4 | consumers?: Consumer[] 5 | } -------------------------------------------------------------------------------- /src/internal/queue/PostgresQueue.ts: -------------------------------------------------------------------------------- 1 | import { Client, ClientConfig, Connection } from "pg"; 2 | import { Pgoutput, PgoutputPlugin } from "pg-logical-replication"; 3 | import { Queue } from "."; 4 | import Consumer, { Task } from "../consumer"; 5 | import { PostgresDb } from "../db"; 6 | import { ModelMeta, ProjectMeta, TransitionMeta } from "../meta"; 7 | import Project from "../project"; 8 | import Transition from "../transition"; 9 | import Logger from "../../cmd/logger"; 10 | import { createTasksForTransition, runTask } from "./common"; 11 | 12 | const MAX_TASKS = 10; 13 | const TASK_PROCESS_INTERVAL_MS = 1_000; 14 | 15 | const PUBLICATION_NAME = "restate_transitions_cdc"; 16 | const REPLICATION_SLOT_NAME = "restate_transitions_cdc"; 17 | 18 | // This class is responsible for finding enqueued tasks and running them. 19 | // It will also automatically start a TaskEnqueuer, which is responsible for detecting new 20 | // transitions and enqueueing tasks if there are any consumers interested in those tasks. 21 | // 22 | // To effectively use Postgres as a task queue, it uses SELECT FOR UPDATE and SKIP LOCKED. 23 | // This way, multiple workers can fetch from the queue without grabbing the same tasks. 24 | export class PostgresQueue implements Queue { 25 | private enqueuer: TaskEnqueuer; 26 | 27 | constructor( 28 | private logger: Logger, 29 | private projectMeta: ProjectMeta, 30 | private db: PostgresDb, 31 | private client: any, 32 | private project: Project 33 | ) { 34 | this.enqueuer = new TaskEnqueuer(logger, projectMeta, db, project); 35 | } 36 | 37 | async run(): Promise { 38 | // Start enqueuer fully async 39 | (async () => { 40 | await this.enqueuer.run(); 41 | })(); 42 | 43 | await this.processTasks(); 44 | } 45 | 46 | private async processTasks() { 47 | let numTasks = 0; 48 | 49 | // We need to run this in a transaction to make use of SELECT FOR UPDATE SKIP LOCKED 50 | this.db.transaction(async (txn) => { 51 | const tasks = await txn.getUnprocessedTasks(MAX_TASKS); 52 | numTasks = tasks.length; 53 | 54 | if (numTasks == 0) { 55 | return; 56 | } 57 | 58 | await Promise.all( 59 | tasks.map((task) => 60 | runTask(txn, this.projectMeta, this.project, this.client, task) 61 | ) 62 | ); 63 | }); 64 | 65 | // If no tasks were found, we wait a set interval and then try again 66 | // If tasks were found and processed, we simply try again immediately 67 | let waitTime = numTasks == 0 ? TASK_PROCESS_INTERVAL_MS : 0; 68 | setTimeout(async () => await this.processTasks(), waitTime); 69 | } 70 | } 71 | 72 | // This class sets up change data capture to detect when new transitions are inserted into the transitions table 73 | // It achieves this by using logical replication, setting up a replication slot and listening for changes. 74 | // When it detects a new transition having been inserted, it will find all consumers interested in that transition 75 | // and create a new task for each of them. It won't actually run any tasks, that's handled by the worker process itself. 76 | // 77 | // Derived from https://github.com/kibae/pg-logical-replication 78 | class TaskEnqueuer { 79 | private decodingPlugin: PgoutputPlugin; 80 | private client: Client; 81 | 82 | constructor( 83 | private logger: Logger, 84 | private projectMeta: ProjectMeta, 85 | private db: PostgresDb, 86 | private project: Project 87 | ) { 88 | this.decodingPlugin = new PgoutputPlugin({ 89 | protoVersion: 1, 90 | publicationNames: [PUBLICATION_NAME], 91 | }); 92 | this.client = new Client({ 93 | replication: "database", 94 | } as ClientConfig); 95 | } 96 | 97 | async run() { 98 | await this.client.connect(); 99 | 100 | // There can only be one worker consuming the change log. To ensure this, we 101 | // use an advisory lock. Only one worker will be able to hold the lock at any given 102 | // time and the others will be waiting for it to become available. This way, if the current 103 | // enqueueing worker fails, another one will automatically take its place. 104 | this.logger.debug( 105 | "Waiting for enqueuer lock (tasks will still be processed)" 106 | ); 107 | await this.client.query("SELECT pg_advisory_lock(1)"); 108 | this.logger.debug("Got lock-y! Starting enqueuer"); 109 | 110 | // Set up replication if necessary. Only needs to be done once. 111 | await this.configureReplication(); 112 | 113 | this.connection().on("copyData", ({ chunk }: { chunk: Buffer }) => { 114 | if (chunk[0] != 0x77 && chunk[0] != 0x6b) { 115 | this.logger.warn("Unknown message received from COPY", { 116 | message: chunk[0], 117 | }); 118 | return; 119 | } 120 | 121 | const lsn = 122 | chunk.readUInt32BE(1).toString(16).toUpperCase() + 123 | "/" + 124 | chunk.readUInt32BE(5).toString(16).toUpperCase(); 125 | 126 | if (chunk[0] == 0x77) { 127 | // XLogData 128 | // This indicates that a new change has happened and that the LSN has moved ahead 129 | this.onLog(lsn, this.decodingPlugin.parse(chunk.subarray(25))); 130 | this.acknowledge(lsn); 131 | } else if (chunk[0] == 0x6b) { 132 | // Primary keepalive message 133 | // These are heartbeats sent out by Postgres. If shouldRespond is true, we must ack, otherwise 134 | // Postgres will close our connection. 135 | const shouldRespond = !!chunk.readInt8(17); 136 | 137 | if (shouldRespond) { 138 | this.logger.debug("Acknowledging keepalive message", { lsn }); 139 | this.acknowledge(lsn); 140 | } 141 | } 142 | }); 143 | 144 | await this.decodingPlugin.start( 145 | this.client, 146 | REPLICATION_SLOT_NAME, 147 | "0/000000" 148 | ); 149 | } 150 | 151 | private async configureReplication() { 152 | // Create publication for only the transitions table 153 | // We inject the PUBLICATION_NAME variable directly here, rather than use a parameterized query 154 | // as our connection is in replication mode and we can only use the simple query protocol. 155 | const { rows: existingPublicationSlots } = await this.client.query( 156 | `SELECT pubname FROM pg_publication WHERE pubname = '${PUBLICATION_NAME}' LIMIT 1;` 157 | ); 158 | if (existingPublicationSlots.length == 0) { 159 | this.logger.debug("Creating publication slot"); 160 | await this.client.query( 161 | `CREATE PUBLICATION ${PUBLICATION_NAME} FOR TABLE public.transitions` 162 | ); 163 | } 164 | 165 | // Create replication slot which we'll use to get changes 166 | const { rows: existingReplicationSlots } = await this.client.query( 167 | `SELECT slot_name FROM pg_replication_slots WHERE slot_name = '${REPLICATION_SLOT_NAME}' LIMIT 1;` 168 | ); 169 | if (existingReplicationSlots.length == 0) { 170 | this.logger.debug("Creating replication slot"); 171 | await this.client.query( 172 | `SELECT pg_create_logical_replication_slot('${REPLICATION_SLOT_NAME}', 'pgoutput')` 173 | ); 174 | } 175 | } 176 | 177 | async onLog(lsn: string, log: Pgoutput.Message) { 178 | if (log.tag != "insert") { 179 | return; 180 | } 181 | 182 | const transition = this.transitionFromLog(log); 183 | await this.createTasksForTransition(transition); 184 | } 185 | 186 | async acknowledge(lsn: string): Promise { 187 | const slice = lsn.split("/"); 188 | let [upperWAL, lowerWAL]: [number, number] = [ 189 | parseInt(slice[0], 16), 190 | parseInt(slice[1], 16), 191 | ]; 192 | 193 | // Timestamp as microseconds since midnight 2000-01-01 194 | const now = Date.now() - 946080000000; 195 | const upperTimestamp = Math.floor(now / 4294967.296); 196 | const lowerTimestamp = Math.floor(now - upperTimestamp * 4294967.296); 197 | 198 | if (lowerWAL === 4294967295) { 199 | // [0xff, 0xff, 0xff, 0xff] 200 | upperWAL = upperWAL + 1; 201 | lowerWAL = 0; 202 | } else { 203 | lowerWAL = lowerWAL + 1; 204 | } 205 | 206 | const response = Buffer.alloc(34); 207 | response.fill(0x72); // 'r' 208 | 209 | // Last WAL Byte + 1 received and written to disk locally 210 | response.writeUInt32BE(upperWAL, 1); 211 | response.writeUInt32BE(lowerWAL, 5); 212 | 213 | // Last WAL Byte + 1 flushed to disk in the standby 214 | response.writeUInt32BE(upperWAL, 9); 215 | response.writeUInt32BE(lowerWAL, 13); 216 | 217 | // Last WAL Byte + 1 applied in the standby 218 | response.writeUInt32BE(upperWAL, 17); 219 | response.writeUInt32BE(lowerWAL, 21); 220 | 221 | // Timestamp as microseconds since midnight 2000-01-01 222 | response.writeUInt32BE(upperTimestamp, 25); 223 | response.writeUInt32BE(lowerTimestamp, 29); 224 | 225 | // If 1, requests server to respond immediately - can be used to verify connectivity 226 | response.writeInt8(0, 33); 227 | 228 | (this.connection() as any).sendCopyFromChunk(response); 229 | 230 | return true; 231 | } 232 | 233 | private connection(): Connection { 234 | return (this.client as any).connection; 235 | } 236 | 237 | private async createTasksForTransition(transition: Transition) { 238 | const tasks = await createTasksForTransition( 239 | this.db, 240 | this.projectMeta, 241 | this.project, 242 | transition 243 | ); 244 | 245 | for (const task of tasks) { 246 | this.logger.info("Enqueued task", { 247 | task: task.id, 248 | transition: transition.id, 249 | consumer: task.consumer, 250 | }); 251 | } 252 | } 253 | 254 | private transitionFromLog( 255 | log: Pgoutput.MessageInsert 256 | ): Transition { 257 | return { 258 | id: log.new.id, 259 | model: log.new.model, 260 | type: log.new.type, 261 | from: log.new.from, 262 | to: log.new.to, 263 | objectId: log.new.object_id, 264 | data: log.new.data, 265 | appliedAt: log.new.applied_at, 266 | }; 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /src/internal/queue/SqliteQueue.ts: -------------------------------------------------------------------------------- 1 | import Project from "../project"; 2 | import SqliteDb from "../db/SqliteDb"; 3 | import { ModelMeta, ProjectMeta } from "../meta"; 4 | import { Queue } from "."; 5 | import { createTasksForTransition, runTask } from "./common"; 6 | 7 | const TRANSITION_PROCESS_INTERVAL = 500; 8 | const TASK_PROCESS_INTERVAL = 500; 9 | 10 | export class SqliteQueue implements Queue { 11 | private lastSeqId: number | undefined; 12 | 13 | constructor( 14 | private projectMeta: ProjectMeta, 15 | private db: SqliteDb, 16 | private client: any, 17 | private project: Project 18 | ) {} 19 | 20 | async run(): Promise { 21 | this.lastSeqId = await this.db.getLatestTransitionSeqId(); 22 | 23 | await this.processTransitions(); 24 | await this.processTasks(); 25 | } 26 | 27 | private async processTransitions() { 28 | const newTransitions = await this.db.getTransitions(this.lastSeqId); 29 | for (const transition of newTransitions) { 30 | await createTasksForTransition( 31 | this.db, 32 | this.projectMeta, 33 | this.project, 34 | transition 35 | ); 36 | } 37 | 38 | if (newTransitions.length != 0) { 39 | this.lastSeqId = newTransitions[newTransitions.length - 1].seqId; 40 | } 41 | 42 | // Process transitions again in a second 43 | setTimeout( 44 | async () => await this.processTransitions(), 45 | TRANSITION_PROCESS_INTERVAL 46 | ); 47 | } 48 | 49 | private async processTasks() { 50 | const tasks = await this.db.getUnprocessedTasks(); 51 | await Promise.all( 52 | tasks.map((task) => 53 | runTask(this.db, this.projectMeta, this.project, this.client, task) 54 | ) 55 | ); 56 | 57 | // Process tasks again after a set interval 58 | setTimeout(async () => await this.processTasks(), TASK_PROCESS_INTERVAL); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/internal/queue/common.ts: -------------------------------------------------------------------------------- 1 | import { toSnakeCase } from "js-convert-case"; 2 | import Consumer, { Task, TaskState } from "../consumer"; 3 | import { Db } from "../db"; 4 | import { generateConsumerTaskId } from "../id"; 5 | import { ProjectMeta } from "../meta"; 6 | import Transition from "../transition"; 7 | 8 | export async function createTasksForTransition( 9 | db: Db, 10 | projectMeta: ProjectMeta, 11 | project: any, 12 | transition: Transition 13 | ): Promise { 14 | const modelMeta = projectMeta.getModelMeta(transition.model); 15 | const transitionMeta = modelMeta.getTransitionMeta(transition.type); 16 | 17 | const allConsumers = project.consumers; 18 | if (allConsumers === undefined) { 19 | return; 20 | } 21 | 22 | // Find all consumers which match this transition 23 | const consumers = allConsumers.filter( 24 | (consumer) => 25 | consumer.model.name == modelMeta.pascalCaseName() && 26 | consumer.transitions.includes(transitionMeta.snakeCaseName()) 27 | ); 28 | 29 | let tasks = []; 30 | 31 | // Create a new task for each consumer 32 | for (const consumer of consumers) { 33 | if (!consumer.transitions.includes(toSnakeCase(transition.type))) { 34 | continue; 35 | } 36 | 37 | const task: Task = { 38 | id: generateConsumerTaskId(), 39 | consumer: consumer.name, 40 | transitionId: transition.id, 41 | state: TaskState.Created, 42 | }; 43 | 44 | await db.insertTask(task); 45 | tasks.push(task); 46 | } 47 | 48 | return tasks; 49 | } 50 | 51 | export async function runTask( 52 | db: Db, 53 | projectMeta: ProjectMeta, 54 | project: any, 55 | client: any, 56 | task: Task 57 | ): Promise { 58 | const consumer = getConsumerByName(project, task.consumer); 59 | const transition = await db.getTransitionById(task.transitionId); 60 | const modelMeta = projectMeta.getModelMeta(transition.model); 61 | const object = await db.getById(modelMeta, transition.objectId); 62 | 63 | const taskSpecificClient = client.withTriggeredBy(task.id); 64 | await consumer.handler(taskSpecificClient, object, transition); 65 | 66 | const updatedTask: Task = { 67 | ...task, 68 | state: TaskState.Completed, 69 | }; 70 | await db.updateTask(updatedTask); 71 | 72 | return updatedTask; 73 | } 74 | 75 | function getConsumerByName(project: any, name: string): Consumer { 76 | const allConsumers = project.consumers; 77 | if (allConsumers == undefined) { 78 | throw new Error(`couldn't find consumer named ${name}`); 79 | } 80 | 81 | const consumer = allConsumers.find((consumer) => consumer.name == name); 82 | if (consumer == undefined) { 83 | throw new Error(`couldn't find consumer named ${name}`); 84 | } 85 | 86 | return consumer; 87 | } 88 | -------------------------------------------------------------------------------- /src/internal/queue/index.ts: -------------------------------------------------------------------------------- 1 | import { createTestConsumerRunner } from "./test"; 2 | import { SqliteQueue } from "./SqliteQueue"; 3 | 4 | interface Queue { 5 | run(): Promise; 6 | } 7 | 8 | export { Queue, createTestConsumerRunner, SqliteQueue }; 9 | -------------------------------------------------------------------------------- /src/internal/queue/test.ts: -------------------------------------------------------------------------------- 1 | import Project from "../project"; 2 | import { Callback } from "../db/TestDb"; 3 | import { ModelMeta, ProjectMeta, TransitionMeta } from "../meta"; 4 | import Transition from "../transition"; 5 | import { createTasksForTransition, runTask } from "./common"; 6 | import { Db } from "../db"; 7 | 8 | export function createTestConsumerRunner( 9 | db: Db, 10 | project: Project, 11 | client: any, 12 | projectMeta: ProjectMeta 13 | ): Callback { 14 | return async ( 15 | _modelMeta: ModelMeta, 16 | _transitionMeta: TransitionMeta, 17 | transition: Transition, 18 | _updatedObject: any 19 | ): Promise => { 20 | const tasks = await createTasksForTransition( 21 | db, 22 | projectMeta, 23 | project, 24 | transition 25 | ); 26 | 27 | await Promise.all( 28 | tasks.map((task) => runTask(db, projectMeta, project, client, task)) 29 | ); 30 | }; 31 | } 32 | -------------------------------------------------------------------------------- /src/internal/transition.ts: -------------------------------------------------------------------------------- 1 | export default interface Transition { 2 | id: string; 3 | objectId: string; 4 | model: string; 5 | type: Type; 6 | from: string | null; 7 | to: string; 8 | data: Data; 9 | triggeredBy?: string; 10 | note?: string; 11 | appliedAt: Date; 12 | } 13 | 14 | export interface TransitionParameters { 15 | triggeredBy?: string; 16 | note?: string; 17 | } 18 | 19 | export interface TransitionWithObject { 20 | object: string | Object; 21 | } 22 | 23 | export interface TransitionWithData { 24 | data: Data; 25 | } 26 | -------------------------------------------------------------------------------- /src/internal/types.ts: -------------------------------------------------------------------------------- 1 | export type ArrayElementType = 2 | Type extends readonly (infer ElementType)[] ? ElementType : never; 3 | 4 | export interface QueryParams { 5 | where?: Filter; 6 | limit?: number; 7 | } 8 | -------------------------------------------------------------------------------- /src/parser.ts: -------------------------------------------------------------------------------- 1 | import { 2 | alt, 3 | apply, 4 | buildLexer, 5 | expectEOF, 6 | expectSingleResult, 7 | list_sc, 8 | opt, 9 | rep_sc, 10 | seq, 11 | tok, 12 | } from "typescript-parsec"; 13 | 14 | export type Model = { 15 | name: string; 16 | states: State[]; 17 | transitions: Transition[]; 18 | baseFields: Field[]; 19 | prefix?: PrefixSetting; 20 | }; 21 | 22 | type ModelComponent = State | Transition | Field | PrefixSetting; 23 | 24 | export type State = { 25 | id: "STATE"; 26 | name: string; 27 | extends: string[]; 28 | fields: Field[]; 29 | }; 30 | 31 | export type Transition = { 32 | id: "TRANSITION"; 33 | name: string; 34 | from?: string[]; 35 | to: string[]; 36 | fields: Field[]; 37 | }; 38 | 39 | export type Field = { 40 | id: "FIELD"; 41 | name: string; 42 | type: Type; 43 | }; 44 | 45 | export type Type = { 46 | id: "TYPE"; 47 | name: string; 48 | nested?: Type; 49 | }; 50 | 51 | export type PrefixSetting = { 52 | id: "PREFIX"; 53 | prefix: string; 54 | }; 55 | 56 | function isState(modelComponent: ModelComponent): modelComponent is State { 57 | return modelComponent.id == "STATE"; 58 | } 59 | 60 | function isTransition( 61 | modelComponent: ModelComponent 62 | ): modelComponent is Transition { 63 | return modelComponent.id == "TRANSITION"; 64 | } 65 | 66 | function isField(modelComponent: ModelComponent): modelComponent is Field { 67 | return modelComponent.id == "FIELD"; 68 | } 69 | 70 | function isPrefixSetting( 71 | modelComponent: ModelComponent 72 | ): modelComponent is PrefixSetting { 73 | return modelComponent.id == "PREFIX"; 74 | } 75 | 76 | enum TokenKind { 77 | KeywordModel, 78 | KeywordState, 79 | KeywordTransition, 80 | KeywordField, 81 | KeywordPrefix, 82 | 83 | Identifier, 84 | StringLiteral, 85 | 86 | RightArrow, 87 | Asterisk, 88 | Pipe, 89 | Colon, 90 | Comma, 91 | LBrace, 92 | RBrace, 93 | LBracket, 94 | RBracket, 95 | Space, 96 | 97 | Comment, 98 | } 99 | 100 | const lexer = buildLexer([ 101 | [true, /^model/g, TokenKind.KeywordModel], 102 | [true, /^state/g, TokenKind.KeywordState], 103 | [true, /^transition/g, TokenKind.KeywordTransition], 104 | [true, /^field/g, TokenKind.KeywordField], 105 | [true, /^prefix/g, TokenKind.KeywordPrefix], 106 | 107 | [true, /^[a-zA-Z]+/g, TokenKind.Identifier], 108 | [true, /^"\S*"/g, TokenKind.StringLiteral], 109 | 110 | [true, /^->/g, TokenKind.RightArrow], 111 | [true, /^\*/g, TokenKind.Asterisk], 112 | [true, /^\|/g, TokenKind.Pipe], 113 | [true, /^:/g, TokenKind.Colon], 114 | [true, /^,/g, TokenKind.Comma], 115 | [true, /^\{/g, TokenKind.LBrace], 116 | [true, /^\}/g, TokenKind.RBrace], 117 | [true, /^\[/g, TokenKind.LBracket], 118 | [true, /^\]/g, TokenKind.RBracket], 119 | 120 | [false, /^\s+/g, TokenKind.Space], 121 | 122 | [false, /^\/\/.*/g, TokenKind.Comment], 123 | ]); 124 | 125 | const parseStringLiteral = apply( 126 | tok(TokenKind.StringLiteral), 127 | (literal): string => { 128 | return literal.text.substring(1, literal.text.length - 1); 129 | } 130 | ); 131 | 132 | const parsePrefixSetting = apply( 133 | seq(tok(TokenKind.KeywordPrefix), parseStringLiteral), 134 | ([_1, literal]): PrefixSetting => { 135 | return { 136 | id: "PREFIX", 137 | prefix: literal, 138 | }; 139 | } 140 | ); 141 | 142 | const parseType = apply( 143 | seq( 144 | tok(TokenKind.Identifier), 145 | opt( 146 | seq( 147 | tok(TokenKind.LBracket), 148 | tok(TokenKind.Identifier), 149 | tok(TokenKind.RBracket) 150 | ) 151 | ) 152 | ), 153 | ([name, nestedTokens]): Type => { 154 | let nested: Type | undefined = undefined; 155 | if (nestedTokens !== undefined) { 156 | nested = { 157 | id: "TYPE", 158 | name: nestedTokens[1].text, 159 | }; 160 | } 161 | 162 | return { 163 | id: "TYPE", 164 | name: name.text, 165 | nested, 166 | }; 167 | } 168 | ); 169 | 170 | const parseField = apply( 171 | seq( 172 | tok(TokenKind.KeywordField), 173 | tok(TokenKind.Identifier), 174 | tok(TokenKind.Colon), 175 | parseType 176 | ), 177 | ([_1, name, _2, type]): Field => { 178 | return { 179 | id: "FIELD", 180 | name: name.text, 181 | type: type, 182 | }; 183 | } 184 | ); 185 | 186 | const parseState = apply( 187 | seq( 188 | tok(TokenKind.KeywordState), 189 | tok(TokenKind.Identifier), 190 | 191 | opt( 192 | seq( 193 | tok(TokenKind.Colon), 194 | list_sc(tok(TokenKind.Identifier), tok(TokenKind.Comma)) 195 | ) 196 | ), 197 | 198 | tok(TokenKind.LBrace), 199 | rep_sc(parseField), 200 | tok(TokenKind.RBrace) 201 | ), 202 | ([_1, ident, ext, _2, fields]): State => { 203 | const extendedStates = ext?.[1].map((token) => token.text) ?? []; 204 | 205 | return { 206 | id: "STATE", 207 | name: ident.text, 208 | extends: extendedStates, 209 | fields, 210 | }; 211 | } 212 | ); 213 | 214 | const parseTransition = apply( 215 | seq( 216 | // transition NAME: 217 | tok(TokenKind.KeywordTransition), 218 | tok(TokenKind.Identifier), 219 | tok(TokenKind.Colon), 220 | 221 | // STATE1 | STATE2 -> 222 | opt( 223 | seq( 224 | alt( 225 | tok(TokenKind.Asterisk), 226 | list_sc(tok(TokenKind.Identifier), tok(TokenKind.Pipe)) 227 | ), 228 | tok(TokenKind.RightArrow) 229 | ) 230 | ), 231 | 232 | // STATE3 | STATE4 233 | alt( 234 | tok(TokenKind.Asterisk), 235 | list_sc(tok(TokenKind.Identifier), tok(TokenKind.Pipe)) 236 | ), 237 | 238 | // {} 239 | tok(TokenKind.LBrace), 240 | rep_sc(parseField), 241 | tok(TokenKind.RBrace) 242 | ), 243 | ([_1, ident, _2, from, to, _4, fields, _5]): Transition => { 244 | // A from state doesn't have to be specified. If it's not, then it's 245 | // an initializing transition. 246 | let fromResult: string[] | undefined = undefined; 247 | if (from) { 248 | const fromStates = from[0]; 249 | const fromTokens = Array.isArray(fromStates) ? fromStates : [fromStates]; 250 | fromResult = fromTokens.map((token) => token.text); 251 | } 252 | 253 | const toTokens = Array.isArray(to) ? to : [to]; 254 | 255 | return { 256 | id: "TRANSITION", 257 | name: ident.text, 258 | from: fromResult, 259 | to: toTokens.map((token) => token.text), 260 | fields: fields, 261 | }; 262 | } 263 | ); 264 | 265 | const parseModel = apply( 266 | seq( 267 | tok(TokenKind.KeywordModel), 268 | tok(TokenKind.Identifier), 269 | 270 | tok(TokenKind.LBrace), 271 | rep_sc(alt(parseState, parseTransition, parseField, parsePrefixSetting)), 272 | tok(TokenKind.RBrace) 273 | ), 274 | ([_1, ident, _2, components]): Model => { 275 | const states: State[] = components.filter(isState); 276 | const transitions: Transition[] = components.filter(isTransition); 277 | const baseFields: Field[] = components.filter(isField); 278 | 279 | const prefixSetting: PrefixSetting | undefined = 280 | components.filter(isPrefixSetting)[0]; 281 | 282 | return { 283 | name: ident.text, 284 | states, 285 | transitions, 286 | baseFields: baseFields, 287 | prefix: prefixSetting, 288 | }; 289 | } 290 | ); 291 | 292 | const parseModels = rep_sc(parseModel); 293 | 294 | export function parse(input: string): Model[] { 295 | const tokens = lexer.parse(input); 296 | return expectSingleResult(expectEOF(parseModels.parse(tokens))); 297 | } 298 | -------------------------------------------------------------------------------- /test/jest.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('ts-jest').JestConfigWithTsJest} */ 2 | module.exports = { 3 | preset: 'ts-jest', 4 | testEnvironment: 'node', 5 | }; -------------------------------------------------------------------------------- /test/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "restate-test", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "restate generate && jest" 8 | }, 9 | "author": "", 10 | "license": "MIT", 11 | "dependencies": { 12 | "restate": "file:.." 13 | }, 14 | "devDependencies": { 15 | "@jest/globals": "^29.3.1", 16 | "jest": "^29.3.1", 17 | "ts-jest": "^29.0.5" 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /test/restate.config.json: -------------------------------------------------------------------------------- 1 | { 2 | "production": { 3 | "database": { 4 | "type": "postgres", 5 | "connection_string": "postgres://fabianlindfors@localhost/fabianlindfors" 6 | } 7 | } 8 | } -------------------------------------------------------------------------------- /test/restate/Email.rst: -------------------------------------------------------------------------------- 1 | model Email { 2 | prefix "email" 3 | 4 | field userId: String 5 | field subject: String 6 | 7 | state Created {} 8 | state Sent {} 9 | 10 | transition Create: Created { 11 | field userId: String 12 | field subject: String 13 | } 14 | 15 | transition Send: Created -> Sent {} 16 | } -------------------------------------------------------------------------------- /test/restate/TypesTest.rst: -------------------------------------------------------------------------------- 1 | model TypesTest { 2 | prefix "tt" 3 | 4 | field string: String 5 | field integer: Int 6 | field decimal: Decimal 7 | field optional: Optional[Int] 8 | field boolean: Bool 9 | 10 | // Test comment 11 | state Created {} 12 | 13 | transition Create: Created { 14 | field string: String // Test comment on a field 15 | field integer: Int 16 | field decimal: Decimal 17 | field optional: Optional[Int] 18 | field boolean: Bool 19 | } 20 | } 21 | 22 | model StateWithNonNullableField { 23 | prefix "swnnf" 24 | 25 | state Created {} 26 | state Finished { 27 | field result: String 28 | } 29 | 30 | transition Create: Created {} 31 | transition Finish: Created -> Finished { 32 | field result: String 33 | } 34 | } -------------------------------------------------------------------------------- /test/restate/User.rst: -------------------------------------------------------------------------------- 1 | model User { 2 | prefix "user" 3 | 4 | field name: String 5 | field duplicateTransition: Optional[String] 6 | 7 | state Created { 8 | field nickname: Optional[String] 9 | field age: Optional[Int] 10 | } 11 | state Deleted {} 12 | 13 | transition Create: Created {} 14 | transition CreateExtra: Created {} 15 | transition CreateWithData: Created { 16 | field nickname: String 17 | field age: Int 18 | } 19 | transition Delete: Created -> Deleted {} 20 | transition CreateDouble: Created {} 21 | 22 | } -------------------------------------------------------------------------------- /test/src/index.test.ts: -------------------------------------------------------------------------------- 1 | import { test, expect, beforeEach, afterEach, describe } from "@jest/globals"; 2 | import { 3 | Email, 4 | RestateClient, 5 | setupTestClient, 6 | User, 7 | } from "../../src/generated"; 8 | import { State } from "../../src/generated/User"; 9 | import { TaskState } from "../../src/internal/consumer"; 10 | import project from "./restate"; 11 | 12 | let client: RestateClient; 13 | 14 | beforeEach(async () => { 15 | client = await setupTestClient(project); 16 | }); 17 | 18 | describe("transitions", () => { 19 | test("initializing", async () => { 20 | const [user, transition] = await client.user.transition.create(); 21 | 22 | // User should get an autogenerated ID with a prefix 23 | expect(user.id).toMatch(/user_[a-zA-Z0-9]+/); 24 | expect(user.name).toBe("Test Name"); 25 | expect(user.state).toBe(State.Created); 26 | 27 | expect(transition.from).toBeNull(); 28 | expect(transition.to).toBe(State.Created); 29 | }); 30 | 31 | test("transition existing object", async () => { 32 | const [createdUser] = await client.user.transition.create(); 33 | const [user, transition] = await client.user.transition.delete({ 34 | object: createdUser, 35 | }); 36 | 37 | expect(user.state).toBe(State.Deleted); 38 | 39 | expect(transition.from).toBe(State.Created); 40 | expect(transition.to).toBe(State.Deleted); 41 | }); 42 | 43 | test("transition with data", async () => { 44 | const [user, transition] = await client.user.transition.createWithData({ 45 | data: { 46 | nickname: "Test Nick", 47 | age: 30, 48 | }, 49 | }); 50 | 51 | expect(user.nickname).toBe("Test Nick"); 52 | expect(user.age).toBe(30); 53 | 54 | expect(transition.data).toEqual({ 55 | nickname: "Test Nick", 56 | age: 30, 57 | }); 58 | }); 59 | 60 | test("notes", async () => { 61 | const [_, transition] = await client.user.transition.create({ 62 | note: "A little helpful note for the future", 63 | }); 64 | 65 | expect(transition.note).toBe("A little helpful note for the future"); 66 | }); 67 | 68 | test("get transition by ID", async () => { 69 | const [_, transition] = await client.user.transition.createWithData({ 70 | data: { 71 | nickname: "Test Nick", 72 | age: 30, 73 | }, 74 | }); 75 | 76 | const foundTransition = await client.user.getTransition(transition.id); 77 | expect(foundTransition.id).toBe(transition.id); 78 | expect(foundTransition.objectId).toBe(transition.objectId); 79 | expect(foundTransition.model).toBe(transition.model); 80 | expect(foundTransition.type).toBe(transition.type); 81 | expect(foundTransition.from).toBeNull(); 82 | expect(foundTransition.to).toBe(State.Created); 83 | expect(foundTransition.triggeredBy).toBe(transition.triggeredBy); 84 | expect(foundTransition.data).toEqual({ 85 | nickname: "Test Nick", 86 | age: 30, 87 | }); 88 | expect(foundTransition.appliedAt).toEqual(transition.appliedAt); 89 | }); 90 | 91 | test("get non-existent transition by ID", async () => { 92 | const missingTransition = await client.user.getTransition("tsn_missing"); 93 | expect(missingTransition).toBeNull(); 94 | }); 95 | 96 | test("get transitions for object", async () => { 97 | const [user, createTransition] = await client.user.transition.create(); 98 | const [_, deleteTransition] = await client.user.transition.delete({ 99 | object: user, 100 | }); 101 | 102 | const transitions = await client.user.getObjectTransitions(user); 103 | expect(transitions).toHaveLength(2); 104 | expect(transitions[0].id).toBe(deleteTransition.id); 105 | expect(transitions[1].id).toBe(createTransition.id); 106 | }); 107 | 108 | test("triggeredBy not set when running directly", async () => { 109 | const [_, transition] = await client.user.transition.create(); 110 | expect(transition.triggeredBy).toBe(null); 111 | }); 112 | 113 | test("triggeredBy set to transition ID when created from another transition", async () => { 114 | const [user, transition] = await client.user.transition.createDouble(); 115 | const otherTransition = await client.user.getTransition( 116 | user.duplicateTransition 117 | ); 118 | 119 | expect(otherTransition.triggeredBy).toEqual(transition.id); 120 | }); 121 | 122 | test("triggeredBy set to task ID when created from a consumer", async () => { 123 | const [user] = await client.user.transition.create(); 124 | const email = await client.email.findOneOrThrow({ 125 | where: { 126 | userId: user.id, 127 | }, 128 | }); 129 | 130 | const transitions = await client.email.getObjectTransitions(email); 131 | expect(transitions[0].triggeredBy).not.toBeNull(); 132 | }); 133 | }); 134 | 135 | describe("query", () => { 136 | let user1: User.Any; 137 | let user2: User.Any; 138 | 139 | beforeEach(async () => { 140 | [user1] = await client.user.transition.create(); 141 | 142 | [user2] = await client.user.transition.create(); 143 | await client.user.transition.delete({ object: user2 }); 144 | }); 145 | 146 | test("findMany", async () => { 147 | const results = await client.user.findAll(); 148 | 149 | expect(results).toHaveLength(2); 150 | }); 151 | 152 | test("findMany with limit", async () => { 153 | const results = await client.user.findAll({ 154 | limit: 1, 155 | }); 156 | 157 | expect(results).toHaveLength(1); 158 | }); 159 | 160 | test("findMany by state", async () => { 161 | const results = await client.user.findAll({ 162 | where: { 163 | state: State.Created, 164 | }, 165 | }); 166 | 167 | expect(results).toHaveLength(1); 168 | expect(results[0].id).toBe(user1.id); 169 | }); 170 | 171 | test("findOne by ID", async () => { 172 | const result = await client.user.findOne({ 173 | where: { 174 | id: user1.id, 175 | }, 176 | }); 177 | 178 | expect(result).not.toBeUndefined(); 179 | expect(result.id).toBe(user1.id); 180 | }); 181 | 182 | test("findOne non-existing", async () => { 183 | const result = await client.user.findOne({ 184 | where: { 185 | id: "user_asdadas", 186 | }, 187 | }); 188 | 189 | expect(result).toBeNull(); 190 | }); 191 | 192 | test("findOneOrThrow non-existing", async () => { 193 | await expect(async () => { 194 | await client.user.findOneOrThrow({ 195 | where: { 196 | id: "user_asdadas", 197 | }, 198 | }); 199 | }).rejects.toThrow("no object found"); 200 | }); 201 | }); 202 | 203 | describe("consumers", () => { 204 | let user: User.Any; 205 | let transition: User.Create; 206 | let results: Email.Any[]; 207 | 208 | beforeEach(async () => { 209 | [user, transition] = await client.user.transition.create(); 210 | results = await client.email.findAll({ 211 | where: { 212 | userId: user.id, 213 | }, 214 | }); 215 | }); 216 | 217 | test("can create another object", async () => { 218 | // `SendEmailOnUserCreation` consumer should have created a welcome email for the user 219 | expect(results).toHaveLength(1); 220 | expect(results[0].userId).toBe(user.id); 221 | expect(results[0].subject).toBe("Welcome!"); 222 | }); 223 | 224 | test("can be chained", async () => { 225 | // `SendEmail` consumer should have transitioned created email to sent 226 | expect(results).toHaveLength(1); 227 | expect(results[0].state).toBe(Email.State.Sent); 228 | }); 229 | 230 | test("getTasksForTransition returns tasks", async () => { 231 | const tasks = await client.getTasksForTransition(transition); 232 | expect(tasks).toHaveLength(1); 233 | expect(tasks[0].id).toBeDefined(); 234 | expect(tasks[0].state).toBe(TaskState.Completed); 235 | expect(tasks[0].transitionId).toBe(transition.id); 236 | }); 237 | }); 238 | 239 | describe("data types", () => { 240 | test("can save and retrieve values", async () => { 241 | // Create object with example values 242 | const [{ id }] = await client.typesTest.transition.create({ 243 | data: { 244 | string: "Test", 245 | integer: 5, 246 | decimal: 5.5, 247 | optional: 1, 248 | boolean: true, 249 | }, 250 | }); 251 | 252 | // Read object back from database 253 | const result = await client.typesTest.findOneOrThrow({ where: { id: id } }); 254 | 255 | expect(result.string).toBe("Test"); 256 | expect(result.integer).toBe(5); 257 | expect(result.decimal).toBe(5.5); 258 | expect(result.optional).toBe(1); 259 | }); 260 | 261 | test("validates strings", async () => { 262 | await expect(async () => { 263 | await client.typesTest.transition.create({ 264 | data: { 265 | string: undefined, // Not a string 266 | integer: 5, 267 | decimal: 5.5, 268 | boolean: true, 269 | }, 270 | }); 271 | }).rejects.toThrow("not a string"); 272 | }); 273 | 274 | test("validates integers", async () => { 275 | await expect(async () => { 276 | await client.typesTest.transition.create({ 277 | data: { 278 | string: "Test", 279 | integer: 0.1, // Not an integer 280 | decimal: 5.5, 281 | boolean: true, 282 | }, 283 | }); 284 | }).rejects.toThrow("not an integer"); 285 | }); 286 | 287 | test("validates decimals", async () => { 288 | await expect(async () => { 289 | await client.typesTest.transition.create({ 290 | data: { 291 | string: "Test", 292 | integer: 5, 293 | decimal: undefined, // Not a number 294 | boolean: true, 295 | }, 296 | }); 297 | }).rejects.toThrow("not a number"); 298 | }); 299 | 300 | test("validates optional nested value", async () => { 301 | await expect(async () => { 302 | await client.typesTest.transition.create({ 303 | data: { 304 | string: "Test", 305 | integer: 1, 306 | decimal: 5.5, 307 | optional: 0.1, // Not an integer 308 | boolean: true, 309 | }, 310 | }); 311 | }).rejects.toThrow("not an integer"); 312 | }); 313 | 314 | test("validates booleans", async () => { 315 | await expect(async () => { 316 | await client.typesTest.transition.create({ 317 | data: { 318 | string: "Test", 319 | integer: 5, 320 | decimal: 5.5, 321 | boolean: undefined, // Not a boolean 322 | }, 323 | }); 324 | }).rejects.toThrow("not a boolean"); 325 | }); 326 | 327 | test("handles state with non-nullable field", async () => { 328 | const [object] = await client.stateWithNonNullableField.transition.create(); 329 | const [updatedObject] = 330 | await client.stateWithNonNullableField.transition.finish({ 331 | object, 332 | data: { 333 | result: "test", 334 | }, 335 | }); 336 | 337 | expect(updatedObject.result).toBe("test"); 338 | }); 339 | }); 340 | 341 | afterEach(async () => { 342 | await client.close(); 343 | }); 344 | -------------------------------------------------------------------------------- /test/src/restate.ts: -------------------------------------------------------------------------------- 1 | import { 2 | RestateClient, 3 | RestateProject, 4 | Email, 5 | User, 6 | TypesTest, 7 | StateWithNonNullableField, 8 | } from "../../src/generated"; 9 | 10 | const project: RestateProject = { 11 | main: async function (restate: RestateClient): Promise { 12 | // No main function needed for tests 13 | }, 14 | transitions: { 15 | user: { 16 | async create(restate: RestateClient) { 17 | return { 18 | state: User.State.Created, 19 | name: "Test Name", 20 | }; 21 | }, 22 | async createExtra(restate: RestateClient) { 23 | return { 24 | state: User.State.Created, 25 | name: "Test Name", 26 | }; 27 | }, 28 | async createWithData(restate: RestateClient, transition) { 29 | return { 30 | state: User.State.Created, 31 | name: "Test Name", 32 | nickname: transition.data.nickname, 33 | age: transition.data.age, 34 | }; 35 | }, 36 | async createDouble(restate: RestateClient) { 37 | const [_, duplicateTransition] = await restate.user.transition.create(); 38 | 39 | return { 40 | state: User.State.Created, 41 | name: "Test Name", 42 | duplicateTransition: duplicateTransition.id, 43 | }; 44 | }, 45 | async delete(restate: RestateClient, existing: User.Created) { 46 | return { 47 | ...existing, 48 | state: User.State.Deleted, 49 | }; 50 | }, 51 | }, 52 | email: { 53 | async create(restate: RestateClient, transition: Email.Create) { 54 | return { 55 | state: Email.State.Created, 56 | userId: transition.data.userId, 57 | subject: transition.data.subject, 58 | }; 59 | }, 60 | async send( 61 | restate: RestateClient, 62 | email: Email.Created, 63 | transition: Email.Send 64 | ) { 65 | return { 66 | ...email, 67 | state: Email.State.Sent, 68 | }; 69 | }, 70 | }, 71 | typesTest: { 72 | async create(restate: RestateClient, transition: TypesTest.Create) { 73 | return { 74 | state: TypesTest.State.Created, 75 | ...transition.data, 76 | }; 77 | }, 78 | }, 79 | stateWithNonNullableField: { 80 | async create( 81 | restate: RestateClient, 82 | transition: StateWithNonNullableField.Create 83 | ) { 84 | return { 85 | state: StateWithNonNullableField.State.Created, 86 | }; 87 | }, 88 | 89 | async finish( 90 | restate: RestateClient, 91 | original: StateWithNonNullableField.Created, 92 | transition: StateWithNonNullableField.Finish 93 | ) { 94 | return { 95 | state: StateWithNonNullableField.State.Finished, 96 | ...transition.data, 97 | }; 98 | }, 99 | }, 100 | }, 101 | consumers: [ 102 | User.createConsumer({ 103 | name: "SendEmailOnUserCreation", 104 | transition: User.Transition.Create, 105 | handler: async (restate, user, transition) => { 106 | const [email] = await restate.email.transition.create({ 107 | data: { 108 | userId: transition.objectId, 109 | subject: "Welcome!", 110 | }, 111 | }); 112 | console.log("[SendEmailOnUserCreation] Created email", email); 113 | }, 114 | }), 115 | Email.createConsumer({ 116 | name: "SendEmail", 117 | transition: Email.Transition.Create, 118 | handler: async (restate, email, transition) => { 119 | const [sentEmail] = await restate.email.transition.send({ 120 | object: email.id, 121 | }); 122 | console.log("[SendEmail] Updated email to sent", sentEmail); 123 | }, 124 | }), 125 | ], 126 | }; 127 | 128 | export default project; 129 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "ts-node": { 3 | "esm": true 4 | }, 5 | "compilerOptions": { 6 | "esModuleInterop": true 7 | } 8 | } --------------------------------------------------------------------------------