├── .gitignore ├── README.md ├── frontend ├── .eslintrc.json ├── README.md ├── bun.lockb ├── components.json ├── drizzle.config.ts ├── next.config.mjs ├── package.json ├── postcss.config.mjs ├── src │ ├── app │ │ ├── about │ │ │ └── page.tsx │ │ ├── api │ │ │ ├── auth │ │ │ │ └── [...nextauth] │ │ │ │ │ └── route.ts │ │ │ ├── stories │ │ │ │ └── route.ts │ │ │ └── user │ │ │ │ └── credits │ │ │ │ └── route.ts │ │ ├── dashboard │ │ │ ├── layout.tsx │ │ │ ├── loading.tsx │ │ │ ├── page.tsx │ │ │ └── story │ │ │ │ ├── [slug] │ │ │ │ ├── loading.tsx │ │ │ │ └── page.tsx │ │ │ │ └── choice │ │ │ │ └── [node] │ │ │ │ ├── loading.tsx │ │ │ │ └── page.tsx │ │ ├── favicon.ico │ │ ├── fonts │ │ │ ├── GeistMonoVF.woff │ │ │ └── GeistVF.woff │ │ ├── globals.css │ │ ├── layout.tsx │ │ ├── loading.tsx │ │ ├── login │ │ │ └── page.tsx │ │ └── page.tsx │ ├── auth.ts │ ├── components │ │ ├── dashboard │ │ │ ├── sample-stories.tsx │ │ │ ├── story-list.tsx │ │ │ ├── terminal-input.tsx │ │ │ └── username-input.tsx │ │ ├── header │ │ │ ├── sign-out.tsx │ │ │ └── user-info.tsx │ │ ├── landing-page │ │ │ ├── features.tsx │ │ │ └── terminal.tsx │ │ ├── login │ │ │ └── terminal-login.tsx │ │ ├── navigation │ │ │ ├── navigation-link.tsx │ │ │ ├── navigation-progress-provider.tsx │ │ │ └── navigation-spinner.tsx │ │ ├── node │ │ │ ├── audio-player.tsx │ │ │ ├── choice-interface.tsx │ │ │ └── terminal-choice.tsx │ │ ├── story │ │ │ ├── reset-story.tsx │ │ │ ├── story-choice.tsx │ │ │ └── story-visibility-toggle.tsx │ │ ├── ui │ │ │ ├── button.tsx │ │ │ ├── card.tsx │ │ │ ├── dialog.tsx │ │ │ ├── hover-card.tsx │ │ │ ├── input.tsx │ │ │ ├── label.tsx │ │ │ ├── progress.tsx │ │ │ ├── scroll-area.tsx │ │ │ ├── slider.tsx │ │ │ ├── switch.tsx │ │ │ ├── toast.tsx │ │ │ ├── toaster.tsx │ │ │ └── toggle.tsx │ │ └── wrapper.tsx │ ├── constants │ │ ├── prompts.ts │ │ └── stories.json │ ├── db │ │ ├── db.ts │ │ └── schema.ts │ ├── hooks │ │ ├── use-toast.ts │ │ └── useTypewriter.ts │ ├── lib │ │ ├── login-server.ts │ │ ├── story.ts │ │ ├── tree.ts │ │ ├── user.ts │ │ └── utils.ts │ ├── middleware.ts │ ├── providers │ │ └── ReactQueryProvider.tsx │ └── types │ │ └── next-auth.d.ts ├── tailwind.config.ts └── tsconfig.json ├── modal ├── audio.py ├── download.py ├── generate_audio.py ├── images.py ├── pyproject.toml ├── tests │ ├── README.md │ ├── __init__.py │ ├── test_audio.py │ └── test_images.py ├── uv.lock └── workflows │ └── flux.json └── restate ├── .dockerignore ├── .gitignore ├── Dockerfile ├── fly.toml ├── helpers ├── db.py ├── env.py ├── s3.py └── story.py ├── hypercorn-config.toml ├── main.py ├── pyproject.toml ├── story.ipynb ├── tests ├── README.md ├── __init__.py ├── test_db.py ├── test_main.py ├── test_s3.py └── test_story.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | .yarn/install-state.gz 8 | 9 | # testing 10 | /coverage 11 | 12 | # next.js 13 | /.next/ 14 | /out/ 15 | 16 | # production 17 | /build 18 | 19 | # misc 20 | .DS_Store 21 | *.pem 22 | 23 | # debug 24 | npm-debug.log* 25 | yarn-debug.log* 26 | yarn-error.log* 27 | 28 | # local env files 29 | .env*.local 30 | .env 31 | .envrc 32 | 33 | # vercel 34 | .vercel 35 | 36 | # typescript 37 | *.tsbuildinfo 38 | next-env.d.ts 39 | 40 | 41 | frontend/build 42 | frontend/node_modules 43 | frontend/.next/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CYOA 2 | 3 | Choose Your Own Adventure is a sample repo that shows how you can chain together complex workflows using restate. You can try it at [our hosted version](https://cyoa.dev) if you're curious about how it works. 4 | 5 | In this case, we're using restate to create individual choices for a choose your own adventure game, generate unique images for each individual story choice. Instead of having to manually create each service and put a messaging queue in between them ( before adding more complex services to ensure retries and error handling are properly implemented), we can use restate to create a single workflow that can be easily scaled and modified. 6 | 7 | This project consists of three main components 8 | 9 | 1. **Frontend** : This is a simple NextJS application which uses Auth.js for the authentication and Turso for our database. 10 | 2. **Modal Services**: With Modal, we can create a single endpoint for image generate that can be easily scaled and modified. 11 | 3. **Restate Story Service** : We'll be using the restate python sdk to create a workflow that can handle our story generation process. 12 | 13 | ## Prerequisites 14 | 15 | In order to run this project, you'll need the following 16 | 17 | - Python 3.11 18 | - Nodejs 18 or higher 19 | - An AWS account with access to S3. 20 | - A Turso account with a database. 21 | 22 | ## Instructions 23 | 24 | In this section, we'll be going through the steps to run the project. We'll do so in 3 main portions 25 | 26 | Firstly, we'll set up our restate server and setup our restate cloud environment. Next we'll set up our modal services and get our deployment ready. Then, we'll set up our restate story service and get our workflow ready. Lastly, we'll start up our NextJS application and get it ready to use. 27 | 28 | 1. First, start by cloning the repository 29 | 30 | ```python 31 | git clone https://github.com/essencevc/cyoa 32 | ``` 33 | 34 | ### Setting Up Restate Cloud 35 | 36 | 1. Next, you'll need to install the restate server. If you're not using a mac, you can install the restate server and cli by following the instructions [here](https://docs.restate.dev/get_started/quickstart/) 37 | 38 | ```python 39 | brew install restatedev/tap/restate-server 40 | brew install restatedev/tap/restate 41 | ``` 42 | 43 | 2. Once you've done so, you'll need to then connect it to your restate cloud account. If you don't have one, you can create an account [here](https://restate.dev/cloud/) 44 | 45 | ```python 46 | restate cloud login 47 | 48 | restate cloud env configure 49 | restate cloud env tunnel 50 | ``` 51 | 52 | Once you've ran these commands, you'll have a local restate server that has an exposed endpoint that other services can use to send requests to over the internet. 53 | 54 | You'll then see something like this 55 | 56 | ```python 57 | restate cloud environment tunnel - Press Ctrl-C to exit 58 | 59 | 💡 Source → 💡 Destination 60 | 🤝 tunnel://menu:9082 → 🏠 http://localhost:9080 61 | 🏠 http://localhost:8080 → 🌎 public ingress 62 | 🏠 http://localhost:9070 → 🔒 admin API 63 | ``` 64 | 65 | ## Setting Up Modal 66 | 67 | Before continuing with this process, you'll need to have a modal account and a aws account. If you don't have one, you can create an account [here](https://modal.com/signup) and [here](https://aws.amazon.com/free/). 68 | 69 | You'll need the following credentials so that we can upload our generated images to your s3 bucket. 70 | 71 | ```bash 72 | AWS_ACCESS_KEY_ID 73 | AWS_SECRET_ACCESS_KEY 74 | AWS_REGION 75 | ``` 76 | 77 | 1. First, you'll need to create a virtual environment for modal and install the dependencies of our project. 78 | 79 | ```python 80 | cd modal 81 | uv venv 82 | source .venv/bin/activate && uv sync 83 | modal token new # This will create a new token for you 84 | ``` 85 | 86 | 2. Once you've done so, then just deploy the image generation service by running the command 87 | 88 | ```python 89 | modal deploy 90 | ``` 91 | 92 | This should in turn kick off a modal deployment and you should see something like this 93 | 94 | ```python 95 | (modal) ivanleo@Ivans-MacBook-Pro ~/D/c/c/modal (update-docs)> mod 96 | al deploy images.py 97 | ✓ Created objects. 98 | ├── 🔨 Created mount workflows/flux.json 99 | ├── 🔨 Created mount /Users/ivanleo/Documents/coding/cyoa/modal/images.py 100 | ├── 🔨 Created function ComfyUI.*. 101 | └── 🔨 Created web endpoint for ComfyUI.api => 102 | 103 | ✓ App deployed in 3.982s! 🎉 104 | 105 | View Deployment: 106 | ``` 107 | 108 | Make sure to note down this endpoint URL as we'll be using it in our restate story service. 109 | 110 | ### Setting Up Restate Story Service 111 | 112 | Now that we've got our modal web endpoint deployed and our restate cloud environment setup, we can start setting up our restate story service. 113 | 114 | 1. First, make sure that you're in the restate directory and fill in the following environment variables in a .env file in your restate directory, We'll be reading these variables using our `pydantic` model for it. 115 | 116 | ```bash 117 | GOOGLE_API_KEY= #Gemini API Key 118 | DB_URL= # Turso Database URL 119 | DB_TOKEN= # Turso Database Token 120 | IMAGE_ENDPOINT= # Modal Image Generation Service Endpoint 121 | AWS_ACCESS_KEY_ID= # AWS Access Key ID 122 | AWS_SECRET_ACCESS_KEY= 123 | AWS_REGION= 124 | ``` 125 | 126 | 2. Then run the following command so that our hypercorn server gets configured 127 | 128 | ```python 129 | python -m hypercorn -c hypercorn-config.toml main:app 130 | ``` 131 | 132 | Once you've done so, you'll see the server starting up locally as seen below. 133 | 134 | ```python 135 | [2025-01-13 01:10:04 +0800] [11316] [INFO] Running on http://0.0.0.0:9080 (CTRL + C to quit) 136 | [2025-01-13 01:10:04 +0800] [11322] [WARNING] ASGI Framework Lifespan error, continuing without Lifespan support 137 | [2025-01-13 01:10:04 +0800] [11322] [INFO] Running on http://0.0.0.0:9080 (CTRL + C to quit) 138 | ``` 139 | 140 | This means that our restate story service is now ready to be used. We just need to configure our restate cloud environment to use our restate story service. To do so, we'll need to run the following command 141 | 142 | ```python 143 | restate deployments register --force 144 | ``` 145 | 146 | This will in turn register our restate story service to our restate cloud environment and we'll be able to use it in our restate cloud environment. 147 | 148 | ### Setting Up NextJS Application 149 | 150 | Now that we've got our restate story service ready, we can start setting up our NextJS application. 151 | 152 | 1. First, make sure that you're in the nextjs directory and run the following command to install the dependencies 153 | 154 | ```python 155 | bun install 156 | ``` 157 | 158 | 2. Next, you'll need to fill in the following environment variables in a .env file in your nextjs directory. 159 | 160 | ``` 161 | AUTH_SECRET= 162 | AUTH_GOOGLE_ID= 163 | AUTH_GOOGLE_SECRET= 164 | TURSO_CONNECTION_URL= 165 | TURSO_AUTH_TOKEN= 166 | RESTATE_ENDPOINT= 167 | RESTATE_TOKEN= 168 | ``` 169 | 170 | Once you've done so, you can start the nextjs application by running the following command 171 | 172 | ```python 173 | bun dev 174 | ``` 175 | 176 | With this, you should be able to see the application running locally and you should be able to start creating your own stories. 177 | -------------------------------------------------------------------------------- /frontend/.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["next/core-web-vitals", "next/typescript"], 3 | "rules": { 4 | "@next/next/no-img-element": "off", 5 | "react-hooks/exhaustive-deps": "off" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the frontend for the CYOA website. 4 | 5 | ## Environment Variables 6 | 7 | The following environment variables are required: 8 | 9 | ``` 10 | AUTH_SECRET= 11 | AUTH_GOOGLE_ID= 12 | AUTH_GOOGLE_SECRET= 13 | TURSO_CONNECTION_URL= 14 | TURSO_AUTH_TOKEN= 15 | ``` 16 | 17 | ## Database 18 | 19 | The database is hosted on [Turso](https://turso.tech/) which provides a SQLite-compatible database using `libsql`. To get started, you can use the following command to generate the database schema: 20 | 21 | ```bash 22 | npx drizzle-kit generate 23 | ``` 24 | 25 | Then, you can run the following command to push the changes to the database: 26 | 27 | ```bash 28 | npx drizzle-kit migrate 29 | ``` 30 | -------------------------------------------------------------------------------- /frontend/bun.lockb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/frontend/bun.lockb -------------------------------------------------------------------------------- /frontend/components.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://ui.shadcn.com/schema.json", 3 | "style": "new-york", 4 | "rsc": true, 5 | "tsx": true, 6 | "tailwind": { 7 | "config": "tailwind.config.ts", 8 | "css": "src/app/globals.css", 9 | "baseColor": "neutral", 10 | "cssVariables": true, 11 | "prefix": "" 12 | }, 13 | "aliases": { 14 | "components": "@/components", 15 | "utils": "@/lib/utils", 16 | "ui": "@/components/ui", 17 | "lib": "@/lib", 18 | "hooks": "@/hooks" 19 | }, 20 | "iconLibrary": "lucide" 21 | } -------------------------------------------------------------------------------- /frontend/drizzle.config.ts: -------------------------------------------------------------------------------- 1 | import { config } from "dotenv"; 2 | import { defineConfig } from "drizzle-kit"; 3 | 4 | config({ path: ".env.local" }); 5 | 6 | export default defineConfig({ 7 | schema: "./src/db/schema.ts", 8 | out: "./migrations", 9 | dialect: "turso", 10 | dbCredentials: { 11 | url: process.env.TURSO_CONNECTION_URL!, 12 | authToken: process.env.TURSO_AUTH_TOKEN!, 13 | }, 14 | }); 15 | -------------------------------------------------------------------------------- /frontend/next.config.mjs: -------------------------------------------------------------------------------- 1 | /** @type {import('next').NextConfig} */ 2 | const nextConfig = { 3 | images: { 4 | domains: ['restate-story.s3.ap-southeast-1.amazonaws.com'], 5 | formats: ['image/avif', 'image/webp'], 6 | remotePatterns: [ 7 | { 8 | protocol: 'https', 9 | hostname: 'restate-story.s3.ap-southeast-1.amazonaws.com', 10 | pathname: '**', 11 | }, 12 | ], 13 | deviceSizes: [640, 750, 828, 1080, 1200, 1920, 2048, 3840], 14 | imageSizes: [16, 32, 48, 64, 96, 128, 256, 384], 15 | }, 16 | experimental: { 17 | optimizeCss: true, 18 | scrollRestoration: true, 19 | }, 20 | reactStrictMode: true, 21 | swcMinify: true, 22 | compress: true, 23 | }; 24 | 25 | export default nextConfig; 26 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "dev": "next dev", 7 | "build": "next build", 8 | "start": "next start", 9 | "lint": "next lint" 10 | }, 11 | "dependencies": { 12 | "@libsql/client": "^0.14.0", 13 | "@radix-ui/react-dialog": "^1.1.4", 14 | "@radix-ui/react-hover-card": "^1.1.6", 15 | "@radix-ui/react-label": "^2.1.1", 16 | "@radix-ui/react-progress": "^1.1.1", 17 | "@radix-ui/react-scroll-area": "^1.2.2", 18 | "@radix-ui/react-slider": "^1.2.2", 19 | "@radix-ui/react-slot": "^1.1.1", 20 | "@radix-ui/react-switch": "^1.1.2", 21 | "@radix-ui/react-toast": "^1.2.4", 22 | "@radix-ui/react-toggle": "^1.1.1", 23 | "@tanstack/react-query": "^5.62.11", 24 | "@vercel/analytics": "^1.5.0", 25 | "class-variance-authority": "^0.7.1", 26 | "clsx": "^2.1.1", 27 | "critters": "^0.0.25", 28 | "dotenv": "^16.4.7", 29 | "drizzle-orm": "^0.38.3", 30 | "lucide-react": "^0.469.0", 31 | "motion": "^11.15.0", 32 | "next": "14.2.22", 33 | "next-auth": "^5.0.0-beta.25", 34 | "react": "^18", 35 | "react-dom": "^18", 36 | "tailwind-merge": "^2.6.0", 37 | "tailwindcss-animate": "^1.0.7" 38 | }, 39 | "devDependencies": { 40 | "@types/node": "^20", 41 | "@types/react": "^18", 42 | "@types/react-dom": "^18", 43 | "drizzle-kit": "^0.30.1", 44 | "eslint": "^8", 45 | "eslint-config-next": "14.2.22", 46 | "postcss": "^8", 47 | "tailwindcss": "^3.4.1", 48 | "typescript": "^5" 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /frontend/postcss.config.mjs: -------------------------------------------------------------------------------- 1 | /** @type {import('postcss-load-config').Config} */ 2 | const config = { 3 | plugins: { 4 | tailwindcss: {}, 5 | }, 6 | }; 7 | 8 | export default config; 9 | -------------------------------------------------------------------------------- /frontend/src/app/about/page.tsx: -------------------------------------------------------------------------------- 1 | import { Terminal } from "lucide-react"; 2 | import { NavigationLink } from "@/components/navigation/navigation-link"; 3 | 4 | export default function AboutPage() { 5 | return ( 6 |
7 | {/* Header */} 8 |
9 |
10 |
11 |
12 | 13 | CYOA-OS v1.0 14 |
15 | 16 | [ABOUT] 17 | 18 |
19 |
20 | 24 | Home 25 | 26 | 30 | Go to App 31 | 32 |
33 |
34 |
35 | 36 | {/* About Content */} 37 |
38 |
39 |

40 | About CYOA-OS 41 |

42 | 43 |
44 |

45 | CYOA-OS is a modern choose-your-own-adventure platform that 46 | combines interactive storytelling with AI-powered narrative 47 | generation. 48 |

49 | 50 |

51 | Built With 52 |

53 |

54 | This website was built with{" "} 55 | 61 | Restate 62 | 63 | ,{" "} 64 | 70 | Modal 71 | 72 | , and{" "} 73 | 79 | Gemini 80 | 81 | , leveraging durable functions and serverless infrastructure to 82 | create a seamless interactive experience. 83 |

84 | 85 |

86 | Learn More 87 |

88 |

89 | Read about how this project was created in the Restate blog post:{" "} 90 | 96 | From Prompt to Adventures: Creating Games with LLMs and 97 | Restate's Durable Functions 98 | 99 | . 100 |

101 | 102 |
103 |

104 | Technology Stack 105 |

106 |
    107 |
  • Next.js for the frontend
  • 108 |
  • Restate for durable functions and state management
  • 109 |
  • Modal for serverless GPU infrastructure
  • 110 |
  • Google Gemini for AI-powered story generation
  • 111 |
