├── .gitignore ├── Makefile ├── README.md ├── SplitModel.mlmodel ├── SplitModel.proto ├── SplitModel.pt ├── data.json ├── ios ├── Split.swift └── SplitBridge.m ├── predict.py ├── prediction.json ├── test-ui ├── index.html ├── main.tsx ├── package.json ├── regression.ts └── tsconfig.json └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | yarn.lock 3 | /env 4 | .cache 5 | dist 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VIRTUALENV:=$(shell which virtualenv) 2 | ENV=env 3 | SITE_PACKAGES=$(ENV)/lib/python2.7/site-packages 4 | PYTHON=/usr/bin/python 5 | LOAD_ENV=source $(ENV)/bin/activate 6 | 7 | PYTORCH=http://download.pytorch.org/whl/torch-0.3.0.post4-cp27-none-macosx_10_6_x86_64.whl 8 | ONNX_COREML=git+https://github.com/onnx/onnx-coreml.git 9 | TNT=git+https://github.com/pytorch/tnt.git 10 | 11 | dev: 12 | cd test-ui && ./node_modules/.bin/parcel index.html & open http://localhost:1234/ 13 | .PHONY: dev 14 | 15 | env: $(VIRTUALENV) 16 | virtualenv env --python=$(PYTHON) 17 | 18 | $(SITE_PACKAGES)/torch: 19 | $(LOAD_ENV) && pip install $(PYTORCH) 20 | 21 | $(SITE_PACKAGES)/onnx_coreml: 22 | $(LOAD_ENV) && pip install $(ONNX_COREML) 23 | 24 | $(SITE_PACKAGES)/torchnet: 25 | $(LOAD_ENV) && pip install $(TNT) 26 | 27 | SplitModel.mlmodel: env $(SITE_PACKAGES)/torch $(SITE_PACKAGES)/onnx_coreml $(SITE_PACKAGES)/torchnet train.py data.json 28 | $(LOAD_ENV) && python train.py 29 | 30 | train: 31 | @touch data.json 32 | @make SplitModel.mlmodel 33 | .PHONY: train 34 | 35 | prediction.json: SplitModel.mlmodel 36 | $(LOAD_ENV) && python predict.py 37 | 38 | predict: prediction.json 39 | .PHONY: predict -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorch and CoreML example 2 | ========================== 3 | 4 | Companion code to the blog post: [How I Shipped a Neural Network on iOS with CoreML, PyTorch, and React Native](https://attardi.org/pytorch-and-coreml). -------------------------------------------------------------------------------- /SplitModel.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steadicat/pytorch-coreml-example/1808b7b7e225215c2033586cf801443693f9c6eb/SplitModel.mlmodel -------------------------------------------------------------------------------- /SplitModel.proto: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steadicat/pytorch-coreml-example/1808b7b7e225215c2033586cf801443693f9c6eb/SplitModel.proto -------------------------------------------------------------------------------- /SplitModel.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steadicat/pytorch-coreml-example/1808b7b7e225215c2033586cf801443693f9c6eb/SplitModel.pt -------------------------------------------------------------------------------- /data.json: -------------------------------------------------------------------------------- 1 | [{"points":[[41,24],[163,116],[254,116],[319,103],[484,112],[533,84],[629,91]],"splits":[112,410]},{"points":[[43,33],[86,69],[152,94],[175,118],[221,156],[247,38],[279,61],[303,89],[329,34],[369,56],[392,76],[422,119],[461,128],[470,34],[500,57],[525,93],[542,114],[582,138]],"splits":[235,320,467]},{"points":[[62,44],[123,55],[201,74],[258,92],[385,142],[498,161],[527,16],[566,24],[603,34],[658,37],[687,61]],"splits":[511]},{"points":[[94,55],[96,56],[98,58],[98,58],[105,51],[105,58],[110,58],[167,75],[196,84],[213,105],[241,91],[296,103],[298,100],[300,108],[300,104],[308,97],[308,103],[382,120],[434,130],[472,29],[502,149],[567,170],[570,155],[572,157],[581,157],[595,160]],"splits":[]},{"points":[[56,132],[123,114],[235,95],[374,69]],"splits":[]},{"points":[[195,125],[300,99],[391,74],[489,51],[505,117]],"splits":[492]},{"points":[[79,101],[170,95],[295,88],[413,95],[519,89],[549,95],[646,100],[722,100]],"splits":[]},{"points":[[113,64],[115,144],[276,53],[279,131],[458,55],[458,137],[632,52],[640,141]],"splits":[]},{"points":[[122,54],[197,77],[234,81],[263,94],[292,100],[598,101],[630,112],[647,115],[661,126]],"splits":[450]},{"points":[[134,65],[149,129],[263,135],[315,43],[390,68],[473,133],[566,86],[635,145],[666,52]],"splits":[211,281,432,599]},{"points":[[25,26],[72,34],[110,35],[138,43],[157,57],[184,57],[224,58],[247,61],[269,64],[281,71],[305,75],[327,79],[342,83],[374,82],[404,84],[426,86],[483,94],[499,100],[521,105],[561,113],[605,126],[605,126],[649,132],[681,144],[702,147],[736,157],[764,165]],"splits":[]},{"points":[[86,149],[134,108],[191,53],[275,112],[305,82],[337,46],[457,178],[516,125],[670,18],[735,129]],"splits":[213,380,702]},{"points":[[37,97],[91,97],[152,95],[184,98],[558,99],[634,100],[666,101],[708,102]],"splits":[370]},{"points":[[36,44],[70,58],[86,78],[114,95],[115,77],[141,100],[651,54],[683,71],[700,90],[720,83],[734,106],[763,121]],"splits":[369]},{"points":[[116,157],[156,141],[182,125],[214,119],[238,109],[265,98],[299,83],[334,88],[375,87],[399,89],[432,90],[467,90],[501,92],[532,92],[567,93],[609,94],[636,96],[680,101]],"splits":[314]},{"points":[[167,52],[384,94],[600,142]],"splits":[]},{"points":[[181,53],[183,56],[188,56],[189,63],[190,48],[382,101],[387,109],[389,103],[389,98],[395,105],[587,150],[589,143],[596,150],[596,154],[601,143]],"splits":[]},{"points":[[67,59],[141,64],[191,71],[226,70],[268,78],[306,80],[340,85],[461,114],[492,112],[527,109],[562,110],[588,114],[633,107],[654,114],[711,114]],"splits":[397]},{"points":[[173,150],[190,139],[214,118],[235,113],[256,91],[292,84],[359,104],[386,81],[407,80],[412,68],[438,58],[514,116],[530,101],[547,97],[557,88],[589,71],[659,117],[698,97],[732,76]],"splits":[323,469,628]},{"points":[[122,123],[229,69],[405,73],[488,119],[635,110],[739,108]],"splits":[311,586]},{"points":[[102,75],[145,98],[187,115],[203,61],[295,91],[387,129],[416,50],[592,102],[741,153]],"splits":[193,403]},{"points":[[47,175],[57,170],[151,173],[162,166],[176,160],[275,173],[285,168],[378,177],[414,162],[445,145],[477,134],[513,119],[544,102],[589,174],[604,162],[666,177],[747,132],[766,120]],"splits":[99,210,318,569,632]},{"points":[[65,107],[99,116],[136,124],[162,127],[258,106],[278,98],[290,87],[314,83],[414,103],[439,110],[464,122],[492,126],[590,110],[641,103],[662,99],[711,92]],"splits":[204,365,543]},{"points":[[99,72],[119,89],[135,98],[152,111],[209,82],[229,95],[241,108],[255,112],[330,95],[344,104],[358,120],[365,124],[379,136],[387,142],[403,155],[457,79],[481,92],[489,105],[535,136],[572,158],[583,169],[614,56],[644,76],[661,88],[678,95],[686,117],[706,115]],"splits":[184,288,423,599]},{"points":[[130,80],[131,91],[132,100],[133,109],[300,119],[301,126],[302,136],[303,105],[450,147],[451,158],[452,169],[519,66],[520,73],[521,83],[602,93],[604,104],[604,84],[683,113],[684,121],[685,129]],"splits":[480]},{"points":[[30,139],[41,135],[51,130],[148,143],[171,133],[187,130],[295,145],[311,143],[323,134],[333,134],[349,125],[364,123],[488,146],[505,140],[515,134],[528,129],[583,146],[597,141],[608,136],[622,129],[686,149],[701,142],[715,139],[732,135]],"splits":[96,225,419,550,656]},{"points":[[27,162],[155,101],[306,28],[312,83],[408,114],[486,138],[549,164],[586,99],[646,97],[777,100]],"splits":[307,568]},{"points":[[33,144],[51,137],[331,30],[337,179],[594,80],[611,71],[710,23]],"splits":[331]},{"points":[[46,95],[154,96],[234,95],[336,98],[602,100],[767,101]],"splits":[]},{"points":[[50,96],[726,93]],"splits":[]},{"points":[[53,61],[174,60],[306,60],[597,65],[645,142],[677,144],[711,143],[746,143]],"splits":[622]},{"points":[[50,46],[464,54],[482,154],[773,158]],"splits":[471]},{"points":[[52,58],[451,61],[460,56],[465,61],[497,157],[501,150],[638,157],[772,161]],"splits":[480]},{"points":[[54,45],[78,55],[183,94],[198,100],[228,117],[292,147],[355,170],[382,39],[409,52],[424,63],[477,31],[513,58],[539,75],[563,85],[597,35],[661,76],[682,91],[698,105],[721,120],[747,134]],"splits":[363,445,584]},{"points":[[253,155],[302,129],[397,159],[491,105],[550,72]],"splits":[348]},{"points":[[126,159],[234,160],[245,153],[361,162],[377,153],[404,136],[456,115],[499,92],[547,63],[600,42]],"splits":[179,308]},{"points":[[19,28],[48,36],[87,46],[648,183]],"splits":[]},{"points":[[10,182],[425,183],[471,184],[617,182],[756,183]],"splits":[]},{"points":[[13,129],[35,130],[368,135],[382,103],[398,104],[523,111],[767,111]],"splits":[376]},{"points":[[56,65],[64,69],[73,69],[79,73],[88,73],[360,134],[371,136],[383,139],[393,144]],"splits":[]},{"points":[[87,65],[132,79],[171,91],[215,104],[252,116],[286,123],[327,131],[363,134],[407,136],[464,138],[522,139],[577,142],[635,144],[697,143]],"splits":[343]},{"points":[[137,48],[168,69],[220,91],[257,108],[306,123],[360,132],[422,140],[481,143],[531,138],[593,128],[644,114],[708,87],[766,57]],"splits":[186,389,570,678]},{"points":[[101,36],[118,97],[129,55],[150,72],[186,96],[224,119],[250,144],[280,164],[334,167],[376,167],[445,169],[493,170],[554,172],[616,171]],"splits":[307]},{"points":[[120,146],[152,128],[182,115],[221,102],[248,103],[327,102],[416,108],[484,106],[561,107],[620,73],[668,52],[680,44],[696,36]],"splits":[200,592]},{"points":[[158,67],[193,85],[232,91],[256,99],[273,110],[326,69],[360,70],[381,87],[405,94],[431,99],[449,107]],"splits":[294]},{"points":[[133,57],[215,57],[254,72],[305,51],[357,73],[428,104],[451,49],[528,81],[580,111],[644,144],[667,64],[718,93],[762,117]],"splits":[169,285,444,657]},{"points":[[128,73],[221,68]],"splits":[]},{"points":[[116,159],[145,142],[187,122],[204,110],[251,158],[271,143],[288,128],[370,158],[398,144],[410,133],[492,152],[512,137],[589,155]],"splits":[226,331,450,551]},{"points":[[114,134],[128,118],[146,97],[164,73],[185,116],[201,102],[216,88],[227,73],[253,121],[269,106],[286,87],[298,71],[342,115],[361,87],[377,75],[399,59],[429,115],[443,91],[454,70]],"splits":[171,240,320,415]},{"points":[[68,142],[101,96],[196,111],[287,87],[371,134],[420,67],[515,106],[580,119],[658,78],[735,105]],"splits":[]},{"points":[[87,47],[309,50],[483,44],[496,164],[775,167]],"splits":[490]},{"points":[[81,71],[130,83],[155,92],[176,96],[218,108],[246,119],[288,112],[301,100],[327,94],[340,86],[377,76],[392,68],[402,66],[446,74],[476,87],[503,94],[529,103],[553,110],[574,119],[594,132],[615,138],[648,123],[675,114],[694,107],[722,98],[733,94],[745,89]],"splits":[265,429,631]},{"points":[[83,84],[124,91],[171,98],[201,96],[223,95],[253,87],[275,80],[298,77],[350,74],[404,74],[451,91],[493,109],[521,116],[585,112],[612,110],[662,103],[695,92],[720,86],[747,91],[773,103]],"splits":[]},{"points":[[51,14],[78,33],[106,49],[133,70],[151,13],[176,32],[201,10],[245,48],[280,82],[316,108],[332,8],[366,38],[387,65],[430,105],[496,10],[565,11],[592,39],[631,10],[685,12],[753,83],[781,107]],"splits":[139,195,329,462,534,619,667]},{"points":[[49,172],[85,144],[188,174],[223,148],[268,122],[348,175],[492,175],[533,146],[582,106],[698,175],[790,93]],"splits":[141,313,414,640]},{"points":[[75,86],[119,98],[205,98],[245,85],[299,84],[385,88],[448,108],[565,126],[633,111],[733,105]],"splits":[]},{"points":[[54,33],[92,51],[129,43],[142,68],[176,72],[217,78],[236,105],[270,108],[374,125],[436,130],[454,122],[494,130],[532,134],[575,131],[601,128],[656,133],[722,140],[753,135]],"splits":[316]},{"points":[[79,64],[119,88],[180,122],[300,66],[416,65],[469,90],[517,123],[583,65],[676,67],[735,100],[781,131]],"splits":[230,363,556,631]},{"points":[[75,139],[171,88],[263,41],[266,34],[343,147],[349,143],[458,148],[465,150],[571,150],[635,103],[639,98],[712,51]],"splits":[314,401,530]},{"points":[[57,97],[218,100],[356,99],[462,97],[545,103],[637,98],[747,101]],"splits":[]},{"points":[[92,119],[146,89],[238,124],[305,91],[362,128],[430,91],[491,49],[555,23],[627,53],[686,23],[756,55]],"splits":[455]},{"points":[[62,32],[62,34],[64,32],[64,32],[198,72],[198,72],[198,72],[198,72],[313,133],[313,133],[313,133],[313,133],[437,148],[437,148],[437,148],[437,148],[537,166],[537,166],[537,166],[537,166],[621,36]],"splits":[589]},{"points":[[73,71],[90,92],[264,101],[281,90],[462,134],[483,147],[650,179],[675,165],[739,31],[754,23]],"splits":[703]},{"points":[[39,166],[57,160],[125,140],[146,132],[173,122],[232,106],[310,82],[321,76],[399,48],[435,37],[462,29],[465,175],[533,150]],"splits":[463]},{"points":[[43,34],[65,44],[105,71],[144,93],[153,99],[164,103],[196,112],[228,133],[253,148],[262,158],[280,31],[315,47],[348,60],[384,84],[426,100],[448,110],[476,27],[558,71],[659,111],[783,178]],"splits":[265,463]},{"points":[[34,162],[64,154],[120,145],[142,141],[165,136],[185,130],[376,161],[414,149],[446,133],[470,123],[500,119],[526,114],[569,98],[595,87],[649,67],[697,48],[724,46],[767,166],[788,153]],"splits":[268,743]},{"points":[[77,43],[177,80],[293,113],[329,137],[408,80],[461,120],[513,135],[581,157],[638,60],[716,103],[778,132]],"splits":[366,612]},{"points":[[48,56],[111,63],[153,70],[185,76],[240,88],[267,95],[315,102],[359,105],[387,112],[410,63],[458,76],[487,79],[517,82],[543,85],[580,97],[623,104],[657,106],[687,83],[724,91],[747,95],[761,99],[784,108]],"splits":[396,672]},{"points":[[25,61],[80,74],[104,79],[139,83],[149,88],[187,92],[323,50],[366,58],[405,66],[446,78],[480,85],[512,93],[543,101],[583,108],[617,117],[648,135],[680,138],[692,53],[722,63],[761,73],[772,77]],"splits":[259,684]},{"points":[[41,126],[95,106],[169,76],[230,54],[237,179],[262,160],[297,147],[319,135],[343,119],[527,36],[570,18],[597,12],[617,160],[658,143],[693,129],[738,121]],"splits":[232,607]},{"points":[[52,65],[117,96],[172,87],[220,99],[246,119],[259,127],[309,132],[353,137],[416,150],[457,160],[506,111],[542,82],[582,81],[615,73],[662,57],[685,39],[702,32],[744,23],[776,17]],"splits":[482]},{"points":[[376,53],[379,124],[383,81]],"splits":[]},{"points":[[287,91],[361,123],[423,81],[492,114],[527,84],[589,134],[651,84]],"splits":[]},{"points":[[21,71],[64,87],[82,95],[121,108],[133,67],[222,91],[285,102],[355,123],[413,137],[435,61],[500,72],[538,88],[598,114],[649,73],[696,92],[714,99],[766,120]],"splits":[126,419,636]},{"points":[[65,122],[130,102],[189,79],[258,68],[351,138],[393,98],[416,94],[476,68],[592,143],[648,111],[658,109],[720,71],[763,47]],"splits":[315,545]},{"points":[[101,69],[103,82],[108,97],[330,101],[336,120],[339,134],[493,155],[505,143],[521,158],[577,57],[586,67],[606,51],[712,119],[727,133],[740,117]],"splits":[548]},{"points":[[71,66],[75,74],[88,77],[98,84],[107,76],[196,100],[199,114],[206,95],[218,106],[296,128],[297,143],[312,125],[318,139],[322,119],[334,43],[352,50],[353,61],[372,45],[373,57],[559,100],[560,93],[579,89],[580,100],[581,97]],"splits":[327]},{"points":[[99,65],[114,78],[119,74],[430,163],[440,172],[453,164],[482,45],[494,50],[501,45],[595,88],[616,91],[632,106],[739,155],[746,164],[759,165]],"splits":[462]},{"points":[[18,172],[44,164],[64,152],[88,143],[111,138],[124,131],[197,107],[209,101],[217,99],[258,84],[265,79],[273,75],[279,73],[349,169],[363,165],[397,147],[412,134],[436,123],[484,108],[492,104],[511,96],[529,90],[589,63],[672,33],[693,168],[739,152]],"splits":[318,685]},{"points":[[33,42],[38,45],[48,49],[49,49],[55,51],[69,57],[75,59],[93,68],[93,68],[116,77],[129,82],[145,93],[149,93],[161,94],[173,99],[183,105],[198,115],[204,118],[225,127],[230,129],[245,136],[262,142],[445,62],[457,67],[470,71],[472,72],[486,78],[508,88],[520,91],[529,95],[543,102],[550,106],[554,108],[573,115],[587,121],[648,144],[706,76],[720,83],[735,90],[735,90],[742,96]],"splits":[336,684]},{"points":[[55,127],[83,118],[97,116],[126,107],[134,105],[138,104],[164,96],[175,96],[195,90],[220,84],[232,81],[244,80],[258,75],[271,75],[279,73],[300,68],[307,66],[339,139],[353,134],[363,131],[379,129],[393,125],[402,123],[417,120],[434,116],[450,114],[474,103],[492,101],[522,95],[549,92],[557,91],[601,81],[669,133]],"splits":[321,636]},{"points":[[42,143],[54,138],[67,135],[73,134],[97,127],[106,126],[117,124],[275,87],[285,81],[309,72],[325,69],[341,66],[371,60],[397,54],[398,54],[433,48],[500,150],[632,177],[640,171],[679,167],[679,167],[679,167],[721,151],[731,149],[755,141],[778,135]],"splits":[464,586]},{"points":[[35,101],[62,97],[80,101],[111,100],[125,99],[142,99],[167,101],[168,101],[175,101],[188,103],[209,101],[223,101],[232,101],[293,101],[326,99],[329,98],[347,101],[429,157],[437,158],[452,158],[471,158],[492,159],[509,158],[517,158],[524,158],[535,77],[559,78],[568,79],[580,76],[581,76],[605,78],[609,78],[615,80],[616,79],[726,130]],"splits":[388,530,677]},{"points":[[40,128],[84,127],[140,123],[162,126],[191,127],[283,126],[398,123],[458,126],[519,130],[646,32]],"splits":[584]},{"points":[[60,117],[109,126],[128,124],[165,132],[190,136],[239,139],[269,142],[306,147],[339,156],[376,161],[401,162],[434,166],[478,169],[507,173],[538,175],[569,117]],"splits":[551]},{"points":[[56,146],[173,121],[297,90],[428,52],[541,16],[549,183]],"splits":[545]},{"points":[[71,87],[98,85],[98,85],[121,85],[131,80],[144,78],[166,78],[181,76],[199,72],[213,71],[224,71],[249,68],[322,88],[333,82],[343,84],[352,80],[357,77],[371,71],[381,68],[390,67],[399,65],[424,58],[503,139],[516,134],[529,135],[542,135],[551,134],[559,135],[570,138],[582,140],[593,138],[611,137],[620,137],[696,32]],"splits":[475,654,278]}] -------------------------------------------------------------------------------- /ios/Split.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SplitSequence.swift 3 | // Movement 4 | // 5 | // Created by Stefano J. Attardi on 1/19/18. 6 | // Copyright © 2018 Rational Creation. All rights reserved. 7 | // 8 | 9 | import CoreML 10 | 11 | @objc(Split) 12 | class Split: NSObject { 13 | 14 | var model: AnyObject? = nil 15 | 16 | @objc static func requiresMainQueueSetup() -> Bool { 17 | return false 18 | } 19 | 20 | @objc(split:callback:) 21 | func split(points: [[Float32]], callback: RCTResponseSenderBlock) { 22 | guard points.count >= 2 else { 23 | callback([NSNull(), []]) 24 | return 25 | } 26 | if #available(iOS 11.0, *) { 27 | if self.model == nil { 28 | self.model = SplitModel() 29 | } 30 | guard let model = self.model as? SplitModel else { 31 | print("Failed to create model") 32 | callback(["coreml_error", NSNull()]) 33 | return 34 | } 35 | 36 | // let example: [[Float32]] = [[41, 24], [163, 116], [254, 116], [319, 103], [484, 112], [533, 84], [629, 91]] 37 | let xs = points.map { $0[0] } 38 | let ys = points.map { $0[1] } 39 | let minX = xs.min()! 40 | let maxX = xs.max()! 41 | let minY = ys.min()! 42 | let maxY = ys.max()! 43 | let yShift = ((maxY - minY) / (maxX - minX)) / 2.0 44 | guard let data = try? MLMultiArray(shape: [1, 2, 100], dataType: .float32) else { 45 | print("Failed to create MLMultiArray") 46 | callback(["coreml_error", NSNull()]) 47 | return 48 | } 49 | 50 | for (i, point) in points.enumerated() { 51 | let doubleI = Double(i) 52 | let x = Double((point[0] - minX) / (maxX - minX) - 0.5) 53 | let y = Double((point[1] - minY) / (maxX - minX) - yShift) 54 | data[[NSNumber(floatLiteral: 0), NSNumber(floatLiteral: 0), NSNumber(floatLiteral: doubleI)]] = NSNumber(floatLiteral: x) 55 | data[[NSNumber(floatLiteral: 0), NSNumber(floatLiteral: 1), NSNumber(floatLiteral: doubleI)]] = NSNumber(floatLiteral: y) 56 | } 57 | 58 | do { 59 | let start = CACurrentMediaTime() 60 | let prediction = try model.prediction(_1: data)._27 61 | print("ml time \(CACurrentMediaTime() - start)") 62 | var indices: [Int] = [] 63 | for (index, prob) in prediction { 64 | if prob > 0.5 && index < points.count - 1 { 65 | indices.append(Int(index)) 66 | } 67 | } 68 | callback([NSNull(), indices.sorted()]) 69 | return 70 | } catch { 71 | print("Error running CoreML: \(error)") 72 | callback(["coreml_error", NSNull()]) 73 | return 74 | } 75 | } else { 76 | callback(["coreml_unavailable", NSNull()]) 77 | } 78 | } 79 | } 80 | 81 | -------------------------------------------------------------------------------- /ios/SplitBridge.m: -------------------------------------------------------------------------------- 1 | // 2 | // SplitBridge.m 3 | // Movement 4 | // 5 | // Created by Stefano J. Attardi on 1/19/18. 6 | // Copyright © 2018 Rational Creation. All rights reserved. 7 | // 8 | 9 | #import 10 | 11 | @interface RCT_EXTERN_MODULE(Split, NSObject) 12 | 13 | RCT_EXTERN_METHOD(split:(NSArray *> *)points callback:(RCTResponseSenderBlock *)callback) 14 | 15 | @end 16 | 17 | #import 18 | 19 | @interface RCTConvert (RCTConvertNSNumberArrayArray) 20 | @end 21 | 22 | @implementation RCTConvert (RCTConvertNSNumberArrayArray) 23 | + (NSArray *> *)NSNumberArrayArray:(id)json 24 | { 25 | return RCTConvertArrayValue(@selector(NSNumberArray:), json); 26 | } 27 | @end 28 | 29 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import json 3 | import os 4 | import torch 5 | from torch.autograd import Variable 6 | from train import Model, encode 7 | 8 | current_dir = os.path.dirname(__file__) 9 | model_file = os.path.join(current_dir, 'SplitModel.pt') 10 | data_file = os.path.join(current_dir, 'data.json') 11 | output_file = os.path.join(current_dir, 'prediction.json') 12 | 13 | model = Model() 14 | model.load_state_dict(torch.load(model_file)) 15 | 16 | with open(data_file) as f: 17 | examples = json.load(f) 18 | 19 | data = Variable(torch.stack([encode(p['points']) for p in examples])) 20 | logits = model(data) 21 | prediction = [] 22 | for example, probs in zip(examples, logits): 23 | prediction.append([i for i, prob in enumerate(list(probs)) if float(prob) >= 0.5 and i < len(example['points'])]) 24 | 25 | with open(output_file, 'w') as f: 26 | json.dump(prediction, f) 27 | -------------------------------------------------------------------------------- /prediction.json: -------------------------------------------------------------------------------- 1 | [[3], [4, 7, 12], [5], [], [], [], [], [], [0, 4], [1, 2, 4], [], [2, 5], [3], [5], [], [], [], [6], [5, 10, 15], [], [5], [1, 4, 6, 12, 14], [3, 7, 11], [3, 7, 14, 20], [10, 13], [2, 5, 11, 15, 19], [], [2], [2], [], [3], [1], [3], [6, 9, 13], [1], [2], [], [], [2], [], [], [1, 5, 8, 10], [3, 7], [8], [4], [0, 2, 5, 9], [], [6, 9], [3, 7, 11], [], [], [5, 20], [], [3, 5, 9, 13, 17], [1, 4, 5, 8], [], [7], [2, 3, 6, 7], [3, 5, 7], [2], [5, 7], [11, 19], [7], [], [9], [5], [3, 7], [8], [5, 16], [3], [], [], [2], [3, 8, 12], [3, 7], [5, 8], [13], [5], [12], [21, 35], [16, 31], [6, 15, 16], [16, 24, 31, 33], [8], [], [], [11, 21, 32]] -------------------------------------------------------------------------------- /test-ui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |
6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /test-ui/main.tsx: -------------------------------------------------------------------------------- 1 | import * as React from 'react'; 2 | import * as ReactDOM from 'react-dom'; 3 | import {piecewiseRegression, piecewiseRegressionWithSplits} from './regression'; 4 | import * as dataJSON from '../data.json'; 5 | import * as prediction from '../prediction.json'; 6 | 7 | type Point = [number, number]; 8 | 9 | export function findRange(values: number[]): [number, number] { 10 | return values.reduce<[number, number]>( 11 | ([min, max], value) => [Math.min(min, value), Math.max(max, value)], 12 | [Infinity, -Infinity], 13 | ); 14 | } 15 | 16 | export function makeScale([start, end]: [number, number], [min, max]: [number, number]) { 17 | return (n: number) => start + (end - start) * (n - min) / (max - min); 18 | } 19 | 20 | let data: Array<{points: Point[]; splits: number[]}> = dataJSON; 21 | let useNN = true; 22 | const history: Array = []; 23 | 24 | function addExample() { 25 | history.push(data); 26 | data = [...data, {points: [], splits: []}]; 27 | useNN = false; 28 | render(); 29 | } 30 | 31 | function undo() { 32 | if (!history.length) return; 33 | data = history.pop(); 34 | render(); 35 | } 36 | 37 | function switchToPlain() { 38 | useNN = false; 39 | render(); 40 | } 41 | 42 | function switchToNN() { 43 | if (history.length > 0) { 44 | alert( 45 | 'Running the Neural Network on dynamically added data is not implemented yet. Undo or refresh to see NN output on preset data.', 46 | ); 47 | return; 48 | } 49 | useNN = true; 50 | render(); 51 | } 52 | 53 | const Button = ({onClick, children, style = {} as React.CSSProperties}) => ( 54 | 57 | ); 58 | 59 | function predict(points: Point[]) { 60 | return []; 61 | } 62 | 63 | function findTrendLines(points, useNN = false, exampleID: number = null) { 64 | if (useNN) { 65 | const splits = prediction[exampleID] || []; 66 | return piecewiseRegressionWithSplits(points, splits); 67 | } else { 68 | return piecewiseRegression(points); 69 | } 70 | } 71 | 72 | function coordinatesToSplits(xs: number[], points: Point[]): number[] { 73 | if (xs.length === 0) return []; 74 | const splits = []; 75 | let nextX = 0; 76 | for (let i = 0; i < points.length; i++) { 77 | if (points[i][0] > xs[nextX]) { 78 | splits.push(i - 1); 79 | nextX++; 80 | if (nextX >= xs.length) break; 81 | } 82 | } 83 | return splits; 84 | } 85 | 86 | type TrendLine = {start: number; end: number; slope: number; y: number}; 87 | 88 | function linesToSplits(lines: TrendLine[], points: Point[]): number[] { 89 | if (lines.length === 0) return []; 90 | const splits = new Set(); 91 | let nextLine = 0; 92 | let inLine = false; 93 | for (let i = 0; i < points.length; i++) { 94 | if (points[i][0] === lines[nextLine].start) { 95 | splits.add(i - 1); 96 | inLine = true; 97 | } else if (points[i][0] === lines[nextLine].end) { 98 | splits.add(i); 99 | inLine = false; 100 | nextLine++; 101 | if (nextLine >= lines.length) break; 102 | } else if (!inLine) { 103 | splits.add(i - 1); 104 | splits.add(i); 105 | } 106 | } 107 | splits.delete(-1); 108 | splits.delete(points.length - 1); 109 | const res = [...splits]; 110 | res.sort((a, b) => a - b); 111 | return res; 112 | } 113 | 114 | const App = ({ 115 | examples, 116 | width, 117 | height, 118 | useNN = false, 119 | }: { 120 | examples: Array<{points: Point[]; splits: number[]}>; 121 | width: number; 122 | height: number; 123 | useNN?: boolean; 124 | }) => ( 125 |
126 |
137 | 138 |
139 | {useNN ? 'Using Neural Network' : 'Using Plain Math'} 140 |
141 | 146 | 149 |
150 | {examples.map(({points, splits}, i) => { 151 | const xRange: Point = [0, width]; // || findRange(points.map(p => p[0])); 152 | const yRange: Point = [0, height]; // || findRange(points.map(p => p[1])); 153 | const xScale = makeScale([0, width], xRange); 154 | const yScale = makeScale([height, 0], yRange); 155 | 156 | const xScaleInverse = makeScale(xRange, [0, width]); 157 | const yScaleInverse = makeScale(yRange, [height, 0]); 158 | const lines = findTrendLines(points, useNN, i); 159 | 160 | splits.sort((a, b) => a - b); 161 | const correct = 162 | coordinatesToSplits(splits, points).join(',') === linesToSplits(lines, points).join(','); 163 | 164 | const addPoint = ({nativeEvent: {offsetX, offsetY}}: React.MouseEvent) => { 165 | history.push(JSON.parse(JSON.stringify(data))); 166 | points.push([xScaleInverse(offsetX), yScaleInverse(offsetY)]); 167 | points.sort((a, b) => a[0] - b[0]); 168 | switchToPlain(); 169 | }; 170 | const addSplit = (event: React.MouseEvent) => { 171 | event.stopPropagation(); 172 | history.push(JSON.parse(JSON.stringify(data))); 173 | splits.push(xScaleInverse(event.nativeEvent.offsetX)); 174 | splits.sort((a, b) => a[0] - b[0]); 175 | switchToPlain(); 176 | }; 177 | 178 | return ( 179 | 190 | 191 | {splits.map((x, i) => ( 192 | 193 | ))} 194 | {points.map(([x, y], i) => ( 195 | 196 | ))} 197 | {lines.map(({slope, y, start, end}, i) => ( 198 | 206 | ))} 207 | 208 | ); 209 | })} 210 |
211 | ); 212 | 213 | const app = document.getElementById('app'); 214 | 215 | function render() { 216 | ReactDOM.render(, app); 217 | } 218 | 219 | render(); 220 | -------------------------------------------------------------------------------- /test-ui/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "pytorch-coreml-example", 3 | "version": "1.0.0", 4 | "description": "Companion code to the blog post at https://attardi.org/pytorch-and-coreml", 5 | "main": "n/a", 6 | "repository": "https://github.com/steadicat/pytorch-and-coreml", 7 | "author": "Stefano J. Attardi ", 8 | "license": "MIT", 9 | "dependencies": { 10 | "react": "^16.2.0", 11 | "react-dom": "^16.2.0" 12 | }, 13 | "devDependencies": { 14 | "@types/react": "^16.0.36", 15 | "@types/react-dom": "^16.0.3", 16 | "parcel-bundler": "^1.5.1", 17 | "tslint": "^5.9.1", 18 | "typescript": "^2.7.1" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /test-ui/regression.ts: -------------------------------------------------------------------------------- 1 | export type Point = [number, number]; 2 | 3 | export function regression(coords: Point[]) { 4 | let sumOfX = 0; 5 | let sumOfY = 0; 6 | let sumOfXX = 0; 7 | let sumOfXY = 0; 8 | for (const [x, y] of coords) { 9 | sumOfX += x; 10 | sumOfY += y; 11 | sumOfXX += x * x; 12 | sumOfXY += x * y; 13 | } 14 | 15 | const len = coords.length; 16 | const run = len * sumOfXX - sumOfX * sumOfX; 17 | const rise = len * sumOfXY - sumOfX * sumOfY; 18 | const slope = run === 0 ? 0 : rise / run; 19 | return {slope, y: sumOfY / len - slope * sumOfX / len}; 20 | } 21 | 22 | function mse({slope, y: y0}: {slope: number; y: number}, coords: Point[]) { 23 | let error = 0; 24 | const weights = coords.map((p, i) => { 25 | const prev = coords[i - 1]; 26 | const next = coords[i + 1]; 27 | if (!prev && !next) return 1; 28 | if (!prev) return next[0] - p[0]; 29 | if (!next) return p[0] - prev[0]; 30 | return Math.min(p[0] - prev[0], next[0] - p[0]); 31 | }); 32 | 33 | coords.forEach(([x, y], i) => { 34 | error += Math.pow(x * slope + y0 - y, 2) * weights[i]; 35 | }); 36 | return error / weights.reduce((a, b) => a + b); 37 | } 38 | 39 | type TrendLine = { 40 | slope: number; 41 | y: number; 42 | start: number; 43 | end: number; 44 | }; 45 | 46 | export function piecewiseRegressionWithSplits( 47 | coords: Point[], 48 | splits: number[] | null, 49 | ): TrendLine[] { 50 | if (!splits) return []; 51 | const parts = []; 52 | let start = 0; 53 | for (const split of splits) { 54 | parts.push(coords.slice(start, split + 1)); 55 | start = split + 1; 56 | } 57 | if (start < coords.length - 1) { 58 | parts.push(coords.slice(start)); 59 | } 60 | return parts.filter(part => part.length > 1).map(part => { 61 | const {slope, y} = regression(part); 62 | return {slope, y, start: part[0][0], end: part[part.length - 1][0]}; 63 | }); 64 | } 65 | 66 | const minGainPercentage = 0.1; 67 | 68 | export function piecewiseRegression(coords: Point[], minGain: number | null = null): TrendLine[] { 69 | if (coords.length <= 1) return []; 70 | if (coords.length <= 3) 71 | return [ 72 | { 73 | ...regression(coords), 74 | start: coords[0][0], 75 | end: coords[coords.length - 1][0], 76 | }, 77 | ]; 78 | const originalLine = regression(coords); 79 | const originalLoss = mse(originalLine, coords); 80 | minGain || (minGain = originalLoss * minGainPercentage); 81 | const a = [...coords]; 82 | const b: typeof a = []; 83 | b.unshift(a.pop()!); 84 | b.unshift(a.pop()!); 85 | let bestSplit = 1; 86 | let bestGain = 0; 87 | for (; a.length > 1; b.unshift(a.pop()!)) { 88 | const aLine = regression(a); 89 | const aLoss = mse(aLine, a); 90 | const bLine = regression(b); 91 | const bLoss = mse(bLine, b); 92 | const gain = originalLoss - (aLoss + bLoss) / 2; 93 | if (gain > bestGain) { 94 | bestGain = gain; 95 | bestSplit = a.length; 96 | } 97 | } 98 | if (bestGain > minGain) { 99 | return [ 100 | ...piecewiseRegression(coords.slice(0, bestSplit), minGain), 101 | ...piecewiseRegression(coords.slice(bestSplit), minGain), 102 | ]; 103 | } else { 104 | return [ 105 | { 106 | ...regression(coords), 107 | start: coords[0][0], 108 | end: coords[coords.length - 1][0], 109 | }, 110 | ]; 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /test-ui/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es6", 4 | "noEmit": true, 5 | "jsx": "react", 6 | "moduleResolution": "node", 7 | "strict": true, 8 | "noUnusedLocals": true, 9 | "noUnusedParameters": true, 10 | "lib": ["es2015", "dom", "es2016"], 11 | "types": [] 12 | }, 13 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchnet.dataset import SplitDataset, ShuffleDataset 9 | import torch.nn.functional as F 10 | from math import sqrt 11 | from onnx_coreml import convert 12 | import onnx 13 | 14 | number_of_points = 100 15 | number_of_channels = 2 16 | epochs = 1000 17 | batch_size = 10 18 | current_dir = os.path.dirname(__file__) 19 | data_file = os.path.join(current_dir, 'data.json') 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self): 24 | super(Model, self).__init__() 25 | 26 | channels = 32 27 | self.conv1 = nn.Sequential( 28 | nn.Conv2d(in_channels=number_of_channels, out_channels=channels, kernel_size=(1, 7), padding=(0, 3)), 29 | nn.ReLU(), 30 | nn.Dropout2d(), 31 | ) 32 | self.conv2 = nn.Sequential( 33 | nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 5), padding=(0, 2)), 34 | nn.ReLU(), 35 | nn.Dropout2d(), 36 | ) 37 | self.conv3 = nn.Sequential( 38 | nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 3), padding=(0, 1)), 39 | nn.ReLU(), 40 | nn.Dropout2d(), 41 | ) 42 | self.conv4 = nn.Sequential( 43 | nn.Conv2d(in_channels=channels, out_channels=1, kernel_size=(1, 3), padding=(0, 1)), 44 | nn.Sigmoid(), 45 | ) 46 | 47 | def forward(self, x): 48 | x = x.view(-1, x.size(1), 1, x.size(2)) 49 | x = self.conv1(x) 50 | x = self.conv2(x) 51 | x = self.conv3(x) 52 | x = self.conv4(x) 53 | x = x.view(-1, x.size(3)) 54 | return x 55 | 56 | 57 | def encode(points): 58 | xs = [p[0] for p in points] 59 | ys = [p[1] for p in points] 60 | min_x = min(xs) 61 | max_x = max(xs) 62 | min_y = min(ys) 63 | max_y = max(ys) 64 | y_shift = ((max_y - min_y) / (max_x - min_x)) / 2.0 65 | input_tensor = torch.zeros([number_of_channels, number_of_points]) 66 | 67 | def normalize_x(x): 68 | return (x - min_x) / (max_x - min_x) - 0.5 69 | def normalize_y(y): 70 | return (y - min_y) / (max_x - min_x) - y_shift 71 | 72 | for i in range(min(number_of_points, len(points))): 73 | x = points[i][0] * 1.0 74 | y = points[i][1] * 1.0 75 | input_tensor[0][i] = normalize_x(x) 76 | input_tensor[1][i] = normalize_y(y) 77 | continue 78 | return input_tensor 79 | 80 | 81 | class PointsDataset(Dataset): 82 | def __init__(self, csv_file): 83 | self.examples = json.load(open(csv_file)) 84 | 85 | def __len__(self): 86 | return len(self.examples) 87 | 88 | def __getitem__(self, idx): 89 | example = self.examples[idx] 90 | input_tensor = encode(example['points']) 91 | output_tensor = torch.zeros(number_of_points) 92 | for split_position in example['splits']: 93 | index = next(i for i, point in enumerate(example['points']) if point[0] > split_position) - 1 94 | output_tensor[index] = 1 95 | return input_tensor, output_tensor 96 | 97 | def evaluate(model, data): 98 | inputs, target = data 99 | inputs = Variable(inputs) 100 | target = Variable(target) 101 | mask = inputs.eq(0).sum(dim=1).eq(0) 102 | logits = model(inputs) 103 | correct = int(logits.round().eq(target).mul(mask).sum().data) 104 | total = int(mask.sum()) 105 | accuracy = 100.0 * correct / total 106 | 107 | float_mask = mask.float() 108 | masked_logits = logits.mul(float_mask) 109 | masked_target = target.mul(float_mask) 110 | loss = F.binary_cross_entropy(masked_logits, masked_target) 111 | 112 | return float(loss), accuracy, correct, total 113 | 114 | def train(model, epochs=epochs, batch_size=batch_size): 115 | optimizer = torch.optim.Adam(model.parameters()) 116 | dataset = PointsDataset(data_file) 117 | dataset = SplitDataset(ShuffleDataset(dataset), {'train': 0.9, 'validation': 0.1}) 118 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size) 119 | 120 | model.train() 121 | 122 | for epoch in range(epochs): 123 | dataset.select('train') 124 | running_loss = 0.0 125 | 126 | for i, (inputs, target) in enumerate(loader): 127 | inputs = Variable(inputs) 128 | target = Variable(target) 129 | 130 | logits = model(inputs) 131 | mask = inputs.eq(0).sum(dim=1).eq(0) 132 | float_mask = mask.float() 133 | masked_logits = logits.mul(float_mask) 134 | masked_target = target.mul(float_mask) 135 | loss = F.binary_cross_entropy(masked_logits, masked_target) 136 | optimizer.zero_grad() 137 | loss.backward() 138 | optimizer.step() 139 | running_loss += loss.data[0] 140 | 141 | 142 | dataset.select('validation') 143 | validation_loss, validation_accuracy, correct, total = evaluate(model, next(iter(loader))) 144 | 145 | print '\r[{:4d}] - running loss: {:8.6f} - validation loss: {:8.6f} validation acc: {:7.3f}% ({}/{})'.format( 146 | epoch + 1, 147 | running_loss, 148 | 149 | validation_loss, 150 | validation_accuracy, 151 | correct, 152 | total 153 | ), 154 | sys.stdout.flush() 155 | 156 | running_loss = 0.0 157 | 158 | print('\n') 159 | 160 | 161 | if __name__ == '__main__': 162 | model = Model() 163 | train(model) 164 | path = os.path.join(current_dir, 'SplitModel.proto') 165 | dummy_input = Variable(torch.FloatTensor(1, number_of_channels, number_of_points)) 166 | torch.save(model.state_dict(), os.path.join(current_dir, 'SplitModel.pt')) 167 | torch.onnx.export(model, dummy_input, path, verbose=True) 168 | model = onnx.load(os.path.join(os.path.dirname(__file__), 'SplitModel.proto')) 169 | coreml_model = convert( 170 | model, 171 | 'classifier', 172 | image_input_names=['input'], 173 | image_output_names=['output'], 174 | class_labels=[i for i in range(number_of_points)], 175 | ) 176 | coreml_model.save('SplitModel.mlmodel') 177 | --------------------------------------------------------------------------------