├── .gitignore ├── requirements.txt ├── native_process ├── README.md ├── walker.py ├── EGES_model.py └── run_EGES.py ├── gpu_process ├── run_EGES.py ├── stream_EGES.py ├── run_EGES_multi_gpu.py └── EGES_module.py ├── README.md ├── utils.py └── data_process.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.0 2 | torch-geometric>=2.0.0 3 | torch-scatter>=2.0.0 4 | torch-sparse>=0.6.0 5 | numpy>=1.19.0 6 | pandas>=1.3.0 7 | matplotlib>=3.4.0 8 | scikit-learn>=0.24.0 9 | tqdm>=4.62.0 10 | networkx>=2.6.0 11 | joblib>=1.0.0 12 | tensorboard>=2.7.0 -------------------------------------------------------------------------------- /native_process/README.md: -------------------------------------------------------------------------------- 1 | # EGES单机实现 2 | 3 | 这是EGES(Enhanced Graph Embedding with Side Information)模型的单机单GPU实现版本。该实现基于原始论文,使用PyTorch框架开发,支持商品推荐场景下的图嵌入学习。 4 | 5 | ## 功能特点 6 | 7 | - 支持商品会话序列的图构建 8 | - 实现基于Node2Vec的随机游走策略 9 | - 集成商品sideinfo(类别、品牌、店铺等) 10 | - 支持模型训练和检查点保存 11 | - 提供嵌入向量的保存和可视化功能 12 | 13 | ## 使用方法 14 | 15 | ### 基本用法 16 | 17 | ```bash 18 | python native_process/run_EGES.py 19 | ``` 20 | 21 | ### 使用自定义参数 22 | 23 | ```bash 24 | python native_process/run_EGES.py \ 25 | --data_path ./data/ \ 26 | --output_dir ./output/native/ \ 27 | --embedding_dim 128 \ 28 | --batch_size 8192 \ 29 | --epochs 10 \ 30 | --visualize 31 | ``` 32 | 33 | ## 参数说明 34 | 35 | - `--data_path`:数据文件路径,默认为 './data/' 36 | - `--output_dir`:输出目录,默认为 './output/native/' 37 | - `--p`:Node2Vec返回参数,默认为0.25 38 | - `--q`:Node2Vec进出参数,默认为2 39 | - `--num_walks`:每个节点的游走次数,默认为10 40 | - `--walk_length`:每次游走的长度,默认为10 41 | - `--window_size`:上下文窗口大小,默认为5 42 | - `--embedding_dim`:嵌入向量维度,默认为128 43 | - `--batch_size`:训练批次大小,默认为512 44 | - `--epochs`:训练轮数,默认为10 45 | - `--lr`:学习率,默认为0.001 46 | - `--seed`:随机种子,默认为42 47 | - `--visualize`:是否可视化嵌入向量 48 | 49 | ## 输出说明 50 | 51 | 训练过程会生成以下文件: 52 | 53 | - `checkpoints/`:模型检查点目录 54 | - `model_epoch_N.pt`:每5轮保存的模型 55 | - `model_final.pt`:最终模型 56 | - `embedding/`:嵌入向量目录 57 | - `node_embeddings.npy`:节点嵌入向量 58 | - `node_embeddings.txt`:文本格式的嵌入向量 59 | - `node_map.txt`:节点ID映射 60 | - `reverse_node_map.txt`:反向节点ID映射 61 | - `plots/`:可视化结果(如果启用) 62 | 63 | ## 性能优化 64 | 65 | - 使用多进程数据加载 66 | - 支持GPU加速 67 | - 实现了高效的随机游走算法 68 | - 优化了内存使用 69 | 70 | ## 常见问题 71 | 72 | 1. **内存不足** 73 | - 减小batch_size 74 | - 减少num_walks和walk_length 75 | - 关闭可视化功能 76 | 77 | 2. **训练速度慢** 78 | - 增加batch_size 79 | - 使用更快的GPU 80 | - 减少epochs数量 81 | 82 | 3. **模型效果不理想** 83 | - 调整p和q参数 84 | - 增加embedding_dim 85 | - 增加训练轮数 -------------------------------------------------------------------------------- /gpu_process/run_EGES.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | from tqdm import tqdm 9 | 10 | # 添加项目根目录到系统路径 11 | current_dir = os.path.dirname(os.path.abspath(__file__)) 12 | project_root = os.path.dirname(current_dir) 13 | sys.path.append(project_root) 14 | 15 | # 导入集成版EGES 16 | from gpu_process.EGES_module import EGESTrainer 17 | from utils import set_seed, plot_loss_curve, write_embedding, visualize_embeddings 18 | from data_process import get_session 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='集成版EGES实现') 23 | parser.add_argument('--data_path', type=str, default='./data/', help='数据文件路径') 24 | parser.add_argument('--output_dir', type=str, default='./output/integrated/', help='输出目录') 25 | parser.add_argument('--p', type=float, default=0.25, help='返回参数') 26 | parser.add_argument('--q', type=float, default=2, help='进出参数') 27 | parser.add_argument('--walk_length', type=int, default=10, help='随机游走长度') 28 | parser.add_argument('--context_size', type=int, default=5, help='上下文窗口大小') 29 | parser.add_argument('--walks_per_node', type=int, default=10, help='每个节点的游走次数') 30 | parser.add_argument('--embedding_dim', type=int, default=128, help='嵌入维度') 31 | parser.add_argument('--batch_size', type=int, default=128, help='批次大小') 32 | parser.add_argument('--epochs', type=int, default=5, help='训练轮数') 33 | parser.add_argument('--lr', type=float, default=0.001, help='学习率') 34 | parser.add_argument('--seed', type=int, default=42, help='随机种子') 35 | parser.add_argument('--gpu', type=int, default=0, help='使用的GPU ID,-1表示使用CPU') 36 | parser.add_argument('--visualize', action='store_true', help='是否进行向量聚类可视化') 37 | parser.add_argument('--n_clusters', type=int, default=8, help='聚类的簇数量') 38 | parser.add_argument('--full_data', action='store_true', help='是否使用完整数据集进行训练') 39 | args = parser.parse_args() 40 | 41 | # 设置随机种子 42 | set_seed(args.seed) 43 | 44 | # 设置设备 45 | if args.gpu >= 0 and torch.cuda.is_available(): 46 | device = torch.device(f'cuda:{args.gpu}') 47 | print(f"使用GPU: {args.gpu}") 48 | else: 49 | device = torch.device('cpu') 50 | print("使用CPU") 51 | 52 | # 创建输出目录 53 | os.makedirs(args.output_dir, exist_ok=True) 54 | 55 | # 读取数据 56 | print("开始数据加载和预处理") 57 | start_time = time.time() 58 | 59 | # 根据参数决定使用完整数据集还是样本数据 60 | if args.full_data: 61 | action_file = 'action.csv' 62 | print("使用完整数据集进行训练") 63 | else: 64 | action_file = 'action_head.csv' 65 | print("使用样本数据集进行训练") 66 | 67 | action_data = pd.read_csv(os.path.join(args.data_path, action_file)) 68 | print(f"读取数据完成,耗时: {time.time() - start_time:.2f}秒") 69 | print(f"数据形状: {action_data.shape}") 70 | 71 | # 构建会话 72 | start_time = time.time() 73 | session_list = get_session(action_data) 74 | print(f"构建会话完成,耗时: {time.time() - start_time:.2f}秒") 75 | print(f"会话数量: {len(session_list)}") 76 | 77 | # 读取SKUsideinfo 78 | start_time = time.time() 79 | sku_info = pd.read_csv(os.path.join(args.data_path, 'jdata_product.csv')) 80 | side_info = sku_info[['sku_id', 'cate', 'brand', 'shop_id']].values 81 | print(f"读取SKUsideinfo完成,耗时: {time.time() - start_time:.2f}秒") 82 | 83 | # 创建训练器 84 | trainer = EGESTrainer( 85 | session_list=session_list, 86 | side_info=side_info, 87 | embedding_dim=args.embedding_dim, 88 | walk_length=args.walk_length, 89 | context_size=args.context_size, 90 | walks_per_node=args.walks_per_node, 91 | p=args.p, 92 | q=args.q, 93 | lr=args.lr, 94 | device=device 95 | ) 96 | 97 | trainer.train( 98 | epochs=args.epochs, 99 | batch_size=args.batch_size, 100 | output_dir=args.output_dir 101 | ) 102 | 103 | # 保存模型 104 | checkpoint_dir = os.path.join(args.output_dir, 'checkpoints') 105 | os.makedirs(checkpoint_dir, exist_ok=True) 106 | trainer.save_model(os.path.join(checkpoint_dir, 'model_final.pt')) 107 | 108 | # 获取嵌入 109 | print("保存嵌入...") 110 | embedding_dir = os.path.join(args.output_dir, 'embedding') 111 | os.makedirs(embedding_dir, exist_ok=True) 112 | 113 | # 获取节点嵌入 114 | node_embeddings = trainer.get_embeddings() 115 | 116 | # 保存嵌入到文件 117 | np.save(os.path.join(embedding_dir, "node_embeddings.npy"), node_embeddings) 118 | 119 | # 将嵌入写入文本文件 120 | embedding_file = os.path.join(embedding_dir, "node_embeddings.txt") 121 | write_embedding([node_embeddings[node_id] for node_id in sorted(node_embeddings.keys())], embedding_file) 122 | 123 | # 使用utils中的可视化函数 124 | if args.visualize: 125 | print(f"使用数据路径 {args.data_path} 进行可视化...") 126 | visualize_embeddings(node_embeddings, args.data_path, args.output_dir) 127 | 128 | print("训练完成!") 129 | 130 | 131 | if __name__ == "__main__": 132 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # update 2025.3 重构项目,更新到Pytorch新版本 2 | 3 | # EGES (Enhanced Graph Embedding with Side Information) 4 | 5 | 这是一个基于PyTorch实现的EGES(Enhanced Graph Embedding with Side Information)模型。该模型通过结合图结构和节点的侧信息来学习更好的节点表示。 6 | 7 | ## 功能特点 8 | 9 | - 支持单GPU和多GPU分布式训练 10 | - 集成了Node2Vec随机游走采样 11 | - 支持节点侧信息的融合 12 | - 实现了高效的数据加载和批处理 13 | - 提供了嵌入可视化功能 14 | - 支持模型checkpointing和结果保存 15 | - 新增流式数据处理功能,适用于推荐系统场景 16 | 17 | ## 环境要求 18 | 19 | - Python 3.7+ 20 | - PyTorch 1.8+ 21 | - CUDA (推荐用于GPU训练) 22 | - PyTorch Geometric (用于图数据处理) 23 | 24 | ## 安装 25 | 26 | 1. 克隆仓库: 27 | ```bash 28 | git clone [repository_url] 29 | cd EGES 30 | ``` 31 | 32 | 2. 安装依赖: 33 | ```bash 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## 数据格式 38 | 39 | 项目需要两个主要的输入文件: 40 | 41 | 1. `action_head.csv`:用户行为数据,包含以下列: 42 | - user_id: 用户ID 43 | - sku_id: 商品ID 44 | - action_time: 行为时间 45 | - module_id: 模块ID 46 | - type: 行为类型 47 | 48 | 2. `jdata_product.csv`:商品侧信息数据,包含以下列: 49 | - sku_id: 商品ID 50 | - cate: 类别ID 51 | - brand: 品牌ID 52 | - shop_id: 店铺ID 53 | 54 | ## 使用方法 55 | 56 | ### 单GPU训练 57 | 58 | ```bash 59 | python gpu_process/run_EGES.py \ 60 | --data_path ./data/ \ 61 | --output_dir ./output/single_gpu/ \ 62 | --epochs 2 \ 63 | --batch_size 128 \ 64 | --embedding_dim 128 \ 65 | --visualize 66 | ``` 67 | 68 | ### 多GPU分布式训练 69 | 70 | ```bash 71 | python gpu_process/run_EGES_multi_gpu.py \ 72 | --data_path ./data/ \ 73 | --output_dir ./output/multi_gpu/ \ 74 | --epochs 10 \ 75 | --batch_size 128 \ 76 | --gpus -1 \ 77 | --embedding_dim 128 \ 78 | --visualize 79 | ``` 80 | 81 | ### 流式数据处理和训练 82 | 83 | 新增流式处理功能,按时间戳划分数据进行增量训练,适用于推荐系统场景: 84 | 85 | ```bash 86 | python gpu_process/stream_EGES.py \ 87 | --data_path ./data/ \ 88 | --output_dir ./output/streaming/ \ 89 | --time_interval 1 \ 90 | --save_interval 10 \ 91 | --epochs 1 \ 92 | --batch_size 128 \ 93 | --embedding_dim 128 94 | ``` 95 | 96 | ### 主要参数说明 97 | 98 | #### 通用参数 99 | - `--data_path`: 数据文件路径 100 | - `--output_dir`: 输出目录 101 | - `--epochs`: 训练轮数 102 | - `--batch_size`: 批次大小 103 | - `--embedding_dim`: 嵌入维度 104 | - `--walk_length`: 随机游走长度 105 | - `--context_size`: 上下文窗口大小 106 | - `--walks_per_node`: 每个节点的游走次数 107 | - `--p`: 返回参数 108 | - `--q`: 进出参数 109 | - `--lr`: 学习率 110 | 111 | #### 多GPU训练参数 112 | - `--gpus`: 使用的GPU数量,-1表示使用所有可用GPU 113 | - `--sync_gradients`: 是否同步梯度(多GPU训练时) 114 | - `--sync_params`: 是否在每个epoch后同步模型参数(多GPU训练时) 115 | - `--visualize`: 是否可视化嵌入向量 116 | 117 | #### 流式处理参数 118 | - `--time_interval`: 时间窗口间隔(小时) 119 | - `--save_interval`: 每多少个窗口保存一次模型和嵌入 120 | - `--cpu`: 是否使用CPU进行训练(不指定则使用GPU) 121 | 122 | ## 输出说明 123 | 124 | ### 批量训练输出 125 | 训练完成后,模型会在指定的输出目录下生成以下文件: 126 | 127 | 1. `/checkpoints/` 128 | - `model_final.pt`: 训练完成的模型权重 129 | 130 | 2. `/embedding/` 131 | - `node_embeddings.npy`: NumPy格式的节点嵌入 132 | - `/plots/`: 嵌入可视化结果(如果启用可视化) 133 | - `cate_dist.png`: 按类别分布的可视化 134 | - `brand_dist.png`: 按品牌分布的可视化 135 | - `shop_dist.png`: 按店铺分布的可视化 136 | 137 | ### 流式处理输出 138 | 流式处理会在指定的输出目录下生成以下文件: 139 | 140 | 1. `/checkpoints/` 141 | - `checkpoint_{window_count}_{time}.pt`: 定期保存的模型检查点 142 | - `model_final.pt`: 最终的模型权重 143 | 144 | 2. `/embeddings_{window_count}_{time}/embedding/` 145 | - `node_embeddings.npy`: 定期保存的节点嵌入向量 146 | 147 | 3. `/final_{timestamp}/` 148 | - 流式处理结束时保存的最终模型和嵌入向量 149 | 150 | ## 项目结构 151 | 152 | ``` 153 | EGES/ 154 | ├── data/ # 数据目录 155 | ├── gpu_process/ # GPU训练相关代码 156 | │ ├── EGES_module.py # EGES模型核心实现 157 | │ ├── run_EGES.py # 单GPU训练脚本 158 | │ ├── run_EGES_multi_gpu.py # 多GPU训练脚本 159 | │ └── stream_EGES.py # 流式数据处理实现 160 | ├── native_process/ # CPU训练相关代码 161 | │ ├── EGES_model.py # 原生EGES模型实现 162 | │ ├── run_EGES.py # CPU训练脚本 163 | │ ├── walker.py # 随机游走实现 164 | │ └── README.md # CPU版本说明 165 | ├── utils.py # 工具函数 166 | ├── data_process.py # 数据处理函数 167 | ├── requirements.txt # 项目依赖 168 | └── README.md # 项目说明 169 | ``` 170 | 171 | ## 版本说明 172 | 173 | 1. **GPU单机版**: 使用PyTorch Geometric实现,支持高效的图数据处理和模型训练。 174 | 2. **GPU多机多卡版**: 支持分布式训练,可以利用多台机器的多个GPU加速训练过程。 175 | 3. **Native CPU版**: 使用原生PyTorch实现,适用于没有GPU资源的环境。 176 | 4. **流式处理版**: 按时间戳划分数据,支持增量训练,适用于推荐系统场景。 177 | 178 | ## 流式处理特点 179 | 180 | - 按时间戳将数据划分为固定时间间隔(默认1小时)的窗口 181 | - 逐个处理时间窗口,增量更新图结构和模型 182 | - 定期保存模型检查点和嵌入向量 183 | - 支持通过Ctrl+C安全中断处理流程 184 | - 处理完所有数据或中断时,保存最终模型 185 | 186 | ## 注意事项 187 | 188 | 1. 多GPU训练时,学习率会根据GPU数量自动调整 189 | 2. 建议在使用多GPU训练时启用梯度同步(--sync_gradients) 190 | 3. 可视化功能会消耗较多内存,对于大规模数据集可能需要调整batch_size 191 | 4. 确保数据文件格式正确,且列名与要求一致 192 | 5. 流式处理中的时间窗口大小需要根据数据特性和业务需求进行调整 193 | 194 | ## 引用 195 | 196 | 如果您使用了本项目的代码,请引用原始EGES论文: 197 | 198 | ```bibtex 199 | @inproceedings{wang2018billion, 200 | title={Billion-scale commodity embedding for e-commerce recommendation in alibaba}, 201 | author={Wang, Jizhe and Huang, Pipei and Zhao, Huan and Zhang, Zhibo and Zhao, Binqiang and Lee, Dik Lun}, 202 | booktitle={Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining}, 203 | pages={839--848}, 204 | year={2018} 205 | } 206 | ``` -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.manifold import TSNE 4 | import torch 5 | import os 6 | import random 7 | import socket 8 | import torch.distributed as dist 9 | import pandas as pd 10 | 11 | 12 | def set_seed(seed): 13 | """设置随机种子以确保结果可复现""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | if torch.cuda.is_available(): 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | 22 | def setup(rank, world_size, master_addr='localhost', master_port='12355'): 23 | """ 24 | 设置分布式训练环境 25 | """ 26 | os.environ['MASTER_ADDR'] = master_addr 27 | os.environ['MASTER_PORT'] = master_port 28 | 29 | # 初始化进程组 30 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 31 | print(f"进程 {rank} 初始化完成") 32 | 33 | 34 | def cleanup(): 35 | """ 36 | 清理分布式训练环境 37 | """ 38 | dist.destroy_process_group() 39 | 40 | 41 | def get_free_port(): 42 | """ 43 | 获取一个可用的端口号 44 | """ 45 | sock = socket.socket() 46 | sock.bind(('', 0)) 47 | port = sock.getsockname()[1] 48 | sock.close() 49 | return str(port) 50 | 51 | 52 | def plot_embeddings(embebed_mat, side_info_mat, output_dir='./data_cache'): 53 | """ 54 | 使用t-SNE可视化嵌入向量 55 | """ 56 | os.makedirs(output_dir, exist_ok=True) 57 | 58 | model = TSNE(n_components=2) 59 | node_pos = model.fit_transform(embebed_mat) 60 | brand_color_idx, shop_color_idx, cate_color_idx = {}, {}, {} 61 | for i in range(len(node_pos)): 62 | brand_color_idx.setdefault(side_info_mat[i, 1], []) 63 | brand_color_idx[side_info_mat[i, 1]].append(i) 64 | shop_color_idx.setdefault(side_info_mat[i, 2], []) 65 | shop_color_idx[side_info_mat[i, 2]].append(i) 66 | cate_color_idx.setdefault(side_info_mat[i, 3], []) 67 | cate_color_idx[side_info_mat[i, 3]].append(i) 68 | 69 | plt.figure(figsize=(10, 8)) 70 | for c, idx in brand_color_idx.items(): 71 | plt.scatter(node_pos[idx, 0], node_pos[idx, 1], label=c) 72 | plt.title('Brand Distribution') 73 | plt.savefig(os.path.join(output_dir, 'brand_dist.png')) 74 | 75 | plt.figure(figsize=(10, 8)) 76 | for c, idx in shop_color_idx.items(): 77 | plt.scatter(node_pos[idx, 0], node_pos[idx, 1], label=c) 78 | plt.title('Shop Distribution') 79 | plt.savefig(os.path.join(output_dir, 'shop_dist.png')) 80 | 81 | plt.figure(figsize=(10, 8)) 82 | for c, idx in cate_color_idx.items(): 83 | plt.scatter(node_pos[idx, 0], node_pos[idx, 1], label=c) 84 | plt.title('Category Distribution') 85 | plt.savefig(os.path.join(output_dir, 'cate_dist.png')) 86 | 87 | 88 | def visualize_embeddings(node_embeddings, data_path, output_dir): 89 | """ 90 | 可视化嵌入向量 91 | 92 | 参数: 93 | node_embeddings: 节点嵌入字典,键为节点ID,值为嵌入向量 94 | data_path: 数据路径,用于读取SKUsideinfo 95 | output_dir: 输出目录,用于保存可视化结果 96 | """ 97 | print("可视化嵌入...") 98 | try: 99 | # 准备嵌入矩阵和sideinfo矩阵 100 | embebed_mat = [] 101 | side_info_mat = [] 102 | 103 | # 读取SKUsideinfo 104 | sku_info = pd.read_csv(data_path + 'jdata_product.csv') 105 | sku_info_dict = {row['sku_id']: row for _, row in sku_info.iterrows()} 106 | 107 | for node_id in sorted(node_embeddings.keys()): 108 | if node_id in sku_info_dict: 109 | embebed_mat.append(node_embeddings[node_id]) 110 | row = sku_info_dict[node_id] 111 | side_info_mat.append([node_id, row['brand'], row['shop_id'], row['cate']]) 112 | 113 | if len(embebed_mat) > 0: 114 | embebed_mat = np.array(embebed_mat) 115 | side_info_mat = np.array(side_info_mat) 116 | 117 | # 可视化嵌入 118 | plot_dir = os.path.join(output_dir, 'embedding', 'plots') 119 | os.makedirs(plot_dir, exist_ok=True) 120 | plot_embeddings(embebed_mat, side_info_mat, output_dir=plot_dir) 121 | print(f"嵌入可视化完成,结果保存在 {plot_dir}") 122 | else: 123 | print("没有找到匹配的节点进行可视化") 124 | except Exception as e: 125 | print(f"可视化嵌入时出错: {e}") 126 | 127 | 128 | def write_embedding(embedding_result, output_file): 129 | """ 130 | 将嵌入向量写入文件 131 | """ 132 | with open(output_file, 'w') as f: 133 | for i in range(len(embedding_result)): 134 | s = " ".join(str(val) for val in embedding_result[i].tolist()) 135 | f.write(s + "\n") 136 | 137 | 138 | def save_dict_to_file(dict_obj, output_file): 139 | """ 140 | 将字典保存到文件 141 | """ 142 | with open(output_file, 'w') as f: 143 | for key, value in dict_obj.items(): 144 | f.write(f"{key}\t{value}\n") 145 | 146 | 147 | def load_dict_from_file(input_file): 148 | """ 149 | 从文件加载字典 150 | """ 151 | dict_obj = {} 152 | with open(input_file, 'r') as f: 153 | for line in f: 154 | key, value = line.strip().split('\t') 155 | dict_obj[int(key)] = int(value) 156 | return dict_obj 157 | 158 | 159 | def plot_loss_curve(losses, output_dir='./output'): 160 | """ 161 | Plot the loss curve 162 | 163 | Parameters: 164 | losses: list of losses for each epoch 165 | output_dir: output directory for saving the loss curve 166 | """ 167 | try: 168 | plt.figure(figsize=(10, 6)) 169 | plt.plot(range(1, len(losses) + 1), losses, 'b-', marker='o', label='Training Loss') 170 | plt.title('Loss Curve') 171 | plt.xlabel('Epoch') 172 | plt.ylabel('Loss') 173 | plt.grid(True) 174 | plt.legend() 175 | 176 | # Ensure output directory exists 177 | os.makedirs(output_dir, exist_ok=True) 178 | 179 | # Save image 180 | loss_curve_path = os.path.join(output_dir, 'loss_curve.png') 181 | plt.savefig(loss_curve_path) 182 | plt.close() 183 | 184 | print(f"Loss curve saved to {loss_curve_path}") 185 | except Exception as e: 186 | print(f"Error plotting loss curve: {e}") 187 | 188 | 189 | def save_embeddings(embeddings, node_map, output_dir): 190 | """ 191 | 保存节点嵌入到文件 192 | 193 | 参数: 194 | embeddings: 嵌入矩阵 195 | node_map: 节点映射(id -> 原始节点) 196 | output_dir: 输出目录 197 | """ 198 | # 确保输出目录存在 199 | os.makedirs(output_dir, exist_ok=True) 200 | 201 | # 只保存numpy格式的嵌入 202 | np.save(os.path.join(output_dir, 'node_embeddings.npy'), embeddings) 203 | 204 | print(f"嵌入已保存到 {output_dir}") -------------------------------------------------------------------------------- /native_process/walker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import torch 4 | import torch_geometric as pyg 5 | from torch_geometric.nn import Node2Vec 6 | from torch_geometric.utils import from_networkx 7 | import time 8 | from tqdm import tqdm 9 | 10 | 11 | class FastGraphWalker: 12 | def __init__(self, p=1, q=1, device=None): 13 | """ 14 | 初始化随机游走器 15 | 16 | 参数: 17 | p: 返回参数,控制立即重访节点的可能性 18 | q: 进出参数,允许搜索区分"向内"和"向外"节点 19 | device: 计算设备 20 | """ 21 | self.p = p 22 | self.q = q 23 | 24 | # 设置设备 25 | if device is None: 26 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | else: 28 | self.device = device 29 | 30 | def build_graph(self, session_list): 31 | """ 32 | 构建图并直接在GPU上进行处理 33 | 34 | 参数: 35 | session_list: 会话列表 36 | 37 | 返回: 38 | pyg_data: PyG图数据 39 | node_maps: 节点映射元组 (node_map, reverse_node_map) 40 | """ 41 | print("构建图...") 42 | start_time = time.time() 43 | 44 | # 提取所有边 45 | edges = [] 46 | for session in session_list: 47 | if len(session) > 1: 48 | for i in range(len(session) - 1): 49 | # 确保节点是整数类型 50 | u = int(session[i]) 51 | v = int(session[i + 1]) 52 | edges.append((u, v)) 53 | 54 | print(f"提取的边数量: {len(edges)}") 55 | if len(edges) == 0: 56 | print("警告: 没有提取到边,无法构建图") 57 | return None, None 58 | 59 | # 创建NetworkX图 60 | G = nx.Graph() 61 | G.add_edges_from(edges) 62 | 63 | # 获取所有唯一节点 64 | nodes = list(G.nodes()) 65 | node_map = {node: i for i, node in enumerate(nodes)} 66 | reverse_node_map = {i: node for i, node in enumerate(nodes)} 67 | 68 | # 重新映射节点ID 69 | G_relabeled = nx.relabel_nodes(G, node_map) 70 | 71 | # 转换为PyG图 72 | pyg_data = from_networkx(G_relabeled).to(self.device) 73 | 74 | end_time = time.time() 75 | print(f"图构建完成,耗时: {end_time - start_time:.2f}秒") 76 | print(f"图包含 {G.number_of_nodes()} 个节点和 {G.number_of_edges()} 条边") 77 | 78 | return pyg_data, (node_map, reverse_node_map) 79 | 80 | 81 | def generate_walks(self, pyg_data, num_walks, walk_length, window_size): 82 | print(f"生成随机游走 (p={self.p}, q={self.q})...") 83 | start_time = time.time() 84 | 85 | # 使用PyG的Node2Vec 86 | model = Node2Vec( 87 | pyg_data.edge_index, 88 | embedding_dim=64, 89 | walk_length=walk_length, 90 | context_size=5, 91 | walks_per_node=num_walks, # 直接设置总游走次数 92 | p=self.p, 93 | q=self.q, 94 | sparse=True 95 | ).to(self.device) 96 | 97 | # 直接在GPU上生成游走 98 | loader = model.loader(batch_size=128, shuffle=True) 99 | all_walks = [] 100 | 101 | for walk_batch in tqdm(loader, desc="生成游走"): 102 | pos_rw, _ = walk_batch # 关键修正:解包元组 103 | all_walks.append(pos_rw.cpu().numpy()) 104 | 105 | # 合并结果 106 | all_walks = np.concatenate(all_walks, axis=0) 107 | 108 | # 生成上下文对(保持原有逻辑) 109 | print("生成上下文对...") 110 | all_pairs = [] 111 | for walk in all_walks: 112 | for i in range(len(walk)): 113 | for j in range(max(0, i-window_size), min(len(walk), i+window_size+1)): 114 | if i != j: 115 | all_pairs.append((walk[i], walk[j])) 116 | 117 | print(f"生成样本对数量: {len(all_pairs)}") 118 | return all_pairs 119 | 120 | 121 | class SimpleWalker: 122 | def __init__(self, p=1, q=1): 123 | """ 124 | 初始化简单随机游走器 125 | 126 | 参数: 127 | p: 返回参数 128 | q: 进出参数 129 | """ 130 | self.p = p 131 | self.q = q 132 | 133 | def build_graph(self, session_list): 134 | """ 135 | 构建图 136 | 137 | 参数: 138 | session_list: 会话列表 139 | 140 | 返回: 141 | G: NetworkX图 142 | node_maps: 节点映射元组 (node_map, reverse_node_map) 143 | """ 144 | print("构建图...") 145 | start_time = time.time() 146 | 147 | # 提取所有边 148 | edges = [] 149 | for session in session_list: 150 | if len(session) > 1: 151 | for i in range(len(session) - 1): 152 | u = int(session[i]) 153 | v = int(session[i + 1]) 154 | edges.append((u, v)) 155 | 156 | print(f"提取的边数量: {len(edges)}") 157 | if len(edges) == 0: 158 | print("警告: 没有提取到边,无法构建图") 159 | return None, None 160 | 161 | # 创建NetworkX图 162 | G = nx.Graph() 163 | G.add_edges_from(edges) 164 | 165 | # 获取所有唯一节点 166 | nodes = list(G.nodes()) 167 | node_map = {node: i for i, node in enumerate(nodes)} 168 | reverse_node_map = {i: node for i, node in enumerate(nodes)} 169 | 170 | end_time = time.time() 171 | print(f"图构建完成,耗时: {end_time - start_time:.2f}秒") 172 | print(f"图包含 {G.number_of_nodes()} 个节点和 {G.number_of_edges()} 条边") 173 | 174 | return G, (node_map, reverse_node_map) 175 | 176 | def generate_walks(self, G, num_walks, walk_length): 177 | """ 178 | 生成随机游走 179 | 180 | 参数: 181 | G: NetworkX图 182 | num_walks: 每个节点的游走次数 183 | walk_length: 每次游走的长度 184 | 185 | 返回: 186 | walks: 随机游走序列 187 | """ 188 | print(f"生成随机游走 (p={self.p}, q={self.q})...") 189 | start_time = time.time() 190 | 191 | walks = [] 192 | nodes = list(G.nodes()) 193 | 194 | for _ in range(num_walks): 195 | np.random.shuffle(nodes) 196 | for node in tqdm(nodes): 197 | walk = [node] 198 | for _ in range(walk_length - 1): 199 | curr = walk[-1] 200 | neighbors = list(G.neighbors(curr)) 201 | if len(neighbors) == 0: 202 | break 203 | walk.append(np.random.choice(neighbors)) 204 | walks.append(walk) 205 | 206 | end_time = time.time() 207 | print(f"随机游走完成,耗时: {end_time - start_time:.2f}秒") 208 | print(f"生成的游走序列数量: {len(walks)}") 209 | 210 | return walks -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from itertools import chain 4 | import time 5 | import networkx as nx 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | import os 9 | from tqdm import tqdm 10 | import multiprocessing 11 | from datetime import datetime 12 | 13 | 14 | def cnt_session(data, time_cut=30, cut_type=2): 15 | """ 16 | 根据时间间隔和操作类型划分会话 17 | 18 | 参数: 19 | data: 包含用户行为数据的DataFrame 20 | time_cut: 时间间隔阈值(分钟) 21 | cut_type: 切分会话的行为类型 22 | 23 | 返回: 24 | session: 会话列表 25 | """ 26 | sku_list = data['sku_id'] 27 | time_list = data['action_time'] 28 | type_list = data['type'] 29 | 30 | # 确保时间列表中的元素是datetime对象 31 | time_list = [pd.to_datetime(t) if isinstance(t, str) else t for t in time_list] 32 | 33 | session = [] 34 | tmp_session = [] 35 | for i, item in enumerate(sku_list): 36 | # 检查是否是最后一个元素或者是否是切分类型 37 | if type_list[i] == cut_type or i == len(sku_list)-1: 38 | tmp_session.append(item) 39 | session.append(tmp_session) 40 | tmp_session = [] 41 | # 检查与下一个元素的时间间隔 42 | elif i < len(sku_list)-1: 43 | # 计算时间差(分钟) 44 | time_diff = (time_list[i+1] - time_list[i]).total_seconds() / 60 45 | if time_diff > time_cut: 46 | tmp_session.append(item) 47 | session.append(tmp_session) 48 | tmp_session = [] 49 | else: 50 | tmp_session.append(item) 51 | 52 | return session 53 | 54 | 55 | def get_session(action_data, use_type=None, verbose=True): 56 | """ 57 | 获取会话列表 58 | 59 | 参数: 60 | action_data: 用户行为数据 61 | use_type: 要使用的行为类型列表 62 | verbose: 是否打印详细信息 63 | 64 | 返回: 65 | flattened_sessions: 展平后的会话列表 66 | """ 67 | if verbose: 68 | print("原始数据形状:", action_data.shape) 69 | print("原始数据列:", action_data.columns) 70 | 71 | # 将时间字段转换为datetime类型 72 | action_data['action_time'] = pd.to_datetime(action_data['action_time']) 73 | 74 | if use_type is None: 75 | use_type = [1, 2, 3, 5] 76 | action_data = action_data[action_data['type'].isin(use_type)] 77 | if verbose: 78 | print("过滤后数据形状:", action_data.shape) 79 | 80 | action_data = action_data.sort_values(by=['user_id', 'action_time'], ascending=True) 81 | group_action_data = action_data.groupby('user_id').agg(list) 82 | if verbose: 83 | print("分组后数据形状:", group_action_data.shape) 84 | 85 | session_list = group_action_data.apply(cnt_session, axis=1) 86 | session_list = session_list.to_numpy() 87 | if verbose: 88 | print("会话列表长度:", len(session_list)) 89 | 90 | # 展平会话列表 91 | flattened_sessions = [] 92 | for sessions in session_list: 93 | flattened_sessions.extend(sessions) 94 | if verbose: 95 | print("处理后的会话数量:", len(flattened_sessions)) 96 | 97 | return flattened_sessions 98 | 99 | 100 | def get_graph_context_all_pairs(walks, window_size): 101 | """ 102 | 根据游走序列生成上下文对 103 | 使用并行处理提高效率 104 | 105 | 参数: 106 | walks: 随机游走序列 107 | window_size: 上下文窗口大小 108 | 109 | 返回: 110 | all_pairs: 所有上下文对 111 | """ 112 | # 确定CPU核心数 113 | num_cores = multiprocessing.cpu_count() 114 | # 将walks分成num_cores份 115 | chunk_size = max(1, len(walks) // num_cores) 116 | chunks = [walks[i:i+chunk_size] for i in range(0, len(walks), chunk_size)] 117 | 118 | # 创建进程池 119 | pool = multiprocessing.Pool(processes=num_cores) 120 | 121 | # 并行处理每个chunk 122 | results = pool.starmap(_process_walk_chunk, [(chunk, window_size) for chunk in chunks]) 123 | 124 | # 关闭进程池 125 | pool.close() 126 | pool.join() 127 | 128 | # 合并结果 129 | all_pairs = [] 130 | for result in results: 131 | all_pairs.extend(result) 132 | 133 | return all_pairs 134 | 135 | 136 | def _process_walk_chunk(walks_chunk, window_size): 137 | """ 138 | 处理一个walks chunk,生成上下文对 139 | 140 | 参数: 141 | walks_chunk: 随机游走序列的一部分 142 | window_size: 上下文窗口大小 143 | 144 | 返回: 145 | pairs: 上下文对列表 146 | """ 147 | pairs = [] 148 | for walk in walks_chunk: 149 | if len(walk) <= 1: 150 | continue 151 | 152 | for i in range(len(walk)): 153 | for j in range(max(0, i - window_size), min(len(walk), i + window_size + 1)): 154 | if i != j: 155 | pairs.append((walk[i], walk[j])) 156 | 157 | return pairs 158 | 159 | 160 | class GraphDataset(Dataset): 161 | """ 162 | 图数据集类,用于加载图嵌入训练数据 163 | """ 164 | def __init__(self, side_info, pairs, node_map): 165 | """ 166 | 初始化数据集 167 | 168 | 参数: 169 | side_info: 节点sideinfo 170 | pairs: 节点对 171 | node_map: 节点映射 172 | """ 173 | self.side_info = side_info 174 | self.pairs = pairs 175 | self.node_map = node_map 176 | 177 | # 创建SKU ID到索引的映射 178 | self.sku_to_idx = {} 179 | for i, row in enumerate(side_info): 180 | self.sku_to_idx[row[0]] = i 181 | 182 | def __len__(self): 183 | return len(self.pairs) 184 | 185 | def __getitem__(self, idx): 186 | node, context = self.pairs[idx] 187 | 188 | # 确保节点在side_info中 189 | if node in self.sku_to_idx: 190 | node_idx = self.sku_to_idx[node] 191 | node_features = self.side_info[node_idx] 192 | 193 | # 确保上下文节点在node_map中 194 | if context in self.node_map: 195 | context_idx = self.node_map[context] 196 | # 将特征转换为张量并确保是长整型 197 | return torch.tensor(node_features, dtype=torch.long), torch.tensor(context_idx, dtype=torch.long) 198 | 199 | # 如果节点不在side_info中或上下文节点不在node_map中,返回一个默认值 200 | # 使用第一个节点作为默认值 201 | return torch.tensor(self.side_info[0], dtype=torch.long), torch.tensor(0, dtype=torch.long) 202 | 203 | 204 | def create_dataloader(side_info, pairs, node_map, batch_size=512, num_workers=4, distributed=False, world_size=1, rank=0, drop_last=False): 205 | """ 206 | 创建数据加载器 207 | 支持分布式训练 208 | 209 | 参数: 210 | side_info: 节点sideinfo 211 | pairs: 节点对 212 | node_map: 节点映射 213 | batch_size: 批次大小 214 | num_workers: 工作进程数 215 | distributed: 是否使用分布式训练 216 | world_size: 总进程数 217 | rank: 当前进程的排名 218 | drop_last: 是否丢弃最后一个不完整的批次 219 | 220 | 返回: 221 | dataloader: 数据加载器 222 | """ 223 | dataset = GraphDataset(side_info, pairs, node_map) 224 | 225 | if distributed: 226 | # 创建分布式采样器 227 | from torch.utils.data.distributed import DistributedSampler 228 | sampler = DistributedSampler( 229 | dataset, 230 | num_replicas=world_size, 231 | rank=rank, 232 | shuffle=True 233 | ) 234 | 235 | dataloader = DataLoader( 236 | dataset, 237 | batch_size=batch_size, 238 | sampler=sampler, 239 | num_workers=num_workers, 240 | pin_memory=True, 241 | drop_last=drop_last 242 | ) 243 | else: 244 | dataloader = DataLoader( 245 | dataset, 246 | batch_size=batch_size, 247 | shuffle=True, 248 | num_workers=num_workers, 249 | pin_memory=True, 250 | drop_last=drop_last 251 | ) 252 | 253 | return dataloader -------------------------------------------------------------------------------- /native_process/EGES_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch.distributed as dist 6 | 7 | 8 | class EGES_Model(nn.Module): 9 | def __init__(self, num_nodes, num_feat, feature_lens, embedding_dim=128, lr=0.001): 10 | """ 11 | 初始化EGES模型 12 | 13 | 参数: 14 | num_nodes: 节点数量 15 | num_feat: 特征数量 16 | feature_lens: 每个特征的长度列表 17 | embedding_dim: 嵌入维度 18 | lr: 学习率 19 | """ 20 | super(EGES_Model, self).__init__() 21 | self.num_feat = num_feat 22 | self.feature_lens = feature_lens 23 | self.embedding_dim = embedding_dim 24 | self.num_nodes = num_nodes 25 | self.lr = lr 26 | 27 | # 初始化嵌入层 28 | self.embedding_layers = nn.ModuleList() 29 | for i in range(self.num_feat): 30 | embedding_layer = nn.Embedding(self.feature_lens[i], self.embedding_dim) 31 | # 使用Xavier初始化以提高收敛速度 32 | nn.init.xavier_uniform_(embedding_layer.weight) 33 | self.embedding_layers.append(embedding_layer) 34 | 35 | # 注意力网络 - 使用更高效的实现 36 | self.attention_network = nn.Sequential( 37 | nn.Linear(self.num_feat, self.num_feat), 38 | nn.ReLU(), 39 | nn.Linear(self.num_feat, self.num_feat), 40 | nn.Softmax(dim=1) 41 | ) 42 | 43 | # 输出层 - 使用嵌入矩阵共享权重 44 | self.node_embeddings = nn.Embedding(num_nodes, embedding_dim) 45 | nn.init.xavier_uniform_(self.node_embeddings.weight) 46 | 47 | # 学习率调度器 48 | self.scheduler = None 49 | 50 | def init_optimizer(self, lr=None, distributed=False): 51 | """ 52 | 初始化优化器,支持分布式训练 53 | 54 | 参数: 55 | lr: 学习率,如果为None则使用默认值 56 | distributed: 是否使用分布式训练 57 | """ 58 | if lr is not None: 59 | self.lr = lr 60 | 61 | # 优化器 - 使用带权重衰减的Adam 62 | self.optimizer = torch.optim.AdamW( 63 | self.parameters(), 64 | lr=self.lr, 65 | weight_decay=1e-5, # 添加L2正则化 66 | betas=(0.9, 0.999) # 使用默认动量参数 67 | ) 68 | 69 | # 学习率调度器 - 使用ReduceLROnPlateau 70 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 71 | self.optimizer, 72 | mode='min', 73 | factor=0.5, 74 | patience=2, 75 | verbose=True 76 | ) 77 | 78 | def forward(self, inputs, context_indices=None): 79 | """ 80 | 前向传播 81 | 82 | 参数: 83 | inputs: 包含多个特征列的列表 84 | context_indices: 上下文节点索引,如果提供则同时计算上下文嵌入 85 | 86 | 返回: 87 | node_embeddings: 节点嵌入 88 | context_embeddings: 上下文嵌入(如果提供了context_indices) 89 | """ 90 | batch_size = inputs[0].size(0) 91 | 92 | # 对每个特征进行嵌入查找 93 | embed_list = [] 94 | for i in range(self.num_feat): 95 | # 确保索引在有效范围内 96 | valid_indices = torch.clamp(inputs[i], 0, self.feature_lens[i] - 1) 97 | embed_list.append(self.embedding_layers[i](valid_indices)) 98 | 99 | # 堆叠嵌入 [batch_size, embedding_dim, num_feat] 100 | stacked_embeds = torch.stack(embed_list, dim=2) 101 | 102 | # 计算注意力权重 103 | # 为每个节点创建特征ID向量 104 | feature_ids = torch.arange(self.num_feat, device=inputs[0].device).expand(batch_size, self.num_feat) 105 | attention_weights = self.attention_network(feature_ids.float()) 106 | 107 | # 应用注意力权重 [batch_size, embedding_dim, 1] 108 | attention_weights = attention_weights.unsqueeze(1) 109 | weighted_embeds = torch.matmul(stacked_embeds, attention_weights.transpose(1, 2)) 110 | 111 | # 最终嵌入 [batch_size, embedding_dim] 112 | node_embeddings = weighted_embeds.squeeze(2) 113 | 114 | # 如果提供了上下文索引,同时计算上下文嵌入 115 | if context_indices is not None: 116 | # 确保上下文索引在有效范围内 117 | valid_context_indices = torch.clamp(context_indices, 0, self.num_nodes - 1) 118 | context_embeddings = self.node_embeddings(valid_context_indices) 119 | return node_embeddings, context_embeddings 120 | 121 | return node_embeddings 122 | 123 | def compute_loss(self, node_embeddings, context_embeddings, labels=None): 124 | """ 125 | 计算损失函数 126 | 使用负采样的Skip-gram模型 127 | 128 | 参数: 129 | node_embeddings: 节点嵌入 130 | context_embeddings: 上下文嵌入 131 | labels: 标签(可选) 132 | 133 | 返回: 134 | loss: 损失值 135 | """ 136 | batch_size = node_embeddings.size(0) 137 | 138 | # 计算正样本得分 139 | pos_score = torch.sum(node_embeddings * context_embeddings, dim=1) 140 | pos_score = torch.sigmoid(pos_score) 141 | 142 | # 计算负样本得分 143 | # 使用批次内其他样本作为负样本 144 | neg_score = torch.matmul(node_embeddings, context_embeddings.t()) 145 | neg_score = torch.sigmoid(neg_score) 146 | 147 | # 创建标签矩阵 148 | neg_mask = torch.ones((batch_size, batch_size), device=node_embeddings.device) - torch.eye(batch_size, device=node_embeddings.device) 149 | 150 | # 计算正样本损失 151 | pos_loss = -torch.log(pos_score + 1e-10).mean() 152 | 153 | # 计算负样本损失 154 | neg_loss = -torch.sum(torch.log(1 - neg_score + 1e-10) * neg_mask) / (batch_size * (batch_size - 1)) 155 | 156 | # 总损失 157 | loss = pos_loss + neg_loss 158 | 159 | return loss 160 | 161 | def get_embeddings(self, inputs): 162 | """ 163 | 获取节点嵌入 164 | 165 | 参数: 166 | inputs: 包含多个特征列的列表 167 | 168 | 返回: 169 | embeddings: 节点嵌入 170 | """ 171 | with torch.no_grad(): 172 | return self.forward(inputs) 173 | 174 | def train_step(self, inputs, labels): 175 | """ 176 | 训练一步 177 | 178 | 参数: 179 | inputs: 包含多个特征列的列表 180 | labels: 标签 181 | 182 | 返回: 183 | loss: 损失值 184 | """ 185 | self.train() 186 | self.optimizer.zero_grad() 187 | 188 | # 获取节点嵌入和上下文嵌入 189 | node_embeddings, context_embeddings = self.forward(inputs, labels) 190 | 191 | # 计算损失 192 | loss = self.compute_loss(node_embeddings, context_embeddings, labels) 193 | 194 | # 反向传播 195 | loss.backward() 196 | 197 | # 梯度裁剪,防止梯度爆炸 198 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5.0) 199 | 200 | # 更新参数 201 | self.optimizer.step() 202 | 203 | return loss.item() 204 | 205 | def update_lr(self, val_loss): 206 | """ 207 | 更新学习率 208 | 209 | 参数: 210 | val_loss: 验证损失 211 | """ 212 | if self.scheduler is not None: 213 | self.scheduler.step(val_loss) 214 | 215 | def sync_parameters(self): 216 | """ 217 | 在分布式训练中同步模型参数 218 | """ 219 | for param in self.parameters(): 220 | dist.broadcast(param.data, 0) 221 | 222 | def reduce_gradients(self): 223 | """ 224 | 在分布式训练中归约梯度 225 | """ 226 | for param in self.parameters(): 227 | if param.requires_grad: 228 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 229 | param.grad.data /= dist.get_world_size() -------------------------------------------------------------------------------- /native_process/run_EGES.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | 9 | # 添加项目根目录到系统路径 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | # 导入公共模块 13 | from native_process.EGES_model import EGES_Model 14 | from utils import set_seed, plot_embeddings, write_embedding, save_dict_to_file, visualize_embeddings 15 | from data_process import get_session, create_dataloader, get_graph_context_all_pairs 16 | from native_process.walker import SimpleWalker 17 | from tqdm import tqdm 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description='EGES原生实现') 22 | parser.add_argument('--data_path', type=str, default='../data/') 23 | parser.add_argument('--output_dir', type=str, default='./') 24 | parser.add_argument('--p', type=float, default=0.25, help='返回参数') 25 | parser.add_argument('--q', type=float, default=2, help='进出参数') 26 | parser.add_argument('--num_walks', type=int, default=10, help='每个节点的游走次数') 27 | parser.add_argument('--walk_length', type=int, default=10, help='每次游走的长度') 28 | parser.add_argument('--window_size', type=int, default=5, help='上下文窗口大小') 29 | parser.add_argument('--embedding_dim', type=int, default=128, help='嵌入维度') 30 | parser.add_argument('--batch_size', type=int, default=8192, help='批次大小') 31 | parser.add_argument('--epochs', type=int, default=1, help='训练轮数') 32 | parser.add_argument('--lr', type=float, default=0.001, help='学习率') 33 | parser.add_argument('--seed', type=int, default=42, help='随机种子') 34 | parser.add_argument('--visualize', action='store_true', help='是否可视化嵌入向量') 35 | args = parser.parse_args() 36 | 37 | # 设置随机种子 38 | set_seed(args.seed) 39 | 40 | # 创建输出目录 41 | os.makedirs(args.output_dir, exist_ok=True) 42 | os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) 43 | os.makedirs(os.path.join(args.output_dir, 'embedding'), exist_ok=True) 44 | os.makedirs(os.path.join(args.output_dir, 'data_cache'), exist_ok=True) 45 | 46 | # 设置设备 47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | 49 | # 数据加载和预处理 50 | print("开始数据加载和预处理") 51 | 52 | # 读取数据 53 | print("读取数据...") 54 | start_time = time.time() 55 | action_data = pd.read_csv(args.data_path + 'action_head.csv') 56 | end_time = time.time() 57 | print(f"读取数据完成,耗时: {end_time - start_time:.2f}秒") 58 | 59 | # 构建会话 60 | print("构建会话...") 61 | start_time = time.time() 62 | session_list = get_session(action_data) 63 | end_time = time.time() 64 | print(f"构建会话完成,耗时: {end_time - start_time:.2f}秒") 65 | 66 | # 构建图并生成随机游走 67 | walker = SimpleWalker(p=args.p, q=args.q) 68 | G, node_maps = walker.build_graph(session_list) 69 | 70 | if G is None: 71 | print("图构建失败,退出程序") 72 | return 73 | 74 | node_map, reverse_node_map = node_maps 75 | 76 | # 生成随机游走 77 | walks = walker.generate_walks(G, args.num_walks, args.walk_length) 78 | 79 | # 生成上下文对 80 | print("生成上下文对...") 81 | start_time = time.time() 82 | all_pairs = get_graph_context_all_pairs(walks, args.window_size) 83 | end_time = time.time() 84 | print(f"生成上下文对完成,耗时: {end_time - start_time:.2f}秒") 85 | print(f"生成的样本对数量: {len(all_pairs)}") 86 | 87 | # 读取SKUsideinfo 88 | print("读取SKUsideinfo...") 89 | start_time = time.time() 90 | sku_info = pd.read_csv(args.data_path + 'jdata_product.csv') 91 | print(f"SKU信息形状: {sku_info.shape}") 92 | 93 | # 提取特征 94 | side_info = sku_info[['sku_id', 'cate', 'brand', 'shop_id']].values 95 | print(f"sideinfo形状: {side_info.shape}") 96 | 97 | # 创建特征长度列表 98 | feature_lens = [] 99 | for i in range(side_info.shape[1]): 100 | tmp_len = len(set(side_info[:, i])) + 1 # 加1是为了处理未知值 101 | feature_lens.append(tmp_len) 102 | 103 | end_time = time.time() 104 | print(f"读取SKUsideinfo完成,耗时: {end_time - start_time:.2f}秒") 105 | 106 | # 创建数据加载器 107 | dataloader = create_dataloader( 108 | side_info=side_info, 109 | pairs=all_pairs, 110 | node_map=node_map, 111 | batch_size=args.batch_size, 112 | num_workers=4 113 | ) 114 | 115 | # 创建模型 116 | model = EGES_Model( 117 | num_nodes=len(node_map), 118 | num_feat=side_info.shape[1], 119 | feature_lens=feature_lens, 120 | embedding_dim=args.embedding_dim, 121 | lr=args.lr 122 | ).to(device) 123 | 124 | # 初始化优化器 125 | model.init_optimizer() 126 | 127 | # 训练模型 128 | print("开始训练...") 129 | 130 | for epoch in range(args.epochs): 131 | model.train() 132 | total_loss = 0 133 | 134 | # 使用tqdm显示进度条 135 | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") 136 | 137 | for batch_idx, (features, contexts) in enumerate(pbar): 138 | # 将数据移到设备上 139 | features = features.to(device) 140 | contexts = contexts.to(device) 141 | 142 | # 将特征拆分为多个列 143 | feature_columns = [features[:, i] for i in range(features.size(1))] 144 | 145 | # 训练一步 146 | loss = model.train_step(feature_columns, contexts) 147 | 148 | total_loss += loss 149 | 150 | # 打印进度 151 | if (batch_idx + 1) % 100 == 0: 152 | print(f"Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss:.4f}") 153 | 154 | # 打印每轮的平均损失 155 | avg_loss = total_loss / len(dataloader) 156 | print(f"Epoch {epoch+1}/{args.epochs}, Average Loss: {avg_loss:.4f}") 157 | 158 | # 更新学习率 159 | model.update_lr(avg_loss) 160 | 161 | # 保存模型 162 | if (epoch + 1) % 5 == 0 or epoch == args.epochs - 1: 163 | model_path = os.path.join(args.output_dir, 'checkpoints', f"model_epoch_{epoch+1}.pt") 164 | 165 | # 保存模型状态 166 | torch.save({ 167 | 'model_state_dict': model.state_dict(), 168 | 'optimizer_state_dict': model.optimizer.state_dict(), 169 | 'epoch': epoch, 170 | 'loss': avg_loss, 171 | 'node_map': node_map, 172 | 'reverse_node_map': reverse_node_map 173 | }, model_path) 174 | 175 | # 如果是最后一轮,保存为final模型 176 | if epoch == args.epochs - 1: 177 | final_model_path = os.path.join(args.output_dir, 'checkpoints', "model_final.pt") 178 | torch.save({ 179 | 'model_state_dict': model.state_dict(), 180 | 'optimizer_state_dict': model.optimizer.state_dict(), 181 | 'epoch': epoch, 182 | 'loss': avg_loss, 183 | 'node_map': node_map, 184 | 'reverse_node_map': reverse_node_map 185 | }, final_model_path) 186 | 187 | # 保存嵌入 188 | print("保存嵌入...") 189 | 190 | # 获取嵌入 191 | embeddings = model.node_embeddings.weight.detach().cpu().numpy() 192 | 193 | # 将嵌入映射回原始节点ID 194 | node_embeddings = {} 195 | for idx, node_id in reverse_node_map.items(): 196 | node_embeddings[node_id] = embeddings[idx] 197 | 198 | # 保存嵌入到文件 199 | np.save(os.path.join(args.output_dir, 'embedding', "node_embeddings.npy"), node_embeddings) 200 | 201 | # 保存节点映射 202 | save_dict_to_file(node_map, os.path.join(args.output_dir, 'embedding', "node_map.txt")) 203 | save_dict_to_file(reverse_node_map, os.path.join(args.output_dir, 'embedding', "reverse_node_map.txt")) 204 | 205 | # 将嵌入写入文本文件 206 | print("将嵌入写入文本文件...") 207 | embedding_file = os.path.join(args.output_dir, 'embedding', "node_embeddings.txt") 208 | write_embedding([node_embeddings[node_id] for node_id in sorted(node_embeddings.keys())], embedding_file) 209 | 210 | # 如果需要可视化,则调用可视化函数 211 | if args.visualize: 212 | visualize_embeddings(node_embeddings, args.data_path, args.output_dir) 213 | 214 | print("训练完成!") 215 | 216 | 217 | if __name__ == "__main__": 218 | main() -------------------------------------------------------------------------------- /gpu_process/stream_EGES.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | import torch 7 | import signal 8 | import argparse 9 | from datetime import datetime, timedelta 10 | from tqdm import tqdm 11 | from torch_geometric.utils import from_networkx 12 | import networkx as nx 13 | import threading 14 | 15 | # 添加项目根目录到系统路径 16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | 18 | from data_process import get_session 19 | from gpu_process.EGES_module import EGES, EGESTrainer 20 | from utils import save_embeddings 21 | 22 | # 全局变量,用于信号处理 23 | stop_streaming = False 24 | 25 | # 信号处理函数 26 | def signal_handler(sig, frame): 27 | global stop_streaming 28 | print("\n接收到停止信号,正在安全退出...") 29 | stop_streaming = True 30 | 31 | # 注册信号处理 32 | signal.signal(signal.SIGINT, signal_handler) 33 | 34 | class StreamingEGES: 35 | """ 36 | 流式EGES实现,按时间戳划分数据 37 | """ 38 | def __init__(self, args): 39 | self.args = args 40 | self.device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') 41 | self.time_interval = args.time_interval # 时间间隔(小时) 42 | 43 | # 模型相关参数 44 | self.embedding_dim = args.embedding_dim 45 | self.walk_length = args.walk_length 46 | self.context_size = args.context_size 47 | self.walks_per_node = args.walks_per_node 48 | self.p = args.p 49 | self.q = args.q 50 | self.lr = args.lr 51 | self.batch_size = args.batch_size 52 | self.epochs = args.epochs 53 | 54 | # 初始化节点映射和图 55 | self.G = None 56 | self.node_map = {} 57 | self.reverse_node_map = {} 58 | self.side_info = None 59 | self.side_info_dict = {} 60 | self.feature_lens = None 61 | self.model = None 62 | self.trainer = None 63 | 64 | # 添加窗口计数器 65 | self.window_count = 0 66 | self.save_interval = args.save_interval # 每10个窗口保存一次 67 | 68 | # 初始化时间窗口 69 | self.current_time = None 70 | self.next_window_time = None 71 | 72 | # 创建输出目录 73 | self.output_dir = os.path.join(args.output_dir, 'streaming') 74 | os.makedirs(self.output_dir, exist_ok=True) 75 | 76 | # 加载商品侧信息 77 | self._load_side_info() 78 | 79 | def _load_side_info(self): 80 | """加载商品侧信息""" 81 | print("加载商品侧信息...") 82 | product_file = os.path.join(self.args.data_path, 'jdata_product.csv') 83 | if not os.path.exists(product_file): 84 | print(f"错误: 商品侧信息文件 {product_file} 不存在!") 85 | return 86 | 87 | side_info_df = pd.read_csv(product_file) 88 | side_info_df = side_info_df.fillna(0) 89 | 90 | # 将时间类型转换为时间戳 91 | if 'market_time' in side_info_df.columns: 92 | side_info_df['market_time'] = pd.to_datetime(side_info_df['market_time']) 93 | side_info_df['market_time'] = side_info_df['market_time'].astype(int) // 10**9 94 | 95 | # 保留数值类型特征列 96 | numeric_cols = ['sku_id', 'brand', 'shop_id', 'cate'] 97 | if all(col in side_info_df.columns for col in numeric_cols): 98 | self.side_info = side_info_df[numeric_cols].values.astype(np.int64) 99 | print(f"加载的侧信息形状: {self.side_info.shape}") 100 | else: 101 | print("警告: 侧信息列不完整,使用默认列") 102 | cols = [col for col in numeric_cols if col in side_info_df.columns] 103 | self.side_info = side_info_df[cols].values.astype(np.int64) 104 | 105 | def _initialize_model(self): 106 | """初始化或更新模型""" 107 | # 如果模型已存在,则保存旧模型的状态 108 | old_state_dict = None 109 | if self.model is not None: 110 | old_state_dict = self.model.state_dict() 111 | 112 | # 创建特征长度列表 113 | if self.feature_lens is None and self.side_info is not None: 114 | self.feature_lens = [] 115 | for i in range(self.side_info.shape[1]): 116 | tmp_len = len(set(self.side_info[:, i])) + 1 117 | self.feature_lens.append(tmp_len) 118 | 119 | # 创建PyG图 120 | if self.G is not None: 121 | pyg_data = from_networkx(self.G).to(self.device) 122 | 123 | # 创建EGESTrainer 124 | self.trainer = EGESTrainer( 125 | session_list=None, # 不需要再次构建图 126 | side_info=self.side_info, 127 | embedding_dim=self.embedding_dim, 128 | walk_length=self.walk_length, 129 | context_size=self.context_size, 130 | walks_per_node=self.walks_per_node, 131 | p=self.p, 132 | q=self.q, 133 | lr=self.lr, 134 | device=self.device, 135 | prefetch_factor=2, 136 | G=self.G, 137 | node_map=self.node_map, 138 | reverse_node_map=self.reverse_node_map, 139 | pyg_data=pyg_data 140 | ) 141 | 142 | self.model = self.trainer.model 143 | 144 | # 如果有旧模型的状态,加载到新模型 145 | if old_state_dict is not None: 146 | # 尝试加载兼容的层 147 | try: 148 | self.model.load_state_dict(old_state_dict, strict=False) 149 | print("已加载先前的模型权重") 150 | except Exception as e: 151 | print(f"加载先前的模型权重时出错: {e}") 152 | print("使用新初始化的模型继续") 153 | 154 | def _update_graph(self, session_list): 155 | """更新图结构""" 156 | print("更新图结构...") 157 | start_time = time.time() 158 | 159 | # 提取所有边 160 | edges = [] 161 | for session in session_list: 162 | if len(session) > 1: 163 | for i in range(len(session) - 1): 164 | u = int(session[i]) 165 | v = int(session[i + 1]) 166 | edges.append((u, v)) 167 | 168 | print(f"当前时间窗口的边数量: {len(edges)}") 169 | if len(edges) == 0: 170 | print("警告: 当前窗口没有边,无法更新图") 171 | return False 172 | 173 | # 如果是第一次构建图 174 | if self.G is None: 175 | self.G = nx.Graph() 176 | self.G.add_edges_from(edges) 177 | 178 | # 获取所有唯一节点 179 | nodes = list(self.G.nodes()) 180 | self.node_map = {node: i for i, node in enumerate(nodes)} 181 | self.reverse_node_map = {i: node for i, node in enumerate(nodes)} 182 | 183 | # 重新映射节点ID 184 | self.G = nx.relabel_nodes(self.G, self.node_map) 185 | else: 186 | # 添加新的边到现有图 187 | new_nodes = set() 188 | relabeled_edges = [] 189 | 190 | for u, v in edges: 191 | # 处理新节点 192 | if u not in self.node_map: 193 | new_nodes.add(u) 194 | if v not in self.node_map: 195 | new_nodes.add(v) 196 | 197 | # 为新节点分配ID 198 | current_max_id = max(self.node_map.values()) if self.node_map else -1 199 | for i, node in enumerate(new_nodes): 200 | node_id = current_max_id + 1 + i 201 | self.node_map[node] = node_id 202 | self.reverse_node_map[node_id] = node 203 | 204 | # 将边转换为内部ID 205 | for u, v in edges: 206 | relabeled_edges.append((self.node_map[u], self.node_map[v])) 207 | 208 | # 添加到现有图 209 | self.G.add_edges_from(relabeled_edges) 210 | 211 | end_time = time.time() 212 | print(f"图更新完成,耗时: {end_time - start_time:.2f}秒") 213 | print(f"当前图节点数: {self.G.number_of_nodes()}, 边数: {self.G.number_of_edges()}") 214 | return True 215 | 216 | def _process_time_window(self, action_data, window_start, window_end): 217 | """处理一个时间窗口的数据""" 218 | print(f"\n处理时间窗口: {window_start} 到 {window_end}") 219 | 220 | # 过滤当前时间窗口的数据 221 | window_data = action_data[(action_data['action_time'] >= window_start) & 222 | (action_data['action_time'] < window_end)] 223 | 224 | if window_data.empty: 225 | print(f"时间窗口 {window_start} 到 {window_end} 没有数据") 226 | return False 227 | 228 | print(f"当前窗口数据量: {len(window_data)}") 229 | 230 | # 获取会话列表 231 | session_list = get_session(window_data, verbose=False) 232 | 233 | if not session_list or len(session_list) == 0: 234 | print("警告: 当前窗口没有有效会话") 235 | return False 236 | 237 | # 更新图结构 238 | if not self._update_graph(session_list): 239 | return False 240 | 241 | # 初始化模型 242 | self._initialize_model() 243 | 244 | # 训练模型 245 | print(f"使用当前窗口数据训练模型 ({window_start} 到 {window_end})...") 246 | 247 | # 不再为每个时间窗口创建输出目录,只记录时间标记 248 | self.current_window_time = window_start 249 | 250 | self.model = self.trainer.train( 251 | epochs=self.epochs, 252 | batch_size=self.batch_size, 253 | output_dir=self.output_dir, # 使用流处理的主输出目录 254 | plot_loss=False # 禁用损失曲线绘制 255 | ) 256 | 257 | # 递增窗口计数 258 | self.window_count += 1 259 | 260 | # 每save_interval个窗口保存一次模型和嵌入 261 | if self.window_count % self.save_interval == 0: 262 | time_str = self.current_window_time.strftime('%Y%m%d%H') 263 | self._save_checkpoint(f"checkpoint_{self.window_count}_{time_str}") 264 | # 获取并保存嵌入 265 | embed_dir = os.path.join(self.output_dir, f"embeddings_{self.window_count}_{time_str}") 266 | self._save_embeddings(embed_dir) 267 | 268 | return True 269 | 270 | def _save_checkpoint(self, checkpoint_name): 271 | """保存模型检查点""" 272 | if self.model is None: 273 | return 274 | 275 | checkpoint_dir = os.path.join(self.output_dir, 'checkpoints') 276 | os.makedirs(checkpoint_dir, exist_ok=True) 277 | checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_name}.pt") 278 | 279 | torch.save({ 280 | 'model_state_dict': self.model.state_dict(), 281 | 'window_count': self.window_count, 282 | 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') 283 | }, checkpoint_path) 284 | 285 | print(f"模型检查点已保存到 {checkpoint_path}") 286 | 287 | def _save_embeddings(self, output_dir): 288 | """保存节点嵌入""" 289 | print("保存节点嵌入...") 290 | 291 | # 创建嵌入目录 292 | embed_dir = os.path.join(output_dir, 'embedding') 293 | os.makedirs(embed_dir, exist_ok=True) 294 | 295 | # 确保模型处于评估模式 296 | self.model.eval() 297 | 298 | # 获取所有节点的嵌入 299 | with torch.no_grad(): 300 | embeddings = self.model.forward() 301 | 302 | # 转换为CPU上的numpy数组 303 | embeddings_np = embeddings.detach().cpu().numpy() 304 | 305 | # 保存嵌入 306 | save_embeddings(embeddings_np, self.reverse_node_map, embed_dir) 307 | 308 | # 不再进行可视化 309 | 310 | print(f"嵌入已保存到 {embed_dir}") 311 | 312 | def stream_process(self): 313 | """流式处理数据""" 314 | global stop_streaming 315 | 316 | # 加载用户行为数据 317 | action_file = os.path.join(self.args.data_path, 'action_head.csv') 318 | if not os.path.exists(action_file): 319 | print(f"错误: 用户行为数据文件 {action_file} 不存在!") 320 | return 321 | 322 | # 读取用户行为数据 323 | print("加载用户行为数据...") 324 | action_data = pd.read_csv(action_file) 325 | 326 | # 确保action_time列是datetime类型 327 | action_data['action_time'] = pd.to_datetime(action_data['action_time']) 328 | 329 | # 获取数据的时间范围 330 | min_time = action_data['action_time'].min() 331 | max_time = action_data['action_time'].max() 332 | 333 | print(f"数据时间范围: {min_time} 到 {max_time}") 334 | 335 | # 初始化当前时间窗口 336 | self.current_time = min_time 337 | self.next_window_time = min_time + timedelta(hours=self.time_interval) 338 | 339 | # 循环处理每个时间窗口的数据,直到数据处理完或接收到停止信号 340 | while self.current_time < max_time and not stop_streaming: 341 | if self._process_time_window(action_data, self.current_time, self.next_window_time): 342 | print(f"时间窗口 {self.current_time} 到 {self.next_window_time} 处理完成") 343 | 344 | # 更新时间窗口 345 | self.current_time = self.next_window_time 346 | self.next_window_time = self.current_time + timedelta(hours=self.time_interval) 347 | 348 | if stop_streaming: 349 | print("流处理被用户中断") 350 | else: 351 | print("所有数据已处理完毕") 352 | 353 | # 保存最终模型 354 | if self.model is not None: 355 | # 使用当前时间作为最终模型的时间标记 356 | time_str = datetime.now().strftime('%Y%m%d%H%M') 357 | final_output_dir = os.path.join(self.output_dir, f'final_{time_str}') 358 | os.makedirs(final_output_dir, exist_ok=True) 359 | 360 | # 保存模型权重 361 | checkpoint_dir = os.path.join(final_output_dir, 'checkpoints') 362 | os.makedirs(checkpoint_dir, exist_ok=True) 363 | torch.save({ 364 | 'model_state_dict': self.model.state_dict(), 365 | 'window_count': self.window_count, 366 | 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') 367 | }, os.path.join(checkpoint_dir, 'model_final.pt')) 368 | 369 | # 保存最终嵌入 370 | embed_dir = os.path.join(final_output_dir, 'embeddings') 371 | self._save_embeddings(embed_dir) 372 | 373 | print(f"最终模型已保存到 {final_output_dir}") 374 | 375 | 376 | def main(): 377 | parser = argparse.ArgumentParser(description='流式EGES实现') 378 | 379 | # 数据参数 380 | parser.add_argument('--data_path', type=str, default='./data/', 381 | help='数据文件路径') 382 | parser.add_argument('--output_dir', type=str, default='./output/', 383 | help='输出目录') 384 | parser.add_argument('--time_interval', type=int, default=1, 385 | help='时间窗口间隔(小时)') 386 | parser.add_argument('--save_interval', type=int, default=100, 387 | help='每10个窗口保存一次') 388 | # 模型参数 389 | parser.add_argument('--embedding_dim', type=int, default=128, 390 | help='嵌入维度') 391 | parser.add_argument('--walk_length', type=int, default=10, 392 | help='随机游走长度') 393 | parser.add_argument('--context_size', type=int, default=5, 394 | help='上下文窗口大小') 395 | parser.add_argument('--walks_per_node', type=int, default=10, 396 | help='每个节点的游走次数') 397 | parser.add_argument('--p', type=float, default=1.0, 398 | help='返回参数') 399 | parser.add_argument('--q', type=float, default=1.0, 400 | help='进出参数') 401 | 402 | # 训练参数 403 | parser.add_argument('--lr', type=float, default=0.001, 404 | help='学习率') 405 | parser.add_argument('--batch_size', type=int, default=128, 406 | help='批次大小') 407 | parser.add_argument('--epochs', type=int, default=1, 408 | help='每个时间窗口的训练轮数') 409 | parser.add_argument('--cpu', action='store_true', 410 | help='是否使用CPU训练') 411 | 412 | # 不再需要可视化参数 413 | # parser.add_argument('--visualize', action='store_true', 414 | # help='是否可视化嵌入向量') 415 | 416 | args = parser.parse_args() 417 | 418 | # 打印配置信息 419 | print("\n=== 流式EGES训练配置 ===") 420 | print(f"数据路径: {args.data_path}") 421 | print(f"输出目录: {args.output_dir}") 422 | print(f"时间窗口间隔: {args.time_interval}小时") 423 | print(f"嵌入维度: {args.embedding_dim}") 424 | print(f"使用设备: {'CPU' if args.cpu else 'GPU'}") 425 | print(f"每个窗口训练轮数: {args.epochs}") 426 | print("====================\n") 427 | 428 | # 创建并启动流式处理 429 | streaming_eges = StreamingEGES(args) 430 | streaming_eges.stream_process() 431 | 432 | 433 | if __name__ == "__main__": 434 | main() -------------------------------------------------------------------------------- /gpu_process/run_EGES_multi_gpu.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from tqdm import tqdm 11 | import pickle 12 | import networkx as nx 13 | from torch_geometric.utils import from_networkx 14 | import io 15 | import logging 16 | 17 | # 添加项目根目录到系统路径 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | project_root = os.path.dirname(current_dir) 20 | sys.path.append(project_root) 21 | 22 | # 导入集成版EGES 23 | from gpu_process.EGES_module import EGES, EGESTrainer 24 | from utils import set_seed, setup, cleanup, get_free_port, write_embedding, visualize_embeddings 25 | from data_process import get_session 26 | 27 | 28 | # 设置日志 29 | def setup_logger(rank, log_dir=None): 30 | """ 31 | 设置日志 32 | 33 | 参数: 34 | rank: 进程排名 35 | log_dir: 日志目录 36 | """ 37 | logger = logging.getLogger(f"rank_{rank}") 38 | logger.setLevel(logging.INFO) 39 | 40 | # 创建控制台处理器 41 | console_handler = logging.StreamHandler() 42 | console_handler.setLevel(logging.INFO) 43 | 44 | # 创建格式化器 45 | formatter = logging.Formatter(f'[Rank {rank}] %(asctime)s - %(levelname)s - %(message)s') 46 | console_handler.setFormatter(formatter) 47 | 48 | # 添加处理器 49 | logger.addHandler(console_handler) 50 | 51 | # 如果指定了日志目录,也添加文件处理器 52 | if log_dir: 53 | os.makedirs(log_dir, exist_ok=True) 54 | file_handler = logging.FileHandler(os.path.join(log_dir, f"rank_{rank}.log")) 55 | file_handler.setLevel(logging.INFO) 56 | file_handler.setFormatter(formatter) 57 | logger.addHandler(file_handler) 58 | 59 | return logger 60 | 61 | 62 | def build_global_graph(session_list, logger): 63 | """ 64 | 构建全局图,确保所有进程使用相同的图结构 65 | 66 | 参数: 67 | session_list: 会话列表 68 | logger: 日志记录器 69 | 70 | 返回: 71 | G: NetworkX图 72 | node_maps: 节点映射元组 (node_map, reverse_node_map) 73 | """ 74 | logger.info("构建全局图...") 75 | start_time = time.time() 76 | 77 | # 提取所有边 - 使用集合去重 78 | edge_set = set() 79 | for session in session_list: 80 | if len(session) > 1: 81 | for i in range(len(session) - 1): 82 | u = int(session[i]) 83 | v = int(session[i + 1]) 84 | edge_set.add((u, v)) 85 | 86 | edges = list(edge_set) 87 | logger.info(f"提取的边数量: {len(edges)}") 88 | if len(edges) == 0: 89 | logger.warning("警告: 没有提取到边,无法构建图") 90 | return None, None 91 | 92 | # 创建NetworkX图 93 | logger.info("创建NetworkX图...") 94 | G = nx.Graph() 95 | G.add_edges_from(edges) 96 | 97 | # 获取所有唯一节点 98 | logger.info("创建节点映射...") 99 | nodes = list(G.nodes()) 100 | node_map = {node: i for i, node in enumerate(nodes)} 101 | reverse_node_map = {i: node for i, node in enumerate(nodes)} 102 | 103 | # 重新映射节点ID 104 | logger.info("重新映射节点ID...") 105 | G_relabeled = nx.relabel_nodes(G, node_map) 106 | 107 | end_time = time.time() 108 | logger.info(f"全局图构建完成,耗时: {end_time - start_time:.2f}秒") 109 | logger.info(f"全局图包含 {G.number_of_nodes()} 个节点和 {G.number_of_edges()} 条边") 110 | 111 | return G_relabeled, (node_map, reverse_node_map) 112 | 113 | 114 | def fast_serialize(obj): 115 | """ 116 | 快速序列化对象 117 | 118 | 参数: 119 | obj: 要序列化的对象 120 | 121 | 返回: 122 | bytes: 序列化后的字节 123 | """ 124 | buffer = io.BytesIO() 125 | torch.save(obj, buffer) 126 | return buffer.getvalue() 127 | 128 | 129 | def fast_deserialize(bytes_obj): 130 | """ 131 | 快速反序列化对象 132 | 133 | 参数: 134 | bytes_obj: 序列化后的字节 135 | 136 | 返回: 137 | obj: 反序列化后的对象 138 | """ 139 | buffer = io.BytesIO(bytes_obj) 140 | return torch.load(buffer) 141 | 142 | 143 | def train_model(rank, world_size, args): 144 | """ 145 | 在指定GPU上训练模型 146 | 147 | 参数: 148 | rank: 当前进程的排名 149 | world_size: 总进程数 150 | args: 命令行参数 151 | """ 152 | # 设置日志 153 | log_dir = os.path.join(args.output_dir, 'logs') 154 | logger = setup_logger(rank, log_dir) 155 | 156 | # 设置分布式训练环境 157 | logger.info(f"初始化分布式环境...") 158 | start_time = time.time() 159 | setup(rank, world_size, args.master_addr, args.master_port) 160 | logger.info(f"分布式环境初始化完成,耗时: {time.time() - start_time:.2f}秒") 161 | 162 | # 设置随机种子 - 每个进程使用不同的种子以增加多样性 163 | set_seed(args.seed + rank) 164 | 165 | # 设置设备 166 | device = torch.device(f"cuda:{rank}") 167 | torch.cuda.set_device(device) 168 | 169 | # 数据加载和预处理 - 只在主进程进行 170 | if rank == 0: 171 | logger.info(f"开始数据加载和预处理") 172 | 173 | # 读取数据 174 | start_time = time.time() 175 | action_data = pd.read_csv(os.path.join(args.data_path, 'action_head.csv')) 176 | logger.info(f"读取数据完成,耗时: {time.time() - start_time:.2f}秒") 177 | 178 | # 构建会话 179 | start_time = time.time() 180 | session_list = get_session(action_data) 181 | logger.info(f"构建会话完成,耗时: {time.time() - start_time:.2f}秒") 182 | 183 | # 读取SKUsideinfo 184 | start_time = time.time() 185 | sku_info = pd.read_csv(os.path.join(args.data_path, 'jdata_product.csv')) 186 | side_info = sku_info[['sku_id', 'cate', 'brand', 'shop_id']].values 187 | logger.info(f"读取SKUsideinfo完成,耗时: {time.time() - start_time:.2f}秒") 188 | 189 | # 构建全局图 - 确保所有进程使用相同的图结构 190 | G, node_maps = build_global_graph(session_list, logger) 191 | node_map, reverse_node_map = node_maps 192 | 193 | # 转换为PyG图 194 | logger.info("转换为PyG图...") 195 | start_time = time.time() 196 | pyg_data = from_networkx(G) 197 | logger.info(f"PyG图转换完成,耗时: {time.time() - start_time:.2f}秒") 198 | 199 | # 分割数据 - 为每个进程准备一部分数据 200 | if world_size > 1: 201 | logger.info("分割会话数据...") 202 | start_time = time.time() 203 | # 计算每个进程的会话数量 204 | sessions_per_process = len(session_list) // world_size 205 | session_splits = [] 206 | 207 | # 分割会话列表 208 | for i in range(world_size): 209 | start_idx = i * sessions_per_process 210 | end_idx = (i + 1) * sessions_per_process if i < world_size - 1 else len(session_list) 211 | session_splits.append(session_list[start_idx:end_idx]) 212 | 213 | # 更新主进程的会话列表 214 | session_list = session_splits[0] 215 | logger.info(f"会话数据分割完成,耗时: {time.time() - start_time:.2f}秒") 216 | logger.info(f"每个进程的会话数量: {[len(split) for split in session_splits]}") 217 | 218 | # 创建训练器 219 | logger.info("创建训练器...") 220 | start_time = time.time() 221 | trainer = EGESTrainer( 222 | session_list=session_list, 223 | side_info=side_info, 224 | embedding_dim=args.embedding_dim, 225 | walk_length=args.walk_length, 226 | context_size=args.context_size, 227 | walks_per_node=args.walks_per_node, 228 | p=args.p, 229 | q=args.q, 230 | lr=args.lr, 231 | device=device, 232 | prefetch_factor=args.prefetch_factor, 233 | # 使用预构建的图和节点映射 234 | G=G, 235 | node_map=node_map, 236 | reverse_node_map=reverse_node_map, 237 | pyg_data=pyg_data 238 | ) 239 | logger.info(f"训练器创建完成,耗时: {time.time() - start_time:.2f}秒") 240 | 241 | # 准备要广播的数据 242 | if world_size > 1: 243 | logger.info("准备广播数据...") 244 | start_time = time.time() 245 | 246 | # 准备图数据 247 | graph_data = { 248 | 'edge_index': trainer.pyg_data.edge_index.cpu(), 249 | 'num_nodes': trainer.num_nodes, 250 | 'node_map': trainer.node_map, 251 | 'reverse_node_map': trainer.reverse_node_map, 252 | 'feature_lens': trainer.feature_lens 253 | } 254 | 255 | # 如果有侧信息,也保存 256 | if trainer.side_info is not None: 257 | graph_data['side_info_dict'] = {k: v.cpu() for k, v in trainer.side_info_dict.items()} 258 | 259 | # 保存模型状态 260 | model_state = trainer.model.state_dict() 261 | 262 | # 序列化数据 - 使用快速序列化 263 | graph_data_bytes = fast_serialize(graph_data) 264 | model_state_bytes = fast_serialize(model_state) 265 | side_info_bytes = fast_serialize(side_info) if side_info is not None else None 266 | session_splits_bytes = fast_serialize(session_splits) 267 | 268 | # 准备广播列表 269 | broadcast_data = [ 270 | graph_data_bytes, 271 | model_state_bytes, 272 | side_info_bytes, 273 | session_splits_bytes 274 | ] 275 | 276 | logger.info(f"广播数据准备完成,耗时: {time.time() - start_time:.2f}秒") 277 | logger.info(f"广播数据大小: {sum(len(x) if x else 0 for x in broadcast_data) / (1024 * 1024):.2f} MB") 278 | else: 279 | broadcast_data = None 280 | session_splits = None 281 | else: 282 | # 非主进程初始化为None 283 | trainer = None 284 | session_list = None 285 | broadcast_data = [None, None, None, None] 286 | side_info = None 287 | 288 | # 广播数据到所有进程 289 | if world_size > 1: 290 | logger.info("开始广播数据...") 291 | start_time = time.time() 292 | 293 | # 广播数据大小 294 | if rank == 0: 295 | sizes = [len(x) if x else 0 for x in broadcast_data] 296 | else: 297 | sizes = [0, 0, 0, 0] 298 | 299 | # 使用all_gather广播大小 300 | size_tensor = torch.tensor(sizes, dtype=torch.long, device=device) 301 | dist.broadcast(size_tensor, 0) 302 | sizes = size_tensor.tolist() 303 | 304 | # 根据大小创建接收缓冲区 305 | if rank != 0: 306 | broadcast_data = [ 307 | torch.zeros(sizes[0], dtype=torch.uint8, device=device) if sizes[0] > 0 else None, 308 | torch.zeros(sizes[1], dtype=torch.uint8, device=device) if sizes[1] > 0 else None, 309 | torch.zeros(sizes[2], dtype=torch.uint8, device=device) if sizes[2] > 0 else None, 310 | torch.zeros(sizes[3], dtype=torch.uint8, device=device) if sizes[3] > 0 else None 311 | ] 312 | else: 313 | # 转换为张量 314 | broadcast_data = [ 315 | torch.frombuffer(x, dtype=torch.uint8).to(device) if x else None 316 | for x in broadcast_data 317 | ] 318 | 319 | # 广播数据 320 | for i in range(4): 321 | if sizes[i] > 0: 322 | dist.broadcast(broadcast_data[i], 0) 323 | 324 | # 反序列化数据 325 | if rank != 0: 326 | # 转换为字节 327 | broadcast_data = [ 328 | x.cpu().numpy().tobytes() if x is not None else None 329 | for x in broadcast_data 330 | ] 331 | 332 | # 反序列化 333 | graph_data = fast_deserialize(broadcast_data[0]) 334 | model_state = fast_deserialize(broadcast_data[1]) 335 | side_info = fast_deserialize(broadcast_data[2]) if broadcast_data[2] else None 336 | session_splits = fast_deserialize(broadcast_data[3]) 337 | 338 | # 使用分配给当前进程的会话列表 339 | session_list = session_splits[rank] 340 | 341 | # 创建训练器 342 | logger.info("创建训练器...") 343 | start_time_trainer = time.time() 344 | 345 | # 从图数据中提取图和节点映射 346 | G = nx.Graph() 347 | edge_index = graph_data['edge_index'] 348 | edges = [(edge_index[0, i].item(), edge_index[1, i].item()) for i in range(edge_index.size(1))] 349 | G.add_edges_from(edges) 350 | 351 | # 创建训练器 - 使用预构建的图和节点映射 352 | trainer = EGESTrainer( 353 | session_list=session_list, 354 | side_info=side_info, 355 | embedding_dim=args.embedding_dim, 356 | walk_length=args.walk_length, 357 | context_size=args.context_size, 358 | walks_per_node=args.walks_per_node, 359 | p=args.p, 360 | q=args.q, 361 | lr=args.lr, 362 | device=device, 363 | prefetch_factor=args.prefetch_factor, 364 | # 使用预构建的图和节点映射,确保所有进程使用相同的图结构 365 | G=G, 366 | node_map=graph_data['node_map'], 367 | reverse_node_map=graph_data['reverse_node_map'], 368 | pyg_data=None # 稍后会更新 369 | ) 370 | 371 | # 更新图结构和节点映射 372 | trainer.num_nodes = graph_data['num_nodes'] 373 | trainer.node_map = graph_data['node_map'] 374 | trainer.reverse_node_map = graph_data['reverse_node_map'] 375 | trainer.feature_lens = graph_data['feature_lens'] 376 | 377 | # 更新PyG数据 378 | trainer.pyg_data = from_networkx(G).to(device) 379 | trainer.pyg_data.edge_index = edge_index.to(device) 380 | 381 | # 重新创建模型以确保结构一致 382 | trainer.model = EGES( 383 | edge_index=trainer.pyg_data.edge_index, 384 | num_nodes=trainer.num_nodes, 385 | feature_dim=len(trainer.feature_lens), 386 | feature_lens=trainer.feature_lens, 387 | embedding_dim=args.embedding_dim, 388 | walk_length=args.walk_length, 389 | context_size=args.context_size, 390 | walks_per_node=args.walks_per_node, 391 | p=args.p, 392 | q=args.q, 393 | lr=args.lr, 394 | device=device 395 | ) 396 | 397 | # 初始化优化器 398 | trainer.model.init_optimizer(args.lr) 399 | 400 | # 如果有侧信息,也更新 401 | if 'side_info_dict' in graph_data: 402 | trainer.side_info_dict = {k: v.to(device) for k, v in graph_data['side_info_dict'].items()} 403 | 404 | # 加载模型状态 405 | trainer.model.load_state_dict(model_state) 406 | 407 | logger.info(f"训练器创建完成,耗时: {time.time() - start_time_trainer:.2f}秒") 408 | 409 | logger.info(f"数据广播完成,耗时: {time.time() - start_time:.2f}秒") 410 | 411 | # 同步所有进程 412 | if world_size > 1: 413 | dist.barrier() 414 | logger.info("所有进程同步完成,准备开始训练") 415 | 416 | # 训练模型 417 | if rank == 0: 418 | logger.info("开始训练...") 419 | 420 | # 获取数据加载器 421 | loader = trainer.model.get_loader(batch_size=args.batch_size, shuffle=True) 422 | 423 | # 记录每个epoch的损失 424 | epoch_losses = [] 425 | 426 | for epoch in range(args.epochs): 427 | epoch_start_time = time.time() 428 | trainer.model.train() 429 | total_loss = 0 430 | batch_count = 0 431 | 432 | # 使用tqdm显示进度条(仅在主进程) 433 | if rank == 0: 434 | pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{args.epochs}") 435 | else: 436 | pbar = loader 437 | 438 | # 每个进程独立训练,不同步梯度 439 | for pos_rw, neg_rw in pbar: 440 | # 训练一步 441 | loss = trainer.model.train_step(pos_rw, neg_rw) 442 | 443 | # 累积损失和批次计数 444 | total_loss += loss 445 | batch_count += 1 446 | 447 | # 更新进度条(仅在主进程) 448 | if rank == 0 and isinstance(pbar, tqdm): 449 | pbar.set_postfix({'loss': f'{loss:.4f}'}) 450 | 451 | # 计算平均损失 452 | avg_loss = total_loss / batch_count 453 | 454 | # 同步损失 455 | if world_size > 1: 456 | avg_loss_tensor = torch.tensor([avg_loss], device=device) 457 | dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM) 458 | avg_loss = avg_loss_tensor.item() / world_size 459 | 460 | # 记录损失 461 | epoch_losses.append(avg_loss) 462 | 463 | # 更新学习率 464 | trainer.model.update_lr(avg_loss) 465 | 466 | # 打印每轮的平均损失(仅在主进程) 467 | if rank == 0: 468 | logger.info(f"Epoch {epoch+1}/{args.epochs}, Average Loss: {avg_loss:.4f}") 469 | 470 | # 同步模型参数 - 使用模型平均 471 | if world_size > 1: 472 | logger.info(f"开始同步模型参数...") 473 | sync_start_time = time.time() 474 | 475 | # 收集所有进程的模型参数 476 | model_state = trainer.model.state_dict() 477 | 478 | # 对于每个参数,进行all_reduce操作 479 | for key in model_state: 480 | if 'embedding' in key or 'attention' in key: # 只平均嵌入和注意力参数 481 | # 确保参数在GPU上 482 | param = model_state[key].to(device).float() 483 | 484 | # 使用all_reduce操作平均参数 485 | dist.all_reduce(param, op=dist.ReduceOp.SUM) 486 | param.div_(world_size) 487 | 488 | # 更新模型状态 489 | model_state[key] = param 490 | 491 | # 将平均后的参数加载回模型 492 | trainer.model.load_state_dict(model_state) 493 | 494 | logger.info(f"模型参数同步完成,耗时: {time.time() - sync_start_time:.2f}秒") 495 | 496 | # 清理缓存并等待所有GPU操作完成 497 | if torch.cuda.is_available(): 498 | torch.cuda.synchronize(device) 499 | torch.cuda.empty_cache() 500 | 501 | logger.info(f"Epoch {epoch+1} 完成,总耗时: {time.time() - epoch_start_time:.2f}秒") 502 | 503 | # 保存模型和嵌入(仅在主进程) 504 | if rank == 0: 505 | # 保存模型 506 | checkpoint_dir = os.path.join(args.output_dir, 'checkpoints') 507 | os.makedirs(checkpoint_dir, exist_ok=True) 508 | trainer.save_model(os.path.join(checkpoint_dir, 'model_final.pt')) 509 | 510 | # 绘制损失曲线 511 | from utils import plot_loss_curve 512 | plot_loss_curve(epoch_losses, output_dir=args.output_dir) 513 | 514 | # 获取嵌入 515 | logger.info("保存嵌入...") 516 | embedding_dir = os.path.join(args.output_dir, 'embedding') 517 | os.makedirs(embedding_dir, exist_ok=True) 518 | 519 | # 获取节点嵌入 - 使用较大的批次大小 520 | node_embeddings = trainer.get_embeddings() 521 | 522 | # 使用numpy的高效操作处理嵌入 523 | sorted_node_ids = sorted(node_embeddings.keys()) 524 | embedding_array = np.stack([node_embeddings[node_id] for node_id in sorted_node_ids]) 525 | 526 | # 并行保存不同格式的嵌入 527 | np.save(os.path.join(embedding_dir, "node_embeddings.npy"), embedding_array) 528 | write_embedding(embedding_array, os.path.join(embedding_dir, "node_embeddings.txt")) 529 | 530 | # 如果需要可视化,则调用可视化函数 531 | if args.visualize: 532 | visualize_embeddings(node_embeddings, args.data_path, args.output_dir) 533 | 534 | logger.info("训练完成!") 535 | 536 | # 清理分布式环境 537 | cleanup() 538 | 539 | 540 | def main(): 541 | parser = argparse.ArgumentParser(description='多GPU版集成EGES实现') 542 | parser.add_argument('--data_path', type=str, default='./data/', help='数据文件路径') 543 | parser.add_argument('--output_dir', type=str, default='./output/integrated_multi_gpu/', help='输出目录') 544 | parser.add_argument('--p', type=float, default=0.25, help='返回参数') 545 | parser.add_argument('--q', type=float, default=2, help='进出参数') 546 | parser.add_argument('--walk_length', type=int, default=10, help='随机游走长度') 547 | parser.add_argument('--context_size', type=int, default=5, help='上下文窗口大小') 548 | parser.add_argument('--walks_per_node', type=int, default=10, help='每个节点的游走次数') 549 | parser.add_argument('--embedding_dim', type=int, default=128, help='嵌入维度') 550 | parser.add_argument('--batch_size', type=int, default=128, help='批次大小') 551 | parser.add_argument('--epochs', type=int, default=5, help='训练轮数') 552 | parser.add_argument('--lr', type=float, default=0.001, help='学习率') 553 | parser.add_argument('--seed', type=int, default=42, help='随机种子') 554 | parser.add_argument('--gpus', type=int, default=-1, help='使用的GPU数量,-1表示使用所有可用GPU') 555 | parser.add_argument('--master_addr', type=str, default='localhost', help='主节点地址') 556 | parser.add_argument('--master_port', type=str, default=None, help='主节点端口') 557 | parser.add_argument('--visualize', action='store_true', help='是否可视化嵌入向量') 558 | parser.add_argument('--prefetch_factor', type=int, default=2, help='数据预取因子') 559 | args = parser.parse_args() 560 | 561 | # 设置CUDA后端为NCCL 562 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(torch.cuda.device_count()))) 563 | torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优 564 | 565 | # 设置随机种子 566 | set_seed(args.seed) 567 | 568 | # 确定可用的GPU数量 569 | num_gpus = torch.cuda.device_count() 570 | if num_gpus == 0: 571 | print("没有可用的GPU,将使用CPU训练") 572 | args.gpus = 0 573 | elif args.gpus == -1 or args.gpus > num_gpus: 574 | args.gpus = num_gpus 575 | print(f"将使用所有 {num_gpus} 个可用的GPU") 576 | else: 577 | print(f"将使用 {args.gpus} 个GPU") 578 | 579 | # 创建输出目录 580 | os.makedirs(args.output_dir, exist_ok=True) 581 | 582 | # 如果没有指定端口,获取一个可用端口 583 | if args.master_port is None: 584 | args.master_port = get_free_port() 585 | 586 | # 单GPU或CPU训练 587 | if args.gpus <= 1: 588 | train_model(0, 1, args) 589 | else: 590 | # 多GPU训练 591 | mp.spawn( 592 | train_model, 593 | args=(args.gpus, args), 594 | nprocs=args.gpus, 595 | join=True 596 | ) 597 | 598 | 599 | if __name__ == "__main__": 600 | main() -------------------------------------------------------------------------------- /gpu_process/EGES_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | import numpy as np 6 | import networkx as nx 7 | from torch_geometric.nn import Node2Vec 8 | from torch_geometric.utils import from_networkx 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | import time 12 | import os 13 | import sys 14 | 15 | # 添加项目根目录到系统路径 16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | 18 | from utils import plot_loss_curve, visualize_embeddings 19 | 20 | 21 | class EGES(nn.Module): 22 | """ 23 | 集成版EGES模型,将游走采样和模型训练整合在一起 24 | """ 25 | def __init__(self, edge_index, num_nodes, feature_dim, feature_lens, 26 | embedding_dim=128, walk_length=10, context_size=5, 27 | walks_per_node=10, p=1.0, q=1.0, lr=0.001, device=None): 28 | """ 29 | 初始化EGES模型 30 | 31 | 参数: 32 | edge_index: 图的边索引 33 | num_nodes: 节点数量 34 | feature_dim: 特征维度 35 | feature_lens: 每个特征的长度列表 36 | embedding_dim: 嵌入维度 37 | walk_length: 随机游走长度 38 | context_size: 上下文窗口大小 39 | walks_per_node: 每个节点的游走次数 40 | p: 返回参数 41 | q: 进出参数 42 | lr: 学习率 43 | device: 计算设备 44 | """ 45 | super(EGES, self).__init__() 46 | 47 | # 设置设备 48 | if device is None: 49 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | else: 51 | self.device = device 52 | 53 | # 模型参数 54 | self.num_nodes = num_nodes 55 | self.feature_dim = feature_dim 56 | self.feature_lens = feature_lens 57 | self.embedding_dim = embedding_dim 58 | self.lr = lr 59 | 60 | # 确保边索引在正确的设备上 61 | if edge_index.device != self.device: 62 | edge_index = edge_index.to(self.device) 63 | 64 | # 初始化Node2Vec模型 65 | self.node2vec = Node2Vec( 66 | edge_index, 67 | embedding_dim=embedding_dim, 68 | walk_length=walk_length, 69 | context_size=context_size, 70 | walks_per_node=walks_per_node, 71 | p=p, 72 | q=q, 73 | sparse=True, 74 | num_negative_samples=5 # 增加负样本数量 75 | ).to(self.device) 76 | 77 | # 初始化特征嵌入层 78 | self.embedding_layers = nn.ModuleList() 79 | for i in range(self.feature_dim): 80 | embedding_layer = nn.Embedding(self.feature_lens[i], self.embedding_dim) 81 | nn.init.xavier_uniform_(embedding_layer.weight) 82 | self.embedding_layers.append(embedding_layer) 83 | 84 | # 注意力网络 85 | self.attention_network = nn.Sequential( 86 | nn.Linear(self.feature_dim, self.feature_dim), 87 | nn.ReLU(), 88 | nn.Linear(self.feature_dim, self.feature_dim), 89 | nn.Softmax(dim=1) 90 | ) 91 | 92 | # 将所有模块移动到指定设备 93 | self.to(self.device) 94 | 95 | # 初始化优化器 96 | self.optimizer = None 97 | self.scheduler = None 98 | 99 | def init_optimizer(self, lr=None): 100 | """ 101 | 初始化优化器 102 | 103 | 参数: 104 | lr: 学习率,如果为None则使用默认值 105 | """ 106 | if lr is not None: 107 | self.lr = lr 108 | 109 | # 使用SparseAdam优化器,它支持稀疏梯度 110 | self.optimizer = torch.optim.SparseAdam( 111 | self.parameters(), 112 | lr=self.lr, 113 | betas=(0.9, 0.999), 114 | eps=1e-8 115 | ) 116 | 117 | # 学习率调度器 118 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 119 | self.optimizer, 120 | mode='min', 121 | factor=0.5, 122 | patience=2, 123 | verbose=True 124 | ) 125 | 126 | def forward(self, features=None): 127 | """ 128 | 前向传播 129 | 130 | 参数: 131 | features: 节点特征,如果为None则只返回Node2Vec嵌入 132 | 133 | 返回: 134 | embeddings: 节点嵌入 135 | """ 136 | # 获取Node2Vec嵌入 137 | node2vec_emb = self.node2vec() 138 | 139 | # 如果没有提供特征,直接返回Node2Vec嵌入 140 | if features is None: 141 | return node2vec_emb 142 | 143 | # 确保特征在正确的设备上 144 | features = [f.to(self.device, non_blocking=True) for f in features] 145 | 146 | # 处理特征 147 | batch_size = features[0].size(0) 148 | 149 | # 对每个特征进行嵌入查找 150 | embed_list = [] 151 | for i in range(self.feature_dim): 152 | # 确保索引在有效范围内 153 | valid_indices = torch.clamp(features[i], 0, self.feature_lens[i] - 1) 154 | # 使用非阻塞传输 155 | valid_indices = valid_indices.to(self.device, non_blocking=True) 156 | embed_list.append(self.embedding_layers[i](valid_indices)) 157 | 158 | # 堆叠嵌入 [batch_size, embedding_dim, num_feat] 159 | stacked_embeds = torch.stack(embed_list, dim=2) 160 | 161 | # 计算注意力权重 162 | feature_ids = torch.arange(self.feature_dim, device=self.device).expand(batch_size, self.feature_dim) 163 | attention_weights = self.attention_network(feature_ids.float()) 164 | 165 | # 应用注意力权重 [batch_size, embedding_dim, 1] 166 | attention_weights = attention_weights.unsqueeze(1) 167 | weighted_embeds = torch.matmul(stacked_embeds, attention_weights.transpose(1, 2)) 168 | 169 | # 最终嵌入 [batch_size, embedding_dim] 170 | feature_emb = weighted_embeds.squeeze(2) 171 | 172 | # 获取对应节点的Node2Vec嵌入 173 | node_indices = torch.clamp(features[0], 0, self.num_nodes - 1) # 确保节点索引在有效范围内 174 | node_emb = node2vec_emb[node_indices] 175 | 176 | # 融合嵌入 177 | final_emb = node_emb + feature_emb 178 | 179 | return final_emb 180 | 181 | def get_loader(self, batch_size=128, shuffle=True): 182 | """ 183 | 获取数据加载器 184 | 185 | 参数: 186 | batch_size: 批次大小 187 | shuffle: 是否打乱数据 188 | 189 | 返回: 190 | loader: 数据加载器 191 | """ 192 | return self.node2vec.loader(batch_size=batch_size, shuffle=shuffle, pin_memory=True) 193 | 194 | 195 | def train_step(self, pos_rw, neg_rw=None, features=None, update_params=True): 196 | """ 197 | 训练一步 198 | 199 | 参数: 200 | pos_rw: 正样本随机游走 201 | neg_rw: 负样本随机游走,如果为None则自动生成 202 | features: 节点特征,如果为None则不使用特征 203 | update_params: 是否更新参数,默认为True 204 | 205 | 返回: 206 | loss: 损失值 207 | """ 208 | self.train() 209 | self.optimizer.zero_grad() 210 | 211 | # 确保数据在正确的设备上,并使用non_blocking=True 212 | pos_rw = pos_rw.to(self.device, non_blocking=True) 213 | 214 | # 如果没有提供负样本,在GPU上直接生成 215 | if neg_rw is None: 216 | neg_rw = torch.randint(0, self.num_nodes, pos_rw.size(), 217 | dtype=torch.long, device=self.device) 218 | else: 219 | neg_rw = neg_rw.to(self.device, non_blocking=True) 220 | 221 | # 计算Node2Vec损失 222 | loss = self.node2vec.loss(pos_rw, neg_rw) 223 | 224 | # 反向传播 225 | loss.backward() 226 | 227 | # 手动裁剪稀疏梯度 228 | for param in self.parameters(): 229 | if param.grad is not None: 230 | if param.grad.is_sparse: 231 | # 对于稀疏梯度,我们只能逐个处理非零元素 232 | grad_values = param.grad._values() 233 | grad_norm = torch.norm(grad_values) 234 | if grad_norm > 5.0: # max_norm 235 | grad_values.mul_(5.0 / grad_norm) 236 | else: 237 | # 对于稠密梯度,使用常规裁剪 238 | torch.nn.utils.clip_grad_norm_([param], max_norm=5.0) 239 | 240 | # 更新参数 241 | if update_params: 242 | self.optimizer.step() 243 | self.optimizer.zero_grad() 244 | 245 | return loss.item() 246 | 247 | def update_lr(self, val_loss): 248 | """ 249 | 更新学习率 250 | 251 | 参数: 252 | val_loss: 验证损失 253 | """ 254 | if self.scheduler is not None: 255 | self.scheduler.step(val_loss) 256 | 257 | def get_embeddings(self, features=None, indices=None): 258 | """ 259 | 获取节点嵌入 260 | 261 | 参数: 262 | features: 节点特征,如果为None则只返回Node2Vec嵌入 263 | indices: 要获取嵌入的节点索引,如果为None则获取所有节点 264 | 265 | 返回: 266 | embeddings: 节点嵌入 267 | """ 268 | with torch.no_grad(): 269 | # 如果没有提供特征,直接返回Node2Vec嵌入 270 | if features is None: 271 | embeddings = self.node2vec() 272 | if indices is not None: 273 | # 确保索引在有效范围内 274 | valid_indices = torch.clamp(indices, 0, self.num_nodes - 1) 275 | return embeddings[valid_indices] 276 | return embeddings 277 | 278 | # 确保特征在正确的设备上 279 | features = [f.to(self.device, non_blocking=True) for f in features] 280 | 281 | # 处理特征 282 | batch_size = features[0].size(0) 283 | 284 | # 对每个特征进行嵌入查找 285 | embed_list = [] 286 | for i in range(self.feature_dim): 287 | # 确保索引在有效范围内 288 | valid_indices = torch.clamp(features[i], 0, self.feature_lens[i] - 1) 289 | # 使用非阻塞传输 290 | valid_indices = valid_indices.to(self.device, non_blocking=True) 291 | embed_list.append(self.embedding_layers[i](valid_indices)) 292 | 293 | # 堆叠嵌入 [batch_size, embedding_dim, num_feat] 294 | stacked_embeds = torch.stack(embed_list, dim=2) 295 | 296 | # 计算注意力权重 297 | feature_ids = torch.arange(self.feature_dim, device=self.device).expand(batch_size, self.feature_dim) 298 | attention_weights = self.attention_network(feature_ids.float()) 299 | 300 | # 应用注意力权重 [batch_size, embedding_dim, 1] 301 | attention_weights = attention_weights.unsqueeze(1) 302 | weighted_embeds = torch.matmul(stacked_embeds, attention_weights.transpose(1, 2)) 303 | 304 | # 最终嵌入 [batch_size, embedding_dim] 305 | feature_emb = weighted_embeds.squeeze(2) 306 | 307 | # 获取对应节点的Node2Vec嵌入 308 | if indices is not None: 309 | # 确保索引在有效范围内 310 | valid_indices = torch.clamp(indices, 0, self.num_nodes - 1) 311 | node_emb = self.node2vec()[valid_indices] 312 | else: 313 | node_emb = self.node2vec()[torch.clamp(features[0], 0, self.num_nodes - 1)] 314 | 315 | # 融合嵌入 316 | final_emb = node_emb + feature_emb 317 | 318 | return final_emb 319 | 320 | 321 | class EGESTrainer: 322 | """ 323 | EGES模型训练器 324 | """ 325 | def __init__(self, session_list, side_info=None, embedding_dim=128, 326 | walk_length=10, context_size=5, walks_per_node=10, 327 | p=1.0, q=1.0, lr=0.001, device=None, prefetch_factor=2, 328 | G=None, node_map=None, reverse_node_map=None, pyg_data=None): 329 | """ 330 | 初始化训练器 331 | 332 | 参数: 333 | session_list: 会话列表 334 | side_info: 侧信息 335 | embedding_dim: 嵌入维度 336 | walk_length: 随机游走长度 337 | context_size: 上下文窗口大小 338 | walks_per_node: 每个节点的游走次数 339 | p: 返回参数 340 | q: 进出参数 341 | lr: 学习率 342 | device: 计算设备 343 | prefetch_factor: 预取因子,控制数据预加载的批次数 344 | G: 预构建的图 345 | node_map: 预构建的节点映射 346 | reverse_node_map: 预构建的反向节点映射 347 | pyg_data: 预构建的PyG数据 348 | """ 349 | # 设置设备 350 | if device is None: 351 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 352 | else: 353 | self.device = device 354 | 355 | self.prefetch_factor = prefetch_factor 356 | 357 | # 使用预构建的图和节点映射,或者构建新的图 358 | if G is not None and node_map is not None and reverse_node_map is not None: 359 | self.G = G 360 | self.node_map = node_map 361 | self.reverse_node_map = reverse_node_map 362 | self.num_nodes = self.G.number_of_nodes() 363 | 364 | # 使用预构建的PyG数据或者转换为PyG图 365 | if pyg_data is not None: 366 | self.pyg_data = pyg_data.to(self.device) 367 | else: 368 | self.pyg_data = from_networkx(self.G).to(self.device) 369 | else: 370 | # 构建图 371 | self.G, self.node_maps = self._build_graph(session_list) 372 | self.node_map, self.reverse_node_map = self.node_maps 373 | 374 | # 设置节点数量 375 | self.num_nodes = self.G.number_of_nodes() 376 | 377 | # 转换为PyG图,并直接放在GPU上 378 | self.pyg_data = from_networkx(self.G).to(self.device) 379 | 380 | # 处理侧信息 381 | self.side_info = side_info 382 | self.side_info_dict = {} 383 | 384 | if side_info is not None: 385 | # 创建特征长度列表 386 | self.feature_lens = [] 387 | for i in range(side_info.shape[1]): 388 | tmp_len = len(set(side_info[:, i])) + 1 389 | self.feature_lens.append(tmp_len) 390 | 391 | # 创建侧信息字典并预先将数据转移到GPU 392 | side_info_tensor = torch.tensor(side_info, dtype=torch.long, device=self.device) 393 | for i in range(len(side_info)): 394 | sku_id = side_info[i][0] 395 | if sku_id in self.node_map: 396 | self.side_info_dict[self.node_map[sku_id]] = side_info_tensor[i] 397 | else: 398 | self.feature_lens = [self.num_nodes + 1] 399 | 400 | # 创建模型 401 | self.model = EGES( 402 | edge_index=self.pyg_data.edge_index, 403 | num_nodes=self.num_nodes, 404 | feature_dim=len(self.feature_lens), 405 | feature_lens=self.feature_lens, 406 | embedding_dim=embedding_dim, 407 | walk_length=walk_length, 408 | context_size=context_size, 409 | walks_per_node=walks_per_node, 410 | p=p, 411 | q=q, 412 | lr=lr, 413 | device=self.device 414 | ) 415 | 416 | # 初始化优化器 417 | self.model.init_optimizer(lr) 418 | 419 | # 创建数据缓冲区 420 | self.data_buffer = [] 421 | 422 | def _build_graph(self, session_list): 423 | """ 424 | 构建图 425 | 426 | 参数: 427 | session_list: 会话列表 428 | 429 | 返回: 430 | G: NetworkX图 431 | node_maps: 节点映射元组 (node_map, reverse_node_map) 432 | """ 433 | print("构建图...") 434 | start_time = time.time() 435 | 436 | # 提取所有边 437 | edges = [] 438 | for session in session_list: 439 | if len(session) > 1: 440 | for i in range(len(session) - 1): 441 | u = int(session[i]) 442 | v = int(session[i + 1]) 443 | edges.append((u, v)) 444 | 445 | print(f"提取的边数量: {len(edges)}") 446 | if len(edges) == 0: 447 | print("警告: 没有提取到边,无法构建图") 448 | return None, None 449 | 450 | # 创建NetworkX图 451 | G = nx.Graph() 452 | G.add_edges_from(edges) 453 | 454 | # 获取所有唯一节点 455 | nodes = list(G.nodes()) 456 | node_map = {node: i for i, node in enumerate(nodes)} 457 | reverse_node_map = {i: node for i, node in enumerate(nodes)} 458 | 459 | # 重新映射节点ID 460 | G_relabeled = nx.relabel_nodes(G, node_map) 461 | 462 | end_time = time.time() 463 | print(f"图构建完成,耗时: {end_time - start_time:.2f}秒") 464 | print(f"图包含 {G.number_of_nodes()} 个节点和 {G.number_of_edges()} 条边") 465 | 466 | return G_relabeled, (node_map, reverse_node_map) 467 | 468 | def _prepare_features(self, nodes): 469 | """ 470 | 准备特征,优化版本直接在GPU上处理 471 | 472 | 参数: 473 | nodes: 原始节点ID列表 474 | 475 | 返回: 476 | features: 特征列表 477 | """ 478 | if self.side_info is None: 479 | # 将原始节点ID转换为内部索引 480 | node_indices = [self.node_map.get(node, 0) for node in nodes] 481 | nodes_tensor = torch.tensor(node_indices, dtype=torch.long, device=self.device) 482 | return [torch.clamp(nodes_tensor, 0, self.feature_lens[0] - 1)] 483 | 484 | # 将原始节点ID转换为内部索引 485 | node_indices = [self.node_map.get(node, 0) for node in nodes] 486 | nodes_tensor = torch.tensor(node_indices, dtype=torch.long, device=self.device) 487 | features = [torch.clamp(nodes_tensor, 0, self.feature_lens[0] - 1)] 488 | 489 | # 使用预计算的映射矩阵 490 | if not hasattr(self, '_feature_mapping'): 491 | # 第一次调用时创建映射矩阵 492 | self._feature_mapping = [] 493 | for i in range(1, len(self.feature_lens)): 494 | mapping = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device) 495 | for node_id, info in self.side_info_dict.items(): 496 | if node_id < self.num_nodes: # 确保索引在有效范围内 497 | mapping[node_id] = info[i] 498 | self._feature_mapping.append(mapping) 499 | 500 | # 使用预计算的映射矩阵快速获取特征 501 | for i, mapping in enumerate(self._feature_mapping): 502 | # 使用高效的索引操作,同时确保索引在有效范围内 503 | valid_nodes = torch.clamp(nodes_tensor, 0, self.num_nodes - 1) 504 | feature_col = mapping[valid_nodes] 505 | # 确保特征索引在有效范围内 506 | feature_col = torch.clamp(feature_col, 0, self.feature_lens[i + 1] - 1) 507 | features.append(feature_col) 508 | 509 | return features 510 | 511 | def _prefetch_data(self, loader, num_batches): 512 | """ 513 | 预取数据到缓冲区 514 | 515 | 参数: 516 | loader: 数据加载器 517 | num_batches: 预取的批次数 518 | """ 519 | try: 520 | for _ in range(num_batches): 521 | batch = next(loader) 522 | self.data_buffer.append(batch) 523 | except StopIteration: 524 | pass 525 | 526 | def train(self, epochs=10, batch_size=128, output_dir='./output/integrated', plot_loss=True): 527 | """ 528 | 训练模型 529 | 530 | 参数: 531 | epochs: 训练轮数 532 | batch_size: 批次大小 533 | output_dir: 输出目录,用于保存中间结果 534 | plot_loss: 是否绘制损失曲线 535 | 536 | 返回: 537 | model: 训练好的模型 538 | """ 539 | print("开始训练...") 540 | 541 | # 确保输出目录存在 542 | os.makedirs(output_dir, exist_ok=True) 543 | 544 | # 获取数据加载器 545 | loader = self.model.get_loader(batch_size=batch_size, shuffle=True) 546 | 547 | # 记录每个epoch的损失 548 | epoch_losses = [] 549 | 550 | for epoch in range(epochs): 551 | self.model.train() 552 | total_loss = 0 553 | batch_count = 0 554 | 555 | # 创建迭代器 556 | loader_iter = iter(loader) 557 | 558 | # 预取数据 559 | self._prefetch_data(loader_iter, self.prefetch_factor) 560 | 561 | # 使用tqdm显示进度条 562 | pbar = tqdm(total=len(loader), desc=f"Epoch {epoch+1}/{epochs}") 563 | 564 | while True: 565 | # 如果缓冲区为空,重新预取数据 566 | if not self.data_buffer: 567 | self._prefetch_data(loader_iter, self.prefetch_factor) 568 | if not self.data_buffer: 569 | break 570 | 571 | # 获取一批数据 572 | pos_rw, neg_rw = self.data_buffer.pop(0) 573 | 574 | # 训练一步 - 不使用特征,简化训练过程 575 | loss = self.model.train_step(pos_rw, neg_rw) 576 | 577 | total_loss += loss 578 | batch_count += 1 579 | 580 | # 更新进度条 581 | pbar.set_postfix({'loss': f'{loss:.4f}'}) 582 | pbar.update(1) 583 | 584 | # 在处理当前批次时,异步预取下一批数据 585 | if len(self.data_buffer) < self.prefetch_factor: 586 | self._prefetch_data(loader_iter, 1) 587 | 588 | pbar.close() 589 | 590 | # 计算平均损失 591 | avg_loss = total_loss / batch_count 592 | epoch_losses.append(avg_loss) 593 | print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}") 594 | 595 | # 更新学习率 596 | self.model.update_lr(avg_loss) 597 | 598 | # 清空数据缓冲区 599 | self.data_buffer.clear() 600 | 601 | print("训练完成!") 602 | 603 | # 绘制损失下降曲线(如果需要) 604 | if plot_loss: 605 | plot_loss_curve(epoch_losses, output_dir=output_dir) 606 | 607 | # 存储损失值供后续使用 608 | self.epoch_losses = epoch_losses 609 | 610 | return self.model 611 | 612 | def get_embeddings(self): 613 | """ 614 | 获取所有节点的嵌入 615 | 616 | 返回: 617 | embeddings_dict: 节点嵌入字典,键为原始节点ID,值为嵌入向量 618 | """ 619 | self.model.eval() 620 | 621 | # 获取所有节点的嵌入 - 分批处理以减少内存压力 622 | batch_size = 1024 # 使用更大的批次大小 623 | num_nodes = len(self.reverse_node_map) 624 | embeddings_dict = {} 625 | 626 | with torch.no_grad(): 627 | # 分批处理节点 628 | for i in range(0, num_nodes, batch_size): 629 | batch_indices = list(range(i, min(i + batch_size, num_nodes))) 630 | 631 | # 准备特征(如果有) 632 | if self.side_info is not None: 633 | batch_nodes = [self.reverse_node_map[idx] for idx in batch_indices] # 获取原始节点ID 634 | features = self._prepare_features(batch_nodes) 635 | else: 636 | features = None 637 | 638 | # 获取这批节点的嵌入 639 | indices_tensor = torch.tensor(batch_indices, device=self.device) 640 | try: 641 | batch_embeddings = self.model.get_embeddings(features, indices=indices_tensor) 642 | 643 | # 将嵌入移到CPU并转换为NumPy 644 | batch_embeddings = batch_embeddings.cpu().numpy() 645 | 646 | # 将嵌入映射回原始节点ID 647 | for j, idx in enumerate(batch_indices): 648 | node_id = self.reverse_node_map[idx] 649 | embeddings_dict[node_id] = batch_embeddings[j] 650 | except Exception as e: 651 | print(f"处理批次 {i} 时出错: {str(e)}") 652 | continue 653 | 654 | return embeddings_dict 655 | 656 | def save_model(self, path): 657 | """ 658 | 保存模型 659 | 660 | 参数: 661 | path: 保存路径 662 | """ 663 | torch.save({ 664 | 'model_state_dict': self.model.state_dict(), 665 | 'optimizer_state_dict': self.model.optimizer.state_dict(), 666 | 'node_map': self.node_map, 667 | 'reverse_node_map': self.reverse_node_map 668 | }, path) 669 | 670 | def load_model(self, path): 671 | """ 672 | 加载模型 673 | 674 | 参数: 675 | path: 加载路径 676 | """ 677 | checkpoint = torch.load(path) 678 | self.model.load_state_dict(checkpoint['model_state_dict']) 679 | self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 680 | self.node_map = checkpoint['node_map'] 681 | self.reverse_node_map = checkpoint['reverse_node_map'] --------------------------------------------------------------------------------