├── .editorconfig ├── bun.lockb ├── TODO.md ├── src ├── chats │ ├── dto │ │ ├── create-chat.dto.ts │ │ ├── update-chat.dto.ts │ │ ├── get-chat-dto.ts │ │ └── add-message.dto.ts │ ├── chats.module.ts │ ├── schemas │ │ └── chat.schema.ts │ ├── chats.service.ts │ └── chats.resolver.ts ├── app.service.ts ├── auth │ ├── guards │ │ ├── jwt-auth.guard.ts │ │ └── gql-auth.guard.ts │ ├── interfaces │ │ ├── jwt-payload.interface.ts │ │ └── device-info.interface.ts │ ├── dto │ │ ├── register.dto.ts │ │ ├── login.dto.ts │ │ └── change-password.dto.ts │ ├── decorators │ │ └── current-user.decorator.ts │ ├── strategies │ │ └── jwt.strategy.ts │ ├── auth.module.ts │ ├── auth.resolver.ts │ └── auth.service.ts ├── encryption │ ├── encryption.module.ts │ └── encryption.service.ts ├── ai │ ├── ai.module.ts │ ├── clients │ │ ├── openai.client.ts │ │ ├── openrouter.client.ts │ │ ├── anthropic.client.ts │ │ ├── google.client.ts │ │ └── base │ │ │ └── base-openai-api.client.ts │ ├── ai.service.ts │ └── interfaces │ │ └── ai-provider.interface.ts ├── app.controller.ts ├── storage │ ├── dto │ │ ├── create-file.dto.ts │ │ └── complete-file.dto.ts │ ├── storage.module.ts │ ├── schemas │ │ └── file.schema.ts │ ├── storage.resolver.ts │ ├── r2WorkerClient.ts │ └── storage.service.ts ├── messages │ ├── dto │ │ ├── update-message.dto.ts │ │ └── get-messages.dto.ts │ ├── messages.module.ts │ ├── messages.resolver.ts │ ├── schemas │ │ └── message.schema.ts │ └── messages.service.ts ├── users │ ├── dto │ │ └── update-user.dto.ts │ ├── users.module.ts │ ├── users.resolver.ts │ ├── schemas │ │ └── user.schema.ts │ └── users.service.ts ├── sessions │ ├── sessions.module.ts │ ├── schemas │ │ └── session.schema.ts │ └── sessions.service.ts ├── keys │ ├── dto │ │ ├── update-api-key.dto.ts │ │ └── create-api-key.dto.ts │ ├── api-key.module.ts │ ├── schemas │ │ └── api-key.schema.ts │ ├── api-key.resolver.ts │ └── api-key.service.ts ├── branches │ ├── dto │ │ ├── fork-branch.dto.ts │ │ ├── create-branch.dto.ts │ │ └── update-branch.dto.ts │ ├── branches.module.ts │ ├── branches.resolver.ts │ ├── schemas │ │ └── chat-branch.schema.ts │ └── branches.service.ts ├── app.controller.spec.ts ├── preferences │ ├── preferences.module.ts │ ├── dto │ │ └── update-preferences.schema.ts │ ├── schema │ │ └── user-preference.schema.ts │ ├── preferences.resolver.ts │ └── preferences.service.ts ├── websockets │ ├── websockets.module.ts │ ├── decorators │ │ └── rate-limit.decorator.ts │ ├── websockets.gateway.ts │ └── websockets.service.ts ├── main.ts ├── app.module.ts └── schema.gql ├── tsconfig.build.json ├── nest-cli.json ├── test ├── jest-e2e.json └── app.e2e-spec.ts ├── .prettierrc ├── .env.sample ├── .vscode └── settings.json ├── tsconfig.json ├── .gitignore ├── eslint.config.mjs ├── package.json └── README.md /.editorconfig: -------------------------------------------------------------------------------- 1 | end_of_line = lf 2 | indent_size = 4 -------------------------------------------------------------------------------- /bun.lockb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unnamed-open-ai-chat/backend/HEAD/bun.lockb -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # To do - List 2 | 3 | ## Websockets Module 4 | 5 | - [ ] Add security (Check branch/chat owner) 6 | -------------------------------------------------------------------------------- /src/chats/dto/create-chat.dto.ts: -------------------------------------------------------------------------------- 1 | import { InputType } from '@nestjs/graphql'; 2 | 3 | @InputType() 4 | export class CreateChatDto {} 5 | -------------------------------------------------------------------------------- /tsconfig.build.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "exclude": ["node_modules", "test", "dist", "**/*spec.ts"] 4 | } 5 | -------------------------------------------------------------------------------- /src/app.service.ts: -------------------------------------------------------------------------------- 1 | import { Injectable } from '@nestjs/common'; 2 | 3 | @Injectable() 4 | export class AppService { 5 | getHello(): string { 6 | return 'Hello World!'; 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /src/auth/guards/jwt-auth.guard.ts: -------------------------------------------------------------------------------- 1 | import { Injectable } from '@nestjs/common'; 2 | import { AuthGuard } from '@nestjs/passport'; 3 | 4 | @Injectable() 5 | export class JwtAuthGuard extends AuthGuard('jwt') {} 6 | -------------------------------------------------------------------------------- /nest-cli.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://json.schemastore.org/nest-cli", 3 | "collection": "@nestjs/schematics", 4 | "sourceRoot": "src", 5 | "compilerOptions": { 6 | "deleteOutDir": true 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /test/jest-e2e.json: -------------------------------------------------------------------------------- 1 | { 2 | "moduleFileExtensions": ["js", "json", "ts"], 3 | "rootDir": ".", 4 | "testEnvironment": "node", 5 | "testRegex": ".e2e-spec.ts$", 6 | "transform": { 7 | "^.+\\.(t|j)s$": "ts-jest" 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "semi": true, 3 | "singleQuote": true, 4 | "tabWidth": 4, 5 | "useTabs": false, 6 | "printWidth": 100, 7 | "trailingComma": "es5", 8 | "bracketSpacing": true, 9 | "arrowParens": "avoid", 10 | "endOfLine": "lf" 11 | } 12 | -------------------------------------------------------------------------------- /src/encryption/encryption.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | 3 | import { EncryptionService } from './encryption.service'; 4 | 5 | @Module({ 6 | providers: [EncryptionService], 7 | exports: [EncryptionService], 8 | }) 9 | export class EncryptionModule {} 10 | -------------------------------------------------------------------------------- /src/ai/ai.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | 3 | import { StorageModule } from '@/storage/storage.module'; 4 | import { AIService } from './ai.service'; 5 | 6 | @Module({ 7 | imports: [StorageModule], 8 | providers: [AIService], 9 | exports: [AIService], 10 | }) 11 | export class AIModule {} 12 | -------------------------------------------------------------------------------- /src/app.controller.ts: -------------------------------------------------------------------------------- 1 | import { Controller, Get } from '@nestjs/common'; 2 | import { AppService } from './app.service'; 3 | 4 | @Controller() 5 | export class AppController { 6 | constructor(private readonly appService: AppService) {} 7 | 8 | @Get() 9 | getHello(): string { 10 | return this.appService.getHello(); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/auth/interfaces/jwt-payload.interface.ts: -------------------------------------------------------------------------------- 1 | export interface AccessJwtPayload { 2 | sub: string; 3 | email: string; 4 | sessionId: string; 5 | iat?: number; 6 | exp?: number; 7 | } 8 | 9 | export interface RefreshJwtPayload { 10 | sub: string; 11 | sessionId: string; 12 | iat?: number; 13 | exp?: number; 14 | } 15 | -------------------------------------------------------------------------------- /src/auth/interfaces/device-info.interface.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | 3 | @ObjectType() 4 | export class DeviceInfo { 5 | @Field() 6 | userAgent: string; 7 | 8 | @Field() 9 | ip: string; 10 | 11 | @Field({ nullable: true }) 12 | platform?: string; 13 | 14 | @Field({ nullable: true }) 15 | browser?: string; 16 | } 17 | -------------------------------------------------------------------------------- /src/storage/dto/create-file.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsNumber, IsString } from 'class-validator'; 3 | 4 | @InputType() 5 | export class CreateFileDto { 6 | @Field() 7 | @IsString() 8 | filename: string; 9 | 10 | @Field() 11 | @IsString() 12 | mimetype: string; 13 | 14 | @Field() 15 | @IsNumber() 16 | size: number; 17 | } 18 | -------------------------------------------------------------------------------- /src/messages/dto/update-message.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsString, MaxLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class UpdateMessageDto { 6 | @Field() 7 | @IsString() 8 | messageId: string; 9 | 10 | @Field() 11 | @IsString() 12 | @MaxLength(50000, { message: 'Message content must be at most 50000 characters' }) 13 | content: string; 14 | } 15 | -------------------------------------------------------------------------------- /src/ai/clients/openai.client.ts: -------------------------------------------------------------------------------- 1 | import { StorageService } from '@/storage/storage.service'; 2 | import { AIProviderId } from '../interfaces/ai-provider.interface'; 3 | import { BaseOpenAIApiClient } from './base/base-openai-api.client'; 4 | 5 | export class OpenAIClient extends BaseOpenAIApiClient { 6 | constructor(protected readonly storageService: StorageService) { 7 | super(AIProviderId.openai, storageService, 'https://api.openai.com/v1'); 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /src/ai/clients/openrouter.client.ts: -------------------------------------------------------------------------------- 1 | import { StorageService } from '@/storage/storage.service'; 2 | import { AIProviderId } from '../interfaces/ai-provider.interface'; 3 | import { BaseOpenAIApiClient } from './base/base-openai-api.client'; 4 | 5 | export class OpenRouterClient extends BaseOpenAIApiClient { 6 | constructor(protected readonly storageService: StorageService) { 7 | super(AIProviderId.openrouter, storageService, 'https://openrouter.ai/api/v1'); 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /src/auth/guards/gql-auth.guard.ts: -------------------------------------------------------------------------------- 1 | import { ExecutionContext, Injectable } from '@nestjs/common'; 2 | import { GqlExecutionContext } from '@nestjs/graphql'; 3 | import { AuthGuard } from '@nestjs/passport'; 4 | 5 | @Injectable() 6 | export class GqlAuthGuard extends AuthGuard('jwt') { 7 | getRequest(httpCtx: ExecutionContext): any { 8 | const gqlCtx = GqlExecutionContext.create(httpCtx); 9 | const req = gqlCtx.getContext().req; 10 | return req; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/auth/dto/register.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | import { LoginDto } from './login.dto'; 5 | 6 | @InputType() 7 | export class RegisterDto extends LoginDto { 8 | @Field() 9 | @IsString() 10 | @MinLength(3, { message: 'Display name must be at least 3 characters' }) 11 | @MaxLength(50, { message: 'Display name must be at most 50 characters' }) 12 | displayName: string; 13 | } 14 | -------------------------------------------------------------------------------- /src/auth/decorators/current-user.decorator.ts: -------------------------------------------------------------------------------- 1 | import { createParamDecorator, ExecutionContext } from '@nestjs/common'; 2 | import { GqlExecutionContext } from '@nestjs/graphql'; 3 | 4 | export const CurrentUser = createParamDecorator((_data: unknown, context: ExecutionContext) => { 5 | if (context.getType() === 'http') { 6 | return context.switchToHttp().getRequest().user; 7 | } else { 8 | const ctx = GqlExecutionContext.create(context); 9 | return ctx.getContext().req.user; 10 | } 11 | }); 12 | -------------------------------------------------------------------------------- /src/storage/dto/complete-file.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsNumber, IsString } from 'class-validator'; 3 | 4 | @InputType() 5 | export class FilePart { 6 | @Field() 7 | @IsString() 8 | etag: string; 9 | 10 | @Field() 11 | @IsNumber() 12 | partNumber: number; 13 | } 14 | 15 | @InputType() 16 | export class CompleteFileDto { 17 | @Field() 18 | @IsString() 19 | fileId: string; 20 | 21 | @Field(() => [FilePart]) 22 | parts: FilePart[]; 23 | } 24 | -------------------------------------------------------------------------------- /src/storage/storage.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { FileSchema } from './schemas/file.schema'; 5 | import { StorageResolver } from './storage.resolver'; 6 | import { StorageService } from './storage.service'; 7 | 8 | @Module({ 9 | imports: [MongooseModule.forFeature([{ name: File.name, schema: FileSchema }])], 10 | providers: [StorageService, StorageResolver], 11 | exports: [StorageService], 12 | }) 13 | export class StorageModule {} 14 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | # App 2 | NODE_ENV=development 3 | PORT=8000 4 | 5 | APP_NAME=uoachat 6 | APP_URL=https://sammwy.com/ 7 | 8 | # Database 9 | MONGODB_URI=mongodb://localhost/uoachat 10 | 11 | # JWT (Sessions) 12 | JWT_EXPIRATION = '30d' 13 | JWT_REFRESH_EXPIRATION = '30d' 14 | JWT_SECRET = 15 | 16 | # Cipher (Encryption and Decryption) 17 | ENCRYPTION_IV=1234567890abcdef1234567890abcdef 18 | 19 | # Worker (File upload) 20 | R2_WORKER_URL = http://localhost:8787 21 | R2_WORKER_SECRET = your-super-secret-jwt-key-change-this 22 | 23 | # CORS 24 | CORS_ORIGIN = "*" -------------------------------------------------------------------------------- /src/users/dto/update-user.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsEmail, IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class UpdateUserDto { 6 | @Field({ nullable: true }) 7 | @IsEmail({}, { message: 'Email is not valid' }) 8 | email?: string; 9 | 10 | @Field({ nullable: true }) 11 | @IsString() 12 | @MinLength(3, { message: 'Display name must be at least 3 characters' }) 13 | @MaxLength(50, { message: 'Display name must be at most 50 characters' }) 14 | displayName?: string; 15 | } 16 | -------------------------------------------------------------------------------- /src/auth/dto/login.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsEmail, IsString, Matches, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class LoginDto { 6 | @Field() 7 | @IsEmail({}, { message: 'Email is not valid' }) 8 | email: string; 9 | 10 | @Field() 11 | @IsString() 12 | @MinLength(8, { message: 'Password must be at least 8 characters' }) 13 | @Matches(/^(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{8,}$/, { 14 | message: 'Password must contain at least one letter and one number', 15 | }) 16 | password: string; 17 | } 18 | -------------------------------------------------------------------------------- /src/sessions/sessions.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { Session, SessionSchema } from './schemas/session.schema'; 5 | import { SessionsService } from './sessions.service'; 6 | 7 | @Module({ 8 | imports: [ 9 | MongooseModule.forFeature([ 10 | { 11 | name: Session.name, 12 | schema: SessionSchema, 13 | }, 14 | ]), 15 | ], 16 | providers: [SessionsService], 17 | exports: [SessionsService], 18 | }) 19 | export class SessionsModule {} 20 | -------------------------------------------------------------------------------- /src/keys/dto/update-api-key.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsBoolean, IsOptional, IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class UpdateApiKeyDto { 6 | @Field({ nullable: true }) 7 | @IsOptional() 8 | @IsString() 9 | @MinLength(3, { message: 'Alias must be at least 3 characters long' }) 10 | @MaxLength(50, { message: 'Alias must not exceed 50 characters' }) 11 | alias?: string; 12 | 13 | @Field({ nullable: true }) 14 | @IsOptional() 15 | @IsBoolean() 16 | isActive?: boolean; 17 | } 18 | -------------------------------------------------------------------------------- /src/branches/dto/fork-branch.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsBoolean, IsOptional, IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class ForkBranchDto { 6 | @Field({ nullable: true }) 7 | @IsOptional() 8 | @IsString() 9 | @MinLength(1, { message: 'Branch name must be at least 1 character' }) 10 | @MaxLength(50, { message: 'Branch name must be at most 50 characters' }) 11 | name?: string; 12 | 13 | @Field({ nullable: true }) 14 | @IsOptional() 15 | @IsBoolean() 16 | cloneMessages?: boolean; 17 | } 18 | -------------------------------------------------------------------------------- /src/messages/dto/get-messages.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsNumber, IsOptional, IsString, Min } from 'class-validator'; 3 | 4 | @InputType() 5 | export class GetMessagesDto { 6 | @Field() 7 | @IsString() 8 | branchId: string; 9 | 10 | @Field({ nullable: true }) 11 | @IsOptional() 12 | @IsNumber() 13 | @Min(1) 14 | limit?: number = 50; 15 | 16 | @Field({ nullable: true }) 17 | @IsOptional() 18 | @IsNumber() 19 | @Min(0) 20 | offset?: number = 0; 21 | 22 | @Field({ nullable: true }) 23 | @IsOptional() 24 | @IsNumber() 25 | @Min(0) 26 | fromIndex?: number = 0; 27 | } 28 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "editor.defaultFormatter": "esbenp.prettier-vscode", 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll.eslint": "explicit" 6 | }, 7 | "files.eol": "\n", 8 | "prettier.requireConfig": true, 9 | "prettier.useEditorConfig": false, 10 | "editor.tabSize": 4, 11 | "editor.insertSpaces": true, 12 | "[javascript]": { 13 | "editor.defaultFormatter": "esbenp.prettier-vscode" 14 | }, 15 | "[typescript]": { 16 | "editor.defaultFormatter": "esbenp.prettier-vscode" 17 | }, 18 | "[json]": { 19 | "editor.defaultFormatter": "esbenp.prettier-vscode" 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "declaration": true, 5 | "removeComments": true, 6 | "emitDecoratorMetadata": true, 7 | "experimentalDecorators": true, 8 | "allowSyntheticDefaultImports": true, 9 | "target": "ES2023", 10 | "sourceMap": true, 11 | "outDir": "./dist", 12 | "baseUrl": "./", 13 | "incremental": true, 14 | "skipLibCheck": true, 15 | "strictNullChecks": true, 16 | "forceConsistentCasingInFileNames": true, 17 | "noImplicitAny": false, 18 | "strictBindCallApply": false, 19 | "noFallthroughCasesInSwitch": false, 20 | 21 | /* Path Aliases */ 22 | "paths": { 23 | "@/*": ["./src/*"], 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/app.controller.spec.ts: -------------------------------------------------------------------------------- 1 | import { Test, TestingModule } from '@nestjs/testing'; 2 | import { AppController } from './app.controller'; 3 | import { AppService } from './app.service'; 4 | 5 | describe('AppController', () => { 6 | let appController: AppController; 7 | 8 | beforeEach(async () => { 9 | const app: TestingModule = await Test.createTestingModule({ 10 | controllers: [AppController], 11 | providers: [AppService], 12 | }).compile(); 13 | 14 | appController = app.get(AppController); 15 | }); 16 | 17 | describe('root', () => { 18 | it('should return "Hello World!"', () => { 19 | expect(appController.getHello()).toBe('Hello World!'); 20 | }); 21 | }); 22 | }); 23 | -------------------------------------------------------------------------------- /test/app.e2e-spec.ts: -------------------------------------------------------------------------------- 1 | import { Test, TestingModule } from '@nestjs/testing'; 2 | import { INestApplication } from '@nestjs/common'; 3 | import * as request from 'supertest'; 4 | import { App } from 'supertest/types'; 5 | import { AppModule } from './../src/app.module'; 6 | 7 | describe('AppController (e2e)', () => { 8 | let app: INestApplication; 9 | 10 | beforeEach(async () => { 11 | const moduleFixture: TestingModule = await Test.createTestingModule({ 12 | imports: [AppModule], 13 | }).compile(); 14 | 15 | app = moduleFixture.createNestApplication(); 16 | await app.init(); 17 | }); 18 | 19 | it('/ (GET)', () => { 20 | return request(app.getHttpServer()).get('/').expect(200).expect('Hello World!'); 21 | }); 22 | }); 23 | -------------------------------------------------------------------------------- /src/keys/dto/create-api-key.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsEnum, IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | import { AIProviderId } from '@/ai/interfaces/ai-provider.interface'; 5 | 6 | @InputType() 7 | export class CreateApiKeyDto { 8 | @Field(() => AIProviderId) 9 | @IsEnum(AIProviderId, { message: 'Invalid provider' }) 10 | provider: AIProviderId; 11 | 12 | @Field() 13 | @IsString() 14 | @MinLength(3, { message: 'Alias must be at least 3 characters long' }) 15 | @MaxLength(50, { message: 'Alias must not exceed 50 characters' }) 16 | alias: string; 17 | 18 | @Field() 19 | @IsString() 20 | @MinLength(10, { message: 'API key must be at least 10 characters long' }) 21 | apiKey: string; 22 | } 23 | -------------------------------------------------------------------------------- /src/keys/api-key.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { AIModule } from '@/ai/ai.module'; 5 | import { EncryptionModule } from '@/encryption/encryption.module'; 6 | import { UsersModule } from '@/users/users.module'; 7 | import { ApiKeyResolver } from './api-key.resolver'; 8 | import { ApiKeysService } from './api-key.service'; 9 | import { ApiKey, ApiKeySchema } from './schemas/api-key.schema'; 10 | 11 | @Module({ 12 | imports: [ 13 | MongooseModule.forFeature([{ name: ApiKey.name, schema: ApiKeySchema }]), 14 | EncryptionModule, 15 | UsersModule, 16 | AIModule, 17 | ], 18 | providers: [ApiKeysService, ApiKeyResolver], 19 | exports: [ApiKeysService], 20 | }) 21 | export class ApiKeysModule {} 22 | -------------------------------------------------------------------------------- /src/auth/dto/change-password.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsString, Matches, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class ChangePasswordDto { 6 | @Field() 7 | @IsString() 8 | @MinLength(8, { message: 'Old password must be at least 8 characters' }) 9 | @Matches(/^(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{8,}$/, { 10 | message: 'Old password must contain at least one letter and one number', 11 | }) 12 | oldPassword: string; 13 | 14 | @Field() 15 | @IsString() 16 | @MinLength(8, { message: 'New password must be at least 8 characters' }) 17 | @Matches(/^(?=.*[A-Za-z])(?=.*\d)[A-Za-z\d]{8,}$/, { 18 | message: 'New password must contain at least one letter and one number', 19 | }) 20 | newPassword: string; 21 | } 22 | -------------------------------------------------------------------------------- /src/chats/dto/update-chat.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsBoolean, IsOptional, IsString, MaxLength, MinLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class UpdateChatDto { 6 | @Field({ nullable: true }) 7 | @IsOptional() 8 | @IsString() 9 | @MinLength(1, { message: 'Chat name must be at least 1 character' }) 10 | @MaxLength(100, { message: 'Chat name must be at most 100 characters' }) 11 | title?: string; 12 | 13 | @Field({ nullable: true }) 14 | @IsOptional() 15 | @IsBoolean() 16 | isPublic?: boolean; 17 | 18 | @Field({ nullable: true }) 19 | @IsOptional() 20 | @IsBoolean() 21 | archived?: boolean; 22 | 23 | @Field({ nullable: true }) 24 | @IsOptional() 25 | @IsBoolean() 26 | pinned?: boolean; 27 | } 28 | -------------------------------------------------------------------------------- /src/preferences/preferences.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { WebsocketsModule } from '@/websockets/websockets.module'; 5 | import { PreferencesResolver } from './preferences.resolver'; 6 | import { PreferencesService } from './preferences.service'; 7 | import { PreferencesSchema, UserPreferences } from './schema/user-preference.schema'; 8 | 9 | @Module({ 10 | imports: [ 11 | MongooseModule.forFeature([ 12 | { 13 | name: UserPreferences.name, 14 | schema: PreferencesSchema, 15 | }, 16 | ]), 17 | WebsocketsModule, 18 | ], 19 | providers: [PreferencesService, PreferencesResolver], 20 | exports: [PreferencesService], 21 | }) 22 | export class PreferencesModule {} 23 | -------------------------------------------------------------------------------- /src/users/users.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { PreferencesModule } from '@/preferences/preferences.module'; 5 | import { WebsocketsModule } from '@/websockets/websockets.module'; 6 | import { User, UserSchema } from './schemas/user.schema'; 7 | import { UsersResolver } from './users.resolver'; 8 | import { UsersService } from './users.service'; 9 | 10 | @Module({ 11 | imports: [ 12 | MongooseModule.forFeature([ 13 | { 14 | name: User.name, 15 | schema: UserSchema, 16 | }, 17 | ]), 18 | PreferencesModule, 19 | WebsocketsModule, 20 | ], 21 | providers: [UsersService, UsersResolver], 22 | exports: [UsersService], 23 | }) 24 | export class UsersModule {} 25 | -------------------------------------------------------------------------------- /src/chats/dto/get-chat-dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsBoolean, IsNumber, IsOptional, IsString, Max, Min } from 'class-validator'; 3 | 4 | @InputType() 5 | export class GetManyChatsDto { 6 | @Field({ nullable: true }) 7 | @IsOptional() 8 | @IsNumber() 9 | @Min(1) 10 | @Max(100) 11 | limit?: number = 20; 12 | 13 | @Field({ nullable: true }) 14 | @IsOptional() 15 | @IsNumber() 16 | @Min(0) 17 | offset?: number = 0; 18 | 19 | @Field({ nullable: true }) 20 | @IsOptional() 21 | @IsBoolean() 22 | archived?: boolean = false; 23 | 24 | @Field({ nullable: true }) 25 | @IsOptional() 26 | @IsString() 27 | search?: string; 28 | } 29 | 30 | @InputType() 31 | export class GetChatDto { 32 | @Field() 33 | @IsString() 34 | chatId: string; 35 | } 36 | -------------------------------------------------------------------------------- /src/branches/dto/create-branch.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsNumber, IsOptional, IsString, MaxLength, MinLength } from 'class-validator'; 3 | import { UpdateBranchModelConfigDto } from './update-branch.dto'; 4 | 5 | @InputType() 6 | export class CreateBranchDto { 7 | @Field() 8 | @IsString() 9 | @MinLength(1, { message: 'Branch name must be at least 1 character' }) 10 | @MaxLength(50, { message: 'Branch name must be at most 50 characters' }) 11 | name: string; 12 | 13 | @Field({ nullable: true }) 14 | @IsOptional() 15 | @IsString() 16 | parentBranchId?: string; 17 | 18 | @Field({ nullable: true }) 19 | @IsOptional() 20 | @IsNumber() 21 | branchPoint?: number; 22 | 23 | @Field({ nullable: true }) 24 | @IsOptional() 25 | modelConfig?: UpdateBranchModelConfigDto; 26 | } 27 | -------------------------------------------------------------------------------- /src/messages/messages.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { ChatBranch, ChatBranchSchema } from '@/branches/schemas/chat-branch.schema'; 5 | import { WebsocketsModule } from '@/websockets/websockets.module'; 6 | import { MessagesResolver } from './messages.resolver'; 7 | import { MessagesService } from './messages.service'; 8 | import { Message, MessageSchema } from './schemas/message.schema'; 9 | 10 | @Module({ 11 | imports: [ 12 | MongooseModule.forFeature([ 13 | { name: ChatBranch.name, schema: ChatBranchSchema }, 14 | { name: Message.name, schema: MessageSchema }, 15 | ]), 16 | WebsocketsModule, 17 | ], 18 | providers: [MessagesService, MessagesResolver], 19 | exports: [MessagesService], 20 | }) 21 | export class MessagesModule {} 22 | -------------------------------------------------------------------------------- /src/auth/strategies/jwt.strategy.ts: -------------------------------------------------------------------------------- 1 | import { Injectable } from '@nestjs/common'; 2 | import { ConfigService } from '@nestjs/config'; 3 | import { PassportStrategy } from '@nestjs/passport'; 4 | import { ExtractJwt, Strategy } from 'passport-jwt'; 5 | import { AccessJwtPayload } from '../interfaces/jwt-payload.interface'; 6 | 7 | @Injectable() 8 | export class JwtStrategy extends PassportStrategy(Strategy) { 9 | constructor(configService: ConfigService) { 10 | super({ 11 | jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(), 12 | ignoreExpiration: false, 13 | secretOrKey: configService.get('JWT_SECRET')!, 14 | }); 15 | } 16 | 17 | validate(payload: AccessJwtPayload): any { 18 | return { 19 | sub: payload.sub, 20 | email: payload.email, 21 | sessionId: payload.sessionId, 22 | }; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/websockets/websockets.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { ConfigModule, ConfigService } from '@nestjs/config'; 3 | import { JwtModule } from '@nestjs/jwt'; 4 | 5 | import { WebsocketGateway } from './websockets.gateway'; 6 | import { WebsocketsService } from './websockets.service'; 7 | 8 | @Module({ 9 | imports: [ 10 | ConfigModule, 11 | JwtModule.registerAsync({ 12 | imports: [ConfigModule], 13 | inject: [ConfigService], 14 | useFactory: (configService: ConfigService) => ({ 15 | secret: configService.get('JWT_SECRET'), 16 | signOptions: { 17 | expiresIn: configService.get('JWT_EXPIRATION'), 18 | }, 19 | }), 20 | }), 21 | ], 22 | providers: [WebsocketGateway, WebsocketsService], 23 | exports: [WebsocketsService], 24 | }) 25 | export class WebsocketsModule {} 26 | -------------------------------------------------------------------------------- /src/preferences/dto/update-preferences.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsBoolean, IsOptional, IsString } from 'class-validator'; 3 | 4 | @InputType() 5 | export class UpdatePreferencesDto { 6 | @Field({ nullable: true }) 7 | @IsOptional() 8 | @IsString() 9 | dateFormat?: string; 10 | 11 | @Field({ nullable: true }) 12 | @IsOptional() 13 | @IsString() 14 | language?: string; 15 | 16 | @Field({ nullable: true }) 17 | @IsOptional() 18 | @IsBoolean() 19 | use24HourFormat?: boolean; 20 | 21 | // UI Preferences 22 | @Field({ nullable: true }) 23 | @IsOptional() 24 | @IsBoolean() 25 | showSidebar?: boolean; 26 | 27 | @Field({ nullable: true }) 28 | @IsOptional() 29 | @IsBoolean() 30 | showTimestamps?: boolean; 31 | 32 | @Field({ nullable: true }) 33 | @IsOptional() 34 | @IsString() 35 | theme?: string; 36 | } 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # compiled output 2 | /dist 3 | /node_modules 4 | /build 5 | 6 | # Logs 7 | logs 8 | *.log 9 | npm-debug.log* 10 | pnpm-debug.log* 11 | yarn-debug.log* 12 | yarn-error.log* 13 | lerna-debug.log* 14 | 15 | # OS 16 | .DS_Store 17 | 18 | # Tests 19 | /coverage 20 | /.nyc_output 21 | 22 | # IDEs and editors 23 | /.idea 24 | .project 25 | .classpath 26 | .c9/ 27 | *.launch 28 | .settings/ 29 | *.sublime-workspace 30 | 31 | # IDE - VSCode 32 | .vscode/* 33 | !.vscode/settings.json 34 | !.vscode/tasks.json 35 | !.vscode/launch.json 36 | !.vscode/extensions.json 37 | 38 | # dotenv environment variable files 39 | .env 40 | .env.development.local 41 | .env.test.local 42 | .env.production.local 43 | .env.local 44 | 45 | # temp directory 46 | .temp 47 | .tmp 48 | 49 | # Runtime data 50 | pids 51 | *.pid 52 | *.seed 53 | *.pid.lock 54 | 55 | # Diagnostic reports (https://nodejs.org/api/report.html) 56 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 57 | uploads -------------------------------------------------------------------------------- /src/chats/dto/add-message.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { IsArray, IsBoolean, IsOptional, IsString, MaxLength } from 'class-validator'; 3 | 4 | @InputType() 5 | export class AddMessageDto { 6 | @Field() 7 | @IsString() 8 | branchId: string; 9 | 10 | @Field() 11 | @IsString() 12 | @MaxLength(50000, { message: 'Message content must be at most 50000 characters' }) 13 | prompt: string; 14 | 15 | @Field() 16 | @IsString() 17 | modelId: string; 18 | 19 | @Field() 20 | @IsString() 21 | rawDecryptKey: string; 22 | 23 | @Field() 24 | @IsString() 25 | apiKeyId: string; 26 | 27 | @Field(() => [String], { nullable: true }) 28 | @IsOptional() 29 | @IsArray() 30 | @IsString({ each: true }) 31 | attachments?: string[]; 32 | 33 | @Field({ nullable: true }) 34 | @IsOptional() 35 | @IsBoolean() 36 | useImageTool?: boolean; 37 | } 38 | -------------------------------------------------------------------------------- /src/main.ts: -------------------------------------------------------------------------------- 1 | import { ConfigService } from '@nestjs/config'; 2 | import { NestFactory } from '@nestjs/core'; 3 | import { AppModule } from './app.module'; 4 | 5 | async function bootstrap() { 6 | const app = await NestFactory.create(AppModule); 7 | 8 | const configService = app.get(ConfigService); 9 | const environment = configService.get('NODE_ENV'); 10 | 11 | if (!process.env.CORS_ORIGIN) { 12 | console.warn('CORS_ORIGIN is not set'); 13 | } else { 14 | console.log('CORS_ORIGIN set to:', process.env.CORS_ORIGIN); 15 | } 16 | 17 | app.enableCors({ 18 | origin: environment === 'development' ? '*' : process.env.CORS_ORIGIN, 19 | methods: ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], 20 | allowedHeaders: ['Content-Type', 'Authorization'], 21 | credentials: true, 22 | }); 23 | 24 | await app.listen(configService.get('PORT') ?? 3000); 25 | } 26 | 27 | bootstrap().catch(console.error); 28 | -------------------------------------------------------------------------------- /src/branches/branches.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { Chat, ChatSchema } from '@/chats/schemas/chat.schema'; 5 | import { ApiKeysModule } from '@/keys/api-key.module'; 6 | import { MessagesModule } from '@/messages/messages.module'; 7 | import { WebsocketsModule } from '@/websockets/websockets.module'; 8 | import { BranchesResolver } from './branches.resolver'; 9 | import { BranchesService } from './branches.service'; 10 | import { ChatBranch, ChatBranchSchema } from './schemas/chat-branch.schema'; 11 | 12 | @Module({ 13 | imports: [ 14 | MongooseModule.forFeature([ 15 | { name: ChatBranch.name, schema: ChatBranchSchema }, 16 | { name: Chat.name, schema: ChatSchema }, 17 | ]), 18 | MessagesModule, 19 | ApiKeysModule, 20 | WebsocketsModule, 21 | ], 22 | providers: [BranchesService, BranchesResolver], 23 | exports: [BranchesService], 24 | }) 25 | export class BranchesModule {} 26 | -------------------------------------------------------------------------------- /src/users/users.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { UpdateUserDto } from './dto/update-user.dto'; 8 | import { User } from './schemas/user.schema'; 9 | import { UsersService } from './users.service'; 10 | 11 | @Resolver() 12 | export class UsersResolver { 13 | constructor(private usersService: UsersService) {} 14 | 15 | @UseGuards(GqlAuthGuard) 16 | @Query(() => User) 17 | async getUser(@CurrentUser() user: AccessJwtPayload): Promise { 18 | return await this.usersService.findById(user.sub); 19 | } 20 | 21 | @UseGuards(GqlAuthGuard) 22 | @Mutation(() => User) 23 | async updateUser( 24 | @CurrentUser() user: AccessJwtPayload, 25 | @Args('payload') payload: UpdateUserDto 26 | ): Promise { 27 | return await this.usersService.updateUser(user.sub, payload); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/messages/messages.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UnauthorizedException, UseGuards } from '@nestjs/common'; 2 | import { Args, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { GetMessagesDto } from './dto/get-messages.dto'; 8 | import { MessagesService } from './messages.service'; 9 | import { Message, MessagesResponse } from './schemas/message.schema'; 10 | 11 | @Resolver(() => Message) 12 | export class MessagesResolver { 13 | constructor(private readonly messagesService: MessagesService) {} 14 | 15 | @UseGuards(GqlAuthGuard) 16 | @Query(() => MessagesResponse) 17 | async getChatMessages( 18 | @CurrentUser() user: AccessJwtPayload, 19 | @Args('query') queryOptions: GetMessagesDto 20 | ): Promise { 21 | if (!user?.sub) { 22 | throw new UnauthorizedException("User doesn't exist"); 23 | } 24 | 25 | return await this.messagesService.findByBranchId(queryOptions, user.sub); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/storage/schemas/file.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document } from 'mongoose'; 4 | import { v4 as uuidv4 } from 'uuid'; 5 | 6 | @Schema({ timestamps: true }) 7 | @ObjectType() 8 | export class File { 9 | @Field(() => String) 10 | @Prop({ type: String, default: uuidv4 }) // <- acá 11 | _id: string; 12 | 13 | @Prop({ required: true }) 14 | @Field() 15 | filename: string; 16 | 17 | @Prop({ required: true }) 18 | @Field() 19 | mimetype: string; 20 | 21 | @Prop({ required: true }) 22 | @Field() 23 | size: number; 24 | 25 | @Prop({ required: true }) 26 | userId: string; 27 | 28 | @Prop() 29 | @Field({ nullable: true }) 30 | uploadId?: string; 31 | 32 | @Prop() 33 | @Field({ nullable: true }) 34 | clientToken?: string; 35 | 36 | @Field() 37 | createdAt: Date; 38 | } 39 | 40 | @ObjectType() 41 | export class UserStorageStats { 42 | @Field() 43 | used: number; 44 | 45 | @Field() 46 | limit: number; 47 | 48 | @Field() 49 | remaining: number; 50 | } 51 | 52 | export const FileSchema = SchemaFactory.createForClass(File); 53 | export type FileDocument = File & Document; 54 | -------------------------------------------------------------------------------- /src/preferences/schema/user-preference.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | @Schema({ timestamps: true }) 6 | @ObjectType() 7 | export class UserPreferences { 8 | @Field(() => String) 9 | _id: Types.ObjectId; 10 | 11 | @Field({ nullable: true }) 12 | @Prop({ lowercase: true, trim: true }) 13 | dateFormat?: string; 14 | 15 | @Field({ nullable: true }) 16 | @Prop({ lowercase: true, trim: true }) 17 | language?: string; 18 | 19 | @Field({ nullable: true }) 20 | @Prop() 21 | use24HourFormat?: boolean; 22 | 23 | @Field({ nullable: true }) 24 | @Prop() 25 | showSidebar?: boolean; 26 | 27 | @Field({ nullable: true }) 28 | @Prop() 29 | showTimestamps?: boolean; 30 | 31 | @Field({ nullable: true }) 32 | @Prop({ lowercase: true, trim: true }) 33 | theme?: string; 34 | 35 | @Prop({ type: Types.ObjectId, ref: 'User', required: true }) 36 | userId: Types.ObjectId; 37 | } 38 | 39 | export type PreferencesDocument = UserPreferences & Document; 40 | export const PreferencesSchema = SchemaFactory.createForClass(UserPreferences); 41 | 42 | PreferencesSchema.index({ userId: 1 }, { unique: true }); 43 | -------------------------------------------------------------------------------- /src/chats/chats.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { MongooseModule } from '@nestjs/mongoose'; 3 | 4 | import { AIModule } from '@/ai/ai.module'; 5 | import { BranchesModule } from '@/branches/branches.module'; 6 | import { EncryptionModule } from '@/encryption/encryption.module'; 7 | import { ApiKeysModule } from '@/keys/api-key.module'; 8 | import { MessagesModule } from '@/messages/messages.module'; 9 | import { StorageModule } from '@/storage/storage.module'; 10 | import { WebsocketsModule } from '@/websockets/websockets.module'; 11 | import { Message, MessageSchema } from '../messages/schemas/message.schema'; 12 | import { ChatsResolver } from './chats.resolver'; 13 | import { ChatService } from './chats.service'; 14 | import { Chat, ChatSchema } from './schemas/chat.schema'; 15 | 16 | @Module({ 17 | imports: [ 18 | MongooseModule.forFeature([ 19 | { name: Chat.name, schema: ChatSchema }, 20 | { name: Message.name, schema: MessageSchema }, 21 | ]), 22 | AIModule, 23 | ApiKeysModule, 24 | EncryptionModule, 25 | BranchesModule, 26 | MessagesModule, 27 | StorageModule, 28 | WebsocketsModule, 29 | ], 30 | providers: [ChatService, ChatsResolver], 31 | exports: [ChatService], 32 | }) 33 | export class ChatsModule {} 34 | -------------------------------------------------------------------------------- /src/auth/auth.module.ts: -------------------------------------------------------------------------------- 1 | import { Module } from '@nestjs/common'; 2 | import { ConfigModule, ConfigService } from '@nestjs/config'; 3 | import { JwtModule } from '@nestjs/jwt'; 4 | import { PassportModule } from '@nestjs/passport'; 5 | 6 | import { EncryptionModule } from 'src/encryption/encryption.module'; 7 | import { SessionsModule } from 'src/sessions/sessions.module'; 8 | import { UsersModule } from 'src/users/users.module'; 9 | import { AuthResolver } from './auth.resolver'; 10 | import { AuthService } from './auth.service'; 11 | import { JwtAuthGuard } from './guards/jwt-auth.guard'; 12 | import { JwtStrategy } from './strategies/jwt.strategy'; 13 | 14 | @Module({ 15 | imports: [ 16 | ConfigModule, 17 | PassportModule, 18 | JwtModule.registerAsync({ 19 | imports: [ConfigModule], 20 | inject: [ConfigService], 21 | useFactory: (configService: ConfigService) => ({ 22 | secret: configService.get('JWT_SECRET'), 23 | signOptions: { 24 | expiresIn: configService.get('JWT_EXPIRATION'), 25 | }, 26 | }), 27 | }), 28 | EncryptionModule, 29 | SessionsModule, 30 | UsersModule, 31 | ], 32 | providers: [AuthService, AuthResolver, JwtStrategy, JwtAuthGuard], 33 | }) 34 | export class AuthModule {} 35 | -------------------------------------------------------------------------------- /src/keys/schemas/api-key.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | import { AIProviderId } from '@/ai/interfaces/ai-provider.interface'; 6 | 7 | @ObjectType() 8 | @Schema({ timestamps: true }) 9 | export class ApiKey { 10 | @Field(() => String) 11 | _id: Types.ObjectId; 12 | 13 | @Prop({ type: Types.ObjectId, ref: 'User', required: true }) 14 | userId: Types.ObjectId; 15 | 16 | @Field(() => String) 17 | @Prop({ 18 | type: String, 19 | enum: AIProviderId, 20 | required: true, 21 | }) 22 | provider: AIProviderId; 23 | 24 | @Field() 25 | @Prop({ required: true }) 26 | alias: string; 27 | 28 | @Prop({ required: true }) 29 | encryptedApiKey: string; 30 | 31 | @Prop({ type: Boolean, default: true }) 32 | isActive: boolean; 33 | 34 | @Field({ nullable: true }) 35 | @Prop({ type: Date }) 36 | lastUsed?: Date; 37 | 38 | @Field({ nullable: true }) 39 | @Prop({ type: Date }) 40 | lastRotated: Date; 41 | 42 | @Field({ nullable: true }) 43 | @Prop({ type: Date }) 44 | lastValidated: Date; 45 | } 46 | 47 | export type ApiKeyDocument = ApiKey & Document; 48 | export const ApiKeySchema = SchemaFactory.createForClass(ApiKey); 49 | 50 | ApiKeySchema.index({ userId: 1, provider: 1, alias: 1 }, { unique: true }); 51 | ApiKeySchema.index({ userId: 1, isActive: 1 }); 52 | ApiKeySchema.index({ lastValidated: 1 }); 53 | -------------------------------------------------------------------------------- /src/preferences/preferences.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { UpdatePreferencesDto } from './dto/update-preferences.schema'; 8 | import { PreferencesService } from './preferences.service'; 9 | import { UserPreferences } from './schema/user-preference.schema'; 10 | 11 | @Resolver(() => UserPreferences) 12 | export class PreferencesResolver { 13 | constructor(private readonly userPreferencesService: PreferencesService) {} 14 | 15 | @UseGuards(GqlAuthGuard) 16 | @Query(() => UserPreferences) 17 | async getPreferences(@CurrentUser() user: AccessJwtPayload): Promise { 18 | return await this.userPreferencesService.findByUserId(user.sub); 19 | } 20 | 21 | @UseGuards(GqlAuthGuard) 22 | @Mutation(() => UserPreferences) 23 | async updatePreferences( 24 | @CurrentUser() user: AccessJwtPayload, 25 | @Args('payload') payload: UpdatePreferencesDto 26 | ): Promise { 27 | return await this.userPreferencesService.updatePreferences(user.sub, payload); 28 | } 29 | 30 | @UseGuards(GqlAuthGuard) 31 | @Mutation(() => UserPreferences) 32 | async createPreferences( 33 | @CurrentUser() user: AccessJwtPayload 34 | ): Promise { 35 | return await this.userPreferencesService.createForUser(user.sub); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/websockets/decorators/rate-limit.decorator.ts: -------------------------------------------------------------------------------- 1 | import { Socket } from 'socket.io'; 2 | 3 | // Rate limiter storage 4 | const rateLimits = new Map(); 5 | 6 | export function RateLimit(maxRequests: number, windowSeconds: number) { 7 | return function (target: any, propertyKey: string, descriptor: PropertyDescriptor) { 8 | const originalMethod = descriptor.value; 9 | 10 | descriptor.value = function (...args: any[]) { 11 | const client: Socket = args[0]; 12 | const key = `${client.id}:${propertyKey}`; 13 | 14 | // Get current time 15 | const now = Date.now(); 16 | 17 | // Get or create rate limit entry 18 | let limitData = rateLimits.get(key); 19 | if (!limitData || now > limitData.resetTime) { 20 | // Reset if window expired 21 | limitData = { 22 | count: 0, 23 | resetTime: now + windowSeconds * 1000, 24 | }; 25 | } 26 | 27 | // Check if limit exceeded 28 | if (limitData.count >= maxRequests) { 29 | client.emit('error', { 30 | message: 'Rate limit exceeded. Please try again later.', 31 | event: propertyKey, 32 | }); 33 | return; 34 | } 35 | 36 | // Increment counter 37 | limitData.count++; 38 | rateLimits.set(key, limitData); 39 | 40 | // Call original method 41 | return originalMethod.apply(this, args); 42 | }; 43 | 44 | return descriptor; 45 | }; 46 | } 47 | -------------------------------------------------------------------------------- /src/branches/dto/update-branch.dto.ts: -------------------------------------------------------------------------------- 1 | import { Field, InputType } from '@nestjs/graphql'; 2 | import { 3 | IsMongoId, 4 | IsNumber, 5 | IsOptional, 6 | IsString, 7 | MaxLength, 8 | Min, 9 | MinLength, 10 | } from 'class-validator'; 11 | 12 | @InputType() 13 | export class UpdateBranchModelConfigDto { 14 | @Field({ nullable: true }) 15 | @IsOptional() 16 | @IsString() 17 | @MinLength(1, { message: 'Model ID must be at least 1 character' }) 18 | @MaxLength(64, { message: 'Model ID must be at most 64 characters' }) 19 | modelId?: string; 20 | 21 | @Field({ nullable: true }) 22 | @IsOptional() 23 | @IsString() 24 | @IsMongoId({ message: 'API Key ID is not valid' }) 25 | apiKeyId?: string; 26 | 27 | @Field({ nullable: true }) 28 | @IsOptional() 29 | @IsNumber( 30 | { allowNaN: false, allowInfinity: false }, 31 | { message: 'Temperature must be a valid number' } 32 | ) 33 | temperature?: number; 34 | 35 | @Field({ nullable: true }) 36 | @IsOptional() 37 | @IsNumber( 38 | { allowNaN: false, allowInfinity: false }, 39 | { message: 'Max tokens must be a valid number' } 40 | ) 41 | @Min(1, { message: 'Max tokens must be at least 1' }) 42 | maxTokens?: number; 43 | } 44 | 45 | @InputType() 46 | export class UpdateBranchDto { 47 | @Field({ nullable: true }) 48 | @IsOptional() 49 | @IsString() 50 | @MinLength(1, { message: 'Branch name must be at least 1 character' }) 51 | @MaxLength(50, { message: 'Branch name must be at most 50 characters' }) 52 | name?: string; 53 | 54 | @Field({ nullable: true }) 55 | @IsOptional() 56 | modelConfig?: UpdateBranchModelConfigDto; 57 | } 58 | -------------------------------------------------------------------------------- /src/sessions/schemas/session.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | import { DeviceInfo } from '@/auth/interfaces/device-info.interface'; 6 | import { User } from '@/users/schemas/user.schema'; 7 | 8 | @Schema({ timestamps: true }) 9 | @ObjectType() 10 | export class Session { 11 | @Field(() => String) 12 | _id: Types.ObjectId; 13 | 14 | @Prop({ type: Types.ObjectId, ref: 'User', required: true }) 15 | userId: Types.ObjectId; 16 | 17 | @Prop({ required: true, unique: true }) 18 | refreshToken: string; 19 | 20 | @Prop({ required: true, unique: true }) 21 | accessToken: string; 22 | 23 | @Prop({ type: Object, required: true }) 24 | @Field() 25 | deviceInfo: DeviceInfo; 26 | 27 | @Prop({ type: Date, required: true }) 28 | @Field() 29 | expiresAt: Date; 30 | 31 | @Prop({ type: Boolean, default: true }) 32 | @Field() 33 | isActive: boolean; 34 | 35 | @Prop({ type: Date, default: Date.now }) 36 | @Field() 37 | lastUsedAt: Date; 38 | } 39 | 40 | @ObjectType() 41 | export class SessionResponse { 42 | @Field() 43 | accessToken: string; 44 | @Field() 45 | refreshToken: string; 46 | @Field({ nullable: true }) 47 | user?: User; 48 | @Field({ nullable: true }) 49 | rawDecryptKey?: string; 50 | } 51 | 52 | export const SessionSchema = SchemaFactory.createForClass(Session); 53 | export type SessionDocument = Session & Document; 54 | 55 | SessionSchema.index({ userId: 1 }); 56 | SessionSchema.index({ refreshToken: 1 }, { unique: true }); 57 | SessionSchema.index({ expiresAt: 1 }, { expireAfterSeconds: 0 }); 58 | SessionSchema.index({ isActive: 1 }); 59 | -------------------------------------------------------------------------------- /src/preferences/preferences.service.ts: -------------------------------------------------------------------------------- 1 | import { WebsocketsService } from '@/websockets/websockets.service'; 2 | import { Injectable, NotFoundException } from '@nestjs/common'; 3 | import { InjectModel } from '@nestjs/mongoose'; 4 | import { Model } from 'mongoose'; 5 | import { UpdatePreferencesDto } from './dto/update-preferences.schema'; 6 | import { PreferencesDocument, UserPreferences } from './schema/user-preference.schema'; 7 | 8 | @Injectable() 9 | export class PreferencesService { 10 | constructor( 11 | @InjectModel(UserPreferences.name) private preferencesModel: Model, 12 | private readonly websocketsService: WebsocketsService 13 | ) {} 14 | 15 | async createForUser(userId: string): Promise { 16 | const preferences = new this.preferencesModel({ userId }); 17 | return preferences.save(); 18 | } 19 | 20 | async findByUserId(userId: string): Promise { 21 | const preferences = await this.preferencesModel.findOne({ userId }); 22 | if (!preferences) { 23 | throw new NotFoundException('Preferences not found for this user'); 24 | } 25 | return preferences; 26 | } 27 | 28 | async updatePreferences( 29 | userId: string, 30 | updateData: UpdatePreferencesDto 31 | ): Promise { 32 | const preferences = await this.findByUserId(userId); 33 | Object.assign(preferences, updateData); 34 | this.websocketsService.emitPreferencesUpdated(userId, preferences); 35 | return preferences.save(); 36 | } 37 | 38 | async deletePreferences(userId: string): Promise { 39 | return await this.preferencesModel.findOneAndDelete({ userId }); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/branches/branches.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { BranchesService } from './branches.service'; 8 | import { ForkBranchDto } from './dto/fork-branch.dto'; 9 | import { UpdateBranchDto } from './dto/update-branch.dto'; 10 | import { ChatBranch } from './schemas/chat-branch.schema'; 11 | 12 | @Resolver(() => ChatBranch) 13 | export class BranchesResolver { 14 | constructor(private branchesService: BranchesService) {} 15 | 16 | @UseGuards(GqlAuthGuard) 17 | @Query(() => [ChatBranch]) 18 | async getChatBranches( 19 | @CurrentUser() user: AccessJwtPayload, 20 | @Args('chatId') chatId: string 21 | ): Promise { 22 | return await this.branchesService.findByChatId(chatId, user.sub); 23 | } 24 | 25 | @UseGuards(GqlAuthGuard) 26 | @Mutation(() => ChatBranch) 27 | async updateBranch( 28 | @CurrentUser() user: AccessJwtPayload, 29 | @Args('branchId') branchId: string, 30 | @Args('payload') payload: UpdateBranchDto 31 | ): Promise { 32 | return await this.branchesService.update(branchId, user.sub, payload); 33 | } 34 | 35 | @UseGuards(GqlAuthGuard) 36 | @Mutation(() => ChatBranch) 37 | async forkBranch( 38 | @CurrentUser() user: AccessJwtPayload, 39 | @Args('originalBranchId') originalBranchId: string, 40 | @Args('payload') payload: ForkBranchDto 41 | ): Promise { 42 | return await this.branchesService.forkBranch(user.sub, originalBranchId, payload); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/keys/api-key.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { ApiKeysService } from './api-key.service'; 8 | import { CreateApiKeyDto } from './dto/create-api-key.dto'; 9 | import { UpdateApiKeyDto } from './dto/update-api-key.dto'; 10 | import { ApiKey } from './schemas/api-key.schema'; 11 | 12 | @Resolver() 13 | export class ApiKeyResolver { 14 | constructor(private apiKeysService: ApiKeysService) {} 15 | 16 | @UseGuards(GqlAuthGuard) 17 | @Query(() => [ApiKey]) 18 | async getApiKeys(@CurrentUser() user: AccessJwtPayload): Promise { 19 | return await this.apiKeysService.findAll(user.sub); 20 | } 21 | 22 | @UseGuards(GqlAuthGuard) 23 | @Mutation(() => ApiKey) 24 | async addApiKey( 25 | @Args('payload') payload: CreateApiKeyDto, 26 | @CurrentUser() user: AccessJwtPayload 27 | ): Promise { 28 | return await this.apiKeysService.create(user.sub, payload); 29 | } 30 | 31 | @UseGuards(GqlAuthGuard) 32 | @Mutation(() => ApiKey) 33 | async updateApiKey( 34 | @Args('id') id: string, 35 | @Args('payload') payload: UpdateApiKeyDto, 36 | @CurrentUser() user: AccessJwtPayload 37 | ): Promise { 38 | return await this.apiKeysService.update(id, user.sub, payload); 39 | } 40 | 41 | @UseGuards(GqlAuthGuard) 42 | @Mutation(() => Boolean) 43 | async deleteApiKey( 44 | @Args('id') id: string, 45 | @CurrentUser() user: AccessJwtPayload 46 | ): Promise { 47 | await this.apiKeysService.delete(id, user.sub); 48 | return true; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/branches/schemas/chat-branch.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | @ObjectType() 6 | export class ModelConfig { 7 | @Field(() => String, { nullable: true }) 8 | modelId?: string; 9 | 10 | @Field(() => String, { nullable: true }) 11 | apiKeyId?: string; 12 | 13 | @Field(() => Number, { nullable: true }) 14 | temperature?: number; 15 | 16 | @Field(() => Number, { nullable: true }) 17 | maxTokens?: number; 18 | } 19 | 20 | @Schema({ timestamps: true }) 21 | @ObjectType() 22 | export class ChatBranch { 23 | @Field(() => String) 24 | _id: Types.ObjectId; 25 | 26 | @Prop({ type: Types.ObjectId, ref: 'User', required: true }) 27 | userId: Types.ObjectId; 28 | 29 | @Prop({ type: Types.ObjectId, ref: 'Chat', required: true }) 30 | chatId: Types.ObjectId; 31 | 32 | @Prop({ required: true, maxlength: 50 }) 33 | @Field() 34 | name: string; 35 | 36 | @Prop({ type: Types.ObjectId, ref: 'ChatBranch' }) 37 | @Field(() => ChatBranch, { nullable: true }) 38 | parentBranchId?: Types.ObjectId; 39 | 40 | @Prop({ type: Number, default: 0 }) 41 | @Field() 42 | branchPoint: number; 43 | 44 | @Prop({ type: Number, default: 0 }) 45 | @Field() 46 | messageCount: number; 47 | 48 | @Prop({ default: false }) 49 | @Field() 50 | isActive: boolean; 51 | 52 | @Prop({ type: Object, default: {} }) 53 | metadata: Record; 54 | 55 | @Prop() 56 | @Field(() => ModelConfig, { nullable: true }) 57 | modelConfig?: ModelConfig; 58 | } 59 | 60 | export type ChatBranchDocument = ChatBranch & Document; 61 | export const ChatBranchSchema = SchemaFactory.createForClass(ChatBranch); 62 | 63 | ChatBranchSchema.index({ chatId: 1, isActive: 1 }); 64 | ChatBranchSchema.index({ parentBranchId: 1 }); 65 | ChatBranchSchema.index({ chatId: 1, name: 1 }); 66 | -------------------------------------------------------------------------------- /src/users/schemas/user.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import * as bcrypt from 'bcrypt'; 4 | import { Document, Types } from 'mongoose'; 5 | 6 | import { UserPreferences } from '@/preferences/schema/user-preference.schema'; 7 | 8 | @Schema({ timestamps: true }) 9 | @ObjectType() 10 | export class User { 11 | @Field(() => String) 12 | _id: Types.ObjectId; 13 | 14 | @Prop({ required: true, unique: true, lowercase: true, trim: true }) 15 | @Field() 16 | email: string; 17 | 18 | @Prop({ required: true, default: false }) 19 | @Field() 20 | emailVerified: boolean; 21 | 22 | @Prop() 23 | emailVerificationCode?: string; 24 | 25 | @Prop({ required: true }) 26 | password: string; 27 | 28 | @Prop({ required: true, trim: true }) 29 | @Field() 30 | displayName: string; 31 | 32 | @Prop({ required: true }) 33 | @Field() 34 | encryptKey: string; 35 | 36 | @Prop({ required: true }) 37 | @Field() 38 | decryptKey: string; 39 | 40 | @Prop({ required: true, default: true }) 41 | isActive: boolean; 42 | 43 | @Prop({ types: Types.ObjectId, ref: 'UserPreferences' }) 44 | @Field(() => UserPreferences, { nullable: true }) 45 | preferences: Types.ObjectId; 46 | 47 | @Field() 48 | createdAt: Date; 49 | 50 | @Field() 51 | updatedAt: Date; 52 | } 53 | 54 | export const UserSchema = SchemaFactory.createForClass(User); 55 | export type UserDocument = User & Document; 56 | 57 | UserSchema.index({ email: 1 }, { unique: true }); 58 | UserSchema.index({ displayName: 1 }); 59 | UserSchema.index({ isActive: 1 }); 60 | UserSchema.index({ createdAt: -1 }); 61 | 62 | UserSchema.pre('save', async function (next) { 63 | if (this.isModified('password')) { 64 | this.password = await bcrypt.hash(this.password, 10); 65 | } 66 | next(); 67 | }); 68 | 69 | export function comparePassword(password: string, hashedPassword: string): Promise { 70 | return bcrypt.compare(password, hashedPassword); 71 | } 72 | -------------------------------------------------------------------------------- /src/chats/schemas/chat.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | import { ChatBranch } from '@/branches/schemas/chat-branch.schema'; 6 | import { Message } from '@/messages/schemas/message.schema'; 7 | 8 | @Schema({ timestamps: true }) 9 | @ObjectType() 10 | export class Chat { 11 | @Field(() => String) 12 | _id: Types.ObjectId; 13 | 14 | @Prop({ type: Types.ObjectId, ref: 'User', required: true }) 15 | userId: Types.ObjectId; 16 | 17 | @Prop({ required: true, maxlength: 100 }) 18 | @Field() 19 | title: string; 20 | 21 | @Prop({ default: false }) 22 | @Field() 23 | isPublic: boolean; 24 | 25 | @Prop({ types: Types.ObjectId, ref: 'ChatBranch' }) 26 | @Field(() => ChatBranch, { nullable: true }) 27 | defaultBranch?: Types.ObjectId; 28 | 29 | @Prop({ type: Object, default: {} }) 30 | metadata: Record; 31 | 32 | @Prop({ type: Date, default: Date.now }) 33 | @Field() 34 | lastActivityAt: Date; 35 | 36 | @Prop({ default: false }) 37 | @Field() 38 | archived: boolean; 39 | 40 | @Prop({ default: false }) 41 | @Field() 42 | pinned: boolean; 43 | } 44 | 45 | export type ChatDocument = Chat & Document; 46 | export const ChatSchema = SchemaFactory.createForClass(Chat); 47 | 48 | @ObjectType() 49 | export class ChatsResponse { 50 | @Field(() => [Chat]) 51 | chats: Chat[]; 52 | 53 | @Field() 54 | total: number; 55 | 56 | @Field() 57 | hasMore: boolean; 58 | } 59 | 60 | @ObjectType() 61 | export class SingleChatResponse { 62 | @Field(() => Chat) 63 | chat: Chat; 64 | 65 | @Field(() => [ChatBranch]) 66 | branches: ChatBranch[]; 67 | 68 | @Field() 69 | totalMessages: number; 70 | } 71 | 72 | @ObjectType() 73 | export class PublicChatResponse { 74 | @Field(() => Chat) 75 | chat: Chat; 76 | 77 | @Field(() => [Message]) 78 | messages: Message[]; 79 | } 80 | 81 | ChatSchema.index({ userId: 1, lastActivityAt: -1 }); 82 | ChatSchema.index({ userId: 1, archived: 1, lastActivityAt: -1 }); 83 | ChatSchema.index({ isPublic: 1, lastActivityAt: -1 }); 84 | ChatSchema.index({ title: 'text' }); 85 | -------------------------------------------------------------------------------- /src/storage/storage.resolver.ts: -------------------------------------------------------------------------------- 1 | import { UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | 4 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 5 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 6 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 7 | import { CompleteFileDto } from './dto/complete-file.dto'; 8 | import { CreateFileDto } from './dto/create-file.dto'; 9 | import { File, UserStorageStats } from './schemas/file.schema'; 10 | import { StorageService } from './storage.service'; 11 | 12 | @Resolver(() => File) 13 | @UseGuards(GqlAuthGuard) 14 | export class StorageResolver { 15 | constructor(private readonly storageService: StorageService) {} 16 | 17 | @Mutation(() => File) 18 | async createFile( 19 | @CurrentUser() user: AccessJwtPayload, 20 | @Args('payload') payload: CreateFileDto 21 | ) { 22 | const { filename, mimetype, size } = payload; 23 | return this.storageService.createFile(user.sub, filename, mimetype, size); 24 | } 25 | 26 | @Mutation(() => File) 27 | async completeFile( 28 | @CurrentUser() user: AccessJwtPayload, 29 | @Args('payload') payload: CompleteFileDto 30 | ) { 31 | return this.storageService.completeFileUpload(user.sub, payload.fileId, payload.parts); 32 | } 33 | 34 | @Mutation(() => Boolean, { name: 'deleteFile' }) 35 | async deleteFile( 36 | @Args('id') id: string, 37 | @CurrentUser() user: AccessJwtPayload 38 | ): Promise { 39 | return this.storageService.deleteFile(id, user.sub); 40 | } 41 | 42 | @Query(() => [File]) 43 | async getUserFiles(@CurrentUser() user: AccessJwtPayload): Promise { 44 | return this.storageService.getUserFiles(user.sub); 45 | } 46 | 47 | @Query(() => File) 48 | async getFileById( 49 | @Args('id') id: string, 50 | @CurrentUser() user: AccessJwtPayload 51 | ): Promise { 52 | return this.storageService.getFileById(id, user.sub); 53 | } 54 | 55 | @Query(() => UserStorageStats) 56 | async getUserStorageStats(@CurrentUser() user: AccessJwtPayload): Promise { 57 | return this.storageService.getUserStorageStats(user.sub); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /eslint.config.mjs: -------------------------------------------------------------------------------- 1 | // @ts-check 2 | import eslint from '@eslint/js'; 3 | import eslintPluginPrettierRecommended from 'eslint-plugin-prettier/recommended'; 4 | import globals from 'globals'; 5 | import tseslint from 'typescript-eslint'; 6 | 7 | export default tseslint.config( 8 | { 9 | ignores: ['eslint.config.mjs', 'dist/**', 'build/**', 'node_modules/**'], 10 | }, 11 | eslint.configs.recommended, 12 | ...tseslint.configs.recommendedTypeChecked, 13 | eslintPluginPrettierRecommended, 14 | { 15 | languageOptions: { 16 | globals: { 17 | ...globals.node, 18 | ...globals.jest, 19 | }, 20 | sourceType: 'commonjs', 21 | parserOptions: { 22 | projectService: true, 23 | tsconfigRootDir: import.meta.dirname, 24 | }, 25 | }, 26 | }, 27 | { 28 | rules: { 29 | // TypeScript rules 30 | '@typescript-eslint/no-explicit-any': 'off', 31 | '@typescript-eslint/no-floating-promises': 'warn', 32 | '@typescript-eslint/no-unsafe-argument': 'off', 33 | '@typescript-eslint/no-unsafe-assignment': 'off', 34 | '@typescript-eslint/no-unsafe-call': 'off', 35 | '@typescript-eslint/no-unsafe-member-access': 'off', 36 | '@typescript-eslint/no-unsafe-return': 'off', 37 | 38 | // Formatting rules - delegated to Prettier 39 | indent: 'off', 40 | 'linebreak-style': 'off', 41 | quotes: 'off', 42 | semi: 'off', 43 | 'comma-dangle': 'off', 44 | 'object-curly-spacing': 'off', 45 | 'array-bracket-spacing': 'off', 46 | 'space-before-function-paren': 'off', 47 | 48 | // Prettier integration 49 | 'prettier/prettier': [ 50 | 'error', 51 | { 52 | semi: true, 53 | singleQuote: true, 54 | tabWidth: 4, 55 | useTabs: false, 56 | printWidth: 100, 57 | trailingComma: 'es5', 58 | bracketSpacing: true, 59 | arrowParens: 'avoid', 60 | endOfLine: 'lf', 61 | }, 62 | ], 63 | }, 64 | } 65 | ); 66 | -------------------------------------------------------------------------------- /src/ai/ai.service.ts: -------------------------------------------------------------------------------- 1 | import { Injectable } from '@nestjs/common'; 2 | 3 | import { Message } from '@/messages/schemas/message.schema'; 4 | import { StorageService } from '@/storage/storage.service'; 5 | import { AnthropicClient } from './clients/anthropic.client'; 6 | import { GoogleClient } from './clients/google.client'; 7 | import { OpenAIClient } from './clients/openai.client'; 8 | import { OpenRouterClient } from './clients/openrouter.client'; 9 | import { 10 | AIModel, 11 | AIProviderCallbacks, 12 | AIProviderClient, 13 | AIProviderId, 14 | AIProviderOptions, 15 | } from './interfaces/ai-provider.interface'; 16 | 17 | @Injectable() 18 | export class AIService { 19 | private readonly clients: Record; 20 | 21 | constructor(private readonly storageService: StorageService) { 22 | this.clients = { 23 | anthropic: new AnthropicClient(), 24 | google: new GoogleClient(), 25 | openrouter: new OpenRouterClient(this.storageService), 26 | openai: new OpenAIClient(this.storageService), 27 | }; 28 | } 29 | 30 | async validateKeyFormat(providerId: AIProviderId, key: string): Promise { 31 | const provider = this.clients[providerId]; 32 | const models = await provider.getModels(key).catch(() => []); 33 | return models.length > 0; 34 | } 35 | 36 | async getModels(providerId: AIProviderId, key: string): Promise { 37 | return this.clients[providerId].getModels(key); 38 | } 39 | 40 | async countInputTokens( 41 | providerId: AIProviderId, 42 | key: string, 43 | modelId: string, 44 | messages: Message[], 45 | settings: AIProviderOptions 46 | ): Promise { 47 | return this.clients[providerId].countInputTokens(key, modelId, messages, settings); 48 | } 49 | 50 | sendMessage( 51 | providerId: AIProviderId, 52 | key: string, 53 | modelId: string, 54 | messages: Message[], 55 | settings: AIProviderOptions, 56 | callbacks: AIProviderCallbacks 57 | ): Promise { 58 | return this.clients[providerId].sendMessage(key, modelId, messages, settings, callbacks); 59 | } 60 | 61 | generateImage( 62 | providerId: AIProviderId, 63 | key: string, 64 | modelId: string, 65 | promptOrMessages: string | Message[], 66 | settings: AIProviderOptions, 67 | callbacks: AIProviderCallbacks 68 | ): Promise { 69 | return this.clients[providerId].generateImage( 70 | key, 71 | modelId, 72 | promptOrMessages, 73 | settings, 74 | callbacks 75 | ); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/users/users.service.ts: -------------------------------------------------------------------------------- 1 | import { ConflictException, Injectable, NotFoundException } from '@nestjs/common'; 2 | import { InjectModel } from '@nestjs/mongoose'; 3 | import { Model } from 'mongoose'; 4 | 5 | import { PreferencesService } from '@/preferences/preferences.service'; 6 | import { WebsocketsService } from '@/websockets/websockets.service'; 7 | import { UpdateUserDto } from './dto/update-user.dto'; 8 | import { User, UserDocument } from './schemas/user.schema'; 9 | 10 | @Injectable() 11 | export class UsersService { 12 | constructor( 13 | @InjectModel(User.name) 14 | private userModel: Model, 15 | private readonly preferencesService: PreferencesService, 16 | private readonly websocketsService: WebsocketsService 17 | ) {} 18 | 19 | async create( 20 | userData: Omit< 21 | User, 22 | '_id' | 'isActive' | 'createdAt' | 'updatedAt' | 'emailVerified' | 'preferences' 23 | > 24 | ) { 25 | const existingUser = await this.findByEmail(userData.email); 26 | if (existingUser) { 27 | throw new ConflictException('User with this email already exists'); 28 | } 29 | 30 | const user = new this.userModel({ 31 | ...userData, 32 | isActive: true, 33 | emailVerified: false, 34 | emailVerificationCode: 'dummy', // Todo: Add random generation. 35 | }); 36 | 37 | const preferences = await this.preferencesService.createForUser(user._id.toString()); 38 | 39 | user.preferences = preferences._id; 40 | await user.save(); 41 | return user; 42 | } 43 | 44 | async findById(id: string): Promise { 45 | const user = await this.userModel.findById(id).exec(); 46 | 47 | if (!user) { 48 | throw new NotFoundException('User not found'); 49 | } 50 | 51 | return user; 52 | } 53 | 54 | async findByEmail(email: string): Promise { 55 | return this.userModel.findOne({ email, isActive: true }).exec(); 56 | } 57 | 58 | async updateUser(userId: string, updateData: UpdateUserDto): Promise { 59 | if (updateData.email) { 60 | const existingUser = await this.findByEmail(updateData.email); 61 | if (existingUser) { 62 | throw new ConflictException('User with this email already exists'); 63 | } 64 | } 65 | 66 | const user = await this.userModel 67 | .findByIdAndUpdate(userId, { $set: updateData }, { new: true, runValidators: true }) 68 | .exec(); 69 | 70 | if (!user) { 71 | throw new NotFoundException('User not found'); 72 | } 73 | 74 | this.websocketsService.emitUserUpdated(userId, user); 75 | return user; 76 | } 77 | 78 | async updateLastLogin(userId: string) { 79 | return this.userModel.findByIdAndUpdate(userId, { lastLogin: new Date() }).exec(); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/sessions/sessions.service.ts: -------------------------------------------------------------------------------- 1 | import { Injectable, NotFoundException } from '@nestjs/common'; 2 | import { InjectModel } from '@nestjs/mongoose'; 3 | import { Model, Types } from 'mongoose'; 4 | 5 | import { Session, SessionDocument } from './schemas/session.schema'; 6 | 7 | @Injectable() 8 | export class SessionsService { 9 | constructor( 10 | @InjectModel(Session.name) 11 | private sessionModel: Model 12 | ) {} 13 | 14 | async createSession( 15 | sessionData: Omit 16 | ): Promise { 17 | const session = new this.sessionModel({ 18 | ...sessionData, 19 | isActive: true, 20 | lastUsedAt: Date.now(), 21 | }); 22 | return session.save(); 23 | } 24 | 25 | async findByRefreshToken(refreshToken: string): Promise { 26 | return this.sessionModel 27 | .findOne({ refreshToken, isActive: true, expiresAt: { $gt: Date.now() } }) 28 | .exec(); 29 | } 30 | 31 | async findById(id: string): Promise { 32 | return this.sessionModel.findById(id).exec(); 33 | } 34 | 35 | async findByUserId(userId: string): Promise { 36 | return this.sessionModel 37 | .find({ 38 | userId: new Types.ObjectId(userId), 39 | isActive: true, 40 | expiresAt: { $gt: Date.now() }, 41 | }) 42 | .exec(); 43 | } 44 | 45 | async updateTokens( 46 | sessionId: string, 47 | refreshToken: string, 48 | accessToken: string 49 | ): Promise { 50 | const session = await this.sessionModel 51 | .findByIdAndUpdate( 52 | sessionId, 53 | { 54 | refreshToken, 55 | accessToken, 56 | lastUsedAt: new Date(), 57 | }, 58 | { new: true } 59 | ) 60 | .exec(); 61 | 62 | if (!session) { 63 | throw new NotFoundException('Session not found'); 64 | } 65 | 66 | return session; 67 | } 68 | 69 | async revokeSession(userId: string, sessionId: string): Promise { 70 | await this.sessionModel 71 | .findOneAndUpdate( 72 | { _id: sessionId, userId: new Types.ObjectId(userId) }, 73 | { isActive: false } 74 | ) 75 | .exec(); 76 | } 77 | 78 | async revokeAllUserSessions(userId: string) { 79 | await this.sessionModel 80 | .updateMany({ userId: new Types.ObjectId(userId) }, { isActive: false }) 81 | .exec(); 82 | } 83 | 84 | async cleanupExpiredSessions(): Promise { 85 | await this.sessionModel 86 | .deleteMany({ 87 | $or: [{ expiresAt: { $lt: Date.now() } }, { isActive: false }], 88 | }) 89 | .exec(); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/messages/schemas/message.schema.ts: -------------------------------------------------------------------------------- 1 | import { Field, ID, ObjectType, registerEnumType } from '@nestjs/graphql'; 2 | import { Prop, Schema, SchemaFactory } from '@nestjs/mongoose'; 3 | import { Document, Types } from 'mongoose'; 4 | 5 | export enum MessageRole { 6 | system = 'system', 7 | user = 'user', 8 | assistant = 'assistant', 9 | function = 'function', 10 | tool = 'tool', 11 | } 12 | 13 | registerEnumType(MessageRole, { 14 | name: 'MessageRole', 15 | }); 16 | 17 | @ObjectType() 18 | export class MessageContent { 19 | @Field() 20 | type: string; 21 | 22 | @Field({ nullable: true }) 23 | id?: string; 24 | 25 | @Field({ nullable: true }) 26 | name?: string; 27 | 28 | @Field({ nullable: true }) 29 | text?: string; 30 | 31 | @Field({ nullable: true }) 32 | tool_use_id?: string; 33 | 34 | input?: unknown; 35 | } 36 | 37 | @Schema({ timestamps: true }) 38 | @ObjectType() 39 | export class Message { 40 | @Field(() => String) 41 | _id: Types.ObjectId; 42 | 43 | @Prop([{ type: String, ref: 'File' }]) 44 | @Field(() => [ID]) 45 | attachments: string[]; 46 | 47 | @Prop({ types: Types.ObjectId, ref: 'Chat', required: true }) 48 | @Field(() => String) 49 | chatId: Types.ObjectId; 50 | 51 | @Prop({ types: Types.ObjectId, ref: 'ChatBranch', required: true }) 52 | @Field(() => String) 53 | branchId: Types.ObjectId; 54 | 55 | @Prop({ required: true }) 56 | @Field() 57 | index: number; 58 | 59 | @Prop({ enum: MessageRole, required: true }) 60 | @Field(() => String) 61 | role: MessageRole; 62 | 63 | @Prop({ type: ObjectType, required: true }) 64 | @Field(() => [MessageContent]) 65 | content: MessageContent[]; 66 | 67 | @Prop() 68 | @Field({ nullable: true }) 69 | modelUsed?: string; 70 | 71 | @Prop({ default: 0 }) 72 | @Field({ nullable: true }) 73 | tokens: number; 74 | 75 | @Prop({ type: Object, default: {} }) 76 | metadata: Record; 77 | 78 | @Prop({ type: Boolean, default: false }) 79 | @Field() 80 | isEdited: boolean; 81 | 82 | @Prop() 83 | @Field({ nullable: true }) 84 | editedAt?: Date; 85 | 86 | @Prop() 87 | @Field(() => [MessageContent], { nullable: true }) 88 | originalContent?: MessageContent[]; 89 | 90 | @Field() 91 | createdAt?: Date; 92 | } 93 | 94 | @ObjectType() 95 | export class MessagesResponse { 96 | @Field(() => [Message]) 97 | messages: Message[]; 98 | @Field() 99 | total: number; 100 | @Field() 101 | hasMore: boolean; 102 | } 103 | 104 | export type MessageDocument = Message & Document; 105 | export const MessageSchema = SchemaFactory.createForClass(Message); 106 | 107 | MessageSchema.index({ chatId: 1, createdAt: -1 }); 108 | MessageSchema.index({ chatId: 1, createdAt: 1 }); 109 | MessageSchema.index({ content: 'text' }); 110 | MessageSchema.index({ role: 1 }); 111 | MessageSchema.index({ attachments: 1 }); 112 | -------------------------------------------------------------------------------- /src/ai/interfaces/ai-provider.interface.ts: -------------------------------------------------------------------------------- 1 | import { Field, ObjectType, registerEnumType } from '@nestjs/graphql'; 2 | 3 | import { Message } from '@/messages/schemas/message.schema'; 4 | 5 | @ObjectType() 6 | export class AIModelCapabilities { 7 | @Field() 8 | textGeneration: boolean; 9 | @Field() 10 | imageGeneration: boolean; 11 | @Field() 12 | imageAnalysis: boolean; 13 | @Field() 14 | functionCalling: boolean; 15 | @Field() 16 | webBrowsing: boolean; 17 | @Field() 18 | codeExecution: boolean; 19 | @Field() 20 | fileAnalysis: boolean; 21 | } 22 | 23 | @ObjectType() 24 | export class AIModel { 25 | @Field() 26 | id: string; 27 | @Field() 28 | name: string; 29 | @Field() 30 | author: string; 31 | @Field(() => String) 32 | provider: AIProviderId; 33 | @Field() 34 | capabilities: AIModelCapabilities; 35 | @Field({ nullable: true }) 36 | enabled?: boolean; 37 | @Field({ nullable: true }) 38 | description?: string; 39 | @Field({ nullable: true }) 40 | category?: string; 41 | @Field(() => String, { nullable: true }) 42 | cost?: AIModelPropValue; 43 | @Field(() => String, { nullable: true }) 44 | speed?: AIModelPropValue; 45 | } 46 | 47 | export type AIProviderOptions = { 48 | maxTokens?: number; 49 | temperature?: number; 50 | imageGeneration?: { 51 | size?: '256x256' | '512x512' | '1024x1024' | '1792x1024' | '1024x1792'; 52 | quality?: 'standard' | 'hd'; 53 | style?: 'vivid' | 'natural'; 54 | n?: number; 55 | }; 56 | }; 57 | 58 | export type AIProviderCallbacks = { 59 | onError: (error: string) => Promise; 60 | onText: (text: string) => Promise; 61 | onEnd: () => Promise; 62 | 63 | onMediaGenStart?: (type: 'image' | 'audio' | 'video') => Promise; 64 | onMediaGenEnd?: ( 65 | mediaUrl: string, 66 | type: 'image' | 'audio' | 'video', 67 | metadata?: any 68 | ) => Promise; 69 | onMediaGenError?: (error: string, type: 'image' | 'audio' | 'video') => Promise; 70 | }; 71 | 72 | export interface AIProviderClient { 73 | getModels(key: string): Promise; 74 | 75 | countInputTokens( 76 | key: string, 77 | modelId: string, 78 | messages: Message[], 79 | settings: AIProviderOptions 80 | ): Promise; 81 | 82 | sendMessage( 83 | key: string, 84 | modelId: string, 85 | messages: Message[], 86 | settings: { maxTokens?: number; temperature?: number }, 87 | callbacks: AIProviderCallbacks 88 | ): Promise; 89 | 90 | generateImage( 91 | key: string, 92 | modelId: string, 93 | promptOrMessages: string | Message[], 94 | settings: AIProviderOptions, 95 | callbacks: AIProviderCallbacks 96 | ); 97 | } 98 | 99 | export enum AIProviderId { 100 | openai = 'openai', 101 | anthropic = 'anthropic', 102 | openrouter = 'openrouter', 103 | google = 'google', 104 | } 105 | 106 | registerEnumType(AIProviderId, { 107 | name: 'AIProviderId', 108 | }); 109 | 110 | export enum AIModelPropValue { 111 | low = 'low', 112 | medium = 'medium', 113 | high = 'high', 114 | } 115 | 116 | registerEnumType(AIModelPropValue, { 117 | name: 'AIModelPropValue', 118 | }); 119 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "uoa-backend", 3 | "version": "0.0.1", 4 | "description": "", 5 | "author": "", 6 | "private": true, 7 | "license": "UNLICENSED", 8 | "scripts": { 9 | "build": "nest build", 10 | "format": "prettier --write \"src/**/*.ts\" \"test/**/*.ts\"", 11 | "start": "nest start", 12 | "start:dev": "nest start --watch", 13 | "start:debug": "nest start --debug --watch", 14 | "start:prod": "node dist/main", 15 | "lint": "eslint \"{src,apps,libs,test}/**/*.ts\" --fix", 16 | "test": "jest", 17 | "test:watch": "jest --watch", 18 | "test:cov": "jest --coverage", 19 | "test:debug": "node --inspect-brk -r tsconfig-paths/register -r ts-node/register node_modules/.bin/jest --runInBand", 20 | "test:e2e": "jest --config ./test/jest-e2e.json" 21 | }, 22 | "dependencies": { 23 | "@anthropic-ai/sdk": "^0.53.0", 24 | "@google/genai": "^1.4.0", 25 | "@nestjs/apollo": "^13.1.0", 26 | "@nestjs/bull": "^11.0.2", 27 | "@nestjs/common": "^11.0.1", 28 | "@nestjs/config": "^4.0.2", 29 | "@nestjs/core": "^11.0.1", 30 | "@nestjs/graphql": "^13.1.0", 31 | "@nestjs/jwt": "^11.0.0", 32 | "@nestjs/mongoose": "^11.0.3", 33 | "@nestjs/passport": "^11.0.5", 34 | "@nestjs/platform-express": "^11.0.1", 35 | "@nestjs/platform-socket.io": "^11.1.3", 36 | "@nestjs/serve-static": "^5.0.3", 37 | "@nestjs/websockets": "^11.1.3", 38 | "axios": "^1.9.0", 39 | "bcrypt": "^6.0.0", 40 | "busboy": "^1.6.0", 41 | "class-validator": "^0.14.2", 42 | "crypto": "^1.0.1", 43 | "graphql": "^16.11.0", 44 | "jwt": "^0.2.0", 45 | "mongoose": "^8.15.1", 46 | "node-fetch": "2.7.0", 47 | "openai": "^5.5.1", 48 | "passport": "^0.7.0", 49 | "passport-jwt": "^4.0.1", 50 | "passport-local": "^1.0.0", 51 | "pdf-parse": "^1.1.1", 52 | "reflect-metadata": "^0.2.2", 53 | "rxjs": "^7.8.1", 54 | "socket.io": "^4.8.1" 55 | }, 56 | "devDependencies": { 57 | "@eslint/eslintrc": "^3.2.0", 58 | "@eslint/js": "^9.18.0", 59 | "@nestjs/cli": "^11.0.0", 60 | "@nestjs/schematics": "^11.0.0", 61 | "@nestjs/testing": "^11.0.1", 62 | "@swc/cli": "^0.6.0", 63 | "@swc/core": "^1.10.7", 64 | "@types/bcrypt": "^5.0.2", 65 | "@types/busboy": "^1.5.4", 66 | "@types/express": "^5.0.0", 67 | "@types/jest": "^29.5.14", 68 | "@types/node": "^22.10.7", 69 | "@types/node-fetch": "^2.6.12", 70 | "@types/passport-jwt": "^4.0.1", 71 | "@types/passport-local": "^1.0.38", 72 | "@types/pdf-parse": "^1.1.5", 73 | "@types/supertest": "^6.0.2", 74 | "@types/uuid": "^10.0.0", 75 | "cross-env": "^7.0.3", 76 | "eslint": "^9.18.0", 77 | "eslint-config-prettier": "^10.0.1", 78 | "eslint-plugin-prettier": "^5.2.2", 79 | "globals": "^16.0.0", 80 | "jest": "^29.7.0", 81 | "prettier": "^3.4.2", 82 | "source-map-support": "^0.5.21", 83 | "supertest": "^7.0.0", 84 | "ts-jest": "^29.2.5", 85 | "ts-loader": "^9.5.2", 86 | "ts-node": "^10.9.2", 87 | "tsconfig-paths": "^4.2.0", 88 | "typescript": "^5.7.3", 89 | "typescript-eslint": "^8.20.0" 90 | }, 91 | "jest": { 92 | "moduleFileExtensions": [ 93 | "js", 94 | "json", 95 | "ts" 96 | ], 97 | "rootDir": "src", 98 | "testRegex": ".*\\.spec\\.ts$", 99 | "transform": { 100 | "^.+\\.(t|j)s$": "ts-jest" 101 | }, 102 | "collectCoverageFrom": [ 103 | "**/*.(t|j)s" 104 | ], 105 | "coverageDirectory": "../coverage", 106 | "testEnvironment": "node" 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/app.module.ts: -------------------------------------------------------------------------------- 1 | import { ApolloDriver, ApolloDriverConfig } from '@nestjs/apollo'; 2 | import { BullModule } from '@nestjs/bull'; 3 | import { Module } from '@nestjs/common'; 4 | import { ConfigModule, ConfigService } from '@nestjs/config'; 5 | import { GraphQLModule } from '@nestjs/graphql'; 6 | import { MongooseModule } from '@nestjs/mongoose'; 7 | import * as path from 'path'; 8 | 9 | import { AIModule } from './ai/ai.module'; 10 | import { AuthModule } from './auth/auth.module'; 11 | import { BranchesModule } from './branches/branches.module'; 12 | import { ChatsModule } from './chats/chats.module'; 13 | import { EncryptionModule } from './encryption/encryption.module'; 14 | import { ApiKeysModule } from './keys/api-key.module'; 15 | import { MessagesModule } from './messages/messages.module'; 16 | import { PreferencesModule } from './preferences/preferences.module'; 17 | import { SessionsModule } from './sessions/sessions.module'; 18 | import { UsersModule } from './users/users.module'; 19 | import { WebsocketsModule } from './websockets/websockets.module'; 20 | 21 | import { AppController } from './app.controller'; 22 | import { AppService } from './app.service'; 23 | import { StorageModule } from './storage/storage.module'; 24 | 25 | @Module({ 26 | imports: [ 27 | // Configuration (Load .env file if exist) 28 | ConfigModule.forRoot({ 29 | isGlobal: true, 30 | envFilePath: '.env', 31 | }), 32 | 33 | // Database (MongoDB) 34 | MongooseModule.forRootAsync({ 35 | imports: [ConfigModule], 36 | inject: [ConfigService], 37 | useFactory: (config: ConfigService) => ({ 38 | uri: config.get('MONGODB_URI'), 39 | }), 40 | }), 41 | 42 | // GraphQL 43 | GraphQLModule.forRootAsync({ 44 | driver: ApolloDriver, 45 | imports: [ConfigModule], 46 | inject: [ConfigService], 47 | useFactory: (config: ConfigService) => ({ 48 | autoSchemaFile: path.join(process.cwd(), 'src/schema.gql'), 49 | sortSchema: true, 50 | playground: config.get('NODE_ENV') === 'development', 51 | introspection: config.get('NODE_ENV') === 'development', 52 | context: ({ req, res }: any) => ({ req, res }), 53 | subscriptions: { 54 | 'graphql-ws': true, 55 | 'subscriptions-transport-ws': true, 56 | }, 57 | }), 58 | }), 59 | 60 | // Bull Queue (Uses Redis) 61 | BullModule.forRootAsync({ 62 | imports: [ConfigModule], 63 | inject: [ConfigService], 64 | useFactory: (config: ConfigService) => ({ 65 | redis: { 66 | host: config.get('REDIS_HOST'), 67 | port: config.get('REDIS_PORT'), 68 | password: config.get('REDIS_PASSWORD'), 69 | tls: config.get('REDIS_TLS') === 'true' ? {} : undefined, 70 | }, 71 | defaultJobOptions: { 72 | removeOnComplete: 100, 73 | removeOnFail: 100, 74 | }, 75 | settings: { 76 | stalledInterval: 30 * 1000, // 30s 77 | maxStalledCount: 10, 78 | drainDelay: 300, // 300 ms 79 | }, 80 | }), 81 | }), 82 | 83 | // Feature Modules 84 | AIModule, 85 | ApiKeysModule, 86 | AuthModule, 87 | BranchesModule, 88 | ChatsModule, 89 | EncryptionModule, 90 | MessagesModule, 91 | SessionsModule, 92 | StorageModule, 93 | UsersModule, 94 | PreferencesModule, 95 | WebsocketsModule, 96 | ], 97 | controllers: [AppController], 98 | providers: [AppService], 99 | }) 100 | export class AppModule {} 101 | -------------------------------------------------------------------------------- /src/ai/clients/anthropic.client.ts: -------------------------------------------------------------------------------- 1 | import Anthropic from '@anthropic-ai/sdk'; 2 | import { MessageStreamParams } from '@anthropic-ai/sdk/resources/index'; 3 | import { MessageParam } from '@anthropic-ai/sdk/resources/messages'; 4 | 5 | import { Message } from '@/messages/schemas/message.schema'; 6 | import { 7 | AIModel, 8 | AIProviderCallbacks, 9 | AIProviderClient, 10 | AIProviderId, 11 | AIProviderOptions, 12 | } from '../interfaces/ai-provider.interface'; 13 | 14 | export class AnthropicClient implements AIProviderClient { 15 | async validateKeyFormat(key: string): Promise { 16 | const models = await this.getModels(key).catch(() => []); 17 | return models.length > 0; 18 | } 19 | 20 | async getModels(key: string): Promise { 21 | const client = new Anthropic({ apiKey: key }); 22 | const raws = await client.models.list(); 23 | 24 | return raws.data.map(model => ({ 25 | id: model.id, 26 | name: model.display_name, 27 | author: 'Anthropic', 28 | provider: AIProviderId.anthropic, 29 | enabled: true, 30 | capabilities: { 31 | codeExecution: false, 32 | fileAnalysis: false, 33 | functionCalling: false, 34 | imageAnalysis: false, 35 | imageGeneration: false, 36 | textGeneration: true, 37 | webBrowsing: false, 38 | }, 39 | })); 40 | } 41 | 42 | async countInputTokens( 43 | key: string, 44 | modelId: string, 45 | messages: Message[], 46 | settings: AIProviderOptions 47 | ): Promise { 48 | const client = new Anthropic({ apiKey: key }); 49 | 50 | const history: Array = messages.map(message => ({ 51 | role: message.role, 52 | content: message.content, 53 | })) as Array; 54 | 55 | const params: MessageStreamParams = { 56 | messages: history, 57 | model: modelId, 58 | max_tokens: settings.maxTokens || 1024, 59 | temperature: settings.temperature, 60 | }; 61 | 62 | return (await client.messages.countTokens(params)).input_tokens; 63 | } 64 | 65 | async sendMessage( 66 | key: string, 67 | modelId: string, 68 | messages: Message[], 69 | settings: AIProviderOptions, 70 | callbacks: AIProviderCallbacks 71 | ): Promise { 72 | const client = new Anthropic({ apiKey: key }); 73 | 74 | const history: Array = messages.map(message => ({ 75 | role: message.role, 76 | content: message.content, 77 | })) as Array; 78 | 79 | const params: MessageStreamParams = { 80 | messages: history, 81 | model: modelId, 82 | max_tokens: settings.maxTokens || 1024, 83 | temperature: settings.temperature, 84 | }; 85 | 86 | return new Promise(resolve => { 87 | client.messages 88 | .stream(params) 89 | .on('text', text => { 90 | void callbacks.onText(text); 91 | }) 92 | .on('end', () => { 93 | void callbacks.onEnd(); 94 | resolve(); 95 | }) 96 | .on('error', err => { 97 | void callbacks.onError(err.message); 98 | resolve(); 99 | }); 100 | }); 101 | } 102 | 103 | async generateImage( 104 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 105 | key: string, 106 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 107 | modelId: string, 108 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 109 | promptOrMessages: string | Message[], 110 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 111 | settings: AIProviderOptions, 112 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 113 | callbacks: AIProviderCallbacks 114 | ) {} 115 | } 116 | -------------------------------------------------------------------------------- /src/auth/auth.resolver.ts: -------------------------------------------------------------------------------- 1 | import { Args, Context, Mutation, Query, Resolver } from '@nestjs/graphql'; 2 | 3 | import { UseGuards } from '@nestjs/common'; 4 | import { Session, SessionResponse } from 'src/sessions/schemas/session.schema'; 5 | import { User } from 'src/users/schemas/user.schema'; 6 | import { AuthService } from './auth.service'; 7 | import { CurrentUser } from './decorators/current-user.decorator'; 8 | import { ChangePasswordDto } from './dto/change-password.dto'; 9 | import { LoginDto } from './dto/login.dto'; 10 | import { RegisterDto } from './dto/register.dto'; 11 | import { GqlAuthGuard } from './guards/gql-auth.guard'; 12 | import { AccessJwtPayload } from './interfaces/jwt-payload.interface'; 13 | 14 | @Resolver(() => SessionResponse) 15 | export class AuthResolver { 16 | constructor(private readonly authService: AuthService) {} 17 | 18 | @Mutation(() => SessionResponse) 19 | async register( 20 | @Args('payload') payload: RegisterDto, 21 | @Context() context: any 22 | ): Promise { 23 | const req = context.req; 24 | const deviceInfo = this.createDeviceInfo(req); 25 | return await this.authService.register(payload, deviceInfo); 26 | } 27 | 28 | @Mutation(() => SessionResponse) 29 | async login( 30 | @Args('payload') payload: LoginDto, 31 | @Context() context: any 32 | ): Promise { 33 | const req = context.req; 34 | const deviceInfo = this.createDeviceInfo(req); 35 | 36 | return this.authService.login(payload, deviceInfo); 37 | } 38 | 39 | @UseGuards(GqlAuthGuard) 40 | @Mutation(() => User) 41 | async updatePassword( 42 | @CurrentUser() user: AccessJwtPayload, 43 | @Args('payload') payload: ChangePasswordDto 44 | ): Promise { 45 | return await this.authService.updatePassword(user.sub, payload); 46 | } 47 | 48 | @Mutation(() => SessionResponse) 49 | async refreshToken(@Args('refreshToken') refreshToken: string): Promise { 50 | return this.authService.refreshTokens(refreshToken); 51 | } 52 | 53 | @Query(() => [Session]) 54 | @UseGuards(GqlAuthGuard) 55 | async getSessions(@CurrentUser() user: AccessJwtPayload): Promise { 56 | return await this.authService.getUserSessions(user.sub); 57 | } 58 | 59 | @UseGuards(GqlAuthGuard) 60 | @Mutation(() => Boolean) 61 | async logout(@CurrentUser() user: AccessJwtPayload): Promise { 62 | await this.authService.logout(user.sub, user.sessionId); 63 | return true; 64 | } 65 | 66 | @UseGuards(GqlAuthGuard) 67 | @Mutation(() => Boolean) 68 | async revokeSession( 69 | @Args('sessionId') sessionId: string, 70 | @CurrentUser() user: AccessJwtPayload 71 | ): Promise { 72 | await this.authService.logout(user.sub, sessionId); 73 | return true; 74 | } 75 | 76 | @UseGuards(GqlAuthGuard) 77 | @Mutation(() => Boolean) 78 | async revokeAllSessions(@CurrentUser() user: AccessJwtPayload): Promise { 79 | await this.authService.logoutAll(user.sub); 80 | return true; 81 | } 82 | 83 | private createDeviceInfo(req: any) { 84 | const ua: string = req.headers['user-agent'] || 'unknown'; 85 | 86 | return { 87 | userAgent: ua, 88 | ip: req.ip || req.connection.remoteAddress || '0.0.0.0', 89 | platform: this.extractPlatform(ua), 90 | browser: this.extractBrowser(ua), 91 | }; 92 | } 93 | 94 | private extractPlatform(userAgent?: string): string { 95 | if (!userAgent) return 'unknown'; 96 | 97 | if (userAgent.includes('Windows')) return 'windows'; 98 | if (userAgent.includes('Mac')) return 'macOS'; 99 | if (userAgent.includes('Linux')) return 'linux'; 100 | if (userAgent.includes('Android')) return 'android'; 101 | if (userAgent.includes('iPhone') || userAgent.includes('ipad')) return 'ios'; 102 | 103 | return 'unknown'; 104 | } 105 | 106 | private extractBrowser(userAgent?: string): string { 107 | if (!userAgent) return 'unknown'; 108 | 109 | if (userAgent.includes('Chrome')) return 'chrome'; 110 | if (userAgent.includes('Firefox')) return 'firefox'; 111 | if (userAgent.includes('Safari')) return 'safari'; 112 | if (userAgent.includes('Edge')) return 'edge'; 113 | if (userAgent.includes('Opera')) return 'ppera'; 114 | 115 | return 'unknown'; 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/keys/api-key.service.ts: -------------------------------------------------------------------------------- 1 | import { 2 | BadRequestException, 3 | ConflictException, 4 | Injectable, 5 | NotFoundException, 6 | } from '@nestjs/common'; 7 | import { InjectModel } from '@nestjs/mongoose'; 8 | import { Model, Types } from 'mongoose'; 9 | 10 | import { AIService } from '@/ai/ai.service'; 11 | import { EncryptionService } from '@/encryption/encryption.service'; 12 | import { UsersService } from '@/users/users.service'; 13 | import { CreateApiKeyDto } from './dto/create-api-key.dto'; 14 | import { UpdateApiKeyDto } from './dto/update-api-key.dto'; 15 | import { ApiKey, ApiKeyDocument } from './schemas/api-key.schema'; 16 | 17 | @Injectable() 18 | export class ApiKeysService { 19 | constructor( 20 | @InjectModel(ApiKey.name) private apiKeyModel: Model, 21 | private encryptionService: EncryptionService, 22 | private aiService: AIService, 23 | private usersService: UsersService 24 | ) {} 25 | 26 | async create(userId: string, data: CreateApiKeyDto) { 27 | const { provider, alias, apiKey } = data; 28 | 29 | // Check if alias is already in use 30 | const existingApiKey = await this.apiKeyModel.findOne({ 31 | userId: new Types.ObjectId(userId), 32 | alias, 33 | isActive: true, 34 | }); 35 | 36 | if (existingApiKey) { 37 | throw new ConflictException(`API key with alias "${alias}" already exists`); 38 | } 39 | 40 | // Validate API key format and test connection 41 | const valid = await this.aiService.validateKeyFormat(provider, apiKey); 42 | if (!valid) { 43 | throw new BadRequestException(`Invalid ${provider} API key`); 44 | } 45 | 46 | // Get user 47 | const user = await this.usersService.findById(userId); 48 | 49 | // Encrypt the API key 50 | const encryptedApiKey = this.encryptionService.encryptWithKey(apiKey, user.encryptKey); 51 | 52 | // Create new API key document 53 | const newApiKey = new this.apiKeyModel({ 54 | userId: new Types.ObjectId(userId), 55 | provider, 56 | alias, 57 | encryptedApiKey, 58 | isActive: true, 59 | lastUsed: null, 60 | lastRotated: null, 61 | }); 62 | 63 | // Save to database 64 | return await newApiKey.save(); 65 | } 66 | 67 | async update( 68 | id: string, 69 | userId: string, 70 | updateApiKeyDto: UpdateApiKeyDto 71 | ): Promise { 72 | const apiKey = await this.findById(id, userId); 73 | 74 | // Update fields 75 | if (updateApiKeyDto.alias) { 76 | // Check if alias is already used by another key 77 | const existingKey = await this.apiKeyModel 78 | .findOne({ 79 | userId: new Types.ObjectId(userId), 80 | alias: updateApiKeyDto.alias, 81 | _id: { $ne: new Types.ObjectId(id) }, 82 | isActive: true, 83 | }) 84 | .exec(); 85 | 86 | if (existingKey) { 87 | throw new ConflictException( 88 | `API key with alias "${updateApiKeyDto.alias}" already exists` 89 | ); 90 | } 91 | 92 | apiKey.alias = updateApiKeyDto.alias; 93 | } 94 | 95 | if (updateApiKeyDto.isActive !== undefined) { 96 | apiKey.isActive = updateApiKeyDto.isActive; 97 | } 98 | 99 | return apiKey.save(); 100 | } 101 | 102 | async delete(id: string, userId: string): Promise { 103 | const apiKey = await this.findById(id, userId); 104 | 105 | // Soft delete by setting isActive to false 106 | apiKey.isActive = false; 107 | await apiKey.save(); 108 | } 109 | 110 | async findAll(userId: string): Promise { 111 | return this.apiKeyModel.find({ userId: new Types.ObjectId(userId), isActive: true }).exec(); 112 | } 113 | 114 | async findById(id: string, userId: string): Promise { 115 | if (!Types.ObjectId.isValid(id)) { 116 | throw new BadRequestException('Invalid API key ID'); 117 | } 118 | 119 | const apiKey = await this.apiKeyModel.findById(id).exec(); 120 | 121 | if (!apiKey || apiKey.userId.toString() !== userId) { 122 | throw new NotFoundException('API key not found'); 123 | } 124 | 125 | return apiKey; 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Nest Logo 3 |

