├── llm_rag ├── chunk.ipynb ├── cold_start.ipynb ├── multi_turn_rag.ipynb ├── fine_tune_embedding.ipynb ├── fine_tune_rerank.ipynb ├── app │ ├── init.sh │ ├── utils.py │ ├── app.py │ └── server.py ├── README.md └── incremental_update.ipynb ├── llm_batch_inference ├── README.md ├── server │ └── run_cluster.sh └── client │ └── vllm_ray_batch_inference.py ├── llm_fine_tuning ├── llm_as_judge.py ├── systnetic_data_oai.py ├── README.md └── llama-instruct-tuning-alpaca.ipynb ├── llm_base_inference ├── vllm_inference.py ├── README.md └── sgl_inference.py ├── llama_tutorials └── README.md ├── requirements.txt ├── assets └── sft-memory.jpg ├── chat_with_sql ├── nba_roster.db ├── README.md └── text2sql_demo.py ├── snippets ├── r_map │ ├── word_ggmap.R │ ├── China_map_ggmap.R │ ├── World_map_rworldmap.R │ ├── World_map_ggplot.R │ ├── China_map_bubble.R │ └── China_map_great_cicle.R ├── r_plot │ ├── binary_comparison_segment.R │ ├── stacked_area_plot.R │ └── multi bar plot.R ├── README.md ├── py_map │ ├── echarts-map.py │ ├── geopandas-map.py │ └── world_map.py ├── send_email │ └── send_email.py └── r_weather-data │ ├── weatherGet.R │ └── weather.R ├── bert_text_classification └── README.md ├── llm_agent ├── README.md └── openai_agent.py ├── .gitignore ├── time_series ├── README.md └── vehicle-sales-prediction-tensorflow-lstm.ipynb └── README.md /llm_rag/chunk.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_rag/cold_start.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_rag/multi_turn_rag.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_batch_inference/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_fine_tuning/llm_as_judge.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_rag/fine_tune_embedding.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_rag/fine_tune_rerank.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_base_inference/vllm_inference.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llm_fine_tuning/systnetic_data_oai.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llama_tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Llama models 2 | -------------------------------------------------------------------------------- /llm_rag/app/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | numpy 3 | pandas 4 | xgboost 5 | -------------------------------------------------------------------------------- /assets/sft-memory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyingyue/machine-learning-poc/HEAD/assets/sft-memory.jpg -------------------------------------------------------------------------------- /chat_with_sql/nba_roster.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongyingyue/machine-learning-poc/HEAD/chat_with_sql/nba_roster.db -------------------------------------------------------------------------------- /snippets/r_map/word_ggmap.R: -------------------------------------------------------------------------------- 1 | library(maps) 2 | library(ggplot2) 3 | 4 | map = map_data("world") 5 | ggplot(map, aes(long, lat, group=group)) + 6 | geom_polygon(fill="white", colour="gray") + 7 | ggtitle("Map of World") -------------------------------------------------------------------------------- /bert_text_classification/README.md: -------------------------------------------------------------------------------- 1 | # Text classifiction 2 | 3 | [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/brendayue/twitter-sentiment-analysis-in-pytorch) 4 | 5 | 6 | -------------------------------------------------------------------------------- /llm_agent/README.md: -------------------------------------------------------------------------------- 1 | # Agent 2 | 3 | read [Agent tutorials](https://github.com/hongyingyue/LLM-agent-tutorials) 4 | 5 | ## Quick start 6 | ```shell 7 | pip install litellm openai duckduckgo-search rich 8 | ``` 9 | 10 | export OPENAI_KEY_API = "sk-xxxx" 11 | -------------------------------------------------------------------------------- /llm_base_inference/README.md: -------------------------------------------------------------------------------- 1 | # LLM Inference 2 | 3 | ## [sglang](https://github.com/sgl-project/sglang) 4 | - use triton 5 | 6 | ## [vllm](https://github.com/vllm-project/vllm) 7 | - concurrent request 8 | 9 | ## General QA 10 | - warm up service 11 | - 12 | -------------------------------------------------------------------------------- /snippets/r_map/China_map_ggmap.R: -------------------------------------------------------------------------------- 1 | # setwd('./Visualization') 2 | # https://www.littlemissdata.com/blog/maps 3 | 4 | library(ggmap) 5 | library(ggplot2) 6 | library(maps) 7 | library(mapdata) 8 | 9 | usa <- map_data("usa") 10 | ggplot() + 11 | geom_polygon(data = usa, aes(x=long, y = lat, group = group), fill = NA, color = "red") + 12 | coord_fixed(1.3) -------------------------------------------------------------------------------- /llm_fine_tuning/README.md: -------------------------------------------------------------------------------- 1 | ## LLM sft 2 | 3 | ### 1. Synthetic Data Generation 4 | 5 | 6 | ### 2. Instruction tuning 7 | The Cuda memory: 8 | ![](../assets/sft-memory.jpg) 9 | 10 | [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/brendayue/llama-instruct-tuning-alpaca/notebook?scriptVersionId=240320418) 11 | 12 | 13 | ### 3. LLM evaluation 14 | -------------------------------------------------------------------------------- /llm_rag/README.md: -------------------------------------------------------------------------------- 1 | # RAG 2 | 3 | ## Langchain POC 4 | 5 | 6 | ## LLamaindex POC 7 | 8 | 9 | ## Basic app 10 | ```shell 11 | pip install uvicorn fastapi python-multipart streamlit PyMuPDF 12 | ``` 13 | 14 | ```shell 15 | cd app 16 | 17 | # Start FastAPI backend 18 | uvicorn server:app --reload --port 8000 19 | 20 | # Start Streamlit frontend 21 | streamlit run app.py --server.port 8501 22 | ``` 23 | -------------------------------------------------------------------------------- /llm_rag/incremental_update.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "717b2451", 6 | "metadata": {}, 7 | "source": [ 8 | "# How to Handle Incremental Updates to Indexed Data in RAG?\n", 9 | "- https://github.com/microsoft/graphrag/issues/741" 10 | ] 11 | } 12 | ], 13 | "metadata": { 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "nbformat": 4, 19 | "nbformat_minor": 5 20 | } 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.ipynb_checkpoints/ 3 | **/.idea/* 4 | **/.pytest_cache/ 5 | **/.DS_Store 6 | /reference/* 7 | /backup/* 8 | **/nohup.out 9 | **/temp.py 10 | /data/logs/* 11 | /data/*.zip 12 | /data/raw/* 13 | /data/web/* 14 | 15 | /weights/scaler.pkl 16 | /weights/saved_model.pb 17 | /weights/variables/* 18 | /weights/checkpoint 19 | /weights/checkpoint.data-00000-of-00001 20 | /weights/checkpoint.index 21 | 22 | /llm_rag/app/uploaded_files 23 | 24 | .env -------------------------------------------------------------------------------- /chat_with_sql/README.md: -------------------------------------------------------------------------------- 1 | # text2sql 2 | 3 | ## qiuck start 4 | ```shell 5 | pip install langchain_community --quiet 6 | pip install "langchain[openai]" --quiet 7 | ``` 8 | 9 | ```shell 10 | python text2sql_demo.py 11 | ``` 12 | 13 | **Output** 14 | ```text 15 | SELECT "Team" FROM nba_roster WHERE "NAME" = 'Stephen Curry'; 16 | 17 | [('Golden State Warriors',)] 18 | 19 | SELECT "SALARY" FROM nba_roster WHERE "NAME" = 'Stephen Curry'; 20 | 21 | [('$51,915,615',)] 22 | ``` 23 | -------------------------------------------------------------------------------- /time_series/README.md: -------------------------------------------------------------------------------- 1 | # 📈 Time Series Prediction: Common Methods Overview 2 | 3 | ## 1. ARIMA (AutoRegressive Integrated Moving Average) 4 | 5 | 6 | ## 2. Prophet 7 | 8 | 9 | ## 3. XGBoost / Gradient Boosting Trees 10 | - [production ready xgboost implementation](https://github.com/hongyingyue/Vehicle-sales-predictor) 11 | 12 | 13 | ## 4. Deep Learning (LSTM, GRU, Transformer, TCN) 14 | 15 | [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/brendayue/vehicle-sales-prediction-tensorflow) 16 | 17 | 18 | - [DeepAR] 19 | - [LSTM](./vehicle-sales-prediction-tensorflow-lstm.ipynb) 20 | - [Transformer] 21 | -------------------------------------------------------------------------------- /llm_rag/app/utils.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | from dotenv import load_dotenv 3 | import os 4 | from openai import OpenAI 5 | 6 | def extract_text_from_pdf(file_path): 7 | doc = fitz.open(file_path) 8 | text = "\n".join(page.get_text() for page in doc) 9 | return text 10 | 11 | 12 | # OpenAI API 13 | load_dotenv() 14 | api_key = os.getenv("OPENAI_API_KEY") 15 | client = OpenAI() 16 | 17 | def get_completion(prompt, 18 | model="gpt-3.5-turbo", 19 | temperature = 0.5): 20 | messages = [{"role": "user", "content": prompt}] 21 | response = client.chat.completions.create( 22 | model = model, 23 | messages = messages, 24 | temperature = temperature 25 | ) 26 | return response.choices[0].message.content -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine learning projects 2 | [![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE)
3 | 4 | Some Proof-of-Concept projects for machine learning.
5 | Welcome to contribute. 6 | 7 | 8 | ## Get started 9 | 10 | - [Text classification with BERT for twitter](./bert_text_classification/) 11 | - [LLM RAG for arxiv](./llm_rag/) 12 | - [LLM Agent for code](./llm_agent/) 13 | - [LLM fine tuning](./llm_fine_tuning/) 14 | - [LLM text2sql](./chat_with_sql/) 15 | - [LLM batch inference fot multi nodes](./llm_batch_inference/) 16 | - [LLM application](./llm_app/) 17 | - [LLama tutorial](./llama_tutorials/) 18 | - [Time series lstm/transformer](./time_series/) 19 | - [Map with R/Python](./snippets/) 20 | -------------------------------------------------------------------------------- /snippets/r_plot/binary_comparison_segment.R: -------------------------------------------------------------------------------- 1 | 2 | setwd('./00-Code/Visualization') 3 | library("ggplot2") 4 | library("ggthemes") 5 | library("reshape2") 6 | library("plyr") 7 | 8 | Da<-read.csv("./data/binary_com.csv",sep=',',header=T) 9 | Da<-data.frame(Da) 10 | 11 | p<-ggplot(Da)+coord_flip() 12 | p<-p+geom_segment(aes(Items,Label1,xend=Items,yend=Label2),linetype='solid',size=5,lineend='round',color='light blue') 13 | p<-p+geom_point(aes(x=Items,y=Label1),color='blue',size=5,shape=16) 14 | p<-p+geom_point(aes(x=Items,y=Label2),color='blue',size=5,shape=21,fill='white') 15 | p<-p+geom_text(aes(x=Items,y=Label1,label=Label1),hjust=2) 16 | p<-p+geom_text(aes(x=Items,y=Label2,label=Label2),hjust=-2) 17 | p<-p+theme_wsj(color="brown") 18 | ggsave("./outputs/Binary comparison.png", p, height=4.8, width=9.5,bg= "transparent") 19 | -------------------------------------------------------------------------------- /snippets/README.md: -------------------------------------------------------------------------------- 1 | # Code snippets in R/Python 2 | 3 | 4 | ## Map 5 | If you want to download more specific province of China, you could download from [NGCC](http://www.webmap.cn/mapDataAction.do?method=forw&resType=5&storeId=2&storeName=%E5%9B%BD%E5%AE%B6%E5%9F%BA%E7%A1%80%E5%9C%B0%E7%90%86%E4%BF%A1%E6%81%AF%E4%B8%AD%E5%BF%83) 6 | 7 | ```bash 8 | Rscript China_map_bubble.R 9 | Rscript Chine_map_great_circle.R 10 | ``` 11 | 12 | ## weather 13 | - Get the data from RNCEP. Further reading [this posting](https://dominicroye.github.io/en/2018/access-to-climate-reanalysis-data-from-r/) 14 | 15 | - check the official [China weather API](http://data.cma.cn/Market/MarketList.html), or download through FTP directly from [NOAA](https://www.esrl.noaa.gov/psd/data/gridded/help.html#FTP) 16 | 17 | ```shell 18 | install.packages("RNCEP") 19 | # In Mac, I have to install xquartz: `brew cask install xquartz`, and restart 20 | ``` 21 | -------------------------------------------------------------------------------- /snippets/r_map/World_map_rworldmap.R: -------------------------------------------------------------------------------- 1 | setwd('./Visualization') 2 | library(rworldmap) 3 | Market<-read.table("./data/map_R_world.csv",header=T,sep=',',skip=2) 4 | Market<-data.frame(Market) 5 | names(Market)<-c("Country_code","Market","Metrics","Claims","Production_number","Complaint_rate") 6 | Market$Production_number<-as.numeric(Market$Production_number) 7 | Market<-Market[order(-Market$Production_number),] 8 | #Market<-Market[1:20,] 9 | Market$Market<-gsub("Korea (South)","South Korea",Market$Market) #Not working!!!! 10 | 11 | SPDF<-joinCountryData2Map(Market,joinCode='NAME',nameJoinColumn='Market') 12 | mapBubbles(SPDF, nameZSize="Complaint_rate",nameZColour="red",oceanCol="lightblue",landCol="beige") 13 | 14 | identifyCountries(getMap(), nameColumnToPlot="category") #click the country to add the lable 15 | 16 | #change the columns name of 'Salesmarket' and 'Complaint_rate', change the percentage to number, change to 'South Korea' -------------------------------------------------------------------------------- /snippets/py_map/echarts-map.py: -------------------------------------------------------------------------------- 1 | 2 | # https://github.com/pyecharts/pyecharts 3 | # you should have Internet connect 4 | 5 | from pyecharts.charts import Map, Geo 6 | from pyecharts import options as opts 7 | 8 | 9 | province_dict = {'河南': 45.23, '北京': 37.56, '河北': 21, '辽宁': 12, '江西': 6, '上海': 20, '安徽': 10, '江苏': 16, '湖南': 9, 10 | '浙江': 13, '海南': 2, '广东': 22, '湖北': 8, '黑龙江': 11, '澳门': 1, '陕西': 11, '四川': 7, 11 | '内蒙古': 3, '重庆': 3, '云南': 6, '贵州': 2, '吉林': 3, '山西': 12, '山东': 11, '福建': 4, '青海': 1, 12 | '天津': 1, '其他': 1} 13 | 14 | province_char = [[item[0],item[1]] for item in province_dict.items()] 15 | print(province_char) 16 | 17 | map = Map(init_opts=opts.InitOpts(width='1200px', height='800px')) 18 | map.set_global_opts( 19 | title_opts=opts.TitleOpts(title="2019年"), 20 | visualmap_opts=opts.VisualMapOpts(max_=50)) 21 | map.add("China Map Example", data_pair=province_char, maptype='china', is_roam=True) 22 | map.render(path="中国地图.html") 23 | -------------------------------------------------------------------------------- /snippets/send_email/send_email.py: -------------------------------------------------------------------------------- 1 | # Import smtplib for the actual sending function 2 | import smtplib 3 | from email.mime.text import MIMEText 4 | from email.header import Header 5 | from email.mime.image import MIMEImage 6 | from email.mime.multipart import MIMEMultipart 7 | 8 | 9 | def send_email(subject="No subject", content="I am boring"): 10 | mail_host = "smtp.163.com" 11 | mail_user = "yuetan@163.com" 12 | mail_pw = "********" # 授权码 13 | sender = "yuetan@163.com" 14 | receiver = "yuetan@163.com" 15 | 16 | # Create the container (outer) email message. 17 | msg = MIMEText(content, "plain", "utf-8") 18 | msg['Subject'] = subject 19 | msg['From'] = sender 20 | msg['To'] = receiver 21 | 22 | try: 23 | smtp = smtplib.SMTP_SSL(mail_host, 994) # 实例化smtp服务器 24 | smtp.login(mail_user, mail_pw) # 登录 25 | smtp.sendmail(sender, receiver, msg.as_string()) 26 | print("Email send successfully") 27 | except smtplib.SMTPException: 28 | print("Error: email send failed") 29 | 30 | 31 | if __name__ == '__main__': 32 | send_email(subject="Training finished", content="I am boring") 33 | -------------------------------------------------------------------------------- /snippets/r_weather-data/weatherGet.R: -------------------------------------------------------------------------------- 1 | install.packages("RNCEP") 2 | 3 | 4 | library(lubridate) 5 | library(RNCEP) 6 | 7 | 8 | get_weather=function(RYear,RMonth,Lon,Lat){ 9 | wx.extent <- NCEP.gather(variable='air',level=850,months.minmax=c(RMonth,RMonth),years.minmax=c(RYear,RYear),lat.southnorth=c(Lat,Lat), lon.westeast=c(Lon,Lon),reanalysis2=FALSE,return.units=TRUE)-273.15 10 | wx.ag <- NCEP.aggregate(wx.data=wx.extent, YEARS=TRUE, MONTHS=TRUE,DAYS=TRUE, HOURS=FALSE, fxn='mean') 11 | wx <- NCEP.array2df(wx.ag, var.names=NULL) 12 | w <- wx[1,4] 13 | return(w) 14 | } 15 | 16 | w=get_weather(2017,1,103,31) 17 | 18 | View(w) 19 | 20 | 21 | wx.extent1 <- NCEP.gather(variable='air', level=850, 22 | months.minmax=c(9,10), years.minmax=c(1998,1998), 23 | lat.southnorth=c(50,51), lon.westeast=c(5,6), 24 | reanalysis2 = FALSE, return.units = TRUE) 25 | 26 | View(wx.extent1) 27 | 28 | dimnames(wx.extent1)[[1]] 29 | dimnames(wx.extent1)[[2]] 30 | dimnames(wx.extent1)[[3]] 31 | 32 | class(wx.extent1) 33 | 34 | 35 | wx.df <- NCEP.array2df(wx.extent1[,,]) 36 | View(wx.df) 37 | -------------------------------------------------------------------------------- /llm_rag/app/app.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import requests 3 | import streamlit as st 4 | 5 | BACKEND_URL = "http://localhost:8000" 6 | 7 | if "session_id" not in st.session_state: 8 | st.session_state.session_id = str(uuid.uuid4()) 9 | 10 | st.title("📁 Chat with Your File") 11 | 12 | # File Upload 13 | uploaded_file = st.file_uploader("Upload a file", type=["txt", "pdf", "csv", "docx"]) 14 | 15 | if uploaded_file: 16 | st.session_state["uploaded"] = True 17 | response = requests.post( 18 | f"{BACKEND_URL}/upload", 19 | files={"file": (uploaded_file.name, uploaded_file)}, 20 | data={"session_id": st.session_state.session_id} 21 | ) 22 | st.success(response.json()["message"]) 23 | 24 | # Chat 25 | if st.session_state.get("uploaded", False): 26 | user_input = st.text_input("Ask a question about the file:") 27 | if user_input: 28 | res = requests.post( 29 | f"{BACKEND_URL}/chat", 30 | data={"session_id": st.session_state.session_id, "user_input": user_input} 31 | ) 32 | if res.status_code == 200: 33 | for turn in res.json()["history"]: 34 | st.markdown(f"**You**: {turn['user']}") 35 | st.markdown(f"**Bot**: {turn['bot']}") 36 | else: 37 | st.error(res.json()["error"]) 38 | -------------------------------------------------------------------------------- /snippets/r_map/World_map_ggplot.R: -------------------------------------------------------------------------------- 1 | setwd('./Visualization') 2 | library("plyr") 3 | library(ggplot2) 4 | library(maptools) 5 | 6 | world_map <-readShapePoly("./data/World_countries_shp.shp") 7 | x <- world_map@data 8 | xs <- data.frame(x,id=seq(0:238)-1) 9 | world_map1 <- fortify(world_map) 10 | world_map_data <- join(world_map1, xs, type = "full") 11 | 12 | Mydata<-read.csv("./data/map_R_country.csv",sep=',',header=T) 13 | names(Mydata)<-c("NAME","Failures") 14 | #replace USA to United States 15 | Worlddata=join(Mydata, world_map_data, type = "full") 16 | 17 | theme_map <- list(theme(panel.grid.minor = element_blank(), 18 | panel.grid.major = element_blank(), 19 | panel.border = element_blank(), 20 | axis.line = element_blank(), 21 | axis.text.x = element_blank(), 22 | axis.text.y = element_blank(), 23 | axis.ticks = element_blank(), 24 | axis.title.x = element_blank(), 25 | axis.title.y = element_blank(), 26 | panel.background = element_blank(), 27 | plot.background = element_blank())) 28 | 29 | p<-ggplot(Worlddata, aes(x = long, y = lat, group = group,fill =Failures)) +geom_polygon(color = 'black',size=0.1,alpha=0.1) 30 | p<-p+scale_fill_brewer(palette="YlOrRd")+theme_map 31 | p<-p+theme(legend.position='none') 32 | ggsave("./outputs/World map ggplot.png", p, height=4.8, width=9.5) 33 | 34 | #unique(world_map_data$NAME) 35 | #+coord_cartesian(xlim=c(60,155),ylim=c(0,65)) #china near area -------------------------------------------------------------------------------- /snippets/py_map/geopandas-map.py: -------------------------------------------------------------------------------- 1 | # 矢量地图shp文件,注意完整文件包括同名的 shp,shx,dbf三个文件,名字和路径需要相同 2 | # https://tianchi.aliyun.com/notebook-ai/detail?spm=5176.12586969.1002.6.72e87f7doOyHCP&postId=63248 3 | 4 | import pandas as pd 5 | import geopandas as gp 6 | from matplotlib import pyplot as plt 7 | import matplotlib 8 | import seaborn as sns 9 | 10 | matplotlib.rc('figure', figsize=(14, 7)) 11 | matplotlib.rc('font', size=14) 12 | matplotlib.rc('axes', grid=False) 13 | matplotlib.rc('axes', facecolor='white') 14 | 15 | 16 | def geo_china(df, size_column='size', title='map'): 17 | china_geod = gp.GeoDataFrame.from_file("../../assets/province.shp", encoding='gb18030') 18 | 19 | ax = china_geod.plot(color='white', edgecolor='black') 20 | 21 | df.plot(ax=ax, color='red', markersize=df[size_column], alpha=0.5) 22 | # df.plot(ax=ax, column=size_column, cmap='Blues', linewidth=0.8, edgecolor='0.8') 23 | ax.set_axis_off() 24 | plt.show() 25 | 26 | 27 | if __name__ == '__main__': 28 | data = pd.read_csv("../../data/map_R_province.csv") 29 | data_agg = data.groupby(['City'])[['City']].size().reset_index(name='size') 30 | 31 | city_name_map = pd.read_csv(open("../../assets/city_geocode_lookup.csv", encoding='gbk')) 32 | 33 | data_agg = data_agg.merge(city_name_map, left_on='City', right_on='City', how='left') 34 | 35 | geo_data_agg = gp.GeoDataFrame(data_agg, geometry=gp.points_from_xy(data_agg.Lon, data_agg.Lat)) 36 | 37 | geo_china(geo_data_agg, size_column='size') 38 | -------------------------------------------------------------------------------- /snippets/r_plot/stacked_area_plot.R: -------------------------------------------------------------------------------- 1 | #To be updated: use + theme_bw() 2 | 3 | # setwd('./06-QFS_month') 4 | 5 | library(reshape2) 6 | library(ggplot2) 7 | library(gridExtra) 8 | 9 | Fre<-read.table("QFS_monthly_prognosis1.csv",sep=',',header=T) 10 | Fre$Time<-format(as.Date(paste(Fre$Time,".01"),"%Y%m .%d"),format="%y-%m") 11 | DA1<-melt(Fre,id.vars="Time",variable.name="Frequency",value.name="Complaints_per_hundred") 12 | DA1<-data.frame(DA1) 13 | colors1<-c("grey","blue","light green","pink") 14 | p1<-ggplot(DA1,aes(x=Time,y=Complaints_per_hundred,fill=Frequency,group=Frequency))+ 15 | geom_area(alpha=0.5,position='stack',color='black',size =0.1)+scale_fill_manual(values = colors1) 16 | p1<-p1+xlab("Complaint Month") 17 | p1<-p1+scale_y_continuous("Complaint Rate [%]",labels=scales::percent) 18 | p1<-p1+theme(axis.text.y = element_text(size = 12),axis.text.x = element_text(size = 12)) 19 | p1<-p1+theme(axis.title.y = element_text(size = 12),axis.title.x = element_text(size = 12)) 20 | p1<-p1+theme(axis.line.x = element_line(color="black", size = 0.3),axis.line.y = element_line(color="black", size = 0.3)) 21 | p1<-p1+theme(legend.justification=c(0.04,0.95),legend.position=c(0.04,0.95),panel.grid.major=element_blank()) 22 | p1<-p1+theme(legend.background = element_rect(fill=alpha('white', 0.1)),legend.text=element_text(size=6),legend.key.size=unit(0.4,'cm'),legend.key.width=unit(0.4,'cm')) 23 | p1<-p1+theme(legend.title=element_blank())+annotate("rect", xmin="18-12", xmax="19-06", ymin=0, ymax=Inf, alpha=0.22, fill="grey") 24 | ggsave("plot.png", p1, height=4.2, width=8.5) 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /snippets/py_map/world_map.py: -------------------------------------------------------------------------------- 1 | # 有哪些地图可视化工具或Python库可以绘制出真实比例的散点图? - 叶山Shan Ye的回答 - 知乎 2 | # https://www.zhihu.com/question/404165841/answer/1310033961 3 | 4 | import numpy as np 5 | import geopandas as gp 6 | from matplotlib import pyplot as plt 7 | import matplotlib 8 | import seaborn as sns 9 | matplotlib.rc('figure', figsize=(14, 7)) 10 | matplotlib.rc('font', size=14) 11 | matplotlib.rc('axes', grid=False) 12 | matplotlib.rc('axes', facecolor='white') 13 | 14 | 15 | def geod_world(df, title, legend=False): 16 | world_geod = gp.GeoDataFrame.from_file('../../assets/World_countries_shp.shp') 17 | 18 | data_geod = gp.GeoDataFrame(df) # 转换格式 19 | da_merge = world_geod.merge(data_geod, on='NAME', how='left') # 合并 20 | sum(np.isnan(da_merge['NUM'])) 21 | da_merge['NUM'][np.isnan(da_merge['NUM'])] = 14.0 # 填充缺失数据 22 | da_merge.plot('NUM', k=20, cmap=plt.cm.Blues, alpha=1, legend=legend) 23 | plt.title(title, fontsize=15) # 设置图形标题 24 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # 去掉x轴刻度 25 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) # 去年y轴刻度 26 | 27 | 28 | country_dict = {'大陆': 'China', '美国': 'United States', '香港': 'Hong Kong' 29 | , '台湾': 'Taiwan, Province of China' 30 | , '日本': 'Japan', '韩国': 'Korea, Republic of', '英国': 'United Kingdom' 31 | , '法国': 'France', '德国': 'Germany' 32 | , '意大利': 'Italy', '西班牙': 'Spain', '印度': 'India', '泰国': 'Thailand' 33 | , '俄罗斯': 'Russian Federation' 34 | , '伊朗': 'Iran', '加拿大': 'Canada', '澳大利亚': 'Australia' 35 | , '爱尔兰': 'Ireland', '瑞典': 'Sweden' 36 | , '巴西': 'Brazil', '丹麦': 'Denmark'} 37 | 38 | temp0 = temp.reset_index() 39 | df = pd.DataFrame({'NAME': temp0['index'].map(country_dict).tolist() 40 | , 'NUM': (np.log1p(temp0['数目']) * 100).tolist()}) 41 | geod_world(df, 'The popularity of movie in the world') 42 | -------------------------------------------------------------------------------- /snippets/r_map/China_map_bubble.R: -------------------------------------------------------------------------------- 1 | # setwd('./Visualization') 2 | library(plyr) 3 | library(ggplot2) 4 | library(maptools) 5 | library(reshape2) 6 | library(gridExtra) 7 | library(rgdal) 8 | 9 | #import the data 10 | Data<-as.data.frame(read.csv("../../data/map_R_province.csv",sep=';',header=T)) 11 | 12 | #aggregrate 13 | Data_agg<-ddply(Data,.(City),summarize,Complaints_value=length(City)) 14 | City_code<-read.table("../../assets/city_geocode_lookup.csv", sep=',', header=T) 15 | Da<-merge(Data_agg,City_code,by.x='City',by.y='City') 16 | 17 | #Import the map data 18 | province<-readShapePoly("../../assets/province.shp") 19 | chinamap<-fortify(province) 20 | provincedata<-data.frame(province@data,id=seq(0:924)-1) 21 | china_mapdata<-join(chinamap, provincedata, type = "full") 22 | 23 | #plot 24 | p1<-ggplot(data = chinamap) + geom_path(aes(x = long, y = lat, group = id),size=0.2, colour="black")+coord_map(ylim = c(14,55)) 25 | p1<-p1 + geom_polygon(aes(x=long,y=lat,group=id),fill = 'grey90',size=0.1, alpha=0.4) 26 | p1<-p1 + geom_point(data=Da, aes(x=Lon, y=Lat, size=Complaints_value),color='red',alpha=0.45)+scale_size(range = c(0,4)) 27 | p1<-p1 + annotate("text",x = 84,y =20, label = "YueTan",family = "serif", fontface = "italic", colour = "black", size = 4) 28 | 29 | theme_map <- list(theme(panel.grid.minor = element_blank(), 30 | panel.grid.major = element_blank(), 31 | panel.border = element_blank(), 32 | axis.line = element_blank(), 33 | axis.text.x = element_blank(), 34 | axis.text.y = element_blank(), 35 | axis.ticks = element_blank(), 36 | axis.title.x = element_blank(), 37 | axis.title.y = element_blank(), 38 | panel.background = element_rect(fill="transparent"), 39 | plot.background = element_rect(fill = "transparent"), 40 | legend.background = element_rect(fill = "transparent"), 41 | legend.box.background = element_rect(fill = "transparent"))) 42 | 43 | p1<-p1+theme_map+ggtitle('Welcome to follow')+theme(plot.title = element_text(family = 'Helvetica',face = "bold")) 44 | ggsave("../../assets/images/China map bubble plot.png", p1, height=4.8, width=9.5) 45 | -------------------------------------------------------------------------------- /llm_rag/app/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fastapi import FastAPI, UploadFile, File, Form 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from fastapi.responses import JSONResponse 5 | from utils import extract_text_from_pdf, get_completion 6 | 7 | app = FastAPI() 8 | UPLOAD_FOLDER = "uploaded_files" 9 | os.makedirs(UPLOAD_FOLDER, exist_ok=True) 10 | 11 | # Allow frontend requests 12 | app.add_middleware( 13 | CORSMiddleware, 14 | allow_origins=["*"], 15 | allow_methods=["*"], 16 | allow_headers=["*"], 17 | ) 18 | 19 | # Store chat histories per session (simple memory) 20 | chat_histories = {} 21 | 22 | 23 | @app.post("/upload") 24 | async def upload_file(file: UploadFile = File(...), session_id: str = Form(...)): 25 | try: 26 | file_path = os.path.join(UPLOAD_FOLDER, f"{session_id}_{file.filename}") 27 | with open(file_path, "wb") as f: 28 | f.write(await file.read()) 29 | chat_histories[session_id] = {"file": file_path, "history": []} 30 | return {"message": f"File '{file.filename}' uploaded successfully."} 31 | except Exception as e: 32 | return JSONResponse(status_code=500, content={"error": str(e)}) 33 | 34 | 35 | @app.post("/chat") 36 | async def chat(session_id: str = Form(...), user_input: str = Form(...)): 37 | history = chat_histories.get(session_id) 38 | if not history: 39 | return JSONResponse(status_code=400, content={"error": "Session not found or file not uploaded."}) 40 | 41 | file_path = history["file"] 42 | filename = os.path.basename(file_path) 43 | 44 | # Determine how to read the file 45 | if filename.endswith(".pdf"): 46 | content = extract_text_from_pdf(file_path) 47 | else: 48 | with open(file_path, "r", encoding="utf-8", errors="ignore") as f: 49 | content = f.read() 50 | 51 | # Dummy logic: replace with LLM logic 52 | # response = f"File has {len(content.splitlines())} lines. You asked: {user_input}" 53 | response = get_completion(user_input) 54 | history["history"].append({"user": user_input, "bot": response}) 55 | return {"response": response, "history": history["history"]} 56 | -------------------------------------------------------------------------------- /llm_batch_inference/server/run_cluster.sh: -------------------------------------------------------------------------------- 1 | 2 | # Extract the mandatory positional arguments and remove them from $@. 3 | DOCKER_IMAGE="$1" 4 | HEAD_NODE_ADDRESS="$2" 5 | NODE_TYPE="$3" # Should be --head or --worker. 6 | PATH_TO_HF_HOME="$4" 7 | shift 4 8 | 9 | # Preserve any extra arguments so they can be forwarded to Docker. 10 | ADDITIONAL_ARGS=("$@") 11 | 12 | # Validate the NODE_TYPE argument. 13 | if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then 14 | echo "Error: Node type must be --head or --worker" 15 | exit 1 16 | fi 17 | 18 | # Generate a unique container name with random suffix. 19 | # Docker container names must be unique on each host. 20 | # The random suffix allows multiple Ray containers to run simultaneously on the same machine, 21 | # for example, on a multi-GPU machine. 22 | CONTAINER_NAME="node-${RANDOM}" 23 | 24 | # Define a cleanup routine that removes the container when the script exits. 25 | # This prevents orphaned containers from accumulating if the script is interrupted. 26 | cleanup() { 27 | docker stop "${CONTAINER_NAME}" 28 | docker rm "${CONTAINER_NAME}" 29 | } 30 | trap cleanup EXIT 31 | 32 | # Build the Ray start command based on the node role. 33 | # The head node manages the cluster and accepts connections on port 6379, 34 | # while workers connect to the head's address. 35 | RAY_START_CMD="ray start --block" 36 | if [ "${NODE_TYPE}" == "--head" ]; then 37 | RAY_START_CMD+=" --head --port=6379" 38 | else 39 | RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379" 40 | fi 41 | 42 | # Launch the container with the assembled parameters. 43 | # --network host: Allows Ray nodes to communicate directly via host networking 44 | # --shm-size 10.24g: Increases shared memory 45 | # --gpus all: Gives container access to all GPUs on the host 46 | # -v HF_HOME: Mounts HuggingFace cache to avoid re-downloading models 47 | docker run \ 48 | --entrypoint /bin/bash \ 49 | --network host \ 50 | --name "${CONTAINER_NAME}" \ 51 | --shm-size 10.24g \ 52 | --gpus all \ 53 | -v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \ 54 | "${ADDITIONAL_ARGS[@]}" \ 55 | "${DOCKER_IMAGE}" -c "${RAY_START_CMD}" 56 | -------------------------------------------------------------------------------- /snippets/r_weather-data/weather.R: -------------------------------------------------------------------------------- 1 | library(RNCEP) 2 | library(tidyverse) 3 | 4 | 5 | get_weather=function(RYear,RMonth,Lon,Lat){ 6 | wx.extent <- NCEP.gather(variable='air',level=850,months.minmax=c(RMonth,RMonth), 7 | years.minmax=c(RYear,RYear), 8 | lat.southnorth=c(Lat,Lat), lon.westeast=c(Lon,Lon), 9 | reanalysis2=FALSE,return.units=TRUE)-273.15 10 | wx.ag <- NCEP.aggregate(wx.data=wx.extent, YEARS=TRUE, MONTHS=TRUE,DAYS=TRUE, 11 | HOURS=TRUE, fxn='mean') 12 | wx <- NCEP.array2df(wx.ag, var.names="temp") 13 | #w <- wx[1,4] 14 | return(wx) 15 | } 16 | 17 | w=get_weather(2017,1,103,31) 18 | 19 | 20 | 21 | 22 | 23 | wx.ag.t1 <- NCEP.aggregate(wx.data=wx.t1, YEARS=TRUE, MONTHS=TRUE,DAYS=TRUE, 24 | HOURS=FALSE,fxn = 'mean') 25 | View(wx.ag.t1) 26 | 27 | flight <- NCEP.flight(beg.loc=c(58.00,7.00), 28 | end.loc=c(53.00,7.00), begin.dt='2007-10-01 18:00:00', 29 | flow.assist='NCEP.Tailwind', fa.args=list(airspeed=12), 30 | path='loxodrome', calibrate.dir=FALSE, calibrate.alt=FALSE, 31 | cutoff=0, when2stop='latitude', levels2consider=c(850,925), 32 | hours=12, evaluation.interval=60, id=1, land.if.bad=FALSE, 33 | reanalysis2 = FALSE, query=TRUE) 34 | 35 | 36 | wx.t1 <- NCEP.gather(variable='air.2m', level='gaussian', 37 | months.minmax = c(1,12),years.minmax = c(2017,2017), 38 | lat.southnorth = c(30,31), lon.westeast = c(104,104), 39 | reanalysis2 = FALSE, return.units = TRUE)-273.15 40 | 41 | wx.t2 <- NCEP.gather(variable='shum.2m', level='gaussian', 42 | months.minmax = c(1,12),years.minmax = c(2017,2017), 43 | lat.southnorth = c(30,31), lon.westeast = c(104,104), 44 | reanalysis2 = FALSE, return.units = TRUE) 45 | 46 | wx.df1 <- NCEP.array2df(wx.data=wx.t1, var.names='temperature') 47 | wx.df2 <- NCEP.array2df(wx.data=wx.t2, var.names='humidity') 48 | 49 | 50 | wx.df <- wx.df1 %>% 51 | inner_join(wx.df2) 52 | 53 | 54 | 55 | View(wx.df) 56 | write_csv(wx.df, "hum.csv") -------------------------------------------------------------------------------- /chat_with_sql/text2sql_demo.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import re 4 | import getpass 5 | from langchain_community.utilities import SQLDatabase 6 | from langchain.chat_models import init_chat_model 7 | 8 | if not os.environ.get("OPENAI_API_KEY"): 9 | os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ") 10 | 11 | 12 | def generate_base_sql_query(schema, question): 13 | """Generate initial SQL query from user question.""" 14 | prompt = f"""Based on the table schema below, write a SQL query that would answer the user's question. 15 | Return only the raw SQL query, without markdown formatting (e.g., no ```sql blocks) and explanations or additional text. 16 | 17 | Schema: {schema} 18 | Question: {question} 19 | 20 | SQL Query:""" 21 | return prompt 22 | 23 | def generate_follow_up_sql_query(schema, context, follow_up_question, previous_query, previous_result): 24 | """Generate SQL for follow-up questions using context from previous queries.""" 25 | prompt = f"""Based on the previous conversation context and table schema, write a SQL query for the follow-up question. 26 | Return only the raw SQL query, without markdown formatting (e.g., no ```sql blocks) and explanations or additional text. 27 | 28 | Schema: {schema} 29 | Previous Question: {context} 30 | Previous SQL Query: {previous_query} 31 | Previous Result: {previous_result} 32 | Follow-up Question: {follow_up_question} 33 | 34 | SQL Query:""" 35 | return prompt 36 | 37 | 38 | def clean_sql_output(sql_text): 39 | """Remove markdown formatting from SQL query.""" 40 | return re.sub(r"^```sql\s*|\s*```$", "", sql_text.strip(), flags=re.IGNORECASE).strip() 41 | 42 | 43 | def execute_prompt(prompt): 44 | answer = llm.invoke(prompt).content 45 | cleaned_answer = clean_sql_output(answer) 46 | print("\n\n") 47 | print(cleaned_answer) 48 | 49 | result = db.run(answer) 50 | print("\n\n") 51 | print(result) 52 | return cleaned_answer, result 53 | 54 | 55 | db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info=0) 56 | db_schema = db.get_table_info() 57 | 58 | llm = init_chat_model("gpt-4o-mini", model_provider="openai") 59 | 60 | question = "What team is Stephen Curry on?" 61 | prompt = generate_base_sql_query(db_schema, question) 62 | answer, result = execute_prompt(prompt) 63 | 64 | follow_up_question = "What's his salary?" 65 | prompt = generate_follow_up_sql_query(db_schema, context=question, follow_up_question=follow_up_question, previous_query=answer, previous_result=result) 66 | answer, result = execute_prompt(prompt) 67 | -------------------------------------------------------------------------------- /snippets/r_map/China_map_great_cicle.R: -------------------------------------------------------------------------------- 1 | #setwd('./Visualization') 2 | library(ggplot2) 3 | library(maptools) 4 | library(geosphere) 5 | library(plyr) 6 | 7 | #Sales data import and pro-processing 8 | MBsales<-read.csv("../../data/map_R_circle.csv",sep=';',header=T) 9 | MBsales2016<-MBsales[MBsales$Year==2016,] 10 | Csales2016<-MBsales[MBsales$Model=="C CLASS SEDAN"& MBsales$Year==2016 & MBsales$CBU.PbP=="PBP",] 11 | Csales2016<-data.frame(Csales2016) 12 | Sales<-ddply(Csales2016,.(City),summarize,Sales=sum(Total.Year)) 13 | Sales<-Sales[order(-Sales$Sales),] 14 | Sales$id<-as.character(c(1:nrow(Sales))) 15 | Sales$City<-gsub("'","",Sales$City) 16 | Sales$City<-gsub("Nanning City","NANNING",Sales$City) 17 | Sales$City<-gsub("Fu2zhou","FUZHOU",Sales$City) 18 | Sales$City<-gsub("Liuzhou City","Liuzhou",Sales$City) 19 | Sales$City<-gsub("Foshan Nanhai","Foshan",Sales$City) 20 | Sales$City<-gsub("Foshan Shunde","Foshan",Sales$City) 21 | Sales$City<-toupper(Sales$City) 22 | 23 | #China CityGeocode data import 24 | Citygeocode<-read.csv("../../assets/city_geocode_lookup.csv",sep=',',header=T) 25 | Citygeocode$City<-toupper(Citygeocode$City) 26 | #selected<-toupper(c("Beijing", "Shanghai", "Guangzhou",?"Foshan", "Xi��an", "Chengdu", "Suzhou", "Dalian")) 27 | #selected<-Citygeocode[Citygeocode$City %in% selected,] 28 | 29 | Sales<-merge(Sales,Citygeocode,by.x='City',by.y='City',all.x=TRUE) 30 | 31 | BBACcode=c(116.3,39.9) 32 | Sales<-Sales[complete.cases(Sales),] 33 | routes = gcIntermediate(BBACcode,Sales[,c('Lon', 'Lat')], 300, breakAtDateLine=FALSE, addStartEnd=TRUE, sp=TRUE) 34 | 35 | fortify.SpatialLinesDataFrame = function(model, data, ...) {ldply(model@lines, fortify)} 36 | fortifiedroutes = fortify.SpatialLinesDataFrame(routes) 37 | 38 | greatcircles = merge(fortifiedroutes, Sales, all.x=T, by="id") 39 | 40 | ChinaProvince<-readShapePoly("../../assets/province.shp") 41 | Chinamap <- fortify(ChinaProvince) 42 | 43 | theme_map <- list(theme(panel.grid.minor = element_blank(), 44 | panel.grid.major = element_blank(), 45 | panel.border = element_blank(), 46 | axis.line = element_blank(), 47 | axis.text.x = element_blank(), 48 | axis.text.y = element_blank(), 49 | axis.ticks = element_blank(), 50 | axis.title.x = element_blank(), 51 | axis.title.y = element_blank(), 52 | panel.background = element_blank(), 53 | plot.background = element_blank())) 54 | 55 | p1<-ggplot(Sales)+ 56 | geom_path(aes(x = long, y = lat, group = id), size=0.2, data=Chinamap)+ 57 | geom_line(aes(long,lat,group=id), data=greatcircles, color='grey', alpha=0.25, size=0.15)+ 58 | geom_point(aes(Lon,Lat,group=id,alpha=Sales,size=Sales),color="red")+scale_size(range = c(0, 4))+ 59 | geom_text(aes(Lon,Lat,label=City),data=Sales[1:5,],hjust =-0.4,check_overlap = TRUE,size=2.5)+ 60 | scale_alpha_continuous(range = c(0.25, 0.8))+coord_map()+ylim(14,55)+theme_map 61 | 62 | ggsave("../../assets/images/China map great circle.png", p1, height=4.8, width=9.5) -------------------------------------------------------------------------------- /llm_base_inference/sgl_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import sglang as sgl 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="SGL Language Demo") 9 | parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-4B-Base", help="Model name or path") 10 | parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for data parallelism") 11 | parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling") 12 | parser.add_argument("--top_p", type=float, default=0.95, help="Top-p for sampling") 13 | parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate") 14 | return parser.parse_args() 15 | 16 | 17 | def generate(llm, prompts, sampling_params): 18 | outputs = llm.generate(prompts, sampling_params) 19 | responses = [output['text'] for output in outputs] 20 | return responses 21 | 22 | 23 | def get_test_prompts(): 24 | return [ 25 | { 26 | "prompt": "Complete this sentence: The capital of France is", 27 | "expected": "Paris" 28 | }, 29 | { 30 | "prompt": "What is 2+2?", 31 | "expected": "4" 32 | }, 33 | { 34 | "prompt": "Translate 'hello' to Spanish:", 35 | "expected": "hola" 36 | }, 37 | { 38 | "prompt": "Name a primary color:", 39 | "expected_options": ["red", "blue", "yellow"] 40 | } 41 | ] 42 | 43 | 44 | def evaluate_response(response, expected): 45 | """Simple evaluation function to check if response contains expected answer.""" 46 | if isinstance(expected, list): 47 | return any(exp.lower() in response.lower() for exp in expected) 48 | return expected.lower() in response.lower() 49 | 50 | 51 | if __name__ == "__main__": 52 | args = parse_args() 53 | 54 | llm = sgl.Engine( 55 | model_path=args.model_name, 56 | dp_size=args.num_gpus, 57 | tp_size=1, 58 | device="cuda", 59 | context_length=1000 60 | ) 61 | 62 | test_data = get_test_prompts() 63 | prompts = [item["prompt"] for item in test_data] 64 | 65 | 66 | sampling_params = { 67 | "temperature": args.temperature, 68 | "top_p": args.top_p, 69 | "max_new_tokens": args.max_new_tokens 70 | } 71 | 72 | # Generate responses and measure time 73 | t0 = time.time() 74 | responses = generate(llm, prompts, sampling_params) 75 | t1 = time.time() 76 | 77 | # Evaluate responses 78 | total_score = 0 79 | for i, (response, test_item) in enumerate(zip(responses, test_data)): 80 | expected = test_item.get("expected_options", test_item.get("expected")) 81 | is_correct = evaluate_response(response, expected) 82 | 83 | print(f"Prompt {i+1}: {test_item['prompt']}") 84 | print(f"Response: {response}") 85 | print(f"Expected: {expected}") 86 | print(f"Correct: {is_correct}") 87 | print("-" * 50) 88 | 89 | if is_correct: 90 | total_score += 1 91 | 92 | print(f"Accuracy: {total_score}/{len(prompts)} = {total_score / len(prompts):.2f}") 93 | print(f"Time taken: {t1 - t0:.4f} seconds") 94 | 95 | llm.shutdown() 96 | -------------------------------------------------------------------------------- /snippets/r_plot/multi bar plot.R: -------------------------------------------------------------------------------- 1 | library(readxl) 2 | library(plyr) 3 | library(ggplot2) 4 | library(ggthemes) 5 | library(gridExtra) 6 | # setwd(".\\06-QFS_month") 7 | 8 | 9 | MB=data.frame(read_excel("MB Online master sheet_201812.xlsx",sheet=1,col_names=TRUE)) 10 | MB$Complaint=paste(MB$QFS_Tier1, MB$QFS_Tier3,sep=" ") 11 | MBagg<-ddply(MB,.(Carline,Complaint),summarize,Value=length(Complaint)) 12 | MBagg<-MBagg[order(MBagg$Value,decreasing=T),] 13 | MBagg1<-by(MBagg,MBagg['Carline'],head,n=3) 14 | MBagg2<-Reduce(rbind,MBagg1) 15 | MBagg2<-MBagg2[MBagg2$Carline %in% c("V205","V213","X253","X156"),] 16 | MBagg2$Carline<-factor(MBagg2$Carline, levels=c("V205","V213","X253","X156")) 17 | 18 | max=max(MBagg2[,"Value"]) 19 | n=1 20 | 21 | p<-list() 22 | for (i in c("V205","V213","X253","X156")) 23 | { 24 | MBcar<-MBagg2[MBagg2$Carline==i,] 25 | p[[n]]<-ggplot(MBcar)+geom_bar(aes(x=reorder(Complaint,Value),y=Value),stat='identity',fill="#d3ba68",alpha=0.7)+ylim(0,max)+geom_text(aes(x=Complaint,y=0,label=Complaint),size=4,hjust=0)+coord_flip()+ theme_wsj()+theme(axis.text.y = element_blank(),axis.text.x = element_text(size = 7)) 26 | n=n+1 27 | } 28 | p2<-grid.arrange(p[[1]],p[[2]],p[[3]],p[[4]],ncol = 1, nrow = 4) 29 | 30 | 31 | 32 | 33 | 34 | BMW=data.frame(read_excel("BMW Online master sheet_201812.xlsx",sheet=1,col_names=TRUE)) 35 | BMW$Complaint=paste(BMW$QFS_Tier1, BMW$QFS_Tier3,sep=" ") 36 | BMWagg<-ddply(BMW,.(Type.Class,Complaint),summarize,Value=length(Complaint)) 37 | BMWagg<-BMWagg[order(BMWagg$Value,decreasing=T),] 38 | BMWagg1<-by(BMWagg,BMWagg['Type.Class'],head,n=3) 39 | BMWagg2<-Reduce(rbind,BMWagg1) 40 | BMWagg2<-BMWagg2[BMWagg2$Type.Class %in% c("3-Series","5-Series","X3","X1"),] 41 | BMWagg2$Type.Class<-factor(BMWagg2$Type.Class, levels=c("3-Series","5-Series","X3","X1")) 42 | max=max(BMWagg2[,"Value"]) 43 | n=5 44 | for (i in c("3-Series","5-Series","X3","X1")) 45 | { 46 | BMWcar<-BMWagg2[BMWagg2$Type.Class==i,] 47 | p[[n]]<-ggplot(BMWcar)+geom_bar(aes(x=reorder(Complaint,Value),y=Value),stat='identity',fill="#d3ba68",alpha=0.7)+ylim(0,max)+geom_text(aes(x=Complaint,y=0,label=Complaint),size=4,hjust=0)+coord_flip()+ theme_wsj()+theme(axis.text.y = element_blank(),axis.text.x = element_text(size = 7)) 48 | n=n+1} 49 | p2<-grid.arrange(p[[5]],p[[6]],p[[7]],p[[8]],ncol = 1, nrow = 4) 50 | 51 | 52 | 53 | 54 | 55 | 56 | Audi=data.frame(read_excel("Audi Online master sheet_201812.xlsx",sheet=1,col_names=TRUE)) 57 | Audi$Complaint=paste(Audi$QFS_Tier1, Audi$QFS_Tier3,sep=" ") 58 | Audiagg<-ddply(Audi,.(Type.Class,Complaint),summarize,Value=length(Complaint)) 59 | Audiagg<-Audiagg[order(Audiagg$Value,decreasing=T),] 60 | Audiagg1<-by(Audiagg,Audiagg['Type.Class'],head,n=3) 61 | Audiagg2<-Reduce(rbind,Audiagg1) 62 | Audiagg2<-Audiagg2[Audiagg2$Type.Class %in% c("A4L","A6L","Q5","Q3"),] 63 | Audiagg2$Type.Class<-factor(Audiagg2$Type.Class, levels=c("A4L","A6L","Q5","Q3")) 64 | max=max(Audiagg2[,"Value"]) 65 | n=9 66 | for (i in c("A4L","A6L","Q5","Q3")) 67 | { 68 | Audicar<-Audiagg2[Audiagg2$Type.Class==i,] 69 | p[[n]]<-ggplot(Audicar)+geom_bar(aes(x=reorder(Complaint,Value),y=Value),stat='identity',fill="#d3ba68",alpha=0.7)+ylim(0,max)+geom_text(aes(x=Complaint,y=0,label=Complaint),size=4,hjust=0)+coord_flip()+ theme_wsj()+theme(axis.text.y = element_blank(),axis.text.x = element_text(size = 7)) 70 | n=n+1 71 | } 72 | p2<-grid.arrange(p[[9]],p[[10]],p[[11]],p[[12]],ncol = 1, nrow = 4) 73 | 74 | 75 | 76 | 77 | 78 | 79 | Lexus=data.frame(read_excel("LEXUS Online master sheet_201812.xlsx",sheet=1,col_names=TRUE)) 80 | Lexus$Complaint=paste(Lexus$QFS_Tier1, Lexus$QFS_Tier3,sep=" ") 81 | Lexusagg<-ddply(Lexus,.(Type.Class,Complaint),summarize,Value=length(Complaint)) 82 | Lexusagg<-Lexusagg[order(Lexusagg$Value,decreasing=T),] 83 | Lexusagg1<-by(Lexusagg,Lexusagg['Type.Class'],head,n=3) 84 | Lexusagg2<-Reduce(rbind,Lexusagg1) 85 | Lexusagg2<-Lexusagg2[Lexusagg2$Type.Class %in% c("IS","ES","NX","CT"),] 86 | Lexusagg2$Type.Class<-factor(Lexusagg2$Type.Class, levels=c("IS","ES","NX","CT")) 87 | max=max(Lexusagg2[,"Value"]) 88 | n=13 89 | for (i in c("IS","ES","NX","CT")) 90 | { 91 | Lexuscar<-Lexusagg2[Lexusagg2$Type.Class==i,] 92 | p[[n]]<-ggplot(Lexuscar)+geom_bar(aes(x=reorder(Complaint,Value),y=Value),stat='identity',fill="#d3ba68",alpha=0.7)+ylim(0,max)+geom_text(aes(x=Complaint,y=0,label=Complaint),size=4,hjust=0)+coord_flip()+ theme_wsj()+theme(axis.text.y = element_blank(),axis.text.x = element_text(size = 7)) 93 | n=n+1 94 | } 95 | p2<-grid.arrange(p[[13]],p[[14]],p[[15]],ncol = 1, nrow = 4) 96 | 97 | 98 | p3<-grid.arrange(p[[1]],p[[5]],p[[9]],p[[13]],p[[2]],p[[6]],p[[10]],p[[14]],p[[3]],p[[7]],p[[11]],p[[15]],p[[4]],p[[8]],p[[12]],nrow = 4,ncol = 4) 99 | 100 | 101 | 102 | 103 | #p<-ggplot(MBagg2)+geom_bar(aes(x=Complaint,y=Value),stat='identity',fill='yellow',alpha=0.5) 104 | #p<-p+facet_wrap(~Carline,nrow=4,scales="free")+ylim(0,25) 105 | #p<-p+geom_text(aes(x=Complaint,y=0,label=Complaint),hjust=0) 106 | #p+coord_flip()+ theme_wsj() -------------------------------------------------------------------------------- /llm_agent/openai_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from datetime import datetime 4 | from litellm import completion 5 | from duckduckgo_search import DDGS 6 | import json 7 | from rich.console import Console 8 | from rich.panel import Panel 9 | from rich.text import Text 10 | from rich.syntax import Syntax 11 | 12 | api_key = "sk-proj-xxxx" 13 | os.environ["OPENAI_API_KEY"] = api_key 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | console = Console() 16 | 17 | # Tool 1: DuckDuckGo Search 18 | def duckduckgo_search(query, max_results=5): 19 | results = [] 20 | with DDGS() as ddgs: 21 | for r in ddgs.text(query, max_results=max_results): 22 | results.append(f"{r['title']}: {r['href']}") 23 | return "\n".join(results) 24 | 25 | # Tool 2: Character Counter 26 | def count_character_count(text): 27 | return f"The input has {len(text)} characters." 28 | 29 | # Tool3: Character count in word 30 | def count_character_occurrence(word, character): 31 | if len(character) != 1: 32 | return "Error: Please provide a single character." 33 | count = word.count(character) 34 | return f"The character '{character}' appears {count} time(s) in '{word}'." 35 | 36 | # Tool 3: Current Datetime 37 | def get_current_datetime(): 38 | return f"The current date and time is: {datetime.now().isoformat()}" 39 | 40 | 41 | tool_functions = {} 42 | 43 | def register_tool(func): 44 | tool_functions[func.__name__] = func 45 | return func 46 | 47 | # Register tools 48 | register_tool(duckduckgo_search) 49 | register_tool(count_character_count) 50 | register_tool(count_character_occurrence) 51 | register_tool(get_current_datetime) 52 | 53 | 54 | # Tool registry 55 | tools = [ 56 | { 57 | "type": "function", 58 | "function": { 59 | "name": "duckduckgo_search", 60 | "description": "Searches the web for the latest info using DuckDuckGo", 61 | "parameters": { 62 | "type": "object", 63 | "properties": { 64 | "query": { 65 | "type": "string", 66 | "description": "The search query", 67 | } 68 | }, 69 | "required": ["query"], 70 | }, 71 | } 72 | }, 73 | { 74 | "type": "function", 75 | "function": { 76 | "name": "count_character_count", 77 | "description": "Counts the number of characters in a string", 78 | "parameters": { 79 | "type": "object", 80 | "properties": { 81 | "text": { 82 | "type": "string", 83 | "description": "The string to count characters in", 84 | } 85 | }, 86 | "required": ["text"], 87 | }, 88 | } 89 | }, 90 | { 91 | "type": "function", 92 | "function": { 93 | "name": "count_character_occurrence", 94 | "description": "Counts how many times a given character appears in a word", 95 | "parameters": { 96 | "type": "object", 97 | "properties": { 98 | "word": { 99 | "type": "string", 100 | "description": "The word in which to count the character", 101 | }, 102 | "character": { 103 | "type": "string", 104 | "description": "The single character to count", 105 | }, 106 | }, 107 | "required": ["word", "character"], 108 | }, 109 | } 110 | }, 111 | { 112 | "type": "function", 113 | "function": { 114 | "name": "get_current_datetime", 115 | "description": "Returns the current date and time", 116 | "parameters": { 117 | "type": "object", 118 | "properties": {} 119 | }, 120 | } 121 | } 122 | ] 123 | 124 | def run_agent(question): 125 | messages = [ 126 | {"role": "system", "content": "You are a helpful assistant."}, 127 | {"role": "user", "content": question} 128 | ] 129 | 130 | for step in range(3): 131 | console.rule(f"[bold green]Step {step}") 132 | console.print(Syntax(str(messages), "python", theme="monokai", word_wrap=True)) 133 | 134 | response = completion( 135 | model="gpt-4.1-mini", 136 | messages=messages, 137 | tools=tools, 138 | tool_choice="auto", 139 | ) 140 | 141 | msg = response['choices'][0]['message'] 142 | 143 | if msg.get("tool_calls"): 144 | tool_call = msg["tool_calls"][0] 145 | function_name = tool_call["function"]["name"] 146 | arguments = json.loads(tool_call["function"]["arguments"]) 147 | 148 | console.print(Panel.fit( 149 | f"[bold yellow]🔧 Tool requested[/bold yellow]: [cyan]{function_name}[/cyan]([magenta]{arguments}[/magenta])", 150 | title="Tool Call", style="yellow")) 151 | 152 | if function_name in tool_functions: 153 | try: 154 | result = tool_functions[function_name](**arguments) 155 | except Exception as e: 156 | result = f"Error running tool '{function_name}': {e}" 157 | else: 158 | result = f"❌ Unknown tool: {function_name}" 159 | 160 | messages.append(msg) 161 | messages.append({ 162 | "role": "tool", 163 | "tool_call_id": tool_call["id"], 164 | "name": function_name, 165 | "content": result, 166 | }) 167 | 168 | console.print(Panel.fit(result, title="Tool Result", style="cyan")) 169 | else: 170 | final_answer = msg["content"] 171 | console.print(Panel.fit( 172 | f"[bold green]🤖 Final Answer:[/bold green]\n{final_answer}", 173 | style="green", title="Answer" 174 | )) 175 | break 176 | 177 | 178 | if __name__ == "__main__": 179 | while True: 180 | user_question = input("\nAsk something (or type 'q' to quit): ") 181 | if user_question.strip().lower() in {"exit", "q"}: 182 | print("Goodbye!") 183 | break 184 | run_agent(user_question) 185 | -------------------------------------------------------------------------------- /llm_batch_inference/client/vllm_ray_batch_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | from omegaconfig import OmegaConf 5 | import pandas as pd 6 | import polars as pl 7 | import torch 8 | import transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor 9 | from vllm.transformers_utils.tokenizer import get_tokenizer 10 | import vllm 11 | import ray 12 | from openai import OpenAI 13 | 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 17 | ) 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | PROMPT = """Classify the following text into one of these categories: positive, negative, or neutral. 22 | Respond with only the category name. 23 | 24 | Text: {text} 25 | 26 | Category:""" 27 | 28 | 29 | def build_df_data() -> pd.DataFrame: 30 | """Create sample data for demonstration""" 31 | sample_texts = [ 32 | "I love this product! It's amazing and works perfectly.", 33 | "This is terrible. Worst purchase ever made.", 34 | "The weather is okay today, nothing special.", 35 | "Fantastic service and great customer support!", 36 | "It's just an average product, nothing more.", 37 | "I hate waiting in long lines at the store.", 38 | "The book was quite interesting and well-written.", 39 | "Meh, could be better but could be worse too.", 40 | "Outstanding performance, exceeded my expectations!", 41 | "This software is buggy and crashes frequently." 42 | ] 43 | 44 | df = pd.DataFrame({ 45 | 'text': sample_texts, 46 | 'id': range(len(sample_texts)) 47 | }) 48 | 49 | logger.info(f"Created sample dataset with {len(df)} rows") 50 | return df 51 | 52 | 53 | class PromptPreparer: 54 | """Prepares prompts for text classification""" 55 | 56 | def __init__(self, model_name: str = "microsoft/DialoGPT-medium"): 57 | self.model_name = model_name 58 | self.tokenizer = None 59 | 60 | def __call__(self, batch: Dict[str, List]) -> Dict[str, List]: 61 | """Process a batch of rows to prepare prompts""" 62 | if self.tokenizer is None: 63 | try: 64 | self.tokenizer = get_tokenizer(self.model_name) 65 | logger.info(f"Loaded tokenizer for {self.model_name}") 66 | except Exception as e: 67 | logger.warning(f"Could not load tokenizer: {e}. Using simple prompt formatting.") 68 | self.tokenizer = None 69 | 70 | # Process each text in the batch 71 | prompts = [] 72 | for text in batch['text']: 73 | if self.tokenizer is not None: 74 | messages = [ 75 | {"role": "user", "content": PROMPT.format(text=text)} 76 | ] 77 | try: 78 | prompt = self.tokenizer.apply_chat_template( 79 | messages, tokenize=False, add_generation_prompt=True 80 | ) 81 | except: 82 | # Fallback if chat template not supported 83 | prompt = PROMPT.format(text=text) 84 | else: 85 | prompt = PROMPT.format(text=text) 86 | 87 | prompts.append(prompt) 88 | 89 | batch['prompt'] = prompts 90 | return batch 91 | 92 | 93 | def generate_batch_response(batch: Dict[str, List]) -> Dict[str, List]: 94 | """Generate responses for a batch of prompts using OpenAI API""" 95 | 96 | # Configuration - in production, use environment variables 97 | api_base_url = os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1') 98 | api_key = os.getenv('OPENAI_API_KEY', 'your-api-key-here') 99 | model_name = os.getenv('MODEL_NAME', 'gpt-3.5-turbo') 100 | max_retries = int(os.getenv('MAX_RETRIES', '3')) 101 | 102 | # Initialize OpenAI client 103 | client = OpenAI( 104 | base_url=api_base_url, 105 | api_key=api_key 106 | ) 107 | 108 | responses = [] 109 | 110 | for i, prompt in enumerate(batch['prompt']): 111 | response_text = "error" 112 | 113 | for attempt in range(max_retries): 114 | try: 115 | # Create chat completion 116 | response = client.chat.completions.create( 117 | model=model_name, 118 | messages=[{"role": "user", "content": prompt}], 119 | temperature=0.0, 120 | max_tokens=50, # Reduced for classification task 121 | timeout=30 122 | ) 123 | 124 | response_text = response.choices[0].message.content.strip() 125 | logger.debug(f"Successfully processed item {i}") 126 | break 127 | 128 | except Exception as e: 129 | logger.warning(f"Attempt {attempt + 1} failed for item {i}: {str(e)}") 130 | if attempt == max_retries - 1: 131 | response_text = f"Error after {max_retries} attempts: {str(e)}" 132 | else: 133 | time.sleep(1) # Brief delay before retry 134 | 135 | responses.append(response_text) 136 | 137 | batch['response'] = responses 138 | return batch 139 | 140 | 141 | def main(): 142 | """Main execution function""" 143 | start_time = time.time() 144 | 145 | # Initialize Ray 146 | if not ray.is_initialized(): 147 | ray.init(ignore_reinit_error=True) 148 | logger.info("Ray initialized") 149 | 150 | try: 151 | # Build sample data 152 | df = build_df_data() 153 | 154 | # Convert to Ray dataset 155 | ds = ray.data.from_pandas(df) 156 | logger.info("Created Ray dataset") 157 | 158 | # Prepare prompts 159 | prompt_preparer = PromptPreparer(model_name="microsoft/DialoGPT-medium") 160 | ds = ds.map_batches(prompt_preparer, batch_size=5) 161 | logger.info("Prompts prepared") 162 | 163 | logger.info("Using OpenAI API for classification") 164 | ds = ds.map_batches(generate_batch_response, batch_size=2) 165 | 166 | # Collect results 167 | logger.info("Collecting results...") 168 | results = ds.to_pandas() 169 | 170 | # Calculate processing time 171 | end_time = time.time() 172 | processing_time = end_time - start_time 173 | 174 | logger.info(f"Processing completed in {processing_time:.2f} seconds") 175 | logger.info(f"Processed {len(results)} items") 176 | 177 | # Save results 178 | output_file = "classification_results.csv" 179 | results.to_csv(output_file, index=False) 180 | logger.info(f"Results saved to {output_file}") 181 | 182 | except Exception as e: 183 | logger.error(f"Error in main execution: {str(e)}") 184 | raise 185 | 186 | finally: 187 | if ray.is_initialized(): 188 | ray.shutdown() 189 | logger.info("Ray shutdown complete") 190 | 191 | 192 | if __name__ == "__main__": 193 | os.environ.setdefault('MAX_RETRIES', '3') 194 | main() 195 | -------------------------------------------------------------------------------- /llm_fine_tuning/llama-instruct-tuning-alpaca.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.11.11","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[{"sourceId":120005,"sourceType":"modelInstanceVersion","isSourceIdPinned":true,"modelInstanceId":100936,"modelId":121027}],"dockerImageVersionId":31041,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install bitsandbytes --quiet\n!pip install transformers --quiet\n!pip install peft --quiet\n!pip install accelerate --quiet\n!pip install trl --quiet\n!pip install datasets --quiet","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-05-17T21:19:10.352402Z","iopub.execute_input":"2025-05-17T21:19:10.352693Z","iopub.status.idle":"2025-05-17T21:20:46.665131Z","shell.execute_reply.started":"2025-05-17T21:19:10.352672Z","shell.execute_reply":"2025-05-17T21:20:46.664076Z"}},"outputs":[{"name":"stdout","text":"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.1/76.1 MB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m30.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m0:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m85.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m348.0/348.0 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m193.6/193.6 kB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\ncesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\nbigframes 1.42.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\ngcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatible.\u001b[0m\u001b[31m\n\u001b[0m","output_type":"stream"}],"execution_count":1},{"cell_type":"code","source":"import os\nimport pandas as pd\nimport torch\nfrom datasets import load_dataset\nfrom trl import SFTTrainer\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments\nfrom peft import LoraConfig, PeftModel, prepare_model_for_kbit_training\n\nos.environ[\"WANDB_DISABLED\"] = \"true\"","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T21:20:46.666888Z","iopub.execute_input":"2025-05-17T21:20:46.667148Z","iopub.status.idle":"2025-05-17T21:21:18.147015Z","shell.execute_reply.started":"2025-05-17T21:20:46.667124Z","shell.execute_reply":"2025-05-17T21:21:18.146450Z"}},"outputs":[{"name":"stderr","text":"2025-05-17 21:21:01.450910: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\nWARNING: All log messages before absl::InitializeLog() is called are written to STDERR\nE0000 00:00:1747516861.668763 35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\nE0000 00:00:1747516861.730283 35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"}],"execution_count":2},{"cell_type":"code","source":"dataset = load_dataset(\"tatsu-lab/alpaca\")[\"train\"]\n\ndataset[0]","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T21:21:18.147720Z","iopub.execute_input":"2025-05-17T21:21:18.148284Z","iopub.status.idle":"2025-05-17T21:21:20.322934Z","shell.execute_reply.started":"2025-05-17T21:21:18.148264Z","shell.execute_reply":"2025-05-17T21:21:20.322343Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/7.47k [00:00 <|eot_id|>\n","output_type":"stream"}],"execution_count":4},{"cell_type":"code","source":"def format_chat_template(row):\n messages = [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": row[\"instruction\"]},\n {\"role\": \"assistant\", \"content\": row[\"output\"]}\n ]\n \n chat = tokenizer.apply_chat_template(messages, tokenize=False)\n return {\"text\": chat}\n\n\nformat_chat_template(dataset[0])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T21:21:20.905083Z","iopub.execute_input":"2025-05-17T21:21:20.905315Z","iopub.status.idle":"2025-05-17T21:21:20.916349Z","shell.execute_reply.started":"2025-05-17T21:21:20.905297Z","shell.execute_reply":"2025-05-17T21:21:20.915614Z"}},"outputs":[{"execution_count":5,"output_type":"execute_result","data":{"text/plain":"{'text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nGive three tips for staying healthy.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'}"},"metadata":{}}],"execution_count":5},{"cell_type":"code","source":"dataset = dataset.shuffle(seed=42).select(range(11000))\n\nprocessed_dataset = dataset.map(\n format_chat_template,\n num_proc= os.cpu_count(),\n)\n\ndataset = processed_dataset.train_test_split(test_size=0.1)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T21:21:20.916941Z","iopub.execute_input":"2025-05-17T21:21:20.917153Z","iopub.status.idle":"2025-05-17T21:21:22.360738Z","shell.execute_reply.started":"2025-05-17T21:21:20.917136Z","shell.execute_reply":"2025-05-17T21:21:22.360076Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Map (num_proc=4): 0%| | 0/11000 [00:00","text/html":"\n
\n \n \n [200/200 1:04:48, Epoch 0/1]\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
StepTraining Loss
102.871600
202.576800
302.196500
402.024300
501.784300
601.618500
701.524300
801.448200
901.450400
1001.406500
1101.392100
1201.427900
1301.417400
1401.402700
1501.402300
1601.356500
1701.375300
1801.382400
1901.426700
2001.352500

