├── packages ├── graph │ ├── tests │ │ └── unit │ │ │ ├── data │ │ │ └── cluster-origin-data.json │ │ │ ├── utils.ts │ │ │ ├── detect-cycle-async-spec.ts │ │ │ ├── mst-spec.ts │ │ │ ├── cosineSimilarity-spec.ts │ │ │ ├── queue-spec.ts │ │ │ ├── stack-spec.ts │ │ │ ├── degree-async-spec.ts │ │ │ ├── util-spec.ts │ │ │ ├── pagerank-spec.ts │ │ │ ├── floydWarshall-spec.ts │ │ │ ├── pagerank-async-spec.ts │ │ │ ├── floydWarshall-async-spec.ts │ │ │ ├── label-propagation-spec.ts │ │ │ ├── connected-component-spec.ts │ │ │ ├── adjacent-matrix-spec.ts │ │ │ ├── linked-list-spec.ts │ │ │ ├── degree-spec.ts │ │ │ ├── adjacent-matrix-async-spec.ts │ │ │ ├── connected-component-async-spec.ts │ │ │ ├── find-path-spec.ts │ │ │ ├── label-propagation-async-spec.ts │ │ │ ├── louvain-spec.ts │ │ │ ├── louvain-async-spec.ts │ │ │ ├── nodesCosineSimilarity-spec.ts │ │ │ ├── find-path-async-spec.ts │ │ │ ├── dfs-spec.ts │ │ │ ├── bfs-spec.ts │ │ │ └── kMeans-spec.ts │ ├── .fatherrc.js │ ├── .prettierrc.js │ ├── src │ │ ├── constants │ │ │ └── time.ts │ │ ├── cosine-similarity.ts │ │ ├── workers │ │ │ ├── createWorker.ts │ │ │ ├── index.worker.ts │ │ │ ├── algorithm.ts │ │ │ ├── constant.ts │ │ │ └── index.ts │ │ ├── i-louvain.ts │ │ ├── structs │ │ │ ├── queue.ts │ │ │ ├── union-find.ts │ │ │ ├── stack.ts │ │ │ ├── binary-heap.ts │ │ │ └── linked-list.ts │ │ ├── floydWarshall.ts │ │ ├── adjacent-matrix.ts │ │ ├── k-core.ts │ │ ├── degree.ts │ │ ├── asyncIndex.ts │ │ ├── nodes-cosine-similarity.ts │ │ ├── util.ts │ │ ├── find-path.ts │ │ ├── dfs.ts │ │ ├── index.ts │ │ ├── bfs.ts │ │ ├── pageRank.ts │ │ ├── utils │ │ │ ├── node-properties.ts │ │ │ ├── data-preprocessing.ts │ │ │ └── vector.ts │ │ ├── gSpan │ │ │ └── struct.ts │ │ ├── dijkstra.ts │ │ ├── mts.ts │ │ ├── connected-component.ts │ │ ├── label-propagation.ts │ │ ├── types.ts │ │ └── k-means.ts │ ├── webpack.dev.config.js │ ├── README.md │ ├── .babelrc.js │ ├── jest.config.js │ ├── tsconfig.json │ ├── webpack.config.js │ └── package.json └── webgpu-graph │ ├── src │ ├── link-analysis │ │ ├── index.ts │ │ └── pageRank.ts │ ├── traversal │ │ ├── index.ts │ │ ├── bfs.ts │ │ └── sssp.ts │ ├── index.ts │ ├── types.ts │ ├── util.ts │ └── WebGPUGraph.ts │ ├── .fatherrc.js │ ├── .prettierrc.js │ ├── webpack.dev.config.js │ ├── .babelrc.js │ ├── tsconfig.json │ ├── README-zh_CN.md │ ├── webpack.config.js │ ├── README.md │ └── package.json ├── lerna.json ├── .gitignore ├── package.json ├── README-zh_CN.md ├── README.md └── CHANGELOG.md /packages/graph/tests/unit/data/cluster-origin-data.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /packages/webgpu-graph/src/link-analysis/index.ts: -------------------------------------------------------------------------------- 1 | export * from './pageRank'; 2 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/traversal/index.ts: -------------------------------------------------------------------------------- 1 | export * from './sssp'; 2 | export * from './bfs'; 3 | -------------------------------------------------------------------------------- /lerna.json: -------------------------------------------------------------------------------- 1 | { 2 | "packages": [ 3 | "packages/*" 4 | ], 5 | "version": "*", 6 | "npmClient": "tnpm" 7 | } 8 | -------------------------------------------------------------------------------- /packages/graph/.fatherrc.js: -------------------------------------------------------------------------------- 1 | export default { 2 | entry: './src/index.ts', 3 | esm: 'babel', 4 | cjs: 'babel' 5 | }; -------------------------------------------------------------------------------- /packages/webgpu-graph/.fatherrc.js: -------------------------------------------------------------------------------- 1 | export default { 2 | entry: './src/index.ts', 3 | esm: 'babel', 4 | cjs: 'babel' 5 | }; -------------------------------------------------------------------------------- /packages/graph/.prettierrc.js: -------------------------------------------------------------------------------- 1 | const fabric = require('@umijs/fabric'); 2 | 3 | module.exports = { 4 | ...fabric.prettier, 5 | }; 6 | -------------------------------------------------------------------------------- /packages/webgpu-graph/.prettierrc.js: -------------------------------------------------------------------------------- 1 | const fabric = require('@umijs/fabric'); 2 | 3 | module.exports = { 4 | ...fabric.prettier, 5 | }; 6 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/index.ts: -------------------------------------------------------------------------------- 1 | export * from './WebGPUGraph'; 2 | export * from './link-analysis'; 3 | export * from './traversal'; 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode 3 | .idea 4 | npm-debug.log 5 | yarn-error.log 6 | lerna-debug.log 7 | node_modules 8 | coverage 9 | lib 10 | es 11 | dist 12 | *.pem 13 | !mock-cert.pem 14 | -------------------------------------------------------------------------------- /packages/graph/src/constants/time.ts: -------------------------------------------------------------------------------- 1 | export const secondReg = /^(\d{1,4})(-|\/)(\d{1,2})\2(\d{1,2})$/; 2 | export const dateReg = /^(\d{1,4})(-|\/)(\d{1,2})\2(\d{1,2}) (\d{1,2}):(\d{1,2}):(\d{1,2})$/; 3 | -------------------------------------------------------------------------------- /packages/graph/webpack.dev.config.js: -------------------------------------------------------------------------------- 1 | const webpackConfig = require('./webpack.config'); 2 | 3 | module.exports = Object.assign( 4 | { 5 | devtool: 'cheap-source-map', 6 | watch: true, 7 | watchOptions: { 8 | aggregateTimeout: 300, 9 | poll: 1000, 10 | ignored: /node_modules/ 11 | }, 12 | }, 13 | webpackConfig, 14 | ); 15 | -------------------------------------------------------------------------------- /packages/webgpu-graph/webpack.dev.config.js: -------------------------------------------------------------------------------- 1 | const webpackConfig = require('./webpack.config'); 2 | 3 | module.exports = Object.assign( 4 | { 5 | devtool: 'cheap-source-map', 6 | watch: true, 7 | watchOptions: { 8 | aggregateTimeout: 300, 9 | poll: 1000, 10 | ignored: /node_modules/ 11 | }, 12 | }, 13 | webpackConfig, 14 | ); 15 | -------------------------------------------------------------------------------- /packages/graph/README.md: -------------------------------------------------------------------------------- 1 | ### AntV Algorithm 2 | AntV 算法包,包括图算法及其他各类算法。 3 | 4 | graph 包下面包括的都是图算法。 5 | 6 | AntV 共支持以下图算法: 7 | - adjacentMatrix 8 | - connectedComponent 9 | - degree:in degree、out degree 10 | - detectCycle 11 | - dfs 12 | - dijkstra 13 | - findPath:short path、all path 14 | - floydWarshall 15 | - labelPropagation 16 | - louvain 17 | - pageRank 18 | - neighbors 19 | -------------------------------------------------------------------------------- /packages/graph/.babelrc.js: -------------------------------------------------------------------------------- 1 | module.exports = api => { 2 | api.cache(() => process.env.NODE_ENV); 3 | return { 4 | presets: [ 5 | [ 6 | '@babel/preset-env', 7 | { 8 | loose: true, 9 | modules: false, 10 | targets: { node: 'current' }, 11 | }, 12 | ], 13 | '@babel/preset-typescript', 14 | { 15 | plugins: ['@babel/plugin-proposal-class-properties'], 16 | }, 17 | ], 18 | }; 19 | }; 20 | -------------------------------------------------------------------------------- /packages/webgpu-graph/.babelrc.js: -------------------------------------------------------------------------------- 1 | module.exports = api => { 2 | api.cache(() => process.env.NODE_ENV); 3 | return { 4 | presets: [ 5 | [ 6 | '@babel/preset-env', 7 | { 8 | loose: true, 9 | modules: false, 10 | targets: { node: 'current' }, 11 | }, 12 | ], 13 | '@babel/preset-typescript', 14 | { 15 | plugins: ['@babel/plugin-proposal-class-properties'], 16 | }, 17 | ], 18 | }; 19 | }; 20 | -------------------------------------------------------------------------------- /packages/graph/jest.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | runner: 'jest-electron/runner', 3 | testEnvironment: 'jest-electron/environment', 4 | preset: 'ts-jest', 5 | collectCoverage: false, 6 | collectCoverageFrom: ['src/**/*.{ts,js}', '!**/node_modules/**', '!**/vendor/**'], 7 | testRegex: '/tests/.*-spec\\.ts?$', 8 | moduleDirectories: ['node_modules', 'src'], 9 | moduleFileExtensions: ['js', 'ts', 'json'], 10 | globals: { 11 | 'ts-jest': { 12 | diagnostics: false, 13 | }, 14 | }, 15 | }; 16 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/traversal/bfs.ts: -------------------------------------------------------------------------------- 1 | import type { WebGLRenderer } from '@antv/g-webgl'; 2 | import { Kernel, BufferUsage } from '@antv/g-plugin-gpgpu'; 3 | import { GraphData } from '../types'; 4 | import { convertGraphData2CSC } from '../util'; 5 | 6 | /** 7 | * Scalable GPU Graph Traversal 8 | * @see https://research.nvidia.com/publication/scalable-gpu-graph-traversal 9 | * @see https://github.com/rafalk342/bfs-cuda 10 | * @see https://github.com/kaletap/bfs-cuda-gpu 11 | */ 12 | export async function bfs(device: WebGLRenderer.Device, graphData: GraphData) { 13 | 14 | } -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@antv/algorithm", 3 | "private": true, 4 | "scripts": { 5 | "build": "lerna build", 6 | "lint": "lerna run lint", 7 | "test": "lerna run test --no-private", 8 | "prettier": "prettier --write '**/*.{js,jsx,tsx,ts,less,md,json}'", 9 | "pretty-quick": "pretty-quick", 10 | "clean": "lerna clean", 11 | "clear": "lerna clean && lerna clean -y", 12 | "clean:modules": "rimraf node_modules", 13 | "bootstrap": "lerna bootstrap", 14 | "ls": "lerna list" 15 | }, 16 | "devDependencies": { 17 | "lerna": "^3.20.2" 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/types.ts: -------------------------------------------------------------------------------- 1 | export interface NodeConfig { 2 | id: string; 3 | clusterId?: string; 4 | [key: string]: any; 5 | } 6 | 7 | export interface EdgeConfig { 8 | source: string; 9 | target: string; 10 | weight?: number; 11 | [key: string]: any; 12 | } 13 | 14 | export interface GraphData { 15 | nodes?: NodeConfig[]; 16 | edges?: EdgeConfig[]; 17 | } 18 | 19 | export interface CSC { 20 | V: number[]; 21 | E: number[]; 22 | I: number[]; 23 | From: number[]; 24 | To: number[]; 25 | nodeId2IndexMap: Record; 26 | edges: EdgeConfig[]; 27 | } 28 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/utils.ts: -------------------------------------------------------------------------------- 1 | import algorithm from '../../src/'; 2 | 3 | type IAlgorithm = typeof algorithm; 4 | declare const window: Window & { 5 | Algorithm: IAlgorithm; 6 | }; 7 | 8 | export const getAlgorithm = () => 9 | new Promise(resolve => { 10 | if (window.Algorithm) { 11 | resolve(window.Algorithm); 12 | } 13 | const script = document.createElement('script'); 14 | script.type = 'text/javascript'; 15 | script.src = `${process.cwd()}/dist/index.min.js`; 16 | script.onload = function() { 17 | resolve(window.Algorithm); 18 | }; 19 | document.body.append(script); 20 | }); -------------------------------------------------------------------------------- /packages/graph/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "rootDir": "./", 4 | "module": "esnext", 5 | "declaration": true, 6 | "sourceMap": true, 7 | "target": "es5", 8 | "outDir": "lib", 9 | "importHelpers": true, 10 | "moduleResolution": "node", 11 | "allowSyntheticDefaultImports": true, 12 | "resolveJsonModule": true, 13 | "esModuleInterop": true, 14 | "lib": ["esnext", "dom"], 15 | "types": ["node", "jest"] 16 | }, 17 | "include": ["src", "tests"], 18 | "typedocOptions": { 19 | "mode": "modules", 20 | "out": "docs/api-ts", 21 | "excludePrivate": true, 22 | "excludeProtected": true, 23 | "excludeExternals": true, 24 | "plugin": "typedoc-plugin-markdown", 25 | "theme": "docusaurus2" 26 | } 27 | } -------------------------------------------------------------------------------- /packages/webgpu-graph/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "rootDir": "./", 4 | "module": "esnext", 5 | "declaration": true, 6 | "sourceMap": true, 7 | "target": "es5", 8 | "outDir": "lib", 9 | "importHelpers": true, 10 | "moduleResolution": "node", 11 | "allowSyntheticDefaultImports": true, 12 | "resolveJsonModule": true, 13 | "esModuleInterop": true, 14 | "lib": ["esnext", "dom"], 15 | "types": ["node", "jest"] 16 | }, 17 | "include": ["src", "tests"], 18 | "typedocOptions": { 19 | "mode": "modules", 20 | "out": "docs/api-ts", 21 | "excludePrivate": true, 22 | "excludeProtected": true, 23 | "excludeExternals": true, 24 | "plugin": "typedoc-plugin-markdown", 25 | "theme": "docusaurus2" 26 | } 27 | } -------------------------------------------------------------------------------- /packages/graph/src/cosine-similarity.ts: -------------------------------------------------------------------------------- 1 | import Vector from './utils/vector'; 2 | /** 3 | * cosine-similarity算法 计算余弦相似度 4 | * @param item 元素 5 | * @param targetItem 目标元素 6 | */ 7 | const cosineSimilarity = ( 8 | item: number[], 9 | targetItem: number[], 10 | ): number => { 11 | // 目标元素向量 12 | const targetItemVector = new Vector(targetItem); 13 | // 目标元素向量的模长 14 | const targetNodeNorm2 = targetItemVector.norm2(); 15 | // 元素向量 16 | const itemVector = new Vector(item); 17 | // 元素向量的模长 18 | const itemNorm2 = itemVector.norm2(); 19 | // 计算元素向量和目标元素向量的点积 20 | const dot = targetItemVector.dot(itemVector); 21 | const norm2Product = targetNodeNorm2 * itemNorm2; 22 | // 计算元素向量和目标元素向量的余弦相似度 23 | const cosineSimilarity = norm2Product ? dot / norm2Product : 0; 24 | return cosineSimilarity; 25 | } 26 | 27 | export default cosineSimilarity; 28 | -------------------------------------------------------------------------------- /packages/graph/src/workers/createWorker.ts: -------------------------------------------------------------------------------- 1 | import { MESSAGE } from './constant'; 2 | import Worker from './index.worker'; 3 | 4 | interface Event { 5 | type: string; 6 | data: any; 7 | } 8 | 9 | /** 10 | * 创建一个在worker中运行的算法 11 | * @param type 算法类型 12 | */ 13 | const createWorker = (type: string) => (...data) => 14 | new Promise((resolve, reject) => { 15 | const worker = new Worker(); 16 | worker.postMessage({ 17 | _algorithmType:type, 18 | data, 19 | }); 20 | 21 | worker.onmessage = (event: Event) => { 22 | const { data, _algorithmType } = event.data; 23 | if (MESSAGE.SUCCESS === _algorithmType) { 24 | resolve(data); 25 | } else { 26 | reject(); 27 | } 28 | 29 | worker.terminate(); 30 | }; 31 | }); 32 | 33 | export default createWorker; 34 | -------------------------------------------------------------------------------- /packages/graph/src/workers/index.worker.ts: -------------------------------------------------------------------------------- 1 | import * as algorithm from './algorithm'; 2 | import { MESSAGE } from './constant'; 3 | 4 | const ctx: Worker = (typeof self !== 'undefined') ? self : {} as any; 5 | 6 | interface Event { 7 | type: string; 8 | data: any; 9 | } 10 | 11 | ctx.onmessage = (event: Event) => { 12 | const { _algorithmType, data } = event.data; 13 | // 如果发送内容没有私有类型。说明不是自己发的。不管 14 | // fix: https://github.com/antvis/algorithm/issues/25 15 | if(!_algorithmType){ 16 | return; 17 | } 18 | if (typeof algorithm[_algorithmType] === 'function') { 19 | const result = algorithm[_algorithmType](...data); 20 | ctx.postMessage({ _algorithmType: MESSAGE.SUCCESS, data: result }); 21 | return; 22 | } 23 | ctx.postMessage({ _algorithmType: MESSAGE.FAILURE }); 24 | }; 25 | 26 | // https://stackoverflow.com/questions/50210416/webpack-worker-loader-fails-to-compile-typescript-worker 27 | export default null as any; 28 | -------------------------------------------------------------------------------- /packages/graph/src/workers/algorithm.ts: -------------------------------------------------------------------------------- 1 | export { default as getAdjMatrix } from '../adjacent-matrix'; 2 | export { default as breadthFirstSearch } from '../bfs'; 3 | export { default as connectedComponent } from '../connected-component'; 4 | export { default as getDegree } from '../degree'; 5 | export { getInDegree, getOutDegree } from '../degree'; 6 | export { default as detectCycle } from '../detect-cycle'; 7 | export { default as depthFirstSearch } from '../dfs'; 8 | export { default as dijkstra } from '../dijkstra'; 9 | export { findAllPath, findShortestPath } from '../find-path'; 10 | export { default as floydWarshall } from '../floydWarshall'; 11 | export { default as labelPropagation } from '../label-propagation'; 12 | export { default as louvain } from '../louvain'; 13 | export { default as minimumSpanningTree } from '../mts'; 14 | export { default as pageRank } from '../pageRank'; 15 | export { default as GADDI } from '../gaddi'; 16 | export { getNeighbors } from '../util'; 17 | -------------------------------------------------------------------------------- /README-zh_CN.md: -------------------------------------------------------------------------------- 1 | ### AntV Algorithm 2 | 3 | AntV 算法包,包括图算法及其他各类算法。 4 | 5 | graph 包下面包括的都是图算法。 6 | 7 | AntV 共支持以下图算法: 8 | - **社区发现** 9 | - k-core: K-Core社区发现算法 -- 找到符合指定核心度K的密切相关子图结构 10 | - louvain: LOUVAIN 算法 -- 根据模块度划分社区 11 | - i-louvain: I-LOUVAIN 算法 -- 根据模块度和惯性模块度(属性相似度)划分社区 12 | - labelPropagation: 标签传播算法 13 | - minimumSpanningTree: 图的最小生成树 14 | 15 | - **节点聚类** 16 | - k-means: K-Means算法 - 根据节点之间的距离将节点分为K个簇 17 | 18 | - **相似性** 19 | - cosineSimilarity: 余弦相似度算法 -- 计算两个元素的余弦相似度 20 | - nodesCosineSimilarity: 节点余弦相似度算法 -- 计算节点与种子节点之间的余弦相似度 21 | 22 | 23 | - **中心性** 24 | - pageRank: 节点排序的页面排序算法 25 | - degree: 计算节点的入度、出度和总度 26 | 27 | - **路径** 28 | - dijkstra: Dijkstra 最短路径算法 29 | - findPath: 通过Dijkstra找到两个节点的最短路径和所有路径 30 | - floydWarshall: 弗洛伊德最短路径算法 31 | 32 | - **其它** 33 | - neighbors: 在图中查找节点的邻居 34 | - GADDI: 图结构和语义模式匹配算法 35 | - detectCycle: 环路检测 36 | - dfs: D深度优先遍历 37 | - adjacentMatrix: 邻接矩阵 38 | - connectedComponent: 联通子图 39 | 40 | 并支持在 web-worker 中计算上述算法 41 | -------------------------------------------------------------------------------- /packages/graph/src/i-louvain.ts: -------------------------------------------------------------------------------- 1 | import louvain from './louvain'; 2 | import type { ClusterData, GraphData } from './types'; 3 | 4 | /** 5 | * 社区发现 i-louvain 算法:模块度 + 惯性模块度(即节点属性相似性) 6 | * @param graphData 图数据 7 | * @param directed 是否有向图,默认为 false 8 | * @param weightPropertyName 权重的属性字段 9 | * @param threshold 差值阈值 10 | * @param propertyKey 属性的字段名 11 | * @param involvedKeys 参与计算的key集合 12 | * @param uninvolvedKeys 不参与计算的key集合 13 | * @param inertialWeight 惯性模块度权重 14 | */ 15 | const iLouvain = ( 16 | graphData: GraphData, 17 | directed: boolean = false, 18 | weightPropertyName: string = 'weight', 19 | threshold: number = 0.0001, 20 | propertyKey: string = undefined, 21 | involvedKeys: string[] = [], 22 | uninvolvedKeys: string[] = ['id'], 23 | inertialWeight: number = 1, 24 | ): ClusterData => { 25 | return louvain(graphData, directed, weightPropertyName, threshold, true, propertyKey, involvedKeys, uninvolvedKeys, inertialWeight); 26 | } 27 | 28 | export default iLouvain; 29 | -------------------------------------------------------------------------------- /packages/graph/src/structs/queue.ts: -------------------------------------------------------------------------------- 1 | import LinkedList from './linked-list'; 2 | 3 | export default class Queue { 4 | public linkedList: LinkedList; 5 | 6 | constructor() { 7 | this.linkedList = new LinkedList(); 8 | } 9 | 10 | /** 11 | * 队列是否为空 12 | */ 13 | public isEmpty() { 14 | return !this.linkedList.head; 15 | } 16 | 17 | /** 18 | * 读取队列头部的元素, 不删除队列中的元素 19 | */ 20 | public peek() { 21 | if (!this.linkedList.head) { 22 | return null; 23 | } 24 | return this.linkedList.head.value; 25 | } 26 | 27 | /** 28 | * 在队列的尾部新增一个元素 29 | * @param value 30 | */ 31 | public enqueue(value) { 32 | this.linkedList.append(value); 33 | } 34 | 35 | /** 36 | * 删除队列中的头部元素,如果队列为空,则返回 null 37 | */ 38 | public dequeue() { 39 | const removeHead = this.linkedList.deleteHead(); 40 | return removeHead ? removeHead.value : null; 41 | } 42 | 43 | public toString(callback?: any) { 44 | return this.linkedList.toString(callback); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /packages/graph/src/workers/constant.ts: -------------------------------------------------------------------------------- 1 | export const ALGORITHM = { 2 | pageRank: 'pageRank', 3 | breadthFirstSearch: 'breadthFirstSearch', 4 | connectedComponent: 'connectedComponent', 5 | depthFirstSearch: 'depthFirstSearch', 6 | detectCycle: 'detectCycle', 7 | detectDirectedCycle: 'detectDirectedCycle', 8 | detectAllCycles: 'detectAllCycles', 9 | detectAllDirectedCycle: 'detectAllDirectedCycle', 10 | detectAllUndirectedCycle: 'detectAllUndirectedCycle', 11 | dijkstra: 'dijkstra', 12 | findAllPath: 'findAllPath', 13 | findShortestPath: 'findShortestPath', 14 | floydWarshall: 'floydWarshall', 15 | getAdjMatrix: 'getAdjMatrix', 16 | getDegree: 'getDegree', 17 | getInDegree: 'getInDegree', 18 | getNeighbors: 'getNeighbors', 19 | getOutDegree: 'getOutDegree', 20 | labelPropagation: 'labelPropagation', 21 | louvain: 'louvain', 22 | GADDI: 'GADDI', 23 | minimumSpanningTree: 'minimumSpanningTree', 24 | SUCCESS: 'SUCCESS', 25 | FAILURE: 'FAILURE', 26 | }; 27 | 28 | export const MESSAGE = { 29 | SUCCESS: 'SUCCESS', 30 | FAILURE: 'FAILURE', 31 | }; 32 | -------------------------------------------------------------------------------- /packages/graph/src/floydWarshall.ts: -------------------------------------------------------------------------------- 1 | import getAdjMatrix from "./adjacent-matrix"; 2 | import { GraphData, Matrix } from "./types"; 3 | 4 | const floydWarshall = (graphData: GraphData, directed?: boolean) => { 5 | const adjacentMatrix = getAdjMatrix(graphData, directed); 6 | 7 | const dist: Matrix[] = []; 8 | const size = adjacentMatrix.length; 9 | for (let i = 0; i < size; i += 1) { 10 | dist[i] = []; 11 | for (let j = 0; j < size; j += 1) { 12 | if (i === j) { 13 | dist[i][j] = 0; 14 | } else if (adjacentMatrix[i][j] === 0 || !adjacentMatrix[i][j]) { 15 | dist[i][j] = Infinity; 16 | } else { 17 | dist[i][j] = adjacentMatrix[i][j]; 18 | } 19 | } 20 | } 21 | // floyd 22 | for (let k = 0; k < size; k += 1) { 23 | for (let i = 0; i < size; i += 1) { 24 | for (let j = 0; j < size; j += 1) { 25 | if (dist[i][j] > dist[i][k] + dist[k][j]) { 26 | dist[i][j] = dist[i][k] + dist[k][j]; 27 | } 28 | } 29 | } 30 | } 31 | return dist; 32 | }; 33 | 34 | export default floydWarshall; 35 | -------------------------------------------------------------------------------- /packages/graph/src/adjacent-matrix.ts: -------------------------------------------------------------------------------- 1 | import { GraphData, Matrix } from "./types"; 2 | 3 | const adjMatrix = (graphData: GraphData, directed?: boolean): Matrix[] => { 4 | const { nodes, edges } = graphData; 5 | const matrix: Matrix[] = []; 6 | // map node with index in data.nodes 7 | const nodeMap: { 8 | [key: string]: number; 9 | } = {}; 10 | 11 | if (!nodes) { 12 | throw new Error("invalid nodes data!"); 13 | } 14 | 15 | if (nodes) { 16 | nodes.forEach((node, i) => { 17 | nodeMap[node.id] = i; 18 | const row: number[] = []; 19 | matrix.push(row); 20 | }); 21 | } 22 | 23 | if (edges) { 24 | edges.forEach((edge) => { 25 | const { source, target } = edge; 26 | const sIndex = nodeMap[source as string]; 27 | const tIndex = nodeMap[target as string]; 28 | if ((!sIndex && sIndex !== 0) || (!tIndex && tIndex !== 0)) return; 29 | matrix[sIndex][tIndex] = 1; 30 | if (!directed) { 31 | matrix[tIndex][sIndex] = 1; 32 | } 33 | }); 34 | } 35 | return matrix; 36 | }; 37 | 38 | export default adjMatrix; 39 | -------------------------------------------------------------------------------- /packages/graph/src/structs/union-find.ts: -------------------------------------------------------------------------------- 1 | /** 2 | * 并查集 Disjoint set to support quick union 3 | */ 4 | export default class UnionFind { 5 | count: number; 6 | 7 | parent: {}; 8 | 9 | constructor(items: (number | string)[]) { 10 | this.count = items.length; 11 | this.parent = {}; 12 | for (const i of items) { 13 | this.parent[i] = i; 14 | } 15 | } 16 | 17 | // find the root of the item 18 | find(item) { 19 | while (this.parent[item] !== item) { 20 | item = this.parent[item]; 21 | } 22 | return item; 23 | } 24 | 25 | union(a, b) { 26 | const rootA = this.find(a); 27 | const rootB = this.find(b); 28 | 29 | if (rootA === rootB) return; 30 | 31 | // make the element with smaller root the parent 32 | if (rootA < rootB) { 33 | if (this.parent[b] !== b) this.union(this.parent[b], a); 34 | this.parent[b] = this.parent[a]; 35 | } else { 36 | if (this.parent[a] !== a) this.union(this.parent[a], b); 37 | this.parent[a] = this.parent[b]; 38 | } 39 | } 40 | 41 | // whether a and b are connected, i.e. a and b have the same root 42 | connected(a, b) { 43 | return this.find(a) === this.find(b); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /packages/graph/src/k-core.ts: -------------------------------------------------------------------------------- 1 | 2 | import { clone } from '@antv/util'; 3 | import degree from './degree'; 4 | import { GraphData } from './types'; 5 | /** 6 | * k-core算法 找出符合指定核心度的紧密关联的子图结构 7 | * @param graphData 图数据 8 | * @param k 核心度数 9 | */ 10 | const kCore = ( 11 | graphData: GraphData, 12 | k: number = 1, 13 | ): GraphData => { 14 | const data = clone(graphData); 15 | const { nodes = [] } = data; 16 | let { edges = [] } = data; 17 | while (true) { 18 | // 获取图中节点的度数 19 | const degrees = degree({ nodes, edges}); 20 | const nodeIds = Object.keys(degrees); 21 | // 按照度数进行排序 22 | nodeIds.sort((a, b) => degrees[a]?.degree - degrees[b]?.degree); 23 | const minIndexId = nodeIds[0]; 24 | if (!nodes.length || degrees[minIndexId]?.degree >= k) { 25 | break; 26 | } 27 | const originIndex = nodes.findIndex(node => node.id === minIndexId); 28 | // 移除度数小于k的节点 29 | nodes.splice(originIndex, 1); 30 | // 移除度数小于k的节点相关的边 31 | edges = edges.filter(edge => !(edge.source === minIndexId || edge.target === minIndexId)); 32 | } 33 | 34 | return { nodes, edges }; 35 | } 36 | 37 | export default kCore; 38 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/util.ts: -------------------------------------------------------------------------------- 1 | import type { GraphData, CSC, EdgeConfig } from './types'; 2 | 3 | export function convertGraphData2CSC(graphData: GraphData): CSC { 4 | const V: number[] = []; 5 | const E: number[] = []; 6 | const From: number[] = []; 7 | const To: number[] = []; 8 | const I: number[] = []; 9 | const nodeId2IndexMap: Record = {}; 10 | const edges: EdgeConfig[] = [] 11 | graphData.nodes.forEach((node, i) => { 12 | nodeId2IndexMap[node.id] = i; 13 | V.push(i); 14 | }); 15 | 16 | let lastSource = ''; 17 | let counter = 0; 18 | // sort by source 19 | [...graphData.edges] 20 | .sort((a, b) => nodeId2IndexMap[a.source] - nodeId2IndexMap[b.source]) 21 | .forEach((edgeConfig) => { 22 | const { source, target } = edgeConfig; 23 | edges.push(edgeConfig); 24 | E.push(nodeId2IndexMap[target]); 25 | From.push(nodeId2IndexMap[source]); 26 | To.push(nodeId2IndexMap[target]); 27 | 28 | if (source !== lastSource) { 29 | I.push(counter); 30 | lastSource = source; 31 | } 32 | 33 | counter++; 34 | }); 35 | 36 | I.push(E.length); 37 | 38 | return { 39 | V, 40 | E, 41 | I, 42 | From, 43 | To, 44 | nodeId2IndexMap, 45 | edges, 46 | }; 47 | } 48 | -------------------------------------------------------------------------------- /packages/graph/src/degree.ts: -------------------------------------------------------------------------------- 1 | import { GraphData, DegreeType } from "./types"; 2 | 3 | const degree = (graphData: GraphData): DegreeType => { 4 | const degrees: DegreeType = {}; 5 | const { nodes = [], edges = [] } = graphData 6 | 7 | nodes.forEach((node) => { 8 | degrees[node.id] = { 9 | degree: 0, 10 | inDegree: 0, 11 | outDegree: 0, 12 | }; 13 | }); 14 | 15 | edges.forEach((edge) => { 16 | degrees[edge.source].degree++; 17 | degrees[edge.source].outDegree++; 18 | degrees[edge.target].degree++; 19 | degrees[edge.target].inDegree++; 20 | }); 21 | 22 | return degrees; 23 | }; 24 | 25 | export default degree; 26 | 27 | /** 28 | * 获取指定节点的入度 29 | * @param graphData 图数据 30 | * @param nodeId 节点ID 31 | */ 32 | export const getInDegree = (graphData: GraphData, nodeId: string): number => { 33 | const nodeDegree = degree(graphData) 34 | if (nodeDegree[nodeId]) { 35 | return degree(graphData)[nodeId].inDegree 36 | } 37 | return 0 38 | } 39 | 40 | /** 41 | * 获取指定节点的出度 42 | * @param graphData 图数据 43 | * @param nodeId 节点ID 44 | */ 45 | export const getOutDegree = (graphData: GraphData, nodeId: string): number => { 46 | const nodeDegree = degree(graphData) 47 | if (nodeDegree[nodeId]) { 48 | return degree(graphData)[nodeId].outDegree 49 | } 50 | return 0 51 | } 52 | -------------------------------------------------------------------------------- /packages/graph/src/structs/stack.ts: -------------------------------------------------------------------------------- 1 | import LinkedList from './linked-list'; 2 | 3 | export default class Stack { 4 | 5 | private linkedList: LinkedList; 6 | 7 | private maxStep: number; 8 | 9 | constructor(maxStep: number = 10) { 10 | this.linkedList = new LinkedList(); 11 | this.maxStep = maxStep; 12 | } 13 | 14 | get length() { 15 | return this.linkedList.toArray().length; 16 | } 17 | 18 | /** 19 | * 判断栈是否为空,如果链表中没有头部元素,则栈为空 20 | */ 21 | isEmpty() { 22 | return !this.linkedList.head; 23 | } 24 | 25 | /** 26 | * 是否到定义的栈的最大长度,如果达到最大长度后,不再允许入栈 27 | */ 28 | isMaxStack() { 29 | return this.toArray().length >= this.maxStep; 30 | } 31 | 32 | /** 33 | * 访问顶端元素 34 | */ 35 | peek() { 36 | if (this.isEmpty()) { 37 | return null; 38 | } 39 | 40 | // 返回头部元素,不删除元素 41 | return this.linkedList.head.value; 42 | } 43 | 44 | push(value) { 45 | this.linkedList.prepend(value); 46 | if (this.length > this.maxStep) { 47 | this.linkedList.deleteTail(); 48 | } 49 | } 50 | 51 | pop() { 52 | const removeHead = this.linkedList.deleteHead(); 53 | return removeHead ? removeHead.value : null; 54 | } 55 | 56 | toArray() { 57 | return this.linkedList.toArray().map((node) => node.value); 58 | } 59 | 60 | clear() { 61 | while (!this.isEmpty()) { 62 | this.pop(); 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/detect-cycle-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | ], 27 | edges: [ 28 | { 29 | source: 'A', 30 | target: 'B', 31 | }, 32 | { 33 | source: 'B', 34 | target: 'C', 35 | }, 36 | { 37 | source: 'A', 38 | target: 'C', 39 | }, 40 | { 41 | source: 'D', 42 | target: 'A', 43 | }, 44 | { 45 | source: 'D', 46 | target: 'E', 47 | }, 48 | { 49 | source: 'E', 50 | target: 'F', 51 | }, 52 | ], 53 | }; 54 | 55 | describe('(Async) detectDirectedCycle', () => { 56 | it('should detect directed cycle', async () => { 57 | const { detectCycleAsync } = await getAlgorithm(); 58 | 59 | let result = await detectCycleAsync(data); 60 | expect(result).toBeNull(); 61 | 62 | data.edges.push({ 63 | source: 'F', 64 | target: 'D', 65 | }); 66 | 67 | result = await detectCycleAsync(data); 68 | expect(result).toEqual({ 69 | D: 'F', 70 | F: 'E', 71 | E: 'D', 72 | }); 73 | }); 74 | }); 75 | -------------------------------------------------------------------------------- /packages/webgpu-graph/README-zh_CN.md: -------------------------------------------------------------------------------- 1 | # AntV Graph Algorithm based on WebGPU 2 | 3 | 参考 [cuGraph](https://github.com/rapidsai/cugraph) 以及其他 CUDA 实现,基于 WebGPU 实现常见的图分析算法,实现大规模节点边数据量下并行加速的目的。 4 | 5 | [文档](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph): 6 | 7 | - Link Analysis 8 | - [PageRank](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph#pagerank) 9 | - Traversal 10 | - [SSSP](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph#sssp) 11 | 12 | ## 前置条件 13 | 14 | - WebGPU 目前仅支持在 Chrome 94 版本以上运行,推荐升级到最新版。 15 | - 启用 Origin Trial 支持 WebGPU 特性(Chrome 100 以上不再需要): 16 | - [获取 Token](https://developer.chrome.com/origintrials/#/view_trial/118219490218475521) 17 | - 在页面中添加 `` 标签,附上上一步获取的 Token,例如通过 DOM API: 18 | 19 | ```js 20 | const tokenElement = document.createElement('meta'); 21 | tokenElement.httpEquiv = 'origin-trial'; 22 | tokenElement.content = 'AkIL...5fQ=='; 23 | document.head.appendChild(tokenElement); 24 | ``` 25 | 26 | ## 使用方法 27 | 28 | 均为异步调用,以 pageRank 为例: 29 | 30 | ```js 31 | import { pageRank, WebGPUGraph } from '@antv/webgpu-graph'; 32 | 33 | // 初始化 34 | const graph = new WebGPUGraph(); 35 | 36 | const result = await graph.pageRank(graph_data, eps, alpha, max_iter); 37 | ``` 38 | 39 | ## 性能测试 40 | 41 | | 算法名 | 节点 / 边 | CPU 耗时 | GPU 耗时 | Speed up | 42 | | -------- | --------------- | ----------- | --------- | -------- | 43 | | SSSP | 1k 节点 5k 边 | 27687.10 ms | 261.60 ms | ~100x | 44 | | PageRank | 1k 节点 500k 边 | 13641.50 ms | 130.20 ms | ~100x | 45 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/mst-spec.ts: -------------------------------------------------------------------------------- 1 | import { minimumSpanningTree } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | ], 27 | edges: [ 28 | { 29 | source: 'A', 30 | target: 'B', 31 | weight: 1, 32 | }, 33 | { 34 | source: 'B', 35 | target: 'C', 36 | weight: 1, 37 | }, 38 | { 39 | source: 'A', 40 | target: 'C', 41 | weight: 2, 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | weight: 3, 47 | }, 48 | { 49 | source: 'D', 50 | target: 'E', 51 | weight: 4, 52 | }, 53 | { 54 | source: 'E', 55 | target: 'F', 56 | weight: 2, 57 | }, 58 | { 59 | source: 'F', 60 | target: 'D', 61 | weight: 3, 62 | }, 63 | ], 64 | }; 65 | 66 | describe('minimumSpanningTree', () => { 67 | it('test kruskal algorithm', () => { 68 | let result = minimumSpanningTree(data, 'weight'); 69 | let totalWeight = 0; 70 | for (let edge of result) { 71 | totalWeight += edge.weight; 72 | } 73 | expect(totalWeight).toEqual(10); 74 | }); 75 | 76 | it('test prim algorithm', () => { 77 | let result = minimumSpanningTree(data, 'weight', 'prim'); 78 | let totalWeight = 0; 79 | for (let edge of result) { 80 | totalWeight += edge.weight; 81 | } 82 | expect(totalWeight).toEqual(10); 83 | }); 84 | }); 85 | -------------------------------------------------------------------------------- /packages/graph/src/asyncIndex.ts: -------------------------------------------------------------------------------- 1 | import { 2 | getAdjMatrixAsync, 3 | connectedComponentAsync, 4 | getDegreeAsync, 5 | getInDegreeAsync, 6 | getOutDegreeAsync, 7 | detectCycleAsync, 8 | detectAllCyclesAsync, 9 | detectAllDirectedCycleAsync, 10 | detectAllUndirectedCycleAsync, 11 | dijkstraAsync, 12 | findAllPathAsync, 13 | findShortestPathAsync, 14 | floydWarshallAsync, 15 | labelPropagationAsync, 16 | louvainAsync, 17 | minimumSpanningTreeAsync, 18 | pageRankAsync, 19 | getNeighborsAsync, 20 | GADDIAsync, 21 | } from './workers/index'; 22 | 23 | const detectDirectedCycleAsync = detectCycleAsync; 24 | 25 | export { 26 | getAdjMatrixAsync, 27 | connectedComponentAsync, 28 | getDegreeAsync, 29 | getInDegreeAsync, 30 | getOutDegreeAsync, 31 | detectCycleAsync, 32 | detectDirectedCycleAsync, 33 | detectAllCyclesAsync, 34 | detectAllDirectedCycleAsync, 35 | detectAllUndirectedCycleAsync, 36 | dijkstraAsync, 37 | findAllPathAsync, 38 | findShortestPathAsync, 39 | floydWarshallAsync, 40 | labelPropagationAsync, 41 | louvainAsync, 42 | minimumSpanningTreeAsync, 43 | pageRankAsync, 44 | getNeighborsAsync, 45 | GADDIAsync, 46 | }; 47 | 48 | export default { 49 | getAdjMatrixAsync, 50 | connectedComponentAsync, 51 | getDegreeAsync, 52 | getInDegreeAsync, 53 | getOutDegreeAsync, 54 | detectCycleAsync, 55 | detectDirectedCycleAsync, 56 | detectAllCyclesAsync, 57 | detectAllDirectedCycleAsync, 58 | detectAllUndirectedCycleAsync, 59 | dijkstraAsync, 60 | findAllPathAsync, 61 | findShortestPathAsync, 62 | floydWarshallAsync, 63 | labelPropagationAsync, 64 | louvainAsync, 65 | minimumSpanningTreeAsync, 66 | pageRankAsync, 67 | getNeighborsAsync, 68 | GADDIAsync, 69 | }; -------------------------------------------------------------------------------- /packages/graph/tests/unit/cosineSimilarity-spec.ts: -------------------------------------------------------------------------------- 1 | import { cosineSimilarity } from '../../src'; 2 | 3 | describe('cosineSimilarity abnormal demo: ', () => { 4 | it('item contains only zeros: ', () => { 5 | const item = [0, 0, 0]; 6 | const targetTtem = [3, 1, 1]; 7 | const cosineSimilarityValue = cosineSimilarity(item, targetTtem); 8 | expect(cosineSimilarityValue).toBe(0); 9 | }); 10 | it('targetTtem contains only zeros: ', () => { 11 | const item = [3, 5, 2]; 12 | const targetTtem = [0, 0, 0]; 13 | const cosineSimilarityValue = cosineSimilarity(item, targetTtem); 14 | expect(cosineSimilarityValue).toBe(0); 15 | }); 16 | it('item and targetTtem both contains only zeros: ', () => { 17 | const item = [0, 0, 0]; 18 | const targetTtem = [0, 0, 0]; 19 | const cosineSimilarityValue = cosineSimilarity(item, targetTtem); 20 | expect(cosineSimilarityValue).toBe(0); 21 | }); 22 | }); 23 | 24 | describe('cosineSimilarity normal demo: ', () => { 25 | it('demo similar: ', () => { 26 | const item = [30, 0, 100]; 27 | const targetTtem = [32, 1, 120]; 28 | const cosineSimilarityValue = cosineSimilarity(item, targetTtem); 29 | expect(cosineSimilarityValue).toBeGreaterThanOrEqual(0); 30 | expect(cosineSimilarityValue).toBeLessThan(1); 31 | expect(Number(cosineSimilarityValue.toFixed(3))).toBe(0.999); 32 | }); 33 | it('demo dissimilar: ', () => { 34 | const item = [10, 300, 2]; 35 | const targetTtem = [1, 2, 30]; 36 | const cosineSimilarityValue = cosineSimilarity(item, targetTtem); 37 | expect(cosineSimilarityValue).toBeGreaterThanOrEqual(0); 38 | expect(cosineSimilarityValue).toBeLessThan(1); 39 | expect(Number(cosineSimilarityValue.toFixed(3))).toBe(0.074); 40 | }); 41 | }); 42 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/WebGPUGraph.ts: -------------------------------------------------------------------------------- 1 | import { Canvas } from '@antv/g'; 2 | import { Renderer } from '@antv/g-webgl'; 3 | import { Plugin } from '@antv/g-plugin-gpgpu'; 4 | import { pageRank } from './link-analysis'; 5 | import { sssp } from './traversal'; 6 | import type { GraphData } from './types'; 7 | 8 | export interface WebGPUGraphOptions { 9 | canvas: HTMLCanvasElement | OffscreenCanvas; 10 | } 11 | 12 | export class WebGPUGraph { 13 | canvas: Canvas; 14 | renderer: Renderer; 15 | 16 | constructor(options: Partial = {}) { 17 | const { canvas } = options; 18 | 19 | // FIXME: use OffscreenCanvas instead of a real DOM 20 | const $canvas = canvas || window.document.createElement('canvas'); 21 | 22 | // use WebGPU 23 | this.renderer = new Renderer({ targets: ['webgpu'] }); 24 | this.renderer.registerPlugin(new Plugin()); 25 | 26 | // create a canvas 27 | this.canvas = new Canvas({ 28 | canvas: $canvas, 29 | width: 1, 30 | height: 1, 31 | renderer: this.renderer, 32 | }); 33 | } 34 | 35 | private async getDevice() { 36 | // wait for canvas' services ready 37 | await this.canvas.ready; 38 | // get GPU Device 39 | return this.renderer.getDevice(); 40 | } 41 | 42 | async pageRank(graphData: GraphData, eps = 1e-5, alpha = 0.85, maxIteration = 1000) { 43 | const device = await this.getDevice(); 44 | return pageRank(device, graphData, eps, alpha, maxIteration); 45 | } 46 | 47 | async sssp(graphData: GraphData, sourceId: string, weightPropertyName: string = '') { 48 | const device = await this.getDevice(); 49 | return sssp(device, graphData, sourceId, weightPropertyName); 50 | } 51 | 52 | destroy() { 53 | if (this.canvas) { 54 | this.canvas.destroy(); 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### AntV Algorithm 2 | 3 | It is an algorithm package of AntV, mainly includes graph related algorithms: 4 | - **Community Detection** 5 | - k-core: K-Core community detection algorithm -- Find the closely related subgraph structure that conforms to the specified core degree K 6 | - louvain: LOUVAIN algorithm -- Divide communities according to Modularity 7 | - i-louvain: I-LOUVAIN algorithm -- Divide communities according to Modularity and Inertial Modularity (properties similarity) 8 | - labelPropagation: Label Propagation(LP) clustering algorithm 9 | - minimumSpanningTree: Generate the minimum spanning tree for a graph 10 | 11 | - **nodes clustering** 12 | - k-means: K-Means algorithm - Cluster nodes into K clusters according to the distance between node 13 | 14 | - **Similarity** 15 | - cosineSimilarity: Cosine Similarity algorithm -- Calculate cosine similarity 16 | - nodesCosineSimilarity: Nodes Cosine Similarity algorithm -- Calculate the cosine similarity between other nodes and seed node 17 | 18 | - **Centrality** 19 | - pageRank: page rank algorithm for nodes ranking 20 | - degree: calculate the in degree, out degree, and total degree for nodes 21 | 22 | - **Path** 23 | - dijkstra: Dijkstra shortest path algorithm 24 | - findPath: Find the shortest paths and all paths for two nodes by Dijkstra 25 | - floydWarshall: Floyd Warshall shortest path algorithm 26 | 27 | - **Other** 28 | - neighbors: Find the neighbors for a node in the graph 29 | - GADDI: graph structural and semantic pattern matching algorithm 30 | - detectCycle: Detect the cycles of the graph data 31 | - dfs: Depth-First search algorithm 32 | - adjacentMatrix: calculate the adjacency matrix for graph data 33 | - connectedComponent: Calculate the connected components for graph data 34 | 35 | All the algorithms above supports to be calculated with web-worker. -------------------------------------------------------------------------------- /packages/graph/tests/unit/queue-spec.ts: -------------------------------------------------------------------------------- 1 | import Queue from '../../src/structs/queue'; 2 | 3 | describe('Queue', () => { 4 | it('should create empty queue', () => { 5 | const queue = new Queue(); 6 | expect(queue).not.toBeNull(); 7 | expect(queue.linkedList).not.toBeNull(); 8 | }); 9 | 10 | it('should enqueue data to queue', () => { 11 | const queue = new Queue(); 12 | 13 | queue.enqueue(1); 14 | queue.enqueue(2); 15 | 16 | expect(queue.toString()).toBe('1,2'); 17 | }); 18 | 19 | it('should be possible to enqueue/dequeue objects', () => { 20 | const queue = new Queue(); 21 | 22 | queue.enqueue({ value: 'test1', key: 'key1' }); 23 | queue.enqueue({ value: 'test2', key: 'key2' }); 24 | 25 | const stringifier = (value) => `${value.key}:${value.value}`; 26 | 27 | expect(queue.toString(stringifier)).toBe('key1:test1,key2:test2'); 28 | expect(queue.dequeue().value).toBe('test1'); 29 | expect(queue.dequeue().value).toBe('test2'); 30 | }); 31 | 32 | it('should peek data from queue', () => { 33 | const queue = new Queue(); 34 | 35 | expect(queue.peek()).toBeNull(); 36 | 37 | queue.enqueue(1); 38 | queue.enqueue(2); 39 | 40 | expect(queue.peek()).toBe(1); 41 | expect(queue.peek()).toBe(1); 42 | }); 43 | 44 | it('should check if queue is empty', () => { 45 | const queue = new Queue(); 46 | 47 | expect(queue.isEmpty()).toBe(true); 48 | 49 | queue.enqueue(1); 50 | 51 | expect(queue.isEmpty()).toBe(false); 52 | }); 53 | 54 | it('should dequeue from queue in FIFO order', () => { 55 | const queue = new Queue(); 56 | 57 | queue.enqueue(1); 58 | queue.enqueue(2); 59 | 60 | expect(queue.dequeue()).toBe(1); 61 | expect(queue.dequeue()).toBe(2); 62 | expect(queue.dequeue()).toBeNull(); 63 | expect(queue.isEmpty()).toBe(true); 64 | }); 65 | }); 66 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/stack-spec.ts: -------------------------------------------------------------------------------- 1 | import Stack from '../../src/structs/stack'; 2 | 3 | describe('stack unit test', () => { 4 | it('init stack', () => { 5 | const stack = new Stack(); 6 | for (let i = 0; i < 4; i++) { 7 | stack.push({ 8 | nodes: [ 9 | { 10 | id: `node${i}`, 11 | }, 12 | ], 13 | }); 14 | } 15 | 16 | const result = stack.pop(); 17 | // console.log(stack.toArray()); 18 | expect(result).toEqual({ 19 | nodes: [ 20 | { 21 | id: 'node3', 22 | }, 23 | ], 24 | }); 25 | 26 | expect(stack.peek()).toEqual({ 27 | nodes: [ 28 | { 29 | id: 'node2', 30 | }, 31 | ], 32 | }); 33 | 34 | expect(stack.isMaxStack()).toBe(false); 35 | expect(stack.isEmpty()).toBe(false); 36 | 37 | stack.push({ 38 | nodes: [ 39 | { 40 | id: 'node5', 41 | }, 42 | ], 43 | }); 44 | stack.push({ 45 | nodes: [ 46 | { 47 | id: 'node6', 48 | }, 49 | ], 50 | }); 51 | 52 | expect(stack.isMaxStack()).toBe(false); 53 | stack.clear() 54 | expect(stack.length).toBe(0) 55 | }); 56 | 57 | it('init stack with maxStep', () => { 58 | const stack = new Stack(3); 59 | for (let i = 0; i < 5; i++) { 60 | stack.push({ 61 | nodes: [ 62 | { 63 | id: `node${i}`, 64 | }, 65 | ], 66 | }); 67 | } 68 | expect(stack.length).toBe(3); 69 | 70 | expect(stack.toArray()).toEqual([ 71 | { 72 | nodes: [{ id: 'node4'}] 73 | }, 74 | { 75 | nodes: [{ id: 'node3'}] 76 | }, 77 | { 78 | nodes: [{ id: 'node2'}] 79 | } 80 | ]) 81 | 82 | stack.clear() 83 | expect(stack.length).toBe(0) 84 | }); 85 | }); 86 | -------------------------------------------------------------------------------- /packages/webgpu-graph/webpack.config.js: -------------------------------------------------------------------------------- 1 | const webpack = require('webpack'); 2 | const resolve = require('path').resolve; 3 | 4 | module.exports = { 5 | entry: { 6 | index: './src/index.ts', 7 | }, 8 | output: { 9 | filename: '[name].min.js', 10 | library: 'WebGPUGraph', 11 | libraryTarget: 'umd', 12 | libraryExport: 'default', 13 | path: resolve(process.cwd(), 'dist/'), 14 | globalObject: 'this', 15 | publicPath: './dist', 16 | }, 17 | watchOptions: { 18 | ignored: /node_modules/ 19 | }, 20 | resolve: { 21 | extensions: ['.ts', '.js'], 22 | }, 23 | module: { 24 | rules: [ 25 | { 26 | test: /\.worker\.ts$/, 27 | exclude: /(node_modules)/, 28 | use: [ 29 | { 30 | loader: 'worker-loader', 31 | options: { 32 | inline: 'fallback', 33 | filename: 'index.worker.js', 34 | }, 35 | }, 36 | ], 37 | }, 38 | { 39 | test: /\.js$/, 40 | include: /node_modules/, 41 | use: { 42 | loader: 'babel-loader', 43 | options: { 44 | presets: [ 45 | [ 46 | '@babel/preset-env', 47 | { 48 | loose: true, 49 | modules: false, 50 | }, 51 | ], 52 | { 53 | plugins: ['@babel/plugin-proposal-class-properties'], 54 | }, 55 | ], 56 | }, 57 | }, 58 | }, 59 | { 60 | test: /\.ts$/, 61 | use: { 62 | loader: 'ts-loader', 63 | options: { 64 | transpileOnly: true, 65 | }, 66 | }, 67 | }, 68 | ], 69 | }, 70 | plugins: [new webpack.NoEmitOnErrorsPlugin(), new webpack.optimize.AggressiveMergingPlugin()], 71 | devtool: 'source-map', 72 | }; 73 | -------------------------------------------------------------------------------- /packages/graph/webpack.config.js: -------------------------------------------------------------------------------- 1 | const webpack = require('webpack'); 2 | const resolve = require('path').resolve; 3 | 4 | module.exports = { 5 | entry: { 6 | index: './src/index.ts', 7 | async: './src/asyncIndex.ts' 8 | }, 9 | output: { 10 | filename: '[name].min.js', 11 | library: 'Algorithm', 12 | libraryTarget: 'umd', 13 | libraryExport: 'default', 14 | path: resolve(process.cwd(), 'dist/'), 15 | globalObject: 'this', 16 | publicPath: './dist', 17 | }, 18 | watchOptions: { 19 | ignored: /node_modules/ 20 | }, 21 | resolve: { 22 | extensions: ['.ts', '.js'], 23 | }, 24 | module: { 25 | rules: [ 26 | { 27 | test: /\.worker\.ts$/, 28 | exclude: /(node_modules)/, 29 | use: [ 30 | { 31 | loader: 'worker-loader', 32 | options: { 33 | inline: 'fallback', 34 | filename: 'index.worker.js', 35 | }, 36 | }, 37 | ], 38 | }, 39 | { 40 | test: /\.js$/, 41 | include: /node_modules/, 42 | use: { 43 | loader: 'babel-loader', 44 | options: { 45 | presets: [ 46 | [ 47 | '@babel/preset-env', 48 | { 49 | loose: true, 50 | modules: false, 51 | }, 52 | ], 53 | { 54 | plugins: ['@babel/plugin-proposal-class-properties'], 55 | }, 56 | ], 57 | }, 58 | }, 59 | }, 60 | { 61 | test: /\.ts$/, 62 | use: { 63 | loader: 'ts-loader', 64 | options: { 65 | transpileOnly: true, 66 | }, 67 | }, 68 | }, 69 | ], 70 | }, 71 | plugins: [new webpack.NoEmitOnErrorsPlugin(), new webpack.optimize.AggressiveMergingPlugin()], 72 | devtool: 'source-map', 73 | }; 74 | -------------------------------------------------------------------------------- /packages/webgpu-graph/README.md: -------------------------------------------------------------------------------- 1 | # AntV Graph Algorithm based on WebGPU 2 | 3 | `webgpu-graph` is a GPU accelerated graph analytics library, with functionality like [WebGPU](https://www.w3.org/TR/webgpu/) which provides modern features such as compute shader(in [WGSL](https://www.w3.org/TR/WGSL/)). Compared with CPU version, we almost gain ~100x speed up with big datasets. 4 | 5 | It's inspired by [cuGraph](https://github.com/rapidsai/cugraph) and other implementations based on CUDA. 6 | 7 | [Docs](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph): 8 | 9 | - Link Analysis 10 | - [PageRank](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph#pagerank) 11 | - Traversal 12 | - [SSSP](https://g-next.antv.vision/zh/docs/api/gpgpu/webgpu-graph#sssp) 13 | 14 | ## Prerequisite 15 | 16 | [How to use WebGPU](https://web.dev/gpu/#use) 17 | 18 | For our examples, we use [origin trial](https://web.dev/gpu/#enabling-support-during-the-origin-trial-phase). The origin trial is expected to end in Chrome 101 (May 18, 2022). 19 | 20 | Since we are using latest syntax of WGSL, you'd better update your Chrome to the latest version. 21 | 22 | ## Usage 23 | 24 | ```js 25 | import { pageRank, WebGPUGraph } from '@antv/webgpu-graph'; 26 | 27 | // initialize WebGPU context 28 | const graph = new WebGPUGraph(); 29 | 30 | // call async method 31 | const result = await graph.pageRank(graph_data, eps, alpha, max_iter); 32 | ``` 33 | 34 | ## Building 35 | 36 | - Install dependencies: `yarn install` 37 | - For production, compile the project: `yarn build` 38 | 39 | ## Benchmark 40 | 41 | | name | vertices and edges | CPU time elapsed | GPU time elapsed | Speed up | 42 | | -------- | ------------------------ | ---------------- | ---------------- | -------- | 43 | | SSSP | 1k vertices & 5k edges | 27687.10 ms | 261.60 ms | ~100x | 44 | | PageRank | 1k vertices & 500k edges | 13641.50 ms | 130.20 ms | ~100x | 45 | -------------------------------------------------------------------------------- /packages/graph/src/nodes-cosine-similarity.ts: -------------------------------------------------------------------------------- 1 | import { clone } from '@antv/util'; 2 | import { NodeConfig } from './types'; 3 | import { getAllProperties } from './utils/node-properties'; 4 | import { oneHot } from './utils/data-preprocessing'; 5 | import cosineSimilarity from './cosine-similarity'; 6 | /** 7 | * nodes-cosine-similarity算法 基于节点属性计算余弦相似度(基于种子节点寻找相似节点) 8 | * @param nodes 图节点数据 9 | * @param seedNode 种子节点 10 | * @param propertyKey 属性的字段名 11 | * @param involvedKeys 参与计算的key集合 12 | * @param uninvolvedKeys 不参与计算的key集合 13 | */ 14 | const nodesCosineSimilarity = ( 15 | nodes: NodeConfig[] = [], 16 | seedNode: NodeConfig, 17 | propertyKey: string = undefined, 18 | involvedKeys: string[] = [], 19 | uninvolvedKeys: string[] = [], 20 | ): { 21 | allCosineSimilarity: number[], 22 | similarNodes: NodeConfig[], 23 | } => { 24 | const similarNodes = clone(nodes.filter(node => node.id !== seedNode.id)); 25 | const seedNodeIndex = nodes.findIndex(node => node.id === seedNode.id); 26 | // 所有节点属性集合 27 | const properties = getAllProperties(nodes, propertyKey); 28 | // 所有节点属性one-hot特征向量集合 29 | const allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys); 30 | // 种子节点属性 31 | const seedNodeProperties = allPropertiesWeight[seedNodeIndex]; 32 | 33 | const allCosineSimilarity: number[] = []; 34 | similarNodes.forEach((node, index) => { 35 | if (node.id !== seedNode.id) { 36 | // 节点属性 37 | const nodeProperties = allPropertiesWeight[index]; 38 | // 计算节点向量和种子节点向量的余弦相似度 39 | const cosineSimilarityValue = cosineSimilarity(nodeProperties, seedNodeProperties); 40 | allCosineSimilarity.push(cosineSimilarityValue); 41 | node.cosineSimilarity = cosineSimilarityValue; 42 | } 43 | }); 44 | 45 | // 将返回的节点按照余弦相似度大小排序 46 | similarNodes.sort((a, b) => b.cosineSimilarity - a.cosineSimilarity); 47 | return { allCosineSimilarity, similarNodes }; 48 | } 49 | 50 | export default nodesCosineSimilarity; 51 | -------------------------------------------------------------------------------- /packages/graph/src/util.ts: -------------------------------------------------------------------------------- 1 | import { EdgeConfig, GraphData, Matrix } from './types' 2 | 3 | /** 4 | * 获取指定节点的所有邻居 5 | * @param nodeId 节点 ID 6 | * @param edges 图中的所有边数据 7 | * @param type 邻居类型 8 | */ 9 | export const getNeighbors = (nodeId: string, edges: EdgeConfig[] = [], type?: 'target' | 'source' | undefined): string[] => { 10 | const currentEdges = edges.filter(edge => edge.source === nodeId || edge.target === nodeId) 11 | if (type === 'target') { 12 | // 当前节点为 source,它所指向的目标节点 13 | const neighhborsConverter = (edge: EdgeConfig) => { 14 | return edge.source === nodeId; 15 | }; 16 | return currentEdges.filter(neighhborsConverter).map((edge) => edge.target); 17 | } 18 | if (type === 'source') { 19 | // 当前节点为 target,它所指向的源节点 20 | const neighhborsConverter = (edge: EdgeConfig) => { 21 | return edge.target === nodeId; 22 | }; 23 | return currentEdges.filter(neighhborsConverter).map((edge) => edge.source); 24 | } 25 | 26 | // 若未指定 type ,则返回所有邻居 27 | const neighhborsConverter = (edge: EdgeConfig) => { 28 | return edge.source === nodeId ? edge.target : edge.source; 29 | }; 30 | return currentEdges.map(neighhborsConverter); 31 | } 32 | 33 | /** 34 | * 获取指定节点的出边 35 | * @param nodeId 节点 ID 36 | * @param edges 图中的所有边数据 37 | */ 38 | export const getOutEdgesNodeId = (nodeId: string, edges: EdgeConfig[]) => { 39 | return edges.filter(edge => edge.source === nodeId) 40 | } 41 | 42 | /** 43 | * 获取指定节点的边,包括出边和入边 44 | * @param nodeId 节点 ID 45 | * @param edges 图中的所有边数据 46 | */ 47 | export const getEdgesByNodeId = (nodeId: string, edges: EdgeConfig[]) => { 48 | return edges.filter(edge => edge.source === nodeId || edge.target === nodeId) 49 | } 50 | 51 | /** 52 | * 生成唯一的 ID,规则是序号 + 时间戳 53 | * @param index 序号 54 | */ 55 | export const uniqueId = (index: number = 0) => { 56 | const random1 = `${Math.random()}`.split('.')[1].substr(0, 5); 57 | const random2 = `${Math.random()}`.split('.')[1].substr(0, 5); 58 | return `${index}-${random1}${random2}` 59 | }; 60 | -------------------------------------------------------------------------------- /packages/graph/src/find-path.ts: -------------------------------------------------------------------------------- 1 | import dijkstra from './dijkstra'; 2 | import { GraphData } from './types'; 3 | import { getNeighbors } from './util'; 4 | 5 | export const findShortestPath = ( 6 | graphData: GraphData, 7 | start: string, 8 | end: string, 9 | directed?: boolean, 10 | weightPropertyName?: string 11 | ) => { 12 | const { length, path, allPath } = dijkstra( 13 | graphData, 14 | start, 15 | directed, 16 | weightPropertyName 17 | ); 18 | return { length: length[end], path: path[end], allPath: allPath[end] }; 19 | }; 20 | 21 | export const findAllPath = ( 22 | graphData: GraphData, 23 | start: string, 24 | end: string, 25 | directed?: boolean 26 | ) => { 27 | if (start === end) return [[start]]; 28 | 29 | const { edges = [] } = graphData; 30 | 31 | const visited = [start]; 32 | const isVisited = { [start]: true }; 33 | const stack: string[][] = []; // 辅助栈,用于存储访问过的节点的邻居节点 34 | const allPath = []; 35 | let neighbors = directed 36 | ? getNeighbors(start, edges, 'target') 37 | : getNeighbors(start, edges); 38 | stack.push(neighbors); 39 | 40 | while (visited.length > 0 && stack.length > 0) { 41 | const children = stack[stack.length - 1]; 42 | if (children.length) { 43 | const child = children.shift(); 44 | if (child) { 45 | visited.push(child); 46 | isVisited[child] = true; 47 | neighbors = directed 48 | ? getNeighbors(child, edges, 'target') 49 | : getNeighbors(child, edges); 50 | stack.push(neighbors.filter(neighbor => !isVisited[neighbor])); 51 | } 52 | } else { 53 | const node = visited.pop(); 54 | isVisited[node] = false; 55 | stack.pop(); 56 | continue; 57 | } 58 | 59 | if (visited[visited.length - 1] === end) { 60 | const path = visited.map(node => node); 61 | allPath.push(path); 62 | 63 | const node = visited.pop(); 64 | isVisited[node] = false; 65 | stack.pop(); 66 | } 67 | } 68 | 69 | return allPath; 70 | }; 71 | -------------------------------------------------------------------------------- /packages/graph/src/dfs.ts: -------------------------------------------------------------------------------- 1 | import { IAlgorithmCallbacks, GraphData } from './types'; 2 | import { getNeighbors } from './util'; 3 | 4 | function initCallbacks(callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks) { 5 | const initiatedCallback = callbacks; 6 | 7 | const stubCallback = () => {}; 8 | 9 | const allowTraversalCallback = (() => { 10 | const seen = {}; 11 | return ({ next }) => { 12 | if (!seen[next]) { 13 | seen[next] = true; 14 | return true; 15 | } 16 | return false; 17 | }; 18 | })(); 19 | 20 | initiatedCallback.allowTraversal = callbacks.allowTraversal || allowTraversalCallback; 21 | initiatedCallback.enter = callbacks.enter || stubCallback; 22 | initiatedCallback.leave = callbacks.leave || stubCallback; 23 | 24 | return initiatedCallback; 25 | } 26 | 27 | /** 28 | * @param {Graph} graph 29 | * @param {GraphNode} currentNode 30 | * @param {GraphNode} previousNode 31 | * @param {Callbacks} callbacks 32 | */ 33 | function depthFirstSearchRecursive( 34 | graphData: GraphData, 35 | currentNode: string, 36 | previousNode: string, 37 | callbacks: IAlgorithmCallbacks, 38 | directed: boolean = true, 39 | ) { 40 | callbacks.enter({ 41 | current: currentNode, 42 | previous: previousNode, 43 | }); 44 | 45 | const { edges = [] } = graphData; 46 | 47 | getNeighbors(currentNode, edges, directed ? 'target' : undefined).forEach((nextNode) => { 48 | if ( 49 | callbacks.allowTraversal({ 50 | previous: previousNode, 51 | current: currentNode, 52 | next: nextNode, 53 | }) 54 | ) { 55 | depthFirstSearchRecursive(graphData, nextNode, currentNode, callbacks, directed); 56 | } 57 | }); 58 | 59 | callbacks.leave({ 60 | current: currentNode, 61 | previous: previousNode, 62 | }); 63 | } 64 | 65 | /** 66 | * 深度优先遍历图 67 | * @param data GraphData 图数据 68 | * @param startNodeId 开始遍历的节点的 ID 69 | * @param originalCallbacks 回调 70 | */ 71 | export default function depthFirstSearch( 72 | graphData: GraphData, 73 | startNodeId: string, 74 | callbacks?: IAlgorithmCallbacks, 75 | directed: boolean = true, 76 | ) { 77 | depthFirstSearchRecursive(graphData, startNodeId, '', initCallbacks(callbacks), directed); 78 | } 79 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/degree-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'A', 41 | target: 'C', 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | }, 47 | { 48 | source: 'D', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | { 60 | source: 'G', 61 | target: 'H', 62 | }, 63 | { 64 | source: 'H', 65 | target: 'G', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('degree async algorithm', () => { 71 | it('getDegreeAsync', async () => { 72 | const degree = { 73 | A: { 74 | degree: 3, 75 | inDegree: 1, 76 | outDegree: 2, 77 | }, 78 | B: { 79 | degree: 2, 80 | inDegree: 1, 81 | outDegree: 1, 82 | }, 83 | C: { 84 | degree: 2, 85 | inDegree: 2, 86 | outDegree: 0, 87 | }, 88 | D: { 89 | degree: 3, 90 | inDegree: 1, 91 | outDegree: 2, 92 | }, 93 | E: { 94 | degree: 2, 95 | inDegree: 1, 96 | outDegree: 1, 97 | }, 98 | F: { 99 | degree: 2, 100 | inDegree: 1, 101 | outDegree: 1, 102 | }, 103 | G: { 104 | degree: 2, 105 | inDegree: 1, 106 | outDegree: 1, 107 | }, 108 | H: { 109 | degree: 2, 110 | inDegree: 1, 111 | outDegree: 1, 112 | }, 113 | }; 114 | 115 | const { getDegreeAsync } = await getAlgorithm(); 116 | const result = await getDegreeAsync(data); 117 | expect(result).toEqual(degree); 118 | }); 119 | }); 120 | -------------------------------------------------------------------------------- /packages/graph/src/structs/binary-heap.ts: -------------------------------------------------------------------------------- 1 | const defaultCompare = (a, b) => { 2 | return a - b; 3 | }; 4 | 5 | export default class MinBinaryHeap { 6 | list: any[]; 7 | 8 | compareFn: (a: any, b: any) => number; 9 | 10 | constructor(compareFn = defaultCompare) { 11 | this.compareFn = compareFn; 12 | this.list = []; 13 | } 14 | 15 | getLeft(index) { 16 | return 2 * index + 1; 17 | } 18 | 19 | getRight(index) { 20 | return 2 * index + 2; 21 | } 22 | 23 | getParent(index) { 24 | if (index === 0) { 25 | return null; 26 | } 27 | return Math.floor((index - 1) / 2); 28 | } 29 | 30 | isEmpty() { 31 | return this.list.length <= 0; 32 | } 33 | 34 | top() { 35 | return this.isEmpty() ? undefined : this.list[0]; 36 | } 37 | 38 | delMin() { 39 | const top = this.top(); 40 | const bottom = this.list.pop(); 41 | if (this.list.length > 0) { 42 | this.list[0] = bottom; 43 | this.moveDown(0); 44 | } 45 | return top; 46 | } 47 | 48 | insert(value) { 49 | if (value !== null) { 50 | this.list.push(value); 51 | const index = this.list.length - 1; 52 | this.moveUp(index); 53 | return true; 54 | } 55 | return false; 56 | } 57 | 58 | moveUp(index) { 59 | let parent = this.getParent(index); 60 | while (index && index > 0 && this.compareFn(this.list[parent], this.list[index]) > 0) { 61 | // swap 62 | const tmp = this.list[parent]; 63 | this.list[parent] = this.list[index]; 64 | this.list[index] = tmp; 65 | // [this.list[index], this.list[parent]] = [this.list[parent], this.list[index]] 66 | index = parent; 67 | parent = this.getParent(index); 68 | } 69 | } 70 | 71 | moveDown(index) { 72 | let element = index; 73 | const left = this.getLeft(index); 74 | const right = this.getRight(index); 75 | const size = this.list.length; 76 | if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) { 77 | element = left; 78 | } else if ( 79 | right !== null && 80 | right < size && 81 | this.compareFn(this.list[element], this.list[right]) > 0 82 | ) { 83 | element = right; 84 | } 85 | if (index !== element) { 86 | [this.list[index], this.list[element]] = [this.list[element], this.list[index]]; 87 | this.moveDown(element); 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/util-spec.ts: -------------------------------------------------------------------------------- 1 | import { getNeighbors, getEdgesByNodeId, getOutEdgesNodeId } from '../../src/util' 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'A', 41 | target: 'C', 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | }, 47 | { 48 | source: 'D', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | { 60 | source: 'G', 61 | target: 'H', 62 | }, 63 | { 64 | source: 'H', 65 | target: 'G', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('algorithm util method', () => { 71 | it('getNeighbors', () => { 72 | 73 | let result = getNeighbors('A', data.edges); 74 | expect(result).toEqual(['B', 'C', 'D']); 75 | 76 | result = getNeighbors('A', data.edges, 'source') 77 | expect(result).toEqual(['D']) 78 | 79 | result = getNeighbors('A', data.edges, 'target') 80 | expect(result).toEqual(['B', 'C']) 81 | }); 82 | 83 | it('getEdgesByNodeId', () => { 84 | const aEdges = [ 85 | { 86 | source: 'A', 87 | target: 'B', 88 | }, 89 | { 90 | source: 'A', 91 | target: 'C', 92 | }, 93 | { 94 | source: 'D', 95 | target: 'A', 96 | }, 97 | ] 98 | let result = getEdgesByNodeId('A', data.edges); 99 | expect(result).toEqual(aEdges); 100 | }); 101 | 102 | it('getOutEdgesNodeId', () => { 103 | const aEdges = [ 104 | { 105 | source: 'A', 106 | target: 'B', 107 | }, 108 | { 109 | source: 'A', 110 | target: 'C', 111 | } 112 | ] 113 | let result = getOutEdgesNodeId('A', data.edges); 114 | expect(result).toEqual(aEdges); 115 | }); 116 | }); 117 | -------------------------------------------------------------------------------- /packages/graph/src/index.ts: -------------------------------------------------------------------------------- 1 | import getAdjMatrix from './adjacent-matrix'; 2 | import breadthFirstSearch from './bfs'; 3 | import connectedComponent from './connected-component'; 4 | import getDegree from './degree'; 5 | import { getInDegree, getOutDegree } from './degree'; 6 | import detectCycle, { detectAllCycles, detectAllDirectedCycle, detectAllUndirectedCycle } from './detect-cycle'; 7 | import depthFirstSearch from './dfs'; 8 | import dijkstra from './dijkstra'; 9 | import { findAllPath, findShortestPath } from './find-path'; 10 | import floydWarshall from './floydWarshall'; 11 | import labelPropagation from './label-propagation'; 12 | import louvain from './louvain'; 13 | import iLouvain from './i-louvain'; 14 | import kCore from './k-core'; 15 | import kMeans from './k-means'; 16 | import cosineSimilarity from './cosine-similarity'; 17 | import nodesCosineSimilarity from './nodes-cosine-similarity'; 18 | import minimumSpanningTree from './mts'; 19 | import pageRank from './pageRank'; 20 | import GADDI from './gaddi'; 21 | import Stack from './structs/stack'; 22 | import { getNeighbors } from './util'; 23 | import { IAlgorithm } from './types'; 24 | 25 | const detectDirectedCycle = detectCycle; 26 | 27 | export { 28 | getAdjMatrix, 29 | breadthFirstSearch, 30 | connectedComponent, 31 | getDegree, 32 | getInDegree, 33 | getOutDegree, 34 | detectCycle, 35 | detectDirectedCycle, 36 | detectAllCycles, 37 | detectAllDirectedCycle, 38 | detectAllUndirectedCycle, 39 | depthFirstSearch, 40 | dijkstra, 41 | findAllPath, 42 | findShortestPath, 43 | floydWarshall, 44 | labelPropagation, 45 | louvain, 46 | iLouvain, 47 | kCore, 48 | kMeans, 49 | cosineSimilarity, 50 | nodesCosineSimilarity, 51 | minimumSpanningTree, 52 | pageRank, 53 | getNeighbors, 54 | Stack, 55 | GADDI, 56 | IAlgorithm 57 | }; 58 | 59 | export default { 60 | getAdjMatrix, 61 | breadthFirstSearch, 62 | connectedComponent, 63 | getDegree, 64 | getInDegree, 65 | getOutDegree, 66 | detectCycle, 67 | detectDirectedCycle, 68 | detectAllCycles, 69 | detectAllDirectedCycle, 70 | detectAllUndirectedCycle, 71 | depthFirstSearch, 72 | dijkstra, 73 | findAllPath, 74 | findShortestPath, 75 | floydWarshall, 76 | labelPropagation, 77 | louvain, 78 | iLouvain, 79 | kCore, 80 | kMeans, 81 | cosineSimilarity, 82 | nodesCosineSimilarity, 83 | minimumSpanningTree, 84 | pageRank, 85 | getNeighbors, 86 | Stack, 87 | GADDI, 88 | }; -------------------------------------------------------------------------------- /packages/graph/src/bfs.ts: -------------------------------------------------------------------------------- 1 | import Queue from './structs/queue' 2 | import { GraphData, IAlgorithmCallbacks } from './types'; 3 | import { getNeighbors } from './util'; 4 | 5 | /** 6 | * 7 | * @param callbacks 8 | * allowTraversal: 确定 BFS 是否从顶点沿着边遍历到其邻居,默认情况下,同一个节点只能遍历一次 9 | * enterNode: 当 BFS 访问某个节点时调用 10 | * leaveNode: 当 BFS 访问访问结束某个节点时调用 11 | */ 12 | function initCallbacks(callbacks: IAlgorithmCallbacks = {} as IAlgorithmCallbacks) { 13 | const initiatedCallback = callbacks; 14 | 15 | const stubCallback = () => {}; 16 | 17 | const allowTraversalCallback = (() => { 18 | const seen = {}; 19 | return ({ next }) => { 20 | const id = next; 21 | if (!seen[id]) { 22 | seen[id] = true; 23 | return true; 24 | } 25 | return false; 26 | }; 27 | })(); 28 | 29 | initiatedCallback.allowTraversal = callbacks.allowTraversal || allowTraversalCallback; 30 | initiatedCallback.enter = callbacks.enter || stubCallback; 31 | initiatedCallback.leave = callbacks.leave || stubCallback; 32 | 33 | return initiatedCallback; 34 | } 35 | 36 | /** 37 | * 广度优先遍历图 38 | * @param graph Graph 图实例 39 | * @param startNode 开始遍历的节点 40 | * @param originalCallbacks 回调 41 | */ 42 | const breadthFirstSearch = ( 43 | graphData: GraphData, 44 | startNodeId: string, 45 | originalCallbacks?: IAlgorithmCallbacks, 46 | directed: boolean = true 47 | ) => { 48 | const callbacks = initCallbacks(originalCallbacks); 49 | const nodeQueue = new Queue(); 50 | 51 | const { edges = [] } = graphData 52 | 53 | // 初始化队列元素 54 | nodeQueue.enqueue(startNodeId); 55 | 56 | let previousNode = ''; 57 | 58 | // 遍历队列中的所有顶点 59 | while (!nodeQueue.isEmpty()) { 60 | const currentNode: string = nodeQueue.dequeue(); 61 | callbacks.enter({ 62 | current: currentNode, 63 | previous: previousNode, 64 | }); 65 | 66 | // 将所有邻居添加到队列中以便遍历 67 | getNeighbors(currentNode, edges, directed ? 'target' : undefined).forEach((nextNode) => { 68 | if ( 69 | callbacks.allowTraversal({ 70 | previous: previousNode, 71 | current: currentNode, 72 | next: nextNode, 73 | }) 74 | ) { 75 | nodeQueue.enqueue(nextNode); 76 | } 77 | }); 78 | 79 | callbacks.leave({ 80 | current: currentNode, 81 | previous: previousNode, 82 | }); 83 | 84 | // 下一次循环之前存储当前顶点 85 | previousNode = currentNode; 86 | } 87 | }; 88 | 89 | export default breadthFirstSearch; 90 | -------------------------------------------------------------------------------- /packages/graph/src/pageRank.ts: -------------------------------------------------------------------------------- 1 | import { GraphData } from "./types"; 2 | import degree from './degree' 3 | import { getNeighbors } from "./util"; 4 | 5 | /** 6 | * PageRank https://en.wikipedia.org/wiki/PageRank 7 | * refer: https://github.com/anvaka/ngraph.pagerank 8 | * @param graph 9 | * @param epsilon 判断是否收敛的精度值,默认 0.000001 10 | * @param linkProb 阻尼系数(dumping factor),指任意时刻,用户访问到某节点后继续访问该节点链接的下一个节点的概率,经验值 0.85 11 | */ 12 | const pageRank = (graphData: GraphData, epsilon?: number, linkProb?: number): { 13 | [key: string]: number; 14 | } => { 15 | if (typeof epsilon !== 'number') epsilon = 0.000001; 16 | if (typeof linkProb !== 'number') linkProb = 0.85; 17 | 18 | let distance = 1; 19 | let leakedRank = 0; 20 | let maxIterations = 1000; 21 | 22 | const { nodes = [], edges = [] } = graphData; 23 | const nodesCount = nodes.length; 24 | let currentRank; 25 | const curRanks = {}; 26 | const prevRanks = {} 27 | 28 | // Initialize pageranks 初始化 29 | for (let j = 0; j < nodesCount; ++j) { 30 | const node = nodes[j]; 31 | const nodeId = node.id; 32 | curRanks[nodeId] = (1 / nodesCount) 33 | prevRanks[nodeId] = (1 / nodesCount) 34 | } 35 | 36 | const nodeDegree = degree(graphData) 37 | while (maxIterations > 0 && distance > epsilon) { 38 | leakedRank = 0; 39 | for (let j = 0; j < nodesCount; ++j) { 40 | const node = nodes[j]; 41 | const nodeId = node.id; 42 | currentRank = 0; 43 | if (nodeDegree[node.id].inDegree === 0) { 44 | curRanks[nodeId] = 0; 45 | } else { 46 | const neighbors = getNeighbors(nodeId, edges, 'source'); 47 | for (let i = 0; i < neighbors.length; ++i) { 48 | const neighbor = neighbors[i]; 49 | const outDegree: number = nodeDegree[neighbor].outDegree; 50 | if (outDegree > 0) currentRank += (prevRanks[neighbor] / outDegree); 51 | } 52 | curRanks[nodeId] = linkProb * currentRank; 53 | leakedRank += curRanks[nodeId]; 54 | } 55 | } 56 | 57 | leakedRank = (1 - leakedRank) / nodesCount; 58 | distance = 0; 59 | for (let j = 0; j < nodesCount; ++j) { 60 | const node = nodes[j]; 61 | const nodeId = node.id; 62 | currentRank = curRanks[nodeId] + leakedRank; 63 | distance += Math.abs(currentRank - prevRanks[nodeId]); 64 | prevRanks[nodeId] = currentRank; 65 | } 66 | maxIterations -= 1 67 | } 68 | 69 | return prevRanks; 70 | } 71 | 72 | export default pageRank 73 | -------------------------------------------------------------------------------- /packages/graph/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@antv/algorithm", 3 | "version": "0.1.26", 4 | "description": "graph algorithm", 5 | "keywords": [ 6 | "graph", 7 | "algorithm", 8 | "antv", 9 | "G6" 10 | ], 11 | "files": [ 12 | "package.json", 13 | "es", 14 | "lib", 15 | "dist", 16 | "LICENSE", 17 | "README.md" 18 | ], 19 | "main": "lib/index.js", 20 | "module": "es/index.js", 21 | "types": "lib/index.d.ts", 22 | "unpkg": "dist/index.min.js", 23 | "scripts": { 24 | "build": "npm run clean && father build && npm run build:umd", 25 | "build:umd": "webpack --config webpack.config.js --mode production", 26 | "dev:umd": "webpack --config webpack.dev.config.js --mode development", 27 | "ci": "npm run build && npm run coverage", 28 | "clean": "rimraf es lib dist", 29 | "coverage": "jest --coverage", 30 | "lint": "eslint --ext .js,.jsx,.ts,.tsx --format=pretty \"./\"", 31 | "lint:src": "eslint --ext .ts --format=pretty \"./src\"", 32 | "prettier": "prettier -c --write \"**/*\"", 33 | "test": "npm run build:umd && jest", 34 | "test-live": "npm run build:umd && DEBUG_MODE=1 jest --watch ./tests/unit/louvain-spec.ts", 35 | "test-live:async": "npm run build:umd && DEBUG_MODE=1 jest --watch ./tests/unit/louvain-async-spec.ts", 36 | "lint-staged:js": "eslint --ext .js,.jsx,.ts,.tsx", 37 | "cdn": "antv-bin upload -n @antv/algorithm" 38 | }, 39 | "homepage": "https://g6.antv.vision", 40 | "bugs": { 41 | "url": "https://github.com/antvis/algorithm/issues" 42 | }, 43 | "repository": { 44 | "type": "git", 45 | "url": "https://github.com/antvis/algorithm" 46 | }, 47 | "license": "MIT", 48 | "author": "https://github.com/orgs/antvis/people", 49 | "devDependencies": { 50 | "@babel/core": "^7.12.10", 51 | "@babel/plugin-proposal-class-properties": "^7.12.1", 52 | "@babel/preset-env": "^7.12.7", 53 | "@babel/preset-typescript": "^7.12.7", 54 | "@types/jest": "^26.0.18", 55 | "@umijs/fabric": "^2.5.6", 56 | "babel-loader": "^8.2.2", 57 | "father": "^2.30.0", 58 | "jest": "^26.6.3", 59 | "jest-electron": "^0.1.11", 60 | "rimraf": "^3.0.2", 61 | "ts-jest": "^26.4.4", 62 | "ts-loader": "^8.0.14", 63 | "tslint": "^6.1.3", 64 | "typescript": "^4.1.3", 65 | "webpack": "^5.17.0", 66 | "webpack-cli": "^4.9.1", 67 | "worker-loader": "^3.0.7" 68 | }, 69 | "dependencies": { 70 | "@antv/util": "^2.0.13", 71 | "tslib": "^2.0.0" 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # ChangeLog 2 | 3 | #### 0.1.26 4 | 5 | - feat: dfs supports undirected graph data; 6 | 7 | #### 0.1.25 8 | 9 | - feat: Optimized data preprocessing coding - when the feature values are all numerical, use the original values (plus normalization), and do not use one-hot coding 10 | 11 | #### 0.1.24 12 | 13 | - fix: i-louvain without cluster problem; 14 | 15 | #### 0.1.23 16 | 17 | - perf: k-means algorithm: set K to minimum 18 | 19 | #### 0.1.22 20 | 21 | - fix: k-means algorithm, perf: louvain -- support specified parameters such as propertyKey,involvedKeys and uninvolvedKeys 22 | 23 | #### 0.1.21 24 | 25 | - perf: k-means algorithm -- Optimize parameters and return 26 | 27 | #### 0.1.20 28 | 29 | - feat: add k-means algorithm for nodes clustering 30 | 31 | #### 0.1.19 32 | 33 | - fix: GADDI matched failed problem; 34 | 35 | #### 0.1.18 36 | 37 | - feat: add one-hot data preprocessing; 38 | 39 | #### 0.1.17 40 | 41 | - feat: add consine-similarity algorithm and nodes-consine-similarity algorithm; 42 | 43 | #### 0.1.16 44 | 45 | - feat: add i-louvain based on louvain according to academic; 46 | 47 | #### 0.1.15 48 | 49 | - feat: k-core algorithm; 50 | 51 | #### 0.1.14 52 | 53 | - fix: GADDI with proper begin p node; 54 | - feat: louvain with property similarity measure 55 | 56 | #### 0.1.10 57 | 58 | - fix: GADDI with better accuracy; 59 | 60 | #### 0.1.9 61 | 62 | - chore: separate sync and async functions into different entries; 63 | 64 | #### 0.1.8 65 | 66 | - fix: CPU usage increases due to 0.1.3-beta ~ 0.1.3 with publicPath configuration; 67 | - fix: CPU usage increases due to 0.1.6 ~ 0.17 with browser output; 68 | - feat: export fix: export detectAllCycles, detectAllDirectedCycle, detectAllUndirectedCycle; 69 | 70 | #### 0.1.6 71 | 72 | - fix: louvain with increased clusterId and node with correct new clusterId; 73 | 74 | #### 0.1.5 75 | 76 | - fix: worker async problem; 77 | - chore: unify allPath and allPaths; 78 | 79 | #### 0.1.2 80 | 81 | - fix: failed to find result problem in dijkstra; 82 | 83 | #### 0.1.1 84 | 85 | - fix: shortestPath with wrong result; 86 | 87 | #### 0.1.0 88 | 89 | - fix: findShortestPath undefined is not interatable; 90 | 91 | #### 0.1.0-beta.3 92 | 93 | - fix: cannot read degree of undefined problem of GADDI; 94 | 95 | #### 0.1.0-beta 96 | 97 | - feat: worker for algorithms; 98 | - feat: gaddi for graph pattern matching; 99 | 100 | #### 0.0.7 101 | 102 | - feat: dijkstra supports finding multiple shortest paths; 103 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/pagerank-spec.ts: -------------------------------------------------------------------------------- 1 | import { pageRank } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: 'A', 8 | }, 9 | { 10 | id: 'B', 11 | label: 'B', 12 | }, 13 | { 14 | id: 'C', 15 | label: 'C', 16 | }, 17 | { 18 | id: 'D', 19 | label: 'D', 20 | }, 21 | { 22 | id: 'E', 23 | label: 'E', 24 | }, 25 | { 26 | id: 'F', 27 | label: 'F', 28 | }, 29 | { 30 | id: 'G', 31 | label: 'G', 32 | }, 33 | { 34 | id: 'H', 35 | label: 'H', 36 | }, 37 | { 38 | id: 'I', 39 | label: 'I', 40 | }, 41 | { 42 | id: 'J', 43 | label: 'J', 44 | }, 45 | { 46 | id: 'K', 47 | label: 'K', 48 | } 49 | ], 50 | edges: [ 51 | { 52 | source: 'D', 53 | target: 'A', 54 | }, 55 | { 56 | source: 'D', 57 | target: 'B', 58 | }, 59 | { 60 | source: 'B', 61 | target: 'C', 62 | }, 63 | { 64 | source: 'C', 65 | target: 'B', 66 | }, 67 | { 68 | source: 'F', 69 | target: 'B', 70 | }, 71 | { 72 | source: 'F', 73 | target: 'E', 74 | }, 75 | { 76 | source: 'E', 77 | target: 'F', 78 | }, 79 | { 80 | source: 'E', 81 | target: 'D', 82 | }, 83 | { 84 | source: 'E', 85 | target: 'B', 86 | }, 87 | { 88 | source: 'K', 89 | target: 'E', 90 | }, 91 | { 92 | source: 'J', 93 | target: 'E', 94 | }, 95 | { 96 | source: 'I', 97 | target: 'E', 98 | }, 99 | { 100 | source: 'H', 101 | target: 'E', 102 | }, 103 | { 104 | source: 'G', 105 | target: 'E', 106 | }, 107 | { 108 | source: 'G', 109 | target: 'B', 110 | }, 111 | { 112 | source: 'H', 113 | target: 'B', 114 | }, 115 | { 116 | source: 'I', 117 | target: 'B', 118 | }, 119 | ], 120 | }; 121 | 122 | describe('Calculate pagerank of graph nodes', () => { 123 | 124 | it('calculate pagerank', () => { 125 | const result = pageRank(data); 126 | let maxNodeId; 127 | let maxVal = 0; 128 | for (let nodeId in result) { 129 | const val = result[nodeId]; 130 | if (val >= maxVal) { 131 | maxNodeId = nodeId; 132 | maxVal = val 133 | } 134 | } 135 | expect(maxNodeId).toBe('B') 136 | }); 137 | }); 138 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/floydWarshall-spec.ts: -------------------------------------------------------------------------------- 1 | import { floydWarshall } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: '0', 8 | }, 9 | { 10 | id: 'B', 11 | label: '1', 12 | }, 13 | { 14 | id: 'C', 15 | label: '2', 16 | }, 17 | { 18 | id: 'D', 19 | label: '3', 20 | }, 21 | { 22 | id: 'E', 23 | label: '4', 24 | }, 25 | { 26 | id: 'F', 27 | label: '5', 28 | }, 29 | { 30 | id: 'G', 31 | label: '6', 32 | }, 33 | { 34 | id: 'H', 35 | label: '7', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('Adjacency Matrix by Algorithm', () => { 71 | it('get graph shortestpath matrix', () => { 72 | const matrix = floydWarshall(data); 73 | expect(Object.keys(matrix).length).toBe(8); 74 | const node0 = matrix[0]; 75 | expect(node0.length).toBe(8); 76 | expect(node0[0]).toBe(0); 77 | expect(node0[1]).toBe(1); 78 | expect(node0[2]).toBe(2); 79 | expect(node0[3]).toBe(1); 80 | expect(node0[4]).toBe(1); 81 | expect(node0[5]).toBe(2); 82 | expect(node0[6]).toBe(3); 83 | expect(node0[7]).toBe(Infinity); 84 | expect(matrix[1][7]).toBe(Infinity); 85 | expect(matrix[2][7]).toBe(Infinity); 86 | expect(matrix[3][7]).toBe(Infinity); 87 | }); 88 | 89 | it('directed', () => { 90 | // directed 91 | const matrix = floydWarshall(data, true); 92 | expect(Object.keys(matrix).length).toBe(8); 93 | const node0 = matrix[0]; 94 | expect(node0.length).toBe(8); 95 | expect(node0[0]).toBe(0); 96 | expect(node0[1]).toBe(1); 97 | expect(node0[2]).toBe(2); 98 | expect(node0[3]).toBe(1); 99 | expect(node0[4]).toBe(1); 100 | expect(node0[5]).toBe(2); 101 | expect(node0[6]).toBe(3); 102 | expect(node0[7]).toBe(Infinity); 103 | const node8 = matrix[6]; 104 | expect(node8.length).toBe(8); 105 | expect(node8[1]).toBe(Infinity); 106 | expect(node8[6]).toBe(0); 107 | }); 108 | }); 109 | -------------------------------------------------------------------------------- /packages/graph/src/utils/node-properties.ts: -------------------------------------------------------------------------------- 1 | import { NodeConfig } from '../types'; 2 | import { secondReg, dateReg } from '../constants/time'; 3 | 4 | // 获取所有属性并排序 5 | export const getAllSortProperties = (nodes: NodeConfig[] = [], n: number = 100) => { 6 | const propertyKeyInfo = {}; 7 | nodes.forEach(node => { 8 | if (!node.properties) { 9 | return; 10 | } 11 | Object.keys(node.properties).forEach(propertyKey => { 12 | // 目前过滤只保留可以转成数值型的或日期型的, todo: 统一转成one-hot特征向量或者embedding 13 | if (propertyKey === 'id' || !`${node.properties[propertyKey]}`.match(secondReg) && 14 | !`${node.properties[propertyKey]}`.match(dateReg) && 15 | isNaN(Number(node.properties[propertyKey]))) { 16 | if (propertyKeyInfo.hasOwnProperty(propertyKey)) { 17 | delete propertyKeyInfo[propertyKey]; 18 | } 19 | return; 20 | } 21 | if (propertyKeyInfo.hasOwnProperty(propertyKey)) { 22 | propertyKeyInfo[propertyKey] += 1; 23 | } else { 24 | propertyKeyInfo[propertyKey] = 1; 25 | } 26 | }) 27 | }) 28 | 29 | // 取top50的属性 30 | const sortKeys = Object.keys(propertyKeyInfo).sort((a,b) => propertyKeyInfo[b] - propertyKeyInfo[a]); 31 | return sortKeys.length < n ? sortKeys : sortKeys.slice(0, n); 32 | } 33 | 34 | const processProperty = (properties, propertyKeys) => propertyKeys.map(key => { 35 | if (properties.hasOwnProperty(key)) { 36 | // // 可以转成数值的直接转成数值 37 | // if (!isNaN(Number(properties[key]))) { 38 | // return Number(properties[key]); 39 | // } 40 | // // 时间型的转成时间戳 41 | // if (properties[key].match(secondReg) || properties[key].match(dateReg)) { 42 | // // @ts-ignore 43 | // return Number(Date.parse(new Date(properties[key]))) / 1000; 44 | // } 45 | return properties[key]; 46 | } 47 | return 0; 48 | }) 49 | 50 | // 获取属性特征权重 51 | export const getPropertyWeight = (nodes: NodeConfig[]) => { 52 | const propertyKeys = getAllSortProperties(nodes); 53 | let allPropertiesWeight = []; 54 | for (let i = 0; i < nodes.length; i++) { 55 | allPropertiesWeight[i] = processProperty(nodes[i].properties, propertyKeys); 56 | } 57 | return allPropertiesWeight; 58 | } 59 | 60 | // 获取所有节点的属性集合 61 | export const getAllProperties = (nodes, key = undefined) => { 62 | const allProperties = []; 63 | nodes.forEach(node => { 64 | if (key === undefined) { 65 | allProperties.push(node); 66 | } 67 | if (node[key] !== undefined) { 68 | allProperties.push(node[key]); 69 | } 70 | }) 71 | return allProperties; 72 | } 73 | 74 | export default { 75 | getAllSortProperties, 76 | getPropertyWeight, 77 | getAllProperties 78 | } 79 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/pagerank-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: 'A', 8 | }, 9 | { 10 | id: 'B', 11 | label: 'B', 12 | }, 13 | { 14 | id: 'C', 15 | label: 'C', 16 | }, 17 | { 18 | id: 'D', 19 | label: 'D', 20 | }, 21 | { 22 | id: 'E', 23 | label: 'E', 24 | }, 25 | { 26 | id: 'F', 27 | label: 'F', 28 | }, 29 | { 30 | id: 'G', 31 | label: 'G', 32 | }, 33 | { 34 | id: 'H', 35 | label: 'H', 36 | }, 37 | { 38 | id: 'I', 39 | label: 'I', 40 | }, 41 | { 42 | id: 'J', 43 | label: 'J', 44 | }, 45 | { 46 | id: 'K', 47 | label: 'K', 48 | } 49 | ], 50 | edges: [ 51 | { 52 | source: 'D', 53 | target: 'A', 54 | }, 55 | { 56 | source: 'D', 57 | target: 'B', 58 | }, 59 | { 60 | source: 'B', 61 | target: 'C', 62 | }, 63 | { 64 | source: 'C', 65 | target: 'B', 66 | }, 67 | { 68 | source: 'F', 69 | target: 'B', 70 | }, 71 | { 72 | source: 'F', 73 | target: 'E', 74 | }, 75 | { 76 | source: 'E', 77 | target: 'F', 78 | }, 79 | { 80 | source: 'E', 81 | target: 'D', 82 | }, 83 | { 84 | source: 'E', 85 | target: 'B', 86 | }, 87 | { 88 | source: 'K', 89 | target: 'E', 90 | }, 91 | { 92 | source: 'J', 93 | target: 'E', 94 | }, 95 | { 96 | source: 'I', 97 | target: 'E', 98 | }, 99 | { 100 | source: 'H', 101 | target: 'E', 102 | }, 103 | { 104 | source: 'G', 105 | target: 'E', 106 | }, 107 | { 108 | source: 'G', 109 | target: 'B', 110 | }, 111 | { 112 | source: 'H', 113 | target: 'B', 114 | }, 115 | { 116 | source: 'I', 117 | target: 'B', 118 | }, 119 | ], 120 | }; 121 | 122 | describe('(Async) Calculate pagerank of graph nodes', () => { 123 | 124 | it('calculate pagerank', async () => { 125 | const { pageRankAsync } = await getAlgorithm(); 126 | const result = await pageRankAsync(data); 127 | let maxNodeId; 128 | let maxVal = 0; 129 | for (let nodeId in result) { 130 | const val = result[nodeId]; 131 | if (val >= maxVal) { 132 | maxNodeId = nodeId; 133 | maxVal = val 134 | } 135 | } 136 | expect(maxNodeId).toBe('B') 137 | }); 138 | }); 139 | -------------------------------------------------------------------------------- /packages/webgpu-graph/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@antv/webgpu-graph", 3 | "version": "1.0.0", 4 | "description": "provide common graph algorithms implemented with WebGPU", 5 | "keywords": [ 6 | "webgpu", 7 | "graph", 8 | "algorithm", 9 | "antv", 10 | "G6" 11 | ], 12 | "publishConfig": { 13 | "access": "public" 14 | }, 15 | "files": [ 16 | "package.json", 17 | "es", 18 | "lib", 19 | "dist", 20 | "LICENSE", 21 | "README.md", 22 | "README-zh_CN.md" 23 | ], 24 | "main": "lib/index.js", 25 | "module": "es/index.js", 26 | "types": "lib/index.d.ts", 27 | "unpkg": "dist/index.min.js", 28 | "scripts": { 29 | "build": "npm run clean && father build && npm run build:umd", 30 | "build:umd": "webpack --config webpack.config.js --mode production", 31 | "dev:umd": "webpack --config webpack.dev.config.js --mode development", 32 | "ci": "npm run build && npm run coverage", 33 | "clean": "rimraf es lib dist", 34 | "coverage": "jest --coverage", 35 | "lint": "eslint --ext .js,.jsx,.ts,.tsx --format=pretty \"./\"", 36 | "lint:src": "eslint --ext .ts --format=pretty \"./src\"", 37 | "prettier": "prettier -c --write \"**/*\"", 38 | "test": "npm run build:umd && jest", 39 | "test-live": "npm run build:umd && DEBUG_MODE=1 jest --watch ./tests/unit/louvain-spec.ts", 40 | "test-live:async": "npm run build:umd && DEBUG_MODE=1 jest --watch ./tests/unit/louvain-async-spec.ts", 41 | "lint-staged:js": "eslint --ext .js,.jsx,.ts,.tsx", 42 | "cdn": "antv-bin upload -n @antv/webgpu-graph" 43 | }, 44 | "homepage": "https://g6.antv.vision", 45 | "bugs": { 46 | "url": "https://github.com/antvis/algorithm/issues" 47 | }, 48 | "repository": { 49 | "type": "git", 50 | "url": "https://github.com/antvis/algorithm" 51 | }, 52 | "license": "MIT", 53 | "author": "https://github.com/orgs/antvis/people", 54 | "devDependencies": { 55 | "@babel/core": "^7.12.10", 56 | "@babel/plugin-proposal-class-properties": "^7.12.1", 57 | "@babel/preset-env": "^7.12.7", 58 | "@babel/preset-typescript": "^7.12.7", 59 | "@types/jest": "^26.0.18", 60 | "@umijs/fabric": "^2.5.6", 61 | "babel-loader": "^8.2.2", 62 | "father": "^2.30.0", 63 | "jest": "^26.6.3", 64 | "jest-electron": "^0.1.11", 65 | "rimraf": "^3.0.2", 66 | "ts-jest": "^26.4.4", 67 | "ts-loader": "^8.0.14", 68 | "tslint": "^6.1.3", 69 | "typescript": "^4.1.3", 70 | "webpack": "^5.17.0", 71 | "webpack-cli": "^4.9.1" 72 | }, 73 | "dependencies": { 74 | "@antv/g": "^5.0.16", 75 | "@antv/g-webgl": "^1.0.19", 76 | "@antv/g-plugin-gpgpu": "^1.0.14", 77 | "@types/offscreencanvas": "^2019.6.4", 78 | "@webgpu/types": "^0.1.6", 79 | "tslib": "^2.0.0" 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/floydWarshall-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: '0', 8 | }, 9 | { 10 | id: 'B', 11 | label: '1', 12 | }, 13 | { 14 | id: 'C', 15 | label: '2', 16 | }, 17 | { 18 | id: 'D', 19 | label: '3', 20 | }, 21 | { 22 | id: 'E', 23 | label: '4', 24 | }, 25 | { 26 | id: 'F', 27 | label: '5', 28 | }, 29 | { 30 | id: 'G', 31 | label: '6', 32 | }, 33 | { 34 | id: 'H', 35 | label: '7', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('(Async) Adjacency Matrix by Algorithm', () => { 71 | it('get graph shortestpath matrix', async () => { 72 | const { floydWarshallAsync } = await getAlgorithm(); 73 | const matrix = await floydWarshallAsync(data); 74 | expect(Object.keys(matrix).length).toBe(8); 75 | const node0 = matrix[0]; 76 | expect(node0.length).toBe(8); 77 | expect(node0[0]).toBe(0); 78 | expect(node0[1]).toBe(1); 79 | expect(node0[2]).toBe(2); 80 | expect(node0[3]).toBe(1); 81 | expect(node0[4]).toBe(1); 82 | expect(node0[5]).toBe(2); 83 | expect(node0[6]).toBe(3); 84 | expect(node0[7]).toBe(Infinity); 85 | expect(matrix[1][7]).toBe(Infinity); 86 | expect(matrix[2][7]).toBe(Infinity); 87 | expect(matrix[3][7]).toBe(Infinity); 88 | }); 89 | 90 | it('directed', async () => { 91 | // directed 92 | const { floydWarshallAsync } = await getAlgorithm(); 93 | const matrix = await floydWarshallAsync(data, true); 94 | expect(Object.keys(matrix).length).toBe(8); 95 | const node0 = matrix[0]; 96 | expect(node0.length).toBe(8); 97 | expect(node0[0]).toBe(0); 98 | expect(node0[1]).toBe(1); 99 | expect(node0[2]).toBe(2); 100 | expect(node0[3]).toBe(1); 101 | expect(node0[4]).toBe(1); 102 | expect(node0[5]).toBe(2); 103 | expect(node0[6]).toBe(3); 104 | expect(node0[7]).toBe(Infinity); 105 | const node8 = matrix[6]; 106 | expect(node8.length).toBe(8); 107 | expect(node8[1]).toBe(Infinity); 108 | expect(node8[6]).toBe(0); 109 | }); 110 | }); 111 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/label-propagation-spec.ts: -------------------------------------------------------------------------------- 1 | import { labelPropagation } from '../../src'; 2 | import { GraphData } from '../../src/types'; 3 | 4 | describe('label propagation', () => { 5 | it('simple label propagation', () => { 6 | const data: GraphData = { 7 | nodes: [ 8 | { id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' }, 9 | { id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' }, 10 | { id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' }, 11 | ], 12 | edges: [ 13 | { source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' }, 14 | { source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' }, 15 | { source: '2', target: '3' }, { source: '2', target: '4' }, 16 | { source: '3', target: '4' }, 17 | { source: '0', target: '0' }, 18 | { source: '0', target: '0' }, 19 | { source: '0', target: '0' }, 20 | 21 | { source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' }, 22 | { source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' }, 23 | { source: '7', target: '8' }, { source: '7', target: '9' }, 24 | { source: '8', target: '9' }, 25 | 26 | { source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' }, 27 | { source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' }, 28 | { source: '12', target: '13' }, { source: '12', target: '14' }, 29 | { source: '13', target: '14', weight: 5 }, 30 | 31 | { source: '0', target: '5' }, 32 | { source: '5', target: '10' }, 33 | { source: '10', target: '0' }, 34 | { source: '10', target: '0' }, 35 | ] 36 | } 37 | const clusteredData = labelPropagation(data, false, 'weight'); 38 | expect(clusteredData.clusters.length).not.toBe(0); 39 | expect(clusteredData.clusterEdges.length).not.toBe(0); 40 | }); 41 | it('label propagation with large graph', () => { // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json 42 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 43 | .then((res) => res.json()) 44 | .then((data) => { // 1589 nodes, 2747 edges 45 | const clusteredData = labelPropagation(data, false, 'weight'); 46 | // console.log(`Call to doSomething took ${t1 - t0} milliseconds.`); 47 | 48 | // 9037.91999999521 ms 49 | 50 | expect(clusteredData.clusters.length).toBe(472); 51 | expect(clusteredData.clusterEdges.length).toBe(444); 52 | }); 53 | }); 54 | }); 55 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/connected-component-spec.ts: -------------------------------------------------------------------------------- 1 | import { connectedComponent } from '../../src' 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'A', 41 | target: 'C', 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | }, 47 | { 48 | source: 'D', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | { 60 | source: 'G', 61 | target: 'H', 62 | }, 63 | { 64 | source: 'H', 65 | target: 'G', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('find connected components', () => { 71 | 72 | it('detect strongly connected components in undirected graph', () => { 73 | let result = connectedComponent(data, false); 74 | expect(result.length).toEqual(2); 75 | expect(result[0].map((node) => node.id).sort()).toEqual(['A', 'B', 'C', 'D', 'E', 'F']); 76 | expect(result[1].map((node) => node.id).sort()).toEqual(['G', 'H']); 77 | }); 78 | 79 | it('detect strongly connected components in directed graph', () => { 80 | let result = connectedComponent(data, true); 81 | expect(result.length).toEqual(5); 82 | expect(result[3].map((node) => node.id).sort()).toEqual(['D', 'E', 'F']); 83 | expect(result[4].map((node) => node.id).sort()).toEqual(['G', 'H']); 84 | }); 85 | 86 | it('test connected components detection performance using large graph', () => { 87 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 88 | .then((res) => res.json()) 89 | .then((data) => { 90 | data.nodes.forEach((node) => { 91 | node.label = node.olabel; 92 | node.degree = 0; 93 | data.edges.forEach((edge) => { 94 | if (edge.source === node.id || edge.target === node.id) { 95 | node.degree++; 96 | } 97 | }); 98 | }); 99 | 100 | let directedComps = connectedComponent(data, true); 101 | let undirectedComps = connectedComponent(data, false); 102 | expect(directedComps.length).toEqual(1589); 103 | expect(undirectedComps.length).toEqual(396); 104 | }); 105 | }); 106 | }); 107 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/adjacent-matrix-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAdjMatrix } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: '0', 8 | }, 9 | { 10 | id: 'B', 11 | label: '1', 12 | }, 13 | { 14 | id: 'C', 15 | label: '2', 16 | }, 17 | { 18 | id: 'D', 19 | label: '3', 20 | }, 21 | { 22 | id: 'E', 23 | label: '4', 24 | }, 25 | { 26 | id: 'F', 27 | label: '5', 28 | }, 29 | { 30 | id: 'G', 31 | label: '6', 32 | }, 33 | { 34 | id: 'H', 35 | label: '7', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('Adjacency Matrix', () => { 71 | it('undirected', () => { 72 | const matrix = getAdjMatrix(data); 73 | expect(Object.keys(matrix).length).toBe(8); 74 | const node0Adj = matrix[0]; 75 | expect(node0Adj.length).toBe(5); 76 | expect(node0Adj[0]).toBe(undefined); 77 | expect(node0Adj[1]).toBe(1); 78 | expect(node0Adj[2]).toBe(undefined); 79 | expect(node0Adj[3]).toBe(1); 80 | expect(node0Adj[4]).toBe(1); 81 | 82 | const node1Adj = matrix[1]; 83 | expect(node1Adj.length).toBe(3); 84 | expect(node1Adj[0]).toBe(1); 85 | expect(node1Adj[1]).toBe(undefined); 86 | expect(node1Adj[2]).toBe(1); 87 | 88 | const node5Adj = matrix[5]; 89 | expect(node5Adj.length).toBe(5); 90 | expect(node5Adj[0]).toBe(undefined); 91 | expect(node5Adj[1]).toBe(undefined); 92 | expect(node5Adj[2]).toBe(undefined); 93 | expect(node5Adj[3]).toBe(1); 94 | expect(node5Adj[4]).toBe(1); 95 | }); 96 | 97 | it('directed', () => { 98 | const matrix = getAdjMatrix(data, true); 99 | expect(Object.keys(matrix).length).toBe(8); 100 | const node0Adj = matrix[0]; 101 | expect(node0Adj.length).toBe(5); 102 | expect(node0Adj[0]).toBe(undefined); 103 | expect(node0Adj[1]).toBe(1); 104 | expect(node0Adj[2]).toBe(undefined); 105 | expect(node0Adj[3]).toBe(1); 106 | expect(node0Adj[4]).toBe(1); 107 | 108 | const node1Adj = matrix[1]; 109 | expect(node1Adj.length).toBe(3); 110 | expect(node1Adj[0]).toBe(undefined); 111 | expect(node1Adj[1]).toBe(undefined); 112 | expect(node1Adj[2]).toBe(1); 113 | }); 114 | }); 115 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/linked-list-spec.ts: -------------------------------------------------------------------------------- 1 | import LinkedList, { LinkedListNode } from '../../src/structs/linked-list' 2 | 3 | describe('linked list struct', () => { 4 | it('init linedListNode', () => { 5 | const linkedNode1 = new LinkedListNode(1) 6 | const linkedNode2 = new LinkedListNode(2, linkedNode1) 7 | 8 | expect(linkedNode1.value).toBe(1) 9 | expect(linkedNode1.next).toBe(null) 10 | expect(linkedNode1.toString()).toEqual('1') 11 | 12 | expect(linkedNode2.value).toBe(2) 13 | expect(linkedNode2.next).toBe(linkedNode1) 14 | expect(linkedNode2.toString()).toBe('2') 15 | }) 16 | 17 | const linkedList = new LinkedList() 18 | it('init linked list', () => { 19 | expect(linkedList).not.toBe(undefined) 20 | }) 21 | 22 | it('find & append', () => { 23 | let node1 = linkedList.find({ value: 1 }) 24 | expect(node1).toBe(null) 25 | 26 | // append node 27 | linkedList.append(1) 28 | node1 = linkedList.find({ value: 1 }) 29 | 30 | expect(node1).not.toBe(null) 31 | expect(node1.value).toBe(1) 32 | }) 33 | 34 | it('prepend', () => { 35 | linkedList.prepend(2) 36 | 37 | const node1 = linkedList.find({ value: 1 }) 38 | const node2 = linkedList.find({ value: 2 }) 39 | expect(linkedList.toArray()).toEqual([node2, node1]) 40 | expect(linkedList.toString()).toEqual('2,1') 41 | }) 42 | 43 | it('deleteHead', () => { 44 | const deleteHead = linkedList.deleteHead() 45 | expect(deleteHead).not.toBe(undefined) 46 | expect(deleteHead.value).toEqual(2) 47 | expect(deleteHead.next).toEqual({ next: null, value: 1 }) 48 | }) 49 | 50 | it('deleteTail', () => { 51 | linkedList.prepend(3) 52 | 53 | const deleteTail = linkedList.deleteTail() 54 | expect(deleteTail).not.toBe(undefined) 55 | expect(deleteTail.value).toBe(1) 56 | expect(deleteTail.next).toBe(null) 57 | }) 58 | 59 | it('delete', () => { 60 | linkedList.append(5) 61 | linkedList.append(6) 62 | 63 | const node3 = linkedList.find({ value: 3 }) 64 | const node5 = linkedList.find({ value: 5 }) 65 | const node6 = linkedList.find({ value: 6 }) 66 | expect(linkedList.toArray()).toEqual([node3, node5, node6]) 67 | expect(linkedList.toString()).toEqual('3,5,6') 68 | 69 | // 删除一个不存在的元素 70 | let deleteNode = linkedList.delete(8) 71 | expect(deleteNode).toBe(null) 72 | 73 | // 删除存在的元素 74 | deleteNode = linkedList.delete(5) 75 | expect(deleteNode).not.toBe(null) 76 | expect(deleteNode.value).toBe(5) 77 | expect(deleteNode.next).toEqual({ next: null, value: 6 }) 78 | 79 | deleteNode = linkedList.find({ value: 5 }) 80 | expect(deleteNode).toBe(null) 81 | }) 82 | 83 | it('reverse', () => { 84 | expect(linkedList.toString()).toEqual('3,6') 85 | linkedList.reverse() 86 | expect(linkedList.toString()).toEqual('6,3') 87 | }) 88 | }) -------------------------------------------------------------------------------- /packages/graph/tests/unit/degree-spec.ts: -------------------------------------------------------------------------------- 1 | import { getDegree, getInDegree, getOutDegree } from '../../src' 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'A', 41 | target: 'C', 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | }, 47 | { 48 | source: 'D', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | { 60 | source: 'G', 61 | target: 'H', 62 | }, 63 | { 64 | source: 'H', 65 | target: 'G', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('degree algorithm', () => { 71 | it('getDegree', () => { 72 | const degree = { 73 | A: { 74 | degree: 3, 75 | inDegree: 1, 76 | outDegree: 2 77 | }, 78 | B: { 79 | degree: 2, 80 | inDegree: 1, 81 | outDegree: 1 82 | }, 83 | C: { 84 | degree: 2, 85 | inDegree: 2, 86 | outDegree: 0 87 | }, 88 | D: { 89 | degree: 3, 90 | inDegree: 1, 91 | outDegree: 2 92 | }, 93 | E: { 94 | degree: 2, 95 | inDegree: 1, 96 | outDegree: 1 97 | }, 98 | F: { 99 | degree: 2, 100 | inDegree: 1, 101 | outDegree: 1 102 | }, 103 | G: { 104 | degree: 2, 105 | inDegree: 1, 106 | outDegree: 1 107 | }, 108 | H: { 109 | degree: 2, 110 | inDegree: 1, 111 | outDegree: 1 112 | } 113 | } 114 | let result = getDegree(data); 115 | expect(result).toEqual(degree); 116 | }); 117 | 118 | it('getInDegree', () => { 119 | let result = getInDegree(data, 'A'); 120 | expect(result).toBe(1); 121 | 122 | result = getInDegree(data, 'C') 123 | expect(result).toBe(2) 124 | 125 | result = getInDegree(data, 'E') 126 | expect(result).toBe(1) 127 | }); 128 | 129 | it('getOutDegree', () => { 130 | let result = getOutDegree(data, 'A'); 131 | expect(result).toEqual(2); 132 | 133 | result = getOutDegree(data, 'D'); 134 | expect(result).toEqual(2); 135 | 136 | result = getOutDegree(data, 'F'); 137 | expect(result).toEqual(1); 138 | }); 139 | }); 140 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/adjacent-matrix-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: '0', 8 | }, 9 | { 10 | id: 'B', 11 | label: '1', 12 | }, 13 | { 14 | id: 'C', 15 | label: '2', 16 | }, 17 | { 18 | id: 'D', 19 | label: '3', 20 | }, 21 | { 22 | id: 'E', 23 | label: '4', 24 | }, 25 | { 26 | id: 'F', 27 | label: '5', 28 | }, 29 | { 30 | id: 'G', 31 | label: '6', 32 | }, 33 | { 34 | id: 'H', 35 | label: '7', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('(Async) Adjacency Matrix', () => { 71 | it('undirected', async () => { 72 | const { getAdjMatrixAsync } = await getAlgorithm(); 73 | const matrix = await getAdjMatrixAsync(data); 74 | expect(Object.keys(matrix).length).toBe(8); 75 | const node0Adj = matrix[0]; 76 | expect(node0Adj.length).toBe(5); 77 | expect(node0Adj[0]).toBe(undefined); 78 | expect(node0Adj[1]).toBe(1); 79 | expect(node0Adj[2]).toBe(undefined); 80 | expect(node0Adj[3]).toBe(1); 81 | expect(node0Adj[4]).toBe(1); 82 | 83 | const node1Adj = matrix[1]; 84 | expect(node1Adj.length).toBe(3); 85 | expect(node1Adj[0]).toBe(1); 86 | expect(node1Adj[1]).toBe(undefined); 87 | expect(node1Adj[2]).toBe(1); 88 | 89 | const node5Adj = matrix[5]; 90 | expect(node5Adj.length).toBe(5); 91 | expect(node5Adj[0]).toBe(undefined); 92 | expect(node5Adj[1]).toBe(undefined); 93 | expect(node5Adj[2]).toBe(undefined); 94 | expect(node5Adj[3]).toBe(1); 95 | expect(node5Adj[4]).toBe(1); 96 | }); 97 | 98 | it('directed', async () => { 99 | const { getAdjMatrixAsync } = await getAlgorithm(); 100 | const matrix = await getAdjMatrixAsync(data, true); 101 | expect(Object.keys(matrix).length).toBe(8); 102 | const node0Adj = matrix[0]; 103 | expect(node0Adj.length).toBe(5); 104 | expect(node0Adj[0]).toBe(undefined); 105 | expect(node0Adj[1]).toBe(1); 106 | expect(node0Adj[2]).toBe(undefined); 107 | expect(node0Adj[3]).toBe(1); 108 | expect(node0Adj[4]).toBe(1); 109 | 110 | const node1Adj = matrix[1]; 111 | expect(node1Adj.length).toBe(3); 112 | expect(node1Adj[0]).toBe(undefined); 113 | expect(node1Adj[1]).toBe(undefined); 114 | expect(node1Adj[2]).toBe(1); 115 | }); 116 | }); 117 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/connected-component-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'A', 41 | target: 'C', 42 | }, 43 | { 44 | source: 'D', 45 | target: 'A', 46 | }, 47 | { 48 | source: 'D', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | { 60 | source: 'G', 61 | target: 'H', 62 | }, 63 | { 64 | source: 'H', 65 | target: 'G', 66 | }, 67 | ], 68 | }; 69 | 70 | describe('(Async) find connected components', () => { 71 | it('detect strongly connected components in undirected graph', async done => { 72 | const { connectedComponentAsync } = await getAlgorithm(); 73 | let result = await connectedComponentAsync(data, false); 74 | expect(result.length).toEqual(2); 75 | expect(result[0].map(node => node.id).sort()).toEqual(['A', 'B', 'C', 'D', 'E', 'F']); 76 | expect(result[1].map(node => node.id).sort()).toEqual(['G', 'H']); 77 | done(); 78 | }); 79 | 80 | it('detect strongly connected components in directed graph', async done => { 81 | const { connectedComponentAsync } = await getAlgorithm(); 82 | let result = await connectedComponentAsync(data, true); 83 | expect(result.length).toEqual(5); 84 | expect(result[3].map(node => node.id).sort()).toEqual(['D', 'E', 'F']); 85 | expect(result[4].map(node => node.id).sort()).toEqual(['G', 'H']); 86 | done(); 87 | }); 88 | 89 | it('test connected components detection performance using large graph', async done => { 90 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 91 | .then(res => res.json()) 92 | .then(async data => { 93 | const { connectedComponentAsync } = await getAlgorithm(); 94 | 95 | data.nodes.forEach(node => { 96 | node.label = node.olabel; 97 | node.degree = 0; 98 | data.edges.forEach(edge => { 99 | if (edge.source === node.id || edge.target === node.id) { 100 | node.degree++; 101 | } 102 | }); 103 | }); 104 | 105 | let directedComps = await connectedComponentAsync(data, true); 106 | let undirectedComps = await connectedComponentAsync(data, false); 107 | expect(directedComps.length).toEqual(1589); 108 | expect(undirectedComps.length).toEqual(396); 109 | done(); 110 | }); 111 | }); 112 | }); 113 | -------------------------------------------------------------------------------- /packages/graph/src/gSpan/struct.ts: -------------------------------------------------------------------------------- 1 | import { indexOf } from "@antv/util"; 2 | 3 | export const VACANT_EDGE_ID = -1; 4 | export const VACANT_NODE_ID = -1; 5 | export const VACANT_EDGE_LABEL = "-1"; 6 | export const VACANT_NODE_LABEL = "-1"; 7 | export const VACANT_GRAPH_ID = -1; 8 | export const AUTO_EDGE_ID = "-1"; 9 | 10 | export class Edge { 11 | public id: number; 12 | public from: number; 13 | public to: number; 14 | public label: string; 15 | 16 | constructor( 17 | id = VACANT_EDGE_ID, 18 | from = VACANT_NODE_ID, 19 | to = VACANT_NODE_ID, 20 | label = VACANT_EDGE_LABEL 21 | ) { 22 | this.id = id; 23 | this.from = from; 24 | this.to = to; 25 | this.label = label; 26 | } 27 | } 28 | 29 | export class Node { 30 | public id: number; 31 | public from: number; 32 | public to: number; 33 | public label: string; 34 | public edges: Edge[]; 35 | public edgeMap: {}; 36 | 37 | constructor(id = VACANT_NODE_ID, label = VACANT_NODE_LABEL) { 38 | this.id = id; 39 | this.label = label; 40 | this.edges = []; 41 | this.edgeMap = {}; 42 | } 43 | 44 | addEdge(edge) { 45 | this.edges.push(edge); 46 | this.edgeMap[edge.id] = edge; 47 | } 48 | } 49 | 50 | export class Graph { 51 | public id: number; 52 | public from: number; 53 | public to: number; 54 | public label: string; 55 | public edgeIdAutoIncrease: boolean; 56 | public nodes: Node[]; 57 | public edges: Edge[]; 58 | public nodeMap: {}; 59 | public edgeMap: {}; 60 | public nodeLabelMap: {}; // key 是 label,value 是节点 id 的数组 61 | public edgeLabelMap: {}; 62 | private counter: number; // 自增用于自动生成边 id 63 | public directed: boolean; 64 | 65 | constructor( 66 | id = VACANT_NODE_ID, 67 | edgeIdAutoIncrease = true, 68 | directed = false 69 | ) { 70 | this.id = id; 71 | this.edgeIdAutoIncrease = edgeIdAutoIncrease; 72 | this.edges = []; 73 | this.nodes = []; 74 | this.nodeMap = {}; 75 | this.edgeMap = {}; 76 | this.nodeLabelMap = {}; 77 | this.edgeLabelMap = {}; 78 | this.counter = 0; 79 | this.directed = directed; 80 | } 81 | 82 | getNodeNum() { 83 | return this.nodes.length; 84 | } 85 | 86 | addNode(id: number, label: string) { 87 | if (this.nodeMap[id]) return; 88 | const node = new Node(id, label); 89 | this.nodes.push(node); 90 | this.nodeMap[id] = node; 91 | if (!this.nodeLabelMap[label]) this.nodeLabelMap[label] = []; 92 | this.nodeLabelMap[label].push(id); 93 | } 94 | 95 | addEdge(id: number, from: number, to: number, label: string) { 96 | if (this.edgeIdAutoIncrease || id === undefined) id = this.counter++; 97 | if (this.nodeMap[from] && this.nodeMap[to] && this.nodeMap[to].edgeMap[id]) 98 | return; 99 | const edge = new Edge(id, from, to, label); 100 | this.edges.push(edge); 101 | this.edgeMap[id] = edge; 102 | 103 | this.nodeMap[from].addEdge(edge); 104 | 105 | if (!this.edgeLabelMap[label]) this.edgeLabelMap[label] = []; 106 | this.edgeLabelMap[label].push(edge); 107 | 108 | if (!this.directed) { 109 | const rEdge = new Edge(id, to, from, label); 110 | this.nodeMap[to].addEdge(rEdge); 111 | this.edgeLabelMap[label].push(rEdge); 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/find-path-spec.ts: -------------------------------------------------------------------------------- 1 | import { findAllPath, findShortestPath } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: 'A', 8 | }, 9 | { 10 | id: 'B', 11 | label: 'B', 12 | }, 13 | { 14 | id: 'C', 15 | label: 'C', 16 | }, 17 | { 18 | id: 'D', 19 | label: 'D', 20 | }, 21 | { 22 | id: 'E', 23 | label: 'E', 24 | }, 25 | { 26 | id: 'F', 27 | label: 'F', 28 | }, 29 | { 30 | id: 'G', 31 | label: 'G', 32 | }, 33 | { 34 | id: 'H', 35 | label: 'H', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | { 68 | source: 'D', 69 | target: 'E', 70 | }, 71 | ], 72 | }; 73 | 74 | describe('Shortest Path from source to target on graph', () => { 75 | it('find the shortest path', () => { 76 | const { length, path } = findShortestPath(data, 'A', 'C'); 77 | expect(length).toBe(2); 78 | expect(path).toStrictEqual(['A', 'B', 'C']); 79 | }); 80 | 81 | it('find all shortest paths', () => { 82 | const { length, allPath } = findShortestPath(data, 'A', 'F'); 83 | expect(length).toBe(2); 84 | expect(allPath[0]).toStrictEqual(['A', 'E', 'F']); 85 | expect(allPath[1]).toStrictEqual(['A', 'D', 'F']); 86 | 87 | const { 88 | length: directedLenght, 89 | path: directedPath, 90 | allPath: directedAllPath, 91 | } = findShortestPath(data, 'A', 'F', true); 92 | expect(directedLenght).toBe(2); 93 | expect(directedAllPath[0]).toStrictEqual(['A', 'E', 'F']); 94 | expect(directedPath).toStrictEqual(['A', 'E', 'F']); 95 | }); 96 | 97 | it('find all paths', () => { 98 | const allPath = findAllPath(data, 'A', 'E'); 99 | expect(allPath.length).toBe(3); 100 | expect(allPath[0]).toStrictEqual(['A', 'D', 'F', 'E']); 101 | expect(allPath[1]).toStrictEqual(['A', 'D', 'E']); 102 | expect(allPath[2]).toStrictEqual(['A', 'E']); 103 | }); 104 | 105 | it('find all paths in directed graph', () => { 106 | const allPath = findAllPath(data, 'A', 'E', true); 107 | expect(allPath.length).toStrictEqual(2); 108 | expect(allPath[0]).toStrictEqual(['A', 'D', 'E']); 109 | expect(allPath[1]).toStrictEqual(['A', 'E']); 110 | }); 111 | 112 | it('find all shortest paths in weighted graph', () => { 113 | data.edges.forEach((edge: any, i) => { 114 | edge.weight = ((i % 2) + 1) * 2; 115 | if (edge.source === 'F' && edge.target === 'D') edge.weight = 10; 116 | }); 117 | const { length, path, allPath } = findShortestPath(data, 'A', 'F', false, 'weight'); 118 | expect(length).toBe(6); 119 | expect(allPath[0]).toStrictEqual(['A', 'E', 'F']); 120 | expect(path).toStrictEqual(['A', 'E', 'F']); 121 | }); 122 | }); 123 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/label-propagation-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | import { GraphData } from '../../src/types'; 3 | 4 | describe('(Async) label propagation', () => { 5 | it('simple label propagation', async done => { 6 | const { labelPropagationAsync } = await getAlgorithm(); 7 | const data: GraphData = { 8 | nodes: [ 9 | { id: '0' }, 10 | { id: '1' }, 11 | { id: '2' }, 12 | { id: '3' }, 13 | { id: '4' }, 14 | { id: '5' }, 15 | { id: '6' }, 16 | { id: '7' }, 17 | { id: '8' }, 18 | { id: '9' }, 19 | { id: '10' }, 20 | { id: '11' }, 21 | { id: '12' }, 22 | { id: '13' }, 23 | { id: '14' }, 24 | ], 25 | edges: [ 26 | { source: '0', target: '1' }, 27 | { source: '0', target: '2' }, 28 | { source: '0', target: '3' }, 29 | { source: '0', target: '4' }, 30 | { source: '1', target: '2' }, 31 | { source: '1', target: '3' }, 32 | { source: '1', target: '4' }, 33 | { source: '2', target: '3' }, 34 | { source: '2', target: '4' }, 35 | { source: '3', target: '4' }, 36 | { source: '0', target: '0' }, 37 | { source: '0', target: '0' }, 38 | { source: '0', target: '0' }, 39 | 40 | { source: '5', target: '6', weight: 5 }, 41 | { source: '5', target: '7' }, 42 | { source: '5', target: '8' }, 43 | { source: '5', target: '9' }, 44 | { source: '6', target: '7' }, 45 | { source: '6', target: '8' }, 46 | { source: '6', target: '9' }, 47 | { source: '7', target: '8' }, 48 | { source: '7', target: '9' }, 49 | { source: '8', target: '9' }, 50 | 51 | { source: '10', target: '11' }, 52 | { source: '10', target: '12' }, 53 | { source: '10', target: '13' }, 54 | { source: '10', target: '14' }, 55 | { source: '11', target: '12' }, 56 | { source: '11', target: '13' }, 57 | { source: '11', target: '14' }, 58 | { source: '12', target: '13' }, 59 | { source: '12', target: '14' }, 60 | { source: '13', target: '14', weight: 5 }, 61 | 62 | { source: '0', target: '5' }, 63 | { source: '5', target: '10' }, 64 | { source: '10', target: '0' }, 65 | { source: '10', target: '0' }, 66 | ], 67 | }; 68 | const clusteredData = await labelPropagationAsync(data, false, 'weight'); 69 | expect(clusteredData.clusters.length).not.toBe(0); 70 | expect(clusteredData.clusterEdges.length).not.toBe(0); 71 | done(); 72 | }); 73 | it('label propagation with large graph', () => { 74 | // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json 75 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 76 | .then(res => res.json()) 77 | .then(async data => { 78 | // 1589 nodes, 2747 edges 79 | const { labelPropagationAsync } = await getAlgorithm(); 80 | const clusteredData = await labelPropagationAsync(data, false, 'weight'); 81 | // console.log(`Call to doSomething took ${t1 - t0} milliseconds.`); 82 | 83 | // 9037.91999999521 ms 84 | 85 | expect(clusteredData.clusters.length).toBe(472); 86 | expect(clusteredData.clusterEdges.length).toBe(444); 87 | }); 88 | }); 89 | }); 90 | -------------------------------------------------------------------------------- /packages/graph/src/dijkstra.ts: -------------------------------------------------------------------------------- 1 | import { isArray } from '@antv/util'; 2 | import { GraphData, NodeConfig, EdgeConfig } from './types'; 3 | import { getOutEdgesNodeId, getEdgesByNodeId } from './util'; 4 | 5 | const minVertex = ( 6 | D: { [key: string]: number }, 7 | nodes: NodeConfig[], 8 | marks: { [key: string]: boolean }, 9 | ): NodeConfig => { 10 | // 找出最小的点 11 | let minDis = Infinity; 12 | let minNode; 13 | for (let i = 0; i < nodes.length; i++) { 14 | const nodeId = nodes[i].id; 15 | if (!marks[nodeId] && D[nodeId] <= minDis) { 16 | minDis = D[nodeId]; 17 | minNode = nodes[i]; 18 | } 19 | } 20 | return minNode; 21 | }; 22 | 23 | const dijkstra = ( 24 | graphData: GraphData, 25 | source: string, 26 | directed?: boolean, 27 | weightPropertyName?: string, 28 | ) => { 29 | const { nodes = [], edges = [] } = graphData; 30 | const nodeIds = []; 31 | const marks = {}; 32 | const D = {}; 33 | const prevs = {}; // key: 顶点, value: 顶点的前驱点数组(可能有多条等长的最短路径) 34 | nodes.forEach((node, i) => { 35 | const id = node.id; 36 | nodeIds.push(id); 37 | D[id] = Infinity; 38 | if (id === source) D[id] = 0; 39 | }); 40 | 41 | const nodeNum = nodes.length; 42 | for (let i = 0; i < nodeNum; i++) { 43 | // Process the vertices 44 | const minNode = minVertex(D, nodes, marks); 45 | const minNodeId = minNode.id; 46 | marks[minNodeId] = true; 47 | 48 | if (D[minNodeId] === Infinity) continue; // Unreachable vertices cannot be the intermediate point 49 | 50 | let relatedEdges: EdgeConfig[] = []; 51 | if (directed) relatedEdges = getOutEdgesNodeId(minNodeId, edges); 52 | else relatedEdges = getEdgesByNodeId(minNodeId, edges); 53 | 54 | relatedEdges.forEach(edge => { 55 | const edgeTarget = edge.target; 56 | const edgeSource = edge.source; 57 | const w = edgeTarget === minNodeId ? edgeSource : edgeTarget; 58 | const weight = weightPropertyName && edge[weightPropertyName] ? edge[weightPropertyName] : 1; 59 | if (D[w] > D[minNode.id] + weight) { 60 | D[w] = D[minNode.id] + weight; 61 | prevs[w] = [minNode.id]; 62 | } else if (D[w] === D[minNode.id] + weight) { 63 | prevs[w].push(minNode.id); 64 | } 65 | }); 66 | } 67 | 68 | prevs[source] = [source]; 69 | // 每个节点存可能存在多条最短路径 70 | const paths = {}; 71 | for (const target in D) { 72 | if (D[target] !== Infinity) { 73 | findAllPaths(source, target, prevs, paths); 74 | } 75 | } 76 | 77 | // 兼容之前单路径 78 | const path = {}; 79 | for (const target in paths) { 80 | path[target] = paths[target][0]; 81 | } 82 | return { length: D, path, allPath: paths }; 83 | }; 84 | 85 | export default dijkstra; 86 | 87 | function findAllPaths(source, target, prevs, foundPaths) { 88 | if (source === target) { 89 | return [source]; 90 | } 91 | if (foundPaths[target]) { 92 | return foundPaths[target]; 93 | } 94 | const paths = []; 95 | for (let prev of prevs[target]) { 96 | const prevPaths = findAllPaths(source, prev, prevs, foundPaths); 97 | if (!prevPaths) return; 98 | for (let prePath of prevPaths) { 99 | if (isArray(prePath)) paths.push([...prePath, target]); 100 | else paths.push([prePath, target]); 101 | } 102 | } 103 | foundPaths[target] = paths; 104 | return foundPaths[target]; 105 | } 106 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/louvain-spec.ts: -------------------------------------------------------------------------------- 1 | import { louvain, iLouvain } from '../../src'; 2 | import { GraphData } from '../../src/types'; 3 | import propertiesGraphData from './data/cluster-origin-properties-data.json'; 4 | 5 | describe('louvain', () => { 6 | 7 | it('simple louvain', () => { 8 | const data: GraphData = { 9 | nodes: [ 10 | { id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' }, 11 | { id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' }, 12 | { id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' }, 13 | ], 14 | edges: [ 15 | { source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' }, 16 | { source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' }, 17 | { source: '2', target: '3' }, { source: '2', target: '4' }, 18 | { source: '3', target: '4' }, 19 | { source: '0', target: '0' }, 20 | { source: '0', target: '0' }, 21 | { source: '0', target: '0' }, 22 | 23 | { source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' }, 24 | { source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' }, 25 | { source: '7', target: '8' }, { source: '7', target: '9' }, 26 | { source: '8', target: '9' }, 27 | 28 | { source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' }, 29 | { source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' }, 30 | { source: '12', target: '13' }, { source: '12', target: '14' }, 31 | { source: '13', target: '14', weight: 5 }, 32 | 33 | { source: '0', target: '5' }, 34 | { source: '5', target: '10' }, 35 | { source: '10', target: '0' }, 36 | { source: '10', target: '0' }, 37 | ] 38 | } 39 | const clusteredData = louvain(data, false, 'weight'); 40 | expect(clusteredData.clusters.length).toBe(3); 41 | expect(clusteredData.clusters[0].sumTot).toBe(3); 42 | expect(clusteredData.clusters[1].sumTot).toBe(2); 43 | expect(clusteredData.clusterEdges.length).toBe(6); 44 | expect(clusteredData.clusterEdges[0].count).toBe(13); 45 | expect(clusteredData.clusterEdges[1].count).toBe(10); 46 | expect(clusteredData.clusterEdges[1].weight).toBe(14); 47 | }); 48 | 49 | it('louvain with large graph', () => { // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json 50 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 51 | .then((res) => res.json()) 52 | .then((data) => { // 1589 nodes, 2747 edges 53 | const clusteredData = louvain(data, false, 'weight'); 54 | expect(clusteredData.clusters.length).toBe(495); 55 | expect(clusteredData.clusterEdges.length).toBe(505); 56 | }); 57 | }); 58 | 59 | it('louvain: add inertialModularity', () => { 60 | const clusteredData = iLouvain(propertiesGraphData as GraphData, false, 'weight', 0.01, 'properties'); 61 | expect(clusteredData.clusters.length).toBe(3); 62 | expect(clusteredData.clusters[0].sumTot).toBe(3); 63 | expect(clusteredData.clusters[1].sumTot).toBe(3); 64 | expect(clusteredData.clusters[2].sumTot).toBe(4); 65 | expect(clusteredData.clusterEdges.length).toBe(7); 66 | }); 67 | }); -------------------------------------------------------------------------------- /packages/graph/tests/unit/louvain-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { louvain } from '../../src'; 2 | import { GraphData } from '../../src/types'; 3 | import propertiesGraphData from './data/cluster-origin-properties-data.json'; 4 | 5 | describe('(Async) louvain', () => { 6 | 7 | it('simple louvain', () => { 8 | const data: GraphData = { 9 | nodes: [ 10 | { id: '0' }, { id: '1' }, { id: '2' }, { id: '3' }, { id: '4' }, 11 | { id: '5' }, { id: '6' }, { id: '7' }, { id: '8' }, { id: '9' }, 12 | { id: '10' }, { id: '11' }, { id: '12' }, { id: '13' }, { id: '14' }, 13 | ], 14 | edges: [ 15 | { source: '0', target: '1' }, { source: '0', target: '2' }, { source: '0', target: '3' }, { source: '0', target: '4' }, 16 | { source: '1', target: '2' }, { source: '1', target: '3' }, { source: '1', target: '4' }, 17 | { source: '2', target: '3' }, { source: '2', target: '4' }, 18 | { source: '3', target: '4' }, 19 | { source: '0', target: '0' }, 20 | { source: '0', target: '0' }, 21 | { source: '0', target: '0' }, 22 | 23 | { source: '5', target: '6', weight: 5 }, { source: '5', target: '7' }, { source: '5', target: '8' }, { source: '5', target: '9' }, 24 | { source: '6', target: '7' }, { source: '6', target: '8' }, { source: '6', target: '9' }, 25 | { source: '7', target: '8' }, { source: '7', target: '9' }, 26 | { source: '8', target: '9' }, 27 | 28 | { source: '10', target: '11' }, { source: '10', target: '12' }, { source: '10', target: '13' }, { source: '10', target: '14' }, 29 | { source: '11', target: '12' }, { source: '11', target: '13' }, { source: '11', target: '14' }, 30 | { source: '12', target: '13' }, { source: '12', target: '14' }, 31 | { source: '13', target: '14', weight: 5 }, 32 | 33 | { source: '0', target: '5' }, 34 | { source: '5', target: '10' }, 35 | { source: '10', target: '0' }, 36 | { source: '10', target: '0' }, 37 | ] 38 | } 39 | const clusteredData = louvain(data, false, 'weight'); 40 | expect(clusteredData.clusters.length).toBe(3); 41 | expect(clusteredData.clusters[0].sumTot).toBe(3); 42 | expect(clusteredData.clusters[1].sumTot).toBe(2); 43 | expect(clusteredData.clusterEdges.length).toBe(6); 44 | expect(clusteredData.clusterEdges[0].count).toBe(13); 45 | expect(clusteredData.clusterEdges[1].count).toBe(10); 46 | expect(clusteredData.clusterEdges[1].weight).toBe(14); 47 | }); 48 | 49 | it('louvain with large graph', () => { // https://gw.alipayobjects.com/os/antvdemo/assets/data/relations.json 50 | fetch('https://gw.alipayobjects.com/os/basement_prod/da5a1b47-37d6-44d7-8d10-f3e046dabf82.json') 51 | .then((res) => res.json()) 52 | .then((data) => { // 1589 nodes, 2747 edges 53 | const clusteredData = louvain(data, false, 'weight'); 54 | expect(clusteredData.clusters.length).toBe(495); 55 | expect(clusteredData.clusterEdges.length).toBe(505); 56 | }); 57 | }); 58 | 59 | it('louvain: add inertialModularity', () => { 60 | const clusteredData = louvain(propertiesGraphData as GraphData, false, 'weight', 0.01, true, 'properties'); 61 | expect(clusteredData.clusters.length).toBe(3); 62 | expect(clusteredData.clusters[0].sumTot).toBe(3); 63 | expect(clusteredData.clusters[1].sumTot).toBe(3); 64 | expect(clusteredData.clusters[2].sumTot).toBe(4); 65 | expect(clusteredData.clusterEdges.length).toBe(7); 66 | }); 67 | }); 68 | -------------------------------------------------------------------------------- /packages/graph/src/utils/data-preprocessing.ts: -------------------------------------------------------------------------------- 1 | import { uniq } from '@antv/util'; 2 | import { PlainObject, DistanceType, GraphData, KeyValueMap } from '../types'; 3 | import Vector from './vector'; 4 | 5 | /** 6 | * 获取数据中所有的属性及其对应的值 7 | * @param dataList 数据集 8 | * @param involvedKeys 参与计算的key集合 9 | * @param uninvolvedKeys 不参与计算的key集合 10 | */ 11 | export const getAllKeyValueMap = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { 12 | let keys = []; 13 | // 指定了参与计算的keys时,使用指定的keys 14 | if (involvedKeys?.length) { 15 | keys = involvedKeys; 16 | } else { 17 | // 未指定抽取的keys时,提取数据中所有的key 18 | dataList.forEach(data => { 19 | keys = keys.concat(Object.keys(data)); 20 | }) 21 | keys = uniq(keys); 22 | } 23 | // 获取所有值非空的key的value数组 24 | const allKeyValueMap: KeyValueMap = {}; 25 | keys.forEach(key => { 26 | let value = []; 27 | dataList.forEach(data => { 28 | if (data[key] !== undefined && data[key] !== '') { 29 | value.push(data[key]); 30 | } 31 | }) 32 | if (value.length && !uninvolvedKeys?.includes(key)) { 33 | allKeyValueMap[key] = uniq(value); 34 | } 35 | }) 36 | 37 | return allKeyValueMap; 38 | } 39 | 40 | /** 41 | * one-hot编码:数据特征提取 42 | * @param dataList 数据集 43 | * @param involvedKeys 参与计算的的key集合 44 | * @param uninvolvedKeys 不参与计算的key集合 45 | */ 46 | export const oneHot = (dataList: PlainObject[], involvedKeys?: string[], uninvolvedKeys?: string[]) => { 47 | // 获取数据中所有的属性/特征及其对应的值 48 | const allKeyValueMap = getAllKeyValueMap(dataList, involvedKeys, uninvolvedKeys); 49 | const oneHotCode = []; 50 | if (!Object.keys(allKeyValueMap).length) { 51 | return oneHotCode; 52 | } 53 | 54 | // 获取所有的属性/特征值 55 | const allValue = Object.values(allKeyValueMap); 56 | // 是否所有属性/特征的值都是数值型 57 | const isAllNumber = allValue.every(value => value.every(item => (typeof(item) === 'number'))); 58 | 59 | // 对数据进行one-hot编码 60 | dataList.forEach((data, index) => { 61 | let code = []; 62 | Object.keys(allKeyValueMap).forEach(key => { 63 | const keyValue = data[key]; 64 | const allKeyValue = allKeyValueMap[key]; 65 | const valueIndex = allKeyValue.findIndex(value => keyValue === value); 66 | let subCode = []; 67 | // 如果属性/特征所有的值都能转成数值型,不满足分箱,则直接用值(todo: 为了收敛更快,需做归一化处理) 68 | if (isAllNumber) { 69 | subCode.push(keyValue); 70 | } else { 71 | // 进行one-hot编码 72 | for(let i = 0; i < allKeyValue.length; i++) { 73 | if (i === valueIndex) { 74 | subCode.push(1); 75 | } else { 76 | subCode.push(0); 77 | } 78 | } 79 | } 80 | code = code.concat(subCode); 81 | }) 82 | oneHotCode[index] = code; 83 | }) 84 | return oneHotCode; 85 | } 86 | 87 | /** 88 | * getDistance:获取两个元素之间的距离 89 | * @param item 90 | * @param otherItem 91 | * @param distanceType 距离类型 92 | * @param graphData 图数据 93 | */ 94 | export const getDistance = (item, otherItem, distanceType: DistanceType = DistanceType.EuclideanDistance, graphData?: GraphData) => { 95 | let distance = 0; 96 | switch (distanceType) { 97 | case DistanceType.EuclideanDistance: 98 | distance = new Vector(item).euclideanDistance(new Vector(otherItem)); 99 | break; 100 | default: 101 | break; 102 | } 103 | return distance; 104 | } 105 | 106 | export default { 107 | getAllKeyValueMap, 108 | oneHot, 109 | getDistance, 110 | } 111 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/nodesCosineSimilarity-spec.ts: -------------------------------------------------------------------------------- 1 | import { nodesCosineSimilarity } from '../../src'; 2 | import { NodeConfig } from '../../src/types'; 3 | import propertiesGraphData from './data/cluster-origin-properties-data.json'; 4 | 5 | describe('nodesCosineSimilarity abnormal demo', () => { 6 | it('no properties demo: ', () => { 7 | const nodes = [ 8 | { 9 | id: 'node-0', 10 | }, 11 | { 12 | id: 'node-1', 13 | }, 14 | { 15 | id: 'node-2', 16 | }, 17 | { 18 | id: 'node-3', 19 | } 20 | ]; 21 | const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[0]); 22 | expect(allCosineSimilarity.length).toBe(3); 23 | expect(similarNodes.length).toBe(3); 24 | expect(allCosineSimilarity[0]).toBe(1); 25 | expect(allCosineSimilarity[1]).toBe(0); 26 | expect(allCosineSimilarity[2]).toBe(0); 27 | }); 28 | }); 29 | 30 | describe('nodesCosineSimilarity normal demo', () => { 31 | it('simple demo: ', () => { 32 | const nodes = [ 33 | { 34 | id: 'node-0', 35 | properties: { 36 | amount: 10, 37 | } 38 | }, 39 | { 40 | id: 'node-2', 41 | properties: { 42 | amount: 100, 43 | } 44 | }, 45 | { 46 | id: 'node-3', 47 | properties: { 48 | amount: 1000, 49 | } 50 | }, 51 | { 52 | id: 'node-4', 53 | properties: { 54 | amount: 50, 55 | } 56 | } 57 | ]; 58 | const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[0], 'properties'); 59 | expect(allCosineSimilarity.length).toBe(3); 60 | expect(similarNodes.length).toBe(3); 61 | allCosineSimilarity.forEach(data => { 62 | expect(data).toBeGreaterThanOrEqual(0); 63 | expect(data).toBeLessThanOrEqual(1); 64 | }) 65 | }); 66 | 67 | it('complex demo: ', () => { 68 | const { nodes } = propertiesGraphData; 69 | const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[16], 'properties'); 70 | expect(allCosineSimilarity.length).toBe(16); 71 | expect(similarNodes.length).toBe(16); 72 | allCosineSimilarity.forEach(data => { 73 | expect(data).toBeGreaterThanOrEqual(0); 74 | expect(data).toBeLessThanOrEqual(1); 75 | }) 76 | }); 77 | 78 | 79 | it('demo use involvedKeys: ', () => { 80 | const involvedKeys = ['amount', 'wifi']; 81 | const { nodes } = propertiesGraphData; 82 | const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[16], 'properties', involvedKeys); 83 | expect(allCosineSimilarity.length).toBe(16); 84 | expect(similarNodes.length).toBe(16); 85 | allCosineSimilarity.forEach(data => { 86 | expect(data).toBeGreaterThanOrEqual(0); 87 | expect(data).toBeLessThanOrEqual(1); 88 | }) 89 | expect(similarNodes[0].id).toBe('node-11'); 90 | }); 91 | 92 | it('demo use uninvolvedKeys: ', () => { 93 | const uninvolvedKeys = ['amount']; 94 | const { nodes } = propertiesGraphData; 95 | const { allCosineSimilarity, similarNodes } = nodesCosineSimilarity(nodes as NodeConfig[], nodes[16], 'properties', [], uninvolvedKeys); 96 | expect(allCosineSimilarity.length).toBe(16); 97 | expect(similarNodes.length).toBe(16); 98 | allCosineSimilarity.forEach(data => { 99 | expect(data).toBeGreaterThanOrEqual(0); 100 | expect(data).toBeLessThanOrEqual(1); 101 | }) 102 | expect(similarNodes[0].id).toBe('node-11'); 103 | }); 104 | }); 105 | -------------------------------------------------------------------------------- /packages/graph/src/mts.ts: -------------------------------------------------------------------------------- 1 | import UnionFind from './structs/union-find'; 2 | import MinBinaryHeap from './structs/binary-heap'; 3 | import { GraphData, EdgeConfig } from './types'; 4 | import { getEdgesByNodeId } from './util'; 5 | 6 | /** 7 | * Prim algorithm,use priority queue,复杂度 O(E+V*logV), V: 节点数量,E: 边的数量 8 | * refer: https://en.wikipedia.org/wiki/Prim%27s_algorithm 9 | * @param graph 10 | * @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致 11 | */ 12 | const primMST = (graphData: GraphData, weight?: string) => { 13 | const selectedEdges = []; 14 | const { nodes = [], edges = [] } = graphData; 15 | if (nodes.length === 0) { 16 | return selectedEdges; 17 | } 18 | 19 | // 从nodes[0]开始 20 | const currNode = nodes[0]; 21 | const visited = new Set(); 22 | visited.add(currNode); 23 | 24 | // 用二叉堆维护距已加入节点的其他节点的边的权值 25 | const compareWeight = (a: EdgeConfig, b: EdgeConfig) => { 26 | if (weight) { 27 | return a.weight - b.weight; 28 | } 29 | return 0; 30 | 31 | }; 32 | const edgeQueue = new MinBinaryHeap(compareWeight); 33 | getEdgesByNodeId(currNode.id, edges).forEach((edge) => { 34 | edgeQueue.insert(edge); 35 | }); 36 | 37 | while (!edgeQueue.isEmpty()) { 38 | // 选取与已加入的结点之间边权最小的结点 39 | const currEdge: EdgeConfig = edgeQueue.delMin(); 40 | const source = currEdge.source; 41 | const target = currEdge.target; 42 | if (visited.has(source) && visited.has(target)) continue; 43 | selectedEdges.push(currEdge); 44 | 45 | if (!visited.has(source)) { 46 | visited.add(source); 47 | getEdgesByNodeId(source, edges).forEach((edge) => { 48 | edgeQueue.insert(edge); 49 | }); 50 | } 51 | if (!visited.has(target)) { 52 | visited.add(target); 53 | getEdgesByNodeId(target, edges).forEach((edge) => { 54 | edgeQueue.insert(edge); 55 | }); 56 | } 57 | } 58 | return selectedEdges; 59 | }; 60 | 61 | /** 62 | * Kruskal algorithm,复杂度 O(E*logE), E: 边的数量 63 | * refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm 64 | * @param graph 65 | * @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致 66 | * @return IEdge[] 返回构成MST的边的数组 67 | */ 68 | const kruskalMST = (graphData: GraphData, weight?: string): EdgeConfig[] => { 69 | const selectedEdges = []; 70 | const { nodes = [], edges = [] } = graphData 71 | if (nodes.length === 0) { 72 | return selectedEdges; 73 | } 74 | 75 | // 若指定weight,则将所有的边按权值从小到大排序 76 | const weightEdges = edges.map((edge) => edge); 77 | if (weight) { 78 | weightEdges.sort((a, b) => { 79 | return a.weight - b.weight; 80 | }); 81 | } 82 | const disjointSet = new UnionFind(nodes.map((n) => n.id)); 83 | 84 | // 从权值最小的边开始,如果这条边连接的两个节点于图G中不在同一个连通分量中,则添加这条边 85 | // 直到遍历完所有点或边 86 | while (weightEdges.length > 0) { 87 | const curEdge = weightEdges.shift(); 88 | const source = curEdge.source; 89 | const target = curEdge.target; 90 | if (!disjointSet.connected(source, target)) { 91 | selectedEdges.push(curEdge); 92 | disjointSet.union(source, target); 93 | } 94 | } 95 | return selectedEdges; 96 | }; 97 | 98 | /** 99 | * 最小生成树 100 | * refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm 101 | * @param graph 102 | * @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致 103 | * @param algo 'prim' | 'kruskal' 算法类型 104 | * @return EdgeConfig[] 返回构成MST的边的数组 105 | */ 106 | const minimumSpanningTree = (graphData: GraphData, weight?: string, algo?: string): EdgeConfig[] => { 107 | const algos = { 108 | prim: primMST, 109 | kruskal: kruskalMST, 110 | }; 111 | if (!algo) return kruskalMST(graphData, weight); 112 | 113 | return algos[algo](graphData, weight); 114 | } 115 | 116 | export default minimumSpanningTree 117 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/find-path-async-spec.ts: -------------------------------------------------------------------------------- 1 | import { getAlgorithm } from './utils'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | label: 'A', 8 | }, 9 | { 10 | id: 'B', 11 | label: 'B', 12 | }, 13 | { 14 | id: 'C', 15 | label: 'C', 16 | }, 17 | { 18 | id: 'D', 19 | label: 'D', 20 | }, 21 | { 22 | id: 'E', 23 | label: 'E', 24 | }, 25 | { 26 | id: 'F', 27 | label: 'F', 28 | }, 29 | { 30 | id: 'G', 31 | label: 'G', 32 | }, 33 | { 34 | id: 'H', 35 | label: 'H', 36 | }, 37 | ], 38 | edges: [ 39 | { 40 | source: 'A', 41 | target: 'B', 42 | }, 43 | { 44 | source: 'B', 45 | target: 'C', 46 | }, 47 | { 48 | source: 'C', 49 | target: 'G', 50 | }, 51 | { 52 | source: 'A', 53 | target: 'D', 54 | }, 55 | { 56 | source: 'A', 57 | target: 'E', 58 | }, 59 | { 60 | source: 'E', 61 | target: 'F', 62 | }, 63 | { 64 | source: 'F', 65 | target: 'D', 66 | }, 67 | { 68 | source: 'D', 69 | target: 'E', 70 | }, 71 | ], 72 | }; 73 | 74 | describe('(Async) Shortest Path from source to target on graph', () => { 75 | it('find the shortest path', async done => { 76 | const { findShortestPathAsync } = await getAlgorithm(); 77 | const { length, path } = await findShortestPathAsync(data, 'A', 'C'); 78 | expect(length).toBe(2); 79 | expect(path).toStrictEqual(['A', 'B', 'C']); 80 | done(); 81 | }); 82 | 83 | it('find all shortest paths', async done => { 84 | const { findShortestPathAsync } = await getAlgorithm(); 85 | const { length, allPath } = await findShortestPathAsync(data, 'A', 'F'); 86 | expect(length).toBe(2); 87 | expect(allPath[0]).toStrictEqual(['A', 'E', 'F']); 88 | expect(allPath[1]).toStrictEqual(['A', 'D', 'F']); 89 | 90 | const { 91 | length: directedLenght, 92 | path: directedPath, 93 | allPath: directedAllPath, 94 | } = await findShortestPathAsync(data, 'A', 'F', true); 95 | expect(directedLenght).toBe(2); 96 | expect(directedAllPath[0]).toStrictEqual(['A', 'E', 'F']); 97 | expect(directedPath).toStrictEqual(['A', 'E', 'F']); 98 | done(); 99 | }); 100 | 101 | it('find all paths', async done => { 102 | const { findAllPathAsync } = await getAlgorithm(); 103 | const allPath = await findAllPathAsync(data, 'A', 'E'); 104 | expect(allPath.length).toBe(3); 105 | expect(allPath[0]).toStrictEqual(['A', 'D', 'F', 'E']); 106 | expect(allPath[1]).toStrictEqual(['A', 'D', 'E']); 107 | expect(allPath[2]).toStrictEqual(['A', 'E']); 108 | done(); 109 | }); 110 | 111 | it('find all paths in directed graph', async done => { 112 | const { findAllPathAsync } = await getAlgorithm(); 113 | const allPath = await findAllPathAsync(data, 'A', 'E', true); 114 | expect(allPath.length).toStrictEqual(2); 115 | expect(allPath[0]).toStrictEqual(['A', 'D', 'E']); 116 | expect(allPath[1]).toStrictEqual(['A', 'E']); 117 | done(); 118 | }); 119 | 120 | it('find all shortest paths in weighted graph', async done => { 121 | const { findShortestPathAsync } = await getAlgorithm(); 122 | data.edges.forEach((edge: any, i) => { 123 | edge.weight = ((i % 2) + 1) * 2; 124 | if (edge.source === 'F' && edge.target === 'D') edge.weight = 10; 125 | }); 126 | const { length, path, allPath } = await findShortestPathAsync(data, 'A', 'F', false, 'weight'); 127 | expect(length).toBe(6); 128 | expect(allPath[0]).toStrictEqual(['A', 'E', 'F']); 129 | expect(path).toStrictEqual(['A', 'E', 'F']); 130 | done(); 131 | }); 132 | }); 133 | -------------------------------------------------------------------------------- /packages/graph/src/connected-component.ts: -------------------------------------------------------------------------------- 1 | import { GraphData, NodeConfig } from "./types"; 2 | import { getNeighbors } from "./util"; 3 | 4 | /** 5 | * Generate all connected components for an undirected graph 6 | * @param graph 7 | */ 8 | export const detectConnectedComponents = (graphData: GraphData): NodeConfig[][] => { 9 | const { nodes = [], edges = [] } = graphData 10 | const allComponents: NodeConfig[][] = []; 11 | const visited = {}; 12 | const nodeStack: NodeConfig[] = []; 13 | 14 | const getComponent = (node: NodeConfig) => { 15 | nodeStack.push(node); 16 | visited[node.id] = true; 17 | const neighbors = getNeighbors(node.id, edges); 18 | for (let i = 0; i < neighbors.length; ++i) { 19 | const neighbor = neighbors[i]; 20 | if (!visited[neighbor]) { 21 | const targetNode = nodes.filter(node => node.id === neighbor) 22 | if (targetNode.length > 0) { 23 | getComponent(targetNode[0]); 24 | } 25 | } 26 | } 27 | }; 28 | 29 | for (let i = 0; i < nodes.length; i++) { 30 | const node = nodes[i]; 31 | if (!visited[node.id]) { 32 | // 对于无向图进行dfs遍历,每一次调用后都得到一个连通分量 33 | getComponent(node); 34 | const component = []; 35 | while (nodeStack.length > 0) { 36 | component.push(nodeStack.pop()); 37 | } 38 | allComponents.push(component); 39 | } 40 | } 41 | return allComponents; 42 | } 43 | 44 | /** 45 | * Tarjan's Algorithm 复杂度 O(|V|+|E|) 46 | * For directed graph only 47 | * a directed graph is said to be strongly connected if "every vertex is reachable from every other vertex". 48 | * refer: http://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm 49 | * @param graph 50 | * @return a list of strongly connected components 51 | */ 52 | export const detectStrongConnectComponents = (graphData: GraphData): NodeConfig[][] => { 53 | const { nodes = [], edges = [] } = graphData 54 | const nodeStack: NodeConfig[] = []; 55 | const inStack = {}; // 辅助判断是否已经在stack中,减少查找开销 56 | const indices = {}; 57 | const lowLink = {}; 58 | const allComponents: NodeConfig[][] = []; 59 | let index = 0; 60 | 61 | const getComponent = (node: NodeConfig) => { 62 | // Set the depth index for v to the smallest unused index 63 | indices[node.id] = index; 64 | lowLink[node.id] = index; 65 | index += 1; 66 | nodeStack.push(node); 67 | inStack[node.id] = true; 68 | 69 | // 考虑每个邻接点 70 | const neighbors = getNeighbors(node.id, edges, 'target').filter((n) => nodes.map(node => node.id).indexOf(n) > -1); 71 | for (let i = 0; i < neighbors.length; i++) { 72 | const targetNodeID = neighbors[i]; 73 | if (!indices[targetNodeID] && indices[targetNodeID] !== 0) { 74 | const targetNode = nodes.filter(node => node.id === targetNodeID) 75 | if (targetNode.length > 0) { 76 | getComponent(targetNode[0]); 77 | } 78 | // tree edge 79 | lowLink[node.id] = Math.min(lowLink[node.id], lowLink[targetNodeID]); 80 | } else if (inStack[targetNodeID]) { 81 | // back edge, target node is in the current SCC 82 | lowLink[node.id] = Math.min(lowLink[node.id], indices[targetNodeID]); 83 | } 84 | } 85 | 86 | // If node is a root node, generate an SCC 87 | if (lowLink[node.id] === indices[node.id]) { 88 | const component = []; 89 | while (nodeStack.length > 0) { 90 | const tmpNode = nodeStack.pop(); 91 | inStack[tmpNode.id] = false; 92 | component.push(tmpNode); 93 | if (tmpNode === node) break; 94 | } 95 | if (component.length > 0) { 96 | allComponents.push(component); 97 | } 98 | } 99 | }; 100 | 101 | for (const node of nodes) { 102 | if (!indices[node.id] && indices[node.id] !== 0) { 103 | getComponent(node); 104 | } 105 | } 106 | 107 | return allComponents; 108 | } 109 | 110 | export default function getConnectedComponents(graphData: GraphData, directed?: boolean): NodeConfig[][] { 111 | if (directed) return detectStrongConnectComponents(graphData); 112 | return detectConnectedComponents(graphData); 113 | } 114 | -------------------------------------------------------------------------------- /packages/graph/src/utils/vector.ts: -------------------------------------------------------------------------------- 1 | 2 | /** 3 | * 向量运算 4 | */ 5 | import { clone } from '@antv/util'; 6 | 7 | class Vector { 8 | arr: number[]; 9 | 10 | constructor(arr) { 11 | this.arr = arr; 12 | } 13 | 14 | getArr() { 15 | return this.arr || []; 16 | } 17 | 18 | add(otherVector) { 19 | const otherArr = otherVector.arr; 20 | if (!this.arr?.length) { 21 | return new Vector(otherArr); 22 | } 23 | if (!otherArr?.length) { 24 | return new Vector(this.arr); 25 | } 26 | if (this.arr.length === otherArr.length) { 27 | let res = []; 28 | for (let index in this.arr) { 29 | res[index] = this.arr[index] + otherArr[index]; 30 | } 31 | return new Vector(res); 32 | } 33 | } 34 | 35 | subtract(otherVector) { 36 | const otherArr = otherVector.arr; 37 | if (!this.arr?.length) { 38 | return new Vector(otherArr); 39 | } 40 | if (!otherArr?.length) { 41 | return new Vector(this.arr); 42 | } 43 | if (this.arr.length === otherArr.length) { 44 | let res = []; 45 | for (let index in this.arr) { 46 | res[index] = this.arr[index] - otherArr[index]; 47 | } 48 | return new Vector(res); 49 | } 50 | } 51 | 52 | avg(length) { 53 | let res = []; 54 | if (length !== 0) { 55 | for (let index in this.arr) { 56 | res[index] = this.arr[index] / length; 57 | } 58 | } 59 | return new Vector(res); 60 | } 61 | 62 | negate() { 63 | let res = []; 64 | for (let index in this.arr) { 65 | res[index] = - this.arr[index]; 66 | } 67 | return new Vector(res); 68 | } 69 | 70 | // 平方欧式距离 71 | squareEuclideanDistance(otherVector) { 72 | const otherArr = otherVector.arr; 73 | if (!this.arr?.length || !otherArr?.length) { 74 | return 0; 75 | } 76 | if (this.arr.length === otherArr.length) { 77 | let res = 0; 78 | for (let index in this.arr) { 79 | res += Math.pow(this.arr[index] - otherVector.arr[index], 2); 80 | } 81 | return res; 82 | } 83 | } 84 | 85 | // 欧式距离 86 | euclideanDistance(otherVector) { 87 | const otherArr = otherVector.arr; 88 | if (!this.arr?.length || !otherArr?.length) { 89 | return 0; 90 | } 91 | if (this.arr.length === otherArr.length) { 92 | let res = 0; 93 | for (let index in this.arr) { 94 | res += Math.pow(this.arr[index] - otherVector.arr[index], 2); 95 | } 96 | return Math.sqrt(res); 97 | } else { 98 | console.error('The two vectors are unequal in length.') 99 | } 100 | } 101 | 102 | // 归一化处理 103 | normalize() { 104 | let res = []; 105 | const cloneArr = clone(this.arr); 106 | cloneArr.sort((a, b) => a - b); 107 | const max = cloneArr[cloneArr.length - 1]; 108 | const min = cloneArr[0]; 109 | for (let index in this.arr) { 110 | res[index] = (this.arr[index] - min) / (max - min); 111 | } 112 | return new Vector(res); 113 | } 114 | 115 | // 2范数 or 模长 116 | norm2() { 117 | if (!this.arr?.length) { 118 | return 0; 119 | } 120 | let res = 0; 121 | for (let index in this.arr) { 122 | res += Math.pow(this.arr[index], 2); 123 | } 124 | return Math.sqrt(res); 125 | } 126 | 127 | // 两个向量的点积 128 | dot(otherVector) { 129 | const otherArr = otherVector.arr; 130 | if (!this.arr?.length || !otherArr?.length) { 131 | return 0; 132 | } 133 | if (this.arr.length === otherArr.length) { 134 | let res = 0; 135 | for (let index in this.arr) { 136 | res += this.arr[index] * otherVector.arr[index]; 137 | } 138 | return res; 139 | } else { 140 | console.error('The two vectors are unequal in length.') 141 | } 142 | } 143 | 144 | // 两个向量比较 145 | equal(otherVector) { 146 | const otherArr = otherVector.arr; 147 | if (this.arr?.length !== otherArr?.length) { 148 | return false; 149 | } 150 | for (let index in this.arr) { 151 | if (this.arr[index] !== otherArr[index]) { 152 | return false; 153 | } 154 | } 155 | return true; 156 | } 157 | } 158 | 159 | export default Vector; 160 | -------------------------------------------------------------------------------- /packages/webgpu-graph/src/link-analysis/pageRank.ts: -------------------------------------------------------------------------------- 1 | import type { WebGLRenderer } from '@antv/g-webgl'; 2 | import { Kernel, BufferUsage } from '@antv/g-plugin-gpgpu'; 3 | import type { GraphData } from '../types'; 4 | import { convertGraphData2CSC } from '../util'; 5 | 6 | /** 7 | * Pagerank using power method, ported from CUDA 8 | * 9 | * @param graphData 10 | * @param eps Set the tolerance the approximation, this parameter should be a small magnitude value. The lower the tolerance the better the approximation. 11 | * @param alpha The damping factor alpha represents the probability to follow an outgoing edge, standard value is 0.85. 12 | * @param maxIteration Set the maximum number of iterations. 13 | * 14 | * @see https://github.com/princeofpython/PageRank-with-CUDA/blob/main/parallel.cu 15 | */ 16 | export async function pageRank(device: WebGLRenderer.Device, graphData: GraphData, eps = 1e-05, alpha = 0.85, maxIteration = 1000) { 17 | const BLOCK_SIZE = 1; 18 | const BLOCKS = 256; 19 | 20 | const { V, From, To } = convertGraphData2CSC(graphData); 21 | 22 | const n = V.length; 23 | const graph = new Float32Array(new Array(n * n).fill((1 - alpha) / n)); 24 | const r = new Float32Array(new Array(n).fill(1 / n)); 25 | 26 | From.forEach((from, i) => { 27 | graph[To[i] * n + from] += alpha * 1.0; 28 | }); 29 | 30 | for (let j = 0; j < n; j++) { 31 | let sum = 0.0; 32 | 33 | for (let i = 0; i < n; ++i) { 34 | sum += graph[i * n + j]; 35 | } 36 | 37 | for (let i = 0; i < n; ++i) { 38 | if (sum != 0.0) { 39 | graph[i * n + j] /= sum; 40 | } else { 41 | graph[i * n + j] = 1 / n; 42 | } 43 | } 44 | } 45 | 46 | const storeKernel = new Kernel(device, { 47 | computeShader: ` 48 | struct Buffer { 49 | data: array; 50 | }; 51 | 52 | @group(0) @binding(0) var r : Buffer; 53 | @group(0) @binding(1) var r_last : Buffer; 54 | 55 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 56 | fn main( 57 | @builtin(global_invocation_id) global_id : vec3 58 | ) { 59 | var index = global_id.x; 60 | if (index < ${V.length}u) { 61 | r_last.data[index] = r.data[index]; 62 | } 63 | }`, 64 | }); 65 | 66 | const matmulKernel = new Kernel(device, { 67 | computeShader: ` 68 | struct Buffer { 69 | data: array; 70 | }; 71 | 72 | @group(0) @binding(0) var graph : Buffer; 73 | @group(0) @binding(1) var r : Buffer; 74 | @group(0) @binding(2) var r_last : Buffer; 75 | 76 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 77 | fn main( 78 | @builtin(global_invocation_id) global_id : vec3 79 | ) { 80 | var index = global_id.x; 81 | if (index < ${V.length}u) { 82 | var sum = 0.0; 83 | for (var i = 0u; i < ${V.length}u; i = i + 1u) { 84 | sum = sum + r_last.data[i] * graph.data[index * ${V.length}u + i]; 85 | } 86 | r.data[index] = sum; 87 | } 88 | } 89 | `, 90 | }); 91 | 92 | const rankDiffKernel = new Kernel(device, { 93 | computeShader: ` 94 | struct Buffer { 95 | data: array; 96 | }; 97 | 98 | @group(0) @binding(0) var r : Buffer; 99 | @group(0) @binding(1) var r_last : Buffer; 100 | 101 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 102 | fn main( 103 | @builtin(global_invocation_id) global_id : vec3 104 | ) { 105 | var index = global_id.x; 106 | if (index < ${V.length}u) { 107 | r_last.data[index] = abs(r_last.data[index] - r.data[index]); 108 | } 109 | } 110 | `, 111 | }); 112 | 113 | const rBuffer = device.createBuffer({ 114 | usage: BufferUsage.STORAGE | BufferUsage.COPY_SRC, 115 | viewOrSize: new Float32Array(r), 116 | }); 117 | const rLastBuffer = device.createBuffer({ 118 | usage: BufferUsage.STORAGE | BufferUsage.COPY_SRC, 119 | viewOrSize: new Float32Array(n), 120 | }); 121 | const graphBuffer = device.createBuffer({ 122 | usage: BufferUsage.STORAGE, 123 | viewOrSize: new Float32Array(graph), 124 | }); 125 | 126 | const readback = device.createReadback(); 127 | 128 | storeKernel.setBinding(0, rBuffer); 129 | storeKernel.setBinding(1, rLastBuffer); 130 | 131 | matmulKernel.setBinding(0, graphBuffer); 132 | matmulKernel.setBinding(1, rBuffer); 133 | matmulKernel.setBinding(2, rLastBuffer); 134 | 135 | rankDiffKernel.setBinding(0, rBuffer); 136 | rankDiffKernel.setBinding(1, rLastBuffer); 137 | 138 | const grids = Math.ceil(V.length / (BLOCKS * BLOCK_SIZE)); 139 | 140 | while (maxIteration--) { 141 | storeKernel.dispatch(grids, 1); 142 | matmulKernel.dispatch(grids, 1); 143 | rankDiffKernel.dispatch(grids, 1); 144 | 145 | const last = await readback.readBuffer(rLastBuffer) as Float32Array; 146 | const result = last.reduce((prev, cur) => prev + cur, 0); 147 | if (result < eps) { 148 | break; 149 | } 150 | } 151 | 152 | const out = await readback.readBuffer(rBuffer) as Float32Array; 153 | return Array.from(out) 154 | .map((score, index) => ({ id: graphData.nodes[index].id, score })) 155 | .sort((a, b) => b.score - a.score); 156 | } -------------------------------------------------------------------------------- /packages/graph/src/label-propagation.ts: -------------------------------------------------------------------------------- 1 | 2 | import getAdjMatrix from './adjacent-matrix' 3 | import { uniqueId } from './util'; 4 | import { GraphData, ClusterData } from './types'; 5 | 6 | /** 7 | * 标签传播算法 8 | * @param graphData 图数据 9 | * @param directed 是否有向图,默认为 false 10 | * @param weightPropertyName 权重的属性字段 11 | * @param maxIteration 最大迭代次数 12 | */ 13 | const labelPropagation = ( 14 | graphData: GraphData, 15 | directed: boolean = false, 16 | weightPropertyName: string = 'weight', 17 | maxIteration: number = 1000 18 | ): ClusterData => { 19 | // the origin data 20 | const { nodes = [], edges = [] } = graphData; 21 | 22 | const clusters = {}; 23 | const nodeMap = {}; 24 | // init the clusters and nodeMap 25 | nodes.forEach((node, i) => { 26 | const cid: string = uniqueId(); 27 | node.clusterId = cid; 28 | clusters[cid] = { 29 | id: cid, 30 | nodes: [node] 31 | }; 32 | nodeMap[node.id] = { 33 | node, 34 | idx: i 35 | }; 36 | }); 37 | 38 | // the adjacent matrix of calNodes inside clusters 39 | const adjMatrix = getAdjMatrix(graphData, directed); 40 | // the sum of each row in adjacent matrix 41 | const ks = []; 42 | /** 43 | * neighbor nodes (id for key and weight for value) for each node 44 | * neighbors = { 45 | * id(node_id): { id(neighbor_1_id): weight(weight of the edge), id(neighbor_2_id): weight(weight of the edge), ... }, 46 | * ... 47 | * } 48 | */ 49 | const neighbors = {}; 50 | adjMatrix.forEach((row, i) => { 51 | let k = 0; 52 | const iid = nodes[i].id; 53 | neighbors[iid] = {}; 54 | row.forEach((entry, j) => { 55 | if (!entry) return; 56 | k += entry; 57 | const jid = nodes[j].id; 58 | neighbors[iid][jid] = entry; 59 | }); 60 | ks.push(k); 61 | }); 62 | 63 | let iter = 0; 64 | 65 | while (iter < maxIteration) { 66 | let changed = false; 67 | nodes.forEach(node => { 68 | const neighborClusters = {}; 69 | Object.keys(neighbors[node.id]).forEach(neighborId => { 70 | const neighborWeight = neighbors[node.id][neighborId]; 71 | const neighborNode = nodeMap[neighborId].node; 72 | const neighborClusterId = neighborNode.clusterId; 73 | if (!neighborClusters[neighborClusterId]) neighborClusters[neighborClusterId] = 0; 74 | neighborClusters[neighborClusterId] += neighborWeight; 75 | }); 76 | // find the cluster with max weight 77 | let maxWeight = -Infinity; 78 | let bestClusterIds = []; 79 | Object.keys(neighborClusters).forEach(clusterId => { 80 | if (maxWeight < neighborClusters[clusterId]) { 81 | maxWeight = neighborClusters[clusterId]; 82 | bestClusterIds = [clusterId]; 83 | } else if (maxWeight === neighborClusters[clusterId]) { 84 | bestClusterIds.push(clusterId); 85 | } 86 | }); 87 | if (bestClusterIds.length === 1 && bestClusterIds[0] === node.clusterId) return; 88 | const selfClusterIdx = bestClusterIds.indexOf(node.clusterId); 89 | if (selfClusterIdx >= 0) bestClusterIds.splice(selfClusterIdx, 1); 90 | if (bestClusterIds && bestClusterIds.length) { 91 | changed = true; 92 | 93 | // remove from origin cluster 94 | const selfCluster = clusters[node.clusterId as string]; 95 | const nodeInSelfClusterIdx = selfCluster.nodes.indexOf(node); 96 | selfCluster.nodes.splice(nodeInSelfClusterIdx, 1); 97 | 98 | // move the node to the best cluster 99 | const randomIdx = Math.floor(Math.random() * bestClusterIds.length); 100 | const bestCluster = clusters[bestClusterIds[randomIdx]]; 101 | bestCluster.nodes.push(node); 102 | node.clusterId = bestCluster.id; 103 | } 104 | }); 105 | if (!changed) break; 106 | iter++; 107 | } 108 | 109 | // delete the empty clusters 110 | Object.keys(clusters).forEach(clusterId => { 111 | const cluster = clusters[clusterId]; 112 | if (!cluster.nodes || !cluster.nodes.length) { 113 | delete clusters[clusterId]; 114 | } 115 | }); 116 | 117 | // get the cluster edges 118 | const clusterEdges = []; 119 | const clusterEdgeMap = {}; 120 | edges.forEach(edge => { 121 | const { source, target } = edge; 122 | const weight = edge[weightPropertyName] || 1; 123 | const sourceClusterId = nodeMap[source].node.clusterId; 124 | const targetClusterId = nodeMap[target].node.clusterId; 125 | const newEdgeId = `${sourceClusterId}---${targetClusterId}`; 126 | if (clusterEdgeMap[newEdgeId]) { 127 | clusterEdgeMap[newEdgeId].weight += weight; 128 | clusterEdgeMap[newEdgeId].count++; 129 | } else { 130 | const newEdge = { 131 | source: sourceClusterId, 132 | target: targetClusterId, 133 | weight, 134 | count: 1 135 | }; 136 | clusterEdgeMap[newEdgeId] = newEdge; 137 | clusterEdges.push(newEdge); 138 | } 139 | }); 140 | 141 | const clustersArray = []; 142 | Object.keys(clusters).forEach(clusterId => { 143 | clustersArray.push(clusters[clusterId]); 144 | }); 145 | return { 146 | clusters: clustersArray, 147 | clusterEdges 148 | } 149 | } 150 | 151 | export default labelPropagation; 152 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/dfs-spec.ts: -------------------------------------------------------------------------------- 1 | import { depthFirstSearch } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | ], 27 | edges: [ 28 | { 29 | source: 'A', 30 | target: 'B', 31 | }, 32 | { 33 | source: 'B', 34 | target: 'C', 35 | }, 36 | { 37 | source: 'C', 38 | target: 'G', 39 | }, 40 | { 41 | source: 'A', 42 | target: 'D', 43 | }, 44 | { 45 | source: 'A', 46 | target: 'E', 47 | }, 48 | { 49 | source: 'E', 50 | target: 'F', 51 | }, 52 | { 53 | source: 'F', 54 | target: 'D', 55 | }, 56 | { 57 | source: 'D', 58 | target: 'G', 59 | }, 60 | ], 61 | }; 62 | 63 | describe('depthFirstSearch', () => { 64 | it('should perform DFS operation on graph', () => { 65 | 66 | const enterNodeCallback = jest.fn(); 67 | const leaveNodeCallback = jest.fn(); 68 | 69 | // Traverse graphs without callbacks first to check default ones. 70 | depthFirstSearch(data, 'A'); 71 | 72 | // Traverse graph with enterNode and leaveNode callbacks. 73 | depthFirstSearch(data, 'A', { 74 | enter: enterNodeCallback, 75 | leave: leaveNodeCallback, 76 | }); 77 | 78 | expect(enterNodeCallback).toHaveBeenCalledTimes(data.nodes.length); 79 | expect(leaveNodeCallback).toHaveBeenCalledTimes(data.nodes.length); 80 | 81 | const enterNodeParamsMap = [ 82 | { currentNode: 'A', previousNode: '' }, 83 | { currentNode: 'B', previousNode: 'A' }, 84 | { currentNode: 'C', previousNode: 'B' }, 85 | { currentNode: 'G', previousNode: 'C' }, 86 | { currentNode: 'D', previousNode: 'A' }, 87 | { currentNode: 'E', previousNode: 'A' }, 88 | { currentNode: 'F', previousNode: 'E' }, 89 | ]; 90 | 91 | for (let callIndex = 0; callIndex < data.nodes.length; callIndex += 1) { 92 | const params = enterNodeCallback.mock.calls[callIndex][0]; 93 | expect(params.current).toEqual(enterNodeParamsMap[callIndex].currentNode); 94 | expect(params.previous).toEqual( 95 | enterNodeParamsMap[callIndex].previousNode, 96 | ); 97 | } 98 | 99 | const leaveNodeParamsMap = [ 100 | { currentNode: 'G', previousNode: 'C' }, 101 | { currentNode: 'C', previousNode: 'B' }, 102 | { currentNode: 'B', previousNode: 'A' }, 103 | { currentNode: 'D', previousNode: 'A' }, 104 | { currentNode: 'F', previousNode: 'E' }, 105 | { currentNode: 'E', previousNode: 'A' }, 106 | { currentNode: 'A', previousNode: '' }, 107 | ]; 108 | 109 | for (let callIndex = 0; callIndex < data.nodes.length; callIndex += 1) { 110 | const params = leaveNodeCallback.mock.calls[callIndex][0]; 111 | expect(params.current).toEqual(leaveNodeParamsMap[callIndex].currentNode); 112 | expect(params.previous).toEqual( 113 | leaveNodeParamsMap[callIndex].previousNode, 114 | ); 115 | } 116 | }); 117 | 118 | it('allow users to redefine node visiting logic', () => { 119 | 120 | const enterNodeCallback = jest.fn(); 121 | const leaveNodeCallback = jest.fn(); 122 | 123 | depthFirstSearch(data, 'A', { 124 | enter: enterNodeCallback, 125 | leave: leaveNodeCallback, 126 | allowTraversal: ({ current: currentNode, next: nextNode }) => { 127 | return !(currentNode === 'A' && nextNode === 'B'); 128 | }, 129 | }); 130 | 131 | expect(enterNodeCallback).toHaveBeenCalledTimes(7); 132 | expect(leaveNodeCallback).toHaveBeenCalledTimes(7); 133 | 134 | const enterNodeParamsMap = [ 135 | { currentNode: 'A', previousNode: '' }, 136 | { currentNode: 'D', previousNode: 'A' }, 137 | { currentNode: 'G', previousNode: 'D' }, 138 | { currentNode: 'E', previousNode: 'A' }, 139 | { currentNode: 'F', previousNode: 'E' }, 140 | { currentNode: 'D', previousNode: 'F' }, 141 | { currentNode: 'G', previousNode: 'D' }, 142 | ]; 143 | 144 | for (let callIndex = 0; callIndex < data.nodes.length; callIndex += 1) { 145 | const params = enterNodeCallback.mock.calls[callIndex][0]; 146 | expect(params.current).toEqual(enterNodeParamsMap[callIndex].currentNode); 147 | expect(params.previous && params.previous).toEqual( 148 | enterNodeParamsMap[callIndex].previousNode, 149 | ); 150 | } 151 | 152 | const leaveNodeParamsMap = [ 153 | { currentNode: 'G', previousNode: 'D' }, 154 | { currentNode: 'D', previousNode: 'A' }, 155 | { currentNode: 'G', previousNode: 'D' }, 156 | { currentNode: 'D', previousNode: 'F' }, 157 | { currentNode: 'F', previousNode: 'E' }, 158 | { currentNode: 'E', previousNode: 'A' }, 159 | { currentNode: 'A', previousNode: '' }, 160 | ]; 161 | 162 | for (let callIndex = 0; callIndex < data.nodes.length; callIndex += 1) { 163 | const params = leaveNodeCallback.mock.calls[callIndex][0]; 164 | expect(params.current).toEqual(leaveNodeParamsMap[callIndex].currentNode); 165 | expect(params.previous).toEqual( 166 | leaveNodeParamsMap[callIndex].previousNode, 167 | ); 168 | } 169 | }); 170 | }); 171 | -------------------------------------------------------------------------------- /packages/graph/tests/unit/bfs-spec.ts: -------------------------------------------------------------------------------- 1 | import { breadthFirstSearch } from '../../src'; 2 | 3 | const data = { 4 | nodes: [ 5 | { 6 | id: 'A', 7 | }, 8 | { 9 | id: 'B', 10 | }, 11 | { 12 | id: 'C', 13 | }, 14 | { 15 | id: 'D', 16 | }, 17 | { 18 | id: 'E', 19 | }, 20 | { 21 | id: 'F', 22 | }, 23 | { 24 | id: 'G', 25 | }, 26 | { 27 | id: 'H', 28 | }, 29 | ], 30 | edges: [ 31 | { 32 | source: 'A', 33 | target: 'B', 34 | }, 35 | { 36 | source: 'B', 37 | target: 'C', 38 | }, 39 | { 40 | source: 'C', 41 | target: 'G', 42 | }, 43 | { 44 | source: 'A', 45 | target: 'D', 46 | }, 47 | { 48 | source: 'A', 49 | target: 'E', 50 | }, 51 | { 52 | source: 'E', 53 | target: 'F', 54 | }, 55 | { 56 | source: 'F', 57 | target: 'D', 58 | }, 59 | ], 60 | }; 61 | 62 | describe('breadthFirstSearch', () => { 63 | it('should perform BFS operation on graph', () => { 64 | const enterNodeCallback = jest.fn(); 65 | const leaveNodeCallback = jest.fn(); 66 | 67 | // Traverse graphs without callbacks first. 68 | breadthFirstSearch(data, 'A'); 69 | 70 | // Traverse graph with enterNode and leaveNode callbacks. 71 | breadthFirstSearch(data, 'A', { 72 | enter: enterNodeCallback, 73 | leave: leaveNodeCallback, 74 | }); 75 | 76 | expect(enterNodeCallback).toHaveBeenCalledTimes(7); 77 | expect(leaveNodeCallback).toHaveBeenCalledTimes(7); 78 | 79 | const nodeA = 'A'; 80 | const nodeB = 'B'; 81 | const nodeC = 'C'; 82 | const nodeD = 'D'; 83 | const nodeE = 'E'; 84 | const nodeF = 'F'; 85 | const nodeG = 'G'; 86 | 87 | const enterNodeParamsMap = [ 88 | { currentNode: nodeA, previousNode: '' }, 89 | { currentNode: nodeB, previousNode: nodeA }, 90 | { currentNode: nodeD, previousNode: nodeB }, 91 | { currentNode: nodeE, previousNode: nodeD }, 92 | { currentNode: nodeC, previousNode: nodeE }, 93 | { currentNode: nodeF, previousNode: nodeC }, 94 | { currentNode: nodeG, previousNode: nodeF }, 95 | ]; 96 | 97 | for (let callIndex = 0; callIndex < 6; callIndex += 1) { 98 | const params = enterNodeCallback.mock.calls[callIndex][0]; 99 | expect(params.current).toEqual(enterNodeParamsMap[callIndex].currentNode); 100 | expect(params.previous).toEqual( 101 | enterNodeParamsMap[callIndex].previousNode && 102 | enterNodeParamsMap[callIndex].previousNode, 103 | ); 104 | } 105 | 106 | const leaveNodeParamsMap = [ 107 | { currentNode: nodeA, previousNode: '' }, 108 | { currentNode: nodeB, previousNode: nodeA }, 109 | { currentNode: nodeD, previousNode: nodeB }, 110 | { currentNode: nodeE, previousNode: nodeD }, 111 | { currentNode: nodeC, previousNode: nodeE }, 112 | { currentNode: nodeF, previousNode: nodeC }, 113 | { currentNode: nodeG, previousNode: nodeF }, 114 | ]; 115 | 116 | for (let callIndex = 0; callIndex < 6; callIndex += 1) { 117 | const params = leaveNodeCallback.mock.calls[callIndex][0]; 118 | expect(params.current).toEqual(leaveNodeParamsMap[callIndex].currentNode); 119 | expect(params.previous).toEqual( 120 | leaveNodeParamsMap[callIndex].previousNode && 121 | leaveNodeParamsMap[callIndex].previousNode, 122 | ); 123 | } 124 | }); 125 | 126 | it('should allow to create custom node visiting logic', () => { 127 | 128 | const enterNodeCallback = jest.fn(); 129 | const leaveNodeCallback = jest.fn(); 130 | 131 | // Traverse graph with enterNode and leaveNode callbacks. 132 | breadthFirstSearch(data, 'A', { 133 | enter: enterNodeCallback, 134 | leave: leaveNodeCallback, 135 | allowTraversal: ({ current, next }) => { 136 | return !(current === 'A' && next === 'B'); 137 | }, 138 | }); 139 | 140 | expect(enterNodeCallback).toHaveBeenCalledTimes(5); 141 | expect(leaveNodeCallback).toHaveBeenCalledTimes(5); 142 | 143 | const enterNodeParamsMap = [ 144 | { currentNode: 'A', previousNode: '' }, 145 | { currentNode: 'D', previousNode: 'A' }, 146 | { currentNode: 'E', previousNode: 'D' }, 147 | { currentNode: 'F', previousNode: 'E' }, 148 | { currentNode: 'D', previousNode: 'F' }, 149 | ]; 150 | 151 | for (let callIndex = 0; callIndex < 5; callIndex += 1) { 152 | const params = enterNodeCallback.mock.calls[callIndex][0]; 153 | expect(params.current).toEqual(enterNodeParamsMap[callIndex].currentNode); 154 | expect(params.previous).toEqual( 155 | enterNodeParamsMap[callIndex].previousNode, 156 | ); 157 | } 158 | 159 | const leaveNodeParamsMap = [ 160 | { currentNode: 'A', previousNode: '' }, 161 | { currentNode: 'D', previousNode: 'A' }, 162 | { currentNode: 'E', previousNode: 'D' }, 163 | { currentNode: 'F', previousNode: 'E' }, 164 | { currentNode: 'D', previousNode: 'F' }, 165 | ]; 166 | 167 | for (let callIndex = 0; callIndex < 5; callIndex += 1) { 168 | const params = leaveNodeCallback.mock.calls[callIndex][0]; 169 | expect(params.current).toEqual(leaveNodeParamsMap[callIndex].currentNode); 170 | expect(params.previous).toEqual( 171 | leaveNodeParamsMap[callIndex].previousNode, 172 | ); 173 | } 174 | }); 175 | }); 176 | -------------------------------------------------------------------------------- /packages/graph/src/structs/linked-list.ts: -------------------------------------------------------------------------------- 1 | const defaultComparator = (a, b) => { 2 | if (a === b) { 3 | return true; 4 | } 5 | 6 | return false; 7 | } 8 | 9 | /** 10 | * 链表中单个元素节点 11 | */ 12 | export class LinkedListNode { 13 | public value; 14 | 15 | public next: LinkedListNode; 16 | 17 | constructor(value, next: LinkedListNode = null) { 18 | this.value = value; 19 | this.next = next; 20 | } 21 | 22 | toString(callback?: any) { 23 | return callback ? callback(this.value) : `${this.value}`; 24 | } 25 | } 26 | 27 | export default class LinkedList { 28 | public head: LinkedListNode; 29 | 30 | public tail: LinkedListNode; 31 | 32 | public compare: Function; 33 | 34 | constructor(comparator = defaultComparator) { 35 | this.head = null; 36 | this.tail = null; 37 | this.compare = comparator; 38 | } 39 | 40 | /** 41 | * 将指定元素添加到链表头部 42 | * @param value 43 | */ 44 | prepend(value) { 45 | // 在头部添加一个节点 46 | const newNode = new LinkedListNode(value, this.head); 47 | this.head = newNode; 48 | 49 | if (!this.tail) { 50 | this.tail = newNode; 51 | } 52 | 53 | return this; 54 | } 55 | 56 | /** 57 | * 将指定元素添加到链表中 58 | * @param value 59 | */ 60 | append(value) { 61 | const newNode = new LinkedListNode(value); 62 | 63 | // 如果不存在头节点,则将创建的新节点作为头节点 64 | if (!this.head) { 65 | this.head = newNode; 66 | this.tail = newNode; 67 | 68 | return this; 69 | } 70 | 71 | // 将新节点附加到链表末尾 72 | this.tail.next = newNode; 73 | this.tail = newNode; 74 | 75 | return this; 76 | } 77 | 78 | /** 79 | * 删除指定元素 80 | * @param value 要删除的元素 81 | */ 82 | delete(value): LinkedListNode { 83 | if (!this.head) { 84 | return null; 85 | } 86 | 87 | let deleteNode = null; 88 | 89 | // 如果删除的是头部元素,则将next作为头元素 90 | while (this.head && this.compare(this.head.value, value)) { 91 | deleteNode = this.head; 92 | this.head = this.head.next; 93 | } 94 | 95 | let currentNode = this.head; 96 | 97 | if (currentNode !== null) { 98 | // 如果删除了节点以后,将next节点前移 99 | while (currentNode.next) { 100 | if (this.compare(currentNode.next.value, value)) { 101 | deleteNode = currentNode.next; 102 | currentNode.next = currentNode.next.next; 103 | } else { 104 | currentNode = currentNode.next; 105 | } 106 | } 107 | } 108 | 109 | // 检查尾部节点是否被删除 110 | if (this.compare(this.tail.value, value)) { 111 | this.tail = currentNode; 112 | } 113 | 114 | return deleteNode; 115 | } 116 | 117 | /** 118 | * 查找指定的元素 119 | * @param param0 120 | */ 121 | find({ value = undefined, callback = undefined }): LinkedListNode { 122 | if (!this.head) { 123 | return null; 124 | } 125 | 126 | let currentNode = this.head; 127 | 128 | while (currentNode) { 129 | // 如果指定了 callback,则按指定的 callback 查找 130 | if (callback && callback(currentNode.value)) { 131 | return currentNode; 132 | } 133 | 134 | // 如果指定了 value,则按 value 查找 135 | if (value !== undefined && this.compare(currentNode.value, value)) { 136 | return currentNode; 137 | } 138 | 139 | currentNode = currentNode.next; 140 | } 141 | 142 | return null; 143 | } 144 | 145 | /** 146 | * 删除尾部节点 147 | */ 148 | deleteTail() { 149 | const deletedTail = this.tail; 150 | 151 | if (this.head === this.tail) { 152 | // 链表中只有一个元素 153 | this.head = null; 154 | this.tail = null; 155 | return deletedTail; 156 | } 157 | 158 | let currentNode = this.head; 159 | while (currentNode.next) { 160 | if (!currentNode.next.next) { 161 | currentNode.next = null; 162 | } else { 163 | currentNode = currentNode.next; 164 | } 165 | } 166 | 167 | this.tail = currentNode; 168 | 169 | return deletedTail; 170 | } 171 | 172 | /** 173 | * 删除头部节点 174 | */ 175 | deleteHead() { 176 | if (!this.head) { 177 | return null; 178 | } 179 | 180 | const deletedHead = this.head; 181 | 182 | if (this.head.next) { 183 | this.head = this.head.next; 184 | } else { 185 | this.head = null; 186 | this.tail = null; 187 | } 188 | 189 | return deletedHead; 190 | } 191 | 192 | /** 193 | * 将一组元素转成链表中的节点 194 | * @param values 链表中的元素 195 | */ 196 | fromArray(values) { 197 | values.forEach((value) => this.append(value)); 198 | return this; 199 | } 200 | 201 | /** 202 | * 将链表中的节点转成数组元素 203 | */ 204 | toArray() { 205 | const nodes = []; 206 | 207 | let currentNode = this.head; 208 | 209 | while (currentNode) { 210 | nodes.push(currentNode); 211 | currentNode = currentNode.next; 212 | } 213 | 214 | return nodes; 215 | } 216 | 217 | /** 218 | * 反转链表中的元素节点 219 | */ 220 | reverse() { 221 | let currentNode = this.head; 222 | let prevNode = null; 223 | let nextNode = null; 224 | while (currentNode) { 225 | // 存储下一个元素节点 226 | nextNode = currentNode.next; 227 | 228 | // 更改当前节点的下一个节点,以便将它连接到上一个节点上 229 | currentNode.next = prevNode; 230 | 231 | // 将 prevNode 和 currentNode 向前移动一步 232 | prevNode = currentNode; 233 | currentNode = nextNode; 234 | } 235 | 236 | this.tail = this.head; 237 | this.head = prevNode; 238 | } 239 | 240 | toString(callback = undefined) { 241 | return this.toArray() 242 | .map((node) => node.toString(callback)) 243 | .toString(); 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /packages/graph/src/types.ts: -------------------------------------------------------------------------------- 1 | 2 | export type Matrix = number[]; 3 | 4 | export interface NodeConfig { 5 | id: string; 6 | clusterId?: string; 7 | [key: string]: any; 8 | } 9 | 10 | export interface EdgeConfig { 11 | source: string; 12 | target: string; 13 | weight?: number; 14 | [key: string]: any; 15 | } 16 | 17 | export interface GraphData { 18 | nodes?: NodeConfig[]; 19 | edges?: EdgeConfig[]; 20 | } 21 | 22 | export interface Cluster { 23 | id: string; 24 | nodes: NodeConfig[]; 25 | sumTot?: number; 26 | } 27 | 28 | export interface ClusterData { 29 | clusters: Cluster[]; 30 | clusterEdges: EdgeConfig[]; 31 | } 32 | 33 | export interface ClusterMap { 34 | [key: string]: Cluster 35 | } 36 | 37 | // 图算法回调方法接口定义 38 | export interface IAlgorithmCallbacks { 39 | enter?: (param: { current: string; previous: string }) => void; 40 | leave?: (param: { current: string; previous?: string }) => void; 41 | allowTraversal?: (param: { previous?: string; current?: string; next: string }) => boolean; 42 | } 43 | 44 | export interface DegreeType { 45 | [key: string]: { 46 | degree: number; 47 | inDegree: number; 48 | outDegree: number; 49 | } 50 | } 51 | 52 | export enum DistanceType { 53 | EuclideanDistance = 'euclideanDistance', 54 | } 55 | 56 | export interface PlainObject { 57 | [key: string]: any; 58 | } 59 | 60 | // 数据集中属性/特征值分布的map 61 | export interface KeyValueMap { 62 | [key:string]: any[]; 63 | } 64 | 65 | export interface IAlgorithm { 66 | getAdjMatrix: (graphData: GraphData, directed?: boolean) => Matrix[], 67 | breadthFirstSearch: ( 68 | graphData: GraphData, 69 | startNodeId: string, 70 | originalCallbacks?: IAlgorithmCallbacks, 71 | directed?: boolean 72 | ) => void, 73 | connectedComponent: (graphData: GraphData, directed?: boolean) => NodeConfig[][], 74 | getDegree: (graphData: GraphData) => DegreeType, 75 | getInDegree: (graphData: GraphData, nodeId: string) => number, 76 | getOutDegree: (graphData: GraphData, nodeId: string) => number, 77 | detectCycle: (graphData: GraphData) => { [key: string]: string }, 78 | detectDirectedCycle: (graphData: GraphData) => { [key: string]: string }, 79 | detectAllCycles: (graphData: GraphData, directed?: boolean, nodeIds?: string[], include?: boolean) => any, 80 | detectAllDirectedCycle: (graphData: GraphData, nodeIds?: string[], include?: boolean) => any, 81 | detectAllUndirectedCycle: (graphData: GraphData, nodeIds?: string[], include?: boolean) => any, 82 | depthFirstSearch: (graphData: GraphData, startNodeId: string, callbacks?: IAlgorithmCallbacks) => void, 83 | dijkstra: (graphData: GraphData, source: string, directed?: boolean, weightPropertyName?: string) => { length: object, allPath: object, path: object}, 84 | findAllPath: (graphData: GraphData, start: string, end: string, directed?: boolean) => any, 85 | findShortestPath: (graphData: GraphData, start: string, end: string, directed?: boolean, weightPropertyName?: string) => any, 86 | floydWarshall: (graphData: GraphData, directed?: boolean) => Matrix[], 87 | labelPropagation: (graphData: GraphData, directed?: boolean, weightPropertyName?: string, maxIteration?: number) => ClusterData, 88 | louvain: (graphData: GraphData, directed: boolean, weightPropertyName: string, threshold: number) => ClusterData, 89 | minimumSpanningTree: (graphData: GraphData, weight?: string, algo?: string) => EdgeConfig[], 90 | pageRank: (graphData: GraphData, epsilon?: number, linkProb?: number) => { [key: string]: number}, 91 | getNeighbors: (nodeId: string, edges?: EdgeConfig[], type?: 'target' | 'source' | undefined) => string[], 92 | Stack: any, 93 | GADDI: (graphData: GraphData, pattern: GraphData, directed: boolean, k: number, length: number, nodeLabelProp: string, edgeLabelProp: string) => GraphData[], 94 | getAdjMatrixAsync: (graphData: GraphData, directed?: boolean) => Matrix[], 95 | connectedComponentAsync: (graphData: GraphData, directed?: boolean) => NodeConfig[][], 96 | getDegreeAsync: (graphData: GraphData) => DegreeType, 97 | getInDegreeAsync: (graphData: GraphData, nodeId: string) => number, 98 | getOutDegreeAsync: (graphData: GraphData, nodeId: string) => number, 99 | detectCycleAsync: (graphData: GraphData) => { [key: string]: string }, 100 | detectDirectedCycleAsync: (graphData: GraphData) => { [key: string]: string }, 101 | detectAllCyclesAsync: (graphData: GraphData, directed?: boolean, nodeIds?: string[], include?: boolean) => any, 102 | detectAllDirectedCycleAsync: (graphData: GraphData, nodeIds?: string[], include?: boolean) => any, 103 | detectAllUndirectedCycleAsync: (graphData: GraphData, nodeIds?: string[], include?: boolean) => any, 104 | dijkstraAsync: (graphData: GraphData, source: string, directed?: boolean, weightPropertyName?: string) => { length: object, allPath: object, path: object}, 105 | findAllPathAsync: (graphData: GraphData, start: string, end: string, directed?: boolean) => any, 106 | findShortestPathAsync: (graphData: GraphData, start: string, end: string, directed?: boolean, weightPropertyName?: string) => any, 107 | floydWarshallAsync: (graphData: GraphData, directed?: boolean) => Matrix[], 108 | labelPropagationAsync: (graphData: GraphData, directed?: boolean, weightPropertyName?: string, maxIteration?: number) => ClusterData, 109 | louvainAsync: (graphData: GraphData, directed: boolean, weightPropertyName: string, threshold: number) => ClusterData, 110 | minimumSpanningTreeAsync: (graphData: GraphData, weight?: string, algo?: string) => EdgeConfig[], 111 | pageRankAsync: (graphData: GraphData, epsilon?: number, linkProb?: number) => { [key: string]: number}, 112 | getNeighborsAsync: (nodeId: string, edges?: EdgeConfig[], type?: 'target' | 'source' | undefined) => string[], 113 | GADDIAsync: (graphData: GraphData, pattern: GraphData, directed: boolean, k: number, length: number, nodeLabelProp: string, edgeLabelProp: string) => GraphData[], 114 | } -------------------------------------------------------------------------------- /packages/webgpu-graph/src/traversal/sssp.ts: -------------------------------------------------------------------------------- 1 | import type { WebGLRenderer } from '@antv/g-webgl'; 2 | import { Kernel, BufferUsage } from '@antv/g-plugin-gpgpu'; 3 | import { GraphData } from '../types'; 4 | import { convertGraphData2CSC } from '../util'; 5 | 6 | /** 7 | * SSSP(Bellman-Ford) ported from CUDA 8 | * 9 | * @see https://www.lewuathe.com/illustration-of-distributed-bellman-ford-algorithm.html 10 | * @see https://github.com/sengorajkumar/gpu_graph_algorithms 11 | * @see https://docs.rapids.ai/api/cugraph/stable/api_docs/api/cugraph.traversal.sssp.sssp.html 12 | * compared with G6: 13 | * @see https://g6.antv.vision/zh/docs/api/Algorithm#findshortestpathgraphdata-start-end-directed-weightpropertyname 14 | */ 15 | export async function sssp(device: WebGLRenderer.Device, graphData: GraphData, sourceId: string, weightPropertyName: string = '', maxDistance = 1000000) { 16 | // The total number of workgroup invocations (4096) exceeds the maximum allowed (256). 17 | const BLOCK_SIZE = 1; 18 | const BLOCKS = 256; 19 | const MAX_DISTANCE = maxDistance; 20 | 21 | const { V, E, I, nodeId2IndexMap, edges } = convertGraphData2CSC(graphData); 22 | let W: number[]; 23 | const sourceIdx = nodeId2IndexMap[sourceId]; 24 | if (weightPropertyName) { 25 | W = edges.map((edgeConfig) => Number(edgeConfig[weightPropertyName])); 26 | } else { 27 | // all the vertex has the same weight 28 | W = new Array(E.length).fill(1); 29 | } 30 | 31 | const relaxKernel = new Kernel(device, { 32 | computeShader: ` 33 | struct Buffer { 34 | data: array; 35 | }; 36 | struct AtomicBuffer { 37 | data: array>; 38 | }; 39 | 40 | @group(0) @binding(0) var d_in_E : Buffer; 41 | @group(0) @binding(1) var d_in_I : Buffer; 42 | @group(0) @binding(2) var d_in_W : Buffer; 43 | @group(0) @binding(3) var d_out_D : Buffer; 44 | @group(0) @binding(4) var d_out_Di : AtomicBuffer; 45 | 46 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 47 | fn main( 48 | @builtin(global_invocation_id) global_id : vec3 49 | ) { 50 | var index = global_id.x; 51 | if (index < ${V.length}u) { 52 | for (var j = d_in_I.data[index]; j < d_in_I.data[index + 1u]; j = j + 1) { 53 | var w = d_in_W.data[j]; 54 | var du = d_out_D.data[index]; 55 | var dv = d_out_D.data[d_in_E.data[j]]; 56 | var newDist = du + w; 57 | if (du == ${MAX_DISTANCE}) { 58 | newDist = ${MAX_DISTANCE}; 59 | } 60 | 61 | if (newDist < dv) { 62 | atomicMin(&d_out_Di.data[d_in_E.data[j]], newDist); 63 | } 64 | } 65 | } 66 | }`, 67 | }); 68 | 69 | const updateDistanceKernel = new Kernel(device, { 70 | computeShader: ` 71 | struct Buffer { 72 | data: array; 73 | }; 74 | 75 | @group(0) @binding(0) var d_out_D : Buffer; 76 | @group(0) @binding(1) var d_out_Di : Buffer; 77 | 78 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 79 | fn main( 80 | @builtin(global_invocation_id) global_id : vec3 81 | ) { 82 | var index = global_id.x; 83 | if (index < ${V.length}u) { 84 | if (d_out_D.data[index] > d_out_Di.data[index]) { 85 | d_out_D.data[index] = d_out_Di.data[index]; 86 | } 87 | d_out_Di.data[index] = d_out_D.data[index]; 88 | } 89 | } 90 | `, 91 | }); 92 | 93 | const updatePredKernel = new Kernel(device, { 94 | computeShader: ` 95 | struct Buffer { 96 | data: array; 97 | }; 98 | struct AtomicBuffer { 99 | data: array>; 100 | }; 101 | 102 | @group(0) @binding(0) var d_in_V : Buffer; 103 | @group(0) @binding(1) var d_in_E : Buffer; 104 | @group(0) @binding(2) var d_in_I : Buffer; 105 | @group(0) @binding(3) var d_in_W : Buffer; 106 | @group(0) @binding(4) var d_out_D : Buffer; 107 | @group(0) @binding(5) var d_out_P : AtomicBuffer; 108 | 109 | @stage(compute) @workgroup_size(${BLOCKS}, ${BLOCK_SIZE}) 110 | fn main( 111 | @builtin(global_invocation_id) global_id : vec3 112 | ) { 113 | var index = global_id.x; 114 | if (index < ${V.length}u) { 115 | for (var j = d_in_I.data[index]; j < d_in_I.data[index + 1u]; j = j + 1) { 116 | var u = d_in_V.data[index]; 117 | var w = d_in_W.data[j]; 118 | 119 | var dis_u = d_out_D.data[index]; 120 | var dis_v = d_out_D.data[d_in_E.data[j]]; 121 | if (dis_v == dis_u + w) { 122 | atomicMin(&d_out_P.data[d_in_E.data[j]], u); 123 | } 124 | } 125 | } 126 | } 127 | `, 128 | }); 129 | 130 | const VBuffer = device.createBuffer({ 131 | usage: BufferUsage.STORAGE, 132 | viewOrSize: new Int32Array(V), 133 | }); 134 | const EBuffer = device.createBuffer({ 135 | usage: BufferUsage.STORAGE, 136 | viewOrSize: new Int32Array(E), 137 | }); 138 | const IBuffer = device.createBuffer({ 139 | usage: BufferUsage.STORAGE, 140 | viewOrSize: new Int32Array(I), 141 | }); 142 | const WBuffer = device.createBuffer({ 143 | usage: BufferUsage.STORAGE, 144 | viewOrSize: new Int32Array(W), 145 | }); 146 | 147 | // mark source vertex 148 | const view = new Array(V.length).fill(MAX_DISTANCE); 149 | view[sourceIdx] = 0; 150 | 151 | const DOutBuffer = device.createBuffer({ 152 | usage: BufferUsage.STORAGE | BufferUsage.COPY_SRC, 153 | viewOrSize: new Int32Array(view), 154 | }); 155 | const DiOutBuffer = device.createBuffer({ 156 | usage: BufferUsage.STORAGE | BufferUsage.COPY_SRC, 157 | viewOrSize: new Int32Array(view), 158 | }); 159 | 160 | // store predecessors 161 | const POutBuffer = device.createBuffer({ 162 | usage: BufferUsage.STORAGE | BufferUsage.COPY_SRC, 163 | viewOrSize: new Int32Array(view), 164 | }); 165 | const readback = device.createReadback(); 166 | 167 | relaxKernel.setBinding(0, EBuffer); 168 | relaxKernel.setBinding(1, IBuffer); 169 | relaxKernel.setBinding(2, WBuffer); 170 | relaxKernel.setBinding(3, DOutBuffer); 171 | relaxKernel.setBinding(4, DiOutBuffer); 172 | 173 | updateDistanceKernel.setBinding(0, DOutBuffer); 174 | updateDistanceKernel.setBinding(1, DiOutBuffer); 175 | 176 | updatePredKernel.setBinding(0, VBuffer); 177 | updatePredKernel.setBinding(1, EBuffer); 178 | updatePredKernel.setBinding(2, IBuffer); 179 | updatePredKernel.setBinding(3, WBuffer); 180 | updatePredKernel.setBinding(4, DOutBuffer); 181 | updatePredKernel.setBinding(5, POutBuffer); 182 | 183 | const grids = Math.ceil(V.length / (BLOCKS * BLOCK_SIZE)); 184 | for (let i = 1; i < V.length; i++) { 185 | relaxKernel.dispatch(grids, 1); 186 | updateDistanceKernel.dispatch(grids, 1); 187 | } 188 | updatePredKernel.dispatch(grids, 1); 189 | 190 | const out = await readback.readBuffer(DiOutBuffer) as Float32Array; 191 | const predecessor = await readback.readBuffer(POutBuffer); 192 | 193 | return Array.from(out).map((distance, i) => ({ 194 | target: graphData.nodes[V[i]].id, 195 | distance, 196 | predecessor: graphData.nodes[V[predecessor[i]]].id, 197 | })); 198 | } -------------------------------------------------------------------------------- /packages/graph/tests/unit/kMeans-spec.ts: -------------------------------------------------------------------------------- 1 | import { kMeans } from '../../src'; 2 | import { GraphData, NodeConfig } from '../../src/types'; 3 | import propertiesGraphData from './data/cluster-origin-properties-data.json'; 4 | 5 | describe('kMeans abnormal demo', () => { 6 | it('no properties demo: ', () => { 7 | const noPropertiesData = { 8 | nodes: [ 9 | { 10 | id: 'node-0', 11 | }, 12 | { 13 | id: 'node-1', 14 | }, 15 | { 16 | id: 'node-2', 17 | }, 18 | { 19 | id: 'node-3', 20 | } 21 | ], 22 | edges: [], 23 | } 24 | const { clusters, clusterEdges } = kMeans(noPropertiesData, 2); 25 | expect(clusters.length).toBe(1); 26 | expect(clusterEdges.length).toBe(0); 27 | }); 28 | }); 29 | 30 | describe('kMeans normal demo', () => { 31 | const simpleGraphData = { 32 | nodes: [ 33 | { 34 | id: 'node-0', 35 | properties: { 36 | amount: 10, 37 | city: '10001', 38 | } 39 | }, 40 | { 41 | id: 'node-1', 42 | properties: { 43 | amount: 10000, 44 | city: '10002', 45 | } 46 | }, 47 | { 48 | id: 'node-2', 49 | properties: { 50 | amount: 3000, 51 | city: '10003', 52 | } 53 | }, 54 | { 55 | id: 'node-3', 56 | properties: { 57 | amount: 3200, 58 | city: '10003', 59 | } 60 | }, 61 | { 62 | id: 'node-4', 63 | properties: { 64 | amount: 2000, 65 | city: '10003', 66 | } 67 | } 68 | ], 69 | edges: [ 70 | { 71 | id: 'edge-0', 72 | source: 'node-0', 73 | target: 'node-1', 74 | }, 75 | { 76 | id: 'edge-1', 77 | source: 'node-0', 78 | target: 'node-2', 79 | }, 80 | { 81 | id: 'edge-4', 82 | source: 'node-3', 83 | target: 'node-2', 84 | }, 85 | { 86 | id: 'edge-5', 87 | source: 'node-2', 88 | target: 'node-1', 89 | }, 90 | { 91 | id: 'edge-6', 92 | source: 'node-4', 93 | target: 'node-1', 94 | }, 95 | ] 96 | } 97 | it('simple data demo: ', () => { 98 | const nodes = simpleGraphData.nodes as NodeConfig[]; 99 | const { clusters } = kMeans(simpleGraphData, 3, 'properties'); 100 | expect(clusters.length).toBe(3); 101 | expect(nodes[2].clusterId).toEqual(nodes[3].clusterId); 102 | expect(nodes[2].clusterId).toEqual(nodes[4].clusterId); 103 | }); 104 | 105 | it('complex data demo: ', () => { 106 | const nodes = propertiesGraphData.nodes as NodeConfig[]; 107 | const { clusters } = kMeans(propertiesGraphData as GraphData, 3, 'properties'); 108 | expect(clusters.length).toBe(3); 109 | expect(nodes[0].clusterId).toEqual(nodes[1].clusterId); 110 | expect(nodes[0].clusterId).toEqual(nodes[2].clusterId); 111 | expect(nodes[0].clusterId).toEqual(nodes[3].clusterId); 112 | expect(nodes[0].clusterId).toEqual(nodes[4].clusterId); 113 | expect(nodes[5].clusterId).toEqual(nodes[6].clusterId); 114 | expect(nodes[5].clusterId).toEqual(nodes[7].clusterId); 115 | expect(nodes[5].clusterId).toEqual(nodes[8].clusterId); 116 | expect(nodes[5].clusterId).toEqual(nodes[9].clusterId); 117 | expect(nodes[5].clusterId).toEqual(nodes[10].clusterId); 118 | expect(nodes[11].clusterId).toEqual(nodes[12].clusterId); 119 | expect(nodes[11].clusterId).toEqual(nodes[13].clusterId); 120 | expect(nodes[11].clusterId).toEqual(nodes[14].clusterId); 121 | expect(nodes[11].clusterId).toEqual(nodes[15].clusterId); 122 | expect(nodes[11].clusterId).toEqual(nodes[16].clusterId); 123 | }); 124 | 125 | 126 | it('demo use involvedKeys: ', () => { 127 | const involvedKeys = ['amount']; 128 | const nodes = simpleGraphData.nodes as NodeConfig[]; 129 | const { clusters } = kMeans(simpleGraphData, 3, 'properties', involvedKeys); 130 | expect(clusters.length).toBe(3); 131 | expect(nodes[2].clusterId).toEqual(nodes[3].clusterId); 132 | expect(nodes[2].clusterId).toEqual(nodes[4].clusterId); 133 | }); 134 | 135 | it('demo use uninvolvedKeys: ', () => { 136 | const uninvolvedKeys = ['id', 'city']; 137 | const nodes = simpleGraphData.nodes as NodeConfig[]; 138 | const { clusters } = kMeans(simpleGraphData, 3, 'properties', [], uninvolvedKeys); 139 | expect(clusters.length).toBe(3); 140 | expect(nodes[2].clusterId).toEqual(nodes[3].clusterId); 141 | expect(nodes[2].clusterId).toEqual(nodes[4].clusterId); 142 | }); 143 | }); 144 | 145 | describe('kMeans All properties values are numeric demo', () => { 146 | it('all properties values are numeric demo: ', () => { 147 | const allPropertiesValuesNumericData = { 148 | nodes: [ 149 | { 150 | id: 'node-0', 151 | properties: { 152 | max: 1000000, 153 | mean: 900000, 154 | min: 800000, 155 | } 156 | }, 157 | { 158 | id: 'node-1', 159 | properties: { 160 | max: 1600000, 161 | mean: 1100000, 162 | min: 600000, 163 | } 164 | }, 165 | { 166 | id: 'node-2', 167 | properties: { 168 | max: 5000, 169 | mean: 3500, 170 | min: 2000, 171 | } 172 | }, 173 | { 174 | id: 'node-3', 175 | properties: { 176 | max: 9000, 177 | mean: 7500, 178 | min: 6000, 179 | } 180 | } 181 | ], 182 | edges: [], 183 | } 184 | const { clusters, clusterEdges } = kMeans(allPropertiesValuesNumericData, 2, 'properties'); 185 | expect(clusters.length).toBe(2); 186 | expect(clusterEdges.length).toBe(0); 187 | const nodes = allPropertiesValuesNumericData.nodes as NodeConfig[]; 188 | expect(nodes[0].clusterId).toEqual(nodes[1].clusterId); 189 | expect(nodes[2].clusterId).toEqual(nodes[3].clusterId); 190 | }); 191 | 192 | it('only one property and the value are numeric demo: ', () => { 193 | const allPropertiesValuesNumericData = { 194 | nodes: [ 195 | { 196 | id: 'node-0', 197 | properties: { 198 | num: 10, 199 | } 200 | }, 201 | { 202 | id: 'node-1', 203 | properties: { 204 | num: 12, 205 | } 206 | }, 207 | { 208 | id: 'node-2', 209 | properties: { 210 | num: 56, 211 | } 212 | }, 213 | { 214 | id: 'node-3', 215 | properties: { 216 | num: 300, 217 | } 218 | }, 219 | { 220 | id: 'node-4', 221 | properties: { 222 | num: 350, 223 | } 224 | } 225 | ], 226 | edges: [], 227 | } 228 | const { clusters, clusterEdges } = kMeans(allPropertiesValuesNumericData, 2, 'properties'); 229 | expect(clusters.length).toBe(2); 230 | expect(clusterEdges.length).toBe(0); 231 | const nodes = allPropertiesValuesNumericData.nodes as NodeConfig[]; 232 | expect(nodes[0].clusterId).toEqual(nodes[1].clusterId); 233 | expect(nodes[0].clusterId).toEqual(nodes[2].clusterId); 234 | expect(nodes[3].clusterId).toEqual(nodes[4].clusterId); 235 | }); 236 | }); 237 | -------------------------------------------------------------------------------- /packages/graph/src/workers/index.ts: -------------------------------------------------------------------------------- 1 | import { 2 | GraphData, 3 | DegreeType, 4 | Matrix, 5 | ClusterData, 6 | EdgeConfig, 7 | NodeConfig, 8 | } from '../types'; 9 | import createWorker from './createWorker'; 10 | import { ALGORITHM } from './constant'; 11 | 12 | /** 13 | * @param graphData 图数据 14 | * @param directed 是否为有向图 15 | */ 16 | const getAdjMatrixAsync = (graphData: GraphData, directed?: boolean) => 17 | createWorker(ALGORITHM.getAdjMatrix)(...[graphData, directed]); 18 | 19 | /** 20 | * 图的连通分量 21 | * @param graphData 图数据 22 | * @param directed 是否为有向图 23 | */ 24 | const connectedComponentAsync = (graphData: GraphData, directed?: boolean) => 25 | createWorker(ALGORITHM.connectedComponent)(...[graphData, directed]); 26 | 27 | /** 28 | * 获取节点的度 29 | * @param graphData 图数据 30 | */ 31 | const getDegreeAsync = (graphData: GraphData) => 32 | createWorker(ALGORITHM.getDegree)(graphData); 33 | 34 | /** 35 | * 获取节点的入度 36 | * @param graphData 图数据 37 | * @param nodeId 节点ID 38 | */ 39 | const getInDegreeAsync = (graphData: GraphData, nodeId: string) => 40 | createWorker(ALGORITHM.getInDegree)(graphData, nodeId); 41 | 42 | /** 43 | * 获取节点的出度 44 | * @param graphData 图数据 45 | * @param nodeId 节点ID 46 | */ 47 | const getOutDegreeAsync = (graphData: GraphData, nodeId: string) => 48 | createWorker(ALGORITHM.getOutDegree)(graphData, nodeId); 49 | 50 | /** 51 | * 检测图中的(有向) Cycle 52 | * @param graphData 图数据 53 | */ 54 | const detectCycleAsync = (graphData: GraphData) => 55 | createWorker<{ 56 | [key: string]: string; 57 | }>(ALGORITHM.detectCycle)(graphData); 58 | 59 | /** 60 | * 检测图中的(无向) Cycle 61 | * @param graphData 图数据 62 | */ 63 | const detectAllCyclesAsync = (graphData: GraphData) => 64 | createWorker<{ 65 | [key: string]: string; 66 | }>(ALGORITHM.detectAllCycles)(graphData); 67 | 68 | /** 69 | * 检测图中的所有(有向) Cycle 70 | * @param graphData 图数据 71 | */ 72 | const detectAllDirectedCycleAsync = (graphData: GraphData) => 73 | createWorker<{ 74 | [key: string]: string; 75 | }>(ALGORITHM.detectAllDirectedCycle)(graphData); 76 | 77 | /** 78 | * 检测图中的所有(无向) Cycle 79 | * @param graphData 图数据 80 | */ 81 | const detectAllUndirectedCycleAsync = (graphData: GraphData) => 82 | createWorker<{ 83 | [key: string]: string; 84 | }>(ALGORITHM.detectAllUndirectedCycle)(graphData); 85 | 86 | /** 87 | * Dijkstra's algorithm, See {@link https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm} 88 | * @param graphData 图数据 89 | */ 90 | const dijkstraAsync = ( 91 | graphData: GraphData, 92 | source: string, 93 | directed?: boolean, 94 | weightPropertyName?: string, 95 | ) => 96 | createWorker<{ 97 | length: number; 98 | path: any; 99 | allPath: any; 100 | }>(ALGORITHM.dijkstra)(...[graphData, source, directed, weightPropertyName]); 101 | 102 | /** 103 | * 查找两点之间的所有路径 104 | * @param graphData 图数据 105 | * @param start 路径起始点ID 106 | * @param end 路径终点ID 107 | * @param directed 是否为有向图 108 | */ 109 | const findAllPathAsync = (graphData: GraphData, start: string, end: string, directed?: boolean) => 110 | createWorker(ALGORITHM.findAllPath)(...[graphData, start, end, directed]); 111 | 112 | /** 113 | * 查找两点之间的所有路径 114 | * @param graphData 图数据 115 | * @param start 路径起始点ID 116 | * @param end 路径终点ID 117 | * @param directed 是否为有向图 118 | * @param weightPropertyName 边权重的属名称,若数据中没有权重,则默认每条边权重为 1 119 | */ 120 | const findShortestPathAsync = ( 121 | graphData: GraphData, 122 | start: string, 123 | end: string, 124 | directed?: boolean, 125 | weightPropertyName?: string, 126 | ) => 127 | createWorker<{ 128 | length: number; 129 | path: any; 130 | allPath: any; 131 | }>(ALGORITHM.findShortestPath)(...[graphData, start, end, directed, weightPropertyName]); 132 | 133 | /** 134 | * Floyd–Warshall algorithm, See {@link https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm} 135 | * @param graphData 图数据 136 | * @param directed 是否为有向图 137 | */ 138 | const floydWarshallAsync = (graphData: GraphData, directed?: boolean) => 139 | createWorker(ALGORITHM.floydWarshall)(...[graphData, directed]); 140 | 141 | /** 142 | * 标签传播算法 143 | * @param graphData 图数据 144 | * @param directed 是否有向图,默认为 false 145 | * @param weightPropertyName 权重的属性字段 146 | * @param maxIteration 最大迭代次数 147 | */ 148 | const labelPropagationAsync = ( 149 | graphData: GraphData, 150 | directed: boolean, 151 | weightPropertyName: string, 152 | maxIteration: number = 1000, 153 | ) => 154 | createWorker(ALGORITHM.labelPropagation)( 155 | graphData, 156 | directed, 157 | weightPropertyName, 158 | maxIteration, 159 | ); 160 | 161 | /** 162 | * 社区发现 louvain 算法 163 | * @param graphData 图数据 164 | * @param directed 是否有向图,默认为 false 165 | * @param weightPropertyName 权重的属性字段 166 | * @param threshold 167 | */ 168 | const louvainAsync = ( 169 | graphData: GraphData, 170 | directed: boolean, 171 | weightPropertyName: string, 172 | threshold: number, 173 | ) => 174 | createWorker(ALGORITHM.louvain)(graphData, directed, weightPropertyName, threshold); 175 | 176 | /** 177 | * 最小生成树,See {@link https://en.wikipedia.org/wiki/Kruskal%27s_algorithm} 178 | * @param graph 179 | * @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致 180 | * @param algo 'prim' | 'kruskal' 算法类型 181 | * @return EdgeConfig[] 返回构成MST的边的数组 182 | */ 183 | const minimumSpanningTreeAsync = (graphData: GraphData, weight?: boolean, algo?: string) => 184 | createWorker(ALGORITHM.minimumSpanningTree)(...[graphData, weight, algo]); 185 | 186 | /** 187 | * PageRank https://en.wikipedia.org/wiki/PageRank 188 | * refer: https://github.com/anvaka/ngraph.pagerank 189 | * @param graph 190 | * @param epsilon 判断是否收敛的精度值,默认 0.000001 191 | * @param linkProb 阻尼系数(dumping factor),指任意时刻,用户访问到某节点后继续访问该节点链接的下一个节点的概率,经验值 0.85 192 | */ 193 | const pageRankAsync = (graphData: GraphData, epsilon?: number, linkProb?: number) => 194 | createWorker<{ 195 | [key: string]: number; 196 | }>(ALGORITHM.pageRank)(...[graphData, epsilon, linkProb]); 197 | 198 | /** 199 | * 获取指定节点的所有邻居 200 | * @param nodeId 节点 ID 201 | * @param edges 图中的所有边数据 202 | * @param type 邻居类型 203 | */ 204 | const getNeighborsAsync = ( 205 | nodeId: string, 206 | edges: EdgeConfig[], 207 | type?: 'target' | 'source' | undefined, 208 | ) => createWorker(ALGORITHM.getNeighbors)(...[nodeId, edges, type]); 209 | 210 | /** 211 | * GADDI 图模式匹配 212 | * @param graphData 原图数据 213 | * @param pattern 搜索图(需要在原图上搜索的模式)数据 214 | * @param directed 是否计算有向图,默认 false 215 | * @param k 参数 k,表示 k-近邻 216 | * @param length 参数 length 217 | * @param nodeLabelProp 节点数据中代表节点标签(分类信息)的属性名。默认为 cluster 218 | * @param edgeLabelProp 边数据中代表边标签(分类信息)的属性名。默认为 cluster 219 | */ 220 | const GADDIAsync = ( 221 | graphData: GraphData, 222 | pattern: GraphData, 223 | directed: boolean = false, 224 | k: number, 225 | length: number, 226 | nodeLabelProp: string = 'cluster', 227 | edgeLabelProp: string = 'cluster', 228 | ) => 229 | createWorker(ALGORITHM.GADDI)( 230 | ...[graphData, pattern, directed, k, length, nodeLabelProp, edgeLabelProp], 231 | ); 232 | 233 | export { 234 | getAdjMatrixAsync, 235 | connectedComponentAsync, 236 | getDegreeAsync, 237 | getInDegreeAsync, 238 | getOutDegreeAsync, 239 | detectCycleAsync, 240 | detectAllCyclesAsync, 241 | detectAllDirectedCycleAsync, 242 | detectAllUndirectedCycleAsync, 243 | dijkstraAsync, 244 | findAllPathAsync, 245 | findShortestPathAsync, 246 | floydWarshallAsync, 247 | labelPropagationAsync, 248 | louvainAsync, 249 | minimumSpanningTreeAsync, 250 | pageRankAsync, 251 | getNeighborsAsync, 252 | GADDIAsync, 253 | }; 254 | -------------------------------------------------------------------------------- /packages/graph/src/k-means.ts: -------------------------------------------------------------------------------- 1 | import { isEqual, uniq } from '@antv/util'; 2 | import { getAllProperties } from './utils/node-properties'; 3 | import { oneHot, getDistance } from './utils/data-preprocessing'; 4 | import Vector from './utils/vector'; 5 | import { GraphData, ClusterData, DistanceType } from './types'; 6 | 7 | // 获取质心 8 | const getCentroid = (distanceType, allPropertiesWeight, index) => { 9 | let centroid = []; 10 | switch (distanceType) { 11 | case DistanceType.EuclideanDistance: 12 | centroid = allPropertiesWeight[index]; 13 | break; 14 | default: 15 | centroid = []; 16 | break; 17 | } 18 | return centroid; 19 | } 20 | 21 | /** 22 | * k-means算法 根据节点之间的距离将节点聚类为K个簇 23 | * @param data 图数据 24 | * @param k 质心(聚类中心)个数 25 | * @param propertyKey 属性的字段名 26 | * @param involvedKeys 参与计算的key集合 27 | * @param uninvolvedKeys 不参与计算的key集合 28 | * @param distanceType 距离类型 默认节点属性的欧式距离 29 | */ 30 | const kMeans = ( 31 | data: GraphData, 32 | k: number = 3, 33 | propertyKey: string = undefined, 34 | involvedKeys: string[] = [], 35 | uninvolvedKeys: string[] = ['id'], 36 | distanceType: DistanceType = DistanceType.EuclideanDistance, 37 | ) : ClusterData => { 38 | const { nodes = [], edges = [] } = data; 39 | 40 | const defaultClusterInfo = { 41 | clusters: [ 42 | { 43 | id: "0", 44 | nodes, 45 | } 46 | ], 47 | clusterEdges: [] 48 | }; 49 | 50 | // 距离类型为欧式距离且没有属性时,直接return 51 | if (distanceType === DistanceType.EuclideanDistance && !nodes.every(node => node.hasOwnProperty(propertyKey))){ 52 | return defaultClusterInfo; 53 | } 54 | 55 | // 所有节点属性集合 56 | let properties = []; 57 | // 所有节点属性one-hot特征向量集合 58 | let allPropertiesWeight = []; 59 | if (distanceType === DistanceType.EuclideanDistance) { 60 | properties = getAllProperties(nodes, propertyKey); 61 | allPropertiesWeight = oneHot(properties, involvedKeys, uninvolvedKeys); 62 | } 63 | if (!allPropertiesWeight.length) { 64 | return defaultClusterInfo; 65 | } 66 | const allPropertiesWeightUniq = uniq(allPropertiesWeight.map(item => item.join(''))); 67 | // 当输入节点数量或者属性集合的长度小于k时,k调整为其中最小的值 68 | const finalK = Math.min(k, nodes.length, allPropertiesWeightUniq.length); 69 | 70 | // 记录节点的原始index,与allPropertiesWeight对应 71 | for (let i = 0; i < nodes.length; i++) { 72 | nodes[i].originIndex = i; 73 | } 74 | // 初始化质心(聚类中心) 75 | const centroids = []; 76 | const centroidIndexList = []; 77 | const clusters = []; 78 | for (let i = 0; i < finalK; i++) { 79 | if (i === 0) { 80 | // 随机选取质心(聚类中心) 81 | const randomIndex = Math.floor(Math.random() * nodes.length); 82 | switch (distanceType) { 83 | case DistanceType.EuclideanDistance: 84 | centroids[i] = allPropertiesWeight[randomIndex]; 85 | break; 86 | default: 87 | centroids[i] = []; 88 | break; 89 | } 90 | centroidIndexList.push(randomIndex); 91 | clusters[i] = [nodes[randomIndex]]; 92 | nodes[randomIndex].clusterId = String(i); 93 | } else { 94 | let maxDistance = -Infinity; 95 | let maxDistanceNodeIndex = 0; 96 | // 选取与已有质心平均距离最远的点做为新的质心 97 | for (let m = 0; m < nodes.length; m++) { 98 | if (!centroidIndexList.includes(m)) { 99 | let totalDistance = 0; 100 | for (let j = 0; j < centroids.length; j++) { 101 | // 求节点到质心的距离(默认节点属性的欧式距离) 102 | let distance = 0; 103 | switch (distanceType) { 104 | case DistanceType.EuclideanDistance: 105 | distance = getDistance(allPropertiesWeight[nodes[m].originIndex], centroids[j], distanceType); 106 | break; 107 | default: 108 | break; 109 | } 110 | totalDistance += distance; 111 | } 112 | // 节点到各质心的平均距离(默认欧式距离) 113 | const avgDistance = totalDistance / centroids.length; 114 | // 记录到已有质心最远的的距离和节点索引 115 | if (avgDistance > maxDistance && 116 | !centroids.find(centroid => isEqual(centroid, getCentroid(distanceType, allPropertiesWeight, nodes[m].originIndex)))) { 117 | maxDistance = avgDistance; 118 | maxDistanceNodeIndex = m; 119 | } 120 | } 121 | } 122 | 123 | centroids[i] = getCentroid(distanceType, allPropertiesWeight, maxDistanceNodeIndex); 124 | centroidIndexList.push(maxDistanceNodeIndex); 125 | clusters[i] = [nodes[maxDistanceNodeIndex]]; 126 | nodes[maxDistanceNodeIndex].clusterId = String(i); 127 | } 128 | } 129 | 130 | 131 | let iterations = 0; 132 | while (true) { 133 | for (let i = 0; i < nodes.length; i++) { 134 | let minDistanceIndex = 0; 135 | let minDistance = Infinity; 136 | if (!(iterations === 0 && centroidIndexList.includes(i))) { 137 | for (let j = 0; j < centroids.length; j++) { 138 | // 求节点到质心的距离(默认节点属性的欧式距离) 139 | let distance = 0; 140 | switch (distanceType) { 141 | case DistanceType.EuclideanDistance: 142 | distance = getDistance(allPropertiesWeight[i], centroids[j], distanceType); 143 | break; 144 | default: 145 | break; 146 | } 147 | // 记录节点最近的质心的索引 148 | if (distance < minDistance) { 149 | minDistance = distance; 150 | minDistanceIndex = j; 151 | } 152 | } 153 | 154 | // 从原来的类别删除节点 155 | if (nodes[i].clusterId !== undefined) { 156 | for (let n = clusters[Number(nodes[i].clusterId)].length - 1; n >= 0 ; n--) { 157 | if (clusters[Number(nodes[i].clusterId)][n].id === nodes[i].id) { 158 | clusters[Number(nodes[i].clusterId)].splice(n, 1); 159 | } 160 | } 161 | } 162 | // 将节点划分到距离最小的质心(聚类中心)所对应的类中 163 | nodes[i].clusterId = String(minDistanceIndex); 164 | clusters[minDistanceIndex].push(nodes[i]); 165 | } 166 | } 167 | 168 | // 是否存在质心(聚类中心)移动 169 | let centroidsEqualAvg = false; 170 | for (let i = 0; i < clusters.length; i ++) { 171 | const clusterNodes = clusters[i]; 172 | let totalVector = new Vector([]); 173 | for (let j = 0; j < clusterNodes.length; j++) { 174 | totalVector = totalVector.add(new Vector(allPropertiesWeight[clusterNodes[j].originIndex])); 175 | } 176 | // 计算每个类别的均值向量 177 | const avgVector = totalVector.avg(clusterNodes.length); 178 | // 如果均值向量不等于质心向量 179 | if (!avgVector.equal(new Vector(centroids[i]))) { 180 | centroidsEqualAvg = true; 181 | // 移动/更新每个类别的质心(聚类中心)到该均值向量 182 | centroids[i] = avgVector.getArr(); 183 | } 184 | } 185 | iterations++; 186 | // 如果每个节点都归属了类别,且不存在质心(聚类中心)移动或者迭代次数超过1000,则停止 187 | if (nodes.every(node => node.clusterId !== undefined) && centroidsEqualAvg || iterations >= 1000) { 188 | break; 189 | } 190 | } 191 | 192 | // get the cluster edges 193 | const clusterEdges = []; 194 | const clusterEdgeMap = {}; 195 | edges.forEach(edge => { 196 | const { source, target } = edge; 197 | const sourceClusterId = nodes.find(node => node.id === source)?.clusterId; 198 | const targetClusterId = nodes.find(node => node.id === target)?.clusterId; 199 | const newEdgeId = `${sourceClusterId}---${targetClusterId}`; 200 | if (clusterEdgeMap[newEdgeId]) { 201 | clusterEdgeMap[newEdgeId].count++; 202 | } else { 203 | const newEdge = { 204 | source: sourceClusterId, 205 | target: targetClusterId, 206 | count: 1 207 | }; 208 | clusterEdgeMap[newEdgeId] = newEdge; 209 | clusterEdges.push(newEdge); 210 | } 211 | }); 212 | 213 | return { clusters, clusterEdges }; 214 | } 215 | 216 | export default kMeans; 217 | --------------------------------------------------------------------------------