├── README.md ├── build.sbt ├── img ├── architecture.png ├── introduce.png ├── job_input_output.png ├── result_plot.png └── tmp ├── project ├── build.properties └── plugins.sbt └── src └── main └── scala └── com └── bigdata ├── CommonUtils.scala ├── PatternFinder.scala ├── Preprocessor.scala ├── StockPatternStream.scala └── StreamingManager.scala /README.md: -------------------------------------------------------------------------------- 1 | # Stock Pattern Stream 2 | ### 현재 주가 차트와 유사한 과거 차트를 실시간으로 찾아주는 분석 서비스 3 | 4 | 주식 종목의 가격과 거래량을 보고 금융 시장의 흐름을 예측하는 것을 "기술적 분석"이라 합니다. 5 | 종목의 주가 패턴은 때떄로 반복적으로 나타나기도 하며, 주식 거래를 하는 많은 사람들이 주가를 예측하기 위해 이러한 패턴을 분석하곤 합니다. 6 | 다음 그림처럼 만약 현재의 주가 차트가 과거의 어떤 시점과 매우 유사하다면, 해당 과거 시점의 추이가 현재 주가의 추이를 예측하는데 도움이 될 수도 있습니다. 7 | 8 |

9 | introduce 10 |

11 | 12 | 본 레포지터리는 *"종목의 과거 주가 차트에서 현재 차트와 가장 유사한 구간을 실시간으로 찾아주는 프로그램이 있다면 편리하지 않을까?"* 13 | 라는 아이디어를 Spark Streaming으로 구현한 것입니다. 14 | 15 | 16 | ## 1. Architecture 17 | 18 | ### 1.1. 실시간 과거 패턴 찾기 19 | 20 | 본 프로세스에선 종목의 과거 주가 차트에 window sliding을 적용하여 n 크기의 슬라이드들을 생성합니다. 21 | 22 | Spark Streaming Job은 종목의 "최근 n개 분봉"을 입력받으면, 23 | 미리 생성한 슬라이드들과 Pearson 상관계수를 계산하여 가장 유사한 슬라이드를 찾습니다. 24 | 25 |

26 | job_input_output 27 |

28 | 29 | 한 종목에 대하여 위와 같은 패턴 찾기 작업이 parallel하게 수행됩니다. 30 | 분석 대상 종목이 여러 개인 경우, 본 작업을 각 종목에 대해 순차적으로 수행합니다. 31 | 32 | ### 1.2. 전체 구성 33 |

34 | architecture 35 |