"},"metadata":{}},{"execution_count":10,"output_type":"execute_result","data":{"text/plain":"TrainOutput(global_step=200, training_loss=1.6418570566177368, metrics={'train_runtime': 3898.6987, 'train_samples_per_second': 0.821, 'train_steps_per_second': 0.051, 'total_flos': 9810597599281152.0, 'train_loss': 1.6418570566177368})"},"metadata":{}}],"execution_count":10},{"cell_type":"markdown","source":"## Inference","metadata":{}},{"cell_type":"code","source":"sample = dataset[\"test\"][0]\nprint(sample)\n\n\ndef format_chat_template(row):\n messages = [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": row[\"instruction\"]}\n ] \n prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n return prompt\n\nprompt = format_chat_template(sample)\nprompt","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:03.428724Z","iopub.execute_input":"2025-05-17T22:27:03.428978Z","iopub.status.idle":"2025-05-17T22:27:03.436206Z","shell.execute_reply.started":"2025-05-17T22:27:03.428960Z","shell.execute_reply":"2025-05-17T22:27:03.435494Z"}},"outputs":[{"name":"stdout","text":"{'instruction': 'Provide the necessary materials for the given project.', 'input': 'Build a birdhouse', 'output': 'Materials Needed for Building a Birdhouse:\\n-Pieces of wood for the base, walls and roof of the birdhouse \\n-Saw \\n-Screws \\n-Screwdriver \\n-Nails \\n-Hammer \\n-Paint\\n-Paintbrushes \\n-Drill and bits \\n-Gravel (optional)', 'text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nProvide the necessary materials for the given project.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nMaterials Needed for Building a Birdhouse:\\n-Pieces of wood for the base, walls and roof of the birdhouse \\n-Saw \\n-Screws \\n-Screwdriver \\n-Nails \\n-Hammer \\n-Paint\\n-Paintbrushes \\n-Drill and bits \\n-Gravel (optional)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'}\n","output_type":"stream"},{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nProvide the necessary materials for the given project.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"# sft_model = AutoModelForCausalLM.from_pretrained(\"./llama3_sft/\")\n\nsft_model = AutoModelForCausalLM.from_pretrained(\"/kaggle/input/llama-3.2/transformers/3b-instruct/1\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:03.436908Z","iopub.execute_input":"2025-05-17T22:27:03.437240Z","iopub.status.idle":"2025-05-17T22:27:10.663130Z","shell.execute_reply.started":"2025-05-17T22:27:03.437221Z","shell.execute_reply":"2025-05-17T22:27:10.662463Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards: 0%| | 0/2 [00:00\")[0]) ","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:10.663898Z","iopub.execute_input":"2025-05-17T22:27:10.664164Z","iopub.status.idle":"2025-05-17T22:27:17.590203Z","shell.execute_reply.started":"2025-05-17T22:27:10.664145Z","shell.execute_reply":"2025-05-17T22:27:17.589393Z"}},"outputs":[{"name":"stderr","text":"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:745: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n return fn(*args, **kwargs)\n/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n warnings.warn(\n","output_type":"stream"},{"name":"stdout","text":"The materials needed for this project are:\n- A piece of wood\n- A drill\n- A hammer\n- A saw\n- A sandpaper\n- Paint\n- Paintbrushes\n- A paint tray\n","output_type":"stream"}],"execution_count":13}]} -------------------------------------------------------------------------------- /time_series/vehicle-sales-prediction-tensorflow-lstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bb0d377a", 6 | "metadata": { 7 | "papermill": { 8 | "duration": 0.002831, 9 | "end_time": "2025-05-16T02:49:48.027348", 10 | "exception": false, 11 | "start_time": "2025-05-16T02:49:48.024517", 12 | "status": "completed" 13 | }, 14 | "tags": [] 15 | }, 16 | "source": [ 17 | "vehicle sales data\n", 18 | "- data in [kaggle dataste](https://www.kaggle.com/datasets/brendayue/china-vehicle-sales-data)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "af6c017a", 25 | "metadata": { 26 | "execution": { 27 | "iopub.execute_input": "2025-05-16T02:49:48.032819Z", 28 | "iopub.status.busy": "2025-05-16T02:49:48.032563Z", 29 | "iopub.status.idle": "2025-05-16T02:49:52.474282Z", 30 | "shell.execute_reply": "2025-05-16T02:49:52.473247Z" 31 | }, 32 | "papermill": { 33 | "duration": 4.446267, 34 | "end_time": "2025-05-16T02:49:52.475945", 35 | "exception": false, 36 | "start_time": "2025-05-16T02:49:48.029678", 37 | "status": "completed" 38 | }, 39 | "tags": [] 40 | }, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.4/74.4 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", 47 | "\u001b[?25h" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "!pip install tfts --quiet" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "id": "cd93f2b1", 59 | "metadata": { 60 | "execution": { 61 | "iopub.execute_input": "2025-05-16T02:49:52.482353Z", 62 | "iopub.status.busy": "2025-05-16T02:49:52.482091Z", 63 | "iopub.status.idle": "2025-05-16T02:50:09.490987Z", 64 | "shell.execute_reply": "2025-05-16T02:50:09.489995Z" 65 | }, 66 | "papermill": { 67 | "duration": 17.014133, 68 | "end_time": "2025-05-16T02:50:09.492944", 69 | "exception": false, 70 | "start_time": "2025-05-16T02:49:52.478811", 71 | "status": "completed" 72 | }, 73 | "tags": [] 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stderr", 78 | "output_type": "stream", 79 | "text": [ 80 | "2025-05-16 02:49:55.547627: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", 81 | "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", 82 | "E0000 00:00:1747363795.751419 19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", 83 | "E0000 00:00:1747363795.806600 19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", 84 | "I0000 00:00:1747363809.433713 19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory: -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "import logging\n", 90 | "from typing import List, Optional, Union\n", 91 | "import numpy as np\n", 92 | "import pandas as pd\n", 93 | "import tensorflow as tf\n", 94 | "from tfts import AutoModel, AutoConfig, KerasTrainer" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "id": "b79c26aa", 101 | "metadata": { 102 | "execution": { 103 | "iopub.execute_input": "2025-05-16T02:50:09.502215Z", 104 | "iopub.status.busy": "2025-05-16T02:50:09.501552Z", 105 | "iopub.status.idle": "2025-05-16T02:50:09.506557Z", 106 | "shell.execute_reply": "2025-05-16T02:50:09.505778Z" 107 | }, 108 | "papermill": { 109 | "duration": 0.011339, 110 | "end_time": "2025-05-16T02:50:09.508436", 111 | "exception": false, 112 | "start_time": "2025-05-16T02:50:09.497097", 113 | "status": "completed" 114 | }, 115 | "tags": [] 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "class CFG:\n", 120 | " input_dir = \"/kaggle/input/china-vehicle-sales-data/china_vehicle_sales_data.csv\"\n", 121 | " train_sequence_length = 12\n", 122 | " predict_sequence_length = 3\n" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "id": "2457152a", 129 | "metadata": { 130 | "execution": { 131 | "iopub.execute_input": "2025-05-16T02:50:09.518205Z", 132 | "iopub.status.busy": "2025-05-16T02:50:09.517910Z", 133 | "iopub.status.idle": "2025-05-16T02:50:09.628283Z", 134 | "shell.execute_reply": "2025-05-16T02:50:09.627424Z" 135 | }, 136 | "papermill": { 137 | "duration": 0.116531, 138 | "end_time": "2025-05-16T02:50:09.629616", 139 | "exception": false, 140 | "start_time": "2025-05-16T02:50:09.513085", 141 | "status": "completed" 142 | }, 143 | "tags": [] 144 | }, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/html": [ 149 | "

