├── 202406高校大数据挑战赛baseline ├── baseline使用说明.md ├── project │ ├── 0629_2temp.pth │ ├── 0629_2wind.pth │ ├── train_mean.npy │ ├── train_std.npy │ ├── Model.py │ └── index.py └── 2024bdc-baseline-LB1-7098.ipynb ├── README.md ├── 202407TimeSeriesTransformer └── 关于代码的说明.md ├── 2024kddWhoiswhotop37solution ├── 2024kddcupwhoiswho赛道top37solution.md └── fasttext-essay-category-80.ipynb ├── 202407科大讯飞短视频推荐baseline └── 科大讯飞短视频推荐LB0.00033.ipynb ├── 202407Kagglejtseptop2solutionstudy └── jtsep-top2-solution-study.ipynb ├── 202407chatglm6b微调 └── chatglm6b-huanhuan-finetune-inference.ipynb ├── 202404KDDcup-whoiswho-baseline └── 202404-kdd-cup-whoiswho-ind-baseline.ipynb └── 202406datacastle睡眠事件检测baseline └── 睡眠事件检测baseline(LB0.6251).ipynb /202406高校大数据挑战赛baseline/baseline使用说明.md: -------------------------------------------------------------------------------- 1 | baseline使用说明 2 | 3 | 这里先使用baseline跑出train_mean.npy,train_std.npy,temp.pth和wind.pth文件,然后再用index.py和Model.py来对测试数据预测。 -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/0629_2temp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunsuxiaozi/AI-and-competition/HEAD/202406高校大数据挑战赛baseline/project/0629_2temp.pth -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/0629_2wind.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunsuxiaozi/AI-and-competition/HEAD/202406高校大数据挑战赛baseline/project/0629_2wind.pth -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/train_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunsuxiaozi/AI-and-competition/HEAD/202406高校大数据挑战赛baseline/project/train_mean.npy -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/train_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunsuxiaozi/AI-and-competition/HEAD/202406高校大数据挑战赛baseline/project/train_std.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 该仓库的使用方法 2 | 3 | 该仓库主要是各种比赛的baseline和少量比赛的topline,还有一些独立于比赛的深度学习项目。 4 | 5 | baseline是各场比赛的入门指南,各位选手可以用baseline完成比赛的第一次提交。baseline相对简单,容易上手,适合初学者学习。 6 | 7 | topline是各场比赛的前排方案。由于是topline,方案相比baseline会更加复杂,整理起来也更加不易,所以目前仓库topline的数量也比较有限。目前仓库里的topline都是作者在各场比赛中在原作者代码的基础上完善而来,修正了原作者的一些错误,删除了无用的代码,并给代码添加了一定的注释方便各位理解。如果你需要学习各场比赛的topline,来我的仓库会比看原作者的代码更加容易理解。 8 | 9 | 如果你从中学到了东西不要忘记动动发财的小手支持一下本仓库。 10 | 11 | 目前该仓库只有作者1人维护,难免会存在疏忽。如果你发现任何问题或者有任何建议欢迎联系。 12 | 13 | 作者的github和Kaggle名都为yunsuxiaozi,即:匀速小子。 14 | 15 | -------------------------------------------------------------------------------- /202407TimeSeriesTransformer/关于代码的说明.md: -------------------------------------------------------------------------------- 1 | 关于代码的说明 2 | 3 | 这是我尝试学习清华大学新出的时间序列的Transformer:iTransformer 4 | 5 | github仓库如下:https://github.com/thuml/iTransformer 6 | 7 | arxiv论文如下:https://arxiv.org/abs/2310.06625 8 | 9 | 这里选择的是Kaggle时间序列入门竞赛来试了一下:https://www.kaggle.com/competitions/store-sales-time-series-forecasting/overview 10 | 11 | 我在Kaggle上的代码如下:https://www.kaggle.com/code/yunsuxiaozi/store-sales-transformer 12 | 13 | 我这里没有完全按照官方的模型来,但是模型的架构应该是和官方一样的,在内部实现上做了一些修改。 14 | 15 | 让我震惊的是,只用了需要预测的那个变量sales作为模型的输入和输出便达到了0.51403,排行榜最高0.37768,如果对模型进行进一步的参数调整应该还能达到更高的分数。这个模型比我之前用lgb模型+特征工程得到的分数要高(虽然可能和我太菜有关) 16 | 17 | -------------------------------------------------------------------------------- /2024kddWhoiswhotop37solution/2024kddcupwhoiswho赛道top37solution.md: -------------------------------------------------------------------------------- 1 | ## 前言 2 | 3 | 感谢主办方举办的这次比赛,也感谢各位参赛选手的努力。 4 | 5 | 本次比赛官网:https://www.biendata.xyz/competition/ind_kdd_2024/ 6 | 7 | 本次比赛完整数据集:https://www.kaggle.com/datasets/yunsuxiaozi/2024kddcupwhoiswho 8 | 9 | Hello,大家好,我是yunsuxiaozi(匀速小子). 10 | 11 | 这是我第一次参加kddcup的比赛,我参加的是Whoiswho赛道的同名消歧比赛。我在A榜排名$31/117\approx 26.5\% (top27\%)$,B榜排名$37/53\approx 69.8\%(top70\%)$。 12 | 13 | 由于我在比赛中并不算前排,以下内容也不是什么好的solution,所以以下内容更多的是写给我自己看的,应该也没有什么人看吧。 14 | 15 | 本次比赛打到目前的排名我还是比较满意的,毕竟是第一次参加KDDcup的比赛。在我前面的大佬都是比赛圈著名的人物。和他们相比,我比赛经验不足,知识面也没他们广,所以比不过也是很正常的。 16 | 17 | ## top37solution介绍: 18 | 19 | 我用的是纯数据挖掘的做法。我这里总共有3个文件,一个是训练了一个论文分类器,一个是构造数据的文件,一个是模型训练和推理的文件。 20 | 21 | 我的fasttext的文件是训练了一个论文分类器,对论文按学科来进行分类。在https://www.biendata.xyz/forum/view_post_category/1034618/这个讨论中,它提到可以用大语言模型将论文的领域概括为一个特征,但是我不擅长搞大语言模型,就用这个数据集https://www.kaggle.com/datasets/Cornell-University/arxiv/data自己训练了一个论文分类器,根据论文的摘要给论文加上学科的特征。(注:这好像是违规操作,但是我排名靠后反正也没有奖金就不在乎这种小事了,训练出来的模型准确率只有80%,可能用大语言模型搞准确率还高,效果还好,如果读者有需要可以用大语言模型来搞。) 22 | 23 | 我的data文件根据训练数据和测试数据构造特征,这里构造特征还是花了很长时间,所以和模型的训练分了2个文件,不想浪费Kaggle的GPU。特征工程主要也是在这个文件中完成,具体做了什么也写的很详细了,有兴趣可以看看. 24 | 25 | 我的model文件就是模型的训练和推理,这里用了2组lgb模型来进行训练。在前面还做了一些特征工程是因为后续也在尝试修改代码以达到更好的效果。 26 | 27 | 28 | 29 | 30 | 31 | ## 几个疑惑的问题思考后的结果(以下仅代表个人观点) 32 | 33 | ### 1.要不要在lgb模型的训练中使用early_stop? 34 | 35 | 答:我的回答是不要。本次比赛的评估指标是对每个author单独计算auc值,然后加权。虽然比赛方提供的训练数据有14万条,但是具体到author只有700多个。如果用交叉验证,每折验证集的author只有100个左右,数据量太少了。这100个author无法代表真实世界中无穷无尽的样本,如果用这些样本评估而早停可能会欠拟合,所以我个人认为不要。 36 | 37 | ### 2.模型的训练过程中要不要使用加权,也就是用比赛的评估指标训练模型? 38 | 39 | 答:经过实验发现效果并不好,我也思考了原因,通过对每个author的oof的auc作图后发现,一个author的论文数量越少,auc相对来说越低,甚至低于0.5,而论文数量越少错误的数量也会相对越少。因此,错误多的样本论文数量多,论文数量多的样本auc好,如果给auc好的样本大权重,auc差的样本小权重,明显是不合理的。最终还是选择不用加权来训练模型。 40 | 41 | ### 3.要不要使用模型融合? 42 | 43 | 答:经过实验效果并不好,xgb和catboost效果比不上lgb模型,虽然模型融合线下oof提高了很多,但是提交上去还是不行,可能是线下选择blending权重的时候过拟合,最后还是选择只使用lgb模型。 44 | -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/Model.py: -------------------------------------------------------------------------------- 1 | import torch#深度学习库,pytorch 2 | import torch.nn as nn#neural network,神经网络 3 | import torch.nn.functional as F#神经网络函数库 4 | import torch.optim as optim#一个实现了各种优化算法的库 5 | 6 | class BaselineModel(nn.Module): 7 | def __init__(self,): 8 | super(BaselineModel,self).__init__() 9 | self.conv=nn.Sequential( 10 | #1*24*36->16*24*36 11 | nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1), 12 | nn.BatchNorm2d(16), 13 | #16*24*36->16*12*18 14 | nn.MaxPool2d(kernel_size=2,stride=2), 15 | nn.GELU(), 16 | #16*12*18->32*12*18 17 | nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1), 18 | nn.BatchNorm2d(32), 19 | #32*12*18->64*12*18 20 | nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1,padding=2), 21 | nn.BatchNorm2d(64), 22 | #64*12*18->64*6*9 23 | nn.MaxPool2d(kernel_size=2,stride=2), 24 | nn.GELU(), 25 | #64*6*9->128*6*9 26 | nn.Conv2d(in_channels=64,out_channels=128,kernel_size=5,stride=1,padding=2), 27 | nn.BatchNorm2d(128), 28 | #128*6*9->128*3*4 29 | nn.MaxPool2d(kernel_size=2,stride=2), 30 | nn.GELU(), 31 | ) 32 | self.head=nn.Sequential( 33 | nn.Linear(128*3*4,128), 34 | nn.BatchNorm1d(128), 35 | nn.GELU(), 36 | nn.Linear(128,256), 37 | nn.BatchNorm1d(256), 38 | nn.GELU(), 39 | nn.Linear(256,1) 40 | ) 41 | 42 | def forward(self,x): 43 | x=self.conv(x) 44 | x=x.reshape(x.shape[0],-1) 45 | return self.head(x) 46 | -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/project/index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np#矩阵运算与科学计算的库 3 | from sklearn.model_selection import train_test_split 4 | import torch#深度学习库,pytorch 5 | import torch.nn as nn#neural network,神经网络 6 | 7 | from Model import BaselineModel 8 | 9 | def predict(feats,model_name='temp.pth',batch_size=128,train_mean=0,train_std=1): 10 | #测试数据 11 | x=60#60个观测站点 12 | model = BaselineModel() 13 | model.load_state_dict(torch.load(model_name)) 14 | model.eval() # 将模型设置为评估模式 15 | #预测24个小时的用的是71*x个站点(i:i-24)时刻的36个特征 16 | data=[] 17 | for i in range(24): 18 | data.append(feats[:,i:i-24,].reshape(feats.shape[0],-1)) 19 | data=np.array(data)#24个小时,71*x个站点 【i:i-24时刻*36个特征】 20 | #【24小时*71*x个站点】*【i:i-24时刻*36个特征】 21 | data=data.reshape(-1,data.shape[-1]) 22 | data=(data-train_mean)/train_std 23 | print(f"input.shape:{data.shape}") 24 | #test_preds=【24小时*71*x个站点】*【预测值】 25 | test_preds=np.zeros(len(data)) 26 | for idx in range(0,len(data),batch_size): 27 | data1=torch.Tensor(data[idx:idx+batch_size]).reshape(-1,1,24,36) 28 | test_preds[idx:idx+batch_size]=model(data1).detach().numpy().reshape(-1) 29 | test_preds=test_preds.reshape(24,71,x,-1) 30 | #71个周,24小时,x个站点的预测值 31 | test_preds=test_preds.transpose(1,0,2,3) 32 | return test_preds 33 | 34 | def invoke(input_dir): 35 | date='0629_2' 36 | np.random.seed(2024) 37 | #测试数据 38 | x=60#60个观测站点 39 | #71个不连续的周,56(每3个小时测一次),观测4个特征,9个观测方位,x个站点 40 | cenn_data=np.load(os.path.join(input_dir,'cenn_data.npy')).mean(axis=-2,keepdims=True)#真实情况是np.load加载的 41 | print(f"cenn_data.shape:{cenn_data.shape}") 42 | #将3个小时变成1个小时 (71,168,4,9,x) 43 | cenn_data_hour=np.repeat(cenn_data, 3, axis=1) 44 | 45 | cenn_data_hour=cenn_data_hour.transpose(0,4,1,2,3)#71*x*168*4*9 46 | cenn_data_hour=cenn_data_hour.reshape(71*x,168,4) 47 | #cenn/temp_lookback.npy 71个不连续的周 1个小时一次 x站上一周的温度 48 | temp_lookback = np.load(os.path.join(input_dir,'temp_lookback.npy')) 49 | print(f"temp_lookback.shape:{temp_lookback.shape}") 50 | temp_lookback=temp_lookback.transpose(0,2,1,3)#71,x,168,1 51 | temp_lookback=temp_lookback.reshape(71*x,168,1) 52 | #cenn/wind_lookback.npy 71个不连续的周 1个小时一次 x站上一周的风速 53 | wind_lookback = np.load(os.path.join(input_dir,'wind_lookback.npy')) 54 | print(f"wind_lookback.shape:{wind_lookback.shape}") 55 | wind_lookback=wind_lookback.transpose(0,2,1,3)#71,x,168,1 56 | wind_lookback=wind_lookback.reshape(71*x,168,1) 57 | #71*x个站点,168小时,38个特征 58 | total_feats=np.concatenate((cenn_data_hour,temp_lookback,wind_lookback),axis=-1) 59 | # 保存到 project 当中 60 | save_path = os.path.join('/home/mw','project') 61 | train_mean=np.load(os.path.join(save_path,'train_mean.npy')) 62 | train_std=np.load(os.path.join(save_path,'train_std.npy')) 63 | temp_predict=predict(feats=total_feats,model_name=os.path.join(save_path,f'{date}temp.pth'),batch_size=128,train_mean=train_mean,train_std=train_std) 64 | wind_predict=predict(feats=total_feats,model_name=os.path.join(save_path,f'{date}wind.pth'),batch_size=128,train_mean=train_mean,train_std=train_std) 65 | np.save(os.path.join(save_path,'temp_predict.npy'),temp_predict) 66 | np.save(os.path.join(save_path,'wind_predict.npy'),wind_predict) -------------------------------------------------------------------------------- /2024kddWhoiswhotop37solution/fasttext-essay-category-80.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1094c368", 6 | "metadata": { 7 | "papermill": { 8 | "duration": 0.003479, 9 | "end_time": "2024-06-08T01:59:39.731806", 10 | "exception": false, 11 | "start_time": "2024-06-08T01:59:39.728327", 12 | "status": "completed" 13 | }, 14 | "tags": [] 15 | }, 16 | "source": [ 17 | "## Created by yunsuxiaozi 2024/6/8" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "e85a5c21", 24 | "metadata": { 25 | "execution": { 26 | "iopub.execute_input": "2024-06-08T01:59:39.739623Z", 27 | "iopub.status.busy": "2024-06-08T01:59:39.739176Z", 28 | "iopub.status.idle": "2024-06-08T01:59:40.904040Z", 29 | "shell.execute_reply": "2024-06-08T01:59:40.902704Z" 30 | }, 31 | "papermill": { 32 | "duration": 1.172733, 33 | "end_time": "2024-06-08T01:59:40.907500", 34 | "exception": false, 35 | "start_time": "2024-06-08T01:59:39.734767", 36 | "status": "completed" 37 | }, 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "#necessary\n", 43 | "import pandas as pd#导入csv文件的库\n", 44 | "import numpy as np#进行矩阵运算的库\n", 45 | "import fasttext#高效处理单词表示和句子分类的库\n", 46 | "import csv#处理csv文件的库(逗号分割值)\n", 47 | "import random#提供了一些用于生成随机数的函数\n", 48 | "#设置随机种子,保证模型可以复现\n", 49 | "def seed_everything(seed):\n", 50 | " np.random.seed(seed)#numpy的随机种子\n", 51 | " random.seed(seed)#python内置的随机种子\n", 52 | "seed_everything(seed=2024)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "id": "bec88780", 59 | "metadata": { 60 | "execution": { 61 | "iopub.execute_input": "2024-06-08T01:59:40.915515Z", 62 | "iopub.status.busy": "2024-06-08T01:59:40.914945Z", 63 | "iopub.status.idle": "2024-06-08T02:02:44.626061Z", 64 | "shell.execute_reply": "2024-06-08T02:02:44.624470Z" 65 | }, 66 | "papermill": { 67 | "duration": 183.72648, 68 | "end_time": "2024-06-08T02:02:44.636895", 69 | "exception": false, 70 | "start_time": "2024-06-08T01:59:40.910415", 71 | "status": "completed" 72 | }, 73 | "tags": [] 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "read_files\n" 81 | ] 82 | }, 83 | { 84 | "data": { 85 | "text/html": [ 86 | "
\n", 87 | "\n", 100 | "\n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | "
labelabstract
15000006Collisional threats posed by Near-Earth Obje...
15000010Recognizing the patient's emotions using dee...
15000020Prior work on diagnosing Alzheimer's disease...
15000030In this paper, we propose a joint radio and ...
15000040In recent years, recommendation systems have...
\n", 136 | "
" 137 | ], 138 | "text/plain": [ 139 | " label abstract\n", 140 | "1500000 6 Collisional threats posed by Near-Earth Obje...\n", 141 | "1500001 0 Recognizing the patient's emotions using dee...\n", 142 | "1500002 0 Prior work on diagnosing Alzheimer's disease...\n", 143 | "1500003 0 In this paper, we propose a joint radio and ...\n", 144 | "1500004 0 In recent years, recommendation systems have..." 145 | ] 146 | }, 147 | "execution_count": 2, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "print(\"read_files\")\n", 154 | "arxiv=pd.read_json(\"/kaggle/input/arxiv/arxiv-metadata-oai-snapshot.json\",lines=True)\n", 155 | "#大类应该就是这些\n", 156 | "category=['cs','math','eess','stat','hep','cond-mat','astro','gr','nlin','q-bio',\n", 157 | " 'quant','nucl','q-fin','econ']\n", 158 | "def get_category(c):\n", 159 | " for i in range(len(category)):\n", 160 | " if category[i] in c:\n", 161 | " return i\n", 162 | "arxiv['label']=arxiv['categories'].apply(lambda x:get_category(x))\n", 163 | "train_feats=arxiv[['label','abstract']][:1500000]\n", 164 | "valid_feats=arxiv[['label','abstract']][1500000:]\n", 165 | "valid_feats.head()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 3, 171 | "id": "971b4172", 172 | "metadata": { 173 | "execution": { 174 | "iopub.execute_input": "2024-06-08T02:02:44.646267Z", 175 | "iopub.status.busy": "2024-06-08T02:02:44.645800Z", 176 | "iopub.status.idle": "2024-06-08T02:08:42.246643Z", 177 | "shell.execute_reply": "2024-06-08T02:08:42.243428Z" 178 | }, 179 | "papermill": { 180 | "duration": 357.61124, 181 | "end_time": "2024-06-08T02:08:42.251527", 182 | "exception": false, 183 | "start_time": "2024-06-08T02:02:44.640287", 184 | "status": "completed" 185 | }, 186 | "tags": [] 187 | }, 188 | "outputs": [ 189 | { 190 | "name": "stderr", 191 | "output_type": "stream", 192 | "text": [ 193 | "Read 233M words\n", 194 | "Number of words: 3714332\n", 195 | "Number of labels: 14\n", 196 | "Progress: 100.0% words/sec/thread: 2864544 lr: 0.000000 avg.loss: 0.734348 ETA: 0h 0m 0s\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "train_text=train_feats['abstract'].values\n", 202 | "train_label=train_feats['label'].values\n", 203 | "train_data=[f'__label__{train_label[i]} '+train_text[i] for i in range(len(train_text))]\n", 204 | "data = pd.DataFrame(train_data)\n", 205 | "data.to_csv(\"train.txt\", #输出文件的名称\n", 206 | " index=False,#不包含行索引(0,1,2,3,4,……)\n", 207 | " sep=' ', #以空格为分隔符\n", 208 | " header=False,#不包含列名\n", 209 | " quoting=csv.QUOTE_NONE, #写入csv文件时不要为任何字段添加引号\n", 210 | " quotechar=\"\",#空字符是安全的\n", 211 | " escapechar=\" \"#引号被设置为一个空字符串.\n", 212 | " )\n", 213 | "#训练一个监督学习模型\n", 214 | "model = fasttext.train_supervised('train.txt',#文件路径 \n", 215 | " label_prefix='__label__',#指定的前缀\n", 216 | " thread=4, #开了4个线程加速运算\n", 217 | " epoch = 12,#模型训练100次\n", 218 | " )\n", 219 | "# 使用 fasttext 提供的 save_model 方法保存模型\n", 220 | "model.save_model('fasttext_arxivcategory.model')\n", 221 | "## 加载之前保存的模型\n", 222 | "#model = fasttext.load_model('fasttext_arxivcategory.model')" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 4, 228 | "id": "2be8bf6e", 229 | "metadata": { 230 | "execution": { 231 | "iopub.execute_input": "2024-06-08T02:08:42.469356Z", 232 | "iopub.status.busy": "2024-06-08T02:08:42.468805Z", 233 | "iopub.status.idle": "2024-06-08T02:09:47.982732Z", 234 | "shell.execute_reply": "2024-06-08T02:09:47.981474Z" 235 | }, 236 | "papermill": { 237 | "duration": 65.713624, 238 | "end_time": "2024-06-08T02:09:48.077260", 239 | "exception": false, 240 | "start_time": "2024-06-08T02:08:42.363636", 241 | "status": "completed" 242 | }, 243 | "tags": [] 244 | }, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "accuracy:0.79926104897164\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "valid_text=list(valid_feats['abstract'].values)\n", 256 | "valid_text = [w.replace('\\n', '') for w in valid_text]\n", 257 | "preds,pro= model.predict(valid_text,k=len(model.labels))\n", 258 | "\n", 259 | "preds=np.array([int(pred[0][9:])for pred in preds])\n", 260 | "true=valid_feats['label'].values\n", 261 | "print(f\"accuracy:{np.mean(preds==true)}\")" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kaggle": { 267 | "accelerator": "none", 268 | "dataSources": [ 269 | { 270 | "datasetId": 612177, 271 | "sourceId": 8581546, 272 | "sourceType": "datasetVersion" 273 | } 274 | ], 275 | "dockerImageVersionId": 30732, 276 | "isGpuEnabled": false, 277 | "isInternetEnabled": true, 278 | "language": "python", 279 | "sourceType": "notebook" 280 | }, 281 | "kernelspec": { 282 | "display_name": "Python 3", 283 | "language": "python", 284 | "name": "python3" 285 | }, 286 | "language_info": { 287 | "codemirror_mode": { 288 | "name": "ipython", 289 | "version": 3 290 | }, 291 | "file_extension": ".py", 292 | "mimetype": "text/x-python", 293 | "name": "python", 294 | "nbconvert_exporter": "python", 295 | "pygments_lexer": "ipython3", 296 | "version": "3.10.13" 297 | }, 298 | "papermill": { 299 | "default_parameters": {}, 300 | "duration": 616.057567, 301 | "end_time": "2024-06-08T02:09:52.429199", 302 | "environment_variables": {}, 303 | "exception": null, 304 | "input_path": "__notebook__.ipynb", 305 | "output_path": "__notebook__.ipynb", 306 | "parameters": {}, 307 | "start_time": "2024-06-08T01:59:36.371632", 308 | "version": "2.5.0" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 5 313 | } 314 | -------------------------------------------------------------------------------- /202407科大讯飞短视频推荐baseline/科大讯飞短视频推荐LB0.00033.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8dae28d4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Created by yunsuxiaozi 2024/7/24\n", 9 | "\n", 10 | "#### 比赛官网如下:短视频精准推荐挑战赛,更多比赛的baseline可以在AI and competition找到。\n", 11 | "\n", 12 | "#### 注:这个比赛奖金不高,参加这个比赛是因为我从来没有参加过推荐的比赛,为了长长见识,本次比赛的评估指标我也是第一次看到。这个方案目前没有用任何模型,在排行榜上目前排4/10,故开源。" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "4a174eeb", 18 | "metadata": {}, 19 | "source": [ 20 | "### 1.读取日志数据" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "9b006821", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "len(logs):3487102\n" 34 | ] 35 | }, 36 | { 37 | "data": { 38 | "text/html": [ 39 | "
\n", 40 | "\n", 53 | "\n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | "
uidvidcidplaytimedurationdaterankplaypercent
0100000100870100037496695902020070610.517831
1100000101167100024503749492020070621.017781
21000001036081000085137148842020070630.345136
31000001002201000845139167842020070640.306184
4100000101674100027514957602020070650.893924
\n", 125 | "
" 126 | ], 127 | "text/plain": [ 128 | " uid vid cid playtime duration date rank playpercent\n", 129 | "0 100000 100870 100037 4966 9590 20200706 1 0.517831\n", 130 | "1 100000 101167 100024 5037 4949 20200706 2 1.017781\n", 131 | "2 100000 103608 100008 5137 14884 20200706 3 0.345136\n", 132 | "3 100000 100220 100084 5139 16784 20200706 4 0.306184\n", 133 | "4 100000 101674 100027 5149 5760 20200706 5 0.893924" 134 | ] 135 | }, 136 | "execution_count": 1, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "import pandas as pd#导入csv文件的库\n", 143 | "import numpy as np#矩阵运算与科学计算的库\n", 144 | "path=''#这里需要换成你自己的数据路径\n", 145 | "logs=pd.read_csv(path+\"uid_click_log.csv\")\n", 146 | "print(f\"len(logs):{len(logs)}\")\n", 147 | "#观影时长/视频时长,如果百分比小于1,说明没看完就走了,大于1说明有些地方重复观看了\n", 148 | "logs['playpercent']=logs['playtime']/logs['duration']\n", 149 | "logs.head()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "id": "f05af3d6", 155 | "metadata": {}, 156 | "source": [ 157 | "### 如果一个视频少于5个人看,那就说明比较小众,也就不推荐了,这里就是把这些视频过滤掉。" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 2, 163 | "id": "83843fe9", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "len(dislike_vids):1328\n", 171 | "len(logs):3484168\n" 172 | ] 173 | }, 174 | { 175 | "data": { 176 | "text/html": [ 177 | "
\n", 178 | "\n", 191 | "\n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | "
uidvidcidplaytimedurationdaterankplaypercent
0100000100870100037496695902020070610.517831
1100000101167100024503749492020070621.017781
21000001036081000085137148842020070630.345136
31000001002201000845139167842020070640.306184
4100000101674100027514957602020070650.893924
\n", 263 | "
" 264 | ], 265 | "text/plain": [ 266 | " uid vid cid playtime duration date rank playpercent\n", 267 | "0 100000 100870 100037 4966 9590 20200706 1 0.517831\n", 268 | "1 100000 101167 100024 5037 4949 20200706 2 1.017781\n", 269 | "2 100000 103608 100008 5137 14884 20200706 3 0.345136\n", 270 | "3 100000 100220 100084 5139 16784 20200706 4 0.306184\n", 271 | "4 100000 101674 100027 5149 5760 20200706 5 0.893924" 272 | ] 273 | }, 274 | "execution_count": 2, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "vid_count=logs['vid'].value_counts().to_dict()\n", 281 | "\n", 282 | "#如果一个视频少于5个人看,那就说明比较小众,也就不推荐了\n", 283 | "dislike_vids=[]\n", 284 | "for vid,count in vid_count.items():\n", 285 | " if count<5:\n", 286 | " dislike_vids.append(vid)\n", 287 | "print(f\"len(dislike_vids):{len(dislike_vids)}\") \n", 288 | "logs=logs[~logs['vid'].isin(dislike_vids)]\n", 289 | "print(f\"len(logs):{len(logs)}\")\n", 290 | "logs.head()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "id": "3510b470", 296 | "metadata": {}, 297 | "source": [ 298 | "### 这里看看每个视频大类cid里每个vid的观看人数count." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 3, 304 | "id": "9a2a3a7e", 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "data": { 309 | "text/html": [ 310 | "
\n", 311 | "\n", 324 | "\n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | "
cidvidcount
91000001000982085
5051000001035791909
4521000001032171904
131000001001131895
3101000001021121895
\n", 366 | "
" 367 | ], 368 | "text/plain": [ 369 | " cid vid count\n", 370 | "9 100000 100098 2085\n", 371 | "505 100000 103579 1909\n", 372 | "452 100000 103217 1904\n", 373 | "13 100000 100113 1895\n", 374 | "310 100000 102112 1895" 375 | ] 376 | }, 377 | "execution_count": 3, 378 | "metadata": {}, 379 | "output_type": "execute_result" 380 | } 381 | ], 382 | "source": [ 383 | "#每个cid里每个vid的count\n", 384 | "cid_vid_count=logs.groupby(['cid','vid'])['uid'].count().reset_index().rename(columns={\"uid\":'count'})\n", 385 | "#每个视频观看的人数越多排在越前面\n", 386 | "cid_vid_count=cid_vid_count.sort_values(['cid','count'],ascending=[True,False])\n", 387 | "cid_vid_count.head()" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "3be4b6d7", 393 | "metadata": {}, 394 | "source": [ 395 | "### 这里看看每个用户uid喜欢看每个cid的时间。" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 4, 401 | "id": "b3fbd318", 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "text/html": [ 407 | "
\n", 408 | "\n", 421 | "\n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | "
uidcidplaytimesum
01000001000001193095
16100000100019687648
22100000100027672526
14100000100017357325
3100000100005319041
\n", 463 | "
" 464 | ], 465 | "text/plain": [ 466 | " uid cid playtimesum\n", 467 | "0 100000 100000 1193095\n", 468 | "16 100000 100019 687648\n", 469 | "22 100000 100027 672526\n", 470 | "14 100000 100017 357325\n", 471 | "3 100000 100005 319041" 472 | ] 473 | }, 474 | "execution_count": 4, 475 | "metadata": {}, 476 | "output_type": "execute_result" 477 | } 478 | ], 479 | "source": [ 480 | "#每个uid最喜欢什么cid\n", 481 | "uid_cid_playtime=logs.groupby(['uid','cid'])['playtime'].sum().reset_index().rename(columns={\"playtime\":'playtimesum'})\n", 482 | "#看看每个uid最喜欢看哪个类型的cid\n", 483 | "uid_cid_playtime=uid_cid_playtime.sort_values(['uid','playtimesum'],ascending=[True,False])\n", 484 | "uid_cid_playtime.head()" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "id": "f91f3f43", 490 | "metadata": {}, 491 | "source": [ 492 | "### 这里选择每个用户uid最喜欢看的3个cid中的2个(1个)vid。" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 5, 498 | "id": "c2a7a1d5", 499 | "metadata": { 500 | "scrolled": true 501 | }, 502 | "outputs": [ 503 | { 504 | "name": "stdout", 505 | "output_type": "stream", 506 | "text": [ 507 | "len(submission):25195\n" 508 | ] 509 | }, 510 | { 511 | "data": { 512 | "text/html": [ 513 | "
\n", 514 | "\n", 527 | "\n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | "
uidvid
0100000100098
1100000103579
2100000100019
3100000100515
4100000100981
\n", 563 | "
" 564 | ], 565 | "text/plain": [ 566 | " uid vid\n", 567 | "0 100000 100098\n", 568 | "1 100000 103579\n", 569 | "2 100000 100019\n", 570 | "3 100000 100515\n", 571 | "4 100000 100981" 572 | ] 573 | }, 574 | "execution_count": 5, 575 | "metadata": {}, 576 | "output_type": "execute_result" 577 | } 578 | ], 579 | "source": [ 580 | "submission=pd.read_csv(path+\"example.csv\")\n", 581 | "print(f\"len(submission):{len(submission)}\")\n", 582 | "uids=submission['uid'].unique()\n", 583 | "test_videos=[]\n", 584 | "for i in range(len(uids)):\n", 585 | " #初始化填充,纯粹就是选择看的人最多的5个视频作为推荐\n", 586 | " videos=[100592,102029,100423,100833,101428]\n", 587 | " uid=uids[i]\n", 588 | " #每个uid喜欢的cids\n", 589 | " cids=uid_cid_playtime[uid_cid_playtime['uid']==uid]['cid'].values\n", 590 | " fill_idx=0#从0开始填充\n", 591 | " #第二次填充,给每个uid推荐它们喜欢的cid下最受欢迎的视频,为了让视频丰富一点,每个cid下选2个视频,最后一个cid下选一个视频\n", 592 | " for cid in cids:\n", 593 | " #uid喜欢的cid里推荐2个uid\n", 594 | " vids=cid_vid_count[cid_vid_count['cid']==cid]['vid'].values\n", 595 | " if fill_idx<4:\n", 596 | " videos[fill_idx:fill_idx+2]=vids[:2]\n", 597 | " else:\n", 598 | " videos[fill_idx:fill_idx+1]=vids[:1]\n", 599 | " fill_idx+=2\n", 600 | " if fill_idx>5:\n", 601 | " break\n", 602 | " test_videos+=videos\n", 603 | "submission['vid']=test_videos\n", 604 | "submission.to_csv(path+\"most_like.csv\",index=None)\n", 605 | "submission.head()" 606 | ] 607 | } 608 | ], 609 | "metadata": { 610 | "kernelspec": { 611 | "display_name": "Python 3 (ipykernel)", 612 | "language": "python", 613 | "name": "python3" 614 | }, 615 | "language_info": { 616 | "codemirror_mode": { 617 | "name": "ipython", 618 | "version": 3 619 | }, 620 | "file_extension": ".py", 621 | "mimetype": "text/x-python", 622 | "name": "python", 623 | "nbconvert_exporter": "python", 624 | "pygments_lexer": "ipython3", 625 | "version": "3.8.5" 626 | } 627 | }, 628 | "nbformat": 4, 629 | "nbformat_minor": 5 630 | } 631 | -------------------------------------------------------------------------------- /202406高校大数据挑战赛baseline/2024bdc-baseline-LB1-7098.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "03f93274", 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 9 | "papermill": { 10 | "duration": 0.00697, 11 | "end_time": "2024-06-29T03:08:29.211294", 12 | "exception": false, 13 | "start_time": "2024-06-29T03:08:29.204324", 14 | "status": "completed" 15 | }, 16 | "tags": [] 17 | }, 18 | "source": [ 19 | "## Created by yunsuxiaozi 2024/06/29\n", 20 | "\n", 21 | "\n", 22 | "#### 这是2024年高校大数据挑战赛的baseline,你可以在AI and competition里获取更多比赛的baseline。本次比赛官网如下:2024bdc\n", 23 | "\n", 24 | "#### 本次比赛官方所给的baseline是最新的论文iTransformer,分数非常高,大部分的选手都已经采用,并且在它的基础上改进取得了更好的成绩。我这里使用2维的CNN来做一个简单的baseline,分数不高,仅供参考。\n", 25 | "\n", 26 | "\n", 27 | "#### 本次比赛的数据集已经被老师上传到Kaggle了,数据集链接如下:2024bdc dataset。这里也直接在Kaggle上运行程序,如果想使用我的baseline可以将数据的路径改成你自己的路径。" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "29d6327e", 33 | "metadata": { 34 | "papermill": { 35 | "duration": 0.004427, 36 | "end_time": "2024-06-29T03:08:29.223014", 37 | "exception": false, 38 | "start_time": "2024-06-29T03:08:29.218587", 39 | "status": "completed" 40 | }, 41 | "tags": [] 42 | }, 43 | "source": [ 44 | "## 1.导入必要的python库,这里不多做解释,注释也已经写的很清楚了。固定随机种子是为了保证模型可以复现。" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 1, 50 | "id": "57c59ea4", 51 | "metadata": { 52 | "execution": { 53 | "iopub.execute_input": "2024-06-29T03:08:29.234153Z", 54 | "iopub.status.busy": "2024-06-29T03:08:29.233746Z", 55 | "iopub.status.idle": "2024-06-29T03:08:34.822077Z", 56 | "shell.execute_reply": "2024-06-29T03:08:34.820835Z" 57 | }, 58 | "papermill": { 59 | "duration": 5.597729, 60 | "end_time": "2024-06-29T03:08:34.825504", 61 | "exception": false, 62 | "start_time": "2024-06-29T03:08:29.227775", 63 | "status": "completed" 64 | }, 65 | "tags": [] 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "import numpy as np#矩阵运算与科学计算的库\n", 70 | "from sklearn.model_selection import train_test_split\n", 71 | "import torch#深度学习库,pytorch\n", 72 | "import torch.nn as nn#neural network,神经网络\n", 73 | "import torch.nn.functional as F#神经网络函数库\n", 74 | "import torch.optim as optim#一个实现了各种优化算法的库\n", 75 | "import gc#垃圾回收模块\n", 76 | "import warnings#避免一些可以忽略的报错\n", 77 | "warnings.filterwarnings('ignore')#filterwarnings()方法是用于设置警告过滤器的方法,它可以控制警告信息的输出方式和级别。\n", 78 | "\n", 79 | "#设置随机种子\n", 80 | "import random\n", 81 | "def seed_everything(seed):\n", 82 | " torch.backends.cudnn.deterministic = True#将cuda加速的随机数生成器设为确定性模式\n", 83 | " torch.backends.cudnn.benchmark = False#关闭CuDNN框架的自动寻找最优卷积算法的功能,以避免不同的算法对结果产生影响\n", 84 | " torch.manual_seed(seed)#pytorch的随机种子\n", 85 | " np.random.seed(seed)#numpy的随机种子\n", 86 | " random.seed(seed)#python内置的随机种子\n", 87 | "seed_everything(seed=2024)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "2832a524", 93 | "metadata": { 94 | "papermill": { 95 | "duration": 0.00481, 96 | "end_time": "2024-06-29T03:08:34.835399", 97 | "exception": false, 98 | "start_time": "2024-06-29T03:08:34.830589", 99 | "status": "completed" 100 | }, 101 | "tags": [] 102 | }, 103 | "source": [ 104 | "## 2.导入数据。这里有必要对数据进行详细的说明。\n", 105 | "\n", 106 | "- global_data:shape为(5848, 4, 9, 3850)\n", 107 | "\n", 108 | " 5848是时间,数据是2019年1月-2020年12月每3小时一次,2年总共有731天,每天有8次记录,故731\\*8=5848\n", 109 | " \n", 110 | " 4是4个特征,即:十米高度的矢量纬向风速10U,正方向为东方(m/s);十米高度的矢量经向风速10V,正方向为北方(m/s);两米高度的温度值T2M(℃);均一海平面气压MSL(Pa)\n", 111 | " \n", 112 | " 9是9个网格,即:左上、上、右上、左、中、右、左下、下、右下.\n", 113 | " \n", 114 | " 3850就是3850个站点。\n", 115 | " \n", 116 | "- temp:shape为(17544, 3850, 1)\n", 117 | " \n", 118 | " 17544为5848\\*3,就是把数据变成1小时1次记录\n", 119 | " \n", 120 | " 3850 是3850个站点\n", 121 | " \n", 122 | " 1我感觉这个维度完全多余。\n", 123 | " \n", 124 | "- wind:shape为(17544, 3850, 1),解释和temp一样。\n", 125 | "\n", 126 | "\n", 127 | "#### 对数据的处理:\n", 128 | "\n", 129 | "- 首先需要将global_data在时间上变成1小时记录一次,这里由于就是baseline,所以将1个数据复制3次,如果后续改进,可以尝试用插值来搞? 9这个维度是方位,由于每个方位检测到的4个特征应该都是相关性比较高的,所以这里考虑直接对它们求平均处理。这样就将global_data变成(17544,4,1,3850)。由于后续的处理是根据每个站点前144个小时的特征预测下一个小时的特征,所以,这里将数据变成(3850,17544,4),具体怎么变见代码。\n", 130 | "\n", 131 | "- 对于temp和wind的处理就是将(17544,3850,1)变成(3850,17544,1)\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 2, 137 | "id": "fc4bff99", 138 | "metadata": { 139 | "execution": { 140 | "iopub.execute_input": "2024-06-29T03:08:34.847716Z", 141 | "iopub.status.busy": "2024-06-29T03:08:34.846307Z", 142 | "iopub.status.idle": "2024-06-29T03:09:33.224982Z", 143 | "shell.execute_reply": "2024-06-29T03:09:33.223566Z" 144 | }, 145 | "papermill": { 146 | "duration": 58.387764, 147 | "end_time": "2024-06-29T03:09:33.227889", 148 | "exception": false, 149 | "start_time": "2024-06-29T03:08:34.840125", 150 | "status": "completed" 151 | }, 152 | "tags": [] 153 | }, 154 | "outputs": [ 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "global_data.shape:(5848, 4, 1, 3850)\n", 160 | "global_data_hour.shape:(3850, 17544, 4)\n", 161 | "temp.shape:(3850, 17544, 1)\n", 162 | "wind.shape:(3850, 17544, 1)\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "path='/kaggle/input/bigdata2024/global/'\n", 168 | "#每3个小时,(温度、湿度、风速、风向),(左上、上、右上、左、中、右、左下、下、右下),3850个站点\n", 169 | "# (5848, 4, 9, x)\n", 170 | "global_data=np.load(path+\"global_data.npy\").mean(axis=-2,keepdims=True)\n", 171 | "print(f\"global_data.shape:{global_data.shape}\")\n", 172 | "#将3个小时变成1个小时 (5848*3, 4, 1, x)\n", 173 | "global_data_hour=np.repeat(global_data, 3, axis=0)\n", 174 | " \n", 175 | "del global_data\n", 176 | "gc.collect()#手动触发垃圾回收,强制回收由垃圾回收器标记为未使用的内存\n", 177 | "\n", 178 | "# (5848*3, 4, 9, x)->(x,5848*3,36)\n", 179 | "global_data_hour=global_data_hour.transpose(3,0,1,2)\n", 180 | "#(x,5848*3,36)\n", 181 | "global_data_hour=global_data_hour.reshape(len(global_data_hour),-1,4)\n", 182 | "print(f\"global_data_hour.shape:{global_data_hour.shape}\")\n", 183 | "\n", 184 | "#每个小时,每个站点的温度 (17544, x, 1)->(x,17544,1)\n", 185 | "temp=np.load(path+\"temp.npy\").transpose(1,0,2)\n", 186 | "print(f\"temp.shape:{temp.shape}\")\n", 187 | "#每个小时,每个站点的风速 (17544, x, 1)->(x,17544,1)\n", 188 | "wind=np.load(path+\"wind.npy\").transpose(1,0,2)\n", 189 | "print(f\"wind.shape:{wind.shape}\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "id": "31677f9b", 195 | "metadata": { 196 | "papermill": { 197 | "duration": 0.004886, 198 | "end_time": "2024-06-29T03:09:33.238090", 199 | "exception": false, 200 | "start_time": "2024-06-29T03:09:33.233204", 201 | "status": "completed" 202 | }, 203 | "tags": [] 204 | }, 205 | "source": [ 206 | "## 3.数据的采样。\n", 207 | "\n", 208 | "#### 我们之前得到的特征是global_data:(3850,17544,4),temp和wind(3850,17544,1),我们这里的idea是用前144个时刻的所有特征预测下一个时刻的temp和wind,所以先拼接一个总特征(3850,17544,6),然后再构造X和y1,y2。由于全部数据的数据量巨大,所以这里对数据进行采样,采样的概率为0.0125,因为0.015我试过,超内存了。对数据进行标准化是神经网络必要的数据预处理,train_mean和train_std也要保存,因为提交的时候对测试数据也要进行同样的操作。最后数据处理完的维度X:(len(X),144\\*6),y1,y2:(len(X),1)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 3, 214 | "id": "5984b296", 215 | "metadata": { 216 | "execution": { 217 | "iopub.execute_input": "2024-06-29T03:09:33.250303Z", 218 | "iopub.status.busy": "2024-06-29T03:09:33.249872Z", 219 | "iopub.status.idle": "2024-06-29T03:11:02.343053Z", 220 | "shell.execute_reply": "2024-06-29T03:11:02.338170Z" 221 | }, 222 | "papermill": { 223 | "duration": 89.139295, 224 | "end_time": "2024-06-29T03:11:02.382451", 225 | "exception": false, 226 | "start_time": "2024-06-29T03:09:33.243156", 227 | "status": "completed" 228 | }, 229 | "tags": [] 230 | }, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "train_feats.shape:(3850, 17544, 6)\n", 237 | "X.shape:(836534, 864),y1.shape:(836534, 1),y2.shape:(836534, 1)\n" 238 | ] 239 | }, 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "0" 244 | ] 245 | }, 246 | "execution_count": 3, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "#(x,17544,38)\n", 253 | "train_feats=np.concatenate((global_data_hour,temp,wind),axis=-1)\n", 254 | "print(f\"train_feats.shape:{train_feats.shape}\")\n", 255 | "#(x,17544,1),(x,17544,1)\n", 256 | "label1,label2=temp,wind\n", 257 | "\n", 258 | "def get_train_data(train_feats,label1,label2):#(x,17544,38),(x,17544,1),(x,17544,1)\n", 259 | " X,y1,y2=[],[],[]\n", 260 | " #每个站点\n", 261 | " for si in range(train_feats.shape[0]):\n", 262 | " for ti in range(train_feats.shape[1]-144):\n", 263 | " if np.random.rand()<0.0125:#这里再进行采样\n", 264 | " #si个站点ti:ti+144个时刻的所有特征\n", 265 | " X.append(train_feats[si][ti:ti+144].reshape(-1))\n", 266 | " y1.append(label1[si][ti+144])\n", 267 | " y2.append(label2[si][ti+144])\n", 268 | " X,y1,y2=np.array(X),np.array(y1),np.array(y2)\n", 269 | " return X,y1,y2\n", 270 | "X,y1,y2=get_train_data(train_feats,label1,label2)\n", 271 | "train_mean=X.mean(axis=0)\n", 272 | "train_std=X.std(axis=0)\n", 273 | "np.save(\"train_mean.npy\",train_mean)\n", 274 | "np.save(\"train_std.npy\",train_std)\n", 275 | "X=(X-train_mean)/train_std\n", 276 | "print(f\"X.shape:{X.shape},y1.shape:{y1.shape},y2.shape:{y2.shape}\")\n", 277 | "del global_data_hour,temp,wind,train_feats\n", 278 | "gc.collect()#手动触发垃圾回收,强制回收由垃圾回收器标记为未使用的内存" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "c2457e5b", 284 | "metadata": { 285 | "papermill": { 286 | "duration": 0.016848, 287 | "end_time": "2024-06-29T03:11:02.421534", 288 | "exception": false, 289 | "start_time": "2024-06-29T03:11:02.404686", 290 | "status": "completed" 291 | }, 292 | "tags": [] 293 | }, 294 | "source": [ 295 | "## 4.BaselineModel\n", 296 | "\n", 297 | "#### 这里搭建了一个简单的CNN作为baseline。" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 4, 303 | "id": "3c211ad7", 304 | "metadata": { 305 | "execution": { 306 | "iopub.execute_input": "2024-06-29T03:11:02.474862Z", 307 | "iopub.status.busy": "2024-06-29T03:11:02.469753Z", 308 | "iopub.status.idle": "2024-06-29T03:11:02.521214Z", 309 | "shell.execute_reply": "2024-06-29T03:11:02.515412Z" 310 | }, 311 | "papermill": { 312 | "duration": 0.087265, 313 | "end_time": "2024-06-29T03:11:02.530069", 314 | "exception": false, 315 | "start_time": "2024-06-29T03:11:02.442804", 316 | "status": "completed" 317 | }, 318 | "tags": [] 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "class BaselineModel(nn.Module):\n", 323 | " def __init__(self,):\n", 324 | " super(BaselineModel,self).__init__()\n", 325 | " self.conv=nn.Sequential(\n", 326 | " #1*24*36->16*24*36\n", 327 | " nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1),\n", 328 | " nn.BatchNorm2d(16),\n", 329 | " #16*24*36->16*12*18\n", 330 | " nn.MaxPool2d(kernel_size=2,stride=2),\n", 331 | " nn.GELU(),\n", 332 | " #16*12*18->32*12*18\n", 333 | " nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1),\n", 334 | " nn.BatchNorm2d(32),\n", 335 | " #32*12*18->64*12*18\n", 336 | " nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1,padding=2),\n", 337 | " nn.BatchNorm2d(64),\n", 338 | " #64*12*18->64*6*9\n", 339 | " nn.MaxPool2d(kernel_size=2,stride=2),\n", 340 | " nn.GELU(),\n", 341 | " #64*6*9->128*6*9\n", 342 | " nn.Conv2d(in_channels=64,out_channels=128,kernel_size=5,stride=1,padding=2),\n", 343 | " nn.BatchNorm2d(128),\n", 344 | " #128*6*9->128*3*4\n", 345 | " nn.MaxPool2d(kernel_size=2,stride=2),\n", 346 | " nn.GELU(),\n", 347 | " )\n", 348 | " self.head=nn.Sequential(\n", 349 | " nn.Linear(128*3*4,128),\n", 350 | " nn.BatchNorm1d(128),\n", 351 | " nn.GELU(),\n", 352 | " nn.Linear(128,256),\n", 353 | " nn.BatchNorm1d(256),\n", 354 | " nn.GELU(),\n", 355 | " nn.Linear(256,1)\n", 356 | " )\n", 357 | " \n", 358 | " def forward(self,x):\n", 359 | " x=self.conv(x)\n", 360 | " x=x.reshape(x.shape[0],-1)\n", 361 | " return self.head(x)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "7a9711e1", 367 | "metadata": { 368 | "papermill": { 369 | "duration": 0.019414, 370 | "end_time": "2024-06-29T03:11:02.567288", 371 | "exception": false, 372 | "start_time": "2024-06-29T03:11:02.547874", 373 | "status": "completed" 374 | }, 375 | "tags": [] 376 | }, 377 | "source": [ 378 | "## 5.模型的训练\n", 379 | "\n", 380 | "#### date就是说这是6月29日的第二次提交。这里之所以将864reshape成(1,24,36)只是想将数据搞得尽可能正方形一点,好使用CNN来卷积,模型训练使用的是MSE,评估指标使用的是官方的评估指标,由于是对temp和wind搞了2个模型,所以没有看最终指标的分数。可能是因为我用train_test_split存在数据泄露的情况,线下跑出来的指标好低,和线上完全对不上。" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 5, 386 | "id": "bec8d8dd", 387 | "metadata": { 388 | "execution": { 389 | "iopub.execute_input": "2024-06-29T03:11:02.605931Z", 390 | "iopub.status.busy": "2024-06-29T03:11:02.605059Z", 391 | "iopub.status.idle": "2024-06-29T09:49:19.548480Z", 392 | "shell.execute_reply": "2024-06-29T09:49:19.543938Z" 393 | }, 394 | "papermill": { 395 | "duration": 23896.967762, 396 | "end_time": "2024-06-29T09:49:19.554212", 397 | "exception": false, 398 | "start_time": "2024-06-29T03:11:02.586450", 399 | "status": "completed" 400 | }, 401 | "tags": [] 402 | }, 403 | "outputs": [ 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "epoch:0,name:temp\n", 409 | "train_loss:39.63254928588867\n", 410 | "valid_loss:6.957380294799805,metric:0.042841896378927866\n", 411 | "epoch:1,name:temp\n", 412 | "train_loss:6.404951572418213\n", 413 | "valid_loss:4.364864349365234,metric:0.026876498608155396\n", 414 | "epoch:2,name:temp\n", 415 | "train_loss:3.9735267162323\n", 416 | "valid_loss:3.3389687538146973,metric:0.020568538182798583\n", 417 | "epoch:3,name:temp\n", 418 | "train_loss:3.416158437728882\n", 419 | "valid_loss:3.1424412727355957,metric:0.019356603391767334\n", 420 | "epoch:4,name:temp\n", 421 | "train_loss:3.0788538455963135\n", 422 | "valid_loss:2.7104780673980713,metric:0.01669682685350834\n", 423 | "epoch:5,name:temp\n", 424 | "train_loss:2.764535903930664\n", 425 | "valid_loss:3.2275211811065674,metric:0.0198797962722501\n", 426 | "epoch:6,name:temp\n", 427 | "train_loss:2.61907696723938\n", 428 | "valid_loss:3.0715019702911377,metric:0.018926934101076896\n", 429 | "epoch:7,name:temp\n", 430 | "train_loss:2.4271206855773926\n", 431 | "valid_loss:2.6960549354553223,metric:0.016605464776724428\n", 432 | "epoch:8,name:temp\n", 433 | "train_loss:2.1649467945098877\n", 434 | "valid_loss:2.8013415336608887,metric:0.017257515111020413\n", 435 | "epoch:9,name:temp\n", 436 | "train_loss:2.076241970062256\n", 437 | "valid_loss:2.7667369842529297,metric:0.01704180363123001\n", 438 | "epoch:0,name:wind\n", 439 | "train_loss:2.634408950805664\n", 440 | "valid_loss:1.8498305082321167,metric:0.293155302079099\n", 441 | "epoch:1,name:wind\n", 442 | "train_loss:1.574277400970459\n", 443 | "valid_loss:1.508863091468811,metric:0.23910268708680157\n", 444 | "epoch:2,name:wind\n", 445 | "train_loss:1.4538973569869995\n", 446 | "valid_loss:1.4260954856872559,metric:0.22600743773084295\n", 447 | "epoch:3,name:wind\n", 448 | "train_loss:1.382198452949524\n", 449 | "valid_loss:1.442484736442566,metric:0.22861156092489276\n", 450 | "epoch:4,name:wind\n", 451 | "train_loss:1.3340543508529663\n", 452 | "valid_loss:1.3945986032485962,metric:0.22101595006495284\n", 453 | "epoch:5,name:wind\n", 454 | "train_loss:1.3144776821136475\n", 455 | "valid_loss:1.4613676071166992,metric:0.23158700964706677\n", 456 | "epoch:6,name:wind\n", 457 | "train_loss:1.3275376558303833\n", 458 | "valid_loss:1.4318352937698364,metric:0.22690022686209754\n", 459 | "epoch:7,name:wind\n", 460 | "train_loss:1.2891463041305542\n", 461 | "valid_loss:1.4204884767532349,metric:0.22511097144144243\n", 462 | "epoch:8,name:wind\n", 463 | "train_loss:1.273412823677063\n", 464 | "valid_loss:1.396796464920044,metric:0.22136685331401934\n", 465 | "epoch:9,name:wind\n", 466 | "train_loss:1.2555644512176514\n", 467 | "valid_loss:1.3794482946395874,metric:0.2186393584855981\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "date='0629_2'\n", 473 | "def loss_fn(y_true,y_pred):#torch.tensor\n", 474 | " return torch.mean((y_true-y_pred)**2)\n", 475 | "def metric(y_true,y_pred):#np.array\n", 476 | " return np.mean((y_true-y_pred)**2)/np.var(y_true)\n", 477 | "\n", 478 | "def train(X,y,batch_size=1024,num_epochs=5,name='wind'):#传入的是np.array的数据,name是wind还是temp\n", 479 | " train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size=0.2, random_state=2024,shuffle=False)\n", 480 | " #模型设置\n", 481 | " model=BaselineModel()\n", 482 | " #优化器设置\n", 483 | " optimizer=optim.Adam(model.parameters(),lr=0.000025,betas=(0.5,0.999))\n", 484 | " for epoch in range(num_epochs):\n", 485 | " print(f\"epoch:{epoch},name:{name}\")\n", 486 | " #模型设置为训练状态\n", 487 | " model.train()\n", 488 | " #将梯度清空\n", 489 | " optimizer.zero_grad()\n", 490 | " #每次训练之前先打乱顺序\n", 491 | " random_index=np.arange(len(train_X))\n", 492 | " np.random.shuffle(random_index)\n", 493 | " train_X,train_y=train_X[random_index],train_y[random_index]\n", 494 | " train_loss=0.0\n", 495 | " for idx in range(0,len(train_X),batch_size):\n", 496 | " train_X1=torch.Tensor(train_X[idx:idx+batch_size]).reshape(-1,1,24,36)\n", 497 | " train_y1=torch.Tensor(train_y[idx:idx+batch_size])\n", 498 | " train_pred=model(train_X1)\n", 499 | " loss=loss_fn(train_y1,train_pred)\n", 500 | " #反向传播\n", 501 | " loss.backward()\n", 502 | " #优化器进行优化(梯度下降,降低误差)\n", 503 | " optimizer.step()\n", 504 | " train_loss+=loss\n", 505 | " print(f\"train_loss:{train_loss/(len(train_X)//batch_size)}\")\n", 506 | " #模型设置为评估模式\n", 507 | " model.eval()\n", 508 | " with torch.no_grad():\n", 509 | " valid_loss=0.00\n", 510 | " valid_preds=np.zeros(len(valid_y))\n", 511 | " for idx in range(0,len(valid_X),batch_size):\n", 512 | " valid_X1=torch.Tensor(valid_X[idx:idx+batch_size]).reshape(-1,1,24,36)\n", 513 | " valid_y1=torch.Tensor(valid_y[idx:idx+batch_size])\n", 514 | " valid_pred=model(valid_X1)\n", 515 | " loss=loss_fn(valid_y1,valid_pred)\n", 516 | " valid_loss+=loss\n", 517 | " valid_preds[idx:idx+batch_size]=valid_pred.detach().numpy().reshape(-1)\n", 518 | " print(f\"valid_loss:{valid_loss/(len(valid_X)//batch_size)},metric:{metric(valid_y.reshape(-1),valid_preds)}\")\n", 519 | " torch.cuda.empty_cache()\n", 520 | " torch.save(model.state_dict(),f\"{date}{name}.pth\")\n", 521 | "train(X,y1,batch_size=128,num_epochs=10,name='temp')\n", 522 | "train(X,y2,batch_size=128,num_epochs=10,name='wind')" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "id": "5ea200c4", 528 | "metadata": { 529 | "papermill": { 530 | "duration": 0.010082, 531 | "end_time": "2024-06-29T09:49:19.574455", 532 | "exception": false, 533 | "start_time": "2024-06-29T09:49:19.564373", 534 | "status": "completed" 535 | }, 536 | "tags": [] 537 | }, 538 | "source": [ 539 | "## 6.模型的预测\n", 540 | "\n", 541 | "\n", 542 | "#### 这里就是线下按照测试数据的大小随机生成测试数据来跑一下,看看能不能跑通模型,这段代码也可以写入提交的index.py文件里。\n", 543 | "\n", 544 | "#### 由于我对代码也一直在改动,注释里的内容也不一定是现在的版本的注释,各位看懂就好,不要在意注释中的错误。\n" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 6, 550 | "id": "b2fdbb11", 551 | "metadata": { 552 | "execution": { 553 | "iopub.execute_input": "2024-06-29T09:49:19.606817Z", 554 | "iopub.status.busy": "2024-06-29T09:49:19.603948Z", 555 | "iopub.status.idle": "2024-06-29T09:51:32.859984Z", 556 | "shell.execute_reply": "2024-06-29T09:51:32.858690Z" 557 | }, 558 | "papermill": { 559 | "duration": 133.286768, 560 | "end_time": "2024-06-29T09:51:32.873149", 561 | "exception": false, 562 | "start_time": "2024-06-29T09:49:19.586381", 563 | "status": "completed" 564 | }, 565 | "tags": [] 566 | }, 567 | "outputs": [ 568 | { 569 | "name": "stdout", 570 | "output_type": "stream", 571 | "text": [ 572 | "cenn_data.shape:(71, 56, 4, 1, 60)\n", 573 | "temp_lookback.shape:(71, 168, 60, 1)\n", 574 | "wind_lookback.shape:(71, 168, 60, 1)\n", 575 | "input.shape:(102240, 864)\n", 576 | "input.shape:(102240, 864)\n" 577 | ] 578 | }, 579 | { 580 | "data": { 581 | "text/plain": [ 582 | "((71, 24, 60, 1), (71, 24, 60, 1))" 583 | ] 584 | }, 585 | "execution_count": 6, 586 | "metadata": {}, 587 | "output_type": "execute_result" 588 | } 589 | ], 590 | "source": [ 591 | "#测试数据\n", 592 | "x=60#60个观测站点\n", 593 | "#71个不连续的周,56(每3个小时测一次),观测4个特征,9个观测方位,x个站点\n", 594 | "cenn_data=np.random.randn(71,56,4,9,x).mean(axis=-2,keepdims=True)#真实情况是np.load加载的\n", 595 | "print(f\"cenn_data.shape:{cenn_data.shape}\")\n", 596 | "#将3个小时变成1个小时 (71,168,4,1,x)\n", 597 | "cenn_data_hour=np.repeat(cenn_data, 3, axis=1)\n", 598 | "cenn_data_hour=cenn_data_hour.transpose(0,4,1,2,3)#71*x*168*4*9\n", 599 | "cenn_data_hour=cenn_data_hour.reshape(71*x,168,4)\n", 600 | "\n", 601 | "\n", 602 | "#cenn/temp_lookback.npy 71个不连续的周 1个小时一次 x站上一周的温度\n", 603 | "temp_lookback=np.random.randn(71,168,x,1)\n", 604 | "print(f\"temp_lookback.shape:{temp_lookback.shape}\")\n", 605 | "temp_lookback=temp_lookback.transpose(0,2,1,3)#71,x,168,1\n", 606 | "temp_lookback=temp_lookback.reshape(71*x,168,1)\n", 607 | "#cenn/wind_lookback.npy 71个不连续的周 1个小时一次 x站上一周的风速\n", 608 | "wind_lookback=np.random.randn(71,168,x,1)\n", 609 | "print(f\"wind_lookback.shape:{wind_lookback.shape}\")\n", 610 | "wind_lookback=wind_lookback.transpose(0,2,1,3)#71,x,168,1\n", 611 | "wind_lookback=wind_lookback.reshape(71*x,168,1)\n", 612 | "\n", 613 | "#71*x个站点,168小时,38个特征\n", 614 | "total_feats=np.concatenate((cenn_data_hour,temp_lookback,wind_lookback),axis=-1)\n", 615 | "\n", 616 | "def predict(feats,model_name='temp.pth',batch_size=128):\n", 617 | " model = BaselineModel()\n", 618 | " model.load_state_dict(torch.load(model_name))\n", 619 | " #预测24个小时的用的是71*x个站点(i:i-24)时刻的36个特征\n", 620 | " data=[]\n", 621 | " for i in range(24):\n", 622 | " data.append(feats[:,i:i-24,].reshape(feats.shape[0],-1))\n", 623 | " data=np.array(data)#24个小时,71*x个站点 【i:i-24时刻*36个特征】\n", 624 | " #【24小时*71*x个站点】*【i:i-24时刻*36个特征】\n", 625 | " data=data.reshape(-1,data.shape[-1])\n", 626 | " print(f\"input.shape:{data.shape}\")\n", 627 | " data=(data-train_mean)/train_std\n", 628 | " #test_preds=【24小时*71*x个站点】*【预测值】\n", 629 | " test_preds=np.zeros(len(data))\n", 630 | " for idx in range(0,len(data),batch_size):\n", 631 | " data1=torch.Tensor(data[idx:idx+batch_size]).reshape(-1,1,24,36)\n", 632 | " test_preds[idx:idx+batch_size]=model(data1).detach().numpy().reshape(-1)\n", 633 | " test_preds=test_preds.reshape(24,71,x,-1)\n", 634 | " #71个周,24小时,x个站点的预测值\n", 635 | " test_preds=test_preds.transpose(1,0,2,3)\n", 636 | " return test_preds\n", 637 | "test_preds1=predict(feats=total_feats,model_name=f'{date}temp.pth',batch_size=128)\n", 638 | "test_preds2=predict(feats=total_feats,model_name=f'{date}wind.pth',batch_size=128)\n", 639 | "test_preds1.shape,test_preds2.shape" 640 | ] 641 | } 642 | ], 643 | "metadata": { 644 | "kaggle": { 645 | "accelerator": "none", 646 | "dataSources": [ 647 | { 648 | "datasetId": 5226298, 649 | "sourceId": 8711844, 650 | "sourceType": "datasetVersion" 651 | } 652 | ], 653 | "dockerImageVersionId": 30732, 654 | "isGpuEnabled": false, 655 | "isInternetEnabled": true, 656 | "language": "python", 657 | "sourceType": "notebook" 658 | }, 659 | "kernelspec": { 660 | "display_name": "Python 3", 661 | "language": "python", 662 | "name": "python3" 663 | }, 664 | "language_info": { 665 | "codemirror_mode": { 666 | "name": "ipython", 667 | "version": 3 668 | }, 669 | "file_extension": ".py", 670 | "mimetype": "text/x-python", 671 | "name": "python", 672 | "nbconvert_exporter": "python", 673 | "pygments_lexer": "ipython3", 674 | "version": "3.10.13" 675 | }, 676 | "papermill": { 677 | "default_parameters": {}, 678 | "duration": 24192.07278, 679 | "end_time": "2024-06-29T09:51:37.464462", 680 | "environment_variables": {}, 681 | "exception": null, 682 | "input_path": "__notebook__.ipynb", 683 | "output_path": "__notebook__.ipynb", 684 | "parameters": {}, 685 | "start_time": "2024-06-29T03:08:25.391682", 686 | "version": "2.5.0" 687 | } 688 | }, 689 | "nbformat": 4, 690 | "nbformat_minor": 5 691 | } 692 | -------------------------------------------------------------------------------- /202407Kagglejtseptop2solutionstudy/jtsep-top2-solution-study.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2164bca6", 6 | "metadata": { 7 | "papermill": { 8 | "duration": 0.005926, 9 | "end_time": "2024-07-22T03:07:34.667554", 10 | "exception": false, 11 | "start_time": "2024-07-22T03:07:34.661628", 12 | "status": "completed" 13 | }, 14 | "tags": [] 15 | }, 16 | "source": [ 17 | "## Created by yunsuxiaozi 2024/7/22\n", 18 | "\n", 19 | "#### 比赛链接如下:JPX Tokyo Stock Exchange Prediction.这个方案和我往常解析的方案有点不同,我在看完代码之后在讨论区发现了这个About the 2nd place solution.这个代码是有错误的,将错误修正之后分数就下降了。也就是说这个代码是靠错误凭运气达到了第二名,第一名据说也是有错误的。但是,一方面我代码已经看完了,不开源就浪费了;另外这也是我第一次接触股票预测的比赛,这个代码对我来说还是有参考意义的,故整理如下。" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "4d78a97e", 25 | "metadata": { 26 | "papermill": { 27 | "duration": 0.004465, 28 | "end_time": "2024-07-22T03:07:34.676748", 29 | "exception": false, 30 | "start_time": "2024-07-22T03:07:34.672283", 31 | "status": "completed" 32 | }, 33 | "tags": [] 34 | }, 35 | "source": [ 36 | "### 1.导入必要的python库,并固定随机种子。" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 1, 42 | "id": "e48eb1bb", 43 | "metadata": { 44 | "execution": { 45 | "iopub.execute_input": "2024-07-22T03:07:34.689478Z", 46 | "iopub.status.busy": "2024-07-22T03:07:34.688042Z", 47 | "iopub.status.idle": "2024-07-22T03:07:37.014758Z", 48 | "shell.execute_reply": "2024-07-22T03:07:37.013532Z" 49 | }, 50 | "papermill": { 51 | "duration": 2.3368, 52 | "end_time": "2024-07-22T03:07:37.018112", 53 | "exception": false, 54 | "start_time": "2024-07-22T03:07:34.681312", 55 | "status": "completed" 56 | }, 57 | "tags": [] 58 | }, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/html": [ 63 | "\n" 84 | ], 85 | "text/plain": [ 86 | "" 87 | ] 88 | }, 89 | "metadata": {}, 90 | "output_type": "display_data" 91 | } 92 | ], 93 | "source": [ 94 | "import pandas as pd#导入csv文件的库\n", 95 | "import numpy as np#矩阵运算与科学计算的库\n", 96 | "from scipy import stats#统计学分析的python库\n", 97 | "import lightgbm as lgb#lightgbm模型\n", 98 | "import jpx_tokyo_market_prediction#这个是比赛官方的环境\n", 99 | "import warnings#避免一些可以忽略的报错\n", 100 | "warnings.filterwarnings('ignore')#filterwarnings()方法是用于设置警告过滤器的方法,它可以控制警告信息的输出方式和级别。\n", 101 | "\n", 102 | "import random#提供了一些用于生成随机数的函数\n", 103 | "#设置随机种子,保证模型可以复现\n", 104 | "def seed_everything(seed):\n", 105 | " np.random.seed(seed)#numpy的随机种子\n", 106 | " random.seed(seed)#python内置的随机种子\n", 107 | "seed_everything(seed=2024)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "326c6959", 113 | "metadata": { 114 | "papermill": { 115 | "duration": 0.004715, 116 | "end_time": "2024-07-22T03:07:37.028290", 117 | "exception": false, 118 | "start_time": "2024-07-22T03:07:37.023575", 119 | "status": "completed" 120 | }, 121 | "tags": [] 122 | }, 123 | "source": [ 124 | "### 2.读取数据集。" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 2, 130 | "id": "df5a70ab", 131 | "metadata": { 132 | "execution": { 133 | "iopub.execute_input": "2024-07-22T03:07:37.040118Z", 134 | "iopub.status.busy": "2024-07-22T03:07:37.039674Z", 135 | "iopub.status.idle": "2024-07-22T03:07:46.299629Z", 136 | "shell.execute_reply": "2024-07-22T03:07:46.298072Z" 137 | }, 138 | "papermill": { 139 | "duration": 9.269484, 140 | "end_time": "2024-07-22T03:07:46.302743", 141 | "exception": false, 142 | "start_time": "2024-07-22T03:07:37.033259", 143 | "status": "completed" 144 | }, 145 | "tags": [] 146 | }, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "len(train):2324923\n" 153 | ] 154 | }, 155 | { 156 | "data": { 157 | "text/html": [ 158 | "
\n", 159 | "\n", 172 | "\n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | "
DateSecuritiesCodeOpenHighLowCloseVolumeTarget
02017-01-0413012734.02755.02730.02742.0314000.000730
12017-01-041332568.0576.0563.0571.027985000.012324
22017-01-0413333150.03210.03140.03210.02708000.006154
32017-01-0413761510.01550.01510.01550.0113000.011053
42017-01-0413773270.03350.03270.03330.01508000.003026
\n", 244 | "
" 245 | ], 246 | "text/plain": [ 247 | " Date SecuritiesCode Open High Low Close Volume \\\n", 248 | "0 2017-01-04 1301 2734.0 2755.0 2730.0 2742.0 31400 \n", 249 | "1 2017-01-04 1332 568.0 576.0 563.0 571.0 2798500 \n", 250 | "2 2017-01-04 1333 3150.0 3210.0 3140.0 3210.0 270800 \n", 251 | "3 2017-01-04 1376 1510.0 1550.0 1510.0 1550.0 11300 \n", 252 | "4 2017-01-04 1377 3270.0 3350.0 3270.0 3330.0 150800 \n", 253 | "\n", 254 | " Target \n", 255 | "0 0.000730 \n", 256 | "1 0.012324 \n", 257 | "2 0.006154 \n", 258 | "3 0.011053 \n", 259 | "4 0.003026 " 260 | ] 261 | }, 262 | "execution_count": 2, 263 | "metadata": {}, 264 | "output_type": "execute_result" 265 | } 266 | ], 267 | "source": [ 268 | "#将CSV文件中的\"Date\"列解析为日期时间格式\n", 269 | "train = pd.read_csv(\"/kaggle/input/jpx-tokyo-stock-exchange-prediction/train_files/stock_prices.csv\",parse_dates=[\"Date\"])\n", 270 | "#RowId就是日期+证券代码,故drop 重复信息\n", 271 | "#ExpectedDividend 缺失值占比99% 缺失值\n", 272 | "#AdjustmentFactor=1的占比:0.9996 一列几乎只有唯一值\n", 273 | "#'SupervisionFlag'为False占比:0.9997 一列几乎只有唯一值\n", 274 | "#dropna,去掉有缺失值的行(不确定线性插值会不会更好)\n", 275 | "train=train.drop(columns=['RowId','ExpectedDividend','AdjustmentFactor','SupervisionFlag']).dropna().reset_index(drop=True)\n", 276 | "print(f\"len(train):{len(train)}\")\n", 277 | "#对测试数据进行同样的操作,测试数据要全部预测,所以不能dropna.\n", 278 | "test = pd.read_csv(\"/kaggle/input/jpx-tokyo-stock-exchange-prediction/supplemental_files/secondary_stock_prices.csv\",parse_dates=[\"Date\"])\n", 279 | "test=test.drop(columns=['RowId','ExpectedDividend','AdjustmentFactor','SupervisionFlag'])\n", 280 | "\n", 281 | "train.head()" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "id": "e3bd8804", 287 | "metadata": { 288 | "papermill": { 289 | "duration": 0.005583, 290 | "end_time": "2024-07-22T03:07:46.314061", 291 | "exception": false, 292 | "start_time": "2024-07-22T03:07:46.308478", 293 | "status": "completed" 294 | }, 295 | "tags": [] 296 | }, 297 | "source": [ 298 | "### 3.特征工程。这里的特征工程存在错误。它应该groupby 'SecuritiesCode'来考虑移动特征,而它这里只是20行移动。" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 3, 304 | "id": "6a3badd4", 305 | "metadata": { 306 | "execution": { 307 | "iopub.execute_input": "2024-07-22T03:07:46.326917Z", 308 | "iopub.status.busy": "2024-07-22T03:07:46.326469Z", 309 | "iopub.status.idle": "2024-07-22T03:08:10.159536Z", 310 | "shell.execute_reply": "2024-07-22T03:08:10.158257Z" 311 | }, 312 | "papermill": { 313 | "duration": 23.843066, 314 | "end_time": "2024-07-22T03:08:10.162700", 315 | "exception": false, 316 | "start_time": "2024-07-22T03:07:46.319634", 317 | "status": "completed" 318 | }, 319 | "tags": [] 320 | }, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/html": [ 325 | "
\n", 326 | "\n", 339 | "\n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | "
DateSecuritiesCodeOpenHighLowCloseVolumeTargetreturn_1monthreturn_2monthreturn_3monthvolatility_1monthvolatility_2monthvolatility_3monthMA_gap_1monthMA_gap_2monthMA_gap_3month
02017-01-0413012734.02755.02730.02742.0314000.0007300.00.00.00.00.00.00.00.00.0
12017-01-041332568.0576.0563.0571.027985000.0123240.00.00.00.00.00.00.00.00.0
22017-01-0413333150.03210.03140.03210.02708000.0061540.00.00.00.00.00.00.00.00.0
32017-01-0413761510.01550.01510.01550.0113000.0110530.00.00.00.00.00.00.00.00.0
42017-01-0413773270.03350.03270.03330.01508000.0030260.00.00.00.00.00.00.00.00.0
\n", 465 | "
" 466 | ], 467 | "text/plain": [ 468 | " Date SecuritiesCode Open High Low Close Volume \\\n", 469 | "0 2017-01-04 1301 2734.0 2755.0 2730.0 2742.0 31400 \n", 470 | "1 2017-01-04 1332 568.0 576.0 563.0 571.0 2798500 \n", 471 | "2 2017-01-04 1333 3150.0 3210.0 3140.0 3210.0 270800 \n", 472 | "3 2017-01-04 1376 1510.0 1550.0 1510.0 1550.0 11300 \n", 473 | "4 2017-01-04 1377 3270.0 3350.0 3270.0 3330.0 150800 \n", 474 | "\n", 475 | " Target return_1month return_2month return_3month volatility_1month \\\n", 476 | "0 0.000730 0.0 0.0 0.0 0.0 \n", 477 | "1 0.012324 0.0 0.0 0.0 0.0 \n", 478 | "2 0.006154 0.0 0.0 0.0 0.0 \n", 479 | "3 0.011053 0.0 0.0 0.0 0.0 \n", 480 | "4 0.003026 0.0 0.0 0.0 0.0 \n", 481 | "\n", 482 | " volatility_2month volatility_3month MA_gap_1month MA_gap_2month \\\n", 483 | "0 0.0 0.0 0.0 0.0 \n", 484 | "1 0.0 0.0 0.0 0.0 \n", 485 | "2 0.0 0.0 0.0 0.0 \n", 486 | "3 0.0 0.0 0.0 0.0 \n", 487 | "4 0.0 0.0 0.0 0.0 \n", 488 | "\n", 489 | " MA_gap_3month \n", 490 | "0 0.0 \n", 491 | "1 0.0 \n", 492 | "2 0.0 \n", 493 | "3 0.0 \n", 494 | "4 0.0 " 495 | ] 496 | }, 497 | "execution_count": 3, 498 | "metadata": {}, 499 | "output_type": "execute_result" 500 | } 501 | ], 502 | "source": [ 503 | "def add_features(feats):\n", 504 | " #股票相比1个月,2个月,3个月的回报率 pct_change (v_t-v_{t-1}}/v_{t-1})\n", 505 | " feats[\"return_1month\"] = feats[\"Close\"].pct_change(20)\n", 506 | " feats[\"return_2month\"] = feats[\"Close\"].pct_change(40)\n", 507 | " feats[\"return_3month\"] = feats[\"Close\"].pct_change(60)\n", 508 | " \n", 509 | " #股票的波动性,取log可能和长尾分布有关,diff就是作差,rolling是考虑一段时间,std是方差\n", 510 | " feats[\"volatility_1month\"] = (\n", 511 | " np.log(feats[\"Close\"]).diff().rolling(20).std()\n", 512 | " )\n", 513 | " feats[\"volatility_2month\"] = (\n", 514 | " np.log(feats[\"Close\"]).diff().rolling(40).std()\n", 515 | " )\n", 516 | " feats[\"volatility_3month\"] = (\n", 517 | " np.log(feats[\"Close\"]).diff().rolling(60).std()\n", 518 | " )\n", 519 | " \n", 520 | " #股票的收盘价/股票一个月收盘价的移动平均\n", 521 | " feats[\"MA_gap_1month\"] = feats[\"Close\"] / (\n", 522 | " feats[\"Close\"].rolling(20).mean()\n", 523 | " )\n", 524 | " feats[\"MA_gap_2month\"] = feats[\"Close\"] / (\n", 525 | " feats[\"Close\"].rolling(40).mean()\n", 526 | " )\n", 527 | " feats[\"MA_gap_3month\"] = feats[\"Close\"] / (\n", 528 | " feats[\"Close\"].rolling(60).mean()\n", 529 | " )\n", 530 | " \n", 531 | " return feats\n", 532 | "\n", 533 | "#将缺失值,np.inf,-np.inf都转成0\n", 534 | "def fill_nan_inf(df):\n", 535 | " df = df.fillna(0)\n", 536 | " df = df.replace([np.inf, -np.inf], 0)\n", 537 | " return df\n", 538 | "train = add_features(train)\n", 539 | "train=fill_nan_inf(train)\n", 540 | "test = add_features(test)\n", 541 | "test=fill_nan_inf(test)\n", 542 | "train.head()" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "id": "63ead625", 548 | "metadata": { 549 | "papermill": { 550 | "duration": 0.005445, 551 | "end_time": "2024-07-22T03:08:10.174276", 552 | "exception": false, 553 | "start_time": "2024-07-22T03:08:10.168831", 554 | "status": "completed" 555 | }, 556 | "tags": [] 557 | }, 558 | "source": [ 559 | "### 4.构造训练验证数据集.这里分别选择了1000支股票。" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 4, 565 | "id": "18ceadc9", 566 | "metadata": { 567 | "execution": { 568 | "iopub.execute_input": "2024-07-22T03:08:10.188678Z", 569 | "iopub.status.busy": "2024-07-22T03:08:10.188163Z", 570 | "iopub.status.idle": "2024-07-22T03:08:11.019641Z", 571 | "shell.execute_reply": "2024-07-22T03:08:11.018413Z" 572 | }, 573 | "papermill": { 574 | "duration": 0.842294, 575 | "end_time": "2024-07-22T03:08:11.022846", 576 | "exception": false, 577 | "start_time": "2024-07-22T03:08:10.180552", 578 | "status": "completed" 579 | }, 580 | "tags": [] 581 | }, 582 | "outputs": [ 583 | { 584 | "name": "stdout", 585 | "output_type": "stream", 586 | "text": [ 587 | "len(train_securitiescode):1000,len(valid_securitiescode):1000\n" 588 | ] 589 | } 590 | ], 591 | "source": [ 592 | "#Target是标签,date是字符串,SecuritiesCode没什么用\n", 593 | "features =list(train.drop(['Target','Date','SecuritiesCode'],axis=1).columns) \n", 594 | "\n", 595 | "#每个证券代码的max_target-min_target排序\n", 596 | "group=(train.groupby('SecuritiesCode')['Target'].max()-train.groupby('SecuritiesCode')['Target'].min()).sort_values()\n", 597 | "#取max_target-min_target最小的1000个股票\n", 598 | "list_spred_h=list(group[:1000].index)\n", 599 | "#剩下的股票\n", 600 | "list_spred_l=list(group[1000:].index)\n", 601 | "print(f\"len(train_securitiescode):{len(list_spred_h)},len(valid_securitiescode):{len(list_spred_l)}\")\n", 602 | "\n", 603 | "train_X=train[train['SecuritiesCode'].isin(list_spred_h)][features]\n", 604 | "train_y=train[train['SecuritiesCode'].isin(list_spred_h)][\"Target\"]\n", 605 | "valid_X=train[train['SecuritiesCode'].isin(list_spred_l)][features]\n", 606 | "valid_y=train[train['SecuritiesCode'].isin(list_spred_l)][\"Target\"]\n", 607 | "\n", 608 | "tr_dataset = lgb.Dataset(train_X,train_y,feature_name = features)\n", 609 | "vl_dataset = lgb.Dataset(valid_X,valid_y,feature_name = features)" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "id": "8364e2dd", 615 | "metadata": { 616 | "papermill": { 617 | "duration": 0.006047, 618 | "end_time": "2024-07-22T03:08:11.034958", 619 | "exception": false, 620 | "start_time": "2024-07-22T03:08:11.028911", 621 | "status": "completed" 622 | }, 623 | "tags": [] 624 | }, 625 | "source": [ 626 | "### 5.模型的训练。" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": 5, 632 | "id": "ee6ec113", 633 | "metadata": { 634 | "execution": { 635 | "iopub.execute_input": "2024-07-22T03:08:11.049176Z", 636 | "iopub.status.busy": "2024-07-22T03:08:11.048723Z", 637 | "iopub.status.idle": "2024-07-22T03:13:13.092347Z", 638 | "shell.execute_reply": "2024-07-22T03:13:13.090990Z" 639 | }, 640 | "papermill": { 641 | "duration": 302.055117, 642 | "end_time": "2024-07-22T03:13:13.096000", 643 | "exception": false, 644 | "start_time": "2024-07-22T03:08:11.040883", 645 | "status": "completed" 646 | }, 647 | "tags": [] 648 | }, 649 | "outputs": [ 650 | { 651 | "name": "stdout", 652 | "output_type": "stream", 653 | "text": [ 654 | "Training until validation scores don't improve for 300 rounds\n", 655 | "[100]\ttraining's pearsonr: 0.0564282\tvalid_1's pearsonr: 0.0108009\n", 656 | "[200]\ttraining's pearsonr: 0.0680563\tvalid_1's pearsonr: 0.0134107\n", 657 | "[300]\ttraining's pearsonr: 0.0761529\tvalid_1's pearsonr: 0.0142165\n", 658 | "[400]\ttraining's pearsonr: 0.082453\tvalid_1's pearsonr: 0.0146069\n", 659 | "[500]\ttraining's pearsonr: 0.0883774\tvalid_1's pearsonr: 0.0147148\n", 660 | "[600]\ttraining's pearsonr: 0.0938508\tvalid_1's pearsonr: 0.0148582\n", 661 | "[700]\ttraining's pearsonr: 0.0986576\tvalid_1's pearsonr: 0.0148196\n", 662 | "[800]\ttraining's pearsonr: 0.103034\tvalid_1's pearsonr: 0.0146873\n", 663 | "[900]\ttraining's pearsonr: 0.106989\tvalid_1's pearsonr: 0.0146611\n", 664 | "Early stopping, best iteration is:\n", 665 | "[606]\ttraining's pearsonr: 0.0942029\tvalid_1's pearsonr: 0.0148682\n" 666 | ] 667 | } 668 | ], 669 | "source": [ 670 | "def feval_pearsonr(y_pred, lgb_train):\n", 671 | " y_true = lgb_train.get_label()\n", 672 | " return 'pearsonr', stats.pearsonr(y_true, y_pred)[0], True\n", 673 | "\n", 674 | "#lgb模型的参数\n", 675 | "params_lgb = {'learning_rate': 0.005,\n", 676 | " 'metric':'None',\n", 677 | " 'objective': 'regression',\n", 678 | " 'boosting': 'gbdt',\n", 679 | " 'verbosity': 0,\n", 680 | " 'n_jobs': -1,\n", 681 | " 'force_col_wise':True}\n", 682 | "\n", 683 | "model = lgb.train(params = params_lgb, \n", 684 | " train_set = tr_dataset, \n", 685 | " valid_sets = [tr_dataset, vl_dataset], \n", 686 | " num_boost_round = 1000, \n", 687 | " feval=feval_pearsonr,\n", 688 | " callbacks=[ lgb.early_stopping(stopping_rounds=300, verbose=True), \n", 689 | " lgb.log_evaluation(period=100)]) " 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "id": "757552a8", 695 | "metadata": { 696 | "papermill": { 697 | "duration": 0.007233, 698 | "end_time": "2024-07-22T03:13:13.110997", 699 | "exception": false, 700 | "start_time": "2024-07-22T03:13:13.103764", 701 | "status": "completed" 702 | }, 703 | "tags": [] 704 | }, 705 | "source": [ 706 | "### 6.模型的推理.这里需要将数值转换成rank." 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 6, 712 | "id": "2319a94e", 713 | "metadata": { 714 | "execution": { 715 | "iopub.execute_input": "2024-07-22T03:13:13.127150Z", 716 | "iopub.status.busy": "2024-07-22T03:13:13.126684Z", 717 | "iopub.status.idle": "2024-07-22T03:13:13.629007Z", 718 | "shell.execute_reply": "2024-07-22T03:13:13.627724Z" 719 | }, 720 | "papermill": { 721 | "duration": 0.513646, 722 | "end_time": "2024-07-22T03:13:13.631831", 723 | "exception": false, 724 | "start_time": "2024-07-22T03:13:13.118185", 725 | "status": "completed" 726 | }, 727 | "tags": [] 728 | }, 729 | "outputs": [ 730 | { 731 | "name": "stdout", 732 | "output_type": "stream", 733 | "text": [ 734 | "This version of the API is not optimized and should not be used to estimate the runtime of your code on the hidden test set.\n" 735 | ] 736 | }, 737 | { 738 | "data": { 739 | "text/html": [ 740 | "
\n", 741 | "\n", 754 | "\n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | "
DateSecuritiesCodeRank
02021-12-071301497
12021-12-071332498
22021-12-071333499
32021-12-071375500
42021-12-0713761916
\n", 796 | "
" 797 | ], 798 | "text/plain": [ 799 | " Date SecuritiesCode Rank\n", 800 | "0 2021-12-07 1301 497\n", 801 | "1 2021-12-07 1332 498\n", 802 | "2 2021-12-07 1333 499\n", 803 | "3 2021-12-07 1375 500\n", 804 | "4 2021-12-07 1376 1916" 805 | ] 806 | }, 807 | "execution_count": 6, 808 | "metadata": {}, 809 | "output_type": "execute_result" 810 | } 811 | ], 812 | "source": [ 813 | "sample_submission = pd.read_csv(\"/kaggle/input/jpx-tokyo-stock-exchange-prediction/example_test_files/sample_submission.csv\")\n", 814 | "#创建比赛的环境\n", 815 | "env = jpx_tokyo_market_prediction.make_env()\n", 816 | "#遍历测试数据的一个迭代器\n", 817 | "iter_test = env.iter_test()\n", 818 | "\n", 819 | "def add_rank(df):\n", 820 | " df[\"Rank\"] = df.groupby(\"Date\")[\"Target\"].rank(ascending=False, method=\"first\") - 1 \n", 821 | " df[\"Rank\"] = df[\"Rank\"].astype(\"int\")\n", 822 | " return df\n", 823 | "\n", 824 | "\n", 825 | "for (prices, options, financials, trades, secondary_prices, sample_prediction) in iter_test: \n", 826 | " prices = add_features(prices)\n", 827 | " prices['Target'] = model.predict(fill_nan_inf(prices)[features])\n", 828 | " prices['target_median']=prices.groupby(\"Date\")[\"Target\"].transform('median')\n", 829 | " prices.loc[prices['SecuritiesCode'].isin(list_spred_h),'Target']=prices['target_median']\n", 830 | " prices = add_rank(prices)\n", 831 | " sample_prediction['Rank'] = prices['Rank']\n", 832 | " env.predict(sample_prediction)\n", 833 | " \n", 834 | "sample_prediction.head()" 835 | ] 836 | } 837 | ], 838 | "metadata": { 839 | "kaggle": { 840 | "accelerator": "none", 841 | "dataSources": [ 842 | { 843 | "databundleVersionId": 3935619, 844 | "sourceId": 34349, 845 | "sourceType": "competition" 846 | } 847 | ], 848 | "dockerImageVersionId": 30301, 849 | "isGpuEnabled": false, 850 | "isInternetEnabled": false, 851 | "language": "python", 852 | "sourceType": "notebook" 853 | }, 854 | "kernelspec": { 855 | "display_name": "Python 3", 856 | "language": "python", 857 | "name": "python3" 858 | }, 859 | "language_info": { 860 | "codemirror_mode": { 861 | "name": "ipython", 862 | "version": 3 863 | }, 864 | "file_extension": ".py", 865 | "mimetype": "text/x-python", 866 | "name": "python", 867 | "nbconvert_exporter": "python", 868 | "pygments_lexer": "ipython3", 869 | "version": "3.7.12" 870 | }, 871 | "papermill": { 872 | "default_parameters": {}, 873 | "duration": 352.314424, 874 | "end_time": "2024-07-22T03:13:14.464123", 875 | "environment_variables": {}, 876 | "exception": null, 877 | "input_path": "__notebook__.ipynb", 878 | "output_path": "__notebook__.ipynb", 879 | "parameters": {}, 880 | "start_time": "2024-07-22T03:07:22.149699", 881 | "version": "2.3.4" 882 | } 883 | }, 884 | "nbformat": 4, 885 | "nbformat_minor": 5 886 | } 887 | -------------------------------------------------------------------------------- /202407chatglm6b微调/chatglm6b-huanhuan-finetune-inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "52620c3c", 6 | "metadata": { 7 | "papermill": { 8 | "duration": 0.003538, 9 | "end_time": "2024-07-31T13:51:30.170739", 10 | "exception": false, 11 | "start_time": "2024-07-31T13:51:30.167201", 12 | "status": "completed" 13 | }, 14 | "tags": [] 15 | }, 16 | "source": [ 17 | "## Created by yunsuxiaozi 2024/7/31\n", 18 | "\n", 19 | "#### 在chatglm6b-huanhuan-finetune(training)这个notebook里,我们使用了甄嬛传的数据集完成了chatglm-6b大模型的微调,训练出了专属于我们的个性化AI--chat_huanhuan。如果有人想要体验chat_huanhuan,可以使用这个notebook来进行体验,下面为示例代码。\n", 20 | "\n", 21 | "#### 注:由于这里的代码都在模型微调的代码里出现过,如果有人想看代码的具体解释可以看模型微调的notebook." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "id": "212245e1", 28 | "metadata": { 29 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 30 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 31 | "execution": { 32 | "iopub.execute_input": "2024-07-31T13:51:30.178138Z", 33 | "iopub.status.busy": "2024-07-31T13:51:30.177835Z", 34 | "iopub.status.idle": "2024-07-31T13:51:32.730571Z", 35 | "shell.execute_reply": "2024-07-31T13:51:32.729659Z" 36 | }, 37 | "papermill": { 38 | "duration": 2.559049, 39 | "end_time": "2024-07-31T13:51:32.732886", 40 | "exception": false, 41 | "start_time": "2024-07-31T13:51:30.173837", 42 | "status": "completed" 43 | }, 44 | "tags": [] 45 | }, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Cloning into 'ChatGLM-6B'...\r\n", 52 | "remote: Enumerating objects: 1252, done.\u001b[K\r\n", 53 | "remote: Counting objects: 100% (17/17), done.\u001b[K\r\n", 54 | "remote: Compressing objects: 100% (11/11), done.\u001b[K\r\n", 55 | "remote: Total 1252 (delta 8), reused 11 (delta 6), pack-reused 1235\u001b[K\r\n", 56 | "Receiving objects: 100% (1252/1252), 9.15 MiB | 17.78 MiB/s, done.\r\n", 57 | "Resolving deltas: 100% (737/737), done.\r\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "!git clone https://github.com/THUDM/ChatGLM-6B.git" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "id": "33bcd743", 69 | "metadata": { 70 | "execution": { 71 | "iopub.execute_input": "2024-07-31T13:51:32.742794Z", 72 | "iopub.status.busy": "2024-07-31T13:51:32.742045Z", 73 | "iopub.status.idle": "2024-07-31T13:52:02.312940Z", 74 | "shell.execute_reply": "2024-07-31T13:52:02.312030Z" 75 | }, 76 | "papermill": { 77 | "duration": 29.578182, 78 | "end_time": "2024-07-31T13:52:02.315209", 79 | "exception": false, 80 | "start_time": "2024-07-31T13:51:32.737027", 81 | "status": "completed" 82 | }, 83 | "tags": [] 84 | }, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "Requirement already satisfied: protobuf in /opt/conda/lib/python3.10/site-packages (from -r ChatGLM-6B/requirements.txt (line 1)) (3.20.3)\r\n", 91 | "Collecting transformers==4.27.1 (from -r ChatGLM-6B/requirements.txt (line 2))\r\n", 92 | " Downloading transformers-4.27.1-py3-none-any.whl.metadata (106 kB)\r\n", 93 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m106.7/106.7 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 94 | "\u001b[?25hCollecting cpm_kernels (from -r ChatGLM-6B/requirements.txt (line 3))\r\n", 95 | " Downloading cpm_kernels-1.0.11-py3-none-any.whl.metadata (1.2 kB)\r\n", 96 | "Requirement already satisfied: torch>=1.10 in /opt/conda/lib/python3.10/site-packages (from -r ChatGLM-6B/requirements.txt (line 4)) (2.1.2)\r\n", 97 | "Collecting gradio (from -r ChatGLM-6B/requirements.txt (line 5))\r\n", 98 | " Downloading gradio-4.39.0-py3-none-any.whl.metadata (15 kB)\r\n", 99 | "Collecting mdtex2html (from -r ChatGLM-6B/requirements.txt (line 6))\r\n", 100 | " Downloading mdtex2html-1.3.0-py3-none-any.whl.metadata (4.1 kB)\r\n", 101 | "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.10/site-packages (from -r ChatGLM-6B/requirements.txt (line 7)) (0.2.0)\r\n", 102 | "Requirement already satisfied: accelerate in /opt/conda/lib/python3.10/site-packages (from -r ChatGLM-6B/requirements.txt (line 8)) (0.32.1)\r\n", 103 | "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (3.13.1)\r\n", 104 | "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (0.23.4)\r\n", 105 | "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (1.26.4)\r\n", 106 | "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (21.3)\r\n", 107 | "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (6.0.1)\r\n", 108 | "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (2023.12.25)\r\n", 109 | "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (2.32.3)\r\n", 110 | "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2))\r\n", 111 | " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\r\n", 112 | "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (4.66.4)\r\n", 113 | "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (4.9.0)\r\n", 114 | "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (1.13.0)\r\n", 115 | "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (3.2.1)\r\n", 116 | "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (3.1.2)\r\n", 117 | "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (2024.5.0)\r\n", 118 | "Requirement already satisfied: aiofiles<24.0,>=22.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (22.1.0)\r\n", 119 | "Requirement already satisfied: anyio<5.0,>=3.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (4.2.0)\r\n", 120 | "Requirement already satisfied: fastapi in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.108.0)\r\n", 121 | "Collecting ffmpy (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 122 | " Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)\r\n", 123 | "Collecting gradio-client==1.1.1 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 124 | " Downloading gradio_client-1.1.1-py3-none-any.whl.metadata (7.1 kB)\r\n", 125 | "Requirement already satisfied: httpx>=0.24.1 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.27.0)\r\n", 126 | "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (6.1.1)\r\n", 127 | "Requirement already satisfied: markupsafe~=2.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.1.3)\r\n", 128 | "Requirement already satisfied: matplotlib~=3.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (3.7.5)\r\n", 129 | "Requirement already satisfied: orjson~=3.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (3.9.10)\r\n", 130 | "Requirement already satisfied: pandas<3.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.2.2)\r\n", 131 | "Requirement already satisfied: pillow<11.0,>=8.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (9.5.0)\r\n", 132 | "Requirement already satisfied: pydantic>=2.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.5.3)\r\n", 133 | "Requirement already satisfied: pydub in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.25.1)\r\n", 134 | "Collecting python-multipart>=0.0.9 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 135 | " Downloading python_multipart-0.0.9-py3-none-any.whl.metadata (2.5 kB)\r\n", 136 | "Collecting ruff>=0.2.2 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 137 | " Downloading ruff-0.5.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)\r\n", 138 | "Collecting semantic-version~=2.0 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 139 | " Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)\r\n", 140 | "Collecting tomlkit==0.12.0 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 141 | " Downloading tomlkit-0.12.0-py3-none-any.whl.metadata (2.7 kB)\r\n", 142 | "Requirement already satisfied: typer<1.0,>=0.12 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.12.3)\r\n", 143 | "Collecting urllib3~=2.0 (from gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 144 | " Downloading urllib3-2.2.2-py3-none-any.whl.metadata (6.4 kB)\r\n", 145 | "Requirement already satisfied: uvicorn>=0.14.0 in /opt/conda/lib/python3.10/site-packages (from gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.25.0)\r\n", 146 | "Collecting websockets<12.0,>=10.0 (from gradio-client==1.1.1->gradio->-r ChatGLM-6B/requirements.txt (line 5))\r\n", 147 | " Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\r\n", 148 | "Requirement already satisfied: markdown in /opt/conda/lib/python3.10/site-packages (from mdtex2html->-r ChatGLM-6B/requirements.txt (line 6)) (3.5.2)\r\n", 149 | "Collecting latex2mathml (from mdtex2html->-r ChatGLM-6B/requirements.txt (line 6))\r\n", 150 | " Downloading latex2mathml-3.77.0-py3-none-any.whl.metadata (14 kB)\r\n", 151 | "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate->-r ChatGLM-6B/requirements.txt (line 8)) (5.9.3)\r\n", 152 | "Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from accelerate->-r ChatGLM-6B/requirements.txt (line 8)) (0.4.3)\r\n", 153 | "Requirement already satisfied: idna>=2.8 in /opt/conda/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (3.6)\r\n", 154 | "Requirement already satisfied: sniffio>=1.1 in /opt/conda/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.3.0)\r\n", 155 | "Requirement already satisfied: exceptiongroup>=1.0.2 in /opt/conda/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.2.0)\r\n", 156 | "Requirement already satisfied: certifi in /opt/conda/lib/python3.10/site-packages (from httpx>=0.24.1->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2024.7.4)\r\n", 157 | "Requirement already satisfied: httpcore==1.* in /opt/conda/lib/python3.10/site-packages (from httpx>=0.24.1->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.0.5)\r\n", 158 | "Requirement already satisfied: h11<0.15,>=0.13 in /opt/conda/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.24.1->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.14.0)\r\n", 159 | "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.2.0)\r\n", 160 | "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.12.1)\r\n", 161 | "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (4.47.0)\r\n", 162 | "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.4.5)\r\n", 163 | "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (3.1.1)\r\n", 164 | "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.9.0.post0)\r\n", 165 | "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2023.3.post1)\r\n", 166 | "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2023.4)\r\n", 167 | "Requirement already satisfied: annotated-types>=0.4.0 in /opt/conda/lib/python3.10/site-packages (from pydantic>=2.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.6.0)\r\n", 168 | "Requirement already satisfied: pydantic-core==2.14.6 in /opt/conda/lib/python3.10/site-packages (from pydantic>=2.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.14.6)\r\n", 169 | "Requirement already satisfied: click>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (8.1.7)\r\n", 170 | "Requirement already satisfied: shellingham>=1.3.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.5.4)\r\n", 171 | "Requirement already satisfied: rich>=10.11.0 in /opt/conda/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (13.7.0)\r\n", 172 | "Requirement already satisfied: starlette<0.33.0,>=0.29.0 in /opt/conda/lib/python3.10/site-packages (from fastapi->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.32.0.post1)\r\n", 173 | "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.27.1->-r ChatGLM-6B/requirements.txt (line 2)) (3.3.2)\r\n", 174 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.10->-r ChatGLM-6B/requirements.txt (line 4)) (1.3.0)\r\n", 175 | "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (1.16.0)\r\n", 176 | "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (3.0.0)\r\n", 177 | "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (2.17.2)\r\n", 178 | "Requirement already satisfied: mdurl~=0.1 in /opt/conda/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio->-r ChatGLM-6B/requirements.txt (line 5)) (0.1.2)\r\n", 179 | "Downloading transformers-4.27.1-py3-none-any.whl (6.7 MB)\r\n", 180 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.7/6.7 MB\u001b[0m \u001b[31m59.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 181 | "\u001b[?25hDownloading cpm_kernels-1.0.11-py3-none-any.whl (416 kB)\r\n", 182 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m416.6/416.6 kB\u001b[0m \u001b[31m26.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 183 | "\u001b[?25hDownloading gradio-4.39.0-py3-none-any.whl (12.4 MB)\r\n", 184 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m79.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 185 | "\u001b[?25hDownloading gradio_client-1.1.1-py3-none-any.whl (318 kB)\r\n", 186 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m318.2/318.2 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 187 | "\u001b[?25hDownloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\r\n", 188 | "Downloading mdtex2html-1.3.0-py3-none-any.whl (13 kB)\r\n", 189 | "Downloading python_multipart-0.0.9-py3-none-any.whl (22 kB)\r\n", 190 | "Downloading ruff-0.5.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.1 MB)\r\n", 191 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m50.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 192 | "\u001b[?25hDownloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\r\n", 193 | "Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\r\n", 194 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m61.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 195 | "\u001b[?25hDownloading urllib3-2.2.2-py3-none-any.whl (121 kB)\r\n", 196 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.4/121.4 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 197 | "\u001b[?25hDownloading ffmpy-0.4.0-py3-none-any.whl (5.8 kB)\r\n", 198 | "Downloading latex2mathml-3.77.0-py3-none-any.whl (73 kB)\r\n", 199 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.7/73.7 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 200 | "\u001b[?25hDownloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\r\n", 201 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m11.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 202 | "\u001b[?25hInstalling collected packages: tokenizers, cpm_kernels, websockets, urllib3, tomlkit, semantic-version, ruff, python-multipart, latex2mathml, ffmpy, mdtex2html, transformers, gradio-client, gradio\r\n", 203 | " Attempting uninstall: tokenizers\r\n", 204 | " Found existing installation: tokenizers 0.19.1\r\n", 205 | " Uninstalling tokenizers-0.19.1:\r\n", 206 | " Successfully uninstalled tokenizers-0.19.1\r\n", 207 | " Attempting uninstall: websockets\r\n", 208 | " Found existing installation: websockets 12.0\r\n", 209 | " Uninstalling websockets-12.0:\r\n", 210 | " Successfully uninstalled websockets-12.0\r\n", 211 | " Attempting uninstall: urllib3\r\n", 212 | " Found existing installation: urllib3 1.26.18\r\n", 213 | " Uninstalling urllib3-1.26.18:\r\n", 214 | " Successfully uninstalled urllib3-1.26.18\r\n", 215 | " Attempting uninstall: tomlkit\r\n", 216 | " Found existing installation: tomlkit 0.12.5\r\n", 217 | " Uninstalling tomlkit-0.12.5:\r\n", 218 | " Successfully uninstalled tomlkit-0.12.5\r\n", 219 | " Attempting uninstall: transformers\r\n", 220 | " Found existing installation: transformers 4.42.3\r\n", 221 | " Uninstalling transformers-4.42.3:\r\n", 222 | " Successfully uninstalled transformers-4.42.3\r\n", 223 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", 224 | "tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.\r\n", 225 | "distributed 2024.5.1 requires dask==2024.5.1, but you have dask 2024.7.0 which is incompatible.\r\n", 226 | "kaggle-environments 1.14.15 requires transformers>=4.33.1, but you have transformers 4.27.1 which is incompatible.\r\n", 227 | "kfp 2.5.0 requires google-cloud-storage<3,>=2.2.1, but you have google-cloud-storage 1.44.0 which is incompatible.\r\n", 228 | "kfp 2.5.0 requires urllib3<2.0.0, but you have urllib3 2.2.2 which is incompatible.\r\n", 229 | "rapids-dask-dependency 24.6.0a0 requires dask==2024.5.1, but you have dask 2024.7.0 which is incompatible.\r\n", 230 | "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.4.1 which is incompatible.\r\n", 231 | "ydata-profiling 4.6.4 requires numpy<1.26,>=1.16.0, but you have numpy 1.26.4 which is incompatible.\u001b[0m\u001b[31m\r\n", 232 | "\u001b[0mSuccessfully installed cpm_kernels-1.0.11 ffmpy-0.4.0 gradio-4.39.0 gradio-client-1.1.1 latex2mathml-3.77.0 mdtex2html-1.3.0 python-multipart-0.0.9 ruff-0.5.5 semantic-version-2.10.0 tokenizers-0.13.3 tomlkit-0.12.0 transformers-4.27.1 urllib3-2.1.0 websockets-11.0.3\r\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "!pip install -r ChatGLM-6B/requirements.txt #安装chatglm需要依赖的库" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 3, 243 | "id": "a051a8f2", 244 | "metadata": { 245 | "execution": { 246 | "iopub.execute_input": "2024-07-31T13:52:02.334988Z", 247 | "iopub.status.busy": "2024-07-31T13:52:02.334686Z", 248 | "iopub.status.idle": "2024-07-31T13:52:15.233281Z", 249 | "shell.execute_reply": "2024-07-31T13:52:15.232081Z" 250 | }, 251 | "papermill": { 252 | "duration": 12.911303, 253 | "end_time": "2024-07-31T13:52:15.235864", 254 | "exception": false, 255 | "start_time": "2024-07-31T13:52:02.324561", 256 | "status": "completed" 257 | }, 258 | "tags": [] 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "!pip install -q rouge_chinese nltk jieba datasets " 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 4, 268 | "id": "69078d8f", 269 | "metadata": { 270 | "execution": { 271 | "iopub.execute_input": "2024-07-31T13:52:15.266596Z", 272 | "iopub.status.busy": "2024-07-31T13:52:15.266091Z", 273 | "iopub.status.idle": "2024-07-31T13:55:28.942716Z", 274 | "shell.execute_reply": "2024-07-31T13:55:28.941294Z" 275 | }, 276 | "papermill": { 277 | "duration": 193.697638, 278 | "end_time": "2024-07-31T13:55:28.945730", 279 | "exception": false, 280 | "start_time": "2024-07-31T13:52:15.248092", 281 | "status": "completed" 282 | }, 283 | "tags": [] 284 | }, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Cloning into 'chatglm-6b-int4'...\r\n", 291 | "remote: Enumerating objects: 137, done.\u001b[K\r\n", 292 | "remote: Total 137 (delta 0), reused 0 (delta 0), pack-reused 137 (from 1)\u001b[K\r\n", 293 | "Receiving objects: 100% (137/137), 62.10 KiB | 15.53 MiB/s, done.\r\n", 294 | "Resolving deltas: 100% (79/79), done.\r\n", 295 | "Filtering content: 100% (2/2), 3.62 GiB | 19.37 MiB/s, done.\r\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "!git clone https://huggingface.co/THUDM/chatglm-6b-int4" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 5, 306 | "id": "0ea30249", 307 | "metadata": { 308 | "execution": { 309 | "iopub.execute_input": "2024-07-31T13:55:29.036405Z", 310 | "iopub.status.busy": "2024-07-31T13:55:29.035622Z", 311 | "iopub.status.idle": "2024-07-31T13:55:45.765993Z", 312 | "shell.execute_reply": "2024-07-31T13:55:45.765205Z" 313 | }, 314 | "papermill": { 315 | "duration": 16.778985, 316 | "end_time": "2024-07-31T13:55:45.768167", 317 | "exception": false, 318 | "start_time": "2024-07-31T13:55:28.989182", 319 | "status": "completed" 320 | }, 321 | "tags": [] 322 | }, 323 | "outputs": [ 324 | { 325 | "name": "stderr", 326 | "output_type": "stream", 327 | "text": [ 328 | "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n", 329 | "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n", 330 | "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n" 331 | ] 332 | }, 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "No compiled kernel found.\n", 338 | "Compiling kernels : /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.c\n", 339 | "Compiling gcc -O3 -fPIC -pthread -fopenmp -std=c99 /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.c -shared -o /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.so\n", 340 | "Load kernel : /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.so\n", 341 | "Setting CPU quantization kernel threads to 2\n", 342 | "Parallel kernel is not recommended when parallel num < 4.\n", 343 | "Using quantization cache\n", 344 | "Applying quantization to glm layers\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "# AutoTokenizer自动加载与模型对应的分词器,AutoModel自动加载预训练模型\n", 350 | "from transformers import AutoTokenizer, AutoModel\n", 351 | "\n", 352 | "model_path = \"chatglm-6b-int4\"#模型的参数\n", 353 | "#根据模型的路径加载预训练分词器,允许远程加载代码(trust_remote_code=True)\n", 354 | "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n", 355 | "#根据模型的路径加载预训练模型,允许远程加载代码(trust_remote_code=True),half是半精度浮点数,cuda是移动到GPU上\n", 356 | "model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 6, 362 | "id": "042a9bb1", 363 | "metadata": { 364 | "execution": { 365 | "iopub.execute_input": "2024-07-31T13:55:45.789803Z", 366 | "iopub.status.busy": "2024-07-31T13:55:45.789330Z", 367 | "iopub.status.idle": "2024-07-31T13:55:53.967142Z", 368 | "shell.execute_reply": "2024-07-31T13:55:53.966290Z" 369 | }, 370 | "papermill": { 371 | "duration": 8.191009, 372 | "end_time": "2024-07-31T13:55:53.969453", 373 | "exception": false, 374 | "start_time": "2024-07-31T13:55:45.778444", 375 | "status": "completed" 376 | }, 377 | "tags": [] 378 | }, 379 | "outputs": [ 380 | { 381 | "name": "stderr", 382 | "output_type": "stream", 383 | "text": [ 384 | "Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.\n", 385 | "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.\n" 386 | ] 387 | }, 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "No compiled kernel found.\n", 393 | "Compiling kernels : /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.c\n", 394 | "Compiling gcc -O3 -fPIC -pthread -fopenmp -std=c99 /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.c -shared -o /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.so\n", 395 | "Load kernel : /root/.cache/huggingface/modules/transformers_modules/chatglm-6b-int4/quantization_kernels_parallel.so\n", 396 | "Setting CPU quantization kernel threads to 2\n", 397 | "Parallel kernel is not recommended when parallel num < 4.\n", 398 | "Using quantization cache\n", 399 | "Applying quantization to glm layers\n" 400 | ] 401 | }, 402 | { 403 | "name": "stderr", 404 | "output_type": "stream", 405 | "text": [ 406 | "Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at chatglm-6b-int4 and are newly initialized: ['transformer.prefix_encoder.embedding.weight']\n", 407 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "import torch#pytorch这个深度学习框架\n", 413 | "from transformers import AutoConfig#自动下载和配置预训练模型的配置\n", 414 | "\n", 415 | "#根据模型路径加载config,允许远程加载代码(trust_remote_code=True),大模型输入序列的最大长度\n", 416 | "config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=256)\n", 417 | "#根据模型的路径和参数加载模型,允许远程加载代码(trust_remote_code=True)\n", 418 | "model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)\n", 419 | "\n", 420 | "#从二进制(binary)文件中加载模型的状态字典,这个参数字典一般是在某个检查点(checkpoint)保存下来的。\n", 421 | "prefix_state_dict = torch.load(\"/kaggle/input/chatglm6b-huanhuan-finetune-training/output/infer-chatglm-6b-int4-pt-256-5e-2/checkpoint-500/pytorch_model.bin\")\n", 422 | "#进行参数的更新\n", 423 | "new_prefix_state_dict = {}\n", 424 | "for k, v in prefix_state_dict.items():\n", 425 | " new_prefix_state_dict[k[len(\"transformer.prefix_encoder.\"):]] = v\n", 426 | "model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)\n", 427 | "\n", 428 | "#half是半精度浮点数,cuda是移动到GPU上\n", 429 | "model = model.half().cuda()\n", 430 | "#将模型prefix_encoder部分的参数换成全精度浮点数float32\n", 431 | "model.transformer.prefix_encoder.float()\n", 432 | "#将大模型换成评估模式\n", 433 | "model = model.eval()" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 7, 439 | "id": "8d4a2fac", 440 | "metadata": { 441 | "execution": { 442 | "iopub.execute_input": "2024-07-31T13:55:53.991904Z", 443 | "iopub.status.busy": "2024-07-31T13:55:53.991624Z", 444 | "iopub.status.idle": "2024-07-31T13:56:08.678944Z", 445 | "shell.execute_reply": "2024-07-31T13:56:08.677900Z" 446 | }, 447 | "papermill": { 448 | "duration": 14.700595, 449 | "end_time": "2024-07-31T13:56:08.681095", 450 | "exception": false, 451 | "start_time": "2024-07-31T13:55:53.980500", 452 | "status": "completed" 453 | }, 454 | "tags": [] 455 | }, 456 | "outputs": [ 457 | { 458 | "name": "stderr", 459 | "output_type": "stream", 460 | "text": [ 461 | "2024-07-31 13:56:00.003249: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 462 | "2024-07-31 13:56:00.003347: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 463 | "2024-07-31 13:56:00.131327: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" 464 | ] 465 | }, 466 | { 467 | "name": "stdout", 468 | "output_type": "stream", 469 | "text": [ 470 | "question:朕的后宫佳丽三千,朕最喜欢的就是你。\n", 471 | "response:皇上最喜我,我自然高兴。\n" 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "question='朕的后宫佳丽三千,朕最喜欢的就是你。'\n", 477 | "response, history = model.chat(tokenizer, question, history=[])\n", 478 | "print(f\"question:{question}\\nresponse:{response}\")" 479 | ] 480 | } 481 | ], 482 | "metadata": { 483 | "kaggle": { 484 | "accelerator": "nvidiaTeslaT4", 485 | "dataSources": [ 486 | { 487 | "datasetId": 5467682, 488 | "sourceId": 9065856, 489 | "sourceType": "datasetVersion" 490 | }, 491 | { 492 | "sourceId": 190568050, 493 | "sourceType": "kernelVersion" 494 | } 495 | ], 496 | "dockerImageVersionId": 30747, 497 | "isGpuEnabled": true, 498 | "isInternetEnabled": true, 499 | "language": "python", 500 | "sourceType": "notebook" 501 | }, 502 | "kernelspec": { 503 | "display_name": "Python 3", 504 | "language": "python", 505 | "name": "python3" 506 | }, 507 | "language_info": { 508 | "codemirror_mode": { 509 | "name": "ipython", 510 | "version": 3 511 | }, 512 | "file_extension": ".py", 513 | "mimetype": "text/x-python", 514 | "name": "python", 515 | "nbconvert_exporter": "python", 516 | "pygments_lexer": "ipython3", 517 | "version": "3.10.13" 518 | }, 519 | "papermill": { 520 | "default_parameters": {}, 521 | "duration": 284.280021, 522 | "end_time": "2024-07-31T13:56:11.696753", 523 | "environment_variables": {}, 524 | "exception": null, 525 | "input_path": "__notebook__.ipynb", 526 | "output_path": "__notebook__.ipynb", 527 | "parameters": {}, 528 | "start_time": "2024-07-31T13:51:27.416732", 529 | "version": "2.5.0" 530 | } 531 | }, 532 | "nbformat": 4, 533 | "nbformat_minor": 5 534 | } 535 | -------------------------------------------------------------------------------- /202404KDDcup-whoiswho-baseline/202404-kdd-cup-whoiswho-ind-baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "599b784e", 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", 9 | "papermill": { 10 | "duration": 0.005529, 11 | "end_time": "2024-04-14T12:59:31.608353", 12 | "exception": false, 13 | "start_time": "2024-04-14T12:59:31.602824", 14 | "status": "completed" 15 | }, 16 | "tags": [] 17 | }, 18 | "source": [ 19 | "## Created by yunsuxiaozi 2024/4/14\n", 20 | "\n", 21 | "### 比赛链接如下: https://www.biendata.xyz/competition/ind_kdd_2024/\n", 22 | "\n", 23 | "### 这是我第一次参加KDD_cup,记录一下.本次比赛应该会用到知识图谱的相关知识,我这里给一个数据挖掘方面的baseline,目前分数还算不错." 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "30e374c8", 29 | "metadata": { 30 | "papermill": { 31 | "duration": 0.004331, 32 | "end_time": "2024-04-14T12:59:31.617712", 33 | "exception": false, 34 | "start_time": "2024-04-14T12:59:31.613381", 35 | "status": "completed" 36 | }, 37 | "tags": [] 38 | }, 39 | "source": [ 40 | "### 导入必要的库" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 1, 46 | "id": "b2fbc4a1", 47 | "metadata": { 48 | "execution": { 49 | "iopub.execute_input": "2024-04-14T12:59:31.629805Z", 50 | "iopub.status.busy": "2024-04-14T12:59:31.629156Z", 51 | "iopub.status.idle": "2024-04-14T12:59:35.525995Z", 52 | "shell.execute_reply": "2024-04-14T12:59:35.524676Z" 53 | }, 54 | "papermill": { 55 | "duration": 3.906376, 56 | "end_time": "2024-04-14T12:59:35.528957", 57 | "exception": false, 58 | "start_time": "2024-04-14T12:59:31.622581", 59 | "status": "completed" 60 | }, 61 | "tags": [] 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "#necessary\n", 66 | "import pandas as pd#导入csv文件的库\n", 67 | "import numpy as np#进行矩阵运算的库\n", 68 | "import json#用于读取和写入json数据格式\n", 69 | "\n", 70 | "#model lgb分类模型,日志评估,早停防止过拟合\n", 71 | "from lightgbm import LGBMClassifier,log_evaluation,early_stopping\n", 72 | "#metric\n", 73 | "from sklearn.metrics import roc_auc_score#导入roc_auc曲线\n", 74 | "#KFold是直接分成k折,StratifiedKFold还要考虑每种类别的占比\n", 75 | "from sklearn.model_selection import StratifiedKFold" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "2451b23a", 81 | "metadata": { 82 | "papermill": { 83 | "duration": 0.004596, 84 | "end_time": "2024-04-14T12:59:35.538428", 85 | "exception": false, 86 | "start_time": "2024-04-14T12:59:35.533832", 87 | "status": "completed" 88 | }, 89 | "tags": [] 90 | }, 91 | "source": [ 92 | "### 设置相关的参数" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 2, 98 | "id": "9a90213c", 99 | "metadata": { 100 | "execution": { 101 | "iopub.execute_input": "2024-04-14T12:59:35.551070Z", 102 | "iopub.status.busy": "2024-04-14T12:59:35.549735Z", 103 | "iopub.status.idle": "2024-04-14T12:59:35.557369Z", 104 | "shell.execute_reply": "2024-04-14T12:59:35.556175Z" 105 | }, 106 | "papermill": { 107 | "duration": 0.016468, 108 | "end_time": "2024-04-14T12:59:35.560053", 109 | "exception": false, 110 | "start_time": "2024-04-14T12:59:35.543585", 111 | "status": "completed" 112 | }, 113 | "tags": [] 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "#config\n", 118 | "class Config():\n", 119 | " seed=2024#随机种子\n", 120 | " num_folds=10#K折交叉验证\n", 121 | " TARGET_NAME ='label'#标签\n", 122 | "import random#提供了一些用于生成随机数的函数\n", 123 | "#设置随机种子,保证模型可以复现\n", 124 | "def seed_everything(seed):\n", 125 | " np.random.seed(seed)#numpy的随机种子\n", 126 | " random.seed(seed)#python内置的随机种子\n", 127 | "seed_everything(Config.seed)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "12332a6f", 133 | "metadata": { 134 | "papermill": { 135 | "duration": 0.004502, 136 | "end_time": "2024-04-14T12:59:35.569434", 137 | "exception": false, 138 | "start_time": "2024-04-14T12:59:35.564932", 139 | "status": "completed" 140 | }, 141 | "tags": [] 142 | }, 143 | "source": [ 144 | "### 导入相关的数据集,我这里是将数据放在Kaggle上." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 3, 150 | "id": "6a54f4c1", 151 | "metadata": { 152 | "execution": { 153 | "iopub.execute_input": "2024-04-14T12:59:35.583422Z", 154 | "iopub.status.busy": "2024-04-14T12:59:35.582977Z", 155 | "iopub.status.idle": "2024-04-14T13:00:07.938280Z", 156 | "shell.execute_reply": "2024-04-14T13:00:07.936950Z" 157 | }, 158 | "papermill": { 159 | "duration": 32.364963, 160 | "end_time": "2024-04-14T13:00:07.941386", 161 | "exception": false, 162 | "start_time": "2024-04-14T12:59:35.576423", 163 | "status": "completed" 164 | }, 165 | "tags": [] 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "path='/kaggle/input/'\n", 170 | "#sample: Iki037dt dict_keys(['name', 'normal_data', 'outliers'])\n", 171 | "with open(path+\"whoiswho-ind-kdd-2024/IND-WhoIsWho/train_author.json\") as f:\n", 172 | " train_author=json.load(f)\n", 173 | "#sample : 6IsfnuWU dict_keys(['id', 'title', 'authors', 'abstract', 'keywords', 'venue', 'year']) \n", 174 | "with open(path+\"whoiswho-ind-kdd-2024/IND-WhoIsWho/pid_to_info_all.json\") as f:\n", 175 | " pid_to_info=json.load(f)\n", 176 | "#efQ8FQ1i dict_keys(['name', 'papers'])\n", 177 | "with open(path+\"whoiswho-ind-kdd-2024/IND-WhoIsWho/ind_valid_author.json\") as f:\n", 178 | " valid_author=json.load(f)\n", 179 | "\n", 180 | "with open(path+\"whoiswho-ind-kdd-2024/IND-WhoIsWho/ind_valid_author_submit.json\") as f:\n", 181 | " submission=json.load(f)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "id": "a440a020", 187 | "metadata": { 188 | "papermill": { 189 | "duration": 0.004549, 190 | "end_time": "2024-04-14T13:00:07.950952", 191 | "exception": false, 192 | "start_time": "2024-04-14T13:00:07.946403", 193 | "status": "completed" 194 | }, 195 | "tags": [] 196 | }, 197 | "source": [ 198 | "### 这里做了简单的特征工程." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 4, 204 | "id": "e5e31cea", 205 | "metadata": { 206 | "execution": { 207 | "iopub.execute_input": "2024-04-14T13:00:07.962527Z", 208 | "iopub.status.busy": "2024-04-14T13:00:07.962053Z", 209 | "iopub.status.idle": "2024-04-14T13:00:09.317361Z", 210 | "shell.execute_reply": "2024-04-14T13:00:09.316054Z" 211 | }, 212 | "papermill": { 213 | "duration": 1.364457, 214 | "end_time": "2024-04-14T13:00:09.320169", 215 | "exception": false, 216 | "start_time": "2024-04-14T13:00:07.955712", 217 | "status": "completed" 218 | }, 219 | "tags": [] 220 | }, 221 | "outputs": [ 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "train_feats.shape:(148309, 6),labels.shape:(148309,)\n", 227 | "np.mean(labels):0.8834527911320283\n" 228 | ] 229 | }, 230 | { 231 | "data": { 232 | "text/html": [ 233 | "
\n", 234 | "\n", 247 | "\n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | "
012345label
01200014020101
1123008020111
210098649420011
31030010019871
41331629510520151
\n", 313 | "
" 314 | ], 315 | "text/plain": [ 316 | " 0 1 2 3 4 5 label\n", 317 | "0 120 0 0 14 0 2010 1\n", 318 | "1 123 0 0 8 0 2011 1\n", 319 | "2 100 986 4 9 4 2001 1\n", 320 | "3 103 0 0 10 0 1987 1\n", 321 | "4 133 1629 5 10 5 2015 1" 322 | ] 323 | }, 324 | "execution_count": 4, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "train_feats=[]\n", 331 | "labels=[]\n", 332 | "for id,person_info in train_author.items():\n", 333 | " for text_id in person_info['normal_data']:#正样本\n", 334 | " feat=pid_to_info[text_id]\n", 335 | " #['title', 'abstract', 'keywords', 'authors', 'venue', 'year']\n", 336 | " try:\n", 337 | " train_feats.append(\n", 338 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 339 | " ,len(feat['keywords']),int(feat['year'])]\n", 340 | " )\n", 341 | " except:\n", 342 | " train_feats.append(\n", 343 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 344 | " ,len(feat['keywords']),2000]\n", 345 | " )\n", 346 | " labels.append(1)\n", 347 | " for text_id in person_info['outliers']:#负样本\n", 348 | " feat=pid_to_info[text_id]\n", 349 | " #['title', 'abstract', 'keywords', 'authors', 'venue', 'year']\n", 350 | " try:\n", 351 | " train_feats.append(\n", 352 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 353 | " ,len(feat['keywords']),int(feat['year'])]\n", 354 | " )\n", 355 | " except:\n", 356 | " train_feats.append(\n", 357 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 358 | " ,len(feat['keywords']),2000]\n", 359 | " )\n", 360 | " labels.append(0) \n", 361 | "train_feats=np.array(train_feats)\n", 362 | "labels=np.array(labels)\n", 363 | "print(f\"train_feats.shape:{train_feats.shape},labels.shape:{labels.shape}\")\n", 364 | "print(f\"np.mean(labels):{np.mean(labels)}\")\n", 365 | "train_feats=pd.DataFrame(train_feats)\n", 366 | "train_feats['label']=labels\n", 367 | "train_feats.head()" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 5, 373 | "id": "1c74b8d8", 374 | "metadata": { 375 | "execution": { 376 | "iopub.execute_input": "2024-04-14T13:00:09.334530Z", 377 | "iopub.status.busy": "2024-04-14T13:00:09.333217Z", 378 | "iopub.status.idle": "2024-04-14T13:00:09.607485Z", 379 | "shell.execute_reply": "2024-04-14T13:00:09.605825Z" 380 | }, 381 | "papermill": { 382 | "duration": 0.2841, 383 | "end_time": "2024-04-14T13:00:09.610251", 384 | "exception": false, 385 | "start_time": "2024-04-14T13:00:09.326151", 386 | "status": "completed" 387 | }, 388 | "tags": [] 389 | }, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "valid_feats.shape:(62229, 6)\n" 396 | ] 397 | }, 398 | { 399 | "data": { 400 | "text/html": [ 401 | "
\n", 402 | "\n", 415 | "\n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | "
012345
0123001002015
1100106001102018
25700902016
31097617672001
4108100051052020
\n", 475 | "
" 476 | ], 477 | "text/plain": [ 478 | " 0 1 2 3 4 5\n", 479 | "0 123 0 0 10 0 2015\n", 480 | "1 100 1060 0 11 0 2018\n", 481 | "2 57 0 0 9 0 2016\n", 482 | "3 109 761 7 6 7 2001\n", 483 | "4 108 1000 5 10 5 2020" 484 | ] 485 | }, 486 | "execution_count": 5, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "valid_feats=[]\n", 493 | "for id,person_info in valid_author.items():\n", 494 | " for text_id in person_info['papers']:\n", 495 | " feat=pid_to_info[text_id]\n", 496 | " #['title', 'abstract', 'keywords', 'authors', 'venue', 'year']\n", 497 | " try:\n", 498 | " valid_feats.append(\n", 499 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 500 | " ,len(feat['keywords']),int(feat['year'])]\n", 501 | " )\n", 502 | " except:\n", 503 | " valid_feats.append(\n", 504 | " [len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors'])\n", 505 | " ,len(feat['keywords']),2000]\n", 506 | " )\n", 507 | "valid_feats=np.array(valid_feats)\n", 508 | "print(f\"valid_feats.shape:{valid_feats.shape}\")\n", 509 | "valid_feats=pd.DataFrame(valid_feats)\n", 510 | "valid_feats.head()" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "id": "db2c4fd5", 516 | "metadata": { 517 | "papermill": { 518 | "duration": 0.006176, 519 | "end_time": "2024-04-14T13:00:09.622302", 520 | "exception": false, 521 | "start_time": "2024-04-14T13:00:09.616126", 522 | "status": "completed" 523 | }, 524 | "tags": [] 525 | }, 526 | "source": [ 527 | "### 训练10折lightgbm模型." 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 6, 533 | "id": "989cec24", 534 | "metadata": { 535 | "execution": { 536 | "iopub.execute_input": "2024-04-14T13:00:09.636185Z", 537 | "iopub.status.busy": "2024-04-14T13:00:09.635708Z", 538 | "iopub.status.idle": "2024-04-14T13:08:45.403174Z", 539 | "shell.execute_reply": "2024-04-14T13:08:45.401548Z" 540 | }, 541 | "papermill": { 542 | "duration": 515.778142, 543 | "end_time": "2024-04-14T13:08:45.406401", 544 | "exception": false, 545 | "start_time": "2024-04-14T13:00:09.628259", 546 | "status": "completed" 547 | }, 548 | "tags": [] 549 | }, 550 | "outputs": [ 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "name:lgb,fold:0\n", 556 | "Training until validation scores don't improve for 100 rounds\n", 557 | "[100]\tvalid_0's auc: 0.630437\n", 558 | "[200]\tvalid_0's auc: 0.640359\n", 559 | "[300]\tvalid_0's auc: 0.645533\n", 560 | "[400]\tvalid_0's auc: 0.647981\n", 561 | "[500]\tvalid_0's auc: 0.650869\n", 562 | "[600]\tvalid_0's auc: 0.654132\n", 563 | "[700]\tvalid_0's auc: 0.656083\n", 564 | "[800]\tvalid_0's auc: 0.657789\n", 565 | "[900]\tvalid_0's auc: 0.659147\n", 566 | "[1000]\tvalid_0's auc: 0.659898\n", 567 | "[1100]\tvalid_0's auc: 0.661268\n", 568 | "[1200]\tvalid_0's auc: 0.66205\n", 569 | "[1300]\tvalid_0's auc: 0.662649\n", 570 | "[1400]\tvalid_0's auc: 0.663043\n", 571 | "[1500]\tvalid_0's auc: 0.663341\n", 572 | "[1600]\tvalid_0's auc: 0.664103\n", 573 | "[1700]\tvalid_0's auc: 0.664395\n", 574 | "[1800]\tvalid_0's auc: 0.664894\n", 575 | "[1900]\tvalid_0's auc: 0.665527\n", 576 | "[2000]\tvalid_0's auc: 0.665736\n", 577 | "[2100]\tvalid_0's auc: 0.665996\n", 578 | "[2200]\tvalid_0's auc: 0.66617\n", 579 | "[2300]\tvalid_0's auc: 0.666477\n", 580 | "[2400]\tvalid_0's auc: 0.666573\n", 581 | "[2500]\tvalid_0's auc: 0.666878\n", 582 | "[2600]\tvalid_0's auc: 0.667085\n", 583 | "[2700]\tvalid_0's auc: 0.667168\n", 584 | "[2800]\tvalid_0's auc: 0.667214\n", 585 | "[2900]\tvalid_0's auc: 0.66736\n", 586 | "[3000]\tvalid_0's auc: 0.667613\n", 587 | "Did not meet early stopping. Best iteration is:\n", 588 | "[3062]\tvalid_0's auc: 0.667964\n", 589 | "name:lgb,fold:1\n", 590 | "Training until validation scores don't improve for 100 rounds\n", 591 | "[100]\tvalid_0's auc: 0.615511\n", 592 | "[200]\tvalid_0's auc: 0.626335\n", 593 | "[300]\tvalid_0's auc: 0.630686\n", 594 | "[400]\tvalid_0's auc: 0.635685\n", 595 | "[500]\tvalid_0's auc: 0.63864\n", 596 | "[600]\tvalid_0's auc: 0.641397\n", 597 | "[700]\tvalid_0's auc: 0.643005\n", 598 | "[800]\tvalid_0's auc: 0.645381\n", 599 | "[900]\tvalid_0's auc: 0.646585\n", 600 | "[1000]\tvalid_0's auc: 0.647435\n", 601 | "[1100]\tvalid_0's auc: 0.64839\n", 602 | "[1200]\tvalid_0's auc: 0.649324\n", 603 | "[1300]\tvalid_0's auc: 0.650197\n", 604 | "[1400]\tvalid_0's auc: 0.651116\n", 605 | "[1500]\tvalid_0's auc: 0.651392\n", 606 | "[1600]\tvalid_0's auc: 0.651963\n", 607 | "[1700]\tvalid_0's auc: 0.652477\n", 608 | "[1800]\tvalid_0's auc: 0.65271\n", 609 | "[1900]\tvalid_0's auc: 0.653317\n", 610 | "[2000]\tvalid_0's auc: 0.653459\n", 611 | "[2100]\tvalid_0's auc: 0.653363\n", 612 | "Early stopping, best iteration is:\n", 613 | "[2027]\tvalid_0's auc: 0.653549\n", 614 | "name:lgb,fold:2\n", 615 | "Training until validation scores don't improve for 100 rounds\n", 616 | "[100]\tvalid_0's auc: 0.619258\n", 617 | "[200]\tvalid_0's auc: 0.629621\n", 618 | "[300]\tvalid_0's auc: 0.635659\n", 619 | "[400]\tvalid_0's auc: 0.639289\n", 620 | "[500]\tvalid_0's auc: 0.64285\n", 621 | "[600]\tvalid_0's auc: 0.646\n", 622 | "[700]\tvalid_0's auc: 0.648501\n", 623 | "[800]\tvalid_0's auc: 0.650791\n", 624 | "[900]\tvalid_0's auc: 0.652709\n", 625 | "[1000]\tvalid_0's auc: 0.654113\n", 626 | "[1100]\tvalid_0's auc: 0.655698\n", 627 | "[1200]\tvalid_0's auc: 0.65647\n", 628 | "[1300]\tvalid_0's auc: 0.657273\n", 629 | "[1400]\tvalid_0's auc: 0.658648\n", 630 | "[1500]\tvalid_0's auc: 0.659571\n", 631 | "[1600]\tvalid_0's auc: 0.659918\n", 632 | "[1700]\tvalid_0's auc: 0.660675\n", 633 | "[1800]\tvalid_0's auc: 0.661075\n", 634 | "[1900]\tvalid_0's auc: 0.661686\n", 635 | "[2000]\tvalid_0's auc: 0.662102\n", 636 | "[2100]\tvalid_0's auc: 0.662378\n", 637 | "[2200]\tvalid_0's auc: 0.662684\n", 638 | "[2300]\tvalid_0's auc: 0.662885\n", 639 | "[2400]\tvalid_0's auc: 0.663099\n", 640 | "[2500]\tvalid_0's auc: 0.663343\n", 641 | "[2600]\tvalid_0's auc: 0.663543\n", 642 | "[2700]\tvalid_0's auc: 0.663964\n", 643 | "[2800]\tvalid_0's auc: 0.66409\n", 644 | "Early stopping, best iteration is:\n", 645 | "[2756]\tvalid_0's auc: 0.664184\n", 646 | "name:lgb,fold:3\n", 647 | "Training until validation scores don't improve for 100 rounds\n", 648 | "[100]\tvalid_0's auc: 0.631423\n", 649 | "[200]\tvalid_0's auc: 0.636674\n", 650 | "[300]\tvalid_0's auc: 0.64028\n", 651 | "[400]\tvalid_0's auc: 0.642409\n", 652 | "[500]\tvalid_0's auc: 0.643911\n", 653 | "[600]\tvalid_0's auc: 0.644709\n", 654 | "[700]\tvalid_0's auc: 0.645442\n", 655 | "[800]\tvalid_0's auc: 0.64645\n", 656 | "[900]\tvalid_0's auc: 0.646913\n", 657 | "[1000]\tvalid_0's auc: 0.647276\n", 658 | "[1100]\tvalid_0's auc: 0.64738\n", 659 | "[1200]\tvalid_0's auc: 0.647635\n", 660 | "[1300]\tvalid_0's auc: 0.647786\n", 661 | "[1400]\tvalid_0's auc: 0.647506\n", 662 | "Early stopping, best iteration is:\n", 663 | "[1304]\tvalid_0's auc: 0.647869\n", 664 | "name:lgb,fold:4\n", 665 | "Training until validation scores don't improve for 100 rounds\n", 666 | "[100]\tvalid_0's auc: 0.629972\n", 667 | "[200]\tvalid_0's auc: 0.642136\n", 668 | "[300]\tvalid_0's auc: 0.648501\n", 669 | "[400]\tvalid_0's auc: 0.651197\n", 670 | "[500]\tvalid_0's auc: 0.653725\n", 671 | "[600]\tvalid_0's auc: 0.655976\n", 672 | "[700]\tvalid_0's auc: 0.657406\n", 673 | "[800]\tvalid_0's auc: 0.658913\n", 674 | "[900]\tvalid_0's auc: 0.660715\n", 675 | "[1000]\tvalid_0's auc: 0.661973\n", 676 | "[1100]\tvalid_0's auc: 0.6624\n", 677 | "[1200]\tvalid_0's auc: 0.663025\n", 678 | "[1300]\tvalid_0's auc: 0.663392\n", 679 | "[1400]\tvalid_0's auc: 0.663679\n", 680 | "[1500]\tvalid_0's auc: 0.663811\n", 681 | "[1600]\tvalid_0's auc: 0.664095\n", 682 | "[1700]\tvalid_0's auc: 0.664327\n", 683 | "[1800]\tvalid_0's auc: 0.66479\n", 684 | "[1900]\tvalid_0's auc: 0.664854\n", 685 | "[2000]\tvalid_0's auc: 0.664933\n", 686 | "Early stopping, best iteration is:\n", 687 | "[1993]\tvalid_0's auc: 0.665025\n", 688 | "name:lgb,fold:5\n", 689 | "Training until validation scores don't improve for 100 rounds\n", 690 | "[100]\tvalid_0's auc: 0.61695\n", 691 | "[200]\tvalid_0's auc: 0.629734\n", 692 | "[300]\tvalid_0's auc: 0.634656\n", 693 | "[400]\tvalid_0's auc: 0.638703\n", 694 | "[500]\tvalid_0's auc: 0.641658\n", 695 | "[600]\tvalid_0's auc: 0.644832\n", 696 | "[700]\tvalid_0's auc: 0.647198\n", 697 | "[800]\tvalid_0's auc: 0.648909\n", 698 | "[900]\tvalid_0's auc: 0.650742\n", 699 | "[1000]\tvalid_0's auc: 0.652068\n", 700 | "[1100]\tvalid_0's auc: 0.652734\n", 701 | "[1200]\tvalid_0's auc: 0.653562\n", 702 | "[1300]\tvalid_0's auc: 0.654184\n", 703 | "[1400]\tvalid_0's auc: 0.655004\n", 704 | "[1500]\tvalid_0's auc: 0.655415\n", 705 | "[1600]\tvalid_0's auc: 0.655829\n", 706 | "[1700]\tvalid_0's auc: 0.656131\n", 707 | "[1800]\tvalid_0's auc: 0.656706\n", 708 | "[1900]\tvalid_0's auc: 0.656883\n", 709 | "[2000]\tvalid_0's auc: 0.65711\n", 710 | "[2100]\tvalid_0's auc: 0.657309\n", 711 | "[2200]\tvalid_0's auc: 0.657863\n", 712 | "[2300]\tvalid_0's auc: 0.658154\n", 713 | "[2400]\tvalid_0's auc: 0.658361\n", 714 | "[2500]\tvalid_0's auc: 0.658688\n", 715 | "[2600]\tvalid_0's auc: 0.65898\n", 716 | "[2700]\tvalid_0's auc: 0.659136\n", 717 | "[2800]\tvalid_0's auc: 0.659502\n", 718 | "[2900]\tvalid_0's auc: 0.659542\n", 719 | "Early stopping, best iteration is:\n", 720 | "[2837]\tvalid_0's auc: 0.659651\n", 721 | "name:lgb,fold:6\n", 722 | "Training until validation scores don't improve for 100 rounds\n", 723 | "[100]\tvalid_0's auc: 0.624821\n", 724 | "[200]\tvalid_0's auc: 0.634455\n", 725 | "[300]\tvalid_0's auc: 0.640654\n", 726 | "[400]\tvalid_0's auc: 0.644759\n", 727 | "[500]\tvalid_0's auc: 0.647785\n", 728 | "[600]\tvalid_0's auc: 0.65128\n", 729 | "[700]\tvalid_0's auc: 0.653475\n", 730 | "[800]\tvalid_0's auc: 0.655492\n", 731 | "[900]\tvalid_0's auc: 0.656274\n", 732 | "[1000]\tvalid_0's auc: 0.657067\n", 733 | "[1100]\tvalid_0's auc: 0.657795\n", 734 | "[1200]\tvalid_0's auc: 0.658408\n", 735 | "[1300]\tvalid_0's auc: 0.658986\n", 736 | "[1400]\tvalid_0's auc: 0.659539\n", 737 | "[1500]\tvalid_0's auc: 0.659802\n", 738 | "[1600]\tvalid_0's auc: 0.660182\n", 739 | "[1700]\tvalid_0's auc: 0.660413\n", 740 | "[1800]\tvalid_0's auc: 0.660799\n", 741 | "[1900]\tvalid_0's auc: 0.660616\n", 742 | "Early stopping, best iteration is:\n", 743 | "[1808]\tvalid_0's auc: 0.66084\n", 744 | "name:lgb,fold:7\n", 745 | "Training until validation scores don't improve for 100 rounds\n", 746 | "[100]\tvalid_0's auc: 0.627434\n", 747 | "[200]\tvalid_0's auc: 0.633438\n", 748 | "[300]\tvalid_0's auc: 0.636971\n", 749 | "[400]\tvalid_0's auc: 0.639303\n", 750 | "[500]\tvalid_0's auc: 0.641647\n", 751 | "[600]\tvalid_0's auc: 0.644262\n", 752 | "[700]\tvalid_0's auc: 0.645511\n", 753 | "[800]\tvalid_0's auc: 0.646853\n", 754 | "[900]\tvalid_0's auc: 0.64777\n", 755 | "[1000]\tvalid_0's auc: 0.648681\n", 756 | "[1100]\tvalid_0's auc: 0.649346\n", 757 | "[1200]\tvalid_0's auc: 0.649883\n", 758 | "[1300]\tvalid_0's auc: 0.650075\n", 759 | "[1400]\tvalid_0's auc: 0.650398\n", 760 | "[1500]\tvalid_0's auc: 0.650675\n", 761 | "[1600]\tvalid_0's auc: 0.651248\n", 762 | "[1700]\tvalid_0's auc: 0.651765\n", 763 | "[1800]\tvalid_0's auc: 0.651723\n", 764 | "Early stopping, best iteration is:\n", 765 | "[1748]\tvalid_0's auc: 0.651819\n", 766 | "name:lgb,fold:8\n", 767 | "Training until validation scores don't improve for 100 rounds\n", 768 | "[100]\tvalid_0's auc: 0.619171\n", 769 | "[200]\tvalid_0's auc: 0.633663\n", 770 | "[300]\tvalid_0's auc: 0.63856\n", 771 | "[400]\tvalid_0's auc: 0.642654\n", 772 | "[500]\tvalid_0's auc: 0.645804\n", 773 | "[600]\tvalid_0's auc: 0.648194\n", 774 | "[700]\tvalid_0's auc: 0.649465\n", 775 | "[800]\tvalid_0's auc: 0.650915\n", 776 | "[900]\tvalid_0's auc: 0.651342\n", 777 | "[1000]\tvalid_0's auc: 0.651788\n", 778 | "[1100]\tvalid_0's auc: 0.652351\n", 779 | "[1200]\tvalid_0's auc: 0.652722\n", 780 | "[1300]\tvalid_0's auc: 0.652813\n", 781 | "Early stopping, best iteration is:\n", 782 | "[1231]\tvalid_0's auc: 0.653012\n", 783 | "name:lgb,fold:9\n", 784 | "Training until validation scores don't improve for 100 rounds\n", 785 | "[100]\tvalid_0's auc: 0.627058\n", 786 | "[200]\tvalid_0's auc: 0.635942\n", 787 | "[300]\tvalid_0's auc: 0.64046\n", 788 | "[400]\tvalid_0's auc: 0.644498\n", 789 | "[500]\tvalid_0's auc: 0.646563\n", 790 | "[600]\tvalid_0's auc: 0.649646\n", 791 | "[700]\tvalid_0's auc: 0.651235\n", 792 | "[800]\tvalid_0's auc: 0.652422\n", 793 | "[900]\tvalid_0's auc: 0.653831\n", 794 | "[1000]\tvalid_0's auc: 0.654403\n", 795 | "[1100]\tvalid_0's auc: 0.655255\n", 796 | "[1200]\tvalid_0's auc: 0.655841\n", 797 | "[1300]\tvalid_0's auc: 0.656601\n", 798 | "[1400]\tvalid_0's auc: 0.657443\n", 799 | "[1500]\tvalid_0's auc: 0.657733\n", 800 | "[1600]\tvalid_0's auc: 0.658352\n", 801 | "[1700]\tvalid_0's auc: 0.6591\n", 802 | "[1800]\tvalid_0's auc: 0.659819\n", 803 | "[1900]\tvalid_0's auc: 0.660032\n", 804 | "[2000]\tvalid_0's auc: 0.660598\n", 805 | "[2100]\tvalid_0's auc: 0.660693\n", 806 | "Early stopping, best iteration is:\n", 807 | "[2058]\tvalid_0's auc: 0.660833\n", 808 | "roc_auc:0.6584834336493429\n" 809 | ] 810 | } 811 | ], 812 | "source": [ 813 | "choose_cols=[col for col in valid_feats.columns]\n", 814 | "def fit_and_predict(model,train_feats=train_feats,test_feats=valid_feats,name=0):\n", 815 | " X=train_feats[choose_cols].copy()\n", 816 | " y=train_feats[Config.TARGET_NAME].copy()\n", 817 | " test_X=test_feats[choose_cols].copy()\n", 818 | " oof_pred_pro=np.zeros((len(X),2))\n", 819 | " test_pred_pro=np.zeros((Config.num_folds,len(test_X),2))\n", 820 | "\n", 821 | " #10折交叉验证\n", 822 | " skf = StratifiedKFold(n_splits=Config.num_folds,random_state=Config.seed, shuffle=True)\n", 823 | "\n", 824 | " for fold, (train_index, valid_index) in (enumerate(skf.split(X, y.astype(str)))):\n", 825 | " print(f\"name:{name},fold:{fold}\")\n", 826 | "\n", 827 | " X_train, X_valid = X.iloc[train_index], X.iloc[valid_index]\n", 828 | " y_train, y_valid = y.iloc[train_index], y.iloc[valid_index]\n", 829 | " \n", 830 | " model.fit(X_train,y_train,eval_set=[(X_valid, y_valid)],\n", 831 | " callbacks=[log_evaluation(100),early_stopping(100)]\n", 832 | " )\n", 833 | " \n", 834 | " oof_pred_pro[valid_index]=model.predict_proba(X_valid)\n", 835 | " #将数据分批次进行预测.\n", 836 | " test_pred_pro[fold]=model.predict_proba(test_X)\n", 837 | " print(f\"roc_auc:{roc_auc_score(y.values,oof_pred_pro[:,1])}\")\n", 838 | " \n", 839 | " return oof_pred_pro,test_pred_pro\n", 840 | "#参数来源:https://www.kaggle.com/code/daviddirethucus/home-credit-risk-lightgbm\n", 841 | "lgb_params={\n", 842 | " \"boosting_type\": \"gbdt\",\n", 843 | " \"objective\": \"binary\",\n", 844 | " \"metric\": \"auc\",\n", 845 | " \"max_depth\": 12,\n", 846 | " \"learning_rate\": 0.05,\n", 847 | " \"n_estimators\":3072,\n", 848 | " \"colsample_bytree\": 0.9,\n", 849 | " \"colsample_bynode\": 0.9,\n", 850 | " \"verbose\": -1,\n", 851 | " \"random_state\": Config.seed,\n", 852 | " \"reg_alpha\": 0.1,\n", 853 | " \"reg_lambda\": 10,\n", 854 | " \"extra_trees\":True,\n", 855 | " 'num_leaves':64,\n", 856 | " \"verbose\": -1,\n", 857 | " \"max_bin\":255,\n", 858 | " }\n", 859 | "\n", 860 | "\n", 861 | "lgb_oof_pred_pro,lgb_test_pred_pro=fit_and_predict(model= LGBMClassifier(**lgb_params),name='lgb'\n", 862 | " )\n", 863 | "test_preds=lgb_test_pred_pro.mean(axis=0)[:,1]" 864 | ] 865 | }, 866 | { 867 | "cell_type": "markdown", 868 | "id": "6d98b5fc", 869 | "metadata": { 870 | "papermill": { 871 | "duration": 0.026982, 872 | "end_time": "2024-04-14T13:08:45.462134", 873 | "exception": false, 874 | "start_time": "2024-04-14T13:08:45.435152", 875 | "status": "completed" 876 | }, 877 | "tags": [] 878 | }, 879 | "source": [ 880 | "### 保存为json文件." 881 | ] 882 | }, 883 | { 884 | "cell_type": "code", 885 | "execution_count": 7, 886 | "id": "66ef8368", 887 | "metadata": { 888 | "execution": { 889 | "iopub.execute_input": "2024-04-14T13:08:45.518083Z", 890 | "iopub.status.busy": "2024-04-14T13:08:45.517395Z", 891 | "iopub.status.idle": "2024-04-14T13:08:45.807668Z", 892 | "shell.execute_reply": "2024-04-14T13:08:45.806687Z" 893 | }, 894 | "papermill": { 895 | "duration": 0.32143, 896 | "end_time": "2024-04-14T13:08:45.810379", 897 | "exception": false, 898 | "start_time": "2024-04-14T13:08:45.488949", 899 | "status": "completed" 900 | }, 901 | "tags": [] 902 | }, 903 | "outputs": [], 904 | "source": [ 905 | "cnt=0\n", 906 | "for id,names in submission.items():\n", 907 | " for name in names:\n", 908 | " submission[id][name]=test_preds[cnt]\n", 909 | " cnt+=1\n", 910 | "with open('baseline.json', 'w', encoding='utf-8') as f:\n", 911 | " json.dump(submission, f, ensure_ascii=False, indent=4)" 912 | ] 913 | } 914 | ], 915 | "metadata": { 916 | "kaggle": { 917 | "accelerator": "none", 918 | "dataSources": [ 919 | { 920 | "datasetId": 4794042, 921 | "sourceId": 8114719, 922 | "sourceType": "datasetVersion" 923 | } 924 | ], 925 | "dockerImageVersionId": 30684, 926 | "isGpuEnabled": false, 927 | "isInternetEnabled": true, 928 | "language": "python", 929 | "sourceType": "notebook" 930 | }, 931 | "kernelspec": { 932 | "display_name": "Python 3", 933 | "language": "python", 934 | "name": "python3" 935 | }, 936 | "language_info": { 937 | "codemirror_mode": { 938 | "name": "ipython", 939 | "version": 3 940 | }, 941 | "file_extension": ".py", 942 | "mimetype": "text/x-python", 943 | "name": "python", 944 | "nbconvert_exporter": "python", 945 | "pygments_lexer": "ipython3", 946 | "version": "3.10.13" 947 | }, 948 | "papermill": { 949 | "default_parameters": {}, 950 | "duration": 559.868762, 951 | "end_time": "2024-04-14T13:08:48.270548", 952 | "environment_variables": {}, 953 | "exception": null, 954 | "input_path": "__notebook__.ipynb", 955 | "output_path": "__notebook__.ipynb", 956 | "parameters": {}, 957 | "start_time": "2024-04-14T12:59:28.401786", 958 | "version": "2.5.0" 959 | } 960 | }, 961 | "nbformat": 4, 962 | "nbformat_minor": 5 963 | } 964 | -------------------------------------------------------------------------------- /202406datacastle睡眠事件检测baseline/睡眠事件检测baseline(LB0.6251).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "11a620b1", 6 | "metadata": {}, 7 | "source": [ 8 | "## Created by yunsuxiaozi 2024/6/18" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "a5750a2f", 14 | "metadata": {}, 15 | "source": [ 16 | "### Libraries" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "1f75552e", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import pandas as pd#导入csv文件的库\n", 27 | "import numpy as np#进行矩阵运算的库\n", 28 | "from tqdm import tqdm#加载进度条的库\n", 29 | "import warnings#避免一些可以忽略的报错\n", 30 | "warnings.filterwarnings('ignore')#filterwarnings()方法是用于设置警告过滤器的方法,它可以控制警告信息的输出方式和级别。\n", 31 | "\n", 32 | "import random#提供了一些用于生成随机数的函数\n", 33 | "#设置随机种子,保证模型可以复现\n", 34 | "def seed_everything(seed):\n", 35 | " np.random.seed(seed)#numpy的随机种子\n", 36 | " random.seed(seed)#python内置的随机种子\n", 37 | "seed_everything(seed=2024)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "2a3aba0b", 43 | "metadata": {}, 44 | "source": [ 45 | "### read data" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "id": "30477f64", 52 | "metadata": { 53 | "scrolled": true 54 | }, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "train_X.shape:(37549, 2, 180)\n", 61 | "train_y.shape:(37549,)\n", 62 | "test_X.shape:(1155, 2, 180)\n" 63 | ] 64 | }, 65 | { 66 | "data": { 67 | "text/html": [ 68 | "
\n", 69 | "\n", 82 | "\n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | "
idlabel
000
110
220
330
440
\n", 118 | "
" 119 | ], 120 | "text/plain": [ 121 | " id label\n", 122 | "0 0 0\n", 123 | "1 1 0\n", 124 | "2 2 0\n", 125 | "3 3 0\n", 126 | "4 4 0" 127 | ] 128 | }, 129 | "execution_count": 2, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "#样本数量*(血氧和心率)*采样为3hz,180个数据总共60秒\n", 136 | "path=\"\"#这里需要改成你自己的文件路径\n", 137 | "train_X=np.load(path+\"训练集\\\\train_x.npy\")\n", 138 | "print(f\"train_X.shape:{train_X.shape}\")\n", 139 | "train_y=np.load(path+\"训练集\\\\train_y.npy\")\n", 140 | "print(f\"train_y.shape:{train_y.shape}\")\n", 141 | "test_X=np.load(path+\"测试集A\\\\test_x_A.npy\")\n", 142 | "print(f\"test_X.shape:{test_X.shape}\")\n", 143 | "submission=pd.read_csv(path+\"测试集A\\\\submit_example_A.csv\")\n", 144 | "submission.head()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 3, 150 | "id": "1baa04dd", 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "len(total_index):12341\n", 158 | "label:0,4600\n", 159 | "label:1,3221\n", 160 | "label:2,4520\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "#为了和测试集保持一致,并测试模型对测试集的效果,label=0的样本随机选择4600个,比label1和label2稍微多一点。\n", 166 | "zero_index=list(np.where(train_y==0)[0])\n", 167 | "np.random.shuffle(zero_index)\n", 168 | "total_index=zero_index[:4600]+list(np.where(train_y!=0)[0])\n", 169 | "train_X=train_X[total_index]\n", 170 | "train_y=train_y[total_index]\n", 171 | "print(f\"len(total_index):{len(total_index)}\")\n", 172 | "for i in range(3):\n", 173 | " print(f\"label:{i},{np.sum(train_y==i)}\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "id": "b866eac5", 179 | "metadata": {}, 180 | "source": [ 181 | "### Feature engineer" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 4, 187 | "id": "72e9873e", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "name": "stderr", 192 | "output_type": "stream", 193 | "text": [ 194 | "100%|████████████████████████████████████████████████████████████████████████████| 12341/12341 [05:21<00:00, 38.40it/s]\n", 195 | "100%|██████████████████████████████████████████████████████████████████████████████| 1155/1155 [00:29<00:00, 39.74it/s]\n" 196 | ] 197 | }, 198 | { 199 | "data": { 200 | "text/html": [ 201 | "
\n", 202 | "\n", 215 | "\n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | "
mean_血氧/秒mean_心率/秒mean_血氧/秒_shift1mean_血氧/秒_gap1mean_血氧/秒_shift2mean_血氧/秒_gap2mean_血氧/秒_shift4mean_血氧/秒_gap4mean_血氧/秒_shift8mean_血氧/秒_gap8...median_心率/秒_gap2median_心率/秒_shift4median_心率/秒_gap4median_心率/秒_shift8median_心率/秒_gap8median_心率/秒_shift16median_心率/秒_gap16median_心率/秒_shift30median_心率/秒_gap30label
092.68333349.38333392.6271190.08474692.5862070.15517292.5000000.30357192.3461540.519231...0.00000050.0000000.00000050.6666670.66666749.8333331.50000048.0000003.8333330
196.00000063.01111196.0000000.00000096.0000000.00000096.0000000.00000096.0000000.000000...0.00000063.0000000.00000063.0000000.66666763.0000001.00000062.1666671.5000000
295.00000057.70000095.0000000.00000095.0000000.00000095.0000000.00000095.0000000.000000...0.00000058.0000000.00000058.1666670.00000058.333333-1.16666758.1666670.0000000
395.40000053.25000095.4067800.05084795.4137930.09195495.4285710.12500095.4615380.083333...0.00000052.8333330.00000052.6666670.16666752.3333330.66666752.0000002.1666670
497.18333370.03333397.1694920.03389897.1551720.06896697.1250000.14285797.0384620.326923...-0.16666770.000000-0.33333370.0000000.00000068.8333331.00000067.5000004.8333330
\n", 365 | "

5 rows × 131 columns

\n", 366 | "
" 367 | ], 368 | "text/plain": [ 369 | " mean_血氧/秒 mean_心率/秒 mean_血氧/秒_shift1 mean_血氧/秒_gap1 mean_血氧/秒_shift2 \\\n", 370 | "0 92.683333 49.383333 92.627119 0.084746 92.586207 \n", 371 | "1 96.000000 63.011111 96.000000 0.000000 96.000000 \n", 372 | "2 95.000000 57.700000 95.000000 0.000000 95.000000 \n", 373 | "3 95.400000 53.250000 95.406780 0.050847 95.413793 \n", 374 | "4 97.183333 70.033333 97.169492 0.033898 97.155172 \n", 375 | "\n", 376 | " mean_血氧/秒_gap2 mean_血氧/秒_shift4 mean_血氧/秒_gap4 mean_血氧/秒_shift8 \\\n", 377 | "0 0.155172 92.500000 0.303571 92.346154 \n", 378 | "1 0.000000 96.000000 0.000000 96.000000 \n", 379 | "2 0.000000 95.000000 0.000000 95.000000 \n", 380 | "3 0.091954 95.428571 0.125000 95.461538 \n", 381 | "4 0.068966 97.125000 0.142857 97.038462 \n", 382 | "\n", 383 | " mean_血氧/秒_gap8 ... median_心率/秒_gap2 median_心率/秒_shift4 \\\n", 384 | "0 0.519231 ... 0.000000 50.000000 \n", 385 | "1 0.000000 ... 0.000000 63.000000 \n", 386 | "2 0.000000 ... 0.000000 58.000000 \n", 387 | "3 0.083333 ... 0.000000 52.833333 \n", 388 | "4 0.326923 ... -0.166667 70.000000 \n", 389 | "\n", 390 | " median_心率/秒_gap4 median_心率/秒_shift8 median_心率/秒_gap8 \\\n", 391 | "0 0.000000 50.666667 0.666667 \n", 392 | "1 0.000000 63.000000 0.666667 \n", 393 | "2 0.000000 58.166667 0.000000 \n", 394 | "3 0.000000 52.666667 0.166667 \n", 395 | "4 -0.333333 70.000000 0.000000 \n", 396 | "\n", 397 | " median_心率/秒_shift16 median_心率/秒_gap16 median_心率/秒_shift30 \\\n", 398 | "0 49.833333 1.500000 48.000000 \n", 399 | "1 63.000000 1.000000 62.166667 \n", 400 | "2 58.333333 -1.166667 58.166667 \n", 401 | "3 52.333333 0.666667 52.000000 \n", 402 | "4 68.833333 1.000000 67.500000 \n", 403 | "\n", 404 | " median_心率/秒_gap30 label \n", 405 | "0 3.833333 0 \n", 406 | "1 1.500000 0 \n", 407 | "2 0.000000 0 \n", 408 | "3 2.166667 0 \n", 409 | "4 4.833333 0 \n", 410 | "\n", 411 | "[5 rows x 131 columns]" 412 | ] 413 | }, 414 | "execution_count": 4, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | } 418 | ], 419 | "source": [ 420 | "#通过train_X和test_X来构造特征\n", 421 | "def get_feats(data):\n", 422 | " feats=[]\n", 423 | " for i in tqdm(range(len(data))):\n", 424 | " #data[i]是2*180 血氧和心率\n", 425 | " data[i][0],data[i][1]\n", 426 | " #由于是3hz,所以按照秒来提取特征\n", 427 | " origin_feats=pd.DataFrame({\"血氧/秒\":data[i][0].reshape(-1,3).mean(axis=1),\"心率/秒\":data[i][1].reshape(-1,3).mean(axis=1)})\n", 428 | " for col in ['血氧/秒',\"心率/秒\"]:\n", 429 | " for gap in [1,2,4,8,16,30]:\n", 430 | " origin_feats[f\"{col}_shift{gap}\"]=origin_feats[col].shift(gap)\n", 431 | " origin_feats[f\"{col}_gap{gap}\"]=origin_feats[col]-origin_feats[f\"{col}_shift{gap}\"]\n", 432 | " feats.append(list(origin_feats.mean(axis=0).values)+list(origin_feats.max(axis=0).values)+\\\n", 433 | " list(origin_feats.min(axis=0).values)+list(origin_feats.std(axis=0).values)+\\\n", 434 | " list(origin_feats.median(axis=0).values)\n", 435 | " )\n", 436 | " feats=pd.DataFrame(feats)\n", 437 | " origin_cols=list(origin_feats.columns)\n", 438 | " feats.columns=[f\"mean_{col}\"for col in origin_cols]+[f\"max_{col}\"for col in origin_cols]\\\n", 439 | " +[f\"min_{col}\"for col in origin_cols]+[f\"std_{col}\"for col in origin_cols]+\\\n", 440 | " [f\"median_{col}\"for col in origin_cols]\n", 441 | " return feats\n", 442 | "train_feats=get_feats(train_X)\n", 443 | "train_feats['label']=train_y\n", 444 | "test_feats=get_feats(test_X)\n", 445 | "train_feats.head()" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "id": "4fc6dfa8", 451 | "metadata": {}, 452 | "source": [ 453 | "### Model training" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 5, 459 | "id": "d3d013da", 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "name": "stdout", 464 | "output_type": "stream", 465 | "text": [ 466 | "name lgb,fold:0\n", 467 | "Training until validation scores don't improve for 100 rounds\n", 468 | "[100]\tvalid_0's multi_logloss: 0.704301\n", 469 | "[200]\tvalid_0's multi_logloss: 0.686409\n", 470 | "[300]\tvalid_0's multi_logloss: 0.678001\n", 471 | "[400]\tvalid_0's multi_logloss: 0.673091\n", 472 | "[500]\tvalid_0's multi_logloss: 0.668517\n", 473 | "[600]\tvalid_0's multi_logloss: 0.665939\n", 474 | "[700]\tvalid_0's multi_logloss: 0.663869\n", 475 | "[800]\tvalid_0's multi_logloss: 0.662662\n", 476 | "[900]\tvalid_0's multi_logloss: 0.661497\n", 477 | "[1000]\tvalid_0's multi_logloss: 0.661133\n", 478 | "[1100]\tvalid_0's multi_logloss: 0.660287\n", 479 | "[1200]\tvalid_0's multi_logloss: 0.659414\n", 480 | "[1300]\tvalid_0's multi_logloss: 0.659015\n", 481 | "[1400]\tvalid_0's multi_logloss: 0.659026\n", 482 | "[1500]\tvalid_0's multi_logloss: 0.658531\n", 483 | "[1600]\tvalid_0's multi_logloss: 0.658765\n", 484 | "Early stopping, best iteration is:\n", 485 | "[1508]\tvalid_0's multi_logloss: 0.658277\n", 486 | "name lgb,fold:1\n", 487 | "Training until validation scores don't improve for 100 rounds\n", 488 | "[100]\tvalid_0's multi_logloss: 0.72123\n", 489 | "[200]\tvalid_0's multi_logloss: 0.702182\n", 490 | "[300]\tvalid_0's multi_logloss: 0.691399\n", 491 | "[400]\tvalid_0's multi_logloss: 0.683737\n", 492 | "[500]\tvalid_0's multi_logloss: 0.679033\n", 493 | "[600]\tvalid_0's multi_logloss: 0.674514\n", 494 | "[700]\tvalid_0's multi_logloss: 0.671347\n", 495 | "[800]\tvalid_0's multi_logloss: 0.667995\n", 496 | "[900]\tvalid_0's multi_logloss: 0.665037\n", 497 | "[1000]\tvalid_0's multi_logloss: 0.663276\n", 498 | "[1100]\tvalid_0's multi_logloss: 0.661124\n", 499 | "[1200]\tvalid_0's multi_logloss: 0.660037\n", 500 | "[1300]\tvalid_0's multi_logloss: 0.658191\n", 501 | "[1400]\tvalid_0's multi_logloss: 0.657359\n", 502 | "[1500]\tvalid_0's multi_logloss: 0.655476\n", 503 | "[1600]\tvalid_0's multi_logloss: 0.654549\n", 504 | "[1700]\tvalid_0's multi_logloss: 0.653858\n", 505 | "[1800]\tvalid_0's multi_logloss: 0.653394\n", 506 | "[1900]\tvalid_0's multi_logloss: 0.652039\n", 507 | "[2000]\tvalid_0's multi_logloss: 0.651329\n", 508 | "[2100]\tvalid_0's multi_logloss: 0.650739\n", 509 | "[2200]\tvalid_0's multi_logloss: 0.650396\n", 510 | "[2300]\tvalid_0's multi_logloss: 0.650024\n", 511 | "[2400]\tvalid_0's multi_logloss: 0.649636\n", 512 | "Early stopping, best iteration is:\n", 513 | "[2348]\tvalid_0's multi_logloss: 0.649555\n", 514 | "name lgb,fold:2\n", 515 | "Training until validation scores don't improve for 100 rounds\n", 516 | "[100]\tvalid_0's multi_logloss: 0.698135\n", 517 | "[200]\tvalid_0's multi_logloss: 0.678384\n", 518 | "[300]\tvalid_0's multi_logloss: 0.670606\n", 519 | "[400]\tvalid_0's multi_logloss: 0.665394\n", 520 | "[500]\tvalid_0's multi_logloss: 0.661957\n", 521 | "[600]\tvalid_0's multi_logloss: 0.659473\n", 522 | "[700]\tvalid_0's multi_logloss: 0.656928\n", 523 | "[800]\tvalid_0's multi_logloss: 0.654736\n", 524 | "[900]\tvalid_0's multi_logloss: 0.653377\n", 525 | "[1000]\tvalid_0's multi_logloss: 0.651986\n", 526 | "[1100]\tvalid_0's multi_logloss: 0.651069\n", 527 | "[1200]\tvalid_0's multi_logloss: 0.650048\n", 528 | "[1300]\tvalid_0's multi_logloss: 0.649362\n", 529 | "[1400]\tvalid_0's multi_logloss: 0.648406\n", 530 | "[1500]\tvalid_0's multi_logloss: 0.647711\n", 531 | "[1600]\tvalid_0's multi_logloss: 0.647653\n", 532 | "[1700]\tvalid_0's multi_logloss: 0.647045\n", 533 | "[1800]\tvalid_0's multi_logloss: 0.646682\n", 534 | "Early stopping, best iteration is:\n", 535 | "[1776]\tvalid_0's multi_logloss: 0.646413\n", 536 | "name lgb,fold:3\n", 537 | "Training until validation scores don't improve for 100 rounds\n", 538 | "[100]\tvalid_0's multi_logloss: 0.745294\n", 539 | "[200]\tvalid_0's multi_logloss: 0.723932\n", 540 | "[300]\tvalid_0's multi_logloss: 0.714146\n", 541 | "[400]\tvalid_0's multi_logloss: 0.705944\n", 542 | "[500]\tvalid_0's multi_logloss: 0.699579\n", 543 | "[600]\tvalid_0's multi_logloss: 0.694506\n", 544 | "[700]\tvalid_0's multi_logloss: 0.691188\n", 545 | "[800]\tvalid_0's multi_logloss: 0.688066\n", 546 | "[900]\tvalid_0's multi_logloss: 0.686456\n", 547 | "[1000]\tvalid_0's multi_logloss: 0.684809\n", 548 | "[1100]\tvalid_0's multi_logloss: 0.683228\n", 549 | "[1200]\tvalid_0's multi_logloss: 0.681649\n", 550 | "[1300]\tvalid_0's multi_logloss: 0.680944\n", 551 | "[1400]\tvalid_0's multi_logloss: 0.680139\n", 552 | "[1500]\tvalid_0's multi_logloss: 0.680116\n", 553 | "[1600]\tvalid_0's multi_logloss: 0.67919\n", 554 | "[1700]\tvalid_0's multi_logloss: 0.678082\n", 555 | "[1800]\tvalid_0's multi_logloss: 0.677826\n", 556 | "Early stopping, best iteration is:\n", 557 | "[1765]\tvalid_0's multi_logloss: 0.677701\n", 558 | "name lgb,fold:4\n", 559 | "Training until validation scores don't improve for 100 rounds\n", 560 | "[100]\tvalid_0's multi_logloss: 0.721801\n", 561 | "[200]\tvalid_0's multi_logloss: 0.702028\n", 562 | "[300]\tvalid_0's multi_logloss: 0.691476\n", 563 | "[400]\tvalid_0's multi_logloss: 0.683836\n", 564 | "[500]\tvalid_0's multi_logloss: 0.678426\n", 565 | "[600]\tvalid_0's multi_logloss: 0.67422\n", 566 | "[700]\tvalid_0's multi_logloss: 0.670884\n", 567 | "[800]\tvalid_0's multi_logloss: 0.667678\n", 568 | "[900]\tvalid_0's multi_logloss: 0.665105\n", 569 | "[1000]\tvalid_0's multi_logloss: 0.662746\n", 570 | "[1100]\tvalid_0's multi_logloss: 0.660758\n", 571 | "[1200]\tvalid_0's multi_logloss: 0.658943\n", 572 | "[1300]\tvalid_0's multi_logloss: 0.657223\n", 573 | "[1400]\tvalid_0's multi_logloss: 0.65647\n", 574 | "[1500]\tvalid_0's multi_logloss: 0.655771\n", 575 | "[1600]\tvalid_0's multi_logloss: 0.654806\n", 576 | "[1700]\tvalid_0's multi_logloss: 0.653854\n", 577 | "[1800]\tvalid_0's multi_logloss: 0.653436\n", 578 | "[1900]\tvalid_0's multi_logloss: 0.65327\n", 579 | "[2000]\tvalid_0's multi_logloss: 0.6527\n", 580 | "[2100]\tvalid_0's multi_logloss: 0.652072\n", 581 | "[2200]\tvalid_0's multi_logloss: 0.651597\n", 582 | "[2300]\tvalid_0's multi_logloss: 0.650908\n", 583 | "[2400]\tvalid_0's multi_logloss: 0.651039\n", 584 | "Early stopping, best iteration is:\n", 585 | "[2317]\tvalid_0's multi_logloss: 0.650671\n", 586 | "name lgb,fold:5\n", 587 | "Training until validation scores don't improve for 100 rounds\n", 588 | "[100]\tvalid_0's multi_logloss: 0.747448\n", 589 | "[200]\tvalid_0's multi_logloss: 0.729605\n", 590 | "[300]\tvalid_0's multi_logloss: 0.720136\n", 591 | "[400]\tvalid_0's multi_logloss: 0.713169\n", 592 | "[500]\tvalid_0's multi_logloss: 0.708154\n", 593 | "[600]\tvalid_0's multi_logloss: 0.704069\n", 594 | "[700]\tvalid_0's multi_logloss: 0.701452\n", 595 | "[800]\tvalid_0's multi_logloss: 0.699605\n", 596 | "[900]\tvalid_0's multi_logloss: 0.698541\n", 597 | "[1000]\tvalid_0's multi_logloss: 0.697277\n", 598 | "[1100]\tvalid_0's multi_logloss: 0.69559\n", 599 | "[1200]\tvalid_0's multi_logloss: 0.694796\n", 600 | "[1300]\tvalid_0's multi_logloss: 0.695679\n", 601 | "Early stopping, best iteration is:\n", 602 | "[1203]\tvalid_0's multi_logloss: 0.69468\n", 603 | "name lgb,fold:6\n", 604 | "Training until validation scores don't improve for 100 rounds\n", 605 | "[100]\tvalid_0's multi_logloss: 0.706165\n", 606 | "[200]\tvalid_0's multi_logloss: 0.689015\n", 607 | "[300]\tvalid_0's multi_logloss: 0.680888\n", 608 | "[400]\tvalid_0's multi_logloss: 0.675696\n", 609 | "[500]\tvalid_0's multi_logloss: 0.671986\n", 610 | "[600]\tvalid_0's multi_logloss: 0.669646\n", 611 | "[700]\tvalid_0's multi_logloss: 0.667501\n", 612 | "[800]\tvalid_0's multi_logloss: 0.665484\n", 613 | "[900]\tvalid_0's multi_logloss: 0.663935\n", 614 | "[1000]\tvalid_0's multi_logloss: 0.662532\n", 615 | "[1100]\tvalid_0's multi_logloss: 0.662434\n", 616 | "Early stopping, best iteration is:\n", 617 | "[1043]\tvalid_0's multi_logloss: 0.66206\n", 618 | "name lgb,fold:7\n", 619 | "Training until validation scores don't improve for 100 rounds\n", 620 | "[100]\tvalid_0's multi_logloss: 0.705645\n", 621 | "[200]\tvalid_0's multi_logloss: 0.682818\n", 622 | "[300]\tvalid_0's multi_logloss: 0.672784\n", 623 | "[400]\tvalid_0's multi_logloss: 0.666212\n", 624 | "[500]\tvalid_0's multi_logloss: 0.661573\n", 625 | "[600]\tvalid_0's multi_logloss: 0.657639\n", 626 | "[700]\tvalid_0's multi_logloss: 0.655245\n", 627 | "[800]\tvalid_0's multi_logloss: 0.652771\n", 628 | "[900]\tvalid_0's multi_logloss: 0.651098\n", 629 | "[1000]\tvalid_0's multi_logloss: 0.649245\n", 630 | "[1100]\tvalid_0's multi_logloss: 0.647459\n", 631 | "[1200]\tvalid_0's multi_logloss: 0.646708\n", 632 | "[1300]\tvalid_0's multi_logloss: 0.646451\n", 633 | "[1400]\tvalid_0's multi_logloss: 0.645563\n", 634 | "[1500]\tvalid_0's multi_logloss: 0.644838\n", 635 | "[1600]\tvalid_0's multi_logloss: 0.644341\n", 636 | "[1700]\tvalid_0's multi_logloss: 0.644489\n", 637 | "Early stopping, best iteration is:\n", 638 | "[1614]\tvalid_0's multi_logloss: 0.644097\n", 639 | "name lgb,fold:8\n", 640 | "Training until validation scores don't improve for 100 rounds\n", 641 | "[100]\tvalid_0's multi_logloss: 0.712087\n", 642 | "[200]\tvalid_0's multi_logloss: 0.691694\n", 643 | "[300]\tvalid_0's multi_logloss: 0.682639\n", 644 | "[400]\tvalid_0's multi_logloss: 0.675665\n", 645 | "[500]\tvalid_0's multi_logloss: 0.671097\n", 646 | "[600]\tvalid_0's multi_logloss: 0.668422\n", 647 | "[700]\tvalid_0's multi_logloss: 0.665623\n", 648 | "[800]\tvalid_0's multi_logloss: 0.66405\n", 649 | "[900]\tvalid_0's multi_logloss: 0.662495\n", 650 | "[1000]\tvalid_0's multi_logloss: 0.662215\n", 651 | "[1100]\tvalid_0's multi_logloss: 0.661327\n", 652 | "Early stopping, best iteration is:\n", 653 | "[1091]\tvalid_0's multi_logloss: 0.660995\n", 654 | "name lgb,fold:9\n", 655 | "Training until validation scores don't improve for 100 rounds\n", 656 | "[100]\tvalid_0's multi_logloss: 0.717289\n", 657 | "[200]\tvalid_0's multi_logloss: 0.700601\n", 658 | "[300]\tvalid_0's multi_logloss: 0.692342\n", 659 | "[400]\tvalid_0's multi_logloss: 0.686667\n", 660 | "[500]\tvalid_0's multi_logloss: 0.683012\n", 661 | "[600]\tvalid_0's multi_logloss: 0.679597\n", 662 | "[700]\tvalid_0's multi_logloss: 0.676828\n", 663 | "[800]\tvalid_0's multi_logloss: 0.674905\n", 664 | "[900]\tvalid_0's multi_logloss: 0.674092\n", 665 | "[1000]\tvalid_0's multi_logloss: 0.673409\n", 666 | "[1100]\tvalid_0's multi_logloss: 0.672145\n", 667 | "[1200]\tvalid_0's multi_logloss: 0.670768\n", 668 | "[1300]\tvalid_0's multi_logloss: 0.670022\n", 669 | "[1400]\tvalid_0's multi_logloss: 0.668904\n", 670 | "[1500]\tvalid_0's multi_logloss: 0.668818\n" 671 | ] 672 | }, 673 | { 674 | "name": "stdout", 675 | "output_type": "stream", 676 | "text": [ 677 | "[1600]\tvalid_0's multi_logloss: 0.667612\n", 678 | "[1700]\tvalid_0's multi_logloss: 0.667209\n", 679 | "[1800]\tvalid_0's multi_logloss: 0.667317\n", 680 | "[1900]\tvalid_0's multi_logloss: 0.667087\n", 681 | "Early stopping, best iteration is:\n", 682 | "[1883]\tvalid_0's multi_logloss: 0.666779\n", 683 | "accuracy_score:0.7106393323069443\n", 684 | "lgb_test_pred[:10]:[2 2 2 0 1 0 1 1 2 1]\n" 685 | ] 686 | } 687 | ], 688 | "source": [ 689 | "#model lgb分类模型,日志评估,早停防止过拟合\n", 690 | "from lightgbm import LGBMClassifier,log_evaluation,early_stopping\n", 691 | "#metric:准确率\n", 692 | "from sklearn.metrics import accuracy_score\n", 693 | "#KFold是直接分成k折,StratifiedKFold还要考虑每种类别的占比\n", 694 | "from sklearn.model_selection import StratifiedKFold\n", 695 | "choose_cols=[col for col in test_feats.columns]\n", 696 | "def fit_and_predict(train_feats=train_feats,test_feats=test_feats,model=None,num_folds=10,seed=2024,name='lgb'):\n", 697 | " X=train_feats[choose_cols].copy()\n", 698 | " y=train_feats['label'].copy()\n", 699 | " oof_pred=np.zeros((len(X)))\n", 700 | " test_X=test_feats[choose_cols].copy()\n", 701 | " test_pred_pro=np.zeros((num_folds,len(test_X),3))#3是num_classes\n", 702 | " \n", 703 | " #10折交叉验证\n", 704 | " skf = StratifiedKFold(n_splits=num_folds,shuffle=True)\n", 705 | " for fold, (train_index, valid_index) in (enumerate(skf.split(X,y))):\n", 706 | " print(f\"name {name},fold:{fold}\")\n", 707 | "\n", 708 | " X_train, X_valid = X.iloc[train_index], X.iloc[valid_index]\n", 709 | " y_train, y_valid = y.iloc[train_index], y.iloc[valid_index]\n", 710 | " \n", 711 | " model.fit(X_train,y_train,eval_set=[(X_valid, y_valid)],\n", 712 | " callbacks=[log_evaluation(100),early_stopping(100)]\n", 713 | " )\n", 714 | " \n", 715 | " oof_pred[valid_index]=model.predict(X_valid)\n", 716 | " test_pred_pro[fold]=model.predict_proba(test_X)\n", 717 | " \n", 718 | " print(f\"accuracy_score:{accuracy_score(y.values,oof_pred)}\")\n", 719 | " #(len(test_X),3)\n", 720 | " test_pred_pro=test_pred_pro.mean(axis=0)\n", 721 | " \n", 722 | " test_preds=np.argmax(test_pred_pro,axis=1)\n", 723 | " return oof_pred,test_preds\n", 724 | "lgb_params={\n", 725 | " \"boosting_type\": \"gbdt\",\n", 726 | " \"objective\": \"multi_class\",\n", 727 | " \"metric\": \"multi_logloss\",\n", 728 | " \"max_depth\": 6,\n", 729 | " \"learning_rate\": 0.05,\n", 730 | " \"n_estimators\":10000,\n", 731 | " \"colsample_bytree\": 0.2,\n", 732 | " \"colsample_bynode\": 0.2,\n", 733 | " \"verbose\": -1,\n", 734 | " \"random_state\": 2024,\n", 735 | " \"reg_alpha\": 0.1,\n", 736 | " \"reg_lambda\": 10,\n", 737 | " \"extra_trees\":True,\n", 738 | " 'num_leaves':127,\n", 739 | " \"verbose\": -1,\n", 740 | " \"max_bin\":225,\n", 741 | " }\n", 742 | "\n", 743 | "lgb_oof_pred_pro,lgb_test_pred=fit_and_predict(model=LGBMClassifier(**lgb_params),num_folds=10,seed=2024,name='lgb')\n", 744 | "print(f\"lgb_test_pred[:10]:{lgb_test_pred[:10]}\")" 745 | ] 746 | }, 747 | { 748 | "cell_type": "markdown", 749 | "id": "401aa1d6", 750 | "metadata": {}, 751 | "source": [ 752 | "### Submission" 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": 6, 758 | "id": "4b2f9f8c", 759 | "metadata": {}, 760 | "outputs": [ 761 | { 762 | "data": { 763 | "text/html": [ 764 | "
\n", 765 | "\n", 778 | "\n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | "
idlabel
002
112
222
330
441
\n", 814 | "
" 815 | ], 816 | "text/plain": [ 817 | " id label\n", 818 | "0 0 2\n", 819 | "1 1 2\n", 820 | "2 2 2\n", 821 | "3 3 0\n", 822 | "4 4 1" 823 | ] 824 | }, 825 | "execution_count": 6, 826 | "metadata": {}, 827 | "output_type": "execute_result" 828 | } 829 | ], 830 | "source": [ 831 | "submission['label']=lgb_test_pred\n", 832 | "submission.to_csv(path+\"baseline.csv\",index=None)\n", 833 | "submission.head()" 834 | ] 835 | }, 836 | { 837 | "cell_type": "markdown", 838 | "id": "ebf05fd0", 839 | "metadata": {}, 840 | "source": [ 841 | "### 后续改进方向:\n", 842 | "\n", 843 | "#### 1.可以用上全部的数据,这样的问题就是训练数据和测试数据分布不一致,线下CV不具有参考意义。\n", 844 | "\n", 845 | "#### 2.构造统计特征的时候加上q25,q75,skew,kurt等特征。\n", 846 | "\n", 847 | "#### 3.考虑构造血氧和心率的交叉特征(加减乘除),并对交叉特征采用统计方法建模。\n", 848 | "\n", 849 | "#### 4.尝试融合模型(lgb,xgb,cat)\n", 850 | "\n", 851 | "#### 5.采用深度学习的方法并结合赛题背景进行建模。" 852 | ] 853 | }, 854 | { 855 | "cell_type": "code", 856 | "execution_count": null, 857 | "id": "d4bda1db", 858 | "metadata": {}, 859 | "outputs": [], 860 | "source": [] 861 | } 862 | ], 863 | "metadata": { 864 | "kernelspec": { 865 | "display_name": "Python 3 (ipykernel)", 866 | "language": "python", 867 | "name": "python3" 868 | }, 869 | "language_info": { 870 | "codemirror_mode": { 871 | "name": "ipython", 872 | "version": 3 873 | }, 874 | "file_extension": ".py", 875 | "mimetype": "text/x-python", 876 | "name": "python", 877 | "nbconvert_exporter": "python", 878 | "pygments_lexer": "ipython3", 879 | "version": "3.8.5" 880 | } 881 | }, 882 | "nbformat": 4, 883 | "nbformat_minor": 5 884 | } 885 | --------------------------------------------------------------------------------