├── README.md
├── build.sbt
├── img
├── architecture.png
├── introduce.png
├── job_input_output.png
├── result_plot.png
└── tmp
├── project
├── build.properties
└── plugins.sbt
└── src
└── main
└── scala
└── com
└── bigdata
├── CommonUtils.scala
├── PatternFinder.scala
├── Preprocessor.scala
├── StockPatternStream.scala
└── StreamingManager.scala
/README.md:
--------------------------------------------------------------------------------
1 | # Stock Pattern Stream
2 | ### 현재 주가 차트와 유사한 과거 차트를 실시간으로 찾아주는 분석 서비스
3 |
4 | 주식 종목의 가격과 거래량을 보고 금융 시장의 흐름을 예측하는 것을 "기술적 분석"이라 합니다.
5 | 종목의 주가 패턴은 때떄로 반복적으로 나타나기도 하며, 주식 거래를 하는 많은 사람들이 주가를 예측하기 위해 이러한 패턴을 분석하곤 합니다.
6 | 다음 그림처럼 만약 현재의 주가 차트가 과거의 어떤 시점과 매우 유사하다면, 해당 과거 시점의 추이가 현재 주가의 추이를 예측하는데 도움이 될 수도 있습니다.
7 |
8 |
9 |
10 |
11 |
12 | 본 레포지터리는 *"종목의 과거 주가 차트에서 현재 차트와 가장 유사한 구간을 실시간으로 찾아주는 프로그램이 있다면 편리하지 않을까?"*
13 | 라는 아이디어를 Spark Streaming으로 구현한 것입니다.
14 |
15 |
16 | ## 1. Architecture
17 |
18 | ### 1.1. 실시간 과거 패턴 찾기
19 |
20 | 본 프로세스에선 종목의 과거 주가 차트에 window sliding을 적용하여 n 크기의 슬라이드들을 생성합니다.
21 |
22 | Spark Streaming Job은 종목의 "최근 n개 분봉"을 입력받으면,
23 | 미리 생성한 슬라이드들과 Pearson 상관계수를 계산하여 가장 유사한 슬라이드를 찾습니다.
24 |
25 |
26 |
27 |
28 |
29 | 한 종목에 대하여 위와 같은 패턴 찾기 작업이 parallel하게 수행됩니다.
30 | 분석 대상 종목이 여러 개인 경우, 본 작업을 각 종목에 대해 순차적으로 수행합니다.
31 |
32 | ### 1.2. 전체 구성
33 |
34 |
35 |
36 |
37 | - Kafka Producer
38 | - 키움 OpenAPI 등으로 종목 분봉 데이터를 실시간으로 수집하여 Kafka에 전송해주는 producer 프로세스
39 | - 분석 대상 종목들의 최근 n개 분봉을 전송
40 | - Kafka
41 | - 분봉 데이터 저장 Queue
42 | - 실시간 과거 패턴 찾기
43 | - 본 레포지터리의 구현 범위
44 | - Spark Streaming Job으로 Kafka의 데이터를 입력 받아 실시간으로 분석
45 | - 분석 결과를 FS에 저장
46 | - ELK Stack
47 | - Logstash, Elasticsearch, Kibana를 이용하여 실시간 분석 결과를 시각화
48 | - Logstash: 저장된 분석 결과를 Elasticsearch에 입력
49 | - Kibana: 실시간 분석 결과를 대시보드에 시각화
50 |
51 |
52 | ### 1.3. 분석 스펙
53 |
54 | - Spark Mode
55 | - Standalone (별도의 Hadoop cluster 없이 동작)
56 | - (Hadoop cluster 모드는 추후 고려 예정)
57 | - Streaming Perfomance
58 | - 10개 이내 종목에 대해 1분 단위 (batchInterval) 스트리밍 분석
59 | - default batchInterval은 5분으로 설정됨
60 | - 분석 데이터 길이
61 | - window slide 길이 n: 59 분봉
62 | - 과거치 분봉 데이터 길이: 629 분봉
63 |
64 | ### 1.4. Input/Output 데이터 구조
65 | 현재 일봉 데이터로 분봉을 대체하고 있으며, 추후 분봉으로 변경 예정임
66 |
67 | |input data| -> |output data|
68 | |:---:|:---:|:---:|
69 | |종목별 과거분봉 데이터종목별 실시간 분봉 데이터종목코드->종목명 매핑 테이블|(*spark streaming job*)|실시간 분석 결과|
70 |
71 | - (Input) 종목별 과거 분봉 데이터
72 | - 입력 방식: spark-submit argument 중 ```--hist-path [file path]```로 파일 경로 입력
73 | - 데이터 포맷: json
74 | ```
75 | // {시간1:분봉가격, 시간2:분봉가격, ..., 시간629:분봉가격, symb:종목코드}
76 | {"20160201": 5630,"20160202": 5633,...,"symb":"A123456"},{...},...
77 | ```
78 | - (Input) 종목별 실시간 분봉 데이터
79 | - 입력 방식: kafka consuming (message's value)
80 | - 데이터 포맷: json
81 | ```
82 | // {시간1:분봉가격, 시간2:분봉가격, ..., 시간59:분봉가격, symb:종목코드}
83 | {"20160201": 5630,"20160202": 5633,...,"symb":"A123456"},{...},...
84 | ```
85 | - (Input) 종목코드->종목명 매핑 테이블
86 | - 입력 방식: spark-submit argument 중 ```--symb2name-path [file path]```로 파일 경로 입력
87 | - 데이터 포맷: json
88 | ```
89 | // [{"symb":종목코드,"name":종목명}, ...]
90 | [{"symb":"A001720","name":"신영증권"},...]
91 | ```
92 | - (Output) 실시간 분석 결과
93 | - 저장 방식: spark-submit argument 중 ```--output-dir [directory path]``` 값 경로에 파일로 저장
94 | - 데이터 포맷: csv
95 | ```
96 | // 종목코드,종목명,실시간분봉시간,실시간분봉가격,과거분봉시간,과거분봉가격,버전,상관계수
97 | A097950,CJ제일제당,20190705,294000.0,20190402,326000.0,20190930,0.949744
98 | A097950,CJ제일제당,20190708,289000.0,20190403,326000.0,20190930,0.949744
99 | A097950,CJ제일제당,20190709,284000.0,20190404,325000.0,20190930,0.949744
100 | ...(후략)...
101 | ```
102 | - 데이터 시각화 예시 (CJ제일제당, 유사도 94%)
103 | - Now: 실시간분봉시간에 따른 실시간분봉가격
104 | - Then: 과거분봉시간에 따른 과거분봉가격
105 |
106 |
107 |
108 |
109 |
110 |
111 | ### 1.5. 주요 소스 코드
112 |
113 | - StockPatternStream.scala
114 | - 메인 싱글톤. arguments parsing 및 분석 함수 호출
115 | - Preprocessor.scala
116 | - DataFrame 전처리 수행
117 | - PatternFinder.scala
118 | - Spark ML을 이용한 Correlation 계산
119 | - StreamingManager.scala
120 | - Spark Streaming 정의
121 | - CommonUtils.scala
122 | - 유틸리티 함수 모음
123 |
124 | ## 2. How to use
125 |
126 | ### 2.1. Prerequisites
127 | |name|version|
128 | |:---|:---|
129 | |Scala|2.11.12|
130 | |SBT|1.3.10|
131 | |JDK|1.8.0|
132 | |Apache Spark|2.4.5|
133 | |spark-streaming-kafka|0.10.0|
134 |
135 | - Spark은 standalone으로 동작함
136 |
137 | ### 2.2. Usage
138 |
139 | create fat jar
140 | ```
141 | sbt assembly
142 | ```
143 |
144 | spark submit
145 | ```
146 | spark-submit jarfile --hist-path [path1] --symb2name-path [path2] --output-dir [path3]
147 | --kafka-bootstrap-server [addr] --kafka-group-id [id] --kafka-topic [topic] --batch-interval [seconds]
148 |
149 | arguments
150 | --hist-path: 종목별 과거 분봉 데이터 경로
151 | --symb2name-path: 종목코드->종목명 매핑 테이블 경로
152 | --output-dir: 분석 결과 저장 경로
153 | --kafka-bootstrap-server: 카프카 부트스트랩 주소 (localhost:9092)
154 | --kafka-group-id: 컨슈머 그룹 ID
155 | --kafka-topic: 토픽명
156 | --batch-interval: Spark Streaming 배치 간격 (default: 300)
157 | ```
158 |
159 |
160 |
161 | ## 3. Future tasks
162 |
163 | - Spark Streaming과 Elasticsearch 연동, Logstash 단계 제거
164 | - 일봉 데이터 -> 분봉 데이터 변경 (현재 일봉 데이터로 분봉을 대체함)
165 | - 싱글모드 -> 클러스터 전환 성능 실험
166 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | name := "StockPatternStream"
2 |
3 | version := "0.1"
4 |
5 | scalaVersion := "2.11.12"
6 |
7 |
8 | libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.5" % "provided"
9 | libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.5" % "provided"
10 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.5" % "provided"
11 | libraryDependencies += "org.apache.spark" %% "spark-streaming" % "2.4.5" % "provided"
12 | libraryDependencies += "org.apache.spark" %% "spark-streaming-kafka-0-10" % "2.4.5" % "provided"
13 | libraryDependencies += "org.apache.spark" %% "spark-hive" % "2.4.5" % "provided"
14 | libraryDependencies += "org.apache.kafka" % "kafka-clients" % "2.4.1" % "provided"
15 | libraryDependencies += "org.apache.kafka" %% "kafka" % "2.4.1" % "provided"
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/img/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/architecture.png
--------------------------------------------------------------------------------
/img/introduce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/introduce.png
--------------------------------------------------------------------------------
/img/job_input_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/job_input_output.png
--------------------------------------------------------------------------------
/img/result_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/result_plot.png
--------------------------------------------------------------------------------
/img/tmp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.3.10
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10")
2 |
--------------------------------------------------------------------------------
/src/main/scala/com/bigdata/CommonUtils.scala:
--------------------------------------------------------------------------------
1 | package com.bigdata
2 |
3 | import org.apache.spark.sql.{DataFrame, SparkSession}
4 |
5 |
6 | /**
7 | * 프로젝트 공통 유틸리티 함수 모음
8 | * 2020.05.13 by dhkim
9 | */
10 |
11 | object CommonUtils {
12 |
13 | /**
14 | * SparkStreaming에서 broadcast할 변수들 담을 케이스 클래스
15 | * @param histSize: historical size
16 | * @param rtSize: real-time size
17 | * @param batchInterval: streaming interval
18 | * @param outputDir: save dir
19 | */
20 | case class BroadcastItems(histSize: Int, rtSize: Int, batchInterval: Int, outputDir: String) {}
21 |
22 |
23 | def getHistPriceMap(spark: SparkSession, rawPriceDf: DataFrame, priceSize: Int,
24 | dates: Array[String]): Map[String, Map[String, Double]] = {
25 | import spark.implicits._
26 |
27 | // Map[symb -> Map[date -> price]]
28 | val priceMap = rawPriceDf
29 | .map{ row =>
30 | val symb = row.getString(priceSize)
31 | val prices = (0 until priceSize).map(row.getLong(_).toDouble)
32 | (symb, dates.zip(prices).toMap)
33 | }
34 | .collect
35 | .toMap
36 |
37 | priceMap
38 | }
39 |
40 | def getRtPriceMap(spark: SparkSession, rawPriceDf: DataFrame, priceSize: Int): Map[String, Seq[Double]] = {
41 | import spark.implicits._
42 |
43 | val priceMap = rawPriceDf
44 | .map{ row =>
45 | val symb = row.getString(priceSize)
46 | val prices = (0 until priceSize).map(row.getLong(_).toDouble)
47 | (symb, prices)
48 | }
49 | .collect
50 | .toMap
51 |
52 | priceMap
53 | }
54 |
55 | def getSymb2NameMap(spark: SparkSession, path: String): Map[String, String] = {
56 | import spark.implicits._
57 |
58 | val symb2nameMapDf: DataFrame = spark.read.json(path)
59 | val symb2nameMap: Map[String, String] = symb2nameMapDf
60 | .map { r =>
61 | val symbCol = r.getAs[String]("symb")
62 | val nameCol = r.getAs[String]("name")
63 | (symbCol, nameCol)
64 | }
65 | .collect
66 | .toMap
67 |
68 | // symb to name Map
69 | symb2nameMap
70 | }
71 |
72 | def getStockSymbols(spark: SparkSession, rawHistDf: DataFrame): Array[String] = {
73 | import spark.implicits._
74 |
75 | val stockSymbols = rawHistDf
76 | .select($"symb")
77 | .collect
78 | .map(_(0).toString)
79 |
80 | stockSymbols
81 | }
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/src/main/scala/com/bigdata/PatternFinder.scala:
--------------------------------------------------------------------------------
1 | package com.bigdata
2 |
3 | import org.apache.spark.ml.linalg.{Matrix, Vectors}
4 | import org.apache.spark.ml.stat.Correlation
5 | import org.apache.spark.sql.{DataFrame, Row, SparkSession}
6 | import org.apache.spark.sql.functions.explode
7 |
8 |
9 | /**
10 | * 종목별 과거 유사 패턴찾기 위한 Correlation 기능 구현
11 | * 2020.05.13 by dhkim
12 | */
13 |
14 | object PatternFinder {
15 |
16 | /**
17 | *
18 | * @param stkCode: 종목코드
19 | * @param mostSimDates: 과거 구간들 중 가장 유사한 구간의 날짜값들
20 | * @param coef: 유사도
21 | */
22 | case class CorrRow(stkCode: String, mostSimDates: Array[String], coef: Double) {}
23 |
24 | /**
25 | * @param spark: SparkSession
26 | * @param scaledHistDf: window sliding 및 min-max 스케일 적용된 historical price df
27 | * @param scaledRtDf: min-max 스케일 적용된 real-time price df
28 | * @param stockSymbols: 종목코드 배열. ex) Array(A123456, A888888, ...)
29 | * @param histDates: scaledHistDf의 날짜들. ex) Array(20200301, 20200302, ...)
30 | * @param rtDates: scaledRtDf의 날짜들.
31 | * @param histSize: 종목별 historical price 전체 길이
32 | * @param rtSize: real-time으로 분석할 종가 구간 길이
33 | */
34 | def getCorrelation(spark: SparkSession, scaledHistDf: DataFrame, scaledRtDf: DataFrame, stockSymbols: Array[String],
35 | histDates: Array[String], rtDates: Array[String], histSize: Int, rtSize: Int): Array[CorrRow] = {
36 |
37 | import spark.implicits._
38 |
39 | val slideSize = histSize - rtSize + 1
40 | val coefSize = slideSize + 1
41 |
42 | // 두 df를 join 및 slide들을 여러 row들로 분리(explode)
43 | val explodedDf = scaledRtDf.as("l")
44 | .join(scaledRtDf.as("r"), $"l.symb" === $"r.symb")
45 | .drop($"l.symb")
46 | .map{ row =>
47 | val slided = row.getAs[Seq[Seq[Double]]](0)
48 | val rt = row.getAs[Seq[Double]](1)
49 | val symb = row.getString(2)
50 | val combined = rt
51 | .zip(slided)
52 | .map(zipped => Seq(zipped._1) ++ zipped._2)
53 | (combined, symb)
54 | }
55 | .toDF("combined", "symb")
56 | .select(explode($"combined").as("exploded"), $"symb") // explode 함수로 row의 slide들을 여러 row들로 분리
57 | /* explodedDf.show()
58 | * +-----------------------------+--------+
59 | * |exploded |symb |
60 | * +-----------------------------+--------+
61 | * |[0.33, 1.0, 0.0, 0.0, ...] |A123456 |
62 | * |[0.66, 0.0, 0.14, 0.16, ...] |A123456 |
63 | * |... |... |
64 | * |[1.0, 0.0, 0.63, 1.0, ...] |A888888 |
65 | * |[1.0, 0.2, 1.0, 0.81, ...] |A888888 |
66 | * |... |... |
67 | * +-----------------------------+--------+
68 | */
69 |
70 | // correlation 계산을 위한 벡터화
71 | val vectorDf = explodedDf
72 | .rdd
73 | .map{ row =>
74 | val vectCol = row.getAs[Seq[Double]](0).toArray
75 | val symb = row.getString(1)
76 | (Vectors.dense(vectCol), symb)
77 | }
78 | .toDF("vectors", "symb")
79 | /* vectorDf.show()
80 | * +-----------------------------+--------+
81 | * |vectors |symb |
82 | * +-----------------------------+--------+
83 | * |[0.33, 1.0, 0.0, 0.0, ...] |A123456 |
84 | * |[0.66, 0.0, 0.14, 0.16, ...] |A123456 |
85 | * |... |... |
86 | * |[1.0, 0.0, 0.63, 1.0, ...] |A888888 |
87 | * |[1.0, 0.2, 1.0, 0.81, ...] |A888888 |
88 | * |... |... |
89 | * +-----------------------------+--------+
90 | */
91 |
92 | val resCorrRows = stockSymbols
93 | .map{ symb =>
94 | val filteredDf = vectorDf.filter($"symb" === symb)
95 | val Row(coefMat: Matrix) = Correlation.corr(filteredDf, "vectors").head
96 | val coefAry = coefMat.toArray.slice(1, coefSize) // 현재값 제외
97 |
98 | val mostSimilar = coefAry.zipWithIndex.maxBy(_._1)
99 | val mostSimilarDates = histDates.slice(mostSimilar._2, mostSimilar._2 + rtSize)
100 |
101 | // (종목코드, 상관구간 날짜들, 상관계수)
102 | CorrRow(symb, mostSimilarDates, mostSimilar._1)
103 | }
104 | /* resCorrRows
105 | * Array(CorrRow(A123456, Array(20100304, 20100305, ...), 0.4773), CorrRow(...), ...)
106 | */
107 |
108 | resCorrRows
109 | }
110 |
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/scala/com/bigdata/Preprocessor.scala:
--------------------------------------------------------------------------------
1 | package com.bigdata
2 |
3 | import scala.math.floor
4 | import org.apache.spark.sql.{DataFrame, SparkSession}
5 |
6 |
7 | /**
8 | * DataFrame 로드 및 전처리
9 | * 2020.05.13 by dhkim
10 | */
11 |
12 | object Preprocessor {
13 |
14 | def getRawDf(spark: SparkSession, path: String): DataFrame = {
15 | val rawDf = spark.read.json(path)
16 | rawDf
17 | }
18 |
19 | def spRound(inval: Double): Double = {
20 | val step = 1000000
21 | floor(inval*step) / step
22 | }
23 |
24 | def getScaledHistDf(spark: SparkSession, rawHistDf: DataFrame, histSize: Int, rtSize: Int): DataFrame = {
25 | /*
26 | * 윈도우 슬라이딩 및 rt 기준 슬라이드별 min-max 스케일링 적용된 historical price df 생성
27 | * > input
28 | * +--------+--------+--------+--------+
29 | * |20100301|20100302|... |symb |
30 | * +--------+--------+--------+--------+
31 | * |24.0 |22.0 |... |A123456 |
32 | * |114.0 |110.0 |... |A888888 |
33 | * +--------+--------+--------+--------+
34 | * > output
35 | * +----------------------+--------+
36 | * |scaledHist |symb |
37 | * +----------------------+--------+
38 | * |[[1.0,0.93,...], ...] |A123456 |
39 | * |[[...], ...] |A888888 |
40 | * +----------------------+--------+
41 | */
42 | import spark.implicits._
43 |
44 | val slideSize = histSize - rtSize + 1
45 |
46 | val slidedHistDf = rawHistDf
47 | .map{ row =>
48 | // sliding
49 | val slidedHist = (0 until histSize)
50 | .map(row.getLong(_).toDouble)
51 | .toArray
52 | .sliding(slideSize)
53 | .map(_.toArray)
54 | .toArray
55 |
56 | // scaling
57 | val transposed = slidedHist.transpose
58 | val mins = transposed.map(_.min)
59 | val durations = transposed.map(i => i.max-i.min) // 스케일링 기준은 slide 가 아니라, 각 slide별 같은 위치값
60 | val scaledAry = slidedHist.map{ elem =>
61 | val scaled = elem
62 | .indices
63 | .map(i => spRound((elem(i)-mins(i))/durations(i)))
64 | .toArray
65 | scaled
66 | }
67 |
68 | // symb column
69 | val symb = row.getString(histSize)
70 | (scaledAry, symb)
71 | }
72 | .toDF("scaledHist", "symb")
73 |
74 | slidedHistDf
75 | }
76 |
77 | def getScaledRtDf(spark: SparkSession, rawRtDf: DataFrame, rtSize: Int): DataFrame = {
78 | /* real-time으로 들어오는 종목별 price를 array 파싱 및 min-max 스케일ㄹ이하여 새로운 df 생성
79 | * > input
80 | * +--------+--------+--------+--------+
81 | * |20200401|20200402|... |symb |
82 | * +--------+--------+--------+--------+
83 | * |77.0 |78.0 |... |A123456 |
84 | * |7.0 |7.0 |... |A888888 |
85 | * +--------+--------+--------+--------+
86 | * > output
87 | * +------------------+--------+
88 | * |scaledRt |symb |
89 | * +------------------+--------+
90 | * |[0.33, 0.66, ...] |A123456 |
91 | * |[1.0, 1.0, ...] |A888888 |
92 | * +------------------+--------+
93 | */
94 | import spark.implicits._
95 |
96 | val scaledRtDf = rawRtDf
97 | .map{ row =>
98 | // scaling
99 | val rt = (0 until rtSize).map(row.getLong(_).toDouble).toArray
100 | val duration = rt.max - rt.min
101 | val scaled = rt.map(v => spRound((v-rt.min)/duration))
102 |
103 | // symb column
104 | val symb = row.getString(rtSize)
105 | (scaled, symb)
106 | }
107 | .toDF("scaledRt", "symb")
108 |
109 | scaledRtDf
110 | }
111 |
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/scala/com/bigdata/StockPatternStream.scala:
--------------------------------------------------------------------------------
1 | package com.bigdata
2 |
3 | import org.apache.spark.sql.SparkSession
4 | import com.bigdata.CommonUtils.BroadcastItems
5 |
6 |
7 | /**
8 | * 프로젝트 메인
9 | * 2020.05.13 by dhkim
10 | */
11 |
12 | object StockPatternStream {
13 |
14 | /**
15 | * argument parse
16 | * @param map: scala Map
17 | * @param argList: argument 문자열 리스트
18 | * @return 키-밸류 Map
19 | */
20 | @scala.annotation.tailrec
21 | def parseArgs(map : Map[Symbol, Any], argList: List[String]): Map[Symbol, Any] = {
22 | argList match {
23 | case Nil => map
24 | case "--hist-path" :: value :: tail => parseArgs(map ++ Map('histpath -> value.toString), tail)
25 | case "--symb2name-path" :: value :: tail => parseArgs(map ++ Map('symb2name -> value.toString), tail)
26 | case "--output-dir" :: value :: tail => parseArgs(map ++ Map('outputdir -> value.toString), tail)
27 | case "--kafka-bootstrap-server" :: value :: tail => parseArgs(map ++ Map('bootstrap -> value.toString), tail)
28 | case "--kafka-group-id" :: value :: tail => parseArgs(map ++ Map('groupid -> value.toString), tail)
29 | case "--kafka-topic" :: value :: tail => parseArgs(map ++ Map('topic -> value.toString), tail)
30 | case "--hist-size" :: value :: tail => parseArgs(map ++ Map('histsize -> value.toInt), tail)
31 | case "--rt-size" :: value :: tail => parseArgs(map ++ Map('rtsize -> value.toInt), tail)
32 | case "--batch-interval" :: value :: tail => parseArgs(map ++ Map('batchinterval -> value.toInt), tail)
33 | case option :: tail => println("Unknown argument " + option)
34 | sys.exit(1)
35 | }
36 | }
37 |
38 | def main(args: Array[String]): Unit = {
39 |
40 | val usage =
41 | """Usage: spark-submit jarfile --hist-path [path1] --symb2name-path [path2] --output-dir [path3]
42 | |--kafka-bootstrap-server [addr] --kafka-group-id [id] --kafka-topic [topic]""".stripMargin
43 |
44 | // arguments parsing
45 | if (args.length == 0) println(usage)
46 | val env = parseArgs(Map(), args.toList)
47 | val histPath = env.getOrElse('histpath, "/data/ailabHome/elasticHome/historyPattern/histQuotes.json").toString
48 | val symb2namePath = env.getOrElse('symb2name, "/data/ailabHome/elasticHome/historyPattern/symb2nameDic.json").toString
49 | val outputDir = env.getOrElse('outputdir, "/data/ailabHome/elasticHome/historyPattern/correlation").toString
50 | val kafkaBootstrapServers = env.getOrElse('bootstrap, "localhost:9092").toString
51 | val kafkaGroupId = env.getOrElse('groupid, "group01").toString
52 | val topic = env.getOrElse('topic, "topicA").toString
53 | val topics = Array(topic)
54 | val histSize = env.getOrElse('histsize, 629).asInstanceOf[Int] // historical 데이터의 일자 수 (630 - 1)
55 | val rtSize = env.getOrElse('rtsize, 59).asInstanceOf[Int] // real-time 데이터의 일자 수 (DStream의 각 rdd 일자 수) (60 - 1)
56 | val batchInterval = env.getOrElse('batchinterval, 60 * 5).asInstanceOf[Int] // Spark Streaming 배치 간격
57 |
58 | // create spark session
59 | val spark = SparkSession
60 | .builder()
61 | .master("local")
62 | .appName("StockPatternStream")
63 | .getOrCreate()
64 |
65 | // prepare data
66 | val rawHistDf = Preprocessor.getRawDf(spark, histPath)
67 | val scaledHistDf = Preprocessor.getScaledHistDf(spark, rawHistDf, histSize, rtSize)
68 | val histDates = rawHistDf.columns.slice(0, histSize)
69 | val stockSymbols = CommonUtils.getStockSymbols(spark, rawHistDf)
70 | val histPriceMap = CommonUtils.getHistPriceMap(spark, rawHistDf, histSize, histDates)
71 | val symb2nameMap = CommonUtils.getSymb2NameMap(spark, symb2namePath)
72 | val broadcastItems = BroadcastItems(histSize, rtSize, batchInterval, outputDir)
73 |
74 | // broadcasting
75 | val bcHistDates = spark.sparkContext.broadcast(histDates)
76 | val bcStockSymbols = spark.sparkContext.broadcast(stockSymbols)
77 | val bcHistPriceMap = spark.sparkContext.broadcast(histPriceMap)
78 | val bcSymb2nameMap = spark.sparkContext.broadcast(symb2nameMap)
79 | val bcBroadcastItems = spark.sparkContext.broadcast(broadcastItems)
80 |
81 | // spark streaming
82 | StreamingManager.process(spark,
83 | kafkaBootstrapServers,
84 | kafkaGroupId,
85 | topics,
86 | scaledHistDf,
87 | bcStockSymbols,
88 | bcHistDates,
89 | bcSymb2nameMap,
90 | bcHistPriceMap,
91 | bcBroadcastItems)
92 | }
93 |
94 | }
95 |
--------------------------------------------------------------------------------
/src/main/scala/com/bigdata/StreamingManager.scala:
--------------------------------------------------------------------------------
1 | package com.bigdata
2 |
3 | import org.apache.kafka.common.serialization.StringDeserializer
4 | import org.apache.spark.broadcast.Broadcast
5 | import org.apache.spark.sql.{DataFrame, SparkSession}
6 | import org.apache.spark.streaming.{Seconds, StreamingContext}
7 | import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
8 | import org.apache.spark.streaming.kafka010.KafkaUtils
9 | import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
10 | import com.bigdata.CommonUtils.BroadcastItems
11 |
12 |
13 | /**
14 | * Spark Streaming 구현
15 | * 2020.05.13 by dhkim
16 | */
17 |
18 | object StreamingManager {
19 |
20 | /**
21 | * SparkStreaming이 수행되는 프로세스
22 | * @param spark: SparkSession
23 | * @param kafkaBootstrapServers: kafka bootstrap server addr
24 | * @param kafkaGroupId: kafka consumer group id
25 | * @param topics: kafka topics
26 | * @param scaledHistDf: 전처리된 historical price df
27 | * @param bcStockSymbols: 종목코드 배열
28 | * @param bcHistDates: 날짜 배열
29 | * @param bcSymb2nameMap: 종목코드->이름 맵
30 | * @param bcHistPriceMap: 종목코드,날짜->주가 맵
31 | * @param broadcastItems: BroadcastItems(histSize: Int, rtSize: Int, batchInterval: Int, saveDir: String)
32 | */
33 | def process(spark: SparkSession,
34 | kafkaBootstrapServers: String,
35 | kafkaGroupId: String,
36 | topics: Array[String],
37 | scaledHistDf: DataFrame,
38 | bcStockSymbols: Broadcast[Array[String]],
39 | bcHistDates: Broadcast[Array[String]],
40 | bcSymb2nameMap: Broadcast[Map[String, String]],
41 | bcHistPriceMap: Broadcast[Map[String, Map[String, Double]]],
42 | broadcastItems: Broadcast[BroadcastItems]): Unit = {
43 |
44 | import spark.implicits._
45 |
46 | // get broadcasted items
47 | val stockSymbols = bcStockSymbols.value
48 | val histDates = bcHistDates.value
49 | val symb2nameMap = bcSymb2nameMap.value
50 | val histPriceMap = bcHistPriceMap.value
51 | val bcItems = broadcastItems.value
52 |
53 | // create streaming context
54 | val ssc = new StreamingContext(spark.sparkContext, Seconds(bcItems.batchInterval))
55 |
56 | // kafka stream connection
57 | val kafkaParams = Map[String, Object](
58 | "bootstrap.servers" -> kafkaBootstrapServers,
59 | "key.deserializer" -> classOf[StringDeserializer],
60 | "value.deserializer" -> classOf[StringDeserializer],
61 | "group.id" -> kafkaGroupId,
62 | "auto.offset.reset" -> "latest",
63 | "enable.auto.commit" -> (false: java.lang.Boolean)
64 | )
65 | val kafkaStream = KafkaUtils.createDirectStream[String, String](
66 | ssc,
67 | PreferConsistent,
68 | Subscribe[String, String](topics, kafkaParams)
69 | )
70 |
71 | scaledHistDf.persist()
72 | /* scaledHistDf.show()
73 | * +----------------------+--------+
74 | * |scaledHist |symb |
75 | * +----------------------+--------+
76 | * |[[1.0,0.93,...], ...] |A123456 |
77 | * |[[...], ...] |A888888 |
78 | * +----------------------+--------+
79 | */
80 |
81 | // build spark stream
82 | kafkaStream
83 | .map(_.value)
84 | .foreachRDD { rdd =>
85 | val rawRtDf = rdd.toDF()
86 | rawRtDf.persist()
87 | /* rawRtDf.show()
88 | * +--------+--------+--------+--------+
89 | * |20200401|20200402|... |symb |
90 | * +--------+--------+--------+--------+
91 | * |77.0 |78.0 |... |A123456 |
92 | * |7.0 |7.0 |... |A888888 |
93 | * +--------+--------+--------+--------+
94 | */
95 |
96 | if (!rawRtDf.head(1).isEmpty) {
97 |
98 | // extract dates and price dictionary
99 | val rtDates = rawRtDf.columns.slice(0, bcItems.rtSize)
100 | val lastDate = rtDates.max
101 | val rtPriceMap = CommonUtils.getRtPriceMap(spark, rawRtDf, bcItems.rtSize)
102 |
103 | // scaling rt df
104 | val scaledRtDf: DataFrame = Preprocessor.getScaledRtDf(spark, rawRtDf, bcItems.rtSize)
105 | /* scaledRtDf.show()
106 | * +------------------+--------+
107 | * |scaledRt |symb |
108 | * +------------------+--------+
109 | * |[0.33, 0.66, ...] |A123456 |
110 | * |[1.0, 1.0, ...] |A888888 |
111 | * +------------------+--------+
112 | */
113 |
114 | // correlation: 각 종목별로 현재와 가장 유사했던 과거 구간을 찾은 것
115 | val resCorr = PatternFinder.getCorrelation(spark, scaledHistDf, scaledRtDf, stockSymbols,
116 | histDates, rtDates, bcItems.histSize, bcItems.rtSize)
117 | /* resCorr
118 | * Array(CorrRow(A123456, Array(20100304, 20100305, ...), 0.4773), CorrRow(...), ...)
119 | */
120 |
121 | // parsing corr items: 종목별로 찾은 과거 유사 구간을 일자별로 분리하여 각 row가 되도록 파싱 (elasticsearch에 쌓기)
122 | // Array((symb, name, rt date, rt price, hist date, hist price, version, similarity))
123 | val resTuples = resCorr
124 | .flatMap { corrRow =>
125 | // CorrRow(stkCode: String, mostSimDates: Array[String], coef: Double)
126 | val symb = corrRow.stkCode
127 | val name = symb2nameMap(symb)
128 | val dates = corrRow.mostSimDates
129 | val similarity = corrRow.coef
130 | val symbPriceMap = histPriceMap(symb)
131 |
132 | // 찾은 과거 구간을 일자별로 각 row가 되도록 쪼개기 (elasticsearch에서 일자별로 넣을 거니까)
133 | val parsedRows = rtDates
134 | .zip(rtPriceMap(symb))
135 | .zip(dates)
136 | .map { r =>
137 | val rtDateCol = r._1._1
138 | val rtPriceCol = r._1._2
139 | val histDateCol = r._2
140 | val histPriceCol = symbPriceMap(r._2)
141 | (symb, name, rtDateCol, rtPriceCol, histDateCol, histPriceCol, lastDate, similarity)
142 | }
143 |
144 | parsedRows
145 | }
146 |
147 | // save output as csv file
148 | spark
149 | .sparkContext
150 | .parallelize(resTuples)
151 | .toDF()
152 | .write
153 | .format("com.databricks.spark.csv")
154 | .mode("append")
155 | .option("header", "false")
156 | .save(bcItems.outputDir)
157 | }
158 |
159 | rawRtDf.unpersist()
160 | }
161 |
162 | // start streaming
163 | ssc.start()
164 | ssc.awaitTermination()
165 |
166 | scaledHistDf.unpersist()
167 | }
168 |
169 | }
--------------------------------------------------------------------------------