36 | 37 | - Kafka Producer 38 | - 키움 OpenAPI 등으로 종목 분봉 데이터를 실시간으로 수집하여 Kafka에 전송해주는 producer 프로세스 39 | - 분석 대상 종목들의 최근 n개 분봉을 전송 40 | - Kafka 41 | - 분봉 데이터 저장 Queue 42 | - 실시간 과거 패턴 찾기 43 | - 본 레포지터리의 구현 범위 44 | - Spark Streaming Job으로 Kafka의 데이터를 입력 받아 실시간으로 분석 45 | - 분석 결과를 FS에 저장 46 | - ELK Stack 47 | - Logstash, Elasticsearch, Kibana를 이용하여 실시간 분석 결과를 시각화 48 | - Logstash: 저장된 분석 결과를 Elasticsearch에 입력 49 | - Kibana: 실시간 분석 결과를 대시보드에 시각화 50 | 51 | 52 | ### 1.3. 분석 스펙 53 | 54 | - Spark Mode 55 | - Standalone (별도의 Hadoop cluster 없이 동작) 56 | - (Hadoop cluster 모드는 추후 고려 예정) 57 | - Streaming Perfomance 58 | - 10개 이내 종목에 대해 1분 단위 (batchInterval) 스트리밍 분석 59 | - default batchInterval은 5분으로 설정됨 60 | - 분석 데이터 길이 61 | - window slide 길이 n: 59 분봉 62 | - 과거치 분봉 데이터 길이: 629 분봉 63 | 64 | ### 1.4. Input/Output 데이터 구조 65 | 현재 일봉 데이터로 분봉을 대체하고 있으며, 추후 분봉으로 변경 예정임 66 | 67 | |input data| -> |output data| 68 | |:---:|:---:|:---:| 69 | |종목별 과거분봉 데이터
종목별 실시간 분봉 데이터
종목코드->종목명 매핑 테이블|(*spark streaming job*)|실시간 분석 결과| 70 | 71 | - (Input) 종목별 과거 분봉 데이터 72 | - 입력 방식: spark-submit argument 중 ```--hist-path [file path]```로 파일 경로 입력 73 | - 데이터 포맷: json 74 | ``` 75 | // {시간1:분봉가격, 시간2:분봉가격, ..., 시간629:분봉가격, symb:종목코드} 76 | {"20160201": 5630,"20160202": 5633,...,"symb":"A123456"},{...},... 77 | ``` 78 | - (Input) 종목별 실시간 분봉 데이터 79 | - 입력 방식: kafka consuming (message's value) 80 | - 데이터 포맷: json 81 | ``` 82 | // {시간1:분봉가격, 시간2:분봉가격, ..., 시간59:분봉가격, symb:종목코드} 83 | {"20160201": 5630,"20160202": 5633,...,"symb":"A123456"},{...},... 84 | ``` 85 | - (Input) 종목코드->종목명 매핑 테이블 86 | - 입력 방식: spark-submit argument 중 ```--symb2name-path [file path]```로 파일 경로 입력 87 | - 데이터 포맷: json 88 | ``` 89 | // [{"symb":종목코드,"name":종목명}, ...] 90 | [{"symb":"A001720","name":"신영증권"},...] 91 | ``` 92 | - (Output) 실시간 분석 결과 93 | - 저장 방식: spark-submit argument 중 ```--output-dir [directory path]``` 값 경로에 파일로 저장 94 | - 데이터 포맷: csv 95 | ``` 96 | // 종목코드,종목명,실시간분봉시간,실시간분봉가격,과거분봉시간,과거분봉가격,버전,상관계수 97 | A097950,CJ제일제당,20190705,294000.0,20190402,326000.0,20190930,0.949744 98 | A097950,CJ제일제당,20190708,289000.0,20190403,326000.0,20190930,0.949744 99 | A097950,CJ제일제당,20190709,284000.0,20190404,325000.0,20190930,0.949744 100 | ...(후략)... 101 | ``` 102 | - 데이터 시각화 예시 (CJ제일제당, 유사도 94%) 103 | - Now: 실시간분봉시간에 따른 실시간분봉가격 104 | - Then: 과거분봉시간에 따른 과거분봉가격 105 |

106 | result_plot 107 |

108 | 109 | 110 | 111 | ### 1.5. 주요 소스 코드 112 | 113 | - StockPatternStream.scala 114 | - 메인 싱글톤. arguments parsing 및 분석 함수 호출 115 | - Preprocessor.scala 116 | - DataFrame 전처리 수행 117 | - PatternFinder.scala 118 | - Spark ML을 이용한 Correlation 계산 119 | - StreamingManager.scala 120 | - Spark Streaming 정의 121 | - CommonUtils.scala 122 | - 유틸리티 함수 모음 123 | 124 | ## 2. How to use 125 | 126 | ### 2.1. Prerequisites 127 | |name|version| 128 | |:---|:---| 129 | |Scala|2.11.12| 130 | |SBT|1.3.10| 131 | |JDK|1.8.0| 132 | |Apache Spark|2.4.5| 133 | |spark-streaming-kafka|0.10.0| 134 | 135 | - Spark은 standalone으로 동작함 136 | 137 | ### 2.2. Usage 138 | 139 | create fat jar 140 | ``` 141 | sbt assembly 142 | ``` 143 | 144 | spark submit 145 | ``` 146 | spark-submit jarfile --hist-path [path1] --symb2name-path [path2] --output-dir [path3] 147 | --kafka-bootstrap-server [addr] --kafka-group-id [id] --kafka-topic [topic] --batch-interval [seconds] 148 | 149 | arguments 150 | --hist-path: 종목별 과거 분봉 데이터 경로 151 | --symb2name-path: 종목코드->종목명 매핑 테이블 경로 152 | --output-dir: 분석 결과 저장 경로 153 | --kafka-bootstrap-server: 카프카 부트스트랩 주소 (localhost:9092) 154 | --kafka-group-id: 컨슈머 그룹 ID 155 | --kafka-topic: 토픽명 156 | --batch-interval: Spark Streaming 배치 간격 (default: 300) 157 | ``` 158 | 159 | 160 | 161 | ## 3. Future tasks 162 | 163 | - Spark Streaming과 Elasticsearch 연동, Logstash 단계 제거 164 | - 일봉 데이터 -> 분봉 데이터 변경 (현재 일봉 데이터로 분봉을 대체함) 165 | - 싱글모드 -> 클러스터 전환 성능 실험 166 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "StockPatternStream" 2 | 3 | version := "0.1" 4 | 5 | scalaVersion := "2.11.12" 6 | 7 | 8 | libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.5" % "provided" 9 | libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.5" % "provided" 10 | libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.5" % "provided" 11 | libraryDependencies += "org.apache.spark" %% "spark-streaming" % "2.4.5" % "provided" 12 | libraryDependencies += "org.apache.spark" %% "spark-streaming-kafka-0-10" % "2.4.5" % "provided" 13 | libraryDependencies += "org.apache.spark" %% "spark-hive" % "2.4.5" % "provided" 14 | libraryDependencies += "org.apache.kafka" % "kafka-clients" % "2.4.1" % "provided" 15 | libraryDependencies += "org.apache.kafka" %% "kafka" % "2.4.1" % "provided" 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/architecture.png -------------------------------------------------------------------------------- /img/introduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/introduce.png -------------------------------------------------------------------------------- /img/job_input_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/job_input_output.png -------------------------------------------------------------------------------- /img/result_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhkdn9192/StockPatternStream/26358b671e0bb96a20706453618d3df02d9145e8/img/result_plot.png -------------------------------------------------------------------------------- /img/tmp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.3.10 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") 2 | -------------------------------------------------------------------------------- /src/main/scala/com/bigdata/CommonUtils.scala: -------------------------------------------------------------------------------- 1 | package com.bigdata 2 | 3 | import org.apache.spark.sql.{DataFrame, SparkSession} 4 | 5 | 6 | /** 7 | * 프로젝트 공통 유틸리티 함수 모음 8 | * 2020.05.13 by dhkim 9 | */ 10 | 11 | object CommonUtils { 12 | 13 | /** 14 | * SparkStreaming에서 broadcast할 변수들 담을 케이스 클래스 15 | * @param histSize: historical size 16 | * @param rtSize: real-time size 17 | * @param batchInterval: streaming interval 18 | * @param outputDir: save dir 19 | */ 20 | case class BroadcastItems(histSize: Int, rtSize: Int, batchInterval: Int, outputDir: String) {} 21 | 22 | 23 | def getHistPriceMap(spark: SparkSession, rawPriceDf: DataFrame, priceSize: Int, 24 | dates: Array[String]): Map[String, Map[String, Double]] = { 25 | import spark.implicits._ 26 | 27 | // Map[symb -> Map[date -> price]] 28 | val priceMap = rawPriceDf 29 | .map{ row => 30 | val symb = row.getString(priceSize) 31 | val prices = (0 until priceSize).map(row.getLong(_).toDouble) 32 | (symb, dates.zip(prices).toMap) 33 | } 34 | .collect 35 | .toMap 36 | 37 | priceMap 38 | } 39 | 40 | def getRtPriceMap(spark: SparkSession, rawPriceDf: DataFrame, priceSize: Int): Map[String, Seq[Double]] = { 41 | import spark.implicits._ 42 | 43 | val priceMap = rawPriceDf 44 | .map{ row => 45 | val symb = row.getString(priceSize) 46 | val prices = (0 until priceSize).map(row.getLong(_).toDouble) 47 | (symb, prices) 48 | } 49 | .collect 50 | .toMap 51 | 52 | priceMap 53 | } 54 | 55 | def getSymb2NameMap(spark: SparkSession, path: String): Map[String, String] = { 56 | import spark.implicits._ 57 | 58 | val symb2nameMapDf: DataFrame = spark.read.json(path) 59 | val symb2nameMap: Map[String, String] = symb2nameMapDf 60 | .map { r => 61 | val symbCol = r.getAs[String]("symb") 62 | val nameCol = r.getAs[String]("name") 63 | (symbCol, nameCol) 64 | } 65 | .collect 66 | .toMap 67 | 68 | // symb to name Map 69 | symb2nameMap 70 | } 71 | 72 | def getStockSymbols(spark: SparkSession, rawHistDf: DataFrame): Array[String] = { 73 | import spark.implicits._ 74 | 75 | val stockSymbols = rawHistDf 76 | .select($"symb") 77 | .collect 78 | .map(_(0).toString) 79 | 80 | stockSymbols 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /src/main/scala/com/bigdata/PatternFinder.scala: -------------------------------------------------------------------------------- 1 | package com.bigdata 2 | 3 | import org.apache.spark.ml.linalg.{Matrix, Vectors} 4 | import org.apache.spark.ml.stat.Correlation 5 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 6 | import org.apache.spark.sql.functions.explode 7 | 8 | 9 | /** 10 | * 종목별 과거 유사 패턴찾기 위한 Correlation 기능 구현 11 | * 2020.05.13 by dhkim 12 | */ 13 | 14 | object PatternFinder { 15 | 16 | /** 17 | * 18 | * @param stkCode: 종목코드 19 | * @param mostSimDates: 과거 구간들 중 가장 유사한 구간의 날짜값들 20 | * @param coef: 유사도 21 | */ 22 | case class CorrRow(stkCode: String, mostSimDates: Array[String], coef: Double) {} 23 | 24 | /** 25 | * @param spark: SparkSession 26 | * @param scaledHistDf: window sliding 및 min-max 스케일 적용된 historical price df 27 | * @param scaledRtDf: min-max 스케일 적용된 real-time price df 28 | * @param stockSymbols: 종목코드 배열. ex) Array(A123456, A888888, ...) 29 | * @param histDates: scaledHistDf의 날짜들. ex) Array(20200301, 20200302, ...) 30 | * @param rtDates: scaledRtDf의 날짜들. 31 | * @param histSize: 종목별 historical price 전체 길이 32 | * @param rtSize: real-time으로 분석할 종가 구간 길이 33 | */ 34 | def getCorrelation(spark: SparkSession, scaledHistDf: DataFrame, scaledRtDf: DataFrame, stockSymbols: Array[String], 35 | histDates: Array[String], rtDates: Array[String], histSize: Int, rtSize: Int): Array[CorrRow] = { 36 | 37 | import spark.implicits._ 38 | 39 | val slideSize = histSize - rtSize + 1 40 | val coefSize = slideSize + 1 41 | 42 | // 두 df를 join 및 slide들을 여러 row들로 분리(explode) 43 | val explodedDf = scaledRtDf.as("l") 44 | .join(scaledRtDf.as("r"), $"l.symb" === $"r.symb") 45 | .drop($"l.symb") 46 | .map{ row => 47 | val slided = row.getAs[Seq[Seq[Double]]](0) 48 | val rt = row.getAs[Seq[Double]](1) 49 | val symb = row.getString(2) 50 | val combined = rt 51 | .zip(slided) 52 | .map(zipped => Seq(zipped._1) ++ zipped._2) 53 | (combined, symb) 54 | } 55 | .toDF("combined", "symb") 56 | .select(explode($"combined").as("exploded"), $"symb") // explode 함수로 row의 slide들을 여러 row들로 분리 57 | /* explodedDf.show() 58 | * +-----------------------------+--------+ 59 | * |exploded |symb | 60 | * +-----------------------------+--------+ 61 | * |[0.33, 1.0, 0.0, 0.0, ...] |A123456 | 62 | * |[0.66, 0.0, 0.14, 0.16, ...] |A123456 | 63 | * |... |... | 64 | * |[1.0, 0.0, 0.63, 1.0, ...] |A888888 | 65 | * |[1.0, 0.2, 1.0, 0.81, ...] |A888888 | 66 | * |... |... | 67 | * +-----------------------------+--------+ 68 | */ 69 | 70 | // correlation 계산을 위한 벡터화 71 | val vectorDf = explodedDf 72 | .rdd 73 | .map{ row => 74 | val vectCol = row.getAs[Seq[Double]](0).toArray 75 | val symb = row.getString(1) 76 | (Vectors.dense(vectCol), symb) 77 | } 78 | .toDF("vectors", "symb") 79 | /* vectorDf.show() 80 | * +-----------------------------+--------+ 81 | * |vectors |symb | 82 | * +-----------------------------+--------+ 83 | * |[0.33, 1.0, 0.0, 0.0, ...] |A123456 | 84 | * |[0.66, 0.0, 0.14, 0.16, ...] |A123456 | 85 | * |... |... | 86 | * |[1.0, 0.0, 0.63, 1.0, ...] |A888888 | 87 | * |[1.0, 0.2, 1.0, 0.81, ...] |A888888 | 88 | * |... |... | 89 | * +-----------------------------+--------+ 90 | */ 91 | 92 | val resCorrRows = stockSymbols 93 | .map{ symb => 94 | val filteredDf = vectorDf.filter($"symb" === symb) 95 | val Row(coefMat: Matrix) = Correlation.corr(filteredDf, "vectors").head 96 | val coefAry = coefMat.toArray.slice(1, coefSize) // 현재값 제외 97 | 98 | val mostSimilar = coefAry.zipWithIndex.maxBy(_._1) 99 | val mostSimilarDates = histDates.slice(mostSimilar._2, mostSimilar._2 + rtSize) 100 | 101 | // (종목코드, 상관구간 날짜들, 상관계수) 102 | CorrRow(symb, mostSimilarDates, mostSimilar._1) 103 | } 104 | /* resCorrRows 105 | * Array(CorrRow(A123456, Array(20100304, 20100305, ...), 0.4773), CorrRow(...), ...) 106 | */ 107 | 108 | resCorrRows 109 | } 110 | 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala/com/bigdata/Preprocessor.scala: -------------------------------------------------------------------------------- 1 | package com.bigdata 2 | 3 | import scala.math.floor 4 | import org.apache.spark.sql.{DataFrame, SparkSession} 5 | 6 | 7 | /** 8 | * DataFrame 로드 및 전처리 9 | * 2020.05.13 by dhkim 10 | */ 11 | 12 | object Preprocessor { 13 | 14 | def getRawDf(spark: SparkSession, path: String): DataFrame = { 15 | val rawDf = spark.read.json(path) 16 | rawDf 17 | } 18 | 19 | def spRound(inval: Double): Double = { 20 | val step = 1000000 21 | floor(inval*step) / step 22 | } 23 | 24 | def getScaledHistDf(spark: SparkSession, rawHistDf: DataFrame, histSize: Int, rtSize: Int): DataFrame = { 25 | /* 26 | * 윈도우 슬라이딩 및 rt 기준 슬라이드별 min-max 스케일링 적용된 historical price df 생성 27 | * > input 28 | * +--------+--------+--------+--------+ 29 | * |20100301|20100302|... |symb | 30 | * +--------+--------+--------+--------+ 31 | * |24.0 |22.0 |... |A123456 | 32 | * |114.0 |110.0 |... |A888888 | 33 | * +--------+--------+--------+--------+ 34 | * > output 35 | * +----------------------+--------+ 36 | * |scaledHist |symb | 37 | * +----------------------+--------+ 38 | * |[[1.0,0.93,...], ...] |A123456 | 39 | * |[[...], ...] |A888888 | 40 | * +----------------------+--------+ 41 | */ 42 | import spark.implicits._ 43 | 44 | val slideSize = histSize - rtSize + 1 45 | 46 | val slidedHistDf = rawHistDf 47 | .map{ row => 48 | // sliding 49 | val slidedHist = (0 until histSize) 50 | .map(row.getLong(_).toDouble) 51 | .toArray 52 | .sliding(slideSize) 53 | .map(_.toArray) 54 | .toArray 55 | 56 | // scaling 57 | val transposed = slidedHist.transpose 58 | val mins = transposed.map(_.min) 59 | val durations = transposed.map(i => i.max-i.min) // 스케일링 기준은 slide 가 아니라, 각 slide별 같은 위치값 60 | val scaledAry = slidedHist.map{ elem => 61 | val scaled = elem 62 | .indices 63 | .map(i => spRound((elem(i)-mins(i))/durations(i))) 64 | .toArray 65 | scaled 66 | } 67 | 68 | // symb column 69 | val symb = row.getString(histSize) 70 | (scaledAry, symb) 71 | } 72 | .toDF("scaledHist", "symb") 73 | 74 | slidedHistDf 75 | } 76 | 77 | def getScaledRtDf(spark: SparkSession, rawRtDf: DataFrame, rtSize: Int): DataFrame = { 78 | /* real-time으로 들어오는 종목별 price를 array 파싱 및 min-max 스케일ㄹ이하여 새로운 df 생성 79 | * > input 80 | * +--------+--------+--------+--------+ 81 | * |20200401|20200402|... |symb | 82 | * +--------+--------+--------+--------+ 83 | * |77.0 |78.0 |... |A123456 | 84 | * |7.0 |7.0 |... |A888888 | 85 | * +--------+--------+--------+--------+ 86 | * > output 87 | * +------------------+--------+ 88 | * |scaledRt |symb | 89 | * +------------------+--------+ 90 | * |[0.33, 0.66, ...] |A123456 | 91 | * |[1.0, 1.0, ...] |A888888 | 92 | * +------------------+--------+ 93 | */ 94 | import spark.implicits._ 95 | 96 | val scaledRtDf = rawRtDf 97 | .map{ row => 98 | // scaling 99 | val rt = (0 until rtSize).map(row.getLong(_).toDouble).toArray 100 | val duration = rt.max - rt.min 101 | val scaled = rt.map(v => spRound((v-rt.min)/duration)) 102 | 103 | // symb column 104 | val symb = row.getString(rtSize) 105 | (scaled, symb) 106 | } 107 | .toDF("scaledRt", "symb") 108 | 109 | scaledRtDf 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala/com/bigdata/StockPatternStream.scala: -------------------------------------------------------------------------------- 1 | package com.bigdata 2 | 3 | import org.apache.spark.sql.SparkSession 4 | import com.bigdata.CommonUtils.BroadcastItems 5 | 6 | 7 | /** 8 | * 프로젝트 메인 9 | * 2020.05.13 by dhkim 10 | */ 11 | 12 | object StockPatternStream { 13 | 14 | /** 15 | * argument parse 16 | * @param map: scala Map 17 | * @param argList: argument 문자열 리스트 18 | * @return 키-밸류 Map 19 | */ 20 | @scala.annotation.tailrec 21 | def parseArgs(map : Map[Symbol, Any], argList: List[String]): Map[Symbol, Any] = { 22 | argList match { 23 | case Nil => map 24 | case "--hist-path" :: value :: tail => parseArgs(map ++ Map('histpath -> value.toString), tail) 25 | case "--symb2name-path" :: value :: tail => parseArgs(map ++ Map('symb2name -> value.toString), tail) 26 | case "--output-dir" :: value :: tail => parseArgs(map ++ Map('outputdir -> value.toString), tail) 27 | case "--kafka-bootstrap-server" :: value :: tail => parseArgs(map ++ Map('bootstrap -> value.toString), tail) 28 | case "--kafka-group-id" :: value :: tail => parseArgs(map ++ Map('groupid -> value.toString), tail) 29 | case "--kafka-topic" :: value :: tail => parseArgs(map ++ Map('topic -> value.toString), tail) 30 | case "--hist-size" :: value :: tail => parseArgs(map ++ Map('histsize -> value.toInt), tail) 31 | case "--rt-size" :: value :: tail => parseArgs(map ++ Map('rtsize -> value.toInt), tail) 32 | case "--batch-interval" :: value :: tail => parseArgs(map ++ Map('batchinterval -> value.toInt), tail) 33 | case option :: tail => println("Unknown argument " + option) 34 | sys.exit(1) 35 | } 36 | } 37 | 38 | def main(args: Array[String]): Unit = { 39 | 40 | val usage = 41 | """Usage: spark-submit jarfile --hist-path [path1] --symb2name-path [path2] --output-dir [path3] 42 | |--kafka-bootstrap-server [addr] --kafka-group-id [id] --kafka-topic [topic]""".stripMargin 43 | 44 | // arguments parsing 45 | if (args.length == 0) println(usage) 46 | val env = parseArgs(Map(), args.toList) 47 | val histPath = env.getOrElse('histpath, "/data/ailabHome/elasticHome/historyPattern/histQuotes.json").toString 48 | val symb2namePath = env.getOrElse('symb2name, "/data/ailabHome/elasticHome/historyPattern/symb2nameDic.json").toString 49 | val outputDir = env.getOrElse('outputdir, "/data/ailabHome/elasticHome/historyPattern/correlation").toString 50 | val kafkaBootstrapServers = env.getOrElse('bootstrap, "localhost:9092").toString 51 | val kafkaGroupId = env.getOrElse('groupid, "group01").toString 52 | val topic = env.getOrElse('topic, "topicA").toString 53 | val topics = Array(topic) 54 | val histSize = env.getOrElse('histsize, 629).asInstanceOf[Int] // historical 데이터의 일자 수 (630 - 1) 55 | val rtSize = env.getOrElse('rtsize, 59).asInstanceOf[Int] // real-time 데이터의 일자 수 (DStream의 각 rdd 일자 수) (60 - 1) 56 | val batchInterval = env.getOrElse('batchinterval, 60 * 5).asInstanceOf[Int] // Spark Streaming 배치 간격 57 | 58 | // create spark session 59 | val spark = SparkSession 60 | .builder() 61 | .master("local") 62 | .appName("StockPatternStream") 63 | .getOrCreate() 64 | 65 | // prepare data 66 | val rawHistDf = Preprocessor.getRawDf(spark, histPath) 67 | val scaledHistDf = Preprocessor.getScaledHistDf(spark, rawHistDf, histSize, rtSize) 68 | val histDates = rawHistDf.columns.slice(0, histSize) 69 | val stockSymbols = CommonUtils.getStockSymbols(spark, rawHistDf) 70 | val histPriceMap = CommonUtils.getHistPriceMap(spark, rawHistDf, histSize, histDates) 71 | val symb2nameMap = CommonUtils.getSymb2NameMap(spark, symb2namePath) 72 | val broadcastItems = BroadcastItems(histSize, rtSize, batchInterval, outputDir) 73 | 74 | // broadcasting 75 | val bcHistDates = spark.sparkContext.broadcast(histDates) 76 | val bcStockSymbols = spark.sparkContext.broadcast(stockSymbols) 77 | val bcHistPriceMap = spark.sparkContext.broadcast(histPriceMap) 78 | val bcSymb2nameMap = spark.sparkContext.broadcast(symb2nameMap) 79 | val bcBroadcastItems = spark.sparkContext.broadcast(broadcastItems) 80 | 81 | // spark streaming 82 | StreamingManager.process(spark, 83 | kafkaBootstrapServers, 84 | kafkaGroupId, 85 | topics, 86 | scaledHistDf, 87 | bcStockSymbols, 88 | bcHistDates, 89 | bcSymb2nameMap, 90 | bcHistPriceMap, 91 | bcBroadcastItems) 92 | } 93 | 94 | } 95 | -------------------------------------------------------------------------------- /src/main/scala/com/bigdata/StreamingManager.scala: -------------------------------------------------------------------------------- 1 | package com.bigdata 2 | 3 | import org.apache.kafka.common.serialization.StringDeserializer 4 | import org.apache.spark.broadcast.Broadcast 5 | import org.apache.spark.sql.{DataFrame, SparkSession} 6 | import org.apache.spark.streaming.{Seconds, StreamingContext} 7 | import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe 8 | import org.apache.spark.streaming.kafka010.KafkaUtils 9 | import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent 10 | import com.bigdata.CommonUtils.BroadcastItems 11 | 12 | 13 | /** 14 | * Spark Streaming 구현 15 | * 2020.05.13 by dhkim 16 | */ 17 | 18 | object StreamingManager { 19 | 20 | /** 21 | * SparkStreaming이 수행되는 프로세스 22 | * @param spark: SparkSession 23 | * @param kafkaBootstrapServers: kafka bootstrap server addr 24 | * @param kafkaGroupId: kafka consumer group id 25 | * @param topics: kafka topics 26 | * @param scaledHistDf: 전처리된 historical price df 27 | * @param bcStockSymbols: 종목코드 배열 28 | * @param bcHistDates: 날짜 배열 29 | * @param bcSymb2nameMap: 종목코드->이름 맵 30 | * @param bcHistPriceMap: 종목코드,날짜->주가 맵 31 | * @param broadcastItems: BroadcastItems(histSize: Int, rtSize: Int, batchInterval: Int, saveDir: String) 32 | */ 33 | def process(spark: SparkSession, 34 | kafkaBootstrapServers: String, 35 | kafkaGroupId: String, 36 | topics: Array[String], 37 | scaledHistDf: DataFrame, 38 | bcStockSymbols: Broadcast[Array[String]], 39 | bcHistDates: Broadcast[Array[String]], 40 | bcSymb2nameMap: Broadcast[Map[String, String]], 41 | bcHistPriceMap: Broadcast[Map[String, Map[String, Double]]], 42 | broadcastItems: Broadcast[BroadcastItems]): Unit = { 43 | 44 | import spark.implicits._ 45 | 46 | // get broadcasted items 47 | val stockSymbols = bcStockSymbols.value 48 | val histDates = bcHistDates.value 49 | val symb2nameMap = bcSymb2nameMap.value 50 | val histPriceMap = bcHistPriceMap.value 51 | val bcItems = broadcastItems.value 52 | 53 | // create streaming context 54 | val ssc = new StreamingContext(spark.sparkContext, Seconds(bcItems.batchInterval)) 55 | 56 | // kafka stream connection 57 | val kafkaParams = Map[String, Object]( 58 | "bootstrap.servers" -> kafkaBootstrapServers, 59 | "key.deserializer" -> classOf[StringDeserializer], 60 | "value.deserializer" -> classOf[StringDeserializer], 61 | "group.id" -> kafkaGroupId, 62 | "auto.offset.reset" -> "latest", 63 | "enable.auto.commit" -> (false: java.lang.Boolean) 64 | ) 65 | val kafkaStream = KafkaUtils.createDirectStream[String, String]( 66 | ssc, 67 | PreferConsistent, 68 | Subscribe[String, String](topics, kafkaParams) 69 | ) 70 | 71 | scaledHistDf.persist() 72 | /* scaledHistDf.show() 73 | * +----------------------+--------+ 74 | * |scaledHist |symb | 75 | * +----------------------+--------+ 76 | * |[[1.0,0.93,...], ...] |A123456 | 77 | * |[[...], ...] |A888888 | 78 | * +----------------------+--------+ 79 | */ 80 | 81 | // build spark stream 82 | kafkaStream 83 | .map(_.value) 84 | .foreachRDD { rdd => 85 | val rawRtDf = rdd.toDF() 86 | rawRtDf.persist() 87 | /* rawRtDf.show() 88 | * +--------+--------+--------+--------+ 89 | * |20200401|20200402|... |symb | 90 | * +--------+--------+--------+--------+ 91 | * |77.0 |78.0 |... |A123456 | 92 | * |7.0 |7.0 |... |A888888 | 93 | * +--------+--------+--------+--------+ 94 | */ 95 | 96 | if (!rawRtDf.head(1).isEmpty) { 97 | 98 | // extract dates and price dictionary 99 | val rtDates = rawRtDf.columns.slice(0, bcItems.rtSize) 100 | val lastDate = rtDates.max 101 | val rtPriceMap = CommonUtils.getRtPriceMap(spark, rawRtDf, bcItems.rtSize) 102 | 103 | // scaling rt df 104 | val scaledRtDf: DataFrame = Preprocessor.getScaledRtDf(spark, rawRtDf, bcItems.rtSize) 105 | /* scaledRtDf.show() 106 | * +------------------+--------+ 107 | * |scaledRt |symb | 108 | * +------------------+--------+ 109 | * |[0.33, 0.66, ...] |A123456 | 110 | * |[1.0, 1.0, ...] |A888888 | 111 | * +------------------+--------+ 112 | */ 113 | 114 | // correlation: 각 종목별로 현재와 가장 유사했던 과거 구간을 찾은 것 115 | val resCorr = PatternFinder.getCorrelation(spark, scaledHistDf, scaledRtDf, stockSymbols, 116 | histDates, rtDates, bcItems.histSize, bcItems.rtSize) 117 | /* resCorr 118 | * Array(CorrRow(A123456, Array(20100304, 20100305, ...), 0.4773), CorrRow(...), ...) 119 | */ 120 | 121 | // parsing corr items: 종목별로 찾은 과거 유사 구간을 일자별로 분리하여 각 row가 되도록 파싱 (elasticsearch에 쌓기) 122 | // Array((symb, name, rt date, rt price, hist date, hist price, version, similarity)) 123 | val resTuples = resCorr 124 | .flatMap { corrRow => 125 | // CorrRow(stkCode: String, mostSimDates: Array[String], coef: Double) 126 | val symb = corrRow.stkCode 127 | val name = symb2nameMap(symb) 128 | val dates = corrRow.mostSimDates 129 | val similarity = corrRow.coef 130 | val symbPriceMap = histPriceMap(symb) 131 | 132 | // 찾은 과거 구간을 일자별로 각 row가 되도록 쪼개기 (elasticsearch에서 일자별로 넣을 거니까) 133 | val parsedRows = rtDates 134 | .zip(rtPriceMap(symb)) 135 | .zip(dates) 136 | .map { r => 137 | val rtDateCol = r._1._1 138 | val rtPriceCol = r._1._2 139 | val histDateCol = r._2 140 | val histPriceCol = symbPriceMap(r._2) 141 | (symb, name, rtDateCol, rtPriceCol, histDateCol, histPriceCol, lastDate, similarity) 142 | } 143 | 144 | parsedRows 145 | } 146 | 147 | // save output as csv file 148 | spark 149 | .sparkContext 150 | .parallelize(resTuples) 151 | .toDF() 152 | .write 153 | .format("com.databricks.spark.csv") 154 | .mode("append") 155 | .option("header", "false") 156 | .save(bcItems.outputDir) 157 | } 158 | 159 | rawRtDf.unpersist() 160 | } 161 | 162 | // start streaming 163 | ssc.start() 164 | ssc.awaitTermination() 165 | 166 | scaledHistDf.unpersist() 167 | } 168 | 169 | } --------------------------------------------------------------------------------