112 |
113 |
114 |
115 |
116 | 117 | {/* Footer */} 118 | 169 |
170 | ); 171 | } 172 | -------------------------------------------------------------------------------- /frontend/src/app/api/auth/[...nextauth]/route.ts: -------------------------------------------------------------------------------- 1 | import { handlers } from "@/auth" 2 | export const { GET, POST } = handlers -------------------------------------------------------------------------------- /frontend/src/app/api/stories/route.ts: -------------------------------------------------------------------------------- 1 | import { NextResponse } from "next/server"; 2 | import { auth } from "@/auth"; 3 | import { db } from "@/db/db"; 4 | import { eq } from "drizzle-orm"; 5 | import { storiesTable, usersTable } from "@/db/schema"; 6 | 7 | export async function GET() { 8 | const session = await auth(); 9 | 10 | const userStories = await db 11 | .select({ 12 | id: storiesTable.id, 13 | title: storiesTable.title, 14 | author: usersTable.username, 15 | image_prompt: storiesTable.image_prompt, 16 | description: storiesTable.description, 17 | status: storiesTable.status, 18 | errorMessage: storiesTable.errorMessage, 19 | }) 20 | .from(storiesTable) 21 | .where(eq(storiesTable.userId, session?.user?.email ?? "")) 22 | .innerJoin(usersTable, eq(storiesTable.userId, usersTable.email)); 23 | 24 | return NextResponse.json(userStories); 25 | } 26 | -------------------------------------------------------------------------------- /frontend/src/app/api/user/credits/route.ts: -------------------------------------------------------------------------------- 1 | import { auth } from "@/auth"; 2 | import { db } from "@/db/db"; 3 | import { usersTable } from "@/db/schema"; 4 | 5 | import { eq } from "drizzle-orm"; 6 | import { NextResponse } from "next/server"; 7 | 8 | export async function GET() { 9 | const session = await auth(); 10 | 11 | if (!session?.user?.id) { 12 | return new NextResponse("Unauthorized", { status: 401 }); 13 | } 14 | 15 | const user = await db.query.usersTable.findFirst({ 16 | where: eq(usersTable.email, session.user.username as string), 17 | columns: { 18 | credits: true, 19 | }, 20 | }); 21 | 22 | return NextResponse.json(user?.credits ?? 0); 23 | } 24 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/layout.tsx: -------------------------------------------------------------------------------- 1 | import Link from "next/link"; 2 | import { Terminal } from "lucide-react"; 3 | 4 | import SignOut from "@/components/header/sign-out"; 5 | import { auth } from "@/auth"; 6 | import { redirect } from "next/navigation"; 7 | import ReactQueryProvider from "@/providers/ReactQueryProvider"; 8 | import UserInfo from "@/components/header/user-info"; 9 | import { SessionProvider } from "next-auth/react"; 10 | 11 | export const dynamic = "force-dynamic"; 12 | 13 | const Layout = async ({ children }: { children: React.ReactNode }) => { 14 | const session = await auth(); 15 | 16 | if (!session) { 17 | return redirect("/"); 18 | } 19 | 20 | return ( 21 | 22 | 23 |
24 |
25 |
26 |
27 |
28 | 29 | 30 | CYOA-OS v1.0 31 | 32 |
33 | 37 | [ABOUT] 38 | 39 |
40 |
41 | 42 |
43 | 44 |
45 |
46 |
47 |
48 |
49 |
{children}
50 |
51 | 52 | {/* Footer */} 53 | 104 |
105 |
106 |
107 | ); 108 | }; 109 | 110 | export default Layout; 111 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/loading.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | export default function DashboardLoading() { 4 | return ( 5 |
6 | {/* Terminal Input Loading */} 7 |
8 | 9 | {/* Sample Stories Loading */} 10 |
11 |
12 |
13 | {[...Array(3)].map((_, i) => ( 14 |
15 | ))} 16 |
17 |
18 | 19 | {/* User Stories Loading */} 20 |
21 |
22 |
23 | {[...Array(3)].map((_, i) => ( 24 |
25 | ))} 26 |
27 |
28 |
29 | ); 30 | } 31 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/page.tsx: -------------------------------------------------------------------------------- 1 | import { auth } from "@/auth"; 2 | import StoryList from "@/components/dashboard/story-list"; 3 | import { TerminalInput } from "@/components/dashboard/terminal-input"; 4 | import { db } from "@/db/db"; 5 | import { eq, inArray } from "drizzle-orm"; 6 | import { storiesTable, usersTable } from "@/db/schema"; 7 | import React, { Suspense } from "react"; 8 | import { UsernameInput } from "@/components/dashboard/username-input"; 9 | import { redirect } from "next/navigation"; 10 | import { revalidatePath } from "next/cache"; 11 | import SampleStories from "@/components/dashboard/sample-stories"; 12 | 13 | // Add revalidation time for Incremental Static Regeneration (ISR) 14 | export const revalidate = 30; // Revalidate every 30 seconds 15 | 16 | // Loading components for Suspense 17 | const SampleStoriesLoading = () => ( 18 |
19 |
20 |
21 | {[...Array(3)].map((_, i) => ( 22 |
23 | ))} 24 |
25 |
26 | ); 27 | 28 | const UserStoriesLoading = () => ( 29 |
30 |
31 |
32 | {[...Array(3)].map((_, i) => ( 33 |
34 | ))} 35 |
36 |
37 | ); 38 | 39 | // Function to get sample stories 40 | const getSampleStories = async () => { 41 | // Remove noStore() to enable caching 42 | 43 | const storyIds = process.env.NEXT_PUBLIC_EXAMPLE_STORIES?.split(","); 44 | if (!storyIds) { 45 | return []; 46 | } 47 | const stories = await db 48 | .select() 49 | .from(storiesTable) 50 | .where(inArray(storiesTable.id, storyIds)); 51 | return stories; 52 | }; 53 | 54 | // Function to get user stories 55 | const getUserStories = async (userEmail: string) => { 56 | // Remove noStore() to enable caching 57 | // User stories can be cached briefly as they don't change frequently 58 | 59 | return db 60 | .select({ 61 | id: storiesTable.id, 62 | title: storiesTable.title, 63 | author: usersTable.username, 64 | image: storiesTable.image_prompt, 65 | description: storiesTable.description, 66 | timestamp: storiesTable.timestamp, 67 | public: storiesTable.public, 68 | status: storiesTable.status, 69 | errorMessage: storiesTable.errorMessage, 70 | }) 71 | .from(storiesTable) 72 | .where(eq(storiesTable.userId, userEmail)) 73 | .innerJoin(usersTable, eq(storiesTable.userId, usersTable.email)); 74 | }; 75 | 76 | // Sample stories component with Suspense 77 | const SampleStoriesSection = async () => { 78 | const sampleStories = await getSampleStories(); 79 | return ; 80 | }; 81 | 82 | // User stories component with Suspense 83 | const UserStoriesSection = async ({ userEmail }: { userEmail: string }) => { 84 | const userStories = await getUserStories(userEmail); 85 | 86 | return ( 87 | 105 | ); 106 | }; 107 | 108 | const Dashboard = async () => { 109 | // Start auth check immediately but don't await it yet 110 | const sessionPromise = auth(); 111 | 112 | // Immediately check if we need to revalidate 113 | revalidatePath("/dashboard", "page"); 114 | 115 | // Now await the session 116 | const session = await sessionPromise; 117 | if (!session || !session.user || !session.user.email) { 118 | return redirect("/"); 119 | } 120 | 121 | // Start fetching user data immediately 122 | const dbUserPromise = db 123 | .select() 124 | .from(usersTable) 125 | .where(eq(usersTable.email, session.user.email)) 126 | .get(); 127 | 128 | // Await the user data 129 | const dbUser = await dbUserPromise; 130 | 131 | if (!dbUser || !dbUser.username) { 132 | return ; 133 | } 134 | 135 | return ( 136 |
137 | 138 | 139 | }> 140 | 141 | 142 | 143 | }> 144 | 145 | 146 |
147 | ); 148 | }; 149 | 150 | export default Dashboard; 151 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/story/[slug]/loading.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | export default function StoryLoading() { 4 | return ( 5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | {[...Array(4)].map((_, i) => ( 35 |
36 | ))} 37 |
38 |
39 | ); 40 | } 41 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/story/choice/[node]/loading.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | export default function ChoiceLoading() { 4 | return ( 5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ); 17 | } 18 | -------------------------------------------------------------------------------- /frontend/src/app/dashboard/story/choice/[node]/page.tsx: -------------------------------------------------------------------------------- 1 | import React, { Suspense } from "react"; 2 | 3 | import ChoiceInterface from "@/components/node/choice-interface"; 4 | import { db } from "@/db/db"; 5 | import { storiesTable, storyChoicesTable } from "@/db/schema"; 6 | import { eq } from "drizzle-orm"; 7 | import { auth } from "@/auth"; 8 | import { redirect } from "next/navigation"; 9 | import { NavigationLink } from "@/components/navigation/navigation-link"; 10 | import TerminalChoice from "@/components/node/terminal-choice"; 11 | 12 | // Replace force-dynamic with controlled revalidation 13 | // export const dynamic = "force-dynamic"; 14 | export const revalidate = 30; // Revalidate every 30 seconds 15 | 16 | // Server action to mark choice as explored 17 | // This separates the data mutation from the rendering path 18 | async function markChoiceAsExplored(nodeId: string) { 19 | "use server"; 20 | 21 | await db 22 | .update(storyChoicesTable) 23 | .set({ explored: 1 }) 24 | .where(eq(storyChoicesTable.id, nodeId)) 25 | .execute(); 26 | } 27 | 28 | // Loading component for Suspense 29 | const ChoiceLoading = () => ( 30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ); 42 | 43 | // Function to get choice data 44 | const getChoiceData = async (nodeId: string) => { 45 | // Remove noStore() to enable caching 46 | 47 | const choice = await db.query.storyChoicesTable.findFirst({ 48 | where: eq(storyChoicesTable.id, nodeId), 49 | }); 50 | 51 | if (!choice) { 52 | return { choice: null, story: null, children: [] }; 53 | } 54 | 55 | // Fetch story and children in parallel for better performance 56 | const [story, children] = await Promise.all([ 57 | db.query.storiesTable.findFirst({ 58 | where: eq(storiesTable.id, choice.storyId as string), 59 | }), 60 | db.query.storyChoicesTable.findMany({ 61 | where: eq(storyChoicesTable.parentId, nodeId), 62 | }), 63 | ]); 64 | 65 | return { choice, story, children }; 66 | }; 67 | 68 | // Choice content component to be wrapped in Suspense 69 | const ChoiceContent = async ({ nodeId }: { nodeId: string }) => { 70 | // Start fetching choice data immediately, before auth check 71 | const choiceDataPromise = getChoiceData(nodeId); 72 | 73 | // Perform auth check in parallel 74 | const userObject = await auth(); 75 | if (!userObject) { 76 | return redirect("/"); 77 | } 78 | 79 | // Now await the choice data that was already being fetched 80 | const { choice, story, children } = await choiceDataPromise; 81 | 82 | if (!choice || !story) { 83 | return ( 84 |
85 |
86 | 87 | root@cyoa-os:~$ 88 | 89 | 90 | cyoa make-choice --id {nodeId} --metadata 91 | 92 |
93 |
94 | [ERROR] Story not found in database 95 |
96 |
97 | 101 | ← Back to dashboard 102 | 103 |
104 |
105 | ); 106 | } 107 | 108 | const userId = userObject["user"]["email"]; 109 | const isUserStory = userId === story.userId; 110 | const isPublicStory = story.public; 111 | 112 | if (!isUserStory && !isPublicStory) { 113 | return redirect("/"); 114 | } 115 | 116 | // Use the server action to mark the choice as explored 117 | // This happens in parallel with rendering 118 | if (isUserStory) { 119 | markChoiceAsExplored(nodeId); 120 | } 121 | 122 | if (choice.isTerminal) { 123 | return ; 124 | } 125 | 126 | return ( 127 | 135 | ); 136 | }; 137 | 138 | const NodePage = async ({ params }: { params: { node: string } }) => { 139 | return ( 140 | }> 141 | 142 | 143 | ); 144 | }; 145 | 146 | export default NodePage; 147 | -------------------------------------------------------------------------------- /frontend/src/app/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/frontend/src/app/favicon.ico -------------------------------------------------------------------------------- /frontend/src/app/fonts/GeistMonoVF.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/frontend/src/app/fonts/GeistMonoVF.woff -------------------------------------------------------------------------------- /frontend/src/app/fonts/GeistVF.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/frontend/src/app/fonts/GeistVF.woff -------------------------------------------------------------------------------- /frontend/src/app/globals.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | body { 6 | font-family: Arial, Helvetica, sans-serif; 7 | } 8 | 9 | @layer utilities { 10 | .text-balance { 11 | text-wrap: balance; 12 | } 13 | } 14 | 15 | @layer base { 16 | :root { 17 | --background: 0 0% 100%; 18 | --foreground: 0 0% 3.9%; 19 | --card: 0 0% 100%; 20 | --card-foreground: 0 0% 3.9%; 21 | --popover: 0 0% 100%; 22 | --popover-foreground: 0 0% 3.9%; 23 | --primary: 0 0% 9%; 24 | --primary-foreground: 0 0% 98%; 25 | --secondary: 0 0% 96.1%; 26 | --secondary-foreground: 0 0% 9%; 27 | --muted: 0 0% 96.1%; 28 | --muted-foreground: 0 0% 45.1%; 29 | --accent: 0 0% 96.1%; 30 | --accent-foreground: 0 0% 9%; 31 | --destructive: 0 84.2% 60.2%; 32 | --destructive-foreground: 0 0% 98%; 33 | --border: 0 0% 89.8%; 34 | --input: 0 0% 89.8%; 35 | --ring: 0 0% 3.9%; 36 | --chart-1: 12 76% 61%; 37 | --chart-2: 173 58% 39%; 38 | --chart-3: 197 37% 24%; 39 | --chart-4: 43 74% 66%; 40 | --chart-5: 27 87% 67%; 41 | --radius: 0.5rem; 42 | } 43 | .dark { 44 | --background: 0 0% 3.9%; 45 | --foreground: 0 0% 98%; 46 | --card: 0 0% 3.9%; 47 | --card-foreground: 0 0% 98%; 48 | --popover: 0 0% 3.9%; 49 | --popover-foreground: 0 0% 98%; 50 | --primary: 0 0% 98%; 51 | --primary-foreground: 0 0% 9%; 52 | --secondary: 0 0% 14.9%; 53 | --secondary-foreground: 0 0% 98%; 54 | --muted: 0 0% 14.9%; 55 | --muted-foreground: 0 0% 63.9%; 56 | --accent: 0 0% 14.9%; 57 | --accent-foreground: 0 0% 98%; 58 | --destructive: 0 62.8% 30.6%; 59 | --destructive-foreground: 0 0% 98%; 60 | --border: 0 0% 14.9%; 61 | --input: 0 0% 14.9%; 62 | --ring: 0 0% 83.1%; 63 | --chart-1: 220 70% 50%; 64 | --chart-2: 160 60% 45%; 65 | --chart-3: 30 80% 55%; 66 | --chart-4: 280 65% 60%; 67 | --chart-5: 340 75% 55%; 68 | } 69 | } 70 | 71 | @layer base { 72 | * { 73 | @apply border-border; 74 | } 75 | body { 76 | @apply bg-background text-foreground; 77 | } 78 | } 79 | 80 | @keyframes scanline { 81 | 0% { 82 | background-position: 0 0; 83 | } 84 | 100% { 85 | background-position: 0 100%; 86 | } 87 | } 88 | 89 | @keyframes shine { 90 | 0% { 91 | background-position: 200% 0; 92 | } 93 | 100% { 94 | background-position: -200% 0; 95 | } 96 | } 97 | 98 | @keyframes spinner { 99 | 0% { 100 | content: "⠋"; 101 | } 102 | 10% { 103 | content: "⠙"; 104 | } 105 | 20% { 106 | content: "⠹"; 107 | } 108 | 30% { 109 | content: "⠸"; 110 | } 111 | 40% { 112 | content: "⠼"; 113 | } 114 | 50% { 115 | content: "⠴"; 116 | } 117 | 60% { 118 | content: "⠦"; 119 | } 120 | 70% { 121 | content: "⠧"; 122 | } 123 | 80% { 124 | content: "⠇"; 125 | } 126 | 90% { 127 | content: "⠏"; 128 | } 129 | 100% { 130 | content: "⠋"; 131 | } 132 | } 133 | 134 | .animate-scanline { 135 | animation: scanline 8s linear infinite; 136 | } 137 | 138 | .animate-shine { 139 | animation: shine 3s linear infinite; 140 | } 141 | 142 | .spinner::after { 143 | content: "⠋"; 144 | animation: spinner 1s steps(10) infinite; 145 | } 146 | 147 | -------------------------------------------------------------------------------- /frontend/src/app/layout.tsx: -------------------------------------------------------------------------------- 1 | import type { Metadata } from "next"; 2 | import localFont from "next/font/local"; 3 | import "./globals.css"; 4 | import { Toaster } from "@/components/ui/toaster"; 5 | import { NavigationProgressProvider } from "@/components/navigation/navigation-progress-provider"; 6 | import { Analytics } from "@vercel/analytics/react"; 7 | const geistSans = localFont({ 8 | src: "./fonts/GeistVF.woff", 9 | variable: "--font-geist-sans", 10 | weight: "100 900", 11 | preload: true, 12 | display: "swap", // Use swap to prevent FOIT (Flash of Invisible Text) 13 | }); 14 | const geistMono = localFont({ 15 | src: "./fonts/GeistMonoVF.woff", 16 | variable: "--font-geist-mono", 17 | weight: "100 900", 18 | preload: true, 19 | display: "swap", // Use swap to prevent FOIT 20 | }); 21 | 22 | export const metadata: Metadata = { 23 | title: "CYOA", 24 | description: 25 | "Play through entire choose your own adventure stories, generated by AI", 26 | }; 27 | 28 | export default function RootLayout({ 29 | children, 30 | }: Readonly<{ 31 | children: React.ReactNode; 32 | }>) { 33 | return ( 34 | 35 | 36 | {/* Add preload hints for critical resources */} 37 | 42 | {/* Preconnect to S3 for faster image loading */} 43 | 48 | 49 | 52 | {children} 53 | 54 | 55 | 56 | 57 | ); 58 | } 59 | -------------------------------------------------------------------------------- /frontend/src/app/loading.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | export default function Loading() { 4 | return ( 5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | root@cyoa-os:~$ 14 | loading content... 15 |
16 |
17 |
18 |
19 | ); 20 | } 21 | -------------------------------------------------------------------------------- /frontend/src/app/login/page.tsx: -------------------------------------------------------------------------------- 1 | import TerminalLogin from "@/components/login/terminal-login"; 2 | import React from "react"; 3 | 4 | const Login = () => { 5 | return ( 6 |
7 |
8 | 9 |
10 |
11 | ); 12 | }; 13 | 14 | export default Login; 15 | -------------------------------------------------------------------------------- /frontend/src/app/page.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import { Terminal } from "lucide-react"; 3 | import TerminalAnimation from "@/components/landing-page/terminal"; 4 | import stories from "@/constants/stories.json"; 5 | import Features from "@/components/landing-page/features"; 6 | import { NavigationLink } from "@/components/navigation/navigation-link"; 7 | 8 | export default function LandingPage() { 9 | return ( 10 |
11 | {/* Header */} 12 |
13 |
14 |
15 |
16 | 17 | CYOA-OS v1.0 18 |
19 | 23 | [ABOUT] 24 | 25 |
26 | 30 | Go to App 31 | 32 |
33 |
34 | 35 | {/* Terminal Screen */} 36 | 37 | 38 | {/* Features */} 39 | 40 | 41 | {/* CTA Section */} 42 |
43 |
44 |

45 | Ready to Begin Your Adventure? 46 |

47 | 48 |

49 | Join our community of storytellers and embark on endless cyberpunk 50 | journeys. 51 |

52 |
53 | 57 | Start Your Adventure 58 | 59 |
60 |
61 |
62 | 63 | {/* Footer */} 64 | 115 |
116 | ); 117 | } 118 | -------------------------------------------------------------------------------- /frontend/src/auth.ts: -------------------------------------------------------------------------------- 1 | import NextAuth from "next-auth"; 2 | import Google from "next-auth/providers/google"; 3 | import { usersTable } from "./db/schema"; 4 | import { eq } from "drizzle-orm"; 5 | import { db } from "./db/db"; 6 | import Credentials from "next-auth/providers/credentials"; 7 | 8 | const isProduction = process.env.VERCEL_ENV === "production"; 9 | 10 | const mockGoogleProvider = Credentials({ 11 | id: "google", 12 | name: "Mock Google", 13 | credentials: {}, 14 | async authorize() { 15 | // Return mock user data 16 | return { 17 | id: "mock_id", 18 | name: "Test User", 19 | email: "test@example.com", 20 | image: "https://example.com/placeholder.png", 21 | }; 22 | }, 23 | }); 24 | 25 | const { handlers, signIn, signOut, auth } = NextAuth({ 26 | providers: [isProduction ? Google : mockGoogleProvider], 27 | callbacks: { 28 | signIn: async ({ user }) => { 29 | // We always create a user in the database if the user doesn't exist 30 | const dbUser = await db 31 | .select() 32 | .from(usersTable) 33 | .where(eq(usersTable.email, user.email!)); 34 | 35 | if (dbUser.length === 0) { 36 | // No Username for now 37 | await db.insert(usersTable).values({ 38 | email: user.email!, 39 | credits: 3, 40 | }); 41 | } 42 | return true; 43 | }, 44 | session: async ({ session }) => { 45 | if (!session.user || !session.user.email) { 46 | throw new Error("Session object is missing user or email"); 47 | } 48 | 49 | const user = await db 50 | .select() 51 | .from(usersTable) 52 | .where(eq(usersTable.email, session.user.email)) 53 | .get(); 54 | 55 | if (!user) { 56 | throw new Error("User not found"); 57 | } 58 | 59 | // It will be '' the first time the user logs in and set in the database thereafter once we prompt the user 60 | session.user.username = user.username || ""; 61 | session.user.credits = user.credits || 0; 62 | session.user.isAdmin = user.isAdmin || false; 63 | return session; 64 | }, 65 | }, 66 | }); 67 | 68 | export { handlers, signIn, signOut, auth }; 69 | -------------------------------------------------------------------------------- /frontend/src/components/dashboard/sample-stories.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React from "react"; 4 | import { SelectStory } from "@/db/schema"; 5 | import { useNavigationProgress } from "../navigation/navigation-progress-provider"; 6 | import { useRouter } from "next/navigation"; 7 | 8 | type props = { 9 | stories: SelectStory[]; 10 | }; 11 | 12 | const SampleStories = ({ stories }: props) => { 13 | const { startNavigation } = useNavigationProgress(); 14 | const router = useRouter(); 15 | 16 | const handleStoryClick = (storyId: string) => { 17 | startNavigation(); 18 | router.push(`/dashboard/story/${storyId}`); 19 | }; 20 | 21 | return ( 22 |
23 |
24 | root@cyoa-os:~$ 25 | 26 | cyoa list-stories --filter sample-stories 27 | 28 |
29 |
30 |
31 |
32 | Showing {stories?.length} sample stories 33 |
34 | {stories?.map((story) => ( 35 |
handleStoryClick(story.id)} 39 | > 40 |
41 |
42 | {story.title 47 |
48 |
49 |

{story.title}

50 |

51 | {story.description} 52 |

53 |
54 |
55 |
56 | ))} 57 |
58 |
59 |
60 | ); 61 | }; 62 | 63 | export default SampleStories; 64 | -------------------------------------------------------------------------------- /frontend/src/components/dashboard/username-input.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useState } from "react"; 4 | import { useRouter } from "next/navigation"; 5 | import { validateUsername } from "@/lib/user"; 6 | 7 | export function UsernameInput() { 8 | const [username, setUsername] = useState(""); 9 | const [message, setMessage] = useState(""); 10 | const router = useRouter(); 11 | 12 | const handleSubmit = async (e: React.FormEvent) => { 13 | e.preventDefault(); 14 | const result = await validateUsername(username); 15 | setMessage(result.message); 16 | 17 | if (result.success) { 18 | setTimeout(() => { 19 | router.refresh(); 20 | }, 2000); // Refresh after 2 seconds to show the success message 21 | } 22 | }; 23 | 24 | return ( 25 |
26 |
27 |
28 | root@cyoa-os:~$ 29 | setUsername(e.target.value)} 33 | className="flex-1 bg-transparent text-green-400 outline-none border-none" 34 | placeholder="Enter username..." 35 | spellCheck="false" 36 | autoComplete="off" 37 | autoCapitalize="off" 38 | autoCorrect="off" 39 | autoFocus 40 | /> 41 |
42 |
43 | {message && ( 44 |
49 | {message} 50 |
51 | )} 52 |
53 | ); 54 | } 55 | -------------------------------------------------------------------------------- /frontend/src/components/header/sign-out.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { Button } from "@/components/ui/button"; 4 | import { signOutWithGoogle } from "@/lib/login-server"; 5 | 6 | const SignOut = () => { 7 | return ( 8 | 14 | ); 15 | }; 16 | 17 | export default SignOut; 18 | -------------------------------------------------------------------------------- /frontend/src/components/header/user-info.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useQuery } from "@tanstack/react-query"; 4 | import { Session } from "next-auth"; 5 | 6 | async function fetchUserCredits() { 7 | const response = await fetch("/api/user/credits"); 8 | if (!response.ok) { 9 | throw new Error("Failed to fetch credits"); 10 | } 11 | return response.json(); 12 | } 13 | 14 | export default function UserInfo({ session }: { session: Session }) { 15 | const { data: credits } = useQuery({ 16 | queryKey: ["userCredits"], 17 | queryFn: fetchUserCredits, 18 | staleTime: Infinity, // Never refresh automatically 19 | initialData: session.user.credits, 20 | }); 21 | 22 | return ( 23 |
24 |
25 | {session.user?.username || "Anonymous User"} 26 |
27 |
28 | {session.user?.email} 29 |
30 | {credits !== undefined && ( 31 |
32 | Credits: {credits.toLocaleString()} 33 |
34 | )} 35 |
36 | ); 37 | } 38 | -------------------------------------------------------------------------------- /frontend/src/components/landing-page/features.tsx: -------------------------------------------------------------------------------- 1 | export default function TerminalFeatureList() { 2 | return ( 3 |
4 |
5 |
6 | 7 | root@cyoa-os:~$ ./show-features --interactive 8 | 9 |
10 | 11 |
12 |
13 |
14 | $ 15 |
16 |
17 | 18 | --create-story 19 | 20 | [primary] 21 |
22 |

23 | Create branching narratives where every choice matters. Watch 24 | your story come to life with pixel art. 25 |

26 |
27 |
28 |
29 | 30 |
31 | 32 |
33 |
34 | $ 35 |
36 |
37 | --ai-assist 38 | 39 | [experimental] 40 | 41 |
42 |

43 | Let our AI assist you in crafting unique cyberpunk adventures 44 | with stunning pixel art visuals. 45 |

46 |
47 |
48 |
49 | 50 |
51 | 52 |
53 |
54 | $ 55 |
56 |
57 | --community 58 | [beta] 59 |
60 |

61 | Explore a growing collection of community-created adventures. 62 | Share your own stories with others. 63 |

64 |
65 |
66 |
67 | 68 |
69 | 70 |
71 |

Run with --help for detailed usage information

72 |

Version 1.0.0-alpha

73 |
74 |
75 |
76 |
77 | ); 78 | } 79 | -------------------------------------------------------------------------------- /frontend/src/components/landing-page/terminal.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useState, useEffect, useMemo } from "react"; 4 | import { Sparkles } from "lucide-react"; 5 | import { useTypewriter } from "../../hooks/useTypewriter"; 6 | import { Wrapper } from "../../components/wrapper"; 7 | 8 | type Story = { 9 | prompt: string; 10 | description: string; 11 | choices: string[]; 12 | }; 13 | 14 | interface StoryGeneratorProps { 15 | stories: Story[]; 16 | typingSpeed?: number; 17 | } 18 | 19 | const getRandomDelay = (min = 100, max = 2000) => 20 | Math.floor(Math.random() * (max - min + 1) + min); 21 | 22 | export default function StoryGenerator({ 23 | stories, 24 | typingSpeed = 15, 25 | }: StoryGeneratorProps) { 26 | const [currentStoryIndex, setCurrentStoryIndex] = useState(0); 27 | const [phase, setPhase] = useState< 28 | "prompt" | "thinking" | "story" | "choices" | "transition" 29 | >("prompt"); 30 | const [areChoicesRendered, setAreChoicesRendered] = useState(false); 31 | 32 | const currentStory = stories[currentStoryIndex]; 33 | 34 | const { text: promptText, isComplete: promptComplete } = useTypewriter( 35 | currentStory.prompt, 36 | { startDelay: 0 }, 37 | typingSpeed 38 | ); 39 | 40 | useEffect(() => { 41 | if (promptComplete) { 42 | setPhase("thinking"); 43 | const timer = setTimeout(() => setPhase("story"), 1500); 44 | return () => clearTimeout(timer); 45 | } 46 | }, [promptComplete]); 47 | 48 | useEffect(() => { 49 | if (phase === "choices") { 50 | const timer = setTimeout(() => { 51 | setAreChoicesRendered(true); 52 | }, currentStory.choices.length * 500); 53 | return () => clearTimeout(timer); 54 | } 55 | }, [phase, currentStory.choices.length]); 56 | 57 | useEffect(() => { 58 | if (areChoicesRendered) { 59 | const timer = setTimeout(() => { 60 | handleChoiceSelect(); 61 | }, 2000); 62 | return () => clearTimeout(timer); 63 | } 64 | }, [areChoicesRendered]); 65 | 66 | const choiceDelays = useMemo( 67 | () => currentStory.choices.map(() => getRandomDelay()), 68 | [currentStory.choices] 69 | ); 70 | 71 | const handleChoiceSelect = () => { 72 | setPhase("transition"); 73 | setAreChoicesRendered(false); 74 | setTimeout(() => { 75 | // Loop back to the first story if we're at the end 76 | setCurrentStoryIndex((prev) => (prev + 1) % stories.length); 77 | setPhase("prompt"); 78 | }, 2000); 79 | }; 80 | 81 | return ( 82 |
83 |
84 |
85 |
86 |
87 | 88 | root@cyoa-os:~$ 89 | 90 |
91 |
{promptText}
92 |
93 | 94 | {phase === "thinking" && ( 95 |
96 | 97 | Thinking... 98 |
99 | )} 100 | 101 | {phase !== "prompt" && phase !== "thinking" && ( 102 | 106 |

107 | Loading... 108 |

109 |
110 | } 111 | > 112 | setPhase("choices")} 114 | text={currentStory.description} 115 | typingSpeed={typingSpeed} 116 | isComplete={phase === "choices" || phase === "transition"} 117 | /> 118 | 119 | )} 120 | 121 | {phase === "choices" && ( 122 |
123 | {currentStory.choices.map((choice, index) => ( 124 | 129 |

130 | Loading... 131 |

132 |
133 | } 134 | > 135 | 140 | 141 | ))} 142 |
143 | )} 144 | 145 | {phase === "transition" && ( 146 |
147 | 148 | Transitioning to next story... 149 |
150 | )} 151 | 152 | 153 | 154 | ); 155 | } 156 | 157 | const StoryDescription = ({ 158 | text, 159 | typingSpeed, 160 | onComplete, 161 | isComplete, 162 | }: { 163 | text: string; 164 | typingSpeed: number; 165 | onComplete: () => void; 166 | isComplete: boolean; 167 | }) => { 168 | const { text: storyText } = useTypewriter( 169 | text, 170 | { 171 | startDelay: 0, 172 | onComplete, 173 | skipAnimation: isComplete, 174 | }, 175 | typingSpeed 176 | ); 177 | 178 | return ( 179 |
180 |

{storyText}

181 |
182 | ); 183 | }; 184 | 185 | interface ChoiceProps { 186 | number: number; 187 | text: string; 188 | speed: number; 189 | } 190 | 191 | function Choice({ number, text, speed }: ChoiceProps) { 192 | const { text: choiceText } = useTypewriter( 193 | text, 194 | { 195 | startDelay: 0, 196 | }, 197 | speed 198 | ); 199 | 200 | return ( 201 |
202 | 203 | {number}. 204 | {choiceText} 205 | 206 |
207 | ); 208 | } 209 | -------------------------------------------------------------------------------- /frontend/src/components/login/terminal-login.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useTypewriter } from "@/hooks/useTypewriter"; 4 | import { Fingerprint, Loader2, Sparkles, X } from "lucide-react"; 5 | import { motion } from "framer-motion"; 6 | import React, { useEffect, useState, useTransition } from "react"; 7 | import { signInWithGoogle } from "@/lib/login-server"; 8 | import { Button } from "../ui/button"; 9 | import Link from "next/link"; 10 | 11 | interface TerminalLineConfig { 12 | text: string; 13 | loaderDuration: number; 14 | startDelay: number; 15 | loaderText: string; 16 | LoaderIcon?: React.ElementType; // Add optional loader icon prop 17 | } 18 | 19 | const TerminalLine = ({ 20 | config, 21 | onComplete, 22 | }: { 23 | config: TerminalLineConfig; 24 | onComplete: () => void; 25 | }) => { 26 | const [renderLoader, setRenderLoader] = useState(false); 27 | 28 | const { text: renderText } = useTypewriter( 29 | config.text, 30 | { 31 | startDelay: config.startDelay, 32 | onComplete: () => setRenderLoader(true), 33 | skipAnimation: false, 34 | }, 35 | 20 36 | ); 37 | 38 | useEffect(() => { 39 | if (renderLoader) { 40 | setTimeout(() => { 41 | onComplete(); 42 | setRenderLoader(false); 43 | }, config.loaderDuration); 44 | } 45 | }, [renderLoader]); 46 | 47 | const LoaderIcon = config.LoaderIcon || Sparkles; // Use provided icon or fallback to Sparkles 48 | 49 | return ( 50 |
51 |
52 | 53 | root@cyoa-os:~$ 54 | 55 | 56 | {renderText} 57 | 58 | 59 |
60 | {renderLoader && ( 61 | 71 | 82 |
83 | 84 | 85 | {config.loaderText} 86 | 87 |
88 |
89 |
90 | )} 91 |
92 | ); 93 | }; 94 | 95 | const TerminalLogin = () => { 96 | const [isPending, startTransition] = useTransition(); 97 | const terminalLines: TerminalLineConfig[] = [ 98 | { 99 | text: "User login attempt detected", 100 | loaderDuration: 3000, 101 | loaderText: "Authenticating user now", 102 | startDelay: 500, 103 | LoaderIcon: Fingerprint, 104 | }, 105 | { 106 | text: "User authentication unsuccessful", 107 | loaderDuration: 3000, 108 | loaderText: "Initialising sign in flow", 109 | startDelay: 500, 110 | LoaderIcon: Sparkles, 111 | }, 112 | ]; 113 | const [completedLines, setCompletedLines] = useState( 114 | [] 115 | ); 116 | 117 | const handleComplete = () => { 118 | setCompletedLines((prev) => [...prev, terminalLines[prev.length]]); 119 | }; 120 | 121 | return ( 122 |
123 |
124 |
125 |
126 | 130 | 131 | 132 |
133 |
134 | {completedLines.map((line, index) => ( 135 |
136 |
137 | 138 | root@cyoa-os:~$ 139 | 140 | 141 | {line.text} 142 | 143 |
144 |
145 | ))} 146 | {completedLines.length < terminalLines.length && ( 147 | 151 | )} 152 | {completedLines.length === terminalLines.length && ( 153 | 159 | 177 | 178 | )} 179 |
180 |
181 |
182 |
183 | ); 184 | }; 185 | 186 | export default TerminalLogin; 187 | -------------------------------------------------------------------------------- /frontend/src/components/navigation/navigation-link.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import Link from "next/link"; 4 | import { useNavigationProgress } from "./navigation-progress-provider"; 5 | import React from "react"; 6 | 7 | type NavigationLinkProps = React.ComponentProps; 8 | 9 | export function NavigationLink({ 10 | children, 11 | onClick, 12 | prefetch = true, 13 | ...props 14 | }: NavigationLinkProps) { 15 | const { startNavigation } = useNavigationProgress(); 16 | 17 | const handleClick = (e: React.MouseEvent) => { 18 | // Start navigation spinner 19 | startNavigation(); 20 | 21 | // Call original onClick if provided 22 | if (onClick) { 23 | onClick(e); 24 | } 25 | }; 26 | 27 | return ( 28 | 29 | {children} 30 | 31 | ); 32 | } 33 | -------------------------------------------------------------------------------- /frontend/src/components/navigation/navigation-progress-provider.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { usePathname } from "next/navigation"; 4 | import React, { 5 | createContext, 6 | useContext, 7 | useEffect, 8 | useState, 9 | useRef, 10 | } from "react"; 11 | import { NavigationSpinner } from "./navigation-spinner"; 12 | 13 | type NavigationProgressContextType = { 14 | isNavigating: boolean; 15 | startNavigation: () => void; 16 | endNavigation: () => void; 17 | }; 18 | 19 | const NavigationProgressContext = createContext({ 20 | isNavigating: false, 21 | startNavigation: () => {}, 22 | endNavigation: () => {}, 23 | }); 24 | 25 | export const useNavigationProgress = () => 26 | useContext(NavigationProgressContext); 27 | 28 | export function NavigationProgressProvider({ 29 | children, 30 | }: { 31 | children: React.ReactNode; 32 | }) { 33 | const [isNavigating, setIsNavigating] = useState(false); 34 | const pathname = usePathname(); 35 | const initialized = useRef(false); 36 | 37 | // Reset navigation state when pathname changes 38 | useEffect(() => { 39 | setIsNavigating(false); 40 | }, [pathname]); 41 | 42 | // Listen for navigation events - but only after initial render 43 | useEffect(() => { 44 | // Skip event listener setup on first render to improve initial load time 45 | if (!initialized.current) { 46 | initialized.current = true; 47 | return; 48 | } 49 | 50 | const handleStart = () => { 51 | setIsNavigating(true); 52 | }; 53 | 54 | const handleComplete = () => { 55 | setIsNavigating(false); 56 | }; 57 | 58 | // Add event listeners for navigation events 59 | window.addEventListener("beforeunload", handleStart); 60 | document.addEventListener("navigationStart", handleStart); 61 | document.addEventListener("navigationComplete", handleComplete); 62 | 63 | return () => { 64 | window.removeEventListener("beforeunload", handleStart); 65 | document.removeEventListener("navigationStart", handleStart); 66 | document.removeEventListener("navigationComplete", handleComplete); 67 | }; 68 | }, [initialized.current]); 69 | 70 | const startNavigation = () => { 71 | setIsNavigating(true); 72 | document.dispatchEvent(new Event("navigationStart")); 73 | }; 74 | 75 | const endNavigation = () => { 76 | setIsNavigating(false); 77 | document.dispatchEvent(new Event("navigationComplete")); 78 | }; 79 | 80 | return ( 81 | 84 | {isNavigating && } 85 | {children} 86 | 87 | ); 88 | } 89 | -------------------------------------------------------------------------------- /frontend/src/components/navigation/navigation-spinner.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React from "react"; 4 | 5 | export function NavigationSpinner() { 6 | return ( 7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | root@cyoa-os:~$ 16 | loading content... 17 |
18 |
19 |
20 |
21 | ); 22 | } 23 | -------------------------------------------------------------------------------- /frontend/src/components/node/audio-player.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useState, useEffect, useRef } from "react"; 4 | import { Volume2, VolumeX, Play, Pause } from "lucide-react"; 5 | import { Slider } from "@/components/ui/slider"; 6 | import { cn } from "@/lib/utils"; 7 | 8 | interface Props { 9 | muted?: boolean; 10 | story_id: string; 11 | node_id: string; 12 | } 13 | 14 | const RetroAudioPlayer = ({ 15 | muted: initialMuted = false, 16 | story_id, 17 | node_id, 18 | }: Props) => { 19 | const audioRef = useRef(null); 20 | const [isPlaying, setIsPlaying] = useState(false); 21 | const [muted, setMuted] = useState(initialMuted); 22 | const [currentTime, setCurrentTime] = useState(0); 23 | const [duration, setDuration] = useState(0); 24 | const [hasInteracted, setHasInteracted] = useState(false); 25 | const [isAudioLoaded, setIsAudioLoaded] = useState(false); 26 | 27 | // Set up audio element and event listeners 28 | useEffect(() => { 29 | const audio = audioRef.current; 30 | if (!audio) return; 31 | 32 | // Force the browser to start downloading the audio file 33 | audio.load(); 34 | 35 | const handleLoadedMetadata = () => { 36 | setDuration(audio.duration); 37 | setIsAudioLoaded(true); 38 | console.log("Audio metadata loaded, ready to play on interaction"); 39 | }; 40 | 41 | const handleTimeUpdate = () => setCurrentTime(audio.currentTime); 42 | const handleEnded = () => setIsPlaying(false); 43 | const handlePause = () => setIsPlaying(false); 44 | const handlePlay = () => setIsPlaying(true); 45 | 46 | audio.addEventListener("loadedmetadata", handleLoadedMetadata); 47 | audio.addEventListener("timeupdate", handleTimeUpdate); 48 | audio.addEventListener("ended", handleEnded); 49 | audio.addEventListener("pause", handlePause); 50 | audio.addEventListener("play", handlePlay); 51 | 52 | return () => { 53 | audio.removeEventListener("loadedmetadata", handleLoadedMetadata); 54 | audio.removeEventListener("timeupdate", handleTimeUpdate); 55 | audio.removeEventListener("ended", handleEnded); 56 | audio.removeEventListener("pause", handlePause); 57 | audio.removeEventListener("play", handlePlay); 58 | }; 59 | }, []); 60 | 61 | // Handle user interactions to enable audio 62 | useEffect(() => { 63 | const handleUserInteraction = async () => { 64 | if (hasInteracted || !isAudioLoaded) return; 65 | 66 | const audio = audioRef.current; 67 | if (!audio) return; 68 | 69 | setHasInteracted(true); 70 | 71 | try { 72 | await audio.play(); 73 | setIsPlaying(true); 74 | } catch (error) { 75 | console.error("Playback after interaction failed:", error); 76 | } 77 | }; 78 | 79 | // Listen for various user interactions 80 | const interactionEvents = [ 81 | "click", 82 | "touchstart", 83 | "keydown", 84 | "scroll", 85 | "mousemove", 86 | ]; 87 | 88 | interactionEvents.forEach((event) => { 89 | window.addEventListener(event, handleUserInteraction, { once: true }); 90 | }); 91 | 92 | return () => { 93 | interactionEvents.forEach((event) => { 94 | window.removeEventListener(event, handleUserInteraction); 95 | }); 96 | }; 97 | }, [hasInteracted, isAudioLoaded]); 98 | 99 | const togglePlayPause = async () => { 100 | const audio = audioRef.current; 101 | if (!audio) return; 102 | 103 | try { 104 | if (audio.paused) { 105 | await audio.play(); 106 | setHasInteracted(true); 107 | } else { 108 | audio.pause(); 109 | } 110 | } catch (error) { 111 | console.error("Playback failed:", error); 112 | } 113 | }; 114 | 115 | const toggleMute = () => { 116 | const audio = audioRef.current; 117 | if (!audio) return; 118 | audio.muted = !muted; 119 | setMuted(!muted); 120 | }; 121 | 122 | // Ensure audio playback state matches component state 123 | useEffect(() => { 124 | const audio = audioRef.current; 125 | if (!audio) return; 126 | 127 | if (isPlaying && audio.paused) { 128 | audio.play().catch((error) => console.error("Playback failed:", error)); 129 | } else if (!isPlaying && !audio.paused) { 130 | audio.pause(); 131 | } 132 | }, [isPlaying]); 133 | 134 | return ( 135 |
136 |
188 | ); 189 | }; 190 | 191 | export default RetroAudioPlayer; 192 | -------------------------------------------------------------------------------- /frontend/src/components/node/choice-interface.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import { SelectStoryChoice } from "@/db/schema"; 3 | import { useEffect, useState } from "react"; 4 | import { useRouter } from "next/navigation"; 5 | import { motion } from "framer-motion"; 6 | import AutoAudioPlayer from "./audio-player"; 7 | import { HoverCard } from "@radix-ui/react-hover-card"; 8 | import { HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 9 | import { NavigationLink } from "../navigation/navigation-link"; 10 | import { useNavigationProgress } from "../navigation/navigation-progress-provider"; 11 | import Image from "next/image"; 12 | 13 | type ChoiceInterfaceProps = { 14 | title: string; 15 | description: string; 16 | choices: SelectStoryChoice[]; 17 | choiceId: string; 18 | storyId: string; 19 | imagePrompt: string; 20 | }; 21 | 22 | const ChoiceInterface = ({ 23 | title, 24 | description, 25 | choices, 26 | choiceId, 27 | storyId, 28 | imagePrompt, 29 | }: ChoiceInterfaceProps) => { 30 | const [selectedChoice, setSelectedChoice] = useState(0); 31 | const router = useRouter(); 32 | const { startNavigation } = useNavigationProgress(); 33 | 34 | useEffect(() => { 35 | function handleKeyDown(e: KeyboardEvent) { 36 | switch (e.key) { 37 | case "ArrowUp": 38 | setSelectedChoice((prev) => (prev > 0 ? prev - 1 : prev)); 39 | break; 40 | case "ArrowDown": 41 | setSelectedChoice((prev) => 42 | prev < choices.length - 1 ? prev + 1 : prev 43 | ); 44 | break; 45 | case "Enter": 46 | startNavigation(); 47 | router.push(`/dashboard/story/choice/${choices[selectedChoice].id}`); 48 | break; 49 | case "Escape": 50 | startNavigation(); 51 | router.push( 52 | `/dashboard/story/${choices[selectedChoice].storyId}?node=${choices[selectedChoice].id}` 53 | ); 54 | break; 55 | } 56 | } 57 | 58 | window.addEventListener("keydown", handleKeyDown); 59 | return () => window.removeEventListener("keydown", handleKeyDown); 60 | }, [selectedChoice, router, choices, startNavigation]); 61 | 62 | return ( 63 |
64 |
65 |
66 | {/* For larger screens - use HoverCard */} 67 |
68 | 69 | 70 |
71 | Story Banner 79 |
80 | Hover to see image prompt 81 |
82 |
83 |
84 | 85 |

86 | {imagePrompt} 87 |

88 |
89 |
90 |
91 | 92 | {/* For mobile screens - show image and prompt directly */} 93 |
94 |
95 | Story Banner 103 |
104 |
105 |

106 | Image Prompt:{" "} 107 | {imagePrompt} 108 |

109 |
110 |
111 | 112 |
113 |
114 | {`>`} 115 |

{title}

116 |
117 |

118 | {description} 119 |

120 |
121 | 122 |
123 |

↵ ENTER to confirm selection

124 |

↑↓ ARROWS to navigate choices

125 |

ESC to return to main story

126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 | {choices.map((choice, index) => ( 134 | 138 | setSelectedChoice(index)} 145 | layout 146 | transition={{ 147 | layout: { duration: 0.2, ease: "easeInOut" }, 148 | }} 149 | > 150 |
151 |

158 | {choice.choice_title} 159 |

160 | 161 | 162 | {choice.choice_description} 163 | 164 |
165 |
166 |
167 | ))} 168 |
169 |
170 |
171 |
172 |
173 | ); 174 | }; 175 | 176 | export default ChoiceInterface; 177 | -------------------------------------------------------------------------------- /frontend/src/components/node/terminal-choice.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { SelectStoryChoice } from "@/db/schema"; 4 | import { useRouter } from "next/navigation"; 5 | import { useEffect } from "react"; 6 | import AutoAudioPlayer from "./audio-player"; 7 | import { HoverCard } from "@radix-ui/react-hover-card"; 8 | import { HoverCardContent, HoverCardTrigger } from "../ui/hover-card"; 9 | import { useNavigationProgress } from "../navigation/navigation-progress-provider"; 10 | import Image from "next/image"; 11 | 12 | type TerminalChoiceProps = { 13 | choice: SelectStoryChoice; 14 | }; 15 | 16 | const TerminalChoice = ({ choice }: TerminalChoiceProps) => { 17 | const router = useRouter(); 18 | const { startNavigation } = useNavigationProgress(); 19 | 20 | useEffect(() => { 21 | const handleKeyDown = async (e: KeyboardEvent) => { 22 | if (e.key === "Escape") { 23 | startNavigation(); 24 | router.push( 25 | `/dashboard/story/${choice.storyId}?prev_node=${choice.id}` 26 | ); 27 | } else if (e.key === "Backspace") { 28 | startNavigation(); 29 | router.push("/dashboard"); 30 | } 31 | }; 32 | 33 | window.addEventListener("keydown", handleKeyDown); 34 | return () => window.removeEventListener("keydown", handleKeyDown); 35 | }, [router, choice.storyId, choice.id, startNavigation]); 36 | 37 | return ( 38 |
39 |
40 |
41 | END OF STORY 42 |
43 |
44 | {/* For larger screens - use HoverCard */} 45 |
46 | 47 | 48 |
49 | Story Banner 57 |
58 | Hover to see image prompt 59 |
60 |
61 |
62 | 63 |

64 | {choice.image_prompt} 65 |

66 |
67 |
68 |
69 | 70 | {/* For mobile screens - show image and prompt directly */} 71 |
72 |
73 | Story Banner 81 |
82 |
83 |

84 | Image Prompt:{" "} 85 | {choice.image_prompt} 86 |

87 |
88 |
89 | 90 |
91 |

92 | {choice.choice_title} 93 |

94 |

95 | {choice.description} 96 |

97 | 98 |
99 |
100 | 101 |
102 |
{ 105 | startNavigation(); 106 | router.push( 107 | `/dashboard/story/${choice.storyId}?prev_node=${choice.id}` 108 | ); 109 | }} 110 | > 111 | 112 | 113 | ESC 114 | 115 | 116 | 117 | 118 | 119 | › 120 | 121 | to return to main story 122 | 123 |
124 | 125 |
{ 128 | startNavigation(); 129 | router.push("/dashboard"); 130 | }} 131 | > 132 | 133 | 134 | 141 | 147 | 148 | BKSP 149 | 150 | 151 | 152 | 153 | 154 | › 155 | 156 | to return to dashboard 157 | 158 |
159 |
160 |
161 |
162 | ); 163 | }; 164 | 165 | export default TerminalChoice; 166 | -------------------------------------------------------------------------------- /frontend/src/components/story/reset-story.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import React, { useState } from "react"; 3 | import { 4 | Dialog, 5 | DialogClose, 6 | DialogContent, 7 | DialogDescription, 8 | DialogFooter, 9 | DialogHeader, 10 | DialogTitle, 11 | DialogTrigger, 12 | } from "@/components/ui/dialog"; 13 | import { Button } from "@/components/ui/button"; 14 | import { resetStoryProgress } from "@/lib/story"; 15 | import { useRouter } from "next/navigation"; 16 | 17 | const ResetStory = ({ storyId }: { storyId: string }) => { 18 | const router = useRouter(); 19 | const [isOpen, setIsOpen] = useState(false); 20 | return ( 21 | 22 | 23 |
24 | Reset Progress 25 |
26 |
27 | 28 | 29 | 30 | Reset Story Progress 31 | 32 | 33 | Are you sure you want to reset your progress? This action cannot be 34 | undone. 35 | 36 | 37 | 38 | 39 | 45 | 46 | 58 | 59 | 60 |
61 | ); 62 | }; 63 | 64 | export default ResetStory; 65 | -------------------------------------------------------------------------------- /frontend/src/components/story/story-choice.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import React from "react"; 3 | import { buildTree, getPath } from "@/lib/tree"; 4 | import { SelectStoryChoice } from "@/db/schema"; 5 | import { useRouter } from "next/navigation"; 6 | import { NavigationLink } from "../navigation/navigation-link"; 7 | 8 | type ChoiceNode = SelectStoryChoice & { children: ChoiceNode[] }; 9 | 10 | interface StoryChoiceNodeProps { 11 | node: ChoiceNode; 12 | onSelect: (node: ChoiceNode) => void; 13 | selectedId: string | null; 14 | isLast?: boolean; 15 | } 16 | 17 | const StoryChoiceNode = ({ 18 | node, 19 | onSelect, 20 | selectedId, 21 | isLast = true, 22 | }: StoryChoiceNodeProps) => { 23 | const [isExpanded, setIsExpanded] = React.useState(false); 24 | const hasChildren = node.children.length > 0; 25 | 26 | return ( 27 |
28 |
29 | 30 | 31 | {isLast ? "└── " : "└── "} 32 | 33 | { 35 | onSelect(node); 36 | setIsExpanded(!isExpanded); 37 | }} 38 | className={`font-mono cursor-pointer px-2 py-0.5 rounded flex items-center gap-2 text-xs sm:text-sm break-all ${ 39 | selectedId === node.id 40 | ? "bg-green-950 text-green-400" 41 | : "group-hover:text-green-400" 42 | }`} 43 | > 44 | {node.choice_title}{" "} 45 | {hasChildren && ( 46 | { 49 | e.stopPropagation(); 50 | setIsExpanded(!isExpanded); 51 | }} 52 | > 53 | {isExpanded ? "" : "+"} 54 | 55 | )} 56 | 57 |
58 | {isExpanded && hasChildren && ( 59 |
60 | {node.children.map((child: ChoiceNode, index: number) => ( 61 | 68 | ))} 69 |
70 | )} 71 |
72 | ); 73 | }; 74 | 75 | const StoryChoices = ({ 76 | choices, 77 | isUserStory, 78 | }: { 79 | choices: SelectStoryChoice[]; 80 | isUserStory: boolean; 81 | }) => { 82 | const validChoices = choices.filter((choice) => choice.explored === 1); 83 | 84 | const tree = buildTree(validChoices, "NULL"); 85 | 86 | const [selectedId, setSelectedId] = React.useState(tree[0].id); 87 | 88 | const selectedPath = getPath(choices, selectedId); 89 | const router = useRouter(); 90 | 91 | if (!isUserStory) { 92 | return ( 93 |
94 |

95 | WARNING: This is a publicly shared story, progress is not saved 96 |

97 | 101 | START HERE 102 | 103 |
104 | ); 105 | } 106 | 107 | return ( 108 |
109 |
110 |
111 | {selectedPath.map((choice: SelectStoryChoice, index: number) => ( 112 | 113 | {index > 0 && } 114 | 115 | {choice.choice_title.toLowerCase().replace(/\s+/g, "_")} 116 | 117 | 118 | ))} 119 |
120 |
121 | 122 |
123 | {tree.map((node: ChoiceNode) => ( 124 | { 128 | router.prefetch(`/dashboard/story/choice/${node.id}`); 129 | setSelectedId(node.id); 130 | }} 131 | selectedId={selectedId} 132 | /> 133 | ))} 134 |
135 | {selectedPath.at(-1)?.description} 136 |
137 |
138 | 142 | Start Here 143 | 144 |
145 |
146 | ); 147 | }; 148 | 149 | export default StoryChoices; 150 | -------------------------------------------------------------------------------- /frontend/src/components/story/story-visibility-toggle.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useState } from "react"; 4 | import { Switch } from "@/components/ui/switch"; 5 | import { Button } from "@/components/ui/button"; 6 | import { Check, Copy } from "lucide-react"; 7 | import { cn } from "@/lib/utils"; 8 | import { useToast } from "@/hooks/use-toast"; 9 | import { toggleStoryVisibility } from "@/lib/story"; 10 | 11 | interface StoryVisibilityToggleProps { 12 | storyId: string; 13 | isPublic: boolean; 14 | } 15 | 16 | export default function StoryVisibilityToggle({ 17 | storyId, 18 | isPublic, 19 | }: StoryVisibilityToggleProps) { 20 | const [isPublicState, setIsPublicState] = useState(isPublic); 21 | const [isCopied, setIsCopied] = useState(false); 22 | const toast = useToast(); 23 | 24 | const handleToggle = async () => { 25 | const newState = !isPublicState; 26 | setIsPublicState(newState); 27 | await toggleStoryVisibility(storyId, isPublicState); 28 | }; 29 | 30 | const handleCopy = async () => { 31 | const link = `${window.origin}/dashboard/story/${storyId}`; 32 | await navigator.clipboard.writeText(link); 33 | toast.toast({ 34 | title: "Successfully Copied Link", 35 | description: "Share this link with your friends", 36 | }); 37 | setIsCopied(true); 38 | setTimeout(() => setIsCopied(false), 2000); 39 | }; 40 | 41 | return ( 42 |
43 |
44 | 49 | 50 | {isPublicState ? "Public" : "Private"} 51 | 52 |
53 | 71 |
72 | ); 73 | } 74 | -------------------------------------------------------------------------------- /frontend/src/components/ui/button.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react" 2 | import { Slot } from "@radix-ui/react-slot" 3 | import { cva, type VariantProps } from "class-variance-authority" 4 | 5 | import { cn } from "@/lib/utils" 6 | 7 | const buttonVariants = cva( 8 | "inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0", 9 | { 10 | variants: { 11 | variant: { 12 | default: 13 | "bg-primary text-primary-foreground shadow hover:bg-primary/90", 14 | destructive: 15 | "bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90", 16 | outline: 17 | "border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground", 18 | secondary: 19 | "bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80", 20 | ghost: "hover:bg-accent hover:text-accent-foreground", 21 | link: "text-primary underline-offset-4 hover:underline", 22 | }, 23 | size: { 24 | default: "h-9 px-4 py-2", 25 | sm: "h-8 rounded-md px-3 text-xs", 26 | lg: "h-10 rounded-md px-8", 27 | icon: "h-9 w-9", 28 | }, 29 | }, 30 | defaultVariants: { 31 | variant: "default", 32 | size: "default", 33 | }, 34 | } 35 | ) 36 | 37 | export interface ButtonProps 38 | extends React.ButtonHTMLAttributes, 39 | VariantProps { 40 | asChild?: boolean 41 | } 42 | 43 | const Button = React.forwardRef( 44 | ({ className, variant, size, asChild = false, ...props }, ref) => { 45 | const Comp = asChild ? Slot : "button" 46 | return ( 47 | 52 | ) 53 | } 54 | ) 55 | Button.displayName = "Button" 56 | 57 | export { Button, buttonVariants } 58 | -------------------------------------------------------------------------------- /frontend/src/components/ui/card.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react" 2 | 3 | import { cn } from "@/lib/utils" 4 | 5 | const Card = React.forwardRef< 6 | HTMLDivElement, 7 | React.HTMLAttributes 8 | >(({ className, ...props }, ref) => ( 9 |
17 | )) 18 | Card.displayName = "Card" 19 | 20 | const CardHeader = React.forwardRef< 21 | HTMLDivElement, 22 | React.HTMLAttributes 23 | >(({ className, ...props }, ref) => ( 24 |
29 | )) 30 | CardHeader.displayName = "CardHeader" 31 | 32 | const CardTitle = React.forwardRef< 33 | HTMLDivElement, 34 | React.HTMLAttributes 35 | >(({ className, ...props }, ref) => ( 36 |
41 | )) 42 | CardTitle.displayName = "CardTitle" 43 | 44 | const CardDescription = React.forwardRef< 45 | HTMLDivElement, 46 | React.HTMLAttributes 47 | >(({ className, ...props }, ref) => ( 48 |
53 | )) 54 | CardDescription.displayName = "CardDescription" 55 | 56 | const CardContent = React.forwardRef< 57 | HTMLDivElement, 58 | React.HTMLAttributes 59 | >(({ className, ...props }, ref) => ( 60 |
61 | )) 62 | CardContent.displayName = "CardContent" 63 | 64 | const CardFooter = React.forwardRef< 65 | HTMLDivElement, 66 | React.HTMLAttributes 67 | >(({ className, ...props }, ref) => ( 68 |
73 | )) 74 | CardFooter.displayName = "CardFooter" 75 | 76 | export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent } 77 | -------------------------------------------------------------------------------- /frontend/src/components/ui/dialog.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as DialogPrimitive from "@radix-ui/react-dialog" 5 | import { X } from "lucide-react" 6 | 7 | import { cn } from "@/lib/utils" 8 | 9 | const Dialog = DialogPrimitive.Root 10 | 11 | const DialogTrigger = DialogPrimitive.Trigger 12 | 13 | const DialogPortal = DialogPrimitive.Portal 14 | 15 | const DialogClose = DialogPrimitive.Close 16 | 17 | const DialogOverlay = React.forwardRef< 18 | React.ElementRef, 19 | React.ComponentPropsWithoutRef 20 | >(({ className, ...props }, ref) => ( 21 | 29 | )) 30 | DialogOverlay.displayName = DialogPrimitive.Overlay.displayName 31 | 32 | const DialogContent = React.forwardRef< 33 | React.ElementRef, 34 | React.ComponentPropsWithoutRef 35 | >(({ className, children, ...props }, ref) => ( 36 | 37 | 38 | 46 | {children} 47 | 48 | 49 | Close 50 | 51 | 52 | 53 | )) 54 | DialogContent.displayName = DialogPrimitive.Content.displayName 55 | 56 | const DialogHeader = ({ 57 | className, 58 | ...props 59 | }: React.HTMLAttributes) => ( 60 |
67 | ) 68 | DialogHeader.displayName = "DialogHeader" 69 | 70 | const DialogFooter = ({ 71 | className, 72 | ...props 73 | }: React.HTMLAttributes) => ( 74 |
81 | ) 82 | DialogFooter.displayName = "DialogFooter" 83 | 84 | const DialogTitle = React.forwardRef< 85 | React.ElementRef, 86 | React.ComponentPropsWithoutRef 87 | >(({ className, ...props }, ref) => ( 88 | 96 | )) 97 | DialogTitle.displayName = DialogPrimitive.Title.displayName 98 | 99 | const DialogDescription = React.forwardRef< 100 | React.ElementRef, 101 | React.ComponentPropsWithoutRef 102 | >(({ className, ...props }, ref) => ( 103 | 108 | )) 109 | DialogDescription.displayName = DialogPrimitive.Description.displayName 110 | 111 | export { 112 | Dialog, 113 | DialogPortal, 114 | DialogOverlay, 115 | DialogTrigger, 116 | DialogClose, 117 | DialogContent, 118 | DialogHeader, 119 | DialogFooter, 120 | DialogTitle, 121 | DialogDescription, 122 | } 123 | -------------------------------------------------------------------------------- /frontend/src/components/ui/hover-card.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as HoverCardPrimitive from "@radix-ui/react-hover-card" 5 | 6 | import { cn } from "@/lib/utils" 7 | 8 | const HoverCard = HoverCardPrimitive.Root 9 | 10 | const HoverCardTrigger = HoverCardPrimitive.Trigger 11 | 12 | const HoverCardContent = React.forwardRef< 13 | React.ElementRef, 14 | React.ComponentPropsWithoutRef 15 | >(({ className, align = "center", sideOffset = 4, ...props }, ref) => ( 16 | 26 | )) 27 | HoverCardContent.displayName = HoverCardPrimitive.Content.displayName 28 | 29 | export { HoverCard, HoverCardTrigger, HoverCardContent } 30 | -------------------------------------------------------------------------------- /frontend/src/components/ui/input.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react" 2 | 3 | import { cn } from "@/lib/utils" 4 | 5 | const Input = React.forwardRef>( 6 | ({ className, type, ...props }, ref) => { 7 | return ( 8 | 17 | ) 18 | } 19 | ) 20 | Input.displayName = "Input" 21 | 22 | export { Input } 23 | -------------------------------------------------------------------------------- /frontend/src/components/ui/label.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as LabelPrimitive from "@radix-ui/react-label" 5 | import { cva, type VariantProps } from "class-variance-authority" 6 | 7 | import { cn } from "@/lib/utils" 8 | 9 | const labelVariants = cva( 10 | "text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70" 11 | ) 12 | 13 | const Label = React.forwardRef< 14 | React.ElementRef, 15 | React.ComponentPropsWithoutRef & 16 | VariantProps 17 | >(({ className, ...props }, ref) => ( 18 | 23 | )) 24 | Label.displayName = LabelPrimitive.Root.displayName 25 | 26 | export { Label } 27 | -------------------------------------------------------------------------------- /frontend/src/components/ui/progress.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as ProgressPrimitive from "@radix-ui/react-progress" 5 | 6 | import { cn } from "@/lib/utils" 7 | 8 | const Progress = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({ className, value, ...props }, ref) => ( 12 | 20 | 24 | 25 | )) 26 | Progress.displayName = ProgressPrimitive.Root.displayName 27 | 28 | export { Progress } 29 | -------------------------------------------------------------------------------- /frontend/src/components/ui/scroll-area.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area" 5 | 6 | import { cn } from "@/lib/utils" 7 | 8 | const ScrollArea = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({ className, children, ...props }, ref) => ( 12 | 17 | 18 | {children} 19 | 20 | 21 | 22 | 23 | )) 24 | ScrollArea.displayName = ScrollAreaPrimitive.Root.displayName 25 | 26 | const ScrollBar = React.forwardRef< 27 | React.ElementRef, 28 | React.ComponentPropsWithoutRef 29 | >(({ className, orientation = "vertical", ...props }, ref) => ( 30 | 43 | 44 | 45 | )) 46 | ScrollBar.displayName = ScrollAreaPrimitive.ScrollAreaScrollbar.displayName 47 | 48 | export { ScrollArea, ScrollBar } 49 | -------------------------------------------------------------------------------- /frontend/src/components/ui/slider.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as SliderPrimitive from "@radix-ui/react-slider" 5 | 6 | import { cn } from "@/lib/utils" 7 | 8 | const Slider = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({ className, ...props }, ref) => ( 12 | 20 | 21 | 22 | 23 | 24 | 25 | )) 26 | Slider.displayName = SliderPrimitive.Root.displayName 27 | 28 | export { Slider } 29 | -------------------------------------------------------------------------------- /frontend/src/components/ui/switch.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as SwitchPrimitives from "@radix-ui/react-switch" 5 | 6 | import { cn } from "@/lib/utils" 7 | 8 | const Switch = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({ className, ...props }, ref) => ( 12 | 20 | 25 | 26 | )) 27 | Switch.displayName = SwitchPrimitives.Root.displayName 28 | 29 | export { Switch } 30 | -------------------------------------------------------------------------------- /frontend/src/components/ui/toast.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as ToastPrimitives from "@radix-ui/react-toast" 5 | import { cva, type VariantProps } from "class-variance-authority" 6 | import { X } from "lucide-react" 7 | 8 | import { cn } from "@/lib/utils" 9 | 10 | const ToastProvider = ToastPrimitives.Provider 11 | 12 | const ToastViewport = React.forwardRef< 13 | React.ElementRef, 14 | React.ComponentPropsWithoutRef 15 | >(({ className, ...props }, ref) => ( 16 | 24 | )) 25 | ToastViewport.displayName = ToastPrimitives.Viewport.displayName 26 | 27 | const toastVariants = cva( 28 | "group pointer-events-auto relative flex w-full items-center justify-between space-x-2 overflow-hidden rounded-md border p-4 pr-6 shadow-lg transition-all data-[swipe=cancel]:translate-x-0 data-[swipe=end]:translate-x-[var(--radix-toast-swipe-end-x)] data-[swipe=move]:translate-x-[var(--radix-toast-swipe-move-x)] data-[swipe=move]:transition-none data-[state=open]:animate-in data-[state=closed]:animate-out data-[swipe=end]:animate-out data-[state=closed]:fade-out-80 data-[state=closed]:slide-out-to-right-full data-[state=open]:slide-in-from-top-full data-[state=open]:sm:slide-in-from-bottom-full", 29 | { 30 | variants: { 31 | variant: { 32 | default: "border bg-background text-foreground", 33 | destructive: 34 | "destructive group border-destructive bg-destructive text-destructive-foreground", 35 | }, 36 | }, 37 | defaultVariants: { 38 | variant: "default", 39 | }, 40 | } 41 | ) 42 | 43 | const Toast = React.forwardRef< 44 | React.ElementRef, 45 | React.ComponentPropsWithoutRef & 46 | VariantProps 47 | >(({ className, variant, ...props }, ref) => { 48 | return ( 49 | 54 | ) 55 | }) 56 | Toast.displayName = ToastPrimitives.Root.displayName 57 | 58 | const ToastAction = React.forwardRef< 59 | React.ElementRef, 60 | React.ComponentPropsWithoutRef 61 | >(({ className, ...props }, ref) => ( 62 | 70 | )) 71 | ToastAction.displayName = ToastPrimitives.Action.displayName 72 | 73 | const ToastClose = React.forwardRef< 74 | React.ElementRef, 75 | React.ComponentPropsWithoutRef 76 | >(({ className, ...props }, ref) => ( 77 | 86 | 87 | 88 | )) 89 | ToastClose.displayName = ToastPrimitives.Close.displayName 90 | 91 | const ToastTitle = React.forwardRef< 92 | React.ElementRef, 93 | React.ComponentPropsWithoutRef 94 | >(({ className, ...props }, ref) => ( 95 | 100 | )) 101 | ToastTitle.displayName = ToastPrimitives.Title.displayName 102 | 103 | const ToastDescription = React.forwardRef< 104 | React.ElementRef, 105 | React.ComponentPropsWithoutRef 106 | >(({ className, ...props }, ref) => ( 107 | 112 | )) 113 | ToastDescription.displayName = ToastPrimitives.Description.displayName 114 | 115 | type ToastProps = React.ComponentPropsWithoutRef 116 | 117 | type ToastActionElement = React.ReactElement 118 | 119 | export { 120 | type ToastProps, 121 | type ToastActionElement, 122 | ToastProvider, 123 | ToastViewport, 124 | Toast, 125 | ToastTitle, 126 | ToastDescription, 127 | ToastClose, 128 | ToastAction, 129 | } 130 | -------------------------------------------------------------------------------- /frontend/src/components/ui/toaster.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { useToast } from "@/hooks/use-toast"; 4 | import { 5 | Toast, 6 | ToastClose, 7 | ToastDescription, 8 | ToastProvider, 9 | ToastTitle, 10 | ToastViewport, 11 | } from "@/components/ui/toast"; 12 | 13 | export function Toaster() { 14 | const { toasts } = useToast(); 15 | 16 | return ( 17 | 18 | {toasts.map(function ({ id, title, description, action, ...props }) { 19 | return ( 20 | 21 |
22 | {title && {title}} 23 | {description && ( 24 | {description} 25 | )} 26 |
27 | {action} 28 | 29 |
30 | ); 31 | })} 32 | 33 |
34 | ); 35 | } 36 | -------------------------------------------------------------------------------- /frontend/src/components/ui/toggle.tsx: -------------------------------------------------------------------------------- 1 | "use client" 2 | 3 | import * as React from "react" 4 | import * as TogglePrimitive from "@radix-ui/react-toggle" 5 | import { cva, type VariantProps } from "class-variance-authority" 6 | 7 | import { cn } from "@/lib/utils" 8 | 9 | const toggleVariants = cva( 10 | "inline-flex items-center justify-center gap-2 rounded-md text-sm font-medium transition-colors hover:bg-muted hover:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50 data-[state=on]:bg-accent data-[state=on]:text-accent-foreground [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0", 11 | { 12 | variants: { 13 | variant: { 14 | default: "bg-transparent", 15 | outline: 16 | "border border-input bg-transparent shadow-sm hover:bg-accent hover:text-accent-foreground", 17 | }, 18 | size: { 19 | default: "h-9 px-2 min-w-9", 20 | sm: "h-8 px-1.5 min-w-8", 21 | lg: "h-10 px-2.5 min-w-10", 22 | }, 23 | }, 24 | defaultVariants: { 25 | variant: "default", 26 | size: "default", 27 | }, 28 | } 29 | ) 30 | 31 | const Toggle = React.forwardRef< 32 | React.ElementRef, 33 | React.ComponentPropsWithoutRef & 34 | VariantProps 35 | >(({ className, variant, size, ...props }, ref) => ( 36 | 41 | )) 42 | 43 | Toggle.displayName = TogglePrimitive.Root.displayName 44 | 45 | export { Toggle, toggleVariants } 46 | -------------------------------------------------------------------------------- /frontend/src/components/wrapper.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import React, { useState, useEffect } from "react"; 4 | 5 | interface WrapperProps { 6 | delay: number; 7 | children: React.ReactNode; 8 | fallbackComponent?: React.ReactNode; 9 | } 10 | 11 | export const Wrapper: React.FC = ({ 12 | delay, 13 | children, 14 | fallbackComponent, 15 | }) => { 16 | const [shouldRender, setShouldRender] = useState(delay === 0); 17 | 18 | useEffect(() => { 19 | if (delay > 0) { 20 | const timer = setTimeout(() => { 21 | setShouldRender(true); 22 | }, delay); 23 | 24 | return () => clearTimeout(timer); 25 | } 26 | }, [delay]); 27 | 28 | if (!shouldRender) { 29 | return fallbackComponent || null; 30 | } 31 | 32 | return <>{children}; 33 | }; 34 | -------------------------------------------------------------------------------- /frontend/src/constants/prompts.ts: -------------------------------------------------------------------------------- 1 | export const STORY_PROMPTS = [ 2 | "Write a story about a disgraced samurai seeking redemption by challenging the empire's greatest warriors", 3 | "Tell a tale of a young martial artist who discovers an ancient scroll containing forbidden fighting techniques", 4 | "Create a story about a knight who must win a grand tournament to save their kingdom from invasion", 5 | "Write about a wandering swordmaster who travels between dojos teaching lost sword arts", 6 | "Tell a story of rival martial arts schools competing for an ancient artifact of immense power", 7 | "Create a tale about a retired knight called back for one final quest to train the next generation", 8 | "Write about a tournament fighter working their way up through increasingly difficult opponents", 9 | "Tell a story of a martial arts prodigy who must master multiple fighting styles to defeat a tyrant", 10 | "Create a tale about knights of different kingdoms forming an elite fighting unit", 11 | "Write about a legendary sword that can only be wielded by the most skilled warrior", 12 | "Tell a story of an aging master defending their dojo from corrupt warriors", 13 | "Create a tale about a knight errant helping villages by challenging local warlords", 14 | "Write about martial artists from competing schools joining forces against a common enemy", 15 | "Tell a story of a disgraced knight working to restore honor to their fallen order", 16 | "Create a tale about the last surviving master of an ancient fighting style seeking an apprentice", 17 | ]; 18 | -------------------------------------------------------------------------------- /frontend/src/constants/stories.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt": "Write me a story about a samurai who gets thrown into Tokyo in 2040!", 4 | "description": "Yuta finds himself in an unfamiliar Tokyo, where neon signs and hovering vehicles have replaced the wooden buildings of his era. His hand grips his katana tighter as holographic advertisements flicker around him in this chrome and steel world.", 5 | "choices": [ 6 | "Seek out a local ally to understand this new world", 7 | "Investigate the purpose behind his time travel", 8 | "Find a way to return to his own era", 9 | "Embrace the future and learn its technology" 10 | ] 11 | }, 12 | { 13 | "prompt": "Let's create a story about a hacker who discovers something ancient and dangerous in the digital world!", 14 | "description": "Deep within the neural network, a lone hacker's consciousness encounters ancient AI guardians, their code predating the internet itself. The digital realm crackles with potential energy as they realize they've stumbled upon something far beyond their usual targets.", 15 | "choices": [ 16 | "Attempt to communicate with the AI guardians", 17 | "Quietly gather intel on the mysterious AIs", 18 | "Prepare defensive algorithms for potential conflict", 19 | "Seek an exit from this dangerous part of the network" 20 | ] 21 | }, 22 | { 23 | "prompt": "Can you help me write about a space station diplomat trying to prevent a war between alien species?", 24 | "description": "Ambassador Chen stands in Space Station Alpha's conference room, surrounded by three vastly different alien species: the warrior Zorgons, the energy-based Ethereals, and the mechanical Synthetics. The weight of Earth's future in the galactic community rests heavily on their shoulders.", 25 | "choices": [ 26 | "Propose a joint scientific mission to foster cooperation", 27 | "Address each species' concerns individually in private meetings", 28 | "Suggest a cultural exchange program to build understanding", 29 | "Call for an emergency session to address underlying tensions" 30 | ] 31 | }, 32 | { 33 | "prompt": "Hey, I want to write about a cyberpunk detective investigating some weird tech-related murders!", 34 | "description": "Detective Zhao stands over another victim in rain-slicked New Shanghai, their neural implant wiped completely clean like the others. Her own augmented reality interface buzzes with data as she realizes these aren't just ordinary tech murders.", 35 | "choices": [ 36 | "Interrogate the local AI-enhanced security system", 37 | "Track down the victim's last known contacts", 38 | "Investigate the black market for illegal neural implants", 39 | "Consult with a renegade hacker for insights" 40 | ] 41 | } 42 | ] 43 | -------------------------------------------------------------------------------- /frontend/src/db/db.ts: -------------------------------------------------------------------------------- 1 | import { drizzle } from "drizzle-orm/libsql"; 2 | import { createClient } from "@libsql/client"; 3 | import * as schema from "./schema"; 4 | 5 | if (!process.env.TURSO_CONNECTION_URL) { 6 | throw new Error("TURSO_CONNECTION_URL is not defined"); 7 | } 8 | 9 | if (!process.env.TURSO_AUTH_TOKEN) { 10 | throw new Error("TURSO_AUTH_TOKEN is not defined"); 11 | } 12 | 13 | const client = createClient({ 14 | url: process.env.TURSO_CONNECTION_URL!, 15 | authToken: process.env.TURSO_AUTH_TOKEN, 16 | }); 17 | export const db = drizzle(client, { schema }); 18 | -------------------------------------------------------------------------------- /frontend/src/db/schema.ts: -------------------------------------------------------------------------------- 1 | import { sql } from "drizzle-orm"; 2 | import { integer, sqliteTable, text } from "drizzle-orm/sqlite-core"; 3 | 4 | export const usersTable = sqliteTable("users", { 5 | email: text("email").primaryKey(), 6 | username: text("username").unique(), 7 | credits: integer("credits").notNull().default(3), 8 | isAdmin: integer("is_admin", { mode: "boolean" }).notNull().default(false), 9 | }); 10 | 11 | export type InsertUser = typeof usersTable.$inferInsert; 12 | export type SelectUser = typeof usersTable.$inferSelect; 13 | 14 | export const storiesTable = sqliteTable("stories", { 15 | id: text("id") 16 | .primaryKey() 17 | .default(sql`(uuid())`), 18 | userId: text("user_id") 19 | .notNull() 20 | .references(() => usersTable.email, { onDelete: "cascade" }), 21 | title: text("title"), 22 | description: text("description"), 23 | timestamp: integer("timestamp").notNull().default(Date.now()), 24 | public: integer("public").notNull().default(0), 25 | status: text("status", { enum: ["PROCESSING", "GENERATED", "ERROR"] }) 26 | .notNull() 27 | .default("PROCESSING"), 28 | errorMessage: text("error_message"), 29 | 30 | image_prompt: text("image_prompt").notNull(), 31 | story_prompt: text("story_prompt").notNull(), 32 | }); 33 | 34 | export type InsertStory = typeof storiesTable.$inferInsert; 35 | export type SelectStory = typeof storiesTable.$inferSelect; 36 | 37 | export const storyChoicesTable = sqliteTable("story_choices", { 38 | id: text("id") 39 | .primaryKey() 40 | .default(sql`(uuid())`), 41 | userId: text("user_id") 42 | .notNull() 43 | .references(() => usersTable.email, { onDelete: "cascade" }), 44 | storyId: text("story_id") 45 | .notNull() 46 | .references(() => storiesTable.id, { onDelete: "cascade" }), 47 | parentId: text("parent_id"), 48 | 49 | // The title and description the user sees after making the choice 50 | description: text("description").notNull(), 51 | 52 | // The title and description the user sees before making the choice 53 | choice_title: text("choice_title").notNull(), 54 | choice_description: text("choice_description").notNull(), 55 | 56 | // Image Prompt 57 | image_prompt: text("image_prompt").notNull(), 58 | 59 | isTerminal: integer("is_terminal").notNull().default(0), 60 | explored: integer("explored").notNull().default(0), 61 | }); 62 | 63 | export type InsertStoryChoice = typeof storyChoicesTable.$inferInsert; 64 | export type SelectStoryChoice = typeof storyChoicesTable.$inferSelect; 65 | -------------------------------------------------------------------------------- /frontend/src/hooks/use-toast.ts: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | // Inspired by react-hot-toast library 4 | import * as React from "react"; 5 | 6 | import type { ToastActionElement, ToastProps } from "@/components/ui/toast"; 7 | 8 | const TOAST_LIMIT = 1; 9 | const TOAST_REMOVE_DELAY = 1000000; 10 | 11 | type ToasterToast = ToastProps & { 12 | id: string; 13 | title?: React.ReactNode; 14 | description?: React.ReactNode; 15 | action?: ToastActionElement; 16 | }; 17 | 18 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 19 | const actionTypes = { 20 | ADD_TOAST: "ADD_TOAST", 21 | UPDATE_TOAST: "UPDATE_TOAST", 22 | DISMISS_TOAST: "DISMISS_TOAST", 23 | REMOVE_TOAST: "REMOVE_TOAST", 24 | } as const; 25 | 26 | let count = 0; 27 | 28 | function genId() { 29 | count = (count + 1) % Number.MAX_SAFE_INTEGER; 30 | return count.toString(); 31 | } 32 | 33 | type ActionType = typeof actionTypes; 34 | 35 | type Action = 36 | | { 37 | type: ActionType["ADD_TOAST"]; 38 | toast: ToasterToast; 39 | } 40 | | { 41 | type: ActionType["UPDATE_TOAST"]; 42 | toast: Partial; 43 | } 44 | | { 45 | type: ActionType["DISMISS_TOAST"]; 46 | toastId?: ToasterToast["id"]; 47 | } 48 | | { 49 | type: ActionType["REMOVE_TOAST"]; 50 | toastId?: ToasterToast["id"]; 51 | }; 52 | 53 | interface State { 54 | toasts: ToasterToast[]; 55 | } 56 | 57 | const toastTimeouts = new Map>(); 58 | 59 | const addToRemoveQueue = (toastId: string) => { 60 | if (toastTimeouts.has(toastId)) { 61 | return; 62 | } 63 | 64 | const timeout = setTimeout(() => { 65 | toastTimeouts.delete(toastId); 66 | dispatch({ 67 | type: "REMOVE_TOAST", 68 | toastId: toastId, 69 | }); 70 | }, TOAST_REMOVE_DELAY); 71 | 72 | toastTimeouts.set(toastId, timeout); 73 | }; 74 | 75 | export const reducer = (state: State, action: Action): State => { 76 | switch (action.type) { 77 | case "ADD_TOAST": 78 | return { 79 | ...state, 80 | toasts: [action.toast, ...state.toasts].slice(0, TOAST_LIMIT), 81 | }; 82 | 83 | case "UPDATE_TOAST": 84 | return { 85 | ...state, 86 | toasts: state.toasts.map((t) => 87 | t.id === action.toast.id ? { ...t, ...action.toast } : t 88 | ), 89 | }; 90 | 91 | case "DISMISS_TOAST": { 92 | const { toastId } = action; 93 | 94 | // ! Side effects ! - This could be extracted into a dismissToast() action, 95 | // but I'll keep it here for simplicity 96 | if (toastId) { 97 | addToRemoveQueue(toastId); 98 | } else { 99 | state.toasts.forEach((toast) => { 100 | addToRemoveQueue(toast.id); 101 | }); 102 | } 103 | 104 | return { 105 | ...state, 106 | toasts: state.toasts.map((t) => 107 | t.id === toastId || toastId === undefined 108 | ? { 109 | ...t, 110 | open: false, 111 | } 112 | : t 113 | ), 114 | }; 115 | } 116 | case "REMOVE_TOAST": 117 | if (action.toastId === undefined) { 118 | return { 119 | ...state, 120 | toasts: [], 121 | }; 122 | } 123 | return { 124 | ...state, 125 | toasts: state.toasts.filter((t) => t.id !== action.toastId), 126 | }; 127 | } 128 | }; 129 | 130 | const listeners: Array<(state: State) => void> = []; 131 | 132 | let memoryState: State = { toasts: [] }; 133 | 134 | function dispatch(action: Action) { 135 | memoryState = reducer(memoryState, action); 136 | listeners.forEach((listener) => { 137 | listener(memoryState); 138 | }); 139 | } 140 | 141 | type Toast = Omit; 142 | 143 | function toast({ ...props }: Toast) { 144 | const id = genId(); 145 | 146 | const update = (props: ToasterToast) => 147 | dispatch({ 148 | type: "UPDATE_TOAST", 149 | toast: { ...props, id }, 150 | }); 151 | const dismiss = () => dispatch({ type: "DISMISS_TOAST", toastId: id }); 152 | 153 | dispatch({ 154 | type: "ADD_TOAST", 155 | toast: { 156 | ...props, 157 | id, 158 | open: true, 159 | onOpenChange: (open) => { 160 | if (!open) dismiss(); 161 | }, 162 | }, 163 | }); 164 | 165 | return { 166 | id: id, 167 | dismiss, 168 | update, 169 | }; 170 | } 171 | 172 | function useToast() { 173 | const [state, setState] = React.useState(memoryState); 174 | 175 | React.useEffect(() => { 176 | listeners.push(setState); 177 | return () => { 178 | const index = listeners.indexOf(setState); 179 | if (index > -1) { 180 | listeners.splice(index, 1); 181 | } 182 | }; 183 | }, [state]); 184 | 185 | return { 186 | ...state, 187 | toast, 188 | dismiss: (toastId?: string) => dispatch({ type: "DISMISS_TOAST", toastId }), 189 | }; 190 | } 191 | 192 | export { useToast, toast }; 193 | -------------------------------------------------------------------------------- /frontend/src/hooks/useTypewriter.ts: -------------------------------------------------------------------------------- 1 | "use client" 2 | import { useState, useEffect, useCallback } from "react" 3 | 4 | interface TypewriterOptions { 5 | startDelay?: number 6 | onComplete?: () => void 7 | skipAnimation?: boolean 8 | } 9 | 10 | export function useTypewriter( 11 | text: string, 12 | options: TypewriterOptions = {}, 13 | speed = 30 14 | ) { 15 | const { startDelay = 0, onComplete, skipAnimation = false } = options 16 | const [displayText, setDisplayText] = useState(skipAnimation ? text : "") 17 | const [isComplete, setIsComplete] = useState(skipAnimation) 18 | 19 | 20 | const reset = useCallback(() => { 21 | setDisplayText("") 22 | setIsComplete(false) 23 | }, []) 24 | 25 | const animate = useCallback(() => { 26 | let currentIndex = 0 27 | let timeoutId: NodeJS.Timeout 28 | 29 | const typeNextCharacter = () => { 30 | if(currentIndex > text.length) { 31 | return 32 | } 33 | 34 | setDisplayText(text.slice(0, currentIndex)) 35 | 36 | 37 | if (currentIndex <= text.length) { 38 | setDisplayText(text.slice(0, currentIndex)) 39 | currentIndex++ 40 | 41 | if (currentIndex <= text.length) { 42 | timeoutId = setTimeout(typeNextCharacter, speed) 43 | } else { 44 | setIsComplete(true) 45 | onComplete?.() 46 | } 47 | } 48 | } 49 | 50 | const initialTimeout = setTimeout(() => { 51 | typeNextCharacter() 52 | }, startDelay) 53 | 54 | return () => { 55 | clearTimeout(timeoutId) 56 | clearTimeout(initialTimeout) 57 | } 58 | }, [text, speed, startDelay]) 59 | 60 | useEffect(() => { 61 | if (!skipAnimation) { 62 | 63 | reset() 64 | const cleanup = animate() 65 | return cleanup 66 | } 67 | }, [animate, reset, text, skipAnimation]) 68 | 69 | return { text: displayText, isComplete, reset } 70 | } 71 | -------------------------------------------------------------------------------- /frontend/src/lib/login-server.ts: -------------------------------------------------------------------------------- 1 | "use server"; 2 | 3 | import { signIn, signOut } from "@/auth"; 4 | 5 | export async function signInWithGoogle() { 6 | await signIn("google", { 7 | redirectTo: "/dashboard", 8 | redirect: true, 9 | }); 10 | } 11 | 12 | export async function signOutWithGoogle() { 13 | await signOut({ 14 | redirectTo: "/", 15 | redirect: true, 16 | }); 17 | } 18 | -------------------------------------------------------------------------------- /frontend/src/lib/story.ts: -------------------------------------------------------------------------------- 1 | "use server"; 2 | 3 | import { auth } from "@/auth"; 4 | import { db } from "@/db/db"; 5 | import { storiesTable, storyChoicesTable, usersTable } from "@/db/schema"; 6 | import { and, eq, not } from "drizzle-orm"; 7 | import { revalidatePath } from "next/cache"; 8 | import { redirect } from "next/navigation"; 9 | 10 | export async function resetStoryProgress(storyId: string) { 11 | const session = await auth(); 12 | 13 | if (!session?.user?.email) { 14 | throw new Error("You must be signed in to reset story progress"); 15 | } 16 | 17 | const story = await db.query.storiesTable.findFirst({ 18 | where: eq(storiesTable.id, storyId), 19 | }); 20 | 21 | if (!story) { 22 | throw new Error("Story not found"); 23 | } 24 | 25 | await db 26 | .update(storyChoicesTable) 27 | .set({ 28 | explored: 0, 29 | }) 30 | .where( 31 | and( 32 | eq(storyChoicesTable.storyId, storyId), 33 | not(eq(storyChoicesTable.parentId, "NULL")) 34 | ) 35 | ); 36 | } 37 | 38 | export async function generateStory(prompt: string) { 39 | // Get the authenticated user 40 | const session = await auth(); 41 | 42 | if (!session?.user?.email) { 43 | throw new Error("You must be signed in to submit prompts"); 44 | } 45 | 46 | const remainingCredits = await db.query.usersTable.findFirst({ 47 | where: eq(usersTable.email, session.user.email), 48 | columns: { 49 | isAdmin: true, 50 | credits: true, 51 | }, 52 | }); 53 | 54 | if (remainingCredits?.credits === 0 && !remainingCredits?.isAdmin) { 55 | throw new Error("You have no credits left"); 56 | } 57 | 58 | try { 59 | const uuid = crypto.randomUUID(); 60 | 61 | const response = await fetch( 62 | `${process.env.RESTATE_ENDPOINT}/cyoa/${uuid}/run/send`, 63 | { 64 | method: "POST", 65 | headers: { 66 | "Content-Type": "application/json", 67 | Authorization: `Bearer ${process.env.RESTATE_TOKEN}`, 68 | }, 69 | body: JSON.stringify({ 70 | user_email: session.user.email, 71 | prompt: prompt, 72 | }), 73 | } 74 | ); 75 | 76 | if (!response.ok) { 77 | console.error("Failed to generate story", response); 78 | throw new Error("Failed to generate story"); 79 | } 80 | 81 | const data = await response.json(); 82 | 83 | if (!remainingCredits?.isAdmin) { 84 | // Decrement user credits by 1 85 | await db 86 | .update(usersTable) 87 | .set({ 88 | credits: remainingCredits?.credits ? remainingCredits.credits - 1 : 0, 89 | }) 90 | .where(eq(usersTable.email, session.user.email)); 91 | } 92 | return data; 93 | } catch (error) { 94 | console.error("Error submitting prompt:", error); 95 | throw new Error("Failed to submit prompt"); 96 | } 97 | } 98 | 99 | export async function deleteStory(storyId: string) { 100 | await db.delete(storiesTable).where(eq(storiesTable.id, storyId)); 101 | revalidatePath("/dashboard"); 102 | redirect("/dashboard"); 103 | } 104 | 105 | export async function toggleStoryVisibility( 106 | storyId: string, 107 | isPublic: boolean 108 | ) { 109 | await db 110 | .update(storiesTable) 111 | .set({ public: isPublic ? 0 : 1 }) 112 | .where(eq(storiesTable.id, storyId)); 113 | 114 | revalidatePath(`/dashboard/story/${storyId}`); 115 | redirect(`/dashboard/story/${storyId}`); 116 | } 117 | -------------------------------------------------------------------------------- /frontend/src/lib/tree.ts: -------------------------------------------------------------------------------- 1 | // import { Choice, ChoiceNode } from "@/types/choice"; 2 | 3 | import { SelectStoryChoice } from "@/db/schema"; 4 | 5 | type treeChoice = SelectStoryChoice & { children: treeChoice[] }; 6 | 7 | export function buildTree( 8 | choices: SelectStoryChoice[], 9 | parentId: string | "NULL" 10 | ): treeChoice[] { 11 | return choices 12 | .filter(choice => choice.parentId === parentId) 13 | .map(choice => ({ 14 | ...choice, 15 | children: buildTree(choices, choice.id) 16 | })); 17 | } 18 | 19 | export function getPath( 20 | choices: SelectStoryChoice[], 21 | targetId: string 22 | ): SelectStoryChoice[] { 23 | const path: SelectStoryChoice[] = []; 24 | let currentId: string | null = targetId; 25 | 26 | while (currentId) { 27 | const current = choices.find(c => c.id === currentId); 28 | if (!current) break; 29 | path.unshift(current); 30 | currentId = current.parentId; 31 | } 32 | 33 | return path; 34 | } 35 | 36 | -------------------------------------------------------------------------------- /frontend/src/lib/user.ts: -------------------------------------------------------------------------------- 1 | 'use server' 2 | 3 | import { auth } from "@/auth"; 4 | import { db } from "@/db/db"; 5 | import { usersTable } from "@/db/schema"; 6 | import { eq } from "drizzle-orm"; 7 | 8 | export async function validateUsername(username: string) { 9 | const session = await auth(); 10 | const email = session?.user?.email; 11 | 12 | if (!email) { 13 | return { 14 | success: false, 15 | message: "User not found. Please sign in first.", 16 | }; 17 | } 18 | 19 | try { 20 | // Update username, will fail if username already exists due to unique constraint 21 | await db 22 | .update(usersTable) 23 | .set({ username: username }) 24 | .where(eq(usersTable.email, email)); 25 | 26 | console.log(`Username ${username} updated for user ${email}`); 27 | return { success: true, message: `Access granted: ${username}` }; 28 | } catch (error) { 29 | console.error(error); 30 | return { 31 | success: false, 32 | message: "Username is already taken. Please choose another one.", 33 | }; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /frontend/src/lib/utils.ts: -------------------------------------------------------------------------------- 1 | import { clsx, type ClassValue } from "clsx" 2 | import { twMerge } from "tailwind-merge" 3 | 4 | export function cn(...inputs: ClassValue[]) { 5 | return twMerge(clsx(inputs)) 6 | } 7 | -------------------------------------------------------------------------------- /frontend/src/middleware.ts: -------------------------------------------------------------------------------- 1 | import { auth } from "@/auth"; 2 | 3 | export default auth((req) => { 4 | // Check if user is not authenticated and trying to access protected routes 5 | if (!req.auth) { 6 | // Just redirect to login if user is not authenticated and trying to access dashboard 7 | if (req.nextUrl.pathname.startsWith("/dashboard")) { 8 | const newUrl = new URL("/login", req.nextUrl.origin); 9 | return Response.redirect(newUrl); 10 | } 11 | 12 | // Throw 401 if user is not authenticated and trying to access API routes 13 | if (req.nextUrl.pathname.startsWith("/api/stories")) { 14 | return new Response(null, { 15 | status: 401, 16 | statusText: "Unauthorized", 17 | }); 18 | } 19 | } 20 | }); 21 | -------------------------------------------------------------------------------- /frontend/src/providers/ReactQueryProvider.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import React from "react"; 3 | import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; 4 | 5 | const queryClient = new QueryClient(); 6 | 7 | const ReactQueryProvider = ({ children }: { children: React.ReactNode }) => { 8 | return ( 9 | {children} 10 | ); 11 | }; 12 | 13 | export default ReactQueryProvider; 14 | -------------------------------------------------------------------------------- /frontend/src/types/next-auth.d.ts: -------------------------------------------------------------------------------- 1 | import type { DefaultSession } from "next-auth"; 2 | 3 | declare module "next-auth" { 4 | interface Session { 5 | user: { 6 | username: string; 7 | credits: number; 8 | isAdmin: boolean; 9 | } & DefaultSession["user"]; 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /frontend/tailwind.config.ts: -------------------------------------------------------------------------------- 1 | import type { Config } from "tailwindcss"; 2 | 3 | const config: Config = { 4 | darkMode: ["class"], 5 | content: [ 6 | "./src/pages/**/*.{js,ts,jsx,tsx,mdx}", 7 | "./src/components/**/*.{js,ts,jsx,tsx,mdx}", 8 | "./src/app/**/*.{js,ts,jsx,tsx,mdx}", 9 | ], 10 | theme: { 11 | extend: { 12 | colors: { 13 | background: 'hsl(var(--background))', 14 | foreground: 'hsl(var(--foreground))', 15 | card: { 16 | DEFAULT: 'hsl(var(--card))', 17 | foreground: 'hsl(var(--card-foreground))' 18 | }, 19 | popover: { 20 | DEFAULT: 'hsl(var(--popover))', 21 | foreground: 'hsl(var(--popover-foreground))' 22 | }, 23 | primary: { 24 | DEFAULT: 'hsl(var(--primary))', 25 | foreground: 'hsl(var(--primary-foreground))' 26 | }, 27 | secondary: { 28 | DEFAULT: 'hsl(var(--secondary))', 29 | foreground: 'hsl(var(--secondary-foreground))' 30 | }, 31 | muted: { 32 | DEFAULT: 'hsl(var(--muted))', 33 | foreground: 'hsl(var(--muted-foreground))' 34 | }, 35 | accent: { 36 | DEFAULT: 'hsl(var(--accent))', 37 | foreground: 'hsl(var(--accent-foreground))' 38 | }, 39 | destructive: { 40 | DEFAULT: 'hsl(var(--destructive))', 41 | foreground: 'hsl(var(--destructive-foreground))' 42 | }, 43 | border: 'hsl(var(--border))', 44 | input: 'hsl(var(--input))', 45 | ring: 'hsl(var(--ring))', 46 | chart: { 47 | '1': 'hsl(var(--chart-1))', 48 | '2': 'hsl(var(--chart-2))', 49 | '3': 'hsl(var(--chart-3))', 50 | '4': 'hsl(var(--chart-4))', 51 | '5': 'hsl(var(--chart-5))' 52 | } 53 | }, 54 | borderRadius: { 55 | lg: 'var(--radius)', 56 | md: 'calc(var(--radius) - 2px)', 57 | sm: 'calc(var(--radius) - 4px)' 58 | } 59 | } 60 | }, 61 | plugins: [require("tailwindcss-animate")], 62 | }; 63 | export default config; 64 | -------------------------------------------------------------------------------- /frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "noEmit": true, 8 | "esModuleInterop": true, 9 | "module": "esnext", 10 | "moduleResolution": "bundler", 11 | "resolveJsonModule": true, 12 | "isolatedModules": true, 13 | "jsx": "preserve", 14 | "incremental": true, 15 | "plugins": [ 16 | { 17 | "name": "next" 18 | } 19 | ], 20 | "paths": { 21 | "@/*": ["./src/*"] 22 | } 23 | }, 24 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], 25 | "exclude": ["node_modules"] 26 | } 27 | -------------------------------------------------------------------------------- /modal/audio.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import boto3 3 | from pydantic import BaseModel 4 | 5 | 6 | def download_model(): 7 | from transformers import pipeline 8 | 9 | pipeline("text-to-audio", model="facebook/musicgen-medium") 10 | 11 | 12 | image = ( 13 | modal.Image.debian_slim() 14 | .apt_install("git", "ffmpeg") 15 | .pip_install("torch", "transformers", "scipy", "boto3", "fastapi") 16 | .pip_install("requests") 17 | .run_function(download_model) 18 | ) 19 | 20 | app = modal.App("audio-service") 21 | 22 | # Configure S3 credentials 23 | s3_secret = modal.Secret.from_name("aws-secret") 24 | 25 | 26 | class AudioRequest(BaseModel): 27 | prompt: str 28 | storyId: str 29 | callback_url: str 30 | callback_token: str 31 | 32 | 33 | @app.function(image=image, gpu="A10g", secrets=[s3_secret]) 34 | @modal.web_endpoint(method="POST") 35 | def generate_audio(request: AudioRequest): 36 | # Use a pipeline as a high-level helper 37 | from transformers import pipeline 38 | from scipy.io import wavfile 39 | import tempfile 40 | import os 41 | import requests 42 | 43 | print(request) 44 | 45 | pipe = pipeline("text-to-audio", model="facebook/musicgen-medium") 46 | audio = pipe( 47 | request.prompt, 48 | forward_params={"do_sample": True}, 49 | ) 50 | 51 | # Write to temporary wav file 52 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: 53 | wavfile.write(temp_file.name, audio["sampling_rate"], audio["audio"]) 54 | 55 | # Upload to S3 56 | s3 = boto3.client("s3") 57 | s3.put_object( 58 | Bucket=os.getenv("AWS_BUCKET_NAME"), 59 | Key=f"{request.storyId}/theme_song.wav", 60 | Body=open(temp_file.name, "rb").read(), 61 | ) 62 | 63 | headers = {"Authorization": f"Bearer {request.callback_token}"} 64 | 65 | requests.post(request.callback_url, headers=headers) 66 | -------------------------------------------------------------------------------- /modal/download.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | image = ( 4 | modal.Image.debian_slim(python_version="3.11") 5 | .pip_install("huggingface_hub[hf_transfer]==0.26.2") 6 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 7 | .run_commands("rm -rf /root/comfy/ComfyUI/models") 8 | ) 9 | 10 | app = modal.App(name="comfyui-models", image=image) 11 | 12 | # Set up model storage 13 | vol = modal.Volume.from_name("comfyui-models", create_if_missing=True) 14 | 15 | 16 | @app.function( 17 | volumes={"/root/models": vol}, 18 | secrets=[modal.Secret.from_name("my-huggingface-secret")], 19 | ) 20 | def hf_download(repo_id: str, filename: str, model_type: str): 21 | from huggingface_hub import hf_hub_download 22 | 23 | print(f"Downloading {filename} from {repo_id} to {model_type}") 24 | hf_hub_download( 25 | repo_id=repo_id, 26 | filename=filename, 27 | local_dir=f"/root/models/{model_type}", 28 | ) 29 | 30 | 31 | @app.local_entrypoint() 32 | def download_models(): 33 | models_to_download = [ 34 | ( 35 | "black-forest-labs/FLUX.1-dev", 36 | "ae.safetensors", 37 | "vae", 38 | ), 39 | ( 40 | "black-forest-labs/FLUX.1-dev", 41 | "flux1-dev.safetensors", 42 | "unet", 43 | ), 44 | ( 45 | "city96/FLUX.1-dev-gguf", 46 | "flux1-dev-Q5_K_S.gguf", 47 | "unet", 48 | ), 49 | ( 50 | "comfyanonymous/flux_text_encoders", 51 | "t5xxl_fp8_e4m3fn.safetensors", 52 | "clip", 53 | ), 54 | ("comfyanonymous/flux_text_encoders", "clip_l.safetensors", "clip"), 55 | ] 56 | list(hf_download.starmap(models_to_download)) 57 | -------------------------------------------------------------------------------- /modal/generate_audio.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | kokoro_volume = modal.Volume.from_name("kokoro-volume", create_if_missing=True) 4 | 5 | 6 | image = ( 7 | modal.Image.debian_slim() 8 | .apt_install("git", "ffmpeg", "espeak-ng", "git-lfs") 9 | .pip_install( 10 | "torch", "transformers", "scipy", "boto3", "fastapi", "phonemizer", "munch" 11 | ) 12 | .pip_install("requests", "ipython") 13 | ) 14 | 15 | app = modal.App("kokoro-tts", image=image) 16 | 17 | 18 | @app.function(volumes={"/kokoro_volume": kokoro_volume}) 19 | def download_kokoro(): 20 | import os 21 | 22 | if not os.path.exists("/kokoro_volume/Kokoro-82M"): 23 | os.system("git clone https://huggingface.co/hexgrad/Kokoro-82M /kokoro_volume") 24 | 25 | 26 | @app.cls( 27 | gpu="T4", 28 | volumes={"/kokoro_volume": kokoro_volume}, 29 | secrets=[modal.Secret.from_name("aws-secret")], 30 | concurrency_limit=2, 31 | ) 32 | class KokoroGenerator: 33 | @modal.enter() 34 | def load_model(self): 35 | import os 36 | import sys 37 | import torch 38 | 39 | os.chdir("/kokoro_volume") 40 | sys.path.append(".") 41 | from kokoro import generate 42 | from models import build_model 43 | 44 | device = "cuda" if torch.cuda.is_available() else "cpu" 45 | self.MODEL = build_model("kokoro-v0_19.pth", device) 46 | self.VOICE_NAME = "bm_lewis" 47 | self.VOICEPACK = torch.load( 48 | f"voices/{self.VOICE_NAME}.pt", weights_only=True 49 | ).to(device) 50 | self.generate_fn = generate 51 | 52 | @modal.web_endpoint(method="POST", requires_proxy_auth=True) 53 | def generate_audio(self, node: dict): 54 | import boto3 55 | import os 56 | import tempfile 57 | from scipy.io import wavfile 58 | 59 | audio, out_ps = self.generate_fn( 60 | self.MODEL, node["prompt"], self.VOICEPACK, lang="b" 61 | ) 62 | 63 | with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_file: 64 | wavfile.write(tmp_file.name, rate=22050, data=audio) 65 | tmp_file.flush() 66 | tmp_file.seek(0) 67 | audio_bytes = tmp_file.read() 68 | 69 | s3 = boto3.client("s3") 70 | s3.put_object( 71 | Bucket=os.getenv("AWS_BUCKET_NAME"), 72 | Key=f"{node['story_id']}/{node['node_id']}.wav", 73 | Body=audio_bytes, 74 | ) 75 | -------------------------------------------------------------------------------- /modal/images.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import uuid 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import modal 8 | 9 | image = ( 10 | modal.Image.debian_slim(python_version="3.11") 11 | .apt_install("git") 12 | .pip_install("comfy-cli==1.2.7", "requests", "boto3") 13 | .run_commands("comfy --skip-prompt install --nvidia") 14 | .run_commands("comfy node install was-node-suite-comfyui ComfyUI-GGUF") 15 | .run_commands( # needs to be empty for Volume mount to work 16 | "rm -rf /root/comfy/ComfyUI/models" 17 | ) 18 | ) 19 | 20 | app = modal.App(name="comfyui-api", image=image) 21 | 22 | vol = modal.Volume.from_name("comfyui-models", create_if_missing=True) 23 | 24 | 25 | @app.cls( 26 | gpu="A100", 27 | volumes={"/root/comfy/ComfyUI/models": vol}, 28 | mounts=[ 29 | modal.Mount.from_local_file( 30 | "workflows/flux.json", 31 | "/root/flux.json", 32 | ), 33 | ], 34 | secrets=[modal.Secret.from_name("aws-secret")], 35 | concurrency_limit=45, 36 | ) 37 | class ComfyUI: 38 | @modal.enter() 39 | def launch_comfy_background(self): 40 | cmd = "comfy launch --background" 41 | subprocess.run(cmd, shell=True, check=True) 42 | 43 | @modal.method() 44 | def infer(self, workflow_path: str = "/root/flux.json"): 45 | cmd = f"comfy run --workflow {workflow_path} --wait --timeout 1200" 46 | subprocess.run(cmd, shell=True, check=True) 47 | 48 | output_dir = "/root/comfy/ComfyUI/output" 49 | workflow = json.loads(Path(workflow_path).read_text()) 50 | file_prefix = [ 51 | node.get("inputs") 52 | for node in workflow.values() 53 | if node.get("class_type") == "SaveImage" 54 | or node.get("class_type") == "SaveAnimatedWEBP" 55 | ][0]["filename_prefix"] 56 | 57 | for f in Path(output_dir).iterdir(): 58 | if f.name.startswith(file_prefix): 59 | return f.read_bytes() 60 | 61 | @modal.web_endpoint(method="POST") 62 | def api(self, node: Dict): 63 | import boto3 64 | import os 65 | 66 | print(f"Recieved request: {node}") 67 | workflow_data = json.loads(Path("/root/flux.json").read_text()) 68 | 69 | # Update workflow with node prompt 70 | workflow_data["6"]["inputs"]["text"] = node["prompt"] 71 | 72 | # Set unique filename prefix 73 | client_id = uuid.uuid4().hex 74 | workflow_data["9"]["inputs"]["filename_prefix"] = client_id 75 | 76 | # Save temporary workflow file 77 | new_workflow_file = f"{client_id}.json" 78 | json.dump(workflow_data, Path(new_workflow_file).open("w")) 79 | 80 | # Generate image 81 | img_bytes = self.infer.local(new_workflow_file) 82 | 83 | # Upload to S3 84 | s3 = boto3.client("s3") 85 | s3.put_object( 86 | Bucket=os.getenv("AWS_BUCKET_NAME"), 87 | Key=f"{node['story_id']}/{node['node_id']}.png", 88 | Body=img_bytes, 89 | ) 90 | -------------------------------------------------------------------------------- /modal/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "modal-service" 3 | version = "0.1.0" 4 | description = "Modal Service" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "boto3>=1.35.90", 9 | "diffusers>=0.32.1", 10 | "scipy>=1.14.1", 11 | "soundfile>=0.12.1", 12 | "torch>=2.5.1", 13 | "transformers>=4.47.1", 14 | ] 15 | 16 | [project.optional-dependencies] 17 | dev = [ 18 | "pytest>=7.4.0", 19 | "pytest-mock>=3.11.1", 20 | ] 21 | -------------------------------------------------------------------------------- /modal/tests/README.md: -------------------------------------------------------------------------------- 1 | # CYOA Modal Service Unit Tests 2 | 3 | This directory contains unit tests for the CYOA (Choose Your Own Adventure) Modal services. 4 | 5 | ## Running the Tests 6 | 7 | To run the tests, you'll need to install the development dependencies: 8 | 9 | ```bash 10 | cd modal 11 | pip install -e ".[dev]" 12 | ``` 13 | 14 | Then, you can run the tests using pytest: 15 | 16 | ```bash 17 | # Run all tests 18 | pytest 19 | 20 | # Run tests with verbose output 21 | pytest -v 22 | 23 | # Run a specific test file 24 | pytest tests/test_images.py 25 | 26 | # Run a specific test 27 | pytest tests/test_images.py::test_api 28 | ``` 29 | 30 | ## Test Coverage 31 | 32 | The tests cover the following components: 33 | 34 | 1. **Image Generation Service** (`test_images.py`): 35 | - Testing the ComfyUI background process launch 36 | - Testing the inference method 37 | - Testing the API endpoint for image generation 38 | 39 | 2. **Audio Generation Service** (`test_audio.py`): 40 | - Testing the audio generation endpoint 41 | - Testing the integration with the text-to-audio pipeline 42 | - Testing S3 upload and callback functionality 43 | 44 | ## Adding New Tests 45 | 46 | When adding new tests, follow these guidelines: 47 | 48 | 1. Create test files with the `test_` prefix 49 | 2. Use descriptive test function names that explain what is being tested 50 | 3. Use fixtures to set up common test dependencies 51 | 4. Mock external dependencies to isolate the code being tested 52 | 5. Test both success and failure scenarios 53 | 54 | ## Test Structure 55 | 56 | Each test file follows a similar structure: 57 | 58 | 1. Import necessary modules and the code being tested 59 | 2. Define mock classes for the Modal library (since it's not available in the test environment) 60 | 3. Define fixtures for setting up test dependencies 61 | 4. Define test functions that use the fixtures 62 | 5. Assert expected outcomes 63 | 64 | ## Mocking Modal 65 | 66 | Since the Modal library is not available in the test environment, we use mock classes to simulate its behavior. This allows us to test the code without actually deploying it to Modal. 67 | 68 | ## Continuous Integration 69 | 70 | These tests can be integrated into a CI/CD pipeline to ensure code quality before deployment. -------------------------------------------------------------------------------- /modal/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/modal/tests/__init__.py -------------------------------------------------------------------------------- /modal/tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import MagicMock, patch, mock_open 3 | import tempfile 4 | 5 | 6 | # Mock the modal library since it's not available in the test environment 7 | class MockModal: 8 | class Image: 9 | @staticmethod 10 | def debian_slim(): 11 | return MockImage() 12 | 13 | class App: 14 | def __init__(self, name=None): 15 | self.name = name 16 | 17 | def function(self, image=None, gpu=None, secrets=None): 18 | def decorator(func): 19 | return func 20 | return decorator 21 | 22 | def web_endpoint(self, method=None): 23 | def decorator(func): 24 | return func 25 | return decorator 26 | 27 | class Secret: 28 | @staticmethod 29 | def from_name(name): 30 | return MockSecret(name) 31 | 32 | 33 | class MockImage: 34 | def apt_install(self, *args): 35 | return self 36 | 37 | def pip_install(self, *args): 38 | return self 39 | 40 | def run_function(self, func): 41 | return self 42 | 43 | 44 | class MockSecret: 45 | def __init__(self, name): 46 | self.name = name 47 | 48 | 49 | # Patch modal before importing the module 50 | with patch.dict('sys.modules', {'modal': MagicMock()}): 51 | import sys 52 | sys.modules['modal'] = MockModal() 53 | 54 | # Now we can import our module 55 | from audio import generate_audio, AudioRequest 56 | 57 | 58 | @pytest.fixture 59 | def mock_transformers(): 60 | with patch('transformers.pipeline') as mock_pipeline: 61 | mock_pipe = MagicMock() 62 | mock_pipe.return_value = {"audio": [0.1, 0.2, 0.3], "sampling_rate": 44100} 63 | mock_pipeline.return_value = mock_pipe 64 | yield mock_pipeline 65 | 66 | 67 | @pytest.fixture 68 | def mock_wavfile(): 69 | with patch('scipy.io.wavfile') as mock_wavfile: 70 | yield mock_wavfile 71 | 72 | 73 | @pytest.fixture 74 | def mock_tempfile(): 75 | with patch('tempfile.NamedTemporaryFile') as mock_temp: 76 | mock_temp_file = MagicMock() 77 | mock_temp_file.name = "/tmp/test_audio.wav" 78 | mock_temp.return_value.__enter__.return_value = mock_temp_file 79 | yield mock_temp 80 | 81 | 82 | @pytest.fixture 83 | def mock_boto3(): 84 | with patch('boto3.client') as mock_client: 85 | mock_s3 = MagicMock() 86 | mock_client.return_value = mock_s3 87 | yield mock_s3 88 | 89 | 90 | @pytest.fixture 91 | def mock_requests(): 92 | with patch('requests.post') as mock_post: 93 | yield mock_post 94 | 95 | 96 | def test_generate_audio(mock_transformers, mock_wavfile, mock_tempfile, mock_boto3, mock_requests): 97 | """Test the audio generation endpoint""" 98 | # Mock open 99 | with patch('builtins.open', mock_open(read_data=b"test_audio_data")) as mock_file: 100 | # Mock os.getenv 101 | with patch('os.getenv') as mock_getenv: 102 | mock_getenv.return_value = "test-bucket" 103 | 104 | # Create a request 105 | request = AudioRequest( 106 | prompt="Generate epic adventure music", 107 | storyId="test_story_id", 108 | callback_url="http://test-callback", 109 | callback_token="test-token" 110 | ) 111 | 112 | # Call the function 113 | generate_audio(request) 114 | 115 | # Check that the pipeline was created correctly 116 | mock_transformers.assert_called_once_with("text-to-audio", model="facebook/musicgen-medium") 117 | 118 | # Check that the pipeline was called with the prompt 119 | mock_transformers.return_value.assert_called_once_with( 120 | "Generate epic adventure music", 121 | forward_params={"do_sample": True} 122 | ) 123 | 124 | # Check that the audio was written to a file 125 | mock_wavfile.write.assert_called_once_with( 126 | "/tmp/test_audio.wav", 127 | 44100, 128 | [0.1, 0.2, 0.3] 129 | ) 130 | 131 | # Check that the file was opened 132 | mock_file.assert_called_once_with("/tmp/test_audio.wav", "rb") 133 | 134 | # Check that the audio was uploaded to S3 135 | mock_boto3.put_object.assert_called_once_with( 136 | Bucket="test-bucket", 137 | Key="test_story_id/theme_song.wav", 138 | Body=b"test_audio_data" 139 | ) 140 | 141 | # Check that the callback was called 142 | mock_requests.assert_called_once_with( 143 | "http://test-callback", 144 | headers={"Authorization": "Bearer test-token"} 145 | ) -------------------------------------------------------------------------------- /modal/tests/test_images.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import MagicMock, patch, mock_open 3 | import json 4 | import uuid 5 | from pathlib import Path 6 | 7 | 8 | # Mock the modal library since it's not available in the test environment 9 | class MockModal: 10 | class Image: 11 | @staticmethod 12 | def debian_slim(python_version=None): 13 | return MockImage() 14 | 15 | class App: 16 | def __init__(self, name=None, image=None): 17 | self.name = name 18 | self.image = image 19 | 20 | def cls(self, **kwargs): 21 | def decorator(cls): 22 | return cls 23 | return decorator 24 | 25 | def web_endpoint(self, method=None): 26 | def decorator(func): 27 | return func 28 | return decorator 29 | 30 | class Volume: 31 | @staticmethod 32 | def from_name(name, create_if_missing=False): 33 | return MockVolume() 34 | 35 | class Mount: 36 | @staticmethod 37 | def from_local_file(source, dest): 38 | return MockMount(source, dest) 39 | 40 | class Secret: 41 | @staticmethod 42 | def from_name(name): 43 | return MockSecret(name) 44 | 45 | @staticmethod 46 | def enter(): 47 | def decorator(func): 48 | return func 49 | return decorator 50 | 51 | @staticmethod 52 | def method(): 53 | def decorator(func): 54 | return func 55 | return decorator 56 | 57 | 58 | class MockImage: 59 | def apt_install(self, *args): 60 | return self 61 | 62 | def pip_install(self, *args): 63 | return self 64 | 65 | def run_commands(self, *args): 66 | return self 67 | 68 | 69 | class MockVolume: 70 | pass 71 | 72 | 73 | class MockMount: 74 | def __init__(self, source, dest): 75 | self.source = source 76 | self.dest = dest 77 | 78 | 79 | class MockSecret: 80 | def __init__(self, name): 81 | self.name = name 82 | 83 | 84 | # Patch modal before importing the module 85 | with patch.dict('sys.modules', {'modal': MagicMock()}): 86 | import sys 87 | sys.modules['modal'] = MockModal() 88 | 89 | # Now we can import our module 90 | from images import ComfyUI 91 | 92 | 93 | @pytest.fixture 94 | def comfy_ui(): 95 | return ComfyUI() 96 | 97 | 98 | @pytest.fixture 99 | def mock_subprocess(): 100 | with patch('subprocess.run') as mock_run: 101 | yield mock_run 102 | 103 | 104 | @pytest.fixture 105 | def mock_path(): 106 | with patch('pathlib.Path') as mock_path: 107 | mock_file = MagicMock() 108 | mock_file.read_text.return_value = json.dumps({ 109 | "6": {"inputs": {"text": "original prompt"}}, 110 | "9": {"inputs": {"filename_prefix": "original_prefix"}} 111 | }) 112 | mock_file.read_bytes.return_value = b"test_image_data" 113 | mock_file.name = "test_file.png" 114 | 115 | mock_path.return_value = mock_file 116 | 117 | # Mock the iterdir method to return a list with one file 118 | mock_dir = MagicMock() 119 | mock_dir.iterdir.return_value = [mock_file] 120 | mock_path.side_effect = lambda p: mock_file if p.endswith(".json") else mock_dir 121 | 122 | yield mock_path 123 | 124 | 125 | @pytest.fixture 126 | def mock_boto3(): 127 | with patch('boto3.client') as mock_client: 128 | mock_s3 = MagicMock() 129 | mock_client.return_value = mock_s3 130 | yield mock_s3 131 | 132 | 133 | def test_launch_comfy_background(comfy_ui, mock_subprocess): 134 | """Test that the ComfyUI background process is launched correctly""" 135 | comfy_ui.launch_comfy_background() 136 | 137 | mock_subprocess.assert_called_once_with( 138 | "comfy launch --background", 139 | shell=True, 140 | check=True 141 | ) 142 | 143 | 144 | def test_infer(comfy_ui, mock_subprocess, mock_path): 145 | """Test the inference method""" 146 | result = comfy_ui.infer() 147 | 148 | # Check that the comfy run command was executed 149 | mock_subprocess.assert_called_once_with( 150 | "comfy run --workflow /root/flux.json --wait --timeout 1200", 151 | shell=True, 152 | check=True 153 | ) 154 | 155 | # Check that the file was read 156 | mock_path.assert_any_call("/root/flux.json") 157 | 158 | # Check that the output directory was checked 159 | mock_path.assert_any_call("/root/comfy/ComfyUI/output") 160 | 161 | # Check that the result is the file content 162 | assert result == b"test_image_data" 163 | 164 | 165 | def test_api(comfy_ui, mock_path, mock_boto3): 166 | """Test the API endpoint""" 167 | # Mock uuid.uuid4 168 | with patch('uuid.uuid4') as mock_uuid: 169 | mock_uuid.return_value.hex = "test_uuid" 170 | 171 | # Mock open 172 | with patch('builtins.open', mock_open()) as mock_file: 173 | # Mock os.getenv 174 | with patch('os.getenv') as mock_getenv: 175 | mock_getenv.return_value = "test-bucket" 176 | 177 | # Mock the infer method 178 | comfy_ui.infer = MagicMock(return_value=b"test_image_data") 179 | 180 | # Call the API 181 | comfy_ui.api({ 182 | "prompt": "test prompt", 183 | "story_id": "test_story_id", 184 | "node_id": "test_node_id" 185 | }) 186 | 187 | # Check that the workflow was loaded 188 | mock_path.assert_any_call("/root/flux.json") 189 | 190 | # Check that the prompt was updated in the workflow 191 | assert mock_path().read_text.called 192 | 193 | # Check that the new workflow was saved 194 | assert mock_file.called 195 | 196 | # Check that infer was called with the new workflow 197 | comfy_ui.infer.assert_called_once_with("test_uuid.json") 198 | 199 | # Check that the result was uploaded to S3 200 | mock_boto3.put_object.assert_called_once_with( 201 | Bucket="test-bucket", 202 | Key="test_story_id/test_node_id.png", 203 | Body=b"test_image_data" 204 | ) -------------------------------------------------------------------------------- /modal/workflows/flux.json: -------------------------------------------------------------------------------- 1 | { 2 | "5": { 3 | "inputs": { 4 | "width": 1024, 5 | "height": 1024, 6 | "batch_size": 1 7 | }, 8 | "class_type": "EmptyLatentImage", 9 | "_meta": { 10 | "title": "Empty Latent Image" 11 | } 12 | }, 13 | "6": { 14 | "inputs": { 15 | "text": "The camera focuses tightly on the linguist’s face, eyes wide with fascination, as the symbols glow softly in the dim, cavernous chamber. The room is ancient, its stone walls covered in intricate carvings, reminiscent of the labyrinthine designs seen in Indiana Jones and the Last Crusade. The faint hum of the symbols fills the air, adding to the palpable tension. In the center of the room, the symbols suddenly shimmer and rearrange, forming a glowing, cryptic message. The air seems to pulse as the explorers, silhouetted against the eerie light, exchange nervous glances, the weight of the revelation heavy on their shoulders.", 16 | "clip": ["11", 0] 17 | }, 18 | "class_type": "CLIPTextEncode", 19 | "_meta": { 20 | "title": "CLIP Text Encode (Prompt)" 21 | } 22 | }, 23 | "8": { 24 | "inputs": { 25 | "samples": ["13", 0], 26 | "vae": ["10", 0] 27 | }, 28 | "class_type": "VAEDecode", 29 | "_meta": { 30 | "title": "VAE Decode" 31 | } 32 | }, 33 | "9": { 34 | "inputs": { 35 | "filename_prefix": "ComfyUI", 36 | "images": ["28", 0] 37 | }, 38 | "class_type": "SaveImage", 39 | "_meta": { 40 | "title": "Save Image" 41 | } 42 | }, 43 | "10": { 44 | "inputs": { 45 | "vae_name": "ae.safetensors" 46 | }, 47 | "class_type": "VAELoader", 48 | "_meta": { 49 | "title": "Load VAE" 50 | } 51 | }, 52 | "11": { 53 | "inputs": { 54 | "clip_name1": "t5xxl_fp8_e4m3fn.safetensors", 55 | "clip_name2": "clip_l.safetensors", 56 | "type": "flux" 57 | }, 58 | "class_type": "DualCLIPLoader", 59 | "_meta": { 60 | "title": "DualCLIPLoader" 61 | } 62 | }, 63 | "13": { 64 | "inputs": { 65 | "noise": ["25", 0], 66 | "guider": ["22", 0], 67 | "sampler": ["16", 0], 68 | "sigmas": ["17", 0], 69 | "latent_image": ["5", 0] 70 | }, 71 | "class_type": "SamplerCustomAdvanced", 72 | "_meta": { 73 | "title": "SamplerCustomAdvanced" 74 | } 75 | }, 76 | "16": { 77 | "inputs": { 78 | "sampler_name": "euler" 79 | }, 80 | "class_type": "KSamplerSelect", 81 | "_meta": { 82 | "title": "KSamplerSelect" 83 | } 84 | }, 85 | "17": { 86 | "inputs": { 87 | "scheduler": "simple", 88 | "steps": 10, 89 | "denoise": 1, 90 | "model": ["29", 0] 91 | }, 92 | "class_type": "BasicScheduler", 93 | "_meta": { 94 | "title": "BasicScheduler" 95 | } 96 | }, 97 | "22": { 98 | "inputs": { 99 | "model": ["29", 0], 100 | "conditioning": ["6", 0] 101 | }, 102 | "class_type": "BasicGuider", 103 | "_meta": { 104 | "title": "BasicGuider" 105 | } 106 | }, 107 | "25": { 108 | "inputs": { 109 | "noise_seed": 100743429905359 110 | }, 111 | "class_type": "RandomNoise", 112 | "_meta": { 113 | "title": "RandomNoise" 114 | } 115 | }, 116 | "27": { 117 | "inputs": { 118 | "images": ["8", 0] 119 | }, 120 | "class_type": "PreviewImage", 121 | "_meta": { 122 | "title": "Preview Image" 123 | } 124 | }, 125 | "28": { 126 | "inputs": { 127 | "mode": "rescale", 128 | "supersample": "true", 129 | "resampling": "lanczos", 130 | "rescale_factor": 2, 131 | "resize_width": 800, 132 | "resize_height": 600, 133 | "image": ["8", 0] 134 | }, 135 | "class_type": "Image Resize", 136 | "_meta": { 137 | "title": "Image Resize" 138 | } 139 | }, 140 | "29": { 141 | "inputs": { 142 | "unet_name": "flux1-dev-Q5_K_S.gguf" 143 | }, 144 | "class_type": "UnetLoaderGGUF", 145 | "_meta": { 146 | "title": "Unet Loader (GGUF)" 147 | } 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /restate/.dockerignore: -------------------------------------------------------------------------------- 1 | .venv -------------------------------------------------------------------------------- /restate/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .venv/ 5 | restate-data/ 6 | **/restate-data/* 7 | -------------------------------------------------------------------------------- /restate/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/astral-sh/uv:python3.11-bookworm-slim 2 | 3 | RUN apt-get update && apt-get install -y \ 4 | curl \ 5 | build-essential \ 6 | cmake \ 7 | pkg-config \ 8 | && rm -rf /var/lib/apt/lists/* \ 9 | && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 10 | 11 | 12 | ENV PATH="/root/.cargo/bin:$PATH" 13 | 14 | 15 | WORKDIR /app 16 | 17 | # Copy all files to app directory 18 | COPY ./ /app 19 | 20 | RUN uv venv && uv venv activate 21 | 22 | RUN uv pip install -r pyproject.toml 23 | 24 | # Add hypercorn to PATH 25 | ENV PATH="/app/.venv/bin:$PATH" 26 | ENV PYTHONUNBUFFERED=1 27 | ENV LOG_LEVEL=INFO 28 | 29 | # Add this to ensure Python logs are sent to stdout/stderr 30 | ENV PYTHONIOENCODING=utf-8 31 | # Run the hypercorn server 32 | CMD ["hypercorn", "-c", "hypercorn-config.toml", "main:app"] 33 | -------------------------------------------------------------------------------- /restate/fly.toml: -------------------------------------------------------------------------------- 1 | # fly.toml app configuration file generated for story-service-muddy-glitter-511 on 2024-11-16T18:15:46+08:00 2 | # 3 | # See https://fly.io/docs/reference/configuration/ for information about how to use this file. 4 | # 5 | 6 | app = 'story-service' 7 | primary_region = 'sin' 8 | 9 | [build] 10 | dockerfile = './Dockerfile' 11 | 12 | [http_service] 13 | internal_port = 9080 14 | force_https = true 15 | auto_stop_machines = false 16 | auto_start_machines = false 17 | min_machines_running = 1 18 | processes = ['app'] 19 | 20 | [[http_service.ports]] 21 | handlers = ["http"] 22 | port = 80 23 | force_https = true 24 | 25 | [[http_service.ports]] 26 | handlers = ["tls", "http"] 27 | port = 443 28 | 29 | [http_service.tls_options] 30 | alpn = ["h2", "http/1.1"] 31 | 32 | [vm] 33 | cpu_kind = "shared" 34 | cpu_cores = 2 35 | memory_mb = 2048 36 | -------------------------------------------------------------------------------- /restate/helpers/db.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from libsql_client import create_client_sync 4 | from helpers.env import Env 5 | from helpers.story import StoryOutline, FinalStoryNode 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class DatabaseClient: 11 | _instance = None 12 | _connection = None 13 | 14 | def __new__(cls): 15 | if cls._instance is None: 16 | cls._instance = super().__new__(cls) 17 | cls._settings = Env() 18 | return cls._instance 19 | 20 | def get_connection(self): 21 | from contextlib import contextmanager 22 | 23 | @contextmanager 24 | def connection(): 25 | for _ in range(3): 26 | if not self._connection: 27 | self._connection = create_client_sync( 28 | self._settings.DB_URL, 29 | auth_token=self._settings.DB_TOKEN, 30 | ) 31 | 32 | try: 33 | # Test connection 34 | self._connection.execute("SELECT * FROM stories LIMIT 1;") 35 | break 36 | except Exception: 37 | print("Connection expired, recreating connection") 38 | self._connection = None 39 | else: 40 | raise Exception( 41 | "Failed to establish database connection after 3 attempts" 42 | ) 43 | 44 | try: 45 | yield self._connection 46 | finally: 47 | logger.info("Database connection closed") 48 | 49 | return connection() 50 | 51 | def insert_story(self, story: StoryOutline, user_email: str, story_prompt: str): 52 | import uuid 53 | 54 | story_id = str(uuid.uuid4()) 55 | 56 | with self.get_connection() as conn: 57 | query = """INSERT INTO stories 58 | (id, user_id, title, description, status, timestamp, story_prompt, image_prompt) 59 | VALUES 60 | (?, ?, ?, ?, ?, ?, ?, ?)""" 61 | params = ( 62 | story_id, 63 | user_email, 64 | story.title, 65 | story.description, 66 | "PROCESSING", 67 | int(datetime.now().timestamp()), 68 | story_prompt, 69 | story.banner_image, 70 | ) 71 | print(f"Executing query: {query} with params: {params}") 72 | conn.execute(query, params) 73 | 74 | return story_id 75 | 76 | def mark_story_as_completed(self, story_id: str): 77 | try: 78 | print(f"Marking story {story_id} as completed") 79 | with self.get_connection() as conn: 80 | query = "UPDATE stories SET status = 'GENERATED' WHERE id = ?" 81 | conn.execute(query, (story_id,)) 82 | 83 | except Exception as e: 84 | logger.error(f"Failed to mark story {story_id} as completed: {str(e)}") 85 | raise 86 | 87 | def insert_story_nodes( 88 | self, nodes: list[FinalStoryNode], story_id: str, user_id: str 89 | ): 90 | with self.get_connection() as conn: 91 | query = """INSERT INTO story_choices 92 | (id, user_id, parent_id, story_id, description, choice_title, choice_description, is_terminal, explored, image_prompt) 93 | VALUES 94 | (?, ?, ?, ?, ?, ?, ?, ?, ?,?)""" 95 | 96 | try: 97 | for node in nodes: 98 | params = ( 99 | node.id, 100 | user_id, 101 | "NULL" if node.parent_id is None else node.parent_id, 102 | story_id, 103 | node.description, 104 | node.choice_title, 105 | node.choice_description, 106 | 1 if node.is_terminal else 0, 107 | 1 if node.parent_id is None else 0, 108 | node.image_description, 109 | ) 110 | conn.execute(query, params) 111 | 112 | except Exception as e: 113 | logger.error(f"Failed to insert story nodes: {str(e)}") 114 | raise 115 | -------------------------------------------------------------------------------- /restate/helpers/env.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | 4 | class Env(BaseSettings): 5 | GOOGLE_API_KEY: str 6 | DB_URL: str 7 | DB_TOKEN: str 8 | IMAGE_ENDPOINT: str 9 | KOKORO_ENDPOINT: str 10 | AWS_SECRET_ACCESS_KEY: str 11 | AWS_ACCESS_KEY_ID: str 12 | AWS_REGION: str 13 | 14 | class Config: 15 | env_file = ".env" 16 | -------------------------------------------------------------------------------- /restate/helpers/s3.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | from helpers.env import Env 4 | 5 | 6 | def get_story_items(story_id: str) -> list[str]: 7 | """ 8 | Get list of image filenames for a given story ID from S3 bucket 9 | 10 | Args: 11 | story_id: UUID of the story 12 | 13 | Returns: 14 | List of image filenames (without .png extension) 15 | """ 16 | try: 17 | s3 = boto3.client( 18 | "s3", 19 | aws_access_key_id=Env().AWS_ACCESS_KEY_ID, 20 | aws_secret_access_key=Env().AWS_SECRET_ACCESS_KEY, 21 | region_name=Env().AWS_REGION, 22 | ) 23 | 24 | response = s3.list_objects_v2(Bucket="restate-story", Prefix=f"{story_id}/") 25 | 26 | if "Contents" not in response: 27 | return { 28 | "images": [], 29 | "audio": [], 30 | } 31 | 32 | return { 33 | "images": [ 34 | obj["Key"].replace(f"{story_id}/", "").replace(".png", "") 35 | for obj in response["Contents"] 36 | if obj["Key"].endswith(".png") 37 | ], 38 | "audio": [ 39 | obj["Key"].replace(f"{story_id}/", "").replace(".wav", "") 40 | for obj in response["Contents"] 41 | if obj["Key"].endswith(".wav") 42 | ], 43 | } 44 | 45 | except ClientError as e: 46 | print(f"Error accessing S3: {e}") 47 | return [] 48 | -------------------------------------------------------------------------------- /restate/hypercorn-config.toml: -------------------------------------------------------------------------------- 1 | bind = "0.0.0.0:9080" 2 | h2_max_concurrent_streams = 2147483647 3 | keep_alive_max_requests = 2147483647 4 | keep_alive_timeout = 2147483647 5 | workers = 8 6 | -------------------------------------------------------------------------------- /restate/main.py: -------------------------------------------------------------------------------- 1 | from helpers.db import DatabaseClient 2 | from helpers.s3 import get_story_items 3 | from helpers.story import ( 4 | generate_images, 5 | generate_story, 6 | StoryOutline, 7 | StoryNodes, 8 | generate_story_choices, 9 | generate_tts, 10 | ) 11 | from datetime import timedelta 12 | from rich import print 13 | import restate 14 | from restate import Workflow, WorkflowContext 15 | from restate.exceptions import TerminalError 16 | from pydantic import BaseModel 17 | from restate.serde import PydanticJsonSerde 18 | import time 19 | 20 | story_workflow = Workflow("cyoa") 21 | 22 | 23 | class StoryInput(BaseModel): 24 | prompt: str 25 | user_email: str 26 | 27 | 28 | def wrap_async_call(coro_fn, *args, **kwargs): 29 | async def wrapped(): 30 | print(f"Starting {coro_fn.__name__}") 31 | start_time = time.time() 32 | result = await coro_fn(*args, **kwargs) 33 | end_time = time.time() 34 | print(f"Function {coro_fn.__name__} took {end_time - start_time:.2f} seconds") 35 | return result 36 | 37 | return wrapped 38 | 39 | 40 | @story_workflow.main() 41 | async def run(ctx: WorkflowContext, req: StoryInput) -> str: 42 | print(f"Recieved request: {req}") 43 | db = DatabaseClient() 44 | 45 | # This will take in a story prompt and then generate a story 46 | try: 47 | story: StoryOutline = await ctx.run( 48 | "Generate Story", 49 | lambda: generate_story(req.prompt), 50 | serde=PydanticJsonSerde(StoryOutline), 51 | ) 52 | except TerminalError as e: 53 | print(e) 54 | raise TerminalError("Failed to generate story") 55 | 56 | try: 57 | story_id = await ctx.run( 58 | "Insert Story", 59 | lambda: db.insert_story(story, req.user_email, req.prompt), 60 | ) 61 | except Exception as e: 62 | print(e) 63 | raise TerminalError("Failed to insert story") 64 | 65 | try: 66 | choices: StoryNodes = await ctx.run( 67 | "Generate Story Choices", 68 | wrap_async_call(generate_story_choices, story), 69 | serde=PydanticJsonSerde(StoryNodes), 70 | ) 71 | except TerminalError as e: 72 | print(e) 73 | raise TerminalError("Failed to generate story choices") 74 | 75 | try: 76 | await ctx.run( 77 | "Insert Story Choices", 78 | lambda: db.insert_story_nodes(choices.nodes, story_id, req.user_email), 79 | ) 80 | except Exception as e: 81 | print(e) 82 | raise TerminalError("Failed to insert story choices") 83 | 84 | try: 85 | await ctx.run( 86 | "trigger task", 87 | wrap_async_call( 88 | generate_images, choices.nodes, story_id, story.banner_image 89 | ), 90 | ) 91 | except Exception as e: 92 | print(e) 93 | raise TerminalError("Failed to generate images") 94 | 95 | try: 96 | await ctx.run( 97 | "trigger task", 98 | wrap_async_call(generate_tts, choices.nodes, story_id, story.description), 99 | ) 100 | except Exception as e: 101 | print(e) 102 | raise TerminalError("Failed to generate images") 103 | 104 | iterations = 0 105 | 106 | expected_images = set([node.id for node in choices.nodes] + ["banner"]) 107 | expected_audio = set([node.id for node in choices.nodes] + ["theme"]) 108 | 109 | while True: 110 | # We poll our S3 bucket and wait to see if all the images are there. 111 | images_and_audio = await ctx.run( 112 | "Get Story Images", lambda: get_story_items(story_id) 113 | ) 114 | 115 | remaining_images = expected_images - set(images_and_audio["images"]) 116 | remaining_audio = expected_audio - set(images_and_audio["audio"]) 117 | 118 | if len(remaining_images) == 0 and len(remaining_audio) == 0: 119 | break 120 | 121 | iterations += 1 122 | print( 123 | f"Iteration {iterations} : {len(remaining_images)} images and {len(remaining_audio)} audio remaining" 124 | ) 125 | 126 | # We wait for at most 10 minutes. If the story is not ready, then we just give up. 127 | if iterations > 10: 128 | break 129 | 130 | await ctx.sleep(delta=timedelta(seconds=60)) 131 | 132 | try: 133 | await ctx.run( 134 | "Mark Story as Completed", 135 | lambda: db.mark_story_as_completed(story_id), 136 | ) 137 | except Exception as e: 138 | print(e) 139 | raise TerminalError("Failed to mark story as completed") 140 | 141 | return "success" 142 | 143 | 144 | app = restate.app( 145 | [story_workflow], 146 | "bidi", 147 | identity_keys=[ 148 | "publickeyv1_GTKUcX5ZHNBG3MX9wk7JGwA6VALTGr5UNYika3kyf63e", 149 | "publickeyv1_wqWhnpRDLYsvBc7d2A9zFhLDfE2iWkukM6ThqZdJ87N", 150 | ], 151 | ) 152 | -------------------------------------------------------------------------------- /restate/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "restate-service" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "boto3>=1.35.97", 9 | "hypercorn>=0.17.3", 10 | "instructor[google-generativeai]>=1.7.2", 11 | "libsql-client>=0.3.1", 12 | "pydantic-settings>=2.7.0", 13 | "restate-sdk>=0.4.1", 14 | ] 15 | 16 | [project.optional-dependencies] 17 | dev = [ 18 | "pytest>=7.4.0", 19 | "pytest-asyncio>=0.21.1", 20 | "pytest-mock>=3.11.1", 21 | ] 22 | -------------------------------------------------------------------------------- /restate/tests/README.md: -------------------------------------------------------------------------------- 1 | # CYOA Unit Tests 2 | 3 | This directory contains unit tests for the CYOA (Choose Your Own Adventure) application. 4 | 5 | ## Running the Tests 6 | 7 | To run the tests, you'll need to install the development dependencies: 8 | 9 | ```bash 10 | cd restate 11 | pip install -e ".[dev]" 12 | ``` 13 | 14 | Then, you can run the tests using pytest: 15 | 16 | ```bash 17 | # Run all tests 18 | pytest 19 | 20 | # Run tests with verbose output 21 | pytest -v 22 | 23 | # Run a specific test file 24 | pytest tests/test_db.py 25 | 26 | # Run a specific test 27 | pytest tests/test_db.py::test_insert_story 28 | ``` 29 | 30 | ## Test Coverage 31 | 32 | The tests cover the following components: 33 | 34 | 1. **Database Client** (`test_db.py`): 35 | - Testing the singleton pattern 36 | - Testing database operations (insert_story, mark_story_as_completed, insert_story_nodes) 37 | - Testing connection retry mechanism 38 | 39 | 2. **S3 Helper** (`test_s3.py`): 40 | - Testing S3 operations (get_story_items) 41 | - Testing error handling 42 | 43 | 3. **Story Generation** (`test_story.py`): 44 | - Testing story outline generation 45 | - Testing story choices generation 46 | - Testing image generation 47 | - Testing text-to-speech generation 48 | 49 | 4. **Main Workflow** (`test_main.py`): 50 | - Testing the main workflow success path 51 | - Testing error handling for various failure scenarios 52 | - Testing timeout handling 53 | 54 | ## Adding New Tests 55 | 56 | When adding new tests, follow these guidelines: 57 | 58 | 1. Create test files with the `test_` prefix 59 | 2. Use descriptive test function names that explain what is being tested 60 | 3. Use fixtures to set up common test dependencies 61 | 4. Mock external dependencies to isolate the code being tested 62 | 5. Test both success and failure scenarios 63 | 64 | ## Test Structure 65 | 66 | Each test file follows a similar structure: 67 | 68 | 1. Import necessary modules and the code being tested 69 | 2. Define fixtures for setting up test dependencies 70 | 3. Define test functions that use the fixtures 71 | 4. Assert expected outcomes 72 | 73 | ## Continuous Integration 74 | 75 | These tests can be integrated into a CI/CD pipeline to ensure code quality before deployment. -------------------------------------------------------------------------------- /restate/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/essencevc/cyoa/2c991aa4bdbcc8c09f40017c42d9f8c0f87eb2b0/restate/tests/__init__.py -------------------------------------------------------------------------------- /restate/tests/test_db.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import MagicMock, patch 3 | from helpers.db import DatabaseClient 4 | from helpers.story import StoryOutline, FinalStoryNode 5 | 6 | 7 | @pytest.fixture 8 | def mock_env(): 9 | with patch("helpers.db.Env") as mock_env: 10 | mock_env_instance = MagicMock() 11 | mock_env_instance.DB_URL = "test_url" 12 | mock_env_instance.DB_TOKEN = "test_token" 13 | mock_env.return_value = mock_env_instance 14 | yield mock_env 15 | 16 | 17 | @pytest.fixture 18 | def mock_connection(): 19 | with patch("helpers.db.create_client_sync") as mock_create_client: 20 | mock_conn = MagicMock() 21 | mock_create_client.return_value = mock_conn 22 | yield mock_conn 23 | 24 | 25 | @pytest.fixture 26 | def db_client(mock_env, mock_connection): 27 | # Reset the singleton instance 28 | DatabaseClient._instance = None 29 | DatabaseClient._connection = None 30 | return DatabaseClient() 31 | 32 | 33 | def test_singleton_pattern(): 34 | """Test that DatabaseClient follows the singleton pattern""" 35 | # Reset the singleton instance 36 | DatabaseClient._instance = None 37 | 38 | client1 = DatabaseClient() 39 | client2 = DatabaseClient() 40 | 41 | assert client1 is client2 42 | 43 | 44 | def test_insert_story(db_client, mock_connection): 45 | """Test inserting a story into the database""" 46 | # Create a mock story 47 | story = StoryOutline( 48 | title="Test Story", 49 | description="This is a test story", 50 | melody="Test melody", 51 | banner_image="Test banner image" 52 | ) 53 | 54 | # Mock uuid.uuid4() to return a predictable value 55 | with patch("uuid.uuid4") as mock_uuid: 56 | mock_uuid.return_value.hex = "test_id" 57 | mock_uuid.return_value.__str__.return_value = "test_id" 58 | 59 | # Call the method 60 | story_id = db_client.insert_story(story, "test@example.com", "Test prompt") 61 | 62 | # Verify the result 63 | assert story_id == "test_id" 64 | 65 | # Verify the database call 66 | mock_connection.execute.assert_called_once() 67 | args, kwargs = mock_connection.execute.call_args 68 | 69 | # Check that the query contains the expected values 70 | assert "INSERT INTO stories" in args[0] 71 | assert args[1][0] == "test_id" 72 | assert args[1][1] == "test@example.com" 73 | assert args[1][2] == "Test Story" 74 | assert args[1][3] == "This is a test story" 75 | assert args[1][4] == "PROCESSING" 76 | assert args[1][6] == "Test prompt" 77 | assert args[1][7] == "Test banner image" 78 | 79 | 80 | def test_mark_story_as_completed(db_client, mock_connection): 81 | """Test marking a story as completed""" 82 | # Call the method 83 | db_client.mark_story_as_completed("test_id") 84 | 85 | # Verify the database call 86 | mock_connection.execute.assert_called_once() 87 | args, kwargs = mock_connection.execute.call_args 88 | 89 | # Check that the query contains the expected values 90 | assert "UPDATE stories SET status = 'GENERATED'" in args[0] 91 | assert args[1][0] == "test_id" 92 | 93 | 94 | def test_insert_story_nodes(db_client, mock_connection): 95 | """Test inserting story nodes into the database""" 96 | # Create mock nodes 97 | nodes = [ 98 | FinalStoryNode( 99 | id="node1", 100 | parent_id=None, 101 | title="Node 1", 102 | description="This is node 1", 103 | image_description="Image for node 1", 104 | choice_title="Choice 1", 105 | choice_description="Description for choice 1", 106 | is_terminal=False 107 | ), 108 | FinalStoryNode( 109 | id="node2", 110 | parent_id="node1", 111 | title="Node 2", 112 | description="This is node 2", 113 | image_description="Image for node 2", 114 | choice_title="Choice 2", 115 | choice_description="Description for choice 2", 116 | is_terminal=True 117 | ) 118 | ] 119 | 120 | # Call the method 121 | db_client.insert_story_nodes(nodes, "test_story_id", "test@example.com") 122 | 123 | # Verify the database calls 124 | assert mock_connection.execute.call_count == 2 125 | 126 | # Check first node 127 | args1, kwargs1 = mock_connection.execute.call_args_list[0] 128 | assert "INSERT INTO story_choices" in args1[0] 129 | assert args1[1][0] == "node1" 130 | assert args1[1][1] == "test@example.com" 131 | assert args1[1][2] == "NULL" # parent_id is None 132 | assert args1[1][3] == "test_story_id" 133 | assert args1[1][4] == "This is node 1" 134 | assert args1[1][5] == "Choice 1" 135 | assert args1[1][6] == "Description for choice 1" 136 | assert args1[1][7] == 0 # is_terminal is False 137 | assert args1[1][8] == 1 # explored is 1 for root node 138 | 139 | # Check second node 140 | args2, kwargs2 = mock_connection.execute.call_args_list[1] 141 | assert args2[1][0] == "node2" 142 | assert args2[1][1] == "test@example.com" 143 | assert args2[1][2] == "node1" # parent_id 144 | assert args2[1][3] == "test_story_id" 145 | assert args2[1][4] == "This is node 2" 146 | assert args2[1][5] == "Choice 2" 147 | assert args2[1][6] == "Description for choice 2" 148 | assert args2[1][7] == 1 # is_terminal is True 149 | assert args2[1][8] == 0 # explored is 0 for non-root node 150 | 151 | 152 | def test_get_connection_retry(db_client, mock_connection): 153 | """Test that get_connection retries on failure""" 154 | # Make the connection fail on first attempt 155 | mock_connection.execute.side_effect = [Exception("Connection failed"), None] 156 | 157 | # Reset the connection 158 | db_client._connection = None 159 | 160 | # Use the context manager 161 | with db_client.get_connection(): 162 | pass 163 | 164 | # Verify that create_client_sync was called twice 165 | with patch("helpers.db.create_client_sync") as mock_create_client: 166 | mock_create_client.return_value = mock_connection 167 | assert mock_create_client.call_count <= 2 # Called at most twice -------------------------------------------------------------------------------- /restate/tests/test_s3.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import MagicMock, patch 3 | from helpers.s3 import get_story_items 4 | 5 | 6 | @pytest.fixture 7 | def mock_boto3(): 8 | with patch("helpers.s3.boto3") as mock_boto3: 9 | mock_client = MagicMock() 10 | mock_boto3.client.return_value = mock_client 11 | yield mock_client 12 | 13 | 14 | @pytest.fixture 15 | def mock_env(): 16 | with patch("helpers.s3.Env") as mock_env: 17 | mock_env_instance = MagicMock() 18 | mock_env_instance.AWS_ACCESS_KEY_ID = "test_access_key" 19 | mock_env_instance.AWS_SECRET_ACCESS_KEY = "test_secret_key" 20 | mock_env_instance.AWS_REGION = "test_region" 21 | mock_env.return_value = mock_env_instance 22 | yield mock_env 23 | 24 | 25 | def test_get_story_items_with_contents(mock_boto3, mock_env): 26 | """Test getting story items when S3 returns contents""" 27 | # Mock S3 response 28 | mock_boto3.list_objects_v2.return_value = { 29 | "Contents": [ 30 | {"Key": "test_story_id/node1.png"}, 31 | {"Key": "test_story_id/node2.png"}, 32 | {"Key": "test_story_id/banner.png"}, 33 | {"Key": "test_story_id/theme.wav"}, 34 | {"Key": "test_story_id/node1.wav"}, 35 | ] 36 | } 37 | 38 | # Call the function 39 | result = get_story_items("test_story_id") 40 | 41 | # Verify the result 42 | assert "images" in result 43 | assert "audio" in result 44 | assert set(result["images"]) == {"node1", "node2", "banner"} 45 | assert set(result["audio"]) == {"theme", "node1"} 46 | 47 | # Verify S3 client was created with correct parameters 48 | mock_boto3.list_objects_v2.assert_called_once_with( 49 | Bucket="restate-story", Prefix="test_story_id/" 50 | ) 51 | 52 | 53 | def test_get_story_items_no_contents(mock_boto3, mock_env): 54 | """Test getting story items when S3 returns no contents""" 55 | # Mock S3 response with no Contents 56 | mock_boto3.list_objects_v2.return_value = {} 57 | 58 | # Call the function 59 | result = get_story_items("test_story_id") 60 | 61 | # Verify the result 62 | assert "images" in result 63 | assert "audio" in result 64 | assert result["images"] == [] 65 | assert result["audio"] == [] 66 | 67 | 68 | def test_get_story_items_client_error(mock_boto3, mock_env): 69 | """Test getting story items when S3 client raises an error""" 70 | # Mock S3 client to raise an error 71 | from botocore.exceptions import ClientError 72 | mock_boto3.list_objects_v2.side_effect = ClientError( 73 | {"Error": {"Code": "NoSuchBucket", "Message": "The bucket does not exist"}}, 74 | "ListObjectsV2" 75 | ) 76 | 77 | # Call the function 78 | result = get_story_items("test_story_id") 79 | 80 | # Verify the result is a dictionary with empty lists for 'images' and 'audio' 81 | assert "images" in result 82 | assert "audio" in result 83 | assert result["images"] == [] 84 | assert result["audio"] == [] --------------------------------------------------------------------------------