4 | 5 | [circleci-image]: https://img.shields.io/circleci/build/github/nestjs/nest/master?token=abc123def456 6 | [circleci-url]: https://circleci.com/gh/nestjs/nest 7 | 8 |

A progressive Node.js framework for building efficient and scalable server-side applications.

9 |

10 | NPM Version 11 | Package License 12 | NPM Downloads 13 | CircleCI 14 | Discord 15 | Backers on Open Collective 16 | Sponsors on Open Collective 17 | Donate us 18 | Support us 19 | Follow us on Twitter 20 |

21 | 23 | 24 | ## Description 25 | 26 | [Nest](https://github.com/nestjs/nest) framework TypeScript starter repository. 27 | 28 | ## Project setup 29 | 30 | ```bash 31 | $ npm install 32 | ``` 33 | 34 | ## Compile and run the project 35 | 36 | ```bash 37 | # development 38 | $ npm run start 39 | 40 | # watch mode 41 | $ npm run start:dev 42 | 43 | # production mode 44 | $ npm run start:prod 45 | ``` 46 | 47 | ## Run tests 48 | 49 | ```bash 50 | # unit tests 51 | $ npm run test 52 | 53 | # e2e tests 54 | $ npm run test:e2e 55 | 56 | # test coverage 57 | $ npm run test:cov 58 | ``` 59 | 60 | ## Deployment 61 | 62 | When you're ready to deploy your NestJS application to production, there are some key steps you can take to ensure it runs as efficiently as possible. Check out the [deployment documentation](https://docs.nestjs.com/deployment) for more information. 63 | 64 | If you are looking for a cloud-based platform to deploy your NestJS application, check out [Mau](https://mau.nestjs.com), our official platform for deploying NestJS applications on AWS. Mau makes deployment straightforward and fast, requiring just a few simple steps: 65 | 66 | ```bash 67 | $ npm install -g @nestjs/mau 68 | $ mau deploy 69 | ``` 70 | 71 | With Mau, you can deploy your application in just a few clicks, allowing you to focus on building features rather than managing infrastructure. 72 | 73 | ## Resources 74 | 75 | Check out a few resources that may come in handy when working with NestJS: 76 | 77 | - Visit the [NestJS Documentation](https://docs.nestjs.com) to learn more about the framework. 78 | - For questions and support, please visit our [Discord channel](https://discord.gg/G7Qnnhy). 79 | - To dive deeper and get more hands-on experience, check out our official video [courses](https://courses.nestjs.com/). 80 | - Deploy your application to AWS with the help of [NestJS Mau](https://mau.nestjs.com) in just a few clicks. 81 | - Visualize your application graph and interact with the NestJS application in real-time using [NestJS Devtools](https://devtools.nestjs.com). 82 | - Need help with your project (part-time to full-time)? Check out our official [enterprise support](https://enterprise.nestjs.com). 83 | - To stay in the loop and get updates, follow us on [X](https://x.com/nestframework) and [LinkedIn](https://linkedin.com/company/nestjs). 84 | - Looking for a job, or have a job to offer? Check out our official [Jobs board](https://jobs.nestjs.com). 85 | 86 | ## Support 87 | 88 | Nest is an MIT-licensed open source project. It can grow thanks to the sponsors and support by the amazing backers. If you'd like to join them, please [read more here](https://docs.nestjs.com/support). 89 | 90 | ## Stay in touch 91 | 92 | - Author - [Kamil Myśliwiec](https://twitter.com/kammysliwiec) 93 | - Website - [https://nestjs.com](https://nestjs.com/) 94 | - Twitter - [@nestframework](https://twitter.com/nestframework) 95 | 96 | ## License 97 | 98 | Nest is [MIT licensed](https://github.com/nestjs/nest/blob/master/LICENSE). 99 | -------------------------------------------------------------------------------- /src/chats/chats.service.ts: -------------------------------------------------------------------------------- 1 | import { 2 | BadRequestException, 3 | ForbiddenException, 4 | Injectable, 5 | NotFoundException, 6 | } from '@nestjs/common'; 7 | import { InjectModel } from '@nestjs/mongoose'; 8 | import { Model, RootFilterQuery, Types } from 'mongoose'; 9 | 10 | import { BranchesService } from '@/branches/branches.service'; 11 | import { MessagesService } from '@/messages/messages.service'; 12 | import { WebsocketsService } from '@/websockets/websockets.service'; 13 | import { GetManyChatsDto } from './dto/get-chat-dto'; 14 | import { UpdateChatDto } from './dto/update-chat.dto'; 15 | import { Chat, ChatDocument, ChatsResponse } from './schemas/chat.schema'; 16 | 17 | @Injectable() 18 | export class ChatService { 19 | constructor( 20 | @InjectModel(Chat.name) private chatModel: Model, 21 | private readonly messagesService: MessagesService, 22 | private readonly branchService: BranchesService, 23 | private readonly websocketsService: WebsocketsService 24 | ) {} 25 | 26 | async createChat(userId: string): Promise { 27 | // Create chat 28 | const chat = new this.chatModel({ 29 | userId: new Types.ObjectId(userId), 30 | title: 'New Chat', 31 | isPublic: false, 32 | lastActivityAt: new Date(), 33 | }); 34 | 35 | await chat.save(); 36 | 37 | // Create default branch 38 | const defaultBranch = await this.branchService.create(userId, chat, { 39 | name: 'main', 40 | branchPoint: 0, 41 | }); 42 | 43 | // update chat with default branch 44 | chat.defaultBranch = defaultBranch._id; 45 | await chat.save(); 46 | 47 | const populated = await chat.populate('defaultBranch'); 48 | this.websocketsService.emitChatCreated(userId, populated); 49 | return populated; 50 | } 51 | 52 | async findById(chatId: string, userId?: string, populate = true): Promise { 53 | if (!Types.ObjectId.isValid(chatId)) { 54 | throw new BadRequestException('Invalid chat id'); 55 | } 56 | 57 | let operation = this.chatModel.findById(chatId); 58 | if (populate) { 59 | operation = operation.populate('defaultBranch'); 60 | } 61 | 62 | const chat = await operation.exec(); 63 | 64 | if (!chat) { 65 | throw new NotFoundException('Chat not found'); 66 | } 67 | 68 | // Check permissions 69 | if (userId && !chat.isPublic && chat.userId.toString() !== userId) { 70 | throw new ForbiddenException('You do not have permission to access this chat'); 71 | } 72 | 73 | return chat; 74 | } 75 | 76 | async findByUserId(userId: string, options: GetManyChatsDto): Promise { 77 | const { limit = 20, offset = 0, archived = false, search } = options; 78 | 79 | const query: RootFilterQuery = { 80 | userId: new Types.ObjectId(userId), 81 | archived, 82 | }; 83 | 84 | // Search functionality 85 | if (search) { 86 | query.$or = [ 87 | { title: { $regex: search, $options: 'i' } }, 88 | { 'metadata.description': { $regex: search, $options: 'i' } }, 89 | ]; 90 | } 91 | 92 | const [chats, total] = await Promise.all([ 93 | await this.chatModel 94 | .find(query) 95 | .sort({ lastActivityAt: -1 }) 96 | .skip(offset) 97 | .limit(limit) 98 | .populate('defaultBranch') 99 | .exec(), 100 | 101 | await this.chatModel.countDocuments(query).exec(), 102 | ]); 103 | 104 | return { 105 | chats, 106 | total, 107 | hasMore: offset + limit < total, 108 | }; 109 | } 110 | 111 | async update(chatId: string, userId: string, updateData: UpdateChatDto): Promise { 112 | const chat = await this.findById(chatId, userId); 113 | 114 | if (chat.userId.toString() !== userId) { 115 | throw new ForbiddenException('You do not have permission to update this chat'); 116 | } 117 | 118 | // Update fields 119 | chat.title = updateData.title ?? chat.title; 120 | chat.isPublic = updateData.isPublic ?? chat.isPublic; 121 | chat.archived = updateData.archived ?? chat.archived; 122 | chat.pinned = updateData.pinned ?? chat.pinned; 123 | 124 | const saved = await chat.save(); 125 | await saved.populate('defaultBranch'); 126 | this.websocketsService.emitChatUpdated(userId, saved); 127 | return saved; 128 | } 129 | 130 | async updateLastActivity(chatId: string): Promise { 131 | await this.chatModel.findByIdAndUpdate(chatId, { lastActivityAt: new Date() }).exec(); 132 | } 133 | 134 | async delete(chatId: string, userId: string): Promise { 135 | const chat = await this.findById(chatId, userId); 136 | 137 | if (chat.userId.toString() !== userId) { 138 | throw new ForbiddenException('You do not have permission to delete this chat'); 139 | } 140 | 141 | await this.chatModel.deleteOne({ _id: chatId }).exec(); 142 | await this.branchService.deleteAllByChatId(chatId, userId); 143 | await this.messagesService.deleteAllByChatId(chatId); 144 | this.websocketsService.emitChatDeleted(userId, chatId); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/ai/clients/google.client.ts: -------------------------------------------------------------------------------- 1 | import { GoogleGenAI } from '@google/genai'; 2 | 3 | import { Message, MessageRole } from '@/messages/schemas/message.schema'; 4 | import { 5 | AIModel, 6 | AIProviderCallbacks, 7 | AIProviderClient, 8 | AIProviderId, 9 | AIProviderOptions, 10 | } from '../interfaces/ai-provider.interface'; 11 | 12 | export class GoogleClient implements AIProviderClient { 13 | async validateKeyFormat(key: string): Promise { 14 | const models = await this.getModels(key).catch(() => []); 15 | return models.length > 0; 16 | } 17 | 18 | async getModels(key: string): Promise { 19 | const client = new GoogleGenAI({ apiKey: key }); 20 | const raw = await client.models.list(); 21 | const models: AIModel[] = []; 22 | 23 | for (const model of raw.page) { 24 | const actions = model.supportedActions; 25 | const id = model.name?.replace('models/', ''); 26 | const name = model.displayName; 27 | 28 | if (!id || !name || !actions) { 29 | console.log('Invalid model:', model); 30 | continue; 31 | } 32 | 33 | models.push({ 34 | id, 35 | name, 36 | author: 'Google', 37 | provider: AIProviderId.google, 38 | enabled: true, 39 | capabilities: { 40 | codeExecution: false, 41 | fileAnalysis: false, 42 | functionCalling: false, 43 | imageAnalysis: false, 44 | imageGeneration: false, 45 | textGeneration: actions.includes('generateContent'), 46 | webBrowsing: false, 47 | }, 48 | }); 49 | } 50 | 51 | return models; 52 | } 53 | 54 | async countInputTokens( 55 | key: string, 56 | modelId: string, 57 | messages: Message[], 58 | settings: AIProviderOptions 59 | ): Promise { 60 | const client = new GoogleGenAI({ apiKey: key }); 61 | 62 | const history = messages.map(message => ({ 63 | role: message.role == MessageRole.user ? 'user' : 'model', 64 | parts: message.content 65 | .filter(part => part.text) 66 | .map(part => ({ 67 | text: part.text!, 68 | })), 69 | })); 70 | 71 | const chat = client.chats.create({ 72 | model: modelId, 73 | history: history, 74 | config: { 75 | temperature: settings.temperature, 76 | maxOutputTokens: settings.maxTokens, 77 | }, 78 | }); 79 | 80 | const countTokensResponse = await client.models.countTokens({ 81 | model: modelId, 82 | contents: chat.getHistory(), 83 | }); 84 | 85 | return countTokensResponse.totalTokens || 0; 86 | } 87 | 88 | async sendMessage( 89 | key: string, 90 | modelId: string, 91 | messages: Message[], 92 | settings: AIProviderOptions, 93 | callbacks: AIProviderCallbacks 94 | ) { 95 | const client = new GoogleGenAI({ apiKey: key }); 96 | const previousMessages = [...messages]; 97 | const lastMessage = previousMessages.pop(); 98 | 99 | const history = previousMessages.map(message => ({ 100 | role: message.role == MessageRole.user ? 'user' : 'model', 101 | parts: message.content 102 | .filter(part => part.text) 103 | .map(part => ({ 104 | text: part.text!, 105 | })), 106 | })); 107 | 108 | const chat = client.chats.create({ 109 | model: modelId, 110 | history, 111 | config: { 112 | temperature: settings.temperature, 113 | maxOutputTokens: settings.maxTokens, 114 | }, 115 | }); 116 | 117 | const response = await chat.sendMessageStream({ 118 | config: { 119 | temperature: settings.temperature, 120 | maxOutputTokens: settings.maxTokens, 121 | }, 122 | message: 123 | lastMessage?.content 124 | .filter(part => part.text) 125 | .map(part => part.text) 126 | .join('\n') || '', 127 | }); 128 | 129 | try { 130 | for await (const chunk of response) { 131 | if (!chunk.text) { 132 | continue; 133 | } 134 | 135 | void callbacks.onText(chunk.text); 136 | } 137 | 138 | void callbacks.onEnd(); 139 | } catch (error) { 140 | const message = (error.message as string) || 'Unknown error'; 141 | void callbacks.onError(message); 142 | } 143 | } 144 | 145 | async generateImage( 146 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 147 | key: string, 148 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 149 | modelId: string, 150 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 151 | promptOrMessages: string | Message[], 152 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 153 | settings: AIProviderOptions, 154 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 155 | callbacks: AIProviderCallbacks 156 | ) {} 157 | } 158 | -------------------------------------------------------------------------------- /src/storage/r2WorkerClient.ts: -------------------------------------------------------------------------------- 1 | import * as jwt from 'jsonwebtoken'; 2 | import fetch from 'node-fetch'; 3 | 4 | export interface CreateResponse { 5 | success: boolean; 6 | error?: string; 7 | uploadId: string; 8 | fileId: string; 9 | totalParts: number; 10 | chunkSize: number; 11 | clientToken: string; 12 | } 13 | 14 | export interface CompleteResponse { 15 | success: boolean; 16 | error?: string; 17 | etag: string; 18 | size: number; 19 | } 20 | 21 | export class R2WorkerClient { 22 | constructor( 23 | private readonly url: string, 24 | private readonly secret: string 25 | ) {} 26 | 27 | private generateJwtTicket(type: string, action: string) { 28 | const now = Math.floor(Date.now() / 1000); 29 | 30 | return jwt.sign( 31 | { 32 | type, 33 | action, 34 | iat: now - 60, // Fix for clock skew 35 | exp: now + 6 * 60, // 5 minutes - 1 minute for clock skew 36 | }, 37 | this.secret, 38 | { 39 | algorithm: 'HS256', 40 | } 41 | ); 42 | } 43 | private async apiFetch( 44 | method: string, 45 | path: string, 46 | action: string, 47 | body: any, 48 | headers = {} 49 | ): Promise { 50 | const jwtTicket = this.generateJwtTicket('backend', action); 51 | const response = await fetch(`${this.url}/${path}`, { 52 | method, 53 | body: JSON.stringify(body), 54 | headers: { 55 | Authorization: `Bearer ${jwtTicket}`, 56 | 'Content-Type': 'application/json', 57 | ...headers, 58 | }, 59 | }); 60 | return response.json() as T; 61 | } 62 | 63 | public async createUpload( 64 | fileId: string, 65 | fileSize: number, 66 | fileMimeType: string 67 | ): Promise { 68 | return this.apiFetch('POST', 'upload/create', 'create', { 69 | id: fileId, 70 | fileSize: fileSize, 71 | mimeType: fileMimeType, 72 | }); 73 | } 74 | 75 | public async completeUpload( 76 | uploadId: string, 77 | fileId: string, 78 | parts: { partNumber: number; etag: string }[] 79 | ): Promise { 80 | return this.apiFetch( 81 | 'POST', 82 | 'upload/complete?fileId=' + fileId, 83 | 'complete', 84 | { 85 | uploadId, 86 | parts, 87 | } 88 | ); 89 | } 90 | 91 | public async abortUpload(uploadId: string, fileId: string): Promise { 92 | const { success } = await this.apiFetch<{ success: boolean }>( 93 | 'POST', 94 | 'upload/abort/' + uploadId, 95 | 'abort', 96 | { 97 | fileId, 98 | } 99 | ); 100 | return success; 101 | } 102 | 103 | public async deleteFile(fileId: string): Promise { 104 | const { success } = await this.apiFetch<{ success: boolean }>( 105 | 'DELETE', 106 | 'file/' + fileId, 107 | 'delete', 108 | {} 109 | ); 110 | return success; 111 | } 112 | 113 | public async uploadFileBuffer(clientToken: string, data: Buffer) { 114 | const parts: { partNumber: number; etag: string }[] = []; 115 | const chunks: Buffer[] = []; 116 | 117 | const chunkSize = 5 * 1024 * 1024; // 5MB 118 | for (let i = 0; i < data.length; i += chunkSize) { 119 | const chunk = data.slice(i, i + chunkSize); 120 | const partNumber = i / chunkSize + 1; 121 | chunks.push(chunk); 122 | parts.push({ partNumber, etag: '' }); 123 | } 124 | 125 | for (let i = 0; i < chunks.length; i++) { 126 | const chunk = chunks[i]; 127 | const partNumber = i + 1; 128 | 129 | let endpoint = '/upload/part/' + partNumber; 130 | if (i === chunks.length - 1) { 131 | endpoint += '?isLast=true'; 132 | } 133 | 134 | const res = await fetch(`${this.url}${endpoint}`, { 135 | method: 'PUT', 136 | body: chunk, 137 | headers: { 138 | Authorization: `Bearer ${clientToken}`, 139 | }, 140 | }); 141 | 142 | const json = (await res.json()) as { etag: string }; 143 | parts[i].etag = json.etag; 144 | } 145 | 146 | return parts; 147 | } 148 | 149 | public async createAndUploadFile(fileId: string, fileMimeType: string, data: Buffer) { 150 | const size = data.byteLength; 151 | const upload = await this.createUpload(fileId, size, fileMimeType); 152 | const parts = await this.uploadFileBuffer(upload.clientToken, data); 153 | return this.completeUpload(upload.uploadId, fileId, parts); 154 | } 155 | 156 | public getUrlForFile(fileId: string): string { 157 | return `${this.url}/file/${fileId}`; 158 | } 159 | 160 | public async readFileAsString(fileId: string): Promise { 161 | const response = await fetch(`${this.url}/file/${fileId}`); 162 | return response.text(); 163 | } 164 | 165 | public async readFileAsBuffer(fileId: string): Promise { 166 | const response = await fetch(`${this.url}/file/${fileId}`); 167 | const arrayBuffer = await response.arrayBuffer(); 168 | return Buffer.from(arrayBuffer); 169 | } 170 | 171 | public async readAsBase64URL(fileId: string, mimeType: string): Promise { 172 | const buffer = await this.readFileAsBuffer(fileId); 173 | return `data:${mimeType};base64,${buffer.toString('base64')}`; 174 | } 175 | 176 | public async readAsBase64Buffer(fileId: string): Promise { 177 | const buffer = await this.readFileAsBuffer(fileId); 178 | return buffer.toString('base64'); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/branches/branches.service.ts: -------------------------------------------------------------------------------- 1 | import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common'; 2 | import { InjectModel } from '@nestjs/mongoose'; 3 | import { Model, Types } from 'mongoose'; 4 | 5 | import { Chat } from '@/chats/schemas/chat.schema'; 6 | import { ApiKeysService } from '@/keys/api-key.service'; 7 | import { MessagesService } from '@/messages/messages.service'; 8 | import { WebsocketsService } from '@/websockets/websockets.service'; 9 | import { CreateBranchDto } from './dto/create-branch.dto'; 10 | import { ForkBranchDto } from './dto/fork-branch.dto'; 11 | import { UpdateBranchDto } from './dto/update-branch.dto'; 12 | import { ChatBranch, ChatBranchDocument } from './schemas/chat-branch.schema'; 13 | 14 | @Injectable() 15 | export class BranchesService { 16 | constructor( 17 | @InjectModel(ChatBranch.name) private branchModel: Model, 18 | @InjectModel(Chat.name) private chatModel: Model, 19 | private readonly messageService: MessagesService, 20 | private readonly apiKeyService: ApiKeysService, 21 | private readonly websocketsService: WebsocketsService 22 | ) {} 23 | 24 | async create( 25 | userId: string, 26 | chatId: string | Chat, 27 | data: CreateBranchDto, 28 | messageCount = 0, 29 | isActive = true 30 | ): Promise { 31 | if (!Types.ObjectId.isValid(userId)) { 32 | throw new BadRequestException('Invalid user id'); 33 | } 34 | 35 | const chat = 36 | typeof chatId === 'string' ? await this.chatModel.findById(chatId).exec() : chatId; 37 | 38 | if (!chat) { 39 | throw new NotFoundException('Chat not found'); 40 | } 41 | 42 | if (chat.userId.toString() !== userId) { 43 | throw new NotFoundException('Chat not found'); 44 | } 45 | 46 | const { parentBranchId, ...rest } = data; 47 | 48 | const branch = new this.branchModel({ 49 | userId: chat.userId, 50 | chatId: chat._id, 51 | parentBranchId: parentBranchId ? new Types.ObjectId(parentBranchId) : undefined, 52 | messageCount, 53 | isActive, 54 | ...rest, 55 | }); 56 | 57 | const saved = await branch.save(); 58 | this.websocketsService.emitBranchCreated(userId, saved); 59 | return saved; 60 | } 61 | 62 | async findById(branchId: string, userId?: string): Promise { 63 | if (!Types.ObjectId.isValid(branchId)) { 64 | throw new BadRequestException('Invalid branch id'); 65 | } 66 | 67 | const branch = await this.branchModel.findById(branchId).exec(); 68 | if (!branch) { 69 | throw new NotFoundException('Branch not found'); 70 | } 71 | 72 | if (userId && branch.userId?.toString() !== userId) { 73 | console.log(userId, branch.userId.toString()); 74 | throw new NotFoundException('Branch not found'); 75 | } 76 | 77 | return branch; 78 | } 79 | 80 | async findByChatId(chatId: string, userId?: string): Promise { 81 | if (!Types.ObjectId.isValid(chatId)) { 82 | throw new BadRequestException('Invalid chat id'); 83 | } 84 | 85 | if (userId && !Types.ObjectId.isValid(userId)) { 86 | throw new BadRequestException('Invalid user id'); 87 | } 88 | 89 | const query = { chatId: new Types.ObjectId(chatId), isActive: true }; 90 | 91 | if (userId) { 92 | query['userId'] = new Types.ObjectId(userId); 93 | } 94 | 95 | return this.branchModel.find(query).sort({ createdAt: 1 }).exec(); 96 | } 97 | 98 | async update( 99 | branchId: string, 100 | userId: string, 101 | updateData: UpdateBranchDto 102 | ): Promise { 103 | if (!userId) { 104 | throw new BadRequestException('User id is required'); 105 | } 106 | 107 | const branch = await this.findById(branchId, userId); 108 | 109 | const apiKeyId = branch.modelConfig?.apiKeyId; 110 | if (apiKeyId) { 111 | await this.apiKeyService.findById(apiKeyId, userId); 112 | } 113 | 114 | Object.assign(branch, updateData); 115 | this.websocketsService.emitBranchUpdated(userId, branch); 116 | return await branch.save(); 117 | } 118 | 119 | async updateMessageCount(branchId: string, count: number): Promise { 120 | const branch = await this.branchModel 121 | .findByIdAndUpdate(branchId, { messageCount: count }, { new: true }) 122 | .exec(); 123 | if (!branch) { 124 | throw new NotFoundException('Branch not found'); 125 | } 126 | return branch; 127 | } 128 | 129 | async incrementMessageCount(branchId: string): Promise { 130 | await this.branchModel.findByIdAndUpdate(branchId, { $inc: { messageCount: 1 } }).exec(); 131 | } 132 | 133 | async delete(branchId: string, userId?: string): Promise { 134 | const branch = await this.findById(branchId, userId); 135 | 136 | // Delete all branch messages 137 | await this.messageService.deleteAllByBranchId(branchId); 138 | 139 | // Mark branch as inactive 140 | branch.isActive = false; 141 | await branch.save(); 142 | 143 | if (userId) { 144 | this.websocketsService.emitBranchDeleted(userId, branchId); 145 | } 146 | 147 | return branch.messageCount; 148 | } 149 | 150 | async deleteAllByChatId(chatId: string, userId: string): Promise { 151 | if (!userId) { 152 | throw new BadRequestException('User id is required'); 153 | } 154 | 155 | const branches = await this.findByChatId(chatId, userId); 156 | 157 | for (const branch of branches) { 158 | await this.delete(branch._id.toString()); 159 | } 160 | 161 | return branches; 162 | } 163 | 164 | async forkBranch( 165 | userId: string, 166 | originalBranchId: string, 167 | payload: ForkBranchDto 168 | ): Promise { 169 | const { name, cloneMessages } = payload; 170 | 171 | const originalBranch = await this.findById(originalBranchId, userId); 172 | 173 | const newBranch = await this.create( 174 | userId, 175 | originalBranch.chatId.toString(), 176 | { 177 | name: name || `Fork of ${originalBranch.name}`, 178 | branchPoint: originalBranch.branchPoint + 1, 179 | modelConfig: originalBranch.modelConfig, 180 | parentBranchId: originalBranch._id.toString(), 181 | }, 182 | originalBranch.messageCount, 183 | originalBranch.isActive 184 | ); 185 | 186 | if (cloneMessages) { 187 | await this.messageService.cloneAllByBranchId( 188 | originalBranchId, 189 | newBranch._id.toString() 190 | ); 191 | } 192 | 193 | return newBranch; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/auth/auth.service.ts: -------------------------------------------------------------------------------- 1 | import { Injectable, UnauthorizedException } from '@nestjs/common'; 2 | import { ConfigService } from '@nestjs/config'; 3 | import { JwtService } from '@nestjs/jwt'; 4 | import * as crypto from 'crypto'; 5 | 6 | import { EncryptionService } from 'src/encryption/encryption.service'; 7 | import { UsersService } from 'src/users/users.service'; 8 | import { SessionResponse } from '../sessions/schemas/session.schema'; 9 | import { SessionsService } from '../sessions/sessions.service'; 10 | import { comparePassword, User } from '../users/schemas/user.schema'; 11 | import { ChangePasswordDto } from './dto/change-password.dto'; 12 | import { LoginDto } from './dto/login.dto'; 13 | import { RegisterDto } from './dto/register.dto'; 14 | import { DeviceInfo } from './interfaces/device-info.interface'; 15 | import { AccessJwtPayload } from './interfaces/jwt-payload.interface'; 16 | 17 | @Injectable() 18 | export class AuthService { 19 | constructor( 20 | private usersService: UsersService, 21 | private sessionService: SessionsService, 22 | private jwtService: JwtService, 23 | private configService: ConfigService, 24 | private encryptionService: EncryptionService 25 | ) {} 26 | 27 | async register(registerDto: RegisterDto, deviceInfo: DeviceInfo): Promise { 28 | const { email, password, displayName } = registerDto; 29 | 30 | // Check if user already exists 31 | const existingUser = await this.usersService.findByEmail(email); 32 | if (existingUser) { 33 | throw new Error('User with this email already exists'); 34 | } 35 | 36 | // Generate encryption keys for key vault 37 | const { encryptKey, decryptKey } = this.encryptionService.generateKeyPair(); 38 | const encryptedDecryptKey = this.encryptionService.encrypt(decryptKey, password); 39 | 40 | // Create user 41 | const user = await this.usersService.create({ 42 | email, 43 | password, 44 | displayName, 45 | encryptKey, 46 | decryptKey: encryptedDecryptKey, 47 | }); 48 | 49 | // Create initial session 50 | return this.createUserSession(user, deviceInfo, decryptKey); 51 | } 52 | 53 | async login(loginDto: LoginDto, deviceInfo: DeviceInfo): Promise { 54 | const { email, password } = loginDto; 55 | 56 | // Check if user exists 57 | const user = await this.usersService.findByEmail(email); 58 | if (!user) { 59 | throw new UnauthorizedException('Invalid Credentials'); 60 | } 61 | 62 | // Validate password 63 | const passwordMatch = await comparePassword(password, user.password); 64 | if (!passwordMatch) { 65 | throw new UnauthorizedException('Invalid Credentials'); 66 | } 67 | 68 | // Update last login 69 | await this.usersService.updateLastLogin(user._id.toString()); 70 | 71 | // Decrypt key 72 | const decryptKey = this.encryptionService.decrypt(user.decryptKey, password); 73 | 74 | // Create session 75 | return await this.createUserSession(user, deviceInfo, decryptKey); 76 | } 77 | 78 | async refreshTokens(refreshToken: string): Promise { 79 | try { 80 | // Find session 81 | const session = await this.sessionService.findByRefreshToken(refreshToken); 82 | 83 | if (!session || !session.isActive) { 84 | throw new UnauthorizedException('Invalid or expired refresh token'); 85 | } 86 | 87 | // Find user 88 | const user = await this.usersService.findById(session.userId.toString()); 89 | 90 | // Generate new tokens 91 | const newAccessToken = this.generateAccessToken(user, session._id.toString()); 92 | const newRefreshToken = this.generateRefreshToken( 93 | user._id.toString(), 94 | session._id.toString() 95 | ); 96 | 97 | // Update session 98 | await this.sessionService.updateTokens( 99 | session._id.toString(), 100 | newRefreshToken, 101 | newAccessToken 102 | ); 103 | 104 | return { 105 | accessToken: newAccessToken, 106 | refreshToken: newRefreshToken, 107 | user, 108 | }; 109 | } catch { 110 | throw new UnauthorizedException('Invalid or expired refresh token'); 111 | } 112 | } 113 | 114 | async logout(userId: string, sessionId: string) { 115 | await this.sessionService.revokeSession(userId, sessionId); 116 | } 117 | 118 | async logoutAll(userId: string) { 119 | await this.sessionService.revokeAllUserSessions(userId); 120 | } 121 | 122 | async getUserSessions(userId: string) { 123 | return this.sessionService.findByUserId(userId); 124 | } 125 | 126 | async updatePassword( 127 | userId: string, 128 | { oldPassword, newPassword }: ChangePasswordDto 129 | ): Promise { 130 | const user = await this.usersService.findById(userId); 131 | const matchPassword = await comparePassword(oldPassword, user.password); 132 | 133 | if (!matchPassword) { 134 | throw new UnauthorizedException('Invalid Old Password'); 135 | } 136 | 137 | const oldDecryptKey = this.encryptionService.decrypt(user.decryptKey, oldPassword); 138 | const newDecryptKey = this.encryptionService.encrypt(oldDecryptKey, newPassword); 139 | user.decryptKey = newDecryptKey; 140 | await user.save(); 141 | 142 | return user; 143 | } 144 | 145 | generateAccessToken(user: User, sessionId: string): string { 146 | const payload: AccessJwtPayload = { 147 | sub: user._id.toString(), 148 | email: user.email, 149 | sessionId, 150 | }; 151 | 152 | return this.jwtService.sign(payload, { 153 | secret: this.configService.get('JWT_SECRET'), 154 | expiresIn: this.configService.get('JWT_EXPIRATION'), 155 | }); 156 | } 157 | 158 | generateRefreshToken(userId: string, sessionId: string): string { 159 | return this.jwtService.sign( 160 | { sub: userId, sessionId }, 161 | { 162 | secret: this.configService.get('JWT_SECRET'), 163 | expiresIn: this.configService.get('JWT_REFRESH_EXPIRATION'), 164 | } 165 | ); 166 | } 167 | 168 | async createUserSession( 169 | user: User, 170 | deviceInfo: DeviceInfo, 171 | rawDecryptKey?: string 172 | ): Promise { 173 | // Generate tokens 174 | const sessionId = crypto.randomUUID(); 175 | const accessToken = this.generateAccessToken(user, sessionId); 176 | const refreshToken = this.generateRefreshToken(user._id.toString(), sessionId); 177 | 178 | // Calculate expiration (30 days) 179 | const expiresAt = new Date(); 180 | expiresAt.setDate(expiresAt.getDate() + 30); 181 | 182 | // Create session 183 | await this.sessionService.createSession({ 184 | userId: user._id, 185 | refreshToken, 186 | accessToken, 187 | deviceInfo, 188 | expiresAt, 189 | }); 190 | 191 | return { 192 | accessToken, 193 | refreshToken, 194 | user, 195 | rawDecryptKey, 196 | }; 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /src/storage/storage.service.ts: -------------------------------------------------------------------------------- 1 | import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common'; 2 | import { ConfigService } from '@nestjs/config'; 3 | import { InjectModel } from '@nestjs/mongoose'; 4 | import { Model } from 'mongoose'; 5 | // eslint-disable-next-line @typescript-eslint/no-require-imports 6 | const pdf = require('pdf-parse'); 7 | 8 | import fetch from 'node-fetch'; 9 | import { R2WorkerClient } from './r2WorkerClient'; 10 | import { File, FileDocument, UserStorageStats } from './schemas/file.schema'; 11 | 12 | @Injectable() 13 | export class StorageService { 14 | private readonly MAX_STORAGE_PER_USER = 50 * 1024 * 1024; // 50MB per user 15 | private readonly MAX_FILE_SIZE = 10 * 1024 * 1024; // 25MB per file 16 | private readonly ALLOWED_MIMETYPES = [ 17 | 'image/jpeg', 18 | 'image/png', 19 | 'image/gif', 20 | 'image/webp', 21 | 'application/pdf', 22 | 'text/plain', 23 | 'application/msword', 24 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 25 | 'audio/mpeg', 26 | 'audio/wav', 27 | ]; 28 | 29 | private readonly r2: R2WorkerClient; 30 | 31 | constructor( 32 | @InjectModel(File.name) 33 | private fileModel: Model, 34 | private configService: ConfigService 35 | ) { 36 | const workerURL = this.configService.get('R2_WORKER_URL'); 37 | const workerSecret = this.configService.get('R2_WORKER_SECRET'); 38 | 39 | if (!workerURL || !workerSecret) { 40 | throw new Error('R2_WORKER_URL and R2_WORKER_SECRET must be set'); 41 | } 42 | 43 | this.r2 = new R2WorkerClient(workerURL, workerSecret); 44 | } 45 | 46 | /** 47 | * Calculate total storage used by a user 48 | */ 49 | async getUserStorageUsage(userId: string): Promise { 50 | const result = await this.fileModel.aggregate([ 51 | { $match: { userId } }, 52 | { $group: { _id: null, totalSize: { $sum: '$size' } } }, 53 | ]); 54 | 55 | return result.length > 0 ? result[0].totalSize : 0; 56 | } 57 | 58 | /** 59 | * Get remaining storage for a user 60 | */ 61 | async getUserStorageStats(userId: string): Promise { 62 | const used = await this.getUserStorageUsage(userId); 63 | const limit = this.MAX_STORAGE_PER_USER; 64 | const remaining = Math.max(0, limit - used); 65 | return { used, limit, remaining }; 66 | } 67 | 68 | /** 69 | * Validate file mimetype 70 | */ 71 | validateMimeType(mimetype: string): void { 72 | if (!this.ALLOWED_MIMETYPES.includes(mimetype)) { 73 | throw new BadRequestException('File type not allowed'); 74 | } 75 | } 76 | 77 | /** 78 | * Get file by ID 79 | */ 80 | async getFileById(id: string, userId?: string): Promise { 81 | const query = userId ? { _id: id, userId } : { _id: id }; 82 | const file = await this.fileModel.findOne(query).exec(); 83 | 84 | if (!file) { 85 | throw new NotFoundException('File with ID ' + id + ' not found'); 86 | } 87 | 88 | return file; 89 | } 90 | 91 | /** 92 | * Get by user 93 | */ 94 | async getUserFiles(userId: string): Promise { 95 | return this.fileModel.find({ userId }).sort({ createdAt: -1 }).exec(); 96 | } 97 | 98 | /** 99 | * Create a new file 100 | */ 101 | async createFile(userId: string, filename: string, mimetype: string, fileSize: number) { 102 | this.validateMimeType(mimetype); 103 | 104 | // Check if user has enough storage 105 | const currentUsage = await this.getUserStorageUsage(userId); 106 | if (currentUsage + fileSize > this.MAX_STORAGE_PER_USER) { 107 | throw new BadRequestException('Storage quota exceeded'); 108 | } 109 | 110 | // Create file 111 | const file = new this.fileModel({ userId, filename, mimetype, size: fileSize }); 112 | await file.save(); 113 | 114 | // Create file using R2 115 | const upload = await this.r2.createUpload(file._id, fileSize, mimetype); 116 | if (!upload.success) { 117 | await file.deleteOne(); 118 | throw new BadRequestException('File upload failed: ' + upload.error); 119 | } 120 | file.clientToken = upload.clientToken; 121 | file.uploadId = upload.uploadId; 122 | await file.save(); 123 | 124 | return file; 125 | } 126 | 127 | /** 128 | * Complete file upload 129 | */ 130 | async completeFileUpload( 131 | userId: string, 132 | fileId: string, 133 | parts: { partNumber: number; etag: string }[] 134 | ): Promise { 135 | const file = await this.getFileById(fileId, userId); 136 | 137 | if (!file.clientToken || !file.uploadId) { 138 | throw new BadRequestException('File already completed.'); 139 | } 140 | 141 | const completed = await this.r2.completeUpload(file.uploadId, fileId, parts); 142 | 143 | if (!completed.success) { 144 | throw new BadRequestException('File upload failed: ' + completed.error); 145 | } 146 | 147 | file.clientToken = undefined; 148 | file.uploadId = undefined; 149 | await file.save(); 150 | 151 | return file; 152 | } 153 | 154 | /** 155 | * Delete already uploaded file 156 | */ 157 | async deleteFile(fileId: string, userId?: string): Promise { 158 | const file = await this.getFileById(fileId, userId); 159 | const success = await this.r2.deleteFile(file._id); 160 | if (!success) { 161 | throw new BadRequestException('File deletion failed'); 162 | } 163 | 164 | await file.deleteOne(); 165 | return true; 166 | } 167 | 168 | /** 169 | * Get absolute URL for file 170 | */ 171 | getUrlForFile(id: string): string { 172 | return this.r2.getUrlForFile(id); 173 | } 174 | 175 | /** 176 | * Get file content 177 | */ 178 | async readFileAsPlainText(id: string): Promise { 179 | return await this.r2.readFileAsString(id); 180 | } 181 | 182 | /** 183 | * Get file as buffer 184 | */ 185 | async readFileAsBuffer(id: string): Promise { 186 | return await this.r2.readFileAsBuffer(id); 187 | } 188 | 189 | /** 190 | * Get file as base64 url 191 | */ 192 | async readFileAsBase64URL(id: string, mimetype: string): Promise { 193 | return await this.r2.readAsBase64URL(id, mimetype); 194 | } 195 | 196 | /** 197 | * Get file as base64 url 198 | */ 199 | async readFileAsBase64Buffer(id: string): Promise { 200 | return await this.r2.readAsBase64Buffer(id); 201 | } 202 | 203 | /** 204 | * Get file pdf as literal text 205 | */ 206 | async readFileAsPDF(id: string): Promise { 207 | const buffer = await this.readFileAsBuffer(id); 208 | const pdfText = await pdf(buffer); 209 | return pdfText.text; 210 | } 211 | 212 | async uploadFromURL(url: string, name: string, mimeType: string, userId: string) { 213 | const res = await fetch(url); 214 | const arrayBuffer = await res.arrayBuffer(); 215 | const buffer = Buffer.from(arrayBuffer); 216 | 217 | const file = await this.createFile(userId, name, mimeType, buffer.byteLength); 218 | const parts = await this.r2.uploadFileBuffer(file.clientToken!, buffer); 219 | return await this.completeFileUpload(userId, file.id, parts); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /src/messages/messages.service.ts: -------------------------------------------------------------------------------- 1 | import { BadRequestException, Injectable, NotFoundException } from '@nestjs/common'; 2 | import { InjectModel } from '@nestjs/mongoose'; 3 | import { Model, RootFilterQuery, Types } from 'mongoose'; 4 | 5 | import { ChatBranch } from '@/branches/schemas/chat-branch.schema'; 6 | import { WebsocketsService } from '@/websockets/websockets.service'; 7 | import { GetMessagesDto } from './dto/get-messages.dto'; 8 | import { UpdateMessageDto } from './dto/update-message.dto'; 9 | import { Message, MessageDocument, MessagesResponse } from './schemas/message.schema'; 10 | 11 | @Injectable() 12 | export class MessagesService { 13 | constructor( 14 | @InjectModel(Message.name) private messageModel: Model, 15 | @InjectModel(ChatBranch.name) private branchModel: Model, 16 | private websocketsService: WebsocketsService 17 | ) {} 18 | 19 | async create( 20 | messageData: Omit, 21 | userId?: string 22 | ): Promise { 23 | const { branchId, chatId, content, role, attachments, modelUsed, tokens } = messageData; 24 | 25 | if (!Types.ObjectId.isValid(branchId)) { 26 | throw new BadRequestException('Invalid branch id'); 27 | } 28 | // Get next index for this branch 29 | const nextIndex = await this.getNextMessageIndex(branchId.toString()); 30 | 31 | const message = new this.messageModel({ 32 | attachments, 33 | branchId: new Types.ObjectId(branchId), 34 | chatId, 35 | content, 36 | metadata: {}, 37 | role, 38 | index: nextIndex, 39 | modelUsed, 40 | tokens, 41 | isEdited: false, 42 | }); 43 | 44 | if (userId) { 45 | this.websocketsService.emitToBranch( 46 | userId, 47 | branchId.toString(), 48 | 'message:new', 49 | message 50 | ); 51 | } 52 | 53 | return await message.save(); 54 | } 55 | 56 | async findById(messageId: string): Promise { 57 | if (!Types.ObjectId.isValid(messageId)) { 58 | throw new BadRequestException('Invalid message id'); 59 | } 60 | 61 | const message = await this.messageModel.findById(messageId).exec(); 62 | if (!message) { 63 | throw new NotFoundException('Message not found'); 64 | } 65 | 66 | return message; 67 | } 68 | 69 | async findByBranchId(options: GetMessagesDto, userId?: string): Promise { 70 | if (!Types.ObjectId.isValid(options.branchId)) { 71 | throw new BadRequestException('Invalid branch id'); 72 | } 73 | 74 | const branch = await this.branchModel.findById(options.branchId).exec(); 75 | if (!branch) { 76 | throw new NotFoundException('Branch not found'); 77 | } 78 | 79 | if (userId && branch.userId?.toString() !== userId) { 80 | throw new NotFoundException('Branch not found'); 81 | } 82 | 83 | const { limit = 50, offset = 0, fromIndex } = options; 84 | 85 | const query: RootFilterQuery = { 86 | branchId: new Types.ObjectId(options.branchId), 87 | }; 88 | 89 | if (fromIndex !== undefined) { 90 | query.index = { $gte: fromIndex }; 91 | } 92 | 93 | const [messages, total] = await Promise.all([ 94 | this.messageModel.find(query).sort({ index: 1 }).skip(offset).limit(limit).exec(), 95 | 96 | this.messageModel.countDocuments(query), 97 | ]); 98 | 99 | return { 100 | messages, 101 | total, 102 | hasMore: messages.length < total, 103 | }; 104 | } 105 | 106 | async update( 107 | messageId: string, 108 | updateData: UpdateMessageDto, 109 | userId: string 110 | ): Promise { 111 | const message = await this.findById(messageId); 112 | 113 | // Get branch 114 | const branch = await this.branchModel.findById(message.branchId.toString()); 115 | 116 | // Permission check 117 | if (!branch || branch.userId.toString() !== userId) { 118 | throw new NotFoundException('Message not found'); 119 | } 120 | 121 | if (!message.isEdited) { 122 | message.originalContent = message.content; 123 | } 124 | 125 | message.content = [ 126 | { 127 | type: 'text', 128 | text: updateData.content, 129 | }, 130 | ]; 131 | message.isEdited = true; 132 | message.editedAt = new Date(); 133 | 134 | return await message.save(); 135 | } 136 | 137 | async delete(messageId: string, userId: string): Promise { 138 | const message = await this.findById(messageId); 139 | 140 | // Get branch 141 | const branch = await this.branchModel.findById(message.branchId.toString()); 142 | 143 | // Permission check 144 | if (!branch || branch.userId.toString() !== userId) { 145 | throw new NotFoundException('Message not found'); 146 | } 147 | 148 | await this.messageModel.findByIdAndDelete(messageId).exec(); 149 | await this.reindexBranchMessages(message.branchId.toString()); 150 | return true; 151 | } 152 | 153 | async deleteAllByBranchId(branchId: string) { 154 | await this.messageModel.deleteMany({ branchId: new Types.ObjectId(branchId) }).exec(); 155 | } 156 | 157 | async deleteAllByChatId(chatId: string) { 158 | await this.messageModel.deleteMany({ chatId: new Types.ObjectId(chatId) }).exec(); 159 | } 160 | 161 | async reindexBranchMessages(branchId: string) { 162 | const messages = await this.messageModel 163 | .find({ branchId: new Types.ObjectId(branchId) }) 164 | .sort({ index: 1 }) 165 | .exec(); 166 | 167 | const updates = messages.map((message, i) => ({ 168 | updateOne: { 169 | filter: { _id: message._id }, 170 | update: { index: i }, 171 | }, 172 | })); 173 | 174 | if (updates.length > 0) { 175 | await this.messageModel.bulkWrite(updates); 176 | } 177 | } 178 | 179 | async getLastMessage(branchId: string): Promise { 180 | if (!Types.ObjectId.isValid(branchId)) { 181 | throw new BadRequestException('Invalid branch id'); 182 | } 183 | 184 | return this.messageModel 185 | .findOne({ branchId: new Types.ObjectId(branchId) }) 186 | .sort({ index: -1 }) 187 | .exec(); 188 | } 189 | 190 | async getNextMessageIndex(branchId: string): Promise { 191 | const lastMessage = await this.getLastMessage(branchId); 192 | return lastMessage ? lastMessage.index + 1 : 0; 193 | } 194 | 195 | async cloneAllByBranchId(branchId: string, newBranchId: string) { 196 | const messages = await this.messageModel 197 | .find({ branchId: new Types.ObjectId(branchId) }) 198 | .lean(); 199 | 200 | const clonedMessages = messages.map(msg => { 201 | // eslint-disable-next-line @typescript-eslint/no-unused-vars 202 | const { _id, branchId, ...rest } = msg; 203 | 204 | return { 205 | ...rest, 206 | chatId: new Types.ObjectId(msg.chatId), 207 | branchId: new Types.ObjectId(newBranchId), 208 | }; 209 | }); 210 | 211 | await this.messageModel.insertMany(clonedMessages); 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /src/schema.gql: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------ 2 | # THIS FILE WAS AUTOMATICALLY GENERATED (DO NOT MODIFY) 3 | # ------------------------------------------------------ 4 | 5 | type AIModel { 6 | author: String! 7 | capabilities: AIModelCapabilities! 8 | category: String 9 | cost: String 10 | description: String 11 | enabled: Boolean 12 | id: String! 13 | name: String! 14 | provider: String! 15 | speed: String 16 | } 17 | 18 | type AIModelCapabilities { 19 | codeExecution: Boolean! 20 | fileAnalysis: Boolean! 21 | functionCalling: Boolean! 22 | imageAnalysis: Boolean! 23 | imageGeneration: Boolean! 24 | textGeneration: Boolean! 25 | webBrowsing: Boolean! 26 | } 27 | 28 | enum AIProviderId { 29 | anthropic 30 | google 31 | openai 32 | openrouter 33 | } 34 | 35 | input AddMessageDto { 36 | apiKeyId: String! 37 | attachments: [String!] 38 | branchId: String! 39 | modelId: String! 40 | prompt: String! 41 | rawDecryptKey: String! 42 | useImageTool: Boolean 43 | } 44 | 45 | type ApiKey { 46 | _id: String! 47 | alias: String! 48 | lastRotated: DateTime 49 | lastUsed: DateTime 50 | lastValidated: DateTime 51 | provider: String! 52 | } 53 | 54 | input ChangePasswordDto { 55 | newPassword: String! 56 | oldPassword: String! 57 | } 58 | 59 | type Chat { 60 | _id: String! 61 | archived: Boolean! 62 | defaultBranch: ChatBranch 63 | isPublic: Boolean! 64 | lastActivityAt: DateTime! 65 | pinned: Boolean! 66 | title: String! 67 | } 68 | 69 | type ChatBranch { 70 | _id: String! 71 | branchPoint: Float! 72 | isActive: Boolean! 73 | messageCount: Float! 74 | modelConfig: ModelConfig 75 | name: String! 76 | parentBranchId: ChatBranch 77 | } 78 | 79 | type ChatsResponse { 80 | chats: [Chat!]! 81 | hasMore: Boolean! 82 | total: Float! 83 | } 84 | 85 | input CompleteFileDto { 86 | fileId: String! 87 | parts: [FilePart!]! 88 | } 89 | 90 | input CreateApiKeyDto { 91 | alias: String! 92 | apiKey: String! 93 | provider: AIProviderId! 94 | } 95 | 96 | input CreateFileDto { 97 | filename: String! 98 | mimetype: String! 99 | size: Float! 100 | } 101 | 102 | """ 103 | A date-time string at UTC, such as 2019-12-03T09:54:33Z, compliant with the date-time format. 104 | """ 105 | scalar DateTime 106 | 107 | type DeviceInfo { 108 | browser: String 109 | ip: String! 110 | platform: String 111 | userAgent: String! 112 | } 113 | 114 | type File { 115 | _id: String! 116 | clientToken: String 117 | createdAt: DateTime! 118 | filename: String! 119 | mimetype: String! 120 | size: Float! 121 | uploadId: String 122 | } 123 | 124 | input FilePart { 125 | etag: String! 126 | partNumber: Float! 127 | } 128 | 129 | input ForkBranchDto { 130 | cloneMessages: Boolean 131 | name: String 132 | } 133 | 134 | input GetChatDto { 135 | chatId: String! 136 | } 137 | 138 | input GetManyChatsDto { 139 | archived: Boolean = false 140 | limit: Float = 20 141 | offset: Float = 0 142 | search: String 143 | } 144 | 145 | input GetMessagesDto { 146 | branchId: String! 147 | fromIndex: Float = 0 148 | limit: Float = 50 149 | offset: Float = 0 150 | } 151 | 152 | input LoginDto { 153 | email: String! 154 | password: String! 155 | } 156 | 157 | type Message { 158 | _id: String! 159 | attachments: [ID!]! 160 | branchId: String! 161 | chatId: String! 162 | content: [MessageContent!]! 163 | createdAt: DateTime! 164 | editedAt: DateTime 165 | index: Float! 166 | isEdited: Boolean! 167 | modelUsed: String 168 | originalContent: [MessageContent!] 169 | role: String! 170 | tokens: Float 171 | } 172 | 173 | type MessageContent { 174 | id: String 175 | name: String 176 | text: String 177 | tool_use_id: String 178 | type: String! 179 | } 180 | 181 | type MessagesResponse { 182 | hasMore: Boolean! 183 | messages: [Message!]! 184 | total: Float! 185 | } 186 | 187 | type ModelConfig { 188 | apiKeyId: String 189 | maxTokens: Float 190 | modelId: String 191 | temperature: Float 192 | } 193 | 194 | type Mutation { 195 | addApiKey(payload: CreateApiKeyDto!): ApiKey! 196 | completeFile(payload: CompleteFileDto!): File! 197 | createChat: Chat! 198 | createFile(payload: CreateFileDto!): File! 199 | createPreferences: UserPreferences! 200 | deleteApiKey(id: String!): Boolean! 201 | deleteFile(id: String!): Boolean! 202 | forkBranch(originalBranchId: String!, payload: ForkBranchDto!): ChatBranch! 203 | login(payload: LoginDto!): SessionResponse! 204 | logout: Boolean! 205 | refreshToken(refreshToken: String!): SessionResponse! 206 | register(payload: RegisterDto!): SessionResponse! 207 | revokeAllSessions: Boolean! 208 | revokeSession(sessionId: String!): Boolean! 209 | sendMessage(payload: AddMessageDto!): Message! 210 | updateApiKey(id: String!, payload: UpdateApiKeyDto!): ApiKey! 211 | updateBranch(branchId: String!, payload: UpdateBranchDto!): ChatBranch! 212 | updateChat(id: String!, payload: UpdateChatDto!): Chat! 213 | updatePassword(payload: ChangePasswordDto!): User! 214 | updatePreferences(payload: UpdatePreferencesDto!): UserPreferences! 215 | updateUser(payload: UpdateUserDto!): User! 216 | } 217 | 218 | type PublicChatResponse { 219 | chat: Chat! 220 | messages: [Message!]! 221 | } 222 | 223 | type Query { 224 | getApiKeys: [ApiKey!]! 225 | getAvailableModels(rawDecryptKey: String!): [AIModel!]! 226 | getChat(query: GetChatDto!): SingleChatResponse! 227 | getChatBranches(chatId: String!): [ChatBranch!]! 228 | getChatMessages(query: GetMessagesDto!): MessagesResponse! 229 | getChats(query: GetManyChatsDto!): ChatsResponse! 230 | getFileById(id: String!): File! 231 | getPreferences: UserPreferences! 232 | getPublicChat(query: GetChatDto!): PublicChatResponse! 233 | getSessions: [Session!]! 234 | getUser: User! 235 | getUserFiles: [File!]! 236 | getUserStorageStats: UserStorageStats! 237 | } 238 | 239 | input RegisterDto { 240 | displayName: String! 241 | email: String! 242 | password: String! 243 | } 244 | 245 | type Session { 246 | _id: String! 247 | deviceInfo: DeviceInfo! 248 | expiresAt: DateTime! 249 | isActive: Boolean! 250 | lastUsedAt: DateTime! 251 | } 252 | 253 | type SessionResponse { 254 | accessToken: String! 255 | rawDecryptKey: String 256 | refreshToken: String! 257 | user: User 258 | } 259 | 260 | type SingleChatResponse { 261 | branches: [ChatBranch!]! 262 | chat: Chat! 263 | totalMessages: Float! 264 | } 265 | 266 | input UpdateApiKeyDto { 267 | alias: String 268 | isActive: Boolean 269 | } 270 | 271 | input UpdateBranchDto { 272 | modelConfig: UpdateBranchModelConfigDto 273 | name: String 274 | } 275 | 276 | input UpdateBranchModelConfigDto { 277 | apiKeyId: String 278 | maxTokens: Float 279 | modelId: String 280 | temperature: Float 281 | } 282 | 283 | input UpdateChatDto { 284 | archived: Boolean 285 | isPublic: Boolean 286 | pinned: Boolean 287 | title: String 288 | } 289 | 290 | input UpdatePreferencesDto { 291 | dateFormat: String 292 | language: String 293 | showSidebar: Boolean 294 | showTimestamps: Boolean 295 | theme: String 296 | use24HourFormat: Boolean 297 | } 298 | 299 | input UpdateUserDto { 300 | displayName: String 301 | email: String 302 | } 303 | 304 | type User { 305 | _id: String! 306 | createdAt: DateTime! 307 | decryptKey: String! 308 | displayName: String! 309 | email: String! 310 | emailVerified: Boolean! 311 | encryptKey: String! 312 | preferences: UserPreferences 313 | updatedAt: DateTime! 314 | } 315 | 316 | type UserPreferences { 317 | _id: String! 318 | dateFormat: String 319 | language: String 320 | showSidebar: Boolean 321 | showTimestamps: Boolean 322 | theme: String 323 | use24HourFormat: Boolean 324 | } 325 | 326 | type UserStorageStats { 327 | limit: Float! 328 | remaining: Float! 329 | used: Float! 330 | } -------------------------------------------------------------------------------- /src/websockets/websockets.gateway.ts: -------------------------------------------------------------------------------- 1 | import { Injectable, Logger, UnauthorizedException } from '@nestjs/common'; 2 | import { ConfigService } from '@nestjs/config'; 3 | import { JwtService } from '@nestjs/jwt'; 4 | import { 5 | ConnectedSocket, 6 | MessageBody, 7 | OnGatewayConnection, 8 | OnGatewayDisconnect, 9 | OnGatewayInit, 10 | SubscribeMessage, 11 | WebSocketGateway, 12 | WebSocketServer, 13 | } from '@nestjs/websockets'; 14 | import { Server, Socket } from 'socket.io'; 15 | 16 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 17 | import { RateLimit } from './decorators/rate-limit.decorator'; 18 | import { WebsocketsService } from './websockets.service'; 19 | 20 | @Injectable() 21 | @WebSocketGateway({ 22 | namespace: 'ws', 23 | transports: ['websocket', 'polling'], 24 | }) 25 | export class WebsocketGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { 26 | private readonly logger = new Logger(WebsocketGateway.name); 27 | 28 | @WebSocketServer() 29 | server: Server; 30 | 31 | constructor( 32 | private websocketService: WebsocketsService, 33 | private jwtService: JwtService, 34 | private configService: ConfigService 35 | ) {} 36 | 37 | afterInit(server: Server) { 38 | this.websocketService.setServer(server); 39 | this.configureCors(server); 40 | this.logger.log('WebSocket Gateway initialized'); 41 | } 42 | 43 | private configureCors(server: Server) { 44 | const corsOrigin = this.configService.get('CORS_ORIGIN'); 45 | const allowedOrigins = this.configService.get('CORS_ALLOWED_ORIGINS') || [ 46 | corsOrigin, 47 | ]; 48 | 49 | // Apply corsOptions to the WebSocket server 50 | server.on('initial_headers', (headers: any, req: any) => { 51 | const origin = req.headers.origin; 52 | if (origin && allowedOrigins.includes(origin)) { 53 | headers['Access-Control-Allow-Origin'] = origin; 54 | headers['Access-Control-Allow-Credentials'] = 'true'; 55 | } 56 | }); 57 | } 58 | 59 | handleConnection(client: Socket) { 60 | try { 61 | this.logger.log(`Client connected: ${client.id}`); 62 | 63 | // Add client to service 64 | this.websocketService.addClient(client.id, client); 65 | 66 | // Extract token from handshake 67 | const token = this.extractTokenFromHandshake(client); 68 | 69 | if (!token) { 70 | this.logger.warn(`No token provided for socket ${client.id}`); 71 | client.emit('connection:error', { message: 'No authentication token provided' }); 72 | client.disconnect(); 73 | return; 74 | } 75 | 76 | // Validate token 77 | try { 78 | const payload = this.validateToken(token); 79 | 80 | // Associate socket with user through service 81 | this.websocketService.associateUserWithSocket(payload.sub, client.id); 82 | 83 | // Send success message 84 | client.emit('auth_success', { 85 | message: 'Successfully connected to WebSocket server', 86 | userId: payload.sub, 87 | }); 88 | 89 | this.logger.log( 90 | `Authenticated user ${payload.sub} connected with socket ${client.id}` 91 | ); 92 | } catch (error) { 93 | this.logger.error(`Authentication failed for socket ${client.id}:`, error.message); 94 | client.emit('auth_error', { message: 'Authentication failed' }); 95 | client.disconnect(); 96 | } 97 | } catch (error) { 98 | this.logger.error(`Error handling connection:`, error); 99 | client.disconnect(); 100 | } 101 | } 102 | 103 | handleDisconnect(client: Socket) { 104 | try { 105 | this.logger.log(`Client disconnected: ${client.id}`); 106 | 107 | // Remove client through service (handles all cleanup) 108 | this.websocketService.removeClient(client.id); 109 | } catch (error) { 110 | this.logger.error(`Error handling disconnection:`, error); 111 | } 112 | } 113 | 114 | // =================== 115 | // EVENT HANDLERS 116 | // =================== 117 | 118 | @SubscribeMessage('join-branch') 119 | @RateLimit(10, 10) // 10 requests per 10 seconds 120 | async handleJoinBranch( 121 | @ConnectedSocket() client: Socket, 122 | @MessageBody() data: { branchId: string } 123 | ) { 124 | try { 125 | if (!data.branchId) { 126 | throw new Error('Branch ID is required'); 127 | } 128 | 129 | // Delegate to service 130 | const result = await this.websocketService.joinBranchRoom(client.id, data.branchId); 131 | 132 | return result; 133 | } catch (error) { 134 | this.logger.error(`Error in join-branch handler:`, error.message); 135 | return { 136 | success: false, 137 | error: error.message, 138 | }; 139 | } 140 | } 141 | 142 | @SubscribeMessage('leave-branch') 143 | @RateLimit(10, 10) 144 | async handleLeaveBranch( 145 | @ConnectedSocket() client: Socket, 146 | @MessageBody() data: { branchId: string } 147 | ) { 148 | try { 149 | if (!data.branchId) { 150 | throw new Error('Branch ID is required'); 151 | } 152 | 153 | // Delegate to service 154 | const result = await this.websocketService.leaveBranchRoom(client.id, data.branchId); 155 | 156 | return result; 157 | } catch (error) { 158 | this.logger.error(`Error in leave-branch handler:`, error.message); 159 | return { 160 | success: false, 161 | error: error.message, 162 | }; 163 | } 164 | } 165 | 166 | // =================== 167 | // UTILITY METHODS (moved to service, kept for backward compatibility) 168 | // =================== 169 | 170 | /** 171 | * @deprecated Use websocketService.getConnectedUserCount() instead 172 | */ 173 | getConnectedUserCount(): number { 174 | return this.websocketService.getConnectedUserCount(); 175 | } 176 | 177 | /** 178 | * @deprecated Use websocketService.isUserConnected() instead 179 | */ 180 | isUserConnected(userId: string): boolean { 181 | return this.websocketService.isUserConnected(userId); 182 | } 183 | 184 | // =================== 185 | // PRIVATE METHODS 186 | // =================== 187 | 188 | private extractTokenFromHandshake(client: Socket): string | null { 189 | // Try to get token from handshake auth 190 | const auth = client.handshake.auth; 191 | if (auth && auth.token) { 192 | return auth.token; 193 | } 194 | 195 | // Try to get token from query params as fallback 196 | const query = client.handshake.query; 197 | if (query && query.token) { 198 | return query.token as string; 199 | } 200 | 201 | // Try to get token from headers as fallback 202 | const headers = client.handshake.headers; 203 | if (headers.authorization) { 204 | const parts = headers.authorization.split(' '); 205 | if (parts.length === 2 && parts[0] === 'Bearer') { 206 | return parts[1]; 207 | } 208 | } 209 | 210 | return null; 211 | } 212 | 213 | private validateToken(token: string): AccessJwtPayload { 214 | try { 215 | // Verify token 216 | const payload = this.jwtService.verify(token, { 217 | secret: this.configService.get('JWT_SECRET'), 218 | }); 219 | 220 | return payload; 221 | } catch { 222 | throw new UnauthorizedException('Invalid token'); 223 | } 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /src/encryption/encryption.service.ts: -------------------------------------------------------------------------------- 1 | import { Injectable } from '@nestjs/common'; 2 | import { ConfigService } from '@nestjs/config'; 3 | import * as crypto from 'crypto'; 4 | 5 | @Injectable() 6 | export class EncryptionService { 7 | private readonly algorithm = 'aes-256-gcm'; 8 | private readonly masterIv: Buffer; 9 | 10 | constructor(config: ConfigService) { 11 | // Initialize encryption keys from environment variables 12 | const encryptionIv = config.get('ENCRYPTION_IV'); 13 | 14 | if (!encryptionIv || encryptionIv.length !== 32) { 15 | throw new Error('ENCRYPTION_IV must be 16-byte (32 hex chars) key'); 16 | } 17 | 18 | // Convert hex strings to buffers 19 | this.masterIv = Buffer.from(encryptionIv, 'hex'); 20 | } 21 | 22 | /** 23 | * Encrypts text using AES-256-GCM 24 | * @param text The text to encrypt 25 | * @param userKey The user's key (raw string, will be derived) 26 | * @returns Encrypted text as a hex string with IV and auth tag prepended 27 | */ 28 | encrypt(text: string, userKey: string): string { 29 | try { 30 | // Generate a random IV for each encryption 31 | const iv = crypto.randomBytes(16); 32 | 33 | // Derive a key from the user's key 34 | const derivedKey = this.deriveKey(userKey); 35 | 36 | // Create cipher 37 | const cipher = crypto.createCipheriv(this.algorithm, derivedKey, iv); 38 | 39 | // Encrypt the text 40 | let encrypted = cipher.update(text, 'utf-8', 'hex'); 41 | encrypted += cipher.final('hex'); 42 | 43 | // Get the auth tag 44 | const authTag = cipher.getAuthTag(); 45 | 46 | // Combine IV, auth tag, and encrypted text 47 | // Format: iv(32 hex) + authTag(32 hex) + encrypted(variable hex) 48 | return iv.toString('hex') + authTag.toString('hex') + encrypted; 49 | } catch (error) { 50 | console.error('Encryption error:', error); 51 | throw new Error('Encryption failed'); 52 | } 53 | } 54 | 55 | /** 56 | * Decrypts text using AES-256-GCM 57 | * @param encryptedText The encrypted text (hex string with IV and auth tag) 58 | * @param userKey The user's key (raw string, will be derived) 59 | * @returns Decrypted text 60 | */ 61 | decrypt(encryptedText: string, userKey: string): string { 62 | try { 63 | // Validate minimum length (32 + 32 + at least some encrypted data) 64 | if (encryptedText.length < 66) { 65 | throw new Error('Invalid encrypted text format'); 66 | } 67 | 68 | // Extract IV (first 32 hex chars = 16 bytes) 69 | const iv = Buffer.from(encryptedText.slice(0, 32), 'hex'); 70 | 71 | // Extract auth tag (next 32 hex chars = 16 bytes) 72 | const authTag = Buffer.from(encryptedText.slice(32, 64), 'hex'); 73 | 74 | // Extract encrypted text (remaining chars) 75 | const encrypted = encryptedText.slice(64); 76 | 77 | // Derive key from the user's key 78 | const derivedKey = this.deriveKey(userKey); 79 | 80 | // Create decipher 81 | const decipher = crypto.createDecipheriv(this.algorithm, derivedKey, iv); 82 | decipher.setAuthTag(authTag); 83 | 84 | // Decrypt the text 85 | let decrypted = decipher.update(encrypted, 'hex', 'utf-8'); 86 | decrypted += decipher.final('utf-8'); 87 | 88 | return decrypted; 89 | } catch (error) { 90 | console.error('Decryption error:', error); 91 | throw new Error('Decryption failed'); 92 | } 93 | } 94 | 95 | /** 96 | * Derives a key from user input using PBKDF2 97 | * @param userKey The user's raw key 98 | * @returns Derived key buffer (32 bytes) 99 | */ 100 | private deriveKey(userKey: string): Buffer { 101 | return crypto.pbkdf2Sync( 102 | userKey, 103 | this.masterIv, // Using masterIv as salt 104 | 100000, // Increased iterations for better security 105 | 32, 106 | 'sha256' 107 | ); 108 | } 109 | 110 | /** 111 | * Generates a new pair of RSA encryption keys (asymmetric) 112 | * @returns Object containing public key (for encryption) and private key (for decryption) 113 | */ 114 | generateKeyPair(): { encryptKey: string; decryptKey: string } { 115 | try { 116 | // Generate RSA key pair (2048 bits for good security/performance balance) 117 | const { publicKey, privateKey } = crypto.generateKeyPairSync('rsa', { 118 | modulusLength: 2048, 119 | publicKeyEncoding: { 120 | type: 'spki', 121 | format: 'pem', 122 | }, 123 | privateKeyEncoding: { 124 | type: 'pkcs8', 125 | format: 'pem', 126 | }, 127 | }); 128 | 129 | return { 130 | encryptKey: publicKey, // Public key for encryption 131 | decryptKey: privateKey, // Private key for decryption 132 | }; 133 | } catch (error) { 134 | console.error('Key pair generation error:', error); 135 | throw new Error('Key pair generation failed'); 136 | } 137 | } 138 | 139 | /** 140 | * Encrypts text using RSA public key (asymmetric encryption) 141 | * @param text The text to encrypt 142 | * @param publicKey The RSA public key in PEM format 143 | * @returns Encrypted text as base64 string 144 | */ 145 | encryptWithKey(text: string, publicKey: string): string { 146 | try { 147 | // RSA has size limitations, so for large texts we use hybrid encryption 148 | // For small texts (< 190 bytes with 2048-bit key), use direct RSA 149 | const textBuffer = Buffer.from(text, 'utf-8'); 150 | 151 | if (textBuffer.length <= 190) { 152 | // Direct RSA encryption for small texts 153 | const encrypted = crypto.publicEncrypt( 154 | { 155 | key: publicKey, 156 | padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, 157 | oaepHash: 'sha256', 158 | }, 159 | textBuffer 160 | ); 161 | return encrypted.toString('base64'); 162 | } else { 163 | // Hybrid encryption for larger texts 164 | // Generate random AES key 165 | const aesKey = crypto.randomBytes(32); 166 | const iv = crypto.randomBytes(16); 167 | 168 | // Encrypt text with AES 169 | const cipher = crypto.createCipheriv('aes-256-gcm', aesKey, iv); 170 | let encryptedText = cipher.update(text, 'utf-8', 'hex'); 171 | encryptedText += cipher.final('hex'); 172 | const authTag = cipher.getAuthTag(); 173 | 174 | // Encrypt AES key with RSA 175 | const encryptedAesKey = crypto.publicEncrypt( 176 | { 177 | key: publicKey, 178 | padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, 179 | oaepHash: 'sha256', 180 | }, 181 | aesKey 182 | ); 183 | 184 | // Combine: encryptedAesKey + iv + authTag + encryptedText 185 | const result = { 186 | key: encryptedAesKey.toString('base64'), 187 | iv: iv.toString('hex'), 188 | authTag: authTag.toString('hex'), 189 | data: encryptedText, 190 | }; 191 | 192 | return Buffer.from(JSON.stringify(result)).toString('base64'); 193 | } 194 | } catch (error) { 195 | console.error('RSA encryption error:', error); 196 | throw new Error('RSA encryption failed'); 197 | } 198 | } 199 | 200 | /** 201 | * Decrypts text using RSA private key (asymmetric decryption) 202 | * @param encryptedText The encrypted text (base64 string) 203 | * @param privateKey The RSA private key in PEM format 204 | * @returns Decrypted text 205 | */ 206 | decryptWithKey(encryptedText: string, privateKey: string): string { 207 | try { 208 | const encryptedBuffer = Buffer.from(encryptedText, 'base64'); 209 | 210 | // Try direct RSA decryption first 211 | try { 212 | const decrypted = crypto.privateDecrypt( 213 | { 214 | key: privateKey, 215 | padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, 216 | oaepHash: 'sha256', 217 | }, 218 | encryptedBuffer 219 | ); 220 | return decrypted.toString('utf-8'); 221 | } catch { 222 | // If direct decryption fails, try hybrid decryption 223 | const hybridData = JSON.parse(encryptedBuffer.toString('utf-8')); 224 | 225 | // Decrypt AES key with RSA 226 | const encryptedAesKey = Buffer.from(hybridData.key, 'base64'); 227 | const aesKey = crypto.privateDecrypt( 228 | { 229 | key: privateKey, 230 | padding: crypto.constants.RSA_PKCS1_OAEP_PADDING, 231 | oaepHash: 'sha256', 232 | }, 233 | encryptedAesKey 234 | ); 235 | 236 | // Decrypt text with AES 237 | const iv = Buffer.from(hybridData.iv, 'hex'); 238 | const authTag = Buffer.from(hybridData.authTag, 'hex'); 239 | 240 | const decipher = crypto.createDecipheriv('aes-256-gcm', aesKey, iv); 241 | decipher.setAuthTag(authTag); 242 | 243 | let decrypted = decipher.update(hybridData.data, 'hex', 'utf-8'); 244 | decrypted += decipher.final('utf-8'); 245 | 246 | return decrypted; 247 | } 248 | } catch (error) { 249 | console.error('RSA decryption error:', error); 250 | throw new Error('RSA decryption failed'); 251 | } 252 | } 253 | 254 | /** 255 | * Validates the integrity of a user's encryption keys 256 | * @param encryptKey Public key for encryption 257 | * @param decryptKey Private key for decryption 258 | * @returns Boolean indicating if the keys work together 259 | */ 260 | validateKeyIntegrity(encryptKey: string, decryptKey: string): boolean { 261 | try { 262 | // Test encryption/decryption with the keys 263 | const testMessage = 'encryption-test-' + Date.now(); 264 | const encrypted = this.encryptWithKey(testMessage, encryptKey); 265 | const decrypted = this.decryptWithKey(encrypted, decryptKey); 266 | 267 | return testMessage === decrypted; 268 | } catch (error) { 269 | console.error('Key validation error:', error); 270 | return false; 271 | } 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /src/chats/chats.resolver.ts: -------------------------------------------------------------------------------- 1 | import { NotFoundException, UseGuards } from '@nestjs/common'; 2 | import { Args, Mutation, Query, Resolver } from '@nestjs/graphql'; 3 | import { Types } from 'mongoose'; 4 | 5 | import { AIService } from '@/ai/ai.service'; 6 | import { AIModel, AIProviderCallbacks } from '@/ai/interfaces/ai-provider.interface'; 7 | import { CurrentUser } from '@/auth/decorators/current-user.decorator'; 8 | import { GqlAuthGuard } from '@/auth/guards/gql-auth.guard'; 9 | import { AccessJwtPayload } from '@/auth/interfaces/jwt-payload.interface'; 10 | import { BranchesService } from '@/branches/branches.service'; 11 | import { EncryptionService } from '@/encryption/encryption.service'; 12 | import { ApiKeysService } from '@/keys/api-key.service'; 13 | import { MessagesService } from '@/messages/messages.service'; 14 | import { StorageService } from '@/storage/storage.service'; 15 | import { WebsocketsService } from '@/websockets/websockets.service'; 16 | import { Message, MessageRole } from '../messages/schemas/message.schema'; 17 | import { ChatService } from './chats.service'; 18 | import { AddMessageDto } from './dto/add-message.dto'; 19 | import { GetChatDto, GetManyChatsDto } from './dto/get-chat-dto'; 20 | import { UpdateChatDto } from './dto/update-chat.dto'; 21 | import { Chat, ChatsResponse, PublicChatResponse, SingleChatResponse } from './schemas/chat.schema'; 22 | 23 | @Resolver(() => Chat) 24 | export class ChatsResolver { 25 | constructor( 26 | private chatService: ChatService, 27 | private branchesService: BranchesService, 28 | private messagesService: MessagesService, 29 | private aiService: AIService, 30 | private apiKeyService: ApiKeysService, 31 | private encryptionService: EncryptionService, 32 | private websocketsService: WebsocketsService, 33 | private storageService: StorageService 34 | ) {} 35 | 36 | @UseGuards(GqlAuthGuard) 37 | @Query(() => [AIModel]) 38 | async getAvailableModels( 39 | @CurrentUser() user: AccessJwtPayload, 40 | @Args('rawDecryptKey') rawDecryptKey: string 41 | ): Promise { 42 | const models = new Map(); 43 | 44 | const apiKeys = await this.apiKeyService.findAll(user.sub); 45 | for (const apiKey of apiKeys) { 46 | const decryptedApiKey = this.encryptionService.decryptWithKey( 47 | apiKey.encryptedApiKey, 48 | rawDecryptKey 49 | ); 50 | const apiModels = await this.aiService.getModels(apiKey.provider, decryptedApiKey); 51 | 52 | for (const model of apiModels) { 53 | models.set(model.id, model); 54 | } 55 | } 56 | 57 | return Array.from(models.values()); 58 | } 59 | 60 | @UseGuards(GqlAuthGuard) 61 | @Query(() => ChatsResponse) 62 | async getChats( 63 | @CurrentUser() user: AccessJwtPayload, 64 | @Args('query') queryOptions: GetManyChatsDto 65 | ): Promise { 66 | const results = await this.chatService.findByUserId(user.sub, queryOptions); 67 | return results; 68 | } 69 | 70 | @UseGuards(GqlAuthGuard) 71 | @Query(() => SingleChatResponse) 72 | async getChat( 73 | @CurrentUser() user: AccessJwtPayload, 74 | @Args('query') queryOptions: GetChatDto 75 | ): Promise { 76 | const { chatId } = queryOptions; 77 | 78 | const chat = await this.chatService.findById(chatId, user.sub); 79 | const branches = await this.branchesService.findByChatId(chatId, user.sub); 80 | const totalMessages = branches.reduce((acc, branch) => acc + branch.messageCount, 0); 81 | 82 | return { 83 | chat, 84 | branches, 85 | totalMessages, 86 | }; 87 | } 88 | 89 | @Query(() => PublicChatResponse) 90 | async getPublicChat(@Args('query') queryOptions: GetChatDto): Promise { 91 | const { chatId } = queryOptions; 92 | 93 | const chat = await this.chatService.findById(chatId, 'none'); 94 | 95 | if (!chat.isPublic) { 96 | throw new NotFoundException('Chat not found'); 97 | } 98 | 99 | if (!chat.defaultBranch) { 100 | throw new NotFoundException('No default branch found'); 101 | } 102 | 103 | const messages = await this.messagesService.findByBranchId({ 104 | branchId: chat.defaultBranch?._id.toString(), 105 | }); 106 | 107 | return { 108 | chat, 109 | messages: messages.messages, 110 | }; 111 | } 112 | 113 | @UseGuards(GqlAuthGuard) 114 | @Mutation(() => Chat) 115 | async createChat(@CurrentUser() user: AccessJwtPayload) { 116 | return await this.chatService.createChat(user.sub); 117 | } 118 | 119 | @UseGuards(GqlAuthGuard) 120 | @Mutation(() => Chat) 121 | async updateChat( 122 | @CurrentUser() user: AccessJwtPayload, 123 | @Args('id') chatId: string, 124 | @Args('payload') payload: UpdateChatDto 125 | ) { 126 | return await this.chatService.update(chatId, user.sub, payload); 127 | } 128 | 129 | @UseGuards(GqlAuthGuard) 130 | @Mutation(() => Message) 131 | async sendMessage( 132 | @CurrentUser() user: AccessJwtPayload, 133 | @Args('payload') payload: AddMessageDto 134 | ) { 135 | const branch = await this.branchesService.findById(payload.branchId, user.sub); 136 | const attachments: string[] = []; 137 | 138 | for (const attachment of payload.attachments || []) { 139 | const queried = await this.storageService.getFileById(attachment, user.sub); 140 | attachments.push(queried._id); 141 | } 142 | 143 | // First, save the user message 144 | const userMessage = await this.messagesService.create( 145 | { 146 | attachments, 147 | branchId: new Types.ObjectId(payload.branchId), 148 | chatId: branch.chatId, 149 | content: [ 150 | { 151 | type: 'text', 152 | text: payload.prompt, 153 | }, 154 | ], 155 | metadata: {}, 156 | role: MessageRole.user, 157 | tokens: 0, 158 | }, 159 | user.sub 160 | ); 161 | 162 | // Update branch and chat message counts 163 | await this.branchesService.incrementMessageCount(payload.branchId); 164 | await this.chatService.updateLastActivity(branch.chatId.toString()); 165 | 166 | // Get API key 167 | const apiKey = await this.apiKeyService.findById(payload.apiKeyId, user.sub); 168 | 169 | // Get decrypt-key 170 | const key = this.encryptionService.decryptWithKey( 171 | apiKey.encryptedApiKey, 172 | payload.rawDecryptKey 173 | ); 174 | 175 | // Get previous history 176 | const chat = await this.messagesService.findByBranchId({ 177 | branchId: payload.branchId, 178 | }); 179 | 180 | let completedMessage = ''; 181 | 182 | this.websocketsService.emitToBranch(user.sub, payload.branchId, 'message:start', null); 183 | const responseAttachments: string[] = []; 184 | 185 | const callbacks: AIProviderCallbacks = { 186 | onEnd: async () => { 187 | console.log('Chat ended'); 188 | 189 | // Save AI message 190 | const message = await this.messagesService.create( 191 | { 192 | attachments: responseAttachments, 193 | branchId: new Types.ObjectId(payload.branchId), 194 | chatId: branch.chatId, 195 | content: [ 196 | { 197 | type: 'text', 198 | text: completedMessage, 199 | }, 200 | ], 201 | metadata: {}, 202 | role: MessageRole.assistant, 203 | tokens: 0, 204 | }, 205 | user.sub 206 | ); 207 | 208 | this.websocketsService.emitToBranch( 209 | user.sub, 210 | payload.branchId, 211 | 'message:end', 212 | message 213 | ); 214 | }, 215 | // eslint-disable-next-line @typescript-eslint/require-await 216 | onError: async error => { 217 | console.log('Chat error', error); 218 | 219 | this.websocketsService.emitToBranch( 220 | user.sub, 221 | payload.branchId, 222 | 'message:error', 223 | error 224 | ); 225 | }, 226 | // eslint-disable-next-line @typescript-eslint/require-await 227 | onText: async text => { 228 | completedMessage += text; 229 | console.log('Chat chunk', text); 230 | 231 | this.websocketsService.emitToBranch( 232 | user.sub, 233 | payload.branchId, 234 | 'message:chunk', 235 | text 236 | ); 237 | }, 238 | // eslint-disable-next-line @typescript-eslint/require-await 239 | onMediaGenStart: async type => { 240 | this.websocketsService.emitToBranch( 241 | user.sub, 242 | payload.branchId, 243 | 'media:start', 244 | type 245 | ); 246 | }, 247 | 248 | onMediaGenEnd: async (url: string, type: string) => { 249 | const attachment = await this.storageService.uploadFromURL( 250 | url, 251 | 'generation', 252 | type, 253 | user.sub 254 | ); 255 | const attachmentId = attachment._id; 256 | if (attachmentId) responseAttachments.push(attachmentId); 257 | 258 | this.websocketsService.emitToBranch(user.sub, payload.branchId, 'media:end', { 259 | url, 260 | type, 261 | }); 262 | }, 263 | // eslint-disable-next-line @typescript-eslint/require-await 264 | onMediaGenError: async error => { 265 | this.websocketsService.emitToBranch( 266 | user.sub, 267 | payload.branchId, 268 | 'media:error', 269 | error 270 | ); 271 | }, 272 | }; 273 | 274 | if (payload.useImageTool) { 275 | this.aiService 276 | .generateImage(apiKey.provider, key, payload.modelId, chat.messages, {}, callbacks) 277 | .catch(error => { 278 | console.log('Internal image generation handling error', error); 279 | this.websocketsService.emitToBranch( 280 | user.sub, 281 | payload.branchId, 282 | 'message:error', 283 | error 284 | ); 285 | }); 286 | } else { 287 | this.aiService 288 | .sendMessage(apiKey.provider, key, payload.modelId, chat.messages, {}, callbacks) 289 | .catch(error => { 290 | console.log('Internal chat handling error', error); 291 | this.websocketsService.emitToBranch( 292 | user.sub, 293 | payload.branchId, 294 | 'message:error', 295 | error 296 | ); 297 | }); 298 | } 299 | 300 | return userMessage; 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /src/ai/clients/base/base-openai-api.client.ts: -------------------------------------------------------------------------------- 1 | import OpenAI, { ClientOptions } from 'openai'; 2 | import { ChatCompletionContentPart, ChatCompletionMessageParam } from 'openai/resources/index'; 3 | 4 | import { 5 | AIModel, 6 | AIProviderCallbacks, 7 | AIProviderClient, 8 | AIProviderId, 9 | AIProviderOptions, 10 | } from '@/ai/interfaces/ai-provider.interface'; 11 | import { Message, MessageRole } from '@/messages/schemas/message.schema'; 12 | import { StorageService } from '@/storage/storage.service'; 13 | export function extractNameAndAuthor(modelId: string): { name: string; author: string } { 14 | if (modelId.includes('/')) { 15 | const [author, ...nameParts] = modelId.split('/'); 16 | const name = nameParts.join('/'); 17 | return { name, author }; 18 | } else { 19 | return { name: modelId, author: 'OpenAI' }; 20 | } 21 | } 22 | 23 | export function normalizeOpenAIModel( 24 | model: OpenAI.Models.Model, 25 | provider: AIProviderId 26 | ): AIModel | null { 27 | if (!model.id) { 28 | return null; 29 | } 30 | 31 | const { author, name } = extractNameAndAuthor(model.id); 32 | 33 | const imageAnalysis = model.id.includes('vision') || model.id.includes('gpt-4'); 34 | const functionCalling = model.id.includes('o1'); 35 | const imageGeneration = 36 | model.id.includes('dall-e') || 37 | model.id.includes('stable-diffusion') || 38 | model.id.includes('midjourney') || 39 | model.id.includes('flux'); 40 | const textGeneration = !model.id.includes('dall-e') && !model.id.includes('stable-diffusion'); 41 | 42 | return { 43 | id: model.id, 44 | name: name, 45 | author: author || 'Unknown', 46 | provider: provider, 47 | enabled: true, 48 | capabilities: { 49 | codeExecution: false, 50 | fileAnalysis: false, 51 | functionCalling, 52 | imageAnalysis, 53 | imageGeneration, 54 | textGeneration, 55 | webBrowsing: false, 56 | }, 57 | }; 58 | } 59 | 60 | export async function messageToOpenAI( 61 | storageService: StorageService, 62 | message: Message, 63 | isLast: boolean 64 | ): Promise { 65 | const role = message.role === MessageRole.user ? 'user' : 'assistant'; 66 | 67 | if (!isLast || role !== 'user' || message.attachments?.length == 0) { 68 | const content = message.content 69 | .filter(part => part.text) 70 | .map(part => part.text!) 71 | .join('\n'); 72 | return { role, content }; 73 | } 74 | 75 | const content: ChatCompletionContentPart[] = []; 76 | 77 | for (const messageContent of message.content) { 78 | if (messageContent.text) { 79 | content.push({ type: 'text', text: messageContent.text }); 80 | } 81 | } 82 | 83 | for (const attachmentId of message.attachments) { 84 | const file = await storageService.getFileById(attachmentId.toString()); 85 | let url = storageService.getUrlForFile(file._id); 86 | const isLocal = url.startsWith('http://localhost:') || url.startsWith('http://127.0.0.1:'); 87 | 88 | // Text plain 89 | if (file.mimetype === 'text/plain') { 90 | const text = await storageService.readFileAsPlainText(file._id); 91 | content.push({ 92 | type: 'text', 93 | text: `[Attached File]:\n${text}`, 94 | }); 95 | continue; 96 | } 97 | 98 | if (file.mimetype === 'application/pdf') { 99 | const pdfText = await storageService.readFileAsPDF(file._id); 100 | content.push({ 101 | type: 'text', 102 | text: `[Attached PDF File (Extracted Text)]:\n${pdfText}`, 103 | }); 104 | continue; 105 | } 106 | 107 | // URLs 108 | if (file.mimetype.startsWith('image/')) { 109 | if (isLocal) { 110 | url = await storageService.readFileAsBase64URL(file._id, file.mimetype); 111 | console.warn('Local url detected, using base64 url...'); 112 | } 113 | 114 | content.push({ 115 | type: 'image_url', 116 | image_url: { 117 | url, 118 | }, 119 | }); 120 | continue; 121 | } 122 | 123 | // Base64 data 124 | const data = await storageService.readFileAsBase64Buffer(file._id); 125 | 126 | if (file.mimetype.startsWith('audio/')) { 127 | const isMp3 = file.mimetype === 'audio/mpeg'; 128 | content.push({ 129 | type: 'input_audio', 130 | input_audio: { 131 | data, 132 | format: isMp3 ? 'mp3' : 'wav', 133 | }, 134 | }); 135 | continue; 136 | } 137 | 138 | // Default 139 | content.push({ 140 | type: 'file', 141 | file: { 142 | file_data: data, 143 | file_id: file._id, 144 | filename: file.filename, 145 | }, 146 | }); 147 | } 148 | 149 | return { role, content }; 150 | } 151 | 152 | export abstract class BaseOpenAIApiClient implements AIProviderClient { 153 | protected readonly provider: AIProviderId; 154 | protected readonly storageService: StorageService; 155 | protected readonly apiEndpoint?: string | undefined; 156 | 157 | constructor(provider: AIProviderId, storageService: StorageService, apiEndpoint?: string) { 158 | this.provider = provider; 159 | this.storageService = storageService; 160 | this.apiEndpoint = apiEndpoint; 161 | 162 | this.sendMessage = this.sendMessage.bind(this); 163 | this.getModels = this.getModels.bind(this); 164 | this.countInputTokens = this.countInputTokens.bind(this); 165 | this.generateImage = this.generateImage.bind(this); 166 | 167 | if (!storageService) { 168 | throw new Error('Storage service is required for provider: ' + provider); 169 | } 170 | } 171 | 172 | private createClient(key: string): OpenAI { 173 | const settings: ClientOptions = { 174 | apiKey: key, 175 | defaultHeaders: { 176 | 'HTTP-Referer': process.env.APP_URL || 'http://localhost:3000', 177 | 'X-Title': process.env.APP_NAME || 'Your App Name', 178 | }, 179 | }; 180 | 181 | if (this.apiEndpoint) { 182 | settings.baseURL = this.apiEndpoint; 183 | } 184 | 185 | return new OpenAI(settings); 186 | } 187 | 188 | async getModels(key: string): Promise { 189 | const client = this.createClient(key); 190 | const response = await client.models.list(); 191 | const models: AIModel[] = []; 192 | 193 | for (const model of response.data) { 194 | const normalized = normalizeOpenAIModel(model, this.provider); 195 | if (normalized) { 196 | models.push(normalized); 197 | } 198 | } 199 | 200 | return models; 201 | } 202 | 203 | // eslint-disable-next-line @typescript-eslint/require-await 204 | async countInputTokens(_key: string, _modelId: string, messages: Message[]): Promise { 205 | // OpenAI doesn't have a direct token counting endpoint 206 | // We'll use a rough estimation: ~4 characters per token 207 | let totalChars = 0; 208 | 209 | for (const message of messages) { 210 | for (const part of message.content) { 211 | if (part.text) { 212 | totalChars += part.text.length; 213 | } 214 | } 215 | } 216 | 217 | return Math.ceil(totalChars / 4); 218 | } 219 | 220 | async sendMessage( 221 | key: string, 222 | modelId: string, 223 | messages: Message[], 224 | settings: { maxTokens?: number; temperature?: number }, 225 | callbacks: AIProviderCallbacks 226 | ): Promise { 227 | const client = this.createClient(key); 228 | 229 | // Convert messages to OpenAI format 230 | const openAIMessages = await Promise.all( 231 | messages.map( 232 | async (message, index) => 233 | await messageToOpenAI( 234 | this.storageService, 235 | message, 236 | index === messages.length - 1 237 | ) 238 | ) 239 | ); 240 | 241 | try { 242 | const stream = await client.chat.completions.create({ 243 | model: modelId, 244 | messages: openAIMessages, 245 | temperature: settings.temperature, 246 | max_tokens: settings.maxTokens, 247 | stream: true, 248 | }); 249 | 250 | for await (const chunk of stream) { 251 | const content = chunk.choices[0]?.delta?.content; 252 | 253 | if (content) { 254 | callbacks.onText(content).catch(error => { 255 | console.error('Error in onText callback:', error); 256 | }); 257 | } 258 | } 259 | 260 | callbacks.onEnd().catch(error => { 261 | console.error('Error in onEnd callback:', error); 262 | }); 263 | } catch (error: any) { 264 | const message = error?.message || 'Unknown error'; 265 | callbacks.onError(message).catch(err => { 266 | console.error('Error in onError callback:', err); 267 | }); 268 | } 269 | } 270 | 271 | async generateImage( 272 | key: string, 273 | modelId: string, 274 | promptOrMessages: string | Message[], 275 | settings: AIProviderOptions, 276 | callbacks: AIProviderCallbacks 277 | ) { 278 | const client = this.createClient(key); 279 | 280 | const prompt = 281 | typeof promptOrMessages === 'string' 282 | ? promptOrMessages 283 | : promptOrMessages[promptOrMessages.length - 1]?.content 284 | .filter(part => part.text) 285 | .map(part => part.text!) 286 | .join('\n') || ''; 287 | 288 | try { 289 | if (callbacks.onMediaGenStart) { 290 | await callbacks.onMediaGenStart('image'); 291 | } 292 | 293 | const response = await client.images.generate({ 294 | model: modelId, 295 | prompt: prompt, 296 | n: settings.imageGeneration?.n || 1, 297 | size: settings.imageGeneration?.size || '1024x1024', 298 | quality: settings.imageGeneration?.quality || 'standard', 299 | style: settings.imageGeneration?.style || 'vivid', 300 | }); 301 | 302 | for (const image of response?.data || []) { 303 | if (image.url) { 304 | const metadata = { 305 | prompt: prompt, 306 | revisedPrompt: image.revised_prompt, 307 | model: modelId, 308 | size: settings.imageGeneration?.size || '1024x1024', 309 | quality: settings.imageGeneration?.quality || 'standard', 310 | style: settings.imageGeneration?.style || 'vivid', 311 | }; 312 | 313 | if (callbacks.onMediaGenEnd) { 314 | await callbacks.onMediaGenEnd(image.url, 'image', metadata); 315 | } 316 | 317 | await callbacks.onText(image.revised_prompt || ''); 318 | } 319 | } 320 | 321 | await callbacks.onEnd(); 322 | } catch (error: any) { 323 | const errorMessage = error?.message || 'Internal Error'; 324 | 325 | if (callbacks.onMediaGenError) { 326 | await callbacks.onMediaGenError(errorMessage, 'image'); 327 | } 328 | 329 | await callbacks.onError(errorMessage); 330 | } 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /src/websockets/websockets.service.ts: -------------------------------------------------------------------------------- 1 | import { ChatBranch } from '@/branches/schemas/chat-branch.schema'; 2 | import { Chat } from '@/chats/schemas/chat.schema'; 3 | import { ApiKey } from '@/keys/schemas/api-key.schema'; 4 | import { UserPreferences } from '@/preferences/schema/user-preference.schema'; 5 | import { User } from '@/users/schemas/user.schema'; 6 | import { Injectable, Logger } from '@nestjs/common'; 7 | import { Server, Socket } from 'socket.io'; 8 | 9 | @Injectable() 10 | export class WebsocketsService { 11 | private readonly logger = new Logger(WebsocketsService.name); 12 | private server: Server; 13 | public readonly connectedClients = new Map(); // socketId -> Socket 14 | private readonly userSockets = new Map>(); // userId -> Set 15 | private readonly connectedUsers = new Map(); // socketId -> userId 16 | private roomClients: Map> = new Map(); 17 | 18 | setServer(server: Server): void { 19 | this.server = server; 20 | this.logger.log('WebSocket server initialized'); 21 | } 22 | 23 | /** 24 | * Add a client to the connected clients map 25 | */ 26 | addClient(socketId: string, socket: Socket): void { 27 | this.connectedClients.set(socketId, socket); 28 | this.logger.debug(`Client connected: ${socketId}`); 29 | } 30 | 31 | /** 32 | * Remove a client from the connected clients map 33 | */ 34 | removeClient(socketId: string): void { 35 | // Get user ID before removing 36 | const userId = this.connectedUsers.get(socketId); 37 | 38 | // Remove from connected clients 39 | this.connectedClients.delete(socketId); 40 | this.connectedUsers.delete(socketId); 41 | 42 | // Remove from user sockets if user was authenticated 43 | if (userId) { 44 | this.removeUserFromSocket(userId, socketId); 45 | } 46 | 47 | // Remove from room clients 48 | for (const [room, clients] of this.roomClients.entries()) { 49 | if (clients.has(socketId)) { 50 | clients.delete(socketId); 51 | if (clients.size === 0) { 52 | this.roomClients.delete(room); 53 | } 54 | } 55 | } 56 | 57 | this.logger.debug(`Client disconnected: ${socketId}`); 58 | } 59 | 60 | /** 61 | * Associate a user with a socket (after authentication) 62 | */ 63 | associateUserWithSocket(userId: string, socketId: string): void { 64 | // Track socket -> userId mapping 65 | this.connectedUsers.set(socketId, userId); 66 | 67 | if (!this.userSockets.has(userId)) { 68 | this.userSockets.set(userId, new Set()); 69 | 70 | // Emit user online event if this is the first socket for this user 71 | if (this.server) { 72 | this.server.emit('user:online', { userId }); 73 | } 74 | } 75 | 76 | this.userSockets.get(userId)!.add(socketId); 77 | this.logger.debug(`Associated user ${userId} with socket ${socketId}`); 78 | } 79 | 80 | /** 81 | * Remove a user from a socket 82 | */ 83 | removeUserFromSocket(userId: string, socketId: string): void { 84 | const userSocketSet = this.userSockets.get(userId); 85 | if (userSocketSet) { 86 | userSocketSet.delete(socketId); 87 | 88 | // If no more sockets for this user, remove the user entry 89 | if (userSocketSet.size === 0) { 90 | this.userSockets.delete(userId); 91 | 92 | // Emit user offline event 93 | if (this.server) { 94 | this.server.emit('user:offline', { userId }); 95 | } 96 | } 97 | } 98 | 99 | this.logger.debug(`User ${userId} removed from socket ${socketId}`); 100 | } 101 | 102 | /** 103 | * Get user ID from socket ID 104 | */ 105 | getUserIdFromSocket(socketId: string): string | undefined { 106 | return this.connectedUsers.get(socketId); 107 | } 108 | 109 | /** 110 | * Get all socket IDs for a specific user 111 | */ 112 | getUserSockets(userId: string): string[] { 113 | return Array.from(this.userSockets.get(userId) || []); 114 | } 115 | 116 | /** 117 | * Get all socket IDs for a specific user as Set 118 | */ 119 | getUserSocketsSet(userId: string): Set { 120 | return this.userSockets.get(userId) || new Set(); 121 | } 122 | 123 | /** 124 | * Get all connected users 125 | */ 126 | getConnectedUsers(): string[] { 127 | return Array.from(this.userSockets.keys()); 128 | } 129 | 130 | /** 131 | * Check if a user is online 132 | */ 133 | isUserOnline(userId: string): boolean { 134 | const sockets = this.userSockets.get(userId); 135 | return !!sockets && sockets.size > 0; 136 | } 137 | 138 | /** 139 | * Check if user is connected (alias for backward compatibility) 140 | */ 141 | isUserConnected(userId: string): boolean { 142 | return this.isUserOnline(userId); 143 | } 144 | 145 | /** 146 | * Get connected user count 147 | */ 148 | getConnectedUserCount(): number { 149 | return this.userSockets.size; 150 | } 151 | 152 | /** 153 | * Get the total number of socket connections 154 | */ 155 | getSocketConnectionCount(): number { 156 | return this.connectedClients.size; 157 | } 158 | 159 | // =================== 160 | // ROOM MANAGEMENT 161 | // =================== 162 | 163 | /** 164 | * Create branch room name with user isolation 165 | */ 166 | createBranchRoom(userId: string, branchId: string): string { 167 | return `branch:${userId}:${branchId}`; 168 | } 169 | 170 | async joinRoom(socketId: string, room: string): Promise { 171 | const socket = this.connectedClients.get(socketId); 172 | if (socket) { 173 | await socket.join(room); 174 | 175 | // Track room membership 176 | if (!this.roomClients.has(room)) { 177 | this.roomClients.set(room, new Set()); 178 | } 179 | this.roomClients.get(room)?.add(socketId); 180 | 181 | this.logger.debug(`Socket ${socketId} joined room ${room}`); 182 | } 183 | } 184 | 185 | async leaveRoom(socketId: string, room: string): Promise { 186 | const socket = this.connectedClients.get(socketId); 187 | if (socket) { 188 | await socket.leave(room); 189 | 190 | // Update room membership tracking 191 | const roomClients = this.roomClients.get(room); 192 | if (roomClients) { 193 | roomClients.delete(socketId); 194 | if (roomClients.size === 0) { 195 | this.roomClients.delete(room); 196 | } 197 | } 198 | 199 | this.logger.debug(`Socket ${socketId} left room ${room}`); 200 | } 201 | } 202 | 203 | /** 204 | * Join branch room with user validation 205 | */ 206 | async joinBranchRoom( 207 | socketId: string, 208 | branchId: string 209 | ): Promise<{ success: boolean; room?: string; error?: string }> { 210 | try { 211 | const userId = this.getUserIdFromSocket(socketId); 212 | if (!userId) { 213 | return { success: false, error: 'User not authenticated' }; 214 | } 215 | 216 | const room = this.createBranchRoom(userId, branchId); 217 | await this.joinRoom(socketId, room); 218 | 219 | return { success: true, room }; 220 | } catch (error) { 221 | this.logger.error(`Error joining branch room:`, error.message); 222 | return { success: false, error: error.message }; 223 | } 224 | } 225 | 226 | /** 227 | * Leave branch room with user validation 228 | */ 229 | async leaveBranchRoom( 230 | socketId: string, 231 | branchId: string 232 | ): Promise<{ success: boolean; error?: string }> { 233 | try { 234 | const userId = this.getUserIdFromSocket(socketId); 235 | if (!userId) { 236 | return { success: false, error: 'User not authenticated' }; 237 | } 238 | 239 | const room = this.createBranchRoom(userId, branchId); 240 | await this.leaveRoom(socketId, room); 241 | 242 | return { success: true }; 243 | } catch (error) { 244 | this.logger.error(`Error leaving branch room:`, error.message); 245 | return { success: false, error: error.message }; 246 | } 247 | } 248 | 249 | getRoomClients(room: string): string[] { 250 | return Array.from(this.roomClients.get(room) || []); 251 | } 252 | 253 | getRoomCount(room: string): number { 254 | return this.roomClients.get(room)?.size || 0; 255 | } 256 | 257 | // =================== 258 | // EMISSION METHODS 259 | // =================== 260 | 261 | /** 262 | * Emit event to specific user (all their sockets) 263 | */ 264 | emitToUser(userId: string, event: string, data: any): void { 265 | const userSockets = this.getUserSockets(userId); 266 | 267 | for (const socketId of userSockets) { 268 | const socket = this.connectedClients.get(socketId); 269 | if (socket) { 270 | socket.emit(event, data); 271 | this.logger.debug(`Event ${event} sent to user ${userId} on socket ${socketId}`); 272 | } 273 | } 274 | } 275 | 276 | /** 277 | * Emit event to specific branch room 278 | */ 279 | emitToBranch(userId: string, branchId: string, event: string, data: any): void { 280 | const room = this.createBranchRoom(userId, branchId); 281 | if (this.server) { 282 | this.server.to(room).emit(event, data); 283 | this.logger.debug(`Event ${event} sent to branch ${branchId} for user ${userId}`); 284 | } 285 | } 286 | 287 | /** 288 | * Emit to a specific socket 289 | */ 290 | emitToSocket(socketId: string, event: string, data: any): void { 291 | const socket = this.connectedClients.get(socketId); 292 | if (socket) { 293 | socket.emit(event, data); 294 | this.logger.debug(`Emitted event ${event} to socket ${socketId}`); 295 | } else { 296 | this.logger.warn(`Socket ${socketId} not found for event ${event}`); 297 | } 298 | } 299 | 300 | /** 301 | * Emit an event to all connected clients 302 | */ 303 | emitToAll(event: string, data: any): void { 304 | if (!this.server) { 305 | this.logger.error('WebSocket server not initialized'); 306 | return; 307 | } 308 | 309 | this.server.emit(event, data); 310 | this.logger.debug(`Event '${event}' emitted to all clients`); 311 | } 312 | 313 | // =================== 314 | // BUSINESS LOGIC EVENTS 315 | // =================== 316 | 317 | // Auth Events 318 | emitTokenRefresh(userId: string, newTokens: { accessToken: string; refreshToken: string }) { 319 | this.emitToUser(userId, 'auth:token_refreshed', newTokens); 320 | } 321 | 322 | emitLogout(userId: string) { 323 | const userSockets = this.getUserSockets(userId); 324 | 325 | for (const socketId of userSockets) { 326 | const socket = this.connectedClients.get(socketId); 327 | if (socket) { 328 | socket.emit('auth:logout'); 329 | this.logger.debug(`Logout signal sent to user ${userId} on socket ${socketId}`); 330 | // Disconnect the socket after logout 331 | socket.disconnect(); 332 | } 333 | } 334 | } 335 | 336 | // Chat Events 337 | emitChatCreated(userId: string, chat: Chat) { 338 | this.emitToUser(userId, 'chat:created', chat); 339 | } 340 | 341 | emitChatUpdated(userId: string, chat: Chat) { 342 | this.emitToUser(userId, 'chat:updated', chat); 343 | } 344 | 345 | emitChatDeleted(userId: string, chatId: string) { 346 | this.emitToUser(userId, 'chat:deleted', chatId); 347 | } 348 | 349 | // API Key Events 350 | emitApiKeyAdded(userId: string, apiKey: ApiKey) { 351 | this.emitToUser(userId, 'apikey:added', apiKey); 352 | } 353 | 354 | emitApiKeyUpdated(userId: string, apiKey: ApiKey) { 355 | this.emitToUser(userId, 'apikey:updated', apiKey); 356 | } 357 | 358 | emitApiKeyDeleted(userId: string, apiKeyId: string) { 359 | this.emitToUser(userId, 'apikey:deleted', apiKeyId); 360 | } 361 | 362 | // Message Events 363 | emitMessageAdd(userId: string, branchId: string, message: any) { 364 | this.emitToBranch(userId, branchId, 'message:added', message); 365 | } 366 | 367 | emitMessageUpdated(userId: string, messageId: string, message: any) { 368 | this.emitToUser(userId, 'message:updated', message); 369 | } 370 | 371 | emitMessageDeleted(userId: string, messageId: string) { 372 | this.emitToUser(userId, 'message:deleted', messageId); 373 | } 374 | 375 | // Branch Events 376 | emitBranchCreated(userId: string, branch: ChatBranch) { 377 | this.emitToUser(userId, 'branch:created', branch); 378 | } 379 | 380 | emitBranchUpdated(userId: string, branch: ChatBranch) { 381 | this.emitToUser(userId, 'branch:updated', branch); 382 | } 383 | 384 | emitBranchDeleted(userId: string, branchId: string) { 385 | this.emitToUser(userId, 'branch:deleted', branchId); 386 | } 387 | 388 | // User Events 389 | emitUserUpdated(userId: string, user: User) { 390 | this.emitToUser(userId, 'user:updated', user); 391 | } 392 | 393 | // Preferences Events 394 | emitPreferencesUpdated(userId: string, preferences: UserPreferences) { 395 | this.emitToUser(userId, 'preferences:updated', preferences); 396 | } 397 | 398 | // =================== 399 | // UTILITY METHODS 400 | // =================== 401 | 402 | /** 403 | * Disconnect all sockets for a specific user 404 | */ 405 | disconnectUser(userId: string): void { 406 | const userSocketSet = this.userSockets.get(userId); 407 | 408 | if (!userSocketSet || userSocketSet.size === 0) { 409 | this.logger.warn(`No sockets found for user ${userId} to disconnect`); 410 | return; 411 | } 412 | 413 | let disconnectedCount = 0; 414 | for (const socketId of Array.from(userSocketSet)) { 415 | const socket = this.connectedClients.get(socketId); 416 | if (socket) { 417 | socket.disconnect(true); 418 | disconnectedCount++; 419 | } 420 | } 421 | 422 | this.logger.log(`Disconnected ${disconnectedCount} sockets for user ${userId}`); 423 | } 424 | 425 | /** 426 | * Get connection statistics 427 | */ 428 | getConnectionStats(): { 429 | connectedUsers: number; 430 | totalConnections: number; 431 | averageConnectionsPerUser: number; 432 | userConnections: { userId: string; socketCount: number }[]; 433 | rooms: { room: string; clientCount: number }[]; 434 | } { 435 | const userConnections = Array.from(this.userSockets.entries()).map(([userId, sockets]) => ({ 436 | userId, 437 | socketCount: sockets.size, 438 | })); 439 | 440 | const rooms = Array.from(this.roomClients.entries()).map(([room, clients]) => ({ 441 | room, 442 | clientCount: clients.size, 443 | })); 444 | 445 | const totalConnections = this.connectedClients.size; 446 | const connectedUsers = this.userSockets.size; 447 | const averageConnectionsPerUser = 448 | connectedUsers > 0 ? totalConnections / connectedUsers : 0; 449 | 450 | return { 451 | connectedUsers, 452 | totalConnections, 453 | averageConnectionsPerUser: Math.round(averageConnectionsPerUser * 100) / 100, 454 | userConnections, 455 | rooms, 456 | }; 457 | } 458 | 459 | /** 460 | * Clean up disconnected sockets (utility method for maintenance) 461 | */ 462 | cleanupDisconnectedSockets(): number { 463 | let cleanedCount = 0; 464 | const socketsToRemove: string[] = []; 465 | 466 | // Find disconnected sockets 467 | for (const [socketId, socket] of this.connectedClients) { 468 | if (!socket.connected) { 469 | socketsToRemove.push(socketId); 470 | } 471 | } 472 | 473 | // Remove disconnected sockets 474 | for (const socketId of socketsToRemove) { 475 | this.removeClient(socketId); 476 | cleanedCount++; 477 | } 478 | 479 | if (cleanedCount > 0) { 480 | this.logger.log(`Cleaned up ${cleanedCount} disconnected sockets`); 481 | } 482 | 483 | return cleanedCount; 484 | } 485 | } 486 | --------------------------------------------------------------------------------