\n", 150 | "\n", 163 | "\n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | "
DateprovinceprovinceIdpopularitymodelbodyTypesalesVolume
0201601Shanghai31000014793c974920a76ac9c1SUV292
1201601Yunnan53000015943c974920a76ac9c1SUV466
2201601Inner Mongolia15000014793c974920a76ac9c1SUV257
3201601Beijing11000023703c974920a76ac9c1SUV408
4201601Sichuan51000035623c974920a76ac9c1SUV610
\n", 229 | "
" 230 | ], 231 | "text/plain": [ 232 | " Date province provinceId popularity model bodyType \\\n", 233 | "0 201601 Shanghai 310000 1479 3c974920a76ac9c1 SUV \n", 234 | "1 201601 Yunnan 530000 1594 3c974920a76ac9c1 SUV \n", 235 | "2 201601 Inner Mongolia 150000 1479 3c974920a76ac9c1 SUV \n", 236 | "3 201601 Beijing 110000 2370 3c974920a76ac9c1 SUV \n", 237 | "4 201601 Sichuan 510000 3562 3c974920a76ac9c1 SUV \n", 238 | "\n", 239 | " salesVolume \n", 240 | "0 292 \n", 241 | "1 466 \n", 242 | "2 257 \n", 243 | "3 408 \n", 244 | "4 610 " 245 | ] 246 | }, 247 | "execution_count": 4, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "data = pd.read_csv(CFG.input_dir)\n", 254 | "\n", 255 | "data.head()" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 5, 261 | "id": "6b674127", 262 | "metadata": { 263 | "execution": { 264 | "iopub.execute_input": "2025-05-16T02:50:09.635994Z", 265 | "iopub.status.busy": "2025-05-16T02:50:09.635773Z", 266 | "iopub.status.idle": "2025-05-16T02:50:09.642094Z", 267 | "shell.execute_reply": "2025-05-16T02:50:09.641305Z" 268 | }, 269 | "papermill": { 270 | "duration": 0.01082, 271 | "end_time": "2025-05-16T02:50:09.643286", 272 | "exception": false, 273 | "start_time": "2025-05-16T02:50:09.632466", 274 | "status": "completed" 275 | }, 276 | "tags": [] 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "# https://github.com/hongyingyue/Vehicle-sales-predictor/blob/main/vehicle_ml/feature/ts_feature.py\n", 281 | "\n", 282 | "logger = logging.getLogger(__name__)\n", 283 | "\n", 284 | "def add_lagging_feature(\n", 285 | " data: pd.DataFrame,\n", 286 | " groupby_column: Union[str, List[str]],\n", 287 | " value_columns: List[str],\n", 288 | " lags: List[int],\n", 289 | " feature_columns: Optional[List[str]] = None,\n", 290 | "):\n", 291 | " # note that the data should be sorted by time already\n", 292 | " # the lagging feature could be further developed use f1 - f1_lag, or f1 / f1_lag\n", 293 | "\n", 294 | " if not isinstance(groupby_column, (str, list)):\n", 295 | " raise TypeError(f\"'groupby_column' must be a string or a list of strings, but got {type(groupby_column)}.\")\n", 296 | "\n", 297 | " if not isinstance(value_columns, (list, tuple)):\n", 298 | " raise TypeError(f\"'value_columns' must be a list of strings, but got {type(value_columns)}.\")\n", 299 | "\n", 300 | " feature_columns: List[str] = feature_columns if feature_columns is not None else []\n", 301 | " for column in value_columns:\n", 302 | " if column not in data.columns:\n", 303 | " raise ValueError(f\"Value column '{column}' not found in DataFrame.\")\n", 304 | "\n", 305 | " for lag in lags:\n", 306 | " feature_col_name = f\"{column}_lag{lag}\"\n", 307 | " feature_columns.append(feature_col_name)\n", 308 | " logger.debug(\n", 309 | " f\"Creating lagging feature: {feature_col_name} for column '{column}' with lag {lag} and groupby '{groupby_column}'.\"\n", 310 | " )\n", 311 | " data[feature_col_name] = data.groupby(groupby_column)[column].shift(lag)\n", 312 | " return data" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 6, 318 | "id": "6c26a66d", 319 | "metadata": { 320 | "execution": { 321 | "iopub.execute_input": "2025-05-16T02:50:09.649177Z", 322 | "iopub.status.busy": "2025-05-16T02:50:09.648793Z", 323 | "iopub.status.idle": "2025-05-16T02:50:09.734527Z", 324 | "shell.execute_reply": "2025-05-16T02:50:09.733463Z" 325 | }, 326 | "papermill": { 327 | "duration": 0.090844, 328 | "end_time": "2025-05-16T02:50:09.736508", 329 | "exception": false, 330 | "start_time": "2025-05-16T02:50:09.645664", 331 | "status": "completed" 332 | }, 333 | "tags": [] 334 | }, 335 | "outputs": [ 336 | { 337 | "name": "stderr", 338 | "output_type": "stream", 339 | "text": [ 340 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1458: RuntimeWarning: invalid value encountered in greater\n", 341 | " has_large_values = (abs_vals > 1e6).any()\n", 342 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1459: RuntimeWarning: invalid value encountered in less\n", 343 | " has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()\n", 344 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1459: RuntimeWarning: invalid value encountered in greater\n", 345 | " has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()\n" 346 | ] 347 | }, 348 | { 349 | "data": { 350 | "text/html": [ 351 | "
\n", 352 | "\n", 365 | "\n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | "
DateprovinceprovinceIdpopularitymodelbodyTypesalesVolumesalesVolume_lag1salesVolume_lag2salesVolume_lag3salesVolume_lag4salesVolume_lag5salesVolume_lag6salesVolume_lag7salesVolume_lag8salesVolume_lag9salesVolume_lag10salesVolume_lag11
0201601Shanghai31000014793c974920a76ac9c1SUV292NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1201601Yunnan53000015943c974920a76ac9c1SUV466NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2201601Inner Mongolia15000014793c974920a76ac9c1SUV257NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3201601Beijing11000023703c974920a76ac9c1SUV408NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
4201601Sichuan51000035623c974920a76ac9c1SUV610NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", 497 | "
" 498 | ], 499 | "text/plain": [ 500 | " Date province provinceId popularity model bodyType \\\n", 501 | "0 201601 Shanghai 310000 1479 3c974920a76ac9c1 SUV \n", 502 | "1 201601 Yunnan 530000 1594 3c974920a76ac9c1 SUV \n", 503 | "2 201601 Inner Mongolia 150000 1479 3c974920a76ac9c1 SUV \n", 504 | "3 201601 Beijing 110000 2370 3c974920a76ac9c1 SUV \n", 505 | "4 201601 Sichuan 510000 3562 3c974920a76ac9c1 SUV \n", 506 | "\n", 507 | " salesVolume salesVolume_lag1 salesVolume_lag2 salesVolume_lag3 \\\n", 508 | "0 292 NaN NaN NaN \n", 509 | "1 466 NaN NaN NaN \n", 510 | "2 257 NaN NaN NaN \n", 511 | "3 408 NaN NaN NaN \n", 512 | "4 610 NaN NaN NaN \n", 513 | "\n", 514 | " salesVolume_lag4 salesVolume_lag5 salesVolume_lag6 salesVolume_lag7 \\\n", 515 | "0 NaN NaN NaN NaN \n", 516 | "1 NaN NaN NaN NaN \n", 517 | "2 NaN NaN NaN NaN \n", 518 | "3 NaN NaN NaN NaN \n", 519 | "4 NaN NaN NaN NaN \n", 520 | "\n", 521 | " salesVolume_lag8 salesVolume_lag9 salesVolume_lag10 salesVolume_lag11 \n", 522 | "0 NaN NaN NaN NaN \n", 523 | "1 NaN NaN NaN NaN \n", 524 | "2 NaN NaN NaN NaN \n", 525 | "3 NaN NaN NaN NaN \n", 526 | "4 NaN NaN NaN NaN " 527 | ] 528 | }, 529 | "execution_count": 6, 530 | "metadata": {}, 531 | "output_type": "execute_result" 532 | } 533 | ], 534 | "source": [ 535 | "feature_columns = []\n", 536 | "\n", 537 | "data = add_lagging_feature(data, groupby_column=[\"provinceId\", \"model\"], value_columns=[\"salesVolume\"], lags=list(range(1, 12)), feature_columns=feature_columns)\n", 538 | "\n", 539 | "data.head()" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 7, 545 | "id": "457d453e", 546 | "metadata": { 547 | "execution": { 548 | "iopub.execute_input": "2025-05-16T02:50:09.745696Z", 549 | "iopub.status.busy": "2025-05-16T02:50:09.745164Z", 550 | "iopub.status.idle": "2025-05-16T02:50:11.128034Z", 551 | "shell.execute_reply": "2025-05-16T02:50:11.127244Z" 552 | }, 553 | "papermill": { 554 | "duration": 1.3884, 555 | "end_time": "2025-05-16T02:50:11.129283", 556 | "exception": false, 557 | "start_time": "2025-05-16T02:50:09.740883", 558 | "status": "completed" 559 | }, 560 | "tags": [] 561 | }, 562 | "outputs": [ 563 | { 564 | "name": "stderr", 565 | "output_type": "stream", 566 | "text": [ 567 | "/tmp/ipykernel_19/3102104721.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n", 568 | " grouped_sequence = data.groupby([\"provinceId\", \"model\"]).apply(\n" 569 | ] 570 | }, 571 | { 572 | "data": { 573 | "text/plain": [ 574 | "array([[[ 799., nan, nan, nan],\n", 575 | " [ 424., 799., nan, nan],\n", 576 | " [ 733., 424., 799., nan],\n", 577 | " ...,\n", 578 | " [ 544., 659., 630., 670.],\n", 579 | " [ 647., 544., 659., 630.],\n", 580 | " [ 640., 647., 544., 659.]],\n", 581 | "\n", 582 | " [[ 135., nan, nan, nan],\n", 583 | " [ 57., 135., nan, nan],\n", 584 | " [ 160., 57., 135., nan],\n", 585 | " ...,\n", 586 | " [ 105., 201., 120., 135.],\n", 587 | " [ 148., 105., 201., 120.],\n", 588 | " [ 112., 148., 105., 201.]],\n", 589 | "\n", 590 | " [[ 872., nan, nan, nan],\n", 591 | " [ 197., 872., nan, nan],\n", 592 | " [ 494., 197., 872., nan],\n", 593 | " ...,\n", 594 | " [ 152., 170., 181., 159.],\n", 595 | " [ 213., 152., 170., 181.],\n", 596 | " [ 226., 213., 152., 170.]],\n", 597 | "\n", 598 | " ...,\n", 599 | "\n", 600 | " [[ 181., nan, nan, nan],\n", 601 | " [ 60., 181., nan, nan],\n", 602 | " [ 111., 60., 181., nan],\n", 603 | " ...,\n", 604 | " [ 330., 297., 252., 199.],\n", 605 | " [ 178., 330., 297., 252.],\n", 606 | " [ 185., 178., 330., 297.]],\n", 607 | "\n", 608 | " [[1023., nan, nan, nan],\n", 609 | " [ 517., 1023., nan, nan],\n", 610 | " [ 513., 517., 1023., nan],\n", 611 | " ...,\n", 612 | " [1110., 991., 975., 798.],\n", 613 | " [ 967., 1110., 991., 975.],\n", 614 | " [1581., 967., 1110., 991.]],\n", 615 | "\n", 616 | " [[ 170., nan, nan, nan],\n", 617 | " [ 37., 170., nan, nan],\n", 618 | " [ 124., 37., 170., nan],\n", 619 | " ...,\n", 620 | " [ 229., 236., 208., 749.],\n", 621 | " [ 240., 229., 236., 208.],\n", 622 | " [ 337., 240., 229., 236.]]])" 623 | ] 624 | }, 625 | "execution_count": 7, 626 | "metadata": {}, 627 | "output_type": "execute_result" 628 | } 629 | ], 630 | "source": [ 631 | "grouped_sequence = data.groupby([\"provinceId\", \"model\"]).apply(\n", 632 | " lambda x: x.sort_values('Date')[[\"salesVolume\", \"salesVolume_lag1\", \"salesVolume_lag2\", \"salesVolume_lag3\"]].to_numpy()\n", 633 | ")\n", 634 | "\n", 635 | "data_3d = np.stack(grouped_sequence.values)\n", 636 | "\n", 637 | "data_3d" 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": 8, 643 | "id": "b141753f", 644 | "metadata": { 645 | "execution": { 646 | "iopub.execute_input": "2025-05-16T02:50:11.136255Z", 647 | "iopub.status.busy": "2025-05-16T02:50:11.135797Z", 648 | "iopub.status.idle": "2025-05-16T02:50:11.144981Z", 649 | "shell.execute_reply": "2025-05-16T02:50:11.144469Z" 650 | }, 651 | "papermill": { 652 | "duration": 0.013747, 653 | "end_time": "2025-05-16T02:50:11.146090", 654 | "exception": false, 655 | "start_time": "2025-05-16T02:50:11.132343", 656 | "status": "completed" 657 | }, 658 | "tags": [] 659 | }, 660 | "outputs": [], 661 | "source": [ 662 | "from tensorflow.keras.utils import Sequence\n", 663 | "\n", 664 | "\n", 665 | "class TimeDataset(Sequence):\n", 666 | " def __init__(self, data, train_sequence_length, predict_sequence_length, batch_size: int = 64):\n", 667 | " self.data = data\n", 668 | " self.train_seq_len = train_sequence_length\n", 669 | " self.pred_seq_len = predict_sequence_length\n", 670 | " self.batch_size = batch_size\n", 671 | "\n", 672 | " self.num_ids = data.shape[0]\n", 673 | " self.max_seq_len = data.shape[1]\n", 674 | " self.feature_dim = data.shape[2]\n", 675 | "\n", 676 | " self.samples_per_id = self.max_seq_len - self.train_seq_len - self.pred_seq_len + 1\n", 677 | " self.total_samples = self.num_ids * self.samples_per_id\n", 678 | "\n", 679 | " # Precompute all valid (id, start_idx) pairs\n", 680 | " self.indices = [\n", 681 | " (i, j)\n", 682 | " for i in range(self.num_ids)\n", 683 | " for j in range(self.samples_per_id)\n", 684 | " ]\n", 685 | " \n", 686 | " def __getitem__(self, index):\n", 687 | " # batch-wise item \n", 688 | " batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]\n", 689 | " \n", 690 | " x_batch = []\n", 691 | " y_batch = []\n", 692 | "\n", 693 | " for id_idx, start_idx in batch_indices:\n", 694 | " x = self.data[id_idx, start_idx:start_idx + self.train_seq_len, 1:]\n", 695 | " y = self.data[id_idx, start_idx + self.train_seq_len:start_idx + self.train_seq_len + self.pred_seq_len, 0]\n", 696 | " x_batch.append(x)\n", 697 | " y_batch.append(y)\n", 698 | "\n", 699 | " return np.nan_to_num(np.array(x_batch)), np.nan_to_num(np.array(y_batch))\n", 700 | " \n", 701 | " def __len__(self):\n", 702 | " # depends on how many samples you want to extract from 1 ID\n", 703 | " return int(np.ceil(len(self.indices) / self.batch_size))" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 9, 709 | "id": "ac417c3d", 710 | "metadata": { 711 | "execution": { 712 | "iopub.execute_input": "2025-05-16T02:50:11.153097Z", 713 | "iopub.status.busy": "2025-05-16T02:50:11.152458Z", 714 | "iopub.status.idle": "2025-05-16T02:50:11.162165Z", 715 | "shell.execute_reply": "2025-05-16T02:50:11.161518Z" 716 | }, 717 | "papermill": { 718 | "duration": 0.014177, 719 | "end_time": "2025-05-16T02:50:11.163314", 720 | "exception": false, 721 | "start_time": "2025-05-16T02:50:11.149137", 722 | "status": "completed" 723 | }, 724 | "tags": [] 725 | }, 726 | "outputs": [ 727 | { 728 | "name": "stdout", 729 | "output_type": "stream", 730 | "text": [ 731 | "(64, 12, 3)\n", 732 | "(64, 3)\n" 733 | ] 734 | } 735 | ], 736 | "source": [ 737 | "train_dataset = TimeDataset(data_3d, CFG.train_sequence_length, CFG.predict_sequence_length)\n", 738 | "valid_dataset = TimeDataset(data_3d, CFG.train_sequence_length, CFG.predict_sequence_length)\n", 739 | "\n", 740 | "print(train_dataset[0][0].shape)\n", 741 | "print(train_dataset[0][1].shape)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 10, 747 | "id": "6a3cc9ee", 748 | "metadata": { 749 | "execution": { 750 | "iopub.execute_input": "2025-05-16T02:50:11.170327Z", 751 | "iopub.status.busy": "2025-05-16T02:50:11.169583Z", 752 | "iopub.status.idle": "2025-05-16T02:50:12.527408Z", 753 | "shell.execute_reply": "2025-05-16T02:50:12.526739Z" 754 | }, 755 | "papermill": { 756 | "duration": 1.36238, 757 | "end_time": "2025-05-16T02:50:12.528531", 758 | "exception": false, 759 | "start_time": "2025-05-16T02:50:11.166151", 760 | "status": "completed" 761 | }, 762 | "tags": [] 763 | }, 764 | "outputs": [ 765 | { 766 | "data": { 767 | "text/html": [ 768 | "
Model: \"functional\"\n",
769 |        "
\n" 770 | ], 771 | "text/plain": [ 772 | "\u001b[1mModel: \"functional\"\u001b[0m\n" 773 | ] 774 | }, 775 | "metadata": {}, 776 | "output_type": "display_data" 777 | }, 778 | { 779 | "data": { 780 | "text/html": [ 781 | "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
782 |        "┃ Layer (type)                          Output Shape                         Param # ┃\n",
783 |        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
784 |        "│ input_layer (InputLayer)             │ (None, 12, 3)               │               0 │\n",
785 |        "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
786 |        "│ encoder (Encoder)                    │ [(None, 12, 64), (None,     │               0 │\n",
787 |        "│                                      │ 128)]                       │                 │\n",
788 |        "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
789 |        "│ dense (Dense)                        │ (None, 128)                 │          16,512 │\n",
790 |        "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
791 |        "│ dense_1 (Dense)                      │ (None, 128)                 │          16,512 │\n",
792 |        "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
793 |        "│ dense_2 (Dense)                      │ (None, 1)                   │             129 │\n",
794 |        "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
795 |        "│ reshape (Reshape)                    │ (None, 1, 1)                │               0 │\n",
796 |        "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
797 |        "
\n" 798 | ], 799 | "text/plain": [ 800 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", 801 | "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", 802 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", 803 | "│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", 804 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", 805 | "│ encoder (\u001b[38;5;33mEncoder\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m64\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n", 806 | "│ │ \u001b[38;5;34m128\u001b[0m)] │ │\n", 807 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", 808 | "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n", 809 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", 810 | "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n", 811 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", 812 | "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │\n", 813 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", 814 | "│ reshape (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", 815 | "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" 816 | ] 817 | }, 818 | "metadata": {}, 819 | "output_type": "display_data" 820 | }, 821 | { 822 | "data": { 823 | "text/html": [ 824 | "
 Total params: 33,153 (129.50 KB)\n",
825 |        "
\n" 826 | ], 827 | "text/plain": [ 828 | "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m33,153\u001b[0m (129.50 KB)\n" 829 | ] 830 | }, 831 | "metadata": {}, 832 | "output_type": "display_data" 833 | }, 834 | { 835 | "data": { 836 | "text/html": [ 837 | "
 Trainable params: 33,153 (129.50 KB)\n",
838 |        "
\n" 839 | ], 840 | "text/plain": [ 841 | "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m33,153\u001b[0m (129.50 KB)\n" 842 | ] 843 | }, 844 | "metadata": {}, 845 | "output_type": "display_data" 846 | }, 847 | { 848 | "data": { 849 | "text/html": [ 850 | "
 Non-trainable params: 0 (0.00 B)\n",
851 |        "
\n" 852 | ], 853 | "text/plain": [ 854 | "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" 855 | ] 856 | }, 857 | "metadata": {}, 858 | "output_type": "display_data" 859 | } 860 | ], 861 | "source": [ 862 | "def build_model():\n", 863 | " inputs = tf.keras.Input(shape=(CFG.train_sequence_length, 3))\n", 864 | " \n", 865 | " config = AutoConfig()(\"rnn\")\n", 866 | " config.rnn_type = \"lstm\"\n", 867 | " backbone = AutoModel.from_config(config=config)\n", 868 | " \n", 869 | " outputs = backbone(inputs)\n", 870 | " model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", 871 | " model.compile(loss=tf.keras.losses.MeanAbsoluteError(), optimizer=tf.keras.optimizers.Adam(), metrics = ['mae'])\n", 872 | " return model\n", 873 | "\n", 874 | "\n", 875 | "model = build_model()\n", 876 | "model.summary()" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": 11, 882 | "id": "d13f0ce1", 883 | "metadata": { 884 | "execution": { 885 | "iopub.execute_input": "2025-05-16T02:50:12.536919Z", 886 | "iopub.status.busy": "2025-05-16T02:50:12.536488Z", 887 | "iopub.status.idle": "2025-05-16T02:50:40.889410Z", 888 | "shell.execute_reply": "2025-05-16T02:50:40.888816Z" 889 | }, 890 | "papermill": { 891 | "duration": 28.358352, 892 | "end_time": "2025-05-16T02:50:40.890731", 893 | "exception": false, 894 | "start_time": "2025-05-16T02:50:12.532379", 895 | "status": "completed" 896 | }, 897 | "tags": [] 898 | }, 899 | "outputs": [ 900 | { 901 | "name": "stdout", 902 | "output_type": "stream", 903 | "text": [ 904 | "Epoch 1/10\n" 905 | ] 906 | }, 907 | { 908 | "name": "stderr", 909 | "output_type": "stream", 910 | "text": [ 911 | "/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n", 912 | " self._warn_if_super_not_called()\n", 913 | "I0000 00:00:1747363815.790123 65 cuda_dnn.cc:529] Loaded cuDNN version 90300\n" 914 | ] 915 | }, 916 | { 917 | "name": "stdout", 918 | "output_type": "stream", 919 | "text": [ 920 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - loss: 324.2088 - mae: 324.2088 - val_loss: 197.0513 - val_mae: 197.0513\n", 921 | "Epoch 2/10\n", 922 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 211.5362 - mae: 211.5362 - val_loss: 187.4890 - val_mae: 187.4890\n", 923 | "Epoch 3/10\n", 924 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 9ms/step - loss: 174.5251 - mae: 174.5251 - val_loss: 162.3955 - val_mae: 162.3955\n", 925 | "Epoch 4/10\n", 926 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 9ms/step - loss: 172.6773 - mae: 172.6773 - val_loss: 149.7986 - val_mae: 149.7986\n", 927 | "Epoch 5/10\n", 928 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 168.0907 - mae: 168.0907 - val_loss: 148.2648 - val_mae: 148.2648\n", 929 | "Epoch 6/10\n", 930 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 9ms/step - loss: 163.7792 - mae: 163.7792 - val_loss: 157.9741 - val_mae: 157.9741\n", 931 | "Epoch 7/10\n", 932 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 166.4666 - mae: 166.4666 - val_loss: 166.7117 - val_mae: 166.7117\n", 933 | "Epoch 8/10\n", 934 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 150.0828 - mae: 150.0828 - val_loss: 159.2332 - val_mae: 159.2332\n", 935 | "Epoch 9/10\n", 936 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 140.8101 - mae: 140.8101 - val_loss: 187.0695 - val_mae: 187.0695\n", 937 | "Epoch 10/10\n", 938 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 154.2695 - mae: 154.2695 - val_loss: 137.9109 - val_mae: 137.9109\n" 939 | ] 940 | } 941 | ], 942 | "source": [ 943 | "history = model.fit(train_dataset, validation_data=valid_dataset, epochs=10) \n", 944 | "model.save_weights('./sales_model.weights.h5')" 945 | ] 946 | } 947 | ], 948 | "metadata": { 949 | "kaggle": { 950 | "accelerator": "gpu", 951 | "dataSources": [ 952 | { 953 | "datasetId": 7421883, 954 | "sourceId": 11816396, 955 | "sourceType": "datasetVersion" 956 | } 957 | ], 958 | "dockerImageVersionId": 31041, 959 | "isGpuEnabled": true, 960 | "isInternetEnabled": true, 961 | "language": "python", 962 | "sourceType": "notebook" 963 | }, 964 | "kernelspec": { 965 | "display_name": "Python 3", 966 | "language": "python", 967 | "name": "python3" 968 | }, 969 | "language_info": { 970 | "codemirror_mode": { 971 | "name": "ipython", 972 | "version": 3 973 | }, 974 | "file_extension": ".py", 975 | "mimetype": "text/x-python", 976 | "name": "python", 977 | "nbconvert_exporter": "python", 978 | "pygments_lexer": "ipython3", 979 | "version": "3.11.11" 980 | }, 981 | "papermill": { 982 | "default_parameters": {}, 983 | "duration": 60.266353, 984 | "end_time": "2025-05-16T02:50:44.081675", 985 | "environment_variables": {}, 986 | "exception": null, 987 | "input_path": "__notebook__.ipynb", 988 | "output_path": "__notebook__.ipynb", 989 | "parameters": {}, 990 | "start_time": "2025-05-16T02:49:43.815322", 991 | "version": "2.6.0" 992 | } 993 | }, 994 | "nbformat": 4, 995 | "nbformat_minor": 5 996 | } 997 | --------------------------------------------------------